From c21c3b0befeb46a51b6bf3758ffa30813bea0ff0 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sat, 9 Mar 2024 14:19:22 +0100 Subject: Adding upstream version 1.44.3. Signed-off-by: Daniel Baumann --- ml/Config.cc | 2 +- ml/dlib/.gitignore | 9 + ml/dlib/.hgignore | 43 + ml/dlib/.hgtags | 41 + ml/dlib/.travis.yml | 107 + ml/dlib/CMakeLists.txt | 36 + ml/dlib/ISSUE_TEMPLATE.md | 31 + ml/dlib/MANIFEST.in | 15 + ml/dlib/README.md | 71 + ml/dlib/dlib/CMakeLists.txt | 841 + ml/dlib/dlib/LICENSE.txt | 23 + ml/dlib/dlib/algs.h | 1157 + ml/dlib/dlib/all/source.cpp | 98 + ml/dlib/dlib/any.h | 13 + ml/dlib/dlib/any/any.h | 183 + ml/dlib/dlib/any/any_abstract.h | 210 + ml/dlib/dlib/any/any_decision_function.h | 209 + ml/dlib/dlib/any/any_decision_function_abstract.h | 224 + ml/dlib/dlib/any/any_function.h | 885 + ml/dlib/dlib/any/any_function_abstract.h | 292 + ml/dlib/dlib/any/any_function_impl.h | 516 + ml/dlib/dlib/any/any_function_impl2.h | 44 + ml/dlib/dlib/any/any_trainer.h | 217 + ml/dlib/dlib/any/any_trainer_abstract.h | 234 + ml/dlib/dlib/appveyor/dtest.yml | 19 + ml/dlib/dlib/appveyor/dtest_vc2017.yml | 21 + ml/dlib/dlib/appveyor/examples.yml | 16 + ml/dlib/dlib/appveyor/python.yml | 33 + ml/dlib/dlib/array.h | 10 + ml/dlib/dlib/array/array_kernel.h | 810 + ml/dlib/dlib/array/array_kernel_abstract.h | 360 + ml/dlib/dlib/array/array_tools.h | 38 + ml/dlib/dlib/array/array_tools_abstract.h | 33 + ml/dlib/dlib/array2d.h | 12 + ml/dlib/dlib/array2d/array2d_generic_image.h | 67 + ml/dlib/dlib/array2d/array2d_kernel.h | 498 + ml/dlib/dlib/array2d/array2d_kernel_abstract.h | 301 + ml/dlib/dlib/array2d/serialize_pixel_overloads.h | 371 + ml/dlib/dlib/assert.h | 216 + ml/dlib/dlib/base64.h | 9 + ml/dlib/dlib/base64/base64_kernel_1.cpp | 403 + ml/dlib/dlib/base64/base64_kernel_1.h | 92 + ml/dlib/dlib/base64/base64_kernel_abstract.h | 121 + ml/dlib/dlib/bayes_utils.h | 11 + ml/dlib/dlib/bayes_utils/bayes_utils.h | 1678 ++ ml/dlib/dlib/bayes_utils/bayes_utils_abstract.h | 1042 + ml/dlib/dlib/bigint.h | 43 + ml/dlib/dlib/bigint/bigint_kernel_1.cpp | 1720 ++ ml/dlib/dlib/bigint/bigint_kernel_1.h | 544 + ml/dlib/dlib/bigint/bigint_kernel_2.cpp | 1945 ++ ml/dlib/dlib/bigint/bigint_kernel_2.h | 570 + ml/dlib/dlib/bigint/bigint_kernel_abstract.h | 670 + ml/dlib/dlib/bigint/bigint_kernel_c.h | 1141 + ml/dlib/dlib/binary_search_tree.h | 50 + .../binary_search_tree_kernel_1.h | 2064 ++ .../binary_search_tree_kernel_2.h | 1897 ++ .../binary_search_tree_kernel_abstract.h | 311 + .../binary_search_tree_kernel_c.h | 235 + ml/dlib/dlib/bit_stream.h | 42 + ml/dlib/dlib/bit_stream/bit_stream_kernel_1.cpp | 200 + ml/dlib/dlib/bit_stream/bit_stream_kernel_1.h | 120 + .../dlib/bit_stream/bit_stream_kernel_abstract.h | 185 + ml/dlib/dlib/bit_stream/bit_stream_kernel_c.h | 172 + ml/dlib/dlib/bit_stream/bit_stream_multi_1.h | 103 + .../dlib/bit_stream/bit_stream_multi_abstract.h | 77 + ml/dlib/dlib/bit_stream/bit_stream_multi_c.h | 101 + ml/dlib/dlib/bits/c++config.h | 1 + ml/dlib/dlib/bound_function_pointer.h | 10 + .../bound_function_pointer_kernel_1.h | 774 + .../bound_function_pointer_kernel_abstract.h | 456 + ml/dlib/dlib/bridge.h | 17 + ml/dlib/dlib/bridge/bridge.h | 669 + ml/dlib/dlib/bridge/bridge_abstract.h | 347 + ml/dlib/dlib/bsp.h | 12 + ml/dlib/dlib/bsp/bsp.cpp | 496 + ml/dlib/dlib/bsp/bsp.h | 1043 + ml/dlib/dlib/bsp/bsp_abstract.h | 912 + ml/dlib/dlib/byte_orderer.h | 10 + ml/dlib/dlib/byte_orderer/byte_orderer_kernel_1.h | 176 + .../byte_orderer/byte_orderer_kernel_abstract.h | 149 + ml/dlib/dlib/cassert | 1 + ml/dlib/dlib/clustering.h | 13 + ml/dlib/dlib/clustering/bottom_up_cluster.h | 253 + .../dlib/clustering/bottom_up_cluster_abstract.h | 136 + ml/dlib/dlib/clustering/chinese_whispers.h | 135 + .../dlib/clustering/chinese_whispers_abstract.h | 97 + ml/dlib/dlib/clustering/modularity_clustering.h | 515 + .../clustering/modularity_clustering_abstract.h | 125 + ml/dlib/dlib/clustering/spectral_cluster.h | 80 + .../dlib/clustering/spectral_cluster_abstract.h | 43 + ml/dlib/dlib/cmake | 5 + .../cmake_utils/add_global_compiler_switch.cmake | 35 + .../dlib/cmake_utils/check_if_neon_available.cmake | 20 + ml/dlib/dlib/cmake_utils/dlib.pc.in | 9 + ml/dlib/dlib/cmake_utils/dlibConfig.cmake.in | 50 + ml/dlib/dlib/cmake_utils/find_blas.cmake | 385 + ml/dlib/dlib/cmake_utils/release_build_by_default | 9 + .../set_compiler_specific_options.cmake | 131 + .../tell_visual_studio_to_use_static_runtime.cmake | 19 + .../dlib/cmake_utils/test_for_cpp11/CMakeLists.txt | 17 + .../dlib/cmake_utils/test_for_cpp11/cpp11_test.cpp | 51 + .../dlib/cmake_utils/test_for_cuda/CMakeLists.txt | 14 + .../dlib/cmake_utils/test_for_cuda/cuda_test.cu | 21 + .../dlib/cmake_utils/test_for_cudnn/CMakeLists.txt | 19 + .../dlib/cmake_utils/test_for_cudnn/find_cudnn.txt | 24 + .../dlib/cmake_utils/test_for_neon/CMakeLists.txt | 6 + .../dlib/cmake_utils/test_for_neon/neon_test.cpp | 9 + ml/dlib/dlib/cmake_utils/use_cpp_11.cmake | 113 + ml/dlib/dlib/cmd_line_parser.h | 84 + .../dlib/cmd_line_parser/cmd_line_parser_check_1.h | 580 + .../dlib/cmd_line_parser/cmd_line_parser_check_c.h | 453 + .../cmd_line_parser/cmd_line_parser_kernel_1.h | 799 + .../cmd_line_parser_kernel_abstract.h | 673 + .../cmd_line_parser/cmd_line_parser_kernel_c.h | 203 + .../dlib/cmd_line_parser/cmd_line_parser_print_1.h | 205 + ml/dlib/dlib/cmd_line_parser/get_option.h | 181 + ml/dlib/dlib/cmd_line_parser/get_option_abstract.h | 146 + ml/dlib/dlib/compress_stream.h | 133 + .../compress_stream/compress_stream_kernel_1.h | 252 + .../compress_stream/compress_stream_kernel_2.h | 431 + .../compress_stream/compress_stream_kernel_3.h | 381 + .../compress_stream_kernel_abstract.h | 94 + ml/dlib/dlib/conditioning_class.h | 80 + .../conditioning_class_kernel_1.h | 333 + .../conditioning_class_kernel_2.h | 500 + .../conditioning_class_kernel_3.h | 438 + .../conditioning_class_kernel_4.h | 533 + .../conditioning_class_kernel_abstract.h | 228 + .../conditioning_class_kernel_c.h | 162 + ml/dlib/dlib/config.h | 31 + ml/dlib/dlib/config.h.in | 34 + ml/dlib/dlib/config_reader.h | 39 + .../dlib/config_reader/config_reader_kernel_1.h | 738 + .../config_reader/config_reader_kernel_abstract.h | 363 + .../config_reader/config_reader_thread_safe_1.h | 456 + .../config_reader_thread_safe_abstract.h | 45 + ml/dlib/dlib/console_progress_indicator.h | 207 + ml/dlib/dlib/control.h | 11 + ml/dlib/dlib/control/approximate_linear_models.h | 128 + .../control/approximate_linear_models_abstract.h | 213 + ml/dlib/dlib/control/lspi.h | 188 + ml/dlib/dlib/control/lspi_abstract.h | 193 + ml/dlib/dlib/control/mpc.h | 370 + ml/dlib/dlib/control/mpc_abstract.h | 276 + ml/dlib/dlib/cpp_pretty_printer.h | 39 + .../cpp_pretty_printer_kernel_1.h | 583 + .../cpp_pretty_printer_kernel_2.h | 520 + .../cpp_pretty_printer_kernel_abstract.h | 88 + ml/dlib/dlib/cpp_tokenizer.h | 40 + .../dlib/cpp_tokenizer/cpp_tokenizer_kernel_1.h | 675 + .../cpp_tokenizer/cpp_tokenizer_kernel_abstract.h | 224 + .../dlib/cpp_tokenizer/cpp_tokenizer_kernel_c.h | 137 + ml/dlib/dlib/crc32.h | 10 + ml/dlib/dlib/crc32/crc32_kernel_1.h | 262 + ml/dlib/dlib/crc32/crc32_kernel_abstract.h | 132 + ml/dlib/dlib/cstring | 1 + ml/dlib/dlib/data_io.h | 18 + ml/dlib/dlib/data_io/image_dataset_metadata.cpp | 411 + ml/dlib/dlib/data_io/image_dataset_metadata.h | 174 + ml/dlib/dlib/data_io/libsvm_io.h | 276 + ml/dlib/dlib/data_io/libsvm_io_abstract.h | 125 + ml/dlib/dlib/data_io/load_image_dataset.h | 510 + ml/dlib/dlib/data_io/load_image_dataset_abstract.h | 358 + ml/dlib/dlib/data_io/mnist.cpp | 133 + ml/dlib/dlib/data_io/mnist.h | 32 + ml/dlib/dlib/data_io/mnist_abstract.h | 46 + ml/dlib/dlib/dir_nav.h | 21 + ml/dlib/dlib/dir_nav/dir_nav_extensions.cpp | 121 + ml/dlib/dlib/dir_nav/dir_nav_extensions.h | 172 + ml/dlib/dlib/dir_nav/dir_nav_extensions_abstract.h | 203 + ml/dlib/dlib/dir_nav/dir_nav_kernel_1.cpp | 258 + ml/dlib/dlib/dir_nav/dir_nav_kernel_1.h | 634 + ml/dlib/dlib/dir_nav/dir_nav_kernel_2.cpp | 254 + ml/dlib/dlib/dir_nav/dir_nav_kernel_2.h | 659 + ml/dlib/dlib/dir_nav/dir_nav_kernel_abstract.h | 515 + ml/dlib/dlib/dir_nav/posix.h | 6 + ml/dlib/dlib/dir_nav/windows.h | 6 + ml/dlib/dlib/directed_graph.h | 37 + .../dlib/directed_graph/directed_graph_kernel_1.h | 704 + .../directed_graph_kernel_abstract.h | 383 + ml/dlib/dlib/disjoint_subsets.h | 12 + ml/dlib/dlib/disjoint_subsets/disjoint_subsets.h | 141 + .../disjoint_subsets/disjoint_subsets_abstract.h | 96 + .../dlib/disjoint_subsets/disjoint_subsets_sized.h | 130 + .../disjoint_subsets_sized_abstract.h | 123 + ml/dlib/dlib/dlib_basic_cpp_build_tutorial.txt | 13 + ml/dlib/dlib/dlib_include_path_tutorial.txt | 20 + ml/dlib/dlib/dnn.h | 37 + ml/dlib/dlib/dnn/core.h | 3599 +++ ml/dlib/dlib/dnn/core_abstract.h | 1700 ++ ml/dlib/dlib/dnn/cpu_dlib.cpp | 2170 ++ ml/dlib/dlib/dnn/cpu_dlib.h | 505 + ml/dlib/dlib/dnn/cublas_dlibapi.cpp | 165 + ml/dlib/dlib/dnn/cublas_dlibapi.h | 50 + ml/dlib/dlib/dnn/cuda_data_ptr.cpp | 71 + ml/dlib/dlib/dnn/cuda_data_ptr.h | 184 + ml/dlib/dlib/dnn/cuda_dlib.cu | 1630 ++ ml/dlib/dlib/dnn/cuda_dlib.h | 469 + ml/dlib/dlib/dnn/cuda_errors.h | 70 + ml/dlib/dlib/dnn/cuda_utils.h | 413 + ml/dlib/dlib/dnn/cudnn_dlibapi.cpp | 1604 ++ ml/dlib/dlib/dnn/cudnn_dlibapi.h | 518 + ml/dlib/dlib/dnn/curand_dlibapi.cpp | 113 + ml/dlib/dlib/dnn/curand_dlibapi.h | 75 + ml/dlib/dlib/dnn/cusolver_dlibapi.cu | 204 + ml/dlib/dlib/dnn/cusolver_dlibapi.h | 75 + ml/dlib/dlib/dnn/gpu_data.cpp | 228 + ml/dlib/dlib/dnn/gpu_data.h | 266 + ml/dlib/dlib/dnn/gpu_data_abstract.h | 266 + ml/dlib/dlib/dnn/input.h | 808 + ml/dlib/dlib/dnn/input_abstract.h | 467 + ml/dlib/dlib/dnn/layers.h | 3244 +++ ml/dlib/dlib/dnn/layers_abstract.h | 2631 +++ ml/dlib/dlib/dnn/loss.h | 2870 +++ ml/dlib/dlib/dnn/loss_abstract.h | 1542 ++ ml/dlib/dlib/dnn/solvers.h | 405 + ml/dlib/dlib/dnn/solvers_abstract.h | 204 + ml/dlib/dlib/dnn/tensor.h | 686 + ml/dlib/dlib/dnn/tensor_abstract.h | 727 + ml/dlib/dlib/dnn/tensor_tools.cpp | 985 + ml/dlib/dlib/dnn/tensor_tools.h | 1711 ++ ml/dlib/dlib/dnn/trainer.h | 1333 ++ ml/dlib/dlib/dnn/trainer_abstract.h | 765 + ml/dlib/dlib/dnn/utilities.h | 281 + ml/dlib/dlib/dnn/utilities_abstract.h | 127 + ml/dlib/dlib/dnn/validation.h | 122 + ml/dlib/dlib/enable_if.h | 62 + ml/dlib/dlib/entropy_decoder.h | 44 + .../entropy_decoder/entropy_decoder_kernel_1.cpp | 220 + .../entropy_decoder/entropy_decoder_kernel_1.h | 132 + .../entropy_decoder/entropy_decoder_kernel_2.cpp | 224 + .../entropy_decoder/entropy_decoder_kernel_2.h | 127 + .../entropy_decoder_kernel_abstract.h | 207 + .../entropy_decoder/entropy_decoder_kernel_c.h | 123 + ml/dlib/dlib/entropy_decoder_model.h | 108 + .../entropy_decoder_model_kernel_1.h | 173 + .../entropy_decoder_model_kernel_2.h | 245 + .../entropy_decoder_model_kernel_3.h | 335 + .../entropy_decoder_model_kernel_4.h | 622 + .../entropy_decoder_model_kernel_5.h | 793 + .../entropy_decoder_model_kernel_6.h | 131 + .../entropy_decoder_model_kernel_abstract.h | 116 + ml/dlib/dlib/entropy_encoder.h | 43 + .../entropy_encoder/entropy_encoder_kernel_1.cpp | 239 + .../entropy_encoder/entropy_encoder_kernel_1.h | 119 + .../entropy_encoder/entropy_encoder_kernel_2.cpp | 233 + .../entropy_encoder/entropy_encoder_kernel_2.h | 112 + .../entropy_encoder_kernel_abstract.h | 161 + .../entropy_encoder/entropy_encoder_kernel_c.h | 112 + ml/dlib/dlib/entropy_encoder_model.h | 146 + .../entropy_encoder_model_kernel_1.h | 167 + .../entropy_encoder_model_kernel_2.h | 246 + .../entropy_encoder_model_kernel_3.h | 341 + .../entropy_encoder_model_kernel_4.h | 553 + .../entropy_encoder_model_kernel_5.h | 817 + .../entropy_encoder_model_kernel_6.h | 127 + .../entropy_encoder_model_kernel_abstract.h | 118 + .../entropy_encoder_model_kernel_c.h | 65 + ml/dlib/dlib/error.h | 449 + ml/dlib/dlib/external/cblas/CMakeLists.txt | 182 + ml/dlib/dlib/external/cblas/README | 7 + ml/dlib/dlib/external/cblas/cblas.h | 575 + ml/dlib/dlib/external/cblas/cblas_caxpy.c | 22 + ml/dlib/dlib/external/cblas/cblas_ccopy.c | 22 + ml/dlib/dlib/external/cblas/cblas_cdotc_sub.c | 23 + ml/dlib/dlib/external/cblas/cblas_cdotu_sub.c | 23 + ml/dlib/dlib/external/cblas/cblas_cgbmv.c | 154 + ml/dlib/dlib/external/cblas/cblas_cgemm.c | 94 + ml/dlib/dlib/external/cblas/cblas_cgemv.c | 151 + ml/dlib/dlib/external/cblas/cblas_cgerc.c | 77 + ml/dlib/dlib/external/cblas/cblas_cgeru.c | 38 + ml/dlib/dlib/external/cblas/cblas_chbmv.c | 145 + ml/dlib/dlib/external/cblas/cblas_chemm.c | 91 + ml/dlib/dlib/external/cblas/cblas_chemv.c | 146 + ml/dlib/dlib/external/cblas/cblas_cher.c | 103 + ml/dlib/dlib/external/cblas/cblas_cher2.c | 139 + ml/dlib/dlib/external/cblas/cblas_cher2k.c | 96 + ml/dlib/dlib/external/cblas/cblas_cherk.c | 90 + ml/dlib/dlib/external/cblas/cblas_chpmv.c | 146 + ml/dlib/dlib/external/cblas/cblas_chpr.c | 102 + ml/dlib/dlib/external/cblas/cblas_chpr2.c | 136 + ml/dlib/dlib/external/cblas/cblas_cscal.c | 21 + ml/dlib/dlib/external/cblas/cblas_csscal.c | 21 + ml/dlib/dlib/external/cblas/cblas_cswap.c | 22 + ml/dlib/dlib/external/cblas/cblas_csymm.c | 91 + ml/dlib/dlib/external/cblas/cblas_csyr2k.c | 93 + ml/dlib/dlib/external/cblas/cblas_csyrk.c | 93 + ml/dlib/dlib/external/cblas/cblas_ctbmv.c | 139 + ml/dlib/dlib/external/cblas/cblas_ctbsv.c | 143 + ml/dlib/dlib/external/cblas/cblas_ctpmv.c | 133 + ml/dlib/dlib/external/cblas/cblas_ctpsv.c | 138 + ml/dlib/dlib/external/cblas/cblas_ctrmm.c | 123 + ml/dlib/dlib/external/cblas/cblas_ctrmv.c | 136 + ml/dlib/dlib/external/cblas/cblas_ctrsm.c | 132 + ml/dlib/dlib/external/cblas/cblas_ctrsv.c | 137 + ml/dlib/dlib/external/cblas/cblas_dasum.c | 23 + ml/dlib/dlib/external/cblas/cblas_daxpy.c | 22 + ml/dlib/dlib/external/cblas/cblas_dcopy.c | 22 + ml/dlib/dlib/external/cblas/cblas_ddot.c | 25 + ml/dlib/dlib/external/cblas/cblas_dgbmv.c | 70 + ml/dlib/dlib/external/cblas/cblas_dgemm.c | 94 + ml/dlib/dlib/external/cblas/cblas_dgemv.c | 67 + ml/dlib/dlib/external/cblas/cblas_dger.c | 40 + ml/dlib/dlib/external/cblas/cblas_dnrm2.c | 23 + ml/dlib/dlib/external/cblas/cblas_drot.c | 23 + ml/dlib/dlib/external/cblas/cblas_drotg.c | 14 + ml/dlib/dlib/external/cblas/cblas_drotm.c | 14 + ml/dlib/dlib/external/cblas/cblas_drotmg.c | 15 + ml/dlib/dlib/external/cblas/cblas_dsbmv.c | 66 + ml/dlib/dlib/external/cblas/cblas_dscal.c | 21 + ml/dlib/dlib/external/cblas/cblas_dsdot.c | 25 + ml/dlib/dlib/external/cblas/cblas_dspmv.c | 65 + ml/dlib/dlib/external/cblas/cblas_dspr.c | 59 + ml/dlib/dlib/external/cblas/cblas_dspr2.c | 59 + ml/dlib/dlib/external/cblas/cblas_dswap.c | 22 + ml/dlib/dlib/external/cblas/cblas_dsymm.c | 91 + ml/dlib/dlib/external/cblas/cblas_dsymv.c | 65 + ml/dlib/dlib/external/cblas/cblas_dsyr.c | 60 + ml/dlib/dlib/external/cblas/cblas_dsyr2.c | 65 + ml/dlib/dlib/external/cblas/cblas_dsyr2k.c | 94 + ml/dlib/dlib/external/cblas/cblas_dsyrk.c | 93 + ml/dlib/dlib/external/cblas/cblas_dtbmv.c | 103 + ml/dlib/dlib/external/cblas/cblas_dtbsv.c | 103 + ml/dlib/dlib/external/cblas/cblas_dtpmv.c | 98 + ml/dlib/dlib/external/cblas/cblas_dtpsv.c | 99 + ml/dlib/dlib/external/cblas/cblas_dtrmm.c | 125 + ml/dlib/dlib/external/cblas/cblas_dtrmv.c | 103 + ml/dlib/dlib/external/cblas/cblas_dtrsm.c | 130 + ml/dlib/dlib/external/cblas/cblas_dtrsv.c | 102 + ml/dlib/dlib/external/cblas/cblas_dzasum.c | 23 + ml/dlib/dlib/external/cblas/cblas_dznrm2.c | 23 + ml/dlib/dlib/external/cblas/cblas_f77.h | 701 + ml/dlib/dlib/external/cblas/cblas_icamax.c | 23 + ml/dlib/dlib/external/cblas/cblas_idamax.c | 23 + ml/dlib/dlib/external/cblas/cblas_isamax.c | 23 + ml/dlib/dlib/external/cblas/cblas_izamax.c | 23 + ml/dlib/dlib/external/cblas/cblas_sasum.c | 23 + ml/dlib/dlib/external/cblas/cblas_saxpy.c | 23 + ml/dlib/dlib/external/cblas/cblas_scasum.c | 23 + ml/dlib/dlib/external/cblas/cblas_scnrm2.c | 23 + ml/dlib/dlib/external/cblas/cblas_scopy.c | 22 + ml/dlib/dlib/external/cblas/cblas_sdot.c | 25 + ml/dlib/dlib/external/cblas/cblas_sdsdot.c | 25 + ml/dlib/dlib/external/cblas/cblas_sgbmv.c | 72 + ml/dlib/dlib/external/cblas/cblas_sgemm.c | 95 + ml/dlib/dlib/external/cblas/cblas_sgemv.c | 67 + ml/dlib/dlib/external/cblas/cblas_sger.c | 39 + ml/dlib/dlib/external/cblas/cblas_snrm2.c | 23 + ml/dlib/dlib/external/cblas/cblas_srot.c | 22 + ml/dlib/dlib/external/cblas/cblas_srotg.c | 14 + ml/dlib/dlib/external/cblas/cblas_srotm.c | 22 + ml/dlib/dlib/external/cblas/cblas_srotmg.c | 15 + ml/dlib/dlib/external/cblas/cblas_ssbmv.c | 65 + ml/dlib/dlib/external/cblas/cblas_sscal.c | 21 + ml/dlib/dlib/external/cblas/cblas_sspmv.c | 62 + ml/dlib/dlib/external/cblas/cblas_sspr.c | 61 + ml/dlib/dlib/external/cblas/cblas_sspr2.c | 60 + ml/dlib/dlib/external/cblas/cblas_sswap.c | 22 + ml/dlib/dlib/external/cblas/cblas_ssymm.c | 93 + ml/dlib/dlib/external/cblas/cblas_ssymv.c | 65 + ml/dlib/dlib/external/cblas/cblas_ssyr.c | 59 + ml/dlib/dlib/external/cblas/cblas_ssyr2.c | 65 + ml/dlib/dlib/external/cblas/cblas_ssyr2k.c | 96 + ml/dlib/dlib/external/cblas/cblas_ssyrk.c | 95 + ml/dlib/dlib/external/cblas/cblas_stbmv.c | 103 + ml/dlib/dlib/external/cblas/cblas_stbsv.c | 103 + ml/dlib/dlib/external/cblas/cblas_stpmv.c | 99 + ml/dlib/dlib/external/cblas/cblas_stpsv.c | 99 + ml/dlib/dlib/external/cblas/cblas_strmm.c | 125 + ml/dlib/dlib/external/cblas/cblas_strmv.c | 103 + ml/dlib/dlib/external/cblas/cblas_strsm.c | 120 + ml/dlib/dlib/external/cblas/cblas_strsv.c | 102 + ml/dlib/dlib/external/cblas/cblas_xerbla.c | 66 + ml/dlib/dlib/external/cblas/cblas_zaxpy.c | 22 + ml/dlib/dlib/external/cblas/cblas_zcopy.c | 22 + ml/dlib/dlib/external/cblas/cblas_zdotc_sub.c | 24 + ml/dlib/dlib/external/cblas/cblas_zdotu_sub.c | 24 + ml/dlib/dlib/external/cblas/cblas_zdscal.c | 21 + ml/dlib/dlib/external/cblas/cblas_zgbmv.c | 155 + ml/dlib/dlib/external/cblas/cblas_zgemm.c | 94 + ml/dlib/dlib/external/cblas/cblas_zgemv.c | 153 + ml/dlib/dlib/external/cblas/cblas_zgerc.c | 77 + ml/dlib/dlib/external/cblas/cblas_zgeru.c | 37 + ml/dlib/dlib/external/cblas/cblas_zhbmv.c | 145 + ml/dlib/dlib/external/cblas/cblas_zhemm.c | 91 + ml/dlib/dlib/external/cblas/cblas_zhemv.c | 146 + ml/dlib/dlib/external/cblas/cblas_zher.c | 99 + ml/dlib/dlib/external/cblas/cblas_zher2.c | 140 + ml/dlib/dlib/external/cblas/cblas_zher2k.c | 95 + ml/dlib/dlib/external/cblas/cblas_zherk.c | 90 + ml/dlib/dlib/external/cblas/cblas_zhpmv.c | 146 + ml/dlib/dlib/external/cblas/cblas_zhpr.c | 102 + ml/dlib/dlib/external/cblas/cblas_zhpr2.c | 137 + ml/dlib/dlib/external/cblas/cblas_zscal.c | 21 + ml/dlib/dlib/external/cblas/cblas_zswap.c | 22 + ml/dlib/dlib/external/cblas/cblas_zsymm.c | 91 + ml/dlib/dlib/external/cblas/cblas_zsyr2k.c | 93 + ml/dlib/dlib/external/cblas/cblas_zsyrk.c | 92 + ml/dlib/dlib/external/cblas/cblas_ztbmv.c | 139 + ml/dlib/dlib/external/cblas/cblas_ztbsv.c | 143 + ml/dlib/dlib/external/cblas/cblas_ztpmv.c | 133 + ml/dlib/dlib/external/cblas/cblas_ztpsv.c | 138 + ml/dlib/dlib/external/cblas/cblas_ztrmm.c | 126 + ml/dlib/dlib/external/cblas/cblas_ztrmv.c | 137 + ml/dlib/dlib/external/cblas/cblas_ztrsm.c | 132 + ml/dlib/dlib/external/cblas/cblas_ztrsv.c | 137 + ml/dlib/dlib/external/cblas/cdotcsub.f | 15 + ml/dlib/dlib/external/cblas/cdotusub.f | 15 + ml/dlib/dlib/external/cblas/dasumsub.f | 15 + ml/dlib/dlib/external/cblas/ddotsub.f | 15 + ml/dlib/dlib/external/cblas/dnrm2sub.f | 15 + ml/dlib/dlib/external/cblas/dsdotsub.f | 15 + ml/dlib/dlib/external/cblas/dzasumsub.f | 15 + ml/dlib/dlib/external/cblas/dznrm2sub.f | 15 + ml/dlib/dlib/external/cblas/icamaxsub.f | 15 + ml/dlib/dlib/external/cblas/idamaxsub.f | 15 + ml/dlib/dlib/external/cblas/isamaxsub.f | 15 + ml/dlib/dlib/external/cblas/izamaxsub.f | 15 + ml/dlib/dlib/external/cblas/sasumsub.f | 15 + ml/dlib/dlib/external/cblas/scasumsub.f | 15 + ml/dlib/dlib/external/cblas/scnrm2sub.f | 15 + ml/dlib/dlib/external/cblas/sdotsub.f | 15 + ml/dlib/dlib/external/cblas/sdsdotsub.f | 15 + ml/dlib/dlib/external/cblas/snrm2sub.f | 15 + ml/dlib/dlib/external/cblas/zdotcsub.f | 15 + ml/dlib/dlib/external/cblas/zdotusub.f | 15 + ml/dlib/dlib/external/libjpeg/README | 385 + ml/dlib/dlib/external/libjpeg/jcapimin.cpp | 280 + ml/dlib/dlib/external/libjpeg/jcapistd.cpp | 161 + ml/dlib/dlib/external/libjpeg/jccoefct.cpp | 449 + ml/dlib/dlib/external/libjpeg/jccolor.cpp | 459 + ml/dlib/dlib/external/libjpeg/jcdctmgr.cpp | 387 + ml/dlib/dlib/external/libjpeg/jchuff.cpp | 909 + ml/dlib/dlib/external/libjpeg/jchuff.h | 47 + ml/dlib/dlib/external/libjpeg/jcinit.cpp | 72 + ml/dlib/dlib/external/libjpeg/jcmainct.cpp | 293 + ml/dlib/dlib/external/libjpeg/jcmarker.cpp | 664 + ml/dlib/dlib/external/libjpeg/jcmaster.cpp | 590 + ml/dlib/dlib/external/libjpeg/jcomapi.cpp | 106 + ml/dlib/dlib/external/libjpeg/jconfig.h | 45 + ml/dlib/dlib/external/libjpeg/jcparam.cpp | 610 + ml/dlib/dlib/external/libjpeg/jcphuff.cpp | 833 + ml/dlib/dlib/external/libjpeg/jcprepct.cpp | 354 + ml/dlib/dlib/external/libjpeg/jcsample.cpp | 519 + ml/dlib/dlib/external/libjpeg/jdapimin.cpp | 395 + ml/dlib/dlib/external/libjpeg/jdapistd.cpp | 275 + ml/dlib/dlib/external/libjpeg/jdatadst.cpp | 151 + ml/dlib/dlib/external/libjpeg/jdatasrc.cpp | 212 + ml/dlib/dlib/external/libjpeg/jdcoefct.cpp | 736 + ml/dlib/dlib/external/libjpeg/jdcolor.cpp | 396 + ml/dlib/dlib/external/libjpeg/jdct.h | 176 + ml/dlib/dlib/external/libjpeg/jddctmgr.cpp | 269 + ml/dlib/dlib/external/libjpeg/jdhuff.cpp | 654 + ml/dlib/dlib/external/libjpeg/jdhuff.h | 201 + ml/dlib/dlib/external/libjpeg/jdinput.cpp | 381 + ml/dlib/dlib/external/libjpeg/jdmainct.cpp | 512 + ml/dlib/dlib/external/libjpeg/jdmarker.cpp | 1360 ++ ml/dlib/dlib/external/libjpeg/jdmaster.cpp | 557 + ml/dlib/dlib/external/libjpeg/jdmerge.cpp | 400 + ml/dlib/dlib/external/libjpeg/jdphuff.cpp | 671 + ml/dlib/dlib/external/libjpeg/jdpostct.cpp | 290 + ml/dlib/dlib/external/libjpeg/jdsample.cpp | 478 + ml/dlib/dlib/external/libjpeg/jerror.cpp | 252 + ml/dlib/dlib/external/libjpeg/jerror.h | 291 + ml/dlib/dlib/external/libjpeg/jfdctflt.cpp | 168 + ml/dlib/dlib/external/libjpeg/jfdctfst.cpp | 224 + ml/dlib/dlib/external/libjpeg/jfdctint.cpp | 283 + ml/dlib/dlib/external/libjpeg/jidctflt.cpp | 242 + ml/dlib/dlib/external/libjpeg/jidctfst.cpp | 368 + ml/dlib/dlib/external/libjpeg/jidctint.cpp | 389 + ml/dlib/dlib/external/libjpeg/jidctred.cpp | 398 + ml/dlib/dlib/external/libjpeg/jinclude.h | 91 + ml/dlib/dlib/external/libjpeg/jmemmgr.cpp | 1118 + ml/dlib/dlib/external/libjpeg/jmemnobs.cpp | 109 + ml/dlib/dlib/external/libjpeg/jmemsys.h | 198 + ml/dlib/dlib/external/libjpeg/jmorecfg.h | 356 + ml/dlib/dlib/external/libjpeg/jpegint.h | 392 + ml/dlib/dlib/external/libjpeg/jpeglib.h | 1096 + ml/dlib/dlib/external/libjpeg/jquant1.cpp | 856 + ml/dlib/dlib/external/libjpeg/jquant2.cpp | 1310 ++ ml/dlib/dlib/external/libjpeg/jutils.cpp | 179 + ml/dlib/dlib/external/libjpeg/jversion.h | 14 + ml/dlib/dlib/external/libpng/LICENSE | 111 + ml/dlib/dlib/external/libpng/README | 202 + ml/dlib/dlib/external/libpng/arm/arm_init.c | 232 + ml/dlib/dlib/external/libpng/arm/filter_neon.S | 245 + .../external/libpng/arm/filter_neon_intrinsics.c | 372 + ml/dlib/dlib/external/libpng/png.c | 4299 ++++ ml/dlib/dlib/external/libpng/png.h | 3319 +++ ml/dlib/dlib/external/libpng/pngconf.h | 626 + ml/dlib/dlib/external/libpng/pngdebug.h | 157 + ml/dlib/dlib/external/libpng/pngerror.c | 932 + ml/dlib/dlib/external/libpng/pngget.c | 1177 + ml/dlib/dlib/external/libpng/pnginfo.h | 260 + ml/dlib/dlib/external/libpng/pnglibconf.h | 211 + ml/dlib/dlib/external/libpng/pngmem.c | 277 + ml/dlib/dlib/external/libpng/pngpread.c | 1291 ++ ml/dlib/dlib/external/libpng/pngpriv.h | 2047 ++ ml/dlib/dlib/external/libpng/pngread.c | 4000 ++++ ml/dlib/dlib/external/libpng/pngrio.c | 118 + ml/dlib/dlib/external/libpng/pngrtran.c | 5110 +++++ ml/dlib/dlib/external/libpng/pngrutil.c | 4475 ++++ ml/dlib/dlib/external/libpng/pngset.c | 1597 ++ ml/dlib/dlib/external/libpng/pngstruct.h | 489 + ml/dlib/dlib/external/libpng/pngtrans.c | 841 + ml/dlib/dlib/external/libpng/pngwio.c | 164 + ml/dlib/dlib/external/libpng/pngwrite.c | 2330 ++ ml/dlib/dlib/external/libpng/pngwtran.c | 637 + ml/dlib/dlib/external/libpng/pngwutil.c | 3023 +++ ml/dlib/dlib/external/pybind11/CMakeLists.txt | 155 + ml/dlib/dlib/external/pybind11/CONTRIBUTING.md | 47 + ml/dlib/dlib/external/pybind11/LICENSE | 29 + ml/dlib/dlib/external/pybind11/README.md | 129 + .../dlib/external/pybind11/include/pybind11/attr.h | 489 + .../pybind11/include/pybind11/buffer_info.h | 108 + .../dlib/external/pybind11/include/pybind11/cast.h | 2063 ++ .../external/pybind11/include/pybind11/chrono.h | 162 + .../external/pybind11/include/pybind11/common.h | 2 + .../external/pybind11/include/pybind11/complex.h | 61 + .../pybind11/include/pybind11/detail/class.h | 626 + .../pybind11/include/pybind11/detail/common.h | 802 + .../pybind11/include/pybind11/detail/descr.h | 185 + .../pybind11/include/pybind11/detail/init.h | 335 + .../pybind11/include/pybind11/detail/internals.h | 249 + .../pybind11/include/pybind11/detail/typeid.h | 53 + .../external/pybind11/include/pybind11/eigen.h | 612 + .../external/pybind11/include/pybind11/embed.h | 194 + .../dlib/external/pybind11/include/pybind11/eval.h | 117 + .../pybind11/include/pybind11/functional.h | 85 + .../external/pybind11/include/pybind11/iostream.h | 200 + .../external/pybind11/include/pybind11/numpy.h | 1600 ++ .../external/pybind11/include/pybind11/operators.h | 168 + .../external/pybind11/include/pybind11/options.h | 65 + .../external/pybind11/include/pybind11/pybind11.h | 1963 ++ .../external/pybind11/include/pybind11/pytypes.h | 1332 ++ .../dlib/external/pybind11/include/pybind11/stl.h | 370 + .../external/pybind11/include/pybind11/stl_bind.h | 599 + .../dlib/external/pybind11/tools/FindCatch.cmake | 57 + .../dlib/external/pybind11/tools/FindEigen3.cmake | 81 + .../pybind11/tools/FindPythonLibsNew.cmake | 195 + .../dlib/external/pybind11/tools/check-style.sh | 70 + ml/dlib/dlib/external/pybind11/tools/libsize.py | 38 + ml/dlib/dlib/external/pybind11/tools/mkdoc.py | 304 + .../pybind11/tools/pybind11Config.cmake.in | 100 + .../external/pybind11/tools/pybind11Tools.cmake | 202 + ml/dlib/dlib/external/zlib/README | 115 + ml/dlib/dlib/external/zlib/adler32.c | 179 + ml/dlib/dlib/external/zlib/compress.c | 80 + ml/dlib/dlib/external/zlib/crc32.c | 425 + ml/dlib/dlib/external/zlib/crc32.h | 441 + ml/dlib/dlib/external/zlib/deflate.c | 1967 ++ ml/dlib/dlib/external/zlib/deflate.h | 346 + ml/dlib/dlib/external/zlib/gzclose.c | 25 + ml/dlib/dlib/external/zlib/gzguts.h | 219 + ml/dlib/dlib/external/zlib/gzlib.c | 634 + ml/dlib/dlib/external/zlib/gzread.c | 594 + ml/dlib/dlib/external/zlib/gzwrite.c | 577 + ml/dlib/dlib/external/zlib/infback.c | 640 + ml/dlib/dlib/external/zlib/inffast.c | 340 + ml/dlib/dlib/external/zlib/inffast.h | 11 + ml/dlib/dlib/external/zlib/inffixed.h | 94 + ml/dlib/dlib/external/zlib/inflate.c | 1512 ++ ml/dlib/dlib/external/zlib/inflate.h | 122 + ml/dlib/dlib/external/zlib/inftrees.c | 306 + ml/dlib/dlib/external/zlib/inftrees.h | 62 + ml/dlib/dlib/external/zlib/trees.c | 1226 ++ ml/dlib/dlib/external/zlib/trees.h | 128 + ml/dlib/dlib/external/zlib/uncompr.c | 59 + ml/dlib/dlib/external/zlib/zconf.h | 511 + ml/dlib/dlib/external/zlib/zlib.h | 1768 ++ ml/dlib/dlib/external/zlib/zutil.c | 324 + ml/dlib/dlib/external/zlib/zutil.h | 253 + ml/dlib/dlib/filtering.h | 12 + ml/dlib/dlib/filtering/kalman_filter.cpp | 104 + ml/dlib/dlib/filtering/kalman_filter.h | 382 + ml/dlib/dlib/filtering/kalman_filter_abstract.h | 492 + ml/dlib/dlib/filtering/rls_filter.h | 198 + ml/dlib/dlib/filtering/rls_filter_abstract.h | 171 + ml/dlib/dlib/float_details.h | 161 + ml/dlib/dlib/fstream | 1 + ml/dlib/dlib/general_hash/count_bits.h | 82 + ml/dlib/dlib/general_hash/count_bits_abstract.h | 48 + ml/dlib/dlib/general_hash/general_hash.h | 80 + ml/dlib/dlib/general_hash/hash.h | 142 + ml/dlib/dlib/general_hash/hash_abstract.h | 182 + ml/dlib/dlib/general_hash/murmur_hash3.h | 519 + ml/dlib/dlib/general_hash/murmur_hash3_abstract.h | 125 + ml/dlib/dlib/general_hash/random_hashing.h | 877 + .../dlib/general_hash/random_hashing_abstract.h | 58 + ml/dlib/dlib/geometry.h | 14 + ml/dlib/dlib/geometry/border_enumerator.h | 186 + ml/dlib/dlib/geometry/border_enumerator_abstract.h | 126 + ml/dlib/dlib/geometry/drectangle.h | 488 + ml/dlib/dlib/geometry/drectangle_abstract.h | 628 + ml/dlib/dlib/geometry/point_transforms.h | 989 + ml/dlib/dlib/geometry/point_transforms_abstract.h | 797 + ml/dlib/dlib/geometry/rectangle.h | 824 + ml/dlib/dlib/geometry/rectangle_abstract.h | 836 + ml/dlib/dlib/geometry/vector.h | 1330 ++ ml/dlib/dlib/geometry/vector_abstract.h | 489 + ml/dlib/dlib/global_optimization.h | 14 + ml/dlib/dlib/global_optimization/find_max_global.h | 511 + .../global_optimization/find_max_global_abstract.h | 496 + .../global_optimization/global_function_search.cpp | 942 + .../global_optimization/global_function_search.h | 245 + .../global_function_search_abstract.h | 605 + .../global_optimization/upper_bound_function.h | 286 + .../upper_bound_function_abstract.h | 212 + ml/dlib/dlib/graph.h | 37 + ml/dlib/dlib/graph/graph_kernel_1.h | 629 + ml/dlib/dlib/graph/graph_kernel_abstract.h | 329 + ml/dlib/dlib/graph_cuts.h | 14 + .../dlib/graph_cuts/find_max_factor_graph_potts.h | 959 + .../find_max_factor_graph_potts_abstract.h | 636 + ml/dlib/dlib/graph_cuts/general_flow_graph.h | 172 + ml/dlib/dlib/graph_cuts/general_potts_problem.h | 99 + ml/dlib/dlib/graph_cuts/graph_labeler.h | 211 + ml/dlib/dlib/graph_cuts/graph_labeler_abstract.h | 185 + ml/dlib/dlib/graph_cuts/min_cut.h | 571 + ml/dlib/dlib/graph_cuts/min_cut_abstract.h | 476 + ml/dlib/dlib/graph_utils.h | 12 + ml/dlib/dlib/graph_utils/edge_list_graphs.h | 593 + .../dlib/graph_utils/edge_list_graphs_abstract.h | 358 + .../graph_utils/find_k_nearest_neighbors_lsh.h | 217 + .../find_k_nearest_neighbors_lsh_abstract.h | 102 + ml/dlib/dlib/graph_utils/function_objects.h | 129 + .../dlib/graph_utils/function_objects_abstract.h | 209 + ml/dlib/dlib/graph_utils/graph_utils.h | 1227 ++ ml/dlib/dlib/graph_utils/graph_utils_abstract.h | 452 + ml/dlib/dlib/graph_utils/ordered_sample_pair.h | 125 + .../graph_utils/ordered_sample_pair_abstract.h | 128 + ml/dlib/dlib/graph_utils/sample_pair.h | 179 + ml/dlib/dlib/graph_utils/sample_pair_abstract.h | 192 + ml/dlib/dlib/graph_utils_threaded.h | 12 + ml/dlib/dlib/gui_core.h | 20 + ml/dlib/dlib/gui_core/gui_core_kernel_1.cpp | 2204 ++ ml/dlib/dlib/gui_core/gui_core_kernel_1.h | 420 + ml/dlib/dlib/gui_core/gui_core_kernel_2.cpp | 1996 ++ ml/dlib/dlib/gui_core/gui_core_kernel_2.h | 419 + ml/dlib/dlib/gui_core/gui_core_kernel_abstract.h | 792 + ml/dlib/dlib/gui_core/windows.h | 6 + ml/dlib/dlib/gui_core/xlib.h | 6 + ml/dlib/dlib/gui_widgets.h | 18 + ml/dlib/dlib/gui_widgets/base_widgets.cpp | 3343 +++ ml/dlib/dlib/gui_widgets/base_widgets.h | 2678 +++ ml/dlib/dlib/gui_widgets/base_widgets_abstract.h | 2290 ++ ml/dlib/dlib/gui_widgets/canvas_drawing.cpp | 101 + ml/dlib/dlib/gui_widgets/canvas_drawing.h | 964 + ml/dlib/dlib/gui_widgets/canvas_drawing_abstract.h | 364 + ml/dlib/dlib/gui_widgets/drawable.cpp | 544 + ml/dlib/dlib/gui_widgets/drawable.h | 527 + ml/dlib/dlib/gui_widgets/drawable_abstract.h | 717 + ml/dlib/dlib/gui_widgets/fonts.cpp | 673 + ml/dlib/dlib/gui_widgets/fonts.h | 628 + ml/dlib/dlib/gui_widgets/fonts_abstract.h | 492 + ml/dlib/dlib/gui_widgets/nativefont.h | 612 + ml/dlib/dlib/gui_widgets/style.cpp | 998 + ml/dlib/dlib/gui_widgets/style.h | 825 + ml/dlib/dlib/gui_widgets/style_abstract.h | 777 + ml/dlib/dlib/gui_widgets/widgets.cpp | 7341 +++++++ ml/dlib/dlib/gui_widgets/widgets.h | 4165 ++++ ml/dlib/dlib/gui_widgets/widgets_abstract.h | 3461 +++ ml/dlib/dlib/hash.h | 14 + ml/dlib/dlib/hash_map.h | 63 + ml/dlib/dlib/hash_map/hash_map_kernel_1.h | 460 + ml/dlib/dlib/hash_map/hash_map_kernel_abstract.h | 247 + ml/dlib/dlib/hash_map/hash_map_kernel_c.h | 276 + ml/dlib/dlib/hash_set.h | 63 + ml/dlib/dlib/hash_set/hash_set_kernel_1.h | 391 + ml/dlib/dlib/hash_set/hash_set_kernel_abstract.h | 207 + ml/dlib/dlib/hash_set/hash_set_kernel_c.h | 190 + ml/dlib/dlib/hash_table.h | 60 + ml/dlib/dlib/hash_table/hash_table_kernel_1.h | 819 + ml/dlib/dlib/hash_table/hash_table_kernel_2.h | 612 + .../dlib/hash_table/hash_table_kernel_abstract.h | 253 + ml/dlib/dlib/hash_table/hash_table_kernel_c.h | 194 + ml/dlib/dlib/http_client/http_client.cpp | 743 + ml/dlib/dlib/http_client/http_client.h | 101 + ml/dlib/dlib/http_client/http_client_abstract.h | 218 + ml/dlib/dlib/image_io.h | 20 + ml/dlib/dlib/image_keypoint.h | 16 + .../image_keypoint/binned_vector_feature_image.h | 433 + .../binned_vector_feature_image_abstract.h | 287 + .../image_keypoint/build_separable_poly_filters.h | 186 + ml/dlib/dlib/image_keypoint/draw_surf_points.h | 40 + .../image_keypoint/draw_surf_points_abstract.h | 30 + ml/dlib/dlib/image_keypoint/fine_hog_image.h | 378 + .../dlib/image_keypoint/fine_hog_image_abstract.h | 276 + ml/dlib/dlib/image_keypoint/hashed_feature_image.h | 518 + .../image_keypoint/hashed_feature_image_abstract.h | 303 + ml/dlib/dlib/image_keypoint/hessian_pyramid.h | 531 + .../dlib/image_keypoint/hessian_pyramid_abstract.h | 244 + ml/dlib/dlib/image_keypoint/hog.h | 514 + ml/dlib/dlib/image_keypoint/hog_abstract.h | 335 + .../nearest_neighbor_feature_image.h | 408 + .../nearest_neighbor_feature_image_abstract.h | 254 + ml/dlib/dlib/image_keypoint/poly_image.h | 649 + ml/dlib/dlib/image_keypoint/poly_image_abstract.h | 335 + ml/dlib/dlib/image_keypoint/surf.h | 295 + ml/dlib/dlib/image_keypoint/surf_abstract.h | 163 + ml/dlib/dlib/image_loader/image_loader.h | 863 + ml/dlib/dlib/image_loader/image_loader_abstract.h | 136 + ml/dlib/dlib/image_loader/jpeg_loader.cpp | 173 + ml/dlib/dlib/image_loader/jpeg_loader.h | 109 + ml/dlib/dlib/image_loader/jpeg_loader_abstract.h | 133 + ml/dlib/dlib/image_loader/load_image.h | 226 + ml/dlib/dlib/image_loader/load_image_abstract.h | 37 + ml/dlib/dlib/image_loader/png_loader.cpp | 222 + ml/dlib/dlib/image_loader/png_loader.h | 223 + ml/dlib/dlib/image_loader/png_loader_abstract.h | 162 + ml/dlib/dlib/image_processing.h | 28 + .../dlib/image_processing/box_overlap_testing.h | 215 + .../box_overlap_testing_abstract.h | 201 + .../dlib/image_processing/correlation_tracker.h | 404 + .../correlation_tracker_abstract.h | 162 + .../image_processing/detection_template_tools.h | 113 + .../detection_template_tools_abstract.h | 95 + .../dlib/image_processing/frontal_face_detector.h | 2373 ++ .../frontal_face_detector_abstract.h | 25 + .../dlib/image_processing/full_object_detection.h | 191 + .../full_object_detection_abstract.h | 203 + ml/dlib/dlib/image_processing/generic_image.h | 431 + ml/dlib/dlib/image_processing/object_detector.h | 626 + .../image_processing/object_detector_abstract.h | 404 + .../remove_unobtainable_rectangles.h | 317 + .../remove_unobtainable_rectangles_abstract.h | 56 + .../dlib/image_processing/render_face_detections.h | 99 + .../render_face_detections_abstract.h | 59 + ml/dlib/dlib/image_processing/scan_fhog_pyramid.h | 1348 ++ .../image_processing/scan_fhog_pyramid_abstract.h | 784 + ml/dlib/dlib/image_processing/scan_image.h | 368 + .../dlib/image_processing/scan_image_abstract.h | 227 + ml/dlib/dlib/image_processing/scan_image_boxes.h | 630 + .../image_processing/scan_image_boxes_abstract.h | 394 + ml/dlib/dlib/image_processing/scan_image_custom.h | 401 + .../image_processing/scan_image_custom_abstract.h | 390 + ml/dlib/dlib/image_processing/scan_image_pyramid.h | 1101 + .../image_processing/scan_image_pyramid_abstract.h | 495 + .../image_processing/scan_image_pyramid_tools.h | 180 + .../scan_image_pyramid_tools_abstract.h | 118 + .../dlib/image_processing/setup_hashed_features.h | 219 + .../setup_hashed_features_abstract.h | 210 + ml/dlib/dlib/image_processing/shape_predictor.h | 524 + .../image_processing/shape_predictor_abstract.h | 195 + .../image_processing/shape_predictor_trainer.h | 852 + .../shape_predictor_trainer_abstract.h | 418 + ml/dlib/dlib/image_saver/dng_shared.h | 288 + ml/dlib/dlib/image_saver/image_saver.h | 688 + ml/dlib/dlib/image_saver/image_saver_abstract.h | 129 + ml/dlib/dlib/image_saver/save_jpeg.cpp | 175 + ml/dlib/dlib/image_saver/save_jpeg.h | 82 + ml/dlib/dlib/image_saver/save_jpeg_abstract.h | 52 + ml/dlib/dlib/image_saver/save_png.cpp | 124 + ml/dlib/dlib/image_saver/save_png.h | 162 + ml/dlib/dlib/image_saver/save_png_abstract.h | 50 + ml/dlib/dlib/image_transforms.h | 31 + ml/dlib/dlib/image_transforms/assign_image.h | 385 + .../dlib/image_transforms/assign_image_abstract.h | 196 + ml/dlib/dlib/image_transforms/colormaps.h | 269 + ml/dlib/dlib/image_transforms/colormaps_abstract.h | 152 + ml/dlib/dlib/image_transforms/draw.h | 396 + ml/dlib/dlib/image_transforms/draw_abstract.h | 150 + ml/dlib/dlib/image_transforms/edge_detector.h | 302 + .../dlib/image_transforms/edge_detector_abstract.h | 112 + ml/dlib/dlib/image_transforms/equalize_histogram.h | 143 + .../image_transforms/equalize_histogram_abstract.h | 91 + ml/dlib/dlib/image_transforms/fhog.h | 1404 ++ ml/dlib/dlib/image_transforms/fhog_abstract.h | 346 + ml/dlib/dlib/image_transforms/hough_transform.h | 358 + .../image_transforms/hough_transform_abstract.h | 145 + ml/dlib/dlib/image_transforms/image_pyramid.h | 1238 ++ .../dlib/image_transforms/image_pyramid_abstract.h | 384 + ml/dlib/dlib/image_transforms/integral_image.h | 190 + .../image_transforms/integral_image_abstract.h | 169 + ml/dlib/dlib/image_transforms/interpolation.h | 2193 ++ .../dlib/image_transforms/interpolation_abstract.h | 1480 ++ .../dlib/image_transforms/label_connected_blobs.h | 188 + .../label_connected_blobs_abstract.h | 199 + ml/dlib/dlib/image_transforms/lbp.h | 307 + ml/dlib/dlib/image_transforms/lbp_abstract.h | 139 + .../image_transforms/morphological_operations.h | 846 + .../morphological_operations_abstract.h | 316 + .../dlib/image_transforms/random_color_transform.h | 157 + .../random_color_transform_abstract.h | 94 + ml/dlib/dlib/image_transforms/random_cropper.h | 361 + .../image_transforms/random_cropper_abstract.h | 346 + ml/dlib/dlib/image_transforms/segment_image.h | 730 + .../dlib/image_transforms/segment_image_abstract.h | 126 + ml/dlib/dlib/image_transforms/spatial_filtering.h | 1580 ++ .../image_transforms/spatial_filtering_abstract.h | 487 + ml/dlib/dlib/image_transforms/thresholding.h | 340 + .../dlib/image_transforms/thresholding_abstract.h | 139 + ml/dlib/dlib/interfaces/cmd_line_parser_option.h | 107 + ml/dlib/dlib/interfaces/enumerable.h | 130 + ml/dlib/dlib/interfaces/map_pair.h | 74 + ml/dlib/dlib/interfaces/remover.h | 220 + ml/dlib/dlib/iomanip | 1 + ml/dlib/dlib/iosfwd | 1 + ml/dlib/dlib/iosockstream.h | 11 + ml/dlib/dlib/iosockstream/iosockstream.h | 171 + ml/dlib/dlib/iosockstream/iosockstream_abstract.h | 171 + ml/dlib/dlib/iostream | 1 + ml/dlib/dlib/is_kind.h | 162 + ml/dlib/dlib/istream | 1 + ml/dlib/dlib/java/CMakeLists.txt | 32 + ml/dlib/dlib/java/cmake_swig_jni | 265 + ml/dlib/dlib/java/java_array.h | 605 + ml/dlib/dlib/java/run_test.sh | 17 + ml/dlib/dlib/java/swig_api.h | 126 + ml/dlib/dlib/java/swig_test.java | 254 + ml/dlib/dlib/linker.h | 9 + ml/dlib/dlib/linker/linker_kernel_1.cpp | 357 + ml/dlib/dlib/linker/linker_kernel_1.h | 141 + ml/dlib/dlib/linker/linker_kernel_abstract.h | 141 + ml/dlib/dlib/locale | 1 + ml/dlib/dlib/logger.h | 11 + ml/dlib/dlib/logger/extra_logger_headers.cpp | 40 + ml/dlib/dlib/logger/extra_logger_headers.h | 41 + ml/dlib/dlib/logger/logger_config_file.cpp | 214 + ml/dlib/dlib/logger/logger_config_file.h | 135 + ml/dlib/dlib/logger/logger_kernel_1.cpp | 498 + ml/dlib/dlib/logger/logger_kernel_1.h | 687 + ml/dlib/dlib/logger/logger_kernel_abstract.h | 429 + ml/dlib/dlib/lsh.h | 14 + ml/dlib/dlib/lsh/create_random_projection_hash.h | 232 + .../lsh/create_random_projection_hash_abstract.h | 148 + ml/dlib/dlib/lsh/hashes.h | 219 + ml/dlib/dlib/lsh/hashes_abstract.h | 286 + ml/dlib/dlib/lsh/projection_hash.h | 118 + ml/dlib/dlib/lsh/projection_hash_abstract.h | 119 + ml/dlib/dlib/lz77_buffer.h | 47 + ml/dlib/dlib/lz77_buffer/lz77_buffer_kernel_1.h | 263 + ml/dlib/dlib/lz77_buffer/lz77_buffer_kernel_2.h | 504 + .../dlib/lz77_buffer/lz77_buffer_kernel_abstract.h | 210 + ml/dlib/dlib/lz77_buffer/lz77_buffer_kernel_c.h | 169 + ml/dlib/dlib/lzp_buffer.h | 46 + ml/dlib/dlib/lzp_buffer/lzp_buffer_kernel_1.h | 236 + ml/dlib/dlib/lzp_buffer/lzp_buffer_kernel_2.h | 319 + .../dlib/lzp_buffer/lzp_buffer_kernel_abstract.h | 130 + ml/dlib/dlib/lzp_buffer/lzp_buffer_kernel_c.h | 101 + ml/dlib/dlib/manifold_regularization.h | 13 + .../linear_manifold_regularizer.h | 328 + .../linear_manifold_regularizer_abstract.h | 137 + ml/dlib/dlib/map.h | 59 + ml/dlib/dlib/map/map_kernel_1.h | 436 + ml/dlib/dlib/map/map_kernel_abstract.h | 235 + ml/dlib/dlib/map/map_kernel_c.h | 248 + ml/dlib/dlib/matlab/CMakeLists.txt | 22 + ml/dlib/dlib/matlab/README.txt | 20 + ml/dlib/dlib/matlab/call_matlab.h | 852 + ml/dlib/dlib/matlab/cmake_mex_wrapper | 103 + ml/dlib/dlib/matlab/example.m | 16 + ml/dlib/dlib/matlab/example_mex_callback.cpp | 52 + ml/dlib/dlib/matlab/example_mex_class.cpp | 72 + ml/dlib/dlib/matlab/example_mex_function.cpp | 84 + ml/dlib/dlib/matlab/example_mex_struct.cpp | 55 + ml/dlib/dlib/matlab/mex_wrapper.cpp | 5144 +++++ ml/dlib/dlib/matlab/subprocess_stream.cpp | 537 + ml/dlib/dlib/matlab/subprocess_stream.h | 223 + ml/dlib/dlib/matrix.h | 24 + ml/dlib/dlib/matrix/cblas_constants.h | 22 + ml/dlib/dlib/matrix/lapack/fortran_id.h | 62 + ml/dlib/dlib/matrix/lapack/gees.h | 264 + ml/dlib/dlib/matrix/lapack/geev.h | 234 + ml/dlib/dlib/matrix/lapack/geqrf.h | 168 + ml/dlib/dlib/matrix/lapack/gesdd.h | 364 + ml/dlib/dlib/matrix/lapack/gesvd.h | 323 + ml/dlib/dlib/matrix/lapack/getrf.h | 132 + ml/dlib/dlib/matrix/lapack/ormqr.h | 224 + ml/dlib/dlib/matrix/lapack/pbtrf.h | 178 + ml/dlib/dlib/matrix/lapack/potrf.h | 174 + ml/dlib/dlib/matrix/lapack/syev.h | 218 + ml/dlib/dlib/matrix/lapack/syevr.h | 445 + ml/dlib/dlib/matrix/matrix.h | 2162 ++ ml/dlib/dlib/matrix/matrix_abstract.h | 857 + ml/dlib/dlib/matrix/matrix_assign.h | 978 + ml/dlib/dlib/matrix/matrix_assign_fwd.h | 413 + ml/dlib/dlib/matrix/matrix_blas_bindings.h | 1637 ++ ml/dlib/dlib/matrix/matrix_cholesky.h | 231 + ml/dlib/dlib/matrix/matrix_conj_trans.h | 71 + ml/dlib/dlib/matrix/matrix_conv.h | 358 + ml/dlib/dlib/matrix/matrix_conv_abstract.h | 158 + ml/dlib/dlib/matrix/matrix_data_layout.h | 1271 ++ ml/dlib/dlib/matrix/matrix_data_layout_abstract.h | 40 + ml/dlib/dlib/matrix/matrix_default_mul.h | 134 + ml/dlib/dlib/matrix/matrix_eigenvalue.h | 1379 ++ ml/dlib/dlib/matrix/matrix_exp.h | 271 + ml/dlib/dlib/matrix/matrix_exp_abstract.h | 210 + ml/dlib/dlib/matrix/matrix_expressions.h | 280 + ml/dlib/dlib/matrix/matrix_fft.h | 846 + ml/dlib/dlib/matrix/matrix_fft_abstract.h | 118 + ml/dlib/dlib/matrix/matrix_fwd.h | 31 + ml/dlib/dlib/matrix/matrix_generic_image.h | 110 + ml/dlib/dlib/matrix/matrix_la.h | 1807 ++ ml/dlib/dlib/matrix/matrix_la_abstract.h | 1005 + ml/dlib/dlib/matrix/matrix_lu.h | 361 + ml/dlib/dlib/matrix/matrix_mat.h | 733 + ml/dlib/dlib/matrix/matrix_mat_abstract.h | 243 + ml/dlib/dlib/matrix/matrix_math_functions.h | 448 + .../dlib/matrix/matrix_math_functions_abstract.h | 595 + ml/dlib/dlib/matrix/matrix_op.h | 479 + ml/dlib/dlib/matrix/matrix_qr.h | 466 + ml/dlib/dlib/matrix/matrix_read_from_istream.h | 108 + ml/dlib/dlib/matrix/matrix_subexp.h | 1566 ++ ml/dlib/dlib/matrix/matrix_subexp_abstract.h | 570 + ml/dlib/dlib/matrix/matrix_trsm.h | 654 + ml/dlib/dlib/matrix/matrix_utilities.h | 4544 ++++ ml/dlib/dlib/matrix/matrix_utilities_abstract.h | 1874 ++ ml/dlib/dlib/matrix/symmetric_matrix_cache.h | 464 + .../dlib/matrix/symmetric_matrix_cache_abstract.h | 63 + ml/dlib/dlib/md5.h | 3 + ml/dlib/dlib/md5/md5_kernel_1.cpp | 617 + ml/dlib/dlib/md5/md5_kernel_1.h | 50 + ml/dlib/dlib/md5/md5_kernel_abstract.h | 83 + ml/dlib/dlib/member_function_pointer.h | 10 + ml/dlib/dlib/member_function_pointer/make_mfp.h | 179 + .../member_function_pointer/make_mfp_abstract.h | 207 + .../member_function_pointer_kernel_1.h | 498 + .../member_function_pointer_kernel_abstract.h | 483 + ml/dlib/dlib/memory_manager.h | 73 + .../dlib/memory_manager/memory_manager_kernel_1.h | 305 + .../dlib/memory_manager/memory_manager_kernel_2.h | 253 + .../dlib/memory_manager/memory_manager_kernel_3.h | 385 + .../memory_manager_kernel_abstract.h | 146 + ml/dlib/dlib/memory_manager_global.h | 38 + .../memory_manager_global_kernel_1.h | 113 + .../memory_manager_global_kernel_abstract.h | 181 + ml/dlib/dlib/memory_manager_stateless.h | 72 + .../memory_manager_stateless_kernel_1.h | 86 + .../memory_manager_stateless_kernel_2.h | 119 + .../memory_manager_stateless_kernel_abstract.h | 142 + ml/dlib/dlib/metaprogramming.h | 71 + ml/dlib/dlib/misc_api.h | 20 + ml/dlib/dlib/misc_api/misc_api_kernel_1.cpp | 149 + ml/dlib/dlib/misc_api/misc_api_kernel_1.h | 110 + ml/dlib/dlib/misc_api/misc_api_kernel_2.cpp | 123 + ml/dlib/dlib/misc_api/misc_api_kernel_2.h | 81 + ml/dlib/dlib/misc_api/misc_api_kernel_abstract.h | 159 + ml/dlib/dlib/misc_api/misc_api_shared.h | 57 + ml/dlib/dlib/misc_api/posix.h | 6 + ml/dlib/dlib/misc_api/windows.h | 6 + ml/dlib/dlib/mlp.h | 30 + ml/dlib/dlib/mlp/mlp_kernel_1.h | 394 + ml/dlib/dlib/mlp/mlp_kernel_abstract.h | 225 + ml/dlib/dlib/mlp/mlp_kernel_c.h | 151 + ml/dlib/dlib/noncopyable.h | 32 + ml/dlib/dlib/numeric_constants.h | 53 + ml/dlib/dlib/numerical_integration.h | 8 + .../integrate_function_adapt_simpson.h | 93 + .../integrate_function_adapt_simpson_abstract.h | 34 + ml/dlib/dlib/opencv.h | 17 + ml/dlib/dlib/opencv/cv_image.h | 225 + ml/dlib/dlib/opencv/cv_image_abstract.h | 280 + ml/dlib/dlib/opencv/to_open_cv.h | 46 + ml/dlib/dlib/opencv/to_open_cv_abstract.h | 34 + ml/dlib/dlib/optimization.h | 24 + ml/dlib/dlib/optimization/elastic_net.h | 389 + ml/dlib/dlib/optimization/elastic_net_abstract.h | 190 + .../optimization/find_max_factor_graph_nmplp.h | 337 + .../find_max_factor_graph_nmplp_abstract.h | 365 + .../optimization/find_max_factor_graph_viterbi.h | 232 + .../find_max_factor_graph_viterbi_abstract.h | 131 + ml/dlib/dlib/optimization/find_max_parse_cky.h | 414 + .../optimization/find_max_parse_cky_abstract.h | 388 + .../dlib/optimization/find_optimal_parameters.h | 117 + .../find_optimal_parameters_abstract.h | 58 + ml/dlib/dlib/optimization/isotonic_regression.h | 169 + .../optimization/isotonic_regression_abstract.h | 128 + ml/dlib/dlib/optimization/max_cost_assignment.h | 288 + .../optimization/max_cost_assignment_abstract.h | 63 + ml/dlib/dlib/optimization/max_sum_submatrix.h | 285 + .../dlib/optimization/max_sum_submatrix_abstract.h | 49 + ml/dlib/dlib/optimization/optimization.h | 714 + ml/dlib/dlib/optimization/optimization_abstract.h | 468 + ml/dlib/dlib/optimization/optimization_bobyqa.h | 3423 +++ .../optimization/optimization_bobyqa_abstract.h | 120 + .../dlib/optimization/optimization_least_squares.h | 345 + .../optimization_least_squares_abstract.h | 112 + .../dlib/optimization/optimization_line_search.h | 888 + .../optimization_line_search_abstract.h | 361 + ml/dlib/dlib/optimization/optimization_oca.h | 407 + .../dlib/optimization/optimization_oca_abstract.h | 334 + .../optimization/optimization_search_strategies.h | 324 + .../optimization_search_strategies_abstract.h | 330 + .../optimization_solve_qp2_using_smo.h | 468 + .../optimization_solve_qp2_using_smo_abstract.h | 150 + .../optimization_solve_qp3_using_smo.h | 455 + .../optimization_solve_qp3_using_smo_abstract.h | 139 + .../optimization/optimization_solve_qp_using_smo.h | 937 + .../optimization_solve_qp_using_smo_abstract.h | 282 + .../optimization/optimization_stop_strategies.h | 173 + .../optimization_stop_strategies_abstract.h | 157 + .../dlib/optimization/optimization_trust_region.h | 564 + .../optimization_trust_region_abstract.h | 233 + ml/dlib/dlib/ostream | 1 + ml/dlib/dlib/pipe.h | 10 + ml/dlib/dlib/pipe/pipe_kernel_1.h | 756 + ml/dlib/dlib/pipe/pipe_kernel_abstract.h | 323 + ml/dlib/dlib/pixel.h | 1649 ++ ml/dlib/dlib/platform.h | 65 + ml/dlib/dlib/python.h | 14 + ml/dlib/dlib/python/numpy.h | 214 + ml/dlib/dlib/python/numpy_image.h | 129 + ml/dlib/dlib/python/pyassert.h | 17 + ml/dlib/dlib/python/pybind_utils.h | 82 + ml/dlib/dlib/python/serialize_pickle.h | 66 + ml/dlib/dlib/quantum_computing.h | 12 + ml/dlib/dlib/quantum_computing/quantum_computing.h | 863 + .../quantum_computing/quantum_computing_abstract.h | 590 + ml/dlib/dlib/queue.h | 84 + ml/dlib/dlib/queue/queue_kernel_1.h | 554 + ml/dlib/dlib/queue/queue_kernel_2.h | 600 + ml/dlib/dlib/queue/queue_kernel_abstract.h | 196 + ml/dlib/dlib/queue/queue_kernel_c.h | 187 + ml/dlib/dlib/queue/queue_sort_1.h | 165 + ml/dlib/dlib/queue/queue_sort_abstract.h | 74 + ml/dlib/dlib/rand.h | 9 + ml/dlib/dlib/rand/mersenne_twister.h | 210 + ml/dlib/dlib/rand/rand_kernel_1.h | 354 + ml/dlib/dlib/rand/rand_kernel_abstract.h | 218 + ml/dlib/dlib/random_forest.h | 10 + .../dlib/random_forest/random_forest_regression.h | 738 + .../random_forest_regression_abstract.h | 460 + ml/dlib/dlib/ref.h | 84 + ml/dlib/dlib/reference_counter.h | 31 + .../reference_counter/reference_counter_kernel_1.h | 298 + .../reference_counter_kernel_abstract.h | 141 + ml/dlib/dlib/revision.h.in | 6 + ml/dlib/dlib/sequence.h | 83 + ml/dlib/dlib/sequence/sequence_compare_1.h | 102 + ml/dlib/dlib/sequence/sequence_compare_abstract.h | 75 + ml/dlib/dlib/sequence/sequence_kernel_1.h | 1340 ++ ml/dlib/dlib/sequence/sequence_kernel_2.h | 682 + ml/dlib/dlib/sequence/sequence_kernel_abstract.h | 199 + ml/dlib/dlib/sequence/sequence_kernel_c.h | 253 + ml/dlib/dlib/sequence/sequence_sort_1.h | 182 + ml/dlib/dlib/sequence/sequence_sort_2.h | 65 + ml/dlib/dlib/sequence/sequence_sort_abstract.h | 65 + ml/dlib/dlib/serialize.h | 1779 ++ ml/dlib/dlib/server.h | 12 + ml/dlib/dlib/server/server_http.cpp | 409 + ml/dlib/dlib/server/server_http.h | 242 + ml/dlib/dlib/server/server_http_abstract.h | 390 + ml/dlib/dlib/server/server_iostream.cpp | 14 + ml/dlib/dlib/server/server_iostream.h | 155 + ml/dlib/dlib/server/server_iostream_abstract.h | 84 + ml/dlib/dlib/server/server_kernel.cpp | 595 + ml/dlib/dlib/server/server_kernel.h | 234 + ml/dlib/dlib/server/server_kernel_abstract.h | 310 + ml/dlib/dlib/set.h | 74 + ml/dlib/dlib/set/set_compare_1.h | 122 + ml/dlib/dlib/set/set_compare_abstract.h | 96 + ml/dlib/dlib/set/set_kernel_1.h | 372 + ml/dlib/dlib/set/set_kernel_abstract.h | 192 + ml/dlib/dlib/set/set_kernel_c.h | 194 + ml/dlib/dlib/set_utils.h | 11 + ml/dlib/dlib/set_utils/set_utils.h | 246 + ml/dlib/dlib/set_utils/set_utils_abstract.h | 98 + ml/dlib/dlib/simd.h | 12 + ml/dlib/dlib/simd/simd4f.h | 685 + ml/dlib/dlib/simd/simd4i.h | 566 + ml/dlib/dlib/simd/simd8f.h | 402 + ml/dlib/dlib/simd/simd8i.h | 339 + ml/dlib/dlib/simd/simd_check.h | 177 + ml/dlib/dlib/sliding_buffer.h | 38 + ml/dlib/dlib/sliding_buffer/circular_buffer.h | 235 + .../dlib/sliding_buffer/circular_buffer_abstract.h | 257 + .../dlib/sliding_buffer/sliding_buffer_kernel_1.h | 227 + .../sliding_buffer_kernel_abstract.h | 205 + .../dlib/sliding_buffer/sliding_buffer_kernel_c.h | 222 + ml/dlib/dlib/smart_pointers.h | 22 + ml/dlib/dlib/smart_pointers/scoped_ptr.h | 16 + ml/dlib/dlib/smart_pointers/shared_ptr.h | 492 + ml/dlib/dlib/smart_pointers/shared_ptr_abstract.h | 374 + .../dlib/smart_pointers/shared_ptr_thread_safe.h | 462 + .../shared_ptr_thread_safe_abstract.h | 352 + ml/dlib/dlib/smart_pointers/weak_ptr.h | 225 + ml/dlib/dlib/smart_pointers/weak_ptr_abstract.h | 193 + ml/dlib/dlib/smart_pointers_thread_safe.h | 21 + ml/dlib/dlib/sockets.h | 20 + ml/dlib/dlib/sockets/posix.h | 6 + ml/dlib/dlib/sockets/sockets_extensions.cpp | 341 + ml/dlib/dlib/sockets/sockets_extensions.h | 151 + ml/dlib/dlib/sockets/sockets_extensions_abstract.h | 300 + ml/dlib/dlib/sockets/sockets_kernel_1.cpp | 979 + ml/dlib/dlib/sockets/sockets_kernel_1.h | 351 + ml/dlib/dlib/sockets/sockets_kernel_2.cpp | 1109 + ml/dlib/dlib/sockets/sockets_kernel_2.h | 396 + ml/dlib/dlib/sockets/sockets_kernel_abstract.h | 495 + ml/dlib/dlib/sockets/windows.h | 6 + ml/dlib/dlib/sockstreambuf.h | 11 + ml/dlib/dlib/sockstreambuf/sockstreambuf.cpp | 177 + ml/dlib/dlib/sockstreambuf/sockstreambuf.h | 172 + .../dlib/sockstreambuf/sockstreambuf_abstract.h | 127 + .../sockstreambuf/sockstreambuf_unbuffered.cpp | 168 + .../dlib/sockstreambuf/sockstreambuf_unbuffered.h | 118 + ml/dlib/dlib/sort.h | 490 + ml/dlib/dlib/sparse_vector.h | 10 + ml/dlib/dlib/sqlite.h | 11 + ml/dlib/dlib/sqlite/sqlite.h | 625 + ml/dlib/dlib/sqlite/sqlite_abstract.h | 506 + ml/dlib/dlib/sqlite/sqlite_tools.h | 189 + ml/dlib/dlib/sqlite/sqlite_tools_abstract.h | 164 + ml/dlib/dlib/sstream | 1 + ml/dlib/dlib/stack.h | 34 + ml/dlib/dlib/stack/stack_kernel_1.h | 504 + ml/dlib/dlib/stack/stack_kernel_abstract.h | 180 + ml/dlib/dlib/stack/stack_kernel_c.h | 189 + ml/dlib/dlib/stack_trace.cpp | 91 + ml/dlib/dlib/stack_trace.h | 118 + ml/dlib/dlib/static_map.h | 43 + ml/dlib/dlib/static_map/static_map_kernel_1.h | 756 + .../dlib/static_map/static_map_kernel_abstract.h | 181 + ml/dlib/dlib/static_map/static_map_kernel_c.h | 89 + ml/dlib/dlib/static_set.h | 49 + ml/dlib/dlib/static_set/static_set_compare_1.h | 122 + .../dlib/static_set/static_set_compare_abstract.h | 93 + ml/dlib/dlib/static_set/static_set_kernel_1.h | 446 + .../dlib/static_set/static_set_kernel_abstract.h | 154 + ml/dlib/dlib/static_set/static_set_kernel_c.h | 88 + ml/dlib/dlib/statistics.h | 19 + ml/dlib/dlib/statistics/average_precision.h | 66 + .../dlib/statistics/average_precision_abstract.h | 67 + ml/dlib/dlib/statistics/cca.h | 186 + ml/dlib/dlib/statistics/cca_abstract.h | 191 + ml/dlib/dlib/statistics/dpca.h | 541 + ml/dlib/dlib/statistics/dpca_abstract.h | 365 + ml/dlib/dlib/statistics/image_feature_sampling.h | 82 + .../statistics/image_feature_sampling_abstract.h | 45 + ml/dlib/dlib/statistics/lda.h | 237 + ml/dlib/dlib/statistics/lda_abstract.h | 118 + ml/dlib/dlib/statistics/random_subset_selector.h | 372 + .../statistics/random_subset_selector_abstract.h | 388 + ml/dlib/dlib/statistics/running_gradient.h | 370 + .../dlib/statistics/running_gradient_abstract.h | 276 + ml/dlib/dlib/statistics/sammon.h | 269 + ml/dlib/dlib/statistics/sammon_abstract.h | 117 + ml/dlib/dlib/statistics/statistics.h | 1890 ++ ml/dlib/dlib/statistics/statistics_abstract.h | 1387 ++ .../dlib/statistics/vector_normalizer_frobmetric.h | 618 + .../vector_normalizer_frobmetric_abstract.h | 328 + ml/dlib/dlib/std_allocator.h | 199 + ml/dlib/dlib/stl_checked.h | 10 + ml/dlib/dlib/stl_checked/std_vector_c.h | 333 + ml/dlib/dlib/stl_checked/std_vector_c_abstract.h | 470 + ml/dlib/dlib/string.h | 9 + ml/dlib/dlib/string/cassert | 1 + ml/dlib/dlib/string/iomanip | 1 + ml/dlib/dlib/string/iosfwd | 1 + ml/dlib/dlib/string/iostream | 1 + ml/dlib/dlib/string/locale | 1 + ml/dlib/dlib/string/string.h | 1004 + ml/dlib/dlib/string/string_abstract.h | 652 + ml/dlib/dlib/svm.h | 60 + ml/dlib/dlib/svm/active_learning.h | 162 + ml/dlib/dlib/svm/active_learning_abstract.h | 75 + ml/dlib/dlib/svm/assignment_function.h | 255 + ml/dlib/dlib/svm/assignment_function_abstract.h | 342 + .../dlib/svm/cross_validate_assignment_trainer.h | 181 + .../cross_validate_assignment_trainer_abstract.h | 69 + .../svm/cross_validate_graph_labeling_trainer.h | 258 + ...ross_validate_graph_labeling_trainer_abstract.h | 147 + .../dlib/svm/cross_validate_multiclass_trainer.h | 208 + .../cross_validate_multiclass_trainer_abstract.h | 99 + .../svm/cross_validate_object_detection_trainer.h | 430 + ...ss_validate_object_detection_trainer_abstract.h | 297 + .../dlib/svm/cross_validate_regression_trainer.h | 155 + .../cross_validate_regression_trainer_abstract.h | 82 + ml/dlib/dlib/svm/cross_validate_sequence_labeler.h | 152 + .../svm/cross_validate_sequence_labeler_abstract.h | 83 + .../dlib/svm/cross_validate_sequence_segmenter.h | 187 + .../cross_validate_sequence_segmenter_abstract.h | 80 + .../svm/cross_validate_track_association_trainer.h | 163 + ...s_validate_track_association_trainer_abstract.h | 69 + ml/dlib/dlib/svm/empirical_kernel_map.h | 429 + ml/dlib/dlib/svm/empirical_kernel_map_abstract.h | 430 + ml/dlib/dlib/svm/feature_ranking.h | 477 + ml/dlib/dlib/svm/feature_ranking_abstract.h | 136 + ml/dlib/dlib/svm/function.h | 882 + ml/dlib/dlib/svm/function_abstract.h | 997 + ml/dlib/dlib/svm/kcentroid.h | 614 + ml/dlib/dlib/svm/kcentroid_abstract.h | 339 + ml/dlib/dlib/svm/kcentroid_overloads.h | 1324 ++ ml/dlib/dlib/svm/kernel.h | 569 + ml/dlib/dlib/svm/kernel_abstract.h | 681 + ml/dlib/dlib/svm/kernel_matrix.h | 268 + ml/dlib/dlib/svm/kernel_matrix_abstract.h | 115 + ml/dlib/dlib/svm/kkmeans.h | 654 + ml/dlib/dlib/svm/kkmeans_abstract.h | 365 + ml/dlib/dlib/svm/krls.h | 358 + ml/dlib/dlib/svm/krls_abstract.h | 202 + ml/dlib/dlib/svm/krr_trainer.h | 368 + ml/dlib/dlib/svm/krr_trainer_abstract.h | 322 + .../dlib/svm/linearly_independent_subset_finder.h | 540 + .../linearly_independent_subset_finder_abstract.h | 327 + ml/dlib/dlib/svm/multiclass_tools.h | 68 + ml/dlib/dlib/svm/multiclass_tools_abstract.h | 45 + ml/dlib/dlib/svm/null_df.h | 33 + ml/dlib/dlib/svm/null_trainer.h | 61 + ml/dlib/dlib/svm/null_trainer_abstract.h | 101 + ml/dlib/dlib/svm/num_nonnegative_weights.h | 76 + ml/dlib/dlib/svm/one_vs_all_decision_function.h | 265 + .../svm/one_vs_all_decision_function_abstract.h | 214 + ml/dlib/dlib/svm/one_vs_all_trainer.h | 234 + ml/dlib/dlib/svm/one_vs_all_trainer_abstract.h | 163 + ml/dlib/dlib/svm/one_vs_one_decision_function.h | 291 + .../svm/one_vs_one_decision_function_abstract.h | 213 + ml/dlib/dlib/svm/one_vs_one_trainer.h | 249 + ml/dlib/dlib/svm/one_vs_one_trainer_abstract.h | 166 + ml/dlib/dlib/svm/pegasos.h | 710 + ml/dlib/dlib/svm/pegasos_abstract.h | 514 + ml/dlib/dlib/svm/ranking_tools.h | 448 + ml/dlib/dlib/svm/ranking_tools_abstract.h | 247 + ml/dlib/dlib/svm/rbf_network.h | 162 + ml/dlib/dlib/svm/rbf_network_abstract.h | 132 + ml/dlib/dlib/svm/reduced.h | 613 + ml/dlib/dlib/svm/reduced_abstract.h | 267 + ml/dlib/dlib/svm/rls.h | 232 + ml/dlib/dlib/svm/rls_abstract.h | 175 + ml/dlib/dlib/svm/roc_trainer.h | 149 + ml/dlib/dlib/svm/roc_trainer_abstract.h | 135 + ml/dlib/dlib/svm/rr_trainer.h | 456 + ml/dlib/dlib/svm/rr_trainer_abstract.h | 255 + ml/dlib/dlib/svm/rvm.h | 1018 + ml/dlib/dlib/svm/rvm_abstract.h | 278 + ml/dlib/dlib/svm/sequence_labeler.h | 339 + ml/dlib/dlib/svm/sequence_labeler_abstract.h | 396 + ml/dlib/dlib/svm/sequence_segmenter.h | 468 + ml/dlib/dlib/svm/sequence_segmenter_abstract.h | 452 + .../dlib/svm/simplify_linear_decision_function.h | 110 + .../simplify_linear_decision_function_abstract.h | 74 + ml/dlib/dlib/svm/sort_basis_vectors.h | 224 + ml/dlib/dlib/svm/sort_basis_vectors_abstract.h | 59 + ml/dlib/dlib/svm/sparse_kernel.h | 384 + ml/dlib/dlib/svm/sparse_kernel_abstract.h | 486 + ml/dlib/dlib/svm/sparse_vector.h | 1170 + ml/dlib/dlib/svm/sparse_vector_abstract.h | 688 + ml/dlib/dlib/svm/structural_assignment_trainer.h | 294 + .../svm/structural_assignment_trainer_abstract.h | 299 + .../dlib/svm/structural_graph_labeling_trainer.h | 282 + .../structural_graph_labeling_trainer_abstract.h | 265 + .../dlib/svm/structural_object_detection_trainer.h | 402 + .../structural_object_detection_trainer_abstract.h | 390 + .../svm/structural_sequence_labeling_trainer.h | 271 + ...structural_sequence_labeling_trainer_abstract.h | 266 + .../svm/structural_sequence_segmentation_trainer.h | 281 + ...ctural_sequence_segmentation_trainer_abstract.h | 264 + .../dlib/svm/structural_svm_assignment_problem.h | 288 + .../structural_svm_assignment_problem_abstract.h | 87 + ml/dlib/dlib/svm/structural_svm_distributed.h | 700 + .../dlib/svm/structural_svm_distributed_abstract.h | 357 + .../svm/structural_svm_graph_labeling_problem.h | 542 + ...tructural_svm_graph_labeling_problem_abstract.h | 249 + .../svm/structural_svm_object_detection_problem.h | 531 + ...uctural_svm_object_detection_problem_abstract.h | 178 + ml/dlib/dlib/svm/structural_svm_problem.h | 649 + ml/dlib/dlib/svm/structural_svm_problem_abstract.h | 348 + ml/dlib/dlib/svm/structural_svm_problem_threaded.h | 157 + .../svm/structural_svm_problem_threaded_abstract.h | 68 + .../svm/structural_svm_sequence_labeling_problem.h | 281 + ...ctural_svm_sequence_labeling_problem_abstract.h | 110 + .../svm/structural_track_association_trainer.h | 404 + ...structural_track_association_trainer_abstract.h | 268 + ml/dlib/dlib/svm/svm.h | 1205 + ml/dlib/dlib/svm/svm_abstract.h | 604 + ml/dlib/dlib/svm/svm_c_ekm_trainer.h | 636 + ml/dlib/dlib/svm/svm_c_ekm_trainer_abstract.h | 384 + ml/dlib/dlib/svm/svm_c_linear_dcd_trainer.h | 712 + .../dlib/svm/svm_c_linear_dcd_trainer_abstract.h | 382 + ml/dlib/dlib/svm/svm_c_linear_trainer.h | 706 + ml/dlib/dlib/svm/svm_c_linear_trainer_abstract.h | 359 + ml/dlib/dlib/svm/svm_c_trainer.h | 359 + ml/dlib/dlib/svm/svm_c_trainer_abstract.h | 237 + ml/dlib/dlib/svm/svm_multiclass_linear_trainer.h | 432 + .../svm/svm_multiclass_linear_trainer_abstract.h | 275 + ml/dlib/dlib/svm/svm_nu_trainer.h | 326 + ml/dlib/dlib/svm/svm_nu_trainer_abstract.h | 210 + ml/dlib/dlib/svm/svm_one_class_trainer.h | 284 + ml/dlib/dlib/svm/svm_one_class_trainer_abstract.h | 201 + ml/dlib/dlib/svm/svm_rank_trainer.h | 495 + ml/dlib/dlib/svm/svm_rank_trainer_abstract.h | 298 + ml/dlib/dlib/svm/svm_threaded.h | 253 + ml/dlib/dlib/svm/svm_threaded_abstract.h | 62 + ml/dlib/dlib/svm/svr_linear_trainer.h | 424 + ml/dlib/dlib/svm/svr_linear_trainer_abstract.h | 269 + ml/dlib/dlib/svm/svr_trainer.h | 393 + ml/dlib/dlib/svm/svr_trainer_abstract.h | 209 + ml/dlib/dlib/svm/track_association_function.h | 154 + .../dlib/svm/track_association_function_abstract.h | 271 + ml/dlib/dlib/svm_threaded.h | 36 + ml/dlib/dlib/sync_extension.h | 31 + .../dlib/sync_extension/sync_extension_kernel_1.h | 67 + .../sync_extension_kernel_abstract.h | 190 + ml/dlib/dlib/test/CMakeLists.txt | 181 + .../test/WINDOWS_build_and_run_all_unit_tests.bat | 42 + ml/dlib/dlib/test/active_learning.cpp | 165 + ml/dlib/dlib/test/any.cpp | 139 + ml/dlib/dlib/test/any_function.cpp | 253 + ml/dlib/dlib/test/array.cpp | 669 + ml/dlib/dlib/test/array2d.cpp | 580 + ml/dlib/dlib/test/assignment_learning.cpp | 379 + ml/dlib/dlib/test/base64.cpp | 208 + ml/dlib/dlib/test/bayes_nets.cpp | 411 + ml/dlib/dlib/test/bigint.cpp | 522 + ml/dlib/dlib/test/binary_search_tree.h | 889 + ml/dlib/dlib/test/binary_search_tree_kernel_1a.cpp | 47 + ml/dlib/dlib/test/binary_search_tree_kernel_2a.cpp | 45 + ml/dlib/dlib/test/binary_search_tree_mm1.cpp | 66 + ml/dlib/dlib/test/binary_search_tree_mm2.cpp | 48 + ml/dlib/dlib/test/blas_bindings/CMakeLists.txt | 33 + .../dlib/test/blas_bindings/blas_bindings_dot.cpp | 314 + .../dlib/test/blas_bindings/blas_bindings_gemm.cpp | 311 + .../dlib/test/blas_bindings/blas_bindings_gemv.cpp | 226 + .../dlib/test/blas_bindings/blas_bindings_ger.cpp | 200 + .../test/blas_bindings/blas_bindings_scal_axpy.cpp | 261 + ml/dlib/dlib/test/blas_bindings/vector.cpp | 115 + ml/dlib/dlib/test/bridge.cpp | 259 + ml/dlib/dlib/test/bsp.cpp | 566 + ml/dlib/dlib/test/byte_orderer.cpp | 111 + ml/dlib/dlib/test/cca.cpp | 460 + ml/dlib/dlib/test/checkerboard.h | 55 + ml/dlib/dlib/test/clustering.cpp | 410 + ml/dlib/dlib/test/cmd_line_parser.cpp | 40 + ml/dlib/dlib/test/cmd_line_parser.h | 901 + ml/dlib/dlib/test/cmd_line_parser_wchar_t.cpp | 40 + ml/dlib/dlib/test/compress_stream.cpp | 306 + ml/dlib/dlib/test/conditioning_class.cpp | 86 + ml/dlib/dlib/test/conditioning_class.h | 841 + ml/dlib/dlib/test/conditioning_class_c.cpp | 87 + ml/dlib/dlib/test/config_reader.cpp | 509 + ml/dlib/dlib/test/correlation_tracker.cpp | 955 + ml/dlib/dlib/test/crc32.cpp | 74 + ml/dlib/dlib/test/create_iris_datafile.cpp | 65 + ml/dlib/dlib/test/create_iris_datafile.h | 19 + ml/dlib/dlib/test/cublas.cpp | 198 + ml/dlib/dlib/test/data_io.cpp | 227 + ml/dlib/dlib/test/directed_graph.cpp | 541 + ml/dlib/dlib/test/discriminant_pca.cpp | 365 + ml/dlib/dlib/test/disjoint_subsets.cpp | 102 + ml/dlib/dlib/test/disjoint_subsets_sized.cpp | 143 + ml/dlib/dlib/test/dnn.cpp | 3261 +++ ml/dlib/dlib/test/ekm_and_lisf.cpp | 306 + ml/dlib/dlib/test/elastic_net.cpp | 122 + ml/dlib/dlib/test/empirical_kernel_map.cpp | 444 + ml/dlib/dlib/test/entropy_coder.cpp | 587 + ml/dlib/dlib/test/entropy_encoder_model.cpp | 198 + ml/dlib/dlib/test/example.cpp | 72 + ml/dlib/dlib/test/example_args.cpp | 75 + ml/dlib/dlib/test/examples/CMakeLists.txt | 8 + ml/dlib/dlib/test/face.cpp | 360 + ml/dlib/dlib/test/fft.cpp | 553 + ml/dlib/dlib/test/fhog.cpp | 684 + ml/dlib/dlib/test/filtering.cpp | 166 + ml/dlib/dlib/test/find_max_factor_graph_nmplp.cpp | 787 + .../dlib/test/find_max_factor_graph_viterbi.cpp | 217 + ml/dlib/dlib/test/find_optimal_parameters.cpp | 58 + ml/dlib/dlib/test/geometry.cpp | 883 + ml/dlib/dlib/test/global_optimization.cpp | 302 + ml/dlib/dlib/test/graph.cpp | 414 + ml/dlib/dlib/test/graph_cuts.cpp | 1217 ++ ml/dlib/dlib/test/graph_labeler.cpp | 472 + ml/dlib/dlib/test/gui/CMakeLists.txt | 20 + ml/dlib/dlib/test/gui/main.cpp | 840 + ml/dlib/dlib/test/hash.cpp | 369 + ml/dlib/dlib/test/hash_map.cpp | 450 + ml/dlib/dlib/test/hash_set.cpp | 387 + ml/dlib/dlib/test/hash_table.cpp | 663 + ml/dlib/dlib/test/hog_image.cpp | 126 + ml/dlib/dlib/test/image.cpp | 1903 ++ ml/dlib/dlib/test/iosockstream.cpp | 181 + ml/dlib/dlib/test/is_same_object.cpp | 141 + ml/dlib/dlib/test/isotonic_regression.cpp | 103 + ml/dlib/dlib/test/kcentroid.cpp | 684 + ml/dlib/dlib/test/kernel_matrix.cpp | 161 + ml/dlib/dlib/test/kmeans.cpp | 163 + ml/dlib/dlib/test/learning_to_track.cpp | 306 + ml/dlib/dlib/test/least_squares.cpp | 452 + ml/dlib/dlib/test/linear_manifold_regularizer.cpp | 408 + ml/dlib/dlib/test/lspi.cpp | 258 + ml/dlib/dlib/test/lz77_buffer.cpp | 569 + ml/dlib/dlib/test/main.cpp | 217 + ml/dlib/dlib/test/makefile | 185 + ml/dlib/dlib/test/map.cpp | 441 + ml/dlib/dlib/test/matrix.cpp | 1519 ++ ml/dlib/dlib/test/matrix2.cpp | 1158 + ml/dlib/dlib/test/matrix3.cpp | 1134 + ml/dlib/dlib/test/matrix4.cpp | 1119 + ml/dlib/dlib/test/matrix_chol.cpp | 182 + ml/dlib/dlib/test/matrix_eig.cpp | 245 + ml/dlib/dlib/test/matrix_lu.cpp | 223 + ml/dlib/dlib/test/matrix_qr.cpp | 208 + ml/dlib/dlib/test/max_cost_assignment.cpp | 157 + ml/dlib/dlib/test/max_sum_submatrix.cpp | 177 + ml/dlib/dlib/test/md5.cpp | 71 + ml/dlib/dlib/test/member_function_pointer.cpp | 553 + ml/dlib/dlib/test/metaprogramming.cpp | 94 + ml/dlib/dlib/test/mpc.cpp | 346 + ml/dlib/dlib/test/multithreaded_object.cpp | 321 + ml/dlib/dlib/test/numerical_integration.cpp | 228 + ml/dlib/dlib/test/object_detector.cpp | 1028 + ml/dlib/dlib/test/oca.cpp | 244 + ml/dlib/dlib/test/one_vs_all_trainer.cpp | 305 + ml/dlib/dlib/test/one_vs_one_trainer.cpp | 218 + ml/dlib/dlib/test/opt_qp_solver.cpp | 813 + ml/dlib/dlib/test/optimization.cpp | 1231 ++ ml/dlib/dlib/test/optimization_test_functions.cpp | 425 + ml/dlib/dlib/test/optimization_test_functions.h | 310 + ml/dlib/dlib/test/parallel_for.cpp | 334 + ml/dlib/dlib/test/parse.cpp | 233 + ml/dlib/dlib/test/pipe.cpp | 688 + ml/dlib/dlib/test/pixel.cpp | 777 + ml/dlib/dlib/test/probabilistic.cpp | 123 + ml/dlib/dlib/test/pyramid_down.cpp | 424 + ml/dlib/dlib/test/queue.cpp | 426 + ml/dlib/dlib/test/rand.cpp | 436 + ml/dlib/dlib/test/random_forest.cpp | 405 + ml/dlib/dlib/test/ranking.cpp | 485 + ml/dlib/dlib/test/read_write_mutex.cpp | 208 + ml/dlib/dlib/test/reference_counter.cpp | 122 + ml/dlib/dlib/test/rls.cpp | 196 + ml/dlib/dlib/test/sammon.cpp | 211 + ml/dlib/dlib/test/scan_image.cpp | 713 + ml/dlib/dlib/test/sequence.cpp | 312 + ml/dlib/dlib/test/sequence_labeler.cpp | 461 + ml/dlib/dlib/test/sequence_segmenter.cpp | 294 + ml/dlib/dlib/test/serialize.cpp | 1087 + ml/dlib/dlib/test/set.cpp | 464 + ml/dlib/dlib/test/sldf.cpp | 296 + ml/dlib/dlib/test/sliding_buffer.cpp | 439 + ml/dlib/dlib/test/smart_pointers.cpp | 449 + ml/dlib/dlib/test/sockets.cpp | 247 + ml/dlib/dlib/test/sockets2.cpp | 204 + ml/dlib/dlib/test/sockstreambuf.cpp | 253 + ml/dlib/dlib/test/sparse_vector.cpp | 301 + ml/dlib/dlib/test/stack.cpp | 294 + ml/dlib/dlib/test/static_map.cpp | 323 + ml/dlib/dlib/test/static_set.cpp | 206 + ml/dlib/dlib/test/statistics.cpp | 915 + ml/dlib/dlib/test/std_vector_c.cpp | 101 + ml/dlib/dlib/test/string.cpp | 329 + ml/dlib/dlib/test/svm.cpp | 661 + ml/dlib/dlib/test/svm_c_linear.cpp | 392 + ml/dlib/dlib/test/svm_c_linear_dcd.cpp | 545 + ml/dlib/dlib/test/svm_multiclass_linear.cpp | 226 + ml/dlib/dlib/test/svm_struct.cpp | 641 + ml/dlib/dlib/test/svr_linear_trainer.cpp | 161 + ml/dlib/dlib/test/symmetric_matrix_cache.cpp | 212 + ml/dlib/dlib/test/tester.cpp | 175 + ml/dlib/dlib/test/tester.h | 187 + ml/dlib/dlib/test/thread_pool.cpp | 428 + ml/dlib/dlib/test/threads.cpp | 158 + ml/dlib/dlib/test/timer.cpp | 347 + ml/dlib/dlib/test/tokenizer.cpp | 378 + ml/dlib/dlib/test/tools/CMakeLists.txt | 5 + ml/dlib/dlib/test/trust_region.cpp | 329 + ml/dlib/dlib/test/tuple.cpp | 186 + ml/dlib/dlib/test/type_safe_union.cpp | 455 + ml/dlib/dlib/test/vectorstream.cpp | 142 + ml/dlib/dlib/test_for_odr_violations.cpp | 47 + ml/dlib/dlib/test_for_odr_violations.h | 57 + ml/dlib/dlib/threads.h | 28 + ml/dlib/dlib/threads/async.cpp | 48 + ml/dlib/dlib/threads/async.h | 105 + ml/dlib/dlib/threads/async_abstract.h | 67 + ml/dlib/dlib/threads/auto_mutex_extension.h | 180 + .../dlib/threads/auto_mutex_extension_abstract.h | 185 + ml/dlib/dlib/threads/auto_unlock_extension.h | 116 + .../dlib/threads/auto_unlock_extension_abstract.h | 116 + ml/dlib/dlib/threads/create_new_thread_extension.h | 46 + .../threads/create_new_thread_extension_abstract.h | 33 + .../threads/multithreaded_object_extension.cpp | 241 + .../dlib/threads/multithreaded_object_extension.h | 153 + .../multithreaded_object_extension_abstract.h | 186 + ml/dlib/dlib/threads/parallel_for_extension.h | 676 + .../dlib/threads/parallel_for_extension_abstract.h | 469 + ml/dlib/dlib/threads/posix.h | 6 + ml/dlib/dlib/threads/read_write_mutex_extension.h | 177 + .../threads/read_write_mutex_extension_abstract.h | 146 + ml/dlib/dlib/threads/rmutex_extension.h | 109 + ml/dlib/dlib/threads/rmutex_extension_abstract.h | 107 + ml/dlib/dlib/threads/rsignaler_extension.h | 90 + .../dlib/threads/rsignaler_extension_abstract.h | 123 + ml/dlib/dlib/threads/thread_function_extension.h | 215 + .../threads/thread_function_extension_abstract.h | 146 + ml/dlib/dlib/threads/thread_pool_extension.cpp | 347 + ml/dlib/dlib/threads/thread_pool_extension.h | 1392 ++ .../dlib/threads/thread_pool_extension_abstract.h | 842 + .../dlib/threads/thread_specific_data_extension.h | 141 + .../thread_specific_data_extension_abstract.h | 87 + ml/dlib/dlib/threads/threaded_object_extension.cpp | 290 + ml/dlib/dlib/threads/threaded_object_extension.h | 123 + .../threads/threaded_object_extension_abstract.h | 199 + ml/dlib/dlib/threads/threads_kernel.h | 18 + ml/dlib/dlib/threads/threads_kernel_1.cpp | 83 + ml/dlib/dlib/threads/threads_kernel_1.h | 158 + ml/dlib/dlib/threads/threads_kernel_2.cpp | 75 + ml/dlib/dlib/threads/threads_kernel_2.h | 180 + ml/dlib/dlib/threads/threads_kernel_abstract.h | 302 + ml/dlib/dlib/threads/threads_kernel_shared.cpp | 318 + ml/dlib/dlib/threads/threads_kernel_shared.h | 274 + ml/dlib/dlib/threads/windows.h | 6 + ml/dlib/dlib/time_this.h | 36 + ml/dlib/dlib/timeout.h | 10 + ml/dlib/dlib/timeout/timeout.h | 200 + ml/dlib/dlib/timeout/timeout_abstract.h | 188 + ml/dlib/dlib/timer.h | 10 + ml/dlib/dlib/timer/timer.cpp | 235 + ml/dlib/dlib/timer/timer.h | 427 + ml/dlib/dlib/timer/timer_abstract.h | 190 + ml/dlib/dlib/timer/timer_heavy.h | 392 + ml/dlib/dlib/timing.h | 196 + ml/dlib/dlib/tokenizer.h | 33 + ml/dlib/dlib/tokenizer/tokenizer_kernel_1.cpp | 295 + ml/dlib/dlib/tokenizer/tokenizer_kernel_1.h | 155 + ml/dlib/dlib/tokenizer/tokenizer_kernel_abstract.h | 289 + ml/dlib/dlib/tokenizer/tokenizer_kernel_c.h | 167 + ml/dlib/dlib/travis/build-and-test.sh | 45 + ml/dlib/dlib/tuple.h | 10 + ml/dlib/dlib/tuple/tuple.h | 410 + ml/dlib/dlib/tuple/tuple_abstract.h | 302 + ml/dlib/dlib/type_safe_union.h | 11 + .../dlib/type_safe_union/type_safe_union_kernel.h | 711 + .../type_safe_union_kernel_abstract.h | 329 + ml/dlib/dlib/uintn.h | 96 + ml/dlib/dlib/unicode.h | 9 + ml/dlib/dlib/unicode/unicode.cpp | 175 + ml/dlib/dlib/unicode/unicode.h | 622 + ml/dlib/dlib/unicode/unicode_abstract.h | 233 + ml/dlib/dlib/unordered_pair.h | 176 + ml/dlib/dlib/vectorstream.h | 11 + ml/dlib/dlib/vectorstream/unserialize.h | 98 + ml/dlib/dlib/vectorstream/unserialize_abstract.h | 58 + ml/dlib/dlib/vectorstream/vectorstream.h | 138 + ml/dlib/dlib/vectorstream/vectorstream_abstract.h | 62 + ml/dlib/dlib/windows_magic.h | 50 + ml/dlib/dlib/xml_parser.h | 13 + ml/dlib/dlib/xml_parser/xml_parser_kernel_1.h | 1532 ++ .../dlib/xml_parser/xml_parser_kernel_abstract.h | 276 + .../dlib/xml_parser/xml_parser_kernel_interfaces.h | 244 + ml/dlib/docs/.logger_revnum | 1 + ml/dlib/docs/README.txt | 72 + ml/dlib/docs/bash_helper_functions | 30 + ml/dlib/docs/docs/algorithms.xml | 1118 + ml/dlib/docs/docs/api.xml | 1289 ++ ml/dlib/docs/docs/bayes.xml | 377 + ml/dlib/docs/docs/bayesopt_vs_lipo.svg | 21764 +++++++++++++++++++ ml/dlib/docs/docs/bigminus.gif | Bin 0 -> 91 bytes ml/dlib/docs/docs/bigplus.gif | Bin 0 -> 99 bytes ml/dlib/docs/docs/books.xml | 306 + ml/dlib/docs/docs/boost.png | Bin 0 -> 6308 bytes ml/dlib/docs/docs/change_log.xml | 11 + ... README. DO NOT EDIT THE TABLE OF CONTENTS FILE | 0 ...README. DO NOT EDIT THE TABLE OF CONTENTS FILE2 | 0 ...README. DO NOT EDIT THE TABLE OF CONTENTS FILE3 | 0 ml/dlib/docs/docs/chm/README.txt | 5 + ml/dlib/docs/docs/chm/documentation.html | 20 + ml/dlib/docs/docs/chm/htmlhelp/hha.dll | Bin 0 -> 837904 bytes ml/dlib/docs/docs/chm/htmlhelp/hhc.exe | Bin 0 -> 51472 bytes ml/dlib/docs/docs/chm/htmlhelp/htmlhelp.reg | 5 + ml/dlib/docs/docs/chm/htmlhelp/itcc.dll | Bin 0 -> 154352 bytes ml/dlib/docs/docs/chm/htmlhelp/itircl.dll | Bin 0 -> 155552 bytes ml/dlib/docs/docs/chm/htmlhelp/itss.dll | Bin 0 -> 138048 bytes ml/dlib/docs/docs/chm/htmlhelp/setup_htmlhelp.sh | 10 + ml/dlib/docs/docs/chm/htmlhelp_stylesheet.xsl | 223 + ml/dlib/docs/docs/chm/lib.hhp | 77 + ml/dlib/docs/docs/chm/toc.xml | 10 + ml/dlib/docs/docs/compile.xml | 227 + ml/dlib/docs/docs/compression.xml | 881 + ml/dlib/docs/docs/containers.xml | 1201 + ml/dlib/docs/docs/dlib-icon-30x32.png | Bin 0 -> 1278 bytes ml/dlib/docs/docs/dlib-icon-32.png | Bin 0 -> 1291 bytes ml/dlib/docs/docs/dlib-icon-48.png | Bin 0 -> 2040 bytes ml/dlib/docs/docs/dlib-icon-64.png | Bin 0 -> 2768 bytes ml/dlib/docs/docs/dlib-icon.ico | Bin 0 -> 1150 bytes ml/dlib/docs/docs/dlib-logo-and-icons.svg | 1602 ++ ml/dlib/docs/docs/dlib-logo-small.png | Bin 0 -> 2780 bytes ml/dlib/docs/docs/dlib-logo.png | Bin 0 -> 5701 bytes ml/dlib/docs/docs/dlib.css | 369 + ml/dlib/docs/docs/dlib.js | 94 + ml/dlib/docs/docs/down.gif | Bin 0 -> 61 bytes ml/dlib/docs/docs/enable_if.html | 387 + ml/dlib/docs/docs/face_landmarking_example.png | Bin 0 -> 113093 bytes ml/dlib/docs/docs/faq.xml | 547 + ml/dlib/docs/docs/find_max_global_example.mp4 | Bin 0 -> 610283 bytes ml/dlib/docs/docs/find_max_global_example.png | Bin 0 -> 13294 bytes ml/dlib/docs/docs/find_max_global_example.webm | Bin 0 -> 355489 bytes .../docs/docs/find_max_global_results_table.svg | 3398 +++ ml/dlib/docs/docs/graph_tools.xml | 678 + ml/dlib/docs/docs/guipics/button.png | Bin 0 -> 327 bytes ml/dlib/docs/docs/guipics/check_box.png | Bin 0 -> 438 bytes .../docs/docs/guipics/directed_graph_drawer.png | Bin 0 -> 3532 bytes ml/dlib/docs/docs/guipics/image_window.jpg | Bin 0 -> 24648 bytes ml/dlib/docs/docs/guipics/label.png | Bin 0 -> 322 bytes ml/dlib/docs/docs/guipics/list_box.png | Bin 0 -> 1187 bytes ml/dlib/docs/docs/guipics/menu_bar.png | Bin 0 -> 4588 bytes ml/dlib/docs/docs/guipics/message_box.png | Bin 0 -> 5209 bytes ml/dlib/docs/docs/guipics/mouse_tracker.png | Bin 0 -> 705 bytes ml/dlib/docs/docs/guipics/named_rectangle.png | Bin 0 -> 730 bytes .../docs/docs/guipics/open_existing_file_box.png | Bin 0 -> 8368 bytes ml/dlib/docs/docs/guipics/open_file_box.png | Bin 0 -> 8267 bytes ml/dlib/docs/docs/guipics/perspective_window.png | Bin 0 -> 29372 bytes ml/dlib/docs/docs/guipics/popup_menu.png | Bin 0 -> 1412 bytes ml/dlib/docs/docs/guipics/radio_button.png | Bin 0 -> 474 bytes ml/dlib/docs/docs/guipics/save_file_box.png | Bin 0 -> 8138 bytes ml/dlib/docs/docs/guipics/scroll_bar.png | Bin 0 -> 358 bytes ml/dlib/docs/docs/guipics/tabbed_display.png | Bin 0 -> 769 bytes ml/dlib/docs/docs/guipics/text_box.png | Bin 0 -> 798 bytes ml/dlib/docs/docs/guipics/text_field.png | Bin 0 -> 451 bytes ml/dlib/docs/docs/guipics/text_grid.png | Bin 0 -> 1516 bytes ml/dlib/docs/docs/heatmap.png | Bin 0 -> 259 bytes ml/dlib/docs/docs/howto_contribute.xml | 604 + ml/dlib/docs/docs/imaging.xml | 2608 +++ ml/dlib/docs/docs/index.xml | 226 + ml/dlib/docs/docs/intro.xml | 431 + ml/dlib/docs/docs/jet.png | Bin 0 -> 273 bytes ml/dlib/docs/docs/kernel_1a.txt | 78 + ml/dlib/docs/docs/kernel_1a.xml | 8 + ml/dlib/docs/docs/kernel_1b.txt | 77 + ml/dlib/docs/docs/kernel_1b.xml | 8 + ml/dlib/docs/docs/kernel_1c.txt | 78 + ml/dlib/docs/docs/kernel_1c.xml | 8 + ml/dlib/docs/docs/kernel_1da.txt | 78 + ml/dlib/docs/docs/kernel_1da.xml | 8 + ml/dlib/docs/docs/kernel_1db.txt | 78 + ml/dlib/docs/docs/kernel_1db.xml | 8 + ml/dlib/docs/docs/kernel_1ea.txt | 78 + ml/dlib/docs/docs/kernel_1ea.xml | 8 + ml/dlib/docs/docs/kernel_1eb.txt | 78 + ml/dlib/docs/docs/kernel_1eb.xml | 8 + ml/dlib/docs/docs/kernel_1ec.txt | 78 + ml/dlib/docs/docs/kernel_1ec.xml | 8 + ml/dlib/docs/docs/kernel_2a.txt | 78 + ml/dlib/docs/docs/kernel_2a.xml | 8 + ml/dlib/docs/docs/kernel_3a.txt | 78 + ml/dlib/docs/docs/kernel_3a.xml | 8 + ml/dlib/docs/docs/kernel_3b.txt | 78 + ml/dlib/docs/docs/kernel_3b.xml | 8 + ml/dlib/docs/docs/license.xml | 36 + ml/dlib/docs/docs/linear_algebra.xml | 1382 ++ ml/dlib/docs/docs/main_menu.xml | 665 + ml/dlib/docs/docs/metaprogramming.xml | 813 + ml/dlib/docs/docs/minus.gif | Bin 0 -> 56 bytes ml/dlib/docs/docs/ml.xml | 3957 ++++ ml/dlib/docs/docs/ml_guide.dia | Bin 0 -> 15962 bytes ml/dlib/docs/docs/ml_guide.svg | 4345 ++++ ml/dlib/docs/docs/network.xml | 259 + ml/dlib/docs/docs/old_change_log.xml | 7 + ml/dlib/docs/docs/old_release_notes.xml | 10 + ml/dlib/docs/docs/optimization.xml | 1338 ++ ml/dlib/docs/docs/other.xml | 1166 + ml/dlib/docs/docs/parsing.xml | 652 + ml/dlib/docs/docs/plus.gif | Bin 0 -> 59 bytes ml/dlib/docs/docs/python/conf.py | 246 + ml/dlib/docs/docs/python/generate_dlib_listing.py | 32 + ml/dlib/docs/docs/python/index.rst | 45 + ml/dlib/docs/docs/rbf_big_gamma.gif | Bin 0 -> 2131 bytes ml/dlib/docs/docs/rbf_normal.gif | Bin 0 -> 3050 bytes ml/dlib/docs/docs/rbf_small_gamma.gif | Bin 0 -> 1503 bytes ml/dlib/docs/docs/release_notes.xml | 4437 ++++ ml/dlib/docs/docs/right.gif | Bin 0 -> 67 bytes ml/dlib/docs/docs/stylesheet.xsl | 1201 + ml/dlib/docs/docs/term_index.xml | 1801 ++ ml/dlib/docs/docs/tiled_pyramid_example.jpg | Bin 0 -> 12237 bytes ml/dlib/docs/docs/vs-cmake-gui.png | Bin 0 -> 93460 bytes ml/dlib/docs/docs/vs_mode_1.png | Bin 0 -> 13372 bytes ml/dlib/docs/docs/vs_mode_2.png | Bin 0 -> 13885 bytes ml/dlib/docs/docs/vs_mode_3.png | Bin 0 -> 11228 bytes ml/dlib/docs/makedocs | 282 + ml/dlib/docs/makerel | 91 + ml/dlib/docs/testenv | 31 + ml/dlib/docs/testenv_rel | 24 + ml/dlib/examples/3d_point_cloud_ex.cpp | 50 + ml/dlib/examples/CMakeLists.txt | 250 + ml/dlib/examples/LICENSE_FOR_EXAMPLE_PROGRAMS.txt | 22 + ml/dlib/examples/assignment_learning_ex.cpp | 325 + ml/dlib/examples/bayes_net_ex.cpp | 307 + ml/dlib/examples/bayes_net_from_disk_ex.cpp | 83 + ml/dlib/examples/bayes_net_gui_ex.cpp | 989 + ml/dlib/examples/bridge_ex.cpp | 365 + ml/dlib/examples/bsp_ex.cpp | 282 + ml/dlib/examples/compress_stream_ex.cpp | 245 + ml/dlib/examples/config.txt | 30 + ml/dlib/examples/config_reader_ex.cpp | 146 + ml/dlib/examples/custom_trainer_ex.cpp | 277 + ml/dlib/examples/dir_nav_ex.cpp | 75 + ml/dlib/examples/dnn_face_recognition_ex.cpp | 220 + ml/dlib/examples/dnn_imagenet_ex.cpp | 171 + ml/dlib/examples/dnn_imagenet_train_ex.cpp | 368 + ml/dlib/examples/dnn_inception_ex.cpp | 154 + ml/dlib/examples/dnn_introduction2_ex.cpp | 388 + ml/dlib/examples/dnn_introduction_ex.cpp | 170 + ml/dlib/examples/dnn_metric_learning_ex.cpp | 128 + .../examples/dnn_metric_learning_on_images_ex.cpp | 340 + ml/dlib/examples/dnn_mmod_dog_hipsterizer.cpp | 180 + ml/dlib/examples/dnn_mmod_ex.cpp | 230 + ml/dlib/examples/dnn_mmod_face_detection_ex.cpp | 114 + ml/dlib/examples/dnn_mmod_find_cars2_ex.cpp | 96 + ml/dlib/examples/dnn_mmod_find_cars_ex.cpp | 236 + ml/dlib/examples/dnn_mmod_train_find_cars_ex.cpp | 425 + ml/dlib/examples/dnn_semantic_segmentation_ex.cpp | 172 + ml/dlib/examples/dnn_semantic_segmentation_ex.h | 200 + .../dnn_semantic_segmentation_train_ex.cpp | 390 + ml/dlib/examples/empirical_kernel_map_ex.cpp | 355 + ml/dlib/examples/face_detection_ex.cpp | 103 + ml/dlib/examples/face_landmark_detection_ex.cpp | 144 + ml/dlib/examples/faces/2007_007763.jpg | Bin 0 -> 89619 bytes ml/dlib/examples/faces/2008_001009.jpg | Bin 0 -> 41770 bytes ml/dlib/examples/faces/2008_001322.jpg | Bin 0 -> 65344 bytes ml/dlib/examples/faces/2008_002079.jpg | Bin 0 -> 92641 bytes ml/dlib/examples/faces/2008_002470.jpg | Bin 0 -> 91349 bytes ml/dlib/examples/faces/2008_002506.jpg | Bin 0 -> 79316 bytes ml/dlib/examples/faces/2008_004176.jpg | Bin 0 -> 93821 bytes ml/dlib/examples/faces/2008_007676.jpg | Bin 0 -> 110034 bytes ml/dlib/examples/faces/2009_004587.jpg | Bin 0 -> 79462 bytes ml/dlib/examples/faces/Tom_Cruise_avp_2014_4.jpg | Bin 0 -> 66360 bytes ml/dlib/examples/faces/bald_guys.jpg | Bin 0 -> 648373 bytes ml/dlib/examples/faces/dogs.jpg | Bin 0 -> 175216 bytes .../examples/faces/image_metadata_stylesheet.xsl | 109 + ml/dlib/examples/faces/testing.xml | 43 + .../examples/faces/testing_with_face_landmarks.xml | 1772 ++ ml/dlib/examples/faces/training.xml | 34 + .../faces/training_with_face_landmarks.xml | 1280 ++ ml/dlib/examples/fhog_ex.cpp | 88 + ml/dlib/examples/fhog_object_detector_ex.cpp | 269 + ml/dlib/examples/file_to_code_ex.cpp | 111 + ml/dlib/examples/graph_labeling_ex.cpp | 259 + ml/dlib/examples/gui_api_ex.cpp | 231 + ml/dlib/examples/hough_transform_ex.cpp | 84 + ml/dlib/examples/image_ex.cpp | 104 + .../examples/integrate_function_adapt_simp_ex.cpp | 89 + ml/dlib/examples/iosockstream_ex.cpp | 47 + .../examples/johns/John_Salley/000179_02159509.jpg | Bin 0 -> 9192 bytes .../examples/johns/John_Salley/000183_02159543.jpg | Bin 0 -> 9811 bytes .../examples/johns/John_Salley/000186_02159346.jpg | Bin 0 -> 8161 bytes .../examples/johns/John_Salley/000189_02159361.jpg | Bin 0 -> 9000 bytes .../examples/johns/John_Salley/000190_02159501.jpg | Bin 0 -> 8133 bytes .../examples/johns/John_Salley/000192_02159531.jpg | Bin 0 -> 9465 bytes .../examples/johns/John_Salley/000194_02159572.jpg | Bin 0 -> 7450 bytes .../examples/johns/John_Salley/000197_02159322.jpg | Bin 0 -> 9227 bytes .../examples/johns/John_Salley/000197_02159525.jpg | Bin 0 -> 7935 bytes .../examples/johns/John_Salley/000198_02159470.jpg | Bin 0 -> 10581 bytes .../examples/johns/John_Salley/000200_02159354.jpg | Bin 0 -> 8485 bytes .../examples/johns/John_Savage/000264_01099001.jpg | Bin 0 -> 6494 bytes .../examples/johns/John_Savage/000274_01099061.jpg | Bin 0 -> 6031 bytes .../examples/johns/John_Savage/000277_01099000.jpg | Bin 0 -> 6636 bytes .../examples/johns/John_Savage/000289_01099139.jpg | Bin 0 -> 5746 bytes .../examples/johns/John_Savage/000290_01099067.jpg | Bin 0 -> 6812 bytes .../examples/johns/John_Savage/000290_01099090.jpg | Bin 0 -> 5937 bytes .../examples/johns/John_Savage/000291_01099023.jpg | Bin 0 -> 6374 bytes .../examples/johns/John_Savage/000291_01099214.jpg | Bin 0 -> 5640 bytes .../examples/johns/John_Savage/000293_01099081.jpg | Bin 0 -> 6849 bytes .../examples/johns/John_Savage/000296_01099007.jpg | Bin 0 -> 6576 bytes .../examples/johns/John_Savage/000299_01099008.jpg | Bin 0 -> 5924 bytes .../johns/John_Schneider/000288_00925786.jpg | Bin 0 -> 7542 bytes .../johns/John_Schneider/000302_00925785.jpg | Bin 0 -> 6806 bytes .../johns/John_Schneider/000307_00925823.jpg | Bin 0 -> 7004 bytes .../johns/John_Schneider/000325_00925954.jpg | Bin 0 -> 7627 bytes .../johns/John_Schneider/000326_00925765.jpg | Bin 0 -> 7325 bytes .../johns/John_Schneider/000326_00926089.jpg | Bin 0 -> 7167 bytes .../johns/John_Schneider/000326_00926128.jpg | Bin 0 -> 6057 bytes .../johns/John_Schneider/000326_00926139.jpg | Bin 0 -> 6233 bytes .../johns/John_Schneider/000329_00925859.jpg | Bin 0 -> 6870 bytes .../johns/John_Schneider/000329_00925963.jpg | Bin 0 -> 7393 bytes .../johns/John_Schneider/000331_00926012.jpg | Bin 0 -> 6852 bytes .../johns/John_Shimkus/000373_03228153.jpg | Bin 0 -> 6910 bytes .../johns/John_Shimkus/000375_03227651.jpg | Bin 0 -> 6952 bytes .../johns/John_Shimkus/000376_02340068.jpg | Bin 0 -> 6810 bytes .../johns/John_Shimkus/000378_02340151.jpg | Bin 0 -> 7215 bytes .../johns/John_Shimkus/000378_03227610.jpg | Bin 0 -> 7215 bytes .../johns/John_Shimkus/000383_03227939.jpg | Bin 0 -> 5846 bytes .../johns/John_Shimkus/000385_03227766.jpg | Bin 0 -> 6084 bytes .../johns/John_Shimkus/000388_03227773.jpg | Bin 0 -> 6510 bytes .../johns/John_Shimkus/000390_03227666.jpg | Bin 0 -> 7838 bytes .../johns/John_Shimkus/000394_02340150.jpg | Bin 0 -> 10182 bytes .../johns/John_Shimkus/000396_03227722.jpg | Bin 0 -> 5802 bytes .../examples/johns/John_Simm/000288_00470387.jpg | Bin 0 -> 6513 bytes .../examples/johns/John_Simm/000297_00470170.jpg | Bin 0 -> 7194 bytes .../examples/johns/John_Simm/000300_00470148.jpg | Bin 0 -> 7289 bytes .../examples/johns/John_Simm/000304_00470122.jpg | Bin 0 -> 6582 bytes .../examples/johns/John_Simm/000305_00470162.jpg | Bin 0 -> 7965 bytes .../examples/johns/John_Simm/000305_00470717.jpg | Bin 0 -> 8694 bytes .../examples/johns/John_Simm/000306_00470222.jpg | Bin 0 -> 6306 bytes .../examples/johns/John_Simm/000306_00470223.jpg | Bin 0 -> 6274 bytes .../examples/johns/John_Simm/000309_00470287.jpg | Bin 0 -> 6195 bytes .../examples/johns/John_Simm/000310_00470421.jpg | Bin 0 -> 5563 bytes .../examples/johns/John_Simm/000310_00470511.jpg | Bin 0 -> 7574 bytes ml/dlib/examples/kcentroid_ex.cpp | 129 + ml/dlib/examples/kkmeans_ex.cpp | 154 + ml/dlib/examples/krls_ex.cpp | 94 + ml/dlib/examples/krls_filter_ex.cpp | 109 + ml/dlib/examples/krr_classification_ex.cpp | 205 + ml/dlib/examples/krr_regression_ex.cpp | 104 + ml/dlib/examples/learning_to_track_ex.cpp | 354 + ml/dlib/examples/least_squares_ex.cpp | 228 + .../examples/linear_manifold_regularizer_ex.cpp | 284 + ml/dlib/examples/logger_custom_output_ex.cpp | 73 + ml/dlib/examples/logger_ex.cpp | 70 + ml/dlib/examples/logger_ex_2.cpp | 153 + ml/dlib/examples/matrix_ex.cpp | 276 + ml/dlib/examples/matrix_expressions_ex.cpp | 406 + ml/dlib/examples/max_cost_assignment_ex.cpp | 47 + ml/dlib/examples/member_function_pointer_ex.cpp | 78 + ml/dlib/examples/mlp_ex.cpp | 86 + ml/dlib/examples/mmod_cars_test_image.jpg | Bin 0 -> 100135 bytes ml/dlib/examples/mmod_cars_test_image2.jpg | Bin 0 -> 259439 bytes ml/dlib/examples/model_selection_ex.cpp | 148 + ml/dlib/examples/mpc_ex.cpp | 156 + ml/dlib/examples/multiclass_classification_ex.cpp | 248 + ml/dlib/examples/multithreaded_object_ex.cpp | 138 + ml/dlib/examples/object_detector_advanced_ex.cpp | 302 + ml/dlib/examples/object_detector_ex.cpp | 263 + ml/dlib/examples/one_class_classifiers_ex.cpp | 245 + ml/dlib/examples/optimization_ex.cpp | 319 + ml/dlib/examples/parallel_for_ex.cpp | 158 + ml/dlib/examples/pipe_ex.cpp | 172 + ml/dlib/examples/pipe_ex_2.cpp | 160 + ml/dlib/examples/quantum_computing_ex.cpp | 337 + ml/dlib/examples/queue_ex.cpp | 78 + ml/dlib/examples/random_cropper_ex.cpp | 99 + ml/dlib/examples/rank_features_ex.cpp | 152 + ml/dlib/examples/running_stats_ex.cpp | 58 + ml/dlib/examples/rvm_ex.cpp | 217 + ml/dlib/examples/rvm_regression_ex.cpp | 101 + ml/dlib/examples/sequence_labeler_ex.cpp | 392 + ml/dlib/examples/sequence_segmenter_ex.cpp | 238 + ml/dlib/examples/server_http_ex.cpp | 108 + ml/dlib/examples/server_iostream_ex.cpp | 84 + ml/dlib/examples/sockets_ex.cpp | 63 + ml/dlib/examples/sockstreambuf_ex.cpp | 92 + ml/dlib/examples/sqlite_ex.cpp | 137 + ml/dlib/examples/std_allocator_ex.cpp | 57 + ml/dlib/examples/surf_ex.cpp | 82 + ml/dlib/examples/svm_c_ex.cpp | 266 + ml/dlib/examples/svm_ex.cpp | 255 + ml/dlib/examples/svm_pegasos_ex.cpp | 160 + ml/dlib/examples/svm_rank_ex.cpp | 151 + ml/dlib/examples/svm_sparse_ex.cpp | 120 + ml/dlib/examples/svm_struct_ex.cpp | 414 + ml/dlib/examples/svr_ex.cpp | 96 + ml/dlib/examples/thread_function_ex.cpp | 71 + ml/dlib/examples/thread_pool_ex.cpp | 183 + ml/dlib/examples/threaded_object_ex.cpp | 79 + ml/dlib/examples/threads_ex.cpp | 93 + ml/dlib/examples/timer_ex.cpp | 56 + ml/dlib/examples/train_object_detector.cpp | 422 + ml/dlib/examples/train_shape_predictor_ex.cpp | 198 + ml/dlib/examples/using_custom_kernels_ex.cpp | 208 + ml/dlib/examples/video_frames/frame_000100.jpg | Bin 0 -> 4674 bytes ml/dlib/examples/video_frames/frame_000101.jpg | Bin 0 -> 4756 bytes ml/dlib/examples/video_frames/frame_000102.jpg | Bin 0 -> 4683 bytes ml/dlib/examples/video_frames/frame_000103.jpg | Bin 0 -> 4653 bytes ml/dlib/examples/video_frames/frame_000104.jpg | Bin 0 -> 4807 bytes ml/dlib/examples/video_frames/frame_000105.jpg | Bin 0 -> 4760 bytes ml/dlib/examples/video_frames/frame_000106.jpg | Bin 0 -> 4640 bytes ml/dlib/examples/video_frames/frame_000107.jpg | Bin 0 -> 4713 bytes ml/dlib/examples/video_frames/frame_000108.jpg | Bin 0 -> 4908 bytes ml/dlib/examples/video_frames/frame_000109.jpg | Bin 0 -> 4854 bytes ml/dlib/examples/video_frames/frame_000110.jpg | Bin 0 -> 4775 bytes ml/dlib/examples/video_frames/frame_000111.jpg | Bin 0 -> 4587 bytes ml/dlib/examples/video_frames/frame_000112.jpg | Bin 0 -> 4759 bytes ml/dlib/examples/video_frames/frame_000113.jpg | Bin 0 -> 4686 bytes ml/dlib/examples/video_frames/frame_000114.jpg | Bin 0 -> 4740 bytes ml/dlib/examples/video_frames/frame_000115.jpg | Bin 0 -> 4667 bytes ml/dlib/examples/video_frames/frame_000116.jpg | Bin 0 -> 5027 bytes ml/dlib/examples/video_frames/frame_000117.jpg | Bin 0 -> 5160 bytes ml/dlib/examples/video_frames/frame_000118.jpg | Bin 0 -> 5033 bytes ml/dlib/examples/video_frames/frame_000119.jpg | Bin 0 -> 5262 bytes ml/dlib/examples/video_frames/frame_000120.jpg | Bin 0 -> 5213 bytes ml/dlib/examples/video_frames/frame_000121.jpg | Bin 0 -> 5229 bytes ml/dlib/examples/video_frames/frame_000122.jpg | Bin 0 -> 5076 bytes ml/dlib/examples/video_frames/frame_000123.jpg | Bin 0 -> 5162 bytes ml/dlib/examples/video_frames/frame_000124.jpg | Bin 0 -> 5068 bytes ml/dlib/examples/video_frames/frame_000125.jpg | Bin 0 -> 5108 bytes ml/dlib/examples/video_frames/frame_000126.jpg | Bin 0 -> 4987 bytes ml/dlib/examples/video_frames/frame_000127.jpg | Bin 0 -> 5068 bytes ml/dlib/examples/video_frames/frame_000128.jpg | Bin 0 -> 4973 bytes ml/dlib/examples/video_frames/frame_000129.jpg | Bin 0 -> 4931 bytes ml/dlib/examples/video_frames/frame_000130.jpg | Bin 0 -> 5087 bytes ml/dlib/examples/video_frames/frame_000131.jpg | Bin 0 -> 4982 bytes ml/dlib/examples/video_frames/frame_000132.jpg | Bin 0 -> 4965 bytes ml/dlib/examples/video_frames/frame_000133.jpg | Bin 0 -> 4944 bytes ml/dlib/examples/video_frames/frame_000134.jpg | Bin 0 -> 4854 bytes ml/dlib/examples/video_frames/frame_000135.jpg | Bin 0 -> 4803 bytes ml/dlib/examples/video_frames/frame_000136.jpg | Bin 0 -> 4793 bytes ml/dlib/examples/video_frames/frame_000137.jpg | Bin 0 -> 4863 bytes ml/dlib/examples/video_frames/frame_000138.jpg | Bin 0 -> 4969 bytes ml/dlib/examples/video_frames/frame_000139.jpg | Bin 0 -> 4960 bytes ml/dlib/examples/video_frames/frame_000140.jpg | Bin 0 -> 5064 bytes ml/dlib/examples/video_frames/frame_000141.jpg | Bin 0 -> 5115 bytes ml/dlib/examples/video_frames/frame_000142.jpg | Bin 0 -> 5112 bytes ml/dlib/examples/video_frames/frame_000143.jpg | Bin 0 -> 5095 bytes ml/dlib/examples/video_frames/frame_000144.jpg | Bin 0 -> 5082 bytes ml/dlib/examples/video_frames/frame_000145.jpg | Bin 0 -> 4971 bytes ml/dlib/examples/video_frames/frame_000146.jpg | Bin 0 -> 4828 bytes ml/dlib/examples/video_frames/frame_000147.jpg | Bin 0 -> 4813 bytes ml/dlib/examples/video_frames/frame_000148.jpg | Bin 0 -> 4804 bytes ml/dlib/examples/video_frames/frame_000149.jpg | Bin 0 -> 4686 bytes ml/dlib/examples/video_frames/frame_000150.jpg | Bin 0 -> 4859 bytes ml/dlib/examples/video_frames/frame_000151.jpg | Bin 0 -> 4780 bytes ml/dlib/examples/video_frames/frame_000152.jpg | Bin 0 -> 4733 bytes ml/dlib/examples/video_frames/frame_000153.jpg | Bin 0 -> 4619 bytes ml/dlib/examples/video_frames/frame_000154.jpg | Bin 0 -> 4661 bytes ml/dlib/examples/video_frames/frame_000155.jpg | Bin 0 -> 4584 bytes ml/dlib/examples/video_frames/frame_000156.jpg | Bin 0 -> 4577 bytes ml/dlib/examples/video_frames/frame_000157.jpg | Bin 0 -> 4680 bytes ml/dlib/examples/video_frames/frame_000158.jpg | Bin 0 -> 4759 bytes ml/dlib/examples/video_frames/frame_000159.jpg | Bin 0 -> 4671 bytes ml/dlib/examples/video_frames/frame_000160.jpg | Bin 0 -> 4776 bytes ml/dlib/examples/video_frames/frame_000161.jpg | Bin 0 -> 4767 bytes ml/dlib/examples/video_frames/frame_000162.jpg | Bin 0 -> 4763 bytes ml/dlib/examples/video_frames/frame_000163.jpg | Bin 0 -> 4793 bytes ml/dlib/examples/video_frames/frame_000164.jpg | Bin 0 -> 4809 bytes ml/dlib/examples/video_frames/frame_000165.jpg | Bin 0 -> 4774 bytes ml/dlib/examples/video_frames/frame_000166.jpg | Bin 0 -> 4801 bytes ml/dlib/examples/video_frames/frame_000167.jpg | Bin 0 -> 4724 bytes ml/dlib/examples/video_frames/frame_000168.jpg | Bin 0 -> 4656 bytes ml/dlib/examples/video_frames/frame_000169.jpg | Bin 0 -> 4544 bytes ml/dlib/examples/video_frames/frame_000170.jpg | Bin 0 -> 4554 bytes ml/dlib/examples/video_frames/frame_000171.jpg | Bin 0 -> 4574 bytes ml/dlib/examples/video_frames/frame_000172.jpg | Bin 0 -> 4379 bytes ml/dlib/examples/video_frames/frame_000173.jpg | Bin 0 -> 4185 bytes ml/dlib/examples/video_frames/frame_000174.jpg | Bin 0 -> 4457 bytes ml/dlib/examples/video_frames/frame_000175.jpg | Bin 0 -> 4596 bytes ml/dlib/examples/video_frames/frame_000176.jpg | Bin 0 -> 4630 bytes ml/dlib/examples/video_frames/frame_000177.jpg | Bin 0 -> 4539 bytes ml/dlib/examples/video_frames/frame_000178.jpg | Bin 0 -> 4582 bytes ml/dlib/examples/video_frames/frame_000179.jpg | Bin 0 -> 4522 bytes ml/dlib/examples/video_frames/frame_000180.jpg | Bin 0 -> 4599 bytes ml/dlib/examples/video_frames/frame_000181.jpg | Bin 0 -> 4523 bytes ml/dlib/examples/video_frames/frame_000182.jpg | Bin 0 -> 4694 bytes ml/dlib/examples/video_frames/frame_000183.jpg | Bin 0 -> 4729 bytes ml/dlib/examples/video_frames/frame_000184.jpg | Bin 0 -> 4916 bytes ml/dlib/examples/video_frames/frame_000185.jpg | Bin 0 -> 4759 bytes ml/dlib/examples/video_frames/frame_000186.jpg | Bin 0 -> 4963 bytes ml/dlib/examples/video_frames/frame_000187.jpg | Bin 0 -> 5026 bytes ml/dlib/examples/video_frames/frame_000188.jpg | Bin 0 -> 5150 bytes ml/dlib/examples/video_frames/frame_000189.jpg | Bin 0 -> 5233 bytes ml/dlib/examples/video_frames/frame_000190.jpg | Bin 0 -> 4999 bytes ml/dlib/examples/video_frames/frame_000191.jpg | Bin 0 -> 5043 bytes ml/dlib/examples/video_frames/frame_000192.jpg | Bin 0 -> 4730 bytes ml/dlib/examples/video_frames/frame_000193.jpg | Bin 0 -> 4773 bytes ml/dlib/examples/video_frames/frame_000194.jpg | Bin 0 -> 4959 bytes ml/dlib/examples/video_frames/frame_000195.jpg | Bin 0 -> 4775 bytes ml/dlib/examples/video_frames/frame_000196.jpg | Bin 0 -> 5078 bytes ml/dlib/examples/video_frames/frame_000197.jpg | Bin 0 -> 5424 bytes ml/dlib/examples/video_frames/frame_000198.jpg | Bin 0 -> 5373 bytes ml/dlib/examples/video_frames/frame_000199.jpg | Bin 0 -> 5797 bytes ml/dlib/examples/video_frames/frame_000200.jpg | Bin 0 -> 6121 bytes ml/dlib/examples/video_frames/frame_000201.jpg | Bin 0 -> 6208 bytes ml/dlib/examples/video_frames/frame_000202.jpg | Bin 0 -> 6116 bytes ml/dlib/examples/video_frames/frame_000203.jpg | Bin 0 -> 6070 bytes ml/dlib/examples/video_frames/frame_000204.jpg | Bin 0 -> 6069 bytes ml/dlib/examples/video_frames/frame_000205.jpg | Bin 0 -> 5959 bytes ml/dlib/examples/video_frames/frame_000206.jpg | Bin 0 -> 5717 bytes ml/dlib/examples/video_frames/frame_000207.jpg | Bin 0 -> 5751 bytes ml/dlib/examples/video_frames/frame_000208.jpg | Bin 0 -> 5529 bytes ml/dlib/examples/video_frames/frame_000209.jpg | Bin 0 -> 5404 bytes ml/dlib/examples/video_frames/frame_000210.jpg | Bin 0 -> 5458 bytes ml/dlib/examples/video_frames/frame_000211.jpg | Bin 0 -> 5320 bytes ml/dlib/examples/video_frames/frame_000212.jpg | Bin 0 -> 5257 bytes ml/dlib/examples/video_frames/frame_000213.jpg | Bin 0 -> 5462 bytes ml/dlib/examples/video_frames/frame_000214.jpg | Bin 0 -> 5434 bytes ml/dlib/examples/video_frames/frame_000215.jpg | Bin 0 -> 5822 bytes ml/dlib/examples/video_frames/frame_000216.jpg | Bin 0 -> 6131 bytes ml/dlib/examples/video_frames/frame_000217.jpg | Bin 0 -> 6031 bytes ml/dlib/examples/video_frames/frame_000218.jpg | Bin 0 -> 6105 bytes ml/dlib/examples/video_frames/frame_000219.jpg | Bin 0 -> 6136 bytes ml/dlib/examples/video_frames/frame_000220.jpg | Bin 0 -> 5870 bytes ml/dlib/examples/video_frames/frame_000221.jpg | Bin 0 -> 5694 bytes ml/dlib/examples/video_frames/frame_000222.jpg | Bin 0 -> 5430 bytes ml/dlib/examples/video_frames/frame_000223.jpg | Bin 0 -> 5222 bytes ml/dlib/examples/video_frames/frame_000224.jpg | Bin 0 -> 4880 bytes ml/dlib/examples/video_frames/frame_000225.jpg | Bin 0 -> 5090 bytes ml/dlib/examples/video_frames/frame_000226.jpg | Bin 0 -> 4821 bytes ml/dlib/examples/video_frames/frame_000227.jpg | Bin 0 -> 4738 bytes ml/dlib/examples/video_frames/frame_000228.jpg | Bin 0 -> 4500 bytes ml/dlib/examples/video_frames/frame_000229.jpg | Bin 0 -> 4360 bytes ml/dlib/examples/video_frames/frame_000230.jpg | Bin 0 -> 4236 bytes ml/dlib/examples/video_frames/frame_000231.jpg | Bin 0 -> 4243 bytes ml/dlib/examples/video_frames/frame_000232.jpg | Bin 0 -> 4191 bytes ml/dlib/examples/video_frames/frame_000233.jpg | Bin 0 -> 4232 bytes ml/dlib/examples/video_frames/frame_000234.jpg | Bin 0 -> 4250 bytes ml/dlib/examples/video_frames/frame_000235.jpg | Bin 0 -> 4119 bytes ml/dlib/examples/video_frames/frame_000236.jpg | Bin 0 -> 4004 bytes ml/dlib/examples/video_frames/frame_000237.jpg | Bin 0 -> 4248 bytes ml/dlib/examples/video_frames/frame_000238.jpg | Bin 0 -> 4283 bytes ml/dlib/examples/video_frames/frame_000239.jpg | Bin 0 -> 4325 bytes ml/dlib/examples/video_frames/frame_000240.jpg | Bin 0 -> 4458 bytes ml/dlib/examples/video_frames/frame_000241.jpg | Bin 0 -> 4577 bytes ml/dlib/examples/video_frames/frame_000242.jpg | Bin 0 -> 4699 bytes ml/dlib/examples/video_frames/frame_000243.jpg | Bin 0 -> 4773 bytes ml/dlib/examples/video_frames/frame_000244.jpg | Bin 0 -> 4956 bytes ml/dlib/examples/video_frames/frame_000245.jpg | Bin 0 -> 5054 bytes ml/dlib/examples/video_frames/frame_000246.jpg | Bin 0 -> 5200 bytes ml/dlib/examples/video_frames/frame_000247.jpg | Bin 0 -> 5210 bytes ml/dlib/examples/video_frames/frame_000248.jpg | Bin 0 -> 5252 bytes ml/dlib/examples/video_frames/frame_000249.jpg | Bin 0 -> 5249 bytes ml/dlib/examples/video_frames/frame_000250.jpg | Bin 0 -> 5148 bytes ml/dlib/examples/video_frames/license.txt | 6 + ml/dlib/examples/video_tracking_ex.cpp | 72 + ml/dlib/examples/webcam_face_pose_ex.cpp | 100 + ml/dlib/examples/xml_parser_ex.cpp | 115 + .../LICENSE_FOR_EXAMPLE_PROGRAMS.txt | 20 + ml/dlib/python_examples/cnn_face_detector.py | 85 + ml/dlib/python_examples/correlation_tracker.py | 72 + ml/dlib/python_examples/face_alignment.py | 91 + ml/dlib/python_examples/face_clustering.py | 127 + ml/dlib/python_examples/face_detector.py | 84 + ml/dlib/python_examples/face_jitter.py | 97 + ml/dlib/python_examples/face_landmark_detection.py | 100 + ml/dlib/python_examples/face_recognition.py | 123 + .../find_candidate_object_locations.py | 54 + ml/dlib/python_examples/global_optimization.py | 47 + ml/dlib/python_examples/max_cost_assignment.py | 57 + ml/dlib/python_examples/requirements.txt | 3 + ml/dlib/python_examples/sequence_segmenter.py | 197 + ml/dlib/python_examples/svm_binary_classifier.py | 68 + ml/dlib/python_examples/svm_rank.py | 155 + ml/dlib/python_examples/svm_struct.py | 343 + ml/dlib/python_examples/train_object_detector.py | 183 + ml/dlib/python_examples/train_shape_predictor.py | 135 + ml/dlib/setup.py | 251 + ml/dlib/tools/archive/train_face_5point_model.cpp | 159 + .../convert_dlib_nets_to_caffe/CMakeLists.txt | 25 + ml/dlib/tools/convert_dlib_nets_to_caffe/main.cpp | 792 + .../running_a_dlib_model_with_caffe_example.py | 77 + ml/dlib/tools/htmlify/CMakeLists.txt | 31 + ml/dlib/tools/htmlify/htmlify.cpp | 632 + ml/dlib/tools/htmlify/to_xml.cpp | 1599 ++ ml/dlib/tools/htmlify/to_xml.h | 22 + ml/dlib/tools/htmlify/to_xml_example/bigminus.gif | Bin 0 -> 91 bytes ml/dlib/tools/htmlify/to_xml_example/bigplus.gif | Bin 0 -> 99 bytes ml/dlib/tools/htmlify/to_xml_example/example.xml | 8 + ml/dlib/tools/htmlify/to_xml_example/minus.gif | Bin 0 -> 56 bytes ml/dlib/tools/htmlify/to_xml_example/output.xml | 49 + ml/dlib/tools/htmlify/to_xml_example/plus.gif | Bin 0 -> 59 bytes .../tools/htmlify/to_xml_example/stylesheet.xsl | 354 + ml/dlib/tools/htmlify/to_xml_example/test.cpp | 78 + ml/dlib/tools/imglab/CMakeLists.txt | 41 + ml/dlib/tools/imglab/README.txt | 40 + .../tools/imglab/convert_imglab_paths_to_relative | 24 + ml/dlib/tools/imglab/copy_imglab_dataset | 22 + ml/dlib/tools/imglab/src/cluster.cpp | 260 + ml/dlib/tools/imglab/src/cluster.h | 11 + ml/dlib/tools/imglab/src/common.cpp | 60 + ml/dlib/tools/imglab/src/common.h | 45 + ml/dlib/tools/imglab/src/convert_idl.cpp | 184 + ml/dlib/tools/imglab/src/convert_idl.h | 14 + ml/dlib/tools/imglab/src/convert_pascal_v1.cpp | 177 + ml/dlib/tools/imglab/src/convert_pascal_v1.h | 13 + ml/dlib/tools/imglab/src/convert_pascal_xml.cpp | 239 + ml/dlib/tools/imglab/src/convert_pascal_xml.h | 12 + ml/dlib/tools/imglab/src/flip_dataset.cpp | 249 + ml/dlib/tools/imglab/src/flip_dataset.h | 12 + ml/dlib/tools/imglab/src/main.cpp | 1145 + ml/dlib/tools/imglab/src/metadata_editor.cpp | 671 + ml/dlib/tools/imglab/src/metadata_editor.h | 116 + ml/dlib/tools/python/CMakeLists.txt | 106 + ml/dlib/tools/python/src/basic.cpp | 272 + ml/dlib/tools/python/src/cca.cpp | 137 + ml/dlib/tools/python/src/cnn_face_detector.cpp | 183 + ml/dlib/tools/python/src/conversion.h | 52 + ml/dlib/tools/python/src/correlation_tracker.cpp | 167 + ml/dlib/tools/python/src/decision_functions.cpp | 263 + ml/dlib/tools/python/src/dlib.cpp | 110 + ml/dlib/tools/python/src/face_recognition.cpp | 245 + ml/dlib/tools/python/src/global_optimization.cpp | 442 + ml/dlib/tools/python/src/gui.cpp | 128 + ml/dlib/tools/python/src/image.cpp | 40 + .../tools/python/src/image_dataset_metadata.cpp | 279 + ml/dlib/tools/python/src/indexing.h | 11 + ml/dlib/tools/python/src/matrix.cpp | 209 + ml/dlib/tools/python/src/numpy_returns.cpp | 158 + ml/dlib/tools/python/src/numpy_returns_stub.cpp | 59 + ml/dlib/tools/python/src/object_detection.cpp | 376 + ml/dlib/tools/python/src/opaque_types.h | 55 + ml/dlib/tools/python/src/other.cpp | 268 + ml/dlib/tools/python/src/rectangles.cpp | 268 + ml/dlib/tools/python/src/sequence_segmenter.cpp | 827 + .../tools/python/src/serialize_object_detector.h | 49 + ml/dlib/tools/python/src/shape_predictor.cpp | 319 + ml/dlib/tools/python/src/shape_predictor.h | 259 + ml/dlib/tools/python/src/simple_object_detector.h | 318 + .../tools/python/src/simple_object_detector_py.h | 290 + ml/dlib/tools/python/src/svm_c_trainer.cpp | 311 + ml/dlib/tools/python/src/svm_rank_trainer.cpp | 161 + ml/dlib/tools/python/src/svm_struct.cpp | 151 + ml/dlib/tools/python/src/testing_results.h | 50 + ml/dlib/tools/python/src/vector.cpp | 182 + ml/dlib/tools/python/test/.gitignore | 1 + ml/dlib/tools/python/test/test_array.py | 107 + .../tools/python/test/test_global_optimization.py | 69 + ml/dlib/tools/python/test/test_matrix.py | 100 + ml/dlib/tools/python/test/test_point.py | 48 + ml/dlib/tools/python/test/test_range.py | 97 + ml/dlib/tools/python/test/test_rgb_pixel.py | 26 + ml/dlib/tools/python/test/test_sparse_vector.py | 101 + ml/dlib/tools/python/test/test_svm_c_trainer.py | 65 + ml/dlib/tools/python/test/test_vector.py | 170 + ml/dlib/tools/visual_studio_natvis/README.txt | 12 + ml/dlib/tools/visual_studio_natvis/dlib.natvis | 51 + ml/ml-dummy.c | 2 +- ml/ml.cc | 17 +- ml/ml.h | 2 +- 2104 files changed, 641380 insertions(+), 9 deletions(-) create mode 100644 ml/dlib/.gitignore create mode 100644 ml/dlib/.hgignore create mode 100644 ml/dlib/.hgtags create mode 100644 ml/dlib/.travis.yml create mode 100644 ml/dlib/CMakeLists.txt create mode 100644 ml/dlib/ISSUE_TEMPLATE.md create mode 100644 ml/dlib/MANIFEST.in create mode 100644 ml/dlib/README.md create mode 100644 ml/dlib/dlib/CMakeLists.txt create mode 100644 ml/dlib/dlib/LICENSE.txt create mode 100644 ml/dlib/dlib/algs.h create mode 100644 ml/dlib/dlib/all/source.cpp create mode 100644 ml/dlib/dlib/any.h create mode 100644 ml/dlib/dlib/any/any.h create mode 100644 ml/dlib/dlib/any/any_abstract.h create mode 100644 ml/dlib/dlib/any/any_decision_function.h create mode 100644 ml/dlib/dlib/any/any_decision_function_abstract.h create mode 100644 ml/dlib/dlib/any/any_function.h create mode 100644 ml/dlib/dlib/any/any_function_abstract.h create mode 100644 ml/dlib/dlib/any/any_function_impl.h create mode 100644 ml/dlib/dlib/any/any_function_impl2.h create mode 100644 ml/dlib/dlib/any/any_trainer.h create mode 100644 ml/dlib/dlib/any/any_trainer_abstract.h create mode 100644 ml/dlib/dlib/appveyor/dtest.yml create mode 100644 ml/dlib/dlib/appveyor/dtest_vc2017.yml create mode 100644 ml/dlib/dlib/appveyor/examples.yml create mode 100644 ml/dlib/dlib/appveyor/python.yml create mode 100644 ml/dlib/dlib/array.h create mode 100644 ml/dlib/dlib/array/array_kernel.h create mode 100644 ml/dlib/dlib/array/array_kernel_abstract.h create mode 100644 ml/dlib/dlib/array/array_tools.h create mode 100644 ml/dlib/dlib/array/array_tools_abstract.h create mode 100644 ml/dlib/dlib/array2d.h create mode 100644 ml/dlib/dlib/array2d/array2d_generic_image.h create mode 100644 ml/dlib/dlib/array2d/array2d_kernel.h create mode 100644 ml/dlib/dlib/array2d/array2d_kernel_abstract.h create mode 100644 ml/dlib/dlib/array2d/serialize_pixel_overloads.h create mode 100644 ml/dlib/dlib/assert.h create mode 100644 ml/dlib/dlib/base64.h create mode 100644 ml/dlib/dlib/base64/base64_kernel_1.cpp create mode 100644 ml/dlib/dlib/base64/base64_kernel_1.h create mode 100644 ml/dlib/dlib/base64/base64_kernel_abstract.h create mode 100644 ml/dlib/dlib/bayes_utils.h create mode 100644 ml/dlib/dlib/bayes_utils/bayes_utils.h create mode 100644 ml/dlib/dlib/bayes_utils/bayes_utils_abstract.h create mode 100644 ml/dlib/dlib/bigint.h create mode 100644 ml/dlib/dlib/bigint/bigint_kernel_1.cpp create mode 100644 ml/dlib/dlib/bigint/bigint_kernel_1.h create mode 100644 ml/dlib/dlib/bigint/bigint_kernel_2.cpp create mode 100644 ml/dlib/dlib/bigint/bigint_kernel_2.h create mode 100644 ml/dlib/dlib/bigint/bigint_kernel_abstract.h create mode 100644 ml/dlib/dlib/bigint/bigint_kernel_c.h create mode 100644 ml/dlib/dlib/binary_search_tree.h create mode 100644 ml/dlib/dlib/binary_search_tree/binary_search_tree_kernel_1.h create mode 100644 ml/dlib/dlib/binary_search_tree/binary_search_tree_kernel_2.h create mode 100644 ml/dlib/dlib/binary_search_tree/binary_search_tree_kernel_abstract.h create mode 100644 ml/dlib/dlib/binary_search_tree/binary_search_tree_kernel_c.h create mode 100644 ml/dlib/dlib/bit_stream.h create mode 100644 ml/dlib/dlib/bit_stream/bit_stream_kernel_1.cpp create mode 100644 ml/dlib/dlib/bit_stream/bit_stream_kernel_1.h create mode 100644 ml/dlib/dlib/bit_stream/bit_stream_kernel_abstract.h create mode 100644 ml/dlib/dlib/bit_stream/bit_stream_kernel_c.h create mode 100644 ml/dlib/dlib/bit_stream/bit_stream_multi_1.h create mode 100644 ml/dlib/dlib/bit_stream/bit_stream_multi_abstract.h create mode 100644 ml/dlib/dlib/bit_stream/bit_stream_multi_c.h create mode 100644 ml/dlib/dlib/bits/c++config.h create mode 100644 ml/dlib/dlib/bound_function_pointer.h create mode 100644 ml/dlib/dlib/bound_function_pointer/bound_function_pointer_kernel_1.h create mode 100644 ml/dlib/dlib/bound_function_pointer/bound_function_pointer_kernel_abstract.h create mode 100644 ml/dlib/dlib/bridge.h create mode 100644 ml/dlib/dlib/bridge/bridge.h create mode 100644 ml/dlib/dlib/bridge/bridge_abstract.h create mode 100644 ml/dlib/dlib/bsp.h create mode 100644 ml/dlib/dlib/bsp/bsp.cpp create mode 100644 ml/dlib/dlib/bsp/bsp.h create mode 100644 ml/dlib/dlib/bsp/bsp_abstract.h create mode 100644 ml/dlib/dlib/byte_orderer.h create mode 100644 ml/dlib/dlib/byte_orderer/byte_orderer_kernel_1.h create mode 100644 ml/dlib/dlib/byte_orderer/byte_orderer_kernel_abstract.h create mode 100644 ml/dlib/dlib/cassert create mode 100644 ml/dlib/dlib/clustering.h create mode 100644 ml/dlib/dlib/clustering/bottom_up_cluster.h create mode 100644 ml/dlib/dlib/clustering/bottom_up_cluster_abstract.h create mode 100644 ml/dlib/dlib/clustering/chinese_whispers.h create mode 100644 ml/dlib/dlib/clustering/chinese_whispers_abstract.h create mode 100644 ml/dlib/dlib/clustering/modularity_clustering.h create mode 100644 ml/dlib/dlib/clustering/modularity_clustering_abstract.h create mode 100644 ml/dlib/dlib/clustering/spectral_cluster.h create mode 100644 ml/dlib/dlib/clustering/spectral_cluster_abstract.h create mode 100644 ml/dlib/dlib/cmake create mode 100644 ml/dlib/dlib/cmake_utils/add_global_compiler_switch.cmake create mode 100644 ml/dlib/dlib/cmake_utils/check_if_neon_available.cmake create mode 100644 ml/dlib/dlib/cmake_utils/dlib.pc.in create mode 100644 ml/dlib/dlib/cmake_utils/dlibConfig.cmake.in create mode 100644 ml/dlib/dlib/cmake_utils/find_blas.cmake create mode 100644 ml/dlib/dlib/cmake_utils/release_build_by_default create mode 100644 ml/dlib/dlib/cmake_utils/set_compiler_specific_options.cmake create mode 100644 ml/dlib/dlib/cmake_utils/tell_visual_studio_to_use_static_runtime.cmake create mode 100644 ml/dlib/dlib/cmake_utils/test_for_cpp11/CMakeLists.txt create mode 100644 ml/dlib/dlib/cmake_utils/test_for_cpp11/cpp11_test.cpp create mode 100644 ml/dlib/dlib/cmake_utils/test_for_cuda/CMakeLists.txt create mode 100644 ml/dlib/dlib/cmake_utils/test_for_cuda/cuda_test.cu create mode 100644 ml/dlib/dlib/cmake_utils/test_for_cudnn/CMakeLists.txt create mode 100644 ml/dlib/dlib/cmake_utils/test_for_cudnn/find_cudnn.txt create mode 100644 ml/dlib/dlib/cmake_utils/test_for_neon/CMakeLists.txt create mode 100644 ml/dlib/dlib/cmake_utils/test_for_neon/neon_test.cpp create mode 100644 ml/dlib/dlib/cmake_utils/use_cpp_11.cmake create mode 100644 ml/dlib/dlib/cmd_line_parser.h create mode 100644 ml/dlib/dlib/cmd_line_parser/cmd_line_parser_check_1.h create mode 100644 ml/dlib/dlib/cmd_line_parser/cmd_line_parser_check_c.h create mode 100644 ml/dlib/dlib/cmd_line_parser/cmd_line_parser_kernel_1.h create mode 100644 ml/dlib/dlib/cmd_line_parser/cmd_line_parser_kernel_abstract.h create mode 100644 ml/dlib/dlib/cmd_line_parser/cmd_line_parser_kernel_c.h create mode 100644 ml/dlib/dlib/cmd_line_parser/cmd_line_parser_print_1.h create mode 100644 ml/dlib/dlib/cmd_line_parser/get_option.h create mode 100644 ml/dlib/dlib/cmd_line_parser/get_option_abstract.h create mode 100644 ml/dlib/dlib/compress_stream.h create mode 100644 ml/dlib/dlib/compress_stream/compress_stream_kernel_1.h create mode 100644 ml/dlib/dlib/compress_stream/compress_stream_kernel_2.h create mode 100644 ml/dlib/dlib/compress_stream/compress_stream_kernel_3.h create mode 100644 ml/dlib/dlib/compress_stream/compress_stream_kernel_abstract.h create mode 100644 ml/dlib/dlib/conditioning_class.h create mode 100644 ml/dlib/dlib/conditioning_class/conditioning_class_kernel_1.h create mode 100644 ml/dlib/dlib/conditioning_class/conditioning_class_kernel_2.h create mode 100644 ml/dlib/dlib/conditioning_class/conditioning_class_kernel_3.h create mode 100644 ml/dlib/dlib/conditioning_class/conditioning_class_kernel_4.h create mode 100644 ml/dlib/dlib/conditioning_class/conditioning_class_kernel_abstract.h create mode 100644 ml/dlib/dlib/conditioning_class/conditioning_class_kernel_c.h create mode 100644 ml/dlib/dlib/config.h create mode 100644 ml/dlib/dlib/config.h.in create mode 100644 ml/dlib/dlib/config_reader.h create mode 100644 ml/dlib/dlib/config_reader/config_reader_kernel_1.h create mode 100644 ml/dlib/dlib/config_reader/config_reader_kernel_abstract.h create mode 100644 ml/dlib/dlib/config_reader/config_reader_thread_safe_1.h create mode 100644 ml/dlib/dlib/config_reader/config_reader_thread_safe_abstract.h create mode 100644 ml/dlib/dlib/console_progress_indicator.h create mode 100644 ml/dlib/dlib/control.h create mode 100644 ml/dlib/dlib/control/approximate_linear_models.h create mode 100644 ml/dlib/dlib/control/approximate_linear_models_abstract.h create mode 100644 ml/dlib/dlib/control/lspi.h create mode 100644 ml/dlib/dlib/control/lspi_abstract.h create mode 100644 ml/dlib/dlib/control/mpc.h create mode 100644 ml/dlib/dlib/control/mpc_abstract.h create mode 100644 ml/dlib/dlib/cpp_pretty_printer.h create mode 100644 ml/dlib/dlib/cpp_pretty_printer/cpp_pretty_printer_kernel_1.h create mode 100644 ml/dlib/dlib/cpp_pretty_printer/cpp_pretty_printer_kernel_2.h create mode 100644 ml/dlib/dlib/cpp_pretty_printer/cpp_pretty_printer_kernel_abstract.h create mode 100644 ml/dlib/dlib/cpp_tokenizer.h create mode 100644 ml/dlib/dlib/cpp_tokenizer/cpp_tokenizer_kernel_1.h create mode 100644 ml/dlib/dlib/cpp_tokenizer/cpp_tokenizer_kernel_abstract.h create mode 100644 ml/dlib/dlib/cpp_tokenizer/cpp_tokenizer_kernel_c.h create mode 100644 ml/dlib/dlib/crc32.h create mode 100644 ml/dlib/dlib/crc32/crc32_kernel_1.h create mode 100644 ml/dlib/dlib/crc32/crc32_kernel_abstract.h create mode 100644 ml/dlib/dlib/cstring create mode 100644 ml/dlib/dlib/data_io.h create mode 100644 ml/dlib/dlib/data_io/image_dataset_metadata.cpp create mode 100644 ml/dlib/dlib/data_io/image_dataset_metadata.h create mode 100644 ml/dlib/dlib/data_io/libsvm_io.h create mode 100644 ml/dlib/dlib/data_io/libsvm_io_abstract.h create mode 100644 ml/dlib/dlib/data_io/load_image_dataset.h create mode 100644 ml/dlib/dlib/data_io/load_image_dataset_abstract.h create mode 100644 ml/dlib/dlib/data_io/mnist.cpp create mode 100644 ml/dlib/dlib/data_io/mnist.h create mode 100644 ml/dlib/dlib/data_io/mnist_abstract.h create mode 100644 ml/dlib/dlib/dir_nav.h create mode 100644 ml/dlib/dlib/dir_nav/dir_nav_extensions.cpp create mode 100644 ml/dlib/dlib/dir_nav/dir_nav_extensions.h create mode 100644 ml/dlib/dlib/dir_nav/dir_nav_extensions_abstract.h create mode 100644 ml/dlib/dlib/dir_nav/dir_nav_kernel_1.cpp create mode 100644 ml/dlib/dlib/dir_nav/dir_nav_kernel_1.h create mode 100644 ml/dlib/dlib/dir_nav/dir_nav_kernel_2.cpp create mode 100644 ml/dlib/dlib/dir_nav/dir_nav_kernel_2.h create mode 100644 ml/dlib/dlib/dir_nav/dir_nav_kernel_abstract.h create mode 100644 ml/dlib/dlib/dir_nav/posix.h create mode 100644 ml/dlib/dlib/dir_nav/windows.h create mode 100644 ml/dlib/dlib/directed_graph.h create mode 100644 ml/dlib/dlib/directed_graph/directed_graph_kernel_1.h create mode 100644 ml/dlib/dlib/directed_graph/directed_graph_kernel_abstract.h create mode 100644 ml/dlib/dlib/disjoint_subsets.h create mode 100644 ml/dlib/dlib/disjoint_subsets/disjoint_subsets.h create mode 100644 ml/dlib/dlib/disjoint_subsets/disjoint_subsets_abstract.h create mode 100644 ml/dlib/dlib/disjoint_subsets/disjoint_subsets_sized.h create mode 100644 ml/dlib/dlib/disjoint_subsets/disjoint_subsets_sized_abstract.h create mode 100644 ml/dlib/dlib/dlib_basic_cpp_build_tutorial.txt create mode 100644 ml/dlib/dlib/dlib_include_path_tutorial.txt create mode 100644 ml/dlib/dlib/dnn.h create mode 100644 ml/dlib/dlib/dnn/core.h create mode 100644 ml/dlib/dlib/dnn/core_abstract.h create mode 100644 ml/dlib/dlib/dnn/cpu_dlib.cpp create mode 100644 ml/dlib/dlib/dnn/cpu_dlib.h create mode 100644 ml/dlib/dlib/dnn/cublas_dlibapi.cpp create mode 100644 ml/dlib/dlib/dnn/cublas_dlibapi.h create mode 100644 ml/dlib/dlib/dnn/cuda_data_ptr.cpp create mode 100644 ml/dlib/dlib/dnn/cuda_data_ptr.h create mode 100644 ml/dlib/dlib/dnn/cuda_dlib.cu create mode 100644 ml/dlib/dlib/dnn/cuda_dlib.h create mode 100644 ml/dlib/dlib/dnn/cuda_errors.h create mode 100644 ml/dlib/dlib/dnn/cuda_utils.h create mode 100644 ml/dlib/dlib/dnn/cudnn_dlibapi.cpp create mode 100644 ml/dlib/dlib/dnn/cudnn_dlibapi.h create mode 100644 ml/dlib/dlib/dnn/curand_dlibapi.cpp create mode 100644 ml/dlib/dlib/dnn/curand_dlibapi.h create mode 100644 ml/dlib/dlib/dnn/cusolver_dlibapi.cu create mode 100644 ml/dlib/dlib/dnn/cusolver_dlibapi.h create mode 100644 ml/dlib/dlib/dnn/gpu_data.cpp create mode 100644 ml/dlib/dlib/dnn/gpu_data.h create mode 100644 ml/dlib/dlib/dnn/gpu_data_abstract.h create mode 100644 ml/dlib/dlib/dnn/input.h create mode 100644 ml/dlib/dlib/dnn/input_abstract.h create mode 100644 ml/dlib/dlib/dnn/layers.h create mode 100644 ml/dlib/dlib/dnn/layers_abstract.h create mode 100644 ml/dlib/dlib/dnn/loss.h create mode 100644 ml/dlib/dlib/dnn/loss_abstract.h create mode 100644 ml/dlib/dlib/dnn/solvers.h create mode 100644 ml/dlib/dlib/dnn/solvers_abstract.h create mode 100644 ml/dlib/dlib/dnn/tensor.h create mode 100644 ml/dlib/dlib/dnn/tensor_abstract.h create mode 100644 ml/dlib/dlib/dnn/tensor_tools.cpp create mode 100644 ml/dlib/dlib/dnn/tensor_tools.h create mode 100644 ml/dlib/dlib/dnn/trainer.h create mode 100644 ml/dlib/dlib/dnn/trainer_abstract.h create mode 100644 ml/dlib/dlib/dnn/utilities.h create mode 100644 ml/dlib/dlib/dnn/utilities_abstract.h create mode 100644 ml/dlib/dlib/dnn/validation.h create mode 100644 ml/dlib/dlib/enable_if.h create mode 100644 ml/dlib/dlib/entropy_decoder.h create mode 100644 ml/dlib/dlib/entropy_decoder/entropy_decoder_kernel_1.cpp create mode 100644 ml/dlib/dlib/entropy_decoder/entropy_decoder_kernel_1.h create mode 100644 ml/dlib/dlib/entropy_decoder/entropy_decoder_kernel_2.cpp create mode 100644 ml/dlib/dlib/entropy_decoder/entropy_decoder_kernel_2.h create mode 100644 ml/dlib/dlib/entropy_decoder/entropy_decoder_kernel_abstract.h create mode 100644 ml/dlib/dlib/entropy_decoder/entropy_decoder_kernel_c.h create mode 100644 ml/dlib/dlib/entropy_decoder_model.h create mode 100644 ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_1.h create mode 100644 ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_2.h create mode 100644 ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_3.h create mode 100644 ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_4.h create mode 100644 ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_5.h create mode 100644 ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_6.h create mode 100644 ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_abstract.h create mode 100644 ml/dlib/dlib/entropy_encoder.h create mode 100644 ml/dlib/dlib/entropy_encoder/entropy_encoder_kernel_1.cpp create mode 100644 ml/dlib/dlib/entropy_encoder/entropy_encoder_kernel_1.h create mode 100644 ml/dlib/dlib/entropy_encoder/entropy_encoder_kernel_2.cpp create mode 100644 ml/dlib/dlib/entropy_encoder/entropy_encoder_kernel_2.h create mode 100644 ml/dlib/dlib/entropy_encoder/entropy_encoder_kernel_abstract.h create mode 100644 ml/dlib/dlib/entropy_encoder/entropy_encoder_kernel_c.h create mode 100644 ml/dlib/dlib/entropy_encoder_model.h create mode 100644 ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_1.h create mode 100644 ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_2.h create mode 100644 ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_3.h create mode 100644 ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_4.h create mode 100644 ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_5.h create mode 100644 ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_6.h create mode 100644 ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_abstract.h create mode 100644 ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_c.h create mode 100644 ml/dlib/dlib/error.h create mode 100644 ml/dlib/dlib/external/cblas/CMakeLists.txt create mode 100644 ml/dlib/dlib/external/cblas/README create mode 100644 ml/dlib/dlib/external/cblas/cblas.h create mode 100644 ml/dlib/dlib/external/cblas/cblas_caxpy.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_ccopy.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_cdotc_sub.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_cdotu_sub.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_cgbmv.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_cgemm.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_cgemv.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_cgerc.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_cgeru.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_chbmv.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_chemm.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_chemv.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_cher.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_cher2.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_cher2k.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_cherk.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_chpmv.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_chpr.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_chpr2.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_cscal.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_csscal.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_cswap.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_csymm.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_csyr2k.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_csyrk.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_ctbmv.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_ctbsv.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_ctpmv.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_ctpsv.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_ctrmm.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_ctrmv.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_ctrsm.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_ctrsv.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_dasum.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_daxpy.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_dcopy.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_ddot.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_dgbmv.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_dgemm.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_dgemv.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_dger.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_dnrm2.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_drot.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_drotg.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_drotm.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_drotmg.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_dsbmv.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_dscal.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_dsdot.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_dspmv.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_dspr.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_dspr2.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_dswap.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_dsymm.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_dsymv.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_dsyr.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_dsyr2.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_dsyr2k.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_dsyrk.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_dtbmv.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_dtbsv.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_dtpmv.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_dtpsv.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_dtrmm.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_dtrmv.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_dtrsm.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_dtrsv.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_dzasum.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_dznrm2.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_f77.h create mode 100644 ml/dlib/dlib/external/cblas/cblas_icamax.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_idamax.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_isamax.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_izamax.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_sasum.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_saxpy.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_scasum.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_scnrm2.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_scopy.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_sdot.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_sdsdot.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_sgbmv.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_sgemm.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_sgemv.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_sger.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_snrm2.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_srot.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_srotg.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_srotm.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_srotmg.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_ssbmv.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_sscal.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_sspmv.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_sspr.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_sspr2.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_sswap.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_ssymm.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_ssymv.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_ssyr.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_ssyr2.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_ssyr2k.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_ssyrk.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_stbmv.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_stbsv.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_stpmv.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_stpsv.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_strmm.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_strmv.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_strsm.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_strsv.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_xerbla.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_zaxpy.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_zcopy.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_zdotc_sub.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_zdotu_sub.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_zdscal.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_zgbmv.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_zgemm.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_zgemv.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_zgerc.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_zgeru.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_zhbmv.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_zhemm.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_zhemv.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_zher.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_zher2.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_zher2k.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_zherk.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_zhpmv.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_zhpr.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_zhpr2.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_zscal.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_zswap.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_zsymm.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_zsyr2k.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_zsyrk.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_ztbmv.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_ztbsv.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_ztpmv.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_ztpsv.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_ztrmm.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_ztrmv.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_ztrsm.c create mode 100644 ml/dlib/dlib/external/cblas/cblas_ztrsv.c create mode 100644 ml/dlib/dlib/external/cblas/cdotcsub.f create mode 100644 ml/dlib/dlib/external/cblas/cdotusub.f create mode 100644 ml/dlib/dlib/external/cblas/dasumsub.f create mode 100644 ml/dlib/dlib/external/cblas/ddotsub.f create mode 100644 ml/dlib/dlib/external/cblas/dnrm2sub.f create mode 100644 ml/dlib/dlib/external/cblas/dsdotsub.f create mode 100644 ml/dlib/dlib/external/cblas/dzasumsub.f create mode 100644 ml/dlib/dlib/external/cblas/dznrm2sub.f create mode 100644 ml/dlib/dlib/external/cblas/icamaxsub.f create mode 100644 ml/dlib/dlib/external/cblas/idamaxsub.f create mode 100644 ml/dlib/dlib/external/cblas/isamaxsub.f create mode 100644 ml/dlib/dlib/external/cblas/izamaxsub.f create mode 100644 ml/dlib/dlib/external/cblas/sasumsub.f create mode 100644 ml/dlib/dlib/external/cblas/scasumsub.f create mode 100644 ml/dlib/dlib/external/cblas/scnrm2sub.f create mode 100644 ml/dlib/dlib/external/cblas/sdotsub.f create mode 100644 ml/dlib/dlib/external/cblas/sdsdotsub.f create mode 100644 ml/dlib/dlib/external/cblas/snrm2sub.f create mode 100644 ml/dlib/dlib/external/cblas/zdotcsub.f create mode 100644 ml/dlib/dlib/external/cblas/zdotusub.f create mode 100644 ml/dlib/dlib/external/libjpeg/README create mode 100644 ml/dlib/dlib/external/libjpeg/jcapimin.cpp create mode 100644 ml/dlib/dlib/external/libjpeg/jcapistd.cpp create mode 100644 ml/dlib/dlib/external/libjpeg/jccoefct.cpp create mode 100644 ml/dlib/dlib/external/libjpeg/jccolor.cpp create mode 100644 ml/dlib/dlib/external/libjpeg/jcdctmgr.cpp create mode 100644 ml/dlib/dlib/external/libjpeg/jchuff.cpp create mode 100644 ml/dlib/dlib/external/libjpeg/jchuff.h create mode 100644 ml/dlib/dlib/external/libjpeg/jcinit.cpp create mode 100644 ml/dlib/dlib/external/libjpeg/jcmainct.cpp create mode 100644 ml/dlib/dlib/external/libjpeg/jcmarker.cpp create mode 100644 ml/dlib/dlib/external/libjpeg/jcmaster.cpp create mode 100644 ml/dlib/dlib/external/libjpeg/jcomapi.cpp create mode 100644 ml/dlib/dlib/external/libjpeg/jconfig.h create mode 100644 ml/dlib/dlib/external/libjpeg/jcparam.cpp create mode 100644 ml/dlib/dlib/external/libjpeg/jcphuff.cpp create mode 100644 ml/dlib/dlib/external/libjpeg/jcprepct.cpp create mode 100644 ml/dlib/dlib/external/libjpeg/jcsample.cpp create mode 100644 ml/dlib/dlib/external/libjpeg/jdapimin.cpp create mode 100644 ml/dlib/dlib/external/libjpeg/jdapistd.cpp create mode 100644 ml/dlib/dlib/external/libjpeg/jdatadst.cpp create mode 100644 ml/dlib/dlib/external/libjpeg/jdatasrc.cpp create mode 100644 ml/dlib/dlib/external/libjpeg/jdcoefct.cpp create mode 100644 ml/dlib/dlib/external/libjpeg/jdcolor.cpp create mode 100644 ml/dlib/dlib/external/libjpeg/jdct.h create mode 100644 ml/dlib/dlib/external/libjpeg/jddctmgr.cpp create mode 100644 ml/dlib/dlib/external/libjpeg/jdhuff.cpp create mode 100644 ml/dlib/dlib/external/libjpeg/jdhuff.h create mode 100644 ml/dlib/dlib/external/libjpeg/jdinput.cpp create mode 100644 ml/dlib/dlib/external/libjpeg/jdmainct.cpp create mode 100644 ml/dlib/dlib/external/libjpeg/jdmarker.cpp create mode 100644 ml/dlib/dlib/external/libjpeg/jdmaster.cpp create mode 100644 ml/dlib/dlib/external/libjpeg/jdmerge.cpp create mode 100644 ml/dlib/dlib/external/libjpeg/jdphuff.cpp create mode 100644 ml/dlib/dlib/external/libjpeg/jdpostct.cpp create mode 100644 ml/dlib/dlib/external/libjpeg/jdsample.cpp create mode 100644 ml/dlib/dlib/external/libjpeg/jerror.cpp create mode 100644 ml/dlib/dlib/external/libjpeg/jerror.h create mode 100644 ml/dlib/dlib/external/libjpeg/jfdctflt.cpp create mode 100644 ml/dlib/dlib/external/libjpeg/jfdctfst.cpp create mode 100644 ml/dlib/dlib/external/libjpeg/jfdctint.cpp create mode 100644 ml/dlib/dlib/external/libjpeg/jidctflt.cpp create mode 100644 ml/dlib/dlib/external/libjpeg/jidctfst.cpp create mode 100644 ml/dlib/dlib/external/libjpeg/jidctint.cpp create mode 100644 ml/dlib/dlib/external/libjpeg/jidctred.cpp create mode 100644 ml/dlib/dlib/external/libjpeg/jinclude.h create mode 100644 ml/dlib/dlib/external/libjpeg/jmemmgr.cpp create mode 100644 ml/dlib/dlib/external/libjpeg/jmemnobs.cpp create mode 100644 ml/dlib/dlib/external/libjpeg/jmemsys.h create mode 100644 ml/dlib/dlib/external/libjpeg/jmorecfg.h create mode 100644 ml/dlib/dlib/external/libjpeg/jpegint.h create mode 100644 ml/dlib/dlib/external/libjpeg/jpeglib.h create mode 100644 ml/dlib/dlib/external/libjpeg/jquant1.cpp create mode 100644 ml/dlib/dlib/external/libjpeg/jquant2.cpp create mode 100644 ml/dlib/dlib/external/libjpeg/jutils.cpp create mode 100644 ml/dlib/dlib/external/libjpeg/jversion.h create mode 100644 ml/dlib/dlib/external/libpng/LICENSE create mode 100644 ml/dlib/dlib/external/libpng/README create mode 100644 ml/dlib/dlib/external/libpng/arm/arm_init.c create mode 100644 ml/dlib/dlib/external/libpng/arm/filter_neon.S create mode 100644 ml/dlib/dlib/external/libpng/arm/filter_neon_intrinsics.c create mode 100644 ml/dlib/dlib/external/libpng/png.c create mode 100644 ml/dlib/dlib/external/libpng/png.h create mode 100644 ml/dlib/dlib/external/libpng/pngconf.h create mode 100644 ml/dlib/dlib/external/libpng/pngdebug.h create mode 100644 ml/dlib/dlib/external/libpng/pngerror.c create mode 100644 ml/dlib/dlib/external/libpng/pngget.c create mode 100644 ml/dlib/dlib/external/libpng/pnginfo.h create mode 100644 ml/dlib/dlib/external/libpng/pnglibconf.h create mode 100644 ml/dlib/dlib/external/libpng/pngmem.c create mode 100644 ml/dlib/dlib/external/libpng/pngpread.c create mode 100644 ml/dlib/dlib/external/libpng/pngpriv.h create mode 100644 ml/dlib/dlib/external/libpng/pngread.c create mode 100644 ml/dlib/dlib/external/libpng/pngrio.c create mode 100644 ml/dlib/dlib/external/libpng/pngrtran.c create mode 100644 ml/dlib/dlib/external/libpng/pngrutil.c create mode 100644 ml/dlib/dlib/external/libpng/pngset.c create mode 100644 ml/dlib/dlib/external/libpng/pngstruct.h create mode 100644 ml/dlib/dlib/external/libpng/pngtrans.c create mode 100644 ml/dlib/dlib/external/libpng/pngwio.c create mode 100644 ml/dlib/dlib/external/libpng/pngwrite.c create mode 100644 ml/dlib/dlib/external/libpng/pngwtran.c create mode 100644 ml/dlib/dlib/external/libpng/pngwutil.c create mode 100644 ml/dlib/dlib/external/pybind11/CMakeLists.txt create mode 100644 ml/dlib/dlib/external/pybind11/CONTRIBUTING.md create mode 100644 ml/dlib/dlib/external/pybind11/LICENSE create mode 100644 ml/dlib/dlib/external/pybind11/README.md create mode 100644 ml/dlib/dlib/external/pybind11/include/pybind11/attr.h create mode 100644 ml/dlib/dlib/external/pybind11/include/pybind11/buffer_info.h create mode 100644 ml/dlib/dlib/external/pybind11/include/pybind11/cast.h create mode 100644 ml/dlib/dlib/external/pybind11/include/pybind11/chrono.h create mode 100644 ml/dlib/dlib/external/pybind11/include/pybind11/common.h create mode 100644 ml/dlib/dlib/external/pybind11/include/pybind11/complex.h create mode 100644 ml/dlib/dlib/external/pybind11/include/pybind11/detail/class.h create mode 100644 ml/dlib/dlib/external/pybind11/include/pybind11/detail/common.h create mode 100644 ml/dlib/dlib/external/pybind11/include/pybind11/detail/descr.h create mode 100644 ml/dlib/dlib/external/pybind11/include/pybind11/detail/init.h create mode 100644 ml/dlib/dlib/external/pybind11/include/pybind11/detail/internals.h create mode 100644 ml/dlib/dlib/external/pybind11/include/pybind11/detail/typeid.h create mode 100644 ml/dlib/dlib/external/pybind11/include/pybind11/eigen.h create mode 100644 ml/dlib/dlib/external/pybind11/include/pybind11/embed.h create mode 100644 ml/dlib/dlib/external/pybind11/include/pybind11/eval.h create mode 100644 ml/dlib/dlib/external/pybind11/include/pybind11/functional.h create mode 100644 ml/dlib/dlib/external/pybind11/include/pybind11/iostream.h create mode 100644 ml/dlib/dlib/external/pybind11/include/pybind11/numpy.h create mode 100644 ml/dlib/dlib/external/pybind11/include/pybind11/operators.h create mode 100644 ml/dlib/dlib/external/pybind11/include/pybind11/options.h create mode 100644 ml/dlib/dlib/external/pybind11/include/pybind11/pybind11.h create mode 100644 ml/dlib/dlib/external/pybind11/include/pybind11/pytypes.h create mode 100644 ml/dlib/dlib/external/pybind11/include/pybind11/stl.h create mode 100644 ml/dlib/dlib/external/pybind11/include/pybind11/stl_bind.h create mode 100644 ml/dlib/dlib/external/pybind11/tools/FindCatch.cmake create mode 100644 ml/dlib/dlib/external/pybind11/tools/FindEigen3.cmake create mode 100644 ml/dlib/dlib/external/pybind11/tools/FindPythonLibsNew.cmake create mode 100755 ml/dlib/dlib/external/pybind11/tools/check-style.sh create mode 100644 ml/dlib/dlib/external/pybind11/tools/libsize.py create mode 100644 ml/dlib/dlib/external/pybind11/tools/mkdoc.py create mode 100644 ml/dlib/dlib/external/pybind11/tools/pybind11Config.cmake.in create mode 100644 ml/dlib/dlib/external/pybind11/tools/pybind11Tools.cmake create mode 100644 ml/dlib/dlib/external/zlib/README create mode 100644 ml/dlib/dlib/external/zlib/adler32.c create mode 100644 ml/dlib/dlib/external/zlib/compress.c create mode 100644 ml/dlib/dlib/external/zlib/crc32.c create mode 100644 ml/dlib/dlib/external/zlib/crc32.h create mode 100644 ml/dlib/dlib/external/zlib/deflate.c create mode 100644 ml/dlib/dlib/external/zlib/deflate.h create mode 100644 ml/dlib/dlib/external/zlib/gzclose.c create mode 100644 ml/dlib/dlib/external/zlib/gzguts.h create mode 100644 ml/dlib/dlib/external/zlib/gzlib.c create mode 100644 ml/dlib/dlib/external/zlib/gzread.c create mode 100644 ml/dlib/dlib/external/zlib/gzwrite.c create mode 100644 ml/dlib/dlib/external/zlib/infback.c create mode 100644 ml/dlib/dlib/external/zlib/inffast.c create mode 100644 ml/dlib/dlib/external/zlib/inffast.h create mode 100644 ml/dlib/dlib/external/zlib/inffixed.h create mode 100644 ml/dlib/dlib/external/zlib/inflate.c create mode 100644 ml/dlib/dlib/external/zlib/inflate.h create mode 100644 ml/dlib/dlib/external/zlib/inftrees.c create mode 100644 ml/dlib/dlib/external/zlib/inftrees.h create mode 100644 ml/dlib/dlib/external/zlib/trees.c create mode 100644 ml/dlib/dlib/external/zlib/trees.h create mode 100644 ml/dlib/dlib/external/zlib/uncompr.c create mode 100644 ml/dlib/dlib/external/zlib/zconf.h create mode 100644 ml/dlib/dlib/external/zlib/zlib.h create mode 100644 ml/dlib/dlib/external/zlib/zutil.c create mode 100644 ml/dlib/dlib/external/zlib/zutil.h create mode 100644 ml/dlib/dlib/filtering.h create mode 100644 ml/dlib/dlib/filtering/kalman_filter.cpp create mode 100644 ml/dlib/dlib/filtering/kalman_filter.h create mode 100644 ml/dlib/dlib/filtering/kalman_filter_abstract.h create mode 100644 ml/dlib/dlib/filtering/rls_filter.h create mode 100644 ml/dlib/dlib/filtering/rls_filter_abstract.h create mode 100644 ml/dlib/dlib/float_details.h create mode 100644 ml/dlib/dlib/fstream create mode 100644 ml/dlib/dlib/general_hash/count_bits.h create mode 100644 ml/dlib/dlib/general_hash/count_bits_abstract.h create mode 100644 ml/dlib/dlib/general_hash/general_hash.h create mode 100644 ml/dlib/dlib/general_hash/hash.h create mode 100644 ml/dlib/dlib/general_hash/hash_abstract.h create mode 100644 ml/dlib/dlib/general_hash/murmur_hash3.h create mode 100644 ml/dlib/dlib/general_hash/murmur_hash3_abstract.h create mode 100644 ml/dlib/dlib/general_hash/random_hashing.h create mode 100644 ml/dlib/dlib/general_hash/random_hashing_abstract.h create mode 100644 ml/dlib/dlib/geometry.h create mode 100644 ml/dlib/dlib/geometry/border_enumerator.h create mode 100644 ml/dlib/dlib/geometry/border_enumerator_abstract.h create mode 100644 ml/dlib/dlib/geometry/drectangle.h create mode 100644 ml/dlib/dlib/geometry/drectangle_abstract.h create mode 100644 ml/dlib/dlib/geometry/point_transforms.h create mode 100644 ml/dlib/dlib/geometry/point_transforms_abstract.h create mode 100644 ml/dlib/dlib/geometry/rectangle.h create mode 100644 ml/dlib/dlib/geometry/rectangle_abstract.h create mode 100644 ml/dlib/dlib/geometry/vector.h create mode 100644 ml/dlib/dlib/geometry/vector_abstract.h create mode 100644 ml/dlib/dlib/global_optimization.h create mode 100644 ml/dlib/dlib/global_optimization/find_max_global.h create mode 100644 ml/dlib/dlib/global_optimization/find_max_global_abstract.h create mode 100644 ml/dlib/dlib/global_optimization/global_function_search.cpp create mode 100644 ml/dlib/dlib/global_optimization/global_function_search.h create mode 100644 ml/dlib/dlib/global_optimization/global_function_search_abstract.h create mode 100644 ml/dlib/dlib/global_optimization/upper_bound_function.h create mode 100644 ml/dlib/dlib/global_optimization/upper_bound_function_abstract.h create mode 100644 ml/dlib/dlib/graph.h create mode 100644 ml/dlib/dlib/graph/graph_kernel_1.h create mode 100644 ml/dlib/dlib/graph/graph_kernel_abstract.h create mode 100644 ml/dlib/dlib/graph_cuts.h create mode 100644 ml/dlib/dlib/graph_cuts/find_max_factor_graph_potts.h create mode 100644 ml/dlib/dlib/graph_cuts/find_max_factor_graph_potts_abstract.h create mode 100644 ml/dlib/dlib/graph_cuts/general_flow_graph.h create mode 100644 ml/dlib/dlib/graph_cuts/general_potts_problem.h create mode 100644 ml/dlib/dlib/graph_cuts/graph_labeler.h create mode 100644 ml/dlib/dlib/graph_cuts/graph_labeler_abstract.h create mode 100644 ml/dlib/dlib/graph_cuts/min_cut.h create mode 100644 ml/dlib/dlib/graph_cuts/min_cut_abstract.h create mode 100644 ml/dlib/dlib/graph_utils.h create mode 100644 ml/dlib/dlib/graph_utils/edge_list_graphs.h create mode 100644 ml/dlib/dlib/graph_utils/edge_list_graphs_abstract.h create mode 100644 ml/dlib/dlib/graph_utils/find_k_nearest_neighbors_lsh.h create mode 100644 ml/dlib/dlib/graph_utils/find_k_nearest_neighbors_lsh_abstract.h create mode 100644 ml/dlib/dlib/graph_utils/function_objects.h create mode 100644 ml/dlib/dlib/graph_utils/function_objects_abstract.h create mode 100644 ml/dlib/dlib/graph_utils/graph_utils.h create mode 100644 ml/dlib/dlib/graph_utils/graph_utils_abstract.h create mode 100644 ml/dlib/dlib/graph_utils/ordered_sample_pair.h create mode 100644 ml/dlib/dlib/graph_utils/ordered_sample_pair_abstract.h create mode 100644 ml/dlib/dlib/graph_utils/sample_pair.h create mode 100644 ml/dlib/dlib/graph_utils/sample_pair_abstract.h create mode 100644 ml/dlib/dlib/graph_utils_threaded.h create mode 100644 ml/dlib/dlib/gui_core.h create mode 100644 ml/dlib/dlib/gui_core/gui_core_kernel_1.cpp create mode 100644 ml/dlib/dlib/gui_core/gui_core_kernel_1.h create mode 100644 ml/dlib/dlib/gui_core/gui_core_kernel_2.cpp create mode 100644 ml/dlib/dlib/gui_core/gui_core_kernel_2.h create mode 100644 ml/dlib/dlib/gui_core/gui_core_kernel_abstract.h create mode 100644 ml/dlib/dlib/gui_core/windows.h create mode 100644 ml/dlib/dlib/gui_core/xlib.h create mode 100644 ml/dlib/dlib/gui_widgets.h create mode 100644 ml/dlib/dlib/gui_widgets/base_widgets.cpp create mode 100644 ml/dlib/dlib/gui_widgets/base_widgets.h create mode 100644 ml/dlib/dlib/gui_widgets/base_widgets_abstract.h create mode 100644 ml/dlib/dlib/gui_widgets/canvas_drawing.cpp create mode 100644 ml/dlib/dlib/gui_widgets/canvas_drawing.h create mode 100644 ml/dlib/dlib/gui_widgets/canvas_drawing_abstract.h create mode 100644 ml/dlib/dlib/gui_widgets/drawable.cpp create mode 100644 ml/dlib/dlib/gui_widgets/drawable.h create mode 100644 ml/dlib/dlib/gui_widgets/drawable_abstract.h create mode 100644 ml/dlib/dlib/gui_widgets/fonts.cpp create mode 100644 ml/dlib/dlib/gui_widgets/fonts.h create mode 100644 ml/dlib/dlib/gui_widgets/fonts_abstract.h create mode 100644 ml/dlib/dlib/gui_widgets/nativefont.h create mode 100644 ml/dlib/dlib/gui_widgets/style.cpp create mode 100644 ml/dlib/dlib/gui_widgets/style.h create mode 100644 ml/dlib/dlib/gui_widgets/style_abstract.h create mode 100644 ml/dlib/dlib/gui_widgets/widgets.cpp create mode 100644 ml/dlib/dlib/gui_widgets/widgets.h create mode 100644 ml/dlib/dlib/gui_widgets/widgets_abstract.h create mode 100644 ml/dlib/dlib/hash.h create mode 100644 ml/dlib/dlib/hash_map.h create mode 100644 ml/dlib/dlib/hash_map/hash_map_kernel_1.h create mode 100644 ml/dlib/dlib/hash_map/hash_map_kernel_abstract.h create mode 100644 ml/dlib/dlib/hash_map/hash_map_kernel_c.h create mode 100644 ml/dlib/dlib/hash_set.h create mode 100644 ml/dlib/dlib/hash_set/hash_set_kernel_1.h create mode 100644 ml/dlib/dlib/hash_set/hash_set_kernel_abstract.h create mode 100644 ml/dlib/dlib/hash_set/hash_set_kernel_c.h create mode 100644 ml/dlib/dlib/hash_table.h create mode 100644 ml/dlib/dlib/hash_table/hash_table_kernel_1.h create mode 100644 ml/dlib/dlib/hash_table/hash_table_kernel_2.h create mode 100644 ml/dlib/dlib/hash_table/hash_table_kernel_abstract.h create mode 100644 ml/dlib/dlib/hash_table/hash_table_kernel_c.h create mode 100644 ml/dlib/dlib/http_client/http_client.cpp create mode 100644 ml/dlib/dlib/http_client/http_client.h create mode 100644 ml/dlib/dlib/http_client/http_client_abstract.h create mode 100644 ml/dlib/dlib/image_io.h create mode 100644 ml/dlib/dlib/image_keypoint.h create mode 100644 ml/dlib/dlib/image_keypoint/binned_vector_feature_image.h create mode 100644 ml/dlib/dlib/image_keypoint/binned_vector_feature_image_abstract.h create mode 100644 ml/dlib/dlib/image_keypoint/build_separable_poly_filters.h create mode 100644 ml/dlib/dlib/image_keypoint/draw_surf_points.h create mode 100644 ml/dlib/dlib/image_keypoint/draw_surf_points_abstract.h create mode 100644 ml/dlib/dlib/image_keypoint/fine_hog_image.h create mode 100644 ml/dlib/dlib/image_keypoint/fine_hog_image_abstract.h create mode 100644 ml/dlib/dlib/image_keypoint/hashed_feature_image.h create mode 100644 ml/dlib/dlib/image_keypoint/hashed_feature_image_abstract.h create mode 100644 ml/dlib/dlib/image_keypoint/hessian_pyramid.h create mode 100644 ml/dlib/dlib/image_keypoint/hessian_pyramid_abstract.h create mode 100644 ml/dlib/dlib/image_keypoint/hog.h create mode 100644 ml/dlib/dlib/image_keypoint/hog_abstract.h create mode 100644 ml/dlib/dlib/image_keypoint/nearest_neighbor_feature_image.h create mode 100644 ml/dlib/dlib/image_keypoint/nearest_neighbor_feature_image_abstract.h create mode 100644 ml/dlib/dlib/image_keypoint/poly_image.h create mode 100644 ml/dlib/dlib/image_keypoint/poly_image_abstract.h create mode 100644 ml/dlib/dlib/image_keypoint/surf.h create mode 100644 ml/dlib/dlib/image_keypoint/surf_abstract.h create mode 100644 ml/dlib/dlib/image_loader/image_loader.h create mode 100644 ml/dlib/dlib/image_loader/image_loader_abstract.h create mode 100644 ml/dlib/dlib/image_loader/jpeg_loader.cpp create mode 100644 ml/dlib/dlib/image_loader/jpeg_loader.h create mode 100644 ml/dlib/dlib/image_loader/jpeg_loader_abstract.h create mode 100644 ml/dlib/dlib/image_loader/load_image.h create mode 100644 ml/dlib/dlib/image_loader/load_image_abstract.h create mode 100644 ml/dlib/dlib/image_loader/png_loader.cpp create mode 100644 ml/dlib/dlib/image_loader/png_loader.h create mode 100644 ml/dlib/dlib/image_loader/png_loader_abstract.h create mode 100644 ml/dlib/dlib/image_processing.h create mode 100644 ml/dlib/dlib/image_processing/box_overlap_testing.h create mode 100644 ml/dlib/dlib/image_processing/box_overlap_testing_abstract.h create mode 100644 ml/dlib/dlib/image_processing/correlation_tracker.h create mode 100644 ml/dlib/dlib/image_processing/correlation_tracker_abstract.h create mode 100644 ml/dlib/dlib/image_processing/detection_template_tools.h create mode 100644 ml/dlib/dlib/image_processing/detection_template_tools_abstract.h create mode 100644 ml/dlib/dlib/image_processing/frontal_face_detector.h create mode 100644 ml/dlib/dlib/image_processing/frontal_face_detector_abstract.h create mode 100644 ml/dlib/dlib/image_processing/full_object_detection.h create mode 100644 ml/dlib/dlib/image_processing/full_object_detection_abstract.h create mode 100644 ml/dlib/dlib/image_processing/generic_image.h create mode 100644 ml/dlib/dlib/image_processing/object_detector.h create mode 100644 ml/dlib/dlib/image_processing/object_detector_abstract.h create mode 100644 ml/dlib/dlib/image_processing/remove_unobtainable_rectangles.h create mode 100644 ml/dlib/dlib/image_processing/remove_unobtainable_rectangles_abstract.h create mode 100644 ml/dlib/dlib/image_processing/render_face_detections.h create mode 100644 ml/dlib/dlib/image_processing/render_face_detections_abstract.h create mode 100644 ml/dlib/dlib/image_processing/scan_fhog_pyramid.h create mode 100644 ml/dlib/dlib/image_processing/scan_fhog_pyramid_abstract.h create mode 100644 ml/dlib/dlib/image_processing/scan_image.h create mode 100644 ml/dlib/dlib/image_processing/scan_image_abstract.h create mode 100644 ml/dlib/dlib/image_processing/scan_image_boxes.h create mode 100644 ml/dlib/dlib/image_processing/scan_image_boxes_abstract.h create mode 100644 ml/dlib/dlib/image_processing/scan_image_custom.h create mode 100644 ml/dlib/dlib/image_processing/scan_image_custom_abstract.h create mode 100644 ml/dlib/dlib/image_processing/scan_image_pyramid.h create mode 100644 ml/dlib/dlib/image_processing/scan_image_pyramid_abstract.h create mode 100644 ml/dlib/dlib/image_processing/scan_image_pyramid_tools.h create mode 100644 ml/dlib/dlib/image_processing/scan_image_pyramid_tools_abstract.h create mode 100644 ml/dlib/dlib/image_processing/setup_hashed_features.h create mode 100644 ml/dlib/dlib/image_processing/setup_hashed_features_abstract.h create mode 100644 ml/dlib/dlib/image_processing/shape_predictor.h create mode 100644 ml/dlib/dlib/image_processing/shape_predictor_abstract.h create mode 100644 ml/dlib/dlib/image_processing/shape_predictor_trainer.h create mode 100644 ml/dlib/dlib/image_processing/shape_predictor_trainer_abstract.h create mode 100644 ml/dlib/dlib/image_saver/dng_shared.h create mode 100644 ml/dlib/dlib/image_saver/image_saver.h create mode 100644 ml/dlib/dlib/image_saver/image_saver_abstract.h create mode 100644 ml/dlib/dlib/image_saver/save_jpeg.cpp create mode 100644 ml/dlib/dlib/image_saver/save_jpeg.h create mode 100644 ml/dlib/dlib/image_saver/save_jpeg_abstract.h create mode 100644 ml/dlib/dlib/image_saver/save_png.cpp create mode 100644 ml/dlib/dlib/image_saver/save_png.h create mode 100644 ml/dlib/dlib/image_saver/save_png_abstract.h create mode 100644 ml/dlib/dlib/image_transforms.h create mode 100644 ml/dlib/dlib/image_transforms/assign_image.h create mode 100644 ml/dlib/dlib/image_transforms/assign_image_abstract.h create mode 100644 ml/dlib/dlib/image_transforms/colormaps.h create mode 100644 ml/dlib/dlib/image_transforms/colormaps_abstract.h create mode 100644 ml/dlib/dlib/image_transforms/draw.h create mode 100644 ml/dlib/dlib/image_transforms/draw_abstract.h create mode 100644 ml/dlib/dlib/image_transforms/edge_detector.h create mode 100644 ml/dlib/dlib/image_transforms/edge_detector_abstract.h create mode 100644 ml/dlib/dlib/image_transforms/equalize_histogram.h create mode 100644 ml/dlib/dlib/image_transforms/equalize_histogram_abstract.h create mode 100644 ml/dlib/dlib/image_transforms/fhog.h create mode 100644 ml/dlib/dlib/image_transforms/fhog_abstract.h create mode 100644 ml/dlib/dlib/image_transforms/hough_transform.h create mode 100644 ml/dlib/dlib/image_transforms/hough_transform_abstract.h create mode 100644 ml/dlib/dlib/image_transforms/image_pyramid.h create mode 100644 ml/dlib/dlib/image_transforms/image_pyramid_abstract.h create mode 100644 ml/dlib/dlib/image_transforms/integral_image.h create mode 100644 ml/dlib/dlib/image_transforms/integral_image_abstract.h create mode 100644 ml/dlib/dlib/image_transforms/interpolation.h create mode 100644 ml/dlib/dlib/image_transforms/interpolation_abstract.h create mode 100644 ml/dlib/dlib/image_transforms/label_connected_blobs.h create mode 100644 ml/dlib/dlib/image_transforms/label_connected_blobs_abstract.h create mode 100644 ml/dlib/dlib/image_transforms/lbp.h create mode 100644 ml/dlib/dlib/image_transforms/lbp_abstract.h create mode 100644 ml/dlib/dlib/image_transforms/morphological_operations.h create mode 100644 ml/dlib/dlib/image_transforms/morphological_operations_abstract.h create mode 100644 ml/dlib/dlib/image_transforms/random_color_transform.h create mode 100644 ml/dlib/dlib/image_transforms/random_color_transform_abstract.h create mode 100644 ml/dlib/dlib/image_transforms/random_cropper.h create mode 100644 ml/dlib/dlib/image_transforms/random_cropper_abstract.h create mode 100644 ml/dlib/dlib/image_transforms/segment_image.h create mode 100644 ml/dlib/dlib/image_transforms/segment_image_abstract.h create mode 100644 ml/dlib/dlib/image_transforms/spatial_filtering.h create mode 100644 ml/dlib/dlib/image_transforms/spatial_filtering_abstract.h create mode 100644 ml/dlib/dlib/image_transforms/thresholding.h create mode 100644 ml/dlib/dlib/image_transforms/thresholding_abstract.h create mode 100644 ml/dlib/dlib/interfaces/cmd_line_parser_option.h create mode 100644 ml/dlib/dlib/interfaces/enumerable.h create mode 100644 ml/dlib/dlib/interfaces/map_pair.h create mode 100644 ml/dlib/dlib/interfaces/remover.h create mode 100644 ml/dlib/dlib/iomanip create mode 100644 ml/dlib/dlib/iosfwd create mode 100644 ml/dlib/dlib/iosockstream.h create mode 100644 ml/dlib/dlib/iosockstream/iosockstream.h create mode 100644 ml/dlib/dlib/iosockstream/iosockstream_abstract.h create mode 100644 ml/dlib/dlib/iostream create mode 100644 ml/dlib/dlib/is_kind.h create mode 100644 ml/dlib/dlib/istream create mode 100644 ml/dlib/dlib/java/CMakeLists.txt create mode 100644 ml/dlib/dlib/java/cmake_swig_jni create mode 100644 ml/dlib/dlib/java/java_array.h create mode 100755 ml/dlib/dlib/java/run_test.sh create mode 100644 ml/dlib/dlib/java/swig_api.h create mode 100644 ml/dlib/dlib/java/swig_test.java create mode 100644 ml/dlib/dlib/linker.h create mode 100644 ml/dlib/dlib/linker/linker_kernel_1.cpp create mode 100644 ml/dlib/dlib/linker/linker_kernel_1.h create mode 100644 ml/dlib/dlib/linker/linker_kernel_abstract.h create mode 100644 ml/dlib/dlib/locale create mode 100644 ml/dlib/dlib/logger.h create mode 100644 ml/dlib/dlib/logger/extra_logger_headers.cpp create mode 100644 ml/dlib/dlib/logger/extra_logger_headers.h create mode 100644 ml/dlib/dlib/logger/logger_config_file.cpp create mode 100644 ml/dlib/dlib/logger/logger_config_file.h create mode 100644 ml/dlib/dlib/logger/logger_kernel_1.cpp create mode 100644 ml/dlib/dlib/logger/logger_kernel_1.h create mode 100644 ml/dlib/dlib/logger/logger_kernel_abstract.h create mode 100644 ml/dlib/dlib/lsh.h create mode 100644 ml/dlib/dlib/lsh/create_random_projection_hash.h create mode 100644 ml/dlib/dlib/lsh/create_random_projection_hash_abstract.h create mode 100644 ml/dlib/dlib/lsh/hashes.h create mode 100644 ml/dlib/dlib/lsh/hashes_abstract.h create mode 100644 ml/dlib/dlib/lsh/projection_hash.h create mode 100644 ml/dlib/dlib/lsh/projection_hash_abstract.h create mode 100644 ml/dlib/dlib/lz77_buffer.h create mode 100644 ml/dlib/dlib/lz77_buffer/lz77_buffer_kernel_1.h create mode 100644 ml/dlib/dlib/lz77_buffer/lz77_buffer_kernel_2.h create mode 100644 ml/dlib/dlib/lz77_buffer/lz77_buffer_kernel_abstract.h create mode 100644 ml/dlib/dlib/lz77_buffer/lz77_buffer_kernel_c.h create mode 100644 ml/dlib/dlib/lzp_buffer.h create mode 100644 ml/dlib/dlib/lzp_buffer/lzp_buffer_kernel_1.h create mode 100644 ml/dlib/dlib/lzp_buffer/lzp_buffer_kernel_2.h create mode 100644 ml/dlib/dlib/lzp_buffer/lzp_buffer_kernel_abstract.h create mode 100644 ml/dlib/dlib/lzp_buffer/lzp_buffer_kernel_c.h create mode 100644 ml/dlib/dlib/manifold_regularization.h create mode 100644 ml/dlib/dlib/manifold_regularization/linear_manifold_regularizer.h create mode 100644 ml/dlib/dlib/manifold_regularization/linear_manifold_regularizer_abstract.h create mode 100644 ml/dlib/dlib/map.h create mode 100644 ml/dlib/dlib/map/map_kernel_1.h create mode 100644 ml/dlib/dlib/map/map_kernel_abstract.h create mode 100644 ml/dlib/dlib/map/map_kernel_c.h create mode 100644 ml/dlib/dlib/matlab/CMakeLists.txt create mode 100644 ml/dlib/dlib/matlab/README.txt create mode 100644 ml/dlib/dlib/matlab/call_matlab.h create mode 100644 ml/dlib/dlib/matlab/cmake_mex_wrapper create mode 100644 ml/dlib/dlib/matlab/example.m create mode 100644 ml/dlib/dlib/matlab/example_mex_callback.cpp create mode 100644 ml/dlib/dlib/matlab/example_mex_class.cpp create mode 100644 ml/dlib/dlib/matlab/example_mex_function.cpp create mode 100644 ml/dlib/dlib/matlab/example_mex_struct.cpp create mode 100644 ml/dlib/dlib/matlab/mex_wrapper.cpp create mode 100644 ml/dlib/dlib/matlab/subprocess_stream.cpp create mode 100644 ml/dlib/dlib/matlab/subprocess_stream.h create mode 100644 ml/dlib/dlib/matrix.h create mode 100644 ml/dlib/dlib/matrix/cblas_constants.h create mode 100644 ml/dlib/dlib/matrix/lapack/fortran_id.h create mode 100644 ml/dlib/dlib/matrix/lapack/gees.h create mode 100644 ml/dlib/dlib/matrix/lapack/geev.h create mode 100644 ml/dlib/dlib/matrix/lapack/geqrf.h create mode 100644 ml/dlib/dlib/matrix/lapack/gesdd.h create mode 100644 ml/dlib/dlib/matrix/lapack/gesvd.h create mode 100644 ml/dlib/dlib/matrix/lapack/getrf.h create mode 100644 ml/dlib/dlib/matrix/lapack/ormqr.h create mode 100644 ml/dlib/dlib/matrix/lapack/pbtrf.h create mode 100644 ml/dlib/dlib/matrix/lapack/potrf.h create mode 100644 ml/dlib/dlib/matrix/lapack/syev.h create mode 100644 ml/dlib/dlib/matrix/lapack/syevr.h create mode 100644 ml/dlib/dlib/matrix/matrix.h create mode 100644 ml/dlib/dlib/matrix/matrix_abstract.h create mode 100644 ml/dlib/dlib/matrix/matrix_assign.h create mode 100644 ml/dlib/dlib/matrix/matrix_assign_fwd.h create mode 100644 ml/dlib/dlib/matrix/matrix_blas_bindings.h create mode 100644 ml/dlib/dlib/matrix/matrix_cholesky.h create mode 100644 ml/dlib/dlib/matrix/matrix_conj_trans.h create mode 100644 ml/dlib/dlib/matrix/matrix_conv.h create mode 100644 ml/dlib/dlib/matrix/matrix_conv_abstract.h create mode 100644 ml/dlib/dlib/matrix/matrix_data_layout.h create mode 100644 ml/dlib/dlib/matrix/matrix_data_layout_abstract.h create mode 100644 ml/dlib/dlib/matrix/matrix_default_mul.h create mode 100644 ml/dlib/dlib/matrix/matrix_eigenvalue.h create mode 100644 ml/dlib/dlib/matrix/matrix_exp.h create mode 100644 ml/dlib/dlib/matrix/matrix_exp_abstract.h create mode 100644 ml/dlib/dlib/matrix/matrix_expressions.h create mode 100644 ml/dlib/dlib/matrix/matrix_fft.h create mode 100644 ml/dlib/dlib/matrix/matrix_fft_abstract.h create mode 100644 ml/dlib/dlib/matrix/matrix_fwd.h create mode 100644 ml/dlib/dlib/matrix/matrix_generic_image.h create mode 100644 ml/dlib/dlib/matrix/matrix_la.h create mode 100644 ml/dlib/dlib/matrix/matrix_la_abstract.h create mode 100644 ml/dlib/dlib/matrix/matrix_lu.h create mode 100644 ml/dlib/dlib/matrix/matrix_mat.h create mode 100644 ml/dlib/dlib/matrix/matrix_mat_abstract.h create mode 100644 ml/dlib/dlib/matrix/matrix_math_functions.h create mode 100644 ml/dlib/dlib/matrix/matrix_math_functions_abstract.h create mode 100644 ml/dlib/dlib/matrix/matrix_op.h create mode 100644 ml/dlib/dlib/matrix/matrix_qr.h create mode 100644 ml/dlib/dlib/matrix/matrix_read_from_istream.h create mode 100644 ml/dlib/dlib/matrix/matrix_subexp.h create mode 100644 ml/dlib/dlib/matrix/matrix_subexp_abstract.h create mode 100644 ml/dlib/dlib/matrix/matrix_trsm.h create mode 100644 ml/dlib/dlib/matrix/matrix_utilities.h create mode 100644 ml/dlib/dlib/matrix/matrix_utilities_abstract.h create mode 100644 ml/dlib/dlib/matrix/symmetric_matrix_cache.h create mode 100644 ml/dlib/dlib/matrix/symmetric_matrix_cache_abstract.h create mode 100644 ml/dlib/dlib/md5.h create mode 100644 ml/dlib/dlib/md5/md5_kernel_1.cpp create mode 100644 ml/dlib/dlib/md5/md5_kernel_1.h create mode 100644 ml/dlib/dlib/md5/md5_kernel_abstract.h create mode 100644 ml/dlib/dlib/member_function_pointer.h create mode 100644 ml/dlib/dlib/member_function_pointer/make_mfp.h create mode 100644 ml/dlib/dlib/member_function_pointer/make_mfp_abstract.h create mode 100644 ml/dlib/dlib/member_function_pointer/member_function_pointer_kernel_1.h create mode 100644 ml/dlib/dlib/member_function_pointer/member_function_pointer_kernel_abstract.h create mode 100644 ml/dlib/dlib/memory_manager.h create mode 100644 ml/dlib/dlib/memory_manager/memory_manager_kernel_1.h create mode 100644 ml/dlib/dlib/memory_manager/memory_manager_kernel_2.h create mode 100644 ml/dlib/dlib/memory_manager/memory_manager_kernel_3.h create mode 100644 ml/dlib/dlib/memory_manager/memory_manager_kernel_abstract.h create mode 100644 ml/dlib/dlib/memory_manager_global.h create mode 100644 ml/dlib/dlib/memory_manager_global/memory_manager_global_kernel_1.h create mode 100644 ml/dlib/dlib/memory_manager_global/memory_manager_global_kernel_abstract.h create mode 100644 ml/dlib/dlib/memory_manager_stateless.h create mode 100644 ml/dlib/dlib/memory_manager_stateless/memory_manager_stateless_kernel_1.h create mode 100644 ml/dlib/dlib/memory_manager_stateless/memory_manager_stateless_kernel_2.h create mode 100644 ml/dlib/dlib/memory_manager_stateless/memory_manager_stateless_kernel_abstract.h create mode 100644 ml/dlib/dlib/metaprogramming.h create mode 100644 ml/dlib/dlib/misc_api.h create mode 100644 ml/dlib/dlib/misc_api/misc_api_kernel_1.cpp create mode 100644 ml/dlib/dlib/misc_api/misc_api_kernel_1.h create mode 100644 ml/dlib/dlib/misc_api/misc_api_kernel_2.cpp create mode 100644 ml/dlib/dlib/misc_api/misc_api_kernel_2.h create mode 100644 ml/dlib/dlib/misc_api/misc_api_kernel_abstract.h create mode 100644 ml/dlib/dlib/misc_api/misc_api_shared.h create mode 100644 ml/dlib/dlib/misc_api/posix.h create mode 100644 ml/dlib/dlib/misc_api/windows.h create mode 100644 ml/dlib/dlib/mlp.h create mode 100644 ml/dlib/dlib/mlp/mlp_kernel_1.h create mode 100644 ml/dlib/dlib/mlp/mlp_kernel_abstract.h create mode 100644 ml/dlib/dlib/mlp/mlp_kernel_c.h create mode 100644 ml/dlib/dlib/noncopyable.h create mode 100644 ml/dlib/dlib/numeric_constants.h create mode 100644 ml/dlib/dlib/numerical_integration.h create mode 100644 ml/dlib/dlib/numerical_integration/integrate_function_adapt_simpson.h create mode 100644 ml/dlib/dlib/numerical_integration/integrate_function_adapt_simpson_abstract.h create mode 100644 ml/dlib/dlib/opencv.h create mode 100644 ml/dlib/dlib/opencv/cv_image.h create mode 100644 ml/dlib/dlib/opencv/cv_image_abstract.h create mode 100644 ml/dlib/dlib/opencv/to_open_cv.h create mode 100644 ml/dlib/dlib/opencv/to_open_cv_abstract.h create mode 100644 ml/dlib/dlib/optimization.h create mode 100644 ml/dlib/dlib/optimization/elastic_net.h create mode 100644 ml/dlib/dlib/optimization/elastic_net_abstract.h create mode 100644 ml/dlib/dlib/optimization/find_max_factor_graph_nmplp.h create mode 100644 ml/dlib/dlib/optimization/find_max_factor_graph_nmplp_abstract.h create mode 100644 ml/dlib/dlib/optimization/find_max_factor_graph_viterbi.h create mode 100644 ml/dlib/dlib/optimization/find_max_factor_graph_viterbi_abstract.h create mode 100644 ml/dlib/dlib/optimization/find_max_parse_cky.h create mode 100644 ml/dlib/dlib/optimization/find_max_parse_cky_abstract.h create mode 100644 ml/dlib/dlib/optimization/find_optimal_parameters.h create mode 100644 ml/dlib/dlib/optimization/find_optimal_parameters_abstract.h create mode 100644 ml/dlib/dlib/optimization/isotonic_regression.h create mode 100644 ml/dlib/dlib/optimization/isotonic_regression_abstract.h create mode 100644 ml/dlib/dlib/optimization/max_cost_assignment.h create mode 100644 ml/dlib/dlib/optimization/max_cost_assignment_abstract.h create mode 100644 ml/dlib/dlib/optimization/max_sum_submatrix.h create mode 100644 ml/dlib/dlib/optimization/max_sum_submatrix_abstract.h create mode 100644 ml/dlib/dlib/optimization/optimization.h create mode 100644 ml/dlib/dlib/optimization/optimization_abstract.h create mode 100644 ml/dlib/dlib/optimization/optimization_bobyqa.h create mode 100644 ml/dlib/dlib/optimization/optimization_bobyqa_abstract.h create mode 100644 ml/dlib/dlib/optimization/optimization_least_squares.h create mode 100644 ml/dlib/dlib/optimization/optimization_least_squares_abstract.h create mode 100644 ml/dlib/dlib/optimization/optimization_line_search.h create mode 100644 ml/dlib/dlib/optimization/optimization_line_search_abstract.h create mode 100644 ml/dlib/dlib/optimization/optimization_oca.h create mode 100644 ml/dlib/dlib/optimization/optimization_oca_abstract.h create mode 100644 ml/dlib/dlib/optimization/optimization_search_strategies.h create mode 100644 ml/dlib/dlib/optimization/optimization_search_strategies_abstract.h create mode 100644 ml/dlib/dlib/optimization/optimization_solve_qp2_using_smo.h create mode 100644 ml/dlib/dlib/optimization/optimization_solve_qp2_using_smo_abstract.h create mode 100644 ml/dlib/dlib/optimization/optimization_solve_qp3_using_smo.h create mode 100644 ml/dlib/dlib/optimization/optimization_solve_qp3_using_smo_abstract.h create mode 100644 ml/dlib/dlib/optimization/optimization_solve_qp_using_smo.h create mode 100644 ml/dlib/dlib/optimization/optimization_solve_qp_using_smo_abstract.h create mode 100644 ml/dlib/dlib/optimization/optimization_stop_strategies.h create mode 100644 ml/dlib/dlib/optimization/optimization_stop_strategies_abstract.h create mode 100644 ml/dlib/dlib/optimization/optimization_trust_region.h create mode 100644 ml/dlib/dlib/optimization/optimization_trust_region_abstract.h create mode 100644 ml/dlib/dlib/ostream create mode 100644 ml/dlib/dlib/pipe.h create mode 100644 ml/dlib/dlib/pipe/pipe_kernel_1.h create mode 100644 ml/dlib/dlib/pipe/pipe_kernel_abstract.h create mode 100644 ml/dlib/dlib/pixel.h create mode 100644 ml/dlib/dlib/platform.h create mode 100644 ml/dlib/dlib/python.h create mode 100644 ml/dlib/dlib/python/numpy.h create mode 100644 ml/dlib/dlib/python/numpy_image.h create mode 100644 ml/dlib/dlib/python/pyassert.h create mode 100644 ml/dlib/dlib/python/pybind_utils.h create mode 100644 ml/dlib/dlib/python/serialize_pickle.h create mode 100644 ml/dlib/dlib/quantum_computing.h create mode 100644 ml/dlib/dlib/quantum_computing/quantum_computing.h create mode 100644 ml/dlib/dlib/quantum_computing/quantum_computing_abstract.h create mode 100644 ml/dlib/dlib/queue.h create mode 100644 ml/dlib/dlib/queue/queue_kernel_1.h create mode 100644 ml/dlib/dlib/queue/queue_kernel_2.h create mode 100644 ml/dlib/dlib/queue/queue_kernel_abstract.h create mode 100644 ml/dlib/dlib/queue/queue_kernel_c.h create mode 100644 ml/dlib/dlib/queue/queue_sort_1.h create mode 100644 ml/dlib/dlib/queue/queue_sort_abstract.h create mode 100644 ml/dlib/dlib/rand.h create mode 100644 ml/dlib/dlib/rand/mersenne_twister.h create mode 100644 ml/dlib/dlib/rand/rand_kernel_1.h create mode 100644 ml/dlib/dlib/rand/rand_kernel_abstract.h create mode 100644 ml/dlib/dlib/random_forest.h create mode 100644 ml/dlib/dlib/random_forest/random_forest_regression.h create mode 100644 ml/dlib/dlib/random_forest/random_forest_regression_abstract.h create mode 100644 ml/dlib/dlib/ref.h create mode 100644 ml/dlib/dlib/reference_counter.h create mode 100644 ml/dlib/dlib/reference_counter/reference_counter_kernel_1.h create mode 100644 ml/dlib/dlib/reference_counter/reference_counter_kernel_abstract.h create mode 100644 ml/dlib/dlib/revision.h.in create mode 100644 ml/dlib/dlib/sequence.h create mode 100644 ml/dlib/dlib/sequence/sequence_compare_1.h create mode 100644 ml/dlib/dlib/sequence/sequence_compare_abstract.h create mode 100644 ml/dlib/dlib/sequence/sequence_kernel_1.h create mode 100644 ml/dlib/dlib/sequence/sequence_kernel_2.h create mode 100644 ml/dlib/dlib/sequence/sequence_kernel_abstract.h create mode 100644 ml/dlib/dlib/sequence/sequence_kernel_c.h create mode 100644 ml/dlib/dlib/sequence/sequence_sort_1.h create mode 100644 ml/dlib/dlib/sequence/sequence_sort_2.h create mode 100644 ml/dlib/dlib/sequence/sequence_sort_abstract.h create mode 100644 ml/dlib/dlib/serialize.h create mode 100644 ml/dlib/dlib/server.h create mode 100644 ml/dlib/dlib/server/server_http.cpp create mode 100644 ml/dlib/dlib/server/server_http.h create mode 100644 ml/dlib/dlib/server/server_http_abstract.h create mode 100644 ml/dlib/dlib/server/server_iostream.cpp create mode 100644 ml/dlib/dlib/server/server_iostream.h create mode 100644 ml/dlib/dlib/server/server_iostream_abstract.h create mode 100644 ml/dlib/dlib/server/server_kernel.cpp create mode 100644 ml/dlib/dlib/server/server_kernel.h create mode 100644 ml/dlib/dlib/server/server_kernel_abstract.h create mode 100644 ml/dlib/dlib/set.h create mode 100644 ml/dlib/dlib/set/set_compare_1.h create mode 100644 ml/dlib/dlib/set/set_compare_abstract.h create mode 100644 ml/dlib/dlib/set/set_kernel_1.h create mode 100644 ml/dlib/dlib/set/set_kernel_abstract.h create mode 100644 ml/dlib/dlib/set/set_kernel_c.h create mode 100644 ml/dlib/dlib/set_utils.h create mode 100644 ml/dlib/dlib/set_utils/set_utils.h create mode 100644 ml/dlib/dlib/set_utils/set_utils_abstract.h create mode 100644 ml/dlib/dlib/simd.h create mode 100644 ml/dlib/dlib/simd/simd4f.h create mode 100644 ml/dlib/dlib/simd/simd4i.h create mode 100644 ml/dlib/dlib/simd/simd8f.h create mode 100644 ml/dlib/dlib/simd/simd8i.h create mode 100644 ml/dlib/dlib/simd/simd_check.h create mode 100644 ml/dlib/dlib/sliding_buffer.h create mode 100644 ml/dlib/dlib/sliding_buffer/circular_buffer.h create mode 100644 ml/dlib/dlib/sliding_buffer/circular_buffer_abstract.h create mode 100644 ml/dlib/dlib/sliding_buffer/sliding_buffer_kernel_1.h create mode 100644 ml/dlib/dlib/sliding_buffer/sliding_buffer_kernel_abstract.h create mode 100644 ml/dlib/dlib/sliding_buffer/sliding_buffer_kernel_c.h create mode 100644 ml/dlib/dlib/smart_pointers.h create mode 100644 ml/dlib/dlib/smart_pointers/scoped_ptr.h create mode 100644 ml/dlib/dlib/smart_pointers/shared_ptr.h create mode 100644 ml/dlib/dlib/smart_pointers/shared_ptr_abstract.h create mode 100644 ml/dlib/dlib/smart_pointers/shared_ptr_thread_safe.h create mode 100644 ml/dlib/dlib/smart_pointers/shared_ptr_thread_safe_abstract.h create mode 100644 ml/dlib/dlib/smart_pointers/weak_ptr.h create mode 100644 ml/dlib/dlib/smart_pointers/weak_ptr_abstract.h create mode 100644 ml/dlib/dlib/smart_pointers_thread_safe.h create mode 100644 ml/dlib/dlib/sockets.h create mode 100644 ml/dlib/dlib/sockets/posix.h create mode 100644 ml/dlib/dlib/sockets/sockets_extensions.cpp create mode 100644 ml/dlib/dlib/sockets/sockets_extensions.h create mode 100644 ml/dlib/dlib/sockets/sockets_extensions_abstract.h create mode 100644 ml/dlib/dlib/sockets/sockets_kernel_1.cpp create mode 100644 ml/dlib/dlib/sockets/sockets_kernel_1.h create mode 100644 ml/dlib/dlib/sockets/sockets_kernel_2.cpp create mode 100644 ml/dlib/dlib/sockets/sockets_kernel_2.h create mode 100644 ml/dlib/dlib/sockets/sockets_kernel_abstract.h create mode 100644 ml/dlib/dlib/sockets/windows.h create mode 100644 ml/dlib/dlib/sockstreambuf.h create mode 100644 ml/dlib/dlib/sockstreambuf/sockstreambuf.cpp create mode 100644 ml/dlib/dlib/sockstreambuf/sockstreambuf.h create mode 100644 ml/dlib/dlib/sockstreambuf/sockstreambuf_abstract.h create mode 100644 ml/dlib/dlib/sockstreambuf/sockstreambuf_unbuffered.cpp create mode 100644 ml/dlib/dlib/sockstreambuf/sockstreambuf_unbuffered.h create mode 100644 ml/dlib/dlib/sort.h create mode 100644 ml/dlib/dlib/sparse_vector.h create mode 100644 ml/dlib/dlib/sqlite.h create mode 100644 ml/dlib/dlib/sqlite/sqlite.h create mode 100644 ml/dlib/dlib/sqlite/sqlite_abstract.h create mode 100644 ml/dlib/dlib/sqlite/sqlite_tools.h create mode 100644 ml/dlib/dlib/sqlite/sqlite_tools_abstract.h create mode 100644 ml/dlib/dlib/sstream create mode 100644 ml/dlib/dlib/stack.h create mode 100644 ml/dlib/dlib/stack/stack_kernel_1.h create mode 100644 ml/dlib/dlib/stack/stack_kernel_abstract.h create mode 100644 ml/dlib/dlib/stack/stack_kernel_c.h create mode 100644 ml/dlib/dlib/stack_trace.cpp create mode 100644 ml/dlib/dlib/stack_trace.h create mode 100644 ml/dlib/dlib/static_map.h create mode 100644 ml/dlib/dlib/static_map/static_map_kernel_1.h create mode 100644 ml/dlib/dlib/static_map/static_map_kernel_abstract.h create mode 100644 ml/dlib/dlib/static_map/static_map_kernel_c.h create mode 100644 ml/dlib/dlib/static_set.h create mode 100644 ml/dlib/dlib/static_set/static_set_compare_1.h create mode 100644 ml/dlib/dlib/static_set/static_set_compare_abstract.h create mode 100644 ml/dlib/dlib/static_set/static_set_kernel_1.h create mode 100644 ml/dlib/dlib/static_set/static_set_kernel_abstract.h create mode 100644 ml/dlib/dlib/static_set/static_set_kernel_c.h create mode 100644 ml/dlib/dlib/statistics.h create mode 100644 ml/dlib/dlib/statistics/average_precision.h create mode 100644 ml/dlib/dlib/statistics/average_precision_abstract.h create mode 100644 ml/dlib/dlib/statistics/cca.h create mode 100644 ml/dlib/dlib/statistics/cca_abstract.h create mode 100644 ml/dlib/dlib/statistics/dpca.h create mode 100644 ml/dlib/dlib/statistics/dpca_abstract.h create mode 100644 ml/dlib/dlib/statistics/image_feature_sampling.h create mode 100644 ml/dlib/dlib/statistics/image_feature_sampling_abstract.h create mode 100644 ml/dlib/dlib/statistics/lda.h create mode 100644 ml/dlib/dlib/statistics/lda_abstract.h create mode 100644 ml/dlib/dlib/statistics/random_subset_selector.h create mode 100644 ml/dlib/dlib/statistics/random_subset_selector_abstract.h create mode 100644 ml/dlib/dlib/statistics/running_gradient.h create mode 100644 ml/dlib/dlib/statistics/running_gradient_abstract.h create mode 100644 ml/dlib/dlib/statistics/sammon.h create mode 100644 ml/dlib/dlib/statistics/sammon_abstract.h create mode 100644 ml/dlib/dlib/statistics/statistics.h create mode 100644 ml/dlib/dlib/statistics/statistics_abstract.h create mode 100644 ml/dlib/dlib/statistics/vector_normalizer_frobmetric.h create mode 100644 ml/dlib/dlib/statistics/vector_normalizer_frobmetric_abstract.h create mode 100644 ml/dlib/dlib/std_allocator.h create mode 100644 ml/dlib/dlib/stl_checked.h create mode 100644 ml/dlib/dlib/stl_checked/std_vector_c.h create mode 100644 ml/dlib/dlib/stl_checked/std_vector_c_abstract.h create mode 100644 ml/dlib/dlib/string.h create mode 100644 ml/dlib/dlib/string/cassert create mode 100644 ml/dlib/dlib/string/iomanip create mode 100644 ml/dlib/dlib/string/iosfwd create mode 100644 ml/dlib/dlib/string/iostream create mode 100644 ml/dlib/dlib/string/locale create mode 100644 ml/dlib/dlib/string/string.h create mode 100644 ml/dlib/dlib/string/string_abstract.h create mode 100644 ml/dlib/dlib/svm.h create mode 100644 ml/dlib/dlib/svm/active_learning.h create mode 100644 ml/dlib/dlib/svm/active_learning_abstract.h create mode 100644 ml/dlib/dlib/svm/assignment_function.h create mode 100644 ml/dlib/dlib/svm/assignment_function_abstract.h create mode 100644 ml/dlib/dlib/svm/cross_validate_assignment_trainer.h create mode 100644 ml/dlib/dlib/svm/cross_validate_assignment_trainer_abstract.h create mode 100644 ml/dlib/dlib/svm/cross_validate_graph_labeling_trainer.h create mode 100644 ml/dlib/dlib/svm/cross_validate_graph_labeling_trainer_abstract.h create mode 100644 ml/dlib/dlib/svm/cross_validate_multiclass_trainer.h create mode 100644 ml/dlib/dlib/svm/cross_validate_multiclass_trainer_abstract.h create mode 100644 ml/dlib/dlib/svm/cross_validate_object_detection_trainer.h create mode 100644 ml/dlib/dlib/svm/cross_validate_object_detection_trainer_abstract.h create mode 100644 ml/dlib/dlib/svm/cross_validate_regression_trainer.h create mode 100644 ml/dlib/dlib/svm/cross_validate_regression_trainer_abstract.h create mode 100644 ml/dlib/dlib/svm/cross_validate_sequence_labeler.h create mode 100644 ml/dlib/dlib/svm/cross_validate_sequence_labeler_abstract.h create mode 100644 ml/dlib/dlib/svm/cross_validate_sequence_segmenter.h create mode 100644 ml/dlib/dlib/svm/cross_validate_sequence_segmenter_abstract.h create mode 100644 ml/dlib/dlib/svm/cross_validate_track_association_trainer.h create mode 100644 ml/dlib/dlib/svm/cross_validate_track_association_trainer_abstract.h create mode 100644 ml/dlib/dlib/svm/empirical_kernel_map.h create mode 100644 ml/dlib/dlib/svm/empirical_kernel_map_abstract.h create mode 100644 ml/dlib/dlib/svm/feature_ranking.h create mode 100644 ml/dlib/dlib/svm/feature_ranking_abstract.h create mode 100644 ml/dlib/dlib/svm/function.h create mode 100644 ml/dlib/dlib/svm/function_abstract.h create mode 100644 ml/dlib/dlib/svm/kcentroid.h create mode 100644 ml/dlib/dlib/svm/kcentroid_abstract.h create mode 100644 ml/dlib/dlib/svm/kcentroid_overloads.h create mode 100644 ml/dlib/dlib/svm/kernel.h create mode 100644 ml/dlib/dlib/svm/kernel_abstract.h create mode 100644 ml/dlib/dlib/svm/kernel_matrix.h create mode 100644 ml/dlib/dlib/svm/kernel_matrix_abstract.h create mode 100644 ml/dlib/dlib/svm/kkmeans.h create mode 100644 ml/dlib/dlib/svm/kkmeans_abstract.h create mode 100644 ml/dlib/dlib/svm/krls.h create mode 100644 ml/dlib/dlib/svm/krls_abstract.h create mode 100644 ml/dlib/dlib/svm/krr_trainer.h create mode 100644 ml/dlib/dlib/svm/krr_trainer_abstract.h create mode 100644 ml/dlib/dlib/svm/linearly_independent_subset_finder.h create mode 100644 ml/dlib/dlib/svm/linearly_independent_subset_finder_abstract.h create mode 100644 ml/dlib/dlib/svm/multiclass_tools.h create mode 100644 ml/dlib/dlib/svm/multiclass_tools_abstract.h create mode 100644 ml/dlib/dlib/svm/null_df.h create mode 100644 ml/dlib/dlib/svm/null_trainer.h create mode 100644 ml/dlib/dlib/svm/null_trainer_abstract.h create mode 100644 ml/dlib/dlib/svm/num_nonnegative_weights.h create mode 100644 ml/dlib/dlib/svm/one_vs_all_decision_function.h create mode 100644 ml/dlib/dlib/svm/one_vs_all_decision_function_abstract.h create mode 100644 ml/dlib/dlib/svm/one_vs_all_trainer.h create mode 100644 ml/dlib/dlib/svm/one_vs_all_trainer_abstract.h create mode 100644 ml/dlib/dlib/svm/one_vs_one_decision_function.h create mode 100644 ml/dlib/dlib/svm/one_vs_one_decision_function_abstract.h create mode 100644 ml/dlib/dlib/svm/one_vs_one_trainer.h create mode 100644 ml/dlib/dlib/svm/one_vs_one_trainer_abstract.h create mode 100644 ml/dlib/dlib/svm/pegasos.h create mode 100644 ml/dlib/dlib/svm/pegasos_abstract.h create mode 100644 ml/dlib/dlib/svm/ranking_tools.h create mode 100644 ml/dlib/dlib/svm/ranking_tools_abstract.h create mode 100644 ml/dlib/dlib/svm/rbf_network.h create mode 100644 ml/dlib/dlib/svm/rbf_network_abstract.h create mode 100644 ml/dlib/dlib/svm/reduced.h create mode 100644 ml/dlib/dlib/svm/reduced_abstract.h create mode 100644 ml/dlib/dlib/svm/rls.h create mode 100644 ml/dlib/dlib/svm/rls_abstract.h create mode 100644 ml/dlib/dlib/svm/roc_trainer.h create mode 100644 ml/dlib/dlib/svm/roc_trainer_abstract.h create mode 100644 ml/dlib/dlib/svm/rr_trainer.h create mode 100644 ml/dlib/dlib/svm/rr_trainer_abstract.h create mode 100644 ml/dlib/dlib/svm/rvm.h create mode 100644 ml/dlib/dlib/svm/rvm_abstract.h create mode 100644 ml/dlib/dlib/svm/sequence_labeler.h create mode 100644 ml/dlib/dlib/svm/sequence_labeler_abstract.h create mode 100644 ml/dlib/dlib/svm/sequence_segmenter.h create mode 100644 ml/dlib/dlib/svm/sequence_segmenter_abstract.h create mode 100644 ml/dlib/dlib/svm/simplify_linear_decision_function.h create mode 100644 ml/dlib/dlib/svm/simplify_linear_decision_function_abstract.h create mode 100644 ml/dlib/dlib/svm/sort_basis_vectors.h create mode 100644 ml/dlib/dlib/svm/sort_basis_vectors_abstract.h create mode 100644 ml/dlib/dlib/svm/sparse_kernel.h create mode 100644 ml/dlib/dlib/svm/sparse_kernel_abstract.h create mode 100644 ml/dlib/dlib/svm/sparse_vector.h create mode 100644 ml/dlib/dlib/svm/sparse_vector_abstract.h create mode 100644 ml/dlib/dlib/svm/structural_assignment_trainer.h create mode 100644 ml/dlib/dlib/svm/structural_assignment_trainer_abstract.h create mode 100644 ml/dlib/dlib/svm/structural_graph_labeling_trainer.h create mode 100644 ml/dlib/dlib/svm/structural_graph_labeling_trainer_abstract.h create mode 100644 ml/dlib/dlib/svm/structural_object_detection_trainer.h create mode 100644 ml/dlib/dlib/svm/structural_object_detection_trainer_abstract.h create mode 100644 ml/dlib/dlib/svm/structural_sequence_labeling_trainer.h create mode 100644 ml/dlib/dlib/svm/structural_sequence_labeling_trainer_abstract.h create mode 100644 ml/dlib/dlib/svm/structural_sequence_segmentation_trainer.h create mode 100644 ml/dlib/dlib/svm/structural_sequence_segmentation_trainer_abstract.h create mode 100644 ml/dlib/dlib/svm/structural_svm_assignment_problem.h create mode 100644 ml/dlib/dlib/svm/structural_svm_assignment_problem_abstract.h create mode 100644 ml/dlib/dlib/svm/structural_svm_distributed.h create mode 100644 ml/dlib/dlib/svm/structural_svm_distributed_abstract.h create mode 100644 ml/dlib/dlib/svm/structural_svm_graph_labeling_problem.h create mode 100644 ml/dlib/dlib/svm/structural_svm_graph_labeling_problem_abstract.h create mode 100644 ml/dlib/dlib/svm/structural_svm_object_detection_problem.h create mode 100644 ml/dlib/dlib/svm/structural_svm_object_detection_problem_abstract.h create mode 100644 ml/dlib/dlib/svm/structural_svm_problem.h create mode 100644 ml/dlib/dlib/svm/structural_svm_problem_abstract.h create mode 100644 ml/dlib/dlib/svm/structural_svm_problem_threaded.h create mode 100644 ml/dlib/dlib/svm/structural_svm_problem_threaded_abstract.h create mode 100644 ml/dlib/dlib/svm/structural_svm_sequence_labeling_problem.h create mode 100644 ml/dlib/dlib/svm/structural_svm_sequence_labeling_problem_abstract.h create mode 100644 ml/dlib/dlib/svm/structural_track_association_trainer.h create mode 100644 ml/dlib/dlib/svm/structural_track_association_trainer_abstract.h create mode 100644 ml/dlib/dlib/svm/svm.h create mode 100644 ml/dlib/dlib/svm/svm_abstract.h create mode 100644 ml/dlib/dlib/svm/svm_c_ekm_trainer.h create mode 100644 ml/dlib/dlib/svm/svm_c_ekm_trainer_abstract.h create mode 100644 ml/dlib/dlib/svm/svm_c_linear_dcd_trainer.h create mode 100644 ml/dlib/dlib/svm/svm_c_linear_dcd_trainer_abstract.h create mode 100644 ml/dlib/dlib/svm/svm_c_linear_trainer.h create mode 100644 ml/dlib/dlib/svm/svm_c_linear_trainer_abstract.h create mode 100644 ml/dlib/dlib/svm/svm_c_trainer.h create mode 100644 ml/dlib/dlib/svm/svm_c_trainer_abstract.h create mode 100644 ml/dlib/dlib/svm/svm_multiclass_linear_trainer.h create mode 100644 ml/dlib/dlib/svm/svm_multiclass_linear_trainer_abstract.h create mode 100644 ml/dlib/dlib/svm/svm_nu_trainer.h create mode 100644 ml/dlib/dlib/svm/svm_nu_trainer_abstract.h create mode 100644 ml/dlib/dlib/svm/svm_one_class_trainer.h create mode 100644 ml/dlib/dlib/svm/svm_one_class_trainer_abstract.h create mode 100644 ml/dlib/dlib/svm/svm_rank_trainer.h create mode 100644 ml/dlib/dlib/svm/svm_rank_trainer_abstract.h create mode 100644 ml/dlib/dlib/svm/svm_threaded.h create mode 100644 ml/dlib/dlib/svm/svm_threaded_abstract.h create mode 100644 ml/dlib/dlib/svm/svr_linear_trainer.h create mode 100644 ml/dlib/dlib/svm/svr_linear_trainer_abstract.h create mode 100644 ml/dlib/dlib/svm/svr_trainer.h create mode 100644 ml/dlib/dlib/svm/svr_trainer_abstract.h create mode 100644 ml/dlib/dlib/svm/track_association_function.h create mode 100644 ml/dlib/dlib/svm/track_association_function_abstract.h create mode 100644 ml/dlib/dlib/svm_threaded.h create mode 100644 ml/dlib/dlib/sync_extension.h create mode 100644 ml/dlib/dlib/sync_extension/sync_extension_kernel_1.h create mode 100644 ml/dlib/dlib/sync_extension/sync_extension_kernel_abstract.h create mode 100644 ml/dlib/dlib/test/CMakeLists.txt create mode 100644 ml/dlib/dlib/test/WINDOWS_build_and_run_all_unit_tests.bat create mode 100644 ml/dlib/dlib/test/active_learning.cpp create mode 100644 ml/dlib/dlib/test/any.cpp create mode 100644 ml/dlib/dlib/test/any_function.cpp create mode 100644 ml/dlib/dlib/test/array.cpp create mode 100644 ml/dlib/dlib/test/array2d.cpp create mode 100644 ml/dlib/dlib/test/assignment_learning.cpp create mode 100644 ml/dlib/dlib/test/base64.cpp create mode 100644 ml/dlib/dlib/test/bayes_nets.cpp create mode 100644 ml/dlib/dlib/test/bigint.cpp create mode 100644 ml/dlib/dlib/test/binary_search_tree.h create mode 100644 ml/dlib/dlib/test/binary_search_tree_kernel_1a.cpp create mode 100644 ml/dlib/dlib/test/binary_search_tree_kernel_2a.cpp create mode 100644 ml/dlib/dlib/test/binary_search_tree_mm1.cpp create mode 100644 ml/dlib/dlib/test/binary_search_tree_mm2.cpp create mode 100644 ml/dlib/dlib/test/blas_bindings/CMakeLists.txt create mode 100644 ml/dlib/dlib/test/blas_bindings/blas_bindings_dot.cpp create mode 100644 ml/dlib/dlib/test/blas_bindings/blas_bindings_gemm.cpp create mode 100644 ml/dlib/dlib/test/blas_bindings/blas_bindings_gemv.cpp create mode 100644 ml/dlib/dlib/test/blas_bindings/blas_bindings_ger.cpp create mode 100644 ml/dlib/dlib/test/blas_bindings/blas_bindings_scal_axpy.cpp create mode 100644 ml/dlib/dlib/test/blas_bindings/vector.cpp create mode 100644 ml/dlib/dlib/test/bridge.cpp create mode 100644 ml/dlib/dlib/test/bsp.cpp create mode 100644 ml/dlib/dlib/test/byte_orderer.cpp create mode 100644 ml/dlib/dlib/test/cca.cpp create mode 100644 ml/dlib/dlib/test/checkerboard.h create mode 100644 ml/dlib/dlib/test/clustering.cpp create mode 100644 ml/dlib/dlib/test/cmd_line_parser.cpp create mode 100644 ml/dlib/dlib/test/cmd_line_parser.h create mode 100644 ml/dlib/dlib/test/cmd_line_parser_wchar_t.cpp create mode 100644 ml/dlib/dlib/test/compress_stream.cpp create mode 100644 ml/dlib/dlib/test/conditioning_class.cpp create mode 100644 ml/dlib/dlib/test/conditioning_class.h create mode 100644 ml/dlib/dlib/test/conditioning_class_c.cpp create mode 100644 ml/dlib/dlib/test/config_reader.cpp create mode 100644 ml/dlib/dlib/test/correlation_tracker.cpp create mode 100644 ml/dlib/dlib/test/crc32.cpp create mode 100644 ml/dlib/dlib/test/create_iris_datafile.cpp create mode 100644 ml/dlib/dlib/test/create_iris_datafile.h create mode 100644 ml/dlib/dlib/test/cublas.cpp create mode 100644 ml/dlib/dlib/test/data_io.cpp create mode 100644 ml/dlib/dlib/test/directed_graph.cpp create mode 100644 ml/dlib/dlib/test/discriminant_pca.cpp create mode 100644 ml/dlib/dlib/test/disjoint_subsets.cpp create mode 100644 ml/dlib/dlib/test/disjoint_subsets_sized.cpp create mode 100644 ml/dlib/dlib/test/dnn.cpp create mode 100644 ml/dlib/dlib/test/ekm_and_lisf.cpp create mode 100644 ml/dlib/dlib/test/elastic_net.cpp create mode 100644 ml/dlib/dlib/test/empirical_kernel_map.cpp create mode 100644 ml/dlib/dlib/test/entropy_coder.cpp create mode 100644 ml/dlib/dlib/test/entropy_encoder_model.cpp create mode 100644 ml/dlib/dlib/test/example.cpp create mode 100644 ml/dlib/dlib/test/example_args.cpp create mode 100644 ml/dlib/dlib/test/examples/CMakeLists.txt create mode 100644 ml/dlib/dlib/test/face.cpp create mode 100644 ml/dlib/dlib/test/fft.cpp create mode 100644 ml/dlib/dlib/test/fhog.cpp create mode 100644 ml/dlib/dlib/test/filtering.cpp create mode 100644 ml/dlib/dlib/test/find_max_factor_graph_nmplp.cpp create mode 100644 ml/dlib/dlib/test/find_max_factor_graph_viterbi.cpp create mode 100644 ml/dlib/dlib/test/find_optimal_parameters.cpp create mode 100644 ml/dlib/dlib/test/geometry.cpp create mode 100644 ml/dlib/dlib/test/global_optimization.cpp create mode 100644 ml/dlib/dlib/test/graph.cpp create mode 100644 ml/dlib/dlib/test/graph_cuts.cpp create mode 100644 ml/dlib/dlib/test/graph_labeler.cpp create mode 100644 ml/dlib/dlib/test/gui/CMakeLists.txt create mode 100644 ml/dlib/dlib/test/gui/main.cpp create mode 100644 ml/dlib/dlib/test/hash.cpp create mode 100644 ml/dlib/dlib/test/hash_map.cpp create mode 100644 ml/dlib/dlib/test/hash_set.cpp create mode 100644 ml/dlib/dlib/test/hash_table.cpp create mode 100644 ml/dlib/dlib/test/hog_image.cpp create mode 100644 ml/dlib/dlib/test/image.cpp create mode 100644 ml/dlib/dlib/test/iosockstream.cpp create mode 100644 ml/dlib/dlib/test/is_same_object.cpp create mode 100644 ml/dlib/dlib/test/isotonic_regression.cpp create mode 100644 ml/dlib/dlib/test/kcentroid.cpp create mode 100644 ml/dlib/dlib/test/kernel_matrix.cpp create mode 100644 ml/dlib/dlib/test/kmeans.cpp create mode 100644 ml/dlib/dlib/test/learning_to_track.cpp create mode 100644 ml/dlib/dlib/test/least_squares.cpp create mode 100644 ml/dlib/dlib/test/linear_manifold_regularizer.cpp create mode 100644 ml/dlib/dlib/test/lspi.cpp create mode 100644 ml/dlib/dlib/test/lz77_buffer.cpp create mode 100644 ml/dlib/dlib/test/main.cpp create mode 100644 ml/dlib/dlib/test/makefile create mode 100644 ml/dlib/dlib/test/map.cpp create mode 100644 ml/dlib/dlib/test/matrix.cpp create mode 100644 ml/dlib/dlib/test/matrix2.cpp create mode 100644 ml/dlib/dlib/test/matrix3.cpp create mode 100644 ml/dlib/dlib/test/matrix4.cpp create mode 100644 ml/dlib/dlib/test/matrix_chol.cpp create mode 100644 ml/dlib/dlib/test/matrix_eig.cpp create mode 100644 ml/dlib/dlib/test/matrix_lu.cpp create mode 100644 ml/dlib/dlib/test/matrix_qr.cpp create mode 100644 ml/dlib/dlib/test/max_cost_assignment.cpp create mode 100644 ml/dlib/dlib/test/max_sum_submatrix.cpp create mode 100644 ml/dlib/dlib/test/md5.cpp create mode 100644 ml/dlib/dlib/test/member_function_pointer.cpp create mode 100644 ml/dlib/dlib/test/metaprogramming.cpp create mode 100644 ml/dlib/dlib/test/mpc.cpp create mode 100644 ml/dlib/dlib/test/multithreaded_object.cpp create mode 100644 ml/dlib/dlib/test/numerical_integration.cpp create mode 100644 ml/dlib/dlib/test/object_detector.cpp create mode 100644 ml/dlib/dlib/test/oca.cpp create mode 100644 ml/dlib/dlib/test/one_vs_all_trainer.cpp create mode 100644 ml/dlib/dlib/test/one_vs_one_trainer.cpp create mode 100644 ml/dlib/dlib/test/opt_qp_solver.cpp create mode 100644 ml/dlib/dlib/test/optimization.cpp create mode 100644 ml/dlib/dlib/test/optimization_test_functions.cpp create mode 100644 ml/dlib/dlib/test/optimization_test_functions.h create mode 100644 ml/dlib/dlib/test/parallel_for.cpp create mode 100644 ml/dlib/dlib/test/parse.cpp create mode 100644 ml/dlib/dlib/test/pipe.cpp create mode 100644 ml/dlib/dlib/test/pixel.cpp create mode 100644 ml/dlib/dlib/test/probabilistic.cpp create mode 100644 ml/dlib/dlib/test/pyramid_down.cpp create mode 100644 ml/dlib/dlib/test/queue.cpp create mode 100644 ml/dlib/dlib/test/rand.cpp create mode 100644 ml/dlib/dlib/test/random_forest.cpp create mode 100644 ml/dlib/dlib/test/ranking.cpp create mode 100644 ml/dlib/dlib/test/read_write_mutex.cpp create mode 100644 ml/dlib/dlib/test/reference_counter.cpp create mode 100644 ml/dlib/dlib/test/rls.cpp create mode 100644 ml/dlib/dlib/test/sammon.cpp create mode 100644 ml/dlib/dlib/test/scan_image.cpp create mode 100644 ml/dlib/dlib/test/sequence.cpp create mode 100644 ml/dlib/dlib/test/sequence_labeler.cpp create mode 100644 ml/dlib/dlib/test/sequence_segmenter.cpp create mode 100644 ml/dlib/dlib/test/serialize.cpp create mode 100644 ml/dlib/dlib/test/set.cpp create mode 100644 ml/dlib/dlib/test/sldf.cpp create mode 100644 ml/dlib/dlib/test/sliding_buffer.cpp create mode 100644 ml/dlib/dlib/test/smart_pointers.cpp create mode 100644 ml/dlib/dlib/test/sockets.cpp create mode 100644 ml/dlib/dlib/test/sockets2.cpp create mode 100644 ml/dlib/dlib/test/sockstreambuf.cpp create mode 100644 ml/dlib/dlib/test/sparse_vector.cpp create mode 100644 ml/dlib/dlib/test/stack.cpp create mode 100644 ml/dlib/dlib/test/static_map.cpp create mode 100644 ml/dlib/dlib/test/static_set.cpp create mode 100644 ml/dlib/dlib/test/statistics.cpp create mode 100644 ml/dlib/dlib/test/std_vector_c.cpp create mode 100644 ml/dlib/dlib/test/string.cpp create mode 100644 ml/dlib/dlib/test/svm.cpp create mode 100644 ml/dlib/dlib/test/svm_c_linear.cpp create mode 100644 ml/dlib/dlib/test/svm_c_linear_dcd.cpp create mode 100644 ml/dlib/dlib/test/svm_multiclass_linear.cpp create mode 100644 ml/dlib/dlib/test/svm_struct.cpp create mode 100644 ml/dlib/dlib/test/svr_linear_trainer.cpp create mode 100644 ml/dlib/dlib/test/symmetric_matrix_cache.cpp create mode 100644 ml/dlib/dlib/test/tester.cpp create mode 100644 ml/dlib/dlib/test/tester.h create mode 100644 ml/dlib/dlib/test/thread_pool.cpp create mode 100644 ml/dlib/dlib/test/threads.cpp create mode 100644 ml/dlib/dlib/test/timer.cpp create mode 100644 ml/dlib/dlib/test/tokenizer.cpp create mode 100644 ml/dlib/dlib/test/tools/CMakeLists.txt create mode 100644 ml/dlib/dlib/test/trust_region.cpp create mode 100644 ml/dlib/dlib/test/tuple.cpp create mode 100644 ml/dlib/dlib/test/type_safe_union.cpp create mode 100644 ml/dlib/dlib/test/vectorstream.cpp create mode 100644 ml/dlib/dlib/test_for_odr_violations.cpp create mode 100644 ml/dlib/dlib/test_for_odr_violations.h create mode 100644 ml/dlib/dlib/threads.h create mode 100644 ml/dlib/dlib/threads/async.cpp create mode 100644 ml/dlib/dlib/threads/async.h create mode 100644 ml/dlib/dlib/threads/async_abstract.h create mode 100644 ml/dlib/dlib/threads/auto_mutex_extension.h create mode 100644 ml/dlib/dlib/threads/auto_mutex_extension_abstract.h create mode 100644 ml/dlib/dlib/threads/auto_unlock_extension.h create mode 100644 ml/dlib/dlib/threads/auto_unlock_extension_abstract.h create mode 100644 ml/dlib/dlib/threads/create_new_thread_extension.h create mode 100644 ml/dlib/dlib/threads/create_new_thread_extension_abstract.h create mode 100644 ml/dlib/dlib/threads/multithreaded_object_extension.cpp create mode 100644 ml/dlib/dlib/threads/multithreaded_object_extension.h create mode 100644 ml/dlib/dlib/threads/multithreaded_object_extension_abstract.h create mode 100644 ml/dlib/dlib/threads/parallel_for_extension.h create mode 100644 ml/dlib/dlib/threads/parallel_for_extension_abstract.h create mode 100644 ml/dlib/dlib/threads/posix.h create mode 100644 ml/dlib/dlib/threads/read_write_mutex_extension.h create mode 100644 ml/dlib/dlib/threads/read_write_mutex_extension_abstract.h create mode 100644 ml/dlib/dlib/threads/rmutex_extension.h create mode 100644 ml/dlib/dlib/threads/rmutex_extension_abstract.h create mode 100644 ml/dlib/dlib/threads/rsignaler_extension.h create mode 100644 ml/dlib/dlib/threads/rsignaler_extension_abstract.h create mode 100644 ml/dlib/dlib/threads/thread_function_extension.h create mode 100644 ml/dlib/dlib/threads/thread_function_extension_abstract.h create mode 100644 ml/dlib/dlib/threads/thread_pool_extension.cpp create mode 100644 ml/dlib/dlib/threads/thread_pool_extension.h create mode 100644 ml/dlib/dlib/threads/thread_pool_extension_abstract.h create mode 100644 ml/dlib/dlib/threads/thread_specific_data_extension.h create mode 100644 ml/dlib/dlib/threads/thread_specific_data_extension_abstract.h create mode 100644 ml/dlib/dlib/threads/threaded_object_extension.cpp create mode 100644 ml/dlib/dlib/threads/threaded_object_extension.h create mode 100644 ml/dlib/dlib/threads/threaded_object_extension_abstract.h create mode 100644 ml/dlib/dlib/threads/threads_kernel.h create mode 100644 ml/dlib/dlib/threads/threads_kernel_1.cpp create mode 100644 ml/dlib/dlib/threads/threads_kernel_1.h create mode 100644 ml/dlib/dlib/threads/threads_kernel_2.cpp create mode 100644 ml/dlib/dlib/threads/threads_kernel_2.h create mode 100644 ml/dlib/dlib/threads/threads_kernel_abstract.h create mode 100644 ml/dlib/dlib/threads/threads_kernel_shared.cpp create mode 100644 ml/dlib/dlib/threads/threads_kernel_shared.h create mode 100644 ml/dlib/dlib/threads/windows.h create mode 100644 ml/dlib/dlib/time_this.h create mode 100644 ml/dlib/dlib/timeout.h create mode 100644 ml/dlib/dlib/timeout/timeout.h create mode 100644 ml/dlib/dlib/timeout/timeout_abstract.h create mode 100644 ml/dlib/dlib/timer.h create mode 100644 ml/dlib/dlib/timer/timer.cpp create mode 100644 ml/dlib/dlib/timer/timer.h create mode 100644 ml/dlib/dlib/timer/timer_abstract.h create mode 100644 ml/dlib/dlib/timer/timer_heavy.h create mode 100644 ml/dlib/dlib/timing.h create mode 100644 ml/dlib/dlib/tokenizer.h create mode 100644 ml/dlib/dlib/tokenizer/tokenizer_kernel_1.cpp create mode 100644 ml/dlib/dlib/tokenizer/tokenizer_kernel_1.h create mode 100644 ml/dlib/dlib/tokenizer/tokenizer_kernel_abstract.h create mode 100644 ml/dlib/dlib/tokenizer/tokenizer_kernel_c.h create mode 100755 ml/dlib/dlib/travis/build-and-test.sh create mode 100644 ml/dlib/dlib/tuple.h create mode 100644 ml/dlib/dlib/tuple/tuple.h create mode 100644 ml/dlib/dlib/tuple/tuple_abstract.h create mode 100644 ml/dlib/dlib/type_safe_union.h create mode 100644 ml/dlib/dlib/type_safe_union/type_safe_union_kernel.h create mode 100644 ml/dlib/dlib/type_safe_union/type_safe_union_kernel_abstract.h create mode 100644 ml/dlib/dlib/uintn.h create mode 100644 ml/dlib/dlib/unicode.h create mode 100644 ml/dlib/dlib/unicode/unicode.cpp create mode 100644 ml/dlib/dlib/unicode/unicode.h create mode 100644 ml/dlib/dlib/unicode/unicode_abstract.h create mode 100644 ml/dlib/dlib/unordered_pair.h create mode 100644 ml/dlib/dlib/vectorstream.h create mode 100644 ml/dlib/dlib/vectorstream/unserialize.h create mode 100644 ml/dlib/dlib/vectorstream/unserialize_abstract.h create mode 100644 ml/dlib/dlib/vectorstream/vectorstream.h create mode 100644 ml/dlib/dlib/vectorstream/vectorstream_abstract.h create mode 100644 ml/dlib/dlib/windows_magic.h create mode 100644 ml/dlib/dlib/xml_parser.h create mode 100644 ml/dlib/dlib/xml_parser/xml_parser_kernel_1.h create mode 100644 ml/dlib/dlib/xml_parser/xml_parser_kernel_abstract.h create mode 100644 ml/dlib/dlib/xml_parser/xml_parser_kernel_interfaces.h create mode 100644 ml/dlib/docs/.logger_revnum create mode 100644 ml/dlib/docs/README.txt create mode 100755 ml/dlib/docs/bash_helper_functions create mode 100644 ml/dlib/docs/docs/algorithms.xml create mode 100644 ml/dlib/docs/docs/api.xml create mode 100644 ml/dlib/docs/docs/bayes.xml create mode 100644 ml/dlib/docs/docs/bayesopt_vs_lipo.svg create mode 100644 ml/dlib/docs/docs/bigminus.gif create mode 100644 ml/dlib/docs/docs/bigplus.gif create mode 100644 ml/dlib/docs/docs/books.xml create mode 100644 ml/dlib/docs/docs/boost.png create mode 100644 ml/dlib/docs/docs/change_log.xml create mode 100644 ml/dlib/docs/docs/chm/READ THE README. DO NOT EDIT THE TABLE OF CONTENTS FILE create mode 100644 ml/dlib/docs/docs/chm/READ THE README. DO NOT EDIT THE TABLE OF CONTENTS FILE2 create mode 100644 ml/dlib/docs/docs/chm/READ THE README. DO NOT EDIT THE TABLE OF CONTENTS FILE3 create mode 100644 ml/dlib/docs/docs/chm/README.txt create mode 100644 ml/dlib/docs/docs/chm/documentation.html create mode 100644 ml/dlib/docs/docs/chm/htmlhelp/hha.dll create mode 100644 ml/dlib/docs/docs/chm/htmlhelp/hhc.exe create mode 100644 ml/dlib/docs/docs/chm/htmlhelp/htmlhelp.reg create mode 100644 ml/dlib/docs/docs/chm/htmlhelp/itcc.dll create mode 100644 ml/dlib/docs/docs/chm/htmlhelp/itircl.dll create mode 100644 ml/dlib/docs/docs/chm/htmlhelp/itss.dll create mode 100755 ml/dlib/docs/docs/chm/htmlhelp/setup_htmlhelp.sh create mode 100644 ml/dlib/docs/docs/chm/htmlhelp_stylesheet.xsl create mode 100644 ml/dlib/docs/docs/chm/lib.hhp create mode 100644 ml/dlib/docs/docs/chm/toc.xml create mode 100644 ml/dlib/docs/docs/compile.xml create mode 100644 ml/dlib/docs/docs/compression.xml create mode 100644 ml/dlib/docs/docs/containers.xml create mode 100644 ml/dlib/docs/docs/dlib-icon-30x32.png create mode 100644 ml/dlib/docs/docs/dlib-icon-32.png create mode 100644 ml/dlib/docs/docs/dlib-icon-48.png create mode 100644 ml/dlib/docs/docs/dlib-icon-64.png create mode 100644 ml/dlib/docs/docs/dlib-icon.ico create mode 100644 ml/dlib/docs/docs/dlib-logo-and-icons.svg create mode 100644 ml/dlib/docs/docs/dlib-logo-small.png create mode 100644 ml/dlib/docs/docs/dlib-logo.png create mode 100644 ml/dlib/docs/docs/dlib.css create mode 100644 ml/dlib/docs/docs/dlib.js create mode 100644 ml/dlib/docs/docs/down.gif create mode 100644 ml/dlib/docs/docs/enable_if.html create mode 100644 ml/dlib/docs/docs/face_landmarking_example.png create mode 100644 ml/dlib/docs/docs/faq.xml create mode 100644 ml/dlib/docs/docs/find_max_global_example.mp4 create mode 100644 ml/dlib/docs/docs/find_max_global_example.png create mode 100644 ml/dlib/docs/docs/find_max_global_example.webm create mode 100644 ml/dlib/docs/docs/find_max_global_results_table.svg create mode 100644 ml/dlib/docs/docs/graph_tools.xml create mode 100644 ml/dlib/docs/docs/guipics/button.png create mode 100644 ml/dlib/docs/docs/guipics/check_box.png create mode 100644 ml/dlib/docs/docs/guipics/directed_graph_drawer.png create mode 100644 ml/dlib/docs/docs/guipics/image_window.jpg create mode 100644 ml/dlib/docs/docs/guipics/label.png create mode 100644 ml/dlib/docs/docs/guipics/list_box.png create mode 100644 ml/dlib/docs/docs/guipics/menu_bar.png create mode 100644 ml/dlib/docs/docs/guipics/message_box.png create mode 100644 ml/dlib/docs/docs/guipics/mouse_tracker.png create mode 100644 ml/dlib/docs/docs/guipics/named_rectangle.png create mode 100644 ml/dlib/docs/docs/guipics/open_existing_file_box.png create mode 100644 ml/dlib/docs/docs/guipics/open_file_box.png create mode 100644 ml/dlib/docs/docs/guipics/perspective_window.png create mode 100644 ml/dlib/docs/docs/guipics/popup_menu.png create mode 100644 ml/dlib/docs/docs/guipics/radio_button.png create mode 100644 ml/dlib/docs/docs/guipics/save_file_box.png create mode 100644 ml/dlib/docs/docs/guipics/scroll_bar.png create mode 100644 ml/dlib/docs/docs/guipics/tabbed_display.png create mode 100644 ml/dlib/docs/docs/guipics/text_box.png create mode 100644 ml/dlib/docs/docs/guipics/text_field.png create mode 100644 ml/dlib/docs/docs/guipics/text_grid.png create mode 100644 ml/dlib/docs/docs/heatmap.png create mode 100644 ml/dlib/docs/docs/howto_contribute.xml create mode 100644 ml/dlib/docs/docs/imaging.xml create mode 100644 ml/dlib/docs/docs/index.xml create mode 100644 ml/dlib/docs/docs/intro.xml create mode 100644 ml/dlib/docs/docs/jet.png create mode 100644 ml/dlib/docs/docs/kernel_1a.txt create mode 100644 ml/dlib/docs/docs/kernel_1a.xml create mode 100644 ml/dlib/docs/docs/kernel_1b.txt create mode 100644 ml/dlib/docs/docs/kernel_1b.xml create mode 100644 ml/dlib/docs/docs/kernel_1c.txt create mode 100644 ml/dlib/docs/docs/kernel_1c.xml create mode 100644 ml/dlib/docs/docs/kernel_1da.txt create mode 100644 ml/dlib/docs/docs/kernel_1da.xml create mode 100644 ml/dlib/docs/docs/kernel_1db.txt create mode 100644 ml/dlib/docs/docs/kernel_1db.xml create mode 100644 ml/dlib/docs/docs/kernel_1ea.txt create mode 100644 ml/dlib/docs/docs/kernel_1ea.xml create mode 100644 ml/dlib/docs/docs/kernel_1eb.txt create mode 100644 ml/dlib/docs/docs/kernel_1eb.xml create mode 100644 ml/dlib/docs/docs/kernel_1ec.txt create mode 100644 ml/dlib/docs/docs/kernel_1ec.xml create mode 100644 ml/dlib/docs/docs/kernel_2a.txt create mode 100644 ml/dlib/docs/docs/kernel_2a.xml create mode 100644 ml/dlib/docs/docs/kernel_3a.txt create mode 100644 ml/dlib/docs/docs/kernel_3a.xml create mode 100644 ml/dlib/docs/docs/kernel_3b.txt create mode 100644 ml/dlib/docs/docs/kernel_3b.xml create mode 100644 ml/dlib/docs/docs/license.xml create mode 100644 ml/dlib/docs/docs/linear_algebra.xml create mode 100644 ml/dlib/docs/docs/main_menu.xml create mode 100644 ml/dlib/docs/docs/metaprogramming.xml create mode 100644 ml/dlib/docs/docs/minus.gif create mode 100644 ml/dlib/docs/docs/ml.xml create mode 100644 ml/dlib/docs/docs/ml_guide.dia create mode 100644 ml/dlib/docs/docs/ml_guide.svg create mode 100644 ml/dlib/docs/docs/network.xml create mode 100644 ml/dlib/docs/docs/old_change_log.xml create mode 100644 ml/dlib/docs/docs/old_release_notes.xml create mode 100644 ml/dlib/docs/docs/optimization.xml create mode 100644 ml/dlib/docs/docs/other.xml create mode 100644 ml/dlib/docs/docs/parsing.xml create mode 100644 ml/dlib/docs/docs/plus.gif create mode 100644 ml/dlib/docs/docs/python/conf.py create mode 100644 ml/dlib/docs/docs/python/generate_dlib_listing.py create mode 100644 ml/dlib/docs/docs/python/index.rst create mode 100644 ml/dlib/docs/docs/rbf_big_gamma.gif create mode 100644 ml/dlib/docs/docs/rbf_normal.gif create mode 100644 ml/dlib/docs/docs/rbf_small_gamma.gif create mode 100644 ml/dlib/docs/docs/release_notes.xml create mode 100644 ml/dlib/docs/docs/right.gif create mode 100644 ml/dlib/docs/docs/stylesheet.xsl create mode 100644 ml/dlib/docs/docs/term_index.xml create mode 100644 ml/dlib/docs/docs/tiled_pyramid_example.jpg create mode 100755 ml/dlib/docs/docs/vs-cmake-gui.png create mode 100755 ml/dlib/docs/docs/vs_mode_1.png create mode 100755 ml/dlib/docs/docs/vs_mode_2.png create mode 100755 ml/dlib/docs/docs/vs_mode_3.png create mode 100755 ml/dlib/docs/makedocs create mode 100755 ml/dlib/docs/makerel create mode 100755 ml/dlib/docs/testenv create mode 100755 ml/dlib/docs/testenv_rel create mode 100644 ml/dlib/examples/3d_point_cloud_ex.cpp create mode 100644 ml/dlib/examples/CMakeLists.txt create mode 100644 ml/dlib/examples/LICENSE_FOR_EXAMPLE_PROGRAMS.txt create mode 100644 ml/dlib/examples/assignment_learning_ex.cpp create mode 100644 ml/dlib/examples/bayes_net_ex.cpp create mode 100644 ml/dlib/examples/bayes_net_from_disk_ex.cpp create mode 100644 ml/dlib/examples/bayes_net_gui_ex.cpp create mode 100644 ml/dlib/examples/bridge_ex.cpp create mode 100644 ml/dlib/examples/bsp_ex.cpp create mode 100644 ml/dlib/examples/compress_stream_ex.cpp create mode 100644 ml/dlib/examples/config.txt create mode 100644 ml/dlib/examples/config_reader_ex.cpp create mode 100644 ml/dlib/examples/custom_trainer_ex.cpp create mode 100644 ml/dlib/examples/dir_nav_ex.cpp create mode 100644 ml/dlib/examples/dnn_face_recognition_ex.cpp create mode 100644 ml/dlib/examples/dnn_imagenet_ex.cpp create mode 100644 ml/dlib/examples/dnn_imagenet_train_ex.cpp create mode 100644 ml/dlib/examples/dnn_inception_ex.cpp create mode 100644 ml/dlib/examples/dnn_introduction2_ex.cpp create mode 100644 ml/dlib/examples/dnn_introduction_ex.cpp create mode 100644 ml/dlib/examples/dnn_metric_learning_ex.cpp create mode 100644 ml/dlib/examples/dnn_metric_learning_on_images_ex.cpp create mode 100644 ml/dlib/examples/dnn_mmod_dog_hipsterizer.cpp create mode 100644 ml/dlib/examples/dnn_mmod_ex.cpp create mode 100644 ml/dlib/examples/dnn_mmod_face_detection_ex.cpp create mode 100644 ml/dlib/examples/dnn_mmod_find_cars2_ex.cpp create mode 100644 ml/dlib/examples/dnn_mmod_find_cars_ex.cpp create mode 100644 ml/dlib/examples/dnn_mmod_train_find_cars_ex.cpp create mode 100644 ml/dlib/examples/dnn_semantic_segmentation_ex.cpp create mode 100644 ml/dlib/examples/dnn_semantic_segmentation_ex.h create mode 100644 ml/dlib/examples/dnn_semantic_segmentation_train_ex.cpp create mode 100644 ml/dlib/examples/empirical_kernel_map_ex.cpp create mode 100644 ml/dlib/examples/face_detection_ex.cpp create mode 100644 ml/dlib/examples/face_landmark_detection_ex.cpp create mode 100755 ml/dlib/examples/faces/2007_007763.jpg create mode 100755 ml/dlib/examples/faces/2008_001009.jpg create mode 100755 ml/dlib/examples/faces/2008_001322.jpg create mode 100755 ml/dlib/examples/faces/2008_002079.jpg create mode 100755 ml/dlib/examples/faces/2008_002470.jpg create mode 100755 ml/dlib/examples/faces/2008_002506.jpg create mode 100755 ml/dlib/examples/faces/2008_004176.jpg create mode 100755 ml/dlib/examples/faces/2008_007676.jpg create mode 100755 ml/dlib/examples/faces/2009_004587.jpg create mode 100644 ml/dlib/examples/faces/Tom_Cruise_avp_2014_4.jpg create mode 100644 ml/dlib/examples/faces/bald_guys.jpg create mode 100644 ml/dlib/examples/faces/dogs.jpg create mode 100644 ml/dlib/examples/faces/image_metadata_stylesheet.xsl create mode 100644 ml/dlib/examples/faces/testing.xml create mode 100644 ml/dlib/examples/faces/testing_with_face_landmarks.xml create mode 100644 ml/dlib/examples/faces/training.xml create mode 100644 ml/dlib/examples/faces/training_with_face_landmarks.xml create mode 100644 ml/dlib/examples/fhog_ex.cpp create mode 100644 ml/dlib/examples/fhog_object_detector_ex.cpp create mode 100644 ml/dlib/examples/file_to_code_ex.cpp create mode 100644 ml/dlib/examples/graph_labeling_ex.cpp create mode 100644 ml/dlib/examples/gui_api_ex.cpp create mode 100644 ml/dlib/examples/hough_transform_ex.cpp create mode 100644 ml/dlib/examples/image_ex.cpp create mode 100644 ml/dlib/examples/integrate_function_adapt_simp_ex.cpp create mode 100644 ml/dlib/examples/iosockstream_ex.cpp create mode 100644 ml/dlib/examples/johns/John_Salley/000179_02159509.jpg create mode 100644 ml/dlib/examples/johns/John_Salley/000183_02159543.jpg create mode 100644 ml/dlib/examples/johns/John_Salley/000186_02159346.jpg create mode 100644 ml/dlib/examples/johns/John_Salley/000189_02159361.jpg create mode 100644 ml/dlib/examples/johns/John_Salley/000190_02159501.jpg create mode 100644 ml/dlib/examples/johns/John_Salley/000192_02159531.jpg create mode 100644 ml/dlib/examples/johns/John_Salley/000194_02159572.jpg create mode 100644 ml/dlib/examples/johns/John_Salley/000197_02159322.jpg create mode 100644 ml/dlib/examples/johns/John_Salley/000197_02159525.jpg create mode 100644 ml/dlib/examples/johns/John_Salley/000198_02159470.jpg create mode 100644 ml/dlib/examples/johns/John_Salley/000200_02159354.jpg create mode 100644 ml/dlib/examples/johns/John_Savage/000264_01099001.jpg create mode 100644 ml/dlib/examples/johns/John_Savage/000274_01099061.jpg create mode 100644 ml/dlib/examples/johns/John_Savage/000277_01099000.jpg create mode 100644 ml/dlib/examples/johns/John_Savage/000289_01099139.jpg create mode 100644 ml/dlib/examples/johns/John_Savage/000290_01099067.jpg create mode 100644 ml/dlib/examples/johns/John_Savage/000290_01099090.jpg create mode 100644 ml/dlib/examples/johns/John_Savage/000291_01099023.jpg create mode 100644 ml/dlib/examples/johns/John_Savage/000291_01099214.jpg create mode 100644 ml/dlib/examples/johns/John_Savage/000293_01099081.jpg create mode 100644 ml/dlib/examples/johns/John_Savage/000296_01099007.jpg create mode 100644 ml/dlib/examples/johns/John_Savage/000299_01099008.jpg create mode 100644 ml/dlib/examples/johns/John_Schneider/000288_00925786.jpg create mode 100644 ml/dlib/examples/johns/John_Schneider/000302_00925785.jpg create mode 100644 ml/dlib/examples/johns/John_Schneider/000307_00925823.jpg create mode 100644 ml/dlib/examples/johns/John_Schneider/000325_00925954.jpg create mode 100644 ml/dlib/examples/johns/John_Schneider/000326_00925765.jpg create mode 100644 ml/dlib/examples/johns/John_Schneider/000326_00926089.jpg create mode 100644 ml/dlib/examples/johns/John_Schneider/000326_00926128.jpg create mode 100644 ml/dlib/examples/johns/John_Schneider/000326_00926139.jpg create mode 100644 ml/dlib/examples/johns/John_Schneider/000329_00925859.jpg create mode 100644 ml/dlib/examples/johns/John_Schneider/000329_00925963.jpg create mode 100644 ml/dlib/examples/johns/John_Schneider/000331_00926012.jpg create mode 100644 ml/dlib/examples/johns/John_Shimkus/000373_03228153.jpg create mode 100644 ml/dlib/examples/johns/John_Shimkus/000375_03227651.jpg create mode 100644 ml/dlib/examples/johns/John_Shimkus/000376_02340068.jpg create mode 100644 ml/dlib/examples/johns/John_Shimkus/000378_02340151.jpg create mode 100644 ml/dlib/examples/johns/John_Shimkus/000378_03227610.jpg create mode 100644 ml/dlib/examples/johns/John_Shimkus/000383_03227939.jpg create mode 100644 ml/dlib/examples/johns/John_Shimkus/000385_03227766.jpg create mode 100644 ml/dlib/examples/johns/John_Shimkus/000388_03227773.jpg create mode 100644 ml/dlib/examples/johns/John_Shimkus/000390_03227666.jpg create mode 100644 ml/dlib/examples/johns/John_Shimkus/000394_02340150.jpg create mode 100644 ml/dlib/examples/johns/John_Shimkus/000396_03227722.jpg create mode 100644 ml/dlib/examples/johns/John_Simm/000288_00470387.jpg create mode 100644 ml/dlib/examples/johns/John_Simm/000297_00470170.jpg create mode 100644 ml/dlib/examples/johns/John_Simm/000300_00470148.jpg create mode 100644 ml/dlib/examples/johns/John_Simm/000304_00470122.jpg create mode 100644 ml/dlib/examples/johns/John_Simm/000305_00470162.jpg create mode 100644 ml/dlib/examples/johns/John_Simm/000305_00470717.jpg create mode 100644 ml/dlib/examples/johns/John_Simm/000306_00470222.jpg create mode 100644 ml/dlib/examples/johns/John_Simm/000306_00470223.jpg create mode 100644 ml/dlib/examples/johns/John_Simm/000309_00470287.jpg create mode 100644 ml/dlib/examples/johns/John_Simm/000310_00470421.jpg create mode 100644 ml/dlib/examples/johns/John_Simm/000310_00470511.jpg create mode 100644 ml/dlib/examples/kcentroid_ex.cpp create mode 100644 ml/dlib/examples/kkmeans_ex.cpp create mode 100644 ml/dlib/examples/krls_ex.cpp create mode 100644 ml/dlib/examples/krls_filter_ex.cpp create mode 100644 ml/dlib/examples/krr_classification_ex.cpp create mode 100644 ml/dlib/examples/krr_regression_ex.cpp create mode 100644 ml/dlib/examples/learning_to_track_ex.cpp create mode 100644 ml/dlib/examples/least_squares_ex.cpp create mode 100644 ml/dlib/examples/linear_manifold_regularizer_ex.cpp create mode 100644 ml/dlib/examples/logger_custom_output_ex.cpp create mode 100644 ml/dlib/examples/logger_ex.cpp create mode 100644 ml/dlib/examples/logger_ex_2.cpp create mode 100644 ml/dlib/examples/matrix_ex.cpp create mode 100644 ml/dlib/examples/matrix_expressions_ex.cpp create mode 100755 ml/dlib/examples/max_cost_assignment_ex.cpp create mode 100644 ml/dlib/examples/member_function_pointer_ex.cpp create mode 100644 ml/dlib/examples/mlp_ex.cpp create mode 100644 ml/dlib/examples/mmod_cars_test_image.jpg create mode 100644 ml/dlib/examples/mmod_cars_test_image2.jpg create mode 100644 ml/dlib/examples/model_selection_ex.cpp create mode 100644 ml/dlib/examples/mpc_ex.cpp create mode 100644 ml/dlib/examples/multiclass_classification_ex.cpp create mode 100644 ml/dlib/examples/multithreaded_object_ex.cpp create mode 100644 ml/dlib/examples/object_detector_advanced_ex.cpp create mode 100644 ml/dlib/examples/object_detector_ex.cpp create mode 100644 ml/dlib/examples/one_class_classifiers_ex.cpp create mode 100644 ml/dlib/examples/optimization_ex.cpp create mode 100644 ml/dlib/examples/parallel_for_ex.cpp create mode 100644 ml/dlib/examples/pipe_ex.cpp create mode 100644 ml/dlib/examples/pipe_ex_2.cpp create mode 100644 ml/dlib/examples/quantum_computing_ex.cpp create mode 100644 ml/dlib/examples/queue_ex.cpp create mode 100644 ml/dlib/examples/random_cropper_ex.cpp create mode 100644 ml/dlib/examples/rank_features_ex.cpp create mode 100644 ml/dlib/examples/running_stats_ex.cpp create mode 100644 ml/dlib/examples/rvm_ex.cpp create mode 100644 ml/dlib/examples/rvm_regression_ex.cpp create mode 100644 ml/dlib/examples/sequence_labeler_ex.cpp create mode 100644 ml/dlib/examples/sequence_segmenter_ex.cpp create mode 100644 ml/dlib/examples/server_http_ex.cpp create mode 100644 ml/dlib/examples/server_iostream_ex.cpp create mode 100644 ml/dlib/examples/sockets_ex.cpp create mode 100644 ml/dlib/examples/sockstreambuf_ex.cpp create mode 100644 ml/dlib/examples/sqlite_ex.cpp create mode 100644 ml/dlib/examples/std_allocator_ex.cpp create mode 100644 ml/dlib/examples/surf_ex.cpp create mode 100644 ml/dlib/examples/svm_c_ex.cpp create mode 100644 ml/dlib/examples/svm_ex.cpp create mode 100644 ml/dlib/examples/svm_pegasos_ex.cpp create mode 100644 ml/dlib/examples/svm_rank_ex.cpp create mode 100644 ml/dlib/examples/svm_sparse_ex.cpp create mode 100644 ml/dlib/examples/svm_struct_ex.cpp create mode 100644 ml/dlib/examples/svr_ex.cpp create mode 100644 ml/dlib/examples/thread_function_ex.cpp create mode 100644 ml/dlib/examples/thread_pool_ex.cpp create mode 100644 ml/dlib/examples/threaded_object_ex.cpp create mode 100644 ml/dlib/examples/threads_ex.cpp create mode 100644 ml/dlib/examples/timer_ex.cpp create mode 100644 ml/dlib/examples/train_object_detector.cpp create mode 100644 ml/dlib/examples/train_shape_predictor_ex.cpp create mode 100644 ml/dlib/examples/using_custom_kernels_ex.cpp create mode 100644 ml/dlib/examples/video_frames/frame_000100.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000101.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000102.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000103.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000104.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000105.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000106.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000107.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000108.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000109.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000110.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000111.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000112.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000113.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000114.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000115.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000116.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000117.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000118.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000119.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000120.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000121.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000122.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000123.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000124.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000125.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000126.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000127.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000128.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000129.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000130.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000131.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000132.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000133.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000134.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000135.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000136.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000137.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000138.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000139.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000140.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000141.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000142.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000143.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000144.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000145.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000146.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000147.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000148.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000149.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000150.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000151.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000152.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000153.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000154.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000155.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000156.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000157.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000158.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000159.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000160.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000161.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000162.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000163.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000164.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000165.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000166.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000167.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000168.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000169.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000170.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000171.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000172.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000173.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000174.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000175.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000176.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000177.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000178.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000179.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000180.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000181.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000182.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000183.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000184.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000185.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000186.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000187.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000188.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000189.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000190.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000191.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000192.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000193.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000194.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000195.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000196.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000197.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000198.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000199.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000200.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000201.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000202.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000203.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000204.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000205.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000206.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000207.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000208.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000209.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000210.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000211.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000212.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000213.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000214.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000215.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000216.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000217.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000218.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000219.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000220.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000221.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000222.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000223.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000224.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000225.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000226.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000227.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000228.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000229.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000230.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000231.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000232.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000233.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000234.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000235.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000236.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000237.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000238.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000239.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000240.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000241.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000242.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000243.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000244.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000245.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000246.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000247.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000248.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000249.jpg create mode 100644 ml/dlib/examples/video_frames/frame_000250.jpg create mode 100644 ml/dlib/examples/video_frames/license.txt create mode 100644 ml/dlib/examples/video_tracking_ex.cpp create mode 100644 ml/dlib/examples/webcam_face_pose_ex.cpp create mode 100644 ml/dlib/examples/xml_parser_ex.cpp create mode 100644 ml/dlib/python_examples/LICENSE_FOR_EXAMPLE_PROGRAMS.txt create mode 100755 ml/dlib/python_examples/cnn_face_detector.py create mode 100755 ml/dlib/python_examples/correlation_tracker.py create mode 100755 ml/dlib/python_examples/face_alignment.py create mode 100755 ml/dlib/python_examples/face_clustering.py create mode 100755 ml/dlib/python_examples/face_detector.py create mode 100755 ml/dlib/python_examples/face_jitter.py create mode 100755 ml/dlib/python_examples/face_landmark_detection.py create mode 100755 ml/dlib/python_examples/face_recognition.py create mode 100755 ml/dlib/python_examples/find_candidate_object_locations.py create mode 100755 ml/dlib/python_examples/global_optimization.py create mode 100755 ml/dlib/python_examples/max_cost_assignment.py create mode 100644 ml/dlib/python_examples/requirements.txt create mode 100755 ml/dlib/python_examples/sequence_segmenter.py create mode 100755 ml/dlib/python_examples/svm_binary_classifier.py create mode 100755 ml/dlib/python_examples/svm_rank.py create mode 100755 ml/dlib/python_examples/svm_struct.py create mode 100755 ml/dlib/python_examples/train_object_detector.py create mode 100755 ml/dlib/python_examples/train_shape_predictor.py create mode 100644 ml/dlib/setup.py create mode 100644 ml/dlib/tools/archive/train_face_5point_model.cpp create mode 100644 ml/dlib/tools/convert_dlib_nets_to_caffe/CMakeLists.txt create mode 100644 ml/dlib/tools/convert_dlib_nets_to_caffe/main.cpp create mode 100755 ml/dlib/tools/convert_dlib_nets_to_caffe/running_a_dlib_model_with_caffe_example.py create mode 100644 ml/dlib/tools/htmlify/CMakeLists.txt create mode 100644 ml/dlib/tools/htmlify/htmlify.cpp create mode 100644 ml/dlib/tools/htmlify/to_xml.cpp create mode 100644 ml/dlib/tools/htmlify/to_xml.h create mode 100644 ml/dlib/tools/htmlify/to_xml_example/bigminus.gif create mode 100644 ml/dlib/tools/htmlify/to_xml_example/bigplus.gif create mode 100644 ml/dlib/tools/htmlify/to_xml_example/example.xml create mode 100644 ml/dlib/tools/htmlify/to_xml_example/minus.gif create mode 100644 ml/dlib/tools/htmlify/to_xml_example/output.xml create mode 100644 ml/dlib/tools/htmlify/to_xml_example/plus.gif create mode 100644 ml/dlib/tools/htmlify/to_xml_example/stylesheet.xsl create mode 100644 ml/dlib/tools/htmlify/to_xml_example/test.cpp create mode 100644 ml/dlib/tools/imglab/CMakeLists.txt create mode 100644 ml/dlib/tools/imglab/README.txt create mode 100755 ml/dlib/tools/imglab/convert_imglab_paths_to_relative create mode 100755 ml/dlib/tools/imglab/copy_imglab_dataset create mode 100644 ml/dlib/tools/imglab/src/cluster.cpp create mode 100644 ml/dlib/tools/imglab/src/cluster.h create mode 100644 ml/dlib/tools/imglab/src/common.cpp create mode 100644 ml/dlib/tools/imglab/src/common.h create mode 100644 ml/dlib/tools/imglab/src/convert_idl.cpp create mode 100644 ml/dlib/tools/imglab/src/convert_idl.h create mode 100644 ml/dlib/tools/imglab/src/convert_pascal_v1.cpp create mode 100644 ml/dlib/tools/imglab/src/convert_pascal_v1.h create mode 100644 ml/dlib/tools/imglab/src/convert_pascal_xml.cpp create mode 100644 ml/dlib/tools/imglab/src/convert_pascal_xml.h create mode 100644 ml/dlib/tools/imglab/src/flip_dataset.cpp create mode 100644 ml/dlib/tools/imglab/src/flip_dataset.h create mode 100644 ml/dlib/tools/imglab/src/main.cpp create mode 100644 ml/dlib/tools/imglab/src/metadata_editor.cpp create mode 100644 ml/dlib/tools/imglab/src/metadata_editor.h create mode 100644 ml/dlib/tools/python/CMakeLists.txt create mode 100644 ml/dlib/tools/python/src/basic.cpp create mode 100644 ml/dlib/tools/python/src/cca.cpp create mode 100644 ml/dlib/tools/python/src/cnn_face_detector.cpp create mode 100644 ml/dlib/tools/python/src/conversion.h create mode 100644 ml/dlib/tools/python/src/correlation_tracker.cpp create mode 100644 ml/dlib/tools/python/src/decision_functions.cpp create mode 100644 ml/dlib/tools/python/src/dlib.cpp create mode 100644 ml/dlib/tools/python/src/face_recognition.cpp create mode 100644 ml/dlib/tools/python/src/global_optimization.cpp create mode 100644 ml/dlib/tools/python/src/gui.cpp create mode 100644 ml/dlib/tools/python/src/image.cpp create mode 100644 ml/dlib/tools/python/src/image_dataset_metadata.cpp create mode 100644 ml/dlib/tools/python/src/indexing.h create mode 100644 ml/dlib/tools/python/src/matrix.cpp create mode 100644 ml/dlib/tools/python/src/numpy_returns.cpp create mode 100644 ml/dlib/tools/python/src/numpy_returns_stub.cpp create mode 100644 ml/dlib/tools/python/src/object_detection.cpp create mode 100644 ml/dlib/tools/python/src/opaque_types.h create mode 100644 ml/dlib/tools/python/src/other.cpp create mode 100644 ml/dlib/tools/python/src/rectangles.cpp create mode 100644 ml/dlib/tools/python/src/sequence_segmenter.cpp create mode 100644 ml/dlib/tools/python/src/serialize_object_detector.h create mode 100644 ml/dlib/tools/python/src/shape_predictor.cpp create mode 100644 ml/dlib/tools/python/src/shape_predictor.h create mode 100644 ml/dlib/tools/python/src/simple_object_detector.h create mode 100644 ml/dlib/tools/python/src/simple_object_detector_py.h create mode 100644 ml/dlib/tools/python/src/svm_c_trainer.cpp create mode 100644 ml/dlib/tools/python/src/svm_rank_trainer.cpp create mode 100644 ml/dlib/tools/python/src/svm_struct.cpp create mode 100644 ml/dlib/tools/python/src/testing_results.h create mode 100644 ml/dlib/tools/python/src/vector.cpp create mode 100644 ml/dlib/tools/python/test/.gitignore create mode 100644 ml/dlib/tools/python/test/test_array.py create mode 100644 ml/dlib/tools/python/test/test_global_optimization.py create mode 100644 ml/dlib/tools/python/test/test_matrix.py create mode 100644 ml/dlib/tools/python/test/test_point.py create mode 100644 ml/dlib/tools/python/test/test_range.py create mode 100644 ml/dlib/tools/python/test/test_rgb_pixel.py create mode 100644 ml/dlib/tools/python/test/test_sparse_vector.py create mode 100644 ml/dlib/tools/python/test/test_svm_c_trainer.py create mode 100644 ml/dlib/tools/python/test/test_vector.py create mode 100644 ml/dlib/tools/visual_studio_natvis/README.txt create mode 100644 ml/dlib/tools/visual_studio_natvis/dlib.natvis (limited to 'ml') diff --git a/ml/Config.cc b/ml/Config.cc index 8f2ef894c..c6a750995 100644 --- a/ml/Config.cc +++ b/ml/Config.cc @@ -51,7 +51,7 @@ void ml_config_load(ml_config_t *cfg) { size_t suppression_window = config_get_number(config_section_ml, "dimension anomaly rate suppression window", 900); size_t suppression_threshold = config_get_number(config_section_ml, "dimension anomaly rate suppression threshold", suppression_window / 2); - bool enable_statistics_charts = config_get_boolean(config_section_ml, "enable statistics charts", true); + bool enable_statistics_charts = config_get_boolean(config_section_ml, "enable statistics charts", false); /* * Clamp diff --git a/ml/dlib/.gitignore b/ml/dlib/.gitignore new file mode 100644 index 000000000..235fff575 --- /dev/null +++ b/ml/dlib/.gitignore @@ -0,0 +1,9 @@ +**/.idea +*~ +*.swp +*.o +*.so +build +dist +*.egg-info/ + diff --git a/ml/dlib/.hgignore b/ml/dlib/.hgignore new file mode 100644 index 000000000..463303389 --- /dev/null +++ b/ml/dlib/.hgignore @@ -0,0 +1,43 @@ +/build/ +/build2/ +/build_clang/ +\.swp$ +\.swo$ +\.o$ +\.a$ +\.so$ +\.orig$ +\.obj$ +\.pyc$ +^build/ +^dist/ +^\.cache/ +^\.eggs/ +^dlib\.egg-info/ +^docs/release/ +^docs/docs/web/ +^docs/docs/chm/ +^docs/docs/cache/ +^docs/docs/log.txt$ +^docs/docs/old_log.txt$ +^dlib/test/debug.txt$ +^dlib/test/test$ +^dlib/test/makefile.bak$ +python_examples/.*.so$ +python_examples/.*.pyd$ +python_examples/.*.dll$ +python_examples/.*.lib$ +docs/docs/python/classes.txt +docs/docs/python/functions.txt +syntax: glob +dlib/test/build64/* +*.svm +dlib/test/build_python +dlib/test/test_log.txt +dlib/test/build_vc2015_64/* +dlib/test/build_vc2013_64/* +dlib/test/build_vc2012_64/* +dlib/java/libmyproject.so +dlib/java/myproject.jar +dlib/java/swig_test.class + diff --git a/ml/dlib/.hgtags b/ml/dlib/.hgtags new file mode 100644 index 000000000..e9932b1e0 --- /dev/null +++ b/ml/dlib/.hgtags @@ -0,0 +1,41 @@ +7f7ffcda900ae3c8de0a562cc7e9d3f51e523a39 v17.39 +43cbb1c92eaeba09bd0e9c72bb0783044492e651 v17.40 +f6c79ee4083449640e45cc7e48ef2e7636687d90 v17.41 +e516e232b94254db264e34b61458e0eabd78bea8 v17.42 +43d280b34aa9dd68c62208d72f3b8e48ed047684 v17.43 +9ed76e89b6644208d4536b3396c7664445bd4520 v17.44 +d20f5ce805d53b5fec9d5dd250159d28480041cf v17.45 +cfbd4102b1ee25ca3531819abbbe93971d6dc65a v17.46 +f58ff52144bddbccb5a372e6c4befb4b6ad98021 v17.47 +b90ab60d8a18fc8616fb895b0d059c1cf07db3ce v17.48 +4312a45be8b45634f8ffafcb2b5a425bd426642b v17.49 +df60c7686f3982791e218977edb64d638151ca3b v18.0 +8d0762ab49b9ee8d1d3fc1fc02926c6bde6d5542 v18.1 +5de237bc41c1c2e63d0731731eee231a4213a31b v18.2 +7f21bd92812d2d08fe9a881a401cbf0f6b104081 v18.3 +78be73b57b829adb20a452ad910f7039a09c9474 v18.4 +3026cfaf82c5a878b2c81ee4d6445237d453e372 v18.5 +6a929c3ad782f17e34d364665e6e277e3ff99912 v18.6 +5a3fb1f81041978948e6148a1f56ab56cf678c69 v18.7 +a6c2b16111b8023dbded7299dcc7e6acd26671b8 v18.8 +4de62892e10850e8f0205b4857cf48b31fd730c8 v18.9 +5a14394843c04628990857e5db94ff6bc43c2da0 v18.10 +dd8e950033d5026373acce9ed4b2ffb85908d3b5 v18.11 +4e3941b13ca859f788853cfcef9973ac4b161e65 v18.12 +67c3ad208aae9537cf16f64936dd62e2210caa96 v18.13 +cae7fcc9e6a9b28b44a703e4598f44286fec734d v18.14 +feaff82884ded598bde93c635eb3ded9c0933a07 v18.15 +42a25c606cf924a8d41d2fc96b9c85f839d21a04 v18.16 +ce6f364987865b19bdb1b4730ac5403e2bb55dc4 v18.17 +7ae1775f61a44b7f07866050b50ad3ade581f019 v18.18 +4d6b102506bb9e2f195c7ddf984cc2d86b8643e7 before_dnn_serialization_cleanup +7210589728f6d83f6cb7d21cd24d114a5364d9e2 v19.0 +ad6cd2a3bfd54d48f6b4a1b6d3ef8c0ce278a8d9 v19.1 +f8fa027c760270d8122427838b89e95ccf0b80a1 v19.2 +26cdc89f4795a1f924d80d208b9ed22437c01600 v19.3 +74c4985dfb28f1b91286ab38f35bc026326ec995 v19.4 +9121e039950df93507fdf27bbb102bb5bc1ab429 v19.5 +3eaa0e35b1b4b912897b664abd78a23cc7705c9b v19.6 +fb51c77524ff13ca58b4846f8778a38f35b6f986 v19.7 +fef491c3b8182c68df0a80fa727ccb5929a45821 v19.8 +0cbf133b31c13f665fc32227952315844a66ff85 v19.9 diff --git a/ml/dlib/.travis.yml b/ml/dlib/.travis.yml new file mode 100644 index 000000000..4604fbaac --- /dev/null +++ b/ml/dlib/.travis.yml @@ -0,0 +1,107 @@ +sudo: required + +matrix: + include: + ################### + - language: cpp + compiler: clang + os: linux + env: + - VARIANT=test + script: + - dlib/travis/build-and-test.sh + + ################### + - language: cpp + compiler: clang + os: linux + env: + - VARIANT=examples + script: + - dlib/travis/build-and-test.sh + + ################### + - language: cpp + compiler: gcc + os: linux + env: + - VARIANT=test + script: + - dlib/travis/build-and-test.sh + + ################### + - language: cpp + compiler: gcc + os: linux + env: + - VARIANT=tools + script: + - dlib/travis/build-and-test.sh + + ################### + - language: cpp + compiler: gcc + os: linux + env: + - VARIANT=dlib_all_source_cpp + script: + - dlib/travis/build-and-test.sh + + ########### test with C++17 ######## + - language: cpp + compiler: gcc + os: linux + env: + - VARIANT=test + - CXXFLAGS=-std=c++17 + # Need to set MATRIX_EVAL to set CC and CXX env vars. You would + # think you could just set them in the env area like any other, but + # travis is wonky about CC and CXX vars so you have to do it this way. + - MATRIX_EVAL="CC=gcc-7 && CXX=g++-7" + addons: + apt: + sources: + - ubuntu-toolchain-r-test + packages: + - g++-7 + script: + - dlib/travis/build-and-test.sh + + ################### + - language: cpp + compiler: gcc + os: linux + env: + - VARIANT=examples + script: + - dlib/travis/build-and-test.sh + + ################### + - language: python + python: 2.7 + env: + - VARIANT=python-api + script: + - dlib/travis/build-and-test.sh + + ################### + - language: python + python: 3.5 + env: + - VARIANT=python-api + script: + - dlib/travis/build-and-test.sh + + ################### + # # Disabled because travis's OS X machines take hours (or days) to begin + # running. Or maybe they are just broken entirely. Who knows. + #- language: cpp + # os: osx + # osx_image: xcode9.2 + # env: + # - VARIANT=test + # script: + # - dlib/travis/build-and-test.sh + + + diff --git a/ml/dlib/CMakeLists.txt b/ml/dlib/CMakeLists.txt new file mode 100644 index 000000000..d3cf123f6 --- /dev/null +++ b/ml/dlib/CMakeLists.txt @@ -0,0 +1,36 @@ +cmake_minimum_required(VERSION 2.8.12) + + + + +############################################################################# +# # +# READ examples/CMakeLists.txt TO SEE HOW TO USE DLIB FROM C++ WITH CMAKE # +# # +############################################################################# + + + + + +get_directory_property(has_parent PARENT_DIRECTORY) +if(NOT has_parent) + # When you call add_subdirectory(dlib) from a parent CMake project dlib's + # CMake scripts will assume you want to statically compile dlib into + # whatever you are building rather than create a standalone copy of dlib. + # This means CMake will build dlib as a static library, disable dlib's + # install targets so they don't clutter your project, and adjust a few other + # minor things that are convenient when statically building dlib as part of + # your own projects. + # + # On the other hand, if there is no parent CMake project or if + # DLIB_IN_PROJECT_BUILD is set to false, CMake will compile dlib as a normal + # standalone library (either shared or static, based on the state of CMake's + # BUILD_SHARED_LIBS flag), and include the usual install targets so you can + # install dlib on your computer via `make install`. Since the only reason + # to build this CMakeLists.txt (the one you are reading right now) by itself + # is if you want to install dlib, we indicate as such by setting + # DLIB_IN_PROJECT_BUILD to false. + set(DLIB_IN_PROJECT_BUILD false) +endif() +add_subdirectory(dlib) diff --git a/ml/dlib/ISSUE_TEMPLATE.md b/ml/dlib/ISSUE_TEMPLATE.md new file mode 100644 index 000000000..faea67768 --- /dev/null +++ b/ml/dlib/ISSUE_TEMPLATE.md @@ -0,0 +1,31 @@ +IF YOU ARE REPORTING A BUG OR PROBLEM WITH DLIB THEN FILL OUT THE ENTIRE TEMPLATE BELOW. ISSUES ASKING QUESTIONS ABOUT WHY SOMETHING DOESN'T WORK THAT FAIL TO FILL OUT THE ENTIRE TEMPLATE WILL BE CLOSED. + +It is OK to suggest interesting improvements to dlib, even if you are not volunteering to implement them. **However, the issue tracker is not a code writing service, do not ask for someone to write code for you.** E.g. Do not ask for feature improvements to the example programs. **If there is some feature improvement you want in an example program then it's up to you to write it**. + +Before you ask a question, check Google for a solution, [the dlib FAQ](http://dlib.net/faq.html), or consult the dlib documentation. Every single function in dlib is documented in detail. If you obviously haven't read the documentation your issue will be closed. + +If you aren't reporting a bug or problem with dlib then delete this template and write whatever you want here. + + + + + + + + +## Expected Behavior + + +## Current Behavior + + +## Steps to Reproduce + + + + +* **Version**: +* **Where did you get dlib**: +* **Platform**: +* **Compiler**: diff --git a/ml/dlib/MANIFEST.in b/ml/dlib/MANIFEST.in new file mode 100644 index 000000000..9ef7fe6d0 --- /dev/null +++ b/ml/dlib/MANIFEST.in @@ -0,0 +1,15 @@ +# +# MANIFEST.in +# +# Manifest template for creating the dlib source distribution. + +include MANIFEST.in +include setup.py +include README.md + +# sources +recursive-include dlib ** +recursive-include python_examples *.txt *.py +recursive-include tools/python ** + + diff --git a/ml/dlib/README.md b/ml/dlib/README.md new file mode 100644 index 000000000..d9103d757 --- /dev/null +++ b/ml/dlib/README.md @@ -0,0 +1,71 @@ +# dlib C++ library [![Travis Status](https://travis-ci.org/davisking/dlib.svg?branch=master)](https://travis-ci.org/davisking/dlib) + +Dlib is a modern C++ toolkit containing machine learning algorithms and tools for creating complex software in C++ to solve real world problems. See [http://dlib.net](http://dlib.net) for the main project documentation and API reference. + + + +## Compiling dlib C++ example programs + +Go into the examples folder and type: + +```bash +mkdir build; cd build; cmake .. ; cmake --build . +``` + +That will build all the examples. +If you have a CPU that supports AVX instructions then turn them on like this: + +```bash +mkdir build; cd build; cmake .. -DUSE_AVX_INSTRUCTIONS=1; cmake --build . +``` + +Doing so will make some things run faster. + +Finally, Visual Studio users should usually do everything in 64bit mode. By default Visual Studio is 32bit, both in its outputs and its own execution, so you have to explicitly tell it to use 64bits. Since it's not the 1990s anymore you probably want to use 64bits. Do that with a cmake invocation like this: +```bash +cmake .. -G "Visual Studio 14 2015 Win64" -T host=x64 +``` + +## Compiling your own C++ programs that use dlib + +The examples folder has a [CMake tutorial](https://github.com/davisking/dlib/blob/master/examples/CMakeLists.txt) that tells you what to do. There are also additional instructions on the [dlib web site](http://dlib.net/compile.html). + +## Compiling dlib Python API + +Before you can run the Python example programs you must compile dlib. Type: + +```bash +python setup.py install +``` + +or type + +```bash +python setup.py install --yes USE_AVX_INSTRUCTIONS +``` + +if you have a CPU that supports AVX instructions, since this makes some things run faster. + + + +## Running the unit test suite + +Type the following to compile and run the dlib unit test suite: + +```bash +cd dlib/test +mkdir build +cd build +cmake .. +cmake --build . --config Release +./dtest --runall +``` + +Note that on windows your compiler might put the test executable in a subfolder called `Release`. If that's the case then you have to go to that folder before running the test. + +This library is licensed under the Boost Software License, which can be found in [dlib/LICENSE.txt](https://github.com/davisking/dlib/blob/master/dlib/LICENSE.txt). The long and short of the license is that you can use dlib however you like, even in closed source commercial software. + +## dlib sponsors + +This research is based in part upon work supported by the Office of the Director of National Intelligence (ODNI), Intelligence Advanced Research Projects Activity (IARPA) under contract number 2014-14071600010. The views and conclusions contained herein are those of the authors and should not be interpreted as necessarily representing the official policies or endorsements, either expressed or implied, of ODNI, IARPA, or the U.S. Government. + diff --git a/ml/dlib/dlib/CMakeLists.txt b/ml/dlib/dlib/CMakeLists.txt new file mode 100644 index 000000000..15208064c --- /dev/null +++ b/ml/dlib/dlib/CMakeLists.txt @@ -0,0 +1,841 @@ +# +# This is a CMake makefile. You can find the cmake utility and +# information about it at http://www.cmake.org +# + + +cmake_minimum_required(VERSION 2.8.12) +project(dlib) + + +include(cmake_utils/set_compiler_specific_options.cmake) + + +# Adhere to GNU filesystem layout conventions +include(GNUInstallDirs) + +# default to a Release build (except if CMAKE_BUILD_TYPE is set) +include(cmake_utils/release_build_by_default) +include(cmake_utils/use_cpp_11.cmake) + + +set(CPACK_PACKAGE_VERSION_MAJOR "19") +set(CPACK_PACKAGE_VERSION_MINOR "10") +set(CPACK_PACKAGE_VERSION_PATCH "0") +set(VERSION ${CPACK_PACKAGE_VERSION_MAJOR}.${CPACK_PACKAGE_VERSION_MINOR}.${CPACK_PACKAGE_VERSION_PATCH}) +# Set DLIB_VERSION in the including CMake file so they can use it to do whatever they want. +get_directory_property(has_parent PARENT_DIRECTORY) +if(has_parent) + set(DLIB_VERSION ${VERSION} PARENT_SCOPE) + if (NOT DEFINED DLIB_IN_PROJECT_BUILD) + set(DLIB_IN_PROJECT_BUILD true) + endif() +endif() + + +if (DLIB_IN_PROJECT_BUILD) + # DLIB_IN_PROJECT_BUILD==true means you are using dlib by invoking + # add_subdirectory(dlib) in the parent project. In this case, we always want + # to build dlib as a static library so the parent project doesn't need to + # deal with some random dlib shared library file. It is much better to + # statically compile dlib into the parent project. So the following bit of + # CMake ensures that happens. However, we have to take care to compile dlib + # with position independent code if appropriate (i.e. if the parent project + # is a shared library). + if (BUILD_SHARED_LIBS) + if (CMAKE_COMPILER_IS_GNUCXX) + # Just setting CMAKE_POSITION_INDEPENDENT_CODE should be enough to set + # -fPIC for GCC but sometimes it still doesn't get set, so make sure it + # does. + add_definitions("-fPIC") + endif() + set(CMAKE_POSITION_INDEPENDENT_CODE true) + endif() + + # Tell cmake to build dlib as a static library + set(BUILD_SHARED_LIBS false) +endif() + + +if (CMAKE_VERSION VERSION_LESS "3.9.0") + # Set only because there are old target_link_libraries() statements in the + # FindCUDA.cmake file that comes with CMake that error out if the new behavior + # is used. In newer versions of CMake we can instead set CUDA_LINK_LIBRARIES_KEYWORD which fixes this issue. + cmake_policy(SET CMP0023 OLD) +else() + set(CUDA_LINK_LIBRARIES_KEYWORD PUBLIC) +endif() + + +macro (enable_preprocessor_switch option_name) + list(APPEND active_preprocessor_switches "-D${option_name}") +endmacro() + +macro (disable_preprocessor_switch option_name) + if (active_preprocessor_switches) + list(REMOVE_ITEM active_preprocessor_switches "-D${option_name}") + endif() +endmacro() + +macro (toggle_preprocessor_switch option_name) + if (${option_name}) + enable_preprocessor_switch(${option_name}) + else() + disable_preprocessor_switch(${option_name}) + endif() +endmacro() + + + +# Suppress superfluous randlib warnings about libdlib.a having no symbols on MacOSX. +if (APPLE) + set(CMAKE_C_ARCHIVE_CREATE " Scr ") + set(CMAKE_CXX_ARCHIVE_CREATE " Scr ") + set(CMAKE_C_ARCHIVE_FINISH " -no_warning_for_no_symbols -c ") + set(CMAKE_CXX_ARCHIVE_FINISH " -no_warning_for_no_symbols -c ") +endif() + +# Don't try to call add_library(dlib) and setup dlib's stuff if it has already +# been done by some other part of the current cmake project. We do this +# because it avoids getting warnings/errors about cmake policy CMP0002. This +# happens when a project tries to call add_subdirectory() on dlib more than +# once. This most often happens when the top level of a project depends on two +# or more other things which both depend on dlib. +if (NOT TARGET dlib) + + set (DLIB_ISO_CPP_ONLY_STR + "Enable this if you don't want to compile any non-ISO C++ code (i.e. you don't use any of the API Wrappers)" ) + set (DLIB_NO_GUI_SUPPORT_STR + "Enable this if you don't want to compile any of the dlib GUI code" ) + set (DLIB_ENABLE_STACK_TRACE_STR + "Enable this if you want to turn on the DLIB_STACK_TRACE macros" ) + set (DLIB_USE_BLAS_STR + "Disable this if you don't want to use a BLAS library" ) + set (DLIB_USE_LAPACK_STR + "Disable this if you don't want to use a LAPACK library" ) + set (DLIB_USE_CUDA_STR + "Disable this if you don't want to use NVIDIA CUDA" ) + set (DLIB_PNG_SUPPORT_STR + "Disable this if you don't want to link against libpng" ) + set (DLIB_GIF_SUPPORT_STR + "Disable this if you don't want to link against libgif" ) + set (DLIB_JPEG_SUPPORT_STR + "Disable this if you don't want to link against libjpeg" ) + set (DLIB_LINK_WITH_SQLITE3_STR + "Disable this if you don't want to link against sqlite3" ) + #set (DLIB_USE_FFTW_STR "Disable this if you don't want to link against fftw" ) + set (DLIB_USE_MKL_FFT_STR + "Disable this is you don't want to use the MKL DFTI FFT implementation" ) + set (DLIB_ENABLE_ASSERTS_STR + "Enable this if you want to turn on the DLIB_ASSERT macro" ) + + + option(DLIB_ENABLE_ASSERTS ${DLIB_ENABLE_ASSERTS_STR} OFF) + option(DLIB_ISO_CPP_ONLY ${DLIB_ISO_CPP_ONLY_STR} OFF) + toggle_preprocessor_switch(DLIB_ISO_CPP_ONLY) + option(DLIB_NO_GUI_SUPPORT ${DLIB_NO_GUI_SUPPORT_STR} OFF) + toggle_preprocessor_switch(DLIB_NO_GUI_SUPPORT) + option(DLIB_ENABLE_STACK_TRACE ${DLIB_ENABLE_STACK_TRACE_STR} OFF) + toggle_preprocessor_switch(DLIB_ENABLE_STACK_TRACE) + + if(DLIB_ENABLE_ASSERTS) + # Set these variables so they are set in the config.h.in file when dlib + # is installed. + set (DLIB_DISABLE_ASSERTS false) + set (ENABLE_ASSERTS true) + enable_preprocessor_switch(ENABLE_ASSERTS) + disable_preprocessor_switch(DLIB_DISABLE_ASSERTS) + else() + # Set these variables so they are set in the config.h.in file when dlib + # is installed. + set (DLIB_DISABLE_ASSERTS true) + set (ENABLE_ASSERTS false) + disable_preprocessor_switch(ENABLE_ASSERTS) + # Never force the asserts off when doing an in project build. The only + # time this matters is when using visual studio. The visual studio IDE + # has a drop down that lets the user select either release or debug + # builds. The DLIB_ASSERT macro is setup to enable/disable automatically + # based on this drop down (via preprocessor magic). However, if + # DLIB_DISABLE_ASSERTS is defined it permanently disables asserts no + # matter what, which would defeat the visual studio drop down. So here + # we make a point to not do that kind of severe disabling when in a + # project build. It should also be pointed out that DLIB_DISABLE_ASSERTS + # is only needed when building and installing dlib as a separately + # installed library. It doesn't matter when doing an in project build. + if (NOT DLIB_IN_PROJECT_BUILD) + enable_preprocessor_switch(DLIB_DISABLE_ASSERTS) + endif() + endif() + + if (DLIB_ISO_CPP_ONLY) + option(DLIB_JPEG_SUPPORT ${DLIB_JPEG_SUPPORT_STR} OFF) + option(DLIB_LINK_WITH_SQLITE3 ${DLIB_LINK_WITH_SQLITE3_STR} OFF) + option(DLIB_USE_BLAS ${DLIB_USE_BLAS_STR} OFF) + option(DLIB_USE_LAPACK ${DLIB_USE_LAPACK_STR} OFF) + option(DLIB_USE_CUDA ${DLIB_USE_CUDA_STR} OFF) + option(DLIB_PNG_SUPPORT ${DLIB_PNG_SUPPORT_STR} OFF) + option(DLIB_GIF_SUPPORT ${DLIB_GIF_SUPPORT_STR} OFF) + #option(DLIB_USE_FFTW ${DLIB_USE_FFTW_STR} OFF) + option(DLIB_USE_MKL_FFT ${DLIB_USE_MKL_FFT_STR} OFF) + else() + option(DLIB_JPEG_SUPPORT ${DLIB_JPEG_SUPPORT_STR} ON) + option(DLIB_LINK_WITH_SQLITE3 ${DLIB_LINK_WITH_SQLITE3_STR} ON) + option(DLIB_USE_BLAS ${DLIB_USE_BLAS_STR} ON) + option(DLIB_USE_LAPACK ${DLIB_USE_LAPACK_STR} ON) + option(DLIB_USE_CUDA ${DLIB_USE_CUDA_STR} ON) + option(DLIB_PNG_SUPPORT ${DLIB_PNG_SUPPORT_STR} ON) + option(DLIB_GIF_SUPPORT ${DLIB_GIF_SUPPORT_STR} ON) + #option(DLIB_USE_FFTW ${DLIB_USE_FFTW_STR} ON) + option(DLIB_USE_MKL_FFT ${DLIB_USE_MKL_FFT_STR} ON) + endif() + toggle_preprocessor_switch(DLIB_JPEG_SUPPORT) + toggle_preprocessor_switch(DLIB_USE_BLAS) + toggle_preprocessor_switch(DLIB_USE_LAPACK) + toggle_preprocessor_switch(DLIB_USE_CUDA) + toggle_preprocessor_switch(DLIB_PNG_SUPPORT) + toggle_preprocessor_switch(DLIB_GIF_SUPPORT) + #toggle_preprocessor_switch(DLIB_USE_FFTW) + toggle_preprocessor_switch(DLIB_USE_MKL_FFT) + + + set(source_files + base64/base64_kernel_1.cpp + bigint/bigint_kernel_1.cpp + bigint/bigint_kernel_2.cpp + bit_stream/bit_stream_kernel_1.cpp + entropy_decoder/entropy_decoder_kernel_1.cpp + entropy_decoder/entropy_decoder_kernel_2.cpp + entropy_encoder/entropy_encoder_kernel_1.cpp + entropy_encoder/entropy_encoder_kernel_2.cpp + md5/md5_kernel_1.cpp + tokenizer/tokenizer_kernel_1.cpp + unicode/unicode.cpp + data_io/image_dataset_metadata.cpp + data_io/mnist.cpp + global_optimization/global_function_search.cpp + filtering/kalman_filter.cpp + test_for_odr_violations.cpp + ) + + + set(dlib_needed_libraries) + set(dlib_needed_includes) + + if (DLIB_ISO_CPP_ONLY) + add_library(dlib ${source_files} ) + else() + + set(source_files ${source_files} + sockets/sockets_kernel_1.cpp + bsp/bsp.cpp + dir_nav/dir_nav_kernel_1.cpp + dir_nav/dir_nav_kernel_2.cpp + dir_nav/dir_nav_extensions.cpp + linker/linker_kernel_1.cpp + logger/extra_logger_headers.cpp + logger/logger_kernel_1.cpp + logger/logger_config_file.cpp + misc_api/misc_api_kernel_1.cpp + misc_api/misc_api_kernel_2.cpp + sockets/sockets_extensions.cpp + sockets/sockets_kernel_2.cpp + sockstreambuf/sockstreambuf.cpp + sockstreambuf/sockstreambuf_unbuffered.cpp + server/server_kernel.cpp + server/server_iostream.cpp + server/server_http.cpp + threads/multithreaded_object_extension.cpp + threads/threaded_object_extension.cpp + threads/threads_kernel_1.cpp + threads/threads_kernel_2.cpp + threads/threads_kernel_shared.cpp + threads/thread_pool_extension.cpp + threads/async.cpp + timer/timer.cpp + stack_trace.cpp + dnn/cpu_dlib.cpp + dnn/tensor_tools.cpp + ) + + if(UNIX) + set(CMAKE_THREAD_PREFER_PTHREAD ON) + find_package(Threads REQUIRED) + set(dlib_needed_libraries ${dlib_needed_libraries} ${CMAKE_THREAD_LIBS_INIT}) + endif() + + # we want to link to the right stuff depending on our platform. + if (WIN32 AND NOT CYGWIN) ############################################################################### + if (DLIB_NO_GUI_SUPPORT) + set (dlib_needed_libraries ws2_32 winmm) + else() + set (dlib_needed_libraries ws2_32 winmm comctl32 gdi32 imm32) + endif() + elseif(APPLE) ############################################################################ + set(CMAKE_MACOSX_RPATH 1) + if (NOT DLIB_NO_GUI_SUPPORT) + find_package(X11 QUIET) + if (X11_FOUND) + # If both X11 and anaconda are installed, it's possible for the + # anaconda path to appear before /opt/X11, so we remove anaconda. + foreach (ITR ${X11_INCLUDE_DIR}) + if ("${ITR}" MATCHES "(.*)(Ana|ana|mini)conda(.*)") + list (REMOVE_ITEM X11_INCLUDE_DIR ${ITR}) + endif () + endforeach(ITR) + include_directories(${X11_INCLUDE_DIR}) + set (dlib_needed_libraries ${dlib_needed_libraries} ${X11_LIBRARIES}) + else() + find_library(xlib X11) + # Make sure X11 is in the include path. Note that we look for + # Xlocale.h rather than Xlib.h because it avoids finding a partial + # copy of the X11 headers on systems with anaconda installed. + find_path(xlib_path Xlocale.h + PATHS + /Developer/SDKs/MacOSX10.4u.sdk/usr/X11R6/include + /opt/local/include + PATH_SUFFIXES X11 + ) + if (xlib AND xlib_path) + get_filename_component(x11_path ${xlib_path} PATH CACHE) + include_directories(${x11_path}) + set(dlib_needed_libraries ${dlib_needed_libraries} ${xlib} ) + set(X11_FOUND 1) + endif() + endif() + if (NOT X11_FOUND) + message(" *****************************************************************************") + message(" *** DLIB GUI SUPPORT DISABLED BECAUSE X11 DEVELOPMENT LIBRARIES NOT FOUND ***") + message(" *** Make sure XQuartz is installed if you want GUI support. ***") + message(" *** You can download XQuartz from: https://www.xquartz.org/ ***") + message(" *****************************************************************************") + set(DLIB_NO_GUI_SUPPORT ON CACHE STRING ${DLIB_NO_GUI_SUPPORT_STR} FORCE ) + enable_preprocessor_switch(DLIB_NO_GUI_SUPPORT) + endif() + endif() + + mark_as_advanced(pthreadlib xlib xlib_path x11_path) + else () ################################################################################## + # link to the nsl library if it exists. this is something you need sometimes + find_library(nsllib nsl) + if (nsllib) + set (dlib_needed_libraries ${dlib_needed_libraries} ${nsllib}) + endif () + + # link to the socket library if it exists. this is something you need on solaris + find_library(socketlib socket) + if (socketlib) + set (dlib_needed_libraries ${dlib_needed_libraries} ${socketlib}) + endif () + + if (NOT DLIB_NO_GUI_SUPPORT) + include(FindX11) + if (X11_FOUND) + include_directories(${X11_INCLUDE_DIR}) + set (dlib_needed_libraries ${dlib_needed_libraries} ${X11_LIBRARIES}) + else() + message(" *****************************************************************************") + message(" *** DLIB GUI SUPPORT DISABLED BECAUSE X11 DEVELOPMENT LIBRARIES NOT FOUND ***") + message(" *** Make sure libx11-dev is installed if you want GUI support. ***") + message(" *** On Ubuntu run: sudo apt-get install libx11-dev ***") + message(" *****************************************************************************") + set(DLIB_NO_GUI_SUPPORT ON CACHE STRING ${DLIB_NO_GUI_SUPPORT_STR} FORCE ) + enable_preprocessor_switch(DLIB_NO_GUI_SUPPORT) + endif() + endif() + + mark_as_advanced(nsllib pthreadlib socketlib) + endif () ################################################################################## + + if (NOT DLIB_NO_GUI_SUPPORT) + set(source_files ${source_files} + gui_widgets/fonts.cpp + gui_widgets/widgets.cpp + gui_widgets/drawable.cpp + gui_widgets/canvas_drawing.cpp + gui_widgets/style.cpp + gui_widgets/base_widgets.cpp + gui_core/gui_core_kernel_1.cpp + gui_core/gui_core_kernel_2.cpp + ) + endif() + + INCLUDE (CheckFunctionExists) + + if (DLIB_GIF_SUPPORT) + find_package(GIF QUIET) + if (GIF_FOUND) + set (dlib_needed_includes ${dlib_needed_includes} ${GIF_INCLUDE_DIR}) + set (dlib_needed_libraries ${dlib_needed_libraries} ${GIF_LIBRARY}) + else() + set(DLIB_GIF_SUPPORT OFF CACHE STRING ${DLIB_GIF_SUPPORT_STR} FORCE ) + toggle_preprocessor_switch(DLIB_GIF_SUPPORT) + endif() + endif() + + if (DLIB_PNG_SUPPORT) + # try to find libpng + find_package(PNG QUIET) + # Make sure there isn't something wrong with the version of LIBPNG + # installed on this system. + if (PNG_FOUND) + set(CMAKE_REQUIRED_LIBRARIES ${PNG_LIBRARIES}) + CHECK_FUNCTION_EXISTS(png_create_read_struct LIBPNG_IS_GOOD) + endif() + if (PNG_FOUND AND LIBPNG_IS_GOOD) + include_directories(${PNG_INCLUDE_DIR}) + set (dlib_needed_libraries ${dlib_needed_libraries} ${PNG_LIBRARIES}) + set(REQUIRES_LIBS " libpng") + else() + # If we can't find libpng then statically compile it in. + include_directories(external/libpng external/zlib) + set(source_files ${source_files} + external/libpng/png.c + external/libpng/pngerror.c + external/libpng/pngget.c + external/libpng/pngmem.c + external/libpng/pngpread.c + external/libpng/pngread.c + external/libpng/pngrio.c + external/libpng/pngrtran.c + external/libpng/pngrutil.c + external/libpng/pngset.c + external/libpng/pngtrans.c + external/libpng/pngwio.c + external/libpng/pngwrite.c + external/libpng/pngwtran.c + external/libpng/pngwutil.c + external/zlib/adler32.c + external/zlib/compress.c + external/zlib/crc32.c + external/zlib/deflate.c + external/zlib/gzclose.c + external/zlib/gzlib.c + external/zlib/gzread.c + external/zlib/gzwrite.c + external/zlib/infback.c + external/zlib/inffast.c + external/zlib/inflate.c + external/zlib/inftrees.c + external/zlib/trees.c + external/zlib/uncompr.c + external/zlib/zutil.c + ) + + include(cmake_utils/check_if_neon_available.cmake) + if (ARM_NEON_IS_AVAILABLE) + message (STATUS "NEON instructions will be used for libpng.") + enable_language(ASM) + set(source_files ${source_files} + external/libpng/arm/arm_init.c + external/libpng/arm/filter_neon_intrinsics.c + external/libpng/arm/filter_neon.S + ) + set_source_files_properties(external/libpng/arm/filter_neon.S PROPERTIES COMPILE_FLAGS "${CMAKE_ASM_FLAGS} ${CMAKE_CXX_FLAGS} -x assembler-with-cpp") + endif() + set(REQUIRES_LIBS "") + endif() + set(source_files ${source_files} + image_loader/png_loader.cpp + image_saver/save_png.cpp + ) + endif() + + if (DLIB_JPEG_SUPPORT) + # try to find libjpeg + find_package(JPEG QUIET) + # Make sure there isn't something wrong with the version of libjpeg + # installed on this system. Also don't use the installed libjpeg + # if this is an APPLE system because apparently it's broken (as of 2015/01/01). + if (JPEG_FOUND AND NOT ("${JPEG_INCLUDE_DIR}" MATCHES "(.*)(Ana|ana|mini)conda(.*)")) + set(CMAKE_REQUIRED_LIBRARIES ${JPEG_LIBRARY}) + CHECK_FUNCTION_EXISTS(jpeg_read_header LIBJPEG_IS_GOOD) + endif() + if (JPEG_FOUND AND LIBJPEG_IS_GOOD AND NOT APPLE) + include_directories(${JPEG_INCLUDE_DIR}) + set (dlib_needed_libraries ${dlib_needed_libraries} ${JPEG_LIBRARY}) + else() + # If we can't find libjpeg then statically compile it in. + add_definitions(-DDLIB_JPEG_STATIC) + set(source_files ${source_files} + external/libjpeg/jcomapi.cpp + external/libjpeg/jdapimin.cpp + external/libjpeg/jdapistd.cpp + external/libjpeg/jdatasrc.cpp + external/libjpeg/jdcoefct.cpp + external/libjpeg/jdcolor.cpp + external/libjpeg/jddctmgr.cpp + external/libjpeg/jdhuff.cpp + external/libjpeg/jdinput.cpp + external/libjpeg/jdmainct.cpp + external/libjpeg/jdmarker.cpp + external/libjpeg/jdmaster.cpp + external/libjpeg/jdmerge.cpp + external/libjpeg/jdphuff.cpp + external/libjpeg/jdpostct.cpp + external/libjpeg/jdsample.cpp + external/libjpeg/jerror.cpp + external/libjpeg/jidctflt.cpp + external/libjpeg/jidctfst.cpp + external/libjpeg/jidctint.cpp + external/libjpeg/jidctred.cpp + external/libjpeg/jmemmgr.cpp + external/libjpeg/jmemnobs.cpp + external/libjpeg/jquant1.cpp + external/libjpeg/jquant2.cpp + external/libjpeg/jutils.cpp + external/libjpeg/jcapimin.cpp + external/libjpeg/jdatadst.cpp + external/libjpeg/jcparam.cpp + external/libjpeg/jcapistd.cpp + external/libjpeg/jcmarker.cpp + external/libjpeg/jcinit.cpp + external/libjpeg/jcmaster.cpp + external/libjpeg/jcdctmgr.cpp + external/libjpeg/jccoefct.cpp + external/libjpeg/jccolor.cpp + external/libjpeg/jchuff.cpp + external/libjpeg/jcmainct.cpp + external/libjpeg/jcphuff.cpp + external/libjpeg/jcprepct.cpp + external/libjpeg/jcsample.cpp + external/libjpeg/jfdctint.cpp + external/libjpeg/jfdctflt.cpp + external/libjpeg/jfdctfst.cpp + ) + endif() + set(source_files ${source_files} + image_loader/jpeg_loader.cpp + image_saver/save_jpeg.cpp + ) + endif() + + + if (DLIB_USE_BLAS OR DLIB_USE_LAPACK OR DLIB_USE_MKL_FFT) + # Try to find BLAS, LAPACK and MKL + include(cmake_utils/find_blas.cmake) + + if (DLIB_USE_BLAS) + if (blas_found) + set (dlib_needed_libraries ${dlib_needed_libraries} ${blas_libraries}) + else() + set(DLIB_USE_BLAS OFF CACHE STRING ${DLIB_USE_BLAS_STR} FORCE ) + toggle_preprocessor_switch(DLIB_USE_BLAS) + endif() + endif() + + if (DLIB_USE_LAPACK) + if (lapack_found) + set (dlib_needed_libraries ${dlib_needed_libraries} ${lapack_libraries}) + if (lapack_with_underscore) + set(LAPACK_FORCE_UNDERSCORE 1) + enable_preprocessor_switch(LAPACK_FORCE_UNDERSCORE) + elseif (lapack_without_underscore) + set(LAPACK_FORCE_NOUNDERSCORE 1) + enable_preprocessor_switch(LAPACK_FORCE_NOUNDERSCORE) + endif () + else() + set(DLIB_USE_LAPACK OFF CACHE STRING ${DLIB_USE_LAPACK_STR} FORCE ) + toggle_preprocessor_switch(DLIB_USE_LAPACK) + endif() + endif() + + if (DLIB_USE_MKL_FFT) + if (found_intel_mkl AND found_intel_mkl_headers) + set (dlib_needed_includes ${dlib_needed_includes} ${mkl_include_dir}) + set (dlib_needed_libraries ${dlib_needed_libraries} ${mkl_libraries}) + else() + set(DLIB_USE_MKL_FFT OFF CACHE STRING ${DLIB_USE_MKL_FFT_STR} FORCE ) + toggle_preprocessor_switch(DLIB_USE_MKL_FFT) + endif() + endif() + endif() + + + if (DLIB_USE_CUDA) + find_package(CUDA 7.5) + + if (CUDA_FOUND AND MSVC AND NOT CUDA_CUBLAS_LIBRARIES AND "${CMAKE_SIZEOF_VOID_P}" EQUAL "4") + message(WARNING "You have CUDA installed, but we can't use it unless you put visual studio in 64bit mode.") + set(CUDA_FOUND 0) + endif() + + if (CUDA_FOUND AND (NOT USING_OLD_VISUAL_STUDIO_COMPILER)) + + # There is some bug in cmake that causes it to mess up the + # -std=c++11 option if you let it propagate it to nvcc in some + # cases. So instead we disable this and manually include + # things from CMAKE_CXX_FLAGS in the CUDA_NVCC_FLAGS list below. + if (APPLE) + set(CUDA_PROPAGATE_HOST_FLAGS OFF) + # Grab all the -D flags from CMAKE_CXX_FLAGS so we can pass them + # to nvcc. + string(REGEX MATCHALL "-D[^ ]*" FLAGS_FOR_NVCC "${CMAKE_CXX_FLAGS}") + endif() + + + set(CUDA_HOST_COMPILATION_CPP ON) + # Note that we add __STRICT_ANSI__ to avoid freaking out nvcc with gcc specific + # magic in the standard C++ header files (since nvcc uses gcc headers on + # linux). + list(APPEND CUDA_NVCC_FLAGS "-arch=sm_30;-D__STRICT_ANSI__;-D_MWAITXINTRIN_H_INCLUDED;-D_FORCE_INLINES;${FLAGS_FOR_NVCC}") + list(APPEND CUDA_NVCC_FLAGS ${active_preprocessor_switches}) + if (NOT MSVC) + list(APPEND CUDA_NVCC_FLAGS "-std=c++11") + endif() + if (CMAKE_POSITION_INDEPENDENT_CODE) + # sometimes this setting isn't propagated to NVCC, which then causes the + # compile to fail. So make sure it's propagated. + if (NOT MSVC) # Visual studio doesn't have -fPIC so don't do it in that case. + list(APPEND CUDA_NVCC_FLAGS "-Xcompiler -fPIC") + endif() + endif() + + include(cmake_utils/test_for_cudnn/find_cudnn.txt) + + if (cudnn AND cudnn_include AND NOT DEFINED cuda_test_compile_worked AND NOT DEFINED cudnn_test_compile_worked) + # make sure cuda is really working by doing a test compile + message(STATUS "Building a CUDA test project to see if your compiler is compatible with CUDA...") + + set(CUDA_TEST_CMAKE_FLAGS + "-DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH}" + "-DCMAKE_INCLUDE_PATH=${CMAKE_INCLUDE_PATH}" + "-DCMAKE_LIBRARY_PATH=${CMAKE_LIBRARY_PATH}") + + if (NOT MSVC) # see https://github.com/davisking/dlib/issues/363 + list(APPEND CUDA_TEST_CMAKE_FLAGS "-DCUDA_HOST_COMPILER=${CUDA_HOST_COMPILER}") + endif() + + try_compile(cuda_test_compile_worked + ${PROJECT_BINARY_DIR}/cuda_test_build + ${PROJECT_SOURCE_DIR}/cmake_utils/test_for_cuda cuda_test + CMAKE_FLAGS ${CUDA_TEST_CMAKE_FLAGS} + OUTPUT_VARIABLE try_compile_output_message + ) + if (NOT cuda_test_compile_worked) + string(REPLACE "\n" "\n *** " try_compile_output_message "${try_compile_output_message}") + message(STATUS "*****************************************************************************************************************") + message(STATUS "*** CUDA was found but your compiler failed to compile a simple CUDA program so dlib isn't going to use CUDA. ") + message(STATUS "*** The output of the failed CUDA test compile is shown below: ") + message(STATUS "*** ${try_compile_output_message}") + message(STATUS "*****************************************************************************************************************") + else() + message(STATUS "Checking if you have the right version of cuDNN installed.") + try_compile(cudnn_test_compile_worked + ${PROJECT_BINARY_DIR}/cudnn_test_build + ${PROJECT_SOURCE_DIR}/cmake_utils/test_for_cudnn cudnn_test + CMAKE_FLAGS ${CUDA_TEST_CMAKE_FLAGS} + ) + if (NOT cudnn_test_compile_worked) + message(STATUS "*** Found cuDNN, but it looks like the wrong version so dlib will not use it. ***") + message(STATUS "*** Dlib requires cuDNN V5.0 OR GREATER. Since cuDNN is not found DLIB WILL NOT USE CUDA. ***") + message(STATUS "*** If you have cuDNN then set CMAKE_PREFIX_PATH to include cuDNN's folder. ***") + endif() + endif() + endif() + + # Find where cuSOLVER is since the FindCUDA cmake package doesn't + # bother to look for it. + get_filename_component(cuda_blas_path "${CUDA_CUBLAS_LIBRARIES}" DIRECTORY) + find_library(cusolver cusolver HINTS ${cuda_blas_path}) + mark_as_advanced(cusolver) + # Also find OpenMP since cuSOLVER needs it. Importantly, we only + # look for one to link to if our use of BLAS, specifically the + # Intel MKL, hasn't already decided what to use. This is because + # it makes the MKL bug out if you link to another openmp lib other + # than Intel's when you use the MKL. + if (NOT openmp_libraries AND NOT MSVC AND NOT XCODE) + find_package(OpenMP) + if (OPENMP_FOUND) + set(openmp_libraries ${OpenMP_CXX_FLAGS}) + else() + message(STATUS "*** Didn't find OpenMP, which is required to use CUDA. ***") + set(CUDA_FOUND 0) + endif() + endif() + endif() + + if (CUDA_FOUND AND cudnn AND (NOT USING_OLD_VISUAL_STUDIO_COMPILER) AND cuda_test_compile_worked AND cudnn_test_compile_worked AND cudnn_include) + set(source_files ${source_files} + dnn/cuda_dlib.cu + dnn/cudnn_dlibapi.cpp + dnn/cublas_dlibapi.cpp + dnn/cusolver_dlibapi.cu + dnn/curand_dlibapi.cpp + dnn/cuda_data_ptr.cpp + dnn/gpu_data.cpp + ) + set(dlib_needed_libraries ${dlib_needed_libraries} + ${CUDA_CUBLAS_LIBRARIES} + ${cudnn} + ${CUDA_curand_LIBRARY} + ${cusolver} + ) + if(openmp_libraries) + list(APPEND dlib_needed_libraries ${openmp_libraries}) + endif() + + include_directories(${cudnn_include}) + message(STATUS "Enabling CUDA support for dlib. DLIB WILL USE CUDA") + else() + set(DLIB_USE_CUDA OFF CACHE STRING ${DLIB_USE_BLAS_STR} FORCE ) + toggle_preprocessor_switch(DLIB_USE_CUDA) + if (USING_OLD_VISUAL_STUDIO_COMPILER) + message(STATUS "*** Dlib CUDA support requires C++11 but your compiler doesn't support it. ***") + endif() + message(STATUS "Disabling CUDA support for dlib. DLIB WILL NOT USE CUDA") + endif() + endif() + + + if (DLIB_LINK_WITH_SQLITE3) + find_library(sqlite sqlite3) + # make sure sqlite3.h is in the include path + find_path(sqlite_path sqlite3.h) + if (sqlite AND sqlite_path) + set(dlib_needed_includes ${dlib_needed_includes} ${sqlite_path}) + set(dlib_needed_libraries ${dlib_needed_libraries} ${sqlite} ) + else() + set(DLIB_LINK_WITH_SQLITE3 OFF CACHE STRING ${DLIB_LINK_WITH_SQLITE3_STR} FORCE ) + endif() + mark_as_advanced(sqlite sqlite_path) + endif() + + + + if (DLIB_USE_FFTW) + find_library(fftw fftw3) + # make sure fftw3.h is in the include path + find_path(fftw_path fftw3.h) + if (fftw AND fftw_path) + set(dlib_needed_includes ${dlib_needed_includes} ${fftw_path}) + set(dlib_needed_libraries ${dlib_needed_libraries} ${fftw} ) + else() + set(DLIB_USE_FFTW OFF CACHE STRING ${DLIB_USE_FFTW_STR} FORCE ) + toggle_preprocessor_switch(DLIB_USE_FFTW) + endif() + mark_as_advanced(fftw fftw_path) + endif() + + + + # Tell CMake to build dlib via add_library()/cuda_add_library() + if (DLIB_USE_CUDA) + # The old cuda_add_library() command doesn't support CMake's newer dependency + # stuff, so we have to set the include path manually still, which we do here. + include_directories(${dlib_needed_includes}) + cuda_add_library(dlib ${source_files} ) + else() + add_library(dlib ${source_files} ) + endif() + + endif () ##### end of if NOT DLIB_ISO_CPP_ONLY ########################################################## + + + target_include_directories(dlib + INTERFACE $ + INTERFACE $ + PUBLIC ${dlib_needed_includes} + ) + target_link_libraries(dlib PUBLIC ${dlib_needed_libraries}) + if (DLIB_IN_PROJECT_BUILD) + target_compile_options(dlib PUBLIC ${active_preprocessor_switches}) + else() + # These are private in this case because they will be controlled by the + # contents of dlib/config.h once it's installed. But for in project + # builds, there is no real config.h so they are public in the above case. + target_compile_options(dlib PRIVATE ${active_preprocessor_switches}) + # Do this so that dlib/config.h won't set DLIB_NOT_CONFIGURED. This will then allow + # the code in dlib/threads_kernel_shared.cpp to emit a linker error for users who + # don't use the configured config.h file generated by cmake. + target_compile_options(dlib PRIVATE -DDLIB__CMAKE_GENERATED_A_CONFIG_H_FILE) + + # Do this so that dlib/config.h can record the version of dlib it's configured with + # and ultimately issue a linker error to people who try to use a binary dlib that is + # the wrong version. + set(DLIB_CHECK_FOR_VERSION_MISMATCH + DLIB_VERSION_MISMATCH_CHECK__EXPECTED_VERSION_${CPACK_PACKAGE_VERSION_MAJOR}_${CPACK_PACKAGE_VERSION_MINOR}_${CPACK_PACKAGE_VERSION_PATCH}) + target_compile_options(dlib PRIVATE "-DDLIB_CHECK_FOR_VERSION_MISMATCH=${DLIB_CHECK_FOR_VERSION_MISMATCH}") + endif() + + + # Allow the unit tests to ask us to compile the all/source.cpp file just to make sure it compiles. + if (DLIB_TEST_COMPILE_ALL_SOURCE_CPP) + add_library(dlib_all_source_cpp STATIC all/source.cpp) + target_link_libraries(dlib_all_source_cpp dlib) + target_compile_options(dlib_all_source_cpp PUBLIC ${active_preprocessor_switches}) + enable_cpp11_for_target(dlib_all_source_cpp) + endif() + + if (TARGET dlib) + enable_cpp11_for_target(dlib) + target_compile_options(dlib PUBLIC ${active_compile_opts}) + endif() + + # Install the library + if (NOT DLIB_IN_PROJECT_BUILD) + set_target_properties(dlib PROPERTIES + VERSION ${VERSION}) + install(TARGETS dlib + EXPORT dlib + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} # Windows considers .dll to be runtime artifacts + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}) + + install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/ DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/dlib + FILES_MATCHING PATTERN "*.h" PATTERN "*.cmake" + REGEX "${CMAKE_CURRENT_BINARY_DIR}" EXCLUDE) + + + configure_file(${PROJECT_SOURCE_DIR}/config.h.in ${CMAKE_CURRENT_BINARY_DIR}/config.h) + # overwrite config.h with the configured one + install(FILES ${CMAKE_CURRENT_BINARY_DIR}/config.h DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/dlib) + + configure_file(${PROJECT_SOURCE_DIR}/revision.h.in ${CMAKE_CURRENT_BINARY_DIR}/revision.h) + install(FILES ${CMAKE_CURRENT_BINARY_DIR}/revision.h DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/dlib) + + ## Config.cmake generation and installation + + set(ConfigPackageLocation "${CMAKE_INSTALL_LIBDIR}/cmake/dlib") + install(EXPORT dlib + NAMESPACE dlib:: + DESTINATION ${ConfigPackageLocation}) + + configure_file(cmake_utils/dlibConfig.cmake.in "${CMAKE_CURRENT_BINARY_DIR}/config/dlibConfig.cmake" @ONLY) + + include(CMakePackageConfigHelpers) + write_basic_package_version_file( + "${CMAKE_CURRENT_BINARY_DIR}/config/dlibConfigVersion.cmake" + VERSION ${VERSION} + COMPATIBILITY AnyNewerVersion + ) + + install(FILES + "${CMAKE_CURRENT_BINARY_DIR}/config/dlibConfig.cmake" + "${CMAKE_CURRENT_BINARY_DIR}/config/dlibConfigVersion.cmake" + DESTINATION ${ConfigPackageLocation}) + + ## dlib-1.pc generation and installation + + configure_file("cmake_utils/dlib.pc.in" "dlib-1.pc" @ONLY) + install(FILES "${CMAKE_CURRENT_BINARY_DIR}/dlib-1.pc" + DESTINATION "${CMAKE_INSTALL_LIBDIR}/pkgconfig") + + endif() + +endif() + +if (MSVC) + # Give the output library files names that are unique functions of the + # visual studio mode that compiled them. We do this so that people who + # compile dlib and then copy the .lib files around (which they shouldn't be + # doing in the first place!) will hopefully be slightly less confused by + # what happens since, at the very least, the filenames will indicate what + # visual studio runtime they go with. + math(EXPR numbits ${CMAKE_SIZEOF_VOID_P}*8) + set_target_properties(dlib PROPERTIES DEBUG_POSTFIX "${VERSION}_debug_${numbits}bit_msvc${MSVC_VERSION}") + set_target_properties(dlib PROPERTIES RELEASE_POSTFIX "${VERSION}_release_${numbits}bit_msvc${MSVC_VERSION}") + set_target_properties(dlib PROPERTIES MINSIZEREL_POSTFIX "${VERSION}_minsizerel_${numbits}bit_msvc${MSVC_VERSION}") + set_target_properties(dlib PROPERTIES RELWITHDEBINFO_POSTFIX "${VERSION}_relwithdebinfo_${numbits}bit_msvc${MSVC_VERSION}") +endif() + +add_library(dlib::dlib ALIAS dlib) diff --git a/ml/dlib/dlib/LICENSE.txt b/ml/dlib/dlib/LICENSE.txt new file mode 100644 index 000000000..127a5bc39 --- /dev/null +++ b/ml/dlib/dlib/LICENSE.txt @@ -0,0 +1,23 @@ +Boost Software License - Version 1.0 - August 17th, 2003 + +Permission is hereby granted, free of charge, to any person or organization +obtaining a copy of the software and accompanying documentation covered by +this license (the "Software") to use, reproduce, display, distribute, +execute, and transmit the Software, and to prepare derivative works of the +Software, and to permit third-parties to whom the Software is furnished to +do so, all subject to the following: + +The copyright notices in the Software and this entire statement, including +the above license grant, this restriction and the following disclaimer, +must be included in all copies of the Software, in whole or in part, and +all derivative works of the Software, unless such copies or derivative +works are solely in the form of machine-executable object code generated by +a source language processor. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT +SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE +FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, +ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/ml/dlib/dlib/algs.h b/ml/dlib/dlib/algs.h new file mode 100644 index 000000000..d0f74b1f2 --- /dev/null +++ b/ml/dlib/dlib/algs.h @@ -0,0 +1,1157 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#ifdef DLIB_ALL_SOURCE_END +#include "dlib_basic_cpp_build_tutorial.txt" +#endif + +#ifndef DLIB_ALGs_ +#define DLIB_ALGs_ + +// this file contains miscellaneous stuff + +// Give people who forget the -std=c++11 option a reminder +#if (defined(__GNUC__) && ((__GNUC__ >= 4 && __GNUC_MINOR__ >= 8) || (__GNUC__ > 4))) || \ + (defined(__clang__) && ((__clang_major__ >= 3 && __clang_minor__ >= 4) || (__clang_major__ >= 3))) + #if __cplusplus < 201103 + #error "Dlib requires C++11 support. Give your compiler the -std=c++11 option to enable it." + #endif +#endif + +#if defined __NVCC__ + // Disable the "statement is unreachable" message since it will go off on code that is + // actually reachable but just happens to not be reachable sometimes during certain + // template instantiations. + #pragma diag_suppress code_is_unreachable +#endif + + +#ifdef _MSC_VER + +#if _MSC_VER < 1900 +#error "dlib versions newer than v19.1 use C++11 and therefore require Visual Studio 2015 or newer." +#endif + +// Disable the following warnings for Visual Studio + +// this is to disable the "'this' : used in base member initializer list" +// warning you get from some of the GUI objects since all the objects +// require that their parent class be passed into their constructor. +// In this case though it is totally safe so it is ok to disable this warning. +#pragma warning(disable : 4355) + +// This is a warning you get sometimes when Visual Studio performs a Koenig Lookup. +// This is a bug in visual studio. It is a totally legitimate thing to +// expect from a compiler. +#pragma warning(disable : 4675) + +// This is a warning you get from visual studio 2005 about things in the standard C++ +// library being "deprecated." I checked the C++ standard and it doesn't say jack +// about any of them (I checked the searchable PDF). So this warning is total Bunk. +#pragma warning(disable : 4996) + +// This is a warning you get from visual studio 2003: +// warning C4345: behavior change: an object of POD type constructed with an initializer +// of the form () will be default-initialized. +// I love it when this compiler gives warnings about bugs in previous versions of itself. +#pragma warning(disable : 4345) + + +// Disable warnings about conversion from size_t to unsigned long and long. +#pragma warning(disable : 4267) + +// Disable warnings about conversion from double to float +#pragma warning(disable : 4244) +#pragma warning(disable : 4305) + +// Disable "warning C4180: qualifier applied to function type has no meaning; ignored". +// This warning happens often in generic code that works with functions and isn't useful. +#pragma warning(disable : 4180) + +// Disable "warning C4290: C++ exception specification ignored except to indicate a function is not __declspec(nothrow)" +#pragma warning(disable : 4290) + + +// DNN module uses template-based network declaration that leads to very long +// type names. Visual Studio will produce Warning C4503 in such cases. https://msdn.microsoft.com/en-us/library/074af4b6.aspx says +// that correct binaries are still produced even when this warning happens, but linker errors from visual studio, if they occurr could be confusing. +#pragma warning( disable: 4503 ) + + +#endif + +#ifdef __BORLANDC__ +// Disable the following warnings for the Borland Compilers +// +// These warnings just say that the compiler is refusing to inline functions with +// loops or try blocks in them. +// +#pragma option -w-8027 +#pragma option -w-8026 +#endif + +#include // for the exceptions + +#ifdef __CYGWIN__ +namespace std +{ + typedef std::basic_string wstring; +} +#endif + +#include "platform.h" +#include "windows_magic.h" + + +#include // for std::swap +#include // for std::bad_alloc +#include +#include // for std::numeric_limits for is_finite() +#include "assert.h" +#include "error.h" +#include "noncopyable.h" +#include "enable_if.h" +#include "uintn.h" +#include "numeric_constants.h" +#include "memory_manager_stateless/memory_manager_stateless_kernel_1.h" // for the default memory manager + + + +// ---------------------------------------------------------------------------------------- +/*!A _dT !*/ + +template +inline charT _dTcast (const char a, const wchar_t b); +template <> +inline char _dTcast (const char a, const wchar_t ) { return a; } +template <> +inline wchar_t _dTcast (const char , const wchar_t b) { return b; } + +template +inline const charT* _dTcast ( const char* a, const wchar_t* b); +template <> +inline const char* _dTcast ( const char* a, const wchar_t* ) { return a; } +template <> +inline const wchar_t* _dTcast ( const char* , const wchar_t* b) { return b; } + + +#define _dT(charT,str) _dTcast(str,L##str) +/*! + requires + - charT == char or wchar_t + - str == a string or character literal + ensures + - returns the literal in the form of a charT type literal. +!*/ + +// ---------------------------------------------------------------------------------------- + + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + /*!A default_memory_manager + + This memory manager just calls new and delete directly. + + !*/ + typedef memory_manager_stateless_kernel_1 default_memory_manager; + +// ---------------------------------------------------------------------------------------- + + /*!A swap !*/ + // make swap available in the dlib namespace + using std::swap; + +// ---------------------------------------------------------------------------------------- + + /*! + Here is where I define my return codes. It is + important that they all be < 0. + !*/ + + enum general_return_codes + { + TIMEOUT = -1, + WOULDBLOCK = -2, + OTHER_ERROR = -3, + SHUTDOWN = -4, + PORTINUSE = -5 + }; + +// ---------------------------------------------------------------------------------------- + + inline unsigned long square_root ( + unsigned long value + ) + /*! + requires + - value <= 2^32 - 1 + ensures + - returns the square root of value. if the square root is not an + integer then it will be rounded up to the nearest integer. + !*/ + { + unsigned long x; + + // set the initial guess for what the root is depending on + // how big value is + if (value < 3) + return value; + else if (value < 4096) // 12 + x = 45; + else if (value < 65536) // 16 + x = 179; + else if (value < 1048576) // 20 + x = 717; + else if (value < 16777216) // 24 + x = 2867; + else if (value < 268435456) // 28 + x = 11469; + else // 32 + x = 45875; + + + + // find the root + x = (x + value/x)>>1; + x = (x + value/x)>>1; + x = (x + value/x)>>1; + x = (x + value/x)>>1; + + + + if (x*x < value) + return x+1; + else + return x; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + void median ( + T& one, + T& two, + T& three + ); + /*! + requires + - T implements operator< + - T is swappable by a global swap() + ensures + - #one is the median + - #one, #two, and #three is some permutation of one, two, and three. + !*/ + + + template < + typename T + > + void median ( + T& one, + T& two, + T& three + ) + { + using std::swap; + using dlib::swap; + + if ( one < two ) + { + // one < two + if ( two < three ) + { + // one < two < three : two + swap(one,two); + + } + else + { + // one < two >= three + if ( one < three) + { + // three + swap(three,one); + } + } + + } + else + { + // one >= two + if ( three < one ) + { + // three <= one >= two + if ( three < two ) + { + // two + swap(two,one); + } + else + { + // three + swap(three,one); + } + } + } + } + +// ---------------------------------------------------------------------------------------- + + namespace relational_operators + { + template < + typename A, + typename B + > + constexpr bool operator> ( + const A& a, + const B& b + ) { return b < a; } + + // --------------------------------- + + template < + typename A, + typename B + > + constexpr bool operator!= ( + const A& a, + const B& b + ) { return !(a == b); } + + // --------------------------------- + + template < + typename A, + typename B + > + constexpr bool operator<= ( + const A& a, + const B& b + ) { return !(b < a); } + + // --------------------------------- + + template < + typename A, + typename B + > + constexpr bool operator>= ( + const A& a, + const B& b + ) { return !(a < b); } + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + void exchange ( + T& a, + T& b + ) + /*! + This function does the exact same thing that global swap does and it does it by + just calling swap. But a lot of compilers have problems doing a Koenig Lookup + and the fact that this has a different name (global swap has the same name as + the member functions called swap) makes them compile right. + + So this is a workaround but not too ugly of one. But hopefully I get get + rid of this in a few years. So this function is already deprecated. + + This also means you should NOT use this function in your own code unless + you have to support an old buggy compiler that benefits from this hack. + !*/ + { + using std::swap; + using dlib::swap; + swap(a,b); + } + +// ---------------------------------------------------------------------------------------- + + /*!A is_pointer_type + + This is a template where is_pointer_type::value == true when T is a pointer + type and false otherwise. + !*/ + + template < + typename T + > + class is_pointer_type + { + public: + enum { value = false }; + private: + is_pointer_type(); + }; + + template < + typename T + > + class is_pointer_type + { + public: + enum { value = true }; + private: + is_pointer_type(); + }; + +// ---------------------------------------------------------------------------------------- + + /*!A is_const_type + + This is a template where is_const_type::value == true when T is a const + type and false otherwise. + !*/ + + template + struct is_const_type + { + static const bool value = false; + }; + template + struct is_const_type + { + static const bool value = true; + }; + template + struct is_const_type + { + static const bool value = true; + }; + +// ---------------------------------------------------------------------------------------- + + /*!A is_reference_type + + This is a template where is_reference_type::value == true when T is a reference + type and false otherwise. + !*/ + + template + struct is_reference_type + { + static const bool value = false; + }; + + template struct is_reference_type { static const bool value = true; }; + template struct is_reference_type { static const bool value = true; }; + +// ---------------------------------------------------------------------------------------- + + /*!A is_same_type + + This is a template where is_same_type::value == true when T and U are the + same type and false otherwise. + !*/ + + template < + typename T, + typename U + > + class is_same_type + { + public: + enum {value = false}; + private: + is_same_type(); + }; + + template + class is_same_type + { + public: + enum {value = true}; + private: + is_same_type(); + }; + +// ---------------------------------------------------------------------------------------- + + /*!A is_float_type + + This is a template that can be used to determine if a type is one of the built + int floating point types (i.e. float, double, or long double). + !*/ + + template < typename T > struct is_float_type { const static bool value = false; }; + template <> struct is_float_type { const static bool value = true; }; + template <> struct is_float_type { const static bool value = true; }; + template <> struct is_float_type { const static bool value = true; }; + +// ---------------------------------------------------------------------------------------- + + /*!A is_convertible + + This is a template that can be used to determine if one type is convertible + into another type. + + For example: + is_convertible::value == true // because ints are convertible to floats + is_convertible::value == false // because int pointers are NOT convertible to floats + !*/ + + template + struct is_convertible + { + struct yes_type { char a; }; + struct no_type { yes_type a[2]; }; + static const from& from_helper(); + static yes_type test(to); + static no_type test(...); + const static bool value = sizeof(test(from_helper())) == sizeof(yes_type); + }; + +// ---------------------------------------------------------------------------------------- + + struct general_ {}; + struct special_ : general_ {}; + template struct int_ { typedef int type; }; + +// ---------------------------------------------------------------------------------------- + + + /*!A is_same_object + + This is a templated function which checks if both of its arguments are actually + references to the same object. It returns true if they are and false otherwise. + + !*/ + + // handle the case where T and U are unrelated types. + template < typename T, typename U > + typename disable_if_c::value || is_convertible::value, bool>::type + is_same_object ( + const T& a, + const U& b + ) + { + return ((void*)&a == (void*)&b); + } + + // handle the case where T and U are related types because their pointers can be + // implicitly converted into one or the other. E.g. a derived class and its base class. + // Or where both T and U are just the same type. This way we make sure that if there is a + // valid way to convert between these two pointer types then we will take that route rather + // than the void* approach used otherwise. + template < typename T, typename U > + typename enable_if_c::value || is_convertible::value, bool>::type + is_same_object ( + const T& a, + const U& b + ) + { + return (&a == &b); + } + +// ---------------------------------------------------------------------------------------- + + /*!A is_unsigned_type + + This is a template where is_unsigned_type::value == true when T is an unsigned + scalar type and false when T is a signed scalar type. + !*/ + template < + typename T + > + struct is_unsigned_type + { + static const bool value = static_cast((static_cast(0)-static_cast(1))) > 0; + }; + template <> struct is_unsigned_type { static const bool value = false; }; + template <> struct is_unsigned_type { static const bool value = false; }; + template <> struct is_unsigned_type { static const bool value = false; }; + +// ---------------------------------------------------------------------------------------- + + /*!A is_signed_type + + This is a template where is_signed_type::value == true when T is a signed + scalar type and false when T is an unsigned scalar type. + !*/ + template < + typename T + > + struct is_signed_type + { + static const bool value = !is_unsigned_type::value; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + class copy_functor + { + public: + void operator() ( + const T& source, + T& destination + ) const + { + destination = source; + } + }; + +// ---------------------------------------------------------------------------------------- + + /*!A static_switch + + To use this template you give it some number of boolean expressions and it + tells you which one of them is true. If more than one of them is true then + it causes a compile time error. + + for example: + static_switch<1 + 1 == 2, 4 - 1 == 4>::value == 1 // because the first expression is true + static_switch<1 + 1 == 3, 4 == 4>::value == 2 // because the second expression is true + static_switch<1 + 1 == 3, 4 == 5>::value == 0 // 0 here because none of them are true + static_switch<1 + 1 == 2, 4 == 4>::value == compiler error // because more than one expression is true + !*/ + + template < bool v1 = 0, bool v2 = 0, bool v3 = 0, bool v4 = 0, bool v5 = 0, + bool v6 = 0, bool v7 = 0, bool v8 = 0, bool v9 = 0, bool v10 = 0, + bool v11 = 0, bool v12 = 0, bool v13 = 0, bool v14 = 0, bool v15 = 0 > + struct static_switch; + + template <> struct static_switch<0,0,0,0,0,0,0,0,0,0,0,0,0,0,0> { const static int value = 0; }; + template <> struct static_switch<1,0,0,0,0,0,0,0,0,0,0,0,0,0,0> { const static int value = 1; }; + template <> struct static_switch<0,1,0,0,0,0,0,0,0,0,0,0,0,0,0> { const static int value = 2; }; + template <> struct static_switch<0,0,1,0,0,0,0,0,0,0,0,0,0,0,0> { const static int value = 3; }; + template <> struct static_switch<0,0,0,1,0,0,0,0,0,0,0,0,0,0,0> { const static int value = 4; }; + template <> struct static_switch<0,0,0,0,1,0,0,0,0,0,0,0,0,0,0> { const static int value = 5; }; + template <> struct static_switch<0,0,0,0,0,1,0,0,0,0,0,0,0,0,0> { const static int value = 6; }; + template <> struct static_switch<0,0,0,0,0,0,1,0,0,0,0,0,0,0,0> { const static int value = 7; }; + template <> struct static_switch<0,0,0,0,0,0,0,1,0,0,0,0,0,0,0> { const static int value = 8; }; + template <> struct static_switch<0,0,0,0,0,0,0,0,1,0,0,0,0,0,0> { const static int value = 9; }; + template <> struct static_switch<0,0,0,0,0,0,0,0,0,1,0,0,0,0,0> { const static int value = 10; }; + template <> struct static_switch<0,0,0,0,0,0,0,0,0,0,1,0,0,0,0> { const static int value = 11; }; + template <> struct static_switch<0,0,0,0,0,0,0,0,0,0,0,1,0,0,0> { const static int value = 12; }; + template <> struct static_switch<0,0,0,0,0,0,0,0,0,0,0,0,1,0,0> { const static int value = 13; }; + template <> struct static_switch<0,0,0,0,0,0,0,0,0,0,0,0,0,1,0> { const static int value = 14; }; + template <> struct static_switch<0,0,0,0,0,0,0,0,0,0,0,0,0,0,1> { const static int value = 15; }; + +// ---------------------------------------------------------------------------------------- + /*!A is_built_in_scalar_type + + This is a template that allows you to determine if the given type is a built + in scalar type such as an int, char, float, short, etc. + + For example, is_built_in_scalar_type::value == true + For example, is_built_in_scalar_type::value == false + !*/ + + template struct is_built_in_scalar_type { const static bool value = false; }; + + template <> struct is_built_in_scalar_type { const static bool value = true; }; + template <> struct is_built_in_scalar_type { const static bool value = true; }; + template <> struct is_built_in_scalar_type { const static bool value = true; }; + template <> struct is_built_in_scalar_type { const static bool value = true; }; + template <> struct is_built_in_scalar_type { const static bool value = true; }; + template <> struct is_built_in_scalar_type { const static bool value = true; }; + template <> struct is_built_in_scalar_type { const static bool value = true; }; + template <> struct is_built_in_scalar_type { const static bool value = true; }; + template <> struct is_built_in_scalar_type { const static bool value = true; }; + template <> struct is_built_in_scalar_type { const static bool value = true; }; + template <> struct is_built_in_scalar_type { const static bool value = true; }; + template <> struct is_built_in_scalar_type { const static bool value = true; }; + template <> struct is_built_in_scalar_type { const static bool value = true; }; + template <> struct is_built_in_scalar_type { const static bool value = true; }; + // Don't define one for wchar_t when using a version of visual studio + // older than 8.0 (visual studio 2005) since before then they improperly set + // wchar_t to be a typedef rather than its own type as required by the C++ + // standard. +#if !defined(_MSC_VER) || _NATIVE_WCHAR_T_DEFINED + template <> struct is_built_in_scalar_type { const static bool value = true; }; +#endif + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + typename enable_if,bool>::type is_finite ( + const T& value + ) + /*! + requires + - value must be some kind of scalar type such as int or double + ensures + - returns true if value is a finite value (e.g. not infinity or NaN) and false + otherwise. + !*/ + { + if (is_float_type::value) + return -std::numeric_limits::infinity() < value && value < std::numeric_limits::infinity(); + else + return true; + } + +// ---------------------------------------------------------------------------------------- + + /*!A promote + + This is a template that takes one of the built in scalar types and gives you another + scalar type that should be big enough to hold sums of values from the original scalar + type. The new scalar type will also always be signed. + + For example, promote::type == int32 + !*/ + + template struct promote; + template struct promote { typedef int32 type; }; + template struct promote { typedef int32 type; }; + template struct promote { typedef int64 type; }; + template struct promote { typedef int64 type; }; + + template <> struct promote { typedef double type; }; + template <> struct promote { typedef double type; }; + template <> struct promote { typedef long double type; }; + +// ---------------------------------------------------------------------------------------- + + /*!A assign_zero_if_built_in_scalar_type + + This function assigns its argument the value of 0 if it is a built in scalar + type according to the is_built_in_scalar_type<> template. If it isn't a + built in scalar type then it does nothing. + !*/ + + template inline typename disable_if,void>::type assign_zero_if_built_in_scalar_type (T&){} + template inline typename enable_if,void>::type assign_zero_if_built_in_scalar_type (T& a){a=0;} + +// ---------------------------------------------------------------------------------------- + + /*!A basic_type + + This is a template that takes a type and strips off any const, volatile, or reference + qualifiers and gives you back the basic underlying type. So for example: + + basic_type::type == int + !*/ + + template struct basic_type { typedef T type; }; + template struct basic_type { typedef T type; }; + template struct basic_type { typedef T type; }; + template struct basic_type { typedef T type; }; + template struct basic_type { typedef T type; }; + template struct basic_type { typedef T type; }; + template struct basic_type { typedef T type; }; + template struct basic_type { typedef T type; }; + +// ---------------------------------------------------------------------------------------- + + template + T put_in_range ( + const T& a, + const T& b, + const T& val + ) + /*! + requires + - T is a type that looks like double, float, int, or so forth + ensures + - if (val is within the range [a,b]) then + - returns val + - else + - returns the end of the range [a,b] that is closest to val + !*/ + { + if (a < b) + { + if (val < a) + return a; + else if (val > b) + return b; + } + else + { + if (val < b) + return b; + else if (val > a) + return a; + } + + return val; + } + + // overload for double + inline double put_in_range(const double& a, const double& b, const double& val) + { return put_in_range(a,b,val); } + +// ---------------------------------------------------------------------------------------- + + /*!A tabs + + This is a template to compute the absolute value a number at compile time. + + For example, + abs<-4>::value == 4 + abs<4>::value == 4 + !*/ + + template + struct tabs { const static long value = x; }; + template + struct tabs::type> { const static long value = -x; }; + +// ---------------------------------------------------------------------------------------- + + /*!A tmax + + This is a template to compute the max of two values at compile time + + For example, + abs<4,7>::value == 7 + !*/ + + template + struct tmax { const static long value = x; }; + template + struct tmax x)>::type> { const static long value = y; }; + +// ---------------------------------------------------------------------------------------- + + /*!A tmin + + This is a template to compute the min of two values at compile time + + For example, + abs<4,7>::value == 4 + !*/ + + template + struct tmin { const static long value = x; }; + template + struct tmin::type> { const static long value = y; }; + +// ---------------------------------------------------------------------------------------- + +#define DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST(testname, returnT, funct_name, args) \ + struct _two_bytes_##testname { char a[2]; }; \ + template < typename T, returnT (T::*funct)args > \ + struct _helper_##testname { typedef char type; }; \ + template \ + static char _has_##testname##_helper( typename _helper_##testname::type ) { return 0;} \ + template \ + static _two_bytes_##testname _has_##testname##_helper(int) { return _two_bytes_##testname();} \ + template struct _##testname##workaroundbug { \ + const static unsigned long U = sizeof(_has_##testname##_helper('a')); }; \ + template ::U > \ + struct testname { static const bool value = false; }; \ + template \ + struct testname { static const bool value = true; }; + /*!A DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST + + The DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST() macro is used to define traits templates + that tell you if a class has a certain member function. For example, to make a + test to see if a class has a public method with the signature void print(int) you + would say: + DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST(has_print, void, print, (int)) + + Then you can check if a class, T, has this method by looking at the boolean value: + has_print::value + which will be true if the member function is in the T class. + + Note that you can test for member functions taking no arguments by simply passing + in empty () like so: + DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST(has_print, void, print, ()) + This would test for a member of the form: + void print(). + + To test for const member functions you would use a statement such as this: + DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST(has_print, void, print, ()const) + This would test for a member of the form: + void print() const. + + To test for const templated member functions you would use a statement such as this: + DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST(has_print, void, template print, ()) + This would test for a member of the form: + template void print(). + !*/ + +// ---------------------------------------------------------------------------------------- + + /*!A is_function + + This is a template that allows you to determine if the given type is a function. + + For example, + void funct(); + + is_built_in_scalar_type::value == true + is_built_in_scalar_type::value == false + !*/ + + template struct is_function { static const bool value = false; }; + template + struct is_function { static const bool value = true; }; + template + struct is_function { static const bool value = true; }; + template + struct is_function { static const bool value = true; }; + template + struct is_function { static const bool value = true; }; + template + struct is_function { static const bool value = true; }; + template + struct is_function { static const bool value = true; }; + template + struct is_function { static const bool value = true; }; + template + struct is_function { static const bool value = true; }; + template + struct is_function { static const bool value = true; }; + template + struct is_function { static const bool value = true; }; + template + struct is_function { static const bool value = true; }; + + + template class funct_wrap0 + { + public: + funct_wrap0(T (&f_)()):f(f_){} + T operator()() const { return f(); } + private: + T (&f)(); + }; + template class funct_wrap1 + { + public: + funct_wrap1(T (&f_)(A0)):f(f_){} + T operator()(A0 a0) const { return f(a0); } + private: + T (&f)(A0); + }; + template class funct_wrap2 + { + public: + funct_wrap2(T (&f_)(A0,A1)):f(f_){} + T operator()(A0 a0, A1 a1) const { return f(a0,a1); } + private: + T (&f)(A0,A1); + }; + template class funct_wrap3 + { + public: + funct_wrap3(T (&f_)(A0,A1,A2)):f(f_){} + T operator()(A0 a0, A1 a1, A2 a2) const { return f(a0,a1,a2); } + private: + T (&f)(A0,A1,A2); + }; + template class funct_wrap4 + { + public: + funct_wrap4(T (&f_)(A0,A1,A2,A3)):f(f_){} + T operator()(A0 a0, A1 a1, A2 a2, A3 a3) const { return f(a0,a1,a2,a3); } + private: + T (&f)(A0,A1,A2,A3); + }; + template class funct_wrap5 + { + public: + funct_wrap5(T (&f_)(A0,A1,A2,A3,A4)):f(f_){} + T operator()(A0 a0, A1 a1, A2 a2, A3 a3, A4 a4) const { return f(a0,a1,a2,a3,a4); } + private: + T (&f)(A0,A1,A2,A3,A4); + }; + + /*!A wrap_function + + This is a template that allows you to turn a global function into a + function object. The reason for this template's existance is so you can + do stuff like this: + + template + void call_funct(const T& funct) + { cout << funct(); } + + std::string test() { return "asdfasf"; } + + int main() + { + call_funct(wrap_function(test)); + } + + The above code doesn't work right on some compilers if you don't + use wrap_function. + !*/ + + template + funct_wrap0 wrap_function(T (&f)()) { return funct_wrap0(f); } + template + funct_wrap1 wrap_function(T (&f)(A0)) { return funct_wrap1(f); } + template + funct_wrap2 wrap_function(T (&f)(A0, A1)) { return funct_wrap2(f); } + template + funct_wrap3 wrap_function(T (&f)(A0, A1, A2)) { return funct_wrap3(f); } + template + funct_wrap4 wrap_function(T (&f)(A0, A1, A2, A3)) { return funct_wrap4(f); } + template + funct_wrap5 wrap_function(T (&f)(A0, A1, A2, A3, A4)) { return funct_wrap5(f); } + +// ---------------------------------------------------------------------------------------- + + template + class stack_based_memory_block : noncopyable + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a simple container for a block of memory + of bSIZE bytes. This memory block is located on the stack + and properly aligned to hold any kind of object. + !*/ + public: + static const unsigned long size = bSIZE; + + stack_based_memory_block(): data(mem.data) {} + + void* get () { return data; } + /*! + ensures + - returns a pointer to the block of memory contained in this object + !*/ + + const void* get () const { return data; } + /*! + ensures + - returns a pointer to the block of memory contained in this object + !*/ + + private: + + // You obviously can't have a block of memory that has zero bytes in it. + COMPILE_TIME_ASSERT(bSIZE > 0); + + union mem_block + { + // All of this garbage is to make sure this union is properly aligned + // (a union is always aligned such that everything in it would be properly + // aligned. So the assumption here is that one of these objects has + // a large enough alignment requirement to satisfy any object this + // block of memory might be cast into). + void* void_ptr; + int integer; + struct { + void (stack_based_memory_block::*callback)(); + stack_based_memory_block* o; + } stuff; + long double more_stuff; + + uint64 var1; + uint32 var2; + double var3; + + char data[size]; + } mem; + + // The reason for having this variable is that doing it this way avoids + // warnings from gcc about violations of strict-aliasing rules. + void* const data; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename F + > + auto max_scoring_element( + const T& container, + F score_func + ) -> decltype(std::make_pair(*container.begin(), 0.0)) + /*! + requires + - container has .begin() and .end(), allowing it to be enumerated. + - score_func() is a function that takes an element of the container and returns a double. + ensures + - This function finds the element of container that has the largest score, + according to score_func(), and returns a std::pair containing that maximal + element along with the score. + - If the container is empty then make_pair(a default initialized object, -infinity) is returned. + !*/ + { + double best_score = -std::numeric_limits::infinity(); + auto best_i = container.begin(); + for (auto i = container.begin(); i != container.end(); ++i) + { + auto score = score_func(*i); + if (score > best_score) + { + best_score = score; + best_i = i; + } + } + + using item_type = typename std::remove_reference::type; + + if (best_i == container.end()) + return std::make_pair(item_type(), best_score); + else + return std::make_pair(*best_i, best_score); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename F + > + auto min_scoring_element( + const T& container, + F score_func + ) -> decltype(std::make_pair(*container.begin(), 0.0)) + /*! + requires + - container has .begin() and .end(), allowing it to be enumerated. + - score_func() is a function that takes an element of the container and returns a double. + ensures + - This function finds the element of container that has the smallest score, + according to score_func(), and returns a std::pair containing that minimal + element along with the score. + - If the container is empty then make_pair(a default initialized object, infinity) is returned. + !*/ + { + double best_score = std::numeric_limits::infinity(); + auto best_i = container.begin(); + for (auto i = container.begin(); i != container.end(); ++i) + { + auto score = score_func(*i); + if (score < best_score) + { + best_score = score; + best_i = i; + } + } + + using item_type = typename std::remove_reference::type; + + if (best_i == container.end()) + return std::make_pair(item_type(), best_score); + else + return std::make_pair(*best_i, best_score); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_ALGs_ + diff --git a/ml/dlib/dlib/all/source.cpp b/ml/dlib/dlib/all/source.cpp new file mode 100644 index 000000000..1fc646a86 --- /dev/null +++ b/ml/dlib/dlib/all/source.cpp @@ -0,0 +1,98 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ALL_SOURCe_ +#define DLIB_ALL_SOURCe_ + +#if defined(DLIB_ALGs_) || defined(DLIB_PLATFORm_) +#include "../dlib_basic_cpp_build_tutorial.txt" +#endif + +// ISO C++ code +#include "../base64/base64_kernel_1.cpp" +#include "../bigint/bigint_kernel_1.cpp" +#include "../bigint/bigint_kernel_2.cpp" +#include "../bit_stream/bit_stream_kernel_1.cpp" +#include "../entropy_decoder/entropy_decoder_kernel_1.cpp" +#include "../entropy_decoder/entropy_decoder_kernel_2.cpp" +#include "../entropy_encoder/entropy_encoder_kernel_1.cpp" +#include "../entropy_encoder/entropy_encoder_kernel_2.cpp" +#include "../md5/md5_kernel_1.cpp" +#include "../tokenizer/tokenizer_kernel_1.cpp" +#include "../unicode/unicode.cpp" +#include "../test_for_odr_violations.cpp" + + + + +#ifndef DLIB_ISO_CPP_ONLY +// Code that depends on OS specific APIs + +// include this first so that it can disable the older version +// of the winsock API when compiled in windows. +#include "../sockets/sockets_kernel_1.cpp" +#include "../bsp/bsp.cpp" + +#include "../dir_nav/dir_nav_kernel_1.cpp" +#include "../dir_nav/dir_nav_kernel_2.cpp" +#include "../dir_nav/dir_nav_extensions.cpp" +#include "../linker/linker_kernel_1.cpp" +#include "../logger/extra_logger_headers.cpp" +#include "../logger/logger_kernel_1.cpp" +#include "../logger/logger_config_file.cpp" +#include "../misc_api/misc_api_kernel_1.cpp" +#include "../misc_api/misc_api_kernel_2.cpp" +#include "../sockets/sockets_extensions.cpp" +#include "../sockets/sockets_kernel_2.cpp" +#include "../sockstreambuf/sockstreambuf.cpp" +#include "../sockstreambuf/sockstreambuf_unbuffered.cpp" +#include "../server/server_kernel.cpp" +#include "../server/server_iostream.cpp" +#include "../server/server_http.cpp" +#include "../threads/multithreaded_object_extension.cpp" +#include "../threads/threaded_object_extension.cpp" +#include "../threads/threads_kernel_1.cpp" +#include "../threads/threads_kernel_2.cpp" +#include "../threads/threads_kernel_shared.cpp" +#include "../threads/thread_pool_extension.cpp" +#include "../threads/async.cpp" +#include "../timer/timer.cpp" +#include "../stack_trace.cpp" + +#ifdef DLIB_PNG_SUPPORT +#include "../image_loader/png_loader.cpp" +#include "../image_saver/save_png.cpp" +#endif + +#ifdef DLIB_JPEG_SUPPORT +#include "../image_loader/jpeg_loader.cpp" +#include "../image_saver/save_jpeg.cpp" +#endif + +#ifndef DLIB_NO_GUI_SUPPORT +#include "../gui_widgets/fonts.cpp" +#include "../gui_widgets/widgets.cpp" +#include "../gui_widgets/drawable.cpp" +#include "../gui_widgets/canvas_drawing.cpp" +#include "../gui_widgets/style.cpp" +#include "../gui_widgets/base_widgets.cpp" +#include "../gui_core/gui_core_kernel_1.cpp" +#include "../gui_core/gui_core_kernel_2.cpp" +#endif // DLIB_NO_GUI_SUPPORT + +#include "../dnn/cpu_dlib.cpp" +#include "../dnn/tensor_tools.cpp" + +#endif // DLIB_ISO_CPP_ONLY + + + +#include "../data_io/image_dataset_metadata.cpp" +#include "../data_io/mnist.cpp" +#include "../global_optimization/global_function_search.cpp" +#include "../filtering/kalman_filter.cpp" + + +#define DLIB_ALL_SOURCE_END + +#endif // DLIB_ALL_SOURCe_ + diff --git a/ml/dlib/dlib/any.h b/ml/dlib/dlib/any.h new file mode 100644 index 000000000..01f047066 --- /dev/null +++ b/ml/dlib/dlib/any.h @@ -0,0 +1,13 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_AnY_ +#define DLIB_AnY_ + +#include "any/any.h" +#include "any/any_trainer.h" +#include "any/any_decision_function.h" +#include "any/any_function.h" + +#endif // DLIB_AnY_ + + diff --git a/ml/dlib/dlib/any/any.h b/ml/dlib/dlib/any/any.h new file mode 100644 index 000000000..b5ef1bc8b --- /dev/null +++ b/ml/dlib/dlib/any/any.h @@ -0,0 +1,183 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_AnY_H_ +#define DLIB_AnY_H_ + +#include "any_abstract.h" +#include "../algs.h" + +#include +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class bad_any_cast : public std::bad_cast + { + public: + virtual const char * what() const throw() + { + return "bad_any_cast"; + } + }; + +// ---------------------------------------------------------------------------------------- + + class any + { + + public: + + any() + { + } + + any ( + const any& item + ) + { + if (item.data) + { + item.data->copy_to(data); + } + } + + template + any ( + const T& item + ) + { + typedef typename basic_type::type U; + data.reset(new derived(item)); + } + + void clear ( + ) + { + data.reset(); + } + + template + bool contains ( + ) const + { + typedef typename basic_type::type U; + return dynamic_cast*>(data.get()) != 0; + } + + bool is_empty( + ) const + { + return data.get() == 0; + } + + template + T& cast_to( + ) + { + typedef typename basic_type::type U; + derived* d = dynamic_cast*>(data.get()); + if (d == 0) + { + throw bad_any_cast(); + } + + return d->item; + } + + template + const T& cast_to( + ) const + { + typedef typename basic_type::type U; + derived* d = dynamic_cast*>(data.get()); + if (d == 0) + { + throw bad_any_cast(); + } + + return d->item; + } + + template + T& get( + ) + { + typedef typename basic_type::type U; + derived* d = dynamic_cast*>(data.get()); + if (d == 0) + { + d = new derived(); + data.reset(d); + } + + return d->item; + } + + any& operator= ( + const any& item + ) + { + any(item).swap(*this); + return *this; + } + + void swap ( + any& item + ) + { + data.swap(item.data); + } + + private: + + struct base + { + virtual ~base() {} + + virtual void copy_to ( + std::unique_ptr& dest + ) const = 0; + }; + + template + struct derived : public base + { + T item; + derived() {} + derived(const T& val) : item(val) {} + + virtual void copy_to ( + std::unique_ptr& dest + ) const + { + dest.reset(new derived(item)); + } + }; + + std::unique_ptr data; + }; + +// ---------------------------------------------------------------------------------------- + + inline void swap ( + any& a, + any& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- + + template T& any_cast(any& a) { return a.cast_to(); } + template const T& any_cast(const any& a) { return a.cast_to(); } + +// ---------------------------------------------------------------------------------------- + +} + + +#endif // DLIB_AnY_H_ + + + diff --git a/ml/dlib/dlib/any/any_abstract.h b/ml/dlib/dlib/any/any_abstract.h new file mode 100644 index 000000000..2fea96381 --- /dev/null +++ b/ml/dlib/dlib/any/any_abstract.h @@ -0,0 +1,210 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_AnY_ABSTRACT_H_ +#ifdef DLIB_AnY_ABSTRACT_H_ + +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class bad_any_cast : public std::bad_cast + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is the exception class used by the any object. + It is used to indicate when someone attempts to cast an any + object into a type which isn't contained in the any object. + !*/ + + public: + virtual const char* what() const throw() { return "bad_any_cast"; } + }; + +// ---------------------------------------------------------------------------------------- + + class any + { + /*! + INITIAL VALUE + - is_empty() == true + - for all T: contains() == false + + WHAT THIS OBJECT REPRESENTS + This object is basically a type-safe version of a void*. In particular, + it is a container which can contain only one object but the object may + be of any type. + + It is somewhat like the type_safe_union except you don't have to declare + the set of possible content types beforehand. So in some sense this is + like a less type-strict version of the type_safe_union. + !*/ + + public: + + any( + ); + /*! + ensures + - this object is properly initialized + !*/ + + any ( + const any& item + ); + /*! + ensures + - copies the state of item into *this. + - Note that *this and item will contain independent copies of the + contents of item. That is, this function performs a deep + copy and therefore does not result in *this containing + any kind of reference to item. + !*/ + + template < typename T > + any ( + const T& item + ); + /*! + ensures + - #contains() == true + - #cast_to() == item + (i.e. a copy of item will be stored in *this) + !*/ + + void clear ( + ); + /*! + ensures + - #*this will have its default value. I.e. #is_empty() == true + !*/ + + template + bool contains ( + ) const; + /*! + ensures + - if (this object currently contains an object of type T) then + - returns true + - else + - returns false + !*/ + + bool is_empty( + ) const; + /*! + ensures + - if (this object contains any kind of object) then + - returns false + - else + - returns true + !*/ + + template + T& cast_to( + ); + /*! + ensures + - if (contains() == true) then + - returns a non-const reference to the object contained within *this + - else + - throws bad_any_cast + !*/ + + template + const T& cast_to( + ) const; + /*! + ensures + - if (contains() == true) then + - returns a const reference to the object contained within *this + - else + - throws bad_any_cast + !*/ + + template + T& get( + ); + /*! + ensures + - #is_empty() == false + - #contains() == true + - if (contains() == true) + - returns a non-const reference to the object contained in *this. + - else + - Constructs an object of type T inside *this + - Any previous object stored in this any object is destructed and its + state is lost. + - returns a non-const reference to the newly created T object. + !*/ + + any& operator= ( + const any& item + ); + /*! + ensures + - copies the state of item into *this. + - Note that *this and item will contain independent copies of the + contents of item. That is, this function performs a deep + copy and therefore does not result in *this containing + any kind of reference to item. + !*/ + + void swap ( + any& item + ); + /*! + ensures + - swaps *this and item + - does not invalidate pointers or references to the object contained + inside *this or item. Moreover, a pointer or reference to the object in + *this will now refer to the contents of #item and vice versa. + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + inline void swap ( + any& a, + any& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + T& any_cast( + any& a + ) { return a.cast_to(); } + /*! + ensures + - returns a.cast_to() + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + const T& any_cast( + const any& a + ) { return a.cast_to(); } + /*! + ensures + - returns a.cast_to() + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_AnY_ABSTRACT_H_ + + diff --git a/ml/dlib/dlib/any/any_decision_function.h b/ml/dlib/dlib/any/any_decision_function.h new file mode 100644 index 000000000..771e9302b --- /dev/null +++ b/ml/dlib/dlib/any/any_decision_function.h @@ -0,0 +1,209 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_AnY_DECISION_FUNCTION_Hh_ +#define DLIB_AnY_DECISION_FUNCTION_Hh_ + +#include "any.h" + +#include "any_decision_function_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename sample_type_, + typename result_type_ = double + > + class any_decision_function + { + + public: + + typedef sample_type_ sample_type; + typedef result_type_ result_type; + typedef default_memory_manager mem_manager_type; + + any_decision_function() + { + } + + any_decision_function ( + const any_decision_function& item + ) + { + if (item.data) + { + item.data->copy_to(data); + } + } + + template + any_decision_function ( + const T& item + ) + { + typedef typename basic_type::type U; + data.reset(new derived(item)); + } + + void clear ( + ) + { + data.reset(); + } + + template + bool contains ( + ) const + { + typedef typename basic_type::type U; + return dynamic_cast*>(data.get()) != 0; + } + + bool is_empty( + ) const + { + return data.get() == 0; + } + + result_type operator() ( + const sample_type& item + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(is_empty() == false, + "\t result_type any_decision_function::operator()" + << "\n\t You can't call operator() on an empty any_decision_function" + << "\n\t this: " << this + ); + + return data->evaluate(item); + } + + template + T& cast_to( + ) + { + typedef typename basic_type::type U; + derived* d = dynamic_cast*>(data.get()); + if (d == 0) + { + throw bad_any_cast(); + } + + return d->item; + } + + template + const T& cast_to( + ) const + { + typedef typename basic_type::type U; + derived* d = dynamic_cast*>(data.get()); + if (d == 0) + { + throw bad_any_cast(); + } + + return d->item; + } + + template + T& get( + ) + { + typedef typename basic_type::type U; + derived* d = dynamic_cast*>(data.get()); + if (d == 0) + { + d = new derived(); + data.reset(d); + } + + return d->item; + } + + any_decision_function& operator= ( + const any_decision_function& item + ) + { + any_decision_function(item).swap(*this); + return *this; + } + + void swap ( + any_decision_function& item + ) + { + data.swap(item.data); + } + + private: + + struct base + { + virtual ~base() {} + + virtual void copy_to ( + std::unique_ptr& dest + ) const = 0; + + virtual result_type evaluate ( + const sample_type& samp + ) const = 0; + }; + + template + struct derived : public base + { + T item; + derived() {} + derived(const T& val) : item(val) {} + + virtual void copy_to ( + std::unique_ptr& dest + ) const + { + dest.reset(new derived(item)); + } + + virtual result_type evaluate ( + const sample_type& samp + ) const + { + return item(samp); + } + }; + + std::unique_ptr data; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename sample_type, + typename result_type + > + inline void swap ( + any_decision_function& a, + any_decision_function& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- + + template + T& any_cast(any_decision_function& a) { return a.template cast_to(); } + + template + const T& any_cast(const any_decision_function& a) { return a.template cast_to(); } + +// ---------------------------------------------------------------------------------------- + +} + + +#endif // DLIB_AnY_DECISION_FUNCTION_Hh_ + + diff --git a/ml/dlib/dlib/any/any_decision_function_abstract.h b/ml/dlib/dlib/any/any_decision_function_abstract.h new file mode 100644 index 000000000..8b6644210 --- /dev/null +++ b/ml/dlib/dlib/any/any_decision_function_abstract.h @@ -0,0 +1,224 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_AnY_DECISION_FUNCTION_ABSTRACT_H_ +#ifdef DLIB_AnY_DECISION_FUNCTION_ABSTRACT_H_ + +#include "any_abstract.h" +#include "../algs.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename sample_type_, + typename result_type_ = double + > + class any_decision_function + { + /*! + INITIAL VALUE + - is_empty() == true + - for all T: contains() == false + + WHAT THIS OBJECT REPRESENTS + This object is a version of dlib::any that is restricted to containing + elements which are some kind of function object with an operator() with + the following signature: + result_type operator()(const sample_type&) const + + It is intended to be used to contain dlib::decision_function objects and + other types which represent learned decision functions. It allows you + to write code which contains and processes these decision functions + without needing to know the specific types of decision functions used. + !*/ + + public: + + typedef sample_type_ sample_type; + typedef result_type_ result_type; + typedef default_memory_manager mem_manager_type; + + any_decision_function( + ); + /*! + ensures + - this object is properly initialized + !*/ + + any_decision_function ( + const any_decision_function& item + ); + /*! + ensures + - copies the state of item into *this. + - Note that *this and item will contain independent copies of the + contents of item. That is, this function performs a deep + copy and therefore does not result in *this containing + any kind of reference to item. + !*/ + + template < typename T > + any_decision_function ( + const T& item + ); + /*! + ensures + - #contains() == true + - #cast_to() == item + (i.e. a copy of item will be stored in *this) + !*/ + + void clear ( + ); + /*! + ensures + - #*this will have its default value. I.e. #is_empty() == true + !*/ + + template + bool contains ( + ) const; + /*! + ensures + - if (this object currently contains an object of type T) then + - returns true + - else + - returns false + !*/ + + bool is_empty( + ) const; + /*! + ensures + - if (this object contains any kind of object) then + - returns false + - else + - returns true + !*/ + + result_type operator() ( + const sample_type& item + ) const; + /*! + requires + - is_empty() == false + ensures + - Let F denote the function object contained within *this. Then + this function performs: + return F(item) + !*/ + + template + T& cast_to( + ); + /*! + ensures + - if (contains() == true) then + - returns a non-const reference to the object contained within *this + - else + - throws bad_any_cast + !*/ + + template + const T& cast_to( + ) const; + /*! + ensures + - if (contains() == true) then + - returns a const reference to the object contained within *this + - else + - throws bad_any_cast + !*/ + + template + T& get( + ); + /*! + ensures + - #is_empty() == false + - #contains() == true + - if (contains() == true) + - returns a non-const reference to the object contained in *this. + - else + - Constructs an object of type T inside *this + - Any previous object stored in this any_decision_function object is destructed and its + state is lost. + - returns a non-const reference to the newly created T object. + !*/ + + any_decision_function& operator= ( + const any_decision_function& item + ); + /*! + ensures + - copies the state of item into *this. + - Note that *this and item will contain independent copies of the + contents of item. That is, this function performs a deep + copy and therefore does not result in *this containing + any kind of reference to item. + !*/ + + void swap ( + any_decision_function& item + ); + /*! + ensures + - swaps *this and item + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename sample_type, + typename result_type + > + inline void swap ( + any_decision_function& a, + any_decision_function& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename sample_type, + typename result_type + > + T& any_cast( + any_decision_function& a + ) { return a.cast_to(); } + /*! + ensures + - returns a.cast_to() + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename sample_type, + typename result_type + > + const T& any_cast( + const any_decision_function& a + ) { return a.cast_to(); } + /*! + ensures + - returns a.cast_to() + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_AnY_DECISION_FUNCTION_ABSTRACT_H_ + + + diff --git a/ml/dlib/dlib/any/any_function.h b/ml/dlib/dlib/any/any_function.h new file mode 100644 index 000000000..f186b4d3f --- /dev/null +++ b/ml/dlib/dlib/any/any_function.h @@ -0,0 +1,885 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_AnY_FUNCTION_Hh_ +#define DLIB_AnY_FUNCTION_Hh_ + +#include "any.h" + +#include "any_function_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template + struct sig_traits {}; + + template < + typename T + > + struct sig_traits + { + typedef T result_type; + typedef void arg1_type; + typedef void arg2_type; + typedef void arg3_type; + typedef void arg4_type; + typedef void arg5_type; + typedef void arg6_type; + typedef void arg7_type; + typedef void arg8_type; + typedef void arg9_type; + typedef void arg10_type; + typedef void arg11_type; + typedef void arg12_type; + typedef void arg13_type; + typedef void arg14_type; + typedef void arg15_type; + typedef void arg16_type; + typedef void arg17_type; + typedef void arg18_type; + typedef void arg19_type; + typedef void arg20_type; + + const static unsigned long num_args = 0; + }; + + template < + typename T, + typename A1 + > + struct sig_traits + { + typedef T result_type; + typedef A1 arg1_type; + typedef void arg2_type; + typedef void arg3_type; + typedef void arg4_type; + typedef void arg5_type; + typedef void arg6_type; + typedef void arg7_type; + typedef void arg8_type; + typedef void arg9_type; + typedef void arg10_type; + typedef void arg11_type; + typedef void arg12_type; + typedef void arg13_type; + typedef void arg14_type; + typedef void arg15_type; + typedef void arg16_type; + typedef void arg17_type; + typedef void arg18_type; + typedef void arg19_type; + typedef void arg20_type; + + const static unsigned long num_args = 1; + }; + + template < + typename T, + typename A1, typename A2 + > + struct sig_traits + { + typedef T result_type; + typedef A1 arg1_type; + typedef A2 arg2_type; + typedef void arg3_type; + typedef void arg4_type; + typedef void arg5_type; + typedef void arg6_type; + typedef void arg7_type; + typedef void arg8_type; + typedef void arg9_type; + typedef void arg10_type; + typedef void arg11_type; + typedef void arg12_type; + typedef void arg13_type; + typedef void arg14_type; + typedef void arg15_type; + typedef void arg16_type; + typedef void arg17_type; + typedef void arg18_type; + typedef void arg19_type; + typedef void arg20_type; + + const static unsigned long num_args = 2; + }; + + template < + typename T, + typename A1, typename A2, typename A3 + > + struct sig_traits + { + typedef T result_type; + typedef A1 arg1_type; + typedef A2 arg2_type; + typedef A3 arg3_type; + typedef void arg4_type; + typedef void arg5_type; + typedef void arg6_type; + typedef void arg7_type; + typedef void arg8_type; + typedef void arg9_type; + typedef void arg10_type; + typedef void arg11_type; + typedef void arg12_type; + typedef void arg13_type; + typedef void arg14_type; + typedef void arg15_type; + typedef void arg16_type; + typedef void arg17_type; + typedef void arg18_type; + typedef void arg19_type; + typedef void arg20_type; + + const static unsigned long num_args = 3; + }; + + template < + typename T, + typename A1, typename A2, typename A3, + typename A4 + > + struct sig_traits + { + typedef T result_type; + typedef A1 arg1_type; + typedef A2 arg2_type; + typedef A3 arg3_type; + typedef A4 arg4_type; + typedef void arg5_type; + typedef void arg6_type; + typedef void arg7_type; + typedef void arg8_type; + typedef void arg9_type; + typedef void arg10_type; + typedef void arg11_type; + typedef void arg12_type; + typedef void arg13_type; + typedef void arg14_type; + typedef void arg15_type; + typedef void arg16_type; + typedef void arg17_type; + typedef void arg18_type; + typedef void arg19_type; + typedef void arg20_type; + + const static unsigned long num_args = 4; + }; + + template < + typename T, + typename A1, typename A2, typename A3, + typename A4, typename A5 + > + struct sig_traits + { + typedef T result_type; + typedef A1 arg1_type; + typedef A2 arg2_type; + typedef A3 arg3_type; + typedef A4 arg4_type; + typedef A5 arg5_type; + typedef void arg6_type; + typedef void arg7_type; + typedef void arg8_type; + typedef void arg9_type; + typedef void arg10_type; + typedef void arg11_type; + typedef void arg12_type; + typedef void arg13_type; + typedef void arg14_type; + typedef void arg15_type; + typedef void arg16_type; + typedef void arg17_type; + typedef void arg18_type; + typedef void arg19_type; + typedef void arg20_type; + + const static unsigned long num_args = 5; + }; + + template < + typename T, + typename A1, typename A2, typename A3, + typename A4, typename A5, typename A6 + > + struct sig_traits + { + typedef T result_type; + typedef A1 arg1_type; + typedef A2 arg2_type; + typedef A3 arg3_type; + typedef A4 arg4_type; + typedef A5 arg5_type; + typedef A6 arg6_type; + typedef void arg7_type; + typedef void arg8_type; + typedef void arg9_type; + typedef void arg10_type; + typedef void arg11_type; + typedef void arg12_type; + typedef void arg13_type; + typedef void arg14_type; + typedef void arg15_type; + typedef void arg16_type; + typedef void arg17_type; + typedef void arg18_type; + typedef void arg19_type; + typedef void arg20_type; + + const static unsigned long num_args = 6; + }; + + template < + typename T, + typename A1, typename A2, typename A3, + typename A4, typename A5, typename A6, + typename A7 + > + struct sig_traits + { + typedef T result_type; + typedef A1 arg1_type; + typedef A2 arg2_type; + typedef A3 arg3_type; + typedef A4 arg4_type; + typedef A5 arg5_type; + typedef A6 arg6_type; + typedef A7 arg7_type; + typedef void arg8_type; + typedef void arg9_type; + typedef void arg10_type; + typedef void arg11_type; + typedef void arg12_type; + typedef void arg13_type; + typedef void arg14_type; + typedef void arg15_type; + typedef void arg16_type; + typedef void arg17_type; + typedef void arg18_type; + typedef void arg19_type; + typedef void arg20_type; + + const static unsigned long num_args = 7; + }; + + template < + typename T, + typename A1, typename A2, typename A3, + typename A4, typename A5, typename A6, + typename A7, typename A8 + > + struct sig_traits + { + typedef T result_type; + typedef A1 arg1_type; + typedef A2 arg2_type; + typedef A3 arg3_type; + typedef A4 arg4_type; + typedef A5 arg5_type; + typedef A6 arg6_type; + typedef A7 arg7_type; + typedef A8 arg8_type; + typedef void arg9_type; + typedef void arg10_type; + typedef void arg11_type; + typedef void arg12_type; + typedef void arg13_type; + typedef void arg14_type; + typedef void arg15_type; + typedef void arg16_type; + typedef void arg17_type; + typedef void arg18_type; + typedef void arg19_type; + typedef void arg20_type; + + const static unsigned long num_args = 8; + }; + + template < + typename T, + typename A1, typename A2, typename A3, + typename A4, typename A5, typename A6, + typename A7, typename A8, typename A9 + > + struct sig_traits + { + typedef T result_type; + typedef A1 arg1_type; + typedef A2 arg2_type; + typedef A3 arg3_type; + typedef A4 arg4_type; + typedef A5 arg5_type; + typedef A6 arg6_type; + typedef A7 arg7_type; + typedef A8 arg8_type; + typedef A9 arg9_type; + typedef void arg10_type; + typedef void arg11_type; + typedef void arg12_type; + typedef void arg13_type; + typedef void arg14_type; + typedef void arg15_type; + typedef void arg16_type; + typedef void arg17_type; + typedef void arg18_type; + typedef void arg19_type; + typedef void arg20_type; + + const static unsigned long num_args = 9; + }; + + template < + typename T, + typename A1, typename A2, typename A3, + typename A4, typename A5, typename A6, + typename A7, typename A8, typename A9, + typename A10 + > + struct sig_traits + { + typedef T result_type; + typedef A1 arg1_type; + typedef A2 arg2_type; + typedef A3 arg3_type; + typedef A4 arg4_type; + typedef A5 arg5_type; + typedef A6 arg6_type; + typedef A7 arg7_type; + typedef A8 arg8_type; + typedef A9 arg9_type; + typedef A10 arg10_type; + typedef void arg11_type; + typedef void arg12_type; + typedef void arg13_type; + typedef void arg14_type; + typedef void arg15_type; + typedef void arg16_type; + typedef void arg17_type; + typedef void arg18_type; + typedef void arg19_type; + typedef void arg20_type; + + const static unsigned long num_args = 10; + }; + + template < + typename T, + typename A1, typename A2, typename A3, + typename A4, typename A5, typename A6, + typename A7, typename A8, typename A9, + typename A10, + typename A11 + > + struct sig_traits + { + typedef T result_type; + typedef A1 arg1_type; + typedef A2 arg2_type; + typedef A3 arg3_type; + typedef A4 arg4_type; + typedef A5 arg5_type; + typedef A6 arg6_type; + typedef A7 arg7_type; + typedef A8 arg8_type; + typedef A9 arg9_type; + typedef A10 arg10_type; + typedef A11 arg11_type; + typedef void arg12_type; + typedef void arg13_type; + typedef void arg14_type; + typedef void arg15_type; + typedef void arg16_type; + typedef void arg17_type; + typedef void arg18_type; + typedef void arg19_type; + typedef void arg20_type; + + const static unsigned long num_args = 11; + }; + + template < + typename T, + typename A1, typename A2, typename A3, + typename A4, typename A5, typename A6, + typename A7, typename A8, typename A9, + typename A10, + typename A11, + typename A12 + > + struct sig_traits + { + typedef T result_type; + typedef A1 arg1_type; + typedef A2 arg2_type; + typedef A3 arg3_type; + typedef A4 arg4_type; + typedef A5 arg5_type; + typedef A6 arg6_type; + typedef A7 arg7_type; + typedef A8 arg8_type; + typedef A9 arg9_type; + typedef A10 arg10_type; + typedef A11 arg11_type; + typedef A12 arg12_type; + typedef void arg13_type; + typedef void arg14_type; + typedef void arg15_type; + typedef void arg16_type; + typedef void arg17_type; + typedef void arg18_type; + typedef void arg19_type; + typedef void arg20_type; + + const static unsigned long num_args = 12; + }; + + template < + typename T, + typename A1, typename A2, typename A3, + typename A4, typename A5, typename A6, + typename A7, typename A8, typename A9, + typename A10, + typename A11, + typename A12, + typename A13 + > + struct sig_traits + { + typedef T result_type; + typedef A1 arg1_type; + typedef A2 arg2_type; + typedef A3 arg3_type; + typedef A4 arg4_type; + typedef A5 arg5_type; + typedef A6 arg6_type; + typedef A7 arg7_type; + typedef A8 arg8_type; + typedef A9 arg9_type; + typedef A10 arg10_type; + typedef A11 arg11_type; + typedef A12 arg12_type; + typedef A13 arg13_type; + typedef void arg14_type; + typedef void arg15_type; + typedef void arg16_type; + typedef void arg17_type; + typedef void arg18_type; + typedef void arg19_type; + typedef void arg20_type; + + const static unsigned long num_args = 13; + }; + + template < + typename T, + typename A1, typename A2, typename A3, + typename A4, typename A5, typename A6, + typename A7, typename A8, typename A9, + typename A10, + typename A11, + typename A12, + typename A13, + typename A14 + > + struct sig_traits + { + typedef T result_type; + typedef A1 arg1_type; + typedef A2 arg2_type; + typedef A3 arg3_type; + typedef A4 arg4_type; + typedef A5 arg5_type; + typedef A6 arg6_type; + typedef A7 arg7_type; + typedef A8 arg8_type; + typedef A9 arg9_type; + typedef A10 arg10_type; + typedef A11 arg11_type; + typedef A12 arg12_type; + typedef A13 arg13_type; + typedef A14 arg14_type; + typedef void arg15_type; + typedef void arg16_type; + typedef void arg17_type; + typedef void arg18_type; + typedef void arg19_type; + typedef void arg20_type; + + const static unsigned long num_args = 14; + }; + + template < + typename T, + typename A1, typename A2, typename A3, + typename A4, typename A5, typename A6, + typename A7, typename A8, typename A9, + typename A10, + typename A11, + typename A12, + typename A13, + typename A14, + typename A15 + > + struct sig_traits + { + typedef T result_type; + typedef A1 arg1_type; + typedef A2 arg2_type; + typedef A3 arg3_type; + typedef A4 arg4_type; + typedef A5 arg5_type; + typedef A6 arg6_type; + typedef A7 arg7_type; + typedef A8 arg8_type; + typedef A9 arg9_type; + typedef A10 arg10_type; + typedef A11 arg11_type; + typedef A12 arg12_type; + typedef A13 arg13_type; + typedef A14 arg14_type; + typedef A15 arg15_type; + typedef void arg16_type; + typedef void arg17_type; + typedef void arg18_type; + typedef void arg19_type; + typedef void arg20_type; + + const static unsigned long num_args = 15; + }; + + template < + typename T, + typename A1, typename A2, typename A3, + typename A4, typename A5, typename A6, + typename A7, typename A8, typename A9, + typename A10, + typename A11, + typename A12, + typename A13, + typename A14, + typename A15, + typename A16 + > + struct sig_traits + { + typedef T result_type; + typedef A1 arg1_type; + typedef A2 arg2_type; + typedef A3 arg3_type; + typedef A4 arg4_type; + typedef A5 arg5_type; + typedef A6 arg6_type; + typedef A7 arg7_type; + typedef A8 arg8_type; + typedef A9 arg9_type; + typedef A10 arg10_type; + typedef A11 arg11_type; + typedef A12 arg12_type; + typedef A13 arg13_type; + typedef A14 arg14_type; + typedef A15 arg15_type; + typedef A16 arg16_type; + typedef void arg17_type; + typedef void arg18_type; + typedef void arg19_type; + typedef void arg20_type; + + const static unsigned long num_args = 16; + }; + + template < + typename T, + typename A1, typename A2, typename A3, + typename A4, typename A5, typename A6, + typename A7, typename A8, typename A9, + typename A10, + typename A11, + typename A12, + typename A13, + typename A14, + typename A15, + typename A16, + typename A17 + > + struct sig_traits + { + typedef T result_type; + typedef A1 arg1_type; + typedef A2 arg2_type; + typedef A3 arg3_type; + typedef A4 arg4_type; + typedef A5 arg5_type; + typedef A6 arg6_type; + typedef A7 arg7_type; + typedef A8 arg8_type; + typedef A9 arg9_type; + typedef A10 arg10_type; + typedef A11 arg11_type; + typedef A12 arg12_type; + typedef A13 arg13_type; + typedef A14 arg14_type; + typedef A15 arg15_type; + typedef A16 arg16_type; + typedef A17 arg17_type; + typedef void arg18_type; + typedef void arg19_type; + typedef void arg20_type; + + const static unsigned long num_args = 17; + }; + + template < + typename T, + typename A1, typename A2, typename A3, + typename A4, typename A5, typename A6, + typename A7, typename A8, typename A9, + typename A10, + typename A11, + typename A12, + typename A13, + typename A14, + typename A15, + typename A16, + typename A17, + typename A18 + > + struct sig_traits + { + typedef T result_type; + typedef A1 arg1_type; + typedef A2 arg2_type; + typedef A3 arg3_type; + typedef A4 arg4_type; + typedef A5 arg5_type; + typedef A6 arg6_type; + typedef A7 arg7_type; + typedef A8 arg8_type; + typedef A9 arg9_type; + typedef A10 arg10_type; + typedef A11 arg11_type; + typedef A12 arg12_type; + typedef A13 arg13_type; + typedef A14 arg14_type; + typedef A15 arg15_type; + typedef A16 arg16_type; + typedef A17 arg17_type; + typedef A18 arg18_type; + typedef void arg19_type; + typedef void arg20_type; + + const static unsigned long num_args = 18; + }; + + template < + typename T, + typename A1, typename A2, typename A3, + typename A4, typename A5, typename A6, + typename A7, typename A8, typename A9, + typename A10, + typename A11, + typename A12, + typename A13, + typename A14, + typename A15, + typename A16, + typename A17, + typename A18, + typename A19 + > + struct sig_traits + { + typedef T result_type; + typedef A1 arg1_type; + typedef A2 arg2_type; + typedef A3 arg3_type; + typedef A4 arg4_type; + typedef A5 arg5_type; + typedef A6 arg6_type; + typedef A7 arg7_type; + typedef A8 arg8_type; + typedef A9 arg9_type; + typedef A10 arg10_type; + typedef A11 arg11_type; + typedef A12 arg12_type; + typedef A13 arg13_type; + typedef A14 arg14_type; + typedef A15 arg15_type; + typedef A16 arg16_type; + typedef A17 arg17_type; + typedef A18 arg18_type; + typedef A19 arg19_type; + typedef void arg20_type; + + const static unsigned long num_args = 19; + }; + + template < + typename T, + typename A1, typename A2, typename A3, + typename A4, typename A5, typename A6, + typename A7, typename A8, typename A9, + typename A10, + typename A11, + typename A12, + typename A13, + typename A14, + typename A15, + typename A16, + typename A17, + typename A18, + typename A19, + typename A20 + > + struct sig_traits + { + typedef T result_type; + typedef A1 arg1_type; + typedef A2 arg2_type; + typedef A3 arg3_type; + typedef A4 arg4_type; + typedef A5 arg5_type; + typedef A6 arg6_type; + typedef A7 arg7_type; + typedef A8 arg8_type; + typedef A9 arg9_type; + typedef A10 arg10_type; + typedef A11 arg11_type; + typedef A12 arg12_type; + typedef A13 arg13_type; + typedef A14 arg14_type; + typedef A15 arg15_type; + typedef A16 arg16_type; + typedef A17 arg17_type; + typedef A18 arg18_type; + typedef A19 arg19_type; + typedef A20 arg20_type; + + const static unsigned long num_args = 20; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename function_type, + // These arguments are used to control the overloading. A user should + // not mess with them. + typename Enabled = void, + unsigned long Num_args = sig_traits::num_args + > + class any_function + { + private: + any_function() {} + /* !!!!!!!! ERRORS ON THE ABOVE LINE !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + An error on this line means you are trying to use a function signature + with more than the supported number of arguments. The current version + of dlib only supports up to 10 arguments. + !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!*/ + }; + + + // The following preprocessor commands build the various overloaded versions + // of any_function for different numbers of commands and void vs. non-void return + // types. + +// 0 arguments +#define DLIB_ANY_FUNCTION_ARG_LIST +#define DLIB_ANY_FUNCTION_ARGS +#define DLIB_ANY_FUNCTION_NUM_ARGS 0 +#include "any_function_impl2.h" + +// 1 argument +#define DLIB_ANY_FUNCTION_ARG_LIST arg1_type a1 +#define DLIB_ANY_FUNCTION_ARGS a1 +#define DLIB_ANY_FUNCTION_NUM_ARGS 1 +#include "any_function_impl2.h" + +// 2 arguments +#define DLIB_ANY_FUNCTION_ARG_LIST arg1_type a1, arg2_type a2 +#define DLIB_ANY_FUNCTION_ARGS a1,a2 +#define DLIB_ANY_FUNCTION_NUM_ARGS 2 +#include "any_function_impl2.h" + +// 3 arguments +#define DLIB_ANY_FUNCTION_ARG_LIST arg1_type a1, arg2_type a2, arg3_type a3 +#define DLIB_ANY_FUNCTION_ARGS a1,a2,a3 +#define DLIB_ANY_FUNCTION_NUM_ARGS 3 +#include "any_function_impl2.h" + +// 4 arguments +#define DLIB_ANY_FUNCTION_ARG_LIST arg1_type a1, arg2_type a2, arg3_type a3, arg4_type a4 +#define DLIB_ANY_FUNCTION_ARGS a1,a2,a3,a4 +#define DLIB_ANY_FUNCTION_NUM_ARGS 4 +#include "any_function_impl2.h" + +// 5 arguments +#define DLIB_ANY_FUNCTION_ARG_LIST arg1_type a1, arg2_type a2, arg3_type a3, arg4_type a4, \ + arg5_type a5 +#define DLIB_ANY_FUNCTION_ARGS a1,a2,a3,a4,a5 +#define DLIB_ANY_FUNCTION_NUM_ARGS 5 +#include "any_function_impl2.h" + +// 6 arguments +#define DLIB_ANY_FUNCTION_ARG_LIST arg1_type a1, arg2_type a2, arg3_type a3, arg4_type a4, \ + arg5_type a5, arg6_type a6 +#define DLIB_ANY_FUNCTION_ARGS a1,a2,a3,a4,a5,a6 +#define DLIB_ANY_FUNCTION_NUM_ARGS 6 +#include "any_function_impl2.h" + +// 7 arguments +#define DLIB_ANY_FUNCTION_ARG_LIST arg1_type a1, arg2_type a2, arg3_type a3, arg4_type a4, \ + arg5_type a5, arg6_type a6, arg7_type a7 +#define DLIB_ANY_FUNCTION_ARGS a1,a2,a3,a4,a5,a6,a7 +#define DLIB_ANY_FUNCTION_NUM_ARGS 7 +#include "any_function_impl2.h" + +// 8 arguments +#define DLIB_ANY_FUNCTION_ARG_LIST arg1_type a1, arg2_type a2, arg3_type a3, arg4_type a4, \ + arg5_type a5, arg6_type a6, arg7_type a7, arg8_type a8 +#define DLIB_ANY_FUNCTION_ARGS a1,a2,a3,a4,a5,a6,a7,a8 +#define DLIB_ANY_FUNCTION_NUM_ARGS 8 +#include "any_function_impl2.h" + +// 9 arguments +#define DLIB_ANY_FUNCTION_ARG_LIST arg1_type a1, arg2_type a2, arg3_type a3, arg4_type a4, \ + arg5_type a5, arg6_type a6, arg7_type a7, arg8_type a8, \ + arg9_type a9 +#define DLIB_ANY_FUNCTION_ARGS a1,a2,a3,a4,a5,a6,a7,a8,a9 +#define DLIB_ANY_FUNCTION_NUM_ARGS 9 +#include "any_function_impl2.h" + +// 10 arguments +#define DLIB_ANY_FUNCTION_ARG_LIST arg1_type a1, arg2_type a2, arg3_type a3, arg4_type a4, \ + arg5_type a5, arg6_type a6, arg7_type a7, arg8_type a8, \ + arg9_type a9, arg10_type a10 +#define DLIB_ANY_FUNCTION_ARGS a1,a2,a3,a4,a5,a6,a7,a8,a9,a10 +#define DLIB_ANY_FUNCTION_NUM_ARGS 10 +#include "any_function_impl2.h" + +// ---------------------------------------------------------------------------------------- + + template + inline void swap ( + any_function& a, + any_function& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- + + template + T& any_cast(any_function& a) { return a.template cast_to(); } + + template + const T& any_cast(const any_function& a) { return a.template cast_to(); } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_AnY_FUNCTION_Hh_ + diff --git a/ml/dlib/dlib/any/any_function_abstract.h b/ml/dlib/dlib/any/any_function_abstract.h new file mode 100644 index 000000000..1fc129edb --- /dev/null +++ b/ml/dlib/dlib/any/any_function_abstract.h @@ -0,0 +1,292 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_AnY_FUNCTION_ABSTRACT_H_ +#ifdef DLIB_AnY_FUNCTION_ABSTRACT_H_ + +#include "any_abstract.h" +#include "../algs.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename function_type + > + class any_function + { + /*! + REQUIREMENTS ON function_type + This type should be a function signature. Some examples are: + void (int,int) // a function returning nothing and taking two ints + void () // a function returning nothing and taking no arguments + char (double&) // a function returning a char and taking a reference to a double + + The number of arguments in the function must be no greater than 10. + + INITIAL VALUE + - is_empty() == true + - for all T: contains() == false + + WHAT THIS OBJECT REPRESENTS + This object is a version of dlib::any that is restricted to containing + elements which are some kind of function object with an operator() which + matches the function signature defined by function_type. + + + Here is an example: + #include + #include + #include "dlib/any.h" + using namespace std; + void print_message(string str) { cout << str << endl; } + + int main() + { + dlib::any_function f; + f = print_message; + f("hello world"); // calls print_message("hello world") + } + + Note that any_function objects can be used to store general function + objects (i.e. defined by a class with an overloaded operator()) in + addition to regular global functions. + !*/ + + public: + + // This is the type of object returned by function_type functions. + typedef result_type_for_function_type result_type; + // Typedefs defining the argument types. If an argument does not exist + // then it is set to void. + typedef type_of_first_argument_in_funct_type arg1_type; + typedef type_of_second_argument_in_funct_type arg2_type; + ... + typedef type_of_last_argument_in_funct_type arg10_type; + const static unsigned long num_args = total_number_of_non_void_arguments; + + any_function( + ); + /*! + ensures + - this object is properly initialized + !*/ + + any_function ( + const any_function& item + ); + /*! + ensures + - copies the state of item into *this. + - Note that *this and item will contain independent copies of the + contents of item. That is, this function performs a deep + copy and therefore does not result in *this containing + any kind of reference to item. + !*/ + + template < typename T > + any_function ( + const T& item + ); + /*! + ensures + - #contains() == true + - #cast_to() == item + (i.e. a copy of item will be stored in *this) + !*/ + + void clear ( + ); + /*! + ensures + - #*this will have its default value. I.e. #is_empty() == true + !*/ + + template + bool contains ( + ) const; + /*! + ensures + - if (this object currently contains an object of type T) then + - returns true + - else + - returns false + !*/ + + bool is_empty( + ) const; + /*! + ensures + - if (this object contains any kind of object) then + - returns false + - else + - returns true + !*/ + + bool is_set ( + ) const; + /*! + ensures + - returns !is_empty() + !*/ + + result_type operator() ( + ) const; + /*! + requires + - is_empty() == false + - the signature defined by function_type takes no arguments + ensures + - Let F denote the function object contained within *this. Then + this function performs: + return F() + or if result_type is void then this function performs: + F() + !*/ + + result_type operator() ( + const arg1_type& a1 + ) const; + /*! + requires + - is_empty() == false + - the signature defined by function_type takes one argument + ensures + - Let F denote the function object contained within *this. Then + this function performs: + return F(a1) + or if result_type is void then this function performs: + F(a1) + !*/ + + result_type operator() ( + const arg1_type& a1, + const arg2_type& a2 + ) const; + /*! + requires + - is_empty() == false + - the signature defined by function_type takes two arguments + ensures + - Let F denote the function object contained within *this. Then + this function performs: + return F(a1,a2) + or if result_type is void then this function performs: + F(a1,a2) + !*/ + + /* !!!!!!!!! NOTE !!!!!!!!! + + In addition to the above, operator() is defined for up to 10 arguments. + They are not listed here because it would clutter the documentation. + + !!!!!!!!! NOTE !!!!!!!!! */ + + template + T& cast_to( + ); + /*! + ensures + - if (contains() == true) then + - returns a non-const reference to the object contained within *this + - else + - throws bad_any_cast + !*/ + + template + const T& cast_to( + ) const; + /*! + ensures + - if (contains() == true) then + - returns a const reference to the object contained within *this + - else + - throws bad_any_cast + !*/ + + template + T& get( + ); + /*! + ensures + - #is_empty() == false + - #contains() == true + - if (contains() == true) + - returns a non-const reference to the object contained in *this. + - else + - Constructs an object of type T inside *this + - Any previous object stored in this any_function object is destructed and its + state is lost. + - returns a non-const reference to the newly created T object. + !*/ + + any_function& operator= ( + const any_function& item + ); + /*! + ensures + - copies the state of item into *this. + - Note that *this and item will contain independent copies of the + contents of item. That is, this function performs a deep + copy and therefore does not result in *this containing + any kind of reference to item. + !*/ + + void swap ( + any_function& item + ); + /*! + ensures + - swaps *this and item + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename function_type + > + inline void swap ( + any_function& a, + any_function& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename function_type + > + T& any_cast( + any_function& a + ) { return a.cast_to(); } + /*! + ensures + - returns a.cast_to() + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename function_type + > + const T& any_cast( + const any_function& a + ) { return a.cast_to(); } + /*! + ensures + - returns a.cast_to() + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_AnY_FUNCTION_ABSTRACT_H_ + diff --git a/ml/dlib/dlib/any/any_function_impl.h b/ml/dlib/dlib/any/any_function_impl.h new file mode 100644 index 000000000..fec66cde7 --- /dev/null +++ b/ml/dlib/dlib/any/any_function_impl.h @@ -0,0 +1,516 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ANY_FUNCTION_RETURN +#error "You aren't supposed to directly #include this file. #include instead." +#endif + +#ifdef _MSC_VER +// When using visual studio 2012, disable the warning "warning C4180: qualifier applied to function type has no meaning; ignored" +// that you get about some template expansions applying & to function types. +#pragma warning(disable : 4180) +#endif + +#ifdef DLIB_ANY_FUNCTION_RETURN + +// This file contains the body of the any_function class. We use the +// preprocessor to generate many different versions. There are +// versions which return a value and those which return void. For +// each of these types there are versions with differing numbers +// of arguments. + +public: +typedef typename sig_traits::result_type result_type; +typedef typename sig_traits::arg1_type arg1_type; +typedef typename sig_traits::arg2_type arg2_type; +typedef typename sig_traits::arg3_type arg3_type; +typedef typename sig_traits::arg4_type arg4_type; +typedef typename sig_traits::arg5_type arg5_type; +typedef typename sig_traits::arg6_type arg6_type; +typedef typename sig_traits::arg7_type arg7_type; +typedef typename sig_traits::arg8_type arg8_type; +typedef typename sig_traits::arg9_type arg9_type; +typedef typename sig_traits::arg10_type arg10_type; +const static unsigned long num_args = sig_traits::num_args; + +any_function() +{ +} + +any_function ( + const any_function& item +) +{ + if (item.data) + { + item.data->copy_to(data); + } +} + +template +any_function ( + const T& item +) +{ + typedef typename basic_type::type U; + data.reset(new derived(item)); +} + +void clear ( +) +{ + data.reset(); +} + +template +bool contains ( +) const +{ + typedef typename basic_type::type U; + return dynamic_cast*>(data.get()) != 0; +} + +bool is_empty( +) const +{ + return data.get() == 0; +} + +bool is_set( +) const +{ + return !is_empty(); +} + +template +T& cast_to( +) +{ + typedef typename basic_type::type U; + derived* d = dynamic_cast*>(data.get()); + if (d == 0) + { + throw bad_any_cast(); + } + + return d->item; +} + +template +const T& cast_to( +) const +{ + typedef typename basic_type::type U; + derived* d = dynamic_cast*>(data.get()); + if (d == 0) + { + throw bad_any_cast(); + } + + return d->item; +} + +template +T& get( +) +{ + typedef typename basic_type::type U; + derived* d = dynamic_cast*>(data.get()); + if (d == 0) + { + d = new derived(); + data.reset(d); + } + + return d->item; +} + +any_function& operator= ( + const any_function& item +) +{ + any_function(item).swap(*this); + return *this; +} + +void swap ( + any_function& item +) +{ + data.swap(item.data); +} + +result_type operator()(DLIB_ANY_FUNCTION_ARG_LIST) const +{ validate(); DLIB_ANY_FUNCTION_RETURN data->evaluate(DLIB_ANY_FUNCTION_ARGS); } +/* !!!!!!!! ERRORS ON THE ABOVE LINE !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + If you are getting an error on the above line then it means you + have attempted to call a dlib::any_function but you have supplied + arguments which don't match the function signature used by the + dlib::any_function. +!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!*/ + +private: + +void validate () const +{ + // make sure requires clause is not broken + DLIB_ASSERT(is_empty() == false, + "\t result_type any_function::operator()" + << "\n\t You can't call operator() on an empty any_function" + << "\n\t this: " << this + ); +} + + +template +struct Tbase +{ + virtual ~Tbase() {} + virtual result_type evaluate () const = 0; + virtual void copy_to ( std::unique_ptr& dest) const = 0; +}; + +template < + typename T, + typename A1 + > +struct Tbase +{ + virtual ~Tbase() {} + virtual T evaluate ( A1) const = 0; + virtual void copy_to ( std::unique_ptr& dest) const = 0; +}; + +template < + typename T, + typename A1, typename A2 + > +struct Tbase +{ + virtual ~Tbase() {} + virtual T evaluate (A1,A2) const = 0; + virtual void copy_to ( std::unique_ptr& dest) const = 0; +}; + +template < + typename T, + typename A1, typename A2, typename A3 + > +struct Tbase +{ + virtual ~Tbase() {} + virtual T evaluate (A1,A2,A3) const = 0; + virtual void copy_to ( std::unique_ptr& dest) const = 0; +}; + +template < + typename T, + typename A1, typename A2, typename A3, + typename A4 + > +struct Tbase +{ + virtual ~Tbase() {} + virtual T evaluate (A1,A2,A3,A4) const = 0; + virtual void copy_to ( std::unique_ptr& dest) const = 0; +}; + +template < + typename T, + typename A1, typename A2, typename A3, + typename A4, typename A5 + > +struct Tbase +{ + virtual ~Tbase() {} + virtual T evaluate (A1,A2,A3,A4,A5) const = 0; + virtual void copy_to ( std::unique_ptr& dest) const = 0; +}; + +template < + typename T, + typename A1, typename A2, typename A3, + typename A4, typename A5, typename A6 + > +struct Tbase +{ + virtual ~Tbase() {} + virtual T evaluate (A1,A2,A3,A4,A5,A6) const = 0; + virtual void copy_to ( std::unique_ptr& dest) const = 0; +}; + +template < + typename T, + typename A1, typename A2, typename A3, + typename A4, typename A5, typename A6, + typename A7 + > +struct Tbase +{ + virtual ~Tbase() {} + virtual T evaluate (A1,A2,A3,A4,A5,A6,A7) const = 0; + virtual void copy_to ( std::unique_ptr& dest) const = 0; +}; + +template < + typename T, + typename A1, typename A2, typename A3, + typename A4, typename A5, typename A6, + typename A7, typename A8 + > +struct Tbase +{ + virtual ~Tbase() {} + virtual T evaluate (A1,A2,A3,A4,A5,A6,A7,A8) const = 0; + virtual void copy_to ( std::unique_ptr& dest) const = 0; +}; + +template < + typename T, + typename A1, typename A2, typename A3, + typename A4, typename A5, typename A6, + typename A7, typename A8, typename A9 + > +struct Tbase +{ + virtual ~Tbase() {} + virtual T evaluate (A1,A2,A3,A4,A5,A6,A7,A8,A9) const = 0; + virtual void copy_to ( std::unique_ptr& dest) const = 0; +}; + +template < + typename T, + typename A1, typename A2, typename A3, + typename A4, typename A5, typename A6, + typename A7, typename A8, typename A9, + typename A10 + > +struct Tbase +{ + virtual ~Tbase() {} + virtual T evaluate (A1,A2,A3,A4,A5,A6,A7,A8,A9,A10) const = 0; + virtual void copy_to ( std::unique_ptr& dest) const = 0; +}; + +typedef Tbase base; + +// ----------------------------------------------- + +// Some templates to help deal with the weirdness of storing C function types (rather than pointer to functions). +// Basically, we make sure things always get turned into function pointers even if the user gives a function reference. +template +struct funct_type { typedef T type; }; +template +struct funct_type >::type> { typedef T* type; }; + +template +static typename enable_if,const T*>::type copy (const T& item) { return &item; } +template +static typename disable_if,const T&>::type copy (const T& item) { return item; } + +template +static typename enable_if,const T&>::type deref (const U& item) { return *item; } +template +static typename disable_if,const T&>::type deref (const U& item) { return item; } + +// ----------------------------------------------- + +#define DLIB_ANY_FUNCTION_DERIVED_BOILERPLATE \ + typename funct_type::type item; \ + derived() {} \ + derived(const T& val) : item(copy(val)) {} \ + virtual void copy_to ( std::unique_ptr& dest) const \ + { dest.reset(new derived(deref(item))); } + +template +struct derived : public base +{ + DLIB_ANY_FUNCTION_DERIVED_BOILERPLATE + + virtual result_type evaluate ( + ) const { DLIB_ANY_FUNCTION_RETURN item(); } + /* !!!!!!!! ERRORS ON THE ABOVE LINE !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + If you are getting an error on the above line then it means you + have attempted to assign a function or function object to a + dlib::any_function but the signatures of the source and + destination functions don't match. + !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!*/ +}; + +template +struct derived : public base +{ + DLIB_ANY_FUNCTION_DERIVED_BOILERPLATE + + virtual result_type evaluate ( + A1 a1 + ) const { DLIB_ANY_FUNCTION_RETURN item(a1); } + /* !!!!!!!! ERRORS ON THE ABOVE LINE !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + If you are getting an error on the above line then it means you + have attempted to assign a function or function object to a + dlib::any_function but the signatures of the source and + destination functions don't match. + !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!*/ +}; + +template +struct derived : public base +{ + DLIB_ANY_FUNCTION_DERIVED_BOILERPLATE + + virtual result_type evaluate ( + A1 a1, A2 a2 + ) const { DLIB_ANY_FUNCTION_RETURN item(a1,a2); } + /* !!!!!!!! ERRORS ON THE ABOVE LINE !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + If you are getting an error on the above line then it means you + have attempted to assign a function or function object to a + dlib::any_function but the signatures of the source and + destination functions don't match. + !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!*/ +}; + +template +struct derived : public base +{ + DLIB_ANY_FUNCTION_DERIVED_BOILERPLATE + + virtual result_type evaluate ( + A1 a1, A2 a2, A3 a3 + ) const { DLIB_ANY_FUNCTION_RETURN item(a1,a2,a3); } + /* !!!!!!!! ERRORS ON THE ABOVE LINE !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + If you are getting an error on the above line then it means you + have attempted to assign a function or function object to a + dlib::any_function but the signatures of the source and + destination functions don't match. + !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!*/ +}; + +template +struct derived : public base +{ + DLIB_ANY_FUNCTION_DERIVED_BOILERPLATE + + virtual result_type evaluate ( + A1 a1, A2 a2, A3 a3, A4 a4 + ) const { DLIB_ANY_FUNCTION_RETURN item(a1,a2,a3,a4); } + /* !!!!!!!! ERRORS ON THE ABOVE LINE !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + If you are getting an error on the above line then it means you + have attempted to assign a function or function object to a + dlib::any_function but the signatures of the source and + destination functions don't match. + !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!*/ +}; + +template +struct derived : public base +{ + DLIB_ANY_FUNCTION_DERIVED_BOILERPLATE + + virtual result_type evaluate ( + A1 a1, A2 a2, A3 a3, A4 a4, A5 a5 + ) const { DLIB_ANY_FUNCTION_RETURN item(a1,a2,a3,a4,a5); } + /* !!!!!!!! ERRORS ON THE ABOVE LINE !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + If you are getting an error on the above line then it means you + have attempted to assign a function or function object to a + dlib::any_function but the signatures of the source and + destination functions don't match. + !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!*/ +}; + +template +struct derived : public base +{ + DLIB_ANY_FUNCTION_DERIVED_BOILERPLATE + + virtual result_type evaluate ( + A1 a1, A2 a2, A3 a3, A4 a4, A5 a5, A6 a6 + ) const { DLIB_ANY_FUNCTION_RETURN item(a1,a2,a3,a4,a5,a6); } + /* !!!!!!!! ERRORS ON THE ABOVE LINE !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + If you are getting an error on the above line then it means you + have attempted to assign a function or function object to a + dlib::any_function but the signatures of the source and + destination functions don't match. + !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!*/ +}; + +template +struct derived : public base +{ + DLIB_ANY_FUNCTION_DERIVED_BOILERPLATE + + virtual result_type evaluate ( + A1 a1, A2 a2, A3 a3, A4 a4, A5 a5, A6 a6, A7 a7 + ) const { DLIB_ANY_FUNCTION_RETURN item(a1,a2,a3,a4,a5,a6,a7); } + /* !!!!!!!! ERRORS ON THE ABOVE LINE !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + If you are getting an error on the above line then it means you + have attempted to assign a function or function object to a + dlib::any_function but the signatures of the source and + destination functions don't match. + !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!*/ +}; + +template +struct derived : public base +{ + DLIB_ANY_FUNCTION_DERIVED_BOILERPLATE + + virtual result_type evaluate ( + A1 a1, A2 a2, A3 a3, A4 a4, A5 a5, A6 a6, A7 a7, A8 a8 + ) const { DLIB_ANY_FUNCTION_RETURN item(a1,a2,a3,a4,a5,a6,a7,a8); } + /* !!!!!!!! ERRORS ON THE ABOVE LINE !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + If you are getting an error on the above line then it means you + have attempted to assign a function or function object to a + dlib::any_function but the signatures of the source and + destination functions don't match. + !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!*/ +}; + +template +struct derived : public base +{ + DLIB_ANY_FUNCTION_DERIVED_BOILERPLATE + + virtual result_type evaluate ( + A1 a1, A2 a2, A3 a3, A4 a4, A5 a5, A6 a6, A7 a7, A8 a8, A9 a9 + ) const { DLIB_ANY_FUNCTION_RETURN item(a1,a2,a3,a4,a5,a6,a7,a8,a9); } + /* !!!!!!!! ERRORS ON THE ABOVE LINE !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + If you are getting an error on the above line then it means you + have attempted to assign a function or function object to a + dlib::any_function but the signatures of the source and + destination functions don't match. + !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!*/ +}; + +template +struct derived : public base +{ + DLIB_ANY_FUNCTION_DERIVED_BOILERPLATE + + virtual result_type evaluate ( + A1 a1, A2 a2, A3 a3, A4 a4, A5 a5, A6 a6, A7 a7, A8 a8, A9 a9, A10 a10 + ) const { DLIB_ANY_FUNCTION_RETURN item(a1,a2,a3,a4,a5,a6,a7,a8,a9,a10); } + /* !!!!!!!! ERRORS ON THE ABOVE LINE !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + If you are getting an error on the above line then it means you + have attempted to assign a function or function object to a + dlib::any_function but the signatures of the source and + destination functions don't match. + !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!*/ +}; + +std::unique_ptr data; + +#undef DLIB_ANY_FUNCTION_DERIVED_BOILERPLATE + +#endif // DLIB_ANY_FUNCTION_RETURN + diff --git a/ml/dlib/dlib/any/any_function_impl2.h b/ml/dlib/dlib/any/any_function_impl2.h new file mode 100644 index 000000000..e1801ddc1 --- /dev/null +++ b/ml/dlib/dlib/any/any_function_impl2.h @@ -0,0 +1,44 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ANY_FUNCTION_ARG_LIST +#error "You aren't supposed to directly #include this file. #include instead." +#endif + +#ifdef DLIB_ANY_FUNCTION_ARG_LIST + +// The case where function_type has a non-void return type + template + class any_function + { +#define DLIB_ANY_FUNCTION_RETURN return +#include "any_function_impl.h" +#undef DLIB_ANY_FUNCTION_RETURN + + private: + // You get a compiler error about this function being private if you try to assign + // or copy between any_functions with different types. You must only copy between + // any_functions that represent functions with the same signature. + template any_function(const any_function&); + }; + +// The case where function_type has a void return type + template + class any_function::type, DLIB_ANY_FUNCTION_NUM_ARGS> + { +#define DLIB_ANY_FUNCTION_RETURN +#include "any_function_impl.h" +#undef DLIB_ANY_FUNCTION_RETURN + + private: + // You get a compiler error about this function being private if you try to assign + // or copy between any_functions with different types. You must only copy between + // any_functions that represent functions with the same signature. + template any_function(const any_function&); + }; + +#undef DLIB_ANY_FUNCTION_ARG_LIST +#undef DLIB_ANY_FUNCTION_ARGS +#undef DLIB_ANY_FUNCTION_NUM_ARGS + +#endif // DLIB_ANY_FUNCTION_ARG_LIST + diff --git a/ml/dlib/dlib/any/any_trainer.h b/ml/dlib/dlib/any/any_trainer.h new file mode 100644 index 000000000..4df10a140 --- /dev/null +++ b/ml/dlib/dlib/any/any_trainer.h @@ -0,0 +1,217 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_AnY_TRAINER_H_ +#define DLIB_AnY_TRAINER_H_ + +#include "any.h" + +#include "any_decision_function.h" + +#include "any_trainer_abstract.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename sample_type_, + typename scalar_type_ = double + > + class any_trainer + { + public: + typedef sample_type_ sample_type; + typedef scalar_type_ scalar_type; + typedef default_memory_manager mem_manager_type; + typedef any_decision_function trained_function_type; + + + any_trainer() + { + } + + any_trainer ( + const any_trainer& item + ) + { + if (item.data) + { + item.data->copy_to(data); + } + } + + template + any_trainer ( + const T& item + ) + { + typedef typename basic_type::type U; + data.reset(new derived(item)); + } + + void clear ( + ) + { + data.reset(); + } + + template + bool contains ( + ) const + { + typedef typename basic_type::type U; + return dynamic_cast*>(data.get()) != 0; + } + + bool is_empty( + ) const + { + return data.get() == 0; + } + + trained_function_type train ( + const std::vector& samples, + const std::vector& labels + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(is_empty() == false, + "\t trained_function_type any_trainer::train()" + << "\n\t You can't call train() on an empty any_trainer" + << "\n\t this: " << this + ); + + return data->train(samples, labels); + } + + template + T& cast_to( + ) + { + typedef typename basic_type::type U; + derived* d = dynamic_cast*>(data.get()); + if (d == 0) + { + throw bad_any_cast(); + } + + return d->item; + } + + template + const T& cast_to( + ) const + { + typedef typename basic_type::type U; + derived* d = dynamic_cast*>(data.get()); + if (d == 0) + { + throw bad_any_cast(); + } + + return d->item; + } + + template + T& get( + ) + { + typedef typename basic_type::type U; + derived* d = dynamic_cast*>(data.get()); + if (d == 0) + { + d = new derived(); + data.reset(d); + } + + return d->item; + } + + any_trainer& operator= ( + const any_trainer& item + ) + { + any_trainer(item).swap(*this); + return *this; + } + + void swap ( + any_trainer& item + ) + { + data.swap(item.data); + } + + private: + + struct base + { + virtual ~base() {} + + virtual trained_function_type train ( + const std::vector& samples, + const std::vector& labels + ) const = 0; + + virtual void copy_to ( + std::unique_ptr& dest + ) const = 0; + }; + + template + struct derived : public base + { + T item; + derived() {} + derived(const T& val) : item(val) {} + + virtual void copy_to ( + std::unique_ptr& dest + ) const + { + dest.reset(new derived(item)); + } + + virtual trained_function_type train ( + const std::vector& samples, + const std::vector& labels + ) const + { + return item.train(samples, labels); + } + }; + + std::unique_ptr data; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename sample_type, + typename scalar_type + > + inline void swap ( + any_trainer& a, + any_trainer& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- + + template + T& any_cast(any_trainer& a) { return a.template cast_to(); } + + template + const T& any_cast(const any_trainer& a) { return a.template cast_to(); } + +// ---------------------------------------------------------------------------------------- + +} + + +#endif // DLIB_AnY_TRAINER_H_ + + + + diff --git a/ml/dlib/dlib/any/any_trainer_abstract.h b/ml/dlib/dlib/any/any_trainer_abstract.h new file mode 100644 index 000000000..877792fc1 --- /dev/null +++ b/ml/dlib/dlib/any/any_trainer_abstract.h @@ -0,0 +1,234 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_AnY_TRAINER_ABSTRACT_H_ +#ifdef DLIB_AnY_TRAINER_ABSTRACT_H_ + +#include "any_abstract.h" +#include "../algs.h" +#include "any_decision_function_abstract.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename sample_type_, + typename scalar_type_ = double + > + class any_trainer + { + /*! + INITIAL VALUE + - is_empty() == true + - for all T: contains() == false + + WHAT THIS OBJECT REPRESENTS + This object is a version of dlib::any that is restricted to containing + elements which are some kind of object with a .train() method compatible + with the following signature: + + decision_function train( + const std::vector& samples, + const std::vector& labels + ) const + + Where decision_function is a type capable of being stored in an + any_decision_function object. + + any_trainer is intended to be used to contain objects such as the svm_nu_trainer + and other similar types which represent supervised machine learning algorithms. + It allows you to write code which contains and processes these trainer objects + without needing to know the specific types of trainer objects used. + !*/ + + public: + + typedef sample_type_ sample_type; + typedef scalar_type_ scalar_type; + typedef default_memory_manager mem_manager_type; + typedef any_decision_function trained_function_type; + + any_trainer( + ); + /*! + ensures + - this object is properly initialized + !*/ + + any_trainer ( + const any_trainer& item + ); + /*! + ensures + - copies the state of item into *this. + - Note that *this and item will contain independent copies of the + contents of item. That is, this function performs a deep + copy and therefore does not result in *this containing + any kind of reference to item. + !*/ + + template < typename T > + any_trainer ( + const T& item + ); + /*! + ensures + - #contains() == true + - #cast_to() == item + (i.e. a copy of item will be stored in *this) + !*/ + + void clear ( + ); + /*! + ensures + - #*this will have its default value. I.e. #is_empty() == true + !*/ + + template + bool contains ( + ) const; + /*! + ensures + - if (this object currently contains an object of type T) then + - returns true + - else + - returns false + !*/ + + bool is_empty( + ) const; + /*! + ensures + - if (this object contains any kind of object) then + - returns false + - else + - returns true + !*/ + + trained_function_type train ( + const std::vector& samples, + const std::vector& labels + ) const + /*! + requires + - is_empty() == false + ensures + - Let TRAINER denote the object contained within *this. Then + this function performs: + return TRAINER.train(samples, labels) + !*/ + + template + T& cast_to( + ); + /*! + ensures + - if (contains() == true) then + - returns a non-const reference to the object contained within *this + - else + - throws bad_any_cast + !*/ + + template + const T& cast_to( + ) const; + /*! + ensures + - if (contains() == true) then + - returns a const reference to the object contained within *this + - else + - throws bad_any_cast + !*/ + + template + T& get( + ); + /*! + ensures + - #is_empty() == false + - #contains() == true + - if (contains() == true) + - returns a non-const reference to the object contained in *this. + - else + - Constructs an object of type T inside *this + - Any previous object stored in this any_trainer object is destructed and its + state is lost. + - returns a non-const reference to the newly created T object. + !*/ + + any_trainer& operator= ( + const any_trainer& item + ); + /*! + ensures + - copies the state of item into *this. + - Note that *this and item will contain independent copies of the + contents of item. That is, this function performs a deep + copy and therefore does not result in *this containing + any kind of reference to item. + !*/ + + void swap ( + any_trainer& item + ); + /*! + ensures + - swaps *this and item + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename sample_type, + typename scalar_type + > + inline void swap ( + any_trainer& a, + any_trainer& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename sample_type, + typename scalar_type + > + T& any_cast( + any_trainer& a + ) { return a.cast_to(); } + /*! + ensures + - returns a.cast_to() + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename sample_type, + typename scalar_type + > + const T& any_cast( + const any_trainer& a + ) { return a.cast_to(); } + /*! + ensures + - returns a.cast_to() + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_AnY_TRAINER_ABSTRACT_H_ + + diff --git a/ml/dlib/dlib/appveyor/dtest.yml b/ml/dlib/dlib/appveyor/dtest.yml new file mode 100644 index 000000000..212528eb1 --- /dev/null +++ b/ml/dlib/dlib/appveyor/dtest.yml @@ -0,0 +1,19 @@ +version: "{build}" + +configuration: Release + +build_script: + # build test + - mkdir %APPVEYOR_BUILD_FOLDER%\build_test + - cd %APPVEYOR_BUILD_FOLDER%\build_test + - cmake -G "Visual Studio 14 2015 Win64" -T host=x64 ../dlib/test + - cmake --build . --config %CONFIGURATION% --target dlib_all_source_cpp + - cmake --build . --config %CONFIGURATION% --target imglab + - cmake --build . --config %CONFIGURATION% --target htmlify + - cmake --build . --config %CONFIGURATION% --target gui + - cmake --build . --config %CONFIGURATION% --target dtest + +test_script: + # run test + - cd %APPVEYOR_BUILD_FOLDER%\build_test\%CONFIGURATION% + - dtest --runall diff --git a/ml/dlib/dlib/appveyor/dtest_vc2017.yml b/ml/dlib/dlib/appveyor/dtest_vc2017.yml new file mode 100644 index 000000000..9e5e72ee9 --- /dev/null +++ b/ml/dlib/dlib/appveyor/dtest_vc2017.yml @@ -0,0 +1,21 @@ +version: "{build}" + +configuration: Release + +image: Visual Studio 2017 + +build_script: + # build test + - mkdir %APPVEYOR_BUILD_FOLDER%\build_test + - cd %APPVEYOR_BUILD_FOLDER%\build_test + - cmake -G "Visual Studio 15 2017 Win64" -T host=x64 ../dlib/test + - cmake --build . --config %CONFIGURATION% --target dlib_all_source_cpp + - cmake --build . --config %CONFIGURATION% --target imglab + - cmake --build . --config %CONFIGURATION% --target htmlify + - cmake --build . --config %CONFIGURATION% --target gui + - cmake --build . --config %CONFIGURATION% --target dtest + +test_script: + # run test + - cd %APPVEYOR_BUILD_FOLDER%\build_test\%CONFIGURATION% + - dtest --runall diff --git a/ml/dlib/dlib/appveyor/examples.yml b/ml/dlib/dlib/appveyor/examples.yml new file mode 100644 index 000000000..55791d224 --- /dev/null +++ b/ml/dlib/dlib/appveyor/examples.yml @@ -0,0 +1,16 @@ +version: "{build}" + +configuration: Release + +image: Visual Studio 2017 + +build_script: + # build test + - mkdir %APPVEYOR_BUILD_FOLDER%\build_examples + - cd %APPVEYOR_BUILD_FOLDER%\build_examples + #- cmake -G "Visual Studio 14 2015 Win64" -T host=x64 ../examples + - cmake -G "Visual Studio 15 2017 Win64" -T host=x64 ../examples + - cmake --build . --config %CONFIGURATION% + +test_script: + # run test diff --git a/ml/dlib/dlib/appveyor/python.yml b/ml/dlib/dlib/appveyor/python.yml new file mode 100644 index 000000000..2e1466084 --- /dev/null +++ b/ml/dlib/dlib/appveyor/python.yml @@ -0,0 +1,33 @@ + +environment: + matrix: + - PYTHON: "C:\\Python27" + PYTHON_VERSION: "2.7.x" + PYTHON_ARCH: "32" + + - PYTHON: "C:\\Python27-x64" + PYTHON_VERSION: "2.7.x" + PYTHON_ARCH: "64" + + - PYTHON: "C:\\Python33" + PYTHON_VERSION: "3.3.x" + PYTHON_ARCH: "32" + + - PYTHON: "C:\\Python33-x64" + PYTHON_VERSION: "3.3.x" + PYTHON_ARCH: "64" + + - PYTHON: "C:\\Python35" + PYTHON_VERSION: "3.5.x" + PYTHON_ARCH: "32" + + - PYTHON: "C:\\Python35-x64" + PYTHON_VERSION: "3.5.x" + PYTHON_ARCH: "64" + + +build_script: + - python setup.py build + +test_script: + - python setup.py test diff --git a/ml/dlib/dlib/array.h b/ml/dlib/dlib/array.h new file mode 100644 index 000000000..ecdafc497 --- /dev/null +++ b/ml/dlib/dlib/array.h @@ -0,0 +1,10 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ARRAy_ +#define DLIB_ARRAy_ + +#include "array/array_kernel.h" +#include "array/array_tools.h" + +#endif // DLIB_ARRAy_ + diff --git a/ml/dlib/dlib/array/array_kernel.h b/ml/dlib/dlib/array/array_kernel.h new file mode 100644 index 000000000..48160941b --- /dev/null +++ b/ml/dlib/dlib/array/array_kernel.h @@ -0,0 +1,810 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ARRAY_KERNEl_2_ +#define DLIB_ARRAY_KERNEl_2_ + +#include "array_kernel_abstract.h" +#include "../interfaces/enumerable.h" +#include "../algs.h" +#include "../serialize.h" +#include "../sort.h" +#include "../is_kind.h" + +namespace dlib +{ + + template < + typename T, + typename mem_manager = default_memory_manager + > + class array : public enumerable + { + + /*! + INITIAL VALUE + - array_size == 0 + - max_array_size == 0 + - array_elements == 0 + - pos == 0 + - last_pos == 0 + - _at_start == true + + CONVENTION + - array_size == size() + - max_array_size == max_size() + - if (max_array_size > 0) + - array_elements == pointer to max_array_size elements of type T + - else + - array_elements == 0 + + - if (array_size > 0) + - last_pos == array_elements + array_size - 1 + - else + - last_pos == 0 + + + - at_start() == _at_start + - current_element_valid() == pos != 0 + - if (current_element_valid()) then + - *pos == element() + !*/ + + public: + + // These typedefs are here for backwards compatibility with old versions of dlib. + typedef array kernel_1a; + typedef array kernel_1a_c; + typedef array kernel_2a; + typedef array kernel_2a_c; + typedef array sort_1a; + typedef array sort_1a_c; + typedef array sort_1b; + typedef array sort_1b_c; + typedef array sort_2a; + typedef array sort_2a_c; + typedef array sort_2b; + typedef array sort_2b_c; + typedef array expand_1a; + typedef array expand_1a_c; + typedef array expand_1b; + typedef array expand_1b_c; + typedef array expand_1c; + typedef array expand_1c_c; + typedef array expand_1d; + typedef array expand_1d_c; + + + + + typedef T type; + typedef T value_type; + typedef mem_manager mem_manager_type; + + array ( + ) : + array_size(0), + max_array_size(0), + array_elements(0), + pos(0), + last_pos(0), + _at_start(true) + {} + + array( + array&& item + ) : array() + { + swap(item); + } + + array& operator=( + array&& item + ) + { + swap(item); + return *this; + } + + explicit array ( + size_t new_size + ) : + array_size(0), + max_array_size(0), + array_elements(0), + pos(0), + last_pos(0), + _at_start(true) + { + resize(new_size); + } + + ~array ( + ); + + void clear ( + ); + + inline const T& operator[] ( + size_t pos + ) const; + + inline T& operator[] ( + size_t pos + ); + + void set_size ( + size_t size + ); + + inline size_t max_size( + ) const; + + void set_max_size( + size_t max + ); + + void swap ( + array& item + ); + + // functions from the enumerable interface + inline size_t size ( + ) const; + + inline bool at_start ( + ) const; + + inline void reset ( + ) const; + + bool current_element_valid ( + ) const; + + inline const T& element ( + ) const; + + inline T& element ( + ); + + bool move_next ( + ) const; + + void sort ( + ); + + void resize ( + size_t new_size + ); + + const T& back ( + ) const; + + T& back ( + ); + + void pop_back ( + ); + + void pop_back ( + T& item + ); + + void push_back ( + T& item + ); + + void push_back ( + T&& item + ); + + typedef T* iterator; + typedef const T* const_iterator; + iterator begin() { return array_elements; } + const_iterator begin() const { return array_elements; } + iterator end() { return array_elements+array_size; } + const_iterator end() const { return array_elements+array_size; } + + private: + + typename mem_manager::template rebind::other pool; + + // data members + size_t array_size; + size_t max_array_size; + T* array_elements; + + mutable T* pos; + T* last_pos; + mutable bool _at_start; + + // restricted functions + array(array&); // copy constructor + array& operator=(array&); // assignment operator + + }; + + template < + typename T, + typename mem_manager + > + inline void swap ( + array& a, + array& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void serialize ( + const array& item, + std::ostream& out + ) + { + try + { + serialize(item.max_size(),out); + serialize(item.size(),out); + + for (size_t i = 0; i < item.size(); ++i) + serialize(item[i],out); + } + catch (serialization_error e) + { + throw serialization_error(e.info + "\n while serializing object of type array"); + } + } + + template < + typename T, + typename mem_manager + > + void deserialize ( + array& item, + std::istream& in + ) + { + try + { + size_t max_size, size; + deserialize(max_size,in); + deserialize(size,in); + item.set_max_size(max_size); + item.set_size(size); + for (size_t i = 0; i < size; ++i) + deserialize(item[i],in); + } + catch (serialization_error e) + { + item.clear(); + throw serialization_error(e.info + "\n while deserializing object of type array"); + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + array:: + ~array ( + ) + { + if (array_elements) + { + pool.deallocate_array(array_elements); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void array:: + clear ( + ) + { + reset(); + last_pos = 0; + array_size = 0; + if (array_elements) + { + pool.deallocate_array(array_elements); + } + array_elements = 0; + max_array_size = 0; + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + const T& array:: + operator[] ( + size_t pos + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT( pos < this->size() , + "\tconst T& array::operator[]" + << "\n\tpos must < size()" + << "\n\tpos: " << pos + << "\n\tsize(): " << this->size() + << "\n\tthis: " << this + ); + + return array_elements[pos]; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + T& array:: + operator[] ( + size_t pos + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( pos < this->size() , + "\tT& array::operator[]" + << "\n\tpos must be < size()" + << "\n\tpos: " << pos + << "\n\tsize(): " << this->size() + << "\n\tthis: " << this + ); + + return array_elements[pos]; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void array:: + set_size ( + size_t size + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(( size <= this->max_size() ), + "\tvoid array::set_size" + << "\n\tsize must be <= max_size()" + << "\n\tsize: " << size + << "\n\tmax size: " << this->max_size() + << "\n\tthis: " << this + ); + + reset(); + array_size = size; + if (size > 0) + last_pos = array_elements + size - 1; + else + last_pos = 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + size_t array:: + size ( + ) const + { + return array_size; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void array:: + set_max_size( + size_t max + ) + { + reset(); + array_size = 0; + last_pos = 0; + if (max != 0) + { + // if new max size is different + if (max != max_array_size) + { + if (array_elements) + { + pool.deallocate_array(array_elements); + } + // try to get more memroy + try { array_elements = pool.allocate_array(max); } + catch (...) { array_elements = 0; max_array_size = 0; throw; } + max_array_size = max; + } + + } + // if the array is being made to be zero + else + { + if (array_elements) + pool.deallocate_array(array_elements); + max_array_size = 0; + array_elements = 0; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + size_t array:: + max_size ( + ) const + { + return max_array_size; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void array:: + swap ( + array& item + ) + { + auto array_size_temp = item.array_size; + auto max_array_size_temp = item.max_array_size; + T* array_elements_temp = item.array_elements; + + item.array_size = array_size; + item.max_array_size = max_array_size; + item.array_elements = array_elements; + + array_size = array_size_temp; + max_array_size = max_array_size_temp; + array_elements = array_elements_temp; + + exchange(_at_start,item._at_start); + exchange(pos,item.pos); + exchange(last_pos,item.last_pos); + pool.swap(item.pool); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// enumerable function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + bool array:: + at_start ( + ) const + { + return _at_start; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void array:: + reset ( + ) const + { + _at_start = true; + pos = 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + bool array:: + current_element_valid ( + ) const + { + return pos != 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + const T& array:: + element ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(this->current_element_valid(), + "\tconst T& array::element()" + << "\n\tThe current element must be valid if you are to access it." + << "\n\tthis: " << this + ); + + return *pos; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + T& array:: + element ( + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(this->current_element_valid(), + "\tT& array::element()" + << "\n\tThe current element must be valid if you are to access it." + << "\n\tthis: " << this + ); + + return *pos; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + bool array:: + move_next ( + ) const + { + if (!_at_start) + { + if (pos < last_pos) + { + ++pos; + return true; + } + else + { + pos = 0; + return false; + } + } + else + { + _at_start = false; + if (array_size > 0) + { + pos = array_elements; + return true; + } + else + { + return false; + } + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// Yet more functions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void array:: + sort ( + ) + { + if (this->size() > 1) + { + // call the quick sort function for arrays that is in algs.h + dlib::qsort_array(*this,0,this->size()-1); + } + this->reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void array:: + resize ( + size_t new_size + ) + { + if (this->max_size() < new_size) + { + array temp; + temp.set_max_size(new_size); + temp.set_size(new_size); + for (size_t i = 0; i < this->size(); ++i) + { + exchange((*this)[i],temp[i]); + } + temp.swap(*this); + } + else + { + this->set_size(new_size); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + T& array:: + back ( + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( this->size() > 0 , + "\tT& array::back()" + << "\n\tsize() must be bigger than 0" + << "\n\tsize(): " << this->size() + << "\n\tthis: " << this + ); + + return (*this)[this->size()-1]; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + const T& array:: + back ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT( this->size() > 0 , + "\tconst T& array::back()" + << "\n\tsize() must be bigger than 0" + << "\n\tsize(): " << this->size() + << "\n\tthis: " << this + ); + + return (*this)[this->size()-1]; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void array:: + pop_back ( + T& item + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( this->size() > 0 , + "\tvoid array::pop_back()" + << "\n\tsize() must be bigger than 0" + << "\n\tsize(): " << this->size() + << "\n\tthis: " << this + ); + + exchange(item,(*this)[this->size()-1]); + this->set_size(this->size()-1); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void array:: + pop_back ( + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( this->size() > 0 , + "\tvoid array::pop_back()" + << "\n\tsize() must be bigger than 0" + << "\n\tsize(): " << this->size() + << "\n\tthis: " << this + ); + + this->set_size(this->size()-1); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void array:: + push_back ( + T& item + ) + { + if (this->max_size() == this->size()) + { + // double the size of the array + array temp; + temp.set_max_size(this->size()*2 + 1); + temp.set_size(this->size()+1); + for (size_t i = 0; i < this->size(); ++i) + { + exchange((*this)[i],temp[i]); + } + exchange(item,temp[temp.size()-1]); + temp.swap(*this); + } + else + { + this->set_size(this->size()+1); + exchange(item,(*this)[this->size()-1]); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void array:: + push_back ( + T&& item + ) { push_back(item); } + +// ---------------------------------------------------------------------------------------- + + template + struct is_array > + { + const static bool value = true; + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_ARRAY_KERNEl_2_ + diff --git a/ml/dlib/dlib/array/array_kernel_abstract.h b/ml/dlib/dlib/array/array_kernel_abstract.h new file mode 100644 index 000000000..5cfdd483a --- /dev/null +++ b/ml/dlib/dlib/array/array_kernel_abstract.h @@ -0,0 +1,360 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_ARRAY_KERNEl_ABSTRACT_ +#ifdef DLIB_ARRAY_KERNEl_ABSTRACT_ + +#include "../interfaces/enumerable.h" +#include "../serialize.h" +#include "../algs.h" + +namespace dlib +{ + + template < + typename T, + typename mem_manager = default_memory_manager + > + class array : public enumerable + { + + /*! + REQUIREMENTS ON T + T must have a default constructor. + + REQUIREMENTS ON mem_manager + must be an implementation of memory_manager/memory_manager_kernel_abstract.h or + must be an implementation of memory_manager_global/memory_manager_global_kernel_abstract.h or + must be an implementation of memory_manager_stateless/memory_manager_stateless_kernel_abstract.h + mem_manager::type can be set to anything. + + POINTERS AND REFERENCES TO INTERNAL DATA + front(), back(), swap(), max_size(), set_size(), and operator[] + functions do not invalidate pointers or references to internal data. + All other functions have no such guarantee. + + INITIAL VALUE + size() == 0 + max_size() == 0 + + ENUMERATION ORDER + The enumerator will iterate over the elements of the array in the + order (*this)[0], (*this)[1], (*this)[2], ... + + WHAT THIS OBJECT REPRESENTS + This object represents an ordered 1-dimensional array of items, + each item is associated with an integer value. The items are + numbered from 0 though size() - 1 and the operator[] functions + run in constant time. + + Also note that unless specified otherwise, no member functions + of this object throw exceptions. + !*/ + + public: + + typedef T type; + typedef T value_type; + typedef mem_manager mem_manager_type; + + array ( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc or any exception thrown by T's constructor + !*/ + + explicit array ( + size_t new_size + ); + /*! + ensures + - #*this is properly initialized + - #size() == new_size + - #max_size() == new_size + - All elements of the array will have initial values for their type. + throws + - std::bad_alloc or any exception thrown by T's constructor + !*/ + + ~array ( + ); + /*! + ensures + - all memory associated with *this has been released + !*/ + + array( + array&& item + ); + /*! + ensures + - move constructs *this from item. Therefore, the state of item is + moved into *this and #item has a valid but unspecified state. + !*/ + + array& operator=( + array&& item + ); + /*! + ensures + - move assigns *this from item. Therefore, the state of item is + moved into *this and #item has a valid but unspecified state. + - returns a reference to #*this + !*/ + + void clear ( + ); + /*! + ensures + - #*this has its initial value + throws + - std::bad_alloc or any exception thrown by T's constructor + if this exception is thrown then the array object is unusable + until clear() is called and succeeds + !*/ + + const T& operator[] ( + size_t pos + ) const; + /*! + requires + - pos < size() + ensures + - returns a const reference to the element at position pos + !*/ + + T& operator[] ( + size_t pos + ); + /*! + requires + - pos < size() + ensures + - returns a non-const reference to the element at position pos + !*/ + + void set_size ( + size_t size + ); + /*! + requires + - size <= max_size() + ensures + - #size() == size + - any element with index between 0 and size - 1 which was in the + array before the call to set_size() retains its value and index. + All other elements have undetermined (but valid for their type) + values. (e.g. this object might buffer old T objects and reuse + them without reinitializing them between calls to set_size()) + - #at_start() == true + throws + - std::bad_alloc or any exception thrown by T's constructor + may throw this exception if there is not enough memory and + if it does throw then the call to set_size() has no effect + !*/ + + size_t max_size( + ) const; + /*! + ensures + - returns the maximum size of *this + !*/ + + void set_max_size( + size_t max + ); + /*! + ensures + - #max_size() == max + - #size() == 0 + - #at_start() == true + throws + - std::bad_alloc or any exception thrown by T's constructor + may throw this exception if there is not enough + memory and if it does throw then max_size() == 0 + !*/ + + void swap ( + array& item + ); + /*! + ensures + - swaps *this and item + !*/ + + void sort ( + ); + /*! + requires + - T must be a type with that is comparable via operator< + ensures + - for all elements in #*this the ith element is <= the i+1 element + - #at_start() == true + throws + - std::bad_alloc or any exception thrown by T's constructor + data may be lost if sort() throws + !*/ + + void resize ( + size_t new_size + ); + /*! + ensures + - #size() == new_size + - #max_size() == max(new_size,max_size()) + - for all i < size() && i < new_size: + - #(*this)[i] == (*this)[i] + (i.e. All the original elements of *this which were at index + values less than new_size are unmodified.) + - for all valid i >= size(): + - #(*this)[i] has an undefined value + (i.e. any new elements of the array have an undefined value) + throws + - std::bad_alloc or any exception thrown by T's constructor. + If an exception is thrown then it has no effect on *this. + !*/ + + + const T& back ( + ) const; + /*! + requires + - size() != 0 + ensures + - returns a const reference to (*this)[size()-1] + !*/ + + T& back ( + ); + /*! + requires + - size() != 0 + ensures + - returns a non-const reference to (*this)[size()-1] + !*/ + + void pop_back ( + T& item + ); + /*! + requires + - size() != 0 + ensures + - #size() == size() - 1 + - swaps (*this)[size()-1] into item + - All elements with an index less than size()-1 are + unmodified by this operation. + !*/ + + void pop_back ( + ); + /*! + requires + - size() != 0 + ensures + - #size() == size() - 1 + - All elements with an index less than size()-1 are + unmodified by this operation. + !*/ + + void push_back ( + T& item + ); + /*! + ensures + - #size() == size()+1 + - swaps item into (*this)[#size()-1] + - #back() == item + - #item has some undefined value (whatever happens to + get swapped out of the array) + throws + - std::bad_alloc or any exception thrown by T's constructor. + If an exception is thrown then it has no effect on *this. + !*/ + + void push_back (T&& item) { push_back(item); } + /*! + enable push_back from rvalues + !*/ + + typedef T* iterator; + typedef const T* const_iterator; + + iterator begin( + ); + /*! + ensures + - returns an iterator that points to the first element in this array or + end() if the array is empty. + !*/ + + const_iterator begin( + ) const; + /*! + ensures + - returns a const iterator that points to the first element in this + array or end() if the array is empty. + !*/ + + iterator end( + ); + /*! + ensures + - returns an iterator that points to one past the end of the array. + !*/ + + const_iterator end( + ) const; + /*! + ensures + - returns a const iterator that points to one past the end of the + array. + !*/ + + private: + + // restricted functions + array(array&); // copy constructor + array& operator=(array&); // assignment operator + + }; + + template < + typename T + > + inline void swap ( + array& a, + array& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + + template < + typename T + > + void serialize ( + const array& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + + template < + typename T + > + void deserialize ( + array& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +} + +#endif // DLIB_ARRAY_KERNEl_ABSTRACT_ + diff --git a/ml/dlib/dlib/array/array_tools.h b/ml/dlib/dlib/array/array_tools.h new file mode 100644 index 000000000..fce634396 --- /dev/null +++ b/ml/dlib/dlib/array/array_tools.h @@ -0,0 +1,38 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ARRAY_tOOLS_H_ +#define DLIB_ARRAY_tOOLS_H_ + +#include "../assert.h" +#include "array_tools_abstract.h" + +namespace dlib +{ + template + void split_array ( + T& a, + T& b, + double frac + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(0 <= frac && frac <= 1, + "\t void split_array()" + << "\n\t frac must be between 0 and 1." + << "\n\t frac: " << frac + ); + + const unsigned long asize = static_cast(a.size()*frac); + const unsigned long bsize = a.size()-asize; + + b.resize(bsize); + for (unsigned long i = 0; i < b.size(); ++i) + { + swap(b[i], a[i+asize]); + } + a.resize(asize); + } +} + +#endif // DLIB_ARRAY_tOOLS_H_ + diff --git a/ml/dlib/dlib/array/array_tools_abstract.h b/ml/dlib/dlib/array/array_tools_abstract.h new file mode 100644 index 000000000..e9b957518 --- /dev/null +++ b/ml/dlib/dlib/array/array_tools_abstract.h @@ -0,0 +1,33 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_ARRAY_tOOLS_ABSTRACT_H_ +#ifdef DLIB_ARRAY_tOOLS_ABSTRACT_H_ + +#include "array_kernel_abstract.h" + +namespace dlib +{ + template + void split_array ( + T& a, + T& b, + double frac + ); + /*! + requires + - 0 <= frac <= 1 + - T must be an array type such as dlib::array or std::vector + ensures + - This function takes the elements of a and splits them into two groups. The + first group remains in a and the second group is put into b. The ordering of + elements in a is preserved. In particular, concatenating #a with #b will + reproduce the original contents of a. + - The elements in a are moved around using global swap(). So they must be + swappable, but do not need to be copyable. + - #a.size() == floor(a.size()*frac) + - #b.size() == a.size()-#a.size() + !*/ +} + +#endif // DLIB_ARRAY_tOOLS_ABSTRACT_H_ + diff --git a/ml/dlib/dlib/array2d.h b/ml/dlib/dlib/array2d.h new file mode 100644 index 000000000..f5325e4a2 --- /dev/null +++ b/ml/dlib/dlib/array2d.h @@ -0,0 +1,12 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ARRAY2d_ +#define DLIB_ARRAY2d_ + + +#include "array2d/array2d_kernel.h" +#include "array2d/serialize_pixel_overloads.h" +#include "array2d/array2d_generic_image.h" + +#endif // DLIB_ARRAY2d_ + diff --git a/ml/dlib/dlib/array2d/array2d_generic_image.h b/ml/dlib/dlib/array2d/array2d_generic_image.h new file mode 100644 index 000000000..a96f5e3c2 --- /dev/null +++ b/ml/dlib/dlib/array2d/array2d_generic_image.h @@ -0,0 +1,67 @@ +// Copyright (C) 2014 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ARRAY2D_GENERIC_iMAGE_Hh_ +#define DLIB_ARRAY2D_GENERIC_iMAGE_Hh_ + +#include "array2d_kernel.h" +#include "../image_processing/generic_image.h" + +namespace dlib +{ + template + struct image_traits > + { + typedef T pixel_type; + }; + template + struct image_traits > + { + typedef T pixel_type; + }; + + template + inline long num_rows( const array2d& img) { return img.nr(); } + template + inline long num_columns( const array2d& img) { return img.nc(); } + + template + inline void set_image_size( + array2d& img, + long rows, + long cols + ) { img.set_size(rows,cols); } + + template + inline void* image_data( + array2d& img + ) + { + if (img.size() != 0) + return &img[0][0]; + else + return 0; + } + + template + inline const void* image_data( + const array2d& img + ) + { + if (img.size() != 0) + return &img[0][0]; + else + return 0; + } + + template + inline long width_step( + const array2d& img + ) + { + return img.width_step(); + } + +} + +#endif // DLIB_ARRAY2D_GENERIC_iMAGE_Hh_ + diff --git a/ml/dlib/dlib/array2d/array2d_kernel.h b/ml/dlib/dlib/array2d/array2d_kernel.h new file mode 100644 index 000000000..597112341 --- /dev/null +++ b/ml/dlib/dlib/array2d/array2d_kernel.h @@ -0,0 +1,498 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ARRAY2D_KERNEl_1_ +#define DLIB_ARRAY2D_KERNEl_1_ + +#include "array2d_kernel_abstract.h" +#include "../algs.h" +#include "../interfaces/enumerable.h" +#include "../serialize.h" +#include "../geometry/rectangle.h" + +namespace dlib +{ + template < + typename T, + typename mem_manager = default_memory_manager + > + class array2d : public enumerable + { + + /*! + INITIAL VALUE + - nc_ == 0 + - nr_ == 0 + - data == 0 + - at_start_ == true + - cur == 0 + - last == 0 + + CONVENTION + - nc_ == nc() + - nr_ == nc() + - if (data != 0) then + - last == a pointer to the last element in the data array + - data == pointer to an array of nc_*nr_ T objects + - else + - nc_ == 0 + - nr_ == 0 + - data == 0 + - last == 0 + + + - nr_ * nc_ == size() + - if (cur == 0) then + - current_element_valid() == false + - else + - current_element_valid() == true + - *cur == element() + + - at_start_ == at_start() + !*/ + + + class row_helper; + public: + + // These typedefs are here for backwards compatibility with older versions of dlib. + typedef array2d kernel_1a; + typedef array2d kernel_1a_c; + + typedef T type; + typedef mem_manager mem_manager_type; + + // ----------------------------------- + + class row + { + /*! + CONVENTION + - nc_ == nc() + - for all x < nc_: + - (*this)[x] == data[x] + !*/ + + friend class array2d; + friend class row_helper; + + public: + long nc ( + ) const { return nc_; } + + const T& operator[] ( + long column + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(column < nc() && column >= 0, + "\tconst T& array2d::operator[](long column) const" + << "\n\tThe column index given must be less than the number of columns." + << "\n\tthis: " << this + << "\n\tcolumn: " << column + << "\n\tnc(): " << nc() + ); + + return data[column]; + } + + T& operator[] ( + long column + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(column < nc() && column >= 0, + "\tT& array2d::operator[](long column)" + << "\n\tThe column index given must be less than the number of columns." + << "\n\tthis: " << this + << "\n\tcolumn: " << column + << "\n\tnc(): " << nc() + ); + + return data[column]; + } + + private: + + row(T* data_, long cols) : data(data_), nc_(cols) {} + + T* data; + long nc_; + + + // restricted functions + row(){} + row& operator=(row&); + }; + + // ----------------------------------- + + array2d ( + ) : + data(0), + nc_(0), + nr_(0), + cur(0), + last(0), + at_start_(true) + { + } + + array2d( + long rows, + long cols + ) : + data(0), + nc_(0), + nr_(0), + cur(0), + last(0), + at_start_(true) + { + // make sure requires clause is not broken + DLIB_ASSERT((cols >= 0 && rows >= 0), + "\t array2d::array2d(long rows, long cols)" + << "\n\t The array2d can't have negative rows or columns." + << "\n\t this: " << this + << "\n\t cols: " << cols + << "\n\t rows: " << rows + ); + + set_size(rows,cols); + } + + array2d(const array2d&) = delete; // copy constructor + array2d& operator=(const array2d&) = delete; // assignment operator + +#ifdef DLIB_HAS_RVALUE_REFERENCES + array2d(array2d&& item) : array2d() + { + swap(item); + } + + array2d& operator= ( + array2d&& rhs + ) + { + swap(rhs); + return *this; + } +#endif + + virtual ~array2d ( + ) { clear(); } + + long nc ( + ) const { return nc_; } + + long nr ( + ) const { return nr_; } + + row operator[] ( + long row_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(row_ < nr() && row_ >= 0, + "\trow array2d::operator[](long row_)" + << "\n\tThe row index given must be less than the number of rows." + << "\n\tthis: " << this + << "\n\trow_: " << row_ + << "\n\tnr(): " << nr() + ); + + return row(data+row_*nc_, nc_); + } + + const row operator[] ( + long row_ + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(row_ < nr() && row_ >= 0, + "\tconst row array2d::operator[](long row_) const" + << "\n\tThe row index given must be less than the number of rows." + << "\n\tthis: " << this + << "\n\trow_: " << row_ + << "\n\tnr(): " << nr() + ); + + return row(data+row_*nc_, nc_); + } + + void swap ( + array2d& item + ) + { + exchange(data,item.data); + exchange(nr_,item.nr_); + exchange(nc_,item.nc_); + exchange(at_start_,item.at_start_); + exchange(cur,item.cur); + exchange(last,item.last); + pool.swap(item.pool); + } + + void clear ( + ) + { + if (data != 0) + { + pool.deallocate_array(data); + nc_ = 0; + nr_ = 0; + data = 0; + at_start_ = true; + cur = 0; + last = 0; + } + } + + void set_size ( + long rows, + long cols + ); + + bool at_start ( + ) const { return at_start_; } + + void reset ( + ) const { at_start_ = true; cur = 0; } + + bool current_element_valid ( + ) const { return (cur != 0); } + + const T& element ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(current_element_valid() == true, + "\tconst T& array2d::element()()" + << "\n\tYou can only call element() when you are at a valid one." + << "\n\tthis: " << this + ); + + return *cur; + } + + T& element ( + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(current_element_valid() == true, + "\tT& array2d::element()()" + << "\n\tYou can only call element() when you are at a valid one." + << "\n\tthis: " << this + ); + + return *cur; + } + + bool move_next ( + ) const + { + if (cur != 0) + { + if (cur != last) + { + ++cur; + return true; + } + cur = 0; + return false; + } + else if (at_start_) + { + cur = data; + at_start_ = false; + return (data != 0); + } + else + { + return false; + } + } + + size_t size ( + ) const { return static_cast(nc_) * static_cast(nr_); } + + long width_step ( + ) const + { + return nc_*sizeof(T); + } + + private: + + + T* data; + long nc_; + long nr_; + + typename mem_manager::template rebind::other pool; + mutable T* cur; + T* last; + mutable bool at_start_; + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + inline void swap ( + array2d& a, + array2d& b + ) { a.swap(b); } + + + template < + typename T, + typename mem_manager + > + void serialize ( + const array2d& item, + std::ostream& out + ) + { + try + { + // The reason the serialization is a little funny is because we are trying to + // maintain backwards compatibility with an older serialization format used by + // dlib while also encoding things in a way that lets the array2d and matrix + // objects have compatible serialization formats. + serialize(-item.nr(),out); + serialize(-item.nc(),out); + + item.reset(); + while (item.move_next()) + serialize(item.element(),out); + item.reset(); + } + catch (serialization_error e) + { + throw serialization_error(e.info + "\n while serializing object of type array2d"); + } + } + + template < + typename T, + typename mem_manager + > + void deserialize ( + array2d& item, + std::istream& in + ) + { + try + { + long nr, nc; + deserialize(nr,in); + deserialize(nc,in); + + // this is the newer serialization format + if (nr < 0 || nc < 0) + { + nr *= -1; + nc *= -1; + } + else + { + std::swap(nr,nc); + } + + item.set_size(nr,nc); + + while (item.move_next()) + deserialize(item.element(),in); + item.reset(); + } + catch (serialization_error e) + { + item.clear(); + throw serialization_error(e.info + "\n while deserializing object of type array2d"); + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void array2d:: + set_size ( + long rows, + long cols + ) + { + // make sure requires clause is not broken + DLIB_ASSERT((cols >= 0 && rows >= 0) , + "\tvoid array2d::set_size(long rows, long cols)" + << "\n\tThe array2d can't have negative rows or columns." + << "\n\tthis: " << this + << "\n\tcols: " << cols + << "\n\trows: " << rows + ); + + // set the enumerator back at the start + at_start_ = true; + cur = 0; + + // don't do anything if we are already the right size. + if (nc_ == cols && nr_ == rows) + { + return; + } + + nc_ = cols; + nr_ = rows; + + // free any existing memory + if (data != 0) + { + pool.deallocate_array(data); + data = 0; + } + + // now setup this object to have the new size + try + { + if (nr_ > 0) + { + data = pool.allocate_array(nr_*nc_); + last = data + nr_*nc_ - 1; + } + } + catch (...) + { + if (data) + pool.deallocate_array(data); + + data = 0; + nc_ = 0; + nr_ = 0; + last = 0; + throw; + } + } + +// ---------------------------------------------------------------------------------------- + + template + struct is_array2d > + { + const static bool value = true; + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_ARRAY2D_KERNEl_1_ + diff --git a/ml/dlib/dlib/array2d/array2d_kernel_abstract.h b/ml/dlib/dlib/array2d/array2d_kernel_abstract.h new file mode 100644 index 000000000..cbb0e9b2b --- /dev/null +++ b/ml/dlib/dlib/array2d/array2d_kernel_abstract.h @@ -0,0 +1,301 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_ARRAY2D_KERNEl_ABSTRACT_ +#ifdef DLIB_ARRAY2D_KERNEl_ABSTRACT_ + +#include "../interfaces/enumerable.h" +#include "../serialize.h" +#include "../algs.h" +#include "../geometry/rectangle_abstract.h" + +namespace dlib +{ + + template < + typename T, + typename mem_manager = default_memory_manager + > + class array2d : public enumerable + { + + /*! + REQUIREMENTS ON T + T must have a default constructor. + + REQUIREMENTS ON mem_manager + must be an implementation of memory_manager/memory_manager_kernel_abstract.h or + must be an implementation of memory_manager_global/memory_manager_global_kernel_abstract.h or + must be an implementation of memory_manager_stateless/memory_manager_stateless_kernel_abstract.h + mem_manager::type can be set to anything. + + POINTERS AND REFERENCES TO INTERNAL DATA + No member functions in this object will invalidate pointers + or references to internal data except for the set_size() + and clear() member functions. + + INITIAL VALUE + nr() == 0 + nc() == 0 + + ENUMERATION ORDER + The enumerator will iterate over the elements of the array starting + with row 0 and then proceeding to row 1 and so on. Each row will be + fully enumerated before proceeding on to the next row and the elements + in a row will be enumerated beginning with the 0th column, then the 1st + column and so on. + + WHAT THIS OBJECT REPRESENTS + This object represents a 2-Dimensional array of objects of + type T. + + Also note that unless specified otherwise, no member functions + of this object throw exceptions. + + + Finally, note that this object stores its data contiguously and in + row major order. Moreover, there is no padding at the end of each row. + This means that its width_step() value is always equal to sizeof(type)*nc(). + !*/ + + + public: + + // ---------------------------------------- + + typedef T type; + typedef mem_manager mem_manager_type; + + // ---------------------------------------- + + class row + { + /*! + POINTERS AND REFERENCES TO INTERNAL DATA + No member functions in this object will invalidate pointers + or references to internal data. + + WHAT THIS OBJECT REPRESENTS + This object represents a row of Ts in an array2d object. + !*/ + public: + long nc ( + ) const; + /*! + ensures + - returns the number of columns in this row + !*/ + + const T& operator[] ( + long column + ) const; + /*! + requires + - 0 <= column < nc() + ensures + - returns a const reference to the T in the given column + !*/ + + T& operator[] ( + long column + ); + /*! + requires + - 0 <= column < nc() + ensures + - returns a non-const reference to the T in the given column + !*/ + + private: + // restricted functions + row(); + row& operator=(row&); + }; + + // ---------------------------------------- + + array2d ( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc + !*/ + + array2d(const array2d&) = delete; // copy constructor + array2d& operator=(const array2d&) = delete; // assignment operator + + array2d( + array2d&& item + ); + /*! + ensures + - Moves the state of item into *this. + - #item is in a valid but unspecified state. + !*/ + + array2d ( + long rows, + long cols + ); + /*! + requires + - rows >= 0 && cols >= 0 + ensures + - #nc() == cols + - #nr() == rows + - #at_start() == true + - all elements in this array have initial values for their type + throws + - std::bad_alloc + !*/ + + virtual ~array2d ( + ); + /*! + ensures + - all resources associated with *this has been released + !*/ + + void clear ( + ); + /*! + ensures + - #*this has an initial value for its type + !*/ + + long nc ( + ) const; + /*! + ensures + - returns the number of elements there are in a row. i.e. returns + the number of columns in *this + !*/ + + long nr ( + ) const; + /*! + ensures + - returns the number of rows in *this + !*/ + + void set_size ( + long rows, + long cols + ); + /*! + requires + - rows >= 0 && cols >= 0 + ensures + - #nc() == cols + - #nr() == rows + - #at_start() == true + - if (the call to set_size() doesn't change the dimensions of this array) then + - all elements in this array retain their values from before this function was called + - else + - all elements in this array have initial values for their type + throws + - std::bad_alloc + If this exception is thrown then #*this will have an initial + value for its type. + !*/ + + row operator[] ( + long row_index + ); + /*! + requires + - 0 <= row_index < nr() + ensures + - returns a non-const row of nc() elements that represents the + given row_index'th row in *this. + !*/ + + const row operator[] ( + long row_index + ) const; + /*! + requires + - 0 <= row_index < nr() + ensures + - returns a const row of nc() elements that represents the + given row_index'th row in *this. + !*/ + + void swap ( + array2d& item + ); + /*! + ensures + - swaps *this and item + !*/ + + array2d& operator= ( + array2d&& rhs + ); + /*! + ensures + - Moves the state of item into *this. + - #item is in a valid but unspecified state. + - returns #*this + !*/ + + long width_step ( + ) const; + /*! + ensures + - returns the size of one row of the image, in bytes. + More precisely, return a number N such that: + (char*)&item[0][0] + N == (char*)&item[1][0]. + - for dlib::array2d objects, the returned value + is always equal to sizeof(type)*nc(). However, + other objects which implement dlib::array2d style + interfaces might have padding at the ends of their + rows and therefore might return larger numbers. + An example of such an object is the dlib::cv_image. + !*/ + + }; + + template < + typename T, + typename mem_manager + > + inline void swap ( + array2d& a, + array2d& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + + template < + typename T, + typename mem_manager + > + void serialize ( + const array2d& item, + std::ostream& out + ); + /*! + Provides serialization support. Note that the serialization formats used by the + dlib::matrix and dlib::array2d objects are compatible. That means you can load the + serialized data from one into another and it will work properly. + !*/ + + template < + typename T, + typename mem_manager + > + void deserialize ( + array2d& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +} + +#endif // DLIB_ARRAY2D_KERNEl_ABSTRACT_ + diff --git a/ml/dlib/dlib/array2d/serialize_pixel_overloads.h b/ml/dlib/dlib/array2d/serialize_pixel_overloads.h new file mode 100644 index 000000000..9ce2c4a13 --- /dev/null +++ b/ml/dlib/dlib/array2d/serialize_pixel_overloads.h @@ -0,0 +1,371 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ARRAY2D_SERIALIZE_PIXEL_OvERLOADS_Hh_ +#define DLIB_ARRAY2D_SERIALIZE_PIXEL_OvERLOADS_Hh_ + +#include "array2d_kernel.h" +#include "../pixel.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + /* + This file contains overloads of the serialize functions for array2d object + for the case where they contain simple 8bit POD pixel types. In these + cases we can perform a much faster serialization by writing data in chunks + instead of one pixel at a time (this avoids a lot of function call overhead + inside the iostreams). + */ + +// ---------------------------------------------------------------------------------------- + + template < + typename mem_manager + > + void serialize ( + const array2d& item, + std::ostream& out + ) + { + try + { + // The reason the serialization is a little funny is because we are trying to + // maintain backwards compatibility with an older serialization format used by + // dlib while also encoding things in a way that lets the array2d and matrix + // objects have compatible serialization formats. + serialize(-item.nr(),out); + serialize(-item.nc(),out); + + COMPILE_TIME_ASSERT(sizeof(rgb_pixel) == 3); + + if (item.size() != 0) + out.write((char*)&item[0][0], sizeof(rgb_pixel)*item.size()); + } + catch (serialization_error e) + { + throw serialization_error(e.info + "\n while serializing object of type array2d"); + } + } + + template < + typename mem_manager + > + void deserialize ( + array2d& item, + std::istream& in + ) + { + try + { + COMPILE_TIME_ASSERT(sizeof(rgb_pixel) == 3); + + long nr, nc; + deserialize(nr,in); + deserialize(nc,in); + + // this is the newer serialization format + if (nr < 0 || nc < 0) + { + nr *= -1; + nc *= -1; + } + else + { + std::swap(nr,nc); + } + + item.set_size(nr,nc); + + if (item.size() != 0) + in.read((char*)&item[0][0], sizeof(rgb_pixel)*item.size()); + } + catch (serialization_error e) + { + item.clear(); + throw serialization_error(e.info + "\n while deserializing object of type array2d"); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename mem_manager + > + void serialize ( + const array2d& item, + std::ostream& out + ) + { + try + { + // The reason the serialization is a little funny is because we are trying to + // maintain backwards compatibility with an older serialization format used by + // dlib while also encoding things in a way that lets the array2d and matrix + // objects have compatible serialization formats. + serialize(-item.nr(),out); + serialize(-item.nc(),out); + + COMPILE_TIME_ASSERT(sizeof(bgr_pixel) == 3); + + if (item.size() != 0) + out.write((char*)&item[0][0], sizeof(bgr_pixel)*item.size()); + } + catch (serialization_error e) + { + throw serialization_error(e.info + "\n while serializing object of type array2d"); + } + } + + template < + typename mem_manager + > + void deserialize ( + array2d& item, + std::istream& in + ) + { + try + { + COMPILE_TIME_ASSERT(sizeof(bgr_pixel) == 3); + + long nr, nc; + deserialize(nr,in); + deserialize(nc,in); + + // this is the newer serialization format + if (nr < 0 || nc < 0) + { + nr *= -1; + nc *= -1; + } + else + { + std::swap(nr,nc); + } + + + item.set_size(nr,nc); + + if (item.size() != 0) + in.read((char*)&item[0][0], sizeof(bgr_pixel)*item.size()); + } + catch (serialization_error e) + { + item.clear(); + throw serialization_error(e.info + "\n while deserializing object of type array2d"); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename mem_manager + > + void serialize ( + const array2d& item, + std::ostream& out + ) + { + try + { + // The reason the serialization is a little funny is because we are trying to + // maintain backwards compatibility with an older serialization format used by + // dlib while also encoding things in a way that lets the array2d and matrix + // objects have compatible serialization formats. + serialize(-item.nr(),out); + serialize(-item.nc(),out); + + COMPILE_TIME_ASSERT(sizeof(hsi_pixel) == 3); + + if (item.size() != 0) + out.write((char*)&item[0][0], sizeof(hsi_pixel)*item.size()); + } + catch (serialization_error e) + { + throw serialization_error(e.info + "\n while serializing object of type array2d"); + } + } + + template < + typename mem_manager + > + void deserialize ( + array2d& item, + std::istream& in + ) + { + try + { + COMPILE_TIME_ASSERT(sizeof(hsi_pixel) == 3); + + long nr, nc; + deserialize(nr,in); + deserialize(nc,in); + + // this is the newer serialization format + if (nr < 0 || nc < 0) + { + nr *= -1; + nc *= -1; + } + else + { + std::swap(nr,nc); + } + + + item.set_size(nr,nc); + + if (item.size() != 0) + in.read((char*)&item[0][0], sizeof(hsi_pixel)*item.size()); + } + catch (serialization_error e) + { + item.clear(); + throw serialization_error(e.info + "\n while deserializing object of type array2d"); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename mem_manager + > + void serialize ( + const array2d& item, + std::ostream& out + ) + { + try + { + // The reason the serialization is a little funny is because we are trying to + // maintain backwards compatibility with an older serialization format used by + // dlib while also encoding things in a way that lets the array2d and matrix + // objects have compatible serialization formats. + serialize(-item.nr(),out); + serialize(-item.nc(),out); + + COMPILE_TIME_ASSERT(sizeof(rgb_alpha_pixel) == 4); + + if (item.size() != 0) + out.write((char*)&item[0][0], sizeof(rgb_alpha_pixel)*item.size()); + } + catch (serialization_error e) + { + throw serialization_error(e.info + "\n while serializing object of type array2d"); + } + } + + template < + typename mem_manager + > + void deserialize ( + array2d& item, + std::istream& in + ) + { + try + { + COMPILE_TIME_ASSERT(sizeof(rgb_alpha_pixel) == 4); + + long nr, nc; + deserialize(nr,in); + deserialize(nc,in); + + // this is the newer serialization format + if (nr < 0 || nc < 0) + { + nr *= -1; + nc *= -1; + } + else + { + std::swap(nr,nc); + } + + + item.set_size(nr,nc); + + if (item.size() != 0) + in.read((char*)&item[0][0], sizeof(rgb_alpha_pixel)*item.size()); + } + catch (serialization_error e) + { + item.clear(); + throw serialization_error(e.info + "\n while deserializing object of type array2d"); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename mem_manager + > + void serialize ( + const array2d& item, + std::ostream& out + ) + { + try + { + // The reason the serialization is a little funny is because we are trying to + // maintain backwards compatibility with an older serialization format used by + // dlib while also encoding things in a way that lets the array2d and matrix + // objects have compatible serialization formats. + serialize(-item.nr(),out); + serialize(-item.nc(),out); + + if (item.size() != 0) + out.write((char*)&item[0][0], sizeof(unsigned char)*item.size()); + } + catch (serialization_error e) + { + throw serialization_error(e.info + "\n while serializing object of type array2d"); + } + } + + template < + typename mem_manager + > + void deserialize ( + array2d& item, + std::istream& in + ) + { + try + { + long nr, nc; + deserialize(nr,in); + deserialize(nc,in); + // this is the newer serialization format + if (nr < 0 || nc < 0) + { + nr *= -1; + nc *= -1; + } + else + { + std::swap(nr,nc); + } + + + item.set_size(nr,nc); + + if (item.size() != 0) + in.read((char*)&item[0][0], sizeof(unsigned char)*item.size()); + } + catch (serialization_error e) + { + item.clear(); + throw serialization_error(e.info + "\n while deserializing object of type array2d"); + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_ARRAY2D_SERIALIZE_PIXEL_OvERLOADS_Hh_ + diff --git a/ml/dlib/dlib/assert.h b/ml/dlib/dlib/assert.h new file mode 100644 index 000000000..2220dd73a --- /dev/null +++ b/ml/dlib/dlib/assert.h @@ -0,0 +1,216 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ASSERt_ +#define DLIB_ASSERt_ + +#include "config.h" +#include +#include +#include "error.h" + +// ----------------------------- + +// Use some stuff from boost here +// (C) Copyright John Maddock 2001 - 2003. +// (C) Copyright Darin Adler 2001. +// (C) Copyright Peter Dimov 2001. +// (C) Copyright Bill Kempf 2002. +// (C) Copyright Jens Maurer 2002. +// (C) Copyright David Abrahams 2002 - 2003. +// (C) Copyright Gennaro Prota 2003. +// (C) Copyright Eric Friedman 2003. +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef BOOST_JOIN +#define BOOST_JOIN( X, Y ) BOOST_DO_JOIN( X, Y ) +#define BOOST_DO_JOIN( X, Y ) BOOST_DO_JOIN2(X,Y) +#define BOOST_DO_JOIN2( X, Y ) X##Y +#endif + +// figure out if the compiler has rvalue references. +#if defined(__clang__) +# if __has_feature(cxx_rvalue_references) +# define DLIB_HAS_RVALUE_REFERENCES +# endif +# if __has_feature(cxx_generalized_initializers) +# define DLIB_HAS_INITIALIZER_LISTS +# endif +#elif defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ > 2)) && defined(__GXX_EXPERIMENTAL_CXX0X__) +# define DLIB_HAS_RVALUE_REFERENCES +# define DLIB_HAS_INITIALIZER_LISTS +#elif defined(_MSC_VER) && _MSC_VER >= 1800 +# define DLIB_HAS_INITIALIZER_LISTS +# define DLIB_HAS_RVALUE_REFERENCES +#elif defined(_MSC_VER) && _MSC_VER >= 1600 +# define DLIB_HAS_RVALUE_REFERENCES +#elif defined(__INTEL_COMPILER) && defined(BOOST_INTEL_STDCXX0X) +# define DLIB_HAS_RVALUE_REFERENCES +# define DLIB_HAS_INITIALIZER_LISTS +#endif + +#if defined(__APPLE__) && defined(__GNUC_LIBSTD__) && ((__GNUC_LIBSTD__-0) * 100 + __GNUC_LIBSTD_MINOR__-0 <= 402) + // Apple has not updated libstdc++ in some time and anything under 4.02 does not have for sure. +# undef DLIB_HAS_INITIALIZER_LISTS +#endif + +// figure out if the compiler has static_assert. +#if defined(__clang__) +# if __has_feature(cxx_static_assert) +# define DLIB_HAS_STATIC_ASSERT +# endif +#elif defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ > 2)) && defined(__GXX_EXPERIMENTAL_CXX0X__) +# define DLIB_HAS_STATIC_ASSERT +#elif defined(_MSC_VER) && _MSC_VER >= 1600 +# define DLIB_HAS_STATIC_ASSERT +#elif defined(__INTEL_COMPILER) && defined(BOOST_INTEL_STDCXX0X) +# define DLIB_HAS_STATIC_ASSERT +#endif + + +// ----------------------------- + +namespace dlib +{ + template struct compile_time_assert; + template <> struct compile_time_assert { enum {value=1}; }; + + template struct assert_are_same_type; + template struct assert_are_same_type {enum{value=1};}; + template struct assert_are_not_same_type {enum{value=1}; }; + template struct assert_are_not_same_type {}; + + template struct assert_types_match {enum{value=0};}; + template struct assert_types_match {enum{value=1};}; +} + + +// gcc 4.8 will warn about unused typedefs. But we use typedefs in some of the compile +// time assert macros so we need to make it not complain about them "not being used". +#ifdef __GNUC__ +#define DLIB_NO_WARN_UNUSED __attribute__ ((unused)) +#else +#define DLIB_NO_WARN_UNUSED +#endif + +// Use the newer static_assert if it's available since it produces much more readable error +// messages. +#ifdef DLIB_HAS_STATIC_ASSERT + #define COMPILE_TIME_ASSERT(expression) static_assert(expression, "Failed assertion") + #define ASSERT_ARE_SAME_TYPE(type1, type2) static_assert(::dlib::assert_types_match::value, "These types should be the same but aren't.") + #define ASSERT_ARE_NOT_SAME_TYPE(type1, type2) static_assert(!::dlib::assert_types_match::value, "These types should NOT be the same.") +#else + #define COMPILE_TIME_ASSERT(expression) \ + DLIB_NO_WARN_UNUSED typedef char BOOST_JOIN(DLIB_CTA, __LINE__)[::dlib::compile_time_assert<(bool)(expression)>::value] + + #define ASSERT_ARE_SAME_TYPE(type1, type2) \ + DLIB_NO_WARN_UNUSED typedef char BOOST_JOIN(DLIB_AAST, __LINE__)[::dlib::assert_are_same_type::value] + + #define ASSERT_ARE_NOT_SAME_TYPE(type1, type2) \ + DLIB_NO_WARN_UNUSED typedef char BOOST_JOIN(DLIB_AANST, __LINE__)[::dlib::assert_are_not_same_type::value] +#endif + +// ----------------------------- + +#if defined DLIB_DISABLE_ASSERTS + // if DLIB_DISABLE_ASSERTS is on then never enable DLIB_ASSERT no matter what. + #undef ENABLE_ASSERTS +#endif + +#if !defined(DLIB_DISABLE_ASSERTS) && ( defined DEBUG || defined _DEBUG) + // make sure ENABLE_ASSERTS is defined if we are indeed using them. + #ifndef ENABLE_ASSERTS + #define ENABLE_ASSERTS + #endif +#endif + +// ----------------------------- + +#ifdef __GNUC__ +// There is a bug in version 4.4.5 of GCC on Ubuntu which causes GCC to segfault +// when __PRETTY_FUNCTION__ is used within certain templated functions. So just +// don't use it with this version of GCC. +# if !(__GNUC__ == 4 && __GNUC_MINOR__ == 4 && __GNUC_PATCHLEVEL__ == 5) +# define DLIB_FUNCTION_NAME __PRETTY_FUNCTION__ +# else +# define DLIB_FUNCTION_NAME "unknown function" +# endif +#elif defined(_MSC_VER) +#define DLIB_FUNCTION_NAME __FUNCSIG__ +#else +#define DLIB_FUNCTION_NAME "unknown function" +#endif + +#define DLIBM_CASSERT(_exp,_message) \ + {if ( !(_exp) ) \ + { \ + dlib_assert_breakpoint(); \ + std::ostringstream dlib_o_out; \ + dlib_o_out << "\n\nError detected at line " << __LINE__ << ".\n"; \ + dlib_o_out << "Error detected in file " << __FILE__ << ".\n"; \ + dlib_o_out << "Error detected in function " << DLIB_FUNCTION_NAME << ".\n\n"; \ + dlib_o_out << "Failing expression was " << #_exp << ".\n"; \ + dlib_o_out << std::boolalpha << _message << "\n"; \ + throw dlib::fatal_error(dlib::EBROKEN_ASSERT,dlib_o_out.str()); \ + }} + +// This macro is not needed if you have a real C++ compiler. It's here to work around bugs in Visual Studio's preprocessor. +#define DLIB_WORKAROUND_VISUAL_STUDIO_BUGS(x) x +// Make it so the 2nd argument of DLIB_CASSERT is optional. That is, you can call it like +// DLIB_CASSERT(exp) or DLIB_CASSERT(exp,message). +#define DLIBM_CASSERT_1_ARGS(exp) DLIBM_CASSERT(exp,"") +#define DLIBM_CASSERT_2_ARGS(exp,message) DLIBM_CASSERT(exp,message) +#define DLIBM_GET_3TH_ARG(arg1, arg2, arg3, ...) arg3 +#define DLIBM_CASSERT_CHOOSER(...) DLIB_WORKAROUND_VISUAL_STUDIO_BUGS(DLIBM_GET_3TH_ARG(__VA_ARGS__, DLIBM_CASSERT_2_ARGS, DLIBM_CASSERT_1_ARGS)) +#define DLIB_CASSERT(...) DLIB_WORKAROUND_VISUAL_STUDIO_BUGS(DLIBM_CASSERT_CHOOSER(__VA_ARGS__)(__VA_ARGS__)) + + +#ifdef ENABLE_ASSERTS + #define DLIB_ASSERT(...) DLIB_CASSERT(__VA_ARGS__) + #define DLIB_IF_ASSERT(exp) exp +#else + #define DLIB_ASSERT(...) {} + #define DLIB_IF_ASSERT(exp) +#endif + +// ---------------------------------------------------------------------------------------- + + /*!A DLIB_ASSERT_HAS_STANDARD_LAYOUT + + This macro is meant to cause a compiler error if a type doesn't have a simple + memory layout (like a C struct). In particular, types with simple layouts are + ones which can be copied via memcpy(). + + + This was called a POD type in C++03 and in C++0x we are looking to check if + it is a "standard layout type". Once we can use C++0x we can change this macro + to something that uses the std::is_standard_layout type_traits class. + See: http://www2.research.att.com/~bs/C++0xFAQ.html#PODs + !*/ + // Use the fact that in C++03 you can't put non-PODs into a union. +#define DLIB_ASSERT_HAS_STANDARD_LAYOUT(type) \ + union BOOST_JOIN(DAHSL_,__LINE__) { type TYPE_NOT_STANDARD_LAYOUT; }; \ + DLIB_NO_WARN_UNUSED typedef char BOOST_JOIN(DAHSL2_,__LINE__)[sizeof(BOOST_JOIN(DAHSL_,__LINE__))]; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + +// breakpoints +extern "C" +{ + inline void dlib_assert_breakpoint( + ) {} + /*! + ensures + - this function does nothing + It exists just so you can put breakpoints on it in a debugging tool. + It is called only when an DLIB_ASSERT or DLIB_CASSERT fails and is about to + throw an exception. + !*/ +} + +// ----------------------------- + +#include "stack_trace.h" + +#endif // DLIB_ASSERt_ + diff --git a/ml/dlib/dlib/base64.h b/ml/dlib/dlib/base64.h new file mode 100644 index 000000000..8308920d6 --- /dev/null +++ b/ml/dlib/dlib/base64.h @@ -0,0 +1,9 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BASe64_ +#define DLIB_BASe64_ + +#include "base64/base64_kernel_1.h" + +#endif // DLIB_BASe64_ + diff --git a/ml/dlib/dlib/base64/base64_kernel_1.cpp b/ml/dlib/dlib/base64/base64_kernel_1.cpp new file mode 100644 index 000000000..5b48c789e --- /dev/null +++ b/ml/dlib/dlib/base64/base64_kernel_1.cpp @@ -0,0 +1,403 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BASE64_KERNEL_1_CPp_ +#define DLIB_BASE64_KERNEL_1_CPp_ + +#include "base64_kernel_1.h" +#include +#include +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + base64::line_ending_type base64:: + line_ending ( + ) const + { + return eol_style; + } + +// ---------------------------------------------------------------------------------------- + + void base64:: + set_line_ending ( + line_ending_type eol_style_ + ) + { + eol_style = eol_style_; + } + +// ---------------------------------------------------------------------------------------- + + base64:: + base64 ( + ) : + encode_table(0), + decode_table(0), + bad_value(100), + eol_style(LF) + { + try + { + encode_table = new char[64]; + decode_table = new unsigned char[UCHAR_MAX]; + } + catch (...) + { + if (encode_table) delete [] encode_table; + if (decode_table) delete [] decode_table; + throw; + } + + // now set up the tables with the right stuff + encode_table[0] = 'A'; + encode_table[17] = 'R'; + encode_table[34] = 'i'; + encode_table[51] = 'z'; + + encode_table[1] = 'B'; + encode_table[18] = 'S'; + encode_table[35] = 'j'; + encode_table[52] = '0'; + + encode_table[2] = 'C'; + encode_table[19] = 'T'; + encode_table[36] = 'k'; + encode_table[53] = '1'; + + encode_table[3] = 'D'; + encode_table[20] = 'U'; + encode_table[37] = 'l'; + encode_table[54] = '2'; + + encode_table[4] = 'E'; + encode_table[21] = 'V'; + encode_table[38] = 'm'; + encode_table[55] = '3'; + + encode_table[5] = 'F'; + encode_table[22] = 'W'; + encode_table[39] = 'n'; + encode_table[56] = '4'; + + encode_table[6] = 'G'; + encode_table[23] = 'X'; + encode_table[40] = 'o'; + encode_table[57] = '5'; + + encode_table[7] = 'H'; + encode_table[24] = 'Y'; + encode_table[41] = 'p'; + encode_table[58] = '6'; + + encode_table[8] = 'I'; + encode_table[25] = 'Z'; + encode_table[42] = 'q'; + encode_table[59] = '7'; + + encode_table[9] = 'J'; + encode_table[26] = 'a'; + encode_table[43] = 'r'; + encode_table[60] = '8'; + + encode_table[10] = 'K'; + encode_table[27] = 'b'; + encode_table[44] = 's'; + encode_table[61] = '9'; + + encode_table[11] = 'L'; + encode_table[28] = 'c'; + encode_table[45] = 't'; + encode_table[62] = '+'; + + encode_table[12] = 'M'; + encode_table[29] = 'd'; + encode_table[46] = 'u'; + encode_table[63] = '/'; + + encode_table[13] = 'N'; + encode_table[30] = 'e'; + encode_table[47] = 'v'; + + encode_table[14] = 'O'; + encode_table[31] = 'f'; + encode_table[48] = 'w'; + + encode_table[15] = 'P'; + encode_table[32] = 'g'; + encode_table[49] = 'x'; + + encode_table[16] = 'Q'; + encode_table[33] = 'h'; + encode_table[50] = 'y'; + + + + // we can now fill out the decode_table by using the encode_table + for (int i = 0; i < UCHAR_MAX; ++i) + { + decode_table[i] = bad_value; + } + for (unsigned char i = 0; i < 64; ++i) + { + decode_table[(unsigned char)encode_table[i]] = i; + } + } + +// ---------------------------------------------------------------------------------------- + + base64:: + ~base64 ( + ) + { + delete [] encode_table; + delete [] decode_table; + } + +// ---------------------------------------------------------------------------------------- + + void base64:: + encode ( + std::istream& in_, + std::ostream& out_ + ) const + { + using namespace std; + streambuf& in = *in_.rdbuf(); + streambuf& out = *out_.rdbuf(); + + unsigned char inbuf[3]; + unsigned char outbuf[4]; + streamsize status = in.sgetn(reinterpret_cast(&inbuf),3); + + unsigned char c1, c2, c3, c4, c5, c6; + + int counter = 19; + + // while we haven't hit the end of the input stream + while (status != 0) + { + if (counter == 0) + { + counter = 19; + // write a newline + char ch; + switch (eol_style) + { + case CR: + ch = '\r'; + if (out.sputn(&ch,1)!=1) + throw std::ios_base::failure("error occurred in the base64 object"); + break; + case LF: + ch = '\n'; + if (out.sputn(&ch,1)!=1) + throw std::ios_base::failure("error occurred in the base64 object"); + break; + case CRLF: + ch = '\r'; + if (out.sputn(&ch,1)!=1) + throw std::ios_base::failure("error occurred in the base64 object"); + ch = '\n'; + if (out.sputn(&ch,1)!=1) + throw std::ios_base::failure("error occurred in the base64 object"); + break; + default: + DLIB_CASSERT(false,"this should never happen"); + } + } + --counter; + + if (status == 3) + { + // encode the bytes in inbuf to base64 and write them to the output stream + c1 = inbuf[0]&0xfc; + c2 = inbuf[0]&0x03; + c3 = inbuf[1]&0xf0; + c4 = inbuf[1]&0x0f; + c5 = inbuf[2]&0xc0; + c6 = inbuf[2]&0x3f; + + outbuf[0] = c1>>2; + outbuf[1] = (c2<<4)|(c3>>4); + outbuf[2] = (c4<<2)|(c5>>6); + outbuf[3] = c6; + + + outbuf[0] = encode_table[outbuf[0]]; + outbuf[1] = encode_table[outbuf[1]]; + outbuf[2] = encode_table[outbuf[2]]; + outbuf[3] = encode_table[outbuf[3]]; + + // write the encoded bytes to the output stream + if (out.sputn(reinterpret_cast(&outbuf),4)!=4) + { + throw std::ios_base::failure("error occurred in the base64 object"); + } + + // get 3 more input bytes + status = in.sgetn(reinterpret_cast(&inbuf),3); + continue; + } + else if (status == 2) + { + // we are at the end of the input stream and need to add some padding + + // encode the bytes in inbuf to base64 and write them to the output stream + c1 = inbuf[0]&0xfc; + c2 = inbuf[0]&0x03; + c3 = inbuf[1]&0xf0; + c4 = inbuf[1]&0x0f; + c5 = 0; + + outbuf[0] = c1>>2; + outbuf[1] = (c2<<4)|(c3>>4); + outbuf[2] = (c4<<2)|(c5>>6); + outbuf[3] = '='; + + outbuf[0] = encode_table[outbuf[0]]; + outbuf[1] = encode_table[outbuf[1]]; + outbuf[2] = encode_table[outbuf[2]]; + + // write the encoded bytes to the output stream + if (out.sputn(reinterpret_cast(&outbuf),4)!=4) + { + throw std::ios_base::failure("error occurred in the base64 object"); + } + + + break; + } + else // in this case status must be 1 + { + // we are at the end of the input stream and need to add some padding + + // encode the bytes in inbuf to base64 and write them to the output stream + c1 = inbuf[0]&0xfc; + c2 = inbuf[0]&0x03; + c3 = 0; + + outbuf[0] = c1>>2; + outbuf[1] = (c2<<4)|(c3>>4); + outbuf[2] = '='; + outbuf[3] = '='; + + outbuf[0] = encode_table[outbuf[0]]; + outbuf[1] = encode_table[outbuf[1]]; + + + // write the encoded bytes to the output stream + if (out.sputn(reinterpret_cast(&outbuf),4)!=4) + { + throw std::ios_base::failure("error occurred in the base64 object"); + } + + break; + } + } // while (status != 0) + + + // make sure the stream buffer flushes to its I/O channel + out.pubsync(); + } + +// ---------------------------------------------------------------------------------------- + + void base64:: + decode ( + std::istream& in_, + std::ostream& out_ + ) const + { + using namespace std; + streambuf& in = *in_.rdbuf(); + streambuf& out = *out_.rdbuf(); + + unsigned char inbuf[4]; + unsigned char outbuf[3]; + int inbuf_pos = 0; + streamsize status = in.sgetn(reinterpret_cast(inbuf),1); + + // only count this character if it isn't some kind of filler + if (status == 1 && decode_table[inbuf[0]] != bad_value ) + ++inbuf_pos; + + unsigned char c1, c2, c3, c4, c5, c6; + streamsize outsize; + + // while we haven't hit the end of the input stream + while (status != 0) + { + // if we have 4 valid characters + if (inbuf_pos == 4) + { + inbuf_pos = 0; + + // this might be the end of the encoded data so we need to figure out if + // there was any padding applied. + outsize = 3; + if (inbuf[3] == '=') + { + if (inbuf[2] == '=') + outsize = 1; + else + outsize = 2; + } + + // decode the incoming characters + inbuf[0] = decode_table[inbuf[0]]; + inbuf[1] = decode_table[inbuf[1]]; + inbuf[2] = decode_table[inbuf[2]]; + inbuf[3] = decode_table[inbuf[3]]; + + + // now pack these guys into bytes rather than 6 bit chunks + c1 = inbuf[0]<<2; + c2 = inbuf[1]>>4; + c3 = inbuf[1]<<4; + c4 = inbuf[2]>>2; + c5 = inbuf[2]<<6; + c6 = inbuf[3]; + + outbuf[0] = c1|c2; + outbuf[1] = c3|c4; + outbuf[2] = c5|c6; + + + // write the encoded bytes to the output stream + if (out.sputn(reinterpret_cast(&outbuf),outsize)!=outsize) + { + throw std::ios_base::failure("error occurred in the base64 object"); + } + } + + // get more input characters + status = in.sgetn(reinterpret_cast(inbuf + inbuf_pos),1); + // only count this character if it isn't some kind of filler + if ((decode_table[inbuf[inbuf_pos]] != bad_value || inbuf[inbuf_pos] == '=') && + status != 0) + ++inbuf_pos; + } // while (status != 0) + + if (inbuf_pos != 0) + { + ostringstream sout; + sout << inbuf_pos << " extra characters were found at the end of the encoded data." + << " This may indicate that the data stream has been truncated."; + // this happens if we hit EOF in the middle of decoding a 24bit block. + throw decode_error(sout.str()); + } + + // make sure the stream buffer flushes to its I/O channel + out.pubsync(); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BASE64_KERNEL_1_CPp_ + diff --git a/ml/dlib/dlib/base64/base64_kernel_1.h b/ml/dlib/dlib/base64/base64_kernel_1.h new file mode 100644 index 000000000..d8f49b1b8 --- /dev/null +++ b/ml/dlib/dlib/base64/base64_kernel_1.h @@ -0,0 +1,92 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BASE64_KERNEl_1_ +#define DLIB_BASE64_KERNEl_1_ + +#include "../algs.h" +#include "base64_kernel_abstract.h" +#include + +namespace dlib +{ + + class base64 + { + /*! + INITIAL VALUE + - bad_value == 100 + - encode_table == a pointer to an array of 64 chars + - where x is a 6 bit value the following is true: + - encode_table[x] == the base64 encoding of x + - decode_table == a pointer to an array of UCHAR_MAX chars + - where x is any char value: + - if (x is a valid character in the base64 coding scheme) then + - decode_table[x] == the 6 bit value that x encodes + - else + - decode_table[x] == bad_value + + CONVENTION + - The state of this object never changes so just refer to its + initial value. + + + !*/ + + public: + // this is here for backwards compatibility with older versions of dlib. + typedef base64 kernel_1a; + + class decode_error : public dlib::error { public: + decode_error( const std::string& e) : error(e) {}}; + + base64 ( + ); + + virtual ~base64 ( + ); + + enum line_ending_type + { + CR, // i.e. "\r" + LF, // i.e. "\n" + CRLF // i.e. "\r\n" + }; + + line_ending_type line_ending ( + ) const; + + void set_line_ending ( + line_ending_type eol_style_ + ); + + void encode ( + std::istream& in, + std::ostream& out + ) const; + + void decode ( + std::istream& in, + std::ostream& out + ) const; + + private: + + char* encode_table; + unsigned char* decode_table; + const unsigned char bad_value; + line_ending_type eol_style; + + // restricted functions + base64(base64&); // copy constructor + base64& operator=(base64&); // assignment operator + + }; + +} + +#ifdef NO_MAKEFILE +#include "base64_kernel_1.cpp" +#endif + +#endif // DLIB_BASE64_KERNEl_1_ + diff --git a/ml/dlib/dlib/base64/base64_kernel_abstract.h b/ml/dlib/dlib/base64/base64_kernel_abstract.h new file mode 100644 index 000000000..0a63d3b87 --- /dev/null +++ b/ml/dlib/dlib/base64/base64_kernel_abstract.h @@ -0,0 +1,121 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_BASE64_KERNEl_ABSTRACT_ +#ifdef DLIB_BASE64_KERNEl_ABSTRACT_ + +#include "../algs.h" +#include + +namespace dlib +{ + + class base64 + { + /*! + INITIAL VALUE + - line_ending() == LF + + WHAT THIS OBJECT REPRESENTS + This object consists of the two functions encode and decode. + These functions allow you to encode and decode data to and from + the Base64 Content-Transfer-Encoding defined in section 6.8 of + rfc2045. + !*/ + + public: + + class decode_error : public dlib::error {}; + + base64 ( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc + !*/ + + virtual ~base64 ( + ); + /*! + ensures + - all memory associated with *this has been released + !*/ + + enum line_ending_type + { + CR, // i.e. "\r" + LF, // i.e. "\n" + CRLF // i.e. "\r\n" + }; + + line_ending_type line_ending ( + ) const; + /*! + ensures + - returns the type of end of line bytes the encoder + will use when encoding data to base64 blocks. Note that + the ostream object you use might apply some sort of transform + to line endings as well. For example, C++ ofstream objects + usually convert '\n' into whatever a normal newline is for + your platform unless you open a file in binary mode. But + aside from file streams the ostream objects usually don't + modify the data you pass to them. + !*/ + + void set_line_ending ( + line_ending_type eol_style + ); + /*! + ensures + - #line_ending() == eol_style + !*/ + + void encode ( + std::istream& in, + std::ostream& out + ) const; + /*! + ensures + - reads all data from in (until EOF is reached) and encodes it + and writes it to out + throws + - std::ios_base::failure + if there was a problem writing to out then this exception will + be thrown. + - any other exception + this exception may be thrown if there is any other problem + !*/ + + void decode ( + std::istream& in, + std::ostream& out + ) const; + /*! + ensures + - reads data from in (until EOF is reached), decodes it, + and writes it to out. + throws + - std::ios_base::failure + if there was a problem writing to out then this exception will + be thrown. + - decode_error + if an error was detected in the encoded data that prevented + it from being correctly decoded then this exception is + thrown. + - any other exception + this exception may be thrown if there is any other problem + !*/ + + private: + + // restricted functions + base64(base64&); // copy constructor + base64& operator=(base64&); // assignment operator + + }; + +} + +#endif // DLIB_BASE64_KERNEl_ABSTRACT_ + diff --git a/ml/dlib/dlib/bayes_utils.h b/ml/dlib/dlib/bayes_utils.h new file mode 100644 index 000000000..51ef6d2ed --- /dev/null +++ b/ml/dlib/dlib/bayes_utils.h @@ -0,0 +1,11 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BAYES_UTILs_H_ +#define DLIB_BAYES_UTILs_H_ + +#include "bayes_utils/bayes_utils.h" + +#endif // DLIB_BAYES_UTILs_H_ + + + diff --git a/ml/dlib/dlib/bayes_utils/bayes_utils.h b/ml/dlib/dlib/bayes_utils/bayes_utils.h new file mode 100644 index 000000000..04b3d1187 --- /dev/null +++ b/ml/dlib/dlib/bayes_utils/bayes_utils.h @@ -0,0 +1,1678 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BAYES_UTILs_ +#define DLIB_BAYES_UTILs_ + +#include "bayes_utils_abstract.h" + +#include +#include +#include +#include + +#include "../string.h" +#include "../map.h" +#include "../matrix.h" +#include "../rand.h" +#include "../array.h" +#include "../set.h" +#include "../algs.h" +#include "../noncopyable.h" +#include "../graph.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class assignment + { + public: + + assignment() + { + } + + assignment( + const assignment& a + ) + { + a.reset(); + while (a.move_next()) + { + unsigned long idx = a.element().key(); + unsigned long value = a.element().value(); + vals.add(idx,value); + } + } + + assignment& operator = ( + const assignment& rhs + ) + { + if (this == &rhs) + return *this; + + assignment(rhs).swap(*this); + return *this; + } + + void clear() + { + vals.clear(); + } + + bool operator < ( + const assignment& item + ) const + { + if (size() < item.size()) + return true; + else if (size() > item.size()) + return false; + + reset(); + item.reset(); + while (move_next()) + { + item.move_next(); + if (element().key() < item.element().key()) + return true; + else if (element().key() > item.element().key()) + return false; + else if (element().value() < item.element().value()) + return true; + else if (element().value() > item.element().value()) + return false; + } + + return false; + } + + bool has_index ( + unsigned long idx + ) const + { + return vals.is_in_domain(idx); + } + + void add ( + unsigned long idx, + unsigned long value = 0 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( has_index(idx) == false , + "\tvoid assignment::add(idx)" + << "\n\tYou can't add the same index to an assignment object more than once" + << "\n\tidx: " << idx + << "\n\tthis: " << this + ); + + vals.add(idx, value); + } + + unsigned long& operator[] ( + const long idx + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( has_index(idx) == true , + "\tunsigned long assignment::operator[](idx)" + << "\n\tYou can't access an index value if it isn't already in the object" + << "\n\tidx: " << idx + << "\n\tthis: " << this + ); + + return vals[idx]; + } + + const unsigned long& operator[] ( + const long idx + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT( has_index(idx) == true , + "\tunsigned long assignment::operator[](idx)" + << "\n\tYou can't access an index value if it isn't already in the object" + << "\n\tidx: " << idx + << "\n\tthis: " << this + ); + + return vals[idx]; + } + + void swap ( + assignment& item + ) + { + vals.swap(item.vals); + } + + void remove ( + unsigned long idx + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( has_index(idx) == true , + "\tunsigned long assignment::remove(idx)" + << "\n\tYou can't remove an index value if it isn't already in the object" + << "\n\tidx: " << idx + << "\n\tthis: " << this + ); + + vals.destroy(idx); + } + + unsigned long size() const { return vals.size(); } + + void reset() const { vals.reset(); } + + bool move_next() const { return vals.move_next(); } + + map_pair& element() + { + // make sure requires clause is not broken + DLIB_ASSERT(current_element_valid() == true, + "\tmap_pair& assignment::element()" + << "\n\tyou can't access the current element if it doesn't exist" + << "\n\tthis: " << this + ); + return vals.element(); + } + + const map_pair& element() const + { + // make sure requires clause is not broken + DLIB_ASSERT(current_element_valid() == true, + "\tconst map_pair& assignment::element() const" + << "\n\tyou can't access the current element if it doesn't exist" + << "\n\tthis: " << this + ); + + return vals.element(); + } + + bool at_start() const { return vals.at_start(); } + + bool current_element_valid() const { return vals.current_element_valid(); } + + friend inline void serialize ( + const assignment& item, + std::ostream& out + ) + { + serialize(item.vals, out); + } + + friend inline void deserialize ( + assignment& item, + std::istream& in + ) + { + deserialize(item.vals, in); + } + + private: + mutable dlib::map::kernel_1b_c vals; + }; + + inline std::ostream& operator << ( + std::ostream& out, + const assignment& a + ) + { + a.reset(); + out << "("; + if (a.move_next()) + out << a.element().key() << ":" << a.element().value(); + + while (a.move_next()) + { + out << ", " << a.element().key() << ":" << a.element().value(); + } + + out << ")"; + return out; + } + + + inline void swap ( + assignment& a, + assignment& b + ) + { + a.swap(b); + } + + +// ------------------------------------------------------------------------ + + class joint_probability_table + { + /*! + INITIAL VALUE + - table.size() == 0 + + CONVENTION + - size() == table.size() + - probability(a) == table[a] + !*/ + public: + + joint_probability_table ( + const joint_probability_table& t + ) + { + t.reset(); + while (t.move_next()) + { + assignment a = t.element().key(); + double p = t.element().value(); + set_probability(a,p); + } + } + + joint_probability_table() {} + + joint_probability_table& operator= ( + const joint_probability_table& rhs + ) + { + if (this == &rhs) + return *this; + joint_probability_table(rhs).swap(*this); + return *this; + } + + void set_probability ( + const assignment& a, + double p + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(0.0 <= p && p <= 1.0, + "\tvoid& joint_probability_table::set_probability(a,p)" + << "\n\tyou have given an invalid probability value" + << "\n\tp: " << p + << "\n\ta: " << a + << "\n\tthis: " << this + ); + + if (table.is_in_domain(a)) + { + table[a] = p; + } + else + { + assignment temp(a); + table.add(temp,p); + } + } + + bool has_entry_for ( + const assignment& a + ) const + { + return table.is_in_domain(a); + } + + void add_probability ( + const assignment& a, + double p + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(0.0 <= p && p <= 1.0, + "\tvoid& joint_probability_table::add_probability(a,p)" + << "\n\tyou have given an invalid probability value" + << "\n\tp: " << p + << "\n\ta: " << a + << "\n\tthis: " << this + ); + + if (table.is_in_domain(a)) + { + table[a] += p; + if (table[a] > 1.0) + table[a] = 1.0; + } + else + { + assignment temp(a); + table.add(temp,p); + } + } + + double probability ( + const assignment& a + ) const + { + return table[a]; + } + + void clear() + { + table.clear(); + } + + size_t size () const { return table.size(); } + bool move_next() const { return table.move_next(); } + void reset() const { table.reset(); } + map_pair& element() + { + // make sure requires clause is not broken + DLIB_ASSERT(current_element_valid() == true, + "\tmap_pair& joint_probability_table::element()" + << "\n\tyou can't access the current element if it doesn't exist" + << "\n\tthis: " << this + ); + + return table.element(); + } + + const map_pair& element() const + { + // make sure requires clause is not broken + DLIB_ASSERT(current_element_valid() == true, + "\tconst map_pair& joint_probability_table::element() const" + << "\n\tyou can't access the current element if it doesn't exist" + << "\n\tthis: " << this + ); + + return table.element(); + } + + bool at_start() const { return table.at_start(); } + + bool current_element_valid() const { return table.current_element_valid(); } + + + template + void marginalize ( + const T& vars, + joint_probability_table& out + ) const + { + out.clear(); + double p; + reset(); + while (move_next()) + { + assignment a; + const assignment& asrc = element().key(); + p = element().value(); + + asrc.reset(); + while (asrc.move_next()) + { + if (vars.is_member(asrc.element().key())) + a.add(asrc.element().key(), asrc.element().value()); + } + + out.add_probability(a,p); + } + } + + void marginalize ( + const unsigned long var, + joint_probability_table& out + ) const + { + out.clear(); + double p; + reset(); + while (move_next()) + { + assignment a; + const assignment& asrc = element().key(); + p = element().value(); + + asrc.reset(); + while (asrc.move_next()) + { + if (var == asrc.element().key()) + a.add(asrc.element().key(), asrc.element().value()); + } + + out.add_probability(a,p); + } + } + + void normalize ( + ) + { + double sum = 0; + + reset(); + while (move_next()) + sum += element().value(); + + reset(); + while (move_next()) + element().value() /= sum; + } + + void swap ( + joint_probability_table& item + ) + { + table.swap(item.table); + } + + friend inline void serialize ( + const joint_probability_table& item, + std::ostream& out + ) + { + serialize(item.table, out); + } + + friend inline void deserialize ( + joint_probability_table& item, + std::istream& in + ) + { + deserialize(item.table, in); + } + + private: + + dlib::map::kernel_1b_c table; + }; + + inline void swap ( + joint_probability_table& a, + joint_probability_table& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- + + class conditional_probability_table : noncopyable + { + /*! + INITIAL VALUE + - table.size() == 0 + + CONVENTION + - if (table.is_in_domain(ps) && value < num_vals && table[ps](value) >= 0) then + - has_entry_for(value,ps) == true + - probability(value,ps) == table[ps](value) + - else + - has_entry_for(value,ps) == false + + - num_values() == num_vals + !*/ + public: + + conditional_probability_table() + { + clear(); + } + + void set_num_values ( + unsigned long num + ) + { + num_vals = num; + table.clear(); + } + + bool has_entry_for ( + unsigned long value, + const assignment& ps + ) const + { + if (table.is_in_domain(ps) && value < num_vals && table[ps](value) >= 0) + return true; + else + return false; + } + + unsigned long num_values ( + ) const { return num_vals; } + + void set_probability ( + unsigned long value, + const assignment& ps, + double p + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( value < num_values() && 0.0 <= p && p <= 1.0 , + "\tvoid conditional_probability_table::set_probability()" + << "\n\tinvalid arguments to set_probability" + << "\n\tvalue: " << value + << "\n\tnum_values(): " << num_values() + << "\n\tp: " << p + << "\n\tps: " << ps + << "\n\tthis: " << this + ); + + if (table.is_in_domain(ps)) + { + table[ps](value) = p; + } + else + { + matrix dist(num_vals); + set_all_elements(dist,-1); + dist(value) = p; + assignment temp(ps); + table.add(temp,dist); + } + } + + double probability( + unsigned long value, + const assignment& ps + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT( value < num_values() && has_entry_for(value,ps) , + "\tvoid conditional_probability_table::probability()" + << "\n\tinvalid arguments to probability" + << "\n\tvalue: " << value + << "\n\tnum_values(): " << num_values() + << "\n\tps: " << ps + << "\n\tthis: " << this + ); + + return table[ps](value); + } + + void clear() + { + table.clear(); + num_vals = 0; + } + + void empty_table () + { + table.clear(); + } + + void swap ( + conditional_probability_table& item + ) + { + exchange(num_vals, item.num_vals); + table.swap(item.table); + } + + friend inline void serialize ( + const conditional_probability_table& item, + std::ostream& out + ) + { + serialize(item.table, out); + serialize(item.num_vals, out); + } + + friend inline void deserialize ( + conditional_probability_table& item, + std::istream& in + ) + { + deserialize(item.table, in); + deserialize(item.num_vals, in); + } + + private: + dlib::map >::kernel_1b_c table; + unsigned long num_vals; + }; + + inline void swap ( + conditional_probability_table& a, + conditional_probability_table& b + ) { a.swap(b); } + +// ------------------------------------------------------------------------ + + class bayes_node : noncopyable + { + public: + bayes_node () + { + is_instantiated = false; + value_ = 0; + } + + unsigned long value ( + ) const { return value_;} + + void set_value ( + unsigned long new_value + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( new_value < table().num_values(), + "\tvoid bayes_node::set_value(new_value)" + << "\n\tnew_value must be less than the number of possible values for this node" + << "\n\tnew_value: " << new_value + << "\n\ttable().num_values(): " << table().num_values() + << "\n\tthis: " << this + ); + + value_ = new_value; + } + + conditional_probability_table& table ( + ) { return table_; } + + const conditional_probability_table& table ( + ) const { return table_; } + + bool is_evidence ( + ) const { return is_instantiated; } + + void set_as_nonevidence ( + ) { is_instantiated = false; } + + void set_as_evidence ( + ) { is_instantiated = true; } + + void swap ( + bayes_node& item + ) + { + exchange(value_, item.value_); + exchange(is_instantiated, item.is_instantiated); + table_.swap(item.table_); + } + + friend inline void serialize ( + const bayes_node& item, + std::ostream& out + ) + { + serialize(item.value_, out); + serialize(item.is_instantiated, out); + serialize(item.table_, out); + } + + friend inline void deserialize ( + bayes_node& item, + std::istream& in + ) + { + deserialize(item.value_, in); + deserialize(item.is_instantiated, in); + deserialize(item.table_, in); + } + + private: + + unsigned long value_; + bool is_instantiated; + conditional_probability_table table_; + }; + + inline void swap ( + bayes_node& a, + bayes_node& b + ) { a.swap(b); } + +// ------------------------------------------------------------------------ + + namespace bayes_node_utils + { + + template + unsigned long node_num_values ( + const T& bn, + unsigned long n + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( n < bn.number_of_nodes(), + "\tvoid bayes_node_utils::node_num_values(bn, n)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tbn.number_of_nodes(): " << bn.number_of_nodes() + ); + + return bn.node(n).data.table().num_values(); + } + + // ---------------------------------------------------------------------------------------- + + template + void set_node_value ( + T& bn, + unsigned long n, + unsigned long val + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( n < bn.number_of_nodes() && val < node_num_values(bn,n), + "\tvoid bayes_node_utils::set_node_value(bn, n, val)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tval: " << val + << "\n\tbn.number_of_nodes(): " << bn.number_of_nodes() + << "\n\tnode_num_values(bn,n): " << node_num_values(bn,n) + ); + + bn.node(n).data.set_value(val); + } + + // ---------------------------------------------------------------------------------------- + template + unsigned long node_value ( + const T& bn, + unsigned long n + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( n < bn.number_of_nodes(), + "\tunsigned long bayes_node_utils::node_value(bn, n)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tbn.number_of_nodes(): " << bn.number_of_nodes() + ); + + return bn.node(n).data.value(); + } + // ---------------------------------------------------------------------------------------- + + template + bool node_is_evidence ( + const T& bn, + unsigned long n + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( n < bn.number_of_nodes(), + "\tbool bayes_node_utils::node_is_evidence(bn, n)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tbn.number_of_nodes(): " << bn.number_of_nodes() + ); + + return bn.node(n).data.is_evidence(); + } + + // ---------------------------------------------------------------------------------------- + + template + void set_node_as_evidence ( + T& bn, + unsigned long n + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( n < bn.number_of_nodes(), + "\tvoid bayes_node_utils::set_node_as_evidence(bn, n)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tbn.number_of_nodes(): " << bn.number_of_nodes() + ); + + bn.node(n).data.set_as_evidence(); + } + + // ---------------------------------------------------------------------------------------- + template + void set_node_as_nonevidence ( + T& bn, + unsigned long n + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( n < bn.number_of_nodes(), + "\tvoid bayes_node_utils::set_node_as_nonevidence(bn, n)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tbn.number_of_nodes(): " << bn.number_of_nodes() + ); + + bn.node(n).data.set_as_nonevidence(); + } + + // ---------------------------------------------------------------------------------------- + + template + void set_node_num_values ( + T& bn, + unsigned long n, + unsigned long num + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( n < bn.number_of_nodes(), + "\tvoid bayes_node_utils::set_node_num_values(bn, n, num)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tbn.number_of_nodes(): " << bn.number_of_nodes() + ); + + bn.node(n).data.table().set_num_values(num); + } + + // ---------------------------------------------------------------------------------------- + + template + double node_probability ( + const T& bn, + unsigned long n, + unsigned long value, + const assignment& parents + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( n < bn.number_of_nodes() && value < node_num_values(bn,n), + "\tdouble bayes_node_utils::node_probability(bn, n, value, parents)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tvalue: " << value + << "\n\tbn.number_of_nodes(): " << bn.number_of_nodes() + << "\n\tnode_num_values(bn,n): " << node_num_values(bn,n) + ); + + DLIB_ASSERT( parents.size() == bn.node(n).number_of_parents(), + "\tdouble bayes_node_utils::node_probability(bn, n, value, parents)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tparents.size(): " << parents.size() + << "\n\tb.node(n).number_of_parents(): " << bn.node(n).number_of_parents() + ); + +#ifdef ENABLE_ASSERTS + parents.reset(); + while (parents.move_next()) + { + const unsigned long x = parents.element().key(); + DLIB_ASSERT( bn.has_edge(x, n), + "\tdouble bayes_node_utils::node_probability(bn, n, value, parents)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tx: " << x + ); + DLIB_ASSERT( parents[x] < node_num_values(bn,x), + "\tdouble bayes_node_utils::node_probability(bn, n, value, parents)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tx: " << x + << "\n\tparents[x]: " << parents[x] + << "\n\tnode_num_values(bn,x): " << node_num_values(bn,x) + ); + } +#endif + + return bn.node(n).data.table().probability(value, parents); + } + + // ---------------------------------------------------------------------------------------- + + template + void set_node_probability ( + T& bn, + unsigned long n, + unsigned long value, + const assignment& parents, + double p + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( n < bn.number_of_nodes() && value < node_num_values(bn,n), + "\tvoid bayes_node_utils::set_node_probability(bn, n, value, parents, p)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tp: " << p + << "\n\tvalue: " << value + << "\n\tbn.number_of_nodes(): " << bn.number_of_nodes() + << "\n\tnode_num_values(bn,n): " << node_num_values(bn,n) + ); + + DLIB_ASSERT( parents.size() == bn.node(n).number_of_parents(), + "\tvoid bayes_node_utils::set_node_probability(bn, n, value, parents, p)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tp: " << p + << "\n\tparents.size(): " << parents.size() + << "\n\tbn.node(n).number_of_parents(): " << bn.node(n).number_of_parents() + ); + + DLIB_ASSERT( 0.0 <= p && p <= 1.0, + "\tvoid bayes_node_utils::set_node_probability(bn, n, value, parents, p)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tp: " << p + ); + +#ifdef ENABLE_ASSERTS + parents.reset(); + while (parents.move_next()) + { + const unsigned long x = parents.element().key(); + DLIB_ASSERT( bn.has_edge(x, n), + "\tvoid bayes_node_utils::set_node_probability(bn, n, value, parents, p)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tx: " << x + ); + DLIB_ASSERT( parents[x] < node_num_values(bn,x), + "\tvoid bayes_node_utils::set_node_probability(bn, n, value, parents, p)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tx: " << x + << "\n\tparents[x]: " << parents[x] + << "\n\tnode_num_values(bn,x): " << node_num_values(bn,x) + ); + } +#endif + + bn.node(n).data.table().set_probability(value,parents,p); + } + +// ---------------------------------------------------------------------------------------- + + template + const assignment node_first_parent_assignment ( + const T& bn, + unsigned long n + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( n < bn.number_of_nodes(), + "\tconst assignment bayes_node_utils::node_first_parent_assignment(bn, n)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + ); + + assignment a; + const unsigned long num_parents = bn.node(n).number_of_parents(); + for (unsigned long i = 0; i < num_parents; ++i) + { + a.add(bn.node(n).parent(i).index(), 0); + } + return a; + } + +// ---------------------------------------------------------------------------------------- + + template + bool node_next_parent_assignment ( + const T& bn, + unsigned long n, + assignment& a + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( n < bn.number_of_nodes(), + "\tbool bayes_node_utils::node_next_parent_assignment(bn, n, a)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + ); + + DLIB_ASSERT( a.size() == bn.node(n).number_of_parents(), + "\tbool bayes_node_utils::node_next_parent_assignment(bn, n, a)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\ta.size(): " << a.size() + << "\n\tbn.node(n).number_of_parents(): " << bn.node(n).number_of_parents() + ); + +#ifdef ENABLE_ASSERTS + a.reset(); + while (a.move_next()) + { + const unsigned long x = a.element().key(); + DLIB_ASSERT( bn.has_edge(x, n), + "\tbool bayes_node_utils::node_next_parent_assignment(bn, n, a)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tx: " << x + ); + DLIB_ASSERT( a[x] < node_num_values(bn,x), + "\tbool bayes_node_utils::node_next_parent_assignment(bn, n, a)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tx: " << x + << "\n\ta[x]: " << a[x] + << "\n\tnode_num_values(bn,x): " << node_num_values(bn,x) + ); + } +#endif + + // basically this loop just adds 1 to the assignment but performs + // carries if necessary + for (unsigned long p = 0; p < a.size(); ++p) + { + const unsigned long pindex = bn.node(n).parent(p).index(); + a[pindex] += 1; + + // if we need to perform a carry + if (a[pindex] >= node_num_values(bn,pindex)) + { + a[pindex] = 0; + } + else + { + // no carry necessary so we are done + return true; + } + } + + // we got through the entire loop which means a carry propagated all the way out + // so there must not be any more valid assignments left + return false; + } + +// ---------------------------------------------------------------------------------------- + + template + bool node_cpt_filled_out ( + const T& bn, + unsigned long n + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( n < bn.number_of_nodes(), + "\tbool bayes_node_utils::node_cpt_filled_out(bn, n)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tbn.number_of_nodes(): " << bn.number_of_nodes() + ); + + const unsigned long num_values = node_num_values(bn,n); + + + const conditional_probability_table& table = bn.node(n).data.table(); + + // now loop over all the possible parent assignments for this node + assignment a(node_first_parent_assignment(bn,n)); + do + { + double sum = 0; + // make sure that this assignment has an entry for all the values this node can take one + for (unsigned long value = 0; value < num_values; ++value) + { + if (table.has_entry_for(value,a) == false) + return false; + else + sum += table.probability(value,a); + } + + // check if the sum of probabilities equals 1 as it should + if (std::abs(sum-1.0) > 1e-5) + return false; + } while (node_next_parent_assignment(bn,n,a)); + + return true; + } + + } + +// ---------------------------------------------------------------------------------------- + + class bayesian_network_gibbs_sampler : noncopyable + { + public: + + bayesian_network_gibbs_sampler () + { + rnd.set_seed(cast_to_string(std::time(0))); + } + + + template < + typename T + > + void sample_graph ( + T& bn + ) + { + using namespace bayes_node_utils; + for (unsigned long n = 0; n < bn.number_of_nodes(); ++n) + { + if (node_is_evidence(bn, n)) + continue; + + samples.set_size(node_num_values(bn,n)); + // obtain the probability distribution for this node + for (long i = 0; i < samples.nc(); ++i) + { + set_node_value(bn, n, i); + samples(i) = node_probability(bn, n); + + for (unsigned long j = 0; j < bn.node(n).number_of_children(); ++j) + samples(i) *= node_probability(bn, bn.node(n).child(j).index()); + } + + //normalize samples + samples /= sum(samples); + + + // select a random point in the probability distribution + double prob = rnd.get_random_double(); + + // now find the point in the distribution this probability corresponds to + long j; + for (j = 0; j < samples.nc()-1; ++j) + { + if (prob <= samples(j)) + break; + else + prob -= samples(j); + } + + set_node_value(bn, n, j); + } + } + + + private: + + template < + typename T + > + double node_probability ( + const T& bn, + unsigned long n + ) + /*! + requires + - n < bn.number_of_nodes() + ensures + - computes the probability of node n having its current value given + the current values of its parents in the network bn + !*/ + { + v.clear(); + for (unsigned long i = 0; i < bn.node(n).number_of_parents(); ++i) + { + v.add(bn.node(n).parent(i).index(), bn.node(n).parent(i).data.value()); + } + return bn.node(n).data.table().probability(bn.node(n).data.value(), v); + } + + assignment v; + + dlib::rand rnd; + matrix samples; + }; + +// ---------------------------------------------------------------------------------------- + + namespace bayesian_network_join_tree_helpers + { + class bnjt + { + /*! + this object is the base class used in this pimpl idiom + !*/ + public: + virtual ~bnjt() {} + + virtual const matrix probability( + unsigned long idx + ) const = 0; + }; + + template + class bnjt_impl : public bnjt + { + /*! + This object is the implementation in the pimpl idiom + !*/ + + public: + + bnjt_impl ( + const T& bn, + const U& join_tree + ) + { + create_bayesian_network_join_tree(bn, join_tree, join_tree_values); + + cliques.resize(bn.number_of_nodes()); + + // figure out which cliques contain each node + for (unsigned long i = 0; i < cliques.size(); ++i) + { + // find the smallest clique that contains node with index i + unsigned long smallest_clique = 0; + unsigned long size = std::numeric_limits::max(); + + for (unsigned long n = 0; n < join_tree.number_of_nodes(); ++n) + { + if (join_tree.node(n).data.is_member(i) && join_tree.node(n).data.size() < size) + { + size = join_tree.node(n).data.size(); + smallest_clique = n; + } + } + + cliques[i] = smallest_clique; + } + } + + virtual const matrix probability( + unsigned long idx + ) const + { + join_tree_values.node(cliques[idx]).data.marginalize(idx, table); + table.normalize(); + var.clear(); + var.add(idx); + dist.set_size(table.size()); + + // read the probabilities out of the table and into the row matrix + for (unsigned long i = 0; i < table.size(); ++i) + { + var[idx] = i; + dist(i) = table.probability(var); + } + + return dist; + } + + private: + + graph< joint_probability_table, joint_probability_table >::kernel_1a_c join_tree_values; + array cliques; + mutable joint_probability_table table; + mutable assignment var; + mutable matrix dist; + + + // ---------------------------------------------------------------------------------------- + + template + bool set_contains_all_parents_of_node ( + const set_type& set, + const node_type& node + ) + { + for (unsigned long i = 0; i < node.number_of_parents(); ++i) + { + if (set.is_member(node.parent(i).index()) == false) + return false; + } + return true; + } + + // ---------------------------------------------------------------------------------------- + + template < + typename V + > + void pass_join_tree_message ( + const U& join_tree, + V& bn_join_tree , + unsigned long from, + unsigned long to + ) + { + using namespace bayes_node_utils; + const typename U::edge_type& e = edge(join_tree, from, to); + typename V::edge_type& old_s = edge(bn_join_tree, from, to); + + typedef typename V::edge_type joint_prob_table; + + joint_prob_table new_s; + bn_join_tree.node(from).data.marginalize(e, new_s); + + joint_probability_table temp(new_s); + // divide new_s by old_s and store the result in temp. + // if old_s is empty then that is the same as if it was all 1s + // so we don't have to do this if that is the case. + if (old_s.size() > 0) + { + temp.reset(); + old_s.reset(); + while (temp.move_next()) + { + old_s.move_next(); + if (old_s.element().value() != 0) + temp.element().value() /= old_s.element().value(); + } + } + + // now multiply temp by d and store the results in d + joint_probability_table& d = bn_join_tree.node(to).data; + d.reset(); + while (d.move_next()) + { + assignment a; + const assignment& asrc = d.element().key(); + asrc.reset(); + while (asrc.move_next()) + { + if (e.is_member(asrc.element().key())) + a.add(asrc.element().key(), asrc.element().value()); + } + + d.element().value() *= temp.probability(a); + + } + + // store new_s in old_s + new_s.swap(old_s); + + } + + // ---------------------------------------------------------------------------------------- + + template < + typename V + > + void create_bayesian_network_join_tree ( + const T& bn, + const U& join_tree, + V& bn_join_tree + ) + /*! + requires + - bn is a proper bayesian network + - join_tree is the join tree for that bayesian network + ensures + - bn_join_tree == the output of the join tree algorithm for bayesian network inference. + So each node in this graph contains a joint_probability_table for the clique + in the corresponding node in the join_tree graph. + !*/ + { + using namespace bayes_node_utils; + bn_join_tree.clear(); + copy_graph_structure(join_tree, bn_join_tree); + + // we need to keep track of which node is "in" each clique for the purposes of + // initializing the tables in each clique. So this vector will be used to do that + // and a value of join_tree.number_of_nodes() means that the node with + // that index is unassigned. + std::vector node_assigned_to(bn.number_of_nodes(),join_tree.number_of_nodes()); + + // populate evidence with all the evidence node indices and their values + dlib::map::kernel_1b_c evidence; + for (unsigned long i = 0; i < bn.number_of_nodes(); ++i) + { + if (node_is_evidence(bn, i)) + { + unsigned long idx = i; + unsigned long value = node_value(bn, i); + evidence.add(idx,value); + } + } + + + // initialize the bn join tree + for (unsigned long i = 0; i < join_tree.number_of_nodes(); ++i) + { + bool contains_evidence = false; + std::vector indices; + assignment value; + + // loop over all the nodes in this clique in the join tree. In this loop + // we are making an assignment with all the values of the nodes it represents set to 0 + join_tree.node(i).data.reset(); + while (join_tree.node(i).data.move_next()) + { + const unsigned long idx = join_tree.node(i).data.element(); + indices.push_back(idx); + value.add(idx); + + if (evidence.is_in_domain(join_tree.node(i).data.element())) + contains_evidence = true; + } + + // now loop over all possible combinations of values that the nodes this + // clique in the join tree can take on. We do this by counting by one through all + // legal values + bool more_assignments = true; + while (more_assignments) + { + bn_join_tree.node(i).data.set_probability(value,1); + + // account for any evidence + if (contains_evidence) + { + // loop over all the nodes in this cluster + for (unsigned long j = 0; j < indices.size(); ++j) + { + // if the current node is an evidence node + if (evidence.is_in_domain(indices[j])) + { + const unsigned long idx = indices[j]; + const unsigned long evidence_value = evidence[idx]; + if (value[idx] != evidence_value) + bn_join_tree.node(i).data.set_probability(value , 0); + } + } + } + + + // now check if any of the nodes in this cluster also have their parents in this cluster + join_tree.node(i).data.reset(); + while (join_tree.node(i).data.move_next()) + { + const unsigned long idx = join_tree.node(i).data.element(); + // if this clique contains all the parents of this node and also hasn't + // been assigned to another clique + if (set_contains_all_parents_of_node(join_tree.node(i).data, bn.node(idx)) && + (i == node_assigned_to[idx] || node_assigned_to[idx] == join_tree.number_of_nodes()) ) + { + // note that this node is now assigned to this clique + node_assigned_to[idx] = i; + // node idx has all its parents in the cluster + assignment parent_values; + for (unsigned long j = 0; j < bn.node(idx).number_of_parents(); ++j) + { + const unsigned long pidx = bn.node(idx).parent(j).index(); + parent_values.add(pidx, value[pidx]); + } + + double temp = bn_join_tree.node(i).data.probability(value); + bn_join_tree.node(i).data.set_probability(value, temp * node_probability(bn, idx, value[idx], parent_values)); + + } + } + + + // now advance the value variable to its next possible state if there is one + more_assignments = false; + value.reset(); + while (value.move_next()) + { + value.element().value() += 1; + // if overflow + if (value.element().value() == node_num_values(bn, value.element().key())) + { + value.element().value() = 0; + } + else + { + more_assignments = true; + break; + } + } + + } // end while (more_assignments) + } + + + + + // the tree is now initialized. Now all we need to do is perform the propagation and + // we are done + dlib::array::compare_1b_c> remaining_msg_to_send; + dlib::array::compare_1b_c> remaining_msg_to_receive; + remaining_msg_to_receive.resize(join_tree.number_of_nodes()); + remaining_msg_to_send.resize(join_tree.number_of_nodes()); + for (unsigned long i = 0; i < remaining_msg_to_receive.size(); ++i) + { + for (unsigned long j = 0; j < join_tree.node(i).number_of_neighbors(); ++j) + { + const unsigned long idx = join_tree.node(i).neighbor(j).index(); + unsigned long temp; + temp = idx; remaining_msg_to_receive[i].add(temp); + temp = idx; remaining_msg_to_send[i].add(temp); + } + } + + // now remaining_msg_to_receive[i] contains all the nodes that node i hasn't yet received + // a message from. + // we will consider node 0 to be the root node. + + + bool message_sent = true; + std::vector::iterator iter; + while (message_sent) + { + message_sent = false; + for (unsigned long i = 1; i < remaining_msg_to_send.size(); ++i) + { + // if node i hasn't sent any messages but has received all but one then send a message to the one + // node who hasn't sent i a message + if (remaining_msg_to_send[i].size() == join_tree.node(i).number_of_neighbors() && remaining_msg_to_receive[i].size() == 1) + { + unsigned long to; + // get the last remaining thing from this set + remaining_msg_to_receive[i].remove_any(to); + + // send the message + pass_join_tree_message(join_tree, bn_join_tree, i, to); + + // record that we sent this message + remaining_msg_to_send[i].destroy(to); + remaining_msg_to_receive[to].destroy(i); + + // put to back in since we still need to receive it + remaining_msg_to_receive[i].add(to); + message_sent = true; + } + else if (remaining_msg_to_receive[i].size() == 0 && remaining_msg_to_send[i].size() > 0) + { + unsigned long to; + remaining_msg_to_send[i].remove_any(to); + remaining_msg_to_receive[to].destroy(i); + pass_join_tree_message(join_tree, bn_join_tree, i, to); + message_sent = true; + } + } + + if (remaining_msg_to_receive[0].size() == 0) + { + // send a message to all of the root nodes neighbors unless we have already sent out he messages + while (remaining_msg_to_send[0].size() > 0) + { + unsigned long to; + remaining_msg_to_send[0].remove_any(to); + remaining_msg_to_receive[to].destroy(0); + pass_join_tree_message(join_tree, bn_join_tree, 0, to); + message_sent = true; + } + } + + + } + + } + + }; + } + + class bayesian_network_join_tree : noncopyable + { + /*! + use the pimpl idiom to push the template arguments from the class level to the + constructor level + !*/ + + public: + + template < + typename T, + typename U + > + bayesian_network_join_tree ( + const T& bn, + const U& join_tree + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( bn.number_of_nodes() > 0 , + "\tbayesian_network_join_tree::bayesian_network_join_tree(bn,join_tree)" + << "\n\tYou have given an invalid bayesian network" + << "\n\tthis: " << this + ); + + DLIB_ASSERT( is_join_tree(bn, join_tree) == true , + "\tbayesian_network_join_tree::bayesian_network_join_tree(bn,join_tree)" + << "\n\tYou have given an invalid join tree for the supplied bayesian network" + << "\n\tthis: " << this + ); + DLIB_ASSERT( graph_contains_length_one_cycle(bn) == false, + "\tbayesian_network_join_tree::bayesian_network_join_tree(bn,join_tree)" + << "\n\tYou have given an invalid bayesian network" + << "\n\tthis: " << this + ); + DLIB_ASSERT( graph_is_connected(bn) == true, + "\tbayesian_network_join_tree::bayesian_network_join_tree(bn,join_tree)" + << "\n\tYou have given an invalid bayesian network" + << "\n\tthis: " << this + ); + +#ifdef ENABLE_ASSERTS + for (unsigned long i = 0; i < bn.number_of_nodes(); ++i) + { + DLIB_ASSERT(bayes_node_utils::node_cpt_filled_out(bn,i) == true, + "\tbayesian_network_join_tree::bayesian_network_join_tree(bn,join_tree)" + << "\n\tYou have given an invalid bayesian network. " + << "\n\tYou must finish filling out the conditional_probability_table of node " << i + << "\n\tthis: " << this + ); + } +#endif + + impl.reset(new bayesian_network_join_tree_helpers::bnjt_impl(bn, join_tree)); + num_nodes = bn.number_of_nodes(); + } + + const matrix probability( + unsigned long idx + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT( idx < number_of_nodes() , + "\tconst matrix bayesian_network_join_tree::probability(idx)" + << "\n\tYou have specified an invalid node index" + << "\n\tidx: " << idx + << "\n\tnumber_of_nodes(): " << number_of_nodes() + << "\n\tthis: " << this + ); + + return impl->probability(idx); + } + + unsigned long number_of_nodes ( + ) const { return num_nodes; } + + void swap ( + bayesian_network_join_tree& item + ) + { + exchange(num_nodes, item.num_nodes); + impl.swap(item.impl); + } + + private: + + std::unique_ptr impl; + unsigned long num_nodes; + + }; + + inline void swap ( + bayesian_network_join_tree& a, + bayesian_network_join_tree& b + ) { a.swap(b); } + +} + +// ---------------------------------------------------------------------------------------- + +#endif // DLIB_BAYES_UTILs_ + diff --git a/ml/dlib/dlib/bayes_utils/bayes_utils_abstract.h b/ml/dlib/dlib/bayes_utils/bayes_utils_abstract.h new file mode 100644 index 000000000..b19e6e1da --- /dev/null +++ b/ml/dlib/dlib/bayes_utils/bayes_utils_abstract.h @@ -0,0 +1,1042 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_BAYES_UTILs_ABSTRACT_ +#ifdef DLIB_BAYES_UTILs_ABSTRACT_ + +#include "../algs.h" +#include "../noncopyable.h" +#include "../interfaces/enumerable.h" +#include "../interfaces/map_pair.h" +#include "../serialize.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class assignment : public enumerable > + { + /*! + INITIAL VALUE + - size() == 0 + + ENUMERATION ORDER + The enumerator will iterate over the entries in the assignment in + ascending order according to index values. (i.e. the elements are + enumerated in sorted order according to the value of their keys) + + WHAT THIS OBJECT REPRESENTS + This object models an assignment of random variables to particular values. + It is used with the joint_probability_table and conditional_probability_table + objects to represent assignments of various random variables to actual values. + + So for example, if you had a joint_probability_table that represented the + following table: + P(A = 0, B = 0) = 0.2 + P(A = 0, B = 1) = 0.3 + P(A = 1, B = 0) = 0.1 + P(A = 1, B = 1) = 0.4 + + Also lets define an enum so we have concrete index numbers for A and B + enum { A = 0, B = 1}; + + Then you could query the value of P(A=1, B=0) as follows: + assignment a; + a.set(A, 1); + a.set(B, 0); + // and now it is the case that: + table.probability(a) == 0.1 + a[A] == 1 + a[B] == 0 + + + Also note that when enumerating the elements of an assignment object + the key() refers to the index and the value() refers to the value at that + index. For example: + + // assume a is an assignment object + a.reset(); + while (a.move_next()) + { + // in this loop it is always the case that: + // a[a.element().key()] == a.element().value() + } + !*/ + + public: + + assignment( + ); + /*! + ensures + - this object is properly initialized + !*/ + + assignment( + const assignment& a + ); + /*! + ensures + - #*this is a copy of a + !*/ + + assignment& operator = ( + const assignment& rhs + ); + /*! + ensures + - #*this is a copy of rhs + - returns *this + !*/ + + void clear( + ); + /*! + ensures + - this object has been returned to its initial value + !*/ + + bool operator < ( + const assignment& item + ) const; + /*! + ensures + - The exact functioning of this operator is undefined. The only guarantee + is that it establishes a total ordering on all possible assignment objects. + In other words, this operator makes it so that you can use assignment + objects in the associative containers but otherwise isn't of any + particular use. + !*/ + + bool has_index ( + unsigned long idx + ) const; + /*! + ensures + - if (this assignment object has an entry for index idx) then + - returns true + - else + - returns false + !*/ + + void add ( + unsigned long idx, + unsigned long value = 0 + ); + /*! + requires + - has_index(idx) == false + ensures + - #has_index(idx) == true + - #(*this)[idx] == value + !*/ + + void remove ( + unsigned long idx + ); + /*! + requires + - has_index(idx) == true + ensures + - #has_index(idx) == false + !*/ + + unsigned long& operator[] ( + const long idx + ); + /*! + requires + - has_index(idx) == true + ensures + - returns a reference to the value associated with index idx + !*/ + + const unsigned long& operator[] ( + const long idx + ) const; + /*! + requires + - has_index(idx) == true + ensures + - returns a const reference to the value associated with index idx + !*/ + + void swap ( + assignment& item + ); + /*! + ensures + - swaps *this and item + !*/ + + }; + + inline void swap ( + assignment& a, + assignment& b + ) { a.swap(b); } + /*! + provides a global swap + !*/ + + std::ostream& operator << ( + std::ostream& out, + const assignment& a + ); + /*! + ensures + - writes a to the given output stream in the following format: + (index1:value1, index2:value2, ..., indexN:valueN) + !*/ + + void serialize ( + const assignment& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + + void deserialize ( + assignment& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +// ------------------------------------------------------------------------ + + class joint_probability_table : public enumerable > + { + /*! + INITIAL VALUE + - size() == 0 + + ENUMERATION ORDER + The enumerator will iterate over the entries in the probability table + in no particular order but they will all be visited. + + WHAT THIS OBJECT REPRESENTS + This object models a joint probability table. That is, it models + the function p(X). So this object models the probability of a particular + set of variables (referred to as X). + !*/ + + public: + + joint_probability_table( + ); + /*! + ensures + - this object is properly initialized + !*/ + + joint_probability_table ( + const joint_probability_table& t + ); + /*! + ensures + - this object is a copy of t + !*/ + + void clear( + ); + /*! + ensures + - this object has its initial value + !*/ + + joint_probability_table& operator= ( + const joint_probability_table& rhs + ); + /*! + ensures + - this object is a copy of rhs + - returns a reference to *this + !*/ + + bool has_entry_for ( + const assignment& a + ) const; + /*! + ensures + - if (this joint_probability_table has an entry for p(X = a)) then + - returns true + - else + - returns false + !*/ + + void set_probability ( + const assignment& a, + double p + ); + /*! + requires + - 0 <= p <= 1 + ensures + - if (has_entry_for(a) == false) then + - #size() == size() + 1 + - #probability(a) == p + - #has_entry_for(a) == true + !*/ + + void add_probability ( + const assignment& a, + double p + ); + /*! + requires + - 0 <= p <= 1 + ensures + - if (has_entry_for(a) == false) then + - #size() == size() + 1 + - #probability(a) == p + - else + - #probability(a) == min(probability(a) + p, 1.0) + (i.e. does a saturating add) + - #has_entry_for(a) == true + !*/ + + const double probability ( + const assignment& a + ) const; + /*! + ensures + - returns the probability p(X == a) + !*/ + + template < + typename T + > + void marginalize ( + const T& vars, + joint_probability_table& output_table + ) const; + /*! + requires + - T is an implementation of set/set_kernel_abstract.h + ensures + - marginalizes *this by summing over all variables not in vars. The + result is stored in output_table. + !*/ + + void marginalize ( + const unsigned long var, + joint_probability_table& output_table + ) const; + /*! + ensures + - is identical to calling the above marginalize() function with a set + that contains only var. Or in other words, performs a marginalization + with just one variable var. So that output_table will contain a table giving + the marginal probability of var all by itself. + !*/ + + void normalize ( + ); + /*! + ensures + - let sum == the sum of all the probabilities in this table + - after normalize() has finished it will be the case that the sum of all + the entries in this table is 1.0. This is accomplished by dividing all + the entries by the sum described above. + !*/ + + void swap ( + joint_probability_table& item + ); + /*! + ensures + - swaps *this and item + !*/ + + }; + + inline void swap ( + joint_probability_table& a, + joint_probability_table& b + ) { a.swap(b); } + /*! + provides a global swap + !*/ + + void serialize ( + const joint_probability_table& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + + void deserialize ( + joint_probability_table& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +// ---------------------------------------------------------------------------------------- + + class conditional_probability_table : noncopyable + { + /*! + INITIAL VALUE + - num_values() == 0 + - has_value_for(x, y) == false for all values of x and y + + WHAT THIS OBJECT REPRESENTS + This object models a conditional probability table. That is, it models + the function p( X | parents). So this object models the conditional + probability of a particular variable (referred to as X) given another set + of variables (referred to as parents). + !*/ + + public: + + conditional_probability_table( + ); + /*! + ensures + - this object is properly initialized + !*/ + + void clear( + ); + /*! + ensures + - this object has its initial value + !*/ + + void empty_table ( + ); + /*! + ensures + - for all possible v and p: + - #has_entry_for(v,p) == false + (i.e. this function clears out the table when you call it but doesn't + change the value of num_values()) + !*/ + + void set_num_values ( + unsigned long num + ); + /*! + ensures + - #num_values() == num + - for all possible v and p: + - #has_entry_for(v,p) == false + (i.e. this function clears out the table when you call it) + !*/ + + unsigned long num_values ( + ) const; + /*! + ensures + - This object models the probability table p(X | parents). This + function returns the number of values X can take on. + !*/ + + bool has_entry_for ( + unsigned long value, + const assignment& ps + ) const; + /*! + ensures + - if (this conditional_probability_table has an entry for p(X = value, parents = ps)) then + - returns true + - else + - returns false + !*/ + + void set_probability ( + unsigned long value, + const assignment& ps, + double p + ); + /*! + requires + - value < num_values() + - 0 <= p <= 1 + ensures + - #probability(ps, value) == p + - #has_entry_for(value, ps) == true + !*/ + + double probability( + unsigned long value, + const assignment& ps + ) const; + /*! + requires + - value < num_values() + - has_entry_for(value, ps) == true + ensures + - returns the probability p( X = value | parents = ps). + !*/ + + void swap ( + conditional_probability_table& item + ); + /*! + ensures + - swaps *this and item + !*/ + }; + + inline void swap ( + conditional_probability_table& a, + conditional_probability_table& b + ) { a.swap(b); } + /*! + provides a global swap + !*/ + + void serialize ( + const conditional_probability_table& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + + void deserialize ( + conditional_probability_table& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +// ------------------------------------------------------------------------ +// ------------------------------------------------------------------------ +// ------------------------------------------------------------------------ + + class bayes_node : noncopyable + { + /*! + INITIAL VALUE + - is_evidence() == false + - value() == 0 + - table().num_values() == 0 + + WHAT THIS OBJECT REPRESENTS + This object represents a node in a bayesian network. It is + intended to be used inside the dlib::directed_graph object to + represent bayesian networks. + !*/ + + public: + bayes_node ( + ); + /*! + ensures + - this object is properly initialized + !*/ + + unsigned long value ( + ) const; + /*! + ensures + - returns the current value of this node + !*/ + + void set_value ( + unsigned long new_value + ); + /*! + requires + - new_value < table().num_values() + ensures + - #value() == new_value + !*/ + + conditional_probability_table& table ( + ); + /*! + ensures + - returns a reference to the conditional_probability_table associated with this node + !*/ + + const conditional_probability_table& table ( + ) const; + /*! + ensures + - returns a const reference to the conditional_probability_table associated with this + node. + !*/ + + bool is_evidence ( + ) const; + /*! + ensures + - if (this is an evidence node) then + - returns true + - else + - returns false + !*/ + + void set_as_nonevidence ( + ); + /*! + ensures + - #is_evidence() == false + !*/ + + void set_as_evidence ( + ); + /*! + ensures + - #is_evidence() == true + !*/ + + void swap ( + bayes_node& item + ); + /*! + ensures + - swaps *this and item + !*/ + + }; + + inline void swap ( + bayes_node& a, + bayes_node& b + ) { a.swap(b); } + /*! + provides a global swap + !*/ + + void serialize ( + const bayes_node& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + + void deserialize ( + bayes_node& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + /* + The following group of functions are convenience functions for manipulating + bayes_node objects while they are inside a directed_graph. These functions + also have additional requires clauses that, in debug mode, will protect you + from attempts to manipulate a bayesian network in an inappropriate way. + */ + + namespace bayes_node_utils + { + + template < + typename T + > + void set_node_value ( + T& bn, + unsigned long n, + unsigned long val + ); + /*! + requires + - T is an implementation of directed_graph/directed_graph_kernel_abstract.h + - T::type == bayes_node + - n < bn.number_of_nodes() + - val < node_num_values(bn, n) + ensures + - #bn.node(n).data.value() = val + !*/ + + // ------------------------------------------------------------------------------------ + + template < + typename T + > + unsigned long node_value ( + const T& bn, + unsigned long n + ); + /*! + requires + - T is an implementation of directed_graph/directed_graph_kernel_abstract.h + - T::type == bayes_node + - n < bn.number_of_nodes() + ensures + - returns bn.node(n).data.value() + !*/ + + // ------------------------------------------------------------------------------------ + + template < + typename T + > + bool node_is_evidence ( + const T& bn, + unsigned long n + ); + /*! + requires + - T is an implementation of directed_graph/directed_graph_kernel_abstract.h + - T::type == bayes_node + - n < bn.number_of_nodes() + ensures + - returns bn.node(n).data.is_evidence() + !*/ + + // ------------------------------------------------------------------------------------ + + template < + typename T + > + void set_node_as_evidence ( + T& bn, + unsigned long n + ); + /*! + requires + - T is an implementation of directed_graph/directed_graph_kernel_abstract.h + - T::type == bayes_node + - n < bn.number_of_nodes() + ensures + - executes: bn.node(n).data.set_as_evidence() + !*/ + + // ------------------------------------------------------------------------------------ + + template < + typename T + > + void set_node_as_nonevidence ( + T& bn, + unsigned long n + ); + /*! + requires + - T is an implementation of directed_graph/directed_graph_kernel_abstract.h + - T::type == bayes_node + - n < bn.number_of_nodes() + ensures + - executes: bn.node(n).data.set_as_nonevidence() + !*/ + + // ------------------------------------------------------------------------------------ + + template < + typename T + > + void set_node_num_values ( + T& bn, + unsigned long n, + unsigned long num + ); + /*! + requires + - T is an implementation of directed_graph/directed_graph_kernel_abstract.h + - T::type == bayes_node + - n < bn.number_of_nodes() + ensures + - #bn.node(n).data.table().num_values() == num + (i.e. sets the number of different values this node can take) + !*/ + + // ------------------------------------------------------------------------------------ + + template < + typename T + > + unsigned long node_num_values ( + const T& bn, + unsigned long n + ); + /*! + requires + - T is an implementation of directed_graph/directed_graph_kernel_abstract.h + - T::type == bayes_node + - n < bn.number_of_nodes() + ensures + - returns bn.node(n).data.table().num_values() + (i.e. returns the number of different values this node can take) + !*/ + + // ------------------------------------------------------------------------------------ + + template < + typename T + > + const double node_probability ( + const T& bn, + unsigned long n, + unsigned long value, + const assignment& parents + ); + /*! + requires + - T is an implementation of directed_graph/directed_graph_kernel_abstract.h + - T::type == bayes_node + - n < bn.number_of_nodes() + - value < node_num_values(bn,n) + - parents.size() == bn.node(n).number_of_parents() + - if (parents.has_index(x)) then + - bn.has_edge(x, n) + - parents[x] < node_num_values(bn,x) + ensures + - returns bn.node(n).data.table().probability(value, parents) + (i.e. returns the probability of node n having the given value when + its parents have the given assignment) + !*/ + + // ------------------------------------------------------------------------------------ + + template < + typename T + > + const double set_node_probability ( + const T& bn, + unsigned long n, + unsigned long value, + const assignment& parents, + double p + ); + /*! + requires + - T is an implementation of directed_graph/directed_graph_kernel_abstract.h + - T::type == bayes_node + - n < bn.number_of_nodes() + - value < node_num_values(bn,n) + - 0 <= p <= 1 + - parents.size() == bn.node(n).number_of_parents() + - if (parents.has_index(x)) then + - bn.has_edge(x, n) + - parents[x] < node_num_values(bn,x) + ensures + - #bn.node(n).data.table().probability(value, parents) == p + (i.e. sets the probability of node n having the given value when + its parents have the given assignment to the probability p) + !*/ + + // ------------------------------------------------------------------------------------ + + template + const assignment node_first_parent_assignment ( + const T& bn, + unsigned long n + ); + /*! + requires + - T is an implementation of directed_graph/directed_graph_kernel_abstract.h + - T::type == bayes_node + - n < bn.number_of_nodes() + ensures + - returns an assignment A such that: + - A.size() == bn.node(n).number_of_parents() + - if (P is a parent of bn.node(n)) then + - A.has_index(P) + - A[P] == 0 + - I.e. this function returns an assignment that contains all + the parents of the given node. Also, all the values of each + parent in the assignment is set to zero. + !*/ + + // ------------------------------------------------------------------------------------ + + template + bool node_next_parent_assignment ( + const T& bn, + unsigned long n, + assignment& A + ); + /*! + requires + - T is an implementation of directed_graph/directed_graph_kernel_abstract.h + - T::type == bayes_node + - n < bn.number_of_nodes() + - A.size() == bn.node(n).number_of_parents() + - if (A.has_index(x)) then + - bn.has_edge(x, n) + - A[x] < node_num_values(bn,x) + ensures + - The behavior of this function is defined by the following code: + assignment a(node_first_parent_assignment(bn,n); + do { + // this loop loops over all possible parent assignments + // of the node bn.node(n). Each time through the loop variable a + // will be the next assignment. + } while (node_next_parent_assignment(bn,n,a)) + !*/ + + // ------------------------------------------------------------------------------------ + + template + bool node_cpt_filled_out ( + const T& bn, + unsigned long n + ); + /*! + requires + - T is an implementation of directed_graph/directed_graph_kernel_abstract.h + - T::type == bayes_node + - n < bn.number_of_nodes() + ensures + - if (the conditional_probability_table bn.node(n).data.table() is + fully filled out for this node) then + - returns true + - This means that each parent assignment for the given node + along with all possible values of this node shows up in the + table. + - It also means that all the probabilities conditioned on the + same parent assignment sum to 1.0 + - else + - returns false + !*/ + + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class bayesian_network_gibbs_sampler : noncopyable + { + /*! + INITIAL VALUE + This object has no state + + WHAT THIS OBJECT REPRESENTS + This object performs Markov Chain Monte Carlo sampling of a bayesian + network using the Gibbs sampling technique. + + Note that this object is limited to only bayesian networks that + don't contain deterministic nodes. That is, incorrect results may + be computed if this object is used when the bayesian network contains + any nodes that have a probability of 1 in their conditional probability + tables for any event. So don't use this object for networks with + deterministic nodes. + !*/ + public: + + bayesian_network_gibbs_sampler ( + ); + /*! + ensures + - this object is properly initialized + !*/ + + template < + typename T + > + void sample_graph ( + T& bn + ) + /*! + requires + - T is an implementation of directed_graph/directed_graph_kernel_abstract.h + - T::type == bayes_node + ensures + - modifies randomly (via the Gibbs sampling technique) samples all the nodes + in the network and updates their values with the newly sampled values + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + class bayesian_network_join_tree : noncopyable + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents an implementation of the join tree algorithm + for inference in bayesian networks. It doesn't have any mutable state. + To you use you just give it a directed_graph that contains a bayesian + network and a graph object that contains that networks corresponding + join tree. Then you may query this object to determine the probabilities + of any variables in the original bayesian network. + !*/ + + public: + + template < + typename bn_type, + typename join_tree_type + > + bayesian_network_join_tree ( + const bn_type& bn, + const join_tree_type& join_tree + ); + /*! + requires + - bn_type is an implementation of directed_graph/directed_graph_kernel_abstract.h + - bn_type::type == bayes_node + - join_tree_type is an implementation of graph/graph_kernel_abstract.h + - join_tree_type::type is an implementation of set/set_compare_abstract.h and + this set type contains unsigned long objects. + - join_tree_type::edge_type is an implementation of set/set_compare_abstract.h and + this set type contains unsigned long objects. + - is_join_tree(bn, join_tree) == true + - bn == a valid bayesian network with all its conditional probability tables + filled out + - for all valid n: + - node_cpt_filled_out(bn,n) == true + - graph_contains_length_one_cycle(bn) == false + - graph_is_connected(bn) == true + - bn.number_of_nodes() > 0 + ensures + - this object is properly initialized + !*/ + + unsigned long number_of_nodes ( + ) const; + /*! + ensures + - returns the number of nodes in the bayesian network that this + object was instantiated from. + !*/ + + const matrix probability( + unsigned long idx + ) const; + /*! + requires + - idx < number_of_nodes() + ensures + - returns the probability distribution for the node with index idx that was in the bayesian + network that *this was instantiated from. Let D represent this distribution, then: + - D.nc() == the number of values the node idx ranges over + - D.nr() == 1 + - D(i) == the probability of node idx taking on the value i + !*/ + + void swap ( + bayesian_network_join_tree& item + ); + /*! + ensures + - swaps *this with item + !*/ + + }; + + inline void swap ( + bayesian_network_join_tree& a, + bayesian_network_join_tree& b + ) { a.swap(b); } + /*! + provides a global swap + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BAYES_UTILs_ABSTRACT_ + + diff --git a/ml/dlib/dlib/bigint.h b/ml/dlib/dlib/bigint.h new file mode 100644 index 000000000..73496689a --- /dev/null +++ b/ml/dlib/dlib/bigint.h @@ -0,0 +1,43 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BIGINt_ +#define DLIB_BIGINt_ + +#include "bigint/bigint_kernel_1.h" +#include "bigint/bigint_kernel_2.h" +#include "bigint/bigint_kernel_c.h" + + + + +namespace dlib +{ + + + class bigint + { + bigint() {} + + + public: + + //----------- kernels --------------- + + // kernel_1a + typedef bigint_kernel_1 + kernel_1a; + typedef bigint_kernel_c + kernel_1a_c; + + // kernel_2a + typedef bigint_kernel_2 + kernel_2a; + typedef bigint_kernel_c + kernel_2a_c; + + + }; +} + +#endif // DLIB_BIGINt_ + diff --git a/ml/dlib/dlib/bigint/bigint_kernel_1.cpp b/ml/dlib/dlib/bigint/bigint_kernel_1.cpp new file mode 100644 index 000000000..feef761c2 --- /dev/null +++ b/ml/dlib/dlib/bigint/bigint_kernel_1.cpp @@ -0,0 +1,1720 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BIGINT_KERNEL_1_CPp_ +#define DLIB_BIGINT_KERNEL_1_CPp_ +#include "bigint_kernel_1.h" + +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member/friend function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1:: + bigint_kernel_1 ( + ) : + slack(25), + data(new data_record(slack)) + {} + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1:: + bigint_kernel_1 ( + uint32 value + ) : + slack(25), + data(new data_record(slack)) + { + *(data->number) = static_cast(value&0xFFFF); + *(data->number+1) = static_cast((value>>16)&0xFFFF); + if (*(data->number+1) != 0) + data->digits_used = 2; + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1:: + bigint_kernel_1 ( + const bigint_kernel_1& item + ) : + slack(25), + data(item.data) + { + data->references += 1; + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1:: + ~bigint_kernel_1 ( + ) + { + if (data->references == 1) + { + delete data; + } + else + { + data->references -= 1; + } + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_1 bigint_kernel_1:: + operator+ ( + const bigint_kernel_1& rhs + ) const + { + data_record* temp = new data_record ( + std::max(rhs.data->digits_used,data->digits_used) + slack + ); + long_add(data,rhs.data,temp); + return bigint_kernel_1(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1& bigint_kernel_1:: + operator+= ( + const bigint_kernel_1& rhs + ) + { + // if there are other references to our data + if (data->references != 1) + { + data_record* temp = new data_record(std::max(data->digits_used,rhs.data->digits_used)+slack); + data->references -= 1; + long_add(data,rhs.data,temp); + data = temp; + } + // if data is not big enough for the result + else if (data->size <= std::max(data->digits_used,rhs.data->digits_used)) + { + data_record* temp = new data_record(std::max(data->digits_used,rhs.data->digits_used)+slack); + long_add(data,rhs.data,temp); + delete data; + data = temp; + } + // there is enough size and no references + else + { + long_add(data,rhs.data,data); + } + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_1 bigint_kernel_1:: + operator- ( + const bigint_kernel_1& rhs + ) const + { + data_record* temp = new data_record ( + data->digits_used + slack + ); + long_sub(data,rhs.data,temp); + return bigint_kernel_1(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1& bigint_kernel_1:: + operator-= ( + const bigint_kernel_1& rhs + ) + { + // if there are other references to this data + if (data->references != 1) + { + data_record* temp = new data_record(data->digits_used+slack); + data->references -= 1; + long_sub(data,rhs.data,temp); + data = temp; + } + else + { + long_sub(data,rhs.data,data); + } + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_1 bigint_kernel_1:: + operator* ( + const bigint_kernel_1& rhs + ) const + { + data_record* temp = new data_record ( + data->digits_used + rhs.data->digits_used + slack + ); + long_mul(data,rhs.data,temp); + return bigint_kernel_1(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1& bigint_kernel_1:: + operator*= ( + const bigint_kernel_1& rhs + ) + { + // create a data_record to store the result of the multiplication in + data_record* temp = new data_record(rhs.data->digits_used+data->digits_used+slack); + long_mul(data,rhs.data,temp); + + // if there are other references to data + if (data->references != 1) + { + data->references -= 1; + } + else + { + delete data; + } + data = temp; + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_1 bigint_kernel_1:: + operator/ ( + const bigint_kernel_1& rhs + ) const + { + data_record* temp = new data_record(data->digits_used+slack); + data_record* remainder; + try { + remainder = new data_record(data->digits_used+slack); + } catch (...) { delete temp; throw; } + + long_div(data,rhs.data,temp,remainder); + delete remainder; + + + return bigint_kernel_1(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1& bigint_kernel_1:: + operator/= ( + const bigint_kernel_1& rhs + ) + { + + data_record* temp = new data_record(data->digits_used+slack); + data_record* remainder; + try { + remainder = new data_record(data->digits_used+slack); + } catch (...) { delete temp; throw; } + + long_div(data,rhs.data,temp,remainder); + + // check if there are other references to data + if (data->references != 1) + { + data->references -= 1; + } + // if there are no references to data then it must be deleted + else + { + delete data; + } + data = temp; + delete remainder; + + + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_1 bigint_kernel_1:: + operator% ( + const bigint_kernel_1& rhs + ) const + { + data_record* temp = new data_record(data->digits_used+slack); + data_record* remainder; + try { + remainder = new data_record(data->digits_used+slack); + } catch (...) { delete temp; throw; } + + long_div(data,rhs.data,temp,remainder); + delete temp; + return bigint_kernel_1(remainder,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1& bigint_kernel_1:: + operator%= ( + const bigint_kernel_1& rhs + ) + { + data_record* temp = new data_record(data->digits_used+slack); + data_record* remainder; + try { + remainder = new data_record(data->digits_used+slack); + } catch (...) { delete temp; throw; } + + long_div(data,rhs.data,temp,remainder); + + // check if there are other references to data + if (data->references != 1) + { + data->references -= 1; + } + // if there are no references to data then it must be deleted + else + { + delete data; + } + data = remainder; + delete temp; + return *this; + } + +// ---------------------------------------------------------------------------------------- + + bool bigint_kernel_1:: + operator < ( + const bigint_kernel_1& rhs + ) const + { + return is_less_than(data,rhs.data); + } + +// ---------------------------------------------------------------------------------------- + + bool bigint_kernel_1:: + operator == ( + const bigint_kernel_1& rhs + ) const + { + return is_equal_to(data,rhs.data); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1& bigint_kernel_1:: + operator= ( + const bigint_kernel_1& rhs + ) + { + if (this == &rhs) + return *this; + + // if we have the only reference to our data then delete it + if (data->references == 1) + { + delete data; + data = rhs.data; + data->references += 1; + } + else + { + data->references -= 1; + data = rhs.data; + data->references += 1; + } + + return *this; + } + +// ---------------------------------------------------------------------------------------- + + std::ostream& operator<< ( + std::ostream& out_, + const bigint_kernel_1& rhs + ) + { + std::ostream out(out_.rdbuf()); + + typedef bigint_kernel_1 bigint; + + bigint::data_record* temp = new bigint::data_record(*rhs.data,0); + + + + // get a char array big enough to hold the number in ascii format + char* str; + try { + str = new char[(rhs.data->digits_used)*5+10]; + } catch (...) { delete temp; throw; } + + char* str_start = str; + str += (rhs.data->digits_used)*5+9; + *str = 0; --str; + + + uint16 remainder; + rhs.short_div(temp,10000,temp,remainder); + + // pull the digits out of remainder + char a = remainder % 10 + '0'; + remainder /= 10; + char b = remainder % 10 + '0'; + remainder /= 10; + char c = remainder % 10 + '0'; + remainder /= 10; + char d = remainder % 10 + '0'; + remainder /= 10; + + + *str = a; --str; + *str = b; --str; + *str = c; --str; + *str = d; --str; + + + // keep looping until temp represents zero + while (temp->digits_used != 1 || *(temp->number) != 0) + { + rhs.short_div(temp,10000,temp,remainder); + + // pull the digits out of remainder + char a = remainder % 10 + '0'; + remainder /= 10; + char b = remainder % 10 + '0'; + remainder /= 10; + char c = remainder % 10 + '0'; + remainder /= 10; + char d = remainder % 10 + '0'; + remainder /= 10; + + *str = a; --str; + *str = b; --str; + *str = c; --str; + *str = d; --str; + } + + // throw away and extra leading zeros + ++str; + if (*str == '0') + ++str; + if (*str == '0') + ++str; + if (*str == '0') + ++str; + + + + + out << str; + delete [] str_start; + delete temp; + return out_; + + } + +// ---------------------------------------------------------------------------------------- + + std::istream& operator>> ( + std::istream& in_, + bigint_kernel_1& rhs + ) + { + std::istream in(in_.rdbuf()); + + // ignore any leading whitespaces + while (in.peek() == ' ' || in.peek() == '\t' || in.peek() == '\n') + { + in.get(); + } + + // if the first digit is not an integer then this is an error + if ( !(in.peek() >= '0' && in.peek() <= '9')) + { + in_.clear(std::ios::failbit); + return in_; + } + + int num_read; + bigint_kernel_1 temp; + do + { + + // try to get 4 chars from in + num_read = 1; + char a = 0; + char b = 0; + char c = 0; + char d = 0; + + if (in.peek() >= '0' && in.peek() <= '9') + { + num_read *= 10; + a = in.get(); + } + if (in.peek() >= '0' && in.peek() <= '9') + { + num_read *= 10; + b = in.get(); + } + if (in.peek() >= '0' && in.peek() <= '9') + { + num_read *= 10; + c = in.get(); + } + if (in.peek() >= '0' && in.peek() <= '9') + { + num_read *= 10; + d = in.get(); + } + + // merge the for digits into an uint16 + uint16 num = 0; + if (a != 0) + { + num = a - '0'; + } + if (b != 0) + { + num *= 10; + num += b - '0'; + } + if (c != 0) + { + num *= 10; + num += c - '0'; + } + if (d != 0) + { + num *= 10; + num += d - '0'; + } + + + if (num_read != 1) + { + // shift the digits in temp left by the number of new digits we just read + temp *= num_read; + // add in new digits + temp += num; + } + + } while (num_read == 10000); + + + rhs = temp; + return in_; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_1 operator+ ( + uint16 lhs, + const bigint_kernel_1& rhs + ) + { + typedef bigint_kernel_1 bigint; + bigint::data_record* temp = new bigint::data_record + (rhs.data->digits_used+rhs.slack); + + rhs.short_add(rhs.data,lhs,temp); + return bigint_kernel_1(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_1 operator+ ( + const bigint_kernel_1& lhs, + uint16 rhs + ) + { + typedef bigint_kernel_1 bigint; + bigint::data_record* temp = new bigint::data_record + (lhs.data->digits_used+lhs.slack); + + lhs.short_add(lhs.data,rhs,temp); + return bigint_kernel_1(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1& bigint_kernel_1:: + operator+= ( + uint16 rhs + ) + { + // if there are other references to this data + if (data->references != 1) + { + data_record* temp = new data_record(data->digits_used+slack); + data->references -= 1; + short_add(data,rhs,temp); + data = temp; + } + // or if we need to enlarge data then do so + else if (data->digits_used == data->size) + { + data_record* temp = new data_record(data->digits_used+slack); + short_add(data,rhs,temp); + delete data; + data = temp; + } + // or if there is plenty of space and no references + else + { + short_add(data,rhs,data); + } + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_1 operator- ( + uint16 lhs, + const bigint_kernel_1& rhs + ) + { + typedef bigint_kernel_1 bigint; + bigint::data_record* temp = new bigint::data_record(rhs.slack); + + *(temp->number) = lhs - *(rhs.data->number); + + return bigint_kernel_1(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_1 operator- ( + const bigint_kernel_1& lhs, + uint16 rhs + ) + { + typedef bigint_kernel_1 bigint; + bigint::data_record* temp = new bigint::data_record + (lhs.data->digits_used+lhs.slack); + + lhs.short_sub(lhs.data,rhs,temp); + return bigint_kernel_1(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1& bigint_kernel_1:: + operator-= ( + uint16 rhs + ) + { + // if there are other references to this data + if (data->references != 1) + { + data_record* temp = new data_record(data->digits_used+slack); + data->references -= 1; + short_sub(data,rhs,temp); + data = temp; + } + else + { + short_sub(data,rhs,data); + } + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_1 operator* ( + uint16 lhs, + const bigint_kernel_1& rhs + ) + { + typedef bigint_kernel_1 bigint; + bigint::data_record* temp = new bigint::data_record + (rhs.data->digits_used+rhs.slack); + + rhs.short_mul(rhs.data,lhs,temp); + return bigint_kernel_1(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_1 operator* ( + const bigint_kernel_1& lhs, + uint16 rhs + ) + { + typedef bigint_kernel_1 bigint; + bigint::data_record* temp = new bigint::data_record + (lhs.data->digits_used+lhs.slack); + + lhs.short_mul(lhs.data,rhs,temp); + return bigint_kernel_1(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1& bigint_kernel_1:: + operator*= ( + uint16 rhs + ) + { + // if there are other references to this data + if (data->references != 1) + { + data_record* temp = new data_record(data->digits_used+slack); + data->references -= 1; + short_mul(data,rhs,temp); + data = temp; + } + // or if we need to enlarge data + else if (data->digits_used == data->size) + { + data_record* temp = new data_record(data->digits_used+slack); + short_mul(data,rhs,temp); + delete data; + data = temp; + } + else + { + short_mul(data,rhs,data); + } + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_1 operator/ ( + uint16 lhs, + const bigint_kernel_1& rhs + ) + { + typedef bigint_kernel_1 bigint; + bigint::data_record* temp = new bigint::data_record(rhs.slack); + + // if rhs might not be bigger than lhs + if (rhs.data->digits_used == 1) + { + *(temp->number) = lhs/ *(rhs.data->number); + } + + return bigint_kernel_1(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_1 operator/ ( + const bigint_kernel_1& lhs, + uint16 rhs + ) + { + typedef bigint_kernel_1 bigint; + bigint::data_record* temp = new bigint::data_record + (lhs.data->digits_used+lhs.slack); + + uint16 remainder; + lhs.short_div(lhs.data,rhs,temp,remainder); + return bigint_kernel_1(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1& bigint_kernel_1:: + operator/= ( + uint16 rhs + ) + { + uint16 remainder; + // if there are other references to this data + if (data->references != 1) + { + data_record* temp = new data_record(data->digits_used+slack); + data->references -= 1; + short_div(data,rhs,temp,remainder); + data = temp; + } + else + { + short_div(data,rhs,data,remainder); + } + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_1 operator% ( + uint16 lhs, + const bigint_kernel_1& rhs + ) + { + typedef bigint_kernel_1 bigint; + // temp is zero by default + bigint::data_record* temp = new bigint::data_record(rhs.slack); + + if (rhs.data->digits_used == 1) + { + // if rhs is just an uint16 inside then perform the modulus + *(temp->number) = lhs % *(rhs.data->number); + } + else + { + // if rhs is bigger than lhs then the answer is lhs + *(temp->number) = lhs; + } + + return bigint_kernel_1(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_1 operator% ( + const bigint_kernel_1& lhs, + uint16 rhs + ) + { + typedef bigint_kernel_1 bigint; + bigint::data_record* temp = new bigint::data_record(lhs.data->digits_used+lhs.slack); + + uint16 remainder; + + lhs.short_div(lhs.data,rhs,temp,remainder); + temp->digits_used = 1; + *(temp->number) = remainder; + return bigint_kernel_1(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1& bigint_kernel_1:: + operator%= ( + uint16 rhs + ) + { + uint16 remainder; + // if there are other references to this data + if (data->references != 1) + { + data_record* temp = new data_record(data->digits_used+slack); + data->references -= 1; + short_div(data,rhs,temp,remainder); + data = temp; + } + else + { + short_div(data,rhs,data,remainder); + } + + data->digits_used = 1; + *(data->number) = remainder; + return *this; + } + +// ---------------------------------------------------------------------------------------- + + bool operator < ( + uint16 lhs, + const bigint_kernel_1& rhs + ) + { + return (rhs.data->digits_used > 1 || lhs < *(rhs.data->number) ); + } + +// ---------------------------------------------------------------------------------------- + + bool operator < ( + const bigint_kernel_1& lhs, + uint16 rhs + ) + { + return (lhs.data->digits_used == 1 && *(lhs.data->number) < rhs); + } + +// ---------------------------------------------------------------------------------------- + + bool operator == ( + const bigint_kernel_1& lhs, + uint16 rhs + ) + { + return (lhs.data->digits_used == 1 && *(lhs.data->number) == rhs); + } + +// ---------------------------------------------------------------------------------------- + + bool operator == ( + uint16 lhs, + const bigint_kernel_1& rhs + ) + { + return (rhs.data->digits_used == 1 && *(rhs.data->number) == lhs); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1& bigint_kernel_1:: + operator= ( + uint16 rhs + ) + { + // check if there are other references to our data + if (data->references != 1) + { + data->references -= 1; + try { + data = new data_record(slack); + } catch (...) { data->references += 1; throw; } + } + else + { + data->digits_used = 1; + } + + *(data->number) = rhs; + + return *this; + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1& bigint_kernel_1:: + operator++ ( + ) + { + // if there are other references to this data then make a copy of it + if (data->references != 1) + { + data_record* temp = new data_record(data->digits_used+slack); + data->references -= 1; + increment(data,temp); + data = temp; + } + // or if we need to enlarge data then do so + else if (data->digits_used == data->size) + { + data_record* temp = new data_record(data->digits_used+slack); + increment(data,temp); + delete data; + data = temp; + } + else + { + increment(data,data); + } + + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_1 bigint_kernel_1:: + operator++ ( + int + ) + { + data_record* temp; // this is the copy of temp we will return in the end + + data_record* temp2 = new data_record(data->digits_used+slack); + increment(data,temp2); + + temp = data; + data = temp2; + + return bigint_kernel_1(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1& bigint_kernel_1:: + operator-- ( + ) + { + // if there are other references to this data + if (data->references != 1) + { + data_record* temp = new data_record(data->digits_used+slack); + data->references -= 1; + decrement(data,temp); + data = temp; + } + else + { + decrement(data,data); + } + + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_1 bigint_kernel_1:: + operator-- ( + int + ) + { + data_record* temp; // this is the copy of temp we will return in the end + + data_record* temp2 = new data_record(data->digits_used+slack); + decrement(data,temp2); + + temp = data; + data = temp2; + + return bigint_kernel_1(temp,0); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // private member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_1:: + short_add ( + const data_record* data, + uint16 value, + data_record* result + ) const + { + // put value into the carry part of temp + uint32 temp = value; + temp <<= 16; + + + const uint16* number = data->number; + const uint16* end = number + data->digits_used; // one past the end of number + uint16* r = result->number; + + while (number != end) + { + // add *number and the current carry + temp = *number + (temp>>16); + // put the low word of temp into *r + *r = static_cast(temp & 0xFFFF); + + ++number; + ++r; + } + + // if there is a final carry + if ((temp>>16) != 0) + { + result->digits_used = data->digits_used + 1; + // store the carry in the most significant digit of the result + *r = static_cast(temp>>16); + } + else + { + result->digits_used = data->digits_used; + } + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_1:: + short_sub ( + const data_record* data, + uint16 value, + data_record* result + ) const + { + + + const uint16* number = data->number; + const uint16* end = number + data->digits_used - 1; + uint16* r = result->number; + + uint32 temp = *number - value; + + // put the low word of temp into *data + *r = static_cast(temp & 0xFFFF); + + + while (number != end) + { + ++number; + ++r; + + // subtract the carry from *number + temp = *number - (temp>>31); + + // put the low word of temp into *r + *r = static_cast(temp & 0xFFFF); + } + + // if we lost a digit in the subtraction + if (*r == 0) + { + if (data->digits_used == 1) + result->digits_used = 1; + else + result->digits_used = data->digits_used - 1; + } + else + { + result->digits_used = data->digits_used; + } + + + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_1:: + short_mul ( + const data_record* data, + uint16 value, + data_record* result + ) const + { + + uint32 temp = 0; + + + const uint16* number = data->number; + uint16* r = result->number; + const uint16* end = r + data->digits_used; + + + + while ( r != end) + { + + // multiply *data and value and add in the carry + temp = *number*(uint32)value + (temp>>16); + + // put the low word of temp into *data + *r = static_cast(temp & 0xFFFF); + + ++number; + ++r; + } + + // if there is a final carry + if ((temp>>16) != 0) + { + result->digits_used = data->digits_used + 1; + // put the final carry into the most significant digit of the result + *r = static_cast(temp>>16); + } + else + { + result->digits_used = data->digits_used; + } + + + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_1:: + short_div ( + const data_record* data, + uint16 value, + data_record* result, + uint16& rem + ) const + { + + uint16 remainder = 0; + uint32 temp; + + + + const uint16* number = data->number + data->digits_used - 1; + const uint16* end = number - data->digits_used; + uint16* r = result->number + data->digits_used - 1; + + + // if we are losing a digit in this division + if (*number < value) + { + if (data->digits_used == 1) + result->digits_used = 1; + else + result->digits_used = data->digits_used - 1; + } + else + { + result->digits_used = data->digits_used; + } + + + // perform the actual division + while (number != end) + { + + temp = *number + (((uint32)remainder)<<16); + + *r = static_cast(temp/value); + remainder = static_cast(temp%value); + + --number; + --r; + } + + rem = remainder; + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_1:: + long_add ( + const data_record* lhs, + const data_record* rhs, + data_record* result + ) const + { + // put value into the carry part of temp + uint32 temp=0; + + uint16* min_num; // the number with the least digits used + uint16* max_num; // the number with the most digits used + uint16* min_end; // one past the end of min_num + uint16* max_end; // one past the end of max_num + uint16* r = result->number; + + uint32 max_digits_used; + if (lhs->digits_used < rhs->digits_used) + { + max_digits_used = rhs->digits_used; + min_num = lhs->number; + max_num = rhs->number; + min_end = min_num + lhs->digits_used; + max_end = max_num + rhs->digits_used; + } + else + { + max_digits_used = lhs->digits_used; + min_num = rhs->number; + max_num = lhs->number; + min_end = min_num + rhs->digits_used; + max_end = max_num + lhs->digits_used; + } + + + + + while (min_num != min_end) + { + // add *min_num, *max_num and the current carry + temp = *min_num + *max_num + (temp>>16); + // put the low word of temp into *r + *r = static_cast(temp & 0xFFFF); + + ++min_num; + ++max_num; + ++r; + } + + + while (max_num != max_end) + { + // add *max_num and the current carry + temp = *max_num + (temp>>16); + // put the low word of temp into *r + *r = static_cast(temp & 0xFFFF); + + ++max_num; + ++r; + } + + // check if there was a final carry + if ((temp>>16) != 0) + { + result->digits_used = max_digits_used + 1; + // put the carry into the most significant digit in the result + *r = static_cast(temp>>16); + } + else + { + result->digits_used = max_digits_used; + } + + + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_1:: + long_sub ( + const data_record* lhs, + const data_record* rhs, + data_record* result + ) const + { + + + const uint16* number1 = lhs->number; + const uint16* number2 = rhs->number; + const uint16* end = number2 + rhs->digits_used; + uint16* r = result->number; + + + + uint32 temp =0; + + + while (number2 != end) + { + + // subtract *number2 from *number1 and then subtract any carry + temp = *number1 - *number2 - (temp>>31); + + // put the low word of temp into *r + *r = static_cast(temp & 0xFFFF); + + ++number1; + ++number2; + ++r; + } + + end = lhs->number + lhs->digits_used; + while (number1 != end) + { + + // subtract the carry from *number1 + temp = *number1 - (temp>>31); + + // put the low word of temp into *r + *r = static_cast(temp & 0xFFFF); + + ++number1; + ++r; + } + + result->digits_used = lhs->digits_used; + // adjust the number of digits used appropriately + --r; + while (*r == 0 && result->digits_used > 1) + { + --r; + --result->digits_used; + } + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_1:: + long_div ( + const data_record* lhs, + const data_record* rhs, + data_record* result, + data_record* remainder + ) const + { + // zero result + result->digits_used = 1; + *(result->number) = 0; + + uint16* a; + uint16* b; + uint16* end; + + // copy lhs into remainder + remainder->digits_used = lhs->digits_used; + a = remainder->number; + end = a + remainder->digits_used; + b = lhs->number; + while (a != end) + { + *a = *b; + ++a; + ++b; + } + + + // if rhs is bigger than lhs then result == 0 and remainder == lhs + // so then we can quit right now + if (is_less_than(lhs,rhs)) + { + return; + } + + + // make a temporary number + data_record temp(lhs->digits_used + slack); + + + // shift rhs left until it is one shift away from being larger than lhs and + // put the number of left shifts necessary into shifts + uint32 shifts; + shifts = (lhs->digits_used - rhs->digits_used) * 16; + + shift_left(rhs,&temp,shifts); + + + // while (lhs > temp) + while (is_less_than(&temp,lhs)) + { + shift_left(&temp,&temp,1); + ++shifts; + } + // make sure lhs isn't smaller than temp + while (is_less_than(lhs,&temp)) + { + shift_right(&temp,&temp); + --shifts; + } + + + + // we want to execute the loop shifts +1 times + ++shifts; + while (shifts != 0) + { + shift_left(result,result,1); + // if (temp <= remainder) + if (!is_less_than(remainder,&temp)) + { + long_sub(remainder,&temp,remainder); + + // increment result + uint16* r = result->number; + uint16* end = r + result->digits_used; + while (true) + { + ++(*r); + // if there was no carry then we are done + if (*r != 0) + break; + + ++r; + + // if we hit the end of r and there is still a carry then + // the next digit of r is 1 and there is one more digit used + if (r == end) + { + *r = 1; + ++(result->digits_used); + break; + } + } + } + shift_right(&temp,&temp); + --shifts; + } + + + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_1:: + long_mul ( + const data_record* lhs, + const data_record* rhs, + data_record* result + ) const + { + // make result be zero + result->digits_used = 1; + *(result->number) = 0; + + + const data_record* aa; + const data_record* bb; + + if (lhs->digits_used < rhs->digits_used) + { + // make copies of lhs and rhs and give them an appropriate amount of + // extra memory so there won't be any overflows + aa = lhs; + bb = rhs; + } + else + { + // make copies of lhs and rhs and give them an appropriate amount of + // extra memory so there won't be any overflows + aa = rhs; + bb = lhs; + } + // this is where we actually copy lhs and rhs + data_record b(*bb,aa->digits_used+slack); // the larger(approximately) of lhs and rhs + + + uint32 shift_value = 0; + uint16* anum = aa->number; + uint16* end = anum + aa->digits_used; + while (anum != end ) + { + uint16 bit = 0x0001; + + for (int i = 0; i < 16; ++i) + { + // if the specified bit of a is 1 + if ((*anum & bit) != 0) + { + shift_left(&b,&b,shift_value); + shift_value = 0; + long_add(&b,result,result); + } + ++shift_value; + bit <<= 1; + } + + ++anum; + } + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_1:: + shift_left ( + const data_record* data, + data_record* result, + uint32 shift_amount + ) const + { + uint32 offset = shift_amount/16; + shift_amount &= 0xf; // same as shift_amount %= 16; + + uint16* r = result->number + data->digits_used + offset; // result + uint16* end = data->number; + uint16* s = end + data->digits_used; // source + const uint32 temp = 16 - shift_amount; + + *r = (*(--s) >> temp); + // set the number of digits used in the result + // if the upper bits from *s were zero then don't count this first word + if (*r == 0) + { + result->digits_used = data->digits_used + offset; + } + else + { + result->digits_used = data->digits_used + offset + 1; + } + --r; + + while (s != end) + { + *r = ((*s << shift_amount) | ( *(s-1) >> temp)); + --r; + --s; + } + *r = *s << shift_amount; + + // now zero the rest of the result + end = result->number; + while (r != end) + *(--r) = 0; + + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_1:: + shift_right ( + const data_record* data, + data_record* result + ) const + { + + uint16* r = result->number; // result + uint16* s = data->number; // source + uint16* end = s + data->digits_used - 1; + + while (s != end) + { + *r = (*s >> 1) | (*(s+1) << 15); + ++r; + ++s; + } + *r = *s >> 1; + + + // calculate the new number for digits_used + if (*r == 0) + { + if (data->digits_used != 1) + result->digits_used = data->digits_used - 1; + else + result->digits_used = 1; + } + else + { + result->digits_used = data->digits_used; + } + + + } + +// ---------------------------------------------------------------------------------------- + + bool bigint_kernel_1:: + is_less_than ( + const data_record* lhs, + const data_record* rhs + ) const + { + uint32 lhs_digits_used = lhs->digits_used; + uint32 rhs_digits_used = rhs->digits_used; + + // if lhs is definitely less than rhs + if (lhs_digits_used < rhs_digits_used ) + return true; + // if lhs is definitely greater than rhs + else if (lhs_digits_used > rhs_digits_used) + return false; + else + { + uint16* end = lhs->number; + uint16* l = end + lhs_digits_used; + uint16* r = rhs->number + rhs_digits_used; + + while (l != end) + { + --l; + --r; + if (*l < *r) + return true; + else if (*l > *r) + return false; + } + + // at this point we know that they are equal + return false; + } + + } + +// ---------------------------------------------------------------------------------------- + + bool bigint_kernel_1:: + is_equal_to ( + const data_record* lhs, + const data_record* rhs + ) const + { + // if lhs and rhs are definitely not equal + if (lhs->digits_used != rhs->digits_used ) + { + return false; + } + else + { + uint16* l = lhs->number; + uint16* r = rhs->number; + uint16* end = l + lhs->digits_used; + + while (l != end) + { + if (*l != *r) + return false; + ++l; + ++r; + } + + // at this point we know that they are equal + return true; + } + + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_1:: + increment ( + const data_record* source, + data_record* dest + ) const + { + uint16* s = source->number; + uint16* d = dest->number; + uint16* end = s + source->digits_used; + while (true) + { + *d = *s + 1; + + // if there was no carry then break out of the loop + if (*d != 0) + { + dest->digits_used = source->digits_used; + + // copy the rest of the digits over to d + ++d; ++s; + while (s != end) + { + *d = *s; + ++d; + ++s; + } + + break; + } + + + ++s; + + // if we have hit the end of s and there was a carry up to this point + // then just make the next digit 1 and add one to the digits used + if (s == end) + { + ++d; + dest->digits_used = source->digits_used + 1; + *d = 1; + break; + } + + ++d; + } + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_1:: + decrement ( + const data_record* source, + data_record* dest + ) const + { + uint16* s = source->number; + uint16* d = dest->number; + uint16* end = s + source->digits_used; + while (true) + { + *d = *s - 1; + + // if there was no carry then break out of the loop + if (*d != 0xFFFF) + { + // if we lost a digit in the subtraction + if (*d == 0 && s+1 == end) + { + if (source->digits_used == 1) + dest->digits_used = 1; + else + dest->digits_used = source->digits_used - 1; + } + else + { + dest->digits_used = source->digits_used; + } + break; + } + else + { + ++d; + ++s; + } + + } + + // copy the rest of the digits over to d + ++d; + ++s; + while (s != end) + { + *d = *s; + ++d; + ++s; + } + } + +// ---------------------------------------------------------------------------------------- + +} +#endif // DLIB_BIGINT_KERNEL_1_CPp_ + diff --git a/ml/dlib/dlib/bigint/bigint_kernel_1.h b/ml/dlib/dlib/bigint/bigint_kernel_1.h new file mode 100644 index 000000000..3e7f3d851 --- /dev/null +++ b/ml/dlib/dlib/bigint/bigint_kernel_1.h @@ -0,0 +1,544 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BIGINT_KERNEl_1_ +#define DLIB_BIGINT_KERNEl_1_ + +#include "bigint_kernel_abstract.h" +#include "../algs.h" +#include "../serialize.h" +#include "../uintn.h" +#include + +namespace dlib +{ + + + class bigint_kernel_1 + { + /*! + INITIAL VALUE + slack == 25 + data->number[0] == 0 + data->size == slack + data->references == 1 + data->digits_used == 1 + + + CONVENTION + slack == the number of extra digits placed into the number when it is + created. the slack value should never be less than 1 + + data->number == pointer to an array of data->size uint16s. + data represents a string of base 65535 numbers with data[0] being + the least significant bit and data[data->digits_used-1] being the most + significant + + + NOTE: In the comments I will consider a word to be a 16 bit value + + + data->digits_used == the number of significant digits in the number. + data->digits_used tells us the number of used elements in the + data->number array so everything beyond data->number[data->digits_used-1] + is undefined + + data->references == the number of bigint_kernel_1 objects which refer + to this data_record + + + + !*/ + + + struct data_record + { + + + explicit data_record( + uint32 size_ + ) : + size(size_), + number(new uint16[size_]), + references(1), + digits_used(1) + {*number = 0;} + /*! + ensures + - initializes *this to represent zero + !*/ + + data_record( + const data_record& item, + uint32 additional_size + ) : + size(item.digits_used + additional_size), + number(new uint16[size]), + references(1), + digits_used(item.digits_used) + { + uint16* source = item.number; + uint16* dest = number; + uint16* end = source + digits_used; + while (source != end) + { + *dest = *source; + ++dest; + ++source; + } + } + /*! + ensures + - *this is a copy of item except with + size == item.digits_used + additional_size + !*/ + + ~data_record( + ) + { + delete [] number; + } + + + const uint32 size; + uint16* number; + uint32 references; + uint32 digits_used; + + private: + // no copy constructor + data_record ( data_record&); + }; + + + + // note that the second parameter is just there + // to resolve the ambiguity between this constructor and + // bigint_kernel_1(uint32) + explicit bigint_kernel_1 ( + data_record* data_, int + ): slack(25),data(data_) {} + /*! + ensures + - *this is initialized with data_ as its data member + !*/ + + + public: + + bigint_kernel_1 ( + ); + + bigint_kernel_1 ( + uint32 value + ); + + bigint_kernel_1 ( + const bigint_kernel_1& item + ); + + virtual ~bigint_kernel_1 ( + ); + + const bigint_kernel_1 operator+ ( + const bigint_kernel_1& rhs + ) const; + + bigint_kernel_1& operator+= ( + const bigint_kernel_1& rhs + ); + + const bigint_kernel_1 operator- ( + const bigint_kernel_1& rhs + ) const; + + bigint_kernel_1& operator-= ( + const bigint_kernel_1& rhs + ); + + const bigint_kernel_1 operator* ( + const bigint_kernel_1& rhs + ) const; + + bigint_kernel_1& operator*= ( + const bigint_kernel_1& rhs + ); + + const bigint_kernel_1 operator/ ( + const bigint_kernel_1& rhs + ) const; + + bigint_kernel_1& operator/= ( + const bigint_kernel_1& rhs + ); + + const bigint_kernel_1 operator% ( + const bigint_kernel_1& rhs + ) const; + + bigint_kernel_1& operator%= ( + const bigint_kernel_1& rhs + ); + + bool operator < ( + const bigint_kernel_1& rhs + ) const; + + bool operator == ( + const bigint_kernel_1& rhs + ) const; + + bigint_kernel_1& operator= ( + const bigint_kernel_1& rhs + ); + + friend std::ostream& operator<< ( + std::ostream& out, + const bigint_kernel_1& rhs + ); + + friend std::istream& operator>> ( + std::istream& in, + bigint_kernel_1& rhs + ); + + bigint_kernel_1& operator++ ( + ); + + const bigint_kernel_1 operator++ ( + int + ); + + bigint_kernel_1& operator-- ( + ); + + const bigint_kernel_1 operator-- ( + int + ); + + friend const bigint_kernel_1 operator+ ( + uint16 lhs, + const bigint_kernel_1& rhs + ); + + friend const bigint_kernel_1 operator+ ( + const bigint_kernel_1& lhs, + uint16 rhs + ); + + bigint_kernel_1& operator+= ( + uint16 rhs + ); + + friend const bigint_kernel_1 operator- ( + uint16 lhs, + const bigint_kernel_1& rhs + ); + + friend const bigint_kernel_1 operator- ( + const bigint_kernel_1& lhs, + uint16 rhs + ); + + bigint_kernel_1& operator-= ( + uint16 rhs + ); + + friend const bigint_kernel_1 operator* ( + uint16 lhs, + const bigint_kernel_1& rhs + ); + + friend const bigint_kernel_1 operator* ( + const bigint_kernel_1& lhs, + uint16 rhs + ); + + bigint_kernel_1& operator*= ( + uint16 rhs + ); + + friend const bigint_kernel_1 operator/ ( + uint16 lhs, + const bigint_kernel_1& rhs + ); + + friend const bigint_kernel_1 operator/ ( + const bigint_kernel_1& lhs, + uint16 rhs + ); + + bigint_kernel_1& operator/= ( + uint16 rhs + ); + + friend const bigint_kernel_1 operator% ( + uint16 lhs, + const bigint_kernel_1& rhs + ); + + friend const bigint_kernel_1 operator% ( + const bigint_kernel_1& lhs, + uint16 rhs + ); + + bigint_kernel_1& operator%= ( + uint16 rhs + ); + + friend bool operator < ( + uint16 lhs, + const bigint_kernel_1& rhs + ); + + friend bool operator < ( + const bigint_kernel_1& lhs, + uint16 rhs + ); + + friend bool operator == ( + const bigint_kernel_1& lhs, + uint16 rhs + ); + + friend bool operator == ( + uint16 lhs, + const bigint_kernel_1& rhs + ); + + bigint_kernel_1& operator= ( + uint16 rhs + ); + + + void swap ( + bigint_kernel_1& item + ) { data_record* temp = data; data = item.data; item.data = temp; } + + + private: + + void long_add ( + const data_record* lhs, + const data_record* rhs, + data_record* result + ) const; + /*! + requires + - result->size >= max(lhs->digits_used,rhs->digits_used) + 1 + ensures + - result == lhs + rhs + !*/ + + void long_sub ( + const data_record* lhs, + const data_record* rhs, + data_record* result + ) const; + /*! + requires + - lhs >= rhs + - result->size >= lhs->digits_used + ensures + - result == lhs - rhs + !*/ + + void long_div ( + const data_record* lhs, + const data_record* rhs, + data_record* result, + data_record* remainder + ) const; + /*! + requires + - rhs != 0 + - result->size >= lhs->digits_used + - remainder->size >= lhs->digits_used + - each parameter is unique (i.e. lhs != result, lhs != remainder, etc.) + ensures + - result == lhs / rhs + - remainder == lhs % rhs + !*/ + + void long_mul ( + const data_record* lhs, + const data_record* rhs, + data_record* result + ) const; + /*! + requires + - result->size >= lhs->digits_used + rhs->digits_used + - result != lhs + - result != rhs + ensures + - result == lhs * rhs + !*/ + + void short_add ( + const data_record* data, + uint16 value, + data_record* result + ) const; + /*! + requires + - result->size >= data->size + 1 + ensures + - result == data + value + !*/ + + void short_sub ( + const data_record* data, + uint16 value, + data_record* result + ) const; + /*! + requires + - data >= value + - result->size >= data->digits_used + ensures + - result == data - value + !*/ + + void short_mul ( + const data_record* data, + uint16 value, + data_record* result + ) const; + /*! + requires + - result->size >= data->digits_used + 1 + ensures + - result == data * value + !*/ + + void short_div ( + const data_record* data, + uint16 value, + data_record* result, + uint16& remainder + ) const; + /*! + requires + - value != 0 + - result->size >= data->digits_used + ensures + - result = data*value + - remainder = data%value + !*/ + + void shift_left ( + const data_record* data, + data_record* result, + uint32 shift_amount + ) const; + /*! + requires + - result->size >= data->digits_used + shift_amount/8 + 1 + ensures + - result == data << shift_amount + !*/ + + void shift_right ( + const data_record* data, + data_record* result + ) const; + /*! + requires + - result->size >= data->digits_used + ensures + - result == data >> 1 + !*/ + + bool is_less_than ( + const data_record* lhs, + const data_record* rhs + ) const; + /*! + ensures + - returns true if lhs < rhs + - returns false otherwise + !*/ + + bool is_equal_to ( + const data_record* lhs, + const data_record* rhs + ) const; + /*! + ensures + - returns true if lhs == rhs + - returns false otherwise + !*/ + + void increment ( + const data_record* source, + data_record* dest + ) const; + /*! + requires + - dest->size >= source->digits_used + 1 + ensures + - dest = source + 1 + !*/ + + void decrement ( + const data_record* source, + data_record* dest + ) const; + /*! + requires + source != 0 + ensuers + dest = source - 1 + !*/ + + // member data + const uint32 slack; + data_record* data; + + + + }; + + inline void swap ( + bigint_kernel_1& a, + bigint_kernel_1& b + ) { a.swap(b); } + + inline void serialize ( + const bigint_kernel_1& item, + std::ostream& out + ) + { + std::ios::fmtflags oldflags = out.flags(); + out.flags(); + out << item << ' '; + out.flags(oldflags); + if (!out) throw serialization_error("Error serializing object of type bigint_kernel_c"); + } + + inline void deserialize ( + bigint_kernel_1& item, + std::istream& in + ) + { + std::ios::fmtflags oldflags = in.flags(); + in.flags(); + in >> item; in.flags(oldflags); + if (in.get() != ' ') + { + item = 0; + throw serialization_error("Error deserializing object of type bigint_kernel_c"); + } + } + + inline bool operator> (const bigint_kernel_1& a, const bigint_kernel_1& b) { return b < a; } + inline bool operator!= (const bigint_kernel_1& a, const bigint_kernel_1& b) { return !(a == b); } + inline bool operator<= (const bigint_kernel_1& a, const bigint_kernel_1& b) { return !(b < a); } + inline bool operator>= (const bigint_kernel_1& a, const bigint_kernel_1& b) { return !(a < b); } +} + +#ifdef NO_MAKEFILE +#include "bigint_kernel_1.cpp" +#endif + +#endif // DLIB_BIGINT_KERNEl_1_ + diff --git a/ml/dlib/dlib/bigint/bigint_kernel_2.cpp b/ml/dlib/dlib/bigint/bigint_kernel_2.cpp new file mode 100644 index 000000000..005e080af --- /dev/null +++ b/ml/dlib/dlib/bigint/bigint_kernel_2.cpp @@ -0,0 +1,1945 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BIGINT_KERNEL_2_CPp_ +#define DLIB_BIGINT_KERNEL_2_CPp_ +#include "bigint_kernel_2.h" + +#include +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member/friend function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2:: + bigint_kernel_2 ( + ) : + slack(25), + data(new data_record(slack)) + {} + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2:: + bigint_kernel_2 ( + uint32 value + ) : + slack(25), + data(new data_record(slack)) + { + *(data->number) = static_cast(value&0xFFFF); + *(data->number+1) = static_cast((value>>16)&0xFFFF); + if (*(data->number+1) != 0) + data->digits_used = 2; + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2:: + bigint_kernel_2 ( + const bigint_kernel_2& item + ) : + slack(25), + data(item.data) + { + data->references += 1; + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2:: + ~bigint_kernel_2 ( + ) + { + if (data->references == 1) + { + delete data; + } + else + { + data->references -= 1; + } + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_2 bigint_kernel_2:: + operator+ ( + const bigint_kernel_2& rhs + ) const + { + data_record* temp = new data_record ( + std::max(rhs.data->digits_used,data->digits_used) + slack + ); + long_add(data,rhs.data,temp); + return bigint_kernel_2(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2& bigint_kernel_2:: + operator+= ( + const bigint_kernel_2& rhs + ) + { + // if there are other references to our data + if (data->references != 1) + { + data_record* temp = new data_record(std::max(data->digits_used,rhs.data->digits_used)+slack); + data->references -= 1; + long_add(data,rhs.data,temp); + data = temp; + } + // if data is not big enough for the result + else if (data->size <= std::max(data->digits_used,rhs.data->digits_used)) + { + data_record* temp = new data_record(std::max(data->digits_used,rhs.data->digits_used)+slack); + long_add(data,rhs.data,temp); + delete data; + data = temp; + } + // there is enough size and no references + else + { + long_add(data,rhs.data,data); + } + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_2 bigint_kernel_2:: + operator- ( + const bigint_kernel_2& rhs + ) const + { + data_record* temp = new data_record ( + data->digits_used + slack + ); + long_sub(data,rhs.data,temp); + return bigint_kernel_2(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2& bigint_kernel_2:: + operator-= ( + const bigint_kernel_2& rhs + ) + { + // if there are other references to this data + if (data->references != 1) + { + data_record* temp = new data_record(data->digits_used+slack); + data->references -= 1; + long_sub(data,rhs.data,temp); + data = temp; + } + else + { + long_sub(data,rhs.data,data); + } + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_2 bigint_kernel_2:: + operator* ( + const bigint_kernel_2& rhs + ) const + { + data_record* temp = new data_record ( + data->digits_used + rhs.data->digits_used + slack + ); + long_mul(data,rhs.data,temp); + return bigint_kernel_2(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2& bigint_kernel_2:: + operator*= ( + const bigint_kernel_2& rhs + ) + { + // create a data_record to store the result of the multiplication in + data_record* temp = new data_record(rhs.data->digits_used+data->digits_used+slack); + long_mul(data,rhs.data,temp); + + // if there are other references to data + if (data->references != 1) + { + data->references -= 1; + } + else + { + delete data; + } + data = temp; + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_2 bigint_kernel_2:: + operator/ ( + const bigint_kernel_2& rhs + ) const + { + data_record* temp = new data_record(data->digits_used+slack); + data_record* remainder; + try { + remainder = new data_record(data->digits_used+slack); + } catch (...) { delete temp; throw; } + + long_div(data,rhs.data,temp,remainder); + delete remainder; + + + return bigint_kernel_2(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2& bigint_kernel_2:: + operator/= ( + const bigint_kernel_2& rhs + ) + { + + data_record* temp = new data_record(data->digits_used+slack); + data_record* remainder; + try { + remainder = new data_record(data->digits_used+slack); + } catch (...) { delete temp; throw; } + + long_div(data,rhs.data,temp,remainder); + + // check if there are other references to data + if (data->references != 1) + { + data->references -= 1; + } + // if there are no references to data then it must be deleted + else + { + delete data; + } + data = temp; + delete remainder; + + + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_2 bigint_kernel_2:: + operator% ( + const bigint_kernel_2& rhs + ) const + { + data_record* temp = new data_record(data->digits_used+slack); + data_record* remainder; + try { + remainder = new data_record(data->digits_used+slack); + } catch (...) { delete temp; throw; } + + long_div(data,rhs.data,temp,remainder); + delete temp; + return bigint_kernel_2(remainder,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2& bigint_kernel_2:: + operator%= ( + const bigint_kernel_2& rhs + ) + { + data_record* temp = new data_record(data->digits_used+slack); + data_record* remainder; + try { + remainder = new data_record(data->digits_used+slack); + } catch (...) { delete temp; throw; } + + long_div(data,rhs.data,temp,remainder); + + // check if there are other references to data + if (data->references != 1) + { + data->references -= 1; + } + // if there are no references to data then it must be deleted + else + { + delete data; + } + data = remainder; + delete temp; + return *this; + } + +// ---------------------------------------------------------------------------------------- + + bool bigint_kernel_2:: + operator < ( + const bigint_kernel_2& rhs + ) const + { + return is_less_than(data,rhs.data); + } + +// ---------------------------------------------------------------------------------------- + + bool bigint_kernel_2:: + operator == ( + const bigint_kernel_2& rhs + ) const + { + return is_equal_to(data,rhs.data); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2& bigint_kernel_2:: + operator= ( + const bigint_kernel_2& rhs + ) + { + if (this == &rhs) + return *this; + + // if we have the only reference to our data then delete it + if (data->references == 1) + { + delete data; + data = rhs.data; + data->references += 1; + } + else + { + data->references -= 1; + data = rhs.data; + data->references += 1; + } + + return *this; + } + +// ---------------------------------------------------------------------------------------- + + std::ostream& operator<< ( + std::ostream& out_, + const bigint_kernel_2& rhs + ) + { + std::ostream out(out_.rdbuf()); + + typedef bigint_kernel_2 bigint; + + bigint::data_record* temp = new bigint::data_record(*rhs.data,0); + + + + // get a char array big enough to hold the number in ascii format + char* str; + try { + str = new char[(rhs.data->digits_used)*5+10]; + } catch (...) { delete temp; throw; } + + char* str_start = str; + str += (rhs.data->digits_used)*5+9; + *str = 0; --str; + + + uint16 remainder; + rhs.short_div(temp,10000,temp,remainder); + + // pull the digits out of remainder + char a = remainder % 10 + '0'; + remainder /= 10; + char b = remainder % 10 + '0'; + remainder /= 10; + char c = remainder % 10 + '0'; + remainder /= 10; + char d = remainder % 10 + '0'; + remainder /= 10; + + + *str = a; --str; + *str = b; --str; + *str = c; --str; + *str = d; --str; + + + // keep looping until temp represents zero + while (temp->digits_used != 1 || *(temp->number) != 0) + { + rhs.short_div(temp,10000,temp,remainder); + + // pull the digits out of remainder + char a = remainder % 10 + '0'; + remainder /= 10; + char b = remainder % 10 + '0'; + remainder /= 10; + char c = remainder % 10 + '0'; + remainder /= 10; + char d = remainder % 10 + '0'; + remainder /= 10; + + *str = a; --str; + *str = b; --str; + *str = c; --str; + *str = d; --str; + } + + // throw away and extra leading zeros + ++str; + if (*str == '0') + ++str; + if (*str == '0') + ++str; + if (*str == '0') + ++str; + + + + + out << str; + delete [] str_start; + delete temp; + return out_; + + } + +// ---------------------------------------------------------------------------------------- + + std::istream& operator>> ( + std::istream& in_, + bigint_kernel_2& rhs + ) + { + std::istream in(in_.rdbuf()); + + // ignore any leading whitespaces + while (in.peek() == ' ' || in.peek() == '\t' || in.peek() == '\n') + { + in.get(); + } + + // if the first digit is not an integer then this is an error + if ( !(in.peek() >= '0' && in.peek() <= '9')) + { + in_.clear(std::ios::failbit); + return in_; + } + + int num_read; + bigint_kernel_2 temp; + do + { + + // try to get 4 chars from in + num_read = 1; + char a = 0; + char b = 0; + char c = 0; + char d = 0; + + if (in.peek() >= '0' && in.peek() <= '9') + { + num_read *= 10; + a = in.get(); + } + if (in.peek() >= '0' && in.peek() <= '9') + { + num_read *= 10; + b = in.get(); + } + if (in.peek() >= '0' && in.peek() <= '9') + { + num_read *= 10; + c = in.get(); + } + if (in.peek() >= '0' && in.peek() <= '9') + { + num_read *= 10; + d = in.get(); + } + + // merge the for digits into an uint16 + uint16 num = 0; + if (a != 0) + { + num = a - '0'; + } + if (b != 0) + { + num *= 10; + num += b - '0'; + } + if (c != 0) + { + num *= 10; + num += c - '0'; + } + if (d != 0) + { + num *= 10; + num += d - '0'; + } + + + if (num_read != 1) + { + // shift the digits in temp left by the number of new digits we just read + temp *= num_read; + // add in new digits + temp += num; + } + + } while (num_read == 10000); + + + rhs = temp; + return in_; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_2 operator+ ( + uint16 lhs, + const bigint_kernel_2& rhs + ) + { + typedef bigint_kernel_2 bigint; + bigint::data_record* temp = new bigint::data_record + (rhs.data->digits_used+rhs.slack); + + rhs.short_add(rhs.data,lhs,temp); + return bigint_kernel_2(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_2 operator+ ( + const bigint_kernel_2& lhs, + uint16 rhs + ) + { + typedef bigint_kernel_2 bigint; + bigint::data_record* temp = new bigint::data_record + (lhs.data->digits_used+lhs.slack); + + lhs.short_add(lhs.data,rhs,temp); + return bigint_kernel_2(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2& bigint_kernel_2:: + operator+= ( + uint16 rhs + ) + { + // if there are other references to this data + if (data->references != 1) + { + data_record* temp = new data_record(data->digits_used+slack); + data->references -= 1; + short_add(data,rhs,temp); + data = temp; + } + // or if we need to enlarge data then do so + else if (data->digits_used == data->size) + { + data_record* temp = new data_record(data->digits_used+slack); + short_add(data,rhs,temp); + delete data; + data = temp; + } + // or if there is plenty of space and no references + else + { + short_add(data,rhs,data); + } + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_2 operator- ( + uint16 lhs, + const bigint_kernel_2& rhs + ) + { + typedef bigint_kernel_2 bigint; + bigint::data_record* temp = new bigint::data_record(rhs.slack); + + *(temp->number) = lhs - *(rhs.data->number); + + return bigint_kernel_2(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_2 operator- ( + const bigint_kernel_2& lhs, + uint16 rhs + ) + { + typedef bigint_kernel_2 bigint; + bigint::data_record* temp = new bigint::data_record + (lhs.data->digits_used+lhs.slack); + + lhs.short_sub(lhs.data,rhs,temp); + return bigint_kernel_2(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2& bigint_kernel_2:: + operator-= ( + uint16 rhs + ) + { + // if there are other references to this data + if (data->references != 1) + { + data_record* temp = new data_record(data->digits_used+slack); + data->references -= 1; + short_sub(data,rhs,temp); + data = temp; + } + else + { + short_sub(data,rhs,data); + } + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_2 operator* ( + uint16 lhs, + const bigint_kernel_2& rhs + ) + { + typedef bigint_kernel_2 bigint; + bigint::data_record* temp = new bigint::data_record + (rhs.data->digits_used+rhs.slack); + + rhs.short_mul(rhs.data,lhs,temp); + return bigint_kernel_2(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_2 operator* ( + const bigint_kernel_2& lhs, + uint16 rhs + ) + { + typedef bigint_kernel_2 bigint; + bigint::data_record* temp = new bigint::data_record + (lhs.data->digits_used+lhs.slack); + + lhs.short_mul(lhs.data,rhs,temp); + return bigint_kernel_2(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2& bigint_kernel_2:: + operator*= ( + uint16 rhs + ) + { + // if there are other references to this data + if (data->references != 1) + { + data_record* temp = new data_record(data->digits_used+slack); + data->references -= 1; + short_mul(data,rhs,temp); + data = temp; + } + // or if we need to enlarge data + else if (data->digits_used == data->size) + { + data_record* temp = new data_record(data->digits_used+slack); + short_mul(data,rhs,temp); + delete data; + data = temp; + } + else + { + short_mul(data,rhs,data); + } + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_2 operator/ ( + uint16 lhs, + const bigint_kernel_2& rhs + ) + { + typedef bigint_kernel_2 bigint; + bigint::data_record* temp = new bigint::data_record(rhs.slack); + + // if rhs might not be bigger than lhs + if (rhs.data->digits_used == 1) + { + *(temp->number) = lhs/ *(rhs.data->number); + } + + return bigint_kernel_2(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_2 operator/ ( + const bigint_kernel_2& lhs, + uint16 rhs + ) + { + typedef bigint_kernel_2 bigint; + bigint::data_record* temp = new bigint::data_record + (lhs.data->digits_used+lhs.slack); + + uint16 remainder; + lhs.short_div(lhs.data,rhs,temp,remainder); + return bigint_kernel_2(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2& bigint_kernel_2:: + operator/= ( + uint16 rhs + ) + { + uint16 remainder; + // if there are other references to this data + if (data->references != 1) + { + data_record* temp = new data_record(data->digits_used+slack); + data->references -= 1; + short_div(data,rhs,temp,remainder); + data = temp; + } + else + { + short_div(data,rhs,data,remainder); + } + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_2 operator% ( + uint16 lhs, + const bigint_kernel_2& rhs + ) + { + typedef bigint_kernel_2 bigint; + // temp is zero by default + bigint::data_record* temp = new bigint::data_record(rhs.slack); + + if (rhs.data->digits_used == 1) + { + // if rhs is just an uint16 inside then perform the modulus + *(temp->number) = lhs % *(rhs.data->number); + } + else + { + // if rhs is bigger than lhs then the answer is lhs + *(temp->number) = lhs; + } + + return bigint_kernel_2(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_2 operator% ( + const bigint_kernel_2& lhs, + uint16 rhs + ) + { + typedef bigint_kernel_2 bigint; + bigint::data_record* temp = new bigint::data_record(lhs.data->digits_used+lhs.slack); + + uint16 remainder; + + lhs.short_div(lhs.data,rhs,temp,remainder); + temp->digits_used = 1; + *(temp->number) = remainder; + return bigint_kernel_2(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2& bigint_kernel_2:: + operator%= ( + uint16 rhs + ) + { + uint16 remainder; + // if there are other references to this data + if (data->references != 1) + { + data_record* temp = new data_record(data->digits_used+slack); + data->references -= 1; + short_div(data,rhs,temp,remainder); + data = temp; + } + else + { + short_div(data,rhs,data,remainder); + } + + data->digits_used = 1; + *(data->number) = remainder; + return *this; + } + +// ---------------------------------------------------------------------------------------- + + bool operator < ( + uint16 lhs, + const bigint_kernel_2& rhs + ) + { + return (rhs.data->digits_used > 1 || lhs < *(rhs.data->number) ); + } + +// ---------------------------------------------------------------------------------------- + + bool operator < ( + const bigint_kernel_2& lhs, + uint16 rhs + ) + { + return (lhs.data->digits_used == 1 && *(lhs.data->number) < rhs); + } + +// ---------------------------------------------------------------------------------------- + + bool operator == ( + const bigint_kernel_2& lhs, + uint16 rhs + ) + { + return (lhs.data->digits_used == 1 && *(lhs.data->number) == rhs); + } + +// ---------------------------------------------------------------------------------------- + + bool operator == ( + uint16 lhs, + const bigint_kernel_2& rhs + ) + { + return (rhs.data->digits_used == 1 && *(rhs.data->number) == lhs); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2& bigint_kernel_2:: + operator= ( + uint16 rhs + ) + { + // check if there are other references to our data + if (data->references != 1) + { + data->references -= 1; + try { + data = new data_record(slack); + } catch (...) { data->references += 1; throw; } + } + else + { + data->digits_used = 1; + } + + *(data->number) = rhs; + + return *this; + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2& bigint_kernel_2:: + operator++ ( + ) + { + // if there are other references to this data then make a copy of it + if (data->references != 1) + { + data_record* temp = new data_record(data->digits_used+slack); + data->references -= 1; + increment(data,temp); + data = temp; + } + // or if we need to enlarge data then do so + else if (data->digits_used == data->size) + { + data_record* temp = new data_record(data->digits_used+slack); + increment(data,temp); + delete data; + data = temp; + } + else + { + increment(data,data); + } + + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_2 bigint_kernel_2:: + operator++ ( + int + ) + { + data_record* temp; // this is the copy of temp we will return in the end + + data_record* temp2 = new data_record(data->digits_used+slack); + increment(data,temp2); + + temp = data; + data = temp2; + + return bigint_kernel_2(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2& bigint_kernel_2:: + operator-- ( + ) + { + // if there are other references to this data + if (data->references != 1) + { + data_record* temp = new data_record(data->digits_used+slack); + data->references -= 1; + decrement(data,temp); + data = temp; + } + else + { + decrement(data,data); + } + + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_2 bigint_kernel_2:: + operator-- ( + int + ) + { + data_record* temp; // this is the copy of temp we will return in the end + + data_record* temp2 = new data_record(data->digits_used+slack); + decrement(data,temp2); + + temp = data; + data = temp2; + + return bigint_kernel_2(temp,0); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // private member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_2:: + short_add ( + const data_record* data, + uint16 value, + data_record* result + ) const + { + // put value into the carry part of temp + uint32 temp = value; + temp <<= 16; + + + const uint16* number = data->number; + const uint16* end = number + data->digits_used; // one past the end of number + uint16* r = result->number; + + while (number != end) + { + // add *number and the current carry + temp = *number + (temp>>16); + // put the low word of temp into *r + *r = static_cast(temp & 0xFFFF); + + ++number; + ++r; + } + + // if there is a final carry + if ((temp>>16) != 0) + { + result->digits_used = data->digits_used + 1; + // store the carry in the most significant digit of the result + *r = static_cast(temp>>16); + } + else + { + result->digits_used = data->digits_used; + } + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_2:: + short_sub ( + const data_record* data, + uint16 value, + data_record* result + ) const + { + + + const uint16* number = data->number; + const uint16* end = number + data->digits_used - 1; + uint16* r = result->number; + + uint32 temp = *number - value; + + // put the low word of temp into *data + *r = static_cast(temp & 0xFFFF); + + + while (number != end) + { + ++number; + ++r; + + // subtract the carry from *number + temp = *number - (temp>>31); + + // put the low word of temp into *r + *r = static_cast(temp & 0xFFFF); + } + + // if we lost a digit in the subtraction + if (*r == 0) + { + if (data->digits_used == 1) + result->digits_used = 1; + else + result->digits_used = data->digits_used - 1; + } + else + { + result->digits_used = data->digits_used; + } + + + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_2:: + short_mul ( + const data_record* data, + uint16 value, + data_record* result + ) const + { + + uint32 temp = 0; + + + const uint16* number = data->number; + uint16* r = result->number; + const uint16* end = r + data->digits_used; + + + + while ( r != end) + { + + // multiply *data and value and add in the carry + temp = *number*(uint32)value + (temp>>16); + + // put the low word of temp into *data + *r = static_cast(temp & 0xFFFF); + + ++number; + ++r; + } + + // if there is a final carry + if ((temp>>16) != 0) + { + result->digits_used = data->digits_used + 1; + // put the final carry into the most significant digit of the result + *r = static_cast(temp>>16); + } + else + { + result->digits_used = data->digits_used; + } + + + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_2:: + short_div ( + const data_record* data, + uint16 value, + data_record* result, + uint16& rem + ) const + { + + uint16 remainder = 0; + uint32 temp; + + + + const uint16* number = data->number + data->digits_used - 1; + const uint16* end = number - data->digits_used; + uint16* r = result->number + data->digits_used - 1; + + + // if we are losing a digit in this division + if (*number < value) + { + if (data->digits_used == 1) + result->digits_used = 1; + else + result->digits_used = data->digits_used - 1; + } + else + { + result->digits_used = data->digits_used; + } + + + // perform the actual division + while (number != end) + { + + temp = *number + (((uint32)remainder)<<16); + + *r = static_cast(temp/value); + remainder = static_cast(temp%value); + + --number; + --r; + } + + rem = remainder; + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_2:: + long_add ( + const data_record* lhs, + const data_record* rhs, + data_record* result + ) const + { + // put value into the carry part of temp + uint32 temp=0; + + uint16* min_num; // the number with the least digits used + uint16* max_num; // the number with the most digits used + uint16* min_end; // one past the end of min_num + uint16* max_end; // one past the end of max_num + uint16* r = result->number; + + uint32 max_digits_used; + if (lhs->digits_used < rhs->digits_used) + { + max_digits_used = rhs->digits_used; + min_num = lhs->number; + max_num = rhs->number; + min_end = min_num + lhs->digits_used; + max_end = max_num + rhs->digits_used; + } + else + { + max_digits_used = lhs->digits_used; + min_num = rhs->number; + max_num = lhs->number; + min_end = min_num + rhs->digits_used; + max_end = max_num + lhs->digits_used; + } + + + + + while (min_num != min_end) + { + // add *min_num, *max_num and the current carry + temp = *min_num + *max_num + (temp>>16); + // put the low word of temp into *r + *r = static_cast(temp & 0xFFFF); + + ++min_num; + ++max_num; + ++r; + } + + + while (max_num != max_end) + { + // add *max_num and the current carry + temp = *max_num + (temp>>16); + // put the low word of temp into *r + *r = static_cast(temp & 0xFFFF); + + ++max_num; + ++r; + } + + // check if there was a final carry + if ((temp>>16) != 0) + { + result->digits_used = max_digits_used + 1; + // put the carry into the most significant digit in the result + *r = static_cast(temp>>16); + } + else + { + result->digits_used = max_digits_used; + } + + + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_2:: + long_sub ( + const data_record* lhs, + const data_record* rhs, + data_record* result + ) const + { + + + const uint16* number1 = lhs->number; + const uint16* number2 = rhs->number; + const uint16* end = number2 + rhs->digits_used; + uint16* r = result->number; + + + + uint32 temp =0; + + + while (number2 != end) + { + + // subtract *number2 from *number1 and then subtract any carry + temp = *number1 - *number2 - (temp>>31); + + // put the low word of temp into *r + *r = static_cast(temp & 0xFFFF); + + ++number1; + ++number2; + ++r; + } + + end = lhs->number + lhs->digits_used; + while (number1 != end) + { + + // subtract the carry from *number1 + temp = *number1 - (temp>>31); + + // put the low word of temp into *r + *r = static_cast(temp & 0xFFFF); + + ++number1; + ++r; + } + + result->digits_used = lhs->digits_used; + // adjust the number of digits used appropriately + --r; + while (*r == 0 && result->digits_used > 1) + { + --r; + --result->digits_used; + } + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_2:: + long_div ( + const data_record* lhs, + const data_record* rhs, + data_record* result, + data_record* remainder + ) const + { + // zero result + result->digits_used = 1; + *(result->number) = 0; + + uint16* a; + uint16* b; + uint16* end; + + // copy lhs into remainder + remainder->digits_used = lhs->digits_used; + a = remainder->number; + end = a + remainder->digits_used; + b = lhs->number; + while (a != end) + { + *a = *b; + ++a; + ++b; + } + + + // if rhs is bigger than lhs then result == 0 and remainder == lhs + // so then we can quit right now + if (is_less_than(lhs,rhs)) + { + return; + } + + + // make a temporary number + data_record temp(lhs->digits_used + slack); + + + // shift rhs left until it is one shift away from being larger than lhs and + // put the number of left shifts necessary into shifts + uint32 shifts; + shifts = (lhs->digits_used - rhs->digits_used) * 16; + + shift_left(rhs,&temp,shifts); + + + // while (lhs > temp) + while (is_less_than(&temp,lhs)) + { + shift_left(&temp,&temp,1); + ++shifts; + } + // make sure lhs isn't smaller than temp + while (is_less_than(lhs,&temp)) + { + shift_right(&temp,&temp); + --shifts; + } + + + + // we want to execute the loop shifts +1 times + ++shifts; + while (shifts != 0) + { + shift_left(result,result,1); + // if (temp <= remainder) + if (!is_less_than(remainder,&temp)) + { + long_sub(remainder,&temp,remainder); + + // increment result + uint16* r = result->number; + uint16* end = r + result->digits_used; + while (true) + { + ++(*r); + // if there was no carry then we are done + if (*r != 0) + break; + + ++r; + + // if we hit the end of r and there is still a carry then + // the next digit of r is 1 and there is one more digit used + if (r == end) + { + *r = 1; + ++(result->digits_used); + break; + } + } + } + shift_right(&temp,&temp); + --shifts; + } + + + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_2:: + long_mul ( + const data_record* lhs, + const data_record* rhs, + data_record* result + ) const + { + // if one of the numbers is small then use this simple but O(n^2) algorithm + if (std::min(lhs->digits_used, rhs->digits_used) < 10) + { + // make result be zero + result->digits_used = 1; + *(result->number) = 0; + + + const data_record* aa; + const data_record* bb; + + if (lhs->digits_used < rhs->digits_used) + { + // make copies of lhs and rhs and give them an appropriate amount of + // extra memory so there won't be any overflows + aa = lhs; + bb = rhs; + } + else + { + // make copies of lhs and rhs and give them an appropriate amount of + // extra memory so there won't be any overflows + aa = rhs; + bb = lhs; + } + + // copy the larger(approximately) of lhs and rhs into b + data_record b(*bb,aa->digits_used+slack); + + + uint32 shift_value = 0; + uint16* anum = aa->number; + uint16* end = anum + aa->digits_used; + while (anum != end ) + { + uint16 bit = 0x0001; + + for (int i = 0; i < 16; ++i) + { + // if the specified bit of a is 1 + if ((*anum & bit) != 0) + { + shift_left(&b,&b,shift_value); + shift_value = 0; + long_add(&b,result,result); + } + ++shift_value; + bit <<= 1; + } + + ++anum; + } + } + else // else if both lhs and rhs are large then use the more complex + // O(n*logn) algorithm + { + uint32 size = 1; + // make size a power of 2 + while (size < (lhs->digits_used + rhs->digits_used)*2) + { + size *= 2; + } + + // allocate some temporary space so we can do the FFT + ct* a = new ct[size]; + ct* b; try {b = new ct[size]; } catch (...) { delete [] a; throw; } + + // load lhs into the a array. We are breaking the input number into + // 8bit chunks for the purpose of using this fft algorithm. The reason + // for this is so that we have smaller numbers coming out of the final + // ifft. This helps avoid overflow. + for (uint32 i = 0; i < lhs->digits_used; ++i) + { + a[i*2] = ct((t)(lhs->number[i]&0xFF),0); + a[i*2+1] = ct((t)(lhs->number[i]>>8),0); + } + for (uint32 i = lhs->digits_used*2; i < size; ++i) + { + a[i] = 0; + } + + // load rhs into the b array + for (uint32 i = 0; i < rhs->digits_used; ++i) + { + b[i*2] = ct((t)(rhs->number[i]&0xFF),0); + b[i*2+1] = ct((t)(rhs->number[i]>>8),0); + } + for (uint32 i = rhs->digits_used*2; i < size; ++i) + { + b[i] = 0; + } + + // perform the forward fft of a and b + fft(a,size); + fft(b,size); + + const double l = 1.0/size; + + // do the pointwise multiply of a and b and also apply the scale + // factor in this loop too. + for (unsigned long i = 0; i < size; ++i) + { + a[i] = l*a[i]*b[i]; + } + + // Now compute the inverse fft of the pointwise multiplication of a and b. + // This is basically the result. We just have to take care of any carries + // that should happen. + ifft(a,size); + + // loop over the result and propagate any carries that need to take place. + // We will also be moving the resulting numbers into result->number at + // the same time. + uint64 carry = 0; + result->digits_used = 0; + int zeros = 0; + const uint32 len = lhs->digits_used + rhs->digits_used; + for (unsigned long i = 0; i < len; ++i) + { + uint64 num1 = static_cast(std::floor(a[i*2].real()+0.5)); + num1 += carry; + carry = 0; + if (num1 > 255) + { + carry = num1 >> 8; + num1 = (num1&0xFF); + } + + uint64 num2 = static_cast(std::floor(a[i*2+1].real()+0.5)); + num2 += carry; + carry = 0; + if (num2 > 255) + { + carry = num2 >> 8; + num2 = (num2&0xFF); + } + + // put the new number into its final place + num1 = (num2<<8) | num1; + result->number[i] = static_cast(num1); + + // keep track of the number of leading zeros + if (num1 == 0) + ++zeros; + else + zeros = 0; + ++(result->digits_used); + } + + // adjust digits_used so that it reflects the actual number + // of non-zero digits in our representation. + result->digits_used -= zeros; + + // if the result was zero then adjust the result accordingly + if (result->digits_used == 0) + { + // make result be zero + result->digits_used = 1; + *(result->number) = 0; + } + + // free all the temporary buffers + delete [] a; + delete [] b; + } + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_2:: + shift_left ( + const data_record* data, + data_record* result, + uint32 shift_amount + ) const + { + uint32 offset = shift_amount/16; + shift_amount &= 0xf; // same as shift_amount %= 16; + + uint16* r = result->number + data->digits_used + offset; // result + uint16* end = data->number; + uint16* s = end + data->digits_used; // source + const uint32 temp = 16 - shift_amount; + + *r = (*(--s) >> temp); + // set the number of digits used in the result + // if the upper bits from *s were zero then don't count this first word + if (*r == 0) + { + result->digits_used = data->digits_used + offset; + } + else + { + result->digits_used = data->digits_used + offset + 1; + } + --r; + + while (s != end) + { + *r = ((*s << shift_amount) | ( *(s-1) >> temp)); + --r; + --s; + } + *r = *s << shift_amount; + + // now zero the rest of the result + end = result->number; + while (r != end) + *(--r) = 0; + + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_2:: + shift_right ( + const data_record* data, + data_record* result + ) const + { + + uint16* r = result->number; // result + uint16* s = data->number; // source + uint16* end = s + data->digits_used - 1; + + while (s != end) + { + *r = (*s >> 1) | (*(s+1) << 15); + ++r; + ++s; + } + *r = *s >> 1; + + + // calculate the new number for digits_used + if (*r == 0) + { + if (data->digits_used != 1) + result->digits_used = data->digits_used - 1; + else + result->digits_used = 1; + } + else + { + result->digits_used = data->digits_used; + } + + + } + +// ---------------------------------------------------------------------------------------- + + bool bigint_kernel_2:: + is_less_than ( + const data_record* lhs, + const data_record* rhs + ) const + { + uint32 lhs_digits_used = lhs->digits_used; + uint32 rhs_digits_used = rhs->digits_used; + + // if lhs is definitely less than rhs + if (lhs_digits_used < rhs_digits_used ) + return true; + // if lhs is definitely greater than rhs + else if (lhs_digits_used > rhs_digits_used) + return false; + else + { + uint16* end = lhs->number; + uint16* l = end + lhs_digits_used; + uint16* r = rhs->number + rhs_digits_used; + + while (l != end) + { + --l; + --r; + if (*l < *r) + return true; + else if (*l > *r) + return false; + } + + // at this point we know that they are equal + return false; + } + + } + +// ---------------------------------------------------------------------------------------- + + bool bigint_kernel_2:: + is_equal_to ( + const data_record* lhs, + const data_record* rhs + ) const + { + // if lhs and rhs are definitely not equal + if (lhs->digits_used != rhs->digits_used ) + { + return false; + } + else + { + uint16* l = lhs->number; + uint16* r = rhs->number; + uint16* end = l + lhs->digits_used; + + while (l != end) + { + if (*l != *r) + return false; + ++l; + ++r; + } + + // at this point we know that they are equal + return true; + } + + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_2:: + increment ( + const data_record* source, + data_record* dest + ) const + { + uint16* s = source->number; + uint16* d = dest->number; + uint16* end = s + source->digits_used; + while (true) + { + *d = *s + 1; + + // if there was no carry then break out of the loop + if (*d != 0) + { + dest->digits_used = source->digits_used; + + // copy the rest of the digits over to d + ++d; ++s; + while (s != end) + { + *d = *s; + ++d; + ++s; + } + + break; + } + + + ++s; + + // if we have hit the end of s and there was a carry up to this point + // then just make the next digit 1 and add one to the digits used + if (s == end) + { + ++d; + dest->digits_used = source->digits_used + 1; + *d = 1; + break; + } + + ++d; + } + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_2:: + decrement ( + const data_record* source, + data_record* dest + ) const + { + uint16* s = source->number; + uint16* d = dest->number; + uint16* end = s + source->digits_used; + while (true) + { + *d = *s - 1; + + // if there was no carry then break out of the loop + if (*d != 0xFFFF) + { + // if we lost a digit in the subtraction + if (*d == 0 && s+1 == end) + { + if (source->digits_used == 1) + dest->digits_used = 1; + else + dest->digits_used = source->digits_used - 1; + } + else + { + dest->digits_used = source->digits_used; + } + break; + } + else + { + ++d; + ++s; + } + + } + + // copy the rest of the digits over to d + ++d; + ++s; + while (s != end) + { + *d = *s; + ++d; + ++s; + } + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_2:: + fft ( + ct* data, + unsigned long len + ) const + { + const t pi2 = -2.0*3.1415926535897932384626433832795028841971693993751; + + const unsigned long half = len/2; + + std::vector twiddle_factors; + twiddle_factors.resize(half); + + // compute the complex root of unity w + const t temp = pi2/len; + ct w = ct(std::cos(temp),std::sin(temp)); + + ct w_pow = 1; + + // compute the twiddle factors + for (std::vector::size_type j = 0; j < twiddle_factors.size(); ++j) + { + twiddle_factors[j] = w_pow; + w_pow *= w; + } + + ct a, b; + + // now compute the decimation in frequency. This first + // outer loop loops log2(len) number of times + unsigned long skip = 1; + for (unsigned long step = half; step != 0; step >>= 1) + { + // do blocks of butterflies in this loop + for (unsigned long j = 0; j < len; j += step*2) + { + // do step butterflies + for (unsigned long k = 0; k < step; ++k) + { + const unsigned long a_idx = j+k; + const unsigned long b_idx = j+k+step; + a = data[a_idx] + data[b_idx]; + b = (data[a_idx] - data[b_idx])*twiddle_factors[k*skip]; + data[a_idx] = a; + data[b_idx] = b; + } + } + skip *= 2; + } + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_2:: + ifft( + ct* data, + unsigned long len + ) const + { + const t pi2 = 2.0*3.1415926535897932384626433832795028841971693993751; + + const unsigned long half = len/2; + + std::vector twiddle_factors; + twiddle_factors.resize(half); + + // compute the complex root of unity w + const t temp = pi2/len; + ct w = ct(std::cos(temp),std::sin(temp)); + + ct w_pow = 1; + + // compute the twiddle factors + for (std::vector::size_type j = 0; j < twiddle_factors.size(); ++j) + { + twiddle_factors[j] = w_pow; + w_pow *= w; + } + + ct a, b; + + // now compute the inverse decimation in frequency. This first + // outer loop loops log2(len) number of times + unsigned long skip = half; + for (unsigned long step = 1; step <= half; step <<= 1) + { + // do blocks of butterflies in this loop + for (unsigned long j = 0; j < len; j += step*2) + { + // do step butterflies + for (unsigned long k = 0; k < step; ++k) + { + const unsigned long a_idx = j+k; + const unsigned long b_idx = j+k+step; + data[b_idx] *= twiddle_factors[k*skip]; + a = data[a_idx] + data[b_idx]; + b = data[a_idx] - data[b_idx]; + data[a_idx] = a; + data[b_idx] = b; + } + } + skip /= 2; + } + } + +// ---------------------------------------------------------------------------------------- + +} +#endif // DLIB_BIGINT_KERNEL_2_CPp_ + diff --git a/ml/dlib/dlib/bigint/bigint_kernel_2.h b/ml/dlib/dlib/bigint/bigint_kernel_2.h new file mode 100644 index 000000000..cbd8f895d --- /dev/null +++ b/ml/dlib/dlib/bigint/bigint_kernel_2.h @@ -0,0 +1,570 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BIGINT_KERNEl_2_ +#define DLIB_BIGINT_KERNEl_2_ + +#include "bigint_kernel_abstract.h" +#include "../algs.h" +#include "../serialize.h" +#include "../uintn.h" +#include +#include +#include +#include + +namespace dlib +{ + + class bigint_kernel_2 + { + /*! + INITIAL VALUE + slack == 25 + data->number[0] == 0 + data->size == slack + data->references == 1 + data->digits_used == 1 + + + CONVENTION + slack == the number of extra digits placed into the number when it is + created. the slack value should never be less than 1 + + data->number == pointer to an array of data->size uint16s. + data represents a string of base 65535 numbers with data[0] being + the least significant bit and data[data->digits_used-1] being the most + significant + + + NOTE: In the comments I will consider a word to be a 16 bit value + + + data->digits_used == the number of significant digits in the number. + data->digits_used tells us the number of used elements in the + data->number array so everything beyond data->number[data->digits_used-1] + is undefined + + data->references == the number of bigint_kernel_2 objects which refer + to this data_record + !*/ + + + struct data_record + { + + + explicit data_record( + uint32 size_ + ) : + size(size_), + number(new uint16[size_]), + references(1), + digits_used(1) + {*number = 0;} + /*! + ensures + - initializes *this to represent zero + !*/ + + data_record( + const data_record& item, + uint32 additional_size + ) : + size(item.digits_used + additional_size), + number(new uint16[size]), + references(1), + digits_used(item.digits_used) + { + uint16* source = item.number; + uint16* dest = number; + uint16* end = source + digits_used; + while (source != end) + { + *dest = *source; + ++dest; + ++source; + } + } + /*! + ensures + - *this is a copy of item except with + size == item.digits_used + additional_size + !*/ + + ~data_record( + ) + { + delete [] number; + } + + + const uint32 size; + uint16* number; + uint32 references; + uint32 digits_used; + + private: + // no copy constructor + data_record ( data_record&); + }; + + + // note that the second parameter is just there + // to resolve the ambiguity between this constructor and + // bigint_kernel_2(uint32) + explicit bigint_kernel_2 ( + data_record* data_, int + ): slack(25),data(data_) {} + /*! + ensures + - *this is initialized with data_ as its data member + !*/ + + public: + + bigint_kernel_2 ( + ); + + bigint_kernel_2 ( + uint32 value + ); + + bigint_kernel_2 ( + const bigint_kernel_2& item + ); + + virtual ~bigint_kernel_2 ( + ); + + const bigint_kernel_2 operator+ ( + const bigint_kernel_2& rhs + ) const; + + bigint_kernel_2& operator+= ( + const bigint_kernel_2& rhs + ); + + const bigint_kernel_2 operator- ( + const bigint_kernel_2& rhs + ) const; + + bigint_kernel_2& operator-= ( + const bigint_kernel_2& rhs + ); + + const bigint_kernel_2 operator* ( + const bigint_kernel_2& rhs + ) const; + + bigint_kernel_2& operator*= ( + const bigint_kernel_2& rhs + ); + + const bigint_kernel_2 operator/ ( + const bigint_kernel_2& rhs + ) const; + + bigint_kernel_2& operator/= ( + const bigint_kernel_2& rhs + ); + + const bigint_kernel_2 operator% ( + const bigint_kernel_2& rhs + ) const; + + bigint_kernel_2& operator%= ( + const bigint_kernel_2& rhs + ); + + bool operator < ( + const bigint_kernel_2& rhs + ) const; + + bool operator == ( + const bigint_kernel_2& rhs + ) const; + + bigint_kernel_2& operator= ( + const bigint_kernel_2& rhs + ); + + friend std::ostream& operator<< ( + std::ostream& out, + const bigint_kernel_2& rhs + ); + + friend std::istream& operator>> ( + std::istream& in, + bigint_kernel_2& rhs + ); + + bigint_kernel_2& operator++ ( + ); + + const bigint_kernel_2 operator++ ( + int + ); + + bigint_kernel_2& operator-- ( + ); + + const bigint_kernel_2 operator-- ( + int + ); + + friend const bigint_kernel_2 operator+ ( + uint16 lhs, + const bigint_kernel_2& rhs + ); + + friend const bigint_kernel_2 operator+ ( + const bigint_kernel_2& lhs, + uint16 rhs + ); + + bigint_kernel_2& operator+= ( + uint16 rhs + ); + + friend const bigint_kernel_2 operator- ( + uint16 lhs, + const bigint_kernel_2& rhs + ); + + friend const bigint_kernel_2 operator- ( + const bigint_kernel_2& lhs, + uint16 rhs + ); + + bigint_kernel_2& operator-= ( + uint16 rhs + ); + + friend const bigint_kernel_2 operator* ( + uint16 lhs, + const bigint_kernel_2& rhs + ); + + friend const bigint_kernel_2 operator* ( + const bigint_kernel_2& lhs, + uint16 rhs + ); + + bigint_kernel_2& operator*= ( + uint16 rhs + ); + + friend const bigint_kernel_2 operator/ ( + uint16 lhs, + const bigint_kernel_2& rhs + ); + + friend const bigint_kernel_2 operator/ ( + const bigint_kernel_2& lhs, + uint16 rhs + ); + + bigint_kernel_2& operator/= ( + uint16 rhs + ); + + friend const bigint_kernel_2 operator% ( + uint16 lhs, + const bigint_kernel_2& rhs + ); + + friend const bigint_kernel_2 operator% ( + const bigint_kernel_2& lhs, + uint16 rhs + ); + + bigint_kernel_2& operator%= ( + uint16 rhs + ); + + friend bool operator < ( + uint16 lhs, + const bigint_kernel_2& rhs + ); + + friend bool operator < ( + const bigint_kernel_2& lhs, + uint16 rhs + ); + + friend bool operator == ( + const bigint_kernel_2& lhs, + uint16 rhs + ); + + friend bool operator == ( + uint16 lhs, + const bigint_kernel_2& rhs + ); + + bigint_kernel_2& operator= ( + uint16 rhs + ); + + + void swap ( + bigint_kernel_2& item + ) { data_record* temp = data; data = item.data; item.data = temp; } + + + private: + + typedef double t; + typedef std::complex ct; + + void fft( + ct* data, + unsigned long len + ) const; + /*! + requires + - len == x^n for some integer n (i.e. len is a power of 2) + - len > 0 + ensures + - #data == the FT decimation in frequency of data + !*/ + + void ifft( + ct* data, + unsigned long len + ) const; + /*! + requires + - len == x^n for some integer n (i.e. len is a power of 2) + - len > 0 + ensures + - #data == the inverse decimation in frequency of data. + (i.e. the inverse of what fft(data,len,-1) does to data) + !*/ + + void long_add ( + const data_record* lhs, + const data_record* rhs, + data_record* result + ) const; + /*! + requires + - result->size >= max(lhs->digits_used,rhs->digits_used) + 1 + ensures + - result == lhs + rhs + !*/ + + void long_sub ( + const data_record* lhs, + const data_record* rhs, + data_record* result + ) const; + /*! + requires + - lhs >= rhs + - result->size >= lhs->digits_used + ensures + - result == lhs - rhs + !*/ + + void long_div ( + const data_record* lhs, + const data_record* rhs, + data_record* result, + data_record* remainder + ) const; + /*! + requires + - rhs != 0 + - result->size >= lhs->digits_used + - remainder->size >= lhs->digits_used + - each parameter is unique (i.e. lhs != result, lhs != remainder, etc.) + ensures + - result == lhs / rhs + - remainder == lhs % rhs + !*/ + + void long_mul ( + const data_record* lhs, + const data_record* rhs, + data_record* result + ) const; + /*! + requires + - result->size >= lhs->digits_used + rhs->digits_used + - result != lhs + - result != rhs + ensures + - result == lhs * rhs + !*/ + + void short_add ( + const data_record* data, + uint16 value, + data_record* result + ) const; + /*! + requires + - result->size >= data->size + 1 + ensures + - result == data + value + !*/ + + void short_sub ( + const data_record* data, + uint16 value, + data_record* result + ) const; + /*! + requires + - data >= value + - result->size >= data->digits_used + ensures + - result == data - value + !*/ + + void short_mul ( + const data_record* data, + uint16 value, + data_record* result + ) const; + /*! + requires + - result->size >= data->digits_used + 1 + ensures + - result == data * value + !*/ + + void short_div ( + const data_record* data, + uint16 value, + data_record* result, + uint16& remainder + ) const; + /*! + requires + - value != 0 + - result->size >= data->digits_used + ensures + - result = data*value + - remainder = data%value + !*/ + + void shift_left ( + const data_record* data, + data_record* result, + uint32 shift_amount + ) const; + /*! + requires + - result->size >= data->digits_used + shift_amount/8 + 1 + ensures + - result == data << shift_amount + !*/ + + void shift_right ( + const data_record* data, + data_record* result + ) const; + /*! + requires + - result->size >= data->digits_used + ensures + - result == data >> 1 + !*/ + + bool is_less_than ( + const data_record* lhs, + const data_record* rhs + ) const; + /*! + ensures + - returns true if lhs < rhs + - returns false otherwise + !*/ + + bool is_equal_to ( + const data_record* lhs, + const data_record* rhs + ) const; + /*! + ensures + - returns true if lhs == rhs + - returns false otherwise + !*/ + + void increment ( + const data_record* source, + data_record* dest + ) const; + /*! + requires + - dest->size >= source->digits_used + 1 + ensures + - dest = source + 1 + !*/ + + void decrement ( + const data_record* source, + data_record* dest + ) const; + /*! + requires + source != 0 + ensuers + dest = source - 1 + !*/ + + // member data + const uint32 slack; + data_record* data; + + + + }; + + inline void swap ( + bigint_kernel_2& a, + bigint_kernel_2& b + ) { a.swap(b); } + + inline void serialize ( + const bigint_kernel_2& item, + std::ostream& out + ) + { + std::ios::fmtflags oldflags = out.flags(); + out.flags(); + out << item << ' '; + out.flags(oldflags); + if (!out) throw serialization_error("Error serializing object of type bigint_kernel_c"); + } + + inline void deserialize ( + bigint_kernel_2& item, + std::istream& in + ) + { + std::ios::fmtflags oldflags = in.flags(); + in.flags(); + in >> item; in.flags(oldflags); + if (in.get() != ' ') + { + item = 0; + throw serialization_error("Error deserializing object of type bigint_kernel_c"); + } + } + + inline bool operator> (const bigint_kernel_2& a, const bigint_kernel_2& b) { return b < a; } + inline bool operator!= (const bigint_kernel_2& a, const bigint_kernel_2& b) { return !(a == b); } + inline bool operator<= (const bigint_kernel_2& a, const bigint_kernel_2& b) { return !(b < a); } + inline bool operator>= (const bigint_kernel_2& a, const bigint_kernel_2& b) { return !(a < b); } + +} + +#ifdef NO_MAKEFILE +#include "bigint_kernel_2.cpp" +#endif + +#endif // DLIB_BIGINT_KERNEl_2_ + diff --git a/ml/dlib/dlib/bigint/bigint_kernel_abstract.h b/ml/dlib/dlib/bigint/bigint_kernel_abstract.h new file mode 100644 index 000000000..99a54520b --- /dev/null +++ b/ml/dlib/dlib/bigint/bigint_kernel_abstract.h @@ -0,0 +1,670 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_BIGINT_KERNEl_ABSTRACT_ +#ifdef DLIB_BIGINT_KERNEl_ABSTRACT_ + +#include +#include "../algs.h" +#include "../serialize.h" +#include "../uintn.h" + +namespace dlib +{ + + class bigint + { + /*! + INITIAL VALUE + *this == 0 + + WHAT THIS OBJECT REPRESENTS + This object represents an arbitrary precision unsigned integer + + the following operators are supported: + operator + + operator += + operator - + operator -= + operator * + operator *= + operator / + operator /= + operator % + operator %= + operator == + operator < + operator = + operator << (for writing to ostreams) + operator >> (for reading from istreams) + operator++ // pre increment + operator++(int) // post increment + operator-- // pre decrement + operator--(int) // post decrement + + + the other comparason operators(>, !=, <=, and >=) are + available and come from the templates in dlib::relational_operators + + THREAD SAFETY + bigint may be reference counted so it is very unthread safe. + use with care in a multithreaded program + + !*/ + + public: + + bigint ( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc + if this is thrown the bigint will be unusable but + will not leak memory + !*/ + + bigint ( + uint32 value + ); + /*! + requires + - value <= (2^32)-1 + ensures + - #*this is properly initialized + - #*this == value + throws + - std::bad_alloc + if this is thrown the bigint will be unusable but + will not leak memory + !*/ + + bigint ( + const bigint& item + ); + /*! + ensures + - #*this is properly initialized + - #*this == value + throws + - std::bad_alloc + if this is thrown the bigint will be unusable but + will not leak memory + !*/ + + virtual ~bigint ( + ); + /*! + ensures + - all resources associated with #*this have been released + !*/ + + const bigint operator+ ( + const bigint& rhs + ) const; + /*! + ensures + - returns the result of adding rhs to *this + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + bigint& operator+= ( + const bigint& rhs + ); + /*! + ensures + - #*this == *this + rhs + - returns #*this + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + const bigint operator- ( + const bigint& rhs + ) const; + /*! + requires + - *this >= rhs + ensures + - returns the result of subtracting rhs from *this + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + bigint& operator-= ( + const bigint& rhs + ); + /*! + requires + - *this >= rhs + ensures + - #*this == *this - rhs + - returns #*this + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + const bigint operator* ( + const bigint& rhs + ) const; + /*! + ensures + - returns the result of multiplying *this and rhs + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + bigint& operator*= ( + const bigint& rhs + ); + /*! + ensures + - #*this == *this * rhs + - returns #*this + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + const bigint operator/ ( + const bigint& rhs + ) const; + /*! + requires + - rhs != 0 + ensures + - returns the result of dividing *this by rhs + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + bigint& operator/= ( + const bigint& rhs + ); + /*! + requires + - rhs != 0 + ensures + - #*this == *this / rhs + - returns #*this + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + const bigint operator% ( + const bigint& rhs + ) const; + /*! + requires + - rhs != 0 + ensures + - returns the result of *this mod rhs + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + bigint& operator%= ( + const bigint& rhs + ); + /*! + requires + - rhs != 0 + ensures + - #*this == *this % rhs + - returns #*this + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + bool operator < ( + const bigint& rhs + ) const; + /*! + ensures + - returns true if *this is less than rhs + - returns false otherwise + !*/ + + bool operator == ( + const bigint& rhs + ) const; + /*! + ensures + - returns true if *this and rhs represent the same number + - returns false otherwise + !*/ + + bigint& operator= ( + const bigint& rhs + ); + /*! + ensures + - #*this == rhs + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + + friend std::ostream& operator<< ( + std::ostream& out, + const bigint& rhs + ); + /*! + ensures + - the number in *this has been written to #out as a base ten number + throws + - std::bad_alloc + if this function throws then it has no effect (nothing + is written to out) + !*/ + + friend std::istream& operator>> ( + std::istream& in, + bigint& rhs + ); + /*! + ensures + - reads a number from in and puts it into #*this + - if (there is no positive base ten number on the input stream ) then + - #in.fail() == true + throws + - std::bad_alloc + if this function throws the value in rhs is undefined and some + characters may have been read from in. rhs is still usable though, + its value is just unknown. + !*/ + + + bigint& operator++ ( + ); + /*! + ensures + - #*this == *this + 1 + - returns #*this + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + const bigint operator++ ( + int + ); + /*! + ensures + - #*this == *this + 1 + - returns *this + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + bigint& operator-- ( + ); + /*! + requires + - *this != 0 + ensures + - #*this == *this - 1 + - returns #*this + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + const bigint operator-- ( + int + ); + /*! + requires + - *this != 0 + ensures + - #*this == *this - 1 + - returns *this + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + void swap ( + bigint& item + ); + /*! + ensures + - swaps *this and item + !*/ + + + // ------------------------------------------------------------------ + // ---- The following functions are identical to the above ----- + // ---- but take uint16 as one of their arguments. They --- + // ---- exist only to allow for a more efficient implementation --- + // ------------------------------------------------------------------ + + + friend const bigint operator+ ( + uint16 lhs, + const bigint& rhs + ); + /*! + requires + - lhs <= 65535 + ensures + - returns the result of adding rhs to lhs + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + friend const bigint operator+ ( + const bigint& lhs, + uint16 rhs + ); + /*! + requires + - rhs <= 65535 + ensures + - returns the result of adding rhs to lhs + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + bigint& operator+= ( + uint16 rhs + ); + /*! + requires + - rhs <= 65535 + ensures + - #*this == *this + rhs + - returns #this + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + friend const bigint operator- ( + uint16 lhs, + const bigint& rhs + ); + /*! + requires + - lhs >= rhs + - lhs <= 65535 + ensures + - returns the result of subtracting rhs from lhs + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + friend const bigint operator- ( + const bigint& lhs, + uint16 rhs + ); + /*! + requires + - lhs >= rhs + - rhs <= 65535 + ensures + - returns the result of subtracting rhs from lhs + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + bigint& operator-= ( + uint16 rhs + ); + /*! + requires + - *this >= rhs + - rhs <= 65535 + ensures + - #*this == *this - rhs + - returns #*this + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + friend const bigint operator* ( + uint16 lhs, + const bigint& rhs + ); + /*! + requires + - lhs <= 65535 + ensures + - returns the result of multiplying lhs and rhs + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + friend const bigint operator* ( + const bigint& lhs, + uint16 rhs + ); + /*! + requires + - rhs <= 65535 + ensures + - returns the result of multiplying lhs and rhs + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + bigint& operator*= ( + uint16 rhs + ); + /*! + requires + - rhs <= 65535 + ensures + - #*this == *this * rhs + - returns #*this + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + friend const bigint operator/ ( + uint16 lhs, + const bigint& rhs + ); + /*! + requires + - rhs != 0 + - lhs <= 65535 + ensures + - returns the result of dividing lhs by rhs + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + friend const bigint operator/ ( + const bigint& lhs, + uint16 rhs + ); + /*! + requires + - rhs != 0 + - rhs <= 65535 + ensures + - returns the result of dividing lhs by rhs + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + bigint& operator/= ( + uint16 rhs + ); + /*! + requires + - rhs != 0 + - rhs <= 65535 + ensures + - #*this == *this / rhs + - returns #*this + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + friend const bigint operator% ( + uint16 lhs, + const bigint& rhs + ); + /*! + requires + - rhs != 0 + - lhs <= 65535 + ensures + - returns the result of lhs mod rhs + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + friend const bigint operator% ( + const bigint& lhs, + uint16 rhs + ); + /*! + requires + - rhs != 0 + - rhs <= 65535 + ensures + - returns the result of lhs mod rhs + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + bigint& operator%= ( + uint16 rhs + ); + /*! + requires + - rhs != 0 + - rhs <= 65535 + ensures + - #*this == *this % rhs + - returns #*this + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + + friend bool operator < ( + uint16 lhs, + const bigint& rhs + ); + /*! + requires + - lhs <= 65535 + ensures + - returns true if lhs is less than rhs + - returns false otherwise + !*/ + + friend bool operator < ( + const bigint& lhs, + uint16 rhs + ); + /*! + requires + - rhs <= 65535 + ensures + - returns true if lhs is less than rhs + - returns false otherwise + !*/ + + friend bool operator == ( + const bigint& lhs, + uint16 rhs + ); + /*! + requires + - rhs <= 65535 + ensures + - returns true if lhs and rhs represent the same number + - returns false otherwise + !*/ + + friend bool operator == ( + uint16 lhs, + const bigint& rhs + ); + /*! + requires + - lhs <= 65535 + ensures + - returns true if lhs and rhs represent the same number + - returns false otherwise + !*/ + + bigint& operator= ( + uint16 rhs + ); + /*! + requires + - rhs <= 65535 + ensures + - #*this == rhs + - returns #*this + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + }; + + inline void swap ( + bigint& a, + bigint& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + + void serialize ( + const bigint& item, + std::istream& in + ); + /*! + provides serialization support + !*/ + + void deserialize ( + bigint& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + + inline bool operator> (const bigint& a, const bigint& b) { return b < a; } + inline bool operator!= (const bigint& a, const bigint& b) { return !(a == b); } + inline bool operator<= (const bigint& a, const bigint& b) { return !(b < a); } + inline bool operator>= (const bigint& a, const bigint& b) { return !(a < b); } +} + +#endif // DLIB_BIGINT_KERNEl_ABSTRACT_ + diff --git a/ml/dlib/dlib/bigint/bigint_kernel_c.h b/ml/dlib/dlib/bigint/bigint_kernel_c.h new file mode 100644 index 000000000..954869a38 --- /dev/null +++ b/ml/dlib/dlib/bigint/bigint_kernel_c.h @@ -0,0 +1,1141 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BIGINT_KERNEl_C_ +#define DLIB_BIGINT_KERNEl_C_ + +#include "bigint_kernel_abstract.h" +#include "../algs.h" +#include "../assert.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + class bigint_kernel_c + { + bigint_base data; + + explicit bigint_kernel_c ( + const bigint_base& item + ) : data(item) {} + + public: + + + bigint_kernel_c ( + ); + + bigint_kernel_c ( + uint32 value + ); + + bigint_kernel_c ( + const bigint_kernel_c& item + ); + + ~bigint_kernel_c ( + ); + + const bigint_kernel_c operator+ ( + const bigint_kernel_c& rhs + ) const; + + bigint_kernel_c& operator+= ( + const bigint_kernel_c& rhs + ); + + const bigint_kernel_c operator- ( + const bigint_kernel_c& rhs + ) const; + bigint_kernel_c& operator-= ( + const bigint_kernel_c& rhs + ); + + const bigint_kernel_c operator* ( + const bigint_kernel_c& rhs + ) const; + + bigint_kernel_c& operator*= ( + const bigint_kernel_c& rhs + ); + + const bigint_kernel_c operator/ ( + const bigint_kernel_c& rhs + ) const; + + bigint_kernel_c& operator/= ( + const bigint_kernel_c& rhs + ); + + const bigint_kernel_c operator% ( + const bigint_kernel_c& rhs + ) const; + + bigint_kernel_c& operator%= ( + const bigint_kernel_c& rhs + ); + + bool operator < ( + const bigint_kernel_c& rhs + ) const; + + bool operator == ( + const bigint_kernel_c& rhs + ) const; + + bigint_kernel_c& operator= ( + const bigint_kernel_c& rhs + ); + + template + friend std::ostream& operator<< ( + std::ostream& out, + const bigint_kernel_c& rhs + ); + + template + friend std::istream& operator>> ( + std::istream& in, + bigint_kernel_c& rhs + ); + + bigint_kernel_c& operator++ ( + ); + + const bigint_kernel_c operator++ ( + int + ); + + bigint_kernel_c& operator-- ( + ); + + const bigint_kernel_c operator-- ( + int + ); + + template + friend const bigint_kernel_c operator+ ( + uint16 lhs, + const bigint_kernel_c& rhs + ); + + template + friend const bigint_kernel_c operator+ ( + const bigint_kernel_c& lhs, + uint16 rhs + ); + + bigint_kernel_c& operator+= ( + uint16 rhs + ); + + template + friend const bigint_kernel_c operator- ( + uint16 lhs, + const bigint_kernel_c& rhs + ); + + template + friend const bigint_kernel_c operator- ( + const bigint_kernel_c& lhs, + uint16 rhs + ); + + bigint_kernel_c& operator-= ( + uint16 rhs + ); + + template + friend const bigint_kernel_c operator* ( + uint16 lhs, + const bigint_kernel_c& rhs + ); + + template + friend const bigint_kernel_c operator* ( + const bigint_kernel_c& lhs, + uint16 rhs + ); + + bigint_kernel_c& operator*= ( + uint16 rhs + ); + + template + friend const bigint_kernel_c operator/ ( + uint16 lhs, + const bigint_kernel_c& rhs + ); + + template + friend const bigint_kernel_c operator/ ( + const bigint_kernel_c& lhs, + uint16 rhs + ); + + bigint_kernel_c& operator/= ( + uint16 rhs + ); + + template + friend const bigint_kernel_c operator% ( + uint16 lhs, + const bigint_kernel_c& rhs + ); + + template + friend const bigint_kernel_c operator% ( + const bigint_kernel_c& lhs, + uint16 rhs + ); + + bigint_kernel_c& operator%= ( + uint16 rhs + ); + + template + friend bool operator < ( + uint16 lhs, + const bigint_kernel_c& rhs + ); + + template + friend bool operator < ( + const bigint_kernel_c& lhs, + uint16 rhs + ); + + template + friend bool operator == ( + const bigint_kernel_c& lhs, + uint16 rhs + ); + + template + friend bool operator == ( + uint16 lhs, + const bigint_kernel_c& rhs + ); + + bigint_kernel_c& operator= ( + uint16 rhs + ); + + + void swap ( + bigint_kernel_c& item + ) { data.swap(item.data); } + + }; + + template < + typename bigint_base + > + void swap ( + bigint_kernel_c& a, + bigint_kernel_c& b + ) { a.swap(b); } + + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + inline void serialize ( + const bigint_kernel_c& item, + std::ostream& out + ) + { + std::ios::fmtflags oldflags = out.flags(); + out.flags(); + out << item << ' '; + out.flags(oldflags); + if (!out) throw serialization_error("Error serializing object of type bigint_kernel_c"); + } + + template < + typename bigint_base + > + inline void deserialize ( + bigint_kernel_c& item, + std::istream& in + ) + { + std::ios::fmtflags oldflags = in.flags(); + in.flags(); + in >> item; in.flags(oldflags); + if (in.get() != ' ') + { + item = 0; + throw serialization_error("Error deserializing object of type bigint_kernel_c"); + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c:: + bigint_kernel_c ( + ) + {} + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c:: + bigint_kernel_c ( + uint32 value + ) : + data(value) + { + // make sure requires clause is not broken + DLIB_CASSERT( value <= 0xFFFFFFFF , + "\tbigint::bigint(uint16)" + << "\n\t value must be <= (2^32)-1" + << "\n\tthis: " << this + << "\n\tvalue: " << value + ); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c:: + bigint_kernel_c ( + const bigint_kernel_c& item + ) : + data(item.data) + {} + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c:: + ~bigint_kernel_c ( + ) + {} + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + const bigint_kernel_c bigint_kernel_c:: + operator+ ( + const bigint_kernel_c& rhs + ) const + { + return bigint_kernel_c(data + rhs.data); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c& bigint_kernel_c:: + operator+= ( + const bigint_kernel_c& rhs + ) + { + data += rhs.data; + return *this; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + const bigint_kernel_c bigint_kernel_c:: + operator- ( + const bigint_kernel_c& rhs + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT( !(*this < rhs), + "\tconst bigint bigint::operator-(const bigint&)" + << "\n\t *this should not be less than rhs" + << "\n\tthis: " << this + << "\n\t*this: " << *this + << "\n\trhs: " << rhs + ); + + // call the real function + return bigint_kernel_c(data-rhs.data); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c& bigint_kernel_c:: + operator-= ( + const bigint_kernel_c& rhs + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( !(*this < rhs), + "\tbigint& bigint::operator-=(const bigint&)" + << "\n\t *this should not be less than rhs" + << "\n\tthis: " << this + << "\n\t*this: " << *this + << "\n\trhs: " << rhs + ); + + // call the real function + data -= rhs.data; + return *this; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + const bigint_kernel_c bigint_kernel_c:: + operator* ( + const bigint_kernel_c& rhs + ) const + { + return bigint_kernel_c(data * rhs.data ); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c& bigint_kernel_c:: + operator*= ( + const bigint_kernel_c& rhs + ) + { + data *= rhs.data; + return *this; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + const bigint_kernel_c bigint_kernel_c:: + operator/ ( + const bigint_kernel_c& rhs + ) const + { + //make sure requires clause is not broken + DLIB_CASSERT( !(rhs == 0), + "\tconst bigint bigint::operator/(const bigint&)" + << "\n\t can't divide by zero" + << "\n\tthis: " << this + ); + + // call the real function + return bigint_kernel_c(data/rhs.data); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c& bigint_kernel_c:: + operator/= ( + const bigint_kernel_c& rhs + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( !(rhs == 0), + "\tbigint& bigint::operator/=(const bigint&)" + << "\n\t can't divide by zero" + << "\n\tthis: " << this + ); + + // call the real function + data /= rhs.data; + return *this; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + const bigint_kernel_c bigint_kernel_c:: + operator% ( + const bigint_kernel_c& rhs + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT( !(rhs == 0), + "\tconst bigint bigint::operator%(const bigint&)" + << "\n\t can't divide by zero" + << "\n\tthis: " << this + ); + + // call the real function + return bigint_kernel_c(data%rhs.data); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c& bigint_kernel_c:: + operator%= ( + const bigint_kernel_c& rhs + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( !(rhs == 0), + "\tbigint& bigint::operator%=(const bigint&)" + << "\n\t can't divide by zero" + << "\n\tthis: " << this + ); + + // call the real function + data %= rhs.data; + return *this; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bool bigint_kernel_c:: + operator < ( + const bigint_kernel_c& rhs + ) const + { + return data < rhs.data; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bool bigint_kernel_c:: + operator == ( + const bigint_kernel_c& rhs + ) const + { + return data == rhs.data; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c& bigint_kernel_c:: + operator= ( + const bigint_kernel_c& rhs + ) + { + data = rhs.data; + return *this; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + std::ostream& operator<< ( + std::ostream& out, + const bigint_kernel_c& rhs + ) + { + out << rhs.data; + return out; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + std::istream& operator>> ( + std::istream& in, + bigint_kernel_c& rhs + ) + { + in >> rhs.data; + return in; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c& bigint_kernel_c:: + operator++ ( + ) + { + ++data; + return *this; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + const bigint_kernel_c bigint_kernel_c:: + operator++ ( + int + ) + { + return bigint_kernel_c(data++); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c& bigint_kernel_c:: + operator-- ( + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( !(*this == 0), + "\tbigint& bigint::operator--()" + << "\n\t *this to subtract from *this it must not be zero to begin with" + << "\n\tthis: " << this + ); + + // call the real function + --data; + return *this; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + const bigint_kernel_c bigint_kernel_c:: + operator-- ( + int + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( !(*this == 0), + "\tconst bigint bigint::operator--(int)" + << "\n\t *this to subtract from *this it must not be zero to begin with" + << "\n\tthis: " << this + ); + + // call the real function + return bigint_kernel_c(data--); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + const bigint_kernel_c operator+ ( + uint16 l, + const bigint_kernel_c& rhs + ) + { + uint32 lhs = l; + // make sure requires clause is not broken + DLIB_CASSERT( lhs <= 65535, + "\tconst bigint operator+(uint16, const bigint&)" + << "\n\t lhs must be <= 65535" + << "\n\trhs: " << rhs + << "\n\tlhs: " << lhs + ); + + return bigint_kernel_c(static_cast(lhs)+rhs.data); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + const bigint_kernel_c operator+ ( + const bigint_kernel_c& lhs, + uint16 r + ) + { + uint32 rhs = r; + // make sure requires clause is not broken + DLIB_CASSERT( rhs <= 65535, + "\tconst bigint operator+(const bigint&, uint16)" + << "\n\t rhs must be <= 65535" + << "\n\trhs: " << rhs + << "\n\tlhs: " << lhs + ); + + return bigint_kernel_c(lhs.data+static_cast(rhs)); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c& bigint_kernel_c:: + operator+= ( + uint16 r + ) + { + uint32 rhs = r; + // make sure requires clause is not broken + DLIB_CASSERT( rhs <= 65535, + "\tbigint& bigint::operator+=(uint16)" + << "\n\t rhs must be <= 65535" + << "\n\tthis: " << this + << "\n\t*this: " << *this + << "\n\trhs: " << rhs + ); + + data += rhs; + return *this; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + const bigint_kernel_c operator- ( + uint16 l, + const bigint_kernel_c& rhs + ) + { + uint32 lhs = l; + // make sure requires clause is not broken + DLIB_CASSERT( !(static_cast(lhs) < rhs) && lhs <= 65535, + "\tconst bigint operator-(uint16,const bigint&)" + << "\n\t lhs must be greater than or equal to rhs and lhs <= 65535" + << "\n\tlhs: " << lhs + << "\n\trhs: " << rhs + << "\n\t&lhs: " << &lhs + << "\n\t&rhs: " << &rhs + ); + + // call the real function + return bigint_kernel_c(static_cast(lhs)-rhs.data); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + const bigint_kernel_c operator- ( + const bigint_kernel_c& lhs, + uint16 r + ) + { + uint32 rhs = r; + // make sure requires clause is not broken + DLIB_CASSERT( !(lhs < static_cast(rhs)) && rhs <= 65535, + "\tconst bigint operator-(const bigint&,uint16)" + << "\n\t lhs must be greater than or equal to rhs and rhs <= 65535" + << "\n\tlhs: " << lhs + << "\n\trhs: " << rhs + << "\n\t&lhs: " << &lhs + << "\n\t&rhs: " << &rhs + ); + + // call the real function + return bigint_kernel_c(lhs.data-static_cast(rhs)); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c& bigint_kernel_c:: + operator-= ( + uint16 r + ) + { + uint32 rhs = r; + // make sure requires clause is not broken + DLIB_CASSERT( !(*this < static_cast(rhs)) && rhs <= 65535, + "\tbigint& bigint::operator-=(uint16)" + << "\n\t *this must not be less than rhs and rhs <= 65535" + << "\n\tthis: " << this + << "\n\t*this: " << *this + << "\n\trhs: " << rhs + ); + + // call the real function + data -= static_cast(rhs); + return *this; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + const bigint_kernel_c operator* ( + uint16 l, + const bigint_kernel_c& rhs + ) + { + uint32 lhs = l; + // make sure requires clause is not broken + DLIB_CASSERT( lhs <= 65535, + "\tconst bigint operator*(uint16, const bigint&)" + << "\n\t lhs must be <= 65535" + << "\n\trhs: " << rhs + << "\n\tlhs: " << lhs + ); + + return bigint_kernel_c(lhs*rhs.data); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + const bigint_kernel_c operator* ( + const bigint_kernel_c& lhs, + uint16 r + ) + { + uint32 rhs = r; + // make sure requires clause is not broken + DLIB_CASSERT( rhs <= 65535, + "\tconst bigint operator*(const bigint&, uint16)" + << "\n\t rhs must be <= 65535" + << "\n\trhs: " << rhs + << "\n\tlhs: " << lhs + ); + + return bigint_kernel_c(lhs.data*rhs); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c& bigint_kernel_c:: + operator*= ( + uint16 r + ) + { + uint32 rhs = r; + // make sure requires clause is not broken + DLIB_CASSERT( rhs <= 65535, + "\t bigint bigint::operator*=(uint16)" + << "\n\t rhs must be <= 65535" + << "\n\tthis: " << this + << "\n\t*this: " << *this + << "\n\trhs: " << rhs + ); + + data *= static_cast(rhs); + return *this; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + const bigint_kernel_c operator/ ( + uint16 l, + const bigint_kernel_c& rhs + ) + { + uint32 lhs = l; + // make sure requires clause is not broken + DLIB_CASSERT( !(rhs == 0) && lhs <= 65535, + "\tconst bigint operator/(uint16,const bigint&)" + << "\n\t you can't divide by zero and lhs <= 65535" + << "\n\t&lhs: " << &lhs + << "\n\t&rhs: " << &rhs + << "\n\tlhs: " << lhs + ); + + // call the real function + return bigint_kernel_c(lhs/rhs.data); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + const bigint_kernel_c operator/ ( + const bigint_kernel_c& lhs, + uint16 r + ) + { + uint32 rhs = r; + // make sure requires clause is not broken + DLIB_CASSERT( !(rhs == 0) && rhs <= 65535, + "\tconst bigint operator/(const bigint&,uint16)" + << "\n\t you can't divide by zero and rhs <= 65535" + << "\n\t&lhs: " << &lhs + << "\n\t&rhs: " << &rhs + << "\n\trhs: " << rhs + ); + + // call the real function + return bigint_kernel_c(lhs.data/static_cast(rhs)); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c& bigint_kernel_c:: + operator/= ( + uint16 rhs + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( !(rhs == 0) && static_cast(rhs) <= 65535, + "\tbigint& bigint::operator/=(uint16)" + << "\n\t you can't divide by zero and rhs must be <= 65535" + << "\n\tthis: " << this + << "\n\trhs: " << rhs + ); + + // call the real function + data /= rhs; + return *this; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + const bigint_kernel_c operator% ( + uint16 lhs, + const bigint_kernel_c& rhs + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( !(rhs == 0) && static_cast(lhs) <= 65535, + "\tconst bigint operator%(uint16,const bigint&)" + << "\n\t you can't divide by zero and lhs must be <= 65535" + << "\n\t&lhs: " << &lhs + << "\n\t&rhs: " << &rhs + << "\n\tlhs: " << lhs + ); + + // call the real function + return bigint_kernel_c(lhs%rhs.data); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + const bigint_kernel_c operator% ( + const bigint_kernel_c& lhs, + uint16 r + ) + { + uint32 rhs = r; + // make sure requires clause is not broken + DLIB_CASSERT( !(rhs == 0) && rhs <= 65535, + "\tconst bigint operator%(const bigint&,uint16)" + << "\n\t you can't divide by zero and rhs must be <= 65535" + << "\n\t&lhs: " << &lhs + << "\n\t&rhs: " << &rhs + << "\n\trhs: " << rhs + ); + + // call the real function + return bigint_kernel_c(lhs.data%static_cast(rhs)); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c& bigint_kernel_c:: + operator%= ( + uint16 r + ) + { + + uint32 rhs = r; + // make sure requires clause is not broken + DLIB_CASSERT( !(rhs == 0) && rhs <= 65535, + "\tbigint& bigint::operator%=(uint16)" + << "\n\t you can't divide by zero and rhs must be <= 65535" + << "\n\tthis: " << this + << "\n\trhs: " << rhs + ); + + // call the real function + data %= rhs; + return *this; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bool operator < ( + uint16 l, + const bigint_kernel_c& rhs + ) + { + uint32 lhs = l; + // make sure requires clause is not broken + DLIB_CASSERT( lhs <= 65535, + "\tbool operator<(uint16, const bigint&)" + << "\n\t lhs must be <= 65535" + << "\n\trhs: " << rhs + << "\n\tlhs: " << lhs + ); + + return static_cast(lhs) < rhs.data; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bool operator < ( + const bigint_kernel_c& lhs, + uint16 r + ) + { + uint32 rhs = r; + // make sure requires clause is not broken + DLIB_CASSERT( rhs <= 65535, + "\tbool operator<(const bigint&, uint16)" + << "\n\t rhs must be <= 65535" + << "\n\trhs: " << rhs + << "\n\tlhs: " << lhs + ); + + return lhs.data < static_cast(rhs); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bool operator == ( + const bigint_kernel_c& lhs, + uint16 r + ) + { + uint32 rhs = r; + // make sure requires clause is not broken + DLIB_CASSERT( rhs <= 65535, + "\tbool operator==(const bigint&, uint16)" + << "\n\t rhs must be <= 65535" + << "\n\trhs: " << rhs + << "\n\tlhs: " << lhs + ); + + return lhs.data == static_cast(rhs); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bool operator == ( + uint16 l, + const bigint_kernel_c& rhs + ) + { + uint32 lhs = l; + // make sure requires clause is not broken + DLIB_CASSERT( lhs <= 65535, + "\tbool operator==(uint16, const bigint&)" + << "\n\t lhs must be <= 65535" + << "\n\trhs: " << rhs + << "\n\tlhs: " << lhs + ); + + return static_cast(lhs) == rhs.data; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c& bigint_kernel_c:: + operator= ( + uint16 r + ) + { + uint32 rhs = r; + // make sure requires clause is not broken + DLIB_CASSERT( rhs <= 65535, + "\tbigint bigint::operator=(uint16)" + << "\n\t rhs must be <= 65535" + << "\n\t*this: " << *this + << "\n\tthis: " << this + << "\n\tlhs: " << rhs + ); + + data = static_cast(rhs); + return *this; + } + +// ---------------------------------------------------------------------------------------- + + template < typename bigint_base > + inline bool operator> (const bigint_kernel_c& a, const bigint_kernel_c& b) { return b < a; } + template < typename bigint_base > + inline bool operator!= (const bigint_kernel_c& a, const bigint_kernel_c& b) { return !(a == b); } + template < typename bigint_base > + inline bool operator<= (const bigint_kernel_c& a, const bigint_kernel_c& b) { return !(b < a); } + template < typename bigint_base > + inline bool operator>= (const bigint_kernel_c& a, const bigint_kernel_c& b) { return !(a < b); } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BIGINT_KERNEl_C_ + diff --git a/ml/dlib/dlib/binary_search_tree.h b/ml/dlib/dlib/binary_search_tree.h new file mode 100644 index 000000000..5273e8ce9 --- /dev/null +++ b/ml/dlib/dlib/binary_search_tree.h @@ -0,0 +1,50 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BINARY_SEARCH_TREe_ +#define DLIB_BINARY_SEARCH_TREe_ + + +#include "binary_search_tree/binary_search_tree_kernel_1.h" +#include "binary_search_tree/binary_search_tree_kernel_2.h" +#include "binary_search_tree/binary_search_tree_kernel_c.h" + + +#include "algs.h" +#include + + +namespace dlib +{ + + template < + typename domain, + typename range, + typename mem_manager = default_memory_manager, + typename compare = std::less + > + class binary_search_tree + { + binary_search_tree() {} + + public: + + //----------- kernels --------------- + + // kernel_1a + typedef binary_search_tree_kernel_1 + kernel_1a; + typedef binary_search_tree_kernel_c + kernel_1a_c; + + + // kernel_2a + typedef binary_search_tree_kernel_2 + kernel_2a; + typedef binary_search_tree_kernel_c + kernel_2a_c; + + }; +} + +#endif // DLIB_BINARY_SEARCH_TREe_ + diff --git a/ml/dlib/dlib/binary_search_tree/binary_search_tree_kernel_1.h b/ml/dlib/dlib/binary_search_tree/binary_search_tree_kernel_1.h new file mode 100644 index 000000000..418eb07d0 --- /dev/null +++ b/ml/dlib/dlib/binary_search_tree/binary_search_tree_kernel_1.h @@ -0,0 +1,2064 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BINARY_SEARCH_TREE_KERNEl_1_ +#define DLIB_BINARY_SEARCH_TREE_KERNEl_1_ + +#include "binary_search_tree_kernel_abstract.h" +#include "../algs.h" +#include "../interfaces/map_pair.h" +#include "../interfaces/enumerable.h" +#include "../interfaces/remover.h" +#include "../serialize.h" +#include +#include + +namespace dlib +{ + + template < + typename domain, + typename range, + typename mem_manager, + typename compare = std::less + > + class binary_search_tree_kernel_1 : public enumerable >, + public asc_pair_remover + { + + /*! + INITIAL VALUE + tree_size == 0 + tree_root == 0 + tree_height == 0 + at_start_ == true + current_element == 0 + stack == array of 50 node pointers + stack_pos == 0 + + + CONVENTION + tree_size == size() + tree_height == height() + + stack[stack_pos-1] == pop() + + current_element_valid() == (current_element != 0) + if (current_element_valid()) then + element() == current_element->d and current_element->r + at_start_ == at_start() + if (current_element != 0 && current_element != tree_root) then + stack[stack_pos-1] == the parent of the node pointed to by current_element + + if (tree_size != 0) + tree_root == pointer to the root node of the binary search tree + else + tree_root == 0 + + + for all nodes: + { + left points to the left subtree or 0 if there is no left subtree and + right points to the right subtree or 0 if there is no right subtree and + all elements in a left subtree are <= the root and + all elements in a right subtree are >= the root and + d is the item in the domain of *this contained in the node + r is the item in the range of *this contained in the node + balance: + balance == 0 if both subtrees have the same height + balance == -1 if the left subtree has a height that is greater + than the height of the right subtree by 1 + balance == 1 if the right subtree has a height that is greater + than the height of the left subtree by 1 + for all trees: + the height of the left and right subtrees differ by at most one + } + + !*/ + + class node + { + public: + node* left; + node* right; + domain d; + range r; + signed char balance; + }; + + class mpair : public map_pair + { + public: + const domain* d; + range* r; + + const domain& key( + ) const { return *d; } + + const range& value( + ) const { return *r; } + + range& value( + ) { return *r; } + }; + + + public: + + typedef domain domain_type; + typedef range range_type; + typedef compare compare_type; + typedef mem_manager mem_manager_type; + + binary_search_tree_kernel_1( + ) : + tree_size(0), + tree_root(0), + current_element(0), + tree_height(0), + at_start_(true), + stack_pos(0), + stack(ppool.allocate_array(50)) + { + } + + virtual ~binary_search_tree_kernel_1( + ); + + inline void clear( + ); + + inline short height ( + ) const; + + inline unsigned long count ( + const domain& item + ) const; + + inline void add ( + domain& d, + range& r + ); + + void remove ( + const domain& d, + domain& d_copy, + range& r + ); + + void destroy ( + const domain& item + ); + + inline const range* operator[] ( + const domain& item + ) const; + + inline range* operator[] ( + const domain& item + ); + + inline void swap ( + binary_search_tree_kernel_1& item + ); + + // function from the asc_pair_remover interface + void remove_any ( + domain& d, + range& r + ); + + // functions from the enumerable interface + inline size_t size ( + ) const; + + bool at_start ( + ) const; + + inline void reset ( + ) const; + + bool current_element_valid ( + ) const; + + const map_pair& element ( + ) const; + + map_pair& element ( + ); + + bool move_next ( + ) const; + + void remove_last_in_order ( + domain& d, + range& r + ); + + void remove_current_element ( + domain& d, + range& r + ); + + void position_enumerator ( + const domain& d + ) const; + + private: + + + inline void rotate_left ( + node*& t + ); + /*! + requires + - t->balance == 2 + - t->right->balance == 0 or 1 + - t == reference to the pointer in t's parent node that points to t + ensures + - #t is still a binary search tree + - #t->balance is between 1 and -1 + - #t now has a height smaller by 1 if #t->balance == 0 + !*/ + + inline void rotate_right ( + node*& t + ); + /*! + requires + - t->balance == -2 + - t->left->balance == 0 or -1 + - t == reference to the pointer in t's parent node that points to t + ensures + - #t is still a binary search tree + - #t->balance is between 1 and -1 + - #t now has a height smaller by 1 if #t->balance == 0 + + !*/ + + inline void double_rotate_right ( + node*& t + ); + /*! + requires + - t->balance == -2 + - t->left->balance == 1 + - t == reference to the pointer in t's parent node that points to t + ensures + - #t is still a binary search tree + - #t now has a balance of 0 + - #t now has a height smaller by 1 + !*/ + + inline void double_rotate_left ( + node*& t + ); + /*! + requires + - t->balance == 2 + - t->right->balance == -1 + - t == reference to the pointer in t's parent node that points to t + ensures + - #t is still a binary search tree + - #t now has a balance of 0 + - #t now has a height smaller by 1 + !*/ + + bool remove_biggest_element_in_tree ( + node*& t, + domain& d, + range& r + ); + /*! + requires + - t != 0 (i.e. there must be something in the tree to remove) + - t == reference to the pointer in t's parent node that points to t + ensures + - the biggest node in t has been removed + - the biggest node domain element in t has been put into #d + - the biggest node range element in t has been put into #r + - #t is still a binary search tree + - returns false if the height of the tree has not changed + - returns true if the height of the tree has shrunk by one + !*/ + + bool remove_least_element_in_tree ( + node*& t, + domain& d, + range& r + ); + /*! + requires + - t != 0 (i.e. there must be something in the tree to remove) + - t == reference to the pointer in t's parent node that points to t + ensures + - the least node in t has been removed + - the least node domain element in t has been put into #d + - the least node range element in t has been put into #r + - #t is still a binary search tree + - returns false if the height of the tree has not changed + - returns true if the height of the tree has shrunk by one + !*/ + + bool add_to_tree ( + node*& t, + domain& d, + range& r + ); + /*! + requires + - t == reference to the pointer in t's parent node that points to t + ensures + - the mapping (d --> r) has been added to #t + - #d and #r have initial values for their types + - #t is still a binary search tree + - returns false if the height of the tree has not changed + - returns true if the height of the tree has grown by one + !*/ + + bool remove_from_tree ( + node*& t, + const domain& d, + domain& d_copy, + range& r + ); + /*! + requires + - return_reference(t,d) != 0 + - t == reference to the pointer in t's parent node that points to t + ensures + - #d_copy is equivalent to d + - an element in t equivalent to d has been removed and swapped + into #d_copy and its associated range object has been + swapped into #r + - #t is still a binary search tree + - returns false if the height of the tree has not changed + - returns true if the height of the tree has shrunk by one + !*/ + + bool remove_from_tree ( + node*& t, + const domain& item + ); + /*! + requires + - return_reference(t,item) != 0 + - t == reference to the pointer in t's parent node that points to t + ensures + - an element in t equivalent to item has been removed + - #t is still a binary search tree + - returns false if the height of the tree has not changed + - returns true if the height of the tree has shrunk by one + !*/ + + const range* return_reference ( + const node* t, + const domain& d + ) const; + /*! + ensures + - if (there is a domain element equivalent to d in t) then + - returns a pointer to the element in the range equivalent to d + - else + - returns 0 + !*/ + + range* return_reference ( + node* t, + const domain& d + ); + /*! + ensures + - if (there is a domain element equivalent to d in t) then + - returns a pointer to the element in the range equivalent to d + - else + - returns 0 + !*/ + + + inline bool keep_node_balanced ( + node*& t + ); + /*! + requires + - t != 0 + - t == reference to the pointer in t's parent node that points to t + ensures + - if (t->balance is < 1 or > 1) then + - keep_node_balanced() will ensure that #t->balance == 0, -1, or 1 + - #t is still a binary search tree + - returns true if it made the tree one height shorter + - returns false if it didn't change the height + !*/ + + + unsigned long get_count ( + const domain& item, + node* tree_root + ) const; + /*! + requires + - tree_root == the root of a binary search tree or 0 + ensures + - if (tree_root == 0) then + - returns 0 + - else + - returns the number of elements in tree_root that are + equivalent to item + !*/ + + + void delete_tree ( + node* t + ); + /*! + requires + - t != 0 + ensures + - deallocates the node pointed to by t and all of t's left and right children + !*/ + + + void push ( + node* n + ) const { stack[stack_pos] = n; ++stack_pos; } + /*! + ensures + - pushes n onto the stack + !*/ + + + node* pop ( + ) const { --stack_pos; return stack[stack_pos]; } + /*! + ensures + - pops the top of the stack and returns it + !*/ + + + + bool fix_stack ( + node* t, + unsigned char depth = 0 + ); + /*! + requires + - current_element != 0 + - depth == 0 + - t == tree_root + ensures + - makes the stack contain the correct set of parent pointers. + also adjusts stack_pos so it is correct. + - #t is still a binary search tree + !*/ + + bool remove_current_element_from_tree ( + node*& t, + domain& d, + range& r, + unsigned long cur_stack_pos = 1 + ); + /*! + requires + - t == tree_root + - cur_stack_pos == 1 + - current_element != 0 + ensures + - removes the data in the node given by current_element and swaps it into + #d and #r. + - #t is still a binary search tree + - the enumerator is advances on to the next element but its stack is + potentially corrupted. so you must call fix_stack(tree_root) to fix + it. + - returns false if the height of the tree has not changed + - returns true if the height of the tree has shrunk by one + !*/ + + + // data members + + mutable mpair p; + unsigned long tree_size; + node* tree_root; + mutable node* current_element; + typename mem_manager::template rebind::other pool; + typename mem_manager::template rebind::other ppool; + short tree_height; + mutable bool at_start_; + mutable unsigned char stack_pos; + mutable node** stack; + compare comp; + + // restricted functions + binary_search_tree_kernel_1(binary_search_tree_kernel_1&); + binary_search_tree_kernel_1& operator=(binary_search_tree_kernel_1&); + + + }; + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + inline void swap ( + binary_search_tree_kernel_1& a, + binary_search_tree_kernel_1& b + ) { a.swap(b); } + + + + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void deserialize ( + binary_search_tree_kernel_1& item, + std::istream& in + ) + { + try + { + item.clear(); + unsigned long size; + deserialize(size,in); + domain d; + range r; + for (unsigned long i = 0; i < size; ++i) + { + deserialize(d,in); + deserialize(r,in); + item.add(d,r); + } + } + catch (serialization_error e) + { + item.clear(); + throw serialization_error(e.info + "\n while deserializing object of type binary_search_tree_kernel_1"); + } + } + + + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + binary_search_tree_kernel_1:: + ~binary_search_tree_kernel_1 ( + ) + { + ppool.deallocate_array(stack); + if (tree_size != 0) + { + delete_tree(tree_root); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_1:: + clear ( + ) + { + if (tree_size > 0) + { + delete_tree(tree_root); + tree_root = 0; + tree_size = 0; + tree_height = 0; + } + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + size_t binary_search_tree_kernel_1:: + size ( + ) const + { + return tree_size; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + short binary_search_tree_kernel_1:: + height ( + ) const + { + return tree_height; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + unsigned long binary_search_tree_kernel_1:: + count ( + const domain& item + ) const + { + return get_count(item,tree_root); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_1:: + add ( + domain& d, + range& r + ) + { + tree_height += add_to_tree(tree_root,d,r); + ++tree_size; + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_1:: + remove ( + const domain& d, + domain& d_copy, + range& r + ) + { + tree_height -= remove_from_tree(tree_root,d,d_copy,r); + --tree_size; + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_1:: + destroy ( + const domain& item + ) + { + tree_height -= remove_from_tree(tree_root,item); + --tree_size; + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_1:: + remove_any ( + domain& d, + range& r + ) + { + tree_height -= remove_least_element_in_tree(tree_root,d,r); + --tree_size; + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + range* binary_search_tree_kernel_1:: + operator[] ( + const domain& item + ) + { + return return_reference(tree_root,item); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + const range* binary_search_tree_kernel_1:: + operator[] ( + const domain& item + ) const + { + return return_reference(tree_root,item); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_1:: + swap ( + binary_search_tree_kernel_1& item + ) + { + pool.swap(item.pool); + ppool.swap(item.ppool); + exchange(p,item.p); + exchange(stack,item.stack); + exchange(stack_pos,item.stack_pos); + exchange(comp,item.comp); + + + node* tree_root_temp = item.tree_root; + unsigned long tree_size_temp = item.tree_size; + short tree_height_temp = item.tree_height; + node* current_element_temp = item.current_element; + bool at_start_temp = item.at_start_; + + item.tree_root = tree_root; + item.tree_size = tree_size; + item.tree_height = tree_height; + item.current_element = current_element; + item.at_start_ = at_start_; + + tree_root = tree_root_temp; + tree_size = tree_size_temp; + tree_height = tree_height_temp; + current_element = current_element_temp; + at_start_ = at_start_temp; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_1:: + remove_last_in_order ( + domain& d, + range& r + ) + { + tree_height -= remove_biggest_element_in_tree(tree_root,d,r); + --tree_size; + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_1:: + remove_current_element ( + domain& d, + range& r + ) + { + tree_height -= remove_current_element_from_tree(tree_root,d,r); + --tree_size; + + // fix the enumerator stack if we need to + if (current_element) + fix_stack(tree_root); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_1:: + position_enumerator ( + const domain& d + ) const + { + // clear the enumerator state and make sure the stack is empty + reset(); + at_start_ = false; + node* t = tree_root; + bool went_left = false; + while (t != 0) + { + if ( comp(d , t->d) ) + { + push(t); + // if item is on the left then look in left + t = t->left; + went_left = true; + } + else if (comp(t->d , d)) + { + push(t); + // if item is on the right then look in right + t = t->right; + went_left = false; + } + else + { + current_element = t; + return; + } + } + + // if we didn't find any matches but there might be something after the + // d in this tree. + if (stack_pos > 0) + { + current_element = pop(); + // if we went left from this node then this node is the next + // biggest. + if (went_left) + { + return; + } + else + { + move_next(); + } + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // enumerable function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + bool binary_search_tree_kernel_1:: + at_start ( + ) const + { + return at_start_; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_1:: + reset ( + ) const + { + at_start_ = true; + current_element = 0; + stack_pos = 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + bool binary_search_tree_kernel_1:: + current_element_valid ( + ) const + { + return (current_element != 0); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + const map_pair& binary_search_tree_kernel_1:: + element ( + ) const + { + p.d = &(current_element->d); + p.r = &(current_element->r); + return p; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + map_pair& binary_search_tree_kernel_1:: + element ( + ) + { + p.d = &(current_element->d); + p.r = &(current_element->r); + return p; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + bool binary_search_tree_kernel_1:: + move_next ( + ) const + { + // if we haven't started iterating yet + if (at_start_) + { + at_start_ = false; + if (tree_size == 0) + { + return false; + } + else + { + // find the first element in the tree + current_element = tree_root; + node* temp = current_element->left; + while (temp != 0) + { + push(current_element); + current_element = temp; + temp = current_element->left; + } + return true; + } + } + else + { + if (current_element == 0) + { + return false; + } + else + { + node* temp; + bool went_up; // true if we went up the tree from a child node to parent + bool from_left = false; // true if we went up and were coming from a left child node + // find the next element in the tree + if (current_element->right != 0) + { + // go right and down + temp = current_element; + push(current_element); + current_element = temp->right; + went_up = false; + } + else + { + // go up to the parent if we can + if (current_element == tree_root) + { + // in this case we have iterated over all the element of the tree + current_element = 0; + return false; + } + went_up = true; + node* parent = pop(); + + + from_left = (parent->left == current_element); + // go up to parent + current_element = parent; + } + + + while (true) + { + if (went_up) + { + if (from_left) + { + // in this case we have found the next node + break; + } + else + { + if (current_element == tree_root) + { + // in this case we have iterated over all the elements + // in the tree + current_element = 0; + return false; + } + // we should go up + node* parent = pop(); + from_left = (parent->left == current_element); + current_element = parent; + } + } + else + { + // we just went down to a child node + if (current_element->left != 0) + { + // go left + went_up = false; + temp = current_element; + push(current_element); + current_element = temp->left; + } + else + { + // if there is no left child then we have found the next node + break; + } + } + } + + return true; + } + } + + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // private member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_1:: + delete_tree ( + node* t + ) + { + if (t->left != 0) + delete_tree(t->left); + if (t->right != 0) + delete_tree(t->right); + pool.deallocate(t); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_1:: + rotate_left ( + node*& t + ) + { + + // set the new balance numbers + if (t->right->balance == 1) + { + t->balance = 0; + t->right->balance = 0; + } + else + { + t->balance = 1; + t->right->balance = -1; + } + + // perform the rotation + node* temp = t->right; + t->right = temp->left; + temp->left = t; + t = temp; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_1:: + rotate_right ( + node*& t + ) + { + // set the new balance numbers + if (t->left->balance == -1) + { + t->balance = 0; + t->left->balance = 0; + } + else + { + t->balance = -1; + t->left->balance = 1; + } + + // preform the rotation + node* temp = t->left; + t->left = temp->right; + temp->right = t; + t = temp; + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_1:: + double_rotate_right ( + node*& t + ) + { + + node* temp = t; + t = t->left->right; + + temp->left->right = t->left; + t->left = temp->left; + + temp->left = t->right; + t->right = temp; + + if (t->balance < 0) + { + t->left->balance = 0; + t->right->balance = 1; + } + else if (t->balance > 0) + { + t->left->balance = -1; + t->right->balance = 0; + } + else + { + t->left->balance = 0; + t->right->balance = 0; + } + t->balance = 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_1:: + double_rotate_left ( + node*& t + ) + { + node* temp = t; + t = t->right->left; + + temp->right->left = t->right; + t->right = temp->right; + + temp->right = t->left; + t->left = temp; + + if (t->balance < 0) + { + t->left->balance = 0; + t->right->balance = 1; + } + else if (t->balance > 0) + { + t->left->balance = -1; + t->right->balance = 0; + } + else + { + t->left->balance = 0; + t->right->balance = 0; + } + + t->balance = 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + bool binary_search_tree_kernel_1:: + remove_biggest_element_in_tree ( + node*& t, + domain& d, + range& r + ) + { + // make a reference to the current node so we don't have to dereference a + // pointer a bunch of times + node& tree = *t; + + // if the right tree is an empty tree + if ( tree.right == 0) + { + // swap nodes domain and range elements into d and r + exchange(d,tree.d); + exchange(r,tree.r); + + // plug hole left by removing this node + t = tree.left; + + // delete the node that was just removed + pool.deallocate(&tree); + + // return that the height of this part of the tree has decreased + return true; + } + else + { + + // keep going right + + // if remove made the tree one height shorter + if ( remove_biggest_element_in_tree(tree.right,d,r) ) + { + // if this caused the current tree to strink then report that + if ( tree.balance == 1) + { + --tree.balance; + return true; + } + else + { + --tree.balance; + return keep_node_balanced(t); + } + } + + return false; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + bool binary_search_tree_kernel_1:: + remove_least_element_in_tree ( + node*& t, + domain& d, + range& r + ) + { + // make a reference to the current node so we don't have to dereference a + // pointer a bunch of times + node& tree = *t; + + // if the left tree is an empty tree + if ( tree.left == 0) + { + // swap nodes domain and range elements into d and r + exchange(d,tree.d); + exchange(r,tree.r); + + // plug hole left by removing this node + t = tree.right; + + // delete the node that was just removed + pool.deallocate(&tree); + + // return that the height of this part of the tree has decreased + return true; + } + else + { + + // keep going left + + // if remove made the tree one height shorter + if ( remove_least_element_in_tree(tree.left,d,r) ) + { + // if this caused the current tree to strink then report that + if ( tree.balance == -1) + { + ++tree.balance; + return true; + } + else + { + ++tree.balance; + return keep_node_balanced(t); + } + } + + return false; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + bool binary_search_tree_kernel_1:: + add_to_tree ( + node*& t, + domain& d, + range& r + ) + { + + // if found place to add + if (t == 0) + { + // create a node to add new item into + t = pool.allocate(); + + // make a reference to the current node so we don't have to dereference a + // pointer a bunch of times + node& tree = *t; + + + // set left and right pointers to NULL to indicate that there are no + // left or right subtrees + tree.left = 0; + tree.right = 0; + tree.balance = 0; + + // put d and r into t + exchange(tree.d,d); + exchange(tree.r,r); + + // indicate that the height of this tree has increased + return true; + } + else // keep looking for a place to add the new item + { + // make a reference to the current node so we don't have to dereference + // a pointer a bunch of times + node& tree = *t; + signed char old_balance = tree.balance; + + // add the new item to whatever subtree it should go into + if (comp( d , tree.d) ) + tree.balance -= add_to_tree(tree.left,d,r); + else + tree.balance += add_to_tree(tree.right,d,r); + + + // if the tree was balanced to start with + if (old_balance == 0) + { + // if its not balanced anymore then it grew in height + if (tree.balance != 0) + return true; + else + return false; + } + else + { + // if the tree is now balanced then it didn't grow + if (tree.balance == 0) + { + return false; + } + else + { + // if the tree needs to be balanced + if (tree.balance != old_balance) + { + return !keep_node_balanced(t); + } + // if there has been no change in the heights + else + { + return false; + } + } + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + bool binary_search_tree_kernel_1:: + fix_stack ( + node* t, + unsigned char depth + ) + { + // if we found the node we were looking for + if (t == current_element) + { + stack_pos = depth; + return true; + } + else if (t == 0) + { + return false; + } + + if (!( comp(t->d , current_element->d))) + { + // go left + if (fix_stack(t->left,depth+1)) + { + stack[depth] = t; + return true; + } + } + if (!(comp(current_element->d , t->d))) + { + // go right + if (fix_stack(t->right,depth+1)) + { + stack[depth] = t; + return true; + } + } + return false; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + bool binary_search_tree_kernel_1:: + remove_current_element_from_tree ( + node*& t, + domain& d, + range& r, + unsigned long cur_stack_pos + ) + { + // make a reference to the current node so we don't have to dereference + // a pointer a bunch of times + node& tree = *t; + + // if we found the node we were looking for + if (t == current_element) + { + + // swap nodes domain and range elements into d_copy and r + exchange(d,tree.d); + exchange(r,tree.r); + + // if there is no left node + if (tree.left == 0) + { + // move the enumerator on to the next element before we mess with the + // tree + move_next(); + + // plug hole left by removing this node and free memory + t = tree.right; // plug hole with right subtree + + // delete old node + pool.deallocate(&tree); + + // indicate that the height has changed + return true; + } + // if there is no right node + else if (tree.right == 0) + { + // move the enumerator on to the next element before we mess with the + // tree + move_next(); + + // plug hole left by removing this node and free memory + t = tree.left; // plug hole with left subtree + + // delete old node + pool.deallocate(&tree); + + // indicate that the height of this tree has changed + return true; + } + // if there are both a left and right sub node + else + { + + // in this case the next current element is going to get swapped back + // into this t node. + current_element = t; + + // get an element that can replace the one being removed and do this + // if it made the right subtree shrink by one + if (remove_least_element_in_tree(tree.right,tree.d,tree.r)) + { + // adjust the tree height + --tree.balance; + + // if the height of the current tree has dropped by one + if (tree.balance == 0) + { + return true; + } + else + { + return keep_node_balanced(t); + } + } + // else this remove did not effect the height of this tree + else + { + return false; + } + + } + + } + else if ( (cur_stack_pos < stack_pos && stack[cur_stack_pos] == tree.left) || + tree.left == current_element ) + { + // go left + if (tree.balance == -1) + { + int balance = tree.balance; + balance += remove_current_element_from_tree(tree.left,d,r,cur_stack_pos+1); + tree.balance = balance; + return !tree.balance; + } + else + { + int balance = tree.balance; + balance += remove_current_element_from_tree(tree.left,d,r,cur_stack_pos+1); + tree.balance = balance; + return keep_node_balanced(t); + } + } + else if ( (cur_stack_pos < stack_pos && stack[cur_stack_pos] == tree.right) || + tree.right == current_element ) + { + // go right + if (tree.balance == 1) + { + int balance = tree.balance; + balance -= remove_current_element_from_tree(tree.right,d,r,cur_stack_pos+1); + tree.balance = balance; + return !tree.balance; + } + else + { + int balance = tree.balance; + balance -= remove_current_element_from_tree(tree.right,d,r,cur_stack_pos+1); + tree.balance = balance; + return keep_node_balanced(t); + } + } + + // this return should never happen but do it anyway to suppress compiler warnings + return false; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + bool binary_search_tree_kernel_1:: + remove_from_tree ( + node*& t, + const domain& d, + domain& d_copy, + range& r + ) + { + // make a reference to the current node so we don't have to dereference + // a pointer a bunch of times + node& tree = *t; + + // if item is on the left + if (comp(d , tree.d)) + { + // if the left side of the tree has the greatest height + if (tree.balance == -1) + { + int balance = tree.balance; + balance += remove_from_tree(tree.left,d,d_copy,r); + tree.balance = balance; + return !tree.balance; + } + else + { + int balance = tree.balance; + balance += remove_from_tree(tree.left,d,d_copy,r); + tree.balance = balance; + return keep_node_balanced(t); + } + + } + // if item is on the right + else if (comp(tree.d , d)) + { + + // if the right side of the tree has the greatest height + if (tree.balance == 1) + { + int balance = tree.balance; + balance -= remove_from_tree(tree.right,d,d_copy,r); + tree.balance = balance; + return !tree.balance; + } + else + { + int balance = tree.balance; + balance -= remove_from_tree(tree.right,d,d_copy,r); + tree.balance = balance; + return keep_node_balanced(t); + } + } + // if item is found + else + { + + // swap nodes domain and range elements into d_copy and r + exchange(d_copy,tree.d); + exchange(r,tree.r); + + // if there is no left node + if (tree.left == 0) + { + + // plug hole left by removing this node and free memory + t = tree.right; // plug hole with right subtree + + // delete old node + pool.deallocate(&tree); + + // indicate that the height has changed + return true; + } + // if there is no right node + else if (tree.right == 0) + { + + // plug hole left by removing this node and free memory + t = tree.left; // plug hole with left subtree + + // delete old node + pool.deallocate(&tree); + + // indicate that the height of this tree has changed + return true; + } + // if there are both a left and right sub node + else + { + + // get an element that can replace the one being removed and do this + // if it made the right subtree shrink by one + if (remove_least_element_in_tree(tree.right,tree.d,tree.r)) + { + // adjust the tree height + --tree.balance; + + // if the height of the current tree has dropped by one + if (tree.balance == 0) + { + return true; + } + else + { + return keep_node_balanced(t); + } + } + // else this remove did not effect the height of this tree + else + { + return false; + } + + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + bool binary_search_tree_kernel_1:: + remove_from_tree ( + node*& t, + const domain& d + ) + { + // make a reference to the current node so we don't have to dereference + // a pointer a bunch of times + node& tree = *t; + + // if item is on the left + if (comp(d , tree.d)) + { + // if the left side of the tree has the greatest height + if (tree.balance == -1) + { + int balance = tree.balance; + balance += remove_from_tree(tree.left,d); + tree.balance = balance; + return !tree.balance; + } + else + { + int balance = tree.balance; + balance += remove_from_tree(tree.left,d); + tree.balance = balance; + return keep_node_balanced(t); + } + + } + // if item is on the right + else if (comp(tree.d , d)) + { + + // if the right side of the tree has the greatest height + if (tree.balance == 1) + { + int balance = tree.balance; + balance -= remove_from_tree(tree.right,d); + tree.balance = balance; + return !tree.balance; + } + else + { + int balance = tree.balance; + balance -= remove_from_tree(tree.right,d); + tree.balance = balance; + return keep_node_balanced(t); + } + } + // if item is found + else + { + + // if there is no left node + if (tree.left == 0) + { + + // plug hole left by removing this node and free memory + t = tree.right; // plug hole with right subtree + + // delete old node + pool.deallocate(&tree); + + // indicate that the height has changed + return true; + } + // if there is no right node + else if (tree.right == 0) + { + + // plug hole left by removing this node and free memory + t = tree.left; // plug hole with left subtree + + // delete old node + pool.deallocate(&tree); + + // indicate that the height of this tree has changed + return true; + } + // if there are both a left and right sub node + else + { + + // get an element that can replace the one being removed and do this + // if it made the right subtree shrink by one + if (remove_least_element_in_tree(tree.right,tree.d,tree.r)) + { + // adjust the tree height + --tree.balance; + + // if the height of the current tree has dropped by one + if (tree.balance == 0) + { + return true; + } + else + { + return keep_node_balanced(t); + } + } + // else this remove did not effect the height of this tree + else + { + return false; + } + + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + range* binary_search_tree_kernel_1:: + return_reference ( + node* t, + const domain& d + ) + { + while (t != 0) + { + + if ( comp(d , t->d )) + { + // if item is on the left then look in left + t = t->left; + } + else if (comp(t->d , d)) + { + // if item is on the right then look in right + t = t->right; + } + else + { + // if it's found then return a reference to it + return &(t->r); + } + } + return 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + const range* binary_search_tree_kernel_1:: + return_reference ( + const node* t, + const domain& d + ) const + { + while (t != 0) + { + + if ( comp(d , t->d) ) + { + // if item is on the left then look in left + t = t->left; + } + else if (comp(t->d , d)) + { + // if item is on the right then look in right + t = t->right; + } + else + { + // if it's found then return a reference to it + return &(t->r); + } + } + return 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + bool binary_search_tree_kernel_1:: + keep_node_balanced ( + node*& t + ) + { + // make a reference to the current node so we don't have to dereference + // a pointer a bunch of times + node& tree = *t; + + // if tree does not need to be balanced then return false + if (tree.balance == 0) + return false; + + + // if tree needs to be rotated left + if (tree.balance == 2) + { + if (tree.right->balance >= 0) + rotate_left(t); + else + double_rotate_left(t); + } + // else if the tree needs to be rotated right + else if (tree.balance == -2) + { + if (tree.left->balance <= 0) + rotate_right(t); + else + double_rotate_right(t); + } + + + if (t->balance == 0) + return true; + else + return false; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + unsigned long binary_search_tree_kernel_1:: + get_count ( + const domain& d, + node* tree_root + ) const + { + if (tree_root != 0) + { + if (comp(d , tree_root->d)) + { + // go left + return get_count(d,tree_root->left); + } + else if (comp(tree_root->d , d)) + { + // go right + return get_count(d,tree_root->right); + } + else + { + // go left and right to look for more matches + return get_count(d,tree_root->left) + + get_count(d,tree_root->right) + + 1; + } + } + return 0; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BINARY_SEARCH_TREE_KERNEl_1_ + diff --git a/ml/dlib/dlib/binary_search_tree/binary_search_tree_kernel_2.h b/ml/dlib/dlib/binary_search_tree/binary_search_tree_kernel_2.h new file mode 100644 index 000000000..098d38c2e --- /dev/null +++ b/ml/dlib/dlib/binary_search_tree/binary_search_tree_kernel_2.h @@ -0,0 +1,1897 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BINARY_SEARCH_TREE_KERNEl_2_ +#define DLIB_BINARY_SEARCH_TREE_KERNEl_2_ + +#include "binary_search_tree_kernel_abstract.h" +#include "../algs.h" +#include "../interfaces/map_pair.h" +#include "../interfaces/enumerable.h" +#include "../interfaces/remover.h" +#include "../serialize.h" +#include + +namespace dlib +{ + + template < + typename domain, + typename range, + typename mem_manager, + typename compare = std::less + > + class binary_search_tree_kernel_2 : public enumerable >, + public asc_pair_remover + { + + /*! + INITIAL VALUE + NIL == pointer to a node that represents a leaf + tree_size == 0 + tree_root == NIL + at_start == true + current_element == 0 + + + CONVENTION + current_element_valid() == (current_element != 0) + if (current_element_valid()) then + element() == current_element->d and current_element->r + at_start_ == at_start() + + + tree_size == size() + + NIL == pointer to a node that represents a leaf + + if (tree_size != 0) + tree_root == pointer to the root node of the binary search tree + else + tree_root == NIL + + tree_root->color == black + Every leaf is black and all leafs are the NIL node. + The number of black nodes in any path from the root to a leaf is the + same. + + for all nodes: + { + - left points to the left subtree or NIL if there is no left subtree + - right points to the right subtree or NIL if there is no right + subtree + - parent points to the parent node or NIL if the node is the root + - ordering of nodes is determined by comparing each node's d member + - all elements in a left subtree are <= the node + - all elements in a right subtree are >= the node + - color == red or black + - if (color == red) + - the node's children are black + } + + !*/ + + class node + { + public: + node* left; + node* right; + node* parent; + domain d; + range r; + char color; + }; + + class mpair : public map_pair + { + public: + const domain* d; + range* r; + + const domain& key( + ) const { return *d; } + + const range& value( + ) const { return *r; } + + range& value( + ) { return *r; } + }; + + + const static char red = 0; + const static char black = 1; + + + public: + + typedef domain domain_type; + typedef range range_type; + typedef compare compare_type; + typedef mem_manager mem_manager_type; + + binary_search_tree_kernel_2( + ) : + NIL(pool.allocate()), + tree_size(0), + tree_root(NIL), + current_element(0), + at_start_(true) + { + NIL->color = black; + NIL->left = 0; + NIL->right = 0; + NIL->parent = 0; + } + + virtual ~binary_search_tree_kernel_2( + ); + + inline void clear( + ); + + inline short height ( + ) const; + + inline unsigned long count ( + const domain& d + ) const; + + inline void add ( + domain& d, + range& r + ); + + void remove ( + const domain& d, + domain& d_copy, + range& r + ); + + void destroy ( + const domain& d + ); + + void remove_any ( + domain& d, + range& r + ); + + inline const range* operator[] ( + const domain& item + ) const; + + inline range* operator[] ( + const domain& item + ); + + inline void swap ( + binary_search_tree_kernel_2& item + ); + + // functions from the enumerable interface + inline size_t size ( + ) const; + + bool at_start ( + ) const; + + inline void reset ( + ) const; + + bool current_element_valid ( + ) const; + + const map_pair& element ( + ) const; + + map_pair& element ( + ); + + bool move_next ( + ) const; + + void remove_last_in_order ( + domain& d, + range& r + ); + + void remove_current_element ( + domain& d, + range& r + ); + + void position_enumerator ( + const domain& d + ) const; + + private: + + inline void rotate_left ( + node* t + ); + /*! + requires + - t != NIL + - t->right != NIL + ensures + - performs a left rotation around t and its right child + !*/ + + inline void rotate_right ( + node* t + ); + /*! + requires + - t != NIL + - t->left != NIL + ensures + - performs a right rotation around t and its left child + !*/ + + inline void double_rotate_right ( + node* t + ); + /*! + requires + - t != NIL + - t->left != NIL + - t->left->right != NIL + - double_rotate_right() is only called in fix_after_add() + ensures + - performs a left rotation around t->left + - then performs a right rotation around t + !*/ + + inline void double_rotate_left ( + node* t + ); + /*! + requires + - t != NIL + - t->right != NIL + - t->right->left != NIL + - double_rotate_left() is only called in fix_after_add() + ensures + - performs a right rotation around t->right + - then performs a left rotation around t + !*/ + + void remove_biggest_element_in_tree ( + node* t, + domain& d, + range& r + ); + /*! + requires + - t != NIL (i.e. there must be something in the tree to remove) + ensures + - the biggest node in t has been removed + - the biggest node element in t has been put into #d and #r + - #t is still a binary search tree + !*/ + + bool remove_least_element_in_tree ( + node* t, + domain& d, + range& r + ); + /*! + requires + - t != NIL (i.e. there must be something in the tree to remove) + ensures + - the least node in t has been removed + - the least node element in t has been put into #d and #r + - #t is still a binary search tree + - if (the node that was removed was the one pointed to by current_element) then + - returns true + - else + - returns false + !*/ + + void add_to_tree ( + node* t, + domain& d, + range& r + ); + /*! + requires + - t != NIL + ensures + - d and r are now in #t + - there is a mapping from d to r in #t + - #d and #r have initial values for their types + - #t is still a binary search tree + !*/ + + void remove_from_tree ( + node* t, + const domain& d, + domain& d_copy, + range& r + ); + /*! + requires + - return_reference(t,d) != 0 + ensures + - #d_copy is equivalent to d + - the first element in t equivalent to d that is encountered when searching down the tree + from t has been removed and swapped into #d_copy. Also, the associated range element + has been removed and swapped into #r. + - if (the node that got removed wasn't current_element) then + - adjusts the current_element pointer if the data in the node that it points to gets moved. + - else + - the value of current_element is now invalid + - #t is still a binary search tree + !*/ + + void remove_from_tree ( + node* t, + const domain& d + ); + /*! + requires + - return_reference(t,d) != 0 + ensures + - an element in t equivalent to d has been removed + - #t is still a binary search tree + !*/ + + const range* return_reference ( + const node* t, + const domain& d + ) const; + /*! + ensures + - if (there is a domain element equivalent to d in t) then + - returns a pointer to the element in the range equivalent to d + - else + - returns 0 + !*/ + + range* return_reference ( + node* t, + const domain& d + ); + /*! + ensures + - if (there is a domain element equivalent to d in t) then + - returns a pointer to the element in the range equivalent to d + - else + - returns 0 + !*/ + + void fix_after_add ( + node* t + ); + /*! + requires + - t == pointer to the node just added + - t->color == red + - t->parent != NIL (t must not be the root) + - fix_after_add() is only called after a new node has been added + to t + ensures + - fixes any deviations from the CONVENTION caused by adding a node + !*/ + + void fix_after_remove ( + node* t + ); + /*! + requires + - t == pointer to the only child of the node that was spliced out + - fix_after_remove() is only called after a node has been removed + from t + - the color of the spliced out node was black + ensures + - fixes any deviations from the CONVENTION causes by removing a node + !*/ + + + short tree_height ( + node* t + ) const; + /*! + ensures + - returns the number of nodes in the longest path from the root of the + tree to a leaf + !*/ + + void delete_tree ( + node* t + ); + /*! + requires + - t == root of binary search tree + - t != NIL + ensures + - deletes all nodes in t except for NIL + !*/ + + unsigned long get_count ( + const domain& item, + node* tree_root + ) const; + /*! + requires + - tree_root == the root of a binary search tree or NIL + ensures + - if (tree_root == NIL) then + - returns 0 + - else + - returns the number of elements in tree_root that are + equivalent to item + !*/ + + + + // data members + typename mem_manager::template rebind::other pool; + node* NIL; + unsigned long tree_size; + node* tree_root; + mutable node* current_element; + mutable bool at_start_; + mutable mpair p; + compare comp; + + + + // restricted functions + binary_search_tree_kernel_2(binary_search_tree_kernel_2&); + binary_search_tree_kernel_2& operator=(binary_search_tree_kernel_2&); + + + }; + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + inline void swap ( + binary_search_tree_kernel_2& a, + binary_search_tree_kernel_2& b + ) { a.swap(b); } + + + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void deserialize ( + binary_search_tree_kernel_2& item, + std::istream& in + ) + { + try + { + item.clear(); + unsigned long size; + deserialize(size,in); + domain d; + range r; + for (unsigned long i = 0; i < size; ++i) + { + deserialize(d,in); + deserialize(r,in); + item.add(d,r); + } + } + catch (serialization_error e) + { + item.clear(); + throw serialization_error(e.info + "\n while deserializing object of type binary_search_tree_kernel_2"); + } + } + + + + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + binary_search_tree_kernel_2:: + ~binary_search_tree_kernel_2 ( + ) + { + if (tree_root != NIL) + delete_tree(tree_root); + pool.deallocate(NIL); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + clear ( + ) + { + if (tree_size > 0) + { + delete_tree(tree_root); + tree_root = NIL; + tree_size = 0; + } + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + size_t binary_search_tree_kernel_2:: + size ( + ) const + { + return tree_size; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + short binary_search_tree_kernel_2:: + height ( + ) const + { + return tree_height(tree_root); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + unsigned long binary_search_tree_kernel_2:: + count ( + const domain& item + ) const + { + return get_count(item,tree_root); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + add ( + domain& d, + range& r + ) + { + if (tree_size == 0) + { + tree_root = pool.allocate(); + tree_root->color = black; + tree_root->left = NIL; + tree_root->right = NIL; + tree_root->parent = NIL; + exchange(tree_root->d,d); + exchange(tree_root->r,r); + } + else + { + add_to_tree(tree_root,d,r); + } + ++tree_size; + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + remove ( + const domain& d, + domain& d_copy, + range& r + ) + { + remove_from_tree(tree_root,d,d_copy,r); + --tree_size; + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + destroy ( + const domain& item + ) + { + remove_from_tree(tree_root,item); + --tree_size; + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + remove_any ( + domain& d, + range& r + ) + { + remove_least_element_in_tree(tree_root,d,r); + --tree_size; + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + range* binary_search_tree_kernel_2:: + operator[] ( + const domain& d + ) + { + return return_reference(tree_root,d); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + const range* binary_search_tree_kernel_2:: + operator[] ( + const domain& d + ) const + { + return return_reference(tree_root,d); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + swap ( + binary_search_tree_kernel_2& item + ) + { + pool.swap(item.pool); + + exchange(p,item.p); + exchange(comp,item.comp); + + node* tree_root_temp = item.tree_root; + unsigned long tree_size_temp = item.tree_size; + node* const NIL_temp = item.NIL; + node* current_element_temp = item.current_element; + bool at_start_temp = item.at_start_; + + item.tree_root = tree_root; + item.tree_size = tree_size; + item.NIL = NIL; + item.current_element = current_element; + item.at_start_ = at_start_; + + tree_root = tree_root_temp; + tree_size = tree_size_temp; + NIL = NIL_temp; + current_element = current_element_temp; + at_start_ = at_start_temp; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + remove_last_in_order ( + domain& d, + range& r + ) + { + remove_biggest_element_in_tree(tree_root,d,r); + --tree_size; + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + remove_current_element ( + domain& d, + range& r + ) + { + node* t = current_element; + move_next(); + remove_from_tree(t,t->d,d,r); + --tree_size; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + position_enumerator ( + const domain& d + ) const + { + // clear the enumerator state and make sure the stack is empty + reset(); + at_start_ = false; + node* t = tree_root; + node* parent = NIL; + bool went_left = false; + while (t != NIL) + { + if ( comp(d , t->d )) + { + // if item is on the left then look in left + parent = t; + t = t->left; + went_left = true; + } + else if (comp(t->d , d)) + { + // if item is on the right then look in right + parent = t; + t = t->right; + went_left = false; + } + else + { + current_element = t; + return; + } + } + + // if we didn't find any matches but there might be something after the + // d in this tree. + if (parent != NIL) + { + current_element = parent; + // if we went left from this node then this node is the next + // biggest. + if (went_left) + { + return; + } + else + { + move_next(); + } + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // enumerable function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + bool binary_search_tree_kernel_2:: + at_start ( + ) const + { + return at_start_; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + reset ( + ) const + { + at_start_ = true; + current_element = 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + bool binary_search_tree_kernel_2:: + current_element_valid ( + ) const + { + return (current_element != 0); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + const map_pair& binary_search_tree_kernel_2:: + element ( + ) const + { + p.d = &(current_element->d); + p.r = &(current_element->r); + return p; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + map_pair& binary_search_tree_kernel_2:: + element ( + ) + { + p.d = &(current_element->d); + p.r = &(current_element->r); + return p; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + bool binary_search_tree_kernel_2:: + move_next ( + ) const + { + // if we haven't started iterating yet + if (at_start_) + { + at_start_ = false; + if (tree_size == 0) + { + return false; + } + else + { + // find the first element in the tree + current_element = tree_root; + node* temp = current_element->left; + while (temp != NIL) + { + current_element = temp; + temp = current_element->left; + } + return true; + } + } + else + { + if (current_element == 0) + { + return false; + } + else + { + bool went_up; // true if we went up the tree from a child node to parent + bool from_left = false; // true if we went up and were coming from a left child node + // find the next element in the tree + if (current_element->right != NIL) + { + // go right and down + current_element = current_element->right; + went_up = false; + } + else + { + went_up = true; + node* parent = current_element->parent; + if (parent == NIL) + { + // in this case we have iterated over all the element of the tree + current_element = 0; + return false; + } + + from_left = (parent->left == current_element); + // go up to parent + current_element = parent; + } + + + while (true) + { + if (went_up) + { + if (from_left) + { + // in this case we have found the next node + break; + } + else + { + // we should go up + node* parent = current_element->parent; + from_left = (parent->left == current_element); + current_element = parent; + if (current_element == NIL) + { + // in this case we have iterated over all the elements + // in the tree + current_element = 0; + return false; + } + } + } + else + { + // we just went down to a child node + if (current_element->left != NIL) + { + // go left + went_up = false; + current_element = current_element->left; + } + else + { + // if there is no left child then we have found the next node + break; + } + } + } + + return true; + } + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // private member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + delete_tree ( + node* t + ) + { + if (t->left != NIL) + delete_tree(t->left); + if (t->right != NIL) + delete_tree(t->right); + pool.deallocate(t); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + rotate_left ( + node* t + ) + { + + // perform the rotation + node* temp = t->right; + t->right = temp->left; + if (temp->left != NIL) + temp->left->parent = t; + temp->left = t; + temp->parent = t->parent; + + + if (t == tree_root) + tree_root = temp; + else + { + // if t was on the left + if (t->parent->left == t) + t->parent->left = temp; + else + t->parent->right = temp; + } + + t->parent = temp; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + rotate_right ( + node* t + ) + { + // perform the rotation + node* temp = t->left; + t->left = temp->right; + if (temp->right != NIL) + temp->right->parent = t; + temp->right = t; + temp->parent = t->parent; + + if (t == tree_root) + tree_root = temp; + else + { + // if t is a left child + if (t->parent->left == t) + t->parent->left = temp; + else + t->parent->right = temp; + } + + t->parent = temp; + } + + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + double_rotate_right ( + node* t + ) + { + + // preform the rotation + node& temp = *(t->left->right); + t->left = temp.right; + temp.right->parent = t; + temp.left->parent = temp.parent; + temp.parent->right = temp.left; + temp.parent->parent = &temp; + temp.right = t; + temp.left = temp.parent; + temp.parent = t->parent; + + + if (tree_root == t) + tree_root = &temp; + else + { + // t is a left child + if (t->parent->left == t) + t->parent->left = &temp; + else + t->parent->right = &temp; + } + t->parent = &temp; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + double_rotate_left ( + node* t + ) + { + + + // preform the rotation + node& temp = *(t->right->left); + t->right = temp.left; + temp.left->parent = t; + temp.right->parent = temp.parent; + temp.parent->left = temp.right; + temp.parent->parent = &temp; + temp.left = t; + temp.right = temp.parent; + temp.parent = t->parent; + + + if (tree_root == t) + tree_root = &temp; + else + { + // t is a left child + if (t->parent->left == t) + t->parent->left = &temp; + else + t->parent->right = &temp; + } + t->parent = &temp; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + remove_biggest_element_in_tree ( + node* t, + domain& d, + range& r + ) + { + + node* next = t->right; + node* child; // the child node of the one we will slice out + + if (next == NIL) + { + // need to determine if t is a right or left child + if (t->parent->right == t) + child = t->parent->right = t->left; + else + child = t->parent->left = t->left; + + // update tree_root if necessary + if (t == tree_root) + tree_root = child; + } + else + { + // find the least node + do + { + t = next; + next = next->right; + } while (next != NIL); + // t is a right child + child = t->parent->right = t->left; + + } + + // swap the item from this node into d and r + exchange(d,t->d); + exchange(r,t->r); + + // plug hole right by removing this node + child->parent = t->parent; + + // keep the red-black properties true + if (t->color == black) + fix_after_remove(child); + + // free the memory for this removed node + pool.deallocate(t); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + bool binary_search_tree_kernel_2:: + remove_least_element_in_tree ( + node* t, + domain& d, + range& r + ) + { + + node* next = t->left; + node* child; // the child node of the one we will slice out + + if (next == NIL) + { + // need to determine if t is a left or right child + if (t->parent->left == t) + child = t->parent->left = t->right; + else + child = t->parent->right = t->right; + + // update tree_root if necessary + if (t == tree_root) + tree_root = child; + } + else + { + // find the least node + do + { + t = next; + next = next->left; + } while (next != NIL); + // t is a left child + child = t->parent->left = t->right; + + } + + // swap the item from this node into d and r + exchange(d,t->d); + exchange(r,t->r); + + // plug hole left by removing this node + child->parent = t->parent; + + // keep the red-black properties true + if (t->color == black) + fix_after_remove(child); + + bool rvalue = (t == current_element); + // free the memory for this removed node + pool.deallocate(t); + return rvalue; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + add_to_tree ( + node* t, + domain& d, + range& r + ) + { + // parent of the current node + node* parent; + + // find a place to add node + while (true) + { + parent = t; + // if item should be put on the left then go left + if (comp(d , t->d)) + { + t = t->left; + if (t == NIL) + { + t = parent->left = pool.allocate(); + break; + } + } + // if item should be put on the right then go right + else + { + t = t->right; + if (t == NIL) + { + t = parent->right = pool.allocate(); + break; + } + } + } + + // t is now the node where we will add item and + // parent is the parent of t + + t->parent = parent; + t->left = NIL; + t->right = NIL; + t->color = red; + exchange(t->d,d); + exchange(t->r,r); + + + // keep the red-black properties true + fix_after_add(t); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + remove_from_tree ( + node* t, + const domain& d, + domain& d_copy, + range& r + ) + { + while (true) + { + if ( comp(d , t->d) ) + { + // if item is on the left then look in left + t = t->left; + } + else if (comp(t->d , d)) + { + // if item is on the right then look in right + t = t->right; + } + else + { + // found the node we want to remove + + // swap out the item into d_copy and r + exchange(d_copy,t->d); + exchange(r,t->r); + + if (t->left == NIL) + { + // if there is no left subtree + + node* parent = t->parent; + + // plug hole with right subtree + + + // if t is on the left + if (parent->left == t) + parent->left = t->right; + else + parent->right = t->right; + t->right->parent = parent; + + // update tree_root if necessary + if (t == tree_root) + tree_root = t->right; + + if (t->color == black) + fix_after_remove(t->right); + + // delete old node + pool.deallocate(t); + } + else if (t->right == NIL) + { + // if there is no right subtree + + node* parent = t->parent; + + // plug hole with left subtree + if (parent->left == t) + parent->left = t->left; + else + parent->right = t->left; + t->left->parent = parent; + + // update tree_root if necessary + if (t == tree_root) + tree_root = t->left; + + if (t->color == black) + fix_after_remove(t->left); + + // delete old node + pool.deallocate(t); + } + else + { + // if there is both a left and right subtree + // get an element to fill this node now that its been swapped into + // item_copy + if (remove_least_element_in_tree(t->right,t->d,t->r)) + { + // the node removed was the one pointed to by current_element so we + // need to update it so that it points to the right spot. + current_element = t; + } + } + + // quit loop + break; + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + remove_from_tree ( + node* t, + const domain& d + ) + { + while (true) + { + if ( comp(d , t->d) ) + { + // if item is on the left then look in left + t = t->left; + } + else if (comp(t->d , d)) + { + // if item is on the right then look in right + t = t->right; + } + else + { + // found the node we want to remove + + + if (t->left == NIL) + { + // if there is no left subtree + + node* parent = t->parent; + + // plug hole with right subtree + + + if (parent->left == t) + parent->left = t->right; + else + parent->right = t->right; + t->right->parent = parent; + + // update tree_root if necessary + if (t == tree_root) + tree_root = t->right; + + if (t->color == black) + fix_after_remove(t->right); + + // delete old node + pool.deallocate(t); + } + else if (t->right == NIL) + { + // if there is no right subtree + + node* parent = t->parent; + + // plug hole with left subtree + if (parent->left == t) + parent->left = t->left; + else + parent->right = t->left; + t->left->parent = parent; + + // update tree_root if necessary + if (t == tree_root) + tree_root = t->left; + + if (t->color == black) + fix_after_remove(t->left); + + // delete old node + pool.deallocate(t); + } + else + { + // if there is both a left and right subtree + // get an element to fill this node now that its been swapped into + // item_copy + remove_least_element_in_tree(t->right,t->d,t->r); + + } + + // quit loop + break; + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + range* binary_search_tree_kernel_2:: + return_reference ( + node* t, + const domain& d + ) + { + while (t != NIL) + { + if ( comp(d , t->d )) + { + // if item is on the left then look in left + t = t->left; + } + else if (comp(t->d , d)) + { + // if item is on the right then look in right + t = t->right; + } + else + { + // if it's found then return a reference to it + return &(t->r); + } + } + return 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + const range* binary_search_tree_kernel_2:: + return_reference ( + const node* t, + const domain& d + ) const + { + while (t != NIL) + { + if ( comp(d , t->d) ) + { + // if item is on the left then look in left + t = t->left; + } + else if (comp(t->d , d)) + { + // if item is on the right then look in right + t = t->right; + } + else + { + // if it's found then return a reference to it + return &(t->r); + } + } + return 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + fix_after_add ( + node* t + ) + { + + while (t->parent->color == red) + { + node& grandparent = *(t->parent->parent); + + // if both t's parent and its sibling are red + if (grandparent.left->color == grandparent.right->color) + { + grandparent.color = red; + grandparent.left->color = black; + grandparent.right->color = black; + t = &grandparent; + } + else + { + // if t is a left child + if (t == t->parent->left) + { + // if t's parent is a left child + if (t->parent == grandparent.left) + { + grandparent.color = red; + grandparent.left->color = black; + rotate_right(&grandparent); + } + // if t's parent is a right child + else + { + t->color = black; + grandparent.color = red; + double_rotate_left(&grandparent); + } + } + // if t is a right child + else + { + // if t's parent is a left child + if (t->parent == grandparent.left) + { + t->color = black; + grandparent.color = red; + double_rotate_right(&grandparent); + } + // if t's parent is a right child + else + { + grandparent.color = red; + grandparent.right->color = black; + rotate_left(&grandparent); + } + } + break; + } + } + tree_root->color = black; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + fix_after_remove ( + node* t + ) + { + + while (t != tree_root && t->color == black) + { + if (t->parent->left == t) + { + node* sibling = t->parent->right; + if (sibling->color == red) + { + sibling->color = black; + t->parent->color = red; + rotate_left(t->parent); + sibling = t->parent->right; + } + + if (sibling->left->color == black && sibling->right->color == black) + { + sibling->color = red; + t = t->parent; + } + else + { + if (sibling->right->color == black) + { + sibling->left->color = black; + sibling->color = red; + rotate_right(sibling); + sibling = t->parent->right; + } + + sibling->color = t->parent->color; + t->parent->color = black; + sibling->right->color = black; + rotate_left(t->parent); + t = tree_root; + + } + + + } + else + { + + node* sibling = t->parent->left; + if (sibling->color == red) + { + sibling->color = black; + t->parent->color = red; + rotate_right(t->parent); + sibling = t->parent->left; + } + + if (sibling->left->color == black && sibling->right->color == black) + { + sibling->color = red; + t = t->parent; + } + else + { + if (sibling->left->color == black) + { + sibling->right->color = black; + sibling->color = red; + rotate_left(sibling); + sibling = t->parent->left; + } + + sibling->color = t->parent->color; + t->parent->color = black; + sibling->left->color = black; + rotate_right(t->parent); + t = tree_root; + + } + + + } + + } + t->color = black; + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + short binary_search_tree_kernel_2:: + tree_height ( + node* t + ) const + { + if (t == NIL) + return 0; + + short height1 = tree_height(t->left); + short height2 = tree_height(t->right); + if (height1 > height2) + return height1 + 1; + else + return height2 + 1; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + unsigned long binary_search_tree_kernel_2:: + get_count ( + const domain& d, + node* tree_root + ) const + { + if (tree_root != NIL) + { + if (comp(d , tree_root->d)) + { + // go left + return get_count(d,tree_root->left); + } + else if (comp(tree_root->d , d)) + { + // go right + return get_count(d,tree_root->right); + } + else + { + // go left and right to look for more matches + return get_count(d,tree_root->left) + + get_count(d,tree_root->right) + + 1; + } + } + return 0; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BINARY_SEARCH_TREE_KERNEl_2_ + diff --git a/ml/dlib/dlib/binary_search_tree/binary_search_tree_kernel_abstract.h b/ml/dlib/dlib/binary_search_tree/binary_search_tree_kernel_abstract.h new file mode 100644 index 000000000..2abfe7e39 --- /dev/null +++ b/ml/dlib/dlib/binary_search_tree/binary_search_tree_kernel_abstract.h @@ -0,0 +1,311 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_BINARY_SEARCH_TREE_KERNEl_ABSTRACT_ +#ifdef DLIB_BINARY_SEARCH_TREE_KERNEl_ABSTRACT_ + +#include "../interfaces/map_pair.h" +#include "../interfaces/enumerable.h" +#include "../interfaces/remover.h" +#include "../serialize.h" +#include "../algs.h" +#include + +namespace dlib +{ + + template < + typename domain, + typename range, + typename mem_manager = default_memory_manager, + typename compare = std::less + > + class binary_search_tree : public enumerable >, + public asc_pair_remover + { + + /*! + REQUIREMENTS ON domain + domain must be comparable by compare where compare is a functor compatible with std::less and + domain is swappable by a global swap() and + domain must have a default constructor + + REQUIREMENTS ON range + range is swappable by a global swap() and + range must have a default constructor + + REQUIREMENTS ON mem_manager + must be an implementation of memory_manager/memory_manager_kernel_abstract.h or + must be an implementation of memory_manager_global/memory_manager_global_kernel_abstract.h or + must be an implementation of memory_manager_stateless/memory_manager_stateless_kernel_abstract.h + mem_manager::type can be set to anything. + + + POINTERS AND REFERENCES TO INTERNAL DATA + swap(), count(), height(), and operator[] functions + do not invalidate pointers or references to internal data. + + position_enumerator() invalidates pointers or references to + data returned by element() and only by element() (i.e. pointers and + references returned by operator[] are still valid). + + All other functions have no such guarantees. + + INITIAL VALUE + size() == 0 + height() == 0 + + ENUMERATION ORDER + The enumerator will iterate over the domain (and each associated + range element) elements in ascending order according to the compare functor. + (i.e. the elements are enumerated in sorted order) + + WHAT THIS OBJECT REPRESENTS + this object represents a data dictionary that is built on top of some + kind of binary search tree. It maps objects of type domain to objects + of type range. + + Also note that unless specified otherwise, no member functions + of this object throw exceptions. + + NOTE: + definition of equivalent: + a is equivalent to b if + a < b == false and + b < a == false + !*/ + + + public: + + typedef domain domain_type; + typedef range range_type; + typedef compare compare_type; + typedef mem_manager mem_manager_type; + + binary_search_tree( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc or any exception thrown by domain's or range's + constructor. + !*/ + + virtual ~binary_search_tree( + ); + /*! + ensures + - all memory associated with *this has been released + !*/ + + void clear( + ); + /*! + ensures + - #*this has its initial value + throws + - std::bad_alloc or any exception thrown by domain's or range's + constructor. + if this exception is thrown then *this is unusable + until clear() is called and succeeds + !*/ + + short height ( + ) const; + /*! + ensures + - returns the number of elements in the longest path from the root + of the tree to a leaf + !*/ + + unsigned long count ( + const domain& d + ) const; + /*! + ensures + - returns the number of elements in the domain of *this that are + equivalent to d + !*/ + + void add ( + domain& d, + range& r + ); + /*! + requires + - &d != &r (i.e. d and r cannot be the same variable) + ensures + - adds a mapping between d and r to *this + - if (count(d) == 0) then + - #*(*this)[d] == r + - else + - #(*this)[d] != 0 + - #d and #r have initial values for their types + - #count(d) == count(d) + 1 + - #at_start() == true + - #size() == size() + 1 + throws + - std::bad_alloc or any exception thrown by domain's or range's + constructor. + if add() throws then it has no effect + !*/ + + void remove ( + const domain& d, + domain& d_copy, + range& r + ); + /*! + requires + - (*this)[d] != 0 + - &d != &r (i.e. d and r cannot be the same variable) + - &d != &d_copy (i.e. d and d_copy cannot be the same variable) + - &r != &d_copy (i.e. r and d_copy cannot be the same variable) + ensures + - some element in the domain of *this that is equivalent to d has + been removed and swapped into #d_copy. Additionally, its + associated range element has been removed and swapped into #r. + - #count(d) == count(d) - 1 + - #size() == size() - 1 + - #at_start() == true + !*/ + + void destroy ( + const domain& d + ); + /*! + requires + - (*this)[d] != 0 + ensures + - an element in the domain of *this equivalent to d has been removed. + The element in the range of *this associated with d has also been + removed. + - #count(d) == count(d) - 1 + - #size() == size() - 1 + - #at_start() == true + !*/ + + void remove_last_in_order ( + domain& d, + range& r + ); + /*! + requires + - size() > 0 + ensures + - the last/biggest (according to the compare functor) element in the domain of *this has + been removed and swapped into #d. The element in the range of *this + associated with #d has also been removed and swapped into #r. + - #count(#d) == count(#d) - 1 + - #size() == size() - 1 + - #at_start() == true + !*/ + + void remove_current_element ( + domain& d, + range& r + ); + /*! + requires + - current_element_valid() == true + ensures + - the current element given by element() has been removed and swapped into d and r. + - #d == element().key() + - #r == element().value() + - #count(#d) == count(#d) - 1 + - #size() == size() - 1 + - moves the enumerator to the next element. If element() was the last + element in enumeration order then #current_element_valid() == false + and #at_start() == false. + !*/ + + void position_enumerator ( + const domain& d + ) const; + /*! + ensures + - #at_start() == false + - if (count(d) > 0) then + - #element().key() == d + - else if (there are any items in the domain of *this that are bigger than + d according to the compare functor) then + - #element().key() == the smallest item in the domain of *this that is + bigger than d according to the compare functor. + - else + - #current_element_valid() == false + !*/ + + const range* operator[] ( + const domain& d + ) const; + /*! + ensures + - if (there is an element in the domain equivalent to d) then + - returns a pointer to an element in the range of *this that + is associated with an element in the domain of *this + equivalent to d. + - else + - returns 0 + !*/ + + range* operator[] ( + const domain& d + ); + /*! + ensures + - if (there is an element in the domain equivalent to d) then + - returns a pointer to an element in the range of *this that + is associated with an element in the domain of *this + equivalent to d. + - else + - returns 0 + !*/ + + void swap ( + binary_search_tree& item + ); + /*! + ensures + - swaps *this and item + !*/ + + private: + + // restricted functions + binary_search_tree(binary_search_tree&); + binary_search_tree& operator=(binary_search_tree&); + + }; + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + inline void swap ( + binary_search_tree& a, + binary_search_tree& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void deserialize ( + binary_search_tree& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ +} + +#endif // DLIB_BINARY_SEARCH_TREE_KERNEl_ABSTRACT_ + diff --git a/ml/dlib/dlib/binary_search_tree/binary_search_tree_kernel_c.h b/ml/dlib/dlib/binary_search_tree/binary_search_tree_kernel_c.h new file mode 100644 index 000000000..0dc153961 --- /dev/null +++ b/ml/dlib/dlib/binary_search_tree/binary_search_tree_kernel_c.h @@ -0,0 +1,235 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BINARY_SEARCH_TREE_KERNEl_C_ +#define DLIB_BINARY_SEARCH_TREE_KERNEl_C_ + +#include "../interfaces/map_pair.h" +#include "binary_search_tree_kernel_abstract.h" +#include "../algs.h" +#include "../assert.h" + +namespace dlib +{ + + template < + typename bst_base + > + class binary_search_tree_kernel_c : public bst_base + { + typedef typename bst_base::domain_type domain; + typedef typename bst_base::range_type range; + + public: + + binary_search_tree_kernel_c () {} + + void remove ( + const domain& d, + domain& d_copy, + range& r + ); + + void destroy ( + const domain& d + ); + + void add ( + domain& d, + range& r + ); + + void remove_any ( + domain& d, + range& r + ); + + const map_pair& element( + ) const + { + DLIB_CASSERT(this->current_element_valid() == true, + "\tconst map_pair& binary_search_tree::element() const" + << "\n\tyou can't access the current element if it doesn't exist" + << "\n\tthis: " << this + ); + + return bst_base::element(); + } + + map_pair& element( + ) + { + DLIB_CASSERT(this->current_element_valid() == true, + "\tmap_pair& binary_search_tree::element()" + << "\n\tyou can't access the current element if it doesn't exist" + << "\n\tthis: " << this + ); + + return bst_base::element(); + } + + void remove_last_in_order ( + domain& d, + range& r + ); + + void remove_current_element ( + domain& d, + range& r + ); + + + }; + + + template < + typename bst_base + > + inline void swap ( + binary_search_tree_kernel_c& a, + binary_search_tree_kernel_c& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename bst_base + > + void binary_search_tree_kernel_c:: + add ( + domain& d, + range& r + ) + { + DLIB_CASSERT( static_cast(&d) != static_cast(&r), + "\tvoid binary_search_tree::add" + << "\n\tyou can't call add() and give the same object to both parameters." + << "\n\tthis: " << this + << "\n\t&d: " << &d + << "\n\t&r: " << &r + << "\n\tsize(): " << this->size() + ); + + bst_base::add(d,r); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bst_base + > + void binary_search_tree_kernel_c:: + destroy ( + const domain& d + ) + { + DLIB_CASSERT(this->operator[](d) != 0, + "\tvoid binary_search_tree::destroy" + << "\n\tthe element must be in the tree for it to be removed" + << "\n\tthis: " << this + << "\n\t&d: " << &d + ); + + bst_base::destroy(d); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bst_base + > + void binary_search_tree_kernel_c:: + remove ( + const domain& d, + domain& d_copy, + range& r + ) + { + DLIB_CASSERT(this->operator[](d) != 0 && + (static_cast(&d) != static_cast(&d_copy)) && + (static_cast(&d) != static_cast(&r)) && + (static_cast(&r) != static_cast(&d_copy)), + "\tvoid binary_search_tree::remove" + << "\n\tthe element must be in the tree for it to be removed" + << "\n\tthis: " << this + << "\n\t&d: " << &d + << "\n\t&d_copy: " << &d_copy + << "\n\t&r: " << &r + ); + + bst_base::remove(d,d_copy,r); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bst_base + > + void binary_search_tree_kernel_c:: + remove_any( + domain& d, + range& r + ) + { + DLIB_CASSERT(this->size() != 0 && + (static_cast(&d) != static_cast(&r)), + "\tvoid binary_search_tree::remove_any" + << "\n\ttree must not be empty if something is going to be removed" + << "\n\tthis: " << this + << "\n\t&d: " << &d + << "\n\t&r: " << &r + ); + + bst_base::remove_any(d,r); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bst_base + > + void binary_search_tree_kernel_c:: + remove_last_in_order ( + domain& d, + range& r + ) + { + DLIB_CASSERT(this->size() > 0, + "\tvoid binary_search_tree::remove_last_in_order()" + << "\n\tyou can't remove an element if it doesn't exist" + << "\n\tthis: " << this + ); + + bst_base::remove_last_in_order(d,r); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bst_base + > + void binary_search_tree_kernel_c:: + remove_current_element ( + domain& d, + range& r + ) + { + DLIB_CASSERT(this->current_element_valid() == true, + "\tvoid binary_search_tree::remove_current_element()" + << "\n\tyou can't remove the current element if it doesn't exist" + << "\n\tthis: " << this + ); + + bst_base::remove_current_element(d,r); + } + + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BINARY_SEARCH_TREE_KERNEl_C_ + diff --git a/ml/dlib/dlib/bit_stream.h b/ml/dlib/dlib/bit_stream.h new file mode 100644 index 000000000..8885f3515 --- /dev/null +++ b/ml/dlib/dlib/bit_stream.h @@ -0,0 +1,42 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BIT_STREAm_ +#define DLIB_BIT_STREAm_ + +#include "bit_stream/bit_stream_kernel_1.h" +#include "bit_stream/bit_stream_kernel_c.h" + +#include "bit_stream/bit_stream_multi_1.h" +#include "bit_stream/bit_stream_multi_c.h" + +namespace dlib +{ + + + class bit_stream + { + bit_stream() {} + public: + + //----------- kernels --------------- + + // kernel_1a + typedef bit_stream_kernel_1 + kernel_1a; + typedef bit_stream_kernel_c + kernel_1a_c; + + //---------- extensions ------------ + + + // multi_1 extend kernel_1a + typedef bit_stream_multi_1 + multi_1a; + typedef bit_stream_multi_c > + multi_1a_c; + + }; +} + +#endif // DLIB_BIT_STREAm_ + diff --git a/ml/dlib/dlib/bit_stream/bit_stream_kernel_1.cpp b/ml/dlib/dlib/bit_stream/bit_stream_kernel_1.cpp new file mode 100644 index 000000000..f49db14d5 --- /dev/null +++ b/ml/dlib/dlib/bit_stream/bit_stream_kernel_1.cpp @@ -0,0 +1,200 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BIT_STREAM_KERNEL_1_CPp_ +#define DLIB_BIT_STREAM_KERNEL_1_CPp_ + + +#include "bit_stream_kernel_1.h" +#include "../algs.h" + +#include + +namespace dlib +{ + + inline void swap ( + bit_stream_kernel_1& a, + bit_stream_kernel_1& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + void bit_stream_kernel_1:: + clear ( + ) + { + if (write_mode) + { + write_mode = false; + + // flush output buffer + if (buffer_size > 0) + { + buffer <<= 8 - buffer_size; + osp->write(reinterpret_cast(&buffer),1); + } + } + else + read_mode = false; + + } + +// ---------------------------------------------------------------------------------------- + + void bit_stream_kernel_1:: + set_input_stream ( + std::istream& is + ) + { + isp = &is; + read_mode = true; + + buffer_size = 0; + } + +// ---------------------------------------------------------------------------------------- + + void bit_stream_kernel_1:: + set_output_stream ( + std::ostream& os + ) + { + osp = &os; + write_mode = true; + + buffer_size = 0; + } + +// ---------------------------------------------------------------------------------------- + + void bit_stream_kernel_1:: + close ( + ) + { + if (write_mode) + { + write_mode = false; + + // flush output buffer + if (buffer_size > 0) + { + buffer <<= 8 - buffer_size; + osp->write(reinterpret_cast(&buffer),1); + } + } + else + read_mode = false; + } + +// ---------------------------------------------------------------------------------------- + + bool bit_stream_kernel_1:: + is_in_write_mode ( + ) const + { + return write_mode; + } + +// ---------------------------------------------------------------------------------------- + + bool bit_stream_kernel_1:: + is_in_read_mode ( + ) const + { + return read_mode; + } + +// ---------------------------------------------------------------------------------------- + + void bit_stream_kernel_1:: + write ( + int bit + ) + { + // flush buffer if necessary + if (buffer_size == 8) + { + buffer <<= 8 - buffer_size; + if (osp->rdbuf()->sputn(reinterpret_cast(&buffer),1) == 0) + { + throw std::ios_base::failure("error occurred in the bit_stream object"); + } + + buffer_size = 0; + } + + ++buffer_size; + buffer <<= 1; + buffer += static_cast(bit); + } + +// ---------------------------------------------------------------------------------------- + + bool bit_stream_kernel_1:: + read ( + int& bit + ) + { + // get new byte if necessary + if (buffer_size == 0) + { + if (isp->rdbuf()->sgetn(reinterpret_cast(&buffer), 1) == 0) + { + // if we didn't read anything then return false + return false; + } + + buffer_size = 8; + } + + // put the most significant bit from buffer into bit + bit = static_cast(buffer >> 7); + + // shift out the bit that was just read + buffer <<= 1; + --buffer_size; + + return true; + } + +// ---------------------------------------------------------------------------------------- + + void bit_stream_kernel_1:: + swap ( + bit_stream_kernel_1& item + ) + { + + std::istream* isp_temp = item.isp; + std::ostream* osp_temp = item.osp; + bool write_mode_temp = item.write_mode; + bool read_mode_temp = item.read_mode; + unsigned char buffer_temp = item.buffer; + unsigned short buffer_size_temp = item.buffer_size; + + item.isp = isp; + item.osp = osp; + item.write_mode = write_mode; + item.read_mode = read_mode; + item.buffer = buffer; + item.buffer_size = buffer_size; + + + isp = isp_temp; + osp = osp_temp; + write_mode = write_mode_temp; + read_mode = read_mode_temp; + buffer = buffer_temp; + buffer_size = buffer_size_temp; + + } + +// ---------------------------------------------------------------------------------------- + +} +#endif // DLIB_BIT_STREAM_KERNEL_1_CPp_ + diff --git a/ml/dlib/dlib/bit_stream/bit_stream_kernel_1.h b/ml/dlib/dlib/bit_stream/bit_stream_kernel_1.h new file mode 100644 index 000000000..801e93e0a --- /dev/null +++ b/ml/dlib/dlib/bit_stream/bit_stream_kernel_1.h @@ -0,0 +1,120 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BIT_STREAM_KERNEl_1_ +#define DLIB_BIT_STREAM_KERNEl_1_ + +#include "bit_stream_kernel_abstract.h" +#include + +namespace dlib +{ + + class bit_stream_kernel_1 + { + + /*! + INITIAL VALUE + write_mode == false + read_mode == false + + CONVENTION + write_mode == is_in_write_mode() + read_mode == is_in_read_mode() + + if (write_mode) + { + osp == pointer to an ostream object + buffer == the low order bits of buffer are the bits to be + written + buffer_size == the number of low order bits in buffer that are + bits that should be written + the lowest order bit is the last bit entered by the user + } + + if (read_mode) + { + isp == pointer to an istream object + buffer == the high order bits of buffer are the bits + waiting to be read by the user + buffer_size == the number of high order bits in buffer that + are bits that are waiting to be read + the highest order bit is the next bit to give to the user + } + !*/ + + + public: + + bit_stream_kernel_1 ( + ) : + write_mode(false), + read_mode(false) + {} + + virtual ~bit_stream_kernel_1 ( + ) + {} + + void clear ( + ); + + void set_input_stream ( + std::istream& is + ); + + void set_output_stream ( + std::ostream& os + ); + + void close ( + ); + + inline bool is_in_write_mode ( + ) const; + + inline bool is_in_read_mode ( + ) const; + + inline void write ( + int bit + ); + + bool read ( + int& bit + ); + + void swap ( + bit_stream_kernel_1& item + ); + + private: + + // member data + std::istream* isp; + std::ostream* osp; + bool write_mode; + bool read_mode; + unsigned char buffer; + unsigned short buffer_size; + + // restricted functions + bit_stream_kernel_1(bit_stream_kernel_1&); // copy constructor + bit_stream_kernel_1& operator=(bit_stream_kernel_1&); // assignment operator + + }; + + inline void swap ( + bit_stream_kernel_1& a, + bit_stream_kernel_1& b + ); + +// ---------------------------------------------------------------------------------------- + +} + +#ifdef NO_MAKEFILE +#include "bit_stream_kernel_1.cpp" +#endif + +#endif // DLIB_BIT_STREAM_KERNEl_1_ + diff --git a/ml/dlib/dlib/bit_stream/bit_stream_kernel_abstract.h b/ml/dlib/dlib/bit_stream/bit_stream_kernel_abstract.h new file mode 100644 index 000000000..00c2ae3b9 --- /dev/null +++ b/ml/dlib/dlib/bit_stream/bit_stream_kernel_abstract.h @@ -0,0 +1,185 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_BIT_STREAM_KERNEl_ABSTRACT_ +#ifdef DLIB_BIT_STREAM_KERNEl_ABSTRACT_ + +#include + +namespace dlib +{ + + class bit_stream + { + + /*! + INITIAL VALUE + is_in_write_mode() == false + is_in_read_mode() == false + + WHAT THIS OBJECT REPRESENTS + this object is a middle man between a user and the iostream classes. + it allows single bits to be read/written easily to/from + the iostream classes + + BUFFERING: + This object will only read/write single bytes at a time from/to the + iostream objects. Any buffered bits still in the bit_stream object + when it is closed or destructed are lost if it is in read mode. If + it is in write mode then any remaining bits are guaranteed to be + written to the output stream by the time it is closed or destructed. + !*/ + + + public: + + bit_stream ( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc + !*/ + + virtual ~bit_stream ( + ); + /*! + ensures + - all memory associated with *this has been released + !*/ + + void clear ( + ); + /*! + ensures + - #*this has its initial value + throws + - std::bad_alloc + if this exception is thrown then *this is unusable + until clear() is called and succeeds + !*/ + + + void set_input_stream ( + std::istream& is + ); + /*! + requires + - is_in_write_mode() == false + - is_in_read_mode() == false + - is is ready to give input + ensures + - #is_in_write_mode() == false + - #is_in_read_mode() == true + - #*this will now be reading from is + throws + - std::bad_alloc + !*/ + + void set_output_stream ( + std::ostream& os + ); + /*! + requires + - is_in_write_mode() == false + - is_in_read_mode() == false + - os is ready to take output + ensures + - #is_in_write_mode() == true + - #is_in_read_mode() == false + - #*this will now write to os + throws + - std::bad_alloc + !*/ + + + + void close ( + ); + /*! + requires + - is_in_write_mode() == true || is_in_read_mode() == true + ensures + - #is_in_write_mode() == false + - #is_in_read_mode() == false + !*/ + + bool is_in_write_mode ( + ) const; + /*! + ensures + - returns true if *this is associated with an output stream object + - returns false otherwise + !*/ + + bool is_in_read_mode ( + ) const; + /*! + ensures + - returns true if *this is associated with an input stream object + - returns false otherwise + !*/ + + void write ( + int bit + ); + /*! + requires + - is_in_write_mode() == true + - bit == 0 || bit == 1 + ensures + - bit will be written to the ostream object associated with *this + throws + - std::ios_base::failure + if (there was a problem writing to the output stream) then + this exception will be thrown. #*this will be unusable until + clear() is called and succeeds + - any other exception + if this exception is thrown then #*this is unusable + until clear() is called and succeeds + !*/ + + bool read ( + int& bit + ); + /*! + requires + - is_in_read_mode() == true + ensures + - the next bit has been read and placed into #bit + - returns true if the read was successful, else false + (ex. false if EOF has been reached) + throws + - any exception + if this exception is thrown then #*this is unusable + until clear() is called and succeeds + !*/ + + void swap ( + bit_stream& item + ); + /*! + ensures + - swaps *this and item + !*/ + + private: + + // restricted functions + bit_stream(bit_stream&); // copy constructor + bit_stream& operator=(bit_stream&); // assignment operator + + }; + + inline void swap ( + bit_stream& a, + bit_stream& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + +} + +#endif // DLIB_BIT_STREAM_KERNEl_ABSTRACT_ + diff --git a/ml/dlib/dlib/bit_stream/bit_stream_kernel_c.h b/ml/dlib/dlib/bit_stream/bit_stream_kernel_c.h new file mode 100644 index 000000000..1d52bff20 --- /dev/null +++ b/ml/dlib/dlib/bit_stream/bit_stream_kernel_c.h @@ -0,0 +1,172 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BIT_STREAM_KERNEl_C_ +#define DLIB_BIT_STREAM_KERNEl_C_ + +#include "bit_stream_kernel_abstract.h" +#include "../algs.h" +#include "../assert.h" +#include + +namespace dlib +{ + + template < + typename bit_stream_base // implements bit_stream/bit_stream_kernel_abstract.h + > + class bit_stream_kernel_c : public bit_stream_base + { + public: + + + void set_input_stream ( + std::istream& is + ); + + void set_output_stream ( + std::ostream& os + ); + + void close ( + ); + + void write ( + int bit + ); + + bool read ( + int& bit + ); + + }; + + template < + typename bit_stream_base + > + inline void swap ( + bit_stream_kernel_c& a, + bit_stream_kernel_c& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename bit_stream_base + > + void bit_stream_kernel_c:: + set_input_stream ( + std::istream& is + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(( this->is_in_write_mode() == false ) && ( this->is_in_read_mode() == false ), + "\tvoid bit_stream::set_intput_stream" + << "\n\tbit_stream must not be in write or read mode" + << "\n\tthis: " << this + ); + + // call the real function + bit_stream_base::set_input_stream(is); + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bit_stream_base + > + void bit_stream_kernel_c:: + set_output_stream ( + std::ostream& os + ) + { + + // make sure requires clause is not broken + DLIB_CASSERT(( this->is_in_write_mode() == false ) && ( this->is_in_read_mode() == false ), + "\tvoid bit_stream::set_output_stream" + << "\n\tbit_stream must not be in write or read mode" + << "\n\tthis: " << this + ); + + // call the real function + bit_stream_base::set_output_stream(os); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bit_stream_base + > + void bit_stream_kernel_c:: + close ( + ) + { + + // make sure requires clause is not broken + DLIB_CASSERT(( this->is_in_write_mode() == true ) || ( this->is_in_read_mode() == true ), + "\tvoid bit_stream::close" + << "\n\tyou can't close a bit_stream that isn't open" + << "\n\tthis: " << this + ); + + // call the real function + bit_stream_base::close(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bit_stream_base + > + void bit_stream_kernel_c:: + write ( + int bit + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(( this->is_in_write_mode() == true ) && ( bit == 0 || bit == 1 ), + "\tvoid bit_stream::write" + << "\n\tthe bit stream bust be in write mode and bit must be either 1 or 0" + << "\n\tis_in_write_mode() == " << this->is_in_write_mode() + << "\n\tbit == " << bit + << "\n\tthis: " << this + ); + + // call the real function + bit_stream_base::write(bit); + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bit_stream_base + > + bool bit_stream_kernel_c:: + read ( + int& bit + ) + { + + // make sure requires clause is not broken + DLIB_CASSERT(( this->is_in_read_mode() == true ), + "\tbool bit_stream::read" + << "\n\tyou can't read from a bit_stream that isn't in read mode" + << "\n\tthis: " << this + ); + + // call the real function + return bit_stream_base::read(bit); + + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BIT_STREAM_KERNEl_C_ + diff --git a/ml/dlib/dlib/bit_stream/bit_stream_multi_1.h b/ml/dlib/dlib/bit_stream/bit_stream_multi_1.h new file mode 100644 index 000000000..bf1cc0357 --- /dev/null +++ b/ml/dlib/dlib/bit_stream/bit_stream_multi_1.h @@ -0,0 +1,103 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BIT_STREAM_MULTi_1_ +#define DLIB_BIT_STREAM_MULTi_1_ + +#include "bit_stream_multi_abstract.h" + +namespace dlib +{ + template < + typename bit_stream_base + > + class bit_stream_multi_1 : public bit_stream_base + { + + public: + + void multi_write ( + unsigned long data, + int num_to_write + ); + + int multi_read ( + unsigned long& data, + int num_to_read + ); + + }; + + template < + typename bit_stream_base + > + inline void swap ( + bit_stream_multi_1& a, + bit_stream_multi_1& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename bit_stream_base + > + void bit_stream_multi_1:: + multi_write ( + unsigned long data, + int num_to_write + ) + { + // move the first bit into the most significant position + data <<= 32 - num_to_write; + + for (int i = 0; i < num_to_write; ++i) + { + // write the first bit from data + this->write(static_cast(data >> 31)); + + // shift the next bit into position + data <<= 1; + + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bit_stream_base + > + int bit_stream_multi_1:: + multi_read ( + unsigned long& data, + int num_to_read + ) + { + int bit, i; + data = 0; + for (i = 0; i < num_to_read; ++i) + { + + // get a bit + if (this->read(bit) == false) + break; + + // shift data to make room for this new bit + data <<= 1; + + // put bit into the least significant position in data + data += static_cast(bit); + + } + + return i; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BIT_STREAM_MULTi_1_ + diff --git a/ml/dlib/dlib/bit_stream/bit_stream_multi_abstract.h b/ml/dlib/dlib/bit_stream/bit_stream_multi_abstract.h new file mode 100644 index 000000000..061af94f4 --- /dev/null +++ b/ml/dlib/dlib/bit_stream/bit_stream_multi_abstract.h @@ -0,0 +1,77 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_BIT_STREAM_MULTi_ABSTRACT_ +#ifdef DLIB_BIT_STREAM_MULTi_ABSTRACT_ + +#include "bit_stream_kernel_abstract.h" + +namespace dlib +{ + template < + typename bit_stream_base + > + class bit_stream_multi : public bit_stream_base + { + + /*! + REQUIREMENTS ON BIT_STREAM_BASE + it is an implementation of bit_stream/bit_stream_kernel_abstract.h + + + WHAT THIS EXTENSION DOES FOR BIT_STREAM + this gives a bit_stream object the ability to read/write multible bits + at a time + !*/ + + + public: + + void multi_write ( + unsigned long data, + int num_to_write + ); + /*! + requires + - is_in_write_mode() == true + - 0 <= num_to_write <= 32 + ensures + - num_to_write low order bits from data will be written to the ostream + - object associated with *this + example: if data is 10010 then the bits will be written in the + order 1,0,0,1,0 + !*/ + + + int multi_read ( + unsigned long& data, + int num_to_read + ); + /*! + requires + - is_in_read_mode() == true + - 0 <= num_to_read <= 32 + ensures + - tries to read num_to_read bits into the low order end of #data + example: if the incoming bits were 10010 then data would end + up with 10010 as its low order bits + - all of the bits in #data not filled in by multi_read() are zero + - returns the number of bits actually read into #data + !*/ + + }; + + template < + typename bit_stream_base + > + inline void swap ( + bit_stream_multi& a, + bit_stream_multi& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + +} + +#endif // DLIB_BIT_STREAM_MULTi_ABSTRACT_ + diff --git a/ml/dlib/dlib/bit_stream/bit_stream_multi_c.h b/ml/dlib/dlib/bit_stream/bit_stream_multi_c.h new file mode 100644 index 000000000..de80c6328 --- /dev/null +++ b/ml/dlib/dlib/bit_stream/bit_stream_multi_c.h @@ -0,0 +1,101 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BIT_STREAM_MULTi_C_ +#define DLIB_BIT_STREAM_MULTi_C_ + +#include "bit_stream_multi_abstract.h" +#include "../algs.h" +#include "../assert.h" + +namespace dlib +{ + template < + typename bit_stream_base // implements bit_stream/bit_stream_multi_abstract.h + > + class bit_stream_multi_c : public bit_stream_base + { + public: + + void multi_write ( + unsigned long data, + int num_to_write + ); + + int multi_read ( + unsigned long& data, + int num_to_read + ); + + }; + + template < + typename bit_stream_base + > + inline void swap ( + bit_stream_multi_c& a, + bit_stream_multi_c& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename bit_stream_base + > + void bit_stream_multi_c:: + multi_write ( + unsigned long data, + int num_to_write + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( (this->is_in_write_mode() == true) && (num_to_write >= 0 && num_to_write <=32), + "\tvoid bit_stream::write" + << "\n\tthe bit stream bust be in write mode and" + << "\n\tnum_to_write must be between 0 and 32 inclusive" + << "\n\tnum_to_write == " << num_to_write + << "\n\tis_in_write_mode() == " << this->is_in_write_mode() + << "\n\tthis: " << this + ); + + // call the real function + bit_stream_base::multi_write(data,num_to_write); + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bit_stream_base + > + int bit_stream_multi_c:: + multi_read ( + unsigned long& data, + int num_to_read + ) + { + + // make sure requires clause is not broken + DLIB_CASSERT(( this->is_in_read_mode() == true && ( num_to_read >= 0 && num_to_read <=32 ) ), + "\tvoid bit_stream::read" + << "\n\tyou can't read from a bit_stream that isn't in read mode and" + << "\n\tnum_to_read must be between 0 and 32 inclusive" + << "\n\tnum_to_read == " << num_to_read + << "\n\tis_in_read_mode() == " << this->is_in_read_mode() + << "\n\tthis: " << this + ); + + // call the real function + return bit_stream_base::multi_read(data,num_to_read); + + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BIT_STREAM_MULTi_C_ + diff --git a/ml/dlib/dlib/bits/c++config.h b/ml/dlib/dlib/bits/c++config.h new file mode 100644 index 000000000..6139ba823 --- /dev/null +++ b/ml/dlib/dlib/bits/c++config.h @@ -0,0 +1 @@ +#include "../dlib_include_path_tutorial.txt" diff --git a/ml/dlib/dlib/bound_function_pointer.h b/ml/dlib/dlib/bound_function_pointer.h new file mode 100644 index 000000000..a482919c6 --- /dev/null +++ b/ml/dlib/dlib/bound_function_pointer.h @@ -0,0 +1,10 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BOUND_FUNCTION_POINTEr_ +#define DLIB_BOUND_FUNCTION_POINTEr_ + +#include "bound_function_pointer/bound_function_pointer_kernel_1.h" + +#endif // DLIB_BOUND_FUNCTION_POINTEr_ + + diff --git a/ml/dlib/dlib/bound_function_pointer/bound_function_pointer_kernel_1.h b/ml/dlib/dlib/bound_function_pointer/bound_function_pointer_kernel_1.h new file mode 100644 index 000000000..a39592742 --- /dev/null +++ b/ml/dlib/dlib/bound_function_pointer/bound_function_pointer_kernel_1.h @@ -0,0 +1,774 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BOUND_FUNCTION_POINTER_KERNEl_1_ +#define DLIB_BOUND_FUNCTION_POINTER_KERNEl_1_ + +#include "../algs.h" +#include "../member_function_pointer.h" +#include "bound_function_pointer_kernel_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace bfp1_helpers + { + template struct strip { typedef T type; }; + template struct strip { typedef T type; }; + + // ------------------------------------------------------------------------------------ + + class bound_function_helper_base_base + { + public: + virtual ~bound_function_helper_base_base(){} + virtual void call() const = 0; + virtual bool is_set() const = 0; + virtual void clone(void* ptr) const = 0; + }; + + // ------------------------------------------------------------------------------------ + + template + class bound_function_helper_base : public bound_function_helper_base_base + { + public: + bound_function_helper_base():arg1(0), arg2(0), arg3(0), arg4(0) {} + + typename strip::type* arg1; + typename strip::type* arg2; + typename strip::type* arg3; + typename strip::type* arg4; + + + member_function_pointer mfp; + }; + + // ---------------- + + template + class bound_function_helper : public bound_function_helper_base + { + public: + void call() const + { + (*fp)(*this->arg1, *this->arg2, *this->arg3, *this->arg4); + } + + typename strip::type* fp; + }; + + template + class bound_function_helper : public bound_function_helper_base + { + public: + void call() const + { + if (this->mfp) this->mfp(*this->arg1, *this->arg2, *this->arg3, *this->arg4); + else if (fp) fp(*this->arg1, *this->arg2, *this->arg3, *this->arg4); + } + + void (*fp)(T1, T2, T3, T4); + }; + + // ---------------- + + template + class bound_function_helper : public bound_function_helper_base + { + public: + void call() const + { + (*fp)(); + } + + typename strip::type* fp; + }; + + template <> + class bound_function_helper : public bound_function_helper_base + { + public: + void call() const + { + if (this->mfp) this->mfp(); + else if (fp) fp(); + } + + void (*fp)(); + }; + + // ---------------- + + template + class bound_function_helper : public bound_function_helper_base + { + public: + void call() const + { + (*fp)(*this->arg1); + } + + typename strip::type* fp; + }; + + template + class bound_function_helper : public bound_function_helper_base + { + public: + void call() const + { + if (this->mfp) this->mfp(*this->arg1); + else if (fp) fp(*this->arg1); + } + + void (*fp)(T1); + }; + + // ---------------- + + template + class bound_function_helper : public bound_function_helper_base + { + public: + void call() const + { + (*fp)(*this->arg1, *this->arg2); + } + + typename strip::type* fp; + }; + + template + class bound_function_helper : public bound_function_helper_base + { + public: + void call() const + { + if (this->mfp) this->mfp(*this->arg1, *this->arg2); + else if (fp) fp(*this->arg1, *this->arg2); + } + + void (*fp)(T1, T2); + }; + + // ---------------- + + template + class bound_function_helper : public bound_function_helper_base + { + public: + void call() const + { + (*fp)(*this->arg1, *this->arg2, *this->arg3); + } + + typename strip::type* fp; + }; + + template + class bound_function_helper : public bound_function_helper_base + { + public: + + void call() const + { + if (this->mfp) this->mfp(*this->arg1, *this->arg2, *this->arg3); + else if (fp) fp(*this->arg1, *this->arg2, *this->arg3); + } + + void (*fp)(T1, T2, T3); + }; + + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + + template + class bound_function_helper_T : public T + { + public: + bound_function_helper_T(){ this->fp = 0;} + + bool is_set() const + { + return this->fp != 0 || this->mfp.is_set(); + } + + template + void safe_clone(stack_based_memory_block& buf) + { + // This is here just to validate the assumption that our block of memory we have made + // in bf_memory is the right size to store the data for this object. If you + // get a compiler error on this line then email me :) + COMPILE_TIME_ASSERT(sizeof(bound_function_helper_T) <= mem_size); + clone(buf.get()); + } + + void clone (void* ptr) const + { + bound_function_helper_T* p = new(ptr) bound_function_helper_T(); + p->arg1 = this->arg1; + p->arg2 = this->arg2; + p->arg3 = this->arg3; + p->arg4 = this->arg4; + p->fp = this->fp; + p->mfp = this->mfp; + } + }; + + } + +// ---------------------------------------------------------------------------------------- + + class bound_function_pointer + { + typedef bfp1_helpers::bound_function_helper_T > bf_null_type; + + public: + + // These typedefs are here for backwards compatibility with previous versions of + // dlib. + typedef bound_function_pointer kernel_1a; + typedef bound_function_pointer kernel_1a_c; + + + bound_function_pointer ( + ) { bf_null_type().safe_clone(bf_memory); } + + bound_function_pointer ( + const bound_function_pointer& item + ) { item.bf()->clone(bf_memory.get()); } + + ~bound_function_pointer() + { destroy_bf_memory(); } + + bound_function_pointer& operator= ( + const bound_function_pointer& item + ) { bound_function_pointer(item).swap(*this); return *this; } + + void clear ( + ) { bound_function_pointer().swap(*this); } + + bool is_set ( + ) const + { + return bf()->is_set(); + } + + void swap ( + bound_function_pointer& item + ) + { + // make a temp copy of item + bound_function_pointer temp(item); + + // destory the stuff in item + item.destroy_bf_memory(); + // copy *this into item + bf()->clone(item.bf_memory.get()); + + // destory the stuff in this + destroy_bf_memory(); + // copy temp into *this + temp.bf()->clone(bf_memory.get()); + } + + void operator() ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(is_set() == true , + "\tvoid bound_function_pointer::operator()" + << "\n\tYou must call set() before you can use this function" + << "\n\tthis: " << this + ); + + bf()->call(); + } + + private: + struct dummy{ void nonnull() {}}; + typedef void (dummy::*safe_bool)(); + + public: + operator safe_bool () const { return is_set() ? &dummy::nonnull : 0; } + bool operator!() const { return !is_set(); } + + // ------------------------------------------- + // set function object overloads + // ------------------------------------------- + + template + void set ( + F& function_object + ) + { + COMPILE_TIME_ASSERT(is_function::value == false); + COMPILE_TIME_ASSERT(is_pointer_type::value == false); + + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.fp = &function_object; + + temp.safe_clone(bf_memory); + } + + template + void set ( + F& function_object, + A1& arg1 + ) + { + COMPILE_TIME_ASSERT(is_function::value == false); + COMPILE_TIME_ASSERT(is_pointer_type::value == false); + + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.arg1 = &arg1; + temp.fp = &function_object; + + temp.safe_clone(bf_memory); + } + + template + void set ( + F& function_object, + A1& arg1, + A2& arg2 + ) + { + COMPILE_TIME_ASSERT(is_function::value == false); + COMPILE_TIME_ASSERT(is_pointer_type::value == false); + + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.arg1 = &arg1; + temp.arg2 = &arg2; + temp.fp = &function_object; + + temp.safe_clone(bf_memory); + } + + template + void set ( + F& function_object, + A1& arg1, + A2& arg2, + A3& arg3 + ) + { + COMPILE_TIME_ASSERT(is_function::value == false); + COMPILE_TIME_ASSERT(is_pointer_type::value == false); + + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.arg1 = &arg1; + temp.arg2 = &arg2; + temp.arg3 = &arg3; + temp.fp = &function_object; + + temp.safe_clone(bf_memory); + } + + template + void set ( + F& function_object, + A1& arg1, + A2& arg2, + A3& arg3, + A4& arg4 + ) + { + COMPILE_TIME_ASSERT(is_function::value == false); + COMPILE_TIME_ASSERT(is_pointer_type::value == false); + + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.arg1 = &arg1; + temp.arg2 = &arg2; + temp.arg3 = &arg3; + temp.arg4 = &arg4; + temp.fp = &function_object; + + temp.safe_clone(bf_memory); + } + + // ------------------------------------------- + // set mfp overloads + // ------------------------------------------- + + template + void set ( + T& object, + void (T::*funct)() + ) + { + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.mfp.set(object,funct); + + temp.safe_clone(bf_memory); + } + + template + void set ( + const T& object, + void (T::*funct)()const + ) + { + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.mfp.set(object,funct); + + temp.safe_clone(bf_memory); + } + + // ------------------------------------------- + + template + void set ( + T& object, + void (T::*funct)(T1), + A1& arg1 + ) + { + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.arg1 = &arg1; + temp.mfp.set(object,funct); + + temp.safe_clone(bf_memory); + } + + template + void set ( + const T& object, + void (T::*funct)(T1)const, + A1& arg1 + ) + { + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.arg1 = &arg1; + temp.mfp.set(object,funct); + + temp.safe_clone(bf_memory); + } + + // ---------------- + + template + void set ( + T& object, + void (T::*funct)(T1, T2), + A1& arg1, + A2& arg2 + ) + { + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.arg1 = &arg1; + temp.arg2 = &arg2; + temp.mfp.set(object,funct); + + temp.safe_clone(bf_memory); + } + + template + void set ( + const T& object, + void (T::*funct)(T1, T2)const, + A1& arg1, + A2& arg2 + ) + { + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.arg1 = &arg1; + temp.arg2 = &arg2; + temp.mfp.set(object,funct); + + temp.safe_clone(bf_memory); + } + + // ---------------- + + template + void set ( + T& object, + void (T::*funct)(T1, T2, T3), + A1& arg1, + A2& arg2, + A3& arg3 + ) + { + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.arg1 = &arg1; + temp.arg2 = &arg2; + temp.arg3 = &arg3; + temp.mfp.set(object,funct); + + temp.safe_clone(bf_memory); + } + + template + void set ( + const T& object, + void (T::*funct)(T1, T2, T3)const, + A1& arg1, + A2& arg2, + A3& arg3 + ) + { + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.arg1 = &arg1; + temp.arg2 = &arg2; + temp.arg3 = &arg3; + temp.mfp.set(object,funct); + + temp.safe_clone(bf_memory); + } + + // ---------------- + + template + void set ( + T& object, + void (T::*funct)(T1, T2, T3, T4), + A1& arg1, + A2& arg2, + A3& arg3, + A4& arg4 + ) + { + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.arg1 = &arg1; + temp.arg2 = &arg2; + temp.arg3 = &arg3; + temp.arg4 = &arg4; + temp.mfp.set(object,funct); + + temp.safe_clone(bf_memory); + } + + template + void set ( + const T& object, + void (T::*funct)(T1, T2, T3, T4)const, + A1& arg1, + A2& arg2, + A3& arg3, + A4& arg4 + ) + { + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.arg1 = &arg1; + temp.arg2 = &arg2; + temp.arg3 = &arg3; + temp.arg4 = &arg4; + temp.mfp.set(object,funct); + + temp.safe_clone(bf_memory); + } + + // ------------------------------------------- + // set fp overloads + // ------------------------------------------- + + void set ( + void (*funct)() + ) + { + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.fp = funct; + + temp.safe_clone(bf_memory); + } + + template + void set ( + void (*funct)(T1), + A1& arg1 + ) + { + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.arg1 = &arg1; + temp.fp = funct; + + temp.safe_clone(bf_memory); + } + + template + void set ( + void (*funct)(T1, T2), + A1& arg1, + A2& arg2 + ) + { + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.arg1 = &arg1; + temp.arg2 = &arg2; + temp.fp = funct; + + temp.safe_clone(bf_memory); + } + + template + void set ( + void (*funct)(T1, T2, T3), + A1& arg1, + A2& arg2, + A3& arg3 + ) + { + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.arg1 = &arg1; + temp.arg2 = &arg2; + temp.arg3 = &arg3; + temp.fp = funct; + + temp.safe_clone(bf_memory); + } + + template + void set ( + void (*funct)(T1, T2, T3, T4), + A1& arg1, + A2& arg2, + A3& arg3, + A4& arg4 + ) + { + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.arg1 = &arg1; + temp.arg2 = &arg2; + temp.arg3 = &arg3; + temp.arg4 = &arg4; + temp.fp = funct; + + temp.safe_clone(bf_memory); + } + + // ------------------------------------------- + + private: + + stack_based_memory_block bf_memory; + + void destroy_bf_memory ( + ) + { + // Honestly, this probably doesn't even do anything but I'm putting + // it here just for good measure. + bf()->~bound_function_helper_base_base(); + } + + bfp1_helpers::bound_function_helper_base_base* bf () + { return static_cast(bf_memory.get()); } + + const bfp1_helpers::bound_function_helper_base_base* bf () const + { return static_cast(bf_memory.get()); } + + }; + +// ---------------------------------------------------------------------------------------- + + inline void swap ( + bound_function_pointer& a, + bound_function_pointer& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BOUND_FUNCTION_POINTER_KERNEl_1_ + diff --git a/ml/dlib/dlib/bound_function_pointer/bound_function_pointer_kernel_abstract.h b/ml/dlib/dlib/bound_function_pointer/bound_function_pointer_kernel_abstract.h new file mode 100644 index 000000000..b5356d6e0 --- /dev/null +++ b/ml/dlib/dlib/bound_function_pointer/bound_function_pointer_kernel_abstract.h @@ -0,0 +1,456 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_BOUND_FUNCTION_POINTER_KERNEl_ABSTRACT_ +#ifdef DLIB_BOUND_FUNCTION_POINTER_KERNEl_ABSTRACT_ + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class bound_function_pointer + { + /*! + INITIAL VALUE + is_set() == false + + WHAT THIS OBJECT REPRESENTS + This object represents a function with all its arguments bound to + specific objects. For example: + + void test(int& var) { var = var+1; } + + bound_function_pointer funct; + + int a = 4; + funct.set(test,a); // bind the variable a to the first argument of the test() function + + // at this point a == 4 + funct(); + // after funct() is called a == 5 + !*/ + + public: + + bound_function_pointer ( + ); + /*! + ensures + - #*this is properly initialized + !*/ + + bound_function_pointer( + const bound_function_pointer& item + ); + /*! + ensures + - *this == item + !*/ + + ~bound_function_pointer ( + ); + /*! + ensures + - any resources associated with *this have been released + !*/ + + bound_function_pointer& operator=( + const bound_function_pointer& item + ); + /*! + ensures + - *this == item + !*/ + + void clear( + ); + /*! + ensures + - #*this has its initial value + !*/ + + bool is_set ( + ) const; + /*! + ensures + - if (this->set() has been called) then + - returns true + - else + - returns false + !*/ + + operator some_undefined_pointer_type ( + ) const; + /*! + ensures + - if (is_set()) then + - returns a non 0 value + - else + - returns a 0 value + !*/ + + bool operator! ( + ) const; + /*! + ensures + - returns !is_set() + !*/ + + void operator () ( + ) const; + /*! + requires + - is_set() == true + ensures + - calls the bound function on the object(s) specified by the last + call to this->set() + throws + - any exception thrown by the function specified by + the previous call to this->set(). + If any of these exceptions are thrown then the call to this + function will have no effect on *this. + !*/ + + void swap ( + bound_function_pointer& item + ); + /*! + ensures + - swaps *this and item + !*/ + + // ---------------------- + + template + void set ( + F& function_object + ); + /*! + requires + - function_object() is a valid expression + ensures + - #is_set() == true + - calls to this->operator() will call function_object() + (This seems pointless but it is a useful base case) + !*/ + + template < typename T> + void set ( + T& object, + void (T::*funct)() + ); + /*! + requires + - funct == a valid member function pointer for class T + ensures + - #is_set() == true + - calls to this->operator() will call (object.*funct)() + !*/ + + template < typename T> + void set ( + const T& object, + void (T::*funct)()const + ); + /*! + requires + - funct == a valid bound function pointer for class T + ensures + - #is_set() == true + - calls to this->operator() will call (object.*funct)() + !*/ + + void set ( + void (*funct)() + ); + /*! + requires + - funct == a valid function pointer + ensures + - #is_set() == true + - calls to this->operator() will call funct() + !*/ + + // ---------------------- + + template + void set ( + F& function_object, + A1& arg1 + ); + /*! + requires + - function_object(arg1) is a valid expression + ensures + - #is_set() == true + - calls to this->operator() will call function_object(arg1) + !*/ + + template < typename T, typename T1, typename A1 > + void set ( + T& object, + void (T::*funct)(T1), + A1& arg1 + ); + /*! + requires + - funct == a valid member function pointer for class T + ensures + - #is_set() == true + - calls to this->operator() will call (object.*funct)(arg1) + !*/ + + template < typename T, typename T1, typename A1 > + void set ( + const T& object, + void (T::*funct)(T1)const, + A1& arg1 + ); + /*! + requires + - funct == a valid bound function pointer for class T + ensures + - #is_set() == true + - calls to this->operator() will call (object.*funct)(arg1) + !*/ + + template + void set ( + void (*funct)(T1), + A1& arg1 + ); + /*! + requires + - funct == a valid function pointer + ensures + - #is_set() == true + - calls to this->operator() will call funct(arg1) + !*/ + + // ---------------------- + template + void set ( + F& function_object, + A1& arg1, + A2& arg2 + ); + /*! + requires + - function_object(arg1,arg2) is a valid expression + ensures + - #is_set() == true + - calls to this->operator() will call function_object(arg1,arg2) + !*/ + + template < typename T, typename T1, typename A1, + typename T2, typename A2> + void set ( + T& object, + void (T::*funct)(T1,T2), + A1& arg1, + A2& arg2 + ); + /*! + requires + - funct == a valid member function pointer for class T + ensures + - #is_set() == true + - calls to this->operator() will call (object.*funct)(arg1,arg2) + !*/ + + template < typename T, typename T1, typename A1, + typename T2, typename A2> + void set ( + const T& object, + void (T::*funct)(T1,T2)const, + A1& arg1, + A2& arg2 + ); + /*! + requires + - funct == a valid bound function pointer for class T + ensures + - #is_set() == true + - calls to this->operator() will call (object.*funct)(arg1,arg2) + !*/ + + template + void set ( + void (*funct)(T1,T2), + A1& arg1, + A2& arg2 + ); + /*! + requires + - funct == a valid function pointer + ensures + - #is_set() == true + - calls to this->operator() will call funct(arg1,arg2) + !*/ + + // ---------------------- + + template + void set ( + F& function_object, + A1& arg1, + A2& arg2, + A3& arg3 + ); + /*! + requires + - function_object(arg1,arg2,arg3) is a valid expression + ensures + - #is_set() == true + - calls to this->operator() will call function_object(arg1,arg2,arg3) + !*/ + + template < typename T, typename T1, typename A1, + typename T2, typename A2, + typename T3, typename A3> + void set ( + T& object, + void (T::*funct)(T1,T2,T3), + A1& arg1, + A2& arg2, + A3& arg3 + ); + /*! + requires + - funct == a valid member function pointer for class T + ensures + - #is_set() == true + - calls to this->operator() will call (object.*funct)(arg1,arg2,arg3) + !*/ + + template < typename T, typename T1, typename A1, + typename T2, typename A2, + typename T3, typename A3> + void set ( + const T& object, + void (T::*funct)(T1,T2,T3)const, + A1& arg1, + A2& arg2, + A3& arg3 + ); + /*! + requires + - funct == a valid bound function pointer for class T + ensures + - #is_set() == true + - calls to this->operator() will call (object.*funct)(arg1,arg2,arg3) + !*/ + + template + void set ( + void (*funct)(T1,T2,T3), + A1& arg1, + A2& arg2, + A3& arg3 + ); + /*! + requires + - funct == a valid function pointer + ensures + - #is_set() == true + - calls to this->operator() will call funct(arg1,arg2,arg3) + !*/ + + // ---------------------- + + template + void set ( + F& function_object, + A1& arg1, + A2& arg2, + A3& arg3, + A4& arg4 + ); + /*! + requires + - function_object(arg1,arg2,arg3,arg4) is a valid expression + ensures + - #is_set() == true + - calls to this->operator() will call function_object(arg1,arg2,arg3,arg4) + !*/ + + template < typename T, typename T1, typename A1, + typename T2, typename A2, + typename T3, typename A3, + typename T4, typename A4> + void set ( + T& object, + void (T::*funct)(T1,T2,T3,T4), + A1& arg1, + A2& arg2, + A3& arg3, + A4& arg4 + ); + /*! + requires + - funct == a valid member function pointer for class T + ensures + - #is_set() == true + - calls to this->operator() will call (object.*funct)(arg1,arg2,arg3,arg4) + !*/ + + template < typename T, typename T1, typename A1, + typename T2, typename A2, + typename T3, typename A3, + typename T4, typename A4> + void set ( + const T& object, + void (T::*funct)(T1,T2,T3,T4)const, + A1& arg1, + A2& arg2, + A3& arg3, + A4& arg4 + ); + /*! + requires + - funct == a valid bound function pointer for class T + ensures + - #is_set() == true + - calls to this->operator() will call (object.*funct)(arg1,arg2,arg3,arg4) + !*/ + + template + void set ( + void (*funct)(T1,T2,T3,T4), + A1& arg1, + A2& arg2, + A3& arg3, + A4& arg4 + ); + /*! + requires + - funct == a valid function pointer + ensures + - #is_set() == true + - calls to this->operator() will call funct(arg1,arg2,arg3,arg4) + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + inline void swap ( + bound_function_pointer& a, + bound_function_pointer& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BOUND_FUNCTION_POINTER_KERNEl_ABSTRACT_ + diff --git a/ml/dlib/dlib/bridge.h b/ml/dlib/dlib/bridge.h new file mode 100644 index 000000000..4b633c405 --- /dev/null +++ b/ml/dlib/dlib/bridge.h @@ -0,0 +1,17 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#ifdef DLIB_ALL_SOURCE_END +#include "dlib_basic_cpp_build_tutorial.txt" +#endif + + +#ifndef DLIB_BRIdGE_ +#define DLIB_BRIdGE_ + + +#include "bridge/bridge.h" + +#endif // DLIB_BRIdGE_ + + diff --git a/ml/dlib/dlib/bridge/bridge.h b/ml/dlib/dlib/bridge/bridge.h new file mode 100644 index 000000000..da4e0bd7e --- /dev/null +++ b/ml/dlib/dlib/bridge/bridge.h @@ -0,0 +1,669 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BRIDGe_Hh_ +#define DLIB_BRIDGe_Hh_ + +#include +#include +#include + +#include "bridge_abstract.h" +#include "../pipe.h" +#include "../threads.h" +#include "../serialize.h" +#include "../sockets.h" +#include "../sockstreambuf.h" +#include "../logger.h" +#include "../algs.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + struct connect_to_ip_and_port + { + connect_to_ip_and_port ( + const std::string& ip_, + unsigned short port_ + ): ip(ip_), port(port_) + { + // make sure requires clause is not broken + DLIB_ASSERT(is_ip_address(ip) && port != 0, + "\t connect_to_ip_and_port()" + << "\n\t Invalid inputs were given to this function" + << "\n\t ip: " << ip + << "\n\t port: " << port + << "\n\t this: " << this + ); + } + + private: + friend class bridge; + const std::string ip; + const unsigned short port; + }; + + inline connect_to_ip_and_port connect_to ( + const network_address& addr + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(addr.port != 0, + "\t connect_to_ip_and_port()" + << "\n\t The TCP port to connect to can't be 0." + << "\n\t addr.port: " << addr.port + ); + + if (is_ip_address(addr.host_address)) + { + return connect_to_ip_and_port(addr.host_address, addr.port); + } + else + { + std::string ip; + if(hostname_to_ip(addr.host_address,ip)) + throw socket_error(ERESOLVE,"unable to resolve '" + addr.host_address + "' in connect_to()"); + + return connect_to_ip_and_port(ip, addr.port); + } + } + + struct listen_on_port + { + listen_on_port( + unsigned short port_ + ) : port(port_) + { + // make sure requires clause is not broken + DLIB_ASSERT( port != 0, + "\t listen_on_port()" + << "\n\t Invalid inputs were given to this function" + << "\n\t port: " << port + << "\n\t this: " << this + ); + } + + private: + friend class bridge; + const unsigned short port; + }; + + template + struct bridge_transmit_decoration + { + bridge_transmit_decoration ( + pipe_type& p_ + ) : p(p_) {} + + private: + friend class bridge; + pipe_type& p; + }; + + template + bridge_transmit_decoration transmit ( pipe_type& p) { return bridge_transmit_decoration(p); } + + template + struct bridge_receive_decoration + { + bridge_receive_decoration ( + pipe_type& p_ + ) : p(p_) {} + + private: + friend class bridge; + pipe_type& p; + }; + + template + bridge_receive_decoration receive ( pipe_type& p) { return bridge_receive_decoration(p); } + +// ---------------------------------------------------------------------------------------- + + struct bridge_status + { + bridge_status() : is_connected(false), foreign_port(0){} + + bool is_connected; + unsigned short foreign_port; + std::string foreign_ip; + }; + + inline void serialize ( const bridge_status& , std::ostream& ) + { + throw serialization_error("It is illegal to serialize bridge_status objects."); + } + + inline void deserialize ( bridge_status& , std::istream& ) + { + throw serialization_error("It is illegal to serialize bridge_status objects."); + } + +// ---------------------------------------------------------------------------------------- + + namespace impl_brns + { + class impl_bridge_base + { + public: + + virtual ~impl_bridge_base() {} + + virtual bridge_status get_bridge_status ( + ) const = 0; + }; + + template < + typename transmit_pipe_type, + typename receive_pipe_type + > + class impl_bridge : public impl_bridge_base, private noncopyable, private multithreaded_object + { + /*! + CONVENTION + - if (list) then + - this object is supposed to be listening on the list object for incoming + connections when not connected. + - else + - this object is supposed to be attempting to connect to ip:port when + not connected. + + - get_bridge_status() == current_bs + !*/ + public: + + impl_bridge ( + unsigned short listen_port, + transmit_pipe_type* transmit_pipe_, + receive_pipe_type* receive_pipe_ + ) : + s(m), + receive_thread_active(false), + transmit_thread_active(false), + port(0), + transmit_pipe(transmit_pipe_), + receive_pipe(receive_pipe_), + dlog("dlib.bridge"), + keepalive_code(0), + message_code(1) + { + int status = create_listener(list, listen_port); + if (status == PORTINUSE) + { + std::ostringstream sout; + sout << "Error, the port " << listen_port << " is already in use."; + throw socket_error(EPORT_IN_USE, sout.str()); + } + else if (status == OTHER_ERROR) + { + throw socket_error("Unable to create listening socket for an unknown reason."); + } + + register_thread(*this, &impl_bridge::transmit_thread); + register_thread(*this, &impl_bridge::receive_thread); + register_thread(*this, &impl_bridge::connect_thread); + + start(); + } + + impl_bridge ( + const std::string ip_, + unsigned short port_, + transmit_pipe_type* transmit_pipe_, + receive_pipe_type* receive_pipe_ + ) : + s(m), + receive_thread_active(false), + transmit_thread_active(false), + port(port_), + ip(ip_), + transmit_pipe(transmit_pipe_), + receive_pipe(receive_pipe_), + dlog("dlib.bridge"), + keepalive_code(0), + message_code(1) + { + register_thread(*this, &impl_bridge::transmit_thread); + register_thread(*this, &impl_bridge::receive_thread); + register_thread(*this, &impl_bridge::connect_thread); + + start(); + } + + ~impl_bridge() + { + // tell the threads to terminate + stop(); + + // save current pipe enabled status so we can restore it to however + // it was before this destructor ran. + bool transmit_enabled = true; + bool receive_enabled = true; + + // make any calls blocked on a pipe return immediately. + if (transmit_pipe) + { + transmit_enabled = transmit_pipe->is_dequeue_enabled(); + transmit_pipe->disable_dequeue(); + } + if (receive_pipe) + { + receive_enabled = receive_pipe->is_enqueue_enabled(); + receive_pipe->disable_enqueue(); + } + + { + auto_mutex lock(m); + s.broadcast(); + // Shutdown the connection if we have one. This will cause + // all blocked I/O calls to return an error. + if (con) + con->shutdown(); + } + + // wait for all the threads to terminate. + wait(); + + if (transmit_pipe && transmit_enabled) + transmit_pipe->enable_dequeue(); + if (receive_pipe && receive_enabled) + receive_pipe->enable_enqueue(); + } + + bridge_status get_bridge_status ( + ) const + { + auto_mutex lock(current_bs_mutex); + return current_bs; + } + + private: + + + template + typename enable_if >::type enqueue_bridge_status ( + pipe_type* p, + const bridge_status& status + ) + { + if (p) + { + typename pipe_type::type temp(status); + p->enqueue(temp); + } + } + + template + typename disable_if >::type enqueue_bridge_status ( + pipe_type* , + const bridge_status& + ) + { + } + + void connect_thread ( + ) + { + while (!should_stop()) + { + auto_mutex lock(m); + int status = OTHER_ERROR; + if (list) + { + do + { + status = list->accept(con, 1000); + } while (status == TIMEOUT && !should_stop()); + } + else + { + status = create_connection(con, port, ip); + } + + if (should_stop()) + break; + + if (status != 0) + { + // The last connection attempt failed. So pause for a little bit before making another attempt. + s.wait_or_timeout(2000); + continue; + } + + dlog << LINFO << "Established new connection to " << con->get_foreign_ip() << ":" << con->get_foreign_port() << "."; + + bridge_status temp_bs; + { auto_mutex lock(current_bs_mutex); + current_bs.is_connected = true; + current_bs.foreign_port = con->get_foreign_port(); + current_bs.foreign_ip = con->get_foreign_ip(); + temp_bs = current_bs; + } + enqueue_bridge_status(receive_pipe, temp_bs); + + + receive_thread_active = true; + transmit_thread_active = true; + + s.broadcast(); + + // Wait for the transmit and receive threads to end before we continue. + // This way we don't invalidate the con pointer while it is in use. + while (receive_thread_active || transmit_thread_active) + s.wait(); + + + dlog << LINFO << "Closed connection to " << con->get_foreign_ip() << ":" << con->get_foreign_port() << "."; + { auto_mutex lock(current_bs_mutex); + current_bs.is_connected = false; + current_bs.foreign_port = con->get_foreign_port(); + current_bs.foreign_ip = con->get_foreign_ip(); + temp_bs = current_bs; + } + enqueue_bridge_status(receive_pipe, temp_bs); + } + + } + + + void receive_thread ( + ) + { + while (true) + { + // wait until we have a connection + { auto_mutex lock(m); + while (!receive_thread_active && !should_stop()) + { + s.wait(); + } + + if (should_stop()) + break; + } + + + + try + { + if (receive_pipe) + { + sockstreambuf buf(con); + std::istream in(&buf); + typename receive_pipe_type::type item; + // This isn't necessary but doing it avoids a warning about + // item being uninitialized sometimes. + assign_zero_if_built_in_scalar_type(item); + + while (in.peek() != EOF) + { + unsigned char code; + in.read((char*)&code, sizeof(code)); + if (code == message_code) + { + deserialize(item, in); + receive_pipe->enqueue(item); + } + } + } + else + { + // Since we don't have a receive pipe to put messages into we will + // just read the bytes from the connection and ignore them. + char buf[1000]; + while (con->read(buf, sizeof(buf)) > 0) ; + } + } + catch (std::bad_alloc& ) + { + dlog << LERROR << "std::bad_alloc thrown while deserializing message from " + << con->get_foreign_ip() << ":" << con->get_foreign_port(); + } + catch (dlib::serialization_error& e) + { + dlog << LERROR << "dlib::serialization_error thrown while deserializing message from " + << con->get_foreign_ip() << ":" << con->get_foreign_port() + << ".\nThe exception error message is: \n" << e.what(); + } + catch (std::exception& e) + { + dlog << LERROR << "std::exception thrown while deserializing message from " + << con->get_foreign_ip() << ":" << con->get_foreign_port() + << ".\nThe exception error message is: \n" << e.what(); + } + + + + + con->shutdown(); + auto_mutex lock(m); + receive_thread_active = false; + s.broadcast(); + } + + auto_mutex lock(m); + receive_thread_active = false; + s.broadcast(); + } + + void transmit_thread ( + ) + { + while (true) + { + // wait until we have a connection + { auto_mutex lock(m); + while (!transmit_thread_active && !should_stop()) + { + s.wait(); + } + + if (should_stop()) + break; + } + + + + try + { + sockstreambuf buf(con); + std::ostream out(&buf); + typename transmit_pipe_type::type item; + // This isn't necessary but doing it avoids a warning about + // item being uninitialized sometimes. + assign_zero_if_built_in_scalar_type(item); + + + while (out) + { + bool dequeue_timed_out = false; + if (transmit_pipe ) + { + if (transmit_pipe->dequeue_or_timeout(item,1000)) + { + out.write((char*)&message_code, sizeof(message_code)); + serialize(item, out); + if (transmit_pipe->size() == 0) + out.flush(); + + continue; + } + + dequeue_timed_out = (transmit_pipe->is_enabled() && transmit_pipe->is_dequeue_enabled()); + } + + // Pause for about a second. Note that we use a wait_or_timeout() call rather + // than sleep() here because we want to wake up immediately if this object is + // being destructed rather than hang for a second. + if (!dequeue_timed_out) + { + auto_mutex lock(m); + if (should_stop()) + break; + + s.wait_or_timeout(1000); + } + // Just send the keepalive byte periodically so we can + // tell if the connection is alive. + out.write((char*)&keepalive_code, sizeof(keepalive_code)); + out.flush(); + } + } + catch (std::bad_alloc& ) + { + dlog << LERROR << "std::bad_alloc thrown while serializing message to " + << con->get_foreign_ip() << ":" << con->get_foreign_port(); + } + catch (dlib::serialization_error& e) + { + dlog << LERROR << "dlib::serialization_error thrown while serializing message to " + << con->get_foreign_ip() << ":" << con->get_foreign_port() + << ".\nThe exception error message is: \n" << e.what(); + } + catch (std::exception& e) + { + dlog << LERROR << "std::exception thrown while serializing message to " + << con->get_foreign_ip() << ":" << con->get_foreign_port() + << ".\nThe exception error message is: \n" << e.what(); + } + + + + + con->shutdown(); + auto_mutex lock(m); + transmit_thread_active = false; + s.broadcast(); + } + + auto_mutex lock(m); + transmit_thread_active = false; + s.broadcast(); + } + + mutex m; + signaler s; + bool receive_thread_active; + bool transmit_thread_active; + std::unique_ptr con; + std::unique_ptr list; + const unsigned short port; + const std::string ip; + transmit_pipe_type* const transmit_pipe; + receive_pipe_type* const receive_pipe; + logger dlog; + const unsigned char keepalive_code; + const unsigned char message_code; + + mutex current_bs_mutex; + bridge_status current_bs; + }; + } + + +// ---------------------------------------------------------------------------------------- + + class bridge : noncopyable + { + public: + + bridge () {} + + template < typename T, typename U, typename V > + bridge ( + T network_parameters, + U pipe1, + V pipe2 + ) { reconfigure(network_parameters,pipe1,pipe2); } + + template < typename T, typename U> + bridge ( + T network_parameters, + U pipe + ) { reconfigure(network_parameters,pipe); } + + + void clear ( + ) + { + pimpl.reset(); + } + + template < typename T, typename R > + void reconfigure ( + listen_on_port network_parameters, + bridge_transmit_decoration transmit_pipe, + bridge_receive_decoration receive_pipe + ) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge(network_parameters.port, &transmit_pipe.p, &receive_pipe.p)); } + + template < typename T, typename R > + void reconfigure ( + listen_on_port network_parameters, + bridge_receive_decoration receive_pipe, + bridge_transmit_decoration transmit_pipe + ) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge(network_parameters.port, &transmit_pipe.p, &receive_pipe.p)); } + + template < typename T > + void reconfigure ( + listen_on_port network_parameters, + bridge_transmit_decoration transmit_pipe + ) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge(network_parameters.port, &transmit_pipe.p, 0)); } + + template < typename R > + void reconfigure ( + listen_on_port network_parameters, + bridge_receive_decoration receive_pipe + ) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge(network_parameters.port, 0, &receive_pipe.p)); } + + + + + template < typename T, typename R > + void reconfigure ( + connect_to_ip_and_port network_parameters, + bridge_transmit_decoration transmit_pipe, + bridge_receive_decoration receive_pipe + ) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge(network_parameters.ip, network_parameters.port, &transmit_pipe.p, &receive_pipe.p)); } + + template < typename T, typename R > + void reconfigure ( + connect_to_ip_and_port network_parameters, + bridge_receive_decoration receive_pipe, + bridge_transmit_decoration transmit_pipe + ) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge(network_parameters.ip, network_parameters.port, &transmit_pipe.p, &receive_pipe.p)); } + + template < typename R > + void reconfigure ( + connect_to_ip_and_port network_parameters, + bridge_receive_decoration receive_pipe + ) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge(network_parameters.ip, network_parameters.port, 0, &receive_pipe.p)); } + + template < typename T > + void reconfigure ( + connect_to_ip_and_port network_parameters, + bridge_transmit_decoration transmit_pipe + ) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge(network_parameters.ip, network_parameters.port, &transmit_pipe.p, 0)); } + + + bridge_status get_bridge_status ( + ) const + { + if (pimpl) + return pimpl->get_bridge_status(); + else + return bridge_status(); + } + + private: + + std::unique_ptr pimpl; + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BRIDGe_Hh_ + diff --git a/ml/dlib/dlib/bridge/bridge_abstract.h b/ml/dlib/dlib/bridge/bridge_abstract.h new file mode 100644 index 000000000..76ed21153 --- /dev/null +++ b/ml/dlib/dlib/bridge/bridge_abstract.h @@ -0,0 +1,347 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_BRIDGe_ABSTRACT_ +#ifdef DLIB_BRIDGe_ABSTRACT_ + +#include +#include "../pipe/pipe_kernel_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + struct connect_to_ip_and_port + { + connect_to_ip_and_port ( + const std::string& ip, + unsigned short port + ); + /*! + requires + - is_ip_address(ip) == true + - port != 0 + ensures + - this object will represent a request to make a TCP connection + to the given IP address and port number. + !*/ + }; + + connect_to_ip_and_port connect_to ( + const network_address& addr + ); + /*! + requires + - addr.port != 0 + ensures + - converts the given network_address object into a connect_to_ip_and_port + object. + !*/ + + struct listen_on_port + { + listen_on_port( + unsigned short port + ); + /*! + requires + - port != 0 + ensures + - this object will represent a request to listen on the given + port number for incoming TCP connections. + !*/ + }; + + template < + typename pipe_type + > + bridge_transmit_decoration transmit ( + pipe_type& p + ); + /*! + requires + - pipe_type is some kind of dlib::pipe object + - the objects in the pipe must be serializable + ensures + - Adds a type decoration to the given pipe, marking it as a transmit pipe, and + then returns it. + !*/ + + template < + typename pipe_type + > + bridge_receive_decoration receive ( + pipe_type& p + ); + /*! + requires + - pipe_type is some kind of dlib::pipe object + - the objects in the pipe must be serializable + ensures + - Adds a type decoration to the given pipe, marking it as a receive pipe, and + then returns it. + !*/ + +// ---------------------------------------------------------------------------------------- + + struct bridge_status + { + /*! + WHAT THIS OBJECT REPRESENTS + This simple struct represents the state of a bridge object. A + bridge is either connected or not. If it is connected then it + is connected to a foreign host with an IP address and port number + as indicated by this object. + !*/ + + bridge_status( + ); + /*! + ensures + - #is_connected == false + - #foreign_port == 0 + - #foreign_ip == "" + !*/ + + bool is_connected; + unsigned short foreign_port; + std::string foreign_ip; + }; + +// ---------------------------------------------------------------------------------------- + + class bridge : noncopyable + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a tool for bridging a dlib::pipe object between + two network connected applications. + + + Note also that this object contains a dlib::logger object + which will log various events taking place inside a bridge. + If you want to see these log messages then enable the logger + named "dlib.bridge". + + + BRIDGE PROTOCOL DETAILS + The bridge object creates a single TCP connection between + two applications. Whenever it sends an object from a pipe + over a TCP connection it sends a byte with the value 1 followed + immediately by the serialized copy of the object from the pipe. + The serialization is performed by calling the global serialize() + function. + + Additionally, a bridge object will periodically send bytes with + a value of 0 to ensure the TCP connection remains alive. These + are just read and ignored. + !*/ + + public: + + bridge ( + ); + /*! + ensures + - this object is properly initialized + - #get_bridge_status().is_connected == false + !*/ + + template + bridge ( + T network_parameters, + U pipe1, + V pipe2 + ); + /*! + requires + - T is of type connect_to_ip_and_port or listen_on_port + - U and V are of type bridge_transmit_decoration or bridge_receive_decoration, + however, U and V must be of different types (i.e. one is a receive type and + another a transmit type). + ensures + - this object is properly initialized + - performs: reconfigure(network_parameters, pipe1, pipe2) + (i.e. using this constructor is identical to using the default constructor + and then calling reconfigure()) + !*/ + + template + bridge ( + T network_parameters, + U pipe + ); + /*! + requires + - T is of type connect_to_ip_and_port or listen_on_port + - U is of type bridge_transmit_decoration or bridge_receive_decoration. + ensures + - this object is properly initialized + - performs: reconfigure(network_parameters, pipe) + (i.e. using this constructor is identical to using the default constructor + and then calling reconfigure()) + !*/ + + ~bridge ( + ); + /*! + ensures + - blocks until all resources associated with this object have been destroyed. + !*/ + + void clear ( + ); + /*! + ensures + - returns this object to its default constructed state. That is, it will + be inactive, neither maintaining a connection nor attempting to acquire one. + - Any active connections or listening sockets will be closed. + !*/ + + bridge_status get_bridge_status ( + ) const; + /*! + ensures + - returns the current status of this bridge object. In particular, returns + an object BS such that: + - BS.is_connected == true if and only if the bridge has an active TCP + connection to another computer. + - if (BS.is_connected) then + - BS.foreign_ip == the IP address of the remote host we are connected to. + - BS.foreign_port == the port number on the remote host we are connected to. + - else if (the bridge has previously been connected to a remote host but hasn't been + reconfigured or cleared since) then + - BS.foreign_ip == the IP address of the remote host we were connected to. + - BS.foreign_port == the port number on the remote host we were connected to. + - else + - BS.foreign_ip == "" + - BS.foreign_port == 0 + !*/ + + + + template < typename T, typename R > + void reconfigure ( + listen_on_port network_parameters, + bridge_transmit_decoration transmit_pipe, + bridge_receive_decoration receive_pipe + ); + /*! + ensures + - This object will begin listening on the port specified by network_parameters + for incoming TCP connections. Any previous bridge state is cleared out. + - Onces a connection is established we will: + - Stop accepting new connections. + - Begin dequeuing objects from the transmit pipe and serializing them over + the TCP connection. + - Begin deserializing objects from the TCP connection and enqueueing them + onto the receive pipe. + - if (the current TCP connection is lost) then + - This object goes back to listening for a new connection. + - if (the receive pipe can contain bridge_status objects) then + - Whenever the bridge's status changes the updated bridge_status will be + enqueued onto the receive pipe unless the change was a TCP disconnect + resulting from a user calling reconfigure(), clear(), or destructing this + bridge. The status contents are defined by get_bridge_status(). + throws + - socket_error + This exception is thrown if we are unable to open the listening socket. + !*/ + template < typename T, typename R > + void reconfigure ( + listen_on_port network_parameters, + bridge_receive_decoration receive_pipe, + bridge_transmit_decoration transmit_pipe + ); + /*! + ensures + - performs reconfigure(network_parameters, transmit_pipe, receive_pipe) + !*/ + template < typename T > + void reconfigure ( + listen_on_port network_parameters, + bridge_transmit_decoration transmit_pipe + ); + /*! + ensures + - This function is identical to the above two reconfigure() functions + except that there is no receive pipe. + !*/ + template < typename R > + void reconfigure ( + listen_on_port network_parameters, + bridge_receive_decoration receive_pipe + ); + /*! + ensures + - This function is identical to the above three reconfigure() functions + except that there is no transmit pipe. + !*/ + + + + template + void reconfigure ( + connect_to_ip_and_port network_parameters, + bridge_transmit_decoration transmit_pipe, + bridge_receive_decoration receive_pipe + ); + /*! + ensures + - This object will begin making TCP connection attempts to the IP address and port + specified by network_parameters. Any previous bridge state is cleared out. + - Onces a connection is established we will: + - Stop attempting new connections. + - Begin dequeuing objects from the transmit pipe and serializing them over + the TCP connection. + - Begin deserializing objects from the TCP connection and enqueueing them + onto the receive pipe. + - if (the current TCP connection is lost) then + - This object goes back to attempting to make a TCP connection with the + IP address and port specified by network_parameters. + - if (the receive pipe can contain bridge_status objects) then + - Whenever the bridge's status changes the updated bridge_status will be + enqueued onto the receive pipe unless the change was a TCP disconnect + resulting from a user calling reconfigure(), clear(), or destructing this + bridge. The status contents are defined by get_bridge_status(). + !*/ + template + void reconfigure ( + connect_to_ip_and_port network_parameters, + bridge_receive_decoration receive_pipe, + bridge_transmit_decoration transmit_pipe + ); + /*! + ensures + - performs reconfigure(network_parameters, transmit_pipe, receive_pipe) + !*/ + template + void reconfigure ( + connect_to_ip_and_port network_parameters, + bridge_transmit_decoration transmit_pipe + ); + /*! + ensures + - This function is identical to the above two reconfigure() functions + except that there is no receive pipe. + !*/ + template + void reconfigure ( + connect_to_ip_and_port network_parameters, + bridge_receive_decoration receive_pipe + ); + /*! + ensures + - This function is identical to the above three reconfigure() functions + except that there is no transmit pipe. + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BRIDGe_ABSTRACT_ + + diff --git a/ml/dlib/dlib/bsp.h b/ml/dlib/dlib/bsp.h new file mode 100644 index 000000000..899b6a405 --- /dev/null +++ b/ml/dlib/dlib/bsp.h @@ -0,0 +1,12 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BSPh_ +#define DLIB_BSPh_ + + +#include "bsp/bsp.h" + +#endif // DLIB_BSPh_ + + + diff --git a/ml/dlib/dlib/bsp/bsp.cpp b/ml/dlib/dlib/bsp/bsp.cpp new file mode 100644 index 000000000..32e23519e --- /dev/null +++ b/ml/dlib/dlib/bsp/bsp.cpp @@ -0,0 +1,496 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BSP_CPph_ +#define DLIB_BSP_CPph_ + +#include "bsp.h" +#include +#include + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + +namespace dlib +{ + + namespace impl1 + { + + void connect_all ( + map_id_to_con& cons, + const std::vector& hosts, + unsigned long node_id + ) + { + cons.clear(); + for (unsigned long i = 0; i < hosts.size(); ++i) + { + std::unique_ptr con(new bsp_con(hosts[i])); + dlib::serialize(node_id, con->stream); // tell the other end our node_id + unsigned long id = i+1; + cons.add(id, con); + } + } + + void connect_all_hostinfo ( + map_id_to_con& cons, + const std::vector& hosts, + unsigned long node_id, + std::string& error_string + ) + { + cons.clear(); + for (unsigned long i = 0; i < hosts.size(); ++i) + { + try + { + std::unique_ptr con(new bsp_con(hosts[i].addr)); + dlib::serialize(node_id, con->stream); // tell the other end our node_id + con->stream.flush(); + unsigned long id = hosts[i].node_id; + cons.add(id, con); + } + catch (std::exception&) + { + std::ostringstream sout; + sout << "Could not connect to " << hosts[i].addr; + error_string = sout.str(); + break; + } + } + } + + + void send_out_connection_orders ( + map_id_to_con& cons, + const std::vector& hosts + ) + { + // tell everyone their node ids + cons.reset(); + while (cons.move_next()) + { + dlib::serialize(cons.element().key(), cons.element().value()->stream); + } + + // now tell them who to connect to + std::vector targets; + for (unsigned long i = 0; i < hosts.size(); ++i) + { + hostinfo info(hosts[i], i+1); + + dlib::serialize(targets, cons[info.node_id]->stream); + targets.push_back(info); + + // let the other host know how many incoming connections to expect + const unsigned long num = hosts.size()-targets.size(); + dlib::serialize(num, cons[info.node_id]->stream); + cons[info.node_id]->stream.flush(); + } + } + + // ------------------------------------------------------------------------------------ + + + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + namespace impl2 + { + // These control bytes are sent before each message between nodes. Note that many + // of these are only sent between the control node (node 0) and the other nodes. + // This is because the controller node is responsible for handling the + // synchronization that needs to happen when all nodes block on calls to + // receive_data() + // at the same time. + + // denotes a normal content message. + const static char MESSAGE_HEADER = 0; + + // sent to the controller node when someone receives a message via receive_data(). + const static char GOT_MESSAGE = 1; + + // sent to the controller node when someone sends a message via send(). + const static char SENT_MESSAGE = 2; + + // sent to the controller node when someone enters a call to receive_data() + const static char IN_WAITING_STATE = 3; + + // broadcast when a node terminates itself. + const static char NODE_TERMINATE = 5; + + // broadcast by the controller node when it determines that all nodes are blocked + // on calls to receive_data() and there aren't any messages in flight. This is also + // what makes us go to the next epoch. + const static char SEE_ALL_IN_WAITING_STATE = 6; + + // This isn't ever transmitted between nodes. It is used internally to indicate + // that an error occurred. + const static char READ_ERROR = 7; + + // ------------------------------------------------------------------------------------ + + void read_thread ( + impl1::bsp_con* con, + unsigned long node_id, + unsigned long sender_id, + impl1::thread_safe_message_queue& msg_buffer + ) + { + try + { + while(true) + { + impl1::msg_data msg; + deserialize(msg.msg_type, con->stream); + msg.sender_id = sender_id; + + if (msg.msg_type == MESSAGE_HEADER) + { + msg.data.reset(new std::vector); + deserialize(msg.epoch, con->stream); + deserialize(*msg.data, con->stream); + } + + msg_buffer.push_and_consume(msg); + + if (msg.msg_type == NODE_TERMINATE) + break; + } + } + catch (std::exception& e) + { + impl1::msg_data msg; + msg.data.reset(new std::vector); + vectorstream sout(*msg.data); + sout << "An exception was thrown while attempting to receive a message from processing node " << sender_id << ".\n"; + sout << " Sending processing node address: " << con->con->get_foreign_ip() << ":" << con->con->get_foreign_port() << std::endl; + sout << " Receiving processing node address: " << con->con->get_local_ip() << ":" << con->con->get_local_port() << std::endl; + sout << " Receiving processing node id: " << node_id << std::endl; + sout << " Error message in the exception: " << e.what() << std::endl; + + msg.sender_id = sender_id; + msg.msg_type = READ_ERROR; + + msg_buffer.push_and_consume(msg); + } + catch (...) + { + impl1::msg_data msg; + msg.data.reset(new std::vector); + vectorstream sout(*msg.data); + sout << "An exception was thrown while attempting to receive a message from processing node " << sender_id << ".\n"; + sout << " Sending processing node address: " << con->con->get_foreign_ip() << ":" << con->con->get_foreign_port() << std::endl; + sout << " Receiving processing node address: " << con->con->get_local_ip() << ":" << con->con->get_local_port() << std::endl; + sout << " Receiving processing node id: " << node_id << std::endl; + + msg.sender_id = sender_id; + msg.msg_type = READ_ERROR; + + msg_buffer.push_and_consume(msg); + } + } + + // ------------------------------------------------------------------------------------ + + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// IMPLEMENTATION OF bsp_context OBJECT MEMBERS +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + void bsp_context:: + close_all_connections_gracefully( + ) + { + if (node_id() != 0) + { + _cons.reset(); + while (_cons.move_next()) + { + // tell the other end that we are intentionally dropping the connection + serialize(impl2::NODE_TERMINATE,_cons.element().value()->stream); + _cons.element().value()->stream.flush(); + } + } + + impl1::msg_data msg; + // now wait for all the other nodes to terminate + while (num_terminated_nodes < _cons.size() ) + { + if (node_id() == 0 && num_waiting_nodes + num_terminated_nodes == _cons.size() && outstanding_messages == 0) + { + num_waiting_nodes = 0; + broadcast_byte(impl2::SEE_ALL_IN_WAITING_STATE); + ++current_epoch; + } + + if (!msg_buffer.pop(msg)) + throw dlib::socket_error("Error reading from msg_buffer in dlib::bsp_context."); + + if (msg.msg_type == impl2::NODE_TERMINATE) + { + ++num_terminated_nodes; + _cons[msg.sender_id]->terminated = true; + } + else if (msg.msg_type == impl2::READ_ERROR) + { + throw dlib::socket_error(msg.data_to_string()); + } + else if (msg.msg_type == impl2::MESSAGE_HEADER) + { + throw dlib::socket_error("A BSP node received a message after it has terminated."); + } + else if (msg.msg_type == impl2::GOT_MESSAGE) + { + --num_waiting_nodes; + --outstanding_messages; + } + else if (msg.msg_type == impl2::SENT_MESSAGE) + { + ++outstanding_messages; + } + else if (msg.msg_type == impl2::IN_WAITING_STATE) + { + ++num_waiting_nodes; + } + } + + if (node_id() == 0) + { + _cons.reset(); + while (_cons.move_next()) + { + // tell the other end that we are intentionally dropping the connection + serialize(impl2::NODE_TERMINATE,_cons.element().value()->stream); + _cons.element().value()->stream.flush(); + } + + if (outstanding_messages != 0) + { + std::ostringstream sout; + sout << "A BSP job was allowed to terminate before all sent messages have been received.\n"; + sout << "There are at least " << outstanding_messages << " messages still in flight. Make sure all sent messages\n"; + sout << "have a corresponding call to receive()."; + throw dlib::socket_error(sout.str()); + } + } + } + +// ---------------------------------------------------------------------------------------- + + bsp_context:: + ~bsp_context() + { + _cons.reset(); + while (_cons.move_next()) + { + _cons.element().value()->con->shutdown(); + } + + msg_buffer.disable(); + + // this will wait for all the threads to terminate + threads.clear(); + } + +// ---------------------------------------------------------------------------------------- + + bsp_context:: + bsp_context( + unsigned long node_id_, + impl1::map_id_to_con& cons_ + ) : + outstanding_messages(0), + num_waiting_nodes(0), + num_terminated_nodes(0), + current_epoch(1), + _cons(cons_), + _node_id(node_id_) + { + // spawn a bunch of read threads, one for each connection + _cons.reset(); + while (_cons.move_next()) + { + std::unique_ptr ptr(new thread_function(&impl2::read_thread, + _cons.element().value().get(), + _node_id, + _cons.element().key(), + ref(msg_buffer))); + threads.push_back(ptr); + } + + } + +// ---------------------------------------------------------------------------------------- + + bool bsp_context:: + receive_data ( + std::shared_ptr >& item, + unsigned long& sending_node_id + ) + { + notify_control_node(impl2::IN_WAITING_STATE); + + while (true) + { + // If there aren't any nodes left to give us messages then return right now. + // We need to check the msg_buffer size to make sure there aren't any + // unprocessed message there. Recall that this can happen because status + // messages always jump to the front of the message buffer. So we might have + // learned about the node terminations before processing their messages for us. + if (num_terminated_nodes == _cons.size() && msg_buffer.size() == 0) + { + return false; + } + + // if all running nodes are currently blocking forever on receive_data() + if (node_id() == 0 && outstanding_messages == 0 && num_terminated_nodes + num_waiting_nodes == _cons.size()) + { + num_waiting_nodes = 0; + broadcast_byte(impl2::SEE_ALL_IN_WAITING_STATE); + + // Note that the reason we have this epoch counter is so we can tell if a + // sent message is from before or after one of these "all nodes waiting" + // synchronization events. If we didn't have the epoch count we would have + // a race condition where one node gets the SEE_ALL_IN_WAITING_STATE + // message before others and then sends out a message to another node + // before that node got the SEE_ALL_IN_WAITING_STATE message. Then that + // node would think the normal message came before SEE_ALL_IN_WAITING_STATE + // which would be bad. + ++current_epoch; + return false; + } + + impl1::msg_data data; + if (!msg_buffer.pop(data, current_epoch)) + throw dlib::socket_error("Error reading from msg_buffer in dlib::bsp_context."); + + + switch(data.msg_type) + { + case impl2::MESSAGE_HEADER: { + item = data.data; + sending_node_id = data.sender_id; + notify_control_node(impl2::GOT_MESSAGE); + return true; + } break; + + case impl2::IN_WAITING_STATE: { + ++num_waiting_nodes; + } break; + + case impl2::GOT_MESSAGE: { + --outstanding_messages; + --num_waiting_nodes; + } break; + + case impl2::SENT_MESSAGE: { + ++outstanding_messages; + } break; + + case impl2::NODE_TERMINATE: { + ++num_terminated_nodes; + _cons[data.sender_id]->terminated = true; + } break; + + case impl2::SEE_ALL_IN_WAITING_STATE: { + ++current_epoch; + return false; + } break; + + case impl2::READ_ERROR: { + throw dlib::socket_error(data.data_to_string()); + } break; + + default: { + throw dlib::socket_error("Unknown message received by dlib::bsp_context"); + } break; + } // end switch() + } // end while (true) + } + +// ---------------------------------------------------------------------------------------- + + void bsp_context:: + notify_control_node ( + char val + ) + { + if (node_id() == 0) + { + using namespace impl2; + switch(val) + { + case SENT_MESSAGE: { + ++outstanding_messages; + } break; + + case GOT_MESSAGE: { + --outstanding_messages; + } break; + + case IN_WAITING_STATE: { + // nothing to do in this case + } break; + + default: + DLIB_CASSERT(false,"This should never happen"); + } + } + else + { + serialize(val, _cons[0]->stream); + _cons[0]->stream.flush(); + } + } + +// ---------------------------------------------------------------------------------------- + + void bsp_context:: + broadcast_byte ( + char val + ) + { + for (unsigned long i = 0; i < number_of_nodes(); ++i) + { + // don't send to yourself or to terminated nodes + if (i == node_id() || _cons[i]->terminated) + continue; + + serialize(val, _cons[i]->stream); + _cons[i]->stream.flush(); + } + } + +// ---------------------------------------------------------------------------------------- + + void bsp_context:: + send_data( + const std::vector& item, + unsigned long target_node_id + ) + { + using namespace impl2; + if (_cons[target_node_id]->terminated) + throw socket_error("Attempt to send a message to a node that has terminated."); + + serialize(MESSAGE_HEADER, _cons[target_node_id]->stream); + serialize(current_epoch, _cons[target_node_id]->stream); + serialize(item, _cons[target_node_id]->stream); + _cons[target_node_id]->stream.flush(); + + notify_control_node(SENT_MESSAGE); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BSP_CPph_ + diff --git a/ml/dlib/dlib/bsp/bsp.h b/ml/dlib/dlib/bsp/bsp.h new file mode 100644 index 000000000..f0732c153 --- /dev/null +++ b/ml/dlib/dlib/bsp/bsp.h @@ -0,0 +1,1043 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BsP_Hh_ +#define DLIB_BsP_Hh_ + +#include "bsp_abstract.h" + +#include +#include +#include + +#include "../sockets.h" +#include "../array.h" +#include "../sockstreambuf.h" +#include "../string.h" +#include "../serialize.h" +#include "../map.h" +#include "../ref.h" +#include "../vectorstream.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace impl1 + { + inline void null_notify( + unsigned short + ) {} + + struct bsp_con + { + bsp_con( + const network_address& dest + ) : + con(connect(dest)), + buf(con), + stream(&buf), + terminated(false) + { + con->disable_nagle(); + } + + bsp_con( + std::unique_ptr& conptr + ) : + buf(conptr), + stream(&buf), + terminated(false) + { + // make sure we own the connection + conptr.swap(con); + + con->disable_nagle(); + } + + std::unique_ptr con; + sockstreambuf buf; + std::iostream stream; + bool terminated; + }; + + typedef dlib::map >::kernel_1a_c map_id_to_con; + + void connect_all ( + map_id_to_con& cons, + const std::vector& hosts, + unsigned long node_id + ); + /*! + ensures + - creates connections to all the given hosts and stores them into cons + !*/ + + void send_out_connection_orders ( + map_id_to_con& cons, + const std::vector& hosts + ); + + // ------------------------------------------------------------------------------------ + + struct hostinfo + { + hostinfo() {} + hostinfo ( + const network_address& addr_, + unsigned long node_id_ + ) : + addr(addr_), + node_id(node_id_) + { + } + + network_address addr; + unsigned long node_id; + }; + + inline void serialize ( + const hostinfo& item, + std::ostream& out + ) + { + dlib::serialize(item.addr, out); + dlib::serialize(item.node_id, out); + } + + inline void deserialize ( + hostinfo& item, + std::istream& in + ) + { + dlib::deserialize(item.addr, in); + dlib::deserialize(item.node_id, in); + } + + // ------------------------------------------------------------------------------------ + + void connect_all_hostinfo ( + map_id_to_con& cons, + const std::vector& hosts, + unsigned long node_id, + std::string& error_string + ); + + // ------------------------------------------------------------------------------------ + + template < + typename port_notify_function_type + > + void listen_and_connect_all( + unsigned long& node_id, + map_id_to_con& cons, + unsigned short port, + port_notify_function_type port_notify_function + ) + { + cons.clear(); + std::unique_ptr list; + const int status = create_listener(list, port); + if (status == PORTINUSE) + { + throw socket_error("Unable to create listening port " + cast_to_string(port) + + ". The port is already in use"); + } + else if (status != 0) + { + throw socket_error("Unable to create listening port " + cast_to_string(port) ); + } + + port_notify_function(list->get_listening_port()); + + std::unique_ptr con; + if (list->accept(con)) + { + throw socket_error("Error occurred while accepting new connection"); + } + + std::unique_ptr temp(new bsp_con(con)); + + unsigned long remote_node_id; + dlib::deserialize(remote_node_id, temp->stream); + dlib::deserialize(node_id, temp->stream); + std::vector targets; + dlib::deserialize(targets, temp->stream); + unsigned long num_incoming_connections; + dlib::deserialize(num_incoming_connections, temp->stream); + + cons.add(remote_node_id,temp); + + // make a thread that will connect to all the targets + map_id_to_con cons2; + std::string error_string; + thread_function thread(connect_all_hostinfo, dlib::ref(cons2), dlib::ref(targets), node_id, dlib::ref(error_string)); + if (error_string.size() != 0) + throw socket_error(error_string); + + // accept any incoming connections + for (unsigned long i = 0; i < num_incoming_connections; ++i) + { + // If it takes more than 10 seconds for the other nodes to connect to us + // then something has gone horribly wrong and it almost certainly will + // never connect at all. So just give up if that happens. + const unsigned long timeout_milliseconds = 10000; + if (list->accept(con, timeout_milliseconds)) + { + throw socket_error("Error occurred while accepting new connection"); + } + + temp.reset(new bsp_con(con)); + + dlib::deserialize(remote_node_id, temp->stream); + cons.add(remote_node_id,temp); + } + + + // put all the connections created by the thread into cons + thread.wait(); + while (cons2.size() > 0) + { + unsigned long id; + std::unique_ptr temp; + cons2.remove_any(id,temp); + cons.add(id,temp); + } + } + + // ------------------------------------------------------------------------------------ + + struct msg_data + { + std::shared_ptr > data; + unsigned long sender_id; + char msg_type; + dlib::uint64 epoch; + + msg_data() : sender_id(0xFFFFFFFF), msg_type(-1), epoch(0) {} + + std::string data_to_string() const + { + if (data && data->size() != 0) + return std::string(&(*data)[0], data->size()); + else + return ""; + } + }; + + // ------------------------------------------------------------------------------------ + + class thread_safe_message_queue : noncopyable + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a simple message queue for msg_data objects. Note that it + has the special property that, while messages will generally leave + the queue in the order they are inserted, any message with a smaller + epoch value will always be popped out first. But for all messages + with equal epoch values the queue functions as a normal FIFO queue. + !*/ + private: + struct msg_wrap + { + msg_wrap( + const msg_data& data_, + const dlib::uint64& sequence_number_ + ) : data(data_), sequence_number(sequence_number_) {} + + msg_wrap() : sequence_number(0){} + + msg_data data; + dlib::uint64 sequence_number; + + // Make it so that when msg_wrap objects are in a std::priority_queue, + // messages with a smaller epoch number always come first. Then, within an + // epoch, messages are ordered by their sequence number (so smaller first + // there as well). + bool operator<(const msg_wrap& item) const + { + if (data.epoch < item.data.epoch) + { + return false; + } + else if (data.epoch > item.data.epoch) + { + return true; + } + else + { + if (sequence_number < item.sequence_number) + return false; + else + return true; + } + } + }; + + public: + thread_safe_message_queue() : sig(class_mutex),disabled(false),next_seq_num(1) {} + + ~thread_safe_message_queue() + { + disable(); + } + + void disable() + { + auto_mutex lock(class_mutex); + disabled = true; + sig.broadcast(); + } + + unsigned long size() const + { + auto_mutex lock(class_mutex); + return data.size(); + } + + void push_and_consume( msg_data& item) + { + auto_mutex lock(class_mutex); + data.push(msg_wrap(item, next_seq_num++)); + // do this here so that we don't have to worry about different threads touching the shared_ptr. + item.data.reset(); + sig.signal(); + } + + bool pop ( + msg_data& item + ) + /*! + ensures + - if (this function returns true) then + - #item == the next thing from the queue + - else + - this object is disabled + !*/ + { + auto_mutex lock(class_mutex); + while (data.size() == 0 && !disabled) + sig.wait(); + + if (disabled) + return false; + + item = data.top().data; + data.pop(); + + return true; + } + + bool pop ( + msg_data& item, + const dlib::uint64& max_epoch + ) + /*! + ensures + - if (this function returns true) then + - #item == the next thing from the queue that has an epoch <= max_epoch + - else + - this object is disabled + !*/ + { + auto_mutex lock(class_mutex); + while ((data.size() == 0 || data.top().data.epoch > max_epoch) && !disabled) + sig.wait(); + + if (disabled) + return false; + + item = data.top().data; + data.pop(); + + return true; + } + + private: + std::priority_queue data; + dlib::mutex class_mutex; + dlib::signaler sig; + bool disabled; + dlib::uint64 next_seq_num; + }; + + + } + +// ---------------------------------------------------------------------------------------- + + class bsp_context : noncopyable + { + + public: + + template + void send( + const T& item, + unsigned long target_node_id + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(target_node_id < number_of_nodes() && + target_node_id != node_id(), + "\t void bsp_context::send()" + << "\n\t Invalid arguments were given to this function." + << "\n\t target_node_id: " << target_node_id + << "\n\t node_id(): " << node_id() + << "\n\t number_of_nodes(): " << number_of_nodes() + << "\n\t this: " << this + ); + + std::vector buf; + vectorstream sout(buf); + serialize(item, sout); + send_data(buf, target_node_id); + } + + template + void broadcast ( + const T& item + ) + { + std::vector buf; + vectorstream sout(buf); + serialize(item, sout); + for (unsigned long i = 0; i < number_of_nodes(); ++i) + { + // Don't send to yourself. + if (i == node_id()) + continue; + + send_data(buf, i); + } + } + + unsigned long node_id ( + ) const { return _node_id; } + + unsigned long number_of_nodes ( + ) const { return _cons.size()+1; } + + void receive ( + ) + { + unsigned long id; + std::shared_ptr > temp; + if (receive_data(temp,id)) + throw dlib::socket_error("Call to bsp_context::receive() got an unexpected message."); + } + + template + void receive ( + T& item + ) + { + if(!try_receive(item)) + throw dlib::socket_error("bsp_context::receive(): no messages to receive, all nodes currently blocked."); + } + + template + bool try_receive ( + T& item + ) + { + unsigned long sending_node_id; + return try_receive(item, sending_node_id); + } + + template + void receive ( + T& item, + unsigned long& sending_node_id + ) + { + if(!try_receive(item, sending_node_id)) + throw dlib::socket_error("bsp_context::receive(): no messages to receive, all nodes currently blocked."); + } + + template + bool try_receive ( + T& item, + unsigned long& sending_node_id + ) + { + std::shared_ptr > temp; + if (receive_data(temp, sending_node_id)) + { + vectorstream sin(*temp); + deserialize(item, sin); + if (sin.peek() != EOF) + throw serialization_error("deserialize() did not consume all bytes produced by serialize(). " + "This probably means you are calling a receive method with a different type " + "of object than the one which was sent."); + return true; + } + else + { + return false; + } + } + + ~bsp_context(); + + private: + + bsp_context(); + + bsp_context( + unsigned long node_id_, + impl1::map_id_to_con& cons_ + ); + + void close_all_connections_gracefully(); + /*! + ensures + - closes all the connections to other nodes and lets them know that + we are terminating normally rather than as the result of some kind + of error. + !*/ + + bool receive_data ( + std::shared_ptr >& item, + unsigned long& sending_node_id + ); + + + void notify_control_node ( + char val + ); + + void broadcast_byte ( + char val + ); + + void send_data( + const std::vector& item, + unsigned long target_node_id + ); + /*! + requires + - target_node_id < number_of_nodes() + - target_node_id != node_id() + ensures + - sends a copy of item to the node with the given id. + !*/ + + + + + unsigned long outstanding_messages; + unsigned long num_waiting_nodes; + unsigned long num_terminated_nodes; + dlib::uint64 current_epoch; + + impl1::thread_safe_message_queue msg_buffer; + + impl1::map_id_to_con& _cons; + const unsigned long _node_id; + array > threads; + + // ----------------------------------- + + template < + typename funct_type + > + friend void bsp_connect ( + const std::vector& hosts, + funct_type funct + ); + + template < + typename funct_type, + typename ARG1 + > + friend void bsp_connect ( + const std::vector& hosts, + funct_type funct, + ARG1 arg1 + ); + + template < + typename funct_type, + typename ARG1, + typename ARG2 + > + friend void bsp_connect ( + const std::vector& hosts, + funct_type funct, + ARG1 arg1, + ARG2 arg2 + ); + + template < + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3 + > + friend void bsp_connect ( + const std::vector& hosts, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3 + ); + + template < + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3, + typename ARG4 + > + friend void bsp_connect ( + const std::vector& hosts, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3, + ARG4 arg4 + ); + + // ----------------------------------- + + template < + typename port_notify_function_type, + typename funct_type + > + friend void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct + ); + + template < + typename port_notify_function_type, + typename funct_type, + typename ARG1 + > + friend void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct, + ARG1 arg1 + ); + + template < + typename port_notify_function_type, + typename funct_type, + typename ARG1, + typename ARG2 + > + friend void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct, + ARG1 arg1, + ARG2 arg2 + ); + + template < + typename port_notify_function_type, + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3 + > + friend void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3 + ); + + template < + typename port_notify_function_type, + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3, + typename ARG4 + > + friend void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3, + ARG4 arg4 + ); + + // ----------------------------------- + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type + > + void bsp_connect ( + const std::vector& hosts, + funct_type funct + ) + { + impl1::map_id_to_con cons; + const unsigned long node_id = 0; + connect_all(cons, hosts, node_id); + send_out_connection_orders(cons, hosts); + bsp_context obj(node_id, cons); + funct(obj); + obj.close_all_connections_gracefully(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1 + > + void bsp_connect ( + const std::vector& hosts, + funct_type funct, + ARG1 arg1 + ) + { + impl1::map_id_to_con cons; + const unsigned long node_id = 0; + connect_all(cons, hosts, node_id); + send_out_connection_orders(cons, hosts); + bsp_context obj(node_id, cons); + funct(obj,arg1); + obj.close_all_connections_gracefully(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1, + typename ARG2 + > + void bsp_connect ( + const std::vector& hosts, + funct_type funct, + ARG1 arg1, + ARG2 arg2 + ) + { + impl1::map_id_to_con cons; + const unsigned long node_id = 0; + connect_all(cons, hosts, node_id); + send_out_connection_orders(cons, hosts); + bsp_context obj(node_id, cons); + funct(obj,arg1,arg2); + obj.close_all_connections_gracefully(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3 + > + void bsp_connect ( + const std::vector& hosts, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3 + ) + { + impl1::map_id_to_con cons; + const unsigned long node_id = 0; + connect_all(cons, hosts, node_id); + send_out_connection_orders(cons, hosts); + bsp_context obj(node_id, cons); + funct(obj,arg1,arg2,arg3); + obj.close_all_connections_gracefully(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3, + typename ARG4 + > + void bsp_connect ( + const std::vector& hosts, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3, + ARG4 arg4 + ) + { + impl1::map_id_to_con cons; + const unsigned long node_id = 0; + connect_all(cons, hosts, node_id); + send_out_connection_orders(cons, hosts); + bsp_context obj(node_id, cons); + funct(obj,arg1,arg2,arg3,arg4); + obj.close_all_connections_gracefully(); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type + > + void bsp_listen ( + unsigned short listening_port, + funct_type funct + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(listening_port != 0, + "\t void bsp_listen()" + << "\n\t Invalid arguments were given to this function." + ); + + bsp_listen_dynamic_port(listening_port, impl1::null_notify, funct); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1 + > + void bsp_listen ( + unsigned short listening_port, + funct_type funct, + ARG1 arg1 + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(listening_port != 0, + "\t void bsp_listen()" + << "\n\t Invalid arguments were given to this function." + ); + + bsp_listen_dynamic_port(listening_port, impl1::null_notify, funct, arg1); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1, + typename ARG2 + > + void bsp_listen ( + unsigned short listening_port, + funct_type funct, + ARG1 arg1, + ARG2 arg2 + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(listening_port != 0, + "\t void bsp_listen()" + << "\n\t Invalid arguments were given to this function." + ); + + bsp_listen_dynamic_port(listening_port, impl1::null_notify, funct, arg1, arg2); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3 + > + void bsp_listen ( + unsigned short listening_port, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3 + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(listening_port != 0, + "\t void bsp_listen()" + << "\n\t Invalid arguments were given to this function." + ); + + bsp_listen_dynamic_port(listening_port, impl1::null_notify, funct, arg1, arg2, arg3); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3, + typename ARG4 + > + void bsp_listen ( + unsigned short listening_port, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3, + ARG4 arg4 + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(listening_port != 0, + "\t void bsp_listen()" + << "\n\t Invalid arguments were given to this function." + ); + + bsp_listen_dynamic_port(listening_port, impl1::null_notify, funct, arg1, arg2, arg3, arg4); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename port_notify_function_type, + typename funct_type + > + void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct + ) + { + impl1::map_id_to_con cons; + unsigned long node_id; + listen_and_connect_all(node_id, cons, listening_port, port_notify_function); + bsp_context obj(node_id, cons); + funct(obj); + obj.close_all_connections_gracefully(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename port_notify_function_type, + typename funct_type, + typename ARG1 + > + void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct, + ARG1 arg1 + ) + { + impl1::map_id_to_con cons; + unsigned long node_id; + listen_and_connect_all(node_id, cons, listening_port, port_notify_function); + bsp_context obj(node_id, cons); + funct(obj,arg1); + obj.close_all_connections_gracefully(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename port_notify_function_type, + typename funct_type, + typename ARG1, + typename ARG2 + > + void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct, + ARG1 arg1, + ARG2 arg2 + ) + { + impl1::map_id_to_con cons; + unsigned long node_id; + listen_and_connect_all(node_id, cons, listening_port, port_notify_function); + bsp_context obj(node_id, cons); + funct(obj,arg1,arg2); + obj.close_all_connections_gracefully(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename port_notify_function_type, + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3 + > + void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3 + ) + { + impl1::map_id_to_con cons; + unsigned long node_id; + listen_and_connect_all(node_id, cons, listening_port, port_notify_function); + bsp_context obj(node_id, cons); + funct(obj,arg1,arg2,arg3); + obj.close_all_connections_gracefully(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename port_notify_function_type, + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3, + typename ARG4 + > + void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3, + ARG4 arg4 + ) + { + impl1::map_id_to_con cons; + unsigned long node_id; + listen_and_connect_all(node_id, cons, listening_port, port_notify_function); + bsp_context obj(node_id, cons); + funct(obj,arg1,arg2,arg3,arg4); + obj.close_all_connections_gracefully(); + } +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + +} + +#ifdef NO_MAKEFILE +#include "bsp.cpp" +#endif + +#endif // DLIB_BsP_Hh_ + diff --git a/ml/dlib/dlib/bsp/bsp_abstract.h b/ml/dlib/dlib/bsp/bsp_abstract.h new file mode 100644 index 000000000..b87f3a0c3 --- /dev/null +++ b/ml/dlib/dlib/bsp/bsp_abstract.h @@ -0,0 +1,912 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_BsP_ABSTRACT_Hh_ +#ifdef DLIB_BsP_ABSTRACT_Hh_ + +#include "../noncopyable.h" +#include "../sockets/sockets_extensions_abstract.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class bsp_context : noncopyable + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a tool used to implement algorithms using the Bulk Synchronous + Parallel (BSP) computing model. A BSP algorithm is composed of a number of + processing nodes, each executing in parallel. The general flow of + execution in each processing node is the following: + 1. Do work locally on some data. + 2. Send some messages to other nodes. + 3. Receive messages from other nodes. + 4. Go to step 1 or terminate if complete. + + To do this, each processing node needs an API used to send and receive + messages. This API is implemented by the bsp_connect object which provides + these services to a BSP node. + + Note that BSP processing nodes are spawned using the bsp_connect() and + bsp_listen() routines defined at the bottom of this file. For example, to + start a BSP algorithm consisting of N processing nodes, you would make N-1 + calls to bsp_listen() and one call to bsp_connect(). The call to + bsp_connect() then initiates the computation on all nodes. + + Finally, note that there is no explicit barrier synchronization function + you call at the end of step 3. Instead, you can simply call a method such + as try_receive() until it returns false. That is, the bsp_context's + receive methods incorporate a barrier synchronization that happens once all + the BSP nodes are blocked on receive calls and there are no more messages + in flight. + + + THREAD SAFETY + This object is not thread-safe. In particular, you should only ever have + one thread that works with an instance of this object. This means that, + for example, you should not spawn sub-threads from within a BSP processing + node and have them invoke methods on this object. Instead, you should only + invoke this object's methods from within the BSP processing node's main + thread (i.e. the thread that executes the user supplied function funct()). + !*/ + + public: + + template + void send( + const T& item, + unsigned long target_node_id + ); + /*! + requires + - item is serializable + - target_node_id < number_of_nodes() + - target_node_id != node_id() + ensures + - sends a copy of item to the node with the given id. + throws + - dlib::socket_error: + This exception is thrown if there is an error which prevents us from + delivering the message to the given node. One way this might happen is + if the target node has already terminated its execution or has lost + network connectivity. + !*/ + + template + void broadcast ( + const T& item + ); + /*! + ensures + - item is serializable + - sends a copy of item to all other processing nodes. + throws + - dlib::socket_error + This exception is thrown if there is an error which prevents us from + delivering a message to one of the other nodes. This might happen, for + example, if one of the nodes has terminated its execution or has lost + network connectivity. + !*/ + + unsigned long node_id ( + ) const; + /*! + ensures + - Returns the id of the current processing node. That is, + returns a number N such that: + - N < number_of_nodes() + - N == the node id of the processing node that called node_id(). This + is a number that uniquely identifies the processing node. + !*/ + + unsigned long number_of_nodes ( + ) const; + /*! + ensures + - returns the number of processing nodes participating in the BSP + computation. + !*/ + + template + bool try_receive ( + T& item + ); + /*! + requires + - item is serializable + ensures + - if (this function returns true) then + - #item == the next message which was sent to the calling processing + node. + - else + - The following must have been true for this function to return false: + - All other nodes were blocked on calls to receive(), + try_receive(), or have terminated. + - There were not any messages in flight between any nodes. + - That is, if all the nodes had continued to block on receive + methods then they all would have blocked forever. Therefore, + this function only returns false once there are no more messages + to process by any node and there is no possibility of more being + generated until control is returned to the callers of receive + methods. + - When one BSP node's receive method returns because of the above + conditions then all of them will also return. That is, it is NOT the + case that just a subset of BSP nodes unblock. Moreover, they all + unblock at the same time. + throws + - dlib::socket_error: + This exception is thrown if some error occurs which prevents us from + communicating with other processing nodes. + - dlib::serialization_error or any exception thrown by the global + deserialize(T) routine: + This is thrown if there is a problem in deserialize(). This might + happen if the message sent doesn't match the type T expected by + try_receive(). + !*/ + + template + void receive ( + T& item + ); + /*! + requires + - item is serializable + ensures + - #item == the next message which was sent to the calling processing + node. + - This function is just a wrapper around try_receive() that throws an + exception if a message is not received (i.e. if try_receive() returns + false). + throws + - dlib::socket_error: + This exception is thrown if some error occurs which prevents us from + communicating with other processing nodes or if there was not a message + to receive. + - dlib::serialization_error or any exception thrown by the global + deserialize(T) routine: + This is thrown if there is a problem in deserialize(). This might + happen if the message sent doesn't match the type T expected by + receive(). + !*/ + + template + bool try_receive ( + T& item, + unsigned long& sending_node_id + ); + /*! + requires + - item is serializable + ensures + - if (this function returns true) then + - #item == the next message which was sent to the calling processing + node. + - #sending_node_id == the node id of the node that sent this message. + - #sending_node_id < number_of_nodes() + - else + - The following must have been true for this function to return false: + - All other nodes were blocked on calls to receive(), + try_receive(), or have terminated. + - There were not any messages in flight between any nodes. + - That is, if all the nodes had continued to block on receive + methods then they all would have blocked forever. Therefore, + this function only returns false once there are no more messages + to process by any node and there is no possibility of more being + generated until control is returned to the callers of receive + methods. + - When one BSP node's receive method returns because of the above + conditions then all of them will also return. That is, it is NOT the + case that just a subset of BSP nodes unblock. Moreover, they all + unblock at the same time. + throws + - dlib::socket_error: + This exception is thrown if some error occurs which prevents us from + communicating with other processing nodes. + - dlib::serialization_error or any exception thrown by the global + deserialize(T) routine: + This is thrown if there is a problem in deserialize(). This might + happen if the message sent doesn't match the type T expected by + try_receive(). + !*/ + + template + void receive ( + T& item, + unsigned long& sending_node_id + ); + /*! + requires + - item is serializable + ensures + - #item == the next message which was sent to the calling processing node. + - #sending_node_id == the node id of the node that sent this message. + - #sending_node_id < number_of_nodes() + - This function is just a wrapper around try_receive() that throws an + exception if a message is not received (i.e. if try_receive() returns + false). + throws + - dlib::socket_error: + This exception is thrown if some error occurs which prevents us from + communicating with other processing nodes or if there was not a message + to receive. + - dlib::serialization_error or any exception thrown by the global + deserialize(T) routine: + This is thrown if there is a problem in deserialize(). This might + happen if the message sent doesn't match the type T expected by + receive(). + !*/ + + void receive ( + ); + /*! + ensures + - Waits for the following to all be true: + - All other nodes were blocked on calls to receive(), try_receive(), or + have terminated. + - There are not any messages in flight between any nodes. + - That is, if all the nodes had continued to block on receive methods + then they all would have blocked forever. Therefore, this function + only returns once there are no more messages to process by any node + and there is no possibility of more being generated until control is + returned to the callers of receive methods. + - When one BSP node's receive method returns because of the above + conditions then all of them will also return. That is, it is NOT the + case that just a subset of BSP nodes unblock. Moreover, they all unblock + at the same time. + throws + - dlib::socket_error: + This exception is thrown if some error occurs which prevents us from + communicating with other processing nodes or if a message is received + before this function would otherwise return. + + !*/ + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type + > + void bsp_connect ( + const std::vector& hosts, + funct_type funct + ); + /*! + requires + - let CONTEXT be an instance of a bsp_context object. Then: + - funct(CONTEXT) must be a valid expression + (i.e. funct must be a function or function object) + ensures + - This function spawns a BSP job consisting of hosts.size()+1 processing nodes. + - The processing node with a node ID of 0 will run locally on the machine + calling bsp_connect(). In particular, this node will execute funct(CONTEXT), + which is expected to carry out this node's portion of the BSP computation. + - The other processing nodes are executed on the hosts indicated by the input + argument. In particular, this function interprets hosts as a list addresses + identifying machines running the bsp_listen() or bsp_listen_dynamic_port() + routines. + - This call to bsp_connect() blocks until the BSP computation has completed on + all processing nodes. + throws + - dlib::socket_error + This exception is thrown if there is an error which prevents the BSP + job from executing. + - Any exception thrown by funct() will be propagated out of this call to + bsp_connect(). + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1 + > + void bsp_connect ( + const std::vector& hosts, + funct_type funct, + ARG1 arg1 + ); + /*! + requires + - let CONTEXT be an instance of a bsp_context object. Then: + - funct(CONTEXT,arg1) must be a valid expression + (i.e. funct must be a function or function object) + ensures + - This function spawns a BSP job consisting of hosts.size()+1 processing nodes. + - The processing node with a node ID of 0 will run locally on the machine + calling bsp_connect(). In particular, this node will execute funct(CONTEXT,arg1), + which is expected to carry out this node's portion of the BSP computation. + - The other processing nodes are executed on the hosts indicated by the input + argument. In particular, this function interprets hosts as a list addresses + identifying machines running the bsp_listen() or bsp_listen_dynamic_port() + routines. + - This call to bsp_connect() blocks until the BSP computation has completed on + all processing nodes. + throws + - dlib::socket_error + This exception is thrown if there is an error which prevents the BSP + job from executing. + - Any exception thrown by funct() will be propagated out of this call to + bsp_connect(). + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1, + typename ARG2 + > + void bsp_connect ( + const std::vector& hosts, + funct_type funct, + ARG1 arg1, + ARG2 arg2 + ); + /*! + requires + - let CONTEXT be an instance of a bsp_context object. Then: + - funct(CONTEXT,arg1,arg2) must be a valid expression + (i.e. funct must be a function or function object) + ensures + - This function spawns a BSP job consisting of hosts.size()+1 processing nodes. + - The processing node with a node ID of 0 will run locally on the machine + calling bsp_connect(). In particular, this node will execute funct(CONTEXT,arg1,arg2), + which is expected to carry out this node's portion of the BSP computation. + - The other processing nodes are executed on the hosts indicated by the input + argument. In particular, this function interprets hosts as a list addresses + identifying machines running the bsp_listen() or bsp_listen_dynamic_port() + routines. + - This call to bsp_connect() blocks until the BSP computation has completed on + all processing nodes. + throws + - dlib::socket_error + This exception is thrown if there is an error which prevents the BSP + job from executing. + - Any exception thrown by funct() will be propagated out of this call to + bsp_connect(). + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3 + > + void bsp_connect ( + const std::vector& hosts, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3 + ); + /*! + requires + - let CONTEXT be an instance of a bsp_context object. Then: + - funct(CONTEXT,arg1,arg2,arg3) must be a valid expression + (i.e. funct must be a function or function object) + ensures + - This function spawns a BSP job consisting of hosts.size()+1 processing nodes. + - The processing node with a node ID of 0 will run locally on the machine + calling bsp_connect(). In particular, this node will execute funct(CONTEXT,arg1,arg2,arg3), + which is expected to carry out this node's portion of the BSP computation. + - The other processing nodes are executed on the hosts indicated by the input + argument. In particular, this function interprets hosts as a list addresses + identifying machines running the bsp_listen() or bsp_listen_dynamic_port() + routines. + - This call to bsp_connect() blocks until the BSP computation has completed on + all processing nodes. + throws + - dlib::socket_error + This exception is thrown if there is an error which prevents the BSP + job from executing. + - Any exception thrown by funct() will be propagated out of this call to + bsp_connect(). + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3, + typename ARG4 + > + void bsp_connect ( + const std::vector& hosts, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3, + ARG4 arg4 + ); + /*! + requires + - let CONTEXT be an instance of a bsp_context object. Then: + - funct(CONTEXT,arg1,arg2,arg3,arg4) must be a valid expression + (i.e. funct must be a function or function object) + ensures + - This function spawns a BSP job consisting of hosts.size()+1 processing nodes. + - The processing node with a node ID of 0 will run locally on the machine + calling bsp_connect(). In particular, this node will execute funct(CONTEXT,arg1,arg2,arg3,arg4), + which is expected to carry out this node's portion of the BSP computation. + - The other processing nodes are executed on the hosts indicated by the input + argument. In particular, this function interprets hosts as a list addresses + identifying machines running the bsp_listen() or bsp_listen_dynamic_port() + routines. + - This call to bsp_connect() blocks until the BSP computation has completed on + all processing nodes. + throws + - dlib::socket_error + This exception is thrown if there is an error which prevents the BSP + job from executing. + - Any exception thrown by funct() will be propagated out of this call to + bsp_connect(). + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type + > + void bsp_listen ( + unsigned short listening_port, + funct_type funct + ); + /*! + requires + - listening_port != 0 + - let CONTEXT be an instance of a bsp_context object. Then: + - funct(CONTEXT) must be a valid expression + (i.e. funct must be a function or function object) + ensures + - This function listens for a connection from the bsp_connect() routine. Once + this connection is established, funct(CONTEXT) will be executed and it will + then be able to participate in the BSP computation as one of the processing + nodes. + - This function will listen on TCP port listening_port for a connection from + bsp_connect(). Once the connection is established, it will close the + listening port so it is free for use by other applications. The connection + and BSP computation will continue uninterrupted. + - This call to bsp_listen() blocks until the BSP computation has completed on + all processing nodes. + throws + - dlib::socket_error + This exception is thrown if there is an error which prevents the BSP + job from executing. + - Any exception thrown by funct() will be propagated out of this call to + bsp_connect(). + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1 + > + void bsp_listen ( + unsigned short listening_port, + funct_type funct, + ARG1 arg1 + ); + /*! + requires + - listening_port != 0 + - let CONTEXT be an instance of a bsp_context object. Then: + - funct(CONTEXT,arg1) must be a valid expression + (i.e. funct must be a function or function object) + ensures + - This function listens for a connection from the bsp_connect() routine. Once + this connection is established, funct(CONTEXT,arg1) will be executed and it will + then be able to participate in the BSP computation as one of the processing + nodes. + - This function will listen on TCP port listening_port for a connection from + bsp_connect(). Once the connection is established, it will close the + listening port so it is free for use by other applications. The connection + and BSP computation will continue uninterrupted. + - This call to bsp_listen() blocks until the BSP computation has completed on + all processing nodes. + throws + - dlib::socket_error + This exception is thrown if there is an error which prevents the BSP + job from executing. + - Any exception thrown by funct() will be propagated out of this call to + bsp_connect(). + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1, + typename ARG2 + > + void bsp_listen ( + unsigned short listening_port, + funct_type funct, + ARG1 arg1, + ARG2 arg2 + ); + /*! + requires + - listening_port != 0 + - let CONTEXT be an instance of a bsp_context object. Then: + - funct(CONTEXT,arg1,arg2) must be a valid expression + (i.e. funct must be a function or function object) + ensures + - This function listens for a connection from the bsp_connect() routine. Once + this connection is established, funct(CONTEXT,arg1,arg2) will be executed and + it will then be able to participate in the BSP computation as one of the + processing nodes. + - This function will listen on TCP port listening_port for a connection from + bsp_connect(). Once the connection is established, it will close the + listening port so it is free for use by other applications. The connection + and BSP computation will continue uninterrupted. + - This call to bsp_listen() blocks until the BSP computation has completed on + all processing nodes. + throws + - dlib::socket_error + This exception is thrown if there is an error which prevents the BSP + job from executing. + - Any exception thrown by funct() will be propagated out of this call to + bsp_connect(). + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3 + > + void bsp_listen ( + unsigned short listening_port, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3 + ); + /*! + requires + - listening_port != 0 + - let CONTEXT be an instance of a bsp_context object. Then: + - funct(CONTEXT,arg1,arg2,arg3) must be a valid expression + (i.e. funct must be a function or function object) + ensures + - This function listens for a connection from the bsp_connect() routine. Once + this connection is established, funct(CONTEXT,arg1,arg2,arg3) will be + executed and it will then be able to participate in the BSP computation as + one of the processing nodes. + - This function will listen on TCP port listening_port for a connection from + bsp_connect(). Once the connection is established, it will close the + listening port so it is free for use by other applications. The connection + and BSP computation will continue uninterrupted. + - This call to bsp_listen() blocks until the BSP computation has completed on + all processing nodes. + throws + - dlib::socket_error + This exception is thrown if there is an error which prevents the BSP + job from executing. + - Any exception thrown by funct() will be propagated out of this call to + bsp_connect(). + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3, + typename ARG4 + > + void bsp_listen ( + unsigned short listening_port, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3, + ARG4 arg4 + ); + /*! + requires + - listening_port != 0 + - let CONTEXT be an instance of a bsp_context object. Then: + - funct(CONTEXT,arg1,arg2,arg3,arg4) must be a valid expression + (i.e. funct must be a function or function object) + ensures + - This function listens for a connection from the bsp_connect() routine. Once + this connection is established, funct(CONTEXT,arg1,arg2,arg3,arg4) will be + executed and it will then be able to participate in the BSP computation as + one of the processing nodes. + - This function will listen on TCP port listening_port for a connection from + bsp_connect(). Once the connection is established, it will close the + listening port so it is free for use by other applications. The connection + and BSP computation will continue uninterrupted. + - This call to bsp_listen() blocks until the BSP computation has completed on + all processing nodes. + throws + - dlib::socket_error + This exception is thrown if there is an error which prevents the BSP + job from executing. + - Any exception thrown by funct() will be propagated out of this call to + bsp_connect(). + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename port_notify_function_type, + typename funct_type + > + void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct + ); + /*! + requires + - let CONTEXT be an instance of a bsp_context object. Then: + - funct(CONTEXT) must be a valid expression + (i.e. funct must be a function or function object) + - port_notify_function((unsigned short) 1234) must be a valid expression + (i.e. port_notify_function() must be a function or function object taking an + unsigned short) + ensures + - This function listens for a connection from the bsp_connect() routine. Once + this connection is established, funct(CONTEXT) will be executed and it will + then be able to participate in the BSP computation as one of the processing + nodes. + - if (listening_port != 0) then + - This function will listen on TCP port listening_port for a connection + from bsp_connect(). + - else + - An available TCP port number is automatically selected and this function + will listen on it for a connection from bsp_connect(). + - Once a listening port is opened, port_notify_function() is called with the + port number used. This provides a mechanism to find out what listening port + has been used if it is automatically selected. It also allows you to find + out when the routine has begun listening for an incoming connection from + bsp_connect(). + - Once a connection is established, we will close the listening port so it is + free for use by other applications. The connection and BSP computation will + continue uninterrupted. + - This call to bsp_listen_dynamic_port() blocks until the BSP computation has + completed on all processing nodes. + throws + - dlib::socket_error + This exception is thrown if there is an error which prevents the BSP + job from executing. + - Any exception thrown by funct() will be propagated out of this call to + bsp_connect(). + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename port_notify_function_type, + typename funct_type, + typename ARG1 + > + void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct, + ARG1 arg1 + ); + /*! + requires + - let CONTEXT be an instance of a bsp_context object. Then: + - funct(CONTEXT,arg1) must be a valid expression + (i.e. funct must be a function or function object) + - port_notify_function((unsigned short) 1234) must be a valid expression + (i.e. port_notify_function() must be a function or function object taking an + unsigned short) + ensures + - This function listens for a connection from the bsp_connect() routine. Once + this connection is established, funct(CONTEXT,arg1) will be executed and it + will then be able to participate in the BSP computation as one of the + processing nodes. + - if (listening_port != 0) then + - This function will listen on TCP port listening_port for a connection + from bsp_connect(). + - else + - An available TCP port number is automatically selected and this function + will listen on it for a connection from bsp_connect(). + - Once a listening port is opened, port_notify_function() is called with the + port number used. This provides a mechanism to find out what listening port + has been used if it is automatically selected. It also allows you to find + out when the routine has begun listening for an incoming connection from + bsp_connect(). + - Once a connection is established, we will close the listening port so it is + free for use by other applications. The connection and BSP computation will + continue uninterrupted. + - This call to bsp_listen_dynamic_port() blocks until the BSP computation has + completed on all processing nodes. + throws + - dlib::socket_error + This exception is thrown if there is an error which prevents the BSP + job from executing. + - Any exception thrown by funct() will be propagated out of this call to + bsp_connect(). + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename port_notify_function_type, + typename funct_type, + typename ARG1, + typename ARG2 + > + void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct, + ARG1 arg1, + ARG2 arg2 + ); + /*! + requires + - let CONTEXT be an instance of a bsp_context object. Then: + - funct(CONTEXT,arg1,arg2) must be a valid expression + (i.e. funct must be a function or function object) + - port_notify_function((unsigned short) 1234) must be a valid expression + (i.e. port_notify_function() must be a function or function object taking an + unsigned short) + ensures + - This function listens for a connection from the bsp_connect() routine. Once + this connection is established, funct(CONTEXT,arg1,arg2) will be executed and + it will then be able to participate in the BSP computation as one of the + processing nodes. + - if (listening_port != 0) then + - This function will listen on TCP port listening_port for a connection + from bsp_connect(). + - else + - An available TCP port number is automatically selected and this function + will listen on it for a connection from bsp_connect(). + - Once a listening port is opened, port_notify_function() is called with the + port number used. This provides a mechanism to find out what listening port + has been used if it is automatically selected. It also allows you to find + out when the routine has begun listening for an incoming connection from + bsp_connect(). + - Once a connection is established, we will close the listening port so it is + free for use by other applications. The connection and BSP computation will + continue uninterrupted. + - This call to bsp_listen_dynamic_port() blocks until the BSP computation has + completed on all processing nodes. + throws + - dlib::socket_error + This exception is thrown if there is an error which prevents the BSP + job from executing. + - Any exception thrown by funct() will be propagated out of this call to + bsp_connect(). + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename port_notify_function_type, + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3 + > + void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3 + ); + /*! + requires + - let CONTEXT be an instance of a bsp_context object. Then: + - funct(CONTEXT,arg1,arg2,arg3) must be a valid expression + (i.e. funct must be a function or function object) + - port_notify_function((unsigned short) 1234) must be a valid expression + (i.e. port_notify_function() must be a function or function object taking an + unsigned short) + ensures + - This function listens for a connection from the bsp_connect() routine. Once + this connection is established, funct(CONTEXT,arg1,arg2,arg3) will be + executed and it will then be able to participate in the BSP computation as + one of the processing nodes. + - if (listening_port != 0) then + - This function will listen on TCP port listening_port for a connection + from bsp_connect(). + - else + - An available TCP port number is automatically selected and this function + will listen on it for a connection from bsp_connect(). + - Once a listening port is opened, port_notify_function() is called with the + port number used. This provides a mechanism to find out what listening port + has been used if it is automatically selected. It also allows you to find + out when the routine has begun listening for an incoming connection from + bsp_connect(). + - Once a connection is established, we will close the listening port so it is + free for use by other applications. The connection and BSP computation will + continue uninterrupted. + - This call to bsp_listen_dynamic_port() blocks until the BSP computation has + completed on all processing nodes. + throws + - dlib::socket_error + This exception is thrown if there is an error which prevents the BSP + job from executing. + - Any exception thrown by funct() will be propagated out of this call to + bsp_connect(). + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename port_notify_function_type, + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3, + typename ARG4 + > + void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3, + ARG4 arg4 + ); + /*! + requires + - let CONTEXT be an instance of a bsp_context object. Then: + - funct(CONTEXT,arg1,arg2,arg3,arg4) must be a valid expression + (i.e. funct must be a function or function object) + - port_notify_function((unsigned short) 1234) must be a valid expression + (i.e. port_notify_function() must be a function or function object taking an + unsigned short) + ensures + - This function listens for a connection from the bsp_connect() routine. Once + this connection is established, funct(CONTEXT,arg1,arg2,arg3,arg4) will be + executed and it will then be able to participate in the BSP computation as + one of the processing nodes. + - if (listening_port != 0) then + - This function will listen on TCP port listening_port for a connection + from bsp_connect(). + - else + - An available TCP port number is automatically selected and this function + will listen on it for a connection from bsp_connect(). + - Once a listening port is opened, port_notify_function() is called with the + port number used. This provides a mechanism to find out what listening port + has been used if it is automatically selected. It also allows you to find + out when the routine has begun listening for an incoming connection from + bsp_connect(). + - Once a connection is established, we will close the listening port so it is + free for use by other applications. The connection and BSP computation will + continue uninterrupted. + - This call to bsp_listen_dynamic_port() blocks until the BSP computation has + completed on all processing nodes. + throws + - dlib::socket_error + This exception is thrown if there is an error which prevents the BSP + job from executing. + - Any exception thrown by funct() will be propagated out of this call to + bsp_connect(). + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BsP_ABSTRACT_Hh_ + diff --git a/ml/dlib/dlib/byte_orderer.h b/ml/dlib/dlib/byte_orderer.h new file mode 100644 index 000000000..bc8f6108d --- /dev/null +++ b/ml/dlib/dlib/byte_orderer.h @@ -0,0 +1,10 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BYTE_ORDEREr_ +#define DLIB_BYTE_ORDEREr_ + + +#include "byte_orderer/byte_orderer_kernel_1.h" + +#endif // DLIB_BYTE_ORDEREr_ + diff --git a/ml/dlib/dlib/byte_orderer/byte_orderer_kernel_1.h b/ml/dlib/dlib/byte_orderer/byte_orderer_kernel_1.h new file mode 100644 index 000000000..9f8e8342f --- /dev/null +++ b/ml/dlib/dlib/byte_orderer/byte_orderer_kernel_1.h @@ -0,0 +1,176 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BYTE_ORDEREr_KERNEL_1_ +#define DLIB_BYTE_ORDEREr_KERNEL_1_ + +#include "byte_orderer_kernel_abstract.h" +#include "../algs.h" +#include "../assert.h" + +namespace dlib +{ + + class byte_orderer + { + /*! + INITIAL VALUE + - if (this machine is little endian) then + - little_endian == true + - else + - little_endian == false + + CONVENTION + - host_is_big_endian() == !little_endian + - host_is_little_endian() == little_endian + + - if (this machine is little endian) then + - little_endian == true + - else + - little_endian == false + + + !*/ + + + public: + + // this is here for backwards compatibility with older versions of dlib. + typedef byte_orderer kernel_1a; + + byte_orderer ( + ) + { + // This will probably never be false but if it is then it means chars are not 8bits + // on this system. Which is a problem for this object. + COMPILE_TIME_ASSERT(sizeof(short) >= 2); + + unsigned long temp = 1; + unsigned char* ptr = reinterpret_cast(&temp); + if (*ptr == 1) + little_endian = true; + else + little_endian = false; + } + + virtual ~byte_orderer ( + ){} + + bool host_is_big_endian ( + ) const { return !little_endian; } + + bool host_is_little_endian ( + ) const { return little_endian; } + + template < + typename T + > + inline void host_to_network ( + T& item + ) const + { if (little_endian) flip(item); } + + template < + typename T + > + inline void network_to_host ( + T& item + ) const { if (little_endian) flip(item); } + + template < + typename T + > + void host_to_big ( + T& item + ) const { if (little_endian) flip(item); } + + template < + typename T + > + void big_to_host ( + T& item + ) const { if (little_endian) flip(item); } + + template < + typename T + > + void host_to_little ( + T& item + ) const { if (!little_endian) flip(item); } + + template < + typename T + > + void little_to_host ( + T& item + ) const { if (!little_endian) flip(item); } + + + private: + + template < + typename T, + size_t size + > + inline void flip ( + T (&array)[size] + ) const + /*! + ensures + - flips the bytes in every element of this array + !*/ + { + for (size_t i = 0; i < size; ++i) + { + flip(array[i]); + } + } + + template < + typename T + > + inline void flip ( + T& item + ) const + /*! + ensures + - reverses the byte ordering in item + !*/ + { + DLIB_ASSERT_HAS_STANDARD_LAYOUT(T); + + T value; + + // If you are getting this as an error then you are probably using + // this object wrong. If you think you aren't then send me (Davis) an + // email and I'll either set you straight or change/remove this check so + // your stuff works :) + COMPILE_TIME_ASSERT(sizeof(T) <= sizeof(long double)); + + // If you are getting a compile error on this line then it means T is + // a pointer type. It doesn't make any sense to byte swap pointers + // since they have no meaning outside the context of their own process. + // So you probably just forgot to dereference that pointer before passing + // it to this function :) + COMPILE_TIME_ASSERT(is_pointer_type::value == false); + + + const size_t size = sizeof(T); + unsigned char* const ptr = reinterpret_cast(&item); + unsigned char* const ptr_temp = reinterpret_cast(&value); + for (size_t i = 0; i < size; ++i) + ptr_temp[size-i-1] = ptr[i]; + + item = value; + } + + bool little_endian; + }; + + // make flip not do anything at all for chars + template <> inline void byte_orderer::flip ( char& ) const {} + template <> inline void byte_orderer::flip ( unsigned char& ) const {} + template <> inline void byte_orderer::flip ( signed char& ) const {} +} + +#endif // DLIB_BYTE_ORDEREr_KERNEL_1_ + diff --git a/ml/dlib/dlib/byte_orderer/byte_orderer_kernel_abstract.h b/ml/dlib/dlib/byte_orderer/byte_orderer_kernel_abstract.h new file mode 100644 index 000000000..f7ea15103 --- /dev/null +++ b/ml/dlib/dlib/byte_orderer/byte_orderer_kernel_abstract.h @@ -0,0 +1,149 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_BYTE_ORDEREr_ABSTRACT_ +#ifdef DLIB_BYTE_ORDEREr_ABSTRACT_ + +#include "../algs.h" + +namespace dlib +{ + + class byte_orderer + { + /*! + INITIAL VALUE + This object has no state. + + WHAT THIS OBJECT REPRESENTS + This object simply provides a mechanism to convert data from a + host machine's own byte ordering to big or little endian and to + also do the reverse. + + It also provides a pair of functions to convert to/from network byte + order where network byte order is big endian byte order. This pair of + functions does the exact same thing as the host_to_big() and big_to_host() + functions and is provided simply so that client code can use the most + self documenting name appropriate. + + Also note that this object is capable of correctly flipping the contents + of arrays when the arrays are declared on the stack. e.g. You can + say things like: + int array[10]; + bo.host_to_network(array); + !*/ + + public: + + byte_orderer ( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc + !*/ + + virtual ~byte_orderer ( + ); + /*! + ensures + - any resources associated with *this have been released + !*/ + + bool host_is_big_endian ( + ) const; + /*! + ensures + - if (the host computer is a big endian machine) then + - returns true + - else + - returns false + !*/ + + bool host_is_little_endian ( + ) const; + /*! + ensures + - if (the host computer is a little endian machine) then + - returns true + - else + - returns false + !*/ + + template < + typename T + > + void host_to_network ( + T& item + ) const; + /*! + ensures + - #item == the value of item converted from host byte order + to network byte order. + !*/ + + template < + typename T + > + void network_to_host ( + T& item + ) const; + /*! + ensures + - #item == the value of item converted from network byte order + to host byte order. + !*/ + + template < + typename T + > + void host_to_big ( + T& item + ) const; + /*! + ensures + - #item == the value of item converted from host byte order + to big endian byte order. + !*/ + + template < + typename T + > + void big_to_host ( + T& item + ) const; + /*! + ensures + - #item == the value of item converted from big endian byte order + to host byte order. + !*/ + + template < + typename T + > + void host_to_little ( + T& item + ) const; + /*! + ensures + - #item == the value of item converted from host byte order + to little endian byte order. + !*/ + + template < + typename T + > + void little_to_host ( + T& item + ) const; + /*! + ensures + - #item == the value of item converted from little endian byte order + to host byte order. + !*/ + + }; +} + +#endif // DLIB_BYTE_ORDEREr_ABSTRACT_ + diff --git a/ml/dlib/dlib/cassert b/ml/dlib/dlib/cassert new file mode 100644 index 000000000..eb0e59e41 --- /dev/null +++ b/ml/dlib/dlib/cassert @@ -0,0 +1 @@ +#include "dlib_include_path_tutorial.txt" diff --git a/ml/dlib/dlib/clustering.h b/ml/dlib/dlib/clustering.h new file mode 100644 index 000000000..3cbd6cfd4 --- /dev/null +++ b/ml/dlib/dlib/clustering.h @@ -0,0 +1,13 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CLuSTERING_ +#define DLIB_CLuSTERING_ + +#include "clustering/modularity_clustering.h" +#include "clustering/chinese_whispers.h" +#include "clustering/spectral_cluster.h" +#include "clustering/bottom_up_cluster.h" +#include "svm/kkmeans.h" + +#endif // DLIB_CLuSTERING_ + diff --git a/ml/dlib/dlib/clustering/bottom_up_cluster.h b/ml/dlib/dlib/clustering/bottom_up_cluster.h new file mode 100644 index 000000000..f80b65108 --- /dev/null +++ b/ml/dlib/dlib/clustering/bottom_up_cluster.h @@ -0,0 +1,253 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BOTTOM_uP_CLUSTER_Hh_ +#define DLIB_BOTTOM_uP_CLUSTER_Hh_ + +#include +#include + +#include "bottom_up_cluster_abstract.h" +#include "../algs.h" +#include "../matrix.h" +#include "../disjoint_subsets.h" +#include "../graph_utils.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace buc_impl + { + inline void merge_sets ( + matrix& dists, + unsigned long dest, + unsigned long src + ) + { + for (long r = 0; r < dists.nr(); ++r) + dists(dest,r) = dists(r,dest) = std::max(dists(r,dest), dists(r,src)); + } + + struct compare_dist + { + bool operator() ( + const sample_pair& a, + const sample_pair& b + ) const + { + return a.distance() > b.distance(); + } + }; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP + > + unsigned long bottom_up_cluster ( + const matrix_exp& dists_, + std::vector& labels, + unsigned long min_num_clusters, + double max_dist = std::numeric_limits::infinity() + ) + { + matrix dists = matrix_cast(dists_); + // make sure requires clause is not broken + DLIB_CASSERT(dists.nr() == dists.nc() && min_num_clusters > 0, + "\t unsigned long bottom_up_cluster()" + << "\n\t Invalid inputs were given to this function." + << "\n\t dists.nr(): " << dists.nr() + << "\n\t dists.nc(): " << dists.nc() + << "\n\t min_num_clusters: " << min_num_clusters + ); + + using namespace buc_impl; + + labels.resize(dists.nr()); + disjoint_subsets sets; + sets.set_size(dists.nr()); + if (labels.size() == 0) + return 0; + + // push all the edges in the graph into a priority queue so the best edges to merge + // come first. + std::priority_queue, compare_dist> que; + for (long r = 0; r < dists.nr(); ++r) + for (long c = r+1; c < dists.nc(); ++c) + que.push(sample_pair(r,c,dists(r,c))); + + // Now start merging nodes. + for (unsigned long iter = min_num_clusters; iter < sets.size(); ++iter) + { + // find the next best thing to merge. + double best_dist = que.top().distance(); + unsigned long a = sets.find_set(que.top().index1()); + unsigned long b = sets.find_set(que.top().index2()); + que.pop(); + // we have been merging and modifying the distances, so make sure this distance + // is still valid and these guys haven't been merged already. + while(a == b || best_dist < dists(a,b)) + { + // Haven't merged it yet, so put it back in with updated distance for + // reconsideration later. + if (a != b) + que.push(sample_pair(a, b, dists(a, b))); + + best_dist = que.top().distance(); + a = sets.find_set(que.top().index1()); + b = sets.find_set(que.top().index2()); + que.pop(); + } + + + // now merge these sets if the best distance is small enough + if (best_dist > max_dist) + break; + unsigned long news = sets.merge_sets(a,b); + unsigned long olds = (news==a)?b:a; + merge_sets(dists, news, olds); + } + + // figure out which cluster each element is in. Also make sure the labels are + // contiguous. + std::map relabel; + for (unsigned long r = 0; r < labels.size(); ++r) + { + unsigned long l = sets.find_set(r); + // relabel to make contiguous + if (relabel.count(l) == 0) + { + unsigned long next = relabel.size(); + relabel[l] = next; + } + labels[r] = relabel[l]; + } + + + return relabel.size(); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + struct snl_range + { + snl_range() = default; + snl_range(double val) : lower(val), upper(val) {} + snl_range(double l, double u) : lower(l), upper(u) { DLIB_ASSERT(lower <= upper)} + + double lower = 0; + double upper = 0; + + double width() const { return upper-lower; } + bool operator<(const snl_range& item) const { return lower < item.lower; } + }; + + inline snl_range merge(const snl_range& a, const snl_range& b) + { + return snl_range(std::min(a.lower, b.lower), std::max(a.upper, b.upper)); + } + + inline double distance (const snl_range& a, const snl_range& b) + { + return std::max(a.lower,b.lower) - std::min(a.upper,b.upper); + } + + inline std::ostream& operator<< (std::ostream& out, const snl_range& item ) + { + out << "["< segment_number_line ( + const std::vector& x, + const double max_range_width + ) + { + DLIB_CASSERT(max_range_width >= 0); + + // create initial ranges, one for each value in x. So initially, all the ranges have + // width of 0. + std::vector ranges; + for (auto v : x) + ranges.push_back(v); + std::sort(ranges.begin(), ranges.end()); + + std::vector greedy_final_ranges; + if (ranges.size() == 0) + return greedy_final_ranges; + // We will try two different clustering strategies. One that does a simple greedy left + // to right sweep and another that does a bottom up agglomerative clustering. This + // first loop runs the greedy left to right sweep. Then at the end of this routine we + // will return the results that produced the tightest clustering. + greedy_final_ranges.push_back(ranges[0]); + for (size_t i = 1; i < ranges.size(); ++i) + { + auto m = merge(greedy_final_ranges.back(), ranges[i]); + if (m.width() <= max_range_width) + greedy_final_ranges.back() = m; + else + greedy_final_ranges.push_back(ranges[i]); + } + + + // Here we do the bottom up clustering. So compute the edges connecting our ranges. + // We will simply say there are edges between ranges if and only if they are + // immediately adjacent on the number line. + std::vector edges; + for (size_t i = 1; i < ranges.size(); ++i) + edges.push_back(sample_pair(i-1,i, distance(ranges[i-1],ranges[i]))); + std::sort(edges.begin(), edges.end(), order_by_distance); + + disjoint_subsets sets; + sets.set_size(ranges.size()); + + // Now start merging nodes. + for (auto edge : edges) + { + // find the next best thing to merge. + unsigned long a = sets.find_set(edge.index1()); + unsigned long b = sets.find_set(edge.index2()); + + // merge it if it doesn't result in an interval that's too big. + auto m = merge(ranges[a], ranges[b]); + if (m.width() <= max_range_width) + { + unsigned long news = sets.merge_sets(a,b); + ranges[news] = m; + } + } + + // Now create a list of the final ranges. We will do this by keeping track of which + // range we already added to final_ranges. + std::vector final_ranges; + std::vector already_output(ranges.size(), false); + for (unsigned long i = 0; i < sets.size(); ++i) + { + auto s = sets.find_set(i); + if (!already_output[s]) + { + final_ranges.push_back(ranges[s]); + already_output[s] = true; + } + } + + // only use the greedy clusters if they found a clustering with fewer clusters. + // Otherwise, the bottom up clustering probably produced a more sensible clustering. + if (final_ranges.size() <= greedy_final_ranges.size()) + return final_ranges; + else + return greedy_final_ranges; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BOTTOM_uP_CLUSTER_Hh_ + diff --git a/ml/dlib/dlib/clustering/bottom_up_cluster_abstract.h b/ml/dlib/dlib/clustering/bottom_up_cluster_abstract.h new file mode 100644 index 000000000..72d362c12 --- /dev/null +++ b/ml/dlib/dlib/clustering/bottom_up_cluster_abstract.h @@ -0,0 +1,136 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_BOTTOM_uP_CLUSTER_ABSTRACT_Hh_ +#ifdef DLIB_BOTTOM_uP_CLUSTER_ABSTRACT_Hh_ + +#include "../matrix.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP + > + unsigned long bottom_up_cluster ( + const matrix_exp& dists, + std::vector& labels, + unsigned long min_num_clusters, + double max_dist = std::numeric_limits::infinity() + ); + /*! + requires + - dists.nr() == dists.nc() + - min_num_clusters > 0 + - dists == trans(dists) + (l.e. dists should be symmetric) + ensures + - Runs a bottom up agglomerative clustering algorithm. + - Interprets dists as a matrix that gives the distances between dists.nr() + items. In particular, we take dists(i,j) to be the distance between the ith + and jth element of some set. This function clusters the elements of this set + into at least min_num_clusters (or dists.nr() if there aren't enough + elements). Additionally, within each cluster, the maximum pairwise distance + between any two cluster elements is <= max_dist. + - returns the number of clusters found. + - #labels.size() == dists.nr() + - for all valid i: + - #labels[i] == the cluster ID of the node with index i (i.e. the node + corresponding to the distances dists(i,*)). + - 0 <= #labels[i] < the number of clusters found + (i.e. cluster IDs are assigned contiguously and start at 0) + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + struct snl_range + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents an interval on the real number line. It is used + to store the outputs of the segment_number_line() routine defined below. + !*/ + + snl_range( + ); + /*! + ensures + - #lower == 0 + - #upper == 0 + !*/ + + snl_range( + double val + ); + /*! + ensures + - #lower == val + - #upper == val + !*/ + + snl_range( + double l, + double u + ); + /*! + requires + - l <= u + ensures + - #lower == l + - #upper == u + !*/ + + double lower; + double upper; + + double width( + ) const { return upper-lower; } + /*! + ensures + - returns the width of this interval on the number line. + !*/ + + bool operator<(const snl_range& item) const { return lower < item.lower; } + /*! + ensures + - provides a total ordering of snl_range objects assuming they are + non-overlapping. + !*/ + }; + + std::ostream& operator<< (std::ostream& out, const snl_range& item ); + /*! + ensures + - prints item to out in the form [lower,upper]. + !*/ + +// ---------------------------------------------------------------------------------------- + + std::vector segment_number_line ( + const std::vector& x, + const double max_range_width + ); + /*! + requires + - max_range_width >= 0 + ensures + - Finds a clustering of the values in x and returns the ranges that define the + clustering. This routine uses a combination of bottom up clustering and a + simple greedy scan to try and find the most compact set of ranges that + contain all the values in x. + - This routine has approximately linear runtime. + - Every value in x will be contained inside one of the returned snl_range + objects; + - All returned snl_range object's will have a width() <= max_range_width and + will also be non-overlapping. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BOTTOM_uP_CLUSTER_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/clustering/chinese_whispers.h b/ml/dlib/dlib/clustering/chinese_whispers.h new file mode 100644 index 000000000..332cce1a0 --- /dev/null +++ b/ml/dlib/dlib/clustering/chinese_whispers.h @@ -0,0 +1,135 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CHINESE_WHISPErS_Hh_ +#define DLIB_CHINESE_WHISPErS_Hh_ + +#include "chinese_whispers_abstract.h" +#include +#include "../rand.h" +#include "../graph_utils/edge_list_graphs.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + inline unsigned long chinese_whispers ( + const std::vector& edges, + std::vector& labels, + const unsigned long num_iterations, + dlib::rand& rnd + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(is_ordered_by_index(edges), + "\t unsigned long chinese_whispers()" + << "\n\t Invalid inputs were given to this function" + ); + + labels.clear(); + if (edges.size() == 0) + return 0; + + std::vector > neighbors; + find_neighbor_ranges(edges, neighbors); + + // Initialize the labels, each node gets a different label. + labels.resize(neighbors.size()); + for (unsigned long i = 0; i < labels.size(); ++i) + labels[i] = i; + + + for (unsigned long iter = 0; iter < neighbors.size()*num_iterations; ++iter) + { + // Pick a random node. + const unsigned long idx = rnd.get_random_64bit_number()%neighbors.size(); + + // Count how many times each label happens amongst our neighbors. + std::map labels_to_counts; + const unsigned long end = neighbors[idx].second; + for (unsigned long i = neighbors[idx].first; i != end; ++i) + { + labels_to_counts[labels[edges[i].index2()]] += edges[i].distance(); + } + + // find the most common label + std::map::iterator i; + double best_score = -std::numeric_limits::infinity(); + unsigned long best_label = labels[idx]; + for (i = labels_to_counts.begin(); i != labels_to_counts.end(); ++i) + { + if (i->second > best_score) + { + best_score = i->second; + best_label = i->first; + } + } + + labels[idx] = best_label; + } + + + // Remap the labels into a contiguous range. First we find the + // mapping. + std::map label_remap; + for (unsigned long i = 0; i < labels.size(); ++i) + { + const unsigned long next_id = label_remap.size(); + if (label_remap.count(labels[i]) == 0) + label_remap[labels[i]] = next_id; + } + // now apply the mapping to all the labels. + for (unsigned long i = 0; i < labels.size(); ++i) + { + labels[i] = label_remap[labels[i]]; + } + + return label_remap.size(); + } + +// ---------------------------------------------------------------------------------------- + + inline unsigned long chinese_whispers ( + const std::vector& edges, + std::vector& labels, + const unsigned long num_iterations, + dlib::rand& rnd + ) + { + std::vector oedges; + convert_unordered_to_ordered(edges, oedges); + std::sort(oedges.begin(), oedges.end(), &order_by_index); + + return chinese_whispers(oedges, labels, num_iterations, rnd); + } + +// ---------------------------------------------------------------------------------------- + + inline unsigned long chinese_whispers ( + const std::vector& edges, + std::vector& labels, + const unsigned long num_iterations = 100 + ) + { + dlib::rand rnd; + return chinese_whispers(edges, labels, num_iterations, rnd); + } + +// ---------------------------------------------------------------------------------------- + + inline unsigned long chinese_whispers ( + const std::vector& edges, + std::vector& labels, + const unsigned long num_iterations = 100 + ) + { + dlib::rand rnd; + return chinese_whispers(edges, labels, num_iterations, rnd); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CHINESE_WHISPErS_Hh_ + diff --git a/ml/dlib/dlib/clustering/chinese_whispers_abstract.h b/ml/dlib/dlib/clustering/chinese_whispers_abstract.h new file mode 100644 index 000000000..7a184c6f9 --- /dev/null +++ b/ml/dlib/dlib/clustering/chinese_whispers_abstract.h @@ -0,0 +1,97 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_CHINESE_WHISPErS_ABSTRACT_Hh_ +#ifdef DLIB_CHINESE_WHISPErS_ABSTRACT_Hh_ + +#include +#include "../rand.h" +#include "../graph_utils/ordered_sample_pair_abstract.h" +#include "../graph_utils/sample_pair_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + unsigned long chinese_whispers ( + const std::vector& edges, + std::vector& labels, + const unsigned long num_iterations, + dlib::rand& rnd + ); + /*! + requires + - is_ordered_by_index(edges) == true + ensures + - This function implements the graph clustering algorithm described in the + paper: Chinese Whispers - an Efficient Graph Clustering Algorithm and its + Application to Natural Language Processing Problems by Chris Biemann. + - Interprets edges as a directed graph. That is, it contains the edges on the + said graph and the ordered_sample_pair::distance() values define the edge + weights (larger values indicating a stronger edge connection between the + nodes). If an edge has a distance() value of infinity then it is considered + a "must link" edge. + - returns the number of clusters found. + - #labels.size() == max_index_plus_one(edges) + - for all valid i: + - #labels[i] == the cluster ID of the node with index i in the graph. + - 0 <= #labels[i] < the number of clusters found + (i.e. cluster IDs are assigned contiguously and start at 0) + - Duplicate edges are interpreted as if there had been just one edge with a + distance value equal to the sum of all the duplicate edge's distance values. + - The algorithm performs exactly num_iterations passes over the graph before + terminating. + !*/ + +// ---------------------------------------------------------------------------------------- + + unsigned long chinese_whispers ( + const std::vector& edges, + std::vector& labels, + const unsigned long num_iterations, + dlib::rand& rnd + ); + /*! + ensures + - This function is identical to the above chinese_whispers() routine except + that it operates on a vector of sample_pair objects instead of + ordered_sample_pairs. Therefore, this is simply a convenience routine. In + particular, it is implemented by transforming the given edges into + ordered_sample_pairs and then calling the chinese_whispers() routine defined + above. + !*/ + +// ---------------------------------------------------------------------------------------- + + unsigned long chinese_whispers ( + const std::vector& edges, + std::vector& labels, + const unsigned long num_iterations = 100 + ); + /*! + requires + - is_ordered_by_index(edges) == true + ensures + - performs: return chinese_whispers(edges, labels, num_iterations, rnd) + where rnd is a default initialized dlib::rand object. + !*/ + +// ---------------------------------------------------------------------------------------- + + unsigned long chinese_whispers ( + const std::vector& edges, + std::vector& labels, + const unsigned long num_iterations = 100 + ); + /*! + ensures + - performs: return chinese_whispers(edges, labels, num_iterations, rnd) + where rnd is a default initialized dlib::rand object. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CHINESE_WHISPErS_ABSTRACT_Hh_ + diff --git a/ml/dlib/dlib/clustering/modularity_clustering.h b/ml/dlib/dlib/clustering/modularity_clustering.h new file mode 100644 index 000000000..8b8a0b0a5 --- /dev/null +++ b/ml/dlib/dlib/clustering/modularity_clustering.h @@ -0,0 +1,515 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MODULARITY_ClUSTERING__H__ +#define DLIB_MODULARITY_ClUSTERING__H__ + +#include "modularity_clustering_abstract.h" +#include "../sparse_vector.h" +#include "../graph_utils/edge_list_graphs.h" +#include "../matrix.h" +#include "../rand.h" + +namespace dlib +{ + +// ----------------------------------------------------------------------------------------- + + namespace impl + { + inline double newman_cluster_split ( + dlib::rand& rnd, + const std::vector& edges, + const matrix& node_degrees, // k from the Newman paper + const matrix& Bdiag, // diag(B) from the Newman paper + const double& edge_sum, // m from the Newman paper + matrix& labels, + const double eps, + const unsigned long max_iterations + ) + /*! + requires + - node_degrees.size() == max_index_plus_one(edges) + - Bdiag.size() == max_index_plus_one(edges) + - edges must be sorted according to order_by_index() + ensures + - This routine splits a graph into two subgraphs using the Newman + clustering method. + - returns the modularity obtained when the graph is split according + to the contents of #labels. + - #labels.size() == node_degrees.size() + - for all valid i: #labels(i) == -1 or +1 + - if (this function returns 0) then + - all the labels are equal, i.e. the graph is not split. + !*/ + { + // Scale epsilon so that it is relative to the expected value of an element of a + // unit vector of length node_degrees.size(). + const double power_iter_eps = eps * std::sqrt(1.0/node_degrees.size()); + + // Make a random unit vector and put in labels. + labels.set_size(node_degrees.size()); + for (long i = 0; i < labels.size(); ++i) + labels(i) = rnd.get_random_gaussian(); + labels /= length(labels); + + matrix Bv, Bv_unit; + + // Do the power iteration for a while. + double eig = -1; + double offset = 0; + while (eig < 0) + { + + // any number larger than power_iter_eps + double iteration_change = power_iter_eps*2+1; + for (unsigned long i = 0; i < max_iterations && iteration_change > power_iter_eps; ++i) + { + sparse_matrix_vector_multiply(edges, labels, Bv); + Bv -= dot(node_degrees, labels)/(2*edge_sum) * node_degrees; + + if (offset != 0) + { + Bv -= offset*labels; + } + + + const double len = length(Bv); + if (len != 0) + { + Bv_unit = Bv/len; + iteration_change = max(abs(labels-Bv_unit)); + labels.swap(Bv_unit); + } + else + { + // Had a bad time, pick another random vector and try it with the + // power iteration. + for (long i = 0; i < labels.size(); ++i) + labels(i) = rnd.get_random_gaussian(); + } + } + + eig = dot(Bv,labels); + // we will repeat this loop if the largest eigenvalue is negative + offset = eig; + } + + + for (long i = 0; i < labels.size(); ++i) + { + if (labels(i) > 0) + labels(i) = 1; + else + labels(i) = -1; + } + + + // compute B*labels, store result in Bv. + sparse_matrix_vector_multiply(edges, labels, Bv); + Bv -= dot(node_degrees, labels)/(2*edge_sum) * node_degrees; + + // Do some label refinement. In this step we swap labels if it + // improves the modularity score. + bool flipped_label = true; + while(flipped_label) + { + flipped_label = false; + unsigned long idx = 0; + for (long i = 0; i < labels.size(); ++i) + { + const double val = -2*labels(i); + const double increase = 4*Bdiag(i) + 2*val*Bv(i); + + // if there is an increase in modularity for swapping this label + if (increase > 0) + { + labels(i) *= -1; + while (idx < edges.size() && edges[idx].index1() == (unsigned long)i) + { + const long j = edges[idx].index2(); + Bv(j) += val*edges[idx].distance(); + ++idx; + } + + Bv -= (val*node_degrees(i)/(2*edge_sum))*node_degrees; + + flipped_label = true; + } + else + { + while (idx < edges.size() && edges[idx].index1() == (unsigned long)i) + { + ++idx; + } + } + } + } + + + const double modularity = dot(Bv, labels)/(4*edge_sum); + + return modularity; + } + + // ------------------------------------------------------------------------------------- + + inline unsigned long newman_cluster_helper ( + dlib::rand& rnd, + const std::vector& edges, + const matrix& node_degrees, // k from the Newman paper + const matrix& Bdiag, // diag(B) from the Newman paper + const double& edge_sum, // m from the Newman paper + std::vector& labels, + double modularity_threshold, + const double eps, + const unsigned long max_iterations + ) + /*! + ensures + - returns the number of clusters the data was split into + !*/ + { + matrix l; + const double modularity = newman_cluster_split(rnd,edges,node_degrees,Bdiag,edge_sum,l,eps,max_iterations); + + + // We need to collapse the node index values down to contiguous values. So + // we use the following two vectors to contain the mappings from input index + // values to their corresponding index values in each split. + std::vector left_idx_map(node_degrees.size()); + std::vector right_idx_map(node_degrees.size()); + + // figure out how many nodes went into each side of the split. + unsigned long num_left_split = 0; + unsigned long num_right_split = 0; + for (long i = 0; i < l.size(); ++i) + { + if (l(i) > 0) + { + left_idx_map[i] = num_left_split; + ++num_left_split; + } + else + { + right_idx_map[i] = num_right_split; + ++num_right_split; + } + } + + // do a recursive split if it will improve the modularity. + if (modularity > modularity_threshold && num_left_split > 0 && num_right_split > 0) + { + + // split the node_degrees and Bdiag matrices into left and right split parts + matrix left_node_degrees(num_left_split); + matrix right_node_degrees(num_right_split); + matrix left_Bdiag(num_left_split); + matrix right_Bdiag(num_right_split); + for (long i = 0; i < l.size(); ++i) + { + if (l(i) > 0) + { + left_node_degrees(left_idx_map[i]) = node_degrees(i); + left_Bdiag(left_idx_map[i]) = Bdiag(i); + } + else + { + right_node_degrees(right_idx_map[i]) = node_degrees(i); + right_Bdiag(right_idx_map[i]) = Bdiag(i); + } + } + + + // put the edges from one side of the split into split_edges + std::vector split_edges; + modularity_threshold = 0; + for (unsigned long k = 0; k < edges.size(); ++k) + { + const unsigned long i = edges[k].index1(); + const unsigned long j = edges[k].index2(); + const double d = edges[k].distance(); + if (l(i) > 0 && l(j) > 0) + { + split_edges.push_back(ordered_sample_pair(left_idx_map[i], left_idx_map[j], d)); + modularity_threshold += d; + } + } + modularity_threshold -= sum(left_node_degrees*sum(left_node_degrees))/(2*edge_sum); + modularity_threshold /= 4*edge_sum; + + unsigned long num_left_clusters; + std::vector left_labels; + num_left_clusters = newman_cluster_helper(rnd,split_edges,left_node_degrees,left_Bdiag, + edge_sum,left_labels,modularity_threshold, + eps, max_iterations); + + // now load the other side into split_edges and cluster it as well + split_edges.clear(); + modularity_threshold = 0; + for (unsigned long k = 0; k < edges.size(); ++k) + { + const unsigned long i = edges[k].index1(); + const unsigned long j = edges[k].index2(); + const double d = edges[k].distance(); + if (l(i) < 0 && l(j) < 0) + { + split_edges.push_back(ordered_sample_pair(right_idx_map[i], right_idx_map[j], d)); + modularity_threshold += d; + } + } + modularity_threshold -= sum(right_node_degrees*sum(right_node_degrees))/(2*edge_sum); + modularity_threshold /= 4*edge_sum; + + unsigned long num_right_clusters; + std::vector right_labels; + num_right_clusters = newman_cluster_helper(rnd,split_edges,right_node_degrees,right_Bdiag, + edge_sum,right_labels,modularity_threshold, + eps, max_iterations); + + // Now merge the labels from the two splits. + labels.resize(node_degrees.size()); + for (unsigned long i = 0; i < labels.size(); ++i) + { + // if this node was in the left split + if (l(i) > 0) + { + labels[i] = left_labels[left_idx_map[i]]; + } + else // if this node was in the right split + { + labels[i] = right_labels[right_idx_map[i]] + num_left_clusters; + } + } + + + return num_left_clusters + num_right_clusters; + } + else + { + labels.assign(node_degrees.size(),0); + return 1; + } + + } + } + +// ---------------------------------------------------------------------------------------- + + inline unsigned long newman_cluster ( + const std::vector& edges, + std::vector& labels, + const double eps = 1e-4, + const unsigned long max_iterations = 2000 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(is_ordered_by_index(edges), + "\t unsigned long newman_cluster()" + << "\n\t Invalid inputs were given to this function" + ); + + labels.clear(); + if (edges.size() == 0) + return 0; + + const unsigned long num_nodes = max_index_plus_one(edges); + + // compute the node_degrees vector, edge_sum value, and diag(B). + matrix node_degrees(num_nodes); + matrix Bdiag(num_nodes); + Bdiag = 0; + double edge_sum = 0; + node_degrees = 0; + for (unsigned long i = 0; i < edges.size(); ++i) + { + node_degrees(edges[i].index1()) += edges[i].distance(); + edge_sum += edges[i].distance(); + if (edges[i].index1() == edges[i].index2()) + Bdiag(edges[i].index1()) += edges[i].distance(); + } + edge_sum /= 2; + Bdiag -= squared(node_degrees)/(2*edge_sum); + + + dlib::rand rnd; + return impl::newman_cluster_helper(rnd,edges,node_degrees,Bdiag,edge_sum,labels,0,eps,max_iterations); + } + +// ---------------------------------------------------------------------------------------- + + inline unsigned long newman_cluster ( + const std::vector& edges, + std::vector& labels, + const double eps = 1e-4, + const unsigned long max_iterations = 2000 + ) + { + std::vector oedges; + convert_unordered_to_ordered(edges, oedges); + std::sort(oedges.begin(), oedges.end(), &order_by_index); + + return newman_cluster(oedges, labels, eps, max_iterations); + } + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + inline std::vector remap_labels ( + const std::vector& labels, + unsigned long& num_labels + ) + /*! + ensures + - This function takes labels and produces a mapping which maps elements of + labels into the most compact range in [0, max] as possible. In particular, + there won't be any unused integers in the mapped range. + - #num_labels == the number of distinct values in labels. + - returns a vector V such that: + - V.size() == labels.size() + - max(mat(V))+1 == num_labels. + - for all valid i,j: + - if (labels[i] == labels[j]) then + - V[i] == V[j] + - else + - V[i] != V[j] + !*/ + { + std::map temp; + for (unsigned long i = 0; i < labels.size(); ++i) + { + if (temp.count(labels[i]) == 0) + { + const unsigned long next = temp.size(); + temp[labels[i]] = next; + } + } + + num_labels = temp.size(); + + std::vector result(labels.size()); + for (unsigned long i = 0; i < labels.size(); ++i) + { + result[i] = temp[labels[i]]; + } + return result; + } + } + +// ---------------------------------------------------------------------------------------- + + inline double modularity ( + const std::vector& edges, + const std::vector& labels + ) + { + const unsigned long num_nodes = max_index_plus_one(edges); + // make sure requires clause is not broken + DLIB_ASSERT(labels.size() == num_nodes, + "\t double modularity()" + << "\n\t Invalid inputs were given to this function" + ); + + unsigned long num_labels; + const std::vector& labels_ = dlib::impl::remap_labels(labels,num_labels); + + std::vector cluster_sums(num_labels,0); + std::vector k(num_nodes,0); + + double Q = 0; + double m = 0; + for (unsigned long i = 0; i < edges.size(); ++i) + { + const unsigned long n1 = edges[i].index1(); + const unsigned long n2 = edges[i].index2(); + k[n1] += edges[i].distance(); + if (n1 != n2) + k[n2] += edges[i].distance(); + + if (n1 != n2) + m += edges[i].distance(); + else + m += edges[i].distance()/2; + + if (labels_[n1] == labels_[n2]) + { + if (n1 != n2) + Q += 2*edges[i].distance(); + else + Q += edges[i].distance(); + } + } + + if (m == 0) + return 0; + + for (unsigned long i = 0; i < labels_.size(); ++i) + { + cluster_sums[labels_[i]] += k[i]; + } + + for (unsigned long i = 0; i < labels_.size(); ++i) + { + Q -= k[i]*cluster_sums[labels_[i]]/(2*m); + } + + return 1.0/(2*m)*Q; + } + +// ---------------------------------------------------------------------------------------- + + inline double modularity ( + const std::vector& edges, + const std::vector& labels + ) + { + const unsigned long num_nodes = max_index_plus_one(edges); + // make sure requires clause is not broken + DLIB_ASSERT(labels.size() == num_nodes, + "\t double modularity()" + << "\n\t Invalid inputs were given to this function" + ); + + + unsigned long num_labels; + const std::vector& labels_ = dlib::impl::remap_labels(labels,num_labels); + + std::vector cluster_sums(num_labels,0); + std::vector k(num_nodes,0); + + double Q = 0; + double m = 0; + for (unsigned long i = 0; i < edges.size(); ++i) + { + const unsigned long n1 = edges[i].index1(); + const unsigned long n2 = edges[i].index2(); + k[n1] += edges[i].distance(); + m += edges[i].distance(); + if (labels_[n1] == labels_[n2]) + { + Q += edges[i].distance(); + } + } + + if (m == 0) + return 0; + + for (unsigned long i = 0; i < labels_.size(); ++i) + { + cluster_sums[labels_[i]] += k[i]; + } + + for (unsigned long i = 0; i < labels_.size(); ++i) + { + Q -= k[i]*cluster_sums[labels_[i]]/m; + } + + return 1.0/m*Q; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MODULARITY_ClUSTERING__H__ + diff --git a/ml/dlib/dlib/clustering/modularity_clustering_abstract.h b/ml/dlib/dlib/clustering/modularity_clustering_abstract.h new file mode 100644 index 000000000..c1e7c20c4 --- /dev/null +++ b/ml/dlib/dlib/clustering/modularity_clustering_abstract.h @@ -0,0 +1,125 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_MODULARITY_ClUSTERING_ABSTRACT_Hh_ +#ifdef DLIB_MODULARITY_ClUSTERING_ABSTRACT_Hh_ + +#include +#include "../graph_utils/ordered_sample_pair_abstract.h" +#include "../graph_utils/sample_pair_abstract.h" + +namespace dlib +{ + +// ----------------------------------------------------------------------------------------- + + double modularity ( + const std::vector& edges, + const std::vector& labels + ); + /*! + requires + - labels.size() == max_index_plus_one(edges) + - for all valid i: + - 0 <= edges[i].distance() < std::numeric_limits::infinity() + ensures + - Interprets edges as an undirected graph. That is, it contains the edges on + the said graph and the sample_pair::distance() values define the edge weights + (larger values indicating a stronger edge connection between the nodes). + - This function returns the modularity value obtained when the given input + graph is broken into subgraphs according to the contents of labels. In + particular, we say that two nodes with indices i and j are in the same + subgraph or community if and only if labels[i] == labels[j]. + - Duplicate edges are interpreted as if there had been just one edge with a + distance value equal to the sum of all the duplicate edge's distance values. + - See the paper Modularity and community structure in networks by M. E. J. Newman + for a detailed definition. + !*/ + +// ---------------------------------------------------------------------------------------- + + double modularity ( + const std::vector& edges, + const std::vector& labels + ); + /*! + requires + - labels.size() == max_index_plus_one(edges) + - for all valid i: + - 0 <= edges[i].distance() < std::numeric_limits::infinity() + ensures + - Interprets edges as a directed graph. That is, it contains the edges on the + said graph and the ordered_sample_pair::distance() values define the edge + weights (larger values indicating a stronger edge connection between the + nodes). Note that, generally, modularity is only really defined for + undirected graphs. Therefore, the "directed graph" given to this function + should have symmetric edges between all nodes. The reason this function is + provided at all is because sometimes a vector of ordered_sample_pair objects + is a useful representation of an undirected graph. + - This function returns the modularity value obtained when the given input + graph is broken into subgraphs according to the contents of labels. In + particular, we say that two nodes with indices i and j are in the same + subgraph or community if and only if labels[i] == labels[j]. + - Duplicate edges are interpreted as if there had been just one edge with a + distance value equal to the sum of all the duplicate edge's distance values. + - See the paper Modularity and community structure in networks by M. E. J. Newman + for a detailed definition. + !*/ + +// ---------------------------------------------------------------------------------------- + + unsigned long newman_cluster ( + const std::vector& edges, + std::vector& labels, + const double eps = 1e-4, + const unsigned long max_iterations = 2000 + ); + /*! + requires + - is_ordered_by_index(edges) == true + - for all valid i: + - 0 <= edges[i].distance() < std::numeric_limits::infinity() + ensures + - This function performs the clustering algorithm described in the paper + Modularity and community structure in networks by M. E. J. Newman. + - This function interprets edges as a graph and attempts to find the labeling + that maximizes modularity(edges, #labels). + - returns the number of clusters found. + - #labels.size() == max_index_plus_one(edges) + - for all valid i: + - #labels[i] == the cluster ID of the node with index i in the graph. + - 0 <= #labels[i] < the number of clusters found + (i.e. cluster IDs are assigned contiguously and start at 0) + - The main computation of the algorithm is involved in finding an eigenvector + of a certain matrix. To do this, we use the power iteration. In particular, + each time we try to find an eigenvector we will let the power iteration loop + at most max_iterations times or until it reaches an accuracy of eps. + Whichever comes first. + !*/ + +// ---------------------------------------------------------------------------------------- + + unsigned long newman_cluster ( + const std::vector& edges, + std::vector& labels, + const double eps = 1e-4, + const unsigned long max_iterations = 2000 + ); + /*! + requires + - for all valid i: + - 0 <= edges[i].distance() < std::numeric_limits::infinity() + ensures + - This function is identical to the above newman_cluster() routine except that + it operates on a vector of sample_pair objects instead of + ordered_sample_pairs. Therefore, this is simply a convenience routine. In + particular, it is implemented by transforming the given edges into + ordered_sample_pairs and then calling the newman_cluster() routine defined + above. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MODULARITY_ClUSTERING_ABSTRACT_Hh_ + diff --git a/ml/dlib/dlib/clustering/spectral_cluster.h b/ml/dlib/dlib/clustering/spectral_cluster.h new file mode 100644 index 000000000..2cac9870f --- /dev/null +++ b/ml/dlib/dlib/clustering/spectral_cluster.h @@ -0,0 +1,80 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SPECTRAL_CLUSTEr_H_ +#define DLIB_SPECTRAL_CLUSTEr_H_ + +#include "spectral_cluster_abstract.h" +#include +#include "../matrix.h" +#include "../svm/kkmeans.h" + +namespace dlib +{ + template < + typename kernel_type, + typename vector_type + > + std::vector spectral_cluster ( + const kernel_type& k, + const vector_type& samples, + const unsigned long num_clusters + ) + { + DLIB_CASSERT(num_clusters > 0, + "\t std::vector spectral_cluster(k,samples,num_clusters)" + << "\n\t num_clusters can't be 0." + ); + + if (num_clusters == 1) + { + // nothing to do, just assign everything to the 0 cluster. + return std::vector(samples.size(), 0); + } + + // compute the similarity matrix. + matrix K(samples.size(), samples.size()); + for (long r = 0; r < K.nr(); ++r) + for (long c = r+1; c < K.nc(); ++c) + K(r,c) = K(c,r) = (double)k(samples[r], samples[c]); + for (long r = 0; r < K.nr(); ++r) + K(r,r) = 0; + + matrix D(K.nr()); + for (long r = 0; r < K.nr(); ++r) + D(r) = sum(rowm(K,r)); + D = sqrt(reciprocal(D)); + K = diagm(D)*K*diagm(D); + matrix u,w,v; + // Use the normal SVD routine unless the matrix is really big, then use the fast + // approximate version. + if (K.nr() < 1000) + svd3(K,u,w,v); + else + svd_fast(K,u,w,v, num_clusters+100, 5); + // Pick out the eigenvectors associated with the largest eigenvalues. + rsort_columns(v,w); + v = colm(v, range(0,num_clusters-1)); + // Now build the normalized spectral vectors, one for each input vector. + std::vector > spec_samps, centers; + for (long r = 0; r < v.nr(); ++r) + { + spec_samps.push_back(trans(rowm(v,r))); + const double len = length(spec_samps.back()); + if (len != 0) + spec_samps.back() /= len; + } + // Finally do the K-means clustering + pick_initial_centers(num_clusters, centers, spec_samps); + find_clusters_using_kmeans(spec_samps, centers); + // And then compute the cluster assignments based on the output of K-means. + std::vector assignments; + for (unsigned long i = 0; i < spec_samps.size(); ++i) + assignments.push_back(nearest_center(centers, spec_samps[i])); + + return assignments; + } + +} + +#endif // DLIB_SPECTRAL_CLUSTEr_H_ + diff --git a/ml/dlib/dlib/clustering/spectral_cluster_abstract.h b/ml/dlib/dlib/clustering/spectral_cluster_abstract.h new file mode 100644 index 000000000..880ad80af --- /dev/null +++ b/ml/dlib/dlib/clustering/spectral_cluster_abstract.h @@ -0,0 +1,43 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SPECTRAL_CLUSTEr_ABSTRACT_H_ +#ifdef DLIB_SPECTRAL_CLUSTEr_ABSTRACT_H_ + +#include + +namespace dlib +{ + template < + typename kernel_type, + typename vector_type + > + std::vector spectral_cluster ( + const kernel_type& k, + const vector_type& samples, + const unsigned long num_clusters + ); + /*! + requires + - samples must be something with an interface compatible with std::vector. + - The following expression must evaluate to a double or float: + k(samples[i], samples[j]) + - num_clusters > 0 + ensures + - Performs the spectral clustering algorithm described in the paper: + On spectral clustering: Analysis and an algorithm by Ng, Jordan, and Weiss. + and returns the results. + - This function clusters the input data samples into num_clusters clusters and + returns a vector that indicates which cluster each sample falls into. In + particular, we return an array A such that: + - A.size() == samples.size() + - A[i] == the cluster assignment of samples[i]. + - for all valid i: 0 <= A[i] < num_clusters + - The "similarity" of samples[i] with samples[j] is given by + k(samples[i],samples[j]). This means that k() should output a number >= 0 + and the number should be larger for samples that are more similar. + !*/ +} + +#endif // DLIB_SPECTRAL_CLUSTEr_ABSTRACT_H_ + + diff --git a/ml/dlib/dlib/cmake b/ml/dlib/dlib/cmake new file mode 100644 index 000000000..d3695b30e --- /dev/null +++ b/ml/dlib/dlib/cmake @@ -0,0 +1,5 @@ + +cmake_minimum_required(VERSION 2.8.12) + +add_subdirectory(${CMAKE_CURRENT_LIST_DIR} dlib_build) + diff --git a/ml/dlib/dlib/cmake_utils/add_global_compiler_switch.cmake b/ml/dlib/dlib/cmake_utils/add_global_compiler_switch.cmake new file mode 100644 index 000000000..5f3d83ce4 --- /dev/null +++ b/ml/dlib/dlib/cmake_utils/add_global_compiler_switch.cmake @@ -0,0 +1,35 @@ + + +cmake_minimum_required(VERSION 2.8.12) + +message(WARNING "add_global_compiler_switch() is deprecated. Use target_compile_options() instead") + +# Make macros that can add compiler switches to the entire project. Not just +# to the current cmake folder being built. +macro ( add_global_compiler_switch switch_name ) + # If removing the switch would change the flags then it's already present + # and we don't need to do anything. + string(REPLACE "${switch_name}" "" tempstr "${CMAKE_CXX_FLAGS}") + if ("${CMAKE_CXX_FLAGS}" STREQUAL "${tempstr}" ) + set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${switch_name}" + CACHE STRING "Flags used by the compiler during all C++ builds." + FORCE) + endif () +endmacro() + +macro ( remove_global_compiler_switch switch_name ) + string(REPLACE "${switch_name}" "" tempstr "${CMAKE_CXX_FLAGS}") + if (NOT "${CMAKE_CXX_FLAGS}" STREQUAL "${tempstr}" ) + set (CMAKE_CXX_FLAGS "${tempstr}" + CACHE STRING "Flags used by the compiler during all C++ builds." + FORCE) + endif () +endmacro() + +macro (add_global_define def_name) + add_global_compiler_switch(-D${def_name}) +endmacro() + +macro (remove_global_define def_name) + remove_global_compiler_switch(-D${def_name}) +endmacro() diff --git a/ml/dlib/dlib/cmake_utils/check_if_neon_available.cmake b/ml/dlib/dlib/cmake_utils/check_if_neon_available.cmake new file mode 100644 index 000000000..0510707df --- /dev/null +++ b/ml/dlib/dlib/cmake_utils/check_if_neon_available.cmake @@ -0,0 +1,20 @@ +# This script checks if __ARM_NEON__ is defined for your compiler + +cmake_minimum_required(VERSION 2.8.12) + +# Don't rerun this script if its already been executed. +if (DEFINED ARM_NEON_IS_AVAILABLE) + return() +endif() + +# Set to false unless we find out otherwise in the code below. +set(ARM_NEON_IS_AVAILABLE 0) + +# test if __ARM_NEON__ is defined +try_compile(test_for_neon_worked ${PROJECT_BINARY_DIR}/neon_test_build ${CMAKE_CURRENT_LIST_DIR}/test_for_neon + neon_test) + +if(test_for_neon_worked) + message (STATUS "__ARM_NEON__ defined.") + set(ARM_NEON_IS_AVAILABLE 1) +endif() diff --git a/ml/dlib/dlib/cmake_utils/dlib.pc.in b/ml/dlib/dlib/cmake_utils/dlib.pc.in new file mode 100644 index 000000000..188a673c3 --- /dev/null +++ b/ml/dlib/dlib/cmake_utils/dlib.pc.in @@ -0,0 +1,9 @@ +libdir=@CMAKE_INSTALL_FULL_LIBDIR@ +includedir=@CMAKE_INSTALL_FULL_INCLUDEDIR@ + +Name: @PROJECT_NAME@ +Description: Numerical and networking C++ library +Version: @VERSION@ +Libs: -L${libdir} -ldlib +Cflags: -I${includedir} +Requires:@REQUIRES_LIBS@ diff --git a/ml/dlib/dlib/cmake_utils/dlibConfig.cmake.in b/ml/dlib/dlib/cmake_utils/dlibConfig.cmake.in new file mode 100644 index 000000000..df427e40e --- /dev/null +++ b/ml/dlib/dlib/cmake_utils/dlibConfig.cmake.in @@ -0,0 +1,50 @@ +# =================================================================================== +# The dlib CMake configuration file +# +# ** File generated automatically, do not modify ** +# +# Usage from an external project: +# In your CMakeLists.txt, add these lines: +# +# FIND_PACKAGE(dlib REQUIRED) +# TARGET_LINK_LIBRARIES(MY_TARGET_NAME ${dlib_LIBRARIES}) +# +# This file will define the following variables: +# - dlib_LIBRARIES : The list of all imported targets for dlib modules. +# - dlib_INCLUDE_DIRS : The dlib include directories. +# - dlib_VERSION : The version of this dlib build. +# - dlib_VERSION_MAJOR : Major version part of this dlib revision. +# - dlib_VERSION_MINOR : Minor version part of this dlib revision. +# +# =================================================================================== + + + + +# Our library dependencies (contains definitions for IMPORTED targets) +if(NOT TARGET dlib-shared AND NOT dlib_BINARY_DIR) + # Compute paths + get_filename_component(dlib_CMAKE_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH) + include("${dlib_CMAKE_DIR}/dlib.cmake") +endif() + +set(dlib_LIBRARIES dlib::dlib) +set(dlib_LIBS dlib::dlib) +set(dlib_INCLUDE_DIRS "@CMAKE_INSTALL_FULL_INCLUDEDIR@" "@dlib_needed_includes@") + +mark_as_advanced(dlib_LIBRARIES) +mark_as_advanced(dlib_LIBS) +mark_as_advanced(dlib_INCLUDE_DIRS) + +# Mark these variables above as deprecated. +function(__deprecated_var var access) + if(access STREQUAL "READ_ACCESS") + message(WARNING "The variable '${var}' is deprecated! Instead, simply use target_link_libraries(your_app dlib::dlib). See http://dlib.net/examples/CMakeLists.txt.html for an example.") + endif() +endfunction() +variable_watch(dlib_LIBRARIES __deprecated_var) +variable_watch(dlib_LIBS __deprecated_var) +variable_watch(dlib_INCLUDE_DIRS __deprecated_var) + + + diff --git a/ml/dlib/dlib/cmake_utils/find_blas.cmake b/ml/dlib/dlib/cmake_utils/find_blas.cmake new file mode 100644 index 000000000..24fca7123 --- /dev/null +++ b/ml/dlib/dlib/cmake_utils/find_blas.cmake @@ -0,0 +1,385 @@ +# +# This is a CMake makefile. You can find the cmake utility and +# information about it at http://www.cmake.org +# +# +# This cmake file tries to find installed BLAS and LAPACK libraries. +# It looks for an installed copy of the Intel MKL library first and then +# attempts to find some other BLAS and LAPACK libraries if you don't have +# the Intel MKL. +# +# blas_found - True if BLAS is available +# lapack_found - True if LAPACK is available +# found_intel_mkl - True if the Intel MKL library is available +# found_intel_mkl_headers - True if Intel MKL headers are available +# blas_libraries - link against these to use BLAS library +# lapack_libraries - link against these to use LAPACK library +# mkl_libraries - link against these to use the MKL library +# mkl_include_dir - add to the include path to use the MKL library +# openmp_libraries - Set to Intel's OpenMP library if and only if we +# find the MKL. + +# setting this makes CMake allow normal looking if else statements +SET(CMAKE_ALLOW_LOOSE_LOOP_CONSTRUCTS true) + +SET(blas_found 0) +SET(lapack_found 0) +SET(found_intel_mkl 0) +SET(found_intel_mkl_headers 0) +SET(lapack_with_underscore 0) +SET(lapack_without_underscore 0) + +message(STATUS "Searching for BLAS and LAPACK") + +if (UNIX OR MINGW) + message(STATUS "Searching for BLAS and LAPACK") + + if (BUILDING_MATLAB_MEX_FILE) + # # This commented out stuff would link directly to MATLAB's built in + # BLAS and LAPACK. But it's better to not link to anything and do a + #find_library(MATLAB_BLAS_LIBRARY mwblas PATHS ${MATLAB_LIB_FOLDERS} ) + #find_library(MATLAB_LAPACK_LIBRARY mwlapack PATHS ${MATLAB_LIB_FOLDERS} ) + #if (MATLAB_BLAS_LIBRARY AND MATLAB_LAPACK_LIBRARY) + # add_subdirectory(external/cblas) + # set(blas_libraries ${MATLAB_BLAS_LIBRARY} cblas ) + # set(lapack_libraries ${MATLAB_LAPACK_LIBRARY} ) + # set(blas_found 1) + # set(lapack_found 1) + # message(STATUS "Found MATLAB's BLAS and LAPACK libraries") + #endif() + + # We need cblas since MATLAB doesn't provide cblas symbols. + add_subdirectory(external/cblas) + set(blas_libraries cblas ) + set(blas_found 1) + set(lapack_found 1) + message(STATUS "Will link with MATLAB's BLAS and LAPACK at runtime (hopefully!)") + + + ## Don't try to link to anything other than MATLAB's own internal blas + ## and lapack libraries because doing so generally upsets MATLAB. So + ## we just end here no matter what. + return() + endif() + + # First, search for libraries via pkg-config, which is the cleanest path + find_package(PkgConfig) + pkg_check_modules(BLAS_REFERENCE cblas) + pkg_check_modules(LAPACK_REFERENCE lapack) + if (BLAS_REFERENCE_FOUND AND LAPACK_REFERENCE_FOUND) + set(blas_libraries "${BLAS_REFERENCE_LDFLAGS}") + set(lapack_libraries "${LAPACK_REFERENCE_LDFLAGS}") + set(blas_found 1) + set(lapack_found 1) + set(REQUIRES_LIBS "${REQUIRES_LIBS} cblas lapack") + message(STATUS "Found BLAS and LAPACK via pkg-config") + return() + endif() + + include(CheckTypeSize) + check_type_size( "void*" SIZE_OF_VOID_PTR) + + if (SIZE_OF_VOID_PTR EQUAL 8) + set( mkl_search_path + /opt/intel/mkl/*/lib/em64t + /opt/intel/mkl/lib/intel64 + /opt/intel/lib/intel64 + /opt/intel/mkl/lib + ) + + find_library(mkl_intel mkl_intel_lp64 ${mkl_search_path}) + mark_as_advanced(mkl_intel) + else() + set( mkl_search_path + /opt/intel/mkl/*/lib/32 + /opt/intel/mkl/lib/ia32 + /opt/intel/lib/ia32 + ) + + find_library(mkl_intel mkl_intel ${mkl_search_path}) + mark_as_advanced(mkl_intel) + endif() + + include(CheckLibraryExists) + + # Get mkl_include_dir + set(mkl_include_search_path + /opt/intel/mkl/include + /opt/intel/include + ) + find_path(mkl_include_dir mkl_version.h ${mkl_include_search_path}) + mark_as_advanced(mkl_include_dir) + + # Search for the needed libraries from the MKL. We will try to link against the mkl_rt + # file first since this way avoids linking bugs in some cases. + find_library(mkl_rt mkl_rt ${mkl_search_path}) + find_library(openmp_libraries iomp5 ${mkl_search_path}) + mark_as_advanced( mkl_rt openmp_libraries ) + # if we found the MKL + if ( mkl_rt) + set(mkl_libraries ${mkl_rt} ) + set(blas_libraries ${mkl_rt} ) + set(lapack_libraries ${mkl_rt} ) + set(blas_found 1) + set(lapack_found 1) + set(found_intel_mkl 1) + message(STATUS "Found Intel MKL BLAS/LAPACK library") + endif() + + if (NOT found_intel_mkl) + # Search for the needed libraries from the MKL. This time try looking for a different + # set of MKL files and try to link against those. + find_library(mkl_core mkl_core ${mkl_search_path}) + find_library(mkl_thread mkl_intel_thread ${mkl_search_path}) + find_library(mkl_iomp iomp5 ${mkl_search_path}) + find_library(mkl_pthread pthread ${mkl_search_path}) + + mark_as_advanced( mkl_intel mkl_core mkl_thread mkl_iomp mkl_pthread) + # If we found the MKL + if (mkl_intel AND mkl_core AND mkl_thread AND mkl_iomp AND mkl_pthread) + set(mkl_libraries ${mkl_intel} ${mkl_core} ${mkl_thread} ${mkl_iomp} ${mkl_pthread}) + set(blas_libraries ${mkl_intel} ${mkl_core} ${mkl_thread} ${mkl_iomp} ${mkl_pthread}) + set(lapack_libraries ${mkl_intel} ${mkl_core} ${mkl_thread} ${mkl_iomp} ${mkl_pthread}) + set(blas_found 1) + set(lapack_found 1) + set(found_intel_mkl 1) + message(STATUS "Found Intel MKL BLAS/LAPACK library") + endif() + endif() + + if (found_intel_mkl AND mkl_include_dir) + set(found_intel_mkl_headers 1) + endif() + + # try to find some other LAPACK libraries if we didn't find the MKL + set(extra_paths + /usr/lib64 + /usr/lib64/atlas-sse3 + /usr/lib64/atlas-sse2 + /usr/lib64/atlas + /usr/lib + /usr/lib/atlas-sse3 + /usr/lib/atlas-sse2 + /usr/lib/atlas + /usr/lib/openblas-base + /opt/OpenBLAS/lib + $ENV{OPENBLAS_HOME}/lib + ) + + INCLUDE (CheckFunctionExists) + + if (NOT blas_found) + find_library(cblas_lib openblas PATHS ${extra_paths}) + if (cblas_lib) + set(blas_libraries ${cblas_lib}) + set(blas_found 1) + message(STATUS "Found OpenBLAS library") + set(CMAKE_REQUIRED_LIBRARIES ${blas_libraries}) + # If you compiled OpenBLAS with LAPACK in it then it should have the + # sgetrf_single function in it. So if we find that function in + # OpenBLAS then just use OpenBLAS's LAPACK. + CHECK_FUNCTION_EXISTS(sgetrf_single OPENBLAS_HAS_LAPACK) + if (OPENBLAS_HAS_LAPACK) + message(STATUS "Using OpenBLAS's built in LAPACK") + # set(lapack_libraries gfortran) + set(lapack_found 1) + endif() + endif() + mark_as_advanced( cblas_lib) + endif() + + + if (NOT lapack_found) + find_library(lapack_lib NAMES lapack lapack-3 PATHS ${extra_paths}) + if (lapack_lib) + set(lapack_libraries ${lapack_lib}) + set(lapack_found 1) + message(STATUS "Found LAPACK library") + endif() + mark_as_advanced( lapack_lib) + endif() + + + # try to find some other BLAS libraries if we didn't find the MKL + + if (NOT blas_found) + find_library(atlas_lib atlas PATHS ${extra_paths}) + find_library(cblas_lib cblas PATHS ${extra_paths}) + if (atlas_lib AND cblas_lib) + set(blas_libraries ${atlas_lib} ${cblas_lib}) + set(blas_found 1) + message(STATUS "Found ATLAS BLAS library") + endif() + mark_as_advanced( atlas_lib cblas_lib) + endif() + + # CentOS 7 atlas + if (NOT blas_found) + find_library(tatlas_lib tatlas PATHS ${extra_paths}) + find_library(satlas_lib satlas PATHS ${extra_paths}) + if (tatlas_lib AND satlas_lib ) + set(blas_libraries ${tatlas_lib} ${satlas_lib}) + set(blas_found 1) + message(STATUS "Found ATLAS BLAS library") + endif() + mark_as_advanced( tatlas_lib satlas_lib) + endif() + + + if (NOT blas_found) + find_library(cblas_lib cblas PATHS ${extra_paths}) + if (cblas_lib) + set(blas_libraries ${cblas_lib}) + set(blas_found 1) + message(STATUS "Found CBLAS library") + endif() + mark_as_advanced( cblas_lib) + endif() + + + if (NOT blas_found) + find_library(generic_blas blas PATHS ${extra_paths}) + if (generic_blas) + set(blas_libraries ${generic_blas}) + set(blas_found 1) + message(STATUS "Found BLAS library") + endif() + mark_as_advanced( generic_blas) + endif() + + + + + # Make sure we really found a CBLAS library. That is, it needs to expose + # the proper cblas link symbols. So here we test if one of them is present + # and assume everything is good if it is. Note that we don't do this check if + # we found the Intel MKL since for some reason CHECK_FUNCTION_EXISTS doesn't work + # with it. But it's fine since the MKL should always have cblas. + if (blas_found AND NOT found_intel_mkl) + set(CMAKE_REQUIRED_LIBRARIES ${blas_libraries}) + CHECK_FUNCTION_EXISTS(cblas_ddot HAVE_CBLAS) + if (NOT HAVE_CBLAS) + message(STATUS "BLAS library does not have cblas symbols, so dlib will not use BLAS or LAPACK") + set(blas_found 0) + set(lapack_found 0) + endif() + endif() + + + +elseif(WIN32 AND NOT MINGW) + message(STATUS "Searching for BLAS and LAPACK") + + include(CheckTypeSize) + check_type_size( "void*" SIZE_OF_VOID_PTR) + if (SIZE_OF_VOID_PTR EQUAL 8) + set( mkl_search_path + "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries_*/windows/mkl/lib/intel64" + "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries_*/windows/compiler/lib/intel64" + "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries/windows/compiler/lib/intel64" + "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries/windows/mkl/lib/intel64" + "C:/Program Files (x86)/Intel/Composer XE/mkl/lib/intel64" + "C:/Program Files (x86)/Intel/Composer XE/compiler/lib/intel64" + "C:/Program Files/Intel/Composer XE/mkl/lib/intel64" + "C:/Program Files/Intel/Composer XE/compiler/lib/intel64" + ) + find_library(mkl_intel mkl_intel_lp64 ${mkl_search_path}) + else() + set( mkl_search_path + "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries_*/windows/mkl/lib/ia32" + "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries_*/windows/compiler/lib/ia32" + "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries/windows/mkl/lib/ia32" + "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries/windows/compiler/lib/ia32" + "C:/Program Files (x86)/Intel/Composer XE/mkl/lib/ia32" + "C:/Program Files (x86)/Intel/Composer XE/compiler/lib/ia32" + "C:/Program Files/Intel/Composer XE/mkl/lib/ia32" + "C:/Program Files/Intel/Composer XE/compiler/lib/ia32" + ) + find_library(mkl_intel mkl_intel_c ${mkl_search_path}) + endif() + + INCLUDE (CheckFunctionExists) + + # Search for the needed libraries from the MKL. + find_library(mkl_core mkl_core ${mkl_search_path}) + find_library(mkl_thread mkl_intel_thread ${mkl_search_path}) + find_library(mkl_iomp libiomp5md ${mkl_search_path}) + + mark_as_advanced( mkl_intel mkl_core mkl_thread mkl_iomp) + # If we found the MKL + if (mkl_intel AND mkl_core AND mkl_thread AND mkl_iomp ) + set(blas_libraries ${mkl_intel} ${mkl_core} ${mkl_thread} ${mkl_iomp} ) + set(lapack_libraries ${mkl_intel} ${mkl_core} ${mkl_thread} ${mkl_iomp} ) + set(blas_found 1) + set(lapack_found 1) + message(STATUS "Found Intel MKL BLAS/LAPACK library") + + # Make sure the version of the Intel MKL we found is compatible with + # the compiler we are using. One way to do this check is to see if we can + # link to it right now. + set(CMAKE_REQUIRED_LIBRARIES ${blas_libraries}) + CHECK_FUNCTION_EXISTS(cblas_ddot HAVE_CBLAS) + if (NOT HAVE_CBLAS) + message("BLAS library does not have cblas symbols, so dlib will not use BLAS or LAPACK") + set(blas_found 0) + set(lapack_found 0) + endif() + + endif() + + +endif() + + +# When all else fails use CMake's built in functions to find BLAS and LAPACK +if (NOT blas_found) + find_package(BLAS QUIET) + if (${BLAS_FOUND}) + set(blas_libraries ${BLAS_LIBRARIES}) + set(blas_found 1) + if (NOT lapack_found) + find_package(LAPACK QUIET) + if (${LAPACK_FOUND}) + set(lapack_libraries ${LAPACK_LIBRARIES}) + set(lapack_found 1) + endif() + endif() + endif() +endif() + + +# If using lapack, determine whether to mangle functions +if (lapack_found) + include(CheckFunctionExists) + include(CheckFortranFunctionExists) + set(CMAKE_REQUIRED_LIBRARIES ${lapack_libraries}) + + check_function_exists("sgesv" LAPACK_FOUND_C_UNMANGLED) + check_function_exists("sgesv_" LAPACK_FOUND_C_MANGLED) + if (CMAKE_Fortran_COMPILER_LOADED) + check_fortran_function_exists("sgesv" LAPACK_FOUND_FORTRAN_UNMANGLED) + check_fortran_function_exists("sgesv_" LAPACK_FOUND_FORTRAN_MANGLED) + endif () + if (LAPACK_FOUND_C_MANGLED OR LAPACK_FOUND_FORTRAN_MANGLED) + set(lapack_with_underscore 1) + elseif (LAPACK_FOUND_C_UNMANGLED OR LAPACK_FOUND_FORTRAN_UNMANGLED) + set(lapack_without_underscore 1) + endif () +endif() + + +if (UNIX OR MINGW) + if (NOT blas_found) + message(" *****************************************************************************") + message(" *** No BLAS library found so using dlib's built in BLAS. However, if you ***") + message(" *** install an optimized BLAS such as OpenBLAS or the Intel MKL your code ***") + message(" *** will run faster. On Ubuntu you can install OpenBLAS by executing: ***") + message(" *** sudo apt-get install libopenblas-dev liblapack-dev ***") + message(" *** Or you can easily install OpenBLAS from source by downloading the ***") + message(" *** source tar file from http://www.openblas.net, extracting it, and ***") + message(" *** running: ***") + message(" *** make; sudo make install ***") + message(" *****************************************************************************") + endif() +endif() + diff --git a/ml/dlib/dlib/cmake_utils/release_build_by_default b/ml/dlib/dlib/cmake_utils/release_build_by_default new file mode 100644 index 000000000..1b0e95831 --- /dev/null +++ b/ml/dlib/dlib/cmake_utils/release_build_by_default @@ -0,0 +1,9 @@ + +#set default build type to Release +if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE "Release" CACHE STRING + "Choose the type of build, options are: Debug Release + RelWithDebInfo MinSizeRel." FORCE) +endif() + + diff --git a/ml/dlib/dlib/cmake_utils/set_compiler_specific_options.cmake b/ml/dlib/dlib/cmake_utils/set_compiler_specific_options.cmake new file mode 100644 index 000000000..dd8e3a7a5 --- /dev/null +++ b/ml/dlib/dlib/cmake_utils/set_compiler_specific_options.cmake @@ -0,0 +1,131 @@ + +cmake_minimum_required(VERSION 2.8.12) + +if (POLICY CMP0054) + cmake_policy(SET CMP0054 NEW) +endif() + +set(USING_OLD_VISUAL_STUDIO_COMPILER 0) +if(MSVC AND MSVC_VERSION VERSION_LESS 1900) + message(FATAL_ERROR "C++11 is required to use dlib, but the version of Visual Studio you are using is too old and doesn't support C++11. You need Visual Studio 2015 or newer. ") +elseif(MSVC AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 19.0.24210.0 ) + message(STATUS "NOTE: Visual Studio didn't have good enough C++11 support until Visual Studio 2015 update 3 (v19.0.24210.0)") + message(STATUS "So we aren't enabling things that require full C++11 support (e.g. the deep learning tools).") + message(STATUS "Also, be aware that Visual Studio's version naming is confusing, in particular, there are multiple versions of 'update 3'") + message(STATUS "So if you are getting this message you need to update to the newer version of Visual Studio to use full C++11.") + set(USING_OLD_VISUAL_STUDIO_COMPILER 1) +elseif(MSVC AND (MSVC_VERSION EQUAL 1911 OR MSVC_VERSION EQUAL 1910)) + message(STATUS "******************************************************************************************") + message(STATUS "Your version of Visual Studio has incomplete C++11 support and is unable to compile the ") + message(STATUS "DNN examples. So we are disabling the deep learning tools. If you want to use the DNN ") + message(STATUS "tools in dlib then update your copy of Visual Studio.") + message(STATUS "******************************************************************************************") + set(USING_OLD_VISUAL_STUDIO_COMPILER 1) +endif() + +if(CMAKE_COMPILER_IS_GNUCXX) + execute_process(COMMAND ${CMAKE_CXX_COMPILER} -dumpversion OUTPUT_VARIABLE GCC_VERSION) + if (GCC_VERSION VERSION_LESS 4.8) + message(FATAL_ERROR "C++11 is required to use dlib, but the version of GCC you are using is too old and doesn't support C++11. You need GCC 4.8 or newer. ") + endif() +endif() + + +# push USING_OLD_VISUAL_STUDIO_COMPILER to the parent so we can use it in the +# examples CMakeLists.txt file. +get_directory_property(has_parent PARENT_DIRECTORY) +if(has_parent) + set(USING_OLD_VISUAL_STUDIO_COMPILER ${USING_OLD_VISUAL_STUDIO_COMPILER} PARENT_SCOPE) +endif() + + + +set(gcc_like_compilers GNU Clang Intel) +set(intel_archs x86_64 i386 i686 AMD64 amd64 x86) + + +# Setup some options to allow a user to enable SSE and AVX instruction use. +if ((";${gcc_like_compilers};" MATCHES ";${CMAKE_CXX_COMPILER_ID};") AND + (";${intel_archs};" MATCHES ";${CMAKE_SYSTEM_PROCESSOR};") AND NOT USE_AUTO_VECTOR) + option(USE_SSE2_INSTRUCTIONS "Compile your program with SSE2 instructions" OFF) + option(USE_SSE4_INSTRUCTIONS "Compile your program with SSE4 instructions" OFF) + option(USE_AVX_INSTRUCTIONS "Compile your program with AVX instructions" OFF) + if(USE_AVX_INSTRUCTIONS) + list(APPEND active_compile_opts -mavx) + message(STATUS "Enabling AVX instructions") + elseif (USE_SSE4_INSTRUCTIONS) + list(APPEND active_compile_opts -msse4) + message(STATUS "Enabling SSE4 instructions") + elseif(USE_SSE2_INSTRUCTIONS) + list(APPEND active_compile_opts -msse2) + message(STATUS "Enabling SSE2 instructions") + endif() +elseif (MSVC OR "${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC") # else if using Visual Studio + # Use SSE2 by default when using Visual Studio. + option(USE_SSE2_INSTRUCTIONS "Compile your program with SSE2 instructions" ON) + option(USE_SSE4_INSTRUCTIONS "Compile your program with SSE4 instructions" OFF) + option(USE_AVX_INSTRUCTIONS "Compile your program with AVX instructions" OFF) + + include(CheckTypeSize) + check_type_size( "void*" SIZE_OF_VOID_PTR) + if(USE_AVX_INSTRUCTIONS) + list(APPEND active_compile_opts /arch:AVX) + message(STATUS "Enabling AVX instructions") + elseif (USE_SSE4_INSTRUCTIONS) + # Visual studio doesn't have an /arch:SSE2 flag when building in 64 bit modes. + # So only give it when we are doing a 32 bit build. + if (SIZE_OF_VOID_PTR EQUAL 4) + list(APPEND active_compile_opts /arch:SSE2) + endif() + message(STATUS "Enabling SSE4 instructions") + list(APPEND active_preprocessor_switches "-DDLIB_HAVE_SSE2") + list(APPEND active_preprocessor_switches "-DDLIB_HAVE_SSE3") + list(APPEND active_preprocessor_switches "-DDLIB_HAVE_SSE41") + elseif(USE_SSE2_INSTRUCTIONS) + # Visual studio doesn't have an /arch:SSE2 flag when building in 64 bit modes. + # So only give it when we are doing a 32 bit build. + if (SIZE_OF_VOID_PTR EQUAL 4) + list(APPEND active_compile_opts /arch:SSE2) + endif() + message(STATUS "Enabling SSE2 instructions") + list(APPEND active_preprocessor_switches "-DDLIB_HAVE_SSE2") + endif() + +elseif((";${gcc_like_compilers};" MATCHES ";${CMAKE_CXX_COMPILER_ID};") AND + ("${CMAKE_SYSTEM_PROCESSOR}" MATCHES "^arm")) + option(USE_NEON_INSTRUCTIONS "Compile your program with ARM-NEON instructions" OFF) + if(USE_NEON_INSTRUCTIONS) + list(APPEND active_compile_opts -mfpu=neon) + message(STATUS "Enabling ARM-NEON instructions") + endif() +endif() + + + + +if (CMAKE_COMPILER_IS_GNUCXX) + # By default, g++ won't warn or error if you forget to return a value in a + # function which requires you to do so. This option makes it give a warning + # for doing this. + list(APPEND active_compile_opts "-Wreturn-type") +endif() + +if ("Clang" MATCHES ${CMAKE_CXX_COMPILER_ID}) + # Increase clang's default tempalte recurision depth so the dnn examples don't error out. + list(APPEND active_compile_opts "-ftemplate-depth=500") +endif() + +if (MSVC) + # By default Visual Studio does not support .obj files with more than 65k sections. + # However, code generated by file_to_code_ex and code using DNN module can have + # them. So this flag enables > 65k sections, but produces .obj files + # that will not be readable by VS 2005. + list(APPEND active_compile_opts "/bigobj") + + if(CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 3.3) + # Clang can compile all Dlib's code at Windows platform. Tested with Clang 5 + list(APPEND active_compile_opts "-Xclang -fcxx-exceptions") + endif() +endif() + + diff --git a/ml/dlib/dlib/cmake_utils/tell_visual_studio_to_use_static_runtime.cmake b/ml/dlib/dlib/cmake_utils/tell_visual_studio_to_use_static_runtime.cmake new file mode 100644 index 000000000..e5fb09129 --- /dev/null +++ b/ml/dlib/dlib/cmake_utils/tell_visual_studio_to_use_static_runtime.cmake @@ -0,0 +1,19 @@ + +# Including this cmake script into your cmake project will cause visual studio +# to build your project against the static C runtime. + +cmake_minimum_required(VERSION 2.8.12) +if (POLICY CMP0054) + cmake_policy(SET CMP0054 NEW) +endif() + +if (MSVC OR "${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC") + foreach(flag_var + CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE + CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO) + if(${flag_var} MATCHES "/MD") + string(REGEX REPLACE "/MD" "/MT" ${flag_var} "${${flag_var}}") + endif() + endforeach(flag_var) +endif() + diff --git a/ml/dlib/dlib/cmake_utils/test_for_cpp11/CMakeLists.txt b/ml/dlib/dlib/cmake_utils/test_for_cpp11/CMakeLists.txt new file mode 100644 index 000000000..bc6f02563 --- /dev/null +++ b/ml/dlib/dlib/cmake_utils/test_for_cpp11/CMakeLists.txt @@ -0,0 +1,17 @@ + +cmake_minimum_required(VERSION 2.8.12) +project(cpp11_test) + +# Try to enable C++11 +include(CheckCXXCompilerFlag) +CHECK_CXX_COMPILER_FLAG("-std=c++11" COMPILER_SUPPORTS_CXX11) +CHECK_CXX_COMPILER_FLAG("-std=c++0x" COMPILER_SUPPORTS_CXX0X) +if(COMPILER_SUPPORTS_CXX11) + message(STATUS "C++11 activated.") + ADD_DEFINITIONS("-std=c++11") +elseif(COMPILER_SUPPORTS_CXX0X) + message(STATUS "C++0x activated.") + ADD_DEFINITIONS("-std=c++0x") +endif() + +add_library(cpp11_test STATIC cpp11_test.cpp ) diff --git a/ml/dlib/dlib/cmake_utils/test_for_cpp11/cpp11_test.cpp b/ml/dlib/dlib/cmake_utils/test_for_cpp11/cpp11_test.cpp new file mode 100644 index 000000000..6cc4f479b --- /dev/null +++ b/ml/dlib/dlib/cmake_utils/test_for_cpp11/cpp11_test.cpp @@ -0,0 +1,51 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include +#include + +using namespace std; + +class testme +{ +public: + + testme(testme&&) = default; + testme(const testme&) = delete; + + + template + auto auto_return(T f) -> decltype(f(4)) { return f(4); } + + template + auto auto_return(T f) -> decltype(f()) { return f(); } + + static int returnint() { return 0; } + + void dostuff() + { + thread_local int stuff1 = 999; + auto x = 4; + + decltype(x) asdf = 9; + + auto f = []() { cout << "in a lambda!" << endl; }; + f(); + + auto_return(returnint); + } + + template + void variadic_template( + T&& ...args + ) + { + } + + + + std::shared_ptr asdf; +}; + +// ------------------------------------------------------------------------------------ + diff --git a/ml/dlib/dlib/cmake_utils/test_for_cuda/CMakeLists.txt b/ml/dlib/dlib/cmake_utils/test_for_cuda/CMakeLists.txt new file mode 100644 index 000000000..5f6af245e --- /dev/null +++ b/ml/dlib/dlib/cmake_utils/test_for_cuda/CMakeLists.txt @@ -0,0 +1,14 @@ + +cmake_minimum_required(VERSION 2.8.12) +project(cuda_test) + +include_directories(../../dnn) +add_definitions(-DDLIB_USE_CUDA) + +# Override the FindCUDA.cmake setting to avoid duplication of host flags if using a toolchain: +option(CUDA_PROPAGATE_HOST_FLAGS "Propage C/CXX_FLAGS and friends to the host compiler via -Xcompile" OFF) +find_package(CUDA 7.5 REQUIRED) +set(CUDA_HOST_COMPILATION_CPP ON) +list(APPEND CUDA_NVCC_FLAGS "-arch=sm_30;-std=c++11;-D__STRICT_ANSI__;-D_MWAITXINTRIN_H_INCLUDED;-D_FORCE_INLINES") + +cuda_add_library(cuda_test STATIC cuda_test.cu ) diff --git a/ml/dlib/dlib/cmake_utils/test_for_cuda/cuda_test.cu b/ml/dlib/dlib/cmake_utils/test_for_cuda/cuda_test.cu new file mode 100644 index 000000000..fb1ffe0da --- /dev/null +++ b/ml/dlib/dlib/cmake_utils/test_for_cuda/cuda_test.cu @@ -0,0 +1,21 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "cuda_utils.h" +#include "cuda_dlib.h" + + +// ------------------------------------------------------------------------------------ + +__global__ void cuda_add_arrays(const float* a, const float* b, float* out, size_t n) +{ + out[0] += a[0]+b[0]; +} + +void add_arrays() +{ + cuda_add_arrays<<<512,512>>>(0,0,0,0); +} + +// ------------------------------------------------------------------------------------ + diff --git a/ml/dlib/dlib/cmake_utils/test_for_cudnn/CMakeLists.txt b/ml/dlib/dlib/cmake_utils/test_for_cudnn/CMakeLists.txt new file mode 100644 index 000000000..556088259 --- /dev/null +++ b/ml/dlib/dlib/cmake_utils/test_for_cudnn/CMakeLists.txt @@ -0,0 +1,19 @@ + +cmake_minimum_required(VERSION 2.8.12) +project(cudnn_test) +include(../use_cpp_11.cmake) + +# Override the FindCUDA.cmake setting to avoid duplication of host flags if using a toolchain: +option(CUDA_PROPAGATE_HOST_FLAGS "Propage C/CXX_FLAGS and friends to the host compiler via -Xcompile" OFF) +find_package(CUDA 7.5 REQUIRED) +set(CUDA_HOST_COMPILATION_CPP ON) +list(APPEND CUDA_NVCC_FLAGS "-arch=sm_30;-std=c++11;-D__STRICT_ANSI__") +add_definitions(-DDLIB_USE_CUDA) + +include(find_cudnn.txt) + +if (cudnn_include AND cudnn) + include_directories(${cudnn_include}) + cuda_add_library(cudnn_test STATIC ../../dnn/cudnn_dlibapi.cpp ${cudnn} ) + enable_cpp11_for_target(cudnn_test) +endif() diff --git a/ml/dlib/dlib/cmake_utils/test_for_cudnn/find_cudnn.txt b/ml/dlib/dlib/cmake_utils/test_for_cudnn/find_cudnn.txt new file mode 100644 index 000000000..dd5f14e3f --- /dev/null +++ b/ml/dlib/dlib/cmake_utils/test_for_cudnn/find_cudnn.txt @@ -0,0 +1,24 @@ + +message(STATUS "Looking for cuDNN install...") +# Look for cudnn, we will look in the same place as other CUDA +# libraries and also a few other places as well. +find_path(cudnn_include cudnn.h + HINTS ${CUDA_INCLUDE_DIRS} ENV CUDNN_INCLUDE_DIR ENV CUDNN_HOME + PATHS /usr/local ENV CPATH + PATH_SUFFIXES include + ) +get_filename_component(cudnn_hint_path "${CUDA_CUBLAS_LIBRARIES}" PATH) +find_library(cudnn cudnn + HINTS ${cudnn_hint_path} ENV CUDNN_LIBRARY_DIR ENV CUDNN_HOME + PATHS /usr/local /usr/local/cuda ENV LD_LIBRARY_PATH + PATH_SUFFIXES lib64 lib x64 + ) +mark_as_advanced(cudnn cudnn_include) + +if (cudnn AND cudnn_include) + message(STATUS "Found cuDNN: " ${cudnn}) +else() + message(STATUS "*** cuDNN V5.0 OR GREATER NOT FOUND. ***") + message(STATUS "*** Dlib requires cuDNN V5.0 OR GREATER. Since cuDNN is not found DLIB WILL NOT USE CUDA. ***") + message(STATUS "*** If you have cuDNN then set CMAKE_PREFIX_PATH to include cuDNN's folder. ***") +endif() diff --git a/ml/dlib/dlib/cmake_utils/test_for_neon/CMakeLists.txt b/ml/dlib/dlib/cmake_utils/test_for_neon/CMakeLists.txt new file mode 100644 index 000000000..0b6eb6f28 --- /dev/null +++ b/ml/dlib/dlib/cmake_utils/test_for_neon/CMakeLists.txt @@ -0,0 +1,6 @@ + +cmake_minimum_required(VERSION 2.8.12) +project(neon_test) + +add_library(neon_test STATIC neon_test.cpp ) + diff --git a/ml/dlib/dlib/cmake_utils/test_for_neon/neon_test.cpp b/ml/dlib/dlib/cmake_utils/test_for_neon/neon_test.cpp new file mode 100644 index 000000000..a4abdbade --- /dev/null +++ b/ml/dlib/dlib/cmake_utils/test_for_neon/neon_test.cpp @@ -0,0 +1,9 @@ +#ifdef __ARM_NEON__ +#else +#error "No NEON" +#endif +int main(){} + + +// ------------------------------------------------------------------------------------ + diff --git a/ml/dlib/dlib/cmake_utils/use_cpp_11.cmake b/ml/dlib/dlib/cmake_utils/use_cpp_11.cmake new file mode 100644 index 000000000..e49e30f2a --- /dev/null +++ b/ml/dlib/dlib/cmake_utils/use_cpp_11.cmake @@ -0,0 +1,113 @@ +# This script creates a function, enable_cpp11_for_target(), which checks if your +# compiler has C++11 support and enables it if it does. + + +cmake_minimum_required(VERSION 2.8.12) + +if (POLICY CMP0054) + cmake_policy(SET CMP0054 NEW) +endif() + + +set(_where_is_cmake_utils_dir ${CMAKE_CURRENT_LIST_DIR}) + +function(enable_cpp11_for_target target_name) + + +# Set to false unless we find out otherwise in the code below. +set(COMPILER_CAN_DO_CPP_11 0) + + + +macro(test_compiler_for_cpp11) + message(STATUS "Building a C++11 test project to see if your compiler supports C++11") + try_compile(test_for_cpp11_worked ${PROJECT_BINARY_DIR}/cpp11_test_build + ${_where_is_cmake_utils_dir}/test_for_cpp11 cpp11_test) + if (test_for_cpp11_worked) + message(STATUS "C++11 activated.") + set(COMPILER_CAN_DO_CPP_11 1) + else() + set(COMPILER_CAN_DO_CPP_11 0) + message(STATUS "********** Your compiler failed to build a C++11 project. C++11 is required to use all parts of dlib! **********") + endif() +endmacro() + +# Now turn on the appropriate compiler switch to enable C++11 if you have a +# C++11 compiler. In CMake 3.1 there is a simple flag you can set, but earlier +# verions of CMake are not so convenient. +if (CMAKE_VERSION VERSION_LESS "3.1.2") + if(CMAKE_COMPILER_IS_GNUCXX) + execute_process(COMMAND ${CMAKE_CXX_COMPILER} -dumpversion OUTPUT_VARIABLE GCC_VERSION) + if (GCC_VERSION VERSION_GREATER 4.8 OR GCC_VERSION VERSION_EQUAL 4.8) + message(STATUS "C++11 activated.") + target_compile_options(${target_name} PUBLIC "-std=gnu++11") + set(COMPILER_CAN_DO_CPP_11 1) + endif() + elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") + execute_process( COMMAND ${CMAKE_CXX_COMPILER} --version OUTPUT_VARIABLE clang_full_version_string ) + string (REGEX REPLACE ".*clang version ([0-9]+\\.[0-9]+).*" "\\1" CLANG_VERSION ${clang_full_version_string}) + if (CLANG_VERSION VERSION_GREATER 3.3) + message(STATUS "C++11 activated.") + target_compile_options(${target_name} PUBLIC "-std=c++11") + set(COMPILER_CAN_DO_CPP_11 1) + endif() + else() + # Since we don't know what compiler this is just try to build a c++11 project and see if it compiles. + test_compiler_for_cpp11() + endif() +else() + + # Set a flag if the compiler you are using is capable of providing C++11 features. + get_property(cxx_features GLOBAL PROPERTY CMAKE_CXX_KNOWN_FEATURES) + if (";${cxx_features};" MATCHES ";cxx_rvalue_references;" AND + ";${cxx_features};" MATCHES ";cxx_variadic_templates;" AND + ";${cxx_features};" MATCHES ";cxx_lambdas;" AND + ";${cxx_features};" MATCHES ";cxx_defaulted_move_initializers;" AND + ";${cxx_features};" MATCHES ";cxx_delegating_constructors;" AND + ";${cxx_features};" MATCHES ";cxx_thread_local;" AND + ";${cxx_features};" MATCHES ";cxx_constexpr;" AND + ";${cxx_features};" MATCHES ";cxx_decltype_incomplete_return_types;" AND + ";${cxx_features};" MATCHES ";cxx_auto_type;") + + set(COMPILER_CAN_DO_CPP_11 1) + # Tell cmake that we need C++11 for dlib + target_compile_features(${target_name} + PUBLIC + cxx_rvalue_references + cxx_variadic_templates + cxx_lambdas + cxx_defaulted_move_initializers + cxx_delegating_constructors + cxx_thread_local + cxx_constexpr + # cxx_decltype_incomplete_return_types # purposfully commented out because cmake errors out on this when using visual studio and cmake 3.8.0 + cxx_auto_type + ) + + if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") + # Sometimes clang will lie and report that it supports C++11 when + # really it doesn't support thread_local. So check for that. + test_compiler_for_cpp11() + else() + message(STATUS "C++11 activated.") + endif() + endif() +endif() + +# Always enable whatever partial C++11 support we have, even if it isn't full +# support, and just hope for the best. +if (NOT COMPILER_CAN_DO_CPP_11) + include(CheckCXXCompilerFlag) + CHECK_CXX_COMPILER_FLAG("-std=c++11" COMPILER_SUPPORTS_CXX11) + CHECK_CXX_COMPILER_FLAG("-std=c++0x" COMPILER_SUPPORTS_CXX0X) + if(COMPILER_SUPPORTS_CXX11) + message(STATUS "C++11 activated (compiler doesn't have full C++11 support).") + target_compile_options(${target_name} PUBLIC "-std=c++11") + elseif(COMPILER_SUPPORTS_CXX0X) + message(STATUS "C++0x activated (compiler doesn't have full C++11 support).") + target_compile_options(${target_name} PUBLIC "-std=c++0x") + endif() +endif() + +endfunction() + diff --git a/ml/dlib/dlib/cmd_line_parser.h b/ml/dlib/dlib/cmd_line_parser.h new file mode 100644 index 000000000..fd1148038 --- /dev/null +++ b/ml/dlib/dlib/cmd_line_parser.h @@ -0,0 +1,84 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CMD_LINE_PARSEr_ +#define DLIB_CMD_LINE_PARSEr_ + +#include "cmd_line_parser/cmd_line_parser_kernel_1.h" +#include "cmd_line_parser/cmd_line_parser_kernel_c.h" +#include "cmd_line_parser/cmd_line_parser_print_1.h" +#include "cmd_line_parser/cmd_line_parser_check_1.h" +#include "cmd_line_parser/cmd_line_parser_check_c.h" +#include +#include "cmd_line_parser/get_option.h" + +#include "map.h" +#include "sequence.h" + + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename charT + > + class impl_cmd_line_parser + { + /*! + This class is basically just a big templated typedef for building + a complete command line parser type out of all the parts it needs. + !*/ + + impl_cmd_line_parser() {} + + typedef typename sequence >::kernel_2a sequence_2a; + typedef typename sequence*>::kernel_2a psequence_2a; + typedef typename map,void*>::kernel_1a map_1a_string; + + public: + + typedef cmd_line_parser_kernel_1 kernel_1a; + typedef cmd_line_parser_kernel_c kernel_1a_c; + typedef cmd_line_parser_print_1 print_1a_c; + typedef cmd_line_parser_check_c > check_1a_c; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename charT + > + class cmd_line_parser : public impl_cmd_line_parser::check_1a_c + { + public: + + // These typedefs are here for backwards compatibility with previous versions of dlib. + typedef cmd_line_parser kernel_1a; + typedef cmd_line_parser kernel_1a_c; + typedef cmd_line_parser print_1a; + typedef cmd_line_parser print_1a_c; + typedef cmd_line_parser check_1a; + typedef cmd_line_parser check_1a_c; + }; + + template < + typename charT + > + inline void swap ( + cmd_line_parser& a, + cmd_line_parser& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- + + typedef cmd_line_parser command_line_parser; + typedef cmd_line_parser wcommand_line_parser; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CMD_LINE_PARSEr_ + diff --git a/ml/dlib/dlib/cmd_line_parser/cmd_line_parser_check_1.h b/ml/dlib/dlib/cmd_line_parser/cmd_line_parser_check_1.h new file mode 100644 index 000000000..1736b4b56 --- /dev/null +++ b/ml/dlib/dlib/cmd_line_parser/cmd_line_parser_check_1.h @@ -0,0 +1,580 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CMD_LINE_PARSER_CHECk_1_ +#define DLIB_CMD_LINE_PARSER_CHECk_1_ + +#include "cmd_line_parser_kernel_abstract.h" +#include +#include +#include "../string.h" +#include + +namespace dlib +{ + + template < + typename clp_base + > + class cmd_line_parser_check_1 : public clp_base + { + + /*! + This extension doesn't add any state. + + !*/ + + + public: + typedef typename clp_base::char_type char_type; + typedef typename clp_base::string_type string_type; + + // ------------------------------------------------------------------------------------ + + class cmd_line_check_error : public dlib::error + { + friend class cmd_line_parser_check_1; + + cmd_line_check_error( + error_type t, + const string_type& opt_, + const string_type& arg_ + ) : + dlib::error(t), + opt(opt_), + opt2(), + arg(arg_), + required_opts() + { set_info_string(); } + + cmd_line_check_error( + error_type t, + const string_type& opt_, + const string_type& opt2_, + int // this is just to make this constructor different from the one above + ) : + dlib::error(t), + opt(opt_), + opt2(opt2_), + arg(), + required_opts() + { set_info_string(); } + + cmd_line_check_error ( + error_type t, + const string_type& opt_, + const std::vector& vect + ) : + dlib::error(t), + opt(opt_), + opt2(), + arg(), + required_opts(vect) + { set_info_string(); } + + cmd_line_check_error( + error_type t, + const string_type& opt_ + ) : + dlib::error(t), + opt(opt_), + opt2(), + arg(), + required_opts() + { set_info_string(); } + + ~cmd_line_check_error() throw() {} + + void set_info_string ( + ) + { + std::ostringstream sout; + switch (type) + { + case EINVALID_OPTION_ARG: + sout << "Command line error: '" << narrow(arg) << "' is not a valid argument to " + << "the '" << narrow(opt) << "' option."; + break; + case EMISSING_REQUIRED_OPTION: + if (required_opts.size() == 1) + { + sout << "Command line error: The '" << narrow(opt) << "' option requires the presence of " + << "the '" << required_opts[0] << "' option."; + } + else + { + sout << "Command line error: The '" << narrow(opt) << "' option requires the presence of " + << "one of the following options: "; + for (unsigned long i = 0; i < required_opts.size(); ++i) + { + if (i == required_opts.size()-2) + sout << "'" << required_opts[i] << "' or "; + else if (i == required_opts.size()-1) + sout << "'" << required_opts[i] << "'."; + else + sout << "'" << required_opts[i] << "', "; + } + } + break; + case EINCOMPATIBLE_OPTIONS: + sout << "Command line error: The '" << narrow(opt) << "' and '" << narrow(opt2) + << "' options cannot be given together on the command line."; + break; + case EMULTIPLE_OCCURANCES: + sout << "Command line error: The '" << narrow(opt) << "' option can only " + << "be given on the command line once."; + break; + default: + sout << "Command line error."; + break; + } + const_cast(info) = wrap_string(sout.str(),0,0); + } + + public: + const string_type opt; + const string_type opt2; + const string_type arg; + const std::vector required_opts; + }; + + // ------------------------------------------------------------------------------------ + + template < + typename T + > + void check_option_arg_type ( + const string_type& option_name + ) const; + + template < + typename T + > + void check_option_arg_range ( + const string_type& option_name, + const T& first, + const T& last + ) const; + + template < + typename T, + size_t length + > + void check_option_arg_range ( + const string_type& option_name, + const T (&arg_set)[length] + ) const; + + template < + size_t length + > + void check_option_arg_range ( + const string_type& option_name, + const char_type* (&arg_set)[length] + ) const; + + template < + size_t length + > + void check_incompatible_options ( + const char_type* (&option_set)[length] + ) const; + + template < + size_t length + > + void check_one_time_options ( + const char_type* (&option_set)[length] + ) const; + + void check_incompatible_options ( + const string_type& option_name1, + const string_type& option_name2 + ) const; + + void check_sub_option ( + const string_type& parent_option, + const string_type& sub_option + ) const; + + template < + size_t length + > + void check_sub_options ( + const string_type& parent_option, + const char_type* (&sub_option_set)[length] + ) const; + + template < + size_t length + > + void check_sub_options ( + const char_type* (&parent_option_set)[length], + const string_type& sub_option + ) const; + + template < + size_t parent_length, + size_t sub_length + > + void check_sub_options ( + const char_type* (&parent_option_set)[parent_length], + const char_type* (&sub_option_set)[sub_length] + ) const; + }; + + template < + typename clp_base + > + inline void swap ( + cmd_line_parser_check_1& a, + cmd_line_parser_check_1& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template + template + void cmd_line_parser_check_1:: + check_option_arg_type ( + const string_type& option_name + ) const + { + try + { + const typename clp_base::option_type& opt = this->option(option_name); + const unsigned long number_of_arguments = opt.number_of_arguments(); + const unsigned long count = opt.count(); + for (unsigned long i = 0; i < number_of_arguments; ++i) + { + for (unsigned long j = 0; j < count; ++j) + { + string_cast(opt.argument(i,j)); + } + } + } + catch (string_cast_error& e) + { + throw cmd_line_check_error(EINVALID_OPTION_ARG,option_name,e.info); + } + } + +// ---------------------------------------------------------------------------------------- + + template + template + void cmd_line_parser_check_1:: + check_option_arg_range ( + const string_type& option_name, + const T& first, + const T& last + ) const + { + try + { + const typename clp_base::option_type& opt = this->option(option_name); + const unsigned long number_of_arguments = opt.number_of_arguments(); + const unsigned long count = opt.count(); + for (unsigned long i = 0; i < number_of_arguments; ++i) + { + for (unsigned long j = 0; j < count; ++j) + { + T temp(string_cast(opt.argument(i,j))); + if (temp < first || last < temp) + { + throw cmd_line_check_error( + EINVALID_OPTION_ARG, + option_name, + opt.argument(i,j) + ); + } + } + } + } + catch (string_cast_error& e) + { + throw cmd_line_check_error(EINVALID_OPTION_ARG,option_name,e.info); + } + } + +// ---------------------------------------------------------------------------------------- + + template + template < typename T, size_t length > + void cmd_line_parser_check_1:: + check_option_arg_range ( + const string_type& option_name, + const T (&arg_set)[length] + ) const + { + try + { + const typename clp_base::option_type& opt = this->option(option_name); + const unsigned long number_of_arguments = opt.number_of_arguments(); + const unsigned long count = opt.count(); + for (unsigned long i = 0; i < number_of_arguments; ++i) + { + for (unsigned long j = 0; j < count; ++j) + { + T temp(string_cast(opt.argument(i,j))); + size_t k = 0; + for (; k < length; ++k) + { + if (arg_set[k] == temp) + break; + } + if (k == length) + { + throw cmd_line_check_error( + EINVALID_OPTION_ARG, + option_name, + opt.argument(i,j) + ); + } + } + } + } + catch (string_cast_error& e) + { + throw cmd_line_check_error(EINVALID_OPTION_ARG,option_name,e.info); + } + } + +// ---------------------------------------------------------------------------------------- + + template + template < size_t length > + void cmd_line_parser_check_1:: + check_option_arg_range ( + const string_type& option_name, + const char_type* (&arg_set)[length] + ) const + { + const typename clp_base::option_type& opt = this->option(option_name); + const unsigned long number_of_arguments = opt.number_of_arguments(); + const unsigned long count = opt.count(); + for (unsigned long i = 0; i < number_of_arguments; ++i) + { + for (unsigned long j = 0; j < count; ++j) + { + size_t k = 0; + for (; k < length; ++k) + { + if (arg_set[k] == opt.argument(i,j)) + break; + } + if (k == length) + { + throw cmd_line_check_error( + EINVALID_OPTION_ARG, + option_name, + opt.argument(i,j) + ); + } + } + } + } + +// ---------------------------------------------------------------------------------------- + + template + template < size_t length > + void cmd_line_parser_check_1:: + check_incompatible_options ( + const char_type* (&option_set)[length] + ) const + { + for (size_t i = 0; i < length; ++i) + { + for (size_t j = i+1; j < length; ++j) + { + if (this->option(option_set[i]).count() > 0 && + this->option(option_set[j]).count() > 0 ) + { + throw cmd_line_check_error( + EINCOMPATIBLE_OPTIONS, + option_set[i], + option_set[j], + 0 // this argument has no meaning and is only here to make this + // call different from the other constructor + ); + } + } + } + } + +// ---------------------------------------------------------------------------------------- + + template + void cmd_line_parser_check_1:: + check_incompatible_options ( + const string_type& option_name1, + const string_type& option_name2 + ) const + { + if (this->option(option_name1).count() > 0 && + this->option(option_name2).count() > 0 ) + { + throw cmd_line_check_error( + EINCOMPATIBLE_OPTIONS, + option_name1, + option_name2, + 0 // this argument has no meaning and is only here to make this + // call different from the other constructor + ); + } + } + +// ---------------------------------------------------------------------------------------- + + template + void cmd_line_parser_check_1:: + check_sub_option ( + const string_type& parent_option, + const string_type& sub_option + ) const + { + if (this->option(parent_option).count() == 0) + { + if (this->option(sub_option).count() != 0) + { + std::vector vect; + vect.resize(1); + vect[0] = parent_option; + throw cmd_line_check_error( EMISSING_REQUIRED_OPTION, sub_option, vect); + } + } + } + +// ---------------------------------------------------------------------------------------- + + template + template < size_t length > + void cmd_line_parser_check_1:: + check_sub_options ( + const string_type& parent_option, + const char_type* (&sub_option_set)[length] + ) const + { + if (this->option(parent_option).count() == 0) + { + size_t i = 0; + for (; i < length; ++i) + { + if (this->option(sub_option_set[i]).count() > 0) + break; + } + if (i != length) + { + std::vector vect; + vect.resize(1); + vect[0] = parent_option; + throw cmd_line_check_error( EMISSING_REQUIRED_OPTION, sub_option_set[i], vect); + } + } + } + +// ---------------------------------------------------------------------------------------- + + template + template < size_t length > + void cmd_line_parser_check_1:: + check_sub_options ( + const char_type* (&parent_option_set)[length], + const string_type& sub_option + ) const + { + // first check if the sub_option is present + if (this->option(sub_option).count() > 0) + { + // now check if any of the parents are present + bool parents_present = false; + for (size_t i = 0; i < length; ++i) + { + if (this->option(parent_option_set[i]).count() > 0) + { + parents_present = true; + break; + } + } + + if (!parents_present) + { + std::vector vect(parent_option_set, parent_option_set+length); + throw cmd_line_check_error( EMISSING_REQUIRED_OPTION, sub_option, vect); + } + } + } + +// ---------------------------------------------------------------------------------------- + + template + template < size_t parent_length, size_t sub_length > + void cmd_line_parser_check_1:: + check_sub_options ( + const char_type* (&parent_option_set)[parent_length], + const char_type* (&sub_option_set)[sub_length] + ) const + { + // first check if any of the parent options are present + bool parents_present = false; + for (size_t i = 0; i < parent_length; ++i) + { + if (this->option(parent_option_set[i]).count() > 0) + { + parents_present = true; + break; + } + } + + if (!parents_present) + { + // none of these sub options should be present + size_t i = 0; + for (; i < sub_length; ++i) + { + if (this->option(sub_option_set[i]).count() > 0) + break; + } + if (i != sub_length) + { + std::vector vect(parent_option_set, parent_option_set+parent_length); + throw cmd_line_check_error( EMISSING_REQUIRED_OPTION, sub_option_set[i], vect); + } + } + } + +// ---------------------------------------------------------------------------------------- + + template + template < size_t length > + void cmd_line_parser_check_1:: + check_one_time_options ( + const char_type* (&option_set)[length] + ) const + { + size_t i = 0; + for (; i < length; ++i) + { + if (this->option(option_set[i]).count() > 1) + break; + } + if (i != length) + { + throw cmd_line_check_error( + EMULTIPLE_OCCURANCES, + option_set[i] + ); + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CMD_LINE_PARSER_CHECk_1_ + + diff --git a/ml/dlib/dlib/cmd_line_parser/cmd_line_parser_check_c.h b/ml/dlib/dlib/cmd_line_parser/cmd_line_parser_check_c.h new file mode 100644 index 000000000..7ff858e89 --- /dev/null +++ b/ml/dlib/dlib/cmd_line_parser/cmd_line_parser_check_c.h @@ -0,0 +1,453 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CMD_LINE_PARSER_CHECk_C_ +#define DLIB_CMD_LINE_PARSER_CHECk_C_ + +#include "cmd_line_parser_kernel_abstract.h" +#include "../algs.h" +#include "../assert.h" +#include +#include "../interfaces/cmd_line_parser_option.h" +#include "../string.h" + +namespace dlib +{ + + template < + typename clp_check + > + class cmd_line_parser_check_c : public clp_check + { + public: + + typedef typename clp_check::char_type char_type; + typedef typename clp_check::string_type string_type; + + template < + typename T + > + void check_option_arg_type ( + const string_type& option_name + ) const; + + template < + typename T + > + void check_option_arg_range ( + const string_type& option_name, + const T& first, + const T& last + ) const; + + template < + typename T, + size_t length + > + void check_option_arg_range ( + const string_type& option_name, + const T (&arg_set)[length] + ) const; + + template < + size_t length + > + void check_option_arg_range ( + const string_type& option_name, + const char_type* (&arg_set)[length] + ) const; + + template < + size_t length + > + void check_incompatible_options ( + const char_type* (&option_set)[length] + ) const; + + template < + size_t length + > + void check_one_time_options ( + const char_type* (&option_set)[length] + ) const; + + void check_incompatible_options ( + const string_type& option_name1, + const string_type& option_name2 + ) const; + + void check_sub_option ( + const string_type& parent_option, + const string_type& sub_option + ) const; + + template < + size_t length + > + void check_sub_options ( + const string_type& parent_option, + const char_type* (&sub_option_set)[length] + ) const; + + template < + size_t length + > + void check_sub_options ( + const char_type* (&parent_option_set)[length], + const string_type& sub_option + ) const; + + template < + size_t parent_length, + size_t sub_length + > + void check_sub_options ( + const char_type* (&parent_option_set)[parent_length], + const char_type* (&sub_option_set)[sub_length] + ) const; + }; + + + template < + typename clp_check + > + inline void swap ( + cmd_line_parser_check_c& a, + cmd_line_parser_check_c& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template + template + void cmd_line_parser_check_c:: + check_option_arg_type ( + const string_type& option_name + ) const + { + COMPILE_TIME_ASSERT(is_pointer_type::value == false); + + // make sure requires clause is not broken + DLIB_CASSERT( this->parsed_line() == true && this->option_is_defined(option_name), + "\tvoid cmd_line_parser_check::check_option_arg_type()" + << "\n\tYou must have already parsed the command line and option_name must be valid." + << "\n\tthis: " << this + << "\n\toption_is_defined(option_name): " << ((this->option_is_defined(option_name))?"true":"false") + << "\n\tparsed_line(): " << ((this->parsed_line())?"true":"false") + << "\n\toption_name: " << option_name + ); + + clp_check::template check_option_arg_type(option_name); + } + +// ---------------------------------------------------------------------------------------- + + template + template + void cmd_line_parser_check_c:: + check_option_arg_range ( + const string_type& option_name, + const T& first, + const T& last + ) const + { + COMPILE_TIME_ASSERT(is_pointer_type::value == false); + + // make sure requires clause is not broken + DLIB_CASSERT( this->parsed_line() == true && this->option_is_defined(option_name) && + first <= last, + "\tvoid cmd_line_parser_check::check_option_arg_range()" + << "\n\tSee the requires clause for this function." + << "\n\tthis: " << this + << "\n\toption_is_defined(option_name): " << ((this->option_is_defined(option_name))?"true":"false") + << "\n\tparsed_line(): " << ((this->parsed_line())?"true":"false") + << "\n\toption_name: " << option_name + << "\n\tfirst: " << first + << "\n\tlast: " << last + ); + + clp_check::check_option_arg_range(option_name,first,last); + } + +// ---------------------------------------------------------------------------------------- + + template + template < typename T, size_t length > + void cmd_line_parser_check_c:: + check_option_arg_range ( + const string_type& option_name, + const T (&arg_set)[length] + ) const + { + COMPILE_TIME_ASSERT(is_pointer_type::value == false); + + // make sure requires clause is not broken + DLIB_CASSERT( this->parsed_line() == true && this->option_is_defined(option_name), + "\tvoid cmd_line_parser_check::check_option_arg_range()" + << "\n\tSee the requires clause for this function." + << "\n\tthis: " << this + << "\n\toption_is_defined(option_name): " << ((this->option_is_defined(option_name))?"true":"false") + << "\n\tparsed_line(): " << ((this->parsed_line())?"true":"false") + << "\n\toption_name: " << option_name + ); + + clp_check::check_option_arg_range(option_name,arg_set); + } + +// ---------------------------------------------------------------------------------------- + + template + template < size_t length > + void cmd_line_parser_check_c:: + check_option_arg_range ( + const string_type& option_name, + const char_type* (&arg_set)[length] + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT( this->parsed_line() == true && this->option_is_defined(option_name), + "\tvoid cmd_line_parser_check::check_option_arg_range()" + << "\n\tSee the requires clause for this function." + << "\n\tthis: " << this + << "\n\toption_is_defined(option_name): " << ((this->option_is_defined(option_name))?"true":"false") + << "\n\tparsed_line(): " << ((this->parsed_line())?"true":"false") + << "\n\toption_name: " << option_name + ); + + clp_check::check_option_arg_range(option_name,arg_set); + } + +// ---------------------------------------------------------------------------------------- + + template + template < size_t length > + void cmd_line_parser_check_c:: + check_incompatible_options ( + const char_type* (&option_set)[length] + ) const + { + // make sure requires clause is not broken + for (size_t i = 0; i < length; ++i) + { + DLIB_CASSERT( this->parsed_line() == true && this->option_is_defined(option_set[i]), + "\tvoid cmd_line_parser_check::check_incompatible_options()" + << "\n\tSee the requires clause for this function." + << "\n\tthis: " << this + << "\n\toption_is_defined(option_set[i]): " << ((this->option_is_defined(option_set[i]))?"true":"false") + << "\n\tparsed_line(): " << ((this->parsed_line())?"true":"false") + << "\n\toption_set[i]: " << option_set[i] + << "\n\ti: " << static_cast(i) + ); + + } + clp_check::check_incompatible_options(option_set); + } + +// ---------------------------------------------------------------------------------------- + + template + void cmd_line_parser_check_c:: + check_incompatible_options ( + const string_type& option_name1, + const string_type& option_name2 + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT( this->parsed_line() == true && this->option_is_defined(option_name1) && + this->option_is_defined(option_name2), + "\tvoid cmd_line_parser_check::check_incompatible_options()" + << "\n\tSee the requires clause for this function." + << "\n\tthis: " << this + << "\n\toption_is_defined(option_name1): " << ((this->option_is_defined(option_name1))?"true":"false") + << "\n\toption_is_defined(option_name2): " << ((this->option_is_defined(option_name2))?"true":"false") + << "\n\tparsed_line(): " << ((this->parsed_line())?"true":"false") + << "\n\toption_name1: " << option_name1 + << "\n\toption_name2: " << option_name2 + ); + + clp_check::check_incompatible_options(option_name1,option_name2); + } + +// ---------------------------------------------------------------------------------------- + + template + void cmd_line_parser_check_c:: + check_sub_option ( + const string_type& parent_option, + const string_type& sub_option + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT( this->parsed_line() == true && this->option_is_defined(parent_option) && + this->option_is_defined(sub_option), + "\tvoid cmd_line_parser_check::check_sub_option()" + << "\n\tSee the requires clause for this function." + << "\n\tthis: " << this + << "\n\tparsed_line(): " << this->parsed_line() + << "\n\toption_is_defined(parent_option): " << this->option_is_defined(parent_option) + << "\n\toption_is_defined(sub_option): " << this->option_is_defined(sub_option) + << "\n\tparent_option: " << parent_option + << "\n\tsub_option: " << sub_option + ); + clp_check::check_sub_option(parent_option,sub_option); + } + +// ---------------------------------------------------------------------------------------- + + template + template < size_t length > + void cmd_line_parser_check_c:: + check_sub_options ( + const string_type& parent_option, + const char_type* (&sub_option_set)[length] + ) const + { + // make sure requires clause is not broken + for (size_t i = 0; i < length; ++i) + { + DLIB_CASSERT( this->option_is_defined(sub_option_set[i]), + "\tvoid cmd_line_parser_check::check_sub_options()" + << "\n\tSee the requires clause for this function." + << "\n\tthis: " << this + << "\n\toption_is_defined(sub_option_set[i]): " + << ((this->option_is_defined(sub_option_set[i]))?"true":"false") + << "\n\tsub_option_set[i]: " << sub_option_set[i] + << "\n\ti: " << static_cast(i) + ); + + } + + DLIB_CASSERT( this->parsed_line() == true && this->option_is_defined(parent_option), + "\tvoid cmd_line_parser_check::check_sub_options()" + << "\n\tSee the requires clause for this function." + << "\n\tthis: " << this + << "\n\toption_is_defined(parent_option): " << ((this->option_is_defined(parent_option))?"true":"false") + << "\n\tparsed_line(): " << ((this->parsed_line())?"true":"false") + << "\n\tparent_option: " << parent_option + ); + clp_check::check_sub_options(parent_option,sub_option_set); + + } + +// ---------------------------------------------------------------------------------------- + + template + template < size_t length > + void cmd_line_parser_check_c:: + check_sub_options ( + const char_type* (&parent_option_set)[length], + const string_type& sub_option + ) const + { + // make sure requires clause is not broken + for (size_t i = 0; i < length; ++i) + { + DLIB_CASSERT( this->option_is_defined(parent_option_set[i]), + "\tvoid cmd_line_parser_check::check_sub_options()" + << "\n\tSee the requires clause for this function." + << "\n\tthis: " << this + << "\n\toption_is_defined(parent_option_set[i]): " + << ((this->option_is_defined(parent_option_set[i]))?"true":"false") + << "\n\tparent_option_set[i]: " << parent_option_set[i] + << "\n\ti: " << static_cast(i) + ); + + } + + DLIB_CASSERT( this->parsed_line() == true && this->option_is_defined(sub_option), + "\tvoid cmd_line_parser_check::check_sub_options()" + << "\n\tSee the requires clause for this function." + << "\n\tthis: " << this + << "\n\toption_is_defined(sub_option): " << ((this->option_is_defined(sub_option))?"true":"false") + << "\n\tparsed_line(): " << ((this->parsed_line())?"true":"false") + << "\n\tsub_option: " << sub_option + ); + clp_check::check_sub_options(parent_option_set,sub_option); + + } + +// ---------------------------------------------------------------------------------------- + + template + template < size_t parent_length, size_t sub_length > + void cmd_line_parser_check_c:: + check_sub_options ( + const char_type* (&parent_option_set)[parent_length], + const char_type* (&sub_option_set)[sub_length] + ) const + { + // make sure requires clause is not broken + for (size_t i = 0; i < sub_length; ++i) + { + DLIB_CASSERT( this->option_is_defined(sub_option_set[i]), + "\tvoid cmd_line_parser_check::check_sub_options()" + << "\n\tSee the requires clause for this function." + << "\n\tthis: " << this + << "\n\toption_is_defined(sub_option_set[i]): " + << ((this->option_is_defined(sub_option_set[i]))?"true":"false") + << "\n\tsub_option_set[i]: " << sub_option_set[i] + << "\n\ti: " << static_cast(i) + ); + } + + for (size_t i = 0; i < parent_length; ++i) + { + DLIB_CASSERT( this->option_is_defined(parent_option_set[i]), + "\tvoid cmd_line_parser_check::check_parent_options()" + << "\n\tSee the requires clause for this function." + << "\n\tthis: " << this + << "\n\toption_is_defined(parent_option_set[i]): " + << ((this->option_is_defined(parent_option_set[i]))?"true":"false") + << "\n\tparent_option_set[i]: " << parent_option_set[i] + << "\n\ti: " << static_cast(i) + ); + } + + + + DLIB_CASSERT( this->parsed_line() == true , + "\tvoid cmd_line_parser_check::check_sub_options()" + << "\n\tYou must have parsed the command line before you call this function." + << "\n\tthis: " << this + << "\n\tparsed_line(): " << ((this->parsed_line())?"true":"false") + ); + + clp_check::check_sub_options(parent_option_set,sub_option_set); + + } + +// ---------------------------------------------------------------------------------------- + + template + template < size_t length > + void cmd_line_parser_check_c:: + check_one_time_options ( + const char_type* (&option_set)[length] + ) const + { + // make sure requires clause is not broken + for (size_t i = 0; i < length; ++i) + { + DLIB_CASSERT( this->parsed_line() == true && this->option_is_defined(option_set[i]), + "\tvoid cmd_line_parser_check::check_one_time_options()" + << "\n\tSee the requires clause for this function." + << "\n\tthis: " << this + << "\n\toption_is_defined(option_set[i]): " << ((this->option_is_defined(option_set[i]))?"true":"false") + << "\n\tparsed_line(): " << ((this->parsed_line())?"true":"false") + << "\n\toption_set[i]: " << option_set[i] + << "\n\ti: " << static_cast(i) + ); + + } + clp_check::check_one_time_options(option_set); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CMD_LINE_PARSER_CHECk_C_ + diff --git a/ml/dlib/dlib/cmd_line_parser/cmd_line_parser_kernel_1.h b/ml/dlib/dlib/cmd_line_parser/cmd_line_parser_kernel_1.h new file mode 100644 index 000000000..68ea5a135 --- /dev/null +++ b/ml/dlib/dlib/cmd_line_parser/cmd_line_parser_kernel_1.h @@ -0,0 +1,799 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CMD_LINE_PARSER_KERNEl_1_ +#define DLIB_CMD_LINE_PARSER_KERNEl_1_ + +#include "cmd_line_parser_kernel_abstract.h" +#include "../algs.h" +#include +#include +#include "../interfaces/enumerable.h" +#include "../interfaces/cmd_line_parser_option.h" +#include "../assert.h" +#include "../string.h" + +namespace dlib +{ + + template < + typename charT, + typename map, + typename sequence, + typename sequence2 + > + class cmd_line_parser_kernel_1 : public enumerable > + { + /*! + REQUIREMENTS ON map + is an implementation of map/map_kernel_abstract.h + is instantiated to map items of type std::basic_string to void* + + REQUIREMENTS ON sequence + is an implementation of sequence/sequence_kernel_abstract.h and + is instantiated with std::basic_string + + REQUIREMENTS ON sequence2 + is an implementation of sequence/sequence_kernel_abstract.h and + is instantiated with std::basic_string* + + INITIAL VALUE + options.size() == 0 + argv.size() == 0 + have_parsed_line == false + + CONVENTION + have_parsed_line == parsed_line() + argv[index] == operator[](index) + argv.size() == number_of_arguments() + *((option_t*)options[name]) == option(name) + options.is_in_domain(name) == option_is_defined(name) + !*/ + + + + + public: + + typedef charT char_type; + typedef std::basic_string string_type; + typedef cmd_line_parser_option option_type; + + // exception class + class cmd_line_parse_error : public dlib::error + { + void set_info_string ( + ) + { + std::ostringstream sout; + switch (type) + { + case EINVALID_OPTION: + sout << "Command line error: '" << narrow(item) << "' is not a valid option."; + break; + case ETOO_FEW_ARGS: + if (num > 1) + { + sout << "Command line error: The '" << narrow(item) << "' option requires " << num + << " arguments."; + } + else + { + sout << "Command line error: The '" << narrow(item) << "' option requires " << num + << " argument."; + } + break; + case ETOO_MANY_ARGS: + sout << "Command line error: The '" << narrow(item) << "' option does not take any arguments.\n"; + break; + default: + sout << "Command line error."; + break; + } + const_cast(info) = wrap_string(sout.str(),0,0); + } + + public: + cmd_line_parse_error( + error_type t, + const std::basic_string& _item + ) : + dlib::error(t), + item(_item), + num(0) + { set_info_string();} + + cmd_line_parse_error( + error_type t, + const std::basic_string& _item, + unsigned long _num + ) : + dlib::error(t), + item(_item), + num(_num) + { set_info_string();} + + cmd_line_parse_error( + ) : + dlib::error(), + item(), + num(0) + { set_info_string();} + + ~cmd_line_parse_error() throw() {} + + const std::basic_string item; + const unsigned long num; + }; + + + private: + + class option_t : public cmd_line_parser_option + { + /*! + INITIAL VALUE + options.size() == 0 + + CONVENTION + name_ == name() + description_ == description() + number_of_arguments_ == number_of_arguments() + options[N][arg] == argument(arg,N) + num_present == count() + !*/ + + friend class cmd_line_parser_kernel_1; + + public: + + const std::basic_string& name ( + ) const { return name_; } + + const std::basic_string& group_name ( + ) const { return group_name_; } + + const std::basic_string& description ( + ) const { return description_; } + + unsigned long number_of_arguments( + ) const { return number_of_arguments_; } + + unsigned long count ( + ) const { return num_present; } + + const std::basic_string& argument ( + unsigned long arg, + unsigned long N + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT( N < count() && arg < number_of_arguments(), + "\tconst string_type& cmd_line_parser_option::argument(unsigned long,unsigned long)" + << "\n\tInvalid arguments were given to this function." + << "\n\tthis: " << this + << "\n\tN: " << N + << "\n\targ: " << arg + << "\n\tname(): " << narrow(name()) + << "\n\tcount(): " << count() + << "\n\tnumber_of_arguments(): " << number_of_arguments() + ); + + return options[N][arg]; + } + + protected: + + option_t ( + ) : + num_present(0) + {} + + ~option_t() + { + clear(); + } + + private: + + void clear() + /*! + ensures + - #count() == 0 + - clears everything out of options and frees memory + !*/ + { + for (unsigned long i = 0; i < options.size(); ++i) + { + delete [] options[i]; + } + options.clear(); + num_present = 0; + } + + // data members + std::basic_string name_; + std::basic_string group_name_; + std::basic_string description_; + sequence2 options; + unsigned long number_of_arguments_; + unsigned long num_present; + + + + // restricted functions + option_t(option_t&); // copy constructor + option_t& operator=(option_t&); // assignment operator + }; + + // -------------------------- + + public: + + cmd_line_parser_kernel_1 ( + ); + + virtual ~cmd_line_parser_kernel_1 ( + ); + + void clear( + ); + + void parse ( + int argc, + const charT** argv + ); + + void parse ( + int argc, + charT** argv + ) + { + parse(argc, const_cast(argv)); + } + + bool parsed_line( + ) const; + + bool option_is_defined ( + const string_type& name + ) const; + + void add_option ( + const string_type& name, + const string_type& description, + unsigned long number_of_arguments = 0 + ); + + void set_group_name ( + const string_type& group_name + ); + + string_type get_group_name ( + ) const { return group_name; } + + const cmd_line_parser_option& option ( + const string_type& name + ) const; + + unsigned long number_of_arguments( + ) const; + + const string_type& operator[] ( + unsigned long index + ) const; + + void swap ( + cmd_line_parser_kernel_1& item + ); + + // functions from the enumerable interface + bool at_start ( + ) const { return options.at_start(); } + + void reset ( + ) const { options.reset(); } + + bool current_element_valid ( + ) const { return options.current_element_valid(); } + + const cmd_line_parser_option& element ( + ) const { return *static_cast*>(options.element().value()); } + + cmd_line_parser_option& element ( + ) { return *static_cast*>(options.element().value()); } + + bool move_next ( + ) const { return options.move_next(); } + + size_t size ( + ) const { return options.size(); } + + private: + + // data members + map options; + sequence argv; + bool have_parsed_line; + string_type group_name; + + // restricted functions + cmd_line_parser_kernel_1(cmd_line_parser_kernel_1&); // copy constructor + cmd_line_parser_kernel_1& operator=(cmd_line_parser_kernel_1&); // assignment operator + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename map, + typename sequence, + typename sequence2 + > + inline void swap ( + cmd_line_parser_kernel_1& a, + cmd_line_parser_kernel_1& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename map, + typename sequence, + typename sequence2 + > + cmd_line_parser_kernel_1:: + cmd_line_parser_kernel_1 ( + ) : + have_parsed_line(false) + { + } + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename map, + typename sequence, + typename sequence2 + > + cmd_line_parser_kernel_1:: + ~cmd_line_parser_kernel_1 ( + ) + { + // delete all option_t objects in options + options.reset(); + while (options.move_next()) + { + delete static_cast(options.element().value()); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename map, + typename sequence, + typename sequence2 + > + void cmd_line_parser_kernel_1:: + clear( + ) + { + have_parsed_line = false; + argv.clear(); + + + // delete all option_t objects in options + options.reset(); + while (options.move_next()) + { + delete static_cast(options.element().value()); + } + options.clear(); + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename map, + typename sequence, + typename sequence2 + > + void cmd_line_parser_kernel_1:: + parse ( + int argc_, + const charT** argv + ) + { + using namespace std; + + // make sure there aren't any arguments hanging around from the last time + // parse was called + this->argv.clear(); + + // make sure that the options have been cleared of any arguments since + // the last time parse() was called + if (have_parsed_line) + { + options.reset(); + while (options.move_next()) + { + static_cast(options.element().value())->clear(); + } + options.reset(); + } + + // this tells us if we have seen -- on the command line all by itself + // or not. + bool escape = false; + + const unsigned long argc = static_cast(argc_); + try + { + + for (unsigned long i = 1; i < argc; ++i) + { + if (argv[i][0] == _dT(charT,'-') && !escape) + { + // we are looking at the start of an option + + // -------------------------------------------------------------------- + if (argv[i][1] == _dT(charT,'-')) + { + // we are looking at the start of a "long named" option + string_type temp = &argv[i][2]; + string_type first_argument; + typename string_type::size_type pos = temp.find_first_of(_dT(charT,'=')); + // This variable will be 1 if there is an argument supplied via the = sign + // and 0 otherwise. + unsigned long extra_argument = 0; + if (pos != string_type::npos) + { + // there should be an extra argument + extra_argument = 1; + first_argument = temp.substr(pos+1); + temp = temp.substr(0,pos); + } + + // make sure this name is defined + if (!options.is_in_domain(temp)) + { + // the long name is not a valid option + if (argv[i][2] == _dT(charT,'\0')) + { + // there was nothing after the -- on the command line + escape = true; + continue; + } + else + { + // there was something after the command line but it + // wasn't a valid option + throw cmd_line_parse_error(EINVALID_OPTION,temp); + } + } + + + option_t* o = static_cast(options[temp]); + + // check the number of arguments after this option and make sure + // it is correct + if (argc + extra_argument <= o->number_of_arguments() + i) + { + // there are too few arguments + throw cmd_line_parse_error(ETOO_FEW_ARGS,temp,o->number_of_arguments()); + } + if (extra_argument && first_argument.size() == 0 ) + { + // if there would be exactly the right number of arguments if + // the first_argument wasn't empty + if (argc == o->number_of_arguments() + i) + throw cmd_line_parse_error(ETOO_FEW_ARGS,temp,o->number_of_arguments()); + else + { + // in this case we just ignore the trailing = and parse everything + // the same. + extra_argument = 0; + } + } + // you can't force an option that doesn't have any arguments to take + // one by using the --option=arg syntax + if (extra_argument == 1 && o->number_of_arguments() == 0) + { + throw cmd_line_parse_error(ETOO_MANY_ARGS,temp); + } + + + + + + + // at this point we know that the option is ok and we should + // populate its options object + if (o->number_of_arguments() > 0) + { + + string_type* stemp = new string_type[o->number_of_arguments()]; + unsigned long j = 0; + + // add the argument after the = sign if one is present + if (extra_argument) + { + stemp[0] = first_argument; + ++j; + } + + for (; j < o->number_of_arguments(); ++j) + { + stemp[j] = argv[i+j+1-extra_argument]; + } + o->options.add(o->options.size(),stemp); + } + o->num_present += 1; + + + // adjust the value of i to account for the arguments to + // this option + i += o->number_of_arguments() - extra_argument; + } + // -------------------------------------------------------------------- + else + { + // we are looking at the start of a list of a single char options + + // make sure there is something in this string other than - + if (argv[i][1] == _dT(charT,'\0')) + { + throw cmd_line_parse_error(); + } + + string_type temp = &argv[i][1]; + const typename string_type::size_type num = temp.size(); + for (unsigned long k = 0; k < num; ++k) + { + string_type name; + // Doing this instead of name = temp[k] seems to avoid a bug in g++ (Ubuntu/Linaro 4.5.2-8ubuntu4) 4.5.2 + // which results in name[0] having the wrong value. + name.resize(1); + name[0] = temp[k]; + + + // make sure this name is defined + if (!options.is_in_domain(name)) + { + // the name is not a valid option + throw cmd_line_parse_error(EINVALID_OPTION,name); + } + + option_t* o = static_cast(options[name]); + + // if there are chars immediately following this option + int delta = 0; + if (num != k+1) + { + delta = 1; + } + + // check the number of arguments after this option and make sure + // it is correct + if (argc + delta <= o->number_of_arguments() + i) + { + // there are too few arguments + std::ostringstream sout; + throw cmd_line_parse_error(ETOO_FEW_ARGS,name,o->number_of_arguments()); + } + + + o->num_present += 1; + + // at this point we know that the option is ok and we should + // populate its options object + if (o->number_of_arguments() > 0) + { + string_type* stemp = new string_type[o->number_of_arguments()]; + if (delta == 1) + { + temp = &argv[i][2+k]; + k = (unsigned long)num; // this ensures that the argument to this + // option isn't going to be treated as a + // list of options + + stemp[0] = temp; + } + for (unsigned long j = 0; j < o->number_of_arguments()-delta; ++j) + { + stemp[j+delta] = argv[i+j+1]; + } + o->options.add(o->options.size(),stemp); + + // adjust the value of i to account for the arguments to + // this option + i += o->number_of_arguments()-delta; + } + } // for (unsigned long k = 0; k < num; ++k) + } + // -------------------------------------------------------------------- + + } + else + { + // this is just a normal argument + string_type temp = argv[i]; + this->argv.add(this->argv.size(),temp); + } + + } + have_parsed_line = true; + + } + catch (...) + { + have_parsed_line = false; + + // clear all the option objects + options.reset(); + while (options.move_next()) + { + static_cast(options.element().value())->clear(); + } + options.reset(); + + throw; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename map, + typename sequence, + typename sequence2 + > + bool cmd_line_parser_kernel_1:: + parsed_line( + ) const + { + return have_parsed_line; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename map, + typename sequence, + typename sequence2 + > + bool cmd_line_parser_kernel_1:: + option_is_defined ( + const string_type& name + ) const + { + return options.is_in_domain(name); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename map, + typename sequence, + typename sequence2 + > + void cmd_line_parser_kernel_1:: + set_group_name ( + const string_type& group_name_ + ) + { + group_name = group_name_; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename map, + typename sequence, + typename sequence2 + > + void cmd_line_parser_kernel_1:: + add_option ( + const string_type& name, + const string_type& description, + unsigned long number_of_arguments + ) + { + option_t* temp = new option_t; + try + { + temp->name_ = name; + temp->group_name_ = group_name; + temp->description_ = description; + temp->number_of_arguments_ = number_of_arguments; + void* t = temp; + string_type n(name); + options.add(n,t); + }catch (...) { delete temp; throw;} + } + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename map, + typename sequence, + typename sequence2 + > + const cmd_line_parser_option& cmd_line_parser_kernel_1:: + option ( + const string_type& name + ) const + { + return *static_cast*>(options[name]); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename map, + typename sequence, + typename sequence2 + > + unsigned long cmd_line_parser_kernel_1:: + number_of_arguments( + ) const + { + return argv.size(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename map, + typename sequence, + typename sequence2 + > + const std::basic_string& cmd_line_parser_kernel_1:: + operator[] ( + unsigned long index + ) const + { + return argv[index]; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename map, + typename sequence, + typename sequence2 + > + void cmd_line_parser_kernel_1:: + swap ( + cmd_line_parser_kernel_1& item + ) + { + options.swap(item.options); + argv.swap(item.argv); + exchange(have_parsed_line,item.have_parsed_line); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CMD_LINE_PARSER_KERNEl_1_ + diff --git a/ml/dlib/dlib/cmd_line_parser/cmd_line_parser_kernel_abstract.h b/ml/dlib/dlib/cmd_line_parser/cmd_line_parser_kernel_abstract.h new file mode 100644 index 000000000..8461ffb26 --- /dev/null +++ b/ml/dlib/dlib/cmd_line_parser/cmd_line_parser_kernel_abstract.h @@ -0,0 +1,673 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_CMD_LINE_PARSER_KERNEl_ABSTRACT_ +#ifdef DLIB_CMD_LINE_PARSER_KERNEl_ABSTRACT_ + +#include "../algs.h" +#include +#include "../interfaces/enumerable.h" +#include "../interfaces/cmd_line_parser_option.h" +#include +#include + +namespace dlib +{ + + template < + typename charT + > + class cmd_line_parser : public enumerable > + { + /*! + REQUIREMENTS ON charT + Must be an integral type suitable for storing characters. (e.g. char + or wchar_t) + + INITIAL VALUE + - parsed_line() == false + - option_is_defined(x) == false, for all values of x + - get_group_name() == "" + + ENUMERATION ORDER + The enumerator will enumerate over all the options defined in *this + in alphebetical order according to the name of the option. + + POINTERS AND REFERENCES TO INTERNAL DATA + parsed_line(), option_is_defined(), option(), number_of_arguments(), + operator[](), and swap() functions do not invalidate pointers or + references to internal data. All other functions have no such guarantee. + + + WHAT THIS OBJECT REPRESENTS + This object represents a command line parser. + The command lines must match the following BNF. + + command_line ::= { | } [ -- {} ] + program_name ::= + arg ::= any that does not start with - + option_arg ::= + option_name ::= + long_option_name ::= { | - } + options ::= - {} {} | + -- [=] { } + char ::= any character other than - or = + word ::= any string from argv where argv is the second + parameter to main() + sword ::= any suffix of a string from argv where argv is the + second parameter to main() + bword ::= This is an empty string which denotes the begining of a + . + + + Options with arguments: + An option with N arguments will consider the next N swords to be + its arguments. + + so for example, if we have an option o that expects 2 arguments + then the following are a few legal examples: + + program -o arg1 arg2 general_argument + program -oarg1 arg2 general_argument + + arg1 and arg2 are associated with the option o and general_argument + is not. + + Arguments not associated with an option: + An argument that is not associated with an option is considered a + general command line argument and is indexed by operator[] defined + by the cmd_line_parser object. Additionally, if the string + "--" appears in the command line all by itself then all words + following it are considered to be general command line arguments. + + + Consider the following two examples involving a command line and + a cmd_line_parser object called parser. + + Example 1: + command line: program general_arg1 -o arg1 arg2 general_arg2 + Then the following is true (assuming the o option is defined + and takes 2 arguments). + + parser[0] == "general_arg1" + parser[1] == "general_arg2" + parser.number_of_arguments() == 2 + parser.option("o").argument(0) == "arg1" + parser.option("o").argument(1) == "arg2" + parser.option("o").count() == 1 + + Example 2: + command line: program general_arg1 -- -o arg1 arg2 general_arg2 + Then the following is true (the -- causes everything following + it to be treated as a general argument). + + parser[0] == "general_arg1" + parser[1] == "-o" + parser[2] == "arg1" + parser[3] == "arg2" + parser[4] == "general_arg2" + parser.number_of_arguments() == 5 + parser.option("o").count() == 0 + !*/ + + public: + + typedef charT char_type; + typedef std::basic_string string_type; + typedef cmd_line_parser_option option_type; + + // exception class + class cmd_line_parse_error : public dlib::error + { + /*! + GENERAL + This exception is thrown if there is an error detected in a + command line while it is being parsed. You can consult this + object's type and item members to determine the nature of the + error. (note that the type member is inherited from dlib::error). + + INTERPRETING THIS EXCEPTION + - if (type == EINVALID_OPTION) then + - There was an undefined option on the command line + - item == The invalid option that was on the command line + - if (type == ETOO_FEW_ARGS) then + - An option was given on the command line but it was not + supplied with the required number of arguments. + - item == The name of this option. + - num == The number of arguments expected by this option. + - if (type == ETOO_MANY_ARGS) then + - An option was given on the command line such as --option=arg + but this option doesn't take any arguments. + - item == The name of this option. + !*/ + public: + const std::basic_string item; + const unsigned long num; + }; + + // -------------------------- + + cmd_line_parser ( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc + !*/ + + virtual ~cmd_line_parser ( + ); + /*! + ensures + - all memory associated with *this has been released + !*/ + + void clear( + ); + /*! + ensures + - #*this has its initial value + throws + - std::bad_alloc + if this exception is thrown then #*this is unusable + until clear() is called and succeeds + !*/ + + void parse ( + int argc, + const charT** argv + ); + /*! + requires + - argv == an array of strings that was obtained from the second argument + of the function main(). + (i.e. argv[0] should be the token, argv[1] should be + an or token, etc.) + - argc == the number of strings in argv + ensures + - parses the command line given by argc and argv + - #parsed_line() == true + - #at_start() == true + throws + - std::bad_alloc + if this exception is thrown then #*this is unusable until clear() + is called successfully + - cmd_line_parse_error + This exception is thrown if there is an error parsing the command line. + If this exception is thrown then #parsed_line() == false and all + options will have their count() set to 0 but otherwise there will + be no effect (i.e. all registered options will remain registered). + !*/ + + void parse ( + int argc, + charT** argv + ); + /*! + This just calls this->parse(argc,argv) and performs the necessary const_cast + on argv. + !*/ + + bool parsed_line( + ) const; + /*! + ensures + - returns true if parse() has been called successfully + - returns false otherwise + !*/ + + bool option_is_defined ( + const string_type& name + ) const; + /*! + ensures + - returns true if the option has been added to the parser object + by calling add_option(name). + - returns false otherwise + !*/ + + void add_option ( + const string_type& name, + const string_type& description, + unsigned long number_of_arguments = 0 + ); + /*! + requires + - parsed_line() == false + - option_is_defined(name) == false + - name does not contain any ' ', '\t', '\n', or '=' characters + - name[0] != '-' + - name.size() > 0 + ensures + - #option_is_defined(name) == true + - #at_start() == true + - #option(name).count() == 0 + - #option(name).description() == description + - #option(name).number_of_arguments() == number_of_arguments + - #option(name).group_name() == get_group_name() + throws + - std::bad_alloc + if this exception is thrown then the add_option() function has no + effect + !*/ + + const option_type& option ( + const string_type& name + ) const; + /*! + requires + - option_is_defined(name) == true + ensures + - returns the option specified by name + !*/ + + unsigned long number_of_arguments( + ) const; + /*! + requires + - parsed_line() == true + ensures + - returns the number of arguments present in the command line. + This count does not include options or their arguments. Only + arguments unrelated to any option are counted. + !*/ + + const string_type& operator[] ( + unsigned long N + ) const; + /*! + requires + - parsed_line() == true + - N < number_of_arguments() + ensures + - returns the Nth command line argument + !*/ + + void swap ( + cmd_line_parser& item + ); + /*! + ensures + - swaps *this and item + !*/ + + void print_options ( + std::basic_ostream& out + ) const; + /*! + ensures + - prints all the command line options to out. + - #at_start() == true + throws + - any exception. + if an exception is thrown then #at_start() == true but otherwise + it will have no effect on the state of #*this. + !*/ + + void print_options ( + ) const; + /*! + ensures + - prints all the command line options to cout. + - #at_start() == true + throws + - any exception. + if an exception is thrown then #at_start() == true but otherwise + it will have no effect on the state of #*this. + !*/ + + string_type get_group_name ( + ) const; + /*! + ensures + - returns the current group name. This is the group new options will be + added into when added via add_option(). + - The group name of an option is used by print_options(). In particular, + it groups all options with the same group name together and displays them + under a title containing the text of the group name. This allows you to + group similar options together in the output of print_options(). + - A group name of "" (i.e. the empty string) means that no group name is + set. + !*/ + + void set_group_name ( + const string_type& group_name + ); + /*! + ensures + - #get_group_name() == group_name + !*/ + + // ------------------------------------------------------------- + // Input Validation Tools + // ------------------------------------------------------------- + + class cmd_line_check_error : public dlib::error + { + /*! + This is the exception thrown by the check_*() routines if they find a + command line error. The interpretation of the member variables is defined + below in each check_*() routine. + !*/ + + public: + const string_type opt; + const string_type opt2; + const string_type arg; + const std::vector required_opts; + }; + + template < + typename T + > + void check_option_arg_type ( + const string_type& option_name + ) const; + /*! + requires + - parsed_line() == true + - option_is_defined(option_name) == true + - T is not a pointer type + ensures + - all the arguments for the given option are convertible + by string_cast() to an object of type T. + throws + - std::bad_alloc + - cmd_line_check_error + This exception is thrown if the ensures clause could not be satisfied. + The exception's members will be set as follows: + - type == EINVALID_OPTION_ARG + - opt == option_name + - arg == the text of the offending argument + !*/ + + template < + typename T + > + void check_option_arg_range ( + const string_type& option_name, + const T& first, + const T& last + ) const; + /*! + requires + - parsed_line() == true + - option_is_defined(option_name) == true + - first <= last + - T is not a pointer type + ensures + - all the arguments for the given option are convertible + by string_cast() to an object of type T and the resulting value is + in the range first to last inclusive. + throws + - std::bad_alloc + - cmd_line_check_error + This exception is thrown if the ensures clause could not be satisfied. + The exception's members will be set as follows: + - type == EINVALID_OPTION_ARG + - opt == option_name + - arg == the text of the offending argument + !*/ + + template < + typename T, + size_t length + > + void check_option_arg_range ( + const string_type& option_name, + const T (&arg_set)[length] + ) const; + /*! + requires + - parsed_line() == true + - option_is_defined(option_name) == true + - T is not a pointer type + ensures + - for each argument to the given option: + - this argument is convertible by string_cast() to an object of + type T and the resulting value is equal to some element in the + arg_set array. + throws + - std::bad_alloc + - cmd_line_check_error + This exception is thrown if the ensures clause could not be satisfied. + The exception's members will be set as follows: + - type == EINVALID_OPTION_ARG + - opt == option_name + - arg == the text of the offending argument + !*/ + + template < + size_t length + > + void check_option_arg_range ( + const string_type& option_name, + const char_type* (&arg_set)[length] + ) const; + /*! + requires + - parsed_line() == true + - option_is_defined(option_name) == true + ensures + - for each argument to the given option: + - there is a string in the arg_set array that is equal to this argument. + throws + - std::bad_alloc + - cmd_line_check_error + This exception is thrown if the ensures clause could not be satisfied. + The exception's members will be set as follows: + - type == EINVALID_OPTION_ARG + - opt == option_name + - arg == the text of the offending argument + !*/ + + template < + size_t length + > + void check_one_time_options ( + const char_type* (&option_set)[length] + ) const; + /*! + requires + - parsed_line() == true + - for all valid i: + - option_is_defined(option_set[i]) == true + ensures + - all the options in the option_set array occur at most once on the + command line. + throws + - std::bad_alloc + - cmd_line_check_error + This exception is thrown if the ensures clause could not be satisfied. + The exception's members will be set as follows: + - type == EMULTIPLE_OCCURANCES + - opt == the option that occurred more than once on the command line. + !*/ + + void check_incompatible_options ( + const string_type& option_name1, + const string_type& option_name2 + ) const; + /*! + requires + - parsed_line() == true + - option_is_defined(option_name1) == true + - option_is_defined(option_name2) == true + ensures + - option(option_name1).count() == 0 || option(option_name2).count() == 0 + (i.e. at most, only one of the options is currently present) + throws + - std::bad_alloc + - cmd_line_check_error + This exception is thrown if the ensures clause could not be satisfied. + The exception's members will be set as follows: + - type == EINCOMPATIBLE_OPTIONS + - opt == option_name1 + - opt2 == option_name2 + !*/ + + template < + size_t length + > + void check_incompatible_options ( + const char_type* (&option_set)[length] + ) const; + /*! + requires + - parsed_line() == true + - for all valid i: + - option_is_defined(option_set[i]) == true + ensures + - At most only one of the options in the array option_set has a count() + greater than 0. (i.e. at most, only one of the options is currently present) + throws + - std::bad_alloc + - cmd_line_check_error + This exception is thrown if the ensures clause could not be satisfied. + The exception's members will be set as follows: + - type == EINCOMPATIBLE_OPTIONS + - opt == One of the incompatible options found. + - opt2 == The next incompatible option found. + !*/ + + void check_sub_option ( + const string_type& parent_option, + const string_type& sub_option + ) const; + /*! + requires + - parsed_line() == true + - option_is_defined(parent_option) == true + - option_is_defined(sub_option) == true + ensures + - if (option(parent_option).count() == 0) then + - option(sub_option).count() == 0 + throws + - std::bad_alloc + - cmd_line_check_error + This exception is thrown if the ensures clause could not be satisfied. + The exception's members will be set as follows: + - type == EMISSING_REQUIRED_OPTION + - opt == sub_option. + - required_opts == a vector that contains only parent_option. + !*/ + + template < + size_t length + > + void check_sub_options ( + const char_type* (&parent_option_set)[length], + const string_type& sub_option + ) const; + /*! + requires + - parsed_line() == true + - option_is_defined(sub_option) == true + - for all valid i: + - option_is_defined(parent_option_set[i] == true + ensures + - if (option(sub_option).count() > 0) then + - At least one of the options in the array parent_option_set has a count() + greater than 0. (i.e. at least one of the options in parent_option_set + is currently present) + throws + - std::bad_alloc + - cmd_line_check_error + This exception is thrown if the ensures clause could not be satisfied. + The exception's members will be set as follows: + - type == EMISSING_REQUIRED_OPTION + - opt == the first option from the sub_option that is present. + - required_opts == a vector containing everything from parent_option_set. + !*/ + + template < + size_t length + > + void check_sub_options ( + const string_type& parent_option, + const char_type* (&sub_option_set)[length] + ) const; + /*! + requires + - parsed_line() == true + - option_is_defined(parent_option) == true + - for all valid i: + - option_is_defined(sub_option_set[i]) == true + ensures + - if (option(parent_option).count() == 0) then + - for all valid i: + - option(sub_option_set[i]).count() == 0 + throws + - std::bad_alloc + - cmd_line_check_error + This exception is thrown if the ensures clause could not be satisfied. + The exception's members will be set as follows: + - type == EMISSING_REQUIRED_OPTION + - opt == the first option from the sub_option_set that is present. + - required_opts == a vector that contains only parent_option. + !*/ + + template < + size_t parent_length, + size_t sub_length + > + void check_sub_options ( + const char_type* (&parent_option_set)[parent_length], + const char_type* (&sub_option_set)[sub_length] + ) const; + /*! + requires + - parsed_line() == true + - for all valid i: + - option_is_defined(parent_option_set[i] == true + - for all valid j: + - option_is_defined(sub_option_set[j]) == true + ensures + - for all valid j: + - if (option(sub_option_set[j]).count() > 0) then + - At least one of the options in the array parent_option_set has a count() + greater than 0. (i.e. at least one of the options in parent_option_set + is currently present) + throws + - std::bad_alloc + - cmd_line_check_error + This exception is thrown if the ensures clause could not be satisfied. + The exception's members will be set as follows: + - type == EMISSING_REQUIRED_OPTION + - opt == the first option from the sub_option_set that is present. + - required_opts == a vector containing everything from parent_option_set. + !*/ + + + private: + + // restricted functions + cmd_line_parser(cmd_line_parser&); // copy constructor + cmd_line_parser& operator=(cmd_line_parser&); // assignment operator + + }; + +// ----------------------------------------------------------------------------------------- + + typedef cmd_line_parser command_line_parser; + typedef cmd_line_parser wcommand_line_parser; + +// ----------------------------------------------------------------------------------------- + + template < + typename charT + > + inline void swap ( + cmd_line_parser& a, + cmd_line_parser& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + +// ----------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CMD_LINE_PARSER_KERNEl_ABSTRACT_ + diff --git a/ml/dlib/dlib/cmd_line_parser/cmd_line_parser_kernel_c.h b/ml/dlib/dlib/cmd_line_parser/cmd_line_parser_kernel_c.h new file mode 100644 index 000000000..e80543018 --- /dev/null +++ b/ml/dlib/dlib/cmd_line_parser/cmd_line_parser_kernel_c.h @@ -0,0 +1,203 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CMD_LINE_PARSER_KERNEl_C_ +#define DLIB_CMD_LINE_PARSER_KERNEl_C_ + +#include "cmd_line_parser_kernel_abstract.h" +#include "../algs.h" +#include "../assert.h" +#include +#include "../interfaces/cmd_line_parser_option.h" +#include "../string.h" + +namespace dlib +{ + + template < + typename clp_base + > + class cmd_line_parser_kernel_c : public clp_base + { + public: + + typedef typename clp_base::char_type char_type; + typedef typename clp_base::string_type string_type; + typedef typename clp_base::option_type option_type; + + void add_option ( + const string_type& name, + const string_type& description, + unsigned long number_of_arguments = 0 + ); + + const option_type& option ( + const string_type& name + ) const; + + unsigned long number_of_arguments( + ) const; + + const option_type& element ( + ) const; + + option_type& element ( + ); + + const string_type& operator[] ( + unsigned long N + ) const; + + }; + + + template < + typename clp_base + > + inline void swap ( + cmd_line_parser_kernel_c& a, + cmd_line_parser_kernel_c& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename clp_base + > + const typename clp_base::string_type& cmd_line_parser_kernel_c:: + operator[] ( + unsigned long N + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT( this->parsed_line() == true && N < number_of_arguments(), + "\tvoid cmd_line_parser::operator[](unsigned long N)" + << "\n\tYou must specify a valid index N and the parser must have run already." + << "\n\tthis: " << this + << "\n\tN: " << N + << "\n\tparsed_line(): " << this->parsed_line() + << "\n\tnumber_of_arguments(): " << number_of_arguments() + ); + + return clp_base::operator[](N); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename clp_base + > + void cmd_line_parser_kernel_c:: + add_option ( + const string_type& name, + const string_type& description, + unsigned long number_of_arguments + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( this->parsed_line() == false && + name.size() > 0 && + this->option_is_defined(name) == false && + name.find_first_of(_dT(char_type," \t\n=")) == string_type::npos && + name[0] != '-', + "\tvoid cmd_line_parser::add_option(const string_type&,const string_type&,unsigned long)" + << "\n\tsee the requires clause of add_option()" + << "\n\tthis: " << this + << "\n\tname.size(): " << static_cast(name.size()) + << "\n\tname: \"" << narrow(name) << "\"" + << "\n\tparsed_line(): " << (this->parsed_line()? "true" : "false") + << "\n\tis_option_defined(\"" << narrow(name) << "\"): " << (this->option_is_defined(name)? "true" : "false") + ); + + clp_base::add_option(name,description,number_of_arguments); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename clp_base + > + const typename clp_base::option_type& cmd_line_parser_kernel_c:: + option ( + const string_type& name + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT( this->option_is_defined(name) == true, + "\toption cmd_line_parser::option(const string_type&)" + << "\n\tto get an option it must be defined by a call to add_option()" + << "\n\tthis: " << this + << "\n\tname: \"" << narrow(name) << "\"" + ); + + return clp_base::option(name); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename clp_base + > + unsigned long cmd_line_parser_kernel_c:: + number_of_arguments( + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT( this->parsed_line() == true , + "\tunsigned long cmd_line_parser::number_of_arguments()" + << "\n\tyou must parse the command line before you can find out how many arguments it has" + << "\n\tthis: " << this + ); + + return clp_base::number_of_arguments(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename clp_base + > + const typename clp_base::option_type& cmd_line_parser_kernel_c:: + element ( + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT(this->current_element_valid() == true, + "\tconst cmd_line_parser_option& cmd_line_parser::element()" + << "\n\tyou can't access the current element if it doesn't exist" + << "\n\tthis: " << this + ); + + // call the real function + return clp_base::element(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename clp_base + > + typename clp_base::option_type& cmd_line_parser_kernel_c:: + element ( + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(this->current_element_valid() == true, + "\tcmd_line_parser_option& cmd_line_parser::element()" + << "\n\tyou can't access the current element if it doesn't exist" + << "\n\tthis: " << this + ); + + // call the real function + return clp_base::element(); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CMD_LINE_PARSER_KERNEl_C_ + diff --git a/ml/dlib/dlib/cmd_line_parser/cmd_line_parser_print_1.h b/ml/dlib/dlib/cmd_line_parser/cmd_line_parser_print_1.h new file mode 100644 index 000000000..3f52c842f --- /dev/null +++ b/ml/dlib/dlib/cmd_line_parser/cmd_line_parser_print_1.h @@ -0,0 +1,205 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CMD_LINE_PARSER_PRINt_1_ +#define DLIB_CMD_LINE_PARSER_PRINt_1_ + +#include "cmd_line_parser_kernel_abstract.h" +#include "../algs.h" +#include "../string.h" +#include +#include +#include +#include +#include + +namespace dlib +{ + + template < + typename clp_base + > + class cmd_line_parser_print_1 : public clp_base + { + + public: + + void print_options ( + std::basic_ostream& out + ) const; + + void print_options ( + ) const + { + print_options(std::cout); + } + + }; + + template < + typename clp_base + > + inline void swap ( + cmd_line_parser_print_1& a, + cmd_line_parser_print_1& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename clp_base + > + void cmd_line_parser_print_1:: + print_options ( + std::basic_ostream& out + ) const + { + typedef typename clp_base::char_type ct; + typedef std::basic_string string; + typedef typename string::size_type size_type; + + typedef std::basic_ostringstream ostringstream; + + try + { + + + size_type max_len = 0; + this->reset(); + + // this loop here is just the bottom loop but without the print statements. + // I'm doing this to figure out what len should be. + while (this->move_next()) + { + size_type len = 0; + len += 3; + if (this->element().name().size() > 1) + { + ++len; + } + len += this->element().name().size(); + + if (this->element().number_of_arguments() == 1) + { + len += 6; + } + else + { + for (unsigned long i = 0; i < this->element().number_of_arguments(); ++i) + { + len += 7; + if (i+1 > 9) + ++len; + } + } + + len += 3; + if (len < 33) + max_len = std::max(max_len,len); + } + + + // Make a separate ostringstream for each option group. We are going to write + // the output for each group to a separate ostringstream so that we can keep + // them grouped together in the final output. + std::map > groups; + this->reset(); + while(this->move_next()) + { + if (!groups[this->element().group_name()]) + groups[this->element().group_name()].reset(new ostringstream); + } + + + + + this->reset(); + + while (this->move_next()) + { + ostringstream& sout = *groups[this->element().group_name()]; + + size_type len = 0; + sout << _dT(ct,"\n -"); + len += 3; + if (this->element().name().size() > 1) + { + sout << _dT(ct,"-"); + ++len; + } + sout << this->element().name(); + len += this->element().name().size(); + + if (this->element().number_of_arguments() == 1) + { + sout << _dT(ct," "); + len += 6; + } + else + { + for (unsigned long i = 0; i < this->element().number_of_arguments(); ++i) + { + sout << _dT(ct," "); + len += 7; + if (i+1 > 9) + ++len; + } + } + + sout << _dT(ct," "); + len += 3; + + while (len < max_len) + { + ++len; + sout << _dT(ct," "); + } + + const unsigned long ml = static_cast(max_len); + // now print the description but make it wrap around nicely if it + // is to long to fit on one line. + if (len <= max_len) + sout << wrap_string(this->element().description(),0,ml); + else + sout << _dT(ct,"\n") << wrap_string(this->element().description(),ml,ml); + } + + // Only print out a generic Options: group name if there is an unnamed option + // present. + if (groups.count(string()) == 1) + out << _dT(ct,"Options:"); + + // Now print everything out + typename std::map >::iterator i; + for (i = groups.begin(); i != groups.end(); ++i) + { + // print the group name if we have one + if (i->first.size() != 0) + { + if (i != groups.begin()) + out << _dT(ct,"\n\n"); + out << i->first << _dT(ct,":"); + } + + // print the options in the group + out << i->second->str(); + } + out << _dT(ct,"\n\n"); + this->reset(); + } + catch (...) + { + this->reset(); + throw; + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CMD_LINE_PARSER_PRINt_1_ + diff --git a/ml/dlib/dlib/cmd_line_parser/get_option.h b/ml/dlib/dlib/cmd_line_parser/get_option.h new file mode 100644 index 000000000..2c8d1644f --- /dev/null +++ b/ml/dlib/dlib/cmd_line_parser/get_option.h @@ -0,0 +1,181 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_GET_OPTiON_Hh_ +#define DLIB_GET_OPTiON_Hh_ + +#include "get_option_abstract.h" +#include "../string.h" +#include "../is_kind.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class option_parse_error : public error + { + public: + option_parse_error(const std::string& option_string, const std::string& str): + error(EOPTION_PARSE,"Error parsing argument for option '" + option_string + "', offending string is '" + str + "'.") {} + }; + +// ---------------------------------------------------------------------------------------- + + template + T impl_config_reader_get_option ( + const config_reader_type& cr, + const std::string& option_name, + const std::string& full_option_name, + T default_value + ) + { + std::string::size_type pos = option_name.find_first_of("."); + if (pos == std::string::npos) + { + if (cr.is_key_defined(option_name)) + { + try{ return string_cast(cr[option_name]); } + catch (string_cast_error&) { throw option_parse_error(full_option_name, cr[option_name]); } + } + } + else + { + std::string block_name = option_name.substr(0,pos); + if (cr.is_block_defined(block_name)) + { + return impl_config_reader_get_option(cr.block(block_name), + option_name.substr(pos+1), + full_option_name, + default_value); + } + } + + return default_value; + } + +// ---------------------------------------------------------------------------------------- + + template + typename enable_if,T>::type get_option ( + const cr_type& cr, + const std::string& option_name, + T default_value + ) + { + return impl_config_reader_get_option(cr, option_name, option_name, default_value); + } + +// ---------------------------------------------------------------------------------------- + + template + typename disable_if,T>::type get_option ( + const parser_type& parser, + const std::string& option_name, + T default_value + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( parser.option_is_defined(option_name) == true && + parser.option(option_name).number_of_arguments() == 1, + "\t T get_option()" + << "\n\t option_name: " << option_name + << "\n\t parser.option_is_defined(option_name): " << parser.option_is_defined(option_name) + << "\n\t parser.option(option_name).number_of_arguments(): " << parser.option(option_name).number_of_arguments() + ); + + if (parser.option(option_name)) + { + try + { + default_value = string_cast(parser.option(option_name).argument()); + } + catch (string_cast_error&) + { + throw option_parse_error(option_name, parser.option(option_name).argument()); + } + } + return default_value; + } + +// ---------------------------------------------------------------------------------------- + + template + typename disable_if,T>::type get_option ( + const parser_type& parser, + const cr_type& cr, + const std::string& option_name, + T default_value + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( parser.option_is_defined(option_name) == true && + parser.option(option_name).number_of_arguments() == 1, + "\t T get_option()" + << "\n\t option_name: " << option_name + << "\n\t parser.option_is_defined(option_name): " << parser.option_is_defined(option_name) + << "\n\t parser.option(option_name).number_of_arguments(): " << parser.option(option_name).number_of_arguments() + ); + + if (parser.option(option_name)) + return get_option(parser, option_name, default_value); + else + return get_option(cr, option_name, default_value); + } + +// ---------------------------------------------------------------------------------------- + + template + typename disable_if,T>::type get_option ( + const cr_type& cr, + const parser_type& parser, + const std::string& option_name, + T default_value + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( parser.option_is_defined(option_name) == true && + parser.option(option_name).number_of_arguments() == 1, + "\t T get_option()" + << "\n\t option_name: " << option_name + << "\n\t parser.option_is_defined(option_name): " << parser.option_is_defined(option_name) + << "\n\t parser.option(option_name).number_of_arguments(): " << parser.option(option_name).number_of_arguments() + ); + + if (parser.option(option_name)) + return get_option(parser, option_name, default_value); + else + return get_option(cr, option_name, default_value); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template + inline std::string get_option ( + const T& cr, + const std::string& option_name, + const char* default_value + ) + { + return get_option(cr, option_name, std::string(default_value)); + } + +// ---------------------------------------------------------------------------------------- + + template + inline std::string get_option ( + const T& parser, + const U& cr, + const std::string& option_name, + const char* default_value + ) + { + return get_option(parser, cr, option_name, std::string(default_value)); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_GET_OPTiON_Hh_ + diff --git a/ml/dlib/dlib/cmd_line_parser/get_option_abstract.h b/ml/dlib/dlib/cmd_line_parser/get_option_abstract.h new file mode 100644 index 000000000..90dc16721 --- /dev/null +++ b/ml/dlib/dlib/cmd_line_parser/get_option_abstract.h @@ -0,0 +1,146 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_GET_OPTiON_ABSTRACT_Hh_ +#ifdef DLIB_GET_OPTiON_ABSTRACT_Hh_ + +#inclue + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class option_parse_error : public error + { + /*! + WHAT THIS OBJECT REPRESENTS + This is the exception thrown by the get_option() functions. It is + thrown when the option string given by a command line parser or + config reader can't be converted into the type T. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename config_reader_type, + typename T + > + T get_option ( + const config_reader_type& cr, + const std::string& option_name, + T default_value + ); + /*! + requires + - T is a type which can be read from an input stream + - config_reader_type == an implementation of config_reader/config_reader_kernel_abstract.h + ensures + - option_name is used to index into the given config_reader. + - if (cr contains an entry corresponding to option_name) then + - converts the string value in cr corresponding to option_name into + an object of type T and returns it. + - else + - returns default_value + - The scheme for indexing into cr based on option_name is best + understood by looking at a few examples: + - an option name of "name" corresponds to cr["name"] + - an option name of "block1.name" corresponds to cr.block("block1")["name"] + - an option name of "block1.block2.name" corresponds to cr.block("block1").block("block2")["name"] + throws + - option_parse_error + This exception is thrown if we attempt but fail to convert the string value + in cr into an object of type T. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename command_line_parser_type, + typename T + > + T get_option ( + const command_line_parser_type& parser, + const std::string& option_name, + T default_value + ); + /*! + requires + - parser.option_is_defined(option_name) == true + - parser.option(option_name).number_of_arguments() == 1 + - T is a type which can be read from an input stream + - command_line_parser_type == an implementation of cmd_line_parser/cmd_line_parser_kernel_abstract.h + ensures + - if (parser.option(option_name)) then + - converts parser.option(option_name).argument() into an object + of type T and returns it. That is, the string argument to this + command line option is converted into a T and returned. + - else + - returns default_value + throws + - option_parse_error + This exception is thrown if we attempt but fail to convert the string + argument into an object of type T. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename command_line_parser_type, + typename config_reader_type, + typename T + > + T get_option ( + const command_line_parser_type& parser, + const config_reader_type& cr, + const std::string& option_name, + T default_value + ); + /*! + requires + - parser.option_is_defined(option_name) == true + - parser.option(option_name).number_of_arguments() == 1 + - T is a type which can be read from an input stream + - command_line_parser_type == an implementation of cmd_line_parser/cmd_line_parser_kernel_abstract.h + - config_reader_type == an implementation of config_reader/config_reader_kernel_abstract.h + ensures + - if (parser.option(option_name)) then + - returns get_option(parser, option_name, default_value) + - else + - returns get_option(cr, option_name, default_value) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename command_line_parser_type, + typename config_reader_type, + typename T + > + T get_option ( + const config_reader_type& cr, + const command_line_parser_type& parser, + const std::string& option_name, + T default_value + ); + /*! + requires + - parser.option_is_defined(option_name) == true + - parser.option(option_name).number_of_arguments() == 1 + - T is a type which can be read from an input stream + - command_line_parser_type == an implementation of cmd_line_parser/cmd_line_parser_kernel_abstract.h + - config_reader_type == an implementation of config_reader/config_reader_kernel_abstract.h + ensures + - if (parser.option(option_name)) then + - returns get_option(parser, option_name, default_value) + - else + - returns get_option(cr, option_name, default_value) + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_GET_OPTiON_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/compress_stream.h b/ml/dlib/dlib/compress_stream.h new file mode 100644 index 000000000..8ccc1d52f --- /dev/null +++ b/ml/dlib/dlib/compress_stream.h @@ -0,0 +1,133 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_COMPRESS_STREAm_ +#define DLIB_COMPRESS_STREAm_ + +#include "compress_stream/compress_stream_kernel_1.h" +#include "compress_stream/compress_stream_kernel_2.h" +#include "compress_stream/compress_stream_kernel_3.h" + +#include "conditioning_class.h" +#include "entropy_encoder.h" +#include "entropy_decoder.h" + +#include "entropy_encoder_model.h" +#include "entropy_decoder_model.h" +#include "lz77_buffer.h" +#include "sliding_buffer.h" +#include "lzp_buffer.h" +#include "crc32.h" + + +namespace dlib +{ + + class compress_stream + { + compress_stream() {} + + typedef entropy_encoder_model<257,entropy_encoder::kernel_2a>::kernel_1b fce1; + typedef entropy_decoder_model<257,entropy_decoder::kernel_2a>::kernel_1b fcd1; + + typedef entropy_encoder_model<257,entropy_encoder::kernel_2a>::kernel_2b fce2; + typedef entropy_decoder_model<257,entropy_decoder::kernel_2a>::kernel_2b fcd2; + + typedef entropy_encoder_model<257,entropy_encoder::kernel_2a>::kernel_3b fce3; + typedef entropy_decoder_model<257,entropy_decoder::kernel_2a>::kernel_3b fcd3; + + typedef entropy_encoder_model<257,entropy_encoder::kernel_2a>::kernel_4a fce4a; + typedef entropy_decoder_model<257,entropy_decoder::kernel_2a>::kernel_4a fcd4a; + typedef entropy_encoder_model<257,entropy_encoder::kernel_2a>::kernel_4b fce4b; + typedef entropy_decoder_model<257,entropy_decoder::kernel_2a>::kernel_4b fcd4b; + + typedef entropy_encoder_model<257,entropy_encoder::kernel_2a>::kernel_5a fce5a; + typedef entropy_decoder_model<257,entropy_decoder::kernel_2a>::kernel_5a fcd5a; + typedef entropy_encoder_model<257,entropy_encoder::kernel_2a>::kernel_5b fce5b; + typedef entropy_decoder_model<257,entropy_decoder::kernel_2a>::kernel_5b fcd5b; + typedef entropy_encoder_model<257,entropy_encoder::kernel_2a>::kernel_5c fce5c; + typedef entropy_decoder_model<257,entropy_decoder::kernel_2a>::kernel_5c fcd5c; + + typedef entropy_encoder_model<257,entropy_encoder::kernel_2a>::kernel_6a fce6; + typedef entropy_decoder_model<257,entropy_decoder::kernel_2a>::kernel_6a fcd6; + + + typedef entropy_encoder_model<257,entropy_encoder::kernel_2a>::kernel_2d fce2d; + typedef entropy_decoder_model<257,entropy_decoder::kernel_2a>::kernel_2d fcd2d; + + typedef sliding_buffer::kernel_1a sliding_buffer1; + typedef lz77_buffer::kernel_2a lz77_buffer2a; + + + typedef lzp_buffer::kernel_1a lzp_buf_1; + typedef lzp_buffer::kernel_2a lzp_buf_2; + + + typedef entropy_encoder_model<513,entropy_encoder::kernel_2a>::kernel_1b fce_length; + typedef entropy_decoder_model<513,entropy_decoder::kernel_2a>::kernel_1b fcd_length; + + typedef entropy_encoder_model<65534,entropy_encoder::kernel_2a>::kernel_1b fce_length_2; + typedef entropy_decoder_model<65534,entropy_decoder::kernel_2a>::kernel_1b fcd_length_2; + + + typedef entropy_encoder_model<32257,entropy_encoder::kernel_2a>::kernel_1b fce_index; + typedef entropy_decoder_model<32257,entropy_decoder::kernel_2a>::kernel_1b fcd_index; + + public: + + //----------- kernels --------------- + + // kernel_1a + typedef compress_stream_kernel_1 + kernel_1a; + + // kernel_1b + typedef compress_stream_kernel_1 + kernel_1b; + + // kernel_1c + typedef compress_stream_kernel_1 + kernel_1c; + + // kernel_1da + typedef compress_stream_kernel_1 + kernel_1da; + + // kernel_1ea + typedef compress_stream_kernel_1 + kernel_1ea; + + // kernel_1db + typedef compress_stream_kernel_1 + kernel_1db; + + // kernel_1eb + typedef compress_stream_kernel_1 + kernel_1eb; + + // kernel_1ec + typedef compress_stream_kernel_1 + kernel_1ec; + + + + + // kernel_2a + typedef compress_stream_kernel_2 + kernel_2a; + + + + + // kernel_3a + typedef compress_stream_kernel_3 + kernel_3a; + // kernel_3b + typedef compress_stream_kernel_3 + kernel_3b; + + + }; +} + +#endif // DLIB_COMPRESS_STREAm_ + diff --git a/ml/dlib/dlib/compress_stream/compress_stream_kernel_1.h b/ml/dlib/dlib/compress_stream/compress_stream_kernel_1.h new file mode 100644 index 000000000..1a75ec6ce --- /dev/null +++ b/ml/dlib/dlib/compress_stream/compress_stream_kernel_1.h @@ -0,0 +1,252 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_COMPRESS_STREAM_KERNEl_1_ +#define DLIB_COMPRESS_STREAM_KERNEl_1_ + +#include "../algs.h" +#include +#include +#include +#include "compress_stream_kernel_abstract.h" + +namespace dlib +{ + + template < + typename fce, + typename fcd, + typename crc32 + > + class compress_stream_kernel_1 + { + /*! + REQUIREMENTS ON fce + is an implementation of entropy_encoder_model/entropy_encoder_model_kernel_abstract.h + the alphabet_size of fce must be 257. + fce and fcd share the same kernel number. + + REQUIREMENTS ON fcd + is an implementation of entropy_decoder_model/entropy_decoder_model_kernel_abstract.h + the alphabet_size of fcd must be 257. + fce and fcd share the same kernel number. + + REQUIREMENTS ON crc32 + is an implementation of crc32/crc32_kernel_abstract.h + + + + INITIAL VALUE + this object has no state + + CONVENTION + this object has no state + !*/ + + const static unsigned long eof_symbol = 256; + + public: + + class decompression_error : public dlib::error + { + public: + decompression_error( + const char* i + ) : + dlib::error(std::string(i)) + {} + + decompression_error( + const std::string& i + ) : + dlib::error(i) + {} + }; + + + compress_stream_kernel_1 ( + ) + {} + + ~compress_stream_kernel_1 ( + ) + {} + + void compress ( + std::istream& in, + std::ostream& out + ) const; + + void decompress ( + std::istream& in, + std::ostream& out + ) const; + + private: + + // restricted functions + compress_stream_kernel_1(compress_stream_kernel_1&); // copy constructor + compress_stream_kernel_1& operator=(compress_stream_kernel_1&); // assignment operator + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename fce, + typename fcd, + typename crc32 + > + void compress_stream_kernel_1:: + compress ( + std::istream& in_, + std::ostream& out_ + ) const + { + std::streambuf::int_type temp; + + std::streambuf& in = *in_.rdbuf(); + + typename fce::entropy_encoder_type coder; + coder.set_stream(out_); + + fce model(coder); + + crc32 crc; + + unsigned long count = 0; + + while (true) + { + // write out a known value every 20000 symbols + if (count == 20000) + { + count = 0; + coder.encode(1500,1501,8000); + } + ++count; + + // get the next character + temp = in.sbumpc(); + + // if we have hit EOF then encode the marker symbol + if (temp != EOF) + { + // encode the symbol + model.encode(static_cast(temp)); + crc.add(static_cast(temp)); + continue; + } + else + { + model.encode(eof_symbol); + + // now write the checksum + unsigned long checksum = crc.get_checksum(); + unsigned char byte1 = static_cast((checksum>>24)&0xFF); + unsigned char byte2 = static_cast((checksum>>16)&0xFF); + unsigned char byte3 = static_cast((checksum>>8)&0xFF); + unsigned char byte4 = static_cast((checksum)&0xFF); + + model.encode(byte1); + model.encode(byte2); + model.encode(byte3); + model.encode(byte4); + + break; + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename fce, + typename fcd, + typename crc32 + > + void compress_stream_kernel_1:: + decompress ( + std::istream& in_, + std::ostream& out_ + ) const + { + + std::streambuf& out = *out_.rdbuf(); + + typename fcd::entropy_decoder_type coder; + coder.set_stream(in_); + + fcd model(coder); + + unsigned long symbol; + unsigned long count = 0; + + crc32 crc; + + // decode until we hit the marker symbol + while (true) + { + // make sure this is the value we expect + if (count == 20000) + { + if (coder.get_target(8000) != 1500) + { + throw decompression_error("Error detected in compressed data stream."); + } + count = 0; + coder.decode(1500,1501); + } + ++count; + + // decode the next symbol + model.decode(symbol); + if (symbol != eof_symbol) + { + crc.add(static_cast(symbol)); + // write this symbol to out + if (out.sputc(static_cast(symbol)) != static_cast(symbol)) + { + throw std::ios::failure("error occurred in compress_stream_kernel_1::decompress"); + } + continue; + } + else + { + // we read eof from the encoded data. now we just have to check the checksum and we are done. + unsigned char byte1; + unsigned char byte2; + unsigned char byte3; + unsigned char byte4; + + model.decode(symbol); byte1 = static_cast(symbol); + model.decode(symbol); byte2 = static_cast(symbol); + model.decode(symbol); byte3 = static_cast(symbol); + model.decode(symbol); byte4 = static_cast(symbol); + + unsigned long checksum = byte1; + checksum <<= 8; + checksum |= byte2; + checksum <<= 8; + checksum |= byte3; + checksum <<= 8; + checksum |= byte4; + + if (checksum != crc.get_checksum()) + throw decompression_error("Error detected in compressed data stream."); + + break; + } + } // while (true) + + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_COMPRESS_STREAM_KERNEl_1_ + diff --git a/ml/dlib/dlib/compress_stream/compress_stream_kernel_2.h b/ml/dlib/dlib/compress_stream/compress_stream_kernel_2.h new file mode 100644 index 000000000..e46b23fad --- /dev/null +++ b/ml/dlib/dlib/compress_stream/compress_stream_kernel_2.h @@ -0,0 +1,431 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_COMPRESS_STREAM_KERNEl_2_ +#define DLIB_COMPRESS_STREAM_KERNEl_2_ + +#include "../algs.h" +#include +#include +#include "compress_stream_kernel_abstract.h" + +namespace dlib +{ + + template < + typename fce, + typename fcd, + typename lz77_buffer, + typename sliding_buffer, + typename fce_length, + typename fcd_length, + typename fce_index, + typename fcd_index, + typename crc32 + > + class compress_stream_kernel_2 + { + /*! + REQUIREMENTS ON fce + is an implementation of entropy_encoder_model/entropy_encoder_model_kernel_abstract.h + the alphabet_size of fce must be 257. + fce and fcd share the same kernel number. + + REQUIREMENTS ON fcd + is an implementation of entropy_decoder_model/entropy_decoder_model_kernel_abstract.h + the alphabet_size of fcd must be 257. + fce and fcd share the same kernel number. + + REQUIREMENTS ON lz77_buffer + is an implementation of lz77_buffer/lz77_buffer_kernel_abstract.h + + REQUIREMENTS ON sliding_buffer + is an implementation of sliding_buffer/sliding_buffer_kernel_abstract.h + is instantiated with T = unsigned char + + REQUIREMENTS ON fce_length + is an implementation of entropy_encoder_model/entropy_encoder_model_kernel_abstract.h + the alphabet_size of fce must be 513. This will be used to encode the length of lz77 matches. + fce_length and fcd share the same kernel number. + + REQUIREMENTS ON fcd_length + is an implementation of entropy_decoder_model/entropy_decoder_model_kernel_abstract.h + the alphabet_size of fcd must be 513. This will be used to decode the length of lz77 matches. + fce_length and fcd share the same kernel number. + + REQUIREMENTS ON fce_index + is an implementation of entropy_encoder_model/entropy_encoder_model_kernel_abstract.h + the alphabet_size of fce must be 32257. This will be used to encode the index of lz77 matches. + fce_index and fcd share the same kernel number. + + REQUIREMENTS ON fcd_index + is an implementation of entropy_decoder_model/entropy_decoder_model_kernel_abstract.h + the alphabet_size of fcd must be 32257. This will be used to decode the index of lz77 matches. + fce_index and fcd share the same kernel number. + + REQUIREMENTS ON crc32 + is an implementation of crc32/crc32_kernel_abstract.h + + INITIAL VALUE + this object has no state + + CONVENTION + this object has no state + !*/ + + const static unsigned long eof_symbol = 256; + + public: + + class decompression_error : public dlib::error + { + public: + decompression_error( + const char* i + ) : + dlib::error(std::string(i)) + {} + + decompression_error( + const std::string& i + ) : + dlib::error(i) + {} + }; + + + compress_stream_kernel_2 ( + ) + {} + + ~compress_stream_kernel_2 ( + ) + {} + + void compress ( + std::istream& in, + std::ostream& out + ) const; + + void decompress ( + std::istream& in, + std::ostream& out + ) const; + + private: + + // restricted functions + compress_stream_kernel_2(compress_stream_kernel_2&); // copy constructor + compress_stream_kernel_2& operator=(compress_stream_kernel_2&); // assignment operator + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename fce, + typename fcd, + typename lz77_buffer, + typename sliding_buffer, + typename fce_length, + typename fcd_length, + typename fce_index, + typename fcd_index, + typename crc32 + > + void compress_stream_kernel_2:: + compress ( + std::istream& in_, + std::ostream& out_ + ) const + { + std::streambuf::int_type temp; + + std::streambuf& in = *in_.rdbuf(); + + typename fce::entropy_encoder_type coder; + coder.set_stream(out_); + + fce model(coder); + fce_length model_length(coder); + fce_index model_index(coder); + + const unsigned long LOOKAHEAD_LIMIT = 512; + lz77_buffer buffer(15,LOOKAHEAD_LIMIT); + + crc32 crc; + + + unsigned long count = 0; + + unsigned long lz77_count = 1; // number of times we used lz77 to encode + unsigned long ppm_count = 1; // number of times we used ppm to encode + + + while (true) + { + // write out a known value every 20000 symbols + if (count == 20000) + { + count = 0; + coder.encode(150,151,400); + } + ++count; + + // try to fill the lookahead buffer + if (buffer.get_lookahead_buffer_size() < buffer.get_lookahead_buffer_limit()) + { + temp = in.sbumpc(); + while (temp != EOF) + { + crc.add(static_cast(temp)); + buffer.add(static_cast(temp)); + if (buffer.get_lookahead_buffer_size() == buffer.get_lookahead_buffer_limit()) + break; + temp = in.sbumpc(); + } + } + + // compute the sum of ppm_count and lz77_count but make sure + // it is less than 65536 + unsigned long sum = ppm_count + lz77_count; + if (sum >= 65536) + { + ppm_count >>= 1; + lz77_count >>= 1; + ppm_count |= 1; + lz77_count |= 1; + sum = ppm_count+lz77_count; + } + + // if there are still more symbols in the lookahead buffer to encode + if (buffer.get_lookahead_buffer_size() > 0) + { + unsigned long match_index, match_length; + buffer.find_match(match_index,match_length,6); + if (match_length != 0) + { + + // signal the decoder that we are using lz77 + coder.encode(0,lz77_count,sum); + ++lz77_count; + + // encode the index and length pair + model_index.encode(match_index); + model_length.encode(match_length); + + } + else + { + + // signal the decoder that we are using ppm + coder.encode(lz77_count,sum,sum); + ++ppm_count; + + // encode the symbol using the ppm model + model.encode(buffer.lookahead_buffer(0)); + buffer.shift_buffers(1); + } + } + else + { + // signal the decoder that we are using ppm + coder.encode(lz77_count,sum,sum); + + + model.encode(eof_symbol); + // now write the checksum + unsigned long checksum = crc.get_checksum(); + unsigned char byte1 = static_cast((checksum>>24)&0xFF); + unsigned char byte2 = static_cast((checksum>>16)&0xFF); + unsigned char byte3 = static_cast((checksum>>8)&0xFF); + unsigned char byte4 = static_cast((checksum)&0xFF); + + model.encode(byte1); + model.encode(byte2); + model.encode(byte3); + model.encode(byte4); + + break; + } + } // while (true) + } + +// ---------------------------------------------------------------------------------------- + + template < + typename fce, + typename fcd, + typename lz77_buffer, + typename sliding_buffer, + typename fce_length, + typename fcd_length, + typename fce_index, + typename fcd_index, + typename crc32 + > + void compress_stream_kernel_2:: + decompress ( + std::istream& in_, + std::ostream& out_ + ) const + { + + std::streambuf& out = *out_.rdbuf(); + + typename fcd::entropy_decoder_type coder; + coder.set_stream(in_); + + fcd model(coder); + fcd_length model_length(coder); + fcd_index model_index(coder); + + unsigned long symbol; + unsigned long count = 0; + + sliding_buffer buffer; + buffer.set_size(15); + + // Initialize the buffer to all zeros. There is no algorithmic reason to + // do this. But doing so avoids a warning from valgrind so that is why + // I'm doing this. + for (unsigned long i = 0; i < buffer.size(); ++i) + buffer[i] = 0; + + crc32 crc; + + unsigned long lz77_count = 1; // number of times we used lz77 to encode + unsigned long ppm_count = 1; // number of times we used ppm to encode + bool next_block_lz77; + + + // decode until we hit the marker symbol + while (true) + { + // make sure this is the value we expect + if (count == 20000) + { + if (coder.get_target(400) != 150) + { + throw decompression_error("Error detected in compressed data stream."); + } + count = 0; + coder.decode(150,151); + } + ++count; + + + // compute the sum of ppm_count and lz77_count but make sure + // it is less than 65536 + unsigned long sum = ppm_count + lz77_count; + if (sum >= 65536) + { + ppm_count >>= 1; + lz77_count >>= 1; + ppm_count |= 1; + lz77_count |= 1; + sum = ppm_count+lz77_count; + } + + // check if we are decoding a lz77 or ppm block + if (coder.get_target(sum) < lz77_count) + { + coder.decode(0,lz77_count); + next_block_lz77 = true; + ++lz77_count; + } + else + { + coder.decode(lz77_count,sum); + next_block_lz77 = false; + ++ppm_count; + } + + + if (next_block_lz77) + { + + unsigned long match_length, match_index; + // decode the match index + model_index.decode(match_index); + + // decode the match length + model_length.decode(match_length); + + + match_index += match_length; + buffer.rotate_left(match_length); + for (unsigned long i = 0; i < match_length; ++i) + { + unsigned char ch = buffer[match_index-i]; + buffer[match_length-i-1] = ch; + + crc.add(ch); + // write this ch to out + if (out.sputc(static_cast(ch)) != static_cast(ch)) + { + throw std::ios::failure("error occurred in compress_stream_kernel_2::decompress"); + } + } + + } + else + { + + // decode the next symbol + model.decode(symbol); + if (symbol != eof_symbol) + { + buffer.rotate_left(1); + buffer[0] = static_cast(symbol); + + + crc.add(static_cast(symbol)); + // write this symbol to out + if (out.sputc(static_cast(symbol)) != static_cast(symbol)) + { + throw std::ios::failure("error occurred in compress_stream_kernel_2::decompress"); + } + } + else + { + // this was the eof marker symbol so we are done. now check the checksum + + // now get the checksum and make sure it matches + unsigned char byte1; + unsigned char byte2; + unsigned char byte3; + unsigned char byte4; + + model.decode(symbol); byte1 = static_cast(symbol); + model.decode(symbol); byte2 = static_cast(symbol); + model.decode(symbol); byte3 = static_cast(symbol); + model.decode(symbol); byte4 = static_cast(symbol); + + unsigned long checksum = byte1; + checksum <<= 8; + checksum |= byte2; + checksum <<= 8; + checksum |= byte3; + checksum <<= 8; + checksum |= byte4; + + if (checksum != crc.get_checksum()) + throw decompression_error("Error detected in compressed data stream."); + + break; + } + } + + } // while (true) + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_COMPRESS_STREAM_KERNEl_2_ + diff --git a/ml/dlib/dlib/compress_stream/compress_stream_kernel_3.h b/ml/dlib/dlib/compress_stream/compress_stream_kernel_3.h new file mode 100644 index 000000000..ed4eee290 --- /dev/null +++ b/ml/dlib/dlib/compress_stream/compress_stream_kernel_3.h @@ -0,0 +1,381 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_COMPRESS_STREAM_KERNEl_3_ +#define DLIB_COMPRESS_STREAM_KERNEl_3_ + +#include "../algs.h" +#include "compress_stream_kernel_abstract.h" +#include "../assert.h" + +namespace dlib +{ + + template < + typename lzp_buf, + typename crc32, + unsigned long buffer_size + > + class compress_stream_kernel_3 + { + /*! + REQUIREMENTS ON lzp_buf + is an implementation of lzp_buffer/lzp_buffer_kernel_abstract.h + + REQUIREMENTS ON buffer_size + 10 < buffer_size < 32 + + REQUIREMENTS ON crc32 + is an implementation of crc32/crc32_kernel_abstract.h + + + INITIAL VALUE + this object has no state + + CONVENTION + this object has no state + + + This implementation uses the lzp_buffer and writes out matches + in a byte aligned format. + + !*/ + + + public: + + class decompression_error : public dlib::error + { + public: + decompression_error( + const char* i + ) : + dlib::error(std::string(i)) + {} + + decompression_error( + const std::string& i + ) : + dlib::error(i) + {} + }; + + + compress_stream_kernel_3 ( + ) + { + COMPILE_TIME_ASSERT(10 < buffer_size && buffer_size < 32); + } + + ~compress_stream_kernel_3 ( + ) + {} + + void compress ( + std::istream& in, + std::ostream& out + ) const; + + void decompress ( + std::istream& in, + std::ostream& out + ) const; + + + + private: + + inline void write ( + unsigned char symbol + ) const + { + if (out->sputn(reinterpret_cast(&symbol),1)==0) + throw std::ios_base::failure("error writing to output stream in compress_stream_kernel_3"); + } + + inline void decode ( + unsigned char& symbol, + unsigned char& flag + ) const + { + if (count == 0) + { + if (((size_t)in->sgetn(reinterpret_cast(buffer),sizeof(buffer)))!=sizeof(buffer)) + throw decompression_error("Error detected in compressed data stream."); + count = 8; + } + --count; + symbol = buffer[8-count]; + flag = buffer[0] >> 7; + buffer[0] <<= 1; + } + + inline void encode ( + unsigned char symbol, + unsigned char flag + ) const + /*! + requires + - 0 <= flag <= 1 + ensures + - writes symbol with the given one bit flag + !*/ + { + // add this symbol and flag to the buffer + ++count; + buffer[0] <<= 1; + buffer[count] = symbol; + buffer[0] |= flag; + + if (count == 8) + { + if (((size_t)out->sputn(reinterpret_cast(buffer),sizeof(buffer)))!=sizeof(buffer)) + throw std::ios_base::failure("error writing to output stream in compress_stream_kernel_3"); + count = 0; + buffer[0] = 0; + } + } + + void clear ( + ) const + /*! + ensures + - resets the buffers + !*/ + { + count = 0; + } + + void flush ( + ) const + /*! + ensures + - flushes any data in the buffers to out + !*/ + { + if (count != 0) + { + buffer[0] <<= (8-count); + if (((size_t)out->sputn(reinterpret_cast(buffer),sizeof(buffer)))!=sizeof(buffer)) + throw std::ios_base::failure("error writing to output stream in compress_stream_kernel_3"); + } + } + + mutable unsigned int count; + // count tells us how many bytes are buffered in buffer and how many flag + // bit are currently in buffer[0] + mutable unsigned char buffer[9]; + // buffer[0] holds the flag bits to be writen. + // the rest of the buffer holds the bytes to be writen. + + mutable std::streambuf* in; + mutable std::streambuf* out; + + // restricted functions + compress_stream_kernel_3(compress_stream_kernel_3&); // copy constructor + compress_stream_kernel_3& operator=(compress_stream_kernel_3&); // assignment operator + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename lzp_buf, + typename crc32, + unsigned long buffer_size + > + void compress_stream_kernel_3:: + compress ( + std::istream& in_, + std::ostream& out_ + ) const + { + in = in_.rdbuf(); + out = out_.rdbuf(); + clear(); + + crc32 crc; + + lzp_buf buffer(buffer_size); + + std::streambuf::int_type temp = in->sbumpc(); + unsigned long index; + unsigned char symbol; + unsigned char length; + + while (temp != EOF) + { + symbol = static_cast(temp); + if (buffer.predict_match(index)) + { + if (buffer[index] == symbol) + { + // this is a match so we must find out how long it is + length = 1; + + buffer.add(symbol); + crc.add(symbol); + + temp = in->sbumpc(); + while (length < 255) + { + if (temp == EOF) + { + break; + } + else if (static_cast(length) >= index) + { + break; + } + else if (static_cast(temp) == buffer[index]) + { + ++length; + buffer.add(static_cast(temp)); + crc.add(static_cast(temp)); + temp = in->sbumpc(); + } + else + { + break; + } + } + + encode(length,1); + } + else + { + // this is also not a match + encode(symbol,0); + buffer.add(symbol); + crc.add(symbol); + + // get the next symbol + temp = in->sbumpc(); + } + } + else + { + // there wasn't a match so just write this symbol + encode(symbol,0); + buffer.add(symbol); + crc.add(symbol); + + // get the next symbol + temp = in->sbumpc(); + } + } + + // use a match of zero length to indicate EOF + encode(0,1); + + // now write the checksum + unsigned long checksum = crc.get_checksum(); + unsigned char byte1 = static_cast((checksum>>24)&0xFF); + unsigned char byte2 = static_cast((checksum>>16)&0xFF); + unsigned char byte3 = static_cast((checksum>>8)&0xFF); + unsigned char byte4 = static_cast((checksum)&0xFF); + + encode(byte1,0); + encode(byte2,0); + encode(byte3,0); + encode(byte4,0); + + flush(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename lzp_buf, + typename crc32, + unsigned long buffer_size + > + void compress_stream_kernel_3:: + decompress ( + std::istream& in_, + std::ostream& out_ + ) const + { + in = in_.rdbuf(); + out = out_.rdbuf(); + clear(); + + crc32 crc; + + lzp_buf buffer(buffer_size); + + + unsigned long index = 0; + unsigned char symbol; + unsigned char length; + unsigned char flag; + + decode(symbol,flag); + while (flag == 0 || symbol != 0) + { + buffer.predict_match(index); + + if (flag == 1) + { + length = symbol; + do + { + --length; + symbol = buffer[index]; + write(symbol); + buffer.add(symbol); + crc.add(symbol); + } while (length != 0); + } + else + { + // this is just a literal + write(symbol); + buffer.add(symbol); + crc.add(symbol); + } + decode(symbol,flag); + } + + + // now get the checksum and make sure it matches + unsigned char byte1; + unsigned char byte2; + unsigned char byte3; + unsigned char byte4; + + decode(byte1,flag); + if (flag != 0) + throw decompression_error("Error detected in compressed data stream."); + decode(byte2,flag); + if (flag != 0) + throw decompression_error("Error detected in compressed data stream."); + decode(byte3,flag); + if (flag != 0) + throw decompression_error("Error detected in compressed data stream."); + decode(byte4,flag); + if (flag != 0) + throw decompression_error("Error detected in compressed data stream."); + + unsigned long checksum = byte1; + checksum <<= 8; + checksum |= byte2; + checksum <<= 8; + checksum |= byte3; + checksum <<= 8; + checksum |= byte4; + + if (checksum != crc.get_checksum()) + throw decompression_error("Error detected in compressed data stream."); + + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_COMPRESS_STREAM_KERNEl_3_ + diff --git a/ml/dlib/dlib/compress_stream/compress_stream_kernel_abstract.h b/ml/dlib/dlib/compress_stream/compress_stream_kernel_abstract.h new file mode 100644 index 000000000..48f46d9e1 --- /dev/null +++ b/ml/dlib/dlib/compress_stream/compress_stream_kernel_abstract.h @@ -0,0 +1,94 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_COMPRESS_STREAM_KERNEl_ABSTRACT_ +#ifdef DLIB_COMPRESS_STREAM_KERNEl_ABSTRACT_ + +#include "../algs.h" +#include + +namespace dlib +{ + + class compress_stream + { + /*! + INITIAL VALUE + This object does not have any state associated with it. + + WHAT THIS OBJECT REPRESENTS + This object consists of the two functions compress and decompress. + These functions allow you to compress and decompress data. + !*/ + + public: + + class decompression_error : public dlib::error {}; + + compress_stream ( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc + !*/ + + virtual ~compress_stream ( + ); + /*! + ensures + - all memory associated with *this has been released + !*/ + + + void compress ( + std::istream& in, + std::ostream& out + ) const; + /*! + ensures + - reads all data from in (until EOF is reached) and compresses it + and writes it to out + throws + - std::ios_base::failure + if there was a problem writing to out then this exception will + be thrown. + - any other exception + this exception may be thrown if there is any other problem + !*/ + + + void decompress ( + std::istream& in, + std::ostream& out + ) const; + /*! + ensures + - reads data from in, decompresses it and writes it to out. note that + it stops reading data from in when it encounters the end of the + compressed data, not when it encounters EOF. + throws + - std::ios_base::failure + if there was a problem writing to out then this exception will + be thrown. + - decompression_error + if an error was detected in the compressed data that prevented + it from being correctly decompressed then this exception is + thrown. + - any other exception + this exception may be thrown if there is any other problem + !*/ + + + private: + + // restricted functions + compress_stream(compress_stream&); // copy constructor + compress_stream& operator=(compress_stream&); // assignment operator + + }; + +} + +#endif // DLIB_COMPRESS_STREAM_KERNEl_ABSTRACT_ + diff --git a/ml/dlib/dlib/conditioning_class.h b/ml/dlib/dlib/conditioning_class.h new file mode 100644 index 000000000..409b98716 --- /dev/null +++ b/ml/dlib/dlib/conditioning_class.h @@ -0,0 +1,80 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CONDITIONING_CLASs_ +#define DLIB_CONDITIONING_CLASs_ + +#include "conditioning_class/conditioning_class_kernel_1.h" +#include "conditioning_class/conditioning_class_kernel_2.h" +#include "conditioning_class/conditioning_class_kernel_3.h" +#include "conditioning_class/conditioning_class_kernel_4.h" +#include "conditioning_class/conditioning_class_kernel_c.h" + + +#include "memory_manager.h" + +namespace dlib +{ + + template < + unsigned long alphabet_size + > + class conditioning_class + { + conditioning_class() {} + + typedef memory_manager::kernel_2b mm; + + public: + + //----------- kernels --------------- + + // kernel_1a + typedef conditioning_class_kernel_1 + kernel_1a; + typedef conditioning_class_kernel_c + kernel_1a_c; + + // kernel_2a + typedef conditioning_class_kernel_2 + kernel_2a; + typedef conditioning_class_kernel_c + kernel_2a_c; + + // kernel_3a + typedef conditioning_class_kernel_3 + kernel_3a; + typedef conditioning_class_kernel_c + kernel_3a_c; + + + // -------- kernel_4 --------- + + // kernel_4a + typedef conditioning_class_kernel_4 + kernel_4a; + typedef conditioning_class_kernel_c + kernel_4a_c; + + // kernel_4b + typedef conditioning_class_kernel_4 + kernel_4b; + typedef conditioning_class_kernel_c + kernel_4b_c; + + // kernel_4c + typedef conditioning_class_kernel_4 + kernel_4c; + typedef conditioning_class_kernel_c + kernel_4c_c; + + // kernel_4d + typedef conditioning_class_kernel_4 + kernel_4d; + typedef conditioning_class_kernel_c + kernel_4d_c; + + }; +} + +#endif // DLIB_CONDITIONING_CLASS_ + diff --git a/ml/dlib/dlib/conditioning_class/conditioning_class_kernel_1.h b/ml/dlib/dlib/conditioning_class/conditioning_class_kernel_1.h new file mode 100644 index 000000000..d26d80244 --- /dev/null +++ b/ml/dlib/dlib/conditioning_class/conditioning_class_kernel_1.h @@ -0,0 +1,333 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CONDITIONING_CLASS_KERNEl_1_ +#define DLIB_CONDITIONING_CLASS_KERNEl_1_ + +#include "conditioning_class_kernel_abstract.h" +#include "../assert.h" +#include "../algs.h" + +namespace dlib +{ + + template < + unsigned long alphabet_size + > + class conditioning_class_kernel_1 + { + /*! + INITIAL VALUE + total == 1 + counts == pointer to an array of alphabet_size unsigned shorts + for all i except i == alphabet_size-1: counts[i] == 0 + counts[alphabet_size-1] == 1 + + CONVENTION + counts == pointer to an array of alphabet_size unsigned shorts + get_total() == total + get_count(symbol) == counts[symbol] + + LOW_COUNT(symbol) == sum of counts[0] though counts[symbol-1] + or 0 if symbol == 0 + + get_memory_usage() == global_state.memory_usage + !*/ + + public: + + class global_state_type + { + public: + global_state_type () : memory_usage(0) {} + private: + unsigned long memory_usage; + + friend class conditioning_class_kernel_1; + }; + + conditioning_class_kernel_1 ( + global_state_type& global_state_ + ); + + ~conditioning_class_kernel_1 ( + ); + + void clear( + ); + + bool increment_count ( + unsigned long symbol, + unsigned short amount = 1 + ); + + unsigned long get_count ( + unsigned long symbol + ) const; + + unsigned long get_total ( + ) const; + + unsigned long get_range ( + unsigned long symbol, + unsigned long& low_count, + unsigned long& high_count, + unsigned long& total_count + ) const; + + void get_symbol ( + unsigned long target, + unsigned long& symbol, + unsigned long& low_count, + unsigned long& high_count + ) const; + + unsigned long get_memory_usage ( + ) const; + + global_state_type& get_global_state ( + ); + + static unsigned long get_alphabet_size ( + ); + + + private: + + // restricted functions + conditioning_class_kernel_1(conditioning_class_kernel_1&); // copy constructor + conditioning_class_kernel_1& operator=(conditioning_class_kernel_1&); // assignment operator + + // data members + unsigned short total; + unsigned short* counts; + global_state_type& global_state; + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + conditioning_class_kernel_1:: + conditioning_class_kernel_1 ( + global_state_type& global_state_ + ) : + total(1), + counts(new unsigned short[alphabet_size]), + global_state(global_state_) + { + COMPILE_TIME_ASSERT( 1 < alphabet_size && alphabet_size < 65536 ); + + unsigned short* start = counts; + unsigned short* end = counts+alphabet_size-1; + while (start != end) + { + *start = 0; + ++start; + } + *start = 1; + + // update memory usage + global_state.memory_usage += sizeof(unsigned short)*alphabet_size + + sizeof(conditioning_class_kernel_1); + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + conditioning_class_kernel_1:: + ~conditioning_class_kernel_1 ( + ) + { + delete [] counts; + // update memory usage + global_state.memory_usage -= sizeof(unsigned short)*alphabet_size + + sizeof(conditioning_class_kernel_1); + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + void conditioning_class_kernel_1:: + clear( + ) + { + total = 1; + unsigned short* start = counts; + unsigned short* end = counts+alphabet_size-1; + while (start != end) + { + *start = 0; + ++start; + } + *start = 1; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + unsigned long conditioning_class_kernel_1:: + get_memory_usage( + ) const + { + return global_state.memory_usage; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + typename conditioning_class_kernel_1::global_state_type& conditioning_class_kernel_1:: + get_global_state( + ) + { + return global_state; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + bool conditioning_class_kernel_1:: + increment_count ( + unsigned long symbol, + unsigned short amount + ) + { + // if we are going over a total of 65535 then scale down all counts by 2 + if (static_cast(total)+static_cast(amount) >= 65536) + { + total = 0; + unsigned short* start = counts; + unsigned short* end = counts+alphabet_size; + while (start != end) + { + *start >>= 1; + total += *start; + ++start; + } + // make sure it is at least one + if (counts[alphabet_size-1]==0) + { + ++total; + counts[alphabet_size-1] = 1; + } + } + counts[symbol] += amount; + total += amount; + return true; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + unsigned long conditioning_class_kernel_1:: + get_count ( + unsigned long symbol + ) const + { + return counts[symbol]; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + unsigned long conditioning_class_kernel_1:: + get_alphabet_size ( + ) + { + return alphabet_size; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + unsigned long conditioning_class_kernel_1:: + get_total ( + ) const + { + return total; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + unsigned long conditioning_class_kernel_1:: + get_range ( + unsigned long symbol, + unsigned long& low_count, + unsigned long& high_count, + unsigned long& total_count + ) const + { + if (counts[symbol] == 0) + return 0; + + total_count = total; + + const unsigned short* start = counts; + const unsigned short* end = counts+symbol; + unsigned short high_count_temp = *start; + while (start != end) + { + ++start; + high_count_temp += *start; + } + low_count = high_count_temp - *start; + high_count = high_count_temp; + return *start; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + void conditioning_class_kernel_1:: + get_symbol ( + unsigned long target, + unsigned long& symbol, + unsigned long& low_count, + unsigned long& high_count + ) const + { + unsigned long high_count_temp = *counts; + const unsigned short* start = counts; + while (target >= high_count_temp) + { + ++start; + high_count_temp += *start; + } + + low_count = high_count_temp - *start; + high_count = high_count_temp; + symbol = static_cast(start-counts); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CONDITIONING_CLASS_KERNEl_1_ + diff --git a/ml/dlib/dlib/conditioning_class/conditioning_class_kernel_2.h b/ml/dlib/dlib/conditioning_class/conditioning_class_kernel_2.h new file mode 100644 index 000000000..c9b38c8e3 --- /dev/null +++ b/ml/dlib/dlib/conditioning_class/conditioning_class_kernel_2.h @@ -0,0 +1,500 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CONDITIONING_CLASS_KERNEl_2_ +#define DLIB_CONDITIONING_CLASS_KERNEl_2_ + +#include "conditioning_class_kernel_abstract.h" +#include "../assert.h" +#include "../algs.h" + +namespace dlib +{ + + template < + unsigned long alphabet_size + > + class conditioning_class_kernel_2 + { + /*! + INITIAL VALUE + total == 1 + symbols == pointer to array of alphabet_size data structs + for all i except i == alphabet_size-1: symbols[i].count == 0 + symbols[i].left_count == 0 + + symbols[alphabet_size-1].count == 1 + symbols[alpahbet_size-1].left_count == 0 + + CONVENTION + symbols == pointer to array of alphabet_size data structs + get_total() == total + get_count(symbol) == symbols[symbol].count + + symbols is organized as a tree with symbols[0] as the root. + + the left subchild of symbols[i] is symbols[i*2+1] and + the right subchild is symbols[i*2+2]. + the partent of symbols[i] == symbols[(i-1)/2] + + symbols[i].left_count == the sum of the counts of all the + symbols to the left of symbols[i] + + get_memory_usage() == global_state.memory_usage + !*/ + + public: + + class global_state_type + { + public: + global_state_type () : memory_usage(0) {} + private: + unsigned long memory_usage; + + friend class conditioning_class_kernel_2; + }; + + conditioning_class_kernel_2 ( + global_state_type& global_state_ + ); + + ~conditioning_class_kernel_2 ( + ); + + void clear( + ); + + bool increment_count ( + unsigned long symbol, + unsigned short amount = 1 + ); + + unsigned long get_count ( + unsigned long symbol + ) const; + + inline unsigned long get_total ( + ) const; + + unsigned long get_range ( + unsigned long symbol, + unsigned long& low_count, + unsigned long& high_count, + unsigned long& total_count + ) const; + + void get_symbol ( + unsigned long target, + unsigned long& symbol, + unsigned long& low_count, + unsigned long& high_count + ) const; + + unsigned long get_memory_usage ( + ) const; + + global_state_type& get_global_state ( + ); + + static unsigned long get_alphabet_size ( + ); + + private: + + // restricted functions + conditioning_class_kernel_2(conditioning_class_kernel_2&); // copy constructor + conditioning_class_kernel_2& operator=(conditioning_class_kernel_2&); // assignment operator + + // data members + unsigned short total; + struct data + { + unsigned short count; + unsigned short left_count; + }; + + data* symbols; + global_state_type& global_state; + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + conditioning_class_kernel_2:: + conditioning_class_kernel_2 ( + global_state_type& global_state_ + ) : + total(1), + symbols(new data[alphabet_size]), + global_state(global_state_) + { + COMPILE_TIME_ASSERT( 1 < alphabet_size && alphabet_size < 65536 ); + + data* start = symbols; + data* end = symbols + alphabet_size-1; + + while (start != end) + { + start->count = 0; + start->left_count = 0; + ++start; + } + + start->count = 1; + start->left_count = 0; + + + // update the left_counts for the symbol alphabet_size-1 + unsigned short temp; + unsigned long symbol = alphabet_size-1; + while (symbol != 0) + { + // temp will be 1 if symbol is odd, 0 if it is even + temp = static_cast(symbol&0x1); + + // set symbol to its parent + symbol = (symbol-1)>>1; + + // note that all left subchidren are odd and also that + // if symbol was a left subchild then we want to increment + // its parents left_count + if (temp) + ++symbols[symbol].left_count; + } + + global_state.memory_usage += sizeof(data)*alphabet_size + + sizeof(conditioning_class_kernel_2); + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + conditioning_class_kernel_2:: + ~conditioning_class_kernel_2 ( + ) + { + delete [] symbols; + global_state.memory_usage -= sizeof(data)*alphabet_size + + sizeof(conditioning_class_kernel_2); + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + void conditioning_class_kernel_2:: + clear( + ) + { + data* start = symbols; + data* end = symbols + alphabet_size-1; + + total = 1; + + while (start != end) + { + start->count = 0; + start->left_count = 0; + ++start; + } + + start->count = 1; + start->left_count = 0; + + // update the left_counts + unsigned short temp; + unsigned long symbol = alphabet_size-1; + while (symbol != 0) + { + // temp will be 1 if symbol is odd, 0 if it is even + temp = static_cast(symbol&0x1); + + // set symbol to its parent + symbol = (symbol-1)>>1; + + // note that all left subchidren are odd and also that + // if symbol was a left subchild then we want to increment + // its parents left_count + symbols[symbol].left_count += temp; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + unsigned long conditioning_class_kernel_2:: + get_memory_usage( + ) const + { + return global_state.memory_usage; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + typename conditioning_class_kernel_2::global_state_type& conditioning_class_kernel_2:: + get_global_state( + ) + { + return global_state; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + bool conditioning_class_kernel_2:: + increment_count ( + unsigned long symbol, + unsigned short amount + ) + { + // if we need to renormalize then do so + if (static_cast(total)+static_cast(amount) >= 65536) + { + unsigned long s; + unsigned short temp; + for (unsigned short i = 0; i < alphabet_size-1; ++i) + { + s = i; + + // divide the count for this symbol by 2 + symbols[i].count >>= 1; + + symbols[i].left_count = 0; + + // bubble this change up though the tree + while (s != 0) + { + // temp will be 1 if symbol is odd, 0 if it is even + temp = static_cast(s&0x1); + + // set s to its parent + s = (s-1)>>1; + + // note that all left subchidren are odd and also that + // if s was a left subchild then we want to increment + // its parents left_count + if (temp) + symbols[s].left_count += symbols[i].count; + } + } + + // update symbols alphabet_size-1 + { + s = alphabet_size-1; + + // divide alphabet_size-1 symbol by 2 if it's > 1 + if (symbols[alphabet_size-1].count > 1) + symbols[alphabet_size-1].count >>= 1; + + // bubble this change up though the tree + while (s != 0) + { + // temp will be 1 if symbol is odd, 0 if it is even + temp = static_cast(s&0x1); + + // set s to its parent + s = (s-1)>>1; + + // note that all left subchidren are odd and also that + // if s was a left subchild then we want to increment + // its parents left_count + if (temp) + symbols[s].left_count += symbols[alphabet_size-1].count; + } + } + + + + + + + // calculate the new total + total = 0; + unsigned long m = 0; + while (m < alphabet_size) + { + total += symbols[m].count + symbols[m].left_count; + m = (m<<1) + 2; + } + + } + + + + + // increment the count for the specified symbol + symbols[symbol].count += amount;; + total += amount; + + + unsigned short temp; + while (symbol != 0) + { + // temp will be 1 if symbol is odd, 0 if it is even + temp = static_cast(symbol&0x1); + + // set symbol to its parent + symbol = (symbol-1)>>1; + + // note that all left subchidren are odd and also that + // if symbol was a left subchild then we want to increment + // its parents left_count + if (temp) + symbols[symbol].left_count += amount; + } + + return true; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + unsigned long conditioning_class_kernel_2:: + get_count ( + unsigned long symbol + ) const + { + return symbols[symbol].count; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + unsigned long conditioning_class_kernel_2:: + get_alphabet_size ( + ) + { + return alphabet_size; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + unsigned long conditioning_class_kernel_2:: + get_total ( + ) const + { + return total; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + unsigned long conditioning_class_kernel_2:: + get_range ( + unsigned long symbol, + unsigned long& low_count, + unsigned long& high_count, + unsigned long& total_count + ) const + { + if (symbols[symbol].count == 0) + return 0; + + unsigned long current = symbol; + total_count = total; + unsigned long high_count_temp = 0; + bool came_from_right = true; + while (true) + { + if (came_from_right) + { + high_count_temp += symbols[current].count + symbols[current].left_count; + } + + // note that if current is even then it is a right child + came_from_right = !(current&0x1); + + if (current == 0) + break; + + // set current to its parent + current = (current-1)>>1 ; + } + + + low_count = high_count_temp - symbols[symbol].count; + high_count = high_count_temp; + + return symbols[symbol].count; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + void conditioning_class_kernel_2:: + get_symbol ( + unsigned long target, + unsigned long& symbol, + unsigned long& low_count, + unsigned long& high_count + ) const + { + unsigned long current = 0; + unsigned long low_count_temp = 0; + + while (true) + { + if (static_cast(target) < symbols[current].left_count) + { + // we should go left + current = (current<<1) + 1; + } + else + { + target -= symbols[current].left_count; + low_count_temp += symbols[current].left_count; + if (static_cast(target) < symbols[current].count) + { + // we have found our target + symbol = current; + high_count = low_count_temp + symbols[current].count; + low_count = low_count_temp; + break; + } + else + { + // go right + target -= symbols[current].count; + low_count_temp += symbols[current].count; + current = (current<<1) + 2; + } + } + + } + + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CONDITIONING_CLASS_KERNEl_1_ + diff --git a/ml/dlib/dlib/conditioning_class/conditioning_class_kernel_3.h b/ml/dlib/dlib/conditioning_class/conditioning_class_kernel_3.h new file mode 100644 index 000000000..b6de48555 --- /dev/null +++ b/ml/dlib/dlib/conditioning_class/conditioning_class_kernel_3.h @@ -0,0 +1,438 @@ +// Copyright (C) 2004 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CONDITIONING_CLASS_KERNEl_3_ +#define DLIB_CONDITIONING_CLASS_KERNEl_3_ + +#include "conditioning_class_kernel_abstract.h" +#include "../assert.h" +#include "../algs.h" + + +namespace dlib +{ + + template < + unsigned long alphabet_size + > + class conditioning_class_kernel_3 + { + /*! + INITIAL VALUE + total == 1 + counts == pointer to an array of alphabet_size data structs + for all i except i == 0: counts[i].count == 0 + counts[0].count == 1 + counts[0].symbol == alphabet_size-1 + for all i except i == alphabet_size-1: counts[i].present == false + counts[alphabet_size-1].present == true + + CONVENTION + counts == pointer to an array of alphabet_size data structs + get_total() == total + get_count(symbol) == counts[x].count where + counts[x].symbol == symbol + + + LOW_COUNT(symbol) == sum of counts[0].count though counts[x-1].count + where counts[x].symbol == symbol + if (counts[0].symbol == symbol) LOW_COUNT(symbol)==0 + + + if (counts[i].count == 0) then + counts[i].symbol == undefined value + + if (symbol has a nonzero count) then + counts[symbol].present == true + + get_memory_usage() == global_state.memory_usage + !*/ + + public: + + class global_state_type + { + public: + global_state_type () : memory_usage(0) {} + private: + unsigned long memory_usage; + + friend class conditioning_class_kernel_3; + }; + + conditioning_class_kernel_3 ( + global_state_type& global_state_ + ); + + ~conditioning_class_kernel_3 ( + ); + + void clear( + ); + + bool increment_count ( + unsigned long symbol, + unsigned short amount = 1 + ); + + unsigned long get_count ( + unsigned long symbol + ) const; + + unsigned long get_total ( + ) const; + + unsigned long get_range ( + unsigned long symbol, + unsigned long& low_count, + unsigned long& high_count, + unsigned long& total_count + ) const; + + void get_symbol ( + unsigned long target, + unsigned long& symbol, + unsigned long& low_count, + unsigned long& high_count + ) const; + + unsigned long get_memory_usage ( + ) const; + + global_state_type& get_global_state ( + ); + + static unsigned long get_alphabet_size ( + ); + + private: + + // restricted functions + conditioning_class_kernel_3(conditioning_class_kernel_3&); // copy constructor + conditioning_class_kernel_3& operator=(conditioning_class_kernel_3&); // assignment operator + + struct data + { + unsigned short count; + unsigned short symbol; + bool present; + }; + + // data members + unsigned short total; + data* counts; + global_state_type& global_state; + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + conditioning_class_kernel_3:: + conditioning_class_kernel_3 ( + global_state_type& global_state_ + ) : + total(1), + counts(new data[alphabet_size]), + global_state(global_state_) + { + COMPILE_TIME_ASSERT( 1 < alphabet_size && alphabet_size < 65536 ); + + data* start = counts; + data* end = counts+alphabet_size; + start->count = 1; + start->symbol = alphabet_size-1; + start->present = false; + ++start; + while (start != end) + { + start->count = 0; + start->present = false; + ++start; + } + counts[alphabet_size-1].present = true; + + // update memory usage + global_state.memory_usage += sizeof(data)*alphabet_size + + sizeof(conditioning_class_kernel_3); + + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + conditioning_class_kernel_3:: + ~conditioning_class_kernel_3 ( + ) + { + delete [] counts; + // update memory usage + global_state.memory_usage -= sizeof(data)*alphabet_size + + sizeof(conditioning_class_kernel_3); + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + void conditioning_class_kernel_3:: + clear( + ) + { + total = 1; + data* start = counts; + data* end = counts+alphabet_size; + start->count = 1; + start->symbol = alphabet_size-1; + start->present = false; + ++start; + while (start != end) + { + start->count = 0; + start->present = false; + ++start; + } + counts[alphabet_size-1].present = true; + + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + typename conditioning_class_kernel_3::global_state_type& conditioning_class_kernel_3:: + get_global_state( + ) + { + return global_state; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + unsigned long conditioning_class_kernel_3:: + get_memory_usage( + ) const + { + return global_state.memory_usage; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + bool conditioning_class_kernel_3:: + increment_count ( + unsigned long symbol, + unsigned short amount + ) + { + // if we are going over a total of 65535 then scale down all counts by 2 + if (static_cast(total)+static_cast(amount) >= 65536) + { + total = 0; + data* start = counts; + data* end = counts+alphabet_size; + + while (start != end) + { + if (start->count == 1) + { + if (start->symbol == alphabet_size-1) + { + // this symbol must never be zero so we will leave its count at 1 + ++total; + } + else + { + start->count = 0; + counts[start->symbol].present = false; + } + } + else + { + start->count >>= 1; + total += start->count; + } + + ++start; + } + } + + + data* start = counts; + data* swap_spot = counts; + + if (counts[symbol].present) + { + while (true) + { + if (start->symbol == symbol && start->count!=0) + { + unsigned short temp = start->count + amount; + + start->symbol = swap_spot->symbol; + start->count = swap_spot->count; + + swap_spot->symbol = static_cast(symbol); + swap_spot->count = temp; + break; + } + + if ( (start->count) < (swap_spot->count)) + { + swap_spot = start; + } + + + ++start; + } + } + else + { + counts[symbol].present = true; + while (true) + { + if (start->count == 0) + { + start->symbol = swap_spot->symbol; + start->count = swap_spot->count; + + swap_spot->symbol = static_cast(symbol); + swap_spot->count = amount; + break; + } + + if ((start->count) < (swap_spot->count)) + { + swap_spot = start; + } + + ++start; + } + } + + total += amount; + + return true; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + unsigned long conditioning_class_kernel_3:: + get_count ( + unsigned long symbol + ) const + { + if (counts[symbol].present == false) + return 0; + + data* start = counts; + while (start->symbol != symbol) + { + ++start; + } + return start->count; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + unsigned long conditioning_class_kernel_3:: + get_alphabet_size ( + ) + { + return alphabet_size; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + unsigned long conditioning_class_kernel_3:: + get_total ( + ) const + { + return total; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + unsigned long conditioning_class_kernel_3:: + get_range ( + unsigned long symbol, + unsigned long& low_count, + unsigned long& high_count, + unsigned long& total_count + ) const + { + if (counts[symbol].present == false) + return 0; + + total_count = total; + unsigned long low_count_temp = 0; + data* start = counts; + while (start->symbol != symbol) + { + low_count_temp += start->count; + ++start; + } + + low_count = low_count_temp; + high_count = low_count_temp + start->count; + return start->count; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + void conditioning_class_kernel_3:: + get_symbol ( + unsigned long target, + unsigned long& symbol, + unsigned long& low_count, + unsigned long& high_count + ) const + { + unsigned long high_count_temp = counts->count; + const data* start = counts; + while (target >= high_count_temp) + { + ++start; + high_count_temp += start->count; + } + + low_count = high_count_temp - start->count; + high_count = high_count_temp; + symbol = static_cast(start->symbol); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CONDITIONING_CLASS_KERNEl_3_ + diff --git a/ml/dlib/dlib/conditioning_class/conditioning_class_kernel_4.h b/ml/dlib/dlib/conditioning_class/conditioning_class_kernel_4.h new file mode 100644 index 000000000..cb48ac196 --- /dev/null +++ b/ml/dlib/dlib/conditioning_class/conditioning_class_kernel_4.h @@ -0,0 +1,533 @@ +// Copyright (C) 2004 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CONDITIONING_CLASS_KERNEl_4_ +#define DLIB_CONDITIONING_CLASS_KERNEl_4_ + +#include "conditioning_class_kernel_abstract.h" +#include "../assert.h" +#include "../algs.h" + +namespace dlib +{ + template < + unsigned long alphabet_size, + unsigned long pool_size, + typename mem_manager + > + class conditioning_class_kernel_4 + { + /*! + REQUIREMENTS ON pool_size + pool_size > 0 + this will be the number of nodes contained in our memory pool + + REQUIREMENTS ON mem_manager + mem_manager is an implementation of memory_manager/memory_manager_kernel_abstract.h + + INITIAL VALUE + total == 1 + escapes == 1 + next == 0 + + CONVENTION + get_total() == total + get_count(alphabet_size-1) == escapes + + if (next != 0) then + next == pointer to the start of a linked list and the linked list + is terminated by a node with a next pointer of 0. + + get_count(symbol) == node::count for the node where node::symbol==symbol + or 0 if no such node currently exists. + + if (there is a node for the symbol) then + LOW_COUNT(symbol) == the sum of all node's counts in the linked list + up to but not including the node for the symbol. + + get_memory_usage() == global_state.memory_usage + !*/ + + + struct node + { + unsigned short symbol; + unsigned short count; + node* next; + }; + + public: + + class global_state_type + { + public: + global_state_type ( + ) : + memory_usage(pool_size*sizeof(node)+sizeof(global_state_type)) + {} + private: + unsigned long memory_usage; + + typename mem_manager::template rebind::other pool; + + friend class conditioning_class_kernel_4; + }; + + conditioning_class_kernel_4 ( + global_state_type& global_state_ + ); + + ~conditioning_class_kernel_4 ( + ); + + void clear( + ); + + bool increment_count ( + unsigned long symbol, + unsigned short amount = 1 + ); + + unsigned long get_count ( + unsigned long symbol + ) const; + + inline unsigned long get_total ( + ) const; + + unsigned long get_range ( + unsigned long symbol, + unsigned long& low_count, + unsigned long& high_count, + unsigned long& total_count + ) const; + + void get_symbol ( + unsigned long target, + unsigned long& symbol, + unsigned long& low_count, + unsigned long& high_count + ) const; + + unsigned long get_memory_usage ( + ) const; + + global_state_type& get_global_state ( + ); + + static unsigned long get_alphabet_size ( + ); + + + private: + + void half_counts ( + ); + /*! + ensures + - divides all counts by 2 but ensures that escapes is always at least 1 + !*/ + + // restricted functions + conditioning_class_kernel_4(conditioning_class_kernel_4&); // copy constructor + conditioning_class_kernel_4& operator=(conditioning_class_kernel_4&); // assignment operator + + // data members + unsigned short total; + unsigned short escapes; + node* next; + global_state_type& global_state; + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + unsigned long pool_size, + typename mem_manager + > + conditioning_class_kernel_4:: + conditioning_class_kernel_4 ( + global_state_type& global_state_ + ) : + total(1), + escapes(1), + next(0), + global_state(global_state_) + { + COMPILE_TIME_ASSERT( 1 < alphabet_size && alphabet_size < 65536 ); + + // update memory usage + global_state.memory_usage += sizeof(conditioning_class_kernel_4); + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + unsigned long pool_size, + typename mem_manager + > + conditioning_class_kernel_4:: + ~conditioning_class_kernel_4 ( + ) + { + clear(); + // update memory usage + global_state.memory_usage -= sizeof(conditioning_class_kernel_4); + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + unsigned long pool_size, + typename mem_manager + > + void conditioning_class_kernel_4:: + clear( + ) + { + total = 1; + escapes = 1; + while (next) + { + node* temp = next; + next = next->next; + global_state.pool.deallocate(temp); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + unsigned long pool_size, + typename mem_manager + > + unsigned long conditioning_class_kernel_4:: + get_memory_usage( + ) const + { + return global_state.memory_usage; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + unsigned long pool_size, + typename mem_manager + > + typename conditioning_class_kernel_4::global_state_type& conditioning_class_kernel_4:: + get_global_state( + ) + { + return global_state; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + unsigned long pool_size, + typename mem_manager + > + bool conditioning_class_kernel_4:: + increment_count ( + unsigned long symbol, + unsigned short amount + ) + { + if (symbol == alphabet_size-1) + { + // make sure we won't cause any overflow + if (total >= 65536 - amount ) + half_counts(); + + escapes += amount; + total += amount; + return true; + } + + + // find the symbol and increment it or add a new node to the list + if (next) + { + node* temp = next; + node* previous = 0; + while (true) + { + if (temp->symbol == static_cast(symbol)) + { + // make sure we won't cause any overflow + if (total >= 65536 - amount ) + half_counts(); + + // we have found the symbol + total += amount; + temp->count += amount; + + // if this node now has a count greater than its parent node + if (previous && temp->count > previous->count) + { + // swap the nodes so that the nodes will be in semi-sorted order + swap(temp->count,previous->count); + swap(temp->symbol,previous->symbol); + } + return true; + } + else if (temp->next == 0) + { + // we did not find the symbol so try to add it to the list + if (global_state.pool.get_number_of_allocations() < pool_size) + { + // make sure we won't cause any overflow + if (total >= 65536 - amount ) + half_counts(); + + node* t = global_state.pool.allocate(); + t->next = 0; + t->symbol = static_cast(symbol); + t->count = amount; + temp->next = t; + total += amount; + return true; + } + else + { + // no memory left + return false; + } + } + else if (temp->count == 0) + { + // remove nodes that have a zero count + if (previous) + { + previous->next = temp->next; + node* t = temp; + temp = temp->next; + global_state.pool.deallocate(t); + } + else + { + next = temp->next; + node* t = temp; + temp = temp->next; + global_state.pool.deallocate(t); + } + } + else + { + previous = temp; + temp = temp->next; + } + } // while (true) + } + // if there aren't any nodes in the list yet then do this instead + else + { + if (global_state.pool.get_number_of_allocations() < pool_size) + { + // make sure we won't cause any overflow + if (total >= 65536 - amount ) + half_counts(); + + next = global_state.pool.allocate(); + next->next = 0; + next->symbol = static_cast(symbol); + next->count = amount; + total += amount; + return true; + } + else + { + // no memory left + return false; + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + unsigned long pool_size, + typename mem_manager + > + unsigned long conditioning_class_kernel_4:: + get_count ( + unsigned long symbol + ) const + { + if (symbol == alphabet_size-1) + { + return escapes; + } + else + { + node* temp = next; + while (temp) + { + if (temp->symbol == symbol) + return temp->count; + temp = temp->next; + } + return 0; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + unsigned long pool_size, + typename mem_manager + > + unsigned long conditioning_class_kernel_4:: + get_alphabet_size ( + ) + { + return alphabet_size; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + unsigned long pool_size, + typename mem_manager + > + unsigned long conditioning_class_kernel_4:: + get_total ( + ) const + { + return total; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + unsigned long pool_size, + typename mem_manager + > + unsigned long conditioning_class_kernel_4:: + get_range ( + unsigned long symbol, + unsigned long& low_count, + unsigned long& high_count, + unsigned long& total_count + ) const + { + if (symbol != alphabet_size-1) + { + node* temp = next; + unsigned long low = 0; + while (temp) + { + if (temp->symbol == static_cast(symbol)) + { + high_count = temp->count + low; + low_count = low; + total_count = total; + return temp->count; + } + low += temp->count; + temp = temp->next; + } + return 0; + } + else + { + total_count = total; + high_count = total; + low_count = total-escapes; + return escapes; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + unsigned long pool_size, + typename mem_manager + > + void conditioning_class_kernel_4:: + get_symbol ( + unsigned long target, + unsigned long& symbol, + unsigned long& low_count, + unsigned long& high_count + ) const + { + node* temp = next; + unsigned long high = 0; + while (true) + { + if (temp != 0) + { + high += temp->count; + if (target < high) + { + symbol = temp->symbol; + high_count = high; + low_count = high - temp->count; + return; + } + temp = temp->next; + } + else + { + // this must be the escape symbol + symbol = alphabet_size-1; + low_count = total-escapes; + high_count = total; + return; + } + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // private member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + + template < + unsigned long alphabet_size, + unsigned long pool_size, + typename mem_manager + > + void conditioning_class_kernel_4:: + half_counts ( + ) + { + total = 0; + if (escapes > 1) + escapes >>= 1; + + //divide all counts by 2 + node* temp = next; + while (temp) + { + temp->count >>= 1; + total += temp->count; + temp = temp->next; + } + total += escapes; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CONDITIONING_CLASS_KERNEl_4_ + diff --git a/ml/dlib/dlib/conditioning_class/conditioning_class_kernel_abstract.h b/ml/dlib/dlib/conditioning_class/conditioning_class_kernel_abstract.h new file mode 100644 index 000000000..411aea566 --- /dev/null +++ b/ml/dlib/dlib/conditioning_class/conditioning_class_kernel_abstract.h @@ -0,0 +1,228 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_CONDITIONING_CLASS_KERNEl_ABSTRACT_ +#ifdef DLIB_CONDITIONING_CLASS_KERNEl_ABSTRACT_ + +#include "../algs.h" + +namespace dlib +{ + + template < + unsigned long alphabet_size + > + class conditioning_class + { + /*! + REQUIREMENTS ON alphabet_size + 1 < alphabet_size < 65536 + + INITIAL VALUE + get_total() == 1 + get_count(X) == 0 : for all valid values of X except alphabet_size-1 + get_count(alphabet_size-1) == 1 + + WHAT THIS OBJECT REPRESENTS + This object represents a conditioning class used for arithmetic style + compression. It maintains the cumulative counts which are needed + by the entropy_coder and entropy_decoder objects. + + At any moment a conditioning_class object represents a set of + alphabet_size symbols. Each symbol is associated with an integer + called its count. + + All symbols start out with a count of zero except for alphabet_size-1. + This last symbol will always have a count of at least one. It is + intended to be used as an escape into a lower context when coding + and so it must never have a zero probability or the decoder won't + be able to identify the escape symbol. + + NOTATION: + Let MAP(i) be a function which maps integers to symbols. MAP(i) is + one to one and onto. Its domain is 1 to alphabet_size inclusive. + + Let RMAP(s) be the inverse of MAP(i). + ( i.e. RMAP(MAP(i)) == i and MAP(RMAP(s)) == s ) + + Let COUNT(i) give the count for the symbol MAP(i). + ( i.e. COUNT(i) == get_count(MAP(i)) ) + + + Let LOW_COUNT(s) == the sum of COUNT(x) for x == 1 to x == RMAP(s)-1 + (note that the sum of COUNT(x) for x == 1 to x == 0 is 0) + Let HIGH_COUNT(s) == LOW_COUNT(s) + get_count(s) + + + + Basically what this is saying is just that you shoudln't assume you know + what order the symbols are placed in when calculating the cumulative + sums. The specific mapping provided by the MAP() function is unspecified. + + THREAD SAFETY + This object can be used safely in a multithreaded program as long as the + global state is not shared between conditioning classes which run on + different threads. + + GLOBAL_STATE_TYPE + The global_state_type obejct allows instances of the conditioning_class + object to share any kind of global state the implementer desires. + However, the global_state_type object exists primarily to facilitate the + sharing of a memory pool between many instances of a conditioning_class + object. But note that it is not required that there be any kind of + memory pool at all, it is just a possibility. + !*/ + + public: + + class global_state_type + { + global_state_type ( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc + !*/ + + // my contents are implementation specific. + }; + + conditioning_class ( + global_state_type& global_state + ); + /*! + ensures + - #*this is properly initialized + - &#get_global_state() == &global_state + throws + - std::bad_alloc + !*/ + + ~conditioning_class ( + ); + /*! + ensures + - all memory associated with *this has been released + !*/ + + void clear( + ); + /*! + ensures + - #*this has its initial value + throws + - std::bad_alloc + !*/ + + bool increment_count ( + unsigned long symbol, + unsigned short amount = 1 + ); + /*! + requires + - 0 <= symbol < alphabet_size + - 0 < amount < 32768 + ensures + - if (sufficient memory is available to complete this operation) then + - returns true + - if (get_total()+amount < 65536) then + - #get_count(symbol) == get_count(symbol) + amount + - else + - #get_count(symbol) == get_count(symbol)/2 + amount + - if (get_count(alphabet_size-1) == 1) then + - #get_count(alphabet_size-1) == 1 + - else + - #get_count(alphabet_size-1) == get_count(alphabet_size-1)/2 + - for all X where (X != symbol)&&(X != alpahbet_size-1): + #get_count(X) == get_count(X)/2 + - else + - returns false + !*/ + + unsigned long get_count ( + unsigned long symbol + ) const; + /*! + requires + - 0 <= symbol < alphabet_size + ensures + - returns the count for the specified symbol + !*/ + + unsigned long get_total ( + ) const; + /*! + ensures + - returns the sum of get_count(X) for all valid values of X + (i.e. returns the sum of the counts for all the symbols) + !*/ + + unsigned long get_range ( + unsigned long symbol, + unsigned long& low_count, + unsigned long& high_count, + unsigned long& total_count + ) const; + /*! + requires + - 0 <= symbol < alphabet_size + ensures + - returns get_count(symbol) + - if (get_count(symbol) != 0) then + - #total_count == get_total() + - #low_count == LOW_COUNT(symbol) + - #high_count == HIGH_COUNT(symbol) + - #low_count < #high_count <= #total_count + !*/ + + void get_symbol ( + unsigned long target, + unsigned long& symbol, + unsigned long& low_count, + unsigned long& high_count + ) const; + /*! + requires + - 0 <= target < get_total() + ensures + - LOW_COUNT(#symbol) <= target < HIGH_COUNT(#symbol) + - #low_count == LOW_COUNT(#symbol) + - #high_count == HIGH_COUNT(#symbol) + - #low_count < #high_count <= get_total() + !*/ + + global_state_type& get_global_state ( + ); + /*! + ensures + - returns a reference to the global state used by *this + !*/ + + unsigned long get_memory_usage ( + ) const; + /*! + ensures + - returns the number of bytes of memory allocated by all conditioning_class + objects that share the global state given by get_global_state() + !*/ + + static unsigned long get_alphabet_size ( + ); + /*! + ensures + - returns alphabet_size + !*/ + + private: + + // restricted functions + conditioning_class(conditioning_class&); // copy constructor + conditioning_class& operator=(conditioning_class&); // assignment operator + + }; + +} + +#endif // DLIB_CONDITIONING_CLASS_KERNEl_ABSTRACT_ + diff --git a/ml/dlib/dlib/conditioning_class/conditioning_class_kernel_c.h b/ml/dlib/dlib/conditioning_class/conditioning_class_kernel_c.h new file mode 100644 index 000000000..964240be8 --- /dev/null +++ b/ml/dlib/dlib/conditioning_class/conditioning_class_kernel_c.h @@ -0,0 +1,162 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CONDITIONING_CLASS_KERNEl_C_ +#define DLIB_CONDITIONING_CLASS_KERNEl_C_ + +#include "conditioning_class_kernel_abstract.h" +#include "../algs.h" +#include "../assert.h" +#include + +namespace dlib +{ + + template < + typename cc_base + > + class conditioning_class_kernel_c : public cc_base + { + const unsigned long alphabet_size; + + public: + + conditioning_class_kernel_c ( + typename cc_base::global_state_type& global_state + ) : cc_base(global_state),alphabet_size(cc_base::get_alphabet_size()) {} + + bool increment_count ( + unsigned long symbol, + unsigned short amount = 1 + ); + + unsigned long get_count ( + unsigned long symbol + ) const; + + unsigned long get_range ( + unsigned long symbol, + unsigned long& low_count, + unsigned long& high_count, + unsigned long& total_count + ) const; + + void get_symbol ( + unsigned long target, + unsigned long& symbol, + unsigned long& low_count, + unsigned long& high_count + ) const; + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename cc_base + > + bool conditioning_class_kernel_c:: + increment_count ( + unsigned long symbol, + unsigned short amount + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(symbol < alphabet_size && + 0 < amount && amount < 32768, + "\tvoid conditioning_class::increment_count()" + << "\n\tthe symbol must be in the range 0 to alphabet_size-1. and" + << "\n\tamount must be in the range 1 to 32767" + << "\n\talphabet_size: " << alphabet_size + << "\n\tsymbol: " << symbol + << "\n\tamount: " << amount + << "\n\tthis: " << this + ); + + // call the real function + return cc_base::increment_count(symbol,amount); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename cc_base + > + unsigned long conditioning_class_kernel_c:: + get_count ( + unsigned long symbol + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT(symbol < alphabet_size, + "\tvoid conditioning_class::get_count()" + << "\n\tthe symbol must be in the range 0 to alphabet_size-1" + << "\n\talphabet_size: " << alphabet_size + << "\n\tsymbol: " << symbol + << "\n\tthis: " << this + ); + + // call the real function + return cc_base::get_count(symbol); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename cc_base + > + unsigned long conditioning_class_kernel_c:: + get_range ( + unsigned long symbol, + unsigned long& low_count, + unsigned long& high_count, + unsigned long& total_count + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT(symbol < alphabet_size, + "\tvoid conditioning_class::get_range()" + << "\n\tthe symbol must be in the range 0 to alphabet_size-1" + << "\n\talphabet_size: " << alphabet_size + << "\n\tsymbol: " << symbol + << "\n\tthis: " << this + ); + + // call the real function + return cc_base::get_range(symbol,low_count,high_count,total_count); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename cc_base + > + void conditioning_class_kernel_c:: + get_symbol ( + unsigned long target, + unsigned long& symbol, + unsigned long& low_count, + unsigned long& high_count + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT( target < this->get_total(), + "\tvoid conditioning_class::get_symbol()" + << "\n\tthe target must be in the range 0 to get_total()-1" + << "\n\tget_total(): " << this->get_total() + << "\n\ttarget: " << target + << "\n\tthis: " << this + ); + + // call the real function + cc_base::get_symbol(target,symbol,low_count,high_count); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CONDITIONING_CLASS_KERNEl_C_ + diff --git a/ml/dlib/dlib/config.h b/ml/dlib/dlib/config.h new file mode 100644 index 000000000..a48ac415f --- /dev/null +++ b/ml/dlib/dlib/config.h @@ -0,0 +1,31 @@ + + +// If you are compiling dlib as a shared library and installing it somewhere on your system +// then it is important that any programs that use dlib agree on the state of the +// DLIB_ASSERT statements (i.e. they are either always on or always off). Therefore, +// uncomment one of the following lines to force all DLIB_ASSERTs to either always on or +// always off. If you don't define one of these two macros then DLIB_ASSERT will toggle +// automatically depending on the state of certain other macros, which is not what you want +// when creating a shared library. +//#define ENABLE_ASSERTS // asserts always enabled +//#define DLIB_DISABLE_ASSERTS // asserts always disabled + +//#define DLIB_ISO_CPP_ONLY +//#define DLIB_NO_GUI_SUPPORT +//#define DLIB_ENABLE_STACK_TRACE + +// You should also consider telling dlib to link against libjpeg, libpng, libgif, fftw, CUDA, +// and a BLAS and LAPACK library. To do this you need to uncomment the following #defines. +// #define DLIB_JPEG_SUPPORT +// #define DLIB_PNG_SUPPORT +// #define DLIB_GIF_SUPPORT +// #define DLIB_USE_FFTW +// #define DLIB_USE_BLAS +// #define DLIB_USE_LAPACK +// #define DLIB_USE_CUDA + + +// Define this so the code in dlib/test_for_odr_violations.h can detect ODR violations +// related to users doing bad things with config.h +#define DLIB_NOT_CONFIGURED + diff --git a/ml/dlib/dlib/config.h.in b/ml/dlib/dlib/config.h.in new file mode 100644 index 000000000..27dff2b27 --- /dev/null +++ b/ml/dlib/dlib/config.h.in @@ -0,0 +1,34 @@ + + +// If you are compiling dlib as a shared library and installing it somewhere on your system +// then it is important that any programs that use dlib agree on the state of the +// DLIB_ASSERT statements (i.e. they are either always on or always off). Therefore, +// uncomment one of the following lines to force all DLIB_ASSERTs to either always on or +// always off. If you don't define one of these two macros then DLIB_ASSERT will toggle +// automatically depending on the state of certain other macros, which is not what you want +// when creating a shared library. +#cmakedefine ENABLE_ASSERTS // asserts always enabled +#cmakedefine DLIB_DISABLE_ASSERTS // asserts always disabled + +#cmakedefine DLIB_ISO_CPP_ONLY +#cmakedefine DLIB_NO_GUI_SUPPORT +#cmakedefine DLIB_ENABLE_STACK_TRACE + +#cmakedefine LAPACK_FORCE_UNDERSCORE +#cmakedefine LAPACK_FORCE_NOUNDERSCORE + +// You should also consider telling dlib to link against libjpeg, libpng, libgif, fftw, CUDA, +// and a BLAS and LAPACK library. To do this you need to uncomment the following #defines. +#cmakedefine DLIB_JPEG_SUPPORT +#cmakedefine DLIB_PNG_SUPPORT +#cmakedefine DLIB_GIF_SUPPORT +#cmakedefine DLIB_USE_FFTW +#cmakedefine DLIB_USE_BLAS +#cmakedefine DLIB_USE_LAPACK +#cmakedefine DLIB_USE_CUDA +#cmakedefine DLIB_USE_MKL_FFT + +// This variable allows dlib/test_for_odr_violations.h to catch people who mistakenly use +// headers from one version of dlib with a compiled dlib binary from a different dlib version. +#cmakedefine DLIB_CHECK_FOR_VERSION_MISMATCH @DLIB_CHECK_FOR_VERSION_MISMATCH@ + diff --git a/ml/dlib/dlib/config_reader.h b/ml/dlib/dlib/config_reader.h new file mode 100644 index 000000000..d140a310c --- /dev/null +++ b/ml/dlib/dlib/config_reader.h @@ -0,0 +1,39 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CONFIG_READEr_ +#define DLIB_CONFIG_READEr_ + +#include "config_reader/config_reader_kernel_1.h" +#include "map.h" +#include "tokenizer.h" +#include "cmd_line_parser/get_option.h" + +#include "algs.h" +#include "is_kind.h" + + +namespace dlib +{ + + typedef config_reader_kernel_1< + map::kernel_1b, + map::kernel_1b, + tokenizer::kernel_1a + > config_reader; + + template <> struct is_config_reader { const static bool value = true; }; + +#ifndef DLIB_ISO_CPP_ONLY + typedef config_reader_thread_safe_1< + config_reader, + map::kernel_1b + > config_reader_thread_safe; + + template <> struct is_config_reader { const static bool value = true; }; +#endif // DLIB_ISO_CPP_ONLY + + +} + +#endif // DLIB_CONFIG_READEr_ + diff --git a/ml/dlib/dlib/config_reader/config_reader_kernel_1.h b/ml/dlib/dlib/config_reader/config_reader_kernel_1.h new file mode 100644 index 000000000..c0f9e5a71 --- /dev/null +++ b/ml/dlib/dlib/config_reader/config_reader_kernel_1.h @@ -0,0 +1,738 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CONFIG_READER_KERNEl_1_ +#define DLIB_CONFIG_READER_KERNEl_1_ + +#include "config_reader_kernel_abstract.h" +#include +#include +#include +#include +#include "../algs.h" +#include "../stl_checked/std_vector_c.h" + +#ifndef DLIB_ISO_CPP_ONLY +#include "config_reader_thread_safe_1.h" +#endif + +namespace dlib +{ + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + class config_reader_kernel_1 + { + + /*! + REQUIREMENTS ON map_string_string + is an implementation of map/map_kernel_abstract.h that maps std::string to std::string + + REQUIREMENTS ON map_string_void + is an implementation of map/map_kernel_abstract.h that maps std::string to void* + + REQUIREMENTS ON tokenizer + is an implementation of tokenizer/tokenizer_kernel_abstract.h + + CONVENTION + key_table.is_in_domain(x) == is_key_defined(x) + block_table.is_in_domain(x) == is_block_defined(x) + + key_table[x] == operator[](x) + block_table[x] == (void*)&block(x) + !*/ + + public: + + // These two typedefs are defined for backwards compatibility with older versions of dlib. + typedef config_reader_kernel_1 kernel_1a; +#ifndef DLIB_ISO_CPP_ONLY + typedef config_reader_thread_safe_1< + config_reader_kernel_1, + map_string_void + > thread_safe_1a; +#endif // DLIB_ISO_CPP_ONLY + + + config_reader_kernel_1(); + + class config_reader_error : public dlib::error + { + friend class config_reader_kernel_1; + config_reader_error( + unsigned long ln, + bool r = false + ) : + dlib::error(ECONFIG_READER), + line_number(ln), + redefinition(r) + { + std::ostringstream sout; + sout << "Error in config_reader while parsing at line number " << line_number << "."; + if (redefinition) + sout << "\nThe identifier on this line has already been defined in this scope."; + const_cast(info) = sout.str(); + } + public: + const unsigned long line_number; + const bool redefinition; + }; + + class file_not_found : public dlib::error + { + friend class config_reader_kernel_1; + file_not_found( + const std::string& file_name_ + ) : + dlib::error(ECONFIG_READER, "Error in config_reader, unable to open file " + file_name_), + file_name(file_name_) + {} + + ~file_not_found() throw() {} + + public: + const std::string file_name; + }; + + class config_reader_access_error : public dlib::error + { + public: + config_reader_access_error( + const std::string& block_name_, + const std::string& key_name_ + ) : + dlib::error(ECONFIG_READER), + block_name(block_name_), + key_name(key_name_) + { + std::ostringstream sout; + sout << "Error in config_reader.\n"; + if (block_name.size() > 0) + sout << " A block with the name '" << block_name << "' was expected but not found."; + else if (key_name.size() > 0) + sout << " A key with the name '" << key_name << "' was expected but not found."; + + const_cast(info) = sout.str(); + } + + ~config_reader_access_error() throw() {} + const std::string block_name; + const std::string key_name; + }; + + config_reader_kernel_1( + const std::string& config_file + ); + + config_reader_kernel_1( + std::istream& in + ); + + virtual ~config_reader_kernel_1( + ); + + void clear ( + ); + + void load_from ( + std::istream& in + ); + + void load_from ( + const std::string& config_file + ); + + bool is_key_defined ( + const std::string& key + ) const; + + bool is_block_defined ( + const std::string& name + ) const; + + typedef config_reader_kernel_1 this_type; + const this_type& block ( + const std::string& name + ) const; + + const std::string& operator[] ( + const std::string& key + ) const; + + template < + typename queue_of_strings + > + void get_keys ( + queue_of_strings& keys + ) const; + + template < + typename alloc + > + void get_keys ( + std::vector& keys + ) const; + + template < + typename alloc + > + void get_keys ( + std_vector_c& keys + ) const; + + template < + typename queue_of_strings + > + void get_blocks ( + queue_of_strings& blocks + ) const; + + template < + typename alloc + > + void get_blocks ( + std::vector& blocks + ) const; + + template < + typename alloc + > + void get_blocks ( + std_vector_c& blocks + ) const; + + private: + + static void parse_config_file ( + config_reader_kernel_1& cr, + tokenizer& tok, + unsigned long& line_number, + const bool top_of_recursion = true + ); + /*! + requires + - line_number == 1 + - cr == *this + - top_of_recursion == true + ensures + - parses the data coming from tok and puts it into cr. + throws + - config_reader_error + !*/ + + map_string_string key_table; + map_string_void block_table; + + // restricted functions + config_reader_kernel_1(config_reader_kernel_1&); + config_reader_kernel_1& operator=(config_reader_kernel_1&); + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + config_reader_kernel_1:: + config_reader_kernel_1( + ) + { + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + void config_reader_kernel_1:: + clear( + ) + { + // free all our blocks + block_table.reset(); + while (block_table.move_next()) + { + delete static_cast(block_table.element().value()); + } + block_table.clear(); + key_table.clear(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + void config_reader_kernel_1:: + load_from( + std::istream& in + ) + { + clear(); + + tokenizer tok; + tok.set_stream(in); + tok.set_identifier_token( + tok.lowercase_letters() + tok.uppercase_letters(), + tok.lowercase_letters() + tok.uppercase_letters() + tok.numbers() + "_-." + ); + + unsigned long line_number = 1; + try + { + parse_config_file(*this,tok,line_number); + } + catch (...) + { + clear(); + throw; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + void config_reader_kernel_1:: + load_from( + const std::string& config_file + ) + { + clear(); + std::ifstream fin(config_file.c_str()); + if (!fin) + throw file_not_found(config_file); + + load_from(fin); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + config_reader_kernel_1:: + config_reader_kernel_1( + std::istream& in + ) + { + load_from(in); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + config_reader_kernel_1:: + config_reader_kernel_1( + const std::string& config_file + ) + { + load_from(config_file); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + void config_reader_kernel_1:: + parse_config_file( + config_reader_kernel_1& cr, + tokenizer& tok, + unsigned long& line_number, + const bool top_of_recursion + ) + { + int type; + std::string token; + bool in_comment = false; + bool seen_identifier = false; + std::string identifier; + while (true) + { + tok.get_token(type,token); + // ignore white space + if (type == tokenizer::WHITE_SPACE) + continue; + + // basically ignore end of lines + if (type == tokenizer::END_OF_LINE) + { + ++line_number; + in_comment = false; + continue; + } + + // we are in a comment still so ignore this + if (in_comment) + continue; + + // if this is the start of a comment + if (type == tokenizer::CHAR && token[0] == '#') + { + in_comment = true; + continue; + } + + // if this is the case then we have just finished parsing a block so we should + // quit this function + if ( (type == tokenizer::CHAR && token[0] == '}' && !top_of_recursion) || + (type == tokenizer::END_OF_FILE && top_of_recursion) ) + { + break; + } + + if (seen_identifier) + { + seen_identifier = false; + // the next character should be either a '=' or a '{' + if (type != tokenizer::CHAR || (token[0] != '=' && token[0] != '{')) + throw config_reader_error(line_number); + + if (token[0] == '=') + { + // we should parse the value out now + // first discard any white space + if (tok.peek_type() == tokenizer::WHITE_SPACE) + tok.get_token(type,token); + + std::string value; + type = tok.peek_type(); + token = tok.peek_token(); + while (true) + { + if (type == tokenizer::END_OF_FILE || type == tokenizer::END_OF_LINE) + break; + + if (type == tokenizer::CHAR && token[0] == '\\') + { + tok.get_token(type,token); + if (tok.peek_type() == tokenizer::CHAR && + tok.peek_token()[0] == '#') + { + tok.get_token(type,token); + value += '#'; + } + else if (tok.peek_type() == tokenizer::CHAR && + tok.peek_token()[0] == '}') + { + tok.get_token(type,token); + value += '}'; + } + else + { + value += '\\'; + } + } + else if (type == tokenizer::CHAR && + (token[0] == '#' || token[0] == '}')) + { + break; + } + else + { + value += token; + tok.get_token(type,token); + } + type = tok.peek_type(); + token = tok.peek_token(); + } // while(true) + + // strip of any tailing white space from value + std::string::size_type pos = value.find_last_not_of(" \t\r\n"); + if (pos == std::string::npos) + value.clear(); + else + value.erase(pos+1); + + // make sure this key isn't already in the key_table + if (cr.key_table.is_in_domain(identifier)) + throw config_reader_error(line_number,true); + + // add this key/value pair to the key_table + cr.key_table.add(identifier,value); + + } + else // when token[0] == '{' + { + // make sure this identifier isn't already in the block_table + if (cr.block_table.is_in_domain(identifier)) + throw config_reader_error(line_number,true); + + config_reader_kernel_1* new_cr = new config_reader_kernel_1; + void* vtemp = new_cr; + try { cr.block_table.add(identifier,vtemp); } + catch (...) { delete new_cr; throw; } + + // now parse this block + parse_config_file(*new_cr,tok,line_number,false); + } + } + else + { + // the next thing should be an identifier but if it isn't this is an error + if (type != tokenizer::IDENTIFIER) + throw config_reader_error(line_number); + + seen_identifier = true; + identifier = token; + } + } // while (true) + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + config_reader_kernel_1:: + ~config_reader_kernel_1( + ) + { + clear(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + bool config_reader_kernel_1:: + is_key_defined ( + const std::string& key + ) const + { + return key_table.is_in_domain(key); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + bool config_reader_kernel_1:: + is_block_defined ( + const std::string& name + ) const + { + return block_table.is_in_domain(name); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename mss, + typename msv, + typename tokenizer + > + const config_reader_kernel_1& config_reader_kernel_1:: + block ( + const std::string& name + ) const + { + if (is_block_defined(name) == false) + { + throw config_reader_access_error(name,""); + } + + return *static_cast(block_table[name]); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + const std::string& config_reader_kernel_1:: + operator[] ( + const std::string& key + ) const + { + if (is_key_defined(key) == false) + { + throw config_reader_access_error("",key); + } + + return key_table[key]; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + template < + typename queue_of_strings + > + void config_reader_kernel_1:: + get_keys ( + queue_of_strings& keys + ) const + { + keys.clear(); + key_table.reset(); + std::string temp; + while (key_table.move_next()) + { + temp = key_table.element().key(); + keys.enqueue(temp); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + template < + typename alloc + > + void config_reader_kernel_1:: + get_keys ( + std::vector& keys + ) const + { + keys.clear(); + key_table.reset(); + while (key_table.move_next()) + { + keys.push_back(key_table.element().key()); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + template < + typename alloc + > + void config_reader_kernel_1:: + get_keys ( + std_vector_c& keys + ) const + { + keys.clear(); + key_table.reset(); + while (key_table.move_next()) + { + keys.push_back(key_table.element().key()); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + template < + typename queue_of_strings + > + void config_reader_kernel_1:: + get_blocks ( + queue_of_strings& blocks + ) const + { + blocks.clear(); + block_table.reset(); + std::string temp; + while (block_table.move_next()) + { + temp = block_table.element().key(); + blocks.enqueue(temp); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + template < + typename alloc + > + void config_reader_kernel_1:: + get_blocks ( + std::vector& blocks + ) const + { + blocks.clear(); + block_table.reset(); + while (block_table.move_next()) + { + blocks.push_back(block_table.element().key()); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + template < + typename alloc + > + void config_reader_kernel_1:: + get_blocks ( + std_vector_c& blocks + ) const + { + blocks.clear(); + block_table.reset(); + while (block_table.move_next()) + { + blocks.push_back(block_table.element().key()); + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CONFIG_READER_KERNEl_1_ + diff --git a/ml/dlib/dlib/config_reader/config_reader_kernel_abstract.h b/ml/dlib/dlib/config_reader/config_reader_kernel_abstract.h new file mode 100644 index 000000000..e8c44c2b2 --- /dev/null +++ b/ml/dlib/dlib/config_reader/config_reader_kernel_abstract.h @@ -0,0 +1,363 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_CONFIG_READER_KERNEl_ABSTRACT_ +#ifdef DLIB_CONFIG_READER_KERNEl_ABSTRACT_ + +#include +#include + +namespace dlib +{ + + class config_reader + { + + /*! + INITIAL VALUE + - there aren't any keys defined for this object + - there aren't any blocks defined for this object + + POINTERS AND REFERENCES TO INTERNAL DATA + The destructor, clear(), and load_from() invalidate pointers + and references to internal data. All other functions are guaranteed + to NOT invalidate pointers or references to internal data. + + WHAT THIS OBJECT REPRESENTS + This object represents something which is intended to be used to read + text configuration files that are defined by the following EBNF (with + config_file as the starting symbol): + + config_file = block; + block = { key_value_pair | sub_block }; + key_value_pair = key_name, "=", value; + sub_block = block_name, "{", block, "}"; + + key_name = identifier; + block_name = identifier; + value = matches any string of text that ends with a newline character, # or }. + note that the trailing newline, # or } is not part of the value though. + identifier = Any string that matches the following regular expression: + [a-zA-Z][a-zA-Z0-9_-\.]* + i.e. Any string that starts with a letter and then is continued + with any number of letters, numbers, _ . or - characters. + + Whitespace and comments are ignored. A comment is text that starts with # (but not \# + since the \ escapes the # so that you can have a # symbol in a value if you want) and + ends in a new line. You can also escape a } (e.g. "\}") if you want to have one in a + value. + + Note that in a value the leading and trailing white spaces are stripped off but any + white space inside the value is preserved. + + Also note that all key_names and block_names within a block syntax group must be unique + but don't have to be globally unique. I.e. different blocks can reuse names. + + EXAMPLE CONFIG FILES: + + Example 1: + #comment. This line is ignored because it starts with # + + #here we have key1 which will have the value of "my value" + key1 = my value + + another_key= another value # this is another key called "another_key" with + # a value of "another value" + + # this key's value is the empty string. I.e. "" + key2= + + Example 2: + #this example illustrates the use of blocks + some_key = blah blah + + # now here is a block + our_block + { + # here we can define some keys and values that are local to this block. + a_key = something + foo = bar + some_key = more stuff # note that it is ok to name our key this even though + # there is a key called some_key above. This is because + # we are doing so inside a different block + } + + another_block { foo = bar2 } # this block has only one key and is all on a single line + !*/ + + public: + + // exception classes + class config_reader_error : public dlib::error + { + /*! + GENERAL + This exception is thrown if there is an error while parsing the + config file. The type member of this exception will be set + to ECONFIG_READER. + + INTERPRETING THIS EXCEPTION + - line_number == the line number the parser was at when the + error occurred. + - if (redefinition) then + - The key or block name on line line_number has already + been defined in this scope which is an error. + - else + - Some other general syntax error was detected + !*/ + public: + const unsigned long line_number; + const bool redefinition; + }; + + class file_not_found : public dlib::error + { + /*! + GENERAL + This exception is thrown if the config file can't be opened for + some reason. The type member of this exception will be set + to ECONFIG_READER. + + INTERPRETING THIS EXCEPTION + - file_name == the name of the config file which we failed to open + !*/ + public: + const std::string file_name; + }; + + + class config_reader_access_error : public dlib::error + { + /*! + GENERAL + This exception is thrown if you try to access a key or + block that doesn't exist inside a config reader. The type + member of this exception will be set to ECONFIG_READER. + !*/ + public: + config_reader_access_error( + const std::string& block_name_, + const std::string& key_name_ + ); + /*! + ensures + - #block_name == block_name_ + - #key_name == key_name_ + !*/ + + const std::string block_name; + const std::string key_name; + }; + + // -------------------------- + + config_reader( + ); + /*! + ensures + - #*this is properly initialized + - This object will not have any keys or blocks defined in it. + throws + - std::bad_alloc + - config_reader_error + !*/ + + config_reader( + std::istream& in + ); + /*! + ensures + - #*this is properly initialized + - reads the config file to parse from the given input stream, + parses it and loads this object up with all the sub blocks and + key/value pairs it finds. + - before the load is performed, the previous state of the config file + reader is erased. So after the load the config file reader will contain + only information from the given config file. + - This object will represent the top most block of the config file. + throws + - std::bad_alloc + - config_reader_error + !*/ + + config_reader( + const std::string& config_file + ); + /*! + ensures + - #*this is properly initialized + - parses the config file named by the config_file string. Specifically, + parses it and loads this object up with all the sub blocks and + key/value pairs it finds in the file. + - before the load is performed, the previous state of the config file + reader is erased. So after the load the config file reader will contain + only information from the given config file. + - This object will represent the top most block of the config file. + throws + - std::bad_alloc + - config_reader_error + - file_not_found + !*/ + + virtual ~config_reader( + ); + /*! + ensures + - all memory associated with *this has been released + !*/ + + void clear( + ); + /*! + ensures + - #*this has its initial value + throws + - std::bad_alloc + If this exception is thrown then *this is unusable + until clear() is called and succeeds + !*/ + + void load_from ( + std::istream& in + ); + /*! + ensures + - reads the config file to parse from the given input stream, + parses it and loads this object up with all the sub blocks and + key/value pairs it finds. + - before the load is performed, the previous state of the config file + reader is erased. So after the load the config file reader will contain + only information from the given config file. + - *this will represent the top most block of the config file contained + in the input stream in. + throws + - std::bad_alloc + If this exception is thrown then *this is unusable + until clear() is called and succeeds + - config_reader_error + If this exception is thrown then this object will + revert to its initial value. + !*/ + + void load_from ( + const std::string& config_file + ); + /*! + ensures + - parses the config file named by the config_file string. Specifically, + parses it and loads this object up with all the sub blocks and + key/value pairs it finds in the file. + - before the load is performed, the previous state of the config file + reader is erased. So after the load the config file reader will contain + only information from the given config file. + - This object will represent the top most block of the config file. + throws + - std::bad_alloc + If this exception is thrown then *this is unusable + until clear() is called and succeeds + - config_reader_error + If this exception is thrown then this object will + revert to its initial value. + - file_not_found + If this exception is thrown then this object will + revert to its initial value. + !*/ + + bool is_key_defined ( + const std::string& key_name + ) const; + /*! + ensures + - if (there is a key with the given name defined within this config_reader's block) then + - returns true + - else + - returns false + !*/ + + bool is_block_defined ( + const std::string& block_name + ) const; + /*! + ensures + - if (there is a sub block with the given name defined within this config_reader's block) then + - returns true + - else + - returns false + !*/ + + typedef config_reader this_type; + const this_type& block ( + const std::string& block_name + ) const; + /*! + ensures + - if (is_block_defined(block_name) == true) then + - returns a const reference to the config_reader that represents the given named sub block + - else + - throws config_reader_access_error + throws + - config_reader_access_error + if this exception is thrown then its block_name field will be set to the + given block_name string. + !*/ + + const std::string& operator[] ( + const std::string& key_name + ) const; + /*! + ensures + - if (is_key_defined(key_name) == true) then + - returns a const reference to the value string associated with the given key in + this config_reader's block. + - else + - throws config_reader_access_error + throws + - config_reader_access_error + if this exception is thrown then its key_name field will be set to the + given key_name string. + !*/ + + template < + typename queue_of_strings + > + void get_keys ( + queue_of_strings& keys + ) const; + /*! + requires + - queue_of_strings is an implementation of queue/queue_kernel_abstract.h + with T set to std::string, or std::vector, or + dlib::std_vector_c + ensures + - #keys == a collection containing all the keys defined in this config_reader's block. + (i.e. for all strings str in keys it is the case that is_key_defined(str) == true) + !*/ + + template < + typename queue_of_strings + > + void get_blocks ( + queue_of_strings& blocks + ) const; + /*! + requires + - queue_of_strings is an implementation of queue/queue_kernel_abstract.h + with T set to std::string, or std::vector, or + dlib::std_vector_c + ensures + - #blocks == a collection containing the names of all the blocks defined in this + config_reader's block. + (i.e. for all strings str in blocks it is the case that is_block_defined(str) == true) + !*/ + + private: + + // restricted functions + config_reader(config_reader&); // copy constructor + config_reader& operator=(config_reader&); // assignment operator + + }; + +} + +#endif // DLIB_CONFIG_READER_KERNEl_ABSTRACT_ + diff --git a/ml/dlib/dlib/config_reader/config_reader_thread_safe_1.h b/ml/dlib/dlib/config_reader/config_reader_thread_safe_1.h new file mode 100644 index 000000000..1ad250c99 --- /dev/null +++ b/ml/dlib/dlib/config_reader/config_reader_thread_safe_1.h @@ -0,0 +1,456 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CONFIG_READER_THREAD_SAFe_ +#define DLIB_CONFIG_READER_THREAD_SAFe_ + +#include "config_reader_kernel_abstract.h" +#include +#include +#include +#include "../algs.h" +#include "../interfaces/enumerable.h" +#include "../threads.h" +#include "config_reader_thread_safe_abstract.h" + +namespace dlib +{ + + template < + typename config_reader_base, + typename map_string_void + > + class config_reader_thread_safe_1 + { + + /*! + CONVENTION + - get_mutex() == *m + - *cr == the config reader being extended + - block_table[x] == (void*)&block(x) + - block_table.size() == the number of blocks in *cr + - block_table[key] == a config_reader_thread_safe_1 that contains &cr.block(key) + - if (own_pointers) then + - this object owns the m and cr pointers and should delete them when destructed + !*/ + + public: + + config_reader_thread_safe_1 ( + const config_reader_base* base, + rmutex* m_ + ); + + config_reader_thread_safe_1(); + + typedef typename config_reader_base::config_reader_error config_reader_error; + typedef typename config_reader_base::config_reader_access_error config_reader_access_error; + + config_reader_thread_safe_1( + std::istream& in + ); + + config_reader_thread_safe_1( + const std::string& config_file + ); + + virtual ~config_reader_thread_safe_1( + ); + + void clear ( + ); + + void load_from ( + std::istream& in + ); + + void load_from ( + const std::string& config_file + ); + + bool is_key_defined ( + const std::string& key + ) const; + + bool is_block_defined ( + const std::string& name + ) const; + + typedef config_reader_thread_safe_1 this_type; + const this_type& block ( + const std::string& name + ) const; + + const std::string& operator[] ( + const std::string& key + ) const; + + template < + typename queue_of_strings + > + void get_keys ( + queue_of_strings& keys + ) const; + + template < + typename queue_of_strings + > + void get_blocks ( + queue_of_strings& blocks + ) const; + + inline const rmutex& get_mutex ( + ) const; + + private: + + void fill_block_table ( + ); + /*! + ensures + - block_table.size() == the number of blocks in cr + - block_table[key] == a config_reader_thread_safe_1 that contains &cr.block(key) + !*/ + + rmutex* m; + config_reader_base* cr; + map_string_void block_table; + const bool own_pointers; + + // restricted functions + config_reader_thread_safe_1(config_reader_thread_safe_1&); + config_reader_thread_safe_1& operator=(config_reader_thread_safe_1&); + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename config_reader_base, + typename map_string_void + > + config_reader_thread_safe_1:: + config_reader_thread_safe_1( + const config_reader_base* base, + rmutex* m_ + ) : + m(m_), + cr(const_cast(base)), + own_pointers(false) + { + fill_block_table(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename config_reader_base, + typename map_string_void + > + config_reader_thread_safe_1:: + config_reader_thread_safe_1( + ) : + m(0), + cr(0), + own_pointers(true) + { + try + { + m = new rmutex; + cr = new config_reader_base; + } + catch (...) + { + if (m) delete m; + if (cr) delete cr; + throw; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename config_reader_base, + typename map_string_void + > + void config_reader_thread_safe_1:: + clear( + ) + { + auto_mutex M(*m); + cr->clear(); + fill_block_table(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename config_reader_base, + typename map_string_void + > + void config_reader_thread_safe_1:: + load_from( + std::istream& in + ) + { + auto_mutex M(*m); + cr->load_from(in); + fill_block_table(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename config_reader_base, + typename map_string_void + > + void config_reader_thread_safe_1:: + load_from( + const std::string& config_file + ) + { + auto_mutex M(*m); + cr->load_from(config_file); + fill_block_table(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename config_reader_base, + typename map_string_void + > + config_reader_thread_safe_1:: + config_reader_thread_safe_1( + std::istream& in + ) : + m(0), + cr(0), + own_pointers(true) + { + try + { + m = new rmutex; + cr = new config_reader_base(in); + fill_block_table(); + } + catch (...) + { + if (m) delete m; + if (cr) delete cr; + throw; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename config_reader_base, + typename map_string_void + > + config_reader_thread_safe_1:: + config_reader_thread_safe_1( + const std::string& config_file + ) : + m(0), + cr(0), + own_pointers(true) + { + try + { + m = new rmutex; + cr = new config_reader_base(config_file); + fill_block_table(); + } + catch (...) + { + if (m) delete m; + if (cr) delete cr; + throw; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename config_reader_base, + typename map_string_void + > + config_reader_thread_safe_1:: + ~config_reader_thread_safe_1( + ) + { + if (own_pointers) + { + delete m; + delete cr; + } + + // clear out the block table + block_table.reset(); + while (block_table.move_next()) + { + delete static_cast(block_table.element().value()); + } + block_table.clear(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename config_reader_base, + typename map_string_void + > + bool config_reader_thread_safe_1:: + is_key_defined ( + const std::string& key + ) const + { + auto_mutex M(*m); + return cr->is_key_defined(key); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename config_reader_base, + typename map_string_void + > + bool config_reader_thread_safe_1:: + is_block_defined ( + const std::string& name + ) const + { + auto_mutex M(*m); + return cr->is_block_defined(name); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename config_reader_base, + typename map_string_void + > + const config_reader_thread_safe_1& config_reader_thread_safe_1:: + block ( + const std::string& name + ) const + { + auto_mutex M(*m); + if (block_table.is_in_domain(name) == false) + { + throw config_reader_access_error(name,""); + } + + return *static_cast(block_table[name]); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename config_reader_base, + typename map_string_void + > + const std::string& config_reader_thread_safe_1:: + operator[] ( + const std::string& key + ) const + { + auto_mutex M(*m); + return (*cr)[key]; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename config_reader_base, + typename map_string_void + > + template < + typename queue_of_strings + > + void config_reader_thread_safe_1:: + get_keys ( + queue_of_strings& keys + ) const + { + auto_mutex M(*m); + cr->get_keys(keys); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename config_reader_base, + typename map_string_void + > + template < + typename queue_of_strings + > + void config_reader_thread_safe_1:: + get_blocks ( + queue_of_strings& blocks + ) const + { + auto_mutex M(*m); + cr->get_blocks(blocks); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename config_reader_base, + typename map_string_void + > + const rmutex& config_reader_thread_safe_1:: + get_mutex ( + ) const + { + return *m; + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// private member functions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename config_reader_base, + typename map_string_void + > + void config_reader_thread_safe_1:: + fill_block_table ( + ) + { + using namespace std; + // first empty out the block table + block_table.reset(); + while (block_table.move_next()) + { + delete static_cast(block_table.element().value()); + } + block_table.clear(); + + std::vector blocks; + cr->get_blocks(blocks); + + // now fill the block table up to match what is in cr + for (unsigned long i = 0; i < blocks.size(); ++i) + { + config_reader_thread_safe_1* block = new config_reader_thread_safe_1(&cr->block(blocks[i]),m); + void* temp = block; + block_table.add(blocks[i],temp); + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CONFIG_READER_THREAD_SAFe_ + + diff --git a/ml/dlib/dlib/config_reader/config_reader_thread_safe_abstract.h b/ml/dlib/dlib/config_reader/config_reader_thread_safe_abstract.h new file mode 100644 index 000000000..25bcbae4a --- /dev/null +++ b/ml/dlib/dlib/config_reader/config_reader_thread_safe_abstract.h @@ -0,0 +1,45 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_CONFIG_READER_THREAD_SAFe_ABSTRACT_ +#ifdef DLIB_CONFIG_READER_THREAD_SAFe_ABSTRACT_ + +#include +#include +#include "config_reader_kernel_abstract.h" +#include "../threads/threads_kernel_abstract.h" + +namespace dlib +{ + + class config_reader_thread_safe + { + + /*! + WHAT THIS EXTENSION DOES FOR config_reader + This object extends a normal config_reader by simply wrapping all + its member functions inside mutex locks to make it safe to use + in a threaded program. + + So this object provides an interface identical to the one defined + in the config_reader/config_reader_kernel_abstract.h file except that + the rmutex returned by get_mutex() is always locked when this + object's member functions are called. + !*/ + + public: + + const rmutex& get_mutex ( + ) const; + /*! + ensures + - returns the rmutex used to make this object thread safe. i.e. returns + the rmutex that is locked when this object's functions are called. + !*/ + + }; + +} + +#endif // DLIB_CONFIG_READER_THREAD_SAFe_ABSTRACT_ + + diff --git a/ml/dlib/dlib/console_progress_indicator.h b/ml/dlib/dlib/console_progress_indicator.h new file mode 100644 index 000000000..8f04aa533 --- /dev/null +++ b/ml/dlib/dlib/console_progress_indicator.h @@ -0,0 +1,207 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CONSOLE_PROGRESS_INDiCATOR_Hh_ +#define DLIB_CONSOLE_PROGRESS_INDiCATOR_Hh_ + +#include +#include +#include +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class console_progress_indicator + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a tool for reporting how long a task will take + to complete. + + For example, consider the following bit of code: + + console_progress_indicator pbar(100) + for (int i = 1; i <= 100; ++i) + { + pbar.print_status(i); + long_running_operation(); + } + + The above code will print a message to the console each iteration + which shows how much time is remaining until the loop terminates. + !*/ + + public: + + inline explicit console_progress_indicator ( + double target_value + ); + /*! + ensures + - #target() == target_value + !*/ + + inline void reset ( + double target_value + ); + /*! + ensures + - #target() == target_value + - performs the equivalent of: + *this = console_progress_indicator(target_value) + (i.e. resets this object with a new target value) + + !*/ + + inline double target ( + ) const; + /*! + ensures + - This object attempts to measure how much time is + left until we reach a certain targeted value. This + function returns that targeted value. + !*/ + + inline bool print_status ( + double cur, + bool always_print = false + ); + /*! + ensures + - print_status() assumes it is called with values which are linearly + approaching target(). It will attempt to predict how much time is + remaining until cur becomes equal to target(). + - prints a status message to the screen which indicates how much + more time is left until cur is equal to target() + - if (always_print) then + - This function prints to the screen each time it is called. + - else + - This function throttles the printing so that at most 1 message is + printed each second. Note that it won't print anything to the screen + until about one second has elapsed. This means that the first call + to print_status() never prints to the screen. + - This function returns true if it prints to the screen and false + otherwise. + !*/ + + private: + + double target_val; + + time_t start_time; + double first_val; + double seen_first_val; + time_t last_time; + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// IMPLEMENTATION DETAILS +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + console_progress_indicator:: + console_progress_indicator ( + double target_value + ) : + target_val(target_value), + start_time(0), + first_val(0), + seen_first_val(false), + last_time(0) + { + } + +// ---------------------------------------------------------------------------------------- + + bool console_progress_indicator:: + print_status ( + double cur, + bool always_print + ) + { + const time_t cur_time = std::time(0); + + // if this is the first time print_status has been called + // then collect some information and exit. We will print status + // on the next call. + if (!seen_first_val) + { + start_time = cur_time; + last_time = cur_time; + first_val = cur; + seen_first_val = true; + return false; + } + + if (cur_time != last_time || always_print) + { + last_time = cur_time; + double delta_t = static_cast(cur_time - start_time); + double delta_val = std::abs(cur - first_val); + + // don't do anything if cur is equal to first_val + if (delta_val < std::numeric_limits::epsilon()) + return false; + + double seconds = delta_t/delta_val * std::abs(target_val - cur); + + std::ios::fmtflags oldflags = std::cout.flags(); + std::cout.flags(); + std::cout.setf(std::ios::fixed,std::ios::floatfield); + std::streamsize ss; + + if (seconds < 60) + { + ss = std::cout.precision(0); + std::cout << "Time remaining: " << seconds << " seconds. \r" << std::flush; + } + else if (seconds < 60*60) + { + ss = std::cout.precision(2); + std::cout << "Time remaining: " << seconds/60 << " minutes. \r" << std::flush; + } + else + { + ss = std::cout.precision(2); + std::cout << "Time remaining: " << seconds/60/60 << " hours. \r" << std::flush; + } + + // restore previous output flags and precision settings + std::cout.flags(oldflags); + std::cout.precision(ss); + + return true; + } + + return false; + } + +// ---------------------------------------------------------------------------------------- + + double console_progress_indicator:: + target ( + ) const + { + return target_val; + } + +// ---------------------------------------------------------------------------------------- + + void console_progress_indicator:: + reset ( + double target_value + ) + { + *this = console_progress_indicator(target_value); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CONSOLE_PROGRESS_INDiCATOR_Hh_ + diff --git a/ml/dlib/dlib/control.h b/ml/dlib/dlib/control.h new file mode 100644 index 000000000..85d00817d --- /dev/null +++ b/ml/dlib/dlib/control.h @@ -0,0 +1,11 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CONTRoL_ +#define DLIB_CONTRoL_ + +#include "control/lspi.h" +#include "control/mpc.h" + +#endif // DLIB_CONTRoL_ + + diff --git a/ml/dlib/dlib/control/approximate_linear_models.h b/ml/dlib/dlib/control/approximate_linear_models.h new file mode 100644 index 000000000..9732d71e9 --- /dev/null +++ b/ml/dlib/dlib/control/approximate_linear_models.h @@ -0,0 +1,128 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_APPROXIMATE_LINEAR_MODELS_Hh_ +#define DLIB_APPROXIMATE_LINEAR_MODELS_Hh_ + +#include "approximate_linear_models_abstract.h" +#include "../matrix.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + struct process_sample + { + typedef feature_extractor feature_extractor_type; + typedef typename feature_extractor::state_type state_type; + typedef typename feature_extractor::action_type action_type; + + process_sample(){} + + process_sample( + const state_type& s, + const action_type& a, + const state_type& n, + const double& r + ) : state(s), action(a), next_state(n), reward(r) {} + + state_type state; + action_type action; + state_type next_state; + double reward; + }; + + template < typename feature_extractor > + void serialize (const process_sample& item, std::ostream& out) + { + serialize(item.state, out); + serialize(item.action, out); + serialize(item.next_state, out); + serialize(item.reward, out); + } + + template < typename feature_extractor > + void deserialize (process_sample& item, std::istream& in) + { + deserialize(item.state, in); + deserialize(item.action, in); + deserialize(item.next_state, in); + deserialize(item.reward, in); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + class policy + { + public: + + typedef feature_extractor feature_extractor_type; + typedef typename feature_extractor::state_type state_type; + typedef typename feature_extractor::action_type action_type; + + + policy ( + ) + { + w.set_size(fe.num_features()); + w = 0; + } + + policy ( + const matrix& weights_, + const feature_extractor& fe_ + ) : w(weights_), fe(fe_) {} + + action_type operator() ( + const state_type& state + ) const + { + return fe.find_best_action(state,w); + } + + const feature_extractor& get_feature_extractor ( + ) const { return fe; } + + const matrix& get_weights ( + ) const { return w; } + + + private: + matrix w; + feature_extractor fe; + }; + + template < typename feature_extractor > + inline void serialize(const policy& item, std::ostream& out) + { + int version = 1; + serialize(version, out); + serialize(item.get_feature_extractor(), out); + serialize(item.get_weights(), out); + } + template < typename feature_extractor > + inline void deserialize(policy& item, std::istream& in) + { + int version = 0; + deserialize(version, in); + if (version != 1) + throw serialization_error("Unexpected version found while deserializing dlib::policy object."); + feature_extractor fe; + matrix w; + deserialize(fe, in); + deserialize(w, in); + item = policy(w,fe); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_APPROXIMATE_LINEAR_MODELS_Hh_ + diff --git a/ml/dlib/dlib/control/approximate_linear_models_abstract.h b/ml/dlib/dlib/control/approximate_linear_models_abstract.h new file mode 100644 index 000000000..59dac4276 --- /dev/null +++ b/ml/dlib/dlib/control/approximate_linear_models_abstract.h @@ -0,0 +1,213 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_APPROXIMATE_LINEAR_MODELS_ABSTRACT_Hh_ +#ifdef DLIB_APPROXIMATE_LINEAR_MODELS_ABSTRACT_Hh_ + +#include "../matrix.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + struct example_feature_extractor + { + /*! + WHAT THIS OBJECT REPRESENTS + This object defines the interface a feature extractor must implement if it + is to be used with the process_sample and policy objects defined at the + bottom of this file. Moreover, it is meant to represent the core part + of a model used in a reinforcement learning algorithm. + + In particular, this object models a Q(state,action) function where + Q(state,action) == dot(w, PSI(state,action)) + where PSI(state,action) is a feature vector and w is a parameter + vector. + + Therefore, a feature extractor defines how the PSI(x,y) feature vector is + calculated. It also defines the types used to represent the state and + action objects. + + + THREAD SAFETY + Instances of this object are required to be threadsafe, that is, it should + be safe for multiple threads to make concurrent calls to the member + functions of this object. + !*/ + + // The state and actions can be any types so long as you provide typedefs for them. + typedef T state_type; + typedef U action_type; + // We can also say that the last element in the weight vector w must be 1. This + // can be useful for including a prior into your model. + const static bool force_last_weight_to_1 = false; + + example_feature_extractor( + ); + /*! + ensures + - this object is properly initialized. + !*/ + + unsigned long num_features( + ) const; + /*! + ensures + - returns the dimensionality of the PSI() feature vector. + !*/ + + action_type find_best_action ( + const state_type& state, + const matrix& w + ) const; + /*! + ensures + - returns the action A that maximizes Q(state,A) = dot(w,PSI(state,A)). + That is, this function finds the best action to take in the given state + when our model is parameterized by the given weight vector w. + !*/ + + void get_features ( + const state_type& state, + const action_type& action, + matrix& feats + ) const; + /*! + ensures + - #feats.size() == num_features() + - #feats == PSI(state,action) + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + struct process_sample + { + /*! + REQUIREMENTS ON feature_extractor + feature_extractor should implement the example_feature_extractor interface + defined at the top of this file. + + WHAT THIS OBJECT REPRESENTS + This object holds a training sample for a reinforcement learning algorithm. + In particular, it should be a sample from some process where the process + was in state this->state, then took this->action action which resulted in + receiving this->reward and ending up in the state this->next_state. + !*/ + + typedef feature_extractor feature_extractor_type; + typedef typename feature_extractor::state_type state_type; + typedef typename feature_extractor::action_type action_type; + + process_sample(){} + + process_sample( + const state_type& s, + const action_type& a, + const state_type& n, + const double& r + ) : state(s), action(a), next_state(n), reward(r) {} + + state_type state; + action_type action; + state_type next_state; + double reward; + }; + + template < typename feature_extractor > + void serialize (const process_sample& item, std::ostream& out); + template < typename feature_extractor > + void deserialize (process_sample& item, std::istream& in); + /*! + provides serialization support. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + class policy + { + /*! + REQUIREMENTS ON feature_extractor + feature_extractor should implement the example_feature_extractor interface + defined at the top of this file. + + WHAT THIS OBJECT REPRESENTS + This is a policy based on the supplied feature_extractor model. In + particular, it maps from feature_extractor::state_type to the best action + to take in that state. + !*/ + + public: + + typedef feature_extractor feature_extractor_type; + typedef typename feature_extractor::state_type state_type; + typedef typename feature_extractor::action_type action_type; + + + policy ( + ); + /*! + ensures + - #get_feature_extractor() == feature_extractor() + (i.e. it will have its default value) + - #get_weights().size() == #get_feature_extractor().num_features() + - #get_weights() == 0 + !*/ + + policy ( + const matrix& weights, + const feature_extractor& fe + ); + /*! + requires + - fe.num_features() == weights.size() + ensures + - #get_feature_extractor() == fe + - #get_weights() == weights + !*/ + + action_type operator() ( + const state_type& state + ) const; + /*! + ensures + - returns get_feature_extractor().find_best_action(state,w); + !*/ + + const feature_extractor& get_feature_extractor ( + ) const; + /*! + ensures + - returns the feature extractor used by this object + !*/ + + const matrix& get_weights ( + ) const; + /*! + ensures + - returns the parameter vector (w) associated with this object. The length + of the vector is get_feature_extractor().num_features(). + !*/ + + }; + + template < typename feature_extractor > + void serialize(const policy& item, std::ostream& out); + template < typename feature_extractor > + void deserialize(policy& item, std::istream& in); + /*! + provides serialization support. + !*/ + +// ---------------------------------------------------------------------------------------- + + +#endif // DLIB_APPROXIMATE_LINEAR_MODELS_ABSTRACT_Hh_ + diff --git a/ml/dlib/dlib/control/lspi.h b/ml/dlib/dlib/control/lspi.h new file mode 100644 index 000000000..b21a501d2 --- /dev/null +++ b/ml/dlib/dlib/control/lspi.h @@ -0,0 +1,188 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_LSPI_Hh_ +#define DLIB_LSPI_Hh_ + +#include "lspi_abstract.h" +#include "approximate_linear_models.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + class lspi + { + public: + typedef feature_extractor feature_extractor_type; + typedef typename feature_extractor::state_type state_type; + typedef typename feature_extractor::action_type action_type; + + explicit lspi( + const feature_extractor& fe_ + ) : fe(fe_) + { + init(); + } + + lspi( + ) + { + init(); + } + + double get_discount ( + ) const { return discount; } + + void set_discount ( + double value + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(0 < value && value <= 1, + "\t void lspi::set_discount(value)" + << "\n\t invalid inputs were given to this function" + << "\n\t value: " << value + ); + discount = value; + } + + const feature_extractor& get_feature_extractor ( + ) const { return fe; } + + void be_verbose ( + ) + { + verbose = true; + } + + void be_quiet ( + ) + { + verbose = false; + } + + void set_epsilon ( + double eps_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(eps_ > 0, + "\t void lspi::set_epsilon(eps_)" + << "\n\t invalid inputs were given to this function" + << "\n\t eps_: " << eps_ + ); + eps = eps_; + } + + double get_epsilon ( + ) const + { + return eps; + } + + void set_lambda ( + double lambda_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(lambda_ >= 0, + "\t void lspi::set_lambda(lambda_)" + << "\n\t invalid inputs were given to this function" + << "\n\t lambda_: " << lambda_ + ); + lambda = lambda_; + } + + double get_lambda ( + ) const + { + return lambda; + } + + void set_max_iterations ( + unsigned long max_iter + ) { max_iterations = max_iter; } + + unsigned long get_max_iterations ( + ) { return max_iterations; } + + template + policy train ( + const vector_type& samples + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(samples.size() > 0, + "\t policy lspi::train(samples)" + << "\n\t invalid inputs were given to this function" + ); + + matrix w(fe.num_features()); + w = 0; + matrix prev_w, b, f1, f2; + + matrix A; + + double change; + unsigned long iter = 0; + do + { + A = identity_matrix(fe.num_features())*lambda; + b = 0; + for (unsigned long i = 0; i < samples.size(); ++i) + { + fe.get_features(samples[i].state, samples[i].action, f1); + fe.get_features(samples[i].next_state, + fe.find_best_action(samples[i].next_state,w), + f2); + A += f1*trans(f1 - discount*f2); + b += f1*samples[i].reward; + } + + prev_w = w; + if (feature_extractor::force_last_weight_to_1) + w = join_cols(pinv(colm(A,range(0,A.nc()-2)))*(b-colm(A,A.nc()-1)),mat(1.0)); + else + w = pinv(A)*b; + + change = length(w-prev_w); + ++iter; + + if (verbose) + std::cout << "iteration: " << iter << "\tchange: " << change << std::endl; + + } while(change > eps && iter < max_iterations); + + return policy(w,fe); + } + + + private: + + void init() + { + lambda = 0.01; + discount = 0.8; + eps = 0.01; + verbose = false; + max_iterations = 100; + } + + double lambda; + double discount; + double eps; + bool verbose; + unsigned long max_iterations; + feature_extractor fe; + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_LSPI_Hh_ + diff --git a/ml/dlib/dlib/control/lspi_abstract.h b/ml/dlib/dlib/control/lspi_abstract.h new file mode 100644 index 000000000..f262d16f4 --- /dev/null +++ b/ml/dlib/dlib/control/lspi_abstract.h @@ -0,0 +1,193 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_LSPI_ABSTRACT_Hh_ +#ifdef DLIB_LSPI_ABSTRACT_Hh_ + +#include "approximate_linear_models_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + class lspi + { + /*! + REQUIREMENTS ON feature_extractor + feature_extractor should implement the example_feature_extractor interface + defined at the top of dlib/control/approximate_linear_models_abstract.h + + WHAT THIS OBJECT REPRESENTS + This object is an implementation of the reinforcement learning algorithm + described in the following paper: + Lagoudakis, Michail G., and Ronald Parr. "Least-squares policy + iteration." The Journal of Machine Learning Research 4 (2003): + 1107-1149. + + This means that it takes a bunch of training data in the form of + process_samples and outputs a policy that hopefully performs well when run + on the process that generated those samples. + !*/ + + public: + typedef feature_extractor feature_extractor_type; + typedef typename feature_extractor::state_type state_type; + typedef typename feature_extractor::action_type action_type; + + explicit lspi( + const feature_extractor& fe_ + ); + /*! + ensures + - #get_feature_extractor() == fe_ + - #get_lambda() == 0.01 + - #get_discount == 0.8 + - #get_epsilon() == 0.01 + - is not verbose + - #get_max_iterations() == 100 + !*/ + + lspi( + ); + /*! + ensures + - #get_feature_extractor() == feature_extractor() + (i.e. it will have its default value) + - #get_lambda() == 0.01 + - #get_discount == 0.8 + - #get_epsilon() == 0.01 + - is not verbose + - #get_max_iterations() == 100 + !*/ + + double get_discount ( + ) const; + /*! + ensures + - returns the discount applied to the sum of rewards in the Bellman + equation. + !*/ + + void set_discount ( + double value + ); + /*! + requires + - 0 < value <= 1 + ensures + - #get_discount() == value + !*/ + + const feature_extractor& get_feature_extractor ( + ) const; + /*! + ensures + - returns the feature extractor used by this object + !*/ + + void be_verbose ( + ); + /*! + ensures + - This object will print status messages to standard out so that a + user can observe the progress of the algorithm. + !*/ + + void be_quiet ( + ); + /*! + ensures + - this object will not print anything to standard out + !*/ + + void set_epsilon ( + double eps + ); + /*! + requires + - eps > 0 + ensures + - #get_epsilon() == eps + !*/ + + double get_epsilon ( + ) const; + /*! + ensures + - returns the error epsilon that determines when training should stop. + Smaller values may result in a more accurate solution but take longer to + train. + !*/ + + void set_lambda ( + double lambda_ + ); + /*! + requires + - lambda >= 0 + ensures + - #get_lambda() == lambda + !*/ + + double get_lambda ( + ) const; + /*! + ensures + - returns the regularization parameter. It is the parameter that + determines the trade off between trying to fit the training data + exactly or allowing more errors but hopefully improving the + generalization ability of the resulting function. Smaller values + encourage exact fitting while larger values of lambda may encourage + better generalization. + !*/ + + void set_max_iterations ( + unsigned long max_iter + ); + /*! + ensures + - #get_max_iterations() == max_iter + !*/ + + unsigned long get_max_iterations ( + ); + /*! + ensures + - returns the maximum number of iterations the SVM optimizer is allowed to + run before it is required to stop and return a result. + !*/ + + template < + typename vector_type + > + policy train ( + const vector_type& samples + ) const; + /*! + requires + - samples.size() > 0 + - samples is something with an interface that looks like + std::vector>. That is, it should + be some kind of array of process_sample objects. + ensures + - Trains a policy based on the given data and returns the results. The + idea is to find a policy that will obtain the largest possible reward + when run on the process that generated the samples. In particular, + if the returned policy is P then: + - P(S) == the best action to take when in state S. + - if (feature_extractor::force_last_weight_to_1) then + - The last element of P.get_weights() is 1. + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_LSPI_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/control/mpc.h b/ml/dlib/dlib/control/mpc.h new file mode 100644 index 000000000..48ef2b72d --- /dev/null +++ b/ml/dlib/dlib/control/mpc.h @@ -0,0 +1,370 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MPC_Hh_ +#define DLIB_MPC_Hh_ + +#include "mpc_abstract.h" +#include "../matrix.h" +#include "../algs.h" + + +namespace dlib +{ + template < + long S_, + long I_, + unsigned long horizon_ + > + class mpc + { + + public: + + const static long S = S_; + const static long I = I_; + const static unsigned long horizon = horizon_; + + mpc( + ) + { + A = 0; + B = 0; + C = 0; + Q = 0; + R = 0; + lower = 0; + upper = 0; + + max_iterations = 0; + eps = 0.01; + for (unsigned long i = 0; i < horizon; ++i) + { + target[i].set_size(A.nr()); + target[i] = 0; + + controls[i].set_size(B.nc()); + controls[i] = 0; + } + lambda = 0; + } + + mpc ( + const matrix& A_, + const matrix& B_, + const matrix& C_, + const matrix& Q_, + const matrix& R_, + const matrix& lower_, + const matrix& upper_ + ) : A(A_), B(B_), C(C_), Q(Q_), R(R_), lower(lower_), upper(upper_) + { + // make sure requires clause is not broken + DLIB_ASSERT(A.nr() > 0 && B.nc() > 0, + "\t mpc::mpc()" + << "\n\t invalid inputs were given to this function" + << "\n\t A.nr(): " << A.nr() + << "\n\t B.nc(): " << B.nc() + ); + + DLIB_ASSERT(A.nr() == A.nc() && + A.nr() == B.nr() && + A.nr() == C.nr() && + A.nr() == Q.nr(), + "\t mpc::mpc()" + << "\n\t invalid inputs were given to this function" + << "\n\t A.nr(): " << A.nr() + << "\n\t A.nc(): " << A.nc() + << "\n\t B.nr(): " << B.nr() + << "\n\t C.nr(): " << C.nr() + << "\n\t Q.nr(): " << Q.nr() + ); + DLIB_ASSERT( + B.nc() == R.nr() && + B.nc() == lower.nr() && + B.nc() == upper.nr() , + "\t mpc::mpc()" + << "\n\t invalid inputs were given to this function" + << "\n\t B.nr(): " << B.nr() + << "\n\t B.nc(): " << B.nc() + << "\n\t lower.nr(): " << lower.nr() + << "\n\t upper.nr(): " << upper.nr() + ); + DLIB_ASSERT(min(Q) >= 0 && + min(R) > 0 && + min(upper-lower) >= 0, + "\t mpc::mpc()" + << "\n\t invalid inputs were given to this function" + << "\n\t min(Q): " << min(Q) + << "\n\t min(R): " << min(R) + << "\n\t min(upper-lower): " << min(upper-lower) + ); + + + max_iterations = 10000; + eps = 0.01; + for (unsigned long i = 0; i < horizon; ++i) + { + target[i].set_size(A.nr()); + target[i] = 0; + + controls[i].set_size(B.nc()); + controls[i] = 0; + } + + // Bound the maximum eigenvalue of the hessian by computing the trace of the + // hessian matrix. + lambda = sum(R)*horizon; + matrix temp = diagm(Q); + for (unsigned long c = 0; c < horizon; ++c) + { + lambda += trace(trans(B)*temp*B); + Q_diag[horizon-c-1] = diag(trans(B)*temp*B); + temp = trans(A)*temp*A + diagm(Q); + } + + } + + const matrix& get_A ( + ) const { return A; } + const matrix& get_B ( + ) const { return B; } + const matrix& get_C ( + ) const { return C; } + const matrix& get_Q ( + ) const { return Q; } + const matrix& get_R ( + ) const { return R; } + const matrix& get_lower_constraints ( + ) const { return lower; } + const matrix& get_upper_constraints ( + ) const { return upper; } + + void set_target ( + const matrix& val, + const unsigned long time + ) + { + DLIB_ASSERT(time < horizon, + "\t void mpc::set_target(eps_)" + << "\n\t invalid inputs were given to this function" + << "\n\t time: " << time + << "\n\t horizon: " << horizon + ); + + target[time] = val; + } + + void set_target ( + const matrix& val + ) + { + for (unsigned long i = 0; i < horizon; ++i) + target[i] = val; + } + + void set_last_target ( + const matrix& val + ) + { + set_target(val, horizon-1); + } + + const matrix& get_target ( + const unsigned long time + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(time < horizon, + "\t matrix mpc::get_target(eps_)" + << "\n\t invalid inputs were given to this function" + << "\n\t time: " << time + << "\n\t horizon: " << horizon + ); + + return target[time]; + } + + unsigned long get_max_iterations ( + ) const { return max_iterations; } + + void set_max_iterations ( + unsigned long max_iter + ) + { + max_iterations = max_iter; + } + + void set_epsilon ( + double eps_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(eps_ > 0, + "\t void mpc::set_epsilon(eps_)" + << "\n\t invalid inputs were given to this function" + << "\n\t eps_: " << eps_ + ); + eps = eps_; + } + + double get_epsilon ( + ) const + { + return eps; + } + + matrix operator() ( + const matrix& current_state + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(min(R) > 0 && A.nr() == current_state.size(), + "\t matrix mpc::operator(current_state)" + << "\n\t invalid inputs were given to this function" + << "\n\t min(R): " << min(R) + << "\n\t A.nr(): " << A.nr() + << "\n\t current_state.size(): " << current_state.size() + ); + + // Shift the inputs over by one time step so we can use them to warm start the + // optimizer. + for (unsigned long i = 1; i < horizon; ++i) + controls[i-1] = controls[i]; + + solve_linear_mpc(current_state); + + for (unsigned long i = 1; i < horizon; ++i) + target[i-1] = target[i]; + + return controls[0]; + } + + private: + + + // These temporary variables here just to avoid reallocating them on each call to + // operator(). + matrix M[horizon]; + matrix MM[horizon]; + matrix df[horizon]; + matrix v[horizon]; + matrix v_old[horizon]; + + void solve_linear_mpc ( + const matrix& initial_state + ) + { + // make it so MM == trans(K)*Q*(M-target) + M[0] = A*initial_state + C; + for (unsigned long i = 1; i < horizon; ++i) + M[i] = A*M[i-1] + C; + for (unsigned long i = 0; i < horizon; ++i) + M[i] = diagm(Q)*(M[i]-target[i]); + for (long i = (long)horizon-2; i >= 0; --i) + M[i] += trans(A)*M[i+1]; + for (unsigned long i = 0; i < horizon; ++i) + MM[i] = trans(B)*M[i]; + + + + unsigned long iter = 0; + for (; iter < max_iterations; ++iter) + { + // compute current gradient and put it into df. + // df == H*controls + MM; + M[0] = B*controls[0]; + for (unsigned long i = 1; i < horizon; ++i) + M[i] = A*M[i-1] + B*controls[i]; + for (unsigned long i = 0; i < horizon; ++i) + M[i] = diagm(Q)*M[i]; + for (long i = (long)horizon-2; i >= 0; --i) + M[i] += trans(A)*M[i+1]; + for (unsigned long i = 0; i < horizon; ++i) + df[i] = MM[i] + trans(B)*M[i] + diagm(R)*controls[i]; + + + + // Check the stopping condition, which is the magnitude of the largest element + // of the gradient. + double max_df = 0; + unsigned long max_t = 0; + long max_v = 0; + for (unsigned long i = 0; i < horizon; ++i) + { + for (long j = 0; j < controls[i].size(); ++j) + { + // if this variable isn't an active constraint then we care about it's + // derivative. + if (!((controls[i](j) <= lower(j) && df[i](j) > 0) || + (controls[i](j) >= upper(j) && df[i](j) < 0))) + { + if (std::abs(df[i](j)) > max_df) + { + max_df = std::abs(df[i](j)); + max_t = i; + max_v = j; + } + } + } + } + if (max_df < eps) + break; + + + + // We will start out by doing a little bit of coordinate descent because it + // allows us to optimize individual variables exactly. Since we are warm + // starting each iteration with a really good solution this helps speed + // things up a lot. + const unsigned long smo_iters = 50; + if (iter < smo_iters) + { + if (Q_diag[max_t](max_v) == 0) continue; + + // Take the optimal step but just for one variable. + controls[max_t](max_v) = -(df[max_t](max_v)-Q_diag[max_t](max_v)*controls[max_t](max_v))/Q_diag[max_t](max_v); + controls[max_t](max_v) = put_in_range(lower(max_v), upper(max_v), controls[max_t](max_v)); + + // If this is the last SMO iteration then don't forget to initialize v + // for the gradient steps. + if (iter+1 == smo_iters) + { + for (unsigned long i = 0; i < horizon; ++i) + v[i] = controls[i]; + } + } + else + { + // Take a projected gradient step. + for (unsigned long i = 0; i < horizon; ++i) + { + v_old[i] = v[i]; + v[i] = dlib::clamp(controls[i] - 1.0/lambda * df[i], lower, upper); + controls[i] = dlib::clamp(v[i] + (std::sqrt(lambda)-1)/(std::sqrt(lambda)+1)*(v[i]-v_old[i]), lower, upper); + } + } + } + } + + unsigned long max_iterations; + double eps; + + matrix A; + matrix B; + matrix C; + matrix Q; + matrix R; + matrix lower; + matrix upper; + matrix target[horizon]; + + double lambda; // abound on the largest eigenvalue of the hessian matrix. + matrix Q_diag[horizon]; + matrix controls[horizon]; + + }; + +} + +#endif // DLIB_MPC_Hh_ + diff --git a/ml/dlib/dlib/control/mpc_abstract.h b/ml/dlib/dlib/control/mpc_abstract.h new file mode 100644 index 000000000..b4421c076 --- /dev/null +++ b/ml/dlib/dlib/control/mpc_abstract.h @@ -0,0 +1,276 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_MPC_ABSTRACT_Hh_ +#ifdef DLIB_MPC_ABSTRACT_Hh_ + +#include "../matrix.h" + +namespace dlib +{ + template < + long S_, + long I_, + unsigned long horizon_ + > + class mpc + { + /*! + REQUIREMENTS ON horizon_ + horizon_ > 0 + + REQUIREMENTS ON S_ + S_ >= 0 + + REQUIREMENTS ON I_ + I_ >= 0 + + WHAT THIS OBJECT REPRESENTS + This object implements a linear model predictive controller. To explain + what that means, suppose you have some process you want to control and the + process dynamics are described by the linear equation: + x_{i+1} = A*x_i + B*u_i + C + That is, the next state the system goes into is a linear function of its + current state (x_i) and the current control (u_i) plus some constant bias + or disturbance. + + A model predictive controller can find the control (u) you should apply to + drive the state (x) to some reference value, or alternatively to make the + state track some reference time-varying sequence. It does this by + simulating the process for horizon_ time steps and selecting the control + that leads to the best performance over the next horizon_ steps. + + To be precise, each time you ask this object for a control, it solves the + following quadratic program: + + min sum_i trans(x_i-target_i)*Q*(x_i-target_i) + trans(u_i)*R*u_i + x_i,u_i + + such that: x_0 == current_state + x_{i+1} == A*x_i + B*u_i + C + lower <= u_i <= upper + 0 <= i < horizon_ + + and reports u_0 as the control you should take given that you are currently + in current_state. Q and R are user supplied matrices that define how we + penalize variations away from the target state as well as how much we want + to avoid generating large control signals. + + Finally, the algorithm we use to solve this quadratic program is based + largely on the method described in: + A Fast Gradient method for embedded linear predictive control (2011) + by Markus Kogel and Rolf Findeisen + !*/ + + public: + + const static long S = S_; + const static long I = I_; + const static unsigned long horizon = horizon_; + + mpc( + ); + /*! + ensures + - #get_max_iterations() == 0 + - The A,B,C,Q,R,lower, and upper parameter matrices are filled with zeros. + Therefore, to use this object you must initialize it via the constructor + that supplies these parameters. + !*/ + + mpc ( + const matrix& A, + const matrix& B, + const matrix& C, + const matrix& Q, + const matrix& R, + const matrix& lower, + const matrix& upper + ); + /*! + requires + - A.nr() > 0 + - B.nc() > 0 + - A.nr() == A.nc() == B.nr() == C.nr() == Q.nr() + - B.nc() == R.nr() == lower.nr() == upper.nr() + - min(Q) >= 0 + - min(R) > 0 + - min(upper-lower) >= 0 + ensures + - #get_A() == A + - #get_B() == B + - #get_C() == C + - #get_Q() == Q + - #get_R() == R + - #get_lower_constraints() == lower + - #get_upper_constraints() == upper + - for all valid i: + - get_target(i) == a vector of all zeros + - get_target(i).size() == A.nr() + - #get_max_iterations() == 10000 + - #get_epsilon() == 0.01 + !*/ + + const matrix& get_A ( + ) const; + /*! + ensures + - returns the A matrix from the quadratic program defined above. + !*/ + + const matrix& get_B ( + ) const; + /*! + ensures + - returns the B matrix from the quadratic program defined above. + !*/ + + const matrix& get_C ( + ) const; + /*! + ensures + - returns the C matrix from the quadratic program defined above. + !*/ + + const matrix& get_Q ( + ) const; + /*! + ensures + - returns the diagonal of the Q matrix from the quadratic program defined + above. + !*/ + + const matrix& get_R ( + ) const; + /*! + ensures + - returns the diagonal of the R matrix from the quadratic program defined + above. + !*/ + + const matrix& get_lower_constraints ( + ) const; + /*! + ensures + - returns the lower matrix from the quadratic program defined above. All + controls generated by this object will have values no less than this + lower bound. That is, any control u will satisfy min(u-lower) >= 0. + !*/ + + const matrix& get_upper_constraints ( + ) const; + /*! + ensures + - returns the upper matrix from the quadratic program defined above. All + controls generated by this object will have values no larger than this + upper bound. That is, any control u will satisfy min(upper-u) >= 0. + !*/ + + const matrix& get_target ( + const unsigned long time + ) const; + /*! + requires + - time < horizon + ensures + - This object will try to find the control sequence that results in the + process obtaining get_target(time) state at the indicated time. Note + that the next time instant after "right now" is time 0. + !*/ + + void set_target ( + const matrix& val, + const unsigned long time + ); + /*! + requires + - time < horizon + ensures + - #get_target(time) == val + !*/ + + void set_target ( + const matrix& val + ); + /*! + ensures + - for all valid t: + - #get_target(t) == val + !*/ + + void set_last_target ( + const matrix& val + ); + /*! + ensures + - performs: set_target(val, horizon-1) + !*/ + + unsigned long get_max_iterations ( + ) const; + /*! + ensures + - When operator() is called it solves an optimization problem to + get_epsilon() precision to determine the next control action. In + particular, we run the optimizer until the magnitude of each element of + the gradient vector is less than get_epsilon() or until + get_max_iterations() solver iterations have been executed. + !*/ + + void set_max_iterations ( + unsigned long max_iter + ); + /*! + ensures + - #get_max_iterations() == max_iter + !*/ + + void set_epsilon ( + double eps + ); + /*! + requires + - eps > 0 + ensures + - #get_epsilon() == eps + !*/ + + double get_epsilon ( + ) const; + /*! + ensures + - When operator() is called it solves an optimization problem to + get_epsilon() precision to determine the next control action. In + particular, we run the optimizer until the magnitude of each element of + the gradient vector is less than get_epsilon() or until + get_max_iterations() solver iterations have been executed. This means + that smaller epsilon values will give more accurate outputs but may take + longer to compute. + !*/ + + matrix operator() ( + const matrix& current_state + ); + /*! + requires + - min(R) > 0 + - A.nr() == current_state.size() + ensures + - Solves the model predictive control problem defined by the arguments to + this objects constructor, assuming that the starting state is given by + current_state. Then we return the control that should be taken in the + current state that best optimizes the quadratic objective function + defined above. + - We also shift over the target states so that you only need to update the + last one (if you are using non-zero target states) via a call to + set_last_target()). In particular, for all valid t, it will be the case + that: + - #get_target(t) == get_target(t+1) + - #get_target(horizon-1) == get_target(horizon-1) + !*/ + + }; + +} + +#endif // DLIB_MPC_ABSTRACT_Hh_ + diff --git a/ml/dlib/dlib/cpp_pretty_printer.h b/ml/dlib/dlib/cpp_pretty_printer.h new file mode 100644 index 000000000..5315559ba --- /dev/null +++ b/ml/dlib/dlib/cpp_pretty_printer.h @@ -0,0 +1,39 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CPP_PRETTY_PRINTEr_ +#define DLIB_CPP_PRETTY_PRINTEr_ + + +#include "cpp_pretty_printer/cpp_pretty_printer_kernel_1.h" +#include "cpp_pretty_printer/cpp_pretty_printer_kernel_2.h" +#include "cpp_tokenizer.h" +#include "stack.h" + +namespace dlib +{ + + class cpp_pretty_printer + { + cpp_pretty_printer() {} + + + typedef stack::kernel_1a stack; + typedef cpp_tokenizer::kernel_1a tok; + + public: + + //----------- kernels --------------- + + // kernel_1a + typedef cpp_pretty_printer_kernel_1 + kernel_1a; + + // kernel_2a + typedef cpp_pretty_printer_kernel_2 + kernel_2a; + + }; +} + +#endif // DLIB_CPP_PRETTY_PRINTEr_ + diff --git a/ml/dlib/dlib/cpp_pretty_printer/cpp_pretty_printer_kernel_1.h b/ml/dlib/dlib/cpp_pretty_printer/cpp_pretty_printer_kernel_1.h new file mode 100644 index 000000000..668d5049d --- /dev/null +++ b/ml/dlib/dlib/cpp_pretty_printer/cpp_pretty_printer_kernel_1.h @@ -0,0 +1,583 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CPP_PRETTY_PRINTER_KERNEl_1_ +#define DLIB_CPP_PRETTY_PRINTER_KERNEl_1_ + +#include +#include +#include +#include "cpp_pretty_printer_kernel_abstract.h" +#include "../algs.h" + +namespace dlib +{ + + template < + typename stack, + typename tok + > + class cpp_pretty_printer_kernel_1 + { + /*! + REQUIREMENTS ON stack + must be an implementation of stack/stack_kernel_abstract.h and + stack::type == unsigned long + + REQUIREMENTS ON tok + must be an implementation of tokenizer/tokenizer_kernel_abstract.h + + INFO + This implementation applies a color scheme, turns include directives + such as #include "file.h" into links to file.h.html, and it also puts + HTML anchor points on function and class declarations. + !*/ + + public: + + cpp_pretty_printer_kernel_1 ( + ); + + virtual ~cpp_pretty_printer_kernel_1 ( + ); + + void print ( + std::istream& in, + std::ostream& out, + const std::string& title + ) const; + + void print_and_number ( + std::istream& in, + std::ostream& out, + const std::string& title + ) const; + + private: + + const std::string htmlify ( + const std::string& str + ) const; + /*! + ensures + - str == str but with any '<' replaced with '<', any '>' replaced + with '>', and any '&' replaced with '&' + !*/ + + // data members + mutable tok t; + + void number ( + std::istream& in, + std::ostream& out + ) const; + /*! + ensures + - prints in to out and adds line numbers + !*/ + + // restricted functions + cpp_pretty_printer_kernel_1(const cpp_pretty_printer_kernel_1&); // copy constructor + cpp_pretty_printer_kernel_1& operator=(const cpp_pretty_printer_kernel_1&); // assignment operator + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename stack, + typename tok + > + cpp_pretty_printer_kernel_1:: + cpp_pretty_printer_kernel_1 ( + ) + { + } + +// ---------------------------------------------------------------------------------------- + + template < + typename stack, + typename tok + > + cpp_pretty_printer_kernel_1:: + ~cpp_pretty_printer_kernel_1 ( + ) + { + } + +// ---------------------------------------------------------------------------------------- + + template < + typename stack, + typename tok + > + void cpp_pretty_printer_kernel_1:: + print ( + std::istream& in, + std::ostream& out, + const std::string& title + ) const + { + using namespace std; + + if (!out) + throw std::ios::failure("error occurred in cpp_pretty_printer_kernel_1::print"); + + t.set_stream(in); + + out << "" << title << "
\n";
+        if (!out)
+            throw std::ios::failure("error occurred in cpp_pretty_printer_kernel_1::print");
+
+        unsigned long scope = 0; // counts the number of new scopes we have entered 
+                        // since we were at a scope where functions can be declared
+
+        bool recently_seen_class_keyword = false;
+            // true if we have seen the keywords class, struct, or enum and
+            // we have not seen any identifiers or { characters
+
+        bool recently_seen_include = false;
+            // true if we have seen the #include keyword and have not seen double
+            // quoted text or >
+
+        bool recently_seen_new_scope = false;  
+            // true if we have seen the keywords class, namespace, or struct and
+            // we have not seen the characters {, ), or ; since then
+
+        bool recently_seen_paren = false;
+            // true if we have seen a ) and we have only seen white_space or comments since
+
+        bool in_initialization_list = false;
+            // true if we have seen a ) followed by any white space or comments and then
+            // followed by a : (in scope==0 with recently_seen_preprocessor==false) and we 
+            // have not yet seen the character { or ;
+
+        bool recently_seen_preprocessor = false;
+            // true if we have seen the #pragma or #if or #define or #elif keywords and have 
+            // not seen an end of line.
+
+        bool recently_seen_extern = false;
+            // true if we have seen the extern keyword and haven't seen a ; or { yet.
+
+        unsigned long paren_count = 0; 
+            // this is the number of ( we have seen minus the number of ) we have
+            // seen.
+            
+
+        int type;
+        stack scopes; // a stack to hold old scopes
+        string token, temp;
+        t.get_token(type,token);
+        while (type != tok::END_OF_FILE)
+        {
+            switch (type)
+            {
+            case tok::IDENTIFIER: // ------------------------------------------
+                if ( recently_seen_class_keyword)
+                {
+                    // this might be a class name so check if there is a 
+                    // ; or identifier or * or & coming up.
+                    type = t.peek_type();
+                    temp.clear();
+                    if (type == tok::WHITE_SPACE)
+                    {
+                        t.get_token(type,temp);
+                        if (temp.find_first_of("\n\r") != string::npos)
+                            recently_seen_preprocessor = false;
+                    }
+                    if (t.peek_token() != ";" && t.peek_type() != tok::IDENTIFIER &&
+                        t.peek_token() != "*" && t.peek_token() != "&")
+                    {
+                        // this is the name of a class or struct in a class or
+                        // struct declaration.
+                        out << "" << token << "" << temp;
+                    }
+                    else
+                    {
+                        out << token << temp;
+                    }
+                }
+                else if ( !in_initialization_list &&
+                     !recently_seen_preprocessor )
+                {
+                    // this might be a function name so check if there is a 
+                    // ( coming up.
+                    type = t.peek_type();
+                    temp.clear();
+                    if (type == tok::WHITE_SPACE)
+                    {
+                        t.get_token(type,temp);
+                        type = t.peek_type();
+                    }
+                    if (type == tok::OTHER && t.peek_token() == "(")
+                    {
+                        if (scope == 0 && paren_count == 0)
+                        {
+                            // this is a function definition or prototype
+                            out << "" << token << "" << temp;
+                        }
+                        else
+                        {
+                            // this is a function call (probably) 
+                            out << "" << token << "" << temp;
+                        }
+                    }
+                    else
+                    {
+                        out << token << temp;
+                    }
+                }
+                else
+                {
+                    out << token;
+                }
+                
+
+
+                recently_seen_class_keyword = false;
+                recently_seen_paren = false;
+                break;
+
+            case tok::KEYWORD: // ---------------------------------------------
+                if (scope == 0 && token == "operator")
+                {
+                    // Doing this is sort of weird since operator is really a keyword
+                    // but I just like how this looks.
+                    out << "" << token << "";
+                }
+                // this isn't a keyword if it is something like #include 
+                else if ( token == "true" || token == "false")
+                {
+                    // color 'true' and 'false' the same way we color numbers
+                    out << "" << token << "";
+                }
+                else if (!recently_seen_include) 
+                {
+                    // This is a normal keyword
+                    if (token == "char" || token == "unsigned" || token == "signed" ||
+                        token == "short" || token == "int" || token == "long" || 
+                        token == "float" || token == "double" || token == "bool" ||
+                        token == "void" || token == "size_t" || token == "wchar_t")
+                    {
+                        out << "" << token << "";
+                    }
+                    else
+                    {
+                        out << "" << token << "";
+                    }
+                }
+                else
+                {
+                    out << token;
+                }
+
+                if (token == "#include") 
+                {
+                    recently_seen_include = true;
+                }
+                else if (token == "class")
+                {
+                    recently_seen_new_scope = true;
+                    recently_seen_class_keyword = true;
+                }
+                else if (token == "namespace")
+                {
+                    recently_seen_new_scope = true;
+                }
+                else if (token == "enum")
+                {
+                    recently_seen_class_keyword = true;
+                }
+                else if (token == "struct")
+                {
+                    recently_seen_new_scope = true;
+                    recently_seen_class_keyword = true;
+                }
+                else if (token == "#pragma" || token == "#if" || token == "#define" || token == "#elif")
+                {
+                    recently_seen_preprocessor = true;
+                }
+                else if (token == "extern")
+                {
+                    recently_seen_extern = true;
+                }
+                recently_seen_paren = false;
+                break;
+
+            case tok::COMMENT: // ---------------------------------------------
+                {
+                    // if this is a special anchor comment
+                    if (token.size() > 4 &&
+                        token[0] == '/' &&
+                        token[1] == '*' &&
+                        token[2] == '!' &&
+                        token[3] == 'A' &&
+                        token[4] == ' '
+                    )
+                    {
+                        temp = token;
+                        istringstream sin(token);
+                        sin >> temp;
+                        sin >> temp;
+                        sin.get();
+                        // if there was still more stuff in the token then we are ok.
+                        if (sin)
+                            out << "";
+                    }
+                    out << "" << htmlify(token) << "";
+                }
+                break;
+
+            case tok::SINGLE_QUOTED_TEXT: // ----------------------------------
+                {
+                    out << "" << htmlify(token) << "";
+                    recently_seen_paren = false;
+                }
+                break;
+
+            case tok::NUMBER: // -----------------------------------------
+                {
+                    out << "" << token << "";
+                    recently_seen_include = false;
+                }
+                break;
+
+            case tok::WHITE_SPACE: // -----------------------------------------
+                {
+                    out << token;
+                    if (token.find_first_of("\n\r") != string::npos)
+                        recently_seen_preprocessor = false;
+                }
+                break;
+
+            case tok::DOUBLE_QUOTED_TEXT: // ----------------------------------
+                {
+                    if (recently_seen_include)
+                    {
+                        // this is the name of an included file
+                        recently_seen_include = false;
+                        out << "" << htmlify(token) << "";                
+                    }
+                    else
+                    {
+                        // this is just a normal quoted string
+                        out << "" << htmlify(token) << "";
+                    }
+                    recently_seen_paren = false;
+                }
+                break;
+
+            case tok::OTHER: // -----------------------------------------------               
+                switch (token[0])
+                {
+                case '{':
+                    out << "{";  
+                    // if we are entering a new scope
+                    if (recently_seen_new_scope || recently_seen_extern)
+                    {
+                        recently_seen_new_scope = false;
+                        scopes.push(scope);
+                        scope = 0;
+                    }
+                    else
+                    {
+                        ++scope;
+                    }
+                    in_initialization_list = false;
+                    recently_seen_paren = false;
+                    recently_seen_class_keyword = false;
+                    recently_seen_extern = false;
+                    break;
+                case '}':
+                    out << "}";
+                    if (scope > 0)
+                    {
+                        --scope;
+                    }
+                    else if (scopes.size())
+                    {
+                        scopes.pop(scope);
+                    }
+                    recently_seen_paren = false;
+                    break;
+
+                case ':':
+                    out << ':';
+                    if (recently_seen_paren && scope == 0 && 
+                        recently_seen_preprocessor == false)
+                    {
+                        in_initialization_list = true;
+                    }
+                    recently_seen_paren = false;
+                    break;
+
+                case ';': 
+                    out << ';';
+                    recently_seen_new_scope = false;
+                    recently_seen_paren = false;
+                    recently_seen_extern = false;
+                    break;
+
+                case ')':
+                    out << ")";
+                    recently_seen_paren = true;
+                    recently_seen_new_scope = false;
+                    --paren_count;
+                    break;
+
+                case '(':
+                    out << "(";
+                    recently_seen_paren = false;
+                    ++paren_count;
+                    break;
+
+                case '>':
+                    recently_seen_include = false;
+                    out << ">";
+                    recently_seen_paren = false;
+                    break;
+
+                case '<':
+                    out << "<";
+                    recently_seen_paren = false;
+                    break;
+
+                case '&':
+                    out << "&";
+                    recently_seen_paren = false;
+                    break;
+
+                case '=':
+                case '+':
+                case '-':
+                case '/':
+                case '*':
+                case '!':
+                case '|':
+                case '%':
+                    out << "" << token << "";
+                    recently_seen_paren = false;
+                    break;
+
+                default:
+                    out << token;
+                    recently_seen_paren = false;
+                    break;
+
+                } // switch (token[0])
+                break;
+
+            } // switch (type)
+
+            t.get_token(type,token);
+        } // while (type != tok::END_OF_FILE)
+
+
+        out << "\n
"; + if (!out) + throw std::ios::failure("error occurred in cpp_pretty_printer_kernel_1::print"); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename stack, + typename tok + > + void cpp_pretty_printer_kernel_1:: + print_and_number ( + std::istream& in, + std::ostream& out, + const std::string& title + ) const + { + using namespace std; + ostringstream sout; + print(in,sout,title); + istringstream sin(sout.str()); + number(sin,out); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // private member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename stack, + typename tok + > + void cpp_pretty_printer_kernel_1:: + number ( + std::istream& in, + std::ostream& out + ) const + { + if (!out) + throw std::ios::failure("error occurred in cpp_pretty_printer_kernel_1::number"); + + std::string space = "   "; + std::ios::int_type ch; + unsigned long count = 1; + while ((ch=in.get()) != EOF) + { + if (ch != '\n') + { + out << (char)ch; + } + else + { + out << "\n" << count << " " + space; + ++count; + if (count == 10) + space = "  "; + if (count == 100) + space = " "; + if (count == 1000) + space = ""; + } + } + if (!out) + throw std::ios::failure("error occurred in cpp_pretty_printer_kernel_1::number"); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename stack, + typename tok + > + const std::string cpp_pretty_printer_kernel_1:: + htmlify ( + const std::string& str + ) const + { + std::string::size_type i; + std::string temp; + for (i = 0; i < str.size(); ++i) + { + if (str[i] == '<') + temp += "<"; + else if (str[i] == '>') + temp += ">"; + else if (str[i] == '&') + temp += "&"; + else + temp += str[i]; + } + return temp; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CPP_PRETTY_PRINTER_KERNEl_1_ + diff --git a/ml/dlib/dlib/cpp_pretty_printer/cpp_pretty_printer_kernel_2.h b/ml/dlib/dlib/cpp_pretty_printer/cpp_pretty_printer_kernel_2.h new file mode 100644 index 000000000..5ac894b33 --- /dev/null +++ b/ml/dlib/dlib/cpp_pretty_printer/cpp_pretty_printer_kernel_2.h @@ -0,0 +1,520 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CPP_PRETTY_PRINTER_KERNEl_2_ +#define DLIB_CPP_PRETTY_PRINTER_KERNEl_2_ + +#include +#include +#include +#include "cpp_pretty_printer_kernel_abstract.h" +#include "../algs.h" + +namespace dlib +{ + + template < + typename stack, + typename tok + > + class cpp_pretty_printer_kernel_2 + { + /*! + REQUIREMENTS ON stack + must be an implementation of stack/stack_kernel_abstract.h and + stack::type == unsigned long + + REQUIREMENTS ON tok + must be an implementation of tokenizer/tokenizer_kernel_abstract.h + + INFO + This implementation applies a black and white color scheme suitable + for printing on a black and white printer. It also places the document + title prominently at the top of the pretty printed source file. + !*/ + + public: + + cpp_pretty_printer_kernel_2 ( + ); + + virtual ~cpp_pretty_printer_kernel_2 ( + ); + + void print ( + std::istream& in, + std::ostream& out, + const std::string& title + ) const; + + void print_and_number ( + std::istream& in, + std::ostream& out, + const std::string& title + ) const; + + private: + + // data members + mutable tok t; + + const std::string htmlify ( + const std::string& str + ) const; + /*! + ensures + - str == str but with any '<' replaced with '<', any '>' replaced + with '>', and any '&' replaced with '&' + !*/ + + void number ( + std::istream& in, + std::ostream& out + ) const; + /*! + ensures + - prints in to out and adds line numbers + !*/ + + // restricted functions + cpp_pretty_printer_kernel_2(const cpp_pretty_printer_kernel_2&); // copy constructor + cpp_pretty_printer_kernel_2& operator=(const cpp_pretty_printer_kernel_2&); // assignment operator + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename stack, + typename tok + > + cpp_pretty_printer_kernel_2:: + cpp_pretty_printer_kernel_2 ( + ) + { + } + +// ---------------------------------------------------------------------------------------- + + template < + typename stack, + typename tok + > + cpp_pretty_printer_kernel_2:: + ~cpp_pretty_printer_kernel_2 ( + ) + { + } + +// ---------------------------------------------------------------------------------------- + + template < + typename stack, + typename tok + > + void cpp_pretty_printer_kernel_2:: + print ( + std::istream& in, + std::ostream& out, + const std::string& title + ) const + { + using namespace std; + + if (!out) + throw std::ios::failure("error occurred in cpp_pretty_printer_kernel_2::print"); + + t.set_stream(in); + + out << "" + << "" << title << "" + << "

" << title << "

\n"
+            << "\n";
+        if (!out)
+            throw std::ios::failure("error occurred in cpp_pretty_printer_kernel_2::print");
+
+        unsigned long scope = 0; // counts the number of new scopes we have entered 
+                        // since we were at a scope where functions can be declared
+
+        bool recently_seen_class_keyword = false;
+            // true if we have seen the keywords class or struct and
+            // we have not seen any identifiers or { characters
+
+        bool recently_seen_include = false;
+            // true if we have seen the #include keyword and have not seen double
+            // quoted text or >
+
+        bool recently_seen_new_scope = false;  
+            // true if we have seen the keywords class, namespace, or struct and
+            // we have not seen the characters {, ), or ; since then
+
+        bool recently_seen_paren = false;
+            // true if we have seen a ) and we have only seen white_space or comments since
+
+        bool in_initialization_list = false;
+            // true if we have seen a ) followed by any white space or comments and then
+            // followed by a : (in scope==0 with recently_seen_preprocessor==false) and we 
+            // have not yet seen the character { or ;
+
+        bool recently_seen_preprocessor = false;
+            // true if we have seen the #pragma or #if or #define or #elif keyword and 
+            // have not seen an identifier.
+
+
+        bool recently_seen_extern = false;
+            // true if we have seen the extern keyword and haven't yet seen a 
+            // { or ; character.
+
+        unsigned long paren_count = 0; 
+            // this is the number of ( we have seen minus the number of ) we have
+            // seen.
+            
+
+        int type;
+        stack scopes; // a stack to hold old scopes
+        string token, temp;
+        t.get_token(type,token);
+        while (type != tok::END_OF_FILE)
+        {
+            switch (type)
+            {
+            case tok::IDENTIFIER: // ------------------------------------------
+                if ( recently_seen_class_keyword)
+                {
+                    // this might be a class name so check if there is a 
+                    // ; or identifier or * or & coming up.
+                    type = t.peek_type();
+                    temp.clear();
+                    if (type == tok::WHITE_SPACE)
+                    {
+                        t.get_token(type,temp);
+                        if (temp.find_first_of("\n\r") != string::npos)
+                            recently_seen_preprocessor = false;
+                    }
+                    if (t.peek_token() != ";" && t.peek_type() != tok::IDENTIFIER &&
+                        t.peek_token() != "*" && t.peek_token() != "&")
+                    {
+                        // this is the name of a class or struct in a class or
+                        // struct declaration.
+                        out << "" << token << "" << temp;
+                    }
+                    else
+                    {
+                        out << token << temp;
+                    }
+                }
+                else if ( !in_initialization_list &&
+                     !recently_seen_preprocessor &&
+                     scope == 0 &&
+                     paren_count == 0)
+                {
+                    // this might be a function name so check if there is a 
+                    // ( coming up.
+                    type = t.peek_type();
+                    temp.clear();
+                    if (type == tok::WHITE_SPACE)
+                    {
+                        t.get_token(type,temp);
+                        type = t.peek_type();
+                    }
+                    if (type == tok::OTHER && t.peek_token() == "(")
+                    {
+                        // this is a function definition or prototype
+                        out << "" << token << "" << temp;
+                    }
+                    else
+                    {
+                        out << token << temp;
+                    }
+                }
+                else
+                {
+                    out << token;
+                }
+                
+
+
+                recently_seen_class_keyword = false;
+                recently_seen_paren = false;
+                break;
+
+            case tok::KEYWORD: // ---------------------------------------------
+                if (scope == 0 && token == "operator")
+                {
+                    // Doing this is sort of weird since operator is really a keyword
+                    // but I just like how this looks.
+                    out << "" << token << "";
+                }
+                // this isn't a keyword if it is something like #include 
+                else if (!recently_seen_include) 
+                {
+                    // This is a normal keyword
+                    out << "" << token << "";
+                }
+                else
+                {
+                    out << token;
+                }
+
+                if (token == "#include") 
+                {
+                    recently_seen_include = true;
+                }
+                else if (token == "class")
+                {
+                    recently_seen_new_scope = true;
+                    recently_seen_class_keyword = true;
+                }
+                else if (token == "namespace")
+                {
+                    recently_seen_new_scope = true;
+                }
+                else if (token == "struct")
+                {
+                    recently_seen_new_scope = true;
+                    recently_seen_class_keyword = true;
+                }
+                else if (token == "#pragma" || token == "#define" || token == "#elif" || token == "#if")
+                {
+                    recently_seen_preprocessor = true;
+                }
+                else if (token == "extern")
+                {
+                    recently_seen_extern = true;
+                }
+                recently_seen_paren = false;
+                break;
+
+            case tok::COMMENT: // ---------------------------------------------
+                {
+                    out << "" << htmlify(token) << "";
+                }
+                break;
+
+            case tok::SINGLE_QUOTED_TEXT: // ----------------------------------
+                {
+                    out << htmlify(token);
+                    recently_seen_paren = false;
+                }
+                break;
+
+            case tok::WHITE_SPACE: // -----------------------------------------
+                {
+                    out << token;
+                    if (token.find_first_of("\n\r") != string::npos)
+                        recently_seen_preprocessor = false;
+                }
+                break;
+
+            case tok::DOUBLE_QUOTED_TEXT: // ----------------------------------
+                {                    
+                    out << htmlify(token);
+                    recently_seen_paren = false;
+                    recently_seen_include = false;
+                }
+                break;
+
+            case tok::NUMBER:
+            case tok::OTHER: // -----------------------------------------------               
+                switch (token[0])
+                {
+                case '{':
+                    out << "{";  
+                    // if we are entering a new scope
+                    if (recently_seen_new_scope || recently_seen_extern)
+                    {
+                        recently_seen_new_scope = false;
+                        scopes.push(scope);
+                        scope = 0;
+                    }
+                    else
+                    {
+                        ++scope;
+                    }
+                    in_initialization_list = false;
+                    recently_seen_paren = false;
+                    recently_seen_class_keyword = false;
+                    recently_seen_extern = false;
+                    break;
+                case '}':
+                    out << "}";
+                    if (scope > 0)
+                    {
+                        --scope;
+                    }
+                    else if (scopes.size())
+                    {
+                        scopes.pop(scope);
+                    }
+                    recently_seen_paren = false;
+                    break;
+
+                case ':':
+                    out << ':';
+                    if (recently_seen_paren && scope == 0 &&
+                        recently_seen_preprocessor == false)
+                    {
+                        in_initialization_list = true;
+                    }
+                    recently_seen_paren = false;
+                    break;
+
+                case ';': 
+                    out << ';';
+                    recently_seen_new_scope = false;
+                    recently_seen_paren = false;
+                    recently_seen_extern = false;
+                    break;
+
+                case ')':
+                    out << ')';
+                    recently_seen_paren = true;
+                    recently_seen_new_scope = false;
+                    --paren_count;
+                    break;
+
+                case '(':
+                    out << '(';
+                    recently_seen_paren = false;
+                    ++paren_count;
+                    break;
+
+                case '>':
+                    recently_seen_include = false;
+                    out << ">";
+                    recently_seen_paren = true;
+                    break;
+
+                case '<':
+                    out << "<";
+                    recently_seen_paren = true;
+                    break;
+
+                case '&':
+                    out << "&";
+                    recently_seen_paren = true;
+                    break;
+
+                default:
+                    out << token;
+                    recently_seen_paren = false;
+                    if (token == ">")
+                        recently_seen_include = false;
+                    break;
+
+                } // switch (token[0])
+                break;
+
+            } // switch (type)
+
+            t.get_token(type,token);
+        } // while (type != tok::END_OF_FILE)
+
+
+        out << "
"; + if (!out) + throw std::ios::failure("error occurred in cpp_pretty_printer_kernel_2::print"); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename stack, + typename tok + > + void cpp_pretty_printer_kernel_2:: + print_and_number ( + std::istream& in, + std::ostream& out, + const std::string& title + ) const + { + using namespace std; + ostringstream sout; + print(in,sout,title); + istringstream sin(sout.str()); + number(sin,out); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // private member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename stack, + typename tok + > + void cpp_pretty_printer_kernel_2:: + number ( + std::istream& in, + std::ostream& out + ) const + { + if (!out) + throw std::ios::failure("error occurred in cpp_pretty_printer_kernel_2::number"); + + std::string space = "   "; + std::ios::int_type ch; + unsigned long count = 1; + while ((ch=in.get()) != EOF) + { + if (ch != '\n') + { + out << (char)ch; + } + else + { + out << "\n" << count << " " + space; + ++count; + if (count == 10) + space = "  "; + if (count == 100) + space = " "; + if (count == 1000) + space = ""; + } + } + if (!out) + throw std::ios::failure("error occurred in cpp_pretty_printer_kernel_2::number"); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename stack, + typename tok + > + const std::string cpp_pretty_printer_kernel_2:: + htmlify ( + const std::string& str + ) const + { + std::string::size_type i; + std::string temp; + for (i = 0; i < str.size(); ++i) + { + if (str[i] == '<') + temp += "<"; + else if (str[i] == '>') + temp += ">"; + else if (str[i] == '&') + temp += "&"; + else + temp += str[i]; + } + return temp; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CPP_PRETTY_PRINTER_KERNEl_2_ + diff --git a/ml/dlib/dlib/cpp_pretty_printer/cpp_pretty_printer_kernel_abstract.h b/ml/dlib/dlib/cpp_pretty_printer/cpp_pretty_printer_kernel_abstract.h new file mode 100644 index 000000000..7d572d4fd --- /dev/null +++ b/ml/dlib/dlib/cpp_pretty_printer/cpp_pretty_printer_kernel_abstract.h @@ -0,0 +1,88 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_CPP_PRETTY_PRINTER_KERNEl_ABSTRACT_ +#ifdef DLIB_CPP_PRETTY_PRINTER_KERNEl_ABSTRACT_ + +#include +#include + +namespace dlib +{ + + class cpp_pretty_printer + { + /*! + INITIAL VALUE + This object does not have any state associated with it. + + WHAT THIS OBJECT REPRESENTS + This object represents an HTML pretty printer for C++ source code. + + !*/ + + public: + + cpp_pretty_printer ( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc + !*/ + + virtual ~cpp_pretty_printer ( + ); + /*! + ensures + - any resources associated with *this have been released + !*/ + + void print ( + std::istream& in, + std::ostream& out, + const std::string& title + ) const; + /*! + ensures + - treats data from in as C++ source code and pretty prints it in + HTML and writes it to out. + - The title of the HTML document writen to out will be title + throws + - std::ios_base::failure + If there was a problem writing to out then this exception will + be thrown. + - any other exception + This exception may be thrown if there is any other problem. + !*/ + + void print_and_number ( + std::istream& in, + std::ostream& out, + const std::string& title + ) const; + /*! + ensures + - treats data from in as C++ source code and pretty prints it in + HTML with line numbers and writes it to out. + - The title of the HTML document writen to out will be title + throws + - std::ios_base::failure + If there was a problem writing to out then this exception will + be thrown. + - any other exception + This exception may be thrown if there is any other problem. + !*/ + + private: + + // restricted functions + cpp_pretty_printer(const cpp_pretty_printer&); // copy constructor + cpp_pretty_printer& operator=(const cpp_pretty_printer&); // assignment operator + + }; + +} + +#endif // DLIB_CPP_PRETTY_PRINTER_KERNEl_ABSTRACT_ + diff --git a/ml/dlib/dlib/cpp_tokenizer.h b/ml/dlib/dlib/cpp_tokenizer.h new file mode 100644 index 000000000..676ad7a52 --- /dev/null +++ b/ml/dlib/dlib/cpp_tokenizer.h @@ -0,0 +1,40 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CPP_TOKENIZEr_ +#define DLIB_CPP_TOKENIZEr_ + +#include +#include "cpp_tokenizer/cpp_tokenizer_kernel_1.h" +#include "cpp_tokenizer/cpp_tokenizer_kernel_c.h" +#include "tokenizer.h" +#include "queue.h" +#include "set.h" + +namespace dlib +{ + + class cpp_tokenizer + { + cpp_tokenizer() {} + + + typedef set::kernel_1a set; + typedef queue::kernel_2a queue; + typedef tokenizer::kernel_1a tok; + + public: + + //----------- kernels --------------- + + // kernel_1a + typedef cpp_tokenizer_kernel_1 + kernel_1a; + typedef cpp_tokenizer_kernel_c + kernel_1a_c; + + + }; +} + +#endif // DLIB_CPP_TOKENIZEr_ + diff --git a/ml/dlib/dlib/cpp_tokenizer/cpp_tokenizer_kernel_1.h b/ml/dlib/dlib/cpp_tokenizer/cpp_tokenizer_kernel_1.h new file mode 100644 index 000000000..8a244faa7 --- /dev/null +++ b/ml/dlib/dlib/cpp_tokenizer/cpp_tokenizer_kernel_1.h @@ -0,0 +1,675 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CPP_TOKENIZER_KERNEl_1_ +#define DLIB_CPP_TOKENIZER_KERNEl_1_ + +#include +#include +#include "cpp_tokenizer_kernel_abstract.h" +#include "../algs.h" + +namespace dlib +{ + + namespace cpp_tok_kernel_1_helper + { + struct token_text_pair + { + std::string token; + int type; + }; + + } + + template < + typename tok, + typename queue, + typename set + > + class cpp_tokenizer_kernel_1 + { + /*! + REQUIREMENTS ON tok + tok must be an implementation of tokenizer/tokenizer_kernel_abstract.h + + REQUIREMENTS ON queue + queue must be an implementation of queue/queue_kernel_abstract.h + and must have T==cpp_tok_kernel_1_helper::token_text_pair + + REQUIREMENTS ON set + set must be an implemention of set/set_kernel_abstract.h or + hash_set/hash_set_kernel_abstract.h and must have T==std::string. + + INITIAL VALUE + - keywords == a set of all the C++ keywords + - tokenizer.stream_is_set() == false + - buffer.size() == 0 + - tokenizer.get_identifier_head() == "$_" + tokenizer.lowercase_letters() + + tokenizer.uppercase_letters() + - tokenizer.get_identifier_body() == "$_" + tokenizer.lowercase_letters() + + tokenizer.uppercase_letters() + tokenizer.numbers() + - have_peeked == false + + + CONVENTION + - tokenizer.stream_is_set() == stream_is_set() + - tokenizer.get_stream() == get_stream() + - keywords == a set of all the C++ keywords + + - tokenizer.get_identifier_head() == "$_" + tokenizer.lowercase_letters() + + tokenizer.uppercase_letters() + - tokenizer.get_identifier_body() == "$_" + tokenizer.lowercase_letters() + + tokenizer.uppercase_letters() + tokenizer.numbers() + + - buffer == a queue of tokens. This is where we put tokens + we gathered early due to looking ahead. + + + - if (have_peeked) then + - next_token == the next token to be returned from get_token() + - next_type == the type of token in peek_token + !*/ + + typedef cpp_tok_kernel_1_helper::token_text_pair token_text_pair; + + public: + + enum + { + END_OF_FILE, + KEYWORD, + COMMENT, + SINGLE_QUOTED_TEXT, + DOUBLE_QUOTED_TEXT, + IDENTIFIER, + OTHER, + NUMBER, + WHITE_SPACE + }; + + cpp_tokenizer_kernel_1 ( + ); + + virtual ~cpp_tokenizer_kernel_1 ( + ); + + void clear( + ); + + void set_stream ( + std::istream& in + ); + + bool stream_is_set ( + ) const; + + std::istream& get_stream ( + ) const; + + void get_token ( + int& type, + std::string& token + ); + + int peek_type ( + ) const; + + const std::string& peek_token ( + ) const; + + void swap ( + cpp_tokenizer_kernel_1& item + ); + + private: + + void buffer_token( + int type, + const std::string& token + ) + /*! + ensures + - stores the token and its type into buffer + !*/ + { + token_text_pair temp; + temp.token = token; + temp.type = type; + buffer.enqueue(temp); + } + + void buffer_token( + int type, + char token + ) + /*! + ensures + - stores the token and its type into buffer + !*/ + { + token_text_pair temp; + temp.token = token; + temp.type = type; + buffer.enqueue(temp); + } + + // restricted functions + cpp_tokenizer_kernel_1(const cpp_tokenizer_kernel_1&); // copy constructor + cpp_tokenizer_kernel_1& operator=(const cpp_tokenizer_kernel_1&); // assignment operator + + // data members + set keywords; + queue buffer; + tok tokenizer; + + mutable std::string next_token; + mutable int next_type; + mutable bool have_peeked; + + + }; + + template < + typename tok, + typename queue, + typename set + > + inline void swap ( + cpp_tokenizer_kernel_1& a, + cpp_tokenizer_kernel_1& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename tok, + typename queue, + typename set + > + cpp_tokenizer_kernel_1:: + cpp_tokenizer_kernel_1( + ) : + have_peeked(false) + { + // add C++ keywords to keywords + std::string temp; + temp = "#include"; keywords.add(temp); + temp = "__asm"; keywords.add(temp); + temp = "_asm"; keywords.add(temp); + temp = "if"; keywords.add(temp); + temp = "int"; keywords.add(temp); + temp = "else"; keywords.add(temp); + temp = "template"; keywords.add(temp); + temp = "void"; keywords.add(temp); + temp = "false"; keywords.add(temp); + temp = "class"; keywords.add(temp); + temp = "public"; keywords.add(temp); + temp = "while"; keywords.add(temp); + temp = "bool"; keywords.add(temp); + temp = "new"; keywords.add(temp); + temp = "delete"; keywords.add(temp); + temp = "true"; keywords.add(temp); + temp = "typedef"; keywords.add(temp); + temp = "const"; keywords.add(temp); + temp = "virtual"; keywords.add(temp); + temp = "inline"; keywords.add(temp); + temp = "for"; keywords.add(temp); + temp = "break"; keywords.add(temp); + temp = "struct"; keywords.add(temp); + temp = "float"; keywords.add(temp); + temp = "case"; keywords.add(temp); + temp = "enum"; keywords.add(temp); + temp = "this"; keywords.add(temp); + temp = "typeid"; keywords.add(temp); + temp = "double"; keywords.add(temp); + temp = "char"; keywords.add(temp); + temp = "typename"; keywords.add(temp); + temp = "signed"; keywords.add(temp); + temp = "friend"; keywords.add(temp); + temp = "wint_t"; keywords.add(temp); + temp = "default"; keywords.add(temp); + temp = "asm"; keywords.add(temp); + temp = "reinterpret_cast"; keywords.add(temp); + temp = "#define"; keywords.add(temp); + temp = "do"; keywords.add(temp); + temp = "continue"; keywords.add(temp); + temp = "auto"; keywords.add(temp); + temp = "unsigned"; keywords.add(temp); + temp = "size_t"; keywords.add(temp); + temp = "#undef"; keywords.add(temp); + temp = "#pragma"; keywords.add(temp); + temp = "namespace"; keywords.add(temp); + temp = "private"; keywords.add(temp); + temp = "#endif"; keywords.add(temp); + temp = "catch"; keywords.add(temp); + temp = "#else"; keywords.add(temp); + temp = "register"; keywords.add(temp); + temp = "volatile"; keywords.add(temp); + temp = "const_cast"; keywords.add(temp); + temp = "#end"; keywords.add(temp); + temp = "mutable"; keywords.add(temp); + temp = "static_cast"; keywords.add(temp); + temp = "wchar_t"; keywords.add(temp); + temp = "#if"; keywords.add(temp); + temp = "protected"; keywords.add(temp); + temp = "throw"; keywords.add(temp); + temp = "using"; keywords.add(temp); + temp = "dynamic_cast"; keywords.add(temp); + temp = "#ifdef"; keywords.add(temp); + temp = "return"; keywords.add(temp); + temp = "short"; keywords.add(temp); + temp = "#error"; keywords.add(temp); + temp = "#line"; keywords.add(temp); + temp = "explicit"; keywords.add(temp); + temp = "union"; keywords.add(temp); + temp = "#ifndef"; keywords.add(temp); + temp = "try"; keywords.add(temp); + temp = "sizeof"; keywords.add(temp); + temp = "goto"; keywords.add(temp); + temp = "long"; keywords.add(temp); + temp = "#elif"; keywords.add(temp); + temp = "static"; keywords.add(temp); + temp = "operator"; keywords.add(temp); + temp = "switch"; keywords.add(temp); + temp = "extern"; keywords.add(temp); + + + // set the tokenizer's IDENTIFIER token for C++ identifiers + tokenizer.set_identifier_token( + "$_" + tokenizer.lowercase_letters() + tokenizer.uppercase_letters(), + "$_" + tokenizer.lowercase_letters() + tokenizer.uppercase_letters() + + tokenizer.numbers() + ); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename tok, + typename queue, + typename set + > + cpp_tokenizer_kernel_1:: + ~cpp_tokenizer_kernel_1 ( + ) + { + } + +// ---------------------------------------------------------------------------------------- + + template < + typename tok, + typename queue, + typename set + > + void cpp_tokenizer_kernel_1:: + clear( + ) + { + tokenizer.clear(); + buffer.clear(); + have_peeked = false; + + // set the tokenizer's IDENTIFIER token for C++ identifiers + tokenizer.set_identifier_token( + "$_" + tokenizer.lowercase_letters() + tokenizer.uppercase_letters(), + "$_" + tokenizer.lowercase_letters() + tokenizer.uppercase_letters() + + tokenizer.numbers() + ); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename tok, + typename queue, + typename set + > + void cpp_tokenizer_kernel_1:: + set_stream ( + std::istream& in + ) + { + tokenizer.set_stream(in); + buffer.clear(); + have_peeked = false; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename tok, + typename queue, + typename set + > + bool cpp_tokenizer_kernel_1:: + stream_is_set ( + ) const + { + return tokenizer.stream_is_set(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename tok, + typename queue, + typename set + > + std::istream& cpp_tokenizer_kernel_1:: + get_stream ( + ) const + { + return tokenizer.get_stream(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename tok, + typename queue, + typename set + > + void cpp_tokenizer_kernel_1:: + get_token ( + int& type, + std::string& token + ) + { + using namespace std; + + if (!have_peeked) + { + + if (buffer.size() > 0) + { + // just return what is in the buffer + token_text_pair temp; + buffer.dequeue(temp); + type = temp.type; + token = temp.token; + return; + } + + tokenizer.get_token(type,token); + + switch (type) + { + case tok::END_OF_FILE: + { + type = END_OF_FILE; + } break; + + case tok::END_OF_LINE: + case tok::WHITE_SPACE: + { + type = tokenizer.peek_type(); + if (type == tok::END_OF_LINE || type == tok::WHITE_SPACE) + { + std::string temp; + do + { + tokenizer.get_token(type,temp); + token += temp; + type = tokenizer.peek_type(); + }while (type == tok::END_OF_LINE || type == tok::WHITE_SPACE); + } + type = WHITE_SPACE; + + } break; + + case tok::NUMBER: + { + // this could be a hex number such as 0xa33. we should check for this. + if (tokenizer.peek_type() == tok::IDENTIFIER && token == "0" && + (tokenizer.peek_token()[0] == 'x' || tokenizer.peek_token()[0] == 'X')) + { + // this is a hex number so accumulate all the numbers and identifiers that follow + // because they have to be part of the number + std::string temp; + tokenizer.get_token(type,temp); + token = "0" + temp; + + // get the rest of the hex number + while (tokenizer.peek_type() == tok::IDENTIFIER || + tokenizer.peek_type() == tok::NUMBER + ) + { + tokenizer.get_token(type,temp); + token += temp; + } + + } + // or this could be a floating point value or something with an 'e' or 'E' in it. + else if ((tokenizer.peek_type() == tok::CHAR && tokenizer.peek_token()[0] == '.') || + (tokenizer.peek_type() == tok::IDENTIFIER && std::tolower(tokenizer.peek_token()[0]) == 'e')) + { + std::string temp; + tokenizer.get_token(type,temp); + token += temp; + // now get the rest of the floating point value + while (tokenizer.peek_type() == tok::IDENTIFIER || + tokenizer.peek_type() == tok::NUMBER + ) + { + tokenizer.get_token(type,temp); + token += temp; + } + } + type = NUMBER; + + } break; + + case tok::IDENTIFIER: + { + if (keywords.is_member(token)) + { + type = KEYWORD; + } + else + { + type = IDENTIFIER; + } + } break; + + case tok::CHAR: + type = OTHER; + switch (token[0]) + { + case '#': + { + // this might be a preprocessor keyword so we should check the + // next token + if (tokenizer.peek_type() == tok::IDENTIFIER && + keywords.is_member('#'+tokenizer.peek_token())) + { + tokenizer.get_token(type,token); + token = '#' + token; + type = KEYWORD; + } + else + { + token = '#'; + type = OTHER; + } + } + break; + + case '"': + { + string temp; + tokenizer.get_token(type,token); + while (type != tok::END_OF_FILE) + { + // if this is the end of the quoted string + if (type == tok::CHAR && token[0] == '"' && + (temp.size() == 0 || temp[temp.size()-1] != '\\' || + (temp.size() > 1 && temp[temp.size()-2] == '\\') )) + { + buffer_token(DOUBLE_QUOTED_TEXT,temp); + buffer_token(OTHER,"\""); + break; + } + else + { + temp += token; + } + tokenizer.get_token(type,token); + } + + + type = OTHER; + token = '"'; + } break; + + case '\'': + { + string temp; + tokenizer.get_token(type,token); + if (type == tok::CHAR && token[0] == '\\') + { + temp += '\\'; + tokenizer.get_token(type,token); + } + temp += token; + buffer_token(SINGLE_QUOTED_TEXT,temp); + + // The next character should be a ' so take it out and put it in + // the buffer. + tokenizer.get_token(type,token); + buffer_token(OTHER,token); + + type = OTHER; + token = '\''; + } break; + + case '/': + { + // look ahead to see if this is the start of a comment + if (tokenizer.peek_type() == tok::CHAR) + { + if (tokenizer.peek_token()[0] == '/') + { + tokenizer.get_token(type,token); + // this is the start of a line comment + token = "//"; + string temp; + tokenizer.get_token(type,temp); + while (type != tok::END_OF_FILE) + { + // if this is the end of the comment + if (type == tok::END_OF_LINE && + token[token.size()-1] != '\\' ) + { + token += '\n'; + break; + } + else + { + token += temp; + } + tokenizer.get_token(type,temp); + } + type = COMMENT; + + } + else if (tokenizer.peek_token()[0] == '*') + { + tokenizer.get_token(type,token); + // this is the start of a block comment + token = "/*"; + string temp; + tokenizer.get_token(type,temp); + while (type != tok::END_OF_FILE) + { + // if this is the end of the comment + if (type == tok::CHAR && temp[0] == '/' && + token[token.size()-1] == '*') + { + token += '/'; + break; + } + else + { + token += temp; + } + tokenizer.get_token(type,temp); + } + type = COMMENT; + } + } + } break; + + default: + break; + } // switch (token[0]) + } // switch (type) + } + else + { + // if we get this far it means we have peeked so we should + // return the peek data. + type = next_type; + token = next_token; + have_peeked = false; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename tok, + typename queue, + typename set + > + int cpp_tokenizer_kernel_1:: + peek_type ( + ) const + { + const_cast*>(this)->get_token(next_type,next_token); + have_peeked = true; + return next_type; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename tok, + typename queue, + typename set + > + const std::string& cpp_tokenizer_kernel_1:: + peek_token ( + ) const + { + const_cast*>(this)->get_token(next_type,next_token); + have_peeked = true; + return next_token; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename tok, + typename queue, + typename set + > + void cpp_tokenizer_kernel_1:: + swap ( + cpp_tokenizer_kernel_1& item + ) + { + tokenizer.swap(item.tokenizer); + buffer.swap(item.buffer); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CPP_TOKENIZER_KERNEl_1_ + diff --git a/ml/dlib/dlib/cpp_tokenizer/cpp_tokenizer_kernel_abstract.h b/ml/dlib/dlib/cpp_tokenizer/cpp_tokenizer_kernel_abstract.h new file mode 100644 index 000000000..e7ac23284 --- /dev/null +++ b/ml/dlib/dlib/cpp_tokenizer/cpp_tokenizer_kernel_abstract.h @@ -0,0 +1,224 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_CPP_TOKENIZER_KERNEl_ABSTRACT_ +#ifdef DLIB_CPP_TOKENIZER_KERNEl_ABSTRACT_ + +#include +#include + +namespace dlib +{ + + class cpp_tokenizer + { + /*! + INITIAL VALUE + stream_is_set() == false + + WHAT THIS OBJECT REPRESENTS + This object represents a simple tokenizer for C++ source code. + + BUFFERING + This object is allowed to buffer data from the input stream. + Thus if you clear it or switch streams (via calling set_stream()) + any buffered data will be lost. + + TOKENS + When picking out tokens the cpp_tokenizer will always extract the + longest token it can. For example, if faced with the string + "AAA" it will consider the three As to be a single IDENTIFIER + token not three smaller IDENTIFIER tokens. + + Also note that no characters in the input stream are discarded. + They will all be returned in the text of some token. + Additionally, each character will never be returned more than once. + This means that if you concatenated all returned tokens it would exactly + reproduce the contents of the input stream. + + The tokens are defined as follows: + + END_OF_FILE + This token represents the end of file. It doesn't have any + actual characters associated with it. + + KEYWORD + This token matches a C++ keyword. (This includes the preprocessor + directives). + + COMMENT + This token matches a C++ comment. + + SINGLE_QUOTED_TEXT + This token matches the text of any single quoted literal. + For example, 'a' would be a match and the text of this token + would be the single character a. + + DOUBLE_QUOTED_TEXT + This token matches the text of any double quoted string. + For example, "C++" would be a match and the text of this token + would be the three character string C++. + + WHITE_SPACE + This is a multi character token. It is defined as a sequence of + one or more spaces, carrage returns, newlines, and tabs. I.e. It + is composed of characters from the following string " \r\n\t". + + IDENTIFIER + This token matches any C++ identifier that isn't matched by any + of the above tokens. (A C++ identifier being a string matching + the regular expression [_$a-zA-Z][_$a-zA-Z0-9]*). + + NUMBER + This token matches any C++ numerical constant. + + OTHER + This matches anything that isn't part of one of the above tokens. + It is always a single character. + !*/ + + public: + + enum + { + END_OF_FILE, + KEYWORD, + COMMENT, + SINGLE_QUOTED_TEXT, + DOUBLE_QUOTED_TEXT, + IDENTIFIER, + OTHER, + NUMBER, + WHITE_SPACE + }; + + cpp_tokenizer ( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc + !*/ + + virtual ~cpp_tokenizer ( + ); + /*! + ensures + - any resources associated with *this have been released + !*/ + + void clear( + ); + /*! + ensures + - #*this has its initial value + throws + - std::bad_alloc + If this exception is thrown then #*this is unusable + until clear() is called and succeeds. + !*/ + + void set_stream ( + std::istream& in + ); + /*! + ensures + - #*this will read data from in and tokenize it + - #stream_is_set() == true + - #get_stream() == in + !*/ + + bool stream_is_set ( + ) const; + /*! + ensures + - returns true if a stream has been associated with *this by calling + set_stream() + !*/ + + std::istream& get_stream ( + ) const; + /*! + requires + - stream_is_set() == true + ensures + - returns a reference to the istream object that *this is reading + from. + !*/ + + void get_token ( + int& type, + std::string& token + ); + /*! + requires + - stream_is_set() == true + ensures + - #token == the next token from the input stream get_stream() + - #type == the type of the token in #token + throws + - bad_alloc + If this exception is thrown then the call to this function will + have no effect on *this but the values of #type and #token will be + undefined. Additionally, some characters may have been read + from the stream get_stream() and lost. + !*/ + + int peek_type ( + ) const; + /*! + requires + - stream_is_set() == true + ensures + - returns the type of the token that will be returned from + the next call to get_token() + throws + - bad_alloc + If this exception is thrown then the call to this function will + have no effect on *this. However, some characters may have been + read from the stream get_stream() and lost. + !*/ + + const std::string& peek_token ( + ) const; + /*! + requires + - stream_is_set() == true + ensures + - returns the text of the token that will be returned from + the next call to get_token() + throws + - bad_alloc + If this exception is thrown then the call to this function will + have no effect on *this. However, some characters may have been + read from the stream get_stream() and lost. + !*/ + + void swap ( + cpp_tokenizer& item + ); + /*! + ensures + - swaps *this and item + !*/ + + private: + + // restricted functions + cpp_tokenizer(const cpp_tokenizer&); // copy constructor + cpp_tokenizer& operator=(const cpp_tokenizer&); // assignment operator + + }; + + inline void swap ( + cpp_tokenizer& a, + cpp_tokenizer& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + +} + +#endif // DLIB_CPP_TOKENIZER_KERNEl_ABSTRACT_ + diff --git a/ml/dlib/dlib/cpp_tokenizer/cpp_tokenizer_kernel_c.h b/ml/dlib/dlib/cpp_tokenizer/cpp_tokenizer_kernel_c.h new file mode 100644 index 000000000..0073a5680 --- /dev/null +++ b/ml/dlib/dlib/cpp_tokenizer/cpp_tokenizer_kernel_c.h @@ -0,0 +1,137 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CPP_TOKENIZER_KERNEl_C_ +#define DLIB_CPP_TOKENIZER_KERNEl_C_ + +#include "cpp_tokenizer_kernel_abstract.h" +#include "../assert.h" +#include +#include + +namespace dlib +{ + + template < + typename tokenizer + > + class cpp_tokenizer_kernel_c : public tokenizer + { + + public: + std::istream& get_stream ( + ) const; + + void get_token ( + int& type, + std::string& token + ); + + int peek_type ( + ) const; + + const std::string& peek_token ( + ) const; + + }; + + template < + typename tokenizer + > + inline void swap ( + cpp_tokenizer_kernel_c& a, + cpp_tokenizer_kernel_c& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename tokenizer + > + std::istream& cpp_tokenizer_kernel_c:: + get_stream ( + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT( this->stream_is_set() == true, + "\tstd::istream& cpp_tokenizer::get_stream()" + << "\n\tyou must set a stream for this object before you can get it" + << "\n\tthis: " << this + ); + + // call the real function + return tokenizer::get_stream(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename tokenizer + > + const std::string& cpp_tokenizer_kernel_c:: + peek_token ( + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT( this->stream_is_set() == true, + "\tconst std::string& cpp_tokenizer::peek_token()" + << "\n\tyou must set a stream for this object before you can peek at what it contains" + << "\n\tthis: " << this + ); + + // call the real function + return tokenizer::peek_token(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename tokenizer + > + int cpp_tokenizer_kernel_c:: + peek_type ( + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT( this->stream_is_set() == true, + "\tint cpp_tokenizer::peek_type()" + << "\n\tyou must set a stream for this object before you can peek at what it contains" + << "\n\tthis: " << this + ); + + // call the real function + return tokenizer::peek_type(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename tokenizer + > + void cpp_tokenizer_kernel_c:: + get_token ( + int& type, + std::string& token + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( this->stream_is_set() == true, + "\tvoid cpp_tokenizer::get_token()" + << "\n\tyou must set a stream for this object before you can get tokens from it." + << "\n\tthis: " << this + ); + + // call the real function + tokenizer::get_token(type,token); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_TOKENIZER_KERNEl_C_ + + diff --git a/ml/dlib/dlib/crc32.h b/ml/dlib/dlib/crc32.h new file mode 100644 index 000000000..004aaeba4 --- /dev/null +++ b/ml/dlib/dlib/crc32.h @@ -0,0 +1,10 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CRc32_ +#define DLIB_CRc32_ + + +#include "crc32/crc32_kernel_1.h" + +#endif // DLIB_CRc32_ + diff --git a/ml/dlib/dlib/crc32/crc32_kernel_1.h b/ml/dlib/dlib/crc32/crc32_kernel_1.h new file mode 100644 index 000000000..4c679d2f2 --- /dev/null +++ b/ml/dlib/dlib/crc32/crc32_kernel_1.h @@ -0,0 +1,262 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CRC32_KERNEl_1_ +#define DLIB_CRC32_KERNEl_1_ + +#include "../algs.h" +#include +#include +#include "crc32_kernel_abstract.h" + +namespace dlib +{ + + class crc32 + { + /*! + INITIAL VALUE + checksum == 0xFFFFFFFF + + CONVENTION + get_checksum() == checksum ^ 0xFFFFFFFF + !*/ + + public: + + // this is here for backwards compatibility with older versions of dlib. + typedef crc32 kernel_1a; + + inline crc32 ( + ); + + inline crc32 ( + const std::string& item + ); + + inline crc32 ( + const std::vector& item + ); + + inline virtual ~crc32 ( + ); + + inline void clear( + ); + + inline void add ( + unsigned char item + ); + + inline void add ( + const std::string& item + ); + + inline void add ( + const std::vector& item + ); + + inline operator unsigned long ( + ) const { return get_checksum(); } + + inline unsigned long get_checksum ( + ) const; + + inline void swap ( + crc32& item + ); + + inline crc32& operator=( + const crc32& + ); + + private: + + unsigned long checksum; + + inline unsigned long table ( + unsigned int idx + ) const + { + /* + // This code generates the crc_table used below. + unsigned long crc_table[256]; + for (unsigned long i = 0; i < 256; ++i) + { + unsigned long temp = i; + for (unsigned long j = 0; j < 8; ++j) + { + if (temp&1) + temp = (temp>>1)^0xedb88320; + else + temp >>= 1; + } + crc_table[i] = temp; + std::cout << std::hex << crc_table[i] << std::endl; + } + */ + + const static unsigned long crc_table[256] = { + 0x00000000, 0x77073096, 0xee0e612c, 0x990951ba, 0x76dc419, 0x706af48f, 0xe963a535, 0x9e6495a3, + 0xedb8832, 0x79dcb8a4, 0xe0d5e91e, 0x97d2d988, 0x9b64c2b, 0x7eb17cbd, 0xe7b82d07, 0x90bf1d91, + 0x1db71064, 0x6ab020f2, 0xf3b97148, 0x84be41de, 0x1adad47d, 0x6ddde4eb, 0xf4d4b551, 0x83d385c7, + 0x136c9856, 0x646ba8c0, 0xfd62f97a, 0x8a65c9ec, 0x14015c4f, 0x63066cd9, 0xfa0f3d63, 0x8d080df5, + 0x3b6e20c8, 0x4c69105e, 0xd56041e4, 0xa2677172, 0x3c03e4d1, 0x4b04d447, 0xd20d85fd, 0xa50ab56b, + 0x35b5a8fa, 0x42b2986c, 0xdbbbc9d6, 0xacbcf940, 0x32d86ce3, 0x45df5c75, 0xdcd60dcf, 0xabd13d59, + 0x26d930ac, 0x51de003a, 0xc8d75180, 0xbfd06116, 0x21b4f4b5, 0x56b3c423, 0xcfba9599, 0xb8bda50f, + 0x2802b89e, 0x5f058808, 0xc60cd9b2, 0xb10be924, 0x2f6f7c87, 0x58684c11, 0xc1611dab, 0xb6662d3d, + 0x76dc4190, 0x1db7106, 0x98d220bc, 0xefd5102a, 0x71b18589, 0x6b6b51f, 0x9fbfe4a5, 0xe8b8d433, + 0x7807c9a2, 0xf00f934, 0x9609a88e, 0xe10e9818, 0x7f6a0dbb, 0x86d3d2d, 0x91646c97, 0xe6635c01, + 0x6b6b51f4, 0x1c6c6162, 0x856530d8, 0xf262004e, 0x6c0695ed, 0x1b01a57b, 0x8208f4c1, 0xf50fc457, + 0x65b0d9c6, 0x12b7e950, 0x8bbeb8ea, 0xfcb9887c, 0x62dd1ddf, 0x15da2d49, 0x8cd37cf3, 0xfbd44c65, + 0x4db26158, 0x3ab551ce, 0xa3bc0074, 0xd4bb30e2, 0x4adfa541, 0x3dd895d7, 0xa4d1c46d, 0xd3d6f4fb, + 0x4369e96a, 0x346ed9fc, 0xad678846, 0xda60b8d0, 0x44042d73, 0x33031de5, 0xaa0a4c5f, 0xdd0d7cc9, + 0x5005713c, 0x270241aa, 0xbe0b1010, 0xc90c2086, 0x5768b525, 0x206f85b3, 0xb966d409, 0xce61e49f, + 0x5edef90e, 0x29d9c998, 0xb0d09822, 0xc7d7a8b4, 0x59b33d17, 0x2eb40d81, 0xb7bd5c3b, 0xc0ba6cad, + 0xedb88320, 0x9abfb3b6, 0x3b6e20c, 0x74b1d29a, 0xead54739, 0x9dd277af, 0x4db2615, 0x73dc1683, + 0xe3630b12, 0x94643b84, 0xd6d6a3e, 0x7a6a5aa8, 0xe40ecf0b, 0x9309ff9d, 0xa00ae27, 0x7d079eb1, + 0xf00f9344, 0x8708a3d2, 0x1e01f268, 0x6906c2fe, 0xf762575d, 0x806567cb, 0x196c3671, 0x6e6b06e7, + 0xfed41b76, 0x89d32be0, 0x10da7a5a, 0x67dd4acc, 0xf9b9df6f, 0x8ebeeff9, 0x17b7be43, 0x60b08ed5, + 0xd6d6a3e8, 0xa1d1937e, 0x38d8c2c4, 0x4fdff252, 0xd1bb67f1, 0xa6bc5767, 0x3fb506dd, 0x48b2364b, + 0xd80d2bda, 0xaf0a1b4c, 0x36034af6, 0x41047a60, 0xdf60efc3, 0xa867df55, 0x316e8eef, 0x4669be79, + 0xcb61b38c, 0xbc66831a, 0x256fd2a0, 0x5268e236, 0xcc0c7795, 0xbb0b4703, 0x220216b9, 0x5505262f, + 0xc5ba3bbe, 0xb2bd0b28, 0x2bb45a92, 0x5cb36a04, 0xc2d7ffa7, 0xb5d0cf31, 0x2cd99e8b, 0x5bdeae1d, + 0x9b64c2b0, 0xec63f226, 0x756aa39c, 0x26d930a, 0x9c0906a9, 0xeb0e363f, 0x72076785, 0x5005713, + 0x95bf4a82, 0xe2b87a14, 0x7bb12bae, 0xcb61b38, 0x92d28e9b, 0xe5d5be0d, 0x7cdcefb7, 0xbdbdf21, + 0x86d3d2d4, 0xf1d4e242, 0x68ddb3f8, 0x1fda836e, 0x81be16cd, 0xf6b9265b, 0x6fb077e1, 0x18b74777, + 0x88085ae6, 0xff0f6a70, 0x66063bca, 0x11010b5c, 0x8f659eff, 0xf862ae69, 0x616bffd3, 0x166ccf45, + 0xa00ae278, 0xd70dd2ee, 0x4e048354, 0x3903b3c2, 0xa7672661, 0xd06016f7, 0x4969474d, 0x3e6e77db, + 0xaed16a4a, 0xd9d65adc, 0x40df0b66, 0x37d83bf0, 0xa9bcae53, 0xdebb9ec5, 0x47b2cf7f, 0x30b5ffe9, + 0xbdbdf21c, 0xcabac28a, 0x53b39330, 0x24b4a3a6, 0xbad03605, 0xcdd70693, 0x54de5729, 0x23d967bf, + 0xb3667a2e, 0xc4614ab8, 0x5d681b02, 0x2a6f2b94, 0xb40bbe37, 0xc30c8ea1, 0x5a05df1b, 0x2d02ef8d + }; + + return crc_table[idx]; + } + + }; + + inline void swap ( + crc32& a, + crc32& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + crc32:: + crc32 ( + ) + { + checksum = 0xFFFFFFFF; + } + +// ---------------------------------------------------------------------------------------- + + crc32:: + crc32 ( + const std::string& item + ) + { + checksum = 0xFFFFFFFF; + add(item); + } + +// ---------------------------------------------------------------------------------------- + + crc32:: + crc32 ( + const std::vector& item + ) + { + checksum = 0xFFFFFFFF; + add(item); + } + +// ---------------------------------------------------------------------------------------- + + crc32:: + ~crc32 ( + ) + { + } + +// ---------------------------------------------------------------------------------------- + + void crc32:: + clear( + ) + { + checksum = 0xFFFFFFFF; + } + +// ---------------------------------------------------------------------------------------- + + void crc32:: + add ( + unsigned char item + ) + { + checksum = (checksum>>8) ^ table((checksum^item) & 0xFF); + } + +// ---------------------------------------------------------------------------------------- + + void crc32:: + add ( + const std::string& item + ) + { + for (std::string::size_type i = 0; i < item.size(); ++i) + checksum = (checksum>>8) ^ table((checksum^item[i]) & 0xFF); + } + +// ---------------------------------------------------------------------------------------- + + void crc32:: + add ( + const std::vector& item + ) + { + for (unsigned long i = 0; i < item.size(); ++i) + checksum = (checksum>>8) ^ table((checksum^item[i]) & 0xFF); + } + +// ---------------------------------------------------------------------------------------- + + unsigned long crc32:: + get_checksum ( + ) const + { + return checksum ^ 0xFFFFFFFF; + } + +// ---------------------------------------------------------------------------------------- + + void crc32:: + swap ( + crc32& item + ) + { + exchange(checksum,item.checksum); + } + +// ---------------------------------------------------------------------------------------- + + crc32& crc32:: + operator=( + const crc32& item + ) + { + checksum = item.checksum; + return *this; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CRC32_KERNEl_1_ + diff --git a/ml/dlib/dlib/crc32/crc32_kernel_abstract.h b/ml/dlib/dlib/crc32/crc32_kernel_abstract.h new file mode 100644 index 000000000..76da49fbc --- /dev/null +++ b/ml/dlib/dlib/crc32/crc32_kernel_abstract.h @@ -0,0 +1,132 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_CRC32_KERNEl_ABSTRACT_ +#ifdef DLIB_CRC32_KERNEl_ABSTRACT_ + +#include "../algs.h" +#include +#include + +namespace dlib +{ + + class crc32 + { + /*! + INITIAL VALUE + The current checksum covers zero bytes. + get_checksum() == 0x00000000 + + WHAT THIS OBJECT REPRESENTS + This object represents the CRC32 algorithm for calculating + checksums. + !*/ + + public: + + crc32 ( + ); + /*! + ensures + - #*this is properly initialized + !*/ + + crc32 ( + const std::string& item + ); + /*! + ensures + - #*this is properly initialized + - calls this->add(item). + (i.e. Using this constructor is the same as using the default + constructor and then calling add() on item) + !*/ + + crc32 ( + const std::vector& item + ); + /*! + ensures + - #*this is properly initialized + - calls this->add(item). + (i.e. Using this constructor is the same as using the default + constructor and then calling add() on item) + !*/ + + virtual ~crc32 ( + ); + /*! + ensures + - any resources associated with *this have been released + !*/ + + void clear( + ); + /*! + ensures + - #*this has its initial value + !*/ + + void add ( + unsigned char item + ); + /*! + ensures + - #get_checksum() == The checksum of all items added to *this previously + concatenated with item. + !*/ + + void add ( + const std::string& item + ); + /*! + ensures + - #get_checksum() == The checksum of all items added to *this previously + concatenated with item. + !*/ + + void add ( + const std::vector& item + ); + /*! + ensures + - #get_checksum() == The checksum of all items added to *this previously + concatenated with item. + !*/ + + unsigned long get_checksum ( + ) const; + /*! + ensures + - returns the current checksum + !*/ + + operator unsigned long ( + ) const; + /*! + ensures + - returns get_checksum() + !*/ + + void swap ( + crc32& item + ); + /*! + ensures + - swaps *this and item + !*/ + + }; + + void swap ( + crc32& a, + crc32& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + +} + +#endif // DLIB_CRC32_KERNEl_ABSTRACT_ + diff --git a/ml/dlib/dlib/cstring b/ml/dlib/dlib/cstring new file mode 100644 index 000000000..eb0e59e41 --- /dev/null +++ b/ml/dlib/dlib/cstring @@ -0,0 +1 @@ +#include "dlib_include_path_tutorial.txt" diff --git a/ml/dlib/dlib/data_io.h b/ml/dlib/dlib/data_io.h new file mode 100644 index 000000000..845e95f40 --- /dev/null +++ b/ml/dlib/dlib/data_io.h @@ -0,0 +1,18 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_DATA_Io_HEADER +#define DLIB_DATA_Io_HEADER + +#include "data_io/libsvm_io.h" +#include "data_io/image_dataset_metadata.h" +#include "data_io/mnist.h" + +#ifndef DLIB_ISO_CPP_ONLY +#include "data_io/load_image_dataset.h" +#endif + +#endif // DLIB_DATA_Io_HEADER + + + + diff --git a/ml/dlib/dlib/data_io/image_dataset_metadata.cpp b/ml/dlib/dlib/data_io/image_dataset_metadata.cpp new file mode 100644 index 000000000..390ef6a0a --- /dev/null +++ b/ml/dlib/dlib/data_io/image_dataset_metadata.cpp @@ -0,0 +1,411 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_IMAGE_DAtASET_METADATA_CPPh_ +#define DLIB_IMAGE_DAtASET_METADATA_CPPh_ + +#include "image_dataset_metadata.h" + +#include +#include +#include "../compress_stream.h" +#include "../base64.h" +#include "../xml_parser.h" +#include "../string.h" + +// ---------------------------------------------------------------------------------------- + +namespace dlib +{ + namespace image_dataset_metadata + { + + // ------------------------------------------------------------------------------------ + + const std::string get_decoded_string(); + void create_image_metadata_stylesheet_file(const std::string& main_filename) + { + std::string path; + std::string::size_type pos = main_filename.find_last_of("/\\"); + if (pos != std::string::npos) + path = main_filename.substr(0,pos+1); + + std::ofstream fout((path + "image_metadata_stylesheet.xsl").c_str()); + if (!fout) + throw dlib::error("ERROR: Unable to open image_metadata_stylesheet.xsl for writing."); + + fout << get_decoded_string(); + + if (!fout) + throw dlib::error("ERROR: Unable to write to image_metadata_stylesheet.xsl."); + } + + void save_image_dataset_metadata ( + const dataset& meta, + const std::string& filename + ) + { + create_image_metadata_stylesheet_file(filename); + + const std::vector& images = meta.images; + + std::ofstream fout(filename.c_str()); + if (!fout) + throw dlib::error("ERROR: Unable to open " + filename + " for writing."); + + fout << "\n"; + fout << "\n"; + fout << "\n"; + fout << "" << meta.name << "\n"; + fout << "" << meta.comment << "\n"; + fout << "\n"; + for (unsigned long i = 0; i < images.size(); ++i) + { + fout << " \n"; + + // save all the boxes + for (unsigned long j = 0; j < images[i].boxes.size(); ++j) + { + const box& b = images[i].boxes[j]; + fout << " \n"; + + if (b.has_label()) + fout << " \n"; + + // save all the parts + std::map::const_iterator itr; + for (itr = b.parts.begin(); itr != b.parts.end(); ++itr) + { + fout << " \n"; + } + + fout << " \n"; + } + else + { + fout << "/>\n"; + } + } + + + + fout << " \n"; + + if (!fout) + throw dlib::error("ERROR: Unable to write to " + filename + "."); + } + fout << "\n"; + fout << ""; + } + + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + + class doc_handler : public document_handler + { + std::vector ts; + image temp_image; + box temp_box; + + dataset& meta; + + public: + + doc_handler( + dataset& metadata_ + ): + meta(metadata_) + {} + + + virtual void start_document ( + ) + { + meta = dataset(); + ts.clear(); + temp_image = image(); + temp_box = box(); + } + + virtual void end_document ( + ) + { + } + + virtual void start_element ( + const unsigned long line_number, + const std::string& name, + const dlib::attribute_list& atts + ) + { + try + { + if (ts.size() == 0) + { + if (name != "dataset") + { + std::ostringstream sout; + sout << "Invalid XML document. Root tag must be . Found <" << name << "> instead."; + throw dlib::error(sout.str()); + } + else + { + ts.push_back(name); + return; + } + } + + + if (name == "box") + { + if (atts.is_in_list("top")) temp_box.rect.top() = sa = atts["top"]; + else throw dlib::error(" missing required attribute 'top'"); + + if (atts.is_in_list("left")) temp_box.rect.left() = sa = atts["left"]; + else throw dlib::error(" missing required attribute 'left'"); + + if (atts.is_in_list("width")) temp_box.rect.right() = sa = atts["width"]; + else throw dlib::error(" missing required attribute 'width'"); + + if (atts.is_in_list("height")) temp_box.rect.bottom() = sa = atts["height"]; + else throw dlib::error(" missing required attribute 'height'"); + + if (atts.is_in_list("difficult")) temp_box.difficult = sa = atts["difficult"]; + if (atts.is_in_list("truncated")) temp_box.truncated = sa = atts["truncated"]; + if (atts.is_in_list("occluded")) temp_box.occluded = sa = atts["occluded"]; + if (atts.is_in_list("ignore")) temp_box.ignore = sa = atts["ignore"]; + if (atts.is_in_list("angle")) temp_box.angle = sa = atts["angle"]; + if (atts.is_in_list("age")) temp_box.age = sa = atts["age"]; + if (atts.is_in_list("gender")) + { + if (atts["gender"] == "male") + temp_box.gender = MALE; + else if (atts["gender"] == "female") + temp_box.gender = FEMALE; + else if (atts["gender"] == "unknown") + temp_box.gender = UNKNOWN; + else + throw dlib::error("Invalid gender string in box attribute."); + } + if (atts.is_in_list("pose")) temp_box.pose = sa = atts["pose"]; + if (atts.is_in_list("detection_score")) temp_box.detection_score = sa = atts["detection_score"]; + + temp_box.rect.bottom() += temp_box.rect.top()-1; + temp_box.rect.right() += temp_box.rect.left()-1; + } + else if (name == "part" && ts.back() == "box") + { + point temp; + if (atts.is_in_list("x")) temp.x() = sa = atts["x"]; + else throw dlib::error(" missing required attribute 'x'"); + + if (atts.is_in_list("y")) temp.y() = sa = atts["y"]; + else throw dlib::error(" missing required attribute 'y'"); + + if (atts.is_in_list("name")) + { + if (temp_box.parts.count(atts["name"])==0) + { + temp_box.parts[atts["name"]] = temp; + } + else + { + throw dlib::error(" with name '" + atts["name"] + "' is defined more than one time in a single box."); + } + } + else + { + throw dlib::error(" missing required attribute 'name'"); + } + } + else if (name == "image") + { + temp_image.boxes.clear(); + + if (atts.is_in_list("file")) temp_image.filename = atts["file"]; + else throw dlib::error(" missing required attribute 'file'"); + } + + ts.push_back(name); + } + catch (error& e) + { + throw dlib::error("Error on line " + cast_to_string(line_number) + ": " + e.what()); + } + } + + virtual void end_element ( + const unsigned long , + const std::string& name + ) + { + ts.pop_back(); + if (ts.size() == 0) + return; + + if (name == "box" && ts.back() == "image") + { + temp_image.boxes.push_back(temp_box); + temp_box = box(); + } + else if (name == "image" && ts.back() == "images") + { + meta.images.push_back(temp_image); + temp_image = image(); + } + } + + virtual void characters ( + const std::string& data + ) + { + if (ts.size() == 2 && ts[1] == "name") + { + meta.name = trim(data); + } + else if (ts.size() == 2 && ts[1] == "comment") + { + meta.comment = trim(data); + } + else if (ts.size() >= 2 && ts[ts.size()-1] == "label" && + ts[ts.size()-2] == "box") + { + temp_box.label = trim(data); + } + } + + virtual void processing_instruction ( + const unsigned long , + const std::string& , + const std::string& + ) + { + } + }; + + // ---------------------------------------------------------------------------------------- + + class xml_error_handler : public error_handler + { + public: + virtual void error ( + const unsigned long + ) { } + + virtual void fatal_error ( + const unsigned long line_number + ) + { + std::ostringstream sout; + sout << "There is a fatal error on line " << line_number << " so parsing will now halt."; + throw dlib::error(sout.str()); + } + }; + + // ------------------------------------------------------------------------------------ + + void load_image_dataset_metadata ( + dataset& meta, + const std::string& filename + ) + { + xml_error_handler eh; + doc_handler dh(meta); + + std::ifstream fin(filename.c_str()); + if (!fin) + throw dlib::error("ERROR: unable to open " + filename + " for reading."); + + xml_parser parser; + parser.add_document_handler(dh); + parser.add_error_handler(eh); + parser.parse(fin); + } + + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + + // This function returns the contents of the file 'images.xsl' + const std::string get_decoded_string() + { + dlib::base64 base64_coder; + dlib::compress_stream::kernel_1ea compressor; + std::ostringstream sout; + std::istringstream sin; + + // The base64 encoded data from the file 'image_metadata_stylesheet.xsl' we want to decode and return. + sout << "PFWfgmWfCHr1DkV63lbjjeY2dCc2FbHDOVh0Kd7dkvaOfRYrOG24f0x77/5iMVq8FtE3UBxtGwSd"; + sout << "1ZHOHRSHgieNoeBv8ssJQ75RRxYtFKRY3OTPX5eKQoCN9jUaUnHnR4QZtEHgmKqXSs50Yrdd+2Ah"; + sout << "gNyarPZCiR6nvqNvCjtP2MP5FxleqNf8Fylatm2KdsXmrv5K87LYVN7i7JMkmZ++cTXYSOxDmxZi"; + sout << "OiCH8funXUdF9apDW547gCjz9HOQUI6dkz5dYUeFjfp6dFugpnaJyyprFLKq048Qk7+QiL4CNF/G"; + sout << "7e0VpBw8dMpiyRNi2fSQGSZGfIAUQKKT6+rPwQoRH2spdjsdXVWj4XQAqBX87nmqMnqjMhn/Vd1s"; + sout << "W5aoC0drwRGu3Xe3gn9vBL8hBkRXcJvEy6q/lb9bYnsLemhE5Zp/+nTmTBjfT9UFYLcsmgsjC+4n"; + sout << "Bq6h9QlpuyMYqJ8RvW8pp3mFlvXc3Yg+18t5F0hSMQfaIFYAuDPU2lVzPpY+ba0B39iu9IrPCLsS"; + sout << "+tUtSNSmQ74CtzZgKKjkTMA3nwYP2SDmZE3firq42pihT7hdU5vYkes69K8AQl8WZyLPpMww+r0z"; + sout << "+veEHPlAuxF7kL3ZvVjdB+xABwwqDe0kSRHRZINYdUfJwJdfYLyDnYoMjj6afqIJZ7QOBPZ42tV5"; + sout << "3hYOQTFwTNovOastzJJXQe1kxPg1AQ8ynmfjjJZqD0xKedlyeJybP919mVAA23UryHsq9TVlabou"; + sout << "qNl3xZW/mKKktvVsd/nuH62HIv/kgomyhaEUY5HgupupBUbQFZfyljZ5bl3g3V3Y1400Z1xTM/LL"; + sout << "LJpeLdlqoGzIe/19vAN1zUUVId9F/OLNUl3Zoar63yZERSJHcsuq/Pasisp0HIGi7rfI9EIQF7C/"; + sout << "IhLKLZsJ+LOycreQGOJALZIEZHOqxYLSXG0qaPM5bQL/MQJ2OZfwEhQgYOrjaM7oPOHHEfTq5kcO"; + sout << "daMwzefKfxrF2GXbUs0bYsEXsIGwENIUKMliFaAI4qKLxxb94oc+O3BRjWueZjZty2zKawQyTHNd"; + sout << "ltFJBUzfffdZN9Wq4zbPzntkM3U6Ys4LRztx5M15dtbhFeKx5rAf2tPXT6wU01hx7EJxBJzpvoDE"; + sout << "YwEoYVDSYulRKpgk82cHFzzUDgWXbl4paFSe1L1w8r9KHr67SYJDTUG86Lrm6LJ0rw73Xp0NAFcU"; + sout << "MKpiG9g1cHW74HYbUb/yAbtVWt40eB7M637umdo2jWz/r/vP5WnfSMXEbkyWebsa1fFceg/TLWy6"; + sout << "E8OTc4XKB48h1oFIlGagOiprxho3+F3TIcxDSwA="; + + + + // Put the data into the istream sin + sin.str(sout.str()); + sout.str(""); + + // Decode the base64 text into its compressed binary form + base64_coder.decode(sin,sout); + sin.clear(); + sin.str(sout.str()); + sout.str(""); + + // Decompress the data into its original form + compressor.decompress(sin,sout); + + // Return the decoded and decompressed data + return sout.str(); + } + + + } +} + +// ---------------------------------------------------------------------------------------- + +#endif // DLIB_IMAGE_DAtASET_METADATA_CPPh_ + + diff --git a/ml/dlib/dlib/data_io/image_dataset_metadata.h b/ml/dlib/dlib/data_io/image_dataset_metadata.h new file mode 100644 index 000000000..3dac29ba6 --- /dev/null +++ b/ml/dlib/dlib/data_io/image_dataset_metadata.h @@ -0,0 +1,174 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_IMAGE_DAtASET_METADATA_Hh_ +#define DLIB_IMAGE_DAtASET_METADATA_Hh_ + +#include +#include +#include "../geometry.h" + +// ---------------------------------------------------------------------------------------- + +namespace dlib +{ + namespace image_dataset_metadata + { + + // ------------------------------------------------------------------------------------ + + enum gender_t + { + UNKNOWN, + MALE, + FEMALE + }; + + // ------------------------------------------------------------------------------------ + + struct box + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents an annotated rectangular area of an image. + It is typically used to mark the location of an object such as a + person, car, etc. + + The main variable of interest is rect. It gives the location of + the box. All the other variables are optional. + !*/ + + box( + ) : + difficult(false), + truncated(false), + occluded(false), + ignore(false), + pose(0), + detection_score(0), + angle(0), + gender(UNKNOWN), + age(0) + {} + + box ( + const rectangle& rect_ + ) : + rect(rect_), + difficult(false), + truncated(false), + occluded(false), + ignore(false), + pose(0), + detection_score(0), + angle(0), + gender(UNKNOWN), + age(0) + {} + + rectangle rect; + + std::map parts; + + // optional fields + std::string label; + bool difficult; + bool truncated; + bool occluded; + bool ignore; + double pose; + double detection_score; + + // The angle of the object in radians. Positive values indicate that the + // object at the center of the box is rotated clockwise by angle radians. A + // value of 0 would indicate that the object is in its "standard" upright pose. + // Therefore, to make the object appear upright we would have to rotate the + // image counter-clockwise by angle radians. + double angle; + + gender_t gender; + double age; + + bool has_label() const { return label.size() != 0; } + /*! + ensures + - returns true if label metadata is present and false otherwise. + !*/ + }; + + // ------------------------------------------------------------------------------------ + + struct image + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents an annotated image. + !*/ + + image() {} + image(const std::string& f) : filename(f) {} + + std::string filename; + std::vector boxes; + }; + + // ------------------------------------------------------------------------------------ + + struct dataset + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a labeled set of images. In particular, it + contains the filename for each image as well as annotated boxes. + !*/ + + std::vector images; + std::string comment; + std::string name; + }; + + // ------------------------------------------------------------------------------------ + + void save_image_dataset_metadata ( + const dataset& meta, + const std::string& filename + ); + /*! + ensures + - Writes the contents of the meta object to a file with the given + filename. The file will be in an XML format. + throws + - dlib::error + This exception is thrown if there is an error which prevents + this function from succeeding. + !*/ + + // ------------------------------------------------------------------------------------ + + void load_image_dataset_metadata ( + dataset& meta, + const std::string& filename + ); + /*! + ensures + - Attempts to interpret filename as a file containing XML formatted data + as produced by the save_image_dataset_metadata() function. Then + meta is loaded with the contents of the file. + throws + - dlib::error + This exception is thrown if there is an error which prevents + this function from succeeding. + !*/ + + // ------------------------------------------------------------------------------------ + + } +} + +// ---------------------------------------------------------------------------------------- + +#ifdef NO_MAKEFILE +#include "image_dataset_metadata.cpp" +#endif + +#endif // DLIB_IMAGE_DAtASET_METADATA_Hh_ + diff --git a/ml/dlib/dlib/data_io/libsvm_io.h b/ml/dlib/dlib/data_io/libsvm_io.h new file mode 100644 index 000000000..f365e82d7 --- /dev/null +++ b/ml/dlib/dlib/data_io/libsvm_io.h @@ -0,0 +1,276 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_LIBSVM_iO_Hh_ +#define DLIB_LIBSVM_iO_Hh_ + +#include "libsvm_io_abstract.h" + +#include +#include +#include +#include "../algs.h" +#include "../matrix.h" +#include "../string.h" +#include "../svm/sparse_vector.h" +#include + +namespace dlib +{ + struct sample_data_io_error : public error + { + sample_data_io_error(const std::string& message): error(message) {} + }; + +// ---------------------------------------------------------------------------------------- + + template + void load_libsvm_formatted_data ( + const std::string& file_name, + std::vector& samples, + std::vector& labels + ) + { + using namespace std; + typedef typename sample_type::value_type pair_type; + typedef typename basic_type::type key_type; + typedef typename pair_type::second_type value_type; + + // You must use unsigned integral key types in your sparse vectors + COMPILE_TIME_ASSERT(is_unsigned_type::value); + + samples.clear(); + labels.clear(); + + ifstream fin(file_name.c_str()); + + if (!fin) + throw sample_data_io_error("Unable to open file " + file_name); + + string line; + istringstream sin; + key_type key; + value_type value; + label_type label; + sample_type sample; + long line_num = 0; + while (fin.peek() != EOF) + { + ++line_num; + getline(fin, line); + + string::size_type pos = line.find_first_not_of(" \t\r\n"); + + // ignore empty lines or comment lines + if (pos == string::npos || line[pos] == '#') + continue; + + sin.clear(); + sin.str(line); + sample.clear(); + + sin >> label; + + if (!sin) + throw sample_data_io_error("On line: " + cast_to_string(line_num) + ", error while reading file " + file_name ); + + // eat whitespace + sin >> ws; + + while (sin.peek() != EOF && sin.peek() != '#') + { + + sin >> key >> ws; + + // ignore what should be a : character + if (sin.get() != ':') + throw sample_data_io_error("On line: " + cast_to_string(line_num) + ", error while reading file " + file_name); + + sin >> value; + + if (sin && value != 0) + { + sample.insert(sample.end(), make_pair(key, value)); + } + + sin >> ws; + } + + samples.push_back(sample); + labels.push_back(label); + } + + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template + typename enable_if >::type + fix_nonzero_indexing ( + std::vector& samples + ) + { + typedef typename sample_type::value_type pair_type; + typedef typename basic_type::type key_type; + + if (samples.size() == 0) + return; + + // figure out the min index value + key_type min_idx = samples[0].begin()->first; + for (unsigned long i = 0; i < samples.size(); ++i) + min_idx = std::min(min_idx, samples[i].begin()->first); + + // Now adjust all the samples so that their min index value is zero. + if (min_idx != 0) + { + sample_type temp; + for (unsigned long i = 0; i < samples.size(); ++i) + { + // copy samples[i] into temp but make sure it has a min index of zero. + temp.clear(); + typename sample_type::iterator j; + for (j = samples[i].begin(); j != samples[i].end(); ++j) + { + temp.insert(temp.end(), std::make_pair(j->first-min_idx, j->second)); + } + + // replace the current sample with temp. + samples[i].swap(temp); + } + } + } + +// ---------------------------------------------------------------------------------------- + +// If the "first" values in the std::pair objects are not const then we can modify them +// directly and that is what this version of fix_nonzero_indexing() does. + template + typename disable_if >::type + fix_nonzero_indexing ( + std::vector& samples + ) + { + typedef typename sample_type::value_type pair_type; + typedef typename basic_type::type key_type; + + if (samples.size() == 0) + return; + + // figure out the min index value + key_type min_idx = samples[0].begin()->first; + for (unsigned long i = 0; i < samples.size(); ++i) + min_idx = std::min(min_idx, samples[i].begin()->first); + + // Now adjust all the samples so that their min index value is zero. + if (min_idx != 0) + { + for (unsigned long i = 0; i < samples.size(); ++i) + { + typename sample_type::iterator j; + for (j = samples[i].begin(); j != samples[i].end(); ++j) + { + j->first -= min_idx; + } + } + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + +// This is an overload for sparse vectors + template + typename disable_if,void>::type save_libsvm_formatted_data ( + const std::string& file_name, + const std::vector& samples, + const std::vector& labels + ) + { + typedef typename sample_type::value_type pair_type; + typedef typename basic_type::type key_type; + + // You must use unsigned integral key types in your sparse vectors + COMPILE_TIME_ASSERT(is_unsigned_type::value); + + // make sure requires clause is not broken + DLIB_ASSERT(samples.size() == labels.size(), + "\t void save_libsvm_formatted_data()" + << "\n\t You have to have labels for each sample and vice versa" + << "\n\t samples.size(): " << samples.size() + << "\n\t labels.size(): " << labels.size() + ); + + + using namespace std; + ofstream fout(file_name.c_str()); + fout.precision(14); + + if (!fout) + throw sample_data_io_error("Unable to open file " + file_name); + + for (unsigned long i = 0; i < samples.size(); ++i) + { + fout << labels[i]; + + for (typename sample_type::const_iterator j = samples[i].begin(); j != samples[i].end(); ++j) + { + if (j->second != 0) + fout << " " << j->first << ":" << j->second; + } + fout << "\n"; + + if (!fout) + throw sample_data_io_error("Error while writing to file " + file_name); + } + + } + +// ---------------------------------------------------------------------------------------- + +// This is an overload for dense vectors + template + typename enable_if,void>::type save_libsvm_formatted_data ( + const std::string& file_name, + const std::vector& samples, + const std::vector& labels + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(samples.size() == labels.size(), + "\t void save_libsvm_formatted_data()" + << "\n\t You have to have labels for each sample and vice versa" + << "\n\t samples.size(): " << samples.size() + << "\n\t labels.size(): " << labels.size() + ); + + using namespace std; + ofstream fout(file_name.c_str()); + fout.precision(14); + + if (!fout) + throw sample_data_io_error("Unable to open file " + file_name); + + for (unsigned long i = 0; i < samples.size(); ++i) + { + fout << labels[i]; + + for (long j = 0; j < samples[i].size(); ++j) + { + if (samples[i](j) != 0) + fout << " " << j << ":" << samples[i](j); + } + fout << "\n"; + + if (!fout) + throw sample_data_io_error("Error while writing to file " + file_name); + } + + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_LIBSVM_iO_Hh_ + diff --git a/ml/dlib/dlib/data_io/libsvm_io_abstract.h b/ml/dlib/dlib/data_io/libsvm_io_abstract.h new file mode 100644 index 000000000..88d934fdb --- /dev/null +++ b/ml/dlib/dlib/data_io/libsvm_io_abstract.h @@ -0,0 +1,125 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_LIBSVM_iO_ABSTRACT_Hh_ +#ifdef DLIB_LIBSVM_iO_ABSTRACT_Hh_ + +#include +#include +#include +#include "../algs.h" +#include "../matrix.h" +#include + +namespace dlib +{ + struct sample_data_io_error : public error + { + /*! + This is the exception class used by the file IO functions defined below. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename sample_type, + typename label_type, + typename alloc1, + typename alloc2 + > + void load_libsvm_formatted_data ( + const std::string& file_name, + std::vector& samples, + std::vector& labels + ); + /*! + requires + - sample_type must be an STL container + - sample_type::value_type == std::pair where T is some kind of + unsigned integral type + ensures + - attempts to read a file of the given name that should contain libsvm + formatted data. We turn the data into sparse vectors and store it + in samples + - #labels.size() == #samples.size() + - for all valid i: #labels[i] is the label for #samples[i] + throws + - sample_data_io_error + This exception is thrown if there is any problem loading data from file + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename sample_type, + typename label_type, + typename alloc1, + typename alloc2 + > + void save_libsvm_formatted_data ( + const std::string& file_name, + const std::vector& samples, + const std::vector& labels + ); + /*! + requires + - sample_type must be an STL container + - sample_type::value_type == std::pair where T is some kind of + unsigned integral type + - samples.size() == labels.size() + ensures + - saves the data to the given file in libsvm format + throws + - sample_data_io_error + This exception is thrown if there is any problem saving data to file + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename sample_type, + typename label_type, + typename alloc1, + typename alloc2 + > + void save_libsvm_formatted_data ( + const std::string& file_name, + const std::vector& samples, + const std::vector& labels + ); + /*! + requires + - sample_type == a dense matrix (i.e. dlib::matrix) + - for all valid i: is_vector(samples[i]) == true + - samples.size() == labels.size() + ensures + - saves the data to the given file in libsvm format + throws + - sample_data_io_error + This exception is thrown if there is any problem saving data to file + !*/ + +// ---------------------------------------------------------------------------------------- + + template + void fix_nonzero_indexing ( + std::vector& samples + ); + /*! + requires + - samples must only contain valid sparse vectors. The definition of + a sparse vector can be found at the top of dlib/svm/sparse_vector_abstract.h + ensures + - Adjusts the sparse vectors in samples so that they are zero-indexed. + Or in other words, assume the smallest used index value in any of the sparse + vectors is N. Then this function subtracts N from all the index values in + samples. This is useful, for example, if you load a libsvm formatted datafile + with features indexed from 1 rather than 0 and you would like to fix this. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_LIBSVM_iO_ABSTRACT_Hh_ + diff --git a/ml/dlib/dlib/data_io/load_image_dataset.h b/ml/dlib/dlib/data_io/load_image_dataset.h new file mode 100644 index 000000000..5664d96b2 --- /dev/null +++ b/ml/dlib/dlib/data_io/load_image_dataset.h @@ -0,0 +1,510 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_LOAD_IMAGE_DaTASET_Hh_ +#define DLIB_LOAD_IMAGE_DaTASET_Hh_ + +#include "load_image_dataset_abstract.h" +#include "../misc_api.h" +#include "../dir_nav.h" +#include "../image_io.h" +#include "../array.h" +#include +#include "../geometry.h" +#include "image_dataset_metadata.h" +#include +#include +#include "../image_processing/full_object_detection.h" +#include +#include +#include "../image_transforms/image_pyramid.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class image_dataset_file + { + public: + image_dataset_file(const std::string& filename) + { + _skip_empty_images = false; + _have_parts = false; + _filename = filename; + _box_area_thresh = std::numeric_limits::infinity(); + } + + image_dataset_file boxes_match_label( + const std::string& label + ) const + { + image_dataset_file temp(*this); + temp._labels.insert(label); + return temp; + } + + image_dataset_file skip_empty_images( + ) const + { + image_dataset_file temp(*this); + temp._skip_empty_images = true; + return temp; + } + + image_dataset_file boxes_have_parts( + ) const + { + image_dataset_file temp(*this); + temp._have_parts = true; + return temp; + } + + image_dataset_file shrink_big_images( + double new_box_area_thresh = 150*150 + ) const + { + image_dataset_file temp(*this); + temp._box_area_thresh = new_box_area_thresh; + return temp; + } + + bool should_load_box ( + const image_dataset_metadata::box& box + ) const + { + if (_have_parts && box.parts.size() == 0) + return false; + if (_labels.size() == 0) + return true; + if (_labels.count(box.label) != 0) + return true; + return false; + } + + const std::string& get_filename() const { return _filename; } + bool should_skip_empty_images() const { return _skip_empty_images; } + bool should_boxes_have_parts() const { return _have_parts; } + double box_area_thresh() const { return _box_area_thresh; } + const std::set& get_selected_box_labels() const { return _labels; } + + private: + std::string _filename; + std::set _labels; + bool _skip_empty_images; + bool _have_parts; + double _box_area_thresh; + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename array_type + > + std::vector > load_image_dataset ( + array_type& images, + std::vector >& object_locations, + const image_dataset_file& source + ) + { + images.clear(); + object_locations.clear(); + + std::vector > ignored_rects; + + using namespace dlib::image_dataset_metadata; + dataset data; + load_image_dataset_metadata(data, source.get_filename()); + + // Set the current directory to be the one that contains the + // metadata file. We do this because the file might contain + // file paths which are relative to this folder. + locally_change_current_dir chdir(get_parent_directory(file(source.get_filename()))); + + + typedef typename array_type::value_type image_type; + + + image_type img; + std::vector rects, ignored; + for (unsigned long i = 0; i < data.images.size(); ++i) + { + double min_rect_size = std::numeric_limits::infinity(); + rects.clear(); + ignored.clear(); + for (unsigned long j = 0; j < data.images[i].boxes.size(); ++j) + { + if (source.should_load_box(data.images[i].boxes[j])) + { + if (data.images[i].boxes[j].ignore) + { + ignored.push_back(data.images[i].boxes[j].rect); + } + else + { + rects.push_back(data.images[i].boxes[j].rect); + min_rect_size = std::min(min_rect_size, rects.back().area()); + } + } + } + + if (!source.should_skip_empty_images() || rects.size() != 0) + { + load_image(img, data.images[i].filename); + if (rects.size() != 0) + { + // if shrinking the image would still result in the smallest box being + // bigger than the box area threshold then shrink the image. + while(min_rect_size/2/2 > source.box_area_thresh()) + { + pyramid_down<2> pyr; + pyr(img); + min_rect_size *= (1.0/2.0)*(1.0/2.0); + for (auto&& r : rects) + r = pyr.rect_down(r); + for (auto&& r : ignored) + r = pyr.rect_down(r); + } + while(min_rect_size*(2.0/3.0)*(2.0/3.0) > source.box_area_thresh()) + { + pyramid_down<3> pyr; + pyr(img); + min_rect_size *= (2.0/3.0)*(2.0/3.0); + for (auto&& r : rects) + r = pyr.rect_down(r); + for (auto&& r : ignored) + r = pyr.rect_down(r); + } + } + images.push_back(img); + object_locations.push_back(rects); + ignored_rects.push_back(ignored); + } + } + + return ignored_rects; + } + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + inline size_t num_non_ignored_boxes (const std::vector& rects) + { + size_t cnt = 0; + for (auto& b : rects) + { + if (!b.ignore) + cnt++; + } + return cnt; + } + } + + template < + typename array_type + > + void load_image_dataset ( + array_type& images, + std::vector >& object_locations, + const image_dataset_file& source + ) + { + images.clear(); + object_locations.clear(); + + using namespace dlib::image_dataset_metadata; + dataset data; + load_image_dataset_metadata(data, source.get_filename()); + + // Set the current directory to be the one that contains the + // metadata file. We do this because the file might contain + // file paths which are relative to this folder. + locally_change_current_dir chdir(get_parent_directory(file(source.get_filename()))); + + typedef typename array_type::value_type image_type; + + image_type img; + std::vector rects; + for (unsigned long i = 0; i < data.images.size(); ++i) + { + double min_rect_size = std::numeric_limits::infinity(); + rects.clear(); + for (unsigned long j = 0; j < data.images[i].boxes.size(); ++j) + { + if (source.should_load_box(data.images[i].boxes[j])) + { + if (data.images[i].boxes[j].ignore) + { + rects.push_back(ignored_mmod_rect(data.images[i].boxes[j].rect)); + } + else + { + rects.push_back(mmod_rect(data.images[i].boxes[j].rect)); + min_rect_size = std::min(min_rect_size, rects.back().rect.area()); + } + rects.back().label = data.images[i].boxes[j].label; + + } + } + + if (!source.should_skip_empty_images() || impl::num_non_ignored_boxes(rects) != 0) + { + load_image(img, data.images[i].filename); + if (rects.size() != 0) + { + // if shrinking the image would still result in the smallest box being + // bigger than the box area threshold then shrink the image. + while(min_rect_size/2/2 > source.box_area_thresh()) + { + pyramid_down<2> pyr; + pyr(img); + min_rect_size *= (1.0/2.0)*(1.0/2.0); + for (auto&& r : rects) + r.rect = pyr.rect_down(r.rect); + } + while(min_rect_size*(2.0/3.0)*(2.0/3.0) > source.box_area_thresh()) + { + pyramid_down<3> pyr; + pyr(img); + min_rect_size *= (2.0/3.0)*(2.0/3.0); + for (auto&& r : rects) + r.rect = pyr.rect_down(r.rect); + } + } + images.push_back(std::move(img)); + object_locations.push_back(std::move(rects)); + } + } + } + +// ---------------------------------------------------------------------------------------- + +// ******* THIS FUNCTION IS DEPRECATED, you should use another version of load_image_dataset() ******* + template < + typename image_type, + typename MM + > + std::vector > load_image_dataset ( + array& images, + std::vector >& object_locations, + const std::string& filename, + const std::string& label, + bool skip_empty_images = false + ) + { + image_dataset_file f(filename); + if (label.size() != 0) + f = f.boxes_match_label(label); + if (skip_empty_images) + f = f.skip_empty_images(); + return load_image_dataset(images, object_locations, f); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename array_type + > + std::vector > load_image_dataset ( + array_type& images, + std::vector >& object_locations, + const std::string& filename + ) + { + return load_image_dataset(images, object_locations, image_dataset_file(filename)); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename array_type + > + void load_image_dataset ( + array_type& images, + std::vector>& object_locations, + const std::string& filename + ) + { + load_image_dataset(images, object_locations, image_dataset_file(filename)); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename array_type + > + std::vector > load_image_dataset ( + array_type& images, + std::vector >& object_locations, + const image_dataset_file& source, + std::vector& parts_list + ) + { + typedef typename array_type::value_type image_type; + parts_list.clear(); + images.clear(); + object_locations.clear(); + + using namespace dlib::image_dataset_metadata; + dataset data; + load_image_dataset_metadata(data, source.get_filename()); + + // Set the current directory to be the one that contains the + // metadata file. We do this because the file might contain + // file paths which are relative to this folder. + locally_change_current_dir chdir(get_parent_directory(file(source.get_filename()))); + + + std::set all_parts; + + // find out what parts are being used in the dataset. Store results in all_parts. + for (unsigned long i = 0; i < data.images.size(); ++i) + { + for (unsigned long j = 0; j < data.images[i].boxes.size(); ++j) + { + if (source.should_load_box(data.images[i].boxes[j])) + { + const std::map& parts = data.images[i].boxes[j].parts; + std::map::const_iterator itr; + + for (itr = parts.begin(); itr != parts.end(); ++itr) + { + all_parts.insert(itr->first); + } + } + } + } + + // make a mapping between part names and the integers [0, all_parts.size()) + std::map parts_idx; + for (std::set::iterator i = all_parts.begin(); i != all_parts.end(); ++i) + { + parts_idx[*i] = parts_list.size(); + parts_list.push_back(*i); + } + + std::vector > ignored_rects; + std::vector ignored; + image_type img; + std::vector object_dets; + for (unsigned long i = 0; i < data.images.size(); ++i) + { + double min_rect_size = std::numeric_limits::infinity(); + object_dets.clear(); + ignored.clear(); + for (unsigned long j = 0; j < data.images[i].boxes.size(); ++j) + { + if (source.should_load_box(data.images[i].boxes[j])) + { + if (data.images[i].boxes[j].ignore) + { + ignored.push_back(data.images[i].boxes[j].rect); + } + else + { + std::vector partlist(parts_idx.size(), OBJECT_PART_NOT_PRESENT); + + // populate partlist with all the parts present in this box. + const std::map& parts = data.images[i].boxes[j].parts; + std::map::const_iterator itr; + for (itr = parts.begin(); itr != parts.end(); ++itr) + { + partlist[parts_idx[itr->first]] = itr->second; + } + + object_dets.push_back(full_object_detection(data.images[i].boxes[j].rect, partlist)); + min_rect_size = std::min(min_rect_size, object_dets.back().get_rect().area()); + } + } + } + + if (!source.should_skip_empty_images() || object_dets.size() != 0) + { + load_image(img, data.images[i].filename); + if (object_dets.size() != 0) + { + // if shrinking the image would still result in the smallest box being + // bigger than the box area threshold then shrink the image. + while(min_rect_size/2/2 > source.box_area_thresh()) + { + pyramid_down<2> pyr; + pyr(img); + min_rect_size *= (1.0/2.0)*(1.0/2.0); + for (auto&& r : object_dets) + { + r.get_rect() = pyr.rect_down(r.get_rect()); + for (unsigned long k = 0; k < r.num_parts(); ++k) + r.part(k) = pyr.point_down(r.part(k)); + } + for (auto&& r : ignored) + { + r = pyr.rect_down(r); + } + } + while(min_rect_size*(2.0/3.0)*(2.0/3.0) > source.box_area_thresh()) + { + pyramid_down<3> pyr; + pyr(img); + min_rect_size *= (2.0/3.0)*(2.0/3.0); + for (auto&& r : object_dets) + { + r.get_rect() = pyr.rect_down(r.get_rect()); + for (unsigned long k = 0; k < r.num_parts(); ++k) + r.part(k) = pyr.point_down(r.part(k)); + } + for (auto&& r : ignored) + { + r = pyr.rect_down(r); + } + } + } + images.push_back(img); + object_locations.push_back(object_dets); + ignored_rects.push_back(ignored); + } + } + + + return ignored_rects; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename array_type + > + std::vector > load_image_dataset ( + array_type& images, + std::vector >& object_locations, + const image_dataset_file& source + ) + { + std::vector parts_list; + return load_image_dataset(images, object_locations, source, parts_list); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename array_type + > + std::vector > load_image_dataset ( + array_type& images, + std::vector >& object_locations, + const std::string& filename + ) + { + std::vector parts_list; + return load_image_dataset(images, object_locations, image_dataset_file(filename), parts_list); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_LOAD_IMAGE_DaTASET_Hh_ + diff --git a/ml/dlib/dlib/data_io/load_image_dataset_abstract.h b/ml/dlib/dlib/data_io/load_image_dataset_abstract.h new file mode 100644 index 000000000..b06252098 --- /dev/null +++ b/ml/dlib/dlib/data_io/load_image_dataset_abstract.h @@ -0,0 +1,358 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_LOAD_IMAGE_DaTASET_ABSTRACT_Hh_ +#ifdef DLIB_LOAD_IMAGE_DaTASET_ABSTRACT_Hh_ + +#include "image_dataset_metadata.h" +#include "../array/array_kernel_abstract.h" +#include +#include +#include "../image_processing/full_object_detection_abstract.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class image_dataset_file + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a tool used to tell the load_image_dataset() functions which + boxes and images to load from an XML based image dataset file. By default, + this object tells load_image_dataset() to load all images and object boxes. + !*/ + + public: + image_dataset_file( + const std::string& filename + ); + /*! + ensures + - #get_filename() == filename + - #should_skip_empty_images() == false + - #get_selected_box_labels().size() == 0 + This means that, initially, all boxes will be loaded. Therefore, for all + possible boxes B we have: + - #should_load_box(B) == true + - #box_area_thresh() == infinity + !*/ + + const std::string& get_filename( + ) const; + /*! + ensures + - returns the name of the XML image dataset metadata file given to this + object's constructor. + !*/ + + bool should_skip_empty_images( + ) const; + /*! + ensures + - returns true if we are supposed to skip images that don't have any + non-ignored boxes to load when loading an image dataset using + load_image_dataset(). + !*/ + + image_dataset_file boxes_match_label( + const std::string& label + ) const; + /*! + ensures + - returns a copy of *this that is identical in all respects to *this except + that label will be included in the labels set (i.e. the set returned by + get_selected_box_labels()). + !*/ + + const std::set& get_selected_box_labels( + ) const; + /*! + ensures + - returns the set of box labels currently selected by the should_load_box() + method. Note that if the set is empty then we select all boxes. + !*/ + + image_dataset_file skip_empty_images( + ) const; + /*! + ensures + - returns a copy of *this that is identical in all respects to *this except + that #should_skip_empty_images() == true. + !*/ + + bool should_boxes_have_parts( + ) const; + /*! + ensures + - returns true if boxes must have some parts defined for them to be loaded. + !*/ + + image_dataset_file boxes_have_parts( + ) const; + /*! + ensures + - returns a copy of *this that is identical in all respects to *this except + that #should_boxes_have_parts() == true. + !*/ + + bool should_load_box ( + const image_dataset_metadata::box& box + ) const; + /*! + ensures + - returns true if we are supposed to load the given box from an image + dataset XML file. In particular, if should_load_box() returns false then + the load_image_dataset() routines will not return the box at all, neither + in the ignore rectangles list or in the primary object_locations vector. + The behavior of this function is defined as follows: + - if (should_boxes_have_parts() && boxes.parts.size() == 0) then + - returns false + - else if (get_selected_box_labels().size() == 0) then + - returns true + - else if (get_selected_box_labels().count(box.label) != 0) then + - returns true + - else + - returns false + !*/ + + image_dataset_file shrink_big_images( + double new_box_area_thresh = 150*150 + ) const; + /*! + ensures + - returns a copy of *this that is identical in all respects to *this except + that #box_area_thresh() == new_box_area_thresh + !*/ + + double box_area_thresh( + ) const; + /*! + ensures + - If the smallest non-ignored rectangle in an image has an area greater + than box_area_thresh() then we will shrink the image until the area of + the box is about equal to box_area_thresh(). This is useful if you have + a dataset containing very high resolution images and you don't want to + load it in its native high resolution. Setting the box_area_thresh() + allows you to control the resolution of the loaded images. + !*/ + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename array_type + > + std::vector > load_image_dataset ( + array_type& images, + std::vector >& object_locations, + const image_dataset_file& source + ); + /*! + requires + - array_type == An array of images. This is anything with an interface that + looks like std::vector where a "generic image" is + anything that implements the generic image interface defined in + dlib/image_processing/generic_image.h. + ensures + - This routine loads the images and their associated object boxes from the + image metadata file indicated by source.get_filename(). This metadata file + should be in the XML format used by the save_image_dataset_metadata() routine. + - #images.size() == The number of images loaded from the metadata file. This + is all the images listed in the file unless source.should_skip_empty_images() + is set to true. + - #images.size() == #object_locations.size() + - This routine is capable of loading any image format which can be read by the + load_image() routine. + - let IGNORED_RECTS denote the vector returned from this function. + - IGNORED_RECTS.size() == #object_locations.size() + - IGNORED_RECTS == a list of the rectangles which have the "ignore" flag set to + true in the input XML file. + - for all valid i: + - #images[i] == a copy of the i-th image from the dataset. + - #object_locations[i] == a vector of all the rectangles associated with + #images[i]. These are the rectangles for which source.should_load_box() + returns true and are also not marked as "ignore" in the XML file. + - IGNORED_RECTS[i] == A vector of all the rectangles associated with #images[i] + that are marked as "ignore" but not discarded by source.should_load_box(). + - if (source.should_skip_empty_images() == true) then + - #object_locations[i].size() != 0 + (i.e. we won't load images that don't end up having any object locations) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename array_type + > + std::vector > load_image_dataset ( + array_type& images, + std::vector >& object_locations, + const std::string& filename + ); + /*! + requires + - array_type == An array of images. This is anything with an interface that + looks like std::vector where a "generic image" is + anything that implements the generic image interface defined in + dlib/image_processing/generic_image.h. + ensures + - performs: return load_image_dataset(images, object_locations, image_dataset_file(filename)); + (i.e. it ignores box labels and therefore loads all the boxes in the dataset) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename array_type + > + void load_image_dataset ( + array_type& images, + std::vector >& object_locations, + const image_dataset_file& source + ); + /*! + requires + - array_type == An array of images. This is anything with an interface that + looks like std::vector where a "generic image" is + anything that implements the generic image interface defined in + dlib/image_processing/generic_image.h. + ensures + - This function has essentially the same behavior as the above + load_image_dataset() routines, except here we output to a vector of + mmod_rects instead of rectangles. In this case, both ignore and non-ignore + rectangles go into object_locations since mmod_rect has an ignore boolean + field that records the ignored/non-ignored state of each rectangle. We also store + a each box's string label into the mmod_rect::label field as well. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename array_type + > + void load_image_dataset ( + array_type& images, + std::vector >& object_locations, + const std::string& filename + ); + /*! + requires + - array_type == An array of images. This is anything with an interface that + looks like std::vector where a "generic image" is + anything that implements the generic image interface defined in + dlib/image_processing/generic_image.h. + ensures + - performs: load_image_dataset(images, object_locations, image_dataset_file(filename)); + (i.e. it ignores box labels and therefore loads all the boxes in the dataset) + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename array_type + > + std::vector > load_image_dataset ( + array_type& images, + std::vector >& object_locations, + const image_dataset_file& source, + std::vector& parts_list + ); + /*! + requires + - array_type == An array of images. This is anything with an interface that + looks like std::vector where a "generic image" is + anything that implements the generic image interface defined in + dlib/image_processing/generic_image.h. + ensures + - This routine loads the images and their associated object locations from the + image metadata file indicated by source.get_filename(). This metadata file + should be in the XML format used by the save_image_dataset_metadata() routine. + - The difference between this function and the version of load_image_dataset() + defined above is that this version will also load object part information and + thus fully populates the full_object_detection objects. + - #images.size() == The number of images loaded from the metadata file. This + is all the images listed in the file unless source.should_skip_empty_images() + is set to true. + - #images.size() == #object_locations.size() + - This routine is capable of loading any image format which can be read + by the load_image() routine. + - #parts_list == a vector that contains the list of object parts found in the + input file and loaded into object_locations. + - #parts_list is in lexicographic sorted order. + - let IGNORED_RECTS denote the vector returned from this function. + - IGNORED_RECTS.size() == #object_locations.size() + - IGNORED_RECTS == a list of the rectangles which have the "ignore" flag set to + true in the input XML file. + - for all valid i: + - #images[i] == a copy of the i-th image from the dataset. + - #object_locations[i] == a vector of all the rectangles associated with + #images[i]. These are the rectangles for which source.should_load_box() + returns true and are also not marked as "ignore" in the XML file. + - IGNORED_RECTS[i] == A vector of all the rectangles associated with #images[i] + that are marked as "ignore" but not discarded by source.should_load_box(). + - if (source.should_skip_empty_images() == true) then + - #object_locations[i].size() != 0 + (i.e. we won't load images that don't end up having any object locations) + - for all valid j: + - #object_locations[i][j].num_parts() == #parts_list.size() + - for all valid k: + - #object_locations[i][j].part(k) == the location of the part + with name #parts_list[k] or OBJECT_PART_NOT_PRESENT if the + part was not indicated for object #object_locations[i][j]. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename array_type + > + std::vector > load_image_dataset ( + array_type& images, + std::vector >& object_locations, + const image_dataset_file& source + ); + /*! + requires + - array_type == An array of images. This is anything with an interface that + looks like std::vector where a "generic image" is + anything that implements the generic image interface defined in + dlib/image_processing/generic_image.h. + ensures + - performs: return load_image_dataset(images, object_locations, source, parts_list); + (i.e. this function simply calls the above function and discards the output + parts_list. So it is just a convenience function you can call if you don't + care about getting the parts list.) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename array_type + > + std::vector > load_image_dataset ( + array_type& images, + std::vector >& object_locations, + const std::string& filename + ); + /*! + requires + - array_type == An array of images. This is anything with an interface that + looks like std::vector where a "generic image" is + anything that implements the generic image interface defined in + dlib/image_processing/generic_image.h. + ensures + - performs: return load_image_dataset(images, object_locations, image_dataset_file(filename)); + (i.e. it ignores box labels and therefore loads all the boxes in the dataset) + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_LOAD_IMAGE_DaTASET_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/data_io/mnist.cpp b/ml/dlib/dlib/data_io/mnist.cpp new file mode 100644 index 000000000..d6a62fb67 --- /dev/null +++ b/ml/dlib/dlib/data_io/mnist.cpp @@ -0,0 +1,133 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MNIST_CPp_ +#define DLIB_MNIST_CPp_ + +#include "mnist.h" +#include +#include "../byte_orderer.h" +#include "../uintn.h" + +// ---------------------------------------------------------------------------------------- + +namespace dlib +{ + void load_mnist_dataset ( + const std::string& folder_name, + std::vector >& training_images, + std::vector& training_labels, + std::vector >& testing_images, + std::vector& testing_labels + ) + { + using namespace std; + ifstream fin1((folder_name+"/train-images-idx3-ubyte").c_str(), ios::binary); + if (!fin1) + { + fin1.open((folder_name + "/train-images.idx3-ubyte").c_str(), ios::binary); + } + + ifstream fin2((folder_name+"/train-labels-idx1-ubyte").c_str(), ios::binary); + if (!fin2) + { + fin2.open((folder_name + "/train-labels.idx1-ubyte").c_str(), ios::binary); + } + + ifstream fin3((folder_name+"/t10k-images-idx3-ubyte").c_str(), ios::binary); + if (!fin3) + { + fin3.open((folder_name + "/t10k-images.idx3-ubyte").c_str(), ios::binary); + } + + ifstream fin4((folder_name+"/t10k-labels-idx1-ubyte").c_str(), ios::binary); + if (!fin4) + { + fin4.open((folder_name + "/t10k-labels.idx1-ubyte").c_str(), ios::binary); + } + + if (!fin1) throw error("Unable to open file train-images-idx3-ubyte or train-images.idx3-ubyte"); + if (!fin2) throw error("Unable to open file train-labels-idx1-ubyte or train-labels.idx1-ubyte"); + if (!fin3) throw error("Unable to open file t10k-images-idx3-ubyte or t10k-images.idx3-ubyte"); + if (!fin4) throw error("Unable to open file t10k-labels-idx1-ubyte or t10k-labels.idx1-ubyte"); + + byte_orderer bo; + + // make sure the files have the contents we expect. + uint32 magic, num, nr, nc, num2, num3, num4; + fin1.read((char*)&magic, sizeof(magic)); bo.big_to_host(magic); + fin1.read((char*)&num, sizeof(num)); bo.big_to_host(num); + fin1.read((char*)&nr, sizeof(nr)); bo.big_to_host(nr); + fin1.read((char*)&nc, sizeof(nc)); bo.big_to_host(nc); + if (magic != 2051 || num != 60000 || nr != 28 || nc != 28) + throw error("mndist dat files are corrupted."); + + fin2.read((char*)&magic, sizeof(magic)); bo.big_to_host(magic); + fin2.read((char*)&num2, sizeof(num2)); bo.big_to_host(num2); + if (magic != 2049 || num2 != 60000) + throw error("mndist dat files are corrupted."); + + fin3.read((char*)&magic, sizeof(magic)); bo.big_to_host(magic); + fin3.read((char*)&num3, sizeof(num3)); bo.big_to_host(num3); + fin3.read((char*)&nr, sizeof(nr)); bo.big_to_host(nr); + fin3.read((char*)&nc, sizeof(nc)); bo.big_to_host(nc); + if (magic != 2051 || num3 != 10000 || nr != 28 || nc != 28) + throw error("mndist dat files are corrupted."); + + fin4.read((char*)&magic, sizeof(magic)); bo.big_to_host(magic); + fin4.read((char*)&num4, sizeof(num4)); bo.big_to_host(num4); + if (magic != 2049 || num4 != 10000) + throw error("mndist dat files are corrupted."); + + if (!fin1) throw error("Unable to read train-images-idx3-ubyte"); + if (!fin2) throw error("Unable to read train-labels-idx1-ubyte"); + if (!fin3) throw error("Unable to read t10k-images-idx3-ubyte"); + if (!fin4) throw error("Unable to read t10k-labels-idx1-ubyte"); + + + training_images.resize(60000); + training_labels.resize(60000); + testing_images.resize(10000); + testing_labels.resize(10000); + + for (size_t i = 0; i < training_images.size(); ++i) + { + training_images[i].set_size(nr,nc); + fin1.read((char*)&training_images[i](0,0), nr*nc); + } + for (size_t i = 0; i < training_labels.size(); ++i) + { + char l; + fin2.read(&l, 1); + training_labels[i] = l; + } + + for (size_t i = 0; i < testing_images.size(); ++i) + { + testing_images[i].set_size(nr,nc); + fin3.read((char*)&testing_images[i](0,0), nr*nc); + } + for (size_t i = 0; i < testing_labels.size(); ++i) + { + char l; + fin4.read(&l, 1); + testing_labels[i] = l; + } + + if (!fin1) throw error("Unable to read train-images-idx3-ubyte"); + if (!fin2) throw error("Unable to read train-labels-idx1-ubyte"); + if (!fin3) throw error("Unable to read t10k-images-idx3-ubyte"); + if (!fin4) throw error("Unable to read t10k-labels-idx1-ubyte"); + + if (fin1.get() != EOF) throw error("Unexpected bytes at end of train-images-idx3-ubyte"); + if (fin2.get() != EOF) throw error("Unexpected bytes at end of train-labels-idx1-ubyte"); + if (fin3.get() != EOF) throw error("Unexpected bytes at end of t10k-images-idx3-ubyte"); + if (fin4.get() != EOF) throw error("Unexpected bytes at end of t10k-labels-idx1-ubyte"); + } +} + +// ---------------------------------------------------------------------------------------- + +#endif // DLIB_MNIST_CPp_ + + + diff --git a/ml/dlib/dlib/data_io/mnist.h b/ml/dlib/dlib/data_io/mnist.h new file mode 100644 index 000000000..e71be6f2b --- /dev/null +++ b/ml/dlib/dlib/data_io/mnist.h @@ -0,0 +1,32 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MNIST_Hh_ +#define DLIB_MNIST_Hh_ + +#include "mnist_abstract.h" +#include +#include +#include "../matrix.h" + +// ---------------------------------------------------------------------------------------- + +namespace dlib +{ + void load_mnist_dataset ( + const std::string& folder_name, + std::vector >& training_images, + std::vector& training_labels, + std::vector >& testing_images, + std::vector& testing_labels + ); +} + +// ---------------------------------------------------------------------------------------- + +#ifdef NO_MAKEFILE +#include "mnist.cpp" +#endif + +#endif // DLIB_MNIST_Hh_ + + diff --git a/ml/dlib/dlib/data_io/mnist_abstract.h b/ml/dlib/dlib/data_io/mnist_abstract.h new file mode 100644 index 000000000..09121633e --- /dev/null +++ b/ml/dlib/dlib/data_io/mnist_abstract.h @@ -0,0 +1,46 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_MNIST_ABSTRACT_Hh_ +#ifdef DLIB_MNIST_ABSTRACT_Hh_ + +#include +#include +#include "../matrix.h" + +// ---------------------------------------------------------------------------------------- + +namespace dlib +{ + void load_mnist_dataset ( + const std::string& folder_name, + std::vector >& training_images, + std::vector& training_labels, + std::vector >& testing_images, + std::vector& testing_labels + ); + /*! + ensures + - Attempts to load the MNIST dataset from the hard drive. This is the dataset + of handwritten digits available from http://yann.lecun.com/exdb/mnist/. In + particular, the 4 files comprising the MNIST dataset should be present in the + folder indicated by folder_name. These four files are: + - train-images-idx3-ubyte + - train-labels-idx1-ubyte + - t10k-images-idx3-ubyte + - t10k-labels-idx1-ubyte + - #training_images == The 60,000 training images from the dataset. + - #training_labels == The labels for the contents of #training_images. + I.e. #training_labels[i] is the label of #training_images[i]. + - #testing_images == The 10,000 testing images from the dataset. + - #testing_labels == The labels for the contents of #testing_images. + I.e. #testing_labels[i] is the label of #testing_images[i]. + throws + - dlib::error if some problem prevents us from loading the data or the files + can't be found. + !*/ +} + +// ---------------------------------------------------------------------------------------- + +#endif // DLIB_MNIST_ABSTRACT_Hh_ + diff --git a/ml/dlib/dlib/dir_nav.h b/ml/dlib/dlib/dir_nav.h new file mode 100644 index 000000000..c5956615d --- /dev/null +++ b/ml/dlib/dlib/dir_nav.h @@ -0,0 +1,21 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_DIR_NAv_ +#define DLIB_DIR_NAv_ + + +#include "platform.h" + + +#ifdef WIN32 +#include "dir_nav/windows.h" +#endif + +#ifndef WIN32 +#include "dir_nav/posix.h" +#endif + +#include "dir_nav/dir_nav_extensions.h" + +#endif // DLIB_DIR_NAv_ + diff --git a/ml/dlib/dlib/dir_nav/dir_nav_extensions.cpp b/ml/dlib/dlib/dir_nav/dir_nav_extensions.cpp new file mode 100644 index 000000000..db05e4cc4 --- /dev/null +++ b/ml/dlib/dlib/dir_nav/dir_nav_extensions.cpp @@ -0,0 +1,121 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_DIR_NAV_EXTENSIONs_CPP_ +#define DLIB_DIR_NAV_EXTENSIONs_CPP_ + +#include "dir_nav_extensions.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace implementation_details + { + void get_all_sub_dirs ( + const directory& top_of_tree, + unsigned long max_depth, + std::vector& result, + std::vector& temp + ) + { + if (max_depth > 0) + { + top_of_tree.get_dirs(temp); + const unsigned long start = result.size(); + result.insert(result.end(), temp.begin(), temp.end()); + const unsigned long end = start + temp.size(); + + for (unsigned long i = start; i < end; ++i) + { + get_all_sub_dirs(result[i], max_depth-1, result, temp); + } + } + } + } + +// ---------------------------------------------------------------------------------------- + + bool file_exists ( + const std::string& filename + ) + { + try + { + dlib::file temp(filename); + return true; + } + catch (file::file_not_found&) + { + return false; + } + } + +// ---------------------------------------------------------------------------------------- + + directory get_parent_directory ( + const directory& dir + ) + { + return dir.get_parent(); + } + +// ---------------------------------------------------------------------------------------- + + directory get_parent_directory ( + const file& f + ) + { + if (f.full_name().size() == 0) + return directory(); + + std::string::size_type pos = f.full_name().find_last_of("\\/"); + + if (pos == std::string::npos) + return directory(); + + return directory(f.full_name().substr(0,pos)); + } + +// ---------------------------------------------------------------------------------------- + + std::string select_oldest_file ( + const std::string& filename1, + const std::string& filename2 + ) + { + file f1, f2; + try{f1 = file(filename1);} catch(file::file_not_found&) { return filename1; } + try{f2 = file(filename2);} catch(file::file_not_found&) { return filename2; } + + if (f1.last_modified() < f2.last_modified()) + return filename1; + else + return filename2; + } + +// ---------------------------------------------------------------------------------------- + + std::string select_newest_file ( + const std::string& filename1, + const std::string& filename2 + ) + { + file f1, f2; + try{f1 = file(filename1);} catch(file::file_not_found&) { return filename2; } + try{f2 = file(filename2);} catch(file::file_not_found&) { return filename1; } + + if (f1.last_modified() > f2.last_modified()) + return filename1; + else + return filename2; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_DIR_NAV_EXTENSIONs_CPP_ + + + diff --git a/ml/dlib/dlib/dir_nav/dir_nav_extensions.h b/ml/dlib/dlib/dir_nav/dir_nav_extensions.h new file mode 100644 index 000000000..93dde1159 --- /dev/null +++ b/ml/dlib/dlib/dir_nav/dir_nav_extensions.h @@ -0,0 +1,172 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_DIR_NAV_EXTENSIONs_H_ +#define DLIB_DIR_NAV_EXTENSIONs_H_ + +#include +#include +#include +#include "dir_nav_extensions_abstract.h" +#include "../dir_nav.h" +#include "../string.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + bool file_exists ( + const std::string& filename + ); + +// ---------------------------------------------------------------------------------------- + + namespace implementation_details + { + void get_all_sub_dirs ( + const directory& top_of_tree, + unsigned long max_depth, + std::vector& result, + std::vector& temp + ); + } + +// ---------------------------------------------------------------------------------------- + + template + const std::vector get_files_in_directory_tree ( + const directory& top_of_tree, + const T& add_file, + unsigned long max_depth = 30 + ) + { + std::vector result, temp; + std::vector dirs, dirs_temp; + dirs.push_back(top_of_tree); + + // get all the directories in the tree first + implementation_details::get_all_sub_dirs(top_of_tree, max_depth, dirs, dirs_temp); + + // now just loop over all the directories and pick out the files we want to keep + for (unsigned long d = 0; d < dirs.size(); ++d) + { + dirs[d].get_files(temp); + + // pick out the members of temp that we should keep + for (unsigned long i = 0; i < temp.size(); ++i) + { + if (add_file(temp[i])) + result.push_back(temp[i]); + } + } + + return result; + } + +// ---------------------------------------------------------------------------------------- + + class match_ending + { + + public: + match_ending ( + const std::string& ending_ + ) : ending(ending_) {} + + bool operator() ( + const file& f + ) const + { + // if the ending is bigger than f's name then it obviously doesn't match + if (ending.size() > f.name().size()) + return false; + + // now check if the actual characters that make up the end of the file name + // matches what is in ending. + return std::equal(ending.begin(), ending.end(), f.name().end()-ending.size()); + } + + private: + std::string ending; + }; + +// ---------------------------------------------------------------------------------------- + + class match_endings + { + + public: + match_endings ( + const std::string& endings_ + ) + { + const std::vector& s = split(endings_); + for (unsigned long i = 0; i < s.size(); ++i) + { + endings.push_back(match_ending(s[i])); + } + } + + bool operator() ( + const file& f + ) const + { + for (unsigned long i = 0; i < endings.size(); ++i) + { + if (endings[i](f)) + return true; + } + + return false; + } + + private: + std::vector endings; + }; + +// ---------------------------------------------------------------------------------------- + + class match_all + { + public: + bool operator() ( + const file& + ) const { return true; } + }; + +// ---------------------------------------------------------------------------------------- + + directory get_parent_directory ( + const directory& dir + ); + +// ---------------------------------------------------------------------------------------- + + directory get_parent_directory ( + const file& f + ); + +// ---------------------------------------------------------------------------------------- + + std::string select_oldest_file ( + const std::string& filename1, + const std::string& filename2 + ); + +// ---------------------------------------------------------------------------------------- + + std::string select_newest_file ( + const std::string& filename1, + const std::string& filename2 + ); + +// ---------------------------------------------------------------------------------------- + +} + +#ifdef NO_MAKEFILE +#include "dir_nav_extensions.cpp" +#endif + +#endif // DLIB_DIR_NAV_EXTENSIONs_H_ + diff --git a/ml/dlib/dlib/dir_nav/dir_nav_extensions_abstract.h b/ml/dlib/dlib/dir_nav/dir_nav_extensions_abstract.h new file mode 100644 index 000000000..4aa6cc4f2 --- /dev/null +++ b/ml/dlib/dlib/dir_nav/dir_nav_extensions_abstract.h @@ -0,0 +1,203 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_DIR_NAV_EXTENSIONs_ABSTRACT_ +#ifdef DLIB_DIR_NAV_EXTENSIONs_ABSTRACT_ + +#include +#include +#include "dir_nav_kernel_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + bool file_exists ( + const std::string& filename + ); + /*! + ensures + - if (a file with the given filename exists) then + - returns true + - else + - returns false + !*/ + +// ---------------------------------------------------------------------------------------- + + template + const std::vector get_files_in_directory_tree ( + const directory& top_of_tree, + const T& add_file, + unsigned long max_depth = 30 + ); + /*! + requires + - add_file must be a function object with the following prototype: + bool add_file (file f); + ensures + - performs a recursive search through the directory top_of_tree and all + its sub-directories (up to the given max depth). All files in these + directories are examined by passing them to add_file() and if it + returns true then they will be included in the returned std::vector + object. + - Note that a max_depth of 0 means that only the files in the directory + top_of_tree will be considered. A depth of 1 means that only files in + top_of_tree and its immediate sub-directories will be considered. And + so on... + !*/ + +// ---------------------------------------------------------------------------------------- + + class match_ending + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a simple function object that can be used with the + above get_files_in_directory_tree() function. This object + just looks for files with a certain ending. + !*/ + + public: + match_ending ( + const std::string& ending + ); + /*! + ensures + - this object will be a function that checks if a file has a + name that ends with the given ending string. + !*/ + + bool operator() ( + const file& f + ) const; + /*! + ensures + - if (the file f has a name that ends with the ending string given + to this object's constructor) then + - returns true + - else + - returns false + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + class match_endings + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a simple function object that can be used with the + above get_files_in_directory_tree() function. This object + allows you to look for files with a number of different + endings. + !*/ + + public: + match_endings ( + const std::string& ending_list + ); + /*! + ensures + - ending_list is interpreted as a whitespace separated list + of file endings. + - this object will be a function that checks if a file has a + name that ends with one of the strings in ending_list. + !*/ + + bool operator() ( + const file& f + ) const; + /*! + ensures + - if (the file f has a name that ends with one of the ending strings + given to this object's constructor) then + - returns true + - else + - returns false + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + class match_all + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a simple function object that can be used with the + above get_files_in_directory_tree() function. This object + matches all files. + !*/ + + public: + bool operator() ( + const file& f + ) const; + /*! + ensures + - returns true + (i.e. this function doesn't do anything. It just says it + matches all files no matter what) + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + directory get_parent_directory ( + const directory& dir + ); + /*! + ensures + - returns the parent directory of dir. In particular, this + function returns the value of dir.get_parent() + !*/ + +// ---------------------------------------------------------------------------------------- + + directory get_parent_directory ( + const file& f + ); + /*! + ensures + - if (f.full_name() != "") then + - returns the directory which contains the given file + - else + - returns a default initialized directory (i.e. directory()) + !*/ + +// ---------------------------------------------------------------------------------------- + + std::string select_oldest_file ( + const std::string& filename1, + const std::string& filename2 + ); + /*! + ensures + - Checks the last modification times of the two given files and returns the + filename of the oldest file, i.e., the file that has gone longest since being + modified. Ties are broken arbitrarily. + - For the purpose of comparison, a file that doesn't exist is presumed to have + a last modification time of -infinity (i.e. very far in the past). + !*/ + +// ---------------------------------------------------------------------------------------- + + std::string select_newest_file ( + const std::string& filename1, + const std::string& filename2 + ); + /*! + ensures + - Checks the last modification times of the two given files and returns the + filename that was most recently modified. Ties are broken arbitrarily. + - For the purpose of comparison, a file that doesn't exist is presumed to have + a last modification time of -infinity (i.e. very far in the past). + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_DIR_NAV_EXTENSIONs_ABSTRACT_ + + diff --git a/ml/dlib/dlib/dir_nav/dir_nav_kernel_1.cpp b/ml/dlib/dlib/dir_nav/dir_nav_kernel_1.cpp new file mode 100644 index 000000000..9891d5dff --- /dev/null +++ b/ml/dlib/dlib/dir_nav/dir_nav_kernel_1.cpp @@ -0,0 +1,258 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_DIR_NAV_KERNEL_1_CPp_ +#define DLIB_DIR_NAV_KERNEL_1_CPp_ +#include "../platform.h" + +#ifdef WIN32 + +#include "dir_nav_kernel_1.h" +#include "../string.h" + + +#ifdef __BORLANDC__ +// Apparently the borland compiler doesn't define this. +#define INVALID_FILE_ATTRIBUTES ((DWORD)-1) +#endif + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // file object implementation +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + void file:: + init ( + const std::string& name + ) + { + using namespace std; + + + char buf[3000]; + char* str; + if (GetFullPathNameA(name.c_str(),sizeof(buf),buf,&str) == 0) + { + // the file was not found + throw file_not_found("Unable to find file " + name); + } + state.full_name = buf; + + + string::size_type pos = state.full_name.find_last_of(directory::get_separator()); + if (pos == string::npos) + { + // no valid full path has no separator characters. + throw file_not_found("Unable to find file " + name); + } + state.name = state.full_name.substr(pos+1); + + + // now find the size of this file + WIN32_FIND_DATAA data; + HANDLE ffind = FindFirstFileA(state.full_name.c_str(), &data); + if (ffind == INVALID_HANDLE_VALUE || + (data.dwFileAttributes&FILE_ATTRIBUTE_DIRECTORY) != 0) + { + throw file_not_found("Unable to find file " + name); + } + else + { + uint64 temp = data.nFileSizeHigh; + temp <<= 32; + temp |= data.nFileSizeLow; + state.file_size = temp; + FindClose(ffind); + + ULARGE_INTEGER ull; + ull.LowPart = data.ftLastWriteTime.dwLowDateTime; + ull.HighPart = data.ftLastWriteTime.dwHighDateTime; + std::chrono::nanoseconds epoch(100 * (ull.QuadPart - 116444736000000000)); + state.last_modified = std::chrono::time_point(std::chrono::duration_cast(epoch)); + } + + } + +// ---------------------------------------------------------------------------------------- + + bool file:: + operator == ( + const file& rhs + ) const + { + using namespace std; + + if (state.full_name.size() != rhs.state.full_name.size()) + return false; + + // compare the strings but ignore the case because file names + // are not case sensitive on windows + return tolower(state.full_name) == tolower(rhs.state.full_name); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // directory object implementation +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + void directory:: + init ( + const std::string& name + ) + { + using namespace std; + + + char buf[3000]; + char* str; + if (GetFullPathNameA(name.c_str(),sizeof(buf),buf,&str) == 0) + { + // the directory was not found + throw dir_not_found("Unable to find directory " + name); + } + state.full_name = buf; + + + const char sep = get_separator(); + if (is_root_path(state.full_name) == false) + { + // ensure that thre is not a trialing separator + if (state.full_name[state.full_name.size()-1] == sep) + state.full_name.erase(state.full_name.size()-1); + + // pick out the directory name + string::size_type pos = state.full_name.find_last_of(sep); + state.name = state.full_name.substr(pos+1); + } + else + { + // ensure that there is a trailing separator + if (state.full_name[state.full_name.size()-1] != sep) + state.full_name += sep; + } + + + // now check that this is actually a valid directory + DWORD attribs = GetFileAttributesA(state.full_name.c_str()); + if (attribs == INVALID_FILE_ATTRIBUTES || + (attribs&FILE_ATTRIBUTE_DIRECTORY) == 0) + { + // the directory was not found + throw dir_not_found("Unable to find directory " + name); + } + + } + +// ---------------------------------------------------------------------------------------- + + char directory:: + get_separator ( + ) + { + return '\\'; + } + +// ---------------------------------------------------------------------------------------- + + bool directory:: + operator == ( + const directory& rhs + ) const + { + using namespace std; + + if (state.full_name.size() != rhs.state.full_name.size()) + return false; + + // compare the strings but ignore the case because file names + // are not case sensitive on windows + return tolower(state.full_name) == tolower(rhs.state.full_name); + } + +// ---------------------------------------------------------------------------------------- + + const directory directory:: + get_parent ( + ) const + { + using namespace std; + // if *this is the root then just return *this + if (is_root()) + { + return *this; + } + else + { + directory temp; + + const char sep = get_separator(); + + string::size_type pos = state.full_name.find_last_of(sep); + temp.state.full_name = state.full_name.substr(0,pos); + + if ( is_root_path(temp.state.full_name)) + { + temp.state.full_name += sep; + } + else + { + pos = temp.state.full_name.find_last_of(sep); + if (pos != string::npos) + { + temp.state.name = temp.state.full_name.substr(pos+1); + } + else + { + temp.state.full_name += sep; + } + } + return temp; + } + } + +// ---------------------------------------------------------------------------------------- + + bool directory:: + is_root_path ( + const std::string& path + ) const + { + using namespace std; + const char sep = get_separator(); + bool root_path = false; + if (path.size() > 2 && path[0] == sep && path[1] == sep) + { + // in this case this is a windows share path + string::size_type pos = path.find_first_of(sep,2); + if (pos != string::npos) + { + pos = path.find_first_of(sep,pos+1); + + if (pos == string::npos && path[path.size()-1] != sep) + root_path = true; + else if (pos == path.size()-1) + root_path = true; + } + + } + else if ( (path.size() == 2 || path.size() == 3) && path[1] == ':') + { + // if this is a valid windows path then it must be a root path + root_path = true; + } + + return root_path; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // WIN32 + +#endif // DLIB_DIR_NAV_KERNEL_1_CPp_ + diff --git a/ml/dlib/dlib/dir_nav/dir_nav_kernel_1.h b/ml/dlib/dlib/dir_nav/dir_nav_kernel_1.h new file mode 100644 index 000000000..a31f689de --- /dev/null +++ b/ml/dlib/dlib/dir_nav/dir_nav_kernel_1.h @@ -0,0 +1,634 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_DIR_NAV_KERNEl_1_ +#define DLIB_DIR_NAV_KERNEl_1_ + +#ifdef DLIB_ISO_CPP_ONLY +#error "DLIB_ISO_CPP_ONLY is defined so you can't use this OS dependent code. Turn DLIB_ISO_CPP_ONLY off if you want to use it." +#endif + +#include "../platform.h" + + +#include "dir_nav_kernel_abstract.h" +#include +#include "../uintn.h" +#include "../algs.h" + +#include "../windows_magic.h" +#include +#include +#include "../stl_checked.h" +#include "../enable_if.h" +#include "../queue.h" +#include + +namespace dlib +{ + + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // file object +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class file + { + /*! + INITIAL VALUES + state.name == name() + state.full_name == full_name() + state.file_size == size() + state.last_modified == last_modified() + + CONVENTION + state.name == name() + state.full_name == full_name() + state.file_size == size() + state.last_modified == last_modified() + + !*/ + + friend class directory; + + struct data + { + uint64 file_size; + std::string name; + std::string full_name; + std::chrono::time_point last_modified; + }; + + + void init ( const std::string& name); + + public: + + struct private_constructor{}; + inline file ( + const std::string& name, + const std::string& full_name, + const uint64 file_size, + const std::chrono::time_point& last_modified, + private_constructor + ) + { + state.file_size = file_size; + state.name = name; + state.full_name = full_name; + state.last_modified = last_modified; + } + + + + + class file_not_found : public error { + public: file_not_found(const std::string& s): error(s){} + }; + + inline file ( + ) + { + state.file_size = 0; + } + + file ( + const std::string& name + ) { init(name); } + + file ( + const char* name + ) { init(name); } + + inline const std::string& name ( + ) const { return state.name; } + + inline const std::string& full_name ( + ) const { return state.full_name; } + + operator std::string ( + ) const { return full_name(); } + + inline uint64 size ( + ) const { return state.file_size; } + + inline std::chrono::time_point last_modified ( + ) const { return state.last_modified; } + + bool operator == ( + const file& rhs + ) const; + + bool operator != ( + const file& rhs + ) const { return !(*this == rhs); } + + inline bool operator < ( + const file& item + ) const { return full_name() < item.full_name(); } + + inline void swap ( + file& item + ) + { + exchange(state,item.state); + } + + private: + + data state; + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // directory object +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class directory + { + /*! + INITIAL VALUES + state.name == name() + state.full_name == full_name() + + CONVENTION + state.name == name() + state.full_name == full_name() + is_root() == state.name.size() == 0 + + !*/ + + void init (const std::string& name); + + public: + + struct data + { + std::string name; + std::string full_name; + }; + + + /* + The reason we don't just make this constructor actually + private is because doing it this way avoids a bug that + sometimes occurs in visual studio 7.1. The bug has + something to do with templated friend functions + such as the get_filesystem_roots() function below if + it was declared as a friend template of this class. + */ + struct private_constructor{}; + inline directory ( + const std::string& name, + const std::string& full_name, + private_constructor + ) + { + state.name = name; + state.full_name = full_name; + } + + + class dir_not_found : public error { + public: dir_not_found(const std::string& s):error(s){} + }; + class listing_error : public error { + public: listing_error(const std::string& s):error(s){} + }; + + inline directory ( + ) + { + } + + directory ( + const std::string& name + ) { init(name); } + + directory ( + const char* name + ) { init(name); } + + + static char get_separator ( + ); + + + template < + typename queue_of_files + > + void get_files ( + queue_of_files& files + ) const; + + template < + typename queue_of_dirs + > + void get_dirs ( + queue_of_dirs& dirs + ) const; + + std::vector get_files ( + ) const + { + std::vector temp_vector; + get_files(temp_vector); + return temp_vector; + } + + std::vector get_dirs ( + ) const + { + std::vector temp_vector; + get_dirs(temp_vector); + return temp_vector; + } + + const directory get_parent ( + ) const; + + inline bool is_root ( + ) const { return state.name.size() == 0; } + + inline const std::string& name ( + ) const { return state.name; } + + inline const std::string& full_name ( + ) const { return state.full_name; } + + operator std::string ( + ) const { return full_name(); } + + bool operator == ( + const directory& rhs + ) const; + + bool operator != ( + const directory& rhs + ) const { return !(*this == rhs); } + + inline bool operator < ( + const directory& item + ) const { return full_name() < item.full_name(); } + + inline void swap ( + directory& item + ) + { + exchange(state,item.state); + } + + private: + + // member data + data state; + + bool is_root_path ( + const std::string& path + ) const; + /*! + ensures + - returns true if path is a root path. + Note that this function considers root paths that don't + have a trailing separator to also be valid. + !*/ + + + }; + +// ---------------------------------------------------------------------------------------- + + inline std::ostream& operator<< ( + std::ostream& out, + const directory& item + ) { out << (std::string)item; return out; } + + inline std::ostream& operator<< ( + std::ostream& out, + const file& item + ) { out << (std::string)item; return out; } + +// ---------------------------------------------------------------------------------------- + + template < + typename queue_of_dir + > + typename disable_if,void>::type get_filesystem_roots ( + queue_of_dir& roots + ) + { + roots.clear(); + const DWORD mask = GetLogicalDrives(); + DWORD bit = 1; + char buf[] = "A:\\"; + + do + { + if (mask & bit) + { + directory dir("",buf,directory::private_constructor()); + roots.enqueue(dir); + } + bit <<= 1; + ++buf[0]; + } while (buf[0] != 'Z'); + } + + template < + typename queue_of_dir + > + typename enable_if,void>::type get_filesystem_roots ( + queue_of_dir& roots + ) + { + roots.clear(); + const DWORD mask = GetLogicalDrives(); + DWORD bit = 1; + char buf[] = "A:\\"; + + do + { + if (mask & bit) + { + directory dir("",buf,directory::private_constructor()); + roots.push_back(dir); + } + bit <<= 1; + ++buf[0]; + } while (buf[0] != 'Z'); + } + +// ---------------------------------------------------------------------------------------- + + inline void swap ( + file& a, + file& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- + + inline void swap ( + directory& a, + directory& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // templated member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename queue_of_files + > + typename disable_if,void>::type + directory_helper_get_files ( + const directory::data& state, + queue_of_files& files + ) + { + using namespace std; + typedef directory::listing_error listing_error; + typedef file::private_constructor private_constructor; + + files.clear(); + if (state.full_name.size() == 0) + throw listing_error("This directory object currently doesn't represent any directory."); + + HANDLE ffind = INVALID_HANDLE_VALUE; + try + { + WIN32_FIND_DATAA data; + string path = state.full_name; + // ensure that the path ends with a separator + if (path[path.size()-1] != directory::get_separator()) + path += directory::get_separator(); + + ffind = FindFirstFileA((path+"*").c_str(), &data); + if (ffind == INVALID_HANDLE_VALUE) + { + throw listing_error("Unable to list the contents of " + state.full_name); + } + + + bool no_more_files = false; + do + { + if ((data.dwFileAttributes&FILE_ATTRIBUTE_DIRECTORY) == 0) + { + uint64 file_size = data.nFileSizeHigh; + file_size <<= 32; + file_size |= data.nFileSizeLow; + + ULARGE_INTEGER ull; + ull.LowPart = data.ftLastWriteTime.dwLowDateTime; + ull.HighPart = data.ftLastWriteTime.dwHighDateTime; + std::chrono::nanoseconds epoch(100 * (ull.QuadPart - 116444736000000000)); + auto last_modified = std::chrono::time_point(std::chrono::duration_cast(epoch)); + + // this is a file so add it to the queue + file temp(data.cFileName,path+data.cFileName,file_size, last_modified, private_constructor()); + files.enqueue(temp); + } + + if (FindNextFileA(ffind,&data) == 0) + { + // an error occurred + if ( GetLastError() == ERROR_NO_MORE_FILES) + { + // there are no more files + no_more_files = true; + } + else + { + // there was an error + throw listing_error("Unable to list the contents of " + state.full_name); + } + } + } while (no_more_files == false); + + FindClose(ffind); + ffind = INVALID_HANDLE_VALUE; + } + catch (...) + { + if (ffind != INVALID_HANDLE_VALUE) + FindClose(ffind); + files.clear(); + throw; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename queue_of_files + > + typename enable_if,void>::type + directory_helper_get_files ( + const directory::data& state, + queue_of_files& files + ) + { + queue::kernel_2a temp_files; + directory_helper_get_files(state,temp_files); + + files.clear(); + + // copy the queue of files into the vector + temp_files.reset(); + while (temp_files.move_next()) + { + files.push_back(temp_files.element()); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename queue_of_files + > + void directory:: + get_files ( + queue_of_files& files + ) const + { + // the reason for this indirection here is because it avoids a bug in + // the mingw version of gcc + directory_helper_get_files(state,files); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename queue_of_dirs + > + typename disable_if,void>::type + directory_helper_get_dirs ( + const directory::data& state, + queue_of_dirs& dirs + ) + { + using namespace std; + typedef directory::listing_error listing_error; + typedef directory::private_constructor private_constructor; + + dirs.clear(); + if (state.full_name.size() == 0) + throw listing_error("This directory object currently doesn't represent any directory."); + + HANDLE dfind = INVALID_HANDLE_VALUE; + try + { + WIN32_FIND_DATAA data; + string path = state.full_name; + // ensure that the path ends with a separator + if (path[path.size()-1] != directory::get_separator()) + path += directory::get_separator(); + + dfind = FindFirstFileA((path+"*").c_str(), &data); + if (dfind == INVALID_HANDLE_VALUE) + { + throw listing_error("Unable to list the contents of " + state.full_name); + } + + + bool no_more_files = false; + do + { + string tname(data.cFileName); + if ((data.dwFileAttributes&FILE_ATTRIBUTE_DIRECTORY) != 0 && + tname != "." && + tname != "..") + { + // this is a directory so add it to the queue + directory temp(tname,path+tname,private_constructor()); + dirs.enqueue(temp); + } + + if (FindNextFileA(dfind,&data) == 0) + { + // an error occurred + if ( GetLastError() == ERROR_NO_MORE_FILES) + { + // there are no more files + no_more_files = true; + } + else + { + // there was an error + throw listing_error("Unable to list the contents of " + state.full_name); + } + } + } while (no_more_files == false); + + FindClose(dfind); + dfind = INVALID_HANDLE_VALUE; + } + catch (...) + { + if (dfind != INVALID_HANDLE_VALUE) + FindClose(dfind); + dirs.clear(); + throw; + } + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename queue_of_dirs + > + typename enable_if,void>::type + directory_helper_get_dirs ( + const directory::data& state, + queue_of_dirs& dirs + ) + { + queue::kernel_2a temp_dirs; + directory_helper_get_dirs(state,temp_dirs); + + dirs.clear(); + + // copy the queue of dirs into the vector + temp_dirs.reset(); + while (temp_dirs.move_next()) + { + dirs.push_back(temp_dirs.element()); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename queue_of_dirs + > + void directory:: + get_dirs ( + queue_of_dirs& dirs + ) const + { + // the reason for this indirection here is because it avoids a bug in + // the mingw version of gcc + directory_helper_get_dirs(state,dirs); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + +} + + +#ifdef NO_MAKEFILE +#include "dir_nav_kernel_1.cpp" +#endif + +#endif // DLIB_DIR_NAV_KERNEl_1_ + diff --git a/ml/dlib/dlib/dir_nav/dir_nav_kernel_2.cpp b/ml/dlib/dlib/dir_nav/dir_nav_kernel_2.cpp new file mode 100644 index 000000000..be97b984c --- /dev/null +++ b/ml/dlib/dlib/dir_nav/dir_nav_kernel_2.cpp @@ -0,0 +1,254 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_DIR_NAV_KERNEL_2_CPp_ +#define DLIB_DIR_NAV_KERNEL_2_CPp_ + +#include "../platform.h" + +#ifdef POSIX + + +#include "dir_nav_kernel_2.h" + + + + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // file object implementation +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + void file:: + init ( + const std::string& name + ) + { + using namespace std; + + + + char buf[PATH_MAX]; + if (realpath(name.c_str(),buf) == 0) + { + // the file was not found + throw file_not_found("Unable to find file " + name); + } + state.full_name = buf; + + + string::size_type pos = state.full_name.find_last_of(directory::get_separator()); + if (pos == string::npos) + { + // no valid full path has no separtor characters. + throw file_not_found("Unable to find file " + name); + } + state.name = state.full_name.substr(pos+1); + + + // now find the size of this file + struct stat64 buffer; + if (::stat64(state.full_name.c_str(), &buffer) || + S_ISDIR(buffer.st_mode)) + { + // there was an error during the call to stat64 or + // name is actually a directory + throw file_not_found("Unable to find file " + name); + } + else + { + state.file_size = static_cast(buffer.st_size); + + + state.last_modified = std::chrono::system_clock::from_time_t(buffer.st_mtime); +#ifdef _BSD_SOURCE + state.last_modified += std::chrono::duration_cast(std::chrono::nanoseconds(buffer.st_atim.tv_nsec)); +#endif + } + + } + +// ---------------------------------------------------------------------------------------- + + bool file:: + operator == ( + const file& rhs + ) const + { + using namespace std; + if (state.full_name.size() == 0 && rhs.state.full_name.size() == 0) + return true; + + // These files might have different names but actually represent the same + // file due to the presence of symbolic links. + char buf[PATH_MAX]; + string left, right; + if (realpath(state.full_name.c_str(),buf) == 0) + return false; + left = buf; + + if (realpath(rhs.state.full_name.c_str(),buf) == 0) + return false; + right = buf; + + return (left == right); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // directory object implementation +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + void directory:: + init ( + const std::string& name + ) + { + using namespace std; + + + char buf[PATH_MAX]; + if (realpath(name.c_str(),buf) == 0) + { + // the directory was not found + throw dir_not_found("Unable to find directory " + name); + } + state.full_name = buf; + + + const char sep = get_separator(); + if (is_root_path(state.full_name) == false) + { + // ensure that thre is not a trialing separator + if (state.full_name[state.full_name.size()-1] == sep) + state.full_name.erase(state.full_name.size()-1); + + // pick out the directory name + string::size_type pos = state.full_name.find_last_of(sep); + state.name = state.full_name.substr(pos+1); + } + else + { + // ensure that there is a trailing separator + if (state.full_name[state.full_name.size()-1] != sep) + state.full_name += sep; + } + + + struct stat64 buffer; + // now check that this is actually a valid directory + if (::stat64(state.full_name.c_str(),&buffer)) + { + // the directory was not found + throw dir_not_found("Unable to find directory " + name); + } + else if (S_ISDIR(buffer.st_mode) == 0) + { + // It is not a directory + throw dir_not_found("Unable to find directory " + name); + } + } + +// ---------------------------------------------------------------------------------------- + + char directory:: + get_separator ( + ) + { + return '/'; + } + +// ---------------------------------------------------------------------------------------- + + bool directory:: + operator == ( + const directory& rhs + ) const + { + using namespace std; + if (state.full_name.size() == 0 && rhs.state.full_name.size() == 0) + return true; + + // These directories might have different names but actually represent the same + // directory due to the presence of symbolic links. + char buf[PATH_MAX]; + string left, right; + if (realpath(state.full_name.c_str(),buf) == 0) + return false; + left = buf; + + if (realpath(rhs.state.full_name.c_str(),buf) == 0) + return false; + right = buf; + + return (left == right); + } + +// ---------------------------------------------------------------------------------------- + + const directory directory:: + get_parent ( + ) const + { + using namespace std; + // if *this is the root then just return *this + if (is_root()) + { + return *this; + } + else + { + directory temp; + + const char sep = get_separator(); + + string::size_type pos = state.full_name.find_last_of(sep); + temp.state.full_name = state.full_name.substr(0,pos); + + if ( is_root_path(temp.state.full_name)) + { + temp.state.full_name += sep; + } + else + { + pos = temp.state.full_name.find_last_of(sep); + if (pos != string::npos) + { + temp.state.name = temp.state.full_name.substr(pos+1); + } + else + { + temp.state.full_name += sep; + } + } + return temp; + } + } + +// ---------------------------------------------------------------------------------------- + + bool directory:: + is_root_path ( + const std::string& path + ) const + { + const char sep = get_separator(); + if (path.size() == 1 && path[0] == sep) + return true; + else + return false; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // POSIX + +#endif // DLIB_DIR_NAV_KERNEL_2_CPp_ + diff --git a/ml/dlib/dlib/dir_nav/dir_nav_kernel_2.h b/ml/dlib/dlib/dir_nav/dir_nav_kernel_2.h new file mode 100644 index 000000000..af2f3d5dc --- /dev/null +++ b/ml/dlib/dlib/dir_nav/dir_nav_kernel_2.h @@ -0,0 +1,659 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_DIR_NAV_KERNEl_2_ +#define DLIB_DIR_NAV_KERNEl_2_ + +#ifdef DLIB_ISO_CPP_ONLY +#error "DLIB_ISO_CPP_ONLY is defined so you can't use this OS dependent code. Turn DLIB_ISO_CPP_ONLY off if you want to use it." +#endif + + +#include "dir_nav_kernel_abstract.h" + +#include +#include "../uintn.h" +#include "../algs.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if !defined(__USE_LARGEFILE64 ) && !defined(_LARGEFILE64_SOURCE) +#define stat64 stat +#endif + +#include +#include "../stl_checked.h" +#include "../enable_if.h" +#include "../queue.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // file object +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class file + { + /*! + INITIAL VALUES + state.name == name() + state.full_name == full_name() + state.file_size == size() + state.last_modified == last_modified() + + CONVENTION + state.name == name() + state.full_name == full_name() + state.file_size == size() + state.last_modified == last_modified() + + !*/ + + friend class directory; + + struct data + { + uint64 file_size; + std::string name; + std::string full_name; + std::chrono::time_point last_modified; + }; + + void init(const std::string& name); + + public: + + struct private_constructor{}; + inline file ( + const std::string& name, + const std::string& full_name, + const uint64 file_size, + const std::chrono::time_point& last_modified, + private_constructor + ) + { + state.file_size = file_size; + state.name = name; + state.full_name = full_name; + state.last_modified = last_modified; + } + + + class file_not_found : public error { + public: file_not_found(const std::string& s): error(s){} + }; + + inline file ( + ) + { + state.file_size = 0; + } + + file ( + const std::string& name + ) { init(name); } + + file ( + const char* name + ) { init(name); } + + inline const std::string& name ( + ) const { return state.name; } + + inline const std::string& full_name ( + ) const { return state.full_name; } + + inline uint64 size ( + ) const { return state.file_size; } + + inline std::chrono::time_point last_modified ( + ) const { return state.last_modified; } + + operator std::string ( + ) const { return full_name(); } + + bool operator == ( + const file& rhs + ) const; + + bool operator != ( + const file& rhs + ) const { return !(*this == rhs); } + + inline bool operator < ( + const file& item + ) const { return full_name() < item.full_name(); } + + inline void swap ( + file& item + ) + { + exchange(state,item.state); + } + + private: + + // member data + data state; + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // directory object +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class directory + { + /*! + INITIAL VALUES + state.name == name() + state.full_name == full_name() + + CONVENTION + state.name == name() + state.full_name == full_name() + is_root() == state.name.size() == 0 + + !*/ + + void init(const std::string& name); + + public: + struct private_constructor{}; + inline directory ( + const std::string& name, + const std::string& full_name, + private_constructor + ) + { + state.name = name; + state.full_name = full_name; + } + + struct data + { + std::string name; + std::string full_name; + }; + + class dir_not_found : public error { + public: dir_not_found(const std::string& s):error(s){} + }; + class listing_error : public error { + public: listing_error(const std::string& s):error(s){} + }; + + inline directory ( + ) + { + } + + directory ( + const std::string& name + ) { init(name); } + + directory ( + const char* name + ) { init(name); } + + static char get_separator ( + ); + + template < + typename queue_of_files + > + void get_files ( + queue_of_files& files + ) const; + + template < + typename queue_of_dirs + > + void get_dirs ( + queue_of_dirs& dirs + ) const; + + std::vector get_files ( + ) const + { + std::vector temp_vector; + get_files(temp_vector); + return temp_vector; + } + + std::vector get_dirs ( + ) const + { + std::vector temp_vector; + get_dirs(temp_vector); + return temp_vector; + } + + const directory get_parent ( + ) const; + + inline bool is_root ( + ) const { return state.name.size() == 0; } + + inline const std::string& name ( + ) const { return state.name; } + + inline const std::string& full_name ( + ) const { return state.full_name; } + + operator std::string ( + ) const { return full_name(); } + + bool operator == ( + const directory& rhs + ) const; + + bool operator != ( + const directory& rhs + ) const { return !(*this == rhs); } + + inline bool operator < ( + const directory& item + ) const { return full_name() < item.full_name(); } + + inline void swap ( + directory& item + ) + { + exchange(state,item.state); + } + + private: + + // member data + data state; + + bool is_root_path ( + const std::string& path + ) const; + /*! + ensures + - returns true if path is a root path. + Note that this function considers root paths that don't + have a trailing separator to also be valid. + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + inline std::ostream& operator<< ( + std::ostream& out, + const directory& item + ) { out << (std::string)item; return out; } + + inline std::ostream& operator<< ( + std::ostream& out, + const file& item + ) { out << (std::string)item; return out; } + +// ---------------------------------------------------------------------------------------- + + inline void swap ( + file& a, + file& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- + + inline void swap ( + directory& a, + directory& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // templated member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename queue_of_files + > + typename disable_if,void>::type + directory_helper_get_files ( + const directory::data& state, + queue_of_files& files + ) + { + using namespace std; + + files.clear(); + if (state.full_name.size() == 0) + throw directory::listing_error("This directory object currently doesn't represent any directory."); + + DIR* ffind = 0; + struct dirent* data; + struct stat64 buffer; + + try + { + string path = state.full_name; + // ensure that the path ends with a separator + if (path[path.size()-1] != directory::get_separator()) + path += directory::get_separator(); + + // get a handle to something we can search with + ffind = opendir(state.full_name.c_str()); + if (ffind == 0) + { + throw directory::listing_error("Unable to list the contents of " + state.full_name); + } + + while(true) + { + errno = 0; + if ( (data = readdir(ffind)) == 0) + { + // there was an error or no more files + if ( errno == 0) + { + // there are no more files + break; + } + else + { + // there was an error + throw directory::listing_error("Unable to list the contents of " + state.full_name); + } + } + + uint64 file_size; + // get a stat64 structure so we can see if this is a file + if (::stat64((path+data->d_name).c_str(), &buffer) != 0) + { + // this might be a broken symbolic link. We can check by calling + // readlink and seeing if it finds anything. + char buf[PATH_MAX]; + ssize_t temp = readlink((path+data->d_name).c_str(),buf,sizeof(buf)); + if (temp == -1) + throw directory::listing_error("Unable to list the contents of " + state.full_name); + else + file_size = static_cast(temp); + } + else + { + file_size = static_cast(buffer.st_size); + } + auto last_modified = std::chrono::system_clock::from_time_t(buffer.st_mtime); +#ifdef _BSD_SOURCE + last_modified += std::chrono::duration_cast(std::chrono::nanoseconds(buffer.st_atim.tv_nsec)); +#endif + + if (S_ISDIR(buffer.st_mode) == 0) + { + // this is actually a file + file temp( + data->d_name, + path+data->d_name, + file_size, + last_modified, + file::private_constructor() + ); + files.enqueue(temp); + } + } // while (true) + + if (ffind != 0) + { + while (closedir(ffind)) + { + if (errno != EINTR) + break; + } + ffind = 0; + } + + } + catch (...) + { + if (ffind != 0) + { + while (closedir(ffind)) + { + if (errno != EINTR) + break; + } + ffind = 0; + } + files.clear(); + throw; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename queue_of_files + > + typename enable_if,void>::type + directory_helper_get_files ( + const directory::data& state, + queue_of_files& files + ) + { + queue::kernel_2a temp_files; + directory_helper_get_files(state,temp_files); + + files.clear(); + + // copy the queue of files into the vector + temp_files.reset(); + while (temp_files.move_next()) + { + files.push_back(temp_files.element()); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename queue_of_files + > + void directory:: + get_files ( + queue_of_files& files + ) const + { + // the reason for this indirection here is because it avoids a bug in + // the cygwin version of gcc + directory_helper_get_files(state,files); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename queue_of_dirs + > + typename disable_if,void>::type + directory_helper_get_dirs ( + const directory::data& state, + queue_of_dirs& dirs + ) + { + using namespace std; + + dirs.clear(); + if (state.full_name.size() == 0) + throw directory::listing_error("This directory object currently doesn't represent any directory."); + + DIR* ffind = 0; + struct dirent* data; + struct stat64 buffer; + + try + { + string path = state.full_name; + // ensure that the path ends with a separator + if (path[path.size()-1] != directory::get_separator()) + path += directory::get_separator(); + + // get a handle to something we can search with + ffind = opendir(state.full_name.c_str()); + if (ffind == 0) + { + throw directory::listing_error("Unable to list the contents of " + state.full_name); + } + + while(true) + { + errno = 0; + if ( (data = readdir(ffind)) == 0) + { + // there was an error or no more files + if ( errno == 0) + { + // there are no more files + break; + } + else + { + // there was an error + throw directory::listing_error("Unable to list the contents of " + state.full_name); + } + } + + // get a stat64 structure so we can see if this is a file + if (::stat64((path+data->d_name).c_str(), &buffer) != 0) + { + // just assume this isn't a directory. It is probably a broken + // symbolic link. + continue; + } + + string dtemp(data->d_name); + if (S_ISDIR(buffer.st_mode) && + dtemp != "." && + dtemp != ".." ) + { + // this is a directory so add it to dirs + directory temp(dtemp,path+dtemp, directory::private_constructor()); + dirs.enqueue(temp); + } + } // while (true) + + if (ffind != 0) + { + while (closedir(ffind)) + { + if (errno != EINTR) + break; + } + ffind = 0; + } + + } + catch (...) + { + if (ffind != 0) + { + while (closedir(ffind)) + { + if (errno != EINTR) + break; + } + ffind = 0; + } + dirs.clear(); + throw; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename queue_of_dirs + > + typename enable_if,void>::type + directory_helper_get_dirs ( + const directory::data& state, + queue_of_dirs& dirs + ) + { + queue::kernel_2a temp_dirs; + directory_helper_get_dirs(state,temp_dirs); + + dirs.clear(); + + // copy the queue of dirs into the vector + temp_dirs.reset(); + while (temp_dirs.move_next()) + { + dirs.push_back(temp_dirs.element()); + } + } + +// ---------------------------------------------------------------------------------------- + + + template < + typename queue_of_dirs + > + void directory:: + get_dirs ( + queue_of_dirs& dirs + ) const + { + // the reason for this indirection here is because it avoids a bug in + // the cygwin version of gcc + directory_helper_get_dirs(state,dirs); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename queue_of_dir + > + typename disable_if,void>::type get_filesystem_roots ( + queue_of_dir& roots + ) + { + roots.clear(); + directory dir("/"); + roots.enqueue(dir); + } + + template < + typename queue_of_dir + > + typename enable_if,void>::type get_filesystem_roots ( + std::vector& roots + ) + { + roots.clear(); + directory dir("/"); + roots.push_back(dir); + } + +// ---------------------------------------------------------------------------------------- + +} + + +#ifdef NO_MAKEFILE +#include "dir_nav_kernel_2.cpp" +#endif + +#endif // DLIB_DIR_NAV_KERNEl_2_ + diff --git a/ml/dlib/dlib/dir_nav/dir_nav_kernel_abstract.h b/ml/dlib/dlib/dir_nav/dir_nav_kernel_abstract.h new file mode 100644 index 000000000..53254ee03 --- /dev/null +++ b/ml/dlib/dlib/dir_nav/dir_nav_kernel_abstract.h @@ -0,0 +1,515 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_DIR_NAV_KERNEl_ABSTRACT_ +#ifdef DLIB_DIR_NAV_KERNEl_ABSTRACT_ + +#include +#include +#include "../uintn.h" +#include "../algs.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + /*! + GENERAL WARNING + Don't call any of these functions or make any of these objects + before main() has been entered. That means no instances + of file or directory at the global scope. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename queue_of_dir + > + void get_filesystem_roots ( + queue_of_dir& roots + ); + /*! + requires + - queue_of_dirs == an implementation of queue/queue_kernel_abstract.h with T + set to directory or a std::vector or dlib::std_vector_c. + ensures + - #roots == a queue containing directories that represent all the roots + of the filesystem on this machine. (e.g. in windows you have c:\, d:\ + etc.) + throws + - std::bad_alloc + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // file object +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class file + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a file. + + Note that the size of a file is determined at the time the file + object is constructed. Thus if a file changes sizes after its + file object has been created its file object's size() method + will not reflect the new file size. + !*/ + + public: + + class file_not_found : public error {}; + + file ( + ); + /*! + ensures + - #*this has been properly initialized + - #name() == "" + - #full_name() == "" + - #size() == 0 + - #*this does not represent any file + throws + - std::bad_alloc + !*/ + + file ( + const std::string& name + ); + /*! + ensures + - #*this has been properly initialized + - #*this represents the file given by name + Note that name can be a fully qualified path or just a path + relative to the current working directory. Also, any symbolic + links in name will be resolved. + throws + - std::bad_alloc + - file_not_found + This exception is thrown if the file can not be found or + accessed. + !*/ + + file ( + const char* name + ); + /*! + ensures + - this function is identical to file(const std::string& name) + !*/ + + file ( + const file& item + ); + /*! + ensures + - #*this == item + throws + - std::bad_alloc + !*/ + + ~file ( + ); + /*! + ensures + - all resources associated with *this have been released + !*/ + + const std::string& name ( + ) const; + /*! + ensures + - returns the name of the file. This is full_name() minus + the path to the file. + !*/ + + const std::string& full_name ( + ) const; + /*! + ensures + - returns the fully qualified name for the file represented by *this + !*/ + + uint64 size ( + ) const; + /*! + ensures + - returns the size of this file in bytes. + !*/ + + std::chrono::time_point last_modified ( + ) const; + /*! + ensures + - returns the time the file was last modified. + !*/ + + operator std::string ( + ) const; + /*! + ensures + - returns full_name() + (i.e. provides an implicit conversion to string from dlib::file) + !*/ + + file& operator= ( + const file& rhs + ); + /*! + ensures + - #*this == rhs + !*/ + + bool operator == ( + const file& rhs + ) const; + /*! + ensures + - if (*this and rhs represent the same file) then + - returns true + - else + - returns false + !*/ + + bool operator != ( + const file& rhs + ) const; + /*! + ensures + - if (*this and rhs represent the same file) then + - returns false + - else + - returns true + !*/ + + bool operator < ( + const file& item + ) const; + /*! + ensures + - if (full_name() < item.full_name()) then + - returns true + - else + - returns false + !*/ + + void swap ( + file& item + ); + /*! + ensures + - swaps *this and item + !*/ + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // directory object +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class directory + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a directory in a file system. It gives + the ability to traverse a directory tree. + + Note that the directories . and .. are not returned by get_dirs() + !*/ + + public: + + class dir_not_found : public error {}; + class listing_error : public error {}; + + directory ( + ); + /*! + ensures + - #*this has been properly initialized + - #full_name() == "" + - #name() == "" + - #is_root() == true + - #*this does not represent any directory + throws + - std::bad_alloc + !*/ + + directory ( + const std::string& name + ); + /*! + ensures + - #*this has been properly initialized + - #*this represents the directory given by name. + Note that name can be a fully qualified path or just a path + relative to the current working directory. Also, any symbolic + links in name will be resolved. + throws + - std::bad_alloc + - dir_not_found + This exception is thrown if the directory can not be found or + accessed. + !*/ + + directory ( + const char* name + ); + /*! + ensures + - this function is identical to directory(const std::string& name) + !*/ + + directory ( + const directory& item + ); + /*! + ensures + - #*this == item + throws + - std::bad_alloc + !*/ + + ~directory ( + ); + /*! + ensures + - all resources associated with *this have been released + !*/ + + static char get_separator ( + ); + /*! + ensures + - returns the character used to separate directories and file names in a + path. (i.e. \ on windows and / in unix) + !*/ + + template < + typename queue_of_files + > + void get_files ( + queue_of_files& files + ) const; + /*! + requires + - queue_of_files == an implementation of queue/queue_kernel_abstract.h with T + set to file or a std::vector or dlib::std_vector_c. + ensures + - #files == A queue containing all the files present in this directory. + (Note that symbolic links will not have been resolved in the names + of the returned files.) + - #files.size() == the number of files in this directory + throws + - bad_alloc + If this exception is thrown then the call to get_files() has + no effect on *this and #files is unusable until files.clear() + is called and succeeds. + - listing_error + This exception is thrown if listing access has been denied to this + directory or if some error occurred that prevented us from successfully + getting the contents of this directory. + If this exception is thrown then the call to get_files() has + no effect on *this and #files.size()==0. + !*/ + + std::vector get_files ( + ) const; + /*! + ensures + - This function simply calls get_files(temp_vector) and then returns temp_vector. + !*/ + + template < + typename queue_of_dirs + > + void get_dirs ( + queue_of_dirs& dirs + ) const; + /*! + requires + - queue_of_dirs == an implementation of queue/queue_kernel_abstract.h with T + set to directory or a std::vector or dlib::std_vector_c. + ensures + - #dirs == a queue containing all the directories present in this directory. + (note that symbolic links will not have been resolved in the names + of the returned directories.) + - #dirs.size() == the number of subdirectories in this directory + throws + - bad_alloc + If this exception is thrown then the call to get_files() has + no effect on *this and #files is unusable until files.clear() + is called and succeeds. + - listing_error + This exception is thrown if listing access has been denied to this + directory or if some error occurred that prevented us from successfully + getting the contents of this directory. + If this exception is thrown then the call to get_dirs() has + no effect on *this and #dirs.size()==0. + !*/ + + std::vector get_dirs ( + ) const; + /*! + ensures + - This function simply calls get_dirs(temp_vector) and then returns temp_vector. + !*/ + + bool is_root ( + ) const; + /*! + ensures + - if (*this represents the root of this directory tree) then + - returns true + - else + - returns false + !*/ + + const directory get_parent ( + ) const; + /*! + ensures + - if (is_root()) then + - returns a copy of *this + - else + - returns the parent directory of *this + throws + - bad_alloc + If this exception is thrown then the call to get_parent() will + have no effect. + !*/ + + const std::string& name ( + ) const; + /*! + ensures + - if (is_root()) then + - returns "" + - else + - returns the name of the directory. This is full_name() minus + the path to the directory. + !*/ + + const std::string& full_name ( + ) const; + /*! + ensures + - returns the fully qualified directory name for *this + - if (is_root()) then + - the last character of #full_name() is get_separator() + - else + - the last character of #full_name() is NOT get_separator() + !*/ + + operator std::string ( + ) const; + /*! + ensures + - returns full_name() + (i.e. provides an implicit conversion to string from dlib::directory) + !*/ + + directory& operator= ( + const directory& rhs + ); + /*! + ensures + - #*this == rhs + !*/ + + bool operator == ( + const directory& rhs + ) const; + /*! + ensures + - if (*this and rhs represent the same directory) then + - returns true + - else + - returns false + !*/ + + bool operator != ( + const directory& rhs + ) const; + /*! + ensures + - if (*this and rhs represent the same directory) then + - returns false + - else + - returns true + !*/ + + bool operator < ( + const directory& item + ) const; + /*! + ensures + - if (full_name() < item.full_name()) then + - returns true + - else + - returns false + !*/ + + void swap ( + directory& item + ); + /*! + ensures + - swaps *this and item + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + inline std::ostream& operator<< ( + std::ostream& out, + const directory& item + ); + /*! + ensures + - performs: out << item.full_name() + - returns out + !*/ + + inline std::ostream& operator<< ( + std::ostream& out, + const file& item + ); + /*! + ensures + - performs: out << item.full_name() + - returns out + !*/ + +// ---------------------------------------------------------------------------------------- + + inline void swap ( + file& a, + file& b + ) { a.swap(b); } + /*! + provides a global swap function for file objects + !*/ + +// ---------------------------------------------------------------------------------------- + + inline void swap ( + directory& a, + directory& b + ) { a.swap(b); } + /*! + provides a global swap function for directory objects + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_DIR_NAV_KERNEl_ABSTRACT_ + diff --git a/ml/dlib/dlib/dir_nav/posix.h b/ml/dlib/dlib/dir_nav/posix.h new file mode 100644 index 000000000..8d499064f --- /dev/null +++ b/ml/dlib/dlib/dir_nav/posix.h @@ -0,0 +1,6 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_DIR_NAV_KERNEl_1_ +#include "dir_nav_kernel_2.h" +#endif + diff --git a/ml/dlib/dlib/dir_nav/windows.h b/ml/dlib/dlib/dir_nav/windows.h new file mode 100644 index 000000000..b0f1e1bf5 --- /dev/null +++ b/ml/dlib/dlib/dir_nav/windows.h @@ -0,0 +1,6 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_DIR_NAV_KERNEl_2_ +#include "dir_nav_kernel_1.h" +#endif + diff --git a/ml/dlib/dlib/directed_graph.h b/ml/dlib/dlib/directed_graph.h new file mode 100644 index 000000000..a452521dc --- /dev/null +++ b/ml/dlib/dlib/directed_graph.h @@ -0,0 +1,37 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_DIRECTED_GRAPh_ +#define DLIB_DIRECTED_GRAPh_ + +#include "directed_graph/directed_graph_kernel_1.h" + +#include "algs.h" + +namespace dlib +{ + + template < + typename T, + typename E = char, + typename mem_manager = default_memory_manager + > + class directed_graph + { + directed_graph() {} + public: + + + //----------- kernels --------------- + + // kernel_1a + typedef directed_graph_kernel_1 + kernel_1a; + typedef directed_graph_kernel_1 + kernel_1a_c; + + }; +} + +#endif // DLIB_DIRECTED_GRAPh_ + + diff --git a/ml/dlib/dlib/directed_graph/directed_graph_kernel_1.h b/ml/dlib/dlib/directed_graph/directed_graph_kernel_1.h new file mode 100644 index 000000000..b0cc6a2c2 --- /dev/null +++ b/ml/dlib/dlib/directed_graph/directed_graph_kernel_1.h @@ -0,0 +1,704 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_DIRECTED_GRAPH_KERNEl_1_ +#define DLIB_DIRECTED_GRAPH_KERNEl_1_ + +#include +#include + +#include "../serialize.h" +#include "../noncopyable.h" +#include "../std_allocator.h" +#include "../algs.h" +#include "directed_graph_kernel_abstract.h" +#include "../is_kind.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template + struct directed_graph_checker_helper + { + /*! + This object is used to check preconditions based on the value of is_checked + !*/ + + static void check_parent_edge ( + unsigned long edge_index, + const node_type& self + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(edge_index < self.number_of_parents(), + "\tnode_type& directed_graph::node_type::parent_edge(edge_index)" + << "\n\tYou have specified an invalid index" + << "\n\tedge_index: " << edge_index + << "\n\tnumber_of_parents(): " << self.number_of_parents() + << "\n\tthis: " << &self + ); + } + + static void check_child_edge ( + unsigned long edge_index, + const node_type& self + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(edge_index < self.number_of_children(), + "\tnode_type& directed_graph::node_type::child_edge(edge_index)" + << "\n\tYou have specified an invalid index" + << "\n\tedge_index: " << edge_index + << "\n\tnumber_of_children(): " << self.number_of_children() + << "\n\tthis: " << &self + ); + } + + static void check_parent ( + unsigned long edge_index, + const node_type& self + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(edge_index < self.number_of_parents(), + "\tnode_type& directed_graph::node_type::parent(edge_index)" + << "\n\tYou have specified an invalid index" + << "\n\tedge_index: " << edge_index + << "\n\tnumber_of_parents(): " << self.number_of_parents() + << "\n\tthis: " << &self + ); + } + + static void check_child ( + unsigned long edge_index, + const node_type& self + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(edge_index < self.number_of_children(), + "\tnode_type& directed_graph::node_type::child(edge_index)" + << "\n\tYou have specified an invalid index" + << "\n\tedge_index: " << edge_index + << "\n\tnumber_of_children(): " << self.number_of_children() + << "\n\tthis: " << &self + ); + } + + static void check_node ( + unsigned long index, + const directed_graph& self + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(index < self.number_of_nodes(), + "\tnode_type& directed_graph::node(index)" + << "\n\tYou have specified an invalid index" + << "\n\tindex: " << index + << "\n\tnumber_of_nodes(): " << self.number_of_nodes() + << "\n\tthis: " << &self + ); + } + + static void check_has_edge ( + unsigned long parent_node_index, + unsigned long child_node_index, + const directed_graph& self + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(parent_node_index < self.number_of_nodes() && + child_node_index < self.number_of_nodes(), + "\tvoid directed_graph::has_edge(parent_node_index, child_node_index)" + << "\n\tYou have specified an invalid index" + << "\n\tparent_node_index: " << parent_node_index + << "\n\tchild_node_index: " << child_node_index + << "\n\tnumber_of_nodes(): " << self.number_of_nodes() + << "\n\tthis: " << &self + ); + } + + static void check_add_edge ( + unsigned long parent_node_index, + unsigned long child_node_index, + const directed_graph& self + ) + { + DLIB_CASSERT(parent_node_index < self.number_of_nodes() && + child_node_index < self.number_of_nodes(), + "\tvoid directed_graph::add_edge(parent_node_index, child_node_index)" + << "\n\tYou have specified an invalid index" + << "\n\tparent_node_index: " << parent_node_index + << "\n\tchild_node_index: " << child_node_index + << "\n\tnumber_of_nodes(): " << self.number_of_nodes() + << "\n\tthis: " << &self + ); + + DLIB_CASSERT( self.has_edge(parent_node_index, child_node_index) == false, + "\tvoid directed_graph::add_edge(parent_node_index, child_node_index)" + << "\n\tYou can't add an edge if it already exists in the graph" + << "\n\tparent_node_index: " << parent_node_index + << "\n\tchild_node_index: " << child_node_index + << "\n\tnumber_of_nodes(): " << self.number_of_nodes() + << "\n\tthis: " << &self + ); + + } + + static void check_remove_edge ( + unsigned long parent_node_index, + unsigned long child_node_index, + const directed_graph& self + ) + { + DLIB_CASSERT(parent_node_index < self.number_of_nodes() && + child_node_index < self.number_of_nodes(), + "\tvoid directed_graph::remove_edge(parent_node_index, child_node_index)" + << "\n\tYou have specified an invalid index" + << "\n\tparent_node_index: " << parent_node_index + << "\n\tchild_node_index: " << child_node_index + << "\n\tnumber_of_nodes(): " << self.number_of_nodes() + << "\n\tthis: " << &self + ); + + DLIB_CASSERT( self.has_edge(parent_node_index, child_node_index) == true, + "\tvoid directed_graph::remove_edge(parent_node_index, child_node_index)" + << "\n\tYou can't remove an edge if it isn't in the graph" + << "\n\tparent_node_index: " << parent_node_index + << "\n\tchild_node_index: " << child_node_index + << "\n\tnumber_of_nodes(): " << self.number_of_nodes() + << "\n\tthis: " << &self + ); + + } + + static void check_remove_node ( + unsigned long index, + const directed_graph& self + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(index < self.number_of_nodes(), + "\tvoid directed_graph::remove_node(index)" + << "\n\tYou have specified an invalid index" + << "\n\tindex: " << index + << "\n\tnumber_of_nodes(): " << self.number_of_nodes() + << "\n\tthis: " << &self + ); + } + }; + + template + struct directed_graph_checker_helper + { + static inline void check_parent ( unsigned long , const node_type&) { } + static inline void check_child ( unsigned long , const node_type& ) { } + static inline void check_parent_edge ( unsigned long , const node_type&) { } + static inline void check_child_edge ( unsigned long , const node_type& ) { } + static inline void check_node ( unsigned long , const directed_graph& ) { } + static inline void check_has_edge ( unsigned long , unsigned long , const directed_graph& ) { } + static inline void check_add_edge ( unsigned long , unsigned long , const directed_graph& ) { } + static inline void check_remove_edge ( unsigned long , unsigned long , const directed_graph& ) { } + static inline void check_remove_node ( unsigned long , const directed_graph& ) { } + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename E = char, + typename mem_manager = default_memory_manager, + bool is_checked = true + > + class directed_graph_kernel_1 : noncopyable + { + + /*! + INITIAL VALUE + - nodes.size() == 0 + + CONVENTION + - nodes.size() == number_of_nodes() + - for all valid i: + - *nodes[i] == node(i) + - nodes[i]->parents.size() == nodes[i]->number_of_parents(i) + - nodes[i]->children.size() == nodes[i]->number_of_children(i) + - nodes[i]->edge_parents.size() == nodes[i]->number_of_parents(i) + - nodes[i]->edge_children.size() == nodes[i]->number_of_children(i) + - nodes[i]->idx == i == nodes[i]->index() + - for all valid p: + - nodes[i]->parents[p] == pointer to the p'th parent node of i + - *nodes[i]->parents[p] == nodes[i]->parent(p) + - *nodes[i]->edge_parents[p] == nodes[i]->parent_edge(p) + - for all valid c: + - nodes[i]->children[c] == pointer to the c'th child node of i + - *nodes[i]->children[c] == nodes[i]->child(c) + - *nodes[i]->edge_children[c] == nodes[i]->child_edge(c) + !*/ + + public: + struct node_type; + + private: + typedef directed_graph_checker_helper checker; + + + public: + + typedef T type; + typedef E edge_type; + typedef mem_manager mem_manager_type; + + template + struct rebind { + typedef directed_graph_kernel_1 other; + }; + + directed_graph_kernel_1( + ) {} + + virtual ~directed_graph_kernel_1( + ) {} + + void clear( + ) { nodes.clear(); } + + void set_number_of_nodes ( + unsigned long new_size + ); + + unsigned long number_of_nodes ( + ) const { return nodes.size(); } + + node_type& node ( + unsigned long index + ) { checker::check_node(index,*this); return *nodes[index]; } + + const node_type& node ( + unsigned long index + ) const { checker::check_node(index,*this); return *nodes[index]; } + + bool has_edge ( + unsigned long parent_node_index, + unsigned long child_node_index + ) const; + + void add_edge ( + unsigned long parent_node_index, + unsigned long child_node_index + ); + + void remove_edge ( + unsigned long parent_node_index, + unsigned long child_node_index + ); + + unsigned long add_node ( + ); + + void remove_node ( + unsigned long index + ); + + void swap ( + directed_graph_kernel_1& item + ) { nodes.swap(item.nodes); } + + private: + + + public: + + struct node_type + { + T data; + typedef directed_graph_kernel_1 graph_type; + + unsigned long index( + ) const { return idx; } + + unsigned long number_of_parents ( + ) const { return parents.size(); } + + unsigned long number_of_children ( + ) const { return children.size(); } + + const node_type& parent ( + unsigned long edge_index + ) const { checker::check_parent(edge_index,*this); return *parents[edge_index]; } + + node_type& parent ( + unsigned long edge_index + ) { checker::check_parent(edge_index,*this); return *parents[edge_index]; } + + const node_type& child ( + unsigned long edge_index + ) const { checker::check_child(edge_index,*this); return *children[edge_index]; } + + node_type& child ( + unsigned long edge_index + ) { checker::check_child(edge_index,*this); return *children[edge_index]; } + + const E& parent_edge ( + unsigned long edge_index + ) const { checker::check_parent_edge(edge_index,*this); return *edge_parents[edge_index]; } + + E& parent_edge ( + unsigned long edge_index + ) { checker::check_parent_edge(edge_index,*this); return *edge_parents[edge_index]; } + + const E& child_edge ( + unsigned long edge_index + ) const { checker::check_child_edge(edge_index,*this); return *edge_children[edge_index]; } + + E& child_edge ( + unsigned long edge_index + ) { checker::check_child_edge(edge_index,*this); return *edge_children[edge_index]; } + + private: + friend class directed_graph_kernel_1; + typedef std_allocator alloc_type; + typedef std_allocator,mem_manager> alloc_edge_type; + std::vector parents; + std::vector children; + std::vector,alloc_edge_type> edge_parents; + std::vector,alloc_edge_type> edge_children; + unsigned long idx; + }; + + private: + + typedef std_allocator,mem_manager> alloc_type; + typedef std::vector, alloc_type> vector_type; + vector_type nodes; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename E, + typename mem_manager, + bool is_checked + > + struct is_directed_graph > + { + static const bool value = true; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename E, + typename mem_manager, + bool is_checked + > + inline void swap ( + directed_graph_kernel_1& a, + directed_graph_kernel_1& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename E, + typename mem_manager, + bool is_checked + > + void serialize ( + const directed_graph_kernel_1& item, + std::ostream& out + ) + { + try + { + serialize(item.number_of_nodes(), out); + + // serialize each node + for (unsigned long i = 0; i < item.number_of_nodes(); ++i) + { + serialize(item.node(i).data, out); + + // serialize all the child edges + serialize(item.node(i).number_of_children(), out); + for (unsigned long c = 0; c < item.node(i).number_of_children(); ++c) + { + serialize(item.node(i).child(c).index(), out); + serialize(item.node(i).child_edge(c), out); + } + } + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing object of type directed_graph_kernel_1"); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename E, + typename mem_manager, + bool is_checked + > + void deserialize ( + directed_graph_kernel_1& item, + std::istream& in + ) + { + try + { + unsigned long size; + deserialize(size, in); + + item.clear(); + item.set_number_of_nodes(size); + + // deserialize each node + for (unsigned long i = 0; i < item.number_of_nodes(); ++i) + { + deserialize(item.node(i).data, in); + + unsigned long num_children; + deserialize(num_children, in); + + // Add all the edges going to this nodes children nodes + for (unsigned long c = 0; c < num_children; ++c) + { + unsigned long child_index; + deserialize(child_index, in); + + item.add_edge(i, child_index); + + // find the edge we just added + for (unsigned long j = 0; j < item.node(i).number_of_children(); ++j) + { + if (item.node(i).child(j).index() == child_index) + { + deserialize(item.node(i).child_edge(j), in); + break; + } + } + } + } + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while deserializing object of type directed_graph_kernel_1"); + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename E, + typename mem_manager, + bool is_checked + > + void directed_graph_kernel_1:: + set_number_of_nodes ( + unsigned long new_size + ) + { + try + { + nodes.resize(new_size); + for (unsigned long i = 0; i < nodes.size(); ++i) + { + nodes[i].reset(new node_type); + nodes[i]->idx = i; + } + } + catch (...) + { + clear(); + throw; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename E, + typename mem_manager, + bool is_checked + > + bool directed_graph_kernel_1:: + has_edge ( + unsigned long parent_node_index, + unsigned long child_node_index + ) const + { + checker::check_has_edge(parent_node_index, child_node_index, *this); + + node_type& n = *nodes[parent_node_index]; + + // search all the child nodes to see if there is a link to the right node + for (unsigned long i = 0; i < n.children.size(); ++i) + { + if (n.children[i]->idx == child_node_index) + return true; + } + + return false; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename E, + typename mem_manager, + bool is_checked + > + void directed_graph_kernel_1:: + add_edge ( + unsigned long parent_node_index, + unsigned long child_node_index + ) + { + checker::check_add_edge(parent_node_index, child_node_index, *this); + try + { + node_type& p = *nodes[parent_node_index]; + node_type& c = *nodes[child_node_index]; + + p.children.push_back(&c); + c.parents.push_back(&p); + + p.edge_children.push_back(std::shared_ptr(new E)); + c.edge_parents.push_back(p.edge_children.back()); + } + catch (...) + { + clear(); + throw; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename E, + typename mem_manager, + bool is_checked + > + void directed_graph_kernel_1:: + remove_edge ( + unsigned long parent_node_index, + unsigned long child_node_index + ) + { + checker::check_remove_edge(parent_node_index, child_node_index, *this); + + node_type& p = *nodes[parent_node_index]; + node_type& c = *nodes[child_node_index]; + + // remove the record of the link from the parent node + unsigned long pos = static_cast(find( p.children.begin(), + p.children.end(), + &c) - p.children.begin()); + p.children.erase(p.children.begin()+pos); + p.edge_children.erase(p.edge_children.begin()+pos); + + // remove the record of the link from the child node + pos = static_cast(find( c.parents.begin(), + c.parents.end(), + &p) - c.parents.begin()); + c.parents.erase(c.parents.begin() + pos); + c.edge_parents.erase(c.edge_parents.begin() + pos); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename E, + typename mem_manager, + bool is_checked + > + unsigned long directed_graph_kernel_1:: + add_node ( + ) + { + try + { + std::shared_ptr n(new node_type); + n->idx = nodes.size(); + nodes.push_back(n); + return n->idx; + } + catch (...) + { + clear(); + throw; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename E, + typename mem_manager, + bool is_checked + > + void directed_graph_kernel_1:: + remove_node ( + unsigned long index + ) + { + checker::check_remove_node(index,*this); + + node_type& n = *nodes[index]; + + // remove all edges pointing to this node from its parents + for (unsigned long i = 0; i < n.parents.size(); ++i) + { + // remove the edge from this specific parent + unsigned long pos = static_cast(find(n.parents[i]->children.begin(), + n.parents[i]->children.end(), + &n) - n.parents[i]->children.begin()); + + n.parents[i]->children.erase(n.parents[i]->children.begin() + pos); + n.parents[i]->edge_children.erase(n.parents[i]->edge_children.begin() + pos); + } + + // remove all edges pointing to this node from its children + for (unsigned long i = 0; i < n.children.size(); ++i) + { + // remove the edge from this specific child + unsigned long pos = static_cast(find(n.children[i]->parents.begin(), + n.children[i]->parents.end(), + &n) - n.children[i]->parents.begin()); + + n.children[i]->parents.erase(n.children[i]->parents.begin() + pos); + n.children[i]->edge_parents.erase(n.children[i]->edge_parents.begin() + pos); + } + + // now remove this node by replacing it with the last node in the nodes vector + nodes[index] = nodes[nodes.size()-1]; + + // update the index for the node we just moved + nodes[index]->idx = index; + + // now remove the duplicated node at the end of the vector + nodes.pop_back(); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_DIRECTED_GRAPH_KERNEl_1_ + diff --git a/ml/dlib/dlib/directed_graph/directed_graph_kernel_abstract.h b/ml/dlib/dlib/directed_graph/directed_graph_kernel_abstract.h new file mode 100644 index 000000000..70dd66efd --- /dev/null +++ b/ml/dlib/dlib/directed_graph/directed_graph_kernel_abstract.h @@ -0,0 +1,383 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_DIRECTED_GRAPH_KERNEl_ABSTRACT_ +#ifdef DLIB_DIRECTED_GRAPH_KERNEl_ABSTRACT_ + +#include "../serialize.h" +#include "../algs.h" +#include "../noncopyable.h" + +namespace dlib +{ + + template < + typename T, + typename E = char, + typename mem_manager = default_memory_manager + > + class directed_graph : noncopyable + { + + /*! + REQUIREMENTS ON T + T must be swappable by a global swap() and + T must have a default constructor + + REQUIREMENTS ON E + E must be swappable by a global swap() and + E must have a default constructor + + REQUIREMENTS ON mem_manager + must be an implementation of memory_manager/memory_manager_kernel_abstract.h or + must be an implementation of memory_manager_global/memory_manager_global_kernel_abstract.h or + must be an implementation of memory_manager_stateless/memory_manager_stateless_kernel_abstract.h + mem_manager::type can be set to anything. + + POINTERS AND REFERENCES TO INTERNAL DATA + The only time pointers or references to nodes or edges become invalid is when + they reference nodes or edges that have been removed from a graph. + + INITIAL VALUE + number_of_nodes() == 0 + + WHAT THIS OBJECT REPRESENTS + This object represents a directed graph which is a set of nodes with directed + edges connecting various nodes. + + In this object if there is a directed edge from a node A to a node B then I say + that A is the parent of B and B is the child of A. + + Also note that unless specified otherwise, no member functions + of this object throw exceptions. + !*/ + + public: + + typedef T type; + typedef E edge_type; + typedef mem_manager mem_manager_type; + + template + struct rebind { + typedef directed_graph other; + }; + + directed_graph( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc or any exception thrown by T's constructor. + !*/ + + virtual ~directed_graph( + ); + /*! + ensures + - all resources associated with *this has been released + !*/ + + void clear( + ); + /*! + ensures + - #*this has its initial value + throws + - std::bad_alloc or any exception thrown by T's constructor. + If this exception is thrown then *this is unusable + until clear() is called and succeeds + !*/ + + void set_number_of_nodes ( + unsigned long new_size + ); + /*! + ensures + - #number_of_nodes() == new_size + - for all i < new_size: + - number_of_parents(i) == 0 + - number_of_children(i) == 0 + throws + - std::bad_alloc or any exception thrown by T's constructor. + If this exception is thrown then this object reverts back + to its initial state. + !*/ + + unsigned long number_of_nodes ( + ) const; + /*! + ensures + - returns the number of nodes in this graph + !*/ + + struct node_type + { + T data; + typedef directed_graph graph_type; + + unsigned long index( + ) const; + /*! + ensures + - let G be the graph that contains the node *this + - returns a number N such that G.node(N) == *this + (i.e. returns the index of this node in the graph) + !*/ + + unsigned long number_of_parents ( + ) const; + /*! + ensures + - returns the number of parents of this node + !*/ + + unsigned long number_of_children ( + ) const; + /*! + ensures + - returns the number of children of this node + !*/ + + const node_type& parent ( + unsigned long edge_index + ) const; + /*! + requires + - edge_index < number_of_parents() + ensures + - returns a const reference to the edge_index'th parent of *this + !*/ + + node_type& parent ( + unsigned long edge_index + ); + /*! + requires + - edge_index < number_of_parents() + ensures + - returns a non-const reference to the edge_index'th parent of *this + !*/ + + const node_type& child ( + unsigned long edge_index + ) const; + /*! + requires + - edge_index < number_of_children() + ensures + - returns a const reference to the edge_index'th child of *this + !*/ + + node_type& child ( + unsigned long edge_index + ); + /*! + requires + - edge_index < number_of_children() + ensures + - returns a non-const reference to the edge_index'th child of *this + !*/ + + const E& parent_edge ( + unsigned long edge_index + ) const; + /*! + requires + - edge_index < number_of_parents() + ensures + - returns a const reference to the edge_index'th edge data for the + edge connecting to node this->parent(edge_index) + !*/ + + E& parent_edge ( + unsigned long edge_index + ); + /*! + requires + - edge_index < number_of_parents() + ensures + - returns a non-const reference to the edge_index'th edge data for the + edge connecting to node this->parent(edge_index) + !*/ + + const E& child_edge ( + unsigned long edge_index + ) const; + /*! + requires + - edge_index < number_of_children() + ensures + - returns a const reference to the edge_index'th edge data for the + edge connecting to node this->child(edge_index) + !*/ + + E& child_edge ( + unsigned long edge_index + ); + /*! + requires + - edge_index < number_of_children() + ensures + - returns a non-const reference to the edge_index'th edge data for the + edge connecting to node this->child(edge_index) + !*/ + }; + + node_type& node ( + unsigned long index + ); + /*! + requires + - index < number_of_nodes() + ensures + - returns a non-const reference to the node with the given index + !*/ + + const node_type& node ( + unsigned long index + ) const; + /*! + requires + - index < number_of_nodes() + ensures + - returns a const reference to the node with the given index + !*/ + + bool has_edge ( + unsigned long parent_node_index, + unsigned long child_node_index + ) const; + /*! + requires + - parent_node_index < number_of_nodes() + - child_node_index < number_of_nodes() + ensures + - if (there is an edge leading from node(parent_node_index) to + node(child_node_index)) then + - returns true + - else + - returns false + !*/ + + void add_edge ( + unsigned long parent_node_index, + unsigned long child_node_index + ); + /*! + requires + - parent_node_index < number_of_nodes() + - child_node_index < number_of_nodes() + - has_edge(parent_node_index, child_node_index) == false + ensures + - #has_edge(parent_node_index, child_node_index) == true + throws + - std::bad_alloc + If this exception is thrown then this object reverts back + to its initial state. + !*/ + + void remove_edge ( + unsigned long parent_node_index, + unsigned long child_node_index + ); + /*! + requires + - parent_node_index < number_of_nodes() + - child_node_index < number_of_nodes() + - has_edge(parent_node_index, child_node_index) == true + ensures + - #has_edge(parent_node_index, child_node_index) == false + throws + - std::bad_alloc + If this exception is thrown then this object reverts back + to its initial state. + !*/ + + unsigned long add_node ( + ); + /*! + ensures + - does not change the index number of existing nodes + - adds a node with index N == number_of_nodes() such that: + - #node(N).number_of_parents() == 0 + - #node(N).number_of_children() == 0 + - #number_of_nodes() == number_of_nodes() + 1 + - returns N + throws + - std::bad_alloc or any exception thrown by T's constructor. + If this exception is thrown then this object reverts back + to its initial state. + !*/ + + void remove_node ( + unsigned long index + ); + /*! + requires + - index < number_of_nodes() + ensures + - removes the node with the given index from the graph. + - removes all edges linking the removed node to the rest + of the graph. + - the remaining node indexes are remapped so that they remain + contiguous. (This means that for all valid N, node(N) doesn't + necessarily reference the same node as #node(N)) + - #number_of_nodes() == number_of_nodes() - 1 + throws + - std::bad_alloc or any exception thrown by T's constructor. + If this exception is thrown then this object reverts back + to its initial state. + !*/ + + void swap ( + directed_graph& item + ); + /*! + ensures + - swaps *this and item + !*/ + + }; + + template < + typename T, + typename mem_manager + > + inline void swap ( + directed_graph& a, + directed_graph& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + + template < + typename T, + typename mem_manager + > + void serialize ( + const directed_graph& item, + std::ostream& out + ); + /*! + provides deserialization support + !*/ + + template < + typename T, + typename mem_manager + > + void deserialize ( + directed_graph& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +} + +#endif // DLIB_DIRECTED_GRAPH_KERNEl_ABSTRACT_ + + diff --git a/ml/dlib/dlib/disjoint_subsets.h b/ml/dlib/dlib/disjoint_subsets.h new file mode 100644 index 000000000..d33ef63f4 --- /dev/null +++ b/ml/dlib/dlib/disjoint_subsets.h @@ -0,0 +1,12 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_DISJOINt_SUBSETS_ +#define DLIB_DISJOINt_SUBSETS_ + + +#include "disjoint_subsets/disjoint_subsets.h" +#include "disjoint_subsets/disjoint_subsets_sized.h" + +#endif // DLIB_DISJOINt_SUBSETS_ + + diff --git a/ml/dlib/dlib/disjoint_subsets/disjoint_subsets.h b/ml/dlib/dlib/disjoint_subsets/disjoint_subsets.h new file mode 100644 index 000000000..7fab9eba3 --- /dev/null +++ b/ml/dlib/dlib/disjoint_subsets/disjoint_subsets.h @@ -0,0 +1,141 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_DISJOINT_SUBsETS_Hh_ +#define DLIB_DISJOINT_SUBsETS_Hh_ + +#include "disjoint_subsets_abstract.h" +#include +#include "../algs.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class disjoint_subsets + { + public: + + void clear ( + ) noexcept + { + items.clear(); + } + + void set_size ( + unsigned long new_size + ) + { + items.resize(new_size); + for (unsigned long i = 0; i < items.size(); ++i) + { + items[i].parent = i; + items[i].rank = 0; + } + } + + size_t size ( + ) const noexcept + { + return items.size(); + } + + unsigned long find_set ( + unsigned long item + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(item < size(), + "\t unsigned long disjoint_subsets::find_set()" + << "\n\t item must be less than size()" + << "\n\t item: " << item + << "\n\t size(): " << size() + << "\n\t this: " << this + ); + + if (items[item].parent == item) + { + return item; + } + else + { + // find root of item + unsigned long x = item; + do + { + x = items[x].parent; + } while (items[x].parent != x); + + // do path compression + const unsigned long root = x; + x = item; + while (items[x].parent != x) + { + const unsigned long prev = x; + x = items[x].parent; + items[prev].parent = root; + } + + return root; + } + } + + unsigned long merge_sets ( + unsigned long a, + unsigned long b + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(a != b && + a < size() && + b < size() && + find_set(a) == a && + find_set(b) == b, + "\t unsigned long disjoint_subsets::merge_sets(a,b)" + << "\n\t invalid arguments were given to this function" + << "\n\t a: " << a + << "\n\t b: " << b + << "\n\t size(): " << size() + << "\n\t find_set(a): " << find_set(a) + << "\n\t find_set(b): " << find_set(b) + << "\n\t this: " << this + ); + + if (items[a].rank > items[b].rank) + { + items[b].parent = a; + return a; + } + else + { + items[a].parent = b; + if (items[a].rank == items[b].rank) + { + items[b].rank = items[b].rank + 1; + } + return b; + } + } + + private: + + /* + See the book Introduction to Algorithms by Cormen, Leiserson, Rivest and Stein + for a discussion of how this algorithm works. + */ + + struct data + { + unsigned long rank; + unsigned long parent; + }; + + mutable std::vector items; + + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_DISJOINT_SUBsETS_Hh_ diff --git a/ml/dlib/dlib/disjoint_subsets/disjoint_subsets_abstract.h b/ml/dlib/dlib/disjoint_subsets/disjoint_subsets_abstract.h new file mode 100644 index 000000000..bd67d0d5c --- /dev/null +++ b/ml/dlib/dlib/disjoint_subsets/disjoint_subsets_abstract.h @@ -0,0 +1,96 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_DISJOINT_SUBsETS_ABSTRACT_Hh_ +#ifdef DLIB_DISJOINT_SUBsETS_ABSTRACT_Hh_ + +#include +#include "../algs.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class disjoint_subsets + { + /*! + INITIAL VALUE + - size() == 0 + + WHAT THIS OBJECT REPRESENTS + This object represents a set of integers which is partitioned into + a number of disjoint subsets. It supports the two fundamental operations + of finding which subset a particular integer belongs to as well as + merging subsets. + !*/ + public: + + void clear ( + ) noexcept; + /*! + ensures + - #size() == 0 + - returns this object to its initial value + !*/ + + void set_size ( + unsigned long new_size + ); + /*! + ensures + - #size() == new_size + - for all valid i: + - #find_set(i) == i + (i.e. this object contains new_size subsets, each containing exactly one element) + !*/ + + size_t size ( + ) const noexcept; + /*! + ensures + - returns the total number of integer elements represented + by this object. + !*/ + + unsigned long find_set ( + unsigned long item + ) const; + /*! + requires + - item < size() + ensures + - Each disjoint subset can be represented by any of its elements (since + the sets are all disjoint). In particular, for each subset we define + a special "representative element" which is used to represent it. + Therefore, this function returns the representative element for the + set which contains item. + - find_set(find_set(item)) == find_set(item) + - Note that if A and B are both elements of the same subset then we always + have find_set(A) == find_set(B). + !*/ + + unsigned long merge_sets ( + unsigned long a, + unsigned long b + ); + /*! + requires + - a != b + - a < size() + - b < size() + - find_set(a) == a + (i.e. a is the representative element of some set) + - find_set(b) == b + (i.e. b is the representative element of some set) + ensures + - #find_set(a) == #find_set(b) + (i.e. merges the set's containing a and b) + - returns #find_set(a) + !*/ + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_DISJOINT_SUBsETS_ABSTRACT_Hh_ diff --git a/ml/dlib/dlib/disjoint_subsets/disjoint_subsets_sized.h b/ml/dlib/dlib/disjoint_subsets/disjoint_subsets_sized.h new file mode 100644 index 000000000..9aa657f43 --- /dev/null +++ b/ml/dlib/dlib/disjoint_subsets/disjoint_subsets_sized.h @@ -0,0 +1,130 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_DISJOINT_SUBsETS_SIZED_Hh_ +#define DLIB_DISJOINT_SUBsETS_SIZED_Hh_ + +#include "disjoint_subsets_sized_abstract.h" +#include "disjoint_subsets.h" +#include +#include "../algs.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class disjoint_subsets_sized + { + public: + + void clear ( + ) noexcept + { + disjoint_subsets_.clear(); + sets_size.clear(); + number_of_sets = 0; + } + + void set_size ( + unsigned long new_size + ) + { + disjoint_subsets_.set_size(new_size); + sets_size.assign(new_size, 1); + number_of_sets = new_size; + } + + size_t size ( + ) const noexcept + { + return disjoint_subsets_.size(); + } + + unsigned long find_set ( + unsigned long item + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(item < size(), + "\t unsigned long disjoint_subsets::find_set()" + << "\n\t item must be less than size()" + << "\n\t item: " << item + << "\n\t size(): " << size() + << "\n\t this: " << this + ); + + return disjoint_subsets_.find_set(item); + } + + unsigned long merge_sets ( + unsigned long a, + unsigned long b + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(a != b && + a < size() && + b < size() && + find_set(a) == a && + find_set(b) == b, + "\t unsigned long disjoint_subsets::merge_sets(a,b)" + << "\n\t invalid arguments were given to this function" + << "\n\t a: " << a + << "\n\t b: " << b + << "\n\t size(): " << size() + << "\n\t find_set(a): " << find_set(a) + << "\n\t find_set(b): " << find_set(b) + << "\n\t this: " << this + ); + + disjoint_subsets_.merge_sets(a, b); + + if (find_set(a) == a) sets_size[a] += sets_size[b]; + else sets_size[b] += sets_size[a]; + --number_of_sets; + + return find_set(a); + } + + unsigned long get_number_of_sets ( + ) const noexcept + { + return number_of_sets; + } + + unsigned long get_size_of_set( + unsigned long item + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(item < size() && + find_set(item) == item, + "\t unsigned long disjoint_subsets::get_size_of_set()" + << "\n\t invalid arguments were given to this function" + << "\n\t item: " << item + << "\n\t size(): " << size() + << "\n\t find_set(item): " << find_set(item) + << "\n\t this: " << this + ); + + return sets_size[item]; + } + + private: + + /* + See the book Introduction to Algorithms by Cormen, Leiserson, Rivest and Stein + for a discussion of how this algorithm works. + */ + + mutable std::vector sets_size; + unsigned long number_of_sets{0}; + disjoint_subsets disjoint_subsets_; + + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_DISJOINT_SUBsETS_SIZED_Hh_ diff --git a/ml/dlib/dlib/disjoint_subsets/disjoint_subsets_sized_abstract.h b/ml/dlib/dlib/disjoint_subsets/disjoint_subsets_sized_abstract.h new file mode 100644 index 000000000..ecc5ef005 --- /dev/null +++ b/ml/dlib/dlib/disjoint_subsets/disjoint_subsets_sized_abstract.h @@ -0,0 +1,123 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_DISJOINT_SUBsETS_SIZED_ABSTRACT_Hh_ +#ifdef DLIB_DISJOINT_SUBsETS_SIZED_ABSTRACT_Hh_ + +#include +#include "../algs.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class disjoint_subsets_sized + { + /*! + INITIAL VALUE + - size() == 0 + - get_number_of_sets() == 0 + + WHAT THIS OBJECT REPRESENTS + This object represents a set of integers which is partitioned into + a number of disjoint subsets. It supports the two fundamental operations + of finding which subset a particular integer belongs to as well as + merging subsets. It also allows you to find out how big each subset is. It + is therefore essentially the same thing as dlib::disjoint_subsets, except + it also keeps track of of the size of each subset. + !*/ + public: + + void clear ( + ) noexcept; + /*! + ensures + - #size() == 0 + - #get_number_of_sets() == 0 + - returns this object to its initial value + !*/ + + void set_size ( + unsigned long new_size + ); + /*! + ensures + - #size() == new_size + - #get_number_of_sets() == new_size + - for all valid i: + - #find_set(i) == i + (i.e. this object contains new_size subsets, each containing exactly one element) + - #get_size_of_set(i) == 1 + !*/ + + size_t size ( + ) const noexcept; + /*! + ensures + - returns the total number of integer elements represented + by this object. + !*/ + + unsigned long find_set ( + unsigned long item + ) const; + /*! + requires + - item < size() + ensures + - Each disjoint subset can be represented by any of its elements (since + the sets are all disjoint). In particular, for each subset we define + a special "representative element" which is used to represent it. + Therefore, this function returns the representative element for the + set which contains item. + - find_set(find_set(item)) == find_set(item) + - Note that if A and B are both elements of the same subset then we always + have find_set(A) == find_set(B). + !*/ + + unsigned long merge_sets ( + unsigned long a, + unsigned long b + ); + /*! + requires + - a != b + - a < size() + - b < size() + - find_set(a) == a + (i.e. a is the representative element of some set) + - find_set(b) == b + (i.e. b is the representative element of some set) + ensures + - #find_set(a) == #find_set(b) + (i.e. merges the set's containing a and b) + - #get_size_of_set(#find_set(a)) == get_size_of_set(a) + get_size_of_set(b) + - #get_number_of_sets() == get_number_of_sets() - 1 + - returns #find_set(a) + !*/ + + unsigned long get_number_of_sets ( + ) const noexcept; + /*! + ensures + - returns the current number of different subsets. + !*/ + + unsigned long get_size_of_set( + unsigned long item + ) const; + /*! + requires + - item < size() + - find_set(item) == item + (i.e. item is the representative element of some set) + ensures + - returns the number of elements which belongs to the set where item is the representative element. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_DISJOINT_SUBsETS_ABSTRACT_Hh_ diff --git a/ml/dlib/dlib/dlib_basic_cpp_build_tutorial.txt b/ml/dlib/dlib/dlib_basic_cpp_build_tutorial.txt new file mode 100644 index 000000000..fa1a970ab --- /dev/null +++ b/ml/dlib/dlib/dlib_basic_cpp_build_tutorial.txt @@ -0,0 +1,13 @@ +#error "Don't write #include in your code." +/* + In C++, it is generally an error to #include .cpp files. This is because it + can lead to what are called multiply defined symbol errors. Therefore, you + should compile dlib/all/source.cpp into your application just like you would + compile any other .cpp file. + + If you are using Visual Studio you add .cpp files to your application using + the solution explorer window. Specifically, right click on Source Files, + then select Add -> Existing Item and select the .cpp files you want to add. + + For general information on compiling dlib see http://dlib.net/compile.html +*/ diff --git a/ml/dlib/dlib/dlib_include_path_tutorial.txt b/ml/dlib/dlib/dlib_include_path_tutorial.txt new file mode 100644 index 000000000..f279ce103 --- /dev/null +++ b/ml/dlib/dlib/dlib_include_path_tutorial.txt @@ -0,0 +1,20 @@ +#error "Don't put the dlib folder in your include path" +/* + You are getting this error because you have added the dlib folder to your + compiler's include search path. + + You should *NOT* add the dlib folder itself to your compiler's include path. + Doing so will cause the build to fail because of name collisions (such as + dlib/string.h and string.h from the standard library). Instead you should + add the folder that contains the dlib folder to your include search path + and then use include statements of the form #include or + #include "dlib/queue.h". This will ensure that everything builds correctly. + + XCode: + The XCode IDE often puts all folders that it knows about into + the compiler search path. So if you are using XCode then either + don't drag the whole dlib folder into the project or alternatively + modify your XCode project settings to not auto-add all folders to + the include path. Instead just make sure that the dlib folder is + itself inside a folder in your include path. +*/ diff --git a/ml/dlib/dlib/dnn.h b/ml/dlib/dlib/dnn.h new file mode 100644 index 000000000..db59948fb --- /dev/null +++ b/ml/dlib/dlib/dnn.h @@ -0,0 +1,37 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_DNn_ +#define DLIB_DNn_ + +// DNN module uses template-based network declaration that leads to very long +// type names. Visual Studio will produce Warning C4503 in such cases +#ifdef _MSC_VER +# pragma warning( disable: 4503 ) +#endif + +#include "dnn/tensor.h" +#include "dnn/input.h" + +// Problem: Visual Studio's vcpkgsrv.exe constantly uses a single CPU core, +// apparently never finishing whatever it's trying to do. Moreover, +// this issue prevents some operations like switching from Debug to +// Release (and vice versa) in the IDE. (Your mileage may vary.) +// Workaround: Keep manually killing the vcpkgsrv.exe process. +// Solution: Disable IntelliSense for some files. Which files? Unfortunately +// this seems to be a trial-and-error process. +#ifndef __INTELLISENSE__ +#include "dnn/layers.h" +#endif // __INTELLISENSE__ + +#include "dnn/loss.h" +#include "dnn/core.h" +#include "dnn/solvers.h" +#include "dnn/trainer.h" +#include "dnn/cpu_dlib.h" +#include "dnn/tensor_tools.h" +#include "dnn/utilities.h" +#include "dnn/validation.h" + +#endif // DLIB_DNn_ + + diff --git a/ml/dlib/dlib/dnn/core.h b/ml/dlib/dlib/dnn/core.h new file mode 100644 index 000000000..5f1d05498 --- /dev/null +++ b/ml/dlib/dlib/dnn/core.h @@ -0,0 +1,3599 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_DNn_CORE_H_ +#define DLIB_DNn_CORE_H_ + +#include "core_abstract.h" +#include "tensor.h" +#include +#include +#include +#include +#include "../statistics.h" +#include "../rand.h" +#include "../algs.h" +#include +#include +#include +#include +#include "tensor_tools.h" +#include +#include "../metaprogramming.h" + +#ifdef _MSC_VER +// Tell Visual Studio not to recursively inline functions very much because otherwise it +// takes hours to compile the DNN code sometimes. It's crazy. Hopefully we can remove +// this some day when the visual studio compiler is more efficient. +#pragma inline_depth(2) +#endif + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + template ::type = 0> + double get_learning_rate_multiplier ( + const T& obj, + special_ + ) { return obj.get_learning_rate_multiplier(); } + + template + double get_learning_rate_multiplier ( const T& , general_) { return 1; } + } + template + double get_learning_rate_multiplier(const T& obj) { return impl::get_learning_rate_multiplier(obj, special_()); } + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + template ::type = 0> + double get_weight_decay_multiplier ( + const T& obj, + special_ + ) { return obj.get_weight_decay_multiplier(); } + + template + double get_weight_decay_multiplier ( const T& , general_) { return 1; } + } + template + double get_weight_decay_multiplier(const T& obj) { return impl::get_weight_decay_multiplier(obj, special_()); } + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + // The reason we return an int for this version rather than doing the more straight forward thing (like we do above) is to avoid a bug in visual studio 2015. + template + auto call_clean_method_if_exists ( + T& obj, + special_ + ) -> typename int_::type { obj.clean(); return 0; } + + template + void call_clean_method_if_exists (T& , general_) {} + } + template + void call_clean_method_if_exists(T& obj) { impl::call_clean_method_if_exists(obj, special_()); } + /*! + ensures + - calls obj.clean() if obj has a .clean() method. + !*/ + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + class repeat_input_layer + { + /*! + None of the declarations in this object are really used. The only reason it + exists is to allow the repeat object to use a special input layer in its + internal networks which will cause add_tag_layer objects that happen to be + right at the input to not create copies of their input tensors. So + introducing the repeat_input_layer object allows us to optimize the + implementation of add_tag_layer for a special case that arises when it's + used in the context of the repeat layer. + !*/ + public: + typedef int input_type; + + template + void to_tensor ( + forward_iterator , + forward_iterator , + resizable_tensor& + ) const + { + } + + friend void serialize(const repeat_input_layer&, std::ostream&){} + friend void deserialize(repeat_input_layer&, std::istream&){} + friend std::ostream& operator<<(std::ostream& out, const repeat_input_layer&) { return out; } + }; + + inline std::string tensor_to_str ( + const tensor& t, + int& min_length + ) + { + if (t.size() == 0) + return ""; + + std::ostringstream sout; + sout << "output size=(num:"<< t.num_samples() << ", "; + sout << "k:" << t.k() << ","; + while (sout.tellp() < 28) sout << " "; + sout << "nr:" << t.nr() << ","; + while (sout.tellp() < 28+8) sout << " "; + sout << "nc:" << t.nc() << ")"; + while (sout.tellp() < min_length) sout << " "; + min_length = sout.tellp(); + sout << "\t"; + return sout.str(); + } + } + +// ---------------------------------------------------------------------------------------- + + // Tell us if T is one of the special layer types (i.e. add_layer, repeat, add_tag_layer, or + // add_skip_layer). + template struct is_nonloss_layer_type : std::false_type {}; + // Tell us if T is an instance of add_loss_layer. + template struct is_loss_layer_type : std::false_type {}; + // Tell us if T is an instance of add_layer + template struct is_add_layer : std::false_type {}; + + namespace impl + { + template + auto tuple_subset( + const Tuple& item, + compile_time_integer_list + ) -> decltype(std::make_tuple(std::get(item)...)) + { + return std::make_tuple(std::get(item)...); + } + + template + std::tuple basic_tuple_tail( + const std::tuple& item + ) + { + return tuple_subset(item, typename make_compile_time_integer_range::type()); + } + + template + std::tuple tuple_flatten(const T& t) + { + return std::make_tuple(t); + } + + template + auto tuple_flatten( + const std::tuple& item + ) -> decltype(tuple_flatten(item, typename make_compile_time_integer_range::type())) + { + return tuple_flatten(item, typename make_compile_time_integer_range::type()); + } + + template + auto tuple_flatten( + const std::tuple& item, + compile_time_integer_list + ) -> decltype(std::tuple_cat(tuple_flatten(std::get(item))...)) + { + return std::tuple_cat(tuple_flatten(std::get(item))...); + } + + template + struct tuple_head_helper + { + typedef T type; + static const type& get(const T& item) + { + return item; + } + }; + + template + struct tuple_head_helper> + { + typedef typename tuple_head_helper::type type; + static const type& get(const std::tuple& item) + { + return tuple_head_helper::get(std::get<0>(item)); + } + }; + + template struct alwaysbool { typedef bool type; }; + // one more structure for VS 2015 UP3 support workaround + template struct alwaysbool2 { typedef bool type; }; + + resizable_tensor& rt(); + + // The significance of a layer's backward method requiring forward's outputs is + // that such as layer can't have an in-place layer stacked on top of it because + // in-place layers overwrite the output of the layer they sit on top of. + template + constexpr auto backward_requires_forward_output( + layer_type& layer, + SUBNET& sub + ) -> typename alwaysbool::type + { + return true; + } + + template + constexpr auto backward_requires_forward_output( + layer_type& layer, + SUBNET& sub + ) -> typename alwaysbool::type + { + return false; + } + + template + constexpr auto backward_requires_forward_output( + layer_type& layer, + SUBNET& sub + ) -> typename alwaysbool::type + { + return true; + } + + template + constexpr auto backward_requires_forward_output( + layer_type& layer, + SUBNET& sub + ) -> typename alwaysbool::type + { + return false; + } + + template + constexpr auto has_inplace_backward( + layer_type& layer, + SUBNET& sub + ) -> typename alwaysbool2::type + { + return false; + } + + template + constexpr auto has_inplace_backward( + layer_type& layer, + SUBNET& sub + ) -> typename alwaysbool2::type + { + return false; + } + + template + constexpr auto has_inplace_backward( + layer_type& layer, + SUBNET& sub + ) -> typename alwaysbool2::type + { + return true; + } + + template + constexpr auto has_inplace_backward( + layer_type& layer, + SUBNET& sub + ) -> typename alwaysbool2::type + { + return true; + } + + template + constexpr auto is_inplace_layer( + layer_type& layer, + const SUBNET& sub + ) -> typename alwaysbool2::type + { + return false; + } + + template + constexpr auto is_inplace_layer( + layer_type& layer, + const SUBNET& sub + ) -> typename alwaysbool::type + { + return true; + } + + template + auto call_layer_backward( + layer_type& layer, + const tensor& computed_output, + const tensor& gradient_input, + SUBNET& sub, + tensor& params_grad + ) -> decltype(layer.backward(computed_output,gradient_input,sub,params_grad)) + { + layer.backward(computed_output,gradient_input,sub,params_grad); + } + + template + auto call_layer_backward( + layer_type& layer, + const tensor& , + const tensor& gradient_input, + SUBNET& sub, + tensor& params_grad + ) -> decltype(layer.backward(gradient_input,sub,params_grad)) + { + layer.backward(gradient_input,sub,params_grad); + } + + template + auto call_layer_backward( + layer_type& layer, + const tensor& computed_output, + const tensor& gradient_input, + SUBNET& sub, + tensor& params_grad + ) -> decltype(layer.backward_inplace(computed_output,gradient_input,sub.get_gradient_input(),params_grad)) + { + layer.backward_inplace(computed_output,gradient_input,sub.get_gradient_input(),params_grad); + } + + template + auto call_layer_backward( + layer_type& layer, + const tensor& , + const tensor& gradient_input, + SUBNET& sub, + tensor& params_grad + ) -> decltype(layer.backward_inplace(gradient_input,sub.get_gradient_input(),params_grad)) + { + layer.backward_inplace(gradient_input,sub.get_gradient_input(),params_grad); + } + + + template + auto call_layer_forward( + layer_type& layer, + const SUBNET& sub, + tensor& /*data_output*/ + ) -> decltype(layer.forward(sub,rt())) + { + // This overload of call_layer_forward() is here because this template + // naturally gets instantiated but only on code paths that never get executed. + // So rather than writing a bunch of hard to read template magic around call + // sites we just have this overload that doesn't do anything (and an assert to + // make sure that's the case). + DLIB_CASSERT(false, "This should never happen"); + } + + template + auto call_layer_forward( + layer_type& layer, + const SUBNET& sub, + resizable_tensor& data_output + ) -> decltype(layer.forward(sub,data_output)) + { + layer.forward(sub,data_output); + } + + template + auto call_layer_forward( + layer_type& layer, + const SUBNET& sub, + tensor& data_output + ) -> decltype(layer.forward_inplace(sub.get_output(),data_output)) + { + layer.forward_inplace(sub.get_output(),data_output); + } + + template + auto call_layer_forward( + layer_type& layer, + const SUBNET& sub, + resizable_tensor& data_output + ) -> decltype(layer.forward_inplace(sub.get_output(),data_output)) + { + if (!have_same_dimensions(data_output, sub.get_output())) + data_output.copy_size(sub.get_output()); + layer.forward_inplace(sub.get_output(),static_cast(data_output)); + } + + + } // end namespace impl + + template + typename impl::tuple_head_helper>::type tuple_head ( + const std::tuple& item + ) + { + return impl::tuple_head_helper>::get(item); + } + + template + auto tuple_tail( + const std::tuple& item + ) -> decltype(impl::basic_tuple_tail(impl::tuple_flatten(item))) + { + return impl::basic_tuple_tail(impl::tuple_flatten(item)); + } + + inline std::tuple<> tuple_tail( + const std::tuple<>& item + ) + { + return item; + } +// ---------------------------------------------------------------------------------------- + + template + class sstack + { + public: + typedef T value_type; + + sstack() = delete; + + sstack ( + T* data_, + size_t s + ) : data(data_), mysize(s) {} + + const T& top() const + { + DLIB_CASSERT(size() != 0, "You can't call top() on an empty stack"); + return *data; + } + T& top() + { + DLIB_CASSERT(size() != 0, "You can't call top() on an empty stack"); + return *data; + } + + size_t size() const { return mysize; } + + sstack pop(size_t num=1) + { + DLIB_CASSERT(num <= size(), "You can't pop more things from the stack than it has in it."); + return sstack(data+num, mysize-num); + } + + private: + + T* data; + size_t mysize; + }; + + template + sstack make_sstack(std::vector& item) + { + return sstack(item.data(), item.size()); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + namespace dimpl + { + template + class subnet_wrapper + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a tool that makes an add_layer or add_loss_layer object + expose only the part of its interface defined by the SUBNET + type in layers_abstract.h. This way, when we pass subnetwork + objects to the layer callbacks those callbacks won't be able to + interact with the subnetworks in a way other than specified + by the SUBNET interface spec. + + We also allow the top layer of a subnet_wrapper stack to call the + private_get_output() and private_get_gradient_input() functions. This + way, layers that have had their output/gradient overwritten by in-place + layers can only be accessed from the in-place layers that sit directly + on top of them since those in-place layers are the only layers that + know how to interact with them properly. + !*/ + + public: + subnet_wrapper(const subnet_wrapper&) = delete; + subnet_wrapper& operator=(const subnet_wrapper&) = delete; + + subnet_wrapper(T& l_, unsigned int sef) : l(l_),_sample_expansion_factor(sef) {} + // Not much here because in this case T is one of the input layer types + // that doesn't have anything in it. + typedef T layer_details_type; + const layer_details_type& layer_details() const { return l; } + unsigned int sample_expansion_factor() const { return _sample_expansion_factor; } + private: + T& l; + unsigned int _sample_expansion_factor; + }; + + template + class subnet_wrapper::value>::type> + { + + public: + subnet_wrapper(const subnet_wrapper&) = delete; + subnet_wrapper& operator=(const subnet_wrapper&) = delete; + + typedef T wrapped_type; + const static size_t num_computational_layers = T::num_computational_layers; + const static size_t num_layers = T::num_layers; + typedef typename T::layer_details_type layer_details_type; + + subnet_wrapper(T& l_, unsigned int = 0) : l(l_),subnetwork(l.subnet(), l.sample_expansion_factor()) {} + + const tensor& get_output() const { return l.private_get_output(); } + tensor& get_gradient_input() { return l.private_get_gradient_input(); } + + const layer_details_type& layer_details() const { return l.layer_details(); } + + const subnet_wrapper& subnet() const { return subnetwork; } + subnet_wrapper& subnet() { return subnetwork; } + unsigned int sample_expansion_factor() const { return l.sample_expansion_factor(); } + + private: + T& l; + subnet_wrapper subnetwork; + }; + + template + class subnet_wrapper::value>::type> + { + + public: + subnet_wrapper(const subnet_wrapper&) = delete; + subnet_wrapper& operator=(const subnet_wrapper&) = delete; + + typedef T wrapped_type; + const static size_t num_computational_layers = T::num_computational_layers; + const static size_t num_layers = T::num_layers; + typedef typename T::layer_details_type layer_details_type; + + subnet_wrapper(T& l_, unsigned int = 0) : l(l_),subnetwork(l.subnet(), l.sample_expansion_factor()) {} + + const tensor& get_output() const { return l.get_output(); } + tensor& get_gradient_input() { return l.get_gradient_input(); } + + const layer_details_type& layer_details() const { return l.layer_details(); } + + const subnet_wrapper& subnet() const { return subnetwork; } + subnet_wrapper& subnet() { return subnetwork; } + unsigned int sample_expansion_factor() const { return l.sample_expansion_factor(); } + + private: + T& l; + subnet_wrapper subnetwork; + }; + } + +// ---------------------------------------------------------------------------------------- + + template + class add_layer; + + template + void serialize(const add_layer& item, std::ostream& out); + template + void deserialize(add_layer& item, std::istream& in); + + template + struct is_nonloss_layer_type> : std::true_type {}; + + template + class add_layer::value>::type> + { + public: + typedef LAYER_DETAILS layer_details_type; + typedef SUBNET subnet_type; + typedef typename subnet_type::input_type input_type; + const static size_t num_layers = subnet_type::num_layers + 1; + const static size_t num_computational_layers = subnet_type::num_computational_layers + 1; + + add_layer( + ): + subnetwork(new subnet_type()), + this_layer_setup_called(false), + gradient_input_is_stale(true), + get_output_and_gradient_input_disabled(false) + { + if (this_layer_operates_inplace()) + subnetwork->disable_output_and_gradient_getters(); + } + + add_layer(const add_layer& item) + { + details = item.details; + subnetwork.reset(new subnet_type(*item.subnetwork)); + this_layer_setup_called = item.this_layer_setup_called; + gradient_input_is_stale = item.gradient_input_is_stale; + get_output_and_gradient_input_disabled = item.get_output_and_gradient_input_disabled; + x_grad = item.x_grad; + cached_output = item.cached_output; + params_grad = item.params_grad; + temp_tensor = item.temp_tensor; + } + add_layer& operator=(const add_layer& item) { add_layer(item).swap(*this); return *this;} + add_layer(add_layer&& item) : add_layer() { swap(item); } + add_layer& operator=(add_layer&& item) { swap(item); return *this; } + + template + friend class add_layer; + template + friend class dimpl::subnet_wrapper; + template + friend class add_tag_layer; + template class T, typename U> + friend class add_skip_layer; + template class L, typename S> + friend class repeat; + + // Allow copying networks from one to another as long as their corresponding + // layers can be constructed from each other. + template + add_layer( + const add_layer& item + ) : + details(item.layer_details()), + subnetwork(new subnet_type(item.subnet())), + this_layer_setup_called(item.this_layer_setup_called), + gradient_input_is_stale(item.gradient_input_is_stale), + get_output_and_gradient_input_disabled(item.get_output_and_gradient_input_disabled), + x_grad(item.x_grad), + cached_output(item.cached_output) + { + if (this_layer_operates_inplace()) + subnetwork->disable_output_and_gradient_getters(); + } + + template + add_layer( + const LAYER_DETAILS& layer_det, + T&& ...args + ) : + details(layer_det), + subnetwork(new subnet_type(std::forward(args)...)), + this_layer_setup_called(false), + gradient_input_is_stale(true), + get_output_and_gradient_input_disabled(false) + { + if (this_layer_operates_inplace()) + subnetwork->disable_output_and_gradient_getters(); + } + + template + struct disable_forwarding_constr + { + const static bool value = std::is_constructible::value; + }; + template + struct disable_forwarding_constr,U...> + { + const static bool value = disable_forwarding_constr::type...>::value; + }; + template + struct disable_forwarding_constr,U...> + { + const static bool value = disable_forwarding_constr::type>::value; + }; + template + struct disable_forwarding_constr,U...> + { + const static bool value = true; + }; + template + struct disable_forwarding_constr> + { + const static bool value = true; + }; + + template < + typename ...T, + typename = typename std::enable_if::type...>::value>::type + > + add_layer( + T&& ...args + ) : + subnetwork(new subnet_type(std::forward(args)...)), + this_layer_setup_called(false), + gradient_input_is_stale(true), + get_output_and_gradient_input_disabled(false) + { + if (this_layer_operates_inplace()) + subnetwork->disable_output_and_gradient_getters(); + } + + template + add_layer( + LAYER_DETAILS&& layer_det, + T&& ...args + ) : + details(std::move(layer_det)), + subnetwork(new subnet_type(std::forward(args)...)), + this_layer_setup_called(false), + gradient_input_is_stale(true), + get_output_and_gradient_input_disabled(false) + { + if (this_layer_operates_inplace()) + subnetwork->disable_output_and_gradient_getters(); + } + + template + add_layer( + const std::tuple& layer_det, + T&& ...args + ) : + details(tuple_head(layer_det)), + subnetwork(new subnet_type(tuple_tail(layer_det),std::forward(args)...)), + this_layer_setup_called(false), + gradient_input_is_stale(true), + get_output_and_gradient_input_disabled(false) + { + if (this_layer_operates_inplace()) + subnetwork->disable_output_and_gradient_getters(); + } + + template + add_layer( + std::tuple<>, + const std::tuple& layer_det, + T&& ...args + ) : add_layer(layer_det,args...) { } + + add_layer ( + std::tuple<> + ) : add_layer() {} + + template + add_layer( + std::tuple<>, + LAYER_DETAILS&& layer_det, + T&& ...args + ) : add_layer(layer_det, args...) { } + + template + void to_tensor ( + forward_iterator ibegin, + forward_iterator iend, + resizable_tensor& data + ) const + { + subnetwork->to_tensor(ibegin,iend,data); + } + + template + const tensor& operator() ( + forward_iterator ibegin, + forward_iterator iend + ) + { + to_tensor(ibegin,iend,temp_tensor); + return forward(temp_tensor); + } + + + const tensor& operator() (const input_type& x) + { + return (*this)(&x, &x+1); + } + + const tensor& forward(const tensor& x) + { + subnetwork->forward(x); + const dimpl::subnet_wrapper wsub(*subnetwork); + if (!this_layer_setup_called) + { + details.setup(wsub); + this_layer_setup_called = true; + } + if (this_layer_operates_inplace()) + impl::call_layer_forward(details, wsub, private_get_output()); + else + impl::call_layer_forward(details, wsub, cached_output); + + gradient_input_is_stale = true; + return private_get_output(); + } + + private: + tensor& private_get_output() const + { + if (const_cast(*this).this_layer_operates_inplace()) + return subnetwork->private_get_output(); + else + return const_cast(cached_output); + } + tensor& private_get_gradient_input() + { + if (this_layer_operates_inplace()) + { + return subnetwork->private_get_gradient_input(); + } + else + { + if (gradient_input_is_stale) + { + gradient_input_is_stale = false; + x_grad.copy_size(private_get_output()); + x_grad = 0; + } + return x_grad; + } + } + void disable_output_and_gradient_getters ( + ) { get_output_and_gradient_input_disabled = true; } + public: + const tensor& get_output() const + { + if (get_output_and_gradient_input_disabled) + throw dlib::error("Accessing this layer's get_output() is disabled because an in-place layer has been stacked on top of it."); + return private_get_output(); + } + tensor& get_gradient_input() + { + if (get_output_and_gradient_input_disabled) + throw dlib::error("Accessing this layer's get_gradient_input() is disabled because an in-place layer has been stacked on top of it."); + return private_get_gradient_input(); + } + + const tensor& get_final_data_gradient( + ) const { return subnetwork->get_final_data_gradient(); } + + void back_propagate_error(const tensor& x) + { + back_propagate_error(x, private_get_gradient_input()); + } + void back_propagate_error(const tensor& x, const tensor& gradient_input) + { + dimpl::subnet_wrapper wsub(*subnetwork); + params_grad.copy_size(details.get_layer_params()); + impl::call_layer_backward(details, private_get_output(), + gradient_input, wsub, static_cast(params_grad)); + + subnetwork->back_propagate_error(x); + + // zero out get_gradient_input() + gradient_input_is_stale = true; + } + + template + void update_parameters(sstack solvers, double learning_rate) + { + DLIB_CASSERT(solvers.size()>=num_computational_layers); + // Don't try to adjust the parameters if this layer doesn't have any or the + // learning rate is disabled for this layer. + if (params_grad.size() != 0 && get_learning_rate_multiplier(details) != 0) + { + const tensor& step = solvers.top()(learning_rate, details, static_cast(params_grad)); + tt::add(details.get_layer_params(), details.get_layer_params(), step); + } + subnetwork->update_parameters(solvers.pop(), learning_rate); + } + + const tensor& get_parameter_gradient( + ) const { return params_grad; } + + tensor& get_parameter_gradient ( + ) { return params_grad; } + + const subnet_type& subnet() const { return *subnetwork; } + subnet_type& subnet() { return *subnetwork; } + + const layer_details_type& layer_details() const { return details; } + layer_details_type& layer_details() { return details; } + + unsigned int sample_expansion_factor() const { return subnet().sample_expansion_factor(); } + + void clean() + { + x_grad.clear(); + cached_output.clear(); + params_grad.clear(); + temp_tensor.clear(); + gradient_input_is_stale = true; + subnetwork->clean(); + call_clean_method_if_exists(details); + } + + friend void serialize(const add_layer& item, std::ostream& out) + { + int version = 2; + serialize(version, out); + serialize(*item.subnetwork, out); + serialize(item.details, out); + serialize(item.this_layer_setup_called, out); + serialize(item.gradient_input_is_stale, out); + serialize(item.get_output_and_gradient_input_disabled, out); + serialize(item.x_grad, out); + serialize(item.cached_output, out); + serialize(item.params_grad, out); + } + + friend void deserialize(add_layer& item, std::istream& in) + { + int version = 0; + deserialize(version, in); + if (!(1 <= version && version <= 2)) + throw serialization_error("Unexpected version found while deserializing dlib::add_layer."); + deserialize(*item.subnetwork, in); + deserialize(item.details, in); + deserialize(item.this_layer_setup_called, in); + deserialize(item.gradient_input_is_stale, in); + deserialize(item.get_output_and_gradient_input_disabled, in); + deserialize(item.x_grad, in); + deserialize(item.cached_output, in); + if (version == 2) + deserialize(item.params_grad, in); + } + + friend std::ostream& operator<< (std::ostream& out, const add_layer& item) + { + int min_length = 0; + item.print(out, 0, min_length); + return out; + } + + void print (std::ostream& out, unsigned long idx, int& min_length) const + { + out << "layer<" << idx << ">\t" << impl::tensor_to_str(private_get_output(), min_length) << layer_details() << "\n"; + subnet().print(out, idx+1, min_length); + } + + private: + + bool this_layer_operates_inplace( + ) + { + // This layer can run in-place if it's an in-place capable layer and also if + // the layer it's on top of doesn't need its own output tensor (since in-place + // layers overwrite that tensor) + return impl::is_inplace_layer(details, *subnetwork) && !subnetwork->this_layer_requires_forward_output(); + } + bool this_layer_requires_forward_output( + ) + { + return impl::backward_requires_forward_output(details, *subnetwork); + } + + void swap(add_layer& item) + { + std::swap(subnetwork,item.subnetwork); + std::swap(details, item.details); + std::swap(this_layer_setup_called, item.this_layer_setup_called); + std::swap(gradient_input_is_stale, item.gradient_input_is_stale); + std::swap(get_output_and_gradient_input_disabled, item.get_output_and_gradient_input_disabled); + std::swap(x_grad, item.x_grad); + std::swap(cached_output, item.cached_output); + std::swap(params_grad, item.params_grad); + } + + + LAYER_DETAILS details; + std::unique_ptr subnetwork; + bool this_layer_setup_called; + bool gradient_input_is_stale; + bool get_output_and_gradient_input_disabled; + // Note that if this_layer_operates_inplace()==true then x_grad and cached_output + // are not used at all. Instead, this layer uses these variables from the lower + // layer. + resizable_tensor x_grad; + resizable_tensor cached_output; + + resizable_tensor params_grad; + + // temp_tensor doesn't logically contribute to the state of this object. + // It is here only to prevent it from being reallocated over and over. + resizable_tensor temp_tensor; + + }; + + template + struct is_add_layer> : std::true_type {}; + template + struct is_add_layer> : std::true_type {}; + template + struct is_add_layer&> : std::true_type {}; + template + struct is_add_layer&> : std::true_type {}; + +// ---------------------------------------------------------------------------------------- + +// This version of add_layer handles the special case where the subnetwork being given is +// just an input layer object. + template + class add_layer + { + public: + typedef LAYER_DETAILS layer_details_type; + typedef INPUT_LAYER subnet_type; + typedef typename INPUT_LAYER::input_type input_type; + const static size_t num_layers = 2; + const static size_t num_computational_layers = 1; + + add_layer( + ): + this_layer_setup_called(false), + gradient_input_is_stale(true), + get_output_and_gradient_input_disabled(false), + _sample_expansion_factor(0) + {} + + add_layer(const add_layer&) = default; + add_layer(add_layer&& item) : add_layer() { swap(item); } + add_layer& operator=(const add_layer&) = default; + add_layer& operator=(add_layer&& item) { swap(item); return *this; } + + template + friend class add_layer; + template + friend class dimpl::subnet_wrapper; + template + friend class add_tag_layer; + template class T, typename U> + friend class add_skip_layer; + template class L, typename S> + friend class repeat; + + // Allow copying networks from one to another as long as their corresponding + // layers can be constructed from each other. + template + add_layer( + const add_layer& item + ): + input_layer(item.subnet()), + details(item.layer_details()), + this_layer_setup_called(item.this_layer_setup_called), + gradient_input_is_stale(item.gradient_input_is_stale), + get_output_and_gradient_input_disabled(false), + _sample_expansion_factor(item._sample_expansion_factor), + x_grad(item.x_grad), + cached_output(item.cached_output), + grad_final(item.grad_final) + { + } + + add_layer( + const LAYER_DETAILS& layer_det + ) : + details(layer_det), + this_layer_setup_called(false), + gradient_input_is_stale(true), + get_output_and_gradient_input_disabled(false), + _sample_expansion_factor(0) + {} + + add_layer( + const INPUT_LAYER& il + ) : + input_layer(il), + this_layer_setup_called(false), + gradient_input_is_stale(true), + get_output_and_gradient_input_disabled(false), + _sample_expansion_factor(0) + {} + + add_layer( + LAYER_DETAILS&& layer_det + ) : + details(std::move(layer_det)), + this_layer_setup_called(false), + gradient_input_is_stale(true), + get_output_and_gradient_input_disabled(false), + _sample_expansion_factor(0) + {} + + add_layer( + LAYER_DETAILS layer_det, + INPUT_LAYER il + ) : + details(std::move(layer_det)), + input_layer(std::move(il)), + this_layer_setup_called(false), + gradient_input_is_stale(true), + get_output_and_gradient_input_disabled(false), + _sample_expansion_factor(0) + {} + + add_layer( + std::tuple<>, + const LAYER_DETAILS& layer_det + ) : add_layer(layer_det) {} + + add_layer( + std::tuple<>, + LAYER_DETAILS&& layer_det + ) : add_layer(layer_det) {} + + add_layer( + std::tuple<>, + LAYER_DETAILS layer_det, + INPUT_LAYER il + ) : add_layer(layer_det,il) {} + + add_layer( + const std::tuple& layer_det + ) : add_layer(tuple_head(layer_det)) {} + + add_layer( + const std::tuple& layer_det, + INPUT_LAYER il + ) : add_layer(tuple_head(layer_det),il) {} + + template + void to_tensor ( + forward_iterator ibegin, + forward_iterator iend, + resizable_tensor& data + ) const + { + input_layer.to_tensor(ibegin, iend, data); + // make sure the input layer's to_tensor() function is implemented properly. + DLIB_CASSERT(data.num_samples() >= std::distance(ibegin,iend), + "The input layer can't produce fewer output tensors than there are inputs."); + DLIB_CASSERT(data.num_samples()%std::distance(ibegin,iend) == 0, + "The number of tensors produced by the input layer must be an integer multiple of the number of input objects."); + + _sample_expansion_factor = data.num_samples()/std::distance(ibegin,iend); + data.async_copy_to_device(); + } + + + template + const tensor& operator() ( + forward_iterator ibegin, + forward_iterator iend + ) + { + to_tensor(ibegin,iend,temp_tensor); + return forward(temp_tensor); + } + + + const tensor& operator() (const input_type& x) + { + return (*this)(&x, &x+1); + } + + const tensor& forward (const tensor& x) + { + DLIB_CASSERT(sample_expansion_factor() != 0, "You must call to_tensor() before this function can be used."); + DLIB_CASSERT(x.num_samples()%sample_expansion_factor() == 0); + subnet_wrapper wsub(x, grad_final, _sample_expansion_factor); + if (!this_layer_setup_called) + { + details.setup(wsub); + this_layer_setup_called = true; + } + impl::call_layer_forward(details, wsub, cached_output); + gradient_input_is_stale = true; + return private_get_output(); + } + + private: + tensor& private_get_output() const { return const_cast(cached_output); } + tensor& private_get_gradient_input() + { + if (gradient_input_is_stale) + { + gradient_input_is_stale = false; + x_grad.copy_size(private_get_output()); + x_grad = 0; + } + return x_grad; + } + void disable_output_and_gradient_getters ( + ) { get_output_and_gradient_input_disabled = true; } + public: + const tensor& get_output() const + { + if (get_output_and_gradient_input_disabled) + throw dlib::error("Accessing this layer's get_output() is disabled because an in-place layer has been stacked on top of it."); + return private_get_output(); + } + tensor& get_gradient_input() + { + if (get_output_and_gradient_input_disabled) + throw dlib::error("Accessing this layer's get_gradient_input() is disabled because an in-place layer has been stacked on top of it."); + return private_get_gradient_input(); + } + + const tensor& get_final_data_gradient( + ) const { return grad_final; } + + void back_propagate_error(const tensor& x) + { + back_propagate_error(x, private_get_gradient_input()); + } + void back_propagate_error(const tensor& x, const tensor& gradient_input) + { + // make sure grad_final is initialized to 0 + if (!have_same_dimensions(x, grad_final)) + grad_final.copy_size(x); + grad_final = 0; + + subnet_wrapper wsub(x, grad_final, _sample_expansion_factor); + params_grad.copy_size(details.get_layer_params()); + impl::call_layer_backward(details, private_get_output(), + gradient_input, wsub, static_cast(params_grad)); + + // zero out get_gradient_input() + gradient_input_is_stale = true; + } + + template + void update_parameters(sstack solvers, double learning_rate) + { + DLIB_CASSERT(solvers.size()>=num_computational_layers); + // Don't try to adjust the parameters if this layer doesn't have any or the + // learning rate is disabled for this layer. + if (params_grad.size() != 0 && get_learning_rate_multiplier(details) != 0) + { + const tensor& step = solvers.top()(learning_rate, details, static_cast(params_grad)); + tt::add(details.get_layer_params(), details.get_layer_params(), step); + } + } + + const tensor& get_parameter_gradient( + ) const { return params_grad; } + + tensor& get_parameter_gradient ( + ) { return params_grad; } + + const subnet_type& subnet() const { return input_layer; } + subnet_type& subnet() { return input_layer; } + + const layer_details_type& layer_details() const { return details; } + layer_details_type& layer_details() { return details; } + + unsigned int sample_expansion_factor() const { return _sample_expansion_factor; } + + void clean() + { + x_grad.clear(); + grad_final.clear(); + cached_output.clear(); + params_grad.clear(); + temp_tensor.clear(); + gradient_input_is_stale = true; + call_clean_method_if_exists(details); + } + + friend void serialize(const add_layer& item, std::ostream& out) + { + int version = 3; + serialize(version, out); + serialize(item.input_layer, out); + serialize(item.details, out); + serialize(item.this_layer_setup_called, out); + serialize(item.gradient_input_is_stale, out); + serialize(item.get_output_and_gradient_input_disabled, out); + serialize(item.x_grad, out); + serialize(item.cached_output, out); + serialize(item.grad_final, out); + serialize(item._sample_expansion_factor, out); + } + + friend void deserialize(add_layer& item, std::istream& in) + { + int version = 0; + deserialize(version, in); + if (!(2 <= version && version <= 3)) + throw serialization_error("Unexpected version found while deserializing dlib::add_layer."); + deserialize(item.input_layer, in); + deserialize(item.details, in); + deserialize(item.this_layer_setup_called, in); + deserialize(item.gradient_input_is_stale, in); + deserialize(item.get_output_and_gradient_input_disabled, in); + deserialize(item.x_grad, in); + deserialize(item.cached_output, in); + deserialize(item.grad_final, in); + if (version >= 3) + deserialize(item._sample_expansion_factor, in); + else + item._sample_expansion_factor = 1; // all layer types set this to 1 in older dlib versions, so that's what we put here. + } + + friend std::ostream& operator<< (std::ostream& out, const add_layer& item) + { + int min_length = 0; + item.print(out, 0, min_length); + return out; + } + + void print (std::ostream& out, unsigned long idx, int& min_length) const + { + out << "layer<" << idx << ">\t" << impl::tensor_to_str(private_get_output(), min_length) << layer_details() << "\n"; + + // Don't print the repeat_input_layer since it doesn't exist from the user's + // point of view. It's just an artifact of how repeat<> works. + if (!std::is_same::value) + out << "layer<" << idx+1 << ">\t" << subnet() << "\n"; + } + + private: + + bool this_layer_requires_forward_output( + ) + { + subnet_wrapper wsub(grad_final, grad_final, _sample_expansion_factor); + return impl::backward_requires_forward_output(details, wsub); + } + + class subnet_wrapper + { + public: + subnet_wrapper(const tensor& x_, resizable_tensor& grad_final_, unsigned int sef) : + x(x_), grad_final(grad_final_), _sample_expansion_factor(sef) {} + + subnet_wrapper(const subnet_wrapper&) = delete; + subnet_wrapper& operator=(const subnet_wrapper&) = delete; + + unsigned int sample_expansion_factor() const { return _sample_expansion_factor;} + const tensor& get_output() const { return x; } + tensor& get_gradient_input() + { + if (!have_same_dimensions(x, grad_final)) + { + grad_final.copy_size(x); + grad_final = 0; + } + return grad_final; + } + + private: + const tensor& x; + resizable_tensor& grad_final; + unsigned int _sample_expansion_factor; + }; + + void swap(add_layer& item) + { + std::swap(input_layer, item.input_layer); + std::swap(details, item.details); + std::swap(this_layer_setup_called, item.this_layer_setup_called); + std::swap(gradient_input_is_stale, item.gradient_input_is_stale); + std::swap(get_output_and_gradient_input_disabled, item.get_output_and_gradient_input_disabled); + std::swap(x_grad, item.x_grad); + std::swap(cached_output, item.cached_output); + std::swap(grad_final, item.grad_final); + std::swap(_sample_expansion_factor, item._sample_expansion_factor); + } + + subnet_type input_layer; + LAYER_DETAILS details; + bool this_layer_setup_called; + bool gradient_input_is_stale; + bool get_output_and_gradient_input_disabled; + mutable unsigned int _sample_expansion_factor; + resizable_tensor x_grad; + resizable_tensor cached_output; + resizable_tensor grad_final; + + // The following 2 objects don't logically contribute to the state of this class. + // They are only here to prevent them from being reallocated over and over in + // member functions. + resizable_tensor params_grad; + resizable_tensor temp_tensor; + }; + +// ---------------------------------------------------------------------------------------- + + template + class add_tag_layer; + + template class tag> + struct tag_id + { + const static unsigned long id = tag::id; + }; + + template + class add_tag_layer::value>::type> + { + public: + typedef SUBNET subnet_type; + typedef typename subnet_type::input_type input_type; + typedef int layer_details_type; // not really used anywhere, but required by subnet_wrapper. + const static size_t num_layers = subnet_type::num_layers + 1; + const static size_t num_computational_layers = subnet_type::num_computational_layers; + const static unsigned long id = ID; + + add_tag_layer() {}; + add_tag_layer(const add_tag_layer&) = default; + add_tag_layer(add_tag_layer&&) = default; + add_tag_layer& operator=(add_tag_layer&&) = default; + add_tag_layer& operator=(const add_tag_layer&) = default; + + template + add_tag_layer( + const add_tag_layer& item + ) : subnetwork(item.subnet()) + {} + + template + add_tag_layer( + T ...args + ) : + subnetwork(std::move(args)...) + { + } + + template + void to_tensor ( + forward_iterator ibegin, + forward_iterator iend, + resizable_tensor& data + ) const + { + subnetwork.to_tensor(ibegin,iend,data); + } + + template + const tensor& operator() ( + forward_iterator ibegin, + forward_iterator iend + ) + { + return subnetwork(ibegin,iend); + } + + const tensor& operator() (const input_type& x) + { + return subnetwork(x); + } + + const tensor& forward(const tensor& x) + { + return subnetwork.forward(x); + } + + const tensor& get_output() const { return subnetwork.get_output(); } + + tensor& get_gradient_input() + { + return subnetwork.get_gradient_input(); + } + + const tensor& get_final_data_gradient( + ) const { return subnetwork.get_final_data_gradient(); } + + void back_propagate_error(const tensor& x) + { + subnetwork.back_propagate_error(x); + } + void back_propagate_error(const tensor& x, const tensor& gradient_input) + { + subnetwork.back_propagate_error(x,gradient_input); + } + + template + void update_parameters(sstack solvers, double learning_rate) + { + subnetwork.update_parameters(solvers, learning_rate); + } + + const tensor& get_parameter_gradient( + ) const { return params_grad; } + + tensor& get_parameter_gradient ( + ) { return params_grad; } + + const subnet_type& subnet() const { return subnetwork; } + subnet_type& subnet() { return subnetwork; } + + unsigned int sample_expansion_factor() const { return subnet().sample_expansion_factor(); } + + void clean() + { + subnetwork.clean(); + } + + friend void serialize(const add_tag_layer& item, std::ostream& out) + { + int version = 1; + serialize(version, out); + serialize(item.subnetwork, out); + } + + friend void deserialize(add_tag_layer& item, std::istream& in) + { + int version = 0; + deserialize(version, in); + if (version != 1) + throw serialization_error("Unexpected version found while deserializing dlib::add_tag_layer."); + deserialize(item.subnetwork, in); + } + + friend std::ostream& operator<< (std::ostream& out, const add_tag_layer& item) + { + int min_length = 0; + item.print(out, 0, min_length); + return out; + } + + void print (std::ostream& out, unsigned long idx, int& min_length) const + { + out << "layer<" << idx << ">\t" << impl::tensor_to_str(private_get_output(), min_length) << "tag" << ID << "\n"; + subnet().print(out, idx+1, min_length); + } + + private: + + template + friend class add_layer; + template + friend class dimpl::subnet_wrapper; + template + friend class add_tag_layer; + template class T, typename U> + friend class add_skip_layer; + template class L, typename S> + friend class repeat; + + // You wouldn't put a tag on a layer if you didn't want to access its forward + // outputs. So this is always true. + bool this_layer_requires_forward_output( + ) { return true; } + + void disable_output_and_gradient_getters ( + ) + { + // This should never happen because only inplace layers call + // disable_output_and_gradient_getters(), however, putting a tag layer right + // before an inplace layer basically means you don't want the following layer + // to operate in place. So the inplace layer should turn itself into an + // out-of-place layer and not call disable_output_and_gradient_getters(). + DLIB_CASSERT(false,"This should never happen"); + } + + tensor& private_get_output() const + { return subnetwork.private_get_output(); } + tensor& private_get_gradient_input() + { return subnetwork.private_get_gradient_input(); } + + subnet_type subnetwork; + + // This member doesn't logically contribute to the state of the object since it is + // always empty. It's just here so we can have the get_parameter_gradient() methods + // which have to return something. So they return this empty tensor. + resizable_tensor params_grad; + }; + +// ---------------------------------------------------------------------------------------- + + template + struct decorator_repeat_group + { + decorator_repeat_group( + T&& ...args + ) : data(std::forward(args)...) {} + + std::tuple data; + }; + template + decorator_repeat_group repeat_group ( + T&& ...args + ) + { + return decorator_repeat_group(std::forward(args)...); + } + + template < + size_t num, + template class REPEATED_LAYER, + typename SUBNET + > + class repeat + { + static_assert(num > 0, "You can't have a layer repeated 0 times."); + public: + typedef SUBNET subnet_type; + typedef typename SUBNET::input_type input_type; + typedef int layer_details_type; // not really used anywhere, but required by subnet_wrapper. + const static size_t comp_layers_in_each_group = (REPEATED_LAYER::num_computational_layers-SUBNET::num_computational_layers); + const static size_t comp_layers_in_repeated_group = comp_layers_in_each_group*num; + const static size_t num_computational_layers = comp_layers_in_repeated_group + SUBNET::num_computational_layers; + + const static size_t layers_in_each_group = (REPEATED_LAYER::num_layers-SUBNET::num_layers); + const static size_t layers_in_repeated_group = layers_in_each_group*num; + const static size_t num_layers = subnet_type::num_layers + layers_in_repeated_group; + + + typedef REPEATED_LAYER repeated_layer_type; + + repeat( + ) : + details(num) + { + } + + size_t num_repetitions ( + ) const { return num; } + + const repeated_layer_type& get_repeated_layer ( + size_t i + ) const + { + DLIB_CASSERT(i < num_repetitions()); + return details[i]; + } + + repeated_layer_type& get_repeated_layer ( + size_t i + ) + { + DLIB_CASSERT(i < num_repetitions()); + return details[i]; + } + + repeat(const repeat&) = default; + repeat(repeat&&) = default; + repeat& operator=(repeat&&) = default; + repeat& operator=(const repeat&) = default; + + template class T, typename U> + repeat( + const repeat& item + ) : + subnetwork(item.subnetwork) + { + for (auto&& d : item.details) + details.emplace_back(d); + } + + template + repeat( + T arg1, + U ...args2 + ): + details(num, std::move(arg1)), + subnetwork(std::move(args2)...) + { + } + + template + repeat( + decorator_repeat_group&& arg1, + U ...args2 + ): + details(num, arg1.data), + subnetwork(std::move(args2)...) + { + } + + template + repeat( + std::tuple<>, + T arg1, + U ...args2 + ): + details(num, std::move(arg1)), + subnetwork(std::move(args2)...) + { + } + + template + void to_tensor ( + forward_iterator ibegin, + forward_iterator iend, + resizable_tensor& data + ) const + { + subnetwork.to_tensor(ibegin,iend,data); + // call to_tensor on the networks in details just to populate the + // _sample_expansion_factor values in those networks. Other than that this + // call is a noop. + for (auto& d : details) + d.to_tensor(ibegin, iend, data); + } + + template + const tensor& operator() ( + forward_iterator ibegin, + forward_iterator iend + ) + { + to_tensor(ibegin,iend,temp_tensor); + return forward(temp_tensor); + } + + const tensor& operator() (const input_type& x) + { + return (*this)(&x, &x+1); + } + + const tensor& forward(const tensor& x) + { + subnetwork.forward(x); + details[details.size()-1].forward(subnetwork.get_output()); + for (long i = details.size()-2; i >= 0; --i) + details[i].forward(details[i+1].get_output()); + return private_get_output(); + } + + private: + tensor& private_get_output() const + { + return details[0].private_get_output(); + } + tensor& private_get_gradient_input() + { + return details[0].private_get_gradient_input(); + } + public: + const tensor& get_output() const + { + return details[0].get_output(); + } + tensor& get_gradient_input() + { + return details[0].get_gradient_input(); + } + + const tensor& get_parameter_gradient( + ) const { return details[0].get_parameter_gradient(); } + + tensor& get_parameter_gradient ( + ) { return details[0].get_parameter_gradient(); } + + void back_propagate_error(const tensor& x) + { + back_propagate_error(x, private_get_gradient_input()); + } + void back_propagate_error(const tensor& x, const tensor& gradient_input) + { + if (details.size() > 1) + { + details[0].back_propagate_error(details[1].get_output(), gradient_input); + for (size_t i = 1; i < details.size(); ++i) + { + if (i+1 < details.size()) + details[i].back_propagate_error(details[i+1].get_output(), details[i-1].get_final_data_gradient()); + else + details[i].back_propagate_error(subnetwork.get_output(), details[i-1].get_final_data_gradient()); + } + } + else + { + details[0].back_propagate_error(subnetwork.get_output(), gradient_input); + } + subnetwork.back_propagate_error(x, details.back().get_final_data_gradient()); + } + + template + void update_parameters(sstack solvers, double learning_rate) + { + for (size_t i = 0; i < details.size(); ++i) + details[i].update_parameters(solvers.pop(comp_layers_in_each_group*i),learning_rate); + subnetwork.update_parameters(solvers.pop(comp_layers_in_each_group*details.size()),learning_rate); + } + + const subnet_type& subnet() const { return subnetwork; } + subnet_type& subnet() { return subnetwork; } + + unsigned int sample_expansion_factor() const { return subnet().sample_expansion_factor(); } + + void clean() + { + temp_tensor.clear(); + subnetwork.clean(); + for (auto&& d : details) + d.clean(); + } + + friend void serialize(const repeat& item, std::ostream& out) + { + int version = 1; + serialize(version, out); + serialize(item.details, out); + serialize(item.subnetwork, out); + } + + friend void deserialize(repeat& item, std::istream& in) + { + int version = 0; + deserialize(version, in); + if (version != 1) + throw serialization_error("Unexpected version found while deserializing dlib::repeat."); + deserialize(item.details, in); + deserialize(item.subnetwork, in); + } + + friend std::ostream& operator<< (std::ostream& out, const repeat& item) + { + int min_length = 0; + item.print(out, 0, min_length); + return out; + } + + void print (std::ostream& out, unsigned long idx, int& min_length) const + { + for (size_t i = 0; i < num_repetitions(); ++i) + { + get_repeated_layer(i).print(out, idx, min_length); + idx += layers_in_each_group; + } + subnet().print(out, idx, min_length); + } + private: + + + template + friend class add_layer; + template + friend class dimpl::subnet_wrapper; + template + friend class add_tag_layer; + template class T, typename U> + friend class add_skip_layer; + template class L, typename S> + friend class repeat; + + bool this_layer_requires_forward_output( + ) + { + return details[0].this_layer_requires_forward_output(); + } + + void disable_output_and_gradient_getters ( + ) + { + details[0].disable_output_and_gradient_getters(); + } + + + std::vector details; + subnet_type subnetwork; + + // temp_tensor doesn't logically contribute to the state of this class. + // It is here only to void needing to reallocate it over and over. + resizable_tensor temp_tensor; + }; + + template < + size_t num, + template class REPEATED_LAYER, + typename SUBNET + > + struct is_nonloss_layer_type> : std::true_type {}; + +// ---------------------------------------------------------------------------------------- + +// This version of add_tag_layer handles the special case where the subnetwork being given +// is just an input layer object. + template + class add_tag_layer + { + public: + typedef INPUT_LAYER subnet_type; + typedef typename subnet_type::input_type input_type; + typedef int layer_details_type; // not really used anywhere, but required by subnet_wrapper. + const static size_t num_computational_layers = 0; + const static size_t num_layers = 2; + const static unsigned long id = ID; + + add_tag_layer():cached_output_ptr(nullptr),gradient_input_is_stale(true),_sample_expansion_factor(0) {} + + add_tag_layer(const add_tag_layer&) = default; + add_tag_layer& operator=(const add_tag_layer&) = default; + add_tag_layer(add_tag_layer&& item) : add_tag_layer() { swap(item); } + add_tag_layer& operator=(add_tag_layer&& item) { swap(item); return *this; } + + template + add_tag_layer( + const add_tag_layer& item + ) : input_layer(item.subnet()), + cached_output(item.cached_output), + cached_output_ptr(nullptr), + grad_final(item.grad_final), + gradient_input_is_stale(item.gradient_input_is_stale), + _sample_expansion_factor(0) + {} + + template + add_tag_layer( + T ...args + ) : + input_layer(std::move(args)...), + cached_output_ptr(nullptr), + gradient_input_is_stale(true), + _sample_expansion_factor(0) + { + } + + add_tag_layer ( + std::tuple<> + ) : + cached_output_ptr(nullptr), + gradient_input_is_stale(true), + _sample_expansion_factor(0) + {} + + template + void to_tensor ( + forward_iterator ibegin, + forward_iterator iend, + resizable_tensor& data + ) const + { + input_layer.to_tensor(ibegin,iend,data); + + // make sure the input layer's to_tensor() function is implemented properly. + DLIB_CASSERT(data.num_samples() >= std::distance(ibegin,iend), + "The input layer can't produce fewer output tensors than there are inputs."); + DLIB_CASSERT(data.num_samples()%std::distance(ibegin,iend) == 0, + "The number of tensors produced by the input layer must be an integer multiple of the number of input objects."); + + _sample_expansion_factor = data.num_samples()/std::distance(ibegin,iend); + data.async_copy_to_device(); + } + + unsigned int sample_expansion_factor() const { return _sample_expansion_factor; } + + template + const tensor& operator() ( + forward_iterator ibegin, + forward_iterator iend + ) + { + input_layer.to_tensor(ibegin,iend,cached_output); + cached_output_ptr = nullptr; + return get_output(); + } + + const tensor& operator() (const input_type& x) + { + return (*this)(&x, &x+1); + } + + const tensor& forward(const tensor& x) + { + // If this tag is the first layer in one of the sub networks inside a repeat + // layer then we don't want it to be creating copies of x. This is because, we + // can just hold a pointer to x since the way repeat is constructed guarantees + // that x will have a lifetime larger than this pointer. + if (is_same_type::value) + cached_output_ptr = const_cast(&x); + else + cached_output = x; + gradient_input_is_stale = true; + return get_output(); + } + + const tensor& get_output() const + { + if (cached_output_ptr) + return *cached_output_ptr; + else + return cached_output; + } + + const tensor& get_final_data_gradient( + ) const { return grad_final; } + + tensor& get_gradient_input() + { + if (!have_same_dimensions(get_output(), grad_final) || + gradient_input_is_stale) + { + grad_final.copy_size(get_output()); + grad_final = 0; + gradient_input_is_stale = false; + } + return grad_final; + } + + void back_propagate_error(const tensor& /*x*/) + { + // nothing to do + } + void back_propagate_error(const tensor& /*x*/, const tensor& /*gradient_input*/) + { + // nothing to do + } + + template + void update_parameters(sstack /*solvers*/, double /*learning_rate*/) + { + // nothing to do + } + + const subnet_type& subnet() const { return input_layer; } + subnet_type& subnet() { return input_layer; } + + void clean() + { + grad_final.clear(); + cached_output.clear(); + cached_output_ptr = 0; + } + + friend void serialize(const add_tag_layer& item, std::ostream& out) + { + int version = 2; + serialize(version, out); + serialize(item.input_layer, out); + serialize(item.cached_output, out); + serialize(item.grad_final, out); + serialize(item.gradient_input_is_stale, out); + serialize(item._sample_expansion_factor, out); + } + + friend void deserialize(add_tag_layer& item, std::istream& in) + { + int version = 0; + deserialize(version, in); + if (!(1 <= version && version <= 2)) + throw serialization_error("Unexpected version found while deserializing dlib::add_tag_layer."); + deserialize(item.input_layer, in); + deserialize(item.cached_output, in); + deserialize(item.grad_final, in); + deserialize(item.gradient_input_is_stale, in); + item.cached_output_ptr = nullptr; + if (version >= 2) + deserialize(item._sample_expansion_factor, in); + else + item._sample_expansion_factor = 1; // all layer types set this to 1 in older dlib versions, so that's what we put here. + + } + + friend std::ostream& operator<< (std::ostream& out, const add_tag_layer& item) + { + int min_length = 0; + item.print(out, 0, min_length); + return out; + } + + void print (std::ostream& out, unsigned long idx, int& min_length) const + { + out << "layer<"<\t"< works. + if (!std::is_same::value) + out << "layer<"<< idx+1 << ">\t" << subnet() << "\n"; + } + + private: + + template + friend class add_layer; + template + friend class dimpl::subnet_wrapper; + template + friend class add_tag_layer; + template class T, typename U> + friend class add_skip_layer; + template class L, typename S> + friend class repeat; + + // You woudln't put a tag on a layer if you didn't want to access its forward + // outputs. So this is always true. + bool this_layer_requires_forward_output( + ) { return true; } + + void disable_output_and_gradient_getters ( + ) + { + // This should never happen because only inplace layers call + // disable_output_and_gradient_getters(), however, putting a tag layer right + // before an inplace layer basically means you don't want the following layer + // to operate in place. So the inplace layer should turn itself into an + // out-of-place layer and not call disable_output_and_gradient_getters(). + DLIB_CASSERT(false,"This should never happen"); + } + + tensor& private_get_output() const + { return const_cast(get_output()); } + tensor& private_get_gradient_input() + { return get_gradient_input(); } + + void swap(add_tag_layer& item) + { + std::swap(input_layer, item.input_layer); + std::swap(cached_output, item.cached_output); + std::swap(cached_output_ptr, item.cached_output_ptr); + std::swap(grad_final, item.grad_final); + std::swap(gradient_input_is_stale, item.gradient_input_is_stale); + std::swap(_sample_expansion_factor, item._sample_expansion_factor); + } + + subnet_type input_layer; + resizable_tensor cached_output; + tensor* cached_output_ptr; + resizable_tensor grad_final; + bool gradient_input_is_stale; + mutable unsigned int _sample_expansion_factor; + }; + + template + struct is_nonloss_layer_type> : std::true_type {}; + + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template + class add_loss_layer; + + class no_label_type + { + private: + // We don't want anyone making these no_label_type objects. They are here only to + // allow add_loss_layer::training_label_type and dnn_trainer::training_label_type + // to exist which avoids needing to overload add_loss_layer and dnn_trainer for + // supervised an unsupervised losses. It also can be a type to use in template + // metaprogramming to indicate "no label". So here we make the constructor private + // with the exception that add_loss_layer objects can make it (again, just to + // simplify add_loss_layer's implementation). + no_label_type(){}; + template friend class add_loss_layer; + template < typename net_type, typename solver_type > friend class dnn_trainer; + }; + +// ---------------------------------------------------------------------------------------- + + template + class add_loss_layer + { + template + struct get_loss_layer_training_label_type + { + typedef no_label_type type; + }; + template + struct get_loss_layer_training_label_type::type> + { + typedef typename T::training_label_type type; + }; + + template + struct get_loss_layer_output_label_type + { + typedef no_label_type type; + }; + template + struct get_loss_layer_output_label_type::type> + { + typedef typename T::output_label_type type; + }; + + public: + typedef LOSS_DETAILS loss_details_type; + typedef SUBNET subnet_type; + typedef typename subnet_type::input_type input_type; + const static size_t num_layers = subnet_type::num_layers + 1; + // Note that the loss layer doesn't count as an additional computational layer. + const static size_t num_computational_layers = subnet_type::num_computational_layers; + typedef typename get_loss_layer_training_label_type::type training_label_type; + typedef typename get_loss_layer_output_label_type::type output_label_type; + + static_assert(is_nonloss_layer_type::value, + "SUBNET must be of type add_layer, add_skip_layer, or add_tag_layer."); + + + add_loss_layer() {}; + add_loss_layer(const add_loss_layer&) = default; + add_loss_layer& operator=(const add_loss_layer&) = default; + add_loss_layer(add_loss_layer&& item) : add_loss_layer() { swap(item); } + add_loss_layer& operator=(add_loss_layer&& item) { swap(item); return *this; } + + template + add_loss_layer( + const add_loss_layer& item + ) : + loss(item.loss_details()), + subnetwork(item.subnet()) + {} + + template + add_loss_layer( + const LOSS_DETAILS& layer_det, + T&& ...args + ) : + loss(layer_det), + subnetwork(std::forward(args)...) + { + } + + template + add_loss_layer( + LOSS_DETAILS&& layer_det, + T&& ...args + ) : + loss(std::move(layer_det)), + subnetwork(std::forward(args)...) + { + } + + template + struct disable_forwarding_constr + { + const static bool value = std::is_constructible::value; + }; + template + struct disable_forwarding_constr> + { + const static bool value = true; + }; + + template < + typename ...T, + typename = typename std::enable_if::type...>::value>::type + > + add_loss_layer( + T&& ...args + ) : + subnetwork(std::forward(args)...) + { + } + + template + void to_tensor ( + forward_iterator ibegin, + forward_iterator iend, + resizable_tensor& data + ) const + { + subnetwork.to_tensor(ibegin,iend,data); + } + + unsigned int sample_expansion_factor() const { return subnet().sample_expansion_factor(); } + + template + void operator() ( + const tensor& x, + output_iterator obegin + ) + { + subnetwork.forward(x); + const dimpl::subnet_wrapper wsub(subnetwork); + loss.to_label(x, wsub, obegin); + } + + template + void operator() ( + forward_iterator ibegin, + forward_iterator iend, + output_iterator obegin + ) + { + to_tensor(ibegin,iend,temp_tensor); + (*this)(temp_tensor, obegin); + } + + const output_label_type& operator() (const input_type& x) + { + (*this)(&x, &x+1, &temp_label); + return temp_label; + } + + template + const output_label_type& process (const input_type& x, T&& ...args) + { + to_tensor(&x,&x+1,temp_tensor); + subnetwork.forward(temp_tensor); + const dimpl::subnet_wrapper wsub(subnetwork); + loss.to_label(temp_tensor, wsub, &temp_label, std::forward(args)...); + return temp_label; + } + + template + std::vector process_batch (const iterable_type& data, size_t batch_size, T&& ...args) + { + std::vector results(std::distance(data.begin(), data.end())); + auto o = results.begin(); + auto i = data.begin(); + auto num_remaining = results.size(); + while(num_remaining != 0) + { + auto inc = std::min(batch_size, num_remaining); + to_tensor(i,i+inc,temp_tensor); + subnetwork.forward(temp_tensor); + const dimpl::subnet_wrapper wsub(subnetwork); + loss.to_label(temp_tensor, wsub, o, std::forward(args)...); + + i += inc; + o += inc; + num_remaining -= inc; + } + return results; + } + + template + std::vector operator() ( + const iterable_type& data, + size_t batch_size = 128 + ) + { + std::vector results(std::distance(data.begin(), data.end())); + auto o = results.begin(); + auto i = data.begin(); + auto num_remaining = results.size(); + while(num_remaining != 0) + { + auto inc = std::min(batch_size, num_remaining); + (*this)(i, i+inc, o); + i += inc; + o += inc; + num_remaining -= inc; + } + return results; + } + + template + double compute_loss ( + const tensor& x, + label_iterator lbegin + ) + { + subnetwork.forward(x); + dimpl::subnet_wrapper wsub(subnetwork); + return loss.compute_loss_value_and_gradient(x, lbegin, wsub); + } + + template + double compute_loss ( + forward_iterator ibegin, + forward_iterator iend, + label_iterator lbegin + ) + { + to_tensor(ibegin,iend,temp_tensor); + return compute_loss(temp_tensor, lbegin); + } + + double compute_loss ( + const tensor& x + ) + { + subnetwork.forward(x); + dimpl::subnet_wrapper wsub(subnetwork); + return loss.compute_loss_value_and_gradient(x, wsub); + } + + template + double compute_loss ( + forward_iterator ibegin, + forward_iterator iend + ) + { + to_tensor(ibegin,iend,temp_tensor); + return compute_loss(temp_tensor); + } + + template + double compute_parameter_gradients ( + const tensor& x, + label_iterator lbegin + ) + { + subnetwork.forward(x); + dimpl::subnet_wrapper wsub(subnetwork); + double l = loss.compute_loss_value_and_gradient(x, lbegin, wsub); + subnetwork.back_propagate_error(x); + return l; + } + template + double compute_parameter_gradients ( + forward_iterator ibegin, + forward_iterator iend, + label_iterator lbegin + ) + { + to_tensor(ibegin,iend,temp_tensor); + return compute_parameter_gradients(temp_tensor, lbegin); + } + double compute_parameter_gradients ( + const tensor& x + ) + { + subnetwork.forward(x); + dimpl::subnet_wrapper wsub(subnetwork); + double l = loss.compute_loss_value_and_gradient(x, wsub); + subnetwork.back_propagate_error(x); + return l; + } + template + double compute_parameter_gradients ( + forward_iterator ibegin, + forward_iterator iend + ) + { + to_tensor(ibegin,iend,temp_tensor); + return compute_parameter_gradients(temp_tensor); + } + + template + void update_parameters ( + sstack solvers, + double learning_rate + ) + { + subnetwork.update_parameters(solvers, learning_rate); + } + + const subnet_type& subnet() const { return subnetwork; } + subnet_type& subnet() { return subnetwork; } + const loss_details_type& loss_details() const { return loss; } + loss_details_type& loss_details() { return loss; } + + void clean ( + ) + { + temp_tensor.clear(); + subnetwork.clean(); + } + + template + friend void serialize(const add_loss_layer& item, std::ostream& out); + template + friend void deserialize(add_loss_layer& item, std::istream& in); + + friend std::ostream& operator<< (std::ostream& out, const add_loss_layer& item) + { + int min_length = 0; + item.print(out, 0, min_length); + return out; + } + + void print (std::ostream& out, unsigned long idx, int& min_length) const + { + out << "layer<" << idx << ">\t" << loss_details() << "\n"; + subnet().print(out, idx+1, min_length); + } + + private: + + + void swap(add_loss_layer& item) + { + std::swap(loss, item.loss); + std::swap(subnetwork, item.subnetwork); + } + + loss_details_type loss; + subnet_type subnetwork; + + // These two objects don't logically contribute to the state of this object. They + // are here to prevent them from being reallocated over and over. + output_label_type temp_label; + resizable_tensor temp_tensor; + }; + + template + void serialize(const add_loss_layer& item, std::ostream& out) + { + int version = 1; + serialize(version, out); + serialize(item.loss, out); + serialize(item.subnetwork, out); + } + + template + void deserialize(add_loss_layer& item, std::istream& in) + { + int version = 0; + deserialize(version, in); + if (version != 1) + throw serialization_error("Unexpected version found while deserializing dlib::add_loss_layer."); + deserialize(item.loss, in); + deserialize(item.subnetwork, in); + } + + + template + struct is_loss_layer_type> : std::true_type {}; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + namespace impl + { + template + struct layer_helper + { + static_assert(i < T::num_layers, "Call to layer() attempted to access non-existing layer in neural network."); + static T& makeT(); + using next_type = typename std::remove_reference::type; + using type = typename layer_helper::type; + static type& layer(T& n) + { + return layer_helper::layer(n.subnet()); + } + }; + template < + unsigned int i, + size_t N, template class L, typename S + > + struct layer_helper, typename std::enable_if<(i!=0&&i>=repeat::layers_in_repeated_group)>::type> + { + const static size_t layers_in_repeated_group = repeat::layers_in_repeated_group; + + static repeat& makeT(); + using next_type = typename std::remove_reference::type; + using type = typename layer_helper::type; + static type& layer(repeat& n) + { + return layer_helper::layer(n.subnet()); + } + }; + template < + unsigned int i, + size_t N, template class L, typename S + > + struct layer_helper, typename std::enable_if<(i!=0&&i::layers_in_repeated_group)>::type> + { + const static size_t layers_in_each_group = repeat::layers_in_each_group; + typedef typename repeat::repeated_layer_type repeated_layer_type; + using next_type = repeated_layer_type; + using type = typename layer_helper::type; + static type& layer(repeat& n) + { + return layer_helper::layer(n.get_repeated_layer(i/layers_in_each_group)); + } + }; + template < + size_t N, template class L, typename S + > + struct layer_helper<0,repeat, void> + { + typedef typename repeat::repeated_layer_type repeated_layer_type; + using type = repeated_layer_type; + static type& layer(repeat& n) + { + return n.get_repeated_layer(0); + } + }; + + + + template < + unsigned int i, + size_t N, template class L, typename S + > + struct layer_helper, typename std::enable_if<(i!=0&&i>=repeat::layers_in_repeated_group)>::type> + { + const static size_t layers_in_repeated_group = repeat::layers_in_repeated_group; + + static const repeat& makeT(); + using next_type = const typename std::remove_reference::type; + using type = const typename layer_helper::type; + static type& layer(const repeat& n) + { + return layer_helper::layer(n.subnet()); + } + }; + template < + unsigned int i, + size_t N, template class L, typename S + > + struct layer_helper, typename std::enable_if<(i!=0&&i::layers_in_repeated_group)>::type> + { + const static size_t layers_in_each_group = repeat::layers_in_each_group; + typedef typename repeat::repeated_layer_type repeated_layer_type; + using next_type = const repeated_layer_type; + using type = const typename layer_helper::type; + static type& layer(const repeat& n) + { + return layer_helper::layer(n.get_repeated_layer(i/layers_in_each_group)); + } + }; + template < + size_t N, template class L, typename S + > + struct layer_helper<0,const repeat, void> + { + typedef typename repeat::repeated_layer_type repeated_layer_type; + using type = const repeated_layer_type; + static type& layer(const repeat& n) + { + return n.get_repeated_layer(0); + } + }; + + + + template + struct layer_helper<0,T,void> + { + using type = T; + static type& layer(T& n) + { + return n; + } + }; + + template class Match, typename T, unsigned int i, typename enabled = void> + struct layer_helper_match + { + static T& makeT(); + using next_type = typename std::remove_reference::type; + using type = typename layer_helper_match::type; + static type& layer(T& n) + { + return layer_helper_match::layer(n.subnet()); + } + }; + // This overload catches add_layer and add_loss_layer templates. + template class Match, typename T, unsigned int i> + struct layer_helper_match>::value>::type> + { + using type = typename layer_helper::type; + static type& layer(T& n) + { + return layer_helper::layer(n); + } + }; + // This overload catches input templates. + template class Match, typename T, unsigned int i> + struct layer_helper_match>::value>::type> + { + using type = typename layer_helper::type; + static type& layer(T& n) + { + return layer_helper::layer(n); + } + }; + // This overload catches subnet_wrapper templates. + template class Match, typename T, unsigned int i> + struct layer_helper_match>::value>::type> + { + using type = typename layer_helper::type; + static type& layer(T& n) + { + return layer_helper::layer(n); + } + }; + } + + template + typename impl::layer_helper::type& layer (T& n) + { + return impl::layer_helper::layer(n); + } + + template class Match, typename T> + typename impl::layer_helper_match::type& layer (T& n) + { + return impl::layer_helper_match::layer(n); + } + + template class Match, unsigned int i, typename T> + typename impl::layer_helper_match::type& layer (T& n) + { + return impl::layer_helper_match::layer(n); + } + +// ---------------------------------------------------------------------------------------- + + + namespace dimpl + { + template + T& get_input_details ( + T& net + ) + { + return net; + } + + template + auto get_input_details ( + dimpl::subnet_wrapper& net + ) -> decltype(net.layer_details())& + { + return net.layer_details(); + } + + template + auto get_input_details ( + const dimpl::subnet_wrapper& net + ) -> decltype(net.layer_details())& + { + return net.layer_details(); + } + } + + template + auto input_layer ( + net_type& net + ) -> decltype(dimpl::get_input_details(layer(net)))& + { + // Calling input_layer() on a subnet_wrapper is a little funny since the behavior of + // .subnet() returns another subnet_wrapper rather than an input details object as it + // does in add_layer. + return dimpl::get_input_details(layer(net)); + } + +// ---------------------------------------------------------------------------------------- + + template class TAG_TYPE, typename SUBNET> + class add_skip_layer + { + public: + typedef SUBNET subnet_type; + typedef typename subnet_type::input_type input_type; + typedef int layer_details_type; // not really used anywhere, but required by subnet_wrapper. + const static size_t num_layers = subnet_type::num_layers + 1; + const static size_t num_computational_layers = subnet_type::num_computational_layers; + const static unsigned long id = tag_id::id; + + add_skip_layer() {}; + add_skip_layer(const add_skip_layer&) = default; + add_skip_layer(add_skip_layer&&) = default; + add_skip_layer& operator=(add_skip_layer&&) = default; + add_skip_layer& operator=(const add_skip_layer&) = default; + + template + add_skip_layer( + const add_skip_layer& item + ) : subnetwork(item.subnet()) + {} + + template + add_skip_layer( + T ...args + ) : + subnetwork(std::move(args)...) + { + } + + template + void to_tensor ( + forward_iterator ibegin, + forward_iterator iend, + resizable_tensor& data + ) const + { + subnetwork.to_tensor(ibegin,iend,data); + } + + template + const tensor& operator() ( + forward_iterator ibegin, + forward_iterator iend + ) + { + subnetwork(ibegin,iend); + return layer(subnetwork).get_output(); + } + + const tensor& operator() (const input_type& x) + { + subnetwork(x); + return layer(subnetwork).get_output(); + } + + const tensor& forward(const tensor& x) + { + subnetwork.forward(x); + return layer(subnetwork).get_output(); + } + + const tensor& get_output() const + { + return layer(subnetwork).get_output(); + } + + tensor& get_gradient_input() + { + return layer(subnetwork).get_gradient_input(); + } + + const tensor& get_final_data_gradient( + ) const + { + return subnetwork.get_final_data_gradient(); + } + + void back_propagate_error(const tensor& x) + { + subnetwork.back_propagate_error(x); + } + + template + void update_parameters(sstack solvers, double learning_rate) + { + subnetwork.update_parameters(solvers, learning_rate); + } + + const tensor& get_parameter_gradient( + ) const { return params_grad; } + + tensor& get_parameter_gradient ( + ) { return params_grad; } + + + const subnet_type& subnet() const + { + return subnetwork; + } + + subnet_type& subnet() + { + return subnetwork; + } + + unsigned int sample_expansion_factor() const { return subnet().sample_expansion_factor(); } + + void clean() + { + subnetwork.clean(); + } + + friend void serialize(const add_skip_layer& item, std::ostream& out) + { + int version = 1; + serialize(version, out); + serialize(item.subnetwork, out); + } + + friend void deserialize(add_skip_layer& item, std::istream& in) + { + int version = 0; + deserialize(version, in); + if (version != 1) + throw serialization_error("Unexpected version found while deserializing dlib::add_skip_layer."); + deserialize(item.subnetwork, in); + } + + friend std::ostream& operator<< (std::ostream& out, const add_skip_layer& item) + { + int min_length = 0; + item.print(out, 0, min_length); + return out; + } + + void print (std::ostream& out, unsigned long idx, int& min_length) const + { + out << "layer<" << idx << ">\t"< + friend class add_layer; + template + friend class dimpl::subnet_wrapper; + template + friend class add_tag_layer; + template class T, typename U> + friend class add_skip_layer; + template class L, typename S> + friend class repeat; + + bool this_layer_requires_forward_output( + ) { return layer(subnetwork).this_layer_requires_forward_output(); } + + void disable_output_and_gradient_getters ( + ) { layer(subnetwork).disable_output_and_gradient_getters(); } + + tensor& private_get_output() const + { return layer(subnetwork).private_get_output(); } + tensor& private_get_gradient_input() + { return layer(subnetwork).private_get_gradient_input(); } + + subnet_type subnetwork; + + // This member doesn't logically contribute to the state of the object since it is + // always empty. It's just here so we can have the get_parameter_gradient() methods + // which have to return something. So they return this empty tensor. + resizable_tensor params_grad; + }; + template class T, typename U> + struct is_nonloss_layer_type> : std::true_type {}; + + template using tag1 = add_tag_layer< 1, SUBNET>; + template using tag2 = add_tag_layer< 2, SUBNET>; + template using tag3 = add_tag_layer< 3, SUBNET>; + template using tag4 = add_tag_layer< 4, SUBNET>; + template using tag5 = add_tag_layer< 5, SUBNET>; + template using tag6 = add_tag_layer< 6, SUBNET>; + template using tag7 = add_tag_layer< 7, SUBNET>; + template using tag8 = add_tag_layer< 8, SUBNET>; + template using tag9 = add_tag_layer< 9, SUBNET>; + template using tag10 = add_tag_layer<10, SUBNET>; + + template using skip1 = add_skip_layer< tag1, SUBNET>; + template using skip2 = add_skip_layer< tag2, SUBNET>; + template using skip3 = add_skip_layer< tag3, SUBNET>; + template using skip4 = add_skip_layer< tag4, SUBNET>; + template using skip5 = add_skip_layer< tag5, SUBNET>; + template using skip6 = add_skip_layer< tag6, SUBNET>; + template using skip7 = add_skip_layer< tag7, SUBNET>; + template using skip8 = add_skip_layer< tag8, SUBNET>; + template using skip9 = add_skip_layer< tag9, SUBNET>; + template using skip10 = add_skip_layer; + +// ---------------------------------------------------------------------------------------- + + namespace timpl + { + inline void fill_with_gassuan_random_numbers ( + tensor& t, + dlib::rand& rnd, + double sigma = 1 + ) + { + float* data = t.host(); + for (size_t i = 0; i < t.size(); ++i) + data[i] = rnd.get_random_gaussian()*sigma; + } + + class test_layer_subnet + { + public: + test_layer_subnet ( + dlib::rand& rnd_ + ) : rnd(rnd_) + { + // Output and gradient_input have to have the same dimensions in each + // layer. + const long num_samples = rnd.get_random_32bit_number()%4+3; + const long k = rnd.get_random_32bit_number()%4+2; + const long nr = rnd.get_random_32bit_number()%4+2; + const long nc = rnd.get_random_32bit_number()%4+2; + + output.set_size(num_samples, k, nr, nc); + gradient_input.set_size(num_samples, k, nr, nc); + + // Use a non-zero initial gradient to make sure the layers add to it + // rather than assign and blow away the initial value. + fill_with_gassuan_random_numbers(gradient_input, rnd, 0.01); + + fill_with_gassuan_random_numbers(output, rnd); + } + + + tensor& get_mutable_output() { return output; } + const tensor& get_output() const { return output; } + const tensor& private_get_output() const { return get_output(); } + const test_layer_subnet& subnet() const { init_sub(); return *subnetwork; } + + tensor& get_gradient_input() { return gradient_input; } + tensor& private_get_gradient_input() { return get_gradient_input(); } + test_layer_subnet& subnet() { init_sub(); return *subnetwork; } + + + + unsigned long count_outputs() const + { + if (subnetwork) + return subnetwork->count_outputs() + output.size(); + else + return output.size(); + } + + float& get_output_element(unsigned long i) + { + if (i < output.size()) + return output.host()[i]; + else + return subnet().get_output_element(i-output.size()); + } + + float get_gradient_input_element(unsigned long i) const + { + if (i < gradient_input.size()) + return gradient_input.host()[i]; + else + return subnet().get_gradient_input_element(i-gradient_input.size()); + } + + + private: + // We lazily initialize sub-layers as needed when someone tries to call + // subnet() + void init_sub() const + { + if (!subnetwork) + subnetwork.reset(new test_layer_subnet(rnd)); + } + + dlib::rand& rnd; + mutable std::unique_ptr subnetwork; + resizable_tensor output; + resizable_tensor gradient_input; + }; + + } + + struct layer_test_results + { + layer_test_results() : was_good(true) {} + explicit layer_test_results(const std::string& l) : log(l),was_good(false) {} + + std::string log; + bool was_good; + + operator bool() const { return was_good; } + }; + + inline std::ostream& operator<< (std::ostream& out, const layer_test_results& item) + { + out << item.log; + return out; + } + + template < + typename layer_details_type + > + layer_test_results impl_test_layer ( + layer_details_type l, + const float base_eps + ) + { + using namespace timpl; + // Do some setup + running_stats rs_data, rs_params; + dlib::rand rnd; + std::ostringstream sout; + for (int iter = 0; iter < 10; ++iter) + { + test_layer_subnet subnetwork(rnd); + resizable_tensor output, out2, out3; + // Run setup() and forward() as well to make sure any calls to subnet() have + // happened before we start assuming we know how many data elements there are + // (since we do a lazy layer creation thing based on calls to subnet() inside + // test_layer_subnet). + l.setup(subnetwork); + impl::call_layer_forward(l, subnetwork, output); + + resizable_tensor input_grad; + input_grad.copy_size(output); + fill_with_gassuan_random_numbers(input_grad, rnd); + + + // The f() we are computing gradients of is this thing. It's value at the current + // parameter and data values is: + //sout << "f(data,params): " << dot(output, input_grad) << std::endl; + + // We are going to save a copy of the subnetwork.get_gradient_input() data before we do + // backpropagation since the backward() function is supposed to *add* to the + // gradients rather than overwrite them. We will use this saved data to check if + // that is the case. + const unsigned long num_data_inputs = subnetwork.count_outputs(); + std::vector initial_gradient_input(num_data_inputs); + for (unsigned long i = 0; i < num_data_inputs; ++i) + initial_gradient_input[i] = subnetwork.get_gradient_input_element(i); + + + // Now tell the layer to compute all the gradients. In the rest of this function + // we will just be checking that these gradients were computed correctly by + // comparing them to a central differences approximation. + resizable_tensor params_grad; + params_grad.copy_size(l.get_layer_params()); + // But first, set the params grad to something crazy so that it's very obvious if + // it doesn't get fully assigned. + params_grad = std::numeric_limits::infinity(); + impl::call_layer_backward(l, output, input_grad, subnetwork, params_grad); + + static_assert(impl::is_inplace_layer(l, subnetwork) == impl::has_inplace_backward(l, subnetwork), + "Layer not defined correctly. forward and backward methods must either both be in-place or both out-of-place. "); + + // Make sure the outputs of forward() and backward() are the same when they are run + // in in-place mode. + if (impl::is_inplace_layer(l, subnetwork)) + { + test_layer_subnet subnetwork2(rnd); + layer_details_type ll(l); + ll.setup(subnetwork2); + resizable_tensor ip_out; + impl::call_layer_forward(ll, subnetwork2, ip_out); + impl::call_layer_forward(ll, subnetwork2, subnetwork2.get_mutable_output()); + const auto forward_error = max(abs(mat(ip_out) - mat(subnetwork2.get_output()))); + if (forward_error > 0.00001) + { + using namespace std; + sout << "This layer is supposed to support in-place computations but the output of forward_inplace()\n"; + sout << "changes when invoked in-place vs. out-of-place. The error was: " << forward_error << endl; + return layer_test_results(sout.str()); + } + + resizable_tensor params_grad; + params_grad.copy_size(ll.get_layer_params()); + params_grad = std::numeric_limits::infinity(); + + resizable_tensor input_grad; + input_grad.copy_size(ip_out); + fill_with_gassuan_random_numbers(input_grad, rnd); + resizable_tensor params_grad1, params_grad2, data_grad1, data_grad2; + params_grad1 = params_grad; + params_grad2 = params_grad; + // Now call backward() and make sure it works as well. Recall that when an + // in-place layer works in-place it assigns to it's outputs but when it's + // not running in-place it adds. So we initialize to a non-zero value to + // check that this is the behavior that really executes. + subnetwork2.get_gradient_input() = 9; + impl::call_layer_backward(ll, ip_out, input_grad, subnetwork2, params_grad1); + data_grad1 = subnetwork2.get_gradient_input(); + + subnetwork2.get_gradient_input() = mat(input_grad); + impl::call_layer_backward(ll, ip_out, subnetwork2.get_gradient_input(), subnetwork2, params_grad2); + data_grad2 = subnetwork2.get_gradient_input(); + if (params_grad.size() != 0) + { + const auto backward_param_error = max(abs(mat(params_grad1) - mat(params_grad2))); + if (backward_param_error > 0.00001) + { + using namespace std; + sout << "This layer is supposed to support in-place computations but the output of backward_inplace()\n"; + sout << "changes when invoked in-place vs. out-of-place. The error was: " << backward_param_error << endl; + return layer_test_results(sout.str()); + } + } + const auto backward_data_error = max(abs(mat(data_grad1)-9 - mat(data_grad2))); + if (backward_data_error > 0.00001) + { + using namespace std; + sout << "This layer is supposed to support in-place computations but the output of backward_inplace()\n"; + sout << "changes when invoked in-place vs. out-of-place. The error was: " << backward_data_error << endl; + return layer_test_results(sout.str()); + } + } + + // ================================================================== + // first validate the way the parameter gradients are computed + for (unsigned long i = 0; i < params_grad.size(); ++i) + { + layer_details_type l1(l); + + float eps = l1.get_layer_params().host()[i]*base_eps; + if (eps == 0) + eps = base_eps; + const float oldval = l1.get_layer_params().host()[i]; + l1.get_layer_params().host()[i] = oldval+eps; + impl::call_layer_forward(l1, subnetwork, out2); + l1.get_layer_params().host()[i] = oldval-eps; + impl::call_layer_forward(l1, subnetwork, out3); + l1.get_layer_params().host()[i] = oldval; + + // Compute a reference derivative via a central differences approximation and + // compare it to the one output by the layer and make sure they match. + double reference_derivative = (dot(out2,input_grad)-dot(out3, input_grad))/(2*eps); + double output_derivative = params_grad.host()[i]; + double relative_error; + if (reference_derivative*output_derivative != 0) + relative_error = (reference_derivative - output_derivative)/(reference_derivative); + else + relative_error = (reference_derivative - output_derivative); + double absolute_error = (reference_derivative - output_derivative); + rs_params.add(std::abs(relative_error)); + if (std::abs(relative_error) > 0.05 && std::abs(absolute_error) > 0.006) + { + using namespace std; + sout << "Gradient error in parameter #" << i <<". Relative error: "<< relative_error << endl; + sout << "expected derivative: " << reference_derivative << endl; + sout << "output derivative: " << output_derivative << endl; + sout << "iteration: " << iter << endl; + return layer_test_results(sout.str()); + } + } + + // ================================================================== + // now validate the data gradients + for (unsigned long i = 0; i < num_data_inputs; ++i) + { + const float oldval = subnetwork.get_output_element(i); + float eps = oldval*base_eps; + if (eps == 0) + eps = base_eps; + subnetwork.get_output_element(i) = oldval+eps; + impl::call_layer_forward(l, subnetwork, out2); + subnetwork.get_output_element(i) = oldval-eps; + impl::call_layer_forward(l, subnetwork, out3); + subnetwork.get_output_element(i) = oldval; + + // Compute a reference derivative via a central differences approximation and + // compare it to the one output by the layer and make sure they match. + double reference_derivative = (dot(out2,input_grad)-dot(out3, input_grad))/(2*eps); + double output_derivative = subnetwork.get_gradient_input_element(i); + output_derivative -= initial_gradient_input[i]; + double relative_error; + if (reference_derivative*output_derivative != 0) + relative_error = (reference_derivative - output_derivative)/(reference_derivative); + else + relative_error = (reference_derivative - output_derivative); + double absolute_error = (reference_derivative - output_derivative); + rs_data.add(std::abs(relative_error)); + if (std::abs(relative_error) > 0.05 && std::abs(absolute_error) > 0.006) + { + using namespace std; + sout << "Gradient error in data variable #" << i <<". Relative error: "<< relative_error << endl; + sout << "expected derivative: " << reference_derivative << endl; + sout << "output derivative: " << output_derivative << endl; + sout << "iteration: " << iter << endl; + return layer_test_results(sout.str()); + } + } + + } // end for (int iter = 0; iter < 10; ++iter) + + if (rs_params.mean() > 0.003) + { + using namespace std; + sout << "Average parameter gradient error is somewhat large at: "<< rs_params.mean() << endl; + return layer_test_results(sout.str()); + } + if (rs_data.mean() > 0.003) + { + using namespace std; + sout << "Average data gradient error is somewhat large at: "<< rs_data.mean() << endl; + return layer_test_results(sout.str()); + } + + return layer_test_results(); + } + + template < + typename layer_details_type + > + layer_test_results test_layer ( + layer_details_type l + ) + { + // Try a few different derivative step sizes to see if any work. + for (float base_eps = 0.0001; base_eps < 0.1; base_eps *= 2) + { + auto result = impl_test_layer(l, base_eps); + if (result) + return result; + } + // However, if none of the step sizes worked then try this one and probably result + // in returning an error. + return impl_test_layer(l, 0.01); + } + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + template + struct vlp_loop + { + template + static typename std::enable_if::value>::type invoke_functor(T&& , size_t& , U&& ) + { + // intentionally left empty + } + + template + static typename std::enable_if::value>::type invoke_functor(T&& v , size_t& comp_i, U&& l ) + { + v(comp_i, l.layer_details().get_layer_params()); + ++comp_i; + } + + template < + typename net_type, + typename visitor + > + static void visit( + size_t comp_i, + net_type& net, + visitor&& v + ) + { + invoke_functor(v, comp_i, layer(net)); + vlp_loop::visit(comp_i, net,v); + } + }; + + template + struct vlp_loop + { + template < + typename net_type, + typename visitor + > + static void visit( + size_t, + net_type&, + visitor&& + ) + { + // Base case of recursion. Don't do anything. + } + }; + + } + + template < + typename net_type, + typename visitor + > + void visit_layer_parameters( + net_type& net, + visitor v + ) + { + size_t comp_i = 0; + impl::vlp_loop<0, net_type::num_layers>::visit(comp_i, net, v); + } + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + template + struct vlpg_loop + { + template + static typename std::enable_if::value>::type invoke_functor(T&& , size_t& , U&& ) + { + // intentionally left empty + } + + template + static typename std::enable_if::value>::type invoke_functor(T&& v , size_t& comp_i, U&& l ) + { + v(comp_i, l.get_parameter_gradient()); + ++comp_i; + } + + template < + typename net_type, + typename visitor + > + static void visit( + size_t comp_i, + net_type& net, + visitor&& v + ) + { + invoke_functor(v, comp_i, layer(net)); + vlpg_loop::visit(comp_i, net,v); + } + }; + + template + struct vlpg_loop + { + template < + typename net_type, + typename visitor + > + static void visit( + size_t, + net_type&, + visitor&& + ) + { + // Base case of recursion. Don't do anything. + } + }; + + } + + template < + typename net_type, + typename visitor + > + void visit_layer_parameter_gradients( + net_type& net, + visitor v + ) + { + size_t comp_i = 0; + impl::vlpg_loop<0, net_type::num_layers>::visit(comp_i, net, v); + } + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + template + struct vl_loop + { + template < + typename net_type, + typename visitor + > + static void visit( + net_type& net, + visitor&& v + ) + { + v(i, layer(net)); + vl_loop::visit(net,v); + } + }; + + template + struct vl_loop + { + template < + typename net_type, + typename visitor + > + static void visit( + net_type&, + visitor&& + ) + { + // Base case of recursion. Don't do anything. + } + }; + + template + struct vl_loop_backwards + { + template < + typename net_type, + typename visitor + > + static void visit( + net_type& net, + visitor&& v + ) + { + vl_loop_backwards::visit(net,v); + v(i, layer(net)); + } + }; + + template + struct vl_loop_backwards + { + template < + typename net_type, + typename visitor + > + static void visit( + net_type&, + visitor&& + ) + { + // Base case of recursion. Don't do anything. + } + }; + + } + + template < + typename net_type, + typename visitor + > + void visit_layers( + net_type& net, + visitor v + ) + { + impl::vl_loop<0, net_type::num_layers>::visit(net, v); + } + + template < + typename net_type, + typename visitor + > + void visit_layers_backwards( + net_type& net, + visitor v + ) + { + impl::vl_loop_backwards<0, net_type::num_layers>::visit(net, v); + } + + template < + size_t begin, + size_t end, + typename net_type, + typename visitor + > + void visit_layers_range( + net_type& net, + visitor v + ) + { + static_assert(begin <= end, "Invalid range"); + static_assert(end <= net_type::num_layers, "Invalid range"); + impl::vl_loop::visit(net, v); + } + + template < + size_t begin, + size_t end, + typename net_type, + typename visitor + > + void visit_layers_backwards_range( + net_type& net, + visitor v + ) + { + static_assert(begin <= end, "Invalid range"); + static_assert(end <= net_type::num_layers, "Invalid range"); + impl::vl_loop_backwards::visit(net, v); + } + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + template + struct vl_until_tag + { + template < + typename net_type, + typename next_net_type, + typename visitor + > + static void visit( + net_type& net, + next_net_type& next_net, + visitor&& v + ) + { + v(next_net); + vl_until_tag::visit(net,layer(net),v); + } + + template < + typename net_type, + typename SUBNET, + typename visitor + > + static void visit( + net_type& net, + const add_tag_layer& next_net, + visitor&& v + ) + { + v(next_net); + } + + template < + typename net_type, + typename SUBNET, + typename visitor + > + static void visit( + net_type& net, + add_tag_layer& next_net, + visitor&& v + ) + { + v(next_net); + } + }; + } + + template < + unsigned long tag_id, + typename net_type, + typename visitor + > + void visit_layers_until_tag( + net_type& net, + visitor v + ) + { + impl::vl_until_tag<0,tag_id>::visit(net, net, v); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_DNn_CORE_H_ + + diff --git a/ml/dlib/dlib/dnn/core_abstract.h b/ml/dlib/dlib/dnn/core_abstract.h new file mode 100644 index 000000000..db168a88b --- /dev/null +++ b/ml/dlib/dlib/dnn/core_abstract.h @@ -0,0 +1,1700 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_DNn_CORE_ABSTRACT_H_ +#ifdef DLIB_DNn_CORE_ABSTRACT_H_ + +#include "tensor_abstract.h" +#include +#include +#include +#include +#include "../rand.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename... T + > + auto tuple_tail( + const std::tuple& item + ); + /*! + ensures + - returns a tuple that contains everything in item except for tuple_head(item). + The items will be in the same order as they are in item, just without + tuple_head(item). + - This function will correctly handle nested tuples. + !*/ + + template + auto tuple_head ( + const std::tuple& item + ); + /*! + ensures + - returns a copy of the first thing in the tuple that isn't a std::tuple. + Essentially, this function calls std::get<0>() recursively on item until + a non-std::tuple object is found. + !*/ + +// ---------------------------------------------------------------------------------------- + + template + double get_learning_rate_multiplier( + const T& obj + ); + /*! + ensures + - if (obj has a get_learning_rate_multiplier() member function) then + - returns obj.get_learning_rate_multiplier() + - else + - returns 1 + !*/ + + template + double get_weight_decay_multiplier( + const T& obj + ); + /*! + ensures + - if (obj has a get_weight_decay_multiplier() member function) then + - returns obj.get_weight_decay_multiplier() + - else + - returns 1 + !*/ + +// ---------------------------------------------------------------------------------------- + + bool dnn_prefer_fastest_algorithms( + ); + /*! + ensures + - If dlib should prefer to use fast algorithms rather than ones that use less + RAM then this function returns true and false otherwise. + - On program startup this function will default to true. + !*/ + + void set_dnn_prefer_fastest_algorithms( + ); + /*! + ensures + - #dnn_prefer_fastest_algorithms() == true + !*/ + + void set_dnn_prefer_smallest_algorithms( + ); + /*! + ensures + - #dnn_prefer_fastest_algorithms() == false + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + class sstack + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a basic stack of T objects. It contains no data itself but simply + points to a memory range of T object and allows you to access that block of + T objects as a stack. + !*/ + + public: + typedef T value_type; + + sstack() = delete; + + sstack ( + T* data, + size_t s + ); + /*! + ensures + - #size() == s + - #top() == *data + - #pop(i).top() == data[i] + !*/ + + const T& top( + ) const; + /*! + requires + - size() != 0 + ensures + - returns the top element of the stack. + !*/ + + T& top( + ); + /*! + requires + - size() != 0 + ensures + - returns the top element of the stack. + !*/ + + size_t size( + ) const; + /*! + ensures + - returns the number of elements in this stack. + !*/ + + sstack pop( + size_t num = 1 + ); + /*! + requires + - num <= size() + ensures + - returns a reference to the sub-stack S such that: + - S.size() == size()-num. + - S.top() is num elements down the stack. + !*/ + }; + + template < + typename T + > + sstack make_sstack( + std::vector& item + ) { return sstack(item.data(), item.size()); } + /*! + ensures + - returns a sstack that sits on top of the given std::vector. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename LAYER_DETAILS, + typename SUBNET + > + class add_layer + { + /*! + REQUIREMENTS ON LAYER_DETAILS + - Must be a type that implements the EXAMPLE_COMPUTATIONAL_LAYER_ interface + defined in layers_abstract.h + + REQUIREMENTS ON SUBNET + - One of the following must be true: + - SUBNET implements the EXAMPLE_INPUT_LAYER interface defined in + input_abstract.h. + - SUBNET is an add_layer object. + - SUBNET is an add_tag_layer object. + - SUBNET is an add_skip_layer object. + - SUBNET is a repeat object. + + WHAT THIS OBJECT REPRESENTS + This object represents a deep neural network. In particular, it is a tool + for adding another layer on top of the neural network of type SUBNET, which + is specified as a template argument. The specific layer added is defined + by the LAYER_DETAILS details template argument. + !*/ + + public: + typedef LAYER_DETAILS layer_details_type; + typedef SUBNET subnet_type; + typedef typename subnet_type::input_type input_type; + // num_computational_layers will always give the number of layers in the network + // that transform tensors (i.e. layers defined by something that implements the + // EXAMPLE_COMPUTATIONAL_LAYER_ interface). This is all the layers except for + // loss, tag, and skip layers. + const static size_t num_computational_layers = subnet_type::num_computational_layers + 1; + // num_layers counts all the layers in the network regardless of their type. + const static size_t num_layers = subnet_type::num_layers + 1; + + add_layer( + ); + /*! + ensures + - default constructs all the layers in this network. + - #sample_expansion_factor() == 0 + !*/ + + add_layer(const add_layer&) = default; + add_layer(add_layer&&) = default; + add_layer& operator=(add_layer&&) = default; + add_layer& operator=(const add_layer&) = default; + /*! + ensures + - this object is copyable and movable. + !*/ + + template + add_layer( + const add_layer& item + ); + /*! + ensures + - This constructor allows you to copy neural network objects from one to + another as long as their corresponding layers can be constructed from + each other. + - #layer_details() == layer_details_type(item.layer_details()) + - #subnet() == subnet_type(item.subnet()) + - #sample_expansion_factor() == item.sample_expansion_factor() + !*/ + + template + add_layer( + const std::tuple& layer_det, + T&& ...args + ); + /*! + ensures + - #layer_details() == layer_details_type(tuple_head(layer_det)) + - #subnet() == subnet_type(tuple_tail(layer_det),args) + - #sample_expansion_factor() == 0 + !*/ + + template + add_layer( + const layer_details_type& layer_det, + T&& ...args + ); + /*! + ensures + - #layer_details() == layer_details_type(layer_det) + - #subnet() == subnet_type(args) + - #sample_expansion_factor() == 0 + !*/ + + template + add_layer( + T&& ...args + ); + /*! + ensures + - This version of the constructor is only called if layer_details_type + can't be constructed from the first thing in args. In this case, the + args are simply passed on to the sub layers in their entirety. + - #layer_details() == layer_details_type() + - #subnet() == subnet_type(args) + - #sample_expansion_factor() == 0 + !*/ + + template + add_layer( + layer_details_type&& layer_det, + T&& ...args + ); + /*! + ensures + - #layer_details() == layer_det + - #subnet() == subnet_type(args) + - #sample_expansion_factor() == 0 + !*/ + + template + void to_tensor ( + forward_iterator ibegin, + forward_iterator iend, + resizable_tensor& data + ) const; + /*! + requires + - [ibegin, iend) is an iterator range over input_type objects. + - std::distance(ibegin,iend) > 0 + ensures + - Converts the iterator range into a tensor and stores it into #data. + - #data.num_samples()%distance(ibegin,iend) == 0. + - #sample_expansion_factor() == #data.num_samples()/distance(ibegin,iend). + - #sample_expansion_factor() > 0 + - The data in the ith sample of #data corresponds to the input_type object + *(ibegin+i/#sample_expansion_factor()). + - Invokes data.async_copy_to_device() so that the data begins transferring + to the GPU device, if present. + - This function is implemented by calling the to_tensor() routine defined + at the input layer of this network. + !*/ + + unsigned int sample_expansion_factor ( + ) const; + /*! + ensures + - When to_tensor() is invoked on this network's input layer it converts N + input objects into M samples, all stored inside a resizable_tensor. It + is always the case that M is some integer multiple of N. + sample_expansion_factor() returns the value of this multiplier. To be + very specific, it is always true that M==I*N where I is some integer. + This integer I is what is returned by sample_expansion_factor(). + !*/ + + const subnet_type& subnet( + ) const; + /*! + ensures + - returns the immediate subnetwork of *this network. + !*/ + + subnet_type& subnet( + ); + /*! + ensures + - returns the immediate subnetwork of *this network. + !*/ + + const layer_details_type& layer_details( + ) const; + /*! + ensures + - returns the layer_details_type instance that defines the behavior of the + layer at the top of this network. I.e. returns the layer details that + defines the behavior of the layer nearest to the network output rather + than the input layer. + !*/ + + layer_details_type& layer_details( + ); + /*! + ensures + - returns the layer_details_type instance that defines the behavior of the + layer at the top of this network. I.e. returns the layer details that + defines the behavior of the layer nearest to the network output rather + than the input layer. + !*/ + + template + const tensor& operator() ( + forward_iterator ibegin, + forward_iterator iend + ); + /*! + requires + - [ibegin, iend) is an iterator range over input_type objects. + - std::distance(ibegin,iend) > 0 + ensures + - runs [ibegin,iend) through the network and returns the results. + In particular, this function performs: + to_tensor(ibegin,iend,temp_tensor); + return forward(temp_tensor); + - The return value from this function is also available in #get_output(). + i.e. this function returns #get_output(). + - have_same_dimensions(#get_gradient_input(), #get_output()) == true. + - All elements of #get_gradient_input() are set to 0. + i.e. calling this function clears out #get_gradient_input() and ensures + it has the same dimensions as the most recent output. + !*/ + + const tensor& operator() ( + const input_type& x + ); + /*! + ensures + - runs a single x through the network and returns the output. + I.e. returns (*this)(&x, &x+1); + !*/ + + const tensor& forward( + const tensor& x + ); + /*! + requires + - sample_expansion_factor() != 0 + (i.e. to_tensor() must have been called to set sample_expansion_factor() + to something non-zero.) + - x.num_samples()%sample_expansion_factor() == 0 + - x.num_samples() > 0 + ensures + - Runs x through the network and returns the results. In particular, this + function performs the equivalent of: + subnet().forward(x); + if (this is the first time forward() has been called) then + layer_details().setup(subnet()); + layer_details().forward(subnet(), get_output()); + - The return value from this function is also available in #get_output(). + i.e. this function returns #get_output(). + - have_same_dimensions(#get_gradient_input(), #get_output()) == true + - All elements of #get_gradient_input() are set to 0. + i.e. calling this function clears out #get_gradient_input() and ensures + it has the same dimensions as the most recent output. + !*/ + + const tensor& get_output( + ) const; + /*! + ensures + - returns the output for the last tensor that was run through the network. + If nothing has been run through the network yet then returns an empty + tensor. + !*/ + + tensor& get_gradient_input( + ); + /*! + ensures + - returns the error gradient for this network. That is, this is the error + gradient that this network will use to compute parameter gradients when + back_propagate_error() is called. Therefore, when performing back + propagation, layers that sit on top of this network layer write their + back-propagated error gradients into get_gradient_input(). Or to put it + another way, during back-propagation, layers take the contents of their + get_gradient_input() and back-propagate it through themselves and store + the result into their subnetwork's get_gradient_input(). + + This means you should consider get_gradient_input() as an input to the + back_propagate_error() method. + !*/ + + const tensor& get_final_data_gradient( + ) const; + /*! + ensures + - if back_propagate_error() has been called to back-propagate a gradient + through this network then you can call get_final_data_gradient() to + obtain the last data gradient computed. That is, this function returns + the gradient of the network with respect to its inputs. + - Note that there is only one "final data gradient" for an entire network, + not one per layer, since there is only one input to the entire network. + !*/ + + const tensor& get_parameter_gradient( + ) const; + /*! + ensures + - if back_propagate_error() has been called then you can call + get_parameter_gradient() to find the gradient of this layer's parameters. + When we update the parameters by calling update_parameters(), it will use + the gradient in get_parameter_gradient() to perform the update. + Therefore, you should consider get_parameter_gradient() as an input to + update_parameters(). + !*/ + + tensor& get_parameter_gradient ( + ); + /*! + ensures + - returns a non-const reference to the tensor returned by the above + get_parameter_gradient() method. You could use this method to modify the + parameter gradient in some way before invoking update_parameters(). + !*/ + + void back_propagate_error( + const tensor& x + ); + /*! + requires + - forward(x) was called to forward propagate x though the network. + Moreover, this was the most recent call to forward() and x has not been + subsequently modified in any way. + - get_gradient_input() has been set equal to the gradient of this network's + output with respect to some loss function. + ensures + - Back propagates the error gradient, get_gradient_input(), through this + network and computes parameter and data gradients, via backpropagation. + Specifically, this function populates get_final_data_gradient() and also, + for each layer, the tensor returned by get_parameter_gradient(). + - All elements of #get_gradient_input() are set to 0. + - have_same_dimensions(#get_final_data_gradient(), x) == true. + - have_same_dimensions(#get_parameter_gradient(), layer_details().get_layer_params()) == true. + - #get_final_data_gradient() contains the gradient of the network with + respect to x. + !*/ + + void back_propagate_error( + const tensor& x, + const tensor& gradient_input + ); + /*! + requires + - forward(x) was called to forward propagate x though the network. + Moreover, this was the most recent call to forward() and x has not been + subsequently modified in any way. + - have_same_dimensions(gradient_input, get_output()) == true + ensures + - This function is identical to the version of back_propagate_error() + defined immediately above except that it back-propagates gradient_input + through the network instead of get_gradient_input(). Therefore, this + version of back_propagate_error() is equivalent to performing: + get_gradient_input() = gradient_input; + back_propagate_error(x); + Except that calling back_propagate_error(x,gradient_input) avoids the + copy and is therefore slightly more efficient. + - All elements of #get_gradient_input() are set to 0. + - have_same_dimensions(#get_final_data_gradient(), x) == true. + - have_same_dimensions(#get_parameter_gradient(), layer_details().get_layer_params()) == true. + - #get_final_data_gradient() contains the gradient of the network with + respect to x. + !*/ + + template + void update_parameters( + sstack solvers, + double learning_rate + ); + /*! + requires + - solver_type is an implementation of the EXAMPLE_SOLVER interface defined + in solvers_abstract.h + - back_propagate_error() has been called. + - The given solvers have only ever been used with this network. That is, + if you want to call update_parameters() on some other neural network + object then you must NOT reuse the same solvers object. + - solvers.size() >= num_computational_layers + - 0 < learning_rate <= 1 + ensures + - Updates all the parameters in the network. In particular, we pass each + layer's parameter gradient (i.e. the tensor returned by the layer's + get_parameter_gradient() member) through that layer's corresponding + solver object. This produces a parameter delta vector which we add to + the layer's parameters. + - The solvers use the given learning rate. + !*/ + + void clean( + ); + /*! + ensures + - Causes the network to forget about everything but its parameters. + That is, for each layer we will have: + - get_output().num_samples() == 0 + - get_gradient_input().num_samples() == 0 + However, running new input data though this network will still produce + the same output it would have produced regardless of any calls to + clean(). The purpose of clean() is to compact the network object prior + to saving it to disk so that it takes up less space and the IO is + quicker. + - This also calls the .clean() method on any layer details objects that + define a .clean() method. + !*/ + + }; + + template + std::ostream& operator<<(std::ostream& out, const add_layer& item); + /*! + prints the network architecture to the given output stream. + !*/ + + template + void serialize(const add_layer& item, std::ostream& out); + template + void deserialize(add_layer& item, std::istream& in); + /*! + provides serialization support + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class no_label_type; + + template < + typename LOSS_DETAILS, + typename SUBNET + > + class add_loss_layer + { + /*! + REQUIREMENTS ON LOSS_DETAILS + - Must be a type that implements the EXAMPLE_LOSS_LAYER_ interface defined + in loss_abstract.h + + REQUIREMENTS ON SUBNET + - One of the following must be true: + - SUBNET is an add_layer object. + - SUBNET is an add_tag_layer object. + - SUBNET is an add_skip_layer object. + - SUBNET is a repeat object. + + WHAT THIS OBJECT REPRESENTS + This object represents a deep neural network. In particular, it is a tool + for adding a loss layer on top of the neural network of type SUBNET, which + is specified as a template argument. The specific layer added is defined + by the LOSS_DETAILS details template argument. Importantly, a loss layer + is the last layer in a deep neural network. So once it is added you can't + add any other layers of any type. + !*/ + + public: + typedef LOSS_DETAILS loss_details_type; + typedef SUBNET subnet_type; + typedef typename subnet_type::input_type input_type; + const static size_t num_computational_layers = subnet_type::num_computational_layers; + const static size_t num_layers = subnet_type::num_layers + 1; + // If LOSS_DETAILS is an unsupervised loss then training_label_type==no_label_type. + // Otherwise it is defined as follows: + typedef typename LOSS_DETAILS::training_label_type training_label_type; + // Similarly, if LOSS_DETAILS doesn't provide any output conversion then + // output_label_type==no_label_type. + typedef typename LOSS_DETAILS::output_label_type output_label_type; + + + + add_loss_layer() = default; + /*! + ensures + - default constructs all the layers in this network. + !*/ + + add_loss_layer(const add_loss_layer&) = default; + add_loss_layer(add_loss_layer&&) = default; + add_loss_layer& operator=(add_loss_layer&&) = default; + add_loss_layer& operator=(const add_loss_layer&) = default; + /*! + ensures + - this object is copyable and movable. + !*/ + + template + add_loss_layer( + const add_loss_layer& item + ); + /*! + ensures + - This constructor allows you to copy neural network objects from one to + another as long as their corresponding layers can be constructed from + each other. + - #loss_details() == loss_details_type(item.loss_details()) + - #subnet() == subnet_type(item.subnet()) + !*/ + + template + add_loss_layer( + const LOSS_DETAILS& layer_det, + T&& ...args + ); + /*! + ensures + - #loss_details() == loss_details_type(layer_det) + - #subnet() == subnet_type(args) + !*/ + + template + add_loss_layer( + LOSS_DETAILS&& layer_det, + T&& ...args + ); + /*! + ensures + - #loss_details() == loss_details_type(layer_det) + - #subnet() == subnet_type(args) + !*/ + + template + add_loss_layer( + T&& ...args + ); + /*! + ensures + - This version of the constructor is only called if loss_details_type can't + be constructed from the first thing in args. In this case, the args are + simply passed on to the sub layers in their entirety. + - #loss_details() == loss_details_type() + - #subnet() == subnet_type(args) + !*/ + + const subnet_type& subnet( + ) const; + /*! + ensures + - returns the immediate subnetwork of *this network. + !*/ + + subnet_type& subnet( + ); + /*! + ensures + - returns the immediate subnetwork of *this network. + !*/ + + const loss_details_type& loss_details( + ) const; + /*! + ensures + - returns the loss_details_type instance that defines the behavior of the + loss layer used by this network. + !*/ + + loss_details_type& loss_details( + ); + /*! + ensures + - returns the loss_details_type instance that defines the behavior of the + loss layer used by this network. + !*/ + + template + void to_tensor ( + forward_iterator ibegin, + forward_iterator iend, + resizable_tensor& data + ) const; + /*! + requires + - [ibegin, iend) is an iterator range over input_type objects. + - std::distance(ibegin,iend) > 0 + ensures + - Converts the iterator range into a tensor and stores it into #data. + - #data.num_samples()%distance(ibegin,iend) == 0. + - #sample_expansion_factor() == #data.num_samples()/distance(ibegin,iend). + - #sample_expansion_factor() > 0 + - The data in the ith sample of #data corresponds to the input_type object + *(ibegin+i/sample_expansion_factor()). + - Invokes data.async_copy_to_device() so that the data begins transferring + to the GPU device, if present. + - This function is implemented by calling the to_tensor() routine defined + at the input layer of this network. + !*/ + + unsigned int sample_expansion_factor ( + ) const; + /*! + ensures + - When to_tensor() is invoked on this network's input layer it converts N + input objects into M samples, all stored inside a resizable_tensor. It + is always the case that M is some integer multiple of N. + sample_expansion_factor() returns the value of this multiplier. To be + very specific, it is always true that M==I*N where I is some integer. + This integer I is what is returned by sample_expansion_factor(). + !*/ + + // ------------- + + template + void operator() ( + const tensor& x, + output_iterator obegin + ); + /*! + requires + - sample_expansion_factor() != 0 + (i.e. to_tensor() must have been called to set sample_expansion_factor() + to something non-zero.) + - x.num_samples()%sample_expansion_factor() == 0 + - x.num_samples() > 0 + - obegin == iterator pointing to the start of a range of + x.num_samples()/sample_expansion_factor() output_label_type elements. + ensures + - runs x through the network and writes the output to the range at obegin. + - loss_details().to_label() is used to write the network output into + obegin. + !*/ + + template + void operator() ( + forward_iterator ibegin, + forward_iterator iend, + label_iterator obegin + ); + /*! + requires + - [ibegin, iend) is an iterator range over input_type objects. + - std::distance(ibegin,iend) > 0 + - obegin == iterator pointing to the start of a range of + std::distance(ibegin,iend) output_label_type elements. + ensures + - runs [ibegin,iend) through the network and writes the output to the range + at obegin. + - loss_details().to_label() is used to write the network output into + obegin. + !*/ + + // ------------- + + const output_label_type& operator() ( + const input_type& x + ); + /*! + ensures + - runs a single object, x, through the network and returns the output. + - loss_details().to_label() is used to convert the network output into a + output_label_type. + !*/ + + template + std::vector operator() ( + const iterable_type& data, + size_t batch_size = 128 + ); + /*! + requires + - batch_size > 0 + - data must have a .begin() and .end() that supply iterators over a + sequence of input_type elements. E.g. data could have a type of + std::vector + ensures + - runs all the objects in data through the network and returns their + predicted labels. This means this function returns a vector V such that: + - V.size() == data.size() + - for all valid i: V[i] == the predicted label of data[i]. + - Elements of data are run through the network in batches of batch_size + items. Using a batch_size > 1 can be faster because it better exploits + the available hardware parallelism. + - loss_details().to_label() is used to convert the network output into a + output_label_type. + !*/ + + template + const output_label_type& process ( + const input_type& x, + T&& ...args + ); + /*! + ensures + - This function is just like (*this)(x), i.e. it runs a single object, x, + through the network and returns the output. But we additionally pass the + given args to loss_details().to_label() as the 4th argument (or more, + depending on how many things are in args) when converting the network + output to an output_label_type. This is useful, for instance, with loss + layers like loss_mmod_ which has an optional adjust_threshold argument to + to_label() that adjusts the detection threshold. Therefore, for such + networks you could call them like: net.process(some_image, -0.5), and -0.5 + would be passed so the adjust_threshold argument of to_tensor(). + !*/ + + template + std::vector process_batch ( + const iterable_type& data, + size_t batch_size, + T&& ...args + ); + /*! + requires + - batch_size > 0 + - data must have a .begin() and .end() that supply iterators over a + sequence of input_type elements. E.g. data could have a type of + std::vector + ensures + - This function is just like (*this)(data,batch_size), i.e. it runs a + bunch of objects through the network and returns the outputs. But we + additionally pass the given args to loss_details().to_label() as the 4th + argument (or more, depending on how many things are in args) when + converting the network output to output_label_types. This is useful, + for instance, with loss layers like loss_mmod_ which has an optional + adjust_threshold argument to to_label() that adjusts the detection + threshold. Therefore, for such networks you could call them like: + net.process_batch(std::vector({some_image, another_image}), 128, -0.5), + and -0.5 would be passed so the adjust_threshold argument of to_tensor(). + !*/ + + // ------------- + + template + double compute_loss ( + const tensor& x, + label_iterator lbegin + ); + /*! + requires + - sample_expansion_factor() != 0 + (i.e. to_tensor() must have been called to set sample_expansion_factor() + to something non-zero.) + - x.num_samples()%sample_expansion_factor() == 0 + - x.num_samples() > 0 + - lbegin == iterator pointing to the start of a range of + x.num_samples()/sample_expansion_factor() training_label_type elements. + ensures + - runs x through the network, compares the output to the expected output + pointed to by lbegin, and returns the resulting loss. + - for all valid k: + - the expected label of the kth sample in x is *(lbegin+k/sample_expansion_factor()). + - This function does not update the network parameters. + !*/ + + template + double compute_loss ( + forward_iterator ibegin, + forward_iterator iend, + label_iterator lbegin + ); + /*! + requires + - [ibegin, iend) is an iterator range over input_type objects. + - std::distance(ibegin,iend) > 0 + - lbegin == iterator pointing to the start of a range of + std::distance(ibegin,iend) training_label_type elements. + ensures + - runs [ibegin,iend) through the network, compares the output to the + expected output pointed to by lbegin, and returns the resulting loss. + - for all valid k: + - the expected label of *(ibegin+k) is *(lbegin+k). + - This function does not update the network parameters. + !*/ + + // ------------- + + double compute_loss ( + const tensor& x + ); + /*! + requires + - LOSS_DETAILS is an unsupervised loss. i.e. training_label_type==no_label_type. + - sample_expansion_factor() != 0 + (i.e. to_tensor() must have been called to set sample_expansion_factor() + to something non-zero.) + - x.num_samples()%sample_expansion_factor() == 0 + - x.num_samples() > 0 + ensures + - runs x through the network and returns the resulting loss. + - This function does not update the network parameters. + !*/ + + template + double compute_loss ( + forward_iterator ibegin, + forward_iterator iend, + ); + /*! + requires + - LOSS_DETAILS is an unsupervised loss. i.e. training_label_type==no_label_type. + - [ibegin, iend) is an iterator range over input_type objects. + - std::distance(ibegin,iend) > 0 + ensures + - runs [ibegin,iend) through the network and returns the resulting loss. + - This function does not update the network parameters. + !*/ + + // ------------- + + template + double compute_parameter_gradients ( + const tensor& x, + label_iterator lbegin + ); + /*! + requires + - sample_expansion_factor() != 0 + (i.e. to_tensor() must have been called to set sample_expansion_factor() + to something non-zero.) + - x.num_samples()%sample_expansion_factor() == 0 + - x.num_samples() > 0 + - lbegin == iterator pointing to the start of a range of + x.num_samples()/sample_expansion_factor() training_label_type elements. + ensures + - runs x through the network, compares the output to the expected output + pointed to by lbegin, and computes parameter and data gradients with + respect to the loss, via backpropagation. Specifically, this function + updates get_final_data_gradient() and also, for each layer, the tensor + returned by get_parameter_gradient(). + - for all valid k: + - the expected label of the kth sample in x is *(lbegin+k/sample_expansion_factor()). + - returns compute_loss(x,lbegin) + !*/ + + template + double compute_parameter_gradients ( + forward_iterator ibegin, + forward_iterator iend, + label_iterator lbegin + ); + /*! + requires + - [ibegin, iend) is an iterator range over input_type objects. + - std::distance(ibegin,iend) > 0 + - lbegin == iterator pointing to the start of a range of + std::distance(ibegin,iend) training_label_type elements. + ensures + - runs [ibegin,iend) through the network, compares the output to the + expected output pointed to by lbegin, and computes parameter and data + gradients with respect to the loss, via backpropagation. Specifically, + this function updates get_final_data_gradient() and also, for each layer, + the tensor returned by get_parameter_gradient(). + - for all valid k: + - the expected label of *(ibegin+k) is *(lbegin+k). + - returns compute_loss(ibegin,iend,lbegin) + !*/ + + double compute_parameter_gradients ( + const tensor& x + ); + /*! + requires + - LOSS_DETAILS is an unsupervised loss. i.e. training_label_type==no_label_type. + - sample_expansion_factor() != 0 + (i.e. to_tensor() must have been called to set sample_expansion_factor() + to something non-zero.) + - x.num_samples()%sample_expansion_factor() == 0 + - x.num_samples() > 0 + ensures + - runs x through the network and computes parameter and data gradients with + respect to the loss, via backpropagation. Specifically, this function + updates get_final_data_gradient() and also, for each layer, the tensor + returned by get_parameter_gradient(). + - returns compute_loss(x) + !*/ + + template + double compute_parameter_gradients ( + forward_iterator ibegin, + forward_iterator iend + ); + /*! + requires + - LOSS_DETAILS is an unsupervised loss. i.e. training_label_type==no_label_type. + - [ibegin, iend) is an iterator range over input_type objects. + - std::distance(ibegin,iend) > 0 + ensures + - runs [ibegin,iend) through the network and computes parameter and data + gradients with respect to the loss, via backpropagation. Specifically, + this function updates get_final_data_gradient() and also, for each layer, + the tensor returned by get_parameter_gradient(). + - returns compute_loss(ibegin,iend) + !*/ + + template + void update_parameters ( + sstack solvers, + double learning_rate + ); + /*! + requires + - solver_type is an implementation of the EXAMPLE_SOLVER interface defined + in solvers_abstract.h + - compute_parameter_gradients() has been called. + - The given solvers have only ever been used with this network. That + is, if you want to call update_parameters() on some other neural network + object then you must NOT reuse the same solvers object. + - solvers.size() >= num_computational_layers + - 0 < learning_rate <= 1 + ensures + - Updates all the parameters in the network. In particular, we pass each + layer's parameter gradient (i.e. the tensor returned by the layer's + get_parameter_gradient() member) through that layer's corresponding + solver object. This produces a parameter delta vector which we add to + the layer's parameters. + - The solvers use the given learning rate. + !*/ + + // ------------- + + void clean ( + ); + /*! + ensures + - Causes the network to forget about everything but its parameters. + - invokes subnet().clean() + !*/ + }; + + template + std::ostream& operator<<(std::ostream& out, const add_loss_layer& item); + /*! + prints the network architecture to the given output stream. + !*/ + + template + void serialize(const add_loss_layer& item, std::ostream& out); + template + void deserialize(add_loss_layer& item, std::istream& in); + /*! + provides serialization support + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template + decorator_repeat_group repeat_group ( + T&& ...args + ); + /*! + ensures + - Decorates a group of variables. This is essentially like std::make_tuple() + except it's only purpose is to group variables together so they can be passed + to the repeat object's constructor. + !*/ + + template < + size_t num, + template class REPEATED_LAYER, + typename SUBNET + > + class repeat + { + /*! + REQUIREMENTS ON num + - num > 0 + + REQUIREMENTS ON REPEATED_LAYER + - REPEATED_LAYER must be a template that stacks more layers onto a deep neural + network. For example, if net_type were a network without a loss layer, + then it should be legal to create a deeper network with a type of + REPEATED_LAYER. + + REQUIREMENTS ON SUBNET + - One of the following must be true: + - SUBNET is an add_layer object. + - SUBNET is an add_tag_layer object. + - SUBNET is an add_skip_layer object. + - SUBNET is a repeat object. + + WHAT THIS OBJECT REPRESENTS + This object adds more layers to a deep neural network. In particular, it + adds REPEATED_LAYER on top of SUBNET num times. So for example, if num were 2 then + repeat<2,REPEATED_LAYER,SUBNET> would create a network equivalent to REPEATED_LAYER>. + + Also, this object provides an interface identical to the one defined by the + add_layer object except that we add the num_repetitions() and + get_repeated_layer() methods. These additions are shown below along with + some additional explanatory comments. + !*/ + + public: + + typedef SUBNET subnet_type; + typedef typename SUBNET::input_type input_type; + const static size_t num_computational_layers = (REPEATED_LAYER::num_computational_layers-SUBNET::num_computational_layers)*num + SUBNET::num_computational_layers; + const static size_t num_layers = (REPEATED_LAYER::num_layers-SUBNET::num_layers)*num + SUBNET::num_layers; + typedef REPEATED_LAYER repeated_layer_type; + + template + repeat( + T arg1, + U ...args2 + ); + /*! + ensures + - arg1 is used to initialize the num_repetitions() copies of REPEATED_LAYER inside + this object. That is, all the REPEATED_LAYER elements are initialized identically + by being given copies of arg1. + - The rest of the arguments to the constructor, i.e. args2, are passed to + SUBNET's constructor. + !*/ + + template + repeat( + decorator_repeat_group&& arg1, + U ...args2 + ); + /*! + ensures + - arg1 is used to initialize the num_repetitions() copies of REPEATED_LAYER inside + this object. That is, all the REPEATED_LAYER elements are initialized identically + by being given copies of an undecorated arg1. + - The rest of the arguments to the constructor, i.e. args2, are passed to + SUBNET's constructor. + !*/ + + size_t num_repetitions ( + ) const; + /*! + ensures + - returns num (i.e. the number of times REPEATED_LAYER was stacked on top of SUBNET) + !*/ + + const repeated_layer_type& get_repeated_layer ( + size_t i + ) const; + /*! + requires + - i < num_repetitions() + ensures + - returns a reference to the i-th instance of REPEATED_LAYER. For example, + get_repeated_layer(0) returns the instance of REPEATED_LAYER that is on the top of + the network while get_repeated_layer(num_repetitions()-1) returns the + instance of REPEATED_LAYER that is stacked immediately on top of SUBNET. + !*/ + + repeated_layer_type& get_repeated_layer ( + size_t i + ); + /*! + requires + - i < num_repetitions() + ensures + - returns a reference to the i-th instance of REPEATED_LAYER. For example, + get_repeated_layer(0) returns the instance of REPEATED_LAYER that is on the top of + the network while get_repeated_layer(num_repetitions()-1) returns the + instance of REPEATED_LAYER that is stacked immediately on top of SUBNET. + !*/ + + const subnet_type& subnet( + ) const; + /*! + ensures + - returns the SUBNET base network that repeat sits on top of. If you want + to access the REPEATED_LAYER components then you must use get_repeated_layer(). + !*/ + + subnet_type& subnet( + ); + /*! + ensures + - returns the SUBNET base network that repeat sits on top of. If you want + to access the REPEATED_LAYER components then you must use get_repeated_layer(). + !*/ + }; + + template < size_t num, template class T, typename U > + std::ostream& operator<<(std::ostream& out, const repeat& item); + /*! + prints the network architecture to the given output stream. + !*/ + + template < size_t num, template class T, typename U > + void serialize(const repeat& item, std::ostream& out); + template < size_t num, template class T, typename U > + void deserialize(repeat& item, std::istream& in); + /*! + provides serialization support + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long ID, + typename SUBNET + > + class add_tag_layer + { + /*! + REQUIREMENTS ON SUBNET + - One of the following must be true: + - SUBNET implements the EXAMPLE_INPUT_LAYER interface defined in + input_abstract.h. + - SUBNET is an add_layer object. + - SUBNET is an add_tag_layer object. + - SUBNET is an add_skip_layer object. + - SUBNET is a repeat object. + + WHAT THIS OBJECT REPRESENTS + This object adds a new layer to a deep neural network. However, this layer + simply performs the identity transform. This means it is a no-op and its + presence does not change the behavior of the network. It exists solely to + be used by add_skip_layer to reference a particular part of a network. + + Also, this object provides an interface identical to the one defined by the + add_layer object. + !*/ + }; + + template + std::ostream& operator<<(std::ostream& out, const add_tag_layer& item); + /*! + prints the network architecture to the given output stream. + !*/ + + template + void serialize(const add_tag_layer& item, std::ostream& out); + template + void deserialize(add_tag_layer& item, std::istream& in); + /*! + provides serialization support + !*/ + + template using tag1 = add_tag_layer< 1, SUBNET>; + template using tag2 = add_tag_layer< 2, SUBNET>; + template using tag3 = add_tag_layer< 3, SUBNET>; + template using tag4 = add_tag_layer< 4, SUBNET>; + template using tag5 = add_tag_layer< 5, SUBNET>; + template using tag6 = add_tag_layer< 6, SUBNET>; + template using tag7 = add_tag_layer< 7, SUBNET>; + template using tag8 = add_tag_layer< 8, SUBNET>; + template using tag9 = add_tag_layer< 9, SUBNET>; + template using tag10 = add_tag_layer<10, SUBNET>; + + template class tag> + struct tag_id + { + /*! + REQUIREMENTS ON tag + Tag should be an add_tag_layer template such as tag1, tag2, etc. + + WHAT THIS OBJECT REPRESENTS + This is a tool for finding the numeric ID of a tag layer. For example, + tag_id::id == 3. + !*/ + + const static unsigned long id; + }; + +// ---------------------------------------------------------------------------------------- + + template < + template class TAG_TYPE, + typename SUBNET + > + class add_skip_layer + { + /*! + REQUIREMENTS ON SUBNET + - One of the following must be true: + - SUBNET is an add_layer object. + - SUBNET is an add_tag_layer object. + - SUBNET is an add_skip_layer object. + - SUBNET is a repeat object. + + WHAT THIS OBJECT REPRESENTS + This object adds a new layer to a deep neural network which draws its + inputs from layer(subnet()) and performs the identity transform. + + Also, this object provides an interface identical to the one defined by the + add_layer object. + !*/ + }; + + template class T, typename U> + std::ostream& operator<<(std::ostream& out, const add_skip_layer& item); + /*! + prints the network architecture to the given output stream. + !*/ + + template class T, typename U> + void serialize(const add_skip_layer& item, std::ostream& out); + template class T, typename U> + void deserialize(add_skip_layer& item, std::istream& in); + /*! + provides serialization support + !*/ + + template using skip1 = add_skip_layer< tag1, SUBNET>; + template using skip2 = add_skip_layer< tag2, SUBNET>; + template using skip3 = add_skip_layer< tag3, SUBNET>; + template using skip4 = add_skip_layer< tag4, SUBNET>; + template using skip5 = add_skip_layer< tag5, SUBNET>; + template using skip6 = add_skip_layer< tag6, SUBNET>; + template using skip7 = add_skip_layer< tag7, SUBNET>; + template using skip8 = add_skip_layer< tag8, SUBNET>; + template using skip9 = add_skip_layer< tag9, SUBNET>; + template using skip10 = add_skip_layer; + +// ---------------------------------------------------------------------------------------- + + template < + unsigned int i, + typename net_type + > + auto& layer ( + net_type& n + ); + /*! + requires + - net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or + add_tag_layer. + - i < net_type::num_layers + ensures + - This function allows you to access any layer in a network by its layer index + i. Therefore, it will walk i steps down the network and return the layer + object there. Since networks can be big, the best way to find layer index + numbers is to print a network to the screen since the print out will include + indexes for each layer. + - In general, this function chains together i calls to n.subnet() and returns + the result. So for example: + - if (i == 0) + - returns n + - else if (i == 1) + - returns n.subnet() + - else if (i == 2) + - returns n.subnet().subnet() + - else if (i == 3) + - returns n.subnet().subnet().subnet() + - else + - etc. + Except that when it hits a repeat layer it recurses into the repeated layers + contained inside. That is, if the layer index indicates a layer in a repeat + object this function will make the appropriate call to get_repeated_layer() + and do the right thing. + !*/ + + template < + template class Match, + typename net_type + > + auto& layer ( + net_type& n + ); + /*! + requires + - net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or + add_tag_layer. + ensures + - returns the first layer in n that is of type Match. E.g. if net_type is + fc>>> then calling layer(n) would return + layer<1>(n), that is, a reference to the relu layer. + !*/ + + template < + template class Match, + unsigned int i, + typename net_type + > + auto& layer ( + net_type& n + ); + /*! + requires + - net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or + add_tag_layer. + ensures + - returns layer(layer(n)) + !*/ + +// ---------------------------------------------------------------------------------------- + + template + auto& input_layer ( + net_type& net + ); + /*! + requires + - net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or + add_tag_layer. + ensures + - returns the input later of the given network object. Specifically, this + function is equivalent to calling: + layer(net); + That is, you get the input layer details object for the network. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename net_type, + typename visitor + > + void visit_layer_parameters( + net_type& net, + visitor v + ); + /*! + requires + - net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or + add_tag_layer. + - v is a function object with a signature equivalent to: + v(size_t idx, tensor& t) + ensures + - Loops over all the computational layers (i.e. layers with parameters, as + opposed to loss, tag, or input layers) in net and passes their parameters to + v(). To be specific, this function essentially performs the following: + + size_t computational_layer_idx = 0; + for (size_t i = 0; i < net_type::num_layers; ++i) + { + if (layer(net) is a computational layer) + { + v(computational_layer_idx, layer(net).layer_details().get_layer_params()); + ++computational_layer_idx; + } + } + - When v() is called, the first argument is always < net_type::num_computational_layers. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename net_type, + typename visitor + > + void visit_layer_parameter_gradients( + net_type& net, + visitor v + ); + /*! + requires + - net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or + add_tag_layer. + - v is a function object with a signature equivalent to: + v(size_t idx, tensor& t) + ensures + - Loops over all the computational layers (i.e. layers with parameters, as + opposed to loss, tag, or input layers) in net and passes their parameter + gradients to v(). To be specific, this function essentially performs the + following: + + size_t computational_layer_idx = 0; + for (size_t i = 0; i < net_type::num_layers; ++i) + { + if (layer(net) is a computational layer) + { + v(computational_layer_idx, layer(net).get_parameter_gradient()); + ++computational_layer_idx; + } + } + - When v() is called, the first argument is always < net_type::num_computational_layers. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename net_type, + typename visitor + > + void visit_layers( + net_type& net, + visitor v + ); + /*! + requires + - net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or + add_tag_layer. + - v is a function object with a signature equivalent to: + v(size_t idx, any_net_type& t) + That is, it must take a size_t and then any of the network types such as + add_layer, add_loss_layer, etc. + ensures + - Loops over all the layers in net and calls v() on them. To be specific, this + function essentially performs the following: + + for (size_t i = 0; i < net_type::num_layers; ++i) + v(i, layer(net)); + !*/ + + template < + typename net_type, + typename visitor + > + void visit_layers_backwards( + net_type& net, + visitor v + ); + /*! + requires + - net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or + add_tag_layer. + - v is a function object with a signature equivalent to: + v(size_t idx, any_net_type& t) + That is, it must take a size_t and then any of the network types such as + add_layer, add_loss_layer, etc. + ensures + - Loops over all the layers in net and calls v() on them. The loop happens in + the reverse order of visit_layers(). To be specific, this function + essentially performs the following: + + for (size_t i = net_type::num_layers; i != 0; --i) + v(i-1, layer(net)); + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + size_t begin, + size_t end, + typename net_type, + typename visitor + > + void visit_layers_range( + net_type& net, + visitor v + ); + /*! + requires + - net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or + add_tag_layer. + - v is a function object with a signature equivalent to: + v(size_t idx, any_net_type& t) + That is, it must take a size_t and then any of the network types such as + add_layer, add_loss_layer, etc. + - begin <= end <= net_type::num_layers + ensures + - Loops over the layers in the range [begin,end) in net and calls v() on them. + The loop happens in the reverse order of visit_layers(). To be specific, + this function essentially performs the following: + + for (size_t i = begin; i < end; ++i) + v(i, layer(net)); + !*/ + + template < + size_t begin, + size_t end, + typename net_type, + typename visitor + > + void visit_layers_backwards_range( + net_type& net, + visitor v + ); + /*! + requires + - net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or + add_tag_layer. + - v is a function object with a signature equivalent to: + v(size_t idx, any_net_type& t) + That is, it must take a size_t and then any of the network types such as + add_layer, add_loss_layer, etc. + - begin <= end <= net_type::num_layers + ensures + - Loops over the layers in the range [begin,end) in net and calls v() on them. + The loop happens in the reverse order of visit_layers_range(). To be specific, + this function essentially performs the following: + + for (size_t i = end; i != begin; --i) + v(i-1, layer(net)); + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long tag_id, + typename net_type, + typename visitor + > + void visit_layers_until_tag( + net_type& net, + visitor v + ); + /*! + requires + - net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or + add_tag_layer. + - v is a function object with a signature equivalent to: + v(any_net_type& t) + That is, it must take any of the network types such as add_layer, + add_loss_layer, etc. + ensures + - Loops over all the layers in net beginning with layer<0>(net) and going until + a tag layer with an ID of tag_id is encountered. To be specific, this + function essentially performs the following: + + size_t i = 0; + while(layer(net) isn't an add_tag_layer with ID == tag_id) { + v(layer(net)); + ++i; + } + v(layer(net)); // also visits the tag layer itself at the very end. + !*/ + +// ---------------------------------------------------------------------------------------- + + struct layer_test_results + { + std::string log; + bool was_good; + + operator bool() const { return was_good; } + }; + + inline std::ostream& operator<< (std::ostream& out, const layer_test_results& item) + { + out << item.log; + return out; + } + + template < + typename layer_details_type + > + layer_test_results test_layer ( + layer_details_type l + ); + /*! + ensures + - Checks if l correctly implements the EXAMPLE_COMPUTATIONAL_LAYER_ interface + defined in layers_abstract.h. Importantly, it computes numerical approximations + to the gradients and compares them to the outputs of the layer. + - The results of the testing are returned. In particular, if the returned object + is RESULT then we will have: + - RESULT.was_good == false if and only if the layer failed the testing. + - RESULT.log == a string describing why the testing failed if was_good==false. + - Note that this function is only capable of checking layers that take + arbitrary subnetworks as input. So if you have designed a layer that expects + only a certain restricted type of subnetwork then you might get a compile or + runtime error when you call this function. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_DNn_CORE_ABSTRACT_H_ + diff --git a/ml/dlib/dlib/dnn/cpu_dlib.cpp b/ml/dlib/dlib/dnn/cpu_dlib.cpp new file mode 100644 index 000000000..ed5661102 --- /dev/null +++ b/ml/dlib/dlib/dnn/cpu_dlib.cpp @@ -0,0 +1,2170 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_DNN_CPU_cPP_ +#define DLIB_DNN_CPU_cPP_ + +// This file contains CPU implementations of the GPU based functions in cuda_dlib.h + +#include "cpu_dlib.h" +#include "tensor_tools.h" +#include "../image_transforms/interpolation.h" +#include "../threads.h" + +namespace dlib +{ + namespace cpu + { + + // ----------------------------------------------------------------------------------- + + void multiply ( + bool add_to, + tensor& dest, + const tensor& src1, + const tensor& src2 + ) + { + DLIB_CASSERT(dest.k() == src1.k() && src1.k() == src2.k() && + dest.nr() == src1.nr() && src1.nr() == src2.nr() && + dest.nc() == src1.nc() && src1.nc() == src2.nc() ); + const long MD = std::max(std::max(dest.num_samples(),src1.num_samples()),src2.num_samples()); + DLIB_CASSERT((dest.num_samples()==1 || dest.num_samples()==MD) && + (src1.num_samples()==1 || src1.num_samples()==MD) && + (src2.num_samples()==1 || src2.num_samples()==MD) ); + + if (dest.size() == 0) + return; + + const size_t max_size = std::max(std::max(dest.size(),src1.size()),src2.size()); + const auto d = dest.host(); + const auto s1 = src1.host(); + const auto s2 = src2.host(); + if (dest.size() == src1.size() && src1.size() == src2.size()) + { + if (add_to) + { + for (size_t i = 0; i < src1.size(); ++i) + d[i] += s1[i]*s2[i]; + } + else + { + for (size_t i = 0; i < src1.size(); ++i) + d[i] = s1[i]*s2[i]; + } + } + else if (dest.num_samples() == 1) + { + if (!add_to) + { + for (size_t i = 0; i < dest.size(); ++i) + d[i] = 0; + } + for (size_t i = 0; i < max_size; ++i) + d[i%dest.size()] += s1[i%src1.size()]*s2[i%src2.size()]; + } + else + { + if (add_to) + { + for (size_t i = 0; i < max_size; ++i) + d[i] += s1[i%src1.size()]*s2[i%src2.size()]; + } + else + { + for (size_t i = 0; i < max_size; ++i) + d[i] = s1[i%src1.size()]*s2[i%src2.size()]; + } + } + } + + // ------------------------------------------------------------------------------------ + + void multiply_conv ( + bool add_to, + tensor& dest, + const tensor& src1, + const tensor& src2 + ) + { + auto d = dest.host(); + auto s1 = src1.host(); + auto s2 = src2.host(); + if (have_same_dimensions(dest,src1)) + { + DLIB_CASSERT(src2.num_samples() == 1 && src2.nr() == 1 && src2.nc() == 1 && src2.k() == src1.k()); + + if (add_to) + { + for (long n = 0; n < dest.num_samples(); ++n) + { + for (long k = 0; k < dest.k(); ++k) + { + for (long r = 0; r < dest.nr(); ++r) + { + for (long c = 0; c < dest.nc(); ++c) + { + *d++ += (*s1++)*s2[k]; + } + } + } + } + } + else + { + for (long n = 0; n < dest.num_samples(); ++n) + { + for (long k = 0; k < dest.k(); ++k) + { + for (long r = 0; r < dest.nr(); ++r) + { + for (long c = 0; c < dest.nc(); ++c) + { + *d++ = (*s1++)*s2[k]; + } + } + } + } + } + } + else + { + DLIB_CASSERT(have_same_dimensions(src1,src2)); + DLIB_CASSERT(dest.num_samples() == 1 && dest.nr() == 1 && dest.nc() == 1 && dest.k() == src1.k()); + + if (!add_to) + { + for (long k = 0; k < src1.k(); ++k) + d[k] = 0; + } + + for (long n = 0; n < src1.num_samples(); ++n) + { + for (long k = 0; k < src1.k(); ++k) + { + for (long r = 0; r < src1.nr(); ++r) + { + for (long c = 0; c < src1.nc(); ++c) + { + d[k] += (*s1++)*(*s2++); + } + } + } + } + } + } + + // ------------------------------------------------------------------------------------ + + void scale_channels ( + bool add_to, + tensor& dest, + const tensor& src, + const tensor& scales + ) + { + DLIB_CASSERT(have_same_dimensions(dest,src) && + scales.num_samples() == src.num_samples() && + scales.k() == src.k() && + scales.nr() == 1 && + scales.nc() == 1 ); + + if (dest.size() == 0) + return; + + if (add_to) + { + auto d = dest.host(); + auto s = src.host(); + auto scal = scales.host(); + + for (long n = 0; n < src.num_samples(); ++n) + { + for (long k = 0; k < src.k(); ++k) + { + const auto scale = scal[n*scales.k() + k]; + for (long r = 0; r < src.nr(); ++r) + { + for (long c = 0; c < src.nc(); ++c) + { + *d++ += (*s++) * scale; + } + } + } + } + + + } + else + { + auto d = dest.host_write_only(); + auto s = src.host(); + auto scal = scales.host(); + + for (long n = 0; n < src.num_samples(); ++n) + { + for (long k = 0; k < src.k(); ++k) + { + const auto scale = scal[n*scales.k() + k]; + for (long r = 0; r < src.nr(); ++r) + { + for (long c = 0; c < src.nc(); ++c) + { + *d++ = (*s++) * scale; + } + } + } + } + } + } + + // ------------------------------------------------------------------------------------ + + void add( + float beta, + tensor& dest, + float alpha, + const tensor& src + ) + { + DLIB_CASSERT( + (have_same_dimensions(src, dest) || + (src.num_samples()==1 && src.k()==dest.k() && src.nr()==1 && src.nc()==1) || + (src.num_samples()==1 && src.k()==dest.k() && src.nr()==dest.nr() && src.nc()==dest.nc()) || + (src.num_samples()==1 && src.k()==1 && src.nr()==dest.nr() && src.nc()==dest.nc()) || + (src.num_samples()==dest.num_samples() && src.k()==1 && src.nr()==1 && src.nc()==1)) && + is_same_object(src,dest) == false , + "\n\t dest.num_samples(): " << dest.num_samples() + <<"\n\t dest.k(): " << dest.k() + <<"\n\t dest.nr(): " << dest.nr() + <<"\n\t dest.nc(): " << dest.nc() + <<"\n\t src.num_samples(): " << src.num_samples() + <<"\n\t src.k(): " << src.k() + <<"\n\t src.nr(): " << src.nr() + <<"\n\t src.nc(): " << src.nc() + ); + + + if (beta == 0 && alpha == 0) + { + dest = 0; + return; + } + + auto d = dest.host(); + auto s = src.host(); + for (long n = 0; n < dest.num_samples(); ++n) + { + const auto sn = src.num_samples()==1 ? 0:n; + for (long k = 0; k < dest.k(); ++k) + { + const auto sk = src.k()==1 ? 0:k; + for (long r = 0; r < dest.nr(); ++r) + { + const auto sr = src.nr()==1 ? 0:r; + for (long c = 0; c < dest.nc(); ++c) + { + const auto sc = src.nc()==1 ? 0:c; + + const auto s_idx = ((sn*src.k() + sk)*src.nr() + sr)*src.nc() + sc; + *d = beta*(*d) + alpha*s[s_idx]; + ++d; + } + } + } + } + } + + // ---------------------------------------------------------------------------------------- + + void add ( + tensor& dest, + const tensor& src1, + const tensor& src2 + ) + { + auto d = dest.host(); + auto s1 = src1.host(); + auto s2 = src2.host(); + + // Do the simple and fast version if everything has the same dimensions + if (have_same_dimensions(dest, src1) && + have_same_dimensions(dest, src2)) + { + for (size_t i = 0; i < dest.size(); ++i) + d[i] = s1[i] + s2[i]; + return; + } + + // Otherwise, do the more complex version with bounds checking. + for (long n = 0; n < dest.num_samples(); ++n) + { + for (long k = 0; k < dest.k(); ++k) + { + for (long r = 0; r < dest.nr(); ++r) + { + for (long c = 0; c < dest.nc(); ++c) + { + float v1 = 0; + float v2 = 0; + + // if this index is inside src1 + if (n < src1.num_samples() && + k < src1.k() && + r < src1.nr() && + c < src1.nc() ) + { + const auto s_idx = ((n*src1.k() + k)*src1.nr() + r)*src1.nc() + c; + v1 = s1[s_idx]; + } + + // if this index is inside src2 + if (n < src2.num_samples() && + k < src2.k() && + r < src2.nr() && + c < src2.nc() ) + { + const auto s_idx = ((n*src2.k() + k)*src2.nr() + r)*src2.nc() + c; + v2 = s2[s_idx]; + } + + *d = v1 + v2; + ++d; + } + } + } + } + } + + // ---------------------------------------------------------------------------------------- + + void multiply_zero_padded ( + bool add_to, + tensor& dest, + const tensor& src1, + const tensor& src2 + ) + { + auto d = dest.host(); + auto s1 = src1.host(); + auto s2 = src2.host(); + + // Do the simple and fast version if everything has the same dimensions + if (have_same_dimensions(dest, src1) && + have_same_dimensions(dest, src2)) + { + if (add_to) + { + for (size_t i = 0; i < dest.size(); ++i) + d[i] += s1[i] * s2[i]; + } + else + { + for (size_t i = 0; i < dest.size(); ++i) + d[i] = s1[i] * s2[i]; + } + return; + } + + // Otherwise, do the more complex version with bounds checking. + for (long n = 0; n < dest.num_samples(); ++n) + { + for (long k = 0; k < dest.k(); ++k) + { + for (long r = 0; r < dest.nr(); ++r) + { + for (long c = 0; c < dest.nc(); ++c) + { + float v1 = 0; + float v2 = 0; + + // if this index is inside src1 + if (n < src1.num_samples() && + k < src1.k() && + r < src1.nr() && + c < src1.nc() ) + { + const auto s_idx = ((n*src1.k() + k)*src1.nr() + r)*src1.nc() + c; + v1 = s1[s_idx]; + } + + // if this index is inside src2 + if (n < src2.num_samples() && + k < src2.k() && + r < src2.nr() && + c < src2.nc() ) + { + const auto s_idx = ((n*src2.k() + k)*src2.nr() + r)*src2.nc() + c; + v2 = s2[s_idx]; + } + + if (add_to) + *d += v1 * v2; + else + *d = v1 * v2; + ++d; + } + } + } + } + } + + // ---------------------------------------------------------------------------------------- + + void assign_bias_gradient ( + tensor& grad, + const tensor& gradient_input + ) + { + DLIB_CASSERT( + grad.num_samples() == 1 && + gradient_input.k() == grad.k() && + gradient_input.nr() == grad.nr() && + gradient_input.nc() == grad.nc() && + gradient_input.size() > 0); + + auto out = grad.host(); + auto in = gradient_input.host(); + + for (size_t i = 0; i < grad.size(); ++i) + out[i] = *in++; + + for (long j = 1; j < gradient_input.num_samples(); ++j) + { + for (size_t i = 0; i < grad.size(); ++i) + out[i] += *in++; + } + } + + // ------------------------------------------------------------------------------------ + + void assign_conv_bias_gradient ( + tensor& grad, + const tensor& gradient_input + ) + { + DLIB_CASSERT( + grad.num_samples() == 1 && + grad.k() >= 1 && + grad.nr() == 1 && + grad.nc() == 1 && + gradient_input.k() == grad.k() && + gradient_input.size() > 0 && + is_same_object(grad,gradient_input) == false + ); + + auto g = grad.host(); + auto gi = gradient_input.host(); + + for (long k = 0; k < gradient_input.k(); ++k) + g[k] = 0; + + for (long n = 0; n < gradient_input.num_samples(); ++n) + { + for (long k = 0; k < gradient_input.k(); ++k) + { + for (long r = 0; r < gradient_input.nr(); ++r) + { + for (long c = 0; c < gradient_input.nc(); ++c) + { + g[k] += (*gi++); + } + } + } + } + } + + // ----------------------------------------------------------------------------------- + + void affine_transform( + tensor& dest, + const tensor& src, + const float A, + const float B + ) + { + DLIB_CASSERT(dest.size()==src.size()); + const auto d = dest.host(); + const auto s = src.host(); + for (size_t i = 0; i < src.size(); ++i) + d[i] = A*s[i] + B; + } + + void affine_transform( + tensor& dest, + const tensor& src1, + const tensor& src2, + const float A, + const float B, + const float C + ) + { + DLIB_CASSERT(dest.size()==src1.size()); + DLIB_CASSERT(dest.size()==src2.size()); + const auto d = dest.host(); + const auto s1 = src1.host(); + const auto s2 = src2.host(); + for (size_t i = 0; i < src1.size(); ++i) + d[i] = A*s1[i] + B*s2[i] + C; + } + + void affine_transform( + tensor& dest, + const tensor& src1, + const tensor& src2, + const tensor& src3, + const float A, + const float B, + const float C, + const float D + ) + { + DLIB_CASSERT(dest.size()==src1.size()); + DLIB_CASSERT(dest.size()==src2.size()); + DLIB_CASSERT(dest.size()==src3.size()); + const auto d = dest.host(); + const auto s1 = src1.host(); + const auto s2 = src2.host(); + const auto s3 = src3.host(); + for (size_t i = 0; i < src1.size(); ++i) + d[i] = A*s1[i] + B*s2[i] + C*s3[i] + D; + } + + void affine_transform_range( + size_t begin, + size_t end, + tensor& dest, + const tensor& src1, + const tensor& src2, + const tensor& src3, + const float A, + const float B, + const float C + ) + { + DLIB_CASSERT(dest.size()==src1.size()); + DLIB_CASSERT(dest.size()==src2.size()); + DLIB_CASSERT(dest.size()==src3.size()); + DLIB_CASSERT(begin <= end && end <= dest.size()); + const auto d = dest.host(); + const auto s1 = src1.host(); + const auto s2 = src2.host(); + const auto s3 = src3.host(); + for (size_t i = begin; i < end; ++i) + d[i] = A*s1[i] + B*s2[i] + C*s3[i]; + } + + // ----------------------------------------------------------------------------------- + + void affine_transform( + tensor& dest, + const tensor& src, + const tensor& A, + const tensor& B + ) + { + DLIB_CASSERT(have_same_dimensions(dest,src)); + DLIB_CASSERT( + ((A.num_samples()==1 && B.num_samples()==1) || + (A.num_samples()==src.num_samples() && B.num_samples()==src.num_samples())) && + A.nr()==B.nr() && B.nr()==src.nr() && + A.nc()==B.nc() && B.nc()==src.nc() && + A.k() ==B.k() && B.k()==src.k()); + + auto d = dest.host(); + auto s = src.host(); + const auto a = A.host(); + const auto b = B.host(); + if (A.num_samples() == 1) + { + const long num = src.size()/src.num_samples(); + for (long i = 0; i < src.num_samples(); ++i) + { + for (long j = 0; j < num; ++j) + { + *d = a[j]*(*s) + b[j]; + d++; + s++; + } + } + } + else + { + for (size_t i = 0; i < src.size(); ++i) + d[i] = a[i]*s[i] + b[i]; + } + } + + // ----------------------------------------------------------------------------------- + + void affine_transform_conv( + tensor& dest, + const tensor& src, + const tensor& A, + const tensor& B + ) + { + DLIB_CASSERT(have_same_dimensions(dest,src)); + DLIB_CASSERT(have_same_dimensions(A,B)); + DLIB_CASSERT(A.num_samples() == 1 && + A.nr() == 1 && + A.nc() == 1 && + A.k() == src.k()); + + auto d = dest.host(); + auto s = src.host(); + const auto a = A.host(); + const auto b = B.host(); + for (long n = 0; n < dest.num_samples(); ++n) + { + for (long k = 0; k < dest.k(); ++k) + { + for (long r = 0; r < dest.nr(); ++r) + { + for (long c = 0; c < dest.nc(); ++c) + { + *d++ = a[k]*(*s++) + b[k]; + } + } + } + } + } + + // ---------------------------------------------------------------------------------------- + + void affine_transform( + const rectangle& rect, + tensor& dest, + const tensor& src1, + const tensor& src2, + const tensor& src3, + float A, + float B, + float C + ) + { + DLIB_CASSERT(dest.size() == src1.size()); + DLIB_CASSERT(dest.size() == src2.size()); + DLIB_CASSERT(dest.size() == src3.size()); + DLIB_CASSERT(dest.num_samples() == src1.num_samples()); + DLIB_CASSERT(dest.num_samples() == src2.num_samples()); + DLIB_CASSERT(dest.num_samples() == src3.num_samples()); + DLIB_CASSERT(rectangle(0,0, dest.size()/dest.num_samples()-1, dest.num_samples()-1).contains(rect)); + + + auto d = dest.host(); + auto s1 = src1.host(); + auto s2 = src2.host(); + auto s3 = src3.host(); + + const auto nc = dest.size()/dest.num_samples(); + + for (long r = rect.top(); r <= rect.bottom(); ++r) + { + for (long c = rect.left(); c <= rect.right(); ++c) + { + auto idx = r*nc + c; + d[idx] = s1[idx]*A + s2[idx]*B + s3[idx]*C; + } + } + + } + + // ----------------------------------------------------------------------------------- + + void compute_adam_update ( + size_t begin, + size_t end, + tensor& s, + tensor& m, + tensor& v, + const float t, + const float learning_rate, + const float weight_decay, + const float momentum1, + const float momentum2, + const tensor& params, + const tensor& params_grad + ) + { + DLIB_CASSERT(s.size() == m.size() && + s.size() == v.size() && + s.size() == params.size() && + s.size() == params_grad.size()); + DLIB_CASSERT(begin <= end && end <= params.size()); + const float eps = 1e-8; + const float alpha = learning_rate*std::sqrt(1-std::pow(momentum2,t))/(1-std::pow(momentum1, t)); + + // The loop is equivalent to doing this: + // m = momentum1*m + (1-momentum1) * (weight_decay*params + params_grad); + // v = momentum2*v + (1-momentum2)*squared(weight_decay*params + params_grad); + // s = -alpha*m/(sqrt(v) + eps); + auto pm = m.host(); + auto pv = v.host(); + auto ps = s.host_write_only(); + auto pparams = params.host(); + auto ppgrad = params_grad.host(); + for (size_t i = begin; i < end; ++i) + { + float g = weight_decay*pparams[i] + ppgrad[i]; + pm[i] = momentum1*pm[i] + (1-momentum1)*g; + pv[i] = momentum2*pv[i] + (1-momentum2)*g*g; + ps[i] = -alpha*pm[i]/(std::sqrt(pv[i]) + eps); + } + } + + // ----------------------------------------------------------------------------------- + + void batch_normalize_inference ( + const double eps, + resizable_tensor& dest, + const tensor& src, + const tensor& gamma, + const tensor& beta, + const tensor& running_means, + const tensor& running_variances + ) + { + DLIB_CASSERT( + gamma.num_samples() == 1 && + gamma.nr() == src.nr() && + gamma.nc() == src.nc() && + gamma.k() == src.k() && + have_same_dimensions(gamma, beta) && + have_same_dimensions(gamma, running_means) && + have_same_dimensions(gamma, running_variances) && + eps > 0, + "\ngamma.num_samples(): " << gamma.num_samples() << + "\ngamma.k(): " << gamma.k() << + "\ngamma.nr(): " << gamma.nr() << + "\ngamma.nc(): " << gamma.nc() << + "\nbeta.num_samples(): " << beta.num_samples() << + "\nbeta.k(): " << beta.k() << + "\nbeta.nr(): " << beta.nr() << + "\nbeta.nc(): " << beta.nc() << + "\nrunning_means.num_samples(): " << running_means.num_samples() << + "\nrunning_means.k(): " << running_means.k() << + "\nrunning_means.nr(): " << running_means.nr() << + "\nrunning_means.nc(): " << running_means.nc() << + "\nrunning_variances.num_samples(): " << running_variances.num_samples() << + "\nrunning_variances.k(): " << running_variances.k() << + "\nrunning_variances.nr(): " << running_variances.nr() << + "\nrunning_variances.nc(): " << running_variances.nc() << + "\nsrc.k(): " << src.k() << + "\nsrc.nr(): " << src.nr() << + "\nsrc.nc(): " << src.nc() << + "\neps: " << eps + ); + dest.copy_size(src); + + auto d = dest.host(); + auto s = src.host(); + auto g = gamma.host(); + auto b = beta.host(); + auto m = running_means.host(); + auto v = running_variances.host(); + + const long num = src.k()*src.nr()*src.nc(); + for (long n = 0; n < src.num_samples(); ++n) + { + for (long k = 0; k < num; ++k) + { + *d = g[k]*(*s - m[k])/std::sqrt(v[k]+eps) + b[k]; + ++d; + ++s; + } + } + } + + void batch_normalize ( + const double eps, + resizable_tensor& dest, + resizable_tensor& means, + resizable_tensor& invstds, + const double averaging_factor, + resizable_tensor& running_means, + resizable_tensor& running_variances, + const tensor& src, + const tensor& gamma, + const tensor& beta + ) + { + DLIB_CASSERT(0 <= averaging_factor && averaging_factor <= 1, "averaging_factor: " << averaging_factor); + DLIB_CASSERT(averaging_factor==1 || have_same_dimensions(running_means,means)); + DLIB_CASSERT(averaging_factor==1 || have_same_dimensions(running_variances,invstds)); + DLIB_CASSERT( + src.num_samples() > 1 && + gamma.num_samples() == 1 && + beta.num_samples() == 1 && + gamma.nr() == beta.nr() && beta.nr() == src.nr() && + gamma.nc() == beta.nc() && beta.nc() == src.nc() && + gamma.k() == beta.k() && beta.k() == src.k() && + eps > 0, + "\ngamma.num_samples(): " << gamma.num_samples() << + "\ngamma.k(): " << gamma.k() << + "\ngamma.nr(): " << gamma.nr() << + "\ngamma.nc(): " << gamma.nc() << + "\nbeta.num_samples(): " << beta.num_samples() << + "\nbeta.k(): " << beta.k() << + "\nbeta.nr(): " << beta.nr() << + "\nbeta.nc(): " << beta.nc() << + "\nsrc.k(): " << src.k() << + "\nsrc.nr(): " << src.nr() << + "\nsrc.nc(): " << src.nc() << + "\neps: " << eps + ); + + dest.copy_size(src); + means.set_size(1, src.k(), src.nr(), src.nc()); + invstds.set_size(1, src.k(), src.nr(), src.nc()); + + // first compute means and invstds + means = 0; + invstds = 0; + const auto p_invstds = invstds.host(); + const auto p_means = means.host(); + auto p_src = src.host(); + const long num = src.k()*src.nr()*src.nc(); + // compute means, and sum of squares + for (long i = 0; i < num; ++i) + { + for (long n = 0; n < src.num_samples(); ++n) + { + float val = p_src[n*num+i]; + p_means[i] += val; + p_invstds[i] += val*val; + } + } + means /= src.num_samples(); + invstds /= src.num_samples(); + // copy data back to host + invstds.host(); means.host(); + + // compute variances + running_variances.copy_size(invstds); + auto rvar = running_variances.host(); + // This scale makes the running variances unbiased. + const double scale = (src.num_samples())/(src.num_samples()-1.0); + for (long i = 0; i < num; ++i) + { + auto actual_var = p_invstds[i] - p_means[i]*p_means[i]; + if (averaging_factor == 1) + rvar[i] = scale*actual_var; + else + rvar[i] = (1-averaging_factor)*rvar[i] + scale*averaging_factor*actual_var; + + p_invstds[i] = 1.0f/std::sqrt(actual_var + eps); + } + + p_src = src.host(); + auto p_dest = dest.host(); + const auto p_gamma = gamma.host(); + const auto p_beta = beta.host(); + for (long n = 0; n < src.num_samples(); ++n) + { + for (long i = 0; i < num; ++i) + { + *p_dest = (*p_src - p_means[i])*p_invstds[i]; + *p_dest = (*p_dest)*p_gamma[i] + p_beta[i]; + ++p_src; + ++p_dest; + } + } + + // now keep track of the running means + running_means.copy_size(means); + if (averaging_factor != 1) + running_means = (1-averaging_factor)*mat(running_means) + averaging_factor*mat(means); + else + running_means = means; + } + + void batch_normalize_gradient ( + const double eps, + const tensor& gradient_input, + const tensor& means, + const tensor& invstds, + const tensor& src, + const tensor& gamma, + tensor& src_grad, + tensor& gamma_grad, + tensor& beta_grad + ) + { + + const long num = src.k()*src.nr()*src.nc(); + DLIB_CASSERT(src.num_samples() > 1); + DLIB_CASSERT(num == (long)means.size()); + DLIB_CASSERT(num == (long)invstds.size()); + DLIB_CASSERT(num == (long)gamma.size()); + DLIB_CASSERT(num == (long)gamma_grad.size()); + DLIB_CASSERT(num == (long)beta_grad.size()); + DLIB_CASSERT(have_same_dimensions(gradient_input, src)); + DLIB_CASSERT(have_same_dimensions(gradient_input, src_grad)); + DLIB_CASSERT(eps > 0); + + beta_grad = 0; + gamma_grad = 0; + auto p_grad = gradient_input.host(); + auto p_src = src.host(); + const auto p_gamma = gamma.host(); + const auto p_gamma_grad = gamma_grad.host(); + const auto p_beta_grad = beta_grad.host(); + const auto p_invstds = invstds.host(); + const auto p_means = means.host(); + + resizable_tensor dvars, dmeans; + dvars.copy_size(invstds); + dmeans.copy_size(means); + dvars = 0; + dmeans = 0; + const auto p_dvars = dvars.host(); + const auto p_dmeans = dmeans.host(); + + for (long n = 0; n < src.num_samples(); ++n) + { + for (long i = 0; i < num; ++i) + { + const float x_hat = (*p_src - p_means[i])*p_invstds[i]; + p_beta_grad[i] += *p_grad; + p_gamma_grad[i] += (*p_grad)*x_hat; + + const float dx = *p_grad * p_gamma[i]; + + p_dvars[i] += dx*(*p_src - p_means[i])*-0.5*std::pow(p_invstds[i], 3.0f); + + ++p_grad; + ++p_src; + } + } + + const float invnum = 1.0f/src.num_samples(); + p_grad = gradient_input.host(); + p_src = src.host(); + for (long n = 0; n < src.num_samples(); ++n) + { + for (long i = 0; i < num; ++i) + { + const float dx = *p_grad * p_gamma[i]; + + p_dmeans[i] += dx*-p_invstds[i] + p_dvars[i] * -2*(*p_src - p_means[i])*invnum; + + ++p_grad; + ++p_src; + } + } + p_grad = gradient_input.host(); + p_src = src.host(); + auto p_src_grad = src_grad.host(); + for (long n = 0; n < src.num_samples(); ++n) + { + for (long i = 0; i < num; ++i) + { + const float dx = *p_grad * p_gamma[i]; + + *p_src_grad += dx*p_invstds[i] + + p_dvars[i] *2*(*p_src - p_means[i])*invnum + + p_dmeans[i]*invnum; + + + ++p_grad; + ++p_src; + ++p_src_grad; + } + } + } + + // ---------------------------------------------------------------------------------------- + + void batch_normalize_conv_inference ( + const double eps, + resizable_tensor& dest, + const tensor& src, + const tensor& gamma, + const tensor& beta, + const tensor& running_means, + const tensor& running_variances + ) + { + DLIB_CASSERT( + gamma.num_samples() == 1 && + gamma.nr() == 1 && + gamma.nc() == 1 && + gamma.k() == src.k() && + have_same_dimensions(gamma, beta) && + have_same_dimensions(gamma, running_means) && + have_same_dimensions(gamma, running_variances) && + eps > 0, + "\ngamma.num_samples(): " << gamma.num_samples() << + "\ngamma.k(): " << gamma.k() << + "\ngamma.nr(): " << gamma.nr() << + "\ngamma.nc(): " << gamma.nc() << + "\nbeta.num_samples(): " << beta.num_samples() << + "\nbeta.k(): " << beta.k() << + "\nbeta.nr(): " << beta.nr() << + "\nbeta.nc(): " << beta.nc() << + "\nrunning_means.num_samples(): " << running_means.num_samples() << + "\nrunning_means.k(): " << running_means.k() << + "\nrunning_means.nr(): " << running_means.nr() << + "\nrunning_means.nc(): " << running_means.nc() << + "\nrunning_variances.num_samples(): " << running_variances.num_samples() << + "\nrunning_variances.k(): " << running_variances.k() << + "\nrunning_variances.nr(): " << running_variances.nr() << + "\nrunning_variances.nc(): " << running_variances.nc() << + "\nsrc.k(): " << src.k() << + "\nsrc.nr(): " << src.nr() << + "\nsrc.nc(): " << src.nc() << + "\neps: " << eps + ); + dest.copy_size(src); + + auto d = dest.host(); + auto s = src.host(); + auto g = gamma.host(); + auto b = beta.host(); + auto m = running_means.host(); + auto v = running_variances.host(); + + const long num = src.nr()*src.nc(); + for (long n = 0; n < src.num_samples(); ++n) + { + for (long k = 0; k < src.k(); ++k) + { + const float invstd = 1.0f/std::sqrt(v[k] + eps); + for (long j = 0; j < num; ++j) + { + *d = g[k]*(*s - m[k])*invstd + b[k]; + ++d; + ++s; + } + } + } + } + + void batch_normalize_conv ( + const double eps, + resizable_tensor& dest, + resizable_tensor& means, + resizable_tensor& invstds, + const double averaging_factor, + resizable_tensor& running_means, + resizable_tensor& running_variances, + const tensor& src, + const tensor& gamma, + const tensor& beta + ) + { + DLIB_CASSERT(0 <= averaging_factor && averaging_factor <= 1, "averaging_factor: " << averaging_factor); + DLIB_CASSERT(averaging_factor==1 || have_same_dimensions(running_means,means)); + DLIB_CASSERT(averaging_factor==1 || have_same_dimensions(running_variances,invstds)); + DLIB_CASSERT( + src.num_samples() > 1 && + gamma.num_samples() == 1 && + beta.num_samples() == 1 && + gamma.nr() == 1 && + beta.nr() == 1 && + gamma.nc() == 1 && + beta.nc() == 1 && + gamma.k() == beta.k() && beta.k() == src.k() && + eps > 0, + "\ngamma.num_samples(): " << gamma.num_samples() << + "\ngamma.k(): " << gamma.k() << + "\ngamma.nr(): " << gamma.nr() << + "\ngamma.nc(): " << gamma.nc() << + "\nbeta.num_samples(): " << beta.num_samples() << + "\nbeta.k(): " << beta.k() << + "\nbeta.nr(): " << beta.nr() << + "\nbeta.nc(): " << beta.nc() << + "\nsrc.k(): " << src.k() << + "\nsrc.nr(): " << src.nr() << + "\nsrc.nc(): " << src.nc() << + "\neps: " << eps + ); + + dest.copy_size(src); + means.set_size(1, src.k()); + invstds.set_size(1, src.k()); + + // first compute means and invstds + means = 0; + invstds = 0; + const auto p_invstds = invstds.host(); + const auto p_means = means.host(); + const auto p_gamma = gamma.host(); + const auto p_beta = beta.host(); + auto p_src = src.host(); + const long num = src.nr()*src.nc(); + // compute means, and sum of squares + for (long n = 0; n < src.num_samples(); ++n) + { + for (long k = 0; k < src.k(); ++k) + { + for (long i = 0; i < num; ++i) + { + p_means[k] += *p_src; + p_invstds[k] += (*p_src)*(*p_src); + ++p_src; + } + } + } + means /= src.num_samples()*num; + invstds /= src.num_samples()*num; + // copy data back to host + invstds.host(); means.host(); + + p_src = src.host(); + // compute variances + running_variances.copy_size(invstds); + auto rvar = running_variances.host(); + // This scale makes the running variances unbiased. + const double scale = (src.num_samples()*num)/(src.num_samples()*num-1.0); + for (long k = 0; k < src.k(); ++k) + { + float actual_var = p_invstds[k] - p_means[k]*p_means[k]; + if (averaging_factor == 1) + rvar[k] = scale*actual_var; + else + rvar[k] = (1-averaging_factor)*rvar[k] + scale*averaging_factor*actual_var; + + p_invstds[k] = 1.0f/std::sqrt(actual_var + eps); + } + + p_src = src.host(); + auto p_dest = dest.host(); + for (long n = 0; n < src.num_samples(); ++n) + { + for (long k = 0; k < src.k(); ++k) + { + for (long i = 0; i < num; ++i) + { + *p_dest = (*p_src - p_means[k])*p_invstds[k]; + *p_dest = (*p_dest)*p_gamma[k] + p_beta[k]; + ++p_src; + ++p_dest; + } + } + } + + // now keep track of the running means + running_means.copy_size(means); + if (averaging_factor != 1) + running_means = (1-averaging_factor)*mat(running_means) + averaging_factor*mat(means); + else + running_means = means; + } + + void batch_normalize_conv_gradient( + const double eps, + const tensor& gradient_input, + const tensor& means, + const tensor& invstds, + const tensor& src, + const tensor& gamma, + tensor& src_grad, + tensor& gamma_grad, + tensor& beta_grad + ) + { + + const long num = src.nr()*src.nc(); + DLIB_CASSERT(src.num_samples() > 1); + DLIB_CASSERT(src.k() == (long)means.size()); + DLIB_CASSERT(src.k() == (long)invstds.size()); + DLIB_CASSERT(src.k() == (long)gamma.size()); + DLIB_CASSERT(src.k() == (long)gamma_grad.size()); + DLIB_CASSERT(src.k() == (long)beta_grad.size()); + DLIB_CASSERT(have_same_dimensions(gradient_input, src)); + DLIB_CASSERT(have_same_dimensions(gradient_input, src_grad)); + DLIB_CASSERT(eps > 0); + + beta_grad = 0; + gamma_grad = 0; + + auto p_grad = gradient_input.host(); + auto p_src = src.host(); + const auto p_gamma = gamma.host(); + const auto p_gamma_grad = gamma_grad.host(); + const auto p_beta_grad = beta_grad.host(); + const auto p_invstds = invstds.host(); + const auto p_means = means.host(); + + resizable_tensor dvars, dmeans; + dvars.copy_size(invstds); + dmeans.copy_size(means); + dvars = 0; + dmeans = 0; + const auto p_dvars = dvars.host(); + const auto p_dmeans = dmeans.host(); + + for (long n = 0; n < src.num_samples(); ++n) + { + for (long k = 0; k < src.k(); ++k) + { + const float invstd_pow = -0.5*std::pow(p_invstds[k], 3.0f); + for (long i = 0; i < num; ++i) + { + const float x_hat = (*p_src - p_means[k])*p_invstds[k]; + p_beta_grad[k] += *p_grad; + p_gamma_grad[k] += (*p_grad)*x_hat; + + const float dx = *p_grad * p_gamma[k]; + + p_dvars[k] += dx*(*p_src - p_means[k])*invstd_pow; + + ++p_grad; + ++p_src; + } + } + } + + p_grad = gradient_input.host(); + p_src = src.host(); + const float invnum = 1.0f/(src.num_samples()*num); + for (long n = 0; n < src.num_samples(); ++n) + { + for (long k = 0; k < src.k(); ++k) + { + for (long i = 0; i < num; ++i) + { + const float dx = *p_grad * p_gamma[k]; + + p_dmeans[k] += -dx*p_invstds[k] + p_dvars[k] * -2*(*p_src - p_means[k])*invnum; + + ++p_grad; + ++p_src; + } + } + } + p_grad = gradient_input.host(); + p_src = src.host(); + auto p_src_grad = src_grad.host(); + for (long n = 0; n < src.num_samples(); ++n) + { + for (long k = 0; k < src.k(); ++k) + { + for (long i = 0; i < num; ++i) + { + const float dx = *p_grad * p_gamma[k]; + + *p_src_grad += dx*p_invstds[k] + + p_dvars[k]*2*(*p_src - p_means[k])*invnum + + p_dmeans[k]*invnum; + + + ++p_grad; + ++p_src; + ++p_src_grad; + } + } + } + } + + // ----------------------------------------------------------------------------------- + + void threshold ( + tensor& data, + float thresh + ) + { + const auto d = data.host(); + for (size_t i = 0; i < data.size(); ++i) + d[i] = d[i]>thresh ? 1:0; + } + + void dot ( + const tensor& a, + const tensor& b, + tensor& result, + size_t idx + ) + { + DLIB_CASSERT(a.size() == b.size()); + DLIB_CASSERT(idx < result.size()); + + const auto aa = a.host(); + const auto bb = b.host(); + auto r = result.host(); + for (size_t i = 0; i < a.size(); ++i) + r[idx] += aa[i]*bb[i]; + } + + // ----------------------------------------------------------------------------------- + // ----------------------------------------------------------------------------------- + // ----------------------------------------------------------------------------------- + + namespace ttimpl + { + void softmax ( + const long num_locations, + const long num_channels, + tensor& dest, + const tensor& src + ) + { + DLIB_ASSERT(num_channels*num_locations == src.nr()*src.nc()*src.k()); + DLIB_CASSERT(have_same_dimensions(dest,src)); + const auto d = dest.host(); + const auto s = src.host(); + + // Note that we subtract out the max values in each channel before applying + // exp() to avoid numeric overflow in the subsequent computations. Doing this + // doesn't change the resulting output, it just makes it more numerically + // stable. + for (long n = 0; n < src.num_samples(); ++n) + { + auto ss = s + num_locations*num_channels*n; + auto dd = d + num_locations*num_channels*n; + for (long i = 0; i < num_locations; ++i) + { + float max_val = -std::numeric_limits::infinity(); + for (long k = 0; k < num_channels; ++k) + max_val = std::max(max_val, ss[k*num_locations]); + + for (long k = 0; k < num_channels; ++k) + dd[k*num_locations] = std::exp(ss[k*num_locations]-max_val); + + ++ss; + ++dd; + } + } + + // Now normalize each channel so they sum to 1. + for (long n = 0; n < src.num_samples(); ++n) + { + const auto dd = d + num_locations*num_channels*n; + for (long i = 0; i < num_locations; ++i) + { + const auto ddd = dd+i; + + float temp = 0; + for (long k = 0; k < num_channels; ++k) + temp += ddd[k*num_locations]; + for (long k = 0; k < num_channels; ++k) + ddd[k*num_locations] /= temp; + } + } + } + + void softmax_gradient ( + const long num_locations, + const long num_channels, + tensor& grad, + const tensor& dest, + const tensor& gradient_input + ) + { + DLIB_ASSERT(num_channels*num_locations == grad.nr()*grad.nc()*grad.k()); + DLIB_CASSERT(have_same_dimensions(grad,dest)); + DLIB_CASSERT(have_same_dimensions(grad,gradient_input)); + const auto d = dest.host(); + const auto g = grad.host(); + const auto in = gradient_input.host(); + + + for (long n = 0; n < grad.num_samples(); ++n) + { + const auto d2 = d + num_locations*num_channels*n; + const auto g2 = g + num_locations*num_channels*n; + const auto in2 = in + num_locations*num_channels*n; + for (long i = 0; i < num_locations; ++i) + { + const auto d3 = d2+i; + const auto g3 = g2+i; + const auto in3 = in2+i; + + float temp = 0; + for (long k = 0; k < num_channels; ++k) + temp += -d3[k*num_locations]*in3[k*num_locations]; + if (is_same_object(gradient_input, grad)) + { + for (long k = 0; k < num_channels; ++k) + g3[k*num_locations] = d3[k*num_locations]*(temp+in3[k*num_locations]); + } + else + { + for (long k = 0; k < num_channels; ++k) + g3[k*num_locations] += d3[k*num_locations]*(temp+in3[k*num_locations]); + } + } + } + } + } + + // ---------------------------------------------------------------------------------------- + + void softmax ( + tensor& dest, + const tensor& src + ) + { + DLIB_CASSERT(have_same_dimensions(dest,src)); + ttimpl::softmax(src.nr()*src.nc(), src.k(), dest, src); + } + + void softmax_gradient ( + tensor& grad, + const tensor& dest, + const tensor& gradient_input + ) + { + DLIB_CASSERT(have_same_dimensions(grad,dest)); + DLIB_CASSERT(have_same_dimensions(grad,gradient_input)); + ttimpl::softmax_gradient(grad.nr()*grad.nc(), grad.k(), grad, dest, gradient_input); + } + + // ------------------------------------------------------------------------------------ + + void softmax_all ( + tensor& dest, + const tensor& src + ) + { + DLIB_CASSERT(have_same_dimensions(dest,src)); + ttimpl::softmax(1, src.nr()*src.nc()*src.k(), dest, src); + } + + void softmax_all_gradient ( + tensor& grad, + const tensor& dest, + const tensor& gradient_input + ) + { + DLIB_CASSERT(have_same_dimensions(grad,dest)); + DLIB_CASSERT(have_same_dimensions(grad,gradient_input)); + ttimpl::softmax_gradient(1, grad.nr()*grad.nc()*grad.k(), grad, dest, gradient_input); + } + + // ------------------------------------------------------------------------------------ + + void sigmoid ( + tensor& dest, + const tensor& src + ) + { + const auto d = dest.host(); + const auto s = src.host(); + for (size_t i = 0; i < src.size(); ++i) + d[i] = 1/(1+std::exp(-s[i])); + } + + void sigmoid_gradient ( + tensor& grad, + const tensor& dest, + const tensor& gradient_input + ) + { + const auto g = grad.host(); + const auto d = dest.host(); + const auto in = gradient_input.host(); + if (is_same_object(gradient_input, grad)) + { + for (size_t i = 0; i < dest.size(); ++i) + g[i] = in[i]*d[i]*(1-d[i]); + } + else + { + for (size_t i = 0; i < dest.size(); ++i) + g[i] += in[i]*d[i]*(1-d[i]); + } + } + + // ------------------------------------------------------------------------------------ + + void relu ( + tensor& dest, + const tensor& src + ) + { + dest = lowerbound(mat(src), 0); + } + + void relu_gradient ( + tensor& grad, + const tensor& dest, + const tensor& gradient_input + ) + { + const float* gi = gradient_input.host(); + const float* in = dest.host(); + float* out = grad.host(); + if (is_same_object(grad, gradient_input)) + { + for (size_t i = 0; i < dest.size(); ++i) + { + if (in[i] > 0) + out[i] = gi[i]; + else + out[i] = 0; + } + } + else + { + for (size_t i = 0; i < dest.size(); ++i) + { + if (in[i] > 0) + out[i] += gi[i]; + } + } + } + + // ---------------------------------------------------------------------------------------- + + void prelu ( + tensor& dest, + const tensor& src, + const tensor& param + ) + { + const float p = param.host()[0]; + const float* s = src.host(); + float* d = dest.host(); + for (size_t i = 0; i < dest.size(); ++i) + { + if (s[i] > 0) + d[i] = s[i]; + else + d[i] = p*s[i]; + } + } + + void prelu_gradient ( + tensor& grad, + const tensor& src, + const tensor& gradient_input, + const tensor& param, + tensor& params_grad + ) + { + DLIB_CASSERT(is_same_object(grad, gradient_input) == false); + const float p = param.host()[0]; + const float* gi = gradient_input.host(); + const float* s = src.host(); + float* out = grad.host(); + float pgrad = 0; + for (size_t i = 0; i < src.size(); ++i) + { + if (s[i] > 0) + { + out[i] += gi[i]; + } + else + { + out[i] += p*gi[i]; + pgrad += gi[i]*s[i]; + } + } + params_grad.host()[0] = pgrad; + } + + // ------------------------------------------------------------------------------------ + + void tanh ( + tensor& dest, + const tensor& src + ) + { + const auto d = dest.host(); + const auto s = src.host(); + for (size_t i = 0; i < src.size(); ++i) + d[i] = std::tanh(s[i]); + } + + void tanh_gradient ( + tensor& grad, + const tensor& dest, + const tensor& gradient_input + ) + { + const auto g = grad.host(); + const auto d = dest.host(); + const auto in = gradient_input.host(); + if (is_same_object(grad, gradient_input)) + { + for (size_t i = 0; i < dest.size(); ++i) + g[i] = in[i]*(1-d[i]*d[i]); + } + else + { + for (size_t i = 0; i < dest.size(); ++i) + g[i] += in[i]*(1-d[i]*d[i]); + } + } + + // ---------------------------------------------------------------------------------------- + + void resize_bilinear ( + tensor& dest, + long dest_row_stride, + long dest_channel_stride, + const tensor& src, + long src_row_stride, + long src_channel_stride + ) + { + DLIB_CASSERT(is_same_object(dest, src)==false); + DLIB_CASSERT(dest.num_samples() == src.num_samples()); + DLIB_CASSERT(dest.k() == src.k()); + + if (dest.size() == 0 || src.size() == 0) + return; + + const float* s = src.host(); + float* d = dest.host(); + + parallel_for(0, dest.k()*dest.num_samples(), [&](long i) + { + auto simg = sub_image(s+i*src_channel_stride, src.nr(), src.nc(), src_row_stride); + auto dimg = sub_image(d+i*dest_channel_stride, dest.nr(), dest.nc(), dest_row_stride); + + resize_image(simg, dimg); + }); + } + + void resize_bilinear_gradient ( + tensor& grad, + long grad_row_stride, + long grad_channel_stride, + const tensor& gradient_input, + long gradient_input_row_stride, + long gradient_input_channel_stride + ) + { + DLIB_CASSERT(is_same_object(grad, gradient_input)==false); + DLIB_CASSERT(gradient_input.num_samples() == grad.num_samples()); + DLIB_CASSERT(gradient_input.k() == grad.k()); + + if (gradient_input.size() == 0 || grad.size() == 0) + return; + + const float* gi = gradient_input.host(); + float* g = grad.host(); + const float x_scale = (grad.nc()-1)/(float)std::max((gradient_input.nc()-1),1); + const float y_scale = (grad.nr()-1)/(float)std::max((gradient_input.nr()-1),1); + for (long long samp = 0; samp < gradient_input.num_samples(); ++samp) + { + for (long long k = 0; k < gradient_input.k(); ++k) + { + for (long long r = 0; r < gradient_input.nr(); ++r) + { + const float y = r*y_scale; + const long long top = static_cast(std::floor(y)); + const long long bottom = std::min(top+1, grad.nr()-1); + const float tb_frac = y - top; + for (long long c = 0; c < gradient_input.nc(); ++c) + { + const float x = c*x_scale; + const long long left = static_cast(std::floor(x)); + const long long right = std::min(left+1, grad.nc()-1); + const float lr_frac = x - left; + + const float tmp = gi[r*gradient_input_row_stride+c]; + + g[top*grad_row_stride+left] += tmp*(1-tb_frac)*(1-lr_frac); + g[top*grad_row_stride+right] += tmp*(1-tb_frac)*(lr_frac); + g[bottom*grad_row_stride+left] += tmp*(tb_frac)*(1-lr_frac); + g[bottom*grad_row_stride+right] += tmp*(tb_frac)*(lr_frac); + } + } + + g += grad_channel_stride; + gi += gradient_input_channel_stride; + } + } + } + + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + + pooling::pooling ( + ) : window_height(0),window_width(0),stride_y(0),stride_x(0),padding_y(0),padding_x(0),do_max_pooling(true) + { + } + + void pooling:: + clear( + ) + { + window_height = 0; + window_width = 0; + stride_y = 0; + stride_x = 0; + padding_y = 0; + padding_x = 0; + } + + void pooling:: + setup_max_pooling( + int window_height_, + int window_width_, + int stride_y_, + int stride_x_, + int padding_y_, + int padding_x_ + ) + { + DLIB_CASSERT(window_width_ > 0); + DLIB_CASSERT(window_height_ > 0); + DLIB_CASSERT(stride_y_ > 0); + DLIB_CASSERT(stride_x_ > 0); + DLIB_CASSERT(0 <= padding_y_ && padding_y_ < window_height_); + DLIB_CASSERT(0 <= padding_x_ && padding_x_ < window_width_); + + window_height = window_height_; + window_width = window_width_; + stride_y = stride_y_; + stride_x = stride_x_; + padding_y = padding_y_; + padding_x = padding_x_; + do_max_pooling = true; + } + + void pooling:: + setup_avg_pooling( + int window_height_, + int window_width_, + int stride_y_, + int stride_x_, + int padding_y_, + int padding_x_ + ) + { + DLIB_CASSERT(window_width_ > 0); + DLIB_CASSERT(window_height_ > 0); + DLIB_CASSERT(stride_y_ > 0); + DLIB_CASSERT(stride_x_ > 0); + DLIB_CASSERT(0 <= padding_y_ && padding_y_ < window_height_); + DLIB_CASSERT(0 <= padding_x_ && padding_x_ < window_width_); + + window_height = window_height_; + window_width = window_width_; + stride_y = stride_y_; + stride_x = stride_x_; + padding_y = padding_y_; + padding_x = padding_x_; + do_max_pooling = false; + } + + void pooling:: + operator() ( + resizable_tensor& dest, + const tensor& src + ) + { + DLIB_CASSERT(window_width > 0); + DLIB_CASSERT(window_height > 0); + DLIB_CASSERT(stride_y > 0); + DLIB_CASSERT(stride_x > 0); + DLIB_CASSERT(0 <= padding_y && padding_y < window_height); + DLIB_CASSERT(0 <= padding_x && padding_x < window_width); + DLIB_CASSERT(window_width <= src.nc() + 2*padding_x, + "Pooling windows must be small enough to fit into the padded image."); + DLIB_CASSERT(window_height <= src.nr() + 2*padding_y, + "Pooling windows must be small enough to fit into the padded image."); + + dest.set_size( + src.num_samples(), + src.k(), + 1+(src.nr()+2*padding_y-window_height)/stride_y, + 1+(src.nc()+2*padding_x-window_width)/stride_x + ); + + if (src.size() == 0) + { + dest = 0; + return; + } + + + auto d = dest.host(); + const long x_offset = window_width/2 - padding_x; + const long y_offset = window_height/2 - padding_y; + if (does_max_pooling()) + { + for (long n = 0; n < dest.num_samples(); ++n) + { + for (long k = 0; k < dest.k(); ++k) + { + auto simg = image_plane(src,n,k); + auto dimg = d + (n*dest.k() + k)*dest.nr()*dest.nc(); + + for (long r = 0; r < dest.nr(); ++r) + { + for (long c = 0; c < dest.nc(); ++c) + { + auto win = centered_rect(c*stride_x+x_offset, + r*stride_y+y_offset, + window_width, + window_height); + dimg[r*dest.nc() + c] = max(subm_clipped(simg,win)); + } + } + } + } + } + else + { + for (long n = 0; n < dest.num_samples(); ++n) + { + for (long k = 0; k < dest.k(); ++k) + { + auto simg = image_plane(src,n,k); + auto dimg = d + (n*dest.k() + k)*dest.nr()*dest.nc(); + + for (long r = 0; r < dest.nr(); ++r) + { + for (long c = 0; c < dest.nc(); ++c) + { + auto win = centered_rect(c*stride_x+x_offset, + r*stride_y+y_offset, + window_width, + window_height); + dimg[r*dest.nc() + c] = mean(subm_clipped(simg,win)); + } + } + } + } + } + + } + + void pooling::get_gradient( + const tensor& gradient_input, + const tensor& dest, + const tensor& src, + tensor& grad + ) + { + DLIB_CASSERT(have_same_dimensions(gradient_input,dest)); + DLIB_CASSERT(have_same_dimensions(src,grad)); + + + if (src.size() == 0) + { + return; + } + + + auto gi = gradient_input.host(); + auto g = grad.host(); + const long x_offset = window_width/2 - padding_x; + const long y_offset = window_height/2 - padding_y; + if (does_max_pooling()) + { + for (long n = 0; n < dest.num_samples(); ++n) + { + for (long k = 0; k < dest.k(); ++k) + { + auto simg = image_plane(src,n,k); + auto gimg = g + (n*grad.k() + k)*grad.nr()*grad.nc(); + auto giimg = gi + (n*dest.k() + k)*dest.nr()*dest.nc(); + auto imgbox = get_rect(simg); + + for (long r = 0; r < dest.nr(); ++r) + { + for (long c = 0; c < dest.nc(); ++c) + { + auto win = centered_rect(c*stride_x+x_offset, + r*stride_y+y_offset, + window_width, + window_height).intersect(imgbox); + auto p = max_point(subm(simg,win))+win.tl_corner(); + gimg[p.y()*grad.nc()+p.x()] += giimg[r*dest.nc()+c]; + } + } + } + } + } + else + { + for (long n = 0; n < dest.num_samples(); ++n) + { + for (long k = 0; k < dest.k(); ++k) + { + auto simg = image_plane(src,n,k); + auto gimg = g + (n*grad.k() + k)*grad.nr()*grad.nc(); + auto giimg = gi + (n*dest.k() + k)*dest.nr()*dest.nc(); + auto imgbox = get_rect(simg); + + for (long r = 0; r < dest.nr(); ++r) + { + for (long c = 0; c < dest.nc(); ++c) + { + auto win = centered_rect(c*stride_x+x_offset, + r*stride_y+y_offset, + window_width, + window_height).intersect(imgbox); + const float delta = giimg[r*dest.nc()+c]/win.area(); + for (long y = win.top(); y <= win.bottom(); ++y) + { + for (long x = win.left(); x <= win.right(); ++x) + { + gimg[y*grad.nc()+x] += delta; + } + } + } + } + } + } + } + + } + + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + + void img2col( + matrix& output, + const tensor& data, + long n, + long filter_nr, + long filter_nc, + long stride_y, + long stride_x, + long padding_y, + long padding_x + ) + { + const auto d = data.host() + data.k()*data.nr()*data.nc()*n; + const rectangle boundary = get_rect(data); + + const long out_nr = 1+(data.nr()+2*padding_y-filter_nr)/stride_y; + const long out_nc = 1+(data.nc()+2*padding_x-filter_nc)/stride_x; + + output.set_size(out_nr*out_nc, + data.k()*filter_nr*filter_nc); + DLIB_CASSERT(output.size() != 0); + float* t = &output(0,0); + + // now fill in the Toeplitz output matrix for the n-th sample in data. + size_t cnt = 0; + const long max_r = data.nr() + padding_y-(filter_nr-1); + const long max_c = data.nc() + padding_x-(filter_nc-1); + for (long r = -padding_y; r < max_r; r+=stride_y) + { + for (long c = -padding_x; c < max_c; c+=stride_x) + { + for (long k = 0; k < data.k(); ++k) + { + for (long y = 0; y < filter_nr; ++y) + { + for (long x = 0; x < filter_nc; ++x) + { + DLIB_ASSERT(cnt < output.size()); + long xx = c+x; + long yy = r+y; + if (boundary.contains(xx,yy)) + *t = d[(k*data.nr() + yy)*data.nc() + xx]; + else + *t = 0; + ++t; + ++cnt; + } + } + } + } + } + } + + void col2img( + const matrix& output, + tensor& data, + long n, + long filter_nr, + long filter_nc, + long stride_y, + long stride_x, + long padding_y, + long padding_x + ) + { + const auto d = data.host() + data.k()*data.nr()*data.nc()*n; + const rectangle boundary = get_rect(data); + + DLIB_CASSERT(output.size() != 0); + const float* t = &output(0,0); + + // now fill in the Toeplitz output matrix for the n-th sample in data. + const long max_r = data.nr() + padding_y-(filter_nr-1); + const long max_c = data.nc() + padding_x-(filter_nc-1); + for (long r = -padding_y; r < max_r; r+=stride_y) + { + for (long c = -padding_x; c < max_c; c+=stride_x) + { + for (long k = 0; k < data.k(); ++k) + { + for (long y = 0; y < filter_nr; ++y) + { + for (long x = 0; x < filter_nc; ++x) + { + long xx = c+x; + long yy = r+y; + if (boundary.contains(xx,yy)) + d[(k*data.nr() + yy)*data.nc() + xx] += *t; + ++t; + } + } + } + } + } + } + + void tensor_conv::operator() ( + const bool add_to_output, + resizable_tensor& output, + const tensor& data, + const tensor& filters + ) + { + DLIB_CASSERT(last_stride_y > 0 && last_stride_x > 0, "You must call setup() before calling this function."); + output.set_size(data.num_samples(), + filters.num_samples(), + 1+(data.nr()+2*last_padding_y-filters.nr())/last_stride_y, + 1+(data.nc()+2*last_padding_x-filters.nc())/last_stride_x); + (*this)(add_to_output, static_cast(output),data,filters); + } + + void tensor_conv::operator() ( + const bool add_to_output, + tensor& output, + const tensor& data, + const tensor& filters + ) + { + DLIB_CASSERT(is_same_object(output,data) == false); + DLIB_CASSERT(is_same_object(output,filters) == false); + DLIB_CASSERT(filters.k() == data.k()); + DLIB_CASSERT(last_stride_y > 0 && last_stride_x > 0, "You must call setup() before calling this function."); + DLIB_CASSERT(filters.nr() <= data.nr() + 2*last_padding_y, + "Filter windows must be small enough to fit into the padded image."); + DLIB_CASSERT(filters.nc() <= data.nc() + 2*last_padding_x, + "Filter windows must be small enough to fit into the padded image."); + + DLIB_CASSERT(output.num_samples() == data.num_samples()); + DLIB_CASSERT(output.k() == filters.num_samples()); + DLIB_CASSERT(output.nr() == 1+(data.nr()+2*last_padding_y-filters.nr())/last_stride_y); + DLIB_CASSERT(output.nc() == 1+(data.nc()+2*last_padding_x-filters.nc())/last_stride_x); + + + matrix temp; + for (long n = 0; n < data.num_samples(); ++n) + { + img2col(temp, data, n, filters.nr(), filters.nc(), last_stride_y, last_stride_x, last_padding_y, last_padding_x); + + if (add_to_output) + output.add_to_sample(n, mat(filters)*trans(temp)); + else + output.set_sample(n, mat(filters)*trans(temp)); + } + } + + // ------------------------------------------------------------------------------------ + + void tensor_conv:: + get_gradient_for_data ( + const bool add_to_output, + const tensor& gradient_input, + const tensor& filters, + tensor& data_gradient + ) + { + matrix temp; + if (!add_to_output) + data_gradient = 0; + for (long n = 0; n < gradient_input.num_samples(); ++n) + { + auto gi = mat(gradient_input.host()+gradient_input.k()*gradient_input.nr()*gradient_input.nc()*n, + gradient_input.k(), + gradient_input.nr()*gradient_input.nc()); + + + temp = trans(gi)*mat(filters); + col2img(temp, data_gradient, n, filters.nr(), filters.nc(), last_stride_y, last_stride_x, last_padding_y, last_padding_x); + } + } + + // ------------------------------------------------------------------------------------ + + void tensor_conv:: + get_gradient_for_filters ( + const bool add_to_output, + const tensor& gradient_input, + const tensor& data, + tensor& filters_gradient + ) + { + matrix temp; + for (long n = 0; n < gradient_input.num_samples(); ++n) + { + auto gi = mat(gradient_input.host()+gradient_input.k()*gradient_input.nr()*gradient_input.nc()*n, + gradient_input.k(), + gradient_input.nr()*gradient_input.nc()); + + + img2col(temp, data, n, filters_gradient.nr(), filters_gradient.nc(), last_stride_y, last_stride_x, last_padding_y, last_padding_x); + if (n == 0) + { + if (add_to_output) + filters_gradient += gi*temp; + else + filters_gradient = gi*temp; + } + else + { + filters_gradient += gi*temp; + } + } + } + + // ------------------------------------------------------------------------------------ + + void copy_tensor( + bool add_to, + tensor& dest, + size_t dest_k_offset, + const tensor& src, + size_t src_k_offset, + size_t count_k + ) + { + const size_t dest_sample_size = static_cast(dest.nc() * dest.nr() * dest.k()); + const size_t src_sample_size = static_cast(src.nc() * src.nr() * src.k()); + + const size_t block_size = count_k * dest.nc() * dest.nr(); + + DLIB_CASSERT(dest.num_samples() == src.num_samples() && + dest.nc() == src.nc() && dest.nr() == src.nr(), "All sources should fit into dest tensor size"); + DLIB_CASSERT(dest.k() - dest_k_offset >= count_k, "Not enough space in dest tensor"); + DLIB_CASSERT(src.k() - src_k_offset >= count_k, "Not enough space in src tensor"); + + float* dest_p = dest.host() + dest_k_offset * dest.nc() * dest.nr(); + const float* src_p = src.host() + src_k_offset * src.nc() * src.nr(); + + for (long i = 0; i < src.num_samples(); ++i) + { + if (add_to) + { + for (size_t j = 0; j < block_size; ++j) + dest_p[j] += src_p[j]; + } + else + { + ::memcpy(dest_p, src_p, block_size * sizeof(float)); + } + + dest_p += dest_sample_size; + src_p += src_sample_size; + } + } + + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + + } +} + + +#endif // DLIB_DNN_CPU_cPP_ + + diff --git a/ml/dlib/dlib/dnn/cpu_dlib.h b/ml/dlib/dlib/dnn/cpu_dlib.h new file mode 100644 index 000000000..330df01a2 --- /dev/null +++ b/ml/dlib/dlib/dnn/cpu_dlib.h @@ -0,0 +1,505 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_DNN_CPU_H_ +#define DLIB_DNN_CPU_H_ + +// This file contains CPU implementations of the GPU based functions in cuda_dlib.h +// and cudnn_dlibapi.h + +#include "tensor.h" +#include "../geometry/rectangle.h" + +namespace dlib +{ + namespace cpu + { + + // ----------------------------------------------------------------------------------- + + void multiply ( + bool add_to, + tensor& dest, + const tensor& src1, + const tensor& src2 + ); + + void multiply_conv ( + bool add_to, + tensor& dest, + const tensor& src1, + const tensor& src2 + ); + + void multiply_zero_padded ( + bool add_to, + tensor& dest, + const tensor& src1, + const tensor& src2 + ); + + void scale_channels ( + bool add_to, + tensor& dest, + const tensor& src, + const tensor& scales + ); + + void add( + float beta, + tensor& dest, + float alpha, + const tensor& src + ); + + void assign_bias_gradient ( + tensor& grad, + const tensor& gradient_input + ); + + void add ( + tensor& dest, + const tensor& src1, + const tensor& src2 + ); + + void assign_conv_bias_gradient ( + tensor& grad, + const tensor& gradient_input + ); + + // ----------------------------------------------------------------------------------- + + void affine_transform( + tensor& dest, + const tensor& src, + const float A, + const float B + ); + + void affine_transform( + tensor& dest, + const tensor& src1, + const tensor& src2, + const float A, + const float B, + const float C + ); + + void affine_transform( + tensor& dest, + const tensor& src1, + const tensor& src2, + const tensor& src3, + const float A, + const float B, + const float C, + const float D + ); + + void affine_transform_range( + size_t begin, + size_t end, + tensor& dest, + const tensor& src1, + const tensor& src2, + const tensor& src3, + const float A, + const float B, + const float C + ); + + // ----------------------------------------------------------------------------------- + + void affine_transform( + tensor& dest, + const tensor& src, + const tensor& A, + const tensor& B + ); + + // ----------------------------------------------------------------------------------- + + void affine_transform_conv( + tensor& dest, + const tensor& src, + const tensor& A, + const tensor& B + ); + + // ----------------------------------------------------------------------------------- + + void affine_transform( + const rectangle& rect, + tensor& dest, + const tensor& src1, + const tensor& src2, + const tensor& src3, + float A, + float B, + float C + ); + + // ----------------------------------------------------------------------------------- + + void compute_adam_update ( + size_t begin, + size_t end, + tensor& s, + tensor& m, + tensor& v, + const float t, + const float learning_rate, + const float weight_decay, + const float momentum1, + const float momentum2, + const tensor& params, + const tensor& params_grad + ); + + // ----------------------------------------------------------------------------------- + + void batch_normalize_inference ( + const double eps, + resizable_tensor& dest, + const tensor& src, + const tensor& gamma, + const tensor& beta, + const tensor& running_means, + const tensor& running_variances + ); + + void batch_normalize ( + const double eps, + resizable_tensor& dest, + resizable_tensor& means, + resizable_tensor& invstds, + const double averaging_factor, + resizable_tensor& running_means, + resizable_tensor& running_variances, + const tensor& src, + const tensor& gamma, + const tensor& beta + ); + + void batch_normalize_gradient ( + const double eps, + const tensor& gradient_input, + const tensor& means, + const tensor& invstds, + const tensor& src, + const tensor& gamma, + tensor& src_grad, + tensor& gamma_grad, + tensor& beta_grad + ); + + void batch_normalize_conv_inference ( + const double eps, + resizable_tensor& dest, + const tensor& src, + const tensor& gamma, + const tensor& beta, + const tensor& running_means, + const tensor& running_variances + ); + + void batch_normalize_conv ( + const double eps, + resizable_tensor& dest, + resizable_tensor& means, + resizable_tensor& invstds, + const double averaging_factor, + resizable_tensor& running_means, + resizable_tensor& running_variances, + const tensor& src, + const tensor& gamma, + const tensor& beta + ); + + void batch_normalize_conv_gradient ( + const double eps, + const tensor& gradient_input, + const tensor& means, + const tensor& invstds, + const tensor& src, + const tensor& gamma, + tensor& src_grad, + tensor& gamma_grad, + tensor& beta_grad + ); + + // ----------------------------------------------------------------------------------- + + void threshold ( + tensor& data, + float thresh + ); + + void dot ( + const tensor& a, + const tensor& b, + tensor& result, + size_t idx + ); + + // ----------------------------------------------------------------------------------- + + void softmax ( + tensor& dest, + const tensor& src + ); + + void softmax_gradient ( + tensor& grad, + const tensor& dest, + const tensor& gradient_input + ); + + // ------------------------------------------------------------------------------------ + + void softmax_all ( + tensor& dest, + const tensor& src + ); + + void softmax_all_gradient ( + tensor& grad, + const tensor& dest, + const tensor& gradient_input + ); + + // ------------------------------------------------------------------------------------ + + void sigmoid ( + tensor& dest, + const tensor& src + ); + + void sigmoid_gradient ( + tensor& grad, + const tensor& dest, + const tensor& gradient_input + ); + + // ------------------------------------------------------------------------------------ + + void relu ( + tensor& dest, + const tensor& src + ); + + void relu_gradient ( + tensor& grad, + const tensor& dest, + const tensor& gradient_input + ); + + // ---------------------------------------------------------------------------------------- + + void prelu ( + tensor& dest, + const tensor& src, + const tensor& param + ); + + void prelu_gradient ( + tensor& grad, + const tensor& src, + const tensor& gradient_input, + const tensor& param, + tensor& params_grad + ); + + // ------------------------------------------------------------------------------------ + + void tanh ( + tensor& dest, + const tensor& src + ); + + void tanh_gradient ( + tensor& grad, + const tensor& dest, + const tensor& gradient_input + ); + + // ---------------------------------------------------------------------------------------- + + void resize_bilinear ( + tensor& dest, + long dest_row_stride, + long dest_channel_stride, + const tensor& src, + long src_row_stride, + long src_channel_stride + ); + + void resize_bilinear_gradient ( + tensor& grad, + long grad_row_stride, + long grad_channel_stride, + const tensor& gradient_input, + long gradient_input_row_stride, + long gradient_input_channel_stride + ); + + inline void resize_bilinear ( + tensor& dest, + const tensor& src + ) { resize_bilinear(dest, dest.nc(), dest.nr()*dest.nc(), src, src.nc(), src.nr()*src.nc()); } + + inline void resize_bilinear_gradient ( + tensor& grad, + const tensor& gradient_input + ) { resize_bilinear_gradient(grad, grad.nc(), grad.nr()*grad.nc(), gradient_input, gradient_input.nc(), gradient_input.nr()*gradient_input.nc()); } + + // ----------------------------------------------------------------------------------- + + class pooling + { + public: + + pooling(const pooling&) = delete; + pooling& operator=(const pooling&) = delete; + + pooling ( + ); + + void clear( + ); + + void setup_max_pooling( + int window_height, + int window_width, + int stride_y, + int stride_x, + int padding_y, + int padding_x + ); + + void setup_avg_pooling( + int window_height, + int window_width, + int stride_y, + int stride_x, + int padding_y, + int padding_x + ); + + bool does_max_pooling( + ) const { return do_max_pooling; } + + void operator() ( + resizable_tensor& dest, + const tensor& src + ); + + void get_gradient( + const tensor& gradient_input, + const tensor& dest, + const tensor& src, + tensor& grad + ); + + private: + int window_height; + int window_width; + int stride_y; + int stride_x; + int padding_y; + int padding_x; + bool do_max_pooling; + + }; + + // ----------------------------------------------------------------------------------- + + class tensor_conv + { + public: + tensor_conv(const tensor_conv&) = delete; + tensor_conv& operator=(const tensor_conv&) = delete; + + tensor_conv() {} + + void clear( + ) {} + + void setup( + const tensor& data, /* not used but required for interface */ + const tensor& filters, /* not used but required for interface */ + int stride_y, + int stride_x, + int padding_y, + int padding_x + ) + { + (void)data; /* silence compiler */ + DLIB_CASSERT(stride_y > 0 && stride_x > 0); + DLIB_CASSERT(0 <= padding_y && padding_y < filters.nr()); + DLIB_CASSERT(0 <= padding_x && padding_x < filters.nc()); + last_stride_y = stride_y; + last_stride_x = stride_x; + last_padding_y = padding_y; + last_padding_x = padding_x; + } + + void operator() ( + const bool add_to_output, + resizable_tensor& output, + const tensor& data, + const tensor& filters + ); + + void operator() ( + const bool add_to_output, + tensor& output, + const tensor& data, + const tensor& filters + ); + + void get_gradient_for_data ( + const bool add_to_output, + const tensor& gradient_input, + const tensor& filters, + tensor& data_gradient + ); + + void get_gradient_for_filters ( + const bool add_to_output, + const tensor& gradient_input, + const tensor& data, + tensor& filters_gradient + ); + + private: + + long last_stride_y = 0; + long last_stride_x = 0; + long last_padding_y = 0; + long last_padding_x = 0; + }; + + // ----------------------------------------------------------------------------------- + + void copy_tensor( + bool add_to, + tensor& dest, + size_t dest_k_offset, + const tensor& src, + size_t src_k_offset, + size_t count_k + ); + + // ----------------------------------------------------------------------------------- + + } +} + +#ifdef NO_MAKEFILE +#include "cpu_dlib.cpp" +#endif + +#endif // DLIB_DNN_CPU_H_ + + diff --git a/ml/dlib/dlib/dnn/cublas_dlibapi.cpp b/ml/dlib/dlib/dnn/cublas_dlibapi.cpp new file mode 100644 index 000000000..376cc9f00 --- /dev/null +++ b/ml/dlib/dlib/dnn/cublas_dlibapi.cpp @@ -0,0 +1,165 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_DNN_CuBLAS_CPP_ +#define DLIB_DNN_CuBLAS_CPP_ + +#ifdef DLIB_USE_CUDA + +#include "cublas_dlibapi.h" +#include "cuda_utils.h" + +#include +#include + +static const char* cublas_get_error_string(cublasStatus_t s) +{ + switch(s) + { + case CUBLAS_STATUS_NOT_INITIALIZED: + return "CUDA Runtime API initialization failed."; + case CUBLAS_STATUS_ALLOC_FAILED: + return "CUDA Resources could not be allocated."; + default: + return "A call to cuBLAS failed"; + } +} + +// Check the return value of a call to the cuBLAS runtime for an error condition. +#define CHECK_CUBLAS(call) \ +do{ \ + const cublasStatus_t error = call; \ + if (error != CUBLAS_STATUS_SUCCESS) \ + { \ + std::ostringstream sout; \ + sout << "Error while calling " << #call << " in file " << __FILE__ << ":" << __LINE__ << ". ";\ + sout << "code: " << error << ", reason: " << cublas_get_error_string(error);\ + throw dlib::cublas_error(sout.str()); \ + } \ +}while(false) + +namespace dlib +{ + namespace cuda + { + + // ----------------------------------------------------------------------------------- + + class cublas_context + { + public: + // not copyable + cublas_context(const cublas_context&) = delete; + cublas_context& operator=(const cublas_context&) = delete; + + cublas_context() + { + handles.resize(16); + } + ~cublas_context() + { + for (auto h : handles) + { + if (h) + cublasDestroy(h); + } + } + + cublasHandle_t get_handle ( + ) + { + int new_device_id; + CHECK_CUDA(cudaGetDevice(&new_device_id)); + // make room for more devices if needed + if (new_device_id >= (long)handles.size()) + handles.resize(new_device_id+16); + + // If we don't have a handle already for this device then make one + if (!handles[new_device_id]) + CHECK_CUBLAS(cublasCreate(&handles[new_device_id])); + + // Finally, return the handle for the current device + return handles[new_device_id]; + } + + private: + + std::vector handles; + }; + + static cublasHandle_t context() + { + thread_local cublas_context c; + return c.get_handle(); + } + + // ----------------------------------------------------------------------------------- + + void gemm ( + float beta, + tensor& dest, + float alpha, + const tensor& lhs, + bool trans_lhs, + const tensor& rhs, + bool trans_rhs + ) + { + // Recall that BLAS uses column major order so to deal with that we flip the + // order of the lhs and rhs arguments. + const auto transa = trans_lhs ? CUBLAS_OP_T : CUBLAS_OP_N; + const auto transb = trans_rhs ? CUBLAS_OP_T : CUBLAS_OP_N; + + const int dest_nr = dest.num_samples(); + const int dest_nc = dest.size()/dest_nr; + const int lhs_nr = lhs.num_samples(); + const int lhs_nc = lhs.size()/lhs_nr; + const int rhs_nr = rhs.num_samples(); + const int rhs_nc = rhs.size()/rhs_nr; + if (trans_lhs && trans_rhs) + { + DLIB_ASSERT( dest_nr == lhs_nc && + dest_nc == rhs_nr && + lhs_nr == rhs_nc) + } + else if (!trans_lhs && trans_rhs) + { + DLIB_ASSERT( dest_nr == lhs_nr && + dest_nc == rhs_nr && + lhs_nc == rhs_nc) + } + else if (trans_lhs && !trans_rhs) + { + DLIB_ASSERT( dest_nr == lhs_nc && + dest_nc == rhs_nc && + lhs_nr == rhs_nr) + } + else + { + DLIB_ASSERT( dest_nr == lhs_nr && + dest_nc == rhs_nc && + lhs_nc == rhs_nr) + } + + const int k = trans_rhs ? rhs_nc : rhs_nr; + CHECK_CUBLAS(cublasSgemm(context(), + transb, + transa, + dest_nc, dest_nr, k, + &alpha, + rhs.device(), rhs_nc, + lhs.device(), lhs_nc, + &beta, + dest.device(),dest_nc)); + } + + // ------------------------------------------------------------------------------------ + + } +} + +#endif // DLIB_USE_CUDA + +#endif // DLIB_DNN_CuBLAS_CPP_ + + + diff --git a/ml/dlib/dlib/dnn/cublas_dlibapi.h b/ml/dlib/dlib/dnn/cublas_dlibapi.h new file mode 100644 index 000000000..b46fd25ca --- /dev/null +++ b/ml/dlib/dlib/dnn/cublas_dlibapi.h @@ -0,0 +1,50 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_DNN_CuBLAS_H_ +#define DLIB_DNN_CuBLAS_H_ + +#ifdef DLIB_USE_CUDA + +#include "tensor.h" +#include "cuda_errors.h" + +namespace dlib +{ + namespace cuda + { + + // ----------------------------------------------------------------------------------- + + void gemm ( + float beta, + tensor& dest, + float alpha, + const tensor& lhs, + bool trans_lhs, + const tensor& rhs, + bool trans_rhs + ); + /*! + requires + - The dimensions of lhs and rhs must be compatible for matrix + multiplication. In particular: + - Let L == trans_lhs ? trans(mat(lhs)) : mat(lhs) + - Let R == trans_rhs ? trans(mat(rhs)) : mat(rhs) + - Let D == mat(dest) + - D.nr() == L.nr() && D.nc() == R.nc() + (i.e. dest must be preallocated and have the correct output dimensions) + - L.nc() == R.nr() + ensures + - performs: dest = alpha*L*R + beta*mat(dest) + !*/ + + // ------------------------------------------------------------------------------------ + + } +} + +#endif // DLIB_USE_CUDA + +#endif // DLIB_DNN_CuBLAS_H_ + + diff --git a/ml/dlib/dlib/dnn/cuda_data_ptr.cpp b/ml/dlib/dlib/dnn/cuda_data_ptr.cpp new file mode 100644 index 000000000..8abce0695 --- /dev/null +++ b/ml/dlib/dlib/dnn/cuda_data_ptr.cpp @@ -0,0 +1,71 @@ +// Copyright (C) 2017 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_DNN_CuDA_DATA_PTR_CPP_ +#define DLIB_DNN_CuDA_DATA_PTR_CPP_ + +#ifdef DLIB_USE_CUDA + +#include "cuda_data_ptr.h" +#include "cuda_utils.h" + +namespace dlib +{ + namespace cuda + { + + // ----------------------------------------------------------------------------------- + + cuda_data_void_ptr:: + cuda_data_void_ptr( + size_t n + ) : num(n) + { + if (n == 0) + return; + + void* data = nullptr; + + CHECK_CUDA(cudaMalloc(&data, n)); + pdata.reset(data, [](void* ptr){ + auto err = cudaFree(ptr); + if(err!=cudaSuccess) + std::cerr << "cudaFree() failed. Reason: " << cudaGetErrorString(err) << std::endl; + }); + } + + // ------------------------------------------------------------------------------------ + + void memcpy( + void* dest, + const cuda_data_void_ptr& src + ) + { + if (src.size() != 0) + { + CHECK_CUDA(cudaMemcpy(dest, src.data(), src.size(), cudaMemcpyDefault)); + } + } + + // ------------------------------------------------------------------------------------ + + void memcpy( + cuda_data_void_ptr& dest, + const void* src + ) + { + if (dest.size() != 0) + { + CHECK_CUDA(cudaMemcpy(dest.data(), src, dest.size(), cudaMemcpyDefault)); + } + } + + // ------------------------------------------------------------------------------------ + + } +} + +#endif // DLIB_USE_CUDA + +#endif // DLIB_DNN_CuDA_DATA_PTR_CPP_ + + diff --git a/ml/dlib/dlib/dnn/cuda_data_ptr.h b/ml/dlib/dlib/dnn/cuda_data_ptr.h new file mode 100644 index 000000000..7eca608a0 --- /dev/null +++ b/ml/dlib/dlib/dnn/cuda_data_ptr.h @@ -0,0 +1,184 @@ +// Copyright (C) 2017 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_DNN_CuDA_DATA_PTR_H_ +#define DLIB_DNN_CuDA_DATA_PTR_H_ + +#ifdef DLIB_USE_CUDA + +#include +#include + +namespace dlib +{ + namespace cuda + { + + // ------------------------------------------------------------------------------------ + + class cuda_data_void_ptr + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a block of memory on a CUDA device. + !*/ + public: + + cuda_data_void_ptr() = default; + + cuda_data_void_ptr(size_t n); + /*! + ensures + - This object will allocate a device memory buffer of n bytes. + - #size() == n + !*/ + + void* data() { return pdata.get(); } + const void* data() const { return pdata.get(); } + operator void*() { return pdata.get(); } + operator const void*() const { return pdata.get(); } + + void reset() { pdata.reset(); } + + size_t size() const { return num; } + /*! + ensures + - returns the length of this buffer, in bytes. + !*/ + + private: + + size_t num = 0; + std::shared_ptr pdata; + }; + + // ------------------------------------------------------------------------------------ + + void memcpy( + void* dest, + const cuda_data_void_ptr& src + ); + /*! + requires + - dest == a pointer to at least src.size() bytes on the host machine. + ensures + - copies the GPU data from src into dest. + !*/ + + // ------------------------------------------------------------------------------------ + + void memcpy( + cuda_data_void_ptr& dest, + const void* src + ); + /*! + requires + - dest == a pointer to at least src.size() bytes on the host machine. + ensures + - copies the host data from src to the GPU memory buffer dest. + !*/ + + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + + template + class cuda_data_ptr + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a block of memory on a CUDA device. It is just a type safe + version of cuda_data_void_ptr. + !*/ + + public: + + static_assert(std::is_standard_layout::value, "You can only create basic standard layout types on the GPU"); + + cuda_data_ptr() = default; + cuda_data_ptr(size_t n) : num(n) + /*! + ensures + - This object will allocate a device memory buffer of n T objects. + - #size() == n + !*/ + { + if (n == 0) + return; + + pdata = cuda_data_void_ptr(n*sizeof(T)); + } + + T* data() { return (T*)pdata.data(); } + const T* data() const { return (T*)pdata.data(); } + + operator T*() { return (T*)pdata.data(); } + operator const T*() const { return (T*)pdata.data(); } + + void reset() { pdata.reset(); } + + size_t size() const { return num; } + + + friend void memcpy( + std::vector& dest, + const cuda_data_ptr& src + ) + { + dest.resize(src.size()); + if (src.size() != 0) + memcpy(dest.data(), src.pdata); + } + + friend void memcpy( + cuda_data_ptr& src, + const std::vector& dest + ) + { + if (dest.size() != src.size()) + dest = cuda_data_ptr(src.size()); + + if (src.size() != 0) + memcpy(src.pdata, dest.data()); + } + + private: + + size_t num = 0; + cuda_data_void_ptr pdata; + }; + + // ------------------------------------------------------------------------------------ + + class resizable_cuda_buffer + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a block of memory on a CUDA device that will be automatically + resized if requested size is larger than allocated. + !*/ + public: + cuda_data_void_ptr get(size_t size) + /*! + ensures + - This object will return the buffer of requested size of larger + - buffer.size() >= size + !*/ + { + if (buffer.size() < size) + { + buffer.reset(); + buffer = cuda_data_void_ptr(size); + } + return buffer; + } + private: + cuda_data_void_ptr buffer; + }; + + } +} + +#endif // DLIB_USE_CUDA + +#endif // DLIB_DNN_CuDA_DATA_PTR_H_ + diff --git a/ml/dlib/dlib/dnn/cuda_dlib.cu b/ml/dlib/dlib/dnn/cuda_dlib.cu new file mode 100644 index 000000000..6c37593f1 --- /dev/null +++ b/ml/dlib/dlib/dnn/cuda_dlib.cu @@ -0,0 +1,1630 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "cuda_utils.h" +#include "cuda_dlib.h" + + +namespace dlib +{ + namespace cuda + { + + // ----------------------------------------------------------------------------------- + + void set_device ( + int dev + ) + { + CHECK_CUDA(cudaSetDevice(dev)); + } + + int get_device ( + ) + { + int dev = 0; + CHECK_CUDA(cudaGetDevice(&dev)); + return dev; + } + + std::string get_device_name ( + int device + ) + { + cudaDeviceProp props; + CHECK_CUDA(cudaGetDeviceProperties(&props, device)); + return props.name; + } + + void set_current_device_blocking_sync( + ) + { + CHECK_CUDA(cudaSetDeviceFlags(cudaDeviceScheduleBlockingSync)); + } + + int get_num_devices ( + ) + { + int num_devices; + CHECK_CUDA(cudaGetDeviceCount(&num_devices)); + return num_devices; + } + + bool can_access_peer (int device_id, int peer_device_id) + { + int can_access; + CHECK_CUDA(cudaDeviceCanAccessPeer(&can_access, device_id, peer_device_id)); + return can_access != 0; + } + bool can_access_peer (const tensor& device, const tensor& peer_device) + { + return can_access_peer(device.device_id(), peer_device.device_id()); + } + + void device_synchronize (int dev) + { + raii_set_device set_dev(dev); + CHECK_CUDA(cudaDeviceSynchronize()); + } + void device_synchronize (const tensor& dev) { device_synchronize(dev.device_id()); } + + enable_peer_access:: + enable_peer_access( + int device_id, + int peer_device_id + ) : call_disable(false), device_id(device_id), peer_device_id(peer_device_id) + { + raii_set_device set_dev(device_id); + + auto err = cudaDeviceEnablePeerAccess(peer_device_id, 0); + if (err == cudaSuccess) + { + call_disable = true; + } + else if (err == cudaErrorPeerAccessAlreadyEnabled) + { + // call cudaGetLastError() to dispose of this error since we don't + // care. + auto err2 = cudaGetLastError(); + if (err2 != cudaErrorPeerAccessAlreadyEnabled) + CHECK_CUDA(err2); + } + else + { + CHECK_CUDA(err); + } + } + + + enable_peer_access:: + ~enable_peer_access() noexcept(false) + { + if (call_disable) + { + raii_set_device set_dev(device_id); + CHECK_CUDA(cudaDeviceDisablePeerAccess(peer_device_id)); + } + } + + // ----------------------------------------------------------------------------------- + // ----------------------------------------------------------------------------------- + // ----------------------------------------------------------------------------------- + + __global__ void _cuda_inverse_norms(float* invnorms, const float* data, size_t nr, size_t nc, const float eps) + { + // initialize invnorms before we begin. + for (auto i : grid_stride_range_y(0, nr)) + for (auto j : grid_stride_range(0, 1)) + invnorms[i] = eps; + __syncthreads(); + + for (auto i : grid_stride_range_y(0, nr)) + { + auto p = data + i*nc; + float temp = 0; + for (auto j : grid_stride_range(0, nc)) + temp += p[j]*p[j]; + + // and store the sum into invnorms[i] + warp_reduce_atomic_add(invnorms[i], temp); + } + __syncthreads(); + + for (auto i : grid_stride_range_y(0, nr)) + for (auto j : grid_stride_range(0, 1)) + invnorms[i] = 1.0/std::sqrt(invnorms[i]); + } + + void inverse_norms ( + resizable_tensor& invnorms, + const tensor& data, + const double eps + ) + { + invnorms.set_size(data.num_samples()); + launch_kernel(_cuda_inverse_norms, max_jobs(data.size()/data.num_samples(), data.num_samples()), + invnorms.device(), data.device(), data.num_samples(), data.size()/data.num_samples(), eps); + } + + // ---------------------------------------------------------------------------------------- + + __global__ void _cuda_dot_prods(float* out, const float* lhs, const float* rhs, size_t nr, size_t nc) + { + // initialize out before we begin. + for (auto i : grid_stride_range_y(0, nr)) + for (auto j : grid_stride_range(0, 1)) + out[i] = 0; + __syncthreads(); + + for (auto i : grid_stride_range_y(0, nr)) + { + auto l = lhs + i*nc; + auto r = rhs + i*nc; + float temp = 0; + for (auto j : grid_stride_range(0, nc)) + temp += l[j]*r[j]; + + // and store the sum into out[i] + warp_reduce_atomic_add(out[i], temp); + } + } + + __global__ void _cuda_dot_prods_add_to(float* out, const float* lhs, const float* rhs, size_t nr, size_t nc) + { + for (auto i : grid_stride_range_y(0, nr)) + { + auto l = lhs + i*nc; + auto r = rhs + i*nc; + float temp = 0; + for (auto j : grid_stride_range(0, nc)) + temp += l[j]*r[j]; + + // and store the sum into out[i] + warp_reduce_atomic_add(out[i], temp); + } + } + + void dot_prods ( + resizable_tensor& out, + const tensor& lhs, + const tensor& rhs + ) + { + DLIB_CASSERT(have_same_dimensions(lhs,rhs)); + + out.set_size(lhs.num_samples()); + if (out.size() == 0) + return; + + const auto nr = lhs.num_samples(); + const auto nc = lhs.size()/lhs.num_samples(); + + launch_kernel(_cuda_dot_prods, max_jobs(nc,nr), out.device_write_only(), lhs.device(), rhs.device(), nr, nc); + } + + void dot_prods ( + bool add_to, + tensor& out, + const tensor& lhs, + const tensor& rhs + ) + { + DLIB_CASSERT(have_same_dimensions(lhs,rhs)); + DLIB_CASSERT(out.k() == 1 && out.nr() == 1 && out.nc() == 1); + DLIB_CASSERT(out.size() == lhs.num_samples()); + + const auto nr = lhs.num_samples(); + const auto nc = lhs.size()/lhs.num_samples(); + + if (add_to) + launch_kernel(_cuda_dot_prods_add_to, max_jobs(nc,nr), out.device(), lhs.device(), rhs.device(), nr, nc); + else + launch_kernel(_cuda_dot_prods, max_jobs(nc,nr), out.device_write_only(), lhs.device(), rhs.device(), nr, nc); + } + + // ---------------------------------------------------------------------------------------- + + __global__ void _cuda_scale_columns(float* out, const float* m, const float* v, size_t nr, size_t nc) + { + for (auto j : grid_stride_range(0, nr*nc)) + { + out[j] = m[j]*v[j%nc]; + } + } + + void scale_columns ( + tensor& out, + const tensor& m, + const tensor& v + ) + { + launch_kernel(_cuda_scale_columns, max_jobs(m.size()), out.device(), m.device(), v.device(), m.num_samples(), m.size()/m.num_samples()); + } + + // ---------------------------------------------------------------------------------------- + + __global__ void _cuda_scale_rows(float* out, const float* m, const float* v, size_t nr, size_t nc) + { + for (auto j : grid_stride_range(0, nr*nc)) + { + out[j] = m[j]*v[j/nc]; + } + } + + void scale_rows ( + tensor& out, + const tensor& m, + const tensor& v + ) + { + launch_kernel(_cuda_scale_rows, max_jobs(m.size()), out.device(), m.device(), v.device(), m.num_samples(), m.size()/m.num_samples()); + } + + // ---------------------------------------------------------------------------------------- + + __global__ void _cuda_scale_rows2(float* out, const float* m1, const float* m2, const float* v1, const float* v2, size_t nr, size_t nc) + { + for (auto j : grid_stride_range(0, nr*nc)) + { + out[j] = (m1[j] - m2[j]*v1[j/nc]) * v2[j/nc]; + } + } + + __global__ void _cuda_scale_rows2_beta(const float beta, float* out, const float* m1, const float* m2, const float* v1, const float* v2, size_t nr, size_t nc) + { + for (auto j : grid_stride_range(0, nr*nc)) + { + out[j] = beta*out[j] + (m1[j] - m2[j]*v1[j/nc]) * v2[j/nc]; + } + } + + void scale_rows2 ( + float beta, + tensor& out, + const tensor& m1, + const tensor& m2, + const tensor& v1, + const tensor& v2 + ) + { + if (beta == 0) + { + launch_kernel(_cuda_scale_rows2, max_jobs(m1.size()), out.device(), + m1.device(), m2.device(), v1.device(), v2.device(), m1.num_samples(), + m1.size()/m1.num_samples()); + } + else + { + launch_kernel(_cuda_scale_rows2_beta, max_jobs(m1.size()), beta, + out.device(), m1.device(), m2.device(), v1.device(), v2.device(), + m1.num_samples(), m1.size()/m1.num_samples()); + } + } + + // ---------------------------------------------------------------------------------------- + + __global__ void _cuda_exp(float* dest, const float* src, size_t n) + { + for (auto i : grid_stride_range(0, n)) + dest[i] = ::exp(src[i]); + } + + void exp ( + tensor& dest, + const tensor& src + ) + { + DLIB_ASSERT(dest.size() == src.size()); + launch_kernel(_cuda_exp, max_jobs(src.size()), dest.device(), src.device(), src.size()); + } + + // ---------------------------------------------------------------------------------------- + + __global__ void _cuda_log(float* dest, const float* src, size_t n) + { + for (auto i : grid_stride_range(0, n)) + dest[i] = ::log(src[i]); + } + + void log ( + tensor& dest, + const tensor& src + ) + { + DLIB_ASSERT(dest.size() == src.size()); + launch_kernel(_cuda_log, max_jobs(src.size()), dest.device(), src.device(), src.size()); + } + + // ---------------------------------------------------------------------------------------- + + __global__ void _cuda_log10(float* dest, const float* src, size_t n) + { + for (auto i : grid_stride_range(0, n)) + dest[i] = ::log10(src[i]); + } + + void log10 ( + tensor& dest, + const tensor& src + ) + { + DLIB_ASSERT(dest.size() == src.size()); + launch_kernel(_cuda_log10, max_jobs(src.size()), dest.device(), src.device(), src.size()); + } + + // ----------------------------------------------------------------------------------- + + __global__ void _cuda_multiply1(float* d, const float* s1, const float* s2, size_t n) + { + for (auto i : grid_stride_range(0, n)) + { + d[i] = s1[i]*s2[i]; + } + } + __global__ void _cuda_multiply2(float* d, const float* s1, const float* s2, + size_t n, size_t s1_n, size_t s2_n, size_t max_size) + { + for (auto i : grid_stride_range(0, n)) + { + d[i] = 0; + for (size_t j = i; j < max_size; j += n) + d[i] += s1[j%s1_n]*s2[j%s2_n]; + } + } + + __global__ void _cuda_multiply3(float* d, const float* s1, const float* s2, + size_t n, size_t s1_n, size_t s2_n) + { + for (auto i : grid_stride_range(0, n)) + { + d[i] = s1[i%s1_n]*s2[i%s2_n]; + } + } + + __global__ void _cuda_multiply1_add_to(float* d, const float* s1, const float* s2, size_t n) + { + for (auto i : grid_stride_range(0, n)) + { + d[i] += s1[i]*s2[i]; + } + } + __global__ void _cuda_multiply2_add_to(float* d, const float* s1, const float* s2, + size_t n, size_t s1_n, size_t s2_n, size_t max_size) + { + for (auto i : grid_stride_range(0, n)) + { + for (size_t j = i; j < max_size; j += n) + d[i] += s1[j%s1_n]*s2[j%s2_n]; + } + } + + __global__ void _cuda_multiply3_add_to(float* d, const float* s1, const float* s2, + size_t n, size_t s1_n, size_t s2_n) + { + for (auto i : grid_stride_range(0, n)) + { + d[i] += s1[i%s1_n]*s2[i%s2_n]; + } + } + + void multiply ( + bool add_to, + tensor& dest, + const tensor& src1, + const tensor& src2 + ) + { + + DLIB_CASSERT(dest.k() == src1.k() && src1.k() == src2.k() && + dest.nr() == src1.nr() && src1.nr() == src2.nr() && + dest.nc() == src1.nc() && src1.nc() == src2.nc() ); + const long MD = std::max(std::max(dest.num_samples(),src1.num_samples()),src2.num_samples()); + DLIB_CASSERT((dest.num_samples()==1 || dest.num_samples()==MD) && + (src1.num_samples()==1 || src1.num_samples()==MD) && + (src2.num_samples()==1 || src2.num_samples()==MD) ); + + if (dest.size() == 0) + return; + + const size_t max_size = std::max(std::max(dest.size(),src1.size()),src2.size()); + const auto d = dest.host(); + const auto s1 = src1.host(); + const auto s2 = src2.host(); + if (dest.size() == src1.size() && src1.size() == src2.size()) + { + if (add_to) + launch_kernel(_cuda_multiply1_add_to,max_jobs(dest.size()),dest.device(), src1.device(), src2.device(), src1.size()); + else + launch_kernel(_cuda_multiply1,max_jobs(dest.size()),dest.device(), src1.device(), src2.device(), src1.size()); + } + else if (dest.num_samples() == 1) + { + if (add_to) + launch_kernel(_cuda_multiply2_add_to,max_jobs(dest.size()),dest.device(), src1.device(), src2.device(), + dest.size(), src1.size(), src2.size(), max_size); + else + launch_kernel(_cuda_multiply2,max_jobs(dest.size()),dest.device(), src1.device(), src2.device(), + dest.size(), src1.size(), src2.size(), max_size); + } + else + { + if (add_to) + launch_kernel(_cuda_multiply3_add_to,max_jobs(dest.size()),dest.device(), src1.device(), src2.device(), + dest.size(), src1.size(), src2.size()); + else + launch_kernel(_cuda_multiply3,max_jobs(dest.size()),dest.device(), src1.device(), src2.device(), + dest.size(), src1.size(), src2.size()); + } + } + + // ------------------------------------------------------------------------------------ + + __global__ void _cuda_multiply_conv(float* d, const float* s1, size_t n, const float* s2, size_t bs, size_t ks) + { + for (auto i : grid_stride_range(0, n)) + { + auto k = (i/bs)%ks; + d[i] = s1[i]*s2[k]; + } + } + + __global__ void _cuda_multiply_conv2(float* d, const float* s1, size_t n, const float* s2, size_t bs, size_t ks) + { + // zero initialize d before we begin. + for (auto i : grid_stride_range_y(0, ks)) + for (auto j : grid_stride_range(0, 1)) + d[i] = 0; + __syncthreads(); + + // loop over all the image planes + for (auto i : grid_stride_range_y(0, n)) + { + // sum all the elements in the i-th image plane + float temp = 0; + for (auto j : grid_stride_range(i*bs, (i+1)*bs)) + temp += s1[j]*s2[j]; + auto k = i%ks; + // and store the sum into d[k] + warp_reduce_atomic_add(d[k], temp); + } + } + + __global__ void _cuda_multiply_conv_add_to(float* d, const float* s1, size_t n, const float* s2, size_t bs, size_t ks) + { + for (auto i : grid_stride_range(0, n)) + { + auto k = (i/bs)%ks; + d[i] += s1[i]*s2[k]; + } + } + + __global__ void _cuda_multiply_conv2_add_to(float* d, const float* s1, size_t n, const float* s2, size_t bs, size_t ks) + { + // loop over all the image planes + for (auto i : grid_stride_range_y(0, n)) + { + // sum all the elements in the i-th image plane + float temp = 0; + for (auto j : grid_stride_range(i*bs, (i+1)*bs)) + temp += s1[j]*s2[j]; + auto k = i%ks; + // and store the sum into d[k] + warp_reduce_atomic_add(d[k], temp); + } + } + + + void multiply_conv ( + bool add_to, + tensor& dest, + const tensor& src1, + const tensor& src2 + ) + { + if (have_same_dimensions(dest,src1)) + { + DLIB_CASSERT(src2.num_samples() == 1 && src2.nr() == 1 && src2.nc() == 1 && src2.k() == src1.k()); + if (dest.size() == 0) + return; + + if (add_to) + launch_kernel(_cuda_multiply_conv_add_to,max_jobs(dest.size()), + dest.device(), src1.device(), src1.size(), src2.device(), src1.nr()*src1.nc(), src1.k()); + else + launch_kernel(_cuda_multiply_conv,max_jobs(dest.size()), + dest.device(), src1.device(), src1.size(), src2.device(), src1.nr()*src1.nc(), src1.k()); + } + else + { + DLIB_CASSERT(have_same_dimensions(src1,src2)); + DLIB_CASSERT(dest.num_samples() == 1 && dest.nr() == 1 && dest.nc() == 1 && dest.k() == src1.k()); + if (dest.size() == 0) + return; + + + const auto bs = src1.nr()*src1.nc(); + const auto n = src1.num_samples()*src1.k(); + if (add_to) + launch_kernel(_cuda_multiply_conv2_add_to, max_jobs(bs,n), + dest.device(), src1.device(), n, src2.device(), bs, src1.k()); + else + launch_kernel(_cuda_multiply_conv2, max_jobs(bs,n), + dest.device(), src1.device(), n, src2.device(), bs, src1.k()); + } + + } + + // ------------------------------------------------------------------------------------ + + __global__ void _cuda_scale_channels_add_to(float* d, const float* src, size_t n, const float* scales, size_t bs) + { + for (auto i : grid_stride_range(0, n)) + { + auto k = i/bs; + d[i] += src[i]*scales[k]; + } + } + + __global__ void _cuda_scale_channels(float* d, const float* src, size_t n, const float* scales, size_t bs) + { + for (auto i : grid_stride_range(0, n)) + { + auto k = i/bs; + d[i] = src[i]*scales[k]; + } + } + + void scale_channels ( + bool add_to, + tensor& dest, + const tensor& src, + const tensor& scales + ) + { + DLIB_CASSERT(have_same_dimensions(dest,src) && + scales.num_samples() == src.num_samples() && + scales.k() == src.k() && + scales.nr() == 1 && + scales.nc() == 1 ); + + if (dest.size() == 0) + return; + + if (add_to) + launch_kernel(_cuda_scale_channels_add_to,max_jobs(dest.size()), + dest.device(), src.device(), src.size(), scales.device(), src.nr()*src.nc()); + else + launch_kernel(_cuda_scale_channels,max_jobs(dest.size()), + dest.device_write_only(), src.device(), src.size(), scales.device(), src.nr()*src.nc()); + } + + // ------------------------------------------------------------------------------------ + + __global__ void _cuda_mult1(float* d, const float* s1, const float* s2, size_t n) + { + for (auto i : grid_stride_range(0, n)) + { + d[i] = s1[i]*s2[i]; + } + } + + __global__ void _cuda_mult1_add_to(float* d, const float* s1, const float* s2, size_t n) + { + for (auto i : grid_stride_range(0, n)) + { + d[i] += s1[i]*s2[i]; + } + } + + __global__ void _cuda_mult2(float* d, const float* s1, const float* s2, + size_t dn, size_t dk, size_t dr, size_t dc, + size_t s1n, size_t s1k, size_t s1r, size_t s1c, + size_t s2n, size_t s2k, size_t s2r, size_t s2c) + { + for (auto i : grid_stride_range(0, dn*dk*dr*dc)) + { + size_t n,k,r,c; + unpack_idx(i, dk,dr,dc, n,k,r,c); + + float v1 = 0; + float v2 = 0; + + if (n < s1n && + k < s1k && + r < s1r && + c < s1c ) + { + v1 = s1[pack_idx(s1k,s1r,s1c, n,k,r,c)]; + } + + if (n < s2n && + k < s2k && + r < s2r && + c < s2c ) + { + v2 = s2[pack_idx(s2k,s2r,s2c, n,k,r,c)]; + } + + d[i] = v1*v2; + } + } + + __global__ void _cuda_mult2_add_to(float* d, const float* s1, const float* s2, + size_t dn, size_t dk, size_t dr, size_t dc, + size_t s1n, size_t s1k, size_t s1r, size_t s1c, + size_t s2n, size_t s2k, size_t s2r, size_t s2c) + { + for (auto i : grid_stride_range(0, dn*dk*dr*dc)) + { + size_t n,k,r,c; + unpack_idx(i, dk,dr,dc, n,k,r,c); + + float v1 = 0; + float v2 = 0; + + if (n < s1n && + k < s1k && + r < s1r && + c < s1c ) + { + v1 = s1[pack_idx(s1k,s1r,s1c, n,k,r,c)]; + } + + if (n < s2n && + k < s2k && + r < s2r && + c < s2c ) + { + v2 = s2[pack_idx(s2k,s2r,s2c, n,k,r,c)]; + } + + d[i] += v1*v2; + } + } + + void multiply_zero_padded ( + bool add_to, + tensor& dest, + const tensor& src1, + const tensor& src2 + ) + { + if (dest.size() == 0) + return; + + // Do the simple and fast version if everything has the same dimensions + if (have_same_dimensions(dest, src1) && + have_same_dimensions(dest, src2)) + { + if (add_to) + launch_kernel(_cuda_mult1_add_to,max_jobs(dest.size()), dest.device(), src1.device(), src2.device(), dest.size()); + else + launch_kernel(_cuda_mult1,max_jobs(dest.size()), dest.device(), src1.device(), src2.device(), dest.size()); + } + else + { + if (add_to) + { + // Otherwise, do the more complex version with bounds checking. + launch_kernel(_cuda_mult2_add_to,max_jobs(dest.size()), + dest.device(), src1.device(), src2.device(), + dest.num_samples(), dest.k(), dest.nr(), dest.nc(), + src1.num_samples(), src1.k(), src1.nr(), src1.nc(), + src2.num_samples(), src2.k(), src2.nr(), src2.nc() + ); + } + else + { + // Otherwise, do the more complex version with bounds checking. + launch_kernel(_cuda_mult2,max_jobs(dest.size()), + dest.device(), src1.device(), src2.device(), + dest.num_samples(), dest.k(), dest.nr(), dest.nc(), + src1.num_samples(), src1.k(), src1.nr(), src1.nc(), + src2.num_samples(), src2.k(), src2.nr(), src2.nc() + ); + } + } + } + + // ------------------------------------------------------------------------------------ + + __global__ void _cuda_add1(float* d, const float* s1, const float* s2, size_t n) + { + for (auto i : grid_stride_range(0, n)) + { + d[i] = s1[i]+s2[i]; + } + } + + __global__ void _cuda_add2(float* d, const float* s1, const float* s2, + size_t dn, size_t dk, size_t dr, size_t dc, + size_t s1n, size_t s1k, size_t s1r, size_t s1c, + size_t s2n, size_t s2k, size_t s2r, size_t s2c) + { + for (auto i : grid_stride_range(0, dn*dk*dr*dc)) + { + size_t n,k,r,c; + unpack_idx(i, dk,dr,dc, n,k,r,c); + + float v1 = 0; + float v2 = 0; + + if (n < s1n && + k < s1k && + r < s1r && + c < s1c ) + { + v1 = s1[pack_idx(s1k,s1r,s1c, n,k,r,c)]; + } + + if (n < s2n && + k < s2k && + r < s2r && + c < s2c ) + { + v2 = s2[pack_idx(s2k,s2r,s2c, n,k,r,c)]; + } + + d[i] = v1+v2; + } + } + + void add ( + tensor& dest, + const tensor& src1, + const tensor& src2 + ) + { + if (dest.size() == 0) + return; + + // Do the simple and fast version if everything has the same dimensions + if (have_same_dimensions(dest, src1) && + have_same_dimensions(dest, src2)) + { + launch_kernel(_cuda_add1,max_jobs(dest.size()), dest.device(), src1.device(), src2.device(), dest.size()); + } + else + { + // Otherwise, do the more complex version with bounds checking. + launch_kernel(_cuda_add2,max_jobs(dest.size()), + dest.device(), src1.device(), src2.device(), + dest.num_samples(), dest.k(), dest.nr(), dest.nc(), + src1.num_samples(), src1.k(), src1.nr(), src1.nc(), + src2.num_samples(), src2.k(), src2.nr(), src2.nc() + ); + } + + } + + // ------------------------------------------------------------------------------------ + + __global__ void _cuda_affine_transform1(float* d, const float* s, size_t n, float A, float B) + { + for (auto i : grid_stride_range(0, n)) + { + d[i] = A*s[i] + B; + } + } + + __global__ void _cuda_affine_transform1_0(float* d, const float* s, size_t n, float A) + { + for (auto i : grid_stride_range(0, n)) + { + d[i] = A*s[i]; + } + } + + void affine_transform( + tensor& dest, + const tensor& src, + const float A, + const float B + ) + { + DLIB_CASSERT(dest.size()==src.size()); + if (B != 0) + launch_kernel(_cuda_affine_transform1,max_jobs(dest.size()),dest.device(), src.device(), src.size(), A, B); + else + launch_kernel(_cuda_affine_transform1_0,max_jobs(dest.size()),dest.device(), src.device(), src.size(), A); + } + + void affine_transform( + tensor& dest, + const tensor& src, + const float A + ) + { + DLIB_CASSERT(dest.size()==src.size()); + launch_kernel(_cuda_affine_transform1_0,max_jobs(dest.size()),dest.device(), src.device(), src.size(), A); + } + + // ---------------------------------------------------------------------------------------- + + __global__ void _cuda_affine_transform_rect( + float* d, + const float* s1, + const float* s2, + const float* s3, + float A, + float B, + float C, + size_t start_idx, + size_t n, + size_t rect_nc, + size_t total_nc + ) + { + for (auto i : grid_stride_range(0, n)) + { + size_t r = i/rect_nc; + size_t c = i%rect_nc; + size_t idx = r*total_nc + c + start_idx; + d[idx] = A*s1[idx] + B*s2[idx] + C*s3[idx]; + } + } + + void affine_transform( + const rectangle& rect, + tensor& dest, + const tensor& src1, + const tensor& src2, + const tensor& src3, + float A, + float B, + float C + ) + { + DLIB_CASSERT(dest.size() == src1.size()); + DLIB_CASSERT(dest.size() == src2.size()); + DLIB_CASSERT(dest.size() == src3.size()); + DLIB_CASSERT(dest.num_samples() == src1.num_samples()); + DLIB_CASSERT(dest.num_samples() == src2.num_samples()); + DLIB_CASSERT(dest.num_samples() == src3.num_samples()); + DLIB_CASSERT(rectangle(0,0, dest.size()/dest.num_samples()-1, dest.num_samples()-1).contains(rect)); + launch_kernel(_cuda_affine_transform_rect,max_jobs(rect.area()), + dest.device(), src1.device(), src2.device(), src3.device(), A, B, C, + rect.left() + rect.top()*(dest.size()/dest.num_samples()), + rect.area(), + rect.width(), + dest.size()/dest.num_samples()); + } + + // ---------------------------------------------------------------------------------------- + + __global__ void _cuda_affine_transform4(float* d, const float* s1, const float* s2, size_t n, float A, float B, float C) + { + for (auto i : grid_stride_range(0, n)) + { + d[i] = A*s1[i] + B*s2[i] + C; + } + } + + __global__ void _cuda_affine_transform4_0(float* d, const float* s1, const float* s2, size_t n, float A, float B) + { + for (auto i : grid_stride_range(0, n)) + { + d[i] = A*s1[i] + B*s2[i]; + } + } + + void affine_transform( + tensor& dest, + const tensor& src1, + const tensor& src2, + const float A, + const float B, + const float C + ) + { + DLIB_CASSERT(dest.size()==src1.size()); + DLIB_CASSERT(dest.size()==src2.size()); + if (C != 0) + launch_kernel(_cuda_affine_transform4,max_jobs(dest.size()),dest.device(), src1.device(), src2.device(), dest.size(), A, B, C); + else + launch_kernel(_cuda_affine_transform4_0,max_jobs(dest.size()),dest.device(), src1.device(), src2.device(), dest.size(), A, B); + } + + void affine_transform( + tensor& dest, + const tensor& src1, + const tensor& src2, + const float A, + const float B + ) + { + DLIB_CASSERT(dest.size()==src1.size()); + DLIB_CASSERT(dest.size()==src2.size()); + launch_kernel(_cuda_affine_transform4_0,max_jobs(dest.size()),dest.device(), src1.device(), src2.device(), dest.size(), A, B); + } + + // ---------------------------------------------------------------------------------------- + + __global__ void _cuda_add_scaled(float* d, const float* s, size_t n, float scale) + { + for (auto i : grid_stride_range(0, n)) + { + d[i] += scale*s[i]; + } + } + + void add_scaled( + tensor& dest, + const float scale, + const tensor& src + ) + { + DLIB_CASSERT(dest.size()==src.size()); + launch_kernel(_cuda_add_scaled,max_jobs(dest.size()),dest.device(), src.device(), dest.size(), scale); + } + + // ---------------------------------------------------------------------------------------- + + __global__ void _cuda_add_cv_to_all_columns(float beta, float* dest, float alpha, const float* src, size_t size, size_t stride) + { + for (auto i : grid_stride_range(0, size)) + { + dest[i] = beta*dest[i] + alpha*src[i/stride]; + } + } + + __global__ void _cuda_add_cv_to_all_columns_no_beta(float* dest, float alpha, const float* src, size_t size, size_t stride) + { + for (auto i : grid_stride_range(0, size)) + { + dest[i] = alpha*src[i/stride]; + } + } + + void add_cv_to_all_columns( + float beta, + tensor& dest, + float alpha, + const tensor& src + ) + { + DLIB_CASSERT(dest.num_samples() == src.num_samples() && src.num_samples() == src.size()); + if (beta == 0) + launch_kernel(_cuda_add_cv_to_all_columns_no_beta, max_jobs(dest.size()), dest.device(), alpha, src.device(), dest.size(), dest.size()/dest.num_samples()); + else + launch_kernel(_cuda_add_cv_to_all_columns, max_jobs(dest.size()), beta, dest.device(), alpha, src.device(), dest.size(), dest.size()/dest.num_samples()); + } + + // ---------------------------------------------------------------------------------------- + + __global__ void _cuda_affine_transform5( + float* d, const float* s1, const float* s2, const float* s3, size_t n, float A, float B, float C, float D + ) + { + for (auto i : grid_stride_range(0, n)) + { + d[i] = A*s1[i] + B*s2[i] + C*s3[i] + D; + } + } + + void affine_transform( + tensor& dest, + const tensor& src1, + const tensor& src2, + const tensor& src3, + const float A, + const float B, + const float C, + const float D + ) + { + DLIB_CASSERT(dest.size()==src1.size()); + DLIB_CASSERT(dest.size()==src2.size()); + DLIB_CASSERT(dest.size()==src3.size()); + launch_kernel(_cuda_affine_transform5,max_jobs(dest.size()),dest.device(), src1.device(), + src2.device(), src3.device(), dest.size(), A, B, C, D); + } + + // ---------------------------------------------------------------------------------------- + + __global__ void _cuda_affine_transform_range( + float* d, const float* s1, const float* s2, const float* s3, size_t begin, size_t end, float A, float B, float C + ) + { + for (auto i : grid_stride_range(begin, end)) + { + d[i] = A*s1[i] + B*s2[i] + C*s3[i]; + } + } + + + void affine_transform_range( + size_t begin, + size_t end, + tensor& dest, + const tensor& src1, + const tensor& src2, + const tensor& src3, + const float A, + const float B, + const float C + ) + { + DLIB_CASSERT(dest.size()==src1.size()); + DLIB_CASSERT(dest.size()==src2.size()); + DLIB_CASSERT(dest.size()==src3.size()); + DLIB_CASSERT(begin <= end && end <= dest.size()); + launch_kernel(_cuda_affine_transform_range,max_jobs(end-begin), + dest.device(), src1.device(), + src2.device(), src3.device(), begin, end, A, B, C); + } + + // ----------------------------------------------------------------------------------- + + __global__ void _cuda_affine_transform2(float* d, const float* s, size_t n, const float* A, const float* B) + { + for (auto i : grid_stride_range(0, n)) + { + d[i] = A[i]*s[i] + B[i]; + } + } + __global__ void _cuda_affine_transform3(float* d, const float* s, size_t n, const float* A, const float* B, size_t bs) + { + for (auto i : grid_stride_range(0, n)) + { + d[i] = A[i%bs]*s[i] + B[i%bs]; + } + } + + void affine_transform( + tensor& dest, + const tensor& src, + const tensor& A, + const tensor& B + ) + { + DLIB_CASSERT(have_same_dimensions(dest, src)); + DLIB_CASSERT( + ((A.num_samples()==1 && B.num_samples()==1) || + (A.num_samples()==src.num_samples() && B.num_samples()==src.num_samples()))); + DLIB_CASSERT( + A.nr()==B.nr() && B.nr()==src.nr() && + A.nc()==B.nc() && B.nc()==src.nc() && + A.k() ==B.k() && B.k()==src.k(), + "\nA.nr(): " << A.nr() << "\nB.nr(): " << B.nr() << "\nsrc.nr(): " << src.nr() + <<"\nA.nc(): " << A.nc() << "\nB.nc(): " << B.nc() << "\nsrc.nc(): " << src.nc() + <<"\nA.k(): " << A.k() << "\nB.k(): " << B.k() << "\nsrc.k(): " << src.k() + ); + + if (A.num_samples() == 1) + { + launch_kernel(_cuda_affine_transform3,max_jobs(dest.size()),dest.device(), src.device(), src.size(), A.device(), B.device(), A.size()); + } + else + { + launch_kernel(_cuda_affine_transform2,max_jobs(dest.size()),dest.device(), src.device(), src.size(), A.device(), B.device()); + } + } + + // ---------------------------------------------------------------------------------------- + + __global__ void _cuda_compute_adam_update( + size_t begin, + size_t end, + float* s, + float* m, + float* v, + const float alpha, + const float weight_decay, + const float momentum1, + const float momentum2, + const float* params, + const float* params_grad + ) + { + const float eps = 1e-8; + // The loop is equivalent to doing this: + // m = momentum1*m + (1-momentum1) * (weight_decay*params + params_grad); + // v = momentum2*v + (1-momentum2)*squared(weight_decay*params + params_grad); + // s = -alpha*m/(sqrt(v) + eps); + for (auto i : grid_stride_range(begin, end)) + { + float g = (weight_decay*params[i] + params_grad[i]); + m[i] = momentum1*m[i] + (1-momentum1)*g; + v[i] = momentum2*v[i] + (1-momentum2)*g*g; + s[i] = -alpha*m[i]/(std::sqrt(v[i]) + eps); + } + } + + void compute_adam_update ( + size_t begin, + size_t end, + tensor& s, + tensor& m, + tensor& v, + const float t, + const float learning_rate, + const float weight_decay, + const float momentum1, + const float momentum2, + const tensor& params, + const tensor& params_grad + ) + { + DLIB_CASSERT(s.size() == m.size() && + s.size() == v.size() && + s.size() == params.size() && + s.size() == params_grad.size()); + DLIB_CASSERT(begin <= end && end <= params.size()); + const float alpha = learning_rate*std::sqrt(1-std::pow(momentum2,t))/(1-std::pow(momentum1, t)); + + launch_kernel(_cuda_compute_adam_update,max_jobs(end-begin), + begin, end, s.device(), m.device(), v.device(), alpha, weight_decay, + momentum1, momentum2, params.device(), params_grad.device()); + } + + // ----------------------------------------------------------------------------------- + + __global__ void _cuda_affine_transform_conv(float* d, const float* s, size_t n, const float* A, const float* B, size_t bs, size_t ks) + { + for (auto i : grid_stride_range(0, n)) + { + auto k = (i/bs)%ks; + d[i] = A[k]*s[i] + B[k]; + } + } + + void affine_transform_conv( + tensor& dest, + const tensor& src, + const tensor& A, + const tensor& B + ) + { + DLIB_CASSERT(have_same_dimensions(dest, src)); + DLIB_CASSERT(have_same_dimensions(A, B)); + DLIB_CASSERT(A.num_samples() == 1 && A.nr() == 1 && A.nc() == 1 && A.k() == src.k()); + + launch_kernel(_cuda_affine_transform_conv,max_jobs(dest.size()), + dest.device(), src.device(), src.size(), A.device(), B.device(), src.nr()*src.nc(), src.k()); + } + + // ----------------------------------------------------------------------------------- + + __global__ void _add_bias_gradient(float* out, const float* in, size_t n, size_t total_n) + { + for (auto i : grid_stride_range(0, n)) + { + out[i] = in[i]; + for (size_t j = i+n; j < total_n; j+=n) + out[i] += in[j]; + } + } + + void assign_bias_gradient ( + tensor& grad, + const tensor& gradient_input + ) + { + DLIB_CASSERT( + grad.num_samples() == 1 && + gradient_input.k() == grad.k() && + gradient_input.nr() == grad.nr() && + gradient_input.nc() == grad.nc() && + gradient_input.size() > 0); + + launch_kernel(_add_bias_gradient,max_jobs(grad.size()),grad.device(), gradient_input.device(), grad.size(), gradient_input.size()); + } + + // ---------------------------------------------------------------------------------------- + + __global__ void _set_tensor(float* out, size_t n, const float val) + { + for (auto i : grid_stride_range(0, n)) + out[i] = val; + } + + void set_tensor ( + tensor& t, + float value + ) + { + launch_kernel(_set_tensor, max_jobs(t.size()), t.device(), t.size(), value); + } + + // ---------------------------------------------------------------------------------------- + + __global__ void _scale_tensor(float* out, size_t n, const float val) + { + for (auto i : grid_stride_range(0, n)) + out[i] *= val; + } + + void scale_tensor ( + tensor& t, + float value + ) + { + launch_kernel(_scale_tensor, max_jobs(t.size()), t.device(), t.size(), value); + } + + // ----------------------------------------------------------------------------------- + // ----------------------------------------------------------------------------------- + + __global__ void _cuda_threshold(float* d, size_t n, float thresh) + { + for (auto i : grid_stride_range(0, n)) + { + d[i] = d[i]>thresh ? 1:0; + } + } + + void threshold ( + tensor& data, + float thresh + ) + { + launch_kernel(_cuda_threshold,max_jobs(data.size()),data.device(), data.size(), thresh); + } + + // ------------------------------------------------------------------------------------ + + __global__ void _cuda_dot(const float* a, const float* b, size_t n, float* result) + { + // Parallel sum everything into local temp variables. + float temp = 0; + for(auto i : grid_stride_range(0, n)) + temp += a[i]*b[i]; + + // Then do the warp reduce add thing to merge into one output value. + warp_reduce_atomic_add(*result, temp); + } + + + void dot ( + const tensor& a, + const tensor& b, + tensor& result, + size_t idx + ) + { + DLIB_CASSERT(a.size() == b.size()); + DLIB_CASSERT(idx < result.size()); + + launch_kernel(_cuda_dot, max_jobs(a.size()), a.device(), b.device(), a.size(), result.device()+idx); + } + + // ---------------------------------------------------------------------------------------- + + __global__ void _cuda_prelu(const float* s, float* d, size_t n, const float* pp) + { + const float p = *pp; + for (auto i : grid_stride_range(0, n)) + { + if (s[i] > 0) + d[i] = s[i]; + else + d[i] = p*s[i]; + } + } + + void prelu ( + tensor& dest, + const tensor& src, + const tensor& param + ) + { + launch_kernel(_cuda_prelu, max_jobs(dest.size()), + src.device(), dest.device(), src.size(), param.device()); + } + + // ---------------------------------------------------------------------------------------- + + __global__ void _cuda_prelu_gradient(float* out, const float* s, const float* gi, size_t n, const float* pp, float* ppgrad) + { + const float p = *pp; + float pgrad = 0; + for(auto i : grid_stride_range(0, n)) + { + if (s[i] > 0) + { + out[i] += gi[i]; + } + else + { + out[i] += p*gi[i]; + pgrad += gi[i]*s[i]; + } + } + + // Then do the warp reduce add thing to merge into one output value. + warp_reduce_atomic_add(*ppgrad, pgrad); + } + + void prelu_gradient ( + tensor& grad, + const tensor& src, + const tensor& gradient_input, + const tensor& param, + tensor& params_grad + ) + { + params_grad = 0; + launch_kernel(_cuda_prelu_gradient, max_jobs(grad.size()), + grad.device(), src.device(), gradient_input.device(), grad.size(), + param.device(), params_grad.device()); + } + + // ---------------------------------------------------------------------------------------- + + __global__ void _cuda_resize_bilinear(size_t dsize, size_t dchan_size, size_t dnc, float* d, + size_t schan_size, int snr, int snc, const float* s, + const float x_scale, const float y_scale) + { + for(auto i : grid_stride_range(0, dsize)) + { + const int idx = i%dchan_size; + const int channel = i/dchan_size; + const int sidx = channel*schan_size; + const int r = idx/dnc; + const int c = idx%dnc; + + const float y = r*y_scale; + const int top = static_cast(::floor(y)); + const int bottom = ::min(top+1, snr-1); + const float tb_frac = y - top; + + const float x = c*x_scale; + const int left = static_cast(::floor(x)); + const int right = ::min(left+1, snc-1); + const float lr_frac = x - left; + + float tl = s[sidx+top*snc+left]; + float tr = s[sidx+top*snc+right]; + float bl = s[sidx+bottom*snc+left]; + float br = s[sidx+bottom*snc+right]; + + float temp = (1-tb_frac)*((1-lr_frac)*tl + lr_frac*tr) + + tb_frac*((1-lr_frac)*bl + lr_frac*br); + + d[i] = temp; + } + } + + __global__ void _cuda_resize_bilinear_strided(size_t dsize, size_t dchan_size, size_t dnc, float* d, + size_t schan_size, int snr, int snc, const float* s, + const float x_scale, const float y_scale, + size_t dest_row_stride, size_t src_row_stride, size_t dest_chan_size_strided + ) + { + for(auto i : grid_stride_range(0, dsize)) + { + const int idx = i%dchan_size; + const int channel = i/dchan_size; + const int sidx = channel*schan_size; + const int r = idx/dnc; + const int c = idx%dnc; + const int didx = channel*dest_chan_size_strided + r*dest_row_stride+c; + + const float y = r*y_scale; + const int top = static_cast(::floor(y)); + const int bottom = ::min(top+1, snr-1); + const float tb_frac = y - top; + + const float x = c*x_scale; + const int left = static_cast(::floor(x)); + const int right = ::min(left+1, snc-1); + const float lr_frac = x - left; + + float tl = s[sidx+top*src_row_stride+left]; + float tr = s[sidx+top*src_row_stride+right]; + float bl = s[sidx+bottom*src_row_stride+left]; + float br = s[sidx+bottom*src_row_stride+right]; + + float temp = (1-tb_frac)*((1-lr_frac)*tl + lr_frac*tr) + + tb_frac*((1-lr_frac)*bl + lr_frac*br); + + d[didx] = temp; + } + } + + void resize_bilinear ( + tensor& dest, + long dest_row_stride, + long dest_channel_stride, + const tensor& src, + long src_row_stride, + long src_channel_stride + ) + { + DLIB_CASSERT(is_same_object(dest, src)==false); + DLIB_CASSERT(dest.num_samples() == src.num_samples()); + DLIB_CASSERT(dest.k() == src.k()); + + if (dest.size() == 0 || src.size() == 0) + return; + + const float x_scale = (src.nc()-1)/(float)std::max((dest.nc()-1),1); + const float y_scale = (src.nr()-1)/(float)std::max((dest.nr()-1),1); + + if (dest.nc() == dest_row_stride && dest.nr()*dest.nc()==dest_channel_stride && + src.nc() == src_row_stride && src.nr()*src.nc()==src_channel_stride) + { + launch_kernel(_cuda_resize_bilinear, + dest.size(), dest.nr()*dest.nc(), dest.nc(), dest.device(), + src.nr()*src.nc(), src.nr(), src.nc(), src.device(), + x_scale, y_scale); + } + else + { + launch_kernel(_cuda_resize_bilinear_strided, + dest.size(), dest.nr()*dest.nc(), dest.nc(), dest.device(), + src_channel_stride, src.nr(), src.nc(), src.device(), + x_scale, y_scale, dest_row_stride, src_row_stride, dest_channel_stride); + } + } + + // ---------------------------------------------------------------------------------------- + + __global__ void _cuda_resize_bilinear_gradient(size_t dsize, size_t dchan_size, size_t dnc, const float* d, + size_t schan_size, int snr, int snc, float* s, + const float x_scale, const float y_scale) + { + for(auto i : grid_stride_range(0, dsize)) + { + const float tmp = d[i]; + + const int idx = i%dchan_size; + const int channel = i/dchan_size; + const int sidx = channel*schan_size; + const int r = idx/dnc; + const int c = idx%dnc; + + const float y = r*y_scale; + const int top = static_cast(::floor(y)); + const int bottom = ::min(top+1, snr-1); + const float tb_frac = y - top; + + const float x = c*x_scale; + const int left = static_cast(::floor(x)); + const int right = ::min(left+1, snc-1); + const float lr_frac = x - left; + + + atomicAdd(s+sidx+top*snc+left, tmp*(1-tb_frac)*(1-lr_frac)); + atomicAdd(s+sidx+top*snc+right, tmp*(1-tb_frac)*(lr_frac)); + atomicAdd(s+sidx+bottom*snc+left, tmp*(tb_frac)*(1-lr_frac)); + atomicAdd(s+sidx+bottom*snc+right, tmp*(tb_frac)*(lr_frac)); + } + } + + __global__ void _cuda_resize_bilinear_gradient_strided(size_t dsize, size_t dchan_size, size_t dnc, const float* d, + size_t schan_size, int snr, int snc, float* s, + const float x_scale, const float y_scale, + size_t dest_row_stride, size_t src_row_stride, size_t dest_chan_size_strided + ) + { + for(auto i : grid_stride_range(0, dsize)) + { + + const int idx = i%dchan_size; + const int channel = i/dchan_size; + const int didx = channel*dest_chan_size_strided; + const int sidx = channel*schan_size; + const int r = idx/dnc; + const int c = idx%dnc; + + const float tmp = d[didx + r*dest_row_stride+c]; + + const float y = r*y_scale; + const int top = static_cast(::floor(y)); + const int bottom = ::min(top+1, snr-1); + const float tb_frac = y - top; + + const float x = c*x_scale; + const int left = static_cast(::floor(x)); + const int right = ::min(left+1, snc-1); + const float lr_frac = x - left; + + + atomicAdd(s+sidx+top*src_row_stride+left, tmp*(1-tb_frac)*(1-lr_frac)); + atomicAdd(s+sidx+top*src_row_stride+right, tmp*(1-tb_frac)*(lr_frac)); + atomicAdd(s+sidx+bottom*src_row_stride+left, tmp*(tb_frac)*(1-lr_frac)); + atomicAdd(s+sidx+bottom*src_row_stride+right, tmp*(tb_frac)*(lr_frac)); + } + } + + void resize_bilinear_gradient ( + tensor& grad, + long grad_row_stride, + long grad_channel_stride, + const tensor& gradient_input, + long gradient_input_row_stride, + long gradient_input_channel_stride + ) + { + DLIB_CASSERT(is_same_object(grad, gradient_input)==false); + DLIB_CASSERT(gradient_input.num_samples() == grad.num_samples()); + DLIB_CASSERT(gradient_input.k() == grad.k()); + + if (grad.size() == 0 || gradient_input.size() == 0) + return; + + const float x_scale = (grad.nc()-1)/(float)std::max((gradient_input.nc()-1),1); + const float y_scale = (grad.nr()-1)/(float)std::max((gradient_input.nr()-1),1); + + if (grad.nc() == grad_row_stride && grad.nr()*grad.nc()==grad_channel_stride && + gradient_input.nc() == gradient_input_row_stride && gradient_input.nr()*gradient_input.nc()==gradient_input_channel_stride) + { + launch_kernel(_cuda_resize_bilinear_gradient, + gradient_input.size(), gradient_input.nr()*gradient_input.nc(), gradient_input.nc(), gradient_input.device(), + grad.nr()*grad.nc(), grad.nr(), grad.nc(), grad.device(), + x_scale, y_scale); + } + else + { + launch_kernel(_cuda_resize_bilinear_gradient_strided, + gradient_input.size(), gradient_input.nr()*gradient_input.nc(), gradient_input.nc(), gradient_input.device(), + grad_channel_stride, grad.nr(), grad.nc(), grad.device(), + x_scale, y_scale, gradient_input_row_stride, grad_row_stride, gradient_input_channel_stride); + } + } + + // ---------------------------------------------------------------------------------------- + + __global__ void _cuda_copy_tensor_add_to (float* dest, size_t size, const float* src, size_t dest_stride, size_t src_stride, size_t block_size) + { + for(auto i : grid_stride_range(0, size)) + { + size_t blk = i/block_size; + size_t j = i%block_size; + dest[blk*dest_stride + j] += src[blk*src_stride + j]; + } + } + + __global__ void _cuda_copy_tensor (float* dest, size_t size, const float* src, size_t dest_stride, size_t src_stride, size_t block_size) + { + for(auto i : grid_stride_range(0, size)) + { + size_t blk = i/block_size; + size_t j = i%block_size; + dest[blk*dest_stride + j] = src[blk*src_stride + j]; + } + } + + void copy_tensor( + bool add_to, + tensor& dest, + size_t dest_k_offset, + const tensor& src, + size_t src_k_offset, + size_t count_k + ) + { + const size_t dest_sample_size = static_cast(dest.nc() * dest.nr() * dest.k()); + const size_t src_sample_size = static_cast(src.nc() * src.nr() * src.k()); + + const size_t block_size = count_k * dest.nc() * dest.nr(); + + DLIB_CASSERT(dest.num_samples() == src.num_samples() && + dest.nc() == src.nc() && dest.nr() == src.nr(), "All sources should fit into dest tensor size"); + DLIB_CASSERT(dest.k() - dest_k_offset >= count_k, "Not enough space in dest tensor"); + DLIB_CASSERT(src.k() - src_k_offset >= count_k, "Not enough space in src tensor"); + + float* dest_p = dest.device() + dest_k_offset * dest.nc() * dest.nr(); + const float* src_p = src.device() + src_k_offset * src.nc() * src.nr();; + + if (add_to) + { + launch_kernel(_cuda_copy_tensor_add_to, max_jobs(dest.size()), + dest_p, block_size*dest.num_samples(), + src_p, dest_sample_size, src_sample_size, block_size); + } + else + { + launch_kernel(_cuda_copy_tensor, max_jobs(dest.size()), + dest_p, block_size*dest.num_samples(), + src_p, dest_sample_size, src_sample_size, block_size); + } + } + + // ---------------------------------------------------------------------------------------- + + } +} + diff --git a/ml/dlib/dlib/dnn/cuda_dlib.h b/ml/dlib/dlib/dnn/cuda_dlib.h new file mode 100644 index 000000000..3a057ffc4 --- /dev/null +++ b/ml/dlib/dlib/dnn/cuda_dlib.h @@ -0,0 +1,469 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_DNN_CuDA_H_ +#define DLIB_DNN_CuDA_H_ + + +#include "tensor.h" +#include "../geometry/rectangle.h" + +namespace dlib +{ + namespace cuda + { + + // ---------------------------------------------------------------------------------------- + + void set_device ( + int dev + ); + + int get_device ( + ); + + int get_num_devices ( + ); + + std::string get_device_name ( + int device + ); + + void set_current_device_blocking_sync( + ); + + bool can_access_peer (int device_id, int peer_device_id); + bool can_access_peer (const tensor& device, const tensor& peer_device); + + void device_synchronize (int dev); + void device_synchronize (const tensor& dev); + + + class raii_set_device + { + public: + raii_set_device() = delete; + raii_set_device(const raii_set_device&) = delete; + raii_set_device& operator=(const raii_set_device&) = delete; + + raii_set_device(int dev) + { + prev_dev = get_device(); + set_device(dev); + } + + raii_set_device(const tensor& dev) + { + prev_dev = get_device(); + set_device(dev.device_id()); + } + + void operator() (int dev) + { + set_device(dev); + } + + void operator() (const tensor& dev) + { + set_device(dev.device_id()); + } + + ~raii_set_device() noexcept(false) + { + set_device(prev_dev); + } + + private: + int prev_dev; + }; + + +#ifdef DLIB_USE_CUDA + + class enable_peer_access + { + public: + + enable_peer_access() = delete; + enable_peer_access(const enable_peer_access&) = delete; + enable_peer_access& operator=(const enable_peer_access&) = delete; + + enable_peer_access( + int device_id, + int peer_device_id + ); + + enable_peer_access( + const tensor& device, + const tensor& peer_device + ) : enable_peer_access(device.device_id(), peer_device.device_id()) + {} + + ~enable_peer_access() noexcept(false); + + private: + + bool call_disable; + int device_id; + int peer_device_id; + }; + + // ----------------------------------------------------------------------------------- + + void inverse_norms ( + resizable_tensor& invnorms, + const tensor& data, + const double eps + ); + + void dot_prods ( + resizable_tensor& out, + const tensor& lhs, + const tensor& rhs + ); + + void dot_prods ( + bool add_to, + tensor& out, + const tensor& lhs, + const tensor& rhs + ); + + void scale_columns ( + tensor& out, + const tensor& m, + const tensor& v + ); + + void scale_rows ( + tensor& out, + const tensor& m, + const tensor& v + ); + + void scale_rows2 ( + float beta, + tensor& out, + const tensor& m1, + const tensor& m2, + const tensor& v1, + const tensor& v2 + ); + + void exp ( + tensor& dest, + const tensor& src + ); + + void log ( + tensor& dest, + const tensor& src + ); + + void log10 ( + tensor& dest, + const tensor& src + ); + + // ------------------------------------------------------------------------------------ + + void set_tensor ( + tensor& t, + float value + ); + + void scale_tensor ( + tensor& t, + float value + ); + + // ------------------------------------------------------------------------------------ + + void multiply ( + bool add_to, + tensor& dest, + const tensor& src1, + const tensor& src2 + ); + + void multiply_conv ( + bool add_to, + tensor& dest, + const tensor& src1, + const tensor& src2 + ); + + void multiply_zero_padded ( + bool add_to, + tensor& dest, + const tensor& src1, + const tensor& src2 + ); + + void scale_channels ( + bool add_to, + tensor& dest, + const tensor& src, + const tensor& scales + ); + + void add ( + tensor& dest, + const tensor& src1, + const tensor& src2 + ); + + // ----------------------------------------------------------------------------------- + + void affine_transform( + tensor& dest, + const tensor& src, + const float A, + const float B + ); + + void affine_transform( + tensor& dest, + const tensor& src, + const float A + ); + + void affine_transform( + tensor& dest, + const tensor& src1, + const tensor& src2, + const float A, + const float B, + const float C + ); + + void affine_transform( + tensor& dest, + const tensor& src1, + const tensor& src2, + const float A, + const float B + ); + + void affine_transform( + tensor& dest, + const tensor& src1, + const tensor& src2, + const tensor& src3, + const float A, + const float B, + const float C, + const float D + ); + + void affine_transform_range( + size_t begin, + size_t end, + tensor& dest, + const tensor& src1, + const tensor& src2, + const tensor& src3, + const float A, + const float B, + const float C + ); + + void affine_transform( + const rectangle& rect, + tensor& dest, + const tensor& src1, + const tensor& src2, + const tensor& src3, + float A, + float B, + float C + ); + + // Note that this function isn't in the tt:: namespace because add_scaled() is + // called by cuda::add() so we don't need a tt:: version of add_scaled(). + void add_scaled( + tensor& dest, + const float scale, + const tensor& src + ); + + void add_cv_to_all_columns( + float beta, + tensor& dest, + float alpha, + const tensor& src + ); + + // ----------------------------------------------------------------------------------- + + void affine_transform( + tensor& dest, + const tensor& src, + const tensor& A, + const tensor& B + ); + + // ----------------------------------------------------------------------------------- + + void affine_transform_conv( + tensor& dest, + const tensor& src, + const tensor& A, + const tensor& B + ); + + // ---------------------------------------------------------------------------------------- + + void compute_adam_update ( + size_t begin, + size_t end, + tensor& s, + tensor& m, + tensor& v, + const float t, + const float learning_rate, + const float weight_decay, + const float momentum1, + const float momentum2, + const tensor& params, + const tensor& params_grad + ); + + // ----------------------------------------------------------------------------------- + + void assign_bias_gradient ( + tensor& grad, + const tensor& gradient_input + ); + + // ----------------------------------------------------------------------------------- + + void threshold ( + tensor& data, + float thresh + ); + + // ---------------------------------------------------------------------------------------- + + void dot ( + const tensor& a, + const tensor& b, + tensor& result, + size_t idx + ); + + // ---------------------------------------------------------------------------------------- + + void prelu ( + tensor& dest, + const tensor& src, + const tensor& param + ); + + void prelu_gradient ( + tensor& grad, + const tensor& src, + const tensor& gradient_input, + const tensor& param, + tensor& params_grad + ); + + + // ---------------------------------------------------------------------------------------- + + void resize_bilinear ( + tensor& dest, + long dest_row_stride, + long dest_channel_stride, + const tensor& src, + long src_row_stride, + long src_channel_stride + ); + + void resize_bilinear_gradient ( + tensor& grad, + long grad_row_stride, + long grad_channel_stride, + const tensor& gradient_input, + long gradient_input_row_stride, + long gradient_input_channel_stride + ); + + inline void resize_bilinear ( + tensor& dest, + const tensor& src + ) { resize_bilinear(dest, dest.nc(), dest.nr()*dest.nc(), src, src.nc(), src.nr()*src.nc()); } + + inline void resize_bilinear_gradient ( + tensor& grad, + const tensor& gradient_input + ) { resize_bilinear_gradient(grad, grad.nc(), grad.nr()*grad.nc(), gradient_input, gradient_input.nc(), gradient_input.nr()*gradient_input.nc()); } + + // ---------------------------------------------------------------------------------------- + + void copy_tensor( + bool add_to, + tensor& dest, + size_t dest_k_offset, + const tensor& src, + size_t src_k_offset, + size_t count_k + ); + + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + +#else // if DLIB_USE_CUDA NOT DEFINED + + inline void set_device ( + int id + ) + { + DLIB_CASSERT(id == 0, "dlib::cuda::set_device(id) called with an invalid device id."); + } + + inline int get_device ( + ){ return 0; } + + inline int get_num_devices ( + ) { return 1; } + + inline std::string get_device_name ( + int device + ) + { + DLIB_CASSERT(device == 0, "dlib::cuda::set_device(id) called with an invalid device id."); + return "CUDA_DISABLED"; + } + + inline void set_current_device_blocking_sync( + ) {} + + + inline bool can_access_peer (int , int ) + { return false; } + inline bool can_access_peer (const tensor& , const tensor& ) + { return false; } + + inline void device_synchronize (int ){} + inline void device_synchronize (const tensor& ){} + + class enable_peer_access + { + public: + enable_peer_access() = delete; + enable_peer_access(const enable_peer_access&) = delete; + enable_peer_access& operator=(const enable_peer_access&) = delete; + enable_peer_access( int, int ){} + enable_peer_access( const tensor&, const tensor& ) {} + }; + +#endif // DLIB_USE_CUDA + + } +} + + +#endif // DLIB_DNN_CuDA_H_ + diff --git a/ml/dlib/dlib/dnn/cuda_errors.h b/ml/dlib/dlib/dnn/cuda_errors.h new file mode 100644 index 000000000..fd28693c2 --- /dev/null +++ b/ml/dlib/dlib/dnn/cuda_errors.h @@ -0,0 +1,70 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CUDA_ERRORs_H_ +#define DLIB_CUDA_ERRORs_H_ + + +#include "../error.h" + +namespace dlib +{ + struct cuda_error : public error + { + /*! + WHAT THIS OBJECT REPRESENTS + This is the exception thrown if any calls to the NVIDIA CUDA runtime + returns an error. + !*/ + + cuda_error(const std::string& message): error(message) {} + }; + + + struct cudnn_error : public cuda_error + { + /*! + WHAT THIS OBJECT REPRESENTS + This is the exception thrown if any calls to the NVIDIA cuDNN library + returns an error. + !*/ + + cudnn_error(const std::string& message): cuda_error(message) {} + }; + + struct curand_error : public cuda_error + { + /*! + WHAT THIS OBJECT REPRESENTS + This is the exception thrown if any calls to the NVIDIA cuRAND library + returns an error. + !*/ + + curand_error(const std::string& message): cuda_error(message) {} + }; + + struct cublas_error : public cuda_error + { + /*! + WHAT THIS OBJECT REPRESENTS + This is the exception thrown if any calls to the NVIDIA cuBLAS library + returns an error. + !*/ + + cublas_error(const std::string& message): cuda_error(message) {} + }; + + struct cusolver_error : public cuda_error + { + /*! + WHAT THIS OBJECT REPRESENTS + This is the exception thrown if any calls to the NVIDIA cuSolver library + returns an error. + !*/ + + cusolver_error(const std::string& message): cuda_error(message) {} + }; +} + + +#endif // DLIB_CUDA_ERRORs_H_ + diff --git a/ml/dlib/dlib/dnn/cuda_utils.h b/ml/dlib/dlib/dnn/cuda_utils.h new file mode 100644 index 000000000..673a4e8ad --- /dev/null +++ b/ml/dlib/dlib/dnn/cuda_utils.h @@ -0,0 +1,413 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CUDA_UtILS_H_ +#define DLIB_CUDA_UtILS_H_ + +#ifndef DLIB_USE_CUDA +#error "This file shouldn't be #included unless DLIB_USE_CUDA is #defined" +#endif + +#include "cuda_errors.h" +#include "../algs.h" +#include + +#include +#include +#include +#include +#include +#include + + +// Check the return value of a call to the CUDA runtime for an error condition. +#define CHECK_CUDA(call) \ +do{ \ + const cudaError_t error = call; \ + if (error != cudaSuccess) \ + { \ + std::ostringstream sout; \ + sout << "Error while calling " << #call << " in file " << __FILE__ << ":" << __LINE__ << ". ";\ + sout << "code: " << error << ", reason: " << cudaGetErrorString(error);\ + throw dlib::cuda_error(sout.str()); \ + } \ +}while(false) + +// ---------------------------------------------------------------------------------------- + +#ifdef __CUDACC__ + +namespace dlib +{ + namespace cuda + { + + // ------------------------------------------------------------------------------------ + + __inline__ __device__ size_t pack_idx ( + size_t dim_size3, + size_t dim_size2, + size_t dim_size1, + size_t idx4, + size_t idx3, + size_t idx2, + size_t idx1 + ) + /*! + ensures + - Converts a 4D array index into a 1D index assuming row major layout. To + understand precisely what this function does, imagine we had an array + declared like this: + int ARRAY[anything][dim_size3][dim_size2][dim_size1]; + Then we could index it like this: + ARRAY[idx4][idx3][idx2][idx1] + or equivalently like this: + ((int*)ARRAY)[pack_idx(dim_size3,dim_size2,dim_size1, idx4,idx3,idx2,idx1)] + !*/ + { + return ((idx4*dim_size3 + idx3)*dim_size2 + idx2)*dim_size1 + idx1; + } + + __inline__ __device__ void unpack_idx ( + size_t idx, + size_t dim_size3, + size_t dim_size2, + size_t dim_size1, + size_t& idx4, + size_t& idx3, + size_t& idx2, + size_t& idx1 + ) + /*! + ensures + - This function computes the inverse of pack_idx(). Therefore, + if PACKED == pack_idx(dim_size3,dim_size2,dim_size1, idx4,idx3,idx2,idx1) + then unpack_idx(PACKED,dim_size3,dim_size2,dim_size1, IDX4,IDX3,IDX2,IDX1) + results in: + - IDX1 == idx1 + - IDX2 == idx2 + - IDX3 == idx3 + - IDX4 == idx4 + !*/ + { + idx1 = idx%dim_size1; + + idx /= dim_size1; + idx2 = idx%dim_size2; + + idx /= dim_size2; + idx3 = idx%dim_size3; + + idx /= dim_size3; + idx4 = idx; + } + + // ------------------------------------------------------------------------------------ + + // This function is from the article: + // http://devblogs.nvidia.com/parallelforall/faster-parallel-reductions-kepler/ + __inline__ __device__ float warp_reduce_sum(float val) + { + for (int offset = warpSize/2; offset > 0; offset /= 2) +#if CUDART_VERSION >= 9000 + val += __shfl_down_sync(0xFFFFFFFF,val, offset); +#else + val += __shfl_down(val, offset); +#endif + return val; + } + + __inline__ __device__ bool is_first_thread_in_warp() + { + return (threadIdx.x & (warpSize - 1)) == 0; + } + + __inline__ __device__ void warp_reduce_atomic_add( + float& out, + float val + ) + /*! + ensures + - Atomically adds all the val variables in the current warp to out. + See this page for an extended discussion: + http://devblogs.nvidia.com/parallelforall/faster-parallel-reductions-kepler/ + !*/ + { + val = warp_reduce_sum(val); + if (is_first_thread_in_warp()) + atomicAdd(&out, val); + } + + // ------------------------------------------------------------------------------------ + + struct max_jobs + { + max_jobs(int x) : num_x(x) {} + max_jobs(int x, int y) : num_x(x), num_y(y) {} + int num_x; + int num_y = 1; + }; + + template + void launch_kernel ( + Kernel K, + T ...args + ) + /*! + ensures + - launches the given kernel K(args...). The point of this function is to + automatically set the kernel launch parameters to something reasonable + based on the properties of the kernel and the current GPU card. + !*/ + { + int num_blocks, num_threads; + CHECK_CUDA(cudaOccupancyMaxPotentialBlockSize(&num_blocks,&num_threads,K)); + K<<>>(args...); + } + + template + void launch_kernel ( + Kernel K, + max_jobs m, + T ...args + ) + /*! + ensures + - This function is just like launch_kernel(K,args...) except that you can + additionally supply a max_jobs number that tells it how many possible + total threads could be used. This is useful when launching potentially + small jobs that might not need the number of threads suggested by + launch_kernel(). + !*/ + { + if (m.num_x == 0 || m.num_y == 0) + return; + int num_blocks, num_threads; + CHECK_CUDA(cudaOccupancyMaxPotentialBlockSize(&num_blocks,&num_threads,K)); + // Check if the job is really small and we don't really need to launch a kernel + // with this many blocks and threads. + if (num_blocks*num_threads > m.num_x*m.num_y) + num_blocks = (m.num_x*m.num_y+num_threads-1)/num_threads; + + if (m.num_y == 1) + { + K<<>>(args...); + } + else + { + /* + In general, the reason m.num_y!=1 (i.e. the reason you are in this + code path) is because we are using nested grid-stride loops. There are + two important things to note about what we are doing here. To + illustrate them we will talk about this little CUDA code snippet: + + // initialize out before we begin. + for (auto i : grid_stride_range_y(0, nr)) + for (auto j : grid_stride_range(0, 1)) + out[i] = 0; + + __syncthreads(); // synchronize threads in block + + // loop over some 2D thing and sum and store things into out. + for (auto i : grid_stride_range_y(0, nr)) + { + float temp = 0; + for (auto j : grid_stride_range(0, nc)) + temp += whatever[i*nc+j]; + + // store the sum into out[i] + warp_reduce_atomic_add(out[i], temp); + } + + First, we make sure the number of x threads is a multiple of 32 so that + you can use warp_reduce_atomic_add() inside the y loop. + + Second, we put the x block size to 1 so inter-block synchronization is + easier. For example, if the number of x blocks wasn't 1 the above code + would have a race condition in it. This is because the execution of + out[i]=0 would be done by blocks with blockIdx.x==0, but then in the + second set of loops, *all* the x blocks use out[i]. Since + __syncthreads() doesn't do any synchronization between blocks some of + the blocks might begin before the out[i]=0 statements finished and that + would be super bad. + */ + + // Try and make sure that the ratio of x to y threads is reasonable based + // on the respective size of our loops. + int x_threads = 32; + int y_threads = num_threads/32; + const int ratio = static_cast(std::round(put_in_range(1, y_threads, m.num_x/(double)m.num_y))); + x_threads *= ratio; + y_threads /= ratio; + + dim3 blocks(1,num_blocks); + dim3 threads(x_threads,y_threads); + K<<>>(args...); + } + } + + // ------------------------------------------------------------------------------------ + + class grid_stride_range + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a tool for making a for loop that loops over an entire block of + memory inside a kernel, but doing so in a way that parallelizes + appropriately across all the threads in a kernel launch. For example, + the following kernel would add the vector a to the vector b and store + the output in out (assuming all vectors are of dimension n): + __global__ void add_arrays( + const float* a, + const float* b, + float* out, + size_t n + ) + { + for (auto i : grid_stride_range(0, n)) + { + out[i] = a[i]+b[i]; + } + } + !*/ + + public: + __device__ grid_stride_range( + size_t ibegin_, + size_t iend_ + ) : + ibegin(ibegin_), + iend(iend_) + {} + + class iterator + { + public: + __device__ iterator() {} + __device__ iterator(size_t pos_) : pos(pos_) {} + + __device__ size_t operator*() const + { + return pos; + } + + __device__ iterator& operator++() + { + pos += gridDim.x * blockDim.x; + return *this; + } + + __device__ bool operator!=(const iterator& item) const + { return pos < item.pos; } + + private: + size_t pos; + }; + + __device__ iterator begin() const + { + return iterator(ibegin+blockDim.x * blockIdx.x + threadIdx.x); + } + __device__ iterator end() const + { + return iterator(iend); + } + private: + + size_t ibegin; + size_t iend; + }; + + // ------------------------------------------------------------------------------------ + + class grid_stride_range_y + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is just like grid_stride_range except that it looks at + CUDA's y thread index (e.g. threadIdx.y) instead of the x index. + Therefore, if you launch a cuda kernel with a statement like: + dim3 blocks(1,10); + dim3 threads(32,32); // You need to have x and y not equal to 1 to get parallelism over both loops. + add_arrays<<>>(a,b,out,nr,nc); + You can perform a nested 2D parallel for loop rather than doing just a + 1D for loop. + + So the code in the kernel would look like this if you wanted to add two + 2D matrices: + __global__ void add_arrays( + const float* a, + const float* b, + float* out, + size_t nr, + size_t nc + ) + { + for (auto r : grid_stride_range_y(0, nr)) + { + for (auto c : grid_stride_range(0, nc)) + { + auto i = r*nc+c; + out[i] = a[i]+b[i]; + } + } + } + !*/ + + public: + __device__ grid_stride_range_y( + size_t ibegin_, + size_t iend_ + ) : + ibegin(ibegin_), + iend(iend_) + {} + + class iterator + { + public: + __device__ iterator() {} + __device__ iterator(size_t pos_) : pos(pos_) {} + + __device__ size_t operator*() const + { + return pos; + } + + __device__ iterator& operator++() + { + pos += gridDim.y * blockDim.y; + return *this; + } + + __device__ bool operator!=(const iterator& item) const + { return pos < item.pos; } + + private: + size_t pos; + }; + + __device__ iterator begin() const + { + return iterator(ibegin+blockDim.y * blockIdx.y + threadIdx.y); + } + __device__ iterator end() const + { + return iterator(iend); + } + private: + + size_t ibegin; + size_t iend; + }; + + // ------------------------------------------------------------------------------------ + + } +} + +#endif // __CUDACC__ + +// ---------------------------------------------------------------------------------------- + +#endif // DLIB_CUDA_UtILS_H_ + diff --git a/ml/dlib/dlib/dnn/cudnn_dlibapi.cpp b/ml/dlib/dlib/dnn/cudnn_dlibapi.cpp new file mode 100644 index 000000000..6926561f1 --- /dev/null +++ b/ml/dlib/dlib/dnn/cudnn_dlibapi.cpp @@ -0,0 +1,1604 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_DNN_CuDNN_CPP_ +#define DLIB_DNN_CuDNN_CPP_ + +#ifdef DLIB_USE_CUDA + +#include "cudnn_dlibapi.h" +#include "tensor.h" +#include +#include +#include +#include +#include "cuda_utils.h" +#include "cpu_dlib.h" +#include "cuda_dlib.h" +#include "tensor_tools.h" + +static const char* cudnn_get_error_string(cudnnStatus_t s) +{ + switch(s) + { + case CUDNN_STATUS_NOT_INITIALIZED: + return "CUDA Runtime API initialization failed."; + case CUDNN_STATUS_ALLOC_FAILED: + return "CUDA Resources could not be allocated."; + case CUDNN_STATUS_BAD_PARAM: + return "CUDNN_STATUS_BAD_PARAM"; + case CUDNN_STATUS_EXECUTION_FAILED: + return "CUDNN_STATUS_EXECUTION_FAILED"; + case CUDNN_STATUS_NOT_SUPPORTED: + return "CUDNN_STATUS_NOT_SUPPORTED"; + case CUDNN_STATUS_ARCH_MISMATCH: + return "CUDNN_STATUS_ARCH_MISMATCH: Your GPU is too old and not supported by cuDNN"; + default: + return "A call to cuDNN failed"; + } +} + +// Check the return value of a call to the cuDNN runtime for an error condition. +#define CHECK_CUDNN(call) \ +do{ \ + const cudnnStatus_t error = call; \ + if (error != CUDNN_STATUS_SUCCESS) \ + { \ + std::ostringstream sout; \ + sout << "Error while calling " << #call << " in file " << __FILE__ << ":" << __LINE__ << ". ";\ + sout << "code: " << error << ", reason: " << cudnn_get_error_string(error);\ + throw dlib::cudnn_error(sout.str()); \ + } \ +}while(false) + + +namespace dlib +{ + + namespace cuda + { + + // ------------------------------------------------------------------------------------ + + static cudnnTensorDescriptor_t descriptor(const tensor& t) + { + return (const cudnnTensorDescriptor_t)t.get_cudnn_tensor_descriptor().get_handle(); + } + static cudnnTensorDescriptor_t descriptor(const tensor_descriptor& t) + { + return (const cudnnTensorDescriptor_t)t.get_handle(); + } + + // ------------------------------------------------------------------------------------ + + class cudnn_context + { + public: + // not copyable + cudnn_context(const cudnn_context&) = delete; + cudnn_context& operator=(const cudnn_context&) = delete; + + cudnn_context() + { + handles.resize(16); + } + ~cudnn_context() + { + for (auto h : handles) + { + if (h) + cudnnDestroy(h); + } + } + + cudnnHandle_t get_handle ( + ) + { + int new_device_id; + CHECK_CUDA(cudaGetDevice(&new_device_id)); + // make room for more devices if needed + if (new_device_id >= (long)handles.size()) + handles.resize(new_device_id+16); + + // If we don't have a handle already for this device then make one + if (!handles[new_device_id]) + CHECK_CUDNN(cudnnCreate(&handles[new_device_id])); + + // Finally, return the handle for the current device + return handles[new_device_id]; + } + + private: + + std::vector handles; + }; + + static cudnnHandle_t context() + { + thread_local cudnn_context c; + return c.get_handle(); + } + // ------------------------------------------------------------------------------------ + + class cudnn_device_buffer + { + public: + // not copyable + cudnn_device_buffer(const cudnn_device_buffer&) = delete; + cudnn_device_buffer& operator=(const cudnn_device_buffer&) = delete; + + cudnn_device_buffer() + { + buffers.resize(16); + } + ~cudnn_device_buffer() + { + } + + std::shared_ptr get_buffer ( + ) + { + int new_device_id; + CHECK_CUDA(cudaGetDevice(&new_device_id)); + // make room for more devices if needed + if (new_device_id >= (long)buffers.size()) + buffers.resize(new_device_id+16); + + // If we don't have a buffer already for this device then make one + std::shared_ptr buff = buffers[new_device_id].lock(); + if (!buff) + { + buff = std::make_shared(); + buffers[new_device_id] = buff; + } + + // Finally, return the buffer for the current device + return buff; + } + + private: + + std::vector> buffers; + }; + + + static std::shared_ptr device_global_buffer() + { + thread_local cudnn_device_buffer buffer; + return buffer.get_buffer(); + } + // ------------------------------------------------------------------------------------ + + class cudnn_activation_descriptor + { + public: + // not copyable + cudnn_activation_descriptor(const cudnn_activation_descriptor&) = delete; + cudnn_activation_descriptor& operator=(const cudnn_activation_descriptor&) = delete; + + cudnn_activation_descriptor( + cudnnActivationMode_t mode, + cudnnNanPropagation_t reluNanOpt, + double reluCeiling + ) + { + CHECK_CUDNN(cudnnCreateActivationDescriptor(&handle)); + CHECK_CUDNN(cudnnSetActivationDescriptor(handle, mode, reluNanOpt, reluCeiling)); + } + + ~cudnn_activation_descriptor() + { + cudnnDestroyActivationDescriptor(handle); + } + + cudnnActivationDescriptor_t get_handle ( + ) + { + return handle; + } + private: + cudnnActivationDescriptor_t handle; + }; + + static cudnnActivationDescriptor_t relu_activation_descriptor() + { + thread_local cudnn_activation_descriptor des(CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN,0); + return des.get_handle(); + } + + static cudnnActivationDescriptor_t sigmoid_activation_descriptor() + { + thread_local cudnn_activation_descriptor des(CUDNN_ACTIVATION_SIGMOID, CUDNN_PROPAGATE_NAN,0); + return des.get_handle(); + } + + static cudnnActivationDescriptor_t tanh_activation_descriptor() + { + thread_local cudnn_activation_descriptor des(CUDNN_ACTIVATION_TANH, CUDNN_PROPAGATE_NAN,0); + return des.get_handle(); + } + + // ------------------------------------------------------------------------------------ + + tensor_descriptor:: + tensor_descriptor( + ) : handle(nullptr) + { + } + + tensor_descriptor:: + ~tensor_descriptor() + { + set_size(0,0,0,0); + } + + void tensor_descriptor:: + set_size( + int n, + int k, + int nr, + int nc + ) + { + if (handle) + { + cudnnDestroyTensorDescriptor((cudnnTensorDescriptor_t)handle); + handle = nullptr; + } + + if (n != 0 && nr != 0 && nc != 0 && k != 0) + { + cudnnTensorDescriptor_t h; + CHECK_CUDNN(cudnnCreateTensorDescriptor(&h)); + handle = h; + + CHECK_CUDNN(cudnnSetTensor4dDescriptor((cudnnTensorDescriptor_t)handle, + CUDNN_TENSOR_NCHW, + CUDNN_DATA_FLOAT, + n, + k, + nr, + nc)); + } + } + + void tensor_descriptor:: + get_size ( + int& n, + int& k, + int& nr, + int& nc + ) const + { + if (handle) + { + int nStride, cStride, hStride, wStride; + cudnnDataType_t datatype; + CHECK_CUDNN(cudnnGetTensor4dDescriptor((cudnnTensorDescriptor_t)handle, + &datatype, + &n, + &k, + &nr, + &nc, + &nStride, + &cStride, + &hStride, + &wStride)); + } + else + { + n = 0; + k = 0; + nr = 0; + nc = 0; + } + } + + // ------------------------------------------------------------------------------------ + + void add( + float beta, + tensor& dest, + float alpha, + const tensor& src + ) + { + DLIB_CASSERT( + (have_same_dimensions(src, dest) || + (src.num_samples()==1 && src.k()==dest.k() && src.nr()==1 && src.nc()==1) || + (src.num_samples()==1 && src.k()==dest.k() && src.nr()==dest.nr() && src.nc()==dest.nc()) || + (src.num_samples()==1 && src.k()==1 && src.nr()==dest.nr() && src.nc()==dest.nc()) || + (src.num_samples()==dest.num_samples() && src.k()==1 && src.nr()==1 && src.nc()==1)) && + is_same_object(src,dest) == false , + "\n\t dest.num_samples(): " << dest.num_samples() + <<"\n\t dest.k(): " << dest.k() + <<"\n\t dest.nr(): " << dest.nr() + <<"\n\t dest.nc(): " << dest.nc() + <<"\n\t src.num_samples(): " << src.num_samples() + <<"\n\t src.k(): " << src.k() + <<"\n\t src.nr(): " << src.nr() + <<"\n\t src.nc(): " << src.nc() + ); + + if (dest.size() == src.size() && beta == 1) + { + // Call the dlib function in this case since it's faster than the one that + // comes with cuDNN (at least as of cuDNN v4). + add_scaled(dest, alpha, src); + return; + } + else if (src.num_samples()==dest.num_samples() && src.k()==1 && src.nr()==1 && src.nc()==1) + { + add_cv_to_all_columns(beta, dest, alpha, src); + return; + } + + CHECK_CUDNN(cudnnAddTensor(context(), + &alpha, + descriptor(src), + src.device(), + &beta, + descriptor(dest), + dest.device())); + } + + void assign_conv_bias_gradient ( + tensor& grad, + const tensor& gradient_input + ) + { + DLIB_CASSERT( + grad.num_samples() == 1 && + grad.k() >= 1 && + grad.nr() == 1 && + grad.nc() == 1 && + gradient_input.k() == grad.k() && + gradient_input.size() > 0 && + is_same_object(grad,gradient_input) == false + ); + + const float alpha = 1; + const float beta = 0; + CHECK_CUDNN(cudnnConvolutionBackwardBias(context(), + &alpha, + descriptor(gradient_input), + gradient_input.device(), + &beta, + descriptor(grad), + grad.device())); + } + + // ------------------------------------------------------------------------------------ + + void batch_normalize_inference ( + const double eps, + resizable_tensor& dest, + const tensor& src, + const tensor& gamma, + const tensor& beta, + const tensor& running_means, + const tensor& running_variances + ) + { + DLIB_CASSERT( + gamma.num_samples() == 1 && + gamma.nr() == src.nr() && + gamma.nc() == src.nc() && + gamma.k() == src.k() && + have_same_dimensions(gamma, beta) && + have_same_dimensions(gamma, running_means) && + have_same_dimensions(gamma, running_variances) && + eps > 0, + "\ngamma.num_samples(): " << gamma.num_samples() << + "\ngamma.k(): " << gamma.k() << + "\ngamma.nr(): " << gamma.nr() << + "\ngamma.nc(): " << gamma.nc() << + "\nbeta.num_samples(): " << beta.num_samples() << + "\nbeta.k(): " << beta.k() << + "\nbeta.nr(): " << beta.nr() << + "\nbeta.nc(): " << beta.nc() << + "\nrunning_means.num_samples(): " << running_means.num_samples() << + "\nrunning_means.k(): " << running_means.k() << + "\nrunning_means.nr(): " << running_means.nr() << + "\nrunning_means.nc(): " << running_means.nc() << + "\nrunning_variances.num_samples(): " << running_variances.num_samples() << + "\nrunning_variances.k(): " << running_variances.k() << + "\nrunning_variances.nr(): " << running_variances.nr() << + "\nrunning_variances.nc(): " << running_variances.nc() << + "\nsrc.k(): " << src.k() << + "\nsrc.nr(): " << src.nr() << + "\nsrc.nc(): " << src.nc() << + "\neps: " << eps + ); + const float in_scale = 1; + const float out_scale = 0; + + dest.copy_size(src); + + CHECK_CUDNN(cudnnBatchNormalizationForwardInference( + context(), + CUDNN_BATCHNORM_PER_ACTIVATION, + &in_scale, + &out_scale, + descriptor(src), + src.device(), + descriptor(dest), + dest.device(), + descriptor(gamma), + gamma.device(), + beta.device(), + running_means.device(), + running_variances.device(), + eps)); + } + + void batch_normalize ( + const double eps, + resizable_tensor& dest, + resizable_tensor& means, + resizable_tensor& invstds, + const double averaging_factor, + resizable_tensor& running_means, + resizable_tensor& running_variances, + const tensor& src, + const tensor& gamma, + const tensor& beta + ) + { + DLIB_CASSERT(0 <= averaging_factor && averaging_factor <= 1, "averaging_factor: " << averaging_factor); + DLIB_CASSERT(averaging_factor==1 || have_same_dimensions(running_means,means)); + DLIB_CASSERT(averaging_factor==1 || have_same_dimensions(running_variances,invstds)); + DLIB_CASSERT( + src.num_samples() > 1 && + gamma.num_samples() == 1 && + beta.num_samples() == 1 && + gamma.nr() == beta.nr() && beta.nr() == src.nr() && + gamma.nc() == beta.nc() && beta.nc() == src.nc() && + gamma.k() == beta.k() && beta.k() == src.k() && + eps > 0, + "\ngamma.num_samples(): " << gamma.num_samples() << + "\ngamma.k(): " << gamma.k() << + "\ngamma.nr(): " << gamma.nr() << + "\ngamma.nc(): " << gamma.nc() << + "\nbeta.num_samples(): " << beta.num_samples() << + "\nbeta.k(): " << beta.k() << + "\nbeta.nr(): " << beta.nr() << + "\nbeta.nc(): " << beta.nc() << + "\nsrc.k(): " << src.k() << + "\nsrc.nr(): " << src.nr() << + "\nsrc.nc(): " << src.nc() << + "\neps: " << eps + ); + + const float in_scale = 1; + const float out_scale = 0; + + dest.copy_size(src); + means.set_size(1, src.k(), src.nr(), src.nc()); + invstds.copy_size(means); + running_means.copy_size(means); + running_variances.copy_size(means); + // cuDNN requires that running_means and running_variances be initialized to + // some valid float values even if the averaging factor would have ignored + // them. + if (averaging_factor == 1) + { + running_means = 0; + running_variances = 1; + } + + CHECK_CUDNN(cudnnBatchNormalizationForwardTraining( + context(), + CUDNN_BATCHNORM_PER_ACTIVATION, + &in_scale, + &out_scale, + descriptor(src), + src.device(), + descriptor(dest), + dest.device(), + descriptor(gamma), + gamma.device(), + beta.device(), + averaging_factor, + running_means.device(), + running_variances.device(), + eps, + means.device(), + invstds.device())); + } + + void batch_normalize_gradient( + const double eps, + const tensor& gradient_input, + const tensor& means, + const tensor& invstds, + const tensor& src, + const tensor& gamma, + tensor& src_grad, + tensor& gamma_grad, + tensor& beta_grad + ) + { + const long num = src.k()*src.nr()*src.nc(); + DLIB_CASSERT(src.num_samples() > 1); + DLIB_CASSERT(num == (long)means.size()); + DLIB_CASSERT(num == (long)invstds.size()); + DLIB_CASSERT(num == (long)gamma.size()); + DLIB_CASSERT(num == (long)gamma_grad.size()); + DLIB_CASSERT(num == (long)beta_grad.size()); + DLIB_CASSERT(have_same_dimensions(gradient_input, src)); + DLIB_CASSERT(have_same_dimensions(gradient_input, src_grad)); + DLIB_CASSERT(eps > 0); + + const float in_scale = 1; + const float out_scale = 1; + const float in_scale_params = 1; + const float out_scale_params = 0; + + CHECK_CUDNN(cudnnBatchNormalizationBackward( + context(), + CUDNN_BATCHNORM_PER_ACTIVATION, + &in_scale, + &out_scale, + &in_scale_params, + &out_scale_params, + descriptor(src), + src.device(), + descriptor(gradient_input), + gradient_input.device(), + descriptor(src_grad), + src_grad.device(), + descriptor(gamma), + gamma.device(), + gamma_grad.device(), + beta_grad.device(), + eps, + means.device(), + invstds.device())); + } + + // ------------------------------------------------------------------------------------ + + void batch_normalize_conv_inference ( + const double eps, + resizable_tensor& dest, + const tensor& src, + const tensor& gamma, + const tensor& beta, + const tensor& running_means, + const tensor& running_variances + ) + { + DLIB_CASSERT( + gamma.num_samples() == 1 && + gamma.nr() == 1 && + gamma.nc() == 1 && + gamma.k() == src.k() && + have_same_dimensions(gamma, beta) && + have_same_dimensions(gamma, running_means) && + have_same_dimensions(gamma, running_variances) && + eps > 0, + "\ngamma.num_samples(): " << gamma.num_samples() << + "\ngamma.k(): " << gamma.k() << + "\ngamma.nr(): " << gamma.nr() << + "\ngamma.nc(): " << gamma.nc() << + "\nbeta.num_samples(): " << beta.num_samples() << + "\nbeta.k(): " << beta.k() << + "\nbeta.nr(): " << beta.nr() << + "\nbeta.nc(): " << beta.nc() << + "\nrunning_means.num_samples(): " << running_means.num_samples() << + "\nrunning_means.k(): " << running_means.k() << + "\nrunning_means.nr(): " << running_means.nr() << + "\nrunning_means.nc(): " << running_means.nc() << + "\nrunning_variances.num_samples(): " << running_variances.num_samples() << + "\nrunning_variances.k(): " << running_variances.k() << + "\nrunning_variances.nr(): " << running_variances.nr() << + "\nrunning_variances.nc(): " << running_variances.nc() << + "\nsrc.k(): " << src.k() << + "\nsrc.nr(): " << src.nr() << + "\nsrc.nc(): " << src.nc() << + "\neps: " << eps + ); + const float in_scale = 1; + const float out_scale = 0; + + dest.copy_size(src); + + CHECK_CUDNN(cudnnBatchNormalizationForwardInference( + context(), + CUDNN_BATCHNORM_SPATIAL, + &in_scale, + &out_scale, + descriptor(src), + src.device(), + descriptor(dest), + dest.device(), + descriptor(gamma), + gamma.device(), + beta.device(), + running_means.device(), + running_variances.device(), + eps)); + } + + void batch_normalize_conv ( + const double eps, + resizable_tensor& dest, + resizable_tensor& means, + resizable_tensor& invstds, + const double averaging_factor, + resizable_tensor& running_means, + resizable_tensor& running_variances, + const tensor& src, + const tensor& gamma, + const tensor& beta + ) + { + DLIB_CASSERT(0 <= averaging_factor && averaging_factor <= 1, "averaging_factor: " << averaging_factor); + DLIB_CASSERT(averaging_factor==1 || have_same_dimensions(running_means,means)); + DLIB_CASSERT(averaging_factor==1 || have_same_dimensions(running_variances,invstds)); + DLIB_CASSERT( + src.num_samples() > 1 && + gamma.num_samples() == 1 && + beta.num_samples() == 1 && + gamma.nr() == 1 && + beta.nr() == 1 && + gamma.nc() == 1 && + beta.nc() == 1 && + gamma.k() == beta.k() && beta.k() == src.k() && + eps > 0, + "\ngamma.num_samples(): " << gamma.num_samples() << + "\ngamma.k(): " << gamma.k() << + "\ngamma.nr(): " << gamma.nr() << + "\ngamma.nc(): " << gamma.nc() << + "\nbeta.num_samples(): " << beta.num_samples() << + "\nbeta.k(): " << beta.k() << + "\nbeta.nr(): " << beta.nr() << + "\nbeta.nc(): " << beta.nc() << + "\nsrc.k(): " << src.k() << + "\nsrc.nr(): " << src.nr() << + "\nsrc.nc(): " << src.nc() << + "\neps: " << eps + ); + const float in_scale = 1; + const float out_scale = 0; + + dest.copy_size(src); + means.set_size(1, src.k()); + invstds.copy_size(means); + running_means.copy_size(means); + running_variances.copy_size(means); + // cuDNN requires that running_means and running_variances be initialized to + // some valid float values even if the averaging factor would have ignored + // them. + if (averaging_factor == 1) + { + running_means = 0; + running_variances = 1; + } + + CHECK_CUDNN(cudnnBatchNormalizationForwardTraining( + context(), + CUDNN_BATCHNORM_SPATIAL, + &in_scale, + &out_scale, + descriptor(src), + src.device(), + descriptor(dest), + dest.device(), + descriptor(gamma), + gamma.device(), + beta.device(), + averaging_factor, + running_means.device(), + running_variances.device(), + eps, + means.device(), + invstds.device())); + } + + void batch_normalize_conv_gradient( + const double eps, + const tensor& gradient_input, + const tensor& means, + const tensor& invstds, + const tensor& src, + const tensor& gamma, + tensor& src_grad, + tensor& gamma_grad, + tensor& beta_grad + ) + { + DLIB_CASSERT(src.k() == (long)means.size()); + DLIB_CASSERT(src.k() == (long)invstds.size()); + DLIB_CASSERT(src.k() == (long)gamma.size()); + DLIB_CASSERT(src.k() == (long)gamma_grad.size()); + DLIB_CASSERT(src.k() == (long)beta_grad.size()); + DLIB_CASSERT(have_same_dimensions(gradient_input, src)); + DLIB_CASSERT(have_same_dimensions(gradient_input, src_grad)); + DLIB_CASSERT(eps > 0); + + const float in_scale = 1; + const float out_scale = 1; + const float in_scale_params = 1; + const float out_scale_params = 0; + + CHECK_CUDNN(cudnnBatchNormalizationBackward( + context(), + CUDNN_BATCHNORM_SPATIAL, + &in_scale, + &out_scale, + &in_scale_params, + &out_scale_params, + descriptor(src), + src.device(), + descriptor(gradient_input), + gradient_input.device(), + descriptor(src_grad), + src_grad.device(), + descriptor(gamma), + gamma.device(), + gamma_grad.device(), + beta_grad.device(), + eps, + means.device(), + invstds.device())); + } + + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + + tensor_conv:: + tensor_conv( + ) : + filter_handle(nullptr), + conv_handle(nullptr), + forward_algo(0), + backward_data_algo(0), + backward_filters_algo(0) + { + clear(); + } + + void tensor_conv:: + clear ( + ) + { + if (filter_handle) + cudnnDestroyFilterDescriptor((cudnnFilterDescriptor_t)filter_handle); + if (conv_handle) + cudnnDestroyConvolutionDescriptor((cudnnConvolutionDescriptor_t)conv_handle); + filter_handle = nullptr; + conv_handle = nullptr; + out_num_samples = 0; + out_k = 0; + out_nr = 0; + out_nc = 0; + + stride_y = 0; + stride_x = 0; + padding_y = 0; + padding_x = 0; + data_num_samples = 0; + data_k = 0; + data_nr = 0; + data_nc = 0; + filters_num_samples = 0; + filters_k = 0; + filters_nr = 0; + filters_nc = 0; + + forward_algo = 0; + backward_data_algo = 0; + backward_filters_algo = 0; + + forward_workspace_size_in_bytes = 0; + backward_data_workspace_size_in_bytes = 0; + backward_filters_workspace_size_in_bytes = 0; + + forward_workspace.reset(); + backward_data_workspace.reset(); + backward_filters_workspace.reset(); + workspace.reset(); + } + + void tensor_conv:: + setup( + const tensor& data, + const tensor& filters, + int stride_y_, + int stride_x_, + int padding_y_, + int padding_x_ + ) + { + DLIB_CASSERT(data.k() == filters.k()); + + // if the last call to setup gave the same exact settings then don't do + // anything. + if (stride_y_ == stride_y && + stride_x_ == stride_x && + padding_y_ == padding_y && + padding_x_ == padding_x && + data_num_samples == data.num_samples() && + data_k == data.k() && + data_nr == data.nr() && + data_nc == data.nc() && + filters_num_samples == filters.num_samples() && + filters_k == filters.k() && + filters_nr == filters.nr() && + filters_nc == filters.nc()) + { + return; + } + + clear(); + try + { + stride_y = stride_y_; + stride_x = stride_x_; + padding_y = padding_y_; + padding_x = padding_x_; + data_num_samples = data.num_samples(); + data_k = data.k(); + data_nr = data.nr(); + data_nc = data.nc(); + filters_num_samples = filters.num_samples(); + filters_k = filters.k(); + filters_nr = filters.nr(); + filters_nc = filters.nc(); + + CHECK_CUDNN(cudnnCreateFilterDescriptor((cudnnFilterDescriptor_t*)&filter_handle)); + CHECK_CUDNN(cudnnSetFilter4dDescriptor((cudnnFilterDescriptor_t)filter_handle, + CUDNN_DATA_FLOAT, + CUDNN_TENSOR_NCHW, + filters.num_samples(), + filters.k(), + filters.nr(), + filters.nc())); + + CHECK_CUDNN(cudnnCreateConvolutionDescriptor((cudnnConvolutionDescriptor_t*)&conv_handle)); +#if CUDNN_MAJOR >= 6 + CHECK_CUDNN(cudnnSetConvolution2dDescriptor((cudnnConvolutionDescriptor_t)conv_handle, + padding_y, // vertical padding + padding_x, // horizontal padding + stride_y, + stride_x, + 1, 1, // must be 1,1 + CUDNN_CROSS_CORRELATION, + CUDNN_DATA_FLOAT)); // could also be CUDNN_CONVOLUTION +#else + CHECK_CUDNN(cudnnSetConvolution2dDescriptor((cudnnConvolutionDescriptor_t)conv_handle, + padding_y, // vertical padding + padding_x, // horizontal padding + stride_y, + stride_x, + 1, 1, // must be 1,1 + CUDNN_CROSS_CORRELATION)); // could also be CUDNN_CONVOLUTION +#endif + + CHECK_CUDNN(cudnnGetConvolution2dForwardOutputDim( + (const cudnnConvolutionDescriptor_t)conv_handle, + descriptor(data), + (const cudnnFilterDescriptor_t)filter_handle, + &out_num_samples, + &out_k, + &out_nr, + &out_nc)); + + tensor_descriptor dest_desc; + dest_desc.set_size(out_num_samples,out_k,out_nr,out_nc); + + // Pick which forward algorithm we will use and allocate the necessary + // workspace buffer. + cudnnConvolutionFwdAlgo_t forward_best_algo; + CHECK_CUDNN(cudnnGetConvolutionForwardAlgorithm( + context(), + descriptor(data), + (const cudnnFilterDescriptor_t)filter_handle, + (const cudnnConvolutionDescriptor_t)conv_handle, + descriptor(dest_desc), + dnn_prefer_fastest_algorithms()?CUDNN_CONVOLUTION_FWD_PREFER_FASTEST:CUDNN_CONVOLUTION_FWD_NO_WORKSPACE, + std::numeric_limits::max(), + &forward_best_algo)); + forward_algo = forward_best_algo; + CHECK_CUDNN(cudnnGetConvolutionForwardWorkspaceSize( + context(), + descriptor(data), + (const cudnnFilterDescriptor_t)filter_handle, + (const cudnnConvolutionDescriptor_t)conv_handle, + descriptor(dest_desc), + forward_best_algo, + &forward_workspace_size_in_bytes)); + + // Pick which backward data algorithm we will use and allocate the + // necessary workspace buffer. + cudnnConvolutionBwdDataAlgo_t backward_data_best_algo; + CHECK_CUDNN(cudnnGetConvolutionBackwardDataAlgorithm( + context(), + (const cudnnFilterDescriptor_t)filter_handle, + descriptor(dest_desc), + (const cudnnConvolutionDescriptor_t)conv_handle, + descriptor(data), + dnn_prefer_fastest_algorithms()?CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST:CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE, + std::numeric_limits::max(), + &backward_data_best_algo)); + backward_data_algo = backward_data_best_algo; + + CHECK_CUDNN(cudnnGetConvolutionBackwardDataWorkspaceSize( + context(), + (const cudnnFilterDescriptor_t)filter_handle, + descriptor(dest_desc), + (const cudnnConvolutionDescriptor_t)conv_handle, + descriptor(data), + backward_data_best_algo, + &backward_data_workspace_size_in_bytes)); + + // Pick which backward filters algorithm we will use and allocate the + // necessary workspace buffer. + cudnnConvolutionBwdFilterAlgo_t backward_filters_best_algo; + CHECK_CUDNN(cudnnGetConvolutionBackwardFilterAlgorithm( + context(), + descriptor(data), + descriptor(dest_desc), + (const cudnnConvolutionDescriptor_t)conv_handle, + (const cudnnFilterDescriptor_t)filter_handle, + dnn_prefer_fastest_algorithms()?CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST:CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE, + std::numeric_limits::max(), + &backward_filters_best_algo)); + // cuDNN 5.1 has a bug that causes + // cudnnGetConvolutionBackwardFilterAlgorithm() to pick the winograd + // algorithm even for cases where cuDNN doesn't support it, leading to + // incorrect outputs. So here we check if we are in a case where winograd + // isn't supported and manually overrule + // cudnnGetConvolutionBackwardFilterAlgorithm() by picking a safe + // algorithm. + if (dnn_prefer_fastest_algorithms() && + !(stride_x == 1 && stride_y == 1 && ((filters_nr==3&&filters_nc==3) || (filters_nr==5&&filters_nc==5))) + ) + { + backward_filters_best_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0; + } + backward_filters_algo = backward_filters_best_algo; + + CHECK_CUDNN(cudnnGetConvolutionBackwardFilterWorkspaceSize( + context(), + descriptor(data), + descriptor(dest_desc), + (const cudnnConvolutionDescriptor_t)conv_handle, + (const cudnnFilterDescriptor_t)filter_handle, + backward_filters_best_algo, + &backward_filters_workspace_size_in_bytes)); + + workspace = device_global_buffer(); + } + catch(...) + { + clear(); + throw; + } + } + + tensor_conv:: + ~tensor_conv ( + ) + { + clear(); + } + + void tensor_conv::operator() ( + const bool add_to_output, + resizable_tensor& output, + const tensor& data, + const tensor& filters + ) + { + DLIB_CASSERT(stride_y > 0 && stride_x > 0, "You must call setup() before calling this function"); + + output.set_size(out_num_samples, out_k, out_nr, out_nc); + (*this)(add_to_output, static_cast(output), data, filters); + } + + void tensor_conv::operator() ( + const bool add_to_output, + tensor& output, + const tensor& data, + const tensor& filters + ) + { + DLIB_CASSERT(is_same_object(output,data) == false); + DLIB_CASSERT(is_same_object(output,filters) == false); + DLIB_CASSERT(filters.k() == data.k()); + DLIB_CASSERT(stride_y > 0 && stride_x > 0, "You must call setup() before calling this function"); + DLIB_CASSERT(filters.nc() <= data.nc() + 2*padding_x, + "Filter windows must be small enough to fit into the padded image." + << "\n\t filters.nc(): " << filters.nc() + << "\n\t data.nc(): " << data.nc() + << "\n\t padding_x: " << padding_x + ); + DLIB_CASSERT(filters.nr() <= data.nr() + 2*padding_y, + "Filter windows must be small enough to fit into the padded image." + << "\n\t filters.nr(): " << filters.nr() + << "\n\t data.nr(): " << data.nr() + << "\n\t padding_y: " << padding_y + ); + + + DLIB_CASSERT(output.num_samples() == data.num_samples(),out_num_samples << " " << data.num_samples()); + DLIB_CASSERT(output.k() == filters.num_samples()); + DLIB_CASSERT(output.nr() == 1+(data.nr()+2*padding_y-filters.nr())/stride_y); + DLIB_CASSERT(output.nc() == 1+(data.nc()+2*padding_x-filters.nc())/stride_x); + + + + const float alpha = 1; + const float beta = add_to_output ? 1 : 0; + + // Since cudnnConvolutionForward() is an asynchronous call, we need to hold a + // reference to the workspace buffer so we can be sure it isn't reallocated + // while the function is still executing on the device. But each time we come + // here, we make sure to grab the latest workspace buffer so that, globally, we + // minimize the number of such buffers. + forward_workspace = workspace->get(forward_workspace_size_in_bytes); + + CHECK_CUDNN(cudnnConvolutionForward( + context(), + &alpha, + descriptor(data), + data.device(), + (const cudnnFilterDescriptor_t)filter_handle, + filters.device(), + (const cudnnConvolutionDescriptor_t)conv_handle, + (cudnnConvolutionFwdAlgo_t)forward_algo, + forward_workspace, + forward_workspace_size_in_bytes, + &beta, + descriptor(output), + output.device())); + } + + void tensor_conv::get_gradient_for_data ( + const bool add_to_output, + const tensor& gradient_input, + const tensor& filters, + tensor& data_gradient + ) + { + const float alpha = 1; + const float beta = add_to_output ? 1 : 0; + + // Since cudnnConvolutionBackwardData() is an asynchronous call, we need to hold a + // reference to the workspace buffer so we can be sure it isn't reallocated + // while the function is still executing on the device. But each time we come + // here, we make sure to grab the latest workspace buffer so that, globally, we + // minimize the number of such buffers. + backward_data_workspace = workspace->get(backward_data_workspace_size_in_bytes); + + + CHECK_CUDNN(cudnnConvolutionBackwardData(context(), + &alpha, + (const cudnnFilterDescriptor_t)filter_handle, + filters.device(), + descriptor(gradient_input), + gradient_input.device(), + (const cudnnConvolutionDescriptor_t)conv_handle, + (cudnnConvolutionBwdDataAlgo_t)backward_data_algo, + backward_data_workspace, + backward_data_workspace_size_in_bytes, + &beta, + descriptor(data_gradient), + data_gradient.device())); + } + + void tensor_conv:: + get_gradient_for_filters ( + const bool add_to_output, + const tensor& gradient_input, + const tensor& data, + tensor& filters_gradient + ) + { + const float alpha = 1; + const float beta = add_to_output ? 1 : 0; + + // Since cudnnConvolutionBackwardFilter() is an asynchronous call, we need to hold a + // reference to the workspace buffer so we can be sure it isn't reallocated + // while the function is still executing on the device. But each time we come + // here, we make sure to grab the latest workspace buffer so that, globally, we + // minimize the number of such buffers. + backward_filters_workspace = workspace->get(backward_filters_workspace_size_in_bytes); + + CHECK_CUDNN(cudnnConvolutionBackwardFilter(context(), + &alpha, + descriptor(data), + data.device(), + descriptor(gradient_input), + gradient_input.device(), + (const cudnnConvolutionDescriptor_t)conv_handle, + (cudnnConvolutionBwdFilterAlgo_t)backward_filters_algo, + backward_filters_workspace, + backward_filters_workspace_size_in_bytes, + &beta, + (const cudnnFilterDescriptor_t)filter_handle, + filters_gradient.device())); + } + + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + + pooling::pooling ( + ) : handle(nullptr),window_height(0),window_width(0),stride_y(0),stride_x(0),padding_y(0), padding_x(0) + { + } + + pooling::~pooling( + ) + { + clear(); + } + + void pooling:: + clear( + ) + { + if (handle) + cudnnDestroyPoolingDescriptor((cudnnPoolingDescriptor_t)handle); + handle = nullptr; + window_height = 0; + window_width = 0; + stride_y = 0; + stride_x = 0; + padding_y = 0; + padding_x = 0; + } + + void pooling:: + setup_max_pooling( + int window_height_, + int window_width_, + int stride_y_, + int stride_x_, + int padding_y_, + int padding_x_ + ) + { + setup(window_height_, window_width_, stride_y_, stride_x_, padding_y_, padding_x_, CUDNN_POOLING_MAX); + do_max_pooling = true; + } + + void pooling:: + setup_avg_pooling( + int window_height_, + int window_width_, + int stride_y_, + int stride_x_, + int padding_y_, + int padding_x_ + ) + { + setup(window_height_, window_width_, stride_y_, stride_x_, padding_y_, padding_x_, CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING); + do_max_pooling = false; + } + + void pooling:: + setup( + int window_height_, + int window_width_, + int stride_y_, + int stride_x_, + int padding_y_, + int padding_x_, + int pooling_mode + ) + { + DLIB_CASSERT (window_height_ > 0 && window_width_ > 0 && + stride_y_ > 0 && stride_x_ > 0 , + "window_height_: " << window_height_ + << "\t\n window_width_: " << window_width_ + << "\t\n stride_y_: " << stride_y_ + << "\t\n stride_x_: " << stride_x_ ); + DLIB_CASSERT( 0 <= padding_y_ && padding_y_ < window_height_ && + 0 <= padding_x_ && padding_x_ < window_width_, + "window_height_: " << window_height_ + << "\t\n window_width_: " << window_width_ + << "\t\n padding_y_: " << padding_y_ + << "\t\n padding_x_: " << padding_x_ ); + + if (window_height == window_height_ && + window_width == window_width_ && + stride_y == stride_y_ && + stride_x == stride_x_ && + padding_y == padding_y_ && + padding_x == padding_x_ + ) + { + return; + } + + clear(); + try + { + window_height = window_height_; + window_width = window_width_; + stride_x = stride_x_; + stride_y = stride_y_; + padding_y = padding_y_; + padding_x = padding_x_; + cudnnPoolingDescriptor_t poolingDesc; + CHECK_CUDNN(cudnnCreatePoolingDescriptor(&poolingDesc)); + handle = poolingDesc; + + CHECK_CUDNN(cudnnSetPooling2dDescriptor(poolingDesc, + (cudnnPoolingMode_t)pooling_mode, + CUDNN_PROPAGATE_NAN, + window_height, + window_width, + padding_y, + padding_x, + stride_y, + stride_x)); + } + catch(...) + { + clear(); + throw; + } + } + + void pooling:: + operator() ( + resizable_tensor& dest, + const tensor& src + ) + { + DLIB_CASSERT(window_width <= src.nc() + 2*padding_x, + "Pooling windows must be small enough to fit into the padded image." + << "\n\t window_width: " << window_width + << "\n\t src.nc(): " << src.nc() + << "\n\t padding_x: " << padding_x + ); + DLIB_CASSERT(window_height <= src.nr() + 2*padding_y, + "Pooling windows must be small enough to fit into the padded image." + << "\n\t window_height: " << window_height + << "\n\t src.nr(): " << src.nr() + << "\n\t padding_y: " << padding_y + ); + const float alpha = 1; + const float beta = 0; + int outN; + int outC; + int outH; + int outW; + CHECK_CUDNN(cudnnGetPooling2dForwardOutputDim((const cudnnPoolingDescriptor_t)handle, + descriptor(src), + &outN, + &outC, + &outH, + &outW)); + + + dest.set_size(outN,outC,outH,outW); + + DLIB_CASSERT(dest.num_samples() == src.num_samples()); + DLIB_CASSERT(dest.k() == src.k()); + DLIB_CASSERT(dest.nr() == 1 + (src.nr() + 2*padding_y - window_height)/stride_y, + "\n stride_y: " << stride_y << + "\n padding_y: " << padding_y << + "\n window_height: " << window_height << + "\n src.nr(): " << src.nr() << + "\n dest.nr(): " << dest.nr() << + "\n src.nr()/stride_y: " << src.nr()/stride_y); + DLIB_CASSERT(dest.nc() == 1 + (src.nc() + 2*padding_x - window_width)/stride_x, + "\n stride_x: " << stride_x << + "\n padding_x: " << padding_x << + "\n window_width: " << window_width << + "\n src.nc(): " << src.nc() << + "\n dest.nc(): " << dest.nc() << + "\n src.nc()/stride_x: " << src.nc()/stride_x); + + CHECK_CUDNN(cudnnPoolingForward(context(), + (const cudnnPoolingDescriptor_t)handle, + &alpha, + descriptor(src), + src.device(), + &beta, + descriptor(dest), + dest.device())); + } + + void pooling::get_gradient( + const tensor& gradient_input, + const tensor& dest, + const tensor& src, + tensor& grad + ) + { + DLIB_CASSERT(have_same_dimensions(gradient_input,dest)); + DLIB_CASSERT(have_same_dimensions(src,grad)); + + const float alpha = 1; + const float beta = 1; + CHECK_CUDNN(cudnnPoolingBackward(context(), + (const cudnnPoolingDescriptor_t)handle, + &alpha, + descriptor(dest), + dest.device(), + descriptor(gradient_input), + gradient_input.device(), + descriptor(src), + src.device(), + &beta, + descriptor(grad), + grad.device())); + } + + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + + void softmax ( + tensor& dest, + const tensor& src + ) + { + DLIB_CASSERT(have_same_dimensions(dest,src)); + if (src.size() == 0) + return; + + const float alpha = 1; + const float beta = 0; + + CHECK_CUDNN(cudnnSoftmaxForward(context(), + CUDNN_SOFTMAX_ACCURATE, + CUDNN_SOFTMAX_MODE_CHANNEL, + &alpha, + descriptor(src), + src.device(), + &beta, + descriptor(dest), + dest.device())); + } + + + void softmax_gradient ( + tensor& grad, + const tensor& dest, + const tensor& gradient_input + ) + { + DLIB_CASSERT( + have_same_dimensions(dest,gradient_input) == true && + have_same_dimensions(dest,grad) == true ); + if (dest.size() == 0) + return; + + const float alpha = 1; + const float beta = is_same_object(grad,gradient_input) ? 0 : 1; + CHECK_CUDNN(cudnnSoftmaxBackward(context(), + CUDNN_SOFTMAX_ACCURATE, + CUDNN_SOFTMAX_MODE_CHANNEL, + &alpha, + descriptor(dest), + dest.device(), + descriptor(gradient_input), + gradient_input.device(), + &beta, + descriptor(grad), + grad.device())); + } + + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + + void softmax_all ( + tensor& dest, + const tensor& src + ) + { + DLIB_CASSERT(have_same_dimensions(dest,src)); + if (src.size() == 0) + return; + + const float alpha = 1; + const float beta = 0; + + CHECK_CUDNN(cudnnSoftmaxForward(context(), + CUDNN_SOFTMAX_ACCURATE, + CUDNN_SOFTMAX_MODE_INSTANCE, + &alpha, + descriptor(src), + src.device(), + &beta, + descriptor(dest), + dest.device())); + } + + + void softmax_all_gradient ( + tensor& grad, + const tensor& dest, + const tensor& gradient_input + ) + { + DLIB_CASSERT( + have_same_dimensions(dest,gradient_input) == true && + have_same_dimensions(dest,grad) == true ); + if (dest.size() == 0) + return; + + const float alpha = 1; + const float beta = is_same_object(grad,gradient_input) ? 0 : 1; + CHECK_CUDNN(cudnnSoftmaxBackward(context(), + CUDNN_SOFTMAX_ACCURATE, + CUDNN_SOFTMAX_MODE_INSTANCE, + &alpha, + descriptor(dest), + dest.device(), + descriptor(gradient_input), + gradient_input.device(), + &beta, + descriptor(grad), + grad.device())); + } + + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + + void sigmoid ( + tensor& dest, + const tensor& src + ) + { + DLIB_CASSERT(have_same_dimensions(dest,src)); + if (src.size() == 0) + return; + + const float alpha = 1; + const float beta = 0; + CHECK_CUDNN(cudnnActivationForward(context(), + sigmoid_activation_descriptor(), + &alpha, + descriptor(src), + src.device(), + &beta, + descriptor(dest), + dest.device())); + } + + void sigmoid_gradient ( + tensor& grad, + const tensor& dest, + const tensor& gradient_input + ) + { + DLIB_CASSERT( + have_same_dimensions(dest,gradient_input) == true && + have_same_dimensions(dest,grad) == true ); + if (dest.size() == 0) + return; + + const float alpha = 1; + const float beta = is_same_object(grad,gradient_input) ? 0 : 1; + CHECK_CUDNN(cudnnActivationBackward(context(), + sigmoid_activation_descriptor(), + &alpha, + descriptor(dest), + dest.device(), + descriptor(gradient_input), + gradient_input.device(), + descriptor(dest), + dest.device(), + &beta, + descriptor(grad), + grad.device())); + } + + // ------------------------------------------------------------------------------------ + + void relu ( + tensor& dest, + const tensor& src + ) + { + DLIB_CASSERT(have_same_dimensions(dest,src)); + if (src.size() == 0) + return; + + const float alpha = 1; + const float beta = 0; + CHECK_CUDNN(cudnnActivationForward(context(), + relu_activation_descriptor(), + &alpha, + descriptor(src), + src.device(), + &beta, + descriptor(dest), + dest.device())); + } + + void relu_gradient ( + tensor& grad, + const tensor& dest, + const tensor& gradient_input + ) + { + DLIB_CASSERT( + have_same_dimensions(dest,gradient_input) == true && + have_same_dimensions(dest,grad) == true ); + if (dest.size() == 0) + return; + + const float alpha = 1; + const float beta = is_same_object(grad,gradient_input) ? 0 : 1; + CHECK_CUDNN(cudnnActivationBackward(context(), + relu_activation_descriptor(), + &alpha, + descriptor(dest), + dest.device(), + descriptor(gradient_input), + gradient_input.device(), + descriptor(dest), + dest.device(), + &beta, + descriptor(grad), + grad.device())); + } + + // ------------------------------------------------------------------------------------ + + void tanh ( + tensor& dest, + const tensor& src + ) + { + DLIB_CASSERT(have_same_dimensions(dest,src)); + if (src.size() == 0) + return; + + const float alpha = 1; + const float beta = 0; + CHECK_CUDNN(cudnnActivationForward(context(), + tanh_activation_descriptor(), + &alpha, + descriptor(src), + src.device(), + &beta, + descriptor(dest), + dest.device())); + } + + void tanh_gradient ( + tensor& grad, + const tensor& dest, + const tensor& gradient_input + ) + { + DLIB_CASSERT( + have_same_dimensions(dest,gradient_input) == true && + have_same_dimensions(dest,grad) == true); + if (dest.size() == 0) + return; + + const float alpha = 1; + const float beta = is_same_object(grad,gradient_input) ? 0 : 1; + CHECK_CUDNN(cudnnActivationBackward(context(), + tanh_activation_descriptor(), + &alpha, + descriptor(dest), + dest.device(), + descriptor(gradient_input), + gradient_input.device(), + descriptor(dest), + dest.device(), + &beta, + descriptor(grad), + grad.device())); + } + + // ------------------------------------------------------------------------------------ + } +} + +#endif // DLIB_USE_CUDA + +#endif // DLIB_DNN_CuDNN_CPP_ + + diff --git a/ml/dlib/dlib/dnn/cudnn_dlibapi.h b/ml/dlib/dlib/dnn/cudnn_dlibapi.h new file mode 100644 index 000000000..e9ffe5f6d --- /dev/null +++ b/ml/dlib/dlib/dnn/cudnn_dlibapi.h @@ -0,0 +1,518 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_DNN_CuDNN_H_ +#define DLIB_DNN_CuDNN_H_ + +#ifdef DLIB_USE_CUDA + +#include "cuda_errors.h" +#include +#include "cuda_data_ptr.h" + +namespace dlib +{ + class tensor; + class resizable_tensor; + + namespace cuda + { + + // ----------------------------------------------------------------------------------- + + class tensor_descriptor + { + /*! + Each tensor object will carry a tensor_descriptor in it when compiled with + CUDA. + !*/ + + public: + // not copyable + tensor_descriptor(const tensor_descriptor&) = delete; + tensor_descriptor& operator=(const tensor_descriptor&) = delete; + // but is movable + tensor_descriptor(tensor_descriptor&& item) : tensor_descriptor() { swap(item); } + tensor_descriptor& operator=(tensor_descriptor&& item) { swap(item); return *this; } + + tensor_descriptor(); + ~tensor_descriptor(); + + void set_size( + int n, + int k, + int nr, + int nc + ); + /*! + ensures + - if any of the arguments are 0 then they are all set to 0 in the tensor. + !*/ + + void get_size ( + int& n, + int& k, + int& nr, + int& nc + ) const; + + const void* get_handle ( + ) const { return handle; } + + private: + + void swap(tensor_descriptor& item) { std::swap(handle, item.handle); } + + void* handle; + }; + + // ------------------------------------------------------------------------------------ + + void add( + float beta, + tensor& dest, + float alpha, + const tensor& src + ); + /*! + requires + - One of the following is true: + - have_same_dimensions(src, dest) + - src.num_samples()==1 && src.k()==dest.k() && src.nr()==1 && src.nc()==1 + - src.num_samples()==1 && src.k()==dest.k() && src.nr()==dest.nr() && src.nc()==dest.nc() + - src.num_samples()==1 && src.k()==1 && src.nr()==dest.nr() && src.nc()==dest.nc() + - is_same_object(src,dest) == false + ensures + - performs: dest = beta*dest + alpha*src + However, how the addition happens depends on the dimensions of src. In + particular, this function adds the scaled values of one src tensor to + dest. Each dimension of the src tensor must match the corresponding + dimension of the dest tensor or must be equal to 1. In the latter case, + the same value from the src tensor, for those dimensions, will be used to + add into the dest tensor. + !*/ + + // ------------------------------------------------------------------------------------ + + void assign_conv_bias_gradient ( + tensor& grad, + const tensor& gradient_input + ); + /*! + requires + - grad.num_samples() == 1 + - grad.k() >= 1 + - grad.nr() == 1 + - grad.nc() == 1 + - gradient_input.k() == grad.k() + - gradient_input.size() > 0 + - is_same_object(grad,gradient_input) == false + ensures + - let BIAS be a tensor with all dimensions equal to 1 except for k which is >= 1. + - let OUT be the output of add(1,OUT,1,BIAS) + - let f(gradient_input,BIAS) == dot(gradient_input,OUT) + - Then this function computes the gradient of f() with respect to BIAS and + assigns it to grad. + !*/ + + // ------------------------------------------------------------------------------------ + + void batch_normalize_inference ( + const double eps, + resizable_tensor& dest, + const tensor& src, + const tensor& gamma, + const tensor& beta, + const tensor& running_means, + const tensor& running_variances + ); + + void batch_normalize ( + const double eps, + resizable_tensor& dest, + resizable_tensor& means, + resizable_tensor& invstds, + const double averaging_factor, + resizable_tensor& running_means, + resizable_tensor& running_variances, + const tensor& src, + const tensor& gamma, + const tensor& beta + ); + + void batch_normalize_gradient( + const double eps, + const tensor& gradient_input, + const tensor& means, + const tensor& invstds, + const tensor& src, + const tensor& gamma, + tensor& src_grad, + tensor& gamma_grad, + tensor& beta_grad + ); + + // ------------------------------------------------------------------------------------ + + void batch_normalize_conv_inference ( + const double eps, + resizable_tensor& dest, + const tensor& src, + const tensor& gamma, + const tensor& beta, + const tensor& running_means, + const tensor& running_variances + ); + + void batch_normalize_conv ( + const double eps, + resizable_tensor& dest, + resizable_tensor& means, + resizable_tensor& invstds, + const double averaging_factor, + resizable_tensor& running_means, + resizable_tensor& running_variances, + const tensor& src, + const tensor& gamma, + const tensor& beta + ); + + void batch_normalize_conv_gradient( + const double eps, + const tensor& gradient_input, + const tensor& means, + const tensor& invstds, + const tensor& src, + const tensor& gamma, + tensor& src_grad, + tensor& gamma_grad, + tensor& beta_grad + ); + + // ------------------------------------------------------------------------------------ + + class tensor_conv + { + public: + tensor_conv(const tensor_conv&) = delete; + tensor_conv& operator=(const tensor_conv&) = delete; + + tensor_conv(); + + void clear( + ); + + ~tensor_conv ( + ); + + void operator() ( + const bool add_to_output, + tensor& output, + const tensor& data, + const tensor& filters + ); + + void operator() ( + const bool add_to_output, + resizable_tensor& output, + const tensor& data, + const tensor& filters + ); + + void get_gradient_for_data ( + const bool add_to_output, + const tensor& gradient_input, + const tensor& filters, + tensor& data_gradient + ); + + void get_gradient_for_filters ( + const bool add_to_output, + const tensor& gradient_input, + const tensor& data, + tensor& filters_gradient + ); + + void setup( + const tensor& data, + const tensor& filters, + int stride_y, + int stride_x, + int padding_y, + int padding_x + ); + + private: + + // These variables record the type of data given to the last call to setup(). + int stride_y; + int stride_x; + int padding_y; + int padding_x; + long data_num_samples, data_k, data_nr, data_nc; + long filters_num_samples, filters_k, filters_nr, filters_nc; + + + void* filter_handle; + void* conv_handle; + + // dimensions of the output tensor from operator() + int out_num_samples; + int out_k; + int out_nr; + int out_nc; + + int forward_algo; + int backward_data_algo; + int backward_filters_algo; + + size_t forward_workspace_size_in_bytes; + size_t backward_data_workspace_size_in_bytes; + size_t backward_filters_workspace_size_in_bytes; + std::shared_ptr workspace; + cuda_data_void_ptr forward_workspace; + cuda_data_void_ptr backward_data_workspace; + cuda_data_void_ptr backward_filters_workspace; + }; + + // ------------------------------------------------------------------------------------ + + class pooling + { + public: + + pooling(const pooling&) = delete; + pooling& operator=(const pooling&) = delete; + + pooling ( + ); + + ~pooling( + ); + + void clear( + ); + + void setup_max_pooling( + int window_height, + int window_width, + int stride_y, + int stride_x, + int padding_y, + int padding_x + ); + + void setup_avg_pooling( + int window_height, + int window_width, + int stride_y, + int stride_x, + int padding_y, + int padding_x + ); + + bool does_max_pooling( + ) const { return do_max_pooling; } + + void operator() ( + resizable_tensor& dest, + const tensor& src + ); + + void get_gradient( + const tensor& gradient_input, + const tensor& dest, + const tensor& src, + tensor& grad + ); + + private: + + void setup( + int window_height, + int window_width, + int stride_y, + int stride_x, + int padding_y, + int padding_x, + int pooling_mode + ); + + void* handle; + int window_height; + int window_width; + int stride_y; + int stride_x; + int padding_y; + int padding_x; + bool do_max_pooling; + }; + + // ------------------------------------------------------------------------------------ + + void softmax ( + tensor& dest, + const tensor& src + ); + /*! + requires + - have_same_dimensions(dest, src) == true + ensures + - Note that the softmax function is a vector valued function: + s(x) == exp(x)/sum(exp(x)) + - Computes the softmax function on src and writes the results to dest. The + softmax is computed per spatial location across the different channels at + each location. That is, softmax() outputs a new tensor, #dest, where + each of the spatial locations in dest (i.e. image idx, row idx, and + column idx) contains the output of s() evaluated over the channel values + at each location. + - This function supports in-place operation, i.e. having + is_same_object(dest, src)==true + !*/ + + void softmax_gradient ( + tensor& grad, + const tensor& dest, + const tensor& gradient_input + ); + /*! + requires + - have_same_dimensions(dest,gradient_input) == true + - have_same_dimensions(dest,grad) == true + - is_same_object(grad, dest)==false + ensures + - We interpret dest as the output of softmax(dest,SRC) for some SRC tensor. + Then let f(SRC) == dot(gradient_input,dest) Then this function computes + the gradient of f() with respect to SRC and assigns it to grad. + - This function supports in-place operation, i.e. having + is_same_object(grad, gradient_input)==true + !*/ + + // ------------------------------------------------------------------------------------ + + void softmax_all ( + tensor& dest, + const tensor& src + ); + + void softmax_all_gradient ( + tensor& grad, + const tensor& dest, + const tensor& gradient_input + ); + + // ------------------------------------------------------------------------------------ + + void sigmoid ( + tensor& dest, + const tensor& src + ); + /*! + requires + - have_same_dimensions(dest, src) == true + ensures + - for all valid i: + - #dest.host()[i] == 1/(1+std::exp(-src.host()[i])) + - This function supports in-place operation, i.e. having + is_same_object(dest, src)==true + !*/ + + void sigmoid_gradient ( + tensor& grad, + const tensor& dest, + const tensor& gradient_input + ); + /*! + requires + - have_same_dimensions(dest,gradient_input) == true + - have_same_dimensions(dest,grad) == true + - is_same_object(grad,dest) == false + ensures + - Recalling that dest is the output of sigmoid(dest,SRC) for some SRC tensor, + let f(SRC) == dot(gradient_input,dest) + - Then this function computes the gradient of f() with respect to SRC and + assigns it to grad. + - This function supports in-place operation, i.e. having + is_same_object(grad, gradient_input)==true + !*/ + + // ------------------------------------------------------------------------------------ + + void relu ( + tensor& dest, + const tensor& src + ); + /*! + requires + - have_same_dimensions(dest, src) == true + ensures + - for all valid i: + - #dest.host()[i] == std::max(0,src.host()[i]) + - This function supports in-place operation, i.e. having + is_same_object(dest, src)==true + !*/ + + void relu_gradient ( + tensor& grad, + const tensor& dest, + const tensor& gradient_input + ); + /*! + requires + - have_same_dimensions(dest,gradient_input) == true + - have_same_dimensions(dest,grad) == true + - is_same_object(grad,dest) == false + ensures + - Recalling that dest is the output of relu(dest,SRC) for some SRC tensor, + let f(SRC) == dot(gradient_input,dest) + - Then this function computes the gradient of f() with respect to SRC and + assigns it to grad. + - This function supports in-place operation, i.e. having + is_same_object(grad, gradient_input)==true + !*/ + + // ------------------------------------------------------------------------------------ + + void tanh ( + tensor& dest, + const tensor& src + ); + /*! + requires + - have_same_dimensions(dest, src) == true + ensures + - for all valid i: + - #dest.host()[i] == std::tanh(src.host()[i]) + - This function supports in-place operation, i.e. having + is_same_object(dest, src)==true + !*/ + + void tanh_gradient ( + tensor& grad, + const tensor& dest, + const tensor& gradient_input + ); + /*! + requires + - have_same_dimensions(dest,gradient_input) == true + - have_same_dimensions(dest,grad) == true + - is_same_object(grad,dest) == false + ensures + - Recalling that dest is the output of tanh(dest,SRC) for some SRC tensor, + let f(SRC) == dot(gradient_input,dest) + - Then this function computes the gradient of f() with respect to SRC and + assigns it to grad. + - This function supports in-place operation, i.e. having + is_same_object(grad, gradient_input)==true + !*/ + + + + // ------------------------------------------------------------------------------------ + + } +} + +#endif // DLIB_USE_CUDA + +#endif // DLIB_DNN_CuDNN_H_ + diff --git a/ml/dlib/dlib/dnn/curand_dlibapi.cpp b/ml/dlib/dlib/dnn/curand_dlibapi.cpp new file mode 100644 index 000000000..67828e664 --- /dev/null +++ b/ml/dlib/dlib/dnn/curand_dlibapi.cpp @@ -0,0 +1,113 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_DNN_CuRAND_CPP_ +#define DLIB_DNN_CuRAND_CPP_ + +#ifdef DLIB_USE_CUDA + +#include "curand_dlibapi.h" +#include +#include "../string.h" + +static const char* curand_get_error_string(curandStatus_t s) +{ + switch(s) + { + case CURAND_STATUS_NOT_INITIALIZED: + return "CUDA Runtime API initialization failed."; + case CURAND_STATUS_LENGTH_NOT_MULTIPLE: + return "The requested length must be a multiple of two."; + default: + return "A call to cuRAND failed"; + } +} + +// Check the return value of a call to the cuDNN runtime for an error condition. +#define CHECK_CURAND(call) \ +do{ \ + const curandStatus_t error = call; \ + if (error != CURAND_STATUS_SUCCESS) \ + { \ + std::ostringstream sout; \ + sout << "Error while calling " << #call << " in file " << __FILE__ << ":" << __LINE__ << ". ";\ + sout << "code: " << error << ", reason: " << curand_get_error_string(error);\ + throw dlib::curand_error(sout.str()); \ + } \ +}while(false) + +namespace dlib +{ + namespace cuda + { + + // ---------------------------------------------------------------------------------------- + + curand_generator:: + curand_generator( + unsigned long long seed + ) : handle(nullptr) + { + curandGenerator_t gen; + CHECK_CURAND(curandCreateGenerator(&gen, CURAND_RNG_PSEUDO_DEFAULT)); + handle = gen; + + CHECK_CURAND(curandSetPseudoRandomGeneratorSeed(gen, seed)); + } + + curand_generator:: + ~curand_generator() + { + if (handle) + { + curandDestroyGenerator((curandGenerator_t)handle); + } + } + + void curand_generator:: + fill_gaussian ( + tensor& data, + float mean, + float stddev + ) + { + if (data.size() == 0) + return; + + CHECK_CURAND(curandGenerateNormal((curandGenerator_t)handle, + data.device(), + data.size(), + mean, + stddev)); + } + + void curand_generator:: + fill_uniform ( + tensor& data + ) + { + if (data.size() == 0) + return; + + CHECK_CURAND(curandGenerateUniform((curandGenerator_t)handle, data.device(), data.size())); + } + + void curand_generator:: + fill ( + cuda_data_ptr& data + ) + { + if (data.size() == 0) + return; + + CHECK_CURAND(curandGenerate((curandGenerator_t)handle, data, data.size())); + } + + // ----------------------------------------------------------------------------------- + + } +} + +#endif // DLIB_USE_CUDA + +#endif // DLIB_DNN_CuRAND_CPP_ + diff --git a/ml/dlib/dlib/dnn/curand_dlibapi.h b/ml/dlib/dlib/dnn/curand_dlibapi.h new file mode 100644 index 000000000..cd51fecee --- /dev/null +++ b/ml/dlib/dlib/dnn/curand_dlibapi.h @@ -0,0 +1,75 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_DNN_CuRAND_H_ +#define DLIB_DNN_CuRAND_H_ + +#ifdef DLIB_USE_CUDA + +#include "tensor.h" +#include "cuda_errors.h" +#include "cuda_data_ptr.h" + +namespace dlib +{ + namespace cuda + { + + // ----------------------------------------------------------------------------------- + + class curand_generator + { + public: + // not copyable + curand_generator(const curand_generator&) = delete; + curand_generator& operator=(const curand_generator&) = delete; + + curand_generator() : curand_generator(0) {} + curand_generator(unsigned long long seed); + ~curand_generator(); + + void fill ( + cuda_data_ptr& data + ); + /*! + ensures + - Fills data with random 32-bit unsigned integers. + !*/ + + void fill_gaussian ( + tensor& data, + float mean = 0, + float stddev = 1 + ); + /*! + requires + - data.size()%2 == 0 + - stddev >= 0 + ensures + - Fills data with random numbers drawn from a Gaussian distribution + with the given mean and standard deviation. + !*/ + + void fill_uniform ( + tensor& data + ); + /*! + ensures + - Fills data with uniform random numbers in the range (0.0, 1.0]. + !*/ + + private: + + void* handle; + }; + + // ----------------------------------------------------------------------------------- + + } +} + +#endif // DLIB_USE_CUDA + +#endif // DLIB_DNN_CuRAND_H_ + + + diff --git a/ml/dlib/dlib/dnn/cusolver_dlibapi.cu b/ml/dlib/dlib/dnn/cusolver_dlibapi.cu new file mode 100644 index 000000000..942613134 --- /dev/null +++ b/ml/dlib/dlib/dnn/cusolver_dlibapi.cu @@ -0,0 +1,204 @@ +// Copyright (C) 2017 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_DNN_CuSOLVER_CU_ +#define DLIB_DNN_CuSOLVER_CU_ + +#ifdef DLIB_USE_CUDA + +#include "cusolver_dlibapi.h" +#include +#include +#include "cuda_utils.h" + +// ---------------------------------------------------------------------------------------- + +static const char* cusolver_get_error_string(cusolverStatus_t s) +{ + switch(s) + { + case CUSOLVER_STATUS_NOT_INITIALIZED: + return "CUDA Runtime API initialization failed."; + case CUSOLVER_STATUS_ALLOC_FAILED: + return "CUDA Resources could not be allocated."; + default: + return "A call to cuSolver failed"; + } +} + +// Check the return value of a call to the cuSolver runtime for an error condition. +#define CHECK_CUSOLVER(call) \ +do{ \ + const cusolverStatus_t error = call; \ + if (error != CUSOLVER_STATUS_SUCCESS) \ + { \ + std::ostringstream sout; \ + sout << "Error while calling " << #call << " in file " << __FILE__ << ":" << __LINE__ << ". ";\ + sout << "code: " << error << ", reason: " << cusolver_get_error_string(error);\ + throw dlib::cusolver_error(sout.str()); \ + } \ +}while(false) + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + +namespace dlib +{ + namespace cuda + { + + // ----------------------------------------------------------------------------------- + + class cusolver_context + { + public: + // not copyable + cusolver_context(const cusolver_context&) = delete; + cusolver_context& operator=(const cusolver_context&) = delete; + + cusolver_context() + { + handles.resize(16); + } + ~cusolver_context() + { + for (auto h : handles) + { + if (h) + cusolverDnDestroy(h); + } + } + + cusolverDnHandle_t get_handle ( + ) + { + int new_device_id; + CHECK_CUDA(cudaGetDevice(&new_device_id)); + // make room for more devices if needed + if (new_device_id >= (long)handles.size()) + handles.resize(new_device_id+16); + + // If we don't have a handle already for this device then make one + if (!handles[new_device_id]) + CHECK_CUSOLVER(cusolverDnCreate(&handles[new_device_id])); + + // Finally, return the handle for the current device + return handles[new_device_id]; + } + + private: + + std::vector handles; + }; + + static cusolverDnHandle_t context() + { + thread_local cusolver_context c; + return c.get_handle(); + } + + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + + __global__ void _cuda_set_to_identity_matrix(float* m, size_t nr) + { + for (auto j : grid_stride_range(0, nr*nr)) + { + if (j%(nr+1) == 0) + m[j] = 1; + else + m[j] = 0; + } + } + + void set_to_identity_matrix ( + tensor& m + ) + { + DLIB_CASSERT(m.size() == m.num_samples()*m.num_samples()); + launch_kernel(_cuda_set_to_identity_matrix, max_jobs(m.size()), m.device(), m.num_samples()); + } + + // ------------------------------------------------------------------------------------ + + inv::~inv() + { + sync_if_needed(); + } + + // ------------------------------------------------------------------------------------ + + void inv:: + operator() ( + const tensor& m_, + resizable_tensor& out + ) + { + DLIB_CASSERT(m_.size() == m_.num_samples()*m_.num_samples(), "Input matrix must be square if you want to invert it."); + m = m_; + + out.copy_size(m); + set_to_identity_matrix(out); + + const int nc = m.num_samples(); + int Lwork; + CHECK_CUSOLVER(cusolverDnSgetrf_bufferSize(context(), nc , nc, m.device(), nc, &Lwork)); + + if (Lwork > (int)workspace.size()) + { + sync_if_needed(); + workspace = cuda_data_ptr(Lwork); + } + if (nc > (int)Ipiv.size()) + { + sync_if_needed(); + Ipiv = cuda_data_ptr(nc); + } + if (info.size() != 1) + { + info = cuda_data_ptr(1); + } + + CHECK_CUSOLVER(cusolverDnSgetrf(context(), nc, nc, m.device(), nc, workspace, Ipiv, info)); + CHECK_CUSOLVER(cusolverDnSgetrs(context(), CUBLAS_OP_N, nc, nc, m.device(), nc, Ipiv, out.device(), nc, info)); + did_work_lately = true; + } + + // ------------------------------------------------------------------------------------ + + int inv:: + get_last_status( + ) + { + std::vector linfo; + memcpy(linfo, info); + if (linfo.size() != 0) + return linfo[0]; + else + return 0; + } + + // ------------------------------------------------------------------------------------ + + void inv:: + sync_if_needed() + { + if (did_work_lately) + { + did_work_lately = false; + // make sure we wait until any previous kernel launches have finished + // before we do something like deallocate the GPU memory. + cudaDeviceSynchronize(); + } + } + + // ------------------------------------------------------------------------------------ + + } +} + +#endif // DLIB_USE_CUDA + +#endif // DLIB_DNN_CuSOLVER_CU_ + + diff --git a/ml/dlib/dlib/dnn/cusolver_dlibapi.h b/ml/dlib/dlib/dnn/cusolver_dlibapi.h new file mode 100644 index 000000000..e5c77c151 --- /dev/null +++ b/ml/dlib/dlib/dnn/cusolver_dlibapi.h @@ -0,0 +1,75 @@ +// Copyright (C) 2017 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_DNN_CuSOLVER_H_ +#define DLIB_DNN_CuSOLVER_H_ + +#ifdef DLIB_USE_CUDA + +#include "tensor.h" +#include "cuda_errors.h" +#include "cuda_data_ptr.h" +#include "../noncopyable.h" + +namespace dlib +{ + namespace cuda + { + + // ----------------------------------------------------------------------------------- + + class inv : noncopyable + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a functor for doing matrix inversion on the GPU. The only + reason it's an object is to avoid the reallocation of some GPU memory + blocks if you want to do a bunch of matrix inversions in a row. + !*/ + + public: + + inv() = default; + ~inv(); + + void operator() ( + const tensor& m, + resizable_tensor& out + ); + /*! + requires + - m.size() == m.num_samples()*m.num_samples() + (i.e. mat(m) must be a square matrix) + ensures + - out == inv(mat(m)); + !*/ + + int get_last_status( + ); + /*! + ensures + - returns 0 if the last matrix inversion was successful and != 0 + otherwise. + !*/ + + private: + + void sync_if_needed(); + + bool did_work_lately = false; + resizable_tensor m; + cuda_data_ptr workspace; + cuda_data_ptr Ipiv; + cuda_data_ptr info; + }; + + // ------------------------------------------------------------------------------------ + + } +} + +#endif // DLIB_USE_CUDA + +#endif // DLIB_DNN_CuSOLVER_H_ + + + diff --git a/ml/dlib/dlib/dnn/gpu_data.cpp b/ml/dlib/dlib/dnn/gpu_data.cpp new file mode 100644 index 000000000..6e7cec6be --- /dev/null +++ b/ml/dlib/dlib/dnn/gpu_data.cpp @@ -0,0 +1,228 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_GPU_DaTA_CPP_ +#define DLIB_GPU_DaTA_CPP_ + +// Only things that require CUDA are declared in this cpp file. Everything else is in the +// gpu_data.h header so that it can operate as "header-only" code when using just the CPU. +#ifdef DLIB_USE_CUDA + +#include "gpu_data.h" +#include +#include "cuda_utils.h" +#include + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + void memcpy ( + gpu_data& dest, + const gpu_data& src + ) + { + DLIB_CASSERT(dest.size() == src.size()); + if (src.size() == 0 || &dest == &src) + return; + + memcpy(dest,0, src, 0, src.size()); + } + + void memcpy ( + gpu_data& dest, + size_t dest_offset, + const gpu_data& src, + size_t src_offset, + size_t num + ) + { + DLIB_CASSERT(dest_offset + num <= dest.size()); + DLIB_CASSERT(src_offset + num <= src.size()); + if (num == 0) + return; + + // if there is aliasing + if (&dest == &src && std::max(dest_offset, src_offset) < std::min(dest_offset,src_offset)+num) + { + // if they perfectly alias each other then there is nothing to do + if (dest_offset == src_offset) + return; + else + std::memmove(dest.host()+dest_offset, src.host()+src_offset, sizeof(float)*num); + } + else + { + // if we write to the entire thing then we can use device_write_only() + if (dest_offset == 0 && num == dest.size()) + { + // copy the memory efficiently based on which copy is current in each object. + if (src.device_ready()) + CHECK_CUDA(cudaMemcpy(dest.device_write_only(), src.device()+src_offset, num*sizeof(float), cudaMemcpyDeviceToDevice)); + else + CHECK_CUDA(cudaMemcpy(dest.device_write_only(), src.host()+src_offset, num*sizeof(float), cudaMemcpyHostToDevice)); + } + else + { + // copy the memory efficiently based on which copy is current in each object. + if (dest.device_ready() && src.device_ready()) + CHECK_CUDA(cudaMemcpy(dest.device()+dest_offset, src.device()+src_offset, num*sizeof(float), cudaMemcpyDeviceToDevice)); + else if (!dest.device_ready() && src.device_ready()) + CHECK_CUDA(cudaMemcpy(dest.host()+dest_offset, src.device()+src_offset, num*sizeof(float), cudaMemcpyDeviceToHost)); + else if (dest.device_ready() && !src.device_ready()) + CHECK_CUDA(cudaMemcpy(dest.device()+dest_offset, src.host()+src_offset, num*sizeof(float), cudaMemcpyHostToDevice)); + else + CHECK_CUDA(cudaMemcpy(dest.host()+dest_offset, src.host()+src_offset, num*sizeof(float), cudaMemcpyHostToHost)); + } + } + } +// ---------------------------------------------------------------------------------------- + + void gpu_data:: + wait_for_transfer_to_finish() const + { + if (have_active_transfer) + { + CHECK_CUDA(cudaStreamSynchronize((cudaStream_t)cuda_stream.get())); + have_active_transfer = false; + // Check for errors. These calls to cudaGetLastError() are what help us find + // out if our kernel launches have been failing. + CHECK_CUDA(cudaGetLastError()); + } + } + + void gpu_data:: + copy_to_device() const + { + // We want transfers to the device to always be concurrent with any device + // computation. So we use our non-default stream to do the transfer. + async_copy_to_device(); + wait_for_transfer_to_finish(); + } + + void gpu_data:: + copy_to_host() const + { + if (!host_current) + { + wait_for_transfer_to_finish(); + CHECK_CUDA(cudaMemcpy(data_host.get(), data_device.get(), data_size*sizeof(float), cudaMemcpyDeviceToHost)); + host_current = true; + // At this point we know our RAM block isn't in use because cudaMemcpy() + // implicitly syncs with the device. + device_in_use = false; + // Check for errors. These calls to cudaGetLastError() are what help us find + // out if our kernel launches have been failing. + CHECK_CUDA(cudaGetLastError()); + } + } + + void gpu_data:: + async_copy_to_device() const + { + if (!device_current) + { + if (device_in_use) + { + // Wait for any possible CUDA kernels that might be using our memory block to + // complete before we overwrite the memory. + CHECK_CUDA(cudaStreamSynchronize(0)); + device_in_use = false; + } + CHECK_CUDA(cudaMemcpyAsync(data_device.get(), data_host.get(), data_size*sizeof(float), cudaMemcpyHostToDevice, (cudaStream_t)cuda_stream.get())); + have_active_transfer = true; + device_current = true; + } + } + + void gpu_data:: + set_size( + size_t new_size + ) + { + if (new_size == 0) + { + if (device_in_use) + { + // Wait for any possible CUDA kernels that might be using our memory block to + // complete before we free the memory. + CHECK_CUDA(cudaStreamSynchronize(0)); + device_in_use = false; + } + wait_for_transfer_to_finish(); + data_size = 0; + host_current = true; + device_current = true; + device_in_use = false; + data_host.reset(); + data_device.reset(); + } + else if (new_size != data_size) + { + if (device_in_use) + { + // Wait for any possible CUDA kernels that might be using our memory block to + // complete before we free the memory. + CHECK_CUDA(cudaStreamSynchronize(0)); + device_in_use = false; + } + wait_for_transfer_to_finish(); + data_size = new_size; + host_current = true; + device_current = true; + device_in_use = false; + + try + { + CHECK_CUDA(cudaGetDevice(&the_device_id)); + + // free memory blocks before we allocate new ones. + data_host.reset(); + data_device.reset(); + + void* data; + CHECK_CUDA(cudaMallocHost(&data, new_size*sizeof(float))); + // Note that we don't throw exceptions since the free calls are invariably + // called in destructors. They also shouldn't fail anyway unless someone + // is resetting the GPU card in the middle of their program. + data_host.reset((float*)data, [](float* ptr){ + auto err = cudaFreeHost(ptr); + if(err!=cudaSuccess) + std::cerr << "cudaFreeHost() failed. Reason: " << cudaGetErrorString(err) << std::endl; + }); + + CHECK_CUDA(cudaMalloc(&data, new_size*sizeof(float))); + data_device.reset((float*)data, [](float* ptr){ + auto err = cudaFree(ptr); + if(err!=cudaSuccess) + std::cerr << "cudaFree() failed. Reason: " << cudaGetErrorString(err) << std::endl; + }); + + if (!cuda_stream) + { + cudaStream_t cstream; + CHECK_CUDA(cudaStreamCreateWithFlags(&cstream, cudaStreamNonBlocking)); + cuda_stream.reset(cstream, [](void* ptr){ + auto err = cudaStreamDestroy((cudaStream_t)ptr); + if(err!=cudaSuccess) + std::cerr << "cudaStreamDestroy() failed. Reason: " << cudaGetErrorString(err) << std::endl; + }); + } + + } + catch(...) + { + set_size(0); + throw; + } + } + } + +// ---------------------------------------------------------------------------------------- +} + +#endif // DLIB_USE_CUDA + +#endif // DLIB_GPU_DaTA_CPP_ + diff --git a/ml/dlib/dlib/dnn/gpu_data.h b/ml/dlib/dlib/dnn/gpu_data.h new file mode 100644 index 000000000..022a05f71 --- /dev/null +++ b/ml/dlib/dlib/dnn/gpu_data.h @@ -0,0 +1,266 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_GPU_DaTA_H_ +#define DLIB_GPU_DaTA_H_ + +#include "gpu_data_abstract.h" +#include +#include +#include "cuda_errors.h" +#include "../serialize.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class gpu_data + { + /*! + CONVENTION + - if (size() != 0) then + - data_host == a pointer to size() floats in CPU memory. + - if (data_device) then + - data_device == a pointer to size() floats in device memory. + + - if (there might be an active async transfer from host to device) then + - have_active_transfer == true + + - We use the host_current and device_current bools to keep track of which + copy of the data (or both) are most current. e.g. if the CPU has + modified the data and it hasn't been copied to the device yet then + host_current==true and device_current==false. + + Similarly, we use device_in_use==true to indicate that device() has been + called and no operation to wait for all CUDA kernel completion has been + executed. So if device_in_use==true then there might be a CUDA kernel + executing that is using the device memory block contained in this object. + + !*/ + public: + + gpu_data( + ) : data_size(0), host_current(true), device_current(true),have_active_transfer(false),device_in_use(false), the_device_id(0) + { + } + + // Not copyable + gpu_data(const gpu_data&) = delete; + gpu_data& operator=(const gpu_data&) = delete; + + // but is movable + gpu_data(gpu_data&& item) : gpu_data() { swap(item); } + gpu_data& operator=(gpu_data&& item) { swap(item); return *this; } + + int device_id() const { return the_device_id; } + +#ifdef DLIB_USE_CUDA + void async_copy_to_device() const; + void set_size(size_t new_size); +#else + // Note that calls to host() or device() will block until any async transfers are complete. + void async_copy_to_device() const{} + + void set_size(size_t new_size) + { + if (new_size == 0) + { + data_size = 0; + host_current = true; + device_current = true; + device_in_use = false; + data_host.reset(); + data_device.reset(); + } + else if (new_size != data_size) + { + data_size = new_size; + host_current = true; + device_current = true; + device_in_use = false; + data_host.reset(new float[new_size], std::default_delete()); + data_device.reset(); + } + } +#endif + + const float* host() const + { + copy_to_host(); + return data_host.get(); + } + + float* host() + { + copy_to_host(); + device_current = false; + return data_host.get(); + } + + float* host_write_only() + { + host_current = true; + device_current = false; + return data_host.get(); + } + + const float* device() const + { +#ifndef DLIB_USE_CUDA + DLIB_CASSERT(false, "CUDA NOT ENABLED"); +#endif + copy_to_device(); + device_in_use = true; + return data_device.get(); + } + + float* device() + { +#ifndef DLIB_USE_CUDA + DLIB_CASSERT(false, "CUDA NOT ENABLED"); +#endif + copy_to_device(); + host_current = false; + device_in_use = true; + return data_device.get(); + } + + float* device_write_only() + { +#ifndef DLIB_USE_CUDA + DLIB_CASSERT(false, "CUDA NOT ENABLED"); +#endif + wait_for_transfer_to_finish(); + host_current = false; + device_current = true; + device_in_use = true; + return data_device.get(); + } + + bool host_ready ( + ) const { return host_current; } + + bool device_ready ( + ) const { return device_current && !have_active_transfer; } + + size_t size() const { return data_size; } + + void swap (gpu_data& item) + { + std::swap(data_size, item.data_size); + std::swap(host_current, item.host_current); + std::swap(device_current, item.device_current); + std::swap(have_active_transfer, item.have_active_transfer); + std::swap(data_host, item.data_host); + std::swap(data_device, item.data_device); + std::swap(cuda_stream, item.cuda_stream); + std::swap(the_device_id, item.the_device_id); + } + + private: + +#ifdef DLIB_USE_CUDA + void copy_to_device() const; + void copy_to_host() const; + void wait_for_transfer_to_finish() const; +#else + void copy_to_device() const{} + void copy_to_host() const{} + void wait_for_transfer_to_finish() const{} +#endif + + + size_t data_size; + mutable bool host_current; + mutable bool device_current; + mutable bool have_active_transfer; + mutable bool device_in_use; + + std::shared_ptr data_host; + std::shared_ptr data_device; + std::shared_ptr cuda_stream; + int the_device_id; + }; + + inline void serialize(const gpu_data& item, std::ostream& out) + { + int version = 1; + serialize(version, out); + serialize(item.size(), out); + auto data = item.host(); + for (size_t i = 0; i < item.size(); ++i) + serialize(data[i], out); + } + + inline void deserialize(gpu_data& item, std::istream& in) + { + int version; + deserialize(version, in); + if (version != 1) + throw serialization_error("Unexpected version found while deserializing dlib::gpu_data."); + size_t s; + deserialize(s, in); + item.set_size(s); + auto data = item.host(); + for (size_t i = 0; i < item.size(); ++i) + deserialize(data[i], in); + } + +#ifdef DLIB_USE_CUDA + void memcpy (gpu_data& dest, const gpu_data& src); + + void memcpy ( + gpu_data& dest, + size_t dest_offset, + const gpu_data& src, + size_t src_offset, + size_t num + ); + +#else + + inline void memcpy (gpu_data& dest, const gpu_data& src) + { + DLIB_CASSERT(dest.size() == src.size()); + if (src.size() == 0 || &dest == &src) + return; + std::memcpy(dest.host_write_only(), src.host(), sizeof(float)*src.size()); + } + + inline void memcpy ( + gpu_data& dest, + size_t dest_offset, + const gpu_data& src, + size_t src_offset, + size_t num + ) + { + DLIB_CASSERT(dest_offset + num <= dest.size()); + DLIB_CASSERT(src_offset + num <= src.size()); + if (num == 0) + return; + if (&dest == &src && std::max(dest_offset, src_offset) < std::min(dest_offset,src_offset)+num) + { + // if they perfectly alias each other then there is nothing to do + if (dest_offset == src_offset) + return; + else + std::memmove(dest.host()+dest_offset, src.host()+src_offset, sizeof(float)*num); + } + else + { + // if we write to the entire thing then we can use host_write_only() + if (dest_offset == 0 && num == dest.size()) + std::memcpy(dest.host_write_only(), src.host()+src_offset, sizeof(float)*num); + else + std::memcpy(dest.host()+dest_offset, src.host()+src_offset, sizeof(float)*num); + } + } +#endif + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_GPU_DaTA_H_ + diff --git a/ml/dlib/dlib/dnn/gpu_data_abstract.h b/ml/dlib/dlib/dnn/gpu_data_abstract.h new file mode 100644 index 000000000..f2423dee1 --- /dev/null +++ b/ml/dlib/dlib/dnn/gpu_data_abstract.h @@ -0,0 +1,266 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_GPU_DaTA_ABSTRACT_H_ +#ifdef DLIB_GPU_DaTA_ABSTRACT_H_ + +#include "cuda_errors.h" +#include "../serialize.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class gpu_data + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a block of size() floats, all stored contiguously in memory. + Importantly, it keeps two copies of the floats, one on the host CPU side + and another on the GPU device side. It automatically performs the necessary + host/device transfers to keep these two copies of the data in sync. + + All transfers to the device happen asynchronously with respect to the + default CUDA stream so that CUDA kernel computations can overlap with data + transfers. However, any transfers from the device to the host happen + synchronously in the default CUDA stream. Therefore, you should perform + all your CUDA kernel launches on the default stream so that transfers back + to the host do not happen before the relevant computations have completed. + + If DLIB_USE_CUDA is not #defined then this object will not use CUDA at all. + Instead, it will simply store one host side memory block of floats. + + THREAD SAFETY + Instances of this object are not thread-safe. So don't touch one from + multiple threads at the same time. + !*/ + public: + + gpu_data( + ); + /*! + ensures + - #size() == 0 + - #host() == nullptr + - #device() == nullptr + - #host_ready() == true + - #device_ready() == true + - #device_id() == 0 + !*/ + + // This object is not copyable, however, it is movable. + gpu_data(const gpu_data&) = delete; + gpu_data& operator=(const gpu_data&) = delete; + gpu_data(gpu_data&& item); + gpu_data& operator=(gpu_data&& item); + + int device_id( + ) const; + /*! + ensures + - returns the ID of the CUDA device that allocated this memory. I.e. the + number returned by cudaGetDevice() when the memory was allocated. + - If CUDA is not being used then this function always returns 0. + !*/ + + void async_copy_to_device( + ); + /*! + ensures + - if (!device_ready()) then + - Begins asynchronously copying host data to the device once it is safe + to do so. I.e. This function will wait until any previously + scheduled CUDA kernels, which are using the device() memory block, + have completed before transferring the new data to the device. + - A call to device() that happens before the transfer completes will + block until the transfer is complete. That is, it is safe to call + async_copy_to_device() and then immediately call device(). + !*/ + + void set_size( + size_t new_size + ); + /*! + ensures + - #size() == new_size + !*/ + + bool host_ready ( + ) const; + /*! + ensures + - returns true if and only if the host's copy of the data is current. The + host's data is current if there aren't any modifications to the data + which were made on the device side that have yet to be copied to the + host. + !*/ + + bool device_ready ( + ) const; + /*! + ensures + - returns true if and only if the device's copy of the data is current. + The device's data is current if there aren't any modifications to the + data which were made on the host side that have yet to be copied to the + device. + !*/ + + const float* host( + ) const; + /*! + ensures + - returns a pointer to the host memory block of size() contiguous float + values or nullptr if size()==0. + - if (!host_ready()) then + - copies the data from the device to the host, while this is happening + the call to host() blocks. + - #host_ready() == true + !*/ + + float* host( + ); + /*! + ensures + - returns a pointer to the host memory block of size() contiguous float + values or nullptr if size()==0. + - if (!host_ready()) then + - copies the data from the device to the host, while this is happening + the call to host() blocks. + - #host_ready() == true + - #device_ready() == false + I.e. Marks the device side data as out of date so that the next call to + device() will perform a host to device transfer. If you want to begin + the transfer immediately then you can call async_copy_to_device() after + calling host(). + !*/ + + float* host_write_only( + ); + /*! + ensures + - This function returns the same pointer as host(), except that it never + performs a device to host memory copy. Instead, it immediately marks the + device side data as out of date, effectively discarding it. Therefore, + the values in the data pointed to by host_write_only() are undefined and + you should only call host_write_only() if you are going to assign to + every memory location in the returned memory block. + - #host_ready() == true + - #device_ready() == false + !*/ + + const float* device( + ) const; + /*! + requires + - DLIB_USE_CUDA is #defined + ensures + - returns a pointer to the device memory block of size() contiguous float + values or nullptr if size()==0. + - if (!device_ready()) then + - copies the data from the host to the device, while this is happening + the call to device() blocks. + - #device_ready() == true + !*/ + + float* device( + ); + /*! + requires + - DLIB_USE_CUDA is #defined + ensures + - returns a pointer to the device memory block of size() contiguous float + values or nullptr if size()==0. + - if (!device_ready()) then + - copies the data from the host to the device, while this is happening + the call to device() blocks. + - #host_ready() == false + - #device_ready() == true + !*/ + + float* device_write_only( + ); + /*! + requires + - DLIB_USE_CUDA is #defined + ensures + - This function returns the same pointer as device(), except that it never + performs a host to device memory copy. Instead, it immediately marks the + host side data as out of date, effectively discarding it. Therefore, the + values in the data pointed to by device_write_only() are undefined and + you should only call device_write_only() if you are going to assign to + every memory location in the returned memory block. + - #host_ready() == false + - #device_ready() == true + !*/ + + + size_t size( + ) const; + /*! + ensures + - returns the number of floats contained in this object. + !*/ + + void swap ( + gpu_data& item + ); + /*! + ensures + - swaps the state of *this and item + !*/ + + }; + + void serialize(const gpu_data& item, std::ostream& out); + void deserialize(gpu_data& item, std::istream& in); + /*! + provides serialization support + !*/ + + void memcpy ( + gpu_data& dest, + const gpu_data& src + ); + /*! + requires + - dest.size() == src.size() + ensures + - Copies the data in src to dest. If the device data is current (i.e. + device_ready()==true) on both src and dest then the copy will happen entirely + on the device side. + - It doesn't matter what GPU device is selected by cudaSetDevice(). You can + always copy gpu_data objects to and from each other regardless. + - This function blocks until the copy has completed. + !*/ + + void memcpy ( + gpu_data& dest, + size_t dest_offset, + const gpu_data& src, + size_t src_offset, + size_t num + ); + /*! + requires + - dest_offset + num <= dest.size() + - src_offset + num <= src.size() + ensures + - Copies the data in src to dest, but only copies data in the range + [src.host()+src_offset, src.host()+src_offset+num) to + [dest.host()+dest_offset, dest.host()+dest_offset+num). Therefore, it is + just like the above memcpy() except that you can specify some subset of data + in a gpu_data object to be copied. + - Like the above version of memcpy(), the copy will happen in the most + efficient way, automatically using the appropriate type of host/device + transfers based on where data is currently resident. + - It doesn't matter what GPU device is selected by cudaSetDevice(). You can + always copy gpu_data objects to and from each other regardless. + - This function blocks until the copy has completed. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_GPU_DaTA_ABSTRACT_H_ + diff --git a/ml/dlib/dlib/dnn/input.h b/ml/dlib/dlib/dnn/input.h new file mode 100644 index 000000000..3b5c954e6 --- /dev/null +++ b/ml/dlib/dlib/dnn/input.h @@ -0,0 +1,808 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_DNn_INPUT_H_ +#define DLIB_DNn_INPUT_H_ + +#include "input_abstract.h" +#include "../matrix.h" +#include "../array2d.h" +#include "../pixel.h" +#include "../image_processing.h" +#include +#include +#include "tensor_tools.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template + class input + { + const static bool always_false = sizeof(T)!=sizeof(T); + static_assert(always_false, "Unsupported type given to input<>. input<> only supports " + "dlib::matrix and dlib::array2d objects."); + }; + +// ---------------------------------------------------------------------------------------- + + template + class input_rgb_image_sized; + + class input_rgb_image + { + public: + typedef matrix input_type; + + input_rgb_image ( + ) : + avg_red(122.782), + avg_green(117.001), + avg_blue(104.298) + { + } + + input_rgb_image ( + float avg_red_, + float avg_green_, + float avg_blue_ + ) : avg_red(avg_red_), avg_green(avg_green_), avg_blue(avg_blue_) + {} + + template + inline input_rgb_image ( + const input_rgb_image_sized& item + ); + + float get_avg_red() const { return avg_red; } + float get_avg_green() const { return avg_green; } + float get_avg_blue() const { return avg_blue; } + + bool image_contained_point ( const tensor& data, const point& p) const { return get_rect(data).contains(p); } + drectangle tensor_space_to_image_space ( const tensor& /*data*/, drectangle r) const { return r; } + drectangle image_space_to_tensor_space ( const tensor& /*data*/, double /*scale*/, drectangle r ) const { return r; } + + template + void to_tensor ( + forward_iterator ibegin, + forward_iterator iend, + resizable_tensor& data + ) const + { + DLIB_CASSERT(std::distance(ibegin,iend) > 0); + const auto nr = ibegin->nr(); + const auto nc = ibegin->nc(); + // make sure all the input matrices have the same dimensions + for (auto i = ibegin; i != iend; ++i) + { + DLIB_CASSERT(i->nr()==nr && i->nc()==nc, + "\t input_rgb_image::to_tensor()" + << "\n\t All matrices given to to_tensor() must have the same dimensions." + << "\n\t nr: " << nr + << "\n\t nc: " << nc + << "\n\t i->nr(): " << i->nr() + << "\n\t i->nc(): " << i->nc() + ); + } + + + // initialize data to the right size to contain the stuff in the iterator range. + data.set_size(std::distance(ibegin,iend), 3, nr, nc); + + + const size_t offset = nr*nc; + auto ptr = data.host(); + for (auto i = ibegin; i != iend; ++i) + { + for (long r = 0; r < nr; ++r) + { + for (long c = 0; c < nc; ++c) + { + rgb_pixel temp = (*i)(r,c); + auto p = ptr++; + *p = (temp.red-avg_red)/256.0; + p += offset; + *p = (temp.green-avg_green)/256.0; + p += offset; + *p = (temp.blue-avg_blue)/256.0; + p += offset; + } + } + ptr += offset*(data.k()-1); + } + + } + + friend void serialize(const input_rgb_image& item, std::ostream& out) + { + serialize("input_rgb_image", out); + serialize(item.avg_red, out); + serialize(item.avg_green, out); + serialize(item.avg_blue, out); + } + + friend void deserialize(input_rgb_image& item, std::istream& in) + { + std::string version; + deserialize(version, in); + if (version != "input_rgb_image" && version != "input_rgb_image_sized") + throw serialization_error("Unexpected version found while deserializing dlib::input_rgb_image."); + deserialize(item.avg_red, in); + deserialize(item.avg_green, in); + deserialize(item.avg_blue, in); + + // read and discard the sizes if this was really a sized input layer. + if (version == "input_rgb_image_sized") + { + size_t nr, nc; + deserialize(nr, in); + deserialize(nc, in); + } + } + + friend std::ostream& operator<<(std::ostream& out, const input_rgb_image& item) + { + out << "input_rgb_image("<"; + } + + private: + float avg_red; + float avg_green; + float avg_blue; + }; + +// ---------------------------------------------------------------------------------------- + + template + class input_rgb_image_sized + { + public: + static_assert(NR != 0 && NC != 0, "The input image can't be empty."); + + typedef matrix input_type; + + input_rgb_image_sized ( + ) : + avg_red(122.782), + avg_green(117.001), + avg_blue(104.298) + { + } + + input_rgb_image_sized ( + const input_rgb_image& item + ) : avg_red(item.get_avg_red()), + avg_green(item.get_avg_green()), + avg_blue(item.get_avg_blue()) + {} + + input_rgb_image_sized ( + float avg_red_, + float avg_green_, + float avg_blue_ + ) : avg_red(avg_red_), avg_green(avg_green_), avg_blue(avg_blue_) + {} + + float get_avg_red() const { return avg_red; } + float get_avg_green() const { return avg_green; } + float get_avg_blue() const { return avg_blue; } + + bool image_contained_point ( const tensor& data, const point& p) const { return get_rect(data).contains(p); } + drectangle tensor_space_to_image_space ( const tensor& /*data*/, drectangle r) const { return r; } + drectangle image_space_to_tensor_space ( const tensor& /*data*/, double /*scale*/, drectangle r ) const { return r; } + + template + void to_tensor ( + forward_iterator ibegin, + forward_iterator iend, + resizable_tensor& data + ) const + { + DLIB_CASSERT(std::distance(ibegin,iend) > 0); + // make sure all input images have the correct size + for (auto i = ibegin; i != iend; ++i) + { + DLIB_CASSERT(i->nr()==NR && i->nc()==NC, + "\t input_rgb_image_sized::to_tensor()" + << "\n\t All input images must have "<nr()<<" rows and "<nc()<<" columns." + ); + } + + + // initialize data to the right size to contain the stuff in the iterator range. + data.set_size(std::distance(ibegin,iend), 3, NR, NC); + + + const size_t offset = NR*NC; + auto ptr = data.host(); + for (auto i = ibegin; i != iend; ++i) + { + for (size_t r = 0; r < NR; ++r) + { + for (size_t c = 0; c < NC; ++c) + { + rgb_pixel temp = (*i)(r,c); + auto p = ptr++; + *p = (temp.red-avg_red)/256.0; + p += offset; + *p = (temp.green-avg_green)/256.0; + p += offset; + *p = (temp.blue-avg_blue)/256.0; + p += offset; + } + } + ptr += offset*(data.k()-1); + } + + } + + friend void serialize(const input_rgb_image_sized& item, std::ostream& out) + { + serialize("input_rgb_image_sized", out); + serialize(item.avg_red, out); + serialize(item.avg_green, out); + serialize(item.avg_blue, out); + serialize(NR, out); + serialize(NC, out); + } + + friend void deserialize(input_rgb_image_sized& item, std::istream& in) + { + std::string version; + deserialize(version, in); + if (version != "input_rgb_image_sized") + throw serialization_error("Unexpected version found while deserializing dlib::input_rgb_image_sized."); + deserialize(item.avg_red, in); + deserialize(item.avg_green, in); + deserialize(item.avg_blue, in); + size_t nr, nc; + deserialize(nr, in); + deserialize(nc, in); + if (nr != NR || nc != NC) + { + std::ostringstream sout; + sout << "Wrong image dimensions found while deserializing dlib::input_rgb_image_sized.\n"; + sout << "Expected "<"; + } + + private: + float avg_red; + float avg_green; + float avg_blue; + }; + +// ---------------------------------------------------------------------------------------- + + template + input_rgb_image:: + input_rgb_image ( + const input_rgb_image_sized& item + ) : avg_red(item.get_avg_red()), + avg_green(item.get_avg_green()), + avg_blue(item.get_avg_blue()) + {} + +// ---------------------------------------------------------------------------------------- + + template + class input> + { + public: + typedef matrix input_type; + + input() {} + input(const input&) {} + + template + input(const input>&) {} + + bool image_contained_point ( const tensor& data, const point& p) const { return get_rect(data).contains(p); } + drectangle tensor_space_to_image_space ( const tensor& /*data*/, drectangle r) const { return r; } + drectangle image_space_to_tensor_space ( const tensor& /*data*/, double /*scale*/, drectangle r ) const { return r; } + + template + void to_tensor ( + forward_iterator ibegin, + forward_iterator iend, + resizable_tensor& data + ) const + { + DLIB_CASSERT(std::distance(ibegin,iend) > 0); + const auto nr = ibegin->nr(); + const auto nc = ibegin->nc(); + // make sure all the input matrices have the same dimensions + for (auto i = ibegin; i != iend; ++i) + { + DLIB_CASSERT(i->nr()==nr && i->nc()==nc, + "\t input::to_tensor()" + << "\n\t All matrices given to to_tensor() must have the same dimensions." + << "\n\t nr: " << nr + << "\n\t nc: " << nc + << "\n\t i->nr(): " << i->nr() + << "\n\t i->nc(): " << i->nc() + ); + } + + + // initialize data to the right size to contain the stuff in the iterator range. + data.set_size(std::distance(ibegin,iend), pixel_traits::num, nr, nc); + + typedef typename pixel_traits::basic_pixel_type bptype; + + const size_t offset = nr*nc; + auto ptr = data.host(); + for (auto i = ibegin; i != iend; ++i) + { + for (long r = 0; r < nr; ++r) + { + for (long c = 0; c < nc; ++c) + { + auto temp = pixel_to_vector((*i)(r,c)); + auto p = ptr++; + for (long j = 0; j < temp.size(); ++j) + { + if (is_same_type::value) + *p = temp(j)/256.0; + else + *p = temp(j); + p += offset; + } + } + } + ptr += offset*(data.k()-1); + } + + } + + friend void serialize(const input& /*item*/, std::ostream& out) + { + serialize("input", out); + } + + friend void deserialize(input& /*item*/, std::istream& in) + { + std::string version; + deserialize(version, in); + if (version != "input") + throw serialization_error("Unexpected version found while deserializing dlib::input."); + } + + friend std::ostream& operator<<(std::ostream& out, const input& /*item*/) + { + out << "input"; + return out; + } + + friend void to_xml(const input& /*item*/, std::ostream& out) + { + out << ""; + } + }; + +// ---------------------------------------------------------------------------------------- + + template + class input,K>> + { + public: + typedef std::array,K> input_type; + + input() {} + input(const input&) {} + + bool image_contained_point ( const tensor& data, const point& p) const { return get_rect(data).contains(p); } + drectangle tensor_space_to_image_space ( const tensor& /*data*/, drectangle r) const { return r; } + drectangle image_space_to_tensor_space ( const tensor& /*data*/, double /*scale*/, drectangle r ) const { return r; } + + template + void to_tensor ( + forward_iterator ibegin, + forward_iterator iend, + resizable_tensor& data + ) const + { + DLIB_CASSERT(std::distance(ibegin,iend) > 0); + DLIB_CASSERT(ibegin->size() != 0, "When using std::array inputs you can't give 0 sized arrays."); + const auto nr = (*ibegin)[0].nr(); + const auto nc = (*ibegin)[0].nc(); + // make sure all the input matrices have the same dimensions + for (auto i = ibegin; i != iend; ++i) + { + for (size_t k = 0; k < K; ++k) + { + const auto& arr = *i; + DLIB_CASSERT(arr[k].nr()==nr && arr[k].nc()==nc, + "\t input::to_tensor()" + << "\n\t When using std::array as input, all matrices in a batch must have the same dimensions." + << "\n\t nr: " << nr + << "\n\t nc: " << nc + << "\n\t k: " << k + << "\n\t arr[k].nr(): " << arr[k].nr() + << "\n\t arr[k].nc(): " << arr[k].nc() + ); + } + } + + + // initialize data to the right size to contain the stuff in the iterator range. + data.set_size(std::distance(ibegin,iend), K, nr, nc); + + auto ptr = data.host(); + for (auto i = ibegin; i != iend; ++i) + { + for (size_t k = 0; k < K; ++k) + { + for (long r = 0; r < nr; ++r) + { + for (long c = 0; c < nc; ++c) + { + if (is_same_type::value) + *ptr++ = (*i)[k](r,c)/256.0; + else + *ptr++ = (*i)[k](r,c); + } + } + } + } + + } + + friend void serialize(const input& /*item*/, std::ostream& out) + { + serialize("input>", out); + } + + friend void deserialize(input& /*item*/, std::istream& in) + { + std::string version; + deserialize(version, in); + if (version != "input>") + throw serialization_error("Unexpected version found while deserializing dlib::input>."); + } + + friend std::ostream& operator<<(std::ostream& out, const input& /*item*/) + { + out << "input>"; + return out; + } + + friend void to_xml(const input& /*item*/, std::ostream& out) + { + out << ""; + } + }; + +// ---------------------------------------------------------------------------------------- + + template + class input> + { + public: + typedef array2d input_type; + + input() {} + input(const input&) {} + + template + input(const input>&) {} + + bool image_contained_point ( const tensor& data, const point& p) const { return get_rect(data).contains(p); } + drectangle tensor_space_to_image_space ( const tensor& /*data*/, drectangle r) const { return r; } + drectangle image_space_to_tensor_space ( const tensor& /*data*/, double /*scale*/, drectangle r ) const { return r; } + + template + void to_tensor ( + forward_iterator ibegin, + forward_iterator iend, + resizable_tensor& data + ) const + { + DLIB_CASSERT(std::distance(ibegin,iend) > 0); + const auto nr = ibegin->nr(); + const auto nc = ibegin->nc(); + // make sure all the input matrices have the same dimensions + for (auto i = ibegin; i != iend; ++i) + { + DLIB_CASSERT(i->nr()==nr && i->nc()==nc, + "\t input::to_tensor()" + << "\n\t All array2d objects given to to_tensor() must have the same dimensions." + << "\n\t nr: " << nr + << "\n\t nc: " << nc + << "\n\t i->nr(): " << i->nr() + << "\n\t i->nc(): " << i->nc() + ); + } + + + // initialize data to the right size to contain the stuff in the iterator range. + data.set_size(std::distance(ibegin,iend), pixel_traits::num, nr, nc); + typedef typename pixel_traits::basic_pixel_type bptype; + + const size_t offset = nr*nc; + auto ptr = data.host(); + for (auto i = ibegin; i != iend; ++i) + { + for (long r = 0; r < nr; ++r) + { + for (long c = 0; c < nc; ++c) + { + auto temp = pixel_to_vector((*i)[r][c]); + auto p = ptr++; + for (long j = 0; j < temp.size(); ++j) + { + if (is_same_type::value) + *p = temp(j)/256.0; + else + *p = temp(j); + p += offset; + } + } + } + ptr += offset*(data.k()-1); + } + + } + + friend void serialize(const input& item, std::ostream& out) + { + serialize("input", out); + } + + friend void deserialize(input& item, std::istream& in) + { + std::string version; + deserialize(version, in); + if (version != "input") + throw serialization_error("Unexpected version found while deserializing dlib::input."); + } + friend std::ostream& operator<<(std::ostream& out, const input& item) + { + out << "input"; + return out; + } + + friend void to_xml(const input& item, std::ostream& out) + { + out << ""; + } + }; + +// ---------------------------------------------------------------------------------------- + + template + class input_rgb_image_pyramid + { + public: + typedef matrix input_type; + typedef PYRAMID_TYPE pyramid_type; + + input_rgb_image_pyramid ( + ) : + avg_red(122.782), + avg_green(117.001), + avg_blue(104.298) + { + } + + input_rgb_image_pyramid ( + float avg_red_, + float avg_green_, + float avg_blue_ + ) : avg_red(avg_red_), avg_green(avg_green_), avg_blue(avg_blue_) + {} + + float get_avg_red() const { return avg_red; } + float get_avg_green() const { return avg_green; } + float get_avg_blue() const { return avg_blue; } + + unsigned long get_pyramid_padding () const { return pyramid_padding; } + void set_pyramid_padding (unsigned long value) { pyramid_padding = value; } + + unsigned long get_pyramid_outer_padding () const { return pyramid_outer_padding; } + void set_pyramid_outer_padding (unsigned long value) { pyramid_outer_padding = value; } + + bool image_contained_point ( + const tensor& data, + const point& p + ) const + { + auto&& rects = any_cast>(data.annotation()); + DLIB_CASSERT(rects.size() > 0); + return rects[0].contains(p+rects[0].tl_corner()); + } + + drectangle tensor_space_to_image_space ( + const tensor& data, + drectangle r + ) const + { + auto&& rects = any_cast>(data.annotation()); + return tiled_pyramid_to_image(rects, r); + } + + drectangle image_space_to_tensor_space ( + const tensor& data, + double scale, + drectangle r + ) const + { + DLIB_CASSERT(0 < scale && scale <= 1 , "scale: "<< scale); + auto&& rects = any_cast>(data.annotation()); + return image_to_tiled_pyramid(rects, scale, r); + } + + template + void to_tensor ( + forward_iterator ibegin, + forward_iterator iend, + resizable_tensor& data + ) const + { + DLIB_CASSERT(std::distance(ibegin,iend) > 0); + auto nr = ibegin->nr(); + auto nc = ibegin->nc(); + // make sure all the input matrices have the same dimensions + for (auto i = ibegin; i != iend; ++i) + { + DLIB_CASSERT(i->nr()==nr && i->nc()==nc, + "\t input_rgb_image_pyramid::to_tensor()" + << "\n\t All matrices given to to_tensor() must have the same dimensions." + << "\n\t nr: " << nr + << "\n\t nc: " << nc + << "\n\t i->nr(): " << i->nr() + << "\n\t i->nc(): " << i->nc() + ); + } + + long NR, NC; + pyramid_type pyr; + auto& rects = data.annotation().get>(); + impl::compute_tiled_image_pyramid_details(pyr, nr, nc, pyramid_padding, pyramid_outer_padding, rects, NR, NC); + + // initialize data to the right size to contain the stuff in the iterator range. + data.set_size(std::distance(ibegin,iend), 3, NR, NC); + + // We need to zero the image before doing the pyramid, since the pyramid + // creation code doesn't write to all parts of the image. We also take + // care to avoid triggering any device to hosts copies. + auto ptr = data.host_write_only(); + for (size_t i = 0; i < data.size(); ++i) + ptr[i] = 0; + + if (rects.size() == 0) + return; + + // copy the first raw image into the top part of the tiled pyramid. We need to + // do this for each of the input images/samples in the tensor. + for (auto i = ibegin; i != iend; ++i) + { + auto& img = *i; + ptr += rects[0].top()*data.nc(); + for (long r = 0; r < img.nr(); ++r) + { + auto p = ptr+rects[0].left(); + for (long c = 0; c < img.nc(); ++c) + p[c] = (img(r,c).red-avg_red)/256.0; + ptr += data.nc(); + } + ptr += data.nc()*(data.nr()-rects[0].bottom()-1); + + ptr += rects[0].top()*data.nc(); + for (long r = 0; r < img.nr(); ++r) + { + auto p = ptr+rects[0].left(); + for (long c = 0; c < img.nc(); ++c) + p[c] = (img(r,c).green-avg_green)/256.0; + ptr += data.nc(); + } + ptr += data.nc()*(data.nr()-rects[0].bottom()-1); + + ptr += rects[0].top()*data.nc(); + for (long r = 0; r < img.nr(); ++r) + { + auto p = ptr+rects[0].left(); + for (long c = 0; c < img.nc(); ++c) + p[c] = (img(r,c).blue-avg_blue)/256.0; + ptr += data.nc(); + } + ptr += data.nc()*(data.nr()-rects[0].bottom()-1); + } + + // now build the image pyramid into data. This does the same thing as + // create_tiled_pyramid(), except we use the GPU if one is available. + for (size_t i = 1; i < rects.size(); ++i) + { + alias_tensor src(data.num_samples(),data.k(),rects[i-1].height(),rects[i-1].width()); + alias_tensor dest(data.num_samples(),data.k(),rects[i].height(),rects[i].width()); + + auto asrc = src(data, data.nc()*rects[i-1].top() + rects[i-1].left()); + auto adest = dest(data, data.nc()*rects[i].top() + rects[i].left()); + + tt::resize_bilinear(adest, data.nc(), data.nr()*data.nc(), + asrc, data.nc(), data.nr()*data.nc()); + } + } + + friend void serialize(const input_rgb_image_pyramid& item, std::ostream& out) + { + serialize("input_rgb_image_pyramid2", out); + serialize(item.avg_red, out); + serialize(item.avg_green, out); + serialize(item.avg_blue, out); + serialize(item.pyramid_padding, out); + serialize(item.pyramid_outer_padding, out); + } + + friend void deserialize(input_rgb_image_pyramid& item, std::istream& in) + { + std::string version; + deserialize(version, in); + if (version != "input_rgb_image_pyramid" && version != "input_rgb_image_pyramid2") + throw serialization_error("Unexpected version found while deserializing dlib::input_rgb_image_pyramid."); + deserialize(item.avg_red, in); + deserialize(item.avg_green, in); + deserialize(item.avg_blue, in); + if (version == "input_rgb_image_pyramid2") + { + deserialize(item.pyramid_padding, in); + deserialize(item.pyramid_outer_padding, in); + } + else + { + item.pyramid_padding = 10; + item.pyramid_outer_padding = 11; + } + } + + friend std::ostream& operator<<(std::ostream& out, const input_rgb_image_pyramid& item) + { + out << "input_rgb_image_pyramid("<"; + } + + private: + float avg_red; + float avg_green; + float avg_blue; + unsigned long pyramid_padding = 10; + unsigned long pyramid_outer_padding = 11; + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_DNn_INPUT_H_ + diff --git a/ml/dlib/dlib/dnn/input_abstract.h b/ml/dlib/dlib/dnn/input_abstract.h new file mode 100644 index 000000000..7130efb17 --- /dev/null +++ b/ml/dlib/dlib/dnn/input_abstract.h @@ -0,0 +1,467 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_DNn_INPUT_ABSTRACT_H_ +#ifdef DLIB_DNn_INPUT_ABSTRACT_H_ + +#include "../matrix.h" +#include "../pixel.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class EXAMPLE_INPUT_LAYER + { + /*! + WHAT THIS OBJECT REPRESENTS + Each deep neural network model in dlib begins with an input layer. The job + of the input layer is to convert an input_type into a tensor. Nothing more + and nothing less. + + Note that there is no dlib::EXAMPLE_INPUT_LAYER type. It is shown here + purely to document the interface that an input layer object must implement. + If you are using some kind of image or matrix object as your input_type + then you can use the provided dlib::input layer defined below. Otherwise, + you need to define your own custom input layer. + + THREAD SAFETY + to_tensor() must be thread safe. That is, multiple threads must be able to + make calls to to_tensor() on a single instance of this object at the same + time. + !*/ + public: + + EXAMPLE_INPUT_LAYER( + ); + /*! + ensures + - Default constructs this object. This function is not required to do + anything in particular but it must exist, that is, it is required that + layer objects be default constructable. + !*/ + + EXAMPLE_INPUT_LAYER ( + const EXAMPLE_INPUT_LAYER& item + ); + /*! + ensures + - EXAMPLE_INPUT_LAYER objects are copy constructable + !*/ + + EXAMPLE_INPUT_LAYER( + const some_other_input_layer_type& item + ); + /*! + ensures + - Constructs this object from item. This form of constructor is optional + but it allows you to provide a conversion from one input layer type to + another. For example, the following code is valid only if my_input_layer2 can + be constructed from my_input_layer1: + relu>>> my_dnn1; + relu>>> my_dnn2(my_dnn1); + This kind of pattern is useful if you want to use one type of input layer + during training but a different type of layer during testing since it + allows you to easily convert between related deep neural network types. + !*/ + + typedef whatever_type_to_tensor_expects input_type; + + template + void to_tensor ( + forward_iterator ibegin, + forward_iterator iend, + resizable_tensor& data + ) const; + /*! + requires + - [ibegin, iend) is an iterator range over input_type objects. + - std::distance(ibegin,iend) > 0 + ensures + - Converts the iterator range into a tensor and stores it into #data. + - #data.num_samples()%distance(ibegin,iend) == 0. + Normally you would have #data.num_samples() == distance(ibegin,iend) but + you can also expand the output by some integer factor so long as the loss + you use can deal with it correctly. + - The data in the ith sample of #data corresponds to the input_type object + *(ibegin+i/sample_expansion_factor). + where sample_expansion_factor==#data.num_samples()/distance(ibegin,iend). + !*/ + }; + + std::ostream& operator<<(std::ostream& out, const EXAMPLE_INPUT_LAYER& item); + /*! + print a string describing this layer. + !*/ + + void to_xml(const EXAMPLE_INPUT_LAYER& item, std::ostream& out); + /*! + This function is optional, but required if you want to print your networks with + net_to_xml(). Therefore, to_xml() prints a layer as XML. + !*/ + + void serialize(const EXAMPLE_INPUT_LAYER& item, std::ostream& out); + void deserialize(EXAMPLE_INPUT_LAYER& item, std::istream& in); + /*! + provides serialization support + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + class input + { + /*! + REQUIREMENTS ON T + One of the following must be true: + - T is a matrix or array2d object and it must contain some kind of + pixel type. I.e. pixel_traits must be defined. + - T is a std::array> where U is any built in scalar type like + float, double, or unsigned char. + + WHAT THIS OBJECT REPRESENTS + This is a basic input layer that simply copies images into a tensor. + !*/ + + public: + typedef T input_type; + + template + void to_tensor ( + forward_iterator ibegin, + forward_iterator iend, + resizable_tensor& data + ) const; + /*! + requires + - [ibegin, iend) is an iterator range over input_type objects. + - std::distance(ibegin,iend) > 0 + - The input range should contain image objects that all have the same + dimensions. + ensures + - Converts the iterator range into a tensor and stores it into #data. In + particular, if the input images have R rows, C columns, and K channels + (where K is given by pixel_traits::num or std::array::size() if + std::array inputs are used) then we will have: + - #data.num_samples() == std::distance(ibegin,iend) + - #data.nr() == R + - #data.nc() == C + - #data.k() == K + For example, a matrix would turn into a tensor with 3 rows, 3 + columns, and k()==1. Or a matrix would turn into a tensor + with 4 rows, 5 columns, and k()==3 (since rgb_pixels have 3 channels). + Or a std::array,5> would turn into a tensor with 3 rows + and columns, and k()==5 channels. + - If the input data contains pixels of type unsigned char, rgb_pixel, or + other pixel types with a basic_pixel_type of unsigned char then each + value written to the output tensor is first divided by 256.0 so that the + resulting outputs are all in the range [0,1]. + !*/ + + // Provided for compatibility with input_rgb_image_pyramid's interface + bool image_contained_point ( const tensor& data, const point& p) const { return get_rect(data).contains(p); } + drectangle tensor_space_to_image_space ( const tensor& /*data*/, drectangle r) const { return r; } + drectangle image_space_to_tensor_space ( const tensor& /*data*/, double /*scale*/, drectangle r ) const { return r; } + }; + +// ---------------------------------------------------------------------------------------- + + class input_rgb_image + { + /*! + WHAT THIS OBJECT REPRESENTS + This input layer works with RGB images of type matrix. It is + very similar to the dlib::input layer except that it allows you to subtract + the average color value from each color channel when converting an image to + a tensor. + !*/ + public: + typedef matrix input_type; + + input_rgb_image ( + ); + /*! + ensures + - #get_avg_red() == 122.782 + - #get_avg_green() == 117.001 + - #get_avg_blue() == 104.298 + !*/ + + input_rgb_image ( + float avg_red, + float avg_green, + float avg_blue + ); + /*! + ensures + - #get_avg_red() == avg_red + - #get_avg_green() == avg_green + - #get_avg_blue() == avg_blue + !*/ + + float get_avg_red( + ) const; + /*! + ensures + - returns the value subtracted from the red color channel. + !*/ + + float get_avg_green( + ) const; + /*! + ensures + - returns the value subtracted from the green color channel. + !*/ + + float get_avg_blue( + ) const; + /*! + ensures + - returns the value subtracted from the blue color channel. + !*/ + + template + void to_tensor ( + forward_iterator ibegin, + forward_iterator iend, + resizable_tensor& data + ) const; + /*! + requires + - [ibegin, iend) is an iterator range over input_type objects. + - std::distance(ibegin,iend) > 0 + - The input range should contain images that all have the same + dimensions. + ensures + - Converts the iterator range into a tensor and stores it into #data. In + particular, if the input images have R rows, C columns then we will have: + - #data.num_samples() == std::distance(ibegin,iend) + - #data.nr() == R + - #data.nc() == C + - #data.k() == 3 + Moreover, each color channel is normalized by having its average value + subtracted (according to get_avg_red(), get_avg_green(), or + get_avg_blue()) and then is divided by 256.0. + !*/ + + + // Provided for compatibility with input_rgb_image_pyramid's interface + bool image_contained_point ( const tensor& data, const point& p) const { return get_rect(data).contains(p); } + drectangle tensor_space_to_image_space ( const tensor& /*data*/, drectangle r) const { return r; } + drectangle image_space_to_tensor_space ( const tensor& /*data*/, double /*scale*/, drectangle r ) const { return r; } + }; + +// ---------------------------------------------------------------------------------------- + + template + class input_rgb_image_sized + { + /*! + WHAT THIS OBJECT REPRESENTS + This layer has an interface and behavior identical to input_rgb_image + except that it requires input images to have NR rows and NC columns. This + is checked by a DLIB_CASSERT inside to_tensor(). + + You can also convert between input_rgb_image and input_rgb_image_sized by + copy construction or assignment. + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename PYRAMID_TYPE + > + class input_rgb_image_pyramid + { + /*! + REQUIREMENTS ON PYRAMID_TYPE + PYRAMID_TYPE must be an instance of the dlib::pyramid_down template. + + WHAT THIS OBJECT REPRESENTS + This input layer works with RGB images of type matrix. It is + identical to input_rgb_image except that it outputs a tensor containing a + tiled image pyramid of each input image rather than a simple copy of each + image. The tiled image pyramid is created using create_tiled_pyramid(). + !*/ + + public: + + typedef matrix input_type; + typedef PYRAMID_TYPE pyramid_type; + + input_rgb_image_pyramid ( + ); + /*! + ensures + - #get_avg_red() == 122.782 + - #get_avg_green() == 117.001 + - #get_avg_blue() == 104.298 + - #get_pyramid_padding() == 10 + - #get_pyramid_outer_padding() == 11 + !*/ + + input_rgb_image_pyramid ( + float avg_red, + float avg_green, + float avg_blue + ); + /*! + ensures + - #get_avg_red() == avg_red + - #get_avg_green() == avg_green + - #get_avg_blue() == avg_blue + - #get_pyramid_padding() == 10 + - #get_pyramid_outer_padding() == 11 + !*/ + + float get_avg_red( + ) const; + /*! + ensures + - returns the value subtracted from the red color channel. + !*/ + + float get_avg_green( + ) const; + /*! + ensures + - returns the value subtracted from the green color channel. + !*/ + + float get_avg_blue( + ) const; + /*! + ensures + - returns the value subtracted from the blue color channel. + !*/ + + unsigned long get_pyramid_padding ( + ) const; + /*! + ensures + - When this object creates a pyramid it will call create_tiled_pyramid() and + set create_tiled_pyramid's pyramid_padding parameter to get_pyramid_padding(). + !*/ + void set_pyramid_padding ( + unsigned long value + ); + /*! + ensures + - #get_pyramid_padding() == value + !*/ + + unsigned long get_pyramid_outer_padding ( + ) const; + /*! + ensures + - When this object creates a pyramid it will call create_tiled_pyramid() + and set create_tiled_pyramid's pyramid_outer_padding parameter to + get_pyramid_outer_padding(). + !*/ + void set_pyramid_outer_padding ( + unsigned long value + ); + /*! + ensures + - #get_pyramid_outer_padding() == value + !*/ + + template + void to_tensor ( + forward_iterator ibegin, + forward_iterator iend, + resizable_tensor& data + ) const; + /*! + requires + - [ibegin, iend) is an iterator range over input_type objects. + - std::distance(ibegin,iend) > 0 + - The input range should contain images that all have the same + dimensions. + ensures + - Converts the iterator range into a tensor and stores it into #data. In + particular, we will have: + - #data.num_samples() == std::distance(ibegin,iend) + - #data.k() == 3 + - Each sample in #data contains a tiled image pyramid of the + corresponding input image. The tiled pyramid is created by + create_tiled_pyramid(). + Moreover, each color channel is normalized by having its average value + subtracted (according to get_avg_red(), get_avg_green(), or + get_avg_blue()) and then is divided by 256.0. + !*/ + + bool image_contained_point ( + const tensor& data, + const point& p + ) const; + /*! + requires + - data is a tensor that was produced by this->to_tensor() + ensures + - Since data is a tensor that is built from a bunch of identically sized + images, we can ask if those images were big enough to contain the point + p. This function returns the answer to that question. + !*/ + + drectangle image_space_to_tensor_space ( + const tensor& data, + double scale, + drectangle r + ) const; + /*! + requires + - data is a tensor that was produced by this->to_tensor() + - 0 < scale <= 1 + ensures + - This function maps from to_tensor()'s input image space to its output + tensor space. Therefore, given that data is a tensor produced by + to_tensor(), image_space_to_tensor_space() allows you to ask for the + rectangle in data that corresponds to a rectangle in the original image + space. + + Note that since the output tensor contains an image pyramid, there are + multiple points in the output tensor that correspond to any input + location. So you must also specify a scale so we know what level of the + pyramid is needed. So given a rectangle r in an input image, you can + ask, what rectangle in data corresponds to r when things are scale times + smaller? That rectangle is returned by this function. + - A scale of 1 means we don't move anywhere in the pyramid scale space relative + to the input image while smaller values of scale mean we move down the + pyramid. + !*/ + + drectangle tensor_space_to_image_space ( + const tensor& data, + drectangle r + ) const; + /*! + requires + - data is a tensor that was produced by this->to_tensor() + ensures + - This function maps from to_tensor()'s output tensor space to its input + image space. Therefore, given that data is a tensor produced by + to_tensor(), tensor_space_to_image_space() allows you to ask for the + rectangle in the input image that corresponds to a rectangle in data. + - It should be noted that this function isn't always an inverse of + image_space_to_tensor_space(). This is because you can ask + image_space_to_tensor_space() for the coordinates of points outside the input + image and they will be mapped to somewhere that doesn't have an inverse. + But for points actually inside the input image this function performs an + approximate inverse mapping. I.e. when image_contained_point(data,center(r))==true + there is an approximate inverse. + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_DNn_INPUT_ABSTRACT_H_ + diff --git a/ml/dlib/dlib/dnn/layers.h b/ml/dlib/dlib/dnn/layers.h new file mode 100644 index 000000000..91436f635 --- /dev/null +++ b/ml/dlib/dlib/dnn/layers.h @@ -0,0 +1,3244 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_DNn_LAYERS_H_ +#define DLIB_DNn_LAYERS_H_ + +#include "layers_abstract.h" +#include "tensor.h" +#include "core.h" +#include +#include +#include "../rand.h" +#include "../string.h" +#include "tensor_tools.h" +#include "../vectorstream.h" +#include "utilities.h" +#include + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + struct num_con_outputs + { + num_con_outputs(unsigned long n) : num_outputs(n) {} + unsigned long num_outputs; + }; + + template < + long _num_filters, + long _nr, + long _nc, + int _stride_y, + int _stride_x, + int _padding_y = _stride_y!=1? 0 : _nr/2, + int _padding_x = _stride_x!=1? 0 : _nc/2 + > + class con_ + { + public: + + static_assert(_num_filters > 0, "The number of filters must be > 0"); + static_assert(_nr >= 0, "The number of rows in a filter must be >= 0"); + static_assert(_nc >= 0, "The number of columns in a filter must be >= 0"); + static_assert(_stride_y > 0, "The filter stride must be > 0"); + static_assert(_stride_x > 0, "The filter stride must be > 0"); + static_assert(_nr==0 || (0 <= _padding_y && _padding_y < _nr), "The padding must be smaller than the filter size."); + static_assert(_nc==0 || (0 <= _padding_x && _padding_x < _nc), "The padding must be smaller than the filter size."); + static_assert(_nr!=0 || 0 == _padding_y, "If _nr==0 then the padding must be set to 0 as well."); + static_assert(_nc!=0 || 0 == _padding_x, "If _nr==0 then the padding must be set to 0 as well."); + + con_( + num_con_outputs o + ) : + learning_rate_multiplier(1), + weight_decay_multiplier(1), + bias_learning_rate_multiplier(1), + bias_weight_decay_multiplier(0), + num_filters_(o.num_outputs), + padding_y_(_padding_y), + padding_x_(_padding_x) + { + DLIB_CASSERT(num_filters_ > 0); + } + + con_() : con_(num_con_outputs(_num_filters)) {} + + long num_filters() const { return num_filters_; } + long nr() const + { + if (_nr==0) + return filters.nr(); + else + return _nr; + } + long nc() const + { + if (_nc==0) + return filters.nc(); + else + return _nc; + } + long stride_y() const { return _stride_y; } + long stride_x() const { return _stride_x; } + long padding_y() const { return padding_y_; } + long padding_x() const { return padding_x_; } + + void set_num_filters(long num) + { + DLIB_CASSERT(num > 0); + if (num != num_filters_) + { + DLIB_CASSERT(get_layer_params().size() == 0, + "You can't change the number of filters in con_ if the parameter tensor has already been allocated."); + num_filters_ = num; + } + } + + double get_learning_rate_multiplier () const { return learning_rate_multiplier; } + double get_weight_decay_multiplier () const { return weight_decay_multiplier; } + void set_learning_rate_multiplier(double val) { learning_rate_multiplier = val; } + void set_weight_decay_multiplier(double val) { weight_decay_multiplier = val; } + + double get_bias_learning_rate_multiplier () const { return bias_learning_rate_multiplier; } + double get_bias_weight_decay_multiplier () const { return bias_weight_decay_multiplier; } + void set_bias_learning_rate_multiplier(double val) { bias_learning_rate_multiplier = val; } + void set_bias_weight_decay_multiplier(double val) { bias_weight_decay_multiplier = val; } + + inline dpoint map_input_to_output ( + dpoint p + ) const + { + p.x() = (p.x()+padding_x()-nc()/2)/stride_x(); + p.y() = (p.y()+padding_y()-nr()/2)/stride_y(); + return p; + } + + inline dpoint map_output_to_input ( + dpoint p + ) const + { + p.x() = p.x()*stride_x() - padding_x() + nc()/2; + p.y() = p.y()*stride_y() - padding_y() + nr()/2; + return p; + } + + con_ ( + const con_& item + ) : + params(item.params), + filters(item.filters), + biases(item.biases), + learning_rate_multiplier(item.learning_rate_multiplier), + weight_decay_multiplier(item.weight_decay_multiplier), + bias_learning_rate_multiplier(item.bias_learning_rate_multiplier), + bias_weight_decay_multiplier(item.bias_weight_decay_multiplier), + num_filters_(item.num_filters_), + padding_y_(item.padding_y_), + padding_x_(item.padding_x_) + { + // this->conv is non-copyable and basically stateless, so we have to write our + // own copy to avoid trying to copy it and getting an error. + } + + con_& operator= ( + const con_& item + ) + { + if (this == &item) + return *this; + + // this->conv is non-copyable and basically stateless, so we have to write our + // own copy to avoid trying to copy it and getting an error. + params = item.params; + filters = item.filters; + biases = item.biases; + padding_y_ = item.padding_y_; + padding_x_ = item.padding_x_; + learning_rate_multiplier = item.learning_rate_multiplier; + weight_decay_multiplier = item.weight_decay_multiplier; + bias_learning_rate_multiplier = item.bias_learning_rate_multiplier; + bias_weight_decay_multiplier = item.bias_weight_decay_multiplier; + num_filters_ = item.num_filters_; + return *this; + } + + template + void setup (const SUBNET& sub) + { + const long filt_nr = _nr!=0 ? _nr : sub.get_output().nr(); + const long filt_nc = _nc!=0 ? _nc : sub.get_output().nc(); + + long num_inputs = filt_nr*filt_nc*sub.get_output().k(); + long num_outputs = num_filters_; + // allocate params for the filters and also for the filter bias values. + params.set_size(num_inputs*num_filters_ + num_filters_); + + dlib::rand rnd(std::rand()); + randomize_parameters(params, num_inputs+num_outputs, rnd); + + filters = alias_tensor(num_filters_, sub.get_output().k(), filt_nr, filt_nc); + biases = alias_tensor(1,num_filters_); + + // set the initial bias values to zero + biases(params,filters.size()) = 0; + } + + template + void forward(const SUBNET& sub, resizable_tensor& output) + { + conv.setup(sub.get_output(), + filters(params,0), + _stride_y, + _stride_x, + padding_y_, + padding_x_); + conv(false, output, + sub.get_output(), + filters(params,0)); + + tt::add(1,output,1,biases(params,filters.size())); + } + + template + void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad) + { + conv.get_gradient_for_data (true, gradient_input, filters(params,0), sub.get_gradient_input()); + // no dpoint computing the parameter gradients if they won't be used. + if (learning_rate_multiplier != 0) + { + auto filt = filters(params_grad,0); + conv.get_gradient_for_filters (false, gradient_input, sub.get_output(), filt); + auto b = biases(params_grad, filters.size()); + tt::assign_conv_bias_gradient(b, gradient_input); + } + } + + const tensor& get_layer_params() const { return params; } + tensor& get_layer_params() { return params; } + + friend void serialize(const con_& item, std::ostream& out) + { + serialize("con_4", out); + serialize(item.params, out); + serialize(item.num_filters_, out); + serialize(_nr, out); + serialize(_nc, out); + serialize(_stride_y, out); + serialize(_stride_x, out); + serialize(item.padding_y_, out); + serialize(item.padding_x_, out); + serialize(item.filters, out); + serialize(item.biases, out); + serialize(item.learning_rate_multiplier, out); + serialize(item.weight_decay_multiplier, out); + serialize(item.bias_learning_rate_multiplier, out); + serialize(item.bias_weight_decay_multiplier, out); + } + + friend void deserialize(con_& item, std::istream& in) + { + std::string version; + deserialize(version, in); + long nr; + long nc; + int stride_y; + int stride_x; + if (version == "con_4") + { + deserialize(item.params, in); + deserialize(item.num_filters_, in); + deserialize(nr, in); + deserialize(nc, in); + deserialize(stride_y, in); + deserialize(stride_x, in); + deserialize(item.padding_y_, in); + deserialize(item.padding_x_, in); + deserialize(item.filters, in); + deserialize(item.biases, in); + deserialize(item.learning_rate_multiplier, in); + deserialize(item.weight_decay_multiplier, in); + deserialize(item.bias_learning_rate_multiplier, in); + deserialize(item.bias_weight_decay_multiplier, in); + if (item.padding_y_ != _padding_y) throw serialization_error("Wrong padding_y found while deserializing dlib::con_"); + if (item.padding_x_ != _padding_x) throw serialization_error("Wrong padding_x found while deserializing dlib::con_"); + if (nr != _nr) throw serialization_error("Wrong nr found while deserializing dlib::con_"); + if (nc != _nc) throw serialization_error("Wrong nc found while deserializing dlib::con_"); + if (stride_y != _stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::con_"); + if (stride_x != _stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::con_"); + } + else + { + throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::con_."); + } + } + + + friend std::ostream& operator<<(std::ostream& out, const con_& item) + { + out << "con\t (" + << "num_filters="<\n"; + out << mat(item.params); + out << ""; + } + + private: + + resizable_tensor params; + alias_tensor filters, biases; + + tt::tensor_conv conv; + double learning_rate_multiplier; + double weight_decay_multiplier; + double bias_learning_rate_multiplier; + double bias_weight_decay_multiplier; + long num_filters_; + + // These are here only because older versions of con (which you might encounter + // serialized to disk) used different padding settings. + int padding_y_; + int padding_x_; + + }; + + template < + long num_filters, + long nr, + long nc, + int stride_y, + int stride_x, + typename SUBNET + > + using con = add_layer, SUBNET>; + +// ---------------------------------------------------------------------------------------- + + template < + long _num_filters, + long _nr, + long _nc, + int _stride_y, + int _stride_x, + int _padding_y = _stride_y!=1? 0 : _nr/2, + int _padding_x = _stride_x!=1? 0 : _nc/2 + > + class cont_ + { + public: + + static_assert(_num_filters > 0, "The number of filters must be > 0"); + static_assert(_nr > 0, "The number of rows in a filter must be > 0"); + static_assert(_nc > 0, "The number of columns in a filter must be > 0"); + static_assert(_stride_y > 0, "The filter stride must be > 0"); + static_assert(_stride_x > 0, "The filter stride must be > 0"); + static_assert(0 <= _padding_y && _padding_y < _nr, "The padding must be smaller than the filter size."); + static_assert(0 <= _padding_x && _padding_x < _nc, "The padding must be smaller than the filter size."); + + cont_( + num_con_outputs o + ) : + learning_rate_multiplier(1), + weight_decay_multiplier(1), + bias_learning_rate_multiplier(1), + bias_weight_decay_multiplier(0), + num_filters_(o.num_outputs), + padding_y_(_padding_y), + padding_x_(_padding_x) + { + DLIB_CASSERT(num_filters_ > 0); + } + + cont_() : cont_(num_con_outputs(_num_filters)) {} + + long num_filters() const { return num_filters_; } + long nr() const { return _nr; } + long nc() const { return _nc; } + long stride_y() const { return _stride_y; } + long stride_x() const { return _stride_x; } + long padding_y() const { return padding_y_; } + long padding_x() const { return padding_x_; } + + void set_num_filters(long num) + { + DLIB_CASSERT(num > 0); + if (num != num_filters_) + { + DLIB_CASSERT(get_layer_params().size() == 0, + "You can't change the number of filters in cont_ if the parameter tensor has already been allocated."); + num_filters_ = num; + } + } + + double get_learning_rate_multiplier () const { return learning_rate_multiplier; } + double get_weight_decay_multiplier () const { return weight_decay_multiplier; } + void set_learning_rate_multiplier(double val) { learning_rate_multiplier = val; } + void set_weight_decay_multiplier(double val) { weight_decay_multiplier = val; } + + double get_bias_learning_rate_multiplier () const { return bias_learning_rate_multiplier; } + double get_bias_weight_decay_multiplier () const { return bias_weight_decay_multiplier; } + void set_bias_learning_rate_multiplier(double val) { bias_learning_rate_multiplier = val; } + void set_bias_weight_decay_multiplier(double val) { bias_weight_decay_multiplier = val; } + + inline dpoint map_output_to_input ( + dpoint p + ) const + { + p.x() = (p.x()+padding_x()-nc()/2)/stride_x(); + p.y() = (p.y()+padding_y()-nr()/2)/stride_y(); + return p; + } + + inline dpoint map_input_to_output ( + dpoint p + ) const + { + p.x() = p.x()*stride_x() - padding_x() + nc()/2; + p.y() = p.y()*stride_y() - padding_y() + nr()/2; + return p; + } + + cont_ ( + const cont_& item + ) : + params(item.params), + filters(item.filters), + biases(item.biases), + learning_rate_multiplier(item.learning_rate_multiplier), + weight_decay_multiplier(item.weight_decay_multiplier), + bias_learning_rate_multiplier(item.bias_learning_rate_multiplier), + bias_weight_decay_multiplier(item.bias_weight_decay_multiplier), + num_filters_(item.num_filters_), + padding_y_(item.padding_y_), + padding_x_(item.padding_x_) + { + // this->conv is non-copyable and basically stateless, so we have to write our + // own copy to avoid trying to copy it and getting an error. + } + + cont_& operator= ( + const cont_& item + ) + { + if (this == &item) + return *this; + + // this->conv is non-copyable and basically stateless, so we have to write our + // own copy to avoid trying to copy it and getting an error. + params = item.params; + filters = item.filters; + biases = item.biases; + padding_y_ = item.padding_y_; + padding_x_ = item.padding_x_; + learning_rate_multiplier = item.learning_rate_multiplier; + weight_decay_multiplier = item.weight_decay_multiplier; + bias_learning_rate_multiplier = item.bias_learning_rate_multiplier; + bias_weight_decay_multiplier = item.bias_weight_decay_multiplier; + num_filters_ = item.num_filters_; + return *this; + } + + template + void setup (const SUBNET& sub) + { + long num_inputs = _nr*_nc*sub.get_output().k(); + long num_outputs = num_filters_; + // allocate params for the filters and also for the filter bias values. + params.set_size(num_inputs*num_filters_ + num_filters_); + + dlib::rand rnd(std::rand()); + randomize_parameters(params, num_inputs+num_outputs, rnd); + + filters = alias_tensor(sub.get_output().k(), num_filters_, _nr, _nc); + biases = alias_tensor(1,num_filters_); + + // set the initial bias values to zero + biases(params,filters.size()) = 0; + } + + template + void forward(const SUBNET& sub, resizable_tensor& output) + { + auto filt = filters(params,0); + unsigned int gnr = _stride_y * (sub.get_output().nr() - 1) + filt.nr() - 2 * padding_y_; + unsigned int gnc = _stride_x * (sub.get_output().nc() - 1) + filt.nc() - 2 * padding_x_; + unsigned int gnsamps = sub.get_output().num_samples(); + unsigned int gk = filt.k(); + output.set_size(gnsamps,gk,gnr,gnc); + conv.setup(output,filt,_stride_y,_stride_x,padding_y_,padding_x_); + conv.get_gradient_for_data(false, sub.get_output(),filt,output); + tt::add(1,output,1,biases(params,filters.size())); + } + + template + void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad) + { + auto filt = filters(params,0); + conv(true, sub.get_gradient_input(),gradient_input, filt); + // no point computing the parameter gradients if they won't be used. + if (learning_rate_multiplier != 0) + { + auto filt = filters(params_grad,0); + conv.get_gradient_for_filters (false, sub.get_output(),gradient_input, filt); + auto b = biases(params_grad, filters.size()); + tt::assign_conv_bias_gradient(b, gradient_input); + } + } + + const tensor& get_layer_params() const { return params; } + tensor& get_layer_params() { return params; } + + friend void serialize(const cont_& item, std::ostream& out) + { + serialize("cont_1", out); + serialize(item.params, out); + serialize(item.num_filters_, out); + serialize(_nr, out); + serialize(_nc, out); + serialize(_stride_y, out); + serialize(_stride_x, out); + serialize(item.padding_y_, out); + serialize(item.padding_x_, out); + serialize(item.filters, out); + serialize(item.biases, out); + serialize(item.learning_rate_multiplier, out); + serialize(item.weight_decay_multiplier, out); + serialize(item.bias_learning_rate_multiplier, out); + serialize(item.bias_weight_decay_multiplier, out); + } + + friend void deserialize(cont_& item, std::istream& in) + { + std::string version; + deserialize(version, in); + long nr; + long nc; + int stride_y; + int stride_x; + if (version == "cont_1") + { + deserialize(item.params, in); + deserialize(item.num_filters_, in); + deserialize(nr, in); + deserialize(nc, in); + deserialize(stride_y, in); + deserialize(stride_x, in); + deserialize(item.padding_y_, in); + deserialize(item.padding_x_, in); + deserialize(item.filters, in); + deserialize(item.biases, in); + deserialize(item.learning_rate_multiplier, in); + deserialize(item.weight_decay_multiplier, in); + deserialize(item.bias_learning_rate_multiplier, in); + deserialize(item.bias_weight_decay_multiplier, in); + if (item.padding_y_ != _padding_y) throw serialization_error("Wrong padding_y found while deserializing dlib::con_"); + if (item.padding_x_ != _padding_x) throw serialization_error("Wrong padding_x found while deserializing dlib::con_"); + if (nr != _nr) throw serialization_error("Wrong nr found while deserializing dlib::con_"); + if (nc != _nc) throw serialization_error("Wrong nc found while deserializing dlib::con_"); + if (stride_y != _stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::con_"); + if (stride_x != _stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::con_"); + } + else + { + throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::con_."); + } + } + + + friend std::ostream& operator<<(std::ostream& out, const cont_& item) + { + out << "cont\t (" + << "num_filters="<\n"; + out << mat(item.params); + out << ""; + } + + private: + + resizable_tensor params; + alias_tensor filters, biases; + + tt::tensor_conv conv; + double learning_rate_multiplier; + double weight_decay_multiplier; + double bias_learning_rate_multiplier; + double bias_weight_decay_multiplier; + long num_filters_; + + int padding_y_; + int padding_x_; + + }; + + template < + long num_filters, + long nr, + long nc, + int stride_y, + int stride_x, + typename SUBNET + > + using cont = add_layer, SUBNET>; + +// ---------------------------------------------------------------------------------------- + + template < + int scale_y, + int scale_x + > + class upsample_ + { + public: + static_assert(scale_y >= 1, "upsampling scale factor can't be less than 1."); + static_assert(scale_x >= 1, "upsampling scale factor can't be less than 1."); + + upsample_() + { + } + + template + void setup (const SUBNET& /*sub*/) + { + } + + template + void forward(const SUBNET& sub, resizable_tensor& output) + { + output.set_size( + sub.get_output().num_samples(), + sub.get_output().k(), + scale_y*sub.get_output().nr(), + scale_x*sub.get_output().nc()); + tt::resize_bilinear(output, sub.get_output()); + } + + template + void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/) + { + tt::resize_bilinear_gradient(sub.get_gradient_input(), gradient_input); + } + + inline dpoint map_input_to_output (dpoint p) const + { + p.x() = p.x()*scale_x; + p.y() = p.y()*scale_y; + return p; + } + inline dpoint map_output_to_input (dpoint p) const + { + p.x() = p.x()/scale_x; + p.y() = p.y()/scale_y; + return p; + } + + const tensor& get_layer_params() const { return params; } + tensor& get_layer_params() { return params; } + + friend void serialize(const upsample_& , std::ostream& out) + { + serialize("upsample_", out); + serialize(scale_y, out); + serialize(scale_x, out); + } + + friend void deserialize(upsample_& , std::istream& in) + { + std::string version; + deserialize(version, in); + if (version != "upsample_") + throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::upsample_."); + + int _scale_y; + int _scale_x; + deserialize(_scale_y, in); + deserialize(_scale_x, in); + if (_scale_y != scale_y || _scale_x != scale_x) + throw serialization_error("Wrong scale found while deserializing dlib::upsample_"); + } + + friend std::ostream& operator<<(std::ostream& out, const upsample_& ) + { + out << "upsample\t (" + << "scale_y="<\n"; + } + + private: + resizable_tensor params; + }; + + template < + int scale, + typename SUBNET + > + using upsample = add_layer, SUBNET>; + +// ---------------------------------------------------------------------------------------- + + template < + long _nr, + long _nc, + int _stride_y, + int _stride_x, + int _padding_y = _stride_y!=1? 0 : _nr/2, + int _padding_x = _stride_x!=1? 0 : _nc/2 + > + class max_pool_ + { + static_assert(_nr >= 0, "The number of rows in a filter must be >= 0"); + static_assert(_nc >= 0, "The number of columns in a filter must be >= 0"); + static_assert(_stride_y > 0, "The filter stride must be > 0"); + static_assert(_stride_x > 0, "The filter stride must be > 0"); + static_assert(0 <= _padding_y && ((_nr==0 && _padding_y == 0) || (_nr!=0 && _padding_y < _nr)), + "The padding must be smaller than the filter size, unless the filters size is 0."); + static_assert(0 <= _padding_x && ((_nc==0 && _padding_x == 0) || (_nc!=0 && _padding_x < _nc)), + "The padding must be smaller than the filter size, unless the filters size is 0."); + public: + + + max_pool_( + ) : + padding_y_(_padding_y), + padding_x_(_padding_x) + {} + + long nr() const { return _nr; } + long nc() const { return _nc; } + long stride_y() const { return _stride_y; } + long stride_x() const { return _stride_x; } + long padding_y() const { return padding_y_; } + long padding_x() const { return padding_x_; } + + inline dpoint map_input_to_output ( + dpoint p + ) const + { + p.x() = (p.x()+padding_x()-nc()/2)/stride_x(); + p.y() = (p.y()+padding_y()-nr()/2)/stride_y(); + return p; + } + + inline dpoint map_output_to_input ( + dpoint p + ) const + { + p.x() = p.x()*stride_x() - padding_x() + nc()/2; + p.y() = p.y()*stride_y() - padding_y() + nr()/2; + return p; + } + + max_pool_ ( + const max_pool_& item + ) : + padding_y_(item.padding_y_), + padding_x_(item.padding_x_) + { + // this->mp is non-copyable so we have to write our own copy to avoid trying to + // copy it and getting an error. + } + + max_pool_& operator= ( + const max_pool_& item + ) + { + if (this == &item) + return *this; + + padding_y_ = item.padding_y_; + padding_x_ = item.padding_x_; + + // this->mp is non-copyable so we have to write our own copy to avoid trying to + // copy it and getting an error. + return *this; + } + + template + void setup (const SUBNET& /*sub*/) + { + } + + template + void forward(const SUBNET& sub, resizable_tensor& output) + { + mp.setup_max_pooling(_nr!=0?_nr:sub.get_output().nr(), + _nc!=0?_nc:sub.get_output().nc(), + _stride_y, _stride_x, padding_y_, padding_x_); + + mp(output, sub.get_output()); + } + + template + void backward(const tensor& computed_output, const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/) + { + mp.setup_max_pooling(_nr!=0?_nr:sub.get_output().nr(), + _nc!=0?_nc:sub.get_output().nc(), + _stride_y, _stride_x, padding_y_, padding_x_); + + mp.get_gradient(gradient_input, computed_output, sub.get_output(), sub.get_gradient_input()); + } + + const tensor& get_layer_params() const { return params; } + tensor& get_layer_params() { return params; } + + friend void serialize(const max_pool_& item, std::ostream& out) + { + serialize("max_pool_2", out); + serialize(_nr, out); + serialize(_nc, out); + serialize(_stride_y, out); + serialize(_stride_x, out); + serialize(item.padding_y_, out); + serialize(item.padding_x_, out); + } + + friend void deserialize(max_pool_& item, std::istream& in) + { + std::string version; + deserialize(version, in); + long nr; + long nc; + int stride_y; + int stride_x; + if (version == "max_pool_2") + { + deserialize(nr, in); + deserialize(nc, in); + deserialize(stride_y, in); + deserialize(stride_x, in); + deserialize(item.padding_y_, in); + deserialize(item.padding_x_, in); + } + else + { + throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::max_pool_."); + } + + if (item.padding_y_ != _padding_y) throw serialization_error("Wrong padding_y found while deserializing dlib::max_pool_"); + if (item.padding_x_ != _padding_x) throw serialization_error("Wrong padding_x found while deserializing dlib::max_pool_"); + if (_nr != nr) throw serialization_error("Wrong nr found while deserializing dlib::max_pool_"); + if (_nc != nc) throw serialization_error("Wrong nc found while deserializing dlib::max_pool_"); + if (_stride_y != stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::max_pool_"); + if (_stride_x != stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::max_pool_"); + } + + friend std::ostream& operator<<(std::ostream& out, const max_pool_& item) + { + out << "max_pool (" + << "nr="<<_nr + << ", nc="<<_nc + << ", stride_y="<<_stride_y + << ", stride_x="<<_stride_x + << ", padding_y="<\n"; + } + + + private: + + + tt::pooling mp; + resizable_tensor params; + + int padding_y_; + int padding_x_; + }; + + template < + long nr, + long nc, + int stride_y, + int stride_x, + typename SUBNET + > + using max_pool = add_layer, SUBNET>; + + template < + typename SUBNET + > + using max_pool_everything = add_layer, SUBNET>; + +// ---------------------------------------------------------------------------------------- + + template < + long _nr, + long _nc, + int _stride_y, + int _stride_x, + int _padding_y = _stride_y!=1? 0 : _nr/2, + int _padding_x = _stride_x!=1? 0 : _nc/2 + > + class avg_pool_ + { + public: + static_assert(_nr >= 0, "The number of rows in a filter must be >= 0"); + static_assert(_nc >= 0, "The number of columns in a filter must be >= 0"); + static_assert(_stride_y > 0, "The filter stride must be > 0"); + static_assert(_stride_x > 0, "The filter stride must be > 0"); + static_assert(0 <= _padding_y && ((_nr==0 && _padding_y == 0) || (_nr!=0 && _padding_y < _nr)), + "The padding must be smaller than the filter size, unless the filters size is 0."); + static_assert(0 <= _padding_x && ((_nc==0 && _padding_x == 0) || (_nc!=0 && _padding_x < _nc)), + "The padding must be smaller than the filter size, unless the filters size is 0."); + + avg_pool_( + ) : + padding_y_(_padding_y), + padding_x_(_padding_x) + {} + + long nr() const { return _nr; } + long nc() const { return _nc; } + long stride_y() const { return _stride_y; } + long stride_x() const { return _stride_x; } + long padding_y() const { return padding_y_; } + long padding_x() const { return padding_x_; } + + inline dpoint map_input_to_output ( + dpoint p + ) const + { + p.x() = (p.x()+padding_x()-nc()/2)/stride_x(); + p.y() = (p.y()+padding_y()-nr()/2)/stride_y(); + return p; + } + + inline dpoint map_output_to_input ( + dpoint p + ) const + { + p.x() = p.x()*stride_x() - padding_x() + nc()/2; + p.y() = p.y()*stride_y() - padding_y() + nr()/2; + return p; + } + + avg_pool_ ( + const avg_pool_& item + ) : + padding_y_(item.padding_y_), + padding_x_(item.padding_x_) + { + // this->ap is non-copyable so we have to write our own copy to avoid trying to + // copy it and getting an error. + } + + avg_pool_& operator= ( + const avg_pool_& item + ) + { + if (this == &item) + return *this; + + padding_y_ = item.padding_y_; + padding_x_ = item.padding_x_; + + // this->ap is non-copyable so we have to write our own copy to avoid trying to + // copy it and getting an error. + return *this; + } + + template + void setup (const SUBNET& /*sub*/) + { + } + + template + void forward(const SUBNET& sub, resizable_tensor& output) + { + ap.setup_avg_pooling(_nr!=0?_nr:sub.get_output().nr(), + _nc!=0?_nc:sub.get_output().nc(), + _stride_y, _stride_x, padding_y_, padding_x_); + + ap(output, sub.get_output()); + } + + template + void backward(const tensor& computed_output, const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/) + { + ap.setup_avg_pooling(_nr!=0?_nr:sub.get_output().nr(), + _nc!=0?_nc:sub.get_output().nc(), + _stride_y, _stride_x, padding_y_, padding_x_); + + ap.get_gradient(gradient_input, computed_output, sub.get_output(), sub.get_gradient_input()); + } + + const tensor& get_layer_params() const { return params; } + tensor& get_layer_params() { return params; } + + friend void serialize(const avg_pool_& item, std::ostream& out) + { + serialize("avg_pool_2", out); + serialize(_nr, out); + serialize(_nc, out); + serialize(_stride_y, out); + serialize(_stride_x, out); + serialize(item.padding_y_, out); + serialize(item.padding_x_, out); + } + + friend void deserialize(avg_pool_& item, std::istream& in) + { + std::string version; + deserialize(version, in); + + long nr; + long nc; + int stride_y; + int stride_x; + if (version == "avg_pool_2") + { + deserialize(nr, in); + deserialize(nc, in); + deserialize(stride_y, in); + deserialize(stride_x, in); + deserialize(item.padding_y_, in); + deserialize(item.padding_x_, in); + } + else + { + throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::avg_pool_."); + } + + if (item.padding_y_ != _padding_y) throw serialization_error("Wrong padding_y found while deserializing dlib::avg_pool_"); + if (item.padding_x_ != _padding_x) throw serialization_error("Wrong padding_x found while deserializing dlib::avg_pool_"); + if (_nr != nr) throw serialization_error("Wrong nr found while deserializing dlib::avg_pool_"); + if (_nc != nc) throw serialization_error("Wrong nc found while deserializing dlib::avg_pool_"); + if (_stride_y != stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::avg_pool_"); + if (_stride_x != stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::avg_pool_"); + } + + friend std::ostream& operator<<(std::ostream& out, const avg_pool_& item) + { + out << "avg_pool (" + << "nr="<<_nr + << ", nc="<<_nc + << ", stride_y="<<_stride_y + << ", stride_x="<<_stride_x + << ", padding_y="<\n"; + } + private: + + tt::pooling ap; + resizable_tensor params; + + int padding_y_; + int padding_x_; + }; + + template < + long nr, + long nc, + int stride_y, + int stride_x, + typename SUBNET + > + using avg_pool = add_layer, SUBNET>; + + template < + typename SUBNET + > + using avg_pool_everything = add_layer, SUBNET>; + +// ---------------------------------------------------------------------------------------- + + enum layer_mode + { + CONV_MODE = 0, + FC_MODE = 1 + }; + + const double DEFAULT_BATCH_NORM_EPS = 0.0001; + + template < + layer_mode mode + > + class bn_ + { + public: + explicit bn_( + unsigned long window_size, + double eps_ = DEFAULT_BATCH_NORM_EPS + ) : + num_updates(0), + running_stats_window_size(window_size), + learning_rate_multiplier(1), + weight_decay_multiplier(0), + bias_learning_rate_multiplier(1), + bias_weight_decay_multiplier(1), + eps(eps_) + { + DLIB_CASSERT(window_size > 0, "The batch normalization running stats window size can't be 0."); + } + + bn_() : bn_(100) {} + + layer_mode get_mode() const { return mode; } + unsigned long get_running_stats_window_size () const { return running_stats_window_size; } + void set_running_stats_window_size (unsigned long new_window_size ) + { + DLIB_CASSERT(new_window_size > 0, "The batch normalization running stats window size can't be 0."); + running_stats_window_size = new_window_size; + } + double get_eps() const { return eps; } + + double get_learning_rate_multiplier () const { return learning_rate_multiplier; } + double get_weight_decay_multiplier () const { return weight_decay_multiplier; } + void set_learning_rate_multiplier(double val) { learning_rate_multiplier = val; } + void set_weight_decay_multiplier(double val) { weight_decay_multiplier = val; } + + double get_bias_learning_rate_multiplier () const { return bias_learning_rate_multiplier; } + double get_bias_weight_decay_multiplier () const { return bias_weight_decay_multiplier; } + void set_bias_learning_rate_multiplier(double val) { bias_learning_rate_multiplier = val; } + void set_bias_weight_decay_multiplier(double val) { bias_weight_decay_multiplier = val; } + + inline dpoint map_input_to_output (const dpoint& p) const { return p; } + inline dpoint map_output_to_input (const dpoint& p) const { return p; } + + + template + void setup (const SUBNET& sub) + { + if (mode == FC_MODE) + { + gamma = alias_tensor(1, + sub.get_output().k(), + sub.get_output().nr(), + sub.get_output().nc()); + } + else + { + gamma = alias_tensor(1, sub.get_output().k()); + } + beta = gamma; + + params.set_size(gamma.size()+beta.size()); + + gamma(params,0) = 1; + beta(params,gamma.size()) = 0; + + running_means.copy_size(gamma(params,0)); + running_variances.copy_size(gamma(params,0)); + running_means = 0; + running_variances = 1; + num_updates = 0; + } + + template + void forward(const SUBNET& sub, resizable_tensor& output) + { + auto g = gamma(params,0); + auto b = beta(params,gamma.size()); + if (sub.get_output().num_samples() > 1) + { + const double decay = 1.0 - num_updates/(num_updates+1.0); + ++num_updates; + if (num_updates > running_stats_window_size) + num_updates = running_stats_window_size; + + if (mode == FC_MODE) + tt::batch_normalize(eps, output, means, invstds, decay, running_means, running_variances, sub.get_output(), g, b); + else + tt::batch_normalize_conv(eps, output, means, invstds, decay, running_means, running_variances, sub.get_output(), g, b); + } + else // we are running in testing mode so we just linearly scale the input tensor. + { + if (mode == FC_MODE) + tt::batch_normalize_inference(eps, output, sub.get_output(), g, b, running_means, running_variances); + else + tt::batch_normalize_conv_inference(eps, output, sub.get_output(), g, b, running_means, running_variances); + } + } + + template + void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad) + { + auto g = gamma(params,0); + auto g_grad = gamma(params_grad, 0); + auto b_grad = beta(params_grad, gamma.size()); + if (mode == FC_MODE) + tt::batch_normalize_gradient(eps, gradient_input, means, invstds, sub.get_output(), g, sub.get_gradient_input(), g_grad, b_grad ); + else + tt::batch_normalize_conv_gradient(eps, gradient_input, means, invstds, sub.get_output(), g, sub.get_gradient_input(), g_grad, b_grad ); + } + + const tensor& get_layer_params() const { return params; } + tensor& get_layer_params() { return params; } + + friend void serialize(const bn_& item, std::ostream& out) + { + if (mode == CONV_MODE) + serialize("bn_con2", out); + else // if FC_MODE + serialize("bn_fc2", out); + serialize(item.params, out); + serialize(item.gamma, out); + serialize(item.beta, out); + serialize(item.means, out); + serialize(item.invstds, out); + serialize(item.running_means, out); + serialize(item.running_variances, out); + serialize(item.num_updates, out); + serialize(item.running_stats_window_size, out); + serialize(item.learning_rate_multiplier, out); + serialize(item.weight_decay_multiplier, out); + serialize(item.bias_learning_rate_multiplier, out); + serialize(item.bias_weight_decay_multiplier, out); + serialize(item.eps, out); + } + + friend void deserialize(bn_& item, std::istream& in) + { + std::string version; + deserialize(version, in); + if (mode == CONV_MODE) + { + if (version != "bn_con2") + throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::bn_."); + } + else // must be in FC_MODE + { + if (version != "bn_fc2") + throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::bn_."); + } + + deserialize(item.params, in); + deserialize(item.gamma, in); + deserialize(item.beta, in); + deserialize(item.means, in); + deserialize(item.invstds, in); + deserialize(item.running_means, in); + deserialize(item.running_variances, in); + deserialize(item.num_updates, in); + deserialize(item.running_stats_window_size, in); + deserialize(item.learning_rate_multiplier, in); + deserialize(item.weight_decay_multiplier, in); + deserialize(item.bias_learning_rate_multiplier, in); + deserialize(item.bias_weight_decay_multiplier, in); + deserialize(item.eps, in); + } + + friend std::ostream& operator<<(std::ostream& out, const bn_& item) + { + if (mode == CONV_MODE) + out << "bn_con "; + else + out << "bn_fc "; + out << " eps="<\n"; + + out << mat(item.params); + + if (mode==CONV_MODE) + out << "\n"; + else + out << "\n"; + } + + private: + + friend class affine_; + + resizable_tensor params; + alias_tensor gamma, beta; + resizable_tensor means, running_means; + resizable_tensor invstds, running_variances; + unsigned long num_updates; + unsigned long running_stats_window_size; + double learning_rate_multiplier; + double weight_decay_multiplier; + double bias_learning_rate_multiplier; + double bias_weight_decay_multiplier; + double eps; + }; + + template + using bn_con = add_layer, SUBNET>; + template + using bn_fc = add_layer, SUBNET>; + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + class visitor_bn_running_stats_window_size + { + public: + + visitor_bn_running_stats_window_size(unsigned long new_window_size_) : new_window_size(new_window_size_) {} + + template + void set_window_size(T&) const + { + // ignore other layer detail types + } + + template < layer_mode mode > + void set_window_size(bn_& l) const + { + l.set_running_stats_window_size(new_window_size); + } + + template + void operator()(size_t , input_layer_type& ) const + { + // ignore other layers + } + + template + void operator()(size_t , add_layer& l) const + { + set_window_size(l.layer_details()); + } + + private: + + unsigned long new_window_size; + }; + } + + template + void set_all_bn_running_stats_window_sizes ( + net_type& net, + unsigned long new_window_size + ) + { + visit_layers(net, impl::visitor_bn_running_stats_window_size(new_window_size)); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + enum fc_bias_mode + { + FC_HAS_BIAS = 0, + FC_NO_BIAS = 1 + }; + + struct num_fc_outputs + { + num_fc_outputs(unsigned long n) : num_outputs(n) {} + unsigned long num_outputs; + }; + + template < + unsigned long num_outputs_, + fc_bias_mode bias_mode + > + class fc_ + { + static_assert(num_outputs_ > 0, "The number of outputs from a fc_ layer must be > 0"); + + public: + fc_(num_fc_outputs o) : num_outputs(o.num_outputs), num_inputs(0), + learning_rate_multiplier(1), + weight_decay_multiplier(1), + bias_learning_rate_multiplier(1), + bias_weight_decay_multiplier(0) + {} + + fc_() : fc_(num_fc_outputs(num_outputs_)) {} + + double get_learning_rate_multiplier () const { return learning_rate_multiplier; } + double get_weight_decay_multiplier () const { return weight_decay_multiplier; } + void set_learning_rate_multiplier(double val) { learning_rate_multiplier = val; } + void set_weight_decay_multiplier(double val) { weight_decay_multiplier = val; } + + double get_bias_learning_rate_multiplier () const { return bias_learning_rate_multiplier; } + double get_bias_weight_decay_multiplier () const { return bias_weight_decay_multiplier; } + void set_bias_learning_rate_multiplier(double val) { bias_learning_rate_multiplier = val; } + void set_bias_weight_decay_multiplier(double val) { bias_weight_decay_multiplier = val; } + + unsigned long get_num_outputs ( + ) const { return num_outputs; } + + void set_num_outputs(long num) + { + DLIB_CASSERT(num > 0); + if (num != (long)num_outputs) + { + DLIB_CASSERT(get_layer_params().size() == 0, + "You can't change the number of filters in fc_ if the parameter tensor has already been allocated."); + num_outputs = num; + } + } + + fc_bias_mode get_bias_mode ( + ) const { return bias_mode; } + + template + void setup (const SUBNET& sub) + { + num_inputs = sub.get_output().nr()*sub.get_output().nc()*sub.get_output().k(); + if (bias_mode == FC_HAS_BIAS) + params.set_size(num_inputs+1, num_outputs); + else + params.set_size(num_inputs, num_outputs); + + dlib::rand rnd(std::rand()); + randomize_parameters(params, num_inputs+num_outputs, rnd); + + weights = alias_tensor(num_inputs, num_outputs); + + if (bias_mode == FC_HAS_BIAS) + { + biases = alias_tensor(1,num_outputs); + // set the initial bias values to zero + biases(params,weights.size()) = 0; + } + } + + template + void forward(const SUBNET& sub, resizable_tensor& output) + { + DLIB_CASSERT((long)num_inputs == sub.get_output().nr()*sub.get_output().nc()*sub.get_output().k(), + "The size of the input tensor to this fc layer doesn't match the size the fc layer was trained with."); + output.set_size(sub.get_output().num_samples(), num_outputs); + + auto w = weights(params, 0); + tt::gemm(0,output, 1,sub.get_output(),false, w,false); + if (bias_mode == FC_HAS_BIAS) + { + auto b = biases(params, weights.size()); + tt::add(1,output,1,b); + } + } + + template + void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad) + { + // no point computing the parameter gradients if they won't be used. + if (learning_rate_multiplier != 0) + { + // compute the gradient of the weight parameters. + auto pw = weights(params_grad, 0); + tt::gemm(0,pw, 1,sub.get_output(),true, gradient_input,false); + + if (bias_mode == FC_HAS_BIAS) + { + // compute the gradient of the bias parameters. + auto pb = biases(params_grad, weights.size()); + tt::assign_bias_gradient(pb, gradient_input); + } + } + + // compute the gradient for the data + auto w = weights(params, 0); + tt::gemm(1,sub.get_gradient_input(), 1,gradient_input,false, w,true); + } + + alias_tensor_instance get_weights() + { + return weights(params, 0); + } + + alias_tensor_const_instance get_weights() const + { + return weights(params, 0); + } + + alias_tensor_instance get_biases() + { + static_assert(bias_mode == FC_HAS_BIAS, "This fc_ layer doesn't have a bias vector " + "to be retrieved, as per template parameter 'bias_mode'."); + return biases(params, weights.size()); + } + + alias_tensor_const_instance get_biases() const + { + static_assert(bias_mode == FC_HAS_BIAS, "This fc_ layer doesn't have a bias vector " + "to be retrieved, as per template parameter 'bias_mode'."); + return biases(params, weights.size()); + } + + const tensor& get_layer_params() const { return params; } + tensor& get_layer_params() { return params; } + + friend void serialize(const fc_& item, std::ostream& out) + { + serialize("fc_2", out); + serialize(item.num_outputs, out); + serialize(item.num_inputs, out); + serialize(item.params, out); + serialize(item.weights, out); + serialize(item.biases, out); + serialize((int)bias_mode, out); + serialize(item.learning_rate_multiplier, out); + serialize(item.weight_decay_multiplier, out); + serialize(item.bias_learning_rate_multiplier, out); + serialize(item.bias_weight_decay_multiplier, out); + } + + friend void deserialize(fc_& item, std::istream& in) + { + std::string version; + deserialize(version, in); + if (version != "fc_2") + throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::fc_."); + + deserialize(item.num_outputs, in); + deserialize(item.num_inputs, in); + deserialize(item.params, in); + deserialize(item.weights, in); + deserialize(item.biases, in); + int bmode = 0; + deserialize(bmode, in); + if (bias_mode != (fc_bias_mode)bmode) throw serialization_error("Wrong fc_bias_mode found while deserializing dlib::fc_"); + deserialize(item.learning_rate_multiplier, in); + deserialize(item.weight_decay_multiplier, in); + deserialize(item.bias_learning_rate_multiplier, in); + deserialize(item.bias_weight_decay_multiplier, in); + } + + friend std::ostream& operator<<(std::ostream& out, const fc_& item) + { + if (bias_mode == FC_HAS_BIAS) + { + out << "fc\t (" + << "num_outputs="<\n"; + out << mat(item.params); + out << "\n"; + } + else + { + out << "\n"; + out << mat(item.params); + out << "\n"; + } + } + + private: + + unsigned long num_outputs; + unsigned long num_inputs; + resizable_tensor params; + alias_tensor weights, biases; + double learning_rate_multiplier; + double weight_decay_multiplier; + double bias_learning_rate_multiplier; + double bias_weight_decay_multiplier; + }; + + template < + unsigned long num_outputs, + typename SUBNET + > + using fc = add_layer, SUBNET>; + + template < + unsigned long num_outputs, + typename SUBNET + > + using fc_no_bias = add_layer, SUBNET>; + +// ---------------------------------------------------------------------------------------- + + class dropout_ + { + public: + explicit dropout_( + float drop_rate_ = 0.5 + ) : + drop_rate(drop_rate_), + rnd(std::rand()) + { + DLIB_CASSERT(0 <= drop_rate && drop_rate <= 1); + } + + // We have to add a copy constructor and assignment operator because the rnd object + // is non-copyable. + dropout_( + const dropout_& item + ) : drop_rate(item.drop_rate), mask(item.mask), rnd(std::rand()) + {} + + dropout_& operator= ( + const dropout_& item + ) + { + if (this == &item) + return *this; + + drop_rate = item.drop_rate; + mask = item.mask; + return *this; + } + + float get_drop_rate ( + ) const { return drop_rate; } + + template + void setup (const SUBNET& /*sub*/) + { + } + + void forward_inplace(const tensor& input, tensor& output) + { + // create a random mask and use it to filter the data + mask.copy_size(input); + rnd.fill_uniform(mask); + tt::threshold(mask, drop_rate); + tt::multiply(false, output, input, mask); + } + + void backward_inplace( + const tensor& gradient_input, + tensor& data_grad, + tensor& /*params_grad*/ + ) + { + if (is_same_object(gradient_input, data_grad)) + tt::multiply(false, data_grad, mask, gradient_input); + else + tt::multiply(true, data_grad, mask, gradient_input); + } + + inline dpoint map_input_to_output (const dpoint& p) const { return p; } + inline dpoint map_output_to_input (const dpoint& p) const { return p; } + + const tensor& get_layer_params() const { return params; } + tensor& get_layer_params() { return params; } + + friend void serialize(const dropout_& item, std::ostream& out) + { + serialize("dropout_", out); + serialize(item.drop_rate, out); + serialize(item.mask, out); + } + + friend void deserialize(dropout_& item, std::istream& in) + { + std::string version; + deserialize(version, in); + if (version != "dropout_") + throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::dropout_."); + deserialize(item.drop_rate, in); + deserialize(item.mask, in); + } + + void clean( + ) + { + mask.clear(); + } + + friend std::ostream& operator<<(std::ostream& out, const dropout_& item) + { + out << "dropout\t (" + << "drop_rate="<\n"; + } + + private: + float drop_rate; + resizable_tensor mask; + + tt::tensor_rand rnd; + resizable_tensor params; // unused + }; + + + template + using dropout = add_layer; + +// ---------------------------------------------------------------------------------------- + + class multiply_ + { + public: + explicit multiply_( + float val_ = 0.5 + ) : + val(val_) + { + } + + multiply_ ( + const dropout_& item + ) : val(1-item.get_drop_rate()) {} + + float get_multiply_value ( + ) const { return val; } + + template + void setup (const SUBNET& /*sub*/) + { + } + + void forward_inplace(const tensor& input, tensor& output) + { + tt::affine_transform(output, input, val); + } + + inline dpoint map_input_to_output (const dpoint& p) const { return p; } + inline dpoint map_output_to_input (const dpoint& p) const { return p; } + + void backward_inplace( + const tensor& gradient_input, + tensor& data_grad, + tensor& /*params_grad*/ + ) + { + if (is_same_object(gradient_input, data_grad)) + tt::affine_transform(data_grad, gradient_input, val); + else + tt::affine_transform(data_grad, data_grad, gradient_input, 1, val); + } + + const tensor& get_layer_params() const { return params; } + tensor& get_layer_params() { return params; } + + friend void serialize(const multiply_& item, std::ostream& out) + { + serialize("multiply_", out); + serialize(item.val, out); + } + + friend void deserialize(multiply_& item, std::istream& in) + { + std::string version; + deserialize(version, in); + if (version == "dropout_") + { + // Since we can build a multiply_ from a dropout_ we check if that's what + // is in the stream and if so then just convert it right here. + unserialize sin(version, in); + dropout_ temp; + deserialize(temp, sin); + item = temp; + return; + } + + if (version != "multiply_") + throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::multiply_."); + deserialize(item.val, in); + } + + friend std::ostream& operator<<(std::ostream& out, const multiply_& item) + { + out << "multiply (" + << "val="<\n"; + } + private: + float val; + resizable_tensor params; // unused + }; + + template + using multiply = add_layer; + +// ---------------------------------------------------------------------------------------- + + class affine_ + { + public: + affine_( + ) : mode(FC_MODE) + { + } + + affine_( + layer_mode mode_ + ) : mode(mode_) + { + } + + template < + layer_mode bnmode + > + affine_( + const bn_& item + ) + { + gamma = item.gamma; + beta = item.beta; + mode = bnmode; + + params.copy_size(item.params); + + auto g = gamma(params,0); + auto b = beta(params,gamma.size()); + + resizable_tensor temp(item.params); + auto sg = gamma(temp,0); + auto sb = beta(temp,gamma.size()); + + g = pointwise_multiply(mat(sg), 1.0f/sqrt(mat(item.running_variances)+item.get_eps())); + b = mat(sb) - pointwise_multiply(mat(g), mat(item.running_means)); + } + + layer_mode get_mode() const { return mode; } + + inline dpoint map_input_to_output (const dpoint& p) const { return p; } + inline dpoint map_output_to_input (const dpoint& p) const { return p; } + + template + void setup (const SUBNET& sub) + { + if (mode == FC_MODE) + { + gamma = alias_tensor(1, + sub.get_output().k(), + sub.get_output().nr(), + sub.get_output().nc()); + } + else + { + gamma = alias_tensor(1, sub.get_output().k()); + } + beta = gamma; + + params.set_size(gamma.size()+beta.size()); + + gamma(params,0) = 1; + beta(params,gamma.size()) = 0; + } + + void forward_inplace(const tensor& input, tensor& output) + { + auto g = gamma(params,0); + auto b = beta(params,gamma.size()); + if (mode == FC_MODE) + tt::affine_transform(output, input, g, b); + else + tt::affine_transform_conv(output, input, g, b); + } + + void backward_inplace( + const tensor& gradient_input, + tensor& data_grad, + tensor& /*params_grad*/ + ) + { + auto g = gamma(params,0); + auto b = beta(params,gamma.size()); + + // We are computing the gradient of dot(gradient_input, computed_output*g + b) + if (mode == FC_MODE) + { + if (is_same_object(gradient_input, data_grad)) + tt::multiply(false, data_grad, gradient_input, g); + else + tt::multiply(true, data_grad, gradient_input, g); + } + else + { + if (is_same_object(gradient_input, data_grad)) + tt::multiply_conv(false, data_grad, gradient_input, g); + else + tt::multiply_conv(true, data_grad, gradient_input, g); + } + } + + const tensor& get_layer_params() const { return empty_params; } + tensor& get_layer_params() { return empty_params; } + + friend void serialize(const affine_& item, std::ostream& out) + { + serialize("affine_", out); + serialize(item.params, out); + serialize(item.gamma, out); + serialize(item.beta, out); + serialize((int)item.mode, out); + } + + friend void deserialize(affine_& item, std::istream& in) + { + std::string version; + deserialize(version, in); + if (version == "bn_con2") + { + // Since we can build an affine_ from a bn_ we check if that's what is in + // the stream and if so then just convert it right here. + unserialize sin(version, in); + bn_ temp; + deserialize(temp, sin); + item = temp; + return; + } + else if (version == "bn_fc2") + { + // Since we can build an affine_ from a bn_ we check if that's what is in + // the stream and if so then just convert it right here. + unserialize sin(version, in); + bn_ temp; + deserialize(temp, sin); + item = temp; + return; + } + + if (version != "affine_") + throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::affine_."); + deserialize(item.params, in); + deserialize(item.gamma, in); + deserialize(item.beta, in); + int mode; + deserialize(mode, in); + item.mode = (layer_mode)mode; + } + + friend std::ostream& operator<<(std::ostream& out, const affine_& ) + { + out << "affine"; + return out; + } + + friend void to_xml(const affine_& item, std::ostream& out) + { + if (item.mode==CONV_MODE) + out << "\n"; + else + out << "\n"; + + out << mat(item.params); + + if (item.mode==CONV_MODE) + out << "\n"; + else + out << "\n"; + } + + private: + resizable_tensor params, empty_params; + alias_tensor gamma, beta; + layer_mode mode; + }; + + template + using affine = add_layer; + +// ---------------------------------------------------------------------------------------- + + template < + template class tag + > + class add_prev_ + { + public: + const static unsigned long id = tag_id::id; + + add_prev_() + { + } + + template + void setup (const SUBNET& /*sub*/) + { + } + + template + void forward(const SUBNET& sub, resizable_tensor& output) + { + auto&& t1 = sub.get_output(); + auto&& t2 = layer(sub).get_output(); + output.set_size(std::max(t1.num_samples(),t2.num_samples()), + std::max(t1.k(),t2.k()), + std::max(t1.nr(),t2.nr()), + std::max(t1.nc(),t2.nc())); + tt::add(output, t1, t2); + } + + template + void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/) + { + // The gradient just flows backwards to the two layers that forward() added + // together. + tt::add(sub.get_gradient_input(), sub.get_gradient_input(), gradient_input); + tt::add(layer(sub).get_gradient_input(), layer(sub).get_gradient_input(), gradient_input); + } + + const tensor& get_layer_params() const { return params; } + tensor& get_layer_params() { return params; } + + inline dpoint map_input_to_output (const dpoint& p) const { return p; } + inline dpoint map_output_to_input (const dpoint& p) const { return p; } + + friend void serialize(const add_prev_& , std::ostream& out) + { + serialize("add_prev_", out); + } + + friend void deserialize(add_prev_& , std::istream& in) + { + std::string version; + deserialize(version, in); + if (version != "add_prev_") + throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::add_prev_."); + } + + friend std::ostream& operator<<(std::ostream& out, const add_prev_& item) + { + out << "add_prev"<\n"; + } + + private: + resizable_tensor params; + }; + + template < + template class tag, + typename SUBNET + > + using add_prev = add_layer, SUBNET>; + + template using add_prev1 = add_prev; + template using add_prev2 = add_prev; + template using add_prev3 = add_prev; + template using add_prev4 = add_prev; + template using add_prev5 = add_prev; + template using add_prev6 = add_prev; + template using add_prev7 = add_prev; + template using add_prev8 = add_prev; + template using add_prev9 = add_prev; + template using add_prev10 = add_prev; + + using add_prev1_ = add_prev_; + using add_prev2_ = add_prev_; + using add_prev3_ = add_prev_; + using add_prev4_ = add_prev_; + using add_prev5_ = add_prev_; + using add_prev6_ = add_prev_; + using add_prev7_ = add_prev_; + using add_prev8_ = add_prev_; + using add_prev9_ = add_prev_; + using add_prev10_ = add_prev_; + +// ---------------------------------------------------------------------------------------- + + template < + template class tag + > + class mult_prev_ + { + public: + const static unsigned long id = tag_id::id; + + mult_prev_() + { + } + + template + void setup (const SUBNET& /*sub*/) + { + } + + template + void forward(const SUBNET& sub, resizable_tensor& output) + { + auto&& t1 = sub.get_output(); + auto&& t2 = layer(sub).get_output(); + output.set_size(std::max(t1.num_samples(),t2.num_samples()), + std::max(t1.k(),t2.k()), + std::max(t1.nr(),t2.nr()), + std::max(t1.nc(),t2.nc())); + tt::multiply_zero_padded(false, output, t1, t2); + } + + template + void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/) + { + auto&& t1 = sub.get_output(); + auto&& t2 = layer(sub).get_output(); + // The gradient just flows backwards to the two layers that forward() + // multiplied together. + tt::multiply_zero_padded(true, sub.get_gradient_input(), t2, gradient_input); + tt::multiply_zero_padded(true, layer(sub).get_gradient_input(), t1, gradient_input); + } + + const tensor& get_layer_params() const { return params; } + tensor& get_layer_params() { return params; } + + friend void serialize(const mult_prev_& , std::ostream& out) + { + serialize("mult_prev_", out); + } + + friend void deserialize(mult_prev_& , std::istream& in) + { + std::string version; + deserialize(version, in); + if (version != "mult_prev_") + throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::mult_prev_."); + } + + friend std::ostream& operator<<(std::ostream& out, const mult_prev_& item) + { + out << "mult_prev"<\n"; + } + + private: + resizable_tensor params; + }; + + template < + template class tag, + typename SUBNET + > + using mult_prev = add_layer, SUBNET>; + + template using mult_prev1 = mult_prev; + template using mult_prev2 = mult_prev; + template using mult_prev3 = mult_prev; + template using mult_prev4 = mult_prev; + template using mult_prev5 = mult_prev; + template using mult_prev6 = mult_prev; + template using mult_prev7 = mult_prev; + template using mult_prev8 = mult_prev; + template using mult_prev9 = mult_prev; + template using mult_prev10 = mult_prev; + + using mult_prev1_ = mult_prev_; + using mult_prev2_ = mult_prev_; + using mult_prev3_ = mult_prev_; + using mult_prev4_ = mult_prev_; + using mult_prev5_ = mult_prev_; + using mult_prev6_ = mult_prev_; + using mult_prev7_ = mult_prev_; + using mult_prev8_ = mult_prev_; + using mult_prev9_ = mult_prev_; + using mult_prev10_ = mult_prev_; + +// ---------------------------------------------------------------------------------------- + + template < + template class tag + > + class scale_ + { + public: + const static unsigned long id = tag_id::id; + + scale_() + { + } + + template + void setup (const SUBNET& /*sub*/) + { + } + + template + void forward(const SUBNET& sub, resizable_tensor& output) + { + auto&& scales = sub.get_output(); + auto&& src = layer(sub).get_output(); + DLIB_CASSERT(scales.num_samples() == src.num_samples() && + scales.k() == src.k() && + scales.nr() == 1 && + scales.nc() == 1, + "scales.k(): " << scales.k() << + "\nsrc.k(): " << src.k() + ); + + output.copy_size(src); + tt::scale_channels(false, output, src, scales); + } + + template + void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/) + { + auto&& scales = sub.get_output(); + auto&& src = layer(sub).get_output(); + // The gradient just flows backwards to the two layers that forward() + // read from. + tt::scale_channels(true, layer(sub).get_gradient_input(), gradient_input, scales); + + if (reshape_src.num_samples() != src.num_samples()) + { + reshape_scales = alias_tensor(src.num_samples()*src.k()); + reshape_src = alias_tensor(src.num_samples()*src.k(),src.nr()*src.nc()); + } + + auto&& scales_grad = sub.get_gradient_input(); + auto sgrad = reshape_scales(scales_grad); + tt::dot_prods(true, sgrad, reshape_src(src), reshape_src(gradient_input)); + } + + const tensor& get_layer_params() const { return params; } + tensor& get_layer_params() { return params; } + + friend void serialize(const scale_& item, std::ostream& out) + { + serialize("scale_", out); + serialize(item.reshape_scales, out); + serialize(item.reshape_src, out); + } + + friend void deserialize(scale_& item, std::istream& in) + { + std::string version; + deserialize(version, in); + if (version != "scale_") + throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::scale_."); + deserialize(item.reshape_scales, in); + deserialize(item.reshape_src, in); + } + + friend std::ostream& operator<<(std::ostream& out, const scale_& item) + { + out << "scale"<\n"; + } + + private: + alias_tensor reshape_scales; + alias_tensor reshape_src; + resizable_tensor params; + }; + + template < + template class tag, + typename SUBNET + > + using scale = add_layer, SUBNET>; + + template using scale1 = scale; + template using scale2 = scale; + template using scale3 = scale; + template using scale4 = scale; + template using scale5 = scale; + template using scale6 = scale; + template using scale7 = scale; + template using scale8 = scale; + template using scale9 = scale; + template using scale10 = scale; + + using scale1_ = scale_; + using scale2_ = scale_; + using scale3_ = scale_; + using scale4_ = scale_; + using scale5_ = scale_; + using scale6_ = scale_; + using scale7_ = scale_; + using scale8_ = scale_; + using scale9_ = scale_; + using scale10_ = scale_; + +// ---------------------------------------------------------------------------------------- + + class relu_ + { + public: + relu_() + { + } + + template + void setup (const SUBNET& /*sub*/) + { + } + + void forward_inplace(const tensor& input, tensor& output) + { + tt::relu(output, input); + } + + void backward_inplace( + const tensor& computed_output, + const tensor& gradient_input, + tensor& data_grad, + tensor& + ) + { + tt::relu_gradient(data_grad, computed_output, gradient_input); + } + + inline dpoint map_input_to_output (const dpoint& p) const { return p; } + inline dpoint map_output_to_input (const dpoint& p) const { return p; } + + const tensor& get_layer_params() const { return params; } + tensor& get_layer_params() { return params; } + + friend void serialize(const relu_& , std::ostream& out) + { + serialize("relu_", out); + } + + friend void deserialize(relu_& , std::istream& in) + { + std::string version; + deserialize(version, in); + if (version != "relu_") + throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::relu_."); + } + + friend std::ostream& operator<<(std::ostream& out, const relu_& ) + { + out << "relu"; + return out; + } + + friend void to_xml(const relu_& /*item*/, std::ostream& out) + { + out << "\n"; + } + + private: + resizable_tensor params; + }; + + + template + using relu = add_layer; + +// ---------------------------------------------------------------------------------------- + + class prelu_ + { + public: + explicit prelu_( + float initial_param_value_ = 0.25 + ) : initial_param_value(initial_param_value_) + { + } + + float get_initial_param_value ( + ) const { return initial_param_value; } + + template + void setup (const SUBNET& /*sub*/) + { + params.set_size(1); + params = initial_param_value; + } + + template + void forward( + const SUBNET& sub, + resizable_tensor& data_output + ) + { + data_output.copy_size(sub.get_output()); + tt::prelu(data_output, sub.get_output(), params); + } + + template + void backward( + const tensor& gradient_input, + SUBNET& sub, + tensor& params_grad + ) + { + tt::prelu_gradient(sub.get_gradient_input(), sub.get_output(), + gradient_input, params, params_grad); + } + + inline dpoint map_input_to_output (const dpoint& p) const { return p; } + inline dpoint map_output_to_input (const dpoint& p) const { return p; } + + const tensor& get_layer_params() const { return params; } + tensor& get_layer_params() { return params; } + + friend void serialize(const prelu_& item, std::ostream& out) + { + serialize("prelu_", out); + serialize(item.params, out); + serialize(item.initial_param_value, out); + } + + friend void deserialize(prelu_& item, std::istream& in) + { + std::string version; + deserialize(version, in); + if (version != "prelu_") + throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::prelu_."); + deserialize(item.params, in); + deserialize(item.initial_param_value, in); + } + + friend std::ostream& operator<<(std::ostream& out, const prelu_& item) + { + out << "prelu\t (" + << "initial_param_value="<\n"; + out << mat(item.params); + out << "\n"; + } + + private: + resizable_tensor params; + float initial_param_value; + }; + + template + using prelu = add_layer; + +// ---------------------------------------------------------------------------------------- + + class sig_ + { + public: + sig_() + { + } + + template + void setup (const SUBNET& /*sub*/) + { + } + + void forward_inplace(const tensor& input, tensor& output) + { + tt::sigmoid(output, input); + } + + void backward_inplace( + const tensor& computed_output, + const tensor& gradient_input, + tensor& data_grad, + tensor& + ) + { + tt::sigmoid_gradient(data_grad, computed_output, gradient_input); + } + + inline dpoint map_input_to_output (const dpoint& p) const { return p; } + inline dpoint map_output_to_input (const dpoint& p) const { return p; } + + const tensor& get_layer_params() const { return params; } + tensor& get_layer_params() { return params; } + + friend void serialize(const sig_& , std::ostream& out) + { + serialize("sig_", out); + } + + friend void deserialize(sig_& , std::istream& in) + { + std::string version; + deserialize(version, in); + if (version != "sig_") + throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::sig_."); + } + + friend std::ostream& operator<<(std::ostream& out, const sig_& ) + { + out << "sig"; + return out; + } + + friend void to_xml(const sig_& /*item*/, std::ostream& out) + { + out << "\n"; + } + + + private: + resizable_tensor params; + }; + + + template + using sig = add_layer; + +// ---------------------------------------------------------------------------------------- + + class htan_ + { + public: + htan_() + { + } + + template + void setup (const SUBNET& /*sub*/) + { + } + + inline dpoint map_input_to_output (const dpoint& p) const { return p; } + inline dpoint map_output_to_input (const dpoint& p) const { return p; } + + void forward_inplace(const tensor& input, tensor& output) + { + tt::tanh(output, input); + } + + void backward_inplace( + const tensor& computed_output, + const tensor& gradient_input, + tensor& data_grad, + tensor& + ) + { + tt::tanh_gradient(data_grad, computed_output, gradient_input); + } + + const tensor& get_layer_params() const { return params; } + tensor& get_layer_params() { return params; } + + friend void serialize(const htan_& , std::ostream& out) + { + serialize("htan_", out); + } + + friend void deserialize(htan_& , std::istream& in) + { + std::string version; + deserialize(version, in); + if (version != "htan_") + throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::htan_."); + } + + friend std::ostream& operator<<(std::ostream& out, const htan_& ) + { + out << "htan"; + return out; + } + + friend void to_xml(const htan_& /*item*/, std::ostream& out) + { + out << "\n"; + } + + + private: + resizable_tensor params; + }; + + + template + using htan = add_layer; + +// ---------------------------------------------------------------------------------------- + + class softmax_ + { + public: + softmax_() + { + } + + template + void setup (const SUBNET& /*sub*/) + { + } + + void forward_inplace(const tensor& input, tensor& output) + { + tt::softmax(output, input); + } + + void backward_inplace( + const tensor& computed_output, + const tensor& gradient_input, + tensor& data_grad, + tensor& + ) + { + tt::softmax_gradient(data_grad, computed_output, gradient_input); + } + + const tensor& get_layer_params() const { return params; } + tensor& get_layer_params() { return params; } + + friend void serialize(const softmax_& , std::ostream& out) + { + serialize("softmax_", out); + } + + friend void deserialize(softmax_& , std::istream& in) + { + std::string version; + deserialize(version, in); + if (version != "softmax_") + throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::softmax_."); + } + + friend std::ostream& operator<<(std::ostream& out, const softmax_& ) + { + out << "softmax"; + return out; + } + + friend void to_xml(const softmax_& /*item*/, std::ostream& out) + { + out << "\n"; + } + + private: + resizable_tensor params; + }; + + template + using softmax = add_layer; + +// ---------------------------------------------------------------------------------------- + + class softmax_all_ + { + public: + softmax_all_() + { + } + + template + void setup (const SUBNET& /*sub*/) + { + } + + void forward_inplace(const tensor& input, tensor& output) + { + tt::softmax_all(output, input); + } + + void backward_inplace( + const tensor& computed_output, + const tensor& gradient_input, + tensor& data_grad, + tensor& + ) + { + tt::softmax_all_gradient(data_grad, computed_output, gradient_input); + } + + const tensor& get_layer_params() const { return params; } + tensor& get_layer_params() { return params; } + + friend void serialize(const softmax_all_& , std::ostream& out) + { + serialize("softmax_all_", out); + } + + friend void deserialize(softmax_all_& , std::istream& in) + { + std::string version; + deserialize(version, in); + if (version != "softmax_all_") + throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::softmax_all_."); + } + + friend std::ostream& operator<<(std::ostream& out, const softmax_all_& ) + { + out << "softmax_all"; + return out; + } + + friend void to_xml(const softmax_all_& /*item*/, std::ostream& out) + { + out << "\n"; + } + + private: + resizable_tensor params; + }; + + template + using softmax_all = add_layer; + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + template class TAG_TYPE, template class... TAG_TYPES> + struct concat_helper_impl{ + + constexpr static size_t tag_count() {return 1 + concat_helper_impl::tag_count();} + static void list_tags(std::ostream& out) + { + out << tag_id::id << (tag_count() > 1 ? "," : ""); + concat_helper_impl::list_tags(out); + } + + template + static void resize_out(resizable_tensor& out, const SUBNET& sub, long sum_k) + { + auto& t = layer(sub).get_output(); + concat_helper_impl::resize_out(out, sub, sum_k + t.k()); + } + template + static void concat(tensor& out, const SUBNET& sub, size_t k_offset) + { + auto& t = layer(sub).get_output(); + tt::copy_tensor(false, out, k_offset, t, 0, t.k()); + k_offset += t.k(); + concat_helper_impl::concat(out, sub, k_offset); + } + template + static void split(const tensor& input, SUBNET& sub, size_t k_offset) + { + auto& t = layer(sub).get_gradient_input(); + tt::copy_tensor(true, t, 0, input, k_offset, t.k()); + k_offset += t.k(); + concat_helper_impl::split(input, sub, k_offset); + } + }; + template class TAG_TYPE> + struct concat_helper_impl{ + constexpr static size_t tag_count() {return 1;} + static void list_tags(std::ostream& out) + { + out << tag_id::id; + } + + template + static void resize_out(resizable_tensor& out, const SUBNET& sub, long sum_k) + { + auto& t = layer(sub).get_output(); + out.set_size(t.num_samples(), t.k() + sum_k, t.nr(), t.nc()); + } + template + static void concat(tensor& out, const SUBNET& sub, size_t k_offset) + { + auto& t = layer(sub).get_output(); + tt::copy_tensor(false, out, k_offset, t, 0, t.k()); + } + template + static void split(const tensor& input, SUBNET& sub, size_t k_offset) + { + auto& t = layer(sub).get_gradient_input(); + tt::copy_tensor(true, t, 0, input, k_offset, t.k()); + } + }; + } + // concat layer + template< + template class... TAG_TYPES + > + class concat_ + { + static void list_tags(std::ostream& out) { impl::concat_helper_impl::list_tags(out);}; + + public: + constexpr static size_t tag_count() {return impl::concat_helper_impl::tag_count();}; + + template + void setup (const SUBNET&) + { + // do nothing + } + template + void forward(const SUBNET& sub, resizable_tensor& output) + { + // the total depth of result is the sum of depths from all tags + impl::concat_helper_impl::resize_out(output, sub, 0); + + // copy output from each tag into different part result + impl::concat_helper_impl::concat(output, sub, 0); + } + + template + void backward(const tensor& gradient_input, SUBNET& sub, tensor&) + { + // Gradient is split into parts for each tag layer + impl::concat_helper_impl::split(gradient_input, sub, 0); + } + + dpoint map_input_to_output(dpoint p) const { return p; } + dpoint map_output_to_input(dpoint p) const { return p; } + + const tensor& get_layer_params() const { return params; } + tensor& get_layer_params() { return params; } + + friend void serialize(const concat_& item, std::ostream& out) + { + serialize("concat_", out); + size_t count = tag_count(); + serialize(count, out); + } + + friend void deserialize(concat_& item, std::istream& in) + { + std::string version; + deserialize(version, in); + if (version != "concat_") + throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::concat_."); + size_t count_tags; + deserialize(count_tags, in); + if (count_tags != tag_count()) + throw serialization_error("Invalid count of tags "+ std::to_string(count_tags) +", expecting " + + std::to_string(tag_count()) + + " found while deserializing dlib::concat_."); + } + + friend std::ostream& operator<<(std::ostream& out, const concat_& item) + { + out << "concat\t ("; + list_tags(out); + out << ")"; + return out; + } + + friend void to_xml(const concat_& item, std::ostream& out) + { + out << "\n"; + } + + private: + resizable_tensor params; // unused + }; + + + // concat layer definitions + template class TAG1, + template class TAG2, + typename SUBNET> + using concat2 = add_layer, SUBNET>; + + template class TAG1, + template class TAG2, + template class TAG3, + typename SUBNET> + using concat3 = add_layer, SUBNET>; + + template class TAG1, + template class TAG2, + template class TAG3, + template class TAG4, + typename SUBNET> + using concat4 = add_layer, SUBNET>; + + template class TAG1, + template class TAG2, + template class TAG3, + template class TAG4, + template class TAG5, + typename SUBNET> + using concat5 = add_layer, SUBNET>; + + // inception layer will use tags internally. If user will use tags too, some conflicts + // possible to exclude them, here are new tags specially for inceptions + template using itag0 = add_tag_layer< 1000 + 0, SUBNET>; + template using itag1 = add_tag_layer< 1000 + 1, SUBNET>; + template using itag2 = add_tag_layer< 1000 + 2, SUBNET>; + template using itag3 = add_tag_layer< 1000 + 3, SUBNET>; + template using itag4 = add_tag_layer< 1000 + 4, SUBNET>; + template using itag5 = add_tag_layer< 1000 + 5, SUBNET>; + // skip to inception input + template using iskip = add_skip_layer< itag0, SUBNET>; + + // here are some templates to be used for creating inception layer groups + template class B1, + templateclass B2, + typename SUBNET> + using inception2 = concat2>>>>>>; + + template class B1, + templateclass B2, + templateclass B3, + typename SUBNET> + using inception3 = concat3>>>>>>>>>; + + template class B1, + templateclass B2, + templateclass B3, + templateclass B4, + typename SUBNET> + using inception4 = concat4>>>>>>>>>>>>; + + template class B1, + templateclass B2, + templateclass B3, + templateclass B4, + templateclass B5, + typename SUBNET> + using inception5 = concat5>>>>>>>>>>>>>>>; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + const double DEFAULT_L2_NORM_EPS = 1e-5; + + class l2normalize_ + { + public: + explicit l2normalize_( + double eps_ = DEFAULT_L2_NORM_EPS + ) : + eps(eps_) + { + } + + double get_eps() const { return eps; } + + template + void setup (const SUBNET& /*sub*/) + { + } + + void forward_inplace(const tensor& input, tensor& output) + { + tt::inverse_norms(norm, input, eps); + tt::scale_rows(output, input, norm); + } + + void backward_inplace( + const tensor& computed_output, + const tensor& gradient_input, + tensor& data_grad, + tensor& /*params_grad*/ + ) + { + if (is_same_object(gradient_input, data_grad)) + { + tt::dot_prods(temp, gradient_input, computed_output); + tt::scale_rows2(0, data_grad, gradient_input, computed_output, temp, norm); + } + else + { + tt::dot_prods(temp, gradient_input, computed_output); + tt::scale_rows2(1, data_grad, gradient_input, computed_output, temp, norm); + } + } + + const tensor& get_layer_params() const { return params; } + tensor& get_layer_params() { return params; } + + friend void serialize(const l2normalize_& item, std::ostream& out) + { + serialize("l2normalize_", out); + serialize(item.eps, out); + } + + friend void deserialize(l2normalize_& item, std::istream& in) + { + std::string version; + deserialize(version, in); + if (version != "l2normalize_") + throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::l2normalize_."); + deserialize(item.eps, in); + } + + friend std::ostream& operator<<(std::ostream& out, const l2normalize_& item) + { + out << "l2normalize"; + out << " eps="<\n"; + } + private: + double eps; + + resizable_tensor params; // unused + // Here only to avoid reallocation and as a cache between forward/backward + // functions. + resizable_tensor norm; + resizable_tensor temp; + }; + + template + using l2normalize = add_layer; + +// ---------------------------------------------------------------------------------------- + + template < + long _offset, + long _k, + long _nr, + long _nc + > + class extract_ + { + static_assert(_offset >= 0, "The offset must be >= 0."); + static_assert(_k > 0, "The number of channels must be > 0."); + static_assert(_nr > 0, "The number of rows must be > 0."); + static_assert(_nc > 0, "The number of columns must be > 0."); + public: + extract_( + ) + { + } + + template + void setup (const SUBNET& sub) + { + DLIB_CASSERT((long)sub.get_output().size() >= sub.get_output().num_samples()*(_offset+_k*_nr*_nc), + "The tensor we are trying to extract from the input tensor is too big to fit into the input tensor."); + + aout = alias_tensor(sub.get_output().num_samples(), _k*_nr*_nc); + ain = alias_tensor(sub.get_output().num_samples(), sub.get_output().size()/sub.get_output().num_samples()); + } + + template + void forward(const SUBNET& sub, resizable_tensor& output) + { + if (aout.num_samples() != sub.get_output().num_samples()) + { + aout = alias_tensor(sub.get_output().num_samples(), _k*_nr*_nc); + ain = alias_tensor(sub.get_output().num_samples(), sub.get_output().size()/sub.get_output().num_samples()); + } + + output.set_size(sub.get_output().num_samples(), _k, _nr, _nc); + auto out = aout(output,0); + auto in = ain(sub.get_output(),0); + tt::copy_tensor(false, out, 0, in, _offset, _k*_nr*_nc); + } + + template + void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/) + { + auto out = ain(sub.get_gradient_input(),0); + auto in = aout(gradient_input,0); + tt::copy_tensor(true, out, _offset, in, 0, _k*_nr*_nc); + } + + const tensor& get_layer_params() const { return params; } + tensor& get_layer_params() { return params; } + + friend void serialize(const extract_& item, std::ostream& out) + { + serialize("extract_", out); + serialize(_offset, out); + serialize(_k, out); + serialize(_nr, out); + serialize(_nc, out); + } + + friend void deserialize(extract_& item, std::istream& in) + { + std::string version; + deserialize(version, in); + if (version != "extract_") + throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::extract_."); + + long offset; + long k; + long nr; + long nc; + deserialize(offset, in); + deserialize(k, in); + deserialize(nr, in); + deserialize(nc, in); + + if (offset != _offset) throw serialization_error("Wrong offset found while deserializing dlib::extract_"); + if (k != _k) throw serialization_error("Wrong k found while deserializing dlib::extract_"); + if (nr != _nr) throw serialization_error("Wrong nr found while deserializing dlib::extract_"); + if (nc != _nc) throw serialization_error("Wrong nc found while deserializing dlib::extract_"); + } + + friend std::ostream& operator<<(std::ostream& out, const extract_& item) + { + out << "extract\t (" + << "offset="<<_offset + << ", k="<<_k + << ", nr="<<_nr + << ", nc="<<_nc + << ")"; + return out; + } + + friend void to_xml(const extract_& item, std::ostream& out) + { + out << "\n"; + } + private: + alias_tensor aout, ain; + + resizable_tensor params; // unused + }; + + template < + long offset, + long k, + long nr, + long nc, + typename SUBNET + > + using extract = add_layer, SUBNET>; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_DNn_LAYERS_H_ + + diff --git a/ml/dlib/dlib/dnn/layers_abstract.h b/ml/dlib/dlib/dnn/layers_abstract.h new file mode 100644 index 000000000..f07025ff8 --- /dev/null +++ b/ml/dlib/dlib/dnn/layers_abstract.h @@ -0,0 +1,2631 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_DNn_LAYERS_ABSTRACT_H_ +#ifdef DLIB_DNn_LAYERS_ABSTRACT_H_ + +#include "tensor_abstract.h" +#include "core_abstract.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class SUBNET + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a deep neural network. In particular, it is + the simplified interface through which layer objects interact with their + subnetworks. A layer's two important tasks are to (1) take outputs from its + subnetwork and forward propagate them through itself and (2) to backwards + propagate an error gradient through itself and onto its subnetwork. + The idea of a subnetwork is illustrated in the following diagram: + + +---------------------------------------------------------+ + | loss <-- layer1 <-- layer2 <-- ... <-- layern <-- input | + +---------------------------------------------------------+ + ^ ^ + \__ subnetwork for layer1 __/ + + Therefore, by "subnetwork" we mean the part of the network closer to the + input. + + Note that there is no dlib::SUBNET type. It is shown here purely to + document the interface layer objects expect to see when they interact + with a network. + !*/ + + public: + // You aren't allowed to copy subnetworks from inside a layer. + SUBNET(const SUBNET&) = delete; + SUBNET& operator=(const SUBNET&) = delete; + + const tensor& get_output( + ) const; + /*! + ensures + - returns the output of this subnetwork. This is the data that the next + layer in the network will take as input. + - have_same_dimensions(#get_gradient_input(), get_output()) == true + !*/ + + tensor& get_gradient_input( + ); + /*! + ensures + - returns the error gradient for this subnetwork. That is, this is the + error gradient that this network will use to update itself. Therefore, + when performing back propagation, layers that sit on top of this + subnetwork write their back propagated error gradients into + get_gradient_input(). Or to put it another way, during back propagation, + layers take the contents of their get_gradient_input() and back propagate + it through themselves and store the results into their subnetwork's + get_gradient_input(). + !*/ + + const NEXT_SUBNET& subnet( + ) const; + /*! + ensures + - returns the subnetwork of *this network. With respect to the diagram + above, if *this was layer1 then subnet() would return the network that + begins with layer2. + !*/ + + NEXT_SUBNET& subnet( + ); + /*! + ensures + - returns the subnetwork of *this network. With respect to the diagram + above, if *this was layer1 then subnet() would return the network that + begins with layer2. + !*/ + + const layer_details_type& layer_details( + ) const; + /*! + ensures + - returns the layer_details_type instance that defines the behavior of the + layer at the top of this network. I.e. returns the layer details that + defines the behavior of the layer nearest to the network output rather + than the input layer. For computational layers, this is the object + implementing the EXAMPLE_COMPUTATIONAL_LAYER_ interface that defines the + layer's behavior. + !*/ + + unsigned int sample_expansion_factor ( + ) const; + /*! + ensures + - When to_tensor() is invoked on this network's input layer it converts N + input objects into M samples, all stored inside a resizable_tensor. It + is always the case that M is some integer multiple of N. + sample_expansion_factor() returns the value of this multiplier. To be + very specific, it is always true that M==I*N where I is some integer. + This integer I is what is returned by sample_expansion_factor(). + + It should be noted that computational layers likely do not care about the + sample expansion factor. It is only really of concern inside a loss + layer where you need to know its value so that tensor samples can be + matched against truth objects. Moreover, in most cases the sample + expansion factor is 1. + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + class EXAMPLE_COMPUTATIONAL_LAYER_ + { + /*! + WHAT THIS OBJECT REPRESENTS + Each computational layer in a deep neural network can be thought of as a + function, f(data,parameters), that takes in a data tensor, some parameters, + and produces an output tensor. You create an entire deep network by + composing these functions. Importantly, you are able to use a wide range + of different functions to accommodate the task you are trying to + accomplish. Therefore, dlib includes a number of common layer types but if + you want to define your own then you simply implement a class with the same + interface as EXAMPLE_COMPUTATIONAL_LAYER_. + + Note that there is no dlib::EXAMPLE_COMPUTATIONAL_LAYER_ type. It is shown + here purely to document the interface that a layer object must implement. + + The central work of defining a layer is implementing the forward and backward + methods. When you do this you have four options: + - Implement the forward() and backward() methods according to the + specification shown below. Do not implement forward_inplace() and + backward_inplace(). + - Implement the forward() and backward() methods according to the + specification shown below, except exclude the computed_output + parameter from backward(). Doing this will allow dlib to make some + layers execute in-place and therefore run a little faster and use + less memory. Do not implement forward_inplace() and + backward_inplace(). + - Implement the forward_inplace() and backward_inplace() methods + according to the specification shown below. Do not implement + forward() and backward(). These in-place methods allow some types of + layers to be implemented more efficiently. + - Implement the forward_inplace() and backward_inplace() methods + according to the specification shown below, except exclude the + computed_output parameter from backward_inplace(). Doing this will + allow dlib to make some layers execute in-place and therefore run a + little faster and use less memory. Do not implement forward() and + backward(). + + + It should also be noted that layers may define additional layer specific + fields and the solvers can use these fields as they see fit. For example, + some layers define get_learning_rate_multiplier() and + get_weight_decay_multiplier() methods. The solvers that come with dlib + look at these methods, if they exist, and adjust the learning rate or + weight decay for that layer according to the multiplier. Therefore, you + can add these methods to your layer types if you want, or even define new + fields and new solvers that use those fields in some way. + !*/ + + public: + + EXAMPLE_COMPUTATIONAL_LAYER_( + ); + /*! + ensures + - Default constructs this object. This function is not required to do + anything in particular but it must exist, that is, it is required that + layer objects be default constructable. + !*/ + + EXAMPLE_COMPUTATIONAL_LAYER_ ( + const EXAMPLE_COMPUTATIONAL_LAYER_& item + ); + /*! + ensures + - EXAMPLE_COMPUTATIONAL_LAYER_ objects are copy constructable + !*/ + + EXAMPLE_COMPUTATIONAL_LAYER_( + const some_other_layer_type& item + ); + /*! + ensures + - Constructs this object from item. This form of constructor is optional + but it allows you to provide a conversion from one layer type to another. + For example, the following code is valid only if my_layer2 can be + constructed from my_layer1: + relu>>>>> my_dnn1; + relu>>>>> my_dnn2(my_dnn1); + This kind of pattern is useful if you want to use one type of layer + during training but a different type of layer during testing since it + allows you to easily convert between related deep neural network types. + + Additionally, if you provide a constructor to build a layer from another + layer type you should also write your layer's deserialize() routine such + that it can read that other layer's serialized data in addition to your + own serialized data. + !*/ + + template + void setup ( + const SUBNET& sub + ); + /*! + requires + - SUBNET implements the SUBNET interface defined at the top of this file. + ensures + - performs any necessary initial memory allocations and/or sets parameters + to their initial values prior to learning. Therefore, calling setup + destroys any previously learned parameters. Also, typically setup() + would look at the dimensions of the outputs of sub and configure the + number of parameters in *this accordingly. + !*/ + + template + void forward( + const SUBNET& sub, + resizable_tensor& data_output + ); + /*! + requires + - SUBNET implements the SUBNET interface defined at the top of this file. + - setup() has been called. + ensures + - Runs the output of the subnetwork through this layer and stores the + results into #data_output. In particular, forward() can use any of the + outputs in sub (e.g. sub.get_output(), sub.subnet().get_output(), etc.) + to compute whatever it wants. + !*/ + + template + void backward( + const tensor& computed_output, // this parameter is optional + const tensor& gradient_input, + SUBNET& sub, + tensor& params_grad + ); + /*! + requires + - SUBNET implements the SUBNET interface defined at the top of this file. + - setup() has been called. + - computed_output is the tensor resulting from calling forward(sub,computed_output). + Moreover, this was the most recent call to forward(). This means that + forward() is allowed to cache intermediate results so they can be used + during the backward computation. + - have_same_dimensions(gradient_input, computed_output) == true + - have_same_dimensions(sub.get_gradient_input(), sub.get_output()) == true + - have_same_dimensions(params_grad, get_layer_params()) == true + ensures + - This function outputs the gradients of this layer with respect to the + input data from sub and also with respect to this layer's parameters. + These gradients are stored into #sub and #params_grad, respectively. To be + precise, the gradients are taken of a function f(sub,get_layer_params()) + which is defined thusly: + - Recalling that computed_output is a function of both sub and get_layer_params(), + since it is the result of calling forward(sub,computed_output): + let f(sub,get_layer_params()) == dot(computed_output, gradient_input) + Then we define the following gradient vectors: + - PARAMETER_GRADIENT == gradient of f(sub,get_layer_params()) with + respect to get_layer_params(). + - for all valid I: + - DATA_GRADIENT_I == gradient of f(sub,get_layer_params()) with + respect to layer(sub).get_output() (recall that forward() can + draw inputs from the immediate sub layer, sub.subnet(), or + any earlier layer. So you must consider the gradients with + respect to all inputs drawn from sub) + Finally, backward() outputs these gradients by performing: + - params_grad = PARAMETER_GRADIENT + - for all valid I: + - layer(sub).get_gradient_input() += DATA_GRADIENT_I + !*/ + + void forward_inplace( + const tensor& data_input, + tensor& data_output + ); + /*! + requires + - have_same_dimensions(data_input,data_output) == true + - setup() has been called. + ensures + - Runs the data_input tensor through this layer and stores the output into + #data_output. + - This function supports in-place operation, i.e. having + is_same_object(data_input, data_output)==true + !*/ + + void backward_inplace( + const tensor& computed_output, // this parameter is optional + const tensor& gradient_input, + tensor& data_grad, + tensor& params_grad + ); + /*! + requires + - setup() has been called. + - computed_output is the tensor resulting from the most recent call to + forward_inplace(). This means that forward_inplace() is allowed to cache + intermediate results so they can be used during the backward computation. + - have_same_dimensions(gradient_input, data_grad) == true + - have_same_dimensions(gradient_input, computed_output) == true + - have_same_dimensions(params_grad, get_layer_params()) == true + ensures + - This function supports in-place operation, i.e. having + is_same_object(gradient_input, data_grad)==true + - This function outputs the gradients of this layer with respect to the + input data from a sublayer and also with respect to this layer's parameters. + These gradients are stored into #data_grad and #params_grad, respectively. To be + precise, the gradients are taken of a function f(data_input,get_layer_params()) + which is defined thusly: + - Recalling that computed_output is a function of both the input to + forward_inplace() and get_layer_params(), since it is the result of + calling forward_inplace(data_input,computed_output): + let f(data_input,get_layer_params()) == dot(computed_output, gradient_input) + Then we define the following gradient vectors: + - PARAMETER_GRADIENT == gradient of f(data_input,get_layer_params()) with + respect to get_layer_params(). + - DATA_GRADIENT == gradient of f(data_input,get_layer_params()) with respect + to data_input. + Finally, backward_inplace() outputs these gradients by performing: + - params_grad = PARAMETER_GRADIENT + - if (is_same_object(gradient_input, data_grad)) then + - data_grad = DATA_GRADIENT + - else + - data_grad += DATA_GRADIENT + !*/ + + const tensor& get_layer_params( + ) const; + /*! + ensures + - returns the parameters that define the behavior of forward(). + !*/ + + tensor& get_layer_params( + ); + /*! + ensures + - returns the parameters that define the behavior of forward(). + !*/ + + + dpoint map_input_to_output(dpoint p) const; + dpoint map_output_to_input(dpoint p) const; + /*! + These two functions are optional. If provided, they should map between + (column,row) coordinates in input and output tensors of forward(). Providing + these functions allows you to use global utility functions like + input_tensor_to_output_tensor(). + !*/ + + void clean ( + ); + /*! + Implementing this function is optional. If you don't need it then you don't + have to provide a clean(). But if you do provide it then it must behave as + follows: + + ensures + - calling clean() Causes this object to forget about everything except its + parameters. This is useful if your layer caches information between + forward and backward passes and you want to clean out that cache + information before saving the network to disk. + !*/ + + }; + + std::ostream& operator<<(std::ostream& out, const EXAMPLE_COMPUTATIONAL_LAYER_& item); + /*! + print a string describing this layer. + !*/ + + void to_xml(const EXAMPLE_COMPUTATIONAL_LAYER_& item, std::ostream& out); + /*! + This function is optional, but required if you want to print your networks with + net_to_xml(). Therefore, to_xml() prints a layer as XML. + !*/ + + void serialize(const EXAMPLE_COMPUTATIONAL_LAYER_& item, std::ostream& out); + void deserialize(EXAMPLE_COMPUTATIONAL_LAYER_& item, std::istream& in); + /*! + provides serialization support + !*/ + + // For each layer you define, always define an add_layer template so that layers can be + // easily composed. Moreover, the convention is that the layer class ends with an _ + // while the add_layer template has the same name but without the trailing _. + template + using EXAMPLE_COMPUTATIONAL_LAYER = add_layer; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + enum fc_bias_mode + { + FC_HAS_BIAS = 0, + FC_NO_BIAS = 1 + }; + + struct num_fc_outputs + { + num_fc_outputs(unsigned long n) : num_outputs(n) {} + unsigned long num_outputs; + }; + + template < + unsigned long num_outputs, + fc_bias_mode bias_mode + > + class fc_ + { + /*! + REQUIREMENTS ON num_outputs + num_outputs > 0 + + WHAT THIS OBJECT REPRESENTS + This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface + defined above. In particular, it defines a fully connected layer that + takes an input tensor and multiplies it by a weight matrix and outputs the + results. + + The dimensions of the tensors output by this layer are as follows (letting + IN be the input tensor and OUT the output tensor): + - OUT.num_samples() == IN.num_samples() + - OUT.k() == get_num_outputs() + - OUT.nr() == 1 + - OUT.nc() == 1 + !*/ + + public: + + fc_( + ); + /*! + ensures + - #get_num_outputs() == num_outputs + - #get_bias_mode() == bias_mode + - #get_learning_rate_multiplier() == 1 + - #get_weight_decay_multiplier() == 1 + - #get_bias_learning_rate_multiplier() == 1 + - #get_bias_weight_decay_multiplier() == 0 + !*/ + + fc_( + num_fc_outputs o + ); + /*! + ensures + - #get_num_outputs() == o.num_outputs + - #get_bias_mode() == bias_mode + - #get_learning_rate_multiplier() == 1 + - #get_weight_decay_multiplier() == 1 + - #get_bias_learning_rate_multiplier() == 1 + - #get_bias_weight_decay_multiplier() == 0 + !*/ + + unsigned long get_num_outputs ( + ) const; + /*! + ensures + - This layer outputs column vectors that contain get_num_outputs() + elements. That is, the output tensor T from forward() will be such that: + - T.num_samples() == however many samples were given to forward(). + - T.k() == get_num_outputs() + - The rest of the dimensions of T will be 1. + !*/ + + void set_num_outputs( + long num + ); + /*! + requires + - num > 0 + - get_layer_params().size() == 0 || get_num_outputs() == num + (i.e. You can't change the number of outputs in fc_ if the parameter + tensor has already been allocated.) + ensures + - #get_num_outputs() == num + !*/ + + fc_bias_mode get_bias_mode ( + ) const; + /*! + ensures + - returns the bias mode which determines if this layer includes bias terms. + That is, if the bias mode is FC_HAS_BIAS then a different constant scalar + is added to each of the outputs of this layer. + !*/ + + double get_learning_rate_multiplier( + ) const; + /*! + ensures + - returns a multiplier number. The interpretation is that this object is + requesting that the learning rate used to optimize its parameters be + multiplied by get_learning_rate_multiplier(). + !*/ + + double get_weight_decay_multiplier( + ) const; + /*! + ensures + - returns a multiplier number. The interpretation is that this object is + requesting that the weight decay used to optimize its parameters be + multiplied by get_weight_decay_multiplier(). + !*/ + + void set_learning_rate_multiplier( + double val + ); + /*! + requires + - val >= 0 + ensures + - #get_learning_rate_multiplier() == val + !*/ + + void set_weight_decay_multiplier( + double val + ); + /*! + requires + - val >= 0 + ensures + - #get_weight_decay_multiplier() == val + !*/ + + double get_bias_learning_rate_multiplier( + ) const; + /*! + ensures + - returns a multiplier number. The interpretation is that this object is + requesting that the learning rate used to optimize its bias parameters be + multiplied by get_learning_rate_multiplier()*get_bias_learning_rate_multiplier(). + !*/ + + double get_bias_weight_decay_multiplier( + ) const; + /*! + ensures + - returns a multiplier number. The interpretation is that this object is + requesting that the weight decay used to optimize its bias parameters be + multiplied by get_weight_decay_multiplier()*get_bias_weight_decay_multiplier(). + !*/ + + void set_bias_learning_rate_multiplier( + double val + ); + /*! + requires + - val >= 0 + ensures + - #get_bias_learning_rate_multiplier() == val + !*/ + + void set_bias_weight_decay_multiplier( + double val + ); + /*! + requires + - val >= 0 + ensures + - #get_bias_weight_decay_multiplier() == val + !*/ + + alias_tensor_const_instance get_weights( + ) const; + /*! + ensures + - returns an alias of get_layer_params(), containing the weights matrix of + the fully connected layer. + - #get_weights().num_samples() is the number of elements in input sample, + i.e. sublayer's output's k * nc * nr. + - #get_bias().k() == #get_num_outputs() + - if get_bias_mode() == FC_HAS_BIAS: + - #get_layer_params().size() == (#get_weights().size() + #get_biases().size()) + - else: + - #get_layer_params().size() == #get_weights().size() + !*/ + + alias_tensor_instance get_weights( + ); + /*! + ensures + - returns an alias of get_layer_params(), containing the weights matrix of + the fully connected layer. + - #get_weights().num_samples() is the number of elements in input sample, + i.e. sublayer's output's k * nc * nr. + - #get_bias().k() == #get_num_outputs() + - if get_bias_mode() == FC_HAS_BIAS: + - #get_layer_params().size() == (#get_weights().size() + #get_biases().size()) + - else: + - #get_layer_params().size() == #get_weights().size() + !*/ + + alias_tensor_const_instance get_biases( + ) const; + /*! + requires + - #get_bias_mode() == FC_HAS_BIAS + ensures + - returns an alias of get_layer_params(), containing the bias vector of + the fully connected layer. + - #get_bias().num_samples() == 1 + - #get_bias().k() == #get_num_outputs() + - #get_layer_params().size() == (#get_weights().size() + #get_biases().size()) + !*/ + + alias_tensor_instance get_biases( + ); + /*! + requires + - #get_bias_mode() == FC_HAS_BIAS + ensures + - returns an alias of get_layer_params(), containing the bias vector of + the fully connected layer. + - #get_bias().num_samples() == 1 + - #get_bias().k() == #get_num_outputs() + - #get_layer_params().size() == (#get_weights().size() + #get_biases().size()) + !*/ + + template void setup (const SUBNET& sub); + template void forward(const SUBNET& sub, resizable_tensor& output); + template void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad); + const tensor& get_layer_params() const; + tensor& get_layer_params(); + /*! + These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. + !*/ + + }; + + template < + unsigned long num_outputs, + typename SUBNET + > + using fc = add_layer, SUBNET>; + + template < + unsigned long num_outputs, + typename SUBNET + > + using fc_no_bias = add_layer, SUBNET>; + +// ---------------------------------------------------------------------------------------- + + struct num_con_outputs + { + num_con_outputs(unsigned long n) : num_outputs(n) {} + unsigned long num_outputs; + }; + + template < + long _num_filters, + long _nr, + long _nc, + int _stride_y, + int _stride_x, + int _padding_y = _stride_y!=1? 0 : _nr/2, + int _padding_x = _stride_x!=1? 0 : _nc/2 + > + class con_ + { + /*! + REQUIREMENTS ON TEMPLATE ARGUMENTS + - _num_filters > 0 + - _nr >= 0 + - _nc >= 0 + - _stride_y > 0 + - _stride_x > 0 + - _padding_y >= 0 + - _padding_x >= 0 + - Also, we require that: + - if (_nr == 0) then + - _padding_y == 0 + - else + - _padding_y < _nr + - if (_nc == 0) then + - _padding_x == 0 + - else + - _padding_x < _nc + + WHAT THIS OBJECT REPRESENTS + This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface + defined above. In particular, it defines a convolution layer that takes an + input tensor (nominally representing an image) and convolves it with a set + of filters and then outputs the results. + + The dimensions of the tensors output by this layer are as follows (letting + IN be the input tensor and OUT the output tensor): + - OUT.num_samples() == IN.num_samples() + - OUT.k() == num_filters() + - OUT.nr() == 1+(IN.nr() + 2*padding_y() - nr())/stride_y() + - OUT.nc() == 1+(IN.nc() + 2*padding_x() - nc())/stride_x() + + Note also that setting _nr or _nc to 0 has a special meaning of "set the + filter size equal to the input image size". Specifically, it means: + - if (_nr == 0) then + - nr() == IN.nr() + - OUT.nr() == 1 + - if (_nc == 0) then + - nc() == IN.nc() + - OUT.nc() == 1 + !*/ + + public: + con_( + ); + /*! + ensures + - #num_filters() == _num_filters + - #nr() == _nr + - #nc() == _nc + - #stride_y() == _stride_y + - #stride_x() == _stride_x + - #padding_y() == _padding_y + - #padding_x() == _padding_x + - #get_learning_rate_multiplier() == 1 + - #get_weight_decay_multiplier() == 1 + - #get_bias_learning_rate_multiplier() == 1 + - #get_bias_weight_decay_multiplier() == 0 + !*/ + + con_( + num_con_outputs o + ); + /*! + ensures + - #num_filters() == o.num_outputs + - #nr() == _nr + - #nc() == _nc + - #stride_y() == _stride_y + - #stride_x() == _stride_x + - #padding_y() == _padding_y + - #padding_x() == _padding_x + - #get_learning_rate_multiplier() == 1 + - #get_weight_decay_multiplier() == 1 + - #get_bias_learning_rate_multiplier() == 1 + - #get_bias_weight_decay_multiplier() == 0 + !*/ + + long num_filters( + ) const; + /*! + ensures + - returns the number of filters contained in this layer. The k dimension + of the output tensors produced by this layer will be equal to the number + of filters. + !*/ + + void set_num_filters( + long num + ); + /*! + requires + - num > 0 + - get_layer_params().size() == 0 || num_filters() == num + (i.e. You can't change the number of filters in con_ if the parameter + tensor has already been allocated.) + ensures + - #num_filters() == num + !*/ + + long nr( + ) const; + /*! + ensures + - returns the number of rows in the filters in this layer. Note that if + nr()==0 then it means the size of the filter is not yet assigned, but + once setup() is called nr() will be set to the input tensor's nr(). + Therefore, nr()==0 has the special interpretation of "be the same size as + the input tensor". + !*/ + + long nc( + ) const; + /*! + ensures + - returns the number of columns in the filters in this layer. Note that if + nc()==0 then it means the size of the filter is not yet assigned, but + once setup() is called nc() will be set to the input tensor's nc(). + Therefore, nc()==0 has the special interpretation of "be the same size as + the input tensor". + !*/ + + long stride_y( + ) const; + /*! + ensures + - returns the vertical stride used when convolving the filters over an + image. That is, each filter will be moved stride_y() pixels down at a + time when it moves over the image. + !*/ + + long stride_x( + ) const; + /*! + ensures + - returns the horizontal stride used when convolving the filters over an + image. That is, each filter will be moved stride_x() pixels right at a + time when it moves over the image. + !*/ + + long padding_y( + ) const; + /*! + ensures + - returns the number of pixels of zero padding added to the top and bottom + sides of the image. + !*/ + + long padding_x( + ) const; + /*! + ensures + - returns the number of pixels of zero padding added to the left and right + sides of the image. + !*/ + + double get_learning_rate_multiplier( + ) const; + /*! + ensures + - returns a multiplier number. The interpretation is that this object is + requesting that the learning rate used to optimize its parameters be + multiplied by get_learning_rate_multiplier(). + !*/ + + double get_weight_decay_multiplier( + ) const; + /*! + ensures + - returns a multiplier number. The interpretation is that this object is + requesting that the weight decay used to optimize its parameters be + multiplied by get_weight_decay_multiplier(). + !*/ + + void set_learning_rate_multiplier( + double val + ); + /*! + requires + - val >= 0 + ensures + - #get_learning_rate_multiplier() == val + !*/ + + void set_weight_decay_multiplier( + double val + ); + /*! + requires + - val >= 0 + ensures + - #get_weight_decay_multiplier() == val + !*/ + + double get_bias_learning_rate_multiplier( + ) const; + /*! + ensures + - returns a multiplier number. The interpretation is that this object is + requesting that the learning rate used to optimize its bias parameters be + multiplied by get_learning_rate_multiplier()*get_bias_learning_rate_multiplier(). + !*/ + + double get_bias_weight_decay_multiplier( + ) const; + /*! + ensures + - returns a multiplier number. The interpretation is that this object is + requesting that the weight decay used to optimize its bias parameters be + multiplied by get_weight_decay_multiplier()*get_bias_weight_decay_multiplier(). + !*/ + + void set_bias_learning_rate_multiplier( + double val + ); + /*! + requires + - val >= 0 + ensures + - #get_bias_learning_rate_multiplier() == val + !*/ + + void set_bias_weight_decay_multiplier( + double val + ); + /*! + requires + - val >= 0 + ensures + - #get_bias_weight_decay_multiplier() == val + !*/ + + template void setup (const SUBNET& sub); + template void forward(const SUBNET& sub, resizable_tensor& output); + template void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad); + dpoint map_input_to_output(dpoint p) const; + dpoint map_output_to_input(dpoint p) const; + const tensor& get_layer_params() const; + tensor& get_layer_params(); + /*! + These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. + !*/ + + }; + + template < + long num_filters, + long nr, + long nc, + int stride_y, + int stride_x, + typename SUBNET + > + using con = add_layer, SUBNET>; + +// ---------------------------------------------------------------------------------------- + + template < + long _num_filters, + long _nr, + long _nc, + int _stride_y, + int _stride_x, + int _padding_y = _stride_y!=1? 0 : _nr/2, + int _padding_x = _stride_x!=1? 0 : _nc/2 + > + class cont_ + { + /*! + REQUIREMENTS ON TEMPLATE ARGUMENTS + All of them must be > 0. + Also, we require that: + - 0 <= _padding_y && _padding_y < _nr + - 0 <= _padding_x && _padding_x < _nc + + WHAT THIS OBJECT REPRESENTS + This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface + defined above. In particular, it defines a transposed convolution layer + that takes an input tensor and transpose convolves (sometimes called + "deconvolution") it with a set of filters and then outputs the results. + + This is essentially a convolutional layer that allows fractional strides. + Therefore, you can make output tensors that are larger than the input + tensors using this layer type. + + + The dimensions of the tensors output by this layer are as follows (letting + IN be the input tensor and OUT the output tensor): + - OUT.num_samples() == IN.num_samples() + - OUT.k() == num_filters() + - OUT.nr() == stride_y()*(IN.nr()-1) + nr() - 2*padding_y() + - OUT.nc() == stride_x()*(IN.nc()-1) + nc() - 2*padding_x() + !*/ + + public: + cont_( + ); + /*! + ensures + - #num_filters() == _num_filters + - #nr() == _nr + - #nc() == _nc + - #stride_y() == _stride_y + - #stride_x() == _stride_x + - #padding_y() == _padding_y + - #padding_x() == _padding_x + - #get_learning_rate_multiplier() == 1 + - #get_weight_decay_multiplier() == 1 + - #get_bias_learning_rate_multiplier() == 1 + - #get_bias_weight_decay_multiplier() == 0 + !*/ + + cont_( + num_con_outputs o + ); + /*! + ensures + - #num_filters() == o.num_outputs + - #nr() == _nr + - #nc() == _nc + - #stride_y() == _stride_y + - #stride_x() == _stride_x + - #padding_y() == _padding_y + - #padding_x() == _padding_x + - #get_learning_rate_multiplier() == 1 + - #get_weight_decay_multiplier() == 1 + - #get_bias_learning_rate_multiplier() == 1 + - #get_bias_weight_decay_multiplier() == 0 + !*/ + + long num_filters( + ) const; + /*! + ensures + - returns the number of filters contained in this layer. The k dimension + of the output tensors produced by this layer will be equal to the number + of filters. + !*/ + + void set_num_filters( + long num + ); + /*! + requires + - num > 0 + - get_layer_params().size() == 0 || num_filters() == num + (i.e. You can't change the number of filters in cont_ if the parameter + tensor has already been allocated.) + ensures + - #num_filters() == num + !*/ + + long nr( + ) const; + /*! + ensures + - returns the number of rows in the filters in this layer. + !*/ + + long nc( + ) const; + /*! + ensures + - returns the number of columns in the filters in this layer. + !*/ + + long stride_y( + ) const; + /*! + ensures + - returns the vertical stride used when convolving the filters over an + image. That is, each filter will be moved 1.0/stride_y() pixels down at + a time when it moves over the image. + !*/ + + long stride_x( + ) const; + /*! + ensures + - returns the horizontal stride used when convolving the filters over an + image. That is, each filter will be moved 1.0/stride_x() pixels right at + a time when it moves over the image. + !*/ + + long padding_y( + ) const; + /*! + ensures + - returns the number of pixels of zero padding added to the top and bottom + sides of the image. + !*/ + + long padding_x( + ) const; + /*! + ensures + - returns the number of pixels of zero padding added to the left and right + sides of the image. + !*/ + + double get_learning_rate_multiplier( + ) const; + /*! + ensures + - returns a multiplier number. The interpretation is that this object is + requesting that the learning rate used to optimize its parameters be + multiplied by get_learning_rate_multiplier(). + !*/ + + double get_weight_decay_multiplier( + ) const; + /*! + ensures + - returns a multiplier number. The interpretation is that this object is + requesting that the weight decay used to optimize its parameters be + multiplied by get_weight_decay_multiplier(). + !*/ + + void set_learning_rate_multiplier( + double val + ); + /*! + requires + - val >= 0 + ensures + - #get_learning_rate_multiplier() == val + !*/ + + void set_weight_decay_multiplier( + double val + ); + /*! + requires + - val >= 0 + ensures + - #get_weight_decay_multiplier() == val + !*/ + + double get_bias_learning_rate_multiplier( + ) const; + /*! + ensures + - returns a multiplier number. The interpretation is that this object is + requesting that the learning rate used to optimize its bias parameters be + multiplied by get_learning_rate_multiplier()*get_bias_learning_rate_multiplier(). + !*/ + + double get_bias_weight_decay_multiplier( + ) const; + /*! + ensures + - returns a multiplier number. The interpretation is that this object is + requesting that the weight decay used to optimize its bias parameters be + multiplied by get_weight_decay_multiplier()*get_bias_weight_decay_multiplier(). + !*/ + + void set_bias_learning_rate_multiplier( + double val + ); + /*! + requires + - val >= 0 + ensures + - #get_bias_learning_rate_multiplier() == val + !*/ + + void set_bias_weight_decay_multiplier( + double val + ); + /*! + requires + - val >= 0 + ensures + - #get_bias_weight_decay_multiplier() == val + !*/ + + template void setup (const SUBNET& sub); + template void forward(const SUBNET& sub, resizable_tensor& output); + template void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad); + dpoint map_input_to_output(dpoint p) const; + dpoint map_output_to_input(dpoint p) const; + const tensor& get_layer_params() const; + tensor& get_layer_params(); + /*! + These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. + !*/ + + }; + + template < + long num_filters, + long nr, + long nc, + int stride_y, + int stride_x, + typename SUBNET + > + using cont = add_layer, SUBNET>; + +// ---------------------------------------------------------------------------------------- + + template < + int scale_y, + int scale_x + > + class upsample_ + { + /*! + REQUIREMENTS ON TEMPLATE ARGUMENTS + All of them must be >= 1. + + WHAT THIS OBJECT REPRESENTS + This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface + defined above. In particular, it allows you to upsample a layer using + bilinear interpolation. To be very specific, it upsamples each of the + channels in an input tensor. Therefore, if IN is the input tensor to this + layer and OUT the output tensor, then we will have: + - OUT.num_samples() == IN.num_samples() + - OUT.k() == IN.k() + - OUT.nr() == IN.nr()*scale_y + - OUT.nc() == IN.nr()*scale_x + - for all valid i,k: image_plane(OUT,i,k) is a copy of + image_plane(IN,i,k) that has been bilinearly interpolated to fit into + the shape of image_plane(OUT,i,k). + !*/ + public: + + upsample_( + ); + /*! + ensures + - This object has no state, so the constructor does nothing, aside from + providing default constructability. + !*/ + + template void setup (const SUBNET& sub); + template void forward(const SUBNET& sub, resizable_tensor& output); + template void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad); + dpoint map_input_to_output(dpoint p) const; + dpoint map_output_to_input(dpoint p) const; + const tensor& get_layer_params() const; + tensor& get_layer_params(); + /*! + These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. + !*/ + }; + + template < + int scale, + typename SUBNET + > + using upsample = add_layer, SUBNET>; + +// ---------------------------------------------------------------------------------------- + + class dropout_ + { + /*! + WHAT THIS OBJECT REPRESENTS + This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface + defined above. In particular, it defines a dropout layer. Therefore, it + passes its inputs through the stochastic function f(x) which outputs either + 0 or x. The probability of 0 being output is given by the drop_rate + argument to this object's constructor. + + Note that, after you finish training a network with dropout, it is a good + idea to replace each dropout_ layer with a multiply_ layer because the + multiply_ layer is faster and deterministic. + !*/ + + public: + + explicit dropout_( + float drop_rate = 0.5 + ); + /*! + requires + - 0 <= drop_rate <= 1 + ensures + - #get_drop_rate() == drop_rate + !*/ + + float get_drop_rate ( + ) const; + /*! + ensures + - returns the probability that an individual input value to this layer will + be replaced with 0. + !*/ + + template void setup (const SUBNET& sub); + void forward_inplace(const tensor& input, tensor& output); + void backward_inplace(const tensor& gradient_input, tensor& data_grad, tensor& params_grad); + dpoint map_input_to_output(dpoint p) const; + dpoint map_output_to_input(dpoint p) const; + const tensor& get_layer_params() const; + tensor& get_layer_params(); + /*! + These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. + !*/ + }; + + template + using dropout = add_layer; + +// ---------------------------------------------------------------------------------------- + + class multiply_ + { + /*! + WHAT THIS OBJECT REPRESENTS + This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface + defined above. In particular, it defines a basic layer that just + multiplies its input tensor with a constant value and returns the result. + It therefore has no learnable parameters. + !*/ + + public: + explicit multiply_( + float val = 0.5 + ); + /*! + ensures + - #get_multiply_value() == val + !*/ + + multiply_ ( + const dropout_& item + ); + /*! + ensures + - #get_multiply_value() == 1-item.get_drop_rate() + (i.e. We construct the multiply_ layer so that it is essentially a + deterministic version of the given dropout_ layer) + !*/ + + float get_multiply_value ( + ) const; + /*! + ensures + - this layer simply multiplies its input tensor by get_multiply_value() and + produces the result as output. + !*/ + + template void setup (const SUBNET& sub); + void forward_inplace(const tensor& input, tensor& output); + void backward_inplace(const tensor& gradient_input, tensor& data_grad, tensor& params_grad); + dpoint map_input_to_output(dpoint p) const; + dpoint map_output_to_input(dpoint p) const; + const tensor& get_layer_params() const; + tensor& get_layer_params(); + /*! + These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. + !*/ + }; + + template + using multiply = add_layer; + +// ---------------------------------------------------------------------------------------- + + enum layer_mode + { + CONV_MODE = 0, // convolutional mode + FC_MODE = 1 // fully connected mode + }; + + const double DEFAULT_BATCH_NORM_EPS = 0.0001; + + template < + layer_mode mode + > + class bn_ + { + /*! + WHAT THIS OBJECT REPRESENTS + This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface + defined above. In particular, it defines a batch normalization layer that + implements the method described in the paper: + Batch Normalization: Accelerating Deep Network Training by Reducing + Internal Covariate Shift by Sergey Ioffe and Christian Szegedy + + In particular, this layer produces output tensors with the same + dimensionality as the input tensors, except that the mean and variances of + the elements have been standardized to 0 and 1 respectively. + + It should also be noted that when tensors with a num_samples() dimension of + 1 are passed to this layer it doesn't perform batch normalization. + Instead, it runs in "inference mode" where the learned linear normalizing + transformation is used to transform the tensor. + + Finally, after you finish training a batch normalized network, it is a good + idea to replace each bn_ layer with an affine_ layer because the affine_ + layer is faster and will never surprise you by performing batch + normalization on tensors that have a num_samples() dimension > 1. This allows + you to run large mini-batches of samples through your final network without + batch normalization executing at all. + !*/ + + public: + bn_( + ); + /*! + ensures + - #get_mode() == mode + - #get_running_stats_window_size() == 100 + - #get_learning_rate_multiplier() == 1 + - #get_weight_decay_multiplier() == 0 + - #get_bias_learning_rate_multiplier() == 1 + - #get_bias_weight_decay_multiplier() == 1 + - #get_eps() == tt::DEFAULT_BATCH_NORM_EPS + !*/ + + explicit bn_( + unsigned long window_size, + double eps = tt::DEFAULT_BATCH_NORM_EPS + ); + /*! + requires + - eps > 0 + - window_size > 0 + ensures + - #get_mode() == mode + - #get_running_stats_window_size() == window_size + - #get_learning_rate_multiplier() == 1 + - #get_weight_decay_multiplier() == 0 + - #get_bias_learning_rate_multiplier() == 1 + - #get_bias_weight_decay_multiplier() == 1 + - #get_eps() == eps + !*/ + + layer_mode get_mode( + ) const; + /*! + ensures + - returns the mode of this layer, either CONV_MODE or FC_MODE. + If the mode is FC_MODE then the normalization is applied across the + samples in a tensor (i.e. k()*nr()*nc() different things will be + normalized). Otherwise, normalization is applied across everything + except for the k() dimension, resulting in there being only k() + normalization equations that are applied spatially over the tensor. + + Therefore, if you are putting batch normalization after a fully connected + layer you should use FC_MODE. Otherwise, if you are putting batch + normalization after a convolutional layer you should use CONV_MODE. + !*/ + + double get_eps( + ) const; + /*! + ensures + - When doing batch normalization, we are dividing by the standard + deviation. This epsilon value returned by this function is added to the + variance to prevent the division from dividing by zero. + !*/ + + unsigned long get_running_stats_window_size ( + ) const; + /*! + ensures + - Just as recommended in the batch normalization paper, this object keeps a + running average of the mean and standard deviations of the features. + These averages are used during "inference mode" so you can run a single + object through a batch normalized network. They are also what is used to + initialize an affine_ layer that is constructed from a bn_ layer. This + function returns the effective number of recent samples used to compute + the running average. + !*/ + + void set_running_stats_window_size ( + unsigned long new_window_size + ); + /*! + requires + - new_window_size > 0 + ensures + - #get_running_stats_window_size() == new_window_size + !*/ + + double get_learning_rate_multiplier( + ) const; + /*! + ensures + - returns a multiplier number. The interpretation is that this object is + requesting that the learning rate used to optimize its parameters be + multiplied by get_learning_rate_multiplier(). + !*/ + + double get_weight_decay_multiplier( + ) const; + /*! + ensures + - returns a multiplier number. The interpretation is that this object is + requesting that the weight decay used to optimize its parameters be + multiplied by get_weight_decay_multiplier(). + !*/ + + void set_learning_rate_multiplier( + double val + ); + /*! + requires + - val >= 0 + ensures + - #get_learning_rate_multiplier() == val + !*/ + + void set_weight_decay_multiplier( + double val + ); + /*! + requires + - val >= 0 + ensures + - #get_weight_decay_multiplier() == val + !*/ + + double get_bias_learning_rate_multiplier( + ) const; + /*! + ensures + - returns a multiplier number. The interpretation is that this object is + requesting that the learning rate used to optimize its bias parameters be + multiplied by get_learning_rate_multiplier()*get_bias_learning_rate_multiplier(). + !*/ + + double get_bias_weight_decay_multiplier( + ) const; + /*! + ensures + - returns a multiplier number. The interpretation is that this object is + requesting that the weight decay used to optimize its bias parameters be + multiplied by get_weight_decay_multiplier()*get_bias_weight_decay_multiplier(). + !*/ + + void set_bias_learning_rate_multiplier( + double val + ); + /*! + requires + - val >= 0 + ensures + - #get_bias_learning_rate_multiplier() == val + !*/ + + void set_bias_weight_decay_multiplier( + double val + ); + /*! + requires + - val >= 0 + ensures + - #get_bias_weight_decay_multiplier() == val + !*/ + + template void setup (const SUBNET& sub); + template void forward(const SUBNET& sub, resizable_tensor& output); + template void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad); + dpoint map_input_to_output(dpoint p) const; + dpoint map_output_to_input(dpoint p) const; + const tensor& get_layer_params() const; + tensor& get_layer_params(); + /*! + These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. + !*/ + }; + + template + using bn_con = add_layer, SUBNET>; + template + using bn_fc = add_layer, SUBNET>; + +// ---------------------------------------------------------------------------------------- + + template + void set_all_bn_running_stats_window_sizes ( + const net_type& net, + unsigned long new_window_size + ); + /*! + requires + - new_window_size > 0 + - net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or + add_tag_layer. + ensures + - Sets the get_running_stats_window_size() field of all bn_ layers in net to + new_window_size. + !*/ + +// ---------------------------------------------------------------------------------------- + + class affine_ + { + /*! + WHAT THIS OBJECT REPRESENTS + This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface + defined above. In particular, it applies a simple pointwise linear + transformation to an input tensor. You can think of it as having two + parameter tensors, A and B. If the input tensor is called INPUT then the + output of this layer is: + A*INPUT+B + where all operations are performed element wise and each sample in the + INPUT tensor is processed separately. + + Moreover, this object has two modes that effect the dimensionalities of A + and B and how they are applied to compute A*INPUT+B. If + get_mode()==FC_MODE then A and B each have the same dimensionality as the + input tensor, except their num_samples() dimensions are 1. If + get_mode()==CONV_MODE then A and B have all their dimensions set to 1 + except for k(), which is equal to INPUT.k(). + + In either case, the computation of A*INPUT+B is performed pointwise over all + the elements of INPUT using either: + OUTPUT(n,k,r,c) == A(1,k,r,c)*INPUT(n,k,r,c)+B(1,k,r,c) + or + OUTPUT(n,k,r,c) == A(1,k,1,1)*INPUT(n,k,r,c)+B(1,k,1,1) + as appropriate. + + + Finally, note that the parameters of this layer are not learnable and + therefore not modified during network updates. Instead, the layer will + perform the identity transformation unless it is initialized with a bn_ + layer, in which case it will perform whatever transformation the bn_ layer + has learned. + !*/ + + public: + + affine_( + ); + /*! + ensures + - #get_mode() == FC_MODE + !*/ + + affine_( + layer_mode mode + ); + /*! + ensures + - #get_mode() == mode + !*/ + + template < + layer_mode mode + > + affine_( + const bn_& layer + ); + /*! + ensures + - Constructs affine_ so that it performs the same transformation as the + supplied batch normalization layer. You would want to do this after you + finish training a network with bn_ layers because the affine_ layer will + execute faster. + - #get_mode() == layer.get_mode() + !*/ + + layer_mode get_mode( + ) const; + /*! + ensures + - returns the mode of this layer, either CONV_MODE or FC_MODE. + !*/ + + template void setup (const SUBNET& sub); + void forward_inplace(const tensor& input, tensor& output); + void backward_inplace(const tensor& computed_output, const tensor& gradient_input, tensor& data_grad, tensor& params_grad); + dpoint map_input_to_output(dpoint p) const; + dpoint map_output_to_input(dpoint p) const; + const tensor& get_layer_params() const; + tensor& get_layer_params(); + /*! + These functions are implemented as described in the + EXAMPLE_COMPUTATIONAL_LAYER_ interface. Also note that get_layer_params() + always returns an empty tensor since there are no learnable parameters in this + object. + !*/ + + }; + + template + using affine = add_layer; + +// ---------------------------------------------------------------------------------------- + + template < + long _nr, + long _nc, + int _stride_y, + int _stride_x, + int _padding_y = _stride_y!=1? 0 : _nr/2, + int _padding_x = _stride_x!=1? 0 : _nc/2 + > + class max_pool_ + { + /*! + REQUIREMENTS ON TEMPLATE ARGUMENTS + - _nr >= 0 + - _nc >= 0 + - _stride_y > 0 + - _stride_x > 0 + - _padding_y >= 0 + - _padding_x >= 0 + - if (_nr != 0) then + - _padding_y < _nr + - else + - _padding_y == 0 + - if (_nc != 0) then + - _padding_x < _nr + - else + - _padding_x == 0 + + WHAT THIS OBJECT REPRESENTS + This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface + defined above. In particular, it defines a max pooling layer that takes an + input tensor and downsamples it. It does this by sliding a window over the + images in an input tensor and outputting, for each channel, the maximum + element within the window. + + If _nr == 0 then it means the filter size covers all the rows in the input + tensor, similarly for the _nc parameter. To be precise, if we call the + input tensor IN and the output tensor OUT, then OUT is defined as follows: + - let FILT_NR == (nr()==0) ? IN.nr() : nr() + - let FILT_NC == (nc()==0) ? IN.nc() : nc() + - OUT.num_samples() == IN.num_samples() + - OUT.k() == IN.k() + - OUT.nr() == 1+(IN.nr() + 2*padding_y() - FILT_NR)/stride_y() + - OUT.nc() == 1+(IN.nc() + 2*padding_x() - FILT_NC)/stride_x() + - for all valid s, k, r, and c: + - image_plane(OUT,s,k)(r,c) == max(subm_clipped(image_plane(IN,s,k), + centered_rect(x*stride_x() + FILT_NC/2 - padding_x(), + y*stride_y() + FILT_NR/2 - padding_y(), + FILT_NC, + FILT_NR))) + !*/ + + public: + + max_pool_ ( + ); + /*! + ensures + - #nr() == _nr + - #nc() == _nc + - #stride_y() == _stride_y + - #stride_x() == _stride_x + - #padding_y() == _padding_y + - #padding_x() == _padding_x + !*/ + + long nr( + ) const; + /*! + ensures + - returns the number of rows in the pooling window or 0 if the window size + is "the entire input tensor". + !*/ + + long nc( + ) const; + /*! + ensures + - returns the number of rows in the pooling window or 0 if the window size + is "the entire input tensor". + !*/ + + long stride_y( + ) const; + /*! + ensures + - returns the vertical stride used when scanning the max pooling window + over an image. That is, each window will be moved stride_y() pixels down + at a time when it moves over the image. + !*/ + + long stride_x( + ) const; + /*! + ensures + - returns the horizontal stride used when scanning the max pooling window + over an image. That is, each window will be moved stride_x() pixels down + at a time when it moves over the image. + !*/ + + long padding_y( + ) const; + /*! + ensures + - returns the number of pixels of zero padding added to the top and bottom + sides of the image. + !*/ + + long padding_x( + ) const; + /*! + ensures + - returns the number of pixels of zero padding added to the left and right + sides of the image. + !*/ + + template void setup (const SUBNET& sub); + template void forward(const SUBNET& sub, resizable_tensor& output); + template void backward(const tensor& computed_output, const tensor& gradient_input, SUBNET& sub, tensor& params_grad); + dpoint map_input_to_output(dpoint p) const; + dpoint map_output_to_input(dpoint p) const; + const tensor& get_layer_params() const; + tensor& get_layer_params(); + /*! + These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ + interface. Note that this layer doesn't have any parameters, so the tensor + returned by get_layer_params() is always empty. + !*/ + }; + + template < + long nr, + long nc, + int stride_y, + int stride_x, + typename SUBNET + > + using max_pool = add_layer, SUBNET>; + + template < + typename SUBNET + > + using max_pool_everything = add_layer, SUBNET>; + +// ---------------------------------------------------------------------------------------- + + template < + long _nr, + long _nc, + int _stride_y, + int _stride_x, + int _padding_y = _stride_y!=1? 0 : _nr/2, + int _padding_x = _stride_x!=1? 0 : _nc/2 + > + class avg_pool_ + { + /*! + REQUIREMENTS ON TEMPLATE ARGUMENTS + - _nr >= 0 + - _nc >= 0 + - _stride_y > 0 + - _stride_x > 0 + - _padding_y >= 0 + - _padding_x >= 0 + - if (_nr != 0) then + - _padding_y < _nr + - else + - _padding_y == 0 + - if (_nc != 0) then + - _padding_x < _nr + - else + - _padding_x == 0 + + WHAT THIS OBJECT REPRESENTS + This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface + defined above. In particular, it defines an average pooling layer that + takes an input tensor and downsamples it. It does this by sliding a window + over the images in an input tensor and outputting, for each channel, the + average element within the window. + + If _nr == 0 then it means the filter size covers all the rows in the input + tensor, similarly for the _nc parameter. To be precise, if we call the + input tensor IN and the output tensor OUT, then OUT is defined as follows: + - let FILT_NR == (nr()==0) ? IN.nr() : nr() + - let FILT_NC == (nc()==0) ? IN.nc() : nc() + - OUT.num_samples() == IN.num_samples() + - OUT.k() == IN.k() + - OUT.nr() == 1+(IN.nr() + 2*padding_y() - FILT_NR)/stride_y() + - OUT.nc() == 1+(IN.nc() + 2*padding_x() - FILT_NC)/stride_x() + - for all valid s, k, r, and c: + - image_plane(OUT,s,k)(r,c) == mean(subm_clipped(image_plane(IN,s,k), + centered_rect(x*stride_x() + FILT_NC/2 - padding_x(), + y*stride_y() + FILT_NR/2 - padding_y(), + FILT_NC, + FILT_NR))) + !*/ + + public: + + avg_pool_ ( + ); + /*! + ensures + - #nr() == _nr + - #nc() == _nc + - #stride_y() == _stride_y + - #stride_x() == _stride_x + - #padding_y() == _padding_y + - #padding_x() == _padding_x + !*/ + + long nr( + ) const; + /*! + ensures + - returns the number of rows in the pooling window or 0 if the window size + is "the entire input tensor". + !*/ + + long nc( + ) const; + /*! + ensures + - returns the number of rows in the pooling window or 0 if the window size + is "the entire input tensor". + !*/ + + long stride_y( + ) const; + /*! + ensures + - returns the vertical stride used when scanning the pooling window + over an image. That is, each window will be moved stride_y() pixels down + at a time when it moves over the image. + !*/ + + long stride_x( + ) const; + /*! + ensures + - returns the horizontal stride used when scanning the pooling window + over an image. That is, each window will be moved stride_x() pixels down + at a time when it moves over the image. + !*/ + + long padding_y( + ) const; + /*! + ensures + - returns the number of pixels of zero padding added to the top and bottom + sides of the image. + !*/ + + long padding_x( + ) const; + /*! + ensures + - returns the number of pixels of zero padding added to the left and right + sides of the image. + !*/ + + template void setup (const SUBNET& sub); + template void forward(const SUBNET& sub, resizable_tensor& output); + template void backward(const tensor& computed_output, const tensor& gradient_input, SUBNET& sub, tensor& params_grad); + dpoint map_input_to_output(dpoint p) const; + dpoint map_output_to_input(dpoint p) const; + const tensor& get_layer_params() const; + tensor& get_layer_params(); + /*! + These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ + interface. Note that this layer doesn't have any parameters, so the tensor + returned by get_layer_params() is always empty. + !*/ + + }; + + template < + long nr, + long nc, + int stride_y, + int stride_x, + typename SUBNET + > + using avg_pool = add_layer, SUBNET>; + + template < + typename SUBNET + > + using avg_pool_everything = add_layer, SUBNET>; + +// ---------------------------------------------------------------------------------------- + + class relu_ + { + /*! + WHAT THIS OBJECT REPRESENTS + This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface + defined above. In particular, it defines a rectified linear layer. + Therefore, it passes its inputs through the function + f(x)=max(x,0) + where f() is applied pointwise across the input tensor. + !*/ + + public: + + relu_( + ); + + template void setup (const SUBNET& sub); + void forward_inplace(const tensor& input, tensor& output); + void backward_inplace(const tensor& computed_output, const tensor& gradient_input, tensor& data_grad, tensor& params_grad); + dpoint map_input_to_output(dpoint p) const; + dpoint map_output_to_input(dpoint p) const; + const tensor& get_layer_params() const; + tensor& get_layer_params(); + /*! + These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ + interface. Note that this layer doesn't have any parameters, so the tensor + returned by get_layer_params() is always empty. + !*/ + }; + + template + using relu = add_layer; + +// ---------------------------------------------------------------------------------------- + + class prelu_ + { + /*! + WHAT THIS OBJECT REPRESENTS + This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface + defined above. In particular, it defines a parametric rectified linear + layer. Therefore, it passes its inputs through the function + f(x) = x>0 ? x : p*x + where f() is applied pointwise across the input tensor and p is a scalar + parameter learned by this layer. + + + This is the layer type introduced in the paper: + He, Kaiming, et al. "Delving deep into rectifiers: Surpassing + human-level performance on imagenet classification." Proceedings of the + IEEE International Conference on Computer Vision. 2015. + !*/ + + public: + + explicit prelu_( + float initial_param_value = 0.25 + ); + /*! + ensures + - The p parameter will be initialized with initial_param_value. + - #get_initial_param_value() == initial_param_value. + !*/ + + float get_initial_param_value ( + ) const; + /*! + ensures + - returns the initial value of the prelu parameter. + !*/ + + template void setup (const SUBNET& sub); + void forward_inplace(const tensor& input, tensor& output); + void backward_inplace(const tensor& computed_output, const tensor& gradient_input, tensor& data_grad, tensor& params_grad); + dpoint map_input_to_output(dpoint p) const; + dpoint map_output_to_input(dpoint p) const; + const tensor& get_layer_params() const; + tensor& get_layer_params(); + /*! + These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. + !*/ + }; + + template + using prelu = add_layer; + +// ---------------------------------------------------------------------------------------- + + class sig_ + { + /*! + WHAT THIS OBJECT REPRESENTS + This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface + defined above. In particular, it defines a sigmoid layer. Therefore, it + passes its inputs through the function + f(x)=1/(1+exp(-x)) + where f() is applied pointwise across the input tensor. + !*/ + + public: + + sig_( + ); + + template void setup (const SUBNET& sub); + void forward_inplace(const tensor& input, tensor& output); + void backward_inplace(const tensor& computed_output, const tensor& gradient_input, tensor& data_grad, tensor& params_grad); + dpoint map_input_to_output(dpoint p) const; + dpoint map_output_to_input(dpoint p) const; + const tensor& get_layer_params() const; + tensor& get_layer_params(); + /*! + These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ + interface. Note that this layer doesn't have any parameters, so the tensor + returned by get_layer_params() is always empty. + !*/ + }; + + template + using sig = add_layer; + +// ---------------------------------------------------------------------------------------- + + class htan_ + { + /*! + WHAT THIS OBJECT REPRESENTS + This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface + defined above. In particular, it defines a hyperbolic tangent layer. + Therefore, it passes its inputs through the function + f(x)=std::tanh(x) + where f() is applied pointwise across the input tensor. + !*/ + + public: + + htan_( + ); + + template void setup (const SUBNET& sub); + void forward_inplace(const tensor& input, tensor& output); + void backward_inplace(const tensor& computed_output, const tensor& gradient_input, tensor& data_grad, tensor& params_grad); + dpoint map_input_to_output(dpoint p) const; + dpoint map_output_to_input(dpoint p) const; + const tensor& get_layer_params() const; + tensor& get_layer_params(); + /*! + These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ + interface. Note that this layer doesn't have any parameters, so the tensor + returned by get_layer_params() is always empty. + !*/ + }; + + template + using htan = add_layer; + +// ---------------------------------------------------------------------------------------- + + class softmax_ + { + /*! + WHAT THIS OBJECT REPRESENTS + This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface + defined above. In particular, it defines a softmax layer. To be precise, + we define the softmax function s(x) as: + s(x) == exp(x)/sum(exp(x)) + where x is a vector. Then this layer treats its input tensor as a + collection of multi-channel images and applies s() to each spatial location + in each image. In each application, the tensor::k() channel elements at + each position are input to s() and then replaced by the outputs of s(). + + This means that, for example, if you collapsed each output image to a 1 + channel image by adding the channels then you would end up with images + where each pixel value was 1. This is because the sum of the outputs of + s() will always be equal to 1. + !*/ + + public: + + softmax_( + ); + + template void setup (const SUBNET& sub); + void forward_inplace(const tensor& input, tensor& output); + void backward_inplace(const tensor& computed_output, const tensor& gradient_input, tensor& data_grad, tensor& params_grad); + const tensor& get_layer_params() const; + tensor& get_layer_params(); + /*! + These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ + interface. Note that this layer doesn't have any parameters, so the tensor + returned by get_layer_params() is always empty. + !*/ + }; + + template + using softmax = add_layer; + +// ---------------------------------------------------------------------------------------- + + class softmax_all_ + { + /*! + WHAT THIS OBJECT REPRESENTS + This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface + defined above. In particular, it defines a softmax layer. To be precise, + we define the softmax function s(x) as: + s(x) == exp(x)/sum(exp(x)) + where x is a vector. Then this layer treats its input tensor as a + collection of tensor::num_samples() vectors and applies s() to each vector + in the tensor. Therefore, there are logically tensor::num_samples() + invocations of s(). + !*/ + + public: + + softmax_all_( + ); + + template void setup (const SUBNET& sub); + void forward_inplace(const tensor& input, tensor& output); + void backward_inplace(const tensor& computed_output, const tensor& gradient_input, tensor& data_grad, tensor& params_grad); + const tensor& get_layer_params() const; + tensor& get_layer_params(); + /*! + These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ + interface. Note that this layer doesn't have any parameters, so the tensor + returned by get_layer_params() is always empty. + !*/ + }; + + template + using softmax_all = add_layer; + +// ---------------------------------------------------------------------------------------- + + template < + template class tag + > + class add_prev_ + { + /*! + WHAT THIS OBJECT REPRESENTS + This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface + defined above. This layer simply adds the output of two previous layers. + In particular, it adds the tensor from its immediate predecessor layer, + sub.get_output(), with the tensor from a deeper layer, + layer(sub).get_output(). + + Therefore, you supply a tag via add_prev_'s template argument that tells it + what layer to add to the output of the previous layer. The result of this + addition is output by add_prev_. Finally, the addition happens pointwise + according to 4D tensor arithmetic. If the dimensions don't match then + missing elements are presumed to be equal to 0. Moreover, each dimension + of the output tensor is equal to the maximum dimension of either of the + inputs. That is, if the tensors A and B are being added to produce C then: + - C.num_samples() == max(A.num_samples(), B.num_samples()) + - C.k() == max(A.k(), B.k()) + - C.nr() == max(A.nr(), B.nr()) + - C.nc() == max(A.nc(), B.nc()) + !*/ + + public: + add_prev_( + ); + + template void setup (const SUBNET& sub); + template void forward(const SUBNET& sub, resizable_tensor& output); + template void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad); + dpoint map_input_to_output(dpoint p) const; + dpoint map_output_to_input(dpoint p) const; + const tensor& get_layer_params() const; + tensor& get_layer_params(); + /*! + These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. + !*/ + }; + + + template < + template class tag, + typename SUBNET + > + using add_prev = add_layer, SUBNET>; + + // Here we add some convenient aliases for using add_prev_ with the tag layers. + template using add_prev1 = add_prev; + template using add_prev2 = add_prev; + template using add_prev3 = add_prev; + template using add_prev4 = add_prev; + template using add_prev5 = add_prev; + template using add_prev6 = add_prev; + template using add_prev7 = add_prev; + template using add_prev8 = add_prev; + template using add_prev9 = add_prev; + template using add_prev10 = add_prev; + using add_prev1_ = add_prev_; + using add_prev2_ = add_prev_; + using add_prev3_ = add_prev_; + using add_prev4_ = add_prev_; + using add_prev5_ = add_prev_; + using add_prev6_ = add_prev_; + using add_prev7_ = add_prev_; + using add_prev8_ = add_prev_; + using add_prev9_ = add_prev_; + using add_prev10_ = add_prev_; + +// ---------------------------------------------------------------------------------------- + + template < + template class tag + > + class mult_prev_ + { + /*! + WHAT THIS OBJECT REPRESENTS + This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface + defined above. This layer simply multiplies the output of two previous + layers. In particular, it multiplies the tensor from its immediate + predecessor layer, sub.get_output(), with the tensor from a deeper layer, + layer(sub).get_output(). + + Therefore, you supply a tag via mult_prev_'s template argument that tells + it what layer to multiply with the output of the previous layer. The + result of this multiplication is output by mult_prev_. Finally, the + multiplication happens pointwise according to 4D tensor arithmetic. If the + dimensions don't match then missing elements are presumed to be equal to 0. + Moreover, each dimension of the output tensor is equal to the maximum + dimension of either of the inputs. That is, if the tensors A and B are + being multiplied to produce C then: + - C.num_samples() == max(A.num_samples(), B.num_samples()) + - C.k() == max(A.k(), B.k()) + - C.nr() == max(A.nr(), B.nr()) + - C.nc() == max(A.nc(), B.nc()) + !*/ + + public: + mult_prev_( + ); + + template void setup (const SUBNET& sub); + template void forward(const SUBNET& sub, resizable_tensor& output); + template void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad); + const tensor& get_layer_params() const; + tensor& get_layer_params(); + /*! + These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. + !*/ + }; + + + template < + template class tag, + typename SUBNET + > + using mult_prev = add_layer, SUBNET>; + + // Here we add some convenient aliases for using mult_prev_ with the tag layers. + template using mult_prev1 = mult_prev; + template using mult_prev2 = mult_prev; + template using mult_prev3 = mult_prev; + template using mult_prev4 = mult_prev; + template using mult_prev5 = mult_prev; + template using mult_prev6 = mult_prev; + template using mult_prev7 = mult_prev; + template using mult_prev8 = mult_prev; + template using mult_prev9 = mult_prev; + template using mult_prev10 = mult_prev; + using mult_prev1_ = mult_prev_; + using mult_prev2_ = mult_prev_; + using mult_prev3_ = mult_prev_; + using mult_prev4_ = mult_prev_; + using mult_prev5_ = mult_prev_; + using mult_prev6_ = mult_prev_; + using mult_prev7_ = mult_prev_; + using mult_prev8_ = mult_prev_; + using mult_prev9_ = mult_prev_; + using mult_prev10_ = mult_prev_; + +// ---------------------------------------------------------------------------------------- + + template < + template class tag + > + class scale_ + { + /*! + WHAT THIS OBJECT REPRESENTS + This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface + defined above. This layer scales the output channels of the tagged layer + by multiplying it with the output of the previous layer. To be specific: + - Let INPUT == layer(sub).get_output() + - Let SCALES == sub.get_output() + - This layer takes INPUT and SCALES as input. + - The output of this layer has the same dimensions as INPUT. + - This layer requires: + - SCALES.num_samples() == INPUT.num_samples() + - SCALES.k() == INPUT.k() + - SCALES.nr() == 1 + - SCALES.nc() == 1 + - The output tensor is produced by pointwise multiplying SCALES with + INPUT at each spatial location. Therefore, if OUT is the output of + this layer then we would have: + OUT(n,k,r,c) == INPUT(n,k,r,c)*SCALES(n,k) + !*/ + + public: + scale_( + ); + + template void setup (const SUBNET& sub); + template void forward(const SUBNET& sub, resizable_tensor& output); + template void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad); + const tensor& get_layer_params() const; + tensor& get_layer_params(); + /*! + These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. + !*/ + }; + + + template < + template class tag, + typename SUBNET + > + using scale = add_layer, SUBNET>; + + // Here we add some convenient aliases for using scale_ with the tag layers. + template using scale1 = scale; + template using scale2 = scale; + template using scale3 = scale; + template using scale4 = scale; + template using scale5 = scale; + template using scale6 = scale; + template using scale7 = scale; + template using scale8 = scale; + template using scale9 = scale; + template using scale10 = scale; + using scale1_ = scale_; + using scale2_ = scale_; + using scale3_ = scale_; + using scale4_ = scale_; + using scale5_ = scale_; + using scale6_ = scale_; + using scale7_ = scale_; + using scale8_ = scale_; + using scale9_ = scale_; + using scale10_ = scale_; + +// ---------------------------------------------------------------------------------------- + + template< + template class... TAG_TYPES + > + class concat_ + { + /*! + WHAT THIS OBJECT REPRESENTS + This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface + defined above. This layer simply concatenates the output of tagged layers. + Importantly, each input layer must have the same dimensions (i.e. + num_samples, nr, and nc) except for the k channel, which may vary. This is + because the concatenation happens along the k dimension. That is, the + output of this network is a tensor, OUT, that is the concatenation of the + tensors: + for each (tag in TAG_TYPES) + layer(subnet).get_output() + Therefore, out.num_samples(), out.nr(), and out.nc() match the dimensions + of the input tensors while OUT.k() is the sum of the input layer's k() + dimensions. + !*/ + + public: + template void setup (const SUBNET& sub); + template void forward(const SUBNET& sub, resizable_tensor& output); + template void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad); + dpoint map_input_to_output(dpoint p) const; + dpoint map_output_to_input(dpoint p) const; + const tensor& get_layer_params() const; + tensor& get_layer_params(); + /*! + These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. + !*/ + }; + + + // concat layer definitions + template class TAG1, + template class TAG2, + typename SUBNET> + using concat2 = add_layer, SUBNET>; + + template class TAG1, + template class TAG2, + template class TAG3, + typename SUBNET> + using concat3 = add_layer, SUBNET>; + + template class TAG1, + template class TAG2, + template class TAG3, + template class TAG4, + typename SUBNET> + using concat4 = add_layer, SUBNET>; + + template class TAG1, + template class TAG2, + template class TAG3, + template class TAG4, + template class TAG5, + typename SUBNET> + using concat5 = add_layer, SUBNET>; + +// ---------------------------------------------------------------------------------------- + + /*!A inception layer definitions !*/ + + // Now define inception layer tag types. These layer aliases allow creating + // the networks described in the paper: + // Szegedy, Christian, et al. "Going deeper with convolutions." Proceedings of + // the IEEE Conference on Computer Vision and Pattern Recognition. 2015. + // See the dnn_inception_ex.cpp example for a complete example of their use. Note also + // that we use tag ID numbers >= 1000 to avoid conflict with user's tag layers. + template using itag0 = add_tag_layer< 1000 + 0, SUBNET>; + template using itag1 = add_tag_layer< 1000 + 1, SUBNET>; + template using itag2 = add_tag_layer< 1000 + 2, SUBNET>; + template using itag3 = add_tag_layer< 1000 + 3, SUBNET>; + template using itag4 = add_tag_layer< 1000 + 4, SUBNET>; + template using itag5 = add_tag_layer< 1000 + 5, SUBNET>; + // skip to inception input + template using iskip = add_skip_layer< itag0, SUBNET>; + + // here are some templates to be used for creating inception layer groups + template class B1, + templateclass B2, + typename SUBNET> + using inception2 = concat2>>>>>>; + + template class B1, + templateclass B2, + templateclass B3, + typename SUBNET> + using inception3 = concat3>>>>>>>>>; + + template class B1, + templateclass B2, + templateclass B3, + templateclass B4, + typename SUBNET> + using inception4 = concat4>>>>>>>>>>>>; + + template class B1, + templateclass B2, + templateclass B3, + templateclass B4, + templateclass B5, + typename SUBNET> + using inception5 = concat5>>>>>>>>>>>>>>>; + +// ---------------------------------------------------------------------------------------- + + const double DEFAULT_L2_NORM_EPS = 1e-5; + + class l2normalize_ + { + /*! + WHAT THIS OBJECT REPRESENTS + This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface + defined above. It takes tensors as input and L2 normalizes them. In particular, + it has the following properties: + - The output tensors from this layer have the same dimensions as the + input tensors. + - If you think of each input tensor as a set of tensor::num_samples() + vectors, then the output tensor contains the same vectors except they + have been length normalized so that their L2 norms are all 1. I.e. + for each vector v we will have ||v||==1. + !*/ + + public: + + explicit l2normalize_( + double eps = tt::DEFAULT_L2_NORM_EPS + ); + /*! + requires + - eps > 0 + ensures + - #get_eps() == eps + !*/ + + double get_eps( + ) const; + /*! + ensures + - When we normalize a vector we divide it by its L2 norm. However, the + get_eps() value is added to the squared norm prior to division to avoid + ever dividing by zero. + !*/ + + template void setup (const SUBNET& sub); + void forward_inplace(const tensor& input, tensor& output); + void backward_inplace(const tensor& computed_output, const tensor& gradient_input, tensor& data_grad, tensor& params_grad); + const tensor& get_layer_params() const; + tensor& get_layer_params(); + /*! + These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + template < + long _offset, + long _k, + long _nr, + long _nc + > + class extract_ + { + /*! + REQUIREMENTS ON TEMPLATE ARGUMENTS + - 0 <= _offset + - 0 < _k + - 0 < _nr + - 0 < _nc + + WHAT THIS OBJECT REPRESENTS + This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface + defined above. In particular, the output of this layer is simply a copy of + the input tensor. However, you can configure the extract layer to output + only some subset of the input tensor and also to reshape it. Therefore, + the dimensions of the tensor output by this layer are as follows (letting + IN be the input tensor and OUT the output tensor): + - OUT.num_samples() == IN.num_samples() + - OUT.k() == _k + - OUT.nr() == _nr + - OUT.nc() == _nc + + So the output will always have the same number of samples as the input, but + within each sample (the k,nr,nc part) we will copy only a subset of the + values. Moreover, the _offset parameter controls which part of each sample + we take. To be very precise, we will have: + - let IN_SIZE = IN.k()*IN.nr()*IN.nc() + - let OUT_SIZE = _k*_nr*_nc + - for i in range[0,IN.num_samples()) and j in range[0,OUT_SIZE): + - OUT.host()[i*OUT_SIZE+j] == IN.host()[i*IN_SIZE+_offset+j] + + + Finally, all this means that the input tensor to this layer must have a big + enough size to accommodate taking a _k*_nr*_nc slice from each of its + samples. + !*/ + + public: + + template void setup (const SUBNET& sub); + template void forward(const SUBNET& sub, resizable_tensor& output); + template void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad); + const tensor& get_layer_params() const; + tensor& get_layer_params(); + /*! + These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. + !*/ + }; + + template < + long offset, + long k, + long nr, + long nc, + typename SUBNET + > + using extract = add_layer, SUBNET>; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_DNn_LAYERS_ABSTRACT_H_ + diff --git a/ml/dlib/dlib/dnn/loss.h b/ml/dlib/dlib/dnn/loss.h new file mode 100644 index 000000000..1b09b85c3 --- /dev/null +++ b/ml/dlib/dlib/dnn/loss.h @@ -0,0 +1,2870 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_DNn_LOSS_H_ +#define DLIB_DNn_LOSS_H_ + +#include "loss_abstract.h" +#include "core.h" +#include "../matrix.h" +#include "tensor_tools.h" +#include "../geometry.h" +#include "../image_processing/box_overlap_testing.h" +#include "../image_processing/full_object_detection.h" +#include "../svm/ranking_tools.h" +#include +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class loss_binary_hinge_ + { + public: + + typedef float training_label_type; + typedef float output_label_type; + + template < + typename SUB_TYPE, + typename label_iterator + > + void to_label ( + const tensor& input_tensor, + const SUB_TYPE& sub, + label_iterator iter + ) const + { + DLIB_CASSERT(sub.sample_expansion_factor() == 1); + + const tensor& output_tensor = sub.get_output(); + DLIB_CASSERT(output_tensor.nr() == 1 && + output_tensor.nc() == 1 && + output_tensor.k() == 1); + DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); + + const float* out_data = output_tensor.host(); + for (long i = 0; i < output_tensor.num_samples(); ++i) + { + *iter++ = out_data[i]; + } + } + + template < + typename const_label_iterator, + typename SUBNET + > + double compute_loss_value_and_gradient ( + const tensor& input_tensor, + const_label_iterator truth, + SUBNET& sub + ) const + { + const tensor& output_tensor = sub.get_output(); + tensor& grad = sub.get_gradient_input(); + + DLIB_CASSERT(sub.sample_expansion_factor() == 1); + DLIB_CASSERT(input_tensor.num_samples() != 0); + DLIB_CASSERT(input_tensor.num_samples()%sub.sample_expansion_factor() == 0); + DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples()); + DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); + DLIB_CASSERT(output_tensor.nr() == 1 && + output_tensor.nc() == 1 && + output_tensor.k() == 1); + + // The loss we output is the average loss over the mini-batch. + const double scale = 1.0/output_tensor.num_samples(); + double loss = 0; + const float* out_data = output_tensor.host(); + float* g = grad.host_write_only(); + for (long i = 0; i < output_tensor.num_samples(); ++i) + { + const float y = *truth++; + DLIB_CASSERT(y == +1 || y == -1, "y: " << y); + const float temp = 1-y*out_data[i]; + if (temp > 0) + { + loss += scale*temp; + g[i] = -scale*y; + } + else + { + g[i] = 0; + } + } + return loss; + } + + friend void serialize(const loss_binary_hinge_& , std::ostream& out) + { + serialize("loss_binary_hinge_", out); + } + + friend void deserialize(loss_binary_hinge_& , std::istream& in) + { + std::string version; + deserialize(version, in); + if (version != "loss_binary_hinge_") + throw serialization_error("Unexpected version found while deserializing dlib::loss_binary_hinge_."); + } + + friend std::ostream& operator<<(std::ostream& out, const loss_binary_hinge_& ) + { + out << "loss_binary_hinge"; + return out; + } + + friend void to_xml(const loss_binary_hinge_& /*item*/, std::ostream& out) + { + out << ""; + } + + }; + + template + using loss_binary_hinge = add_loss_layer; + +// ---------------------------------------------------------------------------------------- + + class loss_binary_log_ + { + public: + + typedef float training_label_type; + typedef float output_label_type; + + template < + typename SUB_TYPE, + typename label_iterator + > + void to_label ( + const tensor& input_tensor, + const SUB_TYPE& sub, + label_iterator iter + ) const + { + DLIB_CASSERT(sub.sample_expansion_factor() == 1); + + const tensor& output_tensor = sub.get_output(); + DLIB_CASSERT(output_tensor.nr() == 1 && + output_tensor.nc() == 1 && + output_tensor.k() == 1); + DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); + + const float* out_data = output_tensor.host(); + for (long i = 0; i < output_tensor.num_samples(); ++i) + { + *iter++ = out_data[i]; + } + } + + + template < + typename const_label_iterator, + typename SUBNET + > + double compute_loss_value_and_gradient ( + const tensor& input_tensor, + const_label_iterator truth, + SUBNET& sub + ) const + { + const tensor& output_tensor = sub.get_output(); + tensor& grad = sub.get_gradient_input(); + + DLIB_CASSERT(sub.sample_expansion_factor() == 1); + DLIB_CASSERT(input_tensor.num_samples() != 0); + DLIB_CASSERT(input_tensor.num_samples()%sub.sample_expansion_factor() == 0); + DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples()); + DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); + DLIB_CASSERT(output_tensor.nr() == 1 && + output_tensor.nc() == 1 && + output_tensor.k() == 1); + DLIB_CASSERT(grad.nr() == 1 && + grad.nc() == 1 && + grad.k() == 1); + + tt::sigmoid(grad, output_tensor); + + // The loss we output is the average loss over the mini-batch. + const double scale = 1.0/output_tensor.num_samples(); + double loss = 0; + float* g = grad.host(); + const float* out_data = output_tensor.host(); + for (long i = 0; i < output_tensor.num_samples(); ++i) + { + const float y = *truth++; + DLIB_CASSERT(y == +1 || y == -1, "y: " << y); + float temp; + if (y > 0) + { + temp = log1pexp(-out_data[i]); + loss += scale*temp; + g[i] = scale*(g[i]-1); + } + else + { + temp = -(-out_data[i]-log1pexp(-out_data[i])); + loss += scale*temp; + g[i] = scale*g[i]; + } + } + return loss; + } + + friend void serialize(const loss_binary_log_& , std::ostream& out) + { + serialize("loss_binary_log_", out); + } + + friend void deserialize(loss_binary_log_& , std::istream& in) + { + std::string version; + deserialize(version, in); + if (version != "loss_binary_log_") + throw serialization_error("Unexpected version found while deserializing dlib::loss_binary_log_."); + } + + friend std::ostream& operator<<(std::ostream& out, const loss_binary_log_& ) + { + out << "loss_binary_log"; + return out; + } + + friend void to_xml(const loss_binary_log_& /*item*/, std::ostream& out) + { + out << ""; + } + + }; + + template + T safe_log(T input, T epsilon = 1e-10) + { + // Prevent trying to calculate the logarithm of a very small number (let alone zero) + return std::log(std::max(input, epsilon)); + } + + template + using loss_binary_log = add_loss_layer; + +// ---------------------------------------------------------------------------------------- + + class loss_multiclass_log_ + { + public: + + typedef unsigned long training_label_type; + typedef unsigned long output_label_type; + + template < + typename SUB_TYPE, + typename label_iterator + > + void to_label ( + const tensor& input_tensor, + const SUB_TYPE& sub, + label_iterator iter + ) const + { + const tensor& output_tensor = sub.get_output(); + DLIB_CASSERT(sub.sample_expansion_factor() == 1); + DLIB_CASSERT(output_tensor.nr() == 1 && + output_tensor.nc() == 1 ); + DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); + + + // Note that output_tensor.k() should match the number of labels. + + for (long i = 0; i < output_tensor.num_samples(); ++i) + { + // The index of the largest output for this sample is the label. + *iter++ = index_of_max(rowm(mat(output_tensor),i)); + } + } + + + template < + typename const_label_iterator, + typename SUBNET + > + double compute_loss_value_and_gradient ( + const tensor& input_tensor, + const_label_iterator truth, + SUBNET& sub + ) const + { + const tensor& output_tensor = sub.get_output(); + tensor& grad = sub.get_gradient_input(); + + DLIB_CASSERT(sub.sample_expansion_factor() == 1); + DLIB_CASSERT(input_tensor.num_samples() != 0); + DLIB_CASSERT(input_tensor.num_samples()%sub.sample_expansion_factor() == 0); + DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples()); + DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); + DLIB_CASSERT(output_tensor.nr() == 1 && + output_tensor.nc() == 1); + DLIB_CASSERT(grad.nr() == 1 && + grad.nc() == 1); + + tt::softmax(grad, output_tensor); + + // The loss we output is the average loss over the mini-batch. + const double scale = 1.0/output_tensor.num_samples(); + double loss = 0; + float* g = grad.host(); + for (long i = 0; i < output_tensor.num_samples(); ++i) + { + const long y = (long)*truth++; + // The network must produce a number of outputs that is equal to the number + // of labels when using this type of loss. + DLIB_CASSERT(y < output_tensor.k(), "y: " << y << ", output_tensor.k(): " << output_tensor.k()); + for (long k = 0; k < output_tensor.k(); ++k) + { + const unsigned long idx = i*output_tensor.k()+k; + if (k == y) + { + loss += scale*-safe_log(g[idx]); + g[idx] = scale*(g[idx]-1); + } + else + { + g[idx] = scale*g[idx]; + } + } + } + return loss; + } + + friend void serialize(const loss_multiclass_log_& , std::ostream& out) + { + serialize("loss_multiclass_log_", out); + } + + friend void deserialize(loss_multiclass_log_& , std::istream& in) + { + std::string version; + deserialize(version, in); + if (version != "loss_multiclass_log_") + throw serialization_error("Unexpected version found while deserializing dlib::loss_multiclass_log_."); + } + + friend std::ostream& operator<<(std::ostream& out, const loss_multiclass_log_& ) + { + out << "loss_multiclass_log"; + return out; + } + + friend void to_xml(const loss_multiclass_log_& /*item*/, std::ostream& out) + { + out << ""; + } + + }; + + template + using loss_multiclass_log = add_loss_layer; + +// ---------------------------------------------------------------------------------------- + + class loss_multimulticlass_log_ + { + + public: + + loss_multimulticlass_log_ () = default; + + loss_multimulticlass_log_ ( + const std::map>& labels + ) + { + for (auto& l : labels) + { + possible_labels[l.first] = std::make_shared(l.second); + DLIB_CASSERT(l.second.size() >= 2, "Each classifier must have at least two possible labels."); + + for (size_t i = 0; i < l.second.size(); ++i) + { + label_idx_lookup[l.first][l.second[i]] = i; + ++total_num_labels; + } + } + } + + unsigned long number_of_labels() const { return total_num_labels; } + + unsigned long number_of_classifiers() const { return possible_labels.size(); } + + std::map> get_labels ( + ) const + { + std::map> info; + for (auto& i : possible_labels) + { + for (auto& label : *i.second) + info[i.first].emplace_back(label); + } + return info; + } + + class classifier_output + { + + public: + classifier_output() = default; + + size_t num_classes() const { return class_probs.size(); } + + double probability_of_class ( + size_t i + ) const + { + DLIB_CASSERT(i < num_classes()); + return class_probs(i); + } + + const std::string& label( + size_t i + ) const + { + DLIB_CASSERT(i < num_classes()); + return (*_labels)[i]; + } + + operator std::string( + ) const + { + DLIB_CASSERT(num_classes() != 0); + return (*_labels)[index_of_max(class_probs)]; + } + + friend std::ostream& operator<< (std::ostream& out, const classifier_output& item) + { + DLIB_ASSERT(item.num_classes() != 0); + out << static_cast(item); + return out; + } + + private: + + friend class loss_multimulticlass_log_; + + template + classifier_output( + const matrix_exp& class_probs, + const std::shared_ptr>& _labels + ) : + class_probs(class_probs), + _labels(_labels) + { + } + + matrix class_probs; + std::shared_ptr> _labels; + }; + + typedef std::map training_label_type; + typedef std::map output_label_type; + + + template < + typename SUB_TYPE, + typename label_iterator + > + void to_label ( + const tensor& input_tensor, + const SUB_TYPE& sub, + label_iterator iter_begin + ) const + { + const tensor& output_tensor = sub.get_output(); + DLIB_CASSERT(sub.sample_expansion_factor() == 1); + DLIB_CASSERT(output_tensor.nr() == 1 && + output_tensor.nc() == 1 ); + DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); + + DLIB_CASSERT(number_of_labels() != 0, "You must give the loss_multimulticlass_log_'s constructor label data before you can use it!"); + DLIB_CASSERT(output_tensor.k() == (long)number_of_labels(), "The output tensor must have " << number_of_labels() << " channels."); + + + long k_offset = 0; + for (auto& l : possible_labels) + { + auto iter = iter_begin; + const std::string& classifier_name = l.first; + const auto& labels = (*l.second); + scratch.set_size(output_tensor.num_samples(), labels.size()); + tt::copy_tensor(false, scratch, 0, output_tensor, k_offset, labels.size()); + + tt::softmax(scratch, scratch); + + for (long i = 0; i < scratch.num_samples(); ++i) + (*iter++)[classifier_name] = classifier_output(rowm(mat(scratch),i), l.second); + + k_offset += labels.size(); + } + } + + + template < + typename const_label_iterator, + typename SUBNET + > + double compute_loss_value_and_gradient ( + const tensor& input_tensor, + const_label_iterator truth_begin, + SUBNET& sub + ) const + { + const tensor& output_tensor = sub.get_output(); + tensor& grad = sub.get_gradient_input(); + + DLIB_CASSERT(sub.sample_expansion_factor() == 1); + DLIB_CASSERT(input_tensor.num_samples() != 0); + DLIB_CASSERT(input_tensor.num_samples()%sub.sample_expansion_factor() == 0); + DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples()); + DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); + DLIB_CASSERT(output_tensor.nr() == 1 && + output_tensor.nc() == 1); + DLIB_CASSERT(grad.nr() == 1 && + grad.nc() == 1); + DLIB_CASSERT(number_of_labels() != 0, "You must give the loss_multimulticlass_log_'s constructor label data before you can use it!"); + DLIB_CASSERT(output_tensor.k() == (long)number_of_labels(), "The output tensor must have " << number_of_labels() << " channels."); + + // The loss we output is the average loss over the mini-batch. + const double scale = 1.0/output_tensor.num_samples(); + double loss = 0; + long k_offset = 0; + for (auto& l : label_idx_lookup) + { + const std::string& classifier_name = l.first; + const auto& int_labels = l.second; + scratch.set_size(output_tensor.num_samples(), int_labels.size()); + tt::copy_tensor(false, scratch, 0, output_tensor, k_offset, int_labels.size()); + + tt::softmax(scratch, scratch); + + + auto truth = truth_begin; + float* g = scratch.host(); + for (long i = 0; i < scratch.num_samples(); ++i) + { + const long y = int_labels.at(truth->at(classifier_name)); + ++truth; + + for (long k = 0; k < scratch.k(); ++k) + { + const unsigned long idx = i*scratch.k()+k; + if (k == y) + { + loss += scale*-std::log(g[idx]); + g[idx] = scale*(g[idx]-1); + } + else + { + g[idx] = scale*g[idx]; + } + } + } + + tt::copy_tensor(false, grad, k_offset, scratch, 0, int_labels.size()); + + k_offset += int_labels.size(); + } + return loss; + } + + + friend void serialize(const loss_multimulticlass_log_& item, std::ostream& out) + { + serialize("loss_multimulticlass_log_", out); + serialize(item.get_labels(), out); + } + + friend void deserialize(loss_multimulticlass_log_& item, std::istream& in) + { + std::string version; + deserialize(version, in); + if (version != "loss_multimulticlass_log_") + throw serialization_error("Unexpected version found while deserializing dlib::loss_multimulticlass_log_."); + + std::map> info; + deserialize(info, in); + item = loss_multimulticlass_log_(info); + } + + friend std::ostream& operator<<(std::ostream& out, const loss_multimulticlass_log_& item) + { + out << "loss_multimulticlass_log, labels={"; + for (auto i = item.possible_labels.begin(); i != item.possible_labels.end(); ) + { + auto& category = i->first; + auto& labels = *(i->second); + out << category << ":("; + for (size_t j = 0; j < labels.size(); ++j) + { + out << labels[j]; + if (j+1 < labels.size()) + out << ","; + } + + out << ")"; + if (++i != item.possible_labels.end()) + out << ", "; + } + out << "}"; + return out; + } + + friend void to_xml(const loss_multimulticlass_log_& item, std::ostream& out) + { + out << "\n"; + out << item; + out << "\n"; + } + + private: + + std::map>> possible_labels; + unsigned long total_num_labels = 0; + + // We make it true that: possible_labels[classifier][label_idx_lookup[classifier][label]] == label + std::map> label_idx_lookup; + + + // Scratch doesn't logically contribute to the state of this object. It's just + // temporary scratch space used by this class. + mutable resizable_tensor scratch; + + + }; + + template + using loss_multimulticlass_log = add_loss_layer; + + inline bool operator== (const std::string& lhs, const loss_multimulticlass_log_::classifier_output& rhs) + { return lhs == static_cast(rhs); } + inline bool operator== (const loss_multimulticlass_log_::classifier_output& lhs, const std::string& rhs) + { return rhs == static_cast(lhs); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + enum class use_image_pyramid : uint8_t + { + no, + yes + }; + + struct mmod_options + { + public: + + struct detector_window_details + { + detector_window_details() = default; + detector_window_details(unsigned long w, unsigned long h) : width(w), height(h) {} + detector_window_details(unsigned long w, unsigned long h, const std::string& l) : width(w), height(h), label(l) {} + + unsigned long width = 0; + unsigned long height = 0; + std::string label; + + friend inline void serialize(const detector_window_details& item, std::ostream& out) + { + int version = 2; + serialize(version, out); + serialize(item.width, out); + serialize(item.height, out); + serialize(item.label, out); + } + + friend inline void deserialize(detector_window_details& item, std::istream& in) + { + int version = 0; + deserialize(version, in); + if (version != 1 && version != 2) + throw serialization_error("Unexpected version found while deserializing dlib::mmod_options::detector_window_details"); + deserialize(item.width, in); + deserialize(item.height, in); + if (version == 2) + deserialize(item.label, in); + } + + }; + + mmod_options() = default; + + std::vector detector_windows; + double loss_per_false_alarm = 1; + double loss_per_missed_target = 1; + double truth_match_iou_threshold = 0.5; + test_box_overlap overlaps_nms = test_box_overlap(0.4); + test_box_overlap overlaps_ignore; + + use_image_pyramid assume_image_pyramid = use_image_pyramid::yes; + + mmod_options ( + const std::vector>& boxes, + const unsigned long target_size, // We want the length of the longest dimension of the detector window to be this. + const unsigned long min_target_size, // But we require that the smallest dimension of the detector window be at least this big. + const double min_detector_window_overlap_iou = 0.75 + ) + { + DLIB_CASSERT(0 < min_target_size && min_target_size <= target_size); + DLIB_CASSERT(0.5 < min_detector_window_overlap_iou && min_detector_window_overlap_iou < 1); + + // Figure out what detector windows we will need. + for (auto& label : get_labels(boxes)) + { + for (auto ratio : find_covering_aspect_ratios(boxes, test_box_overlap(min_detector_window_overlap_iou), label)) + { + double detector_width; + double detector_height; + if (ratio < 1) + { + detector_height = target_size; + detector_width = ratio*target_size; + if (detector_width < min_target_size) + { + detector_height = min_target_size/ratio; + detector_width = min_target_size; + } + } + else + { + detector_width = target_size; + detector_height = target_size/ratio; + if (detector_height < min_target_size) + { + detector_width = min_target_size*ratio; + detector_height = min_target_size; + } + } + + detector_window_details p((unsigned long)std::round(detector_width), (unsigned long)std::round(detector_height), label); + detector_windows.push_back(p); + } + } + + DLIB_CASSERT(detector_windows.size() != 0, "You can't call mmod_options's constructor with a set of boxes that is empty (or only contains ignored boxes)."); + + set_overlap_nms(boxes); + } + + mmod_options( + use_image_pyramid assume_image_pyramid, + const std::vector>& boxes, + const double min_detector_window_overlap_iou = 0.75 + ) + : assume_image_pyramid(assume_image_pyramid) + { + DLIB_CASSERT(assume_image_pyramid == use_image_pyramid::no); + DLIB_CASSERT(0.5 < min_detector_window_overlap_iou && min_detector_window_overlap_iou < 1); + + // Figure out what detector windows we will need. + for (auto& label : get_labels(boxes)) + { + for (auto rectangle : find_covering_rectangles(boxes, test_box_overlap(min_detector_window_overlap_iou), label)) + { + detector_windows.push_back(detector_window_details(rectangle.width(), rectangle.height(), label)); + } + } + + DLIB_CASSERT(detector_windows.size() != 0, "You can't call mmod_options's constructor with a set of boxes that is empty (or only contains ignored boxes)."); + + set_overlap_nms(boxes); + } + + private: + + void set_overlap_nms(const std::vector>& boxes) + { + // Convert from mmod_rect to rectangle so we can call + // find_tight_overlap_tester(). + std::vector> temp; + for (auto&& bi : boxes) + { + std::vector rtemp; + for (auto&& b : bi) + { + if (b.ignore) + continue; + rtemp.push_back(b.rect); + } + temp.push_back(std::move(rtemp)); + } + overlaps_nms = find_tight_overlap_tester(temp); + // Relax the non-max-suppression a little so that it doesn't accidentally make + // it impossible for the detector to output boxes matching the training data. + // This could be a problem with the tightest possible nms test since there is + // some small variability in how boxes get positioned between the training data + // and the coordinate system used by the detector when it runs. So relaxing it + // here takes care of that. + auto iou_thresh = advance_toward_1(overlaps_nms.get_iou_thresh()); + auto percent_covered_thresh = advance_toward_1(overlaps_nms.get_percent_covered_thresh()); + overlaps_nms = test_box_overlap(iou_thresh, percent_covered_thresh); + } + + static double advance_toward_1 ( + double val + ) + { + if (val < 1) + val += (1-val)*0.1; + return val; + } + + static size_t count_overlaps ( + const std::vector& rects, + const test_box_overlap& overlaps, + const rectangle& ref_box + ) + { + size_t cnt = 0; + for (auto& b : rects) + { + if (overlaps(b, ref_box)) + ++cnt; + } + return cnt; + } + + static std::vector find_rectangles_overlapping_all_others ( + std::vector rects, + const test_box_overlap& overlaps + ) + { + std::vector exemplars; + dlib::rand rnd; + + while(rects.size() > 0) + { + // Pick boxes at random and see if they overlap a lot of other boxes. We will try + // 500 different boxes each iteration and select whichever hits the most others to + // add to our exemplar set. + rectangle best_ref_box; + size_t best_cnt = 0; + for (int iter = 0; iter < 500; ++iter) + { + rectangle ref_box = rects[rnd.get_random_64bit_number()%rects.size()]; + size_t cnt = count_overlaps(rects, overlaps, ref_box); + if (cnt >= best_cnt) + { + best_cnt = cnt; + best_ref_box = ref_box; + } + } + + // Now mark all the boxes the new ref box hit as hit. + for (size_t i = 0; i < rects.size(); ++i) + { + if (overlaps(rects[i], best_ref_box)) + { + // remove box from rects so we don't hit it again later + swap(rects[i], rects.back()); + rects.pop_back(); + --i; + } + } + + exemplars.push_back(best_ref_box); + } + + return exemplars; + } + + static std::set get_labels ( + const std::vector>& rects + ) + { + std::set labels; + for (auto& rr : rects) + { + for (auto& r : rr) + labels.insert(r.label); + } + return labels; + } + + static std::vector find_covering_aspect_ratios ( + const std::vector>& rects, + const test_box_overlap& overlaps, + const std::string& label + ) + { + std::vector boxes; + // Make sure all the boxes have the same size and position, so that the only thing our + // checks for overlap will care about is aspect ratio (i.e. scale and x,y position are + // ignored). + for (auto& bb : rects) + { + for (auto&& b : bb) + { + if (!b.ignore && b.label == label) + boxes.push_back(move_rect(set_rect_area(b.rect,400*400), point(0,0))); + } + } + + std::vector ratios; + for (auto r : find_rectangles_overlapping_all_others(boxes, overlaps)) + ratios.push_back(r.width()/(double)r.height()); + return ratios; + } + + static std::vector find_covering_rectangles ( + const std::vector>& rects, + const test_box_overlap& overlaps, + const std::string& label + ) + { + std::vector boxes; + // Make sure all the boxes have the same position, so that the we only check for + // width and height. + for (auto& bb : rects) + { + for (auto&& b : bb) + { + if (!b.ignore && b.label == label) + boxes.push_back(rectangle(b.rect.width(), b.rect.height())); + } + } + + return find_rectangles_overlapping_all_others(boxes, overlaps); + } + }; + + inline void serialize(const mmod_options& item, std::ostream& out) + { + int version = 3; + + serialize(version, out); + serialize(item.detector_windows, out); + serialize(item.loss_per_false_alarm, out); + serialize(item.loss_per_missed_target, out); + serialize(item.truth_match_iou_threshold, out); + serialize(item.overlaps_nms, out); + serialize(item.overlaps_ignore, out); + serialize(static_cast(item.assume_image_pyramid), out); + } + + inline void deserialize(mmod_options& item, std::istream& in) + { + int version = 0; + deserialize(version, in); + if (version != 3 && version != 2 && version != 1) + throw serialization_error("Unexpected version found while deserializing dlib::mmod_options"); + if (version == 1) + { + unsigned long width; + unsigned long height; + deserialize(width, in); + deserialize(height, in); + item.detector_windows = {mmod_options::detector_window_details(width, height)}; + } + else + { + deserialize(item.detector_windows, in); + } + deserialize(item.loss_per_false_alarm, in); + deserialize(item.loss_per_missed_target, in); + deserialize(item.truth_match_iou_threshold, in); + deserialize(item.overlaps_nms, in); + deserialize(item.overlaps_ignore, in); + item.assume_image_pyramid = use_image_pyramid::yes; + if (version >= 3) + { + uint8_t assume_image_pyramid = 0; + deserialize(assume_image_pyramid, in); + item.assume_image_pyramid = static_cast(assume_image_pyramid); + } + } + +// ---------------------------------------------------------------------------------------- + + class loss_mmod_ + { + struct intermediate_detection + { + intermediate_detection() = default; + + intermediate_detection( + rectangle rect_ + ) : rect(rect_) {} + + intermediate_detection( + rectangle rect_, + double detection_confidence_, + size_t tensor_offset_, + long channel + ) : rect(rect_), detection_confidence(detection_confidence_), tensor_offset(tensor_offset_), tensor_channel(channel) {} + + rectangle rect; + double detection_confidence = 0; + size_t tensor_offset = 0; + long tensor_channel = 0; + + bool operator<(const intermediate_detection& item) const { return detection_confidence < item.detection_confidence; } + }; + + public: + + typedef std::vector training_label_type; + typedef std::vector output_label_type; + + loss_mmod_() {} + + loss_mmod_(mmod_options options_) : options(options_) {} + + const mmod_options& get_options ( + ) const { return options; } + + template < + typename SUB_TYPE, + typename label_iterator + > + void to_label ( + const tensor& input_tensor, + const SUB_TYPE& sub, + label_iterator iter, + double adjust_threshold = 0 + ) const + { + const tensor& output_tensor = sub.get_output(); + DLIB_CASSERT(output_tensor.k() == (long)options.detector_windows.size()); + DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); + DLIB_CASSERT(sub.sample_expansion_factor() == 1, sub.sample_expansion_factor()); + + std::vector dets_accum; + output_label_type final_dets; + for (long i = 0; i < output_tensor.num_samples(); ++i) + { + tensor_to_dets(input_tensor, output_tensor, i, dets_accum, adjust_threshold, sub); + + // Do non-max suppression + final_dets.clear(); + for (unsigned long i = 0; i < dets_accum.size(); ++i) + { + if (overlaps_any_box_nms(final_dets, dets_accum[i].rect)) + continue; + + final_dets.push_back(mmod_rect(dets_accum[i].rect, + dets_accum[i].detection_confidence, + options.detector_windows[dets_accum[i].tensor_channel].label)); + } + + *iter++ = std::move(final_dets); + } + } + + template < + typename const_label_iterator, + typename SUBNET + > + double compute_loss_value_and_gradient ( + const tensor& input_tensor, + const_label_iterator truth, + SUBNET& sub + ) const + { + const tensor& output_tensor = sub.get_output(); + tensor& grad = sub.get_gradient_input(); + + DLIB_CASSERT(input_tensor.num_samples() != 0); + DLIB_CASSERT(sub.sample_expansion_factor() == 1); + DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples()); + DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); + DLIB_CASSERT(output_tensor.k() == (long)options.detector_windows.size()); + + double det_thresh_speed_adjust = 0; + + + // we will scale the loss so that it doesn't get really huge + const double scale = 1.0/output_tensor.size(); + double loss = 0; + + float* g = grad.host_write_only(); + for (size_t i = 0; i < grad.size(); ++i) + g[i] = 0; + + const float* out_data = output_tensor.host(); + + std::vector truth_idxs; truth_idxs.reserve(truth->size()); + std::vector dets; + for (long i = 0; i < output_tensor.num_samples(); ++i) + { + tensor_to_dets(input_tensor, output_tensor, i, dets, -options.loss_per_false_alarm + det_thresh_speed_adjust, sub); + + const unsigned long max_num_dets = 50 + truth->size()*5; + // Prevent calls to tensor_to_dets() from running for a really long time + // due to the production of an obscene number of detections. + const unsigned long max_num_initial_dets = max_num_dets*100; + if (dets.size() >= max_num_initial_dets) + { + det_thresh_speed_adjust = std::max(det_thresh_speed_adjust,dets[max_num_initial_dets].detection_confidence + options.loss_per_false_alarm); + } + + + // The loss will measure the number of incorrect detections. A detection is + // incorrect if it doesn't hit a truth rectangle or if it is a duplicate detection + // on a truth rectangle. + loss += truth->size()*options.loss_per_missed_target; + for (auto&& x : *truth) + { + if (!x.ignore) + { + size_t k; + point p; + if(image_rect_to_feat_coord(p, input_tensor, x, x.label, sub, k, options.assume_image_pyramid)) + { + // Ignore boxes that can't be detected by the CNN. + loss -= options.loss_per_missed_target; + continue; + } + const size_t idx = (k*output_tensor.nr() + p.y())*output_tensor.nc() + p.x(); + loss -= out_data[idx]; + // compute gradient + g[idx] = -scale; + truth_idxs.push_back(idx); + } + else + { + // This box was ignored so shouldn't have been counted in the loss. + loss -= options.loss_per_missed_target; + truth_idxs.push_back(0); + } + } + + // Measure the loss augmented score for the detections which hit a truth rect. + std::vector truth_score_hits(truth->size(), 0); + + // keep track of which truth boxes we have hit so far. + std::vector hit_truth_table(truth->size(), false); + + std::vector final_dets; + // The point of this loop is to fill out the truth_score_hits array. + for (unsigned long i = 0; i < dets.size() && final_dets.size() < max_num_dets; ++i) + { + if (overlaps_any_box_nms(final_dets, dets[i].rect)) + continue; + + const auto& det_label = options.detector_windows[dets[i].tensor_channel].label; + + const std::pair hittruth = find_best_match(*truth, dets[i].rect, det_label); + + final_dets.push_back(dets[i].rect); + + const double truth_match = hittruth.first; + // if hit truth rect + if (truth_match > options.truth_match_iou_threshold) + { + // if this is the first time we have seen a detect which hit (*truth)[hittruth.second] + const double score = dets[i].detection_confidence; + if (hit_truth_table[hittruth.second] == false) + { + hit_truth_table[hittruth.second] = true; + truth_score_hits[hittruth.second] += score; + } + else + { + truth_score_hits[hittruth.second] += score + options.loss_per_false_alarm; + } + } + } + + // Check if any of the truth boxes are unobtainable because the NMS is + // killing them. If so, automatically set those unobtainable boxes to + // ignore and print a warning message to the user. + for (size_t i = 0; i < hit_truth_table.size(); ++i) + { + if (!hit_truth_table[i] && !(*truth)[i].ignore) + { + // So we didn't hit this truth box. Is that because there is + // another, different truth box, that overlaps it according to NMS? + const std::pair hittruth = find_best_match(*truth, (*truth)[i], i); + if (hittruth.second == i || (*truth)[hittruth.second].ignore) + continue; + rectangle best_matching_truth_box = (*truth)[hittruth.second]; + if (options.overlaps_nms(best_matching_truth_box, (*truth)[i])) + { + const size_t idx = truth_idxs[i]; + // We are ignoring this box so we shouldn't have counted it in the + // loss in the first place. So we subtract out the loss values we + // added for it in the code above. + loss -= options.loss_per_missed_target-out_data[idx]; + g[idx] = 0; + std::cout << "Warning, ignoring object. We encountered a truth rectangle located at " << (*truth)[i].rect; + std::cout << " that is suppressed by non-max-suppression "; + std::cout << "because it is overlapped by another truth rectangle located at " << best_matching_truth_box + << " (IoU:"<< box_intersection_over_union(best_matching_truth_box,(*truth)[i]) <<", Percent covered:" + << box_percent_covered(best_matching_truth_box,(*truth)[i]) << ")." << std::endl; + } + } + } + + hit_truth_table.assign(hit_truth_table.size(), false); + final_dets.clear(); + + + // Now figure out which detections jointly maximize the loss and detection score sum. We + // need to take into account the fact that allowing a true detection in the output, while + // initially reducing the loss, may allow us to increase the loss later with many duplicate + // detections. + for (unsigned long i = 0; i < dets.size() && final_dets.size() < max_num_dets; ++i) + { + if (overlaps_any_box_nms(final_dets, dets[i].rect)) + continue; + + const auto& det_label = options.detector_windows[dets[i].tensor_channel].label; + + const std::pair hittruth = find_best_match(*truth, dets[i].rect, det_label); + + const double truth_match = hittruth.first; + if (truth_match > options.truth_match_iou_threshold) + { + if (truth_score_hits[hittruth.second] > options.loss_per_missed_target) + { + if (!hit_truth_table[hittruth.second]) + { + hit_truth_table[hittruth.second] = true; + final_dets.push_back(dets[i]); + loss -= options.loss_per_missed_target; + } + else + { + final_dets.push_back(dets[i]); + loss += options.loss_per_false_alarm; + } + } + } + else if (!overlaps_ignore_box(*truth, dets[i].rect)) + { + // didn't hit anything + final_dets.push_back(dets[i]); + loss += options.loss_per_false_alarm; + } + } + + for (auto&& x : final_dets) + { + loss += out_data[x.tensor_offset]; + g[x.tensor_offset] += scale; + } + + ++truth; + g += output_tensor.k()*output_tensor.nr()*output_tensor.nc(); + out_data += output_tensor.k()*output_tensor.nr()*output_tensor.nc(); + } // END for (long i = 0; i < output_tensor.num_samples(); ++i) + + + // Here we scale the loss so that it's roughly equal to the number of mistakes + // in an image. Note that this scaling is different than the scaling we + // applied to the gradient but it doesn't matter since the loss value isn't + // used to update parameters. It's used only for display and to check if we + // have converged. So it doesn't matter that they are scaled differently and + // this way the loss that is displayed is readily interpretable to the user. + return loss/output_tensor.num_samples(); + } + + + friend void serialize(const loss_mmod_& item, std::ostream& out) + { + serialize("loss_mmod_", out); + serialize(item.options, out); + } + + friend void deserialize(loss_mmod_& item, std::istream& in) + { + std::string version; + deserialize(version, in); + if (version != "loss_mmod_") + throw serialization_error("Unexpected version found while deserializing dlib::loss_mmod_."); + deserialize(item.options, in); + } + + friend std::ostream& operator<<(std::ostream& out, const loss_mmod_& item) + { + out << "loss_mmod\t ("; + + out << "detector_windows:("; + auto& opts = item.options; + for (size_t i = 0; i < opts.detector_windows.size(); ++i) + { + out << opts.detector_windows[i].width << "x" << opts.detector_windows[i].height; + if (i+1 < opts.detector_windows.size()) + out << ","; + } + out << ")"; + out << ", loss per FA:" << opts.loss_per_false_alarm; + out << ", loss per miss:" << opts.loss_per_missed_target; + out << ", truth match IOU thresh:" << opts.truth_match_iou_threshold; + out << ", overlaps_nms:("<"; + } + + private: + + template + void tensor_to_dets ( + const tensor& input_tensor, + const tensor& output_tensor, + long i, + std::vector& dets_accum, + double adjust_threshold, + const net_type& net + ) const + { + DLIB_CASSERT(net.sample_expansion_factor() == 1,net.sample_expansion_factor()); + DLIB_CASSERT(output_tensor.k() == (long)options.detector_windows.size()); + const float* out_data = output_tensor.host() + output_tensor.k()*output_tensor.nr()*output_tensor.nc()*i; + // scan the final layer and output the positive scoring locations + dets_accum.clear(); + for (long k = 0; k < output_tensor.k(); ++k) + { + for (long r = 0; r < output_tensor.nr(); ++r) + { + for (long c = 0; c < output_tensor.nc(); ++c) + { + double score = out_data[(k*output_tensor.nr() + r)*output_tensor.nc() + c]; + if (score > adjust_threshold) + { + dpoint p = output_tensor_to_input_tensor(net, point(c,r)); + drectangle rect = centered_drect(p, options.detector_windows[k].width, options.detector_windows[k].height); + rect = input_layer(net).tensor_space_to_image_space(input_tensor,rect); + + dets_accum.push_back(intermediate_detection(rect, score, (k*output_tensor.nr() + r)*output_tensor.nc() + c, k)); + } + } + } + } + std::sort(dets_accum.rbegin(), dets_accum.rend()); + } + + size_t find_best_detection_window ( + rectangle rect, + const std::string& label, + use_image_pyramid assume_image_pyramid + ) const + { + if (assume_image_pyramid == use_image_pyramid::yes) + { + rect = move_rect(set_rect_area(rect, 400*400), point(0,0)); + } + else + { + rect = rectangle(rect.width(), rect.height()); + } + + // Figure out which detection window in options.detector_windows is most similar to rect + // (in terms of aspect ratio, if assume_image_pyramid == use_image_pyramid::yes). + size_t best_i = 0; + double best_ratio_diff = -std::numeric_limits::infinity(); + for (size_t i = 0; i < options.detector_windows.size(); ++i) + { + if (options.detector_windows[i].label != label) + continue; + + rectangle det_window; + + if (options.assume_image_pyramid == use_image_pyramid::yes) + { + det_window = centered_rect(point(0,0), options.detector_windows[i].width, options.detector_windows[i].height); + det_window = move_rect(set_rect_area(det_window, 400*400), point(0,0)); + } + else + { + det_window = rectangle(options.detector_windows[i].width, options.detector_windows[i].height); + } + + double iou = box_intersection_over_union(rect, det_window); + if (iou > best_ratio_diff) + { + best_ratio_diff = iou; + best_i = i; + } + } + return best_i; + } + + template + bool image_rect_to_feat_coord ( + point& tensor_p, + const tensor& input_tensor, + const rectangle& rect, + const std::string& label, + const net_type& net, + size_t& det_idx, + use_image_pyramid assume_image_pyramid + ) const + { + using namespace std; + if (!input_layer(net).image_contained_point(input_tensor,center(rect))) + { + std::ostringstream sout; + sout << "Encountered a truth rectangle located at " << rect << " that is outside the image." << endl; + sout << "The center of each truth rectangle must be within the image." << endl; + throw impossible_labeling_error(sout.str()); + } + + det_idx = find_best_detection_window(rect,label,assume_image_pyramid); + + double scale = 1.0; + if (options.assume_image_pyramid == use_image_pyramid::yes) + { + // Compute the scale we need to be at to get from rect to our detection window. + // Note that we compute the scale as the max of two numbers. It doesn't + // actually matter which one we pick, because if they are very different then + // it means the box can't be matched by the sliding window. But picking the + // max causes the right error message to be selected in the logic below. + scale = std::max(options.detector_windows[det_idx].width/(double)rect.width(), options.detector_windows[det_idx].height/(double)rect.height()); + } + else + { + // We don't want invariance to scale. + scale = 1.0; + } + + const rectangle mapped_rect = input_layer(net).image_space_to_tensor_space(input_tensor, std::min(1.0,scale), rect); + + // compute the detection window that we would use at this position. + tensor_p = center(mapped_rect); + rectangle det_window = centered_rect(tensor_p, options.detector_windows[det_idx].width,options.detector_windows[det_idx].height); + det_window = input_layer(net).tensor_space_to_image_space(input_tensor, det_window); + + // make sure the rect can actually be represented by the image pyramid we are + // using. + if (box_intersection_over_union(rect, det_window) <= options.truth_match_iou_threshold) + { + std::cout << "Warning, ignoring object. We encountered a truth rectangle with a width and height of " << rect.width() << " and " << rect.height() << ". "; + std::cout << "The image pyramid and sliding windows can't output a rectangle of this shape. "; + const double detector_area = options.detector_windows[det_idx].width*options.detector_windows[det_idx].height; + if (mapped_rect.area()/detector_area <= options.truth_match_iou_threshold) + { + std::cout << "This is because the rectangle is smaller than the best matching detection window, which has a width "; + std::cout << "and height of " << options.detector_windows[det_idx].width << " and " << options.detector_windows[det_idx].height << "." << std::endl; + } + else + { + std::cout << "This is either because (1) the final layer's features have too large of a stride across the image, limiting the possible locations the sliding window can search "; + std::cout << "or (2) because the rectangle's aspect ratio is too different from the best matching detection window, "; + std::cout << "which has a width and height of " << options.detector_windows[det_idx].width << " and " << options.detector_windows[det_idx].height << "." << std::endl; + } + return true; + } + + // now map through the CNN to the output layer. + tensor_p = input_tensor_to_output_tensor(net,tensor_p); + + const tensor& output_tensor = net.get_output(); + if (!get_rect(output_tensor).contains(tensor_p)) + { + std::cout << "Warning, ignoring object. We encountered a truth rectangle located at " << rect << " that is too close to the edge "; + std::cout << "of the image to be captured by the CNN features." << std::endl; + return true; + } + + return false; + } + + + bool overlaps_ignore_box ( + const std::vector& boxes, + const rectangle& rect + ) const + { + for (auto&& b : boxes) + { + if (b.ignore && options.overlaps_ignore(b, rect)) + return true; + } + return false; + } + + std::pair find_best_match( + const std::vector& boxes, + const rectangle& rect, + const std::string& label + ) const + { + double match = 0; + unsigned int best_idx = 0; + for (unsigned long i = 0; i < boxes.size(); ++i) + { + if (boxes[i].ignore || boxes[i].label != label) + continue; + + const double new_match = box_intersection_over_union(rect, boxes[i]); + if (new_match > match) + { + match = new_match; + best_idx = i; + } + } + + return std::make_pair(match,best_idx); + } + + std::pair find_best_match( + const std::vector& boxes, + const rectangle& rect, + const size_t excluded_idx + ) const + { + double match = 0; + unsigned int best_idx = 0; + for (unsigned long i = 0; i < boxes.size(); ++i) + { + if (boxes[i].ignore || excluded_idx == i) + continue; + + const double new_match = box_intersection_over_union(rect, boxes[i]); + if (new_match > match) + { + match = new_match; + best_idx = i; + } + } + + return std::make_pair(match,best_idx); + } + + template + inline bool overlaps_any_box_nms ( + const std::vector& rects, + const rectangle& rect + ) const + { + for (auto&& r : rects) + { + if (options.overlaps_nms(r.rect, rect)) + return true; + } + return false; + } + + + mmod_options options; + + }; + + template + using loss_mmod = add_loss_layer; + +// ---------------------------------------------------------------------------------------- + + class loss_metric_ + { + public: + + typedef unsigned long training_label_type; + typedef matrix output_label_type; + + loss_metric_() = default; + + loss_metric_( + float margin_, + float dist_thresh_ + ) : margin(margin_), dist_thresh(dist_thresh_) + { + DLIB_CASSERT(margin_ > 0); + DLIB_CASSERT(dist_thresh_ > 0); + } + + template < + typename SUB_TYPE, + typename label_iterator + > + void to_label ( + const tensor& input_tensor, + const SUB_TYPE& sub, + label_iterator iter + ) const + { + const tensor& output_tensor = sub.get_output(); + DLIB_CASSERT(sub.sample_expansion_factor() == 1); + DLIB_CASSERT(input_tensor.num_samples() != 0); + DLIB_CASSERT(input_tensor.num_samples()%sub.sample_expansion_factor() == 0); + DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); + DLIB_CASSERT(output_tensor.nr() == 1 && + output_tensor.nc() == 1); + + const float* p = output_tensor.host(); + for (long i = 0; i < output_tensor.num_samples(); ++i) + { + *iter = mat(p,output_tensor.k(),1); + + ++iter; + p += output_tensor.k(); + } + } + + + float get_margin() const { return margin; } + float get_distance_threshold() const { return dist_thresh; } + + template < + typename const_label_iterator, + typename SUBNET + > + double compute_loss_value_and_gradient ( + const tensor& input_tensor, + const_label_iterator truth, + SUBNET& sub + ) const + { + const tensor& output_tensor = sub.get_output(); + tensor& grad = sub.get_gradient_input(); + + DLIB_CASSERT(sub.sample_expansion_factor() == 1); + DLIB_CASSERT(input_tensor.num_samples() != 0); + DLIB_CASSERT(input_tensor.num_samples()%sub.sample_expansion_factor() == 0); + DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples()); + DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); + DLIB_CASSERT(output_tensor.nr() == 1 && + output_tensor.nc() == 1); + DLIB_CASSERT(grad.nr() == 1 && + grad.nc() == 1); + + + + temp.set_size(output_tensor.num_samples(), output_tensor.num_samples()); + grad_mul.copy_size(temp); + + tt::gemm(0, temp, 1, output_tensor, false, output_tensor, true); + + + std::vector temp_threshs; + const float* d = temp.host(); + double loss = 0; + double num_pos_samps = 0.0001; + double num_neg_samps = 0.0001; + for (long r = 0; r < temp.num_samples(); ++r) + { + auto xx = d[r*temp.num_samples() + r]; + const auto x_label = *(truth + r); + for (long c = r+1; c < temp.num_samples(); ++c) + { + const auto y_label = *(truth + c); + if (x_label == y_label) + { + ++num_pos_samps; + } + else + { + ++num_neg_samps; + + // Figure out what distance threshold, when applied to the negative pairs, + // causes there to be an equal number of positive and negative pairs. + auto yy = d[c*temp.num_samples() + c]; + auto xy = d[r*temp.num_samples() + c]; + // compute the distance between x and y samples. + auto d2 = xx + yy - 2*xy; + if (d2 < 0) + d2 = 0; + temp_threshs.push_back(d2); + } + } + } + // The whole objective function is multiplied by this to scale the loss + // relative to the number of things in the mini-batch. + const double scale = 0.5/num_pos_samps; + DLIB_CASSERT(num_pos_samps>=1, "Make sure each mini-batch contains both positive pairs and negative pairs"); + DLIB_CASSERT(num_neg_samps>=1, "Make sure each mini-batch contains both positive pairs and negative pairs"); + + std::sort(temp_threshs.begin(), temp_threshs.end()); + const float neg_thresh = std::sqrt(temp_threshs[std::min(num_pos_samps,num_neg_samps)-1]); + + // loop over all the pairs of training samples and compute the loss and + // gradients. Note that we only use the hardest negative pairs and that in + // particular we pick the number of negative pairs equal to the number of + // positive pairs so everything is balanced. + float* gm = grad_mul.host(); + for (long r = 0; r < temp.num_samples(); ++r) + { + gm[r*temp.num_samples() + r] = 0; + const auto x_label = *(truth + r); + auto xx = d[r*temp.num_samples() + r]; + for (long c = 0; c < temp.num_samples(); ++c) + { + if (r==c) + continue; + const auto y_label = *(truth + c); + auto yy = d[c*temp.num_samples() + c]; + auto xy = d[r*temp.num_samples() + c]; + + // compute the distance between x and y samples. + auto d2 = xx + yy - 2*xy; + if (d2 <= 0) + d2 = 0; + else + d2 = std::sqrt(d2); + + // It should be noted that the derivative of length(x-y) with respect + // to the x vector is the unit vector (x-y)/length(x-y). If you stare + // at the code below long enough you will see that it's just an + // application of this formula. + + if (x_label == y_label) + { + // Things with the same label should have distances < dist_thresh between + // them. If not then we experience non-zero loss. + if (d2 < dist_thresh-margin) + { + gm[r*temp.num_samples() + c] = 0; + } + else + { + loss += scale*(d2 - (dist_thresh-margin)); + gm[r*temp.num_samples() + r] += scale/d2; + gm[r*temp.num_samples() + c] = -scale/d2; + } + } + else + { + // Things with different labels should have distances > dist_thresh between + // them. If not then we experience non-zero loss. + if (d2 > dist_thresh+margin || d2 > neg_thresh) + { + gm[r*temp.num_samples() + c] = 0; + } + else + { + loss += scale*((dist_thresh+margin) - d2); + // don't divide by zero (or a really small number) + d2 = std::max(d2, 0.001f); + gm[r*temp.num_samples() + r] -= scale/d2; + gm[r*temp.num_samples() + c] = scale/d2; + } + } + } + } + + + tt::gemm(0, grad, 1, grad_mul, false, output_tensor, false); + + return loss; + } + + friend void serialize(const loss_metric_& item, std::ostream& out) + { + serialize("loss_metric_2", out); + serialize(item.margin, out); + serialize(item.dist_thresh, out); + } + + friend void deserialize(loss_metric_& item, std::istream& in) + { + std::string version; + deserialize(version, in); + if (version == "loss_metric_") + { + // These values used to be hard coded, so for this version of the metric + // learning loss we just use these values. + item.margin = 0.1; + item.dist_thresh = 0.75; + return; + } + else if (version == "loss_metric_2") + { + deserialize(item.margin, in); + deserialize(item.dist_thresh, in); + } + else + { + throw serialization_error("Unexpected version found while deserializing dlib::loss_metric_. Instead found " + version); + } + } + + friend std::ostream& operator<<(std::ostream& out, const loss_metric_& item ) + { + out << "loss_metric (margin="<"; + } + + private: + float margin = 0.04; + float dist_thresh = 0.6; + + + // These variables are only here to avoid being reallocated over and over in + // compute_loss_value_and_gradient() + mutable resizable_tensor temp, grad_mul; + + }; + + template + using loss_metric = add_loss_layer; + +// ---------------------------------------------------------------------------------------- + + class loss_ranking_ + { + public: + + typedef float training_label_type; // nominally +1/-1 + typedef float output_label_type; // ranking score + + template < + typename SUB_TYPE, + typename label_iterator + > + void to_label ( + const tensor& input_tensor, + const SUB_TYPE& sub, + label_iterator iter + ) const + { + DLIB_CASSERT(sub.sample_expansion_factor() == 1); + + const tensor& output_tensor = sub.get_output(); + + DLIB_CASSERT(output_tensor.nr() == 1 && + output_tensor.nc() == 1 && + output_tensor.k() == 1); + DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); + + const float* out_data = output_tensor.host(); + for (long i = 0; i < output_tensor.num_samples(); ++i) + { + *iter++ = out_data[i]; + } + } + + + template < + typename const_label_iterator, + typename SUBNET + > + double compute_loss_value_and_gradient ( + const tensor& input_tensor, + const_label_iterator truth, + SUBNET& sub + ) const + { + const tensor& output_tensor = sub.get_output(); + tensor& grad = sub.get_gradient_input(); + + DLIB_CASSERT(sub.sample_expansion_factor() == 1); + DLIB_CASSERT(input_tensor.num_samples() != 0); + DLIB_CASSERT(input_tensor.num_samples()%sub.sample_expansion_factor() == 0); + DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples()); + DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); + DLIB_CASSERT(output_tensor.nr() == 1 && + output_tensor.nc() == 1 && + output_tensor.k() == 1); + DLIB_CASSERT(grad.nr() == 1 && + grad.nc() == 1 && + grad.k() == 1); + + + std::vector rel_scores; + std::vector nonrel_scores; + std::vector rel_idx, nonrel_idx; + + const float* out_data = output_tensor.host(); + float* g = grad.host_write_only(); + for (long i = 0; i < output_tensor.num_samples(); ++i) + { + const float y = *truth++; + if (y > 0) + { + rel_scores.push_back(out_data[i]-y); + rel_idx.push_back(i); + } + else if (y < 0) + { + nonrel_scores.push_back(out_data[i]-y); + nonrel_idx.push_back(i); + } + else + { + g[i] = 0; + } + } + + + std::vector rel_counts; + std::vector nonrel_counts; + count_ranking_inversions(rel_scores, nonrel_scores, rel_counts, nonrel_counts); + const unsigned long total_pairs = rel_scores.size()*nonrel_scores.size(); + DLIB_CASSERT(total_pairs > 0, "You can't give a ranking mini-batch that contains only one class. Both classes must be represented."); + const double scale = 1.0/total_pairs; + + + double loss = 0; + for (unsigned long k = 0; k < rel_counts.size(); ++k) + { + loss -= rel_counts[k]*rel_scores[k]; + g[rel_idx[k]] = -1.0*rel_counts[k]*scale; + } + + for (unsigned long k = 0; k < nonrel_counts.size(); ++k) + { + loss += nonrel_counts[k]*nonrel_scores[k]; + g[nonrel_idx[k]] = nonrel_counts[k]*scale; + } + + return loss*scale; + } + + friend void serialize(const loss_ranking_& , std::ostream& out) + { + serialize("loss_ranking_", out); + } + + friend void deserialize(loss_ranking_& , std::istream& in) + { + std::string version; + deserialize(version, in); + if (version != "loss_ranking_") + throw serialization_error("Unexpected version found while deserializing dlib::loss_ranking_."); + } + + friend std::ostream& operator<<(std::ostream& out, const loss_ranking_& ) + { + out << "loss_ranking"; + return out; + } + + friend void to_xml(const loss_ranking_& /*item*/, std::ostream& out) + { + out << ""; + } + + }; + + template + using loss_ranking = add_loss_layer; + +// ---------------------------------------------------------------------------------------- + + class loss_mean_squared_ + { + public: + + typedef float training_label_type; + typedef float output_label_type; + + template < + typename SUB_TYPE, + typename label_iterator + > + void to_label ( + const tensor& input_tensor, + const SUB_TYPE& sub, + label_iterator iter + ) const + { + DLIB_CASSERT(sub.sample_expansion_factor() == 1); + + const tensor& output_tensor = sub.get_output(); + + DLIB_CASSERT(output_tensor.nr() == 1 && + output_tensor.nc() == 1 && + output_tensor.k() == 1); + DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); + + const float* out_data = output_tensor.host(); + for (long i = 0; i < output_tensor.num_samples(); ++i) + { + *iter++ = out_data[i]; + } + } + + + template < + typename const_label_iterator, + typename SUBNET + > + double compute_loss_value_and_gradient ( + const tensor& input_tensor, + const_label_iterator truth, + SUBNET& sub + ) const + { + const tensor& output_tensor = sub.get_output(); + tensor& grad = sub.get_gradient_input(); + + DLIB_CASSERT(sub.sample_expansion_factor() == 1); + DLIB_CASSERT(input_tensor.num_samples() != 0); + DLIB_CASSERT(input_tensor.num_samples()%sub.sample_expansion_factor() == 0); + DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples()); + DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); + DLIB_CASSERT(output_tensor.nr() == 1 && + output_tensor.nc() == 1 && + output_tensor.k() == 1); + DLIB_CASSERT(grad.nr() == 1 && + grad.nc() == 1 && + grad.k() == 1); + + // The loss we output is the average loss over the mini-batch. + const double scale = 1.0/output_tensor.num_samples(); + double loss = 0; + float* g = grad.host_write_only(); + const float* out_data = output_tensor.host(); + for (long i = 0; i < output_tensor.num_samples(); ++i) + { + const float y = *truth++; + const float temp1 = y - out_data[i]; + const float temp2 = scale*temp1; + loss += temp2*temp1; + g[i] = -temp2; + + } + return loss; + } + + friend void serialize(const loss_mean_squared_& , std::ostream& out) + { + serialize("loss_mean_squared_", out); + } + + friend void deserialize(loss_mean_squared_& , std::istream& in) + { + std::string version; + deserialize(version, in); + if (version != "loss_mean_squared_") + throw serialization_error("Unexpected version found while deserializing dlib::loss_mean_squared_."); + } + + friend std::ostream& operator<<(std::ostream& out, const loss_mean_squared_& ) + { + out << "loss_mean_squared"; + return out; + } + + friend void to_xml(const loss_mean_squared_& /*item*/, std::ostream& out) + { + out << ""; + } + + }; + + template + using loss_mean_squared = add_loss_layer; + +// ---------------------------------------------------------------------------------------- + + class loss_epsilon_insensitive_ + { + public: + + typedef float training_label_type; + typedef float output_label_type; + + loss_epsilon_insensitive_() = default; + loss_epsilon_insensitive_(double eps) : eps(eps) + { + DLIB_CASSERT(eps >= 0, "You can't set a negative error epsilon."); + } + + double get_epsilon () const { return eps; } + void set_epsilon(double e) + { + DLIB_CASSERT(e >= 0, "You can't set a negative error epsilon."); + eps = e; + } + + template < + typename SUB_TYPE, + typename label_iterator + > + void to_label ( + const tensor& input_tensor, + const SUB_TYPE& sub, + label_iterator iter + ) const + { + DLIB_CASSERT(sub.sample_expansion_factor() == 1); + + const tensor& output_tensor = sub.get_output(); + + DLIB_CASSERT(output_tensor.nr() == 1 && + output_tensor.nc() == 1 && + output_tensor.k() == 1); + DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); + + const float* out_data = output_tensor.host(); + for (long i = 0; i < output_tensor.num_samples(); ++i) + { + *iter++ = out_data[i]; + } + } + + + template < + typename const_label_iterator, + typename SUBNET + > + double compute_loss_value_and_gradient ( + const tensor& input_tensor, + const_label_iterator truth, + SUBNET& sub + ) const + { + const tensor& output_tensor = sub.get_output(); + tensor& grad = sub.get_gradient_input(); + + DLIB_CASSERT(sub.sample_expansion_factor() == 1); + DLIB_CASSERT(input_tensor.num_samples() != 0); + DLIB_CASSERT(input_tensor.num_samples()%sub.sample_expansion_factor() == 0); + DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples()); + DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); + DLIB_CASSERT(output_tensor.nr() == 1 && + output_tensor.nc() == 1 && + output_tensor.k() == 1); + DLIB_CASSERT(grad.nr() == 1 && + grad.nc() == 1 && + grad.k() == 1); + + // The loss we output is the average loss over the mini-batch. + const double scale = 1.0/output_tensor.num_samples(); + double loss = 0; + float* g = grad.host_write_only(); + const float* out_data = output_tensor.host(); + for (long i = 0; i < output_tensor.num_samples(); ++i) + { + const float y = *truth++; + const float err = out_data[i]-y; + if (err > eps) + { + loss += scale*(err-eps); + g[i] = scale; + } + else if (err < -eps) + { + loss += scale*(eps-err); + g[i] = -scale; + } + } + return loss; + } + + friend void serialize(const loss_epsilon_insensitive_& item, std::ostream& out) + { + serialize("loss_epsilon_insensitive_", out); + serialize(item.eps, out); + } + + friend void deserialize(loss_epsilon_insensitive_& item, std::istream& in) + { + std::string version; + deserialize(version, in); + if (version != "loss_epsilon_insensitive_") + throw serialization_error("Unexpected version found while deserializing dlib::loss_epsilon_insensitive_."); + deserialize(item.eps, in); + } + + friend std::ostream& operator<<(std::ostream& out, const loss_epsilon_insensitive_& item) + { + out << "loss_epsilon_insensitive epsilon: " << item.eps; + return out; + } + + friend void to_xml(const loss_epsilon_insensitive_& item, std::ostream& out) + { + out << ""; + } + + private: + double eps = 1; + + }; + + template + using loss_epsilon_insensitive = add_loss_layer; + +// ---------------------------------------------------------------------------------------- + + class loss_mean_squared_multioutput_ + { + public: + + typedef matrix training_label_type; + typedef matrix output_label_type; + + template < + typename SUB_TYPE, + typename label_iterator + > + void to_label ( + const tensor& input_tensor, + const SUB_TYPE& sub, + label_iterator iter + ) const + { + DLIB_CASSERT(sub.sample_expansion_factor() == 1); + + const tensor& output_tensor = sub.get_output(); + + DLIB_CASSERT(output_tensor.nr() == 1 && + output_tensor.nc() == 1) + DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); + + const float* out_data = output_tensor.host(); + for (long i = 0; i < output_tensor.num_samples(); ++i) + { + *iter++ = mat(out_data, output_tensor.k(), 1); + out_data += output_tensor.k(); + } + } + + + template < + typename const_label_iterator, + typename SUBNET + > + double compute_loss_value_and_gradient ( + const tensor& input_tensor, + const_label_iterator truth, + SUBNET& sub + ) const + { + const tensor& output_tensor = sub.get_output(); + tensor& grad = sub.get_gradient_input(); + + DLIB_CASSERT(sub.sample_expansion_factor() == 1); + DLIB_CASSERT(input_tensor.num_samples() != 0); + DLIB_CASSERT(input_tensor.num_samples()%sub.sample_expansion_factor() == 0); + DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples()); + DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); + DLIB_CASSERT(output_tensor.nr() == 1 && + output_tensor.nc() == 1); + DLIB_CASSERT(grad.nr() == 1 && + grad.nc() == 1); + DLIB_CASSERT(grad.k() == output_tensor.k()); + const long k = output_tensor.k(); + for (long idx = 0; idx < output_tensor.num_samples(); ++idx) + { + const_label_iterator truth_matrix_ptr = (truth + idx); + DLIB_CASSERT((*truth_matrix_ptr).nr() == k && + (*truth_matrix_ptr).nc() == 1); + } + + // The loss we output is the average loss over the mini-batch. + const double scale = 1.0/output_tensor.num_samples(); + double loss = 0; + float* g = grad.host_write_only(); + const float* out_data = output_tensor.host(); + matrix ytrue; + for (long i = 0; i < output_tensor.num_samples(); ++i) + { + ytrue = *truth++; + for (long j = 0; j < output_tensor.k(); ++j) + { + const float y = ytrue(j, 0); + const float temp1 = y - *out_data++; + const float temp2 = scale*temp1; + loss += temp2*temp1; + *g = -temp2; + ++g; + } + + } + return loss; + } + + friend void serialize(const loss_mean_squared_multioutput_& , std::ostream& out) + { + serialize("loss_mean_squared_multioutput_", out); + } + + friend void deserialize(loss_mean_squared_multioutput_& , std::istream& in) + { + std::string version; + deserialize(version, in); + if (version != "loss_mean_squared_multioutput_") + throw serialization_error("Unexpected version found while deserializing dlib::loss_mean_squared_."); + } + + friend std::ostream& operator<<(std::ostream& out, const loss_mean_squared_multioutput_& ) + { + out << "loss_mean_squared_multioutput"; + return out; + } + + friend void to_xml(const loss_mean_squared_multioutput_& /*item*/, std::ostream& out) + { + out << ""; + } + + }; + + template + using loss_mean_squared_multioutput = add_loss_layer; + +// ---------------------------------------------------------------------------------------- + + class loss_multiclass_log_per_pixel_ + { + public: + + // In semantic segmentation, if you don't know the ground-truth of some pixel, + // set the label of that pixel to this value. When you do so, the pixel will be + // ignored when computing gradients. + static const uint16_t label_to_ignore = std::numeric_limits::max(); + + + // In semantic segmentation, 65535 classes ought to be enough for anybody. + typedef matrix training_label_type; + typedef matrix output_label_type; + + template < + typename SUB_TYPE, + typename label_iterator + > + static void to_label ( + const tensor& input_tensor, + const SUB_TYPE& sub, + label_iterator iter + ) + { + DLIB_CASSERT(sub.sample_expansion_factor() == 1); + + const tensor& output_tensor = sub.get_output(); + + DLIB_CASSERT(output_tensor.k() >= 1); // Note that output_tensor.k() should match the number of labels. + DLIB_CASSERT(output_tensor.k() < std::numeric_limits::max()); + DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); + + const float* const out_data = output_tensor.host(); + + // The index of the largest output for each element is the label. + const auto find_label = [&](long sample, long r, long c) + { + uint16_t label = 0; + float max_value = out_data[tensor_index(output_tensor, sample, 0, r, c)]; + for (long k = 1; k < output_tensor.k(); ++k) + { + const float value = out_data[tensor_index(output_tensor, sample, k, r, c)]; + if (value > max_value) + { + label = static_cast(k); + max_value = value; + } + } + return label; + }; + + for (long i = 0; i < output_tensor.num_samples(); ++i, ++iter) + { + iter->set_size(output_tensor.nr(), output_tensor.nc()); + for (long r = 0; r < output_tensor.nr(); ++r) + { + for (long c = 0; c < output_tensor.nc(); ++c) + { + // The index of the largest output for this element is the label. + iter->operator()(r, c) = find_label(i, r, c); + } + } + } + } + + template < + typename const_label_iterator, + typename SUBNET + > + double compute_loss_value_and_gradient ( + const tensor& input_tensor, + const_label_iterator truth, + SUBNET& sub + ) const + { + const tensor& output_tensor = sub.get_output(); + tensor& grad = sub.get_gradient_input(); + + DLIB_CASSERT(sub.sample_expansion_factor() == 1); + DLIB_CASSERT(input_tensor.num_samples() != 0); + DLIB_CASSERT(input_tensor.num_samples()%sub.sample_expansion_factor() == 0); + DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples()); + DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); + DLIB_CASSERT(output_tensor.k() >= 1); + DLIB_CASSERT(output_tensor.k() < std::numeric_limits::max()); + DLIB_CASSERT(output_tensor.nr() == grad.nr() && + output_tensor.nc() == grad.nc() && + output_tensor.k() == grad.k()); + for (long idx = 0; idx < output_tensor.num_samples(); ++idx) + { + const_label_iterator truth_matrix_ptr = (truth + idx); + DLIB_CASSERT(truth_matrix_ptr->nr() == output_tensor.nr() && + truth_matrix_ptr->nc() == output_tensor.nc(), + "truth size = " << truth_matrix_ptr->nr() << " x " << truth_matrix_ptr->nc() << ", " + "output size = " << output_tensor.nr() << " x " << output_tensor.nc()); + } + + tt::softmax(grad, output_tensor); + + // The loss we output is the average loss over the mini-batch, and also over each element of the matrix output. + const double scale = 1.0 / (output_tensor.num_samples() * output_tensor.nr() * output_tensor.nc()); + double loss = 0; + float* const g = grad.host(); + for (long i = 0; i < output_tensor.num_samples(); ++i, ++truth) + { + for (long r = 0; r < output_tensor.nr(); ++r) + { + for (long c = 0; c < output_tensor.nc(); ++c) + { + const uint16_t y = truth->operator()(r, c); + // The network must produce a number of outputs that is equal to the number + // of labels when using this type of loss. + DLIB_CASSERT(static_cast(y) < output_tensor.k() || y == label_to_ignore, + "y: " << y << ", output_tensor.k(): " << output_tensor.k()); + for (long k = 0; k < output_tensor.k(); ++k) + { + const size_t idx = tensor_index(output_tensor, i, k, r, c); + if (k == y) + { + loss += scale*-safe_log(g[idx]); + g[idx] = scale*(g[idx] - 1); + } + else if (y == label_to_ignore) + { + g[idx] = 0.f; + } + else + { + g[idx] = scale*g[idx]; + } + } + } + } + } + return loss; + } + + friend void serialize(const loss_multiclass_log_per_pixel_& , std::ostream& out) + { + serialize("loss_multiclass_log_per_pixel_", out); + } + + friend void deserialize(loss_multiclass_log_per_pixel_& , std::istream& in) + { + std::string version; + deserialize(version, in); + if (version != "loss_multiclass_log_per_pixel_") + throw serialization_error("Unexpected version found while deserializing dlib::loss_multiclass_log_per_pixel_."); + } + + friend std::ostream& operator<<(std::ostream& out, const loss_multiclass_log_per_pixel_& ) + { + out << "loss_multiclass_log_per_pixel"; + return out; + } + + friend void to_xml(const loss_multiclass_log_per_pixel_& /*item*/, std::ostream& out) + { + out << ""; + } + + private: + static size_t tensor_index(const tensor& t, long sample, long k, long row, long column) + { + // See: https://github.com/davisking/dlib/blob/4dfeb7e186dd1bf6ac91273509f687293bd4230a/dlib/dnn/tensor_abstract.h#L38 + return ((sample * t.k() + k) * t.nr() + row) * t.nc() + column; + } + + }; + + template + using loss_multiclass_log_per_pixel = add_loss_layer; + +// ---------------------------------------------------------------------------------------- + + class loss_multiclass_log_per_pixel_weighted_ + { + public: + + struct weighted_label + { + weighted_label() + {} + + weighted_label(uint16_t label, float weight = 1.f) + : label(label), weight(weight) + {} + + // In semantic segmentation, 65536 classes ought to be enough for anybody. + uint16_t label = 0; + float weight = 1.f; + }; + + typedef matrix training_label_type; + typedef matrix output_label_type; + + template < + typename SUB_TYPE, + typename label_iterator + > + static void to_label ( + const tensor& input_tensor, + const SUB_TYPE& sub, + label_iterator iter + ) + { + loss_multiclass_log_per_pixel_::to_label(input_tensor, sub, iter); + } + + template < + typename const_label_iterator, + typename SUBNET + > + double compute_loss_value_and_gradient ( + const tensor& input_tensor, + const_label_iterator truth, + SUBNET& sub + ) const + { + const tensor& output_tensor = sub.get_output(); + tensor& grad = sub.get_gradient_input(); + + DLIB_CASSERT(sub.sample_expansion_factor() == 1); + DLIB_CASSERT(input_tensor.num_samples() != 0); + DLIB_CASSERT(input_tensor.num_samples()%sub.sample_expansion_factor() == 0); + DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples()); + DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); + DLIB_CASSERT(output_tensor.k() >= 1); + DLIB_CASSERT(output_tensor.k() < std::numeric_limits::max()); + DLIB_CASSERT(output_tensor.nr() == grad.nr() && + output_tensor.nc() == grad.nc() && + output_tensor.k() == grad.k()); + for (long idx = 0; idx < output_tensor.num_samples(); ++idx) + { + const_label_iterator truth_matrix_ptr = (truth + idx); + DLIB_CASSERT(truth_matrix_ptr->nr() == output_tensor.nr() && + truth_matrix_ptr->nc() == output_tensor.nc(), + "truth size = " << truth_matrix_ptr->nr() << " x " << truth_matrix_ptr->nc() << ", " + "output size = " << output_tensor.nr() << " x " << output_tensor.nc()); + } + + tt::softmax(grad, output_tensor); + + // The loss we output is the weighted average loss over the mini-batch, and also over each element of the matrix output. + const double scale = 1.0 / (output_tensor.num_samples() * output_tensor.nr() * output_tensor.nc()); + double loss = 0; + float* const g = grad.host(); + for (long i = 0; i < output_tensor.num_samples(); ++i, ++truth) + { + for (long r = 0; r < output_tensor.nr(); ++r) + { + for (long c = 0; c < output_tensor.nc(); ++c) + { + const weighted_label& weighted_label = truth->operator()(r, c); + const uint16_t y = weighted_label.label; + const float weight = weighted_label.weight; + // The network must produce a number of outputs that is equal to the number + // of labels when using this type of loss. + DLIB_CASSERT(static_cast(y) < output_tensor.k() || weight == 0.f, + "y: " << y << ", output_tensor.k(): " << output_tensor.k()); + for (long k = 0; k < output_tensor.k(); ++k) + { + const size_t idx = tensor_index(output_tensor, i, k, r, c); + if (k == y) + { + loss += weight*scale*-safe_log(g[idx]); + g[idx] = weight*scale*(g[idx] - 1); + } + else + { + g[idx] = weight*scale*g[idx]; + } + } + } + } + } + return loss; + } + + friend void serialize(const loss_multiclass_log_per_pixel_weighted_& , std::ostream& out) + { + serialize("loss_multiclass_log_per_pixel_weighted_", out); + } + + friend void deserialize(loss_multiclass_log_per_pixel_weighted_& , std::istream& in) + { + std::string version; + deserialize(version, in); + if (version != "loss_multiclass_log_per_pixel_weighted_") + throw serialization_error("Unexpected version found while deserializing dlib::loss_multiclass_log_per_pixel_weighted_."); + } + + friend std::ostream& operator<<(std::ostream& out, const loss_multiclass_log_per_pixel_weighted_& ) + { + out << "loss_multiclass_log_per_pixel_weighted"; + return out; + } + + friend void to_xml(const loss_multiclass_log_per_pixel_weighted_& /*item*/, std::ostream& out) + { + out << ""; + } + + private: + static size_t tensor_index(const tensor& t, long sample, long k, long row, long column) + { + // See: https://github.com/davisking/dlib/blob/4dfeb7e186dd1bf6ac91273509f687293bd4230a/dlib/dnn/tensor_abstract.h#L38 + return ((sample * t.k() + k) * t.nr() + row) * t.nc() + column; + } + + }; + + template + using loss_multiclass_log_per_pixel_weighted = add_loss_layer; + +// ---------------------------------------------------------------------------------------- + + class loss_mean_squared_per_pixel_ + { + public: + + typedef matrix training_label_type; + typedef matrix output_label_type; + + template < + typename SUB_TYPE, + typename label_iterator + > + void to_label ( + const tensor& input_tensor, + const SUB_TYPE& sub, + label_iterator iter + ) const + { + DLIB_CASSERT(sub.sample_expansion_factor() == 1); + + const tensor& output_tensor = sub.get_output(); + + DLIB_CASSERT(output_tensor.k() == 1, "output k = " << output_tensor.k()); + DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); + + const float* out_data = output_tensor.host(); + for (long i = 0; i < output_tensor.num_samples(); ++i, ++iter) + { + iter->set_size(output_tensor.nr(), output_tensor.nc()); + for (long r = 0; r < output_tensor.nr(); ++r) + { + for (long c = 0; c < output_tensor.nc(); ++c) + { + iter->operator()(r, c) = out_data[tensor_index(output_tensor, i, 0, r, c)]; + } + } + } + } + + + template < + typename const_label_iterator, + typename SUBNET + > + double compute_loss_value_and_gradient ( + const tensor& input_tensor, + const_label_iterator truth, + SUBNET& sub + ) const + { + const tensor& output_tensor = sub.get_output(); + tensor& grad = sub.get_gradient_input(); + + DLIB_CASSERT(sub.sample_expansion_factor() == 1); + DLIB_CASSERT(input_tensor.num_samples() != 0); + DLIB_CASSERT(input_tensor.num_samples() % sub.sample_expansion_factor() == 0); + DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples()); + DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); + DLIB_CASSERT(output_tensor.k() >= 1); + DLIB_CASSERT(output_tensor.k() < std::numeric_limits::max()); + DLIB_CASSERT(output_tensor.nr() == grad.nr() && + output_tensor.nc() == grad.nc() && + output_tensor.k() == grad.k()); + for (long idx = 0; idx < output_tensor.num_samples(); ++idx) + { + const_label_iterator truth_matrix_ptr = (truth + idx); + DLIB_CASSERT(truth_matrix_ptr->nr() == output_tensor.nr() && + truth_matrix_ptr->nc() == output_tensor.nc(), + "truth size = " << truth_matrix_ptr->nr() << " x " << truth_matrix_ptr->nc() << ", " + "output size = " << output_tensor.nr() << " x " << output_tensor.nc()); + } + + // The loss we output is the average loss over the mini-batch, and also over each element of the matrix output. + const double scale = 1.0 / (output_tensor.num_samples() * output_tensor.nr() * output_tensor.nc()); + double loss = 0; + float* const g = grad.host(); + const float* out_data = output_tensor.host(); + for (long i = 0; i < output_tensor.num_samples(); ++i, ++truth) + { + for (long r = 0; r < output_tensor.nr(); ++r) + { + for (long c = 0; c < output_tensor.nc(); ++c) + { + const float y = truth->operator()(r, c); + const size_t idx = tensor_index(output_tensor, i, 0, r, c); + const float temp1 = y - out_data[idx]; + const float temp2 = scale*temp1; + loss += temp2*temp1; + g[idx] = -temp2; + } + } + } + return loss; + } + + friend void serialize(const loss_mean_squared_per_pixel_& , std::ostream& out) + { + serialize("loss_mean_squared_per_pixel_", out); + } + + friend void deserialize(loss_mean_squared_per_pixel_& , std::istream& in) + { + std::string version; + deserialize(version, in); + if (version != "loss_mean_squared_per_pixel_") + throw serialization_error("Unexpected version found while deserializing dlib::loss_mean_squared_per_pixel_."); + } + + friend std::ostream& operator<<(std::ostream& out, const loss_mean_squared_per_pixel_& ) + { + out << "loss_mean_squared_per_pixel"; + return out; + } + + friend void to_xml(const loss_mean_squared_per_pixel_& /*item*/, std::ostream& out) + { + out << ""; + } + + private: + static size_t tensor_index(const tensor& t, long sample, long k, long row, long column) + { + // See: https://github.com/davisking/dlib/blob/4dfeb7e186dd1bf6ac91273509f687293bd4230a/dlib/dnn/tensor_abstract.h#L38 + return ((sample * t.k() + k) * t.nr() + row) * t.nc() + column; + } + }; + + template + using loss_mean_squared_per_pixel = add_loss_layer; + +// ---------------------------------------------------------------------------------------- + + class loss_dot_ + { + public: + + typedef matrix training_label_type; + typedef matrix output_label_type; + + template < + typename SUB_TYPE, + typename label_iterator + > + void to_label ( + const tensor& input_tensor, + const SUB_TYPE& sub, + label_iterator iter + ) const + { + const tensor& output_tensor = sub.get_output(); + DLIB_CASSERT(sub.sample_expansion_factor() == 1); + DLIB_CASSERT(input_tensor.num_samples() != 0); + DLIB_CASSERT(input_tensor.num_samples()%sub.sample_expansion_factor() == 0); + DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); + + for (long i = 0; i < output_tensor.num_samples(); ++i) + *iter++ = trans(rowm(mat(output_tensor),i)); + } + + + template < + typename const_label_iterator, + typename SUBNET + > + double compute_loss_value_and_gradient ( + const tensor& input_tensor, + const_label_iterator truth, + SUBNET& sub + ) const + { + const tensor& output_tensor = sub.get_output(); + tensor& grad = sub.get_gradient_input(); + + DLIB_CASSERT(sub.sample_expansion_factor() == 1); + DLIB_CASSERT(input_tensor.num_samples() != 0); + DLIB_CASSERT(input_tensor.num_samples()%sub.sample_expansion_factor() == 0); + DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples()); + DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); + + const long network_output_dims = output_tensor.size()/output_tensor.num_samples(); + + + // The loss we output is the average loss over the mini-batch. + const double scale = 1.0/output_tensor.num_samples(); + double loss = 0; + float* g = grad.host(); + const float* out_data = output_tensor.host(); + for (long i = 0; i < output_tensor.num_samples(); ++i) + { + DLIB_CASSERT(truth->size() == network_output_dims, "The network must output a vector with the same dimensionality as the training labels. " + << "\ntruth->size(): " << truth->size() + << "\nnetwork_output_dims: " << network_output_dims); + + const float* t = &(*truth++)(0); + + for (long j = 0; j < network_output_dims; ++j) + { + g[j] = -t[j]*scale; + loss -= out_data[j]*t[j]; + } + + g += network_output_dims; + out_data += network_output_dims; + } + return loss*scale; + } + + friend void serialize(const loss_dot_& , std::ostream& out) + { + serialize("loss_dot_", out); + } + + friend void deserialize(loss_dot_& , std::istream& in) + { + std::string version; + deserialize(version, in); + if (version != "loss_dot_") + throw serialization_error("Unexpected version found while deserializing dlib::loss_dot_."); + } + + friend std::ostream& operator<<(std::ostream& out, const loss_dot_& ) + { + out << "loss_dot"; + return out; + } + + friend void to_xml(const loss_dot_& /*item*/, std::ostream& out) + { + out << ""; + } + + }; + + template + using loss_dot = add_loss_layer; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_DNn_LOSS_H_ + diff --git a/ml/dlib/dlib/dnn/loss_abstract.h b/ml/dlib/dlib/dnn/loss_abstract.h new file mode 100644 index 000000000..0dd043677 --- /dev/null +++ b/ml/dlib/dlib/dnn/loss_abstract.h @@ -0,0 +1,1542 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_DNn_LOSS_ABSTRACT_H_ +#ifdef DLIB_DNn_LOSS_ABSTRACT_H_ + +#include "core_abstract.h" +#include "../image_processing/full_object_detection_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class EXAMPLE_LOSS_LAYER_ + { + /*! + WHAT THIS OBJECT REPRESENTS + A loss layer is the final layer in a deep neural network. It computes the + task loss. That is, it computes a number that tells us how well the + network is performing on some task, such as predicting a binary label. + + You can use one of the loss layers that comes with dlib (defined below). + But importantly, you are able to define your own loss layers to suit your + needs. You do this by creating a class that defines an interface matching + the one described by this EXAMPLE_LOSS_LAYER_ class. Note that there is no + dlib::EXAMPLE_LOSS_LAYER_ type. It is shown here purely to document the + interface that a loss layer must implement. + + A loss layer can optionally provide a to_label() method that converts the + output of a network into a user defined type. If to_label() is not + provided then the operator() methods of add_loss_layer will not be + available, but otherwise everything will function as normal. + + Finally, note that there are two broad flavors of loss layer, supervised + and unsupervised. The EXAMPLE_LOSS_LAYER_ as shown here is a supervised + layer. To make an unsupervised loss you simply leave out the + training_label_type typedef and the truth iterator argument to + compute_loss_value_and_gradient(). + !*/ + + public: + + // In most cases training_label_type and output_label_type will be the same type. + typedef whatever_type_you_use_for_training_labels training_label_type; + typedef whatever_type_you_use_for_outout_labels output_label_type; + + EXAMPLE_LOSS_LAYER_ ( + ); + /*! + ensures + - EXAMPLE_LOSS_LAYER_ objects are default constructable. + !*/ + + EXAMPLE_LOSS_LAYER_ ( + const EXAMPLE_LOSS_LAYER_& item + ); + /*! + ensures + - EXAMPLE_LOSS_LAYER_ objects are copy constructable. + !*/ + + // Implementing to_label() is optional. + template < + typename SUB_TYPE, + typename label_iterator + > + void to_label ( + const tensor& input_tensor, + const SUB_TYPE& sub, + label_iterator iter + ) const; + /*! + requires + - SUBNET implements the SUBNET interface defined at the top of + layers_abstract.h. + - input_tensor was given as input to the network sub and the outputs are + now visible in layer(sub).get_output(), for all valid i. + - input_tensor.num_samples() > 0 + - input_tensor.num_samples()%sub.sample_expansion_factor() == 0. + - iter == an iterator pointing to the beginning of a range of + input_tensor.num_samples()/sub.sample_expansion_factor() elements. Moreover, + they must be output_label_type elements. + ensures + - Converts the output of the provided network to output_label_type objects and + stores the results into the range indicated by iter. In particular, for + all valid i, it will be the case that: + *(iter+i/sub.sample_expansion_factor()) is populated based on the output of + sub and corresponds to the ith sample in input_tensor. + !*/ + + template < + typename const_label_iterator, + typename SUBNET + > + double compute_loss_value_and_gradient ( + const tensor& input_tensor, + const_label_iterator truth, + SUBNET& sub + ) const; + /*! + requires + - SUBNET implements the SUBNET interface defined at the top of + layers_abstract.h. + - input_tensor was given as input to the network sub and the outputs are + now visible in layer(sub).get_output(), for all valid i. + - input_tensor.num_samples() > 0 + - input_tensor.num_samples()%sub.sample_expansion_factor() == 0. + - for all valid i: + - layer(sub).get_gradient_input() has the same dimensions as + layer(sub).get_output(). + - layer(sub).get_gradient_input() contains all zeros (i.e. + initially, all input gradients are 0). + - truth == an iterator pointing to the beginning of a range of + input_tensor.num_samples()/sub.sample_expansion_factor() elements. Moreover, + they must be training_label_type elements. + - for all valid i: + - *(truth+i/sub.sample_expansion_factor()) is the label of the ith sample in + input_tensor. + ensures + - This function computes a loss function that describes how well the output + of sub matches the expected labels given by truth. Let's write the loss + function as L(input_tensor, truth, sub). + - Then compute_loss_value_and_gradient() computes the gradient of L() with + respect to the outputs in sub. Specifically, compute_loss_value_and_gradient() + assigns the gradients into sub by performing the following tensor + assignments, for all valid i: + - layer(sub).get_gradient_input() = the gradient of + L(input_tensor,truth,sub) with respect to layer(sub).get_output(). + Note that, since get_gradient_input() is zero initialized, you don't + have to write gradient information to layers that have a zero + loss gradient. + - returns L(input_tensor,truth,sub) + !*/ + }; + + std::ostream& operator<<(std::ostream& out, const EXAMPLE_LOSS_LAYER_& item); + /*! + print a string describing this layer. + !*/ + + void to_xml(const EXAMPLE_LOSS_LAYER_& item, std::ostream& out); + /*! + This function is optional, but required if you want to print your networks with + net_to_xml(). Therefore, to_xml() prints a layer as XML. + !*/ + + void serialize(const EXAMPLE_LOSS_LAYER_& item, std::ostream& out); + void deserialize(EXAMPLE_LOSS_LAYER_& item, std::istream& in); + /*! + provides serialization support + !*/ + + // For each loss layer you define, always define an add_loss_layer template so that + // layers can be easily composed. Moreover, the convention is that the layer class + // ends with an _ while the add_loss_layer template has the same name but without the + // trailing _. + template + using EXAMPLE_LOSS_LAYER = add_loss_layer; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class loss_binary_hinge_ + { + /*! + WHAT THIS OBJECT REPRESENTS + This object implements the loss layer interface defined above by + EXAMPLE_LOSS_LAYER_. In particular, it implements the hinge loss, which is + appropriate for binary classification problems. Therefore, the possible + labels when using this loss are +1 and -1. Moreover, it will cause the + network to produce outputs > 0 when predicting a member of the +1 class and + values < 0 otherwise. + !*/ + public: + + typedef float training_label_type; + typedef float output_label_type; + + template < + typename SUB_TYPE, + typename label_iterator + > + void to_label ( + const tensor& input_tensor, + const SUB_TYPE& sub, + label_iterator iter + ) const; + /*! + This function has the same interface as EXAMPLE_LOSS_LAYER_::to_label() except + it has the additional calling requirements that: + - sub.get_output().nr() == 1 + - sub.get_output().nc() == 1 + - sub.get_output().k() == 1 + - sub.get_output().num_samples() == input_tensor.num_samples() + - sub.sample_expansion_factor() == 1 + and the output label is the raw score for each classified object. If the score + is > 0 then the classifier is predicting the +1 class, otherwise it is + predicting the -1 class. + !*/ + + template < + typename const_label_iterator, + typename SUBNET + > + double compute_loss_value_and_gradient ( + const tensor& input_tensor, + const_label_iterator truth, + SUBNET& sub + ) const; + /*! + This function has the same interface as EXAMPLE_LOSS_LAYER_::compute_loss_value_and_gradient() + except it has the additional calling requirements that: + - sub.get_output().nr() == 1 + - sub.get_output().nc() == 1 + - sub.get_output().k() == 1 + - sub.get_output().num_samples() == input_tensor.num_samples() + - sub.sample_expansion_factor() == 1 + - all values pointed to by truth are +1 or -1. + !*/ + + }; + + template + using loss_binary_hinge = add_loss_layer; + +// ---------------------------------------------------------------------------------------- + + class loss_binary_log_ + { + /*! + WHAT THIS OBJECT REPRESENTS + This object implements the loss layer interface defined above by + EXAMPLE_LOSS_LAYER_. In particular, it implements the log loss, which is + appropriate for binary classification problems. Therefore, the possible + labels when using this loss are +1 and -1. Moreover, it will cause the + network to produce outputs > 0 when predicting a member of the +1 class and + values < 0 otherwise. + + To be more specific, this object contains a sigmoid layer followed by a + cross-entropy layer. + !*/ + public: + + typedef float training_label_type; + typedef float output_label_type; + + template < + typename SUB_TYPE, + typename label_iterator + > + void to_label ( + const tensor& input_tensor, + const SUB_TYPE& sub, + label_iterator iter + ) const; + /*! + This function has the same interface as EXAMPLE_LOSS_LAYER_::to_label() except + it has the additional calling requirements that: + - sub.get_output().nr() == 1 + - sub.get_output().nc() == 1 + - sub.get_output().k() == 1 + - sub.get_output().num_samples() == input_tensor.num_samples() + - sub.sample_expansion_factor() == 1 + and the output label is the raw score for each classified object. If the score + is > 0 then the classifier is predicting the +1 class, otherwise it is + predicting the -1 class. + !*/ + + template < + typename const_label_iterator, + typename SUBNET + > + double compute_loss_value_and_gradient ( + const tensor& input_tensor, + const_label_iterator truth, + SUBNET& sub + ) const; + /*! + This function has the same interface as EXAMPLE_LOSS_LAYER_::compute_loss_value_and_gradient() + except it has the additional calling requirements that: + - sub.get_output().nr() == 1 + - sub.get_output().nc() == 1 + - sub.get_output().k() == 1 + - sub.get_output().num_samples() == input_tensor.num_samples() + - sub.sample_expansion_factor() == 1 + - all values pointed to by truth are +1 or -1. + !*/ + + }; + + template + using loss_binary_log = add_loss_layer; + +// ---------------------------------------------------------------------------------------- + + class loss_multiclass_log_ + { + /*! + WHAT THIS OBJECT REPRESENTS + This object implements the loss layer interface defined above by + EXAMPLE_LOSS_LAYER_. In particular, it implements the multiclass logistic + regression loss (e.g. negative log-likelihood loss), which is appropriate + for multiclass classification problems. This means that the possible + labels when using this loss are integers >= 0. + + Moreover, if after training you were to replace the loss layer of the + network with a softmax layer, the network outputs would give the + probabilities of each class assignment. That is, if you have K classes + then the network should output tensors with the tensor::k()'th dimension + equal to K. Applying softmax to these K values gives the probabilities of + each class. The index into that K dimensional vector with the highest + probability is the predicted class label. + !*/ + + public: + + typedef unsigned long training_label_type; + typedef unsigned long output_label_type; + + template < + typename SUB_TYPE, + typename label_iterator + > + void to_label ( + const tensor& input_tensor, + const SUB_TYPE& sub, + label_iterator iter + ) const; + /*! + This function has the same interface as EXAMPLE_LOSS_LAYER_::to_label() except + it has the additional calling requirements that: + - sub.get_output().nr() == 1 + - sub.get_output().nc() == 1 + - sub.get_output().num_samples() == input_tensor.num_samples() + - sub.sample_expansion_factor() == 1 + and the output label is the predicted class for each classified object. The number + of possible output classes is sub.get_output().k(). + !*/ + + template < + typename const_label_iterator, + typename SUBNET + > + double compute_loss_value_and_gradient ( + const tensor& input_tensor, + const_label_iterator truth, + SUBNET& sub + ) const; + /*! + This function has the same interface as EXAMPLE_LOSS_LAYER_::compute_loss_value_and_gradient() + except it has the additional calling requirements that: + - sub.get_output().nr() == 1 + - sub.get_output().nc() == 1 + - sub.get_output().num_samples() == input_tensor.num_samples() + - sub.sample_expansion_factor() == 1 + - all values pointed to by truth are < sub.get_output().k() + !*/ + + }; + + template + using loss_multiclass_log = add_loss_layer; + +// ---------------------------------------------------------------------------------------- + + class loss_multimulticlass_log_ + { + /*! + WHAT THIS OBJECT REPRESENTS + This object implements the loss layer interface defined above by + EXAMPLE_LOSS_LAYER_. In particular, it implements a collection of + multiclass classifiers. An example will make its use clear. So suppose, + for example, that you want to make something that takes a picture of a + vehicle and answers the following questions: + - What type of vehicle is it? A sedan or a truck? + - What color is it? red, green, blue, gray, or black? + You need two separate multi-class classifiers to do this. One to decide + the type of vehicle, and another to decide the color. The + loss_multimulticlass_log_ allows you to pack these two classifiers into one + neural network. This means that when you use the network to process an + image it will output 2 labels for each image, the type label and the color + label. + + To create a loss_multimulticlass_log_ for the above case you would + construct it as follows: + std::map> labels; + labels["type"] = {"sedan", "truck"}; + labels["color"] = {"red", "green", "blue", "gray", "black"}; + loss_multimulticlass_log_ myloss(labels); + Then you could use myloss with a network object and train it to do this + task. More generally, you can use any number of classifiers and labels + when using this object. Finally, each of the classifiers uses a standard + multi-class logistic regression loss. + !*/ + + public: + + loss_multimulticlass_log_( + ); + /*! + ensures + - #number_of_labels() == 0 + - #get_labels().size() == 0 + !*/ + + loss_multimulticlass_log_ ( + const std::map>& labels + ); + /*! + requires + - Each vector in labels must contain at least 2 strings. I.e. each + classifier must have at least two possible labels. + ensures + - #number_of_labels() == the total number of strings in all the + std::vectors in labels. + - #number_of_classifiers() == labels.size() + - #get_labels() == labels + !*/ + + unsigned long number_of_labels( + ) const; + /*! + ensures + - returns the total number of labels known to this loss. This is the count of + all the labels in each classifier. + !*/ + + unsigned long number_of_classifiers( + ) const; + /*! + ensures + - returns the number of classifiers defined by this loss. + !*/ + + std::map> get_labels ( + ) const; + /*! + ensures + - returns the names of the classifiers and labels used by this loss. In + particular, if the returned object is L then: + - L[CLASS] == the set of labels used by the classifier CLASS. + - L.size() == number_of_classifiers() + - The count of strings in the vectors in L == number_of_labels() + !*/ + + class classifier_output + { + /*! + WHAT THIS OBJECT REPRESENTS + This object stores the predictions from one of the classifiers in + loss_multimulticlass_log_. It allows you to find out the most likely + string label predicted by that classifier, as well as get the class + conditional probability of any of the classes in the classifier. + !*/ + + public: + + classifier_output( + ); + /*! + ensures + - #num_classes() == 0 + !*/ + + size_t num_classes( + ) const; + /*! + ensures + - returns the number of possible classes output by this classifier. + !*/ + + double probability_of_class ( + size_t i + ) const; + /*! + requires + - i < num_classes() + ensures + - returns the probability that the true class has a label of label(i). + - The sum of probability_of_class(j) for j in the range [0, num_classes()) is always 1. + !*/ + + const std::string& label( + size_t i + ) const; + /*! + requires + - i < num_classes() + ensures + - returns the string label for the ith class. + !*/ + + operator std::string( + ) const; + /*! + requires + - num_classes() != 0 + ensures + - returns the string label for the most probable class. + !*/ + + friend std::ostream& operator<< (std::ostream& out, const classifier_output& item); + /*! + requires + - num_classes() != 0 + ensures + - prints the most probable class label to out. + !*/ + + }; + + // Both training_label_type and output_label_type should always have sizes equal to + // number_of_classifiers(). That is, the std::map should have an entry for every + // classifier known to this loss. + typedef std::map training_label_type; + typedef std::map output_label_type; + + + template < + typename SUB_TYPE, + typename label_iterator + > + void to_label ( + const tensor& input_tensor, + const SUB_TYPE& sub, + label_iterator iter + ) const; + /*! + This function has the same interface as EXAMPLE_LOSS_LAYER_::to_label() except + it has the additional calling requirements that: + - number_of_labels() != 0 + - sub.get_output().k() == number_of_labels() + - sub.get_output().nr() == 1 + - sub.get_output().nc() == 1 + - sub.get_output().num_samples() == input_tensor.num_samples() + - sub.sample_expansion_factor() == 1 + !*/ + + template < + typename const_label_iterator, + typename SUBNET + > + double compute_loss_value_and_gradient ( + const tensor& input_tensor, + const_label_iterator truth, + SUBNET& sub + ) const; + /*! + This function has the same interface as EXAMPLE_LOSS_LAYER_::compute_loss_value_and_gradient() + except it has the additional calling requirements that: + - number_of_labels() != 0 + - sub.get_output().k() == number_of_labels() + It should be noted that the last layer in your network should usually + be an fc layer. If so, you can satisfy this requirement of k() being + number_of_labels() by calling set_num_outputs() prior to training your + network like so: + your_network.subnet().layer_details().set_num_outputs(your_network.loss_details().number_of_labels()); + - sub.get_output().nr() == 1 + - sub.get_output().nc() == 1 + - sub.get_output().num_samples() == input_tensor.num_samples() + - sub.sample_expansion_factor() == 1 + - All the std::maps pointed to by truth contain entries for all the + classifiers known to this loss. That is, it must be valid to call + truth[i][classifier] for any of the classifiers known to this loss. To + say this another way, all the training samples must contain labels for + each of the classifiers defined by this loss. + + To really belabor this, this also means that truth[i].size() == + get_labels().size() and that both truth[i] and get_labels() have the same + set of key strings. It also means that the value strings in truth[i] + must be strings known to the loss, i.e. they are valid labels according + to get_labels(). + !*/ + }; + + template + using loss_multimulticlass_log = add_loss_layer; + + // Allow comparison between classifier_outputs and std::string to check if the + // predicted class is a particular string. + inline bool operator== (const std::string& lhs, const loss_multimulticlass_log_::classifier_output& rhs) + { return lhs == static_cast(rhs); } + inline bool operator== (const loss_multimulticlass_log_::classifier_output& lhs, const std::string& rhs) + { return rhs == static_cast(lhs); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + enum class use_image_pyramid : uint8_t + { + no, + yes + }; + + struct mmod_options + { + /*! + WHAT THIS OBJECT REPRESENTS + This object contains all the parameters that control the behavior of loss_mmod_. + !*/ + + public: + + struct detector_window_details + { + detector_window_details() = default; + detector_window_details(unsigned long w, unsigned long h) : width(w), height(h) {} + detector_window_details(unsigned long w, unsigned long h, const std::string& l) : width(w), height(h), label(l) {} + + unsigned long width = 0; + unsigned long height = 0; + std::string label; + + friend inline void serialize(const detector_window_details& item, std::ostream& out); + friend inline void deserialize(detector_window_details& item, std::istream& in); + }; + + mmod_options() = default; + + // This kind of object detector is a sliding window detector. The detector_windows + // field determines how many sliding windows we will use and what the shape of each + // window is. It also determines the output label applied to each detection + // identified by each window. Since you will usually use the MMOD loss with an + // image pyramid, the detector sizes also determine the size of the smallest object + // you can detect. + std::vector detector_windows; + + // These parameters control how we penalize different kinds of mistakes. See + // Max-Margin Object Detection by Davis E. King (http://arxiv.org/abs/1502.00046) + // for further details. + double loss_per_false_alarm = 1; + double loss_per_missed_target = 1; + + // A detection must have an intersection-over-union value greater than this for us + // to consider it a match against a ground truth box. + double truth_match_iou_threshold = 0.5; + + // When doing non-max suppression, we use overlaps_nms to decide if a box overlaps + // an already output detection and should therefore be thrown out. + test_box_overlap overlaps_nms = test_box_overlap(0.4); + + // Any mmod_rect in the training data that has its ignore field set to true defines + // an "ignore zone" in an image. Any detection from that area is totally ignored + // by the optimizer. Therefore, this overlaps_ignore field defines how we decide + // if a box falls into an ignore zone. You use these ignore zones if there are + // objects in your dataset that you are unsure if you want to detect or otherwise + // don't care if the detector gets them or not. + test_box_overlap overlaps_ignore; + + // Usually the detector would be scale-invariant, and used with an image pyramid. + // However, sometimes scale-invariance may not be desired. + use_image_pyramid assume_image_pyramid = use_image_pyramid::yes; + + mmod_options ( + const std::vector>& boxes, + const unsigned long target_size, + const unsigned long min_target_size, + const double min_detector_window_overlap_iou = 0.75 + ); + /*! + requires + - 0 < min_target_size <= target_size + - 0.5 < min_detector_window_overlap_iou < 1 + ensures + - use_image_pyramid_ == use_image_pyramid::yes + - This function should be used when scale-invariance is desired, and + input_rgb_image_pyramid is therefore used as the input layer. + - This function tries to automatically set the MMOD options to reasonable + values, assuming you have a training dataset of boxes.size() images, where + the ith image contains objects boxes[i] you want to detect. + - The most important thing this function does is decide what detector + windows should be used. This is done by finding a set of detector + windows that are sized such that: + - When slid over an image pyramid, each box in boxes will have an + intersection-over-union with one of the detector windows of at least + min_detector_window_overlap_iou. That is, we will make sure that + each box in boxes could potentially be detected by one of the + detector windows. This essentially comes down to picking detector + windows with aspect ratios similar to the aspect ratios in boxes. + Note that we also make sure that each box can be detected by a window + with the same label. For example, if all the boxes had the same + aspect ratio but there were 4 different labels used in boxes then + there would be 4 resulting detector windows, one for each label. + - The longest edge of each detector window is target_size pixels in + length, unless the window's shortest side would be less than + min_target_size pixels in length. In this case the shortest side + will be set to min_target_size length, and the other side sized to + preserve the aspect ratio of the window. + This means that target_size and min_target_size control the size of the + detector windows, while the aspect ratios of the detector windows are + automatically determined by the contents of boxes. It should also be + emphasized that the detector isn't going to be able to detect objects + smaller than any of the detector windows. So consider that when setting + these sizes. + - This function will also set the overlaps_nms tester to the most + restrictive tester that doesn't reject anything in boxes. + !*/ + + mmod_options ( + use_image_pyramid use_image_pyramid, + const std::vector>& boxes, + const double min_detector_window_overlap_iou = 0.75 + ); + /*! + requires + - use_image_pyramid == use_image_pyramid::no + - 0.5 < min_detector_window_overlap_iou < 1 + ensures + - This function should be used when scale-invariance is not desired, and + there is no intention to apply an image pyramid. + - This function tries to automatically set the MMOD options to reasonable + values, assuming you have a training dataset of boxes.size() images, where + the ith image contains objects boxes[i] you want to detect. + - The most important thing this function does is decide what detector + windows should be used. This is done by finding a set of detector + windows that are sized such that: + - When slid over an image, each box in boxes will have an + intersection-over-union with one of the detector windows of at least + min_detector_window_overlap_iou. That is, we will make sure that + each box in boxes could potentially be detected by one of the + detector windows. + - This function will also set the overlaps_nms tester to the most + restrictive tester that doesn't reject anything in boxes. + !*/ + }; + + void serialize(const mmod_options& item, std::ostream& out); + void deserialize(mmod_options& item, std::istream& in); + +// ---------------------------------------------------------------------------------------- + + class loss_mmod_ + { + /*! + WHAT THIS OBJECT REPRESENTS + This object implements the loss layer interface defined above by + EXAMPLE_LOSS_LAYER_. In particular, it implements the Max Margin Object + Detection loss defined in the paper: + Max-Margin Object Detection by Davis E. King (http://arxiv.org/abs/1502.00046). + + This means you use this loss if you want to detect the locations of objects + in images. + + It should also be noted that this loss layer requires an input layer that + defines the following functions: + - image_contained_point() + - tensor_space_to_image_space() + - image_space_to_tensor_space() + A reference implementation of them and their definitions can be found in + the input_rgb_image_pyramid object, which is the recommended input layer to + be used with loss_mmod_. + !*/ + + public: + + typedef std::vector training_label_type; + typedef std::vector output_label_type; + + loss_mmod_( + ); + /*! + ensures + - #get_options() == mmod_options() + !*/ + + loss_mmod_( + mmod_options options_ + ); + /*! + ensures + - #get_options() == options_ + !*/ + + const mmod_options& get_options ( + ) const; + /*! + ensures + - returns the options object that defines the general behavior of this loss layer. + !*/ + + template < + typename SUB_TYPE, + typename label_iterator + > + void to_label ( + const tensor& input_tensor, + const SUB_TYPE& sub, + label_iterator iter, + double adjust_threshold = 0 + ) const; + /*! + This function has the same interface as EXAMPLE_LOSS_LAYER_::to_label() except + it has the additional calling requirements that: + - sub.get_output().k() == 1 + - sub.get_output().num_samples() == input_tensor.num_samples() + - sub.sample_expansion_factor() == 1 + Also, the output labels are std::vectors of mmod_rects where, for each mmod_rect R, + we have the following interpretations: + - R.rect == the location of an object in the image. + - R.detection_confidence the score for the object, the bigger the score the + more confident the detector is that an object is really there. Only + objects with a detection_confidence > adjust_threshold are output. So if + you want to output more objects (that are also of less confidence) you + can call to_label() with a smaller value of adjust_threshold. + - R.ignore == false (this value is unused by to_label()). + !*/ + + template < + typename const_label_iterator, + typename SUBNET + > + double compute_loss_value_and_gradient ( + const tensor& input_tensor, + const_label_iterator truth, + SUBNET& sub + ) const; + /*! + This function has the same interface as EXAMPLE_LOSS_LAYER_::compute_loss_value_and_gradient() + except it has the additional calling requirements that: + - sub.get_output().k() == 1 + - sub.get_output().num_samples() == input_tensor.num_samples() + - sub.sample_expansion_factor() == 1 + Also, the loss value returned is roughly equal to the average number of + mistakes made per image. This is the sum of false alarms and missed + detections, weighted by the loss weights for these types of mistakes specified + in the mmod_options. + !*/ + }; + + template + using loss_mmod = add_loss_layer; + +// ---------------------------------------------------------------------------------------- + + class loss_metric_ + { + /*! + WHAT THIS OBJECT REPRESENTS + This object implements the loss layer interface defined above by + EXAMPLE_LOSS_LAYER_. In particular, it allows you to learn to map objects + into a vector space where objects sharing the same class label are close to + each other, while objects with different labels are far apart. + + To be specific, it optimizes the following loss function which considers + all pairs of objects in a mini-batch and computes a different loss depending + on their respective class labels. So if objects A1 and A2 in a mini-batch + share the same class label then their contribution to the loss is: + max(0, length(A1-A2)-get_distance_threshold() + get_margin()) + + While if A1 and B1 have different class labels then their contribution to + the loss function is: + max(0, get_distance_threshold()-length(A1-B1) + get_margin()) + + Therefore, this loss layer optimizes a version of the hinge loss. + Moreover, the loss is trying to make sure that all objects with the same + label are within get_distance_threshold() distance of each other. + Conversely, if two objects have different labels then they should be more + than get_distance_threshold() distance from each other in the learned + embedding. So this loss function gives you a natural decision boundary for + deciding if two objects are from the same class. + + Finally, the loss balances the number of negative pairs relative to the + number of positive pairs. Therefore, if there are N pairs that share the + same identity in a mini-batch then the algorithm will only include the N + worst non-matching pairs in the loss. That is, the algorithm performs hard + negative mining on the non-matching pairs. This is important since there + are in general way more non-matching pairs than matching pairs. So to + avoid imbalance in the loss this kind of hard negative mining is useful. + !*/ + public: + + typedef unsigned long training_label_type; + typedef matrix output_label_type; + + loss_metric_( + ); + /*! + ensures + - #get_margin() == 0.04 + - #get_distance_threshold() == 0.6 + !*/ + + loss_metric_( + float margin, + float dist_thresh + ); + /*! + requires + - margin > 0 + - dist_thresh > 0 + ensures + - #get_margin() == margin + - #get_distance_threshold() == dist_thresh + !*/ + + template < + typename SUB_TYPE, + typename label_iterator + > + void to_label ( + const tensor& input_tensor, + const SUB_TYPE& sub, + label_iterator iter + ) const; + /*! + This function has the same interface as EXAMPLE_LOSS_LAYER_::to_label() except + it has the additional calling requirements that: + - sub.get_output().nr() == 1 + - sub.get_output().nc() == 1 + - sub.get_output().num_samples() == input_tensor.num_samples() + - sub.sample_expansion_factor() == 1 + This loss expects the network to produce a single vector (per sample) as + output. This vector is the learned embedding. Therefore, to_label() just + copies these output vectors from the network into the output label_iterators + given to this function, one for each sample in the input_tensor. + !*/ + + float get_margin() const; + /*! + ensures + - returns the margin value used by the loss function. See the discussion + in WHAT THIS OBJECT REPRESENTS for details. + !*/ + + float get_distance_threshold() const; + /*! + ensures + - returns the distance threshold value used by the loss function. See the discussion + in WHAT THIS OBJECT REPRESENTS for details. + !*/ + + template < + typename const_label_iterator, + typename SUBNET + > + double compute_loss_value_and_gradient ( + const tensor& input_tensor, + const_label_iterator truth, + SUBNET& sub + ) const; + /*! + This function has the same interface as EXAMPLE_LOSS_LAYER_::compute_loss_value_and_gradient() + except it has the additional calling requirements that: + - sub.get_output().nr() == 1 + - sub.get_output().nc() == 1 + - sub.get_output().num_samples() == input_tensor.num_samples() + - sub.sample_expansion_factor() == 1 + !*/ + + }; + + template + using loss_metric = add_loss_layer; + +// ---------------------------------------------------------------------------------------- + + class loss_ranking_ + { + /*! + WHAT THIS OBJECT REPRESENTS + This object implements the loss layer interface defined above by + EXAMPLE_LOSS_LAYER_. In particular, it implements the pairwise ranking + loss described in the paper: + Optimizing Search Engines using Clickthrough Data by Thorsten Joachims + + This is the same loss function used by the dlib::svm_rank_trainer object. + Therefore, it is generally appropriate when you have a two class problem + and you want to learn a function that ranks one class before the other. + + So for example, suppose you have two classes of data. Objects of type A + and objects of type B. Moreover, suppose that you want to sort the objects + so that A objects always come before B objects. This loss will help you + learn a function that assigns a real number to each object such that A + objects get a larger number assigned to them than B objects. This lets you + then sort the objects according to the output of the neural network and + obtain the desired result of having A objects come before B objects. + + The training labels should be positive values for objects you want to get + high scores and negative for objects that should get small scores. So + relative to our A/B example, you would give A objects labels of +1 and B + objects labels of -1. This should cause the learned network to give A + objects large positive values and B objects negative values. + + + Finally, the specific loss function is: + For all pairs of positive vs negative training examples A_i and B_j respectively: + sum_ij: max(0, B_i - A_j + margin_ij) + where margin_ij = the label for A_j minus the label for B_i. If you + always use +1 and -1 labels then the margin is always 2. However, this + formulation allows you to give certain training samples different weight by + adjusting the training labels appropriately. + !*/ + + public: + + typedef float training_label_type; + typedef float output_label_type; + + template < + typename SUB_TYPE, + typename label_iterator + > + void to_label ( + const tensor& input_tensor, + const SUB_TYPE& sub, + label_iterator iter + ) const; + /*! + This function has the same interface as EXAMPLE_LOSS_LAYER_::to_label() except + it has the additional calling requirements that: + - sub.get_output().nr() == 1 + - sub.get_output().nc() == 1 + - sub.get_output().k() == 1 + - sub.get_output().num_samples() == input_tensor.num_samples() + - sub.sample_expansion_factor() == 1 + and the output label is the predicted ranking score. + !*/ + + template < + typename const_label_iterator, + typename SUBNET + > + double compute_loss_value_and_gradient ( + const tensor& input_tensor, + const_label_iterator truth, + SUBNET& sub + ) const; + /*! + This function has the same interface as EXAMPLE_LOSS_LAYER_::compute_loss_value_and_gradient() + except it has the additional calling requirements that: + - sub.get_output().nr() == 1 + - sub.get_output().nc() == 1 + - sub.get_output().k() == 1 + - sub.get_output().num_samples() == input_tensor.num_samples() + - sub.sample_expansion_factor() == 1 + !*/ + + }; + + template + using loss_ranking = add_loss_layer; + +// ---------------------------------------------------------------------------------------- + + class loss_epsilon_insensitive_ + { + /*! + WHAT THIS OBJECT REPRESENTS + This object implements the loss layer interface defined above by + EXAMPLE_LOSS_LAYER_. In particular, it implements the epsilon insensitive + loss, which is appropriate for regression problems. In particular, this + loss function is; + loss(y1,y2) = abs(y1-y2)= 0 + ensures + - #get_epsilon() == eps + !*/ + + double get_epsilon ( + ) const; + /*! + ensures + - returns the epsilon value used in the loss function. Mistakes in the + regressor smaller than get_epsilon() are ignored by the loss function. + !*/ + + void set_epsilon( + double eps + ); + /*! + requires + - eps >= 0 + ensures + - #get_epsilon() == eps + !*/ + + template < + typename SUB_TYPE, + typename label_iterator + > + void to_label ( + const tensor& input_tensor, + const SUB_TYPE& sub, + label_iterator iter + ) const; + /*! + This function has the same interface as EXAMPLE_LOSS_LAYER_::to_label() except + it has the additional calling requirements that: + - sub.get_output().nr() == 1 + - sub.get_output().nc() == 1 + - sub.get_output().k() == 1 + - sub.get_output().num_samples() == input_tensor.num_samples() + - sub.sample_expansion_factor() == 1 + and the output label is the predicted continuous variable. + !*/ + + template < + typename const_label_iterator, + typename SUBNET + > + double compute_loss_value_and_gradient ( + const tensor& input_tensor, + const_label_iterator truth, + SUBNET& sub + ) const; + /*! + This function has the same interface as EXAMPLE_LOSS_LAYER_::compute_loss_value_and_gradient() + except it has the additional calling requirements that: + - sub.get_output().nr() == 1 + - sub.get_output().nc() == 1 + - sub.get_output().k() == 1 + - sub.get_output().num_samples() == input_tensor.num_samples() + - sub.sample_expansion_factor() == 1 + !*/ + + }; + + template + using loss_epsilon_insensitive = add_loss_layer; + +// ---------------------------------------------------------------------------------------- + + class loss_mean_squared_ + { + /*! + WHAT THIS OBJECT REPRESENTS + This object implements the loss layer interface defined above by + EXAMPLE_LOSS_LAYER_. In particular, it implements the mean squared loss, which is + appropriate for regression problems. + !*/ + public: + + typedef float training_label_type; + typedef float output_label_type; + + template < + typename SUB_TYPE, + typename label_iterator + > + void to_label ( + const tensor& input_tensor, + const SUB_TYPE& sub, + label_iterator iter + ) const; + /*! + This function has the same interface as EXAMPLE_LOSS_LAYER_::to_label() except + it has the additional calling requirements that: + - sub.get_output().nr() == 1 + - sub.get_output().nc() == 1 + - sub.get_output().k() == 1 + - sub.get_output().num_samples() == input_tensor.num_samples() + - sub.sample_expansion_factor() == 1 + and the output label is the predicted continuous variable. + !*/ + + template < + typename const_label_iterator, + typename SUBNET + > + double compute_loss_value_and_gradient ( + const tensor& input_tensor, + const_label_iterator truth, + SUBNET& sub + ) const; + /*! + This function has the same interface as EXAMPLE_LOSS_LAYER_::compute_loss_value_and_gradient() + except it has the additional calling requirements that: + - sub.get_output().nr() == 1 + - sub.get_output().nc() == 1 + - sub.get_output().k() == 1 + - sub.get_output().num_samples() == input_tensor.num_samples() + - sub.sample_expansion_factor() == 1 + !*/ + + }; + + template + using loss_mean_squared = add_loss_layer; + +// ---------------------------------------------------------------------------------------- + + class loss_mean_squared_multioutput_ + { + /*! + WHAT THIS OBJECT REPRESENTS + This object implements the loss layer interface defined above by + EXAMPLE_LOSS_LAYER_. In particular, it implements the mean squared loss, + which is appropriate for regression problems. It is basically just like + loss_mean_squared_ except that it lets you define multiple outputs instead + of just 1. + !*/ + public: + + typedef matrix training_label_type; + typedef matrix output_label_type; + + template < + typename SUB_TYPE, + typename label_iterator + > + void to_label ( + const tensor& input_tensor, + const SUB_TYPE& sub, + label_iterator iter + ) const; + /*! + This function has the same interface as EXAMPLE_LOSS_LAYER_::to_label() except + it has the additional calling requirements that: + - sub.get_output().nr() == 1 + - sub.get_output().nc() == 1 + - sub.get_output().num_samples() == input_tensor.num_samples() + - sub.sample_expansion_factor() == 1 + and the output label is the predicted continuous variable. + !*/ + + template < + typename const_label_iterator, + typename SUBNET + > + double compute_loss_value_and_gradient ( + const tensor& input_tensor, + const_label_iterator truth, + SUBNET& sub + ) const; + /*! + This function has the same interface as EXAMPLE_LOSS_LAYER_::compute_loss_value_and_gradient() + except it has the additional calling requirements that: + - sub.get_output().nr() == 1 + - sub.get_output().nc() == 1 + - sub.get_output().num_samples() == input_tensor.num_samples() + - sub.sample_expansion_factor() == 1 + - (*(truth + idx)).nc() == 1 for all idx such that 0 <= idx < sub.get_output().num_samples() + - (*(truth + idx)).nr() == sub.get_output().k() for all idx such that 0 <= idx < sub.get_output().num_samples() + !*/ + + }; + + template + using loss_mean_squared_multioutput = add_loss_layer; + +// ---------------------------------------------------------------------------------------- + + class loss_multiclass_log_per_pixel_ + { + /*! + WHAT THIS OBJECT REPRESENTS + This object implements the loss layer interface defined above by + EXAMPLE_LOSS_LAYER_. In particular, it implements the multiclass logistic + regression loss (e.g. negative log-likelihood loss), which is appropriate + for multiclass classification problems. It is basically just like + loss_multiclass_log_ except that it lets you define matrix outputs instead + of scalar outputs. It should be useful, for example, in semantic + segmentation where we want to classify each pixel of an image. + !*/ + public: + + // In semantic segmentation, if you don't know the ground-truth of some pixel, + // set the label of that pixel to this value. When you do so, the pixel will be + // ignored when computing gradients. + static const uint16_t label_to_ignore = std::numeric_limits::max(); + + // In semantic segmentation, 65535 classes ought to be enough for anybody. + typedef matrix training_label_type; + typedef matrix output_label_type; + + template < + typename SUB_TYPE, + typename label_iterator + > + void to_label ( + const tensor& input_tensor, + const SUB_TYPE& sub, + label_iterator iter + ) const; + /*! + This function has the same interface as EXAMPLE_LOSS_LAYER_::to_label() except + it has the additional calling requirements that: + - sub.get_output().num_samples() == input_tensor.num_samples() + - sub.sample_expansion_factor() == 1 + and the output label is the predicted class for each classified element. The number + of possible output classes is sub.get_output().k(). + !*/ + + template < + typename const_label_iterator, + typename SUBNET + > + double compute_loss_value_and_gradient ( + const tensor& input_tensor, + const_label_iterator truth, + SUBNET& sub + ) const; + /*! + This function has the same interface as EXAMPLE_LOSS_LAYER_::compute_loss_value_and_gradient() + except it has the additional calling requirements that: + - sub.get_output().num_samples() == input_tensor.num_samples() + - sub.sample_expansion_factor() == 1 + - all values pointed to by truth are < sub.get_output().k() or are equal to label_to_ignore. + !*/ + + }; + + template + using loss_multiclass_log_per_pixel = add_loss_layer; + +// ---------------------------------------------------------------------------------------- + + class loss_multiclass_log_per_pixel_weighted_ + { + /*! + WHAT THIS OBJECT REPRESENTS + This object implements the loss layer interface defined above by + EXAMPLE_LOSS_LAYER_. In particular, it implements the multiclass logistic + regression loss (e.g. negative log-likelihood loss), which is appropriate + for multiclass classification problems. It is basically just like + loss_multiclass_log_per_pixel_ except that it lets you define per-pixel + weights, which may be useful e.g. if you want to emphasize rare classes + while training. (If the classification problem is difficult, a flat weight + structure may lead the network to always predict the most common label, in + particular if the degree of imbalance is high. To emphasize a certain + class or classes, simply increase the weights of the corresponding pixels, + relative to the weights of the other pixels.) + + Note that if you set the weight to 0 whenever a pixel's label is equal to + loss_multiclass_log_per_pixel_::label_to_ignore, and to 1 otherwise, then + you essentially get loss_multiclass_log_per_pixel_ as a special case. + !*/ + public: + + struct weighted_label + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents the truth label of a single pixel, together with + an associated weight (the higher the weight, the more emphasis the + corresponding pixel is given during the training). + !*/ + + weighted_label(); + weighted_label(uint16_t label, float weight = 1.f); + + // The ground-truth label. In semantic segmentation, 65536 classes ought to be + // enough for anybody. + uint16_t label = 0; + + // The weight of the corresponding pixel. + float weight = 1.f; + }; + + typedef matrix training_label_type; + typedef matrix output_label_type; + + template < + typename SUB_TYPE, + typename label_iterator + > + void to_label ( + const tensor& input_tensor, + const SUB_TYPE& sub, + label_iterator iter + ) const; + /*! + This function has the same interface as EXAMPLE_LOSS_LAYER_::to_label() except + it has the additional calling requirements that: + - sub.get_output().num_samples() == input_tensor.num_samples() + - sub.sample_expansion_factor() == 1 + and the output label is the predicted class for each classified element. The number + of possible output classes is sub.get_output().k(). + !*/ + + template < + typename const_label_iterator, + typename SUBNET + > + double compute_loss_value_and_gradient ( + const tensor& input_tensor, + const_label_iterator truth, + SUBNET& sub + ) const; + /*! + This function has the same interface as EXAMPLE_LOSS_LAYER_::compute_loss_value_and_gradient() + except it has the additional calling requirements that: + - sub.get_output().num_samples() == input_tensor.num_samples() + - sub.sample_expansion_factor() == 1 + - all labels pointed to by truth are < sub.get_output().k(), or the corresponding weight + is zero. + !*/ + + }; + + template + using loss_multiclass_log_per_pixel_weighted = add_loss_layer; + +// ---------------------------------------------------------------------------------------- + + class loss_mean_squared_per_pixel_ + { + /*! + WHAT THIS OBJECT REPRESENTS + This object implements the loss layer interface defined above by + EXAMPLE_LOSS_LAYER_. In particular, it implements the mean squared loss, + which is appropriate for regression problems. It is basically just like + loss_mean_squared_multioutput_ except that it lets you define matrix or + image outputs, instead of vector. + !*/ + public: + + typedef matrix training_label_type; + typedef matrix output_label_type; + + template < + typename SUB_TYPE, + typename label_iterator + > + void to_label ( + const tensor& input_tensor, + const SUB_TYPE& sub, + label_iterator iter + ) const; + /*! + This function has the same interface as EXAMPLE_LOSS_LAYER_::to_label() except + it has the additional calling requirements that: + - sub.get_output().num_samples() == input_tensor.num_samples() + - sub.sample_expansion_factor() == 1 + and the output labels are the predicted continuous variables. + !*/ + + template < + typename const_label_iterator, + typename SUBNET + > + double compute_loss_value_and_gradient ( + const tensor& input_tensor, + const_label_iterator truth, + SUBNET& sub + ) const; + /*! + This function has the same interface as EXAMPLE_LOSS_LAYER_::compute_loss_value_and_gradient() + except it has the additional calling requirements that: + - sub.get_output().k() == 1 + - sub.get_output().num_samples() == input_tensor.num_samples() + - sub.sample_expansion_factor() == 1 + - for all idx such that 0 <= idx < sub.get_output().num_samples(): + - sub.get_output().nr() == (*(truth + idx)).nr() + - sub.get_output().nc() == (*(truth + idx)).nc() + !*/ + }; + + template + using loss_mean_squared_per_pixel = add_loss_layer; + +// ---------------------------------------------------------------------------------------- + + class loss_dot_ + { + /*! + WHAT THIS OBJECT REPRESENTS + This object implements the loss layer interface defined above by + EXAMPLE_LOSS_LAYER_. In particular, selecting this loss means you want + maximize the dot product between the output of a network and a set of + training vectors. The loss is therefore the negative dot product. To be + very specific, if X is the output vector of a network and Y is a training + label (also a vector), then the loss for this training sample is: -dot(X,Y) + !*/ + + public: + + typedef matrix training_label_type; + typedef matrix output_label_type; + + template < + typename SUB_TYPE, + typename label_iterator + > + void to_label ( + const tensor& input_tensor, + const SUB_TYPE& sub, + label_iterator iter + ) const; + /*! + This function has the same interface as EXAMPLE_LOSS_LAYER_::to_label() except + it has the additional calling requirements that: + - sub.get_output().num_samples() == input_tensor.num_samples() + - sub.sample_expansion_factor() == 1 + and the output labels are simply the final network outputs stuffed into a + vector. To be very specific, the output is the following for all valid i: + *(iter+i) == trans(rowm(mat(sub.get_output()),i)) + !*/ + + + template < + typename const_label_iterator, + typename SUBNET + > + double compute_loss_value_and_gradient ( + const tensor& input_tensor, + const_label_iterator truth, + SUBNET& sub + ) const; + /*! + This function has the same interface as EXAMPLE_LOSS_LAYER_::compute_loss_value_and_gradient() + except it has the additional calling requirements that: + - sub.get_output().num_samples() == input_tensor.num_samples() + - sub.sample_expansion_factor() == 1 + - Let NETWORK_OUTPUT_DIMS == sub.get_output().size()/sub.get_output().num_samples() + - for all idx such that 0 <= idx < sub.get_output().num_samples(): + - NETWORK_OUTPUT_DIMS == (*(truth + idx)).size() + !*/ + }; + + template + using loss_dot = add_loss_layer; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_DNn_LOSS_ABSTRACT_H_ + diff --git a/ml/dlib/dlib/dnn/solvers.h b/ml/dlib/dlib/dnn/solvers.h new file mode 100644 index 000000000..204541a7e --- /dev/null +++ b/ml/dlib/dlib/dnn/solvers.h @@ -0,0 +1,405 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_DNn_SOLVERS_H_ +#define DLIB_DNn_SOLVERS_H_ + +#include "solvers_abstract.h" +#include "tensor.h" +#include +#include "layers.h" + +namespace dlib +{ + class sgd + { + public: + + explicit sgd( + float weight_decay_, + float momentum_ = 0.9 + ) + { + weight_decay = weight_decay_; + momentum = momentum_; + } + + sgd( + ) : sgd(0.0005, 0.9) + { + } + + float get_momentum ( + ) const { return momentum; } + + float get_weight_decay ( + ) const { return weight_decay; } + + template + const tensor& operator() ( + const float learning_rate, + const layer_type& l, + const tensor& params_grad + ) + { + const tensor& params = l.get_layer_params(); + + DLIB_CASSERT(params.size() != 0); + if (v.size() == 0) + { + v.copy_size(params_grad); + v = 0; + } + + const double lr = learning_rate*get_learning_rate_multiplier(l); + const double wd = weight_decay*get_weight_decay_multiplier(l); + + //perform: v = momentum*mat(v) - wd*lr*mat(params) - lr*mat(params_grad); + tt::affine_transform(v, v, params, params_grad, momentum, -wd*lr, -lr); + + return v; + } + + template + const tensor& operator() ( + const float learning_rate, + const fc_& l, + const tensor& params_grad + ) + { + update_considering_bias(learning_rate, l, params_grad, params_grad.size()-l.get_num_outputs()); + return v; + } + + template < + long _num_filters, + long _nr, + long _nc, + int _stride_y, + int _stride_x, + int _padding_y, + int _padding_x + > + const tensor& operator() ( + const float learning_rate, + const con_<_num_filters,_nr,_nc,_stride_y,_stride_x,_padding_y,_padding_x>& l, + const tensor& params_grad + ) + { + update_considering_bias(learning_rate, l, params_grad, params_grad.size()-l.num_filters()); + return v; + } + + template < + long _num_filters, + long _nr, + long _nc, + int _stride_y, + int _stride_x, + int _padding_y, + int _padding_x + > + const tensor& operator() ( + const float learning_rate, + const cont_<_num_filters,_nr,_nc,_stride_y,_stride_x,_padding_y,_padding_x>& l, + const tensor& params_grad + ) + { + update_considering_bias(learning_rate, l, params_grad, params_grad.size()-l.num_filters()); + return v; + } + + template < layer_mode mode > + const tensor& operator() ( + const float learning_rate, + const bn_& l, + const tensor& params_grad + ) + { + update_considering_bias(learning_rate, l, params_grad, params_grad.size()/2); + return v; + } + + friend void serialize(const sgd& item, std::ostream& out) + { + serialize("sgd2", out); + serialize(item.v, out); + serialize(item.weight_decay, out); + serialize(item.momentum, out); + } + + friend void deserialize(sgd& item, std::istream& in) + { + std::string version; + deserialize(version, in); + if (version != "sgd2") + throw serialization_error("Unexpected version found while deserializing dlib::sgd."); + deserialize(item.v, in); + deserialize(item.weight_decay, in); + deserialize(item.momentum, in); + } + + friend std::ostream& operator<< (std::ostream& out, const sgd& item) + { + out << "sgd: weight_decay="< + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class EXAMPLE_SOLVER + { + /*! + WHAT THIS OBJECT REPRESENTS + A solver defines the parameter update rule for a single layer in a deep + neural network. It takes a parameter gradient vector and the layer's + parameters and tells you how the parameters should be updated. + Importantly, each solver instance is used with only one layer in a network. + This allows us to define solvers that have per layer state, for example, a + solver may keep a momentum term and apply it to its update rule. + + Note that there is no dlib::EXAMPLE_SOLVER type. It is shown here purely + to document the interface a solver object must implement. + !*/ + + public: + + EXAMPLE_SOLVER( + ); + + template + const tensor& operator() ( + const float learning_rate, + const layer_type& l, + const tensor& params_grad + ) + /*! + requires + - l.get_layer_params().size() != 0 + - have_same_dimensions(l.get_layer_params(), params_grad) == true. + - When this function is invoked on a particular solver instance, it is + always supplied with the same layer instance, l. That is, the solver is + allowed to remember things from one invocation to another and to assume + that it is being serially applied to optimize the same layer's + parameters. + ensures + - Returns a step vector V that is intended to be used to update the + parameters by adding V to l.get_layer_params(). + - This function will use the given "learning rate" to compute V. How the + learning rate is used is solver dependent. But in general the learning + rate should be used to select the step size, i.e. to somehow determine + the magnitude of V. + !*/ + }; + + void serialize(const EXAMPLE_SOLVER& item, std::ostream& out); + void deserialize(EXAMPLE_SOLVER& item, std::istream& in); + /*! + provides serialization support + !*/ + + std::ostream& operator<< (std::ostream& out, const EXAMPLE_SOLVER& item); + /*! + Prints the solver's name and parameters to out. + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class sgd + { + /*! + WHAT THIS OBJECT REPRESENTS + This object implements the EXAMPLE_SOLVER interface defined above. It is a + basic stochastic gradient descent solver which uses momentum and weight + decay. In particular, it computes the update vector V according to: + V = momentum*V - weight_decay*learning_rate*l.get_layer_params() - learning_rate*params_grad; + Here V is a momentum term that is remembered by the solver from one + invocation of operator() to the next. + + + Note that the actual learning rate and weight decay used by the solver are + multiplied by the per layer multipliers. That is, the solver will call + get_learning_rate_multiplier(l) and get_weight_decay_multiplier(l) and + multiply these values with the nominal learning rate and weight decay, + respectively, to determine the values it will use during each step. It is + also overloaded to allow additional learning rate multipliers to be applied + to fc_ and con_ bias parameters. + !*/ + public: + + sgd( + ); + /*! + ensures + - #get_weight_decay() == 0.0005 + - #get_momentum() == 0.9 + !*/ + + explicit sgd( + float weight_decay, + float momentum = 0.9 + ); + /*! + requires + - weight_decay >= 0 + - momentum >= 0 + ensures + - #get_weight_decay() == weight_decay + - #get_momentum() == momentum + !*/ + + float get_weight_decay () const; + float get_momentum () const; + }; + + void serialize(const sgd& item, std::ostream& out); + void deserialize(sgd& item, std::istream& in); + /*! + provides serialization support + !*/ + + std::ostream& operator<< (std::ostream& out, const sgd& item); + /*! + Prints the solver's name and parameters to out. + !*/ + +// ---------------------------------------------------------------------------------------- + + class adam + { + /*! + WHAT THIS OBJECT REPRESENTS + This object implements the EXAMPLE_SOLVER interface defined above. In + particular, it implements the ADAM parameter update method described in the + paper: + Kingma, Diederik P., and Jimmy Ba Adam. "A method for stochastic + optimization." International Conference on Learning Representation. 2015. + + + Note that the actual learning rate and weight decay used by the solver are + multiplied by the per layer multipliers. That is, the solver will call + get_learning_rate_multiplier(l) and get_weight_decay_multiplier(l) and + multiply these values with the nominal learning rate and weight decay, + respectively, to determine the values it will use during each step. It is + also overloaded to allow additional learning rate multipliers to be applied + to fc_ and con_ bias parameters. + !*/ + + public: + + adam( + ); + /*! + ensures + - #get_weight_decay() == 0.0005 + - #get_momentum1() == 0.9 + - #get_momentum2() == 0.999 + !*/ + + adam( + float weight_decay, + float momentum1, + float momentum2 + ); + /*! + requires + - weight_decay >= 0 + - 0 <= momentum1 < 1 + - 0 <= momentum2 < 1 + ensures + - #get_weight_decay() == weight_decay + - #get_momentum1() == momentum1 + - #get_momentum2() == momentum2 + !*/ + + float get_weight_decay () const; + float get_momentum1 () const; + float get_momentum2 () const; + }; + + void serialize(const adam& item, std::ostream& out); + void deserialize(adam& item, std::istream& in); + /*! + provides serialization support + !*/ + + std::ostream& operator<< (std::ostream& out, const adam& item); + /*! + Prints the solver's name and parameters to out. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_DNn_SOLVERS_ABSTRACT_H_ + diff --git a/ml/dlib/dlib/dnn/tensor.h b/ml/dlib/dlib/dnn/tensor.h new file mode 100644 index 000000000..8039fe666 --- /dev/null +++ b/ml/dlib/dlib/dnn/tensor.h @@ -0,0 +1,686 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_DNn_TENSOR_H_ +#define DLIB_DNn_TENSOR_H_ + +#include "tensor_abstract.h" +#include +#include "../matrix.h" +#include "cudnn_dlibapi.h" +#include "gpu_data.h" +#include "../byte_orderer.h" +#include +#include "../any.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class tensor; + namespace cuda + { + void set_tensor ( + tensor& t, + float value + ); + + void scale_tensor ( + tensor& t, + float value + ); + } + +// ---------------------------------------------------------------------------------------- + + class tensor + { + public: + + tensor ( + ) : + m_n(0), m_k(0), m_nr(0), m_nc(0), m_size(0) + { + } + + virtual ~tensor() {} + + long long num_samples() const { return m_n; } + long long k() const { return m_k; } + long long nr() const { return m_nr; } + long long nc() const { return m_nc; } + size_t size() const { return m_size; } + + typedef float* iterator; + typedef const float* const_iterator; + iterator begin() { return host(); } + const_iterator begin() const { return host(); } + iterator end() { return host()+size(); } + const_iterator end() const { return host()+size(); } + + void async_copy_to_device() const + { + data().async_copy_to_device(); + } + + virtual const float* host() const = 0; + virtual float* host() = 0; + virtual float* host_write_only() = 0; + virtual const float* device() const = 0; + virtual float* device() = 0; + virtual float* device_write_only() = 0; + + virtual const any& annotation() const = 0; + virtual any& annotation() = 0; + + int device_id() const { return data().device_id(); } + + tensor& operator= (float val) + { +#ifdef DLIB_USE_CUDA + // If you are using CUDA then presumably you will be mostly using tensors on + // the GPU. So unless you seem to be actively working with the host side's + // data then we do this initialization on the device side since this avoids a + // host to device transfer that would likely immediately follow. + if (data().device_ready()) + { + cuda::set_tensor(*this, val); + return *this; + } +#endif + auto d = host_write_only(); + for (size_t i = 0; i < size(); ++i) + d[i] = val; + + return *this; + } + + tensor& operator*= (float val) + { +#ifdef DLIB_USE_CUDA + cuda::scale_tensor(*this, val); + return *this; +#else + for (auto& d : *this) + d *= val; + + return *this; +#endif + } + + tensor& operator/= (float val) + { + *this *= 1.0/val; + return *this; + } + + template + tensor& operator= (const matrix_exp& item) + { + DLIB_CASSERT(num_samples() == item.nr() && + nr()*nc()*k() == item.nc()); + static_assert((is_same_type::value == true), + "To assign a matrix to a tensor the matrix must contain float values"); + + set_ptrm(host_write_only(), m_n, m_nr*m_nc*m_k) = item; + return *this; + } + + template + tensor& operator+= (const matrix_exp& item) + { + DLIB_CASSERT(num_samples() == item.nr() && + nr()*nc()*k() == item.nc()); + static_assert((is_same_type::value == true), + "To assign a matrix to a tensor the matrix must contain float values"); + set_ptrm(host(), m_n, m_nr*m_nc*m_k) += item; + return *this; + } + + template + tensor& operator-= (const matrix_exp& item) + { + DLIB_CASSERT(num_samples() == item.nr() && + nr()*nc()*k() == item.nc()); + static_assert((is_same_type::value == true), + "To assign a matrix to a tensor the matrix must contain float values"); + set_ptrm(host(), m_n, m_nr*m_nc*m_k) -= item; + return *this; + } + + template + void set_sample ( + unsigned long long idx, + const matrix_exp& item + ) + { + DLIB_CASSERT(idx < (unsigned long long)num_samples()); + DLIB_CASSERT(item.size() == nr()*nc()*k()); + static_assert((is_same_type::value == true), + "To assign a matrix to a tensor the matrix must contain float values"); + set_ptrm(host()+idx*item.size(), item.nr(), item.nc()) = item; + } + + + template + void add_to_sample ( + unsigned long long idx, + const matrix_exp& item + ) + { + DLIB_CASSERT(idx < (unsigned long long)num_samples()); + DLIB_CASSERT(item.size() == nr()*nc()*k()); + static_assert((is_same_type::value == true), + "To assign a matrix to a tensor the matrix must contain float values"); + set_ptrm(host()+idx*item.size(), item.nr(), item.nc()) += item; + } + + +#ifdef DLIB_USE_CUDA + virtual const cuda::tensor_descriptor& get_cudnn_tensor_descriptor ( + ) const = 0; +#endif + + friend void memcpy ( + tensor& dest, + const tensor& src + ) + { + DLIB_CASSERT(dest.size() == src.size()); + memcpy(dest.data(), dest.get_alias_offset(), + src.data(), src.get_alias_offset(), + src.size()); + } + + + protected: + + friend class alias_tensor; + + virtual gpu_data& data() = 0; + virtual const gpu_data& data() const = 0; + virtual size_t get_alias_offset() const { return 0; } // needed by alias_tensor. + + long long m_n; + long long m_k; + long long m_nr; + long long m_nc; + long long m_size; // always equal to m_n*m_k*m_nr*m_nc + }; + +// ---------------------------------------------------------------------------------------- + + inline bool is_vector ( + const tensor& t + ) + { + return t.size() == (size_t)t.num_samples() || + t.size() == (size_t)t.k() || + t.size() == (size_t)t.nr() || + t.size() == (size_t)t.nc(); + } + +// ---------------------------------------------------------------------------------------- + + inline const matrix_op > mat ( + const tensor& t, + long long nr, + long long nc + ) + { + DLIB_ASSERT(nr >= 0 && nc >= 0 , + "\tconst matrix_exp mat(tensor, nr, nc)" + << "\n\t nr and nc must be >= 0" + << "\n\t nr: " << nr + << "\n\t nc: " << nc + ); + DLIB_ASSERT(nr*nc == (long long)t.size() , + "\tconst matrix_exp mat(tensor, nr, nc)" + << "\n\t The sizes don't match up." + << "\n\t nr*nc: " << nr*nc + << "\n\t t.size(): " << t.size() + ); + typedef op_pointer_to_mat op; + return matrix_op(op(t.host(),nr,nc)); + } + + inline const matrix_op > mat ( + const tensor& t + ) + { + if (t.size() != 0) + return mat(t, t.num_samples(), t.size()/t.num_samples()); + else + return mat((float*)0,0,0); + } + + inline const matrix_op > image_plane ( + const tensor& t, + long long sample = 0, + long long k = 0 + ) + { + DLIB_ASSERT(0 <= sample && sample < t.num_samples() && + 0 <= k && k < t.k() && + t.size() != 0, + "\tconst matrix_exp image_plane(tensor,sample,k)" + << "\n\t Invalid arguments were given to this function." + << "\n\t sample: " << sample + << "\n\t k: " << k + << "\n\t t.num_samples(): " << t.num_samples() + << "\n\t t.k(): " << t.k() + << "\n\t t.size(): " << t.size() + ); + + + typedef op_pointer_to_mat op; + return matrix_op(op(t.host() + ((sample*t.k() + k)*t.nr())*t.nc(), + t.nr(), + t.nc())); + } + +// ---------------------------------------------------------------------------------------- + + inline bool have_same_dimensions ( + const tensor& a, + const tensor& b + ) + { + return a.num_samples() == b.num_samples() && + a.k() == b.k() && + a.nr() == b.nr() && + a.nc() == b.nc(); + } + +// ---------------------------------------------------------------------------------------- + + class resizable_tensor : public tensor + { + public: + resizable_tensor( + ) + {} + + template + resizable_tensor( + const matrix_exp& item + ) + { + set_size(item.nr(), item.nc()); + *this = item; + } + + explicit resizable_tensor( + long long n_, long long k_ = 1, long long nr_ = 1, long long nc_ = 1 + ) + { + DLIB_ASSERT( n_ >= 0 && k_ >= 0 && nr_ >= 0 && nc_ >= 0); + + set_size(n_,k_,nr_,nc_); + } + + resizable_tensor(const resizable_tensor& item) : _annotation(item.annotation()) + { + copy_size(item); + memcpy(*this, item); + } + resizable_tensor(const tensor& item) : _annotation(item.annotation()) + { + copy_size(item); + memcpy(*this, item); + } + + resizable_tensor(resizable_tensor&& item) { swap(item); } + resizable_tensor& operator=(resizable_tensor&& item) { swap(item); return *this; } + + virtual const float* host() const { return data_instance.host(); } + virtual float* host() { return data_instance.host(); } + virtual float* host_write_only() { return data_instance.host_write_only(); } + virtual const float* device() const { return data_instance.device(); } + virtual float* device() { return data_instance.device(); } + virtual float* device_write_only() { return data_instance.device_write_only(); } + + virtual const any& annotation() const { return _annotation; } + virtual any& annotation() { return _annotation; } + + void clear( + ) + { + set_size(0,0,0,0); + _annotation.clear(); + // free underlying memory + data_instance.set_size(0); + } + + void copy_size ( + const tensor& item + ) + { + set_size(item.num_samples(), item.k(), item.nr(), item.nc()); + } + + resizable_tensor& operator= (float val) + { + tensor::operator=(val); + return *this; + } + + template + resizable_tensor& operator= ( + const matrix_exp& item + ) + { + if (!(num_samples() == item.nr() && k()*nr()*nc() == item.nc())) + set_size(item.nr(), item.nc()); + tensor::operator=(item); + return *this; + } + + void set_size( + long long n_, long long k_ = 1, long long nr_ = 1, long long nc_ = 1 + ) + { + DLIB_ASSERT( n_ >= 0 && k_ >= 0 && nr_ >= 0 && nc_ >= 0); + + m_n = n_; + m_k = k_; + m_nr = nr_; + m_nc = nc_; + m_size = n_*k_*nr_*nc_; + if ((long long)data_instance.size() < m_size) + data_instance.set_size(m_size); +#ifdef DLIB_USE_CUDA + cudnn_descriptor.set_size(m_n,m_k,m_nr,m_nc); +#endif + } + + + resizable_tensor& operator= (const resizable_tensor& item) + { + resizable_tensor temp(item); + temp.swap(*this); + return *this; + } + + resizable_tensor& operator= (const tensor& item) + { + resizable_tensor temp(item); + temp.swap(*this); + return *this; + } + + + void swap(resizable_tensor& item) + { + std::swap(m_n, item.m_n); + std::swap(m_k, item.m_k); + std::swap(m_nr, item.m_nr); + std::swap(m_nc, item.m_nc); + std::swap(m_size, item.m_size); + std::swap(data_instance, item.data_instance); + std::swap(_annotation, item._annotation); +#ifdef DLIB_USE_CUDA + std::swap(cudnn_descriptor, item.cudnn_descriptor); +#endif + } + +#ifdef DLIB_USE_CUDA + virtual const cuda::tensor_descriptor& get_cudnn_tensor_descriptor ( + ) const { return cudnn_descriptor; } +#endif + + private: + +#ifdef DLIB_USE_CUDA + cuda::tensor_descriptor cudnn_descriptor; +#endif + + gpu_data data_instance; + any _annotation; + virtual gpu_data& data() { return data_instance; } + virtual const gpu_data& data() const { return data_instance; } + }; + + inline void serialize(const tensor& item, std::ostream& out) + { + int version = 2; + serialize(version, out); + serialize(item.num_samples(), out); + serialize(item.k(), out); + serialize(item.nr(), out); + serialize(item.nc(), out); + byte_orderer bo; + auto sbuf = out.rdbuf(); + for (auto d : item) + { + // Write out our data as 4byte little endian IEEE floats rather than using + // dlib's default float serialization. We do this because it will result in + // more compact outputs. It's slightly less portable but it seems doubtful + // that any CUDA enabled platform isn't going to use IEEE floats. But if one + // does we can just update the serialization code here to handle it if such a + // platform is encountered. + bo.host_to_little(d); + static_assert(sizeof(d)==4, "This serialization code assumes we are writing 4 byte floats"); + sbuf->sputn((char*)&d, sizeof(d)); + } + } + + inline void deserialize(resizable_tensor& item, std::istream& in) + { + int version; + deserialize(version, in); + if (version != 2) + throw serialization_error("Unexpected version found while deserializing dlib::resizable_tensor."); + + long long num_samples=0, k=0, nr=0, nc=0; + deserialize(num_samples, in); + deserialize(k, in); + deserialize(nr, in); + deserialize(nc, in); + item.set_size(num_samples, k, nr, nc); + byte_orderer bo; + auto sbuf = in.rdbuf(); + for (auto& d : item) + { + static_assert(sizeof(d)==4, "This serialization code assumes we are writing 4 byte floats"); + if (sbuf->sgetn((char*)&d,sizeof(d)) != sizeof(d)) + { + in.setstate(std::ios::badbit); + throw serialization_error("Error reading data while deserializing dlib::resizable_tensor."); + } + bo.little_to_host(d); + } + } + +// ---------------------------------------------------------------------------------------- + + inline double dot( + const tensor& a, + const tensor& b + ) + { + DLIB_CASSERT(a.size() == b.size()); + const float* da = a.host(); + const float* db = b.host(); + double sum = 0; + for (size_t i = 0; i < a.size(); ++i) + sum += da[i]*db[i]; + return sum; + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class alias_tensor_instance : public tensor + { + alias_tensor_instance( + ) : data_instance(0), _annotation(0), data_offset(0) {} + + public: + friend class alias_tensor; + friend class alias_tensor_const_instance; + + alias_tensor_instance& operator= (float val) + { + tensor::operator=(val); + return *this; + } + + template + alias_tensor_instance& operator= (const matrix_exp& item) + { + tensor::operator=(item); + return *this; + } + + virtual const float* host() const { return data_instance->host()+data_offset; } + virtual float* host() { return data_instance->host()+data_offset; } + virtual float* host_write_only() { return data_instance->host()+data_offset; } + virtual const float* device() const { return data_instance->device()+data_offset; } + virtual float* device() { return data_instance->device()+data_offset; } + virtual float* device_write_only() { return data_instance->device()+data_offset; } + + virtual const any& annotation() const { return *_annotation; } + virtual any& annotation() { return *_annotation; } + +#ifdef DLIB_USE_CUDA + virtual const cuda::tensor_descriptor& get_cudnn_tensor_descriptor ( + ) const { return *cudnn_descriptor; } +#endif + private: + + virtual size_t get_alias_offset() const { return data_offset; } + +#ifdef DLIB_USE_CUDA + std::shared_ptr cudnn_descriptor; +#endif + gpu_data* data_instance; + any* _annotation; + size_t data_offset; + virtual gpu_data& data() { return *data_instance; } + virtual const gpu_data& data() const { return *data_instance; } + }; + +// ---------------------------------------------------------------------------------------- + + class alias_tensor_const_instance + { + public: + const tensor& get() const { return inst; } + operator const tensor& () { return inst; } + + alias_tensor_const_instance(const alias_tensor_instance& item) : inst(item) {} + + private: + alias_tensor_instance inst; + + friend class alias_tensor; + alias_tensor_const_instance() {} + }; + +// ---------------------------------------------------------------------------------------- + + class alias_tensor + { + public: + + alias_tensor ( + ) {} + + alias_tensor ( + long long n_, long long k_ = 1, long long nr_ = 1, long long nc_ = 1 + ) + { + DLIB_ASSERT( n_ >= 0 && k_ >= 0 && nr_ >= 0 && nc_ >= 0); + + inst.m_n = n_; + inst.m_k = k_; + inst.m_nr = nr_; + inst.m_nc = nc_; + inst.m_size = n_*k_*nr_*nc_; + } + + long long num_samples( + ) const { return inst.m_n; } + + long long k( + ) const { return inst.m_k; } + + long long nr( + ) const { return inst.m_nr; } + + long long nc( + ) const { return inst.m_nc; } + + size_t size( + ) const { return inst.m_size; } + + alias_tensor_instance operator() ( + tensor& t, + size_t offset = 0 + ) const + { + DLIB_CASSERT(offset+size() <= t.size(), + "offset: "<(); + inst.cudnn_descriptor->set_size(inst.m_n, inst.m_k, inst.m_nr, inst.m_nc); + } +#endif + inst.data_instance = &t.data(); + inst._annotation = &t.annotation(); + // Note that t might already be an aliasing tensor so we need to take that into + // account. + inst.data_offset = t.get_alias_offset()+offset; + return inst; + } + + alias_tensor_const_instance operator() ( + const tensor& t, + size_t offset = 0 + ) const + { + alias_tensor_const_instance temp; + temp.inst = (*this)(const_cast(t),offset); + return temp; + } + + private: + mutable alias_tensor_instance inst; + }; + + inline void serialize(const alias_tensor& item, std::ostream& out) + { + int version = 1; + serialize(version, out); + serialize(item.num_samples(), out); + serialize(item.k(), out); + serialize(item.nr(), out); + serialize(item.nc(), out); + } + + inline void deserialize(alias_tensor& item, std::istream& in) + { + int version = 0; + deserialize(version, in); + if (version != 1) + throw serialization_error("Unexpected version found while deserializing dlib::alias_tensor."); + long long num_samples, k, nr, nc; + deserialize(num_samples, in); + deserialize(k, in); + deserialize(nr, in); + deserialize(nc, in); + item = alias_tensor(num_samples, k, nr, nc); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_DNn_TENSOR_H_ + diff --git a/ml/dlib/dlib/dnn/tensor_abstract.h b/ml/dlib/dlib/dnn/tensor_abstract.h new file mode 100644 index 000000000..73a9fff77 --- /dev/null +++ b/ml/dlib/dlib/dnn/tensor_abstract.h @@ -0,0 +1,727 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_DNn_TENSOR_ABSTRACT_H_ +#ifdef DLIB_DNn_TENSOR_ABSTRACT_H_ + +#include "../matrix.h" +#include "../any/any_abstract.h" + +namespace dlib +{ +// ---------------------------------------------------------------------------------------- + + class tensor + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a 4D array of float values, all stored contiguously + in memory. Importantly, it keeps two copies of the floats, one on the host + CPU side and another on the GPU device side. It automatically performs the + necessary host/device transfers to keep these two copies of the data in + sync. + + All transfers to the device happen asynchronously with respect to the + default CUDA stream so that CUDA kernel computations can overlap with data + transfers. However, any transfers from the device to the host happen + synchronously in the default CUDA stream. Therefore, you should perform + all your CUDA kernel launches on the default stream so that transfers back + to the host do not happen before the relevant computations have completed. + + If DLIB_USE_CUDA is not #defined then this object will not use CUDA at all. + Instead, it will simply store one host side memory block of floats. + + Finally, the convention in dlib code is to interpret the tensor as a set of + num_samples() 3D arrays, each of dimension k() by nr() by nc(). Also, + while this class does not specify a memory layout, the convention is to + assume that indexing into an element at coordinates (sample,k,r,c) can be + accomplished via: + host()[((sample*t.k() + k)*t.nr() + r)*t.nc() + c] + + THREAD SAFETY + Instances of this object are not thread-safe. So don't touch one from + multiple threads at the same time. + !*/ + + public: + + virtual ~tensor(); + + long long num_samples( + ) const; + /*! + ensures + - returns the number of 3D arrays of dimension k() by nr() by nc() there + are in this object. + !*/ + + long long k( + ) const; + /*! + ensures + - returns the k dimension of this tensor. Generally, we think of a tensor + as containing num_samples() images of nr() by nc() rows and columns, each + with k() channels. + !*/ + + long long nr( + ) const; + /*! + ensures + - returns the number of rows in this tensor. + !*/ + + long long nc( + ) const; + /*! + ensures + - returns the number of columns in this tensor. + !*/ + + size_t size( + ) const; + /*! + ensures + - returns num_samples()*k()*nr()*nc() + (i.e. the total number of floats in this tensor) + !*/ + + void async_copy_to_device( + ) const; + /*! + ensures + - This function does not block. + - if (the host version of the data is newer than the device's copy) then + - Begins asynchronously copying host data to the device. + - A call to device() that happens before the transfer completes will + block until the transfer is complete. That is, it is safe to call + async_copy_to_device() and then immediately call device(). + !*/ + + typedef float* iterator; + typedef const float* const_iterator; + iterator begin() { return host(); } + const_iterator begin() const { return host(); } + iterator end() { return host()+size(); } + const_iterator end() const { return host()+size(); } + /*! + ensures + - makes a tensor iterable just like the STL containers. + !*/ + + virtual const float* host( + ) const = 0; + /*! + ensures + - returns a pointer to the host memory block of size() contiguous float + values or nullptr if size()==0. + - if (the host's copy of the data is out of date) then + - copies the data from the device to the host, while this is happening + the call to host() blocks. + !*/ + + virtual float* host( + ) = 0; + /*! + ensures + - returns a pointer to the host memory block of size() contiguous float + values or nullptr if size()==0. + - if (the host's copy of the data is out of date) then + - copies the data from the device to the host, while this is happening + the call to host() blocks. + - Marks the device side data as out of date so that the next call to + device() will perform a host to device transfer. If you want to begin + the transfer immediately then you can call async_copy_to_device() after + calling host(). + !*/ + + virtual float* host_write_only( + ) = 0; + /*! + ensures + - This function returns the same pointer as host(), except that it never + performs a device to host memory copy. Instead, it immediately marks the + device side data as out of date, effectively discarding it. Therefore, + the values in the data pointed to by host_write_only() are undefined and + you should only call host_write_only() if you are going to assign to + every memory location in the returned memory block. + !*/ + + virtual const float* device( + ) const = 0; + /*! + requires + - DLIB_USE_CUDA is #defined + ensures + - returns a pointer to the device memory block of size() contiguous float + values or nullptr if size()==0. + - if (the device's copy of the data is out of date) then + - copies the data from the host to the device, while this is happening + the call to device() blocks. + !*/ + + virtual float* device( + ) = 0; + /*! + requires + - DLIB_USE_CUDA is #defined + ensures + - returns a pointer to the device memory block of size() contiguous float + values or nullptr if size()==0. + - if (the device's copy of the data is out of date) then + - copies the data from the host to the device, while this is happening + the call to device() blocks. + - Marks the host side data as out of date so that the next call to + host() will perform a device to host transfer. + !*/ + + virtual float* device_write_only( + ) = 0; + /*! + requires + - DLIB_USE_CUDA is #defined + ensures + - This function returns the same pointer as device(), except that it never + performs a host to device memory copy. Instead, it immediately marks the + host side data as out of date, effectively discarding it. Therefore, the + values in the data pointed to by device_write_only() are undefined and + you should only call device_write_only() if you are going to assign to + every memory location in the returned memory block. + !*/ + + virtual const any& annotation( + ) const = 0; + /*! + ensures + - returns a const reference to the any object in this tensor. The any + object can be used to store any additional annotation you like in a + tensor. However, it should be noted that the annotation() is ignored by + serialize() and therefore not saved when a tensor is serialized. + !*/ + + virtual any& annotation( + ) = 0; + /*! + ensures + - returns a non-const reference to the any object in this tensor. The any + object can be used to store any additional annotation you like in a + tensor. However, it should be noted that the annotation() is ignored by + serialize() and therefore not saved when a tensor is serialized. + !*/ + + int device_id( + ) const; + /*! + ensures + - returns the ID of the CUDA device that allocated this memory. I.e. the + number returned by cudaGetDevice() when the memory was allocated. + - If CUDA is not being used then this function always returns 0. + !*/ + + tensor& operator= ( + float val + ); + /*! + ensures + - sets all elements of this tensor equal to val. + - returns *this + !*/ + + tensor& operator*= ( + float val + ); + /*! + ensures + - pointwise multiplies all elements of *this tensor with val. + - returns *this + !*/ + + tensor& operator/= ( + float val + ); + /*! + ensures + - pointwise divides all elements of *this tensor with val. + - returns *this + !*/ + + template + tensor& operator= ( + const matrix_exp& item + ); + /*! + requires + - num_samples() == item.nr() + - k()*nr()*nc() == item.nc() + - item contains float values + ensures + - Assigns item to *this tensor by performing: + set_ptrm(host(), num_samples(), k()*nr()*nc()) = item; + !*/ + + template + tensor& operator+= ( + const matrix_exp& item + ); + /*! + requires + - num_samples() == item.nr() + - k()*nr()*nc() == item.nc() + - item contains float values + ensures + - Adds item to *this tensor by performing: + set_ptrm(host(), num_samples(), k()*nr()*nc()) += item; + !*/ + + template + tensor& operator-= ( + const matrix_exp& item + ); + /*! + requires + - num_samples() == item.nr() + - k()*nr()*nc() == item.nc() + - item contains float values + ensures + - Subtracts item from *this tensor by performing: + set_ptrm(host(), num_samples(), k()*nr()*nc()) -= item; + !*/ + + template + void set_sample ( + unsigned long long idx, + const matrix_exp& item + ); + /*! + requires + - idx < num_samples() + - k()*nr()*nc() == item.size() + - item contains float values + ensures + - Assigns item to the idx'th sample in *this by performing: + set_ptrm(host()+idx*item.size(), item.nr(), item.nc()) = item; + !*/ + + + template + void add_to_sample ( + unsigned long long idx, + const matrix_exp& item + ); + /*! + requires + - idx < num_samples() + - k()*nr()*nc() == item.size() + - item contains float values + ensures + - Adds item to the idx'th sample in *this by performing: + set_ptrm(host()+idx*item.size(), item.nr(), item.nc()) += item; + !*/ + + protected: + + // You can't move or copy another tensor into *this since that might modify the + // tensor's dimensions. If you want to do that sort of thing then use a + // resizable_tensor. + tensor(const tensor& item); + tensor& operator= (const tensor& item); + tensor(tensor&& item); + tensor& operator=(tensor&& item); + }; + +// ---------------------------------------------------------------------------------------- + + void memcpy ( + tensor& dest, + const tensor& src + ); + /*! + requires + - dest.size() == src.size() + ensures + - Copies the data in src to dest. If the device data is current on both src + and dest then the copy will happen entirely on the device side. + - It doesn't matter what GPU device is selected by cudaSetDevice(). You can + always copy tensor objects to and from each other regardless. + - This function blocks until the copy has completed. + !*/ + +// ---------------------------------------------------------------------------------------- + + bool is_vector ( + const tensor& t + ); + /*! + ensures + - returns true if and only if one of the following is true: + - t.size() == t.num_samples() + - t.size() == t.k() + - t.size() == t.nr() + - t.size() == t.nc() + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp mat ( + const tensor& t, + long long nr, + long long nc + ); + /*! + requires + - nr >= 0 + - nc >= 0 + - nr*nc == t.size() + ensures + - returns a matrix M such that: + - M.nr() == nr + - m.nc() == nc + - for all valid r and c: + M(r,c) == t.host()[r*nc + c] + (i.e. the tensor is interpreted as a matrix laid out in memory + in row major order) + !*/ + + const matrix_exp mat ( + const tensor& t + ); + /*! + ensures + - if (t.size() != 0) then + - returns mat(t, t.num_samples(), t.size()/t.num_samples()) + - else + - returns an empty matrix. + !*/ + + const matrix_exp image_plane ( + const tensor& t, + long long sample = 0, + long long k = 0 + ); + /*! + requires + - t.size() != 0 + - 0 <= sample < t.num_samples() + - 0 <= k < t.k() + ensures + - returns the k-th image plane from the sample-th image in t. That is, + returns a matrix M such that: + - M contains float valued elements. + - M.nr() == t.nr() + - M.nc() == t.nc() + - for all valid r and c: + - M(r,c) == t.host()[((sample*t.k() + k)*t.nr() + r)*t.nc() + c] + !*/ + +// ---------------------------------------------------------------------------------------- + + bool have_same_dimensions ( + const tensor& a, + const tensor& b + ); + /*! + ensures + - returns true if and only if all of the fallowing are satisfied: + - a.num_samples() == b.num_samples() + - a.k() == b.k() + - a.nr() == b.nr() + - a.nc() == b.nc() + !*/ + +// ---------------------------------------------------------------------------------------- + + class resizable_tensor : public tensor + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is just a tensor with the additional ability to be resized. + !*/ + + public: + resizable_tensor( + ); + /*! + ensures + - #size() == 0 + - #num_samples() == 0 + - #k() == 0 + - #nr() == 0 + - #nc() == 0 + - #capacity() == 0 + !*/ + + template + resizable_tensor( + const matrix_exp& item + ); + /*! + requires + - item contains float values + ensures + - #num_samples() == item.nr() + - #k() == item.nc() + - #nr() == 1 + - #nc() == 1 + - Assigns item to *this tensor by performing: + set_ptrm(host(), num_samples(), k()*nr()*nc()) = item; + - #capacity() == size() + !*/ + + explicit resizable_tensor( + long long n_, long long k_ = 1, long long nr_ = 1, long long nc_ = 1 + ); + /*! + requires + - n_ >= 0 + - k_ >= 0 + - nr_ >= 0 + - nc_ >= 0 + ensures + - #size() == n_*k_*nr_*nc_ + - #num_samples() == n_ + - #k() == k_ + - #nr() == nr_ + - #nc() == nc_ + - #capacity() == size() + !*/ + + // This object is copyable and movable + resizable_tensor(const resizable_tensor&) = default; + resizable_tensor(resizable_tensor&&) = default; + resizable_tensor& operator= (const resizable_tensor&) = default; + resizable_tensor& operator= (resizable_tensor&&) = default; + + size_t capacity ( + ) const; + /*! + ensures + - returns the total number of floats allocated. This might be different + from the size() since calls to set_size() that make a tensor smaller + don't trigger reallocations. They simply adjust the nominal dimensions + while keeping the same allocated memory block. This makes calls to + set_size() very fast. If you need to deallocate a tensor then use + clear(). + !*/ + + void clear( + ); + /*! + ensures + - #size() == 0 + - #num_samples() == 0 + - #k() == 0 + - #nr() == 0 + - #nc() == 0 + - #annotation().is_empty() == true + - #capacity() == 0 + !*/ + + void copy_size ( + const tensor& item + ); + /*! + ensures + - resizes *this so that: have_same_dimensions(#*this, item)==true + !*/ + + void set_size( + long long n_, long long k_ = 1, long long nr_ = 1, long long nc_ = 1 + ); + /*! + requires + - n_ >= 0 + - k_ >= 0 + - nr_ >= 0 + - nc_ >= 0 + ensures + - #size() == n_*k_*nr_*nc_ + - #num_samples() == n_ + - #k() == k_ + - #nr() == nr_ + - #nc() == nc_ + - #capacity() == max(#size(), capacity()) + (i.e. capacity() never goes down when calling set_size().) + !*/ + + template + resizable_tensor& operator= ( + const matrix_exp& item + ); + /*! + requires + - item contains float values + ensures + - if (num_samples() == item.nr() && k()*nr()*nc() == item.nc()) then + - the dimensions of this tensor are not changed + - else + - #num_samples() == item.nr() + - #k() == item.nc() + - #nr() == 1 + - #nc() == 1 + - Assigns item to *this tensor by performing: + set_ptrm(host(), num_samples(), k()*nr()*nc()) = item; + !*/ + }; + + void serialize(const tensor& item, std::ostream& out); + void deserialize(resizable_tensor& item, std::istream& in); + /*! + provides serialization support for tensor and resizable_tensor. Note that you can + serialize to/from any combination of tenor and resizable_tensor objects. + !*/ + +// ---------------------------------------------------------------------------------------- + + double dot( + const tensor& a, + const tensor& b + ); + /*! + requires + - a.size() == b.size() + ensures + - returns the dot product between a and b when they are both treated as + a.size() dimensional vectors. That is, this function pointwise multiplies + the vectors together, then sums the result and returns it. + + !*/ + +// ---------------------------------------------------------------------------------------- + + class alias_tensor_instance : public tensor + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a tensor that aliases another tensor. That is, it doesn't + have its own block of memory but instead simply holds pointers to the + memory of another tensor object. It therefore allows you to efficiently + break a tensor into pieces and pass those pieces into functions. + + An alias_tensor_instance doesn't own the resources it points to in any sense. + So it is important to make sure that the underlying owning tensor doesn't get + destructed before any alias tensors which point to it are destructed. + !*/ + + // You can't default initialize this object. You can only get instances of it from + // alias_tensor::operator(). + alias_tensor_instance( + ); + }; + + class alias_tensor_const_instance + { + /*! + WHAT THIS OBJECT REPRESENTS + This is essentially a const version of alias_tensor_instance and therefore + represents a tensor. However, due to the mechanics of C++, this object + can't inherit from tensor. So instead it provides a get() and an implicit + conversion to const tensor. + !*/ + + public: + + // non-const alias tensors are convertible to const ones. + alias_tensor_const_instance(const alias_tensor_instance& item); + + // Methods that cast the alias to a tensor. + const tensor& get() const; + operator const tensor& (); + + private: + // You can't default initialize this object. You can only get instances of it from + // alias_tensor::operator(). + alias_tensor_const_instance(); + }; + + class alias_tensor + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a tool for creating tensor objects that alias other tensor objects. + That is, it allows you to make a tensor that references the memory space of + another tensor object rather than owning its own memory. This allows you + to do things like interpret a single tensor in different ways or even as a + group of multiple tensors. + !*/ + public: + + alias_tensor ( + ); + /*! + ensures + - #size() == 0 + - #num_samples() == 0 + - #k() == 0 + - #nr() == 0 + - #nc() == 0 + !*/ + + alias_tensor ( + long long n_, long long k_ = 1, long long nr_ = 1, long long nc_ = 1 + ); + /*! + requires + - n_ >= 0 + - k_ >= 0 + - nr_ >= 0 + - nc_ >= 0 + ensures + - #size() == n_*k_*nr_*nc_ + - #num_samples() == n_ + - #k() == k_ + - #nr() == nr_ + - #nc() == nc_ + !*/ + + long long num_samples() const; + long long k() const; + long long nr() const; + long long nc() const; + size_t size() const; + + alias_tensor_instance operator() ( + tensor& t, + size_t offset = 0 + ) const; + /*! + requires + - offset+size() <= t.size() + ensures + - Returns a tensor that simply aliases the elements of t beginning with t's + offset'th element. Specifically, this function returns an aliasing + tensor T such that: + - T.size() == size() + - T.num_samples() == num_samples() + - T.k() == k() + - T.nr() == nr() + - T.nc() == nc() + - T.host() == t.host()+offset + - T.device() == t.device()+offset + - &T.annotation() == &t.annotation() + !*/ + + alias_tensor_const_instance operator() ( + const tensor& t, + size_t offset = 0 + ) const; + /*! + requires + - offset+size() <= t.size() + ensures + - This function is identical to the above version of operator() except that + it takes and returns const tensors instead of non-const tensors. + !*/ + }; + + void serialize(const alias_tensor& item, std::ostream& out); + void deserialize(alias_tensor& item, std::istream& in); + /*! + provides serialization support for alias_tensor. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_DNn_TENSOR_ABSTRACT_H_ + + diff --git a/ml/dlib/dlib/dnn/tensor_tools.cpp b/ml/dlib/dlib/dnn/tensor_tools.cpp new file mode 100644 index 000000000..c0f7fd69d --- /dev/null +++ b/ml/dlib/dlib/dnn/tensor_tools.cpp @@ -0,0 +1,985 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_TeNSOR_TOOLS_CPP_ +#define DLIB_TeNSOR_TOOLS_CPP_ + +#include "tensor_tools.h" +#include "../string.h" +#include + +namespace dlib +{ + namespace + { + std::atomic& dnn_prefer_fastest_algo ( + ) + { + static std::atomic var(true); + return var; + } + } + + bool dnn_prefer_fastest_algorithms ( + ) + { + return dnn_prefer_fastest_algo(); + } + + void set_dnn_prefer_fastest_algorithms( + ) + { + dnn_prefer_fastest_algo() = true; + } + + void set_dnn_prefer_smallest_algorithms( + ) + { + dnn_prefer_fastest_algo() = false; + } +} + +namespace dlib { namespace tt +{ + +// ---------------------------------------------------------------------------------------- + + void inverse_norms ( + resizable_tensor& invnorms, + const tensor& data, + const double eps + ) + { +#ifdef DLIB_USE_CUDA + cuda::inverse_norms(invnorms, data, eps); +#else + invnorms = reciprocal(sqrt(sum_cols(squared(mat(data))) + eps)); +#endif + } + + void dot_prods ( + resizable_tensor& out, + const tensor& lhs, + const tensor& rhs + ) + { +#ifdef DLIB_USE_CUDA + cuda::dot_prods(out, lhs, rhs); +#else + out = sum_cols(pointwise_multiply(mat(lhs), mat(rhs))); +#endif + } + + void dot_prods ( + bool add_to, + tensor& out, + const tensor& lhs, + const tensor& rhs + ) + { +#ifdef DLIB_USE_CUDA + cuda::dot_prods(add_to, out, lhs, rhs); +#else + if (add_to) + out += sum_cols(pointwise_multiply(mat(lhs), mat(rhs))); + else + out = sum_cols(pointwise_multiply(mat(lhs), mat(rhs))); +#endif + } + + void scale_columns ( + tensor& out, + const tensor& m, + const tensor& v + ) + { + DLIB_CASSERT(have_same_dimensions(out,m)); + DLIB_CASSERT(is_vector(v)); + if (m.size() == 0 && v.size() == 0) + return; + DLIB_CASSERT(m.size() != 0); + DLIB_CASSERT(m.size()/m.num_samples() == v.size()); + +#ifdef DLIB_USE_CUDA + cuda::scale_columns(out, m, v); +#else + DLIB_CASSERT(false, "shouldn't be called right now"); + out = scale_columns(mat(m), mat(v)); +#endif + } + + void scale_rows ( + tensor& out, + const tensor& m, + const tensor& v + ) + { + DLIB_CASSERT(have_same_dimensions(out,m)); + DLIB_CASSERT(is_vector(v)); + if (m.size() == 0 && v.size() == 0) + return; + DLIB_CASSERT(m.size() != 0); + DLIB_CASSERT(m.num_samples() == v.size()); + +#ifdef DLIB_USE_CUDA + cuda::scale_rows(out, m, v); +#else + out = scale_rows(mat(m), mat(v)); +#endif + } + + void scale_rows2 ( + float beta, + tensor& out, + const tensor& m1, + const tensor& m2, + const tensor& v1, + const tensor& v2 + ) + { + DLIB_CASSERT(have_same_dimensions(out,m1)); + DLIB_CASSERT(have_same_dimensions(out,m2)); + DLIB_CASSERT(have_same_dimensions(v1,v2)); + DLIB_CASSERT(is_vector(mat(v1))); + DLIB_CASSERT(v1.size() == m1.num_samples()); + +#ifdef DLIB_USE_CUDA + cuda::scale_rows2(beta, out, m1, m2, v1, v2); +#else + if (beta == 0) + out = scale_rows(mat(m1) - scale_rows(mat(m2),mat(v1)), mat(v2)); + else + out = beta*mat(out) + scale_rows(mat(m1) - scale_rows(mat(m2),mat(v1)), mat(v2)); +#endif + } + +// ---------------------------------------------------------------------------------------- + + void exp ( + tensor& dest, + const tensor& src + ) + { + DLIB_CASSERT(dest.size() == src.size()); + +#ifdef DLIB_USE_CUDA + cuda::exp(dest,src); +#else + dest = exp(mat(src)); +#endif + } + +// ---------------------------------------------------------------------------------------- + + void log ( + tensor& dest, + const tensor& src + ) + { + DLIB_CASSERT(dest.size() == src.size()); + +#ifdef DLIB_USE_CUDA + cuda::log(dest,src); +#else + dest = log(mat(src)); +#endif + } + +// ---------------------------------------------------------------------------------------- + + void log10 ( + tensor& dest, + const tensor& src + ) + { + DLIB_CASSERT(dest.size() == src.size()); + +#ifdef DLIB_USE_CUDA + cuda::log10(dest,src); +#else + dest = log10(mat(src)); +#endif + } + +// ---------------------------------------------------------------------------------------- + + void gemm ( + float beta, + tensor& dest, + float alpha, + const tensor& lhs, + bool trans_lhs, + const tensor& rhs, + bool trans_rhs + ) + { +#ifdef DLIB_USE_CUDA + cuda::gemm(beta, dest, alpha, lhs, trans_lhs, rhs, trans_rhs); +#else + if (beta != 0) + { + if (trans_lhs && trans_rhs) + dest = alpha*trans(mat(lhs))*trans(mat(rhs)) + beta*mat(dest); + else if (!trans_lhs && trans_rhs) + dest = alpha*mat(lhs)*trans(mat(rhs)) + beta*mat(dest); + else if (trans_lhs && !trans_rhs) + dest = alpha*trans(mat(lhs))*mat(rhs) + beta*mat(dest); + else + dest = alpha*mat(lhs)*mat(rhs) + beta*mat(dest); + } + else + { + if (trans_lhs && trans_rhs) + dest = alpha*trans(mat(lhs))*trans(mat(rhs)); + else if (!trans_lhs && trans_rhs) + dest = alpha*mat(lhs)*trans(mat(rhs)); + else if (trans_lhs && !trans_rhs) + dest = alpha*trans(mat(lhs))*mat(rhs); + else + dest = alpha*mat(lhs)*mat(rhs); + } +#endif + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + tensor_rand:: + tensor_rand( + unsigned long long seed + ) +#ifdef DLIB_USE_CUDA + :rnd(seed){} +#else + {rnd.set_seed(cast_to_string(seed)); } +#endif + + void tensor_rand:: + fill_gaussian ( + tensor& data, + float mean, + float stddev + ) + { + DLIB_CASSERT(data.size()%2 == 0); +#ifdef DLIB_USE_CUDA + rnd.fill_gaussian(data, mean, stddev); +#else + for (auto& x : data) + x = rnd.get_random_gaussian()*stddev + mean; +#endif + } + + void tensor_rand:: + fill_uniform ( + tensor& data + ) + { +#ifdef DLIB_USE_CUDA + rnd.fill_uniform(data); +#else + for (auto& x : data) + x = rnd.get_random_float(); +#endif + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + void multiply ( + bool add_to, + tensor& dest, + const tensor& src1, + const tensor& src2 + ) + { + DLIB_CASSERT(dest.k() == src1.k() && src1.k() == src2.k() && + dest.nr() == src1.nr() && src1.nr() == src2.nr() && + dest.nc() == src1.nc() && src1.nc() == src2.nc() ); + const long MD = std::max(std::max(dest.num_samples(),src1.num_samples()),src2.num_samples()); + DLIB_CASSERT((dest.num_samples()==1 || dest.num_samples()==MD) && + (src1.num_samples()==1 || src1.num_samples()==MD) && + (src2.num_samples()==1 || src2.num_samples()==MD) ); +#ifdef DLIB_USE_CUDA + cuda::multiply(add_to, dest, src1, src2); +#else + cpu::multiply(add_to, dest, src1, src2); +#endif + + } + + void scale_channels ( + bool add_to, + tensor& dest, + const tensor& src, + const tensor& scales + ) + { +#ifdef DLIB_USE_CUDA + cuda::scale_channels(add_to, dest, src, scales); +#else + cpu::scale_channels(add_to, dest, src, scales); +#endif + } + + void multiply_conv ( + bool add_to, + tensor& dest, + const tensor& src1, + const tensor& src2 + ) + { +#ifdef DLIB_USE_CUDA + cuda::multiply_conv(add_to, dest, src1, src2); +#else + cpu::multiply_conv(add_to, dest, src1, src2); +#endif + } + + void multiply_zero_padded ( + bool add_to, + tensor& dest, + const tensor& src1, + const tensor& src2 + ) + { +#ifdef DLIB_USE_CUDA + cuda::multiply_zero_padded(add_to, dest, src1, src2); +#else + cpu::multiply_zero_padded(add_to, dest, src1, src2); +#endif + } + +// ---------------------------------------------------------------------------------------- + + void affine_transform( + tensor& dest, + const tensor& src, + const float A, + const float B + ) + { +#ifdef DLIB_USE_CUDA + cuda::affine_transform(dest,src,A,B); +#else + cpu::affine_transform(dest,src,A,B); +#endif + } + + void affine_transform( + tensor& dest, + const tensor& src, + const float A + ) + { +#ifdef DLIB_USE_CUDA + cuda::affine_transform(dest,src,A); +#else + cpu::affine_transform(dest,src,A,0); +#endif + } + + void affine_transform( + tensor& dest, + const tensor& src1, + const tensor& src2, + const float A, + const float B, + const float C + ) + { +#ifdef DLIB_USE_CUDA + cuda::affine_transform(dest,src1,src2,A,B,C); +#else + cpu::affine_transform(dest,src1,src2,A,B,C); +#endif + } + + void affine_transform( + tensor& dest, + const tensor& src1, + const tensor& src2, + const float A, + const float B + ) + { +#ifdef DLIB_USE_CUDA + cuda::affine_transform(dest,src1,src2,A,B); +#else + cpu::affine_transform(dest,src1,src2,A,B,0); +#endif + } + + void affine_transform( + tensor& dest, + const tensor& src1, + const tensor& src2, + const tensor& src3, + const float A, + const float B, + const float C, + const float D + ) + { +#ifdef DLIB_USE_CUDA + cuda::affine_transform(dest,src1,src2,src3,A,B,C,D); +#else + cpu::affine_transform(dest,src1,src2,src3,A,B,C,D); +#endif + } + + void affine_transform_range( + size_t begin, + size_t end, + tensor& dest, + const tensor& src1, + const tensor& src2, + const tensor& src3, + const float A, + const float B, + const float C + ) + { +#ifdef DLIB_USE_CUDA + cuda::affine_transform_range(begin, end, dest,src1,src2,src3,A,B,C); +#else + cpu::affine_transform_range(begin, end, dest,src1,src2,src3,A,B,C); +#endif + } + + void affine_transform( + const rectangle& rect, + tensor& dest, + const tensor& src1, + const tensor& src2, + const tensor& src3, + float A, + float B, + float C + ) + { +#ifdef DLIB_USE_CUDA + cuda::affine_transform(rect, dest,src1,src2,src3,A,B,C); +#else + cpu::affine_transform(rect, dest,src1,src2,src3,A,B,C); +#endif + } + + void affine_transform( + tensor& dest, + const tensor& src1, + const tensor& src2, + const tensor& src3, + const float A, + const float B, + const float C + ) + { +#ifdef DLIB_USE_CUDA + cuda::affine_transform_range(0,dest.size(),dest,src1,src2,src3,A,B,C); +#else + cpu::affine_transform_range(0,dest.size(),dest,src1,src2,src3,A,B,C); +#endif + } + +// ---------------------------------------------------------------------------------------- + + void affine_transform( + tensor& dest, + const tensor& src, + const tensor& A, + const tensor& B + ) + { +#ifdef DLIB_USE_CUDA + cuda::affine_transform(dest,src,A,B); +#else + cpu::affine_transform(dest,src,A,B); +#endif + } + +// ---------------------------------------------------------------------------------------- + + void affine_transform_conv( + tensor& dest, + const tensor& src, + const tensor& A, + const tensor& B + ) + { +#ifdef DLIB_USE_CUDA + cuda::affine_transform_conv(dest,src,A,B); +#else + cpu::affine_transform_conv(dest,src,A,B); +#endif + } + +// ---------------------------------------------------------------------------------------- + + void compute_adam_update ( + size_t begin, + size_t end, + tensor& s, + tensor& m, + tensor& v, + const float t, + const float learning_rate, + const float weight_decay, + const float momentum1, + const float momentum2, + const tensor& params, + const tensor& params_grad + ) + { +#ifdef DLIB_USE_CUDA + cuda::compute_adam_update(begin, end, s, m, v, t, learning_rate, weight_decay, momentum1, + momentum2, params, params_grad); +#else + cpu::compute_adam_update(begin, end, s, m, v, t, learning_rate, weight_decay, momentum1, + momentum2, params, params_grad); +#endif + } + +// ---------------------------------------------------------------------------------------- + + void batch_normalize_inference ( + const double eps, + resizable_tensor& dest, + const tensor& src, + const tensor& gamma, + const tensor& beta, + const tensor& running_means, + const tensor& running_variances + ) + { +#ifdef DLIB_USE_CUDA + cuda::batch_normalize_inference(eps,dest,src,gamma,beta,running_means,running_variances); +#else + cpu::batch_normalize_inference(eps,dest,src,gamma,beta,running_means,running_variances); +#endif + } + + void batch_normalize ( + const double eps, + resizable_tensor& dest, + resizable_tensor& means, + resizable_tensor& vars, + const double averaging_factor, + resizable_tensor& running_means, + resizable_tensor& running_variances, + const tensor& src, + const tensor& gamma, + const tensor& beta + ) + { +#ifdef DLIB_USE_CUDA + cuda::batch_normalize(eps,dest,means,vars,averaging_factor,running_means,running_variances,src,gamma,beta); +#else + cpu::batch_normalize(eps,dest,means,vars,averaging_factor,running_means,running_variances,src,gamma,beta); +#endif + } + + void batch_normalize_gradient ( + const double eps, + const tensor& gradient_input, + const tensor& means, + const tensor& invstds, + const tensor& src, + const tensor& gamma, + tensor& src_grad, + tensor& gamma_grad, + tensor& beta_grad + ) + { + +#ifdef DLIB_USE_CUDA + cuda::batch_normalize_gradient(eps,gradient_input, means, invstds, src, gamma, src_grad, gamma_grad, beta_grad); +#else + cpu::batch_normalize_gradient(eps,gradient_input, means, invstds, src, gamma, src_grad, gamma_grad, beta_grad); +#endif + } + +// ---------------------------------------------------------------------------------------- + + void batch_normalize_conv_inference ( + const double eps, + resizable_tensor& dest, + const tensor& src, + const tensor& gamma, + const tensor& beta, + const tensor& running_means, + const tensor& running_variances + ) + { +#ifdef DLIB_USE_CUDA + cuda::batch_normalize_conv_inference(eps,dest,src,gamma,beta,running_means,running_variances); +#else + cpu::batch_normalize_conv_inference(eps,dest,src,gamma,beta,running_means,running_variances); +#endif + } + + void batch_normalize_conv ( + const double eps, + resizable_tensor& dest, + resizable_tensor& means, + resizable_tensor& vars, + const double averaging_factor, + resizable_tensor& running_means, + resizable_tensor& running_variances, + const tensor& src, + const tensor& gamma, + const tensor& beta + ) + { +#ifdef DLIB_USE_CUDA + cuda::batch_normalize_conv(eps,dest,means,vars,averaging_factor,running_means,running_variances,src,gamma,beta); +#else + cpu::batch_normalize_conv(eps,dest,means,vars,averaging_factor,running_means,running_variances,src,gamma,beta); +#endif + } + + void batch_normalize_conv_gradient ( + const double eps, + const tensor& gradient_input, + const tensor& means, + const tensor& invstds, + const tensor& src, + const tensor& gamma, + tensor& src_grad, + tensor& gamma_grad, + tensor& beta_grad + ) + { + +#ifdef DLIB_USE_CUDA + cuda::batch_normalize_conv_gradient(eps,gradient_input, means, invstds, src, gamma, src_grad, gamma_grad, beta_grad); +#else + cpu::batch_normalize_conv_gradient(eps,gradient_input, means, invstds, src, gamma, src_grad, gamma_grad, beta_grad); +#endif + } + +// ---------------------------------------------------------------------------------------- + + void threshold ( + tensor& data, + float thresh + ) + { +#ifdef DLIB_USE_CUDA + cuda::threshold(data,thresh); +#else + cpu::threshold(data,thresh); +#endif + } + + void dot ( + const tensor& a, + const tensor& b, + tensor& result, + size_t idx + ) + { +#ifdef DLIB_USE_CUDA + cuda::dot(a,b,result,idx); +#else + cpu::dot(a,b,result,idx); +#endif + } + +// ---------------------------------------------------------------------------------------- + + void add( + float beta, + tensor& dest, + float alpha, + const tensor& src + ) + { +#ifdef DLIB_USE_CUDA + cuda::add(beta,dest,alpha,src); +#else + cpu::add(beta,dest,alpha,src); +#endif + } + +// ---------------------------------------------------------------------------------------- + + void add ( + tensor& dest, + const tensor& src1, + const tensor& src2 + ) + { +#ifdef DLIB_USE_CUDA + cuda::add(dest, src1, src2); +#else + cpu::add(dest, src1, src2); +#endif + } + +// ---------------------------------------------------------------------------------------- + + void assign_conv_bias_gradient ( + tensor& grad, + const tensor& gradient_input + ) + { +#ifdef DLIB_USE_CUDA + cuda::assign_conv_bias_gradient(grad,gradient_input); +#else + cpu::assign_conv_bias_gradient(grad,gradient_input); +#endif + } + +// ---------------------------------------------------------------------------------------- + + void assign_bias_gradient ( + tensor& grad, + const tensor& gradient_input + ) + { +#ifdef DLIB_USE_CUDA + cuda::assign_bias_gradient(grad,gradient_input); +#else + cpu::assign_bias_gradient(grad,gradient_input); +#endif + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + void softmax ( + tensor& dest, + const tensor& src + ) + { +#ifdef DLIB_USE_CUDA + cuda::softmax(dest,src); +#else + cpu::softmax(dest,src); +#endif + } + + void softmax_gradient ( + tensor& grad, + const tensor& dest, + const tensor& gradient_input + ) + { +#ifdef DLIB_USE_CUDA + cuda::softmax_gradient(grad, dest, gradient_input); +#else + cpu::softmax_gradient(grad, dest, gradient_input); +#endif + } + +// ---------------------------------------------------------------------------------------- + + void softmax_all ( + tensor& dest, + const tensor& src + ) + { +#ifdef DLIB_USE_CUDA + cuda::softmax_all(dest,src); +#else + cpu::softmax_all(dest,src); +#endif + } + + void softmax_all_gradient ( + tensor& grad, + const tensor& dest, + const tensor& gradient_input + ) + { +#ifdef DLIB_USE_CUDA + cuda::softmax_all_gradient(grad, dest, gradient_input); +#else + cpu::softmax_all_gradient(grad, dest, gradient_input); +#endif + } + +// ---------------------------------------------------------------------------------------- + + void sigmoid ( + tensor& dest, + const tensor& src + ) + { +#ifdef DLIB_USE_CUDA + cuda::sigmoid(dest,src); +#else + cpu::sigmoid(dest,src); +#endif + } + + void sigmoid_gradient ( + tensor& grad, + const tensor& dest, + const tensor& gradient_input + ) + { +#ifdef DLIB_USE_CUDA + cuda::sigmoid_gradient(grad, dest, gradient_input); +#else + cpu::sigmoid_gradient(grad, dest, gradient_input); +#endif + } + +// ---------------------------------------------------------------------------------------- + + void relu ( + tensor& dest, + const tensor& src + ) + { +#ifdef DLIB_USE_CUDA + cuda::relu(dest,src); +#else + cpu::relu(dest,src); +#endif + } + + void relu_gradient ( + tensor& grad, + const tensor& dest, + const tensor& gradient_input + ) + { +#ifdef DLIB_USE_CUDA + cuda::relu_gradient(grad, dest, gradient_input); +#else + cpu::relu_gradient(grad, dest, gradient_input); +#endif + } + +// ---------------------------------------------------------------------------------------- + + void prelu ( + tensor& dest, + const tensor& src, + const tensor& param + ) + { +#ifdef DLIB_USE_CUDA + cuda::prelu(dest, src, param); +#else + cpu::prelu(dest, src, param); +#endif + } + + void prelu_gradient ( + tensor& grad, + const tensor& src, + const tensor& gradient_input, + const tensor& param, + tensor& params_grad + ) + { +#ifdef DLIB_USE_CUDA + cuda::prelu_gradient(grad, src, gradient_input, param, params_grad); +#else + cpu::prelu_gradient(grad, src, gradient_input, param, params_grad); +#endif + } + +// ---------------------------------------------------------------------------------------- + + void tanh ( + tensor& dest, + const tensor& src + ) + { +#ifdef DLIB_USE_CUDA + cuda::tanh(dest,src); +#else + cpu::tanh(dest,src); +#endif + } + + void tanh_gradient ( + tensor& grad, + const tensor& dest, + const tensor& gradient_input + ) + { +#ifdef DLIB_USE_CUDA + cuda::tanh_gradient(grad, dest, gradient_input); +#else + cpu::tanh_gradient(grad, dest, gradient_input); +#endif + } + +// ---------------------------------------------------------------------------------------- + + void resize_bilinear ( + tensor& dest, + long dest_row_stride, + long dest_channel_stride, + const tensor& src, + long src_row_stride, + long src_channel_stride + ) + { +#ifdef DLIB_USE_CUDA + cuda::resize_bilinear(dest,dest_row_stride,dest_channel_stride, src,src_row_stride,src_channel_stride); +#else + cpu::resize_bilinear(dest,dest_row_stride,dest_channel_stride, src,src_row_stride,src_channel_stride); +#endif + } + + void resize_bilinear_gradient ( + tensor& grad, + long grad_row_stride, + long grad_channel_stride, + const tensor& gradient_input, + long gradient_input_row_stride, + long gradient_input_channel_stride + ) + { +#ifdef DLIB_USE_CUDA + cuda::resize_bilinear_gradient(grad,grad_row_stride,grad_channel_stride, gradient_input,gradient_input_row_stride,gradient_input_channel_stride); +#else + cpu::resize_bilinear_gradient(grad,grad_row_stride,grad_channel_stride, gradient_input,gradient_input_row_stride,gradient_input_channel_stride); +#endif + } + +// ------------------------------------------------------------------------------------ + + void copy_tensor( + bool add_to, + tensor& dest, + size_t dest_k_offset, + const tensor& src, + size_t src_k_offset, + size_t count_k + ) + { +#ifdef DLIB_USE_CUDA + cuda::copy_tensor(add_to, dest, dest_k_offset, src, src_k_offset, count_k); +#else + cpu::copy_tensor(add_to, dest, dest_k_offset, src, src_k_offset, count_k); +#endif + } + +// ---------------------------------------------------------------------------------------- + + void inv:: + operator() ( + const tensor& m, + resizable_tensor& out + ) + { +#ifdef DLIB_USE_CUDA + finv(m,out); +#else + out = dlib::inv(mat(m)); +#endif + } + +// ---------------------------------------------------------------------------------------- + +}} + +#endif // DLIB_TeNSOR_TOOLS_CPP_ + diff --git a/ml/dlib/dlib/dnn/tensor_tools.h b/ml/dlib/dlib/dnn/tensor_tools.h new file mode 100644 index 000000000..9ba3154e5 --- /dev/null +++ b/ml/dlib/dlib/dnn/tensor_tools.h @@ -0,0 +1,1711 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_TeNSOR_TOOLS_H_ +#define DLIB_TeNSOR_TOOLS_H_ + +#include "tensor.h" +#include "cudnn_dlibapi.h" +#include "cublas_dlibapi.h" +#include "cusolver_dlibapi.h" +#include "curand_dlibapi.h" +#include "cpu_dlib.h" +#include "cuda_dlib.h" +#include "../rand.h" +#include +#include "../geometry/rectangle.h" +#include "../test_for_odr_violations.h" + +namespace dlib +{ + bool dnn_prefer_fastest_algorithms(); + void set_dnn_prefer_fastest_algorithms(); + void set_dnn_prefer_smallest_algorithms(); +} + +namespace dlib { namespace tt +{ + +// ---------------------------------------------------------------------------------------- + + void inverse_norms ( + resizable_tensor& invnorms, + const tensor& data, + const double eps + ); + /*! + ensures + - #invnorms == reciprocal(sqrt(sum_cols(squared(mat(data))) + eps)) + !*/ + + void dot_prods ( + resizable_tensor& out, + const tensor& lhs, + const tensor& rhs + ); + /*! + requires + - have_same_dimensions(lhs,rhs) == true + ensures + - #out.num_samples() == lhs.num_samples() + - #out.k() == #out.nr() == #out.nc() == 1 + - #out == sum_cols(pointwise_multiply(mat(lhs), mat(rhs))); + !*/ + + void dot_prods ( + bool add_to, + tensor& out, + const tensor& lhs, + const tensor& rhs + ); + /*! + requires + - have_same_dimensions(lhs,rhs) == true + - out.size() == lhs.num_samples() + - out.k() == out.nr() == out.nc() == 1 + ensures + - if (add_to) then + - #out == mat(out) + sum_cols(pointwise_multiply(mat(lhs), mat(rhs))); + - else + - #out == sum_cols(pointwise_multiply(mat(lhs), mat(rhs))); + !*/ + + void scale_columns ( + tensor& out, + const tensor& m, + const tensor& v + ); + /*! + requires + - have_same_dimensions(out,m) == true + - is_vector(v) == true + - v.size() == mat(m).nc() + ensures + - performs: out = scale_columns(mat(m),mat(v)); + !*/ + + void scale_rows ( + tensor& out, + const tensor& m, + const tensor& v + ); + /*! + requires + - have_same_dimensions(out,m) == true + - is_vector(v) == true + - v.size() == m.num_samples() + ensures + - performs: out = scale_rows(mat(m),mat(v)); + !*/ + + void scale_rows2 ( + float beta, + tensor& out, + const tensor& m1, + const tensor& m2, + const tensor& v1, + const tensor& v2 + ); + /*! + requires + - have_same_dimensions(out,m1) == true + - have_same_dimensions(out,m2) == true + - have_same_dimensions(v1,v2) == true + - is_vector(v1) == true + - v1.size() == m1.num_samples() + ensures + - performs: + out = beta*out + scale_rows(mat(m1) - scale_rows(mat(m2),mat(v1)), mat(v2)); + !*/ + +// ---------------------------------------------------------------------------------------- + + void exp ( + tensor& dest, + const tensor& src + ); + /*! + requires + - dest.size() == src.size() + ensures + - performs: dest = exp(mat(src)) + !*/ + +// ---------------------------------------------------------------------------------------- + + void log ( + tensor& dest, + const tensor& src + ); + /*! + requires + - dest.size() == src.size() + ensures + - performs: dest = log(mat(src)) + !*/ + +// ---------------------------------------------------------------------------------------- + + void log10 ( + tensor& dest, + const tensor& src + ); + /*! + requires + - dest.size() == src.size() + ensures + - performs: dest = log10(mat(src)) + !*/ + +// ---------------------------------------------------------------------------------------- + + void gemm ( + float beta, + tensor& dest, + float alpha, + const tensor& lhs, + bool trans_lhs, + const tensor& rhs, + bool trans_rhs + ); + /*! + requires + - dest does not alias the memory of lhs or rhs + - The dimensions of lhs and rhs must be compatible for matrix multiplication. + In particular: + - Let L == trans_lhs ? trans(mat(lhs)) : mat(lhs) + - Let R == trans_rhs ? trans(mat(rhs)) : mat(rhs) + - Let D == mat(dest) + - D.nr() == L.nr() && D.nc() == R.nc() + (i.e. dest must be preallocated and have the correct output dimensions) + - L.nc() == R.nr() + ensures + - performs: dest = alpha*L*R + beta*mat(dest) + !*/ + +// ---------------------------------------------------------------------------------------- + + class inv + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a functor for doing matrix inversion on the GPU. The only + reason it's an object is to avoid the reallocation of some GPU memory + blocks if you want to do a bunch of matrix inversions in a row. + !*/ + public: + + void operator() ( + const tensor& m, + resizable_tensor& out + ); + /*! + requires + - m.size() == m.num_samples()*m.num_samples() + (i.e. mat(m) must be a square matrix) + ensures + - out == inv(mat(m)); + !*/ + + private: +#ifdef DLIB_USE_CUDA + cuda::inv finv; +#endif + }; + +// ---------------------------------------------------------------------------------------- + + class tensor_rand + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a tool for filling a tensor with random numbers. + + Note that the sequence of random numbers output by this object is different + when dlib is compiled with DLIB_USE_CUDA. So you should not write code + that depends on any specific sequence of numbers coming out of a + tensor_rand. + + !*/ + + public: + // not copyable + tensor_rand(const tensor_rand&) = delete; + tensor_rand& operator=(const tensor_rand&) = delete; + + tensor_rand() : tensor_rand(0) {} + tensor_rand(unsigned long long seed); + + void fill_gaussian ( + tensor& data, + float mean = 0, + float stddev = 1 + ); + /*! + requires + - data.size()%2 == 0 + ensures + - Fills data with random numbers drawn from a Gaussian distribution + with the given mean and standard deviation. + !*/ + + void fill_uniform ( + tensor& data + ); + /*! + ensures + - Fills data with uniform random numbers in the range (0.0, 1.0]. + !*/ + +#ifdef DLIB_USE_CUDA + cuda::curand_generator rnd; +#else + dlib::rand rnd; +#endif + }; + +// ---------------------------------------------------------------------------------------- + + void multiply ( + bool add_to, + tensor& dest, + const tensor& src1, + const tensor& src2 + ); + /*! + requires + - dest.k() == src1.k() == src2.k() + - dest.nr() == src1.nr() == src2.nr() + - dest.nc() == src1.nc() == src2.nc() + - dest.num_samples(), src1.num_samples(), and src2.num_samples() must each + either be 1 or whichever ones aren't equal to 1 must have the same values. + ensures + - let MD = max(dest.num_samples(), src1.num_samples(), src2.num_samples) + - This function pointwise multiplies src1 with src2 and stores the result into + #dest. However, how the multiplication happens depends on the dimensions of + the tensors. First, when src1 and src2 are multiplied together, if either + has a num_samples() dimension that is != MD, then it is first replicated to + produce a tensor with num_samples()==MD dimensions and then they are + pointwise multiplied together. + + Second, if dest.num_samples()==1, then after the pointwise multiplication of + src1 with src2, the result has its samples summed to produce an output tensor + with num_samples()==1 which is then assigned to #dest. + - if (add_to) then + - Instead of assigning the result to dest, this function adds the result to dest. + !*/ + + void scale_channels ( + bool add_to, + tensor& dest, + const tensor& src, + const tensor& scales + ); + /*! + requires + - have_same_dimensions(dest, src) == true + - scales.num_samples() == src.num_samples() + - scales.k() == src.k() + - scales.nr() == 1 + - scales.nc() == 1 + ensures + - Scales each channel of src by the corresponding value in scales. To be + precise, we will have: + - #dest(n,k,r,c) == src(n,k,r,c)*scales(n,k,1,1) + - if (add_to) then + - Instead of assigning the result to dest, this function adds the result to dest. + !*/ + + void multiply_conv ( + bool add_to, + tensor& dest, + const tensor& src1, + const tensor& src2 + ); + /*! + requires + - if (have_same_dimensions(dest, src1) == true) then + - src2.num_samples() == 1 + - src2.nr() == 1 + - src2.nc() == 1 + - src2.k() == src1.k() + - else + - have_same_dimensions(src1, src2) == true) + - dest.num_samples() == 1 + - dest.nr() == 1 + - dest.nc() == 1 + - dest.k() == src1.k() + ensures + - Performs #dest == src1*src2 + In particular, if the elements of dest, src1, and src2 were indexed by (n,k,r,c) then + we would have: + - if (have_same_dimensions(dest,src1)) then + #dest(n,k,r,c) == src1(n,k,r,c)*src2(k) + - else + #dest(k) == sum over {n,r,c} of src1(n,k,r,c)*src2(n,k,r,c) + - if (add_to) then + - Instead of assigning the result to dest, this function adds the result to dest. + !*/ + + void multiply_zero_padded ( + bool add_to, + tensor& dest, + const tensor& src1, + const tensor& src2 + ); + /*! + ensures + - if (add_to) then + - performs: dest += src1 * src2 + - else + - performs: dest = src1 * src2 + - In either case, the multiplication happens pointwise according to 4D tensor + arithmetic. If the dimensions don't match then missing elements are presumed + to be equal to 0. + !*/ + +// ---------------------------------------------------------------------------------------- + + void affine_transform( + tensor& dest, + const tensor& src, + const float A, + const float B + ); + /*! + requires + - dest.size()==src.size() + ensures + - #dest == A*src + B + !*/ + + void affine_transform( + tensor& dest, + const tensor& src, + const float A + ); + /*! + requires + - dest.size()==src.size() + ensures + - #dest == A*src + !*/ + + void affine_transform( + tensor& dest, + const tensor& src1, + const tensor& src2, + const float A, + const float B, + const float C + ); + /*! + requires + - dest.size()==src1.size() + - dest.size()==src2.size() + ensures + - #dest == A*src1 + B*src2 + C + !*/ + + void affine_transform( + tensor& dest, + const tensor& src1, + const tensor& src2, + const float A, + const float B + ); + /*! + requires + - dest.size()==src1.size() + - dest.size()==src2.size() + ensures + - #dest == A*src1 + B*src2 + !*/ + + void affine_transform( + tensor& dest, + const tensor& src1, + const tensor& src2, + const tensor& src3, + const float A, + const float B, + const float C, + const float D + ); + /*! + requires + - dest.size()==src1.size() + - dest.size()==src2.size() + - dest.size()==src3.size() + ensures + - #dest == A*src1 + B*src2 + C*src3 + D + !*/ + + void affine_transform( + tensor& dest, + const tensor& src1, + const tensor& src2, + const tensor& src3, + const float A, + const float B, + const float C + ); + /*! + requires + - dest.size()==src1.size() + - dest.size()==src2.size() + - dest.size()==src3.size() + ensures + - #dest == A*src1 + B*src2 + C*src3 + !*/ + + void affine_transform_range( + size_t begin, + size_t end, + tensor& dest, + const tensor& src1, + const tensor& src2, + const tensor& src3, + const float A, + const float B, + const float C + ); + /*! + requires + - dest.size()==src1.size() + - dest.size()==src2.size() + - dest.size()==src3.size() + - begin <= end <= dest.size() + ensures + - This function operates much like + affine_transform(dest,src1,src2,src3,A,B,C,0), except that it runs over only + the half open range [begin,end) rather than processing the entire tensor. + Specifically, it does this: + - for i in the range [begin, end): + - #dest.host()[i] == A*src1.host()[i] + B*src2.host()[i] + C*src3.host()[i] + !*/ + + void affine_transform( + const rectangle& rect, + tensor& dest, + const tensor& src1, + const tensor& src2, + const tensor& src3, + float A, + float B, + float C + ); + /*! + requires + - dest.size()==src1.size() + - dest.size()==src2.size() + - dest.size()==src3.size() + - dest.num_samples()==src1.num_samples() + - dest.num_samples()==src2.num_samples() + - dest.num_samples()==src3.num_samples() + - get_rect(mat(dest)).contains(rect) == true + (i.e. rect must be entirely contained within dest) + ensures + - This function operates much like + affine_transform(dest,src1,src2,src3,A,B,C,0), except that it runs over only + the sub-rectangle indicated by rect. In particular, this function is equivalent + to: + set_subm(dest,rect) = A*subm(mat(src1),rect) + B*subm(mat(src2),rect) + C*subm(mat(src3),rect) + !*/ + +// ---------------------------------------------------------------------------------------- + + void affine_transform( + tensor& dest, + const tensor& src, + const tensor& A, + const tensor& B + ); + /*! + requires + - have_same_dimensions(dest,src) == true + - if (A.num_samples() == 1) then + - B.num_samples() == 1 + - else + - A.num_samples() == src.num_samples() + - B.num_samples() == src.num_samples() + - A.nr() == B.nr() == src.nr() + - A.nc() == B.nc() == src.nc() + - A.k() == B.k() == src.k() + ensures + - if (A.num_samples() == 1) then + - #dest == A*src + B + (done for each sample in src) + - else + - for all valid i: + - #dest.host()[i] == A.host()[i]*src.host()[i] + B.host()[i] + !*/ + +// ---------------------------------------------------------------------------------------- + + void affine_transform_conv( + tensor& dest, + const tensor& src, + const tensor& A, + const tensor& B + ); + /*! + requires + - have_same_dimensions(dest,src) == true + - have_same_dimensions(A, B) == true + - A.num_samples() == 1 + - A.nr() == 1 + - A.nc() == 1 + - A.k() == src.k() + ensures + - Performs #dest == A*src + B + In particular, if the elements of dest and src were indexed by (n,k,r,c) then + we would have: + #dest(n,k,r,c) == A(k)*src(n,k,r,c) + B(k). + !*/ + +// ---------------------------------------------------------------------------------------- + + void compute_adam_update ( + size_t begin, + size_t end, + tensor& s, + tensor& m, + tensor& v, + const float t, + const float learning_rate, + const float weight_decay, + const float momentum1, + const float momentum2, + const tensor& params, + const tensor& params_grad + ); + /*! + requires + - s.size() == m.size() = v.size() == params.size() == params_grad.size() + - t > 0 + - learning_rate > 0 + - weight_decay >= 0 + - 0 <= momentum1 < 1 + - 0 <= momentum2 < 1 + - begin <= end <= params.size() + ensures + - This function implements the ADAM parameter update method described in the paper: + Kingma, Diederik P., and Jimmy Ba Adam. "A method for stochastic + optimization." International Conference on Learning Representation. 2015. + Specifically, it implements the method shown as Algorithm 1. + - #s is the update vector that should be added to the parameters. + - The function only operates in the half open range [begin,end) of the memory + blocks of each tensor. E.g. to make this function run on the entire tensor + set begin to 0 and end to params.size(). + !*/ + +// ---------------------------------------------------------------------------------------- + + void batch_normalize_inference ( + const double eps, + resizable_tensor& dest, + const tensor& src, + const tensor& gamma, + const tensor& beta, + const tensor& running_means, + const tensor& running_variances + ); + /*! + requires + - eps > 0 + - gamma.num_samples() == 1 + - gamma.nr() == src.nr() + - gamma.nc() == src.nc() + - gamma.k() == src.k() + - have_same_dimensions(gamma, beta) + - have_same_dimensions(gamma, running_means) + - have_same_dimensions(gamma, running_variances) + ensures + - Linearly transforms src as a call to batch_normalize() would if src had means + and variances as given by running_means and running_variances. That is, this + function performs: + dest = gamma*(src-running_means)/sqrt(running_variances+eps) + beta + Note that it does it in a pointwise fashion over the samples in src. + !*/ + + void batch_normalize ( + const double eps, + resizable_tensor& dest, + resizable_tensor& means, + resizable_tensor& invstds, + const double averaging_factor, + resizable_tensor& running_means, + resizable_tensor& running_variances, + const tensor& src, + const tensor& gamma, + const tensor& beta + ); + /*! + requires + - eps > 0 + - src.num_samples() > 1 + - gamma.num_samples() == 1 + - beta.num_samples() == 1 + - gamma.nr() == beta.nr() == src.nr() + - gamma.nc() == beta.nc() == src.nc() + - gamma.k() == beta.k() == src.k() + - 0 <= averaging_factor <= 1 + - if (averaging_factor != 1) + - have_same_dimensions(running_means, means) == true + - have_same_dimensions(running_variances, invstds) == true + ensures + - have_same_dimensions(#dest, src) == true + - #means.num_samples() == 1 + - #invstds.num_samples() == 1 + - means.nr() == invstds.nr() == src.nr() + - means.nc() == invstds.nc() == src.nc() + - means.k() == invstds.k() == src.k() + - #src == the batch normalized version of src. + - #means == the mean values of the contents of src. + - #invstds == 1/(the standard deviation values of the contents of src). + - #running_means = (1-averaging_factor)*mat(#running_means) + averaging_factor*mat(#means); + - #running_variances = (1-averaging_factor)*mat(#running_variances) + averaging_factor*(variance of contents of src); + !*/ + + void batch_normalize_gradient ( + const double eps, + const tensor& gradient_input, + const tensor& means, + const tensor& invstds, + const tensor& src, + const tensor& gamma, + tensor& src_grad, + tensor& gamma_grad, + tensor& beta_grad + ); + /*! + requires + - eps > 0 + - invstds and means should be the output of a call to + batch_normalize(eps,dest,means,invstds,src,gamma,beta) + - have_same_dimensions(gradient_input, src) == true + - have_same_dimensions(src, src_grad) == true + - src.num_samples() > 1 + - gamma.num_samples() == 1 + - have_same_dimensions(gamma, gamma_grad) == true + - have_same_dimensions(gamma, beta_grad) == true + - gamma.nr() == src.nr() + - gamma.nc() == src.nc() + - gamma.k() == src.k() + - have_same_dimensions(means, gamma) == true + - have_same_dimensions(invstds, gamma) == true + ensures + - Let f(src,gamma,beta) == dot(gradient_input, dest output of + batch_normalize(eps,dest,means,invstds,src,gamma,beta)) + - Adds the gradient of f() with respect to src to #src_grad. + - Assigns the gradient of f() with respect to gamma to #gamma_grad. + - Assigns the gradient of f() with respect to beta to #beta_grad. + !*/ + +// ---------------------------------------------------------------------------------------- + + void batch_normalize_conv_inference ( + const double eps, + resizable_tensor& dest, + const tensor& src, + const tensor& gamma, + const tensor& beta, + const tensor& running_means, + const tensor& running_variances + ); + /*! + requires + - eps > 0 + - gamma.num_samples() == 1 + - gamma.nr() == 1 + - gamma.nc() == 1 + - gamma.k() == src.k() + - have_same_dimensions(gamma, beta) + - have_same_dimensions(gamma, running_means) + - have_same_dimensions(gamma, running_variances) + ensures + - Linearly transforms src as a call to batch_normalize_conv() would if src had + means and variances as given by running_means and running_variances. That + is, this function performs: + dest = gamma*(src-running_means)/sqrt(running_variances+eps) + beta + Note that it does this in a pointwise fashion over the samples, rows, and + columns in src. + !*/ + + void batch_normalize_conv ( + const double eps, + resizable_tensor& dest, + resizable_tensor& means, + resizable_tensor& invstds, + const double averaging_factor, + resizable_tensor& running_means, + resizable_tensor& running_variances, + const tensor& src, + const tensor& gamma, + const tensor& beta + ); + /*! + requires + - eps > 0 + - src.num_samples() > 1 + - gamma.num_samples()==gamma.nr()==gamma.nc() == 1 + - beta.num_samples() ==beta.nr() ==gamma.nc() == 1 + - gamma.k() == beta.k() == src.k() + - 0 <= averaging_factor <= 1 + - if (averaging_factor != 1) + - have_same_dimensions(running_means, means) == true + - have_same_dimensions(running_variances, invstds) == true + ensures + - have_same_dimensions(#dest, src) == true + - #means.num_samples()==means.nr()==means.nc() == 1 + - #invstds.num_samples() ==invstds.nr() ==invstds.nc() == 1 + - means.k() == invstds.k() == src.k() + - #src == the batch normalized version of src. + - #means == the mean values of the contents of src. + - #invstds == 1/(the standard deviation values of the contents of src). + - #running_means = (1-averaging_factor)*mat(#running_means) + averaging_factor*mat(#means); + - #running_variances = (1-averaging_factor)*mat(#running_variances) + averaging_factor*(variance of contents of src); + !*/ + + void batch_normalize_conv_gradient ( + const double eps, + const tensor& gradient_input, + const tensor& means, + const tensor& invstds, + const tensor& src, + const tensor& gamma, + tensor& src_grad, + tensor& gamma_grad, + tensor& beta_grad + ); + /*! + requires + - eps > 0 + - invstds and means should be the output of a call to + batch_normalize_conv(eps,dest,means,invstds,src,gamma,beta) + - have_same_dimensions(gradient_input, src) == true + - have_same_dimensions(src, src_grad) == true + - src.num_samples() > 1 + - gamma.num_samples()==gamma.nr()==gamma.nc() == 1 + - have_same_dimensions(gamma, gamma_grad) == true + - have_same_dimensions(gamma, beta_grad) == true + - gamma.k() == src.k() + - have_same_dimensions(means, gamma) == true + - have_same_dimensions(invstds, gamma) == true + ensures + - Let f(src,gamma,beta) == dot(gradient_input, dest output of + batch_normalize_conv(eps,dest,means,invstds,src,gamma,beta)) + - Adds the gradient of f() with respect to src to #src_grad. + - Assigns the gradient of f() with respect to gamma to #gamma_grad. + - Assigns the gradient of f() with respect to beta to #beta_grad. + !*/ + +// ----------------------------------------------------------------------------------- + + void threshold ( + tensor& data, + float thresh + ); + /*! + ensures + - Sets all elements of data to 1 or 0 depending on if they are above or below + the given threshold. Specifically, for all valid i: + - #data.host()[i] == data.host()[i]>thresh ? 1 : 0 + !*/ + + void dot ( + const tensor& a, + const tensor& b, + tensor& result, + size_t idx + ); + /*! + requires + - a.size() == b.size() + - idx < result.size() + ensures + - #result.host()[idx] == result.host()[idx] + dot(a,b); + I.e. Adds the dot product between a and b into the idx-th element of result. + The reason you might want to use this more complex version of dot() is + because, when using CUDA, it runs by generating asynchronous kernel launches + whereas the version of dot() that returns the result immediately as a scalar + must block the host while we wait for the result to be computed and then + transfered from the GPU do the host for return by dot(). So this version of + dot() might be much faster in some cases. + !*/ + +// ---------------------------------------------------------------------------------------- + + void add( + float beta, + tensor& dest, + float alpha, + const tensor& src + ); + /*! + requires + - One of the following is true: + - have_same_dimensions(src, dest) + - src.num_samples()==1 && src.k()==dest.k() && src.nr()==1 && src.nc()==1 + - src.num_samples()==1 && src.k()==dest.k() && src.nr()==dest.nr() && src.nc()==dest.nc() + - src.num_samples()==1 && src.k()==1 && src.nr()==dest.nr() && src.nc()==dest.nc() + - src.num_samples()==dest.num_samples() && src.k()==1 && src.nr()==1 && src.nc()==1 + - is_same_object(src,dest) == false + ensures + - performs: dest = beta*dest + alpha*src + However, how the addition happens depends on the dimensions of src. In + particular, this function adds the scaled values of one src tensor to dest. + Each dimension of the src tensor must match the corresponding dimension of + the dest tensor or must be equal to 1. In the latter case, the same value + from the src tensor, for those dimensions, will be used to add into the dest + tensor. + !*/ + +// ---------------------------------------------------------------------------------------- + + void add ( + tensor& dest, + const tensor& src1, + const tensor& src2 + ); + /*! + ensures + - performs: dest = src1 + src2 + The addition happens pointwise according to 4D tensor arithmetic. If the + dimensions don't match then missing elements are presumed to be equal to 0. + !*/ + +// ---------------------------------------------------------------------------------------- + + void assign_conv_bias_gradient ( + tensor& grad, + const tensor& gradient_input + ); + /*! + requires + - grad.num_samples() == 1 + - grad.k() >= 1 + - grad.nr() == 1 + - grad.nc() == 1 + - gradient_input.k() == grad.k() + - gradient_input.size() > 0 + - is_same_object(grad,gradient_input) == false + ensures + - let BIAS be a tensor with the same dimensions as grad. + - let OUT be the output of add(1,OUT,1,BIAS) + - let f(gradient_input,BIAS) == dot(gradient_input,OUT) + - Then this function computes the gradient of f() with respect to BIAS and + assigns it to grad. + !*/ + +// ---------------------------------------------------------------------------------------- + + void assign_bias_gradient ( + tensor& grad, + const tensor& gradient_input + ); + /*! + requires + - grad.num_samples() == 1 + - gradient_input.k() == grad.k() + - gradient_input.nr() == grad.nr() + - gradient_input.nc() == grad.nc() + - gradient_input.size() > 0 + - is_same_object(grad,gradient_input) == false + ensures + - let BIAS be a tensor with the same dimensions as grad. + - let OUT be the output of add(1,OUT,1,BIAS) + - let f(gradient_input,BIAS) == dot(gradient_input,OUT) + - Then this function computes the gradient of f() with respect to BIAS and + assigns it to grad. + !*/ + +// ---------------------------------------------------------------------------------------- + + class tensor_conv + { + public: + tensor_conv(const tensor_conv&) = delete; + tensor_conv& operator=(const tensor_conv&) = delete; + + tensor_conv() {} + + void clear( + ) { impl.clear(); } + + void operator() ( + const bool add_to_output, + tensor& output, + const tensor& data, + const tensor& filters + ) { impl(add_to_output,output,data,filters); } + /*! + requires + - setup() has been called. Specifically, setup() has been called like this: + this->setup(data, filters, stride_y, stride_x, padding_y, padding_x); + - is_same_object(output,data) == false + - is_same_object(output,filters) == false + - filters.k() == data.k() + - filters.nr() <= src.nr() + 2*padding_y + - filters.nc() <= src.nc() + 2*padding_x + - #output.num_samples() == data.num_samples() + - #output.k() == filters.num_samples() + - #output.nr() == 1+(data.nr() + 2*padding_y - filters.nr())/stride_y + - #output.nc() == 1+(data.nc() + 2*padding_x - filters.nc())/stride_x + ensures + - Convolves filters over data. If add_to_output==true then we add the + results to output, otherwise we assign to output, overwriting the + previous values in output. + - filters contains filters.num_samples() filters. + !*/ + + void operator() ( + const bool add_to_output, + resizable_tensor& output, + const tensor& data, + const tensor& filters + ) { impl(add_to_output,output,data,filters); } + /*! + requires + - setup() has been called. Specifically, setup() has been called like this: + this->setup(data, filters, stride_y, stride_x, padding_y, padding_x); + - is_same_object(output,data) == false + - is_same_object(output,filters) == false + - filters.k() == data.k() + - filters.nr() <= src.nr() + 2*padding_y + - filters.nc() <= src.nc() + 2*padding_x + ensures + - Convolves filters over data. If add_to_output==true then we add the + results to output, otherwise we assign to output, overwriting the + previous values in output. + - filters contains filters.num_samples() filters. + - #output.num_samples() == data.num_samples() + - #output.k() == filters.num_samples() + - #output.nr() == 1+(data.nr() + 2*padding_y - filters.nr())/stride_y + - #output.nc() == 1+(data.nc() + 2*padding_x - filters.nc())/stride_x + !*/ + + void get_gradient_for_data ( + const bool add_to_output, + const tensor& gradient_input, + const tensor& filters, + tensor& data_gradient + ) { impl.get_gradient_for_data(add_to_output,gradient_input,filters,data_gradient); } + /*! + requires + - One of the following must be true: + - filters has the same dimensions as the filters object given to the + last call to operator(). Also, data_gradient has the same dimensions + as the data object given to the last call to operator(). + - setup() has been called. Specifically, setup() has been called like this: + this->setup(data_gradient, filters, stride_y, stride_x, padding_y, padding_x); + - gradient_input has the following dimensions: + - gradient_input.num_samples() == data_gradient.num_samples() + - gradient_input.k() == filters.num_samples() + - gradient_input.nr() == 1+(data_gradient.nr() + 2*padding_y - filters.nr())/stride_y + - gradient_input.nc() == 1+(data_gradient.nc() + 2*padding_x - filters.nc())/stride_x + - NOTE, these dimensions are what you would obtain if gradient_input + has the same dimensions as the last output of operator(). + - is_same_object(data_gradient,filters) == false + - is_same_object(data_gradient,gradient_input) == false + ensures + - let OUT be the output of (*this)(OUT,data,filters,sx,sy). + - let f(data,filters) == dot(OUT, gradient_input) + - if (add_to_output) then + - This function finds the gradient of f() with respect to data and adds + this gradient to data_gradient. + - else + - This function finds the gradient of f() with respect to data and + assigns this gradient to data_gradient, overwriting the previous + values in data_gradient. + !*/ + + void get_gradient_for_filters ( + const bool add_to_output, + const tensor& gradient_input, + const tensor& data, + tensor& filters_gradient + ) { impl.get_gradient_for_filters(add_to_output,gradient_input,data,filters_gradient); } + /*! + requires + - One of the following must be true: + - filters_gradient has the same dimensions as the filters object given + to the last call to operator(). Also, data has the same dimensions + as the data object given to the last call to operator(). + - setup() has been called. Specifically, setup() has been called like this: + this->setup(data, filters_gradient, stride_y, stride_x, padding_y, padding_x); + - gradient_input has the following dimensions: + - gradient_input.num_samples() == data.num_samples() + - gradient_input.k() == filters.num_samples() + - gradient_input.nr() == 1+(data.nr() + 2*padding_y - filters.nr())/stride_y + - gradient_input.nc() == 1+(data.nc() + 2*padding_x - filters.nc())/stride_x + - NOTE, these dimensions are what you would obtain if gradient_input + has the same dimensions as the last output of operator(). + - is_same_object(filters_gradient,data) == false + - is_same_object(filters_gradient,gradient_input) == false + ensures + - let OUT be the output of (*this)(OUT,data,filters,sx,sy). + - let f(data,filters) == dot(OUT, gradient_input) + - if (add_to_output) then + - This function finds the gradient of f() with respect to filters and + adds this gradient to filters_gradient. + - else + - This function finds the gradient of f() with respect to filters and + assigns this gradient to filters_gradient, overwriting the previous + values in filters_gradient. + !*/ + + + void setup( + const tensor& data, + const tensor& filters, + int stride_y, + int stride_x, + int padding_y, + int padding_x + ) {impl.setup(data,filters,stride_y,stride_x,padding_y,padding_x); } + /*! + requires + - filters.k() == data.k() + - stride_y > 0 + - stride_x > 0 + - 0 <= padding_y < filters.nr() + - 0 <= padding_x < filters.nc() + ensures + - When operator() is called, the output tensor will have these dimensions: + - output.nr() == 1+(data.nr() + 2*padding_y - filters.nr())/stride_y + - output.nc() == 1+(data.nc() + 2*padding_x - filters.nc())/stride_x + - output.num_samples() == data.num_samples() + - output.k() == filters.num_samples() + - The point of setup() is to allow this object to gather information about + all the tensor sizes and filter layouts involved in the computation. In + particular, the reason the tensors are input into setup() is just to + observe their sizes. setup() doesn't do anything with the contents of + the tensors, or store any kind of references to the data or filter + tensors. + !*/ + + private: +#ifdef DLIB_USE_CUDA + cuda::tensor_conv impl; +#else + cpu::tensor_conv impl; +#endif + + }; + +// ---------------------------------------------------------------------------------------- + + class pooling + { + /*! + WHAT THIS OBJECT REPRESENTS + The pooling object is a tool for performing spatial pooling over a tensor. + It can be configured to do either max or average pooling. + !*/ + public: + + pooling(const pooling&) = delete; + pooling& operator=(const pooling&) = delete; + + pooling ( + ) = default; + + void clear( + ) { impl.clear(); } + + void setup_max_pooling( + int window_height, + int window_width, + int stride_y, + int stride_x, + int padding_y, + int padding_x + ) { impl.setup_max_pooling(window_height, window_width, stride_y, stride_x, padding_y, padding_x); } + /*! + requires + - window_height > 0 + - window_width > 0 + - stride_y > 0 + - stride_x > 0 + - 0 <= padding_y < window_height + - 0 <= padding_x < window_width + ensures + - When you call operator() it will do max pooling with the given + parameters. + !*/ + + void setup_avg_pooling( + int window_height, + int window_width, + int stride_y, + int stride_x, + int padding_y, + int padding_x + ) { impl.setup_avg_pooling(window_height, window_width, stride_y, stride_x, padding_y, padding_x); } + /*! + requires + - window_height > 0 + - window_width > 0 + - stride_y > 0 + - stride_x > 0 + - 0 <= padding_y < window_height + - 0 <= padding_x < window_width + ensures + - When you call operator() it will do average pooling with the given + parameters. + !*/ + + bool does_max_pooling( + ) const { return impl.does_max_pooling(); } + + void operator() ( + resizable_tensor& dest, + const tensor& src + ) { impl(dest, src); } + /*! + requires + - is_same_object(dest,src) == false + - either setup_max_pooling() or setup_avg_pooling() has been called. + - window_width <= src.nc() + 2*padding_x + - window_height <= src.nr() + 2*padding_y + ensures + - #dest.num_samples() == src.num_samples() + - #dest.k() == src.k() + - #dest.nr() == 1 + (src.nr() + 2*padding_y - window_height)/stride_y + - #dest.nc() == 1 + (src.nc() + 2*padding_x - window_width)/stride_x + - WINDOW == centered_rect(x*stride_x + window_width/2 - padding_x, + y*stride_y + window_height/2 - padding_y, + window_width, + window_height) + - for all valid s, k, r, and c: + - if (does_max_pooling()) then + - image_plane(#dest,s,k)(r,c) == max(subm_clipped(image_plane(src,s,k),WINDOW(c,r))) + - else + - image_plane(#dest,s,k)(r,c) == mean(subm_clipped(image_plane(src,s,k),WINDOW(c,r))) + !*/ + + void get_gradient( + const tensor& gradient_input, + const tensor& dest, + const tensor& src, + tensor& grad + ) { impl.get_gradient(gradient_input, dest, src, grad); } + /*! + requires + - have_same_dimensions(gradient_input,dest) == true + - have_same_dimensions(src,grad) == true + - dest contains the result of calling (*this)(dest,src) + - is_same_object(grad,gradient_input) == false + - is_same_object(grad,dest) == false + - is_same_object(grad,src) == false + ensures + - Recalling that dest is the output of (*this)(dest,src), + let f(src) == dot(gradient_input,dest) + - Then this function computes the gradient of f() with respect to src and + adds it to grad. + !*/ + + private: +#ifdef DLIB_USE_CUDA + cuda::pooling impl; +#else + cpu::pooling impl; +#endif + }; + +// ---------------------------------------------------------------------------------------- + + void softmax ( + tensor& dest, + const tensor& src + ); + /*! + requires + - have_same_dimensions(dest, src) == true + ensures + - Note that the softmax function is a vector valued function: + s(x) == exp(x)/sum(exp(x)) + - Computes the softmax function on src and writes the results to dest. The + softmax is computed per spatial location across the different channels at + each location. That is, softmax() outputs a new tensor, #dest, where each of + the spatial locations in dest (i.e. image idx, row idx, and column idx) + contains the output of s() evaluated over the channel values at each + location. + - This function supports in-place operation, i.e. having + is_same_object(dest, src)==true + !*/ + + void softmax_gradient ( + tensor& grad, + const tensor& dest, + const tensor& gradient_input + ); + /*! + requires + - have_same_dimensions(dest,gradient_input) == true + - have_same_dimensions(dest,grad) == true + ensures + - We interpret dest as the output of softmax(dest,SRC) for some SRC tensor. + Then let f(SRC) == dot(gradient_input,dest). Then this function computes the + gradient of f() with respect to SRC and stores it to grad. Moreover, if + is_same_object(grad,gradient_input)==true then the output is assigned to + grad, replacing its previous contents. Otherwise the output is added to + grad. + - This function supports in-place operation, i.e. having + is_same_object(grad, gradient_input)==true + !*/ + +// ---------------------------------------------------------------------------------------- + + void softmax_all ( + tensor& dest, + const tensor& src + ); + /*! + requires + - have_same_dimensions(dest, src) == true + ensures + - Note that the softmax function is a vector valued function: + s(x) == exp(x)/sum(exp(x)) + - Computes the softmax function on src and writes the results to dest. The + softmax is computed over the entire tensor with one invocation of s(). So + unlike softmax() which computes many s() evaluations, one for each spatial + location, softmax_all() calls s() once for the entire tensor. + - This function supports in-place operation, i.e. having + is_same_object(dest, src)==true + !*/ + + void softmax_all_gradient ( + tensor& grad, + const tensor& dest, + const tensor& gradient_input + ); + /*! + requires + - have_same_dimensions(dest,gradient_input) == true + - have_same_dimensions(dest,grad) == true + - is_same_object(grad, dest)==false + ensures + - We interpret dest as the output of softmax_all(dest,SRC) for some SRC tensor. + Then let f(SRC) == dot(gradient_input,dest) Then this function computes the + gradient of f() with respect to SRC and assigns it to grad. + - This function supports in-place operation, i.e. having + is_same_object(grad, gradient_input)==true + !*/ + +// ---------------------------------------------------------------------------------------- + + void sigmoid ( + tensor& dest, + const tensor& src + ); + /*! + requires + - have_same_dimensions(dest, src) == true + ensures + - for all valid i: + - #dest.host()[i] == 1/(1+std::exp(-src.host()[i])) + - This function supports in-place operation, i.e. having + is_same_object(dest, src)==true + !*/ + + void sigmoid_gradient ( + tensor& grad, + const tensor& dest, + const tensor& gradient_input + ); + /*! + requires + - have_same_dimensions(dest,gradient_input) == true + - have_same_dimensions(dest,grad) == true + ensures + - Recalling that dest is the output of sigmoid(dest,SRC) for some SRC tensor, + let f(SRC) == dot(gradient_input,dest). Then this function computes the + gradient of f() with respect to SRC and stores it to grad. Moreover, if + is_same_object(grad,gradient_input)==true then the output is assigned to + grad, replacing its previous contents. Otherwise the output is added to + grad. + - This function supports in-place operation, i.e. having + is_same_object(grad, gradient_input)==true + !*/ + +// ---------------------------------------------------------------------------------------- + + void relu ( + tensor& dest, + const tensor& src + ); + /*! + requires + - have_same_dimensions(dest, src) == true + ensures + - for all valid i: + - #dest.host()[i] == std::max(0,src.host()[i]) + - This function supports in-place operation, i.e. having + is_same_object(dest, src)==true + !*/ + + void relu_gradient ( + tensor& grad, + const tensor& dest, + const tensor& gradient_input + ); + /*! + requires + - have_same_dimensions(dest,gradient_input) == true + - have_same_dimensions(dest,grad) == true + ensures + - Recalling that dest is the output of relu(dest,SRC) for some SRC tensor, + let f(SRC) == dot(gradient_input,dest). Then this function computes the + gradient of f() with respect to SRC and stores it to grad. Moreover, if + is_same_object(grad,gradient_input)==true then the output is assigned to + grad, replacing its previous contents. Otherwise the output is added to + grad. + - This function supports in-place operation, i.e. having + is_same_object(grad, gradient_input)==true + !*/ + +// ---------------------------------------------------------------------------------------- + + void prelu ( + tensor& dest, + const tensor& src, + const tensor& param + ); + /*! + requires + - have_same_dimensions(dest, src) == true + - param.size() == 1 + ensures + - for all valid i: + - if (src.host()[i] > 0) then + - #dest.host()[i] == src.host()[i] + - else + - #dest.host()[i] == src.host()[i] * param.host()[0] + - This function supports in-place operation, i.e. having + is_same_object(dest, src)==true + !*/ + + void prelu_gradient ( + tensor& grad, + const tensor& src, + const tensor& gradient_input, + const tensor& param, + tensor& params_grad + ); + /*! + requires + - have_same_dimensions(grad,src) == true + - have_same_dimensions(grad,gradient_input) == true + - param.size() == 1 + - params_grad.size() == 1 + - is_same_object(grad, gradient_input) == false + ensures + - Recalling that dest is the output of prelu(dest,src,param) let + f(src,param) == dot(gradient_input,dest) + - Then this function computes the gradient of f() with respect to src and + param. It assigns the gradient with respect to param to #params_grad and + adds the gradient with respect to src to #grad. + !*/ + +// ---------------------------------------------------------------------------------------- + + void tanh ( + tensor& dest, + const tensor& src + ); + /*! + requires + - have_same_dimensions(dest, src) == true + ensures + - for all valid i: + - #dest.host()[i] == std::tanh(src.host()[i]) + - This function supports in-place operation, i.e. having + is_same_object(dest, src)==true + !*/ + + void tanh_gradient ( + tensor& grad, + const tensor& dest, + const tensor& gradient_input + ); + /*! + requires + - have_same_dimensions(dest,gradient_input) == true + - have_same_dimensions(dest,grad) == true + ensures + - Recalling that dest is the output of tanh(dest,SRC) for some SRC tensor, + let f(SRC) == dot(gradient_input,dest). Then this function computes the + gradient of f() with respect to SRC and stores it to grad. Moreover, if + is_same_object(grad,gradient_input)==true then the output is assigned to + grad, replacing its previous contents. Otherwise the output is added to + grad. + - This function supports in-place operation, i.e. having + is_same_object(grad, gradient_input)==true + !*/ + +// ---------------------------------------------------------------------------------------- + + void resize_bilinear ( + tensor& dest, + long dest_row_stride, + long dest_channel_stride, + const tensor& src, + long src_row_stride, + long src_channel_stride + ); + /*! + requires + - is_same_object(dest, src)==false + - dest.num_samples() == src.num_samples() + - dest.k() == src.k() + ensures + - for all valid i,k: image_plane(dest,i,k) is a copy of image_plane(src,i,k) + that has been bilinearly interpolated to fit into the shape of + image_plane(dest,i,k). + - Instead of supposing the row stride and channel stride in the tensors is + given by tensor::nc() and tensor::nr()*tensor::nc() respectively, we use the + provided stride values to transition from one row and channel to the next. + This is useful in combination with alias_tensor objects since it allows you + to operate on subwindows in an image. + !*/ + + void resize_bilinear_gradient ( + tensor& grad, + long grad_row_stride, + long grad_channel_stride, + const tensor& gradient_input, + long gradient_input_row_stride, + long gradient_input_channel_stride + ); + /*! + requires + - is_same_object(grad, gradient_input)==false + - gradient_input.num_samples() == grad.num_samples() + - gradient_input.k() == grad.k() + ensures + - Suppose that DEST is the output of resize_bilinear(DEST,SRC) for some SRC + tensor, let f(SRC) == dot(gradient_input,DEST). Then this function computes + the gradient of f() with respect to SRC and adds it to grad. It should be + noted that we don't need to know the contents of DEST to compute this + gradient. All that matters is that gradient_input have the same dimensions + as DEST. + - Instead of supposing the row stride and channel stride in the tensors is + given by tensor::nc() and tensor::nr()*tensor::nc() respectively, we use the + provided stride values to transition from one row and channel to the next. + This is useful in combination with alias_tensor objects since it allows you + to operate on subwindows in an image. + !*/ + + inline void resize_bilinear ( + tensor& dest, + const tensor& src + ) { resize_bilinear(dest, dest.nc(), dest.nr()*dest.nc(), src, src.nc(), src.nr()*src.nc()); } + /*! + requires + - is_same_object(dest, src)==false + - dest.num_samples() == src.num_samples() + - dest.k() == src.k() + ensures + - for all valid i,k: image_plane(dest,i,k) is a copy of image_plane(src,i,k) + that has been bilinearly interpolated to fit into the shape of + image_plane(dest,i,k). + !*/ + + inline void resize_bilinear_gradient ( + tensor& grad, + const tensor& gradient_input + ) { resize_bilinear_gradient(grad, grad.nc(), grad.nr()*grad.nc(), gradient_input, gradient_input.nc(), gradient_input.nr()*gradient_input.nc()); } + /*! + requires + - is_same_object(grad, gradient_input)==false + - gradient_input.num_samples() == grad.num_samples() + - gradient_input.k() == grad.k() + ensures + - Suppose that DEST is the output of resize_bilinear(DEST,SRC) for some SRC + tensor, let f(SRC) == dot(gradient_input,DEST). Then this function computes + the gradient of f() with respect to SRC and adds it to grad. It should be + noted that we don't need to know the contents of DEST to compute this + gradient. All that matters is that gradient_input have the same dimensions + as DEST. + !*/ + +// ---------------------------------------------------------------------------------------- + + class multi_device_tensor_averager + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a tool for very quickly averaging a bunch of tensors + together. + !*/ + public: + + multi_device_tensor_averager(const multi_device_tensor_averager&) = delete; + multi_device_tensor_averager& operator=(const multi_device_tensor_averager&) = delete; + + multi_device_tensor_averager() = default; + + void set( + std::vector items + ) + /*! + requires + - All the tensors in items are the same size + ensures + - When you call average() we will average the tensors in items. + - It's important that the tensors already be allocated to their devices + before you call set(). This is because set() will setup the types of + between device transfers now and use them when you call average(). + !*/ + { + using namespace ::dlib::cuda; + accessible_groups.clear(); + epa.clear(); + if (items.size() < 1) + return; + + scale = 1.0/items.size(); + + // split item into groups of accessible devices + std::vector group, unused; + while(items.size() > 0) + { + group.push_back(items[0]); + for(size_t i = 1; i < items.size(); ++i) + { + if (can_access_peer(*items[0], *items[i])) + group.push_back(items[i]); + else + unused.push_back(items[i]); + } + accessible_groups.push_back(group); + unused.swap(items); + unused.clear(); + group.clear(); + } + for (auto&& g : accessible_groups) + { + for (size_t i = 1; i < g.size(); ++i) + { + epa.emplace_back(new enable_peer_access(*g[0], *g[i])); + } + } + } + + size_t num_device_groups( + ) const { return accessible_groups.size(); } + /*! + ensures + - The devices given to set() are grouped together when they can directly + access each other using GPUDirect. This function returns the number of + such groups. For example, if all devices can directly access each other + then the number of groups is 1. + !*/ + + void average() + /*! + requires + - All the devices have stopped writing to the tensors given to set(). So + you should probably call cudaDeviceSynchronize() on each of the relevant + devices before calling average(). + ensures + - Computes the average of all the tensors given to set() and then sets them + all equal to the average. + !*/ + { + using namespace ::dlib::cuda; + + + // First we average things within each group + for (auto&& g : accessible_groups) + { + raii_set_device set_dev(*g[0]); + if (g.size() == 1) + tt::affine_transform(*g[0], *g[0], scale); + else + tt::affine_transform(*g[0], *g[0], *g[1], scale, scale); + + for (size_t i = 2; i < g.size(); ++i) + tt::affine_transform(*g[0], *g[0], *g[i], 1, scale); + } + + if (accessible_groups.size() > 1) + { + tensor& total_avg = *accessible_groups[0][0]; + raii_set_device set_dev(total_avg); + accum_buffer.copy_size(total_avg); + // now we need to average things across groups + for (size_t i = 1; i < accessible_groups.size(); ++i) + { + memcpy(accum_buffer, *accessible_groups[i][0]); + tt::add(total_avg, total_avg, accum_buffer); + } + + // Now total_avg has the final average in it. So we need to send + // copies of it back to each of the groups. + for (size_t i = 1; i < accessible_groups.size(); ++i) + { + memcpy(*accessible_groups[i][0], total_avg); + } + } + + + // Now propagate averages back out to each element using point to point + // communication inside a group. + for (auto&& g : accessible_groups) + { + raii_set_device set_dev(*g[0]); + for (size_t i = 1; i < g.size(); ++i) + memcpy(*g[i], *g[0]); + } + } + + private: + std::vector> epa; + std::vector> accessible_groups; + float scale; + + resizable_tensor accum_buffer; + }; + +// ---------------------------------------------------------------------------------------- + + void copy_tensor( + bool add_to, + tensor& dest, + size_t dest_k_offset, + const tensor& src, + size_t src_k_offset, + size_t count_k + ); + /*! + requires + - dest.nc() == src.nc() + - dest.nr() == src.nr() + - dest.num_samples() == src.num_samples() + - dest.k() - dest_k_offset >= count_k + - src.k() - src_k_offset >= count_k + - is_same_object(dest,src) == false + - The memory areas of src and dest do not overlap. + ensures + - if (add_to) then + - performs: dest[i, k + dest_k_offset, r, c] += src[i, k + src_k_offset, r, c], where k in [0..count_k] + i.e., adds content of each sample from src in to corresponding place of sample at dest. + - else + - performs: dest[i, k + dest_k_offset, r, c] = src[i, k + src_k_offset, r, c], where k in [0..count_k] + i.e., copies content of each sample from src in to corresponding place of sample at dest. + !*/ + +// ---------------------------------------------------------------------------------------- + +}} + +#ifdef NO_MAKEFILE +#include "tensor_tools.cpp" +#endif + +#endif // DLIB_TeNSOR_TOOLS_H_ + + diff --git a/ml/dlib/dlib/dnn/trainer.h b/ml/dlib/dlib/dnn/trainer.h new file mode 100644 index 000000000..7cb2bf5e5 --- /dev/null +++ b/ml/dlib/dlib/dnn/trainer.h @@ -0,0 +1,1333 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_DNn_TRAINER_H_ +#define DLIB_DNn_TRAINER_H_ + +#include "trainer_abstract.h" +#include "core.h" +#include "solvers.h" +#include "../statistics.h" +#include +#include +#include +#include "../serialize.h" + +#include "../pipe.h" +#include "../threads.h" +#include "cuda_dlib.h" +#include "../statistics/running_gradient.h" +#include +#include +#include +#include +#include +#include +#include "../dir_nav.h" +#include "../md5.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + template + struct dnn_job_t + { + dnn_job_t() = default; + dnn_job_t(const dnn_job_t&) = delete; + dnn_job_t& operator=(const dnn_job_t&) = delete; + + std::vector> labels; + std::vector t; + std::vector have_data; // have_data[i] is true if there is data in labels[i] and t[i]. + bool test_only = false; + }; + + template + void swap(dnn_job_t& a, dnn_job_t& b) + { + a.labels.swap(b.labels); + a.t.swap(b.t); + a.have_data.swap(b.have_data); + std::swap(a.test_only,b.test_only); + } + } + + enum class force_flush_to_disk { + no = 0, + yes = 1 + }; + + template < + typename net_type, + typename solver_type = sgd + > + class dnn_trainer : private threaded_object + { + public: + + static_assert(is_loss_layer_type::value, + "The last layer in a network must be a loss layer."); + + typedef typename net_type::training_label_type training_label_type; + typedef typename net_type::input_type input_type; + const static size_t num_computational_layers = net_type::num_computational_layers; + const static size_t num_layers = net_type::num_layers; + private: + typedef impl::dnn_job_t job_t; + public: + + dnn_trainer() = delete; + dnn_trainer(const dnn_trainer&) = delete; + dnn_trainer& operator=(const dnn_trainer&) = delete; + + explicit dnn_trainer(net_type& net_) : job_pipe(0), net(net_) + { + solver_type default_solver; + devices.push_back(std::make_shared(dlib::cuda::get_device(), net, default_solver)); + + init(); + } + + dnn_trainer( + net_type& net_, + const solver_type& solver_ + ) : job_pipe(0), net(net_) + { + devices.push_back(std::make_shared(dlib::cuda::get_device(), net, solver_)); + + init(); + } + + dnn_trainer( + net_type& net_, + const solver_type& solver_, + const std::vector& cuda_extra_devices + ) : job_pipe(0), net(net_) + { + devices.push_back(std::make_shared(dlib::cuda::get_device(), net, solver_)); + + const int total_devices = dlib::cuda::get_num_devices(); + + // Make device contexts for the extra device ids but be careful to avoid any + // duplicate ids. + std::set temp(cuda_extra_devices.begin(), cuda_extra_devices.end()); + temp.erase(devices[0]->device_id); + for (auto id : temp) + { + DLIB_CASSERT(0 <= id && id < total_devices, "Invalid CUDA device id given to dnn_trainer."); + // Switch to this device so that any tensor objects that get allocated when + // we create the device context happen on this device. + dlib::cuda::set_device(id); + devices.push_back(std::make_shared(id, net, solver_, clone_net())); + } + // Set the current device back to what it was before this constructor was + // called. + dlib::cuda::set_device(devices[0]->device_id); + + init(); + } + + ~dnn_trainer( + ) + { + job_pipe.disable(); + stop(); + wait(); + } + + net_type& get_net ( + force_flush_to_disk force_flush = force_flush_to_disk::yes + ) + { + wait_for_thread_to_pause(); + sync_to_disk(force_flush == force_flush_to_disk::yes); + propagate_exception(); + return net; + } + + + unsigned long get_mini_batch_size ( + ) const { return mini_batch_size; } + + void set_mini_batch_size ( + unsigned long batch_size + ) + { + DLIB_CASSERT(batch_size > 0); + mini_batch_size = batch_size; + } + + unsigned long get_max_num_epochs ( + ) const { return max_num_epochs; } + + void set_max_num_epochs ( + unsigned long num + ) + { + DLIB_CASSERT(num > 0); + max_num_epochs = num; + } + + void be_verbose ( + ) + { + verbose = true; + } + + void be_quiet ( + ) + { + verbose = false; + } + + + const std::vector& get_solvers ( + ) const + { + wait_for_thread_to_pause(); + propagate_exception(); + return devices[0]->solvers; + } + + void train_one_step ( + const std::vector& data, + const std::vector& labels + ) + { + DLIB_CASSERT(data.size() == labels.size()); + + train_one_step(data.begin(), data.end(), labels.begin()); + } + + template < + typename data_iterator, + typename label_iterator + > + void train_one_step ( + data_iterator dbegin, + data_iterator dend, + label_iterator lbegin + ) + { + DLIB_CASSERT(std::distance(dbegin, dend) > 0); + + print_periodic_verbose_status(); + sync_to_disk(); + send_job(false, dbegin, dend, lbegin); + + ++train_one_step_calls; + } + + void train_one_step ( + const std::vector& data + ) + { + train_one_step(data.begin(), data.end()); + } + + template < + typename data_iterator + > + void train_one_step ( + data_iterator dbegin, + data_iterator dend + ) + { + DLIB_CASSERT(std::distance(dbegin, dend) > 0); + print_periodic_verbose_status(); + sync_to_disk(); + send_job(false, dbegin, dend); + ++train_one_step_calls; + } + + void test_one_step ( + const std::vector& data, + const std::vector& labels + ) + { + DLIB_CASSERT(data.size() == labels.size()); + + test_one_step(data.begin(), data.end(), labels.begin()); + } + + template < + typename data_iterator, + typename label_iterator + > + void test_one_step ( + data_iterator dbegin, + data_iterator dend, + label_iterator lbegin + ) + { + DLIB_CASSERT(std::distance(dbegin, dend) > 0); + + print_periodic_verbose_status(); + sync_to_disk(); + send_job(true, dbegin, dend, lbegin); + + ++test_one_step_calls; + } + + void test_one_step ( + const std::vector& data + ) + { + test_one_step(data.begin(), data.end()); + } + + template < + typename data_iterator + > + void test_one_step ( + data_iterator dbegin, + data_iterator dend + ) + { + DLIB_CASSERT(std::distance(dbegin, dend) > 0); + print_periodic_verbose_status(); + sync_to_disk(); + send_job(true, dbegin, dend); + ++test_one_step_calls; + } + + void train ( + const std::vector& data, + const std::vector& labels + ) + { + DLIB_CASSERT(data.size() == labels.size() && data.size() > 0); + + // The reason these two loops don't initialize their counter variables but + // instead use class members is so we can include the state of the loops in the + // stuff written by sync_to_disk() + for (; + epoch_iteration < max_num_epochs && learning_rate >= min_learning_rate; + ++epoch_iteration) + { + using namespace std::chrono; + last_time = system_clock::now(); + clear_average_loss(); + for (; epoch_pos < data.size() && learning_rate >= min_learning_rate; epoch_pos += mini_batch_size) + { + if (verbose) + { + auto now_time = system_clock::now(); + if (now_time-last_time > seconds(20)) + { + last_time = now_time; + auto iter = epoch_iteration + epoch_pos/(double)data.size(); + std::cout << "epoch: " << rpad(cast_to_string(iter),epoch_string_pad) << " " + << "learning rate: " << rpad(cast_to_string(learning_rate),lr_string_pad) << " " + << "average loss: " << rpad(cast_to_string(get_average_loss()),string_pad) << " "; + print_progress(); + } + } + + sync_to_disk(); + send_job(false, data.begin()+epoch_pos, + data.begin()+std::min(epoch_pos+mini_batch_size,data.size()), + labels.begin()+epoch_pos); + } + epoch_pos = 0; + + if (verbose) + { + // Capitalize the E in Epoch so it's easy to grep out the lines that + // are for full epoch status statements. + std::cout << "Epoch: " << rpad(cast_to_string(epoch_iteration+1),epoch_string_pad) << " " + << "learning rate: " << rpad(cast_to_string(learning_rate),lr_string_pad) << " " + << "average loss: " << rpad(cast_to_string(get_average_loss()),string_pad) << " "; + print_progress(); + } + } + wait_for_thread_to_pause(); + // if we modified the network at all then be sure to sync the final result. + sync_to_disk(true); + } + + void train ( + const std::vector& data + ) + { + DLIB_CASSERT(data.size() > 0); + + const bool has_unsupervised_loss = std::is_same::value; + static_assert(has_unsupervised_loss, + "You can only call this version of train() when using an unsupervised loss."); + + // The reason these two loops don't initialize their counter variables but + // instead use class members is so we can include the state of the loops in the + // stuff written by sync_to_disk() + for (; + epoch_iteration < max_num_epochs && learning_rate >= min_learning_rate; + ++epoch_iteration) + { + using namespace std::chrono; + last_time = system_clock::now(); + clear_average_loss(); + for (; epoch_pos < data.size() && learning_rate >= min_learning_rate; epoch_pos += mini_batch_size) + { + if (verbose) + { + auto now_time = system_clock::now(); + if (now_time-last_time > seconds(20)) + { + last_time = now_time; + auto iter = epoch_iteration + epoch_pos/(double)data.size(); + std::cout << "epoch: " << rpad(cast_to_string(iter),epoch_string_pad) << " " + << "learning rate: " << rpad(cast_to_string(learning_rate),lr_string_pad) << " " + << "average loss: " << rpad(cast_to_string(get_average_loss()),string_pad) << " "; + print_progress(); + } + } + + sync_to_disk(); + send_job(false, data.begin()+epoch_pos, + data.begin()+std::min(epoch_pos+mini_batch_size,data.size())); + } + epoch_pos = 0; + + if (verbose) + { + // Capitalize the E in Epoch so it's easy to grep out the lines that + // are for full epoch status statements. + std::cout << "Epoch: " << rpad(cast_to_string(epoch_iteration+1),epoch_string_pad) << " " + << "learning rate: " << rpad(cast_to_string(learning_rate),lr_string_pad) << " " + << "average loss: " << rpad(cast_to_string(get_average_loss()),string_pad) << " "; + print_progress(); + } + } + wait_for_thread_to_pause(); + // if we modified the network at all then be sure to sync the final result. + sync_to_disk(true); + } + + void set_synchronization_file ( + const std::string& filename, + std::chrono::seconds time_between_syncs_ = std::chrono::minutes(15) + ) + { + last_sync_time = std::chrono::system_clock::now(); + sync_filename = filename; + time_between_syncs = time_between_syncs_; + + // check if the sync file already exists, if it does we should load it. + std::ifstream fin(newest_syncfile(), std::ios::binary); + if (fin) + deserialize(*this, fin); + } + + const std::string& get_synchronization_file ( + ) + { + return sync_filename; + } + + double get_average_loss ( + ) const + { + wait_for_thread_to_pause(); + return rs.mean(); + } + + double get_average_test_loss ( + ) const + { + wait_for_thread_to_pause(); + return rs_test.mean(); + } + + void clear_average_loss ( + ) + { + wait_for_thread_to_pause(); + rs.clear(); + } + + void set_learning_rate ( + double lr + ) + { + DLIB_CASSERT(lr > 0); + wait_for_thread_to_pause(); + if (learning_rate != lr) + { + steps_without_progress = 0; + test_steps_without_progress = 0; + previous_loss_values.clear(); + test_previous_loss_values.clear(); + } + learning_rate = lr; + lr_schedule.set_size(0); + } + + double get_learning_rate( + ) const + { + return learning_rate; + } + + void set_min_learning_rate ( + double lr + ) + { + DLIB_CASSERT(lr > 0); + wait_for_thread_to_pause(); + lr_schedule.set_size(0); + min_learning_rate = lr; + } + + double get_min_learning_rate ( + ) const + { + return min_learning_rate; + } + + template + void set_learning_rate_schedule ( + const matrix_exp& schedule + ) + { + DLIB_CASSERT(schedule.size() > 0); + DLIB_CASSERT(min(schedule) > 0); + set_learning_rate(schedule(0,0)); + set_min_learning_rate(min(schedule)); + set_learning_rate_shrink_factor(1); + lr_schedule = matrix_cast(reshape_to_column_vector(schedule)); + lr_schedule_pos = 0; + } + + const matrix& get_learning_rate_schedule ( + ) const + { + return lr_schedule; + } + + void set_iterations_without_progress_threshold ( + unsigned long thresh + ) + { + wait_for_thread_to_pause(); + lr_schedule.set_size(0); + iter_without_progress_thresh = thresh; + } + + unsigned long get_iterations_without_progress_threshold ( + ) const + { + return iter_without_progress_thresh; + } + + unsigned long get_steps_without_progress ( + ) const + { + return steps_without_progress; + } + + void set_test_iterations_without_progress_threshold ( + unsigned long thresh + ) + { + wait_for_thread_to_pause(); + lr_schedule.set_size(0); + test_iter_without_progress_thresh = thresh; + } + + unsigned long get_test_iterations_without_progress_threshold ( + ) const + { + return test_iter_without_progress_thresh; + } + + unsigned long get_test_steps_without_progress ( + ) const + { + return test_steps_without_progress; + } + + void set_learning_rate_shrink_factor ( + double shrink + ) + { + DLIB_CASSERT(0 < shrink && shrink <= 1); + wait_for_thread_to_pause(); + lr_schedule.set_size(0); + learning_rate_shrink = shrink; + steps_without_progress = 0; + test_steps_without_progress = 0; + } + + double get_learning_rate_shrink_factor ( + ) const + { + return learning_rate_shrink; + } + + unsigned long long get_train_one_step_calls ( + ) const + { + return train_one_step_calls; + } + + unsigned long long get_test_one_step_calls ( + ) const + { + return test_one_step_calls; + } + + private: + + void record_test_loss(double loss) + { + test_previous_loss_values.push_back(loss); + if (is_finite(loss)) + rs_test.add(loss); + // discard really old loss values. + while (test_previous_loss_values.size() > test_iter_without_progress_thresh) + test_previous_loss_values.pop_front(); + } + + void record_loss(double loss) + { + // This kind of budgeting causes our gradient checking to use a fixed amount of + // computational resources, regardless of the size of iter_without_progress_thresh. + gradient_check_budget += 200; + + rs.add(loss); + previous_loss_values.push_back(loss); + // discard really old loss values. + while (previous_loss_values.size() > iter_without_progress_thresh) + previous_loss_values.pop_front(); + } + + template + double compute_parameter_gradients(size_t device, job_t& next_job, const T&) + { + if (next_job.have_data[device]) + { + auto&& dev = *devices[device]; + dlib::cuda::set_device(dev.device_id); + if (next_job.test_only) + return dev.net.compute_loss(next_job.t[device], next_job.labels[device].begin()); + else + return dev.net.compute_parameter_gradients(next_job.t[device], next_job.labels[device].begin()); + } + else + { + return 0; + } + } + + double compute_parameter_gradients(size_t device, job_t& next_job, const no_label_type&) + { + if (next_job.have_data[device]) + { + auto&& dev = *devices[device]; + dlib::cuda::set_device(dev.device_id); + no_label_type pick_which_run_update; + if (next_job.test_only) + return dev.net.compute_loss(next_job.t[device]); + else + return dev.net.compute_parameter_gradients(next_job.t[device]); + } + else + { + return 0; + } + } + + void update_parameters(size_t device) + { + auto&& dev = *devices[device]; + dlib::cuda::set_device(dev.device_id); + dev.net.update_parameters(make_sstack(dev.solvers), learning_rate); + } + + void thread() try + { + training_label_type pick_which_run_update; + job_t next_job; + + std::vector> losses(devices.size()); + + std::vector averagers; + // An array of all the parameter tensors in the first network. We will + // periodically copy these tensors to all the other devices to make sure the + // different GPUs don't go out of sync. + std::vector reference_params; + visit_layer_parameters(devices[0]->net, [&](size_t, tensor& t) { reference_params.push_back(&t); }); + + // We make separate thread pools with just one thread in them because we want + // to make sure each device is always executed on the same thread. We care + // about this because there are thread_local context variables for some cuda + // components and they get allocated for each combination of thread and device. + // So if we make sure the same device always uses the same thread this will + // reduce the number of contexts we allocate from num_devices*num_devices to + // just num_devices. + std::vector> tp; + for (size_t i = 0; i < devices.size(); ++i) + tp.push_back(std::make_shared(1)); + + + main_iteration_counter = 0; + while(job_pipe.dequeue(next_job)) + { + if (next_job.test_only) + { + // compute the testing loss + for (size_t i = 0; i < devices.size(); ++i) + tp[i]->add_task_by_value([&,i](double& loss){ loss = compute_parameter_gradients(i, next_job, pick_which_run_update); }, losses[i]); + // aggregate loss values from all the network computations. + double theloss = 0; + for (auto&& loss : losses) + theloss += loss.get(); + record_test_loss(theloss/losses.size()); + + // Check if we should shrink the learning rate based on how the test + // error has been doing lately. + if (learning_rate_shrink != 1) + { + test_steps_without_progress = count_steps_without_decrease(test_previous_loss_values); + if (test_steps_without_progress >= test_iter_without_progress_thresh) + { + test_steps_without_progress = count_steps_without_decrease_robust(test_previous_loss_values); + if (test_steps_without_progress >= test_iter_without_progress_thresh) + { + // optimization has flattened out, so drop the learning rate. + learning_rate = learning_rate_shrink*learning_rate; + test_steps_without_progress = 0; + // Empty out some of the previous loss values so that test_steps_without_progress + // will decrease below test_iter_without_progress_thresh. + for (unsigned long cnt = 0; cnt < test_previous_loss_values_dump_amount+test_iter_without_progress_thresh/10 && test_previous_loss_values.size() > 0; ++cnt) + test_previous_loss_values.pop_front(); + } + } + } + continue; + } + + updated_net_since_last_sync = true; + ++main_iteration_counter; + // Call compute_parameter_gradients() and update_parameters() but pick the + // right version for unsupervised or supervised training based on the type + // of training_label_type. + for (size_t i = 0; i < devices.size(); ++i) + tp[i]->add_task_by_value([&,i](double& loss){ loss = compute_parameter_gradients(i, next_job, pick_which_run_update); }, losses[i]); + // aggregate loss values from all the network computations. + double theloss = 0; + for (auto&& loss : losses) + theloss += loss.get(); + record_loss(theloss/losses.size()); + + // Now, if there is more than one active device we need to synchronize the + // gradient updates between devices. So we do that now. + if (devices.size() > 1) + { + // if this is the first iteration then we need to setup the averagers. + // We can't do this outside the loop because the tensors that get + // averaged need to be allocated to their devices before we call set() + // so that the averagers can determine how best to average them. + if (averagers.size() == 0 || sync_file_reloaded) + { + averagers = std::vector(net_type::num_computational_layers); + // setup the averagers to point to the tensors in the networks. + std::vector> all_tensors(devices.size()); + for (size_t i = 0; i < all_tensors.size(); ++i) + { + all_tensors[i].resize(net_type::num_computational_layers); + visit_layer_parameter_gradients(devices[i]->net, [&](size_t j, tensor& t){ + all_tensors[i][j] = &t; + }); + } + // Now set each averager to average the tensors at the same layer in each + // network. + for (size_t i = 0; i < net_type::num_computational_layers; ++i) + { + std::vector temp(all_tensors.size()); + for (size_t j = 0; j < all_tensors.size(); ++j) + temp[j] = all_tensors[j][i]; + // ignore layers that don't have parameters + if (temp[0]->size() != 0) + averagers[i].set(temp); + } + + sync_file_reloaded = false; + } + + + for (auto&& d : devices) + cuda::device_synchronize(d->device_id); + + for (auto&& avg : averagers) + avg.average(); + } + + + // Now apply all the updates to each device. + for (size_t i = 0; i < devices.size(); ++i) + tp[i]->add_task_by_value([&,i](){ if (next_job.have_data[i]) update_parameters(i); }); + // and wait for the updates to all happen. + for (size_t i = 0; i < devices.size(); ++i) + tp[i]->wait_for_all_tasks(); + + + // Every now and then force all the parameters to be the same just to make + // sure they aren't drifting apart due to any non-deterministic behavior on + // the GPU. It's also important to do this on the first iteration because + // the different networks may be initialized differently when tensor data + // is first passed through them. So this code block deals with these + // issues. + if (devices.size() > 1 && main_iteration_counter%2000 == 1) + { + for (size_t i = 1; i < devices.size(); ++i) + { + visit_layer_parameters(devices[i]->net, [&](size_t j, tensor& t) + { + memcpy(t, *reference_params[j]); + }); + } + } + + // If we have been running for a while then check if the loss is still + // dropping. If it isn't then we will reduce the learning rate. Note that we + // have a "budget" that prevents us from calling + // count_steps_without_decrease() every iteration. We do this because + // it can be expensive to compute when previous_loss_values is large. + if (gradient_check_budget > iter_without_progress_thresh && learning_rate_shrink != 1) + { + gradient_check_budget = 0; + steps_without_progress = count_steps_without_decrease(previous_loss_values); + if (steps_without_progress >= iter_without_progress_thresh) + { + // Double check that we aren't seeing decrease. This second check + // discards the top 10% largest values and checks again. We do + // this because sometimes a mini-batch might be bad and cause the + // loss to suddenly jump up, making count_steps_without_decrease() + // return a large number. But if we discard the top 10% of the + // values in previous_loss_values then we are robust to that kind + // of noise. Another way of looking at it, if the reason + // count_steps_without_decrease() returns a large value is only + // because the most recent loss values have suddenly been large, + // then we shouldn't stop or lower the learning rate. We should + // keep going until whatever disturbance we hit is damped down. + steps_without_progress = count_steps_without_decrease_robust(previous_loss_values); + if (steps_without_progress >= iter_without_progress_thresh) + { + // optimization has flattened out, so drop the learning rate. + learning_rate = learning_rate_shrink*learning_rate; + steps_without_progress = 0; + // Empty out some of the previous loss values so that steps_without_progress + // will decrease below iter_without_progress_thresh. + for (unsigned long cnt = 0; cnt < previous_loss_values_dump_amount+iter_without_progress_thresh/10 && previous_loss_values.size() > 0; ++cnt) + previous_loss_values.pop_front(); + } + } + } + else if (lr_schedule.size() != 0) // or use the learning rate schedule if we have one. + { + if (lr_schedule_pos < lr_schedule.size()) + learning_rate = lr_schedule(lr_schedule_pos++); + else + learning_rate = lr_schedule(lr_schedule.size()-1)*0.99; + } + } + } + catch(...) + { + // If an exception happens then permanently disable the trainer object. + job_pipe.disable(); + std::lock_guard lock(eptr_mutex); + eptr = std::current_exception(); + } + + void wait_for_thread_to_pause() const + { + job_pipe.wait_for_num_blocked_dequeues(1); + } + + const static long string_pad = 11; + const static long epoch_string_pad = 4; + const static long lr_string_pad = 4; + + void init() + { + max_num_epochs = 10000; + mini_batch_size = 128; + verbose = false; + learning_rate = 1e-2; + min_learning_rate = 1e-5; + iter_without_progress_thresh = 2000; + steps_without_progress = 0; + test_iter_without_progress_thresh = 500; + test_steps_without_progress = 0; + + learning_rate_shrink = 0.1; + epoch_iteration = 0; + epoch_pos = 0; + train_one_step_calls = 0; + test_one_step_calls = 0; + gradient_check_budget = 0; + lr_schedule_pos = 0; + + main_iteration_counter = 0; + main_iteration_counter_at_last_disk_sync = 0; + prob_loss_increasing_thresh_default_value = 0.99; + prob_loss_increasing_thresh_max_value = 0.99999; + prob_loss_increasing_thresh = prob_loss_increasing_thresh_default_value; + updated_net_since_last_sync = false; + sync_file_reloaded = false; + previous_loss_values_dump_amount = 400; + test_previous_loss_values_dump_amount = 100; + + rs_test = running_stats_decayed(200); + + start(); + } + + // serialize and deserialize are private because we hold net by reference so + // allowing someone to serialize this training object is weird and will likely + // result in user errors. However, we use these functions as part of the automatic + // sync code in this object. + friend void serialize(const dnn_trainer& item, std::ostream& out) + { + item.wait_for_thread_to_pause(); + int version = 12; + serialize(version, out); + + size_t nl = dnn_trainer::num_layers; + serialize(nl, out); + serialize(item.rs, out); + serialize(item.rs_test, out); + serialize(item.previous_loss_values, out); + serialize(item.max_num_epochs, out); + serialize(item.mini_batch_size, out); + serialize(item.verbose, out); + serialize(item.net, out); + serialize(item.devices[0]->solvers, out); + serialize(item.learning_rate.load(), out); + serialize(item.min_learning_rate, out); + serialize(item.iter_without_progress_thresh.load(), out); + serialize(item.steps_without_progress.load(), out); + serialize(item.learning_rate_shrink.load(), out); + serialize(item.epoch_iteration, out); + serialize(item.epoch_pos, out); + serialize(item.train_one_step_calls, out); + serialize(item.test_one_step_calls, out); + serialize(item.lr_schedule, out); + serialize(item.lr_schedule_pos, out); + serialize(item.test_iter_without_progress_thresh.load(), out); + serialize(item.test_steps_without_progress.load(), out); + serialize(item.test_previous_loss_values, out); + serialize(item.previous_loss_values_dump_amount, out); + serialize(item.test_previous_loss_values_dump_amount, out); + + } + friend void deserialize(dnn_trainer& item, std::istream& in) + { + item.wait_for_thread_to_pause(); + int version = 0; + deserialize(version, in); + if (version != 12) + throw serialization_error("Unexpected version found while deserializing dlib::dnn_trainer."); + + size_t num_layers = 0; + deserialize(num_layers, in); + if (num_layers != dnn_trainer::num_layers) + { + std::ostringstream sout; + sout << "Error deserializing dlib::dnn_trainer. The saved sync file is for a network with " << std::endl; + sout << "a different number of layers. We expected the number of layers to be " << dnn_trainer::num_layers << " but" << std::endl; + sout << "instead the file contains " << num_layers << " layers." << std::endl; + throw serialization_error(sout.str()); + } + + double dtemp; long ltemp; + deserialize(item.rs, in); + deserialize(item.rs_test, in); + deserialize(item.previous_loss_values, in); + deserialize(item.max_num_epochs, in); + deserialize(item.mini_batch_size, in); + deserialize(item.verbose, in); + deserialize(item.net, in); + deserialize(item.devices[0]->solvers, in); + deserialize(dtemp, in); item.learning_rate = dtemp; + deserialize(item.min_learning_rate, in); + deserialize(ltemp, in); item.iter_without_progress_thresh = ltemp; + deserialize(ltemp, in); item.steps_without_progress = ltemp; + deserialize(dtemp, in); item.learning_rate_shrink = dtemp; + deserialize(item.epoch_iteration, in); + deserialize(item.epoch_pos, in); + deserialize(item.train_one_step_calls, in); + deserialize(item.test_one_step_calls, in); + deserialize(item.lr_schedule, in); + deserialize(item.lr_schedule_pos, in); + deserialize(ltemp, in); item.test_iter_without_progress_thresh = ltemp; + deserialize(ltemp, in); item.test_steps_without_progress = ltemp; + deserialize(item.test_previous_loss_values, in); + deserialize(item.previous_loss_values_dump_amount, in); + deserialize(item.test_previous_loss_values_dump_amount, in); + + if (item.devices.size() > 1) + { + const auto prev_dev = dlib::cuda::get_device(); + // initialize all the other device networks and solver objects + for (size_t i = 1; i < item.devices.size(); ++i) + { + // Switch to this device so that any tensor objects that get allocated when + // we copy this stuff happen on this device. + dlib::cuda::set_device(item.devices[i]->device_id); + item.devices[i]->solvers = item.devices[0]->solvers; + item.devices[i]->net = item.devices[0]->net; + } + dlib::cuda::set_device(prev_dev); + } + } + + void sync_to_disk ( + bool do_it_now = false + ) + { + // don't sync anything if we haven't updated the network since the last sync + if (!updated_net_since_last_sync) + return; + + // If the sync file isn't set then don't do anything. + if (sync_filename.size() == 0) + return; + + // Only sync if it has been long enough since the last sync or we are being + // explicitly forced to do it. + if (std::chrono::system_clock::now() - last_sync_time > time_between_syncs || + do_it_now) + { + wait_for_thread_to_pause(); + + // compact network before saving to disk. + this->net.clean(); + + // if the loss has actually been going up since the last time we saved our + // state to disk then something has probably gone wrong in the + // optimization. So in this case we do the opposite and recall the + // previously saved state in the hopes that the problem won't reoccur. + if (loss_increased_since_last_disk_sync()) + { + std::ifstream fin(newest_syncfile(), std::ios::binary); + deserialize(*this, fin); + sync_file_reloaded = true; + if (verbose) + std::cout << "Loss has been increasing, reloading saved state from " << newest_syncfile() << std::endl; + } + else + { + + const std::string filename = oldest_syncfile(); + serialize(filename) << *this; + + if (verbose) + std::cout << "Saved state to " << filename << std::endl; + } + + last_sync_time = std::chrono::system_clock::now(); + main_iteration_counter_at_last_disk_sync = main_iteration_counter; + updated_net_since_last_sync = false; + } + } + + std::string newest_syncfile ( + ) + { + return select_newest_file(sync_filename, sync_filename + "_"); + } + + std::string oldest_syncfile ( + ) + { + return select_oldest_file(sync_filename, sync_filename + "_"); + } + + bool loss_increased_since_last_disk_sync() + { + size_t gradient_updates_since_last_sync = main_iteration_counter - main_iteration_counter_at_last_disk_sync; + + // if we haven't synced anything to disk yet then return false. + if (!std::ifstream(newest_syncfile(), std::ios::binary)) + return false; + + for (auto x : previous_loss_values) + { + // If we get a NaN value of loss assume things have gone horribly wrong and + // we should reload the state of the trainer. + if (std::isnan(x)) + return true; + } + + // if we haven't seen much data yet then just say false. Or, alternatively, if + // it's been too long since the last sync then don't reload either. + if (gradient_updates_since_last_sync < 30 || previous_loss_values.size() < 2*gradient_updates_since_last_sync) + return false; + + // Now look at the data since a little before the last disk sync. We will + // check if the loss is getting bettor or worse. + running_gradient g; + for (size_t i = previous_loss_values.size() - 2*gradient_updates_since_last_sync; i < previous_loss_values.size(); ++i) + g.add(previous_loss_values[i]); + + // if the loss is very likely to be increasing then return true + const double prob = g.probability_gradient_greater_than(0); + if (prob > prob_loss_increasing_thresh && prob_loss_increasing_thresh <= prob_loss_increasing_thresh_max_value) + { + // Exponentially decay the threshold towards 1 so that if we keep finding + // the loss to be increasing over and over we will make the test + // progressively harder and harder until it fails, therefore ensuring we + // can't get stuck reloading from a previous state over and over. + prob_loss_increasing_thresh = 0.1*prob_loss_increasing_thresh + 0.9*1; + return true; + } + else + { + // decay back to the default threshold + prob_loss_increasing_thresh = std::pow(prob_loss_increasing_thresh, 10.0); + // but don't decay below the default value + prob_loss_increasing_thresh = std::max(prob_loss_increasing_thresh, prob_loss_increasing_thresh_default_value); + + return false; + } + } + + + struct clone_net{}; + + // per device state. All the containers have the same number of objects in them. + struct device_data + { + device_data( + int device_id_, + net_type& net_, + const solver_type& solver_ + ) : device_id(device_id_), net(net_), solvers(num_computational_layers, solver_) {} + + device_data( + int device_id_, + net_type& net_, + const solver_type& solver_, + clone_net + ) : device_id(device_id_), net_copy(std::make_shared(net_)), net(*net_copy), solvers(num_computational_layers, solver_) {} + + int device_id; + std::shared_ptr net_copy; + net_type& net; + std::vector solvers; + }; + + template < + typename data_iterator, + typename label_iterator + > + void send_job ( + bool test_only, + data_iterator dbegin, + data_iterator dend, + label_iterator lbegin + ) + { + propagate_exception(); + size_t num = std::distance(dbegin, dend); + size_t devs = devices.size(); + job.t.resize(devs); + job.labels.resize(devs); + job.have_data.resize(devs); + job.test_only = test_only; + + // chop the data into devs blocks, each of about block_size elements. + size_t block_size = (num+devs-1)/devs; + + const auto prev_dev = dlib::cuda::get_device(); + for (size_t i = 0; i < devs; ++i) + { + dlib::cuda::set_device(devices[i]->device_id); + + size_t start = i*block_size; + size_t stop = std::min(num, start+block_size); + + if (start < stop) + { + devices[i]->net.to_tensor(dbegin+start, dbegin+stop, job.t[i]); + job.labels[i].assign(lbegin+start, lbegin+stop); + job.have_data[i] = true; + } + else + { + job.have_data[i] = false; + } + } + + dlib::cuda::set_device(prev_dev); + job_pipe.enqueue(job); + } + + template < + typename data_iterator + > + void send_job ( + bool test_only, + data_iterator dbegin, + data_iterator dend + ) + { + typename std::vector::iterator nothing; + send_job(test_only, dbegin, dend, nothing); + } + + void print_progress() + { + if (lr_schedule.size() == 0) + { + if (test_previous_loss_values.size() == 0) + std::cout << "steps without apparent progress: " << steps_without_progress; + else + std::cout << "steps without apparent progress: train=" << steps_without_progress << ", test=" << test_steps_without_progress; + } + else + { + std::ostringstream sout; + sout << "percent complete: " << std::fixed << std::setprecision(2) << 100.0*lr_schedule_pos/(double)lr_schedule.size() << "%"; + std::cout << sout.str(); + } + std::cout << std::endl; + } + + void print_periodic_verbose_status() + { + if (verbose) + { + using namespace std::chrono; + auto now_time = system_clock::now(); + if (now_time-last_time > seconds(40)) + { + last_time = now_time; + std::cout << "step#: " << rpad(cast_to_string(train_one_step_calls),epoch_string_pad) << " " + << "learning rate: " << rpad(cast_to_string(learning_rate),lr_string_pad) << " "; + if (test_previous_loss_values.size() == 0) + { + std::cout << "average loss: " << rpad(cast_to_string(get_average_loss()),string_pad) << " "; + } + else + { + std::cout << "train loss: " << rpad(cast_to_string(get_average_loss()),string_pad) << " "; + std::cout << "test loss: " << rpad(cast_to_string(get_average_test_loss()),string_pad) << " "; + } + print_progress(); + clear_average_loss(); + } + } + } + + std::vector> devices; + dlib::pipe job_pipe; + job_t job; + + + running_stats rs; + running_stats_decayed rs_test; + std::deque previous_loss_values; + unsigned long max_num_epochs; + size_t mini_batch_size; + bool verbose; + net_type& net; + std::atomic learning_rate; + double min_learning_rate; + std::atomic iter_without_progress_thresh; + std::atomic steps_without_progress; + + std::atomic test_iter_without_progress_thresh; + std::atomic test_steps_without_progress; + std::deque test_previous_loss_values; + + std::atomic learning_rate_shrink; + std::chrono::time_point last_sync_time; + std::string sync_filename; + std::chrono::seconds time_between_syncs; + unsigned long epoch_iteration; + size_t epoch_pos; + std::chrono::time_point last_time; + unsigned long long train_one_step_calls; + unsigned long long test_one_step_calls; + matrix lr_schedule; + long lr_schedule_pos; + unsigned long gradient_check_budget; + + std::exception_ptr eptr = nullptr; + mutable std::mutex eptr_mutex; + void propagate_exception() const + { + std::lock_guard lock(eptr_mutex); + if (eptr) + std::rethrow_exception(eptr); + } + + // These 5 variables are not serialized + size_t main_iteration_counter; + size_t main_iteration_counter_at_last_disk_sync; + double prob_loss_increasing_thresh_default_value; + double prob_loss_increasing_thresh_max_value; + double prob_loss_increasing_thresh; + std::atomic updated_net_since_last_sync; + + bool sync_file_reloaded; + unsigned long previous_loss_values_dump_amount; + unsigned long test_previous_loss_values_dump_amount; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename net_type, + typename solver_type + > + std::ostream& operator<< ( + std::ostream& out, + dnn_trainer& trainer + ) + { + using std::endl; + out << "dnn_trainer details: \n"; + out << " net_type::num_layers: " << net_type::num_layers << endl; + // figure out how big the net is in MB. + std::ostringstream sout; + net_type temp = trainer.get_net(); // make a copy so that we can clean it without mutating the trainer's net. + temp.clean(); + serialize(temp, sout); + out << " net size: " << sout.str().size()/1024.0/1024.0 << "MB" << endl; + // Don't include the loss params in the hash since we print them on the next line. + // They also aren't really part of the "architecture" of the network. + out << " net architecture hash: " << md5(cast_to_string(trainer.get_net().subnet())) << endl; + out << " loss: " << trainer.get_net().loss_details() << endl; + + out << " synchronization file: " << trainer.get_synchronization_file() << endl; + out << " trainer.get_solvers()[0]: " << trainer.get_solvers()[0] << endl; + auto sched = trainer.get_learning_rate_schedule(); + if (sched.size() != 0) + { + out << " using explicit user-supplied learning rate schedule" << endl; + } + else + { + out << " learning rate: "<< trainer.get_learning_rate() << endl; + out << " learning rate shrink factor: "<< trainer.get_learning_rate_shrink_factor() << endl; + out << " min learning rate: "<< trainer.get_min_learning_rate() << endl; + out << " iterations without progress threshold: "<< trainer.get_iterations_without_progress_threshold() << endl; + out << " test iterations without progress threshold: "<< trainer.get_test_iterations_without_progress_threshold() << endl; + } + return out; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_DNn_TRAINER_H_ + diff --git a/ml/dlib/dlib/dnn/trainer_abstract.h b/ml/dlib/dlib/dnn/trainer_abstract.h new file mode 100644 index 000000000..3bfb6dc99 --- /dev/null +++ b/ml/dlib/dlib/dnn/trainer_abstract.h @@ -0,0 +1,765 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_DNn_TRAINER_ABSTRACT_H_ +#ifdef DLIB_DNn_TRAINER_ABSTRACT_H_ + +#include "core_abstract.h" +#include "solvers_abstract.h" +#include +#include + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + enum class force_flush_to_disk { + no = 0, + yes = 1 + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename net_type, + typename solver_type = sgd + > + class dnn_trainer + { + /*! + REQUIREMENTS ON net_type + - net_type is an add_loss_layer object. + + REQUIREMENTS ON solver_type + - solver_type is an implementation of the EXAMPLE_SOLVER interface defined + in solvers_abstract.h + + WHAT THIS OBJECT REPRESENTS + This object is a tool training a deep neural network. To use it you supply + a neural network type and a solver, then you call train() with your + training data and it will output a new network instance that has hopefully + learned something useful from your training data. + + If you are compiling with CUDA then this object will use the GPU that is + currently selected (i.e. the one indicated by cudaGetDevice()) when + dnn_trainer is constructed. It will continue to use that device even if + you later change it by a call to cudaSetDevice(). + + EXCEPTIONS + If an exception is thrown by any part of the neural network during training + then the exception will be propagated out of the trainer to the user. + Moreover, the trainer instance will be unusable and should be destroyed. + !*/ + + public: + + typedef typename net_type::training_label_type training_label_type; + typedef typename net_type::input_type input_type; + const static size_t num_computational_layers = net_type::num_computational_layers; + + dnn_trainer() = delete; + dnn_trainer(const dnn_trainer&) = delete; + dnn_trainer& operator=(const dnn_trainer&) = delete; + + dnn_trainer( + net_type& net, + const solver_type& solver = solver_type(), + const std::vector& cuda_extra_devices = {} + ); + /*! + requires + - for all valid i: + - 0 <= cuda_extra_devices[i] < dlib::cuda::get_num_devices() + ensures + - &#get_net() == &net + (i.e. The dnn_trainer holds a reference to net, it does not copy it. + Therefore, you must ensure net has a lifetime at least as long as the + dnn_trainer). + - #get_solvers() == a set of solvers that are all initialized with the + provided solver instance. + - #get_max_num_epochs() == 10000 + - #get_mini_batch_size() == 128 + - #get_learning_rate() == 1e-2 + - #get_min_learning_rate() == 1e-5 + - #get_iterations_without_progress_threshold() == 2000 + - #get_test_iterations_without_progress_threshold() == 500 + - #get_learning_rate_shrink_factor() == 0.1 + - #get_learning_rate_schedule().size() == 0 + - #get_train_one_step_calls() == 0 + - #get_test_one_step_calls() == 0 + - #get_synchronization_file() == "" + - if (cuda_extra_devices.size() > 0) then + - This object will use multiple graphics cards to run the learning + algorithms. In particular, it will always use whatever device is + currently selected on the calling thread (the device indicated by + cudaGetDevice()). In addition, you can ask to use additional + devices, which you do by putting their device numbers into + cuda_extra_devices. + !*/ + + net_type& get_net ( + force_flush_to_disk force_flush = force_flush_to_disk::yes + ); + /*! + ensures + - returns the neural network object used by this trainer. This is the + network that is optimized when you call train() or train_one_step(). + Recall that the dnn_trainer doesn't contain the net_type object but + simply holds a reference to an external network which was provided to the + dnn_trainer's constructor. + - This function blocks until all threads inside the dnn_trainer have + stopped touching the net. + - If force_flush is yes, then this function will sync the trainer state to + disk if the current state hasn't already been synced to disk since the + last network modification. + !*/ + + const std::vector& get_solvers ( + ) const; + /*! + ensures + - returns the solvers used to optimize each layer of the neural network + get_net(). In particular, the first layer's solver is + get_solvers()[0], the second layer's solver is + get_solvers()[1], and so on. + - This function blocks until all threads inside the dnn_trainer have + stopped touching the net. + !*/ + + unsigned long get_mini_batch_size ( + ) const; + /*! + ensures + - During training, we call the network's update() routine over and over + with training data. The number of training samples we give to each call + to update is the "mini-batch size", which is defined by + get_mini_batch_size(). + !*/ + + void set_mini_batch_size ( + unsigned long batch_size + ); + /*! + requires + - batch_size > 0 + ensures + - #get_mini_batch_size() == batch_size + !*/ + + unsigned long get_max_num_epochs ( + ) const; + /*! + ensures + - train() will execute at most get_max_num_epochs() iterations over the + training data before returning. + !*/ + + void set_max_num_epochs ( + unsigned long num + ); + /*! + requires + - num > 0 + ensures + - #get_max_num_epochs() == num + !*/ + + void set_learning_rate ( + double lr + ); + /*! + requires + - lr > 0 + ensures + - #get_learning_rate() == lr + - #get_learning_rate_schedule().size() == 0 + - This function blocks until all threads inside the dnn_trainer have + stopped touching the net. + !*/ + + double get_learning_rate( + ) const; + /*! + ensures + - During each training step, a solver tells us how to modify the parameters + of each layer in the network. It does this by outputting a step vector + that, when added to the parameters, will hopefully result in improved + network performance. The learning rate is one of the inputs to the + solver and influences the size of this step vector. This function + returns the current learning rate, that is, the learning rate that will + be used during the next training step. + !*/ + + void set_min_learning_rate ( + double lr + ); + /*! + requires + - lr > 0 + ensures + - #get_min_learning_rate() == lr + - #get_learning_rate_schedule().size() == 0 + - This function blocks until all threads inside the dnn_trainer have + stopped touching the net. + !*/ + + double get_min_learning_rate ( + ) const; + /*! + ensures + - During training via this->train(), this object will test if progress is + still being made and if it isn't then it will reduce get_learning_rate() + by setting it to get_learning_rate()*get_learning_rate_shrink_factor(). + However, it will not reduce it below get_min_learning_rate(). Once this + minimum learning rate is crossed the training will terminate. + - get_min_learning_rate() doesn't apply if you are using train_one_step(). + You can keep calling train_one_step() as many times as you want and the + learning rate will drop infinitely close to 0 if you run long enough. + !*/ + + template + void set_learning_rate_schedule ( + const matrix_exp& schedule + ); + /*! + requires + - schedule.size() > 0 + - min(schedule) > 0 + ensures + - #get_learning_rate_schedule() == reshape_to_column_vector(schedule) + - #get_learning_rate() == schedule(0,0) + - #get_min_learning_rate() == min(schedule) + - #set_learning_rate_shrink_factor() == 1 + !*/ + + const matrix& get_learning_rate_schedule ( + ) const; + /*! + ensures + - if (this function returns a non-empty matrix) then + - This trainer will use an explicit learning rate schedule defined by + the learning rate values in get_learning_rate_schedule(). For + example, if get_learning_rate_schedule() returned {0.1, 0.09, 0.08, + 0.07, 0.06} then the first training mini-batch would use a learning + rate of 0.1, then the next training mini-batch uses 0.09, and then + 0.8, and so on until the end of the schedule is reached. + + If you continue to run training after the end of the schedule has + been reached then the learning rate will be fixed to 0.99 times the + final value. So in our example, eventually the learning rate would + be fixed to 0.99*0.06. This allows you to test if we have reached the + end of the schedule by checking if get_learning_rate() >= 0.06. + !*/ + + unsigned long get_steps_without_progress ( + ) const; + /*! + ensures + - if (get_learning_rate_shrink_factor() != 1) then + - returns an estimate of how many mini-batches have executed without us + observing a statistically significant decrease in the training error. + - else + - returns 0 + !*/ + + void set_iterations_without_progress_threshold ( + unsigned long thresh + ); + /*! + ensures + - #get_iterations_without_progress_threshold() == thresh + - #get_learning_rate_schedule().size() == 0 + - This function blocks until all threads inside the dnn_trainer have + stopped touching the net. + !*/ + + unsigned long get_iterations_without_progress_threshold ( + ) const; + /*! + ensures + - This object monitors the progress of training and estimates if the + training error is being reduced. It does this by looking at the previous + get_iterations_without_progress_threshold() mini-batch results and + applying the statistical test defined by the running_gradient object to + see if the training error is getting smaller. If it isn't being reduced + then get_learning_rate() is made smaller by a factor of get_learning_rate_shrink_factor(). + + Therefore, get_iterations_without_progress_threshold() should always be + set to something sensibly large so that this test can be done with + reasonably high confidence. Think of this test as saying "if the loss + hasn't decreased for the previous get_iterations_without_progress_threshold() + then shrink the learning rate". + !*/ + + void set_learning_rate_shrink_factor ( + double shrink + ); + /*! + requires + - 0 < shrink && shrink <= 1 + ensures + - #get_learning_rate_shrink_factor() == shrink + - #get_learning_rate_schedule().size() == 0 + - This function blocks until all threads inside the dnn_trainer have + stopped touching the net. + !*/ + + double get_learning_rate_shrink_factor ( + ) const; + /*! + ensures + - Whenever the training routine thinks it isn't making progress anymore it + will reduce get_learning_rate() by multiplying it by get_learning_rate_shrink_factor(). + - You can disable the automatic learning rate reduction by setting + get_learning_rate_shrink_factor() to 1. + !*/ + + unsigned long long get_train_one_step_calls ( + ) const; + /*! + ensures + - returns the number of times train_one_step() has been called. + !*/ + + unsigned long long get_test_one_step_calls ( + ) const; + /*! + ensures + - returns the number of times test_one_step() has been called. + !*/ + + void be_verbose ( + ); + /*! + ensures + - This object will print status messages to standard out so that a + user can observe the progress of the algorithm. + !*/ + + void be_quiet ( + ); + /*! + ensures + - This object will not print anything to standard out + !*/ + + void set_synchronization_file ( + const std::string& filename, + std::chrono::seconds time_between_syncs = std::chrono::minutes(15) + ); + /*! + ensures + - #get_synchronization_file() == filename + - While training is running, either via train() or repeated calls to + train_one_step(), this object will save its entire state, including the + state of get_net(), to disk in the file named filename every + time_between_syncs seconds. + - If the filename file already exists then the state of this trainer will + be loaded from that file by this call to set_synchronization_file(). + This allows you to resume a training session which was previously + interrupted. + - It should be noted that when saving, the trainer will alternate between + saving to a file called filename and another file called filename+"_". + We do this because it's possible that your computer might crash (not + because of dlib, just in general) before the data is safely saved to + disk. This way, you will always have a backup file if the write to disk + gets corrupted or is incomplete. Moreover, when loading, we will always + load from the newest of the two possible files. + !*/ + + const std::string& get_synchronization_file ( + ); + /*! + ensures + - Returns the name of the file the dnn_trainer will periodically save it's + state to. If the return value is "" then synchronization is disabled. + !*/ + + void train ( + const std::vector& data, + const std::vector& labels + ); + /*! + requires + - data.size() == labels.size() + - data.size() > 0 + - net_type uses a supervised loss. + i.e. net_type::training_label_type != no_label_type. + ensures + - Trains a supervised neural network based on the given training data. + The goal of training is to find the network parameters that minimize + get_net().compute_loss(data.begin(), data.end(), labels.begin()). + - The optimizer will run until get_learning_rate() < get_min_learning_rate() + or get_max_num_epochs() training epochs have been executed. + - Each layer in the network will be optimized by its corresponding solver + in get_solvers(). + - Each call to train DOES NOT reinitialize the state of get_net() or + get_solvers(). That is, the existing state of the solvers and network is + the starting point for the optimization each time train() is called. In + particular, if you use the set_synchronization_file() method you can + resume an interrupted train() call by simply calling train() again and it + will pick up from the last synchronization point. + - You can obtain the average loss value during the final training epoch by + calling get_average_loss(). + - This function blocks until all threads inside the dnn_trainer have + stopped touching the net. + !*/ + + void train ( + const std::vector& data + ); + /*! + requires + - data.size() > 0 + - net_type uses an unsupervised loss. + i.e. net_type::training_label_type == no_label_type. + ensures + - Trains an unsupervised neural network based on the given training data. + The goal of training is to find the network parameters that minimize + get_net().compute_loss(data.begin(), data.end()). + - The optimizer will run until get_learning_rate() < get_min_learning_rate() + or get_max_num_epochs() training epochs have been executed. + - Each layer in the network will be optimized by its corresponding solver + in get_solvers(). + - Each call to train DOES NOT reinitialize the state of get_net() or + get_solvers(). That is, the existing state of the solvers and network is + the starting point for the optimization each time train() is called. In + particular, if you use the set_synchronization_file() method you can + resume an interrupted train() call by simply calling train() again and it + will pick up from the last synchronization point. + - You can obtain the average loss value during the final training epoch by + calling get_average_loss(). + - This function blocks until all threads inside the dnn_trainer have + stopped touching the net. + !*/ + + void train_one_step ( + const std::vector& data, + const std::vector& labels + ); + /*! + requires + - data.size() == labels.size() + - data.size() > 0 + - net_type uses a supervised loss. + i.e. net_type::training_label_type != no_label_type. + ensures + - Performs one stochastic gradient update step based on the mini-batch of + data and labels supplied to this function. In particular, calling + train_one_step() in a loop is equivalent to calling the train() method + defined above. However, train_one_step() allows you to stream data from + disk into the training process while train() requires you to first load + all the training data into RAM. Otherwise, these training methods are + equivalent. + - You can observe the current average loss value by calling get_average_loss(). + - The network training will happen in another thread. Therefore, after + calling this function you should call get_net() before you touch the net + object from the calling thread to ensure no other threads are still + accessing the network. + - #get_train_one_step_calls() == get_train_one_step_calls() + 1. + !*/ + + template < + typename data_iterator, + typename label_iterator + > + void train_one_step ( + data_iterator dbegin, + data_iterator dend, + label_iterator lbegin + ); + /*! + requires + - std::advance(lbegin, std::distance(dbegin, dend) - 1) is dereferencable + - std::distance(dbegin, dend) > 0 + - net_type uses a supervised loss. + i.e. net_type::training_label_type != no_label_type. + ensures + - Performs one stochastic gradient update step based on the mini-batch of + data and labels supplied to this function. In particular, calling + train_one_step() in a loop is equivalent to calling the train() method + defined above. However, train_one_step() allows you to stream data from + disk into the training process while train() requires you to first load + all the training data into RAM. Otherwise, these training methods are + equivalent. + - You can observe the current average loss value by calling get_average_loss(). + - The network training will happen in another thread. Therefore, after + calling this function you should call get_net() before you touch the net + object from the calling thread to ensure no other threads are still + accessing the network. + - #get_train_one_step_calls() == get_train_one_step_calls() + 1. + !*/ + + void train_one_step ( + const std::vector& data + ); + /*! + requires + - data.size() > 0 + - net_type uses an unsupervised loss. + i.e. net_type::training_label_type == no_label_type. + ensures + - Performs one stochastic gradient update step based on the mini-batch of + data supplied to this function. In particular, calling train_one_step() + in a loop is equivalent to calling the train() method defined above. + However, train_one_step() allows you to stream data from disk into the + training process while train() requires you to first load all the + training data into RAM. Otherwise, these training methods are + equivalent. + - You can observe the current average loss value by calling get_average_loss(). + - The network training will happen in another thread. Therefore, after + calling this function you should call get_net() before you touch the net + object from the calling thread to ensure no other threads are still + accessing the network. + - #get_train_one_step_calls() == get_train_one_step_calls() + 1. + !*/ + + template < + typename data_iterator + > + void train_one_step ( + data_iterator dbegin, + data_iterator dend + ); + /*! + requires + - std::distance(dbegin, dend) > 0 + - net_type uses an unsupervised loss. + i.e. net_type::training_label_type == no_label_type. + ensures + - Performs one stochastic gradient update step based on the mini-batch of + data supplied to this function. In particular, calling train_one_step() + in a loop is equivalent to calling the train() method defined above. + However, train_one_step() allows you to stream data from disk into the + training process while train() requires you to first load all the + training data into RAM. Otherwise, these training methods are + equivalent. + - You can observe the current average loss value by calling get_average_loss(). + - The network training will happen in another thread. Therefore, after + calling this function you should call get_net() before you touch the net + object from the calling thread to ensure no other threads are still + accessing the network. + - #get_train_one_step_calls() == get_train_one_step_calls() + 1. + !*/ + + double get_average_loss ( + ) const; + /*! + ensures + - returns the average loss value observed during previous calls to + train_one_step() or train(). That is, the average output of + net_type::update() during the previous mini-batch updates. + - Note that, if be_verbose() has been called, then this object will + automatically call clear_average_loss() periodically when it logs the + loss to the console. + - This function blocks until all threads inside the dnn_trainer have + stopped touching the net. + !*/ + + void clear_average_loss ( + ); + /*! + ensures + - #get_average_loss() == 0 + - get_average_loss() uses a dlib::running_stats object to keep a running + average of the loss values seen during the previous mini-batch updates + applied during training. Calling clear_average_loss() resets the + running_stats object so it forgets about all previous loss values + observed. + - This function blocks until all threads inside the dnn_trainer have + stopped touching the net. + !*/ + + // ---------------------- + + double get_average_test_loss ( + ) const; + /*! + ensures + - returns the average loss value observed during previous calls to + test_one_step(). + - This function blocks until all threads inside the dnn_trainer have + stopped touching the net. + !*/ + + void test_one_step ( + const std::vector& data, + const std::vector& labels + ); + /*! + requires + - data.size() == labels.size() + - data.size() > 0 + - net_type uses a supervised loss. + i.e. net_type::training_label_type != no_label_type. + ensures + - Runs the given data through the network and computes and records the loss. + - This call does not modify network parameters. The point of + test_one_step() is two fold, to allow you to observe the accuracy of the + network on hold out data during training, and to allow the trainer to + automatically adjust the learning rate when the test loss stops + improving. It should be noted that you are not required to use + test_one_step() at all, but if you want to do this kind of thing it is + available. + - You can observe the current average loss value by calling get_average_test_loss(). + - The computation will happen in another thread. Therefore, after calling + this function you should call get_net() before you touch the net object + from the calling thread to ensure no other threads are still accessing + the network. + - #get_test_one_step_calls() == get_test_one_step_calls() + 1. + !*/ + + template < + typename data_iterator, + typename label_iterator + > + void test_one_step ( + data_iterator dbegin, + data_iterator dend, + label_iterator lbegin + ); + /*! + requires + - std::advance(lbegin, std::distance(dbegin, dend) - 1) is dereferencable + - std::distance(dbegin, dend) > 0 + - net_type uses a supervised loss. + i.e. net_type::training_label_type != no_label_type. + ensures + - Runs the given data through the network and computes and records the loss. + - This call does not modify network parameters. The point of + test_one_step() is two fold, to allow you to observe the accuracy of the + network on hold out data during training, and to allow the trainer to + automatically adjust the learning rate when the test loss stops + improving. It should be noted that you are not required to use + test_one_step() at all, but if you want to do this kind of thing it is + available. + - You can observe the current average loss value by calling get_average_test_loss(). + - The computation will happen in another thread. Therefore, after calling + this function you should call get_net() before you touch the net object + from the calling thread to ensure no other threads are still accessing + the network. + - #get_test_one_step_calls() == get_test_one_step_calls() + 1. + !*/ + + void test_one_step ( + const std::vector& data + ); + /*! + requires + - data.size() > 0 + - net_type uses an unsupervised loss. + i.e. net_type::training_label_type == no_label_type. + ensures + - Runs the given data through the network and computes and records the loss. + - This call does not modify network parameters. The point of + test_one_step() is two fold, to allow you to observe the accuracy of the + network on hold out data during training, and to allow the trainer to + automatically adjust the learning rate when the test loss stops + improving. It should be noted that you are not required to use + test_one_step() at all, but if you want to do this kind of thing it is + available. + - You can observe the current average loss value by calling get_average_test_loss(). + - The computation will happen in another thread. Therefore, after calling + this function you should call get_net() before you touch the net object + from the calling thread to ensure no other threads are still accessing + the network. + - #get_test_one_step_calls() == get_test_one_step_calls() + 1. + !*/ + + template < + typename data_iterator + > + void test_one_step ( + data_iterator dbegin, + data_iterator dend + ); + /*! + requires + - std::distance(dbegin, dend) > 0 + - net_type uses an unsupervised loss. + i.e. net_type::training_label_type == no_label_type. + ensures + - Runs the given data through the network and computes and records the loss. + - This call does not modify network parameters. The point of + test_one_step() is two fold, to allow you to observe the accuracy of the + network on hold out data during training, and to allow the trainer to + automatically adjust the learning rate when the test loss stops + improving. It should be noted that you are not required to use + test_one_step() at all, but if you want to do this kind of thing it is + available. + - You can observe the current average loss value by calling get_average_test_loss(). + - The computation will happen in another thread. Therefore, after calling + this function you should call get_net() before you touch the net object + from the calling thread to ensure no other threads are still accessing + the network. + - #get_test_one_step_calls() == get_test_one_step_calls() + 1. + !*/ + + void set_test_iterations_without_progress_threshold ( + unsigned long thresh + ); + /*! + ensures + - #get_test_iterations_without_progress_threshold() == thresh + - #get_learning_rate_schedule().size() == 0 + - This function blocks until all threads inside the dnn_trainer have + stopped touching the net. + !*/ + + unsigned long get_test_iterations_without_progress_threshold ( + ) const; + /*! + ensures + - This object monitors the progress of training and estimates if the + testing error is being reduced. It does this by looking at the previous + get_test_iterations_without_progress_threshold() mini-batch results from + test_one_step() and applying the statistical test defined by the + running_gradient object to see if the testing error is getting smaller. + If it isn't being reduced then get_learning_rate() is made smaller by a + factor of get_learning_rate_shrink_factor(). + + Therefore, get_test_iterations_without_progress_threshold() should always be + set to something sensibly large so that this test can be done with + reasonably high confidence. Think of this test as saying "if the testing loss + hasn't decreased for the previous get_test_iterations_without_progress_threshold() + calls to test_one_step() then shrink the learning rate". + !*/ + + unsigned long get_test_steps_without_progress ( + ) const; + /*! + ensures + - if (get_learning_rate_shrink_factor() != 1) then + - returns an estimate of how many mini-batches have executed without us + observing a statistically significant decrease in the testing error + (i.e. the error on the data given to the trainer via test_one_step() + calls). + - else + - returns 0 + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename net_type, + typename solver_type + > + std::ostream& operator<< ( + std::ostream& out, + dnn_trainer& trainer + ); + /*! + ensures + - Prints a log of the current parameters of trainer to out. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_DNn_TRAINER_ABSTRACT_H_ + + diff --git a/ml/dlib/dlib/dnn/utilities.h b/ml/dlib/dlib/dnn/utilities.h new file mode 100644 index 000000000..976128c81 --- /dev/null +++ b/ml/dlib/dlib/dnn/utilities.h @@ -0,0 +1,281 @@ +// Copyright (C) 2016 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_DNn_UTILITIES_H_ +#define DLIB_DNn_UTILITIES_H_ + +#include "core.h" +#include "utilities_abstract.h" +#include "../geometry.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + inline double log1pexp(double x) + { + using std::exp; + using namespace std; // Do this instead of using std::log1p because some compilers + // error out otherwise (E.g. gcc 4.9 in cygwin) + if (x <= -37) + return exp(x); + else if (-37 < x && x <= 18) + return log1p(exp(x)); + else if (18 < x && x <= 33.3) + return x + exp(-x); + else + return x; + } + +// ---------------------------------------------------------------------------------------- + + inline void randomize_parameters ( + tensor& params, + unsigned long num_inputs_and_outputs, + dlib::rand& rnd + ) + { + for (auto& val : params) + { + // Draw a random number to initialize the layer according to formula (16) + // from Understanding the difficulty of training deep feedforward neural + // networks by Xavier Glorot and Yoshua Bengio. + val = 2*rnd.get_random_float()-1; + val *= std::sqrt(6.0/(num_inputs_and_outputs)); + } + } + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + class visitor_net_to_xml + { + public: + + visitor_net_to_xml(std::ostream& out_) : out(out_) {} + + template + void operator()(size_t idx, const input_layer_type& l) + { + out << "\n"; + to_xml(l,out); + out << "\n"; + } + + template + void operator()(size_t idx, const add_loss_layer& l) + { + out << "\n"; + to_xml(l.loss_details(),out); + out << "\n"; + } + + template + void operator()(size_t idx, const add_layer& l) + { + out << "\n"; + to_xml(l.layer_details(),out); + out << "\n"; + } + + template + void operator()(size_t idx, const add_tag_layer& l) + { + out << "\n"; + } + + template class T, typename U> + void operator()(size_t idx, const add_skip_layer& l) + { + out << "\n"; + } + + private: + + std::ostream& out; + }; + } + + template + void net_to_xml ( + const net_type& net, + std::ostream& out + ) + { + auto old_precision = out.precision(9); + out << "\n"; + visit_layers(net, impl::visitor_net_to_xml(out)); + out << "\n"; + // restore the original stream precision. + out.precision(old_precision); + } + + template + void net_to_xml ( + const net_type& net, + const std::string& filename + ) + { + std::ofstream fout(filename); + net_to_xml(net, fout); + } + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + + class visitor_net_map_input_to_output + { + public: + + visitor_net_map_input_to_output(dpoint& p_) : p(p_) {} + + dpoint& p; + + template + void operator()(const input_layer_type& net) + { + } + + template + void operator()(const add_loss_layer& net) + { + (*this)(net.subnet()); + } + + template + void operator()(const add_layer& net) + { + (*this)(net.subnet()); + p = net.layer_details().map_input_to_output(p); + } + template + void operator()(const dimpl::subnet_wrapper,B>& net) + { + (*this)(net.subnet()); + p = net.layer_details().map_input_to_output(p); + } + + + template + void operator()(const add_tag_layer& net) + { + // tag layers are an identity transform, so do nothing + (*this)(net.subnet()); + } + template + void operator()(const dimpl::subnet_wrapper,is_first>& net) + { + // tag layers are an identity transform, so do nothing + (*this)(net.subnet()); + } + + + template class TAG_TYPE, typename U> + void operator()(const add_skip_layer& net) + { + (*this)(layer(net)); + } + template class TAG_TYPE, typename SUBNET> + void operator()(const dimpl::subnet_wrapper,is_first>& net) + { + // skip layers are an identity transform, so do nothing + (*this)(layer(net)); + } + + }; + + class visitor_net_map_output_to_input + { + public: + visitor_net_map_output_to_input(dpoint& p_) : p(p_) {} + + dpoint& p; + + template + void operator()(const input_layer_type& net) + { + } + + template + void operator()(const add_loss_layer& net) + { + (*this)(net.subnet()); + } + + template + void operator()(const add_layer& net) + { + p = net.layer_details().map_output_to_input(p); + (*this)(net.subnet()); + } + template + void operator()(const dimpl::subnet_wrapper,B>& net) + { + p = net.layer_details().map_output_to_input(p); + (*this)(net.subnet()); + } + + + template + void operator()(const add_tag_layer& net) + { + // tag layers are an identity transform, so do nothing + (*this)(net.subnet()); + } + template + void operator()(const dimpl::subnet_wrapper,is_first>& net) + { + // tag layers are an identity transform, so do nothing + (*this)(net.subnet()); + } + + + template class TAG_TYPE, typename U> + void operator()(const add_skip_layer& net) + { + (*this)(layer(net)); + } + template class TAG_TYPE, typename SUBNET> + void operator()(const dimpl::subnet_wrapper,is_first>& net) + { + // skip layers are an identity transform, so do nothing + (*this)(layer(net)); + } + + }; + } + + template + inline dpoint input_tensor_to_output_tensor( + const net_type& net, + dpoint p + ) + { + impl::visitor_net_map_input_to_output temp(p); + temp(net); + return p; + } + + template + inline dpoint output_tensor_to_input_tensor( + const net_type& net, + dpoint p + ) + { + impl::visitor_net_map_output_to_input temp(p); + temp(net); + return p; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_DNn_UTILITIES_H_ + + + diff --git a/ml/dlib/dlib/dnn/utilities_abstract.h b/ml/dlib/dlib/dnn/utilities_abstract.h new file mode 100644 index 000000000..2a9a3d3fc --- /dev/null +++ b/ml/dlib/dlib/dnn/utilities_abstract.h @@ -0,0 +1,127 @@ +// Copyright (C) 2016 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_DNn_UTILITIES_ABSTRACT_H_ +#ifdef DLIB_DNn_UTILITIES_ABSTRACT_H_ + +#include "core_abstract.h" +#include "../geometry/vector_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + double log1pexp( + double x + ); + /*! + ensures + - returns log(1+exp(x)) + (except computes it using a numerically accurate method) + !*/ + +// ---------------------------------------------------------------------------------------- + + void randomize_parameters ( + tensor& params, + unsigned long num_inputs_and_outputs, + dlib::rand& rnd + ); + /*! + ensures + - This function assigns random values into params based on the given random + number generator. In particular, it uses the parameter initialization method + of formula 16 from the paper "Understanding the difficulty of training deep + feedforward neural networks" by Xavier Glorot and Yoshua Bengio. + - It is assumed that the total number of inputs and outputs from the layer is + num_inputs_and_outputs. That is, you should set num_inputs_and_outputs to + the sum of the dimensionalities of the vectors going into and out of the + layer that uses params as its parameters. + !*/ + +// ---------------------------------------------------------------------------------------- + + template + void net_to_xml ( + const net_type& net, + std::ostream& out + ); + /*! + requires + - net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or + add_tag_layer. + - All layers in the net must provide to_xml() functions. + ensures + - Prints the given neural network object as an XML document to the given output + stream. + !*/ + + template + void net_to_xml ( + const net_type& net, + const std::string& filename + ); + /*! + requires + - net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or + add_tag_layer. + - All layers in the net must provide to_xml() functions. + ensures + - This function is just like the above net_to_xml(), except it writes to a file + rather than an ostream. + !*/ + +// ---------------------------------------------------------------------------------------- + + template + dpoint input_tensor_to_output_tensor( + const net_type& net, + dpoint p + ); + /*! + requires + - net_type is an object of type add_layer, add_skip_layer, or add_tag_layer. + - All layers in the net must provide map_input_to_output() functions. + ensures + - Given a dpoint (i.e. a row,column coordinate) in the input tensor given to + net, this function returns the corresponding dpoint in the output tensor + net.get_output(). This kind of mapping is useful when working with fully + convolutional networks as you will often want to know what parts of the + output feature maps correspond to what parts of the input. + - If the network contains skip layers then any layers skipped over by the skip + layer are ignored for the purpose of computing this coordinate mapping. That + is, if you walk the network from the output layer to the input layer, where + each time you encounter a skip layer you jump to the layer indicated by the + skip layer, you will visit exactly the layers in the network involved in the + input_tensor_to_output_tensor() calculation. This behavior is useful since it + allows you to compute some auxiliary DNN as a separate branch of computation + that is separate from the main network's job of running some kind of fully + convolutional network over an image. For instance, you might want to have a + branch in your network that computes some global image level + summarization/feature. + !*/ + +// ---------------------------------------------------------------------------------------- + + template + dpoint output_tensor_to_input_tensor( + const net_type& net, + dpoint p + ); + /*! + requires + - net_type is an object of type add_layer, add_skip_layer, or add_tag_layer. + - All layers in the net must provide map_output_to_input() functions. + ensures + - This function provides the reverse mapping of input_tensor_to_output_tensor(). + That is, given a dpoint in net.get_output(), what is the corresponding dpoint + in the input tensor? + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_DNn_UTILITIES_ABSTRACT_H_ + + diff --git a/ml/dlib/dlib/dnn/validation.h b/ml/dlib/dlib/dnn/validation.h new file mode 100644 index 000000000..c65cb4526 --- /dev/null +++ b/ml/dlib/dlib/dnn/validation.h @@ -0,0 +1,122 @@ +// Copyright (C) 2016 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_DNn_VALIDATION_H_ +#define DLIB_DNn_VALIDATION_H_ + +#include "../svm/cross_validate_object_detection_trainer_abstract.h" +#include "../svm/cross_validate_object_detection_trainer.h" +#include "layers.h" +#include + +namespace dlib +{ + namespace impl + { + inline std::set get_labels ( + const std::vector& rects1, + const std::vector& rects2 + ) + { + std::set labels; + for (auto& rr : rects1) + labels.insert(rr.label); + for (auto& rr : rects2) + labels.insert(rr.label); + return labels; + } + } + + template < + typename SUBNET, + typename image_array_type + > + const matrix test_object_detection_function ( + loss_mmod& detector, + const image_array_type& images, + const std::vector>& truth_dets, + const test_box_overlap& overlap_tester = test_box_overlap(), + const double adjust_threshold = 0, + const test_box_overlap& overlaps_ignore_tester = test_box_overlap() + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( is_learning_problem(images,truth_dets) == true , + "\t matrix test_object_detection_function()" + << "\n\t invalid inputs were given to this function" + << "\n\t is_learning_problem(images,truth_dets): " << is_learning_problem(images,truth_dets) + << "\n\t images.size(): " << images.size() + ); + + + + double correct_hits = 0; + double total_true_targets = 0; + + std::vector > all_dets; + unsigned long missing_detections = 0; + + resizable_tensor temp; + + for (unsigned long i = 0; i < images.size(); ++i) + { + std::vector hits; + detector.to_tensor(&images[i], &images[i]+1, temp); + detector.subnet().forward(temp); + detector.loss_details().to_label(temp, detector.subnet(), &hits, adjust_threshold); + + + for (auto& label : impl::get_labels(truth_dets[i], hits)) + { + std::vector truth_boxes; + std::vector ignore; + std::vector> boxes; + // copy hits and truth_dets into the above three objects + for (auto&& b : truth_dets[i]) + { + if (b.ignore) + { + ignore.push_back(b); + } + else if (b.label == label) + { + truth_boxes.push_back(full_object_detection(b.rect)); + ++total_true_targets; + } + } + for (auto&& b : hits) + { + if (b.label == label) + boxes.push_back(std::make_pair(b.detection_confidence, b.rect)); + } + + correct_hits += impl::number_of_truth_hits(truth_boxes, ignore, boxes, overlap_tester, all_dets, missing_detections, overlaps_ignore_tester); + } + } + + std::sort(all_dets.rbegin(), all_dets.rend()); + + double precision, recall; + + double total_hits = all_dets.size(); + + if (total_hits == 0) + precision = 1; + else + precision = correct_hits / total_hits; + + if (total_true_targets == 0) + recall = 1; + else + recall = correct_hits / total_true_targets; + + matrix res; + res = precision, recall, average_precision(all_dets, missing_detections); + return res; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_DNn_VALIDATION_H_ + diff --git a/ml/dlib/dlib/enable_if.h b/ml/dlib/dlib/enable_if.h new file mode 100644 index 000000000..f081dea6d --- /dev/null +++ b/ml/dlib/dlib/enable_if.h @@ -0,0 +1,62 @@ +// Copyright 2003 (C) The Trustees of Indiana University. +// Use, modification, and distribution is subject to the Boost Software +// License, Version 1.0. (See accompanying file LICENSE_1_0.txt or copy at +// http://www.boost.org/LICENSE_1_0.txt) +// Authors: Jaakko Jarvi (jajarvi at osl.iu.edu) +// Jeremiah Willcock (jewillco at osl.iu.edu) +// Andrew Lumsdaine (lums at osl.iu.edu) +#ifndef DLIB_BOOST_UTILITY_ENABLE_IF_HPP +#define DLIB_BOOST_UTILITY_ENABLE_IF_HPP + +namespace dlib +{ + + template + struct enable_if_c { + typedef T type; + }; + + template + struct enable_if_c {}; + + template + struct enable_if : public enable_if_c {}; + + template + struct lazy_enable_if_c { + typedef typename T::type type; + }; + + template + struct lazy_enable_if_c {}; + + template + struct lazy_enable_if : public lazy_enable_if_c {}; + + + template + struct disable_if_c { + typedef T type; + }; + + template + struct disable_if_c {}; + + template + struct disable_if : public disable_if_c {}; + + template + struct lazy_disable_if_c { + typedef typename T::type type; + }; + + template + struct lazy_disable_if_c {}; + + template + struct lazy_disable_if : public lazy_disable_if_c {}; + +} // namespace dlib + +#endif // DLIB_BOOST_UTILITY_ENABLE_IF_HPP + diff --git a/ml/dlib/dlib/entropy_decoder.h b/ml/dlib/dlib/entropy_decoder.h new file mode 100644 index 000000000..345dffa8c --- /dev/null +++ b/ml/dlib/dlib/entropy_decoder.h @@ -0,0 +1,44 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ENTROPY_DECODEr_ +#define DLIB_ENTROPY_DECODEr_ + +#include "entropy_decoder/entropy_decoder_kernel_1.h" +#include "entropy_decoder/entropy_decoder_kernel_2.h" +#include "entropy_decoder/entropy_decoder_kernel_c.h" + + + + +namespace dlib +{ + + + class entropy_decoder + { + entropy_decoder() {} + + + public: + + //----------- kernels --------------- + + // kernel_1a + typedef entropy_decoder_kernel_1 + kernel_1a; + typedef entropy_decoder_kernel_c + kernel_1a_c; + + + // kernel_2a + typedef entropy_decoder_kernel_2 + kernel_2a; + typedef entropy_decoder_kernel_c + kernel_2a_c; + + + }; +} + +#endif // DLIB_ENTROPY_DECODEr_ + diff --git a/ml/dlib/dlib/entropy_decoder/entropy_decoder_kernel_1.cpp b/ml/dlib/dlib/entropy_decoder/entropy_decoder_kernel_1.cpp new file mode 100644 index 000000000..82c583634 --- /dev/null +++ b/ml/dlib/dlib/entropy_decoder/entropy_decoder_kernel_1.cpp @@ -0,0 +1,220 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ENTROPY_DECODER_KERNEL_1_CPp_ +#define DLIB_ENTROPY_DECODER_KERNEL_1_CPp_ +#include "entropy_decoder_kernel_1.h" +#include +#include +#include + +#include "../assert.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + entropy_decoder_kernel_1:: + entropy_decoder_kernel_1( + ) : + initial_low(0x00000001), + initial_high(0xffffffff), + in(0), + low(initial_low), + high(initial_high), + buf(0), + buf_used(0), + target(0x00000000), + r(0) + { + } + +// ---------------------------------------------------------------------------------------- + + entropy_decoder_kernel_1:: + ~entropy_decoder_kernel_1 ( + ) + { + } + +// ---------------------------------------------------------------------------------------- + + void entropy_decoder_kernel_1:: + clear( + ) + { + in = 0; + buf_used = 0; + buf = 0; + r = 0; + low = initial_low; + high = initial_high; + target = 0x00000000; + } + +// ---------------------------------------------------------------------------------------- + + void entropy_decoder_kernel_1:: + set_stream ( + std::istream& in_ + ) + { + buf_used = 0; + buf = 0; + r = 0; + low = initial_low; + high = initial_high; + target = 0x00000000; + + in = &in_; + streambuf = in_.rdbuf(); + + + + unsigned char ch; + + + streambuf->sgetn((char*)&ch,1); + target = ch; + + target <<= 8; + if (streambuf->sgetn((char*)&ch,1)) + target += ch; + + + target <<= 8; + if (streambuf->sgetn((char*)&ch,1)) + target += ch; + + + target <<= 8; + if (streambuf->sgetn((char*)&ch,1)) + target += ch; + + } + +// ---------------------------------------------------------------------------------------- + + bool entropy_decoder_kernel_1:: + stream_is_set ( + ) const + { + if (in != 0) + return true; + else + return false; + } + +// ---------------------------------------------------------------------------------------- + + std::istream& entropy_decoder_kernel_1:: + get_stream ( + ) const + { + return *in; + } + +// ---------------------------------------------------------------------------------------- + + void entropy_decoder_kernel_1:: + decode ( + uint32 low_count, + uint32 high_count + ) + { + // note that we must subtract 1 to preserve the convention that + // high == the real upper range - 1 + high = low + r*high_count - 1; + low = low + r*low_count; + r = 0; + + + + while (true) + { + + // if the highest order bit in high and low is the same + if ( low >= 0x80000000 || high < 0x80000000) + { + // make sure buf isn't empty + if (buf_used == 0) + { + buf_used = 8; + if (streambuf->sgetn(reinterpret_cast(&buf),1)==0) + { + // if there isn't anything else in the streambuffer then just + // make buf zero. + buf = 0; + } + } + + // we will be taking one bit from buf to replace the one we threw away + --buf_used; + + // roll off the bit in target + target <<= 1; + + // roll off the bit + high <<= 1; + low <<= 1; + high |= 1; // note that it is ok to add one to high here because + // of the convention that high == real upper range - 1. + // so that means that if we want to shift the upper range + // left by one then we must shift a one into high also + // since real upper range == high + 0.999999999... + + // make sure low is never zero + if (low == 0) + low = 1; + + // take a bit from buf to fill in the one we threw away + target += (buf>>buf_used)&0x01; + } + // if the distance between high and low is small and there aren't + // any bits we can roll off then round low up or high down. + else if (high-low < 0x10000) + { + if (high == 0x80000000) + high = 0x7fffffff; + else + low = 0x80000000; + } + else + { + break; + } + } // while (true) + + } + +// ---------------------------------------------------------------------------------------- + + bool entropy_decoder_kernel_1:: + get_target_called ( + ) const + { + return (r != 0); + } + +// ---------------------------------------------------------------------------------------- + + uint32 entropy_decoder_kernel_1:: + get_target ( + uint32 total + ) + { + // note that we must add one because of the convention that + // high == the real upper range minus 1 + r = (high-low+1)/total; + uint32 temp = (target-low)/r; + if (temp < total) + return temp; + else + return total-1; + } + +// ---------------------------------------------------------------------------------------- + +} +#endif // DLIB_ENTROPY_DECODER_KERNEL_1_CPp_ + diff --git a/ml/dlib/dlib/entropy_decoder/entropy_decoder_kernel_1.h b/ml/dlib/dlib/entropy_decoder/entropy_decoder_kernel_1.h new file mode 100644 index 000000000..daf6d11ec --- /dev/null +++ b/ml/dlib/dlib/entropy_decoder/entropy_decoder_kernel_1.h @@ -0,0 +1,132 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ENTROPY_DECODER_KERNEl_1_ +#define DLIB_ENTROPY_DECODER_KERNEl_1_ + +#include "../algs.h" +#include "entropy_decoder_kernel_abstract.h" +#include +#include "../uintn.h" + +namespace dlib +{ + + class entropy_decoder_kernel_1 + { + /*! + GENERAL NOTES + this decoder is implemented using arithmetic coding + + INITIAL VALUE + in == 0 + buf_used == 0 + buf == 0 + initial_low == 0x00000001 (slightly more than zero) + initial_high == 0xffffffff (slightly less than one, 0.99999999976717) + target == 0x00000000 (zero) + low == initial_low + high == initial_high + r == 0 + + CONVENTION + if (in != 0) + *in == get_stream() + true == stream_is_set() + streambuf == in->rdbuf() + else + false == stream_is_set() + + buf == used to hold fractional byte values which are fed to target. + buf_used == the number of low order bits in buf that are currently + in use + low == the low end of the range used for arithmetic encoding. + this number is used as a 32bit fixed point real number. + the point is fixed just before the first bit, so it is + always in the range [0,1) + + low is also never allowed to be zero to avoid overflow + in the calculation (high-low+1)/total. + + high == the high end of the range - 1 used for arithmetic encoding. + this number is used as a 32bit fixed point real number. + the point is fixed just before the first bit, so when we + interpret high as a real number then it is always in the + range [0,1) + + the range for arithmetic encoding is always + [low,high + 0.9999999...) the 0.9999999... is why + high == real upper range - 1 + + target == 32 bits of the fraction produced from an arithmetic encoder. + this number is used as a 32bit fixed point real number. + the point is fixed just before the first bit, so it is + always in the range [0,1) + + r == the value (high-low+1)/total from the last call to + get_target() or 0 if get_target_called() should be false + + get_target_called() == (r != 0) + + !*/ + + public: + + entropy_decoder_kernel_1 ( + ); + + virtual ~entropy_decoder_kernel_1 ( + ); + + void clear( + ); + + void set_stream ( + std::istream& in + ); + + bool stream_is_set ( + ) const; + + std::istream& get_stream ( + ) const; + + void decode ( + uint32 low_count, + uint32 high_count + ); + + bool get_target_called ( + ) const; + + uint32 get_target ( + uint32 total + ); + + private: + + // restricted functions + entropy_decoder_kernel_1(entropy_decoder_kernel_1&); // copy constructor + entropy_decoder_kernel_1& operator=(entropy_decoder_kernel_1&); // assignment operator + + // data members + const uint32 initial_low; + const uint32 initial_high; + std::istream* in; + uint32 low; + uint32 high; + unsigned char buf; + uint32 buf_used; + uint32 target; + uint32 r; + std::streambuf* streambuf; + + }; + +} + +#ifdef NO_MAKEFILE +#include "entropy_decoder_kernel_1.cpp" +#endif + +#endif // DLIB_ENTROPY_DECODER_KERNEl_1_ + diff --git a/ml/dlib/dlib/entropy_decoder/entropy_decoder_kernel_2.cpp b/ml/dlib/dlib/entropy_decoder/entropy_decoder_kernel_2.cpp new file mode 100644 index 000000000..5b986273e --- /dev/null +++ b/ml/dlib/dlib/entropy_decoder/entropy_decoder_kernel_2.cpp @@ -0,0 +1,224 @@ +// Copyright (C) 2004 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ENTROPY_DECODER_KERNEL_2_CPp_ +#define DLIB_ENTROPY_DECODER_KERNEL_2_CPp_ +#include "entropy_decoder_kernel_2.h" +#include +#include +#include + +#include "../assert.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + entropy_decoder_kernel_2:: + entropy_decoder_kernel_2( + ) : + initial_low(0x00000001), + initial_high(0xffffffff), + in(0), + low(initial_low), + high(initial_high), + target(0x00000000), + r(0) + { + } + +// ---------------------------------------------------------------------------------------- + + entropy_decoder_kernel_2:: + ~entropy_decoder_kernel_2 ( + ) + { + } + +// ---------------------------------------------------------------------------------------- + + void entropy_decoder_kernel_2:: + clear( + ) + { + in = 0; + r = 0; + low = initial_low; + high = initial_high; + target = 0x00000000; + } + +// ---------------------------------------------------------------------------------------- + + void entropy_decoder_kernel_2:: + set_stream ( + std::istream& in_ + ) + { + r = 0; + low = initial_low; + high = initial_high; + target = 0x00000000; + + in = &in_; + streambuf = in_.rdbuf(); + + + + unsigned char ch; + + + streambuf->sgetn((char*)&ch,1); + target = ch; + + target <<= 8; + if (streambuf->sgetn((char*)&ch,1)) + target += ch; + + + target <<= 8; + if (streambuf->sgetn((char*)&ch,1)) + target += ch; + + + target <<= 8; + if (streambuf->sgetn((char*)&ch,1)) + target += ch; + } + +// ---------------------------------------------------------------------------------------- + + bool entropy_decoder_kernel_2:: + stream_is_set ( + ) const + { + if (in != 0) + return true; + else + return false; + } + +// ---------------------------------------------------------------------------------------- + + std::istream& entropy_decoder_kernel_2:: + get_stream ( + ) const + { + return *in; + } + +// ---------------------------------------------------------------------------------------- + + void entropy_decoder_kernel_2:: + decode ( + uint32 low_count, + uint32 high_count + ) + { + // note that we must subtract 1 to preserve the convention that + // high == the real upper range - 1 + high = low + r*high_count - 1; + low = low + r*low_count; + r = 0; + + + while (true ) + { + + // if high and low don't have the same 8 high order bits + if ((high&0xFF000000) != (low&0xFF000000)) + { + // if the distance between high and low is small and there aren't + // any bits we can roll off then force high and low to have common high + // order bits. + if ((high-low < 0x10000)) + { + if (high-low > 0x1000) + { + high>>=1; + low>>=1; + high = low = high+low; + high += 0xFF; + low -= 0xFF; + } + else /**/ + { + high>>=1; + low>>=1; + high = low = high+low; + } + } + else + { + // there are no bits to roll off and high and low are not + // too close so just quit the loop + break; + } + + } + // else if there are 8 bits we can roll off + else + { + unsigned char buf; + if (streambuf->sgetn(reinterpret_cast(&buf),1)==0) + { + // if there isn't anything else in the streambuffer then just + // make buf zero. + buf = 0; + } + + // also roll off the bits in target + target <<= 8; + + // roll off the bits + high <<= 8; + low <<= 8; + high |= 0xFF; // note that it is ok to add 0xFF to high here because + // of the convention that high == real upper range - 1. + // so that means that if we want to shift the upper range + // left by one then we must shift a one into high also + // since real upper range == high + 0.999999999... + + // make sure low is never zero + if (low == 0) + low = 1; + + + // put the new bits into target + target |= static_cast(buf); + } + + } // while (true) + } + +// ---------------------------------------------------------------------------------------- + + bool entropy_decoder_kernel_2:: + get_target_called ( + ) const + { + return (r != 0); + } + +// ---------------------------------------------------------------------------------------- + + uint32 entropy_decoder_kernel_2:: + get_target ( + uint32 total + ) + { + // note that we must add one because of the convention that + // high == the real upper range minus 1 + r = (high-low+1)/total; + uint32 temp = (target-low)/r; + if (temp < total) + return temp; + else + return total-1; + } + +// ---------------------------------------------------------------------------------------- + +} +#endif // DLIB_ENTROPY_DECODER_KERNEL_2_CPp_ + diff --git a/ml/dlib/dlib/entropy_decoder/entropy_decoder_kernel_2.h b/ml/dlib/dlib/entropy_decoder/entropy_decoder_kernel_2.h new file mode 100644 index 000000000..7284cec3c --- /dev/null +++ b/ml/dlib/dlib/entropy_decoder/entropy_decoder_kernel_2.h @@ -0,0 +1,127 @@ +// Copyright (C) 2004 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ENTROPY_DECODER_KERNEl_2_ +#define DLIB_ENTROPY_DECODER_KERNEl_2_ + +#include "../algs.h" +#include "entropy_decoder_kernel_abstract.h" +#include +#include "../uintn.h" + +namespace dlib +{ + + class entropy_decoder_kernel_2 + { + /*! + GENERAL NOTES + this decoder is implemented using "range" coding + + INITIAL VALUE + in == 0 + initial_low == 0x00000001 (slightly more than zero) + initial_high == 0xffffffff (slightly less than one, 0.99999999976717) + target == 0x00000000 (zero) + low == initial_low + high == initial_high + r == 0 + + CONVENTION + if (in != 0) + *in == get_stream() + true == stream_is_set() + streambuf == in->rdbuf() + else + false == stream_is_set() + + + low == the low end of the range used for arithmetic encoding. + this number is used as a 32bit fixed point real number. + the point is fixed just before the first bit, so it is + always in the range [0,1) + + low is also never allowed to be zero to avoid overflow + in the calculation (high-low+1)/total. + + high == the high end of the range - 1 used for arithmetic encoding. + this number is used as a 32bit fixed point real number. + the point is fixed just before the first bit, so when we + interpret high as a real number then it is always in the + range [0,1) + + the range for arithmetic encoding is always + [low,high + 0.9999999...) the 0.9999999... is why + high == real upper range - 1 + + target == 32 bits of the fraction produced from an arithmetic encoder. + this number is used as a 32bit fixed point real number. + the point is fixed just before the first bit, so it is + always in the range [0,1) + + r == the value (high-low+1)/total from the last call to + get_target() or 0 if get_target_called() should be false + + get_target_called() == (r != 0) + + !*/ + + public: + + entropy_decoder_kernel_2 ( + ); + + virtual ~entropy_decoder_kernel_2 ( + ); + + void clear( + ); + + void set_stream ( + std::istream& in + ); + + bool stream_is_set ( + ) const; + + std::istream& get_stream ( + ) const; + + void decode ( + uint32 low_count, + uint32 high_count + ); + + bool get_target_called ( + ) const; + + uint32 get_target ( + uint32 total + ); + + private: + + // restricted functions + entropy_decoder_kernel_2(entropy_decoder_kernel_2&); // copy constructor + entropy_decoder_kernel_2& operator=(entropy_decoder_kernel_2&); // assignment operator + + // data members + const uint32 initial_low; + const uint32 initial_high; + std::istream* in; + uint32 low; + uint32 high; + uint32 target; + uint32 r; + std::streambuf* streambuf; + + }; + + +} + +#ifdef NO_MAKEFILE +#include "entropy_decoder_kernel_2.cpp" +#endif + +#endif // DLIB_ENTROPY_DECODER_KERNEl_2_ + diff --git a/ml/dlib/dlib/entropy_decoder/entropy_decoder_kernel_abstract.h b/ml/dlib/dlib/entropy_decoder/entropy_decoder_kernel_abstract.h new file mode 100644 index 000000000..89906b9ae --- /dev/null +++ b/ml/dlib/dlib/entropy_decoder/entropy_decoder_kernel_abstract.h @@ -0,0 +1,207 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_ENTROPY_DECODER_KERNEl_ABSTRACT_ +#ifdef DLIB_ENTROPY_DECODER_KERNEl_ABSTRACT_ + +#include "../algs.h" +#include +#include "../uintn.h" + +namespace dlib +{ + + class entropy_decoder + { + /*! + INITIAL VALUE + stream_is_set() == false + get_target_called() == false + + + WHAT THIS OBJECT REPRESENTS + This object represents an entropy decoder (could be implemented as an + arithmetic decoder for example). + + Note that all implementations of entropy_encoder and entropy_decoder + are paired. This means that if you use entropy_encoder_kernel_n to + encode something then you must use the corresponding + entropy_decoder_kernel_n to decode it. + + + WHERE IS EOF? + It is important to note that this object will not give any indication + that is has hit the end of the input stream when it occurs. It is + up to you to use some kind of coding scheme to detect this in the + compressed data stream. + + Another important thing to know is that decode() must be called + exactly the same number of times as encode() and with the same values + supplied for TOTAL, high_count, and low_count. Doing this ensures + that the decoder consumes exactly all the bytes from the input + stream that were written by the entropy_encoder. + + NOTATION: + At any moment each symbol has a certain probability of appearing in + the input stream. These probabilities may change as each symbol is + decoded and the probability model is updated accordingly. + + + - Before considering current symbol: + + let P(i) be a function which gives the probability of seeing the ith + symbol of an N symbol alphabet. Note that P(i) refers to the probability + of seeing the ith symbol WITHOUT considering the symbol currently given + by get_target(TOTAL). ( The domain of P(i) is from 0 to N-1. ) + + for each i: P(i) == COUNT/TOTAL where COUNT and TOTAL are integers + and TOTAL is the same number for all P(i) but COUNT may vary. + + let LOW_COUNT(i) be the sum of all P(x)*TOTAL from x == 0 to x == i-1 + (note that LOW_COUNT(0) == 0) + let HIGH_COUNT(i) be the sum of all P(x)*TOTAL from x == 0 to x == i + + + - After considering current symbol: + + let #P(i) be a function which gives the probability of seeing the ith + symbol after we have updated our probability model to take the symbol + given by get_target(TOTAL) into account. + + for each i: #P(i) == #COUNT/#TOTAL where #COUNT and #TOTAL are integers + and #TOTAL is the same number for all #P(i) but #COUNT may vary. + !*/ + + public: + + entropy_decoder ( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc + !*/ + + virtual ~entropy_decoder ( + ); + /*! + ensures + - all memory associated with *this has been released + !*/ + + void clear( + ); + /*! + ensures + - #*this has its initial value + - if (stream_is_set()) + - clears any state accumulated in *this from decoding data from + the stream get_stream() + throws + - any exception + if this exception is thrown then #*this is unusable + until clear() is called and succeeds + !*/ + + void set_stream ( + std::istream& in + ); + /*! + ensures + - #*this will read data from in and decode it + - #stream_is_set() == true + - #get_target() == a number representing the first symbol from in + - #get_target_called() == false + - if (stream_is_set()) + - clears any state accumulated in *this from decoding data from + the stream get_stream() + throws + - any exception + if this exception is thrown then #*this is unusable + until clear() is called and succeeds + !*/ + + bool stream_is_set ( + ) const; + /*! + ensures + - returns true if a stream has been associated with *this by calling + set_stream() + !*/ + + std::istream& get_stream ( + ) const; + /*! + requires + - stream_is_set() == true + ensures + - returns a reference to the istream object that *this is reading + encoded data from + !*/ + + + void decode ( + uint32 low_count, + uint32 high_count + ); + /*! + requires + - get_target_called() == true + - stream_is_set() == true + - low_count == LOW_COUNT(S) where S is the symbol represented + by get_target(TOTAL) + - high_count == HIGH_COUNT(S) where S is the symbol represented + by get_target(TOTAL) + - low_count <= get_target(TOTAL) < high_count <= TOTAL + ensures + - #get_target(#TOTAL) == a number which represents the next symbol + - #get_target_called() == false + throws + - any exception + if this exception is thrown then #*this is unusable + until clear() is called and succeeds + !*/ + + bool get_target_called ( + ) const; + /*! + ensures + - returns true if get_target() has been called and since then decode() + and set_stream() have not been called + - returns false otherwise + !*/ + + uint32 get_target ( + uint32 total + ); + /*! + requires + - 0 < total < 65536 (2^16) + - total == TOTAL + - stream_is_set() == true + ensures + - in the next call to decode() the value of TOTAL will be + considered to be total + - #get_target_called() == true + - returns a number N such that: + - N is in the range 0 to total - 1 + - N represents a symbol S where + LOW_COUNT(S) <= N < HIGH_COUNT(S) + throws + - any exception + if this exception is thrown then #*this is unusable + until clear() is called and succeeds + !*/ + + private: + + // restricted functions + entropy_decoder(entropy_decoder&); // copy constructor + entropy_decoder& operator=(entropy_decoder&); // assignment operator + + }; + +} + +#endif // DLIB_ENTROPY_DECODER_KERNEl_ABSTRACT_ + diff --git a/ml/dlib/dlib/entropy_decoder/entropy_decoder_kernel_c.h b/ml/dlib/dlib/entropy_decoder/entropy_decoder_kernel_c.h new file mode 100644 index 000000000..4b94791f3 --- /dev/null +++ b/ml/dlib/dlib/entropy_decoder/entropy_decoder_kernel_c.h @@ -0,0 +1,123 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ENTROPY_DECODER_KERNEl_C_ +#define DLIB_ENTROPY_DECODER_KERNEl_C_ + +#include "entropy_decoder_kernel_abstract.h" +#include "../algs.h" +#include "../assert.h" +#include + +namespace dlib +{ + + template < + typename decoder + > + class entropy_decoder_kernel_c : public decoder + { + + public: + std::istream& get_stream ( + ) const; + + void decode ( + uint32 low_count, + uint32 high_count + ); + + uint32 get_target ( + uint32 total + ); + + private: + uint32 _get_target; + uint32 TOTAL; + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename decoder + > + std::istream& entropy_decoder_kernel_c:: + get_stream ( + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT( this->stream_is_set() == true, + "\tstd::istream& entropy_decoder::get_stream()" + << "\n\tyou must set a stream for this object before you can get it" + << "\n\tthis: " << this + ); + + // call the real function + return decoder::get_stream(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename decoder + > + void entropy_decoder_kernel_c:: + decode ( + uint32 low_count, + uint32 high_count + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( (low_count <= _get_target) && (_get_target < high_count) && + (high_count <= TOTAL) && + (this->stream_is_set() == true) && (this->get_target_called() == true), + "\tvoid entropy_decoder::decode()" + << "\n\tRefer to the ensures clause for this function for further information." + << "\n\tNote that _get_target refers to get_target(TOTAL)" + << "\n\tthis: " << this + << "\n\tlow_count: " << low_count + << "\n\thigh_count: " << high_count + << "\n\tTOTAL: " << TOTAL + << "\n\tget_target(TOTAL): " << _get_target + << "\n\tis_stream_set(): " << (this->stream_is_set() ? "true" : "false" ) + << "\n\tget_target_called(): " << (this->get_target_called() ? "true" : "false" ) + ); + + // call the real function + decoder::decode(low_count,high_count); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename decoder + > + uint32 entropy_decoder_kernel_c:: + get_target ( + uint32 total + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( (total > 0) && (total < 65536) && (this->stream_is_set() == true), + "\tvoid entropy_decoder::get_target()" + << "\n\tyou must set a stream for this object before you can get the " + << "\n\rnext target." + << "\n\tthis: " << this + << "\n\ttotal: " << total + ); + + // call the real function + _get_target = decoder::get_target(total); + TOTAL = total; + return _get_target; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_ENTROPY_ENCODER_KERNEl_C_ + diff --git a/ml/dlib/dlib/entropy_decoder_model.h b/ml/dlib/dlib/entropy_decoder_model.h new file mode 100644 index 000000000..d898161aa --- /dev/null +++ b/ml/dlib/dlib/entropy_decoder_model.h @@ -0,0 +1,108 @@ +// Copyright (C) 2004 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ENTROPY_DECODER_MODEl_ +#define DLIB_ENTROPY_DECODER_MODEl_ + +#include "entropy_decoder_model/entropy_decoder_model_kernel_1.h" +#include "entropy_decoder_model/entropy_decoder_model_kernel_2.h" +#include "entropy_decoder_model/entropy_decoder_model_kernel_3.h" +#include "entropy_decoder_model/entropy_decoder_model_kernel_4.h" +#include "entropy_decoder_model/entropy_decoder_model_kernel_5.h" +#include "entropy_decoder_model/entropy_decoder_model_kernel_6.h" + +#include "conditioning_class.h" +#include "memory_manager.h" + +namespace dlib +{ + + + template < + unsigned long alphabet_size, + typename entropy_decoder + > + class entropy_decoder_model + { + entropy_decoder_model() {} + + typedef typename conditioning_class::kernel_1a cc1; + typedef typename conditioning_class::kernel_2a cc2; + typedef typename conditioning_class::kernel_3a cc3; + typedef typename conditioning_class::kernel_4a cc4a; + typedef typename conditioning_class::kernel_4b cc4b; + typedef typename conditioning_class::kernel_4c cc4c; + typedef typename conditioning_class::kernel_4d cc4d; + + public: + + //----------- kernels --------------- + + // kernel_1 + typedef entropy_decoder_model_kernel_1 + kernel_1a; + + typedef entropy_decoder_model_kernel_1 + kernel_1b; + + typedef entropy_decoder_model_kernel_1 + kernel_1c; + + // -------------------- + + // kernel_2 + typedef entropy_decoder_model_kernel_2 + kernel_2a; + + typedef entropy_decoder_model_kernel_2 + kernel_2b; + + typedef entropy_decoder_model_kernel_2 + kernel_2c; + + typedef entropy_decoder_model_kernel_2 + kernel_2d; + + // -------------------- + + // kernel_3 + typedef entropy_decoder_model_kernel_3 + kernel_3a; + + typedef entropy_decoder_model_kernel_3 + kernel_3b; + + typedef entropy_decoder_model_kernel_3 + kernel_3c; + + // -------------------- + + // kernel_4 + typedef entropy_decoder_model_kernel_4 + kernel_4a; + typedef entropy_decoder_model_kernel_4 + kernel_4b; + + + // -------------------- + + // kernel_5 + typedef entropy_decoder_model_kernel_5 + kernel_5a; + typedef entropy_decoder_model_kernel_5 + kernel_5b; + typedef entropy_decoder_model_kernel_5 + kernel_5c; + + + // -------------------- + + // kernel_6 + typedef entropy_decoder_model_kernel_6 + kernel_6a; + + + }; +} + +#endif // DLIB_ENTROPY_DECODER_MODEl_ + diff --git a/ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_1.h b/ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_1.h new file mode 100644 index 000000000..a0a94c948 --- /dev/null +++ b/ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_1.h @@ -0,0 +1,173 @@ +// Copyright (C) 2004 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ENTROPY_DECODER_MODEL_KERNEl_1_ +#define DLIB_ENTROPY_DECODER_MODEL_KERNEl_1_ + +#include "../algs.h" +#include "entropy_decoder_model_kernel_abstract.h" +#include "../assert.h" + +namespace dlib +{ + + template < + unsigned long alphabet_size, + typename entropy_decoder, + typename cc + > + class entropy_decoder_model_kernel_1 + { + /*! + REQUIREMENTS ON cc + cc is an implementation of conditioning_class/conditioning_class_kernel_abstract.h + cc::get_alphabet_size() == alphabet_size+1 + + INITIAL VALUE + Initially this object's finite context model is empty + + CONVENTION + &get_entropy_decoder() == coder + &order_0.get_global_state() == &gs + + This is an order-0 model. The last symbol in the order-0 context is + an escape into the order minus 1 context. + !*/ + + public: + + typedef entropy_decoder entropy_decoder_type; + + entropy_decoder_model_kernel_1 ( + entropy_decoder& coder + ); + + virtual ~entropy_decoder_model_kernel_1 ( + ); + + inline void clear( + ); + + inline void decode ( + unsigned long& symbol + ); + + entropy_decoder& get_entropy_decoder ( + ) { return coder; } + + static unsigned long get_alphabet_size ( + ) { return alphabet_size; } + + private: + + entropy_decoder& coder; + typename cc::global_state_type gs; + cc order_0; + + // restricted functions + entropy_decoder_model_kernel_1(entropy_decoder_model_kernel_1&); // copy constructor + entropy_decoder_model_kernel_1& operator=(entropy_decoder_model_kernel_1&); // assignment operator + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_decoder, + typename cc + > + entropy_decoder_model_kernel_1:: + entropy_decoder_model_kernel_1 ( + entropy_decoder& coder_ + ) : + coder(coder_), + order_0(gs) + { + COMPILE_TIME_ASSERT( 1 < alphabet_size && alphabet_size < 65535 ); + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_decoder, + typename cc + > + entropy_decoder_model_kernel_1:: + ~entropy_decoder_model_kernel_1 ( + ) + { + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_decoder, + typename cc + > + void entropy_decoder_model_kernel_1:: + clear( + ) + { + order_0.clear(); + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_decoder, + typename cc + > + void entropy_decoder_model_kernel_1:: + decode ( + unsigned long& symbol + ) + { + unsigned long current_symbol, low_count, high_count, target; + + // look in the order-0 context + target = coder.get_target(order_0.get_total()); + order_0.get_symbol(target,current_symbol,low_count,high_count); + + + // have coder decode the next symbol + coder.decode(low_count,high_count); + + // if current_symbol is not an escape from the order-0 context + if (current_symbol != alphabet_size) + { + // update the count for this symbol + order_0.increment_count(current_symbol,2); + + symbol = current_symbol; + return; + } + + // update the count for the escape symbol + order_0.increment_count(alphabet_size); + + + // go into the order minus one context + target = coder.get_target(alphabet_size); + coder.decode(target,target+1); + + + // update the count for this symbol in the order-0 context + order_0.increment_count(target,2); + + symbol = target; + + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_ENTROPY_DECODER_MODEL_KERNEl_1_ + diff --git a/ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_2.h b/ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_2.h new file mode 100644 index 000000000..6841db391 --- /dev/null +++ b/ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_2.h @@ -0,0 +1,245 @@ +// Copyright (C) 2004 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ENTROPY_DECODER_MODEL_KERNEl_2_ +#define DLIB_ENTROPY_DECODER_MODEL_KERNEl_2_ + +#include "../algs.h" +#include "entropy_decoder_model_kernel_abstract.h" +#include "../assert.h" + +namespace dlib +{ + + template < + unsigned long alphabet_size, + typename entropy_decoder, + typename cc, + typename ccbig + > + class entropy_decoder_model_kernel_2 + { + /*! + REQUIREMENTS ON cc + cc is an implementation of conditioning_class/conditioning_class_kernel_abstract.h + cc::get_alphabet_size() == alphabet_size+1 + this will be used for the order-0 context + + REQUIREMENTS ON ccbig + ccbig is an implementation of conditioning_class/conditioning_class_kernel_abstract.h + ccbig::get_alphabet_size() == alphabet_size+1 + this will be used for the order-1 context + + INITIAL VALUE + Initially this object's finite context model is empty + previous_symbol == 0 + + CONVENTION + &get_entropy_decoder() == coder + &order_0.get_global_state() == &gs + &order_1[i]->get_global_state() == &gsbig + + + This is an order-1-0 model. The last symbol in the order-0 and order-1 + context is an escape into the lower context. + + previous_symbol == the last symbol seen + !*/ + + public: + + typedef entropy_decoder entropy_decoder_type; + + entropy_decoder_model_kernel_2 ( + entropy_decoder& coder + ); + + virtual ~entropy_decoder_model_kernel_2 ( + ); + + inline void clear( + ); + + inline void decode ( + unsigned long& symbol + ); + + entropy_decoder& get_entropy_decoder ( + ) { return coder; } + + static unsigned long get_alphabet_size ( + ) { return alphabet_size; } + + private: + + entropy_decoder& coder; + typename cc::global_state_type gs; + typename ccbig::global_state_type gsbig; + cc order_0; + ccbig* order_1[alphabet_size]; + unsigned long previous_symbol; + + + // restricted functions + entropy_decoder_model_kernel_2(entropy_decoder_model_kernel_2&); // copy constructor + entropy_decoder_model_kernel_2& operator=(entropy_decoder_model_kernel_2&); // assignment operator + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_decoder, + typename cc, + typename ccbig + > + entropy_decoder_model_kernel_2:: + entropy_decoder_model_kernel_2 ( + entropy_decoder& coder_ + ) : + coder(coder_), + order_0(gs), + previous_symbol(0) + { + COMPILE_TIME_ASSERT( 1 < alphabet_size && alphabet_size < 65535); + + unsigned long i; + try + { + for (i = 0; i < alphabet_size; ++i) + { + order_1[i] = new ccbig(gsbig); + } + } + catch (...) + { + for (unsigned long j = 0; j < i; ++j) + { + delete order_1[j]; + } + throw; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_decoder, + typename cc, + typename ccbig + > + entropy_decoder_model_kernel_2:: + ~entropy_decoder_model_kernel_2 ( + ) + { + for (unsigned long i = 0; i < alphabet_size; ++i) + { + delete order_1[i]; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_decoder, + typename cc, + typename ccbig + > + void entropy_decoder_model_kernel_2:: + clear( + ) + { + previous_symbol = 0; + order_0.clear(); + for (unsigned long i = 0; i < alphabet_size; ++i) + { + order_1[i]->clear(); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_decoder, + typename cc, + typename ccbig + > + void entropy_decoder_model_kernel_2:: + decode ( + unsigned long& symbol + ) + { + unsigned long current_symbol, low_count, high_count, target; + + // look in the order-1 context + target = coder.get_target(order_1[previous_symbol]->get_total()); + order_1[previous_symbol]->get_symbol(target,current_symbol,low_count,high_count); + + // have the coder decode the next symbol + coder.decode(low_count,high_count); + + // if the current_symbol is not an escape from the order-1 context + if (current_symbol != alphabet_size) + { + symbol = current_symbol; + order_1[previous_symbol]->increment_count(current_symbol,2); + previous_symbol = current_symbol; + return; + } + + // since this is an escape to order-0 we should increment + // the escape symbol + order_1[previous_symbol]->increment_count(alphabet_size); + + + + // look in the order-0 context + target = coder.get_target(order_0.get_total()); + order_0.get_symbol(target,current_symbol,low_count,high_count); + + // have coder decode the next symbol + coder.decode(low_count,high_count); + + // if current_symbol is not an escape from the order-0 context + if (current_symbol != alphabet_size) + { + // update the count for this symbol + order_1[previous_symbol]->increment_count(current_symbol,2); + order_0.increment_count(current_symbol,2); + + symbol = current_symbol; + previous_symbol = current_symbol; + return; + } + + // update the count for the escape symbol + order_0.increment_count(current_symbol); + + + // go into the order minus one context + target = coder.get_target(alphabet_size); + coder.decode(target,target+1); + + + // update the count for this symbol + order_1[previous_symbol]->increment_count(target,2); + order_0.increment_count(target,2); + + symbol = target; + previous_symbol = target; + + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_ENTROPY_DECODER_MODEL_KERNEl_2_ + diff --git a/ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_3.h b/ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_3.h new file mode 100644 index 000000000..c55c09e85 --- /dev/null +++ b/ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_3.h @@ -0,0 +1,335 @@ +// Copyright (C) 2004 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ENTROPY_DECODER_MODEL_KERNEl_3_ +#define DLIB_ENTROPY_DECODER_MODEL_KERNEl_3_ + +#include "../algs.h" +#include "entropy_decoder_model_kernel_abstract.h" +#include "../assert.h" + +namespace dlib +{ + + template < + unsigned long alphabet_size, + typename entropy_decoder, + typename cc, + typename cc_high + > + class entropy_decoder_model_kernel_3 + { + /*! + REQUIREMENTS ON cc + cc is an implementation of conditioning_class/conditioning_class_kernel_abstract.h + cc::get_alphabet_size() == alphabet_size+1 + + REQUIREMENTS ON cc_high + cc_high is an implementation of conditioning_class/conditioning_class_kernel_abstract.h + cc_high::get_alphabet_size() == alphabet_size+1 + + INITIAL VALUE + - Initially this object's finite context model is empty + - previous_symbol == 0 + - previous_symbol2 == 0 + - order_1 == pointer to an array of alphabet_size elements + - order_2 == pointer to an array of alphabet_size*alphabet_size elements + - for all values of i: order_2[i] == 0 + + CONVENTION + &get_entropy_encoder() == coder + &order_0.get_global_state() == &gs + &order_1[i]->get_global_state() == &gs + + if (order_2[i] != 0) then + &order_2[i]->get_global_state() == &gs_high + + This is an order-2-1-0 model. The last symbol in the order-2, order-1 and + order-0 contexts is an escape into the lower context. + + previous_symbol == the last symbol seen + previous_symbol2 == the symbol we saw before previous_symbol + !*/ + + public: + + typedef entropy_decoder entropy_decoder_type; + + entropy_decoder_model_kernel_3 ( + entropy_decoder& coder + ); + + virtual ~entropy_decoder_model_kernel_3 ( + ); + + inline void clear( + ); + + inline void decode ( + unsigned long& symbol + ); + + entropy_decoder& get_entropy_decoder ( + ) { return coder; } + + static unsigned long get_alphabet_size ( + ) { return alphabet_size; } + + private: + + entropy_decoder& coder; + typename cc::global_state_type gs; + typename cc_high::global_state_type gs_high; + cc order_0; + cc** order_1; + unsigned long previous_symbol; + cc_high** order_2; + unsigned long previous_symbol2; + + // restricted functions + entropy_decoder_model_kernel_3(entropy_decoder_model_kernel_3&); // copy constructor + entropy_decoder_model_kernel_3& operator=(entropy_decoder_model_kernel_3&); // assignment operator + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_decoder, + typename cc, + typename cc_high + > + entropy_decoder_model_kernel_3:: + entropy_decoder_model_kernel_3 ( + entropy_decoder& coder_ + ) : + coder(coder_), + order_0(gs), + order_1(0), + previous_symbol(0), + order_2(0), + previous_symbol2(0) + { + COMPILE_TIME_ASSERT( 1 < alphabet_size && alphabet_size < 65535); + + try + { + order_1 = new cc*[alphabet_size]; + order_2 = new cc_high*[alphabet_size*alphabet_size]; + } + catch (...) + { + if (order_1) delete [] order_1; + if (order_2) delete [] order_2; + throw; + } + + + unsigned long i; + + for (i = 0; i < alphabet_size*alphabet_size; ++i) + { + order_2[i] = 0; + } + + try + { + for (i = 0; i < alphabet_size; ++i) + { + order_1[i] = new cc(gs); + } + } + catch (...) + { + for (unsigned long j = 0; j < i; ++j) + { + delete order_1[j]; + } + throw; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_decoder, + typename cc, + typename cc_high + > + entropy_decoder_model_kernel_3:: + ~entropy_decoder_model_kernel_3 ( + ) + { + for (unsigned long i = 0; i < alphabet_size; ++i) + { + delete order_1[i]; + } + + for (unsigned long i = 0; i < alphabet_size*alphabet_size; ++i) + { + if (order_2[i] != 0) + delete order_2[i]; + } + delete [] order_1; + delete [] order_2; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_decoder, + typename cc, + typename cc_high + > + void entropy_decoder_model_kernel_3:: + clear( + ) + { + previous_symbol = 0; + previous_symbol2 = 0; + order_0.clear(); + for (unsigned long i = 0; i < alphabet_size; ++i) + { + order_1[i]->clear(); + } + + for (unsigned long i = 0; i < alphabet_size*alphabet_size; ++i) + { + if (order_2[i] != 0) + { + delete order_2[i]; + order_2[i] = 0; + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_decoder, + typename cc, + typename cc_high + > + void entropy_decoder_model_kernel_3:: + decode ( + unsigned long& symbol + ) + { + unsigned long current_symbol, low_count, high_count, target; + + + // look in the order-2 context + unsigned long temp = previous_symbol + (previous_symbol2 * alphabet_size); + if (order_2[temp] != 0) + { + target = coder.get_target(order_2[temp]->get_total()); + order_2[temp]->get_symbol(target,current_symbol,low_count,high_count); + + // have the coder decode the next symbol + coder.decode(low_count,high_count); + + // if the current_symbol is not an escape from the order-2 context + if (current_symbol != alphabet_size) + { + symbol = current_symbol; + order_2[temp]->increment_count(current_symbol,2); + previous_symbol2 = previous_symbol; + previous_symbol = current_symbol; + return; + } + + // since this is an escape to order-1 we should increment + // the escape symbol + order_2[temp]->increment_count(alphabet_size); + } + else + { + order_2[temp] = new cc_high(gs_high); + } + + + + + + + // look in the order-1 context + target = coder.get_target(order_1[previous_symbol]->get_total()); + order_1[previous_symbol]->get_symbol(target,current_symbol,low_count,high_count); + + // have the coder decode the next symbol + coder.decode(low_count,high_count); + + // if the current_symbol is not an escape from the order-1 context + if (current_symbol != alphabet_size) + { + symbol = current_symbol; + order_2[temp]->increment_count(current_symbol,2); + order_1[previous_symbol]->increment_count(current_symbol,2); + previous_symbol2 = previous_symbol; + previous_symbol = current_symbol; + return; + } + + // since this is an escape to order-0 we should increment + // the escape symbol + order_1[previous_symbol]->increment_count(alphabet_size); + + + + // look in the order-0 context + target = coder.get_target(order_0.get_total()); + order_0.get_symbol(target,current_symbol,low_count,high_count); + + // have coder decode the next symbol + coder.decode(low_count,high_count); + + // if current_symbol is not an escape from the order-0 context + if (current_symbol != alphabet_size) + { + // update the count for this symbol + order_2[temp]->increment_count(current_symbol,2); + order_1[previous_symbol]->increment_count(current_symbol,2); + order_0.increment_count(current_symbol,2); + + + symbol = current_symbol; + previous_symbol2 = previous_symbol; + previous_symbol = current_symbol; + return; + } + + // update the count for the escape symbol + order_0.increment_count(current_symbol); + + + // go into the order minus one context + target = coder.get_target(alphabet_size); + coder.decode(target,target+1); + + + // update the count for this symbol + order_2[temp]->increment_count(target,2); + order_1[previous_symbol]->increment_count(target,2); + order_0.increment_count(target,2); + + + symbol = target; + previous_symbol2 = previous_symbol; + previous_symbol = target; + + + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_ENTROPY_DECODER_MODEL_KERNEl_3_ + diff --git a/ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_4.h b/ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_4.h new file mode 100644 index 000000000..dcbfef6a0 --- /dev/null +++ b/ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_4.h @@ -0,0 +1,622 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ENTROPY_DECODER_MODEL_KERNEl_4_ +#define DLIB_ENTROPY_DECODER_MODEL_KERNEl_4_ + +#include "../algs.h" +#include "entropy_decoder_model_kernel_abstract.h" +#include "../assert.h" + + +namespace dlib +{ + + namespace edmk4 + { + struct node + { + node* next; + node* child_context; + node* parent_context; + + unsigned short symbol; + unsigned short count; + unsigned short total; + unsigned short escapes; + }; + } + + + template < + unsigned long alphabet_size, + typename entropy_decoder, + unsigned long total_nodes, + unsigned long order + > + class entropy_decoder_model_kernel_4 + { + /*! + REQUIREMENTS ON total_nodes + - 4096 < total_nodes + - this is the total number of nodes that we will use in the tree + + REQUIREMENTS ON order + - 0 <= order + - this is the maximum depth-1 the tree will be allowed to go (note + that the root level is depth 0). + + + GENERAL NOTES + This implementation follows more or less the implementation + strategy laid out by Alistair Moffat in his paper + Implementing the PPM data compression scheme. Published in IEEE + Transactions on Communications, 38(11):1917-1921, 1990. + + The escape method used will be method D. + + + INITIAL VALUE + - root == pointer to an array of total_nodes nodes + - next_node == 1 + - cur == root + - cur_order = 0 + - root->next == 0 + - root->parent_context == 0 + - root->child_context == 0 + - root->escapes == 0 + - root->total == 0 + - stack_size == 0 + + CONVENTION + - pop() == stack[stack_size-1] + - &get_entropy_decoder() == coder + - root == pointer to an array of total_nodes nodes. + this is also the root of the tree. + + - if (next_node < total_nodes) then + - next_node == the next node in root that has not yet been allocated + + + - root->next == 0 + - root->parent_context == 0 + + + - for every node in the tree: + { + - NOTATION: + - The "context" of a node is the string of symbols seen + when you go from the root of the tree down (down though + child context pointers) to the node, including the symbol at + the node itself. (note that the context of the root node + is "" or the empty string) + - A set of nodes is in the same "context set" if all the node's + contexts are of length n and all the node's contexts share + the same prefix of length n-1. + - The "child context set" of a node is a set of nodes with + contexts that are one symbol longer and prefixed by the node's + context. For example, if a node has a context "abc" then the + nodes for contexts "abca", "abcb", "abcc", etc. are all in + the child context set of the node. + - The "parent context" of a node is the context that is one + symbol shorter than the node's context and includes the + symbol in the node. So the parent context of a node with + context "abcd" would be the context "bcd". + + + - if (next != 0) then + - next == pointer to the next node in the same context set + - if (child_context != 0) then + - child_context == pointer to the first node of the child + context set for this node. + - if (parent_context != 0) then + - parent_context == pointer to the parent context of this node. + - else + - this node is the root node of the tree + + + - if (this is not the root node) then + - symbol == the symbol represented with this node + - count == the number of times this symbol has been seen in its + parent context. + - else + - the root doesn't have a symbol. i.e. the context for the + root node is "" or the empty string. + + - total == The sum of the counts of all the nodes + in the child context set + escapes. + - escapes == the escape count for the context represented + by the node. + } + + + - cur_order < order + - cur_order == the depth of the node cur in the tree. + (note that the root node has depth 0) + - cur == pointer to the node in the tree who's context matches + the most recent symbols we have seen. + + + !*/ + + typedef edmk4::node node; + + public: + + typedef entropy_decoder entropy_decoder_type; + + entropy_decoder_model_kernel_4 ( + entropy_decoder& coder + ); + + virtual ~entropy_decoder_model_kernel_4 ( + ); + + inline void clear( + ); + + inline void decode ( + unsigned long& symbol + ); + + entropy_decoder& get_entropy_decoder ( + ) { return coder; } + + static unsigned long get_alphabet_size ( + ) { return alphabet_size; } + + private: + + + inline void push ( + edmk4::node* n + ); + /*! + requires + - stack_size <= order + ensures + - #pop() == n + !*/ + + inline edmk4::node* pop ( + ); + /*! + requires + - stack_size > 0 + ensures + - returns the node at the top of the stack + !*/ + + inline edmk4::node* allocate_node ( + ); + /*! + requires + - space_left() == true + ensures + - returns a pointer to a new node + !*/ + + inline void destroy_tree ( + ); + /*! + ensures + - deallocates all nodes except the root + - #root->child_context == 0 + - #root->escapes == 0 + - #root->total == 0 + - #cur == root + - #cur_order == 0 + - #stack_size == 0 + !*/ + + + inline bool space_left ( + ) const; + /*! + ensures + - returns true if there is at least 1 free node left. + - returns false otherwise + !*/ + + + inline void scale_counts ( + node* n + ); + /*! + ensures + - divides all the counts in the child context set of n by 2. + - none of the nodes in the child context set will have a count of 0 + !*/ + + + entropy_decoder& coder; + unsigned long next_node; + node* root; + node* cur; + unsigned long cur_order; + node* stack[order+1]; + unsigned long stack_size; + + // restricted functions + entropy_decoder_model_kernel_4(entropy_decoder_model_kernel_4&); // copy constructor + entropy_decoder_model_kernel_4& operator=(entropy_decoder_model_kernel_4&); // assignment operator + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_decoder, + unsigned long total_nodes, + unsigned long order + > + entropy_decoder_model_kernel_4:: + entropy_decoder_model_kernel_4 ( + entropy_decoder& coder_ + ) : + coder(coder_), + next_node(1), + cur_order(0), + stack_size(0) + { + COMPILE_TIME_ASSERT( 1 < alphabet_size && alphabet_size < 65535); + COMPILE_TIME_ASSERT( 4096 < total_nodes ); + + root = new node[total_nodes]; + cur = root; + + root->child_context = 0; + root->escapes = 0; + root->next = 0; + root->parent_context = 0; + root->total = 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_decoder, + unsigned long total_nodes, + unsigned long order + > + entropy_decoder_model_kernel_4:: + ~entropy_decoder_model_kernel_4 ( + ) + { + delete [] root; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_decoder, + unsigned long total_nodes, + unsigned long order + > + void entropy_decoder_model_kernel_4:: + clear( + ) + { + destroy_tree(); + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_decoder, + unsigned long total_nodes, + unsigned long order + > + void entropy_decoder_model_kernel_4:: + decode ( + unsigned long& symbol + ) + { + node* temp = cur; + cur = 0; + unsigned long low_count, high_count, total_count; + unsigned long target; + node* new_node = 0; + + // local_order will track the level of temp in the tree + unsigned long local_order = cur_order; + + + while (true) + { + high_count = 0; + if (space_left()) + { + total_count = temp->total; + + if (total_count > 0) + { + // check if we need to scale the counts + if (total_count > 10000) + { + scale_counts(temp); + total_count = temp->total; + } + + target = coder.get_target(total_count); + + // find either the symbol we are looking for or the + // end of the context set + node* n = temp->child_context; + node* last = 0; + while (true) + { + high_count += n->count; + + if (high_count > target || n->next == 0) + break; + last = n; + n = n->next; + } + + low_count = high_count - n->count; + + // if we found the symbol + if (high_count > target) + { + if (new_node != 0) + { + new_node->parent_context = n; + } + + symbol = n->symbol; + + coder.decode(low_count,high_count); + n->count += 8; + temp->total += 8; + + // move this node to the front + if (last) + { + last->next = n->next; + n->next = temp->child_context; + temp->child_context = n; + } + + + if (cur == 0) + { + if (local_order < order) + { + cur_order = local_order+1; + cur = n; + } + else + { + cur = n->parent_context; + cur_order = local_order; + } + } + + break; + + } + // if we hit the end of the context set without finding the symbol + else + { + if (new_node != 0) + { + new_node->parent_context = allocate_node(); + new_node = new_node->parent_context; + } + else + { + new_node = allocate_node(); + } + + n->next = new_node; + + // get the escape code + coder.decode(high_count,total_count); + } + + } + else // if (total_count == 0) + { + // this means that temp->child_context == 0 so we should make + // a new node here. + if (new_node != 0) + { + new_node->parent_context = allocate_node(); + new_node = new_node->parent_context; + } + else + { + new_node = allocate_node(); + } + + temp->child_context = new_node; + } + + if (cur == 0 && local_order < order) + { + cur = new_node; + cur_order = local_order+1; + } + + // fill out the new node + new_node->child_context = 0; + new_node->count = 4; + new_node->escapes = 0; + new_node->next = 0; + push(new_node); + new_node->total = 0; + temp->escapes += 4; + temp->total += 8; + + + if (temp != root) + { + temp = temp->parent_context; + --local_order; + continue; + } + + // since this is the root we are going to the order-(-1) context + // so we can just take care of that here. + target = coder.get_target(alphabet_size); + new_node->parent_context = root; + coder.decode(target,target+1); + symbol = target; + + if (cur == 0) + { + cur = root; + cur_order = 0; + } + break; + } + else + { + // there isn't enough space so we should rebuild the tree + destroy_tree(); + temp = cur; + local_order = cur_order; + cur = 0; + new_node = 0; + } + } // while (true) + + while (stack_size > 0) + { + pop()->symbol = static_cast(symbol); + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // private member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_decoder, + unsigned long total_nodes, + unsigned long order + > + edmk4::node* entropy_decoder_model_kernel_4:: + allocate_node ( + ) + { + node* temp; + temp = root + next_node; + ++next_node; + return temp; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_decoder, + unsigned long total_nodes, + unsigned long order + > + void entropy_decoder_model_kernel_4:: + destroy_tree ( + ) + { + next_node = 1; + root->child_context = 0; + root->escapes = 0; + root->total = 0; + cur = root; + cur_order = 0; + stack_size = 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_decoder, + unsigned long total_nodes, + unsigned long order + > + bool entropy_decoder_model_kernel_4:: + space_left ( + ) const + { + return (next_node < total_nodes); + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_decoder, + unsigned long total_nodes, + unsigned long order + > + void entropy_decoder_model_kernel_4:: + push ( + edmk4::node* n + ) + { + stack[stack_size] = n; + ++stack_size; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_decoder, + unsigned long total_nodes, + unsigned long order + > + edmk4::node* entropy_decoder_model_kernel_4:: + pop ( + ) + { + --stack_size; + return stack[stack_size]; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_decoder, + unsigned long total_nodes, + unsigned long order + > + void entropy_decoder_model_kernel_4:: + scale_counts ( + node* temp + ) + { + if (temp->escapes > 1) + temp->escapes >>= 1; + temp->total = temp->escapes; + + node* n = temp->child_context; + while (n != 0) + { + if (n->count > 1) + n->count >>= 1; + + temp->total += n->count; + n = n->next; + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_ENTROPY_DECODER_MODEL_KERNEl_4_ + diff --git a/ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_5.h b/ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_5.h new file mode 100644 index 000000000..9253e950b --- /dev/null +++ b/ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_5.h @@ -0,0 +1,793 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ENTROPY_DECODER_MODEL_KERNEl_5_ +#define DLIB_ENTROPY_DECODER_MODEL_KERNEl_5_ + +#include "../algs.h" +#include "entropy_decoder_model_kernel_abstract.h" +#include "../assert.h" + + +namespace dlib +{ + + namespace edmk5 + { + struct node + { + node* next; + node* child_context; + node* parent_context; + + unsigned short symbol; + unsigned short count; + unsigned short total; + unsigned short escapes; + }; + } + + + template < + unsigned long alphabet_size, + typename entropy_decoder, + unsigned long total_nodes, + unsigned long order + > + class entropy_decoder_model_kernel_5 + { + /*! + REQUIREMENTS ON total_nodes + - 4096 < total_nodes + - this is the total number of nodes that we will use in the tree + + REQUIREMENTS ON order + - 0 <= order + - this is the maximum depth-1 the tree will be allowed to go (note + that the root level is depth 0). + + + GENERAL NOTES + This implementation follows more or less the implementation + strategy laid out by Alistair Moffat in his paper + Implementing the PPM data compression scheme. Published in IEEE + Transactions on Communications, 38(11):1917-1921, 1990. + + The escape method used will be method D. + + This also uses Dmitry Shkarin's Information Inheritance scheme. + (described in "PPM: one step to practicality" and "Improving the + Efficiency of the PPM Algorithm") + + + INITIAL VALUE + - root == pointer to an array of total_nodes nodes + - next_node == 1 + - cur == root + - cur_order = 0 + - root->next == 0 + - root->parent_context == 0 + - root->child_context == 0 + - root->escapes == 0 + - root->total == 0 + - stack_size == 0 + - exc_used == false + - for all i: exc[i] == 0 + + CONVENTION + - exc_used == something_is_excluded() + - pop() == stack[stack_size-1].n and stack[stack_size-1].nc + - is_excluded(symbol) == bit symbol&0x1F from exc[symbol>>5] + - &get_entropy_decoder() == coder + - root == pointer to an array of total_nodes nodes. + this is also the root of the tree. + - if (next_node < total_nodes) then + - next_node == the next node in root that has not yet been allocated + + - root->next == 0 + - root->parent_context == 0 + + + - for every node in the tree: + { + - NOTATION: + - The "context" of a node is the string of symbols seen + when you go from the root of the tree down (down though + child context pointers) to the node, including the symbol at + the node itself. (note that the context of the root node + is "" or the empty string) + - A set of nodes is in the same "context set" if all the node's + contexts are of length n and all the node's contexts share + the same prefix of length n-1. + - The "child context set" of a node is a set of nodes with + contexts that are one symbol longer and prefixed by the node's + context. For example, if a node has a context "abc" then the + nodes for contexts "abca", "abcb", "abcc", etc. are all in + the child context set of the node. + - The "parent context" of a node is the context that is one + symbol shorter than the node's context and includes the + symbol in the node. So the parent context of a node with + context "abcd" would be the context "bcd". + + + - if (next != 0) then + - next == pointer to the next node in the same context set + - if (child_context != 0) then + - child_context == pointer to the first node of the child + context set for this node. + - escapes > 0 + - if (parent_context != 0) then + - parent_context == pointer to the parent context of this node. + - else + - this node is the root node of the tree + + + - if (this is not the root node) then + - symbol == the symbol represented with this node + - count == the number of times this symbol has been seen in its + parent context. + - else + - the root doesn't have a symbol. i.e. the context for the + root node is "" or the empty string. + + - total == The sum of the counts of all the nodes + in the child context set + escapes. + - escapes == the escape count for the context represented + by the node. + - count > 0 + } + + + - cur_order < order + - cur_order == the depth of the node cur in the tree. + (note that the root node has depth 0) + - cur == pointer to the node in the tree who's context matches + the most recent symbols we have seen. + + + !*/ + + typedef edmk5::node node; + + public: + + typedef entropy_decoder entropy_decoder_type; + + entropy_decoder_model_kernel_5 ( + entropy_decoder& coder + ); + + virtual ~entropy_decoder_model_kernel_5 ( + ); + + inline void clear( + ); + + inline void decode ( + unsigned long& symbol + ); + + entropy_decoder& get_entropy_decoder ( + ) { return coder; } + + static unsigned long get_alphabet_size ( + ) { return alphabet_size; } + + private: + + + inline void push ( + node* n, + node* nc + ); + /*! + requires + - stack_size < order + ensures + - #pop(a,b): a == n && b == nc + !*/ + + inline void pop ( + node*& n, + node*& nc + ); + /*! + requires + - stack_size > 0 + ensures + - returns the two nodes at the top of the stack + !*/ + + inline edmk5::node* allocate_node ( + ); + /*! + requires + - space_left() == true + ensures + - returns a pointer to a new node + !*/ + + inline bool space_left ( + ) const; + /*! + ensures + - returns true if there is at least 1 free node left. + - returns false otherwise + !*/ + + inline void exclude ( + unsigned short symbol + ); + /*! + ensures + - #is_excluded(symbol) == true + - #something_is_excluded() == true + !*/ + + inline bool is_excluded ( + unsigned short symbol + ); + /*! + ensures + - if (symbol has been excluded) then + - returns true + - else + - returns false + !*/ + + inline bool something_is_excluded ( + ); + /*! + ensures + - returns true if some symbol has been excluded. + returns false otherwise + !*/ + + inline void clear_exclusions ( + ); + /*! + ensures + - for all symbols #is_excluded(symbol) == false + - #something_is_excluded() == false + !*/ + + inline void scale_counts ( + node* n + ); + /*! + ensures + - divides all the counts in the child context set of n by 2. + - none of the nodes in the child context set will have a count of 0 + !*/ + + struct nodes + { + node* n; + node* nc; + }; + + entropy_decoder& coder; + unsigned long next_node; + node* root; + node* cur; + unsigned long cur_order; + unsigned long exc[alphabet_size/32+1]; + nodes stack[order+1]; + unsigned long stack_size; + bool exc_used; + + // restricted functions + entropy_decoder_model_kernel_5(entropy_decoder_model_kernel_5&); // copy constructor + entropy_decoder_model_kernel_5& operator=(entropy_decoder_model_kernel_5&); // assignment operator + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_decoder, + unsigned long total_nodes, + unsigned long order + > + entropy_decoder_model_kernel_5:: + entropy_decoder_model_kernel_5 ( + entropy_decoder& coder_ + ) : + coder(coder_), + next_node(1), + cur_order(0), + stack_size(0) + { + COMPILE_TIME_ASSERT( 1 < alphabet_size && alphabet_size < 65535); + COMPILE_TIME_ASSERT( 4096 < total_nodes ); + + root = new node[total_nodes]; + cur = root; + + root->child_context = 0; + root->escapes = 0; + root->next = 0; + root->parent_context = 0; + root->total = 0; + + clear_exclusions(); + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_decoder, + unsigned long total_nodes, + unsigned long order + > + entropy_decoder_model_kernel_5:: + ~entropy_decoder_model_kernel_5 ( + ) + { + delete [] root; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_decoder, + unsigned long total_nodes, + unsigned long order + > + void entropy_decoder_model_kernel_5:: + clear( + ) + { + next_node = 1; + root->child_context = 0; + root->escapes = 0; + root->total = 0; + cur = root; + cur_order = 0; + stack_size = 0; + + clear_exclusions(); + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_decoder, + unsigned long total_nodes, + unsigned long order + > + void entropy_decoder_model_kernel_5:: + decode ( + unsigned long& symbol + ) + { + node* temp = cur; + cur = 0; + unsigned long low_count, high_count, total_count; + unsigned long target; + node* new_node = 0; + + // local_order will track the level of temp in the tree + unsigned long local_order = cur_order; + + + unsigned short c; // c == t(a|sk) + unsigned short t; // t == T(sk) + + + if (something_is_excluded()) + clear_exclusions(); + + while (true) + { + high_count = 0; + if (space_left()) + { + total_count = temp->total; + + if (total_count > 0) + { + // check if we need to scale the counts + if (total_count > 10000) + { + scale_counts(temp); + total_count = temp->total; + } + + if (something_is_excluded()) + { + node* n = temp->child_context; + total_count = temp->escapes; + while (true) + { + if (is_excluded(n->symbol) == false) + { + total_count += n->count; + } + if (n->next == 0) + break; + n = n->next; + } + } + + + + target = coder.get_target(total_count); + + // find either the symbol we are looking for or the + // end of the context set + node* n = temp->child_context; + node* last = 0; + while (true) + { + if (is_excluded(n->symbol) == false) + { + high_count += n->count; + exclude(n->symbol); + } + + + if (high_count > target || n->next == 0) + break; + last = n; + n = n->next; + } + + + // if we found the symbol + if (high_count > target) + { + low_count = high_count - n->count; + + if (new_node != 0) + { + new_node->parent_context = n; + } + + symbol = n->symbol; + + coder.decode(low_count,high_count); + c = n->count += 8; + t = temp->total += 8; + + + // move this node to the front + if (last) + { + last->next = n->next; + n->next = temp->child_context; + temp->child_context = n; + } + + if (cur == 0) + { + if (local_order < order) + { + cur_order = local_order+1; + cur = n; + } + else + { + cur = n->parent_context; + cur_order = local_order; + } + } + + break; + + + } + // if we hit the end of the context set without finding the symbol + else + { + if (new_node != 0) + { + new_node->parent_context = allocate_node(); + new_node = new_node->parent_context; + } + else + { + new_node = allocate_node(); + } + + n->next = new_node; + + // get the escape code + coder.decode(high_count,total_count); + } + + } + else // if (total_count == 0) + { + // this means that temp->child_context == 0 so we should make + // a new node here. + if (new_node != 0) + { + new_node->parent_context = allocate_node(); + new_node = new_node->parent_context; + } + else + { + new_node = allocate_node(); + } + + temp->child_context = new_node; + } + + if (cur == 0 && local_order < order) + { + cur = new_node; + cur_order = local_order+1; + } + + // fill out the new node + new_node->child_context = 0; + new_node->escapes = 0; + new_node->next = 0; + push(new_node,temp); + new_node->total = 0; + + + + if (temp != root) + { + temp = temp->parent_context; + --local_order; + continue; + } + + t = 2056; + c = 8; + + // since this is the root we are going to the order-(-1) context + // so we can just take care of that here. + target = coder.get_target(alphabet_size); + new_node->parent_context = root; + coder.decode(target,target+1); + symbol = target; + + if (cur == 0) + { + cur = root; + cur_order = 0; + } + break; + } + else + { + // there isn't enough space so we should rebuild the tree + clear(); + temp = cur; + local_order = cur_order; + cur = 0; + new_node = 0; + } + } // while (true) + + // initialize the counts and symbol for any new nodes we have added + // to the tree. + node* n, *nc; + while (stack_size > 0) + { + pop(n,nc); + + n->symbol = static_cast(symbol); + + // if nc is not a determnistic context + if (nc->total) + { + unsigned long temp2 = t-c+nc->total - nc->escapes - nc->escapes; + unsigned long temp = nc->total; + temp *= c; + temp /= (temp2|1); // this oring by 1 is just to make sure that temp2 is never zero + temp += 2; + if (temp > 50000) temp = 50000; + n->count = static_cast(temp); + + + nc->escapes += 4; + nc->total += static_cast(temp) + 4; + } + else + { + n->count = 3 + 5*(c)/(t-c); + + nc->escapes = 4; + nc->total = n->count + 4; + } + + while (nc->total > 10000) + { + scale_counts(nc); + } + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // private member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_decoder, + unsigned long total_nodes, + unsigned long order + > + edmk5::node* entropy_decoder_model_kernel_5:: + allocate_node ( + ) + { + node* temp; + temp = root + next_node; + ++next_node; + return temp; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_decoder, + unsigned long total_nodes, + unsigned long order + > + bool entropy_decoder_model_kernel_5:: + space_left ( + ) const + { + return (next_node < total_nodes); + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_decoder, + unsigned long total_nodes, + unsigned long order + > + void entropy_decoder_model_kernel_5:: + exclude ( + unsigned short symbol + ) + { + exc_used = true; + unsigned long temp = 1; + temp <<= symbol&0x1F; + exc[symbol>>5] |= temp; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_decoder, + unsigned long total_nodes, + unsigned long order + > + bool entropy_decoder_model_kernel_5:: + is_excluded ( + unsigned short symbol + ) + { + unsigned long temp = 1; + temp <<= symbol&0x1F; + return ((exc[symbol>>5]&temp) != 0); + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_decoder, + unsigned long total_nodes, + unsigned long order + > + void entropy_decoder_model_kernel_5:: + clear_exclusions ( + ) + { + exc_used = false; + for (unsigned long i = 0; i < alphabet_size/32+1; ++i) + { + exc[i] = 0; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_decoder, + unsigned long total_nodes, + unsigned long order + > + void entropy_decoder_model_kernel_5:: + push ( + node* n, + node* nc + ) + { + stack[stack_size].n = n; + stack[stack_size].nc = nc; + ++stack_size; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_decoder, + unsigned long total_nodes, + unsigned long order + > + void entropy_decoder_model_kernel_5:: + pop ( + node*& n, + node*& nc + ) + { + --stack_size; + n = stack[stack_size].n; + nc = stack[stack_size].nc; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_decoder, + unsigned long total_nodes, + unsigned long order + > + bool entropy_decoder_model_kernel_5:: + something_is_excluded ( + ) + { + return exc_used; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_decoder, + unsigned long total_nodes, + unsigned long order + > + void entropy_decoder_model_kernel_5:: + scale_counts ( + node* temp + ) + { + if (temp->escapes > 1) + temp->escapes >>= 1; + temp->total = temp->escapes; + + node* n = temp->child_context; + while (n != 0) + { + if (n->count > 1) + n->count >>= 1; + + temp->total += n->count; + n = n->next; + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_ENTROPY_DECODER_MODEL_KERNEl_5_ + + diff --git a/ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_6.h b/ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_6.h new file mode 100644 index 000000000..dc23f10eb --- /dev/null +++ b/ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_6.h @@ -0,0 +1,131 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ENTROPY_DECODER_MODEL_KERNEl_6_ +#define DLIB_ENTROPY_DECODER_MODEL_KERNEl_6_ + +#include "../algs.h" +#include "entropy_decoder_model_kernel_abstract.h" +#include "../assert.h" + +namespace dlib +{ + + template < + unsigned long alphabet_size, + typename entropy_decoder + > + class entropy_decoder_model_kernel_6 + { + /*! + INITIAL VALUE + This object has no state + + CONVENTION + &get_entropy_decoder() == coder + + This is an order-(-1) model. So it doesn't really do anything. + Every symbol has the same probability. + !*/ + + public: + + typedef entropy_decoder entropy_decoder_type; + + entropy_decoder_model_kernel_6 ( + entropy_decoder& coder + ); + + virtual ~entropy_decoder_model_kernel_6 ( + ); + + inline void clear( + ); + + inline void decode ( + unsigned long& symbol + ); + + entropy_decoder& get_entropy_decoder ( + ) { return coder; } + + static unsigned long get_alphabet_size ( + ) { return alphabet_size; } + + private: + + entropy_decoder& coder; + + // restricted functions + entropy_decoder_model_kernel_6(entropy_decoder_model_kernel_6&); // copy constructor + entropy_decoder_model_kernel_6& operator=(entropy_decoder_model_kernel_6&); // assignment operator + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_decoder + > + entropy_decoder_model_kernel_6:: + entropy_decoder_model_kernel_6 ( + entropy_decoder& coder_ + ) : + coder(coder_) + { + COMPILE_TIME_ASSERT( 1 < alphabet_size && alphabet_size < 65535 ); + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_decoder + > + entropy_decoder_model_kernel_6:: + ~entropy_decoder_model_kernel_6 ( + ) + { + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_decoder + > + void entropy_decoder_model_kernel_6:: + clear( + ) + { + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_decoder + > + void entropy_decoder_model_kernel_6:: + decode ( + unsigned long& symbol + ) + { + unsigned long target; + + target = coder.get_target(alphabet_size); + coder.decode(target,target+1); + + symbol = target; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_ENTROPY_DECODER_MODEL_KERNEl_6_ + diff --git a/ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_abstract.h b/ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_abstract.h new file mode 100644 index 000000000..5b2deabd7 --- /dev/null +++ b/ml/dlib/dlib/entropy_decoder_model/entropy_decoder_model_kernel_abstract.h @@ -0,0 +1,116 @@ +// Copyright (C) 2004 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_ENTROPY_DECODER_MODEL_KERNEl_ABSTRACT_ +#ifdef DLIB_ENTROPY_DECODER_MODEL_KERNEl_ABSTRACT_ + +#include "../algs.h" + +namespace dlib +{ + + template < + unsigned long alphabet_size, + typename entropy_decoder + > + class entropy_decoder_model + { + /*! + REQUIREMENTS ON alphabet_size + 1 < alphabet_size < 65535 + + REQUIREMENTS ON entropy_decoder + is an implementation of entropy_decoder/entropy_decoder_kernel_abstract.h + + INITIAL VALUE + Initially this object is at some predefined empty or ground state. + + WHAT THIS OBJECT REPRESENTS + This object represents some kind of statistical model. You + can use it to read symbols from an entropy_decoder and it will calculate + the cumulative counts/probabilities and manage contexts for you. + + Note that all implementations of entropy_encoder_model and + entropy_decoder_model are paired. This means that if you use + entropy_encoder_model_kernel_n to encode something then you must + use the corresponding entropy_decoder_model_kernel_n to decode it. + + Also note that this object does not perform any buffering of symbols. It + reads them from its associated entropy_decoder simply as it needs them. + This makes it safe to use multiple entropy_decoder_model objects with + a single entropy_decoder without them trampling each other. + !*/ + + public: + + typedef entropy_decoder entropy_decoder_type; + + entropy_decoder_model ( + entropy_decoder& coder + ); + /*! + ensures + - #*this is properly initialized + - &#get_entropy_decoder() == &coder + throws + - any exception + !*/ + + virtual ~entropy_decoder_model ( + ); + /*! + ensures + - all memory associated with *this has been released + !*/ + + void clear( + ); + /*! + ensures + - #*this has its initial value + - does not modify get_entropy_decoder() + throws + - any exception + if this exception is thrown then *this is unusable + until clear() is called and succeeds + !*/ + + void decode ( + unsigned long& symbol + ); + /*! + ensures + - decodes the next symbol + - #symbol == the next symbol + - #symbol < alphabet_size + throws + - any exception + If this exception is thrown then #*this is unusable until + clear() is called and succeeds. + !*/ + + entropy_decoder& get_entropy_decoder ( + ); + /*! + ensures + - returns a reference to the entropy_decoder used by *this + !*/ + + static unsigned long get_alphabet_size ( + ); + /*! + ensures + - returns alphabet_size + !*/ + + private: + + // restricted functions + entropy_decoder_model(entropy_decoder_model&); // copy constructor + entropy_decoder_model& operator=(entropy_decoder_model&); // assignment operator + + }; + +} + +#endif // DLIB_ENTROPY_DECODER_MODEL_KERNEl_ABSTRACT_ + diff --git a/ml/dlib/dlib/entropy_encoder.h b/ml/dlib/dlib/entropy_encoder.h new file mode 100644 index 000000000..5afda1976 --- /dev/null +++ b/ml/dlib/dlib/entropy_encoder.h @@ -0,0 +1,43 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ENTROPY_ENCODEr_ +#define DLIB_ENTROPY_ENCODEr_ + +#include "entropy_encoder/entropy_encoder_kernel_1.h" +#include "entropy_encoder/entropy_encoder_kernel_2.h" +#include "entropy_encoder/entropy_encoder_kernel_c.h" + + + + +namespace dlib +{ + + + class entropy_encoder + { + entropy_encoder() {} + + + public: + + //----------- kernels --------------- + + // kernel_1a + typedef entropy_encoder_kernel_1 + kernel_1a; + typedef entropy_encoder_kernel_c + kernel_1a_c; + + + // kernel_2a + typedef entropy_encoder_kernel_2 + kernel_2a; + typedef entropy_encoder_kernel_c + kernel_2a_c; + + }; +} + +#endif // DLIB_ENTROPY_ENCODEr_ + diff --git a/ml/dlib/dlib/entropy_encoder/entropy_encoder_kernel_1.cpp b/ml/dlib/dlib/entropy_encoder/entropy_encoder_kernel_1.cpp new file mode 100644 index 000000000..effcf3123 --- /dev/null +++ b/ml/dlib/dlib/entropy_encoder/entropy_encoder_kernel_1.cpp @@ -0,0 +1,239 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ENTROPY_ENCODER_KERNEL_1_CPp_ +#define DLIB_ENTROPY_ENCODER_KERNEL_1_CPp_ +#include "entropy_encoder_kernel_1.h" +#include +#include + +namespace dlib +{ + + +// ---------------------------------------------------------------------------------------- + + entropy_encoder_kernel_1:: + entropy_encoder_kernel_1( + ) : + initial_low(0x00000001), + initial_high(0xffffffff), + out(0), + low(initial_low), + high(initial_high), + buf(0), + buf_used(0) + { + } + +// ---------------------------------------------------------------------------------------- + + entropy_encoder_kernel_1:: + ~entropy_encoder_kernel_1 ( + ) + { + try { + if (out != 0) + { + flush(); + } + } catch (...) {} + } + +// ---------------------------------------------------------------------------------------- + + void entropy_encoder_kernel_1:: + clear( + ) + { + if (out != 0) + { + flush(); + } + out = 0; + } + +// ---------------------------------------------------------------------------------------- + + void entropy_encoder_kernel_1:: + set_stream ( + std::ostream& out_ + ) + { + if (out != 0) + { + // if a stream is currently set then flush the buffers to it before + // we switch to the new stream + flush(); + } + + out = &out_; + streambuf = out_.rdbuf(); + + // reset the encoder state + buf_used = 0; + buf = 0; + low = initial_low; + high = initial_high; + } + +// ---------------------------------------------------------------------------------------- + + bool entropy_encoder_kernel_1:: + stream_is_set ( + ) const + { + if (out != 0) + return true; + else + return false; + } + +// ---------------------------------------------------------------------------------------- + + std::ostream& entropy_encoder_kernel_1:: + get_stream ( + ) const + { + return *out; + } + +// ---------------------------------------------------------------------------------------- + + void entropy_encoder_kernel_1:: + encode ( + uint32 low_count, + uint32 high_count, + uint32 total + ) + { + // note that we must add one because of the convention that + // high == the real upper range minus 1 + uint32 r = (high-low+1)/total; + + // note that we must subtract 1 to preserve the convention that + // high == the real upper range - 1 + high = low + r*high_count-1; + low = low + r*low_count; + + + while (true) + { + + // if the highest order bit in high and low is the same + if ( low >= 0x80000000 || high < 0x80000000) + { + // if buf is full then write it out + if (buf_used == 8) + { + if (streambuf->sputn(reinterpret_cast(&buf),1)==0) + { + throw std::ios_base::failure("error occurred in the entropy_encoder object"); + } + buf = 0; + buf_used = 0; + } + + + // write the high order bit from low into buf + buf <<= 1; + ++buf_used; + if (low&0x80000000) + buf |= 0x1; + + // roll off the bit we just wrote to buf + low <<= 1; + high <<= 1; + high |= 1; // note that it is ok to add one to high here because + // of the convention that high == real upper range - 1. + // so that means that if we want to shift the upper range + // left by one then we must shift a one into high also + // since real upper range == high + 0.999999999... + + // make sure low is never zero + if (low == 0) + low = 1; + } + // if the distance between high and low is small and there aren't + // any bits we can roll off then round low up or high down. + else if (high-low < 0x10000) + { + if (high == 0x80000000) + high = 0x7fffffff; + else + low = 0x80000000; + } + else + { + break; + } + } // while (true) + + } + +// ---------------------------------------------------------------------------------------- + + void entropy_encoder_kernel_1:: + flush ( + ) + { + // flush the next 4 or 5 bytes that are buffered + // thats whatever is contained in buf and then all of low plus any extra + // bits needed to pad that to be an even 4 or 5 bytes + + + if (buf_used != 8) + { + buf <<= (8-buf_used); + buf |= static_cast(low>>(24+buf_used)); + low <<= (8-buf_used); + } + + if (streambuf->sputn(reinterpret_cast(&buf),1) == 0) + throw std::ios_base::failure("error occurred in the entropy_encoder object"); + + + + buf = static_cast((low >> 24)&0xFF); + if (streambuf->sputn(reinterpret_cast(&buf),1) == 0) + throw std::ios_base::failure("error occurred in the entropy_encoder object"); + + + + + buf = static_cast((low >> 16)&0xFF); + if (streambuf->sputn(reinterpret_cast(&buf),1)==0) + throw std::ios_base::failure("error occurred in the entropy_encoder object"); + + + + buf = static_cast((low >> 8)&0xFF); + if (streambuf->sputn(reinterpret_cast(&buf),1)==0) + throw std::ios_base::failure("error occurred in the entropy_encoder object"); + + + + if (buf_used != 0) + { + buf = static_cast((low)&0xFF); + if (streambuf->sputn(reinterpret_cast(&buf),1)==0) + throw std::ios_base::failure("error occurred in the entropy_encoder object"); + } + + + + // make sure the stream buffer flushes to its I/O channel + streambuf->pubsync(); + + + // reset the encoder state + buf_used = 0; + buf = 0; + low = initial_low; + high = initial_high; + } + +// ---------------------------------------------------------------------------------------- + +} +#endif // DLIB_ENTROPY_ENCODER_KERNEL_1_CPp_ + diff --git a/ml/dlib/dlib/entropy_encoder/entropy_encoder_kernel_1.h b/ml/dlib/dlib/entropy_encoder/entropy_encoder_kernel_1.h new file mode 100644 index 000000000..ccaaf9824 --- /dev/null +++ b/ml/dlib/dlib/entropy_encoder/entropy_encoder_kernel_1.h @@ -0,0 +1,119 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ENTROPY_ENCODER_KERNEl_1_ +#define DLIB_ENTROPY_ENCODER_KERNEl_1_ + +#include "../algs.h" +#include "entropy_encoder_kernel_abstract.h" +#include +#include "../uintn.h" + +namespace dlib +{ + + class entropy_encoder_kernel_1 + { + /*! + GENERAL NOTES + this encoder is implemented using arithmetic coding + + INITIAL VALUE + out == 0 + buf_used == 0 + buf == 0 + initial_low == 0x00000001 (slightly more than zero) + initial_high == 0xffffffff (slightly less than one, 0.99999999976717) + low == initial_low + high == initial_high + + CONVENTION + if (out != 0) + *out == get_stream() + true == stream_is_set() + streambuf == out->rdbuf() + else + false == stream_is_set() + + buf == used to accumulate bits before writing them to out. + buf_used == the number of low order bits in buf that are currently + in use + low == the low end of the range used for arithmetic encoding. + this number is used as a 32bit fixed point real number. + the point is fixed just before the first bit, so it is + always in the range [0,1) + + low is also never allowed to be zero to avoid overflow + in the calculation (high-low+1)/total. + + high == the high end of the range - 1 used for arithmetic encoding. + this number is used as a 32bit fixed point real number. + the point is fixed just before the first bit, so when we + interpret high as a real number then it is always in the + range [0,1) + + the range for arithmetic encoding is always + [low,high + 0.9999999...) the 0.9999999... is why + high == real upper range - 1 + + !*/ + + public: + + entropy_encoder_kernel_1 ( + ); + + virtual ~entropy_encoder_kernel_1 ( + ); + + void clear( + ); + + void set_stream ( + std::ostream& out + ); + + bool stream_is_set ( + ) const; + + std::ostream& get_stream ( + ) const; + + void encode ( + uint32 low_count, + uint32 high_count, + uint32 total + ); + + private: + + void flush ( + ); + /*! + requires + out != 0 (i.e. there is a stream object to flush the data to + !*/ + + // restricted functions + entropy_encoder_kernel_1(entropy_encoder_kernel_1&); // copy constructor + entropy_encoder_kernel_1& operator=(entropy_encoder_kernel_1&); // assignment operator + + // data members + const uint32 initial_low; + const uint32 initial_high; + std::ostream* out; + uint32 low; + uint32 high; + unsigned char buf; + uint32 buf_used; + std::streambuf* streambuf; + + }; + +} + +#ifdef NO_MAKEFILE +#include "entropy_encoder_kernel_1.cpp" +#endif + +#endif // DLIB_ENTROPY_ENCODER_KERNEl_1_ + diff --git a/ml/dlib/dlib/entropy_encoder/entropy_encoder_kernel_2.cpp b/ml/dlib/dlib/entropy_encoder/entropy_encoder_kernel_2.cpp new file mode 100644 index 000000000..4f64a6155 --- /dev/null +++ b/ml/dlib/dlib/entropy_encoder/entropy_encoder_kernel_2.cpp @@ -0,0 +1,233 @@ +// Copyright (C) 2004 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ENTROPY_ENCODER_KERNEL_2_CPp_ +#define DLIB_ENTROPY_ENCODER_KERNEL_2_CPp_ +#include "entropy_encoder_kernel_2.h" +#include +#include + +namespace dlib +{ + + +// ---------------------------------------------------------------------------------------- + + entropy_encoder_kernel_2:: + entropy_encoder_kernel_2( + ) : + initial_low(0x00000001), + initial_high(0xffffffff), + out(0), + low(initial_low), + high(initial_high) + { + } + +// ---------------------------------------------------------------------------------------- + + entropy_encoder_kernel_2:: + ~entropy_encoder_kernel_2 ( + ) + { + try { + if (out != 0) + { + flush(); + } + } catch (...) {} + } + +// ---------------------------------------------------------------------------------------- + + void entropy_encoder_kernel_2:: + clear( + ) + { + if (out != 0) + { + flush(); + } + out = 0; + } + +// ---------------------------------------------------------------------------------------- + + void entropy_encoder_kernel_2:: + set_stream ( + std::ostream& out_ + ) + { + if (out != 0) + { + // if a stream is currently set then flush the buffers to it before + // we switch to the new stream + flush(); + } + + out = &out_; + streambuf = out_.rdbuf(); + + // reset the encoder state + low = initial_low; + high = initial_high; + } + +// ---------------------------------------------------------------------------------------- + + bool entropy_encoder_kernel_2:: + stream_is_set ( + ) const + { + if (out != 0) + return true; + else + return false; + } + +// ---------------------------------------------------------------------------------------- + + std::ostream& entropy_encoder_kernel_2:: + get_stream ( + ) const + { + return *out; + } + +// ---------------------------------------------------------------------------------------- + + void entropy_encoder_kernel_2:: + encode ( + uint32 low_count, + uint32 high_count, + uint32 total + ) + { + // note that we must add one because of the convention that + // high == the real upper range minus 1 + uint32 r = (high-low+1)/total; + + // note that we must subtract 1 to preserve the convention that + // high == the real upper range - 1 + high = low + r*high_count-1; + low = low + r*low_count; + + + + while (true ) + { + + // if high and low don't have the same 8 high order bits + if ((high&0xFF000000) != (low&0xFF000000)) + { + // if the distance between high and low is small and there aren't + // any bits we can roll off then force high and low to have common high + // order bits. + if ((high-low < 0x10000)) + { + if (high-low > 0x1000) + { + high>>=1; + low>>=1; + high = low = high+low; + high += 0xFF; + low -= 0xFF; + } + else /**/ + { + high>>=1; + low>>=1; + high = low = high+low; + } + } + else + { + // there are no bits to roll off and high and low are not + // too close so just quit the loop + break; + } + + } + // else if there are 8 bits we can roll off + else + { + // write the 8 high order bits from low into buf + unsigned char buf = static_cast(low>>24); + + + // roll off the bits we just wrote to buf + high <<= 8; + low <<= 8; + high |= 0xFF; // note that it is ok to add 0xFF to high here because + // of the convention that high == real upper range - 1. + // so that means that if we want to shift the upper range + // left by one then we must shift a one into high also + // since real upper range == high + 0.999999999... + + // make sure low is never zero + if (low == 0) + low = 1; + + // write buf to the output stream + if (streambuf->sputn(reinterpret_cast(&buf),1)==0) + { + throw std::ios_base::failure("error occurred in the entropy_encoder object"); + } + + } + + } // while (true) + + } + +// ---------------------------------------------------------------------------------------- + + void entropy_encoder_kernel_2:: + flush ( + ) + { + + // flush low to the output stream + + + unsigned char buf; + + + buf = static_cast((low >> 24)&0xFF); + if (streambuf->sputn(reinterpret_cast(&buf),1) == 0) + throw std::ios_base::failure("error occurred in the entropy_encoder object"); + + + + + buf = static_cast((low >> 16)&0xFF); + if (streambuf->sputn(reinterpret_cast(&buf),1)==0) + throw std::ios_base::failure("error occurred in the entropy_encoder object"); + + + + buf = static_cast((low >> 8)&0xFF); + if (streambuf->sputn(reinterpret_cast(&buf),1)==0) + throw std::ios_base::failure("error occurred in the entropy_encoder object"); + + + buf = static_cast((low)&0xFF); + if (streambuf->sputn(reinterpret_cast(&buf),1)==0) + throw std::ios_base::failure("error occurred in the entropy_encoder object"); + + + + + // make sure the stream buffer flushes to its I/O channel + streambuf->pubsync(); + + + // reset the encoder state + low = initial_low; + high = initial_high; + } + +// ---------------------------------------------------------------------------------------- + +} +#endif // DLIB_ENTROPY_ENCODER_KERNEL_2_CPp_ + diff --git a/ml/dlib/dlib/entropy_encoder/entropy_encoder_kernel_2.h b/ml/dlib/dlib/entropy_encoder/entropy_encoder_kernel_2.h new file mode 100644 index 000000000..71a7503c6 --- /dev/null +++ b/ml/dlib/dlib/entropy_encoder/entropy_encoder_kernel_2.h @@ -0,0 +1,112 @@ +// Copyright (C) 2004 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ENTROPY_ENCODER_KERNEl_2_ +#define DLIB_ENTROPY_ENCODER_KERNEl_2_ + +#include "../algs.h" +#include "entropy_encoder_kernel_abstract.h" +#include +#include "../uintn.h" + +namespace dlib +{ + + class entropy_encoder_kernel_2 + { + /*! + GENERAL NOTES + this encoder is implemented using "range" coding + + INITIAL VALUE + out == 0 + initial_low == 0x00000001 (slightly more than zero) + initial_high == 0xffffffff (slightly less than one, 0.99999999976717) + low == initial_low + high == initial_high + + CONVENTION + if (out != 0) + *out == get_stream() + true == stream_is_set() + streambuf == out->rdbuf() + else + false == stream_is_set() + + + low == the low end of the range used for arithmetic encoding. + this number is used as a 32bit fixed point real number. + the point is fixed just before the first bit, so it is + always in the range [0,1) + + low is also never allowed to be zero to avoid overflow + in the calculation (high-low+1)/total. + + high == the high end of the range - 1 used for arithmetic encoding. + this number is used as a 32bit fixed point real number. + the point is fixed just before the first bit, so when we + interpret high as a real number then it is always in the + range [0,1) + + the range for arithmetic encoding is always + [low,high + 0.9999999...) the 0.9999999... is why + high == real upper range - 1 + !*/ + + public: + + entropy_encoder_kernel_2 ( + ); + + virtual ~entropy_encoder_kernel_2 ( + ); + + void clear( + ); + + void set_stream ( + std::ostream& out + ); + + bool stream_is_set ( + ) const; + + std::ostream& get_stream ( + ) const; + + void encode ( + uint32 low_count, + uint32 high_count, + uint32 total + ); + + private: + + void flush ( + ); + /*! + requires + out != 0 (i.e. there is a stream object to flush the data to + !*/ + + // restricted functions + entropy_encoder_kernel_2(entropy_encoder_kernel_2&); // copy constructor + entropy_encoder_kernel_2& operator=(entropy_encoder_kernel_2&); // assignment operator + + // data members + const uint32 initial_low; + const uint32 initial_high; + std::ostream* out; + uint32 low; + uint32 high; + std::streambuf* streambuf; + + }; + +} + +#ifdef NO_MAKEFILE +#include "entropy_encoder_kernel_2.cpp" +#endif + +#endif // DLIB_ENTROPY_ENCODER_KERNEl_2_ + diff --git a/ml/dlib/dlib/entropy_encoder/entropy_encoder_kernel_abstract.h b/ml/dlib/dlib/entropy_encoder/entropy_encoder_kernel_abstract.h new file mode 100644 index 000000000..48af93307 --- /dev/null +++ b/ml/dlib/dlib/entropy_encoder/entropy_encoder_kernel_abstract.h @@ -0,0 +1,161 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_ENTROPY_ENCODER_KERNEl_ABSTRACT_ +#ifdef DLIB_ENTROPY_ENCODER_KERNEl_ABSTRACT_ + +#include "../algs.h" +#include +#include "../uintn.h" + +namespace dlib +{ + + class entropy_encoder + { + /*! + INITIAL VALUE + stream_is_set() == false + + + WHAT THIS OBJECT REPRESENTS + This object represents an entropy encoder (could be implemented as an + arithmetic encoder for example). + + Note that all implementations of entropy_encoder and entropy_decoder + are paired. This means that if you use entropy_encoder_kernel_n to + encode something then you must use the corresponding + entropy_decoder_kernel_n to decode it. + + NOTATION: + At any moment each symbol has a certain probability of appearing in + the input stream. These probabilities may change as each symbol is + encountered and the probability model is updated accordingly. + + + let P(i) be a function which gives the probability of seeing the ith + symbol of an N symbol alphabet BEFORE the probability model is updated + to account for the current symbol. ( The domain of P(i) is from 0 to N-1. ) + + for each i: P(i) == COUNT/TOTAL where COUNT and TOTAL are integers. + and TOTAL is the same number for all P(i) but COUNT may vary. + + let LOW_COUNT(i) be the sum of all P(x)*TOTAL from x == 0 to x == i-1 + (note that LOW_COUNT(0) == 0) + let HIGH_COUNT(i) be the sum of all P(x)*TOTAL from x == 0 to x == i + !*/ + + public: + + entropy_encoder ( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc + !*/ + + virtual ~entropy_encoder ( + ); + /*! + ensures + - all memory associated with *this has been released + - if (stream_is_set()) then + - any buffered data in *this will be written to get_stream() + !*/ + + void clear( + ); + /*! + ensures + - #*this has its initial value + - if (stream_is_set()) then + - any buffered data in *this will be written to get_stream() + - clears any memory of all previous calls to encode() from #*this + throws + - std::ios_base::failure + if (stream_is_set() && there was a problem writing to get_stream()) + then this exception will be thrown. #*this will be unusable until + clear() is called and succeeds + - any other exception + if this exception is thrown then #*this is unusable + until clear() is called and succeeds + !*/ + + + void set_stream ( + std::ostream& out + ); + /*! + ensures + - #get_stream() == out + - #stream_is_set() == true + - if (stream_is_set()) then + - any buffered data in *this will be written to get_stream() + - clears any memory of all previous calls to encode() from #*this + throws + - std::ios_base::failure + if (stream_is_set() && there was a problem writing to get_stream()) + then this exception will be thrown. #*this will be unusable until + clear() is called and succeeds + - any other exception + if this exception is thrown then #*this is unusable + until clear() is called and succeeds + !*/ + + bool stream_is_set ( + ) const; + /*! + ensures + - returns true if a stream has been associated with *this by calling + set_stream() + !*/ + + std::ostream& get_stream ( + ) const; + /*! + requires + - stream_is_set() == true + ensures + - returns a reference to the ostream object that *this writes its + encoded data to + !*/ + + void encode ( + uint32 low_count, + uint32 high_count, + uint32 total + ); + /*! + requires + - 0 < total < 65536 (2^16) + - total == TOTAL + - low_count < high_count <= total + - stream_is_set() == true + ensures + - encodes the symbol S where: + - LOW_COUNT(S) == low_count + - HIGH_COUNT(S) == high_count + throws + - std::ios_base::failure + if (there was a problem writing to get_stream()) then + this exception will be thrown. #*this will be unusable until + clear() is called and succeeds + - any other exception + if this exception is thrown then #*this is unusable + until clear() is called and succeeds + !*/ + + + private: + + // restricted functions + entropy_encoder(entropy_encoder&); // copy constructor + entropy_encoder& operator=(entropy_encoder&); // assignment operator + + }; + +} + +#endif // DLIB_ENTROPY_ENCODER_KERNEl_ABSTRACT_ + diff --git a/ml/dlib/dlib/entropy_encoder/entropy_encoder_kernel_c.h b/ml/dlib/dlib/entropy_encoder/entropy_encoder_kernel_c.h new file mode 100644 index 000000000..f11241ecc --- /dev/null +++ b/ml/dlib/dlib/entropy_encoder/entropy_encoder_kernel_c.h @@ -0,0 +1,112 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ENTROPY_ENCODER_KERNEl_C_ +#define DLIB_ENTROPY_ENCODER_KERNEl_C_ + +#include "entropy_encoder_kernel_abstract.h" +#include "../algs.h" +#include "../assert.h" +#include + +namespace dlib +{ + + template < + typename encoder + > + class entropy_encoder_kernel_c : public encoder + { + + public: + std::ostream& get_stream ( + ) const; + + void encode ( + uint32 low_count, + uint32 high_count, + uint32 total + ); + + void flush ( + ); + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename encoder + > + std::ostream& entropy_encoder_kernel_c:: + get_stream ( + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT( this->stream_is_set() == true, + "\tstd::ostream& entropy_encoder::get_stream()" + << "\n\tyou must set a stream for this object before you can get it" + << "\n\tthis: " << this + ); + + // call the real function + return encoder::get_stream(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename encoder + > + void entropy_encoder_kernel_c:: + encode ( + uint32 low_count, + uint32 high_count, + uint32 total + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( (0 < total) && (total < 65536) && (low_count < high_count) && (high_count <= total) && + (this->stream_is_set() == true), + "\tvoid entropy_encoder::encode()" + << "\n\trefer to the ensures clause for this function for further information" + << "\n\tthis: " << this + << "\n\ttotal: " << total + << "\n\tlow_count: " << low_count + << "\n\thigh_count: " << high_count + << "\n\tis_stream_set(): " << (this->stream_is_set() ? "true" : "false" ) + ); + + // call the real function + encoder::encode(low_count,high_count,total); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename encoder + > + void entropy_encoder_kernel_c:: + flush ( + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( this->stream_is_set() == true, + "\tvoid entropy_encoder::flush()" + << "\n\tyou must set a stream for this object before you can flush to it" + << "\n\tthis: " << this + ); + + // call the real function + encoder::flush(); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_ENTROPY_ENCODER_KERNEl_C_ + diff --git a/ml/dlib/dlib/entropy_encoder_model.h b/ml/dlib/dlib/entropy_encoder_model.h new file mode 100644 index 000000000..465377e97 --- /dev/null +++ b/ml/dlib/dlib/entropy_encoder_model.h @@ -0,0 +1,146 @@ +// Copyright (C) 2004 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ENTROPY_ENCODER_MODEl_ +#define DLIB_ENTROPY_ENCODER_MODEl_ + +#include "entropy_encoder_model/entropy_encoder_model_kernel_1.h" +#include "entropy_encoder_model/entropy_encoder_model_kernel_2.h" +#include "entropy_encoder_model/entropy_encoder_model_kernel_3.h" +#include "entropy_encoder_model/entropy_encoder_model_kernel_4.h" +#include "entropy_encoder_model/entropy_encoder_model_kernel_5.h" +#include "entropy_encoder_model/entropy_encoder_model_kernel_6.h" +#include "entropy_encoder_model/entropy_encoder_model_kernel_c.h" + +#include "conditioning_class.h" +#include "memory_manager.h" +#include "sliding_buffer.h" + + +namespace dlib +{ + + + template < + unsigned long alphabet_size, + typename entropy_encoder + > + class entropy_encoder_model + { + entropy_encoder_model() {} + + typedef typename conditioning_class::kernel_1a cc1; + typedef typename conditioning_class::kernel_2a cc2; + typedef typename conditioning_class::kernel_3a cc3; + typedef typename conditioning_class::kernel_4a cc4a; + typedef typename conditioning_class::kernel_4b cc4b; + typedef typename conditioning_class::kernel_4c cc4c; + typedef typename conditioning_class::kernel_4d cc4d; + + + public: + + //----------- kernels --------------- + + // kernel_1 + typedef entropy_encoder_model_kernel_1 + kernel_1a; + typedef entropy_encoder_model_kernel_c + kernel_1a_c; + + typedef entropy_encoder_model_kernel_1 + kernel_1b; + typedef entropy_encoder_model_kernel_c + kernel_1b_c; + + typedef entropy_encoder_model_kernel_1 + kernel_1c; + typedef entropy_encoder_model_kernel_c + kernel_1c_c; + + // -------------------- + + // kernel_2 + typedef entropy_encoder_model_kernel_2 + kernel_2a; + typedef entropy_encoder_model_kernel_c + kernel_2a_c; + + typedef entropy_encoder_model_kernel_2 + kernel_2b; + typedef entropy_encoder_model_kernel_c + kernel_2b_c; + + typedef entropy_encoder_model_kernel_2 + kernel_2c; + typedef entropy_encoder_model_kernel_c + kernel_2c_c; + + typedef entropy_encoder_model_kernel_2 + kernel_2d; + typedef entropy_encoder_model_kernel_c + kernel_2d_c; + + // -------------------- + + // kernel_3 + typedef entropy_encoder_model_kernel_3 + kernel_3a; + typedef entropy_encoder_model_kernel_c + kernel_3a_c; + + typedef entropy_encoder_model_kernel_3 + kernel_3b; + typedef entropy_encoder_model_kernel_c + kernel_3b_c; + + typedef entropy_encoder_model_kernel_3 + kernel_3c; + typedef entropy_encoder_model_kernel_c + kernel_3c_c; + + // -------------------- + + // kernel_4 + typedef entropy_encoder_model_kernel_4 + kernel_4a; + typedef entropy_encoder_model_kernel_c + kernel_4a_c; + + typedef entropy_encoder_model_kernel_4 + kernel_4b; + typedef entropy_encoder_model_kernel_c + kernel_4b_c; + + // -------------------- + + // kernel_5 + typedef entropy_encoder_model_kernel_5 + kernel_5a; + typedef entropy_encoder_model_kernel_c + kernel_5a_c; + + typedef entropy_encoder_model_kernel_5 + kernel_5b; + typedef entropy_encoder_model_kernel_c + kernel_5b_c; + + typedef entropy_encoder_model_kernel_5 + kernel_5c; + typedef entropy_encoder_model_kernel_c + kernel_5c_c; + + // -------------------- + + // kernel_6 + typedef entropy_encoder_model_kernel_6 + kernel_6a; + typedef entropy_encoder_model_kernel_c + kernel_6a_c; + + + + }; +} + +#endif // DLIB_ENTROPY_ENCODER_MODEl_ + diff --git a/ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_1.h b/ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_1.h new file mode 100644 index 000000000..29c82e5be --- /dev/null +++ b/ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_1.h @@ -0,0 +1,167 @@ +// Copyright (C) 2004 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ENTROPY_ENCODER_MODEL_KERNEl_1_ +#define DLIB_ENTROPY_ENCODER_MODEL_KERNEl_1_ + +#include "../algs.h" +#include "entropy_encoder_model_kernel_abstract.h" +#include "../assert.h" + +namespace dlib +{ + + template < + unsigned long alphabet_size, + typename entropy_encoder, + typename cc + > + class entropy_encoder_model_kernel_1 + { + /*! + REQUIREMENTS ON cc + cc is an implementation of conditioning_class/conditioning_class_kernel_abstract.h + cc::get_alphabet_size() == alphabet_size+1 + + INITIAL VALUE + Initially this object's finite context model is empty + + CONVENTION + &get_entropy_encoder() == coder + &order_0.get_global_state() == &gs + + This is an order-0 model. The last symbol in the order-0 context is + an escape into the order minus 1 context. + !*/ + + public: + + typedef entropy_encoder entropy_encoder_type; + + entropy_encoder_model_kernel_1 ( + entropy_encoder& coder + ); + + virtual ~entropy_encoder_model_kernel_1 ( + ); + + inline void clear( + ); + + inline void encode ( + unsigned long symbol + ); + + entropy_encoder& get_entropy_encoder ( + ) { return coder; } + + static unsigned long get_alphabet_size ( + ) { return alphabet_size; } + + private: + + entropy_encoder& coder; + typename cc::global_state_type gs; + cc order_0; + + // restricted functions + entropy_encoder_model_kernel_1(entropy_encoder_model_kernel_1&); // copy constructor + entropy_encoder_model_kernel_1& operator=(entropy_encoder_model_kernel_1&); // assignment operator + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_encoder, + typename cc + > + entropy_encoder_model_kernel_1:: + entropy_encoder_model_kernel_1 ( + entropy_encoder& coder_ + ) : + coder(coder_), + order_0(gs) + { + COMPILE_TIME_ASSERT( 1 < alphabet_size && alphabet_size < 65535 ); + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_encoder, + typename cc + > + entropy_encoder_model_kernel_1:: + ~entropy_encoder_model_kernel_1 ( + ) + { + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_encoder, + typename cc + > + void entropy_encoder_model_kernel_1:: + clear( + ) + { + order_0.clear(); + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_encoder, + typename cc + > + void entropy_encoder_model_kernel_1:: + encode ( + unsigned long symbol + ) + { + unsigned long low_count = 0, high_count = 0, total_count = 0; + + // if we have seen this symbol in the order-0 context + if (order_0.get_range(symbol,low_count,high_count,total_count)) + { + // update the count for this symbol + order_0.increment_count(symbol,2); + // encode this symbol + coder.encode(low_count,high_count,total_count); + return; + } + + // if we are here then the symbol does not appear in the order-0 context + + + // since we have never seen the current symbol in this context + // escape from order-0 context + order_0.get_range(alphabet_size,low_count,high_count,total_count); + coder.encode(low_count,high_count,total_count); + // increment the count for the escape symbol + order_0.increment_count(alphabet_size); + + // update the count for this symbol + order_0.increment_count(symbol,2); + + // use order minus one context + coder.encode(symbol,symbol+1,alphabet_size); + + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_ENTROPY_ENCODER_MODEL_KERNEl_1_ + diff --git a/ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_2.h b/ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_2.h new file mode 100644 index 000000000..08a16cae0 --- /dev/null +++ b/ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_2.h @@ -0,0 +1,246 @@ +// Copyright (C) 2004 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ENTROPY_ENCODER_MODEL_KERNEl_2_ +#define DLIB_ENTROPY_ENCODER_MODEL_KERNEl_2_ + +#include "../algs.h" +#include "entropy_encoder_model_kernel_abstract.h" +#include "../assert.h" + +namespace dlib +{ + + template < + unsigned long alphabet_size, + typename entropy_encoder, + typename cc, + typename ccbig + > + class entropy_encoder_model_kernel_2 + { + /*! + REQUIREMENTS ON cc + cc is an implementation of conditioning_class/conditioning_class_kernel_abstract.h + cc::get_alphabet_size() == alphabet_size+1 + this will be used for the order-0 context + + REQUIREMENTS ON ccbig + ccbig is an implementation of conditioning_class/conditioning_class_kernel_abstract.h + ccbig::get_alphabet_size() == alphabet_size+1 + this will be used for the order-1 context + + INITIAL VALUE + Initially this object's finite context model is empty + previous_symbol == 0 + + CONVENTION + &get_entropy_encoder() == coder + &order_0.get_global_state() == &gs + &order_1[i]->get_global_state() == &gsbig + + + This is an order-1-0 model. The last symbol in the order-0 and order-1 + context is an escape into the lower context. + + previous_symbol == the last symbol seen + !*/ + + public: + + typedef entropy_encoder entropy_encoder_type; + + entropy_encoder_model_kernel_2 ( + entropy_encoder& coder + ); + + virtual ~entropy_encoder_model_kernel_2 ( + ); + + inline void clear( + ); + + inline void encode ( + unsigned long symbol + ); + + entropy_encoder& get_entropy_encoder ( + ) { return coder; } + + static unsigned long get_alphabet_size ( + ) { return alphabet_size; } + + private: + + entropy_encoder& coder; + typename cc::global_state_type gs; + typename ccbig::global_state_type gsbig; + cc order_0; + ccbig* order_1[alphabet_size]; + unsigned long previous_symbol; + + + // restricted functions + entropy_encoder_model_kernel_2(entropy_encoder_model_kernel_2&); // copy constructor + entropy_encoder_model_kernel_2& operator=(entropy_encoder_model_kernel_2&); // assignment operator + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_encoder, + typename cc, + typename ccbig + > + entropy_encoder_model_kernel_2:: + entropy_encoder_model_kernel_2 ( + entropy_encoder& coder_ + ) : + coder(coder_), + order_0(gs), + previous_symbol(0) + { + COMPILE_TIME_ASSERT( 1 < alphabet_size && alphabet_size < 65535 ); + + unsigned long i; + try + { + for (i = 0; i < alphabet_size; ++i) + { + order_1[i] = new ccbig(gsbig); + } + } + catch (...) + { + for (unsigned long j = 0; j < i; ++j) + { + delete order_1[j]; + } + throw; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_encoder, + typename cc, + typename ccbig + > + entropy_encoder_model_kernel_2:: + ~entropy_encoder_model_kernel_2 ( + ) + { + for (unsigned long i = 0; i < alphabet_size; ++i) + { + delete order_1[i]; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_encoder, + typename cc, + typename ccbig + > + void entropy_encoder_model_kernel_2:: + clear( + ) + { + previous_symbol = 0; + order_0.clear(); + for (unsigned long i = 0; i < alphabet_size; ++i) + { + order_1[i]->clear(); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_encoder, + typename cc, + typename ccbig + > + void entropy_encoder_model_kernel_2:: + encode ( + unsigned long symbol + ) + { + unsigned long low_count = 0, high_count = 0, total_count = 0; + + + ccbig& context = *order_1[previous_symbol]; + + // if we have seen this symbol in the order-1 context + if (context.get_range(symbol,low_count,high_count,total_count)) + { + // update the count for this symbol + context.increment_count(symbol,2); + // encode this symbol + coder.encode(low_count,high_count,total_count); + previous_symbol = symbol; + return; + } + + // we didn't find the symbol in the order-1 context so we must escape to a + // lower context. + + // escape to the order-0 context + context.get_range(alphabet_size,low_count,high_count,total_count); + coder.encode(low_count,high_count,total_count); + + + // increment counts for the escape symbol and the current symbol + context.increment_count(alphabet_size); + context.increment_count(symbol,2); + + previous_symbol = symbol; + + + + + + // if we have seen this symbol in the order-0 context + if (order_0.get_range(symbol,low_count,high_count,total_count)) + { + // update the count for this symbol + order_0.increment_count(symbol,2); + // encode this symbol + coder.encode(low_count,high_count,total_count); + return; + } + + // if we are here then the symbol does not appear in the order-0 context + + + // since we have never seen the current symbol in this context + // escape from order-0 context + order_0.get_range(alphabet_size,low_count,high_count,total_count); + coder.encode(low_count,high_count,total_count); + // increment the count for the escape symbol + order_0.increment_count(alphabet_size); + + // update the count for this symbol + order_0.increment_count(symbol,2); + + // use order minus one context + coder.encode(symbol,symbol+1,alphabet_size); + + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_ENTROPY_ENCODER_MODEL_KERNEl_2_ + diff --git a/ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_3.h b/ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_3.h new file mode 100644 index 000000000..0df28f201 --- /dev/null +++ b/ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_3.h @@ -0,0 +1,341 @@ +// Copyright (C) 2004 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ENTROPY_ENCODER_MODEL_KERNEl_3_ +#define DLIB_ENTROPY_ENCODER_MODEL_KERNEl_3_ + +#include "../algs.h" +#include "entropy_encoder_model_kernel_abstract.h" +#include "../assert.h" + +namespace dlib +{ + + template < + unsigned long alphabet_size, + typename entropy_encoder, + typename cc, + typename cc_high + > + class entropy_encoder_model_kernel_3 + { + /*! + REQUIREMENTS ON cc + cc is an implementation of conditioning_class/conditioning_class_kernel_abstract.h + cc::get_alphabet_size() == alphabet_size+1 + + REQUIREMENTS ON cc_high + cc_high is an implementation of conditioning_class/conditioning_class_kernel_abstract.h + cc_high::get_alphabet_size() == alphabet_size+1 + + INITIAL VALUE + - Initially this object's finite context model is empty + - previous_symbol == 0 + - previous_symbol2 == 0 + - order_1 == pointer to an array of alphabet_size elements + - order_2 == pointer to an array of alphabet_size*alphabet_size elements + - order_2[i] == 0 + + CONVENTION + &get_entropy_encoder() == coder + &order_0.get_global_state() == &gs + &order_1[i]->get_global_state() == &gs + + if (order_2[i] != 0) then + &order_2[i]->get_global_state() == &gs_high + + This is an order-2-1-0 model. The last symbol in the order-2, order-1 and + order-0 contexts is an escape into the lower context. + + previous_symbol == the last symbol seen + previous_symbol2 == the symbol we saw before previous_symbol + !*/ + + public: + + typedef entropy_encoder entropy_encoder_type; + + entropy_encoder_model_kernel_3 ( + entropy_encoder& coder + ); + + virtual ~entropy_encoder_model_kernel_3 ( + ); + + inline void clear( + ); + + inline void encode ( + unsigned long symbol + ); + + entropy_encoder& get_entropy_encoder ( + ) { return coder; } + + static unsigned long get_alphabet_size ( + ) { return alphabet_size; } + + private: + + entropy_encoder& coder; + typename cc::global_state_type gs; + typename cc_high::global_state_type gs_high; + cc order_0; + cc** order_1; + unsigned long previous_symbol; + cc_high** order_2; + unsigned long previous_symbol2; + + + // restricted functions + entropy_encoder_model_kernel_3(entropy_encoder_model_kernel_3&); // copy constructor + entropy_encoder_model_kernel_3& operator=(entropy_encoder_model_kernel_3&); // assignment operator + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_encoder, + typename cc, + typename cc_high + > + entropy_encoder_model_kernel_3:: + entropy_encoder_model_kernel_3 ( + entropy_encoder& coder_ + ) : + coder(coder_), + order_0(gs), + order_1(0), + previous_symbol(0), + order_2(0), + previous_symbol2(0) + { + COMPILE_TIME_ASSERT( 1 < alphabet_size && alphabet_size < 65535 ); + + try + { + order_1 = new cc*[alphabet_size]; + order_2 = new cc_high*[alphabet_size*alphabet_size]; + } + catch (...) + { + if (order_1) delete [] order_1; + if (order_2) delete [] order_2; + throw; + } + + unsigned long i; + + for (i = 0; i < (alphabet_size*alphabet_size); ++i) + { + order_2[i] = 0; + } + + try + { + for (i = 0; i < alphabet_size; ++i) + { + order_1[i] = new cc(gs); + } + } + catch (...) + { + for (unsigned long j = 0; j < i; ++j) + { + delete order_1[j]; + } + throw; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_encoder, + typename cc, + typename cc_high + > + entropy_encoder_model_kernel_3:: + ~entropy_encoder_model_kernel_3 ( + ) + { + for (unsigned long i = 0; i < alphabet_size; ++i) + { + delete order_1[i]; + } + + for (unsigned long i = 0; i < alphabet_size*alphabet_size; ++i) + { + if (order_2[i] != 0) + delete order_2[i]; + } + delete [] order_1; + delete [] order_2; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_encoder, + typename cc, + typename cc_high + > + void entropy_encoder_model_kernel_3:: + clear( + ) + { + previous_symbol = 0; + previous_symbol2 = 0; + order_0.clear(); + for (unsigned long i = 0; i < alphabet_size; ++i) + { + order_1[i]->clear(); + } + + for (unsigned long i = 0; i < alphabet_size*alphabet_size; ++i) + { + if (order_2[i] != 0) + { + delete order_2[i]; + order_2[i] = 0; + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_encoder, + typename cc, + typename cc_high + > + void entropy_encoder_model_kernel_3:: + encode ( + unsigned long symbol + ) + { + unsigned long low_count = 0, high_count = 0, total_count = 0; + + + // order-2 context stuff + { + unsigned long temp = previous_symbol + (previous_symbol2 * alphabet_size); + previous_symbol2 = previous_symbol; + + if (order_2[temp] != 0) + { + if (order_2[temp]->get_range(symbol,low_count,high_count,total_count)) + { + // there was an entry for this symbol in this context + + // update the count for this symbol + order_2[temp]->increment_count(symbol,2); + // encode this symbol + coder.encode(low_count,high_count,total_count); + previous_symbol = symbol; + return; + } + + // there was no entry for this symbol in this context so we must + // escape to order-1 + + // escape to the order-1 context + order_2[temp]->get_range(alphabet_size,low_count,high_count,total_count); + coder.encode(low_count,high_count,total_count); + + // increment the count for the escape symbol + order_2[temp]->increment_count(alphabet_size); + + } + else + { + order_2[temp] = new cc_high(gs_high); + + // in this case the decoder knows to escape to order-1 because + // there was no conditioning_class object in this context yet. + // so we don't need to actually write the escape symbol + } + + // update the count for this symbol in this context + order_2[temp]->increment_count(symbol,2); + } + + + + + // order-1 context stuff + { + cc& context = *order_1[previous_symbol]; + + // if we have seen this symbol in the order-1 context + if (context.get_range(symbol,low_count,high_count,total_count)) + { + // update the count for this symbol + context.increment_count(symbol,2); + // encode this symbol + coder.encode(low_count,high_count,total_count); + previous_symbol = symbol; + return; + } + + // we didn't find the symbol in the order-1 context so we must escape to a + // lower context. + + // escape to the order-0 context + context.get_range(alphabet_size,low_count,high_count,total_count); + coder.encode(low_count,high_count,total_count); + + + // increment counts for the escape symbol and the current symbol + context.increment_count(alphabet_size); + context.increment_count(symbol,2); + } + + previous_symbol = symbol; + + + + + + // if we have seen this symbol in the order-0 context + if (order_0.get_range(symbol,low_count,high_count,total_count)) + { + // update the count for this symbol + order_0.increment_count(symbol,2); + // encode this symbol + coder.encode(low_count,high_count,total_count); + return; + } + + // if we are here then the symbol does not appear in the order-0 context + + + // since we have never seen the current symbol in this context + // escape from order-0 context + order_0.get_range(alphabet_size,low_count,high_count,total_count); + coder.encode(low_count,high_count,total_count); + // increment the count for the escape symbol + order_0.increment_count(alphabet_size); + + // update the count for this symbol + order_0.increment_count(symbol,2); + + // use order minus one context + coder.encode(symbol,symbol+1,alphabet_size); + + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_ENTROPY_ENCODER_MODEL_KERNEl_3_ + diff --git a/ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_4.h b/ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_4.h new file mode 100644 index 000000000..0e5ae46d3 --- /dev/null +++ b/ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_4.h @@ -0,0 +1,553 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ENTROPY_ENCODER_MODEL_KERNEl_4_ +#define DLIB_ENTROPY_ENCODER_MODEL_KERNEl_4_ + +#include "../algs.h" +#include "entropy_encoder_model_kernel_abstract.h" +#include "../assert.h" + + + +namespace dlib +{ + + namespace eemk4 + { + struct node + { + node* next; + node* child_context; + node* parent_context; + + unsigned short symbol; + unsigned short count; + unsigned short total; + unsigned short escapes; + }; + } + + + template < + unsigned long alphabet_size, + typename entropy_encoder, + unsigned long total_nodes, + unsigned long order + > + class entropy_encoder_model_kernel_4 + { + /*! + REQUIREMENTS ON total_nodes + - 4096 < total_nodes + - this is the total number of nodes that we will use in the tree + + REQUIREMENTS ON order + - 0 <= order + - this is the maximum depth-1 the tree will be allowed to go (note + that the root level is depth 0). + + GENERAL NOTES + This implementation follows more or less the implementation + strategy laid out by Alistair Moffat in his paper + Implementing the PPM data compression scheme. Published in IEEE + Transactions on Communications, 38(11):1917-1921, 1990. + + The escape method used will be method D. + + + INITIAL VALUE + - root == pointer to an array of total_nodes nodes + - next_node == 1 + - cur == root + - cur_order = 0 + - root->next == 0 + - root->parent_context == 0 + - root->child_context == 0 + - root->escapes == 0 + - root->total == 0 + + CONVENTION + - &get_entropy_encoder() == coder + - root == pointer to an array of total_nodes nodes. + this is also the root of the tree. + + - if (next_node < total_nodes) then + - next_node == the next node in root that has not yet been allocated + + - root->next == 0 + - root->parent_context == 0 + + + - for every node in the tree: + { + - NOTATION: + - The "context" of a node is the string of symbols seen + when you go from the root of the tree down (down though + child context pointers) to the node, including the symbol at + the node itself. (note that the context of the root node + is "" or the empty string) + - A set of nodes is in the same "context set" if all the node's + contexts are of length n and all the node's contexts share + the same prefix of length n-1. + - The "child context set" of a node is a set of nodes with + contexts that are one symbol longer and prefixed by the node's + context. For example, if a node has a context "abc" then the + nodes for contexts "abca", "abcb", "abcc", etc. are all in + the child context set of the node. + - The "parent context" of a node is the context that is one + symbol shorter than the node's context and includes the + symbol in the node. So the parent context of a node with + context "abcd" would be the context "bcd". + + + - if (next != 0) then + - next == pointer to the next node in the same context set + - if (child_context != 0) then + - child_context == pointer to the first node of the child + context set for this node. + - if (parent_context != 0) then + - parent_context == pointer to the parent context of this node. + - else + - this node is the root node of the tree + + + - if (this is not the root node) then + - symbol == the symbol represented with this node + - count == the number of times this symbol has been seen in its + parent context. + - else + - the root doesn't have a symbol. i.e. the context for the + root node is "" or the empty string. + + - total == The sum of the counts of all the nodes + in the child context set + escapes. + - escapes == the escape count for the context represented + by the node. + } + + + - cur_order < order + - cur_order == the depth of the node cur in the tree. + (note that the root node has depth 0) + - cur == pointer to the node in the tree who's context matches + the most recent symbols we have seen. + + + !*/ + + typedef eemk4::node node; + + + public: + + typedef entropy_encoder entropy_encoder_type; + + entropy_encoder_model_kernel_4 ( + entropy_encoder& coder + ); + + virtual ~entropy_encoder_model_kernel_4 ( + ); + + inline void clear( + ); + + inline void encode ( + unsigned long symbol + ); + + entropy_encoder& get_entropy_encoder ( + ) { return coder; } + + static unsigned long get_alphabet_size ( + ) { return alphabet_size; } + + private: + + inline eemk4::node* allocate_node ( + ); + /*! + requires + - space_left() == true + ensures + - returns a pointer to a new node + !*/ + + inline void destroy_tree ( + ); + /*! + ensures + - deallocates all nodes except the root + - #root->child_context == 0 + - #root->escapes == 0 + - #root->total == 0 + - #cur == root + - #cur_order == 0 + !*/ + + + inline bool space_left ( + ) const; + /*! + ensures + - returns true if there is at least 1 free node left. + - returns false otherwise + !*/ + + + inline void scale_counts ( + node* n + ); + /*! + ensures + - divides all the counts in the child context set of n by 2. + - none of the nodes in the child context set will have a count of 0 + !*/ + + + unsigned long next_node; + entropy_encoder& coder; + node* root; + node* cur; + unsigned long cur_order; + + + // restricted functions + entropy_encoder_model_kernel_4(entropy_encoder_model_kernel_4&); // copy constructor + entropy_encoder_model_kernel_4& operator=(entropy_encoder_model_kernel_4&); // assignment operator + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_encoder, + unsigned long total_nodes, + unsigned long order + > + entropy_encoder_model_kernel_4:: + entropy_encoder_model_kernel_4 ( + entropy_encoder& coder_ + ) : + next_node(1), + coder(coder_), + cur_order(0) + { + COMPILE_TIME_ASSERT( 1 < alphabet_size && alphabet_size < 65535 ); + COMPILE_TIME_ASSERT( 4096 < total_nodes ); + + root = new node[total_nodes]; + cur = root; + + root->child_context = 0; + root->escapes = 0; + root->next = 0; + root->parent_context = 0; + root->total = 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_encoder, + unsigned long total_nodes, + unsigned long order + > + entropy_encoder_model_kernel_4:: + ~entropy_encoder_model_kernel_4 ( + ) + { + delete [] root; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_encoder, + unsigned long total_nodes, + unsigned long order + > + void entropy_encoder_model_kernel_4:: + clear( + ) + { + destroy_tree(); + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_encoder, + unsigned long total_nodes, + unsigned long order + > + void entropy_encoder_model_kernel_4:: + encode ( + unsigned long sym + ) + { + unsigned short symbol = static_cast(sym); + node* temp = cur; + cur = 0; + unsigned short low_count, high_count, total_count; + node* new_node = 0; + + // local_order will track the level of temp in the tree + unsigned long local_order = cur_order; + + while (true) + { + high_count = 0; + if (space_left()) + { + total_count = temp->total; + + if (total_count > 0) + { + // check if we need to scale the counts + if (total_count > 10000) + { + scale_counts(temp); + total_count = temp->total; + } + + // find either the symbol we are looking for or the + // end of the context set + node* n = temp->child_context; + node* last = 0; + while (true) + { + high_count += n->count; + + if (n->symbol == symbol || n->next == 0) + break; + last = n; + n = n->next; + } + + low_count = high_count - n->count; + + // if we found the symbol + if (n->symbol == symbol) + { + if (new_node != 0) + { + new_node->parent_context = n; + } + + coder.encode(low_count,high_count,total_count); + n->count += 8; + temp->total += 8; + + + // move this node to the front + if (last) + { + last->next = n->next; + n->next = temp->child_context; + temp->child_context = n; + } + + + if (cur == 0) + { + if (local_order < order) + { + cur_order = local_order+1; + cur = n; + } + else + { + cur = n->parent_context; + cur_order = local_order; + } + } + + break; + + } + // if we hit the end of the context set without finding the symbol + else + { + if (new_node != 0) + { + new_node->parent_context = allocate_node(); + new_node = new_node->parent_context; + } + else + { + new_node = allocate_node(); + } + + n->next = new_node; + + // write an escape to a lower context + coder.encode(high_count,total_count,total_count); + } + + } + else // if (total_count == 0) + { + // this means that temp->child_context == 0 so we should make + // a new node here. + if (new_node != 0) + { + new_node->parent_context = allocate_node(); + new_node = new_node->parent_context; + } + else + { + new_node = allocate_node(); + } + + temp->child_context = new_node; + } + + if (cur == 0 && local_order < order) + { + cur = new_node; + cur_order = local_order+1; + } + + // fill out the new node + new_node->child_context = 0; + new_node->count = 4; + new_node->escapes = 0; + new_node->next = 0; + new_node->symbol = static_cast(symbol); + new_node->total = 0; + temp->escapes += 4; + temp->total += 8; + + + if (temp != root) + { + temp = temp->parent_context; + --local_order; + continue; + } + + // since this is the root we are going to the order-(-1) context + // so we can just take care of that here. + new_node->parent_context = root; + coder.encode(symbol,symbol+1,alphabet_size); + + if (cur == 0) + { + cur = root; + cur_order = 0; + } + break; + } + else + { + // there isn't enough space so we should rebuild the tree + destroy_tree(); + temp = cur; + local_order = cur_order; + cur = 0; + new_node = 0; + } + } // while (true) + + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // private member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_encoder, + unsigned long total_nodes, + unsigned long order + > + eemk4::node* entropy_encoder_model_kernel_4:: + allocate_node ( + ) + { + node* temp; + temp = root + next_node; + ++next_node; + return temp; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_encoder, + unsigned long total_nodes, + unsigned long order + > + void entropy_encoder_model_kernel_4:: + destroy_tree ( + ) + { + next_node = 1; + root->child_context = 0; + root->escapes = 0; + root->total = 0; + cur = root; + cur_order = 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_encoder, + unsigned long total_nodes, + unsigned long order + > + bool entropy_encoder_model_kernel_4:: + space_left ( + ) const + { + return (next_node < total_nodes); + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_encoder, + unsigned long total_nodes, + unsigned long order + > + void entropy_encoder_model_kernel_4:: + scale_counts ( + node* temp + ) + { + if (temp->escapes > 1) + temp->escapes >>= 1; + temp->total = temp->escapes; + + node* n = temp->child_context; + while (n != 0) + { + if (n->count > 1) + n->count >>= 1; + + temp->total += n->count; + n = n->next; + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_ENTROPY_ENCODER_MODEL_KERNEl_4_ + diff --git a/ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_5.h b/ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_5.h new file mode 100644 index 000000000..6c0c30426 --- /dev/null +++ b/ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_5.h @@ -0,0 +1,817 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ENTROPY_ENCODER_MODEL_KERNEl_5_ +#define DLIB_ENTROPY_ENCODER_MODEL_KERNEl_5_ + +#include "../algs.h" +#include "entropy_encoder_model_kernel_abstract.h" +#include "../assert.h" + + + +namespace dlib +{ + + namespace eemk5 + { + struct node + { + node* next; + node* child_context; + node* parent_context; + + unsigned short symbol; + unsigned short count; + unsigned short total; + unsigned short escapes; + }; + } + + + template < + unsigned long alphabet_size, + typename entropy_encoder, + unsigned long total_nodes, + unsigned long order + > + class entropy_encoder_model_kernel_5 + { + /*! + REQUIREMENTS ON total_nodes + - 4096 < total_nodes + - this is the total number of nodes that we will use in the tree + + REQUIREMENTS ON order + - 0 <= order + - this is the maximum depth-1 the tree will be allowed to go (note + that the root level is depth 0). + + GENERAL NOTES + This implementation follows more or less the implementation + strategy laid out by Alistair Moffat in his paper + Implementing the PPM data compression scheme. Published in IEEE + Transactions on Communications, 38(11):1917-1921, 1990. + + The escape method used will be method D. + + This also uses Dmitry Shkarin's Information Inheritance scheme. + (described in "PPM: one step to practicality" and "Improving the + Efficiency of the PPM Algorithm") + + + INITIAL VALUE + - root == pointer to an array of total_nodes nodes + - next_node == 1 + - cur == root + - cur_order = 0 + - root->next == 0 + - root->parent_context == 0 + - root->child_context == 0 + - root->escapes == 0 + - root->total == 0 + - stack_size == 0 + - exc_used == false + - for all i: exc[i] == 0 + + CONVENTION + - pop() == stack[stack_size-1].n and stack[stack_size-1].nc + - exc_used == something_is_excluded() + - is_excluded(symbol) == bit symbol&0x1F from exc[symbol>>5] + - &get_entropy_encoder() == coder + - root == pointer to an array of total_nodes nodes. + this is also the root of the tree. + - if (next_node < total_nodes) then + - next_node == the next node in root that has not yet been allocated + + - root->next == 0 + - root->parent_context == 0 + + + - for every node in the tree: + { + - NOTATION: + - The "context" of a node is the string of symbols seen + when you go from the root of the tree down (down though + child context pointers) to the node, including the symbol at + the node itself. (note that the context of the root node + is "" or the empty string) + - A set of nodes is in the same "context set" if all the node's + contexts are of length n and all the node's contexts share + the same prefix of length n-1. + - The "child context set" of a node is a set of nodes with + contexts that are one symbol longer and prefixed by the node's + context. For example, if a node has a context "abc" then the + nodes for contexts "abca", "abcb", "abcc", etc. are all in + the child context set of the node. + - The "parent context" of a node is the context that is one + symbol shorter than the node's context and includes the + symbol in the node. So the parent context of a node with + context "abcd" would be the context "bcd". + + + - if (next != 0) then + - next == pointer to the next node in the same context set + - if (child_context != 0) then + - child_context == pointer to the first node of the child + context set for this node. + - escapes > 0 + - if (parent_context != 0) then + - parent_context == pointer to the parent context of this node. + - else + - this node is the root node of the tree + + + - if (this is not the root node) then + - symbol == the symbol represented with this node + - count == the number of times this symbol has been seen in its + parent context. + - else + - the root doesn't have a symbol. i.e. the context for the + root node is "" or the empty string. + + - total == The sum of the counts of all the nodes + in the child context set + escapes. + - escapes == the escape count for the context represented + by the node. + - count > 0 + } + + + - cur_order < order + - cur_order == the depth of the node cur in the tree. + (note that the root node has depth 0) + - cur == pointer to the node in the tree who's context matches + the most recent symbols we have seen. + + + !*/ + + typedef eemk5::node node; + + + public: + + typedef entropy_encoder entropy_encoder_type; + + entropy_encoder_model_kernel_5 ( + entropy_encoder& coder + ); + + virtual ~entropy_encoder_model_kernel_5 ( + ); + + inline void clear( + ); + + inline void encode ( + unsigned long symbol + ); + + entropy_encoder& get_entropy_encoder ( + ) { return coder; } + + static unsigned long get_alphabet_size ( + ) { return alphabet_size; } + + private: + + inline eemk5::node* allocate_node ( + ); + /*! + requires + - space_left() == true + ensures + - returns a pointer to a new node + !*/ + + inline bool space_left ( + ) const; + /*! + ensures + - returns true if there is at least 1 free node left. + - returns false otherwise + !*/ + + inline void exclude ( + unsigned short symbol + ); + /*! + ensures + - #is_excluded(symbol) == true + - #something_is_excluded() == true + !*/ + + inline bool something_is_excluded ( + ); + /*! + ensures + - returns true if some symbol has been excluded. + returns false otherwise + !*/ + + inline bool is_excluded ( + unsigned short symbol + ); + /*! + ensures + - if (symbol has been excluded) then + - returns true + - else + - returns false + !*/ + + inline void clear_exclusions ( + ); + /*! + ensures + - for all symbols #is_excluded(symbol) == false + - #something_is_excluded() == true + !*/ + + inline void scale_counts ( + node* n + ); + /*! + ensures + - divides all the counts in the child context set of n by 2. + - none of the nodes in the child context set will have a count of 0 + !*/ + + inline void push ( + node* n, + node* nc + ); + /*! + requires + - stack_size < order + ensures + - #pop(a,b): a == n && b == nc + !*/ + + inline void pop ( + node*& n, + node*& nc + ); + /*! + requires + - stack_size > 0 + ensures + - returns the two nodes at the top of the stack + !*/ + + struct nodes + { + node* n; + node* nc; + }; + + unsigned long next_node; + entropy_encoder& coder; + node* root; + node* cur; + unsigned long cur_order; + unsigned long exc[alphabet_size/32+1]; + bool exc_used; + nodes stack[order+1]; + unsigned long stack_size; + + // restricted functions + entropy_encoder_model_kernel_5(entropy_encoder_model_kernel_5&); // copy constructor + entropy_encoder_model_kernel_5& operator=(entropy_encoder_model_kernel_5&); // assignment operator + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_encoder, + unsigned long total_nodes, + unsigned long order + > + entropy_encoder_model_kernel_5:: + entropy_encoder_model_kernel_5 ( + entropy_encoder& coder_ + ) : + next_node(1), + coder(coder_), + cur_order(0), + stack_size(0) + { + COMPILE_TIME_ASSERT( 1 < alphabet_size && alphabet_size < 65535 ); + COMPILE_TIME_ASSERT( 4096 < total_nodes ); + + root = new node[total_nodes]; + cur = root; + + root->child_context = 0; + root->escapes = 0; + root->next = 0; + root->parent_context = 0; + root->total = 0; + + clear_exclusions(); + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_encoder, + unsigned long total_nodes, + unsigned long order + > + entropy_encoder_model_kernel_5:: + ~entropy_encoder_model_kernel_5 ( + ) + { + delete [] root; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_encoder, + unsigned long total_nodes, + unsigned long order + > + void entropy_encoder_model_kernel_5:: + clear( + ) + { + next_node = 1; + root->child_context = 0; + root->escapes = 0; + root->total = 0; + cur = root; + cur_order = 0; + stack_size = 0; + + clear_exclusions(); + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_encoder, + unsigned long total_nodes, + unsigned long order + > + void entropy_encoder_model_kernel_5:: + encode ( + unsigned long sym + ) + { + unsigned short symbol = static_cast(sym); + node* temp = cur; + cur = 0; + unsigned short low_count, high_count, total_count; + node* new_node = 0; + + // local_order will track the level of temp in the tree + unsigned long local_order = cur_order; + + + unsigned short c; // c == t(a|sk) + unsigned short t; // t == T(sk) + + + if (something_is_excluded()) + clear_exclusions(); + + while (true) + { + low_count = 0; + high_count = 0; + if (space_left()) + { + total_count = temp->total; + + if (total_count > 0) + { + // check if we need to scale the counts + if (total_count > 10000) + { + scale_counts(temp); + total_count = temp->total; + } + + + // find the symbol we are looking for and put a pointer to it + // into found_symbol. If it isn't found then found_symbol == 0. + // also, low_count and high_count will be correctly set. + node* n = temp->child_context; + node* found_symbol = 0; + node* last = 0; + if (something_is_excluded()) + { + node* templast = 0; + while (true) + { + if (is_excluded(n->symbol) == false) + { + exclude(n->symbol); + if (found_symbol == 0) + { + high_count += n->count; + if (n->symbol == symbol) + { + found_symbol = n; + last = templast; + low_count = high_count - n->count; + } + } + } + else + { + total_count -= n->count; + } + + if (n->next == 0) + break; + templast = n; + n = n->next; + } + } + else + { + while (true) + { + high_count += n->count; + exclude(n->symbol); + + if (n->symbol == symbol) + { + found_symbol = n; + low_count = high_count - n->count; + break; + } + + if (n->next == 0) + break; + last = n; + n = n->next; + } + } + + + + + + // if we found the symbol + if (found_symbol) + { + n = found_symbol; + if (new_node != 0) + { + new_node->parent_context = found_symbol; + } + + + coder.encode(low_count,high_count,total_count); + c = n->count += 8; + t = temp->total += 8; + + // move this node to the front + if (last) + { + last->next = n->next; + n->next = temp->child_context; + temp->child_context = n; + } + + + if (cur == 0) + { + if (local_order >= order) + { + cur = n->parent_context; + cur_order = local_order; + } + else + { + cur_order = local_order+1; + cur = n; + } + } + + break; + + } + // if we hit the end of the context set without finding the symbol + else + { + // finish excluding all the symbols + while (n->next) + { + exclude(n->symbol); + n = n->next; + } + + if (new_node != 0) + { + new_node->parent_context = allocate_node(); + new_node = new_node->parent_context; + } + else + { + new_node = allocate_node(); + } + + n->next = new_node; + + // write an escape to a lower context + coder.encode(high_count,total_count,total_count); + } + + } + else // if (total_count == 0) + { + // this means that temp->child_context == 0 so we should make + // a new node here. + if (new_node != 0) + { + new_node->parent_context = allocate_node(); + new_node = new_node->parent_context; + } + else + { + new_node = allocate_node(); + } + + temp->child_context = new_node; + } + + if (cur == 0 && local_order < order) + { + cur = new_node; + cur_order = local_order+1; + } + + // fill out the new node + new_node->child_context = 0; + new_node->escapes = 0; + new_node->next = 0; + new_node->total = 0; + push(new_node,temp); + + if (temp != root) + { + temp = temp->parent_context; + --local_order; + continue; + } + + t = 2056; + c = 8; + + // since this is the root we are going to the order-(-1) context + // so we can just take care of that here. + new_node->parent_context = root; + coder.encode(symbol,symbol+1,alphabet_size); + + if (cur == 0) + { + cur = root; + cur_order = 0; + } + break; + } + else + { + // there isn't enough space so we should throw away the tree + clear(); + temp = cur; + local_order = cur_order; + cur = 0; + new_node = 0; + } + } // while (true) + + + // initialize the counts and symbol for any new nodes we have added + // to the tree. + node* n, *nc; + while (stack_size > 0) + { + pop(n,nc); + + n->symbol = static_cast(symbol); + + // if nc is not a determnistic context + if (nc->total) + { + unsigned long temp2 = t-c+nc->total - nc->escapes - nc->escapes; + unsigned long temp = nc->total; + temp *= c; + temp /= (temp2|1); // this oring by 1 is just to make sure that temp2 is never zero + temp += 2; + if (temp > 50000) temp = 50000; + n->count = static_cast(temp); + + + nc->escapes += 4; + nc->total += static_cast(temp) + 4; + } + else + { + n->count = 3 + 5*(c)/(t-c); + + nc->escapes = 4; + nc->total = n->count + 4; + } + + while (nc->total > 10000) + { + scale_counts(nc); + } + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // private member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_encoder, + unsigned long total_nodes, + unsigned long order + > + eemk5::node* entropy_encoder_model_kernel_5:: + allocate_node ( + ) + { + node* temp; + temp = root + next_node; + ++next_node; + return temp; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_encoder, + unsigned long total_nodes, + unsigned long order + > + bool entropy_encoder_model_kernel_5:: + space_left ( + ) const + { + return (next_node < total_nodes); + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_encoder, + unsigned long total_nodes, + unsigned long order + > + void entropy_encoder_model_kernel_5:: + exclude ( + unsigned short symbol + ) + { + exc_used = true; + unsigned long temp = 1; + temp <<= symbol&0x1F; + exc[symbol>>5] |= temp; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_encoder, + unsigned long total_nodes, + unsigned long order + > + bool entropy_encoder_model_kernel_5:: + is_excluded ( + unsigned short symbol + ) + { + unsigned long temp = 1; + temp <<= symbol&0x1F; + return ((exc[symbol>>5]&temp) != 0); + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_encoder, + unsigned long total_nodes, + unsigned long order + > + void entropy_encoder_model_kernel_5:: + clear_exclusions ( + ) + { + exc_used = false; + for (unsigned long i = 0; i < alphabet_size/32+1; ++i) + { + exc[i] = 0; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_encoder, + unsigned long total_nodes, + unsigned long order + > + bool entropy_encoder_model_kernel_5:: + something_is_excluded ( + ) + { + return exc_used; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_encoder, + unsigned long total_nodes, + unsigned long order + > + void entropy_encoder_model_kernel_5:: + push ( + node* n, + node* nc + ) + { + stack[stack_size].n = n; + stack[stack_size].nc = nc; + ++stack_size; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_encoder, + unsigned long total_nodes, + unsigned long order + > + void entropy_encoder_model_kernel_5:: + pop ( + node*& n, + node*& nc + ) + { + --stack_size; + n = stack[stack_size].n; + nc = stack[stack_size].nc; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_encoder, + unsigned long total_nodes, + unsigned long order + > + void entropy_encoder_model_kernel_5:: + scale_counts ( + node* temp + ) + { + if (temp->escapes > 1) + temp->escapes >>= 1; + temp->total = temp->escapes; + + node* n = temp->child_context; + while (n != 0) + { + if (n->count > 1) + n->count >>= 1; + + temp->total += n->count; + n = n->next; + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_ENTROPY_ENCODER_MODEL_KERNEl_5_ + + diff --git a/ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_6.h b/ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_6.h new file mode 100644 index 000000000..2199bfbe4 --- /dev/null +++ b/ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_6.h @@ -0,0 +1,127 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ENTROPY_ENCODER_MODEL_KERNEl_6_ +#define DLIB_ENTROPY_ENCODER_MODEL_KERNEl_6_ + +#include "../algs.h" +#include "entropy_encoder_model_kernel_abstract.h" +#include "../assert.h" + +namespace dlib +{ + + template < + unsigned long alphabet_size, + typename entropy_encoder + > + class entropy_encoder_model_kernel_6 + { + /*! + INITIAL VALUE + Initially this object's finite context model is empty + + CONVENTION + &get_entropy_encoder() == coder + + This is an order-(-1) model. So it doesn't really do anything. + Every symbol has the same probability. + !*/ + + public: + + typedef entropy_encoder entropy_encoder_type; + + entropy_encoder_model_kernel_6 ( + entropy_encoder& coder + ); + + virtual ~entropy_encoder_model_kernel_6 ( + ); + + inline void clear( + ); + + inline void encode ( + unsigned long symbol + ); + + entropy_encoder& get_entropy_encoder ( + ) { return coder; } + + static unsigned long get_alphabet_size ( + ) { return alphabet_size; } + + private: + + entropy_encoder& coder; + + // restricted functions + entropy_encoder_model_kernel_6(entropy_encoder_model_kernel_6&); // copy constructor + entropy_encoder_model_kernel_6& operator=(entropy_encoder_model_kernel_6&); // assignment operator + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_encoder + > + entropy_encoder_model_kernel_6:: + entropy_encoder_model_kernel_6 ( + entropy_encoder& coder_ + ) : + coder(coder_) + { + COMPILE_TIME_ASSERT( 1 < alphabet_size && alphabet_size < 65535 ); + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_encoder + > + entropy_encoder_model_kernel_6:: + ~entropy_encoder_model_kernel_6 ( + ) + { + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_encoder + > + void entropy_encoder_model_kernel_6:: + clear( + ) + { + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + typename entropy_encoder + > + void entropy_encoder_model_kernel_6:: + encode ( + unsigned long symbol + ) + { + // use order minus one context + coder.encode(symbol,symbol+1,alphabet_size); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_ENTROPY_ENCODER_MODEL_KERNEl_6_ + diff --git a/ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_abstract.h b/ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_abstract.h new file mode 100644 index 000000000..fb5f01bc7 --- /dev/null +++ b/ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_abstract.h @@ -0,0 +1,118 @@ +// Copyright (C) 2004 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_ENTROPY_ENCODER_MODEL_KERNEl_ABSTRACT_ +#ifdef DLIB_ENTROPY_ENCODER_MODEL_KERNEl_ABSTRACT_ + +#include "../algs.h" + +namespace dlib +{ + + template < + unsigned long alphabet_size, + typename entropy_encoder + > + class entropy_encoder_model + { + /*! + REQUIREMENTS ON alphabet_size + 1 < alphabet_size < 65535 + + REQUIREMENTS ON entropy_encoder + is an implementation of entropy_encoder/entropy_encoder_kernel_abstract.h + + INITIAL VALUE + Initially this object is at some predefined empty or ground state. + + WHAT THIS OBJECT REPRESENTS + This object represents some kind of statistical model. You + can use it to write symbols to an entropy_encoder and it will calculate + the cumulative counts/probabilities and manage contexts for you. + + Note that all implementations of entropy_encoder_model and + entropy_decoder_model are paired. This means that if you use + entropy_encoder_model_kernel_n to encode something then you must + use the corresponding entropy_decoder_model_kernel_n to decode it. + + Also note that this object does not perform any buffering of symbols. It + writes them to its associated entropy_encoder immediately. + This makes it safe to use multiple entropy_encoder_model objects with + a single entropy_encoder without them trampling each other. + !*/ + + public: + + typedef entropy_encoder entropy_encoder_type; + + entropy_encoder_model ( + entropy_encoder& coder + ); + /*! + ensures + - #*this is properly initialized + - &#get_entropy_encoder() == &coder + throws + - any exception + !*/ + + virtual ~entropy_encoder_model ( + ); + /*! + ensures + - all memory associated with *this has been released + !*/ + + void clear( + ); + /*! + ensures + - #*this has its initial value + - does not modify get_entropy_encoder() + throws + - any exception + if this exception is thrown then *this is unusable + until clear() is called and succeeds + !*/ + + void encode ( + unsigned long symbol + ); + /*! + requires + - symbol < alphabet_size + ensures + - encodes and writes the symbol to get_entropy_encoder(). + This also means that there is no internal buffering. symbol is + written immediately to the entropy_encoder. + throws + - any exception + If this exception is thrown then #*this is unusable until + clear() is called and succeeds. + !*/ + + entropy_encoder& get_entropy_encoder ( + ); + /*! + ensures + - returns a reference to the entropy_encoder used by *this + !*/ + + static unsigned long get_alphabet_size ( + ); + /*! + ensures + - returns alphabet_size + !*/ + + private: + + // restricted functions + entropy_encoder_model(entropy_encoder_model&); // copy constructor + entropy_encoder_model& operator=(entropy_encoder_model&); // assignment operator + + }; + +} + +#endif // DLIB_ENTROPY_ENCODER_MODEL_KERNEl_ABSTRACT_ + diff --git a/ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_c.h b/ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_c.h new file mode 100644 index 000000000..4637ddd1e --- /dev/null +++ b/ml/dlib/dlib/entropy_encoder_model/entropy_encoder_model_kernel_c.h @@ -0,0 +1,65 @@ +// Copyright (C) 2004 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ENTROPY_ENCODER_MODEL_KERNEl_C_ +#define DLIB_ENTROPY_ENCODER_MODEL_KERNEl_C_ + +#include "entropy_encoder_model_kernel_abstract.h" +#include "../algs.h" +#include "../assert.h" +#include + +namespace dlib +{ + + template < + typename eem_base + > + class entropy_encoder_model_kernel_c : public eem_base + { + const unsigned long alphabet_size; + typedef typename eem_base::entropy_encoder_type entropy_encoder; + + public: + + entropy_encoder_model_kernel_c ( + entropy_encoder& coder + ) : eem_base(coder), alphabet_size(eem_base::get_alphabet_size()) {} + + void encode ( + unsigned long symbol + ); + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename eem_base + > + void entropy_encoder_model_kernel_c:: + encode ( + unsigned long symbol + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(symbol < alphabet_size, + "\tvoid entropy_encoder_model::encode()" + << "\n\tthe symbol must be in the range 0 to alphabet_size-1" + << "\n\talphabet_size: " << alphabet_size + << "\n\tsymbol: " << symbol + << "\n\tthis: " << this + ); + + // call the real function + eem_base::encode(symbol); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_ENTROPY_ENCODER_MODEL_KERNEl_C_ + diff --git a/ml/dlib/dlib/error.h b/ml/dlib/dlib/error.h new file mode 100644 index 000000000..ce9b95b1a --- /dev/null +++ b/ml/dlib/dlib/error.h @@ -0,0 +1,449 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ERROr_ +#define DLIB_ERROr_ + +#include +#include // for std::bad_alloc +#include +#include +#include +#include + +// ------------------------------- +// ------ exception classes ------ +// ------------------------------- + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + enum error_type + { + EPORT_IN_USE, + ETIMEOUT, + ECONNECTION, + ELISTENER, + ERESOLVE, + EMONITOR, + ECREATE_THREAD, + ECREATE_MUTEX, + ECREATE_SIGNALER, + EUNSPECIFIED, + EGENERAL_TYPE1, + EGENERAL_TYPE2, + EGENERAL_TYPE3, + EINVALID_OPTION, + ETOO_FEW_ARGS, + ETOO_MANY_ARGS, + ESOCKET, + ETHREAD, + EGUI, + EFATAL, + EBROKEN_ASSERT, + EIMAGE_LOAD, + EDIR_CREATE, + EINCOMPATIBLE_OPTIONS, + EMISSING_REQUIRED_OPTION, + EINVALID_OPTION_ARG, + EMULTIPLE_OCCURANCES, + ECONFIG_READER, + EIMAGE_SAVE, + ECAST_TO_STRING, + ESTRING_CAST, + EUTF8_TO_UTF32, + EOPTION_PARSE + }; + +// ---------------------------------------------------------------------------------------- + + // the base exception class + class error : public std::exception + { + /*! + WHAT THIS OBJECT REPRESENTS + This is the base exception class for the dlib library. i.e. all + exceptions in this library inherit from this class. + !*/ + + public: + error( + error_type t, + const std::string& a + ): info(a), type(t) {} + /*! + ensures + - #type == t + - #info == a + !*/ + + error( + error_type t + ): type(t) {} + /*! + ensures + - #type == t + - #info == "" + !*/ + + error( + const std::string& a + ): info(a), type(EUNSPECIFIED) {} + /*! + ensures + - #type == EUNSPECIFIED + - #info == a + !*/ + + error( + ): type(EUNSPECIFIED) {} + /*! + ensures + - #type == EUNSPECIFIED + - #info == "" + !*/ + + virtual ~error( + ) throw() {} + /*! + ensures + - does nothing + !*/ + + const char* what( + ) const throw() + /*! + ensures + - if (info.size() != 0) then + - returns info.c_str() + - else + - returns type_to_string(type) + !*/ + { + if (info.size() > 0) + return info.c_str(); + else + return type_to_string(); + } + + const char* type_to_string ( + ) const throw() + /*! + ensures + - returns a string that names the contents of the type member. + !*/ + { + if ( type == EPORT_IN_USE) return "EPORT_IN_USE"; + else if ( type == ETIMEOUT) return "ETIMEOUT"; + else if ( type == ECONNECTION) return "ECONNECTION"; + else if ( type == ELISTENER) return "ELISTENER"; + else if ( type == ERESOLVE) return "ERESOLVE"; + else if ( type == EMONITOR) return "EMONITOR"; + else if ( type == ECREATE_THREAD) return "ECREATE_THREAD"; + else if ( type == ECREATE_MUTEX) return "ECREATE_MUTEX"; + else if ( type == ECREATE_SIGNALER) return "ECREATE_SIGNALER"; + else if ( type == EUNSPECIFIED) return "EUNSPECIFIED"; + else if ( type == EGENERAL_TYPE1) return "EGENERAL_TYPE1"; + else if ( type == EGENERAL_TYPE2) return "EGENERAL_TYPE2"; + else if ( type == EGENERAL_TYPE3) return "EGENERAL_TYPE3"; + else if ( type == EINVALID_OPTION) return "EINVALID_OPTION"; + else if ( type == ETOO_FEW_ARGS) return "ETOO_FEW_ARGS"; + else if ( type == ETOO_MANY_ARGS) return "ETOO_MANY_ARGS"; + else if ( type == ESOCKET) return "ESOCKET"; + else if ( type == ETHREAD) return "ETHREAD"; + else if ( type == EGUI) return "EGUI"; + else if ( type == EFATAL) return "EFATAL"; + else if ( type == EBROKEN_ASSERT) return "EBROKEN_ASSERT"; + else if ( type == EIMAGE_LOAD) return "EIMAGE_LOAD"; + else if ( type == EDIR_CREATE) return "EDIR_CREATE"; + else if ( type == EINCOMPATIBLE_OPTIONS) return "EINCOMPATIBLE_OPTIONS"; + else if ( type == EMISSING_REQUIRED_OPTION) return "EMISSING_REQUIRED_OPTION"; + else if ( type == EINVALID_OPTION_ARG) return "EINVALID_OPTION_ARG"; + else if ( type == EMULTIPLE_OCCURANCES) return "EMULTIPLE_OCCURANCES"; + else if ( type == ECONFIG_READER) return "ECONFIG_READER"; + else if ( type == EIMAGE_SAVE) return "EIMAGE_SAVE"; + else if ( type == ECAST_TO_STRING) return "ECAST_TO_STRING"; + else if ( type == ESTRING_CAST) return "ESTRING_CAST"; + else if ( type == EUTF8_TO_UTF32) return "EUTF8_TO_UTF32"; + else if ( type == EOPTION_PARSE) return "EOPTION_PARSE"; + else return "undefined error type"; + } + + const std::string info; // info about the error + const error_type type; // the type of the error + + private: + const error& operator=(const error&); + }; + +// ---------------------------------------------------------------------------------------- + + class fatal_error : public error + { + /*! + WHAT THIS OBJECT REPRESENTS + As the name says, this object represents some kind of fatal error. + That is, it represents an unrecoverable error and any program that + throws this exception is, by definition, buggy and needs to be fixed. + + Note that a fatal_error exception can only be thrown once. The second + time an application attempts to construct a fatal_error it will be + immediately aborted and an error message will be printed to std::cerr. + The reason for this is because the first fatal_error was apparently ignored + so the second fatal_error is going to make itself impossible to ignore + by calling abort. The lesson here is that you should not try to ignore + fatal errors. + + This is also the exception thrown by the DLIB_ASSERT and DLIB_CASSERT macros. + !*/ + + public: + fatal_error( + error_type t, + const std::string& a + ): error(t,a) {check_for_previous_fatal_errors();} + /*! + ensures + - #type == t + - #info == a + !*/ + + fatal_error( + error_type t + ): error(t) {check_for_previous_fatal_errors();} + /*! + ensures + - #type == t + - #info == "" + !*/ + + fatal_error( + const std::string& a + ): error(EFATAL,a) {check_for_previous_fatal_errors();} + /*! + ensures + - #type == EFATAL + - #info == a + !*/ + + fatal_error( + ): error(EFATAL) {check_for_previous_fatal_errors();} + /*! + ensures + - #type == EFATAL + - #info == "" + !*/ + + private: + + static inline char* message () + { + static char buf[2000]; + buf[1999] = '\0'; // just to be extra safe + return buf; + } + + static inline void dlib_fatal_error_terminate ( + ) + { + std::cerr << "\n**************************** FATAL ERROR DETECTED ****************************"; + std::cerr << message() << std::endl; + std::cerr << "******************************************************************************\n" << std::endl; + } + + void check_for_previous_fatal_errors() + { + // If dlib is being use to create plugins for some other application, like + // MATLAB, then don't do these checks since it terminates the over arching + // system. Just let the errors go to the plugin handler and it will deal with + // them. +#if defined(MATLAB_MEX_FILE) || defined(DLIB_NO_ABORT_ON_2ND_FATAL_ERROR) + return; +#else + static bool is_first_fatal_error = true; + if (is_first_fatal_error == false) + { + std::cerr << "\n\n ************************** FATAL ERROR DETECTED ************************** " << std::endl; + std::cerr << " ************************** FATAL ERROR DETECTED ************************** " << std::endl; + std::cerr << " ************************** FATAL ERROR DETECTED ************************** \n" << std::endl; + std::cerr << "Two fatal errors have been detected, the first was inappropriately ignored. \n" + << "To prevent further fatal errors from being ignored this application will be \n" + << "terminated immediately and you should go fix this buggy program.\n\n" + << "The error message from this fatal error was:\n" << this->what() << "\n\n" << std::endl; + using namespace std; + assert(false); + abort(); + } + else + { + // copy the message into the fixed message buffer so that it can be recalled by dlib_fatal_error_terminate + // if needed. + char* msg = message(); + unsigned long i; + for (i = 0; i < 2000-1 && i < this->info.size(); ++i) + msg[i] = info[i]; + msg[i] = '\0'; + + // set this termination handler so that if the user doesn't catch this dlib::fatal_error that is being + // thrown then it will eventually be printed to standard error + std::set_terminate(&dlib_fatal_error_terminate); + } + is_first_fatal_error = false; +#endif + } + }; + +// ---------------------------------------------------------------------------------------- + + class gui_error : public error + { + public: + gui_error( + error_type t, + const std::string& a + ): error(t,a) {} + /*! + ensures + - #type == t + - #info == a + !*/ + + gui_error( + error_type t + ): error(t) {} + /*! + ensures + - #type == t + - #info == "" + !*/ + + gui_error( + const std::string& a + ): error(EGUI,a) {} + /*! + ensures + - #type == EGUI + - #info == a + !*/ + + gui_error( + ): error(EGUI) {} + /*! + ensures + - #type == EGUI + - #info == "" + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + class socket_error : public error + { + public: + socket_error( + error_type t, + const std::string& a + ): error(t,a) {} + /*! + ensures + - #type == t + - #info == a + !*/ + + socket_error( + error_type t + ): error(t) {} + /*! + ensures + - #type == t + - #info == "" + !*/ + + socket_error( + const std::string& a + ): error(ESOCKET,a) {} + /*! + ensures + - #type == ESOCKET + - #info == a + !*/ + + socket_error( + ): error(ESOCKET) {} + /*! + ensures + - #type == ESOCKET + - #info == "" + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + class thread_error : public error + { + public: + thread_error( + error_type t, + const std::string& a + ): error(t,a) {} + /*! + ensures + - #type == t + - #info == a + !*/ + + thread_error( + error_type t + ): error(t) {} + /*! + ensures + - #type == t + - #info == "" + !*/ + + thread_error( + const std::string& a + ): error(ETHREAD,a) {} + /*! + ensures + - #type == ETHREAD + - #info == a + !*/ + + thread_error( + ): error(ETHREAD) {} + /*! + ensures + - #type == ETHREAD + - #info == "" + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + class impossible_labeling_error : public dlib::error + { + /*! + WHAT THIS OBJECT REPRESENTS + This is the exception thrown by code that trains object detectors (e.g. + structural_svm_object_detection_problem) when they detect that the set of + truth boxes given to the training algorithm contains some impossible to + obtain outputs. + + This kind of problem can happen when the set of image positions scanned by + the underlying object detection method doesn't include the truth rectangle + as a possible output. Another possibility is when two truth boxes are very + close together and hard coded non-max suppression logic would prevent two + boxes in such close proximity from being output. + !*/ + public: + impossible_labeling_error(const std::string& msg) : dlib::error(msg) {}; + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_ERROr_ + diff --git a/ml/dlib/dlib/external/cblas/CMakeLists.txt b/ml/dlib/dlib/external/cblas/CMakeLists.txt new file mode 100644 index 000000000..0d800ae13 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/CMakeLists.txt @@ -0,0 +1,182 @@ +# +# This is a CMake makefile. You can find the cmake utility and +# information about it at http://www.cmake.org +# + + +cmake_minimum_required(VERSION 2.8.12) +project(cblas) + + +enable_language (Fortran) + +set(CMAKE_POSITION_INDEPENDENT_CODE True) +add_definitions(-DADD_ -DF77_INT=ptrdiff_t) + +add_library(cblas STATIC + cblas_caxpy.c + #cblas_ccopy.c + cblas_cdotc_sub.c + cblas_cdotu_sub.c + #cblas_cgbmv.c + cblas_cgemm.c + cblas_cgemv.c + cblas_cgerc.c + cblas_cgeru.c + #cblas_chbmv.c + #cblas_chemm.c + #cblas_chemv.c + #cblas_cher2.c + #cblas_cher2k.c + #cblas_cher.c + #cblas_cherk.c + #cblas_chpmv.c + #cblas_chpr2.c + #cblas_chpr.c + cblas_cscal.c + #cblas_csscal.c + #cblas_cswap.c + #cblas_csymm.c + #cblas_csyr2k.c + #cblas_csyrk.c + #cblas_ctbmv.c + #cblas_ctbsv.c + #cblas_ctpmv.c + #cblas_ctpsv.c + #cblas_ctrmm.c + #cblas_ctrmv.c + cblas_ctrsm.c + #cblas_ctrsv.c + #cblas_dasum.c + cblas_daxpy.c + #cblas_dcopy.c + cblas_ddot.c + #cblas_dgbmv.c + cblas_dgemm.c + cblas_dgemv.c + cblas_dger.c + #cblas_dnrm2.c + #cblas_drot.c + #cblas_drotg.c + #cblas_drotm.c + #cblas_drotmg.c + #cblas_dsbmv.c + cblas_dscal.c + #cblas_dsdot.c + #cblas_dspmv.c + #cblas_dspr2.c + #cblas_dspr.c + #cblas_dswap.c + #cblas_dsymm.c + #cblas_dsymv.c + #cblas_dsyr2.c + #cblas_dsyr2k.c + #cblas_dsyr.c + #cblas_dsyrk.c + #cblas_dtbmv.c + #cblas_dtbsv.c + #cblas_dtpmv.c + #cblas_dtpsv.c + #cblas_dtrmm.c + #cblas_dtrmv.c + cblas_dtrsm.c + #cblas_dtrsv.c + #cblas_dzasum.c + #cblas_dznrm2.c + #cblas_icamax.c + #cblas_idamax.c + #cblas_isamax.c + #cblas_izamax.c + #cblas_sasum.c + cblas_saxpy.c + #cblas_scasum.c + #cblas_scnrm2.c + #cblas_scopy.c + cblas_sdot.c + #cblas_sdsdot.c + #cblas_sgbmv.c + cblas_sgemm.c + cblas_sgemv.c + cblas_sger.c + #cblas_snrm2.c + #cblas_srot.c + #cblas_srotg.c + #cblas_srotm.c + #cblas_srotmg.c + #cblas_ssbmv.c + cblas_sscal.c + #cblas_sspmv.c + #cblas_sspr2.c + #cblas_sspr.c + #cblas_sswap.c + #cblas_ssymm.c + #cblas_ssymv.c + #cblas_ssyr2.c + #cblas_ssyr2k.c + #cblas_ssyr.c + #cblas_ssyrk.c + #cblas_stbmv.c + #cblas_stbsv.c + #cblas_stpmv.c + #cblas_stpsv.c + #cblas_strmm.c + #cblas_strmv.c + cblas_strsm.c + #cblas_strsv.c + cblas_xerbla.c + cblas_zaxpy.c + #cblas_zcopy.c + cblas_zdotc_sub.c + cblas_zdotu_sub.c + #cblas_zdscal.c + #cblas_zgbmv.c + cblas_zgemm.c + cblas_zgemv.c + cblas_zgerc.c + cblas_zgeru.c + #cblas_zhbmv.c + #cblas_zhemm.c + #cblas_zhemv.c + #cblas_zher2.c + #cblas_zher2k.c + #cblas_zher.c + #cblas_zherk.c + #cblas_zhpmv.c + #cblas_zhpr2.c + #cblas_zhpr.c + cblas_zscal.c + #cblas_zswap.c + #cblas_zsymm.c + #cblas_zsyr2k.c + #cblas_zsyrk.c + #cblas_ztbmv.c + #cblas_ztbsv.c + #cblas_ztpmv.c + #cblas_ztpsv.c + #cblas_ztrmm.c + #cblas_ztrmv.c + cblas_ztrsm.c + #cblas_ztrsv.c + + cdotcsub.f + cdotusub.f + dasumsub.f + ddotsub.f + dnrm2sub.f + dsdotsub.f + dzasumsub.f + dznrm2sub.f + icamaxsub.f + idamaxsub.f + isamaxsub.f + izamaxsub.f + sasumsub.f + scasumsub.f + scnrm2sub.f + sdotsub.f + sdsdotsub.f + snrm2sub.f + zdotcsub.f + zdotusub.f + ) + diff --git a/ml/dlib/dlib/external/cblas/README b/ml/dlib/dlib/external/cblas/README new file mode 100644 index 000000000..a89feaffe --- /dev/null +++ b/ml/dlib/dlib/external/cblas/README @@ -0,0 +1,7 @@ +This folder contains a copy of CBLAS (from http://www.netlib.org/blas/) which +has been setup so you can compile it with CMake. It also only compiles the +part of CBLAS needed by dlib. + +Most BLAS libraries come with CBLAS, however, some don't. In particular, if +you are using the BLAS that comes with MATLAB then you will need this CBLAS +code linked into your own to get dlib working with MATLAB's built in BLAS. diff --git a/ml/dlib/dlib/external/cblas/cblas.h b/ml/dlib/dlib/external/cblas/cblas.h new file mode 100644 index 000000000..f91557e74 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas.h @@ -0,0 +1,575 @@ +#ifndef CBLAS_H +#define CBLAS_H +#include + +/* + * Enumerated and derived types + */ +#define CBLAS_INDEX size_t /* this may vary between platforms */ + +enum CBLAS_ORDER {CblasRowMajor=101, CblasColMajor=102}; +enum CBLAS_TRANSPOSE {CblasNoTrans=111, CblasTrans=112, CblasConjTrans=113}; +enum CBLAS_UPLO {CblasUpper=121, CblasLower=122}; +enum CBLAS_DIAG {CblasNonUnit=131, CblasUnit=132}; +enum CBLAS_SIDE {CblasLeft=141, CblasRight=142}; + +#ifdef __cplusplus +extern "C" { +#endif + +/* + * =========================================================================== + * Prototypes for level 1 BLAS functions (complex are recast as routines) + * =========================================================================== + */ +float cblas_sdsdot(const int N, const float alpha, const float *X, + const int incX, const float *Y, const int incY); +double cblas_dsdot(const int N, const float *X, const int incX, const float *Y, + const int incY); +float cblas_sdot(const int N, const float *X, const int incX, + const float *Y, const int incY); +double cblas_ddot(const int N, const double *X, const int incX, + const double *Y, const int incY); + +/* + * Functions having prefixes Z and C only + */ +void cblas_cdotu_sub(const int N, const void *X, const int incX, + const void *Y, const int incY, void *dotu); +void cblas_cdotc_sub(const int N, const void *X, const int incX, + const void *Y, const int incY, void *dotc); + +void cblas_zdotu_sub(const int N, const void *X, const int incX, + const void *Y, const int incY, void *dotu); +void cblas_zdotc_sub(const int N, const void *X, const int incX, + const void *Y, const int incY, void *dotc); + + +/* + * Functions having prefixes S D SC DZ + */ +float cblas_snrm2(const int N, const float *X, const int incX); +float cblas_sasum(const int N, const float *X, const int incX); + +double cblas_dnrm2(const int N, const double *X, const int incX); +double cblas_dasum(const int N, const double *X, const int incX); + +float cblas_scnrm2(const int N, const void *X, const int incX); +float cblas_scasum(const int N, const void *X, const int incX); + +double cblas_dznrm2(const int N, const void *X, const int incX); +double cblas_dzasum(const int N, const void *X, const int incX); + + +/* + * Functions having standard 4 prefixes (S D C Z) + */ +CBLAS_INDEX cblas_isamax(const int N, const float *X, const int incX); +CBLAS_INDEX cblas_idamax(const int N, const double *X, const int incX); +CBLAS_INDEX cblas_icamax(const int N, const void *X, const int incX); +CBLAS_INDEX cblas_izamax(const int N, const void *X, const int incX); + +/* + * =========================================================================== + * Prototypes for level 1 BLAS routines + * =========================================================================== + */ + +/* + * Routines with standard 4 prefixes (s, d, c, z) + */ +void cblas_sswap(const int N, float *X, const int incX, + float *Y, const int incY); +void cblas_scopy(const int N, const float *X, const int incX, + float *Y, const int incY); +void cblas_saxpy(const int N, const float alpha, const float *X, + const int incX, float *Y, const int incY); + +void cblas_dswap(const int N, double *X, const int incX, + double *Y, const int incY); +void cblas_dcopy(const int N, const double *X, const int incX, + double *Y, const int incY); +void cblas_daxpy(const int N, const double alpha, const double *X, + const int incX, double *Y, const int incY); + +void cblas_cswap(const int N, void *X, const int incX, + void *Y, const int incY); +void cblas_ccopy(const int N, const void *X, const int incX, + void *Y, const int incY); +void cblas_caxpy(const int N, const void *alpha, const void *X, + const int incX, void *Y, const int incY); + +void cblas_zswap(const int N, void *X, const int incX, + void *Y, const int incY); +void cblas_zcopy(const int N, const void *X, const int incX, + void *Y, const int incY); +void cblas_zaxpy(const int N, const void *alpha, const void *X, + const int incX, void *Y, const int incY); + + +/* + * Routines with S and D prefix only + */ +void cblas_srotg(float *a, float *b, float *c, float *s); +void cblas_srotmg(float *d1, float *d2, float *b1, const float b2, float *P); +void cblas_srot(const int N, float *X, const int incX, + float *Y, const int incY, const float c, const float s); +void cblas_srotm(const int N, float *X, const int incX, + float *Y, const int incY, const float *P); + +void cblas_drotg(double *a, double *b, double *c, double *s); +void cblas_drotmg(double *d1, double *d2, double *b1, const double b2, double *P); +void cblas_drot(const int N, double *X, const int incX, + double *Y, const int incY, const double c, const double s); +void cblas_drotm(const int N, double *X, const int incX, + double *Y, const int incY, const double *P); + + +/* + * Routines with S D C Z CS and ZD prefixes + */ +void cblas_sscal(const int N, const float alpha, float *X, const int incX); +void cblas_dscal(const int N, const double alpha, double *X, const int incX); +void cblas_cscal(const int N, const void *alpha, void *X, const int incX); +void cblas_zscal(const int N, const void *alpha, void *X, const int incX); +void cblas_csscal(const int N, const float alpha, void *X, const int incX); +void cblas_zdscal(const int N, const double alpha, void *X, const int incX); + +/* + * =========================================================================== + * Prototypes for level 2 BLAS + * =========================================================================== + */ + +/* + * Routines with standard 4 prefixes (S, D, C, Z) + */ +void cblas_sgemv(const enum CBLAS_ORDER order, + const enum CBLAS_TRANSPOSE TransA, const int M, const int N, + const float alpha, const float *A, const int lda, + const float *X, const int incX, const float beta, + float *Y, const int incY); +void cblas_sgbmv(const enum CBLAS_ORDER order, + const enum CBLAS_TRANSPOSE TransA, const int M, const int N, + const int KL, const int KU, const float alpha, + const float *A, const int lda, const float *X, + const int incX, const float beta, float *Y, const int incY); +void cblas_strmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const float *A, const int lda, + float *X, const int incX); +void cblas_stbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const int K, const float *A, const int lda, + float *X, const int incX); +void cblas_stpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const float *Ap, float *X, const int incX); +void cblas_strsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const float *A, const int lda, float *X, + const int incX); +void cblas_stbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const int K, const float *A, const int lda, + float *X, const int incX); +void cblas_stpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const float *Ap, float *X, const int incX); + +void cblas_dgemv(const enum CBLAS_ORDER order, + const enum CBLAS_TRANSPOSE TransA, const int M, const int N, + const double alpha, const double *A, const int lda, + const double *X, const int incX, const double beta, + double *Y, const int incY); +void cblas_dgbmv(const enum CBLAS_ORDER order, + const enum CBLAS_TRANSPOSE TransA, const int M, const int N, + const int KL, const int KU, const double alpha, + const double *A, const int lda, const double *X, + const int incX, const double beta, double *Y, const int incY); +void cblas_dtrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const double *A, const int lda, + double *X, const int incX); +void cblas_dtbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const int K, const double *A, const int lda, + double *X, const int incX); +void cblas_dtpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const double *Ap, double *X, const int incX); +void cblas_dtrsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const double *A, const int lda, double *X, + const int incX); +void cblas_dtbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const int K, const double *A, const int lda, + double *X, const int incX); +void cblas_dtpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const double *Ap, double *X, const int incX); + +void cblas_cgemv(const enum CBLAS_ORDER order, + const enum CBLAS_TRANSPOSE TransA, const int M, const int N, + const void *alpha, const void *A, const int lda, + const void *X, const int incX, const void *beta, + void *Y, const int incY); +void cblas_cgbmv(const enum CBLAS_ORDER order, + const enum CBLAS_TRANSPOSE TransA, const int M, const int N, + const int KL, const int KU, const void *alpha, + const void *A, const int lda, const void *X, + const int incX, const void *beta, void *Y, const int incY); +void cblas_ctrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const void *A, const int lda, + void *X, const int incX); +void cblas_ctbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const int K, const void *A, const int lda, + void *X, const int incX); +void cblas_ctpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const void *Ap, void *X, const int incX); +void cblas_ctrsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const void *A, const int lda, void *X, + const int incX); +void cblas_ctbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const int K, const void *A, const int lda, + void *X, const int incX); +void cblas_ctpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const void *Ap, void *X, const int incX); + +void cblas_zgemv(const enum CBLAS_ORDER order, + const enum CBLAS_TRANSPOSE TransA, const int M, const int N, + const void *alpha, const void *A, const int lda, + const void *X, const int incX, const void *beta, + void *Y, const int incY); +void cblas_zgbmv(const enum CBLAS_ORDER order, + const enum CBLAS_TRANSPOSE TransA, const int M, const int N, + const int KL, const int KU, const void *alpha, + const void *A, const int lda, const void *X, + const int incX, const void *beta, void *Y, const int incY); +void cblas_ztrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const void *A, const int lda, + void *X, const int incX); +void cblas_ztbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const int K, const void *A, const int lda, + void *X, const int incX); +void cblas_ztpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const void *Ap, void *X, const int incX); +void cblas_ztrsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const void *A, const int lda, void *X, + const int incX); +void cblas_ztbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const int K, const void *A, const int lda, + void *X, const int incX); +void cblas_ztpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const void *Ap, void *X, const int incX); + + +/* + * Routines with S and D prefixes only + */ +void cblas_ssymv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const int N, const float alpha, const float *A, + const int lda, const float *X, const int incX, + const float beta, float *Y, const int incY); +void cblas_ssbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const int N, const int K, const float alpha, const float *A, + const int lda, const float *X, const int incX, + const float beta, float *Y, const int incY); +void cblas_sspmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const int N, const float alpha, const float *Ap, + const float *X, const int incX, + const float beta, float *Y, const int incY); +void cblas_sger(const enum CBLAS_ORDER order, const int M, const int N, + const float alpha, const float *X, const int incX, + const float *Y, const int incY, float *A, const int lda); +void cblas_ssyr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const int N, const float alpha, const float *X, + const int incX, float *A, const int lda); +void cblas_sspr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const int N, const float alpha, const float *X, + const int incX, float *Ap); +void cblas_ssyr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const int N, const float alpha, const float *X, + const int incX, const float *Y, const int incY, float *A, + const int lda); +void cblas_sspr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const int N, const float alpha, const float *X, + const int incX, const float *Y, const int incY, float *A); + +void cblas_dsymv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const int N, const double alpha, const double *A, + const int lda, const double *X, const int incX, + const double beta, double *Y, const int incY); +void cblas_dsbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const int N, const int K, const double alpha, const double *A, + const int lda, const double *X, const int incX, + const double beta, double *Y, const int incY); +void cblas_dspmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const int N, const double alpha, const double *Ap, + const double *X, const int incX, + const double beta, double *Y, const int incY); +void cblas_dger(const enum CBLAS_ORDER order, const int M, const int N, + const double alpha, const double *X, const int incX, + const double *Y, const int incY, double *A, const int lda); +void cblas_dsyr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const int N, const double alpha, const double *X, + const int incX, double *A, const int lda); +void cblas_dspr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const int N, const double alpha, const double *X, + const int incX, double *Ap); +void cblas_dsyr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const int N, const double alpha, const double *X, + const int incX, const double *Y, const int incY, double *A, + const int lda); +void cblas_dspr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const int N, const double alpha, const double *X, + const int incX, const double *Y, const int incY, double *A); + + +/* + * Routines with C and Z prefixes only + */ +void cblas_chemv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const int N, const void *alpha, const void *A, + const int lda, const void *X, const int incX, + const void *beta, void *Y, const int incY); +void cblas_chbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const int N, const int K, const void *alpha, const void *A, + const int lda, const void *X, const int incX, + const void *beta, void *Y, const int incY); +void cblas_chpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const int N, const void *alpha, const void *Ap, + const void *X, const int incX, + const void *beta, void *Y, const int incY); +void cblas_cgeru(const enum CBLAS_ORDER order, const int M, const int N, + const void *alpha, const void *X, const int incX, + const void *Y, const int incY, void *A, const int lda); +void cblas_cgerc(const enum CBLAS_ORDER order, const int M, const int N, + const void *alpha, const void *X, const int incX, + const void *Y, const int incY, void *A, const int lda); +void cblas_cher(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const int N, const float alpha, const void *X, const int incX, + void *A, const int lda); +void cblas_chpr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const int N, const float alpha, const void *X, + const int incX, void *A); +void cblas_cher2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N, + const void *alpha, const void *X, const int incX, + const void *Y, const int incY, void *A, const int lda); +void cblas_chpr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N, + const void *alpha, const void *X, const int incX, + const void *Y, const int incY, void *Ap); + +void cblas_zhemv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const int N, const void *alpha, const void *A, + const int lda, const void *X, const int incX, + const void *beta, void *Y, const int incY); +void cblas_zhbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const int N, const int K, const void *alpha, const void *A, + const int lda, const void *X, const int incX, + const void *beta, void *Y, const int incY); +void cblas_zhpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const int N, const void *alpha, const void *Ap, + const void *X, const int incX, + const void *beta, void *Y, const int incY); +void cblas_zgeru(const enum CBLAS_ORDER order, const int M, const int N, + const void *alpha, const void *X, const int incX, + const void *Y, const int incY, void *A, const int lda); +void cblas_zgerc(const enum CBLAS_ORDER order, const int M, const int N, + const void *alpha, const void *X, const int incX, + const void *Y, const int incY, void *A, const int lda); +void cblas_zher(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const int N, const double alpha, const void *X, const int incX, + void *A, const int lda); +void cblas_zhpr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const int N, const double alpha, const void *X, + const int incX, void *A); +void cblas_zher2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N, + const void *alpha, const void *X, const int incX, + const void *Y, const int incY, void *A, const int lda); +void cblas_zhpr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N, + const void *alpha, const void *X, const int incX, + const void *Y, const int incY, void *Ap); + +/* + * =========================================================================== + * Prototypes for level 3 BLAS + * =========================================================================== + */ + +/* + * Routines with standard 4 prefixes (S, D, C, Z) + */ +void cblas_sgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_TRANSPOSE TransB, const int M, const int N, + const int K, const float alpha, const float *A, + const int lda, const float *B, const int ldb, + const float beta, float *C, const int ldc); +void cblas_ssymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const int M, const int N, + const float alpha, const float *A, const int lda, + const float *B, const int ldb, const float beta, + float *C, const int ldc); +void cblas_ssyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const float alpha, const float *A, const int lda, + const float beta, float *C, const int ldc); +void cblas_ssyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const float alpha, const float *A, const int lda, + const float *B, const int ldb, const float beta, + float *C, const int ldc); +void cblas_strmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_DIAG Diag, const int M, const int N, + const float alpha, const float *A, const int lda, + float *B, const int ldb); +void cblas_strsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_DIAG Diag, const int M, const int N, + const float alpha, const float *A, const int lda, + float *B, const int ldb); + +void cblas_dgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_TRANSPOSE TransB, const int M, const int N, + const int K, const double alpha, const double *A, + const int lda, const double *B, const int ldb, + const double beta, double *C, const int ldc); +void cblas_dsymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const int M, const int N, + const double alpha, const double *A, const int lda, + const double *B, const int ldb, const double beta, + double *C, const int ldc); +void cblas_dsyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const double alpha, const double *A, const int lda, + const double beta, double *C, const int ldc); +void cblas_dsyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const double alpha, const double *A, const int lda, + const double *B, const int ldb, const double beta, + double *C, const int ldc); +void cblas_dtrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_DIAG Diag, const int M, const int N, + const double alpha, const double *A, const int lda, + double *B, const int ldb); +void cblas_dtrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_DIAG Diag, const int M, const int N, + const double alpha, const double *A, const int lda, + double *B, const int ldb); + +void cblas_cgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_TRANSPOSE TransB, const int M, const int N, + const int K, const void *alpha, const void *A, + const int lda, const void *B, const int ldb, + const void *beta, void *C, const int ldc); +void cblas_csymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const int M, const int N, + const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, + void *C, const int ldc); +void cblas_csyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const void *alpha, const void *A, const int lda, + const void *beta, void *C, const int ldc); +void cblas_csyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, + void *C, const int ldc); +void cblas_ctrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_DIAG Diag, const int M, const int N, + const void *alpha, const void *A, const int lda, + void *B, const int ldb); +void cblas_ctrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_DIAG Diag, const int M, const int N, + const void *alpha, const void *A, const int lda, + void *B, const int ldb); + +void cblas_zgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_TRANSPOSE TransB, const int M, const int N, + const int K, const void *alpha, const void *A, + const int lda, const void *B, const int ldb, + const void *beta, void *C, const int ldc); +void cblas_zsymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const int M, const int N, + const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, + void *C, const int ldc); +void cblas_zsyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const void *alpha, const void *A, const int lda, + const void *beta, void *C, const int ldc); +void cblas_zsyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, + void *C, const int ldc); +void cblas_ztrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_DIAG Diag, const int M, const int N, + const void *alpha, const void *A, const int lda, + void *B, const int ldb); +void cblas_ztrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_DIAG Diag, const int M, const int N, + const void *alpha, const void *A, const int lda, + void *B, const int ldb); + + +/* + * Routines with prefixes C and Z only + */ +void cblas_chemm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const int M, const int N, + const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, + void *C, const int ldc); +void cblas_cherk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const float alpha, const void *A, const int lda, + const float beta, void *C, const int ldc); +void cblas_cher2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const float beta, + void *C, const int ldc); + +void cblas_zhemm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const int M, const int N, + const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, + void *C, const int ldc); +void cblas_zherk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const double alpha, const void *A, const int lda, + const double beta, void *C, const int ldc); +void cblas_zher2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const double beta, + void *C, const int ldc); + +void cblas_xerbla(int p, const char *rout, const char *form, ...); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/ml/dlib/dlib/external/cblas/cblas_caxpy.c b/ml/dlib/dlib/external/cblas/cblas_caxpy.c new file mode 100644 index 000000000..7579aa707 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_caxpy.c @@ -0,0 +1,22 @@ +/* + * cblas_caxpy.c + * + * The program is a C interface to caxpy. + * + * Written by Keita Teranishi. 2/11/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_caxpy( const int N, const void *alpha, const void *X, + const int incX, void *Y, const int incY) +{ +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX, F77_incY=incY; +#else + #define F77_N N + #define F77_incX incX + #define F77_incY incY +#endif + F77_caxpy( &F77_N, alpha, X, &F77_incX, Y, &F77_incY); +} diff --git a/ml/dlib/dlib/external/cblas/cblas_ccopy.c b/ml/dlib/dlib/external/cblas/cblas_ccopy.c new file mode 100644 index 000000000..b7bc42847 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_ccopy.c @@ -0,0 +1,22 @@ +/* + * cblas_ccopy.c + * + * The program is a C interface to ccopy. + * + * Written by Keita Teranishi. 2/11/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_ccopy( const int N, const void *X, + const int incX, void *Y, const int incY) +{ +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX, F77_incY=incY; +#else + #define F77_N N + #define F77_incX incX + #define F77_incY incY +#endif + F77_ccopy( &F77_N, X, &F77_incX, Y, &F77_incY); +} diff --git a/ml/dlib/dlib/external/cblas/cblas_cdotc_sub.c b/ml/dlib/dlib/external/cblas/cblas_cdotc_sub.c new file mode 100644 index 000000000..d6086814e --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_cdotc_sub.c @@ -0,0 +1,23 @@ +/* + * cblas_cdotc_sub.c + * + * The program is a C interface to cdotc. + * It calls the fortran wrapper before calling cdotc. + * + * Written by Keita Teranishi. 2/11/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_cdotc_sub( const int N, const void *X, const int incX, + const void *Y, const int incY,void *dotc) +{ +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX, F77_incY=incY; +#else + #define F77_N N + #define F77_incX incX + #define F77_incY incY +#endif + F77_cdotc_sub( &F77_N, X, &F77_incX, Y, &F77_incY, dotc); +} diff --git a/ml/dlib/dlib/external/cblas/cblas_cdotu_sub.c b/ml/dlib/dlib/external/cblas/cblas_cdotu_sub.c new file mode 100644 index 000000000..d06e4e5fa --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_cdotu_sub.c @@ -0,0 +1,23 @@ +/* + * cblas_cdotu_sub.f + * + * The program is a C interface to cdotu. + * It calls the forteran wrapper before calling cdotu. + * + * Written by Keita Teranishi. 2/11/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_cdotu_sub( const int N, const void *X, + const int incX, const void *Y, const int incY,void *dotu) +{ +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX, F77_incY=incY; +#else + #define F77_N N + #define F77_incX incX + #define F77_incY incY +#endif + F77_cdotu_sub( &F77_N, X, &F77_incX, Y, &F77_incY, dotu); +} diff --git a/ml/dlib/dlib/external/cblas/cblas_cgbmv.c b/ml/dlib/dlib/external/cblas/cblas_cgbmv.c new file mode 100644 index 000000000..94cc175f3 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_cgbmv.c @@ -0,0 +1,154 @@ +/* + * cblas_cgbmv.c + * The program is a C interface of cgbmv + * + * Keita Teranishi 5/20/98 + * + */ +#include +#include +#include "cblas.h" +#include "cblas_f77.h" +void cblas_cgbmv(const enum CBLAS_ORDER order, + const enum CBLAS_TRANSPOSE TransA, const int M, const int N, + const int KL, const int KU, + const void *alpha, const void *A, const int lda, + const void *X, const int incX, const void *beta, + void *Y, const int incY) +{ + char TA; +#ifdef F77_CHAR + F77_CHAR F77_TA; +#else + #define F77_TA &TA +#endif +#ifdef F77_INT + F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_incX=incX, F77_incY=incY; + F77_INT F77_KL=KL,F77_KU=KU; +#else + #define F77_M M + #define F77_N N + #define F77_lda lda + #define F77_KL KL + #define F77_KU KU + #define F77_incX incx + #define F77_incY incY +#endif + int n=0, i=0, incx=incX; + const float *xx= (float *)X, *alp= (float *)alpha, *bet = (float *)beta; + float ALPHA[2],BETA[2]; + int tincY, tincx; + float *x=(float *)X, *y=(float *)Y, *st=0, *tx=0; + + if (order == CblasColMajor) + { + if (TransA == CblasNoTrans) TA = 'N'; + else if (TransA == CblasTrans) TA = 'T'; + else if (TransA == CblasConjTrans) TA = 'C'; + else + { + cblas_xerbla(2, "cblas_cgbmv","Illegal TransA setting, %d\n", TransA); + return; + } + #ifdef F77_CHAR + F77_TA = C2F_CHAR(&TA); + #endif + F77_cgbmv(F77_TA, &F77_M, &F77_N, &F77_KL, &F77_KU, alpha, + A, &F77_lda, X, &F77_incX, beta, Y, &F77_incY); + } + else if (order == CblasRowMajor) + { + if (TransA == CblasNoTrans) TA = 'T'; + else if (TransA == CblasTrans) TA = 'N'; + else if (TransA == CblasConjTrans) + { + ALPHA[0]= *alp; + ALPHA[1]= -alp[1]; + BETA[0]= *bet; + BETA[1]= -bet[1]; + TA = 'N'; + if (M > 0) + { + n = M << 1; + x = malloc(n*sizeof(float)); + tx = x; + + if( incX > 0 ) { + i = incX << 1 ; + tincx = 2; + st= x+n; + } else { + i = incX *(-2); + tincx = -2; + st = x-2; + x +=(n-2); + } + do + { + *x = *xx; + x[1] = -xx[1]; + x += tincx ; + xx += i; + } + while (x != st); + x=tx; + + #ifdef F77_INT + F77_incX = 1; + #else + incx = 1; + #endif + + if( incY > 0 ) + tincY = incY; + else + tincY = -incY; + + y++; + + if (N > 0) + { + i = tincY << 1; + n = i * N ; + st = y + n; + do { + *y = -(*y); + y += i; + } while(y != st); + y -= n; + } + } + else x = (float *) X; + + + } + else + { + cblas_xerbla(2, "cblas_cgbmv","Illegal TransA setting, %d\n", TransA); + return; + } + #ifdef F77_CHAR + F77_TA = C2F_CHAR(&TA); + #endif + if (TransA == CblasConjTrans) + F77_cgbmv(F77_TA, &F77_N, &F77_M, &F77_KU, &F77_KL, ALPHA, + A ,&F77_lda, x,&F77_incX, BETA, Y, &F77_incY); + else + F77_cgbmv(F77_TA, &F77_N, &F77_M, &F77_KU, &F77_KL, alpha, + A ,&F77_lda, x,&F77_incX, beta, Y, &F77_incY); + if (TransA == CblasConjTrans) + { + if (x != X) free(x); + if (N > 0) + { + do + { + *y = -(*y); + y += i; + } + while (y != st); + } + } + } + else cblas_xerbla(1, "cblas_cgbmv", "Illegal Order setting, %d\n", order); +} diff --git a/ml/dlib/dlib/external/cblas/cblas_cgemm.c b/ml/dlib/dlib/external/cblas/cblas_cgemm.c new file mode 100644 index 000000000..c11641023 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_cgemm.c @@ -0,0 +1,94 @@ +/* + * + * cblas_cgemm.c + * This program is a C interface to cgemm. + * Written by Keita Teranishi + * 4/8/1998 + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_cgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_TRANSPOSE TransB, const int M, const int N, + const int K, const void *alpha, const void *A, + const int lda, const void *B, const int ldb, + const void *beta, void *C, const int ldc) +{ + char TA, TB; +#ifdef F77_CHAR + F77_CHAR F77_TA, F77_TB; +#else + #define F77_TA &TA + #define F77_TB &TB +#endif + +#ifdef F77_INT + F77_INT F77_M=M, F77_N=N, F77_K=K, F77_lda=lda, F77_ldb=ldb; + F77_INT F77_ldc=ldc; +#else + #define F77_M M + #define F77_N N + #define F77_K K + #define F77_lda lda + #define F77_ldb ldb + #define F77_ldc ldc +#endif + + + if( Order == CblasColMajor ) + { + if(TransA == CblasTrans) TA='T'; + else if ( TransA == CblasConjTrans ) TA='C'; + else if ( TransA == CblasNoTrans ) TA='N'; + else + { + cblas_xerbla(2, "cblas_cgemm", "Illegal TransA setting, %d\n", TransA); + return; + } + + if(TransB == CblasTrans) TB='T'; + else if ( TransB == CblasConjTrans ) TB='C'; + else if ( TransB == CblasNoTrans ) TB='N'; + else + { + cblas_xerbla(3, "cblas_cgemm", "Illegal TransB setting, %d\n", TransB); + return; + } + + #ifdef F77_CHAR + F77_TA = C2F_CHAR(&TA); + F77_TB = C2F_CHAR(&TB); + #endif + + F77_cgemm(F77_TA, F77_TB, &F77_M, &F77_N, &F77_K, alpha, A, + &F77_lda, B, &F77_ldb, beta, C, &F77_ldc); + } else if (Order == CblasRowMajor) + { + if(TransA == CblasTrans) TB='T'; + else if ( TransA == CblasConjTrans ) TB='C'; + else if ( TransA == CblasNoTrans ) TB='N'; + else + { + cblas_xerbla(2, "cblas_cgemm", "Illegal TransA setting, %d\n", TransA); + return; + } + if(TransB == CblasTrans) TA='T'; + else if ( TransB == CblasConjTrans ) TA='C'; + else if ( TransB == CblasNoTrans ) TA='N'; + else + { + cblas_xerbla(2, "cblas_cgemm", "Illegal TransB setting, %d\n", TransB); + return; + } + #ifdef F77_CHAR + F77_TA = C2F_CHAR(&TA); + F77_TB = C2F_CHAR(&TB); + #endif + + F77_cgemm(F77_TA, F77_TB, &F77_N, &F77_M, &F77_K, alpha, B, + &F77_ldb, A, &F77_lda, beta, C, &F77_ldc); + } + else cblas_xerbla(1, "cblas_cgemm", "Illegal Order setting, %d\n", Order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_cgemv.c b/ml/dlib/dlib/external/cblas/cblas_cgemv.c new file mode 100644 index 000000000..a1cbb94ee --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_cgemv.c @@ -0,0 +1,151 @@ +/* + * cblas_cgemv.c + * The program is a C interface of cgemv + * + * Keita Teranishi 5/20/98 + * + */ +#include +#include +#include "cblas.h" +#include "cblas_f77.h" +void cblas_cgemv(const enum CBLAS_ORDER order, + const enum CBLAS_TRANSPOSE TransA, const int M, const int N, + const void *alpha, const void *A, const int lda, + const void *X, const int incX, const void *beta, + void *Y, const int incY) +{ + char TA; +#ifdef F77_CHAR + F77_CHAR F77_TA; +#else + #define F77_TA &TA +#endif +#ifdef F77_INT + F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_incX=incX, F77_incY=incY; +#else + #define F77_M M + #define F77_N N + #define F77_lda lda + #define F77_incX incx + #define F77_incY incY +#endif + + int n=0, i=0, incx=incX; + const float *xx= (const float *)X; + float ALPHA[2],BETA[2]; + int tincY, tincx; + float *x=(float *)X, *y=(float *)Y, *st=0, *tx=0; + const float *stx = x; + + + if (order == CblasColMajor) + { + if (TransA == CblasNoTrans) TA = 'N'; + else if (TransA == CblasTrans) TA = 'T'; + else if (TransA == CblasConjTrans) TA = 'C'; + else + { + cblas_xerbla(2, "cblas_cgemv","Illegal TransA setting, %d\n", TransA); + return; + } + #ifdef F77_CHAR + F77_TA = C2F_CHAR(&TA); + #endif + F77_cgemv(F77_TA, &F77_M, &F77_N, alpha, A, &F77_lda, X, &F77_incX, + beta, Y, &F77_incY); + } + else if (order == CblasRowMajor) + { + + if (TransA == CblasNoTrans) TA = 'T'; + else if (TransA == CblasTrans) TA = 'N'; + else if (TransA == CblasConjTrans) + { + ALPHA[0]= *( (const float *) alpha ); + ALPHA[1]= -( *( (const float *) alpha+1) ); + BETA[0]= *( (const float *) beta ); + BETA[1]= -( *( (const float *) beta+1 ) ); + TA = 'N'; + if (M > 0) + { + n = M << 1; + x = malloc(n*sizeof(float)); + tx = x; + if( incX > 0 ) { + i = incX << 1 ; + tincx = 2; + st= x+n; + } else { + i = incX *(-2); + tincx = -2; + st = x-2; + x +=(n-2); + } + + do + { + *x = *xx; + x[1] = -xx[1]; + x += tincx ; + xx += i; + } + while (x != st); + x=tx; + + F77_incX = 1; + + if(incY > 0) + tincY = incY; + else + tincY = -incY; + + y++; + + if (N > 0) + { + i = tincY << 1; + n = i * N ; + st = y + n; + do { + *y = -(*y); + y += i; + } while(y != st); + y -= n; + } + stx = x; + } + else stx = (const float *)X; + } + else + { + cblas_xerbla(2, "cblas_cgemv","Illegal TransA setting, %d\n", TransA); + return; + } + #ifdef F77_CHAR + F77_TA = C2F_CHAR(&TA); + #endif + if (TransA == CblasConjTrans) + F77_cgemv(F77_TA, &F77_N, &F77_M, ALPHA, A, &F77_lda, stx, + &F77_incX, BETA, Y, &F77_incY); + else + F77_cgemv(F77_TA, &F77_N, &F77_M, alpha, A, &F77_lda, x, + &F77_incX, beta, Y, &F77_incY); + + if (TransA == CblasConjTrans) + { + if (x != (const float *)X) free(x); + if (N > 0) + { + do + { + *y = -(*y); + y += i; + } + while (y != st); + } + } + } + else cblas_xerbla(1, "cblas_cgemv", "Illegal Order setting, %d\n", order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_cgerc.c b/ml/dlib/dlib/external/cblas/cblas_cgerc.c new file mode 100644 index 000000000..e843f099b --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_cgerc.c @@ -0,0 +1,77 @@ +/* + * cblas_cgerc.c + * The program is a C interface to cgerc. + * + * Keita Teranishi 5/20/98 + * + */ +#include +#include +#include "cblas.h" +#include "cblas_f77.h" +void cblas_cgerc(const enum CBLAS_ORDER order, const int M, const int N, + const void *alpha, const void *X, const int incX, + const void *Y, const int incY, void *A, const int lda) +{ +#ifdef F77_INT + F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_incX=incX, F77_incY=incY; +#else + #define F77_M M + #define F77_N N + #define F77_incX incX + #define F77_incY incy + #define F77_lda lda +#endif + + int n, i, tincy, incy=incY; + float *y=(float *)Y, *yy=(float *)Y, *ty, *st; + + + if (order == CblasColMajor) + { + F77_cgerc( &F77_M, &F77_N, alpha, X, &F77_incX, Y, &F77_incY, A, + &F77_lda); + } else if (order == CblasRowMajor) + { + if (N > 0) + { + n = N << 1; + y = malloc(n*sizeof(float)); + + ty = y; + if( incY > 0 ) { + i = incY << 1; + tincy = 2; + st= y+n; + } else { + i = incY *(-2); + tincy = -2; + st = y-2; + y +=(n-2); + } + do + { + *y = *yy; + y[1] = -yy[1]; + y += tincy ; + yy += i; + } + while (y != st); + y = ty; + + #ifdef F77_INT + F77_incY = 1; + #else + incy = 1; + #endif + } + else y = (float *) Y; + + F77_cgeru( &F77_N, &F77_M, alpha, y, &F77_incY, X, &F77_incX, A, + &F77_lda); + if(Y!=y) + free(y); + + } else cblas_xerbla(1, "cblas_cgerc", "Illegal Order setting, %d\n", order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_cgeru.c b/ml/dlib/dlib/external/cblas/cblas_cgeru.c new file mode 100644 index 000000000..4471d1f80 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_cgeru.c @@ -0,0 +1,38 @@ +/* + * cblas_cgeru.c + * The program is a C interface to cgeru. + * + * Keita Teranishi 5/20/98 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_cgeru(const enum CBLAS_ORDER order, const int M, const int N, + const void *alpha, const void *X, const int incX, + const void *Y, const int incY, void *A, const int lda) +{ +#ifdef F77_INT + F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_incX=incX, F77_incY=incY; +#else + #define F77_M M + #define F77_N N + #define F77_incX incX + #define F77_incY incY + #define F77_lda lda +#endif + + + + if (order == CblasColMajor) + { + F77_cgeru( &F77_M, &F77_N, alpha, X, &F77_incX, Y, &F77_incY, A, + &F77_lda); + } + else if (order == CblasRowMajor) + { + F77_cgeru( &F77_N, &F77_M, alpha, Y, &F77_incY, X, &F77_incX, A, + &F77_lda); + } + else cblas_xerbla(1, "cblas_cgeru","Illegal Order setting, %d\n", order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_chbmv.c b/ml/dlib/dlib/external/cblas/cblas_chbmv.c new file mode 100644 index 000000000..8681fa345 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_chbmv.c @@ -0,0 +1,145 @@ +/* + * cblas_chbmv.c + * The program is a C interface to chbmv + * + * Keita Teranishi 5/18/98 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +#include +#include +void cblas_chbmv(const enum CBLAS_ORDER order, + const enum CBLAS_UPLO Uplo,const int N,const int K, + const void *alpha, const void *A, const int lda, + const void *X, const int incX, const void *beta, + void *Y, const int incY) +{ + char UL; +#ifdef F77_CHAR + F77_CHAR F77_UL; +#else + #define F77_UL &UL +#endif +#ifdef F77_INT + F77_INT F77_N=N, F77_K=K, F77_lda=lda, F77_incX=incX, F77_incY=incY; +#else + #define F77_N N + #define F77_K K + #define F77_lda lda + #define F77_incX incx + #define F77_incY incY +#endif + int n, i=0, incx=incX; + const float *xx= (float *)X, *alp= (float *)alpha, *bet = (float *)beta; + float ALPHA[2],BETA[2]; + int tincY, tincx; + float *x=(float *)X, *y=(float *)Y, *st=0, *tx; + + if (order == CblasColMajor) + { + if (Uplo == CblasLower) UL = 'L'; + else if (Uplo == CblasUpper) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_chbmv","Illegal Uplo setting, %d\n",Uplo ); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + F77_chbmv(F77_UL, &F77_N, &F77_K, alpha, A, &F77_lda, X, + &F77_incX, beta, Y, &F77_incY); + } + else if (order == CblasRowMajor) + { + ALPHA[0]= *alp; + ALPHA[1]= -alp[1]; + BETA[0]= *bet; + BETA[1]= -bet[1]; + + if (N > 0) + { + n = N << 1; + x = malloc(n*sizeof(float)); + + tx = x; + if( incX > 0 ) { + i = incX << 1 ; + tincx = 2; + st= x+n; + } else { + i = incX *(-2); + tincx = -2; + st = x-2; + x +=(n-2); + } + + do + { + *x = *xx; + x[1] = -xx[1]; + x += tincx ; + xx += i; + } + while (x != st); + x=tx; + + + #ifdef F77_INT + F77_incX = 1; + #else + incx = 1; + #endif + + if(incY > 0) + tincY = incY; + else + tincY = -incY; + y++; + + i = tincY << 1; + n = i * N ; + st = y + n; + do { + *y = -(*y); + y += i; + } while(y != st); + y -= n; + } else + x = (float *) X; + + if (Uplo == CblasUpper) UL = 'L'; + else if (Uplo == CblasLower) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_chbmv","Illegal Uplo setting, %d\n", Uplo); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + F77_chbmv(F77_UL, &F77_N, &F77_K, ALPHA, + A ,&F77_lda, x,&F77_incX, BETA, Y, &F77_incY); + } + else + { + cblas_xerbla(1, "cblas_chbmv","Illegal Order setting, %d\n", order); + return; + } + if ( order == CblasRowMajor ) + { + if(X!=x) + free(x); + if (N > 0) + { + do + { + *y = -(*y); + y += i; + } + while (y != st); + } + } + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_chemm.c b/ml/dlib/dlib/external/cblas/cblas_chemm.c new file mode 100644 index 000000000..e64f0d0f3 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_chemm.c @@ -0,0 +1,91 @@ +/* + * + * cblas_chemm.c + * This program is a C interface to chemm. + * Written by Keita Teranishi + * 4/8/1998 + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_chemm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const int M, const int N, + const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, + void *C, const int ldc) +{ + char SD, UL; +#ifdef F77_CHAR + F77_CHAR F77_SD, F77_UL; +#else + #define F77_SD &SD + #define F77_UL &UL +#endif + +#ifdef F77_INT + F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_ldb=ldb; + F77_INT F77_ldc=ldc; +#else + #define F77_M M + #define F77_N N + #define F77_lda lda + #define F77_ldb ldb + #define F77_ldc ldc +#endif + + + if( Order == CblasColMajor ) + { + if( Side == CblasRight) SD='R'; + else if ( Side == CblasLeft ) SD='L'; + else + { + cblas_xerbla(2, "cblas_chemm", "Illegal Side setting, %d\n", Side); + return; + } + + if( Uplo == CblasUpper) UL='U'; + else if ( Uplo == CblasLower ) UL='L'; + else + { + cblas_xerbla(3, "cblas_chemm", "Illegal Uplo setting, %d\n", Uplo); + return; + } + + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_SD = C2F_CHAR(&SD); + #endif + + F77_chemm(F77_SD, F77_UL, &F77_M, &F77_N, alpha, A, &F77_lda, + B, &F77_ldb, beta, C, &F77_ldc); + } else if (Order == CblasRowMajor) + { + if( Side == CblasRight) SD='L'; + else if ( Side == CblasLeft ) SD='R'; + else + { + cblas_xerbla(2, "cblas_chemm", "Illegal Side setting, %d\n", Side); + return; + } + + if( Uplo == CblasUpper) UL='L'; + else if ( Uplo == CblasLower ) UL='U'; + else + { + cblas_xerbla(3, "cblas_chemm", "Illegal Uplo setting, %d\n", Uplo); + return; + } + + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_SD = C2F_CHAR(&SD); + #endif + + F77_chemm(F77_SD, F77_UL, &F77_N, &F77_M, alpha, A, + &F77_lda, B, &F77_ldb, beta, C, &F77_ldc); + } + else cblas_xerbla(1, "cblas_chemm", "Illegal Order setting, %d\n", Order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_chemv.c b/ml/dlib/dlib/external/cblas/cblas_chemv.c new file mode 100644 index 000000000..bbaaefdc2 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_chemv.c @@ -0,0 +1,146 @@ +/* + * cblas_chemv.c + * The program is a C interface to chemv + * + * Keita Teranishi 5/18/98 + * + */ +#include +#include +#include "cblas.h" +#include "cblas_f77.h" +void cblas_chemv(const enum CBLAS_ORDER order, + const enum CBLAS_UPLO Uplo, const int N, + const void *alpha, const void *A, const int lda, + const void *X, const int incX, const void *beta, + void *Y, const int incY) +{ + char UL; +#ifdef F77_CHAR + F77_CHAR F77_UL; +#else + #define F77_UL &UL +#endif +#ifdef F77_INT + F77_INT F77_N=N, F77_lda=lda, F77_incX=incX, F77_incY=incY; +#else + #define F77_N N + #define F77_lda lda + #define F77_incX incx + #define F77_incY incY +#endif + int n=0, i=0, incx=incX; + const float *xx= (float *)X, *alp= (float *)alpha, *bet = (float *)beta; + float ALPHA[2],BETA[2]; + int tincY, tincx; + float *x=(float *)X, *y=(float *)Y, *st=0, *tx; + + + if (order == CblasColMajor) + { + if (Uplo == CblasUpper) UL = 'U'; + else if (Uplo == CblasLower) UL = 'L'; + else + { + cblas_xerbla(2, "cblas_chemv","Illegal Uplo setting, %d\n",Uplo ); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + F77_chemv(F77_UL, &F77_N, alpha, A, &F77_lda, X, &F77_incX, + beta, Y, &F77_incY); + } + else if (order == CblasRowMajor) + { + ALPHA[0]= *alp; + ALPHA[1]= -alp[1]; + BETA[0]= *bet; + BETA[1]= -bet[1]; + + if (N > 0) + { + n = N << 1; + x = malloc(n*sizeof(float)); + + tx = x; + if( incX > 0 ) { + i = incX << 1 ; + tincx = 2; + st= x+n; + } else { + i = incX *(-2); + tincx = -2; + st = x-2; + x +=(n-2); + } + + do + { + *x = *xx; + x[1] = -xx[1]; + x += tincx ; + xx += i; + } + while (x != st); + x=tx; + + + #ifdef F77_INT + F77_incX = 1; + #else + incx = 1; + #endif + + if(incY > 0) + tincY = incY; + else + tincY = -incY; + y++; + + i = tincY << 1; + n = i * N ; + st = y + n; + do { + *y = -(*y); + y += i; + } while(y != st); + y -= n; + } else + x = (float *) X; + + + if (Uplo == CblasUpper) UL = 'L'; + else if (Uplo == CblasLower) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_chemv","Illegal Uplo setting, %d\n", Uplo); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + F77_chemv(F77_UL, &F77_N, ALPHA, A, &F77_lda, x, &F77_incX, + BETA, Y, &F77_incY); + } + else + { + cblas_xerbla(1, "cblas_chemv","Illegal Order setting, %d\n", order); + return; + } + if ( order == CblasRowMajor ) + { + if ( X != x ) + free(x); + if (N > 0) + { + do + { + *y = -(*y); + y += i; + } + while (y != st); + } + } + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_cher.c b/ml/dlib/dlib/external/cblas/cblas_cher.c new file mode 100644 index 000000000..580413b02 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_cher.c @@ -0,0 +1,103 @@ +/* + * cblas_cher.c + * The program is a C interface to cher. + * + * Keita Teranishi 5/20/98 + * + */ +#include +#include +#include "cblas.h" +#include "cblas_f77.h" +void cblas_cher(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const int N, const float alpha, const void *X, const int incX + ,void *A, const int lda) +{ + char UL; +#ifdef F77_CHAR + F77_CHAR F77_UL; +#else + #define F77_UL &UL +#endif + +#ifdef F77_INT + F77_INT F77_N=N, F77_lda=lda, F77_incX=incX; +#else + #define F77_N N + #define F77_lda lda + #define F77_incX incx +#endif + int n, i, tincx, incx=incX; + float *x=(float *)X, *xx=(float *)X, *tx, *st; + + + if (order == CblasColMajor) + { + if (Uplo == CblasLower) UL = 'L'; + else if (Uplo == CblasUpper) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_cher","Illegal Uplo setting, %d\n",Uplo ); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + + F77_cher(F77_UL, &F77_N, &alpha, X, &F77_incX, A, &F77_lda); + + } else if (order == CblasRowMajor) + { + if (Uplo == CblasUpper) UL = 'L'; + else if (Uplo == CblasLower) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_cher","Illegal Uplo setting, %d\n", Uplo); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + if (N > 0) + { + n = N << 1; + x = malloc(n*sizeof(float)); + tx = x; + if( incX > 0 ) { + i = incX << 1 ; + tincx = 2; + st= x+n; + } else { + i = incX *(-2); + tincx = -2; + st = x-2; + x +=(n-2); + } + do + { + *x = *xx; + x[1] = -xx[1]; + x += tincx ; + xx += i; + } + while (x != st); + x=tx; + + #ifdef F77_INT + F77_incX = 1; + #else + incx = 1; + #endif + } + else x = (float *) X; + F77_cher(F77_UL, &F77_N, &alpha, x, &F77_incX, A, &F77_lda); + } else + { + cblas_xerbla(1, "cblas_cher","Illegal Order setting, %d\n", order); + return; + } + if(X!=x) + free(x); + + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_cher2.c b/ml/dlib/dlib/external/cblas/cblas_cher2.c new file mode 100644 index 000000000..89d36a0ee --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_cher2.c @@ -0,0 +1,139 @@ +/* + * cblas_cher2.c + * The program is a C interface to cher2. + * + * Keita Teranishi 3/23/98 + * + */ +#include +#include +#include "cblas.h" +#include "cblas_f77.h" +void cblas_cher2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const int N, const void *alpha, const void *X, const int incX, + const void *Y, const int incY, void *A, const int lda) +{ + char UL; +#ifdef F77_CHAR + F77_CHAR F77_UL; +#else + #define F77_UL &UL +#endif + +#ifdef F77_INT + F77_INT F77_N=N, F77_lda=lda, F77_incX=incX, F77_incY=incY; +#else + #define F77_N N + #define F77_lda lda + #define F77_incX incx + #define F77_incY incy +#endif + int n, i, j, tincx, tincy, incx=incX, incy=incY; + float *x=(float *)X, *xx=(float *)X, *y=(float *)Y, + *yy=(float *)Y, *tx, *ty, *stx, *sty; + + + if (order == CblasColMajor) + { + if (Uplo == CblasLower) UL = 'L'; + else if (Uplo == CblasUpper) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_cher2","Illegal Uplo setting, %d\n",Uplo ); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + + F77_cher2(F77_UL, &F77_N, alpha, X, &F77_incX, + Y, &F77_incY, A, &F77_lda); + + } else if (order == CblasRowMajor) + { + if (Uplo == CblasUpper) UL = 'L'; + else if (Uplo == CblasLower) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_cher2","Illegal Uplo setting, %d\n", Uplo); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + if (N > 0) + { + n = N << 1; + x = malloc(n*sizeof(float)); + y = malloc(n*sizeof(float)); + tx = x; + ty = y; + if( incX > 0 ) { + i = incX << 1 ; + tincx = 2; + stx= x+n; + } else { + i = incX *(-2); + tincx = -2; + stx = x-2; + x +=(n-2); + } + + if( incY > 0 ) { + j = incY << 1; + tincy = 2; + sty= y+n; + } else { + j = incY *(-2); + tincy = -2; + sty = y-2; + y +=(n-2); + } + + do + { + *x = *xx; + x[1] = -xx[1]; + x += tincx ; + xx += i; + } + while (x != stx); + + do + { + *y = *yy; + y[1] = -yy[1]; + y += tincy ; + yy += j; + } + while (y != sty); + + x=tx; + y=ty; + + #ifdef F77_INT + F77_incX = 1; + F77_incY = 1; + #else + incx = 1; + incy = 1; + #endif + } else + { + x = (float *) X; + y = (float *) Y; + } + F77_cher2(F77_UL, &F77_N, alpha, y, &F77_incY, x, + &F77_incX, A, &F77_lda); + } else + { + cblas_xerbla(1, "cblas_cher2","Illegal Order setting, %d\n", order); + return; + } + if(X!=x) + free(x); + if(Y!=y) + free(y); + + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_cher2k.c b/ml/dlib/dlib/external/cblas/cblas_cher2k.c new file mode 100644 index 000000000..cad1c432e --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_cher2k.c @@ -0,0 +1,96 @@ +/* + * + * cblas_cher2k.c + * This program is a C interface to cher2k. + * Written by Keita Teranishi + * 4/8/1998 + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_cher2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const float beta, + void *C, const int ldc) +{ + char UL, TR; +#ifdef F77_CHAR + F77_CHAR F77_TR, F77_UL; +#else + #define F77_TR &TR + #define F77_UL &UL +#endif + +#ifdef F77_INT + F77_INT F77_N=N, F77_K=K, F77_lda=lda, F77_ldb=ldb; + F77_INT F77_ldc=ldc; +#else + #define F77_N N + #define F77_K K + #define F77_lda lda + #define F77_ldb ldb + #define F77_ldc ldc +#endif + + float ALPHA[2]; + const float *alp=(float *)alpha; + + + if( Order == CblasColMajor ) + { + + if( Uplo == CblasUpper) UL='U'; + else if ( Uplo == CblasLower ) UL='L'; + else + { + cblas_xerbla(2, "cblas_cher2k", "Illegal Uplo setting, %d\n", Uplo); + return; + } + + if( Trans == CblasTrans) TR ='T'; + else if ( Trans == CblasConjTrans ) TR='C'; + else if ( Trans == CblasNoTrans ) TR='N'; + else + { + cblas_xerbla(3, "cblas_cher2k", "Illegal Trans setting, %d\n", Trans); + return; + } + + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TR = C2F_CHAR(&TR); + #endif + + F77_cher2k(F77_UL, F77_TR, &F77_N, &F77_K, alpha, A, &F77_lda, B, &F77_ldb, &beta, C, &F77_ldc); + } else if (Order == CblasRowMajor) + { + + if( Uplo == CblasUpper) UL='L'; + else if ( Uplo == CblasLower ) UL='U'; + else + { + cblas_xerbla(2, "cblas_cher2k", "Illegal Uplo setting, %d\n", Uplo); + return; + } + if( Trans == CblasTrans) TR ='N'; + else if ( Trans == CblasConjTrans ) TR='N'; + else if ( Trans == CblasNoTrans ) TR='C'; + else + { + cblas_xerbla(3, "cblas_cher2k", "Illegal Trans setting, %d\n", Trans); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TR = C2F_CHAR(&TR); + #endif + + ALPHA[0]= *alp; + ALPHA[1]= -alp[1]; + F77_cher2k(F77_UL,F77_TR, &F77_N, &F77_K, ALPHA, A, &F77_lda, B, &F77_ldb, &beta, C, &F77_ldc); + } + else cblas_xerbla(1, "cblas_cher2k", "Illegal Order setting, %d\n", Order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_cherk.c b/ml/dlib/dlib/external/cblas/cblas_cherk.c new file mode 100644 index 000000000..0b6362db7 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_cherk.c @@ -0,0 +1,90 @@ +/* + * + * cblas_cherk.c + * This program is a C interface to cherk. + * Written by Keita Teranishi + * 4/8/1998 + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_cherk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const float alpha, const void *A, const int lda, + const float beta, void *C, const int ldc) +{ + char UL, TR; +#ifdef F77_CHAR + F77_CHAR F77_TR, F77_UL; +#else + #define F77_TR &TR + #define F77_UL &UL +#endif + +#ifdef F77_INT + F77_INT F77_N=N, F77_K=K, F77_lda=lda; + F77_INT F77_ldc=ldc; +#else + #define F77_N N + #define F77_K K + #define F77_lda lda + #define F77_ldc ldc +#endif + + + if( Order == CblasColMajor ) + { + if( Uplo == CblasUpper) UL='U'; + else if ( Uplo == CblasLower ) UL='L'; + else + { + cblas_xerbla(2, "cblas_cherk", "Illegal Uplo setting, %d\n", Uplo); + return; + } + + if( Trans == CblasTrans) TR ='T'; + else if ( Trans == CblasConjTrans ) TR='C'; + else if ( Trans == CblasNoTrans ) TR='N'; + else + { + cblas_xerbla(3, "cblas_cherk", "Illegal Trans setting, %d\n", Trans); + return; + } + + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TR = C2F_CHAR(&TR); + #endif + + F77_cherk(F77_UL, F77_TR, &F77_N, &F77_K, &alpha, A, &F77_lda, + &beta, C, &F77_ldc); + } else if (Order == CblasRowMajor) + { + if( Uplo == CblasUpper) UL='L'; + else if ( Uplo == CblasLower ) UL='U'; + else + { + cblas_xerbla(3, "cblas_cherk", "Illegal Uplo setting, %d\n", Uplo); + return; + } + if( Trans == CblasTrans) TR ='N'; + else if ( Trans == CblasConjTrans ) TR='N'; + else if ( Trans == CblasNoTrans ) TR='C'; + else + { + cblas_xerbla(3, "cblas_cherk", "Illegal Trans setting, %d\n", Trans); + return; + } + + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_SD = C2F_CHAR(&SD); + #endif + + F77_cherk(F77_UL, F77_TR, &F77_N, &F77_K, &alpha, A, &F77_lda, + &beta, C, &F77_ldc); + } + else cblas_xerbla(1, "cblas_cherk", "Illegal Order setting, %d\n", Order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_chpmv.c b/ml/dlib/dlib/external/cblas/cblas_chpmv.c new file mode 100644 index 000000000..048734760 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_chpmv.c @@ -0,0 +1,146 @@ +/* + * cblas_chpmv.c + * The program is a C interface of chpmv + * + * Keita Teranishi 5/18/98 + * + */ +#include +#include +#include "cblas.h" +#include "cblas_f77.h" +void cblas_chpmv(const enum CBLAS_ORDER order, + const enum CBLAS_UPLO Uplo,const int N, + const void *alpha, const void *AP, + const void *X, const int incX, const void *beta, + void *Y, const int incY) +{ + char UL; +#ifdef F77_CHAR + F77_CHAR F77_UL; +#else + #define F77_UL &UL +#endif +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX, F77_incY=incY; +#else + #define F77_N N + #define F77_incX incx + #define F77_incY incY +#endif + int n, i=0, incx=incX; + const float *xx= (float *)X, *alp= (float *)alpha, *bet = (float *)beta; + float ALPHA[2],BETA[2]; + int tincY, tincx; + float *x=(float *)X, *y=(float *)Y, *st=0, *tx; + + if (order == CblasColMajor) + { + if (Uplo == CblasLower) UL = 'L'; + else if (Uplo == CblasUpper) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_chpmv","Illegal Uplo setting, %d\n",Uplo ); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + F77_chpmv(F77_UL, &F77_N, alpha, AP, X, + &F77_incX, beta, Y, &F77_incY); + } + else if (order == CblasRowMajor) + { + ALPHA[0]= *alp; + ALPHA[1]= -alp[1]; + BETA[0]= *bet; + BETA[1]= -bet[1]; + + if (N > 0) + { + n = N << 1; + x = malloc(n*sizeof(float)); + + tx = x; + if( incX > 0 ) { + i = incX << 1; + tincx = 2; + st= x+n; + } else { + i = incX *(-2); + tincx = -2; + st = x-2; + x +=(n-2); + } + + do + { + *x = *xx; + x[1] = -xx[1]; + x += tincx ; + xx += i; + } + while (x != st); + x=tx; + + + #ifdef F77_INT + F77_incX = 1; + #else + incx = 1; + #endif + + if(incY > 0) + tincY = incY; + else + tincY = -incY; + y++; + + i = tincY << 1; + n = i * N ; + st = y + n; + do { + *y = -(*y); + y += i; + } while(y != st); + y -= n; + } else + x = (float *) X; + + + if (Uplo == CblasUpper) UL = 'L'; + else if (Uplo == CblasLower) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_chpmv","Illegal Uplo setting, %d\n", Uplo ); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + + F77_chpmv(F77_UL, &F77_N, ALPHA, + AP, x, &F77_incX, BETA, Y, &F77_incY); + } + else + { + cblas_xerbla(1, "cblas_chpmv","Illegal Order setting, %d\n", order); + return; + } + if ( order == CblasRowMajor ) + { + if(X!=x) + free(x); + if (N > 0) + { + do + { + *y = -(*y); + y += i; + } + while (y != st); + } + } + + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_chpr.c b/ml/dlib/dlib/external/cblas/cblas_chpr.c new file mode 100644 index 000000000..72796b158 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_chpr.c @@ -0,0 +1,102 @@ +/* + * cblas_chpr.c + * The program is a C interface to chpr. + * + * Keita Teranishi 3/23/98 + * + */ +#include +#include +#include "cblas.h" +#include "cblas_f77.h" +void cblas_chpr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const int N, const float alpha, const void *X, + const int incX, void *A) +{ + char UL; +#ifdef F77_CHAR + F77_CHAR F77_UL; +#else + #define F77_UL &UL +#endif + +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX; +#else + #define F77_N N + #define F77_incX incx +#endif + int n, i, tincx, incx=incX; + float *x=(float *)X, *xx=(float *)X, *tx, *st; + + + if (order == CblasColMajor) + { + if (Uplo == CblasLower) UL = 'L'; + else if (Uplo == CblasUpper) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_chpr","Illegal Uplo setting, %d\n",Uplo ); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + + F77_chpr(F77_UL, &F77_N, &alpha, X, &F77_incX, A); + + } else if (order == CblasRowMajor) + { + if (Uplo == CblasUpper) UL = 'L'; + else if (Uplo == CblasLower) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_chpr","Illegal Uplo setting, %d\n", Uplo); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + if (N > 0) + { + n = N << 1; + x = malloc(n*sizeof(float)); + tx = x; + if( incX > 0 ) { + i = incX << 1; + tincx = 2; + st= x+n; + } else { + i = incX *(-2); + tincx = -2; + st = x-2; + x +=(n-2); + } + do + { + *x = *xx; + x[1] = -xx[1]; + x += tincx ; + xx += i; + } + while (x != st); + x=tx; + #ifdef F77_INT + F77_incX = 1; + #else + incx = 1; + #endif + } + else x = (float *) X; + + F77_chpr(F77_UL, &F77_N, &alpha, x, &F77_incX, A); + + } else + { + cblas_xerbla(1, "cblas_chpr","Illegal Order setting, %d\n", order); + return; + } + if(X!=x) + free(x); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_chpr2.c b/ml/dlib/dlib/external/cblas/cblas_chpr2.c new file mode 100644 index 000000000..f80d087aa --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_chpr2.c @@ -0,0 +1,136 @@ +/* + * cblas_chpr2.c + * The program is a C interface to chpr2. + * + * Keita Teranishi 5/20/98 + * + */ +#include +#include +#include "cblas.h" +#include "cblas_f77.h" +void cblas_chpr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const int N,const void *alpha, const void *X, + const int incX,const void *Y, const int incY, void *Ap) + +{ + char UL; +#ifdef F77_CHAR + F77_CHAR F77_UL; +#else + #define F77_UL &UL +#endif + +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX, F77_incY=incY; +#else + #define F77_N N + #define F77_incX incx + #define F77_incY incy +#endif + int n, i, j, tincx, tincy, incx=incX, incy=incY; + float *x=(float *)X, *xx=(float *)X, *y=(float *)Y, + *yy=(float *)Y, *tx, *ty, *stx, *sty; + + + if (order == CblasColMajor) + { + if (Uplo == CblasLower) UL = 'L'; + else if (Uplo == CblasUpper) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_chpr2","Illegal Uplo setting, %d\n",Uplo ); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + + F77_chpr2(F77_UL, &F77_N, alpha, X, &F77_incX, Y, &F77_incY, Ap); + + } else if (order == CblasRowMajor) + { + if (Uplo == CblasUpper) UL = 'L'; + else if (Uplo == CblasLower) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_chpr2","Illegal Uplo setting, %d\n", Uplo); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + if (N > 0) + { + n = N << 1; + x = malloc(n*sizeof(float)); + y = malloc(n*sizeof(float)); + tx = x; + ty = y; + if( incX > 0 ) { + i = incX << 1 ; + tincx = 2; + stx= x+n; + } else { + i = incX *(-2); + tincx = -2; + stx = x-2; + x +=(n-2); + } + + if( incY > 0 ) { + j = incY << 1; + tincy = 2; + sty= y+n; + } else { + j = incY *(-2); + tincy = -2; + sty = y-2; + y +=(n-2); + } + + do + { + *x = *xx; + x[1] = -xx[1]; + x += tincx ; + xx += i; + } + while (x != stx); + do + { + *y = *yy; + y[1] = -yy[1]; + y += tincy ; + yy += j; + } + while (y != sty); + + x=tx; + y=ty; + + #ifdef F77_INT + F77_incX = 1; + F77_incY = 1; + #else + incx = 1; + incy = 1; + #endif + + } else + { + x = (float *) X; + y = (void *) Y; + } + F77_chpr2(F77_UL, &F77_N, alpha, y, &F77_incY, x, &F77_incX, Ap); + } else + { + cblas_xerbla(1, "cblas_chpr2","Illegal Order setting, %d\n", order); + return; + } + if(X!=x) + free(x); + if(Y!=y) + free(y); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_cscal.c b/ml/dlib/dlib/external/cblas/cblas_cscal.c new file mode 100644 index 000000000..a23e6ee57 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_cscal.c @@ -0,0 +1,21 @@ +/* + * cblas_cscal.c + * + * The program is a C interface to cscal.f. + * + * Written by Keita Teranishi. 2/11/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_cscal( const int N, const void *alpha, void *X, + const int incX) +{ +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX; +#else + #define F77_N N + #define F77_incX incX +#endif + F77_cscal( &F77_N, alpha, X, &F77_incX); +} diff --git a/ml/dlib/dlib/external/cblas/cblas_csscal.c b/ml/dlib/dlib/external/cblas/cblas_csscal.c new file mode 100644 index 000000000..39983fe07 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_csscal.c @@ -0,0 +1,21 @@ +/* + * cblas_csscal.c + * + * The program is a C interface to csscal. + * + * Written by Keita Teranishi. 2/11/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_csscal( const int N, const float alpha, void *X, + const int incX) +{ +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX; +#else + #define F77_N N + #define F77_incX incX +#endif + F77_csscal( &F77_N, &alpha, X, &F77_incX); +} diff --git a/ml/dlib/dlib/external/cblas/cblas_cswap.c b/ml/dlib/dlib/external/cblas/cblas_cswap.c new file mode 100644 index 000000000..127282072 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_cswap.c @@ -0,0 +1,22 @@ +/* + * cblas_cswap.c + * + * The program is a C interface to cswap. + * + * Written by Keita Teranishi. 2/11/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_cswap( const int N, void *X, const int incX, void *Y, + const int incY) +{ +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX, F77_incY=incY; +#else + #define F77_N N + #define F77_incX incX + #define F77_incY incY +#endif + F77_cswap( &F77_N, X, &F77_incX, Y, &F77_incY); +} diff --git a/ml/dlib/dlib/external/cblas/cblas_csymm.c b/ml/dlib/dlib/external/cblas/cblas_csymm.c new file mode 100644 index 000000000..a462b5ebd --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_csymm.c @@ -0,0 +1,91 @@ +/* + * + * cblas_csymm.c + * This program is a C interface to csymm. + * Written by Keita Teranishi + * 4/8/1998 + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_csymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const int M, const int N, + const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, + void *C, const int ldc) +{ + char SD, UL; +#ifdef F77_CHAR + F77_CHAR F77_SD, F77_UL; +#else + #define F77_SD &SD + #define F77_UL &UL +#endif + +#ifdef F77_INT + F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_ldb=ldb; + F77_INT F77_ldc=ldc; +#else + #define F77_M M + #define F77_N N + #define F77_lda lda + #define F77_ldb ldb + #define F77_ldc ldc +#endif + + + if( Order == CblasColMajor ) + { + if( Side == CblasRight) SD='R'; + else if ( Side == CblasLeft ) SD='L'; + else + { + cblas_xerbla(2, "cblas_csymm", "Illegal Side setting, %d\n", Side); + return; + } + + if( Uplo == CblasUpper) UL='U'; + else if ( Uplo == CblasLower ) UL='L'; + else + { + cblas_xerbla(3, "cblas_csymm", "Illegal Uplo setting, %d\n", Uplo); + return; + } + + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_SD = C2F_CHAR(&SD); + #endif + + F77_csymm(F77_SD, F77_UL, &F77_M, &F77_N, alpha, A, &F77_lda, + B, &F77_ldb, beta, C, &F77_ldc); + } else if (Order == CblasRowMajor) + { + if( Side == CblasRight) SD='L'; + else if ( Side == CblasLeft ) SD='R'; + else + { + cblas_xerbla(2, "cblas_csymm", "Illegal Side setting, %d\n", Side); + return; + } + + if( Uplo == CblasUpper) UL='L'; + else if ( Uplo == CblasLower ) UL='U'; + else + { + cblas_xerbla(3, "cblas_csymm", "Illegal Uplo setting, %d\n", Uplo); + return; + } + + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_SD = C2F_CHAR(&SD); + #endif + + F77_csymm(F77_SD, F77_UL, &F77_N, &F77_M, alpha, A, &F77_lda, + B, &F77_ldb, beta, C, &F77_ldc); + } + else cblas_xerbla(1, "cblas_csymm", "Illegal Order setting, %d\n", Order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_csyr2k.c b/ml/dlib/dlib/external/cblas/cblas_csyr2k.c new file mode 100644 index 000000000..c93facb2b --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_csyr2k.c @@ -0,0 +1,93 @@ +/* + * + * cblas_csyr2k.c + * This program is a C interface to csyr2k. + * Written by Keita Teranishi + * 4/8/1998 + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_csyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, + void *C, const int ldc) +{ + char UL, TR; +#ifdef F77_CHAR + F77_CHAR F77_TR, F77_UL; +#else + #define F77_TR &TR + #define F77_UL &UL +#endif + +#ifdef F77_INT + F77_INT F77_N=N, F77_K=K, F77_lda=lda, F77_ldb=ldb; + F77_INT F77_ldc=ldc; +#else + #define F77_N N + #define F77_K K + #define F77_lda lda + #define F77_ldb ldb + #define F77_ldc ldc +#endif + + + if( Order == CblasColMajor ) + { + + if( Uplo == CblasUpper) UL='U'; + else if ( Uplo == CblasLower ) UL='L'; + else + { + cblas_xerbla(2, "cblas_csyr2k", "Illegal Uplo setting, %d\n", Uplo); + return; + } + + if( Trans == CblasTrans) TR ='T'; + else if ( Trans == CblasConjTrans ) TR='C'; + else if ( Trans == CblasNoTrans ) TR='N'; + else + { + cblas_xerbla(3, "cblas_csyr2k", "Illegal Trans setting, %d\n", Trans); + return; + } + + + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TR = C2F_CHAR(&TR); + #endif + + F77_csyr2k(F77_UL, F77_TR, &F77_N, &F77_K, alpha, A, &F77_lda, + B, &F77_ldb, beta, C, &F77_ldc); + } else if (Order == CblasRowMajor) + { + if( Uplo == CblasUpper) UL='L'; + else if ( Uplo == CblasLower ) UL='U'; + else + { + cblas_xerbla(3, "cblas_csyr2k", "Illegal Uplo setting, %d\n", Uplo); + return; + } + if( Trans == CblasTrans) TR ='N'; + else if ( Trans == CblasConjTrans ) TR='N'; + else if ( Trans == CblasNoTrans ) TR='T'; + else + { + cblas_xerbla(3, "cblas_csyr2k", "Illegal Trans setting, %d\n", Trans); + return; + } + + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TR = C2F_CHAR(&TR); + #endif + + F77_csyr2k(F77_UL, F77_TR, &F77_N, &F77_K, alpha, A, &F77_lda, B, &F77_ldb, beta, C, &F77_ldc); + } + else cblas_xerbla(1, "cblas_csyr2k", "Illegal Order setting, %d\n", Order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_csyrk.c b/ml/dlib/dlib/external/cblas/cblas_csyrk.c new file mode 100644 index 000000000..4ff0bd535 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_csyrk.c @@ -0,0 +1,93 @@ +/* + * + * cblas_csyrk.c + * This program is a C interface to csyrk. + * Written by Keita Teranishi + * 4/8/1998 + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_csyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const void *alpha, const void *A, const int lda, + const void *beta, void *C, const int ldc) +{ + char UL, TR; +#ifdef F77_CHAR + F77_CHAR F77_TR, F77_UL; +#else + #define F77_TR &TR + #define F77_UL &UL +#endif + +#ifdef F77_INT + F77_INT F77_N=N, F77_K=K, F77_lda=lda; + F77_INT F77_ldc=ldc; +#else + #define F77_N N + #define F77_K K + #define F77_lda lda + #define F77_ldc ldc +#endif + + + if( Order == CblasColMajor ) + { + + if( Uplo == CblasUpper) UL='U'; + else if ( Uplo == CblasLower ) UL='L'; + else + { + cblas_xerbla(2, "cblas_csyrk", "Illegal Uplo setting, %d\n", Uplo); + return; + } + + if( Trans == CblasTrans) TR ='T'; + else if ( Trans == CblasConjTrans ) TR='C'; + else if ( Trans == CblasNoTrans ) TR='N'; + else + { + cblas_xerbla(3, "cblas_csyrk", "Illegal Trans setting, %d\n", Trans); + return; + } + + + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TR = C2F_CHAR(&TR); + #endif + + F77_csyrk(F77_UL, F77_TR, &F77_N, &F77_K, alpha, A, &F77_lda, + beta, C, &F77_ldc); + } else if (Order == CblasRowMajor) + { + if( Uplo == CblasUpper) UL='L'; + else if ( Uplo == CblasLower ) UL='U'; + else + { + cblas_xerbla(3, "cblas_csyrk", "Illegal Uplo setting, %d\n", Uplo); + return; + } + if( Trans == CblasTrans) TR ='N'; + else if ( Trans == CblasConjTrans ) TR='N'; + else if ( Trans == CblasNoTrans ) TR='T'; + else + { + cblas_xerbla(3, "cblas_csyrk", "Illegal Trans setting, %d\n", Trans); + return; + } + + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TR = C2F_CHAR(&TR); + #endif + + F77_csyrk(F77_UL, F77_TR, &F77_N, &F77_K, alpha, A, &F77_lda, + beta, C, &F77_ldc); + } + else cblas_xerbla(1, "cblas_csyrk", "Illegal Order setting, %d\n", Order); + return; +} + diff --git a/ml/dlib/dlib/external/cblas/cblas_ctbmv.c b/ml/dlib/dlib/external/cblas/cblas_ctbmv.c new file mode 100644 index 000000000..0b313d858 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_ctbmv.c @@ -0,0 +1,139 @@ +/* + * cblas_ctbmv.c + * The program is a C interface to ctbmv. + * + * Keita Teranishi 5/20/98 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_ctbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const int K, const void *A, const int lda, + void *X, const int incX) +{ + char TA; + char UL; + char DI; +#ifdef F77_CHAR + F77_CHAR F77_TA, F77_UL, F77_DI; +#else + #define F77_TA &TA + #define F77_UL &UL + #define F77_DI &DI +#endif +#ifdef F77_INT + F77_INT F77_N=N, F77_lda=lda, F77_K=K, F77_incX=incX; +#else + #define F77_N N + #define F77_K K + #define F77_lda lda + #define F77_incX incX +#endif + int n, i=0, tincX; + float *st=0, *x=(float *)X; + + if (order == CblasColMajor) + { + if (Uplo == CblasUpper) UL = 'U'; + else if (Uplo == CblasLower) UL = 'L'; + else + { + cblas_xerbla(2, "cblas_ctbmv","Illegal Uplo setting, %d\n", Uplo); + return; + } + if (TransA == CblasNoTrans) TA = 'N'; + else if (TransA == CblasTrans) TA = 'T'; + else if (TransA == CblasConjTrans) TA = 'C'; + else + { + cblas_xerbla(3, "cblas_ctbmv","Illegal TransA setting, %d\n", TransA); + return; + } + if (Diag == CblasUnit) DI = 'U'; + else if (Diag == CblasNonUnit) DI = 'N'; + else + { + cblas_xerbla(4, "cblas_ctbmv","Illegal Diag setting, %d\n", Diag); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_DI = C2F_CHAR(&DI); + #endif + F77_ctbmv( F77_UL, F77_TA, F77_DI, &F77_N, &F77_K, A, &F77_lda, X, + &F77_incX); + } + else if (order == CblasRowMajor) + { + if (Uplo == CblasUpper) UL = 'L'; + else if (Uplo == CblasLower) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_ctbmv","Illegal Uplo setting, %d\n", Uplo); + return; + } + + if (TransA == CblasNoTrans) TA = 'T'; + else if (TransA == CblasTrans) TA = 'N'; + else if (TransA == CblasConjTrans) + { + TA = 'N'; + if ( N > 0) + { + if(incX > 0) + tincX = incX; + else + tincX = -incX; + i = tincX << 1; + n = i * N; + x++; + st = x + n; + do + { + *x = -(*x); + x+= i; + } + while (x != st); + x -= n; + } + } + else + { + cblas_xerbla(3, "cblas_ctbmv","Illegal TransA setting, %d\n", TransA); + return; + } + + if (Diag == CblasUnit) DI = 'U'; + else if (Diag == CblasNonUnit) DI = 'N'; + else + { + cblas_xerbla(4, "cblas_ctbmv","Illegal Uplo setting, %d\n", Uplo); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_DI = C2F_CHAR(&DI); + #endif + + F77_ctbmv( F77_UL, F77_TA, F77_DI, &F77_N, &F77_K, A, &F77_lda, X, + &F77_incX); + + if (TransA == CblasConjTrans) + { + if (N > 0) + { + do + { + *x = -(*x); + x += i; + } + while (x != st); + } + } + } + else cblas_xerbla(1, "cblas_ctbmv", "Illegal Order setting, %d\n", order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_ctbsv.c b/ml/dlib/dlib/external/cblas/cblas_ctbsv.c new file mode 100644 index 000000000..31f3f5bb0 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_ctbsv.c @@ -0,0 +1,143 @@ +/* + * cblas_ctbsv.c + * The program is a C interface to ctbsv. + * + * Keita Teranishi 3/23/98 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_ctbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const int K, const void *A, const int lda, + void *X, const int incX) +{ + char TA; + char UL; + char DI; +#ifdef F77_CHAR + F77_CHAR F77_TA, F77_UL, F77_DI; +#else + #define F77_TA &TA + #define F77_UL &UL + #define F77_DI &DI +#endif +#ifdef F77_INT + F77_INT F77_N=N, F77_lda=lda, F77_K=K, F77_incX=incX; +#else + #define F77_N N + #define F77_K K + #define F77_lda lda + #define F77_incX incX +#endif + int n, i=0, tincX; + float *st=0,*x=(float *)X; + + if (order == CblasColMajor) + { + if (Uplo == CblasUpper) UL = 'U'; + else if (Uplo == CblasLower) UL = 'L'; + else + { + cblas_xerbla(2, "cblas_ctbsv","Illegal Uplo setting, %d\n", Uplo); + return; + } + if (TransA == CblasNoTrans) TA = 'N'; + else if (TransA == CblasTrans) TA = 'T'; + else if (TransA == CblasConjTrans) TA = 'C'; + else + { + cblas_xerbla(3, "cblas_ctbsv","Illegal TransA setting, %d\n", TransA); + return; + } + if (Diag == CblasUnit) DI = 'U'; + else if (Diag == CblasNonUnit) DI = 'N'; + else + { + cblas_xerbla(4, "cblas_ctbsv","Illegal Diag setting, %d\n", Diag); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_DI = C2F_CHAR(&DI); + #endif + F77_ctbsv( F77_UL, F77_TA, F77_DI, &F77_N, &F77_K, A, &F77_lda, X, + &F77_incX); + } + else if (order == CblasRowMajor) + { + if (Uplo == CblasUpper) UL = 'L'; + else if (Uplo == CblasLower) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_ctbsv","Illegal Uplo setting, %d\n", Uplo); + return; + } + + if (TransA == CblasNoTrans) TA = 'T'; + else if (TransA == CblasTrans) TA = 'N'; + else if (TransA == CblasConjTrans) + { + TA = 'N'; + if ( N > 0) + { + if ( incX > 0 ) + tincX = incX; + else + tincX = -incX; + + n = N*2*(tincX); + + x++; + + st=x+n; + + i = tincX << 1; + do + { + *x = -(*x); + x+=i; + } + while (x != st); + x -= n; + } + } + else + { + cblas_xerbla(3, "cblas_ctbsv","Illegal TransA setting, %d\n", TransA); + return; + } + + if (Diag == CblasUnit) DI = 'U'; + else if (Diag == CblasNonUnit) DI = 'N'; + else + { + cblas_xerbla(4, "cblas_ctbsv","Illegal Diag setting, %d\n", Diag); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_DI = C2F_CHAR(&DI); + #endif + + F77_ctbsv( F77_UL, F77_TA, F77_DI, &F77_N, &F77_K, A, &F77_lda, X, + &F77_incX); + + if (TransA == CblasConjTrans) + { + if (N > 0) + { + do + { + *x = -(*x); + x+= i; + } + while (x != st); + } + } + } + else cblas_xerbla(1, "cblas_ctbsv", "Illegal Order setting, %d\n", order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_ctpmv.c b/ml/dlib/dlib/external/cblas/cblas_ctpmv.c new file mode 100644 index 000000000..03dad131e --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_ctpmv.c @@ -0,0 +1,133 @@ +/* + * cblas_ctpmv.c + * The program is a C interface to ctpmv. + * + * Keita Teranishi 5/20/98 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_ctpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const void *Ap, void *X, const int incX) +{ + char TA; + char UL; + char DI; +#ifdef F77_CHAR + F77_CHAR F77_TA, F77_UL, F77_DI; +#else + #define F77_TA &TA + #define F77_UL &UL + #define F77_DI &DI +#endif +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX; +#else + #define F77_N N + #define F77_incX incX +#endif + int n, i=0, tincX; + float *st=0,*x=(float *)X; + + if (order == CblasColMajor) + { + if (Uplo == CblasUpper) UL = 'U'; + else if (Uplo == CblasLower) UL = 'L'; + else + { + cblas_xerbla(2, "cblas_ctpmv","Illegal Uplo setting, %d\n", Uplo); + return; + } + if (TransA == CblasNoTrans) TA = 'N'; + else if (TransA == CblasTrans) TA = 'T'; + else if (TransA == CblasConjTrans) TA = 'C'; + else + { + cblas_xerbla(3, "cblas_ctpmv","Illegal TransA setting, %d\n", TransA); + return; + } + if (Diag == CblasUnit) DI = 'U'; + else if (Diag == CblasNonUnit) DI = 'N'; + else + { + cblas_xerbla(4, "cblas_ctpmv","Illegal Diag setting, %d\n", Diag); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_DI = C2F_CHAR(&DI); + #endif + F77_ctpmv( F77_UL, F77_TA, F77_DI, &F77_N, Ap, X, &F77_incX); + } + else if (order == CblasRowMajor) + { + if (Uplo == CblasUpper) UL = 'L'; + else if (Uplo == CblasLower) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_ctpmv","Illegal Uplo setting, %d\n", Uplo); + return; + } + + if (TransA == CblasNoTrans) TA = 'T'; + else if (TransA == CblasTrans) TA = 'N'; + else if (TransA == CblasConjTrans) + { + TA = 'N'; + if ( N > 0) + { + if(incX > 0) + tincX = incX; + else + tincX = -incX; + i = tincX << 1; + n = i * N; + x++; + st = x + n; + do + { + *x = -(*x); + x += i; + } + while (x != st); + x -= n; + } + } + else + { + cblas_xerbla(3, "cblas_ctpmv","Illegal TransA setting, %d\n", TransA); + return; + } + + if (Diag == CblasUnit) DI = 'U'; + else if (Diag == CblasNonUnit) DI = 'N'; + else + { + cblas_xerbla(4, "cblas_ctpmv","Illegal Diag setting, %d\n", Diag); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_DI = C2F_CHAR(&DI); + #endif + + F77_ctpmv( F77_UL, F77_TA, F77_DI, &F77_N, Ap, X,&F77_incX); + if (TransA == CblasConjTrans) + { + if (N > 0) + { + do + { + *x = -(*x); + x += i; + } + while (x != st); + } + } + } + else cblas_xerbla(1, "cblas_ctpmv", "Illegal Order setting, %d\n", order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_ctpsv.c b/ml/dlib/dlib/external/cblas/cblas_ctpsv.c new file mode 100644 index 000000000..3023306bd --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_ctpsv.c @@ -0,0 +1,138 @@ +/* + * cblas_ctpsv.c + * The program is a C interface to ctpsv. + * + * Keita Teranishi 3/23/98 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_ctpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const void *Ap, void *X, const int incX) +{ + char TA; + char UL; + char DI; +#ifdef F77_CHAR + F77_CHAR F77_TA, F77_UL, F77_DI; +#else + #define F77_TA &TA + #define F77_UL &UL + #define F77_DI &DI +#endif +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX; +#else + #define F77_N N + #define F77_incX incX +#endif + int n, i=0, tincX; + float *st=0, *x=(float*)X; + + if (order == CblasColMajor) + { + if (Uplo == CblasUpper) UL = 'U'; + else if (Uplo == CblasLower) UL = 'L'; + else + { + cblas_xerbla(2, "cblas_ctpsv","Illegal Uplo setting, %d\n", Uplo); + return; + } + if (TransA == CblasNoTrans) TA = 'N'; + else if (TransA == CblasTrans) TA = 'T'; + else if (TransA == CblasConjTrans) TA = 'C'; + else + { + cblas_xerbla(3, "cblas_ctpsv","Illegal TransA setting, %d\n", TransA); + return; + } + if (Diag == CblasUnit) DI = 'U'; + else if (Diag == CblasNonUnit) DI = 'N'; + else + { + cblas_xerbla(4, "cblas_ctpsv","Illegal Diag setting, %d\n", Diag); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_DI = C2F_CHAR(&DI); + #endif + F77_ctpsv( F77_UL, F77_TA, F77_DI, &F77_N, Ap, X, &F77_incX); + } + else if (order == CblasRowMajor) + { + if (Uplo == CblasUpper) UL = 'L'; + else if (Uplo == CblasLower) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_ctpsv","Illegal Uplo setting, %d\n", Uplo); + return; + } + + if (TransA == CblasNoTrans) TA = 'T'; + else if (TransA == CblasTrans) TA = 'N'; + else if (TransA == CblasConjTrans) + { + TA = 'N'; + if ( N > 0) + { + if ( incX > 0 ) + tincX = incX; + else + tincX = -incX; + + n = N*2*(tincX); + + x++; + + st=x+n; + + i = tincX << 1; + do + { + *x = -(*x); + x+=i; + } + while (x != st); + x -= n; + } + } + else + { + cblas_xerbla(3, "cblas_ctpsv","Illegal TransA setting, %d\n", TransA); + return; + } + + if (Diag == CblasUnit) DI = 'U'; + else if (Diag == CblasNonUnit) DI = 'N'; + else + { + cblas_xerbla(4, "cblas_ctpsv","Illegal Diag setting, %d\n", Diag); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_DI = C2F_CHAR(&DI); + #endif + + F77_ctpsv( F77_UL, F77_TA, F77_DI, &F77_N, Ap, X,&F77_incX); + + if (TransA == CblasConjTrans) + { + if (N > 0) + { + do + { + *x = -(*x); + x += i; + } + while (x != st); + } + } + } + else cblas_xerbla(1, "cblas_ctpsv", "Illegal Order setting, %d\n", order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_ctrmm.c b/ml/dlib/dlib/external/cblas/cblas_ctrmm.c new file mode 100644 index 000000000..ceeb2a5ff --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_ctrmm.c @@ -0,0 +1,123 @@ +/* + * + * cblas_ctrmm.c + * This program is a C interface to ctrmm. + * Written by Keita Teranishi + * 4/8/1998 + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_ctrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_DIAG Diag, const int M, const int N, + const void *alpha, const void *A, const int lda, + void *B, const int ldb) +{ + char UL, TA, SD, DI; +#ifdef F77_CHAR + F77_CHAR F77_TA, F77_UL, F77_SD, F77_DI; +#else + #define F77_TA &TA + #define F77_UL &UL + #define F77_SD &SD + #define F77_DI &DI +#endif + +#ifdef F77_INT + F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_ldb=ldb; +#else + #define F77_M M + #define F77_N N + #define F77_lda lda + #define F77_ldb ldb +#endif + + + if( Order == CblasColMajor ) + { + if( Side == CblasRight ) SD='R'; + else if ( Side == CblasLeft ) SD='L'; + else + { + cblas_xerbla(2, "cblas_ctrmm", "Illegal Side setting, %d\n", Side); + return; + } + if( Uplo == CblasUpper ) UL='U'; + else if ( Uplo == CblasLower ) UL='L'; + else + { + cblas_xerbla(3, "cblas_ctrmm", "Illegal Uplo setting, %d\n", Uplo); + return; + } + + if( TransA == CblasTrans ) TA ='T'; + else if ( TransA == CblasConjTrans ) TA='C'; + else if ( TransA == CblasNoTrans ) TA='N'; + else + { + cblas_xerbla(4, "cblas_ctrmm", "Illegal Trans setting, %d\n", TransA); + return; + } + + if( Diag == CblasUnit ) DI='U'; + else if ( Diag == CblasNonUnit ) DI='N'; + else cblas_xerbla(5, "cblas_ctrmm", + "Illegal Diag setting, %d\n", Diag); + + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_SD = C2F_CHAR(&SD); + F77_DI = C2F_CHAR(&DI); + #endif + + F77_ctrmm(F77_SD, F77_UL, F77_TA, F77_DI, &F77_M, &F77_N, alpha, A, &F77_lda, B, &F77_ldb); + } else if (Order == CblasRowMajor) + { + if( Side == CblasRight ) SD='L'; + else if ( Side == CblasLeft ) SD='R'; + else + { + cblas_xerbla(2, "cblas_ctrmm", "Illegal Side setting, %d\n", Side); + return; + } + + if( Uplo == CblasUpper ) UL='L'; + else if ( Uplo == CblasLower ) UL='U'; + else + { + cblas_xerbla(3, "cblas_ctrmm", "Illegal Uplo setting, %d\n", Uplo); + return; + } + + if( TransA == CblasTrans ) TA ='T'; + else if ( TransA == CblasConjTrans ) TA='C'; + else if ( TransA == CblasNoTrans ) TA='N'; + else + { + cblas_xerbla(4, "cblas_ctrmm", "Illegal Trans setting, %d\n", TransA); + return; + } + + if( Diag == CblasUnit ) DI='U'; + else if ( Diag == CblasNonUnit ) DI='N'; + else + { + cblas_xerbla(5, "cblas_ctrmm", "Illegal Diag setting, %d\n", Diag); + return; + } + + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_SD = C2F_CHAR(&SD); + F77_DI = C2F_CHAR(&DI); + #endif + + F77_ctrmm(F77_SD, F77_UL, F77_TA, F77_DI, &F77_N, &F77_M, alpha, A, &F77_lda, B, &F77_ldb); + } + else cblas_xerbla(1, "cblas_ctrmm", "Illegal Order setting, %d\n", Order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_ctrmv.c b/ml/dlib/dlib/external/cblas/cblas_ctrmv.c new file mode 100644 index 000000000..dc590fe76 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_ctrmv.c @@ -0,0 +1,136 @@ +/* + * cblas_ctrmv.c + * The program is a C interface to ctrmv. + * + * Keita Teranishi 3/23/98 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_ctrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const void *A, const int lda, + void *X, const int incX) + +{ + char TA; + char UL; + char DI; +#ifdef F77_CHAR + F77_CHAR F77_TA, F77_UL, F77_DI; +#else + #define F77_TA &TA + #define F77_UL &UL + #define F77_DI &DI +#endif +#ifdef F77_INT + F77_INT F77_N=N, F77_lda=lda, F77_incX=incX; +#else + #define F77_N N + #define F77_lda lda + #define F77_incX incX +#endif + int n, i=0, tincX; + float *st=0,*x=(float *)X; + + if (order == CblasColMajor) + { + if (Uplo == CblasUpper) UL = 'U'; + else if (Uplo == CblasLower) UL = 'L'; + else + { + cblas_xerbla(2, "cblas_ctrmv","Illegal Uplo setting, %d\n", Uplo); + return; + } + if (TransA == CblasNoTrans) TA = 'N'; + else if (TransA == CblasTrans) TA = 'T'; + else if (TransA == CblasConjTrans) TA = 'C'; + else + { + cblas_xerbla(3, "cblas_ctrmv","Illegal TransA setting, %d\n", TransA); + return; + } + if (Diag == CblasUnit) DI = 'U'; + else if (Diag == CblasNonUnit) DI = 'N'; + else + { + cblas_xerbla(4, "cblas_ctrmv","Illegal Diag setting, %d\n", Diag); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_DI = C2F_CHAR(&DI); + #endif + F77_ctrmv( F77_UL, F77_TA, F77_DI, &F77_N, A, &F77_lda, X, + &F77_incX); + } + else if (order == CblasRowMajor) + { + if (Uplo == CblasUpper) UL = 'L'; + else if (Uplo == CblasLower) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_ctrmv","Illegal Uplo setting, %d\n", Uplo); + return; + } + + if (TransA == CblasNoTrans) TA = 'T'; + else if (TransA == CblasTrans) TA = 'N'; + else if (TransA == CblasConjTrans) + { + TA = 'N'; + if ( N > 0) + { + if(incX > 0) + tincX = incX; + else + tincX = -incX; + i = tincX << 1; + n = i * N; + st = x + n; + do + { + x[1] = -x[1]; + x+= i; + } + while (x != st); + x -= n; + } + } + else + { + cblas_xerbla(3, "cblas_ctrmv","Illegal TransA setting, %d\n", TransA); + return; + } + + if (Diag == CblasUnit) DI = 'U'; + else if (Diag == CblasNonUnit) DI = 'N'; + else + { + cblas_xerbla(4, "cblas_ctrmv","Illegal Diag setting, %d\n", Diag); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_DI = C2F_CHAR(&DI); + #endif + F77_ctrmv( F77_UL, F77_TA, F77_DI, &F77_N, A, &F77_lda, X, + &F77_incX); + if (TransA == CblasConjTrans) + { + if (N > 0) + { + do + { + x[1] = -x[1]; + x += i; + } + while (x != st); + } + } + } + else cblas_xerbla(1, "cblas_ctrmv", "Illegal Order setting, %d\n", order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_ctrsm.c b/ml/dlib/dlib/external/cblas/cblas_ctrsm.c new file mode 100644 index 000000000..ae2d8796e --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_ctrsm.c @@ -0,0 +1,132 @@ +/* + * + * cblas_ctrsm.c + * This program is a C interface to ctrsm. + * Written by Keita Teranishi + * 4/8/1998 + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_ctrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_DIAG Diag, const int M, const int N, + const void *alpha, const void *A, const int lda, + void *B, const int ldb) +{ + char UL, TA, SD, DI; +#ifdef F77_CHAR + F77_CHAR F77_TA, F77_UL, F77_SD, F77_DI; +#else + #define F77_TA &TA + #define F77_UL &UL + #define F77_SD &SD + #define F77_DI &DI +#endif + +#ifdef F77_INT + F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_ldb=ldb; +#else + #define F77_M M + #define F77_N N + #define F77_lda lda + #define F77_ldb ldb +#endif + + + if( Order == CblasColMajor ) + { + + if( Side == CblasRight) SD='R'; + else if ( Side == CblasLeft ) SD='L'; + else + { + cblas_xerbla(2, "cblas_ctrsm", "Illegal Side setting, %d\n", Side); + return; + } + + if( Uplo == CblasUpper) UL='U'; + else if ( Uplo == CblasLower ) UL='L'; + else + { + cblas_xerbla(3, "cblas_ctrsm", "Illegal Uplo setting, %d\n", Uplo); + return; + } + + if( TransA == CblasTrans) TA ='T'; + else if ( TransA == CblasConjTrans ) TA='C'; + else if ( TransA == CblasNoTrans ) TA='N'; + else + { + cblas_xerbla(4, "cblas_ctrsm", "Illegal Trans setting, %d\n", TransA); + return; + } + + if( Diag == CblasUnit ) DI='U'; + else if ( Diag == CblasNonUnit ) DI='N'; + else + { + cblas_xerbla(5, "cblas_ctrsm", "Illegal Diag setting, %d\n", Diag); + return; + } + + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_SD = C2F_CHAR(&SD); + F77_DI = C2F_CHAR(&DI); + #endif + + F77_ctrsm(F77_SD, F77_UL, F77_TA, F77_DI, &F77_M, &F77_N, alpha, A, + &F77_lda, B, &F77_ldb); + } else if (Order == CblasRowMajor) + { + + if( Side == CblasRight) SD='L'; + else if ( Side == CblasLeft ) SD='R'; + else + { + cblas_xerbla(2, "cblas_ctrsm", "Illegal Side setting, %d\n", Side); + return; + } + + if( Uplo == CblasUpper) UL='L'; + else if ( Uplo == CblasLower ) UL='U'; + else + { + cblas_xerbla(3, "cblas_ctrsm", "Illegal Uplo setting, %d\n", Uplo); + return; + } + + if( TransA == CblasTrans) TA ='T'; + else if ( TransA == CblasConjTrans ) TA='C'; + else if ( TransA == CblasNoTrans ) TA='N'; + else + { + cblas_xerbla(4, "cblas_ctrsm", "Illegal Trans setting, %d\n", TransA); + return; + } + + if( Diag == CblasUnit ) DI='U'; + else if ( Diag == CblasNonUnit ) DI='N'; + else + { + cblas_xerbla(5, "cblas_ctrsm", "Illegal Diag setting, %d\n", Diag); + return; + } + + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_SD = C2F_CHAR(&SD); + F77_DI = C2F_CHAR(&DI); + #endif + + + F77_ctrsm(F77_SD, F77_UL, F77_TA, F77_DI, &F77_N, &F77_M, alpha, A, + &F77_lda, B, &F77_ldb); + } + else cblas_xerbla(1, "cblas_ctrsm", "Illegal Order setting, %d\n", Order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_ctrsv.c b/ml/dlib/dlib/external/cblas/cblas_ctrsv.c new file mode 100644 index 000000000..8bfacd913 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_ctrsv.c @@ -0,0 +1,137 @@ +/* + * cblas_ctrsv.c + * The program is a C interface to ctrsv. + * + * Keita Teranishi 3/23/98 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_ctrsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const void *A, const int lda, void *X, + const int incX) +{ + char TA; + char UL; + char DI; +#ifdef F77_CHAR + F77_CHAR F77_TA, F77_UL, F77_DI; +#else + #define F77_TA &TA + #define F77_UL &UL + #define F77_DI &DI +#endif +#ifdef F77_INT + F77_INT F77_N=N, F77_lda=lda, F77_incX=incX; +#else + #define F77_N N + #define F77_lda lda + #define F77_incX incX +#endif + int n, i=0, tincX; + float *st=0,*x=(float *)X; + + if (order == CblasColMajor) + { + if (Uplo == CblasUpper) UL = 'U'; + else if (Uplo == CblasLower) UL = 'L'; + else + { + cblas_xerbla(2, "cblas_ctrsv","Illegal Uplo setting, %d\n", Uplo); + return; + } + if (TransA == CblasNoTrans) TA = 'N'; + else if (TransA == CblasTrans) TA = 'T'; + else if (TransA == CblasConjTrans) TA = 'C'; + else + { + cblas_xerbla(3, "cblas_ctrsv","Illegal TransA setting, %d\n", TransA); + return; + } + if (Diag == CblasUnit) DI = 'U'; + else if (Diag == CblasNonUnit) DI = 'N'; + else + { + cblas_xerbla(4, "cblas_ctrsv","Illegal Diag setting, %d\n", Diag); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_DI = C2F_CHAR(&DI); + #endif + F77_ctrsv( F77_UL, F77_TA, F77_DI, &F77_N, A, &F77_lda, X, + &F77_incX); + } + else if (order == CblasRowMajor) + { + if (Uplo == CblasUpper) UL = 'L'; + else if (Uplo == CblasLower) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_ctrsv","Illegal Uplo setting, %d\n", Uplo); + return; + } + + if (TransA == CblasNoTrans) TA = 'T'; + else if (TransA == CblasTrans) TA = 'N'; + else if (TransA == CblasConjTrans) + { + TA = 'N'; + if ( N > 0) + { + if ( incX > 0 ) + tincX = incX; + else + tincX = -incX; + + n = N*2*(tincX); + x++; + st=x+n; + i = tincX << 1; + do + { + *x = -(*x); + x+=i; + } + while (x != st); + x -= n; + } + } + else + { + cblas_xerbla(3, "cblas_ctrsv","Illegal TransA setting, %d\n", TransA); + return; + } + + if (Diag == CblasUnit) DI = 'U'; + else if (Diag == CblasNonUnit) DI = 'N'; + else + { + cblas_xerbla(4, "cblas_ctrsv","Illegal Diag setting, %d\n", Diag); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_DI = C2F_CHAR(&DI); + #endif + F77_ctrsv( F77_UL, F77_TA, F77_DI, &F77_N, A, &F77_lda, X, + &F77_incX); + if (TransA == CblasConjTrans) + { + if (N > 0) + { + do + { + *x = -(*x); + x += i; + } + while (x != st); + } + } + } + else cblas_xerbla(1, "cblas_ctrsv", "Illegal Order setting, %d\n", order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_dasum.c b/ml/dlib/dlib/external/cblas/cblas_dasum.c new file mode 100644 index 000000000..1a3667f2d --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_dasum.c @@ -0,0 +1,23 @@ +/* + * cblas_dasum.c + * + * The program is a C interface to dasum. + * It calls the fortran wrapper before calling dasum. + * + * Written by Keita Teranishi. 2/11/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +double cblas_dasum( const int N, const double *X, const int incX) +{ + double asum; +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX; +#else + #define F77_N N + #define F77_incX incX +#endif + F77_dasum_sub( &F77_N, X, &F77_incX, &asum); + return asum; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_daxpy.c b/ml/dlib/dlib/external/cblas/cblas_daxpy.c new file mode 100644 index 000000000..3678137fb --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_daxpy.c @@ -0,0 +1,22 @@ +/* + * cblas_daxpy.c + * + * The program is a C interface to daxpy. + * + * Written by Keita Teranishi. 2/11/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_daxpy( const int N, const double alpha, const double *X, + const int incX, double *Y, const int incY) +{ +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX, F77_incY=incY; +#else + #define F77_N N + #define F77_incX incX + #define F77_incY incY +#endif + F77_daxpy( &F77_N, &alpha, X, &F77_incX, Y, &F77_incY); +} diff --git a/ml/dlib/dlib/external/cblas/cblas_dcopy.c b/ml/dlib/dlib/external/cblas/cblas_dcopy.c new file mode 100644 index 000000000..422a55e51 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_dcopy.c @@ -0,0 +1,22 @@ +/* + * cblas_dcopy.c + * + * The program is a C interface to dcopy. + * + * Written by Keita Teranishi. 2/11/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_dcopy( const int N, const double *X, + const int incX, double *Y, const int incY) +{ +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX, F77_incY=incY; +#else + #define F77_N N + #define F77_incX incX + #define F77_incY incY +#endif + F77_dcopy( &F77_N, X, &F77_incX, Y, &F77_incY); +} diff --git a/ml/dlib/dlib/external/cblas/cblas_ddot.c b/ml/dlib/dlib/external/cblas/cblas_ddot.c new file mode 100644 index 000000000..d77343403 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_ddot.c @@ -0,0 +1,25 @@ +/* + * cblas_ddot.c + * + * The program is a C interface to ddot. + * It calls the fortran wrapper before calling ddot. + * + * Written by Keita Teranishi. 2/11/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +double cblas_ddot( const int N, const double *X, + const int incX, const double *Y, const int incY) +{ + double dot; +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX, F77_incY=incY; +#else + #define F77_N N + #define F77_incX incX + #define F77_incY incY +#endif + F77_ddot_sub( &F77_N, X, &F77_incX, Y, &F77_incY, &dot); + return dot; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_dgbmv.c b/ml/dlib/dlib/external/cblas/cblas_dgbmv.c new file mode 100644 index 000000000..886dab740 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_dgbmv.c @@ -0,0 +1,70 @@ +/* + * + * cblas_dgbmv.c + * This program is a C interface to dgbmv. + * Written by Keita Teranishi + * 4/6/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_dgbmv(const enum CBLAS_ORDER order, + const enum CBLAS_TRANSPOSE TransA, const int M, const int N, + const int KL, const int KU, + const double alpha, const double *A, const int lda, + const double *X, const int incX, const double beta, + double *Y, const int incY) +{ + char TA; +#ifdef F77_CHAR + F77_CHAR F77_TA; +#else + #define F77_TA &TA +#endif +#ifdef F77_INT + F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_incX=incX, F77_incY=incY; + F77_INT F77_KL=KL,F77_KU=KU; +#else + #define F77_M M + #define F77_N N + #define F77_lda lda + #define F77_KL KL + #define F77_KU KU + #define F77_incX incX + #define F77_incY incY +#endif + + if (order == CblasColMajor) + { + if (TransA == CblasNoTrans) TA = 'N'; + else if (TransA == CblasTrans) TA = 'T'; + else if (TransA == CblasConjTrans) TA = 'C'; + else + { + cblas_xerbla(2, "cblas_dgbmv","Illegal TransA setting, %d\n", TransA); + return; + } + #ifdef F77_CHAR + F77_TA = C2F_CHAR(&TA); + #endif + F77_dgbmv(F77_TA, &F77_M, &F77_N, &F77_KL, &F77_KU, &alpha, + A, &F77_lda, X, &F77_incX, &beta, Y, &F77_incY); + } + else if (order == CblasRowMajor) + { + if (TransA == CblasNoTrans) TA = 'T'; + else if (TransA == CblasTrans) TA = 'N'; + else if (TransA == CblasConjTrans) TA = 'N'; + else + { + cblas_xerbla(2, "cblas_dgbmv","Illegal TransA setting, %d\n", TransA); + return; + } + #ifdef F77_CHAR + F77_TA = C2F_CHAR(&TA); + #endif + F77_dgbmv(F77_TA, &F77_N, &F77_M, &F77_KU, &F77_KL, &alpha, + A ,&F77_lda, X,&F77_incX, &beta, Y, &F77_incY); + } + else cblas_xerbla(1, "cblas_dgbmv", "Illegal Order setting, %d\n", order); +} diff --git a/ml/dlib/dlib/external/cblas/cblas_dgemm.c b/ml/dlib/dlib/external/cblas/cblas_dgemm.c new file mode 100644 index 000000000..4fa9d8603 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_dgemm.c @@ -0,0 +1,94 @@ +/* + * + * cblas_dgemm.c + * This program is a C interface to dgemm. + * Written by Keita Teranishi + * 4/8/1998 + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_dgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_TRANSPOSE TransB, const int M, const int N, + const int K, const double alpha, const double *A, + const int lda, const double *B, const int ldb, + const double beta, double *C, const int ldc) +{ + char TA, TB; +#ifdef F77_CHAR + F77_CHAR F77_TA, F77_TB; +#else + #define F77_TA &TA + #define F77_TB &TB +#endif + +#ifdef F77_INT + F77_INT F77_M=M, F77_N=N, F77_K=K, F77_lda=lda, F77_ldb=ldb; + F77_INT F77_ldc=ldc; +#else + #define F77_M M + #define F77_N N + #define F77_K K + #define F77_lda lda + #define F77_ldb ldb + #define F77_ldc ldc +#endif + + + if( Order == CblasColMajor ) + { + if(TransA == CblasTrans) TA='T'; + else if ( TransA == CblasConjTrans ) TA='C'; + else if ( TransA == CblasNoTrans ) TA='N'; + else + { + cblas_xerbla(2, "cblas_dgemm","Illegal TransA setting, %d\n", TransA); + return; + } + + if(TransB == CblasTrans) TB='T'; + else if ( TransB == CblasConjTrans ) TB='C'; + else if ( TransB == CblasNoTrans ) TB='N'; + else + { + cblas_xerbla(3, "cblas_dgemm","Illegal TransB setting, %d\n", TransB); + return; + } + + #ifdef F77_CHAR + F77_TA = C2F_CHAR(&TA); + F77_TB = C2F_CHAR(&TB); + #endif + + F77_dgemm(F77_TA, F77_TB, &F77_M, &F77_N, &F77_K, &alpha, A, + &F77_lda, B, &F77_ldb, &beta, C, &F77_ldc); + } else if (Order == CblasRowMajor) + { + if(TransA == CblasTrans) TB='T'; + else if ( TransA == CblasConjTrans ) TB='C'; + else if ( TransA == CblasNoTrans ) TB='N'; + else + { + cblas_xerbla(2, "cblas_dgemm","Illegal TransA setting, %d\n", TransA); + return; + } + if(TransB == CblasTrans) TA='T'; + else if ( TransB == CblasConjTrans ) TA='C'; + else if ( TransB == CblasNoTrans ) TA='N'; + else + { + cblas_xerbla(2, "cblas_dgemm","Illegal TransB setting, %d\n", TransB); + return; + } + #ifdef F77_CHAR + F77_TA = C2F_CHAR(&TA); + F77_TB = C2F_CHAR(&TB); + #endif + + F77_dgemm(F77_TA, F77_TB, &F77_N, &F77_M, &F77_K, &alpha, B, + &F77_ldb, A, &F77_lda, &beta, C, &F77_ldc); + } + else cblas_xerbla(1, "cblas_dgemm", "Illegal Order setting, %d\n", Order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_dgemv.c b/ml/dlib/dlib/external/cblas/cblas_dgemv.c new file mode 100644 index 000000000..23a0f51e7 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_dgemv.c @@ -0,0 +1,67 @@ +/* + * + * cblas_dgemv.c + * This program is a C interface to dgemv. + * Written by Keita Teranishi + * 4/6/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_dgemv(const enum CBLAS_ORDER order, + const enum CBLAS_TRANSPOSE TransA, const int M, const int N, + const double alpha, const double *A, const int lda, + const double *X, const int incX, const double beta, + double *Y, const int incY) +{ + char TA; +#ifdef F77_CHAR + F77_CHAR F77_TA; +#else + #define F77_TA &TA +#endif +#ifdef F77_INT + F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_incX=incX, F77_incY=incY; +#else + #define F77_M M + #define F77_N N + #define F77_lda lda + #define F77_incX incX + #define F77_incY incY +#endif + + if (order == CblasColMajor) + { + if (TransA == CblasNoTrans) TA = 'N'; + else if (TransA == CblasTrans) TA = 'T'; + else if (TransA == CblasConjTrans) TA = 'C'; + else + { + cblas_xerbla(2, "cblas_dgemv","Illegal TransA setting, %d\n", TransA); + return; + } + #ifdef F77_CHAR + F77_TA = C2F_CHAR(&TA); + #endif + F77_dgemv(F77_TA, &F77_M, &F77_N, &alpha, A, &F77_lda, X, &F77_incX, + &beta, Y, &F77_incY); + } + else if (order == CblasRowMajor) + { + if (TransA == CblasNoTrans) TA = 'T'; + else if (TransA == CblasTrans) TA = 'N'; + else if (TransA == CblasConjTrans) TA = 'N'; + else + { + cblas_xerbla(2, "cblas_dgemv","Illegal TransA setting, %d\n", TransA); + return; + } + #ifdef F77_CHAR + F77_TA = C2F_CHAR(&TA); + #endif + F77_dgemv(F77_TA, &F77_N, &F77_M, &alpha, A, &F77_lda, X, + &F77_incX, &beta, Y, &F77_incY); + } + else cblas_xerbla(1, "cblas_dgemv", "Illegal Order setting, %d\n", order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_dger.c b/ml/dlib/dlib/external/cblas/cblas_dger.c new file mode 100644 index 000000000..d021cc401 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_dger.c @@ -0,0 +1,40 @@ +/* + * + * cblas_dger.c + * This program is a C interface to dger. + * Written by Keita Teranishi + * 4/6/1998 + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_dger(const enum CBLAS_ORDER order, const int M, const int N, + const double alpha, const double *X, const int incX, + const double *Y, const int incY, double *A, const int lda) +{ +#ifdef F77_INT + F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_incX=incX, F77_incY=incY; +#else + #define F77_M M + #define F77_N N + #define F77_incX incX + #define F77_incY incY + #define F77_lda lda +#endif + + + if (order == CblasColMajor) + { + F77_dger( &F77_M, &F77_N, &alpha, X, &F77_incX, Y, &F77_incY, A, + &F77_lda); + } + else if (order == CblasRowMajor) + { + F77_dger( &F77_N, &F77_M ,&alpha, Y, &F77_incY, X, &F77_incX, A, + &F77_lda); + + } + else cblas_xerbla(1, "cblas_dger", "Illegal Order setting, %d\n", order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_dnrm2.c b/ml/dlib/dlib/external/cblas/cblas_dnrm2.c new file mode 100644 index 000000000..fe46ad484 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_dnrm2.c @@ -0,0 +1,23 @@ +/* + * cblas_dnrm2.c + * + * The program is a C interface to dnrm2. + * It calls the fortranwrapper before calling dnrm2. + * + * Written by Keita Teranishi. 2/11/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +double cblas_dnrm2( const int N, const double *X, const int incX) +{ + double nrm2; +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX; +#else + #define F77_N N + #define F77_incX incX +#endif + F77_dnrm2_sub( &F77_N, X, &F77_incX, &nrm2); + return nrm2; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_drot.c b/ml/dlib/dlib/external/cblas/cblas_drot.c new file mode 100644 index 000000000..51dc4ad5e --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_drot.c @@ -0,0 +1,23 @@ +/* + * cblas_drot.c + * + * The program is a C interface to drot. + * + * Written by Keita Teranishi. 2/11/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_drot(const int N, double *X, const int incX, + double *Y, const int incY, const double c, const double s) +{ +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX, F77_incY=incY; +#else + #define F77_N N + #define F77_incX incX + #define F77_incY incY +#endif + F77_drot(&F77_N, X, &F77_incX, Y, &F77_incY, &c, &s); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_drotg.c b/ml/dlib/dlib/external/cblas/cblas_drotg.c new file mode 100644 index 000000000..0cbbd8bc0 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_drotg.c @@ -0,0 +1,14 @@ +/* + * cblas_drotg.c + * + * The program is a C interface to drotg. + * + * Written by Keita Teranishi. 2/11/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_drotg( double *a, double *b, double *c, double *s) +{ + F77_drotg(a,b,c,s); +} diff --git a/ml/dlib/dlib/external/cblas/cblas_drotm.c b/ml/dlib/dlib/external/cblas/cblas_drotm.c new file mode 100644 index 000000000..ebe20ad62 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_drotm.c @@ -0,0 +1,14 @@ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_drotm( const int N, double *X, const int incX, double *Y, + const int incY, const double *P) +{ +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX, F77_incY=incY; +#else + #define F77_N N + #define F77_incX incX + #define F77_incY incY +#endif + F77_drotm( &F77_N, X, &F77_incX, Y, &F77_incY, P); +} diff --git a/ml/dlib/dlib/external/cblas/cblas_drotmg.c b/ml/dlib/dlib/external/cblas/cblas_drotmg.c new file mode 100644 index 000000000..13a2208e5 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_drotmg.c @@ -0,0 +1,15 @@ +/* + * cblas_drotmg.c + * + * The program is a C interface to drotmg. + * + * Written by Keita Teranishi. 2/11/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_drotmg( double *d1, double *d2, double *b1, + const double b2, double *p) +{ + F77_drotmg(d1,d2,b1,&b2,p); +} diff --git a/ml/dlib/dlib/external/cblas/cblas_dsbmv.c b/ml/dlib/dlib/external/cblas/cblas_dsbmv.c new file mode 100644 index 000000000..c2f1a71c3 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_dsbmv.c @@ -0,0 +1,66 @@ +/* + * + * cblas_dsbmv.c + * This program is a C interface to dsbmv. + * Written by Keita Teranishi + * 4/6/1998 + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_dsbmv(const enum CBLAS_ORDER order, + const enum CBLAS_UPLO Uplo, const int N, const int K, + const double alpha, const double *A, const int lda, + const double *X, const int incX, const double beta, + double *Y, const int incY) +{ + char UL; +#ifdef F77_CHAR + F77_CHAR F77_UL; +#else + #define F77_UL &UL +#endif +#ifdef F77_INT + F77_INT F77_N=N, F77_K=K, F77_lda=lda, F77_incX=incX, F77_incY=incY; +#else + #define F77_N N + #define F77_K K + #define F77_lda lda + #define F77_incX incX + #define F77_incY incY +#endif + + if (order == CblasColMajor) + { + if (Uplo == CblasUpper) UL = 'U'; + else if (Uplo == CblasLower) UL = 'L'; + else + { + cblas_xerbla(2, "cblas_dsbmv","Illegal Uplo setting, %d\n",Uplo ); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + F77_dsbmv(F77_UL, &F77_N, &F77_K, &alpha, A, &F77_lda, X, + &F77_incX, &beta, Y, &F77_incY); + } + else if (order == CblasRowMajor) + { + if (Uplo == CblasUpper) UL = 'L'; + else if (Uplo == CblasLower) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_dsbmv","Illegal Uplo setting, %d\n", Uplo); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + F77_dsbmv(F77_UL, &F77_N, &F77_K, &alpha, + A ,&F77_lda, X,&F77_incX, &beta, Y, &F77_incY); + } + else cblas_xerbla(1, "cblas_dsbmv", "Illegal Order setting, %d\n", order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_dscal.c b/ml/dlib/dlib/external/cblas/cblas_dscal.c new file mode 100644 index 000000000..bd04de77d --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_dscal.c @@ -0,0 +1,21 @@ +/* + * cblas_dscal.c + * + * The program is a C interface to dscal. + * + * Written by Keita Teranishi. 2/11/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_dscal( const int N, const double alpha, double *X, + const int incX) +{ +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX; +#else + #define F77_N N + #define F77_incX incX +#endif + F77_dscal( &F77_N, &alpha, X, &F77_incX); +} diff --git a/ml/dlib/dlib/external/cblas/cblas_dsdot.c b/ml/dlib/dlib/external/cblas/cblas_dsdot.c new file mode 100644 index 000000000..52cd877a2 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_dsdot.c @@ -0,0 +1,25 @@ +/* + * cblas_dsdot.c + * + * The program is a C interface to dsdot. + * It calls fthe fortran wrapper before calling dsdot. + * + * Written by Keita Teranishi. 2/11/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +double cblas_dsdot( const int N, const float *X, + const int incX, const float *Y, const int incY) +{ + double dot; +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX, F77_incY=incY; +#else + #define F77_N N + #define F77_incX incX + #define F77_incY incY +#endif + F77_dsdot_sub( &F77_N, X, &F77_incX, Y, &F77_incY, &dot); + return dot; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_dspmv.c b/ml/dlib/dlib/external/cblas/cblas_dspmv.c new file mode 100644 index 000000000..ecdf081cb --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_dspmv.c @@ -0,0 +1,65 @@ +/* + * + * cblas_dspmv.c + * This program is a C interface to dspmv. + * Written by Keita Teranishi + * 4/6/1998 + * + */ + + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_dspmv(const enum CBLAS_ORDER order, + const enum CBLAS_UPLO Uplo, const int N, + const double alpha, const double *AP, + const double *X, const int incX, const double beta, + double *Y, const int incY) +{ + char UL; +#ifdef F77_CHAR + F77_CHAR F77_UL; +#else + #define F77_UL &UL +#endif +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX, F77_incY=incY; +#else + #define F77_N N + #define F77_incX incX + #define F77_incY incY +#endif + + if (order == CblasColMajor) + { + if (Uplo == CblasUpper) UL = 'U'; + else if (Uplo == CblasLower) UL = 'L'; + else + { + cblas_xerbla(2, "cblas_dspmv","Illegal Uplo setting, %d\n",Uplo ); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + F77_dspmv(F77_UL, &F77_N, &alpha, AP, X, + &F77_incX, &beta, Y, &F77_incY); + } + else if (order == CblasRowMajor) + { + if (Uplo == CblasUpper) UL = 'L'; + else if (Uplo == CblasLower) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_dspmv","Illegal Uplo setting, %d\n", Uplo); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + F77_dspmv(F77_UL, &F77_N, &alpha, + AP, X,&F77_incX, &beta, Y, &F77_incY); + } + else cblas_xerbla(1, "cblas_dspmv", "Illegal Order setting, %d\n", order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_dspr.c b/ml/dlib/dlib/external/cblas/cblas_dspr.c new file mode 100644 index 000000000..9e40cc11a --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_dspr.c @@ -0,0 +1,59 @@ +/* + * + * cblas_dspr.c + * This program is a C interface to dspr. + * Written by Keita Teranishi + * 4/6/1998 + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_dspr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const int N, const double alpha, const double *X, + const int incX, double *Ap) +{ + char UL; +#ifdef F77_CHAR + F77_CHAR F77_UL; +#else + #define F77_UL &UL +#endif + +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX; +#else + #define F77_N N + #define F77_incX incX +#endif + if (order == CblasColMajor) + { + if (Uplo == CblasLower) UL = 'L'; + else if (Uplo == CblasUpper) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_dspr","Illegal Uplo setting, %d\n",Uplo ); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + + F77_dspr(F77_UL, &F77_N, &alpha, X, &F77_incX, Ap); + + } else if (order == CblasRowMajor) + { + if (Uplo == CblasLower) UL = 'U'; + else if (Uplo == CblasUpper) UL = 'L'; + else + { + cblas_xerbla(2, "cblas_dspr","Illegal Uplo setting, %d\n",Uplo ); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + F77_dspr(F77_UL, &F77_N, &alpha, X, &F77_incX, Ap); + } else cblas_xerbla(1, "cblas_dspr", "Illegal Order setting, %d\n", order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_dspr2.c b/ml/dlib/dlib/external/cblas/cblas_dspr2.c new file mode 100644 index 000000000..4ebbbbd52 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_dspr2.c @@ -0,0 +1,59 @@ +/* + * cblas_dspr2.c + * The program is a C interface to dspr2. + * + * Keita Teranishi 5/20/98 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_dspr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const int N, const double alpha, const double *X, + const int incX, const double *Y, const int incY, double *A) +{ + char UL; +#ifdef F77_CHAR + F77_CHAR F77_UL; +#else + #define F77_UL &UL +#endif + +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX, F77_incY=incY; +#else + #define F77_N N + #define F77_incX incX + #define F77_incY incY +#endif + + if (order == CblasColMajor) + { + if (Uplo == CblasLower) UL = 'L'; + else if (Uplo == CblasUpper) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_dspr2","Illegal Uplo setting, %d\n",Uplo ); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + + F77_dspr2(F77_UL, &F77_N, &alpha, X, &F77_incX, Y, &F77_incY, A); + + } else if (order == CblasRowMajor) + { + if (Uplo == CblasLower) UL = 'U'; + else if (Uplo == CblasUpper) UL = 'L'; + else + { + cblas_xerbla(2, "cblas_dspr2","Illegal Uplo setting, %d\n",Uplo ); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + F77_dspr2(F77_UL, &F77_N, &alpha, X, &F77_incX, Y, &F77_incY, A); + } else cblas_xerbla(1, "cblas_dspr2", "Illegal Order setting, %d\n", order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_dswap.c b/ml/dlib/dlib/external/cblas/cblas_dswap.c new file mode 100644 index 000000000..9ae5bb93c --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_dswap.c @@ -0,0 +1,22 @@ +/* + * cblas_dswap.c + * + * The program is a C interface to dswap. + * + * Written by Keita Teranishi. 2/11/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_dswap( const int N, double *X, const int incX, double *Y, + const int incY) +{ +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX, F77_incY=incY; +#else + #define F77_N N + #define F77_incX incX + #define F77_incY incY +#endif + F77_dswap( &F77_N, X, &F77_incX, Y, &F77_incY); +} diff --git a/ml/dlib/dlib/external/cblas/cblas_dsymm.c b/ml/dlib/dlib/external/cblas/cblas_dsymm.c new file mode 100644 index 000000000..99b3858a1 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_dsymm.c @@ -0,0 +1,91 @@ +/* + * + * cblas_dsymm.c + * This program is a C interface to dsymm. + * Written by Keita Teranishi + * 4/8/1998 + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_dsymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const int M, const int N, + const double alpha, const double *A, const int lda, + const double *B, const int ldb, const double beta, + double *C, const int ldc) +{ + char SD, UL; +#ifdef F77_CHAR + F77_CHAR F77_SD, F77_UL; +#else + #define F77_SD &SD + #define F77_UL &UL +#endif + +#ifdef F77_INT + F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_ldb=ldb; + F77_INT F77_ldc=ldc; +#else + #define F77_M M + #define F77_N N + #define F77_lda lda + #define F77_ldb ldb + #define F77_ldc ldc +#endif + + + if( Order == CblasColMajor ) + { + if( Side == CblasRight) SD='R'; + else if ( Side == CblasLeft ) SD='L'; + else + { + cblas_xerbla(2, "cblas_dsymm","Illegal Side setting, %d\n", Side); + return; + } + + if( Uplo == CblasUpper) UL='U'; + else if ( Uplo == CblasLower ) UL='L'; + else + { + cblas_xerbla(3, "cblas_dsymm","Illegal Uplo setting, %d\n", Uplo); + return; + } + + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_SD = C2F_CHAR(&SD); + #endif + + F77_dsymm(F77_SD, F77_UL, &F77_M, &F77_N, &alpha, A, &F77_lda, + B, &F77_ldb, &beta, C, &F77_ldc); + } else if (Order == CblasRowMajor) + { + if( Side == CblasRight) SD='L'; + else if ( Side == CblasLeft ) SD='R'; + else + { + cblas_xerbla(2, "cblas_dsymm","Illegal Side setting, %d\n", Side); + return; + } + + if( Uplo == CblasUpper) UL='L'; + else if ( Uplo == CblasLower ) UL='U'; + else + { + cblas_xerbla(3, "cblas_dsymm","Illegal Uplo setting, %d\n", Uplo); + return; + } + + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_SD = C2F_CHAR(&SD); + #endif + + F77_dsymm(F77_SD, F77_UL, &F77_N, &F77_M, &alpha, A, &F77_lda, B, + &F77_ldb, &beta, C, &F77_ldc); + } + else cblas_xerbla(1, "cblas_dsymm","Illegal Order setting, %d\n", Order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_dsymv.c b/ml/dlib/dlib/external/cblas/cblas_dsymv.c new file mode 100644 index 000000000..f0d32398a --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_dsymv.c @@ -0,0 +1,65 @@ +/* + * + * cblas_dsymv.c + * This program is a C interface to dsymv. + * Written by Keita Teranishi + * 4/6/1998 + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_dsymv(const enum CBLAS_ORDER order, + const enum CBLAS_UPLO Uplo, const int N, + const double alpha, const double *A, const int lda, + const double *X, const int incX, const double beta, + double *Y, const int incY) +{ + char UL; +#ifdef F77_CHAR + F77_CHAR F77_UL; +#else + #define F77_UL &UL +#endif +#ifdef F77_INT + F77_INT F77_N=N, F77_lda=lda, F77_incX=incX, F77_incY=incY; +#else + #define F77_N N + #define F77_lda lda + #define F77_incX incX + #define F77_incY incY +#endif + + if (order == CblasColMajor) + { + if (Uplo == CblasUpper) UL = 'U'; + else if (Uplo == CblasLower) UL = 'L'; + else + { + cblas_xerbla(2, "cblas_dsymv","Illegal Uplo setting, %d\n",Uplo ); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + F77_dsymv(F77_UL, &F77_N, &alpha, A, &F77_lda, X, + &F77_incX, &beta, Y, &F77_incY); + } + else if (order == CblasRowMajor) + { + if (Uplo == CblasUpper) UL = 'L'; + else if (Uplo == CblasLower) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_dsymv","Illegal Uplo setting, %d\n", Uplo); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + F77_dsymv(F77_UL, &F77_N, &alpha, + A ,&F77_lda, X,&F77_incX, &beta, Y, &F77_incY); + } + else cblas_xerbla(1, "cblas_dsymv", "Illegal Order setting, %d\n", order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_dsyr.c b/ml/dlib/dlib/external/cblas/cblas_dsyr.c new file mode 100644 index 000000000..d21b846e3 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_dsyr.c @@ -0,0 +1,60 @@ +/* + * + * cblas_dsyr.c + * This program is a C interface to dsyr. + * Written by Keita Teranishi + * 4/6/1998 + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_dsyr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const int N, const double alpha, const double *X, + const int incX, double *A, const int lda) +{ + char UL; +#ifdef F77_CHAR + F77_CHAR F77_UL; +#else + #define F77_UL &UL +#endif + +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX, F77_lda=lda; +#else + #define F77_N N + #define F77_incX incX + #define F77_lda lda +#endif + if (order == CblasColMajor) + { + if (Uplo == CblasLower) UL = 'L'; + else if (Uplo == CblasUpper) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_dsyr","Illegal Uplo setting, %d\n",Uplo ); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + + F77_dsyr(F77_UL, &F77_N, &alpha, X, &F77_incX, A, &F77_lda); + + } else if (order == CblasRowMajor) + { + if (Uplo == CblasLower) UL = 'U'; + else if (Uplo == CblasUpper) UL = 'L'; + else + { + cblas_xerbla(2, "cblas_dsyr","Illegal Uplo setting, %d\n",Uplo ); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + F77_dsyr(F77_UL, &F77_N, &alpha, X, &F77_incX, A, &F77_lda); + } else cblas_xerbla(1, "cblas_dsyr", "Illegal Order setting, %d\n", order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_dsyr2.c b/ml/dlib/dlib/external/cblas/cblas_dsyr2.c new file mode 100644 index 000000000..7ce59657d --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_dsyr2.c @@ -0,0 +1,65 @@ +/* + * + * cblas_dsyr2.c + * This program is a C interface to dsyr2. + * Written by Keita Teranishi + * 4/6/1998 + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_dsyr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const int N, const double alpha, const double *X, + const int incX, const double *Y, const int incY, double *A, + const int lda) +{ + char UL; +#ifdef F77_CHAR + F77_CHAR F77_UL; +#else + #define F77_UL &UL +#endif + +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX, F77_incY=incY, F77_lda=lda; +#else + #define F77_N N + #define F77_incX incX + #define F77_incY incY + #define F77_lda lda +#endif + + if (order == CblasColMajor) + { + if (Uplo == CblasLower) UL = 'L'; + else if (Uplo == CblasUpper) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_dsyr2","Illegal Uplo setting, %d\n",Uplo ); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + + F77_dsyr2(F77_UL, &F77_N, &alpha, X, &F77_incX, Y, &F77_incY, A, + &F77_lda); + + } else if (order == CblasRowMajor) + { + if (Uplo == CblasLower) UL = 'U'; + else if (Uplo == CblasUpper) UL = 'L'; + else + { + cblas_xerbla(2, "cblas_dsyr2","Illegal Uplo setting, %d\n",Uplo ); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + F77_dsyr2(F77_UL, &F77_N, &alpha, X, &F77_incX, Y, &F77_incY, A, + &F77_lda); + } else cblas_xerbla(1, "cblas_dsyr2", "Illegal Order setting, %d\n", order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_dsyr2k.c b/ml/dlib/dlib/external/cblas/cblas_dsyr2k.c new file mode 100644 index 000000000..dc11e9549 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_dsyr2k.c @@ -0,0 +1,94 @@ +/* + * + * cblas_dsyr2k.c + * This program is a C interface to dsyr2k. + * Written by Keita Teranishi + * 4/6/1998 + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_dsyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const double alpha, const double *A, const int lda, + const double *B, const int ldb, const double beta, + double *C, const int ldc) +{ + char UL, TR; +#ifdef F77_CHAR + F77_CHAR F77_TA, F77_UL; +#else + #define F77_TR &TR + #define F77_UL &UL +#endif + +#ifdef F77_INT + F77_INT F77_N=N, F77_K=K, F77_lda=lda, F77_ldb=ldb; + F77_INT F77_ldc=ldc; +#else + #define F77_N N + #define F77_K K + #define F77_lda lda + #define F77_ldb ldb + #define F77_ldc ldc +#endif + + + if( Order == CblasColMajor ) + { + + if( Uplo == CblasUpper) UL='U'; + else if ( Uplo == CblasLower ) UL='L'; + else + { + cblas_xerbla(2, "cblas_dsyr2k","Illegal Uplo setting, %d\n", Uplo); + return; + } + + if( Trans == CblasTrans) TR ='T'; + else if ( Trans == CblasConjTrans ) TR='C'; + else if ( Trans == CblasNoTrans ) TR='N'; + else + { + cblas_xerbla(3, "cblas_dsyr2k","Illegal Trans setting, %d\n", Trans); + return; + } + + + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TR = C2F_CHAR(&TR); + #endif + + F77_dsyr2k(F77_UL, F77_TR, &F77_N, &F77_K, &alpha, A, &F77_lda, + B, &F77_ldb, &beta, C, &F77_ldc); + } else if (Order == CblasRowMajor) + { + if( Uplo == CblasUpper) UL='L'; + else if ( Uplo == CblasLower ) UL='U'; + else + { + cblas_xerbla(3, "cblas_dsyr2k","Illegal Uplo setting, %d\n", Uplo); + return; + } + if( Trans == CblasTrans) TR ='N'; + else if ( Trans == CblasConjTrans ) TR='N'; + else if ( Trans == CblasNoTrans ) TR='T'; + else + { + cblas_xerbla(3, "cblas_dsyr2k","Illegal Trans setting, %d\n", Trans); + return; + } + + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TR = C2F_CHAR(&TR); + #endif + + F77_dsyr2k(F77_UL, F77_TR, &F77_N, &F77_K, &alpha, A, &F77_lda, B, + &F77_ldb, &beta, C, &F77_ldc); + } + else cblas_xerbla(1, "cblas_dsyr2k","Illegal Order setting, %d\n", Order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_dsyrk.c b/ml/dlib/dlib/external/cblas/cblas_dsyrk.c new file mode 100644 index 000000000..7ee834ea6 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_dsyrk.c @@ -0,0 +1,93 @@ +/* + * + * cblas_dsyrk.c + * This program is a C interface to dsyrk. + * Written by Keita Teranishi + * 4/8/1998 + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_dsyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const double alpha, const double *A, const int lda, + const double beta, double *C, const int ldc) +{ + char UL, TR; +#ifdef F77_CHAR + F77_CHAR F77_TR, F77_UL; +#else + #define F77_TR &TR + #define F77_UL &UL +#endif + +#ifdef F77_INT + F77_INT F77_N=N, F77_K=K, F77_lda=lda; + F77_INT F77_ldc=ldc; +#else + #define F77_N N + #define F77_K K + #define F77_lda lda + #define F77_ldc ldc +#endif + + + if( Order == CblasColMajor ) + { + + if( Uplo == CblasUpper) UL='U'; + else if ( Uplo == CblasLower ) UL='L'; + else + { + cblas_xerbla(2, "cblas_dsyrk","Illegal Uplo setting, %d\n", Uplo); + return; + } + + if( Trans == CblasTrans) TR ='T'; + else if ( Trans == CblasConjTrans ) TR='C'; + else if ( Trans == CblasNoTrans ) TR='N'; + else + { + cblas_xerbla(3, "cblas_dsyrk","Illegal Trans setting, %d\n", Trans); + return; + } + + + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TR = C2F_CHAR(&TR); + #endif + + F77_dsyrk(F77_UL, F77_TR, &F77_N, &F77_K, &alpha, A, &F77_lda, + &beta, C, &F77_ldc); + } else if (Order == CblasRowMajor) + { + if( Uplo == CblasUpper) UL='L'; + else if ( Uplo == CblasLower ) UL='U'; + else + { + cblas_xerbla(3, "cblas_dsyrk","Illegal Uplo setting, %d\n", Uplo); + return; + } + if( Trans == CblasTrans) TR ='N'; + else if ( Trans == CblasConjTrans ) TR='N'; + else if ( Trans == CblasNoTrans ) TR='T'; + else + { + cblas_xerbla(3, "cblas_dsyrk","Illegal Trans setting, %d\n", Trans); + return; + } + + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TR = C2F_CHAR(&TR); + #endif + + F77_dsyrk(F77_UL, F77_TR, &F77_N, &F77_K, &alpha, A, &F77_lda, + &beta, C, &F77_ldc); + } + else cblas_xerbla(1, "cblas_dsyrk","Illegal Order setting, %d\n", Order); + return; +} + diff --git a/ml/dlib/dlib/external/cblas/cblas_dtbmv.c b/ml/dlib/dlib/external/cblas/cblas_dtbmv.c new file mode 100644 index 000000000..1a06d1886 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_dtbmv.c @@ -0,0 +1,103 @@ +/* + * cblas_dtbmv.c + * The program is a C interface to dtbmv. + * + * Keita Teranishi 5/20/98 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_dtbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const int K, const double *A, const int lda, + double *X, const int incX) +{ + char TA; + char UL; + char DI; +#ifdef F77_CHAR + F77_CHAR F77_TA, F77_UL, F77_DI; +#else + #define F77_TA &TA + #define F77_UL &UL + #define F77_DI &DI +#endif +#ifdef F77_INT + F77_INT F77_N=N, F77_lda=lda, F77_K=K, F77_incX=incX; +#else + #define F77_N N + #define F77_K K + #define F77_lda lda + #define F77_incX incX +#endif + + if (order == CblasColMajor) + { + if (Uplo == CblasUpper) UL = 'U'; + else if (Uplo == CblasLower) UL = 'L'; + else + { + cblas_xerbla(2, "cblas_dtbmv","Illegal Uplo setting, %d\n", Uplo); + return; + } + if (TransA == CblasNoTrans) TA = 'N'; + else if (TransA == CblasTrans) TA = 'T'; + else if (TransA == CblasConjTrans) TA = 'C'; + else + { + cblas_xerbla(3, "cblas_dtbmv","Illegal TransA setting, %d\n", TransA); + return; + } + if (Diag == CblasUnit) DI = 'U'; + else if (Diag == CblasNonUnit) DI = 'N'; + else + { + cblas_xerbla(4, "cblas_dtbmv","Illegal Diag setting, %d\n", Diag); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_DI = C2F_CHAR(&DI); + #endif + F77_dtbmv( F77_UL, F77_TA, F77_DI, &F77_N, &F77_K, A, &F77_lda, X, + &F77_incX); + } + else if (order == CblasRowMajor) + { + if (Uplo == CblasUpper) UL = 'L'; + else if (Uplo == CblasLower) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_dtbmv","Illegal Uplo setting, %d\n", Uplo); + return; + } + + if (TransA == CblasNoTrans) TA = 'T'; + else if (TransA == CblasTrans) TA = 'N'; + else if (TransA == CblasConjTrans) TA = 'N'; + else + { + cblas_xerbla(3, "cblas_dtbmv","Illegal TransA setting, %d\n", TransA); + return; + } + + if (Diag == CblasUnit) DI = 'U'; + else if (Diag == CblasNonUnit) DI = 'N'; + else + { + cblas_xerbla(4, "cblas_dtbmv","Illegal Uplo setting, %d\n", Uplo); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_DI = C2F_CHAR(&DI); + #endif + + F77_dtbmv( F77_UL, F77_TA, F77_DI, &F77_N, &F77_K, A, &F77_lda, X, + &F77_incX); + + } + else cblas_xerbla(1, "cblas_dtbmv", "Illegal Order setting, %d\n", order); +} diff --git a/ml/dlib/dlib/external/cblas/cblas_dtbsv.c b/ml/dlib/dlib/external/cblas/cblas_dtbsv.c new file mode 100644 index 000000000..aaf4a8d4b --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_dtbsv.c @@ -0,0 +1,103 @@ +/* + * cblas_dtbsv.c + * The program is a C interface to dtbsv. + * + * Keita Teranishi 5/20/98 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_dtbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const int K, const double *A, const int lda, + double *X, const int incX) +{ + char TA; + char UL; + char DI; +#ifdef F77_CHAR + F77_CHAR F77_TA, F77_UL, F77_DI; +#else + #define F77_TA &TA + #define F77_UL &UL + #define F77_DI &DI +#endif +#ifdef F77_INT + F77_INT F77_N=N, F77_lda=lda, F77_K=K, F77_incX=incX; +#else + #define F77_N N + #define F77_K K + #define F77_lda lda + #define F77_incX incX +#endif + + if (order == CblasColMajor) + { + if (Uplo == CblasUpper) UL = 'U'; + else if (Uplo == CblasLower) UL = 'L'; + else + { + cblas_xerbla(2, "cblas_dtbsv","Illegal Uplo setting, %d\n", Uplo); + return; + } + if (TransA == CblasNoTrans) TA = 'N'; + else if (TransA == CblasTrans) TA = 'T'; + else if (TransA == CblasConjTrans) TA = 'C'; + else + { + cblas_xerbla(3, "cblas_dtbsv","Illegal TransA setting, %d\n", TransA); + return; + } + if (Diag == CblasUnit) DI = 'U'; + else if (Diag == CblasNonUnit) DI = 'N'; + else + { + cblas_xerbla(4, "cblas_dtbsv","Illegal Diag setting, %d\n", Diag); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_DI = C2F_CHAR(&DI); + #endif + F77_dtbsv( F77_UL, F77_TA, F77_DI, &F77_N, &F77_K, A, &F77_lda, X, + &F77_incX); + } + else if (order == CblasRowMajor) + { + if (Uplo == CblasUpper) UL = 'L'; + else if (Uplo == CblasLower) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_dtbsv","Illegal Uplo setting, %d\n", Uplo); + return; + } + + if (TransA == CblasNoTrans) TA = 'T'; + else if (TransA == CblasTrans) TA = 'N'; + else if (TransA == CblasConjTrans) TA = 'N'; + else + { + cblas_xerbla(3, "cblas_dtbsv","Illegal TransA setting, %d\n", TransA); + return; + } + + if (Diag == CblasUnit) DI = 'U'; + else if (Diag == CblasNonUnit) DI = 'N'; + else + { + cblas_xerbla(4, "cblas_dtbsv","Illegal Diag setting, %d\n", Diag); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_DI = C2F_CHAR(&DI); + #endif + + F77_dtbsv( F77_UL, F77_TA, F77_DI, &F77_N, &F77_K, A, &F77_lda, X, + &F77_incX); + } + else cblas_xerbla(1, "cblas_dtbsv", "Illegal Order setting, %d\n", order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_dtpmv.c b/ml/dlib/dlib/external/cblas/cblas_dtpmv.c new file mode 100644 index 000000000..565f97a68 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_dtpmv.c @@ -0,0 +1,98 @@ +/* + * cblas_dtpmv.c + * The program is a C interface to dtpmv. + * + * Keita Teranishi 5/20/98 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_dtpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const double *Ap, double *X, const int incX) +{ + char TA; + char UL; + char DI; +#ifdef F77_CHAR + F77_CHAR F77_TA, F77_UL, F77_DI; +#else + #define F77_TA &TA + #define F77_UL &UL + #define F77_DI &DI +#endif +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX; +#else + #define F77_N N + #define F77_incX incX +#endif + + if (order == CblasColMajor) + { + if (Uplo == CblasUpper) UL = 'U'; + else if (Uplo == CblasLower) UL = 'L'; + else + { + cblas_xerbla(2, "cblas_dtpmv","Illegal Uplo setting, %d\n", Uplo); + return; + } + if (TransA == CblasNoTrans) TA = 'N'; + else if (TransA == CblasTrans) TA = 'T'; + else if (TransA == CblasConjTrans) TA = 'C'; + else + { + cblas_xerbla(3, "cblas_dtpmv","Illegal TransA setting, %d\n", TransA); + return; + } + if (Diag == CblasUnit) DI = 'U'; + else if (Diag == CblasNonUnit) DI = 'N'; + else + { + cblas_xerbla(4, "cblas_dtpmv","Illegal Diag setting, %d\n", Diag); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_DI = C2F_CHAR(&DI); + #endif + F77_dtpmv( F77_UL, F77_TA, F77_DI, &F77_N, Ap, X, &F77_incX); + } + else if (order == CblasRowMajor) + { + if (Uplo == CblasUpper) UL = 'L'; + else if (Uplo == CblasLower) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_dtpmv","Illegal Uplo setting, %d\n", Uplo); + return; + } + + if (TransA == CblasNoTrans) TA = 'T'; + else if (TransA == CblasTrans) TA = 'N'; + else if (TransA == CblasConjTrans) TA = 'N'; + else + { + cblas_xerbla(3, "cblas_dtpmv","Illegal TransA setting, %d\n", TransA); + return; + } + + if (Diag == CblasUnit) DI = 'U'; + else if (Diag == CblasNonUnit) DI = 'N'; + else + { + cblas_xerbla(4, "cblas_dtpmv","Illegal Diag setting, %d\n", Diag); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_DI = C2F_CHAR(&DI); + #endif + + F77_dtpmv( F77_UL, F77_TA, F77_DI, &F77_N, Ap, X,&F77_incX); + } + else cblas_xerbla(1, "cblas_dtpmv", "Illegal Order setting, %d\n", order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_dtpsv.c b/ml/dlib/dlib/external/cblas/cblas_dtpsv.c new file mode 100644 index 000000000..4f51ccc56 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_dtpsv.c @@ -0,0 +1,99 @@ +/* + * cblas_dtpsv.c + * The program is a C interface to dtpsv. + * + * Keita Teranishi 5/20/98 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_dtpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const double *Ap, double *X, const int incX) +{ + char TA; + char UL; + char DI; +#ifdef F77_CHAR + F77_CHAR F77_TA, F77_UL, F77_DI; +#else + #define F77_TA &TA + #define F77_UL &UL + #define F77_DI &DI +#endif +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX; +#else + #define F77_N N + #define F77_incX incX +#endif + + if (order == CblasColMajor) + { + if (Uplo == CblasUpper) UL = 'U'; + else if (Uplo == CblasLower) UL = 'L'; + else + { + cblas_xerbla(2, "cblas_dtpsv","Illegal Uplo setting, %d\n", Uplo); + return; + } + if (TransA == CblasNoTrans) TA = 'N'; + else if (TransA == CblasTrans) TA = 'T'; + else if (TransA == CblasConjTrans) TA = 'C'; + else + { + cblas_xerbla(3, "cblas_dtpsv","Illegal TransA setting, %d\n", TransA); + return; + } + if (Diag == CblasUnit) DI = 'U'; + else if (Diag == CblasNonUnit) DI = 'N'; + else + { + cblas_xerbla(4, "cblas_dtpsv","Illegal Diag setting, %d\n", Diag); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_DI = C2F_CHAR(&DI); + #endif + F77_dtpsv( F77_UL, F77_TA, F77_DI, &F77_N, Ap, X, &F77_incX); + } + else if (order == CblasRowMajor) + { + if (Uplo == CblasUpper) UL = 'L'; + else if (Uplo == CblasLower) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_dtpsv","Illegal Uplo setting, %d\n", Uplo); + return; + } + + if (TransA == CblasNoTrans) TA = 'T'; + else if (TransA == CblasTrans) TA = 'N'; + else if (TransA == CblasConjTrans) TA = 'N'; + else + { + cblas_xerbla(3, "cblas_dtpsv","Illegal TransA setting, %d\n", TransA); + return; + } + + if (Diag == CblasUnit) DI = 'U'; + else if (Diag == CblasNonUnit) DI = 'N'; + else + { + cblas_xerbla(4, "cblas_dtpsv","Illegal Diag setting, %d\n", Diag); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_DI = C2F_CHAR(&DI); + #endif + + F77_dtpsv( F77_UL, F77_TA, F77_DI, &F77_N, Ap, X,&F77_incX); + + } + else cblas_xerbla(1, "cblas_dtpsv", "Illegal Order setting, %d\n", order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_dtrmm.c b/ml/dlib/dlib/external/cblas/cblas_dtrmm.c new file mode 100644 index 000000000..0f4c0d161 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_dtrmm.c @@ -0,0 +1,125 @@ +/* + * + * cblas_dtrmm.c + * This program is a C interface to dtrmm. + * Written by Keita Teranishi + * 4/6/1998 + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_dtrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_DIAG Diag, const int M, const int N, + const double alpha, const double *A, const int lda, + double *B, const int ldb) +{ + char UL, TA, SD, DI; +#ifdef F77_CHAR + F77_CHAR F77_TA, F77_UL, F77_SD, F77_DI; +#else + #define F77_TA &TA + #define F77_UL &UL + #define F77_SD &SD + #define F77_DI &DI +#endif + +#ifdef F77_INT + F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_ldb=ldb; +#else + #define F77_M M + #define F77_N N + #define F77_lda lda + #define F77_ldb ldb +#endif + + + if( Order == CblasColMajor ) + { + if( Side == CblasRight) SD='R'; + else if ( Side == CblasLeft ) SD='L'; + else + { + cblas_xerbla(2, "cblas_dtrmm","Illegal Side setting, %d\n", Side); + return; + } + if( Uplo == CblasUpper) UL='U'; + else if ( Uplo == CblasLower ) UL='L'; + else + { + cblas_xerbla(3, "cblas_dtrmm","Illegal Uplo setting, %d\n", Uplo); + return; + } + + if( TransA == CblasTrans) TA ='T'; + else if ( TransA == CblasConjTrans ) TA='C'; + else if ( TransA == CblasNoTrans ) TA='N'; + else + { + cblas_xerbla(4, "cblas_dtrmm","Illegal Trans setting, %d\n", TransA); + return; + } + + if( Diag == CblasUnit ) DI='U'; + else if ( Diag == CblasNonUnit ) DI='N'; + else + { + cblas_xerbla(5, "cblas_dtrmm","Illegal Diag setting, %d\n", Diag); + return; + } + + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_SD = C2F_CHAR(&SD); + F77_DI = C2F_CHAR(&DI); + #endif + + F77_dtrmm(F77_SD, F77_UL, F77_TA, F77_DI, &F77_M, &F77_N, &alpha, A, &F77_lda, B, &F77_ldb); + } else if (Order == CblasRowMajor) + { + if( Side == CblasRight) SD='L'; + else if ( Side == CblasLeft ) SD='R'; + else + { + cblas_xerbla(2, "cblas_dtrmm","Illegal Side setting, %d\n", Side); + return; + } + + if( Uplo == CblasUpper) UL='L'; + else if ( Uplo == CblasLower ) UL='U'; + else + { + cblas_xerbla(3, "cblas_dtrmm","Illegal Uplo setting, %d\n", Uplo); + return; + } + + if( TransA == CblasTrans) TA ='T'; + else if ( TransA == CblasConjTrans ) TA='C'; + else if ( TransA == CblasNoTrans ) TA='N'; + else + { + cblas_xerbla(4, "cblas_dtrmm","Illegal Trans setting, %d\n", TransA); + return; + } + + if( Diag == CblasUnit ) DI='U'; + else if ( Diag == CblasNonUnit ) DI='N'; + else + { + cblas_xerbla(5, "cblas_dtrmm","Illegal Diag setting, %d\n", Diag); + return; + } + + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_SD = C2F_CHAR(&SD); + F77_DI = C2F_CHAR(&DI); + #endif + F77_dtrmm(F77_SD, F77_UL, F77_TA, F77_DI, &F77_N, &F77_M, &alpha, A, &F77_lda, B, &F77_ldb); + } + else cblas_xerbla(1, "cblas_dtrmm", "Illegal Order setting, %d\n", Order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_dtrmv.c b/ml/dlib/dlib/external/cblas/cblas_dtrmv.c new file mode 100644 index 000000000..c20ea0626 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_dtrmv.c @@ -0,0 +1,103 @@ +/* + * + * cblas_dtrmv.c + * This program is a C interface to sgemv. + * Written by Keita Teranishi + * 4/6/1998 + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_dtrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const double *A, const int lda, + double *X, const int incX) + +{ + char TA; + char UL; + char DI; +#ifdef F77_CHAR + F77_CHAR F77_TA, F77_UL, F77_DI; +#else + #define F77_TA &TA + #define F77_UL &UL + #define F77_DI &DI +#endif +#ifdef F77_INT + F77_INT F77_N=N, F77_lda=lda, F77_incX=incX; +#else + #define F77_N N + #define F77_lda lda + #define F77_incX incX +#endif + + if (order == CblasColMajor) + { + if (Uplo == CblasUpper) UL = 'U'; + else if (Uplo == CblasLower) UL = 'L'; + else + { + cblas_xerbla(2, "cblas_dtrmv","Illegal Uplo setting, %d\n", Uplo); + return; + } + if (TransA == CblasNoTrans) TA = 'N'; + else if (TransA == CblasTrans) TA = 'T'; + else if (TransA == CblasConjTrans) TA = 'C'; + else + { + cblas_xerbla(3, "cblas_dtrmv","Illegal TransA setting, %d\n", TransA); + return; + } + if (Diag == CblasUnit) DI = 'U'; + else if (Diag == CblasNonUnit) DI = 'N'; + else + { + cblas_xerbla(4, "cblas_dtrmv","Illegal Diag setting, %d\n", Diag); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_DI = C2F_CHAR(&DI); + #endif + F77_dtrmv( F77_UL, F77_TA, F77_DI, &F77_N, A, &F77_lda, X, + &F77_incX); + } + else if (order == CblasRowMajor) + { + if (Uplo == CblasUpper) UL = 'L'; + else if (Uplo == CblasLower) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_dtrmv","Illegal Uplo setting, %d\n", Uplo); + return; + } + + if (TransA == CblasNoTrans) TA = 'T'; + else if (TransA == CblasTrans) TA = 'N'; + else if (TransA == CblasConjTrans) TA = 'N'; + else + { + cblas_xerbla(3, "cblas_dtrmv","Illegal TransA setting, %d\n", TransA); + return; + } + + if (Diag == CblasUnit) DI = 'U'; + else if (Diag == CblasNonUnit) DI = 'N'; + else + { + cblas_xerbla(4, "cblas_dtrmv","Illegal Diag setting, %d\n", Diag); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_DI = C2F_CHAR(&DI); + #endif + F77_dtrmv( F77_UL, F77_TA, F77_DI, &F77_N, A, &F77_lda, X, + &F77_incX); + } else cblas_xerbla(1, "cblas_dtrmv", "Illegal order setting, %d\n", order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_dtrsm.c b/ml/dlib/dlib/external/cblas/cblas_dtrsm.c new file mode 100644 index 000000000..986425f60 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_dtrsm.c @@ -0,0 +1,130 @@ +/* + * + * cblas_dtrsm.c + * This program is a C interface to dtrsm. + * Written by Keita Teranishi + * 4/6/1998 + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_dtrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_DIAG Diag, const int M, const int N, + const double alpha, const double *A, const int lda, + double *B, const int ldb) + +{ + char UL, TA, SD, DI; +#ifdef F77_CHAR + F77_CHAR F77_TA, F77_UL, F77_SD, F77_DI; +#else + #define F77_TA &TA + #define F77_UL &UL + #define F77_SD &SD + #define F77_DI &DI +#endif + +#ifdef F77_INT + F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_ldb=ldb; +#else + #define F77_M M + #define F77_N N + #define F77_lda lda + #define F77_ldb ldb +#endif + + + if( Order == CblasColMajor ) + { + if ( Side == CblasRight) SD='R'; + else if ( Side == CblasLeft ) SD='L'; + else + { + cblas_xerbla(2, "cblas_dtrsm","Illegal Side setting, %d\n", Side); + return; + } + if ( Uplo == CblasUpper) UL='U'; + else if ( Uplo == CblasLower) UL='L'; + else + { + cblas_xerbla(3, "cblas_dtrsm","Illegal Uplo setting, %d\n", Uplo); + return; + } + + if ( TransA == CblasTrans ) TA='T'; + else if ( TransA == CblasConjTrans) TA='C'; + else if ( TransA == CblasNoTrans ) TA='N'; + else + { + cblas_xerbla(4, "cblas_dtrsm","Illegal Trans setting, %d\n", TransA); + return; + } + + if ( Diag == CblasUnit ) DI='U'; + else if ( Diag == CblasNonUnit) DI='N'; + else + { + cblas_xerbla(5, "cblas_dtrsm","Illegal Diag setting, %d\n", Diag); + return; + } + + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_SD = C2F_CHAR(&SD); + F77_DI = C2F_CHAR(&DI); + #endif + + F77_dtrsm(F77_SD, F77_UL, F77_TA, F77_DI, &F77_M, &F77_N, &alpha, + A, &F77_lda, B, &F77_ldb); + } + else if (Order == CblasRowMajor) + { + if ( Side == CblasRight) SD='L'; + else if ( Side == CblasLeft ) SD='R'; + else + { + cblas_xerbla(2, "cblas_dtrsm","Illegal Side setting, %d\n", Side); + return; + } + + if ( Uplo == CblasUpper) UL='L'; + else if ( Uplo == CblasLower) UL='U'; + else + { + cblas_xerbla(3, "cblas_dtrsm","Illegal Uplo setting, %d\n", Uplo); + return; + } + + if ( TransA == CblasTrans ) TA='T'; + else if ( TransA == CblasConjTrans) TA='C'; + else if ( TransA == CblasNoTrans ) TA='N'; + else + { + cblas_xerbla(4, "cblas_dtrsm","Illegal Trans setting, %d\n", TransA); + return; + } + + if ( Diag == CblasUnit ) DI='U'; + else if ( Diag == CblasNonUnit) DI='N'; + else + { + cblas_xerbla(5, "cblas_dtrsm","Illegal Diag setting, %d\n", Diag); + return; + } + + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_SD = C2F_CHAR(&SD); + F77_DI = C2F_CHAR(&DI); + #endif + + F77_dtrsm(F77_SD, F77_UL, F77_TA, F77_DI, &F77_N, &F77_M, &alpha, A, + &F77_lda, B, &F77_ldb); + } + else cblas_xerbla(1, "cblas_dtrsm","Illegal Order setting, %d\n", Order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_dtrsv.c b/ml/dlib/dlib/external/cblas/cblas_dtrsv.c new file mode 100644 index 000000000..5c4ed5637 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_dtrsv.c @@ -0,0 +1,102 @@ +/* + * cblas_dtrsv.c + * The program is a C interface to dtrsv. + * + * Keita Teranishi 5/20/98 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_dtrsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const double *A, const int lda, double *X, + const int incX) + +{ + char TA; + char UL; + char DI; +#ifdef F77_CHAR + F77_CHAR F77_TA, F77_UL, F77_DI; +#else + #define F77_TA &TA + #define F77_UL &UL + #define F77_DI &DI +#endif +#ifdef F77_INT + F77_INT F77_N=N, F77_lda=lda, F77_incX=incX; +#else + #define F77_N N + #define F77_lda lda + #define F77_incX incX +#endif + + if (order == CblasColMajor) + { + if (Uplo == CblasUpper) UL = 'U'; + else if (Uplo == CblasLower) UL = 'L'; + else + { + cblas_xerbla(2, "cblas_dtrsv","Illegal Uplo setting, %d\n", Uplo); + return; + } + if (TransA == CblasNoTrans) TA = 'N'; + else if (TransA == CblasTrans) TA = 'T'; + else if (TransA == CblasConjTrans) TA = 'C'; + else + { + cblas_xerbla(3, "cblas_dtrsv","Illegal TransA setting, %d\n", TransA); + return; + } + if (Diag == CblasUnit) DI = 'U'; + else if (Diag == CblasNonUnit) DI = 'N'; + else + { + cblas_xerbla(4, "cblas_dtrsv","Illegal Diag setting, %d\n", Diag); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_DI = C2F_CHAR(&DI); + #endif + F77_dtrsv( F77_UL, F77_TA, F77_DI, &F77_N, A, &F77_lda, X, + &F77_incX); + } + else if (order == CblasRowMajor) + { + if (Uplo == CblasUpper) UL = 'L'; + else if (Uplo == CblasLower) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_dtrsv","Illegal Uplo setting, %d\n", Uplo); + return; + } + + if (TransA == CblasNoTrans) TA = 'T'; + else if (TransA == CblasTrans) TA = 'N'; + else if (TransA == CblasConjTrans) TA = 'N'; + else + { + cblas_xerbla(3, "cblas_dtrsv","Illegal TransA setting, %d\n", TransA); + return; + } + + if (Diag == CblasUnit) DI = 'U'; + else if (Diag == CblasNonUnit) DI = 'N'; + else + { + cblas_xerbla(4, "cblas_dtrsv","Illegal Diag setting, %d\n", Diag); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_DI = C2F_CHAR(&DI); + #endif + F77_dtrsv( F77_UL, F77_TA, F77_DI, &F77_N, A, &F77_lda, X, + &F77_incX); + } + else cblas_xerbla(1, "cblas_dtrsv", "Illegal Order setting, %d\n", order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_dzasum.c b/ml/dlib/dlib/external/cblas/cblas_dzasum.c new file mode 100644 index 000000000..b32f573e5 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_dzasum.c @@ -0,0 +1,23 @@ +/* + * cblas_dzasum.c + * + * The program is a C interface to dzasum. + * It calls the fortran wrapper before calling dzasum. + * + * Written by Keita Teranishi. 2/11/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +double cblas_dzasum( const int N, const void *X, const int incX) +{ + double asum; +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX; +#else + #define F77_N N + #define F77_incX incX +#endif + F77_dzasum_sub( &F77_N, X, &F77_incX, &asum); + return asum; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_dznrm2.c b/ml/dlib/dlib/external/cblas/cblas_dznrm2.c new file mode 100644 index 000000000..dfa2bfc83 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_dznrm2.c @@ -0,0 +1,23 @@ +/* + * cblas_dznrm2.c + * + * The program is a C interface to dznrm2. + * It calls the fortran wrapper before calling dznrm2. + * + * Written by Keita Teranishi. 2/11/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +double cblas_dznrm2( const int N, const void *X, const int incX) +{ + double nrm2; +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX; +#else + #define F77_N N + #define F77_incX incX +#endif + F77_dznrm2_sub( &F77_N, X, &F77_incX, &nrm2); + return nrm2; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_f77.h b/ml/dlib/dlib/external/cblas/cblas_f77.h new file mode 100644 index 000000000..18435cd30 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_f77.h @@ -0,0 +1,701 @@ +/* + * cblas_f77.h + * Written by Keita Teranishi + * + * Updated by Jeff Horner + * Merged cblas_f77.h and cblas_fortran_header.h + */ + +#ifndef CBLAS_F77_H +#define CBLAS_f77_H + +#ifdef CRAY + #include + #define F77_CHAR _fcd + #define C2F_CHAR(a) ( _cptofcd( (a), 1 ) ) + #define C2F_STR(a, i) ( _cptofcd( (a), (i) ) ) + #define F77_STRLEN(a) (_fcdlen) +#endif + +#ifdef WeirdNEC + #define F77_INT long +#endif + +#ifdef F77_CHAR + #define FCHAR F77_CHAR +#else + #define FCHAR char * +#endif + +#ifdef F77_INT + #define FINT const F77_INT * + #define FINT2 F77_INT * +#else + #define FINT const int * + #define FINT2 int * +#endif + +#if defined(ADD_) +/* + * Level 1 BLAS + */ +#define F77_xerbla xerbla_ + #define F77_srotg srotg_ + #define F77_srotmg srotmg_ + #define F77_srot srot_ + #define F77_srotm srotm_ + #define F77_drotg drotg_ + #define F77_drotmg drotmg_ + #define F77_drot drot_ + #define F77_drotm drotm_ + #define F77_sswap sswap_ + #define F77_scopy scopy_ + #define F77_saxpy saxpy_ + #define F77_isamax_sub isamaxsub_ + #define F77_dswap dswap_ + #define F77_dcopy dcopy_ + #define F77_daxpy daxpy_ + #define F77_idamax_sub idamaxsub_ + #define F77_cswap cswap_ + #define F77_ccopy ccopy_ + #define F77_caxpy caxpy_ + #define F77_icamax_sub icamaxsub_ + #define F77_zswap zswap_ + #define F77_zcopy zcopy_ + #define F77_zaxpy zaxpy_ + #define F77_izamax_sub izamaxsub_ + #define F77_sdot_sub sdotsub_ + #define F77_ddot_sub ddotsub_ + #define F77_dsdot_sub dsdotsub_ + #define F77_sscal sscal_ + #define F77_dscal dscal_ + #define F77_cscal cscal_ + #define F77_zscal zscal_ + #define F77_csscal csscal_ + #define F77_zdscal zdscal_ + #define F77_cdotu_sub cdotusub_ + #define F77_cdotc_sub cdotcsub_ + #define F77_zdotu_sub zdotusub_ + #define F77_zdotc_sub zdotcsub_ + #define F77_snrm2_sub snrm2sub_ + #define F77_sasum_sub sasumsub_ + #define F77_dnrm2_sub dnrm2sub_ + #define F77_dasum_sub dasumsub_ + #define F77_scnrm2_sub scnrm2sub_ + #define F77_scasum_sub scasumsub_ + #define F77_dznrm2_sub dznrm2sub_ + #define F77_dzasum_sub dzasumsub_ + #define F77_sdsdot_sub sdsdotsub_ +/* + * Level 2 BLAS + */ + #define F77_ssymv ssymv_ + #define F77_ssbmv ssbmv_ + #define F77_sspmv sspmv_ + #define F77_sger sger_ + #define F77_ssyr ssyr_ + #define F77_sspr sspr_ + #define F77_ssyr2 ssyr2_ + #define F77_sspr2 sspr2_ + #define F77_dsymv dsymv_ + #define F77_dsbmv dsbmv_ + #define F77_dspmv dspmv_ + #define F77_dger dger_ + #define F77_dsyr dsyr_ + #define F77_dspr dspr_ + #define F77_dsyr2 dsyr2_ + #define F77_dspr2 dspr2_ + #define F77_chemv chemv_ + #define F77_chbmv chbmv_ + #define F77_chpmv chpmv_ + #define F77_cgeru cgeru_ + #define F77_cgerc cgerc_ + #define F77_cher cher_ + #define F77_chpr chpr_ + #define F77_cher2 cher2_ + #define F77_chpr2 chpr2_ + #define F77_zhemv zhemv_ + #define F77_zhbmv zhbmv_ + #define F77_zhpmv zhpmv_ + #define F77_zgeru zgeru_ + #define F77_zgerc zgerc_ + #define F77_zher zher_ + #define F77_zhpr zhpr_ + #define F77_zher2 zher2_ + #define F77_zhpr2 zhpr2_ + #define F77_sgemv sgemv_ + #define F77_sgbmv sgbmv_ + #define F77_strmv strmv_ + #define F77_stbmv stbmv_ + #define F77_stpmv stpmv_ + #define F77_strsv strsv_ + #define F77_stbsv stbsv_ + #define F77_stpsv stpsv_ + #define F77_dgemv dgemv_ + #define F77_dgbmv dgbmv_ + #define F77_dtrmv dtrmv_ + #define F77_dtbmv dtbmv_ + #define F77_dtpmv dtpmv_ + #define F77_dtrsv dtrsv_ + #define F77_dtbsv dtbsv_ + #define F77_dtpsv dtpsv_ + #define F77_cgemv cgemv_ + #define F77_cgbmv cgbmv_ + #define F77_ctrmv ctrmv_ + #define F77_ctbmv ctbmv_ + #define F77_ctpmv ctpmv_ + #define F77_ctrsv ctrsv_ + #define F77_ctbsv ctbsv_ + #define F77_ctpsv ctpsv_ + #define F77_zgemv zgemv_ + #define F77_zgbmv zgbmv_ + #define F77_ztrmv ztrmv_ + #define F77_ztbmv ztbmv_ + #define F77_ztpmv ztpmv_ + #define F77_ztrsv ztrsv_ + #define F77_ztbsv ztbsv_ + #define F77_ztpsv ztpsv_ +/* + * Level 3 BLAS + */ + #define F77_chemm chemm_ + #define F77_cherk cherk_ + #define F77_cher2k cher2k_ + #define F77_zhemm zhemm_ + #define F77_zherk zherk_ + #define F77_zher2k zher2k_ + #define F77_sgemm sgemm_ + #define F77_ssymm ssymm_ + #define F77_ssyrk ssyrk_ + #define F77_ssyr2k ssyr2k_ + #define F77_strmm strmm_ + #define F77_strsm strsm_ + #define F77_dgemm dgemm_ + #define F77_dsymm dsymm_ + #define F77_dsyrk dsyrk_ + #define F77_dsyr2k dsyr2k_ + #define F77_dtrmm dtrmm_ + #define F77_dtrsm dtrsm_ + #define F77_cgemm cgemm_ + #define F77_csymm csymm_ + #define F77_csyrk csyrk_ + #define F77_csyr2k csyr2k_ + #define F77_ctrmm ctrmm_ + #define F77_ctrsm ctrsm_ + #define F77_zgemm zgemm_ + #define F77_zsymm zsymm_ + #define F77_zsyrk zsyrk_ + #define F77_zsyr2k zsyr2k_ + #define F77_ztrmm ztrmm_ + #define F77_ztrsm ztrsm_ +#elif defined(UPCASE) +/* + * Level 1 BLAS + */ +#define F77_xerbla XERBLA + #define F77_srotg SROTG + #define F77_srotmg SROTMG + #define F77_srot SROT + #define F77_srotm SROTM + #define F77_drotg DROTG + #define F77_drotmg DROTMG + #define F77_drot DROT + #define F77_drotm DROTM + #define F77_sswap SSWAP + #define F77_scopy SCOPY + #define F77_saxpy SAXPY + #define F77_isamax_sub ISAMAXSUB + #define F77_dswap DSWAP + #define F77_dcopy DCOPY + #define F77_daxpy DAXPY + #define F77_idamax_sub IDAMAXSUB + #define F77_cswap CSWAP + #define F77_ccopy CCOPY + #define F77_caxpy CAXPY + #define F77_icamax_sub ICAMAXSUB + #define F77_zswap ZSWAP + #define F77_zcopy ZCOPY + #define F77_zaxpy ZAXPY + #define F77_izamax_sub IZAMAXSUB + #define F77_sdot_sub SDOTSUB + #define F77_ddot_sub DDOTSUB + #define F77_dsdot_sub DSDOTSUB + #define F77_sscal SSCAL + #define F77_dscal DSCAL + #define F77_cscal CSCAL + #define F77_zscal ZSCAL + #define F77_csscal CSSCAL + #define F77_zdscal ZDSCAL + #define F77_cdotu_sub CDOTUSUB + #define F77_cdotc_sub CDOTCSUB + #define F77_zdotu_sub ZDOTUSUB + #define F77_zdotc_sub ZDOTCSUB + #define F77_snrm2_sub SNRM2SUB + #define F77_sasum_sub SASUMSUB + #define F77_dnrm2_sub DNRM2SUB + #define F77_dasum_sub DASUMSUB + #define F77_scnrm2_sub SCNRM2SUB + #define F77_scasum_sub SCASUMSUB + #define F77_dznrm2_sub DZNRM2SUB + #define F77_dzasum_sub DZASUMSUB + #define F77_sdsdot_sub SDSDOTSUB +/* + * Level 2 BLAS + */ + #define F77_ssymv SSYMV + #define F77_ssbmv SSBMV + #define F77_sspmv SSPMV + #define F77_sger SGER + #define F77_ssyr SSYR + #define F77_sspr SSPR + #define F77_ssyr2 SSYR2 + #define F77_sspr2 SSPR2 + #define F77_dsymv DSYMV + #define F77_dsbmv DSBMV + #define F77_dspmv DSPMV + #define F77_dger DGER + #define F77_dsyr DSYR + #define F77_dspr DSPR + #define F77_dsyr2 DSYR2 + #define F77_dspr2 DSPR2 + #define F77_chemv CHEMV + #define F77_chbmv CHBMV + #define F77_chpmv CHPMV + #define F77_cgeru CGERU + #define F77_cgerc CGERC + #define F77_cher CHER + #define F77_chpr CHPR + #define F77_cher2 CHER2 + #define F77_chpr2 CHPR2 + #define F77_zhemv ZHEMV + #define F77_zhbmv ZHBMV + #define F77_zhpmv ZHPMV + #define F77_zgeru ZGERU + #define F77_zgerc ZGERC + #define F77_zher ZHER + #define F77_zhpr ZHPR + #define F77_zher2 ZHER2 + #define F77_zhpr2 ZHPR2 + #define F77_sgemv SGEMV + #define F77_sgbmv SGBMV + #define F77_strmv STRMV + #define F77_stbmv STBMV + #define F77_stpmv STPMV + #define F77_strsv STRSV + #define F77_stbsv STBSV + #define F77_stpsv STPSV + #define F77_dgemv DGEMV + #define F77_dgbmv DGBMV + #define F77_dtrmv DTRMV + #define F77_dtbmv DTBMV + #define F77_dtpmv DTPMV + #define F77_dtrsv DTRSV + #define F77_dtbsv DTBSV + #define F77_dtpsv DTPSV + #define F77_cgemv CGEMV + #define F77_cgbmv CGBMV + #define F77_ctrmv CTRMV + #define F77_ctbmv CTBMV + #define F77_ctpmv CTPMV + #define F77_ctrsv CTRSV + #define F77_ctbsv CTBSV + #define F77_ctpsv CTPSV + #define F77_zgemv ZGEMV + #define F77_zgbmv ZGBMV + #define F77_ztrmv ZTRMV + #define F77_ztbmv ZTBMV + #define F77_ztpmv ZTPMV + #define F77_ztrsv ZTRSV + #define F77_ztbsv ZTBSV + #define F77_ztpsv ZTPSV +/* + * Level 3 BLAS + */ + #define F77_chemm CHEMM + #define F77_cherk CHERK + #define F77_cher2k CHER2K + #define F77_zhemm ZHEMM + #define F77_zherk ZHERK + #define F77_zher2k ZHER2K + #define F77_sgemm SGEMM + #define F77_ssymm SSYMM + #define F77_ssyrk SSYRK + #define F77_ssyr2k SSYR2K + #define F77_strmm STRMM + #define F77_strsm STRSM + #define F77_dgemm DGEMM + #define F77_dsymm DSYMM + #define F77_dsyrk DSYRK + #define F77_dsyr2k DSYR2K + #define F77_dtrmm DTRMM + #define F77_dtrsm DTRSM + #define F77_cgemm CGEMM + #define F77_csymm CSYMM + #define F77_csyrk CSYRK + #define F77_csyr2k CSYR2K + #define F77_ctrmm CTRMM + #define F77_ctrsm CTRSM + #define F77_zgemm ZGEMM + #define F77_zsymm ZSYMM + #define F77_zsyrk ZSYRK + #define F77_zsyr2k ZSYR2K + #define F77_ztrmm ZTRMM + #define F77_ztrsm ZTRSM +#elif defined(NOCHANGE) +/* + * Level 1 BLAS + */ +#define F77_xerbla xerbla + #define F77_srotg srotg + #define F77_srotmg srotmg + #define F77_srot srot + #define F77_srotm srotm + #define F77_drotg drotg + #define F77_drotmg drotmg + #define F77_drot drot + #define F77_drotm drotm + #define F77_sswap sswap + #define F77_scopy scopy + #define F77_saxpy saxpy + #define F77_isamax_sub isamaxsub + #define F77_dswap dswap + #define F77_dcopy dcopy + #define F77_daxpy daxpy + #define F77_idamax_sub idamaxsub + #define F77_cswap cswap + #define F77_ccopy ccopy + #define F77_caxpy caxpy + #define F77_icamax_sub icamaxsub + #define F77_zswap zswap + #define F77_zcopy zcopy + #define F77_zaxpy zaxpy + #define F77_izamax_sub izamaxsub + #define F77_sdot_sub sdotsub + #define F77_ddot_sub ddotsub + #define F77_dsdot_sub dsdotsub + #define F77_sscal sscal + #define F77_dscal dscal + #define F77_cscal cscal + #define F77_zscal zscal + #define F77_csscal csscal + #define F77_zdscal zdscal + #define F77_cdotu_sub cdotusub + #define F77_cdotc_sub cdotcsub + #define F77_zdotu_sub zdotusub + #define F77_zdotc_sub zdotcsub + #define F77_snrm2_sub snrm2sub + #define F77_sasum_sub sasumsub + #define F77_dnrm2_sub dnrm2sub + #define F77_dasum_sub dasumsub + #define F77_scnrm2_sub scnrm2sub + #define F77_scasum_sub scasumsub + #define F77_dznrm2_sub dznrm2sub + #define F77_dzasum_sub dzasumsub + #define F77_sdsdot_sub sdsdotsub +/* + * Level 2 BLAS + */ + #define F77_ssymv ssymv + #define F77_ssbmv ssbmv + #define F77_sspmv sspmv + #define F77_sger sger + #define F77_ssyr ssyr + #define F77_sspr sspr + #define F77_ssyr2 ssyr2 + #define F77_sspr2 sspr2 + #define F77_dsymv dsymv + #define F77_dsbmv dsbmv + #define F77_dspmv dspmv + #define F77_dger dger + #define F77_dsyr dsyr + #define F77_dspr dspr + #define F77_dsyr2 dsyr2 + #define F77_dspr2 dspr2 + #define F77_chemv chemv + #define F77_chbmv chbmv + #define F77_chpmv chpmv + #define F77_cgeru cgeru + #define F77_cgerc cgerc + #define F77_cher cher + #define F77_chpr chpr + #define F77_cher2 cher2 + #define F77_chpr2 chpr2 + #define F77_zhemv zhemv + #define F77_zhbmv zhbmv + #define F77_zhpmv zhpmv + #define F77_zgeru zgeru + #define F77_zgerc zgerc + #define F77_zher zher + #define F77_zhpr zhpr + #define F77_zher2 zher2 + #define F77_zhpr2 zhpr2 + #define F77_sgemv sgemv + #define F77_sgbmv sgbmv + #define F77_strmv strmv + #define F77_stbmv stbmv + #define F77_stpmv stpmv + #define F77_strsv strsv + #define F77_stbsv stbsv + #define F77_stpsv stpsv + #define F77_dgemv dgemv + #define F77_dgbmv dgbmv + #define F77_dtrmv dtrmv + #define F77_dtbmv dtbmv + #define F77_dtpmv dtpmv + #define F77_dtrsv dtrsv + #define F77_dtbsv dtbsv + #define F77_dtpsv dtpsv + #define F77_cgemv cgemv + #define F77_cgbmv cgbmv + #define F77_ctrmv ctrmv + #define F77_ctbmv ctbmv + #define F77_ctpmv ctpmv + #define F77_ctrsv ctrsv + #define F77_ctbsv ctbsv + #define F77_ctpsv ctpsv + #define F77_zgemv zgemv + #define F77_zgbmv zgbmv + #define F77_ztrmv ztrmv + #define F77_ztbmv ztbmv + #define F77_ztpmv ztpmv + #define F77_ztrsv ztrsv + #define F77_ztbsv ztbsv + #define F77_ztpsv ztpsv +/* + * Level 3 BLAS + */ + #define F77_chemm chemm + #define F77_cherk cherk + #define F77_cher2k cher2k + #define F77_zhemm zhemm + #define F77_zherk zherk + #define F77_zher2k zher2k + #define F77_sgemm sgemm + #define F77_ssymm ssymm + #define F77_ssyrk ssyrk + #define F77_ssyr2k ssyr2k + #define F77_strmm strmm + #define F77_strsm strsm + #define F77_dgemm dgemm + #define F77_dsymm dsymm + #define F77_dsyrk dsyrk + #define F77_dsyr2k dsyr2k + #define F77_dtrmm dtrmm + #define F77_dtrsm dtrsm + #define F77_cgemm cgemm + #define F77_csymm csymm + #define F77_csyrk csyrk + #define F77_csyr2k csyr2k + #define F77_ctrmm ctrmm + #define F77_ctrsm ctrsm + #define F77_zgemm zgemm + #define F77_zsymm zsymm + #define F77_zsyrk zsyrk + #define F77_zsyr2k zsyr2k + #define F77_ztrmm ztrmm + #define F77_ztrsm ztrsm +#endif + +#ifdef __cplusplus +extern "C" { +#endif + + void F77_xerbla(FCHAR, void *); +/* + * Level 1 Fortran Prototypes + */ + +/* Single Precision */ + + void F77_srot(FINT, float *, FINT, float *, FINT, const float *, const float *); + void F77_srotg(float *,float *,float *,float *); + void F77_srotm( FINT, float *, FINT, float *, FINT, const float *); + void F77_srotmg(float *,float *,float *,const float *, float *); + void F77_sswap( FINT, float *, FINT, float *, FINT); + void F77_scopy( FINT, const float *, FINT, float *, FINT); + void F77_saxpy( FINT, const float *, const float *, FINT, float *, FINT); + void F77_sdot_sub(FINT, const float *, FINT, const float *, FINT, float *); + void F77_sdsdot_sub( FINT, const float *, const float *, FINT, const float *, FINT, float *); + void F77_sscal( FINT, const float *, float *, FINT); + void F77_snrm2_sub( FINT, const float *, FINT, float *); + void F77_sasum_sub( FINT, const float *, FINT, float *); + void F77_isamax_sub( FINT, const float * , FINT, FINT2); + +/* Double Precision */ + + void F77_drot(FINT, double *, FINT, double *, FINT, const double *, const double *); + void F77_drotg(double *,double *,double *,double *); + void F77_drotm( FINT, double *, FINT, double *, FINT, const double *); + void F77_drotmg(double *,double *,double *,const double *, double *); + void F77_dswap( FINT, double *, FINT, double *, FINT); + void F77_dcopy( FINT, const double *, FINT, double *, FINT); + void F77_daxpy( FINT, const double *, const double *, FINT, double *, FINT); + void F77_dswap( FINT, double *, FINT, double *, FINT); + void F77_dsdot_sub(FINT, const float *, FINT, const float *, FINT, double *); + void F77_ddot_sub( FINT, const double *, FINT, const double *, FINT, double *); + void F77_dscal( FINT, const double *, double *, FINT); + void F77_dnrm2_sub( FINT, const double *, FINT, double *); + void F77_dasum_sub( FINT, const double *, FINT, double *); + void F77_idamax_sub( FINT, const double * , FINT, FINT2); + +/* Single Complex Precision */ + + void F77_cswap( FINT, void *, FINT, void *, FINT); + void F77_ccopy( FINT, const void *, FINT, void *, FINT); + void F77_caxpy( FINT, const void *, const void *, FINT, void *, FINT); + void F77_cswap( FINT, void *, FINT, void *, FINT); + void F77_cdotc_sub( FINT, const void *, FINT, const void *, FINT, void *); + void F77_cdotu_sub( FINT, const void *, FINT, const void *, FINT, void *); + void F77_cscal( FINT, const void *, void *, FINT); + void F77_icamax_sub( FINT, const void *, FINT, FINT2); + void F77_csscal( FINT, const float *, void *, FINT); + void F77_scnrm2_sub( FINT, const void *, FINT, float *); + void F77_scasum_sub( FINT, const void *, FINT, float *); + +/* Double Complex Precision */ + + void F77_zswap( FINT, void *, FINT, void *, FINT); + void F77_zcopy( FINT, const void *, FINT, void *, FINT); + void F77_zaxpy( FINT, const void *, const void *, FINT, void *, FINT); + void F77_zswap( FINT, void *, FINT, void *, FINT); + void F77_zdotc_sub( FINT, const void *, FINT, const void *, FINT, void *); + void F77_zdotu_sub( FINT, const void *, FINT, const void *, FINT, void *); + void F77_zdscal( FINT, const double *, void *, FINT); + void F77_zscal( FINT, const void *, void *, FINT); + void F77_dznrm2_sub( FINT, const void *, FINT, double *); + void F77_dzasum_sub( FINT, const void *, FINT, double *); + void F77_izamax_sub( FINT, const void *, FINT, FINT2); + +/* + * Level 2 Fortran Prototypes + */ + +/* Single Precision */ + + void F77_sgemv(FCHAR, FINT, FINT, const float *, const float *, FINT, const float *, FINT, const float *, float *, FINT); + void F77_sgbmv(FCHAR, FINT, FINT, FINT, FINT, const float *, const float *, FINT, const float *, FINT, const float *, float *, FINT); + void F77_ssymv(FCHAR, FINT, const float *, const float *, FINT, const float *, FINT, const float *, float *, FINT); + void F77_ssbmv(FCHAR, FINT, FINT, const float *, const float *, FINT, const float *, FINT, const float *, float *, FINT); + void F77_sspmv(FCHAR, FINT, const float *, const float *, const float *, FINT, const float *, float *, FINT); + void F77_strmv( FCHAR, FCHAR, FCHAR, FINT, const float *, FINT, float *, FINT); + void F77_stbmv( FCHAR, FCHAR, FCHAR, FINT, FINT, const float *, FINT, float *, FINT); + void F77_strsv( FCHAR, FCHAR, FCHAR, FINT, const float *, FINT, float *, FINT); + void F77_stbsv( FCHAR, FCHAR, FCHAR, FINT, FINT, const float *, FINT, float *, FINT); + void F77_stpmv( FCHAR, FCHAR, FCHAR, FINT, const float *, float *, FINT); + void F77_stpsv( FCHAR, FCHAR, FCHAR, FINT, const float *, float *, FINT); + void F77_sger( FINT, FINT, const float *, const float *, FINT, const float *, FINT, float *, FINT); + void F77_ssyr(FCHAR, FINT, const float *, const float *, FINT, float *, FINT); + void F77_sspr(FCHAR, FINT, const float *, const float *, FINT, float *); + void F77_sspr2(FCHAR, FINT, const float *, const float *, FINT, const float *, FINT, float *); + void F77_ssyr2(FCHAR, FINT, const float *, const float *, FINT, const float *, FINT, float *, FINT); + +/* Double Precision */ + + void F77_dgemv(FCHAR, FINT, FINT, const double *, const double *, FINT, const double *, FINT, const double *, double *, FINT); + void F77_dgbmv(FCHAR, FINT, FINT, FINT, FINT, const double *, const double *, FINT, const double *, FINT, const double *, double *, FINT); + void F77_dsymv(FCHAR, FINT, const double *, const double *, FINT, const double *, FINT, const double *, double *, FINT); + void F77_dsbmv(FCHAR, FINT, FINT, const double *, const double *, FINT, const double *, FINT, const double *, double *, FINT); + void F77_dspmv(FCHAR, FINT, const double *, const double *, const double *, FINT, const double *, double *, FINT); + void F77_dtrmv( FCHAR, FCHAR, FCHAR, FINT, const double *, FINT, double *, FINT); + void F77_dtbmv( FCHAR, FCHAR, FCHAR, FINT, FINT, const double *, FINT, double *, FINT); + void F77_dtrsv( FCHAR, FCHAR, FCHAR, FINT, const double *, FINT, double *, FINT); + void F77_dtbsv( FCHAR, FCHAR, FCHAR, FINT, FINT, const double *, FINT, double *, FINT); + void F77_dtpmv( FCHAR, FCHAR, FCHAR, FINT, const double *, double *, FINT); + void F77_dtpsv( FCHAR, FCHAR, FCHAR, FINT, const double *, double *, FINT); + void F77_dger( FINT, FINT, const double *, const double *, FINT, const double *, FINT, double *, FINT); + void F77_dsyr(FCHAR, FINT, const double *, const double *, FINT, double *, FINT); + void F77_dspr(FCHAR, FINT, const double *, const double *, FINT, double *); + void F77_dspr2(FCHAR, FINT, const double *, const double *, FINT, const double *, FINT, double *); + void F77_dsyr2(FCHAR, FINT, const double *, const double *, FINT, const double *, FINT, double *, FINT); + +/* Single Complex Precision */ + + void F77_cgemv(FCHAR, FINT, FINT, const void *, const void *, FINT, const void *, FINT, const void *, void *, FINT); + void F77_cgbmv(FCHAR, FINT, FINT, FINT, FINT, const void *, const void *, FINT, const void *, FINT, const void *, void *, FINT); + void F77_chemv(FCHAR, FINT, const void *, const void *, FINT, const void *, FINT, const void *, void *, FINT); + void F77_chbmv(FCHAR, FINT, FINT, const void *, const void *, FINT, const void *, FINT, const void *, void *, FINT); + void F77_chpmv(FCHAR, FINT, const void *, const void *, const void *, FINT, const void *, void *, FINT); + void F77_ctrmv( FCHAR, FCHAR, FCHAR, FINT, const void *, FINT, void *, FINT); + void F77_ctbmv( FCHAR, FCHAR, FCHAR, FINT, FINT, const void *, FINT, void *, FINT); + void F77_ctpmv( FCHAR, FCHAR, FCHAR, FINT, const void *, void *, FINT); + void F77_ctrsv( FCHAR, FCHAR, FCHAR, FINT, const void *, FINT, void *, FINT); + void F77_ctbsv( FCHAR, FCHAR, FCHAR, FINT, FINT, const void *, FINT, void *, FINT); + void F77_ctpsv( FCHAR, FCHAR, FCHAR, FINT, const void *, void *,FINT); + void F77_cgerc( FINT, FINT, const void *, const void *, FINT, const void *, FINT, void *, FINT); + void F77_cgeru( FINT, FINT, const void *, const void *, FINT, const void *, FINT, void *, FINT); + void F77_cher(FCHAR, FINT, const float *, const void *, FINT, void *, FINT); + void F77_cher2(FCHAR, FINT, const void *, const void *, FINT, const void *, FINT, void *, FINT); + void F77_chpr(FCHAR, FINT, const float *, const void *, FINT, void *); + void F77_chpr2(FCHAR, FINT, const float *, const void *, FINT, const void *, FINT, void *); + +/* Double Complex Precision */ + + void F77_zgemv(FCHAR, FINT, FINT, const void *, const void *, FINT, const void *, FINT, const void *, void *, FINT); + void F77_zgbmv(FCHAR, FINT, FINT, FINT, FINT, const void *, const void *, FINT, const void *, FINT, const void *, void *, FINT); + void F77_zhemv(FCHAR, FINT, const void *, const void *, FINT, const void *, FINT, const void *, void *, FINT); + void F77_zhbmv(FCHAR, FINT, FINT, const void *, const void *, FINT, const void *, FINT, const void *, void *, FINT); + void F77_zhpmv(FCHAR, FINT, const void *, const void *, const void *, FINT, const void *, void *, FINT); + void F77_ztrmv( FCHAR, FCHAR, FCHAR, FINT, const void *, FINT, void *, FINT); + void F77_ztbmv( FCHAR, FCHAR, FCHAR, FINT, FINT, const void *, FINT, void *, FINT); + void F77_ztpmv( FCHAR, FCHAR, FCHAR, FINT, const void *, void *, FINT); + void F77_ztrsv( FCHAR, FCHAR, FCHAR, FINT, const void *, FINT, void *, FINT); + void F77_ztbsv( FCHAR, FCHAR, FCHAR, FINT, FINT, const void *, FINT, void *, FINT); + void F77_ztpsv( FCHAR, FCHAR, FCHAR, FINT, const void *, void *,FINT); + void F77_zgerc( FINT, FINT, const void *, const void *, FINT, const void *, FINT, void *, FINT); + void F77_zgeru( FINT, FINT, const void *, const void *, FINT, const void *, FINT, void *, FINT); + void F77_zher(FCHAR, FINT, const double *, const void *, FINT, void *, FINT); + void F77_zher2(FCHAR, FINT, const void *, const void *, FINT, const void *, FINT, void *, FINT); + void F77_zhpr(FCHAR, FINT, const double *, const void *, FINT, void *); + void F77_zhpr2(FCHAR, FINT, const double *, const void *, FINT, const void *, FINT, void *); + +/* + * Level 3 Fortran Prototypes + */ + +/* Single Precision */ + + void F77_sgemm(FCHAR, FCHAR, FINT, FINT, FINT, const float *, const float *, FINT, const float *, FINT, const float *, float *, FINT); + void F77_ssymm(FCHAR, FCHAR, FINT, FINT, const float *, const float *, FINT, const float *, FINT, const float *, float *, FINT); + void F77_ssyrk(FCHAR, FCHAR, FINT, FINT, const float *, const float *, FINT, const float *, float *, FINT); + void F77_ssyr2k(FCHAR, FCHAR, FINT, FINT, const float *, const float *, FINT, const float *, FINT, const float *, float *, FINT); + void F77_strmm(FCHAR, FCHAR, FCHAR, FCHAR, FINT, FINT, const float *, const float *, FINT, float *, FINT); + void F77_strsm(FCHAR, FCHAR, FCHAR, FCHAR, FINT, FINT, const float *, const float *, FINT, float *, FINT); + +/* Double Precision */ + + void F77_dgemm(FCHAR, FCHAR, FINT, FINT, FINT, const double *, const double *, FINT, const double *, FINT, const double *, double *, FINT); + void F77_dsymm(FCHAR, FCHAR, FINT, FINT, const double *, const double *, FINT, const double *, FINT, const double *, double *, FINT); + void F77_dsyrk(FCHAR, FCHAR, FINT, FINT, const double *, const double *, FINT, const double *, double *, FINT); + void F77_dsyr2k(FCHAR, FCHAR, FINT, FINT, const double *, const double *, FINT, const double *, FINT, const double *, double *, FINT); + void F77_dtrmm(FCHAR, FCHAR, FCHAR, FCHAR, FINT, FINT, const double *, const double *, FINT, double *, FINT); + void F77_dtrsm(FCHAR, FCHAR, FCHAR, FCHAR, FINT, FINT, const double *, const double *, FINT, double *, FINT); + +/* Single Complex Precision */ + + void F77_cgemm(FCHAR, FCHAR, FINT, FINT, FINT, const float *, const float *, FINT, const float *, FINT, const float *, float *, FINT); + void F77_csymm(FCHAR, FCHAR, FINT, FINT, const float *, const float *, FINT, const float *, FINT, const float *, float *, FINT); + void F77_chemm(FCHAR, FCHAR, FINT, FINT, const float *, const float *, FINT, const float *, FINT, const float *, float *, FINT); + void F77_csyrk(FCHAR, FCHAR, FINT, FINT, const float *, const float *, FINT, const float *, float *, FINT); + void F77_cherk(FCHAR, FCHAR, FINT, FINT, const float *, const float *, FINT, const float *, float *, FINT); + void F77_csyr2k(FCHAR, FCHAR, FINT, FINT, const float *, const float *, FINT, const float *, FINT, const float *, float *, FINT); + void F77_cher2k(FCHAR, FCHAR, FINT, FINT, const float *, const float *, FINT, const float *, FINT, const float *, float *, FINT); + void F77_ctrmm(FCHAR, FCHAR, FCHAR, FCHAR, FINT, FINT, const float *, const float *, FINT, float *, FINT); + void F77_ctrsm(FCHAR, FCHAR, FCHAR, FCHAR, FINT, FINT, const float *, const float *, FINT, float *, FINT); + +/* Double Complex Precision */ + + void F77_zgemm(FCHAR, FCHAR, FINT, FINT, FINT, const double *, const double *, FINT, const double *, FINT, const double *, double *, FINT); + void F77_zsymm(FCHAR, FCHAR, FINT, FINT, const double *, const double *, FINT, const double *, FINT, const double *, double *, FINT); + void F77_zhemm(FCHAR, FCHAR, FINT, FINT, const double *, const double *, FINT, const double *, FINT, const double *, double *, FINT); + void F77_zsyrk(FCHAR, FCHAR, FINT, FINT, const double *, const double *, FINT, const double *, double *, FINT); + void F77_zherk(FCHAR, FCHAR, FINT, FINT, const double *, const double *, FINT, const double *, double *, FINT); + void F77_zsyr2k(FCHAR, FCHAR, FINT, FINT, const double *, const double *, FINT, const double *, FINT, const double *, double *, FINT); + void F77_zher2k(FCHAR, FCHAR, FINT, FINT, const double *, const double *, FINT, const double *, FINT, const double *, double *, FINT); + void F77_ztrmm(FCHAR, FCHAR, FCHAR, FCHAR, FINT, FINT, const double *, const double *, FINT, double *, FINT); + void F77_ztrsm(FCHAR, FCHAR, FCHAR, FCHAR, FINT, FINT, const double *, const double *, FINT, double *, FINT); + +#ifdef __cplusplus +} +#endif + +#endif /* CBLAS_F77_H */ diff --git a/ml/dlib/dlib/external/cblas/cblas_icamax.c b/ml/dlib/dlib/external/cblas/cblas_icamax.c new file mode 100644 index 000000000..b3ffe6eec --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_icamax.c @@ -0,0 +1,23 @@ +/* + * cblas_icamax.c + * + * The program is a C interface to icamax. + * It calls the fortran wrapper before calling icamax. + * + * Written by Keita Teranishi. 2/11/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +CBLAS_INDEX cblas_icamax( const int N, const void *X, const int incX) +{ + F77_INT iamax; +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX; +#else + #define F77_N N + #define F77_incX incX +#endif + F77_icamax_sub( &F77_N, X, &F77_incX, &iamax); + return iamax ? iamax-1 : 0; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_idamax.c b/ml/dlib/dlib/external/cblas/cblas_idamax.c new file mode 100644 index 000000000..e42e459ea --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_idamax.c @@ -0,0 +1,23 @@ +/* + * cblas_idamax.c + * + * The program is a C interface to idamax. + * It calls the fortran wrapper before calling idamax. + * + * Written by Keita Teranishi. 2/11/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +CBLAS_INDEX cblas_idamax( const int N, const double *X, const int incX) +{ + F77_INT iamax; +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX; +#else + #define F77_N N + #define F77_incX incX +#endif + F77_idamax_sub( &F77_N, X, &F77_incX, &iamax); + return iamax ? iamax-1 : 0; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_isamax.c b/ml/dlib/dlib/external/cblas/cblas_isamax.c new file mode 100644 index 000000000..63d639c7f --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_isamax.c @@ -0,0 +1,23 @@ +/* + * cblas_isamax.c + * + * The program is a C interface to isamax. + * It calls the fortran wrapper before calling isamax. + * + * Written by Keita Teranishi. 2/11/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +CBLAS_INDEX cblas_isamax( const int N, const float *X, const int incX) +{ + F77_INT iamax; +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX; +#else + #define F77_N N + #define F77_incX incX +#endif + F77_isamax_sub( &F77_N, X, &F77_incX, &iamax); + return iamax ? iamax-1 : 0; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_izamax.c b/ml/dlib/dlib/external/cblas/cblas_izamax.c new file mode 100644 index 000000000..78eda4042 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_izamax.c @@ -0,0 +1,23 @@ +/* + * cblas_izamax.c + * + * The program is a C interface to izamax. + * It calls the fortran wrapper before calling izamax. + * + * Written by Keita Teranishi. 2/11/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +CBLAS_INDEX cblas_izamax( const int N, const void *X, const int incX) +{ + F77_INT iamax; +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX; +#else + #define F77_N N + #define F77_incX incX +#endif + F77_izamax_sub( &F77_N, X, &F77_incX, &iamax); + return (iamax ? iamax-1 : 0); +} diff --git a/ml/dlib/dlib/external/cblas/cblas_sasum.c b/ml/dlib/dlib/external/cblas/cblas_sasum.c new file mode 100644 index 000000000..7d4c32cf9 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_sasum.c @@ -0,0 +1,23 @@ +/* + * cblas_sasum.c + * + * The program is a C interface to sasum. + * It calls the fortran wrapper before calling sasum. + * + * Written by Keita Teranishi. 2/11/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +float cblas_sasum( const int N, const float *X, const int incX) +{ + float asum; +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX; +#else + #define F77_N N + #define F77_incX incX +#endif + F77_sasum_sub( &F77_N, X, &F77_incX, &asum); + return asum; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_saxpy.c b/ml/dlib/dlib/external/cblas/cblas_saxpy.c new file mode 100644 index 000000000..2eee8e06e --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_saxpy.c @@ -0,0 +1,23 @@ +/* + * cblas_saxpy.c + * + * The program is a C interface to saxpy. + * It calls the fortran wrapper before calling saxpy. + * + * Written by Keita Teranishi. 2/11/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_saxpy( const int N, const float alpha, const float *X, + const int incX, float *Y, const int incY) +{ +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX, F77_incY=incY; +#else + #define F77_N N + #define F77_incX incX + #define F77_incY incY +#endif + F77_saxpy( &F77_N, &alpha, X, &F77_incX, Y, &F77_incY); +} diff --git a/ml/dlib/dlib/external/cblas/cblas_scasum.c b/ml/dlib/dlib/external/cblas/cblas_scasum.c new file mode 100644 index 000000000..e1fa53090 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_scasum.c @@ -0,0 +1,23 @@ +/* + * cblas_scasum.c + * + * The program is a C interface to scasum. + * It calls the fortran wrapper before calling scasum. + * + * Written by Keita Teranishi. 2/11/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +float cblas_scasum( const int N, const void *X, const int incX) +{ + float asum; +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX; +#else + #define F77_N N + #define F77_incX incX +#endif + F77_scasum_sub( &F77_N, X, &F77_incX, &asum); + return asum; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_scnrm2.c b/ml/dlib/dlib/external/cblas/cblas_scnrm2.c new file mode 100644 index 000000000..fa48454ed --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_scnrm2.c @@ -0,0 +1,23 @@ +/* + * cblas_scnrm2.c + * + * The program is a C interface to scnrm2. + * It calls the fortran wrapper before calling scnrm2. + * + * Written by Keita Teranishi. 2/11/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +float cblas_scnrm2( const int N, const void *X, const int incX) +{ + float nrm2; +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX; +#else + #define F77_N N + #define F77_incX incX +#endif + F77_scnrm2_sub( &F77_N, X, &F77_incX, &nrm2); + return nrm2; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_scopy.c b/ml/dlib/dlib/external/cblas/cblas_scopy.c new file mode 100644 index 000000000..7796959f3 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_scopy.c @@ -0,0 +1,22 @@ +/* + * cblas_scopy.c + * + * The program is a C interface to scopy. + * + * Written by Keita Teranishi. 2/11/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_scopy( const int N, const float *X, + const int incX, float *Y, const int incY) +{ +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX, F77_incY=incY; +#else + #define F77_N N + #define F77_incX incX + #define F77_incY incY +#endif + F77_scopy( &F77_N, X, &F77_incX, Y, &F77_incY); +} diff --git a/ml/dlib/dlib/external/cblas/cblas_sdot.c b/ml/dlib/dlib/external/cblas/cblas_sdot.c new file mode 100644 index 000000000..baf859272 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_sdot.c @@ -0,0 +1,25 @@ +/* + * cblas_sdot.c + * + * The program is a C interface to sdot. + * It calls the fortran wrapper before calling sdot. + * + * Written by Keita Teranishi. 2/11/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +float cblas_sdot( const int N, const float *X, + const int incX, const float *Y, const int incY) +{ + float dot; +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX, F77_incY=incY; +#else + #define F77_N N + #define F77_incX incX + #define F77_incY incY +#endif + F77_sdot_sub( &F77_N, X, &F77_incX, Y, &F77_incY, &dot); + return dot; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_sdsdot.c b/ml/dlib/dlib/external/cblas/cblas_sdsdot.c new file mode 100644 index 000000000..b824849b9 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_sdsdot.c @@ -0,0 +1,25 @@ +/* + * cblas_sdsdot.c + * + * The program is a C interface to sdsdot. + * It calls the fortran wrapper before calling sdsdot. + * + * Written by Keita Teranishi. 2/11/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +float cblas_sdsdot( const int N, const float alpha, const float *X, + const int incX, const float *Y, const int incY) +{ + float dot; +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX, F77_incY=incY; +#else + #define F77_N N + #define F77_incX incX + #define F77_incY incY +#endif + F77_sdsdot_sub( &F77_N, &alpha, X, &F77_incX, Y, &F77_incY, &dot); + return dot; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_sgbmv.c b/ml/dlib/dlib/external/cblas/cblas_sgbmv.c new file mode 100644 index 000000000..b6de24977 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_sgbmv.c @@ -0,0 +1,72 @@ +/* + * + * cblas_sgbmv.c + * This program is a C interface to sgbmv. + * Written by Keita Teranishi + * 4/6/1998 + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_sgbmv(const enum CBLAS_ORDER order, + const enum CBLAS_TRANSPOSE TransA, const int M, const int N, + const int KL, const int KU, + const float alpha, const float *A, const int lda, + const float *X, const int incX, const float beta, + float *Y, const int incY) +{ + char TA; +#ifdef F77_CHAR + F77_CHAR F77_TA; +#else + #define F77_TA &TA +#endif +#ifdef F77_INT + F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_incX=incX, F77_incY=incY; + F77_INT F77_KL=KL,F77_KU=KU; +#else + #define F77_M M + #define F77_N N + #define F77_lda lda + #define F77_KL KL + #define F77_KU KU + #define F77_incX incX + #define F77_incY incY +#endif + + if (order == CblasColMajor) + { + if (TransA == CblasNoTrans) TA = 'N'; + else if (TransA == CblasTrans) TA = 'T'; + else if (TransA == CblasConjTrans) TA = 'C'; + else + { + cblas_xerbla(2, "cblas_sgbmv","Illegal TransA setting, %d\n", TransA); + return; + } + #ifdef F77_CHAR + F77_TA = C2F_CHAR(&TA); + #endif + F77_sgbmv(F77_TA, &F77_M, &F77_N, &F77_KL, &F77_KU, &alpha, + A, &F77_lda, X, &F77_incX, &beta, Y, &F77_incY); + } + else if (order == CblasRowMajor) + { + if (TransA == CblasNoTrans) TA = 'T'; + else if (TransA == CblasTrans) TA = 'N'; + else if (TransA == CblasConjTrans) TA = 'N'; + else + { + cblas_xerbla(2, "cblas_sgbmv","Illegal TransA setting, %d\n", TransA); + return; + } + #ifdef F77_CHAR + F77_TA = C2F_CHAR(&TA); + #endif + F77_sgbmv(F77_TA, &F77_N, &F77_M, &F77_KU, &F77_KL, &alpha, + A ,&F77_lda, X, &F77_incX, &beta, Y, &F77_incY); + } + else cblas_xerbla(1, "cblas_sgbmv", "Illegal Order setting, %d\n", order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_sgemm.c b/ml/dlib/dlib/external/cblas/cblas_sgemm.c new file mode 100644 index 000000000..f8adeda49 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_sgemm.c @@ -0,0 +1,95 @@ +/* + * + * cblas_sgemm.c + * This program is a C interface to sgemm. + * Written by Keita Teranishi + * 4/8/1998 + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_sgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_TRANSPOSE TransB, const int M, const int N, + const int K, const float alpha, const float *A, + const int lda, const float *B, const int ldb, + const float beta, float *C, const int ldc) +{ + char TA, TB; +#ifdef F77_CHAR + F77_CHAR F77_TA, F77_TB; +#else + #define F77_TA &TA + #define F77_TB &TB +#endif + +#ifdef F77_INT + F77_INT F77_M=M, F77_N=N, F77_K=K, F77_lda=lda, F77_ldb=ldb; + F77_INT F77_ldc=ldc; +#else + #define F77_M M + #define F77_N N + #define F77_K K + #define F77_lda lda + #define F77_ldb ldb + #define F77_ldc ldc +#endif + + if( Order == CblasColMajor ) + { + if(TransA == CblasTrans) TA='T'; + else if ( TransA == CblasConjTrans ) TA='C'; + else if ( TransA == CblasNoTrans ) TA='N'; + else + { + cblas_xerbla(2, "cblas_sgemm", + "Illegal TransA setting, %d\n", TransA); + return; + } + + if(TransB == CblasTrans) TB='T'; + else if ( TransB == CblasConjTrans ) TB='C'; + else if ( TransB == CblasNoTrans ) TB='N'; + else + { + cblas_xerbla(3, "cblas_sgemm", + "Illegal TransB setting, %d\n", TransB); + return; + } + + #ifdef F77_CHAR + F77_TA = C2F_CHAR(&TA); + F77_TB = C2F_CHAR(&TB); + #endif + + F77_sgemm(F77_TA, F77_TB, &F77_M, &F77_N, &F77_K, &alpha, A, &F77_lda, B, &F77_ldb, &beta, C, &F77_ldc); + } else if (Order == CblasRowMajor) + { + if(TransA == CblasTrans) TB='T'; + else if ( TransA == CblasConjTrans ) TB='C'; + else if ( TransA == CblasNoTrans ) TB='N'; + else + { + cblas_xerbla(2, "cblas_sgemm", + "Illegal TransA setting, %d\n", TransA); + return; + } + if(TransB == CblasTrans) TA='T'; + else if ( TransB == CblasConjTrans ) TA='C'; + else if ( TransB == CblasNoTrans ) TA='N'; + else + { + cblas_xerbla(2, "cblas_sgemm", + "Illegal TransA setting, %d\n", TransA); + return; + } + #ifdef F77_CHAR + F77_TA = C2F_CHAR(&TA); + F77_TB = C2F_CHAR(&TB); + #endif + + F77_sgemm(F77_TA, F77_TB, &F77_N, &F77_M, &F77_K, &alpha, B, &F77_ldb, A, &F77_lda, &beta, C, &F77_ldc); + } else + cblas_xerbla(1, "cblas_sgemm", + "Illegal Order setting, %d\n", Order); +} diff --git a/ml/dlib/dlib/external/cblas/cblas_sgemv.c b/ml/dlib/dlib/external/cblas/cblas_sgemv.c new file mode 100644 index 000000000..d47f3be55 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_sgemv.c @@ -0,0 +1,67 @@ +/* + * + * cblas_sgemv.c + * This program is a C interface to sgemv. + * Written by Keita Teranishi + * 4/6/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_sgemv(const enum CBLAS_ORDER order, + const enum CBLAS_TRANSPOSE TransA, const int M, const int N, + const float alpha, const float *A, const int lda, + const float *X, const int incX, const float beta, + float *Y, const int incY) +{ + char TA; +#ifdef F77_CHAR + F77_CHAR F77_TA; +#else + #define F77_TA &TA +#endif +#ifdef F77_INT + F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_incX=incX, F77_incY=incY; +#else + #define F77_M M + #define F77_N N + #define F77_lda lda + #define F77_incX incX + #define F77_incY incY +#endif + + + if (order == CblasColMajor) + { + if (TransA == CblasNoTrans) TA = 'N'; + else if (TransA == CblasTrans) TA = 'T'; + else if (TransA == CblasConjTrans) TA = 'C'; + else + { + cblas_xerbla(2, "cblas_sgemv","Illegal TransA setting, %d\n", TransA); + } + #ifdef F77_CHAR + F77_TA = C2F_CHAR(&TA); + #endif + F77_sgemv(F77_TA, &F77_M, &F77_N, &alpha, A, &F77_lda, X, &F77_incX, + &beta, Y, &F77_incY); + } + else if (order == CblasRowMajor) + { + if (TransA == CblasNoTrans) TA = 'T'; + else if (TransA == CblasTrans) TA = 'N'; + else if (TransA == CblasConjTrans) TA = 'N'; + else + { + cblas_xerbla(2, "cblas_sgemv", "Illegal TransA setting, %d\n", TransA); + return; + } + #ifdef F77_CHAR + F77_TA = C2F_CHAR(&TA); + #endif + F77_sgemv(F77_TA, &F77_N, &F77_M, &alpha, A, &F77_lda, X, + &F77_incX, &beta, Y, &F77_incY); + } + else cblas_xerbla(1, "cblas_sgemv", "Illegal Order setting, %d\n", order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_sger.c b/ml/dlib/dlib/external/cblas/cblas_sger.c new file mode 100644 index 000000000..0313590c7 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_sger.c @@ -0,0 +1,39 @@ +/* + * + * cblas_sger.c + * This program is a C interface to sger. + * Written by Keita Teranishi + * 4/6/1998 + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_sger(const enum CBLAS_ORDER order, const int M, const int N, + const float alpha, const float *X, const int incX, + const float *Y, const int incY, float *A, const int lda) +{ +#ifdef F77_INT + F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_incX=incX, F77_incY=incY; +#else + #define F77_M M + #define F77_N N + #define F77_incX incX + #define F77_incY incY + #define F77_lda lda +#endif + + + if (order == CblasColMajor) + { + F77_sger( &F77_M, &F77_N, &alpha, X, &F77_incX, Y, &F77_incY, A, + &F77_lda); + } + else if (order == CblasRowMajor) + { + F77_sger( &F77_N, &F77_M, &alpha, Y, &F77_incY, X, &F77_incX, A, + &F77_lda); + } + else cblas_xerbla(1, "cblas_sger", "Illegal Order setting, %d\n", order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_snrm2.c b/ml/dlib/dlib/external/cblas/cblas_snrm2.c new file mode 100644 index 000000000..18161b4fa --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_snrm2.c @@ -0,0 +1,23 @@ +/* + * cblas_snrm2.c + * + * The program is a C interface to snrm2. + * It calls the fortran wrapper before calling snrm2. + * + * Written by Keita Teranishi. 2/11/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +float cblas_snrm2( const int N, const float *X, const int incX) +{ + float nrm2; +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX; +#else + #define F77_N N + #define F77_incX incX +#endif + F77_snrm2_sub( &F77_N, X, &F77_incX, &nrm2); + return nrm2; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_srot.c b/ml/dlib/dlib/external/cblas/cblas_srot.c new file mode 100644 index 000000000..cbd1c8c90 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_srot.c @@ -0,0 +1,22 @@ +/* + * cblas_srot.c + * + * The program is a C interface to srot. + * + * Written by Keita Teranishi. 2/11/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_srot( const int N, float *X, const int incX, float *Y, + const int incY, const float c, const float s) +{ +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX, F77_incY=incY; +#else + #define F77_N N + #define F77_incX incX + #define F77_incY incY +#endif + F77_srot(&F77_N, X, &F77_incX, Y, &F77_incY, &c, &s); +} diff --git a/ml/dlib/dlib/external/cblas/cblas_srotg.c b/ml/dlib/dlib/external/cblas/cblas_srotg.c new file mode 100644 index 000000000..f6460048d --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_srotg.c @@ -0,0 +1,14 @@ +/* + * cblas_srotg.c + * + * The program is a C interface to srotg. + * + * Written by Keita Teranishi. 2/11/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_srotg( float *a, float *b, float *c, float *s) +{ + F77_srotg(a,b,c,s); +} diff --git a/ml/dlib/dlib/external/cblas/cblas_srotm.c b/ml/dlib/dlib/external/cblas/cblas_srotm.c new file mode 100644 index 000000000..496746454 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_srotm.c @@ -0,0 +1,22 @@ +/* + * cblas_srotm.c + * + * The program is a C interface to srotm. + * + * Written by Keita Teranishi. 2/11/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_srotm( const int N, float *X, const int incX, float *Y, + const int incY, const float *P) +{ +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX, F77_incY=incY; +#else + #define F77_N N + #define F77_incX incX + #define F77_incY incY +#endif + F77_srotm( &F77_N, X, &F77_incX, Y, &F77_incY, P); +} diff --git a/ml/dlib/dlib/external/cblas/cblas_srotmg.c b/ml/dlib/dlib/external/cblas/cblas_srotmg.c new file mode 100644 index 000000000..04f978b40 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_srotmg.c @@ -0,0 +1,15 @@ +/* + * cblas_srotmg.c + * + * The program is a C interface to srotmg. + * + * Written by Keita Teranishi. 2/11/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_srotmg( float *d1, float *d2, float *b1, + const float b2, float *p) +{ + F77_srotmg(d1,d2,b1,&b2,p); +} diff --git a/ml/dlib/dlib/external/cblas/cblas_ssbmv.c b/ml/dlib/dlib/external/cblas/cblas_ssbmv.c new file mode 100644 index 000000000..69663b619 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_ssbmv.c @@ -0,0 +1,65 @@ +/* + * + * cblas_ssbmv.c + * This program is a C interface to ssbmv. + * Written by Keita Teranishi + * 4/6/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_ssbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const int N, const int K, const float alpha, const float *A, + const int lda, const float *X, const int incX, + const float beta, float *Y, const int incY) +{ + char UL; +#ifdef F77_CHAR + F77_CHAR F77_UL; +#else + #define F77_UL &UL +#endif + +#ifdef F77_INT + F77_INT F77_N=N, F77_K=K, F77_lda=lda, F77_incX=incX, F77_incY=incY; +#else + #define F77_N N + #define F77_K K + #define F77_lda lda + #define F77_incX incX + #define F77_incY incY +#endif + + if (order == CblasColMajor) + { + + if (Uplo == CblasUpper) UL = 'U'; + else if (Uplo == CblasLower) UL = 'L'; + else + { + cblas_xerbla(2, "cblas_ssbmv","Illegal Uplo setting, %d\n",Uplo ); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + F77_ssbmv(F77_UL, &F77_N, &F77_K, &alpha, A, &F77_lda, X, + &F77_incX, &beta, Y, &F77_incY); + }else if (order == CblasRowMajor) + { + if (Uplo == CblasUpper) UL = 'L'; + else if (Uplo == CblasLower) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_ssbmv","Illegal Uplo setting, %d\n", Uplo); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + F77_ssbmv(F77_UL, &F77_N, &F77_K, &alpha, A, &F77_lda, X, + &F77_incX, &beta, Y, &F77_incY); + } + else cblas_xerbla(1, "cblas_ssbmv", "Illegal Order setting, %d\n", order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_sscal.c b/ml/dlib/dlib/external/cblas/cblas_sscal.c new file mode 100644 index 000000000..1f09abe7a --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_sscal.c @@ -0,0 +1,21 @@ +/* + * cblas_sscal.c + * + * The program is a C interface to sscal. + * + * Written by Keita Teranishi. 2/11/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_sscal( const int N, const float alpha, float *X, + const int incX) +{ +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX; +#else + #define F77_N N + #define F77_incX incX +#endif + F77_sscal( &F77_N, &alpha, X, &F77_incX); +} diff --git a/ml/dlib/dlib/external/cblas/cblas_sspmv.c b/ml/dlib/dlib/external/cblas/cblas_sspmv.c new file mode 100644 index 000000000..e3485e742 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_sspmv.c @@ -0,0 +1,62 @@ +/* + * + * cblas_sspmv.c + * This program is a C interface to sspmv. + * Written by Keita Teranishi + * 4/6/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_sspmv(const enum CBLAS_ORDER order, + const enum CBLAS_UPLO Uplo, const int N, + const float alpha, const float *AP, + const float *X, const int incX, const float beta, + float *Y, const int incY) +{ + char UL; +#ifdef F77_CHAR + F77_CHAR F77_UL; +#else + #define F77_UL &UL +#endif +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX, F77_incY=incY; +#else + #define F77_N N + #define F77_incX incX + #define F77_incY incY +#endif + + if (order == CblasColMajor) + { + if (Uplo == CblasUpper) UL = 'U'; + else if (Uplo == CblasLower) UL = 'L'; + else + { + cblas_xerbla(2, "cblas_sspmv","Illegal Uplo setting, %d\n",Uplo ); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + F77_sspmv(F77_UL, &F77_N, &alpha, AP, X, + &F77_incX, &beta, Y, &F77_incY); + } + else if (order == CblasRowMajor) + { + if (Uplo == CblasUpper) UL = 'L'; + else if (Uplo == CblasLower) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_sspmv","Illegal Uplo setting, %d\n", Uplo); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + F77_sspmv(F77_UL, &F77_N, &alpha, + AP, X,&F77_incX, &beta, Y, &F77_incY); + } + else cblas_xerbla(1, "cblas_sspmv", "Illegal Order setting, %d\n", order); +} diff --git a/ml/dlib/dlib/external/cblas/cblas_sspr.c b/ml/dlib/dlib/external/cblas/cblas_sspr.c new file mode 100644 index 000000000..75d669b30 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_sspr.c @@ -0,0 +1,61 @@ +/* + * + * cblas_sspr.c + * This program is a C interface to sspr. + * Written by Keita Teranishi + * 4/6/1998 + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_sspr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const int N, const float alpha, const float *X, + const int incX, float *Ap) +{ + char UL; +#ifdef F77_CHAR + F77_CHAR F77_UL; +#else + #define F77_UL &UL +#endif + +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX; +#else + #define F77_N N + #define F77_incX incX +#endif + + + if (order == CblasColMajor) + { + if (Uplo == CblasLower) UL = 'L'; + else if (Uplo == CblasUpper) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_sspr","Illegal Uplo setting, %d\n",Uplo ); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + + F77_sspr(F77_UL, &F77_N, &alpha, X, &F77_incX, Ap); + + } else if (order == CblasRowMajor) + { + if (Uplo == CblasLower) UL = 'U'; + else if (Uplo == CblasUpper) UL = 'L'; + else + { + cblas_xerbla(2, "cblas_sspr","Illegal Uplo setting, %d\n",Uplo ); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + F77_sspr(F77_UL, &F77_N, &alpha, X, &F77_incX, Ap); + } else cblas_xerbla(1, "cblas_sspr", "Illegal Order setting, %d\n", order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_sspr2.c b/ml/dlib/dlib/external/cblas/cblas_sspr2.c new file mode 100644 index 000000000..b6ff9bfa2 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_sspr2.c @@ -0,0 +1,60 @@ +/* + * + * cblas_sspr2.c + * This program is a C interface to sspr2. + * Written by Keita Teranishi + * 4/6/1998 + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_sspr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const int N, const float alpha, const float *X, + const int incX, const float *Y, const int incY, float *A) +{ + char UL; +#ifdef F77_CHAR + F77_CHAR F77_UL; +#else + #define F77_UL &UL +#endif + +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX, F77_incY=incY; +#else + #define F77_N N + #define F77_incX incX + #define F77_incY incY +#endif + + if (order == CblasColMajor) + { + if (Uplo == CblasLower) UL = 'L'; + else if (Uplo == CblasUpper) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_sspr2","Illegal Uplo setting, %d\n",Uplo ); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + + F77_sspr2(F77_UL, &F77_N, &alpha, X, &F77_incX, Y, &F77_incY, A); + + } else if (order == CblasRowMajor) + { + if (Uplo == CblasLower) UL = 'U'; + else if (Uplo == CblasUpper) UL = 'L'; + else + { + cblas_xerbla(2, "cblas_sspr2","Illegal Uplo setting, %d\n",Uplo ); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + F77_sspr2(F77_UL, &F77_N, &alpha, X, &F77_incX, Y, &F77_incY, A); + } else cblas_xerbla(1, "cblas_sspr2", "Illegal Order setting, %d\n", order); +} diff --git a/ml/dlib/dlib/external/cblas/cblas_sswap.c b/ml/dlib/dlib/external/cblas/cblas_sswap.c new file mode 100644 index 000000000..b74d8469c --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_sswap.c @@ -0,0 +1,22 @@ +/* + * cblas_sswap.c + * + * The program is a C interface to sswap. + * + * Written by Keita Teranishi. 2/11/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_sswap( const int N, float *X, const int incX, float *Y, + const int incY) +{ +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX, F77_incY=incY; +#else + #define F77_N N + #define F77_incX incX + #define F77_incY incY +#endif + F77_sswap( &F77_N, X, &F77_incX, Y, &F77_incY); +} diff --git a/ml/dlib/dlib/external/cblas/cblas_ssymm.c b/ml/dlib/dlib/external/cblas/cblas_ssymm.c new file mode 100644 index 000000000..55c413b48 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_ssymm.c @@ -0,0 +1,93 @@ +/* + * + * cblas_ssymm.c + * This program is a C interface to ssymm. + * Written by Keita Teranishi + * 4/8/1998 + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_ssymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const int M, const int N, + const float alpha, const float *A, const int lda, + const float *B, const int ldb, const float beta, + float *C, const int ldc) +{ + char SD, UL; +#ifdef F77_CHAR + F77_CHAR F77_SD, F77_UL; +#else + #define F77_SD &SD + #define F77_UL &UL +#endif + +#ifdef F77_INT + F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_ldb=ldb; + F77_INT F77_ldc=ldc; +#else + #define F77_M M + #define F77_N N + #define F77_lda lda + #define F77_ldb ldb + #define F77_ldc ldc +#endif + + + if( Order == CblasColMajor ) + { + if( Side == CblasRight) SD='R'; + else if ( Side == CblasLeft ) SD='L'; + else + { + cblas_xerbla(2, "cblas_ssymm", + "Illegal Side setting, %d\n", Side); + return; + } + + if( Uplo == CblasUpper) UL='U'; + else if ( Uplo == CblasLower ) UL='L'; + else + { + cblas_xerbla(3, "cblas_ssymm", + "Illegal Uplo setting, %d\n", Uplo); + return; + } + + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_SD = C2F_CHAR(&SD); + #endif + + F77_ssymm(F77_SD, F77_UL, &F77_M, &F77_N, &alpha, A, &F77_lda, B, &F77_ldb, &beta, C, &F77_ldc); + } else if (Order == CblasRowMajor) + { + if( Side == CblasRight) SD='L'; + else if ( Side == CblasLeft ) SD='R'; + else + { + cblas_xerbla(2, "cblas_ssymm", + "Illegal Side setting, %d\n", Side); + return; + } + + if( Uplo == CblasUpper) UL='L'; + else if ( Uplo == CblasLower ) UL='U'; + else + { + cblas_xerbla(3, "cblas_ssymm", + "Illegal Uplo setting, %d\n", Uplo); + return; + } + + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_SD = C2F_CHAR(&SD); + #endif + + F77_ssymm(F77_SD, F77_UL, &F77_N, &F77_M, &alpha, A, &F77_lda, B, &F77_ldb, &beta, C, &F77_ldc); + } else cblas_xerbla(1, "cblas_ssymm", + "Illegal Order setting, %d\n", Order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_ssymv.c b/ml/dlib/dlib/external/cblas/cblas_ssymv.c new file mode 100644 index 000000000..c56d4d457 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_ssymv.c @@ -0,0 +1,65 @@ +/* + * + * cblas_ssymv.c + * This program is a C interface to ssymv. + * Written by Keita Teranishi + * 4/6/1998 + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_ssymv(const enum CBLAS_ORDER order, + const enum CBLAS_UPLO Uplo, const int N, + const float alpha, const float *A, const int lda, + const float *X, const int incX, const float beta, + float *Y, const int incY) +{ + char UL; +#ifdef F77_CHAR + F77_CHAR F77_UL; +#else + #define F77_UL &UL +#endif +#ifdef F77_INT + F77_INT F77_N=N, F77_lda=lda, F77_incX=incX, F77_incY=incY; +#else + #define F77_N N + #define F77_lda lda + #define F77_incX incX + #define F77_incY incY +#endif + + if (order == CblasColMajor) + { + if (Uplo == CblasUpper) UL = 'U'; + else if (Uplo == CblasLower) UL = 'L'; + else + { + cblas_xerbla(2, "cblas_ssymv","Illegal Uplo setting, %d\n",Uplo ); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + F77_ssymv(F77_UL, &F77_N, &alpha, A, &F77_lda, X, + &F77_incX, &beta, Y, &F77_incY); + } + else if (order == CblasRowMajor) + { + if (Uplo == CblasUpper) UL = 'L'; + else if (Uplo == CblasLower) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_ssymv","Illegal Uplo setting, %d\n", Uplo); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + F77_ssymv(F77_UL, &F77_N, &alpha, + A ,&F77_lda, X,&F77_incX, &beta, Y, &F77_incY); + } + else cblas_xerbla(1, "cblas_ssymv", "Illegal Order setting, %d\n", order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_ssyr.c b/ml/dlib/dlib/external/cblas/cblas_ssyr.c new file mode 100644 index 000000000..4215b9671 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_ssyr.c @@ -0,0 +1,59 @@ +/* + * + * cblas_ssyr.c + * This program is a C interface to ssyr. + * Written by Keita Teranishi + * 4/6/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_ssyr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const int N, const float alpha, const float *X, + const int incX, float *A, const int lda) +{ + char UL; +#ifdef F77_CHAR + F77_CHAR F77_UL; +#else + #define F77_UL &UL +#endif + +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX, F77_lda=lda; +#else + #define F77_N N + #define F77_incX incX + #define F77_lda lda +#endif + if (order == CblasColMajor) + { + if (Uplo == CblasLower) UL = 'L'; + else if (Uplo == CblasUpper) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_ssyr","Illegal Uplo setting, %d\n",Uplo ); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + + F77_ssyr(F77_UL, &F77_N, &alpha, X, &F77_incX, A, &F77_lda); + + } else if (order == CblasRowMajor) + { + if (Uplo == CblasLower) UL = 'U'; + else if (Uplo == CblasUpper) UL = 'L'; + else + { + cblas_xerbla(2, "cblas_ssyr","Illegal Uplo setting, %d\n",Uplo ); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + F77_ssyr(F77_UL, &F77_N, &alpha, X, &F77_incX, A, &F77_lda); + } else cblas_xerbla(1, "cblas_ssyr", "Illegal Order setting, %d\n", order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_ssyr2.c b/ml/dlib/dlib/external/cblas/cblas_ssyr2.c new file mode 100644 index 000000000..9cdaa412d --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_ssyr2.c @@ -0,0 +1,65 @@ +/* + * + * cblas_ssyr2.c + * This program is a C interface to ssyr2. + * Written by Keita Teranishi + * 4/6/1998 + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_ssyr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const int N, const float alpha, const float *X, + const int incX, const float *Y, const int incY, float *A, + const int lda) +{ + char UL; +#ifdef F77_CHAR + F77_CHAR F77_UL; +#else + #define F77_UL &UL +#endif + +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX, F77_incY=incY, F77_lda=lda; +#else + #define F77_N N + #define F77_incX incX + #define F77_incY incY + #define F77_lda lda +#endif + + if (order == CblasColMajor) + { + if (Uplo == CblasLower) UL = 'L'; + else if (Uplo == CblasUpper) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_ssyr2","Illegal Uplo setting, %d\n",Uplo ); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + + F77_ssyr2(F77_UL, &F77_N, &alpha, X, &F77_incX, Y, &F77_incY, A, + &F77_lda); + + } else if (order == CblasRowMajor) + { + if (Uplo == CblasLower) UL = 'U'; + else if (Uplo == CblasUpper) UL = 'L'; + else + { + cblas_xerbla(2, "cblas_ssyr2","Illegal Uplo setting, %d\n",Uplo ); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + F77_ssyr2(F77_UL, &F77_N, &alpha, X, &F77_incX, Y, &F77_incY, A, + &F77_lda); + } else cblas_xerbla(1, "cblas_ssyr2", "Illegal Order setting, %d\n", order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_ssyr2k.c b/ml/dlib/dlib/external/cblas/cblas_ssyr2k.c new file mode 100644 index 000000000..9e9f538df --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_ssyr2k.c @@ -0,0 +1,96 @@ +/* + * + * cblas_ssyr2k.c + * This program is a C interface to ssyr2k. + * Written by Keita Teranishi + * 4/6/1998 + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_ssyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const float alpha, const float *A, const int lda, + const float *B, const int ldb, const float beta, + float *C, const int ldc) +{ + char UL, TR; +#ifdef F77_CHAR + F77_CHAR F77_TA, F77_UL; +#else + #define F77_TR &TR + #define F77_UL &UL +#endif + +#ifdef F77_INT + F77_INT F77_N=N, F77_K=K, F77_lda=lda, F77_ldb=ldb; + F77_INT F77_ldc=ldc; +#else + #define F77_N N + #define F77_K K + #define F77_lda lda + #define F77_ldb ldb + #define F77_ldc ldc +#endif + + + if( Order == CblasColMajor ) + { + + if( Uplo == CblasUpper) UL='U'; + else if ( Uplo == CblasLower ) UL='L'; + else + { + cblas_xerbla(2, "cblas_ssyr2k", + "Illegal Uplo setting, %d\n", Uplo); + return; + } + + if( Trans == CblasTrans) TR ='T'; + else if ( Trans == CblasConjTrans ) TR='C'; + else if ( Trans == CblasNoTrans ) TR='N'; + else + { + cblas_xerbla(3, "cblas_ssyr2k", + "Illegal Trans setting, %d\n", Trans); + return; + } + + + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TR = C2F_CHAR(&TR); + #endif + + F77_ssyr2k(F77_UL, F77_TR, &F77_N, &F77_K, &alpha, A, &F77_lda, B, &F77_ldb, &beta, C, &F77_ldc); + } else if (Order == CblasRowMajor) + { + if( Uplo == CblasUpper) UL='L'; + else if ( Uplo == CblasLower ) UL='U'; + else + { + cblas_xerbla(3, "cblas_ssyr2k", + "Illegal Uplo setting, %d\n", Uplo); + return; + } + if( Trans == CblasTrans) TR ='N'; + else if ( Trans == CblasConjTrans ) TR='N'; + else if ( Trans == CblasNoTrans ) TR='T'; + else + { + cblas_xerbla(3, "cblas_ssyr2k", + "Illegal Trans setting, %d\n", Trans); + return; + } + + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TR = C2F_CHAR(&TR); + #endif + + F77_ssyr2k(F77_UL, F77_TR, &F77_N, &F77_K, &alpha, A, &F77_lda, B, &F77_ldb, &beta, C, &F77_ldc); + } else cblas_xerbla(1, "cblas_ssyr2k", + "Illegal Order setting, %d\n", Order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_ssyrk.c b/ml/dlib/dlib/external/cblas/cblas_ssyrk.c new file mode 100644 index 000000000..55ceb7e3a --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_ssyrk.c @@ -0,0 +1,95 @@ +/* + * + * cblas_ssyrk.c + * This program is a C interface to ssyrk. + * Written by Keita Teranishi + * 4/8/1998 + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_ssyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const float alpha, const float *A, const int lda, + const float beta, float *C, const int ldc) +{ + char UL, TR; +#ifdef F77_CHAR + F77_CHAR F77_TR, F77_UL; +#else + #define F77_TR &TR + #define F77_UL &UL +#endif + +#ifdef F77_INT + F77_INT F77_N=N, F77_K=K, F77_lda=lda; + F77_INT F77_ldc=ldc; +#else + #define F77_N N + #define F77_K K + #define F77_lda lda + #define F77_ldc ldc +#endif + + + if( Order == CblasColMajor ) + { + + if( Uplo == CblasUpper) UL='U'; + else if ( Uplo == CblasLower ) UL='L'; + else + { + cblas_xerbla(2, "cblas_ssyrk", + "Illegal Uplo setting, %d\n", Uplo); + return; + } + + if( Trans == CblasTrans) TR ='T'; + else if ( Trans == CblasConjTrans ) TR='C'; + else if ( Trans == CblasNoTrans ) TR='N'; + else + { + cblas_xerbla(3, "cblas_ssyrk", + "Illegal Trans setting, %d\n", Trans); + return; + } + + + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TR = C2F_CHAR(&TR); + #endif + + F77_ssyrk(F77_UL, F77_TR, &F77_N, &F77_K, &alpha, A, &F77_lda, &beta, C, &F77_ldc); + } else if (Order == CblasRowMajor) + { + if( Uplo == CblasUpper) UL='L'; + else if ( Uplo == CblasLower ) UL='U'; + else + { + cblas_xerbla(3, "cblas_ssyrk", + "Illegal Uplo setting, %d\n", Uplo); + return; + } + if( Trans == CblasTrans) TR ='N'; + else if ( Trans == CblasConjTrans ) TR='N'; + else if ( Trans == CblasNoTrans ) TR='T'; + else + { + cblas_xerbla(3, "cblas_ssyrk", + "Illegal Trans setting, %d\n", Trans); + return; + } + + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TR = C2F_CHAR(&TR); + #endif + + F77_ssyrk(F77_UL, F77_TR, &F77_N, &F77_K, &alpha, A, &F77_lda, &beta, C, &F77_ldc); + } else cblas_xerbla(1, "cblas_ssyrk", + "Illegal Order setting, %d\n", Order); + return; +} + diff --git a/ml/dlib/dlib/external/cblas/cblas_stbmv.c b/ml/dlib/dlib/external/cblas/cblas_stbmv.c new file mode 100644 index 000000000..71ef469a7 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_stbmv.c @@ -0,0 +1,103 @@ +/* + * cblas_stbmv.c + * This program is a C interface to stbmv. + * Written by Keita Teranishi + * 3/3/1998 + */ +#include "cblas.h" +#include "cblas_f77.h" + +void cblas_stbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const int K, const float *A, const int lda, + float *X, const int incX) +{ + char TA; + char UL; + char DI; +#ifdef F77_CHAR + F77_CHAR F77_TA, F77_UL, F77_DI; +#else + #define F77_TA &TA + #define F77_UL &UL + #define F77_DI &DI +#endif +#ifdef F77_INT + F77_INT F77_N=N, F77_lda=lda, F77_K=K, F77_incX=incX; +#else + #define F77_N N + #define F77_K K + #define F77_lda lda + #define F77_incX incX +#endif + + if (order == CblasColMajor) + { + if (Uplo == CblasUpper) UL = 'U'; + else if (Uplo == CblasLower) UL = 'L'; + else + { + cblas_xerbla(2, "cblas_stbmv","Illegal Uplo setting, %d\n", Uplo); + return; + } + if (TransA == CblasNoTrans) TA = 'N'; + else if (TransA == CblasTrans) TA = 'T'; + else if (TransA == CblasConjTrans) TA = 'C'; + else + { + cblas_xerbla(3, "cblas_stbmv","Illegal TransA setting, %d\n", TransA); + return; + } + if (Diag == CblasUnit) DI = 'U'; + else if (Diag == CblasNonUnit) DI = 'N'; + else + { + cblas_xerbla(4, "cblas_stbmv","Illegal Diag setting, %d\n", Diag); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_DI = C2F_CHAR(&DI); + #endif + F77_stbmv( F77_UL, F77_TA, F77_DI, &F77_N, &F77_K, A, &F77_lda, X, + &F77_incX); + } + else if (order == CblasRowMajor) + { + if (Uplo == CblasUpper) UL = 'L'; + else if (Uplo == CblasLower) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_stbmv","Illegal Uplo setting, %d\n", Uplo); + return; + } + + if (TransA == CblasNoTrans) TA = 'T'; + else if (TransA == CblasTrans) TA = 'N'; + else if (TransA == CblasConjTrans) TA = 'N'; + else + { + cblas_xerbla(3, "cblas_stbmv","Illegal TransA setting, %d\n", TransA); + return; + } + + if (Diag == CblasUnit) DI = 'U'; + else if (Diag == CblasNonUnit) DI = 'N'; + else + { + cblas_xerbla(4, "cblas_stbmv","Illegal Uplo setting, %d\n", Uplo); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_DI = C2F_CHAR(&DI); + #endif + + F77_stbmv( F77_UL, F77_TA, F77_DI, &F77_N, &F77_K, A, &F77_lda, X, + &F77_incX); + } + else cblas_xerbla(1, "cblas_stbmv", "Illegal Order setting, %d\n", order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_stbsv.c b/ml/dlib/dlib/external/cblas/cblas_stbsv.c new file mode 100644 index 000000000..96df7c0c1 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_stbsv.c @@ -0,0 +1,103 @@ +/* + * cblas_stbsv.c + * The program is a C interface to stbsv. + * + * Keita Teranishi 5/20/98 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_stbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const int K, const float *A, const int lda, + float *X, const int incX) +{ + char TA; + char UL; + char DI; +#ifdef F77_CHAR + F77_CHAR F77_TA, F77_UL, F77_DI; +#else + #define F77_TA &TA + #define F77_UL &UL + #define F77_DI &DI +#endif +#ifdef F77_INT + F77_INT F77_N=N, F77_lda=lda, F77_K=K, F77_incX=incX; +#else + #define F77_N N + #define F77_K K + #define F77_lda lda + #define F77_incX incX +#endif + + if (order == CblasColMajor) + { + if (Uplo == CblasUpper) UL = 'U'; + else if (Uplo == CblasLower) UL = 'L'; + else + { + cblas_xerbla(2, "cblas_stbsv","Illegal Uplo setting, %d\n", Uplo); + return; + } + if (TransA == CblasNoTrans) TA = 'N'; + else if (TransA == CblasTrans) TA = 'T'; + else if (TransA == CblasConjTrans) TA = 'C'; + else + { + cblas_xerbla(3, "cblas_stbsv","Illegal TransA setting, %d\n", TransA); + return; + } + if (Diag == CblasUnit) DI = 'U'; + else if (Diag == CblasNonUnit) DI = 'N'; + else + { + cblas_xerbla(4, "cblas_stbsv","Illegal Diag setting, %d\n", Diag); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_DI = C2F_CHAR(&DI); + #endif + F77_stbsv( F77_UL, F77_TA, F77_DI, &F77_N, &F77_K, A, &F77_lda, X, + &F77_incX); + } + else if (order == CblasRowMajor) + { + if (Uplo == CblasUpper) UL = 'L'; + else if (Uplo == CblasLower) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_stbsv","Illegal Uplo setting, %d\n", Uplo); + return; + } + + if (TransA == CblasNoTrans) TA = 'T'; + else if (TransA == CblasTrans) TA = 'N'; + else if (TransA == CblasConjTrans) TA = 'N'; + else + { + cblas_xerbla(3, "cblas_stbsv","Illegal TransA setting, %d\n", TransA); + return; + } + + if (Diag == CblasUnit) DI = 'U'; + else if (Diag == CblasNonUnit) DI = 'N'; + else + { + cblas_xerbla(4, "cblas_stbsv","Illegal Diag setting, %d\n", Diag); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_DI = C2F_CHAR(&DI); + #endif + + F77_stbsv( F77_UL, F77_TA, F77_DI, &F77_N, &F77_K, A, &F77_lda, X, + &F77_incX); + } + else cblas_xerbla(1, "cblas_stbsv", "Illegal Order setting, %d\n", order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_stpmv.c b/ml/dlib/dlib/external/cblas/cblas_stpmv.c new file mode 100644 index 000000000..5cb5cd29d --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_stpmv.c @@ -0,0 +1,99 @@ +/* + * + * cblas_stpmv.c + * This program is a C interface to stpmv. + * Written by Keita Teranishi + * 4/6/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_stpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const float *Ap, float *X, const int incX) +{ + char TA; + char UL; + char DI; +#ifdef F77_CHAR + F77_CHAR F77_TA, F77_UL, F77_DI; +#else + #define F77_TA &TA + #define F77_UL &UL + #define F77_DI &DI +#endif +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX; +#else + #define F77_N N + #define F77_incX incX +#endif + + if (order == CblasColMajor) + { + if (Uplo == CblasUpper) UL = 'U'; + else if (Uplo == CblasLower) UL = 'L'; + else + { + cblas_xerbla(2, "cblas_stpmv","Illegal Uplo setting, %d\n", Uplo); + return; + } + if (TransA == CblasNoTrans) TA = 'N'; + else if (TransA == CblasTrans) TA = 'T'; + else if (TransA == CblasConjTrans) TA = 'C'; + else + { + cblas_xerbla(3, "cblas_stpmv","Illegal TransA setting, %d\n", TransA); + return; + } + if (Diag == CblasUnit) DI = 'U'; + else if (Diag == CblasNonUnit) DI = 'N'; + else + { + cblas_xerbla(4, "cblas_stpmv","Illegal Diag setting, %d\n", Diag); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_DI = C2F_CHAR(&DI); + #endif + F77_stpmv( F77_UL, F77_TA, F77_DI, &F77_N, Ap, X, &F77_incX); + } + else if (order == CblasRowMajor) + { + if (Uplo == CblasUpper) UL = 'L'; + else if (Uplo == CblasLower) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_stpmv","Illegal Uplo setting, %d\n", Uplo); + return; + } + + if (TransA == CblasNoTrans) TA = 'T'; + else if (TransA == CblasTrans) TA = 'N'; + else if (TransA == CblasConjTrans) TA = 'N'; + else + { + cblas_xerbla(3, "cblas_stpmv","Illegal TransA setting, %d\n", TransA); + return; + } + + if (Diag == CblasUnit) DI = 'U'; + else if (Diag == CblasNonUnit) DI = 'N'; + else + { + cblas_xerbla(4, "cblas_stpmv","Illegal Diag setting, %d\n", Diag); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_DI = C2F_CHAR(&DI); + #endif + + F77_stpmv( F77_UL, F77_TA, F77_DI, &F77_N, Ap, X,&F77_incX); + } + else cblas_xerbla(1, "cblas_stpmv", "Illegal Order setting, %d\n", order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_stpsv.c b/ml/dlib/dlib/external/cblas/cblas_stpsv.c new file mode 100644 index 000000000..2f0d29c0d --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_stpsv.c @@ -0,0 +1,99 @@ +/* + * cblas_stpsv.c + * The program is a C interface to stpsv. + * + * Keita Teranishi 5/20/98 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_stpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const float *Ap, float *X, const int incX) +{ + char TA; + char UL; + char DI; +#ifdef F77_CHAR + F77_CHAR F77_TA, F77_UL, F77_DI; +#else + #define F77_TA &TA + #define F77_UL &UL + #define F77_DI &DI +#endif +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX; +#else + #define F77_N N + #define F77_incX incX +#endif + + if (order == CblasColMajor) + { + if (Uplo == CblasUpper) UL = 'U'; + else if (Uplo == CblasLower) UL = 'L'; + else + { + cblas_xerbla(2, "cblas_stpsv","Illegal Uplo setting, %d\n", Uplo); + return; + } + if (TransA == CblasNoTrans) TA = 'N'; + else if (TransA == CblasTrans) TA = 'T'; + else if (TransA == CblasConjTrans) TA = 'C'; + else + { + cblas_xerbla(3, "cblas_stpsv","Illegal TransA setting, %d\n", TransA); + return; + } + if (Diag == CblasUnit) DI = 'U'; + else if (Diag == CblasNonUnit) DI = 'N'; + else + { + cblas_xerbla(4, "cblas_stpsv","Illegal Diag setting, %d\n", Diag); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_DI = C2F_CHAR(&DI); + #endif + F77_stpsv( F77_UL, F77_TA, F77_DI, &F77_N, Ap, X, &F77_incX); + } + else if (order == CblasRowMajor) + { + if (Uplo == CblasUpper) UL = 'L'; + else if (Uplo == CblasLower) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_stpsv","Illegal Uplo setting, %d\n", Uplo); + return; + } + + if (TransA == CblasNoTrans) TA = 'T'; + else if (TransA == CblasTrans) TA = 'N'; + else if (TransA == CblasConjTrans) TA = 'N'; + else + { + cblas_xerbla(3, "cblas_stpsv","Illegal TransA setting, %d\n", TransA); + return; + } + + if (Diag == CblasUnit) DI = 'U'; + else if (Diag == CblasNonUnit) DI = 'N'; + else + { + cblas_xerbla(4, "cblas_stpsv","Illegal Diag setting, %d\n", Diag); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_DI = C2F_CHAR(&DI); + #endif + + F77_stpsv( F77_UL, F77_TA, F77_DI, &F77_N, Ap, X,&F77_incX); + + } + else cblas_xerbla(1, "cblas_stpsv", "Illegal Order setting, %d\n", order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_strmm.c b/ml/dlib/dlib/external/cblas/cblas_strmm.c new file mode 100644 index 000000000..40d6e23dd --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_strmm.c @@ -0,0 +1,125 @@ +/* + * + * cblas_strmm.c + * This program is a C interface to strmm. + * Written by Keita Teranishi + * 4/6/1998 + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_strmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_DIAG Diag, const int M, const int N, + const float alpha, const float *A, const int lda, + float *B, const int ldb) +{ + char UL, TA, SD, DI; +#ifdef F77_CHAR + F77_CHAR F77_TA, F77_UL, F77_SD, F77_DI; +#else + #define F77_TA &TA + #define F77_UL &UL + #define F77_SD &SD + #define F77_DI &DI +#endif + +#ifdef F77_INT + F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_ldb=ldb; +#else + #define F77_M M + #define F77_N N + #define F77_lda lda + #define F77_ldb ldb +#endif + + + if( Order == CblasColMajor ) + { + if( Side == CblasRight) SD='R'; + else if ( Side == CblasLeft ) SD='L'; + else + { + cblas_xerbla(2, "cblas_strmm","Illegal Side setting, %d\n", Side); + return; + } + if( Uplo == CblasUpper) UL='U'; + else if ( Uplo == CblasLower ) UL='L'; + else + { + cblas_xerbla(3, "cblas_strmm","Illegal Uplo setting, %d\n", Uplo); + return; + } + + if( TransA == CblasTrans) TA ='T'; + else if ( TransA == CblasConjTrans ) TA='C'; + else if ( TransA == CblasNoTrans ) TA='N'; + else + { + cblas_xerbla(4, "cblas_strmm","Illegal Trans setting, %d\n", TransA); + return; + } + + if( Diag == CblasUnit ) DI='U'; + else if ( Diag == CblasNonUnit ) DI='N'; + else + { + cblas_xerbla(5, "cblas_strmm", "Illegal Diag setting, %d\n", Diag); + return; + } + + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_SD = C2F_CHAR(&SD); + F77_DI = C2F_CHAR(&DI); + #endif + + F77_strmm(F77_SD, F77_UL, F77_TA, F77_DI, &F77_M, &F77_N, &alpha, A, &F77_lda, B, &F77_ldb); + } else if (Order == CblasRowMajor) + { + if( Side == CblasRight) SD='L'; + else if ( Side == CblasLeft ) SD='R'; + else + { + cblas_xerbla(2, "cblas_strmm","Illegal Side setting, %d\n", Side); + return; + } + + if( Uplo == CblasUpper) UL='L'; + else if ( Uplo == CblasLower ) UL='U'; + else + { + cblas_xerbla(3, "cblas_strmm", "Illegal Uplo setting, %d\n", Uplo); + return; + } + + if( TransA == CblasTrans) TA ='T'; + else if ( TransA == CblasConjTrans ) TA='C'; + else if ( TransA == CblasNoTrans ) TA='N'; + else + { + cblas_xerbla(4, "cblas_strmm", "Illegal Trans setting, %d\n", TransA); + return; + } + + if( Diag == CblasUnit ) DI='U'; + else if ( Diag == CblasNonUnit ) DI='N'; + else + { + cblas_xerbla(5, "cblas_strmm","Illegal Diag setting, %d\n", Diag); + return; + } +#ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_SD = C2F_CHAR(&SD); + F77_DI = C2F_CHAR(&DI); +#endif + F77_strmm(F77_SD, F77_UL, F77_TA, F77_DI, &F77_N, &F77_M, &alpha, A, + &F77_lda, B, &F77_ldb); + } + else cblas_xerbla(1, "cblas_strmm", "Illegal Order setting, %d\n", Order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_strmv.c b/ml/dlib/dlib/external/cblas/cblas_strmv.c new file mode 100644 index 000000000..4c2f7b6a6 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_strmv.c @@ -0,0 +1,103 @@ +/* + * + * cblas_strmv.c + * This program is a C interface to strmv. + * Written by Keita Teranishi + * 4/6/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_strmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const float *A, const int lda, + float *X, const int incX) + +{ + char TA; + char UL; + char DI; +#ifdef F77_CHAR + F77_CHAR F77_TA, F77_UL, F77_DI; +#else + #define F77_TA &TA + #define F77_UL &UL + #define F77_DI &DI +#endif +#ifdef F77_INT + F77_INT F77_N=N, F77_lda=lda, F77_incX=incX; +#else + #define F77_N N + #define F77_lda lda + #define F77_incX incX +#endif + + if (order == CblasColMajor) + { + if (Uplo == CblasUpper) UL = 'U'; + else if (Uplo == CblasLower) UL = 'L'; + else + { + cblas_xerbla(2, "cblas_strmv","Illegal Uplo setting, %d\n", Uplo); + return; + } + if (TransA == CblasNoTrans) TA = 'N'; + else if (TransA == CblasTrans) TA = 'T'; + else if (TransA == CblasConjTrans) TA = 'C'; + else + { + cblas_xerbla(3, "cblas_strmv","Illegal TransA setting, %d\n", TransA); + return; + } + if (Diag == CblasUnit) DI = 'U'; + else if (Diag == CblasNonUnit) DI = 'N'; + else + { + cblas_xerbla(4, "cblas_strmv","Illegal Diag setting, %d\n", Diag); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_DI = C2F_CHAR(&DI); + #endif + F77_strmv( F77_UL, F77_TA, F77_DI, &F77_N, A, &F77_lda, X, + &F77_incX); + } + else if (order == CblasRowMajor) + { + if (Uplo == CblasUpper) UL = 'L'; + else if (Uplo == CblasLower) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_strmv","Illegal Uplo setting, %d\n", Uplo); + return; + } + + if (TransA == CblasNoTrans) TA = 'T'; + else if (TransA == CblasTrans) TA = 'N'; + else if (TransA == CblasConjTrans) TA = 'N'; + else + { + cblas_xerbla(3, "cblas_strmv","Illegal TransA setting, %d\n", TransA); + return; + } + + if (Diag == CblasUnit) DI = 'U'; + else if (Diag == CblasNonUnit) DI = 'N'; + else + { + cblas_xerbla(4, "cblas_strmv","Illegal Diag setting, %d\n", Diag); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_DI = C2F_CHAR(&DI); + #endif + F77_strmv( F77_UL, F77_TA, F77_DI, &F77_N, A, &F77_lda, X, + &F77_incX); + } + else cblas_xerbla(1, "cblas_strmv", "Illegal Order setting, %d\n", order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_strsm.c b/ml/dlib/dlib/external/cblas/cblas_strsm.c new file mode 100644 index 000000000..178cf7d06 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_strsm.c @@ -0,0 +1,120 @@ +/* + * + * cblas_strsm.c + * This program is a C interface to strsm. + * Written by Keita Teranishi + * 4/6/1998 + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_strsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_DIAG Diag, const int M, const int N, + const float alpha, const float *A, const int lda, + float *B, const int ldb) + +{ + char UL, TA, SD, DI; +#ifdef F77_CHAR + F77_CHAR F77_TA, F77_UL, F77_SD, F77_DI; +#else + #define F77_TA &TA + #define F77_UL &UL + #define F77_SD &SD + #define F77_DI &DI +#endif + +#ifdef F77_INT + F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_ldb=ldb; +#else + #define F77_M M + #define F77_N N + #define F77_lda lda + #define F77_ldb ldb +#endif + + + if( Order == CblasColMajor ) + { + if( Side == CblasRight) SD='R'; + else if ( Side == CblasLeft ) SD='L'; + else + { + cblas_xerbla(2, "cblas_strsm", "Illegal Side setting, %d\n", Side); + return; + } + if( Uplo == CblasUpper) UL='U'; + else if ( Uplo == CblasLower ) UL='L'; + else + { + cblas_xerbla(3, "cblas_strsm", "Illegal Uplo setting, %d\n", Uplo); + return; + } + if( TransA == CblasTrans) TA ='T'; + else if ( TransA == CblasConjTrans ) TA='C'; + else if ( TransA == CblasNoTrans ) TA='N'; + else + { + cblas_xerbla(4, "cblas_strsm", "Illegal Trans setting, %d\n", TransA); + return; + } + if( Diag == CblasUnit ) DI='U'; + else if ( Diag == CblasNonUnit ) DI='N'; + else + { + cblas_xerbla(5, "cblas_strsm", "Illegal Diag setting, %d\n", Diag); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_SD = C2F_CHAR(&SD); + F77_DI = C2F_CHAR(&DI); + #endif + + F77_strsm(F77_SD, F77_UL, F77_TA, F77_DI, &F77_M, &F77_N, &alpha, A, &F77_lda, B, &F77_ldb); + } else if (Order == CblasRowMajor) + { + if( Side == CblasRight) SD='L'; + else if ( Side == CblasLeft ) SD='R'; + else + { + cblas_xerbla(2, "cblas_strsm", "Illegal Side setting, %d\n", Side); + return; + } + if( Uplo == CblasUpper) UL='L'; + else if ( Uplo == CblasLower ) UL='U'; + else + { + cblas_xerbla(3, "cblas_strsm", "Illegal Uplo setting, %d\n", Uplo); + return; + } + if( TransA == CblasTrans) TA ='T'; + else if ( TransA == CblasConjTrans ) TA='C'; + else if ( TransA == CblasNoTrans ) TA='N'; + else + { + cblas_xerbla(4, "cblas_strsm", "Illegal Trans setting, %d\n", TransA); + return; + } + if( Diag == CblasUnit ) DI='U'; + else if ( Diag == CblasNonUnit ) DI='N'; + else + { + cblas_xerbla(5, "cblas_strsm", "Illegal Diag setting, %d\n", Diag); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_SD = C2F_CHAR(&SD); + F77_DI = C2F_CHAR(&DI); + #endif + + F77_strsm(F77_SD, F77_UL, F77_TA, F77_DI, &F77_N, &F77_M, &alpha, A, &F77_lda, B, &F77_ldb); + } + else cblas_xerbla(1, "cblas_strsm", "Illegal Order setting, %d\n", Order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_strsv.c b/ml/dlib/dlib/external/cblas/cblas_strsv.c new file mode 100644 index 000000000..7c3811974 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_strsv.c @@ -0,0 +1,102 @@ +/* + * cblas_strsv.c + * The program is a C interface to strsv. + * + * Keita Teranishi 5/20/98 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_strsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const float *A, const int lda, float *X, + const int incX) + +{ + char TA; + char UL; + char DI; +#ifdef F77_CHAR + F77_CHAR F77_TA, F77_UL, F77_DI; +#else + #define F77_TA &TA + #define F77_UL &UL + #define F77_DI &DI +#endif +#ifdef F77_INT + F77_INT F77_N=N, F77_lda=lda, F77_incX=incX; +#else + #define F77_N N + #define F77_lda lda + #define F77_incX incX +#endif + + if (order == CblasColMajor) + { + if (Uplo == CblasUpper) UL = 'U'; + else if (Uplo == CblasLower) UL = 'L'; + else + { + cblas_xerbla(2, "cblas_strsv","Illegal Uplo setting, %d\n", Uplo); + return; + } + if (TransA == CblasNoTrans) TA = 'N'; + else if (TransA == CblasTrans) TA = 'T'; + else if (TransA == CblasConjTrans) TA = 'C'; + else + { + cblas_xerbla(3, "cblas_strsv","Illegal TransA setting, %d\n", TransA); + return; + } + if (Diag == CblasUnit) DI = 'U'; + else if (Diag == CblasNonUnit) DI = 'N'; + else + { + cblas_xerbla(4, "cblas_strsv","Illegal Diag setting, %d\n", Diag); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_DI = C2F_CHAR(&DI); + #endif + F77_strsv( F77_UL, F77_TA, F77_DI, &F77_N, A, &F77_lda, X, + &F77_incX); + } + else if (order == CblasRowMajor) + { + if (Uplo == CblasUpper) UL = 'L'; + else if (Uplo == CblasLower) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_strsv","Illegal Uplo setting, %d\n", Uplo); + return; + } + + if (TransA == CblasNoTrans) TA = 'T'; + else if (TransA == CblasTrans) TA = 'N'; + else if (TransA == CblasConjTrans) TA = 'N'; + else + { + cblas_xerbla(3, "cblas_strsv","Illegal TransA setting, %d\n", TransA); + return; + } + + if (Diag == CblasUnit) DI = 'U'; + else if (Diag == CblasNonUnit) DI = 'N'; + else + { + cblas_xerbla(4, "cblas_strsv","Illegal Diag setting, %d\n", Diag); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_DI = C2F_CHAR(&DI); + #endif + F77_strsv( F77_UL, F77_TA, F77_DI, &F77_N, A, &F77_lda, X, + &F77_incX); + } + else cblas_xerbla(1, "cblas_strsv", "Illegal Order setting, %d\n", order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_xerbla.c b/ml/dlib/dlib/external/cblas/cblas_xerbla.c new file mode 100644 index 000000000..0b5b39f53 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_xerbla.c @@ -0,0 +1,66 @@ +#include +#include +#include +#include +#include "cblas.h" +#include "cblas_f77.h" + +void cblas_xerbla(int info, const char *rout, const char *form, ...) +{ + char empty[1] = ""; + va_list argptr; + + va_start(argptr, form); + + { + if (strstr(rout,"gemm") != 0) + { + if (info == 5 ) info = 4; + else if (info == 4 ) info = 5; + else if (info == 11) info = 9; + else if (info == 9 ) info = 11; + } + else if (strstr(rout,"symm") != 0 || strstr(rout,"hemm") != 0) + { + if (info == 5 ) info = 4; + else if (info == 4 ) info = 5; + } + else if (strstr(rout,"trmm") != 0 || strstr(rout,"trsm") != 0) + { + if (info == 7 ) info = 6; + else if (info == 6 ) info = 7; + } + else if (strstr(rout,"gemv") != 0) + { + if (info == 4) info = 3; + else if (info == 3) info = 4; + } + else if (strstr(rout,"gbmv") != 0) + { + if (info == 4) info = 3; + else if (info == 3) info = 4; + else if (info == 6) info = 5; + else if (info == 5) info = 6; + } + else if (strstr(rout,"ger") != 0) + { + if (info == 3) info = 2; + else if (info == 2) info = 3; + else if (info == 8) info = 6; + else if (info == 6) info = 8; + } + else if ( (strstr(rout,"her2") != 0 || strstr(rout,"hpr2") != 0) + && strstr(rout,"her2k") == 0 ) + { + if (info == 8) info = 6; + else if (info == 6) info = 8; + } + } + if (info) + fprintf(stderr, "Parameter %d to routine %s was incorrect\n", info, rout); + vfprintf(stderr, form, argptr); + va_end(argptr); + if (info && !info) + F77_xerbla(empty, &info); /* Force link of our F77 error handler */ + exit(-1); +} diff --git a/ml/dlib/dlib/external/cblas/cblas_zaxpy.c b/ml/dlib/dlib/external/cblas/cblas_zaxpy.c new file mode 100644 index 000000000..f63c4c39b --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_zaxpy.c @@ -0,0 +1,22 @@ +/* + * cblas_zaxpy.c + * + * The program is a C interface to zaxpy. + * + * Written by Keita Teranishi. 2/11/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_zaxpy( const int N, const void *alpha, const void *X, + const int incX, void *Y, const int incY) +{ +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX, F77_incY=incY; +#else + #define F77_N N + #define F77_incX incX + #define F77_incY incY +#endif + F77_zaxpy( &F77_N, alpha, X, &F77_incX, Y, &F77_incY); +} diff --git a/ml/dlib/dlib/external/cblas/cblas_zcopy.c b/ml/dlib/dlib/external/cblas/cblas_zcopy.c new file mode 100644 index 000000000..a16be28e7 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_zcopy.c @@ -0,0 +1,22 @@ +/* + * cblas_zcopy.c + * + * The program is a C interface to zcopy. + * + * Written by Keita Teranishi. 2/11/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_zcopy( const int N, const void *X, + const int incX, void *Y, const int incY) +{ +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX, F77_incY=incY; +#else + #define F77_N N + #define F77_incX incX + #define F77_incY incY +#endif + F77_zcopy( &F77_N, X, &F77_incX, Y, &F77_incY); +} diff --git a/ml/dlib/dlib/external/cblas/cblas_zdotc_sub.c b/ml/dlib/dlib/external/cblas/cblas_zdotc_sub.c new file mode 100644 index 000000000..29dec6c57 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_zdotc_sub.c @@ -0,0 +1,24 @@ +/* + * cblas_zdotc_sub.c + * + * The program is a C interface to zdotc. + * It calls the fortran wrapper before calling zdotc. + * + * Written by Keita Teranishi. 2/11/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_zdotc_sub( const int N, const void *X, const int incX, + const void *Y, const int incY, void *dotc) +{ +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX, F77_incY=incY; +#else + #define F77_N N + #define F77_incX incX + #define F77_incY incY +#endif + F77_zdotc_sub( &F77_N, X, &F77_incX, Y, &F77_incY, dotc); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_zdotu_sub.c b/ml/dlib/dlib/external/cblas/cblas_zdotu_sub.c new file mode 100644 index 000000000..48a14bf3d --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_zdotu_sub.c @@ -0,0 +1,24 @@ +/* + * cblas_zdotu_sub.c + * + * The program is a C interface to zdotu. + * It calls the fortran wrapper before calling zdotu. + * + * Written by Keita Teranishi. 2/11/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_zdotu_sub( const int N, const void *X, const int incX, + const void *Y, const int incY, void *dotu) +{ +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX, F77_incY=incY; +#else + #define F77_N N + #define F77_incX incX + #define F77_incY incY +#endif + F77_zdotu_sub( &F77_N, X, &F77_incX, Y, &F77_incY, dotu); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_zdscal.c b/ml/dlib/dlib/external/cblas/cblas_zdscal.c new file mode 100644 index 000000000..788365bef --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_zdscal.c @@ -0,0 +1,21 @@ +/* + * cblas_zdscal.c + * + * The program is a C interface to zdscal. + * + * Written by Keita Teranishi. 2/11/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_zdscal( const int N, const double alpha, void *X, + const int incX) +{ +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX; +#else + #define F77_N N + #define F77_incX incX +#endif + F77_zdscal( &F77_N, &alpha, X, &F77_incX); +} diff --git a/ml/dlib/dlib/external/cblas/cblas_zgbmv.c b/ml/dlib/dlib/external/cblas/cblas_zgbmv.c new file mode 100644 index 000000000..e5b922429 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_zgbmv.c @@ -0,0 +1,155 @@ +/* + * cblas_zgbmv.c + * The program is a C interface of zgbmv + * + * Keita Teranishi 5/20/98 + * + */ +#include +#include +#include "cblas.h" +#include "cblas_f77.h" +void cblas_zgbmv(const enum CBLAS_ORDER order, + const enum CBLAS_TRANSPOSE TransA, const int M, const int N, + const int KL, const int KU, + const void *alpha, const void *A, const int lda, + const void *X, const int incX, const void *beta, + void *Y, const int incY) +{ + char TA; +#ifdef F77_CHAR + F77_CHAR F77_TA; +#else + #define F77_TA &TA +#endif +#ifdef F77_INT + F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_incX=incX, F77_incY=incY; + F77_INT F77_KL=KL,F77_KU=KU; +#else + #define F77_M M + #define F77_N N + #define F77_lda lda + #define F77_KL KL + #define F77_KU KU + #define F77_incX incx + #define F77_incY incY +#endif + int n, i=0, incx=incX; + const double *xx= (double *)X, *alp= (double *)alpha, *bet = (double *)beta; + double ALPHA[2],BETA[2]; + int tincY, tincx; + double *x=(double *)X, *y=(double *)Y, *st=0, *tx; + + if (order == CblasColMajor) + { + if (TransA == CblasNoTrans) TA = 'N'; + else if (TransA == CblasTrans) TA = 'T'; + else if (TransA == CblasConjTrans) TA = 'C'; + else + { + cblas_xerbla(2, "cblas_zgbmv","Illegal TransA setting, %d\n", TransA); + return; + } + #ifdef F77_CHAR + F77_TA = C2F_CHAR(&TA); + #endif + F77_zgbmv(F77_TA, &F77_M, &F77_N, &F77_KL, &F77_KU, alpha, + A, &F77_lda, X, &F77_incX, beta, Y, &F77_incY); + } + else if (order == CblasRowMajor) + { + if (TransA == CblasNoTrans) TA = 'T'; + else if (TransA == CblasTrans) TA = 'N'; + else if (TransA == CblasConjTrans) + { + ALPHA[0]= *alp; + ALPHA[1]= -alp[1]; + BETA[0]= *bet; + BETA[1]= -bet[1]; + TA = 'N'; + if (M > 0) + { + n = M << 1; + x = malloc(n*sizeof(double)); + tx = x; + + if( incX > 0 ) { + i = incX << 1 ; + tincx = 2; + st= x+n; + } else { + i = incX *(-2); + tincx = -2; + st = x-2; + x +=(n-2); + } + do + { + *x = *xx; + x[1] = -xx[1]; + x += tincx ; + xx += i; + } + while (x != st); + x=tx; + + #ifdef F77_INT + F77_incX = 1; + #else + incx = 1; + #endif + + if( incY > 0 ) + tincY = incY; + else + tincY = -incY; + + y++; + + if (N > 0) + { + i = tincY << 1; + n = i * N ; + st = y + n; + do { + *y = -(*y); + y += i; + } while(y != st); + y -= n; + } + } + else x = (double *) X; + + + } + else + { + cblas_xerbla(2, "cblas_zgbmv","Illegal TransA setting, %d\n", TransA); + return; + } + #ifdef F77_CHAR + F77_TA = C2F_CHAR(&TA); + #endif + if (TransA == CblasConjTrans) + F77_zgbmv(F77_TA, &F77_N, &F77_M, &F77_KU, &F77_KL, ALPHA, + A ,&F77_lda, x,&F77_incX, BETA, Y, &F77_incY); + else + F77_zgbmv(F77_TA, &F77_N, &F77_M, &F77_KU, &F77_KL, alpha, + A ,&F77_lda, x,&F77_incX, beta, Y, &F77_incY); + if (TransA == CblasConjTrans) + { + if (x != X) free(x); + if (N > 0) + { + do + { + *y = -(*y); + y += i; + } + while (y != st); + } + } + } + else cblas_xerbla(1, "cblas_zgbmv", "Illegal Order setting, %d\n", order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_zgemm.c b/ml/dlib/dlib/external/cblas/cblas_zgemm.c new file mode 100644 index 000000000..c348afa2d --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_zgemm.c @@ -0,0 +1,94 @@ +/* + * + * cblas_zgemm.c + * This program is a C interface to zgemm. + * Written by Keita Teranishi + * 4/8/1998 + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_zgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_TRANSPOSE TransB, const int M, const int N, + const int K, const void *alpha, const void *A, + const int lda, const void *B, const int ldb, + const void *beta, void *C, const int ldc) +{ + char TA, TB; +#ifdef F77_CHAR + F77_CHAR F77_TA, F77_TB; +#else + #define F77_TA &TA + #define F77_TB &TB +#endif + +#ifdef F77_INT + F77_INT F77_M=M, F77_N=N, F77_K=K, F77_lda=lda, F77_ldb=ldb; + F77_INT F77_ldc=ldc; +#else + #define F77_M M + #define F77_N N + #define F77_K K + #define F77_lda lda + #define F77_ldb ldb + #define F77_ldc ldc +#endif + + + if( Order == CblasColMajor ) + { + if(TransA == CblasTrans) TA='T'; + else if ( TransA == CblasConjTrans ) TA='C'; + else if ( TransA == CblasNoTrans ) TA='N'; + else + { + cblas_xerbla(2, "cblas_zgemm","Illegal TransA setting, %d\n", TransA); + return; + } + + if(TransB == CblasTrans) TB='T'; + else if ( TransB == CblasConjTrans ) TB='C'; + else if ( TransB == CblasNoTrans ) TB='N'; + else + { + cblas_xerbla(3, "cblas_zgemm","Illegal TransB setting, %d\n", TransB); + return; + } + + #ifdef F77_CHAR + F77_TA = C2F_CHAR(&TA); + F77_TB = C2F_CHAR(&TB); + #endif + + F77_zgemm(F77_TA, F77_TB, &F77_M, &F77_N, &F77_K, alpha, A, + &F77_lda, B, &F77_ldb, beta, C, &F77_ldc); + } else if (Order == CblasRowMajor) + { + if(TransA == CblasTrans) TB='T'; + else if ( TransA == CblasConjTrans ) TB='C'; + else if ( TransA == CblasNoTrans ) TB='N'; + else + { + cblas_xerbla(2, "cblas_zgemm","Illegal TransA setting, %d\n", TransA); + return; + } + if(TransB == CblasTrans) TA='T'; + else if ( TransB == CblasConjTrans ) TA='C'; + else if ( TransB == CblasNoTrans ) TA='N'; + else + { + cblas_xerbla(2, "cblas_zgemm","Illegal TransB setting, %d\n", TransB); + return; + } + #ifdef F77_CHAR + F77_TA = C2F_CHAR(&TA); + F77_TB = C2F_CHAR(&TB); + #endif + + F77_zgemm(F77_TA, F77_TB, &F77_N, &F77_M, &F77_K, alpha, B, + &F77_ldb, A, &F77_lda, beta, C, &F77_ldc); + } + else cblas_xerbla(1, "cblas_zgemm", "Illegal Order setting, %d\n", Order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_zgemv.c b/ml/dlib/dlib/external/cblas/cblas_zgemv.c new file mode 100644 index 000000000..6d5cd0cb2 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_zgemv.c @@ -0,0 +1,153 @@ +/* + * cblas_zgemv.c + * The program is a C interface of zgemv + * + * Keita Teranishi 5/20/98 + * + */ +#include +#include +#include "cblas.h" +#include "cblas_f77.h" +void cblas_zgemv(const enum CBLAS_ORDER order, + const enum CBLAS_TRANSPOSE TransA, const int M, const int N, + const void *alpha, const void *A, const int lda, + const void *X, const int incX, const void *beta, + void *Y, const int incY) +{ + char TA; +#ifdef F77_CHAR + F77_CHAR F77_TA; +#else + #define F77_TA &TA +#endif +#ifdef F77_INT + F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_incX=incX, F77_incY=incY; +#else + #define F77_M M + #define F77_N N + #define F77_lda lda + #define F77_incX incx + #define F77_incY incY +#endif + + int n, i=0, incx=incX; + const double *xx= (double *)X, *alp= (double *)alpha, *bet = (double *)beta; + double ALPHA[2],BETA[2]; + int tincY, tincx; + double *x=(double *)X, *y=(double *)Y, *st=0, *tx; + + + if (order == CblasColMajor) + { + if (TransA == CblasNoTrans) TA = 'N'; + else if (TransA == CblasTrans) TA = 'T'; + else if (TransA == CblasConjTrans) TA = 'C'; + else + { + cblas_xerbla(2, "cblas_zgemv","Illegal TransA setting, %d\n", TransA); + return; + } + #ifdef F77_CHAR + F77_TA = C2F_CHAR(&TA); + #endif + F77_zgemv(F77_TA, &F77_M, &F77_N, alpha, A, &F77_lda, X, &F77_incX, + beta, Y, &F77_incY); + } + else if (order == CblasRowMajor) + { + + if (TransA == CblasNoTrans) TA = 'T'; + else if (TransA == CblasTrans) TA = 'N'; + else if (TransA == CblasConjTrans) + { + ALPHA[0]= *alp; + ALPHA[1]= -alp[1]; + BETA[0]= *bet; + BETA[1]= -bet[1]; + TA = 'N'; + if (M > 0) + { + n = M << 1; + x = malloc(n*sizeof(double)); + tx = x; + if( incX > 0 ) { + i = incX << 1 ; + tincx = 2; + st= x+n; + } else { + i = incX *(-2); + tincx = -2; + st = x-2; + x +=(n-2); + } + + do + { + *x = *xx; + x[1] = -xx[1]; + x += tincx ; + xx += i; + } + while (x != st); + x=tx; + + #ifdef F77_INT + F77_incX = 1; + #else + incx = 1; + #endif + + if(incY > 0) + tincY = incY; + else + tincY = -incY; + + y++; + + if (N > 0) + { + i = tincY << 1; + n = i * N ; + st = y + n; + do { + *y = -(*y); + y += i; + } while(y != st); + y -= n; + } + } + else x = (double *) X; + } + else + { + cblas_xerbla(2, "cblas_zgemv","Illegal TransA setting, %d\n", TransA); + return; + } + #ifdef F77_CHAR + F77_TA = C2F_CHAR(&TA); + #endif + if (TransA == CblasConjTrans) + F77_zgemv(F77_TA, &F77_N, &F77_M, ALPHA, A, &F77_lda, x, + &F77_incX, BETA, Y, &F77_incY); + else + F77_zgemv(F77_TA, &F77_N, &F77_M, alpha, A, &F77_lda, x, + &F77_incX, beta, Y, &F77_incY); + + if (TransA == CblasConjTrans) + { + if (x != (double *)X) free(x); + if (N > 0) + { + do + { + *y = -(*y); + y += i; + } + while (y != st); + } + } + } + else cblas_xerbla(1, "cblas_zgemv", "Illegal Order setting, %d\n", order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_zgerc.c b/ml/dlib/dlib/external/cblas/cblas_zgerc.c new file mode 100644 index 000000000..2fbbcb028 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_zgerc.c @@ -0,0 +1,77 @@ +/* + * cblas_zgerc.c + * The program is a C interface to zgerc. + * + * Keita Teranishi 5/20/98 + * + */ +#include +#include +#include "cblas.h" +#include "cblas_f77.h" +void cblas_zgerc(const enum CBLAS_ORDER order, const int M, const int N, + const void *alpha, const void *X, const int incX, + const void *Y, const int incY, void *A, const int lda) +{ +#ifdef F77_INT + F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_incX=incX, F77_incY=incY; +#else + #define F77_M M + #define F77_N N + #define F77_incX incX + #define F77_incY incy + #define F77_lda lda +#endif + + int n, i, tincy, incy=incY; + double *y=(double *)Y, *yy=(double *)Y, *ty, *st; + + + if (order == CblasColMajor) + { + F77_zgerc( &F77_M, &F77_N, alpha, X, &F77_incX, Y, &F77_incY, A, + &F77_lda); + } else if (order == CblasRowMajor) + { + if (N > 0) + { + n = N << 1; + y = malloc(n*sizeof(double)); + + ty = y; + if( incY > 0 ) { + i = incY << 1; + tincy = 2; + st= y+n; + } else { + i = incY *(-2); + tincy = -2; + st = y-2; + y +=(n-2); + } + do + { + *y = *yy; + y[1] = -yy[1]; + y += tincy ; + yy += i; + } + while (y != st); + y = ty; + + #ifdef F77_INT + F77_incY = 1; + #else + incy = 1; + #endif + } + else y = (double *) Y; + + F77_zgeru( &F77_N, &F77_M, alpha, y, &F77_incY, X, &F77_incX, A, + &F77_lda); + if(Y!=y) + free(y); + + } else cblas_xerbla(1, "cblas_zgerc", "Illegal Order setting, %d\n", order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_zgeru.c b/ml/dlib/dlib/external/cblas/cblas_zgeru.c new file mode 100644 index 000000000..56c3ded68 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_zgeru.c @@ -0,0 +1,37 @@ +/* + * cblas_zgeru.c + * The program is a C interface to zgeru. + * + * Keita Teranishi 5/20/98 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_zgeru(const enum CBLAS_ORDER order, const int M, const int N, + const void *alpha, const void *X, const int incX, + const void *Y, const int incY, void *A, const int lda) +{ +#ifdef F77_INT + F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_incX=incX, F77_incY=incY; +#else + #define F77_M M + #define F77_N N + #define F77_incX incX + #define F77_incY incY + #define F77_lda lda +#endif + + + if (order == CblasColMajor) + { + F77_zgeru( &F77_M, &F77_N, alpha, X, &F77_incX, Y, &F77_incY, A, + &F77_lda); + } + else if (order == CblasRowMajor) + { + F77_zgeru( &F77_N, &F77_M, alpha, Y, &F77_incY, X, &F77_incX, A, + &F77_lda); + } + else cblas_xerbla(1, "cblas_zgeru", "Illegal Order setting, %d\n", order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_zhbmv.c b/ml/dlib/dlib/external/cblas/cblas_zhbmv.c new file mode 100644 index 000000000..491207d3f --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_zhbmv.c @@ -0,0 +1,145 @@ +/* + * cblas_zhbmv.c + * The program is a C interface to zhbmv + * + * Keita Teranishi 5/18/98 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +#include +#include +void cblas_zhbmv(const enum CBLAS_ORDER order, + const enum CBLAS_UPLO Uplo,const int N,const int K, + const void *alpha, const void *A, const int lda, + const void *X, const int incX, const void *beta, + void *Y, const int incY) +{ + char UL; +#ifdef F77_CHAR + F77_CHAR F77_UL; +#else + #define F77_UL &UL +#endif +#ifdef F77_INT + F77_INT F77_N=N, F77_K=K, F77_lda=lda, F77_incX=incX, F77_incY=incY; +#else + #define F77_N N + #define F77_K K + #define F77_lda lda + #define F77_incX incx + #define F77_incY incY +#endif + int n, i=0, incx=incX; + const double *xx= (double *)X, *alp= (double *)alpha, *bet = (double *)beta; + double ALPHA[2],BETA[2]; + int tincY, tincx; + double *x=(double *)X, *y=(double *)Y, *st=0, *tx; + + if (order == CblasColMajor) + { + if (Uplo == CblasLower) UL = 'L'; + else if (Uplo == CblasUpper) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_zhbmv","Illegal Uplo setting, %d\n",Uplo ); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + F77_zhbmv(F77_UL, &F77_N, &F77_K, alpha, A, &F77_lda, X, + &F77_incX, beta, Y, &F77_incY); + } + else if (order == CblasRowMajor) + { + ALPHA[0]= *alp; + ALPHA[1]= -alp[1]; + BETA[0]= *bet; + BETA[1]= -bet[1]; + + if (N > 0) + { + n = N << 1; + x = malloc(n*sizeof(double)); + + tx = x; + if( incX > 0 ) { + i = incX << 1 ; + tincx = 2; + st= x+n; + } else { + i = incX *(-2); + tincx = -2; + st = x-2; + x +=(n-2); + } + + do + { + *x = *xx; + x[1] = -xx[1]; + x += tincx ; + xx += i; + } + while (x != st); + x=tx; + + + #ifdef F77_INT + F77_incX = 1; + #else + incx = 1; + #endif + + if(incY > 0) + tincY = incY; + else + tincY = -incY; + y++; + + i = tincY << 1; + n = i * N ; + st = y + n; + do { + *y = -(*y); + y += i; + } while(y != st); + y -= n; + } else + x = (double *) X; + + if (Uplo == CblasUpper) UL = 'L'; + else if (Uplo == CblasLower) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_zhbmv","Illegal Uplo setting, %d\n", Uplo); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + F77_zhbmv(F77_UL, &F77_N, &F77_K, ALPHA, + A ,&F77_lda, x,&F77_incX, BETA, Y, &F77_incY); + } + else + { + cblas_xerbla(1, "cblas_zhbmv","Illegal Order setting, %d\n", order); + return; + } + if ( order == CblasRowMajor ) + { + if(X!=x) + free(x); + if (N > 0) + { + do + { + *y = -(*y); + y += i; + } + while (y != st); + } + } + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_zhemm.c b/ml/dlib/dlib/external/cblas/cblas_zhemm.c new file mode 100644 index 000000000..a31e9ae87 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_zhemm.c @@ -0,0 +1,91 @@ +/* + * + * cblas_zhemm.c + * This program is a C interface to zhemm. + * Written by Keita Teranishi + * 4/8/1998 + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_zhemm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const int M, const int N, + const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, + void *C, const int ldc) +{ + char SD, UL; +#ifdef F77_CHAR + F77_CHAR F77_SD, F77_UL; +#else + #define F77_SD &SD + #define F77_UL &UL +#endif + +#ifdef F77_INT + F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_ldb=ldb; + F77_INT F77_ldc=ldc; +#else + #define F77_M M + #define F77_N N + #define F77_lda lda + #define F77_ldb ldb + #define F77_ldc ldc +#endif + + + if( Order == CblasColMajor ) + { + if( Side == CblasRight) SD='R'; + else if ( Side == CblasLeft ) SD='L'; + else + { + cblas_xerbla(2, "cblas_zhemm", "Illegal Side setting, %d\n", Side); + return; + } + + if( Uplo == CblasUpper) UL='U'; + else if ( Uplo == CblasLower ) UL='L'; + else + { + cblas_xerbla(3, "cblas_zhemm", "Illegal Uplo setting, %d\n", Uplo); + return; + } + + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_SD = C2F_CHAR(&SD); + #endif + + F77_zhemm(F77_SD, F77_UL, &F77_M, &F77_N, alpha, A, &F77_lda, + B, &F77_ldb, beta, C, &F77_ldc); + } else if (Order == CblasRowMajor) + { + if( Side == CblasRight) SD='L'; + else if ( Side == CblasLeft ) SD='R'; + else + { + cblas_xerbla(2, "cblas_zhemm", "Illegal Side setting, %d\n", Side); + return; + } + + if( Uplo == CblasUpper) UL='L'; + else if ( Uplo == CblasLower ) UL='U'; + else + { + cblas_xerbla(3, "cblas_zhemm", "Illegal Uplo setting, %d\n", Uplo); + return; + } + + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_SD = C2F_CHAR(&SD); + #endif + + F77_zhemm(F77_SD, F77_UL, &F77_N, &F77_M, alpha, A, + &F77_lda, B, &F77_ldb, beta, C, &F77_ldc); + } + else cblas_xerbla(1, "cblas_zhemm", "Illegal Order setting, %d\n", Order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_zhemv.c b/ml/dlib/dlib/external/cblas/cblas_zhemv.c new file mode 100644 index 000000000..c3cdb5958 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_zhemv.c @@ -0,0 +1,146 @@ +/* + * cblas_zhemv.c + * The program is a C interface to zhemv + * + * Keita Teranishi 5/18/98 + * + */ +#include +#include +#include "cblas.h" +#include "cblas_f77.h" +void cblas_zhemv(const enum CBLAS_ORDER order, + const enum CBLAS_UPLO Uplo, const int N, + const void *alpha, const void *A, const int lda, + const void *X, const int incX, const void *beta, + void *Y, const int incY) +{ + char UL; +#ifdef F77_CHAR + F77_CHAR F77_UL; +#else + #define F77_UL &UL +#endif +#ifdef F77_INT + F77_INT F77_N=N, F77_lda=lda, F77_incX=incX, F77_incY=incY; +#else + #define F77_N N + #define F77_lda lda + #define F77_incX incx + #define F77_incY incY +#endif + int n, i=0, incx=incX; + const double *xx= (double *)X, *alp= (double *)alpha, *bet = (double *)beta; + double ALPHA[2],BETA[2]; + int tincY, tincx; + double *x=(double *)X, *y=(double *)Y, *st=0, *tx; + + + if (order == CblasColMajor) + { + if (Uplo == CblasUpper) UL = 'U'; + else if (Uplo == CblasLower) UL = 'L'; + else + { + cblas_xerbla(2, "cblas_zhemv","Illegal Uplo setting, %d\n",Uplo ); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + F77_zhemv(F77_UL, &F77_N, alpha, A, &F77_lda, X, &F77_incX, + beta, Y, &F77_incY); + } + else if (order == CblasRowMajor) + { + ALPHA[0]= *alp; + ALPHA[1]= -alp[1]; + BETA[0]= *bet; + BETA[1]= -bet[1]; + + if (N > 0) + { + n = N << 1; + x = malloc(n*sizeof(double)); + + tx = x; + if( incX > 0 ) { + i = incX << 1 ; + tincx = 2; + st= x+n; + } else { + i = incX *(-2); + tincx = -2; + st = x-2; + x +=(n-2); + } + + do + { + *x = *xx; + x[1] = -xx[1]; + x += tincx ; + xx += i; + } + while (x != st); + x=tx; + + + #ifdef F77_INT + F77_incX = 1; + #else + incx = 1; + #endif + + if(incY > 0) + tincY = incY; + else + tincY = -incY; + y++; + + i = tincY << 1; + n = i * N ; + st = y + n; + do { + *y = -(*y); + y += i; + } while(y != st); + y -= n; + } else + x = (double *) X; + + + if (Uplo == CblasUpper) UL = 'L'; + else if (Uplo == CblasLower) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_zhemv","Illegal Uplo setting, %d\n", Uplo); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + F77_zhemv(F77_UL, &F77_N, ALPHA, A, &F77_lda, x, &F77_incX, + BETA, Y, &F77_incY); + } + else + { + cblas_xerbla(1, "cblas_zhemv","Illegal Order setting, %d\n", order); + return; + } + if ( order == CblasRowMajor ) + { + if ( X != x ) + free(x); + if (N > 0) + { + do + { + *y = -(*y); + y += i; + } + while (y != st); + } + } + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_zher.c b/ml/dlib/dlib/external/cblas/cblas_zher.c new file mode 100644 index 000000000..30453737c --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_zher.c @@ -0,0 +1,99 @@ +/* + * cblas_zher.c + * The program is a C interface to zher. + * + * Keita Teranishi 5/20/98 + * + */ +#include +#include +#include "cblas.h" +#include "cblas_f77.h" +void cblas_zher(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const int N, const double alpha, const void *X, const int incX + ,void *A, const int lda) +{ + char UL; +#ifdef F77_CHAR + F77_CHAR F77_UL; +#else + #define F77_UL &UL +#endif + +#ifdef F77_INT + F77_INT F77_N=N, F77_lda=lda, F77_incX=incX; +#else + #define F77_N N + #define F77_lda lda + #define F77_incX incx +#endif + int n, i, tincx, incx=incX; + double *x=(double *)X, *xx=(double *)X, *tx, *st; + + + if (order == CblasColMajor) + { + if (Uplo == CblasLower) UL = 'L'; + else if (Uplo == CblasUpper) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_zher","Illegal Uplo setting, %d\n",Uplo ); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + + F77_zher(F77_UL, &F77_N, &alpha, X, &F77_incX, A, &F77_lda); + + } else if (order == CblasRowMajor) + { + if (Uplo == CblasUpper) UL = 'L'; + else if (Uplo == CblasLower) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_zher","Illegal Uplo setting, %d\n", Uplo); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + if (N > 0) + { + n = N << 1; + x = malloc(n*sizeof(double)); + tx = x; + if( incX > 0 ) { + i = incX << 1 ; + tincx = 2; + st= x+n; + } else { + i = incX *(-2); + tincx = -2; + st = x-2; + x +=(n-2); + } + do + { + *x = *xx; + x[1] = -xx[1]; + x += tincx ; + xx += i; + } + while (x != st); + x=tx; + + #ifdef F77_INT + F77_incX = 1; + #else + incx = 1; + #endif + } + else x = (double *) X; + F77_zher(F77_UL, &F77_N, &alpha, x, &F77_incX, A, &F77_lda); + } else cblas_xerbla(1, "cblas_zher", "Illegal Order setting, %d\n", order); + if(X!=x) + free(x); + + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_zher2.c b/ml/dlib/dlib/external/cblas/cblas_zher2.c new file mode 100644 index 000000000..8bf0bd733 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_zher2.c @@ -0,0 +1,140 @@ +/* + * cblas_zher2.c + * The program is a C interface to zher2. + * + * Keita Teranishi 3/23/98 + * + */ +#include +#include +#include "cblas.h" +#include "cblas_f77.h" +void cblas_zher2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const int N, const void *alpha, const void *X, const int incX, + const void *Y, const int incY, void *A, const int lda) +{ + char UL; +#ifdef F77_CHAR + F77_CHAR F77_UL; +#else + #define F77_UL &UL +#endif + +#ifdef F77_INT + F77_INT F77_N=N, F77_lda=lda, F77_incX=incX, F77_incY=incY; +#else + #define F77_N N + #define F77_lda lda + #define F77_incX incx + #define F77_incY incy +#endif + int n, i, j, tincx, tincy, incx=incX, incy=incY; + double *x=(double *)X, *xx=(double *)X, *y=(double *)Y, + *yy=(double *)Y, *tx, *ty, *stx, *sty; + + + if (order == CblasColMajor) + { + if (Uplo == CblasLower) UL = 'L'; + else if (Uplo == CblasUpper) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_zher2", "Illegal Uplo setting, %d\n",Uplo ); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + + F77_zher2(F77_UL, &F77_N, alpha, X, &F77_incX, + Y, &F77_incY, A, &F77_lda); + + } else if (order == CblasRowMajor) + { + if (Uplo == CblasUpper) UL = 'L'; + else if (Uplo == CblasLower) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_zher2", "Illegal Uplo setting, %d\n", Uplo); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + if (N > 0) + { + n = N << 1; + x = malloc(n*sizeof(double)); + y = malloc(n*sizeof(double)); + tx = x; + ty = y; + if( incX > 0 ) { + i = incX << 1 ; + tincx = 2; + stx= x+n; + } else { + i = incX *(-2); + tincx = -2; + stx = x-2; + x +=(n-2); + } + + if( incY > 0 ) { + j = incY << 1; + tincy = 2; + sty= y+n; + } else { + j = incY *(-2); + tincy = -2; + sty = y-2; + y +=(n-2); + } + + do + { + *x = *xx; + x[1] = -xx[1]; + x += tincx ; + xx += i; + } + while (x != stx); + + do + { + *y = *yy; + y[1] = -yy[1]; + y += tincy ; + yy += j; + } + while (y != sty); + + x=tx; + y=ty; + + #ifdef F77_INT + F77_incX = 1; + F77_incY = 1; + #else + incx = 1; + incy = 1; + #endif + } else + { + x = (double *) X; + y = (double *) Y; + } + F77_zher2(F77_UL, &F77_N, alpha, y, &F77_incY, x, + &F77_incX, A, &F77_lda); + } + else + { + cblas_xerbla(1, "cblas_zher2", "Illegal Order setting, %d\n", order); + return; + } + if(X!=x) + free(x); + if(Y!=y) + free(y); + + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_zher2k.c b/ml/dlib/dlib/external/cblas/cblas_zher2k.c new file mode 100644 index 000000000..96bcfe2a5 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_zher2k.c @@ -0,0 +1,95 @@ +/* + * + * cblas_zher2k.c + * This program is a C interface to zher2k. + * Written by Keita Teranishi + * 4/8/1998 + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_zher2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const double beta, + void *C, const int ldc) +{ + char UL, TR; +#ifdef F77_CHAR + F77_CHAR F77_TR, F77_UL; +#else + #define F77_TR &TR + #define F77_UL &UL +#endif + +#ifdef F77_INT + F77_INT F77_N=N, F77_K=K, F77_lda=lda, F77_ldb=ldb; + F77_INT F77_ldc=ldc; +#else + #define F77_N N + #define F77_K K + #define F77_lda lda + #define F77_ldb ldb + #define F77_ldc ldc +#endif + + double ALPHA[2]; + const double *alp=(double *)alpha; + + + if( Order == CblasColMajor ) + { + + if( Uplo == CblasUpper) UL='U'; + else if ( Uplo == CblasLower ) UL='L'; + else + { + cblas_xerbla(2, "cblas_zher2k", "Illegal Uplo setting, %d\n", Uplo); + return; + } + + if( Trans == CblasTrans) TR ='T'; + else if ( Trans == CblasConjTrans ) TR='C'; + else if ( Trans == CblasNoTrans ) TR='N'; + else + { + cblas_xerbla(3, "cblas_zher2k", "Illegal Trans setting, %d\n", Trans); + return; + } + + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TR = C2F_CHAR(&TR); + #endif + + F77_zher2k(F77_UL, F77_TR, &F77_N, &F77_K, alpha, A, &F77_lda, B, &F77_ldb, &beta, C, &F77_ldc); + } else if (Order == CblasRowMajor) + { + + if( Uplo == CblasUpper) UL='L'; + else if ( Uplo == CblasLower ) UL='U'; + else + { + cblas_xerbla(2, "cblas_zher2k", "Illegal Uplo setting, %d\n", Uplo); + return; + } + if( Trans == CblasTrans) TR ='N'; + else if ( Trans == CblasConjTrans ) TR='N'; + else if ( Trans == CblasNoTrans ) TR='C'; + else + { + cblas_xerbla(3, "cblas_zher2k", "Illegal Trans setting, %d\n", Trans); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TR = C2F_CHAR(&TR); + #endif + + ALPHA[0]= *alp; + ALPHA[1]= -alp[1]; + F77_zher2k(F77_UL,F77_TR, &F77_N, &F77_K, ALPHA, A, &F77_lda, B, &F77_ldb, &beta, C, &F77_ldc); + } else cblas_xerbla(1, "cblas_zher2k", "Illegal Order setting, %d\n", Order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_zherk.c b/ml/dlib/dlib/external/cblas/cblas_zherk.c new file mode 100644 index 000000000..bddef491b --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_zherk.c @@ -0,0 +1,90 @@ +/* + * + * cblas_zherk.c + * This program is a C interface to zherk. + * Written by Keita Teranishi + * 4/8/1998 + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_zherk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const double alpha, const void *A, const int lda, + const double beta, void *C, const int ldc) +{ + char UL, TR; +#ifdef F77_CHAR + F77_CHAR F77_TR, F77_UL; +#else + #define F77_TR &TR + #define F77_UL &UL +#endif + +#ifdef F77_INT + F77_INT F77_N=N, F77_K=K, F77_lda=lda; + F77_INT F77_ldc=ldc; +#else + #define F77_N N + #define F77_K K + #define F77_lda lda + #define F77_ldc ldc +#endif + + + if( Order == CblasColMajor ) + { + if( Uplo == CblasUpper) UL='U'; + else if ( Uplo == CblasLower ) UL='L'; + else + { + cblas_xerbla(2, "cblas_zherk", "Illegal Uplo setting, %d\n", Uplo); + return; + } + + if( Trans == CblasTrans) TR ='T'; + else if ( Trans == CblasConjTrans ) TR='C'; + else if ( Trans == CblasNoTrans ) TR='N'; + else + { + cblas_xerbla(3, "cblas_zherk", "Illegal Trans setting, %d\n", Trans); + return; + } + + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TR = C2F_CHAR(&TR); + #endif + + F77_zherk(F77_UL, F77_TR, &F77_N, &F77_K, &alpha, A, &F77_lda, + &beta, C, &F77_ldc); + } else if (Order == CblasRowMajor) + { + if( Uplo == CblasUpper) UL='L'; + else if ( Uplo == CblasLower ) UL='U'; + else + { + cblas_xerbla(3, "cblas_zherk", "Illegal Uplo setting, %d\n", Uplo); + return; + } + if( Trans == CblasTrans) TR ='N'; + else if ( Trans == CblasConjTrans ) TR='N'; + else if ( Trans == CblasNoTrans ) TR='C'; + else + { + cblas_xerbla(3, "cblas_zherk", "Illegal Trans setting, %d\n", Trans); + return; + } + + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_SD = C2F_CHAR(&SD); + #endif + + F77_zherk(F77_UL, F77_TR, &F77_N, &F77_K, &alpha, A, &F77_lda, + &beta, C, &F77_ldc); + } + else cblas_xerbla(1, "cblas_zherk", "Illegal Order setting, %d\n", Order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_zhpmv.c b/ml/dlib/dlib/external/cblas/cblas_zhpmv.c new file mode 100644 index 000000000..1812884fc --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_zhpmv.c @@ -0,0 +1,146 @@ +/* + * cblas_zhpmv.c + * The program is a C interface of zhpmv + * + * Keita Teranishi 5/18/98 + * + */ +#include +#include +#include "cblas.h" +#include "cblas_f77.h" +void cblas_zhpmv(const enum CBLAS_ORDER order, + const enum CBLAS_UPLO Uplo,const int N, + const void *alpha, const void *AP, + const void *X, const int incX, const void *beta, + void *Y, const int incY) +{ + char UL; +#ifdef F77_CHAR + F77_CHAR F77_UL; +#else + #define F77_UL &UL +#endif +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX, F77_incY=incY; +#else + #define F77_N N + #define F77_incX incx + #define F77_incY incY +#endif + int n, i=0, incx=incX; + const double *xx= (double *)X, *alp= (double *)alpha, *bet = (double *)beta; + double ALPHA[2],BETA[2]; + int tincY, tincx; + double *x=(double *)X, *y=(double *)Y, *st=0, *tx; + + if (order == CblasColMajor) + { + if (Uplo == CblasLower) UL = 'L'; + else if (Uplo == CblasUpper) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_zhpmv","Illegal Uplo setting, %d\n",Uplo ); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + F77_zhpmv(F77_UL, &F77_N, alpha, AP, X, + &F77_incX, beta, Y, &F77_incY); + } + else if (order == CblasRowMajor) + { + ALPHA[0]= *alp; + ALPHA[1]= -alp[1]; + BETA[0]= *bet; + BETA[1]= -bet[1]; + + if (N > 0) + { + n = N << 1; + x = malloc(n*sizeof(double)); + + tx = x; + if( incX > 0 ) { + i = incX << 1; + tincx = 2; + st= x+n; + } else { + i = incX *(-2); + tincx = -2; + st = x-2; + x +=(n-2); + } + + do + { + *x = *xx; + x[1] = -xx[1]; + x += tincx ; + xx += i; + } + while (x != st); + x=tx; + + + #ifdef F77_INT + F77_incX = 1; + #else + incx = 1; + #endif + + if(incY > 0) + tincY = incY; + else + tincY = -incY; + y++; + + i = tincY << 1; + n = i * N ; + st = y + n; + do { + *y = -(*y); + y += i; + } while(y != st); + y -= n; + } else + x = (double *) X; + + + if (Uplo == CblasUpper) UL = 'L'; + else if (Uplo == CblasLower) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_zhpmv","Illegal Uplo setting, %d\n", Uplo ); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + + F77_zhpmv(F77_UL, &F77_N, ALPHA, + AP, x, &F77_incX, BETA, Y, &F77_incY); + } + else + { + cblas_xerbla(1, "cblas_zhpmv","Illegal Order setting, %d\n", order); + return; + } + if ( order == CblasRowMajor ) + { + if(X!=x) + free(x); + if (N > 0) + { + do + { + *y = -(*y); + y += i; + } + while (y != st); + } + } + + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_zhpr.c b/ml/dlib/dlib/external/cblas/cblas_zhpr.c new file mode 100644 index 000000000..3ed2a8f61 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_zhpr.c @@ -0,0 +1,102 @@ +/* + * cblas_zhpr.c + * The program is a C interface to zhpr. + * + * Keita Teranishi 3/23/98 + * + */ +#include +#include +#include "cblas.h" +#include "cblas_f77.h" +void cblas_zhpr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const int N, const double alpha, const void *X, + const int incX, void *A) +{ + char UL; +#ifdef F77_CHAR + F77_CHAR F77_UL; +#else + #define F77_UL &UL +#endif + +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX; +#else + #define F77_N N + #define F77_incX incx +#endif + int n, i, tincx, incx=incX; + double *x=(double *)X, *xx=(double *)X, *tx, *st; + + + if (order == CblasColMajor) + { + if (Uplo == CblasLower) UL = 'L'; + else if (Uplo == CblasUpper) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_zhpr","Illegal Uplo setting, %d\n",Uplo ); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + + F77_zhpr(F77_UL, &F77_N, &alpha, X, &F77_incX, A); + + } else if (order == CblasRowMajor) + { + if (Uplo == CblasUpper) UL = 'L'; + else if (Uplo == CblasLower) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_zhpr","Illegal Uplo setting, %d\n", Uplo); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + if (N > 0) + { + n = N << 1; + x = malloc(n*sizeof(double)); + tx = x; + if( incX > 0 ) { + i = incX << 1; + tincx = 2; + st= x+n; + } else { + i = incX *(-2); + tincx = -2; + st = x-2; + x +=(n-2); + } + do + { + *x = *xx; + x[1] = -xx[1]; + x += tincx ; + xx += i; + } + while (x != st); + x=tx; + #ifdef F77_INT + F77_incX = 1; + #else + incx = 1; + #endif + } + else x = (double *) X; + + F77_zhpr(F77_UL, &F77_N, &alpha, x, &F77_incX, A); + + } else + { + cblas_xerbla(1, "cblas_zhpr","Illegal Order setting, %d\n", order); + return; + } + if(X!=x) + free(x); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_zhpr2.c b/ml/dlib/dlib/external/cblas/cblas_zhpr2.c new file mode 100644 index 000000000..0793a298a --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_zhpr2.c @@ -0,0 +1,137 @@ +/* + * cblas_zhpr2.c + * The program is a C interface to zhpr2. + * + * Keita Teranishi 5/20/98 + * + */ +#include +#include +#include "cblas.h" +#include "cblas_f77.h" +void cblas_zhpr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const int N,const void *alpha, const void *X, + const int incX,const void *Y, const int incY, void *Ap) + +{ + char UL; +#ifdef F77_CHAR + F77_CHAR F77_UL; +#else + #define F77_UL &UL +#endif + +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX, F77_incY=incY; +#else + #define F77_N N + #define F77_incX incx + #define F77_incY incy +#endif + int n, i, j, incx=incX, incy=incY; + double *x=(double *)X, *xx=(double *)X, *y=(double *)Y, + *yy=(double *)Y, *stx, *sty; + + + if (order == CblasColMajor) + { + if (Uplo == CblasLower) UL = 'L'; + else if (Uplo == CblasUpper) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_zhpr2","Illegal Uplo setting, %d\n",Uplo ); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + + F77_zhpr2(F77_UL, &F77_N, alpha, X, &F77_incX, Y, &F77_incY, Ap); + + } else if (order == CblasRowMajor) + { + if (Uplo == CblasUpper) UL = 'L'; + else if (Uplo == CblasLower) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_zhpr2","Illegal Uplo setting, %d\n", Uplo); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + #endif + if (N > 0) + { + n = N << 1; + x = malloc(n*sizeof(double)); + y = malloc(n*sizeof(double)); + stx = x + n; + sty = y + n; + if( incX > 0 ) + i = incX << 1; + else + i = incX *(-2); + + if( incY > 0 ) + j = incY << 1; + else + j = incY *(-2); + do + { + *x = *xx; + x[1] = -xx[1]; + x += 2; + xx += i; + } while (x != stx); + do + { + *y = *yy; + y[1] = -yy[1]; + y += 2; + yy += j; + } + while (y != sty); + x -= n; + y -= n; + + #ifdef F77_INT + if(incX > 0 ) + F77_incX = 1; + else + F77_incX = -1; + + if(incY > 0 ) + F77_incY = 1; + else + F77_incY = -1; + + #else + if(incX > 0 ) + incx = 1; + else + incx = -1; + + if(incY > 0 ) + incy = 1; + else + incy = -1; + #endif + + } else + { + x = (double *) X; + y = (void *) Y; + } + F77_zhpr2(F77_UL, &F77_N, alpha, y, &F77_incY, x, &F77_incX, Ap); + } + else + { + cblas_xerbla(1, "cblas_zhpr2","Illegal Order setting, %d\n", order); + return; + } + if(X!=x) + free(x); + if(Y!=y) + free(y); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_zscal.c b/ml/dlib/dlib/external/cblas/cblas_zscal.c new file mode 100644 index 000000000..37b319f38 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_zscal.c @@ -0,0 +1,21 @@ +/* + * cblas_zscal.c + * + * The program is a C interface to zscal. + * + * Written by Keita Teranishi. 2/11/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_zscal( const int N, const void *alpha, void *X, + const int incX) +{ +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX; +#else + #define F77_N N + #define F77_incX incX +#endif + F77_zscal( &F77_N, alpha, X, &F77_incX); +} diff --git a/ml/dlib/dlib/external/cblas/cblas_zswap.c b/ml/dlib/dlib/external/cblas/cblas_zswap.c new file mode 100644 index 000000000..dfde2cbd0 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_zswap.c @@ -0,0 +1,22 @@ +/* + * cblas_zswap.c + * + * The program is a C interface to zswap. + * + * Written by Keita Teranishi. 2/11/1998 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_zswap( const int N, void *X, const int incX, void *Y, + const int incY) +{ +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX, F77_incY=incY; +#else + #define F77_N N + #define F77_incX incX + #define F77_incY incY +#endif + F77_zswap( &F77_N, X, &F77_incX, Y, &F77_incY); +} diff --git a/ml/dlib/dlib/external/cblas/cblas_zsymm.c b/ml/dlib/dlib/external/cblas/cblas_zsymm.c new file mode 100644 index 000000000..85d5e3f49 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_zsymm.c @@ -0,0 +1,91 @@ +/* + * + * cblas_zsymm.c + * This program is a C interface to zsymm. + * Written by Keita Teranishi + * 4/8/1998 + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_zsymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const int M, const int N, + const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, + void *C, const int ldc) +{ + char SD, UL; +#ifdef F77_CHAR + F77_CHAR F77_SD, F77_UL; +#else + #define F77_SD &SD + #define F77_UL &UL +#endif + +#ifdef F77_INT + F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_ldb=ldb; + F77_INT F77_ldc=ldc; +#else + #define F77_M M + #define F77_N N + #define F77_lda lda + #define F77_ldb ldb + #define F77_ldc ldc +#endif + + + if( Order == CblasColMajor ) + { + if( Side == CblasRight) SD='R'; + else if ( Side == CblasLeft ) SD='L'; + else + { + cblas_xerbla(2, "cblas_zsymm", "Illegal Side setting, %d\n", Side); + return; + } + + if( Uplo == CblasUpper) UL='U'; + else if ( Uplo == CblasLower ) UL='L'; + else + { + cblas_xerbla(3, "cblas_zsymm", "Illegal Uplo setting, %d\n", Uplo); + return; + } + + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_SD = C2F_CHAR(&SD); + #endif + + F77_zsymm(F77_SD, F77_UL, &F77_M, &F77_N, alpha, A, &F77_lda, + B, &F77_ldb, beta, C, &F77_ldc); + } else if (Order == CblasRowMajor) + { + if( Side == CblasRight) SD='L'; + else if ( Side == CblasLeft ) SD='R'; + else + { + cblas_xerbla(2, "cblas_zsymm", "Illegal Side setting, %d\n", Side); + return; + } + + if( Uplo == CblasUpper) UL='L'; + else if ( Uplo == CblasLower ) UL='U'; + else + { + cblas_xerbla(3, "cblas_zsymm", "Illegal Uplo setting, %d\n", Uplo); + return; + } + + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_SD = C2F_CHAR(&SD); + #endif + + F77_zsymm(F77_SD, F77_UL, &F77_N, &F77_M, alpha, A, &F77_lda, + B, &F77_ldb, beta, C, &F77_ldc); + } + else cblas_xerbla(1, "cblas_zsymm", "Illegal Order setting, %d\n", Order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_zsyr2k.c b/ml/dlib/dlib/external/cblas/cblas_zsyr2k.c new file mode 100644 index 000000000..ffac33462 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_zsyr2k.c @@ -0,0 +1,93 @@ +/* + * + * cblas_zsyr2k.c + * This program is a C interface to zsyr2k. + * Written by Keita Teranishi + * 4/8/1998 + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_zsyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, + void *C, const int ldc) +{ + char UL, TR; +#ifdef F77_CHAR + F77_CHAR F77_TR, F77_UL; +#else + #define F77_TR &TR + #define F77_UL &UL +#endif + +#ifdef F77_INT + F77_INT F77_N=N, F77_K=K, F77_lda=lda, F77_ldb=ldb; + F77_INT F77_ldc=ldc; +#else + #define F77_N N + #define F77_K K + #define F77_lda lda + #define F77_ldb ldb + #define F77_ldc ldc +#endif + + + if( Order == CblasColMajor ) + { + + if( Uplo == CblasUpper) UL='U'; + else if ( Uplo == CblasLower ) UL='L'; + else + { + cblas_xerbla(2, "cblas_zsyr2k", "Illegal Uplo setting, %d\n", Uplo); + return; + } + + if( Trans == CblasTrans) TR ='T'; + else if ( Trans == CblasConjTrans ) TR='C'; + else if ( Trans == CblasNoTrans ) TR='N'; + else + { + cblas_xerbla(3, "cblas_zsyr2k", "Illegal Trans setting, %d\n", Trans); + return; + } + + + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TR = C2F_CHAR(&TR); + #endif + + F77_zsyr2k(F77_UL, F77_TR, &F77_N, &F77_K, alpha, A, &F77_lda, + B, &F77_ldb, beta, C, &F77_ldc); + } else if (Order == CblasRowMajor) + { + if( Uplo == CblasUpper) UL='L'; + else if ( Uplo == CblasLower ) UL='U'; + else + { + cblas_xerbla(3, "cblas_zsyr2k", "Illegal Uplo setting, %d\n", Uplo); + return; + } + if( Trans == CblasTrans) TR ='N'; + else if ( Trans == CblasConjTrans ) TR='N'; + else if ( Trans == CblasNoTrans ) TR='T'; + else + { + cblas_xerbla(3, "cblas_zsyr2k", "Illegal Trans setting, %d\n", Trans); + return; + } + + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TR = C2F_CHAR(&TR); + #endif + + F77_zsyr2k(F77_UL, F77_TR, &F77_N, &F77_K, alpha, A, &F77_lda, B, &F77_ldb, beta, C, &F77_ldc); + } + else cblas_xerbla(1, "cblas_zsyr2k", "Illegal Order setting, %d\n", Order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_zsyrk.c b/ml/dlib/dlib/external/cblas/cblas_zsyrk.c new file mode 100644 index 000000000..45796074f --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_zsyrk.c @@ -0,0 +1,92 @@ +/* + * + * cblas_zsyrk.c + * This program is a C interface to zsyrk. + * Written by Keita Teranishi + * 4/8/1998 + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_zsyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const void *alpha, const void *A, const int lda, + const void *beta, void *C, const int ldc) +{ + char UL, TR; +#ifdef F77_CHAR + F77_CHAR F77_TR, F77_UL; +#else + #define F77_TR &TR + #define F77_UL &UL +#endif + +#ifdef F77_INT + F77_INT F77_N=N, F77_K=K, F77_lda=lda; + F77_INT F77_ldc=ldc; +#else + #define F77_N N + #define F77_K K + #define F77_lda lda + #define F77_ldc ldc +#endif + + + if( Order == CblasColMajor ) + { + + if( Uplo == CblasUpper) UL='U'; + else if ( Uplo == CblasLower ) UL='L'; + else + { + cblas_xerbla(2, "cblas_zsyrk", "Illegal Uplo setting, %d\n", Uplo); + return; + } + + if( Trans == CblasTrans) TR ='T'; + else if ( Trans == CblasConjTrans ) TR='C'; + else if ( Trans == CblasNoTrans ) TR='N'; + else + { + cblas_xerbla(3, "cblas_zsyrk", "Illegal Trans setting, %d\n", Trans); + return; + } + + + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TR = C2F_CHAR(&TR); + #endif + + F77_zsyrk(F77_UL, F77_TR, &F77_N, &F77_K, alpha, A, &F77_lda, + beta, C, &F77_ldc); + } else if (Order == CblasRowMajor) + { + if( Uplo == CblasUpper) UL='L'; + else if ( Uplo == CblasLower ) UL='U'; + else + { + cblas_xerbla(3, "cblas_zsyrk", "Illegal Uplo setting, %d\n", Uplo); + return; + } + if( Trans == CblasTrans) TR ='N'; + else if ( Trans == CblasConjTrans ) TR='N'; + else if ( Trans == CblasNoTrans ) TR='T'; + else + { + cblas_xerbla(3, "cblas_zsyrk", "Illegal Trans setting, %d\n", Trans); + return; + } + + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TR = C2F_CHAR(&TR); + #endif + + F77_zsyrk(F77_UL, F77_TR, &F77_N, &F77_K, alpha, A, &F77_lda, + beta, C, &F77_ldc); + } + else cblas_xerbla(1, "cblas_zsyrk", "Illegal Order setting, %d\n", Order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_ztbmv.c b/ml/dlib/dlib/external/cblas/cblas_ztbmv.c new file mode 100644 index 000000000..916774d27 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_ztbmv.c @@ -0,0 +1,139 @@ +/* + * cblas_ztbmv.c + * The program is a C interface to ztbmv. + * + * Keita Teranishi 5/20/98 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_ztbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const int K, const void *A, const int lda, + void *X, const int incX) +{ + char TA; + char UL; + char DI; +#ifdef F77_CHAR + F77_CHAR F77_TA, F77_UL, F77_DI; +#else + #define F77_TA &TA + #define F77_UL &UL + #define F77_DI &DI +#endif +#ifdef F77_INT + F77_INT F77_N=N, F77_lda=lda, F77_K=K, F77_incX=incX; +#else + #define F77_N N + #define F77_K K + #define F77_lda lda + #define F77_incX incX +#endif + int n, i=0, tincX; + double *st=0, *x=(double *)X; + + if (order == CblasColMajor) + { + if (Uplo == CblasUpper) UL = 'U'; + else if (Uplo == CblasLower) UL = 'L'; + else + { + cblas_xerbla(2, "cblas_ztbmv","Illegal Uplo setting, %d\n", Uplo); + return; + } + if (TransA == CblasNoTrans) TA = 'N'; + else if (TransA == CblasTrans) TA = 'T'; + else if (TransA == CblasConjTrans) TA = 'C'; + else + { + cblas_xerbla(3, "cblas_ztbmv","Illegal TransA setting, %d\n", TransA); + return; + } + if (Diag == CblasUnit) DI = 'U'; + else if (Diag == CblasNonUnit) DI = 'N'; + else + { + cblas_xerbla(4, "cblas_ztbmv","Illegal Diag setting, %d\n", Diag); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_DI = C2F_CHAR(&DI); + #endif + F77_ztbmv( F77_UL, F77_TA, F77_DI, &F77_N, &F77_K, A, &F77_lda, X, + &F77_incX); + } + else if (order == CblasRowMajor) + { + if (Uplo == CblasUpper) UL = 'L'; + else if (Uplo == CblasLower) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_ztbmv","Illegal Uplo setting, %d\n", Uplo); + return; + } + + if (TransA == CblasNoTrans) TA = 'T'; + else if (TransA == CblasTrans) TA = 'N'; + else if (TransA == CblasConjTrans) + { + TA = 'N'; + if ( N > 0) + { + if(incX > 0) + tincX = incX; + else + tincX = -incX; + i = tincX << 1; + n = i * N; + x++; + st = x + n; + do + { + *x = -(*x); + x+= i; + } + while (x != st); + x -= n; + } + } + else + { + cblas_xerbla(3, "cblas_ztbmv","Illegal TransA setting, %d\n", TransA); + return; + } + + if (Diag == CblasUnit) DI = 'U'; + else if (Diag == CblasNonUnit) DI = 'N'; + else + { + cblas_xerbla(4, "cblas_ztbmv","Illegal Uplo setting, %d\n", Uplo); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_DI = C2F_CHAR(&DI); + #endif + + F77_ztbmv( F77_UL, F77_TA, F77_DI, &F77_N, &F77_K, A, &F77_lda, X, + &F77_incX); + + if (TransA == CblasConjTrans) + { + if (N > 0) + { + do + { + *x = -(*x); + x += i; + } + while (x != st); + } + } + } + else cblas_xerbla(1, "cblas_ztbmv", "Illegal Order setting, %d\n", order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_ztbsv.c b/ml/dlib/dlib/external/cblas/cblas_ztbsv.c new file mode 100644 index 000000000..cc5d3f73f --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_ztbsv.c @@ -0,0 +1,143 @@ +/* + * cblas_ztbsv.c + * The program is a C interface to ztbsv. + * + * Keita Teranishi 3/23/98 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_ztbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const int K, const void *A, const int lda, + void *X, const int incX) +{ + char TA; + char UL; + char DI; +#ifdef F77_CHAR + F77_CHAR F77_TA, F77_UL, F77_DI; +#else + #define F77_TA &TA + #define F77_UL &UL + #define F77_DI &DI +#endif +#ifdef F77_INT + F77_INT F77_N=N, F77_lda=lda, F77_K=K, F77_incX=incX; +#else + #define F77_N N + #define F77_K K + #define F77_lda lda + #define F77_incX incX +#endif + int n, i=0, tincX; + double *st=0,*x=(double *)X; + + if (order == CblasColMajor) + { + if (Uplo == CblasUpper) UL = 'U'; + else if (Uplo == CblasLower) UL = 'L'; + else + { + cblas_xerbla(2, "cblas_ztbsv","Illegal Uplo setting, %d\n", Uplo); + return; + } + if (TransA == CblasNoTrans) TA = 'N'; + else if (TransA == CblasTrans) TA = 'T'; + else if (TransA == CblasConjTrans) TA = 'C'; + else + { + cblas_xerbla(3, "cblas_ztbsv","Illegal TransA setting, %d\n", TransA); + return; + } + if (Diag == CblasUnit) DI = 'U'; + else if (Diag == CblasNonUnit) DI = 'N'; + else + { + cblas_xerbla(4, "cblas_ztbsv","Illegal Diag setting, %d\n", Diag); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_DI = C2F_CHAR(&DI); + #endif + F77_ztbsv( F77_UL, F77_TA, F77_DI, &F77_N, &F77_K, A, &F77_lda, X, + &F77_incX); + } + else if (order == CblasRowMajor) + { + if (Uplo == CblasUpper) UL = 'L'; + else if (Uplo == CblasLower) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_ztbsv","Illegal Uplo setting, %d\n", Uplo); + return; + } + + if (TransA == CblasNoTrans) TA = 'T'; + else if (TransA == CblasTrans) TA = 'N'; + else if (TransA == CblasConjTrans) + { + TA = 'N'; + if ( N > 0) + { + if ( incX > 0 ) + tincX = incX; + else + tincX = -incX; + + n = N*2*(tincX); + + x++; + + st=x+n; + + i = tincX << 1; + do + { + *x = -(*x); + x+=i; + } + while (x != st); + x -= n; + } + } + else + { + cblas_xerbla(3, "cblas_ztbsv","Illegal TransA setting, %d\n", TransA); + return; + } + + if (Diag == CblasUnit) DI = 'U'; + else if (Diag == CblasNonUnit) DI = 'N'; + else + { + cblas_xerbla(4, "cblas_ztbsv","Illegal Diag setting, %d\n", Diag); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_DI = C2F_CHAR(&DI); + #endif + + F77_ztbsv( F77_UL, F77_TA, F77_DI, &F77_N, &F77_K, A, &F77_lda, X, + &F77_incX); + + if (TransA == CblasConjTrans) + { + if (N > 0) + { + do + { + *x = -(*x); + x+= i; + } + while (x != st); + } + } + } + else cblas_xerbla(1, "cblas_ztbsv", "Illegal Order setting, %d\n", order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_ztpmv.c b/ml/dlib/dlib/external/cblas/cblas_ztpmv.c new file mode 100644 index 000000000..2e7949a25 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_ztpmv.c @@ -0,0 +1,133 @@ +/* + * cblas_ztpmv.c + * The program is a C interface to ztpmv. + * + * Keita Teranishi 5/20/98 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_ztpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const void *Ap, void *X, const int incX) +{ + char TA; + char UL; + char DI; +#ifdef F77_CHAR + F77_CHAR F77_TA, F77_UL, F77_DI; +#else + #define F77_TA &TA + #define F77_UL &UL + #define F77_DI &DI +#endif +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX; +#else + #define F77_N N + #define F77_incX incX +#endif + int n, i=0, tincX; + double *st=0,*x=(double *)X; + + if (order == CblasColMajor) + { + if (Uplo == CblasUpper) UL = 'U'; + else if (Uplo == CblasLower) UL = 'L'; + else + { + cblas_xerbla(2, "cblas_ztpmv","Illegal Uplo setting, %d\n", Uplo); + return; + } + if (TransA == CblasNoTrans) TA = 'N'; + else if (TransA == CblasTrans) TA = 'T'; + else if (TransA == CblasConjTrans) TA = 'C'; + else + { + cblas_xerbla(3, "cblas_ztpmv","Illegal TransA setting, %d\n", TransA); + return; + } + if (Diag == CblasUnit) DI = 'U'; + else if (Diag == CblasNonUnit) DI = 'N'; + else + { + cblas_xerbla(4, "cblas_ztpmv","Illegal Diag setting, %d\n", Diag); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_DI = C2F_CHAR(&DI); + #endif + F77_ztpmv( F77_UL, F77_TA, F77_DI, &F77_N, Ap, X, &F77_incX); + } + else if (order == CblasRowMajor) + { + if (Uplo == CblasUpper) UL = 'L'; + else if (Uplo == CblasLower) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_ztpmv","Illegal Uplo setting, %d\n", Uplo); + return; + } + + if (TransA == CblasNoTrans) TA = 'T'; + else if (TransA == CblasTrans) TA = 'N'; + else if (TransA == CblasConjTrans) + { + TA = 'N'; + if ( N > 0) + { + if(incX > 0) + tincX = incX; + else + tincX = -incX; + i = tincX << 1; + n = i * N; + x++; + st = x + n; + do + { + *x = -(*x); + x += i; + } + while (x != st); + x -= n; + } + } + else + { + cblas_xerbla(3, "cblas_ztpmv","Illegal TransA setting, %d\n", TransA); + return; + } + + if (Diag == CblasUnit) DI = 'U'; + else if (Diag == CblasNonUnit) DI = 'N'; + else + { + cblas_xerbla(4, "cblas_ztpmv","Illegal Diag setting, %d\n", Diag); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_DI = C2F_CHAR(&DI); + #endif + + F77_ztpmv( F77_UL, F77_TA, F77_DI, &F77_N, Ap, X,&F77_incX); + if (TransA == CblasConjTrans) + { + if (N > 0) + { + do + { + *x = -(*x); + x += i; + } + while (x != st); + } + } + } + else cblas_xerbla(1, "cblas_ztpmv", "Illegal Order setting, %d\n", order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_ztpsv.c b/ml/dlib/dlib/external/cblas/cblas_ztpsv.c new file mode 100644 index 000000000..c41d02016 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_ztpsv.c @@ -0,0 +1,138 @@ +/* + * cblas_ztpsv.c + * The program is a C interface to ztpsv. + * + * Keita Teranishi 3/23/98 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_ztpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const void *Ap, void *X, const int incX) +{ + char TA; + char UL; + char DI; +#ifdef F77_CHAR + F77_CHAR F77_TA, F77_UL, F77_DI; +#else + #define F77_TA &TA + #define F77_UL &UL + #define F77_DI &DI +#endif +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX; +#else + #define F77_N N + #define F77_incX incX +#endif + int n, i=0, tincX; + double *st=0, *x=(double*)X; + + if (order == CblasColMajor) + { + if (Uplo == CblasUpper) UL = 'U'; + else if (Uplo == CblasLower) UL = 'L'; + else + { + cblas_xerbla(2, "cblas_ztpsv","Illegal Uplo setting, %d\n", Uplo); + return; + } + if (TransA == CblasNoTrans) TA = 'N'; + else if (TransA == CblasTrans) TA = 'T'; + else if (TransA == CblasConjTrans) TA = 'C'; + else + { + cblas_xerbla(3, "cblas_ztpsv","Illegal TransA setting, %d\n", TransA); + return; + } + if (Diag == CblasUnit) DI = 'U'; + else if (Diag == CblasNonUnit) DI = 'N'; + else + { + cblas_xerbla(4, "cblas_ztpsv","Illegal Diag setting, %d\n", Diag); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_DI = C2F_CHAR(&DI); + #endif + F77_ztpsv( F77_UL, F77_TA, F77_DI, &F77_N, Ap, X, &F77_incX); + } + else if (order == CblasRowMajor) + { + if (Uplo == CblasUpper) UL = 'L'; + else if (Uplo == CblasLower) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_ztpsv","Illegal Uplo setting, %d\n", Uplo); + return; + } + + if (TransA == CblasNoTrans) TA = 'T'; + else if (TransA == CblasTrans) TA = 'N'; + else if (TransA == CblasConjTrans) + { + TA = 'N'; + if ( N > 0) + { + if ( incX > 0 ) + tincX = incX; + else + tincX = -incX; + + n = N*2*(tincX); + + x++; + + st=x+n; + + i = tincX << 1; + do + { + *x = -(*x); + x+=i; + } + while (x != st); + x -= n; + } + } + else + { + cblas_xerbla(3, "cblas_ztpsv","Illegal TransA setting, %d\n", TransA); + return; + } + + if (Diag == CblasUnit) DI = 'U'; + else if (Diag == CblasNonUnit) DI = 'N'; + else + { + cblas_xerbla(4, "cblas_ztpsv","Illegal Diag setting, %d\n", Diag); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_DI = C2F_CHAR(&DI); + #endif + + F77_ztpsv( F77_UL, F77_TA, F77_DI, &F77_N, Ap, X,&F77_incX); + + if (TransA == CblasConjTrans) + { + if (N > 0) + { + do + { + *x = -(*x); + x += i; + } + while (x != st); + } + } + } + else cblas_xerbla(1, "cblas_ztpsv", "Illegal Order setting, %d\n", order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_ztrmm.c b/ml/dlib/dlib/external/cblas/cblas_ztrmm.c new file mode 100644 index 000000000..4e76377d9 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_ztrmm.c @@ -0,0 +1,126 @@ +/* + * + * cblas_ztrmm.c + * This program is a C interface to ztrmm. + * Written by Keita Teranishi + * 4/8/1998 + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_ztrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_DIAG Diag, const int M, const int N, + const void *alpha, const void *A, const int lda, + void *B, const int ldb) +{ + char UL, TA, SD, DI; +#ifdef F77_CHAR + F77_CHAR F77_TA, F77_UL, F77_SD, F77_DI; +#else + #define F77_TA &TA + #define F77_UL &UL + #define F77_SD &SD + #define F77_DI &DI +#endif + +#ifdef F77_INT + F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_ldb=ldb; +#else + #define F77_M M + #define F77_N N + #define F77_lda lda + #define F77_ldb ldb +#endif + + + if( Order == CblasColMajor ) + { + if( Side == CblasRight ) SD='R'; + else if ( Side == CblasLeft ) SD='L'; + else + { + cblas_xerbla(2, "cblas_ztrmm", "Illegal Side setting, %d\n", Side); + return; + } + if( Uplo == CblasUpper ) UL='U'; + else if ( Uplo == CblasLower ) UL='L'; + else + { + cblas_xerbla(3, "cblas_ztrmm", "Illegal Uplo setting, %d\n", Uplo); + return; + } + + if( TransA == CblasTrans ) TA ='T'; + else if ( TransA == CblasConjTrans ) TA='C'; + else if ( TransA == CblasNoTrans ) TA='N'; + else + { + cblas_xerbla(4, "cblas_ztrmm", "Illegal Trans setting, %d\n", TransA); + return; + } + + if( Diag == CblasUnit ) DI='U'; + else if ( Diag == CblasNonUnit ) DI='N'; + else + { + cblas_xerbla(5, "cblas_ztrmm", "Illegal Diag setting, %d\n", Diag); + return; + } + + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_SD = C2F_CHAR(&SD); + F77_DI = C2F_CHAR(&DI); + #endif + + F77_ztrmm(F77_SD, F77_UL, F77_TA, F77_DI, &F77_M, &F77_N, alpha, A, &F77_lda, B, &F77_ldb); + } else if (Order == CblasRowMajor) + { + if( Side == CblasRight ) SD='L'; + else if ( Side == CblasLeft ) SD='R'; + else + { + cblas_xerbla(2, "cblas_ztrmm", "Illegal Side setting, %d\n", Side); + return; + } + + if( Uplo == CblasUpper ) UL='L'; + else if ( Uplo == CblasLower ) UL='U'; + else + { + cblas_xerbla(3, "cblas_ztrmm", "Illegal Uplo setting, %d\n", Uplo); + return; + } + + if( TransA == CblasTrans ) TA ='T'; + else if ( TransA == CblasConjTrans ) TA='C'; + else if ( TransA == CblasNoTrans ) TA='N'; + else + { + cblas_xerbla(4, "cblas_ztrmm", "Illegal Trans setting, %d\n", TransA); + return; + } + + if( Diag == CblasUnit ) DI='U'; + else if ( Diag == CblasNonUnit ) DI='N'; + else + { + cblas_xerbla(5, "cblas_ztrmm", "Illegal Diag setting, %d\n", Diag); + return; + } + + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_SD = C2F_CHAR(&SD); + F77_DI = C2F_CHAR(&DI); + #endif + + F77_ztrmm(F77_SD, F77_UL, F77_TA, F77_DI, &F77_N, &F77_M, alpha, A, &F77_lda, B, &F77_ldb); + } + else cblas_xerbla(1, "cblas_ztrmm", "Illegal Order setting, %d\n", Order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_ztrmv.c b/ml/dlib/dlib/external/cblas/cblas_ztrmv.c new file mode 100644 index 000000000..65e760529 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_ztrmv.c @@ -0,0 +1,137 @@ +/* + * cblas_ztrmv.c + * The program is a C interface to ztrmv. + * + * Keita Teranishi 5/20/98 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_ztrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const void *A, const int lda, + void *X, const int incX) + +{ + char TA; + char UL; + char DI; +#ifdef F77_CHAR + F77_CHAR F77_TA, F77_UL, F77_DI; +#else + #define F77_TA &TA + #define F77_UL &UL + #define F77_DI &DI +#endif +#ifdef F77_INT + F77_INT F77_N=N, F77_lda=lda, F77_incX=incX; +#else + #define F77_N N + #define F77_lda lda + #define F77_incX incX +#endif + int n, i=0, tincX; + double *st=0,*x=(double *)X; + + if (order == CblasColMajor) + { + if (Uplo == CblasUpper) UL = 'U'; + else if (Uplo == CblasLower) UL = 'L'; + else + { + cblas_xerbla(2, "cblas_ztrmv","Illegal Uplo setting, %d\n", Uplo); + return; + } + if (TransA == CblasNoTrans) TA = 'N'; + else if (TransA == CblasTrans) TA = 'T'; + else if (TransA == CblasConjTrans) TA = 'C'; + else + { + cblas_xerbla(3, "cblas_ztrmv","Illegal TransA setting, %d\n", TransA); + return; + } + if (Diag == CblasUnit) DI = 'U'; + else if (Diag == CblasNonUnit) DI = 'N'; + else + { + cblas_xerbla(4, "cblas_ztrmv","Illegal Diag setting, %d\n", Diag); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_DI = C2F_CHAR(&DI); + #endif + F77_ztrmv( F77_UL, F77_TA, F77_DI, &F77_N, A, &F77_lda, X, + &F77_incX); + } + else if (order == CblasRowMajor) + { + if (Uplo == CblasUpper) UL = 'L'; + else if (Uplo == CblasLower) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_ztrmv","Illegal Uplo setting, %d\n", Uplo); + return; + } + + if (TransA == CblasNoTrans) TA = 'T'; + else if (TransA == CblasTrans) TA = 'N'; + else if (TransA == CblasConjTrans) + { + TA = 'N'; + if ( N > 0) + { + if(incX > 0) + tincX = incX; + else + tincX = -incX; + i = tincX << 1; + n = i * N; + x++; + st = x + n; + do + { + *x = -(*x); + x += i; + } + while (x != st); + x -= n; + } + } + else + { + cblas_xerbla(3, "cblas_ztrmv","Illegal TransA setting, %d\n", TransA); + return; + } + + if (Diag == CblasUnit) DI = 'U'; + else if (Diag == CblasNonUnit) DI = 'N'; + else + { + cblas_xerbla(4, "cblas_ztrmv","Illegal Diag setting, %d\n", Diag); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_DI = C2F_CHAR(&DI); + #endif + F77_ztrmv( F77_UL, F77_TA, F77_DI, &F77_N, A, &F77_lda, X, + &F77_incX); + if (TransA == CblasConjTrans) + { + if (N > 0) + { + do + { + *x = -(*x); + x += i; + } + while (x != st); + } + } + } + else cblas_xerbla(1, "cblas_ztrmv", "Illegal Order setting, %d\n", order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_ztrsm.c b/ml/dlib/dlib/external/cblas/cblas_ztrsm.c new file mode 100644 index 000000000..7540147c8 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_ztrsm.c @@ -0,0 +1,132 @@ +/* + * + * cblas_ztrsm.c + * This program is a C interface to ztrsm. + * Written by Keita Teranishi + * 4/8/1998 + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_ztrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_DIAG Diag, const int M, const int N, + const void *alpha, const void *A, const int lda, + void *B, const int ldb) +{ + char UL, TA, SD, DI; +#ifdef F77_CHAR + F77_CHAR F77_TA, F77_UL, F77_SD, F77_DI; +#else + #define F77_TA &TA + #define F77_UL &UL + #define F77_SD &SD + #define F77_DI &DI +#endif + +#ifdef F77_INT + F77_INT F77_M=M, F77_N=N, F77_lda=lda, F77_ldb=ldb; +#else + #define F77_M M + #define F77_N N + #define F77_lda lda + #define F77_ldb ldb +#endif + + + if( Order == CblasColMajor ) + { + + if( Side == CblasRight) SD='R'; + else if ( Side == CblasLeft ) SD='L'; + else + { + cblas_xerbla(2, "cblas_ztrsm", "Illegal Side setting, %d\n", Side); + return; + } + + if( Uplo == CblasUpper) UL='U'; + else if ( Uplo == CblasLower ) UL='L'; + else + { + cblas_xerbla(3, "cblas_ztrsm", "Illegal Uplo setting, %d\n", Uplo); + return; + } + + if( TransA == CblasTrans) TA ='T'; + else if ( TransA == CblasConjTrans ) TA='C'; + else if ( TransA == CblasNoTrans ) TA='N'; + else + { + cblas_xerbla(4, "cblas_ztrsm", "Illegal Trans setting, %d\n", TransA); + return; + } + + if( Diag == CblasUnit ) DI='U'; + else if ( Diag == CblasNonUnit ) DI='N'; + else + { + cblas_xerbla(5, "cblas_ztrsm", "Illegal Diag setting, %d\n", Diag); + return; + } + + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_SD = C2F_CHAR(&SD); + F77_DI = C2F_CHAR(&DI); + #endif + + F77_ztrsm(F77_SD, F77_UL, F77_TA, F77_DI, &F77_M, &F77_N, alpha, A, + &F77_lda, B, &F77_ldb); + } else if (Order == CblasRowMajor) + { + + if( Side == CblasRight) SD='L'; + else if ( Side == CblasLeft ) SD='R'; + else + { + cblas_xerbla(2, "cblas_ztrsm", "Illegal Side setting, %d\n", Side); + return; + } + + if( Uplo == CblasUpper) UL='L'; + else if ( Uplo == CblasLower ) UL='U'; + else + { + cblas_xerbla(3, "cblas_ztrsm", "Illegal Uplo setting, %d\n", Uplo); + return; + } + + if( TransA == CblasTrans) TA ='T'; + else if ( TransA == CblasConjTrans ) TA='C'; + else if ( TransA == CblasNoTrans ) TA='N'; + else + { + cblas_xerbla(4, "cblas_ztrsm", "Illegal Trans setting, %d\n", TransA); + return; + } + + if( Diag == CblasUnit ) DI='U'; + else if ( Diag == CblasNonUnit ) DI='N'; + else + { + cblas_xerbla(5, "cblas_ztrsm", "Illegal Diag setting, %d\n", Diag); + return; + } + + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_SD = C2F_CHAR(&SD); + F77_DI = C2F_CHAR(&DI); + #endif + + + F77_ztrsm(F77_SD, F77_UL, F77_TA, F77_DI, &F77_N, &F77_M, alpha, A, + &F77_lda, B, &F77_ldb); + } + else cblas_xerbla(1, "cblas_ztrsm", "Illegal Order setting, %d\n", Order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cblas_ztrsv.c b/ml/dlib/dlib/external/cblas/cblas_ztrsv.c new file mode 100644 index 000000000..07e2653b4 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cblas_ztrsv.c @@ -0,0 +1,137 @@ +/* + * cblas_ztrsv.c + * The program is a C interface to ztrsv. + * + * Keita Teranishi 3/23/98 + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_ztrsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const void *A, const int lda, void *X, + const int incX) +{ + char TA; + char UL; + char DI; +#ifdef F77_CHAR + F77_CHAR F77_TA, F77_UL, F77_DI; +#else + #define F77_TA &TA + #define F77_UL &UL + #define F77_DI &DI +#endif +#ifdef F77_INT + F77_INT F77_N=N, F77_lda=lda, F77_incX=incX; +#else + #define F77_N N + #define F77_lda lda + #define F77_incX incX +#endif + int n, i=0, tincX; + double *st=0,*x=(double *)X; + + if (order == CblasColMajor) + { + if (Uplo == CblasUpper) UL = 'U'; + else if (Uplo == CblasLower) UL = 'L'; + else + { + cblas_xerbla(2, "cblas_ztrsv","Illegal Uplo setting, %d\n", Uplo); + return; + } + if (TransA == CblasNoTrans) TA = 'N'; + else if (TransA == CblasTrans) TA = 'T'; + else if (TransA == CblasConjTrans) TA = 'C'; + else + { + cblas_xerbla(3, "cblas_ztrsv","Illegal TransA setting, %d\n", TransA); + return; + } + if (Diag == CblasUnit) DI = 'U'; + else if (Diag == CblasNonUnit) DI = 'N'; + else + { + cblas_xerbla(4, "cblas_ztrsv","Illegal Diag setting, %d\n", Diag); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_DI = C2F_CHAR(&DI); + #endif + F77_ztrsv( F77_UL, F77_TA, F77_DI, &F77_N, A, &F77_lda, X, + &F77_incX); + } + else if (order == CblasRowMajor) + { + if (Uplo == CblasUpper) UL = 'L'; + else if (Uplo == CblasLower) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_ztrsv","Illegal Uplo setting, %d\n", Uplo); + return; + } + + if (TransA == CblasNoTrans) TA = 'T'; + else if (TransA == CblasTrans) TA = 'N'; + else if (TransA == CblasConjTrans) + { + TA = 'N'; + if ( N > 0) + { + if ( incX > 0 ) + tincX = incX; + else + tincX = -incX; + + n = N*2*(tincX); + x++; + st=x+n; + i = tincX << 1; + do + { + *x = -(*x); + x+=i; + } + while (x != st); + x -= n; + } + } + else + { + cblas_xerbla(3, "cblas_ztrsv","Illegal TransA setting, %d\n", TransA); + return; + } + + if (Diag == CblasUnit) DI = 'U'; + else if (Diag == CblasNonUnit) DI = 'N'; + else + { + cblas_xerbla(4, "cblas_ztrsv","Illegal Diag setting, %d\n", Diag); + return; + } + #ifdef F77_CHAR + F77_UL = C2F_CHAR(&UL); + F77_TA = C2F_CHAR(&TA); + F77_DI = C2F_CHAR(&DI); + #endif + F77_ztrsv( F77_UL, F77_TA, F77_DI, &F77_N, A, &F77_lda, X, + &F77_incX); + if (TransA == CblasConjTrans) + { + if (N > 0) + { + do + { + *x = -(*x); + x += i; + } + while (x != st); + } + } + } + else cblas_xerbla(1, "cblas_ztrsv", "Illegal Order setting, %d\n", order); + return; +} diff --git a/ml/dlib/dlib/external/cblas/cdotcsub.f b/ml/dlib/dlib/external/cblas/cdotcsub.f new file mode 100644 index 000000000..f97d7159e --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cdotcsub.f @@ -0,0 +1,15 @@ +c cdotcsub.f +c +c The program is a fortran wrapper for cdotc. +c Witten by Keita Teranishi. 2/11/1998 +c + subroutine cdotcsub(n,x,incx,y,incy,dotc) +c + external cdotc + complex cdotc,dotc + integer n,incx,incy + complex x(*),y(*) +c + dotc=cdotc(n,x,incx,y,incy) + return + end diff --git a/ml/dlib/dlib/external/cblas/cdotusub.f b/ml/dlib/dlib/external/cblas/cdotusub.f new file mode 100644 index 000000000..5107c0402 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/cdotusub.f @@ -0,0 +1,15 @@ +c cdotusub.f +c +c The program is a fortran wrapper for cdotu. +c Witten by Keita Teranishi. 2/11/1998 +c + subroutine cdotusub(n,x,incx,y,incy,dotu) +c + external cdotu + complex cdotu,dotu + integer n,incx,incy + complex x(*),y(*) +c + dotu=cdotu(n,x,incx,y,incy) + return + end diff --git a/ml/dlib/dlib/external/cblas/dasumsub.f b/ml/dlib/dlib/external/cblas/dasumsub.f new file mode 100644 index 000000000..3d64d17e6 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/dasumsub.f @@ -0,0 +1,15 @@ +c dasumsun.f +c +c The program is a fortran wrapper for dasum.. +c Witten by Keita Teranishi. 2/11/1998 +c + subroutine dasumsub(n,x,incx,asum) +c + external dasum + double precision dasum,asum + integer n,incx + double precision x(*) +c + asum=dasum(n,x,incx) + return + end diff --git a/ml/dlib/dlib/external/cblas/ddotsub.f b/ml/dlib/dlib/external/cblas/ddotsub.f new file mode 100644 index 000000000..205f3b46f --- /dev/null +++ b/ml/dlib/dlib/external/cblas/ddotsub.f @@ -0,0 +1,15 @@ +c ddotsub.f +c +c The program is a fortran wrapper for ddot. +c Witten by Keita Teranishi. 2/11/1998 +c + subroutine ddotsub(n,x,incx,y,incy,dot) +c + external ddot + double precision ddot + integer n,incx,incy + double precision x(*),y(*),dot +c + dot=ddot(n,x,incx,y,incy) + return + end diff --git a/ml/dlib/dlib/external/cblas/dnrm2sub.f b/ml/dlib/dlib/external/cblas/dnrm2sub.f new file mode 100644 index 000000000..88f17db8b --- /dev/null +++ b/ml/dlib/dlib/external/cblas/dnrm2sub.f @@ -0,0 +1,15 @@ +c dnrm2sub.f +c +c The program is a fortran wrapper for dnrm2. +c Witten by Keita Teranishi. 2/11/1998 +c + subroutine dnrm2sub(n,x,incx,nrm2) +c + external dnrm2 + double precision dnrm2,nrm2 + integer n,incx + double precision x(*) +c + nrm2=dnrm2(n,x,incx) + return + end diff --git a/ml/dlib/dlib/external/cblas/dsdotsub.f b/ml/dlib/dlib/external/cblas/dsdotsub.f new file mode 100644 index 000000000..e7e872c9e --- /dev/null +++ b/ml/dlib/dlib/external/cblas/dsdotsub.f @@ -0,0 +1,15 @@ +c dsdotsub.f +c +c The program is a fortran wrapper for dsdot. +c Witten by Keita Teranishi. 2/11/1998 +c + subroutine dsdotsub(n,x,incx,y,incy,dot) +c + external dsdot + double precision dsdot,dot + integer n,incx,incy + real x(*),y(*) +c + dot=dsdot(n,x,incx,y,incy) + return + end diff --git a/ml/dlib/dlib/external/cblas/dzasumsub.f b/ml/dlib/dlib/external/cblas/dzasumsub.f new file mode 100644 index 000000000..9aaf16387 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/dzasumsub.f @@ -0,0 +1,15 @@ +c dzasumsub.f +c +c The program is a fortran wrapper for dzasum. +c Witten by Keita Teranishi. 2/11/1998 +c + subroutine dzasumsub(n,x,incx,asum) +c + external dzasum + double precision dzasum,asum + integer n,incx + double complex x(*) +c + asum=dzasum(n,x,incx) + return + end diff --git a/ml/dlib/dlib/external/cblas/dznrm2sub.f b/ml/dlib/dlib/external/cblas/dznrm2sub.f new file mode 100644 index 000000000..45dc599f8 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/dznrm2sub.f @@ -0,0 +1,15 @@ +c dznrm2sub.f +c +c The program is a fortran wrapper for dznrm2. +c Witten by Keita Teranishi. 2/11/1998 +c + subroutine dznrm2sub(n,x,incx,nrm2) +c + external dznrm2 + double precision dznrm2,nrm2 + integer n,incx + double complex x(*) +c + nrm2=dznrm2(n,x,incx) + return + end diff --git a/ml/dlib/dlib/external/cblas/icamaxsub.f b/ml/dlib/dlib/external/cblas/icamaxsub.f new file mode 100644 index 000000000..3f47071eb --- /dev/null +++ b/ml/dlib/dlib/external/cblas/icamaxsub.f @@ -0,0 +1,15 @@ +c icamaxsub.f +c +c The program is a fortran wrapper for icamax. +c Witten by Keita Teranishi. 2/11/1998 +c + subroutine icamaxsub(n,x,incx,iamax) +c + external icamax + integer icamax,iamax + integer n,incx + complex x(*) +c + iamax=icamax(n,x,incx) + return + end diff --git a/ml/dlib/dlib/external/cblas/idamaxsub.f b/ml/dlib/dlib/external/cblas/idamaxsub.f new file mode 100644 index 000000000..3c1ee5c32 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/idamaxsub.f @@ -0,0 +1,15 @@ +c icamaxsub.f +c +c The program is a fortran wrapper for idamax. +c Witten by Keita Teranishi. 2/22/1998 +c + subroutine idamaxsub(n,x,incx,iamax) +c + external idamax + integer idamax,iamax + integer n,incx + double precision x(*) +c + iamax=idamax(n,x,incx) + return + end diff --git a/ml/dlib/dlib/external/cblas/isamaxsub.f b/ml/dlib/dlib/external/cblas/isamaxsub.f new file mode 100644 index 000000000..0faf42fde --- /dev/null +++ b/ml/dlib/dlib/external/cblas/isamaxsub.f @@ -0,0 +1,15 @@ +c isamaxsub.f +c +c The program is a fortran wrapper for isamax. +c Witten by Keita Teranishi. 2/11/1998 +c + subroutine isamaxsub(n,x,incx,iamax) +c + external isamax + integer isamax,iamax + integer n,incx + real x(*) +c + iamax=isamax(n,x,incx) + return + end diff --git a/ml/dlib/dlib/external/cblas/izamaxsub.f b/ml/dlib/dlib/external/cblas/izamaxsub.f new file mode 100644 index 000000000..5b15855a7 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/izamaxsub.f @@ -0,0 +1,15 @@ +c izamaxsub.f +c +c The program is a fortran wrapper for izamax. +c Witten by Keita Teranishi. 2/11/1998 +c + subroutine izamaxsub(n,x,incx,iamax) +c + external izamax + integer izamax,iamax + integer n,incx + double complex x(*) +c + iamax=izamax(n,x,incx) + return + end diff --git a/ml/dlib/dlib/external/cblas/sasumsub.f b/ml/dlib/dlib/external/cblas/sasumsub.f new file mode 100644 index 000000000..955f11e8d --- /dev/null +++ b/ml/dlib/dlib/external/cblas/sasumsub.f @@ -0,0 +1,15 @@ +c sasumsub.f +c +c The program is a fortran wrapper for sasum. +c Witten by Keita Teranishi. 2/11/1998 +c + subroutine sasumsub(n,x,incx,asum) +c + external sasum + real sasum,asum + integer n,incx + real x(*) +c + asum=sasum(n,x,incx) + return + end diff --git a/ml/dlib/dlib/external/cblas/scasumsub.f b/ml/dlib/dlib/external/cblas/scasumsub.f new file mode 100644 index 000000000..077ace670 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/scasumsub.f @@ -0,0 +1,15 @@ +c scasumsub.f +c +c The program is a fortran wrapper for scasum. +c Witten by Keita Teranishi. 2/11/1998 +c + subroutine scasumsub(n,x,incx,asum) +c + external scasum + real scasum,asum + integer n,incx + complex x(*) +c + asum=scasum(n,x,incx) + return + end diff --git a/ml/dlib/dlib/external/cblas/scnrm2sub.f b/ml/dlib/dlib/external/cblas/scnrm2sub.f new file mode 100644 index 000000000..7242c9742 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/scnrm2sub.f @@ -0,0 +1,15 @@ +c scnrm2sub.f +c +c The program is a fortran wrapper for scnrm2. +c Witten by Keita Teranishi. 2/11/1998 +c + subroutine scnrm2sub(n,x,incx,nrm2) +c + external scnrm2 + real scnrm2,nrm2 + integer n,incx + complex x(*) +c + nrm2=scnrm2(n,x,incx) + return + end diff --git a/ml/dlib/dlib/external/cblas/sdotsub.f b/ml/dlib/dlib/external/cblas/sdotsub.f new file mode 100644 index 000000000..e1af3c97b --- /dev/null +++ b/ml/dlib/dlib/external/cblas/sdotsub.f @@ -0,0 +1,15 @@ +c sdotsub.f +c +c The program is a fortran wrapper for sdot. +c Witten by Keita Teranishi. 2/11/1998 +c + subroutine sdotsub(n,x,incx,y,incy,dot) +c + external sdot + real sdot + integer n,incx,incy + real x(*),y(*),dot +c + dot=sdot(n,x,incx,y,incy) + return + end diff --git a/ml/dlib/dlib/external/cblas/sdsdotsub.f b/ml/dlib/dlib/external/cblas/sdsdotsub.f new file mode 100644 index 000000000..80008e9ce --- /dev/null +++ b/ml/dlib/dlib/external/cblas/sdsdotsub.f @@ -0,0 +1,15 @@ +c sdsdotsub.f +c +c The program is a fortran wrapper for sdsdot. +c Witten by Keita Teranishi. 2/11/1998 +c + subroutine sdsdotsub(n,x,incx,y,incy,dot) +c + external sdsdot + real sdsdot,dot + integer n,incx,incy + real x(*),y(*) +c + dot=sdsdot(n,x,incx,y,incy) + return + end diff --git a/ml/dlib/dlib/external/cblas/snrm2sub.f b/ml/dlib/dlib/external/cblas/snrm2sub.f new file mode 100644 index 000000000..871a6e49f --- /dev/null +++ b/ml/dlib/dlib/external/cblas/snrm2sub.f @@ -0,0 +1,15 @@ +c snrm2sub.f +c +c The program is a fortran wrapper for snrm2. +c Witten by Keita Teranishi. 2/11/1998 +c + subroutine snrm2sub(n,x,incx,nrm2) +c + external snrm2 + real snrm2,nrm2 + integer n,incx + real x(*) +c + nrm2=snrm2(n,x,incx) + return + end diff --git a/ml/dlib/dlib/external/cblas/zdotcsub.f b/ml/dlib/dlib/external/cblas/zdotcsub.f new file mode 100644 index 000000000..8d483c895 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/zdotcsub.f @@ -0,0 +1,15 @@ +c zdotcsub.f +c +c The program is a fortran wrapper for zdotc. +c Witten by Keita Teranishi. 2/11/1998 +c + subroutine zdotcsub(n,x,incx,y,incy,dotc) +c + external zdotc + double complex zdotc,dotc + integer n,incx,incy + double complex x(*),y(*) +c + dotc=zdotc(n,x,incx,y,incy) + return + end diff --git a/ml/dlib/dlib/external/cblas/zdotusub.f b/ml/dlib/dlib/external/cblas/zdotusub.f new file mode 100644 index 000000000..23f32dec3 --- /dev/null +++ b/ml/dlib/dlib/external/cblas/zdotusub.f @@ -0,0 +1,15 @@ +c zdotusub.f +c +c The program is a fortran wrapper for zdotu. +c Witten by Keita Teranishi. 2/11/1998 +c + subroutine zdotusub(n,x,incx,y,incy,dotu) +c + external zdotu + double complex zdotu,dotu + integer n,incx,incy + double complex x(*),y(*) +c + dotu=zdotu(n,x,incx,y,incy) + return + end diff --git a/ml/dlib/dlib/external/libjpeg/README b/ml/dlib/dlib/external/libjpeg/README new file mode 100644 index 000000000..86cc20669 --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/README @@ -0,0 +1,385 @@ +The Independent JPEG Group's JPEG software +========================================== + +README for release 6b of 27-Mar-1998 +==================================== + +This distribution contains the sixth public release of the Independent JPEG +Group's free JPEG software. You are welcome to redistribute this software and +to use it for any purpose, subject to the conditions under LEGAL ISSUES, below. + +Serious users of this software (particularly those incorporating it into +larger programs) should contact IJG at jpeg-info@uunet.uu.net to be added to +our electronic mailing list. Mailing list members are notified of updates +and have a chance to participate in technical discussions, etc. + +This software is the work of Tom Lane, Philip Gladstone, Jim Boucher, +Lee Crocker, Julian Minguillon, Luis Ortiz, George Phillips, Davide Rossi, +Guido Vollbeding, Ge' Weijers, and other members of the Independent JPEG +Group. + +IJG is not affiliated with the official ISO JPEG standards committee. + + +DOCUMENTATION ROADMAP +===================== + +This file contains the following sections: + +OVERVIEW General description of JPEG and the IJG software. +LEGAL ISSUES Copyright, lack of warranty, terms of distribution. +REFERENCES Where to learn more about JPEG. +ARCHIVE LOCATIONS Where to find newer versions of this software. +RELATED SOFTWARE Other stuff you should get. +FILE FORMAT WARS Software *not* to get. +TO DO Plans for future IJG releases. + +Other documentation files in the distribution are: + +User documentation: + install.doc How to configure and install the IJG software. + usage.doc Usage instructions for cjpeg, djpeg, jpegtran, + rdjpgcom, and wrjpgcom. + *.1 Unix-style man pages for programs (same info as usage.doc). + wizard.doc Advanced usage instructions for JPEG wizards only. + change.log Version-to-version change highlights. +Programmer and internal documentation: + libjpeg.doc How to use the JPEG library in your own programs. + example.c Sample code for calling the JPEG library. + structure.doc Overview of the JPEG library's internal structure. + filelist.doc Road map of IJG files. + coderules.doc Coding style rules --- please read if you contribute code. + +Please read at least the files install.doc and usage.doc. Useful information +can also be found in the JPEG FAQ (Frequently Asked Questions) article. See +ARCHIVE LOCATIONS below to find out where to obtain the FAQ article. + +If you want to understand how the JPEG code works, we suggest reading one or +more of the REFERENCES, then looking at the documentation files (in roughly +the order listed) before diving into the code. + + +OVERVIEW +======== + +This package contains C software to implement JPEG image compression and +decompression. JPEG (pronounced "jay-peg") is a standardized compression +method for full-color and gray-scale images. JPEG is intended for compressing +"real-world" scenes; line drawings, cartoons and other non-realistic images +are not its strong suit. JPEG is lossy, meaning that the output image is not +exactly identical to the input image. Hence you must not use JPEG if you +have to have identical output bits. However, on typical photographic images, +very good compression levels can be obtained with no visible change, and +remarkably high compression levels are possible if you can tolerate a +low-quality image. For more details, see the references, or just experiment +with various compression settings. + +This software implements JPEG baseline, extended-sequential, and progressive +compression processes. Provision is made for supporting all variants of these +processes, although some uncommon parameter settings aren't implemented yet. +For legal reasons, we are not distributing code for the arithmetic-coding +variants of JPEG; see LEGAL ISSUES. We have made no provision for supporting +the hierarchical or lossless processes defined in the standard. + +We provide a set of library routines for reading and writing JPEG image files, +plus two sample applications "cjpeg" and "djpeg", which use the library to +perform conversion between JPEG and some other popular image file formats. +The library is intended to be reused in other applications. + +In order to support file conversion and viewing software, we have included +considerable functionality beyond the bare JPEG coding/decoding capability; +for example, the color quantization modules are not strictly part of JPEG +decoding, but they are essential for output to colormapped file formats or +colormapped displays. These extra functions can be compiled out of the +library if not required for a particular application. We have also included +"jpegtran", a utility for lossless transcoding between different JPEG +processes, and "rdjpgcom" and "wrjpgcom", two simple applications for +inserting and extracting textual comments in JFIF files. + +The emphasis in designing this software has been on achieving portability and +flexibility, while also making it fast enough to be useful. In particular, +the software is not intended to be read as a tutorial on JPEG. (See the +REFERENCES section for introductory material.) Rather, it is intended to +be reliable, portable, industrial-strength code. We do not claim to have +achieved that goal in every aspect of the software, but we strive for it. + +We welcome the use of this software as a component of commercial products. +No royalty is required, but we do ask for an acknowledgement in product +documentation, as described under LEGAL ISSUES. + + +LEGAL ISSUES +============ + +In plain English: + +1. We don't promise that this software works. (But if you find any bugs, + please let us know!) +2. You can use this software for whatever you want. You don't have to pay us. +3. You may not pretend that you wrote this software. If you use it in a + program, you must acknowledge somewhere in your documentation that + you've used the IJG code. + +In legalese: + +The authors make NO WARRANTY or representation, either express or implied, +with respect to this software, its quality, accuracy, merchantability, or +fitness for a particular purpose. This software is provided "AS IS", and you, +its user, assume the entire risk as to its quality and accuracy. + +This software is copyright (C) 1991-1998, Thomas G. Lane. +All Rights Reserved except as specified below. + +Permission is hereby granted to use, copy, modify, and distribute this +software (or portions thereof) for any purpose, without fee, subject to these +conditions: +(1) If any part of the source code for this software is distributed, then this +README file must be included, with this copyright and no-warranty notice +unaltered; and any additions, deletions, or changes to the original files +must be clearly indicated in accompanying documentation. +(2) If only executable code is distributed, then the accompanying +documentation must state that "this software is based in part on the work of +the Independent JPEG Group". +(3) Permission for use of this software is granted only if the user accepts +full responsibility for any undesirable consequences; the authors accept +NO LIABILITY for damages of any kind. + +These conditions apply to any software derived from or based on the IJG code, +not just to the unmodified library. If you use our work, you ought to +acknowledge us. + +Permission is NOT granted for the use of any IJG author's name or company name +in advertising or publicity relating to this software or products derived from +it. This software may be referred to only as "the Independent JPEG Group's +software". + +We specifically permit and encourage the use of this software as the basis of +commercial products, provided that all warranty or liability claims are +assumed by the product vendor. + + +ansi2knr.c is included in this distribution by permission of L. Peter Deutsch, +sole proprietor of its copyright holder, Aladdin Enterprises of Menlo Park, CA. +ansi2knr.c is NOT covered by the above copyright and conditions, but instead +by the usual distribution terms of the Free Software Foundation; principally, +that you must include source code if you redistribute it. (See the file +ansi2knr.c for full details.) However, since ansi2knr.c is not needed as part +of any program generated from the IJG code, this does not limit you more than +the foregoing paragraphs do. + +The Unix configuration script "configure" was produced with GNU Autoconf. +It is copyright by the Free Software Foundation but is freely distributable. +The same holds for its supporting scripts (config.guess, config.sub, +ltconfig, ltmain.sh). Another support script, install-sh, is copyright +by M.I.T. but is also freely distributable. + +It appears that the arithmetic coding option of the JPEG spec is covered by +patents owned by IBM, AT&T, and Mitsubishi. Hence arithmetic coding cannot +legally be used without obtaining one or more licenses. For this reason, +support for arithmetic coding has been removed from the free JPEG software. +(Since arithmetic coding provides only a marginal gain over the unpatented +Huffman mode, it is unlikely that very many implementations will support it.) +So far as we are aware, there are no patent restrictions on the remaining +code. + +The IJG distribution formerly included code to read and write GIF files. +To avoid entanglement with the Unisys LZW patent, GIF reading support has +been removed altogether, and the GIF writer has been simplified to produce +"uncompressed GIFs". This technique does not use the LZW algorithm; the +resulting GIF files are larger than usual, but are readable by all standard +GIF decoders. + +We are required to state that + "The Graphics Interchange Format(c) is the Copyright property of + CompuServe Incorporated. GIF(sm) is a Service Mark property of + CompuServe Incorporated." + + +REFERENCES +========== + +We highly recommend reading one or more of these references before trying to +understand the innards of the JPEG software. + +The best short technical introduction to the JPEG compression algorithm is + Wallace, Gregory K. "The JPEG Still Picture Compression Standard", + Communications of the ACM, April 1991 (vol. 34 no. 4), pp. 30-44. +(Adjacent articles in that issue discuss MPEG motion picture compression, +applications of JPEG, and related topics.) If you don't have the CACM issue +handy, a PostScript file containing a revised version of Wallace's article is +available at ftp://ftp.uu.net/graphics/jpeg/wallace.ps.gz. The file (actually +a preprint for an article that appeared in IEEE Trans. Consumer Electronics) +omits the sample images that appeared in CACM, but it includes corrections +and some added material. Note: the Wallace article is copyright ACM and IEEE, +and it may not be used for commercial purposes. + +A somewhat less technical, more leisurely introduction to JPEG can be found in +"The Data Compression Book" by Mark Nelson and Jean-loup Gailly, published by +M&T Books (New York), 2nd ed. 1996, ISBN 1-55851-434-1. This book provides +good explanations and example C code for a multitude of compression methods +including JPEG. It is an excellent source if you are comfortable reading C +code but don't know much about data compression in general. The book's JPEG +sample code is far from industrial-strength, but when you are ready to look +at a full implementation, you've got one here... + +The best full description of JPEG is the textbook "JPEG Still Image Data +Compression Standard" by William B. Pennebaker and Joan L. Mitchell, published +by Van Nostrand Reinhold, 1993, ISBN 0-442-01272-1. Price US$59.95, 638 pp. +The book includes the complete text of the ISO JPEG standards (DIS 10918-1 +and draft DIS 10918-2). This is by far the most complete exposition of JPEG +in existence, and we highly recommend it. + +The JPEG standard itself is not available electronically; you must order a +paper copy through ISO or ITU. (Unless you feel a need to own a certified +official copy, we recommend buying the Pennebaker and Mitchell book instead; +it's much cheaper and includes a great deal of useful explanatory material.) +In the USA, copies of the standard may be ordered from ANSI Sales at (212) +642-4900, or from Global Engineering Documents at (800) 854-7179. (ANSI +doesn't take credit card orders, but Global does.) It's not cheap: as of +1992, ANSI was charging $95 for Part 1 and $47 for Part 2, plus 7% +shipping/handling. The standard is divided into two parts, Part 1 being the +actual specification, while Part 2 covers compliance testing methods. Part 1 +is titled "Digital Compression and Coding of Continuous-tone Still Images, +Part 1: Requirements and guidelines" and has document numbers ISO/IEC IS +10918-1, ITU-T T.81. Part 2 is titled "Digital Compression and Coding of +Continuous-tone Still Images, Part 2: Compliance testing" and has document +numbers ISO/IEC IS 10918-2, ITU-T T.83. + +Some extensions to the original JPEG standard are defined in JPEG Part 3, +a newer ISO standard numbered ISO/IEC IS 10918-3 and ITU-T T.84. IJG +currently does not support any Part 3 extensions. + +The JPEG standard does not specify all details of an interchangeable file +format. For the omitted details we follow the "JFIF" conventions, revision +1.02. A copy of the JFIF spec is available from: + Literature Department + C-Cube Microsystems, Inc. + 1778 McCarthy Blvd. + Milpitas, CA 95035 + phone (408) 944-6300, fax (408) 944-6314 +A PostScript version of this document is available by FTP at +ftp://ftp.uu.net/graphics/jpeg/jfif.ps.gz. There is also a plain text +version at ftp://ftp.uu.net/graphics/jpeg/jfif.txt.gz, but it is missing +the figures. + +The TIFF 6.0 file format specification can be obtained by FTP from +ftp://ftp.sgi.com/graphics/tiff/TIFF6.ps.gz. The JPEG incorporation scheme +found in the TIFF 6.0 spec of 3-June-92 has a number of serious problems. +IJG does not recommend use of the TIFF 6.0 design (TIFF Compression tag 6). +Instead, we recommend the JPEG design proposed by TIFF Technical Note #2 +(Compression tag 7). Copies of this Note can be obtained from ftp.sgi.com or +from ftp://ftp.uu.net/graphics/jpeg/. It is expected that the next revision +of the TIFF spec will replace the 6.0 JPEG design with the Note's design. +Although IJG's own code does not support TIFF/JPEG, the free libtiff library +uses our library to implement TIFF/JPEG per the Note. libtiff is available +from ftp://ftp.sgi.com/graphics/tiff/. + + +ARCHIVE LOCATIONS +================= + +The "official" archive site for this software is ftp.uu.net (Internet +address 192.48.96.9). The most recent released version can always be found +there in directory graphics/jpeg. This particular version will be archived +as ftp://ftp.uu.net/graphics/jpeg/jpegsrc.v6b.tar.gz. If you don't have +direct Internet access, UUNET's archives are also available via UUCP; contact +help@uunet.uu.net for information on retrieving files that way. + +Numerous Internet sites maintain copies of the UUNET files. However, only +ftp.uu.net is guaranteed to have the latest official version. + +You can also obtain this software in DOS-compatible "zip" archive format from +the SimTel archives (ftp://ftp.simtel.net/pub/simtelnet/msdos/graphics/), or +on CompuServe in the Graphics Support forum (GO CIS:GRAPHSUP), library 12 +"JPEG Tools". Again, these versions may sometimes lag behind the ftp.uu.net +release. + +The JPEG FAQ (Frequently Asked Questions) article is a useful source of +general information about JPEG. It is updated constantly and therefore is +not included in this distribution. The FAQ is posted every two weeks to +Usenet newsgroups comp.graphics.misc, news.answers, and other groups. +It is available on the World Wide Web at http://www.faqs.org/faqs/jpeg-faq/ +and other news.answers archive sites, including the official news.answers +archive at rtfm.mit.edu: ftp://rtfm.mit.edu/pub/usenet/news.answers/jpeg-faq/. +If you don't have Web or FTP access, send e-mail to mail-server@rtfm.mit.edu +with body + send usenet/news.answers/jpeg-faq/part1 + send usenet/news.answers/jpeg-faq/part2 + + +RELATED SOFTWARE +================ + +Numerous viewing and image manipulation programs now support JPEG. (Quite a +few of them use this library to do so.) The JPEG FAQ described above lists +some of the more popular free and shareware viewers, and tells where to +obtain them on Internet. + +If you are on a Unix machine, we highly recommend Jef Poskanzer's free +PBMPLUS software, which provides many useful operations on PPM-format image +files. In particular, it can convert PPM images to and from a wide range of +other formats, thus making cjpeg/djpeg considerably more useful. The latest +version is distributed by the NetPBM group, and is available from numerous +sites, notably ftp://wuarchive.wustl.edu/graphics/graphics/packages/NetPBM/. +Unfortunately PBMPLUS/NETPBM is not nearly as portable as the IJG software is; +you are likely to have difficulty making it work on any non-Unix machine. + +A different free JPEG implementation, written by the PVRG group at Stanford, +is available from ftp://havefun.stanford.edu/pub/jpeg/. This program +is designed for research and experimentation rather than production use; +it is slower, harder to use, and less portable than the IJG code, but it +is easier to read and modify. Also, the PVRG code supports lossless JPEG, +which we do not. (On the other hand, it doesn't do progressive JPEG.) + + +FILE FORMAT WARS +================ + +Some JPEG programs produce files that are not compatible with our library. +The root of the problem is that the ISO JPEG committee failed to specify a +concrete file format. Some vendors "filled in the blanks" on their own, +creating proprietary formats that no one else could read. (For example, none +of the early commercial JPEG implementations for the Macintosh were able to +exchange compressed files.) + +The file format we have adopted is called JFIF (see REFERENCES). This format +has been agreed to by a number of major commercial JPEG vendors, and it has +become the de facto standard. JFIF is a minimal or "low end" representation. +We recommend the use of TIFF/JPEG (TIFF revision 6.0 as modified by TIFF +Technical Note #2) for "high end" applications that need to record a lot of +additional data about an image. TIFF/JPEG is fairly new and not yet widely +supported, unfortunately. + +The upcoming JPEG Part 3 standard defines a file format called SPIFF. +SPIFF is interoperable with JFIF, in the sense that most JFIF decoders should +be able to read the most common variant of SPIFF. SPIFF has some technical +advantages over JFIF, but its major claim to fame is simply that it is an +official standard rather than an informal one. At this point it is unclear +whether SPIFF will supersede JFIF or whether JFIF will remain the de-facto +standard. IJG intends to support SPIFF once the standard is frozen, but we +have not decided whether it should become our default output format or not. +(In any case, our decoder will remain capable of reading JFIF indefinitely.) + +Various proprietary file formats incorporating JPEG compression also exist. +We have little or no sympathy for the existence of these formats. Indeed, +one of the original reasons for developing this free software was to help +force convergence on common, open format standards for JPEG files. Don't +use a proprietary file format! + + +TO DO +===== + +The major thrust for v7 will probably be improvement of visual quality. +The current method for scaling the quantization tables is known not to be +very good at low Q values. We also intend to investigate block boundary +smoothing, "poor man's variable quantization", and other means of improving +quality-vs-file-size performance without sacrificing compatibility. + +In future versions, we are considering supporting some of the upcoming JPEG +Part 3 extensions --- principally, variable quantization and the SPIFF file +format. + +As always, speeding things up is of great interest. + +Please send bug reports, offers of help, etc. to jpeg-info@uunet.uu.net. diff --git a/ml/dlib/dlib/external/libjpeg/jcapimin.cpp b/ml/dlib/dlib/external/libjpeg/jcapimin.cpp new file mode 100644 index 000000000..bc0ceac5b --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jcapimin.cpp @@ -0,0 +1,280 @@ +/* + * jcapimin.c + * + * Copyright (C) 1994-1998, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This file contains application interface code for the compression half + * of the JPEG library. These are the "minimum" API routines that may be + * needed in either the normal full-compression case or the transcoding-only + * case. + * + * Most of the routines intended to be called directly by an application + * are in this file or in jcapistd.c. But also see jcparam.c for + * parameter-setup helper routines, jcomapi.c for routines shared by + * compression and decompression, and jctrans.c for the transcoding case. + */ + +#define JPEG_INTERNALS +#include "jinclude.h" +#include "jpeglib.h" + + +/* + * Initialization of a JPEG compression object. + * The error manager must already be set up (in case memory manager fails). + */ + +GLOBAL(void) +jpeg_CreateCompress (j_compress_ptr cinfo, int version, size_t structsize) +{ + int i; + + /* Guard against version mismatches between library and caller. */ + cinfo->mem = NULL; /* so jpeg_destroy knows mem mgr not called */ + if (version != JPEG_LIB_VERSION) + ERREXIT2(cinfo, JERR_BAD_LIB_VERSION, JPEG_LIB_VERSION, version); + if (structsize != SIZEOF(struct jpeg_compress_struct)) + ERREXIT2(cinfo, JERR_BAD_STRUCT_SIZE, + (int) SIZEOF(struct jpeg_compress_struct), (int) structsize); + + /* For debugging purposes, we zero the whole master structure. + * But the application has already set the err pointer, and may have set + * client_data, so we have to save and restore those fields. + * Note: if application hasn't set client_data, tools like Purify may + * complain here. + */ + { + struct jpeg_error_mgr * err = cinfo->err; + void * client_data = cinfo->client_data; /* ignore Purify complaint here */ + MEMZERO(cinfo, SIZEOF(struct jpeg_compress_struct)); + cinfo->err = err; + cinfo->client_data = client_data; + } + cinfo->is_decompressor = FALSE; + + /* Initialize a memory manager instance for this object */ + jinit_memory_mgr((j_common_ptr) cinfo); + + /* Zero out pointers to permanent structures. */ + cinfo->progress = NULL; + cinfo->dest = NULL; + + cinfo->comp_info = NULL; + + for (i = 0; i < NUM_QUANT_TBLS; i++) + cinfo->quant_tbl_ptrs[i] = NULL; + + for (i = 0; i < NUM_HUFF_TBLS; i++) { + cinfo->dc_huff_tbl_ptrs[i] = NULL; + cinfo->ac_huff_tbl_ptrs[i] = NULL; + } + + cinfo->script_space = NULL; + + cinfo->input_gamma = 1.0; /* in case application forgets */ + + /* OK, I'm ready */ + cinfo->global_state = CSTATE_START; +} + + +/* + * Destruction of a JPEG compression object + */ + +GLOBAL(void) +jpeg_destroy_compress (j_compress_ptr cinfo) +{ + jpeg_destroy((j_common_ptr) cinfo); /* use common routine */ +} + + +/* + * Abort processing of a JPEG compression operation, + * but don't destroy the object itself. + */ + +GLOBAL(void) +jpeg_abort_compress (j_compress_ptr cinfo) +{ + jpeg_abort((j_common_ptr) cinfo); /* use common routine */ +} + + +/* + * Forcibly suppress or un-suppress all quantization and Huffman tables. + * Marks all currently defined tables as already written (if suppress) + * or not written (if !suppress). This will control whether they get emitted + * by a subsequent jpeg_start_compress call. + * + * This routine is exported for use by applications that want to produce + * abbreviated JPEG datastreams. It logically belongs in jcparam.c, but + * since it is called by jpeg_start_compress, we put it here --- otherwise + * jcparam.o would be linked whether the application used it or not. + */ + +GLOBAL(void) +jpeg_suppress_tables (j_compress_ptr cinfo, int suppress) +{ + int i; + JQUANT_TBL * qtbl; + JHUFF_TBL * htbl; + + for (i = 0; i < NUM_QUANT_TBLS; i++) { + if ((qtbl = cinfo->quant_tbl_ptrs[i]) != NULL) + qtbl->sent_table = suppress; + } + + for (i = 0; i < NUM_HUFF_TBLS; i++) { + if ((htbl = cinfo->dc_huff_tbl_ptrs[i]) != NULL) + htbl->sent_table = suppress; + if ((htbl = cinfo->ac_huff_tbl_ptrs[i]) != NULL) + htbl->sent_table = suppress; + } +} + + +/* + * Finish JPEG compression. + * + * If a multipass operating mode was selected, this may do a great deal of + * work including most of the actual output. + */ + +GLOBAL(void) +jpeg_finish_compress (j_compress_ptr cinfo) +{ + JDIMENSION iMCU_row; + + if (cinfo->global_state == CSTATE_SCANNING || + cinfo->global_state == CSTATE_RAW_OK) { + /* Terminate first pass */ + if (cinfo->next_scanline < cinfo->image_height) + ERREXIT(cinfo, JERR_TOO_LITTLE_DATA); + (*cinfo->master->finish_pass) (cinfo); + } else if (cinfo->global_state != CSTATE_WRCOEFS) + ERREXIT1(cinfo, JERR_BAD_STATE, cinfo->global_state); + /* Perform any remaining passes */ + while (! cinfo->master->is_last_pass) { + (*cinfo->master->prepare_for_pass) (cinfo); + for (iMCU_row = 0; iMCU_row < cinfo->total_iMCU_rows; iMCU_row++) { + if (cinfo->progress != NULL) { + cinfo->progress->pass_counter = (long) iMCU_row; + cinfo->progress->pass_limit = (long) cinfo->total_iMCU_rows; + (*cinfo->progress->progress_monitor) ((j_common_ptr) cinfo); + } + /* We bypass the main controller and invoke coef controller directly; + * all work is being done from the coefficient buffer. + */ + if (! (*cinfo->coef->compress_data) (cinfo, (JSAMPIMAGE) NULL)) + ERREXIT(cinfo, JERR_CANT_SUSPEND); + } + (*cinfo->master->finish_pass) (cinfo); + } + /* Write EOI, do final cleanup */ + (*cinfo->marker->write_file_trailer) (cinfo); + (*cinfo->dest->term_destination) (cinfo); + /* We can use jpeg_abort to release memory and reset global_state */ + jpeg_abort((j_common_ptr) cinfo); +} + + +/* + * Write a special marker. + * This is only recommended for writing COM or APPn markers. + * Must be called after jpeg_start_compress() and before + * first call to jpeg_write_scanlines() or jpeg_write_raw_data(). + */ + +GLOBAL(void) +jpeg_write_marker (j_compress_ptr cinfo, int marker, + const JOCTET *dataptr, unsigned int datalen) +{ + JMETHOD(void, write_marker_byte, (j_compress_ptr info, int val)); + + if (cinfo->next_scanline != 0 || + (cinfo->global_state != CSTATE_SCANNING && + cinfo->global_state != CSTATE_RAW_OK && + cinfo->global_state != CSTATE_WRCOEFS)) + ERREXIT1(cinfo, JERR_BAD_STATE, cinfo->global_state); + + (*cinfo->marker->write_marker_header) (cinfo, marker, datalen); + write_marker_byte = cinfo->marker->write_marker_byte; /* copy for speed */ + while (datalen--) { + (*write_marker_byte) (cinfo, *dataptr); + dataptr++; + } +} + +/* Same, but piecemeal. */ + +GLOBAL(void) +jpeg_write_m_header (j_compress_ptr cinfo, int marker, unsigned int datalen) +{ + if (cinfo->next_scanline != 0 || + (cinfo->global_state != CSTATE_SCANNING && + cinfo->global_state != CSTATE_RAW_OK && + cinfo->global_state != CSTATE_WRCOEFS)) + ERREXIT1(cinfo, JERR_BAD_STATE, cinfo->global_state); + + (*cinfo->marker->write_marker_header) (cinfo, marker, datalen); +} + +GLOBAL(void) +jpeg_write_m_byte (j_compress_ptr cinfo, int val) +{ + (*cinfo->marker->write_marker_byte) (cinfo, val); +} + + +/* + * Alternate compression function: just write an abbreviated table file. + * Before calling this, all parameters and a data destination must be set up. + * + * To produce a pair of files containing abbreviated tables and abbreviated + * image data, one would proceed as follows: + * + * initialize JPEG object + * set JPEG parameters + * set destination to table file + * jpeg_write_tables(cinfo); + * set destination to image file + * jpeg_start_compress(cinfo, FALSE); + * write data... + * jpeg_finish_compress(cinfo); + * + * jpeg_write_tables has the side effect of marking all tables written + * (same as jpeg_suppress_tables(..., TRUE)). Thus a subsequent start_compress + * will not re-emit the tables unless it is passed write_all_tables=TRUE. + */ + +GLOBAL(void) +jpeg_write_tables (j_compress_ptr cinfo) +{ + if (cinfo->global_state != CSTATE_START) + ERREXIT1(cinfo, JERR_BAD_STATE, cinfo->global_state); + + /* (Re)initialize error mgr and destination modules */ + (*cinfo->err->reset_error_mgr) ((j_common_ptr) cinfo); + (*cinfo->dest->init_destination) (cinfo); + /* Initialize the marker writer ... bit of a crock to do it here. */ + jinit_marker_writer(cinfo); + /* Write them tables! */ + (*cinfo->marker->write_tables_only) (cinfo); + /* And clean up. */ + (*cinfo->dest->term_destination) (cinfo); + /* + * In library releases up through v6a, we called jpeg_abort() here to free + * any working memory allocated by the destination manager and marker + * writer. Some applications had a problem with that: they allocated space + * of their own from the library memory manager, and didn't want it to go + * away during write_tables. So now we do nothing. This will cause a + * memory leak if an app calls write_tables repeatedly without doing a full + * compression cycle or otherwise resetting the JPEG object. However, that + * seems less bad than unexpectedly freeing memory in the normal case. + * An app that prefers the old behavior can call jpeg_abort for itself after + * each call to jpeg_write_tables(). + */ +} diff --git a/ml/dlib/dlib/external/libjpeg/jcapistd.cpp b/ml/dlib/dlib/external/libjpeg/jcapistd.cpp new file mode 100644 index 000000000..3f4e08063 --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jcapistd.cpp @@ -0,0 +1,161 @@ +/* + * jcapistd.c + * + * Copyright (C) 1994-1996, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This file contains application interface code for the compression half + * of the JPEG library. These are the "standard" API routines that are + * used in the normal full-compression case. They are not used by a + * transcoding-only application. Note that if an application links in + * jpeg_start_compress, it will end up linking in the entire compressor. + * We thus must separate this file from jcapimin.c to avoid linking the + * whole compression library into a transcoder. + */ + +#define JPEG_INTERNALS +#include "jinclude.h" +#include "jpeglib.h" + + +/* + * Compression initialization. + * Before calling this, all parameters and a data destination must be set up. + * + * We require a write_all_tables parameter as a failsafe check when writing + * multiple datastreams from the same compression object. Since prior runs + * will have left all the tables marked sent_table=TRUE, a subsequent run + * would emit an abbreviated stream (no tables) by default. This may be what + * is wanted, but for safety's sake it should not be the default behavior: + * programmers should have to make a deliberate choice to emit abbreviated + * images. Therefore the documentation and examples should encourage people + * to pass write_all_tables=TRUE; then it will take active thought to do the + * wrong thing. + */ + +GLOBAL(void) +jpeg_start_compress (j_compress_ptr cinfo, int write_all_tables) +{ + if (cinfo->global_state != CSTATE_START) + ERREXIT1(cinfo, JERR_BAD_STATE, cinfo->global_state); + + if (write_all_tables) + jpeg_suppress_tables(cinfo, FALSE); /* mark all tables to be written */ + + /* (Re)initialize error mgr and destination modules */ + (*cinfo->err->reset_error_mgr) ((j_common_ptr) cinfo); + (*cinfo->dest->init_destination) (cinfo); + /* Perform master selection of active modules */ + jinit_compress_master(cinfo); + /* Set up for the first pass */ + (*cinfo->master->prepare_for_pass) (cinfo); + /* Ready for application to drive first pass through jpeg_write_scanlines + * or jpeg_write_raw_data. + */ + cinfo->next_scanline = 0; + cinfo->global_state = (cinfo->raw_data_in ? CSTATE_RAW_OK : CSTATE_SCANNING); +} + + +/* + * Write some scanlines of data to the JPEG compressor. + * + * The return value will be the number of lines actually written. + * This should be less than the supplied num_lines only in case that + * the data destination module has requested suspension of the compressor, + * or if more than image_height scanlines are passed in. + * + * Note: we warn about excess calls to jpeg_write_scanlines() since + * this likely signals an application programmer error. However, + * excess scanlines passed in the last valid call are *silently* ignored, + * so that the application need not adjust num_lines for end-of-image + * when using a multiple-scanline buffer. + */ + +GLOBAL(JDIMENSION) +jpeg_write_scanlines (j_compress_ptr cinfo, JSAMPARRAY scanlines, + JDIMENSION num_lines) +{ + JDIMENSION row_ctr, rows_left; + + if (cinfo->global_state != CSTATE_SCANNING) + ERREXIT1(cinfo, JERR_BAD_STATE, cinfo->global_state); + if (cinfo->next_scanline >= cinfo->image_height) + WARNMS(cinfo, JWRN_TOO_MUCH_DATA); + + /* Call progress monitor hook if present */ + if (cinfo->progress != NULL) { + cinfo->progress->pass_counter = (long) cinfo->next_scanline; + cinfo->progress->pass_limit = (long) cinfo->image_height; + (*cinfo->progress->progress_monitor) ((j_common_ptr) cinfo); + } + + /* Give master control module another chance if this is first call to + * jpeg_write_scanlines. This lets output of the frame/scan headers be + * delayed so that application can write COM, etc, markers between + * jpeg_start_compress and jpeg_write_scanlines. + */ + if (cinfo->master->call_pass_startup) + (*cinfo->master->pass_startup) (cinfo); + + /* Ignore any extra scanlines at bottom of image. */ + rows_left = cinfo->image_height - cinfo->next_scanline; + if (num_lines > rows_left) + num_lines = rows_left; + + row_ctr = 0; + (*cinfo->main->process_data) (cinfo, scanlines, &row_ctr, num_lines); + cinfo->next_scanline += row_ctr; + return row_ctr; +} + + +/* + * Alternate entry point to write raw data. + * Processes exactly one iMCU row per call, unless suspended. + */ + +GLOBAL(JDIMENSION) +jpeg_write_raw_data (j_compress_ptr cinfo, JSAMPIMAGE data, + JDIMENSION num_lines) +{ + JDIMENSION lines_per_iMCU_row; + + if (cinfo->global_state != CSTATE_RAW_OK) + ERREXIT1(cinfo, JERR_BAD_STATE, cinfo->global_state); + if (cinfo->next_scanline >= cinfo->image_height) { + WARNMS(cinfo, JWRN_TOO_MUCH_DATA); + return 0; + } + + /* Call progress monitor hook if present */ + if (cinfo->progress != NULL) { + cinfo->progress->pass_counter = (long) cinfo->next_scanline; + cinfo->progress->pass_limit = (long) cinfo->image_height; + (*cinfo->progress->progress_monitor) ((j_common_ptr) cinfo); + } + + /* Give master control module another chance if this is first call to + * jpeg_write_raw_data. This lets output of the frame/scan headers be + * delayed so that application can write COM, etc, markers between + * jpeg_start_compress and jpeg_write_raw_data. + */ + if (cinfo->master->call_pass_startup) + (*cinfo->master->pass_startup) (cinfo); + + /* Verify that at least one iMCU row has been passed. */ + lines_per_iMCU_row = cinfo->max_v_samp_factor * DCTSIZE; + if (num_lines < lines_per_iMCU_row) + ERREXIT(cinfo, JERR_BUFFER_SIZE); + + /* Directly compress the row. */ + if (! (*cinfo->coef->compress_data) (cinfo, data)) { + /* If compressor did not consume the whole row, suspend processing. */ + return 0; + } + + /* OK, we processed one iMCU row. */ + cinfo->next_scanline += lines_per_iMCU_row; + return lines_per_iMCU_row; +} diff --git a/ml/dlib/dlib/external/libjpeg/jccoefct.cpp b/ml/dlib/dlib/external/libjpeg/jccoefct.cpp new file mode 100644 index 000000000..175d7ecd9 --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jccoefct.cpp @@ -0,0 +1,449 @@ +/* + * jccoefct.c + * + * Copyright (C) 1994-1997, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This file contains the coefficient buffer controller for compression. + * This controller is the top level of the JPEG compressor proper. + * The coefficient buffer lies between forward-DCT and entropy encoding steps. + */ + +#define JPEG_INTERNALS +#include "jinclude.h" +#include "jpeglib.h" + + +/* We use a full-image coefficient buffer when doing Huffman optimization, + * and also for writing multiple-scan JPEG files. In all cases, the DCT + * step is run during the first pass, and subsequent passes need only read + * the buffered coefficients. + */ +#ifdef ENTROPY_OPT_SUPPORTED +#define FULL_COEF_BUFFER_SUPPORTED +#else +#ifdef C_MULTISCAN_FILES_SUPPORTED +#define FULL_COEF_BUFFER_SUPPORTED +#endif +#endif + + +/* Private buffer controller object */ + +typedef struct { + struct jpeg_c_coef_controller pub; /* public fields */ + + JDIMENSION iMCU_row_num; /* iMCU row # within image */ + JDIMENSION mcu_ctr; /* counts MCUs processed in current row */ + int MCU_vert_offset; /* counts MCU rows within iMCU row */ + int MCU_rows_per_iMCU_row; /* number of such rows needed */ + + /* For single-pass compression, it's sufficient to buffer just one MCU + * (although this may prove a bit slow in practice). We allocate a + * workspace of C_MAX_BLOCKS_IN_MCU coefficient blocks, and reuse it for each + * MCU constructed and sent. (On 80x86, the workspace is FAR even though + * it's not really very big; this is to keep the module interfaces unchanged + * when a large coefficient buffer is necessary.) + * In multi-pass modes, this array points to the current MCU's blocks + * within the virtual arrays. + */ + JBLOCKROW MCU_buffer[C_MAX_BLOCKS_IN_MCU]; + + /* In multi-pass modes, we need a virtual block array for each component. */ + jvirt_barray_ptr whole_image[MAX_COMPONENTS]; +} my_coef_controller; + +typedef my_coef_controller * my_coef_ptr; + + +/* Forward declarations */ +METHODDEF(int) compress_data + JPP((j_compress_ptr cinfo, JSAMPIMAGE input_buf)); +#ifdef FULL_COEF_BUFFER_SUPPORTED +METHODDEF(int) compress_first_pass + JPP((j_compress_ptr cinfo, JSAMPIMAGE input_buf)); +METHODDEF(int) compress_output + JPP((j_compress_ptr cinfo, JSAMPIMAGE input_buf)); +#endif + + +LOCAL(void) +start_iMCU_row (j_compress_ptr cinfo) +/* Reset within-iMCU-row counters for a new row */ +{ + my_coef_ptr coef = (my_coef_ptr) cinfo->coef; + + /* In an interleaved scan, an MCU row is the same as an iMCU row. + * In a noninterleaved scan, an iMCU row has v_samp_factor MCU rows. + * But at the bottom of the image, process only what's left. + */ + if (cinfo->comps_in_scan > 1) { + coef->MCU_rows_per_iMCU_row = 1; + } else { + if (coef->iMCU_row_num < (cinfo->total_iMCU_rows-1)) + coef->MCU_rows_per_iMCU_row = cinfo->cur_comp_info[0]->v_samp_factor; + else + coef->MCU_rows_per_iMCU_row = cinfo->cur_comp_info[0]->last_row_height; + } + + coef->mcu_ctr = 0; + coef->MCU_vert_offset = 0; +} + + +/* + * Initialize for a processing pass. + */ + +METHODDEF(void) +start_pass_coef (j_compress_ptr cinfo, J_BUF_MODE pass_mode) +{ + my_coef_ptr coef = (my_coef_ptr) cinfo->coef; + + coef->iMCU_row_num = 0; + start_iMCU_row(cinfo); + + switch (pass_mode) { + case JBUF_PASS_THRU: + if (coef->whole_image[0] != NULL) + ERREXIT(cinfo, JERR_BAD_BUFFER_MODE); + coef->pub.compress_data = compress_data; + break; +#ifdef FULL_COEF_BUFFER_SUPPORTED + case JBUF_SAVE_AND_PASS: + if (coef->whole_image[0] == NULL) + ERREXIT(cinfo, JERR_BAD_BUFFER_MODE); + coef->pub.compress_data = compress_first_pass; + break; + case JBUF_CRANK_DEST: + if (coef->whole_image[0] == NULL) + ERREXIT(cinfo, JERR_BAD_BUFFER_MODE); + coef->pub.compress_data = compress_output; + break; +#endif + default: + ERREXIT(cinfo, JERR_BAD_BUFFER_MODE); + break; + } +} + + +/* + * Process some data in the single-pass case. + * We process the equivalent of one fully interleaved MCU row ("iMCU" row) + * per call, ie, v_samp_factor block rows for each component in the image. + * Returns TRUE if the iMCU row is completed, FALSE if suspended. + * + * NB: input_buf contains a plane for each component in image, + * which we index according to the component's SOF position. + */ + +METHODDEF(int) +compress_data (j_compress_ptr cinfo, JSAMPIMAGE input_buf) +{ + my_coef_ptr coef = (my_coef_ptr) cinfo->coef; + JDIMENSION MCU_col_num; /* index of current MCU within row */ + JDIMENSION last_MCU_col = cinfo->MCUs_per_row - 1; + JDIMENSION last_iMCU_row = cinfo->total_iMCU_rows - 1; + int blkn, bi, ci, yindex, yoffset, blockcnt; + JDIMENSION ypos, xpos; + jpeg_component_info *compptr; + + /* Loop to write as much as one whole iMCU row */ + for (yoffset = coef->MCU_vert_offset; yoffset < coef->MCU_rows_per_iMCU_row; + yoffset++) { + for (MCU_col_num = coef->mcu_ctr; MCU_col_num <= last_MCU_col; + MCU_col_num++) { + /* Determine where data comes from in input_buf and do the DCT thing. + * Each call on forward_DCT processes a horizontal row of DCT blocks + * as wide as an MCU; we rely on having allocated the MCU_buffer[] blocks + * sequentially. Dummy blocks at the right or bottom edge are filled in + * specially. The data in them does not matter for image reconstruction, + * so we fill them with values that will encode to the smallest amount of + * data, viz: all zeroes in the AC entries, DC entries equal to previous + * block's DC value. (Thanks to Thomas Kinsman for this idea.) + */ + blkn = 0; + for (ci = 0; ci < cinfo->comps_in_scan; ci++) { + compptr = cinfo->cur_comp_info[ci]; + blockcnt = (MCU_col_num < last_MCU_col) ? compptr->MCU_width + : compptr->last_col_width; + xpos = MCU_col_num * compptr->MCU_sample_width; + ypos = yoffset * DCTSIZE; /* ypos == (yoffset+yindex) * DCTSIZE */ + for (yindex = 0; yindex < compptr->MCU_height; yindex++) { + if (coef->iMCU_row_num < last_iMCU_row || + yoffset+yindex < compptr->last_row_height) { + (*cinfo->fdct->forward_DCT) (cinfo, compptr, + input_buf[compptr->component_index], + coef->MCU_buffer[blkn], + ypos, xpos, (JDIMENSION) blockcnt); + if (blockcnt < compptr->MCU_width) { + /* Create some dummy blocks at the right edge of the image. */ + jzero_far((void FAR *) coef->MCU_buffer[blkn + blockcnt], + (compptr->MCU_width - blockcnt) * SIZEOF(JBLOCK)); + for (bi = blockcnt; bi < compptr->MCU_width; bi++) { + coef->MCU_buffer[blkn+bi][0][0] = coef->MCU_buffer[blkn+bi-1][0][0]; + } + } + } else { + /* Create a row of dummy blocks at the bottom of the image. */ + jzero_far((void FAR *) coef->MCU_buffer[blkn], + compptr->MCU_width * SIZEOF(JBLOCK)); + for (bi = 0; bi < compptr->MCU_width; bi++) { + coef->MCU_buffer[blkn+bi][0][0] = coef->MCU_buffer[blkn-1][0][0]; + } + } + blkn += compptr->MCU_width; + ypos += DCTSIZE; + } + } + /* Try to write the MCU. In event of a suspension failure, we will + * re-DCT the MCU on restart (a bit inefficient, could be fixed...) + */ + if (! (*cinfo->entropy->encode_mcu) (cinfo, coef->MCU_buffer)) { + /* Suspension forced; update state counters and exit */ + coef->MCU_vert_offset = yoffset; + coef->mcu_ctr = MCU_col_num; + return FALSE; + } + } + /* Completed an MCU row, but perhaps not an iMCU row */ + coef->mcu_ctr = 0; + } + /* Completed the iMCU row, advance counters for next one */ + coef->iMCU_row_num++; + start_iMCU_row(cinfo); + return TRUE; +} + + +#ifdef FULL_COEF_BUFFER_SUPPORTED + +/* + * Process some data in the first pass of a multi-pass case. + * We process the equivalent of one fully interleaved MCU row ("iMCU" row) + * per call, ie, v_samp_factor block rows for each component in the image. + * This amount of data is read from the source buffer, DCT'd and quantized, + * and saved into the virtual arrays. We also generate suitable dummy blocks + * as needed at the right and lower edges. (The dummy blocks are constructed + * in the virtual arrays, which have been padded appropriately.) This makes + * it possible for subsequent passes not to worry about real vs. dummy blocks. + * + * We must also emit the data to the entropy encoder. This is conveniently + * done by calling compress_output() after we've loaded the current strip + * of the virtual arrays. + * + * NB: input_buf contains a plane for each component in image. All + * components are DCT'd and loaded into the virtual arrays in this pass. + * However, it may be that only a subset of the components are emitted to + * the entropy encoder during this first pass; be careful about looking + * at the scan-dependent variables (MCU dimensions, etc). + */ + +METHODDEF(int) +compress_first_pass (j_compress_ptr cinfo, JSAMPIMAGE input_buf) +{ + my_coef_ptr coef = (my_coef_ptr) cinfo->coef; + JDIMENSION last_iMCU_row = cinfo->total_iMCU_rows - 1; + JDIMENSION blocks_across, MCUs_across, MCUindex; + int bi, ci, h_samp_factor, block_row, block_rows, ndummy; + JCOEF lastDC; + jpeg_component_info *compptr; + JBLOCKARRAY buffer; + JBLOCKROW thisblockrow, lastblockrow; + + for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components; + ci++, compptr++) { + /* Align the virtual buffer for this component. */ + buffer = (*cinfo->mem->access_virt_barray) + ((j_common_ptr) cinfo, coef->whole_image[ci], + coef->iMCU_row_num * compptr->v_samp_factor, + (JDIMENSION) compptr->v_samp_factor, TRUE); + /* Count non-dummy DCT block rows in this iMCU row. */ + if (coef->iMCU_row_num < last_iMCU_row) + block_rows = compptr->v_samp_factor; + else { + /* NB: can't use last_row_height here, since may not be set! */ + block_rows = (int) (compptr->height_in_blocks % compptr->v_samp_factor); + if (block_rows == 0) block_rows = compptr->v_samp_factor; + } + blocks_across = compptr->width_in_blocks; + h_samp_factor = compptr->h_samp_factor; + /* Count number of dummy blocks to be added at the right margin. */ + ndummy = (int) (blocks_across % h_samp_factor); + if (ndummy > 0) + ndummy = h_samp_factor - ndummy; + /* Perform DCT for all non-dummy blocks in this iMCU row. Each call + * on forward_DCT processes a complete horizontal row of DCT blocks. + */ + for (block_row = 0; block_row < block_rows; block_row++) { + thisblockrow = buffer[block_row]; + (*cinfo->fdct->forward_DCT) (cinfo, compptr, + input_buf[ci], thisblockrow, + (JDIMENSION) (block_row * DCTSIZE), + (JDIMENSION) 0, blocks_across); + if (ndummy > 0) { + /* Create dummy blocks at the right edge of the image. */ + thisblockrow += blocks_across; /* => first dummy block */ + jzero_far((void FAR *) thisblockrow, ndummy * SIZEOF(JBLOCK)); + lastDC = thisblockrow[-1][0]; + for (bi = 0; bi < ndummy; bi++) { + thisblockrow[bi][0] = lastDC; + } + } + } + /* If at end of image, create dummy block rows as needed. + * The tricky part here is that within each MCU, we want the DC values + * of the dummy blocks to match the last real block's DC value. + * This squeezes a few more bytes out of the resulting file... + */ + if (coef->iMCU_row_num == last_iMCU_row) { + blocks_across += ndummy; /* include lower right corner */ + MCUs_across = blocks_across / h_samp_factor; + for (block_row = block_rows; block_row < compptr->v_samp_factor; + block_row++) { + thisblockrow = buffer[block_row]; + lastblockrow = buffer[block_row-1]; + jzero_far((void FAR *) thisblockrow, + (size_t) (blocks_across * SIZEOF(JBLOCK))); + for (MCUindex = 0; MCUindex < MCUs_across; MCUindex++) { + lastDC = lastblockrow[h_samp_factor-1][0]; + for (bi = 0; bi < h_samp_factor; bi++) { + thisblockrow[bi][0] = lastDC; + } + thisblockrow += h_samp_factor; /* advance to next MCU in row */ + lastblockrow += h_samp_factor; + } + } + } + } + /* NB: compress_output will increment iMCU_row_num if successful. + * A suspension return will result in redoing all the work above next time. + */ + + /* Emit data to the entropy encoder, sharing code with subsequent passes */ + return compress_output(cinfo, input_buf); +} + + +/* + * Process some data in subsequent passes of a multi-pass case. + * We process the equivalent of one fully interleaved MCU row ("iMCU" row) + * per call, ie, v_samp_factor block rows for each component in the scan. + * The data is obtained from the virtual arrays and fed to the entropy coder. + * Returns TRUE if the iMCU row is completed, FALSE if suspended. + * + * NB: input_buf is ignored; it is likely to be a NULL pointer. + */ + +METHODDEF(int) +compress_output (j_compress_ptr cinfo, JSAMPIMAGE )//input_buf) +{ + my_coef_ptr coef = (my_coef_ptr) cinfo->coef; + JDIMENSION MCU_col_num; /* index of current MCU within row */ + int blkn, ci, xindex, yindex, yoffset; + JDIMENSION start_col; + JBLOCKARRAY buffer[MAX_COMPS_IN_SCAN]; + JBLOCKROW buffer_ptr; + jpeg_component_info *compptr; + + /* Align the virtual buffers for the components used in this scan. + * NB: during first pass, this is safe only because the buffers will + * already be aligned properly, so jmemmgr.c won't need to do any I/O. + */ + for (ci = 0; ci < cinfo->comps_in_scan; ci++) { + compptr = cinfo->cur_comp_info[ci]; + buffer[ci] = (*cinfo->mem->access_virt_barray) + ((j_common_ptr) cinfo, coef->whole_image[compptr->component_index], + coef->iMCU_row_num * compptr->v_samp_factor, + (JDIMENSION) compptr->v_samp_factor, FALSE); + } + + /* Loop to process one whole iMCU row */ + for (yoffset = coef->MCU_vert_offset; yoffset < coef->MCU_rows_per_iMCU_row; + yoffset++) { + for (MCU_col_num = coef->mcu_ctr; MCU_col_num < cinfo->MCUs_per_row; + MCU_col_num++) { + /* Construct list of pointers to DCT blocks belonging to this MCU */ + blkn = 0; /* index of current DCT block within MCU */ + for (ci = 0; ci < cinfo->comps_in_scan; ci++) { + compptr = cinfo->cur_comp_info[ci]; + start_col = MCU_col_num * compptr->MCU_width; + for (yindex = 0; yindex < compptr->MCU_height; yindex++) { + buffer_ptr = buffer[ci][yindex+yoffset] + start_col; + for (xindex = 0; xindex < compptr->MCU_width; xindex++) { + coef->MCU_buffer[blkn++] = buffer_ptr++; + } + } + } + /* Try to write the MCU. */ + if (! (*cinfo->entropy->encode_mcu) (cinfo, coef->MCU_buffer)) { + /* Suspension forced; update state counters and exit */ + coef->MCU_vert_offset = yoffset; + coef->mcu_ctr = MCU_col_num; + return FALSE; + } + } + /* Completed an MCU row, but perhaps not an iMCU row */ + coef->mcu_ctr = 0; + } + /* Completed the iMCU row, advance counters for next one */ + coef->iMCU_row_num++; + start_iMCU_row(cinfo); + return TRUE; +} + +#endif /* FULL_COEF_BUFFER_SUPPORTED */ + + +/* + * Initialize coefficient buffer controller. + */ + +GLOBAL(void) +jinit_c_coef_controller (j_compress_ptr cinfo, int need_full_buffer) +{ + my_coef_ptr coef; + + coef = (my_coef_ptr) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE, + SIZEOF(my_coef_controller)); + cinfo->coef = (struct jpeg_c_coef_controller *) coef; + coef->pub.start_pass = start_pass_coef; + + /* Create the coefficient buffer. */ + if (need_full_buffer) { +#ifdef FULL_COEF_BUFFER_SUPPORTED + /* Allocate a full-image virtual array for each component, */ + /* padded to a multiple of samp_factor DCT blocks in each direction. */ + int ci; + jpeg_component_info *compptr; + + for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components; + ci++, compptr++) { + coef->whole_image[ci] = (*cinfo->mem->request_virt_barray) + ((j_common_ptr) cinfo, JPOOL_IMAGE, FALSE, + (JDIMENSION) jround_up((long) compptr->width_in_blocks, + (long) compptr->h_samp_factor), + (JDIMENSION) jround_up((long) compptr->height_in_blocks, + (long) compptr->v_samp_factor), + (JDIMENSION) compptr->v_samp_factor); + } +#else + ERREXIT(cinfo, JERR_BAD_BUFFER_MODE); +#endif + } else { + /* We only need a single-MCU buffer. */ + JBLOCKROW buffer; + int i; + + buffer = (JBLOCKROW) + (*cinfo->mem->alloc_large) ((j_common_ptr) cinfo, JPOOL_IMAGE, + C_MAX_BLOCKS_IN_MCU * SIZEOF(JBLOCK)); + for (i = 0; i < C_MAX_BLOCKS_IN_MCU; i++) { + coef->MCU_buffer[i] = buffer + i; + } + coef->whole_image[0] = NULL; /* flag for no virtual arrays */ + } +} diff --git a/ml/dlib/dlib/external/libjpeg/jccolor.cpp b/ml/dlib/dlib/external/libjpeg/jccolor.cpp new file mode 100644 index 000000000..c5cfeded5 --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jccolor.cpp @@ -0,0 +1,459 @@ +/* + * jccolor.c + * + * Copyright (C) 1991-1996, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This file contains input colorspace conversion routines. + */ + +#define JPEG_INTERNALS +#include "jinclude.h" +#include "jpeglib.h" + + +/* Private subobject */ + +typedef struct { + struct jpeg_color_converter pub; /* public fields */ + + /* Private state for RGB->YCC conversion */ + long * rgb_ycc_tab; /* => table for RGB to YCbCr conversion */ +} my_color_converter; + +typedef my_color_converter * my_cconvert_ptr; + + +/**************** RGB -> YCbCr conversion: most common case **************/ + +/* + * YCbCr is defined per CCIR 601-1, except that Cb and Cr are + * normalized to the range 0..MAXJSAMPLE rather than -0.5 .. 0.5. + * The conversion equations to be implemented are therefore + * Y = 0.29900 * R + 0.58700 * G + 0.11400 * B + * Cb = -0.16874 * R - 0.33126 * G + 0.50000 * B + CENTERJSAMPLE + * Cr = 0.50000 * R - 0.41869 * G - 0.08131 * B + CENTERJSAMPLE + * (These numbers are derived from TIFF 6.0 section 21, dated 3-June-92.) + * Note: older versions of the IJG code used a zero offset of MAXJSAMPLE/2, + * rather than CENTERJSAMPLE, for Cb and Cr. This gave equal positive and + * negative swings for Cb/Cr, but meant that grayscale values (Cb=Cr=0) + * were not represented exactly. Now we sacrifice exact representation of + * maximum red and maximum blue in order to get exact grayscales. + * + * To avoid floating-point arithmetic, we represent the fractional constants + * as integers scaled up by 2^16 (about 4 digits precision); we have to divide + * the products by 2^16, with appropriate rounding, to get the correct answer. + * + * For even more speed, we avoid doing any multiplications in the inner loop + * by precalculating the constants times R,G,B for all possible values. + * For 8-bit JSAMPLEs this is very reasonable (only 256 entries per table); + * for 12-bit samples it is still acceptable. It's not very reasonable for + * 16-bit samples, but if you want lossless storage you shouldn't be changing + * colorspace anyway. + * The CENTERJSAMPLE offsets and the rounding fudge-factor of 0.5 are included + * in the tables to save adding them separately in the inner loop. + */ + +#define SCALEBITS 16 /* speediest right-shift on some machines */ +#define CBCR_OFFSET ((long) CENTERJSAMPLE << SCALEBITS) +#define ONE_HALF ((long) 1 << (SCALEBITS-1)) +#define FIX(x) ((long) ((x) * (1L< Y section */ +#define G_Y_OFF (1*(MAXJSAMPLE+1)) /* offset to G => Y section */ +#define B_Y_OFF (2*(MAXJSAMPLE+1)) /* etc. */ +#define R_CB_OFF (3*(MAXJSAMPLE+1)) +#define G_CB_OFF (4*(MAXJSAMPLE+1)) +#define B_CB_OFF (5*(MAXJSAMPLE+1)) +#define R_CR_OFF B_CB_OFF /* B=>Cb, R=>Cr are the same */ +#define G_CR_OFF (6*(MAXJSAMPLE+1)) +#define B_CR_OFF (7*(MAXJSAMPLE+1)) +#define TABLE_SIZE (8*(MAXJSAMPLE+1)) + + +/* + * Initialize for RGB->YCC colorspace conversion. + */ + +METHODDEF(void) +rgb_ycc_start (j_compress_ptr cinfo) +{ + my_cconvert_ptr cconvert = (my_cconvert_ptr) cinfo->cconvert; + long * rgb_ycc_tab; + long i; + + /* Allocate and fill in the conversion tables. */ + cconvert->rgb_ycc_tab = rgb_ycc_tab = (long *) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE, + (TABLE_SIZE * SIZEOF(long))); + + for (i = 0; i <= MAXJSAMPLE; i++) { + rgb_ycc_tab[i+R_Y_OFF] = FIX(0.29900) * i; + rgb_ycc_tab[i+G_Y_OFF] = FIX(0.58700) * i; + rgb_ycc_tab[i+B_Y_OFF] = FIX(0.11400) * i + ONE_HALF; + rgb_ycc_tab[i+R_CB_OFF] = (-FIX(0.16874)) * i; + rgb_ycc_tab[i+G_CB_OFF] = (-FIX(0.33126)) * i; + /* We use a rounding fudge-factor of 0.5-epsilon for Cb and Cr. + * This ensures that the maximum output will round to MAXJSAMPLE + * not MAXJSAMPLE+1, and thus that we don't have to range-limit. + */ + rgb_ycc_tab[i+B_CB_OFF] = FIX(0.50000) * i + CBCR_OFFSET + ONE_HALF-1; +/* B=>Cb and R=>Cr tables are the same + rgb_ycc_tab[i+R_CR_OFF] = FIX(0.50000) * i + CBCR_OFFSET + ONE_HALF-1; +*/ + rgb_ycc_tab[i+G_CR_OFF] = (-FIX(0.41869)) * i; + rgb_ycc_tab[i+B_CR_OFF] = (-FIX(0.08131)) * i; + } +} + + +/* + * Convert some rows of samples to the JPEG colorspace. + * + * Note that we change from the application's interleaved-pixel format + * to our internal noninterleaved, one-plane-per-component format. + * The input buffer is therefore three times as wide as the output buffer. + * + * A starting row offset is provided only for the output buffer. The caller + * can easily adjust the passed input_buf value to accommodate any row + * offset required on that side. + */ + +METHODDEF(void) +rgb_ycc_convert (j_compress_ptr cinfo, + JSAMPARRAY input_buf, JSAMPIMAGE output_buf, + JDIMENSION output_row, int num_rows) +{ + my_cconvert_ptr cconvert = (my_cconvert_ptr) cinfo->cconvert; + int r, g, b; + long * ctab = cconvert->rgb_ycc_tab; + JSAMPROW inptr; + JSAMPROW outptr0, outptr1, outptr2; + JDIMENSION col; + JDIMENSION num_cols = cinfo->image_width; + + while (--num_rows >= 0) { + inptr = *input_buf++; + outptr0 = output_buf[0][output_row]; + outptr1 = output_buf[1][output_row]; + outptr2 = output_buf[2][output_row]; + output_row++; + for (col = 0; col < num_cols; col++) { + r = GETJSAMPLE(inptr[RGB_RED]); + g = GETJSAMPLE(inptr[RGB_GREEN]); + b = GETJSAMPLE(inptr[RGB_BLUE]); + inptr += RGB_PIXELSIZE; + /* If the inputs are 0..MAXJSAMPLE, the outputs of these equations + * must be too; we do not need an explicit range-limiting operation. + * Hence the value being shifted is never negative, and we don't + * need the general RIGHT_SHIFT macro. + */ + /* Y */ + outptr0[col] = (JSAMPLE) + ((ctab[r+R_Y_OFF] + ctab[g+G_Y_OFF] + ctab[b+B_Y_OFF]) + >> SCALEBITS); + /* Cb */ + outptr1[col] = (JSAMPLE) + ((ctab[r+R_CB_OFF] + ctab[g+G_CB_OFF] + ctab[b+B_CB_OFF]) + >> SCALEBITS); + /* Cr */ + outptr2[col] = (JSAMPLE) + ((ctab[r+R_CR_OFF] + ctab[g+G_CR_OFF] + ctab[b+B_CR_OFF]) + >> SCALEBITS); + } + } +} + + +/**************** Cases other than RGB -> YCbCr **************/ + + +/* + * Convert some rows of samples to the JPEG colorspace. + * This version handles RGB->grayscale conversion, which is the same + * as the RGB->Y portion of RGB->YCbCr. + * We assume rgb_ycc_start has been called (we only use the Y tables). + */ + +METHODDEF(void) +rgb_gray_convert (j_compress_ptr cinfo, + JSAMPARRAY input_buf, JSAMPIMAGE output_buf, + JDIMENSION output_row, int num_rows) +{ + my_cconvert_ptr cconvert = (my_cconvert_ptr) cinfo->cconvert; + int r, g, b; + long * ctab = cconvert->rgb_ycc_tab; + JSAMPROW inptr; + JSAMPROW outptr; + JDIMENSION col; + JDIMENSION num_cols = cinfo->image_width; + + while (--num_rows >= 0) { + inptr = *input_buf++; + outptr = output_buf[0][output_row]; + output_row++; + for (col = 0; col < num_cols; col++) { + r = GETJSAMPLE(inptr[RGB_RED]); + g = GETJSAMPLE(inptr[RGB_GREEN]); + b = GETJSAMPLE(inptr[RGB_BLUE]); + inptr += RGB_PIXELSIZE; + /* Y */ + outptr[col] = (JSAMPLE) + ((ctab[r+R_Y_OFF] + ctab[g+G_Y_OFF] + ctab[b+B_Y_OFF]) + >> SCALEBITS); + } + } +} + + +/* + * Convert some rows of samples to the JPEG colorspace. + * This version handles Adobe-style CMYK->YCCK conversion, + * where we convert R=1-C, G=1-M, and B=1-Y to YCbCr using the same + * conversion as above, while passing K (black) unchanged. + * We assume rgb_ycc_start has been called. + */ + +METHODDEF(void) +cmyk_ycck_convert (j_compress_ptr cinfo, + JSAMPARRAY input_buf, JSAMPIMAGE output_buf, + JDIMENSION output_row, int num_rows) +{ + my_cconvert_ptr cconvert = (my_cconvert_ptr) cinfo->cconvert; + int r, g, b; + long * ctab = cconvert->rgb_ycc_tab; + JSAMPROW inptr; + JSAMPROW outptr0, outptr1, outptr2, outptr3; + JDIMENSION col; + JDIMENSION num_cols = cinfo->image_width; + + while (--num_rows >= 0) { + inptr = *input_buf++; + outptr0 = output_buf[0][output_row]; + outptr1 = output_buf[1][output_row]; + outptr2 = output_buf[2][output_row]; + outptr3 = output_buf[3][output_row]; + output_row++; + for (col = 0; col < num_cols; col++) { + r = MAXJSAMPLE - GETJSAMPLE(inptr[0]); + g = MAXJSAMPLE - GETJSAMPLE(inptr[1]); + b = MAXJSAMPLE - GETJSAMPLE(inptr[2]); + /* K passes through as-is */ + outptr3[col] = inptr[3]; /* don't need GETJSAMPLE here */ + inptr += 4; + /* If the inputs are 0..MAXJSAMPLE, the outputs of these equations + * must be too; we do not need an explicit range-limiting operation. + * Hence the value being shifted is never negative, and we don't + * need the general RIGHT_SHIFT macro. + */ + /* Y */ + outptr0[col] = (JSAMPLE) + ((ctab[r+R_Y_OFF] + ctab[g+G_Y_OFF] + ctab[b+B_Y_OFF]) + >> SCALEBITS); + /* Cb */ + outptr1[col] = (JSAMPLE) + ((ctab[r+R_CB_OFF] + ctab[g+G_CB_OFF] + ctab[b+B_CB_OFF]) + >> SCALEBITS); + /* Cr */ + outptr2[col] = (JSAMPLE) + ((ctab[r+R_CR_OFF] + ctab[g+G_CR_OFF] + ctab[b+B_CR_OFF]) + >> SCALEBITS); + } + } +} + + +/* + * Convert some rows of samples to the JPEG colorspace. + * This version handles grayscale output with no conversion. + * The source can be either plain grayscale or YCbCr (since Y == gray). + */ + +METHODDEF(void) +grayscale_convert (j_compress_ptr cinfo, + JSAMPARRAY input_buf, JSAMPIMAGE output_buf, + JDIMENSION output_row, int num_rows) +{ + JSAMPROW inptr; + JSAMPROW outptr; + JDIMENSION col; + JDIMENSION num_cols = cinfo->image_width; + int instride = cinfo->input_components; + + while (--num_rows >= 0) { + inptr = *input_buf++; + outptr = output_buf[0][output_row]; + output_row++; + for (col = 0; col < num_cols; col++) { + outptr[col] = inptr[0]; /* don't need GETJSAMPLE() here */ + inptr += instride; + } + } +} + + +/* + * Convert some rows of samples to the JPEG colorspace. + * This version handles multi-component colorspaces without conversion. + * We assume input_components == num_components. + */ + +METHODDEF(void) +null_convert (j_compress_ptr cinfo, + JSAMPARRAY input_buf, JSAMPIMAGE output_buf, + JDIMENSION output_row, int num_rows) +{ + JSAMPROW inptr; + JSAMPROW outptr; + JDIMENSION col; + int ci; + int nc = cinfo->num_components; + JDIMENSION num_cols = cinfo->image_width; + + while (--num_rows >= 0) { + /* It seems fastest to make a separate pass for each component. */ + for (ci = 0; ci < nc; ci++) { + inptr = *input_buf; + outptr = output_buf[ci][output_row]; + for (col = 0; col < num_cols; col++) { + outptr[col] = inptr[ci]; /* don't need GETJSAMPLE() here */ + inptr += nc; + } + } + input_buf++; + output_row++; + } +} + + +/* + * Empty method for start_pass. + */ + +METHODDEF(void) +null_method (j_compress_ptr )//cinfo) +{ + /* no work needed */ +} + + +/* + * Module initialization routine for input colorspace conversion. + */ + +GLOBAL(void) +jinit_color_converter (j_compress_ptr cinfo) +{ + my_cconvert_ptr cconvert; + + cconvert = (my_cconvert_ptr) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE, + SIZEOF(my_color_converter)); + cinfo->cconvert = (struct jpeg_color_converter *) cconvert; + /* set start_pass to null method until we find out differently */ + cconvert->pub.start_pass = null_method; + + /* Make sure input_components agrees with in_color_space */ + switch (cinfo->in_color_space) { + case JCS_GRAYSCALE: + if (cinfo->input_components != 1) + ERREXIT(cinfo, JERR_BAD_IN_COLORSPACE); + break; + + case JCS_RGB: +#if RGB_PIXELSIZE != 3 + if (cinfo->input_components != RGB_PIXELSIZE) + ERREXIT(cinfo, JERR_BAD_IN_COLORSPACE); + break; +#endif /* else share code with YCbCr */ + + case JCS_YCbCr: + if (cinfo->input_components != 3) + ERREXIT(cinfo, JERR_BAD_IN_COLORSPACE); + break; + + case JCS_CMYK: + case JCS_YCCK: + if (cinfo->input_components != 4) + ERREXIT(cinfo, JERR_BAD_IN_COLORSPACE); + break; + + default: /* JCS_UNKNOWN can be anything */ + if (cinfo->input_components < 1) + ERREXIT(cinfo, JERR_BAD_IN_COLORSPACE); + break; + } + + /* Check num_components, set conversion method based on requested space */ + switch (cinfo->jpeg_color_space) { + case JCS_GRAYSCALE: + if (cinfo->num_components != 1) + ERREXIT(cinfo, JERR_BAD_J_COLORSPACE); + if (cinfo->in_color_space == JCS_GRAYSCALE) + cconvert->pub.color_convert = grayscale_convert; + else if (cinfo->in_color_space == JCS_RGB) { + cconvert->pub.start_pass = rgb_ycc_start; + cconvert->pub.color_convert = rgb_gray_convert; + } else if (cinfo->in_color_space == JCS_YCbCr) + cconvert->pub.color_convert = grayscale_convert; + else + ERREXIT(cinfo, JERR_CONVERSION_NOTIMPL); + break; + + case JCS_RGB: + if (cinfo->num_components != 3) + ERREXIT(cinfo, JERR_BAD_J_COLORSPACE); + if (cinfo->in_color_space == JCS_RGB && RGB_PIXELSIZE == 3) + cconvert->pub.color_convert = null_convert; + else + ERREXIT(cinfo, JERR_CONVERSION_NOTIMPL); + break; + + case JCS_YCbCr: + if (cinfo->num_components != 3) + ERREXIT(cinfo, JERR_BAD_J_COLORSPACE); + if (cinfo->in_color_space == JCS_RGB) { + cconvert->pub.start_pass = rgb_ycc_start; + cconvert->pub.color_convert = rgb_ycc_convert; + } else if (cinfo->in_color_space == JCS_YCbCr) + cconvert->pub.color_convert = null_convert; + else + ERREXIT(cinfo, JERR_CONVERSION_NOTIMPL); + break; + + case JCS_CMYK: + if (cinfo->num_components != 4) + ERREXIT(cinfo, JERR_BAD_J_COLORSPACE); + if (cinfo->in_color_space == JCS_CMYK) + cconvert->pub.color_convert = null_convert; + else + ERREXIT(cinfo, JERR_CONVERSION_NOTIMPL); + break; + + case JCS_YCCK: + if (cinfo->num_components != 4) + ERREXIT(cinfo, JERR_BAD_J_COLORSPACE); + if (cinfo->in_color_space == JCS_CMYK) { + cconvert->pub.start_pass = rgb_ycc_start; + cconvert->pub.color_convert = cmyk_ycck_convert; + } else if (cinfo->in_color_space == JCS_YCCK) + cconvert->pub.color_convert = null_convert; + else + ERREXIT(cinfo, JERR_CONVERSION_NOTIMPL); + break; + + default: /* allow null conversion of JCS_UNKNOWN */ + if (cinfo->jpeg_color_space != cinfo->in_color_space || + cinfo->num_components != cinfo->input_components) + ERREXIT(cinfo, JERR_CONVERSION_NOTIMPL); + cconvert->pub.color_convert = null_convert; + break; + } +} diff --git a/ml/dlib/dlib/external/libjpeg/jcdctmgr.cpp b/ml/dlib/dlib/external/libjpeg/jcdctmgr.cpp new file mode 100644 index 000000000..cbfc1a857 --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jcdctmgr.cpp @@ -0,0 +1,387 @@ +/* + * jcdctmgr.c + * + * Copyright (C) 1994-1996, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This file contains the forward-DCT management logic. + * This code selects a particular DCT implementation to be used, + * and it performs related housekeeping chores including coefficient + * quantization. + */ + +#define JPEG_INTERNALS +#include "jinclude.h" +#include "jpeglib.h" +#include "jdct.h" /* Private declarations for DCT subsystem */ + + +/* Private subobject for this module */ + +typedef struct { + struct jpeg_forward_dct pub; /* public fields */ + + /* Pointer to the DCT routine actually in use */ + forward_DCT_method_ptr do_dct; + + /* The actual post-DCT divisors --- not identical to the quant table + * entries, because of scaling (especially for an unnormalized DCT). + * Each table is given in normal array order. + */ + DCTELEM * divisors[NUM_QUANT_TBLS]; + +#ifdef DCT_FLOAT_SUPPORTED + /* Same as above for the floating-point case. */ + float_DCT_method_ptr do_float_dct; + FAST_FLOAT * float_divisors[NUM_QUANT_TBLS]; +#endif +} my_fdct_controller; + +typedef my_fdct_controller * my_fdct_ptr; + + +/* + * Initialize for a processing pass. + * Verify that all referenced Q-tables are present, and set up + * the divisor table for each one. + * In the current implementation, DCT of all components is done during + * the first pass, even if only some components will be output in the + * first scan. Hence all components should be examined here. + */ + +METHODDEF(void) +start_pass_fdctmgr (j_compress_ptr cinfo) +{ + my_fdct_ptr fdct = (my_fdct_ptr) cinfo->fdct; + int ci, qtblno, i; + jpeg_component_info *compptr; + JQUANT_TBL * qtbl; + DCTELEM * dtbl; + + for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components; + ci++, compptr++) { + qtblno = compptr->quant_tbl_no; + /* Make sure specified quantization table is present */ + if (qtblno < 0 || qtblno >= NUM_QUANT_TBLS || + cinfo->quant_tbl_ptrs[qtblno] == NULL) + ERREXIT1(cinfo, JERR_NO_QUANT_TABLE, qtblno); + qtbl = cinfo->quant_tbl_ptrs[qtblno]; + /* Compute divisors for this quant table */ + /* We may do this more than once for same table, but it's not a big deal */ + switch (cinfo->dct_method) { +#ifdef DCT_ISLOW_SUPPORTED + case JDCT_ISLOW: + /* For LL&M IDCT method, divisors are equal to raw quantization + * coefficients multiplied by 8 (to counteract scaling). + */ + if (fdct->divisors[qtblno] == NULL) { + fdct->divisors[qtblno] = (DCTELEM *) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE, + DCTSIZE2 * SIZEOF(DCTELEM)); + } + dtbl = fdct->divisors[qtblno]; + for (i = 0; i < DCTSIZE2; i++) { + dtbl[i] = ((DCTELEM) qtbl->quantval[i]) << 3; + } + break; +#endif +#ifdef DCT_IFAST_SUPPORTED + case JDCT_IFAST: + { + /* For AA&N IDCT method, divisors are equal to quantization + * coefficients scaled by scalefactor[row]*scalefactor[col], where + * scalefactor[0] = 1 + * scalefactor[k] = cos(k*PI/16) * sqrt(2) for k=1..7 + * We apply a further scale factor of 8. + */ +#define CONST_BITS 14 + static const short aanscales[DCTSIZE2] = { + /* precomputed values scaled up by 14 bits */ + 16384, 22725, 21407, 19266, 16384, 12873, 8867, 4520, + 22725, 31521, 29692, 26722, 22725, 17855, 12299, 6270, + 21407, 29692, 27969, 25172, 21407, 16819, 11585, 5906, + 19266, 26722, 25172, 22654, 19266, 15137, 10426, 5315, + 16384, 22725, 21407, 19266, 16384, 12873, 8867, 4520, + 12873, 17855, 16819, 15137, 12873, 10114, 6967, 3552, + 8867, 12299, 11585, 10426, 8867, 6967, 4799, 2446, + 4520, 6270, 5906, 5315, 4520, 3552, 2446, 1247 + }; + SHIFT_TEMPS + + if (fdct->divisors[qtblno] == NULL) { + fdct->divisors[qtblno] = (DCTELEM *) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE, + DCTSIZE2 * SIZEOF(DCTELEM)); + } + dtbl = fdct->divisors[qtblno]; + for (i = 0; i < DCTSIZE2; i++) { + dtbl[i] = (DCTELEM) + DESCALE(MULTIPLY16V16((long) qtbl->quantval[i], + (long) aanscales[i]), + CONST_BITS-3); + } + } + break; +#endif +#ifdef DCT_FLOAT_SUPPORTED + case JDCT_FLOAT: + { + /* For float AA&N IDCT method, divisors are equal to quantization + * coefficients scaled by scalefactor[row]*scalefactor[col], where + * scalefactor[0] = 1 + * scalefactor[k] = cos(k*PI/16) * sqrt(2) for k=1..7 + * We apply a further scale factor of 8. + * What's actually stored is 1/divisor so that the inner loop can + * use a multiplication rather than a division. + */ + FAST_FLOAT * fdtbl; + int row, col; + static const double aanscalefactor[DCTSIZE] = { + 1.0, 1.387039845, 1.306562965, 1.175875602, + 1.0, 0.785694958, 0.541196100, 0.275899379 + }; + + if (fdct->float_divisors[qtblno] == NULL) { + fdct->float_divisors[qtblno] = (FAST_FLOAT *) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE, + DCTSIZE2 * SIZEOF(FAST_FLOAT)); + } + fdtbl = fdct->float_divisors[qtblno]; + i = 0; + for (row = 0; row < DCTSIZE; row++) { + for (col = 0; col < DCTSIZE; col++) { + fdtbl[i] = (FAST_FLOAT) + (1.0 / (((double) qtbl->quantval[i] * + aanscalefactor[row] * aanscalefactor[col] * 8.0))); + i++; + } + } + } + break; +#endif + default: + ERREXIT(cinfo, JERR_NOT_COMPILED); + break; + } + } +} + + +/* + * Perform forward DCT on one or more blocks of a component. + * + * The input samples are taken from the sample_data[] array starting at + * position start_row/start_col, and moving to the right for any additional + * blocks. The quantized coefficients are returned in coef_blocks[]. + */ + +METHODDEF(void) +forward_DCT (j_compress_ptr cinfo, jpeg_component_info * compptr, + JSAMPARRAY sample_data, JBLOCKROW coef_blocks, + JDIMENSION start_row, JDIMENSION start_col, + JDIMENSION num_blocks) +/* This version is used for integer DCT implementations. */ +{ + /* This routine is heavily used, so it's worth coding it tightly. */ + my_fdct_ptr fdct = (my_fdct_ptr) cinfo->fdct; + forward_DCT_method_ptr do_dct = fdct->do_dct; + DCTELEM * divisors = fdct->divisors[compptr->quant_tbl_no]; + DCTELEM workspace[DCTSIZE2]; /* work area for FDCT subroutine */ + JDIMENSION bi; + + sample_data += start_row; /* fold in the vertical offset once */ + + for (bi = 0; bi < num_blocks; bi++, start_col += DCTSIZE) { + /* Load data into workspace, applying unsigned->signed conversion */ + { DCTELEM *workspaceptr; + JSAMPROW elemptr; + int elemr; + + workspaceptr = workspace; + for (elemr = 0; elemr < DCTSIZE; elemr++) { + elemptr = sample_data[elemr] + start_col; +#if DCTSIZE == 8 /* unroll the inner loop */ + *workspaceptr++ = GETJSAMPLE(*elemptr++) - CENTERJSAMPLE; + *workspaceptr++ = GETJSAMPLE(*elemptr++) - CENTERJSAMPLE; + *workspaceptr++ = GETJSAMPLE(*elemptr++) - CENTERJSAMPLE; + *workspaceptr++ = GETJSAMPLE(*elemptr++) - CENTERJSAMPLE; + *workspaceptr++ = GETJSAMPLE(*elemptr++) - CENTERJSAMPLE; + *workspaceptr++ = GETJSAMPLE(*elemptr++) - CENTERJSAMPLE; + *workspaceptr++ = GETJSAMPLE(*elemptr++) - CENTERJSAMPLE; + *workspaceptr++ = GETJSAMPLE(*elemptr++) - CENTERJSAMPLE; +#else + { int elemc; + for (elemc = DCTSIZE; elemc > 0; elemc--) { + *workspaceptr++ = GETJSAMPLE(*elemptr++) - CENTERJSAMPLE; + } + } +#endif + } + } + + /* Perform the DCT */ + (*do_dct) (workspace); + + /* Quantize/descale the coefficients, and store into coef_blocks[] */ + { DCTELEM temp, qval; + int i; + JCOEFPTR output_ptr = coef_blocks[bi]; + + for (i = 0; i < DCTSIZE2; i++) { + qval = divisors[i]; + temp = workspace[i]; + /* Divide the coefficient value by qval, ensuring proper rounding. + * Since C does not specify the direction of rounding for negative + * quotients, we have to force the dividend positive for portability. + * + * In most files, at least half of the output values will be zero + * (at default quantization settings, more like three-quarters...) + * so we should ensure that this case is fast. On many machines, + * a comparison is enough cheaper than a divide to make a special test + * a win. Since both inputs will be nonnegative, we need only test + * for a < b to discover whether a/b is 0. + * If your machine's division is fast enough, define FAST_DIVIDE. + */ +#ifdef FAST_DIVIDE +#define DIVIDE_BY(a,b) a /= b +#else +#define DIVIDE_BY(a,b) if (a >= b) a /= b; else a = 0 +#endif + if (temp < 0) { + temp = -temp; + temp += qval>>1; /* for rounding */ + DIVIDE_BY(temp, qval); + temp = -temp; + } else { + temp += qval>>1; /* for rounding */ + DIVIDE_BY(temp, qval); + } + output_ptr[i] = (JCOEF) temp; + } + } + } +} + + +#ifdef DCT_FLOAT_SUPPORTED + +METHODDEF(void) +forward_DCT_float (j_compress_ptr cinfo, jpeg_component_info * compptr, + JSAMPARRAY sample_data, JBLOCKROW coef_blocks, + JDIMENSION start_row, JDIMENSION start_col, + JDIMENSION num_blocks) +/* This version is used for floating-point DCT implementations. */ +{ + /* This routine is heavily used, so it's worth coding it tightly. */ + my_fdct_ptr fdct = (my_fdct_ptr) cinfo->fdct; + float_DCT_method_ptr do_dct = fdct->do_float_dct; + FAST_FLOAT * divisors = fdct->float_divisors[compptr->quant_tbl_no]; + FAST_FLOAT workspace[DCTSIZE2]; /* work area for FDCT subroutine */ + JDIMENSION bi; + + sample_data += start_row; /* fold in the vertical offset once */ + + for (bi = 0; bi < num_blocks; bi++, start_col += DCTSIZE) { + /* Load data into workspace, applying unsigned->signed conversion */ + { FAST_FLOAT *workspaceptr; + JSAMPROW elemptr; + int elemr; + + workspaceptr = workspace; + for (elemr = 0; elemr < DCTSIZE; elemr++) { + elemptr = sample_data[elemr] + start_col; +#if DCTSIZE == 8 /* unroll the inner loop */ + *workspaceptr++ = (FAST_FLOAT)(GETJSAMPLE(*elemptr++) - CENTERJSAMPLE); + *workspaceptr++ = (FAST_FLOAT)(GETJSAMPLE(*elemptr++) - CENTERJSAMPLE); + *workspaceptr++ = (FAST_FLOAT)(GETJSAMPLE(*elemptr++) - CENTERJSAMPLE); + *workspaceptr++ = (FAST_FLOAT)(GETJSAMPLE(*elemptr++) - CENTERJSAMPLE); + *workspaceptr++ = (FAST_FLOAT)(GETJSAMPLE(*elemptr++) - CENTERJSAMPLE); + *workspaceptr++ = (FAST_FLOAT)(GETJSAMPLE(*elemptr++) - CENTERJSAMPLE); + *workspaceptr++ = (FAST_FLOAT)(GETJSAMPLE(*elemptr++) - CENTERJSAMPLE); + *workspaceptr++ = (FAST_FLOAT)(GETJSAMPLE(*elemptr++) - CENTERJSAMPLE); +#else + { int elemc; + for (elemc = DCTSIZE; elemc > 0; elemc--) { + *workspaceptr++ = (FAST_FLOAT) + (GETJSAMPLE(*elemptr++) - CENTERJSAMPLE); + } + } +#endif + } + } + + /* Perform the DCT */ + (*do_dct) (workspace); + + /* Quantize/descale the coefficients, and store into coef_blocks[] */ + { FAST_FLOAT temp; + int i; + JCOEFPTR output_ptr = coef_blocks[bi]; + + for (i = 0; i < DCTSIZE2; i++) { + /* Apply the quantization and scaling factor */ + temp = workspace[i] * divisors[i]; + /* Round to nearest integer. + * Since C does not specify the direction of rounding for negative + * quotients, we have to force the dividend positive for portability. + * The maximum coefficient size is +-16K (for 12-bit data), so this + * code should work for either 16-bit or 32-bit ints. + */ + output_ptr[i] = (JCOEF) ((int) (temp + (FAST_FLOAT) 16384.5) - 16384); + } + } + } +} + +#endif /* DCT_FLOAT_SUPPORTED */ + + +/* + * Initialize FDCT manager. + */ + +GLOBAL(void) +jinit_forward_dct (j_compress_ptr cinfo) +{ + my_fdct_ptr fdct; + int i; + + fdct = (my_fdct_ptr) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE, + SIZEOF(my_fdct_controller)); + cinfo->fdct = (struct jpeg_forward_dct *) fdct; + fdct->pub.start_pass = start_pass_fdctmgr; + + switch (cinfo->dct_method) { +#ifdef DCT_ISLOW_SUPPORTED + case JDCT_ISLOW: + fdct->pub.forward_DCT = forward_DCT; + fdct->do_dct = jpeg_fdct_islow; + break; +#endif +#ifdef DCT_IFAST_SUPPORTED + case JDCT_IFAST: + fdct->pub.forward_DCT = forward_DCT; + fdct->do_dct = jpeg_fdct_ifast; + break; +#endif +#ifdef DCT_FLOAT_SUPPORTED + case JDCT_FLOAT: + fdct->pub.forward_DCT = forward_DCT_float; + fdct->do_float_dct = jpeg_fdct_float; + break; +#endif + default: + ERREXIT(cinfo, JERR_NOT_COMPILED); + break; + } + + /* Mark divisor tables unallocated */ + for (i = 0; i < NUM_QUANT_TBLS; i++) { + fdct->divisors[i] = NULL; +#ifdef DCT_FLOAT_SUPPORTED + fdct->float_divisors[i] = NULL; +#endif + } +} diff --git a/ml/dlib/dlib/external/libjpeg/jchuff.cpp b/ml/dlib/dlib/external/libjpeg/jchuff.cpp new file mode 100644 index 000000000..d543319a6 --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jchuff.cpp @@ -0,0 +1,909 @@ +/* + * jchuff.c + * + * Copyright (C) 1991-1997, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This file contains Huffman entropy encoding routines. + * + * Much of the complexity here has to do with supporting output suspension. + * If the data destination module demands suspension, we want to be able to + * back up to the start of the current MCU. To do this, we copy state + * variables into local working storage, and update them back to the + * permanent JPEG objects only upon successful completion of an MCU. + */ + +#define JPEG_INTERNALS +#include "jinclude.h" +#include "jpeglib.h" +#include "jchuff.h" /* Declarations shared with jcphuff.c */ + + +/* Expanded entropy encoder object for Huffman encoding. + * + * The savable_state subrecord contains fields that change within an MCU, + * but must not be updated permanently until we complete the MCU. + */ + +typedef struct { + long put_buffer; /* current bit-accumulation buffer */ + int put_bits; /* # of bits now in it */ + int last_dc_val[MAX_COMPS_IN_SCAN]; /* last DC coef for each component */ +} savable_state; + +/* This macro is to work around compilers with missing or broken + * structure assignment. You'll need to fix this code if you have + * such a compiler and you change MAX_COMPS_IN_SCAN. + */ + +#ifndef NO_STRUCT_ASSIGN +#define ASSIGN_STATE(dest,src) ((dest) = (src)) +#else +#if MAX_COMPS_IN_SCAN == 4 +#define ASSIGN_STATE(dest,src) \ + ((dest).put_buffer = (src).put_buffer, \ + (dest).put_bits = (src).put_bits, \ + (dest).last_dc_val[0] = (src).last_dc_val[0], \ + (dest).last_dc_val[1] = (src).last_dc_val[1], \ + (dest).last_dc_val[2] = (src).last_dc_val[2], \ + (dest).last_dc_val[3] = (src).last_dc_val[3]) +#endif +#endif + + +typedef struct { + struct jpeg_entropy_encoder pub; /* public fields */ + + savable_state saved; /* Bit buffer & DC state at start of MCU */ + + /* These fields are NOT loaded into local working state. */ + unsigned int restarts_to_go; /* MCUs left in this restart interval */ + int next_restart_num; /* next restart number to write (0-7) */ + + /* Pointers to derived tables (these workspaces have image lifespan) */ + c_derived_tbl * dc_derived_tbls[NUM_HUFF_TBLS]; + c_derived_tbl * ac_derived_tbls[NUM_HUFF_TBLS]; + +#ifdef ENTROPY_OPT_SUPPORTED /* Statistics tables for optimization */ + long * dc_count_ptrs[NUM_HUFF_TBLS]; + long * ac_count_ptrs[NUM_HUFF_TBLS]; +#endif +} huff_entropy_encoder; + +typedef huff_entropy_encoder * huff_entropy_ptr; + +/* Working state while writing an MCU. + * This struct contains all the fields that are needed by subroutines. + */ + +typedef struct { + JOCTET * next_output_byte; /* => next byte to write in buffer */ + size_t free_in_buffer; /* # of byte spaces remaining in buffer */ + savable_state cur; /* Current bit buffer & DC state */ + j_compress_ptr cinfo; /* dump_buffer needs access to this */ +} working_state; + + +/* Forward declarations */ +METHODDEF(int) encode_mcu_huff JPP((j_compress_ptr cinfo, + JBLOCKROW *MCU_data)); +METHODDEF(void) finish_pass_huff JPP((j_compress_ptr cinfo)); +#ifdef ENTROPY_OPT_SUPPORTED +METHODDEF(int) encode_mcu_gather JPP((j_compress_ptr cinfo, + JBLOCKROW *MCU_data)); +METHODDEF(void) finish_pass_gather JPP((j_compress_ptr cinfo)); +#endif + + +/* + * Initialize for a Huffman-compressed scan. + * If gather_statistics is TRUE, we do not output anything during the scan, + * just count the Huffman symbols used and generate Huffman code tables. + */ + +METHODDEF(void) +start_pass_huff (j_compress_ptr cinfo, int gather_statistics) +{ + huff_entropy_ptr entropy = (huff_entropy_ptr) cinfo->entropy; + int ci, dctbl, actbl; + jpeg_component_info * compptr; + + if (gather_statistics) { +#ifdef ENTROPY_OPT_SUPPORTED + entropy->pub.encode_mcu = encode_mcu_gather; + entropy->pub.finish_pass = finish_pass_gather; +#else + ERREXIT(cinfo, JERR_NOT_COMPILED); +#endif + } else { + entropy->pub.encode_mcu = encode_mcu_huff; + entropy->pub.finish_pass = finish_pass_huff; + } + + for (ci = 0; ci < cinfo->comps_in_scan; ci++) { + compptr = cinfo->cur_comp_info[ci]; + dctbl = compptr->dc_tbl_no; + actbl = compptr->ac_tbl_no; + if (gather_statistics) { +#ifdef ENTROPY_OPT_SUPPORTED + /* Check for invalid table indexes */ + /* (make_c_derived_tbl does this in the other path) */ + if (dctbl < 0 || dctbl >= NUM_HUFF_TBLS) + ERREXIT1(cinfo, JERR_NO_HUFF_TABLE, dctbl); + if (actbl < 0 || actbl >= NUM_HUFF_TBLS) + ERREXIT1(cinfo, JERR_NO_HUFF_TABLE, actbl); + /* Allocate and zero the statistics tables */ + /* Note that jpeg_gen_optimal_table expects 257 entries in each table! */ + if (entropy->dc_count_ptrs[dctbl] == NULL) + entropy->dc_count_ptrs[dctbl] = (long *) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE, + 257 * SIZEOF(long)); + MEMZERO(entropy->dc_count_ptrs[dctbl], 257 * SIZEOF(long)); + if (entropy->ac_count_ptrs[actbl] == NULL) + entropy->ac_count_ptrs[actbl] = (long *) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE, + 257 * SIZEOF(long)); + MEMZERO(entropy->ac_count_ptrs[actbl], 257 * SIZEOF(long)); +#endif + } else { + /* Compute derived values for Huffman tables */ + /* We may do this more than once for a table, but it's not expensive */ + jpeg_make_c_derived_tbl(cinfo, TRUE, dctbl, + & entropy->dc_derived_tbls[dctbl]); + jpeg_make_c_derived_tbl(cinfo, FALSE, actbl, + & entropy->ac_derived_tbls[actbl]); + } + /* Initialize DC predictions to 0 */ + entropy->saved.last_dc_val[ci] = 0; + } + + /* Initialize bit buffer to empty */ + entropy->saved.put_buffer = 0; + entropy->saved.put_bits = 0; + + /* Initialize restart stuff */ + entropy->restarts_to_go = cinfo->restart_interval; + entropy->next_restart_num = 0; +} + + +/* + * Compute the derived values for a Huffman table. + * This routine also performs some validation checks on the table. + * + * Note this is also used by jcphuff.c. + */ + +GLOBAL(void) +jpeg_make_c_derived_tbl (j_compress_ptr cinfo, int isDC, int tblno, + c_derived_tbl ** pdtbl) +{ + JHUFF_TBL *htbl; + c_derived_tbl *dtbl; + int p, i, l, lastp, si, maxsymbol; + char huffsize[257]; + unsigned int huffcode[257]; + unsigned int code; + + /* Note that huffsize[] and huffcode[] are filled in code-length order, + * paralleling the order of the symbols themselves in htbl->huffval[]. + */ + + /* Find the input Huffman table */ + if (tblno < 0 || tblno >= NUM_HUFF_TBLS) + ERREXIT1(cinfo, JERR_NO_HUFF_TABLE, tblno); + htbl = + isDC ? cinfo->dc_huff_tbl_ptrs[tblno] : cinfo->ac_huff_tbl_ptrs[tblno]; + if (htbl == NULL) + ERREXIT1(cinfo, JERR_NO_HUFF_TABLE, tblno); + + /* Allocate a workspace if we haven't already done so. */ + if (*pdtbl == NULL) + *pdtbl = (c_derived_tbl *) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE, + SIZEOF(c_derived_tbl)); + dtbl = *pdtbl; + + /* Figure C.1: make table of Huffman code length for each symbol */ + + p = 0; + for (l = 1; l <= 16; l++) { + i = (int) htbl->bits[l]; + if (i < 0 || p + i > 256) /* protect against table overrun */ + ERREXIT(cinfo, JERR_BAD_HUFF_TABLE); + while (i--) + huffsize[p++] = (char) l; + } + huffsize[p] = 0; + lastp = p; + + /* Figure C.2: generate the codes themselves */ + /* We also validate that the counts represent a legal Huffman code tree. */ + + code = 0; + si = huffsize[0]; + p = 0; + while (huffsize[p]) { + while (((int) huffsize[p]) == si) { + huffcode[p++] = code; + code++; + } + /* code is now 1 more than the last code used for codelength si; but + * it must still fit in si bits, since no code is allowed to be all ones. + */ + if (((long) code) >= (((long) 1) << si)) + ERREXIT(cinfo, JERR_BAD_HUFF_TABLE); + code <<= 1; + si++; + } + + /* Figure C.3: generate encoding tables */ + /* These are code and size indexed by symbol value */ + + /* Set all codeless symbols to have code length 0; + * this lets us detect duplicate VAL entries here, and later + * allows emit_bits to detect any attempt to emit such symbols. + */ + MEMZERO(dtbl->ehufsi, SIZEOF(dtbl->ehufsi)); + + /* This is also a convenient place to check for out-of-range + * and duplicated VAL entries. We allow 0..255 for AC symbols + * but only 0..15 for DC. (We could constrain them further + * based on data depth and mode, but this seems enough.) + */ + maxsymbol = isDC ? 15 : 255; + + for (p = 0; p < lastp; p++) { + i = htbl->huffval[p]; + if (i < 0 || i > maxsymbol || dtbl->ehufsi[i]) + ERREXIT(cinfo, JERR_BAD_HUFF_TABLE); + dtbl->ehufco[i] = huffcode[p]; + dtbl->ehufsi[i] = huffsize[p]; + } +} + + +/* Outputting bytes to the file */ + +/* Emit a byte, taking 'action' if must suspend. */ +#define emit_byte(state,val,action) \ + { *(state)->next_output_byte++ = (JOCTET) (val); \ + if (--(state)->free_in_buffer == 0) \ + if (! dump_buffer(state)) \ + { action; } } + + +LOCAL(int) +dump_buffer (working_state * state) +/* Empty the output buffer; return TRUE if successful, FALSE if must suspend */ +{ + struct jpeg_destination_mgr * dest = state->cinfo->dest; + + if (! (*dest->empty_output_buffer) (state->cinfo)) + return FALSE; + /* After a successful buffer dump, must reset buffer pointers */ + state->next_output_byte = dest->next_output_byte; + state->free_in_buffer = dest->free_in_buffer; + return TRUE; +} + + +/* Outputting bits to the file */ + +/* Only the right 24 bits of put_buffer are used; the valid bits are + * left-justified in this part. At most 16 bits can be passed to emit_bits + * in one call, and we never retain more than 7 bits in put_buffer + * between calls, so 24 bits are sufficient. + */ + +inline +LOCAL(int) +emit_bits (working_state * state, unsigned int code, int size) +/* Emit some bits; return TRUE if successful, FALSE if must suspend */ +{ + /* This routine is heavily used, so it's worth coding tightly. */ + long put_buffer = (long) code; + int put_bits = state->cur.put_bits; + + /* if size is 0, caller used an invalid Huffman table entry */ + if (size == 0) + ERREXIT(state->cinfo, JERR_HUFF_MISSING_CODE); + + put_buffer &= (((long) 1)<cur.put_buffer; /* and merge with old buffer contents */ + + while (put_bits >= 8) { + int c = (int) ((put_buffer >> 16) & 0xFF); + + emit_byte(state, c, return FALSE); + if (c == 0xFF) { /* need to stuff a zero byte? */ + emit_byte(state, 0, return FALSE); + } + put_buffer <<= 8; + put_bits -= 8; + } + + state->cur.put_buffer = put_buffer; /* update state variables */ + state->cur.put_bits = put_bits; + + return TRUE; +} + + +LOCAL(int) +flush_bits (working_state * state) +{ + if (! emit_bits(state, 0x7F, 7)) /* fill any partial byte with ones */ + return FALSE; + state->cur.put_buffer = 0; /* and reset bit-buffer to empty */ + state->cur.put_bits = 0; + return TRUE; +} + + +/* Encode a single block's worth of coefficients */ + +LOCAL(int) +encode_one_block (working_state * state, JCOEFPTR block, int last_dc_val, + c_derived_tbl *dctbl, c_derived_tbl *actbl) +{ + int temp, temp2; + int nbits; + int k, r, i; + + /* Encode the DC coefficient difference per section F.1.2.1 */ + + temp = temp2 = block[0] - last_dc_val; + + if (temp < 0) { + temp = -temp; /* temp is abs value of input */ + /* For a negative input, want temp2 = bitwise complement of abs(input) */ + /* This code assumes we are on a two's complement machine */ + temp2--; + } + + /* Find the number of bits needed for the magnitude of the coefficient */ + nbits = 0; + while (temp) { + nbits++; + temp >>= 1; + } + /* Check for out-of-range coefficient values. + * Since we're encoding a difference, the range limit is twice as much. + */ + if (nbits > MAX_COEF_BITS+1) + ERREXIT(state->cinfo, JERR_BAD_DCT_COEF); + + /* Emit the Huffman-coded symbol for the number of bits */ + if (! emit_bits(state, dctbl->ehufco[nbits], dctbl->ehufsi[nbits])) + return FALSE; + + /* Emit that number of bits of the value, if positive, */ + /* or the complement of its magnitude, if negative. */ + if (nbits) /* emit_bits rejects calls with size 0 */ + if (! emit_bits(state, (unsigned int) temp2, nbits)) + return FALSE; + + /* Encode the AC coefficients per section F.1.2.2 */ + + r = 0; /* r = run length of zeros */ + + for (k = 1; k < DCTSIZE2; k++) { + if ((temp = block[jpeg_natural_order[k]]) == 0) { + r++; + } else { + /* if run length > 15, must emit special run-length-16 codes (0xF0) */ + while (r > 15) { + if (! emit_bits(state, actbl->ehufco[0xF0], actbl->ehufsi[0xF0])) + return FALSE; + r -= 16; + } + + temp2 = temp; + if (temp < 0) { + temp = -temp; /* temp is abs value of input */ + /* This code assumes we are on a two's complement machine */ + temp2--; + } + + /* Find the number of bits needed for the magnitude of the coefficient */ + nbits = 1; /* there must be at least one 1 bit */ + while ((temp >>= 1)) + nbits++; + /* Check for out-of-range coefficient values */ + if (nbits > MAX_COEF_BITS) + ERREXIT(state->cinfo, JERR_BAD_DCT_COEF); + + /* Emit Huffman symbol for run length / number of bits */ + i = (r << 4) + nbits; + if (! emit_bits(state, actbl->ehufco[i], actbl->ehufsi[i])) + return FALSE; + + /* Emit that number of bits of the value, if positive, */ + /* or the complement of its magnitude, if negative. */ + if (! emit_bits(state, (unsigned int) temp2, nbits)) + return FALSE; + + r = 0; + } + } + + /* If the last coef(s) were zero, emit an end-of-block code */ + if (r > 0) + if (! emit_bits(state, actbl->ehufco[0], actbl->ehufsi[0])) + return FALSE; + + return TRUE; +} + + +/* + * Emit a restart marker & resynchronize predictions. + */ + +LOCAL(int) +emit_restart (working_state * state, int restart_num) +{ + int ci; + + if (! flush_bits(state)) + return FALSE; + + emit_byte(state, 0xFF, return FALSE); + emit_byte(state, JPEG_RST0 + restart_num, return FALSE); + + /* Re-initialize DC predictions to 0 */ + for (ci = 0; ci < state->cinfo->comps_in_scan; ci++) + state->cur.last_dc_val[ci] = 0; + + /* The restart counter is not updated until we successfully write the MCU. */ + + return TRUE; +} + + +/* + * Encode and output one MCU's worth of Huffman-compressed coefficients. + */ + +METHODDEF(int) +encode_mcu_huff (j_compress_ptr cinfo, JBLOCKROW *MCU_data) +{ + huff_entropy_ptr entropy = (huff_entropy_ptr) cinfo->entropy; + working_state state; + int blkn, ci; + jpeg_component_info * compptr; + + /* Load up working state */ + state.next_output_byte = cinfo->dest->next_output_byte; + state.free_in_buffer = cinfo->dest->free_in_buffer; + ASSIGN_STATE(state.cur, entropy->saved); + state.cinfo = cinfo; + + /* Emit restart marker if needed */ + if (cinfo->restart_interval) { + if (entropy->restarts_to_go == 0) + if (! emit_restart(&state, entropy->next_restart_num)) + return FALSE; + } + + /* Encode the MCU data blocks */ + for (blkn = 0; blkn < cinfo->blocks_in_MCU; blkn++) { + ci = cinfo->MCU_membership[blkn]; + compptr = cinfo->cur_comp_info[ci]; + if (! encode_one_block(&state, + MCU_data[blkn][0], state.cur.last_dc_val[ci], + entropy->dc_derived_tbls[compptr->dc_tbl_no], + entropy->ac_derived_tbls[compptr->ac_tbl_no])) + return FALSE; + /* Update last_dc_val */ + state.cur.last_dc_val[ci] = MCU_data[blkn][0][0]; + } + + /* Completed MCU, so update state */ + cinfo->dest->next_output_byte = state.next_output_byte; + cinfo->dest->free_in_buffer = state.free_in_buffer; + ASSIGN_STATE(entropy->saved, state.cur); + + /* Update restart-interval state too */ + if (cinfo->restart_interval) { + if (entropy->restarts_to_go == 0) { + entropy->restarts_to_go = cinfo->restart_interval; + entropy->next_restart_num++; + entropy->next_restart_num &= 7; + } + entropy->restarts_to_go--; + } + + return TRUE; +} + + +/* + * Finish up at the end of a Huffman-compressed scan. + */ + +METHODDEF(void) +finish_pass_huff (j_compress_ptr cinfo) +{ + huff_entropy_ptr entropy = (huff_entropy_ptr) cinfo->entropy; + working_state state; + + /* Load up working state ... flush_bits needs it */ + state.next_output_byte = cinfo->dest->next_output_byte; + state.free_in_buffer = cinfo->dest->free_in_buffer; + ASSIGN_STATE(state.cur, entropy->saved); + state.cinfo = cinfo; + + /* Flush out the last data */ + if (! flush_bits(&state)) + ERREXIT(cinfo, JERR_CANT_SUSPEND); + + /* Update state */ + cinfo->dest->next_output_byte = state.next_output_byte; + cinfo->dest->free_in_buffer = state.free_in_buffer; + ASSIGN_STATE(entropy->saved, state.cur); +} + + +/* + * Huffman coding optimization. + * + * We first scan the supplied data and count the number of uses of each symbol + * that is to be Huffman-coded. (This process MUST agree with the code above.) + * Then we build a Huffman coding tree for the observed counts. + * Symbols which are not needed at all for the particular image are not + * assigned any code, which saves space in the DHT marker as well as in + * the compressed data. + */ + +#ifdef ENTROPY_OPT_SUPPORTED + + +/* Process a single block's worth of coefficients */ + +LOCAL(void) +htest_one_block (j_compress_ptr cinfo, JCOEFPTR block, int last_dc_val, + long dc_counts[], long ac_counts[]) +{ + int temp; + int nbits; + int k, r; + + /* Encode the DC coefficient difference per section F.1.2.1 */ + + temp = block[0] - last_dc_val; + if (temp < 0) + temp = -temp; + + /* Find the number of bits needed for the magnitude of the coefficient */ + nbits = 0; + while (temp) { + nbits++; + temp >>= 1; + } + /* Check for out-of-range coefficient values. + * Since we're encoding a difference, the range limit is twice as much. + */ + if (nbits > MAX_COEF_BITS+1) + ERREXIT(cinfo, JERR_BAD_DCT_COEF); + + /* Count the Huffman symbol for the number of bits */ + dc_counts[nbits]++; + + /* Encode the AC coefficients per section F.1.2.2 */ + + r = 0; /* r = run length of zeros */ + + for (k = 1; k < DCTSIZE2; k++) { + if ((temp = block[jpeg_natural_order[k]]) == 0) { + r++; + } else { + /* if run length > 15, must emit special run-length-16 codes (0xF0) */ + while (r > 15) { + ac_counts[0xF0]++; + r -= 16; + } + + /* Find the number of bits needed for the magnitude of the coefficient */ + if (temp < 0) + temp = -temp; + + /* Find the number of bits needed for the magnitude of the coefficient */ + nbits = 1; /* there must be at least one 1 bit */ + while ((temp >>= 1)) + nbits++; + /* Check for out-of-range coefficient values */ + if (nbits > MAX_COEF_BITS) + ERREXIT(cinfo, JERR_BAD_DCT_COEF); + + /* Count Huffman symbol for run length / number of bits */ + ac_counts[(r << 4) + nbits]++; + + r = 0; + } + } + + /* If the last coef(s) were zero, emit an end-of-block code */ + if (r > 0) + ac_counts[0]++; +} + + +/* + * Trial-encode one MCU's worth of Huffman-compressed coefficients. + * No data is actually output, so no suspension return is possible. + */ + +METHODDEF(int) +encode_mcu_gather (j_compress_ptr cinfo, JBLOCKROW *MCU_data) +{ + huff_entropy_ptr entropy = (huff_entropy_ptr) cinfo->entropy; + int blkn, ci; + jpeg_component_info * compptr; + + /* Take care of restart intervals if needed */ + if (cinfo->restart_interval) { + if (entropy->restarts_to_go == 0) { + /* Re-initialize DC predictions to 0 */ + for (ci = 0; ci < cinfo->comps_in_scan; ci++) + entropy->saved.last_dc_val[ci] = 0; + /* Update restart state */ + entropy->restarts_to_go = cinfo->restart_interval; + } + entropy->restarts_to_go--; + } + + for (blkn = 0; blkn < cinfo->blocks_in_MCU; blkn++) { + ci = cinfo->MCU_membership[blkn]; + compptr = cinfo->cur_comp_info[ci]; + htest_one_block(cinfo, MCU_data[blkn][0], entropy->saved.last_dc_val[ci], + entropy->dc_count_ptrs[compptr->dc_tbl_no], + entropy->ac_count_ptrs[compptr->ac_tbl_no]); + entropy->saved.last_dc_val[ci] = MCU_data[blkn][0][0]; + } + + return TRUE; +} + + +/* + * Generate the best Huffman code table for the given counts, fill htbl. + * Note this is also used by jcphuff.c. + * + * The JPEG standard requires that no symbol be assigned a codeword of all + * one bits (so that padding bits added at the end of a compressed segment + * can't look like a valid code). Because of the canonical ordering of + * codewords, this just means that there must be an unused slot in the + * longest codeword length category. Section K.2 of the JPEG spec suggests + * reserving such a slot by pretending that symbol 256 is a valid symbol + * with count 1. In theory that's not optimal; giving it count zero but + * including it in the symbol set anyway should give a better Huffman code. + * But the theoretically better code actually seems to come out worse in + * practice, because it produces more all-ones bytes (which incur stuffed + * zero bytes in the final file). In any case the difference is tiny. + * + * The JPEG standard requires Huffman codes to be no more than 16 bits long. + * If some symbols have a very small but nonzero probability, the Huffman tree + * must be adjusted to meet the code length restriction. We currently use + * the adjustment method suggested in JPEG section K.2. This method is *not* + * optimal; it may not choose the best possible limited-length code. But + * typically only very-low-frequency symbols will be given less-than-optimal + * lengths, so the code is almost optimal. Experimental comparisons against + * an optimal limited-length-code algorithm indicate that the difference is + * microscopic --- usually less than a hundredth of a percent of total size. + * So the extra complexity of an optimal algorithm doesn't seem worthwhile. + */ + +GLOBAL(void) +jpeg_gen_optimal_table (j_compress_ptr cinfo, JHUFF_TBL * htbl, long freq[]) +{ +#define MAX_CLEN 32 /* assumed maximum initial code length */ + unsigned short bits[MAX_CLEN+1]; /* bits[k] = # of symbols with code length k */ + int codesize[257]; /* codesize[k] = code length of symbol k */ + int others[257]; /* next symbol in current branch of tree */ + int c1, c2; + int p, i, j; + long v; + + /* This algorithm is explained in section K.2 of the JPEG standard */ + + MEMZERO(bits, SIZEOF(bits)); + MEMZERO(codesize, SIZEOF(codesize)); + for (i = 0; i < 257; i++) + others[i] = -1; /* init links to empty */ + + freq[256] = 1; /* make sure 256 has a nonzero count */ + /* Including the pseudo-symbol 256 in the Huffman procedure guarantees + * that no real symbol is given code-value of all ones, because 256 + * will be placed last in the largest codeword category. + */ + + /* Huffman's basic algorithm to assign optimal code lengths to symbols */ + + for (;;) { + /* Find the smallest nonzero frequency, set c1 = its symbol */ + /* In case of ties, take the larger symbol number */ + c1 = -1; + v = 1000000000L; + for (i = 0; i <= 256; i++) { + if (freq[i] && freq[i] <= v) { + v = freq[i]; + c1 = i; + } + } + + /* Find the next smallest nonzero frequency, set c2 = its symbol */ + /* In case of ties, take the larger symbol number */ + c2 = -1; + v = 1000000000L; + for (i = 0; i <= 256; i++) { + if (freq[i] && freq[i] <= v && i != c1) { + v = freq[i]; + c2 = i; + } + } + + /* Done if we've merged everything into one frequency */ + if (c2 < 0) + break; + + /* Else merge the two counts/trees */ + freq[c1] += freq[c2]; + freq[c2] = 0; + + /* Increment the codesize of everything in c1's tree branch */ + codesize[c1]++; + while (others[c1] >= 0) { + c1 = others[c1]; + codesize[c1]++; + } + + others[c1] = c2; /* chain c2 onto c1's tree branch */ + + /* Increment the codesize of everything in c2's tree branch */ + codesize[c2]++; + while (others[c2] >= 0) { + c2 = others[c2]; + codesize[c2]++; + } + } + + /* Now count the number of symbols of each code length */ + for (i = 0; i <= 256; i++) { + if (codesize[i]) { + /* The JPEG standard seems to think that this can't happen, */ + /* but I'm paranoid... */ + if (codesize[i] > MAX_CLEN) + ERREXIT(cinfo, JERR_HUFF_CLEN_OVERFLOW); + + bits[codesize[i]]++; + } + } + + /* JPEG doesn't allow symbols with code lengths over 16 bits, so if the pure + * Huffman procedure assigned any such lengths, we must adjust the coding. + * Here is what the JPEG spec says about how this next bit works: + * Since symbols are paired for the longest Huffman code, the symbols are + * removed from this length category two at a time. The prefix for the pair + * (which is one bit shorter) is allocated to one of the pair; then, + * skipping the BITS entry for that prefix length, a code word from the next + * shortest nonzero BITS entry is converted into a prefix for two code words + * one bit longer. + */ + + for (i = MAX_CLEN; i > 16; i--) { + while (bits[i] > 0) { + j = i - 2; /* find length of new prefix to be used */ + while (bits[j] == 0) + j--; + + bits[i] -= 2; /* remove two symbols */ + bits[i-1]++; /* one goes in this length */ + bits[j+1] += 2; /* two new symbols in this length */ + bits[j]--; /* symbol of this length is now a prefix */ + } + } + + /* Remove the count for the pseudo-symbol 256 from the largest codelength */ + while (bits[i] == 0) /* find largest codelength still in use */ + i--; + bits[i]--; + + /* Return final symbol counts (only for lengths 0..16) */ + MEMCOPY(htbl->bits, bits, SIZEOF(htbl->bits)); + + /* Return a list of the symbols sorted by code length */ + /* It's not real clear to me why we don't need to consider the codelength + * changes made above, but the JPEG spec seems to think this works. + */ + p = 0; + for (i = 1; i <= MAX_CLEN; i++) { + for (j = 0; j <= 255; j++) { + if (codesize[j] == i) { + htbl->huffval[p] = (unsigned char) j; + p++; + } + } + } + + /* Set sent_table FALSE so updated table will be written to JPEG file. */ + htbl->sent_table = FALSE; +} + + +/* + * Finish up a statistics-gathering pass and create the new Huffman tables. + */ + +METHODDEF(void) +finish_pass_gather (j_compress_ptr cinfo) +{ + huff_entropy_ptr entropy = (huff_entropy_ptr) cinfo->entropy; + int ci, dctbl, actbl; + jpeg_component_info * compptr; + JHUFF_TBL **htblptr; + int did_dc[NUM_HUFF_TBLS]; + int did_ac[NUM_HUFF_TBLS]; + + /* It's important not to apply jpeg_gen_optimal_table more than once + * per table, because it clobbers the input frequency counts! + */ + MEMZERO(did_dc, SIZEOF(did_dc)); + MEMZERO(did_ac, SIZEOF(did_ac)); + + for (ci = 0; ci < cinfo->comps_in_scan; ci++) { + compptr = cinfo->cur_comp_info[ci]; + dctbl = compptr->dc_tbl_no; + actbl = compptr->ac_tbl_no; + if (! did_dc[dctbl]) { + htblptr = & cinfo->dc_huff_tbl_ptrs[dctbl]; + if (*htblptr == NULL) + *htblptr = jpeg_alloc_huff_table((j_common_ptr) cinfo); + jpeg_gen_optimal_table(cinfo, *htblptr, entropy->dc_count_ptrs[dctbl]); + did_dc[dctbl] = TRUE; + } + if (! did_ac[actbl]) { + htblptr = & cinfo->ac_huff_tbl_ptrs[actbl]; + if (*htblptr == NULL) + *htblptr = jpeg_alloc_huff_table((j_common_ptr) cinfo); + jpeg_gen_optimal_table(cinfo, *htblptr, entropy->ac_count_ptrs[actbl]); + did_ac[actbl] = TRUE; + } + } +} + + +#endif /* ENTROPY_OPT_SUPPORTED */ + + +/* + * Module initialization routine for Huffman entropy encoding. + */ + +GLOBAL(void) +jinit_huff_encoder (j_compress_ptr cinfo) +{ + huff_entropy_ptr entropy; + int i; + + entropy = (huff_entropy_ptr) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE, + SIZEOF(huff_entropy_encoder)); + cinfo->entropy = (struct jpeg_entropy_encoder *) entropy; + entropy->pub.start_pass = start_pass_huff; + + /* Mark tables unallocated */ + for (i = 0; i < NUM_HUFF_TBLS; i++) { + entropy->dc_derived_tbls[i] = entropy->ac_derived_tbls[i] = NULL; +#ifdef ENTROPY_OPT_SUPPORTED + entropy->dc_count_ptrs[i] = entropy->ac_count_ptrs[i] = NULL; +#endif + } +} diff --git a/ml/dlib/dlib/external/libjpeg/jchuff.h b/ml/dlib/dlib/external/libjpeg/jchuff.h new file mode 100644 index 000000000..2d184ec55 --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jchuff.h @@ -0,0 +1,47 @@ +/* + * jchuff.h + * + * Copyright (C) 1991-1997, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This file contains declarations for Huffman entropy encoding routines + * that are shared between the sequential encoder (jchuff.c) and the + * progressive encoder (jcphuff.c). No other modules need to see these. + */ + +/* The legal range of a DCT coefficient is + * -1024 .. +1023 for 8-bit data; + * -16384 .. +16383 for 12-bit data. + * Hence the magnitude should always fit in 10 or 14 bits respectively. + */ + +#if BITS_IN_JSAMPLE == 8 +#define MAX_COEF_BITS 10 +#else +#define MAX_COEF_BITS 14 +#endif + +/* Derived data constructed for each Huffman table */ + +typedef struct { + unsigned int ehufco[256]; /* code for each symbol */ + char ehufsi[256]; /* length of code for each symbol */ + /* If no code has been allocated for a symbol S, ehufsi[S] contains 0 */ +} c_derived_tbl; + +/* Short forms of external names for systems with brain-damaged linkers. */ + +#ifdef NEED_SHORT_EXTERNAL_NAMES +#define jpeg_make_c_derived_tbl jMkCDerived +#define jpeg_gen_optimal_table jGenOptTbl +#endif /* NEED_SHORT_EXTERNAL_NAMES */ + +/* Expand a Huffman table definition into the derived format */ +EXTERN(void) jpeg_make_c_derived_tbl + JPP((j_compress_ptr cinfo, int isDC, int tblno, + c_derived_tbl ** pdtbl)); + +/* Generate an optimal table definition given the specified counts */ +EXTERN(void) jpeg_gen_optimal_table + JPP((j_compress_ptr cinfo, JHUFF_TBL * htbl, long freq[])); diff --git a/ml/dlib/dlib/external/libjpeg/jcinit.cpp b/ml/dlib/dlib/external/libjpeg/jcinit.cpp new file mode 100644 index 000000000..2ca809cfd --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jcinit.cpp @@ -0,0 +1,72 @@ +/* + * jcinit.c + * + * Copyright (C) 1991-1997, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This file contains initialization logic for the JPEG compressor. + * This routine is in charge of selecting the modules to be executed and + * making an initialization call to each one. + * + * Logically, this code belongs in jcmaster.c. It's split out because + * linking this routine implies linking the entire compression library. + * For a transcoding-only application, we want to be able to use jcmaster.c + * without linking in the whole library. + */ + +#define JPEG_INTERNALS +#include "jinclude.h" +#include "jpeglib.h" + + +/* + * Master selection of compression modules. + * This is done once at the start of processing an image. We determine + * which modules will be used and give them appropriate initialization calls. + */ + +GLOBAL(void) +jinit_compress_master (j_compress_ptr cinfo) +{ + /* Initialize master control (includes parameter checking/processing) */ + jinit_c_master_control(cinfo, FALSE /* full compression */); + + /* Preprocessing */ + if (! cinfo->raw_data_in) { + jinit_color_converter(cinfo); + jinit_downsampler(cinfo); + jinit_c_prep_controller(cinfo, FALSE /* never need full buffer here */); + } + /* Forward DCT */ + jinit_forward_dct(cinfo); + /* Entropy encoding: either Huffman or arithmetic coding. */ + if (cinfo->arith_code) { + ERREXIT(cinfo, JERR_ARITH_NOTIMPL); + } else { + if (cinfo->progressive_mode) { +#ifdef C_PROGRESSIVE_SUPPORTED + jinit_phuff_encoder(cinfo); +#else + ERREXIT(cinfo, JERR_NOT_COMPILED); +#endif + } else + jinit_huff_encoder(cinfo); + } + + /* Need a full-image coefficient buffer in any multi-pass mode. */ + jinit_c_coef_controller(cinfo, + (int) (cinfo->num_scans > 1 || cinfo->optimize_coding)); + jinit_c_main_controller(cinfo, FALSE /* never need full buffer here */); + + jinit_marker_writer(cinfo); + + /* We can now tell the memory manager to allocate virtual arrays. */ + (*cinfo->mem->realize_virt_arrays) ((j_common_ptr) cinfo); + + /* Write the datastream header (SOI) immediately. + * Frame and scan headers are postponed till later. + * This lets application insert special markers after the SOI. + */ + (*cinfo->marker->write_file_header) (cinfo); +} diff --git a/ml/dlib/dlib/external/libjpeg/jcmainct.cpp b/ml/dlib/dlib/external/libjpeg/jcmainct.cpp new file mode 100644 index 000000000..1e5f97a20 --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jcmainct.cpp @@ -0,0 +1,293 @@ +/* + * jcmainct.c + * + * Copyright (C) 1994-1996, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This file contains the main buffer controller for compression. + * The main buffer lies between the pre-processor and the JPEG + * compressor proper; it holds downsampled data in the JPEG colorspace. + */ + +#define JPEG_INTERNALS +#include "jinclude.h" +#include "jpeglib.h" + + +/* Note: currently, there is no operating mode in which a full-image buffer + * is needed at this step. If there were, that mode could not be used with + * "raw data" input, since this module is bypassed in that case. However, + * we've left the code here for possible use in special applications. + */ +#undef FULL_MAIN_BUFFER_SUPPORTED + + +/* Private buffer controller object */ + +typedef struct { + struct jpeg_c_main_controller pub; /* public fields */ + + JDIMENSION cur_iMCU_row; /* number of current iMCU row */ + JDIMENSION rowgroup_ctr; /* counts row groups received in iMCU row */ + int suspended; /* remember if we suspended output */ + J_BUF_MODE pass_mode; /* current operating mode */ + + /* If using just a strip buffer, this points to the entire set of buffers + * (we allocate one for each component). In the full-image case, this + * points to the currently accessible strips of the virtual arrays. + */ + JSAMPARRAY buffer[MAX_COMPONENTS]; + +#ifdef FULL_MAIN_BUFFER_SUPPORTED + /* If using full-image storage, this array holds pointers to virtual-array + * control blocks for each component. Unused if not full-image storage. + */ + jvirt_sarray_ptr whole_image[MAX_COMPONENTS]; +#endif +} my_main_controller; + +typedef my_main_controller * my_main_ptr; + + +/* Forward declarations */ +METHODDEF(void) process_data_simple_main + JPP((j_compress_ptr cinfo, JSAMPARRAY input_buf, + JDIMENSION *in_row_ctr, JDIMENSION in_rows_avail)); +#ifdef FULL_MAIN_BUFFER_SUPPORTED +METHODDEF(void) process_data_buffer_main + JPP((j_compress_ptr cinfo, JSAMPARRAY input_buf, + JDIMENSION *in_row_ctr, JDIMENSION in_rows_avail)); +#endif + + +/* + * Initialize for a processing pass. + */ + +METHODDEF(void) +start_pass_main (j_compress_ptr cinfo, J_BUF_MODE pass_mode) +{ + my_main_ptr main = (my_main_ptr) cinfo->main; + + /* Do nothing in raw-data mode. */ + if (cinfo->raw_data_in) + return; + + main->cur_iMCU_row = 0; /* initialize counters */ + main->rowgroup_ctr = 0; + main->suspended = FALSE; + main->pass_mode = pass_mode; /* save mode for use by process_data */ + + switch (pass_mode) { + case JBUF_PASS_THRU: +#ifdef FULL_MAIN_BUFFER_SUPPORTED + if (main->whole_image[0] != NULL) + ERREXIT(cinfo, JERR_BAD_BUFFER_MODE); +#endif + main->pub.process_data = process_data_simple_main; + break; +#ifdef FULL_MAIN_BUFFER_SUPPORTED + case JBUF_SAVE_SOURCE: + case JBUF_CRANK_DEST: + case JBUF_SAVE_AND_PASS: + if (main->whole_image[0] == NULL) + ERREXIT(cinfo, JERR_BAD_BUFFER_MODE); + main->pub.process_data = process_data_buffer_main; + break; +#endif + default: + ERREXIT(cinfo, JERR_BAD_BUFFER_MODE); + break; + } +} + + +/* + * Process some data. + * This routine handles the simple pass-through mode, + * where we have only a strip buffer. + */ + +METHODDEF(void) +process_data_simple_main (j_compress_ptr cinfo, + JSAMPARRAY input_buf, JDIMENSION *in_row_ctr, + JDIMENSION in_rows_avail) +{ + my_main_ptr main = (my_main_ptr) cinfo->main; + + while (main->cur_iMCU_row < cinfo->total_iMCU_rows) { + /* Read input data if we haven't filled the main buffer yet */ + if (main->rowgroup_ctr < DCTSIZE) + (*cinfo->prep->pre_process_data) (cinfo, + input_buf, in_row_ctr, in_rows_avail, + main->buffer, &main->rowgroup_ctr, + (JDIMENSION) DCTSIZE); + + /* If we don't have a full iMCU row buffered, return to application for + * more data. Note that preprocessor will always pad to fill the iMCU row + * at the bottom of the image. + */ + if (main->rowgroup_ctr != DCTSIZE) + return; + + /* Send the completed row to the compressor */ + if (! (*cinfo->coef->compress_data) (cinfo, main->buffer)) { + /* If compressor did not consume the whole row, then we must need to + * suspend processing and return to the application. In this situation + * we pretend we didn't yet consume the last input row; otherwise, if + * it happened to be the last row of the image, the application would + * think we were done. + */ + if (! main->suspended) { + (*in_row_ctr)--; + main->suspended = TRUE; + } + return; + } + /* We did finish the row. Undo our little suspension hack if a previous + * call suspended; then mark the main buffer empty. + */ + if (main->suspended) { + (*in_row_ctr)++; + main->suspended = FALSE; + } + main->rowgroup_ctr = 0; + main->cur_iMCU_row++; + } +} + + +#ifdef FULL_MAIN_BUFFER_SUPPORTED + +/* + * Process some data. + * This routine handles all of the modes that use a full-size buffer. + */ + +METHODDEF(void) +process_data_buffer_main (j_compress_ptr cinfo, + JSAMPARRAY input_buf, JDIMENSION *in_row_ctr, + JDIMENSION in_rows_avail) +{ + my_main_ptr main = (my_main_ptr) cinfo->main; + int ci; + jpeg_component_info *compptr; + int writing = (main->pass_mode != JBUF_CRANK_DEST); + + while (main->cur_iMCU_row < cinfo->total_iMCU_rows) { + /* Realign the virtual buffers if at the start of an iMCU row. */ + if (main->rowgroup_ctr == 0) { + for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components; + ci++, compptr++) { + main->buffer[ci] = (*cinfo->mem->access_virt_sarray) + ((j_common_ptr) cinfo, main->whole_image[ci], + main->cur_iMCU_row * (compptr->v_samp_factor * DCTSIZE), + (JDIMENSION) (compptr->v_samp_factor * DCTSIZE), writing); + } + /* In a read pass, pretend we just read some source data. */ + if (! writing) { + *in_row_ctr += cinfo->max_v_samp_factor * DCTSIZE; + main->rowgroup_ctr = DCTSIZE; + } + } + + /* If a write pass, read input data until the current iMCU row is full. */ + /* Note: preprocessor will pad if necessary to fill the last iMCU row. */ + if (writing) { + (*cinfo->prep->pre_process_data) (cinfo, + input_buf, in_row_ctr, in_rows_avail, + main->buffer, &main->rowgroup_ctr, + (JDIMENSION) DCTSIZE); + /* Return to application if we need more data to fill the iMCU row. */ + if (main->rowgroup_ctr < DCTSIZE) + return; + } + + /* Emit data, unless this is a sink-only pass. */ + if (main->pass_mode != JBUF_SAVE_SOURCE) { + if (! (*cinfo->coef->compress_data) (cinfo, main->buffer)) { + /* If compressor did not consume the whole row, then we must need to + * suspend processing and return to the application. In this situation + * we pretend we didn't yet consume the last input row; otherwise, if + * it happened to be the last row of the image, the application would + * think we were done. + */ + if (! main->suspended) { + (*in_row_ctr)--; + main->suspended = TRUE; + } + return; + } + /* We did finish the row. Undo our little suspension hack if a previous + * call suspended; then mark the main buffer empty. + */ + if (main->suspended) { + (*in_row_ctr)++; + main->suspended = FALSE; + } + } + + /* If get here, we are done with this iMCU row. Mark buffer empty. */ + main->rowgroup_ctr = 0; + main->cur_iMCU_row++; + } +} + +#endif /* FULL_MAIN_BUFFER_SUPPORTED */ + + +/* + * Initialize main buffer controller. + */ + +GLOBAL(void) +jinit_c_main_controller (j_compress_ptr cinfo, int need_full_buffer) +{ + my_main_ptr main; + int ci; + jpeg_component_info *compptr; + + main = (my_main_ptr) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE, + SIZEOF(my_main_controller)); + cinfo->main = (struct jpeg_c_main_controller *) main; + main->pub.start_pass = start_pass_main; + + /* We don't need to create a buffer in raw-data mode. */ + if (cinfo->raw_data_in) + return; + + /* Create the buffer. It holds downsampled data, so each component + * may be of a different size. + */ + if (need_full_buffer) { +#ifdef FULL_MAIN_BUFFER_SUPPORTED + /* Allocate a full-image virtual array for each component */ + /* Note we pad the bottom to a multiple of the iMCU height */ + for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components; + ci++, compptr++) { + main->whole_image[ci] = (*cinfo->mem->request_virt_sarray) + ((j_common_ptr) cinfo, JPOOL_IMAGE, FALSE, + compptr->width_in_blocks * DCTSIZE, + (JDIMENSION) jround_up((long) compptr->height_in_blocks, + (long) compptr->v_samp_factor) * DCTSIZE, + (JDIMENSION) (compptr->v_samp_factor * DCTSIZE)); + } +#else + ERREXIT(cinfo, JERR_BAD_BUFFER_MODE); +#endif + } else { +#ifdef FULL_MAIN_BUFFER_SUPPORTED + main->whole_image[0] = NULL; /* flag for no virtual arrays */ +#endif + /* Allocate a strip buffer for each component */ + for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components; + ci++, compptr++) { + main->buffer[ci] = (*cinfo->mem->alloc_sarray) + ((j_common_ptr) cinfo, JPOOL_IMAGE, + compptr->width_in_blocks * DCTSIZE, + (JDIMENSION) (compptr->v_samp_factor * DCTSIZE)); + } + } +} diff --git a/ml/dlib/dlib/external/libjpeg/jcmarker.cpp b/ml/dlib/dlib/external/libjpeg/jcmarker.cpp new file mode 100644 index 000000000..1bfd15c55 --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jcmarker.cpp @@ -0,0 +1,664 @@ +/* + * jcmarker.c + * + * Copyright (C) 1991-1998, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This file contains routines to write JPEG datastream markers. + */ + +#define JPEG_INTERNALS +#include "jinclude.h" +#include "jpeglib.h" + + +typedef enum { /* JPEG marker codes */ + M_SOF0 = 0xc0, + M_SOF1 = 0xc1, + M_SOF2 = 0xc2, + M_SOF3 = 0xc3, + + M_SOF5 = 0xc5, + M_SOF6 = 0xc6, + M_SOF7 = 0xc7, + + M_JPG = 0xc8, + M_SOF9 = 0xc9, + M_SOF10 = 0xca, + M_SOF11 = 0xcb, + + M_SOF13 = 0xcd, + M_SOF14 = 0xce, + M_SOF15 = 0xcf, + + M_DHT = 0xc4, + + M_DAC = 0xcc, + + M_RST0 = 0xd0, + M_RST1 = 0xd1, + M_RST2 = 0xd2, + M_RST3 = 0xd3, + M_RST4 = 0xd4, + M_RST5 = 0xd5, + M_RST6 = 0xd6, + M_RST7 = 0xd7, + + M_SOI = 0xd8, + M_EOI = 0xd9, + M_SOS = 0xda, + M_DQT = 0xdb, + M_DNL = 0xdc, + M_DRI = 0xdd, + M_DHP = 0xde, + M_EXP = 0xdf, + + M_APP0 = 0xe0, + M_APP1 = 0xe1, + M_APP2 = 0xe2, + M_APP3 = 0xe3, + M_APP4 = 0xe4, + M_APP5 = 0xe5, + M_APP6 = 0xe6, + M_APP7 = 0xe7, + M_APP8 = 0xe8, + M_APP9 = 0xe9, + M_APP10 = 0xea, + M_APP11 = 0xeb, + M_APP12 = 0xec, + M_APP13 = 0xed, + M_APP14 = 0xee, + M_APP15 = 0xef, + + M_JPG0 = 0xf0, + M_JPG13 = 0xfd, + M_COM = 0xfe, + + M_TEM = 0x01, + + M_ERROR = 0x100 +} JPEG_MARKER; + + +/* Private state */ + +typedef struct { + struct jpeg_marker_writer pub; /* public fields */ + + unsigned int last_restart_interval; /* last DRI value emitted; 0 after SOI */ +} my_marker_writer; + +typedef my_marker_writer * my_marker_ptr; + + +/* + * Basic output routines. + * + * Note that we do not support suspension while writing a marker. + * Therefore, an application using suspension must ensure that there is + * enough buffer space for the initial markers (typ. 600-700 bytes) before + * calling jpeg_start_compress, and enough space to write the trailing EOI + * (a few bytes) before calling jpeg_finish_compress. Multipass compression + * modes are not supported at all with suspension, so those two are the only + * points where markers will be written. + */ + +LOCAL(void) +emit_byte (j_compress_ptr cinfo, int val) +/* Emit a byte */ +{ + struct jpeg_destination_mgr * dest = cinfo->dest; + + *(dest->next_output_byte)++ = (JOCTET) val; + if (--dest->free_in_buffer == 0) { + if (! (*dest->empty_output_buffer) (cinfo)) + ERREXIT(cinfo, JERR_CANT_SUSPEND); + } +} + + +LOCAL(void) +emit_marker (j_compress_ptr cinfo, JPEG_MARKER mark) +/* Emit a marker code */ +{ + emit_byte(cinfo, 0xFF); + emit_byte(cinfo, (int) mark); +} + + +LOCAL(void) +emit_2bytes (j_compress_ptr cinfo, int value) +/* Emit a 2-byte integer; these are always MSB first in JPEG files */ +{ + emit_byte(cinfo, (value >> 8) & 0xFF); + emit_byte(cinfo, value & 0xFF); +} + + +/* + * Routines to write specific marker types. + */ + +LOCAL(int) +emit_dqt (j_compress_ptr cinfo, int index) +/* Emit a DQT marker */ +/* Returns the precision used (0 = 8bits, 1 = 16bits) for baseline checking */ +{ + JQUANT_TBL * qtbl = cinfo->quant_tbl_ptrs[index]; + int prec; + int i; + + if (qtbl == NULL) + ERREXIT1(cinfo, JERR_NO_QUANT_TABLE, index); + + prec = 0; + for (i = 0; i < DCTSIZE2; i++) { + if (qtbl->quantval[i] > 255) + prec = 1; + } + + if (! qtbl->sent_table) { + emit_marker(cinfo, M_DQT); + + emit_2bytes(cinfo, prec ? DCTSIZE2*2 + 1 + 2 : DCTSIZE2 + 1 + 2); + + emit_byte(cinfo, index + (prec<<4)); + + for (i = 0; i < DCTSIZE2; i++) { + /* The table entries must be emitted in zigzag order. */ + unsigned int qval = qtbl->quantval[jpeg_natural_order[i]]; + if (prec) + emit_byte(cinfo, (int) (qval >> 8)); + emit_byte(cinfo, (int) (qval & 0xFF)); + } + + qtbl->sent_table = TRUE; + } + + return prec; +} + + +LOCAL(void) +emit_dht (j_compress_ptr cinfo, int index, int is_ac) +/* Emit a DHT marker */ +{ + JHUFF_TBL * htbl; + int length, i; + + if (is_ac) { + htbl = cinfo->ac_huff_tbl_ptrs[index]; + index += 0x10; /* output index has AC bit set */ + } else { + htbl = cinfo->dc_huff_tbl_ptrs[index]; + } + + if (htbl == NULL) + ERREXIT1(cinfo, JERR_NO_HUFF_TABLE, index); + + if (! htbl->sent_table) { + emit_marker(cinfo, M_DHT); + + length = 0; + for (i = 1; i <= 16; i++) + length += htbl->bits[i]; + + emit_2bytes(cinfo, length + 2 + 1 + 16); + emit_byte(cinfo, index); + + for (i = 1; i <= 16; i++) + emit_byte(cinfo, htbl->bits[i]); + + for (i = 0; i < length; i++) + emit_byte(cinfo, htbl->huffval[i]); + + htbl->sent_table = TRUE; + } +} + + +LOCAL(void) +emit_dac (j_compress_ptr )//cinfo) +/* Emit a DAC marker */ +/* Since the useful info is so small, we want to emit all the tables in */ +/* one DAC marker. Therefore this routine does its own scan of the table. */ +{ +#ifdef C_ARITH_CODING_SUPPORTED + char dc_in_use[NUM_ARITH_TBLS]; + char ac_in_use[NUM_ARITH_TBLS]; + int length, i; + jpeg_component_info *compptr; + + for (i = 0; i < NUM_ARITH_TBLS; i++) + dc_in_use[i] = ac_in_use[i] = 0; + + for (i = 0; i < cinfo->comps_in_scan; i++) { + compptr = cinfo->cur_comp_info[i]; + dc_in_use[compptr->dc_tbl_no] = 1; + ac_in_use[compptr->ac_tbl_no] = 1; + } + + length = 0; + for (i = 0; i < NUM_ARITH_TBLS; i++) + length += dc_in_use[i] + ac_in_use[i]; + + emit_marker(cinfo, M_DAC); + + emit_2bytes(cinfo, length*2 + 2); + + for (i = 0; i < NUM_ARITH_TBLS; i++) { + if (dc_in_use[i]) { + emit_byte(cinfo, i); + emit_byte(cinfo, cinfo->arith_dc_L[i] + (cinfo->arith_dc_U[i]<<4)); + } + if (ac_in_use[i]) { + emit_byte(cinfo, i + 0x10); + emit_byte(cinfo, cinfo->arith_ac_K[i]); + } + } +#endif /* C_ARITH_CODING_SUPPORTED */ +} + + +LOCAL(void) +emit_dri (j_compress_ptr cinfo) +/* Emit a DRI marker */ +{ + emit_marker(cinfo, M_DRI); + + emit_2bytes(cinfo, 4); /* fixed length */ + + emit_2bytes(cinfo, (int) cinfo->restart_interval); +} + + +LOCAL(void) +emit_sof (j_compress_ptr cinfo, JPEG_MARKER code) +/* Emit a SOF marker */ +{ + int ci; + jpeg_component_info *compptr; + + emit_marker(cinfo, code); + + emit_2bytes(cinfo, 3 * cinfo->num_components + 2 + 5 + 1); /* length */ + + /* Make sure image isn't bigger than SOF field can handle */ + if ((long) cinfo->image_height > 65535L || + (long) cinfo->image_width > 65535L) + ERREXIT1(cinfo, JERR_IMAGE_TOO_BIG, (unsigned int) 65535); + + emit_byte(cinfo, cinfo->data_precision); + emit_2bytes(cinfo, (int) cinfo->image_height); + emit_2bytes(cinfo, (int) cinfo->image_width); + + emit_byte(cinfo, cinfo->num_components); + + for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components; + ci++, compptr++) { + emit_byte(cinfo, compptr->component_id); + emit_byte(cinfo, (compptr->h_samp_factor << 4) + compptr->v_samp_factor); + emit_byte(cinfo, compptr->quant_tbl_no); + } +} + + +LOCAL(void) +emit_sos (j_compress_ptr cinfo) +/* Emit a SOS marker */ +{ + int i, td, ta; + jpeg_component_info *compptr; + + emit_marker(cinfo, M_SOS); + + emit_2bytes(cinfo, 2 * cinfo->comps_in_scan + 2 + 1 + 3); /* length */ + + emit_byte(cinfo, cinfo->comps_in_scan); + + for (i = 0; i < cinfo->comps_in_scan; i++) { + compptr = cinfo->cur_comp_info[i]; + emit_byte(cinfo, compptr->component_id); + td = compptr->dc_tbl_no; + ta = compptr->ac_tbl_no; + if (cinfo->progressive_mode) { + /* Progressive mode: only DC or only AC tables are used in one scan; + * furthermore, Huffman coding of DC refinement uses no table at all. + * We emit 0 for unused field(s); this is recommended by the P&M text + * but does not seem to be specified in the standard. + */ + if (cinfo->Ss == 0) { + ta = 0; /* DC scan */ + if (cinfo->Ah != 0 && !cinfo->arith_code) + td = 0; /* no DC table either */ + } else { + td = 0; /* AC scan */ + } + } + emit_byte(cinfo, (td << 4) + ta); + } + + emit_byte(cinfo, cinfo->Ss); + emit_byte(cinfo, cinfo->Se); + emit_byte(cinfo, (cinfo->Ah << 4) + cinfo->Al); +} + + +LOCAL(void) +emit_jfif_app0 (j_compress_ptr cinfo) +/* Emit a JFIF-compliant APP0 marker */ +{ + /* + * Length of APP0 block (2 bytes) + * Block ID (4 bytes - ASCII "JFIF") + * Zero byte (1 byte to terminate the ID string) + * Version Major, Minor (2 bytes - major first) + * Units (1 byte - 0x00 = none, 0x01 = inch, 0x02 = cm) + * Xdpu (2 bytes - dots per unit horizontal) + * Ydpu (2 bytes - dots per unit vertical) + * Thumbnail X size (1 byte) + * Thumbnail Y size (1 byte) + */ + + emit_marker(cinfo, M_APP0); + + emit_2bytes(cinfo, 2 + 4 + 1 + 2 + 1 + 2 + 2 + 1 + 1); /* length */ + + emit_byte(cinfo, 0x4A); /* Identifier: ASCII "JFIF" */ + emit_byte(cinfo, 0x46); + emit_byte(cinfo, 0x49); + emit_byte(cinfo, 0x46); + emit_byte(cinfo, 0); + emit_byte(cinfo, cinfo->JFIF_major_version); /* Version fields */ + emit_byte(cinfo, cinfo->JFIF_minor_version); + emit_byte(cinfo, cinfo->density_unit); /* Pixel size information */ + emit_2bytes(cinfo, (int) cinfo->X_density); + emit_2bytes(cinfo, (int) cinfo->Y_density); + emit_byte(cinfo, 0); /* No thumbnail image */ + emit_byte(cinfo, 0); +} + + +LOCAL(void) +emit_adobe_app14 (j_compress_ptr cinfo) +/* Emit an Adobe APP14 marker */ +{ + /* + * Length of APP14 block (2 bytes) + * Block ID (5 bytes - ASCII "Adobe") + * Version Number (2 bytes - currently 100) + * Flags0 (2 bytes - currently 0) + * Flags1 (2 bytes - currently 0) + * Color transform (1 byte) + * + * Although Adobe TN 5116 mentions Version = 101, all the Adobe files + * now in circulation seem to use Version = 100, so that's what we write. + * + * We write the color transform byte as 1 if the JPEG color space is + * YCbCr, 2 if it's YCCK, 0 otherwise. Adobe's definition has to do with + * whether the encoder performed a transformation, which is pretty useless. + */ + + emit_marker(cinfo, M_APP14); + + emit_2bytes(cinfo, 2 + 5 + 2 + 2 + 2 + 1); /* length */ + + emit_byte(cinfo, 0x41); /* Identifier: ASCII "Adobe" */ + emit_byte(cinfo, 0x64); + emit_byte(cinfo, 0x6F); + emit_byte(cinfo, 0x62); + emit_byte(cinfo, 0x65); + emit_2bytes(cinfo, 100); /* Version */ + emit_2bytes(cinfo, 0); /* Flags0 */ + emit_2bytes(cinfo, 0); /* Flags1 */ + switch (cinfo->jpeg_color_space) { + case JCS_YCbCr: + emit_byte(cinfo, 1); /* Color transform = 1 */ + break; + case JCS_YCCK: + emit_byte(cinfo, 2); /* Color transform = 2 */ + break; + default: + emit_byte(cinfo, 0); /* Color transform = 0 */ + break; + } +} + + +/* + * These routines allow writing an arbitrary marker with parameters. + * The only intended use is to emit COM or APPn markers after calling + * write_file_header and before calling write_frame_header. + * Other uses are not guaranteed to produce desirable results. + * Counting the parameter bytes properly is the caller's responsibility. + */ + +METHODDEF(void) +write_marker_header (j_compress_ptr cinfo, int marker, unsigned int datalen) +/* Emit an arbitrary marker header */ +{ + if (datalen > (unsigned int) 65533) /* safety check */ + ERREXIT(cinfo, JERR_BAD_LENGTH); + + emit_marker(cinfo, (JPEG_MARKER) marker); + + emit_2bytes(cinfo, (int) (datalen + 2)); /* total length */ +} + +METHODDEF(void) +write_marker_byte (j_compress_ptr cinfo, int val) +/* Emit one byte of marker parameters following write_marker_header */ +{ + emit_byte(cinfo, val); +} + + +/* + * Write datastream header. + * This consists of an SOI and optional APPn markers. + * We recommend use of the JFIF marker, but not the Adobe marker, + * when using YCbCr or grayscale data. The JFIF marker should NOT + * be used for any other JPEG colorspace. The Adobe marker is helpful + * to distinguish RGB, CMYK, and YCCK colorspaces. + * Note that an application can write additional header markers after + * jpeg_start_compress returns. + */ + +METHODDEF(void) +write_file_header (j_compress_ptr cinfo) +{ + my_marker_ptr marker = (my_marker_ptr) cinfo->marker; + + emit_marker(cinfo, M_SOI); /* first the SOI */ + + /* SOI is defined to reset restart interval to 0 */ + marker->last_restart_interval = 0; + + if (cinfo->write_JFIF_header) /* next an optional JFIF APP0 */ + emit_jfif_app0(cinfo); + if (cinfo->write_Adobe_marker) /* next an optional Adobe APP14 */ + emit_adobe_app14(cinfo); +} + + +/* + * Write frame header. + * This consists of DQT and SOFn markers. + * Note that we do not emit the SOF until we have emitted the DQT(s). + * This avoids compatibility problems with incorrect implementations that + * try to error-check the quant table numbers as soon as they see the SOF. + */ + +METHODDEF(void) +write_frame_header (j_compress_ptr cinfo) +{ + int ci, prec; + int is_baseline; + jpeg_component_info *compptr; + + /* Emit DQT for each quantization table. + * Note that emit_dqt() suppresses any duplicate tables. + */ + prec = 0; + for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components; + ci++, compptr++) { + prec += emit_dqt(cinfo, compptr->quant_tbl_no); + } + /* now prec is nonzero iff there are any 16-bit quant tables. */ + + /* Check for a non-baseline specification. + * Note we assume that Huffman table numbers won't be changed later. + */ + if (cinfo->arith_code || cinfo->progressive_mode || + cinfo->data_precision != 8) { + is_baseline = FALSE; + } else { + is_baseline = TRUE; + for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components; + ci++, compptr++) { + if (compptr->dc_tbl_no > 1 || compptr->ac_tbl_no > 1) + is_baseline = FALSE; + } + if (prec && is_baseline) { + is_baseline = FALSE; + /* If it's baseline except for quantizer size, warn the user */ + TRACEMS(cinfo, 0, JTRC_16BIT_TABLES); + } + } + + /* Emit the proper SOF marker */ + if (cinfo->arith_code) { + emit_sof(cinfo, M_SOF9); /* SOF code for arithmetic coding */ + } else { + if (cinfo->progressive_mode) + emit_sof(cinfo, M_SOF2); /* SOF code for progressive Huffman */ + else if (is_baseline) + emit_sof(cinfo, M_SOF0); /* SOF code for baseline implementation */ + else + emit_sof(cinfo, M_SOF1); /* SOF code for non-baseline Huffman file */ + } +} + + +/* + * Write scan header. + * This consists of DHT or DAC markers, optional DRI, and SOS. + * Compressed data will be written following the SOS. + */ + +METHODDEF(void) +write_scan_header (j_compress_ptr cinfo) +{ + my_marker_ptr marker = (my_marker_ptr) cinfo->marker; + int i; + jpeg_component_info *compptr; + + if (cinfo->arith_code) { + /* Emit arith conditioning info. We may have some duplication + * if the file has multiple scans, but it's so small it's hardly + * worth worrying about. + */ + emit_dac(cinfo); + } else { + /* Emit Huffman tables. + * Note that emit_dht() suppresses any duplicate tables. + */ + for (i = 0; i < cinfo->comps_in_scan; i++) { + compptr = cinfo->cur_comp_info[i]; + if (cinfo->progressive_mode) { + /* Progressive mode: only DC or only AC tables are used in one scan */ + if (cinfo->Ss == 0) { + if (cinfo->Ah == 0) /* DC needs no table for refinement scan */ + emit_dht(cinfo, compptr->dc_tbl_no, FALSE); + } else { + emit_dht(cinfo, compptr->ac_tbl_no, TRUE); + } + } else { + /* Sequential mode: need both DC and AC tables */ + emit_dht(cinfo, compptr->dc_tbl_no, FALSE); + emit_dht(cinfo, compptr->ac_tbl_no, TRUE); + } + } + } + + /* Emit DRI if required --- note that DRI value could change for each scan. + * We avoid wasting space with unnecessary DRIs, however. + */ + if (cinfo->restart_interval != marker->last_restart_interval) { + emit_dri(cinfo); + marker->last_restart_interval = cinfo->restart_interval; + } + + emit_sos(cinfo); +} + + +/* + * Write datastream trailer. + */ + +METHODDEF(void) +write_file_trailer (j_compress_ptr cinfo) +{ + emit_marker(cinfo, M_EOI); +} + + +/* + * Write an abbreviated table-specification datastream. + * This consists of SOI, DQT and DHT tables, and EOI. + * Any table that is defined and not marked sent_table = TRUE will be + * emitted. Note that all tables will be marked sent_table = TRUE at exit. + */ + +METHODDEF(void) +write_tables_only (j_compress_ptr cinfo) +{ + int i; + + emit_marker(cinfo, M_SOI); + + for (i = 0; i < NUM_QUANT_TBLS; i++) { + if (cinfo->quant_tbl_ptrs[i] != NULL) + (void) emit_dqt(cinfo, i); + } + + if (! cinfo->arith_code) { + for (i = 0; i < NUM_HUFF_TBLS; i++) { + if (cinfo->dc_huff_tbl_ptrs[i] != NULL) + emit_dht(cinfo, i, FALSE); + if (cinfo->ac_huff_tbl_ptrs[i] != NULL) + emit_dht(cinfo, i, TRUE); + } + } + + emit_marker(cinfo, M_EOI); +} + + +/* + * Initialize the marker writer module. + */ + +GLOBAL(void) +jinit_marker_writer (j_compress_ptr cinfo) +{ + my_marker_ptr marker; + + /* Create the subobject */ + marker = (my_marker_ptr) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE, + SIZEOF(my_marker_writer)); + cinfo->marker = (struct jpeg_marker_writer *) marker; + /* Initialize method pointers */ + marker->pub.write_file_header = write_file_header; + marker->pub.write_frame_header = write_frame_header; + marker->pub.write_scan_header = write_scan_header; + marker->pub.write_file_trailer = write_file_trailer; + marker->pub.write_tables_only = write_tables_only; + marker->pub.write_marker_header = write_marker_header; + marker->pub.write_marker_byte = write_marker_byte; + /* Initialize private state */ + marker->last_restart_interval = 0; +} diff --git a/ml/dlib/dlib/external/libjpeg/jcmaster.cpp b/ml/dlib/dlib/external/libjpeg/jcmaster.cpp new file mode 100644 index 000000000..3e1b1d711 --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jcmaster.cpp @@ -0,0 +1,590 @@ +/* + * jcmaster.c + * + * Copyright (C) 1991-1997, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This file contains master control logic for the JPEG compressor. + * These routines are concerned with parameter validation, initial setup, + * and inter-pass control (determining the number of passes and the work + * to be done in each pass). + */ + +#define JPEG_INTERNALS +#include "jinclude.h" +#include "jpeglib.h" + + +/* Private state */ + +typedef enum { + main_pass, /* input data, also do first output step */ + huff_opt_pass, /* Huffman code optimization pass */ + output_pass /* data output pass */ +} c_pass_type; + +typedef struct { + struct jpeg_comp_master pub; /* public fields */ + + c_pass_type pass_type; /* the type of the current pass */ + + int pass_number; /* # of passes completed */ + int total_passes; /* total # of passes needed */ + + int scan_number; /* current index in scan_info[] */ +} my_comp_master; + +typedef my_comp_master * my_master_ptr; + + +/* + * Support routines that do various essential calculations. + */ + +LOCAL(void) +initial_setup (j_compress_ptr cinfo) +/* Do computations that are needed before master selection phase */ +{ + int ci; + jpeg_component_info *compptr; + long samplesperrow; + JDIMENSION jd_samplesperrow; + + /* Sanity check on image dimensions */ + if (cinfo->image_height <= 0 || cinfo->image_width <= 0 + || cinfo->num_components <= 0 || cinfo->input_components <= 0) + ERREXIT(cinfo, JERR_EMPTY_IMAGE); + + /* Make sure image isn't bigger than I can handle */ + if ((long) cinfo->image_height > (long) JPEG_MAX_DIMENSION || + (long) cinfo->image_width > (long) JPEG_MAX_DIMENSION) + ERREXIT1(cinfo, JERR_IMAGE_TOO_BIG, (unsigned int) JPEG_MAX_DIMENSION); + + /* Width of an input scanline must be representable as JDIMENSION. */ + samplesperrow = (long) cinfo->image_width * (long) cinfo->input_components; + jd_samplesperrow = (JDIMENSION) samplesperrow; + if ((long) jd_samplesperrow != samplesperrow) + ERREXIT(cinfo, JERR_WIDTH_OVERFLOW); + + /* For now, precision must match compiled-in value... */ + if (cinfo->data_precision != BITS_IN_JSAMPLE) + ERREXIT1(cinfo, JERR_BAD_PRECISION, cinfo->data_precision); + + /* Check that number of components won't exceed internal array sizes */ + if (cinfo->num_components > MAX_COMPONENTS) + ERREXIT2(cinfo, JERR_COMPONENT_COUNT, cinfo->num_components, + MAX_COMPONENTS); + + /* Compute maximum sampling factors; check factor validity */ + cinfo->max_h_samp_factor = 1; + cinfo->max_v_samp_factor = 1; + for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components; + ci++, compptr++) { + if (compptr->h_samp_factor<=0 || compptr->h_samp_factor>MAX_SAMP_FACTOR || + compptr->v_samp_factor<=0 || compptr->v_samp_factor>MAX_SAMP_FACTOR) + ERREXIT(cinfo, JERR_BAD_SAMPLING); + cinfo->max_h_samp_factor = MAX(cinfo->max_h_samp_factor, + compptr->h_samp_factor); + cinfo->max_v_samp_factor = MAX(cinfo->max_v_samp_factor, + compptr->v_samp_factor); + } + + /* Compute dimensions of components */ + for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components; + ci++, compptr++) { + /* Fill in the correct component_index value; don't rely on application */ + compptr->component_index = ci; + /* For compression, we never do DCT scaling. */ + compptr->DCT_scaled_size = DCTSIZE; + /* Size in DCT blocks */ + compptr->width_in_blocks = (JDIMENSION) + jdiv_round_up((long) cinfo->image_width * (long) compptr->h_samp_factor, + (long) (cinfo->max_h_samp_factor * DCTSIZE)); + compptr->height_in_blocks = (JDIMENSION) + jdiv_round_up((long) cinfo->image_height * (long) compptr->v_samp_factor, + (long) (cinfo->max_v_samp_factor * DCTSIZE)); + /* Size in samples */ + compptr->downsampled_width = (JDIMENSION) + jdiv_round_up((long) cinfo->image_width * (long) compptr->h_samp_factor, + (long) cinfo->max_h_samp_factor); + compptr->downsampled_height = (JDIMENSION) + jdiv_round_up((long) cinfo->image_height * (long) compptr->v_samp_factor, + (long) cinfo->max_v_samp_factor); + /* Mark component needed (this flag isn't actually used for compression) */ + compptr->component_needed = TRUE; + } + + /* Compute number of fully interleaved MCU rows (number of times that + * main controller will call coefficient controller). + */ + cinfo->total_iMCU_rows = (JDIMENSION) + jdiv_round_up((long) cinfo->image_height, + (long) (cinfo->max_v_samp_factor*DCTSIZE)); +} + + +#ifdef C_MULTISCAN_FILES_SUPPORTED + +LOCAL(void) +validate_script (j_compress_ptr cinfo) +/* Verify that the scan script in cinfo->scan_info[] is valid; also + * determine whether it uses progressive JPEG, and set cinfo->progressive_mode. + */ +{ + const jpeg_scan_info * scanptr; + int scanno, ncomps, ci, coefi, thisi; + int Ss, Se, Ah, Al; + int component_sent[MAX_COMPONENTS]; +#ifdef C_PROGRESSIVE_SUPPORTED + int * last_bitpos_ptr; + int last_bitpos[MAX_COMPONENTS][DCTSIZE2]; + /* -1 until that coefficient has been seen; then last Al for it */ +#endif + + if (cinfo->num_scans <= 0) + ERREXIT1(cinfo, JERR_BAD_SCAN_SCRIPT, 0); + + /* For sequential JPEG, all scans must have Ss=0, Se=DCTSIZE2-1; + * for progressive JPEG, no scan can have this. + */ + scanptr = cinfo->scan_info; + if (scanptr->Ss != 0 || scanptr->Se != DCTSIZE2-1) { +#ifdef C_PROGRESSIVE_SUPPORTED + cinfo->progressive_mode = TRUE; + last_bitpos_ptr = & last_bitpos[0][0]; + for (ci = 0; ci < cinfo->num_components; ci++) + for (coefi = 0; coefi < DCTSIZE2; coefi++) + *last_bitpos_ptr++ = -1; +#else + ERREXIT(cinfo, JERR_NOT_COMPILED); +#endif + } else { + cinfo->progressive_mode = FALSE; + for (ci = 0; ci < cinfo->num_components; ci++) + component_sent[ci] = FALSE; + } + + for (scanno = 1; scanno <= cinfo->num_scans; scanptr++, scanno++) { + /* Validate component indexes */ + ncomps = scanptr->comps_in_scan; + if (ncomps <= 0 || ncomps > MAX_COMPS_IN_SCAN) + ERREXIT2(cinfo, JERR_COMPONENT_COUNT, ncomps, MAX_COMPS_IN_SCAN); + for (ci = 0; ci < ncomps; ci++) { + thisi = scanptr->component_index[ci]; + if (thisi < 0 || thisi >= cinfo->num_components) + ERREXIT1(cinfo, JERR_BAD_SCAN_SCRIPT, scanno); + /* Components must appear in SOF order within each scan */ + if (ci > 0 && thisi <= scanptr->component_index[ci-1]) + ERREXIT1(cinfo, JERR_BAD_SCAN_SCRIPT, scanno); + } + /* Validate progression parameters */ + Ss = scanptr->Ss; + Se = scanptr->Se; + Ah = scanptr->Ah; + Al = scanptr->Al; + if (cinfo->progressive_mode) { +#ifdef C_PROGRESSIVE_SUPPORTED + /* The JPEG spec simply gives the ranges 0..13 for Ah and Al, but that + * seems wrong: the upper bound ought to depend on data precision. + * Perhaps they really meant 0..N+1 for N-bit precision. + * Here we allow 0..10 for 8-bit data; Al larger than 10 results in + * out-of-range reconstructed DC values during the first DC scan, + * which might cause problems for some decoders. + */ +#if BITS_IN_JSAMPLE == 8 +#define MAX_AH_AL 10 +#else +#define MAX_AH_AL 13 +#endif + if (Ss < 0 || Ss >= DCTSIZE2 || Se < Ss || Se >= DCTSIZE2 || + Ah < 0 || Ah > MAX_AH_AL || Al < 0 || Al > MAX_AH_AL) + ERREXIT1(cinfo, JERR_BAD_PROG_SCRIPT, scanno); + if (Ss == 0) { + if (Se != 0) /* DC and AC together not OK */ + ERREXIT1(cinfo, JERR_BAD_PROG_SCRIPT, scanno); + } else { + if (ncomps != 1) /* AC scans must be for only one component */ + ERREXIT1(cinfo, JERR_BAD_PROG_SCRIPT, scanno); + } + for (ci = 0; ci < ncomps; ci++) { + last_bitpos_ptr = & last_bitpos[scanptr->component_index[ci]][0]; + if (Ss != 0 && last_bitpos_ptr[0] < 0) /* AC without prior DC scan */ + ERREXIT1(cinfo, JERR_BAD_PROG_SCRIPT, scanno); + for (coefi = Ss; coefi <= Se; coefi++) { + if (last_bitpos_ptr[coefi] < 0) { + /* first scan of this coefficient */ + if (Ah != 0) + ERREXIT1(cinfo, JERR_BAD_PROG_SCRIPT, scanno); + } else { + /* not first scan */ + if (Ah != last_bitpos_ptr[coefi] || Al != Ah-1) + ERREXIT1(cinfo, JERR_BAD_PROG_SCRIPT, scanno); + } + last_bitpos_ptr[coefi] = Al; + } + } +#endif + } else { + /* For sequential JPEG, all progression parameters must be these: */ + if (Ss != 0 || Se != DCTSIZE2-1 || Ah != 0 || Al != 0) + ERREXIT1(cinfo, JERR_BAD_PROG_SCRIPT, scanno); + /* Make sure components are not sent twice */ + for (ci = 0; ci < ncomps; ci++) { + thisi = scanptr->component_index[ci]; + if (component_sent[thisi]) + ERREXIT1(cinfo, JERR_BAD_SCAN_SCRIPT, scanno); + component_sent[thisi] = TRUE; + } + } + } + + /* Now verify that everything got sent. */ + if (cinfo->progressive_mode) { +#ifdef C_PROGRESSIVE_SUPPORTED + /* For progressive mode, we only check that at least some DC data + * got sent for each component; the spec does not require that all bits + * of all coefficients be transmitted. Would it be wiser to enforce + * transmission of all coefficient bits?? + */ + for (ci = 0; ci < cinfo->num_components; ci++) { + if (last_bitpos[ci][0] < 0) + ERREXIT(cinfo, JERR_MISSING_DATA); + } +#endif + } else { + for (ci = 0; ci < cinfo->num_components; ci++) { + if (! component_sent[ci]) + ERREXIT(cinfo, JERR_MISSING_DATA); + } + } +} + +#endif /* C_MULTISCAN_FILES_SUPPORTED */ + + +LOCAL(void) +select_scan_parameters (j_compress_ptr cinfo) +/* Set up the scan parameters for the current scan */ +{ + int ci; + +#ifdef C_MULTISCAN_FILES_SUPPORTED + if (cinfo->scan_info != NULL) { + /* Prepare for current scan --- the script is already validated */ + my_master_ptr master = (my_master_ptr) cinfo->master; + const jpeg_scan_info * scanptr = cinfo->scan_info + master->scan_number; + + cinfo->comps_in_scan = scanptr->comps_in_scan; + for (ci = 0; ci < scanptr->comps_in_scan; ci++) { + cinfo->cur_comp_info[ci] = + &cinfo->comp_info[scanptr->component_index[ci]]; + } + cinfo->Ss = scanptr->Ss; + cinfo->Se = scanptr->Se; + cinfo->Ah = scanptr->Ah; + cinfo->Al = scanptr->Al; + } + else +#endif + { + /* Prepare for single sequential-JPEG scan containing all components */ + if (cinfo->num_components > MAX_COMPS_IN_SCAN) + ERREXIT2(cinfo, JERR_COMPONENT_COUNT, cinfo->num_components, + MAX_COMPS_IN_SCAN); + cinfo->comps_in_scan = cinfo->num_components; + for (ci = 0; ci < cinfo->num_components; ci++) { + cinfo->cur_comp_info[ci] = &cinfo->comp_info[ci]; + } + cinfo->Ss = 0; + cinfo->Se = DCTSIZE2-1; + cinfo->Ah = 0; + cinfo->Al = 0; + } +} + + +LOCAL(void) +per_scan_setup (j_compress_ptr cinfo) +/* Do computations that are needed before processing a JPEG scan */ +/* cinfo->comps_in_scan and cinfo->cur_comp_info[] are already set */ +{ + int ci, mcublks, tmp; + jpeg_component_info *compptr; + + if (cinfo->comps_in_scan == 1) { + + /* Noninterleaved (single-component) scan */ + compptr = cinfo->cur_comp_info[0]; + + /* Overall image size in MCUs */ + cinfo->MCUs_per_row = compptr->width_in_blocks; + cinfo->MCU_rows_in_scan = compptr->height_in_blocks; + + /* For noninterleaved scan, always one block per MCU */ + compptr->MCU_width = 1; + compptr->MCU_height = 1; + compptr->MCU_blocks = 1; + compptr->MCU_sample_width = DCTSIZE; + compptr->last_col_width = 1; + /* For noninterleaved scans, it is convenient to define last_row_height + * as the number of block rows present in the last iMCU row. + */ + tmp = (int) (compptr->height_in_blocks % compptr->v_samp_factor); + if (tmp == 0) tmp = compptr->v_samp_factor; + compptr->last_row_height = tmp; + + /* Prepare array describing MCU composition */ + cinfo->blocks_in_MCU = 1; + cinfo->MCU_membership[0] = 0; + + } else { + + /* Interleaved (multi-component) scan */ + if (cinfo->comps_in_scan <= 0 || cinfo->comps_in_scan > MAX_COMPS_IN_SCAN) + ERREXIT2(cinfo, JERR_COMPONENT_COUNT, cinfo->comps_in_scan, + MAX_COMPS_IN_SCAN); + + /* Overall image size in MCUs */ + cinfo->MCUs_per_row = (JDIMENSION) + jdiv_round_up((long) cinfo->image_width, + (long) (cinfo->max_h_samp_factor*DCTSIZE)); + cinfo->MCU_rows_in_scan = (JDIMENSION) + jdiv_round_up((long) cinfo->image_height, + (long) (cinfo->max_v_samp_factor*DCTSIZE)); + + cinfo->blocks_in_MCU = 0; + + for (ci = 0; ci < cinfo->comps_in_scan; ci++) { + compptr = cinfo->cur_comp_info[ci]; + /* Sampling factors give # of blocks of component in each MCU */ + compptr->MCU_width = compptr->h_samp_factor; + compptr->MCU_height = compptr->v_samp_factor; + compptr->MCU_blocks = compptr->MCU_width * compptr->MCU_height; + compptr->MCU_sample_width = compptr->MCU_width * DCTSIZE; + /* Figure number of non-dummy blocks in last MCU column & row */ + tmp = (int) (compptr->width_in_blocks % compptr->MCU_width); + if (tmp == 0) tmp = compptr->MCU_width; + compptr->last_col_width = tmp; + tmp = (int) (compptr->height_in_blocks % compptr->MCU_height); + if (tmp == 0) tmp = compptr->MCU_height; + compptr->last_row_height = tmp; + /* Prepare array describing MCU composition */ + mcublks = compptr->MCU_blocks; + if (cinfo->blocks_in_MCU + mcublks > C_MAX_BLOCKS_IN_MCU) + ERREXIT(cinfo, JERR_BAD_MCU_SIZE); + while (mcublks-- > 0) { + cinfo->MCU_membership[cinfo->blocks_in_MCU++] = ci; + } + } + + } + + /* Convert restart specified in rows to actual MCU count. */ + /* Note that count must fit in 16 bits, so we provide limiting. */ + if (cinfo->restart_in_rows > 0) { + long nominal = (long) cinfo->restart_in_rows * (long) cinfo->MCUs_per_row; + cinfo->restart_interval = (unsigned int) MIN(nominal, 65535L); + } +} + + +/* + * Per-pass setup. + * This is called at the beginning of each pass. We determine which modules + * will be active during this pass and give them appropriate start_pass calls. + * We also set is_last_pass to indicate whether any more passes will be + * required. + */ + +METHODDEF(void) +prepare_for_pass (j_compress_ptr cinfo) +{ + my_master_ptr master = (my_master_ptr) cinfo->master; + + switch (master->pass_type) { + case main_pass: + /* Initial pass: will collect input data, and do either Huffman + * optimization or data output for the first scan. + */ + select_scan_parameters(cinfo); + per_scan_setup(cinfo); + if (! cinfo->raw_data_in) { + (*cinfo->cconvert->start_pass) (cinfo); + (*cinfo->downsample->start_pass) (cinfo); + (*cinfo->prep->start_pass) (cinfo, JBUF_PASS_THRU); + } + (*cinfo->fdct->start_pass) (cinfo); + (*cinfo->entropy->start_pass) (cinfo, cinfo->optimize_coding); + (*cinfo->coef->start_pass) (cinfo, + (master->total_passes > 1 ? + JBUF_SAVE_AND_PASS : JBUF_PASS_THRU)); + (*cinfo->main->start_pass) (cinfo, JBUF_PASS_THRU); + if (cinfo->optimize_coding) { + /* No immediate data output; postpone writing frame/scan headers */ + master->pub.call_pass_startup = FALSE; + } else { + /* Will write frame/scan headers at first jpeg_write_scanlines call */ + master->pub.call_pass_startup = TRUE; + } + break; +#ifdef ENTROPY_OPT_SUPPORTED + case huff_opt_pass: + /* Do Huffman optimization for a scan after the first one. */ + select_scan_parameters(cinfo); + per_scan_setup(cinfo); + if (cinfo->Ss != 0 || cinfo->Ah == 0 || cinfo->arith_code) { + (*cinfo->entropy->start_pass) (cinfo, TRUE); + (*cinfo->coef->start_pass) (cinfo, JBUF_CRANK_DEST); + master->pub.call_pass_startup = FALSE; + break; + } + /* Special case: Huffman DC refinement scans need no Huffman table + * and therefore we can skip the optimization pass for them. + */ + master->pass_type = output_pass; + master->pass_number++; + /*FALLTHROUGH*/ +#endif + case output_pass: + /* Do a data-output pass. */ + /* We need not repeat per-scan setup if prior optimization pass did it. */ + if (! cinfo->optimize_coding) { + select_scan_parameters(cinfo); + per_scan_setup(cinfo); + } + (*cinfo->entropy->start_pass) (cinfo, FALSE); + (*cinfo->coef->start_pass) (cinfo, JBUF_CRANK_DEST); + /* We emit frame/scan headers now */ + if (master->scan_number == 0) + (*cinfo->marker->write_frame_header) (cinfo); + (*cinfo->marker->write_scan_header) (cinfo); + master->pub.call_pass_startup = FALSE; + break; + default: + ERREXIT(cinfo, JERR_NOT_COMPILED); + } + + master->pub.is_last_pass = (master->pass_number == master->total_passes-1); + + /* Set up progress monitor's pass info if present */ + if (cinfo->progress != NULL) { + cinfo->progress->completed_passes = master->pass_number; + cinfo->progress->total_passes = master->total_passes; + } +} + + +/* + * Special start-of-pass hook. + * This is called by jpeg_write_scanlines if call_pass_startup is TRUE. + * In single-pass processing, we need this hook because we don't want to + * write frame/scan headers during jpeg_start_compress; we want to let the + * application write COM markers etc. between jpeg_start_compress and the + * jpeg_write_scanlines loop. + * In multi-pass processing, this routine is not used. + */ + +METHODDEF(void) +pass_startup (j_compress_ptr cinfo) +{ + cinfo->master->call_pass_startup = FALSE; /* reset flag so call only once */ + + (*cinfo->marker->write_frame_header) (cinfo); + (*cinfo->marker->write_scan_header) (cinfo); +} + + +/* + * Finish up at end of pass. + */ + +METHODDEF(void) +finish_pass_master (j_compress_ptr cinfo) +{ + my_master_ptr master = (my_master_ptr) cinfo->master; + + /* The entropy coder always needs an end-of-pass call, + * either to analyze statistics or to flush its output buffer. + */ + (*cinfo->entropy->finish_pass) (cinfo); + + /* Update state for next pass */ + switch (master->pass_type) { + case main_pass: + /* next pass is either output of scan 0 (after optimization) + * or output of scan 1 (if no optimization). + */ + master->pass_type = output_pass; + if (! cinfo->optimize_coding) + master->scan_number++; + break; + case huff_opt_pass: + /* next pass is always output of current scan */ + master->pass_type = output_pass; + break; + case output_pass: + /* next pass is either optimization or output of next scan */ + if (cinfo->optimize_coding) + master->pass_type = huff_opt_pass; + master->scan_number++; + break; + } + + master->pass_number++; +} + + +/* + * Initialize master compression control. + */ + +GLOBAL(void) +jinit_c_master_control (j_compress_ptr cinfo, int transcode_only) +{ + my_master_ptr master; + + master = (my_master_ptr) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE, + SIZEOF(my_comp_master)); + cinfo->master = (struct jpeg_comp_master *) master; + master->pub.prepare_for_pass = prepare_for_pass; + master->pub.pass_startup = pass_startup; + master->pub.finish_pass = finish_pass_master; + master->pub.is_last_pass = FALSE; + + /* Validate parameters, determine derived values */ + initial_setup(cinfo); + + if (cinfo->scan_info != NULL) { +#ifdef C_MULTISCAN_FILES_SUPPORTED + validate_script(cinfo); +#else + ERREXIT(cinfo, JERR_NOT_COMPILED); +#endif + } else { + cinfo->progressive_mode = FALSE; + cinfo->num_scans = 1; + } + + if (cinfo->progressive_mode) /* TEMPORARY HACK ??? */ + cinfo->optimize_coding = TRUE; /* assume default tables no good for progressive mode */ + + /* Initialize my private state */ + if (transcode_only) { + /* no main pass in transcoding */ + if (cinfo->optimize_coding) + master->pass_type = huff_opt_pass; + else + master->pass_type = output_pass; + } else { + /* for normal compression, first pass is always this type: */ + master->pass_type = main_pass; + } + master->scan_number = 0; + master->pass_number = 0; + if (cinfo->optimize_coding) + master->total_passes = cinfo->num_scans * 2; + else + master->total_passes = cinfo->num_scans; +} diff --git a/ml/dlib/dlib/external/libjpeg/jcomapi.cpp b/ml/dlib/dlib/external/libjpeg/jcomapi.cpp new file mode 100644 index 000000000..9b1fa7568 --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jcomapi.cpp @@ -0,0 +1,106 @@ +/* + * jcomapi.c + * + * Copyright (C) 1994-1997, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This file contains application interface routines that are used for both + * compression and decompression. + */ + +#define JPEG_INTERNALS +#include "jinclude.h" +#include "jpeglib.h" + + +/* + * Abort processing of a JPEG compression or decompression operation, + * but don't destroy the object itself. + * + * For this, we merely clean up all the nonpermanent memory pools. + * Note that temp files (virtual arrays) are not allowed to belong to + * the permanent pool, so we will be able to close all temp files here. + * Closing a data source or destination, if necessary, is the application's + * responsibility. + */ + +GLOBAL(void) +jpeg_abort (j_common_ptr cinfo) +{ + int pool; + + /* Do nothing if called on a not-initialized or destroyed JPEG object. */ + if (cinfo->mem == NULL) + return; + + /* Releasing pools in reverse order might help avoid fragmentation + * with some (brain-damaged) malloc libraries. + */ + for (pool = JPOOL_NUMPOOLS-1; pool > JPOOL_PERMANENT; pool--) { + (*cinfo->mem->free_pool) (cinfo, pool); + } + + /* Reset overall state for possible reuse of object */ + if (cinfo->is_decompressor) { + cinfo->global_state = DSTATE_START; + /* Try to keep application from accessing now-deleted marker list. + * A bit kludgy to do it here, but this is the most central place. + */ + ((j_decompress_ptr) cinfo)->marker_list = NULL; + } else { + cinfo->global_state = CSTATE_START; + } +} + + +/* + * Destruction of a JPEG object. + * + * Everything gets deallocated except the master jpeg_compress_struct itself + * and the error manager struct. Both of these are supplied by the application + * and must be freed, if necessary, by the application. (Often they are on + * the stack and so don't need to be freed anyway.) + * Closing a data source or destination, if necessary, is the application's + * responsibility. + */ + +GLOBAL(void) +jpeg_destroy (j_common_ptr cinfo) +{ + /* We need only tell the memory manager to release everything. */ + /* NB: mem pointer is NULL if memory mgr failed to initialize. */ + if (cinfo->mem != NULL) + (*cinfo->mem->self_destruct) (cinfo); + cinfo->mem = NULL; /* be safe if jpeg_destroy is called twice */ + cinfo->global_state = 0; /* mark it destroyed */ +} + + +/* + * Convenience routines for allocating quantization and Huffman tables. + * (Would jutils.c be a more reasonable place to put these?) + */ + +GLOBAL(JQUANT_TBL *) +jpeg_alloc_quant_table (j_common_ptr cinfo) +{ + JQUANT_TBL *tbl; + + tbl = (JQUANT_TBL *) + (*cinfo->mem->alloc_small) (cinfo, JPOOL_PERMANENT, SIZEOF(JQUANT_TBL)); + tbl->sent_table = FALSE; /* make sure this is false in any new table */ + return tbl; +} + + +GLOBAL(JHUFF_TBL *) +jpeg_alloc_huff_table (j_common_ptr cinfo) +{ + JHUFF_TBL *tbl; + + tbl = (JHUFF_TBL *) + (*cinfo->mem->alloc_small) (cinfo, JPOOL_PERMANENT, SIZEOF(JHUFF_TBL)); + tbl->sent_table = FALSE; /* make sure this is false in any new table */ + return tbl; +} diff --git a/ml/dlib/dlib/external/libjpeg/jconfig.h b/ml/dlib/dlib/external/libjpeg/jconfig.h new file mode 100644 index 000000000..9594ec56b --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jconfig.h @@ -0,0 +1,45 @@ +/* jconfig.h. Generated automatically by configure. */ +/* jconfig.cfg --- source file edited by configure script */ +/* see jconfig.doc for explanations */ + +#define HAVE_PROTOTYPES +#define HAVE_UNSIGNED_CHAR +#define HAVE_UNSIGNED_SHORT +#undef void +#undef const +#undef CHAR_IS_UNSIGNED +#define HAVE_STDDEF_H +#define HAVE_STDLIB_H +#undef NEED_BSD_STRINGS +#undef NEED_SYS_TYPES_H +#undef NEED_FAR_POINTERS +#undef NEED_SHORT_EXTERNAL_NAMES +/* Define this if you get warnings about undefined structures. */ +#undef INCOMPLETE_TYPES_BROKEN + +#ifdef JPEG_INTERNALS + +#undef RIGHT_SHIFT_IS_UNSIGNED +#define INLINE __inline__ +/* These are for configuring the JPEG memory manager. */ +#undef DEFAULT_MAX_MEM +#undef NO_MKTEMP + +#endif /* JPEG_INTERNALS */ + +#ifdef JPEG_CJPEG_DJPEG + +#define BMP_SUPPORTED /* BMP image file format */ +#define GIF_SUPPORTED /* GIF image file format */ +#define PPM_SUPPORTED /* PBMPLUS PPM/PGM image file format */ +#undef RLE_SUPPORTED /* Utah RLE image file format */ +#define TARGA_SUPPORTED /* Targa image file format */ + +#undef TWO_FILE_COMMANDLINE +#undef NEED_SIGNAL_CATCHER +#undef DONT_USE_B_MODE + +/* Define this if you want percent-done progress reports from cjpeg/djpeg. */ +#undef PROGRESS_REPORT + +#endif /* JPEG_CJPEG_DJPEG */ diff --git a/ml/dlib/dlib/external/libjpeg/jcparam.cpp b/ml/dlib/dlib/external/libjpeg/jcparam.cpp new file mode 100644 index 000000000..a87ef2079 --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jcparam.cpp @@ -0,0 +1,610 @@ +/* + * jcparam.c + * + * Copyright (C) 1991-1998, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This file contains optional default-setting code for the JPEG compressor. + * Applications do not have to use this file, but those that don't use it + * must know a lot more about the innards of the JPEG code. + */ + +#define JPEG_INTERNALS +#include "jinclude.h" +#include "jpeglib.h" + + +/* + * Quantization table setup routines + */ + +GLOBAL(void) +jpeg_add_quant_table (j_compress_ptr cinfo, int which_tbl, + const unsigned int *basic_table, + int scale_factor, int force_baseline) +/* Define a quantization table equal to the basic_table times + * a scale factor (given as a percentage). + * If force_baseline is TRUE, the computed quantization table entries + * are limited to 1..255 for JPEG baseline compatibility. + */ +{ + JQUANT_TBL ** qtblptr; + int i; + long temp; + + /* Safety check to ensure start_compress not called yet. */ + if (cinfo->global_state != CSTATE_START) + ERREXIT1(cinfo, JERR_BAD_STATE, cinfo->global_state); + + if (which_tbl < 0 || which_tbl >= NUM_QUANT_TBLS) + ERREXIT1(cinfo, JERR_DQT_INDEX, which_tbl); + + qtblptr = & cinfo->quant_tbl_ptrs[which_tbl]; + + if (*qtblptr == NULL) + *qtblptr = jpeg_alloc_quant_table((j_common_ptr) cinfo); + + for (i = 0; i < DCTSIZE2; i++) { + temp = ((long) basic_table[i] * scale_factor + 50L) / 100L; + /* limit the values to the valid range */ + if (temp <= 0L) temp = 1L; + if (temp > 32767L) temp = 32767L; /* max quantizer needed for 12 bits */ + if (force_baseline && temp > 255L) + temp = 255L; /* limit to baseline range if requested */ + (*qtblptr)->quantval[i] = (unsigned short) temp; + } + + /* Initialize sent_table FALSE so table will be written to JPEG file. */ + (*qtblptr)->sent_table = FALSE; +} + + +GLOBAL(void) +jpeg_set_linear_quality (j_compress_ptr cinfo, int scale_factor, + int force_baseline) +/* Set or change the 'quality' (quantization) setting, using default tables + * and a straight percentage-scaling quality scale. In most cases it's better + * to use jpeg_set_quality (below); this entry point is provided for + * applications that insist on a linear percentage scaling. + */ +{ + /* These are the sample quantization tables given in JPEG spec section K.1. + * The spec says that the values given produce "good" quality, and + * when divided by 2, "very good" quality. + */ + static const unsigned int std_luminance_quant_tbl[DCTSIZE2] = { + 16, 11, 10, 16, 24, 40, 51, 61, + 12, 12, 14, 19, 26, 58, 60, 55, + 14, 13, 16, 24, 40, 57, 69, 56, + 14, 17, 22, 29, 51, 87, 80, 62, + 18, 22, 37, 56, 68, 109, 103, 77, + 24, 35, 55, 64, 81, 104, 113, 92, + 49, 64, 78, 87, 103, 121, 120, 101, + 72, 92, 95, 98, 112, 100, 103, 99 + }; + static const unsigned int std_chrominance_quant_tbl[DCTSIZE2] = { + 17, 18, 24, 47, 99, 99, 99, 99, + 18, 21, 26, 66, 99, 99, 99, 99, + 24, 26, 56, 99, 99, 99, 99, 99, + 47, 66, 99, 99, 99, 99, 99, 99, + 99, 99, 99, 99, 99, 99, 99, 99, + 99, 99, 99, 99, 99, 99, 99, 99, + 99, 99, 99, 99, 99, 99, 99, 99, + 99, 99, 99, 99, 99, 99, 99, 99 + }; + + /* Set up two quantization tables using the specified scaling */ + jpeg_add_quant_table(cinfo, 0, std_luminance_quant_tbl, + scale_factor, force_baseline); + jpeg_add_quant_table(cinfo, 1, std_chrominance_quant_tbl, + scale_factor, force_baseline); +} + + +GLOBAL(int) +jpeg_quality_scaling (int quality) +/* Convert a user-specified quality rating to a percentage scaling factor + * for an underlying quantization table, using our recommended scaling curve. + * The input 'quality' factor should be 0 (terrible) to 100 (very good). + */ +{ + /* Safety limit on quality factor. Convert 0 to 1 to avoid zero divide. */ + if (quality <= 0) quality = 1; + if (quality > 100) quality = 100; + + /* The basic table is used as-is (scaling 100) for a quality of 50. + * Qualities 50..100 are converted to scaling percentage 200 - 2*Q; + * note that at Q=100 the scaling is 0, which will cause jpeg_add_quant_table + * to make all the table entries 1 (hence, minimum quantization loss). + * Qualities 1..50 are converted to scaling percentage 5000/Q. + */ + if (quality < 50) + quality = 5000 / quality; + else + quality = 200 - quality*2; + + return quality; +} + + +GLOBAL(void) +jpeg_set_quality (j_compress_ptr cinfo, int quality, int force_baseline) +/* Set or change the 'quality' (quantization) setting, using default tables. + * This is the standard quality-adjusting entry point for typical user + * interfaces; only those who want detailed control over quantization tables + * would use the preceding three routines directly. + */ +{ + /* Convert user 0-100 rating to percentage scaling */ + quality = jpeg_quality_scaling(quality); + + /* Set up standard quality tables */ + jpeg_set_linear_quality(cinfo, quality, force_baseline); +} + + +/* + * Huffman table setup routines + */ + +LOCAL(void) +add_huff_table (j_compress_ptr cinfo, + JHUFF_TBL **htblptr, const unsigned char *bits, const unsigned char *val) +/* Define a Huffman table */ +{ + int nsymbols, len; + + if (*htblptr == NULL) + *htblptr = jpeg_alloc_huff_table((j_common_ptr) cinfo); + + /* Copy the number-of-symbols-of-each-code-length counts */ + MEMCOPY((*htblptr)->bits, bits, SIZEOF((*htblptr)->bits)); + + /* Validate the counts. We do this here mainly so we can copy the right + * number of symbols from the val[] array, without risking marching off + * the end of memory. jchuff.c will do a more thorough test later. + */ + nsymbols = 0; + for (len = 1; len <= 16; len++) + nsymbols += bits[len]; + if (nsymbols < 1 || nsymbols > 256) + ERREXIT(cinfo, JERR_BAD_HUFF_TABLE); + + MEMCOPY((*htblptr)->huffval, val, nsymbols * SIZEOF(unsigned char)); + + /* Initialize sent_table FALSE so table will be written to JPEG file. */ + (*htblptr)->sent_table = FALSE; +} + + +LOCAL(void) +std_huff_tables (j_compress_ptr cinfo) +/* Set up the standard Huffman tables (cf. JPEG standard section K.3) */ +/* IMPORTANT: these are only valid for 8-bit data precision! */ +{ + static const unsigned char bits_dc_luminance[17] = + { /* 0-base */ 0, 0, 1, 5, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0 }; + static const unsigned char val_dc_luminance[] = + { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11 }; + + static const unsigned char bits_dc_chrominance[17] = + { /* 0-base */ 0, 0, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0 }; + static const unsigned char val_dc_chrominance[] = + { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11 }; + + static const unsigned char bits_ac_luminance[17] = + { /* 0-base */ 0, 0, 2, 1, 3, 3, 2, 4, 3, 5, 5, 4, 4, 0, 0, 1, 0x7d }; + static const unsigned char val_ac_luminance[] = + { 0x01, 0x02, 0x03, 0x00, 0x04, 0x11, 0x05, 0x12, + 0x21, 0x31, 0x41, 0x06, 0x13, 0x51, 0x61, 0x07, + 0x22, 0x71, 0x14, 0x32, 0x81, 0x91, 0xa1, 0x08, + 0x23, 0x42, 0xb1, 0xc1, 0x15, 0x52, 0xd1, 0xf0, + 0x24, 0x33, 0x62, 0x72, 0x82, 0x09, 0x0a, 0x16, + 0x17, 0x18, 0x19, 0x1a, 0x25, 0x26, 0x27, 0x28, + 0x29, 0x2a, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, + 0x3a, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, 0x49, + 0x4a, 0x53, 0x54, 0x55, 0x56, 0x57, 0x58, 0x59, + 0x5a, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, 0x69, + 0x6a, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, + 0x7a, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, + 0x8a, 0x92, 0x93, 0x94, 0x95, 0x96, 0x97, 0x98, + 0x99, 0x9a, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6, 0xa7, + 0xa8, 0xa9, 0xaa, 0xb2, 0xb3, 0xb4, 0xb5, 0xb6, + 0xb7, 0xb8, 0xb9, 0xba, 0xc2, 0xc3, 0xc4, 0xc5, + 0xc6, 0xc7, 0xc8, 0xc9, 0xca, 0xd2, 0xd3, 0xd4, + 0xd5, 0xd6, 0xd7, 0xd8, 0xd9, 0xda, 0xe1, 0xe2, + 0xe3, 0xe4, 0xe5, 0xe6, 0xe7, 0xe8, 0xe9, 0xea, + 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, + 0xf9, 0xfa }; + + static const unsigned char bits_ac_chrominance[17] = + { /* 0-base */ 0, 0, 2, 1, 2, 4, 4, 3, 4, 7, 5, 4, 4, 0, 1, 2, 0x77 }; + static const unsigned char val_ac_chrominance[] = + { 0x00, 0x01, 0x02, 0x03, 0x11, 0x04, 0x05, 0x21, + 0x31, 0x06, 0x12, 0x41, 0x51, 0x07, 0x61, 0x71, + 0x13, 0x22, 0x32, 0x81, 0x08, 0x14, 0x42, 0x91, + 0xa1, 0xb1, 0xc1, 0x09, 0x23, 0x33, 0x52, 0xf0, + 0x15, 0x62, 0x72, 0xd1, 0x0a, 0x16, 0x24, 0x34, + 0xe1, 0x25, 0xf1, 0x17, 0x18, 0x19, 0x1a, 0x26, + 0x27, 0x28, 0x29, 0x2a, 0x35, 0x36, 0x37, 0x38, + 0x39, 0x3a, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, + 0x49, 0x4a, 0x53, 0x54, 0x55, 0x56, 0x57, 0x58, + 0x59, 0x5a, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, + 0x69, 0x6a, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, + 0x79, 0x7a, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, + 0x88, 0x89, 0x8a, 0x92, 0x93, 0x94, 0x95, 0x96, + 0x97, 0x98, 0x99, 0x9a, 0xa2, 0xa3, 0xa4, 0xa5, + 0xa6, 0xa7, 0xa8, 0xa9, 0xaa, 0xb2, 0xb3, 0xb4, + 0xb5, 0xb6, 0xb7, 0xb8, 0xb9, 0xba, 0xc2, 0xc3, + 0xc4, 0xc5, 0xc6, 0xc7, 0xc8, 0xc9, 0xca, 0xd2, + 0xd3, 0xd4, 0xd5, 0xd6, 0xd7, 0xd8, 0xd9, 0xda, + 0xe2, 0xe3, 0xe4, 0xe5, 0xe6, 0xe7, 0xe8, 0xe9, + 0xea, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, + 0xf9, 0xfa }; + + add_huff_table(cinfo, &cinfo->dc_huff_tbl_ptrs[0], + bits_dc_luminance, val_dc_luminance); + add_huff_table(cinfo, &cinfo->ac_huff_tbl_ptrs[0], + bits_ac_luminance, val_ac_luminance); + add_huff_table(cinfo, &cinfo->dc_huff_tbl_ptrs[1], + bits_dc_chrominance, val_dc_chrominance); + add_huff_table(cinfo, &cinfo->ac_huff_tbl_ptrs[1], + bits_ac_chrominance, val_ac_chrominance); +} + + +/* + * Default parameter setup for compression. + * + * Applications that don't choose to use this routine must do their + * own setup of all these parameters. Alternately, you can call this + * to establish defaults and then alter parameters selectively. This + * is the recommended approach since, if we add any new parameters, + * your code will still work (they'll be set to reasonable defaults). + */ + +GLOBAL(void) +jpeg_set_defaults (j_compress_ptr cinfo) +{ + int i; + + /* Safety check to ensure start_compress not called yet. */ + if (cinfo->global_state != CSTATE_START) + ERREXIT1(cinfo, JERR_BAD_STATE, cinfo->global_state); + + /* Allocate comp_info array large enough for maximum component count. + * Array is made permanent in case application wants to compress + * multiple images at same param settings. + */ + if (cinfo->comp_info == NULL) + cinfo->comp_info = (jpeg_component_info *) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_PERMANENT, + MAX_COMPONENTS * SIZEOF(jpeg_component_info)); + + /* Initialize everything not dependent on the color space */ + + cinfo->data_precision = BITS_IN_JSAMPLE; + /* Set up two quantization tables using default quality of 75 */ + jpeg_set_quality(cinfo, 75, TRUE); + /* Set up two Huffman tables */ + std_huff_tables(cinfo); + + /* Initialize default arithmetic coding conditioning */ + for (i = 0; i < NUM_ARITH_TBLS; i++) { + cinfo->arith_dc_L[i] = 0; + cinfo->arith_dc_U[i] = 1; + cinfo->arith_ac_K[i] = 5; + } + + /* Default is no multiple-scan output */ + cinfo->scan_info = NULL; + cinfo->num_scans = 0; + + /* Expect normal source image, not raw downsampled data */ + cinfo->raw_data_in = FALSE; + + /* Use Huffman coding, not arithmetic coding, by default */ + cinfo->arith_code = FALSE; + + /* By default, don't do extra passes to optimize entropy coding */ + cinfo->optimize_coding = FALSE; + /* The standard Huffman tables are only valid for 8-bit data precision. + * If the precision is higher, force optimization on so that usable + * tables will be computed. This test can be removed if default tables + * are supplied that are valid for the desired precision. + */ + if (cinfo->data_precision > 8) + cinfo->optimize_coding = TRUE; + + /* By default, use the simpler non-cosited sampling alignment */ + cinfo->CCIR601_sampling = FALSE; + + /* No input smoothing */ + cinfo->smoothing_factor = 0; + + /* DCT algorithm preference */ + cinfo->dct_method = JDCT_DEFAULT; + + /* No restart markers */ + cinfo->restart_interval = 0; + cinfo->restart_in_rows = 0; + + /* Fill in default JFIF marker parameters. Note that whether the marker + * will actually be written is determined by jpeg_set_colorspace. + * + * By default, the library emits JFIF version code 1.01. + * An application that wants to emit JFIF 1.02 extension markers should set + * JFIF_minor_version to 2. We could probably get away with just defaulting + * to 1.02, but there may still be some decoders in use that will complain + * about that; saying 1.01 should minimize compatibility problems. + */ + cinfo->JFIF_major_version = 1; /* Default JFIF version = 1.01 */ + cinfo->JFIF_minor_version = 1; + cinfo->density_unit = 0; /* Pixel size is unknown by default */ + cinfo->X_density = 1; /* Pixel aspect ratio is square by default */ + cinfo->Y_density = 1; + + /* Choose JPEG colorspace based on input space, set defaults accordingly */ + + jpeg_default_colorspace(cinfo); +} + + +/* + * Select an appropriate JPEG colorspace for in_color_space. + */ + +GLOBAL(void) +jpeg_default_colorspace (j_compress_ptr cinfo) +{ + switch (cinfo->in_color_space) { + case JCS_GRAYSCALE: + jpeg_set_colorspace(cinfo, JCS_GRAYSCALE); + break; + case JCS_RGB: + jpeg_set_colorspace(cinfo, JCS_YCbCr); + break; + case JCS_YCbCr: + jpeg_set_colorspace(cinfo, JCS_YCbCr); + break; + case JCS_CMYK: + jpeg_set_colorspace(cinfo, JCS_CMYK); /* By default, no translation */ + break; + case JCS_YCCK: + jpeg_set_colorspace(cinfo, JCS_YCCK); + break; + case JCS_UNKNOWN: + jpeg_set_colorspace(cinfo, JCS_UNKNOWN); + break; + default: + ERREXIT(cinfo, JERR_BAD_IN_COLORSPACE); + } +} + + +/* + * Set the JPEG colorspace, and choose colorspace-dependent default values. + */ + +GLOBAL(void) +jpeg_set_colorspace (j_compress_ptr cinfo, J_COLOR_SPACE colorspace) +{ + jpeg_component_info * compptr; + int ci; + +#define SET_COMP(index,id,hsamp,vsamp,quant,dctbl,actbl) \ + (compptr = &cinfo->comp_info[index], \ + compptr->component_id = (id), \ + compptr->h_samp_factor = (hsamp), \ + compptr->v_samp_factor = (vsamp), \ + compptr->quant_tbl_no = (quant), \ + compptr->dc_tbl_no = (dctbl), \ + compptr->ac_tbl_no = (actbl) ) + + /* Safety check to ensure start_compress not called yet. */ + if (cinfo->global_state != CSTATE_START) + ERREXIT1(cinfo, JERR_BAD_STATE, cinfo->global_state); + + /* For all colorspaces, we use Q and Huff tables 0 for luminance components, + * tables 1 for chrominance components. + */ + + cinfo->jpeg_color_space = colorspace; + + cinfo->write_JFIF_header = FALSE; /* No marker for non-JFIF colorspaces */ + cinfo->write_Adobe_marker = FALSE; /* write no Adobe marker by default */ + + switch (colorspace) { + case JCS_GRAYSCALE: + cinfo->write_JFIF_header = TRUE; /* Write a JFIF marker */ + cinfo->num_components = 1; + /* JFIF specifies component ID 1 */ + SET_COMP(0, 1, 1,1, 0, 0,0); + break; + case JCS_RGB: + cinfo->write_Adobe_marker = TRUE; /* write Adobe marker to flag RGB */ + cinfo->num_components = 3; + SET_COMP(0, 0x52 /* 'R' */, 1,1, 0, 0,0); + SET_COMP(1, 0x47 /* 'G' */, 1,1, 0, 0,0); + SET_COMP(2, 0x42 /* 'B' */, 1,1, 0, 0,0); + break; + case JCS_YCbCr: + cinfo->write_JFIF_header = TRUE; /* Write a JFIF marker */ + cinfo->num_components = 3; + /* JFIF specifies component IDs 1,2,3 */ + /* We default to 2x2 subsamples of chrominance */ + SET_COMP(0, 1, 2,2, 0, 0,0); + SET_COMP(1, 2, 1,1, 1, 1,1); + SET_COMP(2, 3, 1,1, 1, 1,1); + break; + case JCS_CMYK: + cinfo->write_Adobe_marker = TRUE; /* write Adobe marker to flag CMYK */ + cinfo->num_components = 4; + SET_COMP(0, 0x43 /* 'C' */, 1,1, 0, 0,0); + SET_COMP(1, 0x4D /* 'M' */, 1,1, 0, 0,0); + SET_COMP(2, 0x59 /* 'Y' */, 1,1, 0, 0,0); + SET_COMP(3, 0x4B /* 'K' */, 1,1, 0, 0,0); + break; + case JCS_YCCK: + cinfo->write_Adobe_marker = TRUE; /* write Adobe marker to flag YCCK */ + cinfo->num_components = 4; + SET_COMP(0, 1, 2,2, 0, 0,0); + SET_COMP(1, 2, 1,1, 1, 1,1); + SET_COMP(2, 3, 1,1, 1, 1,1); + SET_COMP(3, 4, 2,2, 0, 0,0); + break; + case JCS_UNKNOWN: + cinfo->num_components = cinfo->input_components; + if (cinfo->num_components < 1 || cinfo->num_components > MAX_COMPONENTS) + ERREXIT2(cinfo, JERR_COMPONENT_COUNT, cinfo->num_components, + MAX_COMPONENTS); + for (ci = 0; ci < cinfo->num_components; ci++) { + SET_COMP(ci, ci, 1,1, 0, 0,0); + } + break; + default: + ERREXIT(cinfo, JERR_BAD_J_COLORSPACE); + } +} + + +#ifdef C_PROGRESSIVE_SUPPORTED + +LOCAL(jpeg_scan_info *) +fill_a_scan (jpeg_scan_info * scanptr, int ci, + int Ss, int Se, int Ah, int Al) +/* Support routine: generate one scan for specified component */ +{ + scanptr->comps_in_scan = 1; + scanptr->component_index[0] = ci; + scanptr->Ss = Ss; + scanptr->Se = Se; + scanptr->Ah = Ah; + scanptr->Al = Al; + scanptr++; + return scanptr; +} + +LOCAL(jpeg_scan_info *) +fill_scans (jpeg_scan_info * scanptr, int ncomps, + int Ss, int Se, int Ah, int Al) +/* Support routine: generate one scan for each component */ +{ + int ci; + + for (ci = 0; ci < ncomps; ci++) { + scanptr->comps_in_scan = 1; + scanptr->component_index[0] = ci; + scanptr->Ss = Ss; + scanptr->Se = Se; + scanptr->Ah = Ah; + scanptr->Al = Al; + scanptr++; + } + return scanptr; +} + +LOCAL(jpeg_scan_info *) +fill_dc_scans (jpeg_scan_info * scanptr, int ncomps, int Ah, int Al) +/* Support routine: generate interleaved DC scan if possible, else N scans */ +{ + int ci; + + if (ncomps <= MAX_COMPS_IN_SCAN) { + /* Single interleaved DC scan */ + scanptr->comps_in_scan = ncomps; + for (ci = 0; ci < ncomps; ci++) + scanptr->component_index[ci] = ci; + scanptr->Ss = scanptr->Se = 0; + scanptr->Ah = Ah; + scanptr->Al = Al; + scanptr++; + } else { + /* Noninterleaved DC scan for each component */ + scanptr = fill_scans(scanptr, ncomps, 0, 0, Ah, Al); + } + return scanptr; +} + + +/* + * Create a recommended progressive-JPEG script. + * cinfo->num_components and cinfo->jpeg_color_space must be correct. + */ + +GLOBAL(void) +jpeg_simple_progression (j_compress_ptr cinfo) +{ + int ncomps = cinfo->num_components; + int nscans; + jpeg_scan_info * scanptr; + + /* Safety check to ensure start_compress not called yet. */ + if (cinfo->global_state != CSTATE_START) + ERREXIT1(cinfo, JERR_BAD_STATE, cinfo->global_state); + + /* Figure space needed for script. Calculation must match code below! */ + if (ncomps == 3 && cinfo->jpeg_color_space == JCS_YCbCr) { + /* Custom script for YCbCr color images. */ + nscans = 10; + } else { + /* All-purpose script for other color spaces. */ + if (ncomps > MAX_COMPS_IN_SCAN) + nscans = 6 * ncomps; /* 2 DC + 4 AC scans per component */ + else + nscans = 2 + 4 * ncomps; /* 2 DC scans; 4 AC scans per component */ + } + + /* Allocate space for script. + * We need to put it in the permanent pool in case the application performs + * multiple compressions without changing the settings. To avoid a memory + * leak if jpeg_simple_progression is called repeatedly for the same JPEG + * object, we try to re-use previously allocated space, and we allocate + * enough space to handle YCbCr even if initially asked for grayscale. + */ + if (cinfo->script_space == NULL || cinfo->script_space_size < nscans) { + cinfo->script_space_size = MAX(nscans, 10); + cinfo->script_space = (jpeg_scan_info *) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_PERMANENT, + cinfo->script_space_size * SIZEOF(jpeg_scan_info)); + } + scanptr = cinfo->script_space; + cinfo->scan_info = scanptr; + cinfo->num_scans = nscans; + + if (ncomps == 3 && cinfo->jpeg_color_space == JCS_YCbCr) { + /* Custom script for YCbCr color images. */ + /* Initial DC scan */ + scanptr = fill_dc_scans(scanptr, ncomps, 0, 1); + /* Initial AC scan: get some luma data out in a hurry */ + scanptr = fill_a_scan(scanptr, 0, 1, 5, 0, 2); + /* Chroma data is too small to be worth expending many scans on */ + scanptr = fill_a_scan(scanptr, 2, 1, 63, 0, 1); + scanptr = fill_a_scan(scanptr, 1, 1, 63, 0, 1); + /* Complete spectral selection for luma AC */ + scanptr = fill_a_scan(scanptr, 0, 6, 63, 0, 2); + /* Refine next bit of luma AC */ + scanptr = fill_a_scan(scanptr, 0, 1, 63, 2, 1); + /* Finish DC successive approximation */ + scanptr = fill_dc_scans(scanptr, ncomps, 1, 0); + /* Finish AC successive approximation */ + scanptr = fill_a_scan(scanptr, 2, 1, 63, 1, 0); + scanptr = fill_a_scan(scanptr, 1, 1, 63, 1, 0); + /* Luma bottom bit comes last since it's usually largest scan */ + scanptr = fill_a_scan(scanptr, 0, 1, 63, 1, 0); + } else { + /* All-purpose script for other color spaces. */ + /* Successive approximation first pass */ + scanptr = fill_dc_scans(scanptr, ncomps, 0, 1); + scanptr = fill_scans(scanptr, ncomps, 1, 5, 0, 2); + scanptr = fill_scans(scanptr, ncomps, 6, 63, 0, 2); + /* Successive approximation second pass */ + scanptr = fill_scans(scanptr, ncomps, 1, 63, 2, 1); + /* Successive approximation final pass */ + scanptr = fill_dc_scans(scanptr, ncomps, 1, 0); + scanptr = fill_scans(scanptr, ncomps, 1, 63, 1, 0); + } +} + +#endif /* C_PROGRESSIVE_SUPPORTED */ diff --git a/ml/dlib/dlib/external/libjpeg/jcphuff.cpp b/ml/dlib/dlib/external/libjpeg/jcphuff.cpp new file mode 100644 index 000000000..66a85b8c7 --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jcphuff.cpp @@ -0,0 +1,833 @@ +/* + * jcphuff.c + * + * Copyright (C) 1995-1997, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This file contains Huffman entropy encoding routines for progressive JPEG. + * + * We do not support output suspension in this module, since the library + * currently does not allow multiple-scan files to be written with output + * suspension. + */ + +#define JPEG_INTERNALS +#include "jinclude.h" +#include "jpeglib.h" +#include "jchuff.h" /* Declarations shared with jchuff.c */ + +#ifdef C_PROGRESSIVE_SUPPORTED + +/* Expanded entropy encoder object for progressive Huffman encoding. */ + +typedef struct { + struct jpeg_entropy_encoder pub; /* public fields */ + + /* Mode flag: TRUE for optimization, FALSE for actual data output */ + int gather_statistics; + + /* Bit-level coding status. + * next_output_byte/free_in_buffer are local copies of cinfo->dest fields. + */ + JOCTET * next_output_byte; /* => next byte to write in buffer */ + size_t free_in_buffer; /* # of byte spaces remaining in buffer */ + long put_buffer; /* current bit-accumulation buffer */ + int put_bits; /* # of bits now in it */ + j_compress_ptr cinfo; /* link to cinfo (needed for dump_buffer) */ + + /* Coding status for DC components */ + int last_dc_val[MAX_COMPS_IN_SCAN]; /* last DC coef for each component */ + + /* Coding status for AC components */ + int ac_tbl_no; /* the table number of the single component */ + unsigned int EOBRUN; /* run length of EOBs */ + unsigned int BE; /* # of buffered correction bits before MCU */ + char * bit_buffer; /* buffer for correction bits (1 per char) */ + /* packing correction bits tightly would save some space but cost time... */ + + unsigned int restarts_to_go; /* MCUs left in this restart interval */ + int next_restart_num; /* next restart number to write (0-7) */ + + /* Pointers to derived tables (these workspaces have image lifespan). + * Since any one scan codes only DC or only AC, we only need one set + * of tables, not one for DC and one for AC. + */ + c_derived_tbl * derived_tbls[NUM_HUFF_TBLS]; + + /* Statistics tables for optimization; again, one set is enough */ + long * count_ptrs[NUM_HUFF_TBLS]; +} phuff_entropy_encoder; + +typedef phuff_entropy_encoder * phuff_entropy_ptr; + +/* MAX_CORR_BITS is the number of bits the AC refinement correction-bit + * buffer can hold. Larger sizes may slightly improve compression, but + * 1000 is already well into the realm of overkill. + * The minimum safe size is 64 bits. + */ + +#define MAX_CORR_BITS 1000 /* Max # of correction bits I can buffer */ + +/* IRIGHT_SHIFT is like RIGHT_SHIFT, but works on int rather than long. + * We assume that int right shift is unsigned if long right shift is, + * which should be safe. + */ + +#ifdef RIGHT_SHIFT_IS_UNSIGNED +#define ISHIFT_TEMPS int ishift_temp; +#define IRIGHT_SHIFT(x,shft) \ + ((ishift_temp = (x)) < 0 ? \ + (ishift_temp >> (shft)) | ((~0) << (16-(shft))) : \ + (ishift_temp >> (shft))) +#else +#define ISHIFT_TEMPS +#define IRIGHT_SHIFT(x,shft) ((x) >> (shft)) +#endif + +/* Forward declarations */ +METHODDEF(int) encode_mcu_DC_first JPP((j_compress_ptr cinfo, + JBLOCKROW *MCU_data)); +METHODDEF(int) encode_mcu_AC_first JPP((j_compress_ptr cinfo, + JBLOCKROW *MCU_data)); +METHODDEF(int) encode_mcu_DC_refine JPP((j_compress_ptr cinfo, + JBLOCKROW *MCU_data)); +METHODDEF(int) encode_mcu_AC_refine JPP((j_compress_ptr cinfo, + JBLOCKROW *MCU_data)); +METHODDEF(void) finish_pass_phuff JPP((j_compress_ptr cinfo)); +METHODDEF(void) finish_pass_gather_phuff JPP((j_compress_ptr cinfo)); + + +/* + * Initialize for a Huffman-compressed scan using progressive JPEG. + */ + +METHODDEF(void) +start_pass_phuff (j_compress_ptr cinfo, int gather_statistics) +{ + phuff_entropy_ptr entropy = (phuff_entropy_ptr) cinfo->entropy; + int is_DC_band; + int ci, tbl; + jpeg_component_info * compptr; + + entropy->cinfo = cinfo; + entropy->gather_statistics = gather_statistics; + + is_DC_band = (cinfo->Ss == 0); + + /* We assume jcmaster.c already validated the scan parameters. */ + + /* Select execution routines */ + if (cinfo->Ah == 0) { + if (is_DC_band) + entropy->pub.encode_mcu = encode_mcu_DC_first; + else + entropy->pub.encode_mcu = encode_mcu_AC_first; + } else { + if (is_DC_band) + entropy->pub.encode_mcu = encode_mcu_DC_refine; + else { + entropy->pub.encode_mcu = encode_mcu_AC_refine; + /* AC refinement needs a correction bit buffer */ + if (entropy->bit_buffer == NULL) + entropy->bit_buffer = (char *) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE, + MAX_CORR_BITS * SIZEOF(char)); + } + } + if (gather_statistics) + entropy->pub.finish_pass = finish_pass_gather_phuff; + else + entropy->pub.finish_pass = finish_pass_phuff; + + /* Only DC coefficients may be interleaved, so cinfo->comps_in_scan = 1 + * for AC coefficients. + */ + for (ci = 0; ci < cinfo->comps_in_scan; ci++) { + compptr = cinfo->cur_comp_info[ci]; + /* Initialize DC predictions to 0 */ + entropy->last_dc_val[ci] = 0; + /* Get table index */ + if (is_DC_band) { + if (cinfo->Ah != 0) /* DC refinement needs no table */ + continue; + tbl = compptr->dc_tbl_no; + } else { + entropy->ac_tbl_no = tbl = compptr->ac_tbl_no; + } + if (gather_statistics) { + /* Check for invalid table index */ + /* (make_c_derived_tbl does this in the other path) */ + if (tbl < 0 || tbl >= NUM_HUFF_TBLS) + ERREXIT1(cinfo, JERR_NO_HUFF_TABLE, tbl); + /* Allocate and zero the statistics tables */ + /* Note that jpeg_gen_optimal_table expects 257 entries in each table! */ + if (entropy->count_ptrs[tbl] == NULL) + entropy->count_ptrs[tbl] = (long *) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE, + 257 * SIZEOF(long)); + MEMZERO(entropy->count_ptrs[tbl], 257 * SIZEOF(long)); + } else { + /* Compute derived values for Huffman table */ + /* We may do this more than once for a table, but it's not expensive */ + jpeg_make_c_derived_tbl(cinfo, is_DC_band, tbl, + & entropy->derived_tbls[tbl]); + } + } + + /* Initialize AC stuff */ + entropy->EOBRUN = 0; + entropy->BE = 0; + + /* Initialize bit buffer to empty */ + entropy->put_buffer = 0; + entropy->put_bits = 0; + + /* Initialize restart stuff */ + entropy->restarts_to_go = cinfo->restart_interval; + entropy->next_restart_num = 0; +} + + +/* Outputting bytes to the file. + * NB: these must be called only when actually outputting, + * that is, entropy->gather_statistics == FALSE. + */ + +/* Emit a byte */ +#define emit_byte(entropy,val) \ + { *(entropy)->next_output_byte++ = (JOCTET) (val); \ + if (--(entropy)->free_in_buffer == 0) \ + dump_buffer(entropy); } + + +LOCAL(void) +dump_buffer (phuff_entropy_ptr entropy) +/* Empty the output buffer; we do not support suspension in this module. */ +{ + struct jpeg_destination_mgr * dest = entropy->cinfo->dest; + + if (! (*dest->empty_output_buffer) (entropy->cinfo)) + ERREXIT(entropy->cinfo, JERR_CANT_SUSPEND); + /* After a successful buffer dump, must reset buffer pointers */ + entropy->next_output_byte = dest->next_output_byte; + entropy->free_in_buffer = dest->free_in_buffer; +} + + +/* Outputting bits to the file */ + +/* Only the right 24 bits of put_buffer are used; the valid bits are + * left-justified in this part. At most 16 bits can be passed to emit_bits + * in one call, and we never retain more than 7 bits in put_buffer + * between calls, so 24 bits are sufficient. + */ + +inline +LOCAL(void) +emit_bits (phuff_entropy_ptr entropy, unsigned int code, int size) +/* Emit some bits, unless we are in gather mode */ +{ + /* This routine is heavily used, so it's worth coding tightly. */ + long put_buffer = (long) code; + int put_bits = entropy->put_bits; + + /* if size is 0, caller used an invalid Huffman table entry */ + if (size == 0) + ERREXIT(entropy->cinfo, JERR_HUFF_MISSING_CODE); + + if (entropy->gather_statistics) + return; /* do nothing if we're only getting stats */ + + put_buffer &= (((long) 1)<put_buffer; /* and merge with old buffer contents */ + + while (put_bits >= 8) { + int c = (int) ((put_buffer >> 16) & 0xFF); + + emit_byte(entropy, c); + if (c == 0xFF) { /* need to stuff a zero byte? */ + emit_byte(entropy, 0); + } + put_buffer <<= 8; + put_bits -= 8; + } + + entropy->put_buffer = put_buffer; /* update variables */ + entropy->put_bits = put_bits; +} + + +LOCAL(void) +flush_bits (phuff_entropy_ptr entropy) +{ + emit_bits(entropy, 0x7F, 7); /* fill any partial byte with ones */ + entropy->put_buffer = 0; /* and reset bit-buffer to empty */ + entropy->put_bits = 0; +} + + +/* + * Emit (or just count) a Huffman symbol. + */ + +inline +LOCAL(void) +emit_symbol (phuff_entropy_ptr entropy, int tbl_no, int symbol) +{ + if (entropy->gather_statistics) + entropy->count_ptrs[tbl_no][symbol]++; + else { + c_derived_tbl * tbl = entropy->derived_tbls[tbl_no]; + emit_bits(entropy, tbl->ehufco[symbol], tbl->ehufsi[symbol]); + } +} + + +/* + * Emit bits from a correction bit buffer. + */ + +LOCAL(void) +emit_buffered_bits (phuff_entropy_ptr entropy, char * bufstart, + unsigned int nbits) +{ + if (entropy->gather_statistics) + return; /* no real work */ + + while (nbits > 0) { + emit_bits(entropy, (unsigned int) (*bufstart), 1); + bufstart++; + nbits--; + } +} + + +/* + * Emit any pending EOBRUN symbol. + */ + +LOCAL(void) +emit_eobrun (phuff_entropy_ptr entropy) +{ + int temp, nbits; + + if (entropy->EOBRUN > 0) { /* if there is any pending EOBRUN */ + temp = entropy->EOBRUN; + nbits = 0; + while ((temp >>= 1)) + nbits++; + /* safety check: shouldn't happen given limited correction-bit buffer */ + if (nbits > 14) + ERREXIT(entropy->cinfo, JERR_HUFF_MISSING_CODE); + + emit_symbol(entropy, entropy->ac_tbl_no, nbits << 4); + if (nbits) + emit_bits(entropy, entropy->EOBRUN, nbits); + + entropy->EOBRUN = 0; + + /* Emit any buffered correction bits */ + emit_buffered_bits(entropy, entropy->bit_buffer, entropy->BE); + entropy->BE = 0; + } +} + + +/* + * Emit a restart marker & resynchronize predictions. + */ + +LOCAL(void) +emit_restart (phuff_entropy_ptr entropy, int restart_num) +{ + int ci; + + emit_eobrun(entropy); + + if (! entropy->gather_statistics) { + flush_bits(entropy); + emit_byte(entropy, 0xFF); + emit_byte(entropy, JPEG_RST0 + restart_num); + } + + if (entropy->cinfo->Ss == 0) { + /* Re-initialize DC predictions to 0 */ + for (ci = 0; ci < entropy->cinfo->comps_in_scan; ci++) + entropy->last_dc_val[ci] = 0; + } else { + /* Re-initialize all AC-related fields to 0 */ + entropy->EOBRUN = 0; + entropy->BE = 0; + } +} + + +/* + * MCU encoding for DC initial scan (either spectral selection, + * or first pass of successive approximation). + */ + +METHODDEF(int) +encode_mcu_DC_first (j_compress_ptr cinfo, JBLOCKROW *MCU_data) +{ + phuff_entropy_ptr entropy = (phuff_entropy_ptr) cinfo->entropy; + int temp, temp2; + int nbits; + int blkn, ci; + int Al = cinfo->Al; + JBLOCKROW block; + jpeg_component_info * compptr; + ISHIFT_TEMPS + + entropy->next_output_byte = cinfo->dest->next_output_byte; + entropy->free_in_buffer = cinfo->dest->free_in_buffer; + + /* Emit restart marker if needed */ + if (cinfo->restart_interval) + if (entropy->restarts_to_go == 0) + emit_restart(entropy, entropy->next_restart_num); + + /* Encode the MCU data blocks */ + for (blkn = 0; blkn < cinfo->blocks_in_MCU; blkn++) { + block = MCU_data[blkn]; + ci = cinfo->MCU_membership[blkn]; + compptr = cinfo->cur_comp_info[ci]; + + /* Compute the DC value after the required point transform by Al. + * This is simply an arithmetic right shift. + */ + temp2 = IRIGHT_SHIFT((int) ((*block)[0]), Al); + + /* DC differences are figured on the point-transformed values. */ + temp = temp2 - entropy->last_dc_val[ci]; + entropy->last_dc_val[ci] = temp2; + + /* Encode the DC coefficient difference per section G.1.2.1 */ + temp2 = temp; + if (temp < 0) { + temp = -temp; /* temp is abs value of input */ + /* For a negative input, want temp2 = bitwise complement of abs(input) */ + /* This code assumes we are on a two's complement machine */ + temp2--; + } + + /* Find the number of bits needed for the magnitude of the coefficient */ + nbits = 0; + while (temp) { + nbits++; + temp >>= 1; + } + /* Check for out-of-range coefficient values. + * Since we're encoding a difference, the range limit is twice as much. + */ + if (nbits > MAX_COEF_BITS+1) + ERREXIT(cinfo, JERR_BAD_DCT_COEF); + + /* Count/emit the Huffman-coded symbol for the number of bits */ + emit_symbol(entropy, compptr->dc_tbl_no, nbits); + + /* Emit that number of bits of the value, if positive, */ + /* or the complement of its magnitude, if negative. */ + if (nbits) /* emit_bits rejects calls with size 0 */ + emit_bits(entropy, (unsigned int) temp2, nbits); + } + + cinfo->dest->next_output_byte = entropy->next_output_byte; + cinfo->dest->free_in_buffer = entropy->free_in_buffer; + + /* Update restart-interval state too */ + if (cinfo->restart_interval) { + if (entropy->restarts_to_go == 0) { + entropy->restarts_to_go = cinfo->restart_interval; + entropy->next_restart_num++; + entropy->next_restart_num &= 7; + } + entropy->restarts_to_go--; + } + + return TRUE; +} + + +/* + * MCU encoding for AC initial scan (either spectral selection, + * or first pass of successive approximation). + */ + +METHODDEF(int) +encode_mcu_AC_first (j_compress_ptr cinfo, JBLOCKROW *MCU_data) +{ + phuff_entropy_ptr entropy = (phuff_entropy_ptr) cinfo->entropy; + int temp, temp2; + int nbits; + int r, k; + int Se = cinfo->Se; + int Al = cinfo->Al; + JBLOCKROW block; + + entropy->next_output_byte = cinfo->dest->next_output_byte; + entropy->free_in_buffer = cinfo->dest->free_in_buffer; + + /* Emit restart marker if needed */ + if (cinfo->restart_interval) + if (entropy->restarts_to_go == 0) + emit_restart(entropy, entropy->next_restart_num); + + /* Encode the MCU data block */ + block = MCU_data[0]; + + /* Encode the AC coefficients per section G.1.2.2, fig. G.3 */ + + r = 0; /* r = run length of zeros */ + + for (k = cinfo->Ss; k <= Se; k++) { + if ((temp = (*block)[jpeg_natural_order[k]]) == 0) { + r++; + continue; + } + /* We must apply the point transform by Al. For AC coefficients this + * is an integer division with rounding towards 0. To do this portably + * in C, we shift after obtaining the absolute value; so the code is + * interwoven with finding the abs value (temp) and output bits (temp2). + */ + if (temp < 0) { + temp = -temp; /* temp is abs value of input */ + temp >>= Al; /* apply the point transform */ + /* For a negative coef, want temp2 = bitwise complement of abs(coef) */ + temp2 = ~temp; + } else { + temp >>= Al; /* apply the point transform */ + temp2 = temp; + } + /* Watch out for case that nonzero coef is zero after point transform */ + if (temp == 0) { + r++; + continue; + } + + /* Emit any pending EOBRUN */ + if (entropy->EOBRUN > 0) + emit_eobrun(entropy); + /* if run length > 15, must emit special run-length-16 codes (0xF0) */ + while (r > 15) { + emit_symbol(entropy, entropy->ac_tbl_no, 0xF0); + r -= 16; + } + + /* Find the number of bits needed for the magnitude of the coefficient */ + nbits = 1; /* there must be at least one 1 bit */ + while ((temp >>= 1)) + nbits++; + /* Check for out-of-range coefficient values */ + if (nbits > MAX_COEF_BITS) + ERREXIT(cinfo, JERR_BAD_DCT_COEF); + + /* Count/emit Huffman symbol for run length / number of bits */ + emit_symbol(entropy, entropy->ac_tbl_no, (r << 4) + nbits); + + /* Emit that number of bits of the value, if positive, */ + /* or the complement of its magnitude, if negative. */ + emit_bits(entropy, (unsigned int) temp2, nbits); + + r = 0; /* reset zero run length */ + } + + if (r > 0) { /* If there are trailing zeroes, */ + entropy->EOBRUN++; /* count an EOB */ + if (entropy->EOBRUN == 0x7FFF) + emit_eobrun(entropy); /* force it out to avoid overflow */ + } + + cinfo->dest->next_output_byte = entropy->next_output_byte; + cinfo->dest->free_in_buffer = entropy->free_in_buffer; + + /* Update restart-interval state too */ + if (cinfo->restart_interval) { + if (entropy->restarts_to_go == 0) { + entropy->restarts_to_go = cinfo->restart_interval; + entropy->next_restart_num++; + entropy->next_restart_num &= 7; + } + entropy->restarts_to_go--; + } + + return TRUE; +} + + +/* + * MCU encoding for DC successive approximation refinement scan. + * Note: we assume such scans can be multi-component, although the spec + * is not very clear on the point. + */ + +METHODDEF(int) +encode_mcu_DC_refine (j_compress_ptr cinfo, JBLOCKROW *MCU_data) +{ + phuff_entropy_ptr entropy = (phuff_entropy_ptr) cinfo->entropy; + int temp; + int blkn; + int Al = cinfo->Al; + JBLOCKROW block; + + entropy->next_output_byte = cinfo->dest->next_output_byte; + entropy->free_in_buffer = cinfo->dest->free_in_buffer; + + /* Emit restart marker if needed */ + if (cinfo->restart_interval) + if (entropy->restarts_to_go == 0) + emit_restart(entropy, entropy->next_restart_num); + + /* Encode the MCU data blocks */ + for (blkn = 0; blkn < cinfo->blocks_in_MCU; blkn++) { + block = MCU_data[blkn]; + + /* We simply emit the Al'th bit of the DC coefficient value. */ + temp = (*block)[0]; + emit_bits(entropy, (unsigned int) (temp >> Al), 1); + } + + cinfo->dest->next_output_byte = entropy->next_output_byte; + cinfo->dest->free_in_buffer = entropy->free_in_buffer; + + /* Update restart-interval state too */ + if (cinfo->restart_interval) { + if (entropy->restarts_to_go == 0) { + entropy->restarts_to_go = cinfo->restart_interval; + entropy->next_restart_num++; + entropy->next_restart_num &= 7; + } + entropy->restarts_to_go--; + } + + return TRUE; +} + + +/* + * MCU encoding for AC successive approximation refinement scan. + */ + +METHODDEF(int) +encode_mcu_AC_refine (j_compress_ptr cinfo, JBLOCKROW *MCU_data) +{ + phuff_entropy_ptr entropy = (phuff_entropy_ptr) cinfo->entropy; + int temp; + int r, k; + int EOB; + char *BR_buffer; + unsigned int BR; + int Se = cinfo->Se; + int Al = cinfo->Al; + JBLOCKROW block; + int absvalues[DCTSIZE2]; + + entropy->next_output_byte = cinfo->dest->next_output_byte; + entropy->free_in_buffer = cinfo->dest->free_in_buffer; + + /* Emit restart marker if needed */ + if (cinfo->restart_interval) + if (entropy->restarts_to_go == 0) + emit_restart(entropy, entropy->next_restart_num); + + /* Encode the MCU data block */ + block = MCU_data[0]; + + /* It is convenient to make a pre-pass to determine the transformed + * coefficients' absolute values and the EOB position. + */ + EOB = 0; + for (k = cinfo->Ss; k <= Se; k++) { + temp = (*block)[jpeg_natural_order[k]]; + /* We must apply the point transform by Al. For AC coefficients this + * is an integer division with rounding towards 0. To do this portably + * in C, we shift after obtaining the absolute value. + */ + if (temp < 0) + temp = -temp; /* temp is abs value of input */ + temp >>= Al; /* apply the point transform */ + absvalues[k] = temp; /* save abs value for main pass */ + if (temp == 1) + EOB = k; /* EOB = index of last newly-nonzero coef */ + } + + /* Encode the AC coefficients per section G.1.2.3, fig. G.7 */ + + r = 0; /* r = run length of zeros */ + BR = 0; /* BR = count of buffered bits added now */ + BR_buffer = entropy->bit_buffer + entropy->BE; /* Append bits to buffer */ + + for (k = cinfo->Ss; k <= Se; k++) { + if ((temp = absvalues[k]) == 0) { + r++; + continue; + } + + /* Emit any required ZRLs, but not if they can be folded into EOB */ + while (r > 15 && k <= EOB) { + /* emit any pending EOBRUN and the BE correction bits */ + emit_eobrun(entropy); + /* Emit ZRL */ + emit_symbol(entropy, entropy->ac_tbl_no, 0xF0); + r -= 16; + /* Emit buffered correction bits that must be associated with ZRL */ + emit_buffered_bits(entropy, BR_buffer, BR); + BR_buffer = entropy->bit_buffer; /* BE bits are gone now */ + BR = 0; + } + + /* If the coef was previously nonzero, it only needs a correction bit. + * NOTE: a straight translation of the spec's figure G.7 would suggest + * that we also need to test r > 15. But if r > 15, we can only get here + * if k > EOB, which implies that this coefficient is not 1. + */ + if (temp > 1) { + /* The correction bit is the next bit of the absolute value. */ + BR_buffer[BR++] = (char) (temp & 1); + continue; + } + + /* Emit any pending EOBRUN and the BE correction bits */ + emit_eobrun(entropy); + + /* Count/emit Huffman symbol for run length / number of bits */ + emit_symbol(entropy, entropy->ac_tbl_no, (r << 4) + 1); + + /* Emit output bit for newly-nonzero coef */ + temp = ((*block)[jpeg_natural_order[k]] < 0) ? 0 : 1; + emit_bits(entropy, (unsigned int) temp, 1); + + /* Emit buffered correction bits that must be associated with this code */ + emit_buffered_bits(entropy, BR_buffer, BR); + BR_buffer = entropy->bit_buffer; /* BE bits are gone now */ + BR = 0; + r = 0; /* reset zero run length */ + } + + if (r > 0 || BR > 0) { /* If there are trailing zeroes, */ + entropy->EOBRUN++; /* count an EOB */ + entropy->BE += BR; /* concat my correction bits to older ones */ + /* We force out the EOB if we risk either: + * 1. overflow of the EOB counter; + * 2. overflow of the correction bit buffer during the next MCU. + */ + if (entropy->EOBRUN == 0x7FFF || entropy->BE > (MAX_CORR_BITS-DCTSIZE2+1)) + emit_eobrun(entropy); + } + + cinfo->dest->next_output_byte = entropy->next_output_byte; + cinfo->dest->free_in_buffer = entropy->free_in_buffer; + + /* Update restart-interval state too */ + if (cinfo->restart_interval) { + if (entropy->restarts_to_go == 0) { + entropy->restarts_to_go = cinfo->restart_interval; + entropy->next_restart_num++; + entropy->next_restart_num &= 7; + } + entropy->restarts_to_go--; + } + + return TRUE; +} + + +/* + * Finish up at the end of a Huffman-compressed progressive scan. + */ + +METHODDEF(void) +finish_pass_phuff (j_compress_ptr cinfo) +{ + phuff_entropy_ptr entropy = (phuff_entropy_ptr) cinfo->entropy; + + entropy->next_output_byte = cinfo->dest->next_output_byte; + entropy->free_in_buffer = cinfo->dest->free_in_buffer; + + /* Flush out any buffered data */ + emit_eobrun(entropy); + flush_bits(entropy); + + cinfo->dest->next_output_byte = entropy->next_output_byte; + cinfo->dest->free_in_buffer = entropy->free_in_buffer; +} + + +/* + * Finish up a statistics-gathering pass and create the new Huffman tables. + */ + +METHODDEF(void) +finish_pass_gather_phuff (j_compress_ptr cinfo) +{ + phuff_entropy_ptr entropy = (phuff_entropy_ptr) cinfo->entropy; + int is_DC_band; + int ci, tbl; + jpeg_component_info * compptr; + JHUFF_TBL **htblptr; + int did[NUM_HUFF_TBLS]; + + /* Flush out buffered data (all we care about is counting the EOB symbol) */ + emit_eobrun(entropy); + + is_DC_band = (cinfo->Ss == 0); + + /* It's important not to apply jpeg_gen_optimal_table more than once + * per table, because it clobbers the input frequency counts! + */ + MEMZERO(did, SIZEOF(did)); + + for (ci = 0; ci < cinfo->comps_in_scan; ci++) { + compptr = cinfo->cur_comp_info[ci]; + if (is_DC_band) { + if (cinfo->Ah != 0) /* DC refinement needs no table */ + continue; + tbl = compptr->dc_tbl_no; + } else { + tbl = compptr->ac_tbl_no; + } + if (! did[tbl]) { + if (is_DC_band) + htblptr = & cinfo->dc_huff_tbl_ptrs[tbl]; + else + htblptr = & cinfo->ac_huff_tbl_ptrs[tbl]; + if (*htblptr == NULL) + *htblptr = jpeg_alloc_huff_table((j_common_ptr) cinfo); + jpeg_gen_optimal_table(cinfo, *htblptr, entropy->count_ptrs[tbl]); + did[tbl] = TRUE; + } + } +} + + +/* + * Module initialization routine for progressive Huffman entropy encoding. + */ + +GLOBAL(void) +jinit_phuff_encoder (j_compress_ptr cinfo) +{ + phuff_entropy_ptr entropy; + int i; + + entropy = (phuff_entropy_ptr) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE, + SIZEOF(phuff_entropy_encoder)); + cinfo->entropy = (struct jpeg_entropy_encoder *) entropy; + entropy->pub.start_pass = start_pass_phuff; + + /* Mark tables unallocated */ + for (i = 0; i < NUM_HUFF_TBLS; i++) { + entropy->derived_tbls[i] = NULL; + entropy->count_ptrs[i] = NULL; + } + entropy->bit_buffer = NULL; /* needed only in AC refinement scan */ +} + +#endif /* C_PROGRESSIVE_SUPPORTED */ diff --git a/ml/dlib/dlib/external/libjpeg/jcprepct.cpp b/ml/dlib/dlib/external/libjpeg/jcprepct.cpp new file mode 100644 index 000000000..d1532c273 --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jcprepct.cpp @@ -0,0 +1,354 @@ +/* + * jcprepct.c + * + * Copyright (C) 1994-1996, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This file contains the compression preprocessing controller. + * This controller manages the color conversion, downsampling, + * and edge expansion steps. + * + * Most of the complexity here is associated with buffering input rows + * as required by the downsampler. See the comments at the head of + * jcsample.c for the downsampler's needs. + */ + +#define JPEG_INTERNALS +#include "jinclude.h" +#include "jpeglib.h" + + +/* At present, jcsample.c can request context rows only for smoothing. + * In the future, we might also need context rows for CCIR601 sampling + * or other more-complex downsampling procedures. The code to support + * context rows should be compiled only if needed. + */ +#ifdef INPUT_SMOOTHING_SUPPORTED +#define CONTEXT_ROWS_SUPPORTED +#endif + + +/* + * For the simple (no-context-row) case, we just need to buffer one + * row group's worth of pixels for the downsampling step. At the bottom of + * the image, we pad to a full row group by replicating the last pixel row. + * The downsampler's last output row is then replicated if needed to pad + * out to a full iMCU row. + * + * When providing context rows, we must buffer three row groups' worth of + * pixels. Three row groups are physically allocated, but the row pointer + * arrays are made five row groups high, with the extra pointers above and + * below "wrapping around" to point to the last and first real row groups. + * This allows the downsampler to access the proper context rows. + * At the top and bottom of the image, we create dummy context rows by + * copying the first or last real pixel row. This copying could be avoided + * by pointer hacking as is done in jdmainct.c, but it doesn't seem worth the + * trouble on the compression side. + */ + + +/* Private buffer controller object */ + +typedef struct { + struct jpeg_c_prep_controller pub; /* public fields */ + + /* Downsampling input buffer. This buffer holds color-converted data + * until we have enough to do a downsample step. + */ + JSAMPARRAY color_buf[MAX_COMPONENTS]; + + JDIMENSION rows_to_go; /* counts rows remaining in source image */ + int next_buf_row; /* index of next row to store in color_buf */ + +#ifdef CONTEXT_ROWS_SUPPORTED /* only needed for context case */ + int this_row_group; /* starting row index of group to process */ + int next_buf_stop; /* downsample when we reach this index */ +#endif +} my_prep_controller; + +typedef my_prep_controller * my_prep_ptr; + + +/* + * Initialize for a processing pass. + */ + +METHODDEF(void) +start_pass_prep (j_compress_ptr cinfo, J_BUF_MODE pass_mode) +{ + my_prep_ptr prep = (my_prep_ptr) cinfo->prep; + + if (pass_mode != JBUF_PASS_THRU) + ERREXIT(cinfo, JERR_BAD_BUFFER_MODE); + + /* Initialize total-height counter for detecting bottom of image */ + prep->rows_to_go = cinfo->image_height; + /* Mark the conversion buffer empty */ + prep->next_buf_row = 0; +#ifdef CONTEXT_ROWS_SUPPORTED + /* Preset additional state variables for context mode. + * These aren't used in non-context mode, so we needn't test which mode. + */ + prep->this_row_group = 0; + /* Set next_buf_stop to stop after two row groups have been read in. */ + prep->next_buf_stop = 2 * cinfo->max_v_samp_factor; +#endif +} + + +/* + * Expand an image vertically from height input_rows to height output_rows, + * by duplicating the bottom row. + */ + +LOCAL(void) +expand_bottom_edge (JSAMPARRAY image_data, JDIMENSION num_cols, + int input_rows, int output_rows) +{ + int row; + + for (row = input_rows; row < output_rows; row++) { + jcopy_sample_rows(image_data, input_rows-1, image_data, row, + 1, num_cols); + } +} + + +/* + * Process some data in the simple no-context case. + * + * Preprocessor output data is counted in "row groups". A row group + * is defined to be v_samp_factor sample rows of each component. + * Downsampling will produce this much data from each max_v_samp_factor + * input rows. + */ + +METHODDEF(void) +pre_process_data (j_compress_ptr cinfo, + JSAMPARRAY input_buf, JDIMENSION *in_row_ctr, + JDIMENSION in_rows_avail, + JSAMPIMAGE output_buf, JDIMENSION *out_row_group_ctr, + JDIMENSION out_row_groups_avail) +{ + my_prep_ptr prep = (my_prep_ptr) cinfo->prep; + int numrows, ci; + JDIMENSION inrows; + jpeg_component_info * compptr; + + while (*in_row_ctr < in_rows_avail && + *out_row_group_ctr < out_row_groups_avail) { + /* Do color conversion to fill the conversion buffer. */ + inrows = in_rows_avail - *in_row_ctr; + numrows = cinfo->max_v_samp_factor - prep->next_buf_row; + numrows = (int) MIN((JDIMENSION) numrows, inrows); + (*cinfo->cconvert->color_convert) (cinfo, input_buf + *in_row_ctr, + prep->color_buf, + (JDIMENSION) prep->next_buf_row, + numrows); + *in_row_ctr += numrows; + prep->next_buf_row += numrows; + prep->rows_to_go -= numrows; + /* If at bottom of image, pad to fill the conversion buffer. */ + if (prep->rows_to_go == 0 && + prep->next_buf_row < cinfo->max_v_samp_factor) { + for (ci = 0; ci < cinfo->num_components; ci++) { + expand_bottom_edge(prep->color_buf[ci], cinfo->image_width, + prep->next_buf_row, cinfo->max_v_samp_factor); + } + prep->next_buf_row = cinfo->max_v_samp_factor; + } + /* If we've filled the conversion buffer, empty it. */ + if (prep->next_buf_row == cinfo->max_v_samp_factor) { + (*cinfo->downsample->downsample) (cinfo, + prep->color_buf, (JDIMENSION) 0, + output_buf, *out_row_group_ctr); + prep->next_buf_row = 0; + (*out_row_group_ctr)++; + } + /* If at bottom of image, pad the output to a full iMCU height. + * Note we assume the caller is providing a one-iMCU-height output buffer! + */ + if (prep->rows_to_go == 0 && + *out_row_group_ctr < out_row_groups_avail) { + for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components; + ci++, compptr++) { + expand_bottom_edge(output_buf[ci], + compptr->width_in_blocks * DCTSIZE, + (int) (*out_row_group_ctr * compptr->v_samp_factor), + (int) (out_row_groups_avail * compptr->v_samp_factor)); + } + *out_row_group_ctr = out_row_groups_avail; + break; /* can exit outer loop without test */ + } + } +} + + +#ifdef CONTEXT_ROWS_SUPPORTED + +/* + * Process some data in the context case. + */ + +METHODDEF(void) +pre_process_context (j_compress_ptr cinfo, + JSAMPARRAY input_buf, JDIMENSION *in_row_ctr, + JDIMENSION in_rows_avail, + JSAMPIMAGE output_buf, JDIMENSION *out_row_group_ctr, + JDIMENSION out_row_groups_avail) +{ + my_prep_ptr prep = (my_prep_ptr) cinfo->prep; + int numrows, ci; + int buf_height = cinfo->max_v_samp_factor * 3; + JDIMENSION inrows; + + while (*out_row_group_ctr < out_row_groups_avail) { + if (*in_row_ctr < in_rows_avail) { + /* Do color conversion to fill the conversion buffer. */ + inrows = in_rows_avail - *in_row_ctr; + numrows = prep->next_buf_stop - prep->next_buf_row; + numrows = (int) MIN((JDIMENSION) numrows, inrows); + (*cinfo->cconvert->color_convert) (cinfo, input_buf + *in_row_ctr, + prep->color_buf, + (JDIMENSION) prep->next_buf_row, + numrows); + /* Pad at top of image, if first time through */ + if (prep->rows_to_go == cinfo->image_height) { + for (ci = 0; ci < cinfo->num_components; ci++) { + int row; + for (row = 1; row <= cinfo->max_v_samp_factor; row++) { + jcopy_sample_rows(prep->color_buf[ci], 0, + prep->color_buf[ci], -row, + 1, cinfo->image_width); + } + } + } + *in_row_ctr += numrows; + prep->next_buf_row += numrows; + prep->rows_to_go -= numrows; + } else { + /* Return for more data, unless we are at the bottom of the image. */ + if (prep->rows_to_go != 0) + break; + /* When at bottom of image, pad to fill the conversion buffer. */ + if (prep->next_buf_row < prep->next_buf_stop) { + for (ci = 0; ci < cinfo->num_components; ci++) { + expand_bottom_edge(prep->color_buf[ci], cinfo->image_width, + prep->next_buf_row, prep->next_buf_stop); + } + prep->next_buf_row = prep->next_buf_stop; + } + } + /* If we've gotten enough data, downsample a row group. */ + if (prep->next_buf_row == prep->next_buf_stop) { + (*cinfo->downsample->downsample) (cinfo, + prep->color_buf, + (JDIMENSION) prep->this_row_group, + output_buf, *out_row_group_ctr); + (*out_row_group_ctr)++; + /* Advance pointers with wraparound as necessary. */ + prep->this_row_group += cinfo->max_v_samp_factor; + if (prep->this_row_group >= buf_height) + prep->this_row_group = 0; + if (prep->next_buf_row >= buf_height) + prep->next_buf_row = 0; + prep->next_buf_stop = prep->next_buf_row + cinfo->max_v_samp_factor; + } + } +} + + +/* + * Create the wrapped-around downsampling input buffer needed for context mode. + */ + +LOCAL(void) +create_context_buffer (j_compress_ptr cinfo) +{ + my_prep_ptr prep = (my_prep_ptr) cinfo->prep; + int rgroup_height = cinfo->max_v_samp_factor; + int ci, i; + jpeg_component_info * compptr; + JSAMPARRAY true_buffer, fake_buffer; + + /* Grab enough space for fake row pointers for all the components; + * we need five row groups' worth of pointers for each component. + */ + fake_buffer = (JSAMPARRAY) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE, + (cinfo->num_components * 5 * rgroup_height) * + SIZEOF(JSAMPROW)); + + for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components; + ci++, compptr++) { + /* Allocate the actual buffer space (3 row groups) for this component. + * We make the buffer wide enough to allow the downsampler to edge-expand + * horizontally within the buffer, if it so chooses. + */ + true_buffer = (*cinfo->mem->alloc_sarray) + ((j_common_ptr) cinfo, JPOOL_IMAGE, + (JDIMENSION) (((long) compptr->width_in_blocks * DCTSIZE * + cinfo->max_h_samp_factor) / compptr->h_samp_factor), + (JDIMENSION) (3 * rgroup_height)); + /* Copy true buffer row pointers into the middle of the fake row array */ + MEMCOPY(fake_buffer + rgroup_height, true_buffer, + 3 * rgroup_height * SIZEOF(JSAMPROW)); + /* Fill in the above and below wraparound pointers */ + for (i = 0; i < rgroup_height; i++) { + fake_buffer[i] = true_buffer[2 * rgroup_height + i]; + fake_buffer[4 * rgroup_height + i] = true_buffer[i]; + } + prep->color_buf[ci] = fake_buffer + rgroup_height; + fake_buffer += 5 * rgroup_height; /* point to space for next component */ + } +} + +#endif /* CONTEXT_ROWS_SUPPORTED */ + + +/* + * Initialize preprocessing controller. + */ + +GLOBAL(void) +jinit_c_prep_controller (j_compress_ptr cinfo, int need_full_buffer) +{ + my_prep_ptr prep; + int ci; + jpeg_component_info * compptr; + + if (need_full_buffer) /* safety check */ + ERREXIT(cinfo, JERR_BAD_BUFFER_MODE); + + prep = (my_prep_ptr) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE, + SIZEOF(my_prep_controller)); + cinfo->prep = (struct jpeg_c_prep_controller *) prep; + prep->pub.start_pass = start_pass_prep; + + /* Allocate the color conversion buffer. + * We make the buffer wide enough to allow the downsampler to edge-expand + * horizontally within the buffer, if it so chooses. + */ + if (cinfo->downsample->need_context_rows) { + /* Set up to provide context rows */ +#ifdef CONTEXT_ROWS_SUPPORTED + prep->pub.pre_process_data = pre_process_context; + create_context_buffer(cinfo); +#else + ERREXIT(cinfo, JERR_NOT_COMPILED); +#endif + } else { + /* No context, just make it tall enough for one row group */ + prep->pub.pre_process_data = pre_process_data; + for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components; + ci++, compptr++) { + prep->color_buf[ci] = (*cinfo->mem->alloc_sarray) + ((j_common_ptr) cinfo, JPOOL_IMAGE, + (JDIMENSION) (((long) compptr->width_in_blocks * DCTSIZE * + cinfo->max_h_samp_factor) / compptr->h_samp_factor), + (JDIMENSION) cinfo->max_v_samp_factor); + } + } +} diff --git a/ml/dlib/dlib/external/libjpeg/jcsample.cpp b/ml/dlib/dlib/external/libjpeg/jcsample.cpp new file mode 100644 index 000000000..b73270120 --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jcsample.cpp @@ -0,0 +1,519 @@ +/* + * jcsample.c + * + * Copyright (C) 1991-1996, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This file contains downsampling routines. + * + * Downsampling input data is counted in "row groups". A row group + * is defined to be max_v_samp_factor pixel rows of each component, + * from which the downsampler produces v_samp_factor sample rows. + * A single row group is processed in each call to the downsampler module. + * + * The downsampler is responsible for edge-expansion of its output data + * to fill an integral number of DCT blocks horizontally. The source buffer + * may be modified if it is helpful for this purpose (the source buffer is + * allocated wide enough to correspond to the desired output width). + * The caller (the prep controller) is responsible for vertical padding. + * + * The downsampler may request "context rows" by setting need_context_rows + * during startup. In this case, the input arrays will contain at least + * one row group's worth of pixels above and below the passed-in data; + * the caller will create dummy rows at image top and bottom by replicating + * the first or last real pixel row. + * + * An excellent reference for image resampling is + * Digital Image Warping, George Wolberg, 1990. + * Pub. by IEEE Computer Society Press, Los Alamitos, CA. ISBN 0-8186-8944-7. + * + * The downsampling algorithm used here is a simple average of the source + * pixels covered by the output pixel. The hi-falutin sampling literature + * refers to this as a "box filter". In general the characteristics of a box + * filter are not very good, but for the specific cases we normally use (1:1 + * and 2:1 ratios) the box is equivalent to a "triangle filter" which is not + * nearly so bad. If you intend to use other sampling ratios, you'd be well + * advised to improve this code. + * + * A simple input-smoothing capability is provided. This is mainly intended + * for cleaning up color-dithered GIF input files (if you find it inadequate, + * we suggest using an external filtering program such as pnmconvol). When + * enabled, each input pixel P is replaced by a weighted sum of itself and its + * eight neighbors. P's weight is 1-8*SF and each neighbor's weight is SF, + * where SF = (smoothing_factor / 1024). + * Currently, smoothing is only supported for 2h2v sampling factors. + */ + +#define JPEG_INTERNALS +#include "jinclude.h" +#include "jpeglib.h" + + +/* Pointer to routine to downsample a single component */ +typedef JMETHOD(void, downsample1_ptr, + (j_compress_ptr cinfo, jpeg_component_info * compptr, + JSAMPARRAY input_data, JSAMPARRAY output_data)); + +/* Private subobject */ + +typedef struct { + struct jpeg_downsampler pub; /* public fields */ + + /* Downsampling method pointers, one per component */ + downsample1_ptr methods[MAX_COMPONENTS]; +} my_downsampler; + +typedef my_downsampler * my_downsample_ptr; + + +/* + * Initialize for a downsampling pass. + */ + +METHODDEF(void) +start_pass_downsample (j_compress_ptr )//cinfo) +{ + /* no work for now */ +} + + +/* + * Expand a component horizontally from width input_cols to width output_cols, + * by duplicating the rightmost samples. + */ + +LOCAL(void) +expand_right_edge (JSAMPARRAY image_data, int num_rows, + JDIMENSION input_cols, JDIMENSION output_cols) +{ + JSAMPROW ptr; + JSAMPLE pixval; + int count; + int row; + int numcols = (int) (output_cols - input_cols); + + if (numcols > 0) { + for (row = 0; row < num_rows; row++) { + ptr = image_data[row] + input_cols; + pixval = ptr[-1]; /* don't need GETJSAMPLE() here */ + for (count = numcols; count > 0; count--) + *ptr++ = pixval; + } + } +} + + +/* + * Do downsampling for a whole row group (all components). + * + * In this version we simply downsample each component independently. + */ + +METHODDEF(void) +sep_downsample (j_compress_ptr cinfo, + JSAMPIMAGE input_buf, JDIMENSION in_row_index, + JSAMPIMAGE output_buf, JDIMENSION out_row_group_index) +{ + my_downsample_ptr downsample = (my_downsample_ptr) cinfo->downsample; + int ci; + jpeg_component_info * compptr; + JSAMPARRAY in_ptr, out_ptr; + + for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components; + ci++, compptr++) { + in_ptr = input_buf[ci] + in_row_index; + out_ptr = output_buf[ci] + (out_row_group_index * compptr->v_samp_factor); + (*downsample->methods[ci]) (cinfo, compptr, in_ptr, out_ptr); + } +} + + +/* + * Downsample pixel values of a single component. + * One row group is processed per call. + * This version handles arbitrary integral sampling ratios, without smoothing. + * Note that this version is not actually used for customary sampling ratios. + */ + +METHODDEF(void) +int_downsample (j_compress_ptr cinfo, jpeg_component_info * compptr, + JSAMPARRAY input_data, JSAMPARRAY output_data) +{ + int inrow, outrow, h_expand, v_expand, numpix, numpix2, h, v; + JDIMENSION outcol, outcol_h; /* outcol_h == outcol*h_expand */ + JDIMENSION output_cols = compptr->width_in_blocks * DCTSIZE; + JSAMPROW inptr, outptr; + long outvalue; + + h_expand = cinfo->max_h_samp_factor / compptr->h_samp_factor; + v_expand = cinfo->max_v_samp_factor / compptr->v_samp_factor; + numpix = h_expand * v_expand; + numpix2 = numpix/2; + + /* Expand input data enough to let all the output samples be generated + * by the standard loop. Special-casing padded output would be more + * efficient. + */ + expand_right_edge(input_data, cinfo->max_v_samp_factor, + cinfo->image_width, output_cols * h_expand); + + inrow = 0; + for (outrow = 0; outrow < compptr->v_samp_factor; outrow++) { + outptr = output_data[outrow]; + for (outcol = 0, outcol_h = 0; outcol < output_cols; + outcol++, outcol_h += h_expand) { + outvalue = 0; + for (v = 0; v < v_expand; v++) { + inptr = input_data[inrow+v] + outcol_h; + for (h = 0; h < h_expand; h++) { + outvalue += (long) GETJSAMPLE(*inptr++); + } + } + *outptr++ = (JSAMPLE) ((outvalue + numpix2) / numpix); + } + inrow += v_expand; + } +} + + +/* + * Downsample pixel values of a single component. + * This version handles the special case of a full-size component, + * without smoothing. + */ + +METHODDEF(void) +fullsize_downsample (j_compress_ptr cinfo, jpeg_component_info * compptr, + JSAMPARRAY input_data, JSAMPARRAY output_data) +{ + /* Copy the data */ + jcopy_sample_rows(input_data, 0, output_data, 0, + cinfo->max_v_samp_factor, cinfo->image_width); + /* Edge-expand */ + expand_right_edge(output_data, cinfo->max_v_samp_factor, + cinfo->image_width, compptr->width_in_blocks * DCTSIZE); +} + + +/* + * Downsample pixel values of a single component. + * This version handles the common case of 2:1 horizontal and 1:1 vertical, + * without smoothing. + * + * A note about the "bias" calculations: when rounding fractional values to + * integer, we do not want to always round 0.5 up to the next integer. + * If we did that, we'd introduce a noticeable bias towards larger values. + * Instead, this code is arranged so that 0.5 will be rounded up or down at + * alternate pixel locations (a simple ordered dither pattern). + */ + +METHODDEF(void) +h2v1_downsample (j_compress_ptr cinfo, jpeg_component_info * compptr, + JSAMPARRAY input_data, JSAMPARRAY output_data) +{ + int outrow; + JDIMENSION outcol; + JDIMENSION output_cols = compptr->width_in_blocks * DCTSIZE; + JSAMPROW inptr, outptr; + int bias; + + /* Expand input data enough to let all the output samples be generated + * by the standard loop. Special-casing padded output would be more + * efficient. + */ + expand_right_edge(input_data, cinfo->max_v_samp_factor, + cinfo->image_width, output_cols * 2); + + for (outrow = 0; outrow < compptr->v_samp_factor; outrow++) { + outptr = output_data[outrow]; + inptr = input_data[outrow]; + bias = 0; /* bias = 0,1,0,1,... for successive samples */ + for (outcol = 0; outcol < output_cols; outcol++) { + *outptr++ = (JSAMPLE) ((GETJSAMPLE(*inptr) + GETJSAMPLE(inptr[1]) + + bias) >> 1); + bias ^= 1; /* 0=>1, 1=>0 */ + inptr += 2; + } + } +} + + +/* + * Downsample pixel values of a single component. + * This version handles the standard case of 2:1 horizontal and 2:1 vertical, + * without smoothing. + */ + +METHODDEF(void) +h2v2_downsample (j_compress_ptr cinfo, jpeg_component_info * compptr, + JSAMPARRAY input_data, JSAMPARRAY output_data) +{ + int inrow, outrow; + JDIMENSION outcol; + JDIMENSION output_cols = compptr->width_in_blocks * DCTSIZE; + JSAMPROW inptr0, inptr1, outptr; + int bias; + + /* Expand input data enough to let all the output samples be generated + * by the standard loop. Special-casing padded output would be more + * efficient. + */ + expand_right_edge(input_data, cinfo->max_v_samp_factor, + cinfo->image_width, output_cols * 2); + + inrow = 0; + for (outrow = 0; outrow < compptr->v_samp_factor; outrow++) { + outptr = output_data[outrow]; + inptr0 = input_data[inrow]; + inptr1 = input_data[inrow+1]; + bias = 1; /* bias = 1,2,1,2,... for successive samples */ + for (outcol = 0; outcol < output_cols; outcol++) { + *outptr++ = (JSAMPLE) ((GETJSAMPLE(*inptr0) + GETJSAMPLE(inptr0[1]) + + GETJSAMPLE(*inptr1) + GETJSAMPLE(inptr1[1]) + + bias) >> 2); + bias ^= 3; /* 1=>2, 2=>1 */ + inptr0 += 2; inptr1 += 2; + } + inrow += 2; + } +} + + +#ifdef INPUT_SMOOTHING_SUPPORTED + +/* + * Downsample pixel values of a single component. + * This version handles the standard case of 2:1 horizontal and 2:1 vertical, + * with smoothing. One row of context is required. + */ + +METHODDEF(void) +h2v2_smooth_downsample (j_compress_ptr cinfo, jpeg_component_info * compptr, + JSAMPARRAY input_data, JSAMPARRAY output_data) +{ + int inrow, outrow; + JDIMENSION colctr; + JDIMENSION output_cols = compptr->width_in_blocks * DCTSIZE; + JSAMPROW inptr0, inptr1, above_ptr, below_ptr, outptr; + long membersum, neighsum, memberscale, neighscale; + + /* Expand input data enough to let all the output samples be generated + * by the standard loop. Special-casing padded output would be more + * efficient. + */ + expand_right_edge(input_data - 1, cinfo->max_v_samp_factor + 2, + cinfo->image_width, output_cols * 2); + + /* We don't bother to form the individual "smoothed" input pixel values; + * we can directly compute the output which is the average of the four + * smoothed values. Each of the four member pixels contributes a fraction + * (1-8*SF) to its own smoothed image and a fraction SF to each of the three + * other smoothed pixels, therefore a total fraction (1-5*SF)/4 to the final + * output. The four corner-adjacent neighbor pixels contribute a fraction + * SF to just one smoothed pixel, or SF/4 to the final output; while the + * eight edge-adjacent neighbors contribute SF to each of two smoothed + * pixels, or SF/2 overall. In order to use integer arithmetic, these + * factors are scaled by 2^16 = 65536. + * Also recall that SF = smoothing_factor / 1024. + */ + + memberscale = 16384 - cinfo->smoothing_factor * 80; /* scaled (1-5*SF)/4 */ + neighscale = cinfo->smoothing_factor * 16; /* scaled SF/4 */ + + inrow = 0; + for (outrow = 0; outrow < compptr->v_samp_factor; outrow++) { + outptr = output_data[outrow]; + inptr0 = input_data[inrow]; + inptr1 = input_data[inrow+1]; + above_ptr = input_data[inrow-1]; + below_ptr = input_data[inrow+2]; + + /* Special case for first column: pretend column -1 is same as column 0 */ + membersum = GETJSAMPLE(*inptr0) + GETJSAMPLE(inptr0[1]) + + GETJSAMPLE(*inptr1) + GETJSAMPLE(inptr1[1]); + neighsum = GETJSAMPLE(*above_ptr) + GETJSAMPLE(above_ptr[1]) + + GETJSAMPLE(*below_ptr) + GETJSAMPLE(below_ptr[1]) + + GETJSAMPLE(*inptr0) + GETJSAMPLE(inptr0[2]) + + GETJSAMPLE(*inptr1) + GETJSAMPLE(inptr1[2]); + neighsum += neighsum; + neighsum += GETJSAMPLE(*above_ptr) + GETJSAMPLE(above_ptr[2]) + + GETJSAMPLE(*below_ptr) + GETJSAMPLE(below_ptr[2]); + membersum = membersum * memberscale + neighsum * neighscale; + *outptr++ = (JSAMPLE) ((membersum + 32768) >> 16); + inptr0 += 2; inptr1 += 2; above_ptr += 2; below_ptr += 2; + + for (colctr = output_cols - 2; colctr > 0; colctr--) { + /* sum of pixels directly mapped to this output element */ + membersum = GETJSAMPLE(*inptr0) + GETJSAMPLE(inptr0[1]) + + GETJSAMPLE(*inptr1) + GETJSAMPLE(inptr1[1]); + /* sum of edge-neighbor pixels */ + neighsum = GETJSAMPLE(*above_ptr) + GETJSAMPLE(above_ptr[1]) + + GETJSAMPLE(*below_ptr) + GETJSAMPLE(below_ptr[1]) + + GETJSAMPLE(inptr0[-1]) + GETJSAMPLE(inptr0[2]) + + GETJSAMPLE(inptr1[-1]) + GETJSAMPLE(inptr1[2]); + /* The edge-neighbors count twice as much as corner-neighbors */ + neighsum += neighsum; + /* Add in the corner-neighbors */ + neighsum += GETJSAMPLE(above_ptr[-1]) + GETJSAMPLE(above_ptr[2]) + + GETJSAMPLE(below_ptr[-1]) + GETJSAMPLE(below_ptr[2]); + /* form final output scaled up by 2^16 */ + membersum = membersum * memberscale + neighsum * neighscale; + /* round, descale and output it */ + *outptr++ = (JSAMPLE) ((membersum + 32768) >> 16); + inptr0 += 2; inptr1 += 2; above_ptr += 2; below_ptr += 2; + } + + /* Special case for last column */ + membersum = GETJSAMPLE(*inptr0) + GETJSAMPLE(inptr0[1]) + + GETJSAMPLE(*inptr1) + GETJSAMPLE(inptr1[1]); + neighsum = GETJSAMPLE(*above_ptr) + GETJSAMPLE(above_ptr[1]) + + GETJSAMPLE(*below_ptr) + GETJSAMPLE(below_ptr[1]) + + GETJSAMPLE(inptr0[-1]) + GETJSAMPLE(inptr0[1]) + + GETJSAMPLE(inptr1[-1]) + GETJSAMPLE(inptr1[1]); + neighsum += neighsum; + neighsum += GETJSAMPLE(above_ptr[-1]) + GETJSAMPLE(above_ptr[1]) + + GETJSAMPLE(below_ptr[-1]) + GETJSAMPLE(below_ptr[1]); + membersum = membersum * memberscale + neighsum * neighscale; + *outptr = (JSAMPLE) ((membersum + 32768) >> 16); + + inrow += 2; + } +} + + +/* + * Downsample pixel values of a single component. + * This version handles the special case of a full-size component, + * with smoothing. One row of context is required. + */ + +METHODDEF(void) +fullsize_smooth_downsample (j_compress_ptr cinfo, jpeg_component_info *compptr, + JSAMPARRAY input_data, JSAMPARRAY output_data) +{ + int outrow; + JDIMENSION colctr; + JDIMENSION output_cols = compptr->width_in_blocks * DCTSIZE; + JSAMPROW inptr, above_ptr, below_ptr, outptr; + long membersum, neighsum, memberscale, neighscale; + int colsum, lastcolsum, nextcolsum; + + /* Expand input data enough to let all the output samples be generated + * by the standard loop. Special-casing padded output would be more + * efficient. + */ + expand_right_edge(input_data - 1, cinfo->max_v_samp_factor + 2, + cinfo->image_width, output_cols); + + /* Each of the eight neighbor pixels contributes a fraction SF to the + * smoothed pixel, while the main pixel contributes (1-8*SF). In order + * to use integer arithmetic, these factors are multiplied by 2^16 = 65536. + * Also recall that SF = smoothing_factor / 1024. + */ + + memberscale = 65536L - cinfo->smoothing_factor * 512L; /* scaled 1-8*SF */ + neighscale = cinfo->smoothing_factor * 64; /* scaled SF */ + + for (outrow = 0; outrow < compptr->v_samp_factor; outrow++) { + outptr = output_data[outrow]; + inptr = input_data[outrow]; + above_ptr = input_data[outrow-1]; + below_ptr = input_data[outrow+1]; + + /* Special case for first column */ + colsum = GETJSAMPLE(*above_ptr++) + GETJSAMPLE(*below_ptr++) + + GETJSAMPLE(*inptr); + membersum = GETJSAMPLE(*inptr++); + nextcolsum = GETJSAMPLE(*above_ptr) + GETJSAMPLE(*below_ptr) + + GETJSAMPLE(*inptr); + neighsum = colsum + (colsum - membersum) + nextcolsum; + membersum = membersum * memberscale + neighsum * neighscale; + *outptr++ = (JSAMPLE) ((membersum + 32768) >> 16); + lastcolsum = colsum; colsum = nextcolsum; + + for (colctr = output_cols - 2; colctr > 0; colctr--) { + membersum = GETJSAMPLE(*inptr++); + above_ptr++; below_ptr++; + nextcolsum = GETJSAMPLE(*above_ptr) + GETJSAMPLE(*below_ptr) + + GETJSAMPLE(*inptr); + neighsum = lastcolsum + (colsum - membersum) + nextcolsum; + membersum = membersum * memberscale + neighsum * neighscale; + *outptr++ = (JSAMPLE) ((membersum + 32768) >> 16); + lastcolsum = colsum; colsum = nextcolsum; + } + + /* Special case for last column */ + membersum = GETJSAMPLE(*inptr); + neighsum = lastcolsum + (colsum - membersum) + colsum; + membersum = membersum * memberscale + neighsum * neighscale; + *outptr = (JSAMPLE) ((membersum + 32768) >> 16); + + } +} + +#endif /* INPUT_SMOOTHING_SUPPORTED */ + + +/* + * Module initialization routine for downsampling. + * Note that we must select a routine for each component. + */ + +GLOBAL(void) +jinit_downsampler (j_compress_ptr cinfo) +{ + my_downsample_ptr downsample; + int ci; + jpeg_component_info * compptr; + int smoothok = TRUE; + + downsample = (my_downsample_ptr) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE, + SIZEOF(my_downsampler)); + cinfo->downsample = (struct jpeg_downsampler *) downsample; + downsample->pub.start_pass = start_pass_downsample; + downsample->pub.downsample = sep_downsample; + downsample->pub.need_context_rows = FALSE; + + if (cinfo->CCIR601_sampling) + ERREXIT(cinfo, JERR_CCIR601_NOTIMPL); + + /* Verify we can handle the sampling factors, and set up method pointers */ + for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components; + ci++, compptr++) { + if (compptr->h_samp_factor == cinfo->max_h_samp_factor && + compptr->v_samp_factor == cinfo->max_v_samp_factor) { +#ifdef INPUT_SMOOTHING_SUPPORTED + if (cinfo->smoothing_factor) { + downsample->methods[ci] = fullsize_smooth_downsample; + downsample->pub.need_context_rows = TRUE; + } else +#endif + downsample->methods[ci] = fullsize_downsample; + } else if (compptr->h_samp_factor * 2 == cinfo->max_h_samp_factor && + compptr->v_samp_factor == cinfo->max_v_samp_factor) { + smoothok = FALSE; + downsample->methods[ci] = h2v1_downsample; + } else if (compptr->h_samp_factor * 2 == cinfo->max_h_samp_factor && + compptr->v_samp_factor * 2 == cinfo->max_v_samp_factor) { +#ifdef INPUT_SMOOTHING_SUPPORTED + if (cinfo->smoothing_factor) { + downsample->methods[ci] = h2v2_smooth_downsample; + downsample->pub.need_context_rows = TRUE; + } else +#endif + downsample->methods[ci] = h2v2_downsample; + } else if ((cinfo->max_h_samp_factor % compptr->h_samp_factor) == 0 && + (cinfo->max_v_samp_factor % compptr->v_samp_factor) == 0) { + smoothok = FALSE; + downsample->methods[ci] = int_downsample; + } else + ERREXIT(cinfo, JERR_FRACT_SAMPLE_NOTIMPL); + } + +#ifdef INPUT_SMOOTHING_SUPPORTED + if (cinfo->smoothing_factor && !smoothok) + TRACEMS(cinfo, 0, JTRC_SMOOTH_NOTIMPL); +#endif +} diff --git a/ml/dlib/dlib/external/libjpeg/jdapimin.cpp b/ml/dlib/dlib/external/libjpeg/jdapimin.cpp new file mode 100644 index 000000000..3ea5bf161 --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jdapimin.cpp @@ -0,0 +1,395 @@ +/* + * jdapimin.c + * + * Copyright (C) 1994-1998, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This file contains application interface code for the decompression half + * of the JPEG library. These are the "minimum" API routines that may be + * needed in either the normal full-decompression case or the + * transcoding-only case. + * + * Most of the routines intended to be called directly by an application + * are in this file or in jdapistd.c. But also see jcomapi.c for routines + * shared by compression and decompression, and jdtrans.c for the transcoding + * case. + */ + +#define JPEG_INTERNALS +#include "jinclude.h" +#include "jpeglib.h" + + +/* + * Initialization of a JPEG decompression object. + * The error manager must already be set up (in case memory manager fails). + */ + +GLOBAL(void) +jpeg_CreateDecompress (j_decompress_ptr cinfo, int version, size_t structsize) +{ + int i; + + /* Guard against version mismatches between library and caller. */ + cinfo->mem = NULL; /* so jpeg_destroy knows mem mgr not called */ + if (version != JPEG_LIB_VERSION) + ERREXIT2(cinfo, JERR_BAD_LIB_VERSION, JPEG_LIB_VERSION, version); + if (structsize != SIZEOF(struct jpeg_decompress_struct)) + ERREXIT2(cinfo, JERR_BAD_STRUCT_SIZE, + (int) SIZEOF(struct jpeg_decompress_struct), (int) structsize); + + /* For debugging purposes, we zero the whole master structure. + * But the application has already set the err pointer, and may have set + * client_data, so we have to save and restore those fields. + * Note: if application hasn't set client_data, tools like Purify may + * complain here. + */ + { + struct jpeg_error_mgr * err = cinfo->err; + void * client_data = cinfo->client_data; /* ignore Purify complaint here */ + MEMZERO(cinfo, SIZEOF(struct jpeg_decompress_struct)); + cinfo->err = err; + cinfo->client_data = client_data; + } + cinfo->is_decompressor = TRUE; + + /* Initialize a memory manager instance for this object */ + jinit_memory_mgr((j_common_ptr) cinfo); + + /* Zero out pointers to permanent structures. */ + cinfo->progress = NULL; + cinfo->src = NULL; + + for (i = 0; i < NUM_QUANT_TBLS; i++) + cinfo->quant_tbl_ptrs[i] = NULL; + + for (i = 0; i < NUM_HUFF_TBLS; i++) { + cinfo->dc_huff_tbl_ptrs[i] = NULL; + cinfo->ac_huff_tbl_ptrs[i] = NULL; + } + + /* Initialize marker processor so application can override methods + * for COM, APPn markers before calling jpeg_read_header. + */ + cinfo->marker_list = NULL; + jinit_marker_reader(cinfo); + + /* And initialize the overall input controller. */ + jinit_input_controller(cinfo); + + /* OK, I'm ready */ + cinfo->global_state = DSTATE_START; +} + + +/* + * Destruction of a JPEG decompression object + */ + +GLOBAL(void) +jpeg_destroy_decompress (j_decompress_ptr cinfo) +{ + jpeg_destroy((j_common_ptr) cinfo); /* use common routine */ +} + + +/* + * Abort processing of a JPEG decompression operation, + * but don't destroy the object itself. + */ + +GLOBAL(void) +jpeg_abort_decompress (j_decompress_ptr cinfo) +{ + jpeg_abort((j_common_ptr) cinfo); /* use common routine */ +} + + +/* + * Set default decompression parameters. + */ + +LOCAL(void) +default_decompress_parms (j_decompress_ptr cinfo) +{ + /* Guess the input colorspace, and set output colorspace accordingly. */ + /* (Wish JPEG committee had provided a real way to specify this...) */ + /* Note application may override our guesses. */ + switch (cinfo->num_components) { + case 1: + cinfo->jpeg_color_space = JCS_GRAYSCALE; + cinfo->out_color_space = JCS_GRAYSCALE; + break; + + case 3: + if (cinfo->saw_JFIF_marker) { + cinfo->jpeg_color_space = JCS_YCbCr; /* JFIF implies YCbCr */ + } else if (cinfo->saw_Adobe_marker) { + switch (cinfo->Adobe_transform) { + case 0: + cinfo->jpeg_color_space = JCS_RGB; + break; + case 1: + cinfo->jpeg_color_space = JCS_YCbCr; + break; + default: + WARNMS1(cinfo, JWRN_ADOBE_XFORM, cinfo->Adobe_transform); + cinfo->jpeg_color_space = JCS_YCbCr; /* assume it's YCbCr */ + break; + } + } else { + /* Saw no special markers, try to guess from the component IDs */ + int cid0 = cinfo->comp_info[0].component_id; + int cid1 = cinfo->comp_info[1].component_id; + int cid2 = cinfo->comp_info[2].component_id; + + if (cid0 == 1 && cid1 == 2 && cid2 == 3) + cinfo->jpeg_color_space = JCS_YCbCr; /* assume JFIF w/out marker */ + else if (cid0 == 82 && cid1 == 71 && cid2 == 66) + cinfo->jpeg_color_space = JCS_RGB; /* ASCII 'R', 'G', 'B' */ + else { + TRACEMS3(cinfo, 1, JTRC_UNKNOWN_IDS, cid0, cid1, cid2); + cinfo->jpeg_color_space = JCS_YCbCr; /* assume it's YCbCr */ + } + } + /* Always guess RGB is proper output colorspace. */ + cinfo->out_color_space = JCS_RGB; + break; + + case 4: + if (cinfo->saw_Adobe_marker) { + switch (cinfo->Adobe_transform) { + case 0: + cinfo->jpeg_color_space = JCS_CMYK; + break; + case 2: + cinfo->jpeg_color_space = JCS_YCCK; + break; + default: + WARNMS1(cinfo, JWRN_ADOBE_XFORM, cinfo->Adobe_transform); + cinfo->jpeg_color_space = JCS_YCCK; /* assume it's YCCK */ + break; + } + } else { + /* No special markers, assume straight CMYK. */ + cinfo->jpeg_color_space = JCS_CMYK; + } + cinfo->out_color_space = JCS_CMYK; + break; + + default: + cinfo->jpeg_color_space = JCS_UNKNOWN; + cinfo->out_color_space = JCS_UNKNOWN; + break; + } + + /* Set defaults for other decompression parameters. */ + cinfo->scale_num = 1; /* 1:1 scaling */ + cinfo->scale_denom = 1; + cinfo->output_gamma = 1.0; + cinfo->buffered_image = FALSE; + cinfo->raw_data_out = FALSE; + cinfo->dct_method = JDCT_DEFAULT; + cinfo->do_fancy_upsampling = TRUE; + cinfo->do_block_smoothing = TRUE; + cinfo->quantize_colors = FALSE; + /* We set these in case application only sets quantize_colors. */ + cinfo->dither_mode = JDITHER_FS; +#ifdef QUANT_2PASS_SUPPORTED + cinfo->two_pass_quantize = TRUE; +#else + cinfo->two_pass_quantize = FALSE; +#endif + cinfo->desired_number_of_colors = 256; + cinfo->colormap = NULL; + /* Initialize for no mode change in buffered-image mode. */ + cinfo->enable_1pass_quant = FALSE; + cinfo->enable_external_quant = FALSE; + cinfo->enable_2pass_quant = FALSE; +} + + +/* + * Decompression startup: read start of JPEG datastream to see what's there. + * Need only initialize JPEG object and supply a data source before calling. + * + * This routine will read as far as the first SOS marker (ie, actual start of + * compressed data), and will save all tables and parameters in the JPEG + * object. It will also initialize the decompression parameters to default + * values, and finally return JPEG_HEADER_OK. On return, the application may + * adjust the decompression parameters and then call jpeg_start_decompress. + * (Or, if the application only wanted to determine the image parameters, + * the data need not be decompressed. In that case, call jpeg_abort or + * jpeg_destroy to release any temporary space.) + * If an abbreviated (tables only) datastream is presented, the routine will + * return JPEG_HEADER_TABLES_ONLY upon reaching EOI. The application may then + * re-use the JPEG object to read the abbreviated image datastream(s). + * It is unnecessary (but OK) to call jpeg_abort in this case. + * The JPEG_SUSPENDED return code only occurs if the data source module + * requests suspension of the decompressor. In this case the application + * should load more source data and then re-call jpeg_read_header to resume + * processing. + * If a non-suspending data source is used and require_image is TRUE, then the + * return code need not be inspected since only JPEG_HEADER_OK is possible. + * + * This routine is now just a front end to jpeg_consume_input, with some + * extra error checking. + */ + +GLOBAL(int) +jpeg_read_header (j_decompress_ptr cinfo, int require_image) +{ + int retcode; + + if (cinfo->global_state != DSTATE_START && + cinfo->global_state != DSTATE_INHEADER) + ERREXIT1(cinfo, JERR_BAD_STATE, cinfo->global_state); + + retcode = jpeg_consume_input(cinfo); + + switch (retcode) { + case JPEG_REACHED_SOS: + retcode = JPEG_HEADER_OK; + break; + case JPEG_REACHED_EOI: + if (require_image) /* Complain if application wanted an image */ + ERREXIT(cinfo, JERR_NO_IMAGE); + /* Reset to start state; it would be safer to require the application to + * call jpeg_abort, but we can't change it now for compatibility reasons. + * A side effect is to free any temporary memory (there shouldn't be any). + */ + jpeg_abort((j_common_ptr) cinfo); /* sets state = DSTATE_START */ + retcode = JPEG_HEADER_TABLES_ONLY; + break; + case JPEG_SUSPENDED: + /* no work */ + break; + } + + return retcode; +} + + +/* + * Consume data in advance of what the decompressor requires. + * This can be called at any time once the decompressor object has + * been created and a data source has been set up. + * + * This routine is essentially a state machine that handles a couple + * of critical state-transition actions, namely initial setup and + * transition from header scanning to ready-for-start_decompress. + * All the actual input is done via the input controller's consume_input + * method. + */ + +GLOBAL(int) +jpeg_consume_input (j_decompress_ptr cinfo) +{ + int retcode = JPEG_SUSPENDED; + + /* NB: every possible DSTATE value should be listed in this switch */ + switch (cinfo->global_state) { + case DSTATE_START: + /* Start-of-datastream actions: reset appropriate modules */ + (*cinfo->inputctl->reset_input_controller) (cinfo); + /* Initialize application's data source module */ + (*cinfo->src->init_source) (cinfo); + cinfo->global_state = DSTATE_INHEADER; + /*FALLTHROUGH*/ + case DSTATE_INHEADER: + retcode = (*cinfo->inputctl->consume_input) (cinfo); + if (retcode == JPEG_REACHED_SOS) { /* Found SOS, prepare to decompress */ + /* Set up default parameters based on header data */ + default_decompress_parms(cinfo); + /* Set global state: ready for start_decompress */ + cinfo->global_state = DSTATE_READY; + } + break; + case DSTATE_READY: + /* Can't advance past first SOS until start_decompress is called */ + retcode = JPEG_REACHED_SOS; + break; + case DSTATE_PRELOAD: + case DSTATE_PRESCAN: + case DSTATE_SCANNING: + case DSTATE_RAW_OK: + case DSTATE_BUFIMAGE: + case DSTATE_BUFPOST: + case DSTATE_STOPPING: + retcode = (*cinfo->inputctl->consume_input) (cinfo); + break; + default: + ERREXIT1(cinfo, JERR_BAD_STATE, cinfo->global_state); + } + return retcode; +} + + +/* + * Have we finished reading the input file? + */ + +GLOBAL(int) +jpeg_input_complete (j_decompress_ptr cinfo) +{ + /* Check for valid jpeg object */ + if (cinfo->global_state < DSTATE_START || + cinfo->global_state > DSTATE_STOPPING) + ERREXIT1(cinfo, JERR_BAD_STATE, cinfo->global_state); + return cinfo->inputctl->eoi_reached; +} + + +/* + * Is there more than one scan? + */ + +GLOBAL(int) +jpeg_has_multiple_scans (j_decompress_ptr cinfo) +{ + /* Only valid after jpeg_read_header completes */ + if (cinfo->global_state < DSTATE_READY || + cinfo->global_state > DSTATE_STOPPING) + ERREXIT1(cinfo, JERR_BAD_STATE, cinfo->global_state); + return cinfo->inputctl->has_multiple_scans; +} + + +/* + * Finish JPEG decompression. + * + * This will normally just verify the file trailer and release temp storage. + * + * Returns FALSE if suspended. The return value need be inspected only if + * a suspending data source is used. + */ + +GLOBAL(int) +jpeg_finish_decompress (j_decompress_ptr cinfo) +{ + if ((cinfo->global_state == DSTATE_SCANNING || + cinfo->global_state == DSTATE_RAW_OK) && ! cinfo->buffered_image) { + /* Terminate final pass of non-buffered mode */ + if (cinfo->output_scanline < cinfo->output_height) + ERREXIT(cinfo, JERR_TOO_LITTLE_DATA); + (*cinfo->master->finish_output_pass) (cinfo); + cinfo->global_state = DSTATE_STOPPING; + } else if (cinfo->global_state == DSTATE_BUFIMAGE) { + /* Finishing after a buffered-image operation */ + cinfo->global_state = DSTATE_STOPPING; + } else if (cinfo->global_state != DSTATE_STOPPING) { + /* STOPPING = repeat call after a suspension, anything else is error */ + ERREXIT1(cinfo, JERR_BAD_STATE, cinfo->global_state); + } + /* Read until EOI */ + while (! cinfo->inputctl->eoi_reached) { + if ((*cinfo->inputctl->consume_input) (cinfo) == JPEG_SUSPENDED) + return FALSE; /* Suspend, come back later */ + } + /* Do final cleanup */ + (*cinfo->src->term_source) (cinfo); + /* We can use jpeg_abort to release memory and reset global_state */ + jpeg_abort((j_common_ptr) cinfo); + return TRUE; +} diff --git a/ml/dlib/dlib/external/libjpeg/jdapistd.cpp b/ml/dlib/dlib/external/libjpeg/jdapistd.cpp new file mode 100644 index 000000000..03d909e00 --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jdapistd.cpp @@ -0,0 +1,275 @@ +/* + * jdapistd.c + * + * Copyright (C) 1994-1996, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This file contains application interface code for the decompression half + * of the JPEG library. These are the "standard" API routines that are + * used in the normal full-decompression case. They are not used by a + * transcoding-only application. Note that if an application links in + * jpeg_start_decompress, it will end up linking in the entire decompressor. + * We thus must separate this file from jdapimin.c to avoid linking the + * whole decompression library into a transcoder. + */ + +#define JPEG_INTERNALS +#include "jinclude.h" +#include "jpeglib.h" + + +/* Forward declarations */ +LOCAL(int) output_pass_setup JPP((j_decompress_ptr cinfo)); + + +/* + * Decompression initialization. + * jpeg_read_header must be completed before calling this. + * + * If a multipass operating mode was selected, this will do all but the + * last pass, and thus may take a great deal of time. + * + * Returns FALSE if suspended. The return value need be inspected only if + * a suspending data source is used. + */ + +GLOBAL(int) +jpeg_start_decompress (j_decompress_ptr cinfo) +{ + if (cinfo->global_state == DSTATE_READY) { + /* First call: initialize master control, select active modules */ + jinit_master_decompress(cinfo); + if (cinfo->buffered_image) { + /* No more work here; expecting jpeg_start_output next */ + cinfo->global_state = DSTATE_BUFIMAGE; + return TRUE; + } + cinfo->global_state = DSTATE_PRELOAD; + } + if (cinfo->global_state == DSTATE_PRELOAD) { + /* If file has multiple scans, absorb them all into the coef buffer */ + if (cinfo->inputctl->has_multiple_scans) { +#ifdef D_MULTISCAN_FILES_SUPPORTED + for (;;) { + int retcode; + /* Call progress monitor hook if present */ + if (cinfo->progress != NULL) + (*cinfo->progress->progress_monitor) ((j_common_ptr) cinfo); + /* Absorb some more input */ + retcode = (*cinfo->inputctl->consume_input) (cinfo); + if (retcode == JPEG_SUSPENDED) + return FALSE; + if (retcode == JPEG_REACHED_EOI) + break; + /* Advance progress counter if appropriate */ + if (cinfo->progress != NULL && + (retcode == JPEG_ROW_COMPLETED || retcode == JPEG_REACHED_SOS)) { + if (++cinfo->progress->pass_counter >= cinfo->progress->pass_limit) { + /* jdmaster underestimated number of scans; ratchet up one scan */ + cinfo->progress->pass_limit += (long) cinfo->total_iMCU_rows; + } + } + } +#else + ERREXIT(cinfo, JERR_NOT_COMPILED); +#endif /* D_MULTISCAN_FILES_SUPPORTED */ + } + cinfo->output_scan_number = cinfo->input_scan_number; + } else if (cinfo->global_state != DSTATE_PRESCAN) + ERREXIT1(cinfo, JERR_BAD_STATE, cinfo->global_state); + /* Perform any dummy output passes, and set up for the final pass */ + return output_pass_setup(cinfo); +} + + +/* + * Set up for an output pass, and perform any dummy pass(es) needed. + * Common subroutine for jpeg_start_decompress and jpeg_start_output. + * Entry: global_state = DSTATE_PRESCAN only if previously suspended. + * Exit: If done, returns TRUE and sets global_state for proper output mode. + * If suspended, returns FALSE and sets global_state = DSTATE_PRESCAN. + */ + +LOCAL(int) +output_pass_setup (j_decompress_ptr cinfo) +{ + if (cinfo->global_state != DSTATE_PRESCAN) { + /* First call: do pass setup */ + (*cinfo->master->prepare_for_output_pass) (cinfo); + cinfo->output_scanline = 0; + cinfo->global_state = DSTATE_PRESCAN; + } + /* Loop over any required dummy passes */ + while (cinfo->master->is_dummy_pass) { +#ifdef QUANT_2PASS_SUPPORTED + /* Crank through the dummy pass */ + while (cinfo->output_scanline < cinfo->output_height) { + JDIMENSION last_scanline; + /* Call progress monitor hook if present */ + if (cinfo->progress != NULL) { + cinfo->progress->pass_counter = (long) cinfo->output_scanline; + cinfo->progress->pass_limit = (long) cinfo->output_height; + (*cinfo->progress->progress_monitor) ((j_common_ptr) cinfo); + } + /* Process some data */ + last_scanline = cinfo->output_scanline; + (*cinfo->main->process_data) (cinfo, (JSAMPARRAY) NULL, + &cinfo->output_scanline, (JDIMENSION) 0); + if (cinfo->output_scanline == last_scanline) + return FALSE; /* No progress made, must suspend */ + } + /* Finish up dummy pass, and set up for another one */ + (*cinfo->master->finish_output_pass) (cinfo); + (*cinfo->master->prepare_for_output_pass) (cinfo); + cinfo->output_scanline = 0; +#else + ERREXIT(cinfo, JERR_NOT_COMPILED); +#endif /* QUANT_2PASS_SUPPORTED */ + } + /* Ready for application to drive output pass through + * jpeg_read_scanlines or jpeg_read_raw_data. + */ + cinfo->global_state = cinfo->raw_data_out ? DSTATE_RAW_OK : DSTATE_SCANNING; + return TRUE; +} + + +/* + * Read some scanlines of data from the JPEG decompressor. + * + * The return value will be the number of lines actually read. + * This may be less than the number requested in several cases, + * including bottom of image, data source suspension, and operating + * modes that emit multiple scanlines at a time. + * + * Note: we warn about excess calls to jpeg_read_scanlines() since + * this likely signals an application programmer error. However, + * an oversize buffer (max_lines > scanlines remaining) is not an error. + */ + +GLOBAL(JDIMENSION) +jpeg_read_scanlines (j_decompress_ptr cinfo, JSAMPARRAY scanlines, + JDIMENSION max_lines) +{ + JDIMENSION row_ctr; + + if (cinfo->global_state != DSTATE_SCANNING) + ERREXIT1(cinfo, JERR_BAD_STATE, cinfo->global_state); + if (cinfo->output_scanline >= cinfo->output_height) { + WARNMS(cinfo, JWRN_TOO_MUCH_DATA); + return 0; + } + + /* Call progress monitor hook if present */ + if (cinfo->progress != NULL) { + cinfo->progress->pass_counter = (long) cinfo->output_scanline; + cinfo->progress->pass_limit = (long) cinfo->output_height; + (*cinfo->progress->progress_monitor) ((j_common_ptr) cinfo); + } + + /* Process some data */ + row_ctr = 0; + (*cinfo->main->process_data) (cinfo, scanlines, &row_ctr, max_lines); + cinfo->output_scanline += row_ctr; + return row_ctr; +} + + +/* + * Alternate entry point to read raw data. + * Processes exactly one iMCU row per call, unless suspended. + */ + +GLOBAL(JDIMENSION) +jpeg_read_raw_data (j_decompress_ptr cinfo, JSAMPIMAGE data, + JDIMENSION max_lines) +{ + JDIMENSION lines_per_iMCU_row; + + if (cinfo->global_state != DSTATE_RAW_OK) + ERREXIT1(cinfo, JERR_BAD_STATE, cinfo->global_state); + if (cinfo->output_scanline >= cinfo->output_height) { + WARNMS(cinfo, JWRN_TOO_MUCH_DATA); + return 0; + } + + /* Call progress monitor hook if present */ + if (cinfo->progress != NULL) { + cinfo->progress->pass_counter = (long) cinfo->output_scanline; + cinfo->progress->pass_limit = (long) cinfo->output_height; + (*cinfo->progress->progress_monitor) ((j_common_ptr) cinfo); + } + + /* Verify that at least one iMCU row can be returned. */ + lines_per_iMCU_row = cinfo->max_v_samp_factor * cinfo->min_DCT_scaled_size; + if (max_lines < lines_per_iMCU_row) + ERREXIT(cinfo, JERR_BUFFER_SIZE); + + /* Decompress directly into user's buffer. */ + if (! (*cinfo->coef->decompress_data) (cinfo, data)) + return 0; /* suspension forced, can do nothing more */ + + /* OK, we processed one iMCU row. */ + cinfo->output_scanline += lines_per_iMCU_row; + return lines_per_iMCU_row; +} + + +/* Additional entry points for buffered-image mode. */ + +#ifdef D_MULTISCAN_FILES_SUPPORTED + +/* + * Initialize for an output pass in buffered-image mode. + */ + +GLOBAL(int) +jpeg_start_output (j_decompress_ptr cinfo, int scan_number) +{ + if (cinfo->global_state != DSTATE_BUFIMAGE && + cinfo->global_state != DSTATE_PRESCAN) + ERREXIT1(cinfo, JERR_BAD_STATE, cinfo->global_state); + /* Limit scan number to valid range */ + if (scan_number <= 0) + scan_number = 1; + if (cinfo->inputctl->eoi_reached && + scan_number > cinfo->input_scan_number) + scan_number = cinfo->input_scan_number; + cinfo->output_scan_number = scan_number; + /* Perform any dummy output passes, and set up for the real pass */ + return output_pass_setup(cinfo); +} + + +/* + * Finish up after an output pass in buffered-image mode. + * + * Returns FALSE if suspended. The return value need be inspected only if + * a suspending data source is used. + */ + +GLOBAL(int) +jpeg_finish_output (j_decompress_ptr cinfo) +{ + if ((cinfo->global_state == DSTATE_SCANNING || + cinfo->global_state == DSTATE_RAW_OK) && cinfo->buffered_image) { + /* Terminate this pass. */ + /* We do not require the whole pass to have been completed. */ + (*cinfo->master->finish_output_pass) (cinfo); + cinfo->global_state = DSTATE_BUFPOST; + } else if (cinfo->global_state != DSTATE_BUFPOST) { + /* BUFPOST = repeat call after a suspension, anything else is error */ + ERREXIT1(cinfo, JERR_BAD_STATE, cinfo->global_state); + } + /* Read markers looking for SOS or EOI */ + while (cinfo->input_scan_number <= cinfo->output_scan_number && + ! cinfo->inputctl->eoi_reached) { + if ((*cinfo->inputctl->consume_input) (cinfo) == JPEG_SUSPENDED) + return FALSE; /* Suspend, come back later */ + } + cinfo->global_state = DSTATE_BUFIMAGE; + return TRUE; +} + +#endif /* D_MULTISCAN_FILES_SUPPORTED */ diff --git a/ml/dlib/dlib/external/libjpeg/jdatadst.cpp b/ml/dlib/dlib/external/libjpeg/jdatadst.cpp new file mode 100644 index 000000000..afa3c83c6 --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jdatadst.cpp @@ -0,0 +1,151 @@ +/* + * jdatadst.c + * + * Copyright (C) 1994-1996, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This file contains compression data destination routines for the case of + * emitting JPEG data to a file (or any stdio stream). While these routines + * are sufficient for most applications, some will want to use a different + * destination manager. + * IMPORTANT: we assume that fwrite() will correctly transcribe an array of + * JOCTETs into 8-bit-wide elements on external storage. If char is wider + * than 8 bits on your machine, you may need to do some tweaking. + */ + +/* this is not a core library module, so it doesn't define JPEG_INTERNALS */ +#include "jinclude.h" +#include "jpeglib.h" +#include "jerror.h" + + +/* Expanded data destination object for stdio output */ + +typedef struct { + struct jpeg_destination_mgr pub; /* public fields */ + + FILE * outfile; /* target stream */ + JOCTET * buffer; /* start of buffer */ +} my_destination_mgr; + +typedef my_destination_mgr * my_dest_ptr; + +#define OUTPUT_BUF_SIZE 4096 /* choose an efficiently fwrite'able size */ + + +/* + * Initialize destination --- called by jpeg_start_compress + * before any data is actually written. + */ + +METHODDEF(void) +init_destination (j_compress_ptr cinfo) +{ + my_dest_ptr dest = (my_dest_ptr) cinfo->dest; + + /* Allocate the output buffer --- it will be released when done with image */ + dest->buffer = (JOCTET *) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE, + OUTPUT_BUF_SIZE * SIZEOF(JOCTET)); + + dest->pub.next_output_byte = dest->buffer; + dest->pub.free_in_buffer = OUTPUT_BUF_SIZE; +} + + +/* + * Empty the output buffer --- called whenever buffer fills up. + * + * In typical applications, this should write the entire output buffer + * (ignoring the current state of next_output_byte & free_in_buffer), + * reset the pointer & count to the start of the buffer, and return TRUE + * indicating that the buffer has been dumped. + * + * In applications that need to be able to suspend compression due to output + * overrun, a FALSE return indicates that the buffer cannot be emptied now. + * In this situation, the compressor will return to its caller (possibly with + * an indication that it has not accepted all the supplied scanlines). The + * application should resume compression after it has made more room in the + * output buffer. Note that there are substantial restrictions on the use of + * suspension --- see the documentation. + * + * When suspending, the compressor will back up to a convenient restart point + * (typically the start of the current MCU). next_output_byte & free_in_buffer + * indicate where the restart point will be if the current call returns FALSE. + * Data beyond this point will be regenerated after resumption, so do not + * write it out when emptying the buffer externally. + */ + +METHODDEF(int) +empty_output_buffer (j_compress_ptr cinfo) +{ + my_dest_ptr dest = (my_dest_ptr) cinfo->dest; + + if (JFWRITE(dest->outfile, dest->buffer, OUTPUT_BUF_SIZE) != + (size_t) OUTPUT_BUF_SIZE) + ERREXIT(cinfo, JERR_FILE_WRITE); + + dest->pub.next_output_byte = dest->buffer; + dest->pub.free_in_buffer = OUTPUT_BUF_SIZE; + + return TRUE; +} + + +/* + * Terminate destination --- called by jpeg_finish_compress + * after all data has been written. Usually needs to flush buffer. + * + * NB: *not* called by jpeg_abort or jpeg_destroy; surrounding + * application must deal with any cleanup that should happen even + * for error exit. + */ + +METHODDEF(void) +term_destination (j_compress_ptr cinfo) +{ + my_dest_ptr dest = (my_dest_ptr) cinfo->dest; + size_t datacount = OUTPUT_BUF_SIZE - dest->pub.free_in_buffer; + + /* Write any data remaining in the buffer */ + if (datacount > 0) { + if (JFWRITE(dest->outfile, dest->buffer, datacount) != datacount) + ERREXIT(cinfo, JERR_FILE_WRITE); + } + fflush(dest->outfile); + /* Make sure we wrote the output file OK */ + if (ferror(dest->outfile)) + ERREXIT(cinfo, JERR_FILE_WRITE); +} + + +/* + * Prepare for output to a stdio stream. + * The caller must have already opened the stream, and is responsible + * for closing it after finishing compression. + */ + +GLOBAL(void) +jpeg_stdio_dest (j_compress_ptr cinfo, FILE * outfile) +{ + my_dest_ptr dest; + + /* The destination object is made permanent so that multiple JPEG images + * can be written to the same file without re-executing jpeg_stdio_dest. + * This makes it dangerous to use this manager and a different destination + * manager serially with the same JPEG object, because their private object + * sizes may be different. Caveat programmer. + */ + if (cinfo->dest == NULL) { /* first time for this JPEG object? */ + cinfo->dest = (struct jpeg_destination_mgr *) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_PERMANENT, + SIZEOF(my_destination_mgr)); + } + + dest = (my_dest_ptr) cinfo->dest; + dest->pub.init_destination = init_destination; + dest->pub.empty_output_buffer = empty_output_buffer; + dest->pub.term_destination = term_destination; + dest->outfile = outfile; +} diff --git a/ml/dlib/dlib/external/libjpeg/jdatasrc.cpp b/ml/dlib/dlib/external/libjpeg/jdatasrc.cpp new file mode 100644 index 000000000..7af097f02 --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jdatasrc.cpp @@ -0,0 +1,212 @@ +/* + * jdatasrc.c + * + * Copyright (C) 1994-1996, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This file contains decompression data source routines for the case of + * reading JPEG data from a file (or any stdio stream). While these routines + * are sufficient for most applications, some will want to use a different + * source manager. + * IMPORTANT: we assume that fread() will correctly transcribe an array of + * JOCTETs from 8-bit-wide elements on external storage. If char is wider + * than 8 bits on your machine, you may need to do some tweaking. + */ + +/* this is not a core library module, so it doesn't define JPEG_INTERNALS */ +#include "jinclude.h" +#include "jpeglib.h" +#include "jerror.h" + + +/* Expanded data source object for stdio input */ + +typedef struct { + struct jpeg_source_mgr pub; /* public fields */ + + FILE * infile; /* source stream */ + JOCTET * buffer; /* start of buffer */ + int start_of_file; /* have we gotten any data yet? */ +} my_source_mgr; + +typedef my_source_mgr * my_src_ptr; + +#define INPUT_BUF_SIZE 4096 /* choose an efficiently fread'able size */ + + +/* + * Initialize source --- called by jpeg_read_header + * before any data is actually read. + */ + +METHODDEF(void) +init_source (j_decompress_ptr cinfo) +{ + my_src_ptr src = (my_src_ptr) cinfo->src; + + /* We reset the empty-input-file flag for each image, + * but we don't clear the input buffer. + * This is correct behavior for reading a series of images from one source. + */ + src->start_of_file = TRUE; +} + + +/* + * Fill the input buffer --- called whenever buffer is emptied. + * + * In typical applications, this should read fresh data into the buffer + * (ignoring the current state of next_input_byte & bytes_in_buffer), + * reset the pointer & count to the start of the buffer, and return TRUE + * indicating that the buffer has been reloaded. It is not necessary to + * fill the buffer entirely, only to obtain at least one more byte. + * + * There is no such thing as an EOF return. If the end of the file has been + * reached, the routine has a choice of ERREXIT() or inserting fake data into + * the buffer. In most cases, generating a warning message and inserting a + * fake EOI marker is the best course of action --- this will allow the + * decompressor to output however much of the image is there. However, + * the resulting error message is misleading if the real problem is an empty + * input file, so we handle that case specially. + * + * In applications that need to be able to suspend compression due to input + * not being available yet, a FALSE return indicates that no more data can be + * obtained right now, but more may be forthcoming later. In this situation, + * the decompressor will return to its caller (with an indication of the + * number of scanlines it has read, if any). The application should resume + * decompression after it has loaded more data into the input buffer. Note + * that there are substantial restrictions on the use of suspension --- see + * the documentation. + * + * When suspending, the decompressor will back up to a convenient restart point + * (typically the start of the current MCU). next_input_byte & bytes_in_buffer + * indicate where the restart point will be if the current call returns FALSE. + * Data beyond this point must be rescanned after resumption, so move it to + * the front of the buffer rather than discarding it. + */ + +METHODDEF(int) +fill_input_buffer (j_decompress_ptr cinfo) +{ + my_src_ptr src = (my_src_ptr) cinfo->src; + size_t nbytes; + + nbytes = JFREAD(src->infile, src->buffer, INPUT_BUF_SIZE); + + if (nbytes <= 0) { + if (src->start_of_file) /* Treat empty input file as fatal error */ + ERREXIT(cinfo, JERR_INPUT_EMPTY); + WARNMS(cinfo, JWRN_JPEG_EOF); + /* Insert a fake EOI marker */ + src->buffer[0] = (JOCTET) 0xFF; + src->buffer[1] = (JOCTET) JPEG_EOI; + nbytes = 2; + } + + src->pub.next_input_byte = src->buffer; + src->pub.bytes_in_buffer = nbytes; + src->start_of_file = FALSE; + + return TRUE; +} + + +/* + * Skip data --- used to skip over a potentially large amount of + * uninteresting data (such as an APPn marker). + * + * Writers of suspendable-input applications must note that skip_input_data + * is not granted the right to give a suspension return. If the skip extends + * beyond the data currently in the buffer, the buffer can be marked empty so + * that the next read will cause a fill_input_buffer call that can suspend. + * Arranging for additional bytes to be discarded before reloading the input + * buffer is the application writer's problem. + */ + +METHODDEF(void) +skip_input_data (j_decompress_ptr cinfo, long num_bytes) +{ + my_src_ptr src = (my_src_ptr) cinfo->src; + + /* Just a dumb implementation for now. Could use fseek() except + * it doesn't work on pipes. Not clear that being smart is worth + * any trouble anyway --- large skips are infrequent. + */ + if (num_bytes > 0) { + while (num_bytes > (long) src->pub.bytes_in_buffer) { + num_bytes -= (long) src->pub.bytes_in_buffer; + (void) fill_input_buffer(cinfo); + /* note we assume that fill_input_buffer will never return FALSE, + * so suspension need not be handled. + */ + } + src->pub.next_input_byte += (size_t) num_bytes; + src->pub.bytes_in_buffer -= (size_t) num_bytes; + } +} + + +/* + * An additional method that can be provided by data source modules is the + * resync_to_restart method for error recovery in the presence of RST markers. + * For the moment, this source module just uses the default resync method + * provided by the JPEG library. That method assumes that no backtracking + * is possible. + */ + + +/* + * Terminate source --- called by jpeg_finish_decompress + * after all data has been read. Often a no-op. + * + * NB: *not* called by jpeg_abort or jpeg_destroy; surrounding + * application must deal with any cleanup that should happen even + * for error exit. + */ + +METHODDEF(void) +term_source (j_decompress_ptr ) +{ + /* no work necessary here */ +} + + +/* + * Prepare for input from a stdio stream. + * The caller must have already opened the stream, and is responsible + * for closing it after finishing decompression. + */ + +GLOBAL(void) +jpeg_stdio_src (j_decompress_ptr cinfo, FILE * infile) +{ + my_src_ptr src; + + /* The source object and input buffer are made permanent so that a series + * of JPEG images can be read from the same file by calling jpeg_stdio_src + * only before the first one. (If we discarded the buffer at the end of + * one image, we'd likely lose the start of the next one.) + * This makes it unsafe to use this manager and a different source + * manager serially with the same JPEG object. Caveat programmer. + */ + if (cinfo->src == NULL) { /* first time for this JPEG object? */ + cinfo->src = (struct jpeg_source_mgr *) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_PERMANENT, + SIZEOF(my_source_mgr)); + src = (my_src_ptr) cinfo->src; + src->buffer = (JOCTET *) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_PERMANENT, + INPUT_BUF_SIZE * SIZEOF(JOCTET)); + } + + src = (my_src_ptr) cinfo->src; + src->pub.init_source = init_source; + src->pub.fill_input_buffer = fill_input_buffer; + src->pub.skip_input_data = skip_input_data; + src->pub.resync_to_restart = jpeg_resync_to_restart; /* use default method */ + src->pub.term_source = term_source; + src->infile = infile; + src->pub.bytes_in_buffer = 0; /* forces fill_input_buffer on first read */ + src->pub.next_input_byte = NULL; /* until buffer loaded */ +} diff --git a/ml/dlib/dlib/external/libjpeg/jdcoefct.cpp b/ml/dlib/dlib/external/libjpeg/jdcoefct.cpp new file mode 100644 index 000000000..11b618920 --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jdcoefct.cpp @@ -0,0 +1,736 @@ +/* + * jdcoefct.c + * + * Copyright (C) 1994-1997, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This file contains the coefficient buffer controller for decompression. + * This controller is the top level of the JPEG decompressor proper. + * The coefficient buffer lies between entropy decoding and inverse-DCT steps. + * + * In buffered-image mode, this controller is the interface between + * input-oriented processing and output-oriented processing. + * Also, the input side (only) is used when reading a file for transcoding. + */ + +#define JPEG_INTERNALS +#include "jinclude.h" +#include "jpeglib.h" + +/* Block smoothing is only applicable for progressive JPEG, so: */ +#ifndef D_PROGRESSIVE_SUPPORTED +#undef BLOCK_SMOOTHING_SUPPORTED +#endif + +/* Private buffer controller object */ + +typedef struct { + struct jpeg_d_coef_controller pub; /* public fields */ + + /* These variables keep track of the current location of the input side. */ + /* cinfo->input_iMCU_row is also used for this. */ + JDIMENSION MCU_ctr; /* counts MCUs processed in current row */ + int MCU_vert_offset; /* counts MCU rows within iMCU row */ + int MCU_rows_per_iMCU_row; /* number of such rows needed */ + + /* The output side's location is represented by cinfo->output_iMCU_row. */ + + /* In single-pass modes, it's sufficient to buffer just one MCU. + * We allocate a workspace of D_MAX_BLOCKS_IN_MCU coefficient blocks, + * and let the entropy decoder write into that workspace each time. + * (On 80x86, the workspace is FAR even though it's not really very big; + * this is to keep the module interfaces unchanged when a large coefficient + * buffer is necessary.) + * In multi-pass modes, this array points to the current MCU's blocks + * within the virtual arrays; it is used only by the input side. + */ + JBLOCKROW MCU_buffer[D_MAX_BLOCKS_IN_MCU]; + +#ifdef D_MULTISCAN_FILES_SUPPORTED + /* In multi-pass modes, we need a virtual block array for each component. */ + jvirt_barray_ptr whole_image[MAX_COMPONENTS]; +#endif + +#ifdef BLOCK_SMOOTHING_SUPPORTED + /* When doing block smoothing, we latch coefficient Al values here */ + int * coef_bits_latch; +#define SAVED_COEFS 6 /* we save coef_bits[0..5] */ +#endif +} my_coef_controller; + +typedef my_coef_controller * my_coef_ptr; + +/* Forward declarations */ +METHODDEF(int) decompress_onepass + JPP((j_decompress_ptr cinfo, JSAMPIMAGE output_buf)); +#ifdef D_MULTISCAN_FILES_SUPPORTED +METHODDEF(int) decompress_data + JPP((j_decompress_ptr cinfo, JSAMPIMAGE output_buf)); +#endif +#ifdef BLOCK_SMOOTHING_SUPPORTED +LOCAL(int) smoothing_ok JPP((j_decompress_ptr cinfo)); +METHODDEF(int) decompress_smooth_data + JPP((j_decompress_ptr cinfo, JSAMPIMAGE output_buf)); +#endif + + +LOCAL(void) +start_iMCU_row (j_decompress_ptr cinfo) +/* Reset within-iMCU-row counters for a new row (input side) */ +{ + my_coef_ptr coef = (my_coef_ptr) cinfo->coef; + + /* In an interleaved scan, an MCU row is the same as an iMCU row. + * In a noninterleaved scan, an iMCU row has v_samp_factor MCU rows. + * But at the bottom of the image, process only what's left. + */ + if (cinfo->comps_in_scan > 1) { + coef->MCU_rows_per_iMCU_row = 1; + } else { + if (cinfo->input_iMCU_row < (cinfo->total_iMCU_rows-1)) + coef->MCU_rows_per_iMCU_row = cinfo->cur_comp_info[0]->v_samp_factor; + else + coef->MCU_rows_per_iMCU_row = cinfo->cur_comp_info[0]->last_row_height; + } + + coef->MCU_ctr = 0; + coef->MCU_vert_offset = 0; +} + + +/* + * Initialize for an input processing pass. + */ + +METHODDEF(void) +start_input_pass (j_decompress_ptr cinfo) +{ + cinfo->input_iMCU_row = 0; + start_iMCU_row(cinfo); +} + + +/* + * Initialize for an output processing pass. + */ + +METHODDEF(void) +start_output_pass (j_decompress_ptr cinfo) +{ +#ifdef BLOCK_SMOOTHING_SUPPORTED + my_coef_ptr coef = (my_coef_ptr) cinfo->coef; + + /* If multipass, check to see whether to use block smoothing on this pass */ + if (coef->pub.coef_arrays != NULL) { + if (cinfo->do_block_smoothing && smoothing_ok(cinfo)) + coef->pub.decompress_data = decompress_smooth_data; + else + coef->pub.decompress_data = decompress_data; + } +#endif + cinfo->output_iMCU_row = 0; +} + + +/* + * Decompress and return some data in the single-pass case. + * Always attempts to emit one fully interleaved MCU row ("iMCU" row). + * Input and output must run in lockstep since we have only a one-MCU buffer. + * Return value is JPEG_ROW_COMPLETED, JPEG_SCAN_COMPLETED, or JPEG_SUSPENDED. + * + * NB: output_buf contains a plane for each component in image, + * which we index according to the component's SOF position. + */ + +METHODDEF(int) +decompress_onepass (j_decompress_ptr cinfo, JSAMPIMAGE output_buf) +{ + my_coef_ptr coef = (my_coef_ptr) cinfo->coef; + JDIMENSION MCU_col_num; /* index of current MCU within row */ + JDIMENSION last_MCU_col = cinfo->MCUs_per_row - 1; + JDIMENSION last_iMCU_row = cinfo->total_iMCU_rows - 1; + int blkn, ci, xindex, yindex, yoffset, useful_width; + JSAMPARRAY output_ptr; + JDIMENSION start_col, output_col; + jpeg_component_info *compptr; + inverse_DCT_method_ptr inverse_DCT; + + /* Loop to process as much as one whole iMCU row */ + for (yoffset = coef->MCU_vert_offset; yoffset < coef->MCU_rows_per_iMCU_row; + yoffset++) { + for (MCU_col_num = coef->MCU_ctr; MCU_col_num <= last_MCU_col; + MCU_col_num++) { + /* Try to fetch an MCU. Entropy decoder expects buffer to be zeroed. */ + jzero_far((void FAR *) coef->MCU_buffer[0], + (size_t) (cinfo->blocks_in_MCU * SIZEOF(JBLOCK))); + if (! (*cinfo->entropy->decode_mcu) (cinfo, coef->MCU_buffer)) { + /* Suspension forced; update state counters and exit */ + coef->MCU_vert_offset = yoffset; + coef->MCU_ctr = MCU_col_num; + return JPEG_SUSPENDED; + } + /* Determine where data should go in output_buf and do the IDCT thing. + * We skip dummy blocks at the right and bottom edges (but blkn gets + * incremented past them!). Note the inner loop relies on having + * allocated the MCU_buffer[] blocks sequentially. + */ + blkn = 0; /* index of current DCT block within MCU */ + for (ci = 0; ci < cinfo->comps_in_scan; ci++) { + compptr = cinfo->cur_comp_info[ci]; + /* Don't bother to IDCT an uninteresting component. */ + if (! compptr->component_needed) { + blkn += compptr->MCU_blocks; + continue; + } + inverse_DCT = cinfo->idct->inverse_DCT[compptr->component_index]; + useful_width = (MCU_col_num < last_MCU_col) ? compptr->MCU_width + : compptr->last_col_width; + output_ptr = output_buf[compptr->component_index] + + yoffset * compptr->DCT_scaled_size; + start_col = MCU_col_num * compptr->MCU_sample_width; + for (yindex = 0; yindex < compptr->MCU_height; yindex++) { + if (cinfo->input_iMCU_row < last_iMCU_row || + yoffset+yindex < compptr->last_row_height) { + output_col = start_col; + for (xindex = 0; xindex < useful_width; xindex++) { + (*inverse_DCT) (cinfo, compptr, + (JCOEFPTR) coef->MCU_buffer[blkn+xindex], + output_ptr, output_col); + output_col += compptr->DCT_scaled_size; + } + } + blkn += compptr->MCU_width; + output_ptr += compptr->DCT_scaled_size; + } + } + } + /* Completed an MCU row, but perhaps not an iMCU row */ + coef->MCU_ctr = 0; + } + /* Completed the iMCU row, advance counters for next one */ + cinfo->output_iMCU_row++; + if (++(cinfo->input_iMCU_row) < cinfo->total_iMCU_rows) { + start_iMCU_row(cinfo); + return JPEG_ROW_COMPLETED; + } + /* Completed the scan */ + (*cinfo->inputctl->finish_input_pass) (cinfo); + return JPEG_SCAN_COMPLETED; +} + + +/* + * Dummy consume-input routine for single-pass operation. + */ + +METHODDEF(int) +dummy_consume_data (j_decompress_ptr ) +{ + return JPEG_SUSPENDED; /* Always indicate nothing was done */ +} + + +#ifdef D_MULTISCAN_FILES_SUPPORTED + +/* + * Consume input data and store it in the full-image coefficient buffer. + * We read as much as one fully interleaved MCU row ("iMCU" row) per call, + * ie, v_samp_factor block rows for each component in the scan. + * Return value is JPEG_ROW_COMPLETED, JPEG_SCAN_COMPLETED, or JPEG_SUSPENDED. + */ + +METHODDEF(int) +consume_data (j_decompress_ptr cinfo) +{ + my_coef_ptr coef = (my_coef_ptr) cinfo->coef; + JDIMENSION MCU_col_num; /* index of current MCU within row */ + int blkn, ci, xindex, yindex, yoffset; + JDIMENSION start_col; + JBLOCKARRAY buffer[MAX_COMPS_IN_SCAN]; + JBLOCKROW buffer_ptr; + jpeg_component_info *compptr; + + /* Align the virtual buffers for the components used in this scan. */ + for (ci = 0; ci < cinfo->comps_in_scan; ci++) { + compptr = cinfo->cur_comp_info[ci]; + buffer[ci] = (*cinfo->mem->access_virt_barray) + ((j_common_ptr) cinfo, coef->whole_image[compptr->component_index], + cinfo->input_iMCU_row * compptr->v_samp_factor, + (JDIMENSION) compptr->v_samp_factor, TRUE); + /* Note: entropy decoder expects buffer to be zeroed, + * but this is handled automatically by the memory manager + * because we requested a pre-zeroed array. + */ + } + + /* Loop to process one whole iMCU row */ + for (yoffset = coef->MCU_vert_offset; yoffset < coef->MCU_rows_per_iMCU_row; + yoffset++) { + for (MCU_col_num = coef->MCU_ctr; MCU_col_num < cinfo->MCUs_per_row; + MCU_col_num++) { + /* Construct list of pointers to DCT blocks belonging to this MCU */ + blkn = 0; /* index of current DCT block within MCU */ + for (ci = 0; ci < cinfo->comps_in_scan; ci++) { + compptr = cinfo->cur_comp_info[ci]; + start_col = MCU_col_num * compptr->MCU_width; + for (yindex = 0; yindex < compptr->MCU_height; yindex++) { + buffer_ptr = buffer[ci][yindex+yoffset] + start_col; + for (xindex = 0; xindex < compptr->MCU_width; xindex++) { + coef->MCU_buffer[blkn++] = buffer_ptr++; + } + } + } + /* Try to fetch the MCU. */ + if (! (*cinfo->entropy->decode_mcu) (cinfo, coef->MCU_buffer)) { + /* Suspension forced; update state counters and exit */ + coef->MCU_vert_offset = yoffset; + coef->MCU_ctr = MCU_col_num; + return JPEG_SUSPENDED; + } + } + /* Completed an MCU row, but perhaps not an iMCU row */ + coef->MCU_ctr = 0; + } + /* Completed the iMCU row, advance counters for next one */ + if (++(cinfo->input_iMCU_row) < cinfo->total_iMCU_rows) { + start_iMCU_row(cinfo); + return JPEG_ROW_COMPLETED; + } + /* Completed the scan */ + (*cinfo->inputctl->finish_input_pass) (cinfo); + return JPEG_SCAN_COMPLETED; +} + + +/* + * Decompress and return some data in the multi-pass case. + * Always attempts to emit one fully interleaved MCU row ("iMCU" row). + * Return value is JPEG_ROW_COMPLETED, JPEG_SCAN_COMPLETED, or JPEG_SUSPENDED. + * + * NB: output_buf contains a plane for each component in image. + */ + +METHODDEF(int) +decompress_data (j_decompress_ptr cinfo, JSAMPIMAGE output_buf) +{ + my_coef_ptr coef = (my_coef_ptr) cinfo->coef; + JDIMENSION last_iMCU_row = cinfo->total_iMCU_rows - 1; + JDIMENSION block_num; + int ci, block_row, block_rows; + JBLOCKARRAY buffer; + JBLOCKROW buffer_ptr; + JSAMPARRAY output_ptr; + JDIMENSION output_col; + jpeg_component_info *compptr; + inverse_DCT_method_ptr inverse_DCT; + + /* Force some input to be done if we are getting ahead of the input. */ + while (cinfo->input_scan_number < cinfo->output_scan_number || + (cinfo->input_scan_number == cinfo->output_scan_number && + cinfo->input_iMCU_row <= cinfo->output_iMCU_row)) { + if ((*cinfo->inputctl->consume_input)(cinfo) == JPEG_SUSPENDED) + return JPEG_SUSPENDED; + } + + /* OK, output from the virtual arrays. */ + for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components; + ci++, compptr++) { + /* Don't bother to IDCT an uninteresting component. */ + if (! compptr->component_needed) + continue; + /* Align the virtual buffer for this component. */ + buffer = (*cinfo->mem->access_virt_barray) + ((j_common_ptr) cinfo, coef->whole_image[ci], + cinfo->output_iMCU_row * compptr->v_samp_factor, + (JDIMENSION) compptr->v_samp_factor, FALSE); + /* Count non-dummy DCT block rows in this iMCU row. */ + if (cinfo->output_iMCU_row < last_iMCU_row) + block_rows = compptr->v_samp_factor; + else { + /* NB: can't use last_row_height here; it is input-side-dependent! */ + block_rows = (int) (compptr->height_in_blocks % compptr->v_samp_factor); + if (block_rows == 0) block_rows = compptr->v_samp_factor; + } + inverse_DCT = cinfo->idct->inverse_DCT[ci]; + output_ptr = output_buf[ci]; + /* Loop over all DCT blocks to be processed. */ + for (block_row = 0; block_row < block_rows; block_row++) { + buffer_ptr = buffer[block_row]; + output_col = 0; + for (block_num = 0; block_num < compptr->width_in_blocks; block_num++) { + (*inverse_DCT) (cinfo, compptr, (JCOEFPTR) buffer_ptr, + output_ptr, output_col); + buffer_ptr++; + output_col += compptr->DCT_scaled_size; + } + output_ptr += compptr->DCT_scaled_size; + } + } + + if (++(cinfo->output_iMCU_row) < cinfo->total_iMCU_rows) + return JPEG_ROW_COMPLETED; + return JPEG_SCAN_COMPLETED; +} + +#endif /* D_MULTISCAN_FILES_SUPPORTED */ + + +#ifdef BLOCK_SMOOTHING_SUPPORTED + +/* + * This code applies interblock smoothing as described by section K.8 + * of the JPEG standard: the first 5 AC coefficients are estimated from + * the DC values of a DCT block and its 8 neighboring blocks. + * We apply smoothing only for progressive JPEG decoding, and only if + * the coefficients it can estimate are not yet known to full precision. + */ + +/* Natural-order array positions of the first 5 zigzag-order coefficients */ +#define Q01_POS 1 +#define Q10_POS 8 +#define Q20_POS 16 +#define Q11_POS 9 +#define Q02_POS 2 + +/* + * Determine whether block smoothing is applicable and safe. + * We also latch the current states of the coef_bits[] entries for the + * AC coefficients; otherwise, if the input side of the decompressor + * advances into a new scan, we might think the coefficients are known + * more accurately than they really are. + */ + +LOCAL(int) +smoothing_ok (j_decompress_ptr cinfo) +{ + my_coef_ptr coef = (my_coef_ptr) cinfo->coef; + int smoothing_useful = FALSE; + int ci, coefi; + jpeg_component_info *compptr; + JQUANT_TBL * qtable; + int * coef_bits; + int * coef_bits_latch; + + if (! cinfo->progressive_mode || cinfo->coef_bits == NULL) + return FALSE; + + /* Allocate latch area if not already done */ + if (coef->coef_bits_latch == NULL) + coef->coef_bits_latch = (int *) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE, + cinfo->num_components * + (SAVED_COEFS * SIZEOF(int))); + coef_bits_latch = coef->coef_bits_latch; + + for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components; + ci++, compptr++) { + /* All components' quantization values must already be latched. */ + if ((qtable = compptr->quant_table) == NULL) + return FALSE; + /* Verify DC & first 5 AC quantizers are nonzero to avoid zero-divide. */ + if (qtable->quantval[0] == 0 || + qtable->quantval[Q01_POS] == 0 || + qtable->quantval[Q10_POS] == 0 || + qtable->quantval[Q20_POS] == 0 || + qtable->quantval[Q11_POS] == 0 || + qtable->quantval[Q02_POS] == 0) + return FALSE; + /* DC values must be at least partly known for all components. */ + coef_bits = cinfo->coef_bits[ci]; + if (coef_bits[0] < 0) + return FALSE; + /* Block smoothing is helpful if some AC coefficients remain inaccurate. */ + for (coefi = 1; coefi <= 5; coefi++) { + coef_bits_latch[coefi] = coef_bits[coefi]; + if (coef_bits[coefi] != 0) + smoothing_useful = TRUE; + } + coef_bits_latch += SAVED_COEFS; + } + + return smoothing_useful; +} + + +/* + * Variant of decompress_data for use when doing block smoothing. + */ + +METHODDEF(int) +decompress_smooth_data (j_decompress_ptr cinfo, JSAMPIMAGE output_buf) +{ + my_coef_ptr coef = (my_coef_ptr) cinfo->coef; + JDIMENSION last_iMCU_row = cinfo->total_iMCU_rows - 1; + JDIMENSION block_num, last_block_column; + int ci, block_row, block_rows, access_rows; + JBLOCKARRAY buffer; + JBLOCKROW buffer_ptr, prev_block_row, next_block_row; + JSAMPARRAY output_ptr; + JDIMENSION output_col; + jpeg_component_info *compptr; + inverse_DCT_method_ptr inverse_DCT; + int first_row, last_row; + JBLOCK workspace; + int *coef_bits; + JQUANT_TBL *quanttbl; + long Q00,Q01,Q02,Q10,Q11,Q20, num; + int DC1,DC2,DC3,DC4,DC5,DC6,DC7,DC8,DC9; + int Al, pred; + + /* Force some input to be done if we are getting ahead of the input. */ + while (cinfo->input_scan_number <= cinfo->output_scan_number && + ! cinfo->inputctl->eoi_reached) { + if (cinfo->input_scan_number == cinfo->output_scan_number) { + /* If input is working on current scan, we ordinarily want it to + * have completed the current row. But if input scan is DC, + * we want it to keep one row ahead so that next block row's DC + * values are up to date. + */ + JDIMENSION delta = (cinfo->Ss == 0) ? 1 : 0; + if (cinfo->input_iMCU_row > cinfo->output_iMCU_row+delta) + break; + } + if ((*cinfo->inputctl->consume_input)(cinfo) == JPEG_SUSPENDED) + return JPEG_SUSPENDED; + } + + /* OK, output from the virtual arrays. */ + for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components; + ci++, compptr++) { + /* Don't bother to IDCT an uninteresting component. */ + if (! compptr->component_needed) + continue; + /* Count non-dummy DCT block rows in this iMCU row. */ + if (cinfo->output_iMCU_row < last_iMCU_row) { + block_rows = compptr->v_samp_factor; + access_rows = block_rows * 2; /* this and next iMCU row */ + last_row = FALSE; + } else { + /* NB: can't use last_row_height here; it is input-side-dependent! */ + block_rows = (int) (compptr->height_in_blocks % compptr->v_samp_factor); + if (block_rows == 0) block_rows = compptr->v_samp_factor; + access_rows = block_rows; /* this iMCU row only */ + last_row = TRUE; + } + /* Align the virtual buffer for this component. */ + if (cinfo->output_iMCU_row > 0) { + access_rows += compptr->v_samp_factor; /* prior iMCU row too */ + buffer = (*cinfo->mem->access_virt_barray) + ((j_common_ptr) cinfo, coef->whole_image[ci], + (cinfo->output_iMCU_row - 1) * compptr->v_samp_factor, + (JDIMENSION) access_rows, FALSE); + buffer += compptr->v_samp_factor; /* point to current iMCU row */ + first_row = FALSE; + } else { + buffer = (*cinfo->mem->access_virt_barray) + ((j_common_ptr) cinfo, coef->whole_image[ci], + (JDIMENSION) 0, (JDIMENSION) access_rows, FALSE); + first_row = TRUE; + } + /* Fetch component-dependent info */ + coef_bits = coef->coef_bits_latch + (ci * SAVED_COEFS); + quanttbl = compptr->quant_table; + Q00 = quanttbl->quantval[0]; + Q01 = quanttbl->quantval[Q01_POS]; + Q10 = quanttbl->quantval[Q10_POS]; + Q20 = quanttbl->quantval[Q20_POS]; + Q11 = quanttbl->quantval[Q11_POS]; + Q02 = quanttbl->quantval[Q02_POS]; + inverse_DCT = cinfo->idct->inverse_DCT[ci]; + output_ptr = output_buf[ci]; + /* Loop over all DCT blocks to be processed. */ + for (block_row = 0; block_row < block_rows; block_row++) { + buffer_ptr = buffer[block_row]; + if (first_row && block_row == 0) + prev_block_row = buffer_ptr; + else + prev_block_row = buffer[block_row-1]; + if (last_row && block_row == block_rows-1) + next_block_row = buffer_ptr; + else + next_block_row = buffer[block_row+1]; + /* We fetch the surrounding DC values using a sliding-register approach. + * Initialize all nine here so as to do the right thing on narrow pics. + */ + DC1 = DC2 = DC3 = (int) prev_block_row[0][0]; + DC4 = DC5 = DC6 = (int) buffer_ptr[0][0]; + DC7 = DC8 = DC9 = (int) next_block_row[0][0]; + output_col = 0; + last_block_column = compptr->width_in_blocks - 1; + for (block_num = 0; block_num <= last_block_column; block_num++) { + /* Fetch current DCT block into workspace so we can modify it. */ + jcopy_block_row(buffer_ptr, (JBLOCKROW) workspace, (JDIMENSION) 1); + /* Update DC values */ + if (block_num < last_block_column) { + DC3 = (int) prev_block_row[1][0]; + DC6 = (int) buffer_ptr[1][0]; + DC9 = (int) next_block_row[1][0]; + } + /* Compute coefficient estimates per K.8. + * An estimate is applied only if coefficient is still zero, + * and is not known to be fully accurate. + */ + /* AC01 */ + if ((Al=coef_bits[1]) != 0 && workspace[1] == 0) { + num = 36 * Q00 * (DC4 - DC6); + if (num >= 0) { + pred = (int) (((Q01<<7) + num) / (Q01<<8)); + if (Al > 0 && pred >= (1< 0 && pred >= (1<= 0) { + pred = (int) (((Q10<<7) + num) / (Q10<<8)); + if (Al > 0 && pred >= (1< 0 && pred >= (1<= 0) { + pred = (int) (((Q20<<7) + num) / (Q20<<8)); + if (Al > 0 && pred >= (1< 0 && pred >= (1<= 0) { + pred = (int) (((Q11<<7) + num) / (Q11<<8)); + if (Al > 0 && pred >= (1< 0 && pred >= (1<= 0) { + pred = (int) (((Q02<<7) + num) / (Q02<<8)); + if (Al > 0 && pred >= (1< 0 && pred >= (1<DCT_scaled_size; + } + output_ptr += compptr->DCT_scaled_size; + } + } + + if (++(cinfo->output_iMCU_row) < cinfo->total_iMCU_rows) + return JPEG_ROW_COMPLETED; + return JPEG_SCAN_COMPLETED; +} + +#endif /* BLOCK_SMOOTHING_SUPPORTED */ + + +/* + * Initialize coefficient buffer controller. + */ + +GLOBAL(void) +jinit_d_coef_controller (j_decompress_ptr cinfo, int need_full_buffer) +{ + my_coef_ptr coef; + + coef = (my_coef_ptr) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE, + SIZEOF(my_coef_controller)); + cinfo->coef = (struct jpeg_d_coef_controller *) coef; + coef->pub.start_input_pass = start_input_pass; + coef->pub.start_output_pass = start_output_pass; +#ifdef BLOCK_SMOOTHING_SUPPORTED + coef->coef_bits_latch = NULL; +#endif + + /* Create the coefficient buffer. */ + if (need_full_buffer) { +#ifdef D_MULTISCAN_FILES_SUPPORTED + /* Allocate a full-image virtual array for each component, */ + /* padded to a multiple of samp_factor DCT blocks in each direction. */ + /* Note we ask for a pre-zeroed array. */ + int ci, access_rows; + jpeg_component_info *compptr; + + for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components; + ci++, compptr++) { + access_rows = compptr->v_samp_factor; +#ifdef BLOCK_SMOOTHING_SUPPORTED + /* If block smoothing could be used, need a bigger window */ + if (cinfo->progressive_mode) + access_rows *= 3; +#endif + coef->whole_image[ci] = (*cinfo->mem->request_virt_barray) + ((j_common_ptr) cinfo, JPOOL_IMAGE, TRUE, + (JDIMENSION) jround_up((long) compptr->width_in_blocks, + (long) compptr->h_samp_factor), + (JDIMENSION) jround_up((long) compptr->height_in_blocks, + (long) compptr->v_samp_factor), + (JDIMENSION) access_rows); + } + coef->pub.consume_data = consume_data; + coef->pub.decompress_data = decompress_data; + coef->pub.coef_arrays = coef->whole_image; /* link to virtual arrays */ +#else + ERREXIT(cinfo, JERR_NOT_COMPILED); +#endif + } else { + /* We only need a single-MCU buffer. */ + JBLOCKROW buffer; + int i; + + buffer = (JBLOCKROW) + (*cinfo->mem->alloc_large) ((j_common_ptr) cinfo, JPOOL_IMAGE, + D_MAX_BLOCKS_IN_MCU * SIZEOF(JBLOCK)); + for (i = 0; i < D_MAX_BLOCKS_IN_MCU; i++) { + coef->MCU_buffer[i] = buffer + i; + } + coef->pub.consume_data = dummy_consume_data; + coef->pub.decompress_data = decompress_onepass; + coef->pub.coef_arrays = NULL; /* flag for no virtual arrays */ + } +} diff --git a/ml/dlib/dlib/external/libjpeg/jdcolor.cpp b/ml/dlib/dlib/external/libjpeg/jdcolor.cpp new file mode 100644 index 000000000..8dd88bfd3 --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jdcolor.cpp @@ -0,0 +1,396 @@ +/* + * jdcolor.c + * + * Copyright (C) 1991-1997, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This file contains output colorspace conversion routines. + */ + +#define JPEG_INTERNALS +#include "jinclude.h" +#include "jpeglib.h" + + +/* Private subobject */ + +typedef struct { + struct jpeg_color_deconverter pub; /* public fields */ + + /* Private state for YCC->RGB conversion */ + int * Cr_r_tab; /* => table for Cr to R conversion */ + int * Cb_b_tab; /* => table for Cb to B conversion */ + long * Cr_g_tab; /* => table for Cr to G conversion */ + long * Cb_g_tab; /* => table for Cb to G conversion */ +} my_color_deconverter; + +typedef my_color_deconverter * my_cconvert_ptr; + + +/**************** YCbCr -> RGB conversion: most common case **************/ + +/* + * YCbCr is defined per CCIR 601-1, except that Cb and Cr are + * normalized to the range 0..MAXJSAMPLE rather than -0.5 .. 0.5. + * The conversion equations to be implemented are therefore + * R = Y + 1.40200 * Cr + * G = Y - 0.34414 * Cb - 0.71414 * Cr + * B = Y + 1.77200 * Cb + * where Cb and Cr represent the incoming values less CENTERJSAMPLE. + * (These numbers are derived from TIFF 6.0 section 21, dated 3-June-92.) + * + * To avoid floating-point arithmetic, we represent the fractional constants + * as integers scaled up by 2^16 (about 4 digits precision); we have to divide + * the products by 2^16, with appropriate rounding, to get the correct answer. + * Notice that Y, being an integral input, does not contribute any fraction + * so it need not participate in the rounding. + * + * For even more speed, we avoid doing any multiplications in the inner loop + * by precalculating the constants times Cb and Cr for all possible values. + * For 8-bit JSAMPLEs this is very reasonable (only 256 entries per table); + * for 12-bit samples it is still acceptable. It's not very reasonable for + * 16-bit samples, but if you want lossless storage you shouldn't be changing + * colorspace anyway. + * The Cr=>R and Cb=>B values can be rounded to integers in advance; the + * values for the G calculation are left scaled up, since we must add them + * together before rounding. + */ + +#define SCALEBITS 16 /* speediest right-shift on some machines */ +#define ONE_HALF ((long) 1 << (SCALEBITS-1)) +#define FIX(x) ((long) ((x) * (1L<RGB colorspace conversion. + */ + +LOCAL(void) +build_ycc_rgb_table (j_decompress_ptr cinfo) +{ + my_cconvert_ptr cconvert = (my_cconvert_ptr) cinfo->cconvert; + int i; + long x; + SHIFT_TEMPS + + cconvert->Cr_r_tab = (int *) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE, + (MAXJSAMPLE+1) * SIZEOF(int)); + cconvert->Cb_b_tab = (int *) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE, + (MAXJSAMPLE+1) * SIZEOF(int)); + cconvert->Cr_g_tab = (long *) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE, + (MAXJSAMPLE+1) * SIZEOF(long)); + cconvert->Cb_g_tab = (long *) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE, + (MAXJSAMPLE+1) * SIZEOF(long)); + + for (i = 0, x = -CENTERJSAMPLE; i <= MAXJSAMPLE; i++, x++) { + /* i is the actual input pixel value, in the range 0..MAXJSAMPLE */ + /* The Cb or Cr value we are thinking of is x = i - CENTERJSAMPLE */ + /* Cr=>R value is nearest int to 1.40200 * x */ + cconvert->Cr_r_tab[i] = (int) + RIGHT_SHIFT(FIX(1.40200) * x + ONE_HALF, SCALEBITS); + /* Cb=>B value is nearest int to 1.77200 * x */ + cconvert->Cb_b_tab[i] = (int) + RIGHT_SHIFT(FIX(1.77200) * x + ONE_HALF, SCALEBITS); + /* Cr=>G value is scaled-up -0.71414 * x */ + cconvert->Cr_g_tab[i] = (- FIX(0.71414)) * x; + /* Cb=>G value is scaled-up -0.34414 * x */ + /* We also add in ONE_HALF so that need not do it in inner loop */ + cconvert->Cb_g_tab[i] = (- FIX(0.34414)) * x + ONE_HALF; + } +} + + +/* + * Convert some rows of samples to the output colorspace. + * + * Note that we change from noninterleaved, one-plane-per-component format + * to interleaved-pixel format. The output buffer is therefore three times + * as wide as the input buffer. + * A starting row offset is provided only for the input buffer. The caller + * can easily adjust the passed output_buf value to accommodate any row + * offset required on that side. + */ + +METHODDEF(void) +ycc_rgb_convert (j_decompress_ptr cinfo, + JSAMPIMAGE input_buf, JDIMENSION input_row, + JSAMPARRAY output_buf, int num_rows) +{ + my_cconvert_ptr cconvert = (my_cconvert_ptr) cinfo->cconvert; + int y, cb, cr; + JSAMPROW outptr; + JSAMPROW inptr0, inptr1, inptr2; + JDIMENSION col; + JDIMENSION num_cols = cinfo->output_width; + /* copy these pointers into registers if possible */ + JSAMPLE * range_limit = cinfo->sample_range_limit; + int * Crrtab = cconvert->Cr_r_tab; + int * Cbbtab = cconvert->Cb_b_tab; + long * Crgtab = cconvert->Cr_g_tab; + long * Cbgtab = cconvert->Cb_g_tab; + SHIFT_TEMPS + + while (--num_rows >= 0) { + inptr0 = input_buf[0][input_row]; + inptr1 = input_buf[1][input_row]; + inptr2 = input_buf[2][input_row]; + input_row++; + outptr = *output_buf++; + for (col = 0; col < num_cols; col++) { + y = GETJSAMPLE(inptr0[col]); + cb = GETJSAMPLE(inptr1[col]); + cr = GETJSAMPLE(inptr2[col]); + /* Range-limiting is essential due to noise introduced by DCT losses. */ + outptr[RGB_RED] = range_limit[y + Crrtab[cr]]; + outptr[RGB_GREEN] = range_limit[y + + ((int) RIGHT_SHIFT(Cbgtab[cb] + Crgtab[cr], + SCALEBITS))]; + outptr[RGB_BLUE] = range_limit[y + Cbbtab[cb]]; + outptr += RGB_PIXELSIZE; + } + } +} + + +/**************** Cases other than YCbCr -> RGB **************/ + + +/* + * Color conversion for no colorspace change: just copy the data, + * converting from separate-planes to interleaved representation. + */ + +METHODDEF(void) +null_convert (j_decompress_ptr cinfo, + JSAMPIMAGE input_buf, JDIMENSION input_row, + JSAMPARRAY output_buf, int num_rows) +{ + JSAMPROW inptr, outptr; + JDIMENSION count; + int num_components = cinfo->num_components; + JDIMENSION num_cols = cinfo->output_width; + int ci; + + while (--num_rows >= 0) { + for (ci = 0; ci < num_components; ci++) { + inptr = input_buf[ci][input_row]; + outptr = output_buf[0] + ci; + for (count = num_cols; count > 0; count--) { + *outptr = *inptr++; /* needn't bother with GETJSAMPLE() here */ + outptr += num_components; + } + } + input_row++; + output_buf++; + } +} + + +/* + * Color conversion for grayscale: just copy the data. + * This also works for YCbCr -> grayscale conversion, in which + * we just copy the Y (luminance) component and ignore chrominance. + */ + +METHODDEF(void) +grayscale_convert (j_decompress_ptr cinfo, + JSAMPIMAGE input_buf, JDIMENSION input_row, + JSAMPARRAY output_buf, int num_rows) +{ + jcopy_sample_rows(input_buf[0], (int) input_row, output_buf, 0, + num_rows, cinfo->output_width); +} + + +/* + * Convert grayscale to RGB: just duplicate the graylevel three times. + * This is provided to support applications that don't want to cope + * with grayscale as a separate case. + */ + +METHODDEF(void) +gray_rgb_convert (j_decompress_ptr cinfo, + JSAMPIMAGE input_buf, JDIMENSION input_row, + JSAMPARRAY output_buf, int num_rows) +{ + JSAMPROW inptr, outptr; + JDIMENSION col; + JDIMENSION num_cols = cinfo->output_width; + + while (--num_rows >= 0) { + inptr = input_buf[0][input_row++]; + outptr = *output_buf++; + for (col = 0; col < num_cols; col++) { + /* We can dispense with GETJSAMPLE() here */ + outptr[RGB_RED] = outptr[RGB_GREEN] = outptr[RGB_BLUE] = inptr[col]; + outptr += RGB_PIXELSIZE; + } + } +} + + +/* + * Adobe-style YCCK->CMYK conversion. + * We convert YCbCr to R=1-C, G=1-M, and B=1-Y using the same + * conversion as above, while passing K (black) unchanged. + * We assume build_ycc_rgb_table has been called. + */ + +METHODDEF(void) +ycck_cmyk_convert (j_decompress_ptr cinfo, + JSAMPIMAGE input_buf, JDIMENSION input_row, + JSAMPARRAY output_buf, int num_rows) +{ + my_cconvert_ptr cconvert = (my_cconvert_ptr) cinfo->cconvert; + int y, cb, cr; + JSAMPROW outptr; + JSAMPROW inptr0, inptr1, inptr2, inptr3; + JDIMENSION col; + JDIMENSION num_cols = cinfo->output_width; + /* copy these pointers into registers if possible */ + JSAMPLE * range_limit = cinfo->sample_range_limit; + int * Crrtab = cconvert->Cr_r_tab; + int * Cbbtab = cconvert->Cb_b_tab; + long * Crgtab = cconvert->Cr_g_tab; + long * Cbgtab = cconvert->Cb_g_tab; + SHIFT_TEMPS + + while (--num_rows >= 0) { + inptr0 = input_buf[0][input_row]; + inptr1 = input_buf[1][input_row]; + inptr2 = input_buf[2][input_row]; + inptr3 = input_buf[3][input_row]; + input_row++; + outptr = *output_buf++; + for (col = 0; col < num_cols; col++) { + y = GETJSAMPLE(inptr0[col]); + cb = GETJSAMPLE(inptr1[col]); + cr = GETJSAMPLE(inptr2[col]); + /* Range-limiting is essential due to noise introduced by DCT losses. */ + outptr[0] = range_limit[MAXJSAMPLE - (y + Crrtab[cr])]; /* red */ + outptr[1] = range_limit[MAXJSAMPLE - (y + /* green */ + ((int) RIGHT_SHIFT(Cbgtab[cb] + Crgtab[cr], + SCALEBITS)))]; + outptr[2] = range_limit[MAXJSAMPLE - (y + Cbbtab[cb])]; /* blue */ + /* K passes through unchanged */ + outptr[3] = inptr3[col]; /* don't need GETJSAMPLE here */ + outptr += 4; + } + } +} + + +/* + * Empty method for start_pass. + */ + +METHODDEF(void) +start_pass_dcolor (j_decompress_ptr ) +{ + /* no work needed */ +} + + +/* + * Module initialization routine for output colorspace conversion. + */ + +GLOBAL(void) +jinit_color_deconverter (j_decompress_ptr cinfo) +{ + my_cconvert_ptr cconvert; + int ci; + + cconvert = (my_cconvert_ptr) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE, + SIZEOF(my_color_deconverter)); + cinfo->cconvert = (struct jpeg_color_deconverter *) cconvert; + cconvert->pub.start_pass = start_pass_dcolor; + + /* Make sure num_components agrees with jpeg_color_space */ + switch (cinfo->jpeg_color_space) { + case JCS_GRAYSCALE: + if (cinfo->num_components != 1) + ERREXIT(cinfo, JERR_BAD_J_COLORSPACE); + break; + + case JCS_RGB: + case JCS_YCbCr: + if (cinfo->num_components != 3) + ERREXIT(cinfo, JERR_BAD_J_COLORSPACE); + break; + + case JCS_CMYK: + case JCS_YCCK: + if (cinfo->num_components != 4) + ERREXIT(cinfo, JERR_BAD_J_COLORSPACE); + break; + + default: /* JCS_UNKNOWN can be anything */ + if (cinfo->num_components < 1) + ERREXIT(cinfo, JERR_BAD_J_COLORSPACE); + break; + } + + /* Set out_color_components and conversion method based on requested space. + * Also clear the component_needed flags for any unused components, + * so that earlier pipeline stages can avoid useless computation. + */ + + switch (cinfo->out_color_space) { + case JCS_GRAYSCALE: + cinfo->out_color_components = 1; + if (cinfo->jpeg_color_space == JCS_GRAYSCALE || + cinfo->jpeg_color_space == JCS_YCbCr) { + cconvert->pub.color_convert = grayscale_convert; + /* For color->grayscale conversion, only the Y (0) component is needed */ + for (ci = 1; ci < cinfo->num_components; ci++) + cinfo->comp_info[ci].component_needed = FALSE; + } else + ERREXIT(cinfo, JERR_CONVERSION_NOTIMPL); + break; + + case JCS_RGB: + cinfo->out_color_components = RGB_PIXELSIZE; + if (cinfo->jpeg_color_space == JCS_YCbCr) { + cconvert->pub.color_convert = ycc_rgb_convert; + build_ycc_rgb_table(cinfo); + } else if (cinfo->jpeg_color_space == JCS_GRAYSCALE) { + cconvert->pub.color_convert = gray_rgb_convert; + } else if (cinfo->jpeg_color_space == JCS_RGB && RGB_PIXELSIZE == 3) { + cconvert->pub.color_convert = null_convert; + } else + ERREXIT(cinfo, JERR_CONVERSION_NOTIMPL); + break; + + case JCS_CMYK: + cinfo->out_color_components = 4; + if (cinfo->jpeg_color_space == JCS_YCCK) { + cconvert->pub.color_convert = ycck_cmyk_convert; + build_ycc_rgb_table(cinfo); + } else if (cinfo->jpeg_color_space == JCS_CMYK) { + cconvert->pub.color_convert = null_convert; + } else + ERREXIT(cinfo, JERR_CONVERSION_NOTIMPL); + break; + + default: + /* Permit null conversion to same output space */ + if (cinfo->out_color_space == cinfo->jpeg_color_space) { + cinfo->out_color_components = cinfo->num_components; + cconvert->pub.color_convert = null_convert; + } else /* unsupported non-null conversion */ + ERREXIT(cinfo, JERR_CONVERSION_NOTIMPL); + break; + } + + if (cinfo->quantize_colors) + cinfo->output_components = 1; /* single colormapped output component */ + else + cinfo->output_components = cinfo->out_color_components; +} diff --git a/ml/dlib/dlib/external/libjpeg/jdct.h b/ml/dlib/dlib/external/libjpeg/jdct.h new file mode 100644 index 000000000..a89c9550b --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jdct.h @@ -0,0 +1,176 @@ +/* + * jdct.h + * + * Copyright (C) 1994-1996, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This include file contains common declarations for the forward and + * inverse DCT modules. These declarations are private to the DCT managers + * (jcdctmgr.c, jddctmgr.c) and the individual DCT algorithms. + * The individual DCT algorithms are kept in separate files to ease + * machine-dependent tuning (e.g., assembly coding). + */ + + +/* + * A forward DCT routine is given a pointer to a work area of type DCTELEM[]; + * the DCT is to be performed in-place in that buffer. Type DCTELEM is int + * for 8-bit samples, long for 12-bit samples. (NOTE: Floating-point DCT + * implementations use an array of type FAST_FLOAT, instead.) + * The DCT inputs are expected to be signed (range +-CENTERJSAMPLE). + * The DCT outputs are returned scaled up by a factor of 8; they therefore + * have a range of +-8K for 8-bit data, +-128K for 12-bit data. This + * convention improves accuracy in integer implementations and saves some + * work in floating-point ones. + * Quantization of the output coefficients is done by jcdctmgr.c. + */ + +#if BITS_IN_JSAMPLE == 8 +typedef int DCTELEM; /* 16 or 32 bits is fine */ +#else +typedef long DCTELEM; /* must have 32 bits */ +#endif + +typedef JMETHOD(void, forward_DCT_method_ptr, (DCTELEM * data)); +typedef JMETHOD(void, float_DCT_method_ptr, (FAST_FLOAT * data)); + + +/* + * An inverse DCT routine is given a pointer to the input JBLOCK and a pointer + * to an output sample array. The routine must dequantize the input data as + * well as perform the IDCT; for dequantization, it uses the multiplier table + * pointed to by compptr->dct_table. The output data is to be placed into the + * sample array starting at a specified column. (Any row offset needed will + * be applied to the array pointer before it is passed to the IDCT code.) + * Note that the number of samples emitted by the IDCT routine is + * DCT_scaled_size * DCT_scaled_size. + */ + +/* typedef inverse_DCT_method_ptr is declared in jpegint.h */ + +/* + * Each IDCT routine has its own ideas about the best dct_table element type. + */ + +typedef MULTIPLIER ISLOW_MULT_TYPE; /* short or int, whichever is faster */ +#if BITS_IN_JSAMPLE == 8 +typedef MULTIPLIER IFAST_MULT_TYPE; /* 16 bits is OK, use short if faster */ +#define IFAST_SCALE_BITS 2 /* fractional bits in scale factors */ +#else +typedef long IFAST_MULT_TYPE; /* need 32 bits for scaled quantizers */ +#define IFAST_SCALE_BITS 13 /* fractional bits in scale factors */ +#endif +typedef FAST_FLOAT FLOAT_MULT_TYPE; /* preferred floating type */ + + +/* + * Each IDCT routine is responsible for range-limiting its results and + * converting them to unsigned form (0..MAXJSAMPLE). The raw outputs could + * be quite far out of range if the input data is corrupt, so a bulletproof + * range-limiting step is required. We use a mask-and-table-lookup method + * to do the combined operations quickly. See the comments with + * prepare_range_limit_table (in jdmaster.c) for more info. + */ + +#define IDCT_range_limit(cinfo) ((cinfo)->sample_range_limit + CENTERJSAMPLE) + +#define RANGE_MASK (MAXJSAMPLE * 4 + 3) /* 2 bits wider than legal samples */ + + +/* Short forms of external names for systems with brain-damaged linkers. */ + +#ifdef NEED_SHORT_EXTERNAL_NAMES +#define jpeg_fdct_islow jFDislow +#define jpeg_fdct_ifast jFDifast +#define jpeg_fdct_float jFDfloat +#define jpeg_idct_islow jRDislow +#define jpeg_idct_ifast jRDifast +#define jpeg_idct_float jRDfloat +#define jpeg_idct_4x4 jRD4x4 +#define jpeg_idct_2x2 jRD2x2 +#define jpeg_idct_1x1 jRD1x1 +#endif /* NEED_SHORT_EXTERNAL_NAMES */ + +/* Extern declarations for the forward and inverse DCT routines. */ + +EXTERN(void) jpeg_fdct_islow JPP((DCTELEM * data)); +EXTERN(void) jpeg_fdct_ifast JPP((DCTELEM * data)); +EXTERN(void) jpeg_fdct_float JPP((FAST_FLOAT * data)); + +EXTERN(void) jpeg_idct_islow + JPP((j_decompress_ptr cinfo, jpeg_component_info * compptr, + JCOEFPTR coef_block, JSAMPARRAY output_buf, JDIMENSION output_col)); +EXTERN(void) jpeg_idct_ifast + JPP((j_decompress_ptr cinfo, jpeg_component_info * compptr, + JCOEFPTR coef_block, JSAMPARRAY output_buf, JDIMENSION output_col)); +EXTERN(void) jpeg_idct_float + JPP((j_decompress_ptr cinfo, jpeg_component_info * compptr, + JCOEFPTR coef_block, JSAMPARRAY output_buf, JDIMENSION output_col)); +EXTERN(void) jpeg_idct_4x4 + JPP((j_decompress_ptr cinfo, jpeg_component_info * compptr, + JCOEFPTR coef_block, JSAMPARRAY output_buf, JDIMENSION output_col)); +EXTERN(void) jpeg_idct_2x2 + JPP((j_decompress_ptr cinfo, jpeg_component_info * compptr, + JCOEFPTR coef_block, JSAMPARRAY output_buf, JDIMENSION output_col)); +EXTERN(void) jpeg_idct_1x1 + JPP((j_decompress_ptr cinfo, jpeg_component_info * compptr, + JCOEFPTR coef_block, JSAMPARRAY output_buf, JDIMENSION output_col)); + + +/* + * Macros for handling fixed-point arithmetic; these are used by many + * but not all of the DCT/IDCT modules. + * + * All values are expected to be of type long. + * Fractional constants are scaled left by CONST_BITS bits. + * CONST_BITS is defined within each module using these macros, + * and may differ from one module to the next. + */ + +#define ONE ((long) 1) +#define CONST_SCALE (ONE << CONST_BITS) + +/* Convert a positive real constant to an integer scaled by CONST_SCALE. + * Caution: some C compilers fail to reduce "FIX(constant)" at compile time, + * thus causing a lot of useless floating-point operations at run time. + */ + +#define FIX(x) ((long) ((x) * CONST_SCALE + 0.5)) + +/* Descale and correctly round an long value that's scaled by N bits. + * We assume RIGHT_SHIFT rounds towards minus infinity, so adding + * the fudge factor is correct for either sign of X. + */ + +#define DESCALE(x,n) RIGHT_SHIFT((x) + (ONE << ((n)-1)), n) + +/* Multiply an long variable by an long constant to yield an long result. + * This macro is used only when the two inputs will actually be no more than + * 16 bits wide, so that a 16x16->32 bit multiply can be used instead of a + * full 32x32 multiply. This provides a useful speedup on many machines. + * Unfortunately there is no way to specify a 16x16->32 multiply portably + * in C, but some C compilers will do the right thing if you provide the + * correct combination of casts. + */ + +#ifdef SHORTxSHORT_32 /* may work if 'int' is 32 bits */ +#define MULTIPLY16C16(var,const) (((short) (var)) * ((short) (const))) +#endif +#ifdef SHORTxLCONST_32 /* known to work with Microsoft C 6.0 */ +#define MULTIPLY16C16(var,const) (((short) (var)) * ((long) (const))) +#endif + +#ifndef MULTIPLY16C16 /* default definition */ +#define MULTIPLY16C16(var,const) ((var) * (const)) +#endif + +/* Same except both inputs are variables. */ + +#ifdef SHORTxSHORT_32 /* may work if 'int' is 32 bits */ +#define MULTIPLY16V16(var1,var2) (((short) (var1)) * ((short) (var2))) +#endif + +#ifndef MULTIPLY16V16 /* default definition */ +#define MULTIPLY16V16(var1,var2) ((var1) * (var2)) +#endif diff --git a/ml/dlib/dlib/external/libjpeg/jddctmgr.cpp b/ml/dlib/dlib/external/libjpeg/jddctmgr.cpp new file mode 100644 index 000000000..620da686d --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jddctmgr.cpp @@ -0,0 +1,269 @@ +/* + * jddctmgr.c + * + * Copyright (C) 1994-1996, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This file contains the inverse-DCT management logic. + * This code selects a particular IDCT implementation to be used, + * and it performs related housekeeping chores. No code in this file + * is executed per IDCT step, only during output pass setup. + * + * Note that the IDCT routines are responsible for performing coefficient + * dequantization as well as the IDCT proper. This module sets up the + * dequantization multiplier table needed by the IDCT routine. + */ + +#define JPEG_INTERNALS +#include "jinclude.h" +#include "jpeglib.h" +#include "jdct.h" /* Private declarations for DCT subsystem */ + + +/* + * The decompressor input side (jdinput.c) saves away the appropriate + * quantization table for each component at the start of the first scan + * involving that component. (This is necessary in order to correctly + * decode files that reuse Q-table slots.) + * When we are ready to make an output pass, the saved Q-table is converted + * to a multiplier table that will actually be used by the IDCT routine. + * The multiplier table contents are IDCT-method-dependent. To support + * application changes in IDCT method between scans, we can remake the + * multiplier tables if necessary. + * In buffered-image mode, the first output pass may occur before any data + * has been seen for some components, and thus before their Q-tables have + * been saved away. To handle this case, multiplier tables are preset + * to zeroes; the result of the IDCT will be a neutral gray level. + */ + + +/* Private subobject for this module */ + +typedef struct { + struct jpeg_inverse_dct pub; /* public fields */ + + /* This array contains the IDCT method code that each multiplier table + * is currently set up for, or -1 if it's not yet set up. + * The actual multiplier tables are pointed to by dct_table in the + * per-component comp_info structures. + */ + int cur_method[MAX_COMPONENTS]; +} my_idct_controller; + +typedef my_idct_controller * my_idct_ptr; + + +/* Allocated multiplier tables: big enough for any supported variant */ + +typedef union { + ISLOW_MULT_TYPE islow_array[DCTSIZE2]; +#ifdef DCT_IFAST_SUPPORTED + IFAST_MULT_TYPE ifast_array[DCTSIZE2]; +#endif +#ifdef DCT_FLOAT_SUPPORTED + FLOAT_MULT_TYPE float_array[DCTSIZE2]; +#endif +} multiplier_table; + + +/* The current scaled-IDCT routines require ISLOW-style multiplier tables, + * so be sure to compile that code if either ISLOW or SCALING is requested. + */ +#ifdef DCT_ISLOW_SUPPORTED +#define PROVIDE_ISLOW_TABLES +#else +#ifdef IDCT_SCALING_SUPPORTED +#define PROVIDE_ISLOW_TABLES +#endif +#endif + + +/* + * Prepare for an output pass. + * Here we select the proper IDCT routine for each component and build + * a matching multiplier table. + */ + +METHODDEF(void) +start_pass (j_decompress_ptr cinfo) +{ + my_idct_ptr idct = (my_idct_ptr) cinfo->idct; + int ci, i; + jpeg_component_info *compptr; + int method = 0; + inverse_DCT_method_ptr method_ptr = NULL; + JQUANT_TBL * qtbl; + + for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components; + ci++, compptr++) { + /* Select the proper IDCT routine for this component's scaling */ + switch (compptr->DCT_scaled_size) { +#ifdef IDCT_SCALING_SUPPORTED + case 1: + method_ptr = jpeg_idct_1x1; + method = JDCT_ISLOW; /* jidctred uses islow-style table */ + break; + case 2: + method_ptr = jpeg_idct_2x2; + method = JDCT_ISLOW; /* jidctred uses islow-style table */ + break; + case 4: + method_ptr = jpeg_idct_4x4; + method = JDCT_ISLOW; /* jidctred uses islow-style table */ + break; +#endif + case DCTSIZE: + switch (cinfo->dct_method) { +#ifdef DCT_ISLOW_SUPPORTED + case JDCT_ISLOW: + method_ptr = jpeg_idct_islow; + method = JDCT_ISLOW; + break; +#endif +#ifdef DCT_IFAST_SUPPORTED + case JDCT_IFAST: + method_ptr = jpeg_idct_ifast; + method = JDCT_IFAST; + break; +#endif +#ifdef DCT_FLOAT_SUPPORTED + case JDCT_FLOAT: + method_ptr = jpeg_idct_float; + method = JDCT_FLOAT; + break; +#endif + default: + ERREXIT(cinfo, JERR_NOT_COMPILED); + break; + } + break; + default: + ERREXIT1(cinfo, JERR_BAD_DCTSIZE, compptr->DCT_scaled_size); + break; + } + idct->pub.inverse_DCT[ci] = method_ptr; + /* Create multiplier table from quant table. + * However, we can skip this if the component is uninteresting + * or if we already built the table. Also, if no quant table + * has yet been saved for the component, we leave the + * multiplier table all-zero; we'll be reading zeroes from the + * coefficient controller's buffer anyway. + */ + if (! compptr->component_needed || idct->cur_method[ci] == method) + continue; + qtbl = compptr->quant_table; + if (qtbl == NULL) /* happens if no data yet for component */ + continue; + idct->cur_method[ci] = method; + switch (method) { +#ifdef PROVIDE_ISLOW_TABLES + case JDCT_ISLOW: + { + /* For LL&M IDCT method, multipliers are equal to raw quantization + * coefficients, but are stored as ints to ensure access efficiency. + */ + ISLOW_MULT_TYPE * ismtbl = (ISLOW_MULT_TYPE *) compptr->dct_table; + for (i = 0; i < DCTSIZE2; i++) { + ismtbl[i] = (ISLOW_MULT_TYPE) qtbl->quantval[i]; + } + } + break; +#endif +#ifdef DCT_IFAST_SUPPORTED + case JDCT_IFAST: + { + /* For AA&N IDCT method, multipliers are equal to quantization + * coefficients scaled by scalefactor[row]*scalefactor[col], where + * scalefactor[0] = 1 + * scalefactor[k] = cos(k*PI/16) * sqrt(2) for k=1..7 + * For integer operation, the multiplier table is to be scaled by + * IFAST_SCALE_BITS. + */ + IFAST_MULT_TYPE * ifmtbl = (IFAST_MULT_TYPE *) compptr->dct_table; +#define CONST_BITS 14 + static const short aanscales[DCTSIZE2] = { + /* precomputed values scaled up by 14 bits */ + 16384, 22725, 21407, 19266, 16384, 12873, 8867, 4520, + 22725, 31521, 29692, 26722, 22725, 17855, 12299, 6270, + 21407, 29692, 27969, 25172, 21407, 16819, 11585, 5906, + 19266, 26722, 25172, 22654, 19266, 15137, 10426, 5315, + 16384, 22725, 21407, 19266, 16384, 12873, 8867, 4520, + 12873, 17855, 16819, 15137, 12873, 10114, 6967, 3552, + 8867, 12299, 11585, 10426, 8867, 6967, 4799, 2446, + 4520, 6270, 5906, 5315, 4520, 3552, 2446, 1247 + }; + SHIFT_TEMPS + + for (i = 0; i < DCTSIZE2; i++) { + ifmtbl[i] = (IFAST_MULT_TYPE) + DESCALE(MULTIPLY16V16((long) qtbl->quantval[i], + (long) aanscales[i]), + CONST_BITS-IFAST_SCALE_BITS); + } + } + break; +#endif +#ifdef DCT_FLOAT_SUPPORTED + case JDCT_FLOAT: + { + /* For float AA&N IDCT method, multipliers are equal to quantization + * coefficients scaled by scalefactor[row]*scalefactor[col], where + * scalefactor[0] = 1 + * scalefactor[k] = cos(k*PI/16) * sqrt(2) for k=1..7 + */ + FLOAT_MULT_TYPE * fmtbl = (FLOAT_MULT_TYPE *) compptr->dct_table; + int row, col; + static const double aanscalefactor[DCTSIZE] = { + 1.0, 1.387039845, 1.306562965, 1.175875602, + 1.0, 0.785694958, 0.541196100, 0.275899379 + }; + + i = 0; + for (row = 0; row < DCTSIZE; row++) { + for (col = 0; col < DCTSIZE; col++) { + fmtbl[i] = (FLOAT_MULT_TYPE) + ((double) qtbl->quantval[i] * + aanscalefactor[row] * aanscalefactor[col]); + i++; + } + } + } + break; +#endif + default: + ERREXIT(cinfo, JERR_NOT_COMPILED); + break; + } + } +} + + +/* + * Initialize IDCT manager. + */ + +GLOBAL(void) +jinit_inverse_dct (j_decompress_ptr cinfo) +{ + my_idct_ptr idct; + int ci; + jpeg_component_info *compptr; + + idct = (my_idct_ptr) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE, + SIZEOF(my_idct_controller)); + cinfo->idct = (struct jpeg_inverse_dct *) idct; + idct->pub.start_pass = start_pass; + + for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components; + ci++, compptr++) { + /* Allocate and pre-zero a multiplier table for each component */ + compptr->dct_table = + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE, + SIZEOF(multiplier_table)); + MEMZERO(compptr->dct_table, SIZEOF(multiplier_table)); + /* Mark multiplier table not yet set up for any method */ + idct->cur_method[ci] = -1; + } +} diff --git a/ml/dlib/dlib/external/libjpeg/jdhuff.cpp b/ml/dlib/dlib/external/libjpeg/jdhuff.cpp new file mode 100644 index 000000000..26a2a36f2 --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jdhuff.cpp @@ -0,0 +1,654 @@ +/* + * jdhuff.c + * + * Copyright (C) 1991-1997, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This file contains Huffman entropy decoding routines. + * + * Much of the complexity here has to do with supporting input suspension. + * If the data source module demands suspension, we want to be able to back + * up to the start of the current MCU. To do this, we copy state variables + * into local working storage, and update them back to the permanent + * storage only upon successful completion of an MCU. + */ + +#define JPEG_INTERNALS +#include "jinclude.h" +#include "jpeglib.h" +#include "jdhuff.h" /* Declarations shared with jdphuff.c */ + +#ifdef __GNUC__ +#pragma GCC diagnostic ignored "-Wshift-negative-value" +#endif + +/* + * Expanded entropy decoder object for Huffman decoding. + * + * The savable_state subrecord contains fields that change within an MCU, + * but must not be updated permanently until we complete the MCU. + */ + +typedef struct { + int last_dc_val[MAX_COMPS_IN_SCAN]; /* last DC coef for each component */ +} savable_state; + +/* This macro is to work around compilers with missing or broken + * structure assignment. You'll need to fix this code if you have + * such a compiler and you change MAX_COMPS_IN_SCAN. + */ + +#ifndef NO_STRUCT_ASSIGN +#define ASSIGN_STATE(dest,src) ((dest) = (src)) +#else +#if MAX_COMPS_IN_SCAN == 4 +#define ASSIGN_STATE(dest,src) \ + ((dest).last_dc_val[0] = (src).last_dc_val[0], \ + (dest).last_dc_val[1] = (src).last_dc_val[1], \ + (dest).last_dc_val[2] = (src).last_dc_val[2], \ + (dest).last_dc_val[3] = (src).last_dc_val[3]) +#endif +#endif + + +typedef struct { + struct jpeg_entropy_decoder pub; /* public fields */ + + /* These fields are loaded into local variables at start of each MCU. + * In case of suspension, we exit WITHOUT updating them. + */ + bitread_perm_state bitstate; /* Bit buffer at start of MCU */ + savable_state saved; /* Other state at start of MCU */ + + /* These fields are NOT loaded into local working state. */ + unsigned int restarts_to_go; /* MCUs left in this restart interval */ + + /* Pointers to derived tables (these workspaces have image lifespan) */ + d_derived_tbl * dc_derived_tbls[NUM_HUFF_TBLS]; + d_derived_tbl * ac_derived_tbls[NUM_HUFF_TBLS]; + + /* Precalculated info set up by start_pass for use in decode_mcu: */ + + /* Pointers to derived tables to be used for each block within an MCU */ + d_derived_tbl * dc_cur_tbls[D_MAX_BLOCKS_IN_MCU]; + d_derived_tbl * ac_cur_tbls[D_MAX_BLOCKS_IN_MCU]; + /* Whether we care about the DC and AC coefficient values for each block */ + int dc_needed[D_MAX_BLOCKS_IN_MCU]; + int ac_needed[D_MAX_BLOCKS_IN_MCU]; +} huff_entropy_decoder; + +typedef huff_entropy_decoder * huff_entropy_ptr; + + +/* + * Initialize for a Huffman-compressed scan. + */ + +METHODDEF(void) +start_pass_huff_decoder (j_decompress_ptr cinfo) +{ + huff_entropy_ptr entropy = (huff_entropy_ptr) cinfo->entropy; + int ci, blkn, dctbl, actbl; + jpeg_component_info * compptr; + + /* Check that the scan parameters Ss, Se, Ah/Al are OK for sequential JPEG. + * This ought to be an error condition, but we make it a warning because + * there are some baseline files out there with all zeroes in these bytes. + */ + if (cinfo->Ss != 0 || cinfo->Se != DCTSIZE2-1 || + cinfo->Ah != 0 || cinfo->Al != 0) + WARNMS(cinfo, JWRN_NOT_SEQUENTIAL); + + for (ci = 0; ci < cinfo->comps_in_scan; ci++) { + compptr = cinfo->cur_comp_info[ci]; + dctbl = compptr->dc_tbl_no; + actbl = compptr->ac_tbl_no; + /* Compute derived values for Huffman tables */ + /* We may do this more than once for a table, but it's not expensive */ + jpeg_make_d_derived_tbl(cinfo, TRUE, dctbl, + & entropy->dc_derived_tbls[dctbl]); + jpeg_make_d_derived_tbl(cinfo, FALSE, actbl, + & entropy->ac_derived_tbls[actbl]); + /* Initialize DC predictions to 0 */ + entropy->saved.last_dc_val[ci] = 0; + } + + /* Precalculate decoding info for each block in an MCU of this scan */ + for (blkn = 0; blkn < cinfo->blocks_in_MCU; blkn++) { + ci = cinfo->MCU_membership[blkn]; + compptr = cinfo->cur_comp_info[ci]; + /* Precalculate which table to use for each block */ + entropy->dc_cur_tbls[blkn] = entropy->dc_derived_tbls[compptr->dc_tbl_no]; + entropy->ac_cur_tbls[blkn] = entropy->ac_derived_tbls[compptr->ac_tbl_no]; + /* Decide whether we really care about the coefficient values */ + if (compptr->component_needed) { + entropy->dc_needed[blkn] = TRUE; + /* we don't need the ACs if producing a 1/8th-size image */ + entropy->ac_needed[blkn] = (compptr->DCT_scaled_size > 1); + } else { + entropy->dc_needed[blkn] = entropy->ac_needed[blkn] = FALSE; + } + } + + /* Initialize bitread state variables */ + entropy->bitstate.bits_left = 0; + entropy->bitstate.get_buffer = 0; /* unnecessary, but keeps Purify quiet */ + entropy->pub.insufficient_data = FALSE; + + /* Initialize restart counter */ + entropy->restarts_to_go = cinfo->restart_interval; +} + + +/* + * Compute the derived values for a Huffman table. + * This routine also performs some validation checks on the table. + * + * Note this is also used by jdphuff.c. + */ + +GLOBAL(void) +jpeg_make_d_derived_tbl (j_decompress_ptr cinfo, int isDC, int tblno, + d_derived_tbl ** pdtbl) +{ + JHUFF_TBL *htbl; + d_derived_tbl *dtbl; + int p, i, l, si, numsymbols; + int lookbits, ctr; + char huffsize[257]; + unsigned int huffcode[257]; + unsigned int code; + + /* Note that huffsize[] and huffcode[] are filled in code-length order, + * paralleling the order of the symbols themselves in htbl->huffval[]. + */ + + /* Find the input Huffman table */ + if (tblno < 0 || tblno >= NUM_HUFF_TBLS) + ERREXIT1(cinfo, JERR_NO_HUFF_TABLE, tblno); + htbl = + isDC ? cinfo->dc_huff_tbl_ptrs[tblno] : cinfo->ac_huff_tbl_ptrs[tblno]; + if (htbl == NULL) + ERREXIT1(cinfo, JERR_NO_HUFF_TABLE, tblno); + + /* Allocate a workspace if we haven't already done so. */ + if (*pdtbl == NULL) + *pdtbl = (d_derived_tbl *) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE, + SIZEOF(d_derived_tbl)); + dtbl = *pdtbl; + dtbl->pub = htbl; /* fill in back link */ + + /* Figure C.1: make table of Huffman code length for each symbol */ + + p = 0; + for (l = 1; l <= 16; l++) { + i = (int) htbl->bits[l]; + if (i < 0 || p + i > 256) /* protect against table overrun */ + ERREXIT(cinfo, JERR_BAD_HUFF_TABLE); + while (i--) + huffsize[p++] = (char) l; + } + huffsize[p] = 0; + numsymbols = p; + + /* Figure C.2: generate the codes themselves */ + /* We also validate that the counts represent a legal Huffman code tree. */ + + code = 0; + si = huffsize[0]; + p = 0; + while (huffsize[p]) { + while (((int) huffsize[p]) == si) { + huffcode[p++] = code; + code++; + } + /* code is now 1 more than the last code used for codelength si; but + * it must still fit in si bits, since no code is allowed to be all ones. + */ + if (((long) code) >= (((long) 1) << si)) + ERREXIT(cinfo, JERR_BAD_HUFF_TABLE); + code <<= 1; + si++; + } + + /* Figure F.15: generate decoding tables for bit-sequential decoding */ + + p = 0; + for (l = 1; l <= 16; l++) { + if (htbl->bits[l]) { + /* valoffset[l] = huffval[] index of 1st symbol of code length l, + * minus the minimum code of length l + */ + dtbl->valoffset[l] = (long) p - (long) huffcode[p]; + p += htbl->bits[l]; + dtbl->maxcode[l] = huffcode[p-1]; /* maximum code of length l */ + } else { + dtbl->maxcode[l] = -1; /* -1 if no codes of this length */ + } + } + dtbl->maxcode[17] = 0xFFFFFL; /* ensures jpeg_huff_decode terminates */ + + /* Compute lookahead tables to speed up decoding. + * First we set all the table entries to 0, indicating "too long"; + * then we iterate through the Huffman codes that are short enough and + * fill in all the entries that correspond to bit sequences starting + * with that code. + */ + + MEMZERO(dtbl->look_nbits, SIZEOF(dtbl->look_nbits)); + + p = 0; + for (l = 1; l <= HUFF_LOOKAHEAD; l++) { + for (i = 1; i <= (int) htbl->bits[l]; i++, p++) { + /* l = current code's length, p = its index in huffcode[] & huffval[]. */ + /* Generate left-justified code followed by all possible bit sequences */ + lookbits = huffcode[p] << (HUFF_LOOKAHEAD-l); + for (ctr = 1 << (HUFF_LOOKAHEAD-l); ctr > 0; ctr--) { + dtbl->look_nbits[lookbits] = l; + dtbl->look_sym[lookbits] = htbl->huffval[p]; + lookbits++; + } + } + } + + /* Validate symbols as being reasonable. + * For AC tables, we make no check, but accept all byte values 0..255. + * For DC tables, we require the symbols to be in range 0..15. + * (Tighter bounds could be applied depending on the data depth and mode, + * but this is sufficient to ensure safe decoding.) + */ + if (isDC) { + for (i = 0; i < numsymbols; i++) { + int sym = htbl->huffval[i]; + if (sym < 0 || sym > 15) + ERREXIT(cinfo, JERR_BAD_HUFF_TABLE); + } + } +} + + +/* + * Out-of-line code for bit fetching (shared with jdphuff.c). + * See jdhuff.h for info about usage. + * Note: current values of get_buffer and bits_left are passed as parameters, + * but are returned in the corresponding fields of the state struct. + * + * On most machines MIN_GET_BITS should be 25 to allow the full 32-bit width + * of get_buffer to be used. (On machines with wider words, an even larger + * buffer could be used.) However, on some machines 32-bit shifts are + * quite slow and take time proportional to the number of places shifted. + * (This is true with most PC compilers, for instance.) In this case it may + * be a win to set MIN_GET_BITS to the minimum value of 15. This reduces the + * average shift distance at the cost of more calls to jpeg_fill_bit_buffer. + */ + +#ifdef SLOW_SHIFT_32 +#define MIN_GET_BITS 15 /* minimum allowable value */ +#else +#define MIN_GET_BITS (BIT_BUF_SIZE-7) +#endif + + +GLOBAL(int) +jpeg_fill_bit_buffer (bitread_working_state * state, + bit_buf_type get_buffer, register int bits_left, + int nbits) +/* Load up the bit buffer to a depth of at least nbits */ +{ + /* Copy heavily used state fields into locals (hopefully registers) */ + const JOCTET * next_input_byte = state->next_input_byte; + size_t bytes_in_buffer = state->bytes_in_buffer; + j_decompress_ptr cinfo = state->cinfo; + + /* Attempt to load at least MIN_GET_BITS bits into get_buffer. */ + /* (It is assumed that no request will be for more than that many bits.) */ + /* We fail to do so only if we hit a marker or are forced to suspend. */ + + if (cinfo->unread_marker == 0) { /* cannot advance past a marker */ + while (bits_left < MIN_GET_BITS) { + int c; + + /* Attempt to read a byte */ + if (bytes_in_buffer == 0) { + if (! (*cinfo->src->fill_input_buffer) (cinfo)) + return FALSE; + next_input_byte = cinfo->src->next_input_byte; + bytes_in_buffer = cinfo->src->bytes_in_buffer; + } + bytes_in_buffer--; + c = GETJOCTET(*next_input_byte++); + + /* If it's 0xFF, check and discard stuffed zero byte */ + if (c == 0xFF) { + /* Loop here to discard any padding FF's on terminating marker, + * so that we can save a valid unread_marker value. NOTE: we will + * accept multiple FF's followed by a 0 as meaning a single FF data + * byte. This data pattern is not valid according to the standard. + */ + do { + if (bytes_in_buffer == 0) { + if (! (*cinfo->src->fill_input_buffer) (cinfo)) + return FALSE; + next_input_byte = cinfo->src->next_input_byte; + bytes_in_buffer = cinfo->src->bytes_in_buffer; + } + bytes_in_buffer--; + c = GETJOCTET(*next_input_byte++); + } while (c == 0xFF); + + if (c == 0) { + /* Found FF/00, which represents an FF data byte */ + c = 0xFF; + } else { + /* Oops, it's actually a marker indicating end of compressed data. + * Save the marker code for later use. + * Fine point: it might appear that we should save the marker into + * bitread working state, not straight into permanent state. But + * once we have hit a marker, we cannot need to suspend within the + * current MCU, because we will read no more bytes from the data + * source. So it is OK to update permanent state right away. + */ + cinfo->unread_marker = c; + /* See if we need to insert some fake zero bits. */ + goto no_more_bytes; + } + } + + /* OK, load c into get_buffer */ + get_buffer = (get_buffer << 8) | c; + bits_left += 8; + } /* end while */ + } else { + no_more_bytes: + /* We get here if we've read the marker that terminates the compressed + * data segment. There should be enough bits in the buffer register + * to satisfy the request; if so, no problem. + */ + if (nbits > bits_left) { + /* Uh-oh. Report corrupted data to user and stuff zeroes into + * the data stream, so that we can produce some kind of image. + * We use a nonvolatile flag to ensure that only one warning message + * appears per data segment. + */ + if (! cinfo->entropy->insufficient_data) { + WARNMS(cinfo, JWRN_HIT_MARKER); + cinfo->entropy->insufficient_data = TRUE; + } + /* Fill the buffer with zero bits */ + get_buffer <<= MIN_GET_BITS - bits_left; + bits_left = MIN_GET_BITS; + } + } + + /* Unload the local registers */ + state->next_input_byte = next_input_byte; + state->bytes_in_buffer = bytes_in_buffer; + state->get_buffer = get_buffer; + state->bits_left = bits_left; + + return TRUE; +} + + +/* + * Out-of-line code for Huffman code decoding. + * See jdhuff.h for info about usage. + */ + +GLOBAL(int) +jpeg_huff_decode (bitread_working_state * state, + bit_buf_type get_buffer, register int bits_left, + d_derived_tbl * htbl, int min_bits) +{ + int l = min_bits; + long code; + + /* HUFF_DECODE has determined that the code is at least min_bits */ + /* bits long, so fetch that many bits in one swoop. */ + + CHECK_BIT_BUFFER(*state, l, return -1); + code = GET_BITS(l); + + /* Collect the rest of the Huffman code one bit at a time. */ + /* This is per Figure F.16 in the JPEG spec. */ + + while (code > htbl->maxcode[l]) { + code <<= 1; + CHECK_BIT_BUFFER(*state, 1, return -1); + code |= GET_BITS(1); + l++; + } + + /* Unload the local registers */ + state->get_buffer = get_buffer; + state->bits_left = bits_left; + + /* With garbage input we may reach the sentinel value l = 17. */ + + if (l > 16) { + WARNMS(state->cinfo, JWRN_HUFF_BAD_CODE); + return 0; /* fake a zero as the safest result */ + } + + return htbl->pub->huffval[ (int) (code + htbl->valoffset[l]) ]; +} + + +/* + * Figure F.12: extend sign bit. + * On some machines, a shift and add will be faster than a table lookup. + */ + +#ifdef AVOID_TABLES + +#define HUFF_EXTEND(x,s) ((x) < (1<<((s)-1)) ? (x) + (((-1)<<(s)) + 1) : (x)) + +#else + +#define HUFF_EXTEND(x,s) ((x) < extend_test[s] ? (x) + extend_offset[s] : (x)) + +static const int extend_test[16] = /* entry n is 2**(n-1) */ + { 0, 0x0001, 0x0002, 0x0004, 0x0008, 0x0010, 0x0020, 0x0040, 0x0080, + 0x0100, 0x0200, 0x0400, 0x0800, 0x1000, 0x2000, 0x4000 }; + +static const int extend_offset[16] = /* entry n is (-1 << n) + 1 */ + { 0, ((-1)<<1) + 1, ((-1)<<2) + 1, ((-1)<<3) + 1, ((-1)<<4) + 1, + ((-1)<<5) + 1, ((-1)<<6) + 1, ((-1)<<7) + 1, ((-1)<<8) + 1, + ((-1)<<9) + 1, ((-1)<<10) + 1, ((-1)<<11) + 1, ((-1)<<12) + 1, + ((-1)<<13) + 1, ((-1)<<14) + 1, ((-1)<<15) + 1 }; + +#endif /* AVOID_TABLES */ + + +/* + * Check for a restart marker & resynchronize decoder. + * Returns FALSE if must suspend. + */ + +LOCAL(int) +process_restart (j_decompress_ptr cinfo) +{ + huff_entropy_ptr entropy = (huff_entropy_ptr) cinfo->entropy; + int ci; + + /* Throw away any unused bits remaining in bit buffer; */ + /* include any full bytes in next_marker's count of discarded bytes */ + cinfo->marker->discarded_bytes += entropy->bitstate.bits_left / 8; + entropy->bitstate.bits_left = 0; + + /* Advance past the RSTn marker */ + if (! (*cinfo->marker->read_restart_marker) (cinfo)) + return FALSE; + + /* Re-initialize DC predictions to 0 */ + for (ci = 0; ci < cinfo->comps_in_scan; ci++) + entropy->saved.last_dc_val[ci] = 0; + + /* Reset restart counter */ + entropy->restarts_to_go = cinfo->restart_interval; + + /* Reset out-of-data flag, unless read_restart_marker left us smack up + * against a marker. In that case we will end up treating the next data + * segment as empty, and we can avoid producing bogus output pixels by + * leaving the flag set. + */ + if (cinfo->unread_marker == 0) + entropy->pub.insufficient_data = FALSE; + + return TRUE; +} + + +/* + * Decode and return one MCU's worth of Huffman-compressed coefficients. + * The coefficients are reordered from zigzag order into natural array order, + * but are not dequantized. + * + * The i'th block of the MCU is stored into the block pointed to by + * MCU_data[i]. WE ASSUME THIS AREA HAS BEEN ZEROED BY THE CALLER. + * (Wholesale zeroing is usually a little faster than retail...) + * + * Returns FALSE if data source requested suspension. In that case no + * changes have been made to permanent state. (Exception: some output + * coefficients may already have been assigned. This is harmless for + * this module, since we'll just re-assign them on the next call.) + */ + +METHODDEF(int) +decode_mcu (j_decompress_ptr cinfo, JBLOCKROW *MCU_data) +{ + huff_entropy_ptr entropy = (huff_entropy_ptr) cinfo->entropy; + int blkn; + BITREAD_STATE_VARS; + savable_state state; + + /* Process restart marker if needed; may have to suspend */ + if (cinfo->restart_interval) { + if (entropy->restarts_to_go == 0) + if (! process_restart(cinfo)) + return FALSE; + } + + /* If we've run out of data, just leave the MCU set to zeroes. + * This way, we return uniform gray for the remainder of the segment. + */ + if (! entropy->pub.insufficient_data) { + + /* Load up working state */ + BITREAD_LOAD_STATE(cinfo,entropy->bitstate); + ASSIGN_STATE(state, entropy->saved); + + /* Outer loop handles each block in the MCU */ + + for (blkn = 0; blkn < cinfo->blocks_in_MCU; blkn++) { + JBLOCKROW block = MCU_data[blkn]; + d_derived_tbl * dctbl = entropy->dc_cur_tbls[blkn]; + d_derived_tbl * actbl = entropy->ac_cur_tbls[blkn]; + int s, k, r; + + /* Decode a single block's worth of coefficients */ + + /* Section F.2.2.1: decode the DC coefficient difference */ + HUFF_DECODE(s, br_state, dctbl, return FALSE, label1); + if (s) { + CHECK_BIT_BUFFER(br_state, s, return FALSE); + r = GET_BITS(s); + s = HUFF_EXTEND(r, s); + } + + if (entropy->dc_needed[blkn]) { + /* Convert DC difference to actual value, update last_dc_val */ + int ci = cinfo->MCU_membership[blkn]; + s += state.last_dc_val[ci]; + state.last_dc_val[ci] = s; + /* Output the DC coefficient (assumes jpeg_natural_order[0] = 0) */ + (*block)[0] = (JCOEF) s; + } + + if (entropy->ac_needed[blkn]) { + + /* Section F.2.2.2: decode the AC coefficients */ + /* Since zeroes are skipped, output area must be cleared beforehand */ + for (k = 1; k < DCTSIZE2; k++) { + HUFF_DECODE(s, br_state, actbl, return FALSE, label2); + + r = s >> 4; + s &= 15; + + if (s) { + k += r; + CHECK_BIT_BUFFER(br_state, s, return FALSE); + r = GET_BITS(s); + s = HUFF_EXTEND(r, s); + /* Output coefficient in natural (dezigzagged) order. + * Note: the extra entries in jpeg_natural_order[] will save us + * if k >= DCTSIZE2, which could happen if the data is corrupted. + */ + (*block)[jpeg_natural_order[k]] = (JCOEF) s; + } else { + if (r != 15) + break; + k += 15; + } + } + + } else { + + /* Section F.2.2.2: decode the AC coefficients */ + /* In this path we just discard the values */ + for (k = 1; k < DCTSIZE2; k++) { + HUFF_DECODE(s, br_state, actbl, return FALSE, label3); + + r = s >> 4; + s &= 15; + + if (s) { + k += r; + CHECK_BIT_BUFFER(br_state, s, return FALSE); + DROP_BITS(s); + } else { + if (r != 15) + break; + k += 15; + } + } + + } + } + + /* Completed MCU, so update state */ + BITREAD_SAVE_STATE(cinfo,entropy->bitstate); + ASSIGN_STATE(entropy->saved, state); + } + + /* Account for restart interval (no-op if not using restarts) */ + entropy->restarts_to_go--; + + return TRUE; +} + + +/* + * Module initialization routine for Huffman entropy decoding. + */ + +GLOBAL(void) +jinit_huff_decoder (j_decompress_ptr cinfo) +{ + huff_entropy_ptr entropy; + int i; + + entropy = (huff_entropy_ptr) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE, + SIZEOF(huff_entropy_decoder)); + cinfo->entropy = (struct jpeg_entropy_decoder *) entropy; + entropy->pub.start_pass = start_pass_huff_decoder; + entropy->pub.decode_mcu = decode_mcu; + + /* Mark tables unallocated */ + for (i = 0; i < NUM_HUFF_TBLS; i++) { + entropy->dc_derived_tbls[i] = entropy->ac_derived_tbls[i] = NULL; + } +} diff --git a/ml/dlib/dlib/external/libjpeg/jdhuff.h b/ml/dlib/dlib/external/libjpeg/jdhuff.h new file mode 100644 index 000000000..6a0e939af --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jdhuff.h @@ -0,0 +1,201 @@ +/* + * jdhuff.h + * + * Copyright (C) 1991-1997, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This file contains declarations for Huffman entropy decoding routines + * that are shared between the sequential decoder (jdhuff.c) and the + * progressive decoder (jdphuff.c). No other modules need to see these. + */ + +/* Short forms of external names for systems with brain-damaged linkers. */ + +#ifdef NEED_SHORT_EXTERNAL_NAMES +#define jpeg_make_d_derived_tbl jMkDDerived +#define jpeg_fill_bit_buffer jFilBitBuf +#define jpeg_huff_decode jHufDecode +#endif /* NEED_SHORT_EXTERNAL_NAMES */ + + +/* Derived data constructed for each Huffman table */ + +#define HUFF_LOOKAHEAD 8 /* # of bits of lookahead */ + +typedef struct { + /* Basic tables: (element [0] of each array is unused) */ + long maxcode[18]; /* largest code of length k (-1 if none) */ + /* (maxcode[17] is a sentinel to ensure jpeg_huff_decode terminates) */ + long valoffset[17]; /* huffval[] offset for codes of length k */ + /* valoffset[k] = huffval[] index of 1st symbol of code length k, less + * the smallest code of length k; so given a code of length k, the + * corresponding symbol is huffval[code + valoffset[k]] + */ + + /* Link to public Huffman table (needed only in jpeg_huff_decode) */ + JHUFF_TBL *pub; + + /* Lookahead tables: indexed by the next HUFF_LOOKAHEAD bits of + * the input data stream. If the next Huffman code is no more + * than HUFF_LOOKAHEAD bits long, we can obtain its length and + * the corresponding symbol directly from these tables. + */ + int look_nbits[1< 32 bits on your machine, and shifting/masking longs is + * reasonably fast, making bit_buf_type be long and setting BIT_BUF_SIZE + * appropriately should be a win. Unfortunately we can't define the size + * with something like #define BIT_BUF_SIZE (sizeof(bit_buf_type)*8) + * because not all machines measure sizeof in 8-bit bytes. + */ + +typedef struct { /* Bitreading state saved across MCUs */ + bit_buf_type get_buffer; /* current bit-extraction buffer */ + int bits_left; /* # of unused bits in it */ +} bitread_perm_state; + +typedef struct { /* Bitreading working state within an MCU */ + /* Current data source location */ + /* We need a copy, rather than munging the original, in case of suspension */ + const JOCTET * next_input_byte; /* => next byte to read from source */ + size_t bytes_in_buffer; /* # of bytes remaining in source buffer */ + /* Bit input buffer --- note these values are kept in variables, + * not in this struct, inside the inner loops. + */ + bit_buf_type get_buffer; /* current bit-extraction buffer */ + int bits_left; /* # of unused bits in it */ + /* Pointer needed by jpeg_fill_bit_buffer. */ + j_decompress_ptr cinfo; /* back link to decompress master record */ +} bitread_working_state; + +/* Macros to declare and load/save bitread local variables. */ +#define BITREAD_STATE_VARS \ + bit_buf_type get_buffer; \ + int bits_left; \ + bitread_working_state br_state + +#define BITREAD_LOAD_STATE(cinfop,permstate) \ + br_state.cinfo = cinfop; \ + br_state.next_input_byte = cinfop->src->next_input_byte; \ + br_state.bytes_in_buffer = cinfop->src->bytes_in_buffer; \ + get_buffer = permstate.get_buffer; \ + bits_left = permstate.bits_left; + +#define BITREAD_SAVE_STATE(cinfop,permstate) \ + cinfop->src->next_input_byte = br_state.next_input_byte; \ + cinfop->src->bytes_in_buffer = br_state.bytes_in_buffer; \ + permstate.get_buffer = get_buffer; \ + permstate.bits_left = bits_left + +/* + * These macros provide the in-line portion of bit fetching. + * Use CHECK_BIT_BUFFER to ensure there are N bits in get_buffer + * before using GET_BITS, PEEK_BITS, or DROP_BITS. + * The variables get_buffer and bits_left are assumed to be locals, + * but the state struct might not be (jpeg_huff_decode needs this). + * CHECK_BIT_BUFFER(state,n,action); + * Ensure there are N bits in get_buffer; if suspend, take action. + * val = GET_BITS(n); + * Fetch next N bits. + * val = PEEK_BITS(n); + * Fetch next N bits without removing them from the buffer. + * DROP_BITS(n); + * Discard next N bits. + * The value N should be a simple variable, not an expression, because it + * is evaluated multiple times. + */ + +#define CHECK_BIT_BUFFER(state,nbits,action) \ + { if (bits_left < (nbits)) { \ + if (! jpeg_fill_bit_buffer(&(state),get_buffer,bits_left,nbits)) \ + { action; } \ + get_buffer = (state).get_buffer; bits_left = (state).bits_left; } } + +#define GET_BITS(nbits) \ + (((int) (get_buffer >> (bits_left -= (nbits)))) & ((1<<(nbits))-1)) + +#define PEEK_BITS(nbits) \ + (((int) (get_buffer >> (bits_left - (nbits)))) & ((1<<(nbits))-1)) + +#define DROP_BITS(nbits) \ + (bits_left -= (nbits)) + +/* Load up the bit buffer to a depth of at least nbits */ +EXTERN(int) jpeg_fill_bit_buffer + JPP((bitread_working_state * state, bit_buf_type get_buffer, + int bits_left, int nbits)); + + +/* + * Code for extracting next Huffman-coded symbol from input bit stream. + * Again, this is time-critical and we make the main paths be macros. + * + * We use a lookahead table to process codes of up to HUFF_LOOKAHEAD bits + * without looping. Usually, more than 95% of the Huffman codes will be 8 + * or fewer bits long. The few overlength codes are handled with a loop, + * which need not be inline code. + * + * Notes about the HUFF_DECODE macro: + * 1. Near the end of the data segment, we may fail to get enough bits + * for a lookahead. In that case, we do it the hard way. + * 2. If the lookahead table contains no entry, the next code must be + * more than HUFF_LOOKAHEAD bits long. + * 3. jpeg_huff_decode returns -1 if forced to suspend. + */ + +#define HUFF_DECODE(result,state,htbl,failaction,slowlabel) \ +{ int nb, look; \ + if (bits_left < HUFF_LOOKAHEAD) { \ + if (! jpeg_fill_bit_buffer(&state,get_buffer,bits_left, 0)) {failaction;} \ + get_buffer = state.get_buffer; bits_left = state.bits_left; \ + if (bits_left < HUFF_LOOKAHEAD) { \ + nb = 1; goto slowlabel; \ + } \ + } \ + look = PEEK_BITS(HUFF_LOOKAHEAD); \ + if ((nb = htbl->look_nbits[look]) != 0) { \ + DROP_BITS(nb); \ + result = htbl->look_sym[look]; \ + } else { \ + nb = HUFF_LOOKAHEAD+1; \ +slowlabel: \ + if ((result=jpeg_huff_decode(&state,get_buffer,bits_left,htbl,nb)) < 0) \ + { failaction; } \ + get_buffer = state.get_buffer; bits_left = state.bits_left; \ + } \ +} + +/* Out-of-line case for Huffman code fetching */ +EXTERN(int) jpeg_huff_decode + JPP((bitread_working_state * state, bit_buf_type get_buffer, + int bits_left, d_derived_tbl * htbl, int min_bits)); diff --git a/ml/dlib/dlib/external/libjpeg/jdinput.cpp b/ml/dlib/dlib/external/libjpeg/jdinput.cpp new file mode 100644 index 000000000..42f79977d --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jdinput.cpp @@ -0,0 +1,381 @@ +/* + * jdinput.c + * + * Copyright (C) 1991-1997, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This file contains input control logic for the JPEG decompressor. + * These routines are concerned with controlling the decompressor's input + * processing (marker reading and coefficient decoding). The actual input + * reading is done in jdmarker.c, jdhuff.c, and jdphuff.c. + */ + +#define JPEG_INTERNALS +#include "jinclude.h" +#include "jpeglib.h" + + +/* Private state */ + +typedef struct { + struct jpeg_input_controller pub; /* public fields */ + + int inheaders; /* TRUE until first SOS is reached */ +} my_input_controller; + +typedef my_input_controller * my_inputctl_ptr; + + +/* Forward declarations */ +METHODDEF(int) consume_markers JPP((j_decompress_ptr cinfo)); + + +/* + * Routines to calculate various quantities related to the size of the image. + */ + +LOCAL(void) +initial_setup (j_decompress_ptr cinfo) +/* Called once, when first SOS marker is reached */ +{ + int ci; + jpeg_component_info *compptr; + + /* Make sure image isn't bigger than I can handle */ + if ((long) cinfo->image_height > (long) JPEG_MAX_DIMENSION || + (long) cinfo->image_width > (long) JPEG_MAX_DIMENSION) + ERREXIT1(cinfo, JERR_IMAGE_TOO_BIG, (unsigned int) JPEG_MAX_DIMENSION); + + /* For now, precision must match compiled-in value... */ + if (cinfo->data_precision != BITS_IN_JSAMPLE) + ERREXIT1(cinfo, JERR_BAD_PRECISION, cinfo->data_precision); + + /* Check that number of components won't exceed internal array sizes */ + if (cinfo->num_components > MAX_COMPONENTS) + ERREXIT2(cinfo, JERR_COMPONENT_COUNT, cinfo->num_components, + MAX_COMPONENTS); + + /* Compute maximum sampling factors; check factor validity */ + cinfo->max_h_samp_factor = 1; + cinfo->max_v_samp_factor = 1; + for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components; + ci++, compptr++) { + if (compptr->h_samp_factor<=0 || compptr->h_samp_factor>MAX_SAMP_FACTOR || + compptr->v_samp_factor<=0 || compptr->v_samp_factor>MAX_SAMP_FACTOR) + ERREXIT(cinfo, JERR_BAD_SAMPLING); + cinfo->max_h_samp_factor = MAX(cinfo->max_h_samp_factor, + compptr->h_samp_factor); + cinfo->max_v_samp_factor = MAX(cinfo->max_v_samp_factor, + compptr->v_samp_factor); + } + + /* We initialize DCT_scaled_size and min_DCT_scaled_size to DCTSIZE. + * In the full decompressor, this will be overridden by jdmaster.c; + * but in the transcoder, jdmaster.c is not used, so we must do it here. + */ + cinfo->min_DCT_scaled_size = DCTSIZE; + + /* Compute dimensions of components */ + for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components; + ci++, compptr++) { + compptr->DCT_scaled_size = DCTSIZE; + /* Size in DCT blocks */ + compptr->width_in_blocks = (JDIMENSION) + jdiv_round_up((long) cinfo->image_width * (long) compptr->h_samp_factor, + (long) (cinfo->max_h_samp_factor * DCTSIZE)); + compptr->height_in_blocks = (JDIMENSION) + jdiv_round_up((long) cinfo->image_height * (long) compptr->v_samp_factor, + (long) (cinfo->max_v_samp_factor * DCTSIZE)); + /* downsampled_width and downsampled_height will also be overridden by + * jdmaster.c if we are doing full decompression. The transcoder library + * doesn't use these values, but the calling application might. + */ + /* Size in samples */ + compptr->downsampled_width = (JDIMENSION) + jdiv_round_up((long) cinfo->image_width * (long) compptr->h_samp_factor, + (long) cinfo->max_h_samp_factor); + compptr->downsampled_height = (JDIMENSION) + jdiv_round_up((long) cinfo->image_height * (long) compptr->v_samp_factor, + (long) cinfo->max_v_samp_factor); + /* Mark component needed, until color conversion says otherwise */ + compptr->component_needed = TRUE; + /* Mark no quantization table yet saved for component */ + compptr->quant_table = NULL; + } + + /* Compute number of fully interleaved MCU rows. */ + cinfo->total_iMCU_rows = (JDIMENSION) + jdiv_round_up((long) cinfo->image_height, + (long) (cinfo->max_v_samp_factor*DCTSIZE)); + + /* Decide whether file contains multiple scans */ + if (cinfo->comps_in_scan < cinfo->num_components || cinfo->progressive_mode) + cinfo->inputctl->has_multiple_scans = TRUE; + else + cinfo->inputctl->has_multiple_scans = FALSE; +} + + +LOCAL(void) +per_scan_setup (j_decompress_ptr cinfo) +/* Do computations that are needed before processing a JPEG scan */ +/* cinfo->comps_in_scan and cinfo->cur_comp_info[] were set from SOS marker */ +{ + int ci, mcublks, tmp; + jpeg_component_info *compptr; + + if (cinfo->comps_in_scan == 1) { + + /* Noninterleaved (single-component) scan */ + compptr = cinfo->cur_comp_info[0]; + + /* Overall image size in MCUs */ + cinfo->MCUs_per_row = compptr->width_in_blocks; + cinfo->MCU_rows_in_scan = compptr->height_in_blocks; + + /* For noninterleaved scan, always one block per MCU */ + compptr->MCU_width = 1; + compptr->MCU_height = 1; + compptr->MCU_blocks = 1; + compptr->MCU_sample_width = compptr->DCT_scaled_size; + compptr->last_col_width = 1; + /* For noninterleaved scans, it is convenient to define last_row_height + * as the number of block rows present in the last iMCU row. + */ + tmp = (int) (compptr->height_in_blocks % compptr->v_samp_factor); + if (tmp == 0) tmp = compptr->v_samp_factor; + compptr->last_row_height = tmp; + + /* Prepare array describing MCU composition */ + cinfo->blocks_in_MCU = 1; + cinfo->MCU_membership[0] = 0; + + } else { + + /* Interleaved (multi-component) scan */ + if (cinfo->comps_in_scan <= 0 || cinfo->comps_in_scan > MAX_COMPS_IN_SCAN) + ERREXIT2(cinfo, JERR_COMPONENT_COUNT, cinfo->comps_in_scan, + MAX_COMPS_IN_SCAN); + + /* Overall image size in MCUs */ + cinfo->MCUs_per_row = (JDIMENSION) + jdiv_round_up((long) cinfo->image_width, + (long) (cinfo->max_h_samp_factor*DCTSIZE)); + cinfo->MCU_rows_in_scan = (JDIMENSION) + jdiv_round_up((long) cinfo->image_height, + (long) (cinfo->max_v_samp_factor*DCTSIZE)); + + cinfo->blocks_in_MCU = 0; + + for (ci = 0; ci < cinfo->comps_in_scan; ci++) { + compptr = cinfo->cur_comp_info[ci]; + /* Sampling factors give # of blocks of component in each MCU */ + compptr->MCU_width = compptr->h_samp_factor; + compptr->MCU_height = compptr->v_samp_factor; + compptr->MCU_blocks = compptr->MCU_width * compptr->MCU_height; + compptr->MCU_sample_width = compptr->MCU_width * compptr->DCT_scaled_size; + /* Figure number of non-dummy blocks in last MCU column & row */ + tmp = (int) (compptr->width_in_blocks % compptr->MCU_width); + if (tmp == 0) tmp = compptr->MCU_width; + compptr->last_col_width = tmp; + tmp = (int) (compptr->height_in_blocks % compptr->MCU_height); + if (tmp == 0) tmp = compptr->MCU_height; + compptr->last_row_height = tmp; + /* Prepare array describing MCU composition */ + mcublks = compptr->MCU_blocks; + if (cinfo->blocks_in_MCU + mcublks > D_MAX_BLOCKS_IN_MCU) + ERREXIT(cinfo, JERR_BAD_MCU_SIZE); + while (mcublks-- > 0) { + cinfo->MCU_membership[cinfo->blocks_in_MCU++] = ci; + } + } + + } +} + + +/* + * Save away a copy of the Q-table referenced by each component present + * in the current scan, unless already saved during a prior scan. + * + * In a multiple-scan JPEG file, the encoder could assign different components + * the same Q-table slot number, but change table definitions between scans + * so that each component uses a different Q-table. (The IJG encoder is not + * currently capable of doing this, but other encoders might.) Since we want + * to be able to dequantize all the components at the end of the file, this + * means that we have to save away the table actually used for each component. + * We do this by copying the table at the start of the first scan containing + * the component. + * The JPEG spec prohibits the encoder from changing the contents of a Q-table + * slot between scans of a component using that slot. If the encoder does so + * anyway, this decoder will simply use the Q-table values that were current + * at the start of the first scan for the component. + * + * The decompressor output side looks only at the saved quant tables, + * not at the current Q-table slots. + */ + +LOCAL(void) +latch_quant_tables (j_decompress_ptr cinfo) +{ + int ci, qtblno; + jpeg_component_info *compptr; + JQUANT_TBL * qtbl; + + for (ci = 0; ci < cinfo->comps_in_scan; ci++) { + compptr = cinfo->cur_comp_info[ci]; + /* No work if we already saved Q-table for this component */ + if (compptr->quant_table != NULL) + continue; + /* Make sure specified quantization table is present */ + qtblno = compptr->quant_tbl_no; + if (qtblno < 0 || qtblno >= NUM_QUANT_TBLS || + cinfo->quant_tbl_ptrs[qtblno] == NULL) + ERREXIT1(cinfo, JERR_NO_QUANT_TABLE, qtblno); + /* OK, save away the quantization table */ + qtbl = (JQUANT_TBL *) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE, + SIZEOF(JQUANT_TBL)); + MEMCOPY(qtbl, cinfo->quant_tbl_ptrs[qtblno], SIZEOF(JQUANT_TBL)); + compptr->quant_table = qtbl; + } +} + + +/* + * Initialize the input modules to read a scan of compressed data. + * The first call to this is done by jdmaster.c after initializing + * the entire decompressor (during jpeg_start_decompress). + * Subsequent calls come from consume_markers, below. + */ + +METHODDEF(void) +start_input_pass (j_decompress_ptr cinfo) +{ + per_scan_setup(cinfo); + latch_quant_tables(cinfo); + (*cinfo->entropy->start_pass) (cinfo); + (*cinfo->coef->start_input_pass) (cinfo); + cinfo->inputctl->consume_input = cinfo->coef->consume_data; +} + + +/* + * Finish up after inputting a compressed-data scan. + * This is called by the coefficient controller after it's read all + * the expected data of the scan. + */ + +METHODDEF(void) +finish_input_pass (j_decompress_ptr cinfo) +{ + cinfo->inputctl->consume_input = consume_markers; +} + + +/* + * Read JPEG markers before, between, or after compressed-data scans. + * Change state as necessary when a new scan is reached. + * Return value is JPEG_SUSPENDED, JPEG_REACHED_SOS, or JPEG_REACHED_EOI. + * + * The consume_input method pointer points either here or to the + * coefficient controller's consume_data routine, depending on whether + * we are reading a compressed data segment or inter-segment markers. + */ + +METHODDEF(int) +consume_markers (j_decompress_ptr cinfo) +{ + my_inputctl_ptr inputctl = (my_inputctl_ptr) cinfo->inputctl; + int val; + + if (inputctl->pub.eoi_reached) /* After hitting EOI, read no further */ + return JPEG_REACHED_EOI; + + val = (*cinfo->marker->read_markers) (cinfo); + + switch (val) { + case JPEG_REACHED_SOS: /* Found SOS */ + if (inputctl->inheaders) { /* 1st SOS */ + initial_setup(cinfo); + inputctl->inheaders = FALSE; + /* Note: start_input_pass must be called by jdmaster.c + * before any more input can be consumed. jdapimin.c is + * responsible for enforcing this sequencing. + */ + } else { /* 2nd or later SOS marker */ + if (! inputctl->pub.has_multiple_scans) + ERREXIT(cinfo, JERR_EOI_EXPECTED); /* Oops, I wasn't expecting this! */ + start_input_pass(cinfo); + } + break; + case JPEG_REACHED_EOI: /* Found EOI */ + inputctl->pub.eoi_reached = TRUE; + if (inputctl->inheaders) { /* Tables-only datastream, apparently */ + if (cinfo->marker->saw_SOF) + ERREXIT(cinfo, JERR_SOF_NO_SOS); + } else { + /* Prevent infinite loop in coef ctlr's decompress_data routine + * if user set output_scan_number larger than number of scans. + */ + if (cinfo->output_scan_number > cinfo->input_scan_number) + cinfo->output_scan_number = cinfo->input_scan_number; + } + break; + case JPEG_SUSPENDED: + break; + } + + return val; +} + + +/* + * Reset state to begin a fresh datastream. + */ + +METHODDEF(void) +reset_input_controller (j_decompress_ptr cinfo) +{ + my_inputctl_ptr inputctl = (my_inputctl_ptr) cinfo->inputctl; + + inputctl->pub.consume_input = consume_markers; + inputctl->pub.has_multiple_scans = FALSE; /* "unknown" would be better */ + inputctl->pub.eoi_reached = FALSE; + inputctl->inheaders = TRUE; + /* Reset other modules */ + (*cinfo->err->reset_error_mgr) ((j_common_ptr) cinfo); + (*cinfo->marker->reset_marker_reader) (cinfo); + /* Reset progression state -- would be cleaner if entropy decoder did this */ + cinfo->coef_bits = NULL; +} + + +/* + * Initialize the input controller module. + * This is called only once, when the decompression object is created. + */ + +GLOBAL(void) +jinit_input_controller (j_decompress_ptr cinfo) +{ + my_inputctl_ptr inputctl; + + /* Create subobject in permanent pool */ + inputctl = (my_inputctl_ptr) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_PERMANENT, + SIZEOF(my_input_controller)); + cinfo->inputctl = (struct jpeg_input_controller *) inputctl; + /* Initialize method pointers */ + inputctl->pub.consume_input = consume_markers; + inputctl->pub.reset_input_controller = reset_input_controller; + inputctl->pub.start_input_pass = start_input_pass; + inputctl->pub.finish_input_pass = finish_input_pass; + /* Initialize state: can't use reset_input_controller since we don't + * want to try to reset other modules yet. + */ + inputctl->pub.has_multiple_scans = FALSE; /* "unknown" would be better */ + inputctl->pub.eoi_reached = FALSE; + inputctl->inheaders = TRUE; +} diff --git a/ml/dlib/dlib/external/libjpeg/jdmainct.cpp b/ml/dlib/dlib/external/libjpeg/jdmainct.cpp new file mode 100644 index 000000000..bc2c378aa --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jdmainct.cpp @@ -0,0 +1,512 @@ +/* + * jdmainct.c + * + * Copyright (C) 1994-1996, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This file contains the main buffer controller for decompression. + * The main buffer lies between the JPEG decompressor proper and the + * post-processor; it holds downsampled data in the JPEG colorspace. + * + * Note that this code is bypassed in raw-data mode, since the application + * supplies the equivalent of the main buffer in that case. + */ + +#define JPEG_INTERNALS +#include "jinclude.h" +#include "jpeglib.h" + + +/* + * In the current system design, the main buffer need never be a full-image + * buffer; any full-height buffers will be found inside the coefficient or + * postprocessing controllers. Nonetheless, the main controller is not + * trivial. Its responsibility is to provide context rows for upsampling/ + * rescaling, and doing this in an efficient fashion is a bit tricky. + * + * Postprocessor input data is counted in "row groups". A row group + * is defined to be (v_samp_factor * DCT_scaled_size / min_DCT_scaled_size) + * sample rows of each component. (We require DCT_scaled_size values to be + * chosen such that these numbers are integers. In practice DCT_scaled_size + * values will likely be powers of two, so we actually have the stronger + * condition that DCT_scaled_size / min_DCT_scaled_size is an integer.) + * Upsampling will typically produce max_v_samp_factor pixel rows from each + * row group (times any additional scale factor that the upsampler is + * applying). + * + * The coefficient controller will deliver data to us one iMCU row at a time; + * each iMCU row contains v_samp_factor * DCT_scaled_size sample rows, or + * exactly min_DCT_scaled_size row groups. (This amount of data corresponds + * to one row of MCUs when the image is fully interleaved.) Note that the + * number of sample rows varies across components, but the number of row + * groups does not. Some garbage sample rows may be included in the last iMCU + * row at the bottom of the image. + * + * Depending on the vertical scaling algorithm used, the upsampler may need + * access to the sample row(s) above and below its current input row group. + * The upsampler is required to set need_context_rows TRUE at global selection + * time if so. When need_context_rows is FALSE, this controller can simply + * obtain one iMCU row at a time from the coefficient controller and dole it + * out as row groups to the postprocessor. + * + * When need_context_rows is TRUE, this controller guarantees that the buffer + * passed to postprocessing contains at least one row group's worth of samples + * above and below the row group(s) being processed. Note that the context + * rows "above" the first passed row group appear at negative row offsets in + * the passed buffer. At the top and bottom of the image, the required + * context rows are manufactured by duplicating the first or last real sample + * row; this avoids having special cases in the upsampling inner loops. + * + * The amount of context is fixed at one row group just because that's a + * convenient number for this controller to work with. The existing + * upsamplers really only need one sample row of context. An upsampler + * supporting arbitrary output rescaling might wish for more than one row + * group of context when shrinking the image; tough, we don't handle that. + * (This is justified by the assumption that downsizing will be handled mostly + * by adjusting the DCT_scaled_size values, so that the actual scale factor at + * the upsample step needn't be much less than one.) + * + * To provide the desired context, we have to retain the last two row groups + * of one iMCU row while reading in the next iMCU row. (The last row group + * can't be processed until we have another row group for its below-context, + * and so we have to save the next-to-last group too for its above-context.) + * We could do this most simply by copying data around in our buffer, but + * that'd be very slow. We can avoid copying any data by creating a rather + * strange pointer structure. Here's how it works. We allocate a workspace + * consisting of M+2 row groups (where M = min_DCT_scaled_size is the number + * of row groups per iMCU row). We create two sets of redundant pointers to + * the workspace. Labeling the physical row groups 0 to M+1, the synthesized + * pointer lists look like this: + * M+1 M-1 + * master pointer --> 0 master pointer --> 0 + * 1 1 + * ... ... + * M-3 M-3 + * M-2 M + * M-1 M+1 + * M M-2 + * M+1 M-1 + * 0 0 + * We read alternate iMCU rows using each master pointer; thus the last two + * row groups of the previous iMCU row remain un-overwritten in the workspace. + * The pointer lists are set up so that the required context rows appear to + * be adjacent to the proper places when we pass the pointer lists to the + * upsampler. + * + * The above pictures describe the normal state of the pointer lists. + * At top and bottom of the image, we diddle the pointer lists to duplicate + * the first or last sample row as necessary (this is cheaper than copying + * sample rows around). + * + * This scheme breaks down if M < 2, ie, min_DCT_scaled_size is 1. In that + * situation each iMCU row provides only one row group so the buffering logic + * must be different (eg, we must read two iMCU rows before we can emit the + * first row group). For now, we simply do not support providing context + * rows when min_DCT_scaled_size is 1. That combination seems unlikely to + * be worth providing --- if someone wants a 1/8th-size preview, they probably + * want it quick and dirty, so a context-free upsampler is sufficient. + */ + + +/* Private buffer controller object */ + +typedef struct { + struct jpeg_d_main_controller pub; /* public fields */ + + /* Pointer to allocated workspace (M or M+2 row groups). */ + JSAMPARRAY buffer[MAX_COMPONENTS]; + + int buffer_full; /* Have we gotten an iMCU row from decoder? */ + JDIMENSION rowgroup_ctr; /* counts row groups output to postprocessor */ + + /* Remaining fields are only used in the context case. */ + + /* These are the master pointers to the funny-order pointer lists. */ + JSAMPIMAGE xbuffer[2]; /* pointers to weird pointer lists */ + + int whichptr; /* indicates which pointer set is now in use */ + int context_state; /* process_data state machine status */ + JDIMENSION rowgroups_avail; /* row groups available to postprocessor */ + JDIMENSION iMCU_row_ctr; /* counts iMCU rows to detect image top/bot */ +} my_main_controller; + +typedef my_main_controller * my_main_ptr; + +/* context_state values: */ +#define CTX_PREPARE_FOR_IMCU 0 /* need to prepare for MCU row */ +#define CTX_PROCESS_IMCU 1 /* feeding iMCU to postprocessor */ +#define CTX_POSTPONED_ROW 2 /* feeding postponed row group */ + + +/* Forward declarations */ +METHODDEF(void) process_data_simple_main + JPP((j_decompress_ptr cinfo, JSAMPARRAY output_buf, + JDIMENSION *out_row_ctr, JDIMENSION out_rows_avail)); +METHODDEF(void) process_data_context_main + JPP((j_decompress_ptr cinfo, JSAMPARRAY output_buf, + JDIMENSION *out_row_ctr, JDIMENSION out_rows_avail)); +#ifdef QUANT_2PASS_SUPPORTED +METHODDEF(void) process_data_crank_post + JPP((j_decompress_ptr cinfo, JSAMPARRAY output_buf, + JDIMENSION *out_row_ctr, JDIMENSION out_rows_avail)); +#endif + + +LOCAL(void) +alloc_funny_pointers (j_decompress_ptr cinfo) +/* Allocate space for the funny pointer lists. + * This is done only once, not once per pass. + */ +{ + my_main_ptr main = (my_main_ptr) cinfo->main; + int ci, rgroup; + int M = cinfo->min_DCT_scaled_size; + jpeg_component_info *compptr; + JSAMPARRAY xbuf; + + /* Get top-level space for component array pointers. + * We alloc both arrays with one call to save a few cycles. + */ + main->xbuffer[0] = (JSAMPIMAGE) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE, + cinfo->num_components * 2 * SIZEOF(JSAMPARRAY)); + main->xbuffer[1] = main->xbuffer[0] + cinfo->num_components; + + for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components; + ci++, compptr++) { + rgroup = (compptr->v_samp_factor * compptr->DCT_scaled_size) / + cinfo->min_DCT_scaled_size; /* height of a row group of component */ + /* Get space for pointer lists --- M+4 row groups in each list. + * We alloc both pointer lists with one call to save a few cycles. + */ + xbuf = (JSAMPARRAY) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE, + 2 * (rgroup * (M + 4)) * SIZEOF(JSAMPROW)); + xbuf += rgroup; /* want one row group at negative offsets */ + main->xbuffer[0][ci] = xbuf; + xbuf += rgroup * (M + 4); + main->xbuffer[1][ci] = xbuf; + } +} + + +LOCAL(void) +make_funny_pointers (j_decompress_ptr cinfo) +/* Create the funny pointer lists discussed in the comments above. + * The actual workspace is already allocated (in main->buffer), + * and the space for the pointer lists is allocated too. + * This routine just fills in the curiously ordered lists. + * This will be repeated at the beginning of each pass. + */ +{ + my_main_ptr main = (my_main_ptr) cinfo->main; + int ci, i, rgroup; + int M = cinfo->min_DCT_scaled_size; + jpeg_component_info *compptr; + JSAMPARRAY buf, xbuf0, xbuf1; + + for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components; + ci++, compptr++) { + rgroup = (compptr->v_samp_factor * compptr->DCT_scaled_size) / + cinfo->min_DCT_scaled_size; /* height of a row group of component */ + xbuf0 = main->xbuffer[0][ci]; + xbuf1 = main->xbuffer[1][ci]; + /* First copy the workspace pointers as-is */ + buf = main->buffer[ci]; + for (i = 0; i < rgroup * (M + 2); i++) { + xbuf0[i] = xbuf1[i] = buf[i]; + } + /* In the second list, put the last four row groups in swapped order */ + for (i = 0; i < rgroup * 2; i++) { + xbuf1[rgroup*(M-2) + i] = buf[rgroup*M + i]; + xbuf1[rgroup*M + i] = buf[rgroup*(M-2) + i]; + } + /* The wraparound pointers at top and bottom will be filled later + * (see set_wraparound_pointers, below). Initially we want the "above" + * pointers to duplicate the first actual data line. This only needs + * to happen in xbuffer[0]. + */ + for (i = 0; i < rgroup; i++) { + xbuf0[i - rgroup] = xbuf0[0]; + } + } +} + + +LOCAL(void) +set_wraparound_pointers (j_decompress_ptr cinfo) +/* Set up the "wraparound" pointers at top and bottom of the pointer lists. + * This changes the pointer list state from top-of-image to the normal state. + */ +{ + my_main_ptr main = (my_main_ptr) cinfo->main; + int ci, i, rgroup; + int M = cinfo->min_DCT_scaled_size; + jpeg_component_info *compptr; + JSAMPARRAY xbuf0, xbuf1; + + for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components; + ci++, compptr++) { + rgroup = (compptr->v_samp_factor * compptr->DCT_scaled_size) / + cinfo->min_DCT_scaled_size; /* height of a row group of component */ + xbuf0 = main->xbuffer[0][ci]; + xbuf1 = main->xbuffer[1][ci]; + for (i = 0; i < rgroup; i++) { + xbuf0[i - rgroup] = xbuf0[rgroup*(M+1) + i]; + xbuf1[i - rgroup] = xbuf1[rgroup*(M+1) + i]; + xbuf0[rgroup*(M+2) + i] = xbuf0[i]; + xbuf1[rgroup*(M+2) + i] = xbuf1[i]; + } + } +} + + +LOCAL(void) +set_bottom_pointers (j_decompress_ptr cinfo) +/* Change the pointer lists to duplicate the last sample row at the bottom + * of the image. whichptr indicates which xbuffer holds the final iMCU row. + * Also sets rowgroups_avail to indicate number of nondummy row groups in row. + */ +{ + my_main_ptr main = (my_main_ptr) cinfo->main; + int ci, i, rgroup, iMCUheight, rows_left; + jpeg_component_info *compptr; + JSAMPARRAY xbuf; + + for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components; + ci++, compptr++) { + /* Count sample rows in one iMCU row and in one row group */ + iMCUheight = compptr->v_samp_factor * compptr->DCT_scaled_size; + rgroup = iMCUheight / cinfo->min_DCT_scaled_size; + /* Count nondummy sample rows remaining for this component */ + rows_left = (int) (compptr->downsampled_height % (JDIMENSION) iMCUheight); + if (rows_left == 0) rows_left = iMCUheight; + /* Count nondummy row groups. Should get same answer for each component, + * so we need only do it once. + */ + if (ci == 0) { + main->rowgroups_avail = (JDIMENSION) ((rows_left-1) / rgroup + 1); + } + /* Duplicate the last real sample row rgroup*2 times; this pads out the + * last partial rowgroup and ensures at least one full rowgroup of context. + */ + xbuf = main->xbuffer[main->whichptr][ci]; + for (i = 0; i < rgroup * 2; i++) { + xbuf[rows_left + i] = xbuf[rows_left-1]; + } + } +} + + +/* + * Initialize for a processing pass. + */ + +METHODDEF(void) +start_pass_main (j_decompress_ptr cinfo, J_BUF_MODE pass_mode) +{ + my_main_ptr main = (my_main_ptr) cinfo->main; + + switch (pass_mode) { + case JBUF_PASS_THRU: + if (cinfo->upsample->need_context_rows) { + main->pub.process_data = process_data_context_main; + make_funny_pointers(cinfo); /* Create the xbuffer[] lists */ + main->whichptr = 0; /* Read first iMCU row into xbuffer[0] */ + main->context_state = CTX_PREPARE_FOR_IMCU; + main->iMCU_row_ctr = 0; + } else { + /* Simple case with no context needed */ + main->pub.process_data = process_data_simple_main; + } + main->buffer_full = FALSE; /* Mark buffer empty */ + main->rowgroup_ctr = 0; + break; +#ifdef QUANT_2PASS_SUPPORTED + case JBUF_CRANK_DEST: + /* For last pass of 2-pass quantization, just crank the postprocessor */ + main->pub.process_data = process_data_crank_post; + break; +#endif + default: + ERREXIT(cinfo, JERR_BAD_BUFFER_MODE); + break; + } +} + + +/* + * Process some data. + * This handles the simple case where no context is required. + */ + +METHODDEF(void) +process_data_simple_main (j_decompress_ptr cinfo, + JSAMPARRAY output_buf, JDIMENSION *out_row_ctr, + JDIMENSION out_rows_avail) +{ + my_main_ptr main = (my_main_ptr) cinfo->main; + JDIMENSION rowgroups_avail; + + /* Read input data if we haven't filled the main buffer yet */ + if (! main->buffer_full) { + if (! (*cinfo->coef->decompress_data) (cinfo, main->buffer)) + return; /* suspension forced, can do nothing more */ + main->buffer_full = TRUE; /* OK, we have an iMCU row to work with */ + } + + /* There are always min_DCT_scaled_size row groups in an iMCU row. */ + rowgroups_avail = (JDIMENSION) cinfo->min_DCT_scaled_size; + /* Note: at the bottom of the image, we may pass extra garbage row groups + * to the postprocessor. The postprocessor has to check for bottom + * of image anyway (at row resolution), so no point in us doing it too. + */ + + /* Feed the postprocessor */ + (*cinfo->post->post_process_data) (cinfo, main->buffer, + &main->rowgroup_ctr, rowgroups_avail, + output_buf, out_row_ctr, out_rows_avail); + + /* Has postprocessor consumed all the data yet? If so, mark buffer empty */ + if (main->rowgroup_ctr >= rowgroups_avail) { + main->buffer_full = FALSE; + main->rowgroup_ctr = 0; + } +} + + +/* + * Process some data. + * This handles the case where context rows must be provided. + */ + +METHODDEF(void) +process_data_context_main (j_decompress_ptr cinfo, + JSAMPARRAY output_buf, JDIMENSION *out_row_ctr, + JDIMENSION out_rows_avail) +{ + my_main_ptr main = (my_main_ptr) cinfo->main; + + /* Read input data if we haven't filled the main buffer yet */ + if (! main->buffer_full) { + if (! (*cinfo->coef->decompress_data) (cinfo, + main->xbuffer[main->whichptr])) + return; /* suspension forced, can do nothing more */ + main->buffer_full = TRUE; /* OK, we have an iMCU row to work with */ + main->iMCU_row_ctr++; /* count rows received */ + } + + /* Postprocessor typically will not swallow all the input data it is handed + * in one call (due to filling the output buffer first). Must be prepared + * to exit and restart. This switch lets us keep track of how far we got. + * Note that each case falls through to the next on successful completion. + */ + switch (main->context_state) { + case CTX_POSTPONED_ROW: + /* Call postprocessor using previously set pointers for postponed row */ + (*cinfo->post->post_process_data) (cinfo, main->xbuffer[main->whichptr], + &main->rowgroup_ctr, main->rowgroups_avail, + output_buf, out_row_ctr, out_rows_avail); + if (main->rowgroup_ctr < main->rowgroups_avail) + return; /* Need to suspend */ + main->context_state = CTX_PREPARE_FOR_IMCU; + if (*out_row_ctr >= out_rows_avail) + return; /* Postprocessor exactly filled output buf */ + /*FALLTHROUGH*/ + case CTX_PREPARE_FOR_IMCU: + /* Prepare to process first M-1 row groups of this iMCU row */ + main->rowgroup_ctr = 0; + main->rowgroups_avail = (JDIMENSION) (cinfo->min_DCT_scaled_size - 1); + /* Check for bottom of image: if so, tweak pointers to "duplicate" + * the last sample row, and adjust rowgroups_avail to ignore padding rows. + */ + if (main->iMCU_row_ctr == cinfo->total_iMCU_rows) + set_bottom_pointers(cinfo); + main->context_state = CTX_PROCESS_IMCU; + /*FALLTHROUGH*/ + case CTX_PROCESS_IMCU: + /* Call postprocessor using previously set pointers */ + (*cinfo->post->post_process_data) (cinfo, main->xbuffer[main->whichptr], + &main->rowgroup_ctr, main->rowgroups_avail, + output_buf, out_row_ctr, out_rows_avail); + if (main->rowgroup_ctr < main->rowgroups_avail) + return; /* Need to suspend */ + /* After the first iMCU, change wraparound pointers to normal state */ + if (main->iMCU_row_ctr == 1) + set_wraparound_pointers(cinfo); + /* Prepare to load new iMCU row using other xbuffer list */ + main->whichptr ^= 1; /* 0=>1 or 1=>0 */ + main->buffer_full = FALSE; + /* Still need to process last row group of this iMCU row, */ + /* which is saved at index M+1 of the other xbuffer */ + main->rowgroup_ctr = (JDIMENSION) (cinfo->min_DCT_scaled_size + 1); + main->rowgroups_avail = (JDIMENSION) (cinfo->min_DCT_scaled_size + 2); + main->context_state = CTX_POSTPONED_ROW; + } +} + + +/* + * Process some data. + * Final pass of two-pass quantization: just call the postprocessor. + * Source data will be the postprocessor controller's internal buffer. + */ + +#ifdef QUANT_2PASS_SUPPORTED + +METHODDEF(void) +process_data_crank_post (j_decompress_ptr cinfo, + JSAMPARRAY output_buf, JDIMENSION *out_row_ctr, + JDIMENSION out_rows_avail) +{ + (*cinfo->post->post_process_data) (cinfo, (JSAMPIMAGE) NULL, + (JDIMENSION *) NULL, (JDIMENSION) 0, + output_buf, out_row_ctr, out_rows_avail); +} + +#endif /* QUANT_2PASS_SUPPORTED */ + + +/* + * Initialize main buffer controller. + */ + +GLOBAL(void) +jinit_d_main_controller (j_decompress_ptr cinfo, int need_full_buffer) +{ + my_main_ptr main; + int ci, rgroup, ngroups; + jpeg_component_info *compptr; + + main = (my_main_ptr) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE, + SIZEOF(my_main_controller)); + cinfo->main = (struct jpeg_d_main_controller *) main; + main->pub.start_pass = start_pass_main; + + if (need_full_buffer) /* shouldn't happen */ + ERREXIT(cinfo, JERR_BAD_BUFFER_MODE); + + /* Allocate the workspace. + * ngroups is the number of row groups we need. + */ + if (cinfo->upsample->need_context_rows) { + if (cinfo->min_DCT_scaled_size < 2) /* unsupported, see comments above */ + ERREXIT(cinfo, JERR_NOTIMPL); + alloc_funny_pointers(cinfo); /* Alloc space for xbuffer[] lists */ + ngroups = cinfo->min_DCT_scaled_size + 2; + } else { + ngroups = cinfo->min_DCT_scaled_size; + } + + for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components; + ci++, compptr++) { + rgroup = (compptr->v_samp_factor * compptr->DCT_scaled_size) / + cinfo->min_DCT_scaled_size; /* height of a row group of component */ + main->buffer[ci] = (*cinfo->mem->alloc_sarray) + ((j_common_ptr) cinfo, JPOOL_IMAGE, + compptr->width_in_blocks * compptr->DCT_scaled_size, + (JDIMENSION) (rgroup * ngroups)); + } +} diff --git a/ml/dlib/dlib/external/libjpeg/jdmarker.cpp b/ml/dlib/dlib/external/libjpeg/jdmarker.cpp new file mode 100644 index 000000000..c8c9f8e1c --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jdmarker.cpp @@ -0,0 +1,1360 @@ +/* + * jdmarker.c + * + * Copyright (C) 1991-1998, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This file contains routines to decode JPEG datastream markers. + * Most of the complexity arises from our desire to support input + * suspension: if not all of the data for a marker is available, + * we must exit back to the application. On resumption, we reprocess + * the marker. + */ + +#define JPEG_INTERNALS +#include "jinclude.h" +#include "jpeglib.h" + + +typedef enum { /* JPEG marker codes */ + M_SOF0 = 0xc0, + M_SOF1 = 0xc1, + M_SOF2 = 0xc2, + M_SOF3 = 0xc3, + + M_SOF5 = 0xc5, + M_SOF6 = 0xc6, + M_SOF7 = 0xc7, + + M_JPG = 0xc8, + M_SOF9 = 0xc9, + M_SOF10 = 0xca, + M_SOF11 = 0xcb, + + M_SOF13 = 0xcd, + M_SOF14 = 0xce, + M_SOF15 = 0xcf, + + M_DHT = 0xc4, + + M_DAC = 0xcc, + + M_RST0 = 0xd0, + M_RST1 = 0xd1, + M_RST2 = 0xd2, + M_RST3 = 0xd3, + M_RST4 = 0xd4, + M_RST5 = 0xd5, + M_RST6 = 0xd6, + M_RST7 = 0xd7, + + M_SOI = 0xd8, + M_EOI = 0xd9, + M_SOS = 0xda, + M_DQT = 0xdb, + M_DNL = 0xdc, + M_DRI = 0xdd, + M_DHP = 0xde, + M_EXP = 0xdf, + + M_APP0 = 0xe0, + M_APP1 = 0xe1, + M_APP2 = 0xe2, + M_APP3 = 0xe3, + M_APP4 = 0xe4, + M_APP5 = 0xe5, + M_APP6 = 0xe6, + M_APP7 = 0xe7, + M_APP8 = 0xe8, + M_APP9 = 0xe9, + M_APP10 = 0xea, + M_APP11 = 0xeb, + M_APP12 = 0xec, + M_APP13 = 0xed, + M_APP14 = 0xee, + M_APP15 = 0xef, + + M_JPG0 = 0xf0, + M_JPG13 = 0xfd, + M_COM = 0xfe, + + M_TEM = 0x01, + + M_ERROR = 0x100 +} JPEG_MARKER; + + +/* Private state */ + +typedef struct { + struct jpeg_marker_reader pub; /* public fields */ + + /* Application-overridable marker processing methods */ + jpeg_marker_parser_method process_COM; + jpeg_marker_parser_method process_APPn[16]; + + /* Limit on marker data length to save for each marker type */ + unsigned int length_limit_COM; + unsigned int length_limit_APPn[16]; + + /* Status of COM/APPn marker saving */ + jpeg_saved_marker_ptr cur_marker; /* NULL if not processing a marker */ + unsigned int bytes_read; /* data bytes read so far in marker */ + /* Note: cur_marker is not linked into marker_list until it's all read. */ +} my_marker_reader; + +typedef my_marker_reader * my_marker_ptr; + + +/* + * Macros for fetching data from the data source module. + * + * At all times, cinfo->src->next_input_byte and ->bytes_in_buffer reflect + * the current restart point; we update them only when we have reached a + * suitable place to restart if a suspension occurs. + */ + +/* Declare and initialize local copies of input pointer/count */ +#define INPUT_VARS(cinfo) \ + struct jpeg_source_mgr * datasrc = (cinfo)->src; \ + const JOCTET * next_input_byte = datasrc->next_input_byte; \ + size_t bytes_in_buffer = datasrc->bytes_in_buffer + +/* Unload the local copies --- do this only at a restart boundary */ +#define INPUT_SYNC(cinfo) \ + ( datasrc->next_input_byte = next_input_byte, \ + datasrc->bytes_in_buffer = bytes_in_buffer ) + +/* Reload the local copies --- used only in MAKE_BYTE_AVAIL */ +#define INPUT_RELOAD(cinfo) \ + ( next_input_byte = datasrc->next_input_byte, \ + bytes_in_buffer = datasrc->bytes_in_buffer ) + +/* Internal macro for INPUT_BYTE and INPUT_2BYTES: make a byte available. + * Note we do *not* do INPUT_SYNC before calling fill_input_buffer, + * but we must reload the local copies after a successful fill. + */ +#define MAKE_BYTE_AVAIL(cinfo,action) \ + if (bytes_in_buffer == 0) { \ + if (! (*datasrc->fill_input_buffer) (cinfo)) \ + { action; } \ + INPUT_RELOAD(cinfo); \ + } + +/* Read a byte into variable V. + * If must suspend, take the specified action (typically "return FALSE"). + */ +#define INPUT_BYTE(cinfo,V,action) \ + MAKESTMT( MAKE_BYTE_AVAIL(cinfo,action); \ + bytes_in_buffer--; \ + V = GETJOCTET(*next_input_byte++); ) + +/* As above, but read two bytes interpreted as an unsigned 16-bit integer. + * V should be declared unsigned int or perhaps long. + */ +#define INPUT_2BYTES(cinfo,V,action) \ + MAKESTMT( MAKE_BYTE_AVAIL(cinfo,action); \ + bytes_in_buffer--; \ + V = ((unsigned int) GETJOCTET(*next_input_byte++)) << 8; \ + MAKE_BYTE_AVAIL(cinfo,action); \ + bytes_in_buffer--; \ + V += GETJOCTET(*next_input_byte++); ) + + +/* + * Routines to process JPEG markers. + * + * Entry condition: JPEG marker itself has been read and its code saved + * in cinfo->unread_marker; input restart point is just after the marker. + * + * Exit: if return TRUE, have read and processed any parameters, and have + * updated the restart point to point after the parameters. + * If return FALSE, was forced to suspend before reaching end of + * marker parameters; restart point has not been moved. Same routine + * will be called again after application supplies more input data. + * + * This approach to suspension assumes that all of a marker's parameters + * can fit into a single input bufferload. This should hold for "normal" + * markers. Some COM/APPn markers might have large parameter segments + * that might not fit. If we are simply dropping such a marker, we use + * skip_input_data to get past it, and thereby put the problem on the + * source manager's shoulders. If we are saving the marker's contents + * into memory, we use a slightly different convention: when forced to + * suspend, the marker processor updates the restart point to the end of + * what it's consumed (ie, the end of the buffer) before returning FALSE. + * On resumption, cinfo->unread_marker still contains the marker code, + * but the data source will point to the next chunk of marker data. + * The marker processor must retain internal state to deal with this. + * + * Note that we don't bother to avoid duplicate trace messages if a + * suspension occurs within marker parameters. Other side effects + * require more care. + */ + + +LOCAL(int) +get_soi (j_decompress_ptr cinfo) +/* Process an SOI marker */ +{ + int i; + + TRACEMS(cinfo, 1, JTRC_SOI); + + if (cinfo->marker->saw_SOI) + ERREXIT(cinfo, JERR_SOI_DUPLICATE); + + /* Reset all parameters that are defined to be reset by SOI */ + + for (i = 0; i < NUM_ARITH_TBLS; i++) { + cinfo->arith_dc_L[i] = 0; + cinfo->arith_dc_U[i] = 1; + cinfo->arith_ac_K[i] = 5; + } + cinfo->restart_interval = 0; + + /* Set initial assumptions for colorspace etc */ + + cinfo->jpeg_color_space = JCS_UNKNOWN; + cinfo->CCIR601_sampling = FALSE; /* Assume non-CCIR sampling??? */ + + cinfo->saw_JFIF_marker = FALSE; + cinfo->JFIF_major_version = 1; /* set default JFIF APP0 values */ + cinfo->JFIF_minor_version = 1; + cinfo->density_unit = 0; + cinfo->X_density = 1; + cinfo->Y_density = 1; + cinfo->saw_Adobe_marker = FALSE; + cinfo->Adobe_transform = 0; + + cinfo->marker->saw_SOI = TRUE; + + return TRUE; +} + + +LOCAL(int) +get_sof (j_decompress_ptr cinfo, int is_prog, int is_arith) +/* Process a SOFn marker */ +{ + long length; + int c, ci; + jpeg_component_info * compptr; + INPUT_VARS(cinfo); + + cinfo->progressive_mode = is_prog; + cinfo->arith_code = is_arith; + + INPUT_2BYTES(cinfo, length, return FALSE); + + INPUT_BYTE(cinfo, cinfo->data_precision, return FALSE); + INPUT_2BYTES(cinfo, cinfo->image_height, return FALSE); + INPUT_2BYTES(cinfo, cinfo->image_width, return FALSE); + INPUT_BYTE(cinfo, cinfo->num_components, return FALSE); + + length -= 8; + + TRACEMS4(cinfo, 1, JTRC_SOF, cinfo->unread_marker, + (int) cinfo->image_width, (int) cinfo->image_height, + cinfo->num_components); + + if (cinfo->marker->saw_SOF) + ERREXIT(cinfo, JERR_SOF_DUPLICATE); + + /* We don't support files in which the image height is initially specified */ + /* as 0 and is later redefined by DNL. As long as we have to check that, */ + /* might as well have a general sanity check. */ + if (cinfo->image_height <= 0 || cinfo->image_width <= 0 + || cinfo->num_components <= 0) + ERREXIT(cinfo, JERR_EMPTY_IMAGE); + + if (length != (cinfo->num_components * 3)) + ERREXIT(cinfo, JERR_BAD_LENGTH); + + if (cinfo->comp_info == NULL) /* do only once, even if suspend */ + cinfo->comp_info = (jpeg_component_info *) (*cinfo->mem->alloc_small) + ((j_common_ptr) cinfo, JPOOL_IMAGE, + cinfo->num_components * SIZEOF(jpeg_component_info)); + + for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components; + ci++, compptr++) { + compptr->component_index = ci; + INPUT_BYTE(cinfo, compptr->component_id, return FALSE); + INPUT_BYTE(cinfo, c, return FALSE); + compptr->h_samp_factor = (c >> 4) & 15; + compptr->v_samp_factor = (c ) & 15; + INPUT_BYTE(cinfo, compptr->quant_tbl_no, return FALSE); + + TRACEMS4(cinfo, 1, JTRC_SOF_COMPONENT, + compptr->component_id, compptr->h_samp_factor, + compptr->v_samp_factor, compptr->quant_tbl_no); + } + + cinfo->marker->saw_SOF = TRUE; + + INPUT_SYNC(cinfo); + return TRUE; +} + + +LOCAL(int) +get_sos (j_decompress_ptr cinfo) +/* Process a SOS marker */ +{ + long length; + int i, ci, n, c, cc; + jpeg_component_info * compptr; + INPUT_VARS(cinfo); + + if (! cinfo->marker->saw_SOF) + ERREXIT(cinfo, JERR_SOS_NO_SOF); + + INPUT_2BYTES(cinfo, length, return FALSE); + + INPUT_BYTE(cinfo, n, return FALSE); /* Number of components */ + + TRACEMS1(cinfo, 1, JTRC_SOS, n); + + if (length != (n * 2 + 6) || n < 1 || n > MAX_COMPS_IN_SCAN) + ERREXIT(cinfo, JERR_BAD_LENGTH); + + cinfo->comps_in_scan = n; + + /* Collect the component-spec parameters */ + + for (i = 0; i < n; i++) { + INPUT_BYTE(cinfo, cc, return FALSE); + INPUT_BYTE(cinfo, c, return FALSE); + + for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components; + ci++, compptr++) { + if (cc == compptr->component_id) + goto id_found; + } + + ERREXIT1(cinfo, JERR_BAD_COMPONENT_ID, cc); + + id_found: + + cinfo->cur_comp_info[i] = compptr; + compptr->dc_tbl_no = (c >> 4) & 15; + compptr->ac_tbl_no = (c ) & 15; + + TRACEMS3(cinfo, 1, JTRC_SOS_COMPONENT, cc, + compptr->dc_tbl_no, compptr->ac_tbl_no); + } + + /* Collect the additional scan parameters Ss, Se, Ah/Al. */ + INPUT_BYTE(cinfo, c, return FALSE); + cinfo->Ss = c; + INPUT_BYTE(cinfo, c, return FALSE); + cinfo->Se = c; + INPUT_BYTE(cinfo, c, return FALSE); + cinfo->Ah = (c >> 4) & 15; + cinfo->Al = (c ) & 15; + + TRACEMS4(cinfo, 1, JTRC_SOS_PARAMS, cinfo->Ss, cinfo->Se, + cinfo->Ah, cinfo->Al); + + /* Prepare to scan data & restart markers */ + cinfo->marker->next_restart_num = 0; + + /* Count another SOS marker */ + cinfo->input_scan_number++; + + INPUT_SYNC(cinfo); + return TRUE; +} + + +#ifdef D_ARITH_CODING_SUPPORTED + +LOCAL(int) +get_dac (j_decompress_ptr cinfo) +/* Process a DAC marker */ +{ + long length; + int index, val; + INPUT_VARS(cinfo); + + INPUT_2BYTES(cinfo, length, return FALSE); + length -= 2; + + while (length > 0) { + INPUT_BYTE(cinfo, index, return FALSE); + INPUT_BYTE(cinfo, val, return FALSE); + + length -= 2; + + TRACEMS2(cinfo, 1, JTRC_DAC, index, val); + + if (index < 0 || index >= (2*NUM_ARITH_TBLS)) + ERREXIT1(cinfo, JERR_DAC_INDEX, index); + + if (index >= NUM_ARITH_TBLS) { /* define AC table */ + cinfo->arith_ac_K[index-NUM_ARITH_TBLS] = (unsigned char) val; + } else { /* define DC table */ + cinfo->arith_dc_L[index] = (unsigned char) (val & 0x0F); + cinfo->arith_dc_U[index] = (unsigned char) (val >> 4); + if (cinfo->arith_dc_L[index] > cinfo->arith_dc_U[index]) + ERREXIT1(cinfo, JERR_DAC_VALUE, val); + } + } + + if (length != 0) + ERREXIT(cinfo, JERR_BAD_LENGTH); + + INPUT_SYNC(cinfo); + return TRUE; +} + +#else /* ! D_ARITH_CODING_SUPPORTED */ + +#define get_dac(cinfo) skip_variable(cinfo) + +#endif /* D_ARITH_CODING_SUPPORTED */ + + +LOCAL(int) +get_dht (j_decompress_ptr cinfo) +/* Process a DHT marker */ +{ + long length; + unsigned char bits[17]; + unsigned char huffval[256]; + int i, index, count; + JHUFF_TBL **htblptr; + INPUT_VARS(cinfo); + + INPUT_2BYTES(cinfo, length, return FALSE); + length -= 2; + + while (length > 16) { + INPUT_BYTE(cinfo, index, return FALSE); + + TRACEMS1(cinfo, 1, JTRC_DHT, index); + + bits[0] = 0; + count = 0; + for (i = 1; i <= 16; i++) { + INPUT_BYTE(cinfo, bits[i], return FALSE); + count += bits[i]; + } + + length -= 1 + 16; + + TRACEMS8(cinfo, 2, JTRC_HUFFBITS, + bits[1], bits[2], bits[3], bits[4], + bits[5], bits[6], bits[7], bits[8]); + TRACEMS8(cinfo, 2, JTRC_HUFFBITS, + bits[9], bits[10], bits[11], bits[12], + bits[13], bits[14], bits[15], bits[16]); + + /* Here we just do minimal validation of the counts to avoid walking + * off the end of our table space. jdhuff.c will check more carefully. + */ + if (count > 256 || ((long) count) > length) + ERREXIT(cinfo, JERR_BAD_HUFF_TABLE); + + for (i = 0; i < count; i++) + INPUT_BYTE(cinfo, huffval[i], return FALSE); + + length -= count; + + if (index & 0x10) { /* AC table definition */ + index -= 0x10; + htblptr = &cinfo->ac_huff_tbl_ptrs[index]; + } else { /* DC table definition */ + htblptr = &cinfo->dc_huff_tbl_ptrs[index]; + } + + if (index < 0 || index >= NUM_HUFF_TBLS) + ERREXIT1(cinfo, JERR_DHT_INDEX, index); + + if (*htblptr == NULL) + *htblptr = jpeg_alloc_huff_table((j_common_ptr) cinfo); + + MEMCOPY((*htblptr)->bits, bits, SIZEOF((*htblptr)->bits)); + MEMCOPY((*htblptr)->huffval, huffval, SIZEOF((*htblptr)->huffval)); + } + + if (length != 0) + ERREXIT(cinfo, JERR_BAD_LENGTH); + + INPUT_SYNC(cinfo); + return TRUE; +} + + +LOCAL(int) +get_dqt (j_decompress_ptr cinfo) +/* Process a DQT marker */ +{ + long length; + int n, i, prec; + unsigned int tmp; + JQUANT_TBL *quant_ptr; + INPUT_VARS(cinfo); + + INPUT_2BYTES(cinfo, length, return FALSE); + length -= 2; + + while (length > 0) { + INPUT_BYTE(cinfo, n, return FALSE); + prec = n >> 4; + n &= 0x0F; + + TRACEMS2(cinfo, 1, JTRC_DQT, n, prec); + + if (n >= NUM_QUANT_TBLS) + ERREXIT1(cinfo, JERR_DQT_INDEX, n); + + if (cinfo->quant_tbl_ptrs[n] == NULL) + cinfo->quant_tbl_ptrs[n] = jpeg_alloc_quant_table((j_common_ptr) cinfo); + quant_ptr = cinfo->quant_tbl_ptrs[n]; + + for (i = 0; i < DCTSIZE2; i++) { + if (prec) + INPUT_2BYTES(cinfo, tmp, return FALSE); + else + INPUT_BYTE(cinfo, tmp, return FALSE); + /* We convert the zigzag-order table to natural array order. */ + quant_ptr->quantval[jpeg_natural_order[i]] = (unsigned short) tmp; + } + + if (cinfo->err->trace_level >= 2) { + for (i = 0; i < DCTSIZE2; i += 8) { + TRACEMS8(cinfo, 2, JTRC_QUANTVALS, + quant_ptr->quantval[i], quant_ptr->quantval[i+1], + quant_ptr->quantval[i+2], quant_ptr->quantval[i+3], + quant_ptr->quantval[i+4], quant_ptr->quantval[i+5], + quant_ptr->quantval[i+6], quant_ptr->quantval[i+7]); + } + } + + length -= DCTSIZE2+1; + if (prec) length -= DCTSIZE2; + } + + if (length != 0) + ERREXIT(cinfo, JERR_BAD_LENGTH); + + INPUT_SYNC(cinfo); + return TRUE; +} + + +LOCAL(int) +get_dri (j_decompress_ptr cinfo) +/* Process a DRI marker */ +{ + long length; + unsigned int tmp; + INPUT_VARS(cinfo); + + INPUT_2BYTES(cinfo, length, return FALSE); + + if (length != 4) + ERREXIT(cinfo, JERR_BAD_LENGTH); + + INPUT_2BYTES(cinfo, tmp, return FALSE); + + TRACEMS1(cinfo, 1, JTRC_DRI, tmp); + + cinfo->restart_interval = tmp; + + INPUT_SYNC(cinfo); + return TRUE; +} + + +/* + * Routines for processing APPn and COM markers. + * These are either saved in memory or discarded, per application request. + * APP0 and APP14 are specially checked to see if they are + * JFIF and Adobe markers, respectively. + */ + +#define APP0_DATA_LEN 14 /* Length of interesting data in APP0 */ +#define APP14_DATA_LEN 12 /* Length of interesting data in APP14 */ +#define APPN_DATA_LEN 14 /* Must be the largest of the above!! */ + + +LOCAL(void) +examine_app0 (j_decompress_ptr cinfo, JOCTET FAR * data, + unsigned int datalen, long remaining) +/* Examine first few bytes from an APP0. + * Take appropriate action if it is a JFIF marker. + * datalen is # of bytes at data[], remaining is length of rest of marker data. + */ +{ + long totallen = (long) datalen + remaining; + + if (datalen >= APP0_DATA_LEN && + GETJOCTET(data[0]) == 0x4A && + GETJOCTET(data[1]) == 0x46 && + GETJOCTET(data[2]) == 0x49 && + GETJOCTET(data[3]) == 0x46 && + GETJOCTET(data[4]) == 0) { + /* Found JFIF APP0 marker: save info */ + cinfo->saw_JFIF_marker = TRUE; + cinfo->JFIF_major_version = GETJOCTET(data[5]); + cinfo->JFIF_minor_version = GETJOCTET(data[6]); + cinfo->density_unit = GETJOCTET(data[7]); + cinfo->X_density = (GETJOCTET(data[8]) << 8) + GETJOCTET(data[9]); + cinfo->Y_density = (GETJOCTET(data[10]) << 8) + GETJOCTET(data[11]); + /* Check version. + * Major version must be 1, anything else signals an incompatible change. + * (We used to treat this as an error, but now it's a nonfatal warning, + * because some bozo at Hijaak couldn't read the spec.) + * Minor version should be 0..2, but process anyway if newer. + */ + if (cinfo->JFIF_major_version != 1) + WARNMS2(cinfo, JWRN_JFIF_MAJOR, + cinfo->JFIF_major_version, cinfo->JFIF_minor_version); + /* Generate trace messages */ + TRACEMS5(cinfo, 1, JTRC_JFIF, + cinfo->JFIF_major_version, cinfo->JFIF_minor_version, + cinfo->X_density, cinfo->Y_density, cinfo->density_unit); + /* Validate thumbnail dimensions and issue appropriate messages */ + if (GETJOCTET(data[12]) | GETJOCTET(data[13])) + TRACEMS2(cinfo, 1, JTRC_JFIF_THUMBNAIL, + GETJOCTET(data[12]), GETJOCTET(data[13])); + totallen -= APP0_DATA_LEN; + if (totallen != + ((long)GETJOCTET(data[12]) * (long)GETJOCTET(data[13]) * (long) 3)) + TRACEMS1(cinfo, 1, JTRC_JFIF_BADTHUMBNAILSIZE, (int) totallen); + } else if (datalen >= 6 && + GETJOCTET(data[0]) == 0x4A && + GETJOCTET(data[1]) == 0x46 && + GETJOCTET(data[2]) == 0x58 && + GETJOCTET(data[3]) == 0x58 && + GETJOCTET(data[4]) == 0) { + /* Found JFIF "JFXX" extension APP0 marker */ + /* The library doesn't actually do anything with these, + * but we try to produce a helpful trace message. + */ + switch (GETJOCTET(data[5])) { + case 0x10: + TRACEMS1(cinfo, 1, JTRC_THUMB_JPEG, (int) totallen); + break; + case 0x11: + TRACEMS1(cinfo, 1, JTRC_THUMB_PALETTE, (int) totallen); + break; + case 0x13: + TRACEMS1(cinfo, 1, JTRC_THUMB_RGB, (int) totallen); + break; + default: + TRACEMS2(cinfo, 1, JTRC_JFIF_EXTENSION, + GETJOCTET(data[5]), (int) totallen); + break; + } + } else { + /* Start of APP0 does not match "JFIF" or "JFXX", or too short */ + TRACEMS1(cinfo, 1, JTRC_APP0, (int) totallen); + } +} + + +LOCAL(void) +examine_app14 (j_decompress_ptr cinfo, JOCTET FAR * data, + unsigned int datalen, long remaining) +/* Examine first few bytes from an APP14. + * Take appropriate action if it is an Adobe marker. + * datalen is # of bytes at data[], remaining is length of rest of marker data. + */ +{ + unsigned int version, flags0, flags1, transform; + + if (datalen >= APP14_DATA_LEN && + GETJOCTET(data[0]) == 0x41 && + GETJOCTET(data[1]) == 0x64 && + GETJOCTET(data[2]) == 0x6F && + GETJOCTET(data[3]) == 0x62 && + GETJOCTET(data[4]) == 0x65) { + /* Found Adobe APP14 marker */ + version = (GETJOCTET(data[5]) << 8) + GETJOCTET(data[6]); + flags0 = (GETJOCTET(data[7]) << 8) + GETJOCTET(data[8]); + flags1 = (GETJOCTET(data[9]) << 8) + GETJOCTET(data[10]); + transform = GETJOCTET(data[11]); + TRACEMS4(cinfo, 1, JTRC_ADOBE, version, flags0, flags1, transform); + cinfo->saw_Adobe_marker = TRUE; + cinfo->Adobe_transform = (unsigned char) transform; + } else { + /* Start of APP14 does not match "Adobe", or too short */ + TRACEMS1(cinfo, 1, JTRC_APP14, (int) (datalen + remaining)); + } +} + + +METHODDEF(int) +get_interesting_appn (j_decompress_ptr cinfo) +/* Process an APP0 or APP14 marker without saving it */ +{ + long length; + JOCTET b[APPN_DATA_LEN]; + unsigned int i, numtoread; + INPUT_VARS(cinfo); + + INPUT_2BYTES(cinfo, length, return FALSE); + length -= 2; + + /* get the interesting part of the marker data */ + if (length >= APPN_DATA_LEN) + numtoread = APPN_DATA_LEN; + else if (length > 0) + numtoread = (unsigned int) length; + else + numtoread = 0; + for (i = 0; i < numtoread; i++) + INPUT_BYTE(cinfo, b[i], return FALSE); + length -= numtoread; + + /* process it */ + switch (cinfo->unread_marker) { + case M_APP0: + examine_app0(cinfo, (JOCTET FAR *) b, numtoread, length); + break; + case M_APP14: + examine_app14(cinfo, (JOCTET FAR *) b, numtoread, length); + break; + default: + /* can't get here unless jpeg_save_markers chooses wrong processor */ + ERREXIT1(cinfo, JERR_UNKNOWN_MARKER, cinfo->unread_marker); + break; + } + + /* skip any remaining data -- could be lots */ + INPUT_SYNC(cinfo); + if (length > 0) + (*cinfo->src->skip_input_data) (cinfo, (long) length); + + return TRUE; +} + + +#ifdef SAVE_MARKERS_SUPPORTED + +METHODDEF(int) +save_marker (j_decompress_ptr cinfo) +/* Save an APPn or COM marker into the marker list */ +{ + my_marker_ptr marker = (my_marker_ptr) cinfo->marker; + jpeg_saved_marker_ptr cur_marker = marker->cur_marker; + unsigned int bytes_read, data_length; + JOCTET FAR * data; + long length = 0; + INPUT_VARS(cinfo); + + if (cur_marker == NULL) { + /* begin reading a marker */ + INPUT_2BYTES(cinfo, length, return FALSE); + length -= 2; + if (length >= 0) { /* watch out for bogus length word */ + /* figure out how much we want to save */ + unsigned int limit; + if (cinfo->unread_marker == (int) M_COM) + limit = marker->length_limit_COM; + else + limit = marker->length_limit_APPn[cinfo->unread_marker - (int) M_APP0]; + if ((unsigned int) length < limit) + limit = (unsigned int) length; + /* allocate and initialize the marker item */ + cur_marker = (jpeg_saved_marker_ptr) + (*cinfo->mem->alloc_large) ((j_common_ptr) cinfo, JPOOL_IMAGE, + SIZEOF(struct jpeg_marker_struct) + limit); + cur_marker->next = NULL; + cur_marker->marker = (unsigned char) cinfo->unread_marker; + cur_marker->original_length = (unsigned int) length; + cur_marker->data_length = limit; + /* data area is just beyond the jpeg_marker_struct */ + data = cur_marker->data = (JOCTET FAR *) (cur_marker + 1); + marker->cur_marker = cur_marker; + marker->bytes_read = 0; + bytes_read = 0; + data_length = limit; + } else { + /* deal with bogus length word */ + bytes_read = data_length = 0; + data = NULL; + } + } else { + /* resume reading a marker */ + bytes_read = marker->bytes_read; + data_length = cur_marker->data_length; + data = cur_marker->data + bytes_read; + } + + while (bytes_read < data_length) { + INPUT_SYNC(cinfo); /* move the restart point to here */ + marker->bytes_read = bytes_read; + /* If there's not at least one byte in buffer, suspend */ + MAKE_BYTE_AVAIL(cinfo, return FALSE); + /* Copy bytes with reasonable rapidity */ + while (bytes_read < data_length && bytes_in_buffer > 0) { + *data++ = *next_input_byte++; + bytes_in_buffer--; + bytes_read++; + } + } + + /* Done reading what we want to read */ + if (cur_marker != NULL) { /* will be NULL if bogus length word */ + /* Add new marker to end of list */ + if (cinfo->marker_list == NULL) { + cinfo->marker_list = cur_marker; + } else { + jpeg_saved_marker_ptr prev = cinfo->marker_list; + while (prev->next != NULL) + prev = prev->next; + prev->next = cur_marker; + } + /* Reset pointer & calc remaining data length */ + data = cur_marker->data; + length = cur_marker->original_length - data_length; + } + /* Reset to initial state for next marker */ + marker->cur_marker = NULL; + + /* Process the marker if interesting; else just make a generic trace msg */ + switch (cinfo->unread_marker) { + case M_APP0: + examine_app0(cinfo, data, data_length, length); + break; + case M_APP14: + examine_app14(cinfo, data, data_length, length); + break; + default: + TRACEMS2(cinfo, 1, JTRC_MISC_MARKER, cinfo->unread_marker, + (int) (data_length + length)); + break; + } + + /* skip any remaining data -- could be lots */ + INPUT_SYNC(cinfo); /* do before skip_input_data */ + if (length > 0) + (*cinfo->src->skip_input_data) (cinfo, (long) length); + + return TRUE; +} + +#endif /* SAVE_MARKERS_SUPPORTED */ + + +METHODDEF(int) +skip_variable (j_decompress_ptr cinfo) +/* Skip over an unknown or uninteresting variable-length marker */ +{ + long length; + INPUT_VARS(cinfo); + + INPUT_2BYTES(cinfo, length, return FALSE); + length -= 2; + + TRACEMS2(cinfo, 1, JTRC_MISC_MARKER, cinfo->unread_marker, (int) length); + + INPUT_SYNC(cinfo); /* do before skip_input_data */ + if (length > 0) + (*cinfo->src->skip_input_data) (cinfo, (long) length); + + return TRUE; +} + + +/* + * Find the next JPEG marker, save it in cinfo->unread_marker. + * Returns FALSE if had to suspend before reaching a marker; + * in that case cinfo->unread_marker is unchanged. + * + * Note that the result might not be a valid marker code, + * but it will never be 0 or FF. + */ + +LOCAL(int) +next_marker (j_decompress_ptr cinfo) +{ + int c; + INPUT_VARS(cinfo); + + for (;;) { + INPUT_BYTE(cinfo, c, return FALSE); + /* Skip any non-FF bytes. + * This may look a bit inefficient, but it will not occur in a valid file. + * We sync after each discarded byte so that a suspending data source + * can discard the byte from its buffer. + */ + while (c != 0xFF) { + cinfo->marker->discarded_bytes++; + INPUT_SYNC(cinfo); + INPUT_BYTE(cinfo, c, return FALSE); + } + /* This loop swallows any duplicate FF bytes. Extra FFs are legal as + * pad bytes, so don't count them in discarded_bytes. We assume there + * will not be so many consecutive FF bytes as to overflow a suspending + * data source's input buffer. + */ + do { + INPUT_BYTE(cinfo, c, return FALSE); + } while (c == 0xFF); + if (c != 0) + break; /* found a valid marker, exit loop */ + /* Reach here if we found a stuffed-zero data sequence (FF/00). + * Discard it and loop back to try again. + */ + cinfo->marker->discarded_bytes += 2; + INPUT_SYNC(cinfo); + } + + if (cinfo->marker->discarded_bytes != 0) { + WARNMS2(cinfo, JWRN_EXTRANEOUS_DATA, cinfo->marker->discarded_bytes, c); + cinfo->marker->discarded_bytes = 0; + } + + cinfo->unread_marker = c; + + INPUT_SYNC(cinfo); + return TRUE; +} + + +LOCAL(int) +first_marker (j_decompress_ptr cinfo) +/* Like next_marker, but used to obtain the initial SOI marker. */ +/* For this marker, we do not allow preceding garbage or fill; otherwise, + * we might well scan an entire input file before realizing it ain't JPEG. + * If an application wants to process non-JFIF files, it must seek to the + * SOI before calling the JPEG library. + */ +{ + int c, c2; + INPUT_VARS(cinfo); + + INPUT_BYTE(cinfo, c, return FALSE); + INPUT_BYTE(cinfo, c2, return FALSE); + if (c != 0xFF || c2 != (int) M_SOI) + ERREXIT2(cinfo, JERR_NO_SOI, c, c2); + + cinfo->unread_marker = c2; + + INPUT_SYNC(cinfo); + return TRUE; +} + + +/* + * Read markers until SOS or EOI. + * + * Returns same codes as are defined for jpeg_consume_input: + * JPEG_SUSPENDED, JPEG_REACHED_SOS, or JPEG_REACHED_EOI. + */ + +METHODDEF(int) +read_markers (j_decompress_ptr cinfo) +{ + /* Outer loop repeats once for each marker. */ + for (;;) { + /* Collect the marker proper, unless we already did. */ + /* NB: first_marker() enforces the requirement that SOI appear first. */ + if (cinfo->unread_marker == 0) { + if (! cinfo->marker->saw_SOI) { + if (! first_marker(cinfo)) + return JPEG_SUSPENDED; + } else { + if (! next_marker(cinfo)) + return JPEG_SUSPENDED; + } + } + /* At this point cinfo->unread_marker contains the marker code and the + * input point is just past the marker proper, but before any parameters. + * A suspension will cause us to return with this state still true. + */ + switch (cinfo->unread_marker) { + case M_SOI: + if (! get_soi(cinfo)) + return JPEG_SUSPENDED; + break; + + case M_SOF0: /* Baseline */ + case M_SOF1: /* Extended sequential, Huffman */ + if (! get_sof(cinfo, FALSE, FALSE)) + return JPEG_SUSPENDED; + break; + + case M_SOF2: /* Progressive, Huffman */ + if (! get_sof(cinfo, TRUE, FALSE)) + return JPEG_SUSPENDED; + break; + + case M_SOF9: /* Extended sequential, arithmetic */ + if (! get_sof(cinfo, FALSE, TRUE)) + return JPEG_SUSPENDED; + break; + + case M_SOF10: /* Progressive, arithmetic */ + if (! get_sof(cinfo, TRUE, TRUE)) + return JPEG_SUSPENDED; + break; + + /* Currently unsupported SOFn types */ + case M_SOF3: /* Lossless, Huffman */ + case M_SOF5: /* Differential sequential, Huffman */ + case M_SOF6: /* Differential progressive, Huffman */ + case M_SOF7: /* Differential lossless, Huffman */ + case M_JPG: /* Reserved for JPEG extensions */ + case M_SOF11: /* Lossless, arithmetic */ + case M_SOF13: /* Differential sequential, arithmetic */ + case M_SOF14: /* Differential progressive, arithmetic */ + case M_SOF15: /* Differential lossless, arithmetic */ + ERREXIT1(cinfo, JERR_SOF_UNSUPPORTED, cinfo->unread_marker); + break; + + case M_SOS: + if (! get_sos(cinfo)) + return JPEG_SUSPENDED; + cinfo->unread_marker = 0; /* processed the marker */ + return JPEG_REACHED_SOS; + + case M_EOI: + TRACEMS(cinfo, 1, JTRC_EOI); + cinfo->unread_marker = 0; /* processed the marker */ + return JPEG_REACHED_EOI; + + case M_DAC: + if (! get_dac(cinfo)) + return JPEG_SUSPENDED; + break; + + case M_DHT: + if (! get_dht(cinfo)) + return JPEG_SUSPENDED; + break; + + case M_DQT: + if (! get_dqt(cinfo)) + return JPEG_SUSPENDED; + break; + + case M_DRI: + if (! get_dri(cinfo)) + return JPEG_SUSPENDED; + break; + + case M_APP0: + case M_APP1: + case M_APP2: + case M_APP3: + case M_APP4: + case M_APP5: + case M_APP6: + case M_APP7: + case M_APP8: + case M_APP9: + case M_APP10: + case M_APP11: + case M_APP12: + case M_APP13: + case M_APP14: + case M_APP15: + if (! (*((my_marker_ptr) cinfo->marker)->process_APPn[ + cinfo->unread_marker - (int) M_APP0]) (cinfo)) + return JPEG_SUSPENDED; + break; + + case M_COM: + if (! (*((my_marker_ptr) cinfo->marker)->process_COM) (cinfo)) + return JPEG_SUSPENDED; + break; + + case M_RST0: /* these are all parameterless */ + case M_RST1: + case M_RST2: + case M_RST3: + case M_RST4: + case M_RST5: + case M_RST6: + case M_RST7: + case M_TEM: + TRACEMS1(cinfo, 1, JTRC_PARMLESS_MARKER, cinfo->unread_marker); + break; + + case M_DNL: /* Ignore DNL ... perhaps the wrong thing */ + if (! skip_variable(cinfo)) + return JPEG_SUSPENDED; + break; + + default: /* must be DHP, EXP, JPGn, or RESn */ + /* For now, we treat the reserved markers as fatal errors since they are + * likely to be used to signal incompatible JPEG Part 3 extensions. + * Once the JPEG 3 version-number marker is well defined, this code + * ought to change! + */ + ERREXIT1(cinfo, JERR_UNKNOWN_MARKER, cinfo->unread_marker); + break; + } + /* Successfully processed marker, so reset state variable */ + cinfo->unread_marker = 0; + } /* end loop */ +} + + +/* + * Read a restart marker, which is expected to appear next in the datastream; + * if the marker is not there, take appropriate recovery action. + * Returns FALSE if suspension is required. + * + * This is called by the entropy decoder after it has read an appropriate + * number of MCUs. cinfo->unread_marker may be nonzero if the entropy decoder + * has already read a marker from the data source. Under normal conditions + * cinfo->unread_marker will be reset to 0 before returning; if not reset, + * it holds a marker which the decoder will be unable to read past. + */ + +METHODDEF(int) +read_restart_marker (j_decompress_ptr cinfo) +{ + /* Obtain a marker unless we already did. */ + /* Note that next_marker will complain if it skips any data. */ + if (cinfo->unread_marker == 0) { + if (! next_marker(cinfo)) + return FALSE; + } + + if (cinfo->unread_marker == + ((int) M_RST0 + cinfo->marker->next_restart_num)) { + /* Normal case --- swallow the marker and let entropy decoder continue */ + TRACEMS1(cinfo, 3, JTRC_RST, cinfo->marker->next_restart_num); + cinfo->unread_marker = 0; + } else { + /* Uh-oh, the restart markers have been messed up. */ + /* Let the data source manager determine how to resync. */ + if (! (*cinfo->src->resync_to_restart) (cinfo, + cinfo->marker->next_restart_num)) + return FALSE; + } + + /* Update next-restart state */ + cinfo->marker->next_restart_num = (cinfo->marker->next_restart_num + 1) & 7; + + return TRUE; +} + + +/* + * This is the default resync_to_restart method for data source managers + * to use if they don't have any better approach. Some data source managers + * may be able to back up, or may have additional knowledge about the data + * which permits a more intelligent recovery strategy; such managers would + * presumably supply their own resync method. + * + * read_restart_marker calls resync_to_restart if it finds a marker other than + * the restart marker it was expecting. (This code is *not* used unless + * a nonzero restart interval has been declared.) cinfo->unread_marker is + * the marker code actually found (might be anything, except 0 or FF). + * The desired restart marker number (0..7) is passed as a parameter. + * This routine is supposed to apply whatever error recovery strategy seems + * appropriate in order to position the input stream to the next data segment. + * Note that cinfo->unread_marker is treated as a marker appearing before + * the current data-source input point; usually it should be reset to zero + * before returning. + * Returns FALSE if suspension is required. + * + * This implementation is substantially constrained by wanting to treat the + * input as a data stream; this means we can't back up. Therefore, we have + * only the following actions to work with: + * 1. Simply discard the marker and let the entropy decoder resume at next + * byte of file. + * 2. Read forward until we find another marker, discarding intervening + * data. (In theory we could look ahead within the current bufferload, + * without having to discard data if we don't find the desired marker. + * This idea is not implemented here, in part because it makes behavior + * dependent on buffer size and chance buffer-boundary positions.) + * 3. Leave the marker unread (by failing to zero cinfo->unread_marker). + * This will cause the entropy decoder to process an empty data segment, + * inserting dummy zeroes, and then we will reprocess the marker. + * + * #2 is appropriate if we think the desired marker lies ahead, while #3 is + * appropriate if the found marker is a future restart marker (indicating + * that we have missed the desired restart marker, probably because it got + * corrupted). + * We apply #2 or #3 if the found marker is a restart marker no more than + * two counts behind or ahead of the expected one. We also apply #2 if the + * found marker is not a legal JPEG marker code (it's certainly bogus data). + * If the found marker is a restart marker more than 2 counts away, we do #1 + * (too much risk that the marker is erroneous; with luck we will be able to + * resync at some future point). + * For any valid non-restart JPEG marker, we apply #3. This keeps us from + * overrunning the end of a scan. An implementation limited to single-scan + * files might find it better to apply #2 for markers other than EOI, since + * any other marker would have to be bogus data in that case. + */ + +GLOBAL(int) +jpeg_resync_to_restart (j_decompress_ptr cinfo, int desired) +{ + int marker = cinfo->unread_marker; + int action = 1; + + /* Always put up a warning. */ + WARNMS2(cinfo, JWRN_MUST_RESYNC, marker, desired); + + /* Outer loop handles repeated decision after scanning forward. */ + for (;;) { + if (marker < (int) M_SOF0) + action = 2; /* invalid marker */ + else if (marker < (int) M_RST0 || marker > (int) M_RST7) + action = 3; /* valid non-restart marker */ + else { + if (marker == ((int) M_RST0 + ((desired+1) & 7)) || + marker == ((int) M_RST0 + ((desired+2) & 7))) + action = 3; /* one of the next two expected restarts */ + else if (marker == ((int) M_RST0 + ((desired-1) & 7)) || + marker == ((int) M_RST0 + ((desired-2) & 7))) + action = 2; /* a prior restart, so advance */ + else + action = 1; /* desired restart or too far away */ + } + TRACEMS2(cinfo, 4, JTRC_RECOVERY_ACTION, marker, action); + switch (action) { + case 1: + /* Discard marker and let entropy decoder resume processing. */ + cinfo->unread_marker = 0; + return TRUE; + case 2: + /* Scan to the next marker, and repeat the decision loop. */ + if (! next_marker(cinfo)) + return FALSE; + marker = cinfo->unread_marker; + break; + case 3: + /* Return without advancing past this marker. */ + /* Entropy decoder will be forced to process an empty segment. */ + return TRUE; + } + } /* end loop */ +} + + +/* + * Reset marker processing state to begin a fresh datastream. + */ + +METHODDEF(void) +reset_marker_reader (j_decompress_ptr cinfo) +{ + my_marker_ptr marker = (my_marker_ptr) cinfo->marker; + + cinfo->comp_info = NULL; /* until allocated by get_sof */ + cinfo->input_scan_number = 0; /* no SOS seen yet */ + cinfo->unread_marker = 0; /* no pending marker */ + marker->pub.saw_SOI = FALSE; /* set internal state too */ + marker->pub.saw_SOF = FALSE; + marker->pub.discarded_bytes = 0; + marker->cur_marker = NULL; +} + + +/* + * Initialize the marker reader module. + * This is called only once, when the decompression object is created. + */ + +GLOBAL(void) +jinit_marker_reader (j_decompress_ptr cinfo) +{ + my_marker_ptr marker; + int i; + + /* Create subobject in permanent pool */ + marker = (my_marker_ptr) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_PERMANENT, + SIZEOF(my_marker_reader)); + cinfo->marker = (struct jpeg_marker_reader *) marker; + /* Initialize public method pointers */ + marker->pub.reset_marker_reader = reset_marker_reader; + marker->pub.read_markers = read_markers; + marker->pub.read_restart_marker = read_restart_marker; + /* Initialize COM/APPn processing. + * By default, we examine and then discard APP0 and APP14, + * but simply discard COM and all other APPn. + */ + marker->process_COM = skip_variable; + marker->length_limit_COM = 0; + for (i = 0; i < 16; i++) { + marker->process_APPn[i] = skip_variable; + marker->length_limit_APPn[i] = 0; + } + marker->process_APPn[0] = get_interesting_appn; + marker->process_APPn[14] = get_interesting_appn; + /* Reset marker processing state */ + reset_marker_reader(cinfo); +} + + +/* + * Control saving of COM and APPn markers into marker_list. + */ + +#ifdef SAVE_MARKERS_SUPPORTED + +GLOBAL(void) +jpeg_save_markers (j_decompress_ptr cinfo, int marker_code, + unsigned int length_limit) +{ + my_marker_ptr marker = (my_marker_ptr) cinfo->marker; + long maxlength; + jpeg_marker_parser_method processor; + + /* Length limit mustn't be larger than what we can allocate + * (should only be a concern in a 16-bit environment). + */ + maxlength = cinfo->mem->max_alloc_chunk - SIZEOF(struct jpeg_marker_struct); + if (((long) length_limit) > maxlength) + length_limit = (unsigned int) maxlength; + + /* Choose processor routine to use. + * APP0/APP14 have special requirements. + */ + if (length_limit) { + processor = save_marker; + /* If saving APP0/APP14, save at least enough for our internal use. */ + if (marker_code == (int) M_APP0 && length_limit < APP0_DATA_LEN) + length_limit = APP0_DATA_LEN; + else if (marker_code == (int) M_APP14 && length_limit < APP14_DATA_LEN) + length_limit = APP14_DATA_LEN; + } else { + processor = skip_variable; + /* If discarding APP0/APP14, use our regular on-the-fly processor. */ + if (marker_code == (int) M_APP0 || marker_code == (int) M_APP14) + processor = get_interesting_appn; + } + + if (marker_code == (int) M_COM) { + marker->process_COM = processor; + marker->length_limit_COM = length_limit; + } else if (marker_code >= (int) M_APP0 && marker_code <= (int) M_APP15) { + marker->process_APPn[marker_code - (int) M_APP0] = processor; + marker->length_limit_APPn[marker_code - (int) M_APP0] = length_limit; + } else + ERREXIT1(cinfo, JERR_UNKNOWN_MARKER, marker_code); +} + +#endif /* SAVE_MARKERS_SUPPORTED */ + + +/* + * Install a special processing method for COM or APPn markers. + */ + +GLOBAL(void) +jpeg_set_marker_processor (j_decompress_ptr cinfo, int marker_code, + jpeg_marker_parser_method routine) +{ + my_marker_ptr marker = (my_marker_ptr) cinfo->marker; + + if (marker_code == (int) M_COM) + marker->process_COM = routine; + else if (marker_code >= (int) M_APP0 && marker_code <= (int) M_APP15) + marker->process_APPn[marker_code - (int) M_APP0] = routine; + else + ERREXIT1(cinfo, JERR_UNKNOWN_MARKER, marker_code); +} diff --git a/ml/dlib/dlib/external/libjpeg/jdmaster.cpp b/ml/dlib/dlib/external/libjpeg/jdmaster.cpp new file mode 100644 index 000000000..8aea1c8ea --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jdmaster.cpp @@ -0,0 +1,557 @@ +/* + * jdmaster.c + * + * Copyright (C) 1991-1997, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This file contains master control logic for the JPEG decompressor. + * These routines are concerned with selecting the modules to be executed + * and with determining the number of passes and the work to be done in each + * pass. + */ + +#define JPEG_INTERNALS +#include "jinclude.h" +#include "jpeglib.h" + + +/* Private state */ + +typedef struct { + struct jpeg_decomp_master pub; /* public fields */ + + int pass_number; /* # of passes completed */ + + int using_merged_upsample; /* TRUE if using merged upsample/cconvert */ + + /* Saved references to initialized quantizer modules, + * in case we need to switch modes. + */ + struct jpeg_color_quantizer * quantizer_1pass; + struct jpeg_color_quantizer * quantizer_2pass; +} my_decomp_master; + +typedef my_decomp_master * my_master_ptr; + + +/* + * Determine whether merged upsample/color conversion should be used. + * CRUCIAL: this must match the actual capabilities of jdmerge.c! + */ + +LOCAL(int) +use_merged_upsample (j_decompress_ptr cinfo) +{ +#ifdef UPSAMPLE_MERGING_SUPPORTED + /* Merging is the equivalent of plain box-filter upsampling */ + if (cinfo->do_fancy_upsampling || cinfo->CCIR601_sampling) + return FALSE; + /* jdmerge.c only supports YCC=>RGB color conversion */ + if (cinfo->jpeg_color_space != JCS_YCbCr || cinfo->num_components != 3 || + cinfo->out_color_space != JCS_RGB || + cinfo->out_color_components != RGB_PIXELSIZE) + return FALSE; + /* and it only handles 2h1v or 2h2v sampling ratios */ + if (cinfo->comp_info[0].h_samp_factor != 2 || + cinfo->comp_info[1].h_samp_factor != 1 || + cinfo->comp_info[2].h_samp_factor != 1 || + cinfo->comp_info[0].v_samp_factor > 2 || + cinfo->comp_info[1].v_samp_factor != 1 || + cinfo->comp_info[2].v_samp_factor != 1) + return FALSE; + /* furthermore, it doesn't work if we've scaled the IDCTs differently */ + if (cinfo->comp_info[0].DCT_scaled_size != cinfo->min_DCT_scaled_size || + cinfo->comp_info[1].DCT_scaled_size != cinfo->min_DCT_scaled_size || + cinfo->comp_info[2].DCT_scaled_size != cinfo->min_DCT_scaled_size) + return FALSE; + /* ??? also need to test for upsample-time rescaling, when & if supported */ + return TRUE; /* by golly, it'll work... */ +#else + return FALSE; +#endif +} + + +/* + * Compute output image dimensions and related values. + * NOTE: this is exported for possible use by application. + * Hence it mustn't do anything that can't be done twice. + * Also note that it may be called before the master module is initialized! + */ + +GLOBAL(void) +jpeg_calc_output_dimensions (j_decompress_ptr cinfo) +/* Do computations that are needed before master selection phase */ +{ +#ifdef IDCT_SCALING_SUPPORTED + int ci; + jpeg_component_info *compptr; +#endif + + /* Prevent application from calling me at wrong times */ + if (cinfo->global_state != DSTATE_READY) + ERREXIT1(cinfo, JERR_BAD_STATE, cinfo->global_state); + +#ifdef IDCT_SCALING_SUPPORTED + + /* Compute actual output image dimensions and DCT scaling choices. */ + if (cinfo->scale_num * 8 <= cinfo->scale_denom) { + /* Provide 1/8 scaling */ + cinfo->output_width = (JDIMENSION) + jdiv_round_up((long) cinfo->image_width, 8L); + cinfo->output_height = (JDIMENSION) + jdiv_round_up((long) cinfo->image_height, 8L); + cinfo->min_DCT_scaled_size = 1; + } else if (cinfo->scale_num * 4 <= cinfo->scale_denom) { + /* Provide 1/4 scaling */ + cinfo->output_width = (JDIMENSION) + jdiv_round_up((long) cinfo->image_width, 4L); + cinfo->output_height = (JDIMENSION) + jdiv_round_up((long) cinfo->image_height, 4L); + cinfo->min_DCT_scaled_size = 2; + } else if (cinfo->scale_num * 2 <= cinfo->scale_denom) { + /* Provide 1/2 scaling */ + cinfo->output_width = (JDIMENSION) + jdiv_round_up((long) cinfo->image_width, 2L); + cinfo->output_height = (JDIMENSION) + jdiv_round_up((long) cinfo->image_height, 2L); + cinfo->min_DCT_scaled_size = 4; + } else { + /* Provide 1/1 scaling */ + cinfo->output_width = cinfo->image_width; + cinfo->output_height = cinfo->image_height; + cinfo->min_DCT_scaled_size = DCTSIZE; + } + /* In selecting the actual DCT scaling for each component, we try to + * scale up the chroma components via IDCT scaling rather than upsampling. + * This saves time if the upsampler gets to use 1:1 scaling. + * Note this code assumes that the supported DCT scalings are powers of 2. + */ + for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components; + ci++, compptr++) { + int ssize = cinfo->min_DCT_scaled_size; + while (ssize < DCTSIZE && + (compptr->h_samp_factor * ssize * 2 <= + cinfo->max_h_samp_factor * cinfo->min_DCT_scaled_size) && + (compptr->v_samp_factor * ssize * 2 <= + cinfo->max_v_samp_factor * cinfo->min_DCT_scaled_size)) { + ssize = ssize * 2; + } + compptr->DCT_scaled_size = ssize; + } + + /* Recompute downsampled dimensions of components; + * application needs to know these if using raw downsampled data. + */ + for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components; + ci++, compptr++) { + /* Size in samples, after IDCT scaling */ + compptr->downsampled_width = (JDIMENSION) + jdiv_round_up((long) cinfo->image_width * + (long) (compptr->h_samp_factor * compptr->DCT_scaled_size), + (long) (cinfo->max_h_samp_factor * DCTSIZE)); + compptr->downsampled_height = (JDIMENSION) + jdiv_round_up((long) cinfo->image_height * + (long) (compptr->v_samp_factor * compptr->DCT_scaled_size), + (long) (cinfo->max_v_samp_factor * DCTSIZE)); + } + +#else /* !IDCT_SCALING_SUPPORTED */ + + /* Hardwire it to "no scaling" */ + cinfo->output_width = cinfo->image_width; + cinfo->output_height = cinfo->image_height; + /* jdinput.c has already initialized DCT_scaled_size to DCTSIZE, + * and has computed unscaled downsampled_width and downsampled_height. + */ + +#endif /* IDCT_SCALING_SUPPORTED */ + + /* Report number of components in selected colorspace. */ + /* Probably this should be in the color conversion module... */ + switch (cinfo->out_color_space) { + case JCS_GRAYSCALE: + cinfo->out_color_components = 1; + break; + case JCS_RGB: +#if RGB_PIXELSIZE != 3 + cinfo->out_color_components = RGB_PIXELSIZE; + break; +#endif /* else share code with YCbCr */ + case JCS_YCbCr: + cinfo->out_color_components = 3; + break; + case JCS_CMYK: + case JCS_YCCK: + cinfo->out_color_components = 4; + break; + default: /* else must be same colorspace as in file */ + cinfo->out_color_components = cinfo->num_components; + break; + } + cinfo->output_components = (cinfo->quantize_colors ? 1 : + cinfo->out_color_components); + + /* See if upsampler will want to emit more than one row at a time */ + if (use_merged_upsample(cinfo)) + cinfo->rec_outbuf_height = cinfo->max_v_samp_factor; + else + cinfo->rec_outbuf_height = 1; +} + + +/* + * Several decompression processes need to range-limit values to the range + * 0..MAXJSAMPLE; the input value may fall somewhat outside this range + * due to noise introduced by quantization, roundoff error, etc. These + * processes are inner loops and need to be as fast as possible. On most + * machines, particularly CPUs with pipelines or instruction prefetch, + * a (subscript-check-less) C table lookup + * x = sample_range_limit[x]; + * is faster than explicit tests + * if (x < 0) x = 0; + * else if (x > MAXJSAMPLE) x = MAXJSAMPLE; + * These processes all use a common table prepared by the routine below. + * + * For most steps we can mathematically guarantee that the initial value + * of x is within MAXJSAMPLE+1 of the legal range, so a table running from + * -(MAXJSAMPLE+1) to 2*MAXJSAMPLE+1 is sufficient. But for the initial + * limiting step (just after the IDCT), a wildly out-of-range value is + * possible if the input data is corrupt. To avoid any chance of indexing + * off the end of memory and getting a bad-pointer trap, we perform the + * post-IDCT limiting thus: + * x = range_limit[x & MASK]; + * where MASK is 2 bits wider than legal sample data, ie 10 bits for 8-bit + * samples. Under normal circumstances this is more than enough range and + * a correct output will be generated; with bogus input data the mask will + * cause wraparound, and we will safely generate a bogus-but-in-range output. + * For the post-IDCT step, we want to convert the data from signed to unsigned + * representation by adding CENTERJSAMPLE at the same time that we limit it. + * So the post-IDCT limiting table ends up looking like this: + * CENTERJSAMPLE,CENTERJSAMPLE+1,...,MAXJSAMPLE, + * MAXJSAMPLE (repeat 2*(MAXJSAMPLE+1)-CENTERJSAMPLE times), + * 0 (repeat 2*(MAXJSAMPLE+1)-CENTERJSAMPLE times), + * 0,1,...,CENTERJSAMPLE-1 + * Negative inputs select values from the upper half of the table after + * masking. + * + * We can save some space by overlapping the start of the post-IDCT table + * with the simpler range limiting table. The post-IDCT table begins at + * sample_range_limit + CENTERJSAMPLE. + * + * Note that the table is allocated in near data space on PCs; it's small + * enough and used often enough to justify this. + */ + +LOCAL(void) +prepare_range_limit_table (j_decompress_ptr cinfo) +/* Allocate and fill in the sample_range_limit table */ +{ + JSAMPLE * table; + int i; + + table = (JSAMPLE *) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE, + (5 * (MAXJSAMPLE+1) + CENTERJSAMPLE) * SIZEOF(JSAMPLE)); + table += (MAXJSAMPLE+1); /* allow negative subscripts of simple table */ + cinfo->sample_range_limit = table; + /* First segment of "simple" table: limit[x] = 0 for x < 0 */ + MEMZERO(table - (MAXJSAMPLE+1), (MAXJSAMPLE+1) * SIZEOF(JSAMPLE)); + /* Main part of "simple" table: limit[x] = x */ + for (i = 0; i <= MAXJSAMPLE; i++) + table[i] = (JSAMPLE) i; + table += CENTERJSAMPLE; /* Point to where post-IDCT table starts */ + /* End of simple table, rest of first half of post-IDCT table */ + for (i = CENTERJSAMPLE; i < 2*(MAXJSAMPLE+1); i++) + table[i] = MAXJSAMPLE; + /* Second half of post-IDCT table */ + MEMZERO(table + (2 * (MAXJSAMPLE+1)), + (2 * (MAXJSAMPLE+1) - CENTERJSAMPLE) * SIZEOF(JSAMPLE)); + MEMCOPY(table + (4 * (MAXJSAMPLE+1) - CENTERJSAMPLE), + cinfo->sample_range_limit, CENTERJSAMPLE * SIZEOF(JSAMPLE)); +} + + +/* + * Master selection of decompression modules. + * This is done once at jpeg_start_decompress time. We determine + * which modules will be used and give them appropriate initialization calls. + * We also initialize the decompressor input side to begin consuming data. + * + * Since jpeg_read_header has finished, we know what is in the SOF + * and (first) SOS markers. We also have all the application parameter + * settings. + */ + +LOCAL(void) +master_selection (j_decompress_ptr cinfo) +{ + my_master_ptr master = (my_master_ptr) cinfo->master; + int use_c_buffer; + long samplesperrow; + JDIMENSION jd_samplesperrow; + + /* Initialize dimensions and other stuff */ + jpeg_calc_output_dimensions(cinfo); + prepare_range_limit_table(cinfo); + + /* Width of an output scanline must be representable as JDIMENSION. */ + samplesperrow = (long) cinfo->output_width * (long) cinfo->out_color_components; + jd_samplesperrow = (JDIMENSION) samplesperrow; + if ((long) jd_samplesperrow != samplesperrow) + ERREXIT(cinfo, JERR_WIDTH_OVERFLOW); + + /* Initialize my private state */ + master->pass_number = 0; + master->using_merged_upsample = use_merged_upsample(cinfo); + + /* Color quantizer selection */ + master->quantizer_1pass = NULL; + master->quantizer_2pass = NULL; + /* No mode changes if not using buffered-image mode. */ + if (! cinfo->quantize_colors || ! cinfo->buffered_image) { + cinfo->enable_1pass_quant = FALSE; + cinfo->enable_external_quant = FALSE; + cinfo->enable_2pass_quant = FALSE; + } + if (cinfo->quantize_colors) { + if (cinfo->raw_data_out) + ERREXIT(cinfo, JERR_NOTIMPL); + /* 2-pass quantizer only works in 3-component color space. */ + if (cinfo->out_color_components != 3) { + cinfo->enable_1pass_quant = TRUE; + cinfo->enable_external_quant = FALSE; + cinfo->enable_2pass_quant = FALSE; + cinfo->colormap = NULL; + } else if (cinfo->colormap != NULL) { + cinfo->enable_external_quant = TRUE; + } else if (cinfo->two_pass_quantize) { + cinfo->enable_2pass_quant = TRUE; + } else { + cinfo->enable_1pass_quant = TRUE; + } + + if (cinfo->enable_1pass_quant) { +#ifdef QUANT_1PASS_SUPPORTED + jinit_1pass_quantizer(cinfo); + master->quantizer_1pass = cinfo->cquantize; +#else + ERREXIT(cinfo, JERR_NOT_COMPILED); +#endif + } + + /* We use the 2-pass code to map to external colormaps. */ + if (cinfo->enable_2pass_quant || cinfo->enable_external_quant) { +#ifdef QUANT_2PASS_SUPPORTED + jinit_2pass_quantizer(cinfo); + master->quantizer_2pass = cinfo->cquantize; +#else + ERREXIT(cinfo, JERR_NOT_COMPILED); +#endif + } + /* If both quantizers are initialized, the 2-pass one is left active; + * this is necessary for starting with quantization to an external map. + */ + } + + /* Post-processing: in particular, color conversion first */ + if (! cinfo->raw_data_out) { + if (master->using_merged_upsample) { +#ifdef UPSAMPLE_MERGING_SUPPORTED + jinit_merged_upsampler(cinfo); /* does color conversion too */ +#else + ERREXIT(cinfo, JERR_NOT_COMPILED); +#endif + } else { + jinit_color_deconverter(cinfo); + jinit_upsampler(cinfo); + } + jinit_d_post_controller(cinfo, cinfo->enable_2pass_quant); + } + /* Inverse DCT */ + jinit_inverse_dct(cinfo); + /* Entropy decoding: either Huffman or arithmetic coding. */ + if (cinfo->arith_code) { + ERREXIT(cinfo, JERR_ARITH_NOTIMPL); + } else { + if (cinfo->progressive_mode) { +#ifdef D_PROGRESSIVE_SUPPORTED + jinit_phuff_decoder(cinfo); +#else + ERREXIT(cinfo, JERR_NOT_COMPILED); +#endif + } else + jinit_huff_decoder(cinfo); + } + + /* Initialize principal buffer controllers. */ + use_c_buffer = cinfo->inputctl->has_multiple_scans || cinfo->buffered_image; + jinit_d_coef_controller(cinfo, use_c_buffer); + + if (! cinfo->raw_data_out) + jinit_d_main_controller(cinfo, FALSE /* never need full buffer here */); + + /* We can now tell the memory manager to allocate virtual arrays. */ + (*cinfo->mem->realize_virt_arrays) ((j_common_ptr) cinfo); + + /* Initialize input side of decompressor to consume first scan. */ + (*cinfo->inputctl->start_input_pass) (cinfo); + +#ifdef D_MULTISCAN_FILES_SUPPORTED + /* If jpeg_start_decompress will read the whole file, initialize + * progress monitoring appropriately. The input step is counted + * as one pass. + */ + if (cinfo->progress != NULL && ! cinfo->buffered_image && + cinfo->inputctl->has_multiple_scans) { + int nscans; + /* Estimate number of scans to set pass_limit. */ + if (cinfo->progressive_mode) { + /* Arbitrarily estimate 2 interleaved DC scans + 3 AC scans/component. */ + nscans = 2 + 3 * cinfo->num_components; + } else { + /* For a nonprogressive multiscan file, estimate 1 scan per component. */ + nscans = cinfo->num_components; + } + cinfo->progress->pass_counter = 0L; + cinfo->progress->pass_limit = (long) cinfo->total_iMCU_rows * nscans; + cinfo->progress->completed_passes = 0; + cinfo->progress->total_passes = (cinfo->enable_2pass_quant ? 3 : 2); + /* Count the input pass as done */ + master->pass_number++; + } +#endif /* D_MULTISCAN_FILES_SUPPORTED */ +} + + +/* + * Per-pass setup. + * This is called at the beginning of each output pass. We determine which + * modules will be active during this pass and give them appropriate + * start_pass calls. We also set is_dummy_pass to indicate whether this + * is a "real" output pass or a dummy pass for color quantization. + * (In the latter case, jdapistd.c will crank the pass to completion.) + */ + +METHODDEF(void) +prepare_for_output_pass (j_decompress_ptr cinfo) +{ + my_master_ptr master = (my_master_ptr) cinfo->master; + + if (master->pub.is_dummy_pass) { +#ifdef QUANT_2PASS_SUPPORTED + /* Final pass of 2-pass quantization */ + master->pub.is_dummy_pass = FALSE; + (*cinfo->cquantize->start_pass) (cinfo, FALSE); + (*cinfo->post->start_pass) (cinfo, JBUF_CRANK_DEST); + (*cinfo->main->start_pass) (cinfo, JBUF_CRANK_DEST); +#else + ERREXIT(cinfo, JERR_NOT_COMPILED); +#endif /* QUANT_2PASS_SUPPORTED */ + } else { + if (cinfo->quantize_colors && cinfo->colormap == NULL) { + /* Select new quantization method */ + if (cinfo->two_pass_quantize && cinfo->enable_2pass_quant) { + cinfo->cquantize = master->quantizer_2pass; + master->pub.is_dummy_pass = TRUE; + } else if (cinfo->enable_1pass_quant) { + cinfo->cquantize = master->quantizer_1pass; + } else { + ERREXIT(cinfo, JERR_MODE_CHANGE); + } + } + (*cinfo->idct->start_pass) (cinfo); + (*cinfo->coef->start_output_pass) (cinfo); + if (! cinfo->raw_data_out) { + if (! master->using_merged_upsample) + (*cinfo->cconvert->start_pass) (cinfo); + (*cinfo->upsample->start_pass) (cinfo); + if (cinfo->quantize_colors) + (*cinfo->cquantize->start_pass) (cinfo, master->pub.is_dummy_pass); + (*cinfo->post->start_pass) (cinfo, + (master->pub.is_dummy_pass ? JBUF_SAVE_AND_PASS : JBUF_PASS_THRU)); + (*cinfo->main->start_pass) (cinfo, JBUF_PASS_THRU); + } + } + + /* Set up progress monitor's pass info if present */ + if (cinfo->progress != NULL) { + cinfo->progress->completed_passes = master->pass_number; + cinfo->progress->total_passes = master->pass_number + + (master->pub.is_dummy_pass ? 2 : 1); + /* In buffered-image mode, we assume one more output pass if EOI not + * yet reached, but no more passes if EOI has been reached. + */ + if (cinfo->buffered_image && ! cinfo->inputctl->eoi_reached) { + cinfo->progress->total_passes += (cinfo->enable_2pass_quant ? 2 : 1); + } + } +} + + +/* + * Finish up at end of an output pass. + */ + +METHODDEF(void) +finish_output_pass (j_decompress_ptr cinfo) +{ + my_master_ptr master = (my_master_ptr) cinfo->master; + + if (cinfo->quantize_colors) + (*cinfo->cquantize->finish_pass) (cinfo); + master->pass_number++; +} + + +#ifdef D_MULTISCAN_FILES_SUPPORTED + +/* + * Switch to a new external colormap between output passes. + */ + +GLOBAL(void) +jpeg_new_colormap (j_decompress_ptr cinfo) +{ + my_master_ptr master = (my_master_ptr) cinfo->master; + + /* Prevent application from calling me at wrong times */ + if (cinfo->global_state != DSTATE_BUFIMAGE) + ERREXIT1(cinfo, JERR_BAD_STATE, cinfo->global_state); + + if (cinfo->quantize_colors && cinfo->enable_external_quant && + cinfo->colormap != NULL) { + /* Select 2-pass quantizer for external colormap use */ + cinfo->cquantize = master->quantizer_2pass; + /* Notify quantizer of colormap change */ + (*cinfo->cquantize->new_color_map) (cinfo); + master->pub.is_dummy_pass = FALSE; /* just in case */ + } else + ERREXIT(cinfo, JERR_MODE_CHANGE); +} + +#endif /* D_MULTISCAN_FILES_SUPPORTED */ + + +/* + * Initialize master decompression control and select active modules. + * This is performed at the start of jpeg_start_decompress. + */ + +GLOBAL(void) +jinit_master_decompress (j_decompress_ptr cinfo) +{ + my_master_ptr master; + + master = (my_master_ptr) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE, + SIZEOF(my_decomp_master)); + cinfo->master = (struct jpeg_decomp_master *) master; + master->pub.prepare_for_output_pass = prepare_for_output_pass; + master->pub.finish_output_pass = finish_output_pass; + + master->pub.is_dummy_pass = FALSE; + + master_selection(cinfo); +} diff --git a/ml/dlib/dlib/external/libjpeg/jdmerge.cpp b/ml/dlib/dlib/external/libjpeg/jdmerge.cpp new file mode 100644 index 000000000..38a692ff8 --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jdmerge.cpp @@ -0,0 +1,400 @@ +/* + * jdmerge.c + * + * Copyright (C) 1994-1996, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This file contains code for merged upsampling/color conversion. + * + * This file combines functions from jdsample.c and jdcolor.c; + * read those files first to understand what's going on. + * + * When the chroma components are to be upsampled by simple replication + * (ie, box filtering), we can save some work in color conversion by + * calculating all the output pixels corresponding to a pair of chroma + * samples at one time. In the conversion equations + * R = Y + K1 * Cr + * G = Y + K2 * Cb + K3 * Cr + * B = Y + K4 * Cb + * only the Y term varies among the group of pixels corresponding to a pair + * of chroma samples, so the rest of the terms can be calculated just once. + * At typical sampling ratios, this eliminates half or three-quarters of the + * multiplications needed for color conversion. + * + * This file currently provides implementations for the following cases: + * YCbCr => RGB color conversion only. + * Sampling ratios of 2h1v or 2h2v. + * No scaling needed at upsample time. + * Corner-aligned (non-CCIR601) sampling alignment. + * Other special cases could be added, but in most applications these are + * the only common cases. (For uncommon cases we fall back on the more + * general code in jdsample.c and jdcolor.c.) + */ + +#define JPEG_INTERNALS +#include "jinclude.h" +#include "jpeglib.h" + +#ifdef UPSAMPLE_MERGING_SUPPORTED + + +/* Private subobject */ + +typedef struct { + struct jpeg_upsampler pub; /* public fields */ + + /* Pointer to routine to do actual upsampling/conversion of one row group */ + JMETHOD(void, upmethod, (j_decompress_ptr cinfo, + JSAMPIMAGE input_buf, JDIMENSION in_row_group_ctr, + JSAMPARRAY output_buf)); + + /* Private state for YCC->RGB conversion */ + int * Cr_r_tab; /* => table for Cr to R conversion */ + int * Cb_b_tab; /* => table for Cb to B conversion */ + long * Cr_g_tab; /* => table for Cr to G conversion */ + long * Cb_g_tab; /* => table for Cb to G conversion */ + + /* For 2:1 vertical sampling, we produce two output rows at a time. + * We need a "spare" row buffer to hold the second output row if the + * application provides just a one-row buffer; we also use the spare + * to discard the dummy last row if the image height is odd. + */ + JSAMPROW spare_row; + int spare_full; /* T if spare buffer is occupied */ + + JDIMENSION out_row_width; /* samples per output row */ + JDIMENSION rows_to_go; /* counts rows remaining in image */ +} my_upsampler; + +typedef my_upsampler * my_upsample_ptr; + +#define SCALEBITS 16 /* speediest right-shift on some machines */ +#define ONE_HALF ((long) 1 << (SCALEBITS-1)) +#define FIX(x) ((long) ((x) * (1L<RGB colorspace conversion. + * This is taken directly from jdcolor.c; see that file for more info. + */ + +LOCAL(void) +build_ycc_rgb_table (j_decompress_ptr cinfo) +{ + my_upsample_ptr upsample = (my_upsample_ptr) cinfo->upsample; + int i; + long x; + SHIFT_TEMPS + + upsample->Cr_r_tab = (int *) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE, + (MAXJSAMPLE+1) * SIZEOF(int)); + upsample->Cb_b_tab = (int *) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE, + (MAXJSAMPLE+1) * SIZEOF(int)); + upsample->Cr_g_tab = (long *) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE, + (MAXJSAMPLE+1) * SIZEOF(long)); + upsample->Cb_g_tab = (long *) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE, + (MAXJSAMPLE+1) * SIZEOF(long)); + + for (i = 0, x = -CENTERJSAMPLE; i <= MAXJSAMPLE; i++, x++) { + /* i is the actual input pixel value, in the range 0..MAXJSAMPLE */ + /* The Cb or Cr value we are thinking of is x = i - CENTERJSAMPLE */ + /* Cr=>R value is nearest int to 1.40200 * x */ + upsample->Cr_r_tab[i] = (int) + RIGHT_SHIFT(FIX(1.40200) * x + ONE_HALF, SCALEBITS); + /* Cb=>B value is nearest int to 1.77200 * x */ + upsample->Cb_b_tab[i] = (int) + RIGHT_SHIFT(FIX(1.77200) * x + ONE_HALF, SCALEBITS); + /* Cr=>G value is scaled-up -0.71414 * x */ + upsample->Cr_g_tab[i] = (- FIX(0.71414)) * x; + /* Cb=>G value is scaled-up -0.34414 * x */ + /* We also add in ONE_HALF so that need not do it in inner loop */ + upsample->Cb_g_tab[i] = (- FIX(0.34414)) * x + ONE_HALF; + } +} + + +/* + * Initialize for an upsampling pass. + */ + +METHODDEF(void) +start_pass_merged_upsample (j_decompress_ptr cinfo) +{ + my_upsample_ptr upsample = (my_upsample_ptr) cinfo->upsample; + + /* Mark the spare buffer empty */ + upsample->spare_full = FALSE; + /* Initialize total-height counter for detecting bottom of image */ + upsample->rows_to_go = cinfo->output_height; +} + + +/* + * Control routine to do upsampling (and color conversion). + * + * The control routine just handles the row buffering considerations. + */ + +METHODDEF(void) +merged_2v_upsample (j_decompress_ptr cinfo, + JSAMPIMAGE input_buf, JDIMENSION *in_row_group_ctr, + JDIMENSION , + JSAMPARRAY output_buf, JDIMENSION *out_row_ctr, + JDIMENSION out_rows_avail) +/* 2:1 vertical sampling case: may need a spare row. */ +{ + my_upsample_ptr upsample = (my_upsample_ptr) cinfo->upsample; + JSAMPROW work_ptrs[2]; + JDIMENSION num_rows; /* number of rows returned to caller */ + + if (upsample->spare_full) { + /* If we have a spare row saved from a previous cycle, just return it. */ + jcopy_sample_rows(& upsample->spare_row, 0, output_buf + *out_row_ctr, 0, + 1, upsample->out_row_width); + num_rows = 1; + upsample->spare_full = FALSE; + } else { + /* Figure number of rows to return to caller. */ + num_rows = 2; + /* Not more than the distance to the end of the image. */ + if (num_rows > upsample->rows_to_go) + num_rows = upsample->rows_to_go; + /* And not more than what the client can accept: */ + out_rows_avail -= *out_row_ctr; + if (num_rows > out_rows_avail) + num_rows = out_rows_avail; + /* Create output pointer array for upsampler. */ + work_ptrs[0] = output_buf[*out_row_ctr]; + if (num_rows > 1) { + work_ptrs[1] = output_buf[*out_row_ctr + 1]; + } else { + work_ptrs[1] = upsample->spare_row; + upsample->spare_full = TRUE; + } + /* Now do the upsampling. */ + (*upsample->upmethod) (cinfo, input_buf, *in_row_group_ctr, work_ptrs); + } + + /* Adjust counts */ + *out_row_ctr += num_rows; + upsample->rows_to_go -= num_rows; + /* When the buffer is emptied, declare this input row group consumed */ + if (! upsample->spare_full) + (*in_row_group_ctr)++; +} + + +METHODDEF(void) +merged_1v_upsample (j_decompress_ptr cinfo, + JSAMPIMAGE input_buf, JDIMENSION *in_row_group_ctr, + JDIMENSION , + JSAMPARRAY output_buf, JDIMENSION *out_row_ctr, + JDIMENSION ) +/* 1:1 vertical sampling case: much easier, never need a spare row. */ +{ + my_upsample_ptr upsample = (my_upsample_ptr) cinfo->upsample; + + /* Just do the upsampling. */ + (*upsample->upmethod) (cinfo, input_buf, *in_row_group_ctr, + output_buf + *out_row_ctr); + /* Adjust counts */ + (*out_row_ctr)++; + (*in_row_group_ctr)++; +} + + +/* + * These are the routines invoked by the control routines to do + * the actual upsampling/conversion. One row group is processed per call. + * + * Note: since we may be writing directly into application-supplied buffers, + * we have to be honest about the output width; we can't assume the buffer + * has been rounded up to an even width. + */ + + +/* + * Upsample and color convert for the case of 2:1 horizontal and 1:1 vertical. + */ + +METHODDEF(void) +h2v1_merged_upsample (j_decompress_ptr cinfo, + JSAMPIMAGE input_buf, JDIMENSION in_row_group_ctr, + JSAMPARRAY output_buf) +{ + my_upsample_ptr upsample = (my_upsample_ptr) cinfo->upsample; + int y, cred, cgreen, cblue; + int cb, cr; + JSAMPROW outptr; + JSAMPROW inptr0, inptr1, inptr2; + JDIMENSION col; + /* copy these pointers into registers if possible */ + JSAMPLE * range_limit = cinfo->sample_range_limit; + int * Crrtab = upsample->Cr_r_tab; + int * Cbbtab = upsample->Cb_b_tab; + long * Crgtab = upsample->Cr_g_tab; + long * Cbgtab = upsample->Cb_g_tab; + SHIFT_TEMPS + + inptr0 = input_buf[0][in_row_group_ctr]; + inptr1 = input_buf[1][in_row_group_ctr]; + inptr2 = input_buf[2][in_row_group_ctr]; + outptr = output_buf[0]; + /* Loop for each pair of output pixels */ + for (col = cinfo->output_width >> 1; col > 0; col--) { + /* Do the chroma part of the calculation */ + cb = GETJSAMPLE(*inptr1++); + cr = GETJSAMPLE(*inptr2++); + cred = Crrtab[cr]; + cgreen = (int) RIGHT_SHIFT(Cbgtab[cb] + Crgtab[cr], SCALEBITS); + cblue = Cbbtab[cb]; + /* Fetch 2 Y values and emit 2 pixels */ + y = GETJSAMPLE(*inptr0++); + outptr[RGB_RED] = range_limit[y + cred]; + outptr[RGB_GREEN] = range_limit[y + cgreen]; + outptr[RGB_BLUE] = range_limit[y + cblue]; + outptr += RGB_PIXELSIZE; + y = GETJSAMPLE(*inptr0++); + outptr[RGB_RED] = range_limit[y + cred]; + outptr[RGB_GREEN] = range_limit[y + cgreen]; + outptr[RGB_BLUE] = range_limit[y + cblue]; + outptr += RGB_PIXELSIZE; + } + /* If image width is odd, do the last output column separately */ + if (cinfo->output_width & 1) { + cb = GETJSAMPLE(*inptr1); + cr = GETJSAMPLE(*inptr2); + cred = Crrtab[cr]; + cgreen = (int) RIGHT_SHIFT(Cbgtab[cb] + Crgtab[cr], SCALEBITS); + cblue = Cbbtab[cb]; + y = GETJSAMPLE(*inptr0); + outptr[RGB_RED] = range_limit[y + cred]; + outptr[RGB_GREEN] = range_limit[y + cgreen]; + outptr[RGB_BLUE] = range_limit[y + cblue]; + } +} + + +/* + * Upsample and color convert for the case of 2:1 horizontal and 2:1 vertical. + */ + +METHODDEF(void) +h2v2_merged_upsample (j_decompress_ptr cinfo, + JSAMPIMAGE input_buf, JDIMENSION in_row_group_ctr, + JSAMPARRAY output_buf) +{ + my_upsample_ptr upsample = (my_upsample_ptr) cinfo->upsample; + int y, cred, cgreen, cblue; + int cb, cr; + JSAMPROW outptr0, outptr1; + JSAMPROW inptr00, inptr01, inptr1, inptr2; + JDIMENSION col; + /* copy these pointers into registers if possible */ + JSAMPLE * range_limit = cinfo->sample_range_limit; + int * Crrtab = upsample->Cr_r_tab; + int * Cbbtab = upsample->Cb_b_tab; + long * Crgtab = upsample->Cr_g_tab; + long * Cbgtab = upsample->Cb_g_tab; + SHIFT_TEMPS + + inptr00 = input_buf[0][in_row_group_ctr*2]; + inptr01 = input_buf[0][in_row_group_ctr*2 + 1]; + inptr1 = input_buf[1][in_row_group_ctr]; + inptr2 = input_buf[2][in_row_group_ctr]; + outptr0 = output_buf[0]; + outptr1 = output_buf[1]; + /* Loop for each group of output pixels */ + for (col = cinfo->output_width >> 1; col > 0; col--) { + /* Do the chroma part of the calculation */ + cb = GETJSAMPLE(*inptr1++); + cr = GETJSAMPLE(*inptr2++); + cred = Crrtab[cr]; + cgreen = (int) RIGHT_SHIFT(Cbgtab[cb] + Crgtab[cr], SCALEBITS); + cblue = Cbbtab[cb]; + /* Fetch 4 Y values and emit 4 pixels */ + y = GETJSAMPLE(*inptr00++); + outptr0[RGB_RED] = range_limit[y + cred]; + outptr0[RGB_GREEN] = range_limit[y + cgreen]; + outptr0[RGB_BLUE] = range_limit[y + cblue]; + outptr0 += RGB_PIXELSIZE; + y = GETJSAMPLE(*inptr00++); + outptr0[RGB_RED] = range_limit[y + cred]; + outptr0[RGB_GREEN] = range_limit[y + cgreen]; + outptr0[RGB_BLUE] = range_limit[y + cblue]; + outptr0 += RGB_PIXELSIZE; + y = GETJSAMPLE(*inptr01++); + outptr1[RGB_RED] = range_limit[y + cred]; + outptr1[RGB_GREEN] = range_limit[y + cgreen]; + outptr1[RGB_BLUE] = range_limit[y + cblue]; + outptr1 += RGB_PIXELSIZE; + y = GETJSAMPLE(*inptr01++); + outptr1[RGB_RED] = range_limit[y + cred]; + outptr1[RGB_GREEN] = range_limit[y + cgreen]; + outptr1[RGB_BLUE] = range_limit[y + cblue]; + outptr1 += RGB_PIXELSIZE; + } + /* If image width is odd, do the last output column separately */ + if (cinfo->output_width & 1) { + cb = GETJSAMPLE(*inptr1); + cr = GETJSAMPLE(*inptr2); + cred = Crrtab[cr]; + cgreen = (int) RIGHT_SHIFT(Cbgtab[cb] + Crgtab[cr], SCALEBITS); + cblue = Cbbtab[cb]; + y = GETJSAMPLE(*inptr00); + outptr0[RGB_RED] = range_limit[y + cred]; + outptr0[RGB_GREEN] = range_limit[y + cgreen]; + outptr0[RGB_BLUE] = range_limit[y + cblue]; + y = GETJSAMPLE(*inptr01); + outptr1[RGB_RED] = range_limit[y + cred]; + outptr1[RGB_GREEN] = range_limit[y + cgreen]; + outptr1[RGB_BLUE] = range_limit[y + cblue]; + } +} + + +/* + * Module initialization routine for merged upsampling/color conversion. + * + * NB: this is called under the conditions determined by use_merged_upsample() + * in jdmaster.c. That routine MUST correspond to the actual capabilities + * of this module; no safety checks are made here. + */ + +GLOBAL(void) +jinit_merged_upsampler (j_decompress_ptr cinfo) +{ + my_upsample_ptr upsample; + + upsample = (my_upsample_ptr) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE, + SIZEOF(my_upsampler)); + cinfo->upsample = (struct jpeg_upsampler *) upsample; + upsample->pub.start_pass = start_pass_merged_upsample; + upsample->pub.need_context_rows = FALSE; + + upsample->out_row_width = cinfo->output_width * cinfo->out_color_components; + + if (cinfo->max_v_samp_factor == 2) { + upsample->pub.upsample = merged_2v_upsample; + upsample->upmethod = h2v2_merged_upsample; + /* Allocate a spare row buffer */ + upsample->spare_row = (JSAMPROW) + (*cinfo->mem->alloc_large) ((j_common_ptr) cinfo, JPOOL_IMAGE, + (size_t) (upsample->out_row_width * SIZEOF(JSAMPLE))); + } else { + upsample->pub.upsample = merged_1v_upsample; + upsample->upmethod = h2v1_merged_upsample; + /* No spare row needed */ + upsample->spare_row = NULL; + } + + build_ycc_rgb_table(cinfo); +} + +#endif /* UPSAMPLE_MERGING_SUPPORTED */ diff --git a/ml/dlib/dlib/external/libjpeg/jdphuff.cpp b/ml/dlib/dlib/external/libjpeg/jdphuff.cpp new file mode 100644 index 000000000..d9c02fe0b --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jdphuff.cpp @@ -0,0 +1,671 @@ +/* + * jdphuff.c + * + * Copyright (C) 1995-1997, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This file contains Huffman entropy decoding routines for progressive JPEG. + * + * Much of the complexity here has to do with supporting input suspension. + * If the data source module demands suspension, we want to be able to back + * up to the start of the current MCU. To do this, we copy state variables + * into local working storage, and update them back to the permanent + * storage only upon successful completion of an MCU. + */ + +#define JPEG_INTERNALS +#include "jinclude.h" +#include "jpeglib.h" +#include "jdhuff.h" /* Declarations shared with jdhuff.c */ + +#ifdef __GNUC__ +#pragma GCC diagnostic ignored "-Wshift-negative-value" +#endif + +#ifdef D_PROGRESSIVE_SUPPORTED + +/* + * Expanded entropy decoder object for progressive Huffman decoding. + * + * The savable_state subrecord contains fields that change within an MCU, + * but must not be updated permanently until we complete the MCU. + */ + +typedef struct { + unsigned int EOBRUN; /* remaining EOBs in EOBRUN */ + int last_dc_val[MAX_COMPS_IN_SCAN]; /* last DC coef for each component */ +} savable_state; + +/* This macro is to work around compilers with missing or broken + * structure assignment. You'll need to fix this code if you have + * such a compiler and you change MAX_COMPS_IN_SCAN. + */ + +#ifndef NO_STRUCT_ASSIGN +#define ASSIGN_STATE(dest,src) ((dest) = (src)) +#else +#if MAX_COMPS_IN_SCAN == 4 +#define ASSIGN_STATE(dest,src) \ + ((dest).EOBRUN = (src).EOBRUN, \ + (dest).last_dc_val[0] = (src).last_dc_val[0], \ + (dest).last_dc_val[1] = (src).last_dc_val[1], \ + (dest).last_dc_val[2] = (src).last_dc_val[2], \ + (dest).last_dc_val[3] = (src).last_dc_val[3]) +#endif +#endif + + +typedef struct { + struct jpeg_entropy_decoder pub; /* public fields */ + + /* These fields are loaded into local variables at start of each MCU. + * In case of suspension, we exit WITHOUT updating them. + */ + bitread_perm_state bitstate; /* Bit buffer at start of MCU */ + savable_state saved; /* Other state at start of MCU */ + + /* These fields are NOT loaded into local working state. */ + unsigned int restarts_to_go; /* MCUs left in this restart interval */ + + /* Pointers to derived tables (these workspaces have image lifespan) */ + d_derived_tbl * derived_tbls[NUM_HUFF_TBLS]; + + d_derived_tbl * ac_derived_tbl; /* active table during an AC scan */ +} phuff_entropy_decoder; + +typedef phuff_entropy_decoder * phuff_entropy_ptr; + +/* Forward declarations */ +METHODDEF(int) decode_mcu_DC_first JPP((j_decompress_ptr cinfo, + JBLOCKROW *MCU_data)); +METHODDEF(int) decode_mcu_AC_first JPP((j_decompress_ptr cinfo, + JBLOCKROW *MCU_data)); +METHODDEF(int) decode_mcu_DC_refine JPP((j_decompress_ptr cinfo, + JBLOCKROW *MCU_data)); +METHODDEF(int) decode_mcu_AC_refine JPP((j_decompress_ptr cinfo, + JBLOCKROW *MCU_data)); + + +/* + * Initialize for a Huffman-compressed scan. + */ + +METHODDEF(void) +start_pass_phuff_decoder (j_decompress_ptr cinfo) +{ + phuff_entropy_ptr entropy = (phuff_entropy_ptr) cinfo->entropy; + int is_DC_band, bad; + int ci, coefi, tbl; + int *coef_bit_ptr; + jpeg_component_info * compptr; + + is_DC_band = (cinfo->Ss == 0); + + /* Validate scan parameters */ + bad = FALSE; + if (is_DC_band) { + if (cinfo->Se != 0) + bad = TRUE; + } else { + /* need not check Ss/Se < 0 since they came from unsigned bytes */ + if (cinfo->Ss > cinfo->Se || cinfo->Se >= DCTSIZE2) + bad = TRUE; + /* AC scans may have only one component */ + if (cinfo->comps_in_scan != 1) + bad = TRUE; + } + if (cinfo->Ah != 0) { + /* Successive approximation refinement scan: must have Al = Ah-1. */ + if (cinfo->Al != cinfo->Ah-1) + bad = TRUE; + } + if (cinfo->Al > 13) /* need not check for < 0 */ + bad = TRUE; + /* Arguably the maximum Al value should be less than 13 for 8-bit precision, + * but the spec doesn't say so, and we try to be liberal about what we + * accept. Note: large Al values could result in out-of-range DC + * coefficients during early scans, leading to bizarre displays due to + * overflows in the IDCT math. But we won't crash. + */ + if (bad) + ERREXIT4(cinfo, JERR_BAD_PROGRESSION, + cinfo->Ss, cinfo->Se, cinfo->Ah, cinfo->Al); + /* Update progression status, and verify that scan order is legal. + * Note that inter-scan inconsistencies are treated as warnings + * not fatal errors ... not clear if this is right way to behave. + */ + for (ci = 0; ci < cinfo->comps_in_scan; ci++) { + int cindex = cinfo->cur_comp_info[ci]->component_index; + coef_bit_ptr = & cinfo->coef_bits[cindex][0]; + if (!is_DC_band && coef_bit_ptr[0] < 0) /* AC without prior DC scan */ + WARNMS2(cinfo, JWRN_BOGUS_PROGRESSION, cindex, 0); + for (coefi = cinfo->Ss; coefi <= cinfo->Se; coefi++) { + int expected = (coef_bit_ptr[coefi] < 0) ? 0 : coef_bit_ptr[coefi]; + if (cinfo->Ah != expected) + WARNMS2(cinfo, JWRN_BOGUS_PROGRESSION, cindex, coefi); + coef_bit_ptr[coefi] = cinfo->Al; + } + } + + /* Select MCU decoding routine */ + if (cinfo->Ah == 0) { + if (is_DC_band) + entropy->pub.decode_mcu = decode_mcu_DC_first; + else + entropy->pub.decode_mcu = decode_mcu_AC_first; + } else { + if (is_DC_band) + entropy->pub.decode_mcu = decode_mcu_DC_refine; + else + entropy->pub.decode_mcu = decode_mcu_AC_refine; + } + + for (ci = 0; ci < cinfo->comps_in_scan; ci++) { + compptr = cinfo->cur_comp_info[ci]; + /* Make sure requested tables are present, and compute derived tables. + * We may build same derived table more than once, but it's not expensive. + */ + if (is_DC_band) { + if (cinfo->Ah == 0) { /* DC refinement needs no table */ + tbl = compptr->dc_tbl_no; + jpeg_make_d_derived_tbl(cinfo, TRUE, tbl, + & entropy->derived_tbls[tbl]); + } + } else { + tbl = compptr->ac_tbl_no; + jpeg_make_d_derived_tbl(cinfo, FALSE, tbl, + & entropy->derived_tbls[tbl]); + /* remember the single active table */ + entropy->ac_derived_tbl = entropy->derived_tbls[tbl]; + } + /* Initialize DC predictions to 0 */ + entropy->saved.last_dc_val[ci] = 0; + } + + /* Initialize bitread state variables */ + entropy->bitstate.bits_left = 0; + entropy->bitstate.get_buffer = 0; /* unnecessary, but keeps Purify quiet */ + entropy->pub.insufficient_data = FALSE; + + /* Initialize private state variables */ + entropy->saved.EOBRUN = 0; + + /* Initialize restart counter */ + entropy->restarts_to_go = cinfo->restart_interval; +} + + +/* + * Figure F.12: extend sign bit. + * On some machines, a shift and add will be faster than a table lookup. + */ + +#ifdef AVOID_TABLES + +#define HUFF_EXTEND(x,s) ((x) < (1<<((s)-1)) ? (x) + (((-1)<<(s)) + 1) : (x)) + +#else + +#define HUFF_EXTEND(x,s) ((x) < extend_test[s] ? (x) + extend_offset[s] : (x)) + +static const int extend_test[16] = /* entry n is 2**(n-1) */ + { 0, 0x0001, 0x0002, 0x0004, 0x0008, 0x0010, 0x0020, 0x0040, 0x0080, + 0x0100, 0x0200, 0x0400, 0x0800, 0x1000, 0x2000, 0x4000 }; + +static const int extend_offset[16] = /* entry n is (-1 << n) + 1 */ + { 0, ((-1)<<1) + 1, ((-1)<<2) + 1, ((-1)<<3) + 1, ((-1)<<4) + 1, + ((-1)<<5) + 1, ((-1)<<6) + 1, ((-1)<<7) + 1, ((-1)<<8) + 1, + ((-1)<<9) + 1, ((-1)<<10) + 1, ((-1)<<11) + 1, ((-1)<<12) + 1, + ((-1)<<13) + 1, ((-1)<<14) + 1, ((-1)<<15) + 1 }; + +#endif /* AVOID_TABLES */ + + +/* + * Check for a restart marker & resynchronize decoder. + * Returns FALSE if must suspend. + */ + +LOCAL(int) +process_restart (j_decompress_ptr cinfo) +{ + phuff_entropy_ptr entropy = (phuff_entropy_ptr) cinfo->entropy; + int ci; + + /* Throw away any unused bits remaining in bit buffer; */ + /* include any full bytes in next_marker's count of discarded bytes */ + cinfo->marker->discarded_bytes += entropy->bitstate.bits_left / 8; + entropy->bitstate.bits_left = 0; + + /* Advance past the RSTn marker */ + if (! (*cinfo->marker->read_restart_marker) (cinfo)) + return FALSE; + + /* Re-initialize DC predictions to 0 */ + for (ci = 0; ci < cinfo->comps_in_scan; ci++) + entropy->saved.last_dc_val[ci] = 0; + /* Re-init EOB run count, too */ + entropy->saved.EOBRUN = 0; + + /* Reset restart counter */ + entropy->restarts_to_go = cinfo->restart_interval; + + /* Reset out-of-data flag, unless read_restart_marker left us smack up + * against a marker. In that case we will end up treating the next data + * segment as empty, and we can avoid producing bogus output pixels by + * leaving the flag set. + */ + if (cinfo->unread_marker == 0) + entropy->pub.insufficient_data = FALSE; + + return TRUE; +} + + +/* + * Huffman MCU decoding. + * Each of these routines decodes and returns one MCU's worth of + * Huffman-compressed coefficients. + * The coefficients are reordered from zigzag order into natural array order, + * but are not dequantized. + * + * The i'th block of the MCU is stored into the block pointed to by + * MCU_data[i]. WE ASSUME THIS AREA IS INITIALLY ZEROED BY THE CALLER. + * + * We return FALSE if data source requested suspension. In that case no + * changes have been made to permanent state. (Exception: some output + * coefficients may already have been assigned. This is harmless for + * spectral selection, since we'll just re-assign them on the next call. + * Successive approximation AC refinement has to be more careful, however.) + */ + +/* + * MCU decoding for DC initial scan (either spectral selection, + * or first pass of successive approximation). + */ + +METHODDEF(int) +decode_mcu_DC_first (j_decompress_ptr cinfo, JBLOCKROW *MCU_data) +{ + phuff_entropy_ptr entropy = (phuff_entropy_ptr) cinfo->entropy; + int Al = cinfo->Al; + int s, r; + int blkn, ci; + JBLOCKROW block; + BITREAD_STATE_VARS; + savable_state state; + d_derived_tbl * tbl; + jpeg_component_info * compptr; + + /* Process restart marker if needed; may have to suspend */ + if (cinfo->restart_interval) { + if (entropy->restarts_to_go == 0) + if (! process_restart(cinfo)) + return FALSE; + } + + /* If we've run out of data, just leave the MCU set to zeroes. + * This way, we return uniform gray for the remainder of the segment. + */ + if (! entropy->pub.insufficient_data) { + + /* Load up working state */ + BITREAD_LOAD_STATE(cinfo,entropy->bitstate); + ASSIGN_STATE(state, entropy->saved); + + /* Outer loop handles each block in the MCU */ + + for (blkn = 0; blkn < cinfo->blocks_in_MCU; blkn++) { + block = MCU_data[blkn]; + ci = cinfo->MCU_membership[blkn]; + compptr = cinfo->cur_comp_info[ci]; + tbl = entropy->derived_tbls[compptr->dc_tbl_no]; + + /* Decode a single block's worth of coefficients */ + + /* Section F.2.2.1: decode the DC coefficient difference */ + HUFF_DECODE(s, br_state, tbl, return FALSE, label1); + if (s) { + CHECK_BIT_BUFFER(br_state, s, return FALSE); + r = GET_BITS(s); + s = HUFF_EXTEND(r, s); + } + + /* Convert DC difference to actual value, update last_dc_val */ + s += state.last_dc_val[ci]; + state.last_dc_val[ci] = s; + /* Scale and output the coefficient (assumes jpeg_natural_order[0]=0) */ + (*block)[0] = (JCOEF) (s << Al); + } + + /* Completed MCU, so update state */ + BITREAD_SAVE_STATE(cinfo,entropy->bitstate); + ASSIGN_STATE(entropy->saved, state); + } + + /* Account for restart interval (no-op if not using restarts) */ + entropy->restarts_to_go--; + + return TRUE; +} + + +/* + * MCU decoding for AC initial scan (either spectral selection, + * or first pass of successive approximation). + */ + +METHODDEF(int) +decode_mcu_AC_first (j_decompress_ptr cinfo, JBLOCKROW *MCU_data) +{ + phuff_entropy_ptr entropy = (phuff_entropy_ptr) cinfo->entropy; + int Se = cinfo->Se; + int Al = cinfo->Al; + int s, k, r; + unsigned int EOBRUN; + JBLOCKROW block; + BITREAD_STATE_VARS; + d_derived_tbl * tbl; + + /* Process restart marker if needed; may have to suspend */ + if (cinfo->restart_interval) { + if (entropy->restarts_to_go == 0) + if (! process_restart(cinfo)) + return FALSE; + } + + /* If we've run out of data, just leave the MCU set to zeroes. + * This way, we return uniform gray for the remainder of the segment. + */ + if (! entropy->pub.insufficient_data) { + + /* Load up working state. + * We can avoid loading/saving bitread state if in an EOB run. + */ + EOBRUN = entropy->saved.EOBRUN; /* only part of saved state we need */ + + /* There is always only one block per MCU */ + + if (EOBRUN > 0) /* if it's a band of zeroes... */ + EOBRUN--; /* ...process it now (we do nothing) */ + else { + BITREAD_LOAD_STATE(cinfo,entropy->bitstate); + block = MCU_data[0]; + tbl = entropy->ac_derived_tbl; + + for (k = cinfo->Ss; k <= Se; k++) { + HUFF_DECODE(s, br_state, tbl, return FALSE, label2); + r = s >> 4; + s &= 15; + if (s) { + k += r; + CHECK_BIT_BUFFER(br_state, s, return FALSE); + r = GET_BITS(s); + s = HUFF_EXTEND(r, s); + /* Scale and output coefficient in natural (dezigzagged) order */ + (*block)[jpeg_natural_order[k]] = (JCOEF) (s << Al); + } else { + if (r == 15) { /* ZRL */ + k += 15; /* skip 15 zeroes in band */ + } else { /* EOBr, run length is 2^r + appended bits */ + EOBRUN = 1 << r; + if (r) { /* EOBr, r > 0 */ + CHECK_BIT_BUFFER(br_state, r, return FALSE); + r = GET_BITS(r); + EOBRUN += r; + } + EOBRUN--; /* this band is processed at this moment */ + break; /* force end-of-band */ + } + } + } + + BITREAD_SAVE_STATE(cinfo,entropy->bitstate); + } + + /* Completed MCU, so update state */ + entropy->saved.EOBRUN = EOBRUN; /* only part of saved state we need */ + } + + /* Account for restart interval (no-op if not using restarts) */ + entropy->restarts_to_go--; + + return TRUE; +} + + +/* + * MCU decoding for DC successive approximation refinement scan. + * Note: we assume such scans can be multi-component, although the spec + * is not very clear on the point. + */ + +METHODDEF(int) +decode_mcu_DC_refine (j_decompress_ptr cinfo, JBLOCKROW *MCU_data) +{ + phuff_entropy_ptr entropy = (phuff_entropy_ptr) cinfo->entropy; + int p1 = 1 << cinfo->Al; /* 1 in the bit position being coded */ + int blkn; + JBLOCKROW block; + BITREAD_STATE_VARS; + + /* Process restart marker if needed; may have to suspend */ + if (cinfo->restart_interval) { + if (entropy->restarts_to_go == 0) + if (! process_restart(cinfo)) + return FALSE; + } + + /* Not worth the cycles to check insufficient_data here, + * since we will not change the data anyway if we read zeroes. + */ + + /* Load up working state */ + BITREAD_LOAD_STATE(cinfo,entropy->bitstate); + + /* Outer loop handles each block in the MCU */ + + for (blkn = 0; blkn < cinfo->blocks_in_MCU; blkn++) { + block = MCU_data[blkn]; + + /* Encoded data is simply the next bit of the two's-complement DC value */ + CHECK_BIT_BUFFER(br_state, 1, return FALSE); + if (GET_BITS(1)) + (*block)[0] |= p1; + /* Note: since we use |=, repeating the assignment later is safe */ + } + + /* Completed MCU, so update state */ + BITREAD_SAVE_STATE(cinfo,entropy->bitstate); + + /* Account for restart interval (no-op if not using restarts) */ + entropy->restarts_to_go--; + + return TRUE; +} + + +/* + * MCU decoding for AC successive approximation refinement scan. + */ + +METHODDEF(int) +decode_mcu_AC_refine (j_decompress_ptr cinfo, JBLOCKROW *MCU_data) +{ + phuff_entropy_ptr entropy = (phuff_entropy_ptr) cinfo->entropy; + int Se = cinfo->Se; + int p1 = 1 << cinfo->Al; /* 1 in the bit position being coded */ + int m1 = (-1) << cinfo->Al; /* -1 in the bit position being coded */ + int s, k, r; + unsigned int EOBRUN; + JBLOCKROW block; + JCOEFPTR thiscoef; + BITREAD_STATE_VARS; + d_derived_tbl * tbl; + int num_newnz; + int newnz_pos[DCTSIZE2]; + + /* Process restart marker if needed; may have to suspend */ + if (cinfo->restart_interval) { + if (entropy->restarts_to_go == 0) + if (! process_restart(cinfo)) + return FALSE; + } + + /* If we've run out of data, don't modify the MCU. + */ + if (! entropy->pub.insufficient_data) { + + /* Load up working state */ + BITREAD_LOAD_STATE(cinfo,entropy->bitstate); + EOBRUN = entropy->saved.EOBRUN; /* only part of saved state we need */ + + /* There is always only one block per MCU */ + block = MCU_data[0]; + tbl = entropy->ac_derived_tbl; + + /* If we are forced to suspend, we must undo the assignments to any newly + * nonzero coefficients in the block, because otherwise we'd get confused + * next time about which coefficients were already nonzero. + * But we need not undo addition of bits to already-nonzero coefficients; + * instead, we can test the current bit to see if we already did it. + */ + num_newnz = 0; + + /* initialize coefficient loop counter to start of band */ + k = cinfo->Ss; + + if (EOBRUN == 0) { + for (; k <= Se; k++) { + HUFF_DECODE(s, br_state, tbl, goto undoit, label3); + r = s >> 4; + s &= 15; + if (s) { + if (s != 1) /* size of new coef should always be 1 */ + WARNMS(cinfo, JWRN_HUFF_BAD_CODE); + CHECK_BIT_BUFFER(br_state, 1, goto undoit); + if (GET_BITS(1)) + s = p1; /* newly nonzero coef is positive */ + else + s = m1; /* newly nonzero coef is negative */ + } else { + if (r != 15) { + EOBRUN = 1 << r; /* EOBr, run length is 2^r + appended bits */ + if (r) { + CHECK_BIT_BUFFER(br_state, r, goto undoit); + r = GET_BITS(r); + EOBRUN += r; + } + break; /* rest of block is handled by EOB logic */ + } + /* note s = 0 for processing ZRL */ + } + /* Advance over already-nonzero coefs and r still-zero coefs, + * appending correction bits to the nonzeroes. A correction bit is 1 + * if the absolute value of the coefficient must be increased. + */ + do { + thiscoef = *block + jpeg_natural_order[k]; + if (*thiscoef != 0) { + CHECK_BIT_BUFFER(br_state, 1, goto undoit); + if (GET_BITS(1)) { + if ((*thiscoef & p1) == 0) { /* do nothing if already set it */ + if (*thiscoef >= 0) + *thiscoef += p1; + else + *thiscoef += m1; + } + } + } else { + if (--r < 0) + break; /* reached target zero coefficient */ + } + k++; + } while (k <= Se); + if (s) { + int pos = jpeg_natural_order[k]; + /* Output newly nonzero coefficient */ + (*block)[pos] = (JCOEF) s; + /* Remember its position in case we have to suspend */ + newnz_pos[num_newnz++] = pos; + } + } + } + + if (EOBRUN > 0) { + /* Scan any remaining coefficient positions after the end-of-band + * (the last newly nonzero coefficient, if any). Append a correction + * bit to each already-nonzero coefficient. A correction bit is 1 + * if the absolute value of the coefficient must be increased. + */ + for (; k <= Se; k++) { + thiscoef = *block + jpeg_natural_order[k]; + if (*thiscoef != 0) { + CHECK_BIT_BUFFER(br_state, 1, goto undoit); + if (GET_BITS(1)) { + if ((*thiscoef & p1) == 0) { /* do nothing if already changed it */ + if (*thiscoef >= 0) + *thiscoef += p1; + else + *thiscoef += m1; + } + } + } + } + /* Count one block completed in EOB run */ + EOBRUN--; + } + + /* Completed MCU, so update state */ + BITREAD_SAVE_STATE(cinfo,entropy->bitstate); + entropy->saved.EOBRUN = EOBRUN; /* only part of saved state we need */ + } + + /* Account for restart interval (no-op if not using restarts) */ + entropy->restarts_to_go--; + + return TRUE; + +undoit: + /* Re-zero any output coefficients that we made newly nonzero */ + while (num_newnz > 0) + (*block)[newnz_pos[--num_newnz]] = 0; + + return FALSE; +} + + +/* + * Module initialization routine for progressive Huffman entropy decoding. + */ + +GLOBAL(void) +jinit_phuff_decoder (j_decompress_ptr cinfo) +{ + phuff_entropy_ptr entropy; + int *coef_bit_ptr; + int ci, i; + + entropy = (phuff_entropy_ptr) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE, + SIZEOF(phuff_entropy_decoder)); + cinfo->entropy = (struct jpeg_entropy_decoder *) entropy; + entropy->pub.start_pass = start_pass_phuff_decoder; + + /* Mark derived tables unallocated */ + for (i = 0; i < NUM_HUFF_TBLS; i++) { + entropy->derived_tbls[i] = NULL; + } + + /* Create progression status table */ + cinfo->coef_bits = (int (*)[DCTSIZE2]) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE, + cinfo->num_components*DCTSIZE2*SIZEOF(int)); + coef_bit_ptr = & cinfo->coef_bits[0][0]; + for (ci = 0; ci < cinfo->num_components; ci++) + for (i = 0; i < DCTSIZE2; i++) + *coef_bit_ptr++ = -1; +} + +#endif /* D_PROGRESSIVE_SUPPORTED */ diff --git a/ml/dlib/dlib/external/libjpeg/jdpostct.cpp b/ml/dlib/dlib/external/libjpeg/jdpostct.cpp new file mode 100644 index 000000000..63e10ec17 --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jdpostct.cpp @@ -0,0 +1,290 @@ +/* + * jdpostct.c + * + * Copyright (C) 1994-1996, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This file contains the decompression postprocessing controller. + * This controller manages the upsampling, color conversion, and color + * quantization/reduction steps; specifically, it controls the buffering + * between upsample/color conversion and color quantization/reduction. + * + * If no color quantization/reduction is required, then this module has no + * work to do, and it just hands off to the upsample/color conversion code. + * An integrated upsample/convert/quantize process would replace this module + * entirely. + */ + +#define JPEG_INTERNALS +#include "jinclude.h" +#include "jpeglib.h" + + +/* Private buffer controller object */ + +typedef struct { + struct jpeg_d_post_controller pub; /* public fields */ + + /* Color quantization source buffer: this holds output data from + * the upsample/color conversion step to be passed to the quantizer. + * For two-pass color quantization, we need a full-image buffer; + * for one-pass operation, a strip buffer is sufficient. + */ + jvirt_sarray_ptr whole_image; /* virtual array, or NULL if one-pass */ + JSAMPARRAY buffer; /* strip buffer, or current strip of virtual */ + JDIMENSION strip_height; /* buffer size in rows */ + /* for two-pass mode only: */ + JDIMENSION starting_row; /* row # of first row in current strip */ + JDIMENSION next_row; /* index of next row to fill/empty in strip */ +} my_post_controller; + +typedef my_post_controller * my_post_ptr; + + +/* Forward declarations */ +METHODDEF(void) post_process_1pass + JPP((j_decompress_ptr cinfo, + JSAMPIMAGE input_buf, JDIMENSION *in_row_group_ctr, + JDIMENSION in_row_groups_avail, + JSAMPARRAY output_buf, JDIMENSION *out_row_ctr, + JDIMENSION out_rows_avail)); +#ifdef QUANT_2PASS_SUPPORTED +METHODDEF(void) post_process_prepass + JPP((j_decompress_ptr cinfo, + JSAMPIMAGE input_buf, JDIMENSION *in_row_group_ctr, + JDIMENSION in_row_groups_avail, + JSAMPARRAY output_buf, JDIMENSION *out_row_ctr, + JDIMENSION out_rows_avail)); +METHODDEF(void) post_process_2pass + JPP((j_decompress_ptr cinfo, + JSAMPIMAGE input_buf, JDIMENSION *in_row_group_ctr, + JDIMENSION in_row_groups_avail, + JSAMPARRAY output_buf, JDIMENSION *out_row_ctr, + JDIMENSION out_rows_avail)); +#endif + + +/* + * Initialize for a processing pass. + */ + +METHODDEF(void) +start_pass_dpost (j_decompress_ptr cinfo, J_BUF_MODE pass_mode) +{ + my_post_ptr post = (my_post_ptr) cinfo->post; + + switch (pass_mode) { + case JBUF_PASS_THRU: + if (cinfo->quantize_colors) { + /* Single-pass processing with color quantization. */ + post->pub.post_process_data = post_process_1pass; + /* We could be doing buffered-image output before starting a 2-pass + * color quantization; in that case, jinit_d_post_controller did not + * allocate a strip buffer. Use the virtual-array buffer as workspace. + */ + if (post->buffer == NULL) { + post->buffer = (*cinfo->mem->access_virt_sarray) + ((j_common_ptr) cinfo, post->whole_image, + (JDIMENSION) 0, post->strip_height, TRUE); + } + } else { + /* For single-pass processing without color quantization, + * I have no work to do; just call the upsampler directly. + */ + post->pub.post_process_data = cinfo->upsample->upsample; + } + break; +#ifdef QUANT_2PASS_SUPPORTED + case JBUF_SAVE_AND_PASS: + /* First pass of 2-pass quantization */ + if (post->whole_image == NULL) + ERREXIT(cinfo, JERR_BAD_BUFFER_MODE); + post->pub.post_process_data = post_process_prepass; + break; + case JBUF_CRANK_DEST: + /* Second pass of 2-pass quantization */ + if (post->whole_image == NULL) + ERREXIT(cinfo, JERR_BAD_BUFFER_MODE); + post->pub.post_process_data = post_process_2pass; + break; +#endif /* QUANT_2PASS_SUPPORTED */ + default: + ERREXIT(cinfo, JERR_BAD_BUFFER_MODE); + break; + } + post->starting_row = post->next_row = 0; +} + + +/* + * Process some data in the one-pass (strip buffer) case. + * This is used for color precision reduction as well as one-pass quantization. + */ + +METHODDEF(void) +post_process_1pass (j_decompress_ptr cinfo, + JSAMPIMAGE input_buf, JDIMENSION *in_row_group_ctr, + JDIMENSION in_row_groups_avail, + JSAMPARRAY output_buf, JDIMENSION *out_row_ctr, + JDIMENSION out_rows_avail) +{ + my_post_ptr post = (my_post_ptr) cinfo->post; + JDIMENSION num_rows, max_rows; + + /* Fill the buffer, but not more than what we can dump out in one go. */ + /* Note we rely on the upsampler to detect bottom of image. */ + max_rows = out_rows_avail - *out_row_ctr; + if (max_rows > post->strip_height) + max_rows = post->strip_height; + num_rows = 0; + (*cinfo->upsample->upsample) (cinfo, + input_buf, in_row_group_ctr, in_row_groups_avail, + post->buffer, &num_rows, max_rows); + /* Quantize and emit data. */ + (*cinfo->cquantize->color_quantize) (cinfo, + post->buffer, output_buf + *out_row_ctr, (int) num_rows); + *out_row_ctr += num_rows; +} + + +#ifdef QUANT_2PASS_SUPPORTED + +/* + * Process some data in the first pass of 2-pass quantization. + */ + +METHODDEF(void) +post_process_prepass (j_decompress_ptr cinfo, + JSAMPIMAGE input_buf, JDIMENSION *in_row_group_ctr, + JDIMENSION in_row_groups_avail, + JSAMPARRAY , JDIMENSION *out_row_ctr, + JDIMENSION ) +{ + my_post_ptr post = (my_post_ptr) cinfo->post; + JDIMENSION old_next_row, num_rows; + + /* Reposition virtual buffer if at start of strip. */ + if (post->next_row == 0) { + post->buffer = (*cinfo->mem->access_virt_sarray) + ((j_common_ptr) cinfo, post->whole_image, + post->starting_row, post->strip_height, TRUE); + } + + /* Upsample some data (up to a strip height's worth). */ + old_next_row = post->next_row; + (*cinfo->upsample->upsample) (cinfo, + input_buf, in_row_group_ctr, in_row_groups_avail, + post->buffer, &post->next_row, post->strip_height); + + /* Allow quantizer to scan new data. No data is emitted, */ + /* but we advance out_row_ctr so outer loop can tell when we're done. */ + if (post->next_row > old_next_row) { + num_rows = post->next_row - old_next_row; + (*cinfo->cquantize->color_quantize) (cinfo, post->buffer + old_next_row, + (JSAMPARRAY) NULL, (int) num_rows); + *out_row_ctr += num_rows; + } + + /* Advance if we filled the strip. */ + if (post->next_row >= post->strip_height) { + post->starting_row += post->strip_height; + post->next_row = 0; + } +} + + +/* + * Process some data in the second pass of 2-pass quantization. + */ + +METHODDEF(void) +post_process_2pass (j_decompress_ptr cinfo, + JSAMPIMAGE , JDIMENSION *, + JDIMENSION , + JSAMPARRAY output_buf, JDIMENSION *out_row_ctr, + JDIMENSION out_rows_avail) +{ + my_post_ptr post = (my_post_ptr) cinfo->post; + JDIMENSION num_rows, max_rows; + + /* Reposition virtual buffer if at start of strip. */ + if (post->next_row == 0) { + post->buffer = (*cinfo->mem->access_virt_sarray) + ((j_common_ptr) cinfo, post->whole_image, + post->starting_row, post->strip_height, FALSE); + } + + /* Determine number of rows to emit. */ + num_rows = post->strip_height - post->next_row; /* available in strip */ + max_rows = out_rows_avail - *out_row_ctr; /* available in output area */ + if (num_rows > max_rows) + num_rows = max_rows; + /* We have to check bottom of image here, can't depend on upsampler. */ + max_rows = cinfo->output_height - post->starting_row; + if (num_rows > max_rows) + num_rows = max_rows; + + /* Quantize and emit data. */ + (*cinfo->cquantize->color_quantize) (cinfo, + post->buffer + post->next_row, output_buf + *out_row_ctr, + (int) num_rows); + *out_row_ctr += num_rows; + + /* Advance if we filled the strip. */ + post->next_row += num_rows; + if (post->next_row >= post->strip_height) { + post->starting_row += post->strip_height; + post->next_row = 0; + } +} + +#endif /* QUANT_2PASS_SUPPORTED */ + + +/* + * Initialize postprocessing controller. + */ + +GLOBAL(void) +jinit_d_post_controller (j_decompress_ptr cinfo, int need_full_buffer) +{ + my_post_ptr post; + + post = (my_post_ptr) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE, + SIZEOF(my_post_controller)); + cinfo->post = (struct jpeg_d_post_controller *) post; + post->pub.start_pass = start_pass_dpost; + post->whole_image = NULL; /* flag for no virtual arrays */ + post->buffer = NULL; /* flag for no strip buffer */ + + /* Create the quantization buffer, if needed */ + if (cinfo->quantize_colors) { + /* The buffer strip height is max_v_samp_factor, which is typically + * an efficient number of rows for upsampling to return. + * (In the presence of output rescaling, we might want to be smarter?) + */ + post->strip_height = (JDIMENSION) cinfo->max_v_samp_factor; + if (need_full_buffer) { + /* Two-pass color quantization: need full-image storage. */ + /* We round up the number of rows to a multiple of the strip height. */ +#ifdef QUANT_2PASS_SUPPORTED + post->whole_image = (*cinfo->mem->request_virt_sarray) + ((j_common_ptr) cinfo, JPOOL_IMAGE, FALSE, + cinfo->output_width * cinfo->out_color_components, + (JDIMENSION) jround_up((long) cinfo->output_height, + (long) post->strip_height), + post->strip_height); +#else + ERREXIT(cinfo, JERR_BAD_BUFFER_MODE); +#endif /* QUANT_2PASS_SUPPORTED */ + } else { + /* One-pass color quantization: just make a strip buffer. */ + post->buffer = (*cinfo->mem->alloc_sarray) + ((j_common_ptr) cinfo, JPOOL_IMAGE, + cinfo->output_width * cinfo->out_color_components, + post->strip_height); + } + } +} diff --git a/ml/dlib/dlib/external/libjpeg/jdsample.cpp b/ml/dlib/dlib/external/libjpeg/jdsample.cpp new file mode 100644 index 000000000..647c8bf45 --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jdsample.cpp @@ -0,0 +1,478 @@ +/* + * jdsample.c + * + * Copyright (C) 1991-1996, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This file contains upsampling routines. + * + * Upsampling input data is counted in "row groups". A row group + * is defined to be (v_samp_factor * DCT_scaled_size / min_DCT_scaled_size) + * sample rows of each component. Upsampling will normally produce + * max_v_samp_factor pixel rows from each row group (but this could vary + * if the upsampler is applying a scale factor of its own). + * + * An excellent reference for image resampling is + * Digital Image Warping, George Wolberg, 1990. + * Pub. by IEEE Computer Society Press, Los Alamitos, CA. ISBN 0-8186-8944-7. + */ + +#define JPEG_INTERNALS +#include "jinclude.h" +#include "jpeglib.h" + + +/* Pointer to routine to upsample a single component */ +typedef JMETHOD(void, upsample1_ptr, + (j_decompress_ptr cinfo, jpeg_component_info * compptr, + JSAMPARRAY input_data, JSAMPARRAY * output_data_ptr)); + +/* Private subobject */ + +typedef struct { + struct jpeg_upsampler pub; /* public fields */ + + /* Color conversion buffer. When using separate upsampling and color + * conversion steps, this buffer holds one upsampled row group until it + * has been color converted and output. + * Note: we do not allocate any storage for component(s) which are full-size, + * ie do not need rescaling. The corresponding entry of color_buf[] is + * simply set to point to the input data array, thereby avoiding copying. + */ + JSAMPARRAY color_buf[MAX_COMPONENTS]; + + /* Per-component upsampling method pointers */ + upsample1_ptr methods[MAX_COMPONENTS]; + + int next_row_out; /* counts rows emitted from color_buf */ + JDIMENSION rows_to_go; /* counts rows remaining in image */ + + /* Height of an input row group for each component. */ + int rowgroup_height[MAX_COMPONENTS]; + + /* These arrays save pixel expansion factors so that int_expand need not + * recompute them each time. They are unused for other upsampling methods. + */ + unsigned char h_expand[MAX_COMPONENTS]; + unsigned char v_expand[MAX_COMPONENTS]; +} my_upsampler; + +typedef my_upsampler * my_upsample_ptr; + + +/* + * Initialize for an upsampling pass. + */ + +METHODDEF(void) +start_pass_upsample (j_decompress_ptr cinfo) +{ + my_upsample_ptr upsample = (my_upsample_ptr) cinfo->upsample; + + /* Mark the conversion buffer empty */ + upsample->next_row_out = cinfo->max_v_samp_factor; + /* Initialize total-height counter for detecting bottom of image */ + upsample->rows_to_go = cinfo->output_height; +} + + +/* + * Control routine to do upsampling (and color conversion). + * + * In this version we upsample each component independently. + * We upsample one row group into the conversion buffer, then apply + * color conversion a row at a time. + */ + +METHODDEF(void) +sep_upsample (j_decompress_ptr cinfo, + JSAMPIMAGE input_buf, JDIMENSION *in_row_group_ctr, + JDIMENSION , + JSAMPARRAY output_buf, JDIMENSION *out_row_ctr, + JDIMENSION out_rows_avail) +{ + my_upsample_ptr upsample = (my_upsample_ptr) cinfo->upsample; + int ci; + jpeg_component_info * compptr; + JDIMENSION num_rows; + + /* Fill the conversion buffer, if it's empty */ + if (upsample->next_row_out >= cinfo->max_v_samp_factor) { + for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components; + ci++, compptr++) { + /* Invoke per-component upsample method. Notice we pass a POINTER + * to color_buf[ci], so that fullsize_upsample can change it. + */ + (*upsample->methods[ci]) (cinfo, compptr, + input_buf[ci] + (*in_row_group_ctr * upsample->rowgroup_height[ci]), + upsample->color_buf + ci); + } + upsample->next_row_out = 0; + } + + /* Color-convert and emit rows */ + + /* How many we have in the buffer: */ + num_rows = (JDIMENSION) (cinfo->max_v_samp_factor - upsample->next_row_out); + /* Not more than the distance to the end of the image. Need this test + * in case the image height is not a multiple of max_v_samp_factor: + */ + if (num_rows > upsample->rows_to_go) + num_rows = upsample->rows_to_go; + /* And not more than what the client can accept: */ + out_rows_avail -= *out_row_ctr; + if (num_rows > out_rows_avail) + num_rows = out_rows_avail; + + (*cinfo->cconvert->color_convert) (cinfo, upsample->color_buf, + (JDIMENSION) upsample->next_row_out, + output_buf + *out_row_ctr, + (int) num_rows); + + /* Adjust counts */ + *out_row_ctr += num_rows; + upsample->rows_to_go -= num_rows; + upsample->next_row_out += num_rows; + /* When the buffer is emptied, declare this input row group consumed */ + if (upsample->next_row_out >= cinfo->max_v_samp_factor) + (*in_row_group_ctr)++; +} + + +/* + * These are the routines invoked by sep_upsample to upsample pixel values + * of a single component. One row group is processed per call. + */ + + +/* + * For full-size components, we just make color_buf[ci] point at the + * input buffer, and thus avoid copying any data. Note that this is + * safe only because sep_upsample doesn't declare the input row group + * "consumed" until we are done color converting and emitting it. + */ + +METHODDEF(void) +fullsize_upsample (j_decompress_ptr , jpeg_component_info * , + JSAMPARRAY input_data, JSAMPARRAY * output_data_ptr) +{ + *output_data_ptr = input_data; +} + + +/* + * This is a no-op version used for "uninteresting" components. + * These components will not be referenced by color conversion. + */ + +METHODDEF(void) +noop_upsample (j_decompress_ptr , jpeg_component_info * , + JSAMPARRAY , JSAMPARRAY * output_data_ptr) +{ + *output_data_ptr = NULL; /* safety check */ +} + + +/* + * This version handles any integral sampling ratios. + * This is not used for typical JPEG files, so it need not be fast. + * Nor, for that matter, is it particularly accurate: the algorithm is + * simple replication of the input pixel onto the corresponding output + * pixels. The hi-falutin sampling literature refers to this as a + * "box filter". A box filter tends to introduce visible artifacts, + * so if you are actually going to use 3:1 or 4:1 sampling ratios + * you would be well advised to improve this code. + */ + +METHODDEF(void) +int_upsample (j_decompress_ptr cinfo, jpeg_component_info * compptr, + JSAMPARRAY input_data, JSAMPARRAY * output_data_ptr) +{ + my_upsample_ptr upsample = (my_upsample_ptr) cinfo->upsample; + JSAMPARRAY output_data = *output_data_ptr; + JSAMPROW inptr, outptr; + JSAMPLE invalue; + int h; + JSAMPROW outend; + int h_expand, v_expand; + int inrow, outrow; + + h_expand = upsample->h_expand[compptr->component_index]; + v_expand = upsample->v_expand[compptr->component_index]; + + inrow = outrow = 0; + while (outrow < cinfo->max_v_samp_factor) { + /* Generate one output row with proper horizontal expansion */ + inptr = input_data[inrow]; + outptr = output_data[outrow]; + outend = outptr + cinfo->output_width; + while (outptr < outend) { + invalue = *inptr++; /* don't need GETJSAMPLE() here */ + for (h = h_expand; h > 0; h--) { + *outptr++ = invalue; + } + } + /* Generate any additional output rows by duplicating the first one */ + if (v_expand > 1) { + jcopy_sample_rows(output_data, outrow, output_data, outrow+1, + v_expand-1, cinfo->output_width); + } + inrow++; + outrow += v_expand; + } +} + + +/* + * Fast processing for the common case of 2:1 horizontal and 1:1 vertical. + * It's still a box filter. + */ + +METHODDEF(void) +h2v1_upsample (j_decompress_ptr cinfo, jpeg_component_info * , + JSAMPARRAY input_data, JSAMPARRAY * output_data_ptr) +{ + JSAMPARRAY output_data = *output_data_ptr; + JSAMPROW inptr, outptr; + JSAMPLE invalue; + JSAMPROW outend; + int inrow; + + for (inrow = 0; inrow < cinfo->max_v_samp_factor; inrow++) { + inptr = input_data[inrow]; + outptr = output_data[inrow]; + outend = outptr + cinfo->output_width; + while (outptr < outend) { + invalue = *inptr++; /* don't need GETJSAMPLE() here */ + *outptr++ = invalue; + *outptr++ = invalue; + } + } +} + + +/* + * Fast processing for the common case of 2:1 horizontal and 2:1 vertical. + * It's still a box filter. + */ + +METHODDEF(void) +h2v2_upsample (j_decompress_ptr cinfo, jpeg_component_info * , + JSAMPARRAY input_data, JSAMPARRAY * output_data_ptr) +{ + JSAMPARRAY output_data = *output_data_ptr; + JSAMPROW inptr, outptr; + JSAMPLE invalue; + JSAMPROW outend; + int inrow, outrow; + + inrow = outrow = 0; + while (outrow < cinfo->max_v_samp_factor) { + inptr = input_data[inrow]; + outptr = output_data[outrow]; + outend = outptr + cinfo->output_width; + while (outptr < outend) { + invalue = *inptr++; /* don't need GETJSAMPLE() here */ + *outptr++ = invalue; + *outptr++ = invalue; + } + jcopy_sample_rows(output_data, outrow, output_data, outrow+1, + 1, cinfo->output_width); + inrow++; + outrow += 2; + } +} + + +/* + * Fancy processing for the common case of 2:1 horizontal and 1:1 vertical. + * + * The upsampling algorithm is linear interpolation between pixel centers, + * also known as a "triangle filter". This is a good compromise between + * speed and visual quality. The centers of the output pixels are 1/4 and 3/4 + * of the way between input pixel centers. + * + * A note about the "bias" calculations: when rounding fractional values to + * integer, we do not want to always round 0.5 up to the next integer. + * If we did that, we'd introduce a noticeable bias towards larger values. + * Instead, this code is arranged so that 0.5 will be rounded up or down at + * alternate pixel locations (a simple ordered dither pattern). + */ + +METHODDEF(void) +h2v1_fancy_upsample (j_decompress_ptr cinfo, jpeg_component_info * compptr, + JSAMPARRAY input_data, JSAMPARRAY * output_data_ptr) +{ + JSAMPARRAY output_data = *output_data_ptr; + JSAMPROW inptr, outptr; + int invalue; + JDIMENSION colctr; + int inrow; + + for (inrow = 0; inrow < cinfo->max_v_samp_factor; inrow++) { + inptr = input_data[inrow]; + outptr = output_data[inrow]; + /* Special case for first column */ + invalue = GETJSAMPLE(*inptr++); + *outptr++ = (JSAMPLE) invalue; + *outptr++ = (JSAMPLE) ((invalue * 3 + GETJSAMPLE(*inptr) + 2) >> 2); + + for (colctr = compptr->downsampled_width - 2; colctr > 0; colctr--) { + /* General case: 3/4 * nearer pixel + 1/4 * further pixel */ + invalue = GETJSAMPLE(*inptr++) * 3; + *outptr++ = (JSAMPLE) ((invalue + GETJSAMPLE(inptr[-2]) + 1) >> 2); + *outptr++ = (JSAMPLE) ((invalue + GETJSAMPLE(*inptr) + 2) >> 2); + } + + /* Special case for last column */ + invalue = GETJSAMPLE(*inptr); + *outptr++ = (JSAMPLE) ((invalue * 3 + GETJSAMPLE(inptr[-1]) + 1) >> 2); + *outptr++ = (JSAMPLE) invalue; + } +} + + +/* + * Fancy processing for the common case of 2:1 horizontal and 2:1 vertical. + * Again a triangle filter; see comments for h2v1 case, above. + * + * It is OK for us to reference the adjacent input rows because we demanded + * context from the main buffer controller (see initialization code). + */ + +METHODDEF(void) +h2v2_fancy_upsample (j_decompress_ptr cinfo, jpeg_component_info * compptr, + JSAMPARRAY input_data, JSAMPARRAY * output_data_ptr) +{ + JSAMPARRAY output_data = *output_data_ptr; + JSAMPROW inptr0, inptr1, outptr; +#if BITS_IN_JSAMPLE == 8 + int thiscolsum, lastcolsum, nextcolsum; +#else + long thiscolsum, lastcolsum, nextcolsum; +#endif + JDIMENSION colctr; + int inrow, outrow, v; + + inrow = outrow = 0; + while (outrow < cinfo->max_v_samp_factor) { + for (v = 0; v < 2; v++) { + /* inptr0 points to nearest input row, inptr1 points to next nearest */ + inptr0 = input_data[inrow]; + if (v == 0) /* next nearest is row above */ + inptr1 = input_data[inrow-1]; + else /* next nearest is row below */ + inptr1 = input_data[inrow+1]; + outptr = output_data[outrow++]; + + /* Special case for first column */ + thiscolsum = GETJSAMPLE(*inptr0++) * 3 + GETJSAMPLE(*inptr1++); + nextcolsum = GETJSAMPLE(*inptr0++) * 3 + GETJSAMPLE(*inptr1++); + *outptr++ = (JSAMPLE) ((thiscolsum * 4 + 8) >> 4); + *outptr++ = (JSAMPLE) ((thiscolsum * 3 + nextcolsum + 7) >> 4); + lastcolsum = thiscolsum; thiscolsum = nextcolsum; + + for (colctr = compptr->downsampled_width - 2; colctr > 0; colctr--) { + /* General case: 3/4 * nearer pixel + 1/4 * further pixel in each */ + /* dimension, thus 9/16, 3/16, 3/16, 1/16 overall */ + nextcolsum = GETJSAMPLE(*inptr0++) * 3 + GETJSAMPLE(*inptr1++); + *outptr++ = (JSAMPLE) ((thiscolsum * 3 + lastcolsum + 8) >> 4); + *outptr++ = (JSAMPLE) ((thiscolsum * 3 + nextcolsum + 7) >> 4); + lastcolsum = thiscolsum; thiscolsum = nextcolsum; + } + + /* Special case for last column */ + *outptr++ = (JSAMPLE) ((thiscolsum * 3 + lastcolsum + 8) >> 4); + *outptr++ = (JSAMPLE) ((thiscolsum * 4 + 7) >> 4); + } + inrow++; + } +} + + +/* + * Module initialization routine for upsampling. + */ + +GLOBAL(void) +jinit_upsampler (j_decompress_ptr cinfo) +{ + my_upsample_ptr upsample; + int ci; + jpeg_component_info * compptr; + int need_buffer, do_fancy; + int h_in_group, v_in_group, h_out_group, v_out_group; + + upsample = (my_upsample_ptr) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE, + SIZEOF(my_upsampler)); + cinfo->upsample = (struct jpeg_upsampler *) upsample; + upsample->pub.start_pass = start_pass_upsample; + upsample->pub.upsample = sep_upsample; + upsample->pub.need_context_rows = FALSE; /* until we find out differently */ + + if (cinfo->CCIR601_sampling) /* this isn't supported */ + ERREXIT(cinfo, JERR_CCIR601_NOTIMPL); + + /* jdmainct.c doesn't support context rows when min_DCT_scaled_size = 1, + * so don't ask for it. + */ + do_fancy = cinfo->do_fancy_upsampling && cinfo->min_DCT_scaled_size > 1; + + /* Verify we can handle the sampling factors, select per-component methods, + * and create storage as needed. + */ + for (ci = 0, compptr = cinfo->comp_info; ci < cinfo->num_components; + ci++, compptr++) { + /* Compute size of an "input group" after IDCT scaling. This many samples + * are to be converted to max_h_samp_factor * max_v_samp_factor pixels. + */ + h_in_group = (compptr->h_samp_factor * compptr->DCT_scaled_size) / + cinfo->min_DCT_scaled_size; + v_in_group = (compptr->v_samp_factor * compptr->DCT_scaled_size) / + cinfo->min_DCT_scaled_size; + h_out_group = cinfo->max_h_samp_factor; + v_out_group = cinfo->max_v_samp_factor; + upsample->rowgroup_height[ci] = v_in_group; /* save for use later */ + need_buffer = TRUE; + if (! compptr->component_needed) { + /* Don't bother to upsample an uninteresting component. */ + upsample->methods[ci] = noop_upsample; + need_buffer = FALSE; + } else if (h_in_group == h_out_group && v_in_group == v_out_group) { + /* Fullsize components can be processed without any work. */ + upsample->methods[ci] = fullsize_upsample; + need_buffer = FALSE; + } else if (h_in_group * 2 == h_out_group && + v_in_group == v_out_group) { + /* Special cases for 2h1v upsampling */ + if (do_fancy && compptr->downsampled_width > 2) + upsample->methods[ci] = h2v1_fancy_upsample; + else + upsample->methods[ci] = h2v1_upsample; + } else if (h_in_group * 2 == h_out_group && + v_in_group * 2 == v_out_group) { + /* Special cases for 2h2v upsampling */ + if (do_fancy && compptr->downsampled_width > 2) { + upsample->methods[ci] = h2v2_fancy_upsample; + upsample->pub.need_context_rows = TRUE; + } else + upsample->methods[ci] = h2v2_upsample; + } else if ((h_out_group % h_in_group) == 0 && + (v_out_group % v_in_group) == 0) { + /* Generic integral-factors upsampling method */ + upsample->methods[ci] = int_upsample; + upsample->h_expand[ci] = (unsigned char) (h_out_group / h_in_group); + upsample->v_expand[ci] = (unsigned char) (v_out_group / v_in_group); + } else + ERREXIT(cinfo, JERR_FRACT_SAMPLE_NOTIMPL); + if (need_buffer) { + upsample->color_buf[ci] = (*cinfo->mem->alloc_sarray) + ((j_common_ptr) cinfo, JPOOL_IMAGE, + (JDIMENSION) jround_up((long) cinfo->output_width, + (long) cinfo->max_h_samp_factor), + (JDIMENSION) cinfo->max_v_samp_factor); + } + } +} diff --git a/ml/dlib/dlib/external/libjpeg/jerror.cpp b/ml/dlib/dlib/external/libjpeg/jerror.cpp new file mode 100644 index 000000000..117bc4829 --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jerror.cpp @@ -0,0 +1,252 @@ +/* + * jerror.c + * + * Copyright (C) 1991-1998, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This file contains simple error-reporting and trace-message routines. + * These are suitable for Unix-like systems and others where writing to + * stderr is the right thing to do. Many applications will want to replace + * some or all of these routines. + * + * If you define USE_WINDOWS_MESSAGEBOX in jconfig.h or in the makefile, + * you get a Windows-specific hack to display error messages in a dialog box. + * It ain't much, but it beats dropping error messages into the bit bucket, + * which is what happens to output to stderr under most Windows C compilers. + * + * These routines are used by both the compression and decompression code. + */ + +/* this is not a core library module, so it doesn't define JPEG_INTERNALS */ +#include "jinclude.h" +#include "jpeglib.h" +#include "jversion.h" +#include "jerror.h" + +#ifdef USE_WINDOWS_MESSAGEBOX +#include +#endif + +#ifndef EXIT_FAILURE /* define exit() codes if not provided */ +#define EXIT_FAILURE 1 +#endif + + +/* + * Create the message string table. + * We do this from the master message list in jerror.h by re-reading + * jerror.h with a suitable definition for macro JMESSAGE. + * The message table is made an external symbol just in case any applications + * want to refer to it directly. + */ + +#ifdef NEED_SHORT_EXTERNAL_NAMES +#define jpeg_std_message_table jMsgTable +#endif + +#define JMESSAGE(code,string) string , + +const char * const jpeg_std_message_table[] = { +#include "jerror.h" + NULL +}; + + +/* + * Error exit handler: must not return to caller. + * + * Applications may override this if they want to get control back after + * an error. Typically one would longjmp somewhere instead of exiting. + * The setjmp buffer can be made a private field within an expanded error + * handler object. Note that the info needed to generate an error message + * is stored in the error object, so you can generate the message now or + * later, at your convenience. + * You should make sure that the JPEG object is cleaned up (with jpeg_abort + * or jpeg_destroy) at some point. + */ + +METHODDEF(void) +error_exit (j_common_ptr cinfo) +{ + /* Always display the message */ + (*cinfo->err->output_message) (cinfo); + + /* Let the memory manager delete any temp files before we die */ + jpeg_destroy(cinfo); + + exit(EXIT_FAILURE); +} + + +/* + * Actual output of an error or trace message. + * Applications may override this method to send JPEG messages somewhere + * other than stderr. + * + * On Windows, printing to stderr is generally completely useless, + * so we provide optional code to produce an error-dialog popup. + * Most Windows applications will still prefer to override this routine, + * but if they don't, it'll do something at least marginally useful. + * + * NOTE: to use the library in an environment that doesn't support the + * C stdio library, you may have to delete the call to fprintf() entirely, + * not just not use this routine. + */ + +METHODDEF(void) +output_message (j_common_ptr cinfo) +{ + char buffer[JMSG_LENGTH_MAX]; + + /* Create the message */ + (*cinfo->err->format_message) (cinfo, buffer); + +#ifdef USE_WINDOWS_MESSAGEBOX + /* Display it in a message dialog box */ + MessageBox(GetActiveWindow(), buffer, "JPEG Library Error", + MB_OK | MB_ICONERROR); +#else + /* Send it to stderr, adding a newline */ + fprintf(stderr, "%s\n", buffer); +#endif +} + + +/* + * Decide whether to emit a trace or warning message. + * msg_level is one of: + * -1: recoverable corrupt-data warning, may want to abort. + * 0: important advisory messages (always display to user). + * 1: first level of tracing detail. + * 2,3,...: successively more detailed tracing messages. + * An application might override this method if it wanted to abort on warnings + * or change the policy about which messages to display. + */ + +METHODDEF(void) +emit_message (j_common_ptr cinfo, int msg_level) +{ + struct jpeg_error_mgr * err = cinfo->err; + + if (msg_level < 0) { + /* It's a warning message. Since corrupt files may generate many warnings, + * the policy implemented here is to show only the first warning, + * unless trace_level >= 3. + */ + if (err->num_warnings == 0 || err->trace_level >= 3) + (*err->output_message) (cinfo); + /* Always count warnings in num_warnings. */ + err->num_warnings++; + } else { + /* It's a trace message. Show it if trace_level >= msg_level. */ + if (err->trace_level >= msg_level) + (*err->output_message) (cinfo); + } +} + + +/* + * Format a message string for the most recent JPEG error or message. + * The message is stored into buffer, which should be at least JMSG_LENGTH_MAX + * characters. Note that no '\n' character is added to the string. + * Few applications should need to override this method. + */ + +METHODDEF(void) +format_message (j_common_ptr cinfo, char * buffer) +{ + struct jpeg_error_mgr * err = cinfo->err; + int msg_code = err->msg_code; + const char * msgtext = NULL; + const char * msgptr; + char ch; + int isstring; + + /* Look up message string in proper table */ + if (msg_code > 0 && msg_code <= err->last_jpeg_message) { + msgtext = err->jpeg_message_table[msg_code]; + } else if (err->addon_message_table != NULL && + msg_code >= err->first_addon_message && + msg_code <= err->last_addon_message) { + msgtext = err->addon_message_table[msg_code - err->first_addon_message]; + } + + /* Defend against bogus message number */ + if (msgtext == NULL) { + err->msg_parm.i[0] = msg_code; + msgtext = err->jpeg_message_table[0]; + } + + /* Check for string parameter, as indicated by %s in the message text */ + isstring = FALSE; + msgptr = msgtext; + while ((ch = *msgptr++) != '\0') { + if (ch == '%') { + if (*msgptr == 's') isstring = TRUE; + break; + } + } + + /* Format the message into the passed buffer */ + if (isstring) + sprintf(buffer, msgtext, err->msg_parm.s); + else + sprintf(buffer, msgtext, + err->msg_parm.i[0], err->msg_parm.i[1], + err->msg_parm.i[2], err->msg_parm.i[3], + err->msg_parm.i[4], err->msg_parm.i[5], + err->msg_parm.i[6], err->msg_parm.i[7]); +} + + +/* + * Reset error state variables at start of a new image. + * This is called during compression startup to reset trace/error + * processing to default state, without losing any application-specific + * method pointers. An application might possibly want to override + * this method if it has additional error processing state. + */ + +METHODDEF(void) +reset_error_mgr (j_common_ptr cinfo) +{ + cinfo->err->num_warnings = 0; + /* trace_level is not reset since it is an application-supplied parameter */ + cinfo->err->msg_code = 0; /* may be useful as a flag for "no error" */ +} + + +/* + * Fill in the standard error-handling methods in a jpeg_error_mgr object. + * Typical call is: + * struct jpeg_compress_struct cinfo; + * struct jpeg_error_mgr err; + * + * cinfo.err = jpeg_std_error(&err); + * after which the application may override some of the methods. + */ + +GLOBAL(struct jpeg_error_mgr *) +jpeg_std_error (struct jpeg_error_mgr * err) +{ + err->error_exit = error_exit; + err->emit_message = emit_message; + err->output_message = output_message; + err->format_message = format_message; + err->reset_error_mgr = reset_error_mgr; + + err->trace_level = 0; /* default = no tracing */ + err->num_warnings = 0; /* no warnings emitted yet */ + err->msg_code = 0; /* may be useful as a flag for "no error" */ + + /* Initialize message table pointers */ + err->jpeg_message_table = jpeg_std_message_table; + err->last_jpeg_message = (int) JMSG_LASTMSGCODE - 1; + + err->addon_message_table = NULL; + err->first_addon_message = 0; /* for safety */ + err->last_addon_message = 0; + + return err; +} diff --git a/ml/dlib/dlib/external/libjpeg/jerror.h b/ml/dlib/dlib/external/libjpeg/jerror.h new file mode 100644 index 000000000..fc2fffeac --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jerror.h @@ -0,0 +1,291 @@ +/* + * jerror.h + * + * Copyright (C) 1994-1997, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This file defines the error and message codes for the JPEG library. + * Edit this file to add new codes, or to translate the message strings to + * some other language. + * A set of error-reporting macros are defined too. Some applications using + * the JPEG library may wish to include this file to get the error codes + * and/or the macros. + */ + +/* + * To define the enum list of message codes, include this file without + * defining macro JMESSAGE. To create a message string table, include it + * again with a suitable JMESSAGE definition (see jerror.c for an example). + */ +#ifndef JMESSAGE +#ifndef JERROR_H +/* First time through, define the enum list */ +#define JMAKE_ENUM_LIST +#else +/* Repeated inclusions of this file are no-ops unless JMESSAGE is defined */ +#define JMESSAGE(code,string) +#endif /* JERROR_H */ +#endif /* JMESSAGE */ + +#ifdef JMAKE_ENUM_LIST + +typedef enum { + +#define JMESSAGE(code,string) code , + +#endif /* JMAKE_ENUM_LIST */ + +JMESSAGE(JMSG_NOMESSAGE, "Bogus message code %d") /* Must be first entry! */ + +/* For maintenance convenience, list is alphabetical by message code name */ +JMESSAGE(JERR_ARITH_NOTIMPL, + "Sorry, there are legal restrictions on arithmetic coding") +JMESSAGE(JERR_BAD_ALIGN_TYPE, "ALIGN_TYPE is wrong, please fix") +JMESSAGE(JERR_BAD_ALLOC_CHUNK, "MAX_ALLOC_CHUNK is wrong, please fix") +JMESSAGE(JERR_BAD_BUFFER_MODE, "Bogus buffer control mode") +JMESSAGE(JERR_BAD_COMPONENT_ID, "Invalid component ID %d in SOS") +JMESSAGE(JERR_BAD_DCT_COEF, "DCT coefficient out of range") +JMESSAGE(JERR_BAD_DCTSIZE, "IDCT output block size %d not supported") +JMESSAGE(JERR_BAD_HUFF_TABLE, "Bogus Huffman table definition") +JMESSAGE(JERR_BAD_IN_COLORSPACE, "Bogus input colorspace") +JMESSAGE(JERR_BAD_J_COLORSPACE, "Bogus JPEG colorspace") +JMESSAGE(JERR_BAD_LENGTH, "Bogus marker length") +JMESSAGE(JERR_BAD_LIB_VERSION, + "Wrong JPEG library version: library is %d, caller expects %d") +JMESSAGE(JERR_BAD_MCU_SIZE, "Sampling factors too large for interleaved scan") +JMESSAGE(JERR_BAD_POOL_ID, "Invalid memory pool code %d") +JMESSAGE(JERR_BAD_PRECISION, "Unsupported JPEG data precision %d") +JMESSAGE(JERR_BAD_PROGRESSION, + "Invalid progressive parameters Ss=%d Se=%d Ah=%d Al=%d") +JMESSAGE(JERR_BAD_PROG_SCRIPT, + "Invalid progressive parameters at scan script entry %d") +JMESSAGE(JERR_BAD_SAMPLING, "Bogus sampling factors") +JMESSAGE(JERR_BAD_SCAN_SCRIPT, "Invalid scan script at entry %d") +JMESSAGE(JERR_BAD_STATE, "Improper call to JPEG library in state %d") +JMESSAGE(JERR_BAD_STRUCT_SIZE, + "JPEG parameter struct mismatch: library thinks size is %u, caller expects %u") +JMESSAGE(JERR_BAD_VIRTUAL_ACCESS, "Bogus virtual array access") +JMESSAGE(JERR_BUFFER_SIZE, "Buffer passed to JPEG library is too small") +JMESSAGE(JERR_CANT_SUSPEND, "Suspension not allowed here") +JMESSAGE(JERR_CCIR601_NOTIMPL, "CCIR601 sampling not implemented yet") +JMESSAGE(JERR_COMPONENT_COUNT, "Too many color components: %d, max %d") +JMESSAGE(JERR_CONVERSION_NOTIMPL, "Unsupported color conversion request") +JMESSAGE(JERR_DAC_INDEX, "Bogus DAC index %d") +JMESSAGE(JERR_DAC_VALUE, "Bogus DAC value 0x%x") +JMESSAGE(JERR_DHT_INDEX, "Bogus DHT index %d") +JMESSAGE(JERR_DQT_INDEX, "Bogus DQT index %d") +JMESSAGE(JERR_EMPTY_IMAGE, "Empty JPEG image (DNL not supported)") +JMESSAGE(JERR_EMS_READ, "Read from EMS failed") +JMESSAGE(JERR_EMS_WRITE, "Write to EMS failed") +JMESSAGE(JERR_EOI_EXPECTED, "Didn't expect more than one scan") +JMESSAGE(JERR_FILE_READ, "Input file read error") +JMESSAGE(JERR_FILE_WRITE, "Output file write error --- out of disk space?") +JMESSAGE(JERR_FRACT_SAMPLE_NOTIMPL, "Fractional sampling not implemented yet") +JMESSAGE(JERR_HUFF_CLEN_OVERFLOW, "Huffman code size table overflow") +JMESSAGE(JERR_HUFF_MISSING_CODE, "Missing Huffman code table entry") +JMESSAGE(JERR_IMAGE_TOO_BIG, "Maximum supported image dimension is %u pixels") +JMESSAGE(JERR_INPUT_EMPTY, "Empty input file") +JMESSAGE(JERR_INPUT_EOF, "Premature end of input file") +JMESSAGE(JERR_MISMATCHED_QUANT_TABLE, + "Cannot transcode due to multiple use of quantization table %d") +JMESSAGE(JERR_MISSING_DATA, "Scan script does not transmit all data") +JMESSAGE(JERR_MODE_CHANGE, "Invalid color quantization mode change") +JMESSAGE(JERR_NOTIMPL, "Not implemented yet") +JMESSAGE(JERR_NOT_COMPILED, "Requested feature was omitted at compile time") +JMESSAGE(JERR_NO_BACKING_STORE, "Backing store not supported") +JMESSAGE(JERR_NO_HUFF_TABLE, "Huffman table 0x%02x was not defined") +JMESSAGE(JERR_NO_IMAGE, "JPEG datastream contains no image") +JMESSAGE(JERR_NO_QUANT_TABLE, "Quantization table 0x%02x was not defined") +JMESSAGE(JERR_NO_SOI, "Not a JPEG file: starts with 0x%02x 0x%02x") +JMESSAGE(JERR_OUT_OF_MEMORY, "Insufficient memory (case %d)") +JMESSAGE(JERR_QUANT_COMPONENTS, + "Cannot quantize more than %d color components") +JMESSAGE(JERR_QUANT_FEW_COLORS, "Cannot quantize to fewer than %d colors") +JMESSAGE(JERR_QUANT_MANY_COLORS, "Cannot quantize to more than %d colors") +JMESSAGE(JERR_SOF_DUPLICATE, "Invalid JPEG file structure: two SOF markers") +JMESSAGE(JERR_SOF_NO_SOS, "Invalid JPEG file structure: missing SOS marker") +JMESSAGE(JERR_SOF_UNSUPPORTED, "Unsupported JPEG process: SOF type 0x%02x") +JMESSAGE(JERR_SOI_DUPLICATE, "Invalid JPEG file structure: two SOI markers") +JMESSAGE(JERR_SOS_NO_SOF, "Invalid JPEG file structure: SOS before SOF") +JMESSAGE(JERR_TFILE_CREATE, "Failed to create temporary file %s") +JMESSAGE(JERR_TFILE_READ, "Read failed on temporary file") +JMESSAGE(JERR_TFILE_SEEK, "Seek failed on temporary file") +JMESSAGE(JERR_TFILE_WRITE, + "Write failed on temporary file --- out of disk space?") +JMESSAGE(JERR_TOO_LITTLE_DATA, "Application transferred too few scanlines") +JMESSAGE(JERR_UNKNOWN_MARKER, "Unsupported marker type 0x%02x") +JMESSAGE(JERR_VIRTUAL_BUG, "Virtual array controller messed up") +JMESSAGE(JERR_WIDTH_OVERFLOW, "Image too wide for this implementation") +JMESSAGE(JERR_XMS_READ, "Read from XMS failed") +JMESSAGE(JERR_XMS_WRITE, "Write to XMS failed") +JMESSAGE(JMSG_COPYRIGHT, JCOPYRIGHT) +JMESSAGE(JMSG_VERSION, JVERSION) +JMESSAGE(JTRC_16BIT_TABLES, + "Caution: quantization tables are too coarse for baseline JPEG") +JMESSAGE(JTRC_ADOBE, + "Adobe APP14 marker: version %d, flags 0x%04x 0x%04x, transform %d") +JMESSAGE(JTRC_APP0, "Unknown APP0 marker (not JFIF), length %u") +JMESSAGE(JTRC_APP14, "Unknown APP14 marker (not Adobe), length %u") +JMESSAGE(JTRC_DAC, "Define Arithmetic Table 0x%02x: 0x%02x") +JMESSAGE(JTRC_DHT, "Define Huffman Table 0x%02x") +JMESSAGE(JTRC_DQT, "Define Quantization Table %d precision %d") +JMESSAGE(JTRC_DRI, "Define Restart Interval %u") +JMESSAGE(JTRC_EMS_CLOSE, "Freed EMS handle %u") +JMESSAGE(JTRC_EMS_OPEN, "Obtained EMS handle %u") +JMESSAGE(JTRC_EOI, "End Of Image") +JMESSAGE(JTRC_HUFFBITS, " %3d %3d %3d %3d %3d %3d %3d %3d") +JMESSAGE(JTRC_JFIF, "JFIF APP0 marker: version %d.%02d, density %dx%d %d") +JMESSAGE(JTRC_JFIF_BADTHUMBNAILSIZE, + "Warning: thumbnail image size does not match data length %u") +JMESSAGE(JTRC_JFIF_EXTENSION, + "JFIF extension marker: type 0x%02x, length %u") +JMESSAGE(JTRC_JFIF_THUMBNAIL, " with %d x %d thumbnail image") +JMESSAGE(JTRC_MISC_MARKER, "Miscellaneous marker 0x%02x, length %u") +JMESSAGE(JTRC_PARMLESS_MARKER, "Unexpected marker 0x%02x") +JMESSAGE(JTRC_QUANTVALS, " %4u %4u %4u %4u %4u %4u %4u %4u") +JMESSAGE(JTRC_QUANT_3_NCOLORS, "Quantizing to %d = %d*%d*%d colors") +JMESSAGE(JTRC_QUANT_NCOLORS, "Quantizing to %d colors") +JMESSAGE(JTRC_QUANT_SELECTED, "Selected %d colors for quantization") +JMESSAGE(JTRC_RECOVERY_ACTION, "At marker 0x%02x, recovery action %d") +JMESSAGE(JTRC_RST, "RST%d") +JMESSAGE(JTRC_SMOOTH_NOTIMPL, + "Smoothing not supported with nonstandard sampling ratios") +JMESSAGE(JTRC_SOF, "Start Of Frame 0x%02x: width=%u, height=%u, components=%d") +JMESSAGE(JTRC_SOF_COMPONENT, " Component %d: %dhx%dv q=%d") +JMESSAGE(JTRC_SOI, "Start of Image") +JMESSAGE(JTRC_SOS, "Start Of Scan: %d components") +JMESSAGE(JTRC_SOS_COMPONENT, " Component %d: dc=%d ac=%d") +JMESSAGE(JTRC_SOS_PARAMS, " Ss=%d, Se=%d, Ah=%d, Al=%d") +JMESSAGE(JTRC_TFILE_CLOSE, "Closed temporary file %s") +JMESSAGE(JTRC_TFILE_OPEN, "Opened temporary file %s") +JMESSAGE(JTRC_THUMB_JPEG, + "JFIF extension marker: JPEG-compressed thumbnail image, length %u") +JMESSAGE(JTRC_THUMB_PALETTE, + "JFIF extension marker: palette thumbnail image, length %u") +JMESSAGE(JTRC_THUMB_RGB, + "JFIF extension marker: RGB thumbnail image, length %u") +JMESSAGE(JTRC_UNKNOWN_IDS, + "Unrecognized component IDs %d %d %d, assuming YCbCr") +JMESSAGE(JTRC_XMS_CLOSE, "Freed XMS handle %u") +JMESSAGE(JTRC_XMS_OPEN, "Obtained XMS handle %u") +JMESSAGE(JWRN_ADOBE_XFORM, "Unknown Adobe color transform code %d") +JMESSAGE(JWRN_BOGUS_PROGRESSION, + "Inconsistent progression sequence for component %d coefficient %d") +JMESSAGE(JWRN_EXTRANEOUS_DATA, + "Corrupt JPEG data: %u extraneous bytes before marker 0x%02x") +JMESSAGE(JWRN_HIT_MARKER, "Corrupt JPEG data: premature end of data segment") +JMESSAGE(JWRN_HUFF_BAD_CODE, "Corrupt JPEG data: bad Huffman code") +JMESSAGE(JWRN_JFIF_MAJOR, "Warning: unknown JFIF revision number %d.%02d") +JMESSAGE(JWRN_JPEG_EOF, "Premature end of JPEG file") +JMESSAGE(JWRN_MUST_RESYNC, + "Corrupt JPEG data: found marker 0x%02x instead of RST%d") +JMESSAGE(JWRN_NOT_SEQUENTIAL, "Invalid SOS parameters for sequential JPEG") +JMESSAGE(JWRN_TOO_MUCH_DATA, "Application transferred too many scanlines") + +#ifdef JMAKE_ENUM_LIST + + JMSG_LASTMSGCODE +} J_MESSAGE_CODE; + +#undef JMAKE_ENUM_LIST +#endif /* JMAKE_ENUM_LIST */ + +/* Zap JMESSAGE macro so that future re-inclusions do nothing by default */ +#undef JMESSAGE + + +#ifndef JERROR_H +#define JERROR_H + +/* Macros to simplify using the error and trace message stuff */ +/* The first parameter is either type of cinfo pointer */ + +/* Fatal errors (print message and exit) */ +#define ERREXIT(cinfo,code) \ + ((cinfo)->err->msg_code = (code), \ + (*(cinfo)->err->error_exit) ((j_common_ptr) (cinfo))) +#define ERREXIT1(cinfo,code,p1) \ + ((cinfo)->err->msg_code = (code), \ + (cinfo)->err->msg_parm.i[0] = (p1), \ + (*(cinfo)->err->error_exit) ((j_common_ptr) (cinfo))) +#define ERREXIT2(cinfo,code,p1,p2) \ + ((cinfo)->err->msg_code = (code), \ + (cinfo)->err->msg_parm.i[0] = (p1), \ + (cinfo)->err->msg_parm.i[1] = (p2), \ + (*(cinfo)->err->error_exit) ((j_common_ptr) (cinfo))) +#define ERREXIT3(cinfo,code,p1,p2,p3) \ + ((cinfo)->err->msg_code = (code), \ + (cinfo)->err->msg_parm.i[0] = (p1), \ + (cinfo)->err->msg_parm.i[1] = (p2), \ + (cinfo)->err->msg_parm.i[2] = (p3), \ + (*(cinfo)->err->error_exit) ((j_common_ptr) (cinfo))) +#define ERREXIT4(cinfo,code,p1,p2,p3,p4) \ + ((cinfo)->err->msg_code = (code), \ + (cinfo)->err->msg_parm.i[0] = (p1), \ + (cinfo)->err->msg_parm.i[1] = (p2), \ + (cinfo)->err->msg_parm.i[2] = (p3), \ + (cinfo)->err->msg_parm.i[3] = (p4), \ + (*(cinfo)->err->error_exit) ((j_common_ptr) (cinfo))) +#define ERREXITS(cinfo,code,str) \ + ((cinfo)->err->msg_code = (code), \ + strncpy((cinfo)->err->msg_parm.s, (str), JMSG_STR_PARM_MAX), \ + (*(cinfo)->err->error_exit) ((j_common_ptr) (cinfo))) + +#define MAKESTMT(stuff) do { stuff } while (0) + +/* Nonfatal errors (we can keep going, but the data is probably corrupt) */ +#define WARNMS(cinfo,code) \ + ((cinfo)->err->msg_code = (code), \ + (*(cinfo)->err->emit_message) ((j_common_ptr) (cinfo), -1)) +#define WARNMS1(cinfo,code,p1) \ + ((cinfo)->err->msg_code = (code), \ + (cinfo)->err->msg_parm.i[0] = (p1), \ + (*(cinfo)->err->emit_message) ((j_common_ptr) (cinfo), -1)) +#define WARNMS2(cinfo,code,p1,p2) \ + ((cinfo)->err->msg_code = (code), \ + (cinfo)->err->msg_parm.i[0] = (p1), \ + (cinfo)->err->msg_parm.i[1] = (p2), \ + (*(cinfo)->err->emit_message) ((j_common_ptr) (cinfo), -1)) + +/* Informational/debugging messages */ +#define TRACEMS(cinfo,lvl,code) \ + ((cinfo)->err->msg_code = (code), \ + (*(cinfo)->err->emit_message) ((j_common_ptr) (cinfo), (lvl))) +#define TRACEMS1(cinfo,lvl,code,p1) \ + ((cinfo)->err->msg_code = (code), \ + (cinfo)->err->msg_parm.i[0] = (p1), \ + (*(cinfo)->err->emit_message) ((j_common_ptr) (cinfo), (lvl))) +#define TRACEMS2(cinfo,lvl,code,p1,p2) \ + ((cinfo)->err->msg_code = (code), \ + (cinfo)->err->msg_parm.i[0] = (p1), \ + (cinfo)->err->msg_parm.i[1] = (p2), \ + (*(cinfo)->err->emit_message) ((j_common_ptr) (cinfo), (lvl))) +#define TRACEMS3(cinfo,lvl,code,p1,p2,p3) \ + MAKESTMT(int * _mp = (cinfo)->err->msg_parm.i; \ + _mp[0] = (p1); _mp[1] = (p2); _mp[2] = (p3); \ + (cinfo)->err->msg_code = (code); \ + (*(cinfo)->err->emit_message) ((j_common_ptr) (cinfo), (lvl)); ) +#define TRACEMS4(cinfo,lvl,code,p1,p2,p3,p4) \ + MAKESTMT(int * _mp = (cinfo)->err->msg_parm.i; \ + _mp[0] = (p1); _mp[1] = (p2); _mp[2] = (p3); _mp[3] = (p4); \ + (cinfo)->err->msg_code = (code); \ + (*(cinfo)->err->emit_message) ((j_common_ptr) (cinfo), (lvl)); ) +#define TRACEMS5(cinfo,lvl,code,p1,p2,p3,p4,p5) \ + MAKESTMT(int * _mp = (cinfo)->err->msg_parm.i; \ + _mp[0] = (p1); _mp[1] = (p2); _mp[2] = (p3); _mp[3] = (p4); \ + _mp[4] = (p5); \ + (cinfo)->err->msg_code = (code); \ + (*(cinfo)->err->emit_message) ((j_common_ptr) (cinfo), (lvl)); ) +#define TRACEMS8(cinfo,lvl,code,p1,p2,p3,p4,p5,p6,p7,p8) \ + MAKESTMT(int * _mp = (cinfo)->err->msg_parm.i; \ + _mp[0] = (p1); _mp[1] = (p2); _mp[2] = (p3); _mp[3] = (p4); \ + _mp[4] = (p5); _mp[5] = (p6); _mp[6] = (p7); _mp[7] = (p8); \ + (cinfo)->err->msg_code = (code); \ + (*(cinfo)->err->emit_message) ((j_common_ptr) (cinfo), (lvl)); ) +#define TRACEMSS(cinfo,lvl,code,str) \ + ((cinfo)->err->msg_code = (code), \ + strncpy((cinfo)->err->msg_parm.s, (str), JMSG_STR_PARM_MAX), \ + (*(cinfo)->err->emit_message) ((j_common_ptr) (cinfo), (lvl))) + +#endif /* JERROR_H */ diff --git a/ml/dlib/dlib/external/libjpeg/jfdctflt.cpp b/ml/dlib/dlib/external/libjpeg/jfdctflt.cpp new file mode 100644 index 000000000..79d7a0078 --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jfdctflt.cpp @@ -0,0 +1,168 @@ +/* + * jfdctflt.c + * + * Copyright (C) 1994-1996, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This file contains a floating-point implementation of the + * forward DCT (Discrete Cosine Transform). + * + * This implementation should be more accurate than either of the integer + * DCT implementations. However, it may not give the same results on all + * machines because of differences in roundoff behavior. Speed will depend + * on the hardware's floating point capacity. + * + * A 2-D DCT can be done by 1-D DCT on each row followed by 1-D DCT + * on each column. Direct algorithms are also available, but they are + * much more complex and seem not to be any faster when reduced to code. + * + * This implementation is based on Arai, Agui, and Nakajima's algorithm for + * scaled DCT. Their original paper (Trans. IEICE E-71(11):1095) is in + * Japanese, but the algorithm is described in the Pennebaker & Mitchell + * JPEG textbook (see REFERENCES section in file README). The following code + * is based directly on figure 4-8 in P&M. + * While an 8-point DCT cannot be done in less than 11 multiplies, it is + * possible to arrange the computation so that many of the multiplies are + * simple scalings of the final outputs. These multiplies can then be + * folded into the multiplications or divisions by the JPEG quantization + * table entries. The AA&N method leaves only 5 multiplies and 29 adds + * to be done in the DCT itself. + * The primary disadvantage of this method is that with a fixed-point + * implementation, accuracy is lost due to imprecise representation of the + * scaled quantization values. However, that problem does not arise if + * we use floating point arithmetic. + */ + +#define JPEG_INTERNALS +#include "jinclude.h" +#include "jpeglib.h" +#include "jdct.h" /* Private declarations for DCT subsystem */ + +#ifdef DCT_FLOAT_SUPPORTED + + +/* + * This module is specialized to the case DCTSIZE = 8. + */ + +#if DCTSIZE != 8 + Sorry, this code only copes with 8x8 DCTs. /* deliberate syntax err */ +#endif + + +/* + * Perform the forward DCT on one block of samples. + */ + +GLOBAL(void) +jpeg_fdct_float (FAST_FLOAT * data) +{ + FAST_FLOAT tmp0, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7; + FAST_FLOAT tmp10, tmp11, tmp12, tmp13; + FAST_FLOAT z1, z2, z3, z4, z5, z11, z13; + FAST_FLOAT *dataptr; + int ctr; + + /* Pass 1: process rows. */ + + dataptr = data; + for (ctr = DCTSIZE-1; ctr >= 0; ctr--) { + tmp0 = dataptr[0] + dataptr[7]; + tmp7 = dataptr[0] - dataptr[7]; + tmp1 = dataptr[1] + dataptr[6]; + tmp6 = dataptr[1] - dataptr[6]; + tmp2 = dataptr[2] + dataptr[5]; + tmp5 = dataptr[2] - dataptr[5]; + tmp3 = dataptr[3] + dataptr[4]; + tmp4 = dataptr[3] - dataptr[4]; + + /* Even part */ + + tmp10 = tmp0 + tmp3; /* phase 2 */ + tmp13 = tmp0 - tmp3; + tmp11 = tmp1 + tmp2; + tmp12 = tmp1 - tmp2; + + dataptr[0] = tmp10 + tmp11; /* phase 3 */ + dataptr[4] = tmp10 - tmp11; + + z1 = (tmp12 + tmp13) * ((FAST_FLOAT) 0.707106781); /* c4 */ + dataptr[2] = tmp13 + z1; /* phase 5 */ + dataptr[6] = tmp13 - z1; + + /* Odd part */ + + tmp10 = tmp4 + tmp5; /* phase 2 */ + tmp11 = tmp5 + tmp6; + tmp12 = tmp6 + tmp7; + + /* The rotator is modified from fig 4-8 to avoid extra negations. */ + z5 = (tmp10 - tmp12) * ((FAST_FLOAT) 0.382683433); /* c6 */ + z2 = ((FAST_FLOAT) 0.541196100) * tmp10 + z5; /* c2-c6 */ + z4 = ((FAST_FLOAT) 1.306562965) * tmp12 + z5; /* c2+c6 */ + z3 = tmp11 * ((FAST_FLOAT) 0.707106781); /* c4 */ + + z11 = tmp7 + z3; /* phase 5 */ + z13 = tmp7 - z3; + + dataptr[5] = z13 + z2; /* phase 6 */ + dataptr[3] = z13 - z2; + dataptr[1] = z11 + z4; + dataptr[7] = z11 - z4; + + dataptr += DCTSIZE; /* advance pointer to next row */ + } + + /* Pass 2: process columns. */ + + dataptr = data; + for (ctr = DCTSIZE-1; ctr >= 0; ctr--) { + tmp0 = dataptr[DCTSIZE*0] + dataptr[DCTSIZE*7]; + tmp7 = dataptr[DCTSIZE*0] - dataptr[DCTSIZE*7]; + tmp1 = dataptr[DCTSIZE*1] + dataptr[DCTSIZE*6]; + tmp6 = dataptr[DCTSIZE*1] - dataptr[DCTSIZE*6]; + tmp2 = dataptr[DCTSIZE*2] + dataptr[DCTSIZE*5]; + tmp5 = dataptr[DCTSIZE*2] - dataptr[DCTSIZE*5]; + tmp3 = dataptr[DCTSIZE*3] + dataptr[DCTSIZE*4]; + tmp4 = dataptr[DCTSIZE*3] - dataptr[DCTSIZE*4]; + + /* Even part */ + + tmp10 = tmp0 + tmp3; /* phase 2 */ + tmp13 = tmp0 - tmp3; + tmp11 = tmp1 + tmp2; + tmp12 = tmp1 - tmp2; + + dataptr[DCTSIZE*0] = tmp10 + tmp11; /* phase 3 */ + dataptr[DCTSIZE*4] = tmp10 - tmp11; + + z1 = (tmp12 + tmp13) * ((FAST_FLOAT) 0.707106781); /* c4 */ + dataptr[DCTSIZE*2] = tmp13 + z1; /* phase 5 */ + dataptr[DCTSIZE*6] = tmp13 - z1; + + /* Odd part */ + + tmp10 = tmp4 + tmp5; /* phase 2 */ + tmp11 = tmp5 + tmp6; + tmp12 = tmp6 + tmp7; + + /* The rotator is modified from fig 4-8 to avoid extra negations. */ + z5 = (tmp10 - tmp12) * ((FAST_FLOAT) 0.382683433); /* c6 */ + z2 = ((FAST_FLOAT) 0.541196100) * tmp10 + z5; /* c2-c6 */ + z4 = ((FAST_FLOAT) 1.306562965) * tmp12 + z5; /* c2+c6 */ + z3 = tmp11 * ((FAST_FLOAT) 0.707106781); /* c4 */ + + z11 = tmp7 + z3; /* phase 5 */ + z13 = tmp7 - z3; + + dataptr[DCTSIZE*5] = z13 + z2; /* phase 6 */ + dataptr[DCTSIZE*3] = z13 - z2; + dataptr[DCTSIZE*1] = z11 + z4; + dataptr[DCTSIZE*7] = z11 - z4; + + dataptr++; /* advance pointer to next column */ + } +} + +#endif /* DCT_FLOAT_SUPPORTED */ diff --git a/ml/dlib/dlib/external/libjpeg/jfdctfst.cpp b/ml/dlib/dlib/external/libjpeg/jfdctfst.cpp new file mode 100644 index 000000000..4c96d4f3d --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jfdctfst.cpp @@ -0,0 +1,224 @@ +/* + * jfdctfst.c + * + * Copyright (C) 1994-1996, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This file contains a fast, not so accurate integer implementation of the + * forward DCT (Discrete Cosine Transform). + * + * A 2-D DCT can be done by 1-D DCT on each row followed by 1-D DCT + * on each column. Direct algorithms are also available, but they are + * much more complex and seem not to be any faster when reduced to code. + * + * This implementation is based on Arai, Agui, and Nakajima's algorithm for + * scaled DCT. Their original paper (Trans. IEICE E-71(11):1095) is in + * Japanese, but the algorithm is described in the Pennebaker & Mitchell + * JPEG textbook (see REFERENCES section in file README). The following code + * is based directly on figure 4-8 in P&M. + * While an 8-point DCT cannot be done in less than 11 multiplies, it is + * possible to arrange the computation so that many of the multiplies are + * simple scalings of the final outputs. These multiplies can then be + * folded into the multiplications or divisions by the JPEG quantization + * table entries. The AA&N method leaves only 5 multiplies and 29 adds + * to be done in the DCT itself. + * The primary disadvantage of this method is that with fixed-point math, + * accuracy is lost due to imprecise representation of the scaled + * quantization values. The smaller the quantization table entry, the less + * precise the scaled value, so this implementation does worse with high- + * quality-setting files than with low-quality ones. + */ + +#define JPEG_INTERNALS +#include "jinclude.h" +#include "jpeglib.h" +#include "jdct.h" /* Private declarations for DCT subsystem */ + +#ifdef DCT_IFAST_SUPPORTED + + +/* + * This module is specialized to the case DCTSIZE = 8. + */ + +#if DCTSIZE != 8 + Sorry, this code only copes with 8x8 DCTs. /* deliberate syntax err */ +#endif + + +/* Scaling decisions are generally the same as in the LL&M algorithm; + * see jfdctint.c for more details. However, we choose to descale + * (right shift) multiplication products as soon as they are formed, + * rather than carrying additional fractional bits into subsequent additions. + * This compromises accuracy slightly, but it lets us save a few shifts. + * More importantly, 16-bit arithmetic is then adequate (for 8-bit samples) + * everywhere except in the multiplications proper; this saves a good deal + * of work on 16-bit-int machines. + * + * Again to save a few shifts, the intermediate results between pass 1 and + * pass 2 are not upscaled, but are represented only to integral precision. + * + * A final compromise is to represent the multiplicative constants to only + * 8 fractional bits, rather than 13. This saves some shifting work on some + * machines, and may also reduce the cost of multiplication (since there + * are fewer one-bits in the constants). + */ + +#define CONST_BITS 8 + + +/* Some C compilers fail to reduce "FIX(constant)" at compile time, thus + * causing a lot of useless floating-point operations at run time. + * To get around this we use the following pre-calculated constants. + * If you change CONST_BITS you may want to add appropriate values. + * (With a reasonable C compiler, you can just rely on the FIX() macro...) + */ + +#if CONST_BITS == 8 +#define FIX_0_382683433 ((long) 98) /* FIX(0.382683433) */ +#define FIX_0_541196100 ((long) 139) /* FIX(0.541196100) */ +#define FIX_0_707106781 ((long) 181) /* FIX(0.707106781) */ +#define FIX_1_306562965 ((long) 334) /* FIX(1.306562965) */ +#else +#define FIX_0_382683433 FIX(0.382683433) +#define FIX_0_541196100 FIX(0.541196100) +#define FIX_0_707106781 FIX(0.707106781) +#define FIX_1_306562965 FIX(1.306562965) +#endif + + +/* We can gain a little more speed, with a further compromise in accuracy, + * by omitting the addition in a descaling shift. This yields an incorrectly + * rounded result half the time... + */ + +#ifndef USE_ACCURATE_ROUNDING +#undef DESCALE +#define DESCALE(x,n) RIGHT_SHIFT(x, n) +#endif + + +/* Multiply a DCTELEM variable by an long constant, and immediately + * descale to yield a DCTELEM result. + */ + +#define MULTIPLY(var,const) ((DCTELEM) DESCALE((var) * (const), CONST_BITS)) + + +/* + * Perform the forward DCT on one block of samples. + */ + +GLOBAL(void) +jpeg_fdct_ifast (DCTELEM * data) +{ + DCTELEM tmp0, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7; + DCTELEM tmp10, tmp11, tmp12, tmp13; + DCTELEM z1, z2, z3, z4, z5, z11, z13; + DCTELEM *dataptr; + int ctr; + SHIFT_TEMPS + + /* Pass 1: process rows. */ + + dataptr = data; + for (ctr = DCTSIZE-1; ctr >= 0; ctr--) { + tmp0 = dataptr[0] + dataptr[7]; + tmp7 = dataptr[0] - dataptr[7]; + tmp1 = dataptr[1] + dataptr[6]; + tmp6 = dataptr[1] - dataptr[6]; + tmp2 = dataptr[2] + dataptr[5]; + tmp5 = dataptr[2] - dataptr[5]; + tmp3 = dataptr[3] + dataptr[4]; + tmp4 = dataptr[3] - dataptr[4]; + + /* Even part */ + + tmp10 = tmp0 + tmp3; /* phase 2 */ + tmp13 = tmp0 - tmp3; + tmp11 = tmp1 + tmp2; + tmp12 = tmp1 - tmp2; + + dataptr[0] = tmp10 + tmp11; /* phase 3 */ + dataptr[4] = tmp10 - tmp11; + + z1 = MULTIPLY(tmp12 + tmp13, FIX_0_707106781); /* c4 */ + dataptr[2] = tmp13 + z1; /* phase 5 */ + dataptr[6] = tmp13 - z1; + + /* Odd part */ + + tmp10 = tmp4 + tmp5; /* phase 2 */ + tmp11 = tmp5 + tmp6; + tmp12 = tmp6 + tmp7; + + /* The rotator is modified from fig 4-8 to avoid extra negations. */ + z5 = MULTIPLY(tmp10 - tmp12, FIX_0_382683433); /* c6 */ + z2 = MULTIPLY(tmp10, FIX_0_541196100) + z5; /* c2-c6 */ + z4 = MULTIPLY(tmp12, FIX_1_306562965) + z5; /* c2+c6 */ + z3 = MULTIPLY(tmp11, FIX_0_707106781); /* c4 */ + + z11 = tmp7 + z3; /* phase 5 */ + z13 = tmp7 - z3; + + dataptr[5] = z13 + z2; /* phase 6 */ + dataptr[3] = z13 - z2; + dataptr[1] = z11 + z4; + dataptr[7] = z11 - z4; + + dataptr += DCTSIZE; /* advance pointer to next row */ + } + + /* Pass 2: process columns. */ + + dataptr = data; + for (ctr = DCTSIZE-1; ctr >= 0; ctr--) { + tmp0 = dataptr[DCTSIZE*0] + dataptr[DCTSIZE*7]; + tmp7 = dataptr[DCTSIZE*0] - dataptr[DCTSIZE*7]; + tmp1 = dataptr[DCTSIZE*1] + dataptr[DCTSIZE*6]; + tmp6 = dataptr[DCTSIZE*1] - dataptr[DCTSIZE*6]; + tmp2 = dataptr[DCTSIZE*2] + dataptr[DCTSIZE*5]; + tmp5 = dataptr[DCTSIZE*2] - dataptr[DCTSIZE*5]; + tmp3 = dataptr[DCTSIZE*3] + dataptr[DCTSIZE*4]; + tmp4 = dataptr[DCTSIZE*3] - dataptr[DCTSIZE*4]; + + /* Even part */ + + tmp10 = tmp0 + tmp3; /* phase 2 */ + tmp13 = tmp0 - tmp3; + tmp11 = tmp1 + tmp2; + tmp12 = tmp1 - tmp2; + + dataptr[DCTSIZE*0] = tmp10 + tmp11; /* phase 3 */ + dataptr[DCTSIZE*4] = tmp10 - tmp11; + + z1 = MULTIPLY(tmp12 + tmp13, FIX_0_707106781); /* c4 */ + dataptr[DCTSIZE*2] = tmp13 + z1; /* phase 5 */ + dataptr[DCTSIZE*6] = tmp13 - z1; + + /* Odd part */ + + tmp10 = tmp4 + tmp5; /* phase 2 */ + tmp11 = tmp5 + tmp6; + tmp12 = tmp6 + tmp7; + + /* The rotator is modified from fig 4-8 to avoid extra negations. */ + z5 = MULTIPLY(tmp10 - tmp12, FIX_0_382683433); /* c6 */ + z2 = MULTIPLY(tmp10, FIX_0_541196100) + z5; /* c2-c6 */ + z4 = MULTIPLY(tmp12, FIX_1_306562965) + z5; /* c2+c6 */ + z3 = MULTIPLY(tmp11, FIX_0_707106781); /* c4 */ + + z11 = tmp7 + z3; /* phase 5 */ + z13 = tmp7 - z3; + + dataptr[DCTSIZE*5] = z13 + z2; /* phase 6 */ + dataptr[DCTSIZE*3] = z13 - z2; + dataptr[DCTSIZE*1] = z11 + z4; + dataptr[DCTSIZE*7] = z11 - z4; + + dataptr++; /* advance pointer to next column */ + } +} + +#endif /* DCT_IFAST_SUPPORTED */ diff --git a/ml/dlib/dlib/external/libjpeg/jfdctint.cpp b/ml/dlib/dlib/external/libjpeg/jfdctint.cpp new file mode 100644 index 000000000..b6046a26b --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jfdctint.cpp @@ -0,0 +1,283 @@ +/* + * jfdctint.c + * + * Copyright (C) 1991-1996, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This file contains a slow-but-accurate integer implementation of the + * forward DCT (Discrete Cosine Transform). + * + * A 2-D DCT can be done by 1-D DCT on each row followed by 1-D DCT + * on each column. Direct algorithms are also available, but they are + * much more complex and seem not to be any faster when reduced to code. + * + * This implementation is based on an algorithm described in + * C. Loeffler, A. Ligtenberg and G. Moschytz, "Practical Fast 1-D DCT + * Algorithms with 11 Multiplications", Proc. Int'l. Conf. on Acoustics, + * Speech, and Signal Processing 1989 (ICASSP '89), pp. 988-991. + * The primary algorithm described there uses 11 multiplies and 29 adds. + * We use their alternate method with 12 multiplies and 32 adds. + * The advantage of this method is that no data path contains more than one + * multiplication; this allows a very simple and accurate implementation in + * scaled fixed-point arithmetic, with a minimal number of shifts. + */ + +#define JPEG_INTERNALS +#include "jinclude.h" +#include "jpeglib.h" +#include "jdct.h" /* Private declarations for DCT subsystem */ + +#ifdef DCT_ISLOW_SUPPORTED + + +/* + * This module is specialized to the case DCTSIZE = 8. + */ + +#if DCTSIZE != 8 + Sorry, this code only copes with 8x8 DCTs. /* deliberate syntax err */ +#endif + + +/* + * The poop on this scaling stuff is as follows: + * + * Each 1-D DCT step produces outputs which are a factor of sqrt(N) + * larger than the true DCT outputs. The final outputs are therefore + * a factor of N larger than desired; since N=8 this can be cured by + * a simple right shift at the end of the algorithm. The advantage of + * this arrangement is that we save two multiplications per 1-D DCT, + * because the y0 and y4 outputs need not be divided by sqrt(N). + * In the IJG code, this factor of 8 is removed by the quantization step + * (in jcdctmgr.c), NOT in this module. + * + * We have to do addition and subtraction of the integer inputs, which + * is no problem, and multiplication by fractional constants, which is + * a problem to do in integer arithmetic. We multiply all the constants + * by CONST_SCALE and convert them to integer constants (thus retaining + * CONST_BITS bits of precision in the constants). After doing a + * multiplication we have to divide the product by CONST_SCALE, with proper + * rounding, to produce the correct output. This division can be done + * cheaply as a right shift of CONST_BITS bits. We postpone shifting + * as long as possible so that partial sums can be added together with + * full fractional precision. + * + * The outputs of the first pass are scaled up by PASS1_BITS bits so that + * they are represented to better-than-integral precision. These outputs + * require BITS_IN_JSAMPLE + PASS1_BITS + 3 bits; this fits in a 16-bit word + * with the recommended scaling. (For 12-bit sample data, the intermediate + * array is long anyway.) + * + * To avoid overflow of the 32-bit intermediate results in pass 2, we must + * have BITS_IN_JSAMPLE + CONST_BITS + PASS1_BITS <= 26. Error analysis + * shows that the values given below are the most effective. + */ + +#if BITS_IN_JSAMPLE == 8 +#define CONST_BITS 13 +#define PASS1_BITS 2 +#else +#define CONST_BITS 13 +#define PASS1_BITS 1 /* lose a little precision to avoid overflow */ +#endif + +/* Some C compilers fail to reduce "FIX(constant)" at compile time, thus + * causing a lot of useless floating-point operations at run time. + * To get around this we use the following pre-calculated constants. + * If you change CONST_BITS you may want to add appropriate values. + * (With a reasonable C compiler, you can just rely on the FIX() macro...) + */ + +#if CONST_BITS == 13 +#define FIX_0_298631336 ((long) 2446) /* FIX(0.298631336) */ +#define FIX_0_390180644 ((long) 3196) /* FIX(0.390180644) */ +#define FIX_0_541196100 ((long) 4433) /* FIX(0.541196100) */ +#define FIX_0_765366865 ((long) 6270) /* FIX(0.765366865) */ +#define FIX_0_899976223 ((long) 7373) /* FIX(0.899976223) */ +#define FIX_1_175875602 ((long) 9633) /* FIX(1.175875602) */ +#define FIX_1_501321110 ((long) 12299) /* FIX(1.501321110) */ +#define FIX_1_847759065 ((long) 15137) /* FIX(1.847759065) */ +#define FIX_1_961570560 ((long) 16069) /* FIX(1.961570560) */ +#define FIX_2_053119869 ((long) 16819) /* FIX(2.053119869) */ +#define FIX_2_562915447 ((long) 20995) /* FIX(2.562915447) */ +#define FIX_3_072711026 ((long) 25172) /* FIX(3.072711026) */ +#else +#define FIX_0_298631336 FIX(0.298631336) +#define FIX_0_390180644 FIX(0.390180644) +#define FIX_0_541196100 FIX(0.541196100) +#define FIX_0_765366865 FIX(0.765366865) +#define FIX_0_899976223 FIX(0.899976223) +#define FIX_1_175875602 FIX(1.175875602) +#define FIX_1_501321110 FIX(1.501321110) +#define FIX_1_847759065 FIX(1.847759065) +#define FIX_1_961570560 FIX(1.961570560) +#define FIX_2_053119869 FIX(2.053119869) +#define FIX_2_562915447 FIX(2.562915447) +#define FIX_3_072711026 FIX(3.072711026) +#endif + + +/* Multiply an long variable by an long constant to yield an long result. + * For 8-bit samples with the recommended scaling, all the variable + * and constant values involved are no more than 16 bits wide, so a + * 16x16->32 bit multiply can be used instead of a full 32x32 multiply. + * For 12-bit samples, a full 32-bit multiplication will be needed. + */ + +#if BITS_IN_JSAMPLE == 8 +#define MULTIPLY(var,const) MULTIPLY16C16(var,const) +#else +#define MULTIPLY(var,const) ((var) * (const)) +#endif + + +/* + * Perform the forward DCT on one block of samples. + */ + +GLOBAL(void) +jpeg_fdct_islow (DCTELEM * data) +{ + long tmp0, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7; + long tmp10, tmp11, tmp12, tmp13; + long z1, z2, z3, z4, z5; + DCTELEM *dataptr; + int ctr; + SHIFT_TEMPS + + /* Pass 1: process rows. */ + /* Note results are scaled up by sqrt(8) compared to a true DCT; */ + /* furthermore, we scale the results by 2**PASS1_BITS. */ + + dataptr = data; + for (ctr = DCTSIZE-1; ctr >= 0; ctr--) { + tmp0 = dataptr[0] + dataptr[7]; + tmp7 = dataptr[0] - dataptr[7]; + tmp1 = dataptr[1] + dataptr[6]; + tmp6 = dataptr[1] - dataptr[6]; + tmp2 = dataptr[2] + dataptr[5]; + tmp5 = dataptr[2] - dataptr[5]; + tmp3 = dataptr[3] + dataptr[4]; + tmp4 = dataptr[3] - dataptr[4]; + + /* Even part per LL&M figure 1 --- note that published figure is faulty; + * rotator "sqrt(2)*c1" should be "sqrt(2)*c6". + */ + + tmp10 = tmp0 + tmp3; + tmp13 = tmp0 - tmp3; + tmp11 = tmp1 + tmp2; + tmp12 = tmp1 - tmp2; + + dataptr[0] = (DCTELEM) ((tmp10 + tmp11) << PASS1_BITS); + dataptr[4] = (DCTELEM) ((tmp10 - tmp11) << PASS1_BITS); + + z1 = MULTIPLY(tmp12 + tmp13, FIX_0_541196100); + dataptr[2] = (DCTELEM) DESCALE(z1 + MULTIPLY(tmp13, FIX_0_765366865), + CONST_BITS-PASS1_BITS); + dataptr[6] = (DCTELEM) DESCALE(z1 + MULTIPLY(tmp12, - FIX_1_847759065), + CONST_BITS-PASS1_BITS); + + /* Odd part per figure 8 --- note paper omits factor of sqrt(2). + * cK represents cos(K*pi/16). + * i0..i3 in the paper are tmp4..tmp7 here. + */ + + z1 = tmp4 + tmp7; + z2 = tmp5 + tmp6; + z3 = tmp4 + tmp6; + z4 = tmp5 + tmp7; + z5 = MULTIPLY(z3 + z4, FIX_1_175875602); /* sqrt(2) * c3 */ + + tmp4 = MULTIPLY(tmp4, FIX_0_298631336); /* sqrt(2) * (-c1+c3+c5-c7) */ + tmp5 = MULTIPLY(tmp5, FIX_2_053119869); /* sqrt(2) * ( c1+c3-c5+c7) */ + tmp6 = MULTIPLY(tmp6, FIX_3_072711026); /* sqrt(2) * ( c1+c3+c5-c7) */ + tmp7 = MULTIPLY(tmp7, FIX_1_501321110); /* sqrt(2) * ( c1+c3-c5-c7) */ + z1 = MULTIPLY(z1, - FIX_0_899976223); /* sqrt(2) * (c7-c3) */ + z2 = MULTIPLY(z2, - FIX_2_562915447); /* sqrt(2) * (-c1-c3) */ + z3 = MULTIPLY(z3, - FIX_1_961570560); /* sqrt(2) * (-c3-c5) */ + z4 = MULTIPLY(z4, - FIX_0_390180644); /* sqrt(2) * (c5-c3) */ + + z3 += z5; + z4 += z5; + + dataptr[7] = (DCTELEM) DESCALE(tmp4 + z1 + z3, CONST_BITS-PASS1_BITS); + dataptr[5] = (DCTELEM) DESCALE(tmp5 + z2 + z4, CONST_BITS-PASS1_BITS); + dataptr[3] = (DCTELEM) DESCALE(tmp6 + z2 + z3, CONST_BITS-PASS1_BITS); + dataptr[1] = (DCTELEM) DESCALE(tmp7 + z1 + z4, CONST_BITS-PASS1_BITS); + + dataptr += DCTSIZE; /* advance pointer to next row */ + } + + /* Pass 2: process columns. + * We remove the PASS1_BITS scaling, but leave the results scaled up + * by an overall factor of 8. + */ + + dataptr = data; + for (ctr = DCTSIZE-1; ctr >= 0; ctr--) { + tmp0 = dataptr[DCTSIZE*0] + dataptr[DCTSIZE*7]; + tmp7 = dataptr[DCTSIZE*0] - dataptr[DCTSIZE*7]; + tmp1 = dataptr[DCTSIZE*1] + dataptr[DCTSIZE*6]; + tmp6 = dataptr[DCTSIZE*1] - dataptr[DCTSIZE*6]; + tmp2 = dataptr[DCTSIZE*2] + dataptr[DCTSIZE*5]; + tmp5 = dataptr[DCTSIZE*2] - dataptr[DCTSIZE*5]; + tmp3 = dataptr[DCTSIZE*3] + dataptr[DCTSIZE*4]; + tmp4 = dataptr[DCTSIZE*3] - dataptr[DCTSIZE*4]; + + /* Even part per LL&M figure 1 --- note that published figure is faulty; + * rotator "sqrt(2)*c1" should be "sqrt(2)*c6". + */ + + tmp10 = tmp0 + tmp3; + tmp13 = tmp0 - tmp3; + tmp11 = tmp1 + tmp2; + tmp12 = tmp1 - tmp2; + + dataptr[DCTSIZE*0] = (DCTELEM) DESCALE(tmp10 + tmp11, PASS1_BITS); + dataptr[DCTSIZE*4] = (DCTELEM) DESCALE(tmp10 - tmp11, PASS1_BITS); + + z1 = MULTIPLY(tmp12 + tmp13, FIX_0_541196100); + dataptr[DCTSIZE*2] = (DCTELEM) DESCALE(z1 + MULTIPLY(tmp13, FIX_0_765366865), + CONST_BITS+PASS1_BITS); + dataptr[DCTSIZE*6] = (DCTELEM) DESCALE(z1 + MULTIPLY(tmp12, - FIX_1_847759065), + CONST_BITS+PASS1_BITS); + + /* Odd part per figure 8 --- note paper omits factor of sqrt(2). + * cK represents cos(K*pi/16). + * i0..i3 in the paper are tmp4..tmp7 here. + */ + + z1 = tmp4 + tmp7; + z2 = tmp5 + tmp6; + z3 = tmp4 + tmp6; + z4 = tmp5 + tmp7; + z5 = MULTIPLY(z3 + z4, FIX_1_175875602); /* sqrt(2) * c3 */ + + tmp4 = MULTIPLY(tmp4, FIX_0_298631336); /* sqrt(2) * (-c1+c3+c5-c7) */ + tmp5 = MULTIPLY(tmp5, FIX_2_053119869); /* sqrt(2) * ( c1+c3-c5+c7) */ + tmp6 = MULTIPLY(tmp6, FIX_3_072711026); /* sqrt(2) * ( c1+c3+c5-c7) */ + tmp7 = MULTIPLY(tmp7, FIX_1_501321110); /* sqrt(2) * ( c1+c3-c5-c7) */ + z1 = MULTIPLY(z1, - FIX_0_899976223); /* sqrt(2) * (c7-c3) */ + z2 = MULTIPLY(z2, - FIX_2_562915447); /* sqrt(2) * (-c1-c3) */ + z3 = MULTIPLY(z3, - FIX_1_961570560); /* sqrt(2) * (-c3-c5) */ + z4 = MULTIPLY(z4, - FIX_0_390180644); /* sqrt(2) * (c5-c3) */ + + z3 += z5; + z4 += z5; + + dataptr[DCTSIZE*7] = (DCTELEM) DESCALE(tmp4 + z1 + z3, + CONST_BITS+PASS1_BITS); + dataptr[DCTSIZE*5] = (DCTELEM) DESCALE(tmp5 + z2 + z4, + CONST_BITS+PASS1_BITS); + dataptr[DCTSIZE*3] = (DCTELEM) DESCALE(tmp6 + z2 + z3, + CONST_BITS+PASS1_BITS); + dataptr[DCTSIZE*1] = (DCTELEM) DESCALE(tmp7 + z1 + z4, + CONST_BITS+PASS1_BITS); + + dataptr++; /* advance pointer to next column */ + } +} + +#endif /* DCT_ISLOW_SUPPORTED */ diff --git a/ml/dlib/dlib/external/libjpeg/jidctflt.cpp b/ml/dlib/dlib/external/libjpeg/jidctflt.cpp new file mode 100644 index 000000000..3e9b54579 --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jidctflt.cpp @@ -0,0 +1,242 @@ +/* + * jidctflt.c + * + * Copyright (C) 1994-1998, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This file contains a floating-point implementation of the + * inverse DCT (Discrete Cosine Transform). In the IJG code, this routine + * must also perform dequantization of the input coefficients. + * + * This implementation should be more accurate than either of the integer + * IDCT implementations. However, it may not give the same results on all + * machines because of differences in roundoff behavior. Speed will depend + * on the hardware's floating point capacity. + * + * A 2-D IDCT can be done by 1-D IDCT on each column followed by 1-D IDCT + * on each row (or vice versa, but it's more convenient to emit a row at + * a time). Direct algorithms are also available, but they are much more + * complex and seem not to be any faster when reduced to code. + * + * This implementation is based on Arai, Agui, and Nakajima's algorithm for + * scaled DCT. Their original paper (Trans. IEICE E-71(11):1095) is in + * Japanese, but the algorithm is described in the Pennebaker & Mitchell + * JPEG textbook (see REFERENCES section in file README). The following code + * is based directly on figure 4-8 in P&M. + * While an 8-point DCT cannot be done in less than 11 multiplies, it is + * possible to arrange the computation so that many of the multiplies are + * simple scalings of the final outputs. These multiplies can then be + * folded into the multiplications or divisions by the JPEG quantization + * table entries. The AA&N method leaves only 5 multiplies and 29 adds + * to be done in the DCT itself. + * The primary disadvantage of this method is that with a fixed-point + * implementation, accuracy is lost due to imprecise representation of the + * scaled quantization values. However, that problem does not arise if + * we use floating point arithmetic. + */ + +#define JPEG_INTERNALS +#include "jinclude.h" +#include "jpeglib.h" +#include "jdct.h" /* Private declarations for DCT subsystem */ + +#ifdef DCT_FLOAT_SUPPORTED + + +/* + * This module is specialized to the case DCTSIZE = 8. + */ + +#if DCTSIZE != 8 + Sorry, this code only copes with 8x8 DCTs. /* deliberate syntax err */ +#endif + + +/* Dequantize a coefficient by multiplying it by the multiplier-table + * entry; produce a float result. + */ + +#define DEQUANTIZE(coef,quantval) (((FAST_FLOAT) (coef)) * (quantval)) + + +/* + * Perform dequantization and inverse DCT on one block of coefficients. + */ + +GLOBAL(void) +jpeg_idct_float (j_decompress_ptr cinfo, jpeg_component_info * compptr, + JCOEFPTR coef_block, + JSAMPARRAY output_buf, JDIMENSION output_col) +{ + FAST_FLOAT tmp0, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7; + FAST_FLOAT tmp10, tmp11, tmp12, tmp13; + FAST_FLOAT z5, z10, z11, z12, z13; + JCOEFPTR inptr; + FLOAT_MULT_TYPE * quantptr; + FAST_FLOAT * wsptr; + JSAMPROW outptr; + JSAMPLE *range_limit = IDCT_range_limit(cinfo); + int ctr; + FAST_FLOAT workspace[DCTSIZE2]; /* buffers data between passes */ + SHIFT_TEMPS + + /* Pass 1: process columns from input, store into work array. */ + + inptr = coef_block; + quantptr = (FLOAT_MULT_TYPE *) compptr->dct_table; + wsptr = workspace; + for (ctr = DCTSIZE; ctr > 0; ctr--) { + /* Due to quantization, we will usually find that many of the input + * coefficients are zero, especially the AC terms. We can exploit this + * by short-circuiting the IDCT calculation for any column in which all + * the AC terms are zero. In that case each output is equal to the + * DC coefficient (with scale factor as needed). + * With typical images and quantization tables, half or more of the + * column DCT calculations can be simplified this way. + */ + + if (inptr[DCTSIZE*1] == 0 && inptr[DCTSIZE*2] == 0 && + inptr[DCTSIZE*3] == 0 && inptr[DCTSIZE*4] == 0 && + inptr[DCTSIZE*5] == 0 && inptr[DCTSIZE*6] == 0 && + inptr[DCTSIZE*7] == 0) { + /* AC terms all zero */ + FAST_FLOAT dcval = DEQUANTIZE(inptr[DCTSIZE*0], quantptr[DCTSIZE*0]); + + wsptr[DCTSIZE*0] = dcval; + wsptr[DCTSIZE*1] = dcval; + wsptr[DCTSIZE*2] = dcval; + wsptr[DCTSIZE*3] = dcval; + wsptr[DCTSIZE*4] = dcval; + wsptr[DCTSIZE*5] = dcval; + wsptr[DCTSIZE*6] = dcval; + wsptr[DCTSIZE*7] = dcval; + + inptr++; /* advance pointers to next column */ + quantptr++; + wsptr++; + continue; + } + + /* Even part */ + + tmp0 = DEQUANTIZE(inptr[DCTSIZE*0], quantptr[DCTSIZE*0]); + tmp1 = DEQUANTIZE(inptr[DCTSIZE*2], quantptr[DCTSIZE*2]); + tmp2 = DEQUANTIZE(inptr[DCTSIZE*4], quantptr[DCTSIZE*4]); + tmp3 = DEQUANTIZE(inptr[DCTSIZE*6], quantptr[DCTSIZE*6]); + + tmp10 = tmp0 + tmp2; /* phase 3 */ + tmp11 = tmp0 - tmp2; + + tmp13 = tmp1 + tmp3; /* phases 5-3 */ + tmp12 = (tmp1 - tmp3) * ((FAST_FLOAT) 1.414213562) - tmp13; /* 2*c4 */ + + tmp0 = tmp10 + tmp13; /* phase 2 */ + tmp3 = tmp10 - tmp13; + tmp1 = tmp11 + tmp12; + tmp2 = tmp11 - tmp12; + + /* Odd part */ + + tmp4 = DEQUANTIZE(inptr[DCTSIZE*1], quantptr[DCTSIZE*1]); + tmp5 = DEQUANTIZE(inptr[DCTSIZE*3], quantptr[DCTSIZE*3]); + tmp6 = DEQUANTIZE(inptr[DCTSIZE*5], quantptr[DCTSIZE*5]); + tmp7 = DEQUANTIZE(inptr[DCTSIZE*7], quantptr[DCTSIZE*7]); + + z13 = tmp6 + tmp5; /* phase 6 */ + z10 = tmp6 - tmp5; + z11 = tmp4 + tmp7; + z12 = tmp4 - tmp7; + + tmp7 = z11 + z13; /* phase 5 */ + tmp11 = (z11 - z13) * ((FAST_FLOAT) 1.414213562); /* 2*c4 */ + + z5 = (z10 + z12) * ((FAST_FLOAT) 1.847759065); /* 2*c2 */ + tmp10 = ((FAST_FLOAT) 1.082392200) * z12 - z5; /* 2*(c2-c6) */ + tmp12 = ((FAST_FLOAT) -2.613125930) * z10 + z5; /* -2*(c2+c6) */ + + tmp6 = tmp12 - tmp7; /* phase 2 */ + tmp5 = tmp11 - tmp6; + tmp4 = tmp10 + tmp5; + + wsptr[DCTSIZE*0] = tmp0 + tmp7; + wsptr[DCTSIZE*7] = tmp0 - tmp7; + wsptr[DCTSIZE*1] = tmp1 + tmp6; + wsptr[DCTSIZE*6] = tmp1 - tmp6; + wsptr[DCTSIZE*2] = tmp2 + tmp5; + wsptr[DCTSIZE*5] = tmp2 - tmp5; + wsptr[DCTSIZE*4] = tmp3 + tmp4; + wsptr[DCTSIZE*3] = tmp3 - tmp4; + + inptr++; /* advance pointers to next column */ + quantptr++; + wsptr++; + } + + /* Pass 2: process rows from work array, store into output array. */ + /* Note that we must descale the results by a factor of 8 == 2**3. */ + + wsptr = workspace; + for (ctr = 0; ctr < DCTSIZE; ctr++) { + outptr = output_buf[ctr] + output_col; + /* Rows of zeroes can be exploited in the same way as we did with columns. + * However, the column calculation has created many nonzero AC terms, so + * the simplification applies less often (typically 5% to 10% of the time). + * And testing floats for zero is relatively expensive, so we don't bother. + */ + + /* Even part */ + + tmp10 = wsptr[0] + wsptr[4]; + tmp11 = wsptr[0] - wsptr[4]; + + tmp13 = wsptr[2] + wsptr[6]; + tmp12 = (wsptr[2] - wsptr[6]) * ((FAST_FLOAT) 1.414213562) - tmp13; + + tmp0 = tmp10 + tmp13; + tmp3 = tmp10 - tmp13; + tmp1 = tmp11 + tmp12; + tmp2 = tmp11 - tmp12; + + /* Odd part */ + + z13 = wsptr[5] + wsptr[3]; + z10 = wsptr[5] - wsptr[3]; + z11 = wsptr[1] + wsptr[7]; + z12 = wsptr[1] - wsptr[7]; + + tmp7 = z11 + z13; + tmp11 = (z11 - z13) * ((FAST_FLOAT) 1.414213562); + + z5 = (z10 + z12) * ((FAST_FLOAT) 1.847759065); /* 2*c2 */ + tmp10 = ((FAST_FLOAT) 1.082392200) * z12 - z5; /* 2*(c2-c6) */ + tmp12 = ((FAST_FLOAT) -2.613125930) * z10 + z5; /* -2*(c2+c6) */ + + tmp6 = tmp12 - tmp7; + tmp5 = tmp11 - tmp6; + tmp4 = tmp10 + tmp5; + + /* Final output stage: scale down by a factor of 8 and range-limit */ + + outptr[0] = range_limit[(int) DESCALE((long) (tmp0 + tmp7), 3) + & RANGE_MASK]; + outptr[7] = range_limit[(int) DESCALE((long) (tmp0 - tmp7), 3) + & RANGE_MASK]; + outptr[1] = range_limit[(int) DESCALE((long) (tmp1 + tmp6), 3) + & RANGE_MASK]; + outptr[6] = range_limit[(int) DESCALE((long) (tmp1 - tmp6), 3) + & RANGE_MASK]; + outptr[2] = range_limit[(int) DESCALE((long) (tmp2 + tmp5), 3) + & RANGE_MASK]; + outptr[5] = range_limit[(int) DESCALE((long) (tmp2 - tmp5), 3) + & RANGE_MASK]; + outptr[4] = range_limit[(int) DESCALE((long) (tmp3 + tmp4), 3) + & RANGE_MASK]; + outptr[3] = range_limit[(int) DESCALE((long) (tmp3 - tmp4), 3) + & RANGE_MASK]; + + wsptr += DCTSIZE; /* advance pointer to next row */ + } +} + +#endif /* DCT_FLOAT_SUPPORTED */ diff --git a/ml/dlib/dlib/external/libjpeg/jidctfst.cpp b/ml/dlib/dlib/external/libjpeg/jidctfst.cpp new file mode 100644 index 000000000..e08835bc5 --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jidctfst.cpp @@ -0,0 +1,368 @@ +/* + * jidctfst.c + * + * Copyright (C) 1994-1998, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This file contains a fast, not so accurate integer implementation of the + * inverse DCT (Discrete Cosine Transform). In the IJG code, this routine + * must also perform dequantization of the input coefficients. + * + * A 2-D IDCT can be done by 1-D IDCT on each column followed by 1-D IDCT + * on each row (or vice versa, but it's more convenient to emit a row at + * a time). Direct algorithms are also available, but they are much more + * complex and seem not to be any faster when reduced to code. + * + * This implementation is based on Arai, Agui, and Nakajima's algorithm for + * scaled DCT. Their original paper (Trans. IEICE E-71(11):1095) is in + * Japanese, but the algorithm is described in the Pennebaker & Mitchell + * JPEG textbook (see REFERENCES section in file README). The following code + * is based directly on figure 4-8 in P&M. + * While an 8-point DCT cannot be done in less than 11 multiplies, it is + * possible to arrange the computation so that many of the multiplies are + * simple scalings of the final outputs. These multiplies can then be + * folded into the multiplications or divisions by the JPEG quantization + * table entries. The AA&N method leaves only 5 multiplies and 29 adds + * to be done in the DCT itself. + * The primary disadvantage of this method is that with fixed-point math, + * accuracy is lost due to imprecise representation of the scaled + * quantization values. The smaller the quantization table entry, the less + * precise the scaled value, so this implementation does worse with high- + * quality-setting files than with low-quality ones. + */ + +#define JPEG_INTERNALS +#include "jinclude.h" +#include "jpeglib.h" +#include "jdct.h" /* Private declarations for DCT subsystem */ + +#ifdef DCT_IFAST_SUPPORTED + + +/* + * This module is specialized to the case DCTSIZE = 8. + */ + +#if DCTSIZE != 8 + Sorry, this code only copes with 8x8 DCTs. /* deliberate syntax err */ +#endif + + +/* Scaling decisions are generally the same as in the LL&M algorithm; + * see jidctint.c for more details. However, we choose to descale + * (right shift) multiplication products as soon as they are formed, + * rather than carrying additional fractional bits into subsequent additions. + * This compromises accuracy slightly, but it lets us save a few shifts. + * More importantly, 16-bit arithmetic is then adequate (for 8-bit samples) + * everywhere except in the multiplications proper; this saves a good deal + * of work on 16-bit-int machines. + * + * The dequantized coefficients are not integers because the AA&N scaling + * factors have been incorporated. We represent them scaled up by PASS1_BITS, + * so that the first and second IDCT rounds have the same input scaling. + * For 8-bit JSAMPLEs, we choose IFAST_SCALE_BITS = PASS1_BITS so as to + * avoid a descaling shift; this compromises accuracy rather drastically + * for small quantization table entries, but it saves a lot of shifts. + * For 12-bit JSAMPLEs, there's no hope of using 16x16 multiplies anyway, + * so we use a much larger scaling factor to preserve accuracy. + * + * A final compromise is to represent the multiplicative constants to only + * 8 fractional bits, rather than 13. This saves some shifting work on some + * machines, and may also reduce the cost of multiplication (since there + * are fewer one-bits in the constants). + */ + +#if BITS_IN_JSAMPLE == 8 +#define CONST_BITS 8 +#define PASS1_BITS 2 +#else +#define CONST_BITS 8 +#define PASS1_BITS 1 /* lose a little precision to avoid overflow */ +#endif + +/* Some C compilers fail to reduce "FIX(constant)" at compile time, thus + * causing a lot of useless floating-point operations at run time. + * To get around this we use the following pre-calculated constants. + * If you change CONST_BITS you may want to add appropriate values. + * (With a reasonable C compiler, you can just rely on the FIX() macro...) + */ + +#if CONST_BITS == 8 +#define FIX_1_082392200 ((long) 277) /* FIX(1.082392200) */ +#define FIX_1_414213562 ((long) 362) /* FIX(1.414213562) */ +#define FIX_1_847759065 ((long) 473) /* FIX(1.847759065) */ +#define FIX_2_613125930 ((long) 669) /* FIX(2.613125930) */ +#else +#define FIX_1_082392200 FIX(1.082392200) +#define FIX_1_414213562 FIX(1.414213562) +#define FIX_1_847759065 FIX(1.847759065) +#define FIX_2_613125930 FIX(2.613125930) +#endif + + +/* We can gain a little more speed, with a further compromise in accuracy, + * by omitting the addition in a descaling shift. This yields an incorrectly + * rounded result half the time... + */ + +#ifndef USE_ACCURATE_ROUNDING +#undef DESCALE +#define DESCALE(x,n) RIGHT_SHIFT(x, n) +#endif + + +/* Multiply a DCTELEM variable by an long constant, and immediately + * descale to yield a DCTELEM result. + */ + +#define MULTIPLY(var,const) ((DCTELEM) DESCALE((var) * (const), CONST_BITS)) + + +/* Dequantize a coefficient by multiplying it by the multiplier-table + * entry; produce a DCTELEM result. For 8-bit data a 16x16->16 + * multiplication will do. For 12-bit data, the multiplier table is + * declared long, so a 32-bit multiply will be used. + */ + +#if BITS_IN_JSAMPLE == 8 +#define DEQUANTIZE(coef,quantval) (((IFAST_MULT_TYPE) (coef)) * (quantval)) +#else +#define DEQUANTIZE(coef,quantval) \ + DESCALE((coef)*(quantval), IFAST_SCALE_BITS-PASS1_BITS) +#endif + + +/* Like DESCALE, but applies to a DCTELEM and produces an int. + * We assume that int right shift is unsigned if long right shift is. + */ + +#ifdef RIGHT_SHIFT_IS_UNSIGNED +#define ISHIFT_TEMPS DCTELEM ishift_temp; +#if BITS_IN_JSAMPLE == 8 +#define DCTELEMBITS 16 /* DCTELEM may be 16 or 32 bits */ +#else +#define DCTELEMBITS 32 /* DCTELEM must be 32 bits */ +#endif +#define IRIGHT_SHIFT(x,shft) \ + ((ishift_temp = (x)) < 0 ? \ + (ishift_temp >> (shft)) | ((~((DCTELEM) 0)) << (DCTELEMBITS-(shft))) : \ + (ishift_temp >> (shft))) +#else +#define ISHIFT_TEMPS +#define IRIGHT_SHIFT(x,shft) ((x) >> (shft)) +#endif + +#ifdef USE_ACCURATE_ROUNDING +#define IDESCALE(x,n) ((int) IRIGHT_SHIFT((x) + (1 << ((n)-1)), n)) +#else +#define IDESCALE(x,n) ((int) IRIGHT_SHIFT(x, n)) +#endif + + +/* + * Perform dequantization and inverse DCT on one block of coefficients. + */ + +GLOBAL(void) +jpeg_idct_ifast (j_decompress_ptr cinfo, jpeg_component_info * compptr, + JCOEFPTR coef_block, + JSAMPARRAY output_buf, JDIMENSION output_col) +{ + DCTELEM tmp0, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7; + DCTELEM tmp10, tmp11, tmp12, tmp13; + DCTELEM z5, z10, z11, z12, z13; + JCOEFPTR inptr; + IFAST_MULT_TYPE * quantptr; + int * wsptr; + JSAMPROW outptr; + JSAMPLE *range_limit = IDCT_range_limit(cinfo); + int ctr; + int workspace[DCTSIZE2]; /* buffers data between passes */ + SHIFT_TEMPS /* for DESCALE */ + ISHIFT_TEMPS /* for IDESCALE */ + + /* Pass 1: process columns from input, store into work array. */ + + inptr = coef_block; + quantptr = (IFAST_MULT_TYPE *) compptr->dct_table; + wsptr = workspace; + for (ctr = DCTSIZE; ctr > 0; ctr--) { + /* Due to quantization, we will usually find that many of the input + * coefficients are zero, especially the AC terms. We can exploit this + * by short-circuiting the IDCT calculation for any column in which all + * the AC terms are zero. In that case each output is equal to the + * DC coefficient (with scale factor as needed). + * With typical images and quantization tables, half or more of the + * column DCT calculations can be simplified this way. + */ + + if (inptr[DCTSIZE*1] == 0 && inptr[DCTSIZE*2] == 0 && + inptr[DCTSIZE*3] == 0 && inptr[DCTSIZE*4] == 0 && + inptr[DCTSIZE*5] == 0 && inptr[DCTSIZE*6] == 0 && + inptr[DCTSIZE*7] == 0) { + /* AC terms all zero */ + int dcval = (int) DEQUANTIZE(inptr[DCTSIZE*0], quantptr[DCTSIZE*0]); + + wsptr[DCTSIZE*0] = dcval; + wsptr[DCTSIZE*1] = dcval; + wsptr[DCTSIZE*2] = dcval; + wsptr[DCTSIZE*3] = dcval; + wsptr[DCTSIZE*4] = dcval; + wsptr[DCTSIZE*5] = dcval; + wsptr[DCTSIZE*6] = dcval; + wsptr[DCTSIZE*7] = dcval; + + inptr++; /* advance pointers to next column */ + quantptr++; + wsptr++; + continue; + } + + /* Even part */ + + tmp0 = DEQUANTIZE(inptr[DCTSIZE*0], quantptr[DCTSIZE*0]); + tmp1 = DEQUANTIZE(inptr[DCTSIZE*2], quantptr[DCTSIZE*2]); + tmp2 = DEQUANTIZE(inptr[DCTSIZE*4], quantptr[DCTSIZE*4]); + tmp3 = DEQUANTIZE(inptr[DCTSIZE*6], quantptr[DCTSIZE*6]); + + tmp10 = tmp0 + tmp2; /* phase 3 */ + tmp11 = tmp0 - tmp2; + + tmp13 = tmp1 + tmp3; /* phases 5-3 */ + tmp12 = MULTIPLY(tmp1 - tmp3, FIX_1_414213562) - tmp13; /* 2*c4 */ + + tmp0 = tmp10 + tmp13; /* phase 2 */ + tmp3 = tmp10 - tmp13; + tmp1 = tmp11 + tmp12; + tmp2 = tmp11 - tmp12; + + /* Odd part */ + + tmp4 = DEQUANTIZE(inptr[DCTSIZE*1], quantptr[DCTSIZE*1]); + tmp5 = DEQUANTIZE(inptr[DCTSIZE*3], quantptr[DCTSIZE*3]); + tmp6 = DEQUANTIZE(inptr[DCTSIZE*5], quantptr[DCTSIZE*5]); + tmp7 = DEQUANTIZE(inptr[DCTSIZE*7], quantptr[DCTSIZE*7]); + + z13 = tmp6 + tmp5; /* phase 6 */ + z10 = tmp6 - tmp5; + z11 = tmp4 + tmp7; + z12 = tmp4 - tmp7; + + tmp7 = z11 + z13; /* phase 5 */ + tmp11 = MULTIPLY(z11 - z13, FIX_1_414213562); /* 2*c4 */ + + z5 = MULTIPLY(z10 + z12, FIX_1_847759065); /* 2*c2 */ + tmp10 = MULTIPLY(z12, FIX_1_082392200) - z5; /* 2*(c2-c6) */ + tmp12 = MULTIPLY(z10, - FIX_2_613125930) + z5; /* -2*(c2+c6) */ + + tmp6 = tmp12 - tmp7; /* phase 2 */ + tmp5 = tmp11 - tmp6; + tmp4 = tmp10 + tmp5; + + wsptr[DCTSIZE*0] = (int) (tmp0 + tmp7); + wsptr[DCTSIZE*7] = (int) (tmp0 - tmp7); + wsptr[DCTSIZE*1] = (int) (tmp1 + tmp6); + wsptr[DCTSIZE*6] = (int) (tmp1 - tmp6); + wsptr[DCTSIZE*2] = (int) (tmp2 + tmp5); + wsptr[DCTSIZE*5] = (int) (tmp2 - tmp5); + wsptr[DCTSIZE*4] = (int) (tmp3 + tmp4); + wsptr[DCTSIZE*3] = (int) (tmp3 - tmp4); + + inptr++; /* advance pointers to next column */ + quantptr++; + wsptr++; + } + + /* Pass 2: process rows from work array, store into output array. */ + /* Note that we must descale the results by a factor of 8 == 2**3, */ + /* and also undo the PASS1_BITS scaling. */ + + wsptr = workspace; + for (ctr = 0; ctr < DCTSIZE; ctr++) { + outptr = output_buf[ctr] + output_col; + /* Rows of zeroes can be exploited in the same way as we did with columns. + * However, the column calculation has created many nonzero AC terms, so + * the simplification applies less often (typically 5% to 10% of the time). + * On machines with very fast multiplication, it's possible that the + * test takes more time than it's worth. In that case this section + * may be commented out. + */ + +#ifndef NO_ZERO_ROW_TEST + if (wsptr[1] == 0 && wsptr[2] == 0 && wsptr[3] == 0 && wsptr[4] == 0 && + wsptr[5] == 0 && wsptr[6] == 0 && wsptr[7] == 0) { + /* AC terms all zero */ + JSAMPLE dcval = range_limit[IDESCALE(wsptr[0], PASS1_BITS+3) + & RANGE_MASK]; + + outptr[0] = dcval; + outptr[1] = dcval; + outptr[2] = dcval; + outptr[3] = dcval; + outptr[4] = dcval; + outptr[5] = dcval; + outptr[6] = dcval; + outptr[7] = dcval; + + wsptr += DCTSIZE; /* advance pointer to next row */ + continue; + } +#endif + + /* Even part */ + + tmp10 = ((DCTELEM) wsptr[0] + (DCTELEM) wsptr[4]); + tmp11 = ((DCTELEM) wsptr[0] - (DCTELEM) wsptr[4]); + + tmp13 = ((DCTELEM) wsptr[2] + (DCTELEM) wsptr[6]); + tmp12 = MULTIPLY((DCTELEM) wsptr[2] - (DCTELEM) wsptr[6], FIX_1_414213562) + - tmp13; + + tmp0 = tmp10 + tmp13; + tmp3 = tmp10 - tmp13; + tmp1 = tmp11 + tmp12; + tmp2 = tmp11 - tmp12; + + /* Odd part */ + + z13 = (DCTELEM) wsptr[5] + (DCTELEM) wsptr[3]; + z10 = (DCTELEM) wsptr[5] - (DCTELEM) wsptr[3]; + z11 = (DCTELEM) wsptr[1] + (DCTELEM) wsptr[7]; + z12 = (DCTELEM) wsptr[1] - (DCTELEM) wsptr[7]; + + tmp7 = z11 + z13; /* phase 5 */ + tmp11 = MULTIPLY(z11 - z13, FIX_1_414213562); /* 2*c4 */ + + z5 = MULTIPLY(z10 + z12, FIX_1_847759065); /* 2*c2 */ + tmp10 = MULTIPLY(z12, FIX_1_082392200) - z5; /* 2*(c2-c6) */ + tmp12 = MULTIPLY(z10, - FIX_2_613125930) + z5; /* -2*(c2+c6) */ + + tmp6 = tmp12 - tmp7; /* phase 2 */ + tmp5 = tmp11 - tmp6; + tmp4 = tmp10 + tmp5; + + /* Final output stage: scale down by a factor of 8 and range-limit */ + + outptr[0] = range_limit[IDESCALE(tmp0 + tmp7, PASS1_BITS+3) + & RANGE_MASK]; + outptr[7] = range_limit[IDESCALE(tmp0 - tmp7, PASS1_BITS+3) + & RANGE_MASK]; + outptr[1] = range_limit[IDESCALE(tmp1 + tmp6, PASS1_BITS+3) + & RANGE_MASK]; + outptr[6] = range_limit[IDESCALE(tmp1 - tmp6, PASS1_BITS+3) + & RANGE_MASK]; + outptr[2] = range_limit[IDESCALE(tmp2 + tmp5, PASS1_BITS+3) + & RANGE_MASK]; + outptr[5] = range_limit[IDESCALE(tmp2 - tmp5, PASS1_BITS+3) + & RANGE_MASK]; + outptr[4] = range_limit[IDESCALE(tmp3 + tmp4, PASS1_BITS+3) + & RANGE_MASK]; + outptr[3] = range_limit[IDESCALE(tmp3 - tmp4, PASS1_BITS+3) + & RANGE_MASK]; + + wsptr += DCTSIZE; /* advance pointer to next row */ + } +} + +#endif /* DCT_IFAST_SUPPORTED */ diff --git a/ml/dlib/dlib/external/libjpeg/jidctint.cpp b/ml/dlib/dlib/external/libjpeg/jidctint.cpp new file mode 100644 index 000000000..630b4fb89 --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jidctint.cpp @@ -0,0 +1,389 @@ +/* + * jidctint.c + * + * Copyright (C) 1991-1998, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This file contains a slow-but-accurate integer implementation of the + * inverse DCT (Discrete Cosine Transform). In the IJG code, this routine + * must also perform dequantization of the input coefficients. + * + * A 2-D IDCT can be done by 1-D IDCT on each column followed by 1-D IDCT + * on each row (or vice versa, but it's more convenient to emit a row at + * a time). Direct algorithms are also available, but they are much more + * complex and seem not to be any faster when reduced to code. + * + * This implementation is based on an algorithm described in + * C. Loeffler, A. Ligtenberg and G. Moschytz, "Practical Fast 1-D DCT + * Algorithms with 11 Multiplications", Proc. Int'l. Conf. on Acoustics, + * Speech, and Signal Processing 1989 (ICASSP '89), pp. 988-991. + * The primary algorithm described there uses 11 multiplies and 29 adds. + * We use their alternate method with 12 multiplies and 32 adds. + * The advantage of this method is that no data path contains more than one + * multiplication; this allows a very simple and accurate implementation in + * scaled fixed-point arithmetic, with a minimal number of shifts. + */ + +#define JPEG_INTERNALS +#include "jinclude.h" +#include "jpeglib.h" +#include "jdct.h" /* Private declarations for DCT subsystem */ + +#ifdef DCT_ISLOW_SUPPORTED + + +/* + * This module is specialized to the case DCTSIZE = 8. + */ + +#if DCTSIZE != 8 + Sorry, this code only copes with 8x8 DCTs. /* deliberate syntax err */ +#endif + + +/* + * The poop on this scaling stuff is as follows: + * + * Each 1-D IDCT step produces outputs which are a factor of sqrt(N) + * larger than the true IDCT outputs. The final outputs are therefore + * a factor of N larger than desired; since N=8 this can be cured by + * a simple right shift at the end of the algorithm. The advantage of + * this arrangement is that we save two multiplications per 1-D IDCT, + * because the y0 and y4 inputs need not be divided by sqrt(N). + * + * We have to do addition and subtraction of the integer inputs, which + * is no problem, and multiplication by fractional constants, which is + * a problem to do in integer arithmetic. We multiply all the constants + * by CONST_SCALE and convert them to integer constants (thus retaining + * CONST_BITS bits of precision in the constants). After doing a + * multiplication we have to divide the product by CONST_SCALE, with proper + * rounding, to produce the correct output. This division can be done + * cheaply as a right shift of CONST_BITS bits. We postpone shifting + * as long as possible so that partial sums can be added together with + * full fractional precision. + * + * The outputs of the first pass are scaled up by PASS1_BITS bits so that + * they are represented to better-than-integral precision. These outputs + * require BITS_IN_JSAMPLE + PASS1_BITS + 3 bits; this fits in a 16-bit word + * with the recommended scaling. (To scale up 12-bit sample data further, an + * intermediate long array would be needed.) + * + * To avoid overflow of the 32-bit intermediate results in pass 2, we must + * have BITS_IN_JSAMPLE + CONST_BITS + PASS1_BITS <= 26. Error analysis + * shows that the values given below are the most effective. + */ + +#if BITS_IN_JSAMPLE == 8 +#define CONST_BITS 13 +#define PASS1_BITS 2 +#else +#define CONST_BITS 13 +#define PASS1_BITS 1 /* lose a little precision to avoid overflow */ +#endif + +/* Some C compilers fail to reduce "FIX(constant)" at compile time, thus + * causing a lot of useless floating-point operations at run time. + * To get around this we use the following pre-calculated constants. + * If you change CONST_BITS you may want to add appropriate values. + * (With a reasonable C compiler, you can just rely on the FIX() macro...) + */ + +#if CONST_BITS == 13 +#define FIX_0_298631336 ((long) 2446) /* FIX(0.298631336) */ +#define FIX_0_390180644 ((long) 3196) /* FIX(0.390180644) */ +#define FIX_0_541196100 ((long) 4433) /* FIX(0.541196100) */ +#define FIX_0_765366865 ((long) 6270) /* FIX(0.765366865) */ +#define FIX_0_899976223 ((long) 7373) /* FIX(0.899976223) */ +#define FIX_1_175875602 ((long) 9633) /* FIX(1.175875602) */ +#define FIX_1_501321110 ((long) 12299) /* FIX(1.501321110) */ +#define FIX_1_847759065 ((long) 15137) /* FIX(1.847759065) */ +#define FIX_1_961570560 ((long) 16069) /* FIX(1.961570560) */ +#define FIX_2_053119869 ((long) 16819) /* FIX(2.053119869) */ +#define FIX_2_562915447 ((long) 20995) /* FIX(2.562915447) */ +#define FIX_3_072711026 ((long) 25172) /* FIX(3.072711026) */ +#else +#define FIX_0_298631336 FIX(0.298631336) +#define FIX_0_390180644 FIX(0.390180644) +#define FIX_0_541196100 FIX(0.541196100) +#define FIX_0_765366865 FIX(0.765366865) +#define FIX_0_899976223 FIX(0.899976223) +#define FIX_1_175875602 FIX(1.175875602) +#define FIX_1_501321110 FIX(1.501321110) +#define FIX_1_847759065 FIX(1.847759065) +#define FIX_1_961570560 FIX(1.961570560) +#define FIX_2_053119869 FIX(2.053119869) +#define FIX_2_562915447 FIX(2.562915447) +#define FIX_3_072711026 FIX(3.072711026) +#endif + + +/* Multiply an long variable by an long constant to yield an long result. + * For 8-bit samples with the recommended scaling, all the variable + * and constant values involved are no more than 16 bits wide, so a + * 16x16->32 bit multiply can be used instead of a full 32x32 multiply. + * For 12-bit samples, a full 32-bit multiplication will be needed. + */ + +#if BITS_IN_JSAMPLE == 8 +#define MULTIPLY(var,const) MULTIPLY16C16(var,const) +#else +#define MULTIPLY(var,const) ((var) * (const)) +#endif + + +/* Dequantize a coefficient by multiplying it by the multiplier-table + * entry; produce an int result. In this module, both inputs and result + * are 16 bits or less, so either int or short multiply will work. + */ + +#define DEQUANTIZE(coef,quantval) (((ISLOW_MULT_TYPE) (coef)) * (quantval)) + + +/* + * Perform dequantization and inverse DCT on one block of coefficients. + */ + +GLOBAL(void) +jpeg_idct_islow (j_decompress_ptr cinfo, jpeg_component_info * compptr, + JCOEFPTR coef_block, + JSAMPARRAY output_buf, JDIMENSION output_col) +{ + long tmp0, tmp1, tmp2, tmp3; + long tmp10, tmp11, tmp12, tmp13; + long z1, z2, z3, z4, z5; + JCOEFPTR inptr; + ISLOW_MULT_TYPE * quantptr; + int * wsptr; + JSAMPROW outptr; + JSAMPLE *range_limit = IDCT_range_limit(cinfo); + int ctr; + int workspace[DCTSIZE2]; /* buffers data between passes */ + SHIFT_TEMPS + + /* Pass 1: process columns from input, store into work array. */ + /* Note results are scaled up by sqrt(8) compared to a true IDCT; */ + /* furthermore, we scale the results by 2**PASS1_BITS. */ + + inptr = coef_block; + quantptr = (ISLOW_MULT_TYPE *) compptr->dct_table; + wsptr = workspace; + for (ctr = DCTSIZE; ctr > 0; ctr--) { + /* Due to quantization, we will usually find that many of the input + * coefficients are zero, especially the AC terms. We can exploit this + * by short-circuiting the IDCT calculation for any column in which all + * the AC terms are zero. In that case each output is equal to the + * DC coefficient (with scale factor as needed). + * With typical images and quantization tables, half or more of the + * column DCT calculations can be simplified this way. + */ + + if (inptr[DCTSIZE*1] == 0 && inptr[DCTSIZE*2] == 0 && + inptr[DCTSIZE*3] == 0 && inptr[DCTSIZE*4] == 0 && + inptr[DCTSIZE*5] == 0 && inptr[DCTSIZE*6] == 0 && + inptr[DCTSIZE*7] == 0) { + /* AC terms all zero */ + int dcval = DEQUANTIZE(inptr[DCTSIZE*0], quantptr[DCTSIZE*0]) << PASS1_BITS; + + wsptr[DCTSIZE*0] = dcval; + wsptr[DCTSIZE*1] = dcval; + wsptr[DCTSIZE*2] = dcval; + wsptr[DCTSIZE*3] = dcval; + wsptr[DCTSIZE*4] = dcval; + wsptr[DCTSIZE*5] = dcval; + wsptr[DCTSIZE*6] = dcval; + wsptr[DCTSIZE*7] = dcval; + + inptr++; /* advance pointers to next column */ + quantptr++; + wsptr++; + continue; + } + + /* Even part: reverse the even part of the forward DCT. */ + /* The rotator is sqrt(2)*c(-6). */ + + z2 = DEQUANTIZE(inptr[DCTSIZE*2], quantptr[DCTSIZE*2]); + z3 = DEQUANTIZE(inptr[DCTSIZE*6], quantptr[DCTSIZE*6]); + + z1 = MULTIPLY(z2 + z3, FIX_0_541196100); + tmp2 = z1 + MULTIPLY(z3, - FIX_1_847759065); + tmp3 = z1 + MULTIPLY(z2, FIX_0_765366865); + + z2 = DEQUANTIZE(inptr[DCTSIZE*0], quantptr[DCTSIZE*0]); + z3 = DEQUANTIZE(inptr[DCTSIZE*4], quantptr[DCTSIZE*4]); + + tmp0 = (z2 + z3) << CONST_BITS; + tmp1 = (z2 - z3) << CONST_BITS; + + tmp10 = tmp0 + tmp3; + tmp13 = tmp0 - tmp3; + tmp11 = tmp1 + tmp2; + tmp12 = tmp1 - tmp2; + + /* Odd part per figure 8; the matrix is unitary and hence its + * transpose is its inverse. i0..i3 are y7,y5,y3,y1 respectively. + */ + + tmp0 = DEQUANTIZE(inptr[DCTSIZE*7], quantptr[DCTSIZE*7]); + tmp1 = DEQUANTIZE(inptr[DCTSIZE*5], quantptr[DCTSIZE*5]); + tmp2 = DEQUANTIZE(inptr[DCTSIZE*3], quantptr[DCTSIZE*3]); + tmp3 = DEQUANTIZE(inptr[DCTSIZE*1], quantptr[DCTSIZE*1]); + + z1 = tmp0 + tmp3; + z2 = tmp1 + tmp2; + z3 = tmp0 + tmp2; + z4 = tmp1 + tmp3; + z5 = MULTIPLY(z3 + z4, FIX_1_175875602); /* sqrt(2) * c3 */ + + tmp0 = MULTIPLY(tmp0, FIX_0_298631336); /* sqrt(2) * (-c1+c3+c5-c7) */ + tmp1 = MULTIPLY(tmp1, FIX_2_053119869); /* sqrt(2) * ( c1+c3-c5+c7) */ + tmp2 = MULTIPLY(tmp2, FIX_3_072711026); /* sqrt(2) * ( c1+c3+c5-c7) */ + tmp3 = MULTIPLY(tmp3, FIX_1_501321110); /* sqrt(2) * ( c1+c3-c5-c7) */ + z1 = MULTIPLY(z1, - FIX_0_899976223); /* sqrt(2) * (c7-c3) */ + z2 = MULTIPLY(z2, - FIX_2_562915447); /* sqrt(2) * (-c1-c3) */ + z3 = MULTIPLY(z3, - FIX_1_961570560); /* sqrt(2) * (-c3-c5) */ + z4 = MULTIPLY(z4, - FIX_0_390180644); /* sqrt(2) * (c5-c3) */ + + z3 += z5; + z4 += z5; + + tmp0 += z1 + z3; + tmp1 += z2 + z4; + tmp2 += z2 + z3; + tmp3 += z1 + z4; + + /* Final output stage: inputs are tmp10..tmp13, tmp0..tmp3 */ + + wsptr[DCTSIZE*0] = (int) DESCALE(tmp10 + tmp3, CONST_BITS-PASS1_BITS); + wsptr[DCTSIZE*7] = (int) DESCALE(tmp10 - tmp3, CONST_BITS-PASS1_BITS); + wsptr[DCTSIZE*1] = (int) DESCALE(tmp11 + tmp2, CONST_BITS-PASS1_BITS); + wsptr[DCTSIZE*6] = (int) DESCALE(tmp11 - tmp2, CONST_BITS-PASS1_BITS); + wsptr[DCTSIZE*2] = (int) DESCALE(tmp12 + tmp1, CONST_BITS-PASS1_BITS); + wsptr[DCTSIZE*5] = (int) DESCALE(tmp12 - tmp1, CONST_BITS-PASS1_BITS); + wsptr[DCTSIZE*3] = (int) DESCALE(tmp13 + tmp0, CONST_BITS-PASS1_BITS); + wsptr[DCTSIZE*4] = (int) DESCALE(tmp13 - tmp0, CONST_BITS-PASS1_BITS); + + inptr++; /* advance pointers to next column */ + quantptr++; + wsptr++; + } + + /* Pass 2: process rows from work array, store into output array. */ + /* Note that we must descale the results by a factor of 8 == 2**3, */ + /* and also undo the PASS1_BITS scaling. */ + + wsptr = workspace; + for (ctr = 0; ctr < DCTSIZE; ctr++) { + outptr = output_buf[ctr] + output_col; + /* Rows of zeroes can be exploited in the same way as we did with columns. + * However, the column calculation has created many nonzero AC terms, so + * the simplification applies less often (typically 5% to 10% of the time). + * On machines with very fast multiplication, it's possible that the + * test takes more time than it's worth. In that case this section + * may be commented out. + */ + +#ifndef NO_ZERO_ROW_TEST + if (wsptr[1] == 0 && wsptr[2] == 0 && wsptr[3] == 0 && wsptr[4] == 0 && + wsptr[5] == 0 && wsptr[6] == 0 && wsptr[7] == 0) { + /* AC terms all zero */ + JSAMPLE dcval = range_limit[(int) DESCALE((long) wsptr[0], PASS1_BITS+3) + & RANGE_MASK]; + + outptr[0] = dcval; + outptr[1] = dcval; + outptr[2] = dcval; + outptr[3] = dcval; + outptr[4] = dcval; + outptr[5] = dcval; + outptr[6] = dcval; + outptr[7] = dcval; + + wsptr += DCTSIZE; /* advance pointer to next row */ + continue; + } +#endif + + /* Even part: reverse the even part of the forward DCT. */ + /* The rotator is sqrt(2)*c(-6). */ + + z2 = (long) wsptr[2]; + z3 = (long) wsptr[6]; + + z1 = MULTIPLY(z2 + z3, FIX_0_541196100); + tmp2 = z1 + MULTIPLY(z3, - FIX_1_847759065); + tmp3 = z1 + MULTIPLY(z2, FIX_0_765366865); + + tmp0 = ((long) wsptr[0] + (long) wsptr[4]) << CONST_BITS; + tmp1 = ((long) wsptr[0] - (long) wsptr[4]) << CONST_BITS; + + tmp10 = tmp0 + tmp3; + tmp13 = tmp0 - tmp3; + tmp11 = tmp1 + tmp2; + tmp12 = tmp1 - tmp2; + + /* Odd part per figure 8; the matrix is unitary and hence its + * transpose is its inverse. i0..i3 are y7,y5,y3,y1 respectively. + */ + + tmp0 = (long) wsptr[7]; + tmp1 = (long) wsptr[5]; + tmp2 = (long) wsptr[3]; + tmp3 = (long) wsptr[1]; + + z1 = tmp0 + tmp3; + z2 = tmp1 + tmp2; + z3 = tmp0 + tmp2; + z4 = tmp1 + tmp3; + z5 = MULTIPLY(z3 + z4, FIX_1_175875602); /* sqrt(2) * c3 */ + + tmp0 = MULTIPLY(tmp0, FIX_0_298631336); /* sqrt(2) * (-c1+c3+c5-c7) */ + tmp1 = MULTIPLY(tmp1, FIX_2_053119869); /* sqrt(2) * ( c1+c3-c5+c7) */ + tmp2 = MULTIPLY(tmp2, FIX_3_072711026); /* sqrt(2) * ( c1+c3+c5-c7) */ + tmp3 = MULTIPLY(tmp3, FIX_1_501321110); /* sqrt(2) * ( c1+c3-c5-c7) */ + z1 = MULTIPLY(z1, - FIX_0_899976223); /* sqrt(2) * (c7-c3) */ + z2 = MULTIPLY(z2, - FIX_2_562915447); /* sqrt(2) * (-c1-c3) */ + z3 = MULTIPLY(z3, - FIX_1_961570560); /* sqrt(2) * (-c3-c5) */ + z4 = MULTIPLY(z4, - FIX_0_390180644); /* sqrt(2) * (c5-c3) */ + + z3 += z5; + z4 += z5; + + tmp0 += z1 + z3; + tmp1 += z2 + z4; + tmp2 += z2 + z3; + tmp3 += z1 + z4; + + /* Final output stage: inputs are tmp10..tmp13, tmp0..tmp3 */ + + outptr[0] = range_limit[(int) DESCALE(tmp10 + tmp3, + CONST_BITS+PASS1_BITS+3) + & RANGE_MASK]; + outptr[7] = range_limit[(int) DESCALE(tmp10 - tmp3, + CONST_BITS+PASS1_BITS+3) + & RANGE_MASK]; + outptr[1] = range_limit[(int) DESCALE(tmp11 + tmp2, + CONST_BITS+PASS1_BITS+3) + & RANGE_MASK]; + outptr[6] = range_limit[(int) DESCALE(tmp11 - tmp2, + CONST_BITS+PASS1_BITS+3) + & RANGE_MASK]; + outptr[2] = range_limit[(int) DESCALE(tmp12 + tmp1, + CONST_BITS+PASS1_BITS+3) + & RANGE_MASK]; + outptr[5] = range_limit[(int) DESCALE(tmp12 - tmp1, + CONST_BITS+PASS1_BITS+3) + & RANGE_MASK]; + outptr[3] = range_limit[(int) DESCALE(tmp13 + tmp0, + CONST_BITS+PASS1_BITS+3) + & RANGE_MASK]; + outptr[4] = range_limit[(int) DESCALE(tmp13 - tmp0, + CONST_BITS+PASS1_BITS+3) + & RANGE_MASK]; + + wsptr += DCTSIZE; /* advance pointer to next row */ + } +} + +#endif /* DCT_ISLOW_SUPPORTED */ diff --git a/ml/dlib/dlib/external/libjpeg/jidctred.cpp b/ml/dlib/dlib/external/libjpeg/jidctred.cpp new file mode 100644 index 000000000..fa442ac9a --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jidctred.cpp @@ -0,0 +1,398 @@ +/* + * jidctred.c + * + * Copyright (C) 1994-1998, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This file contains inverse-DCT routines that produce reduced-size output: + * either 4x4, 2x2, or 1x1 pixels from an 8x8 DCT block. + * + * The implementation is based on the Loeffler, Ligtenberg and Moschytz (LL&M) + * algorithm used in jidctint.c. We simply replace each 8-to-8 1-D IDCT step + * with an 8-to-4 step that produces the four averages of two adjacent outputs + * (or an 8-to-2 step producing two averages of four outputs, for 2x2 output). + * These steps were derived by computing the corresponding values at the end + * of the normal LL&M code, then simplifying as much as possible. + * + * 1x1 is trivial: just take the DC coefficient divided by 8. + * + * See jidctint.c for additional comments. + */ + +#define JPEG_INTERNALS +#include "jinclude.h" +#include "jpeglib.h" +#include "jdct.h" /* Private declarations for DCT subsystem */ + +#ifdef IDCT_SCALING_SUPPORTED + + +/* + * This module is specialized to the case DCTSIZE = 8. + */ + +#if DCTSIZE != 8 + Sorry, this code only copes with 8x8 DCTs. /* deliberate syntax err */ +#endif + + +/* Scaling is the same as in jidctint.c. */ + +#if BITS_IN_JSAMPLE == 8 +#define CONST_BITS 13 +#define PASS1_BITS 2 +#else +#define CONST_BITS 13 +#define PASS1_BITS 1 /* lose a little precision to avoid overflow */ +#endif + +/* Some C compilers fail to reduce "FIX(constant)" at compile time, thus + * causing a lot of useless floating-point operations at run time. + * To get around this we use the following pre-calculated constants. + * If you change CONST_BITS you may want to add appropriate values. + * (With a reasonable C compiler, you can just rely on the FIX() macro...) + */ + +#if CONST_BITS == 13 +#define FIX_0_211164243 ((long) 1730) /* FIX(0.211164243) */ +#define FIX_0_509795579 ((long) 4176) /* FIX(0.509795579) */ +#define FIX_0_601344887 ((long) 4926) /* FIX(0.601344887) */ +#define FIX_0_720959822 ((long) 5906) /* FIX(0.720959822) */ +#define FIX_0_765366865 ((long) 6270) /* FIX(0.765366865) */ +#define FIX_0_850430095 ((long) 6967) /* FIX(0.850430095) */ +#define FIX_0_899976223 ((long) 7373) /* FIX(0.899976223) */ +#define FIX_1_061594337 ((long) 8697) /* FIX(1.061594337) */ +#define FIX_1_272758580 ((long) 10426) /* FIX(1.272758580) */ +#define FIX_1_451774981 ((long) 11893) /* FIX(1.451774981) */ +#define FIX_1_847759065 ((long) 15137) /* FIX(1.847759065) */ +#define FIX_2_172734803 ((long) 17799) /* FIX(2.172734803) */ +#define FIX_2_562915447 ((long) 20995) /* FIX(2.562915447) */ +#define FIX_3_624509785 ((long) 29692) /* FIX(3.624509785) */ +#else +#define FIX_0_211164243 FIX(0.211164243) +#define FIX_0_509795579 FIX(0.509795579) +#define FIX_0_601344887 FIX(0.601344887) +#define FIX_0_720959822 FIX(0.720959822) +#define FIX_0_765366865 FIX(0.765366865) +#define FIX_0_850430095 FIX(0.850430095) +#define FIX_0_899976223 FIX(0.899976223) +#define FIX_1_061594337 FIX(1.061594337) +#define FIX_1_272758580 FIX(1.272758580) +#define FIX_1_451774981 FIX(1.451774981) +#define FIX_1_847759065 FIX(1.847759065) +#define FIX_2_172734803 FIX(2.172734803) +#define FIX_2_562915447 FIX(2.562915447) +#define FIX_3_624509785 FIX(3.624509785) +#endif + + +/* Multiply an long variable by an long constant to yield an long result. + * For 8-bit samples with the recommended scaling, all the variable + * and constant values involved are no more than 16 bits wide, so a + * 16x16->32 bit multiply can be used instead of a full 32x32 multiply. + * For 12-bit samples, a full 32-bit multiplication will be needed. + */ + +#if BITS_IN_JSAMPLE == 8 +#define MULTIPLY(var,const) MULTIPLY16C16(var,const) +#else +#define MULTIPLY(var,const) ((var) * (const)) +#endif + + +/* Dequantize a coefficient by multiplying it by the multiplier-table + * entry; produce an int result. In this module, both inputs and result + * are 16 bits or less, so either int or short multiply will work. + */ + +#define DEQUANTIZE(coef,quantval) (((ISLOW_MULT_TYPE) (coef)) * (quantval)) + + +/* + * Perform dequantization and inverse DCT on one block of coefficients, + * producing a reduced-size 4x4 output block. + */ + +GLOBAL(void) +jpeg_idct_4x4 (j_decompress_ptr cinfo, jpeg_component_info * compptr, + JCOEFPTR coef_block, + JSAMPARRAY output_buf, JDIMENSION output_col) +{ + long tmp0, tmp2, tmp10, tmp12; + long z1, z2, z3, z4; + JCOEFPTR inptr; + ISLOW_MULT_TYPE * quantptr; + int * wsptr; + JSAMPROW outptr; + JSAMPLE *range_limit = IDCT_range_limit(cinfo); + int ctr; + int workspace[DCTSIZE*4]; /* buffers data between passes */ + SHIFT_TEMPS + + /* Pass 1: process columns from input, store into work array. */ + + inptr = coef_block; + quantptr = (ISLOW_MULT_TYPE *) compptr->dct_table; + wsptr = workspace; + for (ctr = DCTSIZE; ctr > 0; inptr++, quantptr++, wsptr++, ctr--) { + /* Don't bother to process column 4, because second pass won't use it */ + if (ctr == DCTSIZE-4) + continue; + if (inptr[DCTSIZE*1] == 0 && inptr[DCTSIZE*2] == 0 && + inptr[DCTSIZE*3] == 0 && inptr[DCTSIZE*5] == 0 && + inptr[DCTSIZE*6] == 0 && inptr[DCTSIZE*7] == 0) { + /* AC terms all zero; we need not examine term 4 for 4x4 output */ + int dcval = DEQUANTIZE(inptr[DCTSIZE*0], quantptr[DCTSIZE*0]) << PASS1_BITS; + + wsptr[DCTSIZE*0] = dcval; + wsptr[DCTSIZE*1] = dcval; + wsptr[DCTSIZE*2] = dcval; + wsptr[DCTSIZE*3] = dcval; + + continue; + } + + /* Even part */ + + tmp0 = DEQUANTIZE(inptr[DCTSIZE*0], quantptr[DCTSIZE*0]); + tmp0 <<= (CONST_BITS+1); + + z2 = DEQUANTIZE(inptr[DCTSIZE*2], quantptr[DCTSIZE*2]); + z3 = DEQUANTIZE(inptr[DCTSIZE*6], quantptr[DCTSIZE*6]); + + tmp2 = MULTIPLY(z2, FIX_1_847759065) + MULTIPLY(z3, - FIX_0_765366865); + + tmp10 = tmp0 + tmp2; + tmp12 = tmp0 - tmp2; + + /* Odd part */ + + z1 = DEQUANTIZE(inptr[DCTSIZE*7], quantptr[DCTSIZE*7]); + z2 = DEQUANTIZE(inptr[DCTSIZE*5], quantptr[DCTSIZE*5]); + z3 = DEQUANTIZE(inptr[DCTSIZE*3], quantptr[DCTSIZE*3]); + z4 = DEQUANTIZE(inptr[DCTSIZE*1], quantptr[DCTSIZE*1]); + + tmp0 = MULTIPLY(z1, - FIX_0_211164243) /* sqrt(2) * (c3-c1) */ + + MULTIPLY(z2, FIX_1_451774981) /* sqrt(2) * (c3+c7) */ + + MULTIPLY(z3, - FIX_2_172734803) /* sqrt(2) * (-c1-c5) */ + + MULTIPLY(z4, FIX_1_061594337); /* sqrt(2) * (c5+c7) */ + + tmp2 = MULTIPLY(z1, - FIX_0_509795579) /* sqrt(2) * (c7-c5) */ + + MULTIPLY(z2, - FIX_0_601344887) /* sqrt(2) * (c5-c1) */ + + MULTIPLY(z3, FIX_0_899976223) /* sqrt(2) * (c3-c7) */ + + MULTIPLY(z4, FIX_2_562915447); /* sqrt(2) * (c1+c3) */ + + /* Final output stage */ + + wsptr[DCTSIZE*0] = (int) DESCALE(tmp10 + tmp2, CONST_BITS-PASS1_BITS+1); + wsptr[DCTSIZE*3] = (int) DESCALE(tmp10 - tmp2, CONST_BITS-PASS1_BITS+1); + wsptr[DCTSIZE*1] = (int) DESCALE(tmp12 + tmp0, CONST_BITS-PASS1_BITS+1); + wsptr[DCTSIZE*2] = (int) DESCALE(tmp12 - tmp0, CONST_BITS-PASS1_BITS+1); + } + + /* Pass 2: process 4 rows from work array, store into output array. */ + + wsptr = workspace; + for (ctr = 0; ctr < 4; ctr++) { + outptr = output_buf[ctr] + output_col; + /* It's not clear whether a zero row test is worthwhile here ... */ + +#ifndef NO_ZERO_ROW_TEST + if (wsptr[1] == 0 && wsptr[2] == 0 && wsptr[3] == 0 && + wsptr[5] == 0 && wsptr[6] == 0 && wsptr[7] == 0) { + /* AC terms all zero */ + JSAMPLE dcval = range_limit[(int) DESCALE((long) wsptr[0], PASS1_BITS+3) + & RANGE_MASK]; + + outptr[0] = dcval; + outptr[1] = dcval; + outptr[2] = dcval; + outptr[3] = dcval; + + wsptr += DCTSIZE; /* advance pointer to next row */ + continue; + } +#endif + + /* Even part */ + + tmp0 = ((long) wsptr[0]) << (CONST_BITS+1); + + tmp2 = MULTIPLY((long) wsptr[2], FIX_1_847759065) + + MULTIPLY((long) wsptr[6], - FIX_0_765366865); + + tmp10 = tmp0 + tmp2; + tmp12 = tmp0 - tmp2; + + /* Odd part */ + + z1 = (long) wsptr[7]; + z2 = (long) wsptr[5]; + z3 = (long) wsptr[3]; + z4 = (long) wsptr[1]; + + tmp0 = MULTIPLY(z1, - FIX_0_211164243) /* sqrt(2) * (c3-c1) */ + + MULTIPLY(z2, FIX_1_451774981) /* sqrt(2) * (c3+c7) */ + + MULTIPLY(z3, - FIX_2_172734803) /* sqrt(2) * (-c1-c5) */ + + MULTIPLY(z4, FIX_1_061594337); /* sqrt(2) * (c5+c7) */ + + tmp2 = MULTIPLY(z1, - FIX_0_509795579) /* sqrt(2) * (c7-c5) */ + + MULTIPLY(z2, - FIX_0_601344887) /* sqrt(2) * (c5-c1) */ + + MULTIPLY(z3, FIX_0_899976223) /* sqrt(2) * (c3-c7) */ + + MULTIPLY(z4, FIX_2_562915447); /* sqrt(2) * (c1+c3) */ + + /* Final output stage */ + + outptr[0] = range_limit[(int) DESCALE(tmp10 + tmp2, + CONST_BITS+PASS1_BITS+3+1) + & RANGE_MASK]; + outptr[3] = range_limit[(int) DESCALE(tmp10 - tmp2, + CONST_BITS+PASS1_BITS+3+1) + & RANGE_MASK]; + outptr[1] = range_limit[(int) DESCALE(tmp12 + tmp0, + CONST_BITS+PASS1_BITS+3+1) + & RANGE_MASK]; + outptr[2] = range_limit[(int) DESCALE(tmp12 - tmp0, + CONST_BITS+PASS1_BITS+3+1) + & RANGE_MASK]; + + wsptr += DCTSIZE; /* advance pointer to next row */ + } +} + + +/* + * Perform dequantization and inverse DCT on one block of coefficients, + * producing a reduced-size 2x2 output block. + */ + +GLOBAL(void) +jpeg_idct_2x2 (j_decompress_ptr cinfo, jpeg_component_info * compptr, + JCOEFPTR coef_block, + JSAMPARRAY output_buf, JDIMENSION output_col) +{ + long tmp0, tmp10, z1; + JCOEFPTR inptr; + ISLOW_MULT_TYPE * quantptr; + int * wsptr; + JSAMPROW outptr; + JSAMPLE *range_limit = IDCT_range_limit(cinfo); + int ctr; + int workspace[DCTSIZE*2]; /* buffers data between passes */ + SHIFT_TEMPS + + /* Pass 1: process columns from input, store into work array. */ + + inptr = coef_block; + quantptr = (ISLOW_MULT_TYPE *) compptr->dct_table; + wsptr = workspace; + for (ctr = DCTSIZE; ctr > 0; inptr++, quantptr++, wsptr++, ctr--) { + /* Don't bother to process columns 2,4,6 */ + if (ctr == DCTSIZE-2 || ctr == DCTSIZE-4 || ctr == DCTSIZE-6) + continue; + if (inptr[DCTSIZE*1] == 0 && inptr[DCTSIZE*3] == 0 && + inptr[DCTSIZE*5] == 0 && inptr[DCTSIZE*7] == 0) { + /* AC terms all zero; we need not examine terms 2,4,6 for 2x2 output */ + int dcval = DEQUANTIZE(inptr[DCTSIZE*0], quantptr[DCTSIZE*0]) << PASS1_BITS; + + wsptr[DCTSIZE*0] = dcval; + wsptr[DCTSIZE*1] = dcval; + + continue; + } + + /* Even part */ + + z1 = DEQUANTIZE(inptr[DCTSIZE*0], quantptr[DCTSIZE*0]); + tmp10 = z1 << (CONST_BITS+2); + + /* Odd part */ + + z1 = DEQUANTIZE(inptr[DCTSIZE*7], quantptr[DCTSIZE*7]); + tmp0 = MULTIPLY(z1, - FIX_0_720959822); /* sqrt(2) * (c7-c5+c3-c1) */ + z1 = DEQUANTIZE(inptr[DCTSIZE*5], quantptr[DCTSIZE*5]); + tmp0 += MULTIPLY(z1, FIX_0_850430095); /* sqrt(2) * (-c1+c3+c5+c7) */ + z1 = DEQUANTIZE(inptr[DCTSIZE*3], quantptr[DCTSIZE*3]); + tmp0 += MULTIPLY(z1, - FIX_1_272758580); /* sqrt(2) * (-c1+c3-c5-c7) */ + z1 = DEQUANTIZE(inptr[DCTSIZE*1], quantptr[DCTSIZE*1]); + tmp0 += MULTIPLY(z1, FIX_3_624509785); /* sqrt(2) * (c1+c3+c5+c7) */ + + /* Final output stage */ + + wsptr[DCTSIZE*0] = (int) DESCALE(tmp10 + tmp0, CONST_BITS-PASS1_BITS+2); + wsptr[DCTSIZE*1] = (int) DESCALE(tmp10 - tmp0, CONST_BITS-PASS1_BITS+2); + } + + /* Pass 2: process 2 rows from work array, store into output array. */ + + wsptr = workspace; + for (ctr = 0; ctr < 2; ctr++) { + outptr = output_buf[ctr] + output_col; + /* It's not clear whether a zero row test is worthwhile here ... */ + +#ifndef NO_ZERO_ROW_TEST + if (wsptr[1] == 0 && wsptr[3] == 0 && wsptr[5] == 0 && wsptr[7] == 0) { + /* AC terms all zero */ + JSAMPLE dcval = range_limit[(int) DESCALE((long) wsptr[0], PASS1_BITS+3) + & RANGE_MASK]; + + outptr[0] = dcval; + outptr[1] = dcval; + + wsptr += DCTSIZE; /* advance pointer to next row */ + continue; + } +#endif + + /* Even part */ + + tmp10 = ((long) wsptr[0]) << (CONST_BITS+2); + + /* Odd part */ + + tmp0 = MULTIPLY((long) wsptr[7], - FIX_0_720959822) /* sqrt(2) * (c7-c5+c3-c1) */ + + MULTIPLY((long) wsptr[5], FIX_0_850430095) /* sqrt(2) * (-c1+c3+c5+c7) */ + + MULTIPLY((long) wsptr[3], - FIX_1_272758580) /* sqrt(2) * (-c1+c3-c5-c7) */ + + MULTIPLY((long) wsptr[1], FIX_3_624509785); /* sqrt(2) * (c1+c3+c5+c7) */ + + /* Final output stage */ + + outptr[0] = range_limit[(int) DESCALE(tmp10 + tmp0, + CONST_BITS+PASS1_BITS+3+2) + & RANGE_MASK]; + outptr[1] = range_limit[(int) DESCALE(tmp10 - tmp0, + CONST_BITS+PASS1_BITS+3+2) + & RANGE_MASK]; + + wsptr += DCTSIZE; /* advance pointer to next row */ + } +} + + +/* + * Perform dequantization and inverse DCT on one block of coefficients, + * producing a reduced-size 1x1 output block. + */ + +GLOBAL(void) +jpeg_idct_1x1 (j_decompress_ptr cinfo, jpeg_component_info * compptr, + JCOEFPTR coef_block, + JSAMPARRAY output_buf, JDIMENSION output_col) +{ + int dcval; + ISLOW_MULT_TYPE * quantptr; + JSAMPLE *range_limit = IDCT_range_limit(cinfo); + SHIFT_TEMPS + + /* We hardly need an inverse DCT routine for this: just take the + * average pixel value, which is one-eighth of the DC coefficient. + */ + quantptr = (ISLOW_MULT_TYPE *) compptr->dct_table; + dcval = DEQUANTIZE(coef_block[0], quantptr[0]); + dcval = (int) DESCALE((long) dcval, 3); + + output_buf[0][output_col] = range_limit[dcval & RANGE_MASK]; +} + +#endif /* IDCT_SCALING_SUPPORTED */ diff --git a/ml/dlib/dlib/external/libjpeg/jinclude.h b/ml/dlib/dlib/external/libjpeg/jinclude.h new file mode 100644 index 000000000..0a4f15146 --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jinclude.h @@ -0,0 +1,91 @@ +/* + * jinclude.h + * + * Copyright (C) 1991-1994, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This file exists to provide a single place to fix any problems with + * including the wrong system include files. (Common problems are taken + * care of by the standard jconfig symbols, but on really weird systems + * you may have to edit this file.) + * + * NOTE: this file is NOT intended to be included by applications using the + * JPEG library. Most applications need only include jpeglib.h. + */ + + +/* Include auto-config file to find out which system include files we need. */ + +#include "jconfig.h" /* auto configuration options */ +#define JCONFIG_INCLUDED /* so that jpeglib.h doesn't do it again */ + +/* + * We need the NULL macro and size_t typedef. + * On an ANSI-conforming system it is sufficient to include . + * Otherwise, we get them from or ; we may have to + * pull in as well. + * Note that the core JPEG library does not require ; + * only the default error handler and data source/destination modules do. + * But we must pull it in because of the references to FILE in jpeglib.h. + * You can remove those references if you want to compile without . + */ + +#ifdef HAVE_STDDEF_H +#include +#endif + +#ifdef HAVE_STDLIB_H +#include +#endif + +#ifdef NEED_SYS_TYPES_H +#include +#endif + +#include + +/* + * We need memory copying and zeroing functions, plus strncpy(). + * ANSI and System V implementations declare these in . + * BSD doesn't have the mem() functions, but it does have bcopy()/bzero(). + * Some systems may declare memset and memcpy in . + * + * NOTE: we assume the size parameters to these functions are of type size_t. + * Change the casts in these macros if not! + */ + +#ifdef NEED_BSD_STRINGS + +#include +#define MEMZERO(target,size) bzero((void *)(target), (size_t)(size)) +#define MEMCOPY(dest,src,size) bcopy((const void *)(src), (void *)(dest), (size_t)(size)) + +#else /* not BSD, assume ANSI/SysV string lib */ + +#include +#define MEMZERO(target,size) memset((void *)(target), 0, (size_t)(size)) +#define MEMCOPY(dest,src,size) memcpy((void *)(dest), (const void *)(src), (size_t)(size)) + +#endif + +/* + * In ANSI C, and indeed any rational implementation, size_t is also the + * type returned by sizeof(). However, it seems there are some irrational + * implementations out there, in which sizeof() returns an int even though + * size_t is defined as long or unsigned long. To ensure consistent results + * we always use this SIZEOF() macro in place of using sizeof() directly. + */ + +#define SIZEOF(object) ((size_t) sizeof(object)) + +/* + * The modules that use fread() and fwrite() always invoke them through + * these macros. On some systems you may need to twiddle the argument casts. + * CAUTION: argument order is different from underlying functions! + */ + +#define JFREAD(file,buf,sizeofbuf) \ + ((size_t) fread((void *) (buf), (size_t) 1, (size_t) (sizeofbuf), (file))) +#define JFWRITE(file,buf,sizeofbuf) \ + ((size_t) fwrite((const void *) (buf), (size_t) 1, (size_t) (sizeofbuf), (file))) diff --git a/ml/dlib/dlib/external/libjpeg/jmemmgr.cpp b/ml/dlib/dlib/external/libjpeg/jmemmgr.cpp new file mode 100644 index 000000000..3a2e61955 --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jmemmgr.cpp @@ -0,0 +1,1118 @@ +/* + * jmemmgr.c + * + * Copyright (C) 1991-1997, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This file contains the JPEG system-independent memory management + * routines. This code is usable across a wide variety of machines; most + * of the system dependencies have been isolated in a separate file. + * The major functions provided here are: + * * pool-based allocation and freeing of memory; + * * policy decisions about how to divide available memory among the + * virtual arrays; + * * control logic for swapping virtual arrays between main memory and + * backing storage. + * The separate system-dependent file provides the actual backing-storage + * access code, and it contains the policy decision about how much total + * main memory to use. + * This file is system-dependent in the sense that some of its functions + * are unnecessary in some systems. For example, if there is enough virtual + * memory so that backing storage will never be used, much of the virtual + * array control logic could be removed. (Of course, if you have that much + * memory then you shouldn't care about a little bit of unused code...) + */ + +#define JPEG_INTERNALS +#define AM_MEMORY_MANAGER /* we define jvirt_Xarray_control structs */ +#include "jinclude.h" +#include "jpeglib.h" +#include "jmemsys.h" /* import the system-dependent declarations */ + +#ifndef NO_GETENV +#ifndef HAVE_STDLIB_H /* should declare getenv() */ +extern char * getenv JPP((const char * name)); +#endif +#endif + + +/* + * Some important notes: + * The allocation routines provided here must never return NULL. + * They should exit to error_exit if unsuccessful. + * + * It's not a good idea to try to merge the sarray and barray routines, + * even though they are textually almost the same, because samples are + * usually stored as bytes while coefficients are shorts or ints. Thus, + * in machines where byte pointers have a different representation from + * word pointers, the resulting machine code could not be the same. + */ + + +/* + * Many machines require storage alignment: longs must start on 4-byte + * boundaries, doubles on 8-byte boundaries, etc. On such machines, malloc() + * always returns pointers that are multiples of the worst-case alignment + * requirement, and we had better do so too. + * There isn't any really portable way to determine the worst-case alignment + * requirement. This module assumes that the alignment requirement is + * multiples of sizeof(ALIGN_TYPE). + * By default, we define ALIGN_TYPE as double. This is necessary on some + * workstations (where doubles really do need 8-byte alignment) and will work + * fine on nearly everything. If your machine has lesser alignment needs, + * you can save a few bytes by making ALIGN_TYPE smaller. + * The only place I know of where this will NOT work is certain Macintosh + * 680x0 compilers that define double as a 10-byte IEEE extended float. + * Doing 10-byte alignment is counterproductive because longwords won't be + * aligned well. Put "#define ALIGN_TYPE long" in jconfig.h if you have + * such a compiler. + */ + +#ifndef ALIGN_TYPE /* so can override from jconfig.h */ +#define ALIGN_TYPE double +#endif + + +/* + * We allocate objects from "pools", where each pool is gotten with a single + * request to jpeg_get_small() or jpeg_get_large(). There is no per-object + * overhead within a pool, except for alignment padding. Each pool has a + * header with a link to the next pool of the same class. + * Small and large pool headers are identical except that the latter's + * link pointer must be FAR on 80x86 machines. + * Notice that the "real" header fields are union'ed with a dummy ALIGN_TYPE + * field. This forces the compiler to make SIZEOF(small_pool_hdr) a multiple + * of the alignment requirement of ALIGN_TYPE. + */ + +typedef union small_pool_struct * small_pool_ptr; + +typedef union small_pool_struct { + struct { + small_pool_ptr next; /* next in list of pools */ + size_t bytes_used; /* how many bytes already used within pool */ + size_t bytes_left; /* bytes still available in this pool */ + } hdr; + ALIGN_TYPE dummy; /* included in union to ensure alignment */ +} small_pool_hdr; + +typedef union large_pool_struct FAR * large_pool_ptr; + +typedef union large_pool_struct { + struct { + large_pool_ptr next; /* next in list of pools */ + size_t bytes_used; /* how many bytes already used within pool */ + size_t bytes_left; /* bytes still available in this pool */ + } hdr; + ALIGN_TYPE dummy; /* included in union to ensure alignment */ +} large_pool_hdr; + + +/* + * Here is the full definition of a memory manager object. + */ + +typedef struct { + struct jpeg_memory_mgr pub; /* public fields */ + + /* Each pool identifier (lifetime class) names a linked list of pools. */ + small_pool_ptr small_list[JPOOL_NUMPOOLS]; + large_pool_ptr large_list[JPOOL_NUMPOOLS]; + + /* Since we only have one lifetime class of virtual arrays, only one + * linked list is necessary (for each datatype). Note that the virtual + * array control blocks being linked together are actually stored somewhere + * in the small-pool list. + */ + jvirt_sarray_ptr virt_sarray_list; + jvirt_barray_ptr virt_barray_list; + + /* This counts total space obtained from jpeg_get_small/large */ + long total_space_allocated; + + /* alloc_sarray and alloc_barray set this value for use by virtual + * array routines. + */ + JDIMENSION last_rowsperchunk; /* from most recent alloc_sarray/barray */ +} my_memory_mgr; + +typedef my_memory_mgr * my_mem_ptr; + + +/* + * The control blocks for virtual arrays. + * Note that these blocks are allocated in the "small" pool area. + * System-dependent info for the associated backing store (if any) is hidden + * inside the backing_store_info struct. + */ + +struct jvirt_sarray_control { + JSAMPARRAY mem_buffer; /* => the in-memory buffer */ + JDIMENSION rows_in_array; /* total virtual array height */ + JDIMENSION samplesperrow; /* width of array (and of memory buffer) */ + JDIMENSION maxaccess; /* max rows accessed by access_virt_sarray */ + JDIMENSION rows_in_mem; /* height of memory buffer */ + JDIMENSION rowsperchunk; /* allocation chunk size in mem_buffer */ + JDIMENSION cur_start_row; /* first logical row # in the buffer */ + JDIMENSION first_undef_row; /* row # of first uninitialized row */ + int pre_zero; /* pre-zero mode requested? */ + int dirty; /* do current buffer contents need written? */ + int b_s_open; /* is backing-store data valid? */ + jvirt_sarray_ptr next; /* link to next virtual sarray control block */ + backing_store_info b_s_info; /* System-dependent control info */ +}; + +struct jvirt_barray_control { + JBLOCKARRAY mem_buffer; /* => the in-memory buffer */ + JDIMENSION rows_in_array; /* total virtual array height */ + JDIMENSION blocksperrow; /* width of array (and of memory buffer) */ + JDIMENSION maxaccess; /* max rows accessed by access_virt_barray */ + JDIMENSION rows_in_mem; /* height of memory buffer */ + JDIMENSION rowsperchunk; /* allocation chunk size in mem_buffer */ + JDIMENSION cur_start_row; /* first logical row # in the buffer */ + JDIMENSION first_undef_row; /* row # of first uninitialized row */ + int pre_zero; /* pre-zero mode requested? */ + int dirty; /* do current buffer contents need written? */ + int b_s_open; /* is backing-store data valid? */ + jvirt_barray_ptr next; /* link to next virtual barray control block */ + backing_store_info b_s_info; /* System-dependent control info */ +}; + + +#ifdef MEM_STATS /* optional extra stuff for statistics */ + +LOCAL(void) +print_mem_stats (j_common_ptr cinfo, int pool_id) +{ + my_mem_ptr mem = (my_mem_ptr) cinfo->mem; + small_pool_ptr shdr_ptr; + large_pool_ptr lhdr_ptr; + + /* Since this is only a debugging stub, we can cheat a little by using + * fprintf directly rather than going through the trace message code. + * This is helpful because message parm array can't handle longs. + */ + fprintf(stderr, "Freeing pool %d, total space = %ld\n", + pool_id, mem->total_space_allocated); + + for (lhdr_ptr = mem->large_list[pool_id]; lhdr_ptr != NULL; + lhdr_ptr = lhdr_ptr->hdr.next) { + fprintf(stderr, " Large chunk used %ld\n", + (long) lhdr_ptr->hdr.bytes_used); + } + + for (shdr_ptr = mem->small_list[pool_id]; shdr_ptr != NULL; + shdr_ptr = shdr_ptr->hdr.next) { + fprintf(stderr, " Small chunk used %ld free %ld\n", + (long) shdr_ptr->hdr.bytes_used, + (long) shdr_ptr->hdr.bytes_left); + } +} + +#endif /* MEM_STATS */ + + +LOCAL(void) +out_of_memory (j_common_ptr cinfo, int which) +/* Report an out-of-memory error and stop execution */ +/* If we compiled MEM_STATS support, report alloc requests before dying */ +{ +#ifdef MEM_STATS + cinfo->err->trace_level = 2; /* force self_destruct to report stats */ +#endif + ERREXIT1(cinfo, JERR_OUT_OF_MEMORY, which); +} + + +/* + * Allocation of "small" objects. + * + * For these, we use pooled storage. When a new pool must be created, + * we try to get enough space for the current request plus a "slop" factor, + * where the slop will be the amount of leftover space in the new pool. + * The speed vs. space tradeoff is largely determined by the slop values. + * A different slop value is provided for each pool class (lifetime), + * and we also distinguish the first pool of a class from later ones. + * NOTE: the values given work fairly well on both 16- and 32-bit-int + * machines, but may be too small if longs are 64 bits or more. + */ + +static const size_t first_pool_slop[JPOOL_NUMPOOLS] = +{ + 1600, /* first PERMANENT pool */ + 16000 /* first IMAGE pool */ +}; + +static const size_t extra_pool_slop[JPOOL_NUMPOOLS] = +{ + 0, /* additional PERMANENT pools */ + 5000 /* additional IMAGE pools */ +}; + +#define MIN_SLOP 50 /* greater than 0 to avoid futile looping */ + + +METHODDEF(void *) +alloc_small (j_common_ptr cinfo, int pool_id, size_t sizeofobject) +/* Allocate a "small" object */ +{ + my_mem_ptr mem = (my_mem_ptr) cinfo->mem; + small_pool_ptr hdr_ptr, prev_hdr_ptr; + char * data_ptr; + size_t odd_bytes, min_request, slop; + + /* Check for unsatisfiable request (do now to ensure no overflow below) */ + if (sizeofobject > (size_t) (MAX_ALLOC_CHUNK-SIZEOF(small_pool_hdr))) + out_of_memory(cinfo, 1); /* request exceeds malloc's ability */ + + /* Round up the requested size to a multiple of SIZEOF(ALIGN_TYPE) */ + odd_bytes = sizeofobject % SIZEOF(ALIGN_TYPE); + if (odd_bytes > 0) + sizeofobject += SIZEOF(ALIGN_TYPE) - odd_bytes; + + /* See if space is available in any existing pool */ + if (pool_id < 0 || pool_id >= JPOOL_NUMPOOLS) + ERREXIT1(cinfo, JERR_BAD_POOL_ID, pool_id); /* safety check */ + prev_hdr_ptr = NULL; + hdr_ptr = mem->small_list[pool_id]; + while (hdr_ptr != NULL) { + if (hdr_ptr->hdr.bytes_left >= sizeofobject) + break; /* found pool with enough space */ + prev_hdr_ptr = hdr_ptr; + hdr_ptr = hdr_ptr->hdr.next; + } + + /* Time to make a new pool? */ + if (hdr_ptr == NULL) { + /* min_request is what we need now, slop is what will be leftover */ + min_request = sizeofobject + SIZEOF(small_pool_hdr); + if (prev_hdr_ptr == NULL) /* first pool in class? */ + slop = first_pool_slop[pool_id]; + else + slop = extra_pool_slop[pool_id]; + /* Don't ask for more than MAX_ALLOC_CHUNK */ + if (slop > (size_t) (MAX_ALLOC_CHUNK-min_request)) + slop = (size_t) (MAX_ALLOC_CHUNK-min_request); + /* Try to get space, if fail reduce slop and try again */ + for (;;) { + hdr_ptr = (small_pool_ptr) jpeg_get_small(cinfo, min_request + slop); + if (hdr_ptr != NULL) + break; + slop /= 2; + if (slop < MIN_SLOP) /* give up when it gets real small */ + out_of_memory(cinfo, 2); /* jpeg_get_small failed */ + } + mem->total_space_allocated += (long)(min_request + slop); + /* Success, initialize the new pool header and add to end of list */ + hdr_ptr->hdr.next = NULL; + hdr_ptr->hdr.bytes_used = 0; + hdr_ptr->hdr.bytes_left = sizeofobject + slop; + if (prev_hdr_ptr == NULL) /* first pool in class? */ + mem->small_list[pool_id] = hdr_ptr; + else + prev_hdr_ptr->hdr.next = hdr_ptr; + } + + /* OK, allocate the object from the current pool */ + data_ptr = (char *) (hdr_ptr + 1); /* point to first data byte in pool */ + data_ptr += hdr_ptr->hdr.bytes_used; /* point to place for object */ + hdr_ptr->hdr.bytes_used += sizeofobject; + hdr_ptr->hdr.bytes_left -= sizeofobject; + + return (void *) data_ptr; +} + + +/* + * Allocation of "large" objects. + * + * The external semantics of these are the same as "small" objects, + * except that FAR pointers are used on 80x86. However the pool + * management heuristics are quite different. We assume that each + * request is large enough that it may as well be passed directly to + * jpeg_get_large; the pool management just links everything together + * so that we can free it all on demand. + * Note: the major use of "large" objects is in JSAMPARRAY and JBLOCKARRAY + * structures. The routines that create these structures (see below) + * deliberately bunch rows together to ensure a large request size. + */ + +METHODDEF(void FAR *) +alloc_large (j_common_ptr cinfo, int pool_id, size_t sizeofobject) +/* Allocate a "large" object */ +{ + my_mem_ptr mem = (my_mem_ptr) cinfo->mem; + large_pool_ptr hdr_ptr; + size_t odd_bytes; + + /* Check for unsatisfiable request (do now to ensure no overflow below) */ + if (sizeofobject > (size_t) (MAX_ALLOC_CHUNK-SIZEOF(large_pool_hdr))) + out_of_memory(cinfo, 3); /* request exceeds malloc's ability */ + + /* Round up the requested size to a multiple of SIZEOF(ALIGN_TYPE) */ + odd_bytes = sizeofobject % SIZEOF(ALIGN_TYPE); + if (odd_bytes > 0) + sizeofobject += SIZEOF(ALIGN_TYPE) - odd_bytes; + + /* Always make a new pool */ + if (pool_id < 0 || pool_id >= JPOOL_NUMPOOLS) + ERREXIT1(cinfo, JERR_BAD_POOL_ID, pool_id); /* safety check */ + + hdr_ptr = (large_pool_ptr) jpeg_get_large(cinfo, sizeofobject + + SIZEOF(large_pool_hdr)); + if (hdr_ptr == NULL) + out_of_memory(cinfo, 4); /* jpeg_get_large failed */ + mem->total_space_allocated += (long)(sizeofobject + SIZEOF(large_pool_hdr)); + + /* Success, initialize the new pool header and add to list */ + hdr_ptr->hdr.next = mem->large_list[pool_id]; + /* We maintain space counts in each pool header for statistical purposes, + * even though they are not needed for allocation. + */ + hdr_ptr->hdr.bytes_used = sizeofobject; + hdr_ptr->hdr.bytes_left = 0; + mem->large_list[pool_id] = hdr_ptr; + + return (void FAR *) (hdr_ptr + 1); /* point to first data byte in pool */ +} + + +/* + * Creation of 2-D sample arrays. + * The pointers are in near heap, the samples themselves in FAR heap. + * + * To minimize allocation overhead and to allow I/O of large contiguous + * blocks, we allocate the sample rows in groups of as many rows as possible + * without exceeding MAX_ALLOC_CHUNK total bytes per allocation request. + * NB: the virtual array control routines, later in this file, know about + * this chunking of rows. The rowsperchunk value is left in the mem manager + * object so that it can be saved away if this sarray is the workspace for + * a virtual array. + */ + +METHODDEF(JSAMPARRAY) +alloc_sarray (j_common_ptr cinfo, int pool_id, + JDIMENSION samplesperrow, JDIMENSION numrows) +/* Allocate a 2-D sample array */ +{ + my_mem_ptr mem = (my_mem_ptr) cinfo->mem; + JSAMPARRAY result; + JSAMPROW workspace; + JDIMENSION rowsperchunk, currow, i; + long ltemp; + + /* Calculate max # of rows allowed in one allocation chunk */ + ltemp = (MAX_ALLOC_CHUNK-SIZEOF(large_pool_hdr)) / + ((long) samplesperrow * SIZEOF(JSAMPLE)); + if (ltemp <= 0) + ERREXIT(cinfo, JERR_WIDTH_OVERFLOW); + if (ltemp < (long) numrows) + rowsperchunk = (JDIMENSION) ltemp; + else + rowsperchunk = numrows; + mem->last_rowsperchunk = rowsperchunk; + + /* Get space for row pointers (small object) */ + result = (JSAMPARRAY) alloc_small(cinfo, pool_id, + (size_t) (numrows * SIZEOF(JSAMPROW))); + + /* Get the rows themselves (large objects) */ + currow = 0; + while (currow < numrows) { + rowsperchunk = MIN(rowsperchunk, numrows - currow); + workspace = (JSAMPROW) alloc_large(cinfo, pool_id, + (size_t) ((size_t) rowsperchunk * (size_t) samplesperrow + * SIZEOF(JSAMPLE))); + for (i = rowsperchunk; i > 0; i--) { + result[currow++] = workspace; + workspace += samplesperrow; + } + } + + return result; +} + + +/* + * Creation of 2-D coefficient-block arrays. + * This is essentially the same as the code for sample arrays, above. + */ + +METHODDEF(JBLOCKARRAY) +alloc_barray (j_common_ptr cinfo, int pool_id, + JDIMENSION blocksperrow, JDIMENSION numrows) +/* Allocate a 2-D coefficient-block array */ +{ + my_mem_ptr mem = (my_mem_ptr) cinfo->mem; + JBLOCKARRAY result; + JBLOCKROW workspace; + JDIMENSION rowsperchunk, currow, i; + long ltemp; + + /* Calculate max # of rows allowed in one allocation chunk */ + ltemp = (MAX_ALLOC_CHUNK-SIZEOF(large_pool_hdr)) / + ((long) blocksperrow * SIZEOF(JBLOCK)); + if (ltemp <= 0) + ERREXIT(cinfo, JERR_WIDTH_OVERFLOW); + if (ltemp < (long) numrows) + rowsperchunk = (JDIMENSION) ltemp; + else + rowsperchunk = numrows; + mem->last_rowsperchunk = rowsperchunk; + + /* Get space for row pointers (small object) */ + result = (JBLOCKARRAY) alloc_small(cinfo, pool_id, + (size_t) (numrows * SIZEOF(JBLOCKROW))); + + /* Get the rows themselves (large objects) */ + currow = 0; + while (currow < numrows) { + rowsperchunk = MIN(rowsperchunk, numrows - currow); + workspace = (JBLOCKROW) alloc_large(cinfo, pool_id, + (size_t) ((size_t) rowsperchunk * (size_t) blocksperrow + * SIZEOF(JBLOCK))); + for (i = rowsperchunk; i > 0; i--) { + result[currow++] = workspace; + workspace += blocksperrow; + } + } + + return result; +} + + +/* + * About virtual array management: + * + * The above "normal" array routines are only used to allocate strip buffers + * (as wide as the image, but just a few rows high). Full-image-sized buffers + * are handled as "virtual" arrays. The array is still accessed a strip at a + * time, but the memory manager must save the whole array for repeated + * accesses. The intended implementation is that there is a strip buffer in + * memory (as high as is possible given the desired memory limit), plus a + * backing file that holds the rest of the array. + * + * The request_virt_array routines are told the total size of the image and + * the maximum number of rows that will be accessed at once. The in-memory + * buffer must be at least as large as the maxaccess value. + * + * The request routines create control blocks but not the in-memory buffers. + * That is postponed until realize_virt_arrays is called. At that time the + * total amount of space needed is known (approximately, anyway), so free + * memory can be divided up fairly. + * + * The access_virt_array routines are responsible for making a specific strip + * area accessible (after reading or writing the backing file, if necessary). + * Note that the access routines are told whether the caller intends to modify + * the accessed strip; during a read-only pass this saves having to rewrite + * data to disk. The access routines are also responsible for pre-zeroing + * any newly accessed rows, if pre-zeroing was requested. + * + * In current usage, the access requests are usually for nonoverlapping + * strips; that is, successive access start_row numbers differ by exactly + * num_rows = maxaccess. This means we can get good performance with simple + * buffer dump/reload logic, by making the in-memory buffer be a multiple + * of the access height; then there will never be accesses across bufferload + * boundaries. The code will still work with overlapping access requests, + * but it doesn't handle bufferload overlaps very efficiently. + */ + + +METHODDEF(jvirt_sarray_ptr) +request_virt_sarray (j_common_ptr cinfo, int pool_id, int pre_zero, + JDIMENSION samplesperrow, JDIMENSION numrows, + JDIMENSION maxaccess) +/* Request a virtual 2-D sample array */ +{ + my_mem_ptr mem = (my_mem_ptr) cinfo->mem; + jvirt_sarray_ptr result; + + /* Only IMAGE-lifetime virtual arrays are currently supported */ + if (pool_id != JPOOL_IMAGE) + ERREXIT1(cinfo, JERR_BAD_POOL_ID, pool_id); /* safety check */ + + /* get control block */ + result = (jvirt_sarray_ptr) alloc_small(cinfo, pool_id, + SIZEOF(struct jvirt_sarray_control)); + + result->mem_buffer = NULL; /* marks array not yet realized */ + result->rows_in_array = numrows; + result->samplesperrow = samplesperrow; + result->maxaccess = maxaccess; + result->pre_zero = pre_zero; + result->b_s_open = FALSE; /* no associated backing-store object */ + result->next = mem->virt_sarray_list; /* add to list of virtual arrays */ + mem->virt_sarray_list = result; + + return result; +} + + +METHODDEF(jvirt_barray_ptr) +request_virt_barray (j_common_ptr cinfo, int pool_id, int pre_zero, + JDIMENSION blocksperrow, JDIMENSION numrows, + JDIMENSION maxaccess) +/* Request a virtual 2-D coefficient-block array */ +{ + my_mem_ptr mem = (my_mem_ptr) cinfo->mem; + jvirt_barray_ptr result; + + /* Only IMAGE-lifetime virtual arrays are currently supported */ + if (pool_id != JPOOL_IMAGE) + ERREXIT1(cinfo, JERR_BAD_POOL_ID, pool_id); /* safety check */ + + /* get control block */ + result = (jvirt_barray_ptr) alloc_small(cinfo, pool_id, + SIZEOF(struct jvirt_barray_control)); + + result->mem_buffer = NULL; /* marks array not yet realized */ + result->rows_in_array = numrows; + result->blocksperrow = blocksperrow; + result->maxaccess = maxaccess; + result->pre_zero = pre_zero; + result->b_s_open = FALSE; /* no associated backing-store object */ + result->next = mem->virt_barray_list; /* add to list of virtual arrays */ + mem->virt_barray_list = result; + + return result; +} + + +METHODDEF(void) +realize_virt_arrays (j_common_ptr cinfo) +/* Allocate the in-memory buffers for any unrealized virtual arrays */ +{ + my_mem_ptr mem = (my_mem_ptr) cinfo->mem; + long space_per_minheight, maximum_space, avail_mem; + long minheights, max_minheights; + jvirt_sarray_ptr sptr; + jvirt_barray_ptr bptr; + + /* Compute the minimum space needed (maxaccess rows in each buffer) + * and the maximum space needed (full image height in each buffer). + * These may be of use to the system-dependent jpeg_mem_available routine. + */ + space_per_minheight = 0; + maximum_space = 0; + for (sptr = mem->virt_sarray_list; sptr != NULL; sptr = sptr->next) { + if (sptr->mem_buffer == NULL) { /* if not realized yet */ + space_per_minheight += (long) sptr->maxaccess * + (long) sptr->samplesperrow * SIZEOF(JSAMPLE); + maximum_space += (long) sptr->rows_in_array * + (long) sptr->samplesperrow * SIZEOF(JSAMPLE); + } + } + for (bptr = mem->virt_barray_list; bptr != NULL; bptr = bptr->next) { + if (bptr->mem_buffer == NULL) { /* if not realized yet */ + space_per_minheight += (long) bptr->maxaccess * + (long) bptr->blocksperrow * SIZEOF(JBLOCK); + maximum_space += (long) bptr->rows_in_array * + (long) bptr->blocksperrow * SIZEOF(JBLOCK); + } + } + + if (space_per_minheight <= 0) + return; /* no unrealized arrays, no work */ + + /* Determine amount of memory to actually use; this is system-dependent. */ + avail_mem = jpeg_mem_available(cinfo, space_per_minheight, maximum_space, + mem->total_space_allocated); + + /* If the maximum space needed is available, make all the buffers full + * height; otherwise parcel it out with the same number of minheights + * in each buffer. + */ + if (avail_mem >= maximum_space) + max_minheights = 1000000000L; + else { + max_minheights = avail_mem / space_per_minheight; + /* If there doesn't seem to be enough space, try to get the minimum + * anyway. This allows a "stub" implementation of jpeg_mem_available(). + */ + if (max_minheights <= 0) + max_minheights = 1; + } + + /* Allocate the in-memory buffers and initialize backing store as needed. */ + + for (sptr = mem->virt_sarray_list; sptr != NULL; sptr = sptr->next) { + if (sptr->mem_buffer == NULL) { /* if not realized yet */ + minheights = ((long) sptr->rows_in_array - 1L) / sptr->maxaccess + 1L; + if (minheights <= max_minheights) { + /* This buffer fits in memory */ + sptr->rows_in_mem = sptr->rows_in_array; + } else { + /* It doesn't fit in memory, create backing store. */ + sptr->rows_in_mem = (JDIMENSION) (max_minheights * sptr->maxaccess); + jpeg_open_backing_store(cinfo, & sptr->b_s_info, + (long) sptr->rows_in_array * + (long) sptr->samplesperrow * + (long) SIZEOF(JSAMPLE)); + sptr->b_s_open = TRUE; + } + sptr->mem_buffer = alloc_sarray(cinfo, JPOOL_IMAGE, + sptr->samplesperrow, sptr->rows_in_mem); + sptr->rowsperchunk = mem->last_rowsperchunk; + sptr->cur_start_row = 0; + sptr->first_undef_row = 0; + sptr->dirty = FALSE; + } + } + + for (bptr = mem->virt_barray_list; bptr != NULL; bptr = bptr->next) { + if (bptr->mem_buffer == NULL) { /* if not realized yet */ + minheights = ((long) bptr->rows_in_array - 1L) / bptr->maxaccess + 1L; + if (minheights <= max_minheights) { + /* This buffer fits in memory */ + bptr->rows_in_mem = bptr->rows_in_array; + } else { + /* It doesn't fit in memory, create backing store. */ + bptr->rows_in_mem = (JDIMENSION) (max_minheights * bptr->maxaccess); + jpeg_open_backing_store(cinfo, & bptr->b_s_info, + (long) bptr->rows_in_array * + (long) bptr->blocksperrow * + (long) SIZEOF(JBLOCK)); + bptr->b_s_open = TRUE; + } + bptr->mem_buffer = alloc_barray(cinfo, JPOOL_IMAGE, + bptr->blocksperrow, bptr->rows_in_mem); + bptr->rowsperchunk = mem->last_rowsperchunk; + bptr->cur_start_row = 0; + bptr->first_undef_row = 0; + bptr->dirty = FALSE; + } + } +} + + +LOCAL(void) +do_sarray_io (j_common_ptr cinfo, jvirt_sarray_ptr ptr, int writing) +/* Do backing store read or write of a virtual sample array */ +{ + long bytesperrow, file_offset, byte_count, rows, thisrow, i; + + bytesperrow = (long) ptr->samplesperrow * SIZEOF(JSAMPLE); + file_offset = ptr->cur_start_row * bytesperrow; + /* Loop to read or write each allocation chunk in mem_buffer */ + for (i = 0; i < (long) ptr->rows_in_mem; i += ptr->rowsperchunk) { + /* One chunk, but check for short chunk at end of buffer */ + rows = MIN((long) ptr->rowsperchunk, (long) ptr->rows_in_mem - i); + /* Transfer no more than is currently defined */ + thisrow = (long) ptr->cur_start_row + i; + rows = MIN(rows, (long) ptr->first_undef_row - thisrow); + /* Transfer no more than fits in file */ + rows = MIN(rows, (long) ptr->rows_in_array - thisrow); + if (rows <= 0) /* this chunk might be past end of file! */ + break; + byte_count = rows * bytesperrow; + if (writing) + (*ptr->b_s_info.write_backing_store) (cinfo, & ptr->b_s_info, + (void FAR *) ptr->mem_buffer[i], + file_offset, byte_count); + else + (*ptr->b_s_info.read_backing_store) (cinfo, & ptr->b_s_info, + (void FAR *) ptr->mem_buffer[i], + file_offset, byte_count); + file_offset += byte_count; + } +} + + +LOCAL(void) +do_barray_io (j_common_ptr cinfo, jvirt_barray_ptr ptr, int writing) +/* Do backing store read or write of a virtual coefficient-block array */ +{ + long bytesperrow, file_offset, byte_count, rows, thisrow, i; + + bytesperrow = (long) ptr->blocksperrow * SIZEOF(JBLOCK); + file_offset = ptr->cur_start_row * bytesperrow; + /* Loop to read or write each allocation chunk in mem_buffer */ + for (i = 0; i < (long) ptr->rows_in_mem; i += ptr->rowsperchunk) { + /* One chunk, but check for short chunk at end of buffer */ + rows = MIN((long) ptr->rowsperchunk, (long) ptr->rows_in_mem - i); + /* Transfer no more than is currently defined */ + thisrow = (long) ptr->cur_start_row + i; + rows = MIN(rows, (long) ptr->first_undef_row - thisrow); + /* Transfer no more than fits in file */ + rows = MIN(rows, (long) ptr->rows_in_array - thisrow); + if (rows <= 0) /* this chunk might be past end of file! */ + break; + byte_count = rows * bytesperrow; + if (writing) + (*ptr->b_s_info.write_backing_store) (cinfo, & ptr->b_s_info, + (void FAR *) ptr->mem_buffer[i], + file_offset, byte_count); + else + (*ptr->b_s_info.read_backing_store) (cinfo, & ptr->b_s_info, + (void FAR *) ptr->mem_buffer[i], + file_offset, byte_count); + file_offset += byte_count; + } +} + + +METHODDEF(JSAMPARRAY) +access_virt_sarray (j_common_ptr cinfo, jvirt_sarray_ptr ptr, + JDIMENSION start_row, JDIMENSION num_rows, + int writable) +/* Access the part of a virtual sample array starting at start_row */ +/* and extending for num_rows rows. writable is true if */ +/* caller intends to modify the accessed area. */ +{ + JDIMENSION end_row = start_row + num_rows; + JDIMENSION undef_row; + + /* debugging check */ + if (end_row > ptr->rows_in_array || num_rows > ptr->maxaccess || + ptr->mem_buffer == NULL) + ERREXIT(cinfo, JERR_BAD_VIRTUAL_ACCESS); + + /* Make the desired part of the virtual array accessible */ + if (start_row < ptr->cur_start_row || + end_row > ptr->cur_start_row+ptr->rows_in_mem) { + if (! ptr->b_s_open) + ERREXIT(cinfo, JERR_VIRTUAL_BUG); + /* Flush old buffer contents if necessary */ + if (ptr->dirty) { + do_sarray_io(cinfo, ptr, TRUE); + ptr->dirty = FALSE; + } + /* Decide what part of virtual array to access. + * Algorithm: if target address > current window, assume forward scan, + * load starting at target address. If target address < current window, + * assume backward scan, load so that target area is top of window. + * Note that when switching from forward write to forward read, will have + * start_row = 0, so the limiting case applies and we load from 0 anyway. + */ + if (start_row > ptr->cur_start_row) { + ptr->cur_start_row = start_row; + } else { + /* use long arithmetic here to avoid overflow & unsigned problems */ + long ltemp; + + ltemp = (long) end_row - (long) ptr->rows_in_mem; + if (ltemp < 0) + ltemp = 0; /* don't fall off front end of file */ + ptr->cur_start_row = (JDIMENSION) ltemp; + } + /* Read in the selected part of the array. + * During the initial write pass, we will do no actual read + * because the selected part is all undefined. + */ + do_sarray_io(cinfo, ptr, FALSE); + } + /* Ensure the accessed part of the array is defined; prezero if needed. + * To improve locality of access, we only prezero the part of the array + * that the caller is about to access, not the entire in-memory array. + */ + if (ptr->first_undef_row < end_row) { + if (ptr->first_undef_row < start_row) { + if (writable) /* writer skipped over a section of array */ + ERREXIT(cinfo, JERR_BAD_VIRTUAL_ACCESS); + undef_row = start_row; /* but reader is allowed to read ahead */ + } else { + undef_row = ptr->first_undef_row; + } + if (writable) + ptr->first_undef_row = end_row; + if (ptr->pre_zero) { + size_t bytesperrow = (size_t) ptr->samplesperrow * SIZEOF(JSAMPLE); + undef_row -= ptr->cur_start_row; /* make indexes relative to buffer */ + end_row -= ptr->cur_start_row; + while (undef_row < end_row) { + jzero_far((void FAR *) ptr->mem_buffer[undef_row], bytesperrow); + undef_row++; + } + } else { + if (! writable) /* reader looking at undefined data */ + ERREXIT(cinfo, JERR_BAD_VIRTUAL_ACCESS); + } + } + /* Flag the buffer dirty if caller will write in it */ + if (writable) + ptr->dirty = TRUE; + /* Return address of proper part of the buffer */ + return ptr->mem_buffer + (start_row - ptr->cur_start_row); +} + + +METHODDEF(JBLOCKARRAY) +access_virt_barray (j_common_ptr cinfo, jvirt_barray_ptr ptr, + JDIMENSION start_row, JDIMENSION num_rows, + int writable) +/* Access the part of a virtual block array starting at start_row */ +/* and extending for num_rows rows. writable is true if */ +/* caller intends to modify the accessed area. */ +{ + JDIMENSION end_row = start_row + num_rows; + JDIMENSION undef_row; + + /* debugging check */ + if (end_row > ptr->rows_in_array || num_rows > ptr->maxaccess || + ptr->mem_buffer == NULL) + ERREXIT(cinfo, JERR_BAD_VIRTUAL_ACCESS); + + /* Make the desired part of the virtual array accessible */ + if (start_row < ptr->cur_start_row || + end_row > ptr->cur_start_row+ptr->rows_in_mem) { + if (! ptr->b_s_open) + ERREXIT(cinfo, JERR_VIRTUAL_BUG); + /* Flush old buffer contents if necessary */ + if (ptr->dirty) { + do_barray_io(cinfo, ptr, TRUE); + ptr->dirty = FALSE; + } + /* Decide what part of virtual array to access. + * Algorithm: if target address > current window, assume forward scan, + * load starting at target address. If target address < current window, + * assume backward scan, load so that target area is top of window. + * Note that when switching from forward write to forward read, will have + * start_row = 0, so the limiting case applies and we load from 0 anyway. + */ + if (start_row > ptr->cur_start_row) { + ptr->cur_start_row = start_row; + } else { + /* use long arithmetic here to avoid overflow & unsigned problems */ + long ltemp; + + ltemp = (long) end_row - (long) ptr->rows_in_mem; + if (ltemp < 0) + ltemp = 0; /* don't fall off front end of file */ + ptr->cur_start_row = (JDIMENSION) ltemp; + } + /* Read in the selected part of the array. + * During the initial write pass, we will do no actual read + * because the selected part is all undefined. + */ + do_barray_io(cinfo, ptr, FALSE); + } + /* Ensure the accessed part of the array is defined; prezero if needed. + * To improve locality of access, we only prezero the part of the array + * that the caller is about to access, not the entire in-memory array. + */ + if (ptr->first_undef_row < end_row) { + if (ptr->first_undef_row < start_row) { + if (writable) /* writer skipped over a section of array */ + ERREXIT(cinfo, JERR_BAD_VIRTUAL_ACCESS); + undef_row = start_row; /* but reader is allowed to read ahead */ + } else { + undef_row = ptr->first_undef_row; + } + if (writable) + ptr->first_undef_row = end_row; + if (ptr->pre_zero) { + size_t bytesperrow = (size_t) ptr->blocksperrow * SIZEOF(JBLOCK); + undef_row -= ptr->cur_start_row; /* make indexes relative to buffer */ + end_row -= ptr->cur_start_row; + while (undef_row < end_row) { + jzero_far((void FAR *) ptr->mem_buffer[undef_row], bytesperrow); + undef_row++; + } + } else { + if (! writable) /* reader looking at undefined data */ + ERREXIT(cinfo, JERR_BAD_VIRTUAL_ACCESS); + } + } + /* Flag the buffer dirty if caller will write in it */ + if (writable) + ptr->dirty = TRUE; + /* Return address of proper part of the buffer */ + return ptr->mem_buffer + (start_row - ptr->cur_start_row); +} + + +/* + * Release all objects belonging to a specified pool. + */ + +METHODDEF(void) +free_pool (j_common_ptr cinfo, int pool_id) +{ + my_mem_ptr mem = (my_mem_ptr) cinfo->mem; + small_pool_ptr shdr_ptr; + large_pool_ptr lhdr_ptr; + size_t space_freed; + + if (pool_id < 0 || pool_id >= JPOOL_NUMPOOLS) + ERREXIT1(cinfo, JERR_BAD_POOL_ID, pool_id); /* safety check */ + +#ifdef MEM_STATS + if (cinfo->err->trace_level > 1) + print_mem_stats(cinfo, pool_id); /* print pool's memory usage statistics */ +#endif + + /* If freeing IMAGE pool, close any virtual arrays first */ + if (pool_id == JPOOL_IMAGE) { + jvirt_sarray_ptr sptr; + jvirt_barray_ptr bptr; + + for (sptr = mem->virt_sarray_list; sptr != NULL; sptr = sptr->next) { + if (sptr->b_s_open) { /* there may be no backing store */ + sptr->b_s_open = FALSE; /* prevent recursive close if error */ + (*sptr->b_s_info.close_backing_store) (cinfo, & sptr->b_s_info); + } + } + mem->virt_sarray_list = NULL; + for (bptr = mem->virt_barray_list; bptr != NULL; bptr = bptr->next) { + if (bptr->b_s_open) { /* there may be no backing store */ + bptr->b_s_open = FALSE; /* prevent recursive close if error */ + (*bptr->b_s_info.close_backing_store) (cinfo, & bptr->b_s_info); + } + } + mem->virt_barray_list = NULL; + } + + /* Release large objects */ + lhdr_ptr = mem->large_list[pool_id]; + mem->large_list[pool_id] = NULL; + + while (lhdr_ptr != NULL) { + large_pool_ptr next_lhdr_ptr = lhdr_ptr->hdr.next; + space_freed = lhdr_ptr->hdr.bytes_used + + lhdr_ptr->hdr.bytes_left + + SIZEOF(large_pool_hdr); + jpeg_free_large(cinfo, (void FAR *) lhdr_ptr, space_freed); + mem->total_space_allocated -= (long)space_freed; + lhdr_ptr = next_lhdr_ptr; + } + + /* Release small objects */ + shdr_ptr = mem->small_list[pool_id]; + mem->small_list[pool_id] = NULL; + + while (shdr_ptr != NULL) { + small_pool_ptr next_shdr_ptr = shdr_ptr->hdr.next; + space_freed = shdr_ptr->hdr.bytes_used + + shdr_ptr->hdr.bytes_left + + SIZEOF(small_pool_hdr); + jpeg_free_small(cinfo, (void *) shdr_ptr, space_freed); + mem->total_space_allocated -= (long)space_freed; + shdr_ptr = next_shdr_ptr; + } +} + + +/* + * Close up shop entirely. + * Note that this cannot be called unless cinfo->mem is non-NULL. + */ + +METHODDEF(void) +self_destruct (j_common_ptr cinfo) +{ + int pool; + + /* Close all backing store, release all memory. + * Releasing pools in reverse order might help avoid fragmentation + * with some (brain-damaged) malloc libraries. + */ + for (pool = JPOOL_NUMPOOLS-1; pool >= JPOOL_PERMANENT; pool--) { + free_pool(cinfo, pool); + } + + /* Release the memory manager control block too. */ + jpeg_free_small(cinfo, (void *) cinfo->mem, SIZEOF(my_memory_mgr)); + cinfo->mem = NULL; /* ensures I will be called only once */ + + jpeg_mem_term(cinfo); /* system-dependent cleanup */ +} + + +/* + * Memory manager initialization. + * When this is called, only the error manager pointer is valid in cinfo! + */ + +GLOBAL(void) +jinit_memory_mgr (j_common_ptr cinfo) +{ + my_mem_ptr mem; + long max_to_use; + int pool; + size_t test_mac; + + cinfo->mem = NULL; /* for safety if init fails */ + + /* Check for configuration errors. + * SIZEOF(ALIGN_TYPE) should be a power of 2; otherwise, it probably + * doesn't reflect any real hardware alignment requirement. + * The test is a little tricky: for X>0, X and X-1 have no one-bits + * in common if and only if X is a power of 2, ie has only one one-bit. + * Some compilers may give an "unreachable code" warning here; ignore it. + */ + if ((SIZEOF(ALIGN_TYPE) & (SIZEOF(ALIGN_TYPE)-1)) != 0) + ERREXIT(cinfo, JERR_BAD_ALIGN_TYPE); + /* MAX_ALLOC_CHUNK must be representable as type size_t, and must be + * a multiple of SIZEOF(ALIGN_TYPE). + * Again, an "unreachable code" warning may be ignored here. + * But a "constant too large" warning means you need to fix MAX_ALLOC_CHUNK. + */ + test_mac = (size_t) MAX_ALLOC_CHUNK; + if ((long) test_mac != MAX_ALLOC_CHUNK || + (MAX_ALLOC_CHUNK % SIZEOF(ALIGN_TYPE)) != 0) + ERREXIT(cinfo, JERR_BAD_ALLOC_CHUNK); + + max_to_use = jpeg_mem_init(cinfo); /* system-dependent initialization */ + + /* Attempt to allocate memory manager's control block */ + mem = (my_mem_ptr) jpeg_get_small(cinfo, SIZEOF(my_memory_mgr)); + + if (mem == NULL) { + jpeg_mem_term(cinfo); /* system-dependent cleanup */ + ERREXIT1(cinfo, JERR_OUT_OF_MEMORY, 0); + } + + /* OK, fill in the method pointers */ + mem->pub.alloc_small = alloc_small; + mem->pub.alloc_large = alloc_large; + mem->pub.alloc_sarray = alloc_sarray; + mem->pub.alloc_barray = alloc_barray; + mem->pub.request_virt_sarray = request_virt_sarray; + mem->pub.request_virt_barray = request_virt_barray; + mem->pub.realize_virt_arrays = realize_virt_arrays; + mem->pub.access_virt_sarray = access_virt_sarray; + mem->pub.access_virt_barray = access_virt_barray; + mem->pub.free_pool = free_pool; + mem->pub.self_destruct = self_destruct; + + /* Make MAX_ALLOC_CHUNK accessible to other modules */ + mem->pub.max_alloc_chunk = MAX_ALLOC_CHUNK; + + /* Initialize working state */ + mem->pub.max_memory_to_use = max_to_use; + + for (pool = JPOOL_NUMPOOLS-1; pool >= JPOOL_PERMANENT; pool--) { + mem->small_list[pool] = NULL; + mem->large_list[pool] = NULL; + } + mem->virt_sarray_list = NULL; + mem->virt_barray_list = NULL; + + mem->total_space_allocated = SIZEOF(my_memory_mgr); + + /* Declare ourselves open for business */ + cinfo->mem = & mem->pub; + + /* Check for an environment variable JPEGMEM; if found, override the + * default max_memory setting from jpeg_mem_init. Note that the + * surrounding application may again override this value. + * If your system doesn't support getenv(), define NO_GETENV to disable + * this feature. + */ +#ifndef NO_GETENV + { char * memenv; + + if ((memenv = getenv("JPEGMEM")) != NULL) { + char ch = 'x'; + + if (sscanf(memenv, "%ld%c", &max_to_use, &ch) > 0) { + if (ch == 'm' || ch == 'M') + max_to_use *= 1000L; + mem->pub.max_memory_to_use = max_to_use * 1000L; + } + } + } +#endif + +} diff --git a/ml/dlib/dlib/external/libjpeg/jmemnobs.cpp b/ml/dlib/dlib/external/libjpeg/jmemnobs.cpp new file mode 100644 index 000000000..27fe6c457 --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jmemnobs.cpp @@ -0,0 +1,109 @@ +/* + * jmemnobs.c + * + * Copyright (C) 1992-1996, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This file provides a really simple implementation of the system- + * dependent portion of the JPEG memory manager. This implementation + * assumes that no backing-store files are needed: all required space + * can be obtained from malloc(). + * This is very portable in the sense that it'll compile on almost anything, + * but you'd better have lots of main memory (or virtual memory) if you want + * to process big images. + * Note that the max_memory_to_use option is ignored by this implementation. + */ + +#define JPEG_INTERNALS +#include "jinclude.h" +#include "jpeglib.h" +#include "jmemsys.h" /* import the system-dependent declarations */ + +#ifndef HAVE_STDLIB_H /* should declare malloc(),free() */ +extern void * malloc JPP((size_t size)); +extern void free JPP((void *ptr)); +#endif + + +/* + * Memory allocation and freeing are controlled by the regular library + * routines malloc() and free(). + */ + +GLOBAL(void *) +jpeg_get_small (j_common_ptr , size_t sizeofobject) +{ + return (void *) malloc(sizeofobject); +} + +GLOBAL(void) +jpeg_free_small (j_common_ptr , void * object, size_t ) +{ + free(object); +} + + +/* + * "Large" objects are treated the same as "small" ones. + * NB: although we include FAR keywords in the routine declarations, + * this file won't actually work in 80x86 small/medium model; at least, + * you probably won't be able to process useful-size images in only 64KB. + */ + +GLOBAL(void FAR *) +jpeg_get_large (j_common_ptr , size_t sizeofobject) +{ + return (void FAR *) malloc(sizeofobject); +} + +GLOBAL(void) +jpeg_free_large (j_common_ptr , void FAR * object, size_t ) +{ + free(object); +} + + +/* + * This routine computes the total memory space available for allocation. + * Here we always say, "we got all you want bud!" + */ + +GLOBAL(long) +jpeg_mem_available (j_common_ptr , long , + long max_bytes_needed, long ) +{ + return max_bytes_needed; +} + + +/* + * Backing store (temporary file) management. + * Since jpeg_mem_available always promised the moon, + * this should never be called and we can just error out. + */ + +GLOBAL(void) +jpeg_open_backing_store (j_common_ptr cinfo, backing_store_ptr , + long ) +{ + ERREXIT(cinfo, JERR_NO_BACKING_STORE); +} + + +/* + * These routines take care of any system-dependent initialization and + * cleanup required. Here, there isn't any. + */ + +GLOBAL(long) +jpeg_mem_init (j_common_ptr ) +{ + return 0; /* just set max_memory_to_use to 0 */ +} + +GLOBAL(void) +jpeg_mem_term (j_common_ptr ) +{ + /* no work */ +} diff --git a/ml/dlib/dlib/external/libjpeg/jmemsys.h b/ml/dlib/dlib/external/libjpeg/jmemsys.h new file mode 100644 index 000000000..6c3c6d348 --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jmemsys.h @@ -0,0 +1,198 @@ +/* + * jmemsys.h + * + * Copyright (C) 1992-1997, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This include file defines the interface between the system-independent + * and system-dependent portions of the JPEG memory manager. No other + * modules need include it. (The system-independent portion is jmemmgr.c; + * there are several different versions of the system-dependent portion.) + * + * This file works as-is for the system-dependent memory managers supplied + * in the IJG distribution. You may need to modify it if you write a + * custom memory manager. If system-dependent changes are needed in + * this file, the best method is to #ifdef them based on a configuration + * symbol supplied in jconfig.h, as we have done with USE_MSDOS_MEMMGR + * and USE_MAC_MEMMGR. + */ + + +/* Short forms of external names for systems with brain-damaged linkers. */ + +#ifdef NEED_SHORT_EXTERNAL_NAMES +#define jpeg_get_small jGetSmall +#define jpeg_free_small jFreeSmall +#define jpeg_get_large jGetLarge +#define jpeg_free_large jFreeLarge +#define jpeg_mem_available jMemAvail +#define jpeg_open_backing_store jOpenBackStore +#define jpeg_mem_init jMemInit +#define jpeg_mem_term jMemTerm +#endif /* NEED_SHORT_EXTERNAL_NAMES */ + + +/* + * These two functions are used to allocate and release small chunks of + * memory. (Typically the total amount requested through jpeg_get_small is + * no more than 20K or so; this will be requested in chunks of a few K each.) + * Behavior should be the same as for the standard library functions malloc + * and free; in particular, jpeg_get_small must return NULL on failure. + * On most systems, these ARE malloc and free. jpeg_free_small is passed the + * size of the object being freed, just in case it's needed. + * On an 80x86 machine using small-data memory model, these manage near heap. + */ + +EXTERN(void *) jpeg_get_small JPP((j_common_ptr cinfo, size_t sizeofobject)); +EXTERN(void) jpeg_free_small JPP((j_common_ptr cinfo, void * object, + size_t sizeofobject)); + +/* + * These two functions are used to allocate and release large chunks of + * memory (up to the total free space designated by jpeg_mem_available). + * The interface is the same as above, except that on an 80x86 machine, + * far pointers are used. On most other machines these are identical to + * the jpeg_get/free_small routines; but we keep them separate anyway, + * in case a different allocation strategy is desirable for large chunks. + */ + +EXTERN(void FAR *) jpeg_get_large JPP((j_common_ptr cinfo, + size_t sizeofobject)); +EXTERN(void) jpeg_free_large JPP((j_common_ptr cinfo, void FAR * object, + size_t sizeofobject)); + +/* + * The macro MAX_ALLOC_CHUNK designates the maximum number of bytes that may + * be requested in a single call to jpeg_get_large (and jpeg_get_small for that + * matter, but that case should never come into play). This macro is needed + * to model the 64Kb-segment-size limit of far addressing on 80x86 machines. + * On those machines, we expect that jconfig.h will provide a proper value. + * On machines with 32-bit flat address spaces, any large constant may be used. + * + * NB: jmemmgr.c expects that MAX_ALLOC_CHUNK will be representable as type + * size_t and will be a multiple of sizeof(align_type). + */ + +#ifndef MAX_ALLOC_CHUNK /* may be overridden in jconfig.h */ +#define MAX_ALLOC_CHUNK 1000000000L +#endif + +/* + * This routine computes the total space still available for allocation by + * jpeg_get_large. If more space than this is needed, backing store will be + * used. NOTE: any memory already allocated must not be counted. + * + * There is a minimum space requirement, corresponding to the minimum + * feasible buffer sizes; jmemmgr.c will request that much space even if + * jpeg_mem_available returns zero. The maximum space needed, enough to hold + * all working storage in memory, is also passed in case it is useful. + * Finally, the total space already allocated is passed. If no better + * method is available, cinfo->mem->max_memory_to_use - already_allocated + * is often a suitable calculation. + * + * It is OK for jpeg_mem_available to underestimate the space available + * (that'll just lead to more backing-store access than is really necessary). + * However, an overestimate will lead to failure. Hence it's wise to subtract + * a slop factor from the true available space. 5% should be enough. + * + * On machines with lots of virtual memory, any large constant may be returned. + * Conversely, zero may be returned to always use the minimum amount of memory. + */ + +EXTERN(long) jpeg_mem_available JPP((j_common_ptr cinfo, + long min_bytes_needed, + long max_bytes_needed, + long already_allocated)); + + +/* + * This structure holds whatever state is needed to access a single + * backing-store object. The read/write/close method pointers are called + * by jmemmgr.c to manipulate the backing-store object; all other fields + * are private to the system-dependent backing store routines. + */ + +#define TEMP_NAME_LENGTH 64 /* max length of a temporary file's name */ + + +#ifdef USE_MSDOS_MEMMGR /* DOS-specific junk */ + +typedef unsigned short XMSH; /* type of extended-memory handles */ +typedef unsigned short EMSH; /* type of expanded-memory handles */ + +typedef union { + short file_handle; /* DOS file handle if it's a temp file */ + XMSH xms_handle; /* handle if it's a chunk of XMS */ + EMSH ems_handle; /* handle if it's a chunk of EMS */ +} handle_union; + +#endif /* USE_MSDOS_MEMMGR */ + +#ifdef USE_MAC_MEMMGR /* Mac-specific junk */ +#include +#endif /* USE_MAC_MEMMGR */ + + +typedef struct backing_store_struct * backing_store_ptr; + +typedef struct backing_store_struct { + /* Methods for reading/writing/closing this backing-store object */ + JMETHOD(void, read_backing_store, (j_common_ptr cinfo, + backing_store_ptr info, + void FAR * buffer_address, + long file_offset, long byte_count)); + JMETHOD(void, write_backing_store, (j_common_ptr cinfo, + backing_store_ptr info, + void FAR * buffer_address, + long file_offset, long byte_count)); + JMETHOD(void, close_backing_store, (j_common_ptr cinfo, + backing_store_ptr info)); + + /* Private fields for system-dependent backing-store management */ +#ifdef USE_MSDOS_MEMMGR + /* For the MS-DOS manager (jmemdos.c), we need: */ + handle_union handle; /* reference to backing-store storage object */ + char temp_name[TEMP_NAME_LENGTH]; /* name if it's a file */ +#else +#ifdef USE_MAC_MEMMGR + /* For the Mac manager (jmemmac.c), we need: */ + short temp_file; /* file reference number to temp file */ + FSSpec tempSpec; /* the FSSpec for the temp file */ + char temp_name[TEMP_NAME_LENGTH]; /* name if it's a file */ +#else + /* For a typical implementation with temp files, we need: */ + FILE * temp_file; /* stdio reference to temp file */ + char temp_name[TEMP_NAME_LENGTH]; /* name of temp file */ +#endif +#endif +} backing_store_info; + + +/* + * Initial opening of a backing-store object. This must fill in the + * read/write/close pointers in the object. The read/write routines + * may take an error exit if the specified maximum file size is exceeded. + * (If jpeg_mem_available always returns a large value, this routine can + * just take an error exit.) + */ + +EXTERN(void) jpeg_open_backing_store JPP((j_common_ptr cinfo, + backing_store_ptr info, + long total_bytes_needed)); + + +/* + * These routines take care of any system-dependent initialization and + * cleanup required. jpeg_mem_init will be called before anything is + * allocated (and, therefore, nothing in cinfo is of use except the error + * manager pointer). It should return a suitable default value for + * max_memory_to_use; this may subsequently be overridden by the surrounding + * application. (Note that max_memory_to_use is only important if + * jpeg_mem_available chooses to consult it ... no one else will.) + * jpeg_mem_term may assume that all requested memory has been freed and that + * all opened backing-store objects have been closed. + */ + +EXTERN(long) jpeg_mem_init JPP((j_common_ptr cinfo)); +EXTERN(void) jpeg_mem_term JPP((j_common_ptr cinfo)); diff --git a/ml/dlib/dlib/external/libjpeg/jmorecfg.h b/ml/dlib/dlib/external/libjpeg/jmorecfg.h new file mode 100644 index 000000000..6082f069a --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jmorecfg.h @@ -0,0 +1,356 @@ +/* + * jmorecfg.h + * + * Copyright (C) 1991-1997, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This file contains additional configuration options that customize the + * JPEG software for special applications or support machine-dependent + * optimizations. Most users will not need to touch this file. + */ + + +/* + * Define BITS_IN_JSAMPLE as either + * 8 for 8-bit sample values (the usual setting) + * 12 for 12-bit sample values + * Only 8 and 12 are legal data precisions for lossy JPEG according to the + * JPEG standard, and the IJG code does not support anything else! + * We do not support run-time selection of data precision, sorry. + */ + +#define BITS_IN_JSAMPLE 8 /* use 8 or 12 */ + + +/* + * Maximum number of components (color channels) allowed in JPEG image. + * To meet the letter of the JPEG spec, set this to 255. However, darn + * few applications need more than 4 channels (maybe 5 for CMYK + alpha + * mask). We recommend 10 as a reasonable compromise; use 4 if you are + * really short on memory. (Each allowed component costs a hundred or so + * bytes of storage, whether actually used in an image or not.) + */ + +#define MAX_COMPONENTS 10 /* maximum number of image components */ + + +/* + * Basic data types. + * You may need to change these if you have a machine with unusual data + * type sizes; for example, "char" not 8 bits, "short" not 16 bits, + * or "long" not 32 bits. We don't care whether "int" is 16 or 32 bits, + * but it had better be at least 16. + */ + +/* Representation of a single sample (pixel element value). + * We frequently allocate large arrays of these, so it's important to keep + * them small. But if you have memory to burn and access to char or short + * arrays is very slow on your hardware, you might want to change these. + */ + + +#ifdef _MSC_VER +// Disable the following warnings for Visual Studio +// This is a warning you get from visual studio 2005 about things in the standard C++ +// library being "deprecated." I checked the C++ standard and it doesn't say jack +// about any of them (I checked the searchable PDF). So this warning is total Bunk. +#pragma warning(disable : 4996) +#endif + + +#if BITS_IN_JSAMPLE == 8 +/* JSAMPLE should be the smallest type that will hold the values 0..255. + * You can use a signed char by having GETJSAMPLE mask it with 0xFF. + */ + +#ifdef HAVE_UNSIGNED_CHAR + +typedef unsigned char JSAMPLE; +#define GETJSAMPLE(value) ((int) (value)) + +#else /* not HAVE_UNSIGNED_CHAR */ + +typedef char JSAMPLE; +#ifdef CHAR_IS_UNSIGNED +#define GETJSAMPLE(value) ((int) (value)) +#else +#define GETJSAMPLE(value) ((int) (value) & 0xFF) +#endif /* CHAR_IS_UNSIGNED */ + +#endif /* HAVE_UNSIGNED_CHAR */ + +#define MAXJSAMPLE 255 +#define CENTERJSAMPLE 128 + +#endif /* BITS_IN_JSAMPLE == 8 */ + + +#if BITS_IN_JSAMPLE == 12 +/* JSAMPLE should be the smallest type that will hold the values 0..4095. + * On nearly all machines "short" will do nicely. + */ + +typedef short JSAMPLE; +#define GETJSAMPLE(value) ((int) (value)) + +#define MAXJSAMPLE 4095 +#define CENTERJSAMPLE 2048 + +#endif /* BITS_IN_JSAMPLE == 12 */ + + +/* Representation of a DCT frequency coefficient. + * This should be a signed value of at least 16 bits; "short" is usually OK. + * Again, we allocate large arrays of these, but you can change to int + * if you have memory to burn and "short" is really slow. + */ + +typedef short JCOEF; + + +/* Compressed datastreams are represented as arrays of JOCTET. + * These must be EXACTLY 8 bits wide, at least once they are written to + * external storage. Note that when using the stdio data source/destination + * managers, this is also the data type passed to fread/fwrite. + */ + +#ifdef HAVE_UNSIGNED_CHAR + +typedef unsigned char JOCTET; +#define GETJOCTET(value) (value) + +#else /* not HAVE_UNSIGNED_CHAR */ + +typedef char JOCTET; +#ifdef CHAR_IS_UNSIGNED +#define GETJOCTET(value) (value) +#else +#define GETJOCTET(value) ((value) & 0xFF) +#endif /* CHAR_IS_UNSIGNED */ + +#endif /* HAVE_UNSIGNED_CHAR */ + + +/* These typedefs are used for various table entries and so forth. + * They must be at least as wide as specified; but making them too big + * won't cost a huge amount of memory, so we don't provide special + * extraction code like we did for JSAMPLE. (In other words, these + * typedefs live at a different point on the speed/space tradeoff curve.) + */ + +/* unsigned char must hold at least the values 0..255. */ + + +/* unsigned short must hold at least the values 0..65535. */ + + + +/* Datatype used for image dimensions. The JPEG standard only supports + * images up to 64K*64K due to 16-bit fields in SOF markers. Therefore + * "unsigned int" is sufficient on all machines. However, if you need to + * handle larger images and you don't mind deviating from the spec, you + * can change this datatype. + */ + +typedef unsigned int JDIMENSION; + +#define JPEG_MAX_DIMENSION 65500L /* a tad under 64K to prevent overflows */ + + +/* These macros are used in all function definitions and extern declarations. + * You could modify them if you need to change function linkage conventions; + * in particular, you'll need to do that to make the library a Windows DLL. + * Another application is to make all functions global for use with debuggers + * or code profilers that require it. + */ + +/* a function called through method pointers: */ +#define METHODDEF(type) static type +/* a function used only in its module: */ +#define LOCAL(type) static type +/* a function referenced thru EXTERNs: */ +#define GLOBAL(type) type +/* + Use C linking unless we are supposed to be compiling our own copy of + libjpeg. Then let it use C++ linking so that we are less likely to get + linker name conflicts with other libraries that happen to statically include + libjpeg as well. +*/ +#if defined(__cplusplus) && !defined(DLIB_JPEG_STATIC) +#define EXTERN(type) extern "C" type +#else +#define EXTERN(type) extern type +#endif + + +/* This macro is used to declare a "method", that is, a function pointer. + * We want to supply prototype parameters if the compiler can cope. + * Note that the arglist parameter must be parenthesized! + * Again, you can customize this if you need special linkage keywords. + */ + +#ifdef HAVE_PROTOTYPES +#define JMETHOD(type,methodname,arglist) type (*methodname) arglist +#else +#define JMETHOD(type,methodname,arglist) type (*methodname) () +#endif + + +/* Here is the pseudo-keyword for declaring pointers that must be "far" + * on 80x86 machines. Most of the specialized coding for 80x86 is handled + * by just saying "FAR *" where such a pointer is needed. In a few places + * explicit coding is needed; see uses of the NEED_FAR_POINTERS symbol. + */ + +#ifdef NEED_FAR_POINTERS +#define FAR far +#else +#ifndef FAR + #define FAR +#endif +#endif + + +/* + * On a few systems, type boolean and/or its values FALSE, TRUE may appear + * in standard header files. Or you may have conflicts with application- + * specific header files that you want to include together with these files. + * Defining HAVE_BOOLEAN before including jpeglib.h should make it work. + */ + +#ifndef FALSE /* in case these macros already exist */ +#define FALSE 0 /* values of boolean */ +#endif +#ifndef TRUE +#define TRUE 1 +#endif + + +/* + * The remaining options affect code selection within the JPEG library, + * but they don't need to be visible to most applications using the library. + * To minimize application namespace pollution, the symbols won't be + * defined unless JPEG_INTERNALS or JPEG_INTERNAL_OPTIONS has been defined. + */ + +#ifdef JPEG_INTERNALS +#define JPEG_INTERNAL_OPTIONS +#endif + +#ifdef JPEG_INTERNAL_OPTIONS + + +/* + * These defines indicate whether to include various optional functions. + * Undefining some of these symbols will produce a smaller but less capable + * library. Note that you can leave certain source files out of the + * compilation/linking process if you've #undef'd the corresponding symbols. + * (You may HAVE to do that if your compiler doesn't like null source files.) + */ + +/* Arithmetic coding is unsupported for legal reasons. Complaints to IBM. */ + +/* Capability options common to encoder and decoder: */ + +#define DCT_ISLOW_SUPPORTED /* slow but accurate integer algorithm */ +#define DCT_IFAST_SUPPORTED /* faster, less accurate integer method */ +#define DCT_FLOAT_SUPPORTED /* floating-point: accurate, fast on fast HW */ + +/* Encoder capability options: */ + +#undef C_ARITH_CODING_SUPPORTED /* Arithmetic coding back end? */ +#define C_MULTISCAN_FILES_SUPPORTED /* Multiple-scan JPEG files? */ +#define C_PROGRESSIVE_SUPPORTED /* Progressive JPEG? (Requires MULTISCAN)*/ +#define ENTROPY_OPT_SUPPORTED /* Optimization of entropy coding parms? */ +/* Note: if you selected 12-bit data precision, it is dangerous to turn off + * ENTROPY_OPT_SUPPORTED. The standard Huffman tables are only good for 8-bit + * precision, so jchuff.c normally uses entropy optimization to compute + * usable tables for higher precision. If you don't want to do optimization, + * you'll have to supply different default Huffman tables. + * The exact same statements apply for progressive JPEG: the default tables + * don't work for progressive mode. (This may get fixed, however.) + */ +#define INPUT_SMOOTHING_SUPPORTED /* Input image smoothing option? */ + +/* Decoder capability options: */ + +#undef D_ARITH_CODING_SUPPORTED /* Arithmetic coding back end? */ +#define D_MULTISCAN_FILES_SUPPORTED /* Multiple-scan JPEG files? */ +#define D_PROGRESSIVE_SUPPORTED /* Progressive JPEG? (Requires MULTISCAN)*/ +#define SAVE_MARKERS_SUPPORTED /* jpeg_save_markers() needed? */ +#define BLOCK_SMOOTHING_SUPPORTED /* Block smoothing? (Progressive only) */ +#define IDCT_SCALING_SUPPORTED /* Output rescaling via IDCT? */ +#undef UPSAMPLE_SCALING_SUPPORTED /* Output rescaling at upsample stage? */ +#define UPSAMPLE_MERGING_SUPPORTED /* Fast path for sloppy upsampling? */ +#define QUANT_1PASS_SUPPORTED /* 1-pass color quantization? */ +#define QUANT_2PASS_SUPPORTED /* 2-pass color quantization? */ + +/* more capability options later, no doubt */ + + +/* + * Ordering of RGB data in scanlines passed to or from the application. + * If your application wants to deal with data in the order B,G,R, just + * change these macros. You can also deal with formats such as R,G,B,X + * (one extra byte per pixel) by changing RGB_PIXELSIZE. Note that changing + * the offsets will also change the order in which colormap data is organized. + * RESTRICTIONS: + * 1. The sample applications cjpeg,djpeg do NOT support modified RGB formats. + * 2. These macros only affect RGB<=>YCbCr color conversion, so they are not + * useful if you are using JPEG color spaces other than YCbCr or grayscale. + * 3. The color quantizer modules will not behave desirably if RGB_PIXELSIZE + * is not 3 (they don't understand about dummy color components!). So you + * can't use color quantization if you change that value. + */ + +#define RGB_RED 0 /* Offset of Red in an RGB scanline element */ +#define RGB_GREEN 1 /* Offset of Green */ +#define RGB_BLUE 2 /* Offset of Blue */ +#define RGB_PIXELSIZE 3 /* JSAMPLEs per RGB scanline element */ + + +/* Definitions for speed-related optimizations. */ + + +/* If your compiler supports inline functions, define INLINE + * as the inline keyword; otherwise define it as empty. + */ + +#ifndef INLINE +#ifdef __GNUC__ /* for instance, GNU C knows about inline */ +#define INLINE __inline__ +#endif +#ifndef INLINE +#define INLINE /* default is to define it as empty */ +#endif +#endif + + +/* On some machines (notably 68000 series) "int" is 32 bits, but multiplying + * two 16-bit shorts is faster than multiplying two ints. Define MULTIPLIER + * as short on such a machine. MULTIPLIER must be at least 16 bits wide. + */ + +#ifndef MULTIPLIER +#define MULTIPLIER int /* type for fastest integer multiply */ +#endif + + +/* FAST_FLOAT should be either float or double, whichever is done faster + * by your compiler. (Note that this type is only used in the floating point + * DCT routines, so it only matters if you've defined DCT_FLOAT_SUPPORTED.) + * Typically, float is faster in ANSI C compilers, while double is faster in + * pre-ANSI compilers (because they insist on converting to double anyway). + * The code below therefore chooses float if we have ANSI-style prototypes. + */ + +#ifndef FAST_FLOAT +#ifdef HAVE_PROTOTYPES +#define FAST_FLOAT float +#else +#define FAST_FLOAT double +#endif +#endif + +#endif /* JPEG_INTERNAL_OPTIONS */ diff --git a/ml/dlib/dlib/external/libjpeg/jpegint.h b/ml/dlib/dlib/external/libjpeg/jpegint.h new file mode 100644 index 000000000..654067965 --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jpegint.h @@ -0,0 +1,392 @@ +/* + * jpegint.h + * + * Copyright (C) 1991-1997, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This file provides common declarations for the various JPEG modules. + * These declarations are considered internal to the JPEG library; most + * applications using the library shouldn't need to include this file. + */ + + +/* Declarations for both compression & decompression */ + +typedef enum { /* Operating modes for buffer controllers */ + JBUF_PASS_THRU, /* Plain stripwise operation */ + /* Remaining modes require a full-image buffer to have been created */ + JBUF_SAVE_SOURCE, /* Run source subobject only, save output */ + JBUF_CRANK_DEST, /* Run dest subobject only, using saved data */ + JBUF_SAVE_AND_PASS /* Run both subobjects, save output */ +} J_BUF_MODE; + +/* Values of global_state field (jdapi.c has some dependencies on ordering!) */ +#define CSTATE_START 100 /* after create_compress */ +#define CSTATE_SCANNING 101 /* start_compress done, write_scanlines OK */ +#define CSTATE_RAW_OK 102 /* start_compress done, write_raw_data OK */ +#define CSTATE_WRCOEFS 103 /* jpeg_write_coefficients done */ +#define DSTATE_START 200 /* after create_decompress */ +#define DSTATE_INHEADER 201 /* reading header markers, no SOS yet */ +#define DSTATE_READY 202 /* found SOS, ready for start_decompress */ +#define DSTATE_PRELOAD 203 /* reading multiscan file in start_decompress*/ +#define DSTATE_PRESCAN 204 /* performing dummy pass for 2-pass quant */ +#define DSTATE_SCANNING 205 /* start_decompress done, read_scanlines OK */ +#define DSTATE_RAW_OK 206 /* start_decompress done, read_raw_data OK */ +#define DSTATE_BUFIMAGE 207 /* expecting jpeg_start_output */ +#define DSTATE_BUFPOST 208 /* looking for SOS/EOI in jpeg_finish_output */ +#define DSTATE_RDCOEFS 209 /* reading file in jpeg_read_coefficients */ +#define DSTATE_STOPPING 210 /* looking for EOI in jpeg_finish_decompress */ + + +/* Declarations for compression modules */ + +/* Master control module */ +struct jpeg_comp_master { + JMETHOD(void, prepare_for_pass, (j_compress_ptr cinfo)); + JMETHOD(void, pass_startup, (j_compress_ptr cinfo)); + JMETHOD(void, finish_pass, (j_compress_ptr cinfo)); + + /* State variables made visible to other modules */ + int call_pass_startup; /* True if pass_startup must be called */ + int is_last_pass; /* True during last pass */ +}; + +/* Main buffer control (downsampled-data buffer) */ +struct jpeg_c_main_controller { + JMETHOD(void, start_pass, (j_compress_ptr cinfo, J_BUF_MODE pass_mode)); + JMETHOD(void, process_data, (j_compress_ptr cinfo, + JSAMPARRAY input_buf, JDIMENSION *in_row_ctr, + JDIMENSION in_rows_avail)); +}; + +/* Compression preprocessing (downsampling input buffer control) */ +struct jpeg_c_prep_controller { + JMETHOD(void, start_pass, (j_compress_ptr cinfo, J_BUF_MODE pass_mode)); + JMETHOD(void, pre_process_data, (j_compress_ptr cinfo, + JSAMPARRAY input_buf, + JDIMENSION *in_row_ctr, + JDIMENSION in_rows_avail, + JSAMPIMAGE output_buf, + JDIMENSION *out_row_group_ctr, + JDIMENSION out_row_groups_avail)); +}; + +/* Coefficient buffer control */ +struct jpeg_c_coef_controller { + JMETHOD(void, start_pass, (j_compress_ptr cinfo, J_BUF_MODE pass_mode)); + JMETHOD(int, compress_data, (j_compress_ptr cinfo, + JSAMPIMAGE input_buf)); +}; + +/* Colorspace conversion */ +struct jpeg_color_converter { + JMETHOD(void, start_pass, (j_compress_ptr cinfo)); + JMETHOD(void, color_convert, (j_compress_ptr cinfo, + JSAMPARRAY input_buf, JSAMPIMAGE output_buf, + JDIMENSION output_row, int num_rows)); +}; + +/* Downsampling */ +struct jpeg_downsampler { + JMETHOD(void, start_pass, (j_compress_ptr cinfo)); + JMETHOD(void, downsample, (j_compress_ptr cinfo, + JSAMPIMAGE input_buf, JDIMENSION in_row_index, + JSAMPIMAGE output_buf, + JDIMENSION out_row_group_index)); + + int need_context_rows; /* TRUE if need rows above & below */ +}; + +/* Forward DCT (also controls coefficient quantization) */ +struct jpeg_forward_dct { + JMETHOD(void, start_pass, (j_compress_ptr cinfo)); + /* perhaps this should be an array??? */ + JMETHOD(void, forward_DCT, (j_compress_ptr cinfo, + jpeg_component_info * compptr, + JSAMPARRAY sample_data, JBLOCKROW coef_blocks, + JDIMENSION start_row, JDIMENSION start_col, + JDIMENSION num_blocks)); +}; + +/* Entropy encoding */ +struct jpeg_entropy_encoder { + JMETHOD(void, start_pass, (j_compress_ptr cinfo, int gather_statistics)); + JMETHOD(int, encode_mcu, (j_compress_ptr cinfo, JBLOCKROW *MCU_data)); + JMETHOD(void, finish_pass, (j_compress_ptr cinfo)); +}; + +/* Marker writing */ +struct jpeg_marker_writer { + JMETHOD(void, write_file_header, (j_compress_ptr cinfo)); + JMETHOD(void, write_frame_header, (j_compress_ptr cinfo)); + JMETHOD(void, write_scan_header, (j_compress_ptr cinfo)); + JMETHOD(void, write_file_trailer, (j_compress_ptr cinfo)); + JMETHOD(void, write_tables_only, (j_compress_ptr cinfo)); + /* These routines are exported to allow insertion of extra markers */ + /* Probably only COM and APPn markers should be written this way */ + JMETHOD(void, write_marker_header, (j_compress_ptr cinfo, int marker, + unsigned int datalen)); + JMETHOD(void, write_marker_byte, (j_compress_ptr cinfo, int val)); +}; + + +/* Declarations for decompression modules */ + +/* Master control module */ +struct jpeg_decomp_master { + JMETHOD(void, prepare_for_output_pass, (j_decompress_ptr cinfo)); + JMETHOD(void, finish_output_pass, (j_decompress_ptr cinfo)); + + /* State variables made visible to other modules */ + int is_dummy_pass; /* True during 1st pass for 2-pass quant */ +}; + +/* Input control module */ +struct jpeg_input_controller { + JMETHOD(int, consume_input, (j_decompress_ptr cinfo)); + JMETHOD(void, reset_input_controller, (j_decompress_ptr cinfo)); + JMETHOD(void, start_input_pass, (j_decompress_ptr cinfo)); + JMETHOD(void, finish_input_pass, (j_decompress_ptr cinfo)); + + /* State variables made visible to other modules */ + int has_multiple_scans; /* True if file has multiple scans */ + int eoi_reached; /* True when EOI has been consumed */ +}; + +/* Main buffer control (downsampled-data buffer) */ +struct jpeg_d_main_controller { + JMETHOD(void, start_pass, (j_decompress_ptr cinfo, J_BUF_MODE pass_mode)); + JMETHOD(void, process_data, (j_decompress_ptr cinfo, + JSAMPARRAY output_buf, JDIMENSION *out_row_ctr, + JDIMENSION out_rows_avail)); +}; + +/* Coefficient buffer control */ +struct jpeg_d_coef_controller { + JMETHOD(void, start_input_pass, (j_decompress_ptr cinfo)); + JMETHOD(int, consume_data, (j_decompress_ptr cinfo)); + JMETHOD(void, start_output_pass, (j_decompress_ptr cinfo)); + JMETHOD(int, decompress_data, (j_decompress_ptr cinfo, + JSAMPIMAGE output_buf)); + /* Pointer to array of coefficient virtual arrays, or NULL if none */ + jvirt_barray_ptr *coef_arrays; +}; + +/* Decompression postprocessing (color quantization buffer control) */ +struct jpeg_d_post_controller { + JMETHOD(void, start_pass, (j_decompress_ptr cinfo, J_BUF_MODE pass_mode)); + JMETHOD(void, post_process_data, (j_decompress_ptr cinfo, + JSAMPIMAGE input_buf, + JDIMENSION *in_row_group_ctr, + JDIMENSION in_row_groups_avail, + JSAMPARRAY output_buf, + JDIMENSION *out_row_ctr, + JDIMENSION out_rows_avail)); +}; + +/* Marker reading & parsing */ +struct jpeg_marker_reader { + JMETHOD(void, reset_marker_reader, (j_decompress_ptr cinfo)); + /* Read markers until SOS or EOI. + * Returns same codes as are defined for jpeg_consume_input: + * JPEG_SUSPENDED, JPEG_REACHED_SOS, or JPEG_REACHED_EOI. + */ + JMETHOD(int, read_markers, (j_decompress_ptr cinfo)); + /* Read a restart marker --- exported for use by entropy decoder only */ + jpeg_marker_parser_method read_restart_marker; + + /* State of marker reader --- nominally internal, but applications + * supplying COM or APPn handlers might like to know the state. + */ + int saw_SOI; /* found SOI? */ + int saw_SOF; /* found SOF? */ + int next_restart_num; /* next restart number expected (0-7) */ + unsigned int discarded_bytes; /* # of bytes skipped looking for a marker */ +}; + +/* Entropy decoding */ +struct jpeg_entropy_decoder { + JMETHOD(void, start_pass, (j_decompress_ptr cinfo)); + JMETHOD(int, decode_mcu, (j_decompress_ptr cinfo, + JBLOCKROW *MCU_data)); + + /* This is here to share code between baseline and progressive decoders; */ + /* other modules probably should not use it */ + int insufficient_data; /* set TRUE after emitting warning */ +}; + +/* Inverse DCT (also performs dequantization) */ +typedef JMETHOD(void, inverse_DCT_method_ptr, + (j_decompress_ptr cinfo, jpeg_component_info * compptr, + JCOEFPTR coef_block, + JSAMPARRAY output_buf, JDIMENSION output_col)); + +struct jpeg_inverse_dct { + JMETHOD(void, start_pass, (j_decompress_ptr cinfo)); + /* It is useful to allow each component to have a separate IDCT method. */ + inverse_DCT_method_ptr inverse_DCT[MAX_COMPONENTS]; +}; + +/* Upsampling (note that upsampler must also call color converter) */ +struct jpeg_upsampler { + JMETHOD(void, start_pass, (j_decompress_ptr cinfo)); + JMETHOD(void, upsample, (j_decompress_ptr cinfo, + JSAMPIMAGE input_buf, + JDIMENSION *in_row_group_ctr, + JDIMENSION in_row_groups_avail, + JSAMPARRAY output_buf, + JDIMENSION *out_row_ctr, + JDIMENSION out_rows_avail)); + + int need_context_rows; /* TRUE if need rows above & below */ +}; + +/* Colorspace conversion */ +struct jpeg_color_deconverter { + JMETHOD(void, start_pass, (j_decompress_ptr cinfo)); + JMETHOD(void, color_convert, (j_decompress_ptr cinfo, + JSAMPIMAGE input_buf, JDIMENSION input_row, + JSAMPARRAY output_buf, int num_rows)); +}; + +/* Color quantization or color precision reduction */ +struct jpeg_color_quantizer { + JMETHOD(void, start_pass, (j_decompress_ptr cinfo, int is_pre_scan)); + JMETHOD(void, color_quantize, (j_decompress_ptr cinfo, + JSAMPARRAY input_buf, JSAMPARRAY output_buf, + int num_rows)); + JMETHOD(void, finish_pass, (j_decompress_ptr cinfo)); + JMETHOD(void, new_color_map, (j_decompress_ptr cinfo)); +}; + + +/* Miscellaneous useful macros */ + +#undef MAX +#define MAX(a,b) ((a) > (b) ? (a) : (b)) +#undef MIN +#define MIN(a,b) ((a) < (b) ? (a) : (b)) + + +/* We assume that right shift corresponds to signed division by 2 with + * rounding towards minus infinity. This is correct for typical "arithmetic + * shift" instructions that shift in copies of the sign bit. But some + * C compilers implement >> with an unsigned shift. For these machines you + * must define RIGHT_SHIFT_IS_UNSIGNED. + * RIGHT_SHIFT provides a proper signed right shift of an long quantity. + * It is only applied with constant shift counts. SHIFT_TEMPS must be + * included in the variables of any routine using RIGHT_SHIFT. + */ + +#ifdef RIGHT_SHIFT_IS_UNSIGNED +#define SHIFT_TEMPS long shift_temp; +#define RIGHT_SHIFT(x,shft) \ + ((shift_temp = (x)) < 0 ? \ + (shift_temp >> (shft)) | ((~((long) 0)) << (32-(shft))) : \ + (shift_temp >> (shft))) +#else +#define SHIFT_TEMPS +#define RIGHT_SHIFT(x,shft) ((x) >> (shft)) +#endif + + +/* Short forms of external names for systems with brain-damaged linkers. */ + +#ifdef NEED_SHORT_EXTERNAL_NAMES +#define jinit_compress_master jICompress +#define jinit_c_master_control jICMaster +#define jinit_c_main_controller jICMainC +#define jinit_c_prep_controller jICPrepC +#define jinit_c_coef_controller jICCoefC +#define jinit_color_converter jICColor +#define jinit_downsampler jIDownsampler +#define jinit_forward_dct jIFDCT +#define jinit_huff_encoder jIHEncoder +#define jinit_phuff_encoder jIPHEncoder +#define jinit_marker_writer jIMWriter +#define jinit_master_decompress jIDMaster +#define jinit_d_main_controller jIDMainC +#define jinit_d_coef_controller jIDCoefC +#define jinit_d_post_controller jIDPostC +#define jinit_input_controller jIInCtlr +#define jinit_marker_reader jIMReader +#define jinit_huff_decoder jIHDecoder +#define jinit_phuff_decoder jIPHDecoder +#define jinit_inverse_dct jIIDCT +#define jinit_upsampler jIUpsampler +#define jinit_color_deconverter jIDColor +#define jinit_1pass_quantizer jI1Quant +#define jinit_2pass_quantizer jI2Quant +#define jinit_merged_upsampler jIMUpsampler +#define jinit_memory_mgr jIMemMgr +#define jdiv_round_up jDivRound +#define jround_up jRound +#define jcopy_sample_rows jCopySamples +#define jcopy_block_row jCopyBlocks +#define jzero_far jZeroFar +#define jpeg_zigzag_order jZIGTable +#define jpeg_natural_order jZAGTable +#endif /* NEED_SHORT_EXTERNAL_NAMES */ + + +/* Compression module initialization routines */ +EXTERN(void) jinit_compress_master JPP((j_compress_ptr cinfo)); +EXTERN(void) jinit_c_master_control JPP((j_compress_ptr cinfo, + int transcode_only)); +EXTERN(void) jinit_c_main_controller JPP((j_compress_ptr cinfo, + int need_full_buffer)); +EXTERN(void) jinit_c_prep_controller JPP((j_compress_ptr cinfo, + int need_full_buffer)); +EXTERN(void) jinit_c_coef_controller JPP((j_compress_ptr cinfo, + int need_full_buffer)); +EXTERN(void) jinit_color_converter JPP((j_compress_ptr cinfo)); +EXTERN(void) jinit_downsampler JPP((j_compress_ptr cinfo)); +EXTERN(void) jinit_forward_dct JPP((j_compress_ptr cinfo)); +EXTERN(void) jinit_huff_encoder JPP((j_compress_ptr cinfo)); +EXTERN(void) jinit_phuff_encoder JPP((j_compress_ptr cinfo)); +EXTERN(void) jinit_marker_writer JPP((j_compress_ptr cinfo)); +/* Decompression module initialization routines */ +EXTERN(void) jinit_master_decompress JPP((j_decompress_ptr cinfo)); +EXTERN(void) jinit_d_main_controller JPP((j_decompress_ptr cinfo, + int need_full_buffer)); +EXTERN(void) jinit_d_coef_controller JPP((j_decompress_ptr cinfo, + int need_full_buffer)); +EXTERN(void) jinit_d_post_controller JPP((j_decompress_ptr cinfo, + int need_full_buffer)); +EXTERN(void) jinit_input_controller JPP((j_decompress_ptr cinfo)); +EXTERN(void) jinit_marker_reader JPP((j_decompress_ptr cinfo)); +EXTERN(void) jinit_huff_decoder JPP((j_decompress_ptr cinfo)); +EXTERN(void) jinit_phuff_decoder JPP((j_decompress_ptr cinfo)); +EXTERN(void) jinit_inverse_dct JPP((j_decompress_ptr cinfo)); +EXTERN(void) jinit_upsampler JPP((j_decompress_ptr cinfo)); +EXTERN(void) jinit_color_deconverter JPP((j_decompress_ptr cinfo)); +EXTERN(void) jinit_1pass_quantizer JPP((j_decompress_ptr cinfo)); +EXTERN(void) jinit_2pass_quantizer JPP((j_decompress_ptr cinfo)); +EXTERN(void) jinit_merged_upsampler JPP((j_decompress_ptr cinfo)); +/* Memory manager initialization */ +EXTERN(void) jinit_memory_mgr JPP((j_common_ptr cinfo)); + +/* Utility routines in jutils.c */ +EXTERN(long) jdiv_round_up JPP((long a, long b)); +EXTERN(long) jround_up JPP((long a, long b)); +EXTERN(void) jcopy_sample_rows JPP((JSAMPARRAY input_array, int source_row, + JSAMPARRAY output_array, int dest_row, + int num_rows, JDIMENSION num_cols)); +EXTERN(void) jcopy_block_row JPP((JBLOCKROW input_row, JBLOCKROW output_row, + JDIMENSION num_blocks)); +EXTERN(void) jzero_far JPP((void FAR * target, size_t bytestozero)); +/* Constant tables in jutils.c */ +#if 0 /* This table is not actually needed in v6a */ +extern const int jpeg_zigzag_order[]; /* natural coef order to zigzag order */ +#endif +extern const int jpeg_natural_order[]; /* zigzag coef order to natural order */ + +/* Suppress undefined-structure complaints if necessary. */ + +#ifdef INCOMPLETE_TYPES_BROKEN +#ifndef AM_MEMORY_MANAGER /* only jmemmgr.c defines these */ +struct jvirt_sarray_control { long dummy; }; +struct jvirt_barray_control { long dummy; }; +#endif +#endif /* INCOMPLETE_TYPES_BROKEN */ diff --git a/ml/dlib/dlib/external/libjpeg/jpeglib.h b/ml/dlib/dlib/external/libjpeg/jpeglib.h new file mode 100644 index 000000000..e611602d2 --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jpeglib.h @@ -0,0 +1,1096 @@ +/* + * jpeglib.h + * + * Copyright (C) 1991-1998, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This file defines the application interface for the JPEG library. + * Most applications using the library need only include this file, + * and perhaps jerror.h if they want to know the exact error codes. + */ + +#ifndef JPEGLIB_H +#define JPEGLIB_H + +/* + * First we include the configuration files that record how this + * installation of the JPEG library is set up. jconfig.h can be + * generated automatically for many systems. jmorecfg.h contains + * manual configuration options that most people need not worry about. + */ + +#ifndef JCONFIG_INCLUDED /* in case jinclude.h already did */ +#include "jconfig.h" /* widely used configuration options */ +#endif +#include "jmorecfg.h" /* seldom changed options */ + + +/* Version ID for the JPEG library. + * Might be useful for tests like "#if JPEG_LIB_VERSION >= 60". + */ + +#define JPEG_LIB_VERSION 62 /* Version 6b */ + + +/* Various constants determining the sizes of things. + * All of these are specified by the JPEG standard, so don't change them + * if you want to be compatible. + */ + +#define DCTSIZE 8 /* The basic DCT block is 8x8 samples */ +#define DCTSIZE2 64 /* DCTSIZE squared; # of elements in a block */ +#define NUM_QUANT_TBLS 4 /* Quantization tables are numbered 0..3 */ +#define NUM_HUFF_TBLS 4 /* Huffman tables are numbered 0..3 */ +#define NUM_ARITH_TBLS 16 /* Arith-coding tables are numbered 0..15 */ +#define MAX_COMPS_IN_SCAN 4 /* JPEG limit on # of components in one scan */ +#define MAX_SAMP_FACTOR 4 /* JPEG limit on sampling factors */ +/* Unfortunately, some bozo at Adobe saw no reason to be bound by the standard; + * the PostScript DCT filter can emit files with many more than 10 blocks/MCU. + * If you happen to run across such a file, you can up D_MAX_BLOCKS_IN_MCU + * to handle it. We even let you do this from the jconfig.h file. However, + * we strongly discourage changing C_MAX_BLOCKS_IN_MCU; just because Adobe + * sometimes emits noncompliant files doesn't mean you should too. + */ +#define C_MAX_BLOCKS_IN_MCU 10 /* compressor's limit on blocks per MCU */ +#ifndef D_MAX_BLOCKS_IN_MCU +#define D_MAX_BLOCKS_IN_MCU 10 /* decompressor's limit on blocks per MCU */ +#endif + + +/* Data structures for images (arrays of samples and of DCT coefficients). + * On 80x86 machines, the image arrays are too big for near pointers, + * but the pointer arrays can fit in near memory. + */ + +typedef JSAMPLE FAR *JSAMPROW; /* ptr to one image row of pixel samples. */ +typedef JSAMPROW *JSAMPARRAY; /* ptr to some rows (a 2-D sample array) */ +typedef JSAMPARRAY *JSAMPIMAGE; /* a 3-D sample array: top index is color */ + +typedef JCOEF JBLOCK[DCTSIZE2]; /* one block of coefficients */ +typedef JBLOCK FAR *JBLOCKROW; /* pointer to one row of coefficient blocks */ +typedef JBLOCKROW *JBLOCKARRAY; /* a 2-D array of coefficient blocks */ +typedef JBLOCKARRAY *JBLOCKIMAGE; /* a 3-D array of coefficient blocks */ + +typedef JCOEF FAR *JCOEFPTR; /* useful in a couple of places */ + + +/* Types for JPEG compression parameters and working tables. */ + + +/* DCT coefficient quantization tables. */ + +typedef struct { + /* This array gives the coefficient quantizers in natural array order + * (not the zigzag order in which they are stored in a JPEG DQT marker). + * CAUTION: IJG versions prior to v6a kept this array in zigzag order. + */ + unsigned short quantval[DCTSIZE2]; /* quantization step for each coefficient */ + /* This field is used only during compression. It's initialized FALSE when + * the table is created, and set TRUE when it's been output to the file. + * You could suppress output of a table by setting this to TRUE. + * (See jpeg_suppress_tables for an example.) + */ + int sent_table; /* TRUE when table has been output */ +} JQUANT_TBL; + + +/* Huffman coding tables. */ + +typedef struct { + /* These two fields directly represent the contents of a JPEG DHT marker */ + unsigned char bits[17]; /* bits[k] = # of symbols with codes of */ + /* length k bits; bits[0] is unused */ + unsigned char huffval[256]; /* The symbols, in order of incr code length */ + /* This field is used only during compression. It's initialized FALSE when + * the table is created, and set TRUE when it's been output to the file. + * You could suppress output of a table by setting this to TRUE. + * (See jpeg_suppress_tables for an example.) + */ + int sent_table; /* TRUE when table has been output */ +} JHUFF_TBL; + + +/* Basic info about one component (color channel). */ + +typedef struct { + /* These values are fixed over the whole image. */ + /* For compression, they must be supplied by parameter setup; */ + /* for decompression, they are read from the SOF marker. */ + int component_id; /* identifier for this component (0..255) */ + int component_index; /* its index in SOF or cinfo->comp_info[] */ + int h_samp_factor; /* horizontal sampling factor (1..4) */ + int v_samp_factor; /* vertical sampling factor (1..4) */ + int quant_tbl_no; /* quantization table selector (0..3) */ + /* These values may vary between scans. */ + /* For compression, they must be supplied by parameter setup; */ + /* for decompression, they are read from the SOS marker. */ + /* The decompressor output side may not use these variables. */ + int dc_tbl_no; /* DC entropy table selector (0..3) */ + int ac_tbl_no; /* AC entropy table selector (0..3) */ + + /* Remaining fields should be treated as private by applications. */ + + /* These values are computed during compression or decompression startup: */ + /* Component's size in DCT blocks. + * Any dummy blocks added to complete an MCU are not counted; therefore + * these values do not depend on whether a scan is interleaved or not. + */ + JDIMENSION width_in_blocks; + JDIMENSION height_in_blocks; + /* Size of a DCT block in samples. Always DCTSIZE for compression. + * For decompression this is the size of the output from one DCT block, + * reflecting any scaling we choose to apply during the IDCT step. + * Values of 1,2,4,8 are likely to be supported. Note that different + * components may receive different IDCT scalings. + */ + int DCT_scaled_size; + /* The downsampled dimensions are the component's actual, unpadded number + * of samples at the main buffer (preprocessing/compression interface), thus + * downsampled_width = ceil(image_width * Hi/Hmax) + * and similarly for height. For decompression, IDCT scaling is included, so + * downsampled_width = ceil(image_width * Hi/Hmax * DCT_scaled_size/DCTSIZE) + */ + JDIMENSION downsampled_width; /* actual width in samples */ + JDIMENSION downsampled_height; /* actual height in samples */ + /* This flag is used only for decompression. In cases where some of the + * components will be ignored (eg grayscale output from YCbCr image), + * we can skip most computations for the unused components. + */ + int component_needed; /* do we need the value of this component? */ + + /* These values are computed before starting a scan of the component. */ + /* The decompressor output side may not use these variables. */ + int MCU_width; /* number of blocks per MCU, horizontally */ + int MCU_height; /* number of blocks per MCU, vertically */ + int MCU_blocks; /* MCU_width * MCU_height */ + int MCU_sample_width; /* MCU width in samples, MCU_width*DCT_scaled_size */ + int last_col_width; /* # of non-dummy blocks across in last MCU */ + int last_row_height; /* # of non-dummy blocks down in last MCU */ + + /* Saved quantization table for component; NULL if none yet saved. + * See jdinput.c comments about the need for this information. + * This field is currently used only for decompression. + */ + JQUANT_TBL * quant_table; + + /* Private per-component storage for DCT or IDCT subsystem. */ + void * dct_table; +} jpeg_component_info; + + +/* The script for encoding a multiple-scan file is an array of these: */ + +typedef struct { + int comps_in_scan; /* number of components encoded in this scan */ + int component_index[MAX_COMPS_IN_SCAN]; /* their SOF/comp_info[] indexes */ + int Ss, Se; /* progressive JPEG spectral selection parms */ + int Ah, Al; /* progressive JPEG successive approx. parms */ +} jpeg_scan_info; + +/* The decompressor can save APPn and COM markers in a list of these: */ + +typedef struct jpeg_marker_struct FAR * jpeg_saved_marker_ptr; + +struct jpeg_marker_struct { + jpeg_saved_marker_ptr next; /* next in list, or NULL */ + unsigned char marker; /* marker code: JPEG_COM, or JPEG_APP0+n */ + unsigned int original_length; /* # bytes of data in the file */ + unsigned int data_length; /* # bytes of data saved at data[] */ + JOCTET FAR * data; /* the data contained in the marker */ + /* the marker length word is not counted in data_length or original_length */ +}; + +/* Known color spaces. */ + +typedef enum { + JCS_UNKNOWN, /* error/unspecified */ + JCS_GRAYSCALE, /* monochrome */ + JCS_RGB, /* red/green/blue */ + JCS_YCbCr, /* Y/Cb/Cr (also known as YUV) */ + JCS_CMYK, /* C/M/Y/K */ + JCS_YCCK /* Y/Cb/Cr/K */ +} J_COLOR_SPACE; + +/* DCT/IDCT algorithm options. */ + +typedef enum { + JDCT_ISLOW, /* slow but accurate integer algorithm */ + JDCT_IFAST, /* faster, less accurate integer method */ + JDCT_FLOAT /* floating-point: accurate, fast on fast HW */ +} J_DCT_METHOD; + +#ifndef JDCT_DEFAULT /* may be overridden in jconfig.h */ +#define JDCT_DEFAULT JDCT_ISLOW +#endif +#ifndef JDCT_FASTEST /* may be overridden in jconfig.h */ +#define JDCT_FASTEST JDCT_IFAST +#endif + +/* Dithering options for decompression. */ + +typedef enum { + JDITHER_NONE, /* no dithering */ + JDITHER_ORDERED, /* simple ordered dither */ + JDITHER_FS /* Floyd-Steinberg error diffusion dither */ +} J_DITHER_MODE; + + +/* Common fields between JPEG compression and decompression master structs. */ + +#define jpeg_common_fields \ + struct jpeg_error_mgr * err; /* Error handler module */\ + struct jpeg_memory_mgr * mem; /* Memory manager module */\ + struct jpeg_progress_mgr * progress; /* Progress monitor, or NULL if none */\ + void * client_data; /* Available for use by application */\ + int is_decompressor; /* So common code can tell which is which */\ + int global_state /* For checking call sequence validity */ + +/* Routines that are to be used by both halves of the library are declared + * to receive a pointer to this structure. There are no actual instances of + * jpeg_common_struct, only of jpeg_compress_struct and jpeg_decompress_struct. + */ +struct jpeg_common_struct { + jpeg_common_fields; /* Fields common to both master struct types */ + /* Additional fields follow in an actual jpeg_compress_struct or + * jpeg_decompress_struct. All three structs must agree on these + * initial fields! (This would be a lot cleaner in C++.) + */ +}; + +typedef struct jpeg_common_struct * j_common_ptr; +typedef struct jpeg_compress_struct * j_compress_ptr; +typedef struct jpeg_decompress_struct * j_decompress_ptr; + + +/* Master record for a compression instance */ + +struct jpeg_compress_struct { + jpeg_common_fields; /* Fields shared with jpeg_decompress_struct */ + + /* Destination for compressed data */ + struct jpeg_destination_mgr * dest; + + /* Description of source image --- these fields must be filled in by + * outer application before starting compression. in_color_space must + * be correct before you can even call jpeg_set_defaults(). + */ + + JDIMENSION image_width; /* input image width */ + JDIMENSION image_height; /* input image height */ + int input_components; /* # of color components in input image */ + J_COLOR_SPACE in_color_space; /* colorspace of input image */ + + double input_gamma; /* image gamma of input image */ + + /* Compression parameters --- these fields must be set before calling + * jpeg_start_compress(). We recommend calling jpeg_set_defaults() to + * initialize everything to reasonable defaults, then changing anything + * the application specifically wants to change. That way you won't get + * burnt when new parameters are added. Also note that there are several + * helper routines to simplify changing parameters. + */ + + int data_precision; /* bits of precision in image data */ + + int num_components; /* # of color components in JPEG image */ + J_COLOR_SPACE jpeg_color_space; /* colorspace of JPEG image */ + + jpeg_component_info * comp_info; + /* comp_info[i] describes component that appears i'th in SOF */ + + JQUANT_TBL * quant_tbl_ptrs[NUM_QUANT_TBLS]; + /* ptrs to coefficient quantization tables, or NULL if not defined */ + + JHUFF_TBL * dc_huff_tbl_ptrs[NUM_HUFF_TBLS]; + JHUFF_TBL * ac_huff_tbl_ptrs[NUM_HUFF_TBLS]; + /* ptrs to Huffman coding tables, or NULL if not defined */ + + unsigned char arith_dc_L[NUM_ARITH_TBLS]; /* L values for DC arith-coding tables */ + unsigned char arith_dc_U[NUM_ARITH_TBLS]; /* U values for DC arith-coding tables */ + unsigned char arith_ac_K[NUM_ARITH_TBLS]; /* Kx values for AC arith-coding tables */ + + int num_scans; /* # of entries in scan_info array */ + const jpeg_scan_info * scan_info; /* script for multi-scan file, or NULL */ + /* The default value of scan_info is NULL, which causes a single-scan + * sequential JPEG file to be emitted. To create a multi-scan file, + * set num_scans and scan_info to point to an array of scan definitions. + */ + + int raw_data_in; /* TRUE=caller supplies downsampled data */ + int arith_code; /* TRUE=arithmetic coding, FALSE=Huffman */ + int optimize_coding; /* TRUE=optimize entropy encoding parms */ + int CCIR601_sampling; /* TRUE=first samples are cosited */ + int smoothing_factor; /* 1..100, or 0 for no input smoothing */ + J_DCT_METHOD dct_method; /* DCT algorithm selector */ + + /* The restart interval can be specified in absolute MCUs by setting + * restart_interval, or in MCU rows by setting restart_in_rows + * (in which case the correct restart_interval will be figured + * for each scan). + */ + unsigned int restart_interval; /* MCUs per restart, or 0 for no restart */ + int restart_in_rows; /* if > 0, MCU rows per restart interval */ + + /* Parameters controlling emission of special markers. */ + + int write_JFIF_header; /* should a JFIF marker be written? */ + unsigned char JFIF_major_version; /* What to write for the JFIF version number */ + unsigned char JFIF_minor_version; + /* These three values are not used by the JPEG code, merely copied */ + /* into the JFIF APP0 marker. density_unit can be 0 for unknown, */ + /* 1 for dots/inch, or 2 for dots/cm. Note that the pixel aspect */ + /* ratio is defined by X_density/Y_density even when density_unit=0. */ + unsigned char density_unit; /* JFIF code for pixel size units */ + unsigned short X_density; /* Horizontal pixel density */ + unsigned short Y_density; /* Vertical pixel density */ + int write_Adobe_marker; /* should an Adobe marker be written? */ + + /* State variable: index of next scanline to be written to + * jpeg_write_scanlines(). Application may use this to control its + * processing loop, e.g., "while (next_scanline < image_height)". + */ + + JDIMENSION next_scanline; /* 0 .. image_height-1 */ + + /* Remaining fields are known throughout compressor, but generally + * should not be touched by a surrounding application. + */ + + /* + * These fields are computed during compression startup + */ + int progressive_mode; /* TRUE if scan script uses progressive mode */ + int max_h_samp_factor; /* largest h_samp_factor */ + int max_v_samp_factor; /* largest v_samp_factor */ + + JDIMENSION total_iMCU_rows; /* # of iMCU rows to be input to coef ctlr */ + /* The coefficient controller receives data in units of MCU rows as defined + * for fully interleaved scans (whether the JPEG file is interleaved or not). + * There are v_samp_factor * DCTSIZE sample rows of each component in an + * "iMCU" (interleaved MCU) row. + */ + + /* + * These fields are valid during any one scan. + * They describe the components and MCUs actually appearing in the scan. + */ + int comps_in_scan; /* # of JPEG components in this scan */ + jpeg_component_info * cur_comp_info[MAX_COMPS_IN_SCAN]; + /* *cur_comp_info[i] describes component that appears i'th in SOS */ + + JDIMENSION MCUs_per_row; /* # of MCUs across the image */ + JDIMENSION MCU_rows_in_scan; /* # of MCU rows in the image */ + + int blocks_in_MCU; /* # of DCT blocks per MCU */ + int MCU_membership[C_MAX_BLOCKS_IN_MCU]; + /* MCU_membership[i] is index in cur_comp_info of component owning */ + /* i'th block in an MCU */ + + int Ss, Se, Ah, Al; /* progressive JPEG parameters for scan */ + + /* + * Links to compression subobjects (methods and private variables of modules) + */ + struct jpeg_comp_master * master; + struct jpeg_c_main_controller * main; + struct jpeg_c_prep_controller * prep; + struct jpeg_c_coef_controller * coef; + struct jpeg_marker_writer * marker; + struct jpeg_color_converter * cconvert; + struct jpeg_downsampler * downsample; + struct jpeg_forward_dct * fdct; + struct jpeg_entropy_encoder * entropy; + jpeg_scan_info * script_space; /* workspace for jpeg_simple_progression */ + int script_space_size; +}; + + +/* Master record for a decompression instance */ + +struct jpeg_decompress_struct { + jpeg_common_fields; /* Fields shared with jpeg_compress_struct */ + + /* Source of compressed data */ + struct jpeg_source_mgr * src; + + /* Basic description of image --- filled in by jpeg_read_header(). */ + /* Application may inspect these values to decide how to process image. */ + + JDIMENSION image_width; /* nominal image width (from SOF marker) */ + JDIMENSION image_height; /* nominal image height */ + int num_components; /* # of color components in JPEG image */ + J_COLOR_SPACE jpeg_color_space; /* colorspace of JPEG image */ + + /* Decompression processing parameters --- these fields must be set before + * calling jpeg_start_decompress(). Note that jpeg_read_header() initializes + * them to default values. + */ + + J_COLOR_SPACE out_color_space; /* colorspace for output */ + + unsigned int scale_num, scale_denom; /* fraction by which to scale image */ + + double output_gamma; /* image gamma wanted in output */ + + int buffered_image; /* TRUE=multiple output passes */ + int raw_data_out; /* TRUE=downsampled data wanted */ + + J_DCT_METHOD dct_method; /* IDCT algorithm selector */ + int do_fancy_upsampling; /* TRUE=apply fancy upsampling */ + int do_block_smoothing; /* TRUE=apply interblock smoothing */ + + int quantize_colors; /* TRUE=colormapped output wanted */ + /* the following are ignored if not quantize_colors: */ + J_DITHER_MODE dither_mode; /* type of color dithering to use */ + int two_pass_quantize; /* TRUE=use two-pass color quantization */ + int desired_number_of_colors; /* max # colors to use in created colormap */ + /* these are significant only in buffered-image mode: */ + int enable_1pass_quant; /* enable future use of 1-pass quantizer */ + int enable_external_quant;/* enable future use of external colormap */ + int enable_2pass_quant; /* enable future use of 2-pass quantizer */ + + /* Description of actual output image that will be returned to application. + * These fields are computed by jpeg_start_decompress(). + * You can also use jpeg_calc_output_dimensions() to determine these values + * in advance of calling jpeg_start_decompress(). + */ + + JDIMENSION output_width; /* scaled image width */ + JDIMENSION output_height; /* scaled image height */ + int out_color_components; /* # of color components in out_color_space */ + int output_components; /* # of color components returned */ + /* output_components is 1 (a colormap index) when quantizing colors; + * otherwise it equals out_color_components. + */ + int rec_outbuf_height; /* min recommended height of scanline buffer */ + /* If the buffer passed to jpeg_read_scanlines() is less than this many rows + * high, space and time will be wasted due to unnecessary data copying. + * Usually rec_outbuf_height will be 1 or 2, at most 4. + */ + + /* When quantizing colors, the output colormap is described by these fields. + * The application can supply a colormap by setting colormap non-NULL before + * calling jpeg_start_decompress; otherwise a colormap is created during + * jpeg_start_decompress or jpeg_start_output. + * The map has out_color_components rows and actual_number_of_colors columns. + */ + int actual_number_of_colors; /* number of entries in use */ + JSAMPARRAY colormap; /* The color map as a 2-D pixel array */ + + /* State variables: these variables indicate the progress of decompression. + * The application may examine these but must not modify them. + */ + + /* Row index of next scanline to be read from jpeg_read_scanlines(). + * Application may use this to control its processing loop, e.g., + * "while (output_scanline < output_height)". + */ + JDIMENSION output_scanline; /* 0 .. output_height-1 */ + + /* Current input scan number and number of iMCU rows completed in scan. + * These indicate the progress of the decompressor input side. + */ + int input_scan_number; /* Number of SOS markers seen so far */ + JDIMENSION input_iMCU_row; /* Number of iMCU rows completed */ + + /* The "output scan number" is the notional scan being displayed by the + * output side. The decompressor will not allow output scan/row number + * to get ahead of input scan/row, but it can fall arbitrarily far behind. + */ + int output_scan_number; /* Nominal scan number being displayed */ + JDIMENSION output_iMCU_row; /* Number of iMCU rows read */ + + /* Current progression status. coef_bits[c][i] indicates the precision + * with which component c's DCT coefficient i (in zigzag order) is known. + * It is -1 when no data has yet been received, otherwise it is the point + * transform (shift) value for the most recent scan of the coefficient + * (thus, 0 at completion of the progression). + * This pointer is NULL when reading a non-progressive file. + */ + int (*coef_bits)[DCTSIZE2]; /* -1 or current Al value for each coef */ + + /* Internal JPEG parameters --- the application usually need not look at + * these fields. Note that the decompressor output side may not use + * any parameters that can change between scans. + */ + + /* Quantization and Huffman tables are carried forward across input + * datastreams when processing abbreviated JPEG datastreams. + */ + + JQUANT_TBL * quant_tbl_ptrs[NUM_QUANT_TBLS]; + /* ptrs to coefficient quantization tables, or NULL if not defined */ + + JHUFF_TBL * dc_huff_tbl_ptrs[NUM_HUFF_TBLS]; + JHUFF_TBL * ac_huff_tbl_ptrs[NUM_HUFF_TBLS]; + /* ptrs to Huffman coding tables, or NULL if not defined */ + + /* These parameters are never carried across datastreams, since they + * are given in SOF/SOS markers or defined to be reset by SOI. + */ + + int data_precision; /* bits of precision in image data */ + + jpeg_component_info * comp_info; + /* comp_info[i] describes component that appears i'th in SOF */ + + int progressive_mode; /* TRUE if SOFn specifies progressive mode */ + int arith_code; /* TRUE=arithmetic coding, FALSE=Huffman */ + + unsigned char arith_dc_L[NUM_ARITH_TBLS]; /* L values for DC arith-coding tables */ + unsigned char arith_dc_U[NUM_ARITH_TBLS]; /* U values for DC arith-coding tables */ + unsigned char arith_ac_K[NUM_ARITH_TBLS]; /* Kx values for AC arith-coding tables */ + + unsigned int restart_interval; /* MCUs per restart interval, or 0 for no restart */ + + /* These fields record data obtained from optional markers recognized by + * the JPEG library. + */ + int saw_JFIF_marker; /* TRUE iff a JFIF APP0 marker was found */ + /* Data copied from JFIF marker; only valid if saw_JFIF_marker is TRUE: */ + unsigned char JFIF_major_version; /* JFIF version number */ + unsigned char JFIF_minor_version; + unsigned char density_unit; /* JFIF code for pixel size units */ + unsigned short X_density; /* Horizontal pixel density */ + unsigned short Y_density; /* Vertical pixel density */ + int saw_Adobe_marker; /* TRUE iff an Adobe APP14 marker was found */ + unsigned char Adobe_transform; /* Color transform code from Adobe marker */ + + int CCIR601_sampling; /* TRUE=first samples are cosited */ + + /* Aside from the specific data retained from APPn markers known to the + * library, the uninterpreted contents of any or all APPn and COM markers + * can be saved in a list for examination by the application. + */ + jpeg_saved_marker_ptr marker_list; /* Head of list of saved markers */ + + /* Remaining fields are known throughout decompressor, but generally + * should not be touched by a surrounding application. + */ + + /* + * These fields are computed during decompression startup + */ + int max_h_samp_factor; /* largest h_samp_factor */ + int max_v_samp_factor; /* largest v_samp_factor */ + + int min_DCT_scaled_size; /* smallest DCT_scaled_size of any component */ + + JDIMENSION total_iMCU_rows; /* # of iMCU rows in image */ + /* The coefficient controller's input and output progress is measured in + * units of "iMCU" (interleaved MCU) rows. These are the same as MCU rows + * in fully interleaved JPEG scans, but are used whether the scan is + * interleaved or not. We define an iMCU row as v_samp_factor DCT block + * rows of each component. Therefore, the IDCT output contains + * v_samp_factor*DCT_scaled_size sample rows of a component per iMCU row. + */ + + JSAMPLE * sample_range_limit; /* table for fast range-limiting */ + + /* + * These fields are valid during any one scan. + * They describe the components and MCUs actually appearing in the scan. + * Note that the decompressor output side must not use these fields. + */ + int comps_in_scan; /* # of JPEG components in this scan */ + jpeg_component_info * cur_comp_info[MAX_COMPS_IN_SCAN]; + /* *cur_comp_info[i] describes component that appears i'th in SOS */ + + JDIMENSION MCUs_per_row; /* # of MCUs across the image */ + JDIMENSION MCU_rows_in_scan; /* # of MCU rows in the image */ + + int blocks_in_MCU; /* # of DCT blocks per MCU */ + int MCU_membership[D_MAX_BLOCKS_IN_MCU]; + /* MCU_membership[i] is index in cur_comp_info of component owning */ + /* i'th block in an MCU */ + + int Ss, Se, Ah, Al; /* progressive JPEG parameters for scan */ + + /* This field is shared between entropy decoder and marker parser. + * It is either zero or the code of a JPEG marker that has been + * read from the data source, but has not yet been processed. + */ + int unread_marker; + + /* + * Links to decompression subobjects (methods, private variables of modules) + */ + struct jpeg_decomp_master * master; + struct jpeg_d_main_controller * main; + struct jpeg_d_coef_controller * coef; + struct jpeg_d_post_controller * post; + struct jpeg_input_controller * inputctl; + struct jpeg_marker_reader * marker; + struct jpeg_entropy_decoder * entropy; + struct jpeg_inverse_dct * idct; + struct jpeg_upsampler * upsample; + struct jpeg_color_deconverter * cconvert; + struct jpeg_color_quantizer * cquantize; +}; + + +/* "Object" declarations for JPEG modules that may be supplied or called + * directly by the surrounding application. + * As with all objects in the JPEG library, these structs only define the + * publicly visible methods and state variables of a module. Additional + * private fields may exist after the public ones. + */ + + +/* Error handler object */ + +struct jpeg_error_mgr { + /* Error exit handler: does not return to caller */ + JMETHOD(void, error_exit, (j_common_ptr cinfo)); + /* Conditionally emit a trace or warning message */ + JMETHOD(void, emit_message, (j_common_ptr cinfo, int msg_level)); + /* Routine that actually outputs a trace or error message */ + JMETHOD(void, output_message, (j_common_ptr cinfo)); + /* Format a message string for the most recent JPEG error or message */ + JMETHOD(void, format_message, (j_common_ptr cinfo, char * buffer)); +#define JMSG_LENGTH_MAX 200 /* recommended size of format_message buffer */ + /* Reset error state variables at start of a new image */ + JMETHOD(void, reset_error_mgr, (j_common_ptr cinfo)); + + /* The message ID code and any parameters are saved here. + * A message can have one string parameter or up to 8 int parameters. + */ + int msg_code; +#define JMSG_STR_PARM_MAX 80 + union { + int i[8]; + char s[JMSG_STR_PARM_MAX]; + } msg_parm; + + /* Standard state variables for error facility */ + + int trace_level; /* max msg_level that will be displayed */ + + /* For recoverable corrupt-data errors, we emit a warning message, + * but keep going unless emit_message chooses to abort. emit_message + * should count warnings in num_warnings. The surrounding application + * can check for bad data by seeing if num_warnings is nonzero at the + * end of processing. + */ + long num_warnings; /* number of corrupt-data warnings */ + + /* These fields point to the table(s) of error message strings. + * An application can change the table pointer to switch to a different + * message list (typically, to change the language in which errors are + * reported). Some applications may wish to add additional error codes + * that will be handled by the JPEG library error mechanism; the second + * table pointer is used for this purpose. + * + * First table includes all errors generated by JPEG library itself. + * Error code 0 is reserved for a "no such error string" message. + */ + const char * const * jpeg_message_table; /* Library errors */ + int last_jpeg_message; /* Table contains strings 0..last_jpeg_message */ + /* Second table can be added by application (see cjpeg/djpeg for example). + * It contains strings numbered first_addon_message..last_addon_message. + */ + const char * const * addon_message_table; /* Non-library errors */ + int first_addon_message; /* code for first string in addon table */ + int last_addon_message; /* code for last string in addon table */ +}; + + +/* Progress monitor object */ + +struct jpeg_progress_mgr { + JMETHOD(void, progress_monitor, (j_common_ptr cinfo)); + + long pass_counter; /* work units completed in this pass */ + long pass_limit; /* total number of work units in this pass */ + int completed_passes; /* passes completed so far */ + int total_passes; /* total number of passes expected */ +}; + + +/* Data destination object for compression */ + +struct jpeg_destination_mgr { + JOCTET * next_output_byte; /* => next byte to write in buffer */ + size_t free_in_buffer; /* # of byte spaces remaining in buffer */ + + JMETHOD(void, init_destination, (j_compress_ptr cinfo)); + JMETHOD(int, empty_output_buffer, (j_compress_ptr cinfo)); + JMETHOD(void, term_destination, (j_compress_ptr cinfo)); +}; + + +/* Data source object for decompression */ + +struct jpeg_source_mgr { + const JOCTET * next_input_byte; /* => next byte to read from buffer */ + size_t bytes_in_buffer; /* # of bytes remaining in buffer */ + + JMETHOD(void, init_source, (j_decompress_ptr cinfo)); + JMETHOD(int, fill_input_buffer, (j_decompress_ptr cinfo)); + JMETHOD(void, skip_input_data, (j_decompress_ptr cinfo, long num_bytes)); + JMETHOD(int, resync_to_restart, (j_decompress_ptr cinfo, int desired)); + JMETHOD(void, term_source, (j_decompress_ptr cinfo)); +}; + + +/* Memory manager object. + * Allocates "small" objects (a few K total), "large" objects (tens of K), + * and "really big" objects (virtual arrays with backing store if needed). + * The memory manager does not allow individual objects to be freed; rather, + * each created object is assigned to a pool, and whole pools can be freed + * at once. This is faster and more convenient than remembering exactly what + * to free, especially where malloc()/free() are not too speedy. + * NB: alloc routines never return NULL. They exit to error_exit if not + * successful. + */ + +#define JPOOL_PERMANENT 0 /* lasts until master record is destroyed */ +#define JPOOL_IMAGE 1 /* lasts until done with image/datastream */ +#define JPOOL_NUMPOOLS 2 + +typedef struct jvirt_sarray_control * jvirt_sarray_ptr; +typedef struct jvirt_barray_control * jvirt_barray_ptr; + + +struct jpeg_memory_mgr { + /* Method pointers */ + JMETHOD(void *, alloc_small, (j_common_ptr cinfo, int pool_id, + size_t sizeofobject)); + JMETHOD(void FAR *, alloc_large, (j_common_ptr cinfo, int pool_id, + size_t sizeofobject)); + JMETHOD(JSAMPARRAY, alloc_sarray, (j_common_ptr cinfo, int pool_id, + JDIMENSION samplesperrow, + JDIMENSION numrows)); + JMETHOD(JBLOCKARRAY, alloc_barray, (j_common_ptr cinfo, int pool_id, + JDIMENSION blocksperrow, + JDIMENSION numrows)); + JMETHOD(jvirt_sarray_ptr, request_virt_sarray, (j_common_ptr cinfo, + int pool_id, + int pre_zero, + JDIMENSION samplesperrow, + JDIMENSION numrows, + JDIMENSION maxaccess)); + JMETHOD(jvirt_barray_ptr, request_virt_barray, (j_common_ptr cinfo, + int pool_id, + int pre_zero, + JDIMENSION blocksperrow, + JDIMENSION numrows, + JDIMENSION maxaccess)); + JMETHOD(void, realize_virt_arrays, (j_common_ptr cinfo)); + JMETHOD(JSAMPARRAY, access_virt_sarray, (j_common_ptr cinfo, + jvirt_sarray_ptr ptr, + JDIMENSION start_row, + JDIMENSION num_rows, + int writable)); + JMETHOD(JBLOCKARRAY, access_virt_barray, (j_common_ptr cinfo, + jvirt_barray_ptr ptr, + JDIMENSION start_row, + JDIMENSION num_rows, + int writable)); + JMETHOD(void, free_pool, (j_common_ptr cinfo, int pool_id)); + JMETHOD(void, self_destruct, (j_common_ptr cinfo)); + + /* Limit on memory allocation for this JPEG object. (Note that this is + * merely advisory, not a guaranteed maximum; it only affects the space + * used for virtual-array buffers.) May be changed by outer application + * after creating the JPEG object. + */ + long max_memory_to_use; + + /* Maximum allocation request accepted by alloc_large. */ + long max_alloc_chunk; +}; + + +/* Routine signature for application-supplied marker processing methods. + * Need not pass marker code since it is stored in cinfo->unread_marker. + */ +typedef JMETHOD(int, jpeg_marker_parser_method, (j_decompress_ptr cinfo)); + + +/* Declarations for routines called by application. + * The JPP macro hides prototype parameters from compilers that can't cope. + * Note JPP requires double parentheses. + */ + +#ifdef HAVE_PROTOTYPES +#define JPP(arglist) arglist +#else +#define JPP(arglist) () +#endif + + +/* Short forms of external names for systems with brain-damaged linkers. + * We shorten external names to be unique in the first six letters, which + * is good enough for all known systems. + * (If your compiler itself needs names to be unique in less than 15 + * characters, you are out of luck. Get a better compiler.) + */ + +#ifdef NEED_SHORT_EXTERNAL_NAMES +#define jpeg_std_error jStdError +#define jpeg_CreateCompress jCreaCompress +#define jpeg_CreateDecompress jCreaDecompress +#define jpeg_destroy_compress jDestCompress +#define jpeg_destroy_decompress jDestDecompress +#define jpeg_stdio_dest jStdDest +#define jpeg_stdio_src jStdSrc +#define jpeg_set_defaults jSetDefaults +#define jpeg_set_colorspace jSetColorspace +#define jpeg_default_colorspace jDefColorspace +#define jpeg_set_quality jSetQuality +#define jpeg_set_linear_quality jSetLQuality +#define jpeg_add_quant_table jAddQuantTable +#define jpeg_quality_scaling jQualityScaling +#define jpeg_simple_progression jSimProgress +#define jpeg_suppress_tables jSuppressTables +#define jpeg_alloc_quant_table jAlcQTable +#define jpeg_alloc_huff_table jAlcHTable +#define jpeg_start_compress jStrtCompress +#define jpeg_write_scanlines jWrtScanlines +#define jpeg_finish_compress jFinCompress +#define jpeg_write_raw_data jWrtRawData +#define jpeg_write_marker jWrtMarker +#define jpeg_write_m_header jWrtMHeader +#define jpeg_write_m_byte jWrtMByte +#define jpeg_write_tables jWrtTables +#define jpeg_read_header jReadHeader +#define jpeg_start_decompress jStrtDecompress +#define jpeg_read_scanlines jReadScanlines +#define jpeg_finish_decompress jFinDecompress +#define jpeg_read_raw_data jReadRawData +#define jpeg_has_multiple_scans jHasMultScn +#define jpeg_start_output jStrtOutput +#define jpeg_finish_output jFinOutput +#define jpeg_input_complete jInComplete +#define jpeg_new_colormap jNewCMap +#define jpeg_consume_input jConsumeInput +#define jpeg_calc_output_dimensions jCalcDimensions +#define jpeg_save_markers jSaveMarkers +#define jpeg_set_marker_processor jSetMarker +#define jpeg_read_coefficients jReadCoefs +#define jpeg_write_coefficients jWrtCoefs +#define jpeg_copy_critical_parameters jCopyCrit +#define jpeg_abort_compress jAbrtCompress +#define jpeg_abort_decompress jAbrtDecompress +#define jpeg_abort jAbort +#define jpeg_destroy jDestroy +#define jpeg_resync_to_restart jResyncRestart +#endif /* NEED_SHORT_EXTERNAL_NAMES */ + + +/* Default error-management setup */ +EXTERN(struct jpeg_error_mgr *) jpeg_std_error + JPP((struct jpeg_error_mgr * err)); + +/* Initialization of JPEG compression objects. + * jpeg_create_compress() and jpeg_create_decompress() are the exported + * names that applications should call. These expand to calls on + * jpeg_CreateCompress and jpeg_CreateDecompress with additional information + * passed for version mismatch checking. + * NB: you must set up the error-manager BEFORE calling jpeg_create_xxx. + */ +#define jpeg_create_compress(cinfo) \ + jpeg_CreateCompress((cinfo), JPEG_LIB_VERSION, \ + (size_t) sizeof(struct jpeg_compress_struct)) +#define jpeg_create_decompress(cinfo) \ + jpeg_CreateDecompress((cinfo), JPEG_LIB_VERSION, \ + (size_t) sizeof(struct jpeg_decompress_struct)) +EXTERN(void) jpeg_CreateCompress JPP((j_compress_ptr cinfo, + int version, size_t structsize)); +EXTERN(void) jpeg_CreateDecompress JPP((j_decompress_ptr cinfo, + int version, size_t structsize)); +/* Destruction of JPEG compression objects */ +EXTERN(void) jpeg_destroy_compress JPP((j_compress_ptr cinfo)); +EXTERN(void) jpeg_destroy_decompress JPP((j_decompress_ptr cinfo)); + +/* Standard data source and destination managers: stdio streams. */ +/* Caller is responsible for opening the file before and closing after. */ +EXTERN(void) jpeg_stdio_dest JPP((j_compress_ptr cinfo, FILE * outfile)); +EXTERN(void) jpeg_stdio_src JPP((j_decompress_ptr cinfo, FILE * infile)); + +/* Default parameter setup for compression */ +EXTERN(void) jpeg_set_defaults JPP((j_compress_ptr cinfo)); +/* Compression parameter setup aids */ +EXTERN(void) jpeg_set_colorspace JPP((j_compress_ptr cinfo, + J_COLOR_SPACE colorspace)); +EXTERN(void) jpeg_default_colorspace JPP((j_compress_ptr cinfo)); +EXTERN(void) jpeg_set_quality JPP((j_compress_ptr cinfo, int quality, + int force_baseline)); +EXTERN(void) jpeg_set_linear_quality JPP((j_compress_ptr cinfo, + int scale_factor, + int force_baseline)); +EXTERN(void) jpeg_add_quant_table JPP((j_compress_ptr cinfo, int which_tbl, + const unsigned int *basic_table, + int scale_factor, + int force_baseline)); +EXTERN(int) jpeg_quality_scaling JPP((int quality)); +EXTERN(void) jpeg_simple_progression JPP((j_compress_ptr cinfo)); +EXTERN(void) jpeg_suppress_tables JPP((j_compress_ptr cinfo, + int suppress)); +EXTERN(JQUANT_TBL *) jpeg_alloc_quant_table JPP((j_common_ptr cinfo)); +EXTERN(JHUFF_TBL *) jpeg_alloc_huff_table JPP((j_common_ptr cinfo)); + +/* Main entry points for compression */ +EXTERN(void) jpeg_start_compress JPP((j_compress_ptr cinfo, + int write_all_tables)); +EXTERN(JDIMENSION) jpeg_write_scanlines JPP((j_compress_ptr cinfo, + JSAMPARRAY scanlines, + JDIMENSION num_lines)); +EXTERN(void) jpeg_finish_compress JPP((j_compress_ptr cinfo)); + +/* Replaces jpeg_write_scanlines when writing raw downsampled data. */ +EXTERN(JDIMENSION) jpeg_write_raw_data JPP((j_compress_ptr cinfo, + JSAMPIMAGE data, + JDIMENSION num_lines)); + +/* Write a special marker. See libjpeg.doc concerning safe usage. */ +EXTERN(void) jpeg_write_marker + JPP((j_compress_ptr cinfo, int marker, + const JOCTET * dataptr, unsigned int datalen)); +/* Same, but piecemeal. */ +EXTERN(void) jpeg_write_m_header + JPP((j_compress_ptr cinfo, int marker, unsigned int datalen)); +EXTERN(void) jpeg_write_m_byte + JPP((j_compress_ptr cinfo, int val)); + +/* Alternate compression function: just write an abbreviated table file */ +EXTERN(void) jpeg_write_tables JPP((j_compress_ptr cinfo)); + +/* Decompression startup: read start of JPEG datastream to see what's there */ +EXTERN(int) jpeg_read_header JPP((j_decompress_ptr cinfo, + int require_image)); +/* Return value is one of: */ +#define JPEG_SUSPENDED 0 /* Suspended due to lack of input data */ +#define JPEG_HEADER_OK 1 /* Found valid image datastream */ +#define JPEG_HEADER_TABLES_ONLY 2 /* Found valid table-specs-only datastream */ +/* If you pass require_image = TRUE (normal case), you need not check for + * a TABLES_ONLY return code; an abbreviated file will cause an error exit. + * JPEG_SUSPENDED is only possible if you use a data source module that can + * give a suspension return (the stdio source module doesn't). + */ + +/* Main entry points for decompression */ +EXTERN(int) jpeg_start_decompress JPP((j_decompress_ptr cinfo)); +EXTERN(JDIMENSION) jpeg_read_scanlines JPP((j_decompress_ptr cinfo, + JSAMPARRAY scanlines, + JDIMENSION max_lines)); +EXTERN(int) jpeg_finish_decompress JPP((j_decompress_ptr cinfo)); + +/* Replaces jpeg_read_scanlines when reading raw downsampled data. */ +EXTERN(JDIMENSION) jpeg_read_raw_data JPP((j_decompress_ptr cinfo, + JSAMPIMAGE data, + JDIMENSION max_lines)); + +/* Additional entry points for buffered-image mode. */ +EXTERN(int) jpeg_has_multiple_scans JPP((j_decompress_ptr cinfo)); +EXTERN(int) jpeg_start_output JPP((j_decompress_ptr cinfo, + int scan_number)); +EXTERN(int) jpeg_finish_output JPP((j_decompress_ptr cinfo)); +EXTERN(int) jpeg_input_complete JPP((j_decompress_ptr cinfo)); +EXTERN(void) jpeg_new_colormap JPP((j_decompress_ptr cinfo)); +EXTERN(int) jpeg_consume_input JPP((j_decompress_ptr cinfo)); +/* Return value is one of: */ +/* #define JPEG_SUSPENDED 0 Suspended due to lack of input data */ +#define JPEG_REACHED_SOS 1 /* Reached start of new scan */ +#define JPEG_REACHED_EOI 2 /* Reached end of image */ +#define JPEG_ROW_COMPLETED 3 /* Completed one iMCU row */ +#define JPEG_SCAN_COMPLETED 4 /* Completed last iMCU row of a scan */ + +/* Precalculate output dimensions for current decompression parameters. */ +EXTERN(void) jpeg_calc_output_dimensions JPP((j_decompress_ptr cinfo)); + +/* Control saving of COM and APPn markers into marker_list. */ +EXTERN(void) jpeg_save_markers + JPP((j_decompress_ptr cinfo, int marker_code, + unsigned int length_limit)); + +/* Install a special processing method for COM or APPn markers. */ +EXTERN(void) jpeg_set_marker_processor + JPP((j_decompress_ptr cinfo, int marker_code, + jpeg_marker_parser_method routine)); + +/* Read or write raw DCT coefficients --- useful for lossless transcoding. */ +EXTERN(jvirt_barray_ptr *) jpeg_read_coefficients JPP((j_decompress_ptr cinfo)); +EXTERN(void) jpeg_write_coefficients JPP((j_compress_ptr cinfo, + jvirt_barray_ptr * coef_arrays)); +EXTERN(void) jpeg_copy_critical_parameters JPP((j_decompress_ptr srcinfo, + j_compress_ptr dstinfo)); + +/* If you choose to abort compression or decompression before completing + * jpeg_finish_(de)compress, then you need to clean up to release memory, + * temporary files, etc. You can just call jpeg_destroy_(de)compress + * if you're done with the JPEG object, but if you want to clean it up and + * reuse it, call this: + */ +EXTERN(void) jpeg_abort_compress JPP((j_compress_ptr cinfo)); +EXTERN(void) jpeg_abort_decompress JPP((j_decompress_ptr cinfo)); + +/* Generic versions of jpeg_abort and jpeg_destroy that work on either + * flavor of JPEG object. These may be more convenient in some places. + */ +EXTERN(void) jpeg_abort JPP((j_common_ptr cinfo)); +EXTERN(void) jpeg_destroy JPP((j_common_ptr cinfo)); + +/* Default restart-marker-resync procedure for use by data source modules */ +EXTERN(int) jpeg_resync_to_restart JPP((j_decompress_ptr cinfo, + int desired)); + + +/* These marker codes are exported since applications and data source modules + * are likely to want to use them. + */ + +#define JPEG_RST0 0xD0 /* RST0 marker code */ +#define JPEG_EOI 0xD9 /* EOI marker code */ +#define JPEG_APP0 0xE0 /* APP0 marker code */ +#define JPEG_COM 0xFE /* COM marker code */ + + +/* If we have a brain-damaged compiler that emits warnings (or worse, errors) + * for structure definitions that are never filled in, keep it quiet by + * supplying dummy definitions for the various substructures. + */ + +#ifdef INCOMPLETE_TYPES_BROKEN +#ifndef JPEG_INTERNALS /* will be defined in jpegint.h */ +struct jvirt_sarray_control { long dummy; }; +struct jvirt_barray_control { long dummy; }; +struct jpeg_comp_master { long dummy; }; +struct jpeg_c_main_controller { long dummy; }; +struct jpeg_c_prep_controller { long dummy; }; +struct jpeg_c_coef_controller { long dummy; }; +struct jpeg_marker_writer { long dummy; }; +struct jpeg_color_converter { long dummy; }; +struct jpeg_downsampler { long dummy; }; +struct jpeg_forward_dct { long dummy; }; +struct jpeg_entropy_encoder { long dummy; }; +struct jpeg_decomp_master { long dummy; }; +struct jpeg_d_main_controller { long dummy; }; +struct jpeg_d_coef_controller { long dummy; }; +struct jpeg_d_post_controller { long dummy; }; +struct jpeg_input_controller { long dummy; }; +struct jpeg_marker_reader { long dummy; }; +struct jpeg_entropy_decoder { long dummy; }; +struct jpeg_inverse_dct { long dummy; }; +struct jpeg_upsampler { long dummy; }; +struct jpeg_color_deconverter { long dummy; }; +struct jpeg_color_quantizer { long dummy; }; +#endif /* JPEG_INTERNALS */ +#endif /* INCOMPLETE_TYPES_BROKEN */ + + +/* + * The JPEG library modules define JPEG_INTERNALS before including this file. + * The internal structure declarations are read only when that is true. + * Applications using the library should not include jpegint.h, but may wish + * to include jerror.h. + */ + +#ifdef JPEG_INTERNALS +#include "jpegint.h" /* fetch private declarations */ +#include "jerror.h" /* fetch error codes too */ +#endif + +#endif /* JPEGLIB_H */ diff --git a/ml/dlib/dlib/external/libjpeg/jquant1.cpp b/ml/dlib/dlib/external/libjpeg/jquant1.cpp new file mode 100644 index 000000000..7582015ad --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jquant1.cpp @@ -0,0 +1,856 @@ +/* + * jquant1.c + * + * Copyright (C) 1991-1996, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This file contains 1-pass color quantization (color mapping) routines. + * These routines provide mapping to a fixed color map using equally spaced + * color values. Optional Floyd-Steinberg or ordered dithering is available. + */ + +#define JPEG_INTERNALS +#include "jinclude.h" +#include "jpeglib.h" + +#ifdef QUANT_1PASS_SUPPORTED + + +/* + * The main purpose of 1-pass quantization is to provide a fast, if not very + * high quality, colormapped output capability. A 2-pass quantizer usually + * gives better visual quality; however, for quantized grayscale output this + * quantizer is perfectly adequate. Dithering is highly recommended with this + * quantizer, though you can turn it off if you really want to. + * + * In 1-pass quantization the colormap must be chosen in advance of seeing the + * image. We use a map consisting of all combinations of Ncolors[i] color + * values for the i'th component. The Ncolors[] values are chosen so that + * their product, the total number of colors, is no more than that requested. + * (In most cases, the product will be somewhat less.) + * + * Since the colormap is orthogonal, the representative value for each color + * component can be determined without considering the other components; + * then these indexes can be combined into a colormap index by a standard + * N-dimensional-array-subscript calculation. Most of the arithmetic involved + * can be precalculated and stored in the lookup table colorindex[]. + * colorindex[i][j] maps pixel value j in component i to the nearest + * representative value (grid plane) for that component; this index is + * multiplied by the array stride for component i, so that the + * index of the colormap entry closest to a given pixel value is just + * sum( colorindex[component-number][pixel-component-value] ) + * Aside from being fast, this scheme allows for variable spacing between + * representative values with no additional lookup cost. + * + * If gamma correction has been applied in color conversion, it might be wise + * to adjust the color grid spacing so that the representative colors are + * equidistant in linear space. At this writing, gamma correction is not + * implemented by jdcolor, so nothing is done here. + */ + + +/* Declarations for ordered dithering. + * + * We use a standard 16x16 ordered dither array. The basic concept of ordered + * dithering is described in many references, for instance Dale Schumacher's + * chapter II.2 of Graphics Gems II (James Arvo, ed. Academic Press, 1991). + * In place of Schumacher's comparisons against a "threshold" value, we add a + * "dither" value to the input pixel and then round the result to the nearest + * output value. The dither value is equivalent to (0.5 - threshold) times + * the distance between output values. For ordered dithering, we assume that + * the output colors are equally spaced; if not, results will probably be + * worse, since the dither may be too much or too little at a given point. + * + * The normal calculation would be to form pixel value + dither, range-limit + * this to 0..MAXJSAMPLE, and then index into the colorindex table as usual. + * We can skip the separate range-limiting step by extending the colorindex + * table in both directions. + */ + +#define ODITHER_SIZE 16 /* dimension of dither matrix */ +/* NB: if ODITHER_SIZE is not a power of 2, ODITHER_MASK uses will break */ +#define ODITHER_CELLS (ODITHER_SIZE*ODITHER_SIZE) /* # cells in matrix */ +#define ODITHER_MASK (ODITHER_SIZE-1) /* mask for wrapping around counters */ + +typedef int ODITHER_MATRIX[ODITHER_SIZE][ODITHER_SIZE]; +typedef int (*ODITHER_MATRIX_PTR)[ODITHER_SIZE]; + +static const unsigned char base_dither_matrix[ODITHER_SIZE][ODITHER_SIZE] = { + /* Bayer's order-4 dither array. Generated by the code given in + * Stephen Hawley's article "Ordered Dithering" in Graphics Gems I. + * The values in this array must range from 0 to ODITHER_CELLS-1. + */ + { 0,192, 48,240, 12,204, 60,252, 3,195, 51,243, 15,207, 63,255 }, + { 128, 64,176,112,140, 76,188,124,131, 67,179,115,143, 79,191,127 }, + { 32,224, 16,208, 44,236, 28,220, 35,227, 19,211, 47,239, 31,223 }, + { 160, 96,144, 80,172,108,156, 92,163, 99,147, 83,175,111,159, 95 }, + { 8,200, 56,248, 4,196, 52,244, 11,203, 59,251, 7,199, 55,247 }, + { 136, 72,184,120,132, 68,180,116,139, 75,187,123,135, 71,183,119 }, + { 40,232, 24,216, 36,228, 20,212, 43,235, 27,219, 39,231, 23,215 }, + { 168,104,152, 88,164,100,148, 84,171,107,155, 91,167,103,151, 87 }, + { 2,194, 50,242, 14,206, 62,254, 1,193, 49,241, 13,205, 61,253 }, + { 130, 66,178,114,142, 78,190,126,129, 65,177,113,141, 77,189,125 }, + { 34,226, 18,210, 46,238, 30,222, 33,225, 17,209, 45,237, 29,221 }, + { 162, 98,146, 82,174,110,158, 94,161, 97,145, 81,173,109,157, 93 }, + { 10,202, 58,250, 6,198, 54,246, 9,201, 57,249, 5,197, 53,245 }, + { 138, 74,186,122,134, 70,182,118,137, 73,185,121,133, 69,181,117 }, + { 42,234, 26,218, 38,230, 22,214, 41,233, 25,217, 37,229, 21,213 }, + { 170,106,154, 90,166,102,150, 86,169,105,153, 89,165,101,149, 85 } +}; + + +/* Declarations for Floyd-Steinberg dithering. + * + * Errors are accumulated into the array fserrors[], at a resolution of + * 1/16th of a pixel count. The error at a given pixel is propagated + * to its not-yet-processed neighbors using the standard F-S fractions, + * ... (here) 7/16 + * 3/16 5/16 1/16 + * We work left-to-right on even rows, right-to-left on odd rows. + * + * We can get away with a single array (holding one row's worth of errors) + * by using it to store the current row's errors at pixel columns not yet + * processed, but the next row's errors at columns already processed. We + * need only a few extra variables to hold the errors immediately around the + * current column. (If we are lucky, those variables are in registers, but + * even if not, they're probably cheaper to access than array elements are.) + * + * The fserrors[] array is indexed [component#][position]. + * We provide (#columns + 2) entries per component; the extra entry at each + * end saves us from special-casing the first and last pixels. + * + * Note: on a wide image, we might not have enough room in a PC's near data + * segment to hold the error array; so it is allocated with alloc_large. + */ + +#if BITS_IN_JSAMPLE == 8 +typedef short FSERROR; /* 16 bits should be enough */ +typedef int LOCFSERROR; /* use 'int' for calculation temps */ +#else +typedef long FSERROR; /* may need more than 16 bits */ +typedef long LOCFSERROR; /* be sure calculation temps are big enough */ +#endif + +typedef FSERROR FAR *FSERRPTR; /* pointer to error array (in FAR storage!) */ + + +/* Private subobject */ + +#define MAX_Q_COMPS 4 /* max components I can handle */ + +typedef struct { + struct jpeg_color_quantizer pub; /* public fields */ + + /* Initially allocated colormap is saved here */ + JSAMPARRAY sv_colormap; /* The color map as a 2-D pixel array */ + int sv_actual; /* number of entries in use */ + + JSAMPARRAY colorindex; /* Precomputed mapping for speed */ + /* colorindex[i][j] = index of color closest to pixel value j in component i, + * premultiplied as described above. Since colormap indexes must fit into + * JSAMPLEs, the entries of this array will too. + */ + int is_padded; /* is the colorindex padded for odither? */ + + int Ncolors[MAX_Q_COMPS]; /* # of values alloced to each component */ + + /* Variables for ordered dithering */ + int row_index; /* cur row's vertical index in dither matrix */ + ODITHER_MATRIX_PTR odither[MAX_Q_COMPS]; /* one dither array per component */ + + /* Variables for Floyd-Steinberg dithering */ + FSERRPTR fserrors[MAX_Q_COMPS]; /* accumulated errors */ + int on_odd_row; /* flag to remember which row we are on */ +} my_cquantizer; + +typedef my_cquantizer * my_cquantize_ptr; + + +/* + * Policy-making subroutines for create_colormap and create_colorindex. + * These routines determine the colormap to be used. The rest of the module + * only assumes that the colormap is orthogonal. + * + * * select_ncolors decides how to divvy up the available colors + * among the components. + * * output_value defines the set of representative values for a component. + * * largest_input_value defines the mapping from input values to + * representative values for a component. + * Note that the latter two routines may impose different policies for + * different components, though this is not currently done. + */ + + +LOCAL(int) +select_ncolors (j_decompress_ptr cinfo, int Ncolors[]) +/* Determine allocation of desired colors to components, */ +/* and fill in Ncolors[] array to indicate choice. */ +/* Return value is total number of colors (product of Ncolors[] values). */ +{ + int nc = cinfo->out_color_components; /* number of color components */ + int max_colors = cinfo->desired_number_of_colors; + int total_colors, iroot, i, j; + int changed; + long temp; + static const int RGB_order[3] = { RGB_GREEN, RGB_RED, RGB_BLUE }; + + /* We can allocate at least the nc'th root of max_colors per component. */ + /* Compute floor(nc'th root of max_colors). */ + iroot = 1; + do { + iroot++; + temp = iroot; /* set temp = iroot ** nc */ + for (i = 1; i < nc; i++) + temp *= iroot; + } while (temp <= (long) max_colors); /* repeat till iroot exceeds root */ + iroot--; /* now iroot = floor(root) */ + + /* Must have at least 2 color values per component */ + if (iroot < 2) + ERREXIT1(cinfo, JERR_QUANT_FEW_COLORS, (int) temp); + + /* Initialize to iroot color values for each component */ + total_colors = 1; + for (i = 0; i < nc; i++) { + Ncolors[i] = iroot; + total_colors *= iroot; + } + /* We may be able to increment the count for one or more components without + * exceeding max_colors, though we know not all can be incremented. + * Sometimes, the first component can be incremented more than once! + * (Example: for 16 colors, we start at 2*2*2, go to 3*2*2, then 4*2*2.) + * In RGB colorspace, try to increment G first, then R, then B. + */ + do { + changed = FALSE; + for (i = 0; i < nc; i++) { + j = (cinfo->out_color_space == JCS_RGB ? RGB_order[i] : i); + /* calculate new total_colors if Ncolors[j] is incremented */ + temp = total_colors / Ncolors[j]; + temp *= Ncolors[j]+1; /* done in long arith to avoid oflo */ + if (temp > (long) max_colors) + break; /* won't fit, done with this pass */ + Ncolors[j]++; /* OK, apply the increment */ + total_colors = (int) temp; + changed = TRUE; + } + } while (changed); + + return total_colors; +} + + +LOCAL(int) +output_value (j_decompress_ptr , int , int j, int maxj) +/* Return j'th output value, where j will range from 0 to maxj */ +/* The output values must fall in 0..MAXJSAMPLE in increasing order */ +{ + /* We always provide values 0 and MAXJSAMPLE for each component; + * any additional values are equally spaced between these limits. + * (Forcing the upper and lower values to the limits ensures that + * dithering can't produce a color outside the selected gamut.) + */ + return (int) (((long) j * MAXJSAMPLE + maxj/2) / maxj); +} + + +LOCAL(int) +largest_input_value (j_decompress_ptr , int , int j, int maxj) +/* Return largest input value that should map to j'th output value */ +/* Must have largest(j=0) >= 0, and largest(j=maxj) >= MAXJSAMPLE */ +{ + /* Breakpoints are halfway between values returned by output_value */ + return (int) (((long) (2*j + 1) * MAXJSAMPLE + maxj) / (2*maxj)); +} + + +/* + * Create the colormap. + */ + +LOCAL(void) +create_colormap (j_decompress_ptr cinfo) +{ + my_cquantize_ptr cquantize = (my_cquantize_ptr) cinfo->cquantize; + JSAMPARRAY colormap; /* Created colormap */ + int total_colors; /* Number of distinct output colors */ + int i,j,k, nci, blksize, blkdist, ptr, val; + + /* Select number of colors for each component */ + total_colors = select_ncolors(cinfo, cquantize->Ncolors); + + /* Report selected color counts */ + if (cinfo->out_color_components == 3) + TRACEMS4(cinfo, 1, JTRC_QUANT_3_NCOLORS, + total_colors, cquantize->Ncolors[0], + cquantize->Ncolors[1], cquantize->Ncolors[2]); + else + TRACEMS1(cinfo, 1, JTRC_QUANT_NCOLORS, total_colors); + + /* Allocate and fill in the colormap. */ + /* The colors are ordered in the map in standard row-major order, */ + /* i.e. rightmost (highest-indexed) color changes most rapidly. */ + + colormap = (*cinfo->mem->alloc_sarray) + ((j_common_ptr) cinfo, JPOOL_IMAGE, + (JDIMENSION) total_colors, (JDIMENSION) cinfo->out_color_components); + + /* blksize is number of adjacent repeated entries for a component */ + /* blkdist is distance between groups of identical entries for a component */ + blkdist = total_colors; + + for (i = 0; i < cinfo->out_color_components; i++) { + /* fill in colormap entries for i'th color component */ + nci = cquantize->Ncolors[i]; /* # of distinct values for this color */ + blksize = blkdist / nci; + for (j = 0; j < nci; j++) { + /* Compute j'th output value (out of nci) for component */ + val = output_value(cinfo, i, j, nci-1); + /* Fill in all colormap entries that have this value of this component */ + for (ptr = j * blksize; ptr < total_colors; ptr += blkdist) { + /* fill in blksize entries beginning at ptr */ + for (k = 0; k < blksize; k++) + colormap[i][ptr+k] = (JSAMPLE) val; + } + } + blkdist = blksize; /* blksize of this color is blkdist of next */ + } + + /* Save the colormap in private storage, + * where it will survive color quantization mode changes. + */ + cquantize->sv_colormap = colormap; + cquantize->sv_actual = total_colors; +} + + +/* + * Create the color index table. + */ + +LOCAL(void) +create_colorindex (j_decompress_ptr cinfo) +{ + my_cquantize_ptr cquantize = (my_cquantize_ptr) cinfo->cquantize; + JSAMPROW indexptr; + int i,j,k, nci, blksize, val, pad; + + /* For ordered dither, we pad the color index tables by MAXJSAMPLE in + * each direction (input index values can be -MAXJSAMPLE .. 2*MAXJSAMPLE). + * This is not necessary in the other dithering modes. However, we + * flag whether it was done in case user changes dithering mode. + */ + if (cinfo->dither_mode == JDITHER_ORDERED) { + pad = MAXJSAMPLE*2; + cquantize->is_padded = TRUE; + } else { + pad = 0; + cquantize->is_padded = FALSE; + } + + cquantize->colorindex = (*cinfo->mem->alloc_sarray) + ((j_common_ptr) cinfo, JPOOL_IMAGE, + (JDIMENSION) (MAXJSAMPLE+1 + pad), + (JDIMENSION) cinfo->out_color_components); + + /* blksize is number of adjacent repeated entries for a component */ + blksize = cquantize->sv_actual; + + for (i = 0; i < cinfo->out_color_components; i++) { + /* fill in colorindex entries for i'th color component */ + nci = cquantize->Ncolors[i]; /* # of distinct values for this color */ + blksize = blksize / nci; + + /* adjust colorindex pointers to provide padding at negative indexes. */ + if (pad) + cquantize->colorindex[i] += MAXJSAMPLE; + + /* in loop, val = index of current output value, */ + /* and k = largest j that maps to current val */ + indexptr = cquantize->colorindex[i]; + val = 0; + k = largest_input_value(cinfo, i, 0, nci-1); + for (j = 0; j <= MAXJSAMPLE; j++) { + while (j > k) /* advance val if past boundary */ + k = largest_input_value(cinfo, i, ++val, nci-1); + /* premultiply so that no multiplication needed in main processing */ + indexptr[j] = (JSAMPLE) (val * blksize); + } + /* Pad at both ends if necessary */ + if (pad) + for (j = 1; j <= MAXJSAMPLE; j++) { + indexptr[-j] = indexptr[0]; + indexptr[MAXJSAMPLE+j] = indexptr[MAXJSAMPLE]; + } + } +} + + +/* + * Create an ordered-dither array for a component having ncolors + * distinct output values. + */ + +LOCAL(ODITHER_MATRIX_PTR) +make_odither_array (j_decompress_ptr cinfo, int ncolors) +{ + ODITHER_MATRIX_PTR odither; + int j,k; + long num,den; + + odither = (ODITHER_MATRIX_PTR) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE, + SIZEOF(ODITHER_MATRIX)); + /* The inter-value distance for this color is MAXJSAMPLE/(ncolors-1). + * Hence the dither value for the matrix cell with fill order f + * (f=0..N-1) should be (N-1-2*f)/(2*N) * MAXJSAMPLE/(ncolors-1). + * On 16-bit-int machine, be careful to avoid overflow. + */ + den = 2 * ODITHER_CELLS * ((long) (ncolors - 1)); + for (j = 0; j < ODITHER_SIZE; j++) { + for (k = 0; k < ODITHER_SIZE; k++) { + num = ((long) (ODITHER_CELLS-1 - 2*((int)base_dither_matrix[j][k]))) + * MAXJSAMPLE; + /* Ensure round towards zero despite C's lack of consistency + * about rounding negative values in integer division... + */ + odither[j][k] = (int) (num<0 ? -((-num)/den) : num/den); + } + } + return odither; +} + + +/* + * Create the ordered-dither tables. + * Components having the same number of representative colors may + * share a dither table. + */ + +LOCAL(void) +create_odither_tables (j_decompress_ptr cinfo) +{ + my_cquantize_ptr cquantize = (my_cquantize_ptr) cinfo->cquantize; + ODITHER_MATRIX_PTR odither; + int i, j, nci; + + for (i = 0; i < cinfo->out_color_components; i++) { + nci = cquantize->Ncolors[i]; /* # of distinct values for this color */ + odither = NULL; /* search for matching prior component */ + for (j = 0; j < i; j++) { + if (nci == cquantize->Ncolors[j]) { + odither = cquantize->odither[j]; + break; + } + } + if (odither == NULL) /* need a new table? */ + odither = make_odither_array(cinfo, nci); + cquantize->odither[i] = odither; + } +} + + +/* + * Map some rows of pixels to the output colormapped representation. + */ + +METHODDEF(void) +color_quantize (j_decompress_ptr cinfo, JSAMPARRAY input_buf, + JSAMPARRAY output_buf, int num_rows) +/* General case, no dithering */ +{ + my_cquantize_ptr cquantize = (my_cquantize_ptr) cinfo->cquantize; + JSAMPARRAY colorindex = cquantize->colorindex; + int pixcode, ci; + JSAMPROW ptrin, ptrout; + int row; + JDIMENSION col; + JDIMENSION width = cinfo->output_width; + int nc = cinfo->out_color_components; + + for (row = 0; row < num_rows; row++) { + ptrin = input_buf[row]; + ptrout = output_buf[row]; + for (col = width; col > 0; col--) { + pixcode = 0; + for (ci = 0; ci < nc; ci++) { + pixcode += GETJSAMPLE(colorindex[ci][GETJSAMPLE(*ptrin++)]); + } + *ptrout++ = (JSAMPLE) pixcode; + } + } +} + + +METHODDEF(void) +color_quantize3 (j_decompress_ptr cinfo, JSAMPARRAY input_buf, + JSAMPARRAY output_buf, int num_rows) +/* Fast path for out_color_components==3, no dithering */ +{ + my_cquantize_ptr cquantize = (my_cquantize_ptr) cinfo->cquantize; + int pixcode; + JSAMPROW ptrin, ptrout; + JSAMPROW colorindex0 = cquantize->colorindex[0]; + JSAMPROW colorindex1 = cquantize->colorindex[1]; + JSAMPROW colorindex2 = cquantize->colorindex[2]; + int row; + JDIMENSION col; + JDIMENSION width = cinfo->output_width; + + for (row = 0; row < num_rows; row++) { + ptrin = input_buf[row]; + ptrout = output_buf[row]; + for (col = width; col > 0; col--) { + pixcode = GETJSAMPLE(colorindex0[GETJSAMPLE(*ptrin++)]); + pixcode += GETJSAMPLE(colorindex1[GETJSAMPLE(*ptrin++)]); + pixcode += GETJSAMPLE(colorindex2[GETJSAMPLE(*ptrin++)]); + *ptrout++ = (JSAMPLE) pixcode; + } + } +} + + +METHODDEF(void) +quantize_ord_dither (j_decompress_ptr cinfo, JSAMPARRAY input_buf, + JSAMPARRAY output_buf, int num_rows) +/* General case, with ordered dithering */ +{ + my_cquantize_ptr cquantize = (my_cquantize_ptr) cinfo->cquantize; + JSAMPROW input_ptr; + JSAMPROW output_ptr; + JSAMPROW colorindex_ci; + int * dither; /* points to active row of dither matrix */ + int row_index, col_index; /* current indexes into dither matrix */ + int nc = cinfo->out_color_components; + int ci; + int row; + JDIMENSION col; + JDIMENSION width = cinfo->output_width; + + for (row = 0; row < num_rows; row++) { + /* Initialize output values to 0 so can process components separately */ + jzero_far((void FAR *) output_buf[row], + (size_t) (width * SIZEOF(JSAMPLE))); + row_index = cquantize->row_index; + for (ci = 0; ci < nc; ci++) { + input_ptr = input_buf[row] + ci; + output_ptr = output_buf[row]; + colorindex_ci = cquantize->colorindex[ci]; + dither = cquantize->odither[ci][row_index]; + col_index = 0; + + for (col = width; col > 0; col--) { + /* Form pixel value + dither, range-limit to 0..MAXJSAMPLE, + * select output value, accumulate into output code for this pixel. + * Range-limiting need not be done explicitly, as we have extended + * the colorindex table to produce the right answers for out-of-range + * inputs. The maximum dither is +- MAXJSAMPLE; this sets the + * required amount of padding. + */ + *output_ptr += colorindex_ci[GETJSAMPLE(*input_ptr)+dither[col_index]]; + input_ptr += nc; + output_ptr++; + col_index = (col_index + 1) & ODITHER_MASK; + } + } + /* Advance row index for next row */ + row_index = (row_index + 1) & ODITHER_MASK; + cquantize->row_index = row_index; + } +} + + +METHODDEF(void) +quantize3_ord_dither (j_decompress_ptr cinfo, JSAMPARRAY input_buf, + JSAMPARRAY output_buf, int num_rows) +/* Fast path for out_color_components==3, with ordered dithering */ +{ + my_cquantize_ptr cquantize = (my_cquantize_ptr) cinfo->cquantize; + int pixcode; + JSAMPROW input_ptr; + JSAMPROW output_ptr; + JSAMPROW colorindex0 = cquantize->colorindex[0]; + JSAMPROW colorindex1 = cquantize->colorindex[1]; + JSAMPROW colorindex2 = cquantize->colorindex[2]; + int * dither0; /* points to active row of dither matrix */ + int * dither1; + int * dither2; + int row_index, col_index; /* current indexes into dither matrix */ + int row; + JDIMENSION col; + JDIMENSION width = cinfo->output_width; + + for (row = 0; row < num_rows; row++) { + row_index = cquantize->row_index; + input_ptr = input_buf[row]; + output_ptr = output_buf[row]; + dither0 = cquantize->odither[0][row_index]; + dither1 = cquantize->odither[1][row_index]; + dither2 = cquantize->odither[2][row_index]; + col_index = 0; + + for (col = width; col > 0; col--) { + pixcode = GETJSAMPLE(colorindex0[GETJSAMPLE(*input_ptr++) + + dither0[col_index]]); + pixcode += GETJSAMPLE(colorindex1[GETJSAMPLE(*input_ptr++) + + dither1[col_index]]); + pixcode += GETJSAMPLE(colorindex2[GETJSAMPLE(*input_ptr++) + + dither2[col_index]]); + *output_ptr++ = (JSAMPLE) pixcode; + col_index = (col_index + 1) & ODITHER_MASK; + } + row_index = (row_index + 1) & ODITHER_MASK; + cquantize->row_index = row_index; + } +} + + +METHODDEF(void) +quantize_fs_dither (j_decompress_ptr cinfo, JSAMPARRAY input_buf, + JSAMPARRAY output_buf, int num_rows) +/* General case, with Floyd-Steinberg dithering */ +{ + my_cquantize_ptr cquantize = (my_cquantize_ptr) cinfo->cquantize; + LOCFSERROR cur; /* current error or pixel value */ + LOCFSERROR belowerr; /* error for pixel below cur */ + LOCFSERROR bpreverr; /* error for below/prev col */ + LOCFSERROR bnexterr; /* error for below/next col */ + LOCFSERROR delta; + FSERRPTR errorptr; /* => fserrors[] at column before current */ + JSAMPROW input_ptr; + JSAMPROW output_ptr; + JSAMPROW colorindex_ci; + JSAMPROW colormap_ci; + int pixcode; + int nc = cinfo->out_color_components; + int dir; /* 1 for left-to-right, -1 for right-to-left */ + int dirnc; /* dir * nc */ + int ci; + int row; + JDIMENSION col; + JDIMENSION width = cinfo->output_width; + JSAMPLE *range_limit = cinfo->sample_range_limit; + SHIFT_TEMPS + + for (row = 0; row < num_rows; row++) { + /* Initialize output values to 0 so can process components separately */ + jzero_far((void FAR *) output_buf[row], + (size_t) (width * SIZEOF(JSAMPLE))); + for (ci = 0; ci < nc; ci++) { + input_ptr = input_buf[row] + ci; + output_ptr = output_buf[row]; + if (cquantize->on_odd_row) { + /* work right to left in this row */ + input_ptr += (width-1) * nc; /* so point to rightmost pixel */ + output_ptr += width-1; + dir = -1; + dirnc = -nc; + errorptr = cquantize->fserrors[ci] + (width+1); /* => entry after last column */ + } else { + /* work left to right in this row */ + dir = 1; + dirnc = nc; + errorptr = cquantize->fserrors[ci]; /* => entry before first column */ + } + colorindex_ci = cquantize->colorindex[ci]; + colormap_ci = cquantize->sv_colormap[ci]; + /* Preset error values: no error propagated to first pixel from left */ + cur = 0; + /* and no error propagated to row below yet */ + belowerr = bpreverr = 0; + + for (col = width; col > 0; col--) { + /* cur holds the error propagated from the previous pixel on the + * current line. Add the error propagated from the previous line + * to form the complete error correction term for this pixel, and + * round the error term (which is expressed * 16) to an integer. + * RIGHT_SHIFT rounds towards minus infinity, so adding 8 is correct + * for either sign of the error value. + * Note: errorptr points to *previous* column's array entry. + */ + cur = RIGHT_SHIFT(cur + errorptr[dir] + 8, 4); + /* Form pixel value + error, and range-limit to 0..MAXJSAMPLE. + * The maximum error is +- MAXJSAMPLE; this sets the required size + * of the range_limit array. + */ + cur += GETJSAMPLE(*input_ptr); + cur = GETJSAMPLE(range_limit[cur]); + /* Select output value, accumulate into output code for this pixel */ + pixcode = GETJSAMPLE(colorindex_ci[cur]); + *output_ptr += (JSAMPLE) pixcode; + /* Compute actual representation error at this pixel */ + /* Note: we can do this even though we don't have the final */ + /* pixel code, because the colormap is orthogonal. */ + cur -= GETJSAMPLE(colormap_ci[pixcode]); + /* Compute error fractions to be propagated to adjacent pixels. + * Add these into the running sums, and simultaneously shift the + * next-line error sums left by 1 column. + */ + bnexterr = cur; + delta = cur * 2; + cur += delta; /* form error * 3 */ + errorptr[0] = (FSERROR) (bpreverr + cur); + cur += delta; /* form error * 5 */ + bpreverr = belowerr + cur; + belowerr = bnexterr; + cur += delta; /* form error * 7 */ + /* At this point cur contains the 7/16 error value to be propagated + * to the next pixel on the current line, and all the errors for the + * next line have been shifted over. We are therefore ready to move on. + */ + input_ptr += dirnc; /* advance input ptr to next column */ + output_ptr += dir; /* advance output ptr to next column */ + errorptr += dir; /* advance errorptr to current column */ + } + /* Post-loop cleanup: we must unload the final error value into the + * final fserrors[] entry. Note we need not unload belowerr because + * it is for the dummy column before or after the actual array. + */ + errorptr[0] = (FSERROR) bpreverr; /* unload prev err into array */ + } + cquantize->on_odd_row = (cquantize->on_odd_row ? FALSE : TRUE); + } +} + + +/* + * Allocate workspace for Floyd-Steinberg errors. + */ + +LOCAL(void) +alloc_fs_workspace (j_decompress_ptr cinfo) +{ + my_cquantize_ptr cquantize = (my_cquantize_ptr) cinfo->cquantize; + size_t arraysize; + int i; + + arraysize = (size_t) ((cinfo->output_width + 2) * SIZEOF(FSERROR)); + for (i = 0; i < cinfo->out_color_components; i++) { + cquantize->fserrors[i] = (FSERRPTR) + (*cinfo->mem->alloc_large)((j_common_ptr) cinfo, JPOOL_IMAGE, arraysize); + } +} + + +/* + * Initialize for one-pass color quantization. + */ + +METHODDEF(void) +start_pass_1_quant (j_decompress_ptr cinfo, int ) +{ + my_cquantize_ptr cquantize = (my_cquantize_ptr) cinfo->cquantize; + size_t arraysize; + int i; + + /* Install my colormap. */ + cinfo->colormap = cquantize->sv_colormap; + cinfo->actual_number_of_colors = cquantize->sv_actual; + + /* Initialize for desired dithering mode. */ + switch (cinfo->dither_mode) { + case JDITHER_NONE: + if (cinfo->out_color_components == 3) + cquantize->pub.color_quantize = color_quantize3; + else + cquantize->pub.color_quantize = color_quantize; + break; + case JDITHER_ORDERED: + if (cinfo->out_color_components == 3) + cquantize->pub.color_quantize = quantize3_ord_dither; + else + cquantize->pub.color_quantize = quantize_ord_dither; + cquantize->row_index = 0; /* initialize state for ordered dither */ + /* If user changed to ordered dither from another mode, + * we must recreate the color index table with padding. + * This will cost extra space, but probably isn't very likely. + */ + if (! cquantize->is_padded) + create_colorindex(cinfo); + /* Create ordered-dither tables if we didn't already. */ + if (cquantize->odither[0] == NULL) + create_odither_tables(cinfo); + break; + case JDITHER_FS: + cquantize->pub.color_quantize = quantize_fs_dither; + cquantize->on_odd_row = FALSE; /* initialize state for F-S dither */ + /* Allocate Floyd-Steinberg workspace if didn't already. */ + if (cquantize->fserrors[0] == NULL) + alloc_fs_workspace(cinfo); + /* Initialize the propagated errors to zero. */ + arraysize = (size_t) ((cinfo->output_width + 2) * SIZEOF(FSERROR)); + for (i = 0; i < cinfo->out_color_components; i++) + jzero_far((void FAR *) cquantize->fserrors[i], arraysize); + break; + default: + ERREXIT(cinfo, JERR_NOT_COMPILED); + break; + } +} + + +/* + * Finish up at the end of the pass. + */ + +METHODDEF(void) +finish_pass_1_quant (j_decompress_ptr ) +{ + /* no work in 1-pass case */ +} + + +/* + * Switch to a new external colormap between output passes. + * Shouldn't get to this module! + */ + +METHODDEF(void) +new_color_map_1_quant (j_decompress_ptr cinfo) +{ + ERREXIT(cinfo, JERR_MODE_CHANGE); +} + + +/* + * Module initialization routine for 1-pass color quantization. + */ + +GLOBAL(void) +jinit_1pass_quantizer (j_decompress_ptr cinfo) +{ + my_cquantize_ptr cquantize; + + cquantize = (my_cquantize_ptr) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE, + SIZEOF(my_cquantizer)); + cinfo->cquantize = (struct jpeg_color_quantizer *) cquantize; + cquantize->pub.start_pass = start_pass_1_quant; + cquantize->pub.finish_pass = finish_pass_1_quant; + cquantize->pub.new_color_map = new_color_map_1_quant; + cquantize->fserrors[0] = NULL; /* Flag FS workspace not allocated */ + cquantize->odither[0] = NULL; /* Also flag odither arrays not allocated */ + + /* Make sure my internal arrays won't overflow */ + if (cinfo->out_color_components > MAX_Q_COMPS) + ERREXIT1(cinfo, JERR_QUANT_COMPONENTS, MAX_Q_COMPS); + /* Make sure colormap indexes can be represented by JSAMPLEs */ + if (cinfo->desired_number_of_colors > (MAXJSAMPLE+1)) + ERREXIT1(cinfo, JERR_QUANT_MANY_COLORS, MAXJSAMPLE+1); + + /* Create the colormap and color index table. */ + create_colormap(cinfo); + create_colorindex(cinfo); + + /* Allocate Floyd-Steinberg workspace now if requested. + * We do this now since it is FAR storage and may affect the memory + * manager's space calculations. If the user changes to FS dither + * mode in a later pass, we will allocate the space then, and will + * possibly overrun the max_memory_to_use setting. + */ + if (cinfo->dither_mode == JDITHER_FS) + alloc_fs_workspace(cinfo); +} + +#endif /* QUANT_1PASS_SUPPORTED */ diff --git a/ml/dlib/dlib/external/libjpeg/jquant2.cpp b/ml/dlib/dlib/external/libjpeg/jquant2.cpp new file mode 100644 index 000000000..0d7b5969a --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jquant2.cpp @@ -0,0 +1,1310 @@ +/* + * jquant2.c + * + * Copyright (C) 1991-1996, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This file contains 2-pass color quantization (color mapping) routines. + * These routines provide selection of a custom color map for an image, + * followed by mapping of the image to that color map, with optional + * Floyd-Steinberg dithering. + * It is also possible to use just the second pass to map to an arbitrary + * externally-given color map. + * + * Note: ordered dithering is not supported, since there isn't any fast + * way to compute intercolor distances; it's unclear that ordered dither's + * fundamental assumptions even hold with an irregularly spaced color map. + */ + +#define JPEG_INTERNALS +#include "jinclude.h" +#include "jpeglib.h" + +#ifdef QUANT_2PASS_SUPPORTED + + +/* + * This module implements the well-known Heckbert paradigm for color + * quantization. Most of the ideas used here can be traced back to + * Heckbert's seminal paper + * Heckbert, Paul. "Color Image Quantization for Frame Buffer Display", + * Proc. SIGGRAPH '82, Computer Graphics v.16 #3 (July 1982), pp 297-304. + * + * In the first pass over the image, we accumulate a histogram showing the + * usage count of each possible color. To keep the histogram to a reasonable + * size, we reduce the precision of the input; typical practice is to retain + * 5 or 6 bits per color, so that 8 or 4 different input values are counted + * in the same histogram cell. + * + * Next, the color-selection step begins with a box representing the whole + * color space, and repeatedly splits the "largest" remaining box until we + * have as many boxes as desired colors. Then the mean color in each + * remaining box becomes one of the possible output colors. + * + * The second pass over the image maps each input pixel to the closest output + * color (optionally after applying a Floyd-Steinberg dithering correction). + * This mapping is logically trivial, but making it go fast enough requires + * considerable care. + * + * Heckbert-style quantizers vary a good deal in their policies for choosing + * the "largest" box and deciding where to cut it. The particular policies + * used here have proved out well in experimental comparisons, but better ones + * may yet be found. + * + * In earlier versions of the IJG code, this module quantized in YCbCr color + * space, processing the raw upsampled data without a color conversion step. + * This allowed the color conversion math to be done only once per colormap + * entry, not once per pixel. However, that optimization precluded other + * useful optimizations (such as merging color conversion with upsampling) + * and it also interfered with desired capabilities such as quantizing to an + * externally-supplied colormap. We have therefore abandoned that approach. + * The present code works in the post-conversion color space, typically RGB. + * + * To improve the visual quality of the results, we actually work in scaled + * RGB space, giving G distances more weight than R, and R in turn more than + * B. To do everything in integer math, we must use integer scale factors. + * The 2/3/1 scale factors used here correspond loosely to the relative + * weights of the colors in the NTSC grayscale equation. + * If you want to use this code to quantize a non-RGB color space, you'll + * probably need to change these scale factors. + */ + +#define R_SCALE 2 /* scale R distances by this much */ +#define G_SCALE 3 /* scale G distances by this much */ +#define B_SCALE 1 /* and B by this much */ + +/* Relabel R/G/B as components 0/1/2, respecting the RGB ordering defined + * in jmorecfg.h. As the code stands, it will do the right thing for R,G,B + * and B,G,R orders. If you define some other weird order in jmorecfg.h, + * you'll get compile errors until you extend this logic. In that case + * you'll probably want to tweak the histogram sizes too. + */ + +#if RGB_RED == 0 +#define C0_SCALE R_SCALE +#endif +#if RGB_BLUE == 0 +#define C0_SCALE B_SCALE +#endif +#if RGB_GREEN == 1 +#define C1_SCALE G_SCALE +#endif +#if RGB_RED == 2 +#define C2_SCALE R_SCALE +#endif +#if RGB_BLUE == 2 +#define C2_SCALE B_SCALE +#endif + + +/* + * First we have the histogram data structure and routines for creating it. + * + * The number of bits of precision can be adjusted by changing these symbols. + * We recommend keeping 6 bits for G and 5 each for R and B. + * If you have plenty of memory and cycles, 6 bits all around gives marginally + * better results; if you are short of memory, 5 bits all around will save + * some space but degrade the results. + * To maintain a fully accurate histogram, we'd need to allocate a "long" + * (preferably unsigned long) for each cell. In practice this is overkill; + * we can get by with 16 bits per cell. Few of the cell counts will overflow, + * and clamping those that do overflow to the maximum value will give close- + * enough results. This reduces the recommended histogram size from 256Kb + * to 128Kb, which is a useful savings on PC-class machines. + * (In the second pass the histogram space is re-used for pixel mapping data; + * in that capacity, each cell must be able to store zero to the number of + * desired colors. 16 bits/cell is plenty for that too.) + * Since the JPEG code is intended to run in small memory model on 80x86 + * machines, we can't just allocate the histogram in one chunk. Instead + * of a true 3-D array, we use a row of pointers to 2-D arrays. Each + * pointer corresponds to a C0 value (typically 2^5 = 32 pointers) and + * each 2-D array has 2^6*2^5 = 2048 or 2^6*2^6 = 4096 entries. Note that + * on 80x86 machines, the pointer row is in near memory but the actual + * arrays are in far memory (same arrangement as we use for image arrays). + */ + +#define MAXNUMCOLORS (MAXJSAMPLE+1) /* maximum size of colormap */ + +/* These will do the right thing for either R,G,B or B,G,R color order, + * but you may not like the results for other color orders. + */ +#define HIST_C0_BITS 5 /* bits of precision in R/B histogram */ +#define HIST_C1_BITS 6 /* bits of precision in G histogram */ +#define HIST_C2_BITS 5 /* bits of precision in B/R histogram */ + +/* Number of elements along histogram axes. */ +#define HIST_C0_ELEMS (1<cquantize; + JSAMPROW ptr; + histptr histp; + hist3d histogram = cquantize->histogram; + int row; + JDIMENSION col; + JDIMENSION width = cinfo->output_width; + + for (row = 0; row < num_rows; row++) { + ptr = input_buf[row]; + for (col = width; col > 0; col--) { + /* get pixel value and index into the histogram */ + histp = & histogram[GETJSAMPLE(ptr[0]) >> C0_SHIFT] + [GETJSAMPLE(ptr[1]) >> C1_SHIFT] + [GETJSAMPLE(ptr[2]) >> C2_SHIFT]; + /* increment, check for overflow and undo increment if so. */ + if (++(*histp) <= 0) + (*histp)--; + ptr += 3; + } + } +} + + +/* + * Next we have the really interesting routines: selection of a colormap + * given the completed histogram. + * These routines work with a list of "boxes", each representing a rectangular + * subset of the input color space (to histogram precision). + */ + +typedef struct { + /* The bounds of the box (inclusive); expressed as histogram indexes */ + int c0min, c0max; + int c1min, c1max; + int c2min, c2max; + /* The volume (actually 2-norm) of the box */ + long volume; + /* The number of nonzero histogram cells within this box */ + long colorcount; +} box; + +typedef box * boxptr; + + +LOCAL(boxptr) +find_biggest_color_pop (boxptr boxlist, int numboxes) +/* Find the splittable box with the largest color population */ +/* Returns NULL if no splittable boxes remain */ +{ + boxptr boxp; + int i; + long maxc = 0; + boxptr which = NULL; + + for (i = 0, boxp = boxlist; i < numboxes; i++, boxp++) { + if (boxp->colorcount > maxc && boxp->volume > 0) { + which = boxp; + maxc = boxp->colorcount; + } + } + return which; +} + + +LOCAL(boxptr) +find_biggest_volume (boxptr boxlist, int numboxes) +/* Find the splittable box with the largest (scaled) volume */ +/* Returns NULL if no splittable boxes remain */ +{ + boxptr boxp; + int i; + long maxv = 0; + boxptr which = NULL; + + for (i = 0, boxp = boxlist; i < numboxes; i++, boxp++) { + if (boxp->volume > maxv) { + which = boxp; + maxv = boxp->volume; + } + } + return which; +} + + +LOCAL(void) +update_box (j_decompress_ptr cinfo, boxptr boxp) +/* Shrink the min/max bounds of a box to enclose only nonzero elements, */ +/* and recompute its volume and population */ +{ + my_cquantize_ptr cquantize = (my_cquantize_ptr) cinfo->cquantize; + hist3d histogram = cquantize->histogram; + histptr histp; + int c0,c1,c2; + int c0min,c0max,c1min,c1max,c2min,c2max; + long dist0,dist1,dist2; + long ccount; + + c0min = boxp->c0min; c0max = boxp->c0max; + c1min = boxp->c1min; c1max = boxp->c1max; + c2min = boxp->c2min; c2max = boxp->c2max; + + if (c0max > c0min) + for (c0 = c0min; c0 <= c0max; c0++) + for (c1 = c1min; c1 <= c1max; c1++) { + histp = & histogram[c0][c1][c2min]; + for (c2 = c2min; c2 <= c2max; c2++) + if (*histp++ != 0) { + boxp->c0min = c0min = c0; + goto have_c0min; + } + } + have_c0min: + if (c0max > c0min) + for (c0 = c0max; c0 >= c0min; c0--) + for (c1 = c1min; c1 <= c1max; c1++) { + histp = & histogram[c0][c1][c2min]; + for (c2 = c2min; c2 <= c2max; c2++) + if (*histp++ != 0) { + boxp->c0max = c0max = c0; + goto have_c0max; + } + } + have_c0max: + if (c1max > c1min) + for (c1 = c1min; c1 <= c1max; c1++) + for (c0 = c0min; c0 <= c0max; c0++) { + histp = & histogram[c0][c1][c2min]; + for (c2 = c2min; c2 <= c2max; c2++) + if (*histp++ != 0) { + boxp->c1min = c1min = c1; + goto have_c1min; + } + } + have_c1min: + if (c1max > c1min) + for (c1 = c1max; c1 >= c1min; c1--) + for (c0 = c0min; c0 <= c0max; c0++) { + histp = & histogram[c0][c1][c2min]; + for (c2 = c2min; c2 <= c2max; c2++) + if (*histp++ != 0) { + boxp->c1max = c1max = c1; + goto have_c1max; + } + } + have_c1max: + if (c2max > c2min) + for (c2 = c2min; c2 <= c2max; c2++) + for (c0 = c0min; c0 <= c0max; c0++) { + histp = & histogram[c0][c1min][c2]; + for (c1 = c1min; c1 <= c1max; c1++, histp += HIST_C2_ELEMS) + if (*histp != 0) { + boxp->c2min = c2min = c2; + goto have_c2min; + } + } + have_c2min: + if (c2max > c2min) + for (c2 = c2max; c2 >= c2min; c2--) + for (c0 = c0min; c0 <= c0max; c0++) { + histp = & histogram[c0][c1min][c2]; + for (c1 = c1min; c1 <= c1max; c1++, histp += HIST_C2_ELEMS) + if (*histp != 0) { + boxp->c2max = c2max = c2; + goto have_c2max; + } + } + have_c2max: + + /* Update box volume. + * We use 2-norm rather than real volume here; this biases the method + * against making long narrow boxes, and it has the side benefit that + * a box is splittable iff norm > 0. + * Since the differences are expressed in histogram-cell units, + * we have to shift back to JSAMPLE units to get consistent distances; + * after which, we scale according to the selected distance scale factors. + */ + dist0 = ((c0max - c0min) << C0_SHIFT) * C0_SCALE; + dist1 = ((c1max - c1min) << C1_SHIFT) * C1_SCALE; + dist2 = ((c2max - c2min) << C2_SHIFT) * C2_SCALE; + boxp->volume = dist0*dist0 + dist1*dist1 + dist2*dist2; + + /* Now scan remaining volume of box and compute population */ + ccount = 0; + for (c0 = c0min; c0 <= c0max; c0++) + for (c1 = c1min; c1 <= c1max; c1++) { + histp = & histogram[c0][c1][c2min]; + for (c2 = c2min; c2 <= c2max; c2++, histp++) + if (*histp != 0) { + ccount++; + } + } + boxp->colorcount = ccount; +} + + +LOCAL(int) +median_cut (j_decompress_ptr cinfo, boxptr boxlist, int numboxes, + int desired_colors) +/* Repeatedly select and split the largest box until we have enough boxes */ +{ + int n,lb; + int c0,c1,c2,cmax; + boxptr b1,b2; + + while (numboxes < desired_colors) { + /* Select box to split. + * Current algorithm: by population for first half, then by volume. + */ + if (numboxes*2 <= desired_colors) { + b1 = find_biggest_color_pop(boxlist, numboxes); + } else { + b1 = find_biggest_volume(boxlist, numboxes); + } + if (b1 == NULL) /* no splittable boxes left! */ + break; + b2 = &boxlist[numboxes]; /* where new box will go */ + /* Copy the color bounds to the new box. */ + b2->c0max = b1->c0max; b2->c1max = b1->c1max; b2->c2max = b1->c2max; + b2->c0min = b1->c0min; b2->c1min = b1->c1min; b2->c2min = b1->c2min; + /* Choose which axis to split the box on. + * Current algorithm: longest scaled axis. + * See notes in update_box about scaling distances. + */ + c0 = ((b1->c0max - b1->c0min) << C0_SHIFT) * C0_SCALE; + c1 = ((b1->c1max - b1->c1min) << C1_SHIFT) * C1_SCALE; + c2 = ((b1->c2max - b1->c2min) << C2_SHIFT) * C2_SCALE; + /* We want to break any ties in favor of green, then red, blue last. + * This code does the right thing for R,G,B or B,G,R color orders only. + */ +#if RGB_RED == 0 + cmax = c1; n = 1; + if (c0 > cmax) { cmax = c0; n = 0; } + if (c2 > cmax) { n = 2; } +#else + cmax = c1; n = 1; + if (c2 > cmax) { cmax = c2; n = 2; } + if (c0 > cmax) { n = 0; } +#endif + /* Choose split point along selected axis, and update box bounds. + * Current algorithm: split at halfway point. + * (Since the box has been shrunk to minimum volume, + * any split will produce two nonempty subboxes.) + * Note that lb value is max for lower box, so must be < old max. + */ + switch (n) { + case 0: + lb = (b1->c0max + b1->c0min) / 2; + b1->c0max = lb; + b2->c0min = lb+1; + break; + case 1: + lb = (b1->c1max + b1->c1min) / 2; + b1->c1max = lb; + b2->c1min = lb+1; + break; + case 2: + lb = (b1->c2max + b1->c2min) / 2; + b1->c2max = lb; + b2->c2min = lb+1; + break; + } + /* Update stats for boxes */ + update_box(cinfo, b1); + update_box(cinfo, b2); + numboxes++; + } + return numboxes; +} + + +LOCAL(void) +compute_color (j_decompress_ptr cinfo, boxptr boxp, int icolor) +/* Compute representative color for a box, put it in colormap[icolor] */ +{ + /* Current algorithm: mean weighted by pixels (not colors) */ + /* Note it is important to get the rounding correct! */ + my_cquantize_ptr cquantize = (my_cquantize_ptr) cinfo->cquantize; + hist3d histogram = cquantize->histogram; + histptr histp; + int c0,c1,c2; + int c0min,c0max,c1min,c1max,c2min,c2max; + long count; + long total = 0; + long c0total = 0; + long c1total = 0; + long c2total = 0; + + c0min = boxp->c0min; c0max = boxp->c0max; + c1min = boxp->c1min; c1max = boxp->c1max; + c2min = boxp->c2min; c2max = boxp->c2max; + + for (c0 = c0min; c0 <= c0max; c0++) + for (c1 = c1min; c1 <= c1max; c1++) { + histp = & histogram[c0][c1][c2min]; + for (c2 = c2min; c2 <= c2max; c2++) { + if ((count = *histp++) != 0) { + total += count; + c0total += ((c0 << C0_SHIFT) + ((1<>1)) * count; + c1total += ((c1 << C1_SHIFT) + ((1<>1)) * count; + c2total += ((c2 << C2_SHIFT) + ((1<>1)) * count; + } + } + } + + cinfo->colormap[0][icolor] = (JSAMPLE) ((c0total + (total>>1)) / total); + cinfo->colormap[1][icolor] = (JSAMPLE) ((c1total + (total>>1)) / total); + cinfo->colormap[2][icolor] = (JSAMPLE) ((c2total + (total>>1)) / total); +} + + +LOCAL(void) +select_colors (j_decompress_ptr cinfo, int desired_colors) +/* Master routine for color selection */ +{ + boxptr boxlist; + int numboxes; + int i; + + /* Allocate workspace for box list */ + boxlist = (boxptr) (*cinfo->mem->alloc_small) + ((j_common_ptr) cinfo, JPOOL_IMAGE, desired_colors * SIZEOF(box)); + /* Initialize one box containing whole space */ + numboxes = 1; + boxlist[0].c0min = 0; + boxlist[0].c0max = MAXJSAMPLE >> C0_SHIFT; + boxlist[0].c1min = 0; + boxlist[0].c1max = MAXJSAMPLE >> C1_SHIFT; + boxlist[0].c2min = 0; + boxlist[0].c2max = MAXJSAMPLE >> C2_SHIFT; + /* Shrink it to actually-used volume and set its statistics */ + update_box(cinfo, & boxlist[0]); + /* Perform median-cut to produce final box list */ + numboxes = median_cut(cinfo, boxlist, numboxes, desired_colors); + /* Compute the representative color for each box, fill colormap */ + for (i = 0; i < numboxes; i++) + compute_color(cinfo, & boxlist[i], i); + cinfo->actual_number_of_colors = numboxes; + TRACEMS1(cinfo, 1, JTRC_QUANT_SELECTED, numboxes); +} + + +/* + * These routines are concerned with the time-critical task of mapping input + * colors to the nearest color in the selected colormap. + * + * We re-use the histogram space as an "inverse color map", essentially a + * cache for the results of nearest-color searches. All colors within a + * histogram cell will be mapped to the same colormap entry, namely the one + * closest to the cell's center. This may not be quite the closest entry to + * the actual input color, but it's almost as good. A zero in the cache + * indicates we haven't found the nearest color for that cell yet; the array + * is cleared to zeroes before starting the mapping pass. When we find the + * nearest color for a cell, its colormap index plus one is recorded in the + * cache for future use. The pass2 scanning routines call fill_inverse_cmap + * when they need to use an unfilled entry in the cache. + * + * Our method of efficiently finding nearest colors is based on the "locally + * sorted search" idea described by Heckbert and on the incremental distance + * calculation described by Spencer W. Thomas in chapter III.1 of Graphics + * Gems II (James Arvo, ed. Academic Press, 1991). Thomas points out that + * the distances from a given colormap entry to each cell of the histogram can + * be computed quickly using an incremental method: the differences between + * distances to adjacent cells themselves differ by a constant. This allows a + * fairly fast implementation of the "brute force" approach of computing the + * distance from every colormap entry to every histogram cell. Unfortunately, + * it needs a work array to hold the best-distance-so-far for each histogram + * cell (because the inner loop has to be over cells, not colormap entries). + * The work array elements have to be INT32s, so the work array would need + * 256Kb at our recommended precision. This is not feasible in DOS machines. + * + * To get around these problems, we apply Thomas' method to compute the + * nearest colors for only the cells within a small subbox of the histogram. + * The work array need be only as big as the subbox, so the memory usage + * problem is solved. Furthermore, we need not fill subboxes that are never + * referenced in pass2; many images use only part of the color gamut, so a + * fair amount of work is saved. An additional advantage of this + * approach is that we can apply Heckbert's locality criterion to quickly + * eliminate colormap entries that are far away from the subbox; typically + * three-fourths of the colormap entries are rejected by Heckbert's criterion, + * and we need not compute their distances to individual cells in the subbox. + * The speed of this approach is heavily influenced by the subbox size: too + * small means too much overhead, too big loses because Heckbert's criterion + * can't eliminate as many colormap entries. Empirically the best subbox + * size seems to be about 1/512th of the histogram (1/8th in each direction). + * + * Thomas' article also describes a refined method which is asymptotically + * faster than the brute-force method, but it is also far more complex and + * cannot efficiently be applied to small subboxes. It is therefore not + * useful for programs intended to be portable to DOS machines. On machines + * with plenty of memory, filling the whole histogram in one shot with Thomas' + * refined method might be faster than the present code --- but then again, + * it might not be any faster, and it's certainly more complicated. + */ + + +/* log2(histogram cells in update box) for each axis; this can be adjusted */ +#define BOX_C0_LOG (HIST_C0_BITS-3) +#define BOX_C1_LOG (HIST_C1_BITS-3) +#define BOX_C2_LOG (HIST_C2_BITS-3) + +#define BOX_C0_ELEMS (1<actual_number_of_colors; + int maxc0, maxc1, maxc2; + int centerc0, centerc1, centerc2; + int i, x, ncolors; + long minmaxdist, min_dist, max_dist, tdist; + long mindist[MAXNUMCOLORS]; /* min distance to colormap entry i */ + + /* Compute true coordinates of update box's upper corner and center. + * Actually we compute the coordinates of the center of the upper-corner + * histogram cell, which are the upper bounds of the volume we care about. + * Note that since ">>" rounds down, the "center" values may be closer to + * min than to max; hence comparisons to them must be "<=", not "<". + */ + maxc0 = minc0 + ((1 << BOX_C0_SHIFT) - (1 << C0_SHIFT)); + centerc0 = (minc0 + maxc0) >> 1; + maxc1 = minc1 + ((1 << BOX_C1_SHIFT) - (1 << C1_SHIFT)); + centerc1 = (minc1 + maxc1) >> 1; + maxc2 = minc2 + ((1 << BOX_C2_SHIFT) - (1 << C2_SHIFT)); + centerc2 = (minc2 + maxc2) >> 1; + + /* For each color in colormap, find: + * 1. its minimum squared-distance to any point in the update box + * (zero if color is within update box); + * 2. its maximum squared-distance to any point in the update box. + * Both of these can be found by considering only the corners of the box. + * We save the minimum distance for each color in mindist[]; + * only the smallest maximum distance is of interest. + */ + minmaxdist = 0x7FFFFFFFL; + + for (i = 0; i < numcolors; i++) { + /* We compute the squared-c0-distance term, then add in the other two. */ + x = GETJSAMPLE(cinfo->colormap[0][i]); + if (x < minc0) { + tdist = (x - minc0) * C0_SCALE; + min_dist = tdist*tdist; + tdist = (x - maxc0) * C0_SCALE; + max_dist = tdist*tdist; + } else if (x > maxc0) { + tdist = (x - maxc0) * C0_SCALE; + min_dist = tdist*tdist; + tdist = (x - minc0) * C0_SCALE; + max_dist = tdist*tdist; + } else { + /* within cell range so no contribution to min_dist */ + min_dist = 0; + if (x <= centerc0) { + tdist = (x - maxc0) * C0_SCALE; + max_dist = tdist*tdist; + } else { + tdist = (x - minc0) * C0_SCALE; + max_dist = tdist*tdist; + } + } + + x = GETJSAMPLE(cinfo->colormap[1][i]); + if (x < minc1) { + tdist = (x - minc1) * C1_SCALE; + min_dist += tdist*tdist; + tdist = (x - maxc1) * C1_SCALE; + max_dist += tdist*tdist; + } else if (x > maxc1) { + tdist = (x - maxc1) * C1_SCALE; + min_dist += tdist*tdist; + tdist = (x - minc1) * C1_SCALE; + max_dist += tdist*tdist; + } else { + /* within cell range so no contribution to min_dist */ + if (x <= centerc1) { + tdist = (x - maxc1) * C1_SCALE; + max_dist += tdist*tdist; + } else { + tdist = (x - minc1) * C1_SCALE; + max_dist += tdist*tdist; + } + } + + x = GETJSAMPLE(cinfo->colormap[2][i]); + if (x < minc2) { + tdist = (x - minc2) * C2_SCALE; + min_dist += tdist*tdist; + tdist = (x - maxc2) * C2_SCALE; + max_dist += tdist*tdist; + } else if (x > maxc2) { + tdist = (x - maxc2) * C2_SCALE; + min_dist += tdist*tdist; + tdist = (x - minc2) * C2_SCALE; + max_dist += tdist*tdist; + } else { + /* within cell range so no contribution to min_dist */ + if (x <= centerc2) { + tdist = (x - maxc2) * C2_SCALE; + max_dist += tdist*tdist; + } else { + tdist = (x - minc2) * C2_SCALE; + max_dist += tdist*tdist; + } + } + + mindist[i] = min_dist; /* save away the results */ + if (max_dist < minmaxdist) + minmaxdist = max_dist; + } + + /* Now we know that no cell in the update box is more than minmaxdist + * away from some colormap entry. Therefore, only colors that are + * within minmaxdist of some part of the box need be considered. + */ + ncolors = 0; + for (i = 0; i < numcolors; i++) { + if (mindist[i] <= minmaxdist) + colorlist[ncolors++] = (JSAMPLE) i; + } + return ncolors; +} + + +LOCAL(void) +find_best_colors (j_decompress_ptr cinfo, int minc0, int minc1, int minc2, + int numcolors, JSAMPLE colorlist[], JSAMPLE bestcolor[]) +/* Find the closest colormap entry for each cell in the update box, + * given the list of candidate colors prepared by find_nearby_colors. + * Return the indexes of the closest entries in the bestcolor[] array. + * This routine uses Thomas' incremental distance calculation method to + * find the distance from a colormap entry to successive cells in the box. + */ +{ + int ic0, ic1, ic2; + int i, icolor; + long * bptr; /* pointer into bestdist[] array */ + JSAMPLE * cptr; /* pointer into bestcolor[] array */ + long dist0, dist1; /* initial distance values */ + long dist2; /* current distance in inner loop */ + long xx0, xx1; /* distance increments */ + long xx2; + long inc0, inc1, inc2; /* initial values for increments */ + /* This array holds the distance to the nearest-so-far color for each cell */ + long bestdist[BOX_C0_ELEMS * BOX_C1_ELEMS * BOX_C2_ELEMS]; + + /* Initialize best-distance for each cell of the update box */ + bptr = bestdist; + for (i = BOX_C0_ELEMS*BOX_C1_ELEMS*BOX_C2_ELEMS-1; i >= 0; i--) + *bptr++ = 0x7FFFFFFFL; + + /* For each color selected by find_nearby_colors, + * compute its distance to the center of each cell in the box. + * If that's less than best-so-far, update best distance and color number. + */ + + /* Nominal steps between cell centers ("x" in Thomas article) */ +#define STEP_C0 ((1 << C0_SHIFT) * C0_SCALE) +#define STEP_C1 ((1 << C1_SHIFT) * C1_SCALE) +#define STEP_C2 ((1 << C2_SHIFT) * C2_SCALE) + + for (i = 0; i < numcolors; i++) { + icolor = GETJSAMPLE(colorlist[i]); + /* Compute (square of) distance from minc0/c1/c2 to this color */ + inc0 = (minc0 - GETJSAMPLE(cinfo->colormap[0][icolor])) * C0_SCALE; + dist0 = inc0*inc0; + inc1 = (minc1 - GETJSAMPLE(cinfo->colormap[1][icolor])) * C1_SCALE; + dist0 += inc1*inc1; + inc2 = (minc2 - GETJSAMPLE(cinfo->colormap[2][icolor])) * C2_SCALE; + dist0 += inc2*inc2; + /* Form the initial difference increments */ + inc0 = inc0 * (2 * STEP_C0) + STEP_C0 * STEP_C0; + inc1 = inc1 * (2 * STEP_C1) + STEP_C1 * STEP_C1; + inc2 = inc2 * (2 * STEP_C2) + STEP_C2 * STEP_C2; + /* Now loop over all cells in box, updating distance per Thomas method */ + bptr = bestdist; + cptr = bestcolor; + xx0 = inc0; + for (ic0 = BOX_C0_ELEMS-1; ic0 >= 0; ic0--) { + dist1 = dist0; + xx1 = inc1; + for (ic1 = BOX_C1_ELEMS-1; ic1 >= 0; ic1--) { + dist2 = dist1; + xx2 = inc2; + for (ic2 = BOX_C2_ELEMS-1; ic2 >= 0; ic2--) { + if (dist2 < *bptr) { + *bptr = dist2; + *cptr = (JSAMPLE) icolor; + } + dist2 += xx2; + xx2 += 2 * STEP_C2 * STEP_C2; + bptr++; + cptr++; + } + dist1 += xx1; + xx1 += 2 * STEP_C1 * STEP_C1; + } + dist0 += xx0; + xx0 += 2 * STEP_C0 * STEP_C0; + } + } +} + + +LOCAL(void) +fill_inverse_cmap (j_decompress_ptr cinfo, int c0, int c1, int c2) +/* Fill the inverse-colormap entries in the update box that contains */ +/* histogram cell c0/c1/c2. (Only that one cell MUST be filled, but */ +/* we can fill as many others as we wish.) */ +{ + my_cquantize_ptr cquantize = (my_cquantize_ptr) cinfo->cquantize; + hist3d histogram = cquantize->histogram; + int minc0, minc1, minc2; /* lower left corner of update box */ + int ic0, ic1, ic2; + JSAMPLE * cptr; /* pointer into bestcolor[] array */ + histptr cachep; /* pointer into main cache array */ + /* This array lists the candidate colormap indexes. */ + JSAMPLE colorlist[MAXNUMCOLORS]; + int numcolors; /* number of candidate colors */ + /* This array holds the actually closest colormap index for each cell. */ + JSAMPLE bestcolor[BOX_C0_ELEMS * BOX_C1_ELEMS * BOX_C2_ELEMS]; + + /* Convert cell coordinates to update box ID */ + c0 >>= BOX_C0_LOG; + c1 >>= BOX_C1_LOG; + c2 >>= BOX_C2_LOG; + + /* Compute true coordinates of update box's origin corner. + * Actually we compute the coordinates of the center of the corner + * histogram cell, which are the lower bounds of the volume we care about. + */ + minc0 = (c0 << BOX_C0_SHIFT) + ((1 << C0_SHIFT) >> 1); + minc1 = (c1 << BOX_C1_SHIFT) + ((1 << C1_SHIFT) >> 1); + minc2 = (c2 << BOX_C2_SHIFT) + ((1 << C2_SHIFT) >> 1); + + /* Determine which colormap entries are close enough to be candidates + * for the nearest entry to some cell in the update box. + */ + numcolors = find_nearby_colors(cinfo, minc0, minc1, minc2, colorlist); + + /* Determine the actually nearest colors. */ + find_best_colors(cinfo, minc0, minc1, minc2, numcolors, colorlist, + bestcolor); + + /* Save the best color numbers (plus 1) in the main cache array */ + c0 <<= BOX_C0_LOG; /* convert ID back to base cell indexes */ + c1 <<= BOX_C1_LOG; + c2 <<= BOX_C2_LOG; + cptr = bestcolor; + for (ic0 = 0; ic0 < BOX_C0_ELEMS; ic0++) { + for (ic1 = 0; ic1 < BOX_C1_ELEMS; ic1++) { + cachep = & histogram[c0+ic0][c1+ic1][c2]; + for (ic2 = 0; ic2 < BOX_C2_ELEMS; ic2++) { + *cachep++ = (histcell) (GETJSAMPLE(*cptr++) + 1); + } + } + } +} + + +/* + * Map some rows of pixels to the output colormapped representation. + */ + +METHODDEF(void) +pass2_no_dither (j_decompress_ptr cinfo, + JSAMPARRAY input_buf, JSAMPARRAY output_buf, int num_rows) +/* This version performs no dithering */ +{ + my_cquantize_ptr cquantize = (my_cquantize_ptr) cinfo->cquantize; + hist3d histogram = cquantize->histogram; + JSAMPROW inptr, outptr; + histptr cachep; + int c0, c1, c2; + int row; + JDIMENSION col; + JDIMENSION width = cinfo->output_width; + + for (row = 0; row < num_rows; row++) { + inptr = input_buf[row]; + outptr = output_buf[row]; + for (col = width; col > 0; col--) { + /* get pixel value and index into the cache */ + c0 = GETJSAMPLE(*inptr++) >> C0_SHIFT; + c1 = GETJSAMPLE(*inptr++) >> C1_SHIFT; + c2 = GETJSAMPLE(*inptr++) >> C2_SHIFT; + cachep = & histogram[c0][c1][c2]; + /* If we have not seen this color before, find nearest colormap entry */ + /* and update the cache */ + if (*cachep == 0) + fill_inverse_cmap(cinfo, c0,c1,c2); + /* Now emit the colormap index for this cell */ + *outptr++ = (JSAMPLE) (*cachep - 1); + } + } +} + + +METHODDEF(void) +pass2_fs_dither (j_decompress_ptr cinfo, + JSAMPARRAY input_buf, JSAMPARRAY output_buf, int num_rows) +/* This version performs Floyd-Steinberg dithering */ +{ + my_cquantize_ptr cquantize = (my_cquantize_ptr) cinfo->cquantize; + hist3d histogram = cquantize->histogram; + LOCFSERROR cur0, cur1, cur2; /* current error or pixel value */ + LOCFSERROR belowerr0, belowerr1, belowerr2; /* error for pixel below cur */ + LOCFSERROR bpreverr0, bpreverr1, bpreverr2; /* error for below/prev col */ + FSERRPTR errorptr; /* => fserrors[] at column before current */ + JSAMPROW inptr; /* => current input pixel */ + JSAMPROW outptr; /* => current output pixel */ + histptr cachep; + int dir; /* +1 or -1 depending on direction */ + int dir3; /* 3*dir, for advancing inptr & errorptr */ + int row; + JDIMENSION col; + JDIMENSION width = cinfo->output_width; + JSAMPLE *range_limit = cinfo->sample_range_limit; + int *error_limit = cquantize->error_limiter; + JSAMPROW colormap0 = cinfo->colormap[0]; + JSAMPROW colormap1 = cinfo->colormap[1]; + JSAMPROW colormap2 = cinfo->colormap[2]; + SHIFT_TEMPS + + for (row = 0; row < num_rows; row++) { + inptr = input_buf[row]; + outptr = output_buf[row]; + if (cquantize->on_odd_row) { + /* work right to left in this row */ + inptr += (width-1) * 3; /* so point to rightmost pixel */ + outptr += width-1; + dir = -1; + dir3 = -3; + errorptr = cquantize->fserrors + (width+1)*3; /* => entry after last column */ + cquantize->on_odd_row = FALSE; /* flip for next time */ + } else { + /* work left to right in this row */ + dir = 1; + dir3 = 3; + errorptr = cquantize->fserrors; /* => entry before first real column */ + cquantize->on_odd_row = TRUE; /* flip for next time */ + } + /* Preset error values: no error propagated to first pixel from left */ + cur0 = cur1 = cur2 = 0; + /* and no error propagated to row below yet */ + belowerr0 = belowerr1 = belowerr2 = 0; + bpreverr0 = bpreverr1 = bpreverr2 = 0; + + for (col = width; col > 0; col--) { + /* curN holds the error propagated from the previous pixel on the + * current line. Add the error propagated from the previous line + * to form the complete error correction term for this pixel, and + * round the error term (which is expressed * 16) to an integer. + * RIGHT_SHIFT rounds towards minus infinity, so adding 8 is correct + * for either sign of the error value. + * Note: errorptr points to *previous* column's array entry. + */ + cur0 = RIGHT_SHIFT(cur0 + errorptr[dir3+0] + 8, 4); + cur1 = RIGHT_SHIFT(cur1 + errorptr[dir3+1] + 8, 4); + cur2 = RIGHT_SHIFT(cur2 + errorptr[dir3+2] + 8, 4); + /* Limit the error using transfer function set by init_error_limit. + * See comments with init_error_limit for rationale. + */ + cur0 = error_limit[cur0]; + cur1 = error_limit[cur1]; + cur2 = error_limit[cur2]; + /* Form pixel value + error, and range-limit to 0..MAXJSAMPLE. + * The maximum error is +- MAXJSAMPLE (or less with error limiting); + * this sets the required size of the range_limit array. + */ + cur0 += GETJSAMPLE(inptr[0]); + cur1 += GETJSAMPLE(inptr[1]); + cur2 += GETJSAMPLE(inptr[2]); + cur0 = GETJSAMPLE(range_limit[cur0]); + cur1 = GETJSAMPLE(range_limit[cur1]); + cur2 = GETJSAMPLE(range_limit[cur2]); + /* Index into the cache with adjusted pixel value */ + cachep = & histogram[cur0>>C0_SHIFT][cur1>>C1_SHIFT][cur2>>C2_SHIFT]; + /* If we have not seen this color before, find nearest colormap */ + /* entry and update the cache */ + if (*cachep == 0) + fill_inverse_cmap(cinfo, cur0>>C0_SHIFT,cur1>>C1_SHIFT,cur2>>C2_SHIFT); + /* Now emit the colormap index for this cell */ + { int pixcode = *cachep - 1; + *outptr = (JSAMPLE) pixcode; + /* Compute representation error for this pixel */ + cur0 -= GETJSAMPLE(colormap0[pixcode]); + cur1 -= GETJSAMPLE(colormap1[pixcode]); + cur2 -= GETJSAMPLE(colormap2[pixcode]); + } + /* Compute error fractions to be propagated to adjacent pixels. + * Add these into the running sums, and simultaneously shift the + * next-line error sums left by 1 column. + */ + { LOCFSERROR bnexterr, delta; + + bnexterr = cur0; /* Process component 0 */ + delta = cur0 * 2; + cur0 += delta; /* form error * 3 */ + errorptr[0] = (FSERROR) (bpreverr0 + cur0); + cur0 += delta; /* form error * 5 */ + bpreverr0 = belowerr0 + cur0; + belowerr0 = bnexterr; + cur0 += delta; /* form error * 7 */ + bnexterr = cur1; /* Process component 1 */ + delta = cur1 * 2; + cur1 += delta; /* form error * 3 */ + errorptr[1] = (FSERROR) (bpreverr1 + cur1); + cur1 += delta; /* form error * 5 */ + bpreverr1 = belowerr1 + cur1; + belowerr1 = bnexterr; + cur1 += delta; /* form error * 7 */ + bnexterr = cur2; /* Process component 2 */ + delta = cur2 * 2; + cur2 += delta; /* form error * 3 */ + errorptr[2] = (FSERROR) (bpreverr2 + cur2); + cur2 += delta; /* form error * 5 */ + bpreverr2 = belowerr2 + cur2; + belowerr2 = bnexterr; + cur2 += delta; /* form error * 7 */ + } + /* At this point curN contains the 7/16 error value to be propagated + * to the next pixel on the current line, and all the errors for the + * next line have been shifted over. We are therefore ready to move on. + */ + inptr += dir3; /* Advance pixel pointers to next column */ + outptr += dir; + errorptr += dir3; /* advance errorptr to current column */ + } + /* Post-loop cleanup: we must unload the final error values into the + * final fserrors[] entry. Note we need not unload belowerrN because + * it is for the dummy column before or after the actual array. + */ + errorptr[0] = (FSERROR) bpreverr0; /* unload prev errs into array */ + errorptr[1] = (FSERROR) bpreverr1; + errorptr[2] = (FSERROR) bpreverr2; + } +} + + +/* + * Initialize the error-limiting transfer function (lookup table). + * The raw F-S error computation can potentially compute error values of up to + * +- MAXJSAMPLE. But we want the maximum correction applied to a pixel to be + * much less, otherwise obviously wrong pixels will be created. (Typical + * effects include weird fringes at color-area boundaries, isolated bright + * pixels in a dark area, etc.) The standard advice for avoiding this problem + * is to ensure that the "corners" of the color cube are allocated as output + * colors; then repeated errors in the same direction cannot cause cascading + * error buildup. However, that only prevents the error from getting + * completely out of hand; Aaron Giles reports that error limiting improves + * the results even with corner colors allocated. + * A simple clamping of the error values to about +- MAXJSAMPLE/8 works pretty + * well, but the smoother transfer function used below is even better. Thanks + * to Aaron Giles for this idea. + */ + +LOCAL(void) +init_error_limit (j_decompress_ptr cinfo) +/* Allocate and fill in the error_limiter table */ +{ + my_cquantize_ptr cquantize = (my_cquantize_ptr) cinfo->cquantize; + int * table; + int in, out; + + table = (int *) (*cinfo->mem->alloc_small) + ((j_common_ptr) cinfo, JPOOL_IMAGE, (MAXJSAMPLE*2+1) * SIZEOF(int)); + table += MAXJSAMPLE; /* so can index -MAXJSAMPLE .. +MAXJSAMPLE */ + cquantize->error_limiter = table; + +#define STEPSIZE ((MAXJSAMPLE+1)/16) + /* Map errors 1:1 up to +- MAXJSAMPLE/16 */ + out = 0; + for (in = 0; in < STEPSIZE; in++, out++) { + table[in] = out; table[-in] = -out; + } + /* Map errors 1:2 up to +- 3*MAXJSAMPLE/16 */ + for (; in < STEPSIZE*3; in++, out += (in&1) ? 0 : 1) { + table[in] = out; table[-in] = -out; + } + /* Clamp the rest to final out value (which is (MAXJSAMPLE+1)/8) */ + for (; in <= MAXJSAMPLE; in++) { + table[in] = out; table[-in] = -out; + } +#undef STEPSIZE +} + + +/* + * Finish up at the end of each pass. + */ + +METHODDEF(void) +finish_pass1 (j_decompress_ptr cinfo) +{ + my_cquantize_ptr cquantize = (my_cquantize_ptr) cinfo->cquantize; + + /* Select the representative colors and fill in cinfo->colormap */ + cinfo->colormap = cquantize->sv_colormap; + select_colors(cinfo, cquantize->desired); + /* Force next pass to zero the color index table */ + cquantize->needs_zeroed = TRUE; +} + + +METHODDEF(void) +finish_pass2 (j_decompress_ptr ) +{ + /* no work */ +} + + +/* + * Initialize for each processing pass. + */ + +METHODDEF(void) +start_pass_2_quant (j_decompress_ptr cinfo, int is_pre_scan) +{ + my_cquantize_ptr cquantize = (my_cquantize_ptr) cinfo->cquantize; + hist3d histogram = cquantize->histogram; + int i; + + /* Only F-S dithering or no dithering is supported. */ + /* If user asks for ordered dither, give him F-S. */ + if (cinfo->dither_mode != JDITHER_NONE) + cinfo->dither_mode = JDITHER_FS; + + if (is_pre_scan) { + /* Set up method pointers */ + cquantize->pub.color_quantize = prescan_quantize; + cquantize->pub.finish_pass = finish_pass1; + cquantize->needs_zeroed = TRUE; /* Always zero histogram */ + } else { + /* Set up method pointers */ + if (cinfo->dither_mode == JDITHER_FS) + cquantize->pub.color_quantize = pass2_fs_dither; + else + cquantize->pub.color_quantize = pass2_no_dither; + cquantize->pub.finish_pass = finish_pass2; + + /* Make sure color count is acceptable */ + i = cinfo->actual_number_of_colors; + if (i < 1) + ERREXIT1(cinfo, JERR_QUANT_FEW_COLORS, 1); + if (i > MAXNUMCOLORS) + ERREXIT1(cinfo, JERR_QUANT_MANY_COLORS, MAXNUMCOLORS); + + if (cinfo->dither_mode == JDITHER_FS) { + size_t arraysize = (size_t) ((cinfo->output_width + 2) * + (3 * SIZEOF(FSERROR))); + /* Allocate Floyd-Steinberg workspace if we didn't already. */ + if (cquantize->fserrors == NULL) + cquantize->fserrors = (FSERRPTR) (*cinfo->mem->alloc_large) + ((j_common_ptr) cinfo, JPOOL_IMAGE, arraysize); + /* Initialize the propagated errors to zero. */ + jzero_far((void FAR *) cquantize->fserrors, arraysize); + /* Make the error-limit table if we didn't already. */ + if (cquantize->error_limiter == NULL) + init_error_limit(cinfo); + cquantize->on_odd_row = FALSE; + } + + } + /* Zero the histogram or inverse color map, if necessary */ + if (cquantize->needs_zeroed) { + for (i = 0; i < HIST_C0_ELEMS; i++) { + jzero_far((void FAR *) histogram[i], + HIST_C1_ELEMS*HIST_C2_ELEMS * SIZEOF(histcell)); + } + cquantize->needs_zeroed = FALSE; + } +} + + +/* + * Switch to a new external colormap between output passes. + */ + +METHODDEF(void) +new_color_map_2_quant (j_decompress_ptr cinfo) +{ + my_cquantize_ptr cquantize = (my_cquantize_ptr) cinfo->cquantize; + + /* Reset the inverse color map */ + cquantize->needs_zeroed = TRUE; +} + + +/* + * Module initialization routine for 2-pass color quantization. + */ + +GLOBAL(void) +jinit_2pass_quantizer (j_decompress_ptr cinfo) +{ + my_cquantize_ptr cquantize; + int i; + + cquantize = (my_cquantize_ptr) + (*cinfo->mem->alloc_small) ((j_common_ptr) cinfo, JPOOL_IMAGE, + SIZEOF(my_cquantizer)); + cinfo->cquantize = (struct jpeg_color_quantizer *) cquantize; + cquantize->pub.start_pass = start_pass_2_quant; + cquantize->pub.new_color_map = new_color_map_2_quant; + cquantize->fserrors = NULL; /* flag optional arrays not allocated */ + cquantize->error_limiter = NULL; + + /* Make sure jdmaster didn't give me a case I can't handle */ + if (cinfo->out_color_components != 3) + ERREXIT(cinfo, JERR_NOTIMPL); + + /* Allocate the histogram/inverse colormap storage */ + cquantize->histogram = (hist3d) (*cinfo->mem->alloc_small) + ((j_common_ptr) cinfo, JPOOL_IMAGE, HIST_C0_ELEMS * SIZEOF(hist2d)); + for (i = 0; i < HIST_C0_ELEMS; i++) { + cquantize->histogram[i] = (hist2d) (*cinfo->mem->alloc_large) + ((j_common_ptr) cinfo, JPOOL_IMAGE, + HIST_C1_ELEMS*HIST_C2_ELEMS * SIZEOF(histcell)); + } + cquantize->needs_zeroed = TRUE; /* histogram is garbage now */ + + /* Allocate storage for the completed colormap, if required. + * We do this now since it is FAR storage and may affect + * the memory manager's space calculations. + */ + if (cinfo->enable_2pass_quant) { + /* Make sure color count is acceptable */ + int desired = cinfo->desired_number_of_colors; + /* Lower bound on # of colors ... somewhat arbitrary as long as > 0 */ + if (desired < 8) + ERREXIT1(cinfo, JERR_QUANT_FEW_COLORS, 8); + /* Make sure colormap indexes can be represented by JSAMPLEs */ + if (desired > MAXNUMCOLORS) + ERREXIT1(cinfo, JERR_QUANT_MANY_COLORS, MAXNUMCOLORS); + cquantize->sv_colormap = (*cinfo->mem->alloc_sarray) + ((j_common_ptr) cinfo,JPOOL_IMAGE, (JDIMENSION) desired, (JDIMENSION) 3); + cquantize->desired = desired; + } else + cquantize->sv_colormap = NULL; + + /* Only F-S dithering or no dithering is supported. */ + /* If user asks for ordered dither, give him F-S. */ + if (cinfo->dither_mode != JDITHER_NONE) + cinfo->dither_mode = JDITHER_FS; + + /* Allocate Floyd-Steinberg workspace if necessary. + * This isn't really needed until pass 2, but again it is FAR storage. + * Although we will cope with a later change in dither_mode, + * we do not promise to honor max_memory_to_use if dither_mode changes. + */ + if (cinfo->dither_mode == JDITHER_FS) { + cquantize->fserrors = (FSERRPTR) (*cinfo->mem->alloc_large) + ((j_common_ptr) cinfo, JPOOL_IMAGE, + (size_t) ((cinfo->output_width + 2) * (3 * SIZEOF(FSERROR)))); + /* Might as well create the error-limiting table too. */ + init_error_limit(cinfo); + } +} + +#endif /* QUANT_2PASS_SUPPORTED */ diff --git a/ml/dlib/dlib/external/libjpeg/jutils.cpp b/ml/dlib/dlib/external/libjpeg/jutils.cpp new file mode 100644 index 000000000..fd8906c83 --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jutils.cpp @@ -0,0 +1,179 @@ +/* + * jutils.c + * + * Copyright (C) 1991-1996, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This file contains tables and miscellaneous utility routines needed + * for both compression and decompression. + * Note we prefix all global names with "j" to minimize conflicts with + * a surrounding application. + */ + +#define JPEG_INTERNALS +#include "jinclude.h" +#include "jpeglib.h" + + +/* + * jpeg_zigzag_order[i] is the zigzag-order position of the i'th element + * of a DCT block read in natural order (left to right, top to bottom). + */ + +#if 0 /* This table is not actually needed in v6a */ + +const int jpeg_zigzag_order[DCTSIZE2] = { + 0, 1, 5, 6, 14, 15, 27, 28, + 2, 4, 7, 13, 16, 26, 29, 42, + 3, 8, 12, 17, 25, 30, 41, 43, + 9, 11, 18, 24, 31, 40, 44, 53, + 10, 19, 23, 32, 39, 45, 52, 54, + 20, 22, 33, 38, 46, 51, 55, 60, + 21, 34, 37, 47, 50, 56, 59, 61, + 35, 36, 48, 49, 57, 58, 62, 63 +}; + +#endif + +/* + * jpeg_natural_order[i] is the natural-order position of the i'th element + * of zigzag order. + * + * When reading corrupted data, the Huffman decoders could attempt + * to reference an entry beyond the end of this array (if the decoded + * zero run length reaches past the end of the block). To prevent + * wild stores without adding an inner-loop test, we put some extra + * "63"s after the real entries. This will cause the extra coefficient + * to be stored in location 63 of the block, not somewhere random. + * The worst case would be a run-length of 15, which means we need 16 + * fake entries. + */ + +const int jpeg_natural_order[DCTSIZE2+16] = { + 0, 1, 8, 16, 9, 2, 3, 10, + 17, 24, 32, 25, 18, 11, 4, 5, + 12, 19, 26, 33, 40, 48, 41, 34, + 27, 20, 13, 6, 7, 14, 21, 28, + 35, 42, 49, 56, 57, 50, 43, 36, + 29, 22, 15, 23, 30, 37, 44, 51, + 58, 59, 52, 45, 38, 31, 39, 46, + 53, 60, 61, 54, 47, 55, 62, 63, + 63, 63, 63, 63, 63, 63, 63, 63, /* extra entries for safety in decoder */ + 63, 63, 63, 63, 63, 63, 63, 63 +}; + + +/* + * Arithmetic utilities + */ + +GLOBAL(long) +jdiv_round_up (long a, long b) +/* Compute a/b rounded up to next integer, ie, ceil(a/b) */ +/* Assumes a >= 0, b > 0 */ +{ + return (a + b - 1L) / b; +} + + +GLOBAL(long) +jround_up (long a, long b) +/* Compute a rounded up to next multiple of b, ie, ceil(a/b)*b */ +/* Assumes a >= 0, b > 0 */ +{ + a += b - 1L; + return a - (a % b); +} + + +/* On normal machines we can apply MEMCOPY() and MEMZERO() to sample arrays + * and coefficient-block arrays. This won't work on 80x86 because the arrays + * are FAR and we're assuming a small-pointer memory model. However, some + * DOS compilers provide far-pointer versions of memcpy() and memset() even + * in the small-model libraries. These will be used if USE_FMEM is defined. + * Otherwise, the routines below do it the hard way. (The performance cost + * is not all that great, because these routines aren't very heavily used.) + */ + +#ifndef NEED_FAR_POINTERS /* normal case, same as regular macros */ +#define FMEMCOPY(dest,src,size) MEMCOPY(dest,src,size) +#define FMEMZERO(target,size) MEMZERO(target,size) +#else /* 80x86 case, define if we can */ +#ifdef USE_FMEM +#define FMEMCOPY(dest,src,size) _fmemcpy((void FAR *)(dest), (const void FAR *)(src), (size_t)(size)) +#define FMEMZERO(target,size) _fmemset((void FAR *)(target), 0, (size_t)(size)) +#endif +#endif + + +GLOBAL(void) +jcopy_sample_rows (JSAMPARRAY input_array, int source_row, + JSAMPARRAY output_array, int dest_row, + int num_rows, JDIMENSION num_cols) +/* Copy some rows of samples from one place to another. + * num_rows rows are copied from input_array[source_row++] + * to output_array[dest_row++]; these areas may overlap for duplication. + * The source and destination arrays must be at least as wide as num_cols. + */ +{ + JSAMPROW inptr, outptr; +#ifdef FMEMCOPY + size_t count = (size_t) (num_cols * SIZEOF(JSAMPLE)); +#else + JDIMENSION count; +#endif + int row; + + input_array += source_row; + output_array += dest_row; + + for (row = num_rows; row > 0; row--) { + inptr = *input_array++; + outptr = *output_array++; +#ifdef FMEMCOPY + FMEMCOPY(outptr, inptr, count); +#else + for (count = num_cols; count > 0; count--) + *outptr++ = *inptr++; /* needn't bother with GETJSAMPLE() here */ +#endif + } +} + + +GLOBAL(void) +jcopy_block_row (JBLOCKROW input_row, JBLOCKROW output_row, + JDIMENSION num_blocks) +/* Copy a row of coefficient blocks from one place to another. */ +{ +#ifdef FMEMCOPY + FMEMCOPY(output_row, input_row, num_blocks * (DCTSIZE2 * SIZEOF(JCOEF))); +#else + JCOEFPTR inptr, outptr; + long count; + + inptr = (JCOEFPTR) input_row; + outptr = (JCOEFPTR) output_row; + for (count = (long) num_blocks * DCTSIZE2; count > 0; count--) { + *outptr++ = *inptr++; + } +#endif +} + + +GLOBAL(void) +jzero_far (void FAR * target, size_t bytestozero) +/* Zero out a chunk of FAR memory. */ +/* This might be sample-array data, block-array data, or alloc_large data. */ +{ +#ifdef FMEMZERO + FMEMZERO(target, bytestozero); +#else + char FAR * ptr = (char FAR *) target; + size_t count; + + for (count = bytestozero; count > 0; count--) { + *ptr++ = 0; + } +#endif +} diff --git a/ml/dlib/dlib/external/libjpeg/jversion.h b/ml/dlib/dlib/external/libjpeg/jversion.h new file mode 100644 index 000000000..6472c58d3 --- /dev/null +++ b/ml/dlib/dlib/external/libjpeg/jversion.h @@ -0,0 +1,14 @@ +/* + * jversion.h + * + * Copyright (C) 1991-1998, Thomas G. Lane. + * This file is part of the Independent JPEG Group's software. + * For conditions of distribution and use, see the accompanying README file. + * + * This file contains software version identification. + */ + + +#define JVERSION "6b 27-Mar-1998" + +#define JCOPYRIGHT "Copyright (C) 1998, Thomas G. Lane" diff --git a/ml/dlib/dlib/external/libpng/LICENSE b/ml/dlib/dlib/external/libpng/LICENSE new file mode 100644 index 000000000..b1b97ea57 --- /dev/null +++ b/ml/dlib/dlib/external/libpng/LICENSE @@ -0,0 +1,111 @@ + +This copy of the libpng notices is provided for your convenience. In case of +any discrepancy between this copy and the notices in the file png.h that is +included in the libpng distribution, the latter shall prevail. + +COPYRIGHT NOTICE, DISCLAIMER, and LICENSE: + +If you modify libpng you may insert additional notices immediately following +this sentence. + +This code is released under the libpng license. + +libpng versions 1.2.6, August 15, 2004, through 1.6.7, November 14, 2013, are +Copyright (c) 2004, 2006-2013 Glenn Randers-Pehrson, and are +distributed according to the same disclaimer and license as libpng-1.2.5 +with the following individual added to the list of Contributing Authors + + Cosmin Truta + +libpng versions 1.0.7, July 1, 2000, through 1.2.5 - October 3, 2002, are +Copyright (c) 2000-2002 Glenn Randers-Pehrson, and are +distributed according to the same disclaimer and license as libpng-1.0.6 +with the following individuals added to the list of Contributing Authors + + Simon-Pierre Cadieux + Eric S. Raymond + Gilles Vollant + +and with the following additions to the disclaimer: + + There is no warranty against interference with your enjoyment of the + library or against infringement. There is no warranty that our + efforts or the library will fulfill any of your particular purposes + or needs. This library is provided with all faults, and the entire + risk of satisfactory quality, performance, accuracy, and effort is with + the user. + +libpng versions 0.97, January 1998, through 1.0.6, March 20, 2000, are +Copyright (c) 1998, 1999 Glenn Randers-Pehrson, and are +distributed according to the same disclaimer and license as libpng-0.96, +with the following individuals added to the list of Contributing Authors: + + Tom Lane + Glenn Randers-Pehrson + Willem van Schaik + +libpng versions 0.89, June 1996, through 0.96, May 1997, are +Copyright (c) 1996, 1997 Andreas Dilger +Distributed according to the same disclaimer and license as libpng-0.88, +with the following individuals added to the list of Contributing Authors: + + John Bowler + Kevin Bracey + Sam Bushell + Magnus Holmgren + Greg Roelofs + Tom Tanner + +libpng versions 0.5, May 1995, through 0.88, January 1996, are +Copyright (c) 1995, 1996 Guy Eric Schalnat, Group 42, Inc. + +For the purposes of this copyright and license, "Contributing Authors" +is defined as the following set of individuals: + + Andreas Dilger + Dave Martindale + Guy Eric Schalnat + Paul Schmidt + Tim Wegner + +The PNG Reference Library is supplied "AS IS". The Contributing Authors +and Group 42, Inc. disclaim all warranties, expressed or implied, +including, without limitation, the warranties of merchantability and of +fitness for any purpose. The Contributing Authors and Group 42, Inc. +assume no liability for direct, indirect, incidental, special, exemplary, +or consequential damages, which may result from the use of the PNG +Reference Library, even if advised of the possibility of such damage. + +Permission is hereby granted to use, copy, modify, and distribute this +source code, or portions hereof, for any purpose, without fee, subject +to the following restrictions: + +1. The origin of this source code must not be misrepresented. + +2. Altered versions must be plainly marked as such and must not + be misrepresented as being the original source. + +3. This Copyright notice may not be removed or altered from any + source or altered source distribution. + +The Contributing Authors and Group 42, Inc. specifically permit, without +fee, and encourage the use of this source code as a component to +supporting the PNG file format in commercial products. If you use this +source code in a product, acknowledgment is not required but would be +appreciated. + + +A "png_get_copyright" function is available, for convenient use in "about" +boxes and the like: + + printf("%s",png_get_copyright(NULL)); + +Also, the PNG logo (in PNG format, of course) is supplied in the +files "pngbar.png" and "pngbar.jpg (88x31) and "pngnow.png" (98x31). + +Libpng is OSI Certified Open Source Software. OSI Certified Open Source is a +certification mark of the Open Source Initiative. + +Glenn Randers-Pehrson +glennrp at users.sourceforge.net +November 14, 2013 diff --git a/ml/dlib/dlib/external/libpng/README b/ml/dlib/dlib/external/libpng/README new file mode 100644 index 000000000..80fc574ad --- /dev/null +++ b/ml/dlib/dlib/external/libpng/README @@ -0,0 +1,202 @@ +README for libpng version 1.6.7 - November 14, 2013 (shared library 16.0) +See the note about version numbers near the top of png.h + +See INSTALL for instructions on how to install libpng. + +Libpng comes in several distribution formats. Get libpng-*.tar.gz or +libpng-*.tar.xz or if you want UNIX-style line endings in the text files, +or lpng*.7z or lpng*.zip if you want DOS-style line endings. + +Version 0.89 was the first official release of libpng. Don't let the +fact that it's the first release fool you. The libpng library has been in +extensive use and testing since mid-1995. By late 1997 it had +finally gotten to the stage where there hadn't been significant +changes to the API in some time, and people have a bad feeling about +libraries with versions < 1.0. Version 1.0.0 was released in +March 1998. + +**** +Note that some of the changes to the png_info structure render this +version of the library binary incompatible with libpng-0.89 or +earlier versions if you are using a shared library. The type of the +"filler" parameter for png_set_filler() has changed from png_byte to +png_uint_32, which will affect shared-library applications that use +this function. + +To avoid problems with changes to the internals of png_info_struct, +new APIs have been made available in 0.95 to avoid direct application +access to info_ptr. These functions are the png_set_ and +png_get_ functions. These functions should be used when +accessing/storing the info_struct data, rather than manipulating it +directly, to avoid such problems in the future. + +It is important to note that the APIs do not make current programs +that access the info struct directly incompatible with the new +library. However, it is strongly suggested that new programs use +the new APIs (as shown in example.c and pngtest.c), and older programs +be converted to the new format, to facilitate upgrades in the future. +**** + +Additions since 0.90 include the ability to compile libpng as a +Windows DLL, and new APIs for accessing data in the info struct. +Experimental functions include the ability to set weighting and cost +factors for row filter selection, direct reads of integers from buffers +on big-endian processors that support misaligned data access, faster +methods of doing alpha composition, and more accurate 16->8 bit color +conversion. + +The additions since 0.89 include the ability to read from a PNG stream +which has had some (or all) of the signature bytes read by the calling +application. This also allows the reading of embedded PNG streams that +do not have the PNG file signature. As well, it is now possible to set +the library action on the detection of chunk CRC errors. It is possible +to set different actions based on whether the CRC error occurred in a +critical or an ancillary chunk. + +The changes made to the library, and bugs fixed are based on discussions +on the PNG-implement mailing list and not on material submitted +privately to Guy, Andreas, or Glenn. They will forward any good +suggestions to the list. + +For a detailed description on using libpng, read libpng-manual.txt. For +examples of libpng in a program, see example.c and pngtest.c. For usage +information and restrictions (what little they are) on libpng, see +png.h. For a description on using zlib (the compression library used by +libpng) and zlib's restrictions, see zlib.h + +I have included a general makefile, as well as several machine and +compiler specific ones, but you may have to modify one for your own needs. + +You should use zlib 1.0.4 or later to run this, but it MAY work with +versions as old as zlib 0.95. Even so, there are bugs in older zlib +versions which can cause the output of invalid compression streams for +some images. You will definitely need zlib 1.0.4 or later if you are +taking advantage of the MS-DOS "far" structure allocation for the small +and medium memory models. You should also note that zlib is a +compression library that is useful for more things than just PNG files. +You can use zlib as a drop-in replacement for fread() and fwrite() if +you are so inclined. + +zlib should be available at the same place that libpng is, or at zlib.net. + +You may also want a copy of the PNG specification. It is available +as an RFC, a W3C Recommendation, and an ISO/IEC Standard. You can find +these at http://www.libpng.org/pub/png/documents/ + +This code is currently being archived at libpng.sf.net in the +[DOWNLOAD] area, and at ftp://ftp.simplesystems.org. If you can't find it +in any of those places, e-mail me, and I'll help you find it. + +If you have any code changes, requests, problems, etc., please e-mail +them to me. Also, I'd appreciate any make files or project files, +and any modifications you needed to make to get libpng to compile, +along with a #define variable to tell what compiler/system you are on. +If you needed to add transformations to libpng, or wish libpng would +provide the image in a different way, drop me a note (and code, if +possible), so I can consider supporting the transformation. +Finally, if you get any warning messages when compiling libpng +(note: not zlib), and they are easy to fix, I'd appreciate the +fix. Please mention "libpng" somewhere in the subject line. Thanks. + +This release was created and will be supported by myself (of course +based in a large way on Guy's and Andreas' earlier work), and the PNG +development group. + +Send comments/corrections/commendations to png-mng-implement at +lists.sourceforge.net (subscription required; visit +https://lists.sourceforge.net/lists/listinfo/png-mng-implement +to subscribe) or to glennrp at users.sourceforge.net + +You can't reach Guy, the original libpng author, at the addresses +given in previous versions of this document. He and Andreas will +read mail addressed to the png-implement list, however. + +Please do not send general questions about PNG. Send them to +png-mng-misc at lists.sf.net (subscription required; visit +https://lists.sourceforge.net/lists/listinfo/png-mng-misc to +subscribe). If you have a question about something +in the PNG specification that is related to using libpng, send it +to me. Send me any questions that start with "I was using libpng, +and ...". If in doubt, send questions to me. I'll bounce them +to others, if necessary. + +Please do not send suggestions on how to change PNG. We have +been discussing PNG for eighteen years now, and it is official and +finished. If you have suggestions for libpng, however, I'll +gladly listen. Even if your suggestion is not used immediately, +it may be used later. + +Files in this distribution: + + ANNOUNCE => Announcement of this version, with recent changes + CHANGES => Description of changes between libpng versions + KNOWNBUG => List of known bugs and deficiencies + LICENSE => License to use and redistribute libpng + README => This file + TODO => Things not implemented in the current library + Y2KINFO => Statement of Y2K compliance + example.c => Example code for using libpng functions + libpng.3 => manual page for libpng (includes libpng-manual.txt) + libpng-manual.txt => Description of libpng and its functions + libpngpf.3 => manual page for libpng's private functions + png.5 => manual page for the PNG format + png.c => Basic interface functions common to library + png.h => Library function and interface declarations (public) + pngpriv.h => Library function and interface declarations (private) + pngconf.h => System specific library configuration (public) + pngstruct.h => png_struct declaration (private) + pnginfo.h => png_info struct declaration (private) + pngdebug.h => debugging macros (private) + pngerror.c => Error/warning message I/O functions + pngget.c => Functions for retrieving info from struct + pngmem.c => Memory handling functions + pngbar.png => PNG logo, 88x31 + pngnow.png => PNG logo, 98x31 + pngpread.c => Progressive reading functions + pngread.c => Read data/helper high-level functions + pngrio.c => Lowest-level data read I/O functions + pngrtran.c => Read data transformation functions + pngrutil.c => Read data utility functions + pngset.c => Functions for storing data into the info_struct + pngtest.c => Library test program + pngtest.png => Library test sample image + pngtrans.c => Common data transformation functions + pngwio.c => Lowest-level write I/O functions + pngwrite.c => High-level write functions + pngwtran.c => Write data transformations + pngwutil.c => Write utility functions + arm => Contains optimized code for the ARM platform + contrib => Contributions + examples => Example programs + gregbook => source code for PNG reading and writing, from + Greg Roelofs' "PNG: The Definitive Guide", + O'Reilly, 1999 + libtests => Test programs + pngminim => Minimal decoder, encoder, and progressive decoder + programs demonstrating use of pngusr.dfa + pngminus => Simple pnm2png and png2pnm programs + pngsuite => Test images + tools => Various tools + visupng => Contains a MSVC workspace for VisualPng + projects => Contains project files and workspaces for + building a DLL + owatcom => Contains a WATCOM project for building libpng + visualc71 => Contains a Microsoft Visual C++ (MSVC) + workspace for building libpng and zlib + vstudio => Contains a Microsoft Visual C++ (MSVC) + workspace for building libpng and zlib + scripts => Directory containing scripts for building libpng: + (see scripts/README.txt for the list of scripts) + +Good luck, and happy coding. + +-Glenn Randers-Pehrson (current maintainer, since 1998) + Internet: glennrp at users.sourceforge.net + +-Andreas Eric Dilger (former maintainer, 1996-1997) + Internet: adilger at enel.ucalgary.ca + Web: http://www-mddsp.enel.ucalgary.ca/People/adilger/ + +-Guy Eric Schalnat (original author and former maintainer, 1995-1996) + (formerly of Group 42, Inc) + Internet: gschal at infinet.com diff --git a/ml/dlib/dlib/external/libpng/arm/arm_init.c b/ml/dlib/dlib/external/libpng/arm/arm_init.c new file mode 100644 index 000000000..098771781 --- /dev/null +++ b/ml/dlib/dlib/external/libpng/arm/arm_init.c @@ -0,0 +1,232 @@ + +/* arm_init.c - NEON optimised filter functions + * + * Copyright (c) 2013 Glenn Randers-Pehrson + * Written by Mans Rullgard, 2011. + * Last changed in libpng 1.6.6 [September 16, 2013] + * + * This code is released under the libpng license. + * For conditions of distribution and use, see the disclaimer + * and license in png.h + */ +/* Below, after checking __linux__, various non-C90 POSIX 1003.1 functions are + * called. + */ +#define _POSIX_SOURCE 1 + +#include "../pngpriv.h" + +#ifdef PNG_READ_SUPPORTED +#if PNG_ARM_NEON_OPT > 0 +#ifdef PNG_ARM_NEON_CHECK_SUPPORTED /* Do run-time checks */ +#include /* for sig_atomic_t */ + +#ifdef __ANDROID__ +/* Linux provides access to information about CPU capabilites via + * /proc/self/auxv, however Android blocks this while still claiming to be + * Linux. The Andoid NDK, however, provides appropriate support. + * + * Documentation: http://www.kandroid.org/ndk/docs/CPU-ARM-NEON.html + */ +#include + +static int +png_have_neon(png_structp png_ptr) +{ + /* This is a whole lot easier than the mess below, however it is probably + * implemented as below, therefore it is better to cache the result (these + * function calls may be slow!) + */ + PNG_UNUSED(png_ptr) + return android_getCpuFamily() == ANDROID_CPU_FAMILY_ARM && + (android_getCpuFeatures() & ANDROID_CPU_ARM_FEATURE_NEON) != 0; +} +#elif defined(__linux__) +/* The generic __linux__ implementation requires reading /proc/self/auxv and + * looking at each element for one that records NEON capabilities. + */ +#include /* for POSIX 1003.1 */ +#include /* for EINTR */ + +#include +#include +#include +#include +#include + +/* A read call may be interrupted, in which case it returns -1 and sets errno to + * EINTR if nothing was done, otherwise (if something was done) a partial read + * may result. + */ +static size_t +safe_read(png_structp png_ptr, int fd, void *buffer_in, size_t nbytes) +{ + size_t ntotal = 0; + char *buffer = png_voidcast(char*, buffer_in); + + while (nbytes > 0) + { + unsigned int nread; + int iread; + + /* Passing nread > INT_MAX to read is implementation defined in POSIX + * 1003.1, therefore despite the unsigned argument portable code must + * limit the value to INT_MAX! + */ + if (nbytes > INT_MAX) + nread = INT_MAX; + + else + nread = (unsigned int)/*SAFE*/nbytes; + + iread = read(fd, buffer, nread); + + if (iread == -1) + { + /* This is the devil in the details, a read can terminate early with 0 + * bytes read because of EINTR, yet it still returns -1 otherwise end + * of file cannot be distinguished. + */ + if (errno != EINTR) + { + png_warning(png_ptr, "/proc read failed"); + return 0; /* I.e. a permanent failure */ + } + } + + else if (iread < 0) + { + /* Not a valid 'read' result: */ + png_warning(png_ptr, "OS /proc read bug"); + return 0; + } + + else if (iread > 0) + { + /* Continue reading until a permanent failure, or EOF */ + buffer += iread; + nbytes -= (unsigned int)/*SAFE*/iread; + ntotal += (unsigned int)/*SAFE*/iread; + } + + else + return ntotal; + } + + return ntotal; /* nbytes == 0 */ +} + +static int +png_have_neon(png_structp png_ptr) +{ + int fd = open("/proc/self/auxv", O_RDONLY); + Elf32_auxv_t aux; + + /* Failsafe: failure to open means no NEON */ + if (fd == -1) + { + png_warning(png_ptr, "/proc/self/auxv open failed"); + return 0; + } + + while (safe_read(png_ptr, fd, &aux, sizeof aux) == sizeof aux) + { + if (aux.a_type == AT_HWCAP && (aux.a_un.a_val & HWCAP_NEON) != 0) + { + close(fd); + return 1; + } + } + + close(fd); + return 0; +} +#else + /* We don't know how to do a run-time check on this system */ +# error "no support for run-time ARM NEON checks" +#endif /* OS checks */ +#endif /* PNG_ARM_NEON_CHECK_SUPPORTED */ + +#ifndef PNG_ALIGNED_MEMORY_SUPPORTED +# error "ALIGNED_MEMORY is required; set: -DPNG_ALIGNED_MEMORY_SUPPORTED" +#endif + +void +png_init_filter_functions_neon(png_structp pp, unsigned int bpp) +{ + /* The switch statement is compiled in for ARM_NEON_API, the call to + * png_have_neon is compiled in for ARM_NEON_CHECK. If both are defined + * the check is only performed if the API has not set the NEON option on + * or off explicitly. In this case the check controls what happens. + * + * If the CHECK is not compiled in and the option is UNSET the behavior prior + * to 1.6.7 was to use the NEON code - this was a bug caused by having the + * wrong order of the 'ON' and 'default' cases. UNSET now defaults to OFF, + * as documented in png.h + */ +#ifdef PNG_ARM_NEON_API_SUPPORTED + switch ((pp->options >> PNG_ARM_NEON) & 3) + { + case PNG_OPTION_UNSET: + /* Allow the run-time check to execute if it has been enabled - + * thus both API and CHECK can be turned on. If it isn't supported + * this case will fall through to the 'default' below, which just + * returns. + */ +#endif /* PNG_ARM_NEON_API_SUPPORTED */ +#ifdef PNG_ARM_NEON_CHECK_SUPPORTED + { + static volatile sig_atomic_t no_neon = -1; /* not checked */ + + if (no_neon < 0) + no_neon = !png_have_neon(pp); + + if (no_neon) + return; + } +#ifdef PNG_ARM_NEON_API_SUPPORTED + break; +#endif +#endif /* PNG_ARM_NEON_CHECK_SUPPORTED */ + +#ifdef PNG_ARM_NEON_API_SUPPORTED + default: /* OFF or INVALID */ + return; + + case PNG_OPTION_ON: + /* Option turned on */ + break; + } +#endif + + /* IMPORTANT: any new external functions used here must be declared using + * PNG_INTERNAL_FUNCTION in ../pngpriv.h. This is required so that the + * 'prefix' option to configure works: + * + * ./configure --with-libpng-prefix=foobar_ + * + * Verify you have got this right by running the above command, doing a build + * and examining pngprefix.h; it must contain a #define for every external + * function you add. (Notice that this happens automatically for the + * initialization function.) + */ + pp->read_filter[PNG_FILTER_VALUE_UP-1] = png_read_filter_row_up_neon; + + if (bpp == 3) + { + pp->read_filter[PNG_FILTER_VALUE_SUB-1] = png_read_filter_row_sub3_neon; + pp->read_filter[PNG_FILTER_VALUE_AVG-1] = png_read_filter_row_avg3_neon; + pp->read_filter[PNG_FILTER_VALUE_PAETH-1] = + png_read_filter_row_paeth3_neon; + } + + else if (bpp == 4) + { + pp->read_filter[PNG_FILTER_VALUE_SUB-1] = png_read_filter_row_sub4_neon; + pp->read_filter[PNG_FILTER_VALUE_AVG-1] = png_read_filter_row_avg4_neon; + pp->read_filter[PNG_FILTER_VALUE_PAETH-1] = + png_read_filter_row_paeth4_neon; + } +} +#endif /* PNG_ARM_NEON_OPT > 0 */ +#endif /* PNG_READ_SUPPORTED */ diff --git a/ml/dlib/dlib/external/libpng/arm/filter_neon.S b/ml/dlib/dlib/external/libpng/arm/filter_neon.S new file mode 100644 index 000000000..3d1ccf505 --- /dev/null +++ b/ml/dlib/dlib/external/libpng/arm/filter_neon.S @@ -0,0 +1,245 @@ + +/* filter_neon.S - NEON optimised filter functions + * + * Copyright (c) 2013 Glenn Randers-Pehrson + * Written by Mans Rullgard, 2011. + * Last changed in libpng 1.6.7 [November 14, 2013] + * + * This code is released under the libpng license. + * For conditions of distribution and use, see the disclaimer + * and license in png.h + */ + +/* This is required to get the symbol renames, which are #defines, and also + * includes the definition (or not) of PNG_ARM_NEON_OPT. + */ +#define PNG_VERSION_INFO_ONLY +#include "../pngpriv.h" + +#if defined(__linux__) && defined(__ELF__) +.section .note.GNU-stack,"",%progbits /* mark stack as non-executable */ +#endif + +/* Assembler NEON support - only works for 32-bit ARM (i.e. it does not work for + * ARM64). The code in arm/filter_neon_intrinsics.c supports ARM64, however it + * only works if -mfpu=neon is specified on the GCC command line. See pngpriv.h + * for the logic which sets PNG_USE_ARM_NEON_ASM: + */ +#if PNG_ARM_NEON_IMPLEMENTATION == 2 /* hand-coded assembler */ + +#ifdef PNG_READ_SUPPORTED +#if PNG_ARM_NEON_OPT > 0 + +#ifdef __ELF__ +# define ELF +#else +# define ELF @ +#endif + + .arch armv7-a + .fpu neon + +.macro func name, export=0 + .macro endfunc +ELF .size \name, . - \name + .endfunc + .purgem endfunc + .endm + .text + .if \export + .global \name + .endif +ELF .type \name, STT_FUNC + .func \name +\name: +.endm + +func png_read_filter_row_sub4_neon, export=1 + ldr r3, [r0, #4] @ rowbytes + vmov.i8 d3, #0 +1: + vld4.32 {d4[],d5[],d6[],d7[]}, [r1,:128] + vadd.u8 d0, d3, d4 + vadd.u8 d1, d0, d5 + vadd.u8 d2, d1, d6 + vadd.u8 d3, d2, d7 + vst4.32 {d0[0],d1[0],d2[0],d3[0]},[r1,:128]! + subs r3, r3, #16 + bgt 1b + + bx lr +endfunc + +func png_read_filter_row_sub3_neon, export=1 + ldr r3, [r0, #4] @ rowbytes + vmov.i8 d3, #0 + mov r0, r1 + mov r2, #3 + mov r12, #12 + vld1.8 {q11}, [r0], r12 +1: + vext.8 d5, d22, d23, #3 + vadd.u8 d0, d3, d22 + vext.8 d6, d22, d23, #6 + vadd.u8 d1, d0, d5 + vext.8 d7, d23, d23, #1 + vld1.8 {q11}, [r0], r12 + vst1.32 {d0[0]}, [r1,:32], r2 + vadd.u8 d2, d1, d6 + vst1.32 {d1[0]}, [r1], r2 + vadd.u8 d3, d2, d7 + vst1.32 {d2[0]}, [r1], r2 + vst1.32 {d3[0]}, [r1], r2 + subs r3, r3, #12 + bgt 1b + + bx lr +endfunc + +func png_read_filter_row_up_neon, export=1 + ldr r3, [r0, #4] @ rowbytes +1: + vld1.8 {q0}, [r1,:128] + vld1.8 {q1}, [r2,:128]! + vadd.u8 q0, q0, q1 + vst1.8 {q0}, [r1,:128]! + subs r3, r3, #16 + bgt 1b + + bx lr +endfunc + +func png_read_filter_row_avg4_neon, export=1 + ldr r12, [r0, #4] @ rowbytes + vmov.i8 d3, #0 +1: + vld4.32 {d4[],d5[],d6[],d7[]}, [r1,:128] + vld4.32 {d16[],d17[],d18[],d19[]},[r2,:128]! + vhadd.u8 d0, d3, d16 + vadd.u8 d0, d0, d4 + vhadd.u8 d1, d0, d17 + vadd.u8 d1, d1, d5 + vhadd.u8 d2, d1, d18 + vadd.u8 d2, d2, d6 + vhadd.u8 d3, d2, d19 + vadd.u8 d3, d3, d7 + vst4.32 {d0[0],d1[0],d2[0],d3[0]},[r1,:128]! + subs r12, r12, #16 + bgt 1b + + bx lr +endfunc + +func png_read_filter_row_avg3_neon, export=1 + push {r4,lr} + ldr r12, [r0, #4] @ rowbytes + vmov.i8 d3, #0 + mov r0, r1 + mov r4, #3 + mov lr, #12 + vld1.8 {q11}, [r0], lr +1: + vld1.8 {q10}, [r2], lr + vext.8 d5, d22, d23, #3 + vhadd.u8 d0, d3, d20 + vext.8 d17, d20, d21, #3 + vadd.u8 d0, d0, d22 + vext.8 d6, d22, d23, #6 + vhadd.u8 d1, d0, d17 + vext.8 d18, d20, d21, #6 + vadd.u8 d1, d1, d5 + vext.8 d7, d23, d23, #1 + vld1.8 {q11}, [r0], lr + vst1.32 {d0[0]}, [r1,:32], r4 + vhadd.u8 d2, d1, d18 + vst1.32 {d1[0]}, [r1], r4 + vext.8 d19, d21, d21, #1 + vadd.u8 d2, d2, d6 + vhadd.u8 d3, d2, d19 + vst1.32 {d2[0]}, [r1], r4 + vadd.u8 d3, d3, d7 + vst1.32 {d3[0]}, [r1], r4 + subs r12, r12, #12 + bgt 1b + + pop {r4,pc} +endfunc + +.macro paeth rx, ra, rb, rc + vaddl.u8 q12, \ra, \rb @ a + b + vaddl.u8 q15, \rc, \rc @ 2*c + vabdl.u8 q13, \rb, \rc @ pa + vabdl.u8 q14, \ra, \rc @ pb + vabd.u16 q15, q12, q15 @ pc + vcle.u16 q12, q13, q14 @ pa <= pb + vcle.u16 q13, q13, q15 @ pa <= pc + vcle.u16 q14, q14, q15 @ pb <= pc + vand q12, q12, q13 @ pa <= pb && pa <= pc + vmovn.u16 d28, q14 + vmovn.u16 \rx, q12 + vbsl d28, \rb, \rc + vbsl \rx, \ra, d28 +.endm + +func png_read_filter_row_paeth4_neon, export=1 + ldr r12, [r0, #4] @ rowbytes + vmov.i8 d3, #0 + vmov.i8 d20, #0 +1: + vld4.32 {d4[],d5[],d6[],d7[]}, [r1,:128] + vld4.32 {d16[],d17[],d18[],d19[]},[r2,:128]! + paeth d0, d3, d16, d20 + vadd.u8 d0, d0, d4 + paeth d1, d0, d17, d16 + vadd.u8 d1, d1, d5 + paeth d2, d1, d18, d17 + vadd.u8 d2, d2, d6 + paeth d3, d2, d19, d18 + vmov d20, d19 + vadd.u8 d3, d3, d7 + vst4.32 {d0[0],d1[0],d2[0],d3[0]},[r1,:128]! + subs r12, r12, #16 + bgt 1b + + bx lr +endfunc + +func png_read_filter_row_paeth3_neon, export=1 + push {r4,lr} + ldr r12, [r0, #4] @ rowbytes + vmov.i8 d3, #0 + vmov.i8 d4, #0 + mov r0, r1 + mov r4, #3 + mov lr, #12 + vld1.8 {q11}, [r0], lr +1: + vld1.8 {q10}, [r2], lr + paeth d0, d3, d20, d4 + vext.8 d5, d22, d23, #3 + vadd.u8 d0, d0, d22 + vext.8 d17, d20, d21, #3 + paeth d1, d0, d17, d20 + vst1.32 {d0[0]}, [r1,:32], r4 + vext.8 d6, d22, d23, #6 + vadd.u8 d1, d1, d5 + vext.8 d18, d20, d21, #6 + paeth d2, d1, d18, d17 + vext.8 d7, d23, d23, #1 + vld1.8 {q11}, [r0], lr + vst1.32 {d1[0]}, [r1], r4 + vadd.u8 d2, d2, d6 + vext.8 d19, d21, d21, #1 + paeth d3, d2, d19, d18 + vst1.32 {d2[0]}, [r1], r4 + vmov d4, d19 + vadd.u8 d3, d3, d7 + vst1.32 {d3[0]}, [r1], r4 + subs r12, r12, #12 + bgt 1b + + pop {r4,pc} +endfunc +#endif /* PNG_ARM_NEON_OPT > 0 */ +#endif /* PNG_READ_SUPPORTED */ +#endif /* PNG_ARM_NEON_IMPLEMENTATION == 2 (assembler) */ diff --git a/ml/dlib/dlib/external/libpng/arm/filter_neon_intrinsics.c b/ml/dlib/dlib/external/libpng/arm/filter_neon_intrinsics.c new file mode 100644 index 000000000..e6a0217ab --- /dev/null +++ b/ml/dlib/dlib/external/libpng/arm/filter_neon_intrinsics.c @@ -0,0 +1,372 @@ + +/* filter_neon_intrinsics.c - NEON optimised filter functions + * + * Copyright (c) 2013 Glenn Randers-Pehrson + * Written by James Yu , October 2013. + * Based on filter_neon.S, written by Mans Rullgard, 2011. + * + * Last changed in libpng 1.6.7 [November 14, 2013] + * + * This code is released under the libpng license. + * For conditions of distribution and use, see the disclaimer + * and license in png.h + */ + +#include "../pngpriv.h" + +/* This code requires -mfpu=neon on the command line: */ +#if PNG_ARM_NEON_IMPLEMENTATION == 1 /* intrinsics code */ + +#include + +/* libpng row pointers are not necessarily aligned to any particular boundary, + * however this code will only work with appropriate alignment. arm/arm_init.c + * checks for this (and will not compile unless it is done), this code uses + * variants of png_aligncast to avoid compiler warnings. + */ +#define png_ptr(type,pointer) png_aligncast(type *,pointer) +#define png_ptrc(type,pointer) png_aligncastconst(const type *,pointer) + +/* The following relies on a variable 'temp_pointer' being declared with type + * 'type'. This is written this way just to hide the GCC strict aliasing + * warning; note that the code is safe because there never is an alias between + * the input and output pointers. + */ +#define png_ldr(type,pointer)\ + (temp_pointer = png_ptr(type,pointer), *temp_pointer) + +#ifdef PNG_READ_SUPPORTED +#if PNG_ARM_NEON_OPT > 0 + +void +png_read_filter_row_up_neon(png_row_infop row_info, png_bytep row, + png_const_bytep prev_row) +{ + png_bytep rp = row; + png_bytep rp_stop = row + row_info->rowbytes; + png_const_bytep pp = prev_row; + + for (; rp < rp_stop; rp += 16, pp += 16) + { + uint8x16_t qrp, qpp; + + qrp = vld1q_u8(rp); + qpp = vld1q_u8(pp); + qrp = vaddq_u8(qrp, qpp); + vst1q_u8(rp, qrp); + } +} + +void +png_read_filter_row_sub3_neon(png_row_infop row_info, png_bytep row, + png_const_bytep prev_row) +{ + png_bytep rp = row; + png_bytep rp_stop = row + row_info->rowbytes; + + uint8x16_t vtmp = vld1q_u8(rp); + uint8x8x2_t *vrpt = png_ptr(uint8x8x2_t, &vtmp); + uint8x8x2_t vrp = *vrpt; + + uint8x8x4_t vdest; + vdest.val[3] = vdup_n_u8(0); + + for (; rp < rp_stop;) + { + uint8x8_t vtmp1, vtmp2; + uint32x2_t *temp_pointer; + + vtmp1 = vext_u8(vrp.val[0], vrp.val[1], 3); + vdest.val[0] = vadd_u8(vdest.val[3], vrp.val[0]); + vtmp2 = vext_u8(vrp.val[0], vrp.val[1], 6); + vdest.val[1] = vadd_u8(vdest.val[0], vtmp1); + + vtmp1 = vext_u8(vrp.val[1], vrp.val[1], 1); + vdest.val[2] = vadd_u8(vdest.val[1], vtmp2); + vdest.val[3] = vadd_u8(vdest.val[2], vtmp1); + + vtmp = vld1q_u8(rp + 12); + vrpt = png_ptr(uint8x8x2_t, &vtmp); + vrp = *vrpt; + + vst1_lane_u32(png_ptr(uint32_t,rp), png_ldr(uint32x2_t,&vdest.val[0]), 0); + rp += 3; + vst1_lane_u32(png_ptr(uint32_t,rp), png_ldr(uint32x2_t,&vdest.val[1]), 0); + rp += 3; + vst1_lane_u32(png_ptr(uint32_t,rp), png_ldr(uint32x2_t,&vdest.val[2]), 0); + rp += 3; + vst1_lane_u32(png_ptr(uint32_t,rp), png_ldr(uint32x2_t,&vdest.val[3]), 0); + rp += 3; + } + + PNG_UNUSED(prev_row) +} + +void +png_read_filter_row_sub4_neon(png_row_infop row_info, png_bytep row, + png_const_bytep prev_row) +{ + png_bytep rp = row; + png_bytep rp_stop = row + row_info->rowbytes; + + uint8x8x4_t vdest; + vdest.val[3] = vdup_n_u8(0); + + for (; rp < rp_stop; rp += 16) + { + uint32x2x4_t vtmp = vld4_u32(png_ptr(uint32_t,rp)); + uint8x8x4_t *vrpt = png_ptr(uint8x8x4_t,&vtmp); + uint8x8x4_t vrp = *vrpt; + uint32x2x4_t *temp_pointer; + + vdest.val[0] = vadd_u8(vdest.val[3], vrp.val[0]); + vdest.val[1] = vadd_u8(vdest.val[0], vrp.val[1]); + vdest.val[2] = vadd_u8(vdest.val[1], vrp.val[2]); + vdest.val[3] = vadd_u8(vdest.val[2], vrp.val[3]); + vst4_lane_u32(png_ptr(uint32_t,rp), png_ldr(uint32x2x4_t,&vdest), 0); + } + + PNG_UNUSED(prev_row) +} + +void +png_read_filter_row_avg3_neon(png_row_infop row_info, png_bytep row, + png_const_bytep prev_row) +{ + png_bytep rp = row; + png_const_bytep pp = prev_row; + png_bytep rp_stop = row + row_info->rowbytes; + + uint8x16_t vtmp; + uint8x8x2_t *vrpt; + uint8x8x2_t vrp; + uint8x8x4_t vdest; + vdest.val[3] = vdup_n_u8(0); + + vtmp = vld1q_u8(rp); + vrpt = png_ptr(uint8x8x2_t,&vtmp); + vrp = *vrpt; + + for (; rp < rp_stop; pp += 12) + { + uint8x8_t vtmp1, vtmp2, vtmp3; + + uint8x8x2_t *vppt; + uint8x8x2_t vpp; + + uint32x2_t *temp_pointer; + + vtmp = vld1q_u8(pp); + vppt = png_ptr(uint8x8x2_t,&vtmp); + vpp = *vppt; + + vtmp1 = vext_u8(vrp.val[0], vrp.val[1], 3); + vdest.val[0] = vhadd_u8(vdest.val[3], vpp.val[0]); + vdest.val[0] = vadd_u8(vdest.val[0], vrp.val[0]); + + vtmp2 = vext_u8(vpp.val[0], vpp.val[1], 3); + vtmp3 = vext_u8(vrp.val[0], vrp.val[1], 6); + vdest.val[1] = vhadd_u8(vdest.val[0], vtmp2); + vdest.val[1] = vadd_u8(vdest.val[1], vtmp1); + + vtmp2 = vext_u8(vpp.val[0], vpp.val[1], 6); + vtmp1 = vext_u8(vrp.val[1], vrp.val[1], 1); + + vtmp = vld1q_u8(rp + 12); + vrpt = png_ptr(uint8x8x2_t,&vtmp); + vrp = *vrpt; + + vdest.val[2] = vhadd_u8(vdest.val[1], vtmp2); + vdest.val[2] = vadd_u8(vdest.val[2], vtmp3); + + vtmp2 = vext_u8(vpp.val[1], vpp.val[1], 1); + + vdest.val[3] = vhadd_u8(vdest.val[2], vtmp2); + vdest.val[3] = vadd_u8(vdest.val[3], vtmp1); + + vst1_lane_u32(png_ptr(uint32_t,rp), png_ldr(uint32x2_t,&vdest.val[0]), 0); + rp += 3; + vst1_lane_u32(png_ptr(uint32_t,rp), png_ldr(uint32x2_t,&vdest.val[1]), 0); + rp += 3; + vst1_lane_u32(png_ptr(uint32_t,rp), png_ldr(uint32x2_t,&vdest.val[2]), 0); + rp += 3; + vst1_lane_u32(png_ptr(uint32_t,rp), png_ldr(uint32x2_t,&vdest.val[3]), 0); + rp += 3; + } +} + +void +png_read_filter_row_avg4_neon(png_row_infop row_info, png_bytep row, + png_const_bytep prev_row) +{ + png_bytep rp = row; + png_bytep rp_stop = row + row_info->rowbytes; + png_const_bytep pp = prev_row; + + uint8x8x4_t vdest; + vdest.val[3] = vdup_n_u8(0); + + for (; rp < rp_stop; rp += 16, pp += 16) + { + uint32x2x4_t vtmp; + uint8x8x4_t *vrpt, *vppt; + uint8x8x4_t vrp, vpp; + uint32x2x4_t *temp_pointer; + + vtmp = vld4_u32(png_ptr(uint32_t,rp)); + vrpt = png_ptr(uint8x8x4_t,&vtmp); + vrp = *vrpt; + vtmp = vld4_u32(png_ptrc(uint32_t,pp)); + vppt = png_ptr(uint8x8x4_t,&vtmp); + vpp = *vppt; + + vdest.val[0] = vhadd_u8(vdest.val[3], vpp.val[0]); + vdest.val[0] = vadd_u8(vdest.val[0], vrp.val[0]); + vdest.val[1] = vhadd_u8(vdest.val[0], vpp.val[1]); + vdest.val[1] = vadd_u8(vdest.val[1], vrp.val[1]); + vdest.val[2] = vhadd_u8(vdest.val[1], vpp.val[2]); + vdest.val[2] = vadd_u8(vdest.val[2], vrp.val[2]); + vdest.val[3] = vhadd_u8(vdest.val[2], vpp.val[3]); + vdest.val[3] = vadd_u8(vdest.val[3], vrp.val[3]); + + vst4_lane_u32(png_ptr(uint32_t,rp), png_ldr(uint32x2x4_t,&vdest), 0); + } +} + +static uint8x8_t +paeth(uint8x8_t a, uint8x8_t b, uint8x8_t c) +{ + uint8x8_t d, e; + uint16x8_t p1, pa, pb, pc; + + p1 = vaddl_u8(a, b); /* a + b */ + pc = vaddl_u8(c, c); /* c * 2 */ + pa = vabdl_u8(b, c); /* pa */ + pb = vabdl_u8(a, c); /* pb */ + pc = vabdq_u16(p1, pc); /* pc */ + + p1 = vcleq_u16(pa, pb); /* pa <= pb */ + pa = vcleq_u16(pa, pc); /* pa <= pc */ + pb = vcleq_u16(pb, pc); /* pb <= pc */ + + p1 = vandq_u16(p1, pa); /* pa <= pb && pa <= pc */ + + d = vmovn_u16(pb); + e = vmovn_u16(p1); + + d = vbsl_u8(d, b, c); + e = vbsl_u8(e, a, d); + + return e; +} + +void +png_read_filter_row_paeth3_neon(png_row_infop row_info, png_bytep row, + png_const_bytep prev_row) +{ + png_bytep rp = row; + png_const_bytep pp = prev_row; + png_bytep rp_stop = row + row_info->rowbytes; + + uint8x16_t vtmp; + uint8x8x2_t *vrpt; + uint8x8x2_t vrp; + uint8x8_t vlast = vdup_n_u8(0); + uint8x8x4_t vdest; + vdest.val[3] = vdup_n_u8(0); + + vtmp = vld1q_u8(rp); + vrpt = png_ptr(uint8x8x2_t,&vtmp); + vrp = *vrpt; + + for (; rp < rp_stop; pp += 12) + { + uint8x8x2_t *vppt; + uint8x8x2_t vpp; + uint8x8_t vtmp1, vtmp2, vtmp3; + uint32x2_t *temp_pointer; + + vtmp = vld1q_u8(pp); + vppt = png_ptr(uint8x8x2_t,&vtmp); + vpp = *vppt; + + vdest.val[0] = paeth(vdest.val[3], vpp.val[0], vlast); + vdest.val[0] = vadd_u8(vdest.val[0], vrp.val[0]); + + vtmp1 = vext_u8(vrp.val[0], vrp.val[1], 3); + vtmp2 = vext_u8(vpp.val[0], vpp.val[1], 3); + vdest.val[1] = paeth(vdest.val[0], vtmp2, vpp.val[0]); + vdest.val[1] = vadd_u8(vdest.val[1], vtmp1); + + vtmp1 = vext_u8(vrp.val[0], vrp.val[1], 6); + vtmp3 = vext_u8(vpp.val[0], vpp.val[1], 6); + vdest.val[2] = paeth(vdest.val[1], vtmp3, vtmp2); + vdest.val[2] = vadd_u8(vdest.val[2], vtmp1); + + vtmp1 = vext_u8(vrp.val[1], vrp.val[1], 1); + vtmp2 = vext_u8(vpp.val[1], vpp.val[1], 1); + + vtmp = vld1q_u8(rp + 12); + vrpt = png_ptr(uint8x8x2_t,&vtmp); + vrp = *vrpt; + + vdest.val[3] = paeth(vdest.val[2], vtmp2, vtmp3); + vdest.val[3] = vadd_u8(vdest.val[3], vtmp1); + + vlast = vtmp2; + + vst1_lane_u32(png_ptr(uint32_t,rp), png_ldr(uint32x2_t,&vdest.val[0]), 0); + rp += 3; + vst1_lane_u32(png_ptr(uint32_t,rp), png_ldr(uint32x2_t,&vdest.val[1]), 0); + rp += 3; + vst1_lane_u32(png_ptr(uint32_t,rp), png_ldr(uint32x2_t,&vdest.val[2]), 0); + rp += 3; + vst1_lane_u32(png_ptr(uint32_t,rp), png_ldr(uint32x2_t,&vdest.val[3]), 0); + rp += 3; + } +} + +void +png_read_filter_row_paeth4_neon(png_row_infop row_info, png_bytep row, + png_const_bytep prev_row) +{ + png_bytep rp = row; + png_bytep rp_stop = row + row_info->rowbytes; + png_const_bytep pp = prev_row; + + uint8x8_t vlast = vdup_n_u8(0); + uint8x8x4_t vdest; + vdest.val[3] = vdup_n_u8(0); + + for (; rp < rp_stop; rp += 16, pp += 16) + { + uint32x2x4_t vtmp; + uint8x8x4_t *vrpt, *vppt; + uint8x8x4_t vrp, vpp; + uint32x2x4_t *temp_pointer; + + vtmp = vld4_u32(png_ptr(uint32_t,rp)); + vrpt = png_ptr(uint8x8x4_t,&vtmp); + vrp = *vrpt; + vtmp = vld4_u32(png_ptrc(uint32_t,pp)); + vppt = png_ptr(uint8x8x4_t,&vtmp); + vpp = *vppt; + + vdest.val[0] = paeth(vdest.val[3], vpp.val[0], vlast); + vdest.val[0] = vadd_u8(vdest.val[0], vrp.val[0]); + vdest.val[1] = paeth(vdest.val[0], vpp.val[1], vpp.val[0]); + vdest.val[1] = vadd_u8(vdest.val[1], vrp.val[1]); + vdest.val[2] = paeth(vdest.val[1], vpp.val[2], vpp.val[1]); + vdest.val[2] = vadd_u8(vdest.val[2], vrp.val[2]); + vdest.val[3] = paeth(vdest.val[2], vpp.val[3], vpp.val[2]); + vdest.val[3] = vadd_u8(vdest.val[3], vrp.val[3]); + + vlast = vpp.val[3]; + + vst4_lane_u32(png_ptr(uint32_t,rp), png_ldr(uint32x2x4_t,&vdest), 0); + } +} + +#endif /* PNG_ARM_NEON_OPT > 0 */ +#endif /* PNG_READ_SUPPORTED */ +#endif /* PNG_ARM_NEON_IMPLEMENTATION == 1 (intrinsics) */ diff --git a/ml/dlib/dlib/external/libpng/png.c b/ml/dlib/dlib/external/libpng/png.c new file mode 100644 index 000000000..efcc6eead --- /dev/null +++ b/ml/dlib/dlib/external/libpng/png.c @@ -0,0 +1,4299 @@ + +/* png.c - location for general purpose libpng functions + * + * Last changed in libpng 1.6.2 [April 25, 2013] + * Copyright (c) 1998-2013 Glenn Randers-Pehrson + * (Version 0.96 Copyright (c) 1996, 1997 Andreas Dilger) + * (Version 0.88 Copyright (c) 1995, 1996 Guy Eric Schalnat, Group 42, Inc.) + * + * This code is released under the libpng license. + * For conditions of distribution and use, see the disclaimer + * and license in png.h + */ + +#include "pngpriv.h" + +/* Generate a compiler error if there is an old png.h in the search path. */ +typedef png_libpng_version_1_6_7 Your_png_h_is_not_version_1_6_7; + +/* Tells libpng that we have already handled the first "num_bytes" bytes + * of the PNG file signature. If the PNG data is embedded into another + * stream we can set num_bytes = 8 so that libpng will not attempt to read + * or write any of the magic bytes before it starts on the IHDR. + */ + +#ifdef PNG_READ_SUPPORTED +void PNGAPI +png_set_sig_bytes(png_structrp png_ptr, int num_bytes) +{ + png_debug(1, "in png_set_sig_bytes"); + + if (png_ptr == NULL) + return; + + if (num_bytes > 8) + png_error(png_ptr, "Too many bytes for PNG signature"); + + png_ptr->sig_bytes = (png_byte)(num_bytes < 0 ? 0 : num_bytes); +} + +/* Checks whether the supplied bytes match the PNG signature. We allow + * checking less than the full 8-byte signature so that those apps that + * already read the first few bytes of a file to determine the file type + * can simply check the remaining bytes for extra assurance. Returns + * an integer less than, equal to, or greater than zero if sig is found, + * respectively, to be less than, to match, or be greater than the correct + * PNG signature (this is the same behavior as strcmp, memcmp, etc). + */ +int PNGAPI +png_sig_cmp(png_const_bytep sig, png_size_t start, png_size_t num_to_check) +{ + png_byte png_signature[8] = {137, 80, 78, 71, 13, 10, 26, 10}; + + if (num_to_check > 8) + num_to_check = 8; + + else if (num_to_check < 1) + return (-1); + + if (start > 7) + return (-1); + + if (start + num_to_check > 8) + num_to_check = 8 - start; + + return ((int)(memcmp(&sig[start], &png_signature[start], num_to_check))); +} + +#endif /* PNG_READ_SUPPORTED */ + +#if defined(PNG_READ_SUPPORTED) || defined(PNG_WRITE_SUPPORTED) +/* Function to allocate memory for zlib */ +PNG_FUNCTION(voidpf /* PRIVATE */, +png_zalloc,(voidpf png_ptr, uInt items, uInt size),PNG_ALLOCATED) +{ + png_alloc_size_t num_bytes = size; + + if (png_ptr == NULL) + return NULL; + + if (items >= (~(png_alloc_size_t)0)/size) + { + png_warning (png_voidcast(png_structrp, png_ptr), + "Potential overflow in png_zalloc()"); + return NULL; + } + + num_bytes *= items; + return png_malloc_warn(png_voidcast(png_structrp, png_ptr), num_bytes); +} + +/* Function to free memory for zlib */ +void /* PRIVATE */ +png_zfree(voidpf png_ptr, voidpf ptr) +{ + png_free(png_voidcast(png_const_structrp,png_ptr), ptr); +} + +/* Reset the CRC variable to 32 bits of 1's. Care must be taken + * in case CRC is > 32 bits to leave the top bits 0. + */ +void /* PRIVATE */ +png_reset_crc(png_structrp png_ptr) +{ + /* The cast is safe because the crc is a 32 bit value. */ + png_ptr->crc = (png_uint_32)crc32(0, Z_NULL, 0); +} + +/* Calculate the CRC over a section of data. We can only pass as + * much data to this routine as the largest single buffer size. We + * also check that this data will actually be used before going to the + * trouble of calculating it. + */ +void /* PRIVATE */ +png_calculate_crc(png_structrp png_ptr, png_const_bytep ptr, png_size_t length) +{ + int need_crc = 1; + + if (PNG_CHUNK_ANCILLARY(png_ptr->chunk_name)) + { + if ((png_ptr->flags & PNG_FLAG_CRC_ANCILLARY_MASK) == + (PNG_FLAG_CRC_ANCILLARY_USE | PNG_FLAG_CRC_ANCILLARY_NOWARN)) + need_crc = 0; + } + + else /* critical */ + { + if (png_ptr->flags & PNG_FLAG_CRC_CRITICAL_IGNORE) + need_crc = 0; + } + + /* 'uLong' is defined in zlib.h as unsigned long; this means that on some + * systems it is a 64 bit value. crc32, however, returns 32 bits so the + * following cast is safe. 'uInt' may be no more than 16 bits, so it is + * necessary to perform a loop here. + */ + if (need_crc && length > 0) + { + uLong crc = png_ptr->crc; /* Should never issue a warning */ + + do + { + uInt safe_length = (uInt)length; + if (safe_length == 0) + safe_length = (uInt)-1; /* evil, but safe */ + + crc = crc32(crc, ptr, safe_length); + + /* The following should never issue compiler warnings; if they do the + * target system has characteristics that will probably violate other + * assumptions within the libpng code. + */ + ptr += safe_length; + length -= safe_length; + } + while (length > 0); + + /* And the following is always safe because the crc is only 32 bits. */ + png_ptr->crc = (png_uint_32)crc; + } +} + +/* Check a user supplied version number, called from both read and write + * functions that create a png_struct. + */ +int +png_user_version_check(png_structrp png_ptr, png_const_charp user_png_ver) +{ + if (user_png_ver) + { + int i = 0; + + do + { + if (user_png_ver[i] != png_libpng_ver[i]) + png_ptr->flags |= PNG_FLAG_LIBRARY_MISMATCH; + } while (png_libpng_ver[i++]); + } + + else + png_ptr->flags |= PNG_FLAG_LIBRARY_MISMATCH; + + if (png_ptr->flags & PNG_FLAG_LIBRARY_MISMATCH) + { + /* Libpng 0.90 and later are binary incompatible with libpng 0.89, so + * we must recompile any applications that use any older library version. + * For versions after libpng 1.0, we will be compatible, so we need + * only check the first and third digits (note that when we reach version + * 1.10 we will need to check the fourth symbol, namely user_png_ver[3]). + */ + if (user_png_ver == NULL || user_png_ver[0] != png_libpng_ver[0] || + (user_png_ver[0] == '1' && (user_png_ver[2] != png_libpng_ver[2] || + user_png_ver[3] != png_libpng_ver[3])) || + (user_png_ver[0] == '0' && user_png_ver[2] < '9')) + { +#ifdef PNG_WARNINGS_SUPPORTED + size_t pos = 0; + char m[128]; + + pos = png_safecat(m, (sizeof m), pos, + "Application built with libpng-"); + pos = png_safecat(m, (sizeof m), pos, user_png_ver); + pos = png_safecat(m, (sizeof m), pos, " but running with "); + pos = png_safecat(m, (sizeof m), pos, png_libpng_ver); + + png_warning(png_ptr, m); +#endif + +#ifdef PNG_ERROR_NUMBERS_SUPPORTED + png_ptr->flags = 0; +#endif + + return 0; + } + } + + /* Success return. */ + return 1; +} + +/* Generic function to create a png_struct for either read or write - this + * contains the common initialization. + */ +PNG_FUNCTION(png_structp /* PRIVATE */, +png_create_png_struct,(png_const_charp user_png_ver, png_voidp error_ptr, + png_error_ptr error_fn, png_error_ptr warn_fn, png_voidp mem_ptr, + png_malloc_ptr malloc_fn, png_free_ptr free_fn),PNG_ALLOCATED) +{ + png_struct create_struct; +# ifdef PNG_SETJMP_SUPPORTED + jmp_buf create_jmp_buf; +# endif + + /* This temporary stack-allocated structure is used to provide a place to + * build enough context to allow the user provided memory allocator (if any) + * to be called. + */ + memset(&create_struct, 0, (sizeof create_struct)); + + /* Added at libpng-1.2.6 */ +# ifdef PNG_USER_LIMITS_SUPPORTED + create_struct.user_width_max = PNG_USER_WIDTH_MAX; + create_struct.user_height_max = PNG_USER_HEIGHT_MAX; + +# ifdef PNG_USER_CHUNK_CACHE_MAX + /* Added at libpng-1.2.43 and 1.4.0 */ + create_struct.user_chunk_cache_max = PNG_USER_CHUNK_CACHE_MAX; +# endif + +# ifdef PNG_USER_CHUNK_MALLOC_MAX + /* Added at libpng-1.2.43 and 1.4.1, required only for read but exists + * in png_struct regardless. + */ + create_struct.user_chunk_malloc_max = PNG_USER_CHUNK_MALLOC_MAX; +# endif +# endif + + /* The following two API calls simply set fields in png_struct, so it is safe + * to do them now even though error handling is not yet set up. + */ +# ifdef PNG_USER_MEM_SUPPORTED + png_set_mem_fn(&create_struct, mem_ptr, malloc_fn, free_fn); +# endif + + /* (*error_fn) can return control to the caller after the error_ptr is set, + * this will result in a memory leak unless the error_fn does something + * extremely sophisticated. The design lacks merit but is implicit in the + * API. + */ + png_set_error_fn(&create_struct, error_ptr, error_fn, warn_fn); + +# ifdef PNG_SETJMP_SUPPORTED + if (!setjmp(create_jmp_buf)) + { + /* Temporarily fake out the longjmp information until we have + * successfully completed this function. This only works if we have + * setjmp() support compiled in, but it is safe - this stuff should + * never happen. + */ + create_struct.jmp_buf_ptr = &create_jmp_buf; + create_struct.jmp_buf_size = 0; /*stack allocation*/ + create_struct.longjmp_fn = longjmp; +# else + { +# endif + /* Call the general version checker (shared with read and write code): + */ + if (png_user_version_check(&create_struct, user_png_ver)) + { + png_structrp png_ptr = png_voidcast(png_structrp, + png_malloc_warn(&create_struct, (sizeof *png_ptr))); + + if (png_ptr != NULL) + { + /* png_ptr->zstream holds a back-pointer to the png_struct, so + * this can only be done now: + */ + create_struct.zstream.zalloc = png_zalloc; + create_struct.zstream.zfree = png_zfree; + create_struct.zstream.opaque = png_ptr; + +# ifdef PNG_SETJMP_SUPPORTED + /* Eliminate the local error handling: */ + create_struct.jmp_buf_ptr = NULL; + create_struct.jmp_buf_size = 0; + create_struct.longjmp_fn = 0; +# endif + + *png_ptr = create_struct; + + /* This is the successful return point */ + return png_ptr; + } + } + } + + /* A longjmp because of a bug in the application storage allocator or a + * simple failure to allocate the png_struct. + */ + return NULL; +} + +/* Allocate the memory for an info_struct for the application. */ +PNG_FUNCTION(png_infop,PNGAPI +png_create_info_struct,(png_const_structrp png_ptr),PNG_ALLOCATED) +{ + png_inforp info_ptr; + + png_debug(1, "in png_create_info_struct"); + + if (png_ptr == NULL) + return NULL; + + /* Use the internal API that does not (or at least should not) error out, so + * that this call always returns ok. The application typically sets up the + * error handling *after* creating the info_struct because this is the way it + * has always been done in 'example.c'. + */ + info_ptr = png_voidcast(png_inforp, png_malloc_base(png_ptr, + (sizeof *info_ptr))); + + if (info_ptr != NULL) + memset(info_ptr, 0, (sizeof *info_ptr)); + + return info_ptr; +} + +/* This function frees the memory associated with a single info struct. + * Normally, one would use either png_destroy_read_struct() or + * png_destroy_write_struct() to free an info struct, but this may be + * useful for some applications. From libpng 1.6.0 this function is also used + * internally to implement the png_info release part of the 'struct' destroy + * APIs. This ensures that all possible approaches free the same data (all of + * it). + */ +void PNGAPI +png_destroy_info_struct(png_const_structrp png_ptr, png_infopp info_ptr_ptr) +{ + png_inforp info_ptr = NULL; + + png_debug(1, "in png_destroy_info_struct"); + + if (png_ptr == NULL) + return; + + if (info_ptr_ptr != NULL) + info_ptr = *info_ptr_ptr; + + if (info_ptr != NULL) + { + /* Do this first in case of an error below; if the app implements its own + * memory management this can lead to png_free calling png_error, which + * will abort this routine and return control to the app error handler. + * An infinite loop may result if it then tries to free the same info + * ptr. + */ + *info_ptr_ptr = NULL; + + png_free_data(png_ptr, info_ptr, PNG_FREE_ALL, -1); + memset(info_ptr, 0, (sizeof *info_ptr)); + png_free(png_ptr, info_ptr); + } +} + +/* Initialize the info structure. This is now an internal function (0.89) + * and applications using it are urged to use png_create_info_struct() + * instead. Use deprecated in 1.6.0, internal use removed (used internally it + * is just a memset). + * + * NOTE: it is almost inconceivable that this API is used because it bypasses + * the user-memory mechanism and the user error handling/warning mechanisms in + * those cases where it does anything other than a memset. + */ +PNG_FUNCTION(void,PNGAPI +png_info_init_3,(png_infopp ptr_ptr, png_size_t png_info_struct_size), + PNG_DEPRECATED) +{ + png_inforp info_ptr = *ptr_ptr; + + png_debug(1, "in png_info_init_3"); + + if (info_ptr == NULL) + return; + + if ((sizeof (png_info)) > png_info_struct_size) + { + *ptr_ptr = NULL; + /* The following line is why this API should not be used: */ + free(info_ptr); + info_ptr = png_voidcast(png_inforp, png_malloc_base(NULL, + (sizeof *info_ptr))); + *ptr_ptr = info_ptr; + } + + /* Set everything to 0 */ + memset(info_ptr, 0, (sizeof *info_ptr)); +} + +/* The following API is not called internally */ +void PNGAPI +png_data_freer(png_const_structrp png_ptr, png_inforp info_ptr, + int freer, png_uint_32 mask) +{ + png_debug(1, "in png_data_freer"); + + if (png_ptr == NULL || info_ptr == NULL) + return; + + if (freer == PNG_DESTROY_WILL_FREE_DATA) + info_ptr->free_me |= mask; + + else if (freer == PNG_USER_WILL_FREE_DATA) + info_ptr->free_me &= ~mask; + + else + png_error(png_ptr, "Unknown freer parameter in png_data_freer"); +} + +void PNGAPI +png_free_data(png_const_structrp png_ptr, png_inforp info_ptr, png_uint_32 mask, + int num) +{ + png_debug(1, "in png_free_data"); + + if (png_ptr == NULL || info_ptr == NULL) + return; + +#ifdef PNG_TEXT_SUPPORTED + /* Free text item num or (if num == -1) all text items */ + if ((mask & PNG_FREE_TEXT) & info_ptr->free_me) + { + if (num != -1) + { + if (info_ptr->text && info_ptr->text[num].key) + { + png_free(png_ptr, info_ptr->text[num].key); + info_ptr->text[num].key = NULL; + } + } + + else + { + int i; + for (i = 0; i < info_ptr->num_text; i++) + png_free_data(png_ptr, info_ptr, PNG_FREE_TEXT, i); + png_free(png_ptr, info_ptr->text); + info_ptr->text = NULL; + info_ptr->num_text=0; + } + } +#endif + +#ifdef PNG_tRNS_SUPPORTED + /* Free any tRNS entry */ + if ((mask & PNG_FREE_TRNS) & info_ptr->free_me) + { + png_free(png_ptr, info_ptr->trans_alpha); + info_ptr->trans_alpha = NULL; + info_ptr->valid &= ~PNG_INFO_tRNS; + } +#endif + +#ifdef PNG_sCAL_SUPPORTED + /* Free any sCAL entry */ + if ((mask & PNG_FREE_SCAL) & info_ptr->free_me) + { + png_free(png_ptr, info_ptr->scal_s_width); + png_free(png_ptr, info_ptr->scal_s_height); + info_ptr->scal_s_width = NULL; + info_ptr->scal_s_height = NULL; + info_ptr->valid &= ~PNG_INFO_sCAL; + } +#endif + +#ifdef PNG_pCAL_SUPPORTED + /* Free any pCAL entry */ + if ((mask & PNG_FREE_PCAL) & info_ptr->free_me) + { + png_free(png_ptr, info_ptr->pcal_purpose); + png_free(png_ptr, info_ptr->pcal_units); + info_ptr->pcal_purpose = NULL; + info_ptr->pcal_units = NULL; + if (info_ptr->pcal_params != NULL) + { + unsigned int i; + for (i = 0; i < info_ptr->pcal_nparams; i++) + { + png_free(png_ptr, info_ptr->pcal_params[i]); + info_ptr->pcal_params[i] = NULL; + } + png_free(png_ptr, info_ptr->pcal_params); + info_ptr->pcal_params = NULL; + } + info_ptr->valid &= ~PNG_INFO_pCAL; + } +#endif + +#ifdef PNG_iCCP_SUPPORTED + /* Free any profile entry */ + if ((mask & PNG_FREE_ICCP) & info_ptr->free_me) + { + png_free(png_ptr, info_ptr->iccp_name); + png_free(png_ptr, info_ptr->iccp_profile); + info_ptr->iccp_name = NULL; + info_ptr->iccp_profile = NULL; + info_ptr->valid &= ~PNG_INFO_iCCP; + } +#endif + +#ifdef PNG_sPLT_SUPPORTED + /* Free a given sPLT entry, or (if num == -1) all sPLT entries */ + if ((mask & PNG_FREE_SPLT) & info_ptr->free_me) + { + if (num != -1) + { + if (info_ptr->splt_palettes) + { + png_free(png_ptr, info_ptr->splt_palettes[num].name); + png_free(png_ptr, info_ptr->splt_palettes[num].entries); + info_ptr->splt_palettes[num].name = NULL; + info_ptr->splt_palettes[num].entries = NULL; + } + } + + else + { + if (info_ptr->splt_palettes_num) + { + int i; + for (i = 0; i < info_ptr->splt_palettes_num; i++) + png_free_data(png_ptr, info_ptr, PNG_FREE_SPLT, (int)i); + + png_free(png_ptr, info_ptr->splt_palettes); + info_ptr->splt_palettes = NULL; + info_ptr->splt_palettes_num = 0; + } + info_ptr->valid &= ~PNG_INFO_sPLT; + } + } +#endif + +#ifdef PNG_STORE_UNKNOWN_CHUNKS_SUPPORTED + if ((mask & PNG_FREE_UNKN) & info_ptr->free_me) + { + if (num != -1) + { + if (info_ptr->unknown_chunks) + { + png_free(png_ptr, info_ptr->unknown_chunks[num].data); + info_ptr->unknown_chunks[num].data = NULL; + } + } + + else + { + int i; + + if (info_ptr->unknown_chunks_num) + { + for (i = 0; i < info_ptr->unknown_chunks_num; i++) + png_free_data(png_ptr, info_ptr, PNG_FREE_UNKN, (int)i); + + png_free(png_ptr, info_ptr->unknown_chunks); + info_ptr->unknown_chunks = NULL; + info_ptr->unknown_chunks_num = 0; + } + } + } +#endif + +#ifdef PNG_hIST_SUPPORTED + /* Free any hIST entry */ + if ((mask & PNG_FREE_HIST) & info_ptr->free_me) + { + png_free(png_ptr, info_ptr->hist); + info_ptr->hist = NULL; + info_ptr->valid &= ~PNG_INFO_hIST; + } +#endif + + /* Free any PLTE entry that was internally allocated */ + if ((mask & PNG_FREE_PLTE) & info_ptr->free_me) + { + png_free(png_ptr, info_ptr->palette); + info_ptr->palette = NULL; + info_ptr->valid &= ~PNG_INFO_PLTE; + info_ptr->num_palette = 0; + } + +#ifdef PNG_INFO_IMAGE_SUPPORTED + /* Free any image bits attached to the info structure */ + if ((mask & PNG_FREE_ROWS) & info_ptr->free_me) + { + if (info_ptr->row_pointers) + { + png_uint_32 row; + for (row = 0; row < info_ptr->height; row++) + { + png_free(png_ptr, info_ptr->row_pointers[row]); + info_ptr->row_pointers[row] = NULL; + } + png_free(png_ptr, info_ptr->row_pointers); + info_ptr->row_pointers = NULL; + } + info_ptr->valid &= ~PNG_INFO_IDAT; + } +#endif + + if (num != -1) + mask &= ~PNG_FREE_MUL; + + info_ptr->free_me &= ~mask; +} +#endif /* defined(PNG_READ_SUPPORTED) || defined(PNG_WRITE_SUPPORTED) */ + +/* This function returns a pointer to the io_ptr associated with the user + * functions. The application should free any memory associated with this + * pointer before png_write_destroy() or png_read_destroy() are called. + */ +png_voidp PNGAPI +png_get_io_ptr(png_const_structrp png_ptr) +{ + if (png_ptr == NULL) + return (NULL); + + return (png_ptr->io_ptr); +} + +#if defined(PNG_READ_SUPPORTED) || defined(PNG_WRITE_SUPPORTED) +# ifdef PNG_STDIO_SUPPORTED +/* Initialize the default input/output functions for the PNG file. If you + * use your own read or write routines, you can call either png_set_read_fn() + * or png_set_write_fn() instead of png_init_io(). If you have defined + * PNG_NO_STDIO or otherwise disabled PNG_STDIO_SUPPORTED, you must use a + * function of your own because "FILE *" isn't necessarily available. + */ +void PNGAPI +png_init_io(png_structrp png_ptr, png_FILE_p fp) +{ + png_debug(1, "in png_init_io"); + + if (png_ptr == NULL) + return; + + png_ptr->io_ptr = (png_voidp)fp; +} +# endif + +#ifdef PNG_SAVE_INT_32_SUPPORTED +/* The png_save_int_32 function assumes integers are stored in two's + * complement format. If this isn't the case, then this routine needs to + * be modified to write data in two's complement format. Note that, + * the following works correctly even if png_int_32 has more than 32 bits + * (compare the more complex code required on read for sign extension.) + */ +void PNGAPI +png_save_int_32(png_bytep buf, png_int_32 i) +{ + buf[0] = (png_byte)((i >> 24) & 0xff); + buf[1] = (png_byte)((i >> 16) & 0xff); + buf[2] = (png_byte)((i >> 8) & 0xff); + buf[3] = (png_byte)(i & 0xff); +} +#endif + +# ifdef PNG_TIME_RFC1123_SUPPORTED +/* Convert the supplied time into an RFC 1123 string suitable for use in + * a "Creation Time" or other text-based time string. + */ +int PNGAPI +png_convert_to_rfc1123_buffer(char out[29], png_const_timep ptime) +{ + static PNG_CONST char short_months[12][4] = + {"Jan", "Feb", "Mar", "Apr", "May", "Jun", + "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"}; + + if (out == NULL) + return 0; + + if (ptime->year > 9999 /* RFC1123 limitation */ || + ptime->month == 0 || ptime->month > 12 || + ptime->day == 0 || ptime->day > 31 || + ptime->hour > 23 || ptime->minute > 59 || + ptime->second > 60) + return 0; + + { + size_t pos = 0; + char number_buf[5]; /* enough for a four-digit year */ + +# define APPEND_STRING(string) pos = png_safecat(out, 29, pos, (string)) +# define APPEND_NUMBER(format, value)\ + APPEND_STRING(PNG_FORMAT_NUMBER(number_buf, format, (value))) +# define APPEND(ch) if (pos < 28) out[pos++] = (ch) + + APPEND_NUMBER(PNG_NUMBER_FORMAT_u, (unsigned)ptime->day); + APPEND(' '); + APPEND_STRING(short_months[(ptime->month - 1)]); + APPEND(' '); + APPEND_NUMBER(PNG_NUMBER_FORMAT_u, ptime->year); + APPEND(' '); + APPEND_NUMBER(PNG_NUMBER_FORMAT_02u, (unsigned)ptime->hour); + APPEND(':'); + APPEND_NUMBER(PNG_NUMBER_FORMAT_02u, (unsigned)ptime->minute); + APPEND(':'); + APPEND_NUMBER(PNG_NUMBER_FORMAT_02u, (unsigned)ptime->second); + APPEND_STRING(" +0000"); /* This reliably terminates the buffer */ + +# undef APPEND +# undef APPEND_NUMBER +# undef APPEND_STRING + } + + return 1; +} + +# if PNG_LIBPNG_VER < 10700 +/* To do: remove the following from libpng-1.7 */ +/* Original API that uses a private buffer in png_struct. + * Deprecated because it causes png_struct to carry a spurious temporary + * buffer (png_struct::time_buffer), better to have the caller pass this in. + */ +png_const_charp PNGAPI +png_convert_to_rfc1123(png_structrp png_ptr, png_const_timep ptime) +{ + if (png_ptr != NULL) + { + /* The only failure above if png_ptr != NULL is from an invalid ptime */ + if (!png_convert_to_rfc1123_buffer(png_ptr->time_buffer, ptime)) + png_warning(png_ptr, "Ignoring invalid time value"); + + else + return png_ptr->time_buffer; + } + + return NULL; +} +# endif +# endif /* PNG_TIME_RFC1123_SUPPORTED */ + +#endif /* defined(PNG_READ_SUPPORTED) || defined(PNG_WRITE_SUPPORTED) */ + +png_const_charp PNGAPI +png_get_copyright(png_const_structrp png_ptr) +{ + PNG_UNUSED(png_ptr) /* Silence compiler warning about unused png_ptr */ +#ifdef PNG_STRING_COPYRIGHT + return PNG_STRING_COPYRIGHT +#else +# ifdef __STDC__ + return PNG_STRING_NEWLINE \ + "libpng version 1.6.7 - November 14, 2013" PNG_STRING_NEWLINE \ + "Copyright (c) 1998-2013 Glenn Randers-Pehrson" PNG_STRING_NEWLINE \ + "Copyright (c) 1996-1997 Andreas Dilger" PNG_STRING_NEWLINE \ + "Copyright (c) 1995-1996 Guy Eric Schalnat, Group 42, Inc." \ + PNG_STRING_NEWLINE; +# else + return "libpng version 1.6.7 - November 14, 2013\ + Copyright (c) 1998-2013 Glenn Randers-Pehrson\ + Copyright (c) 1996-1997 Andreas Dilger\ + Copyright (c) 1995-1996 Guy Eric Schalnat, Group 42, Inc."; +# endif +#endif +} + +/* The following return the library version as a short string in the + * format 1.0.0 through 99.99.99zz. To get the version of *.h files + * used with your application, print out PNG_LIBPNG_VER_STRING, which + * is defined in png.h. + * Note: now there is no difference between png_get_libpng_ver() and + * png_get_header_ver(). Due to the version_nn_nn_nn typedef guard, + * it is guaranteed that png.c uses the correct version of png.h. + */ +png_const_charp PNGAPI +png_get_libpng_ver(png_const_structrp png_ptr) +{ + /* Version of *.c files used when building libpng */ + return png_get_header_ver(png_ptr); +} + +png_const_charp PNGAPI +png_get_header_ver(png_const_structrp png_ptr) +{ + /* Version of *.h files used when building libpng */ + PNG_UNUSED(png_ptr) /* Silence compiler warning about unused png_ptr */ + return PNG_LIBPNG_VER_STRING; +} + +png_const_charp PNGAPI +png_get_header_version(png_const_structrp png_ptr) +{ + /* Returns longer string containing both version and date */ + PNG_UNUSED(png_ptr) /* Silence compiler warning about unused png_ptr */ +#ifdef __STDC__ + return PNG_HEADER_VERSION_STRING +# ifndef PNG_READ_SUPPORTED + " (NO READ SUPPORT)" +# endif + PNG_STRING_NEWLINE; +#else + return PNG_HEADER_VERSION_STRING; +#endif +} + +#ifdef PNG_SET_UNKNOWN_CHUNKS_SUPPORTED +int PNGAPI +png_handle_as_unknown(png_const_structrp png_ptr, png_const_bytep chunk_name) +{ + /* Check chunk_name and return "keep" value if it's on the list, else 0 */ + png_const_bytep p, p_end; + + if (png_ptr == NULL || chunk_name == NULL || png_ptr->num_chunk_list == 0) + return PNG_HANDLE_CHUNK_AS_DEFAULT; + + p_end = png_ptr->chunk_list; + p = p_end + png_ptr->num_chunk_list*5; /* beyond end */ + + /* The code is the fifth byte after each four byte string. Historically this + * code was always searched from the end of the list, this is no longer + * necessary because the 'set' routine handles duplicate entries correcty. + */ + do /* num_chunk_list > 0, so at least one */ + { + p -= 5; + + if (!memcmp(chunk_name, p, 4)) + return p[4]; + } + while (p > p_end); + + /* This means that known chunks should be processed and unknown chunks should + * be handled according to the value of png_ptr->unknown_default; this can be + * confusing because, as a result, there are two levels of defaulting for + * unknown chunks. + */ + return PNG_HANDLE_CHUNK_AS_DEFAULT; +} + +#if defined(PNG_READ_UNKNOWN_CHUNKS_SUPPORTED) ||\ + defined(PNG_HANDLE_AS_UNKNOWN_SUPPORTED) +int /* PRIVATE */ +png_chunk_unknown_handling(png_const_structrp png_ptr, png_uint_32 chunk_name) +{ + png_byte chunk_string[5]; + + PNG_CSTRING_FROM_CHUNK(chunk_string, chunk_name); + return png_handle_as_unknown(png_ptr, chunk_string); +} +#endif /* READ_UNKNOWN_CHUNKS || HANDLE_AS_UNKNOWN */ +#endif /* SET_UNKNOWN_CHUNKS */ + +#ifdef PNG_READ_SUPPORTED +/* This function, added to libpng-1.0.6g, is untested. */ +int PNGAPI +png_reset_zstream(png_structrp png_ptr) +{ + if (png_ptr == NULL) + return Z_STREAM_ERROR; + + /* WARNING: this resets the window bits to the maximum! */ + return (inflateReset(&png_ptr->zstream)); +} +#endif /* PNG_READ_SUPPORTED */ + +/* This function was added to libpng-1.0.7 */ +png_uint_32 PNGAPI +png_access_version_number(void) +{ + /* Version of *.c files used when building libpng */ + return((png_uint_32)PNG_LIBPNG_VER); +} + + + +#if defined(PNG_READ_SUPPORTED) || defined(PNG_WRITE_SUPPORTED) +/* Ensure that png_ptr->zstream.msg holds some appropriate error message string. + * If it doesn't 'ret' is used to set it to something appropriate, even in cases + * like Z_OK or Z_STREAM_END where the error code is apparently a success code. + */ +void /* PRIVATE */ +png_zstream_error(png_structrp png_ptr, int ret) +{ + /* Translate 'ret' into an appropriate error string, priority is given to the + * one in zstream if set. This always returns a string, even in cases like + * Z_OK or Z_STREAM_END where the error code is a success code. + */ + if (png_ptr->zstream.msg == NULL) switch (ret) + { + default: + case Z_OK: + png_ptr->zstream.msg = PNGZ_MSG_CAST("unexpected zlib return code"); + break; + + case Z_STREAM_END: + /* Normal exit */ + png_ptr->zstream.msg = PNGZ_MSG_CAST("unexpected end of LZ stream"); + break; + + case Z_NEED_DICT: + /* This means the deflate stream did not have a dictionary; this + * indicates a bogus PNG. + */ + png_ptr->zstream.msg = PNGZ_MSG_CAST("missing LZ dictionary"); + break; + + case Z_ERRNO: + /* gz APIs only: should not happen */ + png_ptr->zstream.msg = PNGZ_MSG_CAST("zlib IO error"); + break; + + case Z_STREAM_ERROR: + /* internal libpng error */ + png_ptr->zstream.msg = PNGZ_MSG_CAST("bad parameters to zlib"); + break; + + case Z_DATA_ERROR: + png_ptr->zstream.msg = PNGZ_MSG_CAST("damaged LZ stream"); + break; + + case Z_MEM_ERROR: + png_ptr->zstream.msg = PNGZ_MSG_CAST("insufficient memory"); + break; + + case Z_BUF_ERROR: + /* End of input or output; not a problem if the caller is doing + * incremental read or write. + */ + png_ptr->zstream.msg = PNGZ_MSG_CAST("truncated"); + break; + + case Z_VERSION_ERROR: + png_ptr->zstream.msg = PNGZ_MSG_CAST("unsupported zlib version"); + break; + + case PNG_UNEXPECTED_ZLIB_RETURN: + /* Compile errors here mean that zlib now uses the value co-opted in + * pngpriv.h for PNG_UNEXPECTED_ZLIB_RETURN; update the switch above + * and change pngpriv.h. Note that this message is "... return", + * whereas the default/Z_OK one is "... return code". + */ + png_ptr->zstream.msg = PNGZ_MSG_CAST("unexpected zlib return"); + break; + } +} + +/* png_convert_size: a PNGAPI but no longer in png.h, so deleted + * at libpng 1.5.5! + */ + +/* Added at libpng version 1.2.34 and 1.4.0 (moved from pngset.c) */ +#ifdef PNG_GAMMA_SUPPORTED /* always set if COLORSPACE */ +static int +png_colorspace_check_gamma(png_const_structrp png_ptr, + png_colorspacerp colorspace, png_fixed_point gAMA, int from) + /* This is called to check a new gamma value against an existing one. The + * routine returns false if the new gamma value should not be written. + * + * 'from' says where the new gamma value comes from: + * + * 0: the new gamma value is the libpng estimate for an ICC profile + * 1: the new gamma value comes from a gAMA chunk + * 2: the new gamma value comes from an sRGB chunk + */ +{ + png_fixed_point gtest; + + if ((colorspace->flags & PNG_COLORSPACE_HAVE_GAMMA) != 0 && + (!png_muldiv(>est, colorspace->gamma, PNG_FP_1, gAMA) || + png_gamma_significant(gtest))) + { + /* Either this is an sRGB image, in which case the calculated gamma + * approximation should match, or this is an image with a profile and the + * value libpng calculates for the gamma of the profile does not match the + * value recorded in the file. The former, sRGB, case is an error, the + * latter is just a warning. + */ + if ((colorspace->flags & PNG_COLORSPACE_FROM_sRGB) != 0 || from == 2) + { + png_chunk_report(png_ptr, "gamma value does not match sRGB", + PNG_CHUNK_ERROR); + /* Do not overwrite an sRGB value */ + return from == 2; + } + + else /* sRGB tag not involved */ + { + png_chunk_report(png_ptr, "gamma value does not match libpng estimate", + PNG_CHUNK_WARNING); + return from == 1; + } + } + + return 1; +} + +void /* PRIVATE */ +png_colorspace_set_gamma(png_const_structrp png_ptr, + png_colorspacerp colorspace, png_fixed_point gAMA) +{ + /* Changed in libpng-1.5.4 to limit the values to ensure overflow can't + * occur. Since the fixed point representation is assymetrical it is + * possible for 1/gamma to overflow the limit of 21474 and this means the + * gamma value must be at least 5/100000 and hence at most 20000.0. For + * safety the limits here are a little narrower. The values are 0.00016 to + * 6250.0, which are truly ridiculous gamma values (and will produce + * displays that are all black or all white.) + * + * In 1.6.0 this test replaces the ones in pngrutil.c, in the gAMA chunk + * handling code, which only required the value to be >0. + */ + png_const_charp errmsg; + + if (gAMA < 16 || gAMA > 625000000) + errmsg = "gamma value out of range"; + +# ifdef PNG_READ_gAMA_SUPPORTED + /* Allow the application to set the gamma value more than once */ + else if ((png_ptr->mode & PNG_IS_READ_STRUCT) != 0 && + (colorspace->flags & PNG_COLORSPACE_FROM_gAMA) != 0) + errmsg = "duplicate"; +# endif + + /* Do nothing if the colorspace is already invalid */ + else if (colorspace->flags & PNG_COLORSPACE_INVALID) + return; + + else + { + if (png_colorspace_check_gamma(png_ptr, colorspace, gAMA, 1/*from gAMA*/)) + { + /* Store this gamma value. */ + colorspace->gamma = gAMA; + colorspace->flags |= + (PNG_COLORSPACE_HAVE_GAMMA | PNG_COLORSPACE_FROM_gAMA); + } + + /* At present if the check_gamma test fails the gamma of the colorspace is + * not updated however the colorspace is not invalidated. This + * corresponds to the case where the existing gamma comes from an sRGB + * chunk or profile. An error message has already been output. + */ + return; + } + + /* Error exit - errmsg has been set. */ + colorspace->flags |= PNG_COLORSPACE_INVALID; + png_chunk_report(png_ptr, errmsg, PNG_CHUNK_WRITE_ERROR); +} + +void /* PRIVATE */ +png_colorspace_sync_info(png_const_structrp png_ptr, png_inforp info_ptr) +{ + if (info_ptr->colorspace.flags & PNG_COLORSPACE_INVALID) + { + /* Everything is invalid */ + info_ptr->valid &= ~(PNG_INFO_gAMA|PNG_INFO_cHRM|PNG_INFO_sRGB| + PNG_INFO_iCCP); + +# ifdef PNG_COLORSPACE_SUPPORTED + /* Clean up the iCCP profile now if it won't be used. */ + png_free_data(png_ptr, info_ptr, PNG_FREE_ICCP, -1/*not used*/); +# else + PNG_UNUSED(png_ptr) +# endif + } + + else + { +# ifdef PNG_COLORSPACE_SUPPORTED + /* Leave the INFO_iCCP flag set if the pngset.c code has already set + * it; this allows a PNG to contain a profile which matches sRGB and + * yet still have that profile retrievable by the application. + */ + if (info_ptr->colorspace.flags & PNG_COLORSPACE_MATCHES_sRGB) + info_ptr->valid |= PNG_INFO_sRGB; + + else + info_ptr->valid &= ~PNG_INFO_sRGB; + + if (info_ptr->colorspace.flags & PNG_COLORSPACE_HAVE_ENDPOINTS) + info_ptr->valid |= PNG_INFO_cHRM; + + else + info_ptr->valid &= ~PNG_INFO_cHRM; +# endif + + if (info_ptr->colorspace.flags & PNG_COLORSPACE_HAVE_GAMMA) + info_ptr->valid |= PNG_INFO_gAMA; + + else + info_ptr->valid &= ~PNG_INFO_gAMA; + } +} + +#ifdef PNG_READ_SUPPORTED +void /* PRIVATE */ +png_colorspace_sync(png_const_structrp png_ptr, png_inforp info_ptr) +{ + if (info_ptr == NULL) /* reduce code size; check here not in the caller */ + return; + + info_ptr->colorspace = png_ptr->colorspace; + png_colorspace_sync_info(png_ptr, info_ptr); +} +#endif +#endif + +#ifdef PNG_COLORSPACE_SUPPORTED +/* Added at libpng-1.5.5 to support read and write of true CIEXYZ values for + * cHRM, as opposed to using chromaticities. These internal APIs return + * non-zero on a parameter error. The X, Y and Z values are required to be + * positive and less than 1.0. + */ +static int +png_xy_from_XYZ(png_xy *xy, const png_XYZ *XYZ) +{ + png_int_32 d, dwhite, whiteX, whiteY; + + d = XYZ->red_X + XYZ->red_Y + XYZ->red_Z; + if (!png_muldiv(&xy->redx, XYZ->red_X, PNG_FP_1, d)) return 1; + if (!png_muldiv(&xy->redy, XYZ->red_Y, PNG_FP_1, d)) return 1; + dwhite = d; + whiteX = XYZ->red_X; + whiteY = XYZ->red_Y; + + d = XYZ->green_X + XYZ->green_Y + XYZ->green_Z; + if (!png_muldiv(&xy->greenx, XYZ->green_X, PNG_FP_1, d)) return 1; + if (!png_muldiv(&xy->greeny, XYZ->green_Y, PNG_FP_1, d)) return 1; + dwhite += d; + whiteX += XYZ->green_X; + whiteY += XYZ->green_Y; + + d = XYZ->blue_X + XYZ->blue_Y + XYZ->blue_Z; + if (!png_muldiv(&xy->bluex, XYZ->blue_X, PNG_FP_1, d)) return 1; + if (!png_muldiv(&xy->bluey, XYZ->blue_Y, PNG_FP_1, d)) return 1; + dwhite += d; + whiteX += XYZ->blue_X; + whiteY += XYZ->blue_Y; + + /* The reference white is simply the sum of the end-point (X,Y,Z) vectors, + * thus: + */ + if (!png_muldiv(&xy->whitex, whiteX, PNG_FP_1, dwhite)) return 1; + if (!png_muldiv(&xy->whitey, whiteY, PNG_FP_1, dwhite)) return 1; + + return 0; +} + +static int +png_XYZ_from_xy(png_XYZ *XYZ, const png_xy *xy) +{ + png_fixed_point red_inverse, green_inverse, blue_scale; + png_fixed_point left, right, denominator; + + /* Check xy and, implicitly, z. Note that wide gamut color spaces typically + * have end points with 0 tristimulus values (these are impossible end + * points, but they are used to cover the possible colors.) + */ + if (xy->redx < 0 || xy->redx > PNG_FP_1) return 1; + if (xy->redy < 0 || xy->redy > PNG_FP_1-xy->redx) return 1; + if (xy->greenx < 0 || xy->greenx > PNG_FP_1) return 1; + if (xy->greeny < 0 || xy->greeny > PNG_FP_1-xy->greenx) return 1; + if (xy->bluex < 0 || xy->bluex > PNG_FP_1) return 1; + if (xy->bluey < 0 || xy->bluey > PNG_FP_1-xy->bluex) return 1; + if (xy->whitex < 0 || xy->whitex > PNG_FP_1) return 1; + if (xy->whitey < 0 || xy->whitey > PNG_FP_1-xy->whitex) return 1; + + /* The reverse calculation is more difficult because the original tristimulus + * value had 9 independent values (red,green,blue)x(X,Y,Z) however only 8 + * derived values were recorded in the cHRM chunk; + * (red,green,blue,white)x(x,y). This loses one degree of freedom and + * therefore an arbitrary ninth value has to be introduced to undo the + * original transformations. + * + * Think of the original end-points as points in (X,Y,Z) space. The + * chromaticity values (c) have the property: + * + * C + * c = --------- + * X + Y + Z + * + * For each c (x,y,z) from the corresponding original C (X,Y,Z). Thus the + * three chromaticity values (x,y,z) for each end-point obey the + * relationship: + * + * x + y + z = 1 + * + * This describes the plane in (X,Y,Z) space that intersects each axis at the + * value 1.0; call this the chromaticity plane. Thus the chromaticity + * calculation has scaled each end-point so that it is on the x+y+z=1 plane + * and chromaticity is the intersection of the vector from the origin to the + * (X,Y,Z) value with the chromaticity plane. + * + * To fully invert the chromaticity calculation we would need the three + * end-point scale factors, (red-scale, green-scale, blue-scale), but these + * were not recorded. Instead we calculated the reference white (X,Y,Z) and + * recorded the chromaticity of this. The reference white (X,Y,Z) would have + * given all three of the scale factors since: + * + * color-C = color-c * color-scale + * white-C = red-C + green-C + blue-C + * = red-c*red-scale + green-c*green-scale + blue-c*blue-scale + * + * But cHRM records only white-x and white-y, so we have lost the white scale + * factor: + * + * white-C = white-c*white-scale + * + * To handle this the inverse transformation makes an arbitrary assumption + * about white-scale: + * + * Assume: white-Y = 1.0 + * Hence: white-scale = 1/white-y + * Or: red-Y + green-Y + blue-Y = 1.0 + * + * Notice the last statement of the assumption gives an equation in three of + * the nine values we want to calculate. 8 more equations come from the + * above routine as summarised at the top above (the chromaticity + * calculation): + * + * Given: color-x = color-X / (color-X + color-Y + color-Z) + * Hence: (color-x - 1)*color-X + color.x*color-Y + color.x*color-Z = 0 + * + * This is 9 simultaneous equations in the 9 variables "color-C" and can be + * solved by Cramer's rule. Cramer's rule requires calculating 10 9x9 matrix + * determinants, however this is not as bad as it seems because only 28 of + * the total of 90 terms in the various matrices are non-zero. Nevertheless + * Cramer's rule is notoriously numerically unstable because the determinant + * calculation involves the difference of large, but similar, numbers. It is + * difficult to be sure that the calculation is stable for real world values + * and it is certain that it becomes unstable where the end points are close + * together. + * + * So this code uses the perhaps slightly less optimal but more + * understandable and totally obvious approach of calculating color-scale. + * + * This algorithm depends on the precision in white-scale and that is + * (1/white-y), so we can immediately see that as white-y approaches 0 the + * accuracy inherent in the cHRM chunk drops off substantially. + * + * libpng arithmetic: a simple invertion of the above equations + * ------------------------------------------------------------ + * + * white_scale = 1/white-y + * white-X = white-x * white-scale + * white-Y = 1.0 + * white-Z = (1 - white-x - white-y) * white_scale + * + * white-C = red-C + green-C + blue-C + * = red-c*red-scale + green-c*green-scale + blue-c*blue-scale + * + * This gives us three equations in (red-scale,green-scale,blue-scale) where + * all the coefficients are now known: + * + * red-x*red-scale + green-x*green-scale + blue-x*blue-scale + * = white-x/white-y + * red-y*red-scale + green-y*green-scale + blue-y*blue-scale = 1 + * red-z*red-scale + green-z*green-scale + blue-z*blue-scale + * = (1 - white-x - white-y)/white-y + * + * In the last equation color-z is (1 - color-x - color-y) so we can add all + * three equations together to get an alternative third: + * + * red-scale + green-scale + blue-scale = 1/white-y = white-scale + * + * So now we have a Cramer's rule solution where the determinants are just + * 3x3 - far more tractible. Unfortunately 3x3 determinants still involve + * multiplication of three coefficients so we can't guarantee to avoid + * overflow in the libpng fixed point representation. Using Cramer's rule in + * floating point is probably a good choice here, but it's not an option for + * fixed point. Instead proceed to simplify the first two equations by + * eliminating what is likely to be the largest value, blue-scale: + * + * blue-scale = white-scale - red-scale - green-scale + * + * Hence: + * + * (red-x - blue-x)*red-scale + (green-x - blue-x)*green-scale = + * (white-x - blue-x)*white-scale + * + * (red-y - blue-y)*red-scale + (green-y - blue-y)*green-scale = + * 1 - blue-y*white-scale + * + * And now we can trivially solve for (red-scale,green-scale): + * + * green-scale = + * (white-x - blue-x)*white-scale - (red-x - blue-x)*red-scale + * ----------------------------------------------------------- + * green-x - blue-x + * + * red-scale = + * 1 - blue-y*white-scale - (green-y - blue-y) * green-scale + * --------------------------------------------------------- + * red-y - blue-y + * + * Hence: + * + * red-scale = + * ( (green-x - blue-x) * (white-y - blue-y) - + * (green-y - blue-y) * (white-x - blue-x) ) / white-y + * ------------------------------------------------------------------------- + * (green-x - blue-x)*(red-y - blue-y)-(green-y - blue-y)*(red-x - blue-x) + * + * green-scale = + * ( (red-y - blue-y) * (white-x - blue-x) - + * (red-x - blue-x) * (white-y - blue-y) ) / white-y + * ------------------------------------------------------------------------- + * (green-x - blue-x)*(red-y - blue-y)-(green-y - blue-y)*(red-x - blue-x) + * + * Accuracy: + * The input values have 5 decimal digits of accuracy. The values are all in + * the range 0 < value < 1, so simple products are in the same range but may + * need up to 10 decimal digits to preserve the original precision and avoid + * underflow. Because we are using a 32-bit signed representation we cannot + * match this; the best is a little over 9 decimal digits, less than 10. + * + * The approach used here is to preserve the maximum precision within the + * signed representation. Because the red-scale calculation above uses the + * difference between two products of values that must be in the range -1..+1 + * it is sufficient to divide the product by 7; ceil(100,000/32767*2). The + * factor is irrelevant in the calculation because it is applied to both + * numerator and denominator. + * + * Note that the values of the differences of the products of the + * chromaticities in the above equations tend to be small, for example for + * the sRGB chromaticities they are: + * + * red numerator: -0.04751 + * green numerator: -0.08788 + * denominator: -0.2241 (without white-y multiplication) + * + * The resultant Y coefficients from the chromaticities of some widely used + * color space definitions are (to 15 decimal places): + * + * sRGB + * 0.212639005871510 0.715168678767756 0.072192315360734 + * Kodak ProPhoto + * 0.288071128229293 0.711843217810102 0.000085653960605 + * Adobe RGB + * 0.297344975250536 0.627363566255466 0.075291458493998 + * Adobe Wide Gamut RGB + * 0.258728243040113 0.724682314948566 0.016589442011321 + */ + /* By the argument, above overflow should be impossible here. The return + * value of 2 indicates an internal error to the caller. + */ + if (!png_muldiv(&left, xy->greenx-xy->bluex, xy->redy - xy->bluey, 7)) + return 2; + if (!png_muldiv(&right, xy->greeny-xy->bluey, xy->redx - xy->bluex, 7)) + return 2; + denominator = left - right; + + /* Now find the red numerator. */ + if (!png_muldiv(&left, xy->greenx-xy->bluex, xy->whitey-xy->bluey, 7)) + return 2; + if (!png_muldiv(&right, xy->greeny-xy->bluey, xy->whitex-xy->bluex, 7)) + return 2; + + /* Overflow is possible here and it indicates an extreme set of PNG cHRM + * chunk values. This calculation actually returns the reciprocal of the + * scale value because this allows us to delay the multiplication of white-y + * into the denominator, which tends to produce a small number. + */ + if (!png_muldiv(&red_inverse, xy->whitey, denominator, left-right) || + red_inverse <= xy->whitey /* r+g+b scales = white scale */) + return 1; + + /* Similarly for green_inverse: */ + if (!png_muldiv(&left, xy->redy-xy->bluey, xy->whitex-xy->bluex, 7)) + return 2; + if (!png_muldiv(&right, xy->redx-xy->bluex, xy->whitey-xy->bluey, 7)) + return 2; + if (!png_muldiv(&green_inverse, xy->whitey, denominator, left-right) || + green_inverse <= xy->whitey) + return 1; + + /* And the blue scale, the checks above guarantee this can't overflow but it + * can still produce 0 for extreme cHRM values. + */ + blue_scale = png_reciprocal(xy->whitey) - png_reciprocal(red_inverse) - + png_reciprocal(green_inverse); + if (blue_scale <= 0) return 1; + + + /* And fill in the png_XYZ: */ + if (!png_muldiv(&XYZ->red_X, xy->redx, PNG_FP_1, red_inverse)) return 1; + if (!png_muldiv(&XYZ->red_Y, xy->redy, PNG_FP_1, red_inverse)) return 1; + if (!png_muldiv(&XYZ->red_Z, PNG_FP_1 - xy->redx - xy->redy, PNG_FP_1, + red_inverse)) + return 1; + + if (!png_muldiv(&XYZ->green_X, xy->greenx, PNG_FP_1, green_inverse)) + return 1; + if (!png_muldiv(&XYZ->green_Y, xy->greeny, PNG_FP_1, green_inverse)) + return 1; + if (!png_muldiv(&XYZ->green_Z, PNG_FP_1 - xy->greenx - xy->greeny, PNG_FP_1, + green_inverse)) + return 1; + + if (!png_muldiv(&XYZ->blue_X, xy->bluex, blue_scale, PNG_FP_1)) return 1; + if (!png_muldiv(&XYZ->blue_Y, xy->bluey, blue_scale, PNG_FP_1)) return 1; + if (!png_muldiv(&XYZ->blue_Z, PNG_FP_1 - xy->bluex - xy->bluey, blue_scale, + PNG_FP_1)) + return 1; + + return 0; /*success*/ +} + +static int +png_XYZ_normalize(png_XYZ *XYZ) +{ + png_int_32 Y; + + if (XYZ->red_Y < 0 || XYZ->green_Y < 0 || XYZ->blue_Y < 0 || + XYZ->red_X < 0 || XYZ->green_X < 0 || XYZ->blue_X < 0 || + XYZ->red_Z < 0 || XYZ->green_Z < 0 || XYZ->blue_Z < 0) + return 1; + + /* Normalize by scaling so the sum of the end-point Y values is PNG_FP_1. + * IMPLEMENTATION NOTE: ANSI requires signed overflow not to occur, therefore + * relying on addition of two positive values producing a negative one is not + * safe. + */ + Y = XYZ->red_Y; + if (0x7fffffff - Y < XYZ->green_X) return 1; + Y += XYZ->green_Y; + if (0x7fffffff - Y < XYZ->blue_X) return 1; + Y += XYZ->blue_Y; + + if (Y != PNG_FP_1) + { + if (!png_muldiv(&XYZ->red_X, XYZ->red_X, PNG_FP_1, Y)) return 1; + if (!png_muldiv(&XYZ->red_Y, XYZ->red_Y, PNG_FP_1, Y)) return 1; + if (!png_muldiv(&XYZ->red_Z, XYZ->red_Z, PNG_FP_1, Y)) return 1; + + if (!png_muldiv(&XYZ->green_X, XYZ->green_X, PNG_FP_1, Y)) return 1; + if (!png_muldiv(&XYZ->green_Y, XYZ->green_Y, PNG_FP_1, Y)) return 1; + if (!png_muldiv(&XYZ->green_Z, XYZ->green_Z, PNG_FP_1, Y)) return 1; + + if (!png_muldiv(&XYZ->blue_X, XYZ->blue_X, PNG_FP_1, Y)) return 1; + if (!png_muldiv(&XYZ->blue_Y, XYZ->blue_Y, PNG_FP_1, Y)) return 1; + if (!png_muldiv(&XYZ->blue_Z, XYZ->blue_Z, PNG_FP_1, Y)) return 1; + } + + return 0; +} + +static int +png_colorspace_endpoints_match(const png_xy *xy1, const png_xy *xy2, int delta) +{ + /* Allow an error of +/-0.01 (absolute value) on each chromaticity */ + return !(PNG_OUT_OF_RANGE(xy1->whitex, xy2->whitex,delta) || + PNG_OUT_OF_RANGE(xy1->whitey, xy2->whitey,delta) || + PNG_OUT_OF_RANGE(xy1->redx, xy2->redx, delta) || + PNG_OUT_OF_RANGE(xy1->redy, xy2->redy, delta) || + PNG_OUT_OF_RANGE(xy1->greenx, xy2->greenx,delta) || + PNG_OUT_OF_RANGE(xy1->greeny, xy2->greeny,delta) || + PNG_OUT_OF_RANGE(xy1->bluex, xy2->bluex, delta) || + PNG_OUT_OF_RANGE(xy1->bluey, xy2->bluey, delta)); +} + +/* Added in libpng-1.6.0, a different check for the validity of a set of cHRM + * chunk chromaticities. Earlier checks used to simply look for the overflow + * condition (where the determinant of the matrix to solve for XYZ ends up zero + * because the chromaticity values are not all distinct.) Despite this it is + * theoretically possible to produce chromaticities that are apparently valid + * but that rapidly degrade to invalid, potentially crashing, sets because of + * arithmetic inaccuracies when calculations are performed on them. The new + * check is to round-trip xy -> XYZ -> xy and then check that the result is + * within a small percentage of the original. + */ +static int +png_colorspace_check_xy(png_XYZ *XYZ, const png_xy *xy) +{ + int result; + png_xy xy_test; + + /* As a side-effect this routine also returns the XYZ endpoints. */ + result = png_XYZ_from_xy(XYZ, xy); + if (result) return result; + + result = png_xy_from_XYZ(&xy_test, XYZ); + if (result) return result; + + if (png_colorspace_endpoints_match(xy, &xy_test, + 5/*actually, the math is pretty accurate*/)) + return 0; + + /* Too much slip */ + return 1; +} + +/* This is the check going the other way. The XYZ is modified to normalize it + * (another side-effect) and the xy chromaticities are returned. + */ +static int +png_colorspace_check_XYZ(png_xy *xy, png_XYZ *XYZ) +{ + int result; + png_XYZ XYZtemp; + + result = png_XYZ_normalize(XYZ); + if (result) return result; + + result = png_xy_from_XYZ(xy, XYZ); + if (result) return result; + + XYZtemp = *XYZ; + return png_colorspace_check_xy(&XYZtemp, xy); +} + +/* Used to check for an endpoint match against sRGB */ +static const png_xy sRGB_xy = /* From ITU-R BT.709-3 */ +{ + /* color x y */ + /* red */ 64000, 33000, + /* green */ 30000, 60000, + /* blue */ 15000, 6000, + /* white */ 31270, 32900 +}; + +static int +png_colorspace_set_xy_and_XYZ(png_const_structrp png_ptr, + png_colorspacerp colorspace, const png_xy *xy, const png_XYZ *XYZ, + int preferred) +{ + if (colorspace->flags & PNG_COLORSPACE_INVALID) + return 0; + + /* The consistency check is performed on the chromaticities; this factors out + * variations because of the normalization (or not) of the end point Y + * values. + */ + if (preferred < 2 && (colorspace->flags & PNG_COLORSPACE_HAVE_ENDPOINTS)) + { + /* The end points must be reasonably close to any we already have. The + * following allows an error of up to +/-.001 + */ + if (!png_colorspace_endpoints_match(xy, &colorspace->end_points_xy, 100)) + { + colorspace->flags |= PNG_COLORSPACE_INVALID; + png_benign_error(png_ptr, "inconsistent chromaticities"); + return 0; /* failed */ + } + + /* Only overwrite with preferred values */ + if (!preferred) + return 1; /* ok, but no change */ + } + + colorspace->end_points_xy = *xy; + colorspace->end_points_XYZ = *XYZ; + colorspace->flags |= PNG_COLORSPACE_HAVE_ENDPOINTS; + + /* The end points are normally quoted to two decimal digits, so allow +/-0.01 + * on this test. + */ + if (png_colorspace_endpoints_match(xy, &sRGB_xy, 1000)) + colorspace->flags |= PNG_COLORSPACE_ENDPOINTS_MATCH_sRGB; + + else + colorspace->flags &= PNG_COLORSPACE_CANCEL( + PNG_COLORSPACE_ENDPOINTS_MATCH_sRGB); + + return 2; /* ok and changed */ +} + +int /* PRIVATE */ +png_colorspace_set_chromaticities(png_const_structrp png_ptr, + png_colorspacerp colorspace, const png_xy *xy, int preferred) +{ + /* We must check the end points to ensure they are reasonable - in the past + * color management systems have crashed as a result of getting bogus + * colorant values, while this isn't the fault of libpng it is the + * responsibility of libpng because PNG carries the bomb and libpng is in a + * position to protect against it. + */ + png_XYZ XYZ; + + switch (png_colorspace_check_xy(&XYZ, xy)) + { + case 0: /* success */ + return png_colorspace_set_xy_and_XYZ(png_ptr, colorspace, xy, &XYZ, + preferred); + + case 1: + /* We can't invert the chromaticities so we can't produce value XYZ + * values. Likely as not a color management system will fail too. + */ + colorspace->flags |= PNG_COLORSPACE_INVALID; + png_benign_error(png_ptr, "invalid chromaticities"); + break; + + default: + /* libpng is broken; this should be a warning but if it happens we + * want error reports so for the moment it is an error. + */ + colorspace->flags |= PNG_COLORSPACE_INVALID; + png_error(png_ptr, "internal error checking chromaticities"); + break; + } + + return 0; /* failed */ +} + +int /* PRIVATE */ +png_colorspace_set_endpoints(png_const_structrp png_ptr, + png_colorspacerp colorspace, const png_XYZ *XYZ_in, int preferred) +{ + png_XYZ XYZ = *XYZ_in; + png_xy xy; + + switch (png_colorspace_check_XYZ(&xy, &XYZ)) + { + case 0: + return png_colorspace_set_xy_and_XYZ(png_ptr, colorspace, &xy, &XYZ, + preferred); + + case 1: + /* End points are invalid. */ + colorspace->flags |= PNG_COLORSPACE_INVALID; + png_benign_error(png_ptr, "invalid end points"); + break; + + default: + colorspace->flags |= PNG_COLORSPACE_INVALID; + png_error(png_ptr, "internal error checking chromaticities"); + break; + } + + return 0; /* failed */ +} + +#if defined(PNG_sRGB_SUPPORTED) || defined(PNG_iCCP_SUPPORTED) +/* Error message generation */ +static char +png_icc_tag_char(png_uint_32 byte) +{ + byte &= 0xff; + if (byte >= 32 && byte <= 126) + return (char)byte; + else + return '?'; +} + +static void +png_icc_tag_name(char *name, png_uint_32 tag) +{ + name[0] = '\''; + name[1] = png_icc_tag_char(tag >> 24); + name[2] = png_icc_tag_char(tag >> 16); + name[3] = png_icc_tag_char(tag >> 8); + name[4] = png_icc_tag_char(tag ); + name[5] = '\''; +} + +static int +is_ICC_signature_char(png_alloc_size_t it) +{ + return it == 32 || (it >= 48 && it <= 57) || (it >= 65 && it <= 90) || + (it >= 97 && it <= 122); +} + +static int is_ICC_signature(png_alloc_size_t it) +{ + return is_ICC_signature_char(it >> 24) /* checks all the top bits */ && + is_ICC_signature_char((it >> 16) & 0xff) && + is_ICC_signature_char((it >> 8) & 0xff) && + is_ICC_signature_char(it & 0xff); +} + +static int +png_icc_profile_error(png_const_structrp png_ptr, png_colorspacerp colorspace, + png_const_charp name, png_alloc_size_t value, png_const_charp reason) +{ + size_t pos; + char message[196]; /* see below for calculation */ + + if (colorspace != NULL) + colorspace->flags |= PNG_COLORSPACE_INVALID; + + pos = png_safecat(message, (sizeof message), 0, "profile '"); /* 9 chars */ + pos = png_safecat(message, pos+79, pos, name); /* Truncate to 79 chars */ + pos = png_safecat(message, (sizeof message), pos, "': "); /* +2 = 90 */ + if (is_ICC_signature(value)) + { + /* So 'value' is at most 4 bytes and the following cast is safe */ + png_icc_tag_name(message+pos, (png_uint_32)value); + pos += 6; /* total +8; less than the else clause */ + message[pos++] = ':'; + message[pos++] = ' '; + } +# ifdef PNG_WARNINGS_SUPPORTED + else + { + char number[PNG_NUMBER_BUFFER_SIZE]; /* +24 = 114*/ + + pos = png_safecat(message, (sizeof message), pos, + png_format_number(number, number+(sizeof number), + PNG_NUMBER_FORMAT_x, value)); + pos = png_safecat(message, (sizeof message), pos, "h: "); /*+2 = 116*/ + } +# endif + /* The 'reason' is an arbitrary message, allow +79 maximum 195 */ + pos = png_safecat(message, (sizeof message), pos, reason); + + /* This is recoverable, but make it unconditionally an app_error on write to + * avoid writing invalid ICC profiles into PNG files. (I.e. we handle them + * on read, with a warning, but on write unless the app turns off + * application errors the PNG won't be written.) + */ + png_chunk_report(png_ptr, message, + (colorspace != NULL) ? PNG_CHUNK_ERROR : PNG_CHUNK_WRITE_ERROR); + + return 0; +} +#endif /* sRGB || iCCP */ + +#ifdef PNG_sRGB_SUPPORTED +int /* PRIVATE */ +png_colorspace_set_sRGB(png_const_structrp png_ptr, png_colorspacerp colorspace, + int intent) +{ + /* sRGB sets known gamma, end points and (from the chunk) intent. */ + /* IMPORTANT: these are not necessarily the values found in an ICC profile + * because ICC profiles store values adapted to a D50 environment; it is + * expected that the ICC profile mediaWhitePointTag will be D50, see the + * checks and code elsewhere to understand this better. + * + * These XYZ values, which are accurate to 5dp, produce rgb to gray + * coefficients of (6968,23435,2366), which are reduced (because they add up + * to 32769 not 32768) to (6968,23434,2366). These are the values that + * libpng has traditionally used (and are the best values given the 15bit + * algorithm used by the rgb to gray code.) + */ + static const png_XYZ sRGB_XYZ = /* D65 XYZ (*not* the D50 adapted values!) */ + { + /* color X Y Z */ + /* red */ 41239, 21264, 1933, + /* green */ 35758, 71517, 11919, + /* blue */ 18048, 7219, 95053 + }; + + /* Do nothing if the colorspace is already invalidated. */ + if (colorspace->flags & PNG_COLORSPACE_INVALID) + return 0; + + /* Check the intent, then check for existing settings. It is valid for the + * PNG file to have cHRM or gAMA chunks along with sRGB, but the values must + * be consistent with the correct values. If, however, this function is + * called below because an iCCP chunk matches sRGB then it is quite + * conceivable that an older app recorded incorrect gAMA and cHRM because of + * an incorrect calculation based on the values in the profile - this does + * *not* invalidate the profile (though it still produces an error, which can + * be ignored.) + */ + if (intent < 0 || intent >= PNG_sRGB_INTENT_LAST) + return png_icc_profile_error(png_ptr, colorspace, "sRGB", + (unsigned)intent, "invalid sRGB rendering intent"); + + if ((colorspace->flags & PNG_COLORSPACE_HAVE_INTENT) != 0 && + colorspace->rendering_intent != intent) + return png_icc_profile_error(png_ptr, colorspace, "sRGB", + (unsigned)intent, "inconsistent rendering intents"); + + if ((colorspace->flags & PNG_COLORSPACE_FROM_sRGB) != 0) + { + png_benign_error(png_ptr, "duplicate sRGB information ignored"); + return 0; + } + + /* If the standard sRGB cHRM chunk does not match the one from the PNG file + * warn but overwrite the value with the correct one. + */ + if ((colorspace->flags & PNG_COLORSPACE_HAVE_ENDPOINTS) != 0 && + !png_colorspace_endpoints_match(&sRGB_xy, &colorspace->end_points_xy, + 100)) + png_chunk_report(png_ptr, "cHRM chunk does not match sRGB", + PNG_CHUNK_ERROR); + + /* This check is just done for the error reporting - the routine always + * returns true when the 'from' argument corresponds to sRGB (2). + */ + (void)png_colorspace_check_gamma(png_ptr, colorspace, PNG_GAMMA_sRGB_INVERSE, + 2/*from sRGB*/); + + /* intent: bugs in GCC force 'int' to be used as the parameter type. */ + colorspace->rendering_intent = (png_uint_16)intent; + colorspace->flags |= PNG_COLORSPACE_HAVE_INTENT; + + /* endpoints */ + colorspace->end_points_xy = sRGB_xy; + colorspace->end_points_XYZ = sRGB_XYZ; + colorspace->flags |= + (PNG_COLORSPACE_HAVE_ENDPOINTS|PNG_COLORSPACE_ENDPOINTS_MATCH_sRGB); + + /* gamma */ + colorspace->gamma = PNG_GAMMA_sRGB_INVERSE; + colorspace->flags |= PNG_COLORSPACE_HAVE_GAMMA; + + /* Finally record that we have an sRGB profile */ + colorspace->flags |= + (PNG_COLORSPACE_MATCHES_sRGB|PNG_COLORSPACE_FROM_sRGB); + + return 1; /* set */ +} +#endif /* sRGB */ + +#ifdef PNG_iCCP_SUPPORTED +/* Encoded value of D50 as an ICC XYZNumber. From the ICC 2010 spec the value + * is XYZ(0.9642,1.0,0.8249), which scales to: + * + * (63189.8112, 65536, 54060.6464) + */ +static const png_byte D50_nCIEXYZ[12] = + { 0x00, 0x00, 0xf6, 0xd6, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0xd3, 0x2d }; + +int /* PRIVATE */ +png_icc_check_length(png_const_structrp png_ptr, png_colorspacerp colorspace, + png_const_charp name, png_uint_32 profile_length) +{ + if (profile_length < 132) + return png_icc_profile_error(png_ptr, colorspace, name, profile_length, + "too short"); + + if (profile_length & 3) + return png_icc_profile_error(png_ptr, colorspace, name, profile_length, + "invalid length"); + + return 1; +} + +int /* PRIVATE */ +png_icc_check_header(png_const_structrp png_ptr, png_colorspacerp colorspace, + png_const_charp name, png_uint_32 profile_length, + png_const_bytep profile/* first 132 bytes only */, int color_type) +{ + png_uint_32 temp; + + /* Length check; this cannot be ignored in this code because profile_length + * is used later to check the tag table, so even if the profile seems over + * long profile_length from the caller must be correct. The caller can fix + * this up on read or write by just passing in the profile header length. + */ + temp = png_get_uint_32(profile); + if (temp != profile_length) + return png_icc_profile_error(png_ptr, colorspace, name, temp, + "length does not match profile"); + + temp = png_get_uint_32(profile+128); /* tag count: 12 bytes/tag */ + if (temp > 357913930 || /* (2^32-4-132)/12: maximum possible tag count */ + profile_length < 132+12*temp) /* truncated tag table */ + return png_icc_profile_error(png_ptr, colorspace, name, temp, + "tag count too large"); + + /* The 'intent' must be valid or we can't store it, ICC limits the intent to + * 16 bits. + */ + temp = png_get_uint_32(profile+64); + if (temp >= 0xffff) /* The ICC limit */ + return png_icc_profile_error(png_ptr, colorspace, name, temp, + "invalid rendering intent"); + + /* This is just a warning because the profile may be valid in future + * versions. + */ + if (temp >= PNG_sRGB_INTENT_LAST) + (void)png_icc_profile_error(png_ptr, NULL, name, temp, + "intent outside defined range"); + + /* At this point the tag table can't be checked because it hasn't necessarily + * been loaded; however, various header fields can be checked. These checks + * are for values permitted by the PNG spec in an ICC profile; the PNG spec + * restricts the profiles that can be passed in an iCCP chunk (they must be + * appropriate to processing PNG data!) + */ + + /* Data checks (could be skipped). These checks must be independent of the + * version number; however, the version number doesn't accomodate changes in + * the header fields (just the known tags and the interpretation of the + * data.) + */ + temp = png_get_uint_32(profile+36); /* signature 'ascp' */ + if (temp != 0x61637370) + return png_icc_profile_error(png_ptr, colorspace, name, temp, + "invalid signature"); + + /* Currently the PCS illuminant/adopted white point (the computational + * white point) are required to be D50, + * however the profile contains a record of the illuminant so perhaps ICC + * expects to be able to change this in the future (despite the rationale in + * the introduction for using a fixed PCS adopted white.) Consequently the + * following is just a warning. + */ + if (memcmp(profile+68, D50_nCIEXYZ, 12) != 0) + (void)png_icc_profile_error(png_ptr, NULL, name, 0/*no tag value*/, + "PCS illuminant is not D50"); + + /* The PNG spec requires this: + * "If the iCCP chunk is present, the image samples conform to the colour + * space represented by the embedded ICC profile as defined by the + * International Color Consortium [ICC]. The colour space of the ICC profile + * shall be an RGB colour space for colour images (PNG colour types 2, 3, and + * 6), or a greyscale colour space for greyscale images (PNG colour types 0 + * and 4)." + * + * This checking code ensures the embedded profile (on either read or write) + * conforms to the specification requirements. Notice that an ICC 'gray' + * color-space profile contains the information to transform the monochrome + * data to XYZ or L*a*b (according to which PCS the profile uses) and this + * should be used in preference to the standard libpng K channel replication + * into R, G and B channels. + * + * Previously it was suggested that an RGB profile on grayscale data could be + * handled. However it it is clear that using an RGB profile in this context + * must be an error - there is no specification of what it means. Thus it is + * almost certainly more correct to ignore the profile. + */ + temp = png_get_uint_32(profile+16); /* data colour space field */ + switch (temp) + { + case 0x52474220: /* 'RGB ' */ + if (!(color_type & PNG_COLOR_MASK_COLOR)) + return png_icc_profile_error(png_ptr, colorspace, name, temp, + "RGB color space not permitted on grayscale PNG"); + break; + + case 0x47524159: /* 'GRAY' */ + if (color_type & PNG_COLOR_MASK_COLOR) + return png_icc_profile_error(png_ptr, colorspace, name, temp, + "Gray color space not permitted on RGB PNG"); + break; + + default: + return png_icc_profile_error(png_ptr, colorspace, name, temp, + "invalid ICC profile color space"); + } + + /* It is up to the application to check that the profile class matches the + * application requirements; the spec provides no guidance, but it's pretty + * weird if the profile is not scanner ('scnr'), monitor ('mntr'), printer + * ('prtr') or 'spac' (for generic color spaces). Issue a warning in these + * cases. Issue an error for device link or abstract profiles - these don't + * contain the records necessary to transform the color-space to anything + * other than the target device (and not even that for an abstract profile). + * Profiles of these classes may not be embedded in images. + */ + temp = png_get_uint_32(profile+12); /* profile/device class */ + switch (temp) + { + case 0x73636E72: /* 'scnr' */ + case 0x6D6E7472: /* 'mntr' */ + case 0x70727472: /* 'prtr' */ + case 0x73706163: /* 'spac' */ + /* All supported */ + break; + + case 0x61627374: /* 'abst' */ + /* May not be embedded in an image */ + return png_icc_profile_error(png_ptr, colorspace, name, temp, + "invalid embedded Abstract ICC profile"); + + case 0x6C696E6B: /* 'link' */ + /* DeviceLink profiles cannnot be interpreted in a non-device specific + * fashion, if an app uses the AToB0Tag in the profile the results are + * undefined unless the result is sent to the intended device, + * therefore a DeviceLink profile should not be found embedded in a + * PNG. + */ + return png_icc_profile_error(png_ptr, colorspace, name, temp, + "unexpected DeviceLink ICC profile class"); + + case 0x6E6D636C: /* 'nmcl' */ + /* A NamedColor profile is also device specific, however it doesn't + * contain an AToB0 tag that is open to misintrepretation. Almost + * certainly it will fail the tests below. + */ + (void)png_icc_profile_error(png_ptr, NULL, name, temp, + "unexpected NamedColor ICC profile class"); + break; + + default: + /* To allow for future enhancements to the profile accept unrecognized + * profile classes with a warning, these then hit the test below on the + * tag content to ensure they are backward compatible with one of the + * understood profiles. + */ + (void)png_icc_profile_error(png_ptr, NULL, name, temp, + "unrecognized ICC profile class"); + break; + } + + /* For any profile other than a device link one the PCS must be encoded + * either in XYZ or Lab. + */ + temp = png_get_uint_32(profile+20); + switch (temp) + { + case 0x58595A20: /* 'XYZ ' */ + case 0x4C616220: /* 'Lab ' */ + break; + + default: + return png_icc_profile_error(png_ptr, colorspace, name, temp, + "unexpected ICC PCS encoding"); + } + + return 1; +} + +int /* PRIVATE */ +png_icc_check_tag_table(png_const_structrp png_ptr, png_colorspacerp colorspace, + png_const_charp name, png_uint_32 profile_length, + png_const_bytep profile /* header plus whole tag table */) +{ + png_uint_32 tag_count = png_get_uint_32(profile+128); + png_uint_32 itag; + png_const_bytep tag = profile+132; /* The first tag */ + + /* First scan all the tags in the table and add bits to the icc_info value + * (temporarily in 'tags'). + */ + for (itag=0; itag < tag_count; ++itag, tag += 12) + { + png_uint_32 tag_id = png_get_uint_32(tag+0); + png_uint_32 tag_start = png_get_uint_32(tag+4); /* must be aligned */ + png_uint_32 tag_length = png_get_uint_32(tag+8);/* not padded */ + + /* The ICC specification does not exclude zero length tags, therefore the + * start might actually be anywhere if there is no data, but this would be + * a clear abuse of the intent of the standard so the start is checked for + * being in range. All defined tag types have an 8 byte header - a 4 byte + * type signature then 0. + */ + if ((tag_start & 3) != 0) + { + /* CNHP730S.icc shipped with Microsoft Windows 64 violates this, it is + * only a warning here because libpng does not care about the + * alignment. + */ + (void)png_icc_profile_error(png_ptr, NULL, name, tag_id, + "ICC profile tag start not a multiple of 4"); + } + + /* This is a hard error; potentially it can cause read outside the + * profile. + */ + if (tag_start > profile_length || tag_length > profile_length - tag_start) + return png_icc_profile_error(png_ptr, colorspace, name, tag_id, + "ICC profile tag outside profile"); + } + + return 1; /* success, maybe with warnings */ +} + +#ifdef PNG_sRGB_SUPPORTED +/* Information about the known ICC sRGB profiles */ +static const struct +{ + png_uint_32 adler, crc, length; + png_uint_32 md5[4]; + png_byte have_md5; + png_byte is_broken; + png_uint_16 intent; + +# define PNG_MD5(a,b,c,d) { a, b, c, d }, (a!=0)||(b!=0)||(c!=0)||(d!=0) +# define PNG_ICC_CHECKSUM(adler, crc, md5, intent, broke, date, length, fname)\ + { adler, crc, length, md5, broke, intent }, + +} png_sRGB_checks[] = +{ + /* This data comes from contrib/tools/checksum-icc run on downloads of + * all four ICC sRGB profiles from www.color.org. + */ + /* adler32, crc32, MD5[4], intent, date, length, file-name */ + PNG_ICC_CHECKSUM(0x0a3fd9f6, 0x3b8772b9, + PNG_MD5(0x29f83dde, 0xaff255ae, 0x7842fae4, 0xca83390d), 0, 0, + "2009/03/27 21:36:31", 3048, "sRGB_IEC61966-2-1_black_scaled.icc") + + /* ICC sRGB v2 perceptual no black-compensation: */ + PNG_ICC_CHECKSUM(0x4909e5e1, 0x427ebb21, + PNG_MD5(0xc95bd637, 0xe95d8a3b, 0x0df38f99, 0xc1320389), 1, 0, + "2009/03/27 21:37:45", 3052, "sRGB_IEC61966-2-1_no_black_scaling.icc") + + PNG_ICC_CHECKSUM(0xfd2144a1, 0x306fd8ae, + PNG_MD5(0xfc663378, 0x37e2886b, 0xfd72e983, 0x8228f1b8), 0, 0, + "2009/08/10 17:28:01", 60988, "sRGB_v4_ICC_preference_displayclass.icc") + + /* ICC sRGB v4 perceptual */ + PNG_ICC_CHECKSUM(0x209c35d2, 0xbbef7812, + PNG_MD5(0x34562abf, 0x994ccd06, 0x6d2c5721, 0xd0d68c5d), 0, 0, + "2007/07/25 00:05:37", 60960, "sRGB_v4_ICC_preference.icc") + + /* The following profiles have no known MD5 checksum. If there is a match + * on the (empty) MD5 the other fields are used to attempt a match and + * a warning is produced. The first two of these profiles have a 'cprt' tag + * which suggests that they were also made by Hewlett Packard. + */ + PNG_ICC_CHECKSUM(0xa054d762, 0x5d5129ce, + PNG_MD5(0x00000000, 0x00000000, 0x00000000, 0x00000000), 1, 0, + "2004/07/21 18:57:42", 3024, "sRGB_IEC61966-2-1_noBPC.icc") + + /* This is a 'mntr' (display) profile with a mediaWhitePointTag that does not + * match the D50 PCS illuminant in the header (it is in fact the D65 values, + * so the white point is recorded as the un-adapted value.) The profiles + * below only differ in one byte - the intent - and are basically the same as + * the previous profile except for the mediaWhitePointTag error and a missing + * chromaticAdaptationTag. + */ + PNG_ICC_CHECKSUM(0xf784f3fb, 0x182ea552, + PNG_MD5(0x00000000, 0x00000000, 0x00000000, 0x00000000), 0, 1/*broken*/, + "1998/02/09 06:49:00", 3144, "HP-Microsoft sRGB v2 perceptual") + + PNG_ICC_CHECKSUM(0x0398f3fc, 0xf29e526d, + PNG_MD5(0x00000000, 0x00000000, 0x00000000, 0x00000000), 1, 1/*broken*/, + "1998/02/09 06:49:00", 3144, "HP-Microsoft sRGB v2 media-relative") +}; + +static int +png_compare_ICC_profile_with_sRGB(png_const_structrp png_ptr, + png_const_bytep profile, uLong adler) +{ + /* The quick check is to verify just the MD5 signature and trust the + * rest of the data. Because the profile has already been verified for + * correctness this is safe. png_colorspace_set_sRGB will check the 'intent' + * field too, so if the profile has been edited with an intent not defined + * by sRGB (but maybe defined by a later ICC specification) the read of + * the profile will fail at that point. + */ + png_uint_32 length = 0; + png_uint_32 intent = 0x10000; /* invalid */ +#if PNG_sRGB_PROFILE_CHECKS > 1 + uLong crc = 0; /* the value for 0 length data */ +#endif + unsigned int i; + + for (i=0; i < (sizeof png_sRGB_checks) / (sizeof png_sRGB_checks[0]); ++i) + { + if (png_get_uint_32(profile+84) == png_sRGB_checks[i].md5[0] && + png_get_uint_32(profile+88) == png_sRGB_checks[i].md5[1] && + png_get_uint_32(profile+92) == png_sRGB_checks[i].md5[2] && + png_get_uint_32(profile+96) == png_sRGB_checks[i].md5[3]) + { + /* This may be one of the old HP profiles without an MD5, in that + * case we can only use the length and Adler32 (note that these + * are not used by default if there is an MD5!) + */ +# if PNG_sRGB_PROFILE_CHECKS == 0 + if (png_sRGB_checks[i].have_md5) + return 1+png_sRGB_checks[i].is_broken; +# endif + + /* Profile is unsigned or more checks have been configured in. */ + if (length == 0) + { + length = png_get_uint_32(profile); + intent = png_get_uint_32(profile+64); + } + + /* Length *and* intent must match */ + if (length == png_sRGB_checks[i].length && + intent == png_sRGB_checks[i].intent) + { + /* Now calculate the adler32 if not done already. */ + if (adler == 0) + { + adler = adler32(0, NULL, 0); + adler = adler32(adler, profile, length); + } + + if (adler == png_sRGB_checks[i].adler) + { + /* These basic checks suggest that the data has not been + * modified, but if the check level is more than 1 perform + * our own crc32 checksum on the data. + */ +# if PNG_sRGB_PROFILE_CHECKS > 1 + if (crc == 0) + { + crc = crc32(0, NULL, 0); + crc = crc32(crc, profile, length); + } + + /* So this check must pass for the 'return' below to happen. + */ + if (crc == png_sRGB_checks[i].crc) +# endif + { + if (png_sRGB_checks[i].is_broken) + { + /* These profiles are known to have bad data that may cause + * problems if they are used, therefore attempt to + * discourage their use, skip the 'have_md5' warning below, + * which is made irrelevant by this error. + */ + png_chunk_report(png_ptr, "known incorrect sRGB profile", + PNG_CHUNK_ERROR); + } + + /* Warn that this being done; this isn't even an error since + * the profile is perfectly valid, but it would be nice if + * people used the up-to-date ones. + */ + else if (!png_sRGB_checks[i].have_md5) + { + png_chunk_report(png_ptr, + "out-of-date sRGB profile with no signature", + PNG_CHUNK_WARNING); + } + + return 1+png_sRGB_checks[i].is_broken; + } + } + } + +# if PNG_sRGB_PROFILE_CHECKS > 0 + /* The signature matched, but the profile had been changed in some + * way. This is an apparent violation of the ICC terms of use and, + * anyway, probably indicates a data error or uninformed hacking. + */ + if (png_sRGB_checks[i].have_md5) + png_benign_error(png_ptr, + "copyright violation: edited ICC profile ignored"); +# endif + } + } + + return 0; /* no match */ +} +#endif + +#ifdef PNG_sRGB_SUPPORTED +void /* PRIVATE */ +png_icc_set_sRGB(png_const_structrp png_ptr, + png_colorspacerp colorspace, png_const_bytep profile, uLong adler) +{ + /* Is this profile one of the known ICC sRGB profiles? If it is, just set + * the sRGB information. + */ + if (png_compare_ICC_profile_with_sRGB(png_ptr, profile, adler)) + (void)png_colorspace_set_sRGB(png_ptr, colorspace, + (int)/*already checked*/png_get_uint_32(profile+64)); +} +#endif /* PNG_READ_sRGB_SUPPORTED */ + +int /* PRIVATE */ +png_colorspace_set_ICC(png_const_structrp png_ptr, png_colorspacerp colorspace, + png_const_charp name, png_uint_32 profile_length, png_const_bytep profile, + int color_type) +{ + if (colorspace->flags & PNG_COLORSPACE_INVALID) + return 0; + + if (png_icc_check_length(png_ptr, colorspace, name, profile_length) && + png_icc_check_header(png_ptr, colorspace, name, profile_length, profile, + color_type) && + png_icc_check_tag_table(png_ptr, colorspace, name, profile_length, + profile)) + { +# ifdef PNG_sRGB_SUPPORTED + /* If no sRGB support, don't try storing sRGB information */ + png_icc_set_sRGB(png_ptr, colorspace, profile, 0); +# endif + return 1; + } + + /* Failure case */ + return 0; +} +#endif /* iCCP */ + +#ifdef PNG_READ_RGB_TO_GRAY_SUPPORTED +void /* PRIVATE */ +png_colorspace_set_rgb_coefficients(png_structrp png_ptr) +{ + /* Set the rgb_to_gray coefficients from the colorspace. */ + if (!png_ptr->rgb_to_gray_coefficients_set && + (png_ptr->colorspace.flags & PNG_COLORSPACE_HAVE_ENDPOINTS) != 0) + { + /* png_set_background has not been called, get the coefficients from the Y + * values of the colorspace colorants. + */ + png_fixed_point r = png_ptr->colorspace.end_points_XYZ.red_Y; + png_fixed_point g = png_ptr->colorspace.end_points_XYZ.green_Y; + png_fixed_point b = png_ptr->colorspace.end_points_XYZ.blue_Y; + png_fixed_point total = r+g+b; + + if (total > 0 && + r >= 0 && png_muldiv(&r, r, 32768, total) && r >= 0 && r <= 32768 && + g >= 0 && png_muldiv(&g, g, 32768, total) && g >= 0 && g <= 32768 && + b >= 0 && png_muldiv(&b, b, 32768, total) && b >= 0 && b <= 32768 && + r+g+b <= 32769) + { + /* We allow 0 coefficients here. r+g+b may be 32769 if two or + * all of the coefficients were rounded up. Handle this by + * reducing the *largest* coefficient by 1; this matches the + * approach used for the default coefficients in pngrtran.c + */ + int add = 0; + + if (r+g+b > 32768) + add = -1; + else if (r+g+b < 32768) + add = 1; + + if (add != 0) + { + if (g >= r && g >= b) + g += add; + else if (r >= g && r >= b) + r += add; + else + b += add; + } + + /* Check for an internal error. */ + if (r+g+b != 32768) + png_error(png_ptr, + "internal error handling cHRM coefficients"); + + else + { + png_ptr->rgb_to_gray_red_coeff = (png_uint_16)r; + png_ptr->rgb_to_gray_green_coeff = (png_uint_16)g; + } + } + + /* This is a png_error at present even though it could be ignored - + * it should never happen, but it is important that if it does, the + * bug is fixed. + */ + else + png_error(png_ptr, "internal error handling cHRM->XYZ"); + } +} +#endif + +#endif /* COLORSPACE */ + +void /* PRIVATE */ +png_check_IHDR(png_const_structrp png_ptr, + png_uint_32 width, png_uint_32 height, int bit_depth, + int color_type, int interlace_type, int compression_type, + int filter_type) +{ + int error = 0; + + /* Check for width and height valid values */ + if (width == 0) + { + png_warning(png_ptr, "Image width is zero in IHDR"); + error = 1; + } + + if (height == 0) + { + png_warning(png_ptr, "Image height is zero in IHDR"); + error = 1; + } + +# ifdef PNG_SET_USER_LIMITS_SUPPORTED + if (width > png_ptr->user_width_max) + +# else + if (width > PNG_USER_WIDTH_MAX) +# endif + { + png_warning(png_ptr, "Image width exceeds user limit in IHDR"); + error = 1; + } + +# ifdef PNG_SET_USER_LIMITS_SUPPORTED + if (height > png_ptr->user_height_max) +# else + if (height > PNG_USER_HEIGHT_MAX) +# endif + { + png_warning(png_ptr, "Image height exceeds user limit in IHDR"); + error = 1; + } + + if (width > PNG_UINT_31_MAX) + { + png_warning(png_ptr, "Invalid image width in IHDR"); + error = 1; + } + + if (height > PNG_UINT_31_MAX) + { + png_warning(png_ptr, "Invalid image height in IHDR"); + error = 1; + } + + if (width > (PNG_UINT_32_MAX + >> 3) /* 8-byte RGBA pixels */ + - 48 /* bigrowbuf hack */ + - 1 /* filter byte */ + - 7*8 /* rounding of width to multiple of 8 pixels */ + - 8) /* extra max_pixel_depth pad */ + png_warning(png_ptr, "Width is too large for libpng to process pixels"); + + /* Check other values */ + if (bit_depth != 1 && bit_depth != 2 && bit_depth != 4 && + bit_depth != 8 && bit_depth != 16) + { + png_warning(png_ptr, "Invalid bit depth in IHDR"); + error = 1; + } + + if (color_type < 0 || color_type == 1 || + color_type == 5 || color_type > 6) + { + png_warning(png_ptr, "Invalid color type in IHDR"); + error = 1; + } + + if (((color_type == PNG_COLOR_TYPE_PALETTE) && bit_depth > 8) || + ((color_type == PNG_COLOR_TYPE_RGB || + color_type == PNG_COLOR_TYPE_GRAY_ALPHA || + color_type == PNG_COLOR_TYPE_RGB_ALPHA) && bit_depth < 8)) + { + png_warning(png_ptr, "Invalid color type/bit depth combination in IHDR"); + error = 1; + } + + if (interlace_type >= PNG_INTERLACE_LAST) + { + png_warning(png_ptr, "Unknown interlace method in IHDR"); + error = 1; + } + + if (compression_type != PNG_COMPRESSION_TYPE_BASE) + { + png_warning(png_ptr, "Unknown compression method in IHDR"); + error = 1; + } + +# ifdef PNG_MNG_FEATURES_SUPPORTED + /* Accept filter_method 64 (intrapixel differencing) only if + * 1. Libpng was compiled with PNG_MNG_FEATURES_SUPPORTED and + * 2. Libpng did not read a PNG signature (this filter_method is only + * used in PNG datastreams that are embedded in MNG datastreams) and + * 3. The application called png_permit_mng_features with a mask that + * included PNG_FLAG_MNG_FILTER_64 and + * 4. The filter_method is 64 and + * 5. The color_type is RGB or RGBA + */ + if ((png_ptr->mode & PNG_HAVE_PNG_SIGNATURE) && + png_ptr->mng_features_permitted) + png_warning(png_ptr, "MNG features are not allowed in a PNG datastream"); + + if (filter_type != PNG_FILTER_TYPE_BASE) + { + if (!((png_ptr->mng_features_permitted & PNG_FLAG_MNG_FILTER_64) && + (filter_type == PNG_INTRAPIXEL_DIFFERENCING) && + ((png_ptr->mode & PNG_HAVE_PNG_SIGNATURE) == 0) && + (color_type == PNG_COLOR_TYPE_RGB || + color_type == PNG_COLOR_TYPE_RGB_ALPHA))) + { + png_warning(png_ptr, "Unknown filter method in IHDR"); + error = 1; + } + + if (png_ptr->mode & PNG_HAVE_PNG_SIGNATURE) + { + png_warning(png_ptr, "Invalid filter method in IHDR"); + error = 1; + } + } + +# else + if (filter_type != PNG_FILTER_TYPE_BASE) + { + png_warning(png_ptr, "Unknown filter method in IHDR"); + error = 1; + } +# endif + + if (error == 1) + png_error(png_ptr, "Invalid IHDR data"); +} + +#if defined(PNG_sCAL_SUPPORTED) || defined(PNG_pCAL_SUPPORTED) +/* ASCII to fp functions */ +/* Check an ASCII formated floating point value, see the more detailed + * comments in pngpriv.h + */ +/* The following is used internally to preserve the sticky flags */ +#define png_fp_add(state, flags) ((state) |= (flags)) +#define png_fp_set(state, value) ((state) = (value) | ((state) & PNG_FP_STICKY)) + +int /* PRIVATE */ +png_check_fp_number(png_const_charp string, png_size_t size, int *statep, + png_size_tp whereami) +{ + int state = *statep; + png_size_t i = *whereami; + + while (i < size) + { + int type; + /* First find the type of the next character */ + switch (string[i]) + { + case 43: type = PNG_FP_SAW_SIGN; break; + case 45: type = PNG_FP_SAW_SIGN + PNG_FP_NEGATIVE; break; + case 46: type = PNG_FP_SAW_DOT; break; + case 48: type = PNG_FP_SAW_DIGIT; break; + case 49: case 50: case 51: case 52: + case 53: case 54: case 55: case 56: + case 57: type = PNG_FP_SAW_DIGIT + PNG_FP_NONZERO; break; + case 69: + case 101: type = PNG_FP_SAW_E; break; + default: goto PNG_FP_End; + } + + /* Now deal with this type according to the current + * state, the type is arranged to not overlap the + * bits of the PNG_FP_STATE. + */ + switch ((state & PNG_FP_STATE) + (type & PNG_FP_SAW_ANY)) + { + case PNG_FP_INTEGER + PNG_FP_SAW_SIGN: + if (state & PNG_FP_SAW_ANY) + goto PNG_FP_End; /* not a part of the number */ + + png_fp_add(state, type); + break; + + case PNG_FP_INTEGER + PNG_FP_SAW_DOT: + /* Ok as trailer, ok as lead of fraction. */ + if (state & PNG_FP_SAW_DOT) /* two dots */ + goto PNG_FP_End; + + else if (state & PNG_FP_SAW_DIGIT) /* trailing dot? */ + png_fp_add(state, type); + + else + png_fp_set(state, PNG_FP_FRACTION | type); + + break; + + case PNG_FP_INTEGER + PNG_FP_SAW_DIGIT: + if (state & PNG_FP_SAW_DOT) /* delayed fraction */ + png_fp_set(state, PNG_FP_FRACTION | PNG_FP_SAW_DOT); + + png_fp_add(state, type | PNG_FP_WAS_VALID); + + break; + + case PNG_FP_INTEGER + PNG_FP_SAW_E: + if ((state & PNG_FP_SAW_DIGIT) == 0) + goto PNG_FP_End; + + png_fp_set(state, PNG_FP_EXPONENT); + + break; + + /* case PNG_FP_FRACTION + PNG_FP_SAW_SIGN: + goto PNG_FP_End; ** no sign in fraction */ + + /* case PNG_FP_FRACTION + PNG_FP_SAW_DOT: + goto PNG_FP_End; ** Because SAW_DOT is always set */ + + case PNG_FP_FRACTION + PNG_FP_SAW_DIGIT: + png_fp_add(state, type | PNG_FP_WAS_VALID); + break; + + case PNG_FP_FRACTION + PNG_FP_SAW_E: + /* This is correct because the trailing '.' on an + * integer is handled above - so we can only get here + * with the sequence ".E" (with no preceding digits). + */ + if ((state & PNG_FP_SAW_DIGIT) == 0) + goto PNG_FP_End; + + png_fp_set(state, PNG_FP_EXPONENT); + + break; + + case PNG_FP_EXPONENT + PNG_FP_SAW_SIGN: + if (state & PNG_FP_SAW_ANY) + goto PNG_FP_End; /* not a part of the number */ + + png_fp_add(state, PNG_FP_SAW_SIGN); + + break; + + /* case PNG_FP_EXPONENT + PNG_FP_SAW_DOT: + goto PNG_FP_End; */ + + case PNG_FP_EXPONENT + PNG_FP_SAW_DIGIT: + png_fp_add(state, PNG_FP_SAW_DIGIT | PNG_FP_WAS_VALID); + + break; + + /* case PNG_FP_EXPONEXT + PNG_FP_SAW_E: + goto PNG_FP_End; */ + + default: goto PNG_FP_End; /* I.e. break 2 */ + } + + /* The character seems ok, continue. */ + ++i; + } + +PNG_FP_End: + /* Here at the end, update the state and return the correct + * return code. + */ + *statep = state; + *whereami = i; + + return (state & PNG_FP_SAW_DIGIT) != 0; +} + + +/* The same but for a complete string. */ +int +png_check_fp_string(png_const_charp string, png_size_t size) +{ + int state=0; + png_size_t char_index=0; + + if (png_check_fp_number(string, size, &state, &char_index) && + (char_index == size || string[char_index] == 0)) + return state /* must be non-zero - see above */; + + return 0; /* i.e. fail */ +} +#endif /* pCAL or sCAL */ + +#ifdef PNG_sCAL_SUPPORTED +# ifdef PNG_FLOATING_POINT_SUPPORTED +/* Utility used below - a simple accurate power of ten from an integral + * exponent. + */ +static double +png_pow10(int power) +{ + int recip = 0; + double d = 1; + + /* Handle negative exponent with a reciprocal at the end because + * 10 is exact whereas .1 is inexact in base 2 + */ + if (power < 0) + { + if (power < DBL_MIN_10_EXP) return 0; + recip = 1, power = -power; + } + + if (power > 0) + { + /* Decompose power bitwise. */ + double mult = 10; + do + { + if (power & 1) d *= mult; + mult *= mult; + power >>= 1; + } + while (power > 0); + + if (recip) d = 1/d; + } + /* else power is 0 and d is 1 */ + + return d; +} + +/* Function to format a floating point value in ASCII with a given + * precision. + */ +void /* PRIVATE */ +png_ascii_from_fp(png_const_structrp png_ptr, png_charp ascii, png_size_t size, + double fp, unsigned int precision) +{ + /* We use standard functions from math.h, but not printf because + * that would require stdio. The caller must supply a buffer of + * sufficient size or we will png_error. The tests on size and + * the space in ascii[] consumed are indicated below. + */ + if (precision < 1) + precision = DBL_DIG; + + /* Enforce the limit of the implementation precision too. */ + if (precision > DBL_DIG+1) + precision = DBL_DIG+1; + + /* Basic sanity checks */ + if (size >= precision+5) /* See the requirements below. */ + { + if (fp < 0) + { + fp = -fp; + *ascii++ = 45; /* '-' PLUS 1 TOTAL 1 */ + --size; + } + + if (fp >= DBL_MIN && fp <= DBL_MAX) + { + int exp_b10; /* A base 10 exponent */ + double base; /* 10^exp_b10 */ + + /* First extract a base 10 exponent of the number, + * the calculation below rounds down when converting + * from base 2 to base 10 (multiply by log10(2) - + * 0.3010, but 77/256 is 0.3008, so exp_b10 needs to + * be increased. Note that the arithmetic shift + * performs a floor() unlike C arithmetic - using a + * C multiply would break the following for negative + * exponents. + */ + (void)frexp(fp, &exp_b10); /* exponent to base 2 */ + + exp_b10 = (exp_b10 * 77) >> 8; /* <= exponent to base 10 */ + + /* Avoid underflow here. */ + base = png_pow10(exp_b10); /* May underflow */ + + while (base < DBL_MIN || base < fp) + { + /* And this may overflow. */ + double test = png_pow10(exp_b10+1); + + if (test <= DBL_MAX) + ++exp_b10, base = test; + + else + break; + } + + /* Normalize fp and correct exp_b10, after this fp is in the + * range [.1,1) and exp_b10 is both the exponent and the digit + * *before* which the decimal point should be inserted + * (starting with 0 for the first digit). Note that this + * works even if 10^exp_b10 is out of range because of the + * test on DBL_MAX above. + */ + fp /= base; + while (fp >= 1) fp /= 10, ++exp_b10; + + /* Because of the code above fp may, at this point, be + * less than .1, this is ok because the code below can + * handle the leading zeros this generates, so no attempt + * is made to correct that here. + */ + + { + int czero, clead, cdigits; + char exponent[10]; + + /* Allow up to two leading zeros - this will not lengthen + * the number compared to using E-n. + */ + if (exp_b10 < 0 && exp_b10 > -3) /* PLUS 3 TOTAL 4 */ + { + czero = -exp_b10; /* PLUS 2 digits: TOTAL 3 */ + exp_b10 = 0; /* Dot added below before first output. */ + } + else + czero = 0; /* No zeros to add */ + + /* Generate the digit list, stripping trailing zeros and + * inserting a '.' before a digit if the exponent is 0. + */ + clead = czero; /* Count of leading zeros */ + cdigits = 0; /* Count of digits in list. */ + + do + { + double d; + + fp *= 10; + /* Use modf here, not floor and subtract, so that + * the separation is done in one step. At the end + * of the loop don't break the number into parts so + * that the final digit is rounded. + */ + if (cdigits+czero-clead+1 < (int)precision) + fp = modf(fp, &d); + + else + { + d = floor(fp + .5); + + if (d > 9) + { + /* Rounding up to 10, handle that here. */ + if (czero > 0) + { + --czero, d = 1; + if (cdigits == 0) --clead; + } + else + { + while (cdigits > 0 && d > 9) + { + int ch = *--ascii; + + if (exp_b10 != (-1)) + ++exp_b10; + + else if (ch == 46) + { + ch = *--ascii, ++size; + /* Advance exp_b10 to '1', so that the + * decimal point happens after the + * previous digit. + */ + exp_b10 = 1; + } + + --cdigits; + d = ch - 47; /* I.e. 1+(ch-48) */ + } + + /* Did we reach the beginning? If so adjust the + * exponent but take into account the leading + * decimal point. + */ + if (d > 9) /* cdigits == 0 */ + { + if (exp_b10 == (-1)) + { + /* Leading decimal point (plus zeros?), if + * we lose the decimal point here it must + * be reentered below. + */ + int ch = *--ascii; + + if (ch == 46) + ++size, exp_b10 = 1; + + /* Else lost a leading zero, so 'exp_b10' is + * still ok at (-1) + */ + } + else + ++exp_b10; + + /* In all cases we output a '1' */ + d = 1; + } + } + } + fp = 0; /* Guarantees termination below. */ + } + + if (d == 0) + { + ++czero; + if (cdigits == 0) ++clead; + } + else + { + /* Included embedded zeros in the digit count. */ + cdigits += czero - clead; + clead = 0; + + while (czero > 0) + { + /* exp_b10 == (-1) means we just output the decimal + * place - after the DP don't adjust 'exp_b10' any + * more! + */ + if (exp_b10 != (-1)) + { + if (exp_b10 == 0) *ascii++ = 46, --size; + /* PLUS 1: TOTAL 4 */ + --exp_b10; + } + *ascii++ = 48, --czero; + } + + if (exp_b10 != (-1)) + { + if (exp_b10 == 0) *ascii++ = 46, --size; /* counted + above */ + --exp_b10; + } + *ascii++ = (char)(48 + (int)d), ++cdigits; + } + } + while (cdigits+czero-clead < (int)precision && fp > DBL_MIN); + + /* The total output count (max) is now 4+precision */ + + /* Check for an exponent, if we don't need one we are + * done and just need to terminate the string. At + * this point exp_b10==(-1) is effectively if flag - it got + * to '-1' because of the decrement after outputing + * the decimal point above (the exponent required is + * *not* -1!) + */ + if (exp_b10 >= (-1) && exp_b10 <= 2) + { + /* The following only happens if we didn't output the + * leading zeros above for negative exponent, so this + * doest add to the digit requirement. Note that the + * two zeros here can only be output if the two leading + * zeros were *not* output, so this doesn't increase + * the output count. + */ + while (--exp_b10 >= 0) *ascii++ = 48; + + *ascii = 0; + + /* Total buffer requirement (including the '\0') is + * 5+precision - see check at the start. + */ + return; + } + + /* Here if an exponent is required, adjust size for + * the digits we output but did not count. The total + * digit output here so far is at most 1+precision - no + * decimal point and no leading or trailing zeros have + * been output. + */ + size -= cdigits; + + *ascii++ = 69, --size; /* 'E': PLUS 1 TOTAL 2+precision */ + + /* The following use of an unsigned temporary avoids ambiguities in + * the signed arithmetic on exp_b10 and permits GCC at least to do + * better optimization. + */ + { + unsigned int uexp_b10; + + if (exp_b10 < 0) + { + *ascii++ = 45, --size; /* '-': PLUS 1 TOTAL 3+precision */ + uexp_b10 = -exp_b10; + } + + else + uexp_b10 = exp_b10; + + cdigits = 0; + + while (uexp_b10 > 0) + { + exponent[cdigits++] = (char)(48 + uexp_b10 % 10); + uexp_b10 /= 10; + } + } + + /* Need another size check here for the exponent digits, so + * this need not be considered above. + */ + if ((int)size > cdigits) + { + while (cdigits > 0) *ascii++ = exponent[--cdigits]; + + *ascii = 0; + + return; + } + } + } + else if (!(fp >= DBL_MIN)) + { + *ascii++ = 48; /* '0' */ + *ascii = 0; + return; + } + else + { + *ascii++ = 105; /* 'i' */ + *ascii++ = 110; /* 'n' */ + *ascii++ = 102; /* 'f' */ + *ascii = 0; + return; + } + } + + /* Here on buffer too small. */ + png_error(png_ptr, "ASCII conversion buffer too small"); +} + +# endif /* FLOATING_POINT */ + +# ifdef PNG_FIXED_POINT_SUPPORTED +/* Function to format a fixed point value in ASCII. + */ +void /* PRIVATE */ +png_ascii_from_fixed(png_const_structrp png_ptr, png_charp ascii, + png_size_t size, png_fixed_point fp) +{ + /* Require space for 10 decimal digits, a decimal point, a minus sign and a + * trailing \0, 13 characters: + */ + if (size > 12) + { + png_uint_32 num; + + /* Avoid overflow here on the minimum integer. */ + if (fp < 0) + *ascii++ = 45, --size, num = -fp; + else + num = fp; + + if (num <= 0x80000000) /* else overflowed */ + { + unsigned int ndigits = 0, first = 16 /* flag value */; + char digits[10]; + + while (num) + { + /* Split the low digit off num: */ + unsigned int tmp = num/10; + num -= tmp*10; + digits[ndigits++] = (char)(48 + num); + /* Record the first non-zero digit, note that this is a number + * starting at 1, it's not actually the array index. + */ + if (first == 16 && num > 0) + first = ndigits; + num = tmp; + } + + if (ndigits > 0) + { + while (ndigits > 5) *ascii++ = digits[--ndigits]; + /* The remaining digits are fractional digits, ndigits is '5' or + * smaller at this point. It is certainly not zero. Check for a + * non-zero fractional digit: + */ + if (first <= 5) + { + unsigned int i; + *ascii++ = 46; /* decimal point */ + /* ndigits may be <5 for small numbers, output leading zeros + * then ndigits digits to first: + */ + i = 5; + while (ndigits < i) *ascii++ = 48, --i; + while (ndigits >= first) *ascii++ = digits[--ndigits]; + /* Don't output the trailing zeros! */ + } + } + else + *ascii++ = 48; + + /* And null terminate the string: */ + *ascii = 0; + return; + } + } + + /* Here on buffer too small. */ + png_error(png_ptr, "ASCII conversion buffer too small"); +} +# endif /* FIXED_POINT */ +#endif /* READ_SCAL */ + +#if defined(PNG_FLOATING_POINT_SUPPORTED) && \ + !defined(PNG_FIXED_POINT_MACRO_SUPPORTED) && \ + (defined(PNG_gAMA_SUPPORTED) || defined(PNG_cHRM_SUPPORTED) || \ + defined(PNG_sCAL_SUPPORTED) || defined(PNG_READ_BACKGROUND_SUPPORTED) || \ + defined(PNG_READ_RGB_TO_GRAY_SUPPORTED)) || \ + (defined(PNG_sCAL_SUPPORTED) && \ + defined(PNG_FLOATING_ARITHMETIC_SUPPORTED)) +png_fixed_point +png_fixed(png_const_structrp png_ptr, double fp, png_const_charp text) +{ + double r = floor(100000 * fp + .5); + + if (r > 2147483647. || r < -2147483648.) + png_fixed_error(png_ptr, text); + + return (png_fixed_point)r; +} +#endif + +#if defined(PNG_READ_GAMMA_SUPPORTED) || \ + defined(PNG_INCH_CONVERSIONS_SUPPORTED) || defined(PNG_READ_pHYs_SUPPORTED) +/* muldiv functions */ +/* This API takes signed arguments and rounds the result to the nearest + * integer (or, for a fixed point number - the standard argument - to + * the nearest .00001). Overflow and divide by zero are signalled in + * the result, a boolean - true on success, false on overflow. + */ +int +png_muldiv(png_fixed_point_p res, png_fixed_point a, png_int_32 times, + png_int_32 divisor) +{ + /* Return a * times / divisor, rounded. */ + if (divisor != 0) + { + if (a == 0 || times == 0) + { + *res = 0; + return 1; + } + else + { +#ifdef PNG_FLOATING_ARITHMETIC_SUPPORTED + double r = a; + r *= times; + r /= divisor; + r = floor(r+.5); + + /* A png_fixed_point is a 32-bit integer. */ + if (r <= 2147483647. && r >= -2147483648.) + { + *res = (png_fixed_point)r; + return 1; + } +#else + int negative = 0; + png_uint_32 A, T, D; + png_uint_32 s16, s32, s00; + + if (a < 0) + negative = 1, A = -a; + else + A = a; + + if (times < 0) + negative = !negative, T = -times; + else + T = times; + + if (divisor < 0) + negative = !negative, D = -divisor; + else + D = divisor; + + /* Following can't overflow because the arguments only + * have 31 bits each, however the result may be 32 bits. + */ + s16 = (A >> 16) * (T & 0xffff) + + (A & 0xffff) * (T >> 16); + /* Can't overflow because the a*times bit is only 30 + * bits at most. + */ + s32 = (A >> 16) * (T >> 16) + (s16 >> 16); + s00 = (A & 0xffff) * (T & 0xffff); + + s16 = (s16 & 0xffff) << 16; + s00 += s16; + + if (s00 < s16) + ++s32; /* carry */ + + if (s32 < D) /* else overflow */ + { + /* s32.s00 is now the 64-bit product, do a standard + * division, we know that s32 < D, so the maximum + * required shift is 31. + */ + int bitshift = 32; + png_fixed_point result = 0; /* NOTE: signed */ + + while (--bitshift >= 0) + { + png_uint_32 d32, d00; + + if (bitshift > 0) + d32 = D >> (32-bitshift), d00 = D << bitshift; + + else + d32 = 0, d00 = D; + + if (s32 > d32) + { + if (s00 < d00) --s32; /* carry */ + s32 -= d32, s00 -= d00, result += 1<= d00) + s32 = 0, s00 -= d00, result += 1<= (D >> 1)) + ++result; + + if (negative) + result = -result; + + /* Check for overflow. */ + if ((negative && result <= 0) || (!negative && result >= 0)) + { + *res = result; + return 1; + } + } +#endif + } + } + + return 0; +} +#endif /* READ_GAMMA || INCH_CONVERSIONS */ + +#if defined(PNG_READ_GAMMA_SUPPORTED) || defined(PNG_INCH_CONVERSIONS_SUPPORTED) +/* The following is for when the caller doesn't much care about the + * result. + */ +png_fixed_point +png_muldiv_warn(png_const_structrp png_ptr, png_fixed_point a, png_int_32 times, + png_int_32 divisor) +{ + png_fixed_point result; + + if (png_muldiv(&result, a, times, divisor)) + return result; + + png_warning(png_ptr, "fixed point overflow ignored"); + return 0; +} +#endif + +#ifdef PNG_GAMMA_SUPPORTED /* more fixed point functions for gamma */ +/* Calculate a reciprocal, return 0 on div-by-zero or overflow. */ +png_fixed_point +png_reciprocal(png_fixed_point a) +{ +#ifdef PNG_FLOATING_ARITHMETIC_SUPPORTED + double r = floor(1E10/a+.5); + + if (r <= 2147483647. && r >= -2147483648.) + return (png_fixed_point)r; +#else + png_fixed_point res; + + if (png_muldiv(&res, 100000, 100000, a)) + return res; +#endif + + return 0; /* error/overflow */ +} + +/* This is the shared test on whether a gamma value is 'significant' - whether + * it is worth doing gamma correction. + */ +int /* PRIVATE */ +png_gamma_significant(png_fixed_point gamma_val) +{ + return gamma_val < PNG_FP_1 - PNG_GAMMA_THRESHOLD_FIXED || + gamma_val > PNG_FP_1 + PNG_GAMMA_THRESHOLD_FIXED; +} +#endif + +#ifdef PNG_READ_GAMMA_SUPPORTED +/* A local convenience routine. */ +static png_fixed_point +png_product2(png_fixed_point a, png_fixed_point b) +{ + /* The required result is 1/a * 1/b; the following preserves accuracy. */ +#ifdef PNG_FLOATING_ARITHMETIC_SUPPORTED + double r = a * 1E-5; + r *= b; + r = floor(r+.5); + + if (r <= 2147483647. && r >= -2147483648.) + return (png_fixed_point)r; +#else + png_fixed_point res; + + if (png_muldiv(&res, a, b, 100000)) + return res; +#endif + + return 0; /* overflow */ +} + +/* The inverse of the above. */ +png_fixed_point +png_reciprocal2(png_fixed_point a, png_fixed_point b) +{ + /* The required result is 1/a * 1/b; the following preserves accuracy. */ +#ifdef PNG_FLOATING_ARITHMETIC_SUPPORTED + double r = 1E15/a; + r /= b; + r = floor(r+.5); + + if (r <= 2147483647. && r >= -2147483648.) + return (png_fixed_point)r; +#else + /* This may overflow because the range of png_fixed_point isn't symmetric, + * but this API is only used for the product of file and screen gamma so it + * doesn't matter that the smallest number it can produce is 1/21474, not + * 1/100000 + */ + png_fixed_point res = png_product2(a, b); + + if (res != 0) + return png_reciprocal(res); +#endif + + return 0; /* overflow */ +} +#endif /* READ_GAMMA */ + +#ifdef PNG_READ_GAMMA_SUPPORTED /* gamma table code */ +#ifndef PNG_FLOATING_ARITHMETIC_SUPPORTED +/* Fixed point gamma. + * + * The code to calculate the tables used below can be found in the shell script + * contrib/tools/intgamma.sh + * + * To calculate gamma this code implements fast log() and exp() calls using only + * fixed point arithmetic. This code has sufficient precision for either 8-bit + * or 16-bit sample values. + * + * The tables used here were calculated using simple 'bc' programs, but C double + * precision floating point arithmetic would work fine. + * + * 8-bit log table + * This is a table of -log(value/255)/log(2) for 'value' in the range 128 to + * 255, so it's the base 2 logarithm of a normalized 8-bit floating point + * mantissa. The numbers are 32-bit fractions. + */ +static const png_uint_32 +png_8bit_l2[128] = +{ + 4270715492U, 4222494797U, 4174646467U, 4127164793U, 4080044201U, 4033279239U, + 3986864580U, 3940795015U, 3895065449U, 3849670902U, 3804606499U, 3759867474U, + 3715449162U, 3671346997U, 3627556511U, 3584073329U, 3540893168U, 3498011834U, + 3455425220U, 3413129301U, 3371120137U, 3329393864U, 3287946700U, 3246774933U, + 3205874930U, 3165243125U, 3124876025U, 3084770202U, 3044922296U, 3005329011U, + 2965987113U, 2926893432U, 2888044853U, 2849438323U, 2811070844U, 2772939474U, + 2735041326U, 2697373562U, 2659933400U, 2622718104U, 2585724991U, 2548951424U, + 2512394810U, 2476052606U, 2439922311U, 2404001468U, 2368287663U, 2332778523U, + 2297471715U, 2262364947U, 2227455964U, 2192742551U, 2158222529U, 2123893754U, + 2089754119U, 2055801552U, 2022034013U, 1988449497U, 1955046031U, 1921821672U, + 1888774511U, 1855902668U, 1823204291U, 1790677560U, 1758320682U, 1726131893U, + 1694109454U, 1662251657U, 1630556815U, 1599023271U, 1567649391U, 1536433567U, + 1505374214U, 1474469770U, 1443718700U, 1413119487U, 1382670639U, 1352370686U, + 1322218179U, 1292211689U, 1262349810U, 1232631153U, 1203054352U, 1173618059U, + 1144320946U, 1115161701U, 1086139034U, 1057251672U, 1028498358U, 999877854U, + 971388940U, 943030410U, 914801076U, 886699767U, 858725327U, 830876614U, + 803152505U, 775551890U, 748073672U, 720716771U, 693480120U, 666362667U, + 639363374U, 612481215U, 585715177U, 559064263U, 532527486U, 506103872U, + 479792461U, 453592303U, 427502463U, 401522014U, 375650043U, 349885648U, + 324227938U, 298676034U, 273229066U, 247886176U, 222646516U, 197509248U, + 172473545U, 147538590U, 122703574U, 97967701U, 73330182U, 48790236U, + 24347096U, 0U + +#if 0 + /* The following are the values for 16-bit tables - these work fine for the + * 8-bit conversions but produce very slightly larger errors in the 16-bit + * log (about 1.2 as opposed to 0.7 absolute error in the final value). To + * use these all the shifts below must be adjusted appropriately. + */ + 65166, 64430, 63700, 62976, 62257, 61543, 60835, 60132, 59434, 58741, 58054, + 57371, 56693, 56020, 55352, 54689, 54030, 53375, 52726, 52080, 51439, 50803, + 50170, 49542, 48918, 48298, 47682, 47070, 46462, 45858, 45257, 44661, 44068, + 43479, 42894, 42312, 41733, 41159, 40587, 40020, 39455, 38894, 38336, 37782, + 37230, 36682, 36137, 35595, 35057, 34521, 33988, 33459, 32932, 32408, 31887, + 31369, 30854, 30341, 29832, 29325, 28820, 28319, 27820, 27324, 26830, 26339, + 25850, 25364, 24880, 24399, 23920, 23444, 22970, 22499, 22029, 21562, 21098, + 20636, 20175, 19718, 19262, 18808, 18357, 17908, 17461, 17016, 16573, 16132, + 15694, 15257, 14822, 14390, 13959, 13530, 13103, 12678, 12255, 11834, 11415, + 10997, 10582, 10168, 9756, 9346, 8937, 8531, 8126, 7723, 7321, 6921, 6523, + 6127, 5732, 5339, 4947, 4557, 4169, 3782, 3397, 3014, 2632, 2251, 1872, 1495, + 1119, 744, 372 +#endif +}; + +static png_int_32 +png_log8bit(unsigned int x) +{ + unsigned int lg2 = 0; + /* Each time 'x' is multiplied by 2, 1 must be subtracted off the final log, + * because the log is actually negate that means adding 1. The final + * returned value thus has the range 0 (for 255 input) to 7.994 (for 1 + * input), return -1 for the overflow (log 0) case, - so the result is + * always at most 19 bits. + */ + if ((x &= 0xff) == 0) + return -1; + + if ((x & 0xf0) == 0) + lg2 = 4, x <<= 4; + + if ((x & 0xc0) == 0) + lg2 += 2, x <<= 2; + + if ((x & 0x80) == 0) + lg2 += 1, x <<= 1; + + /* result is at most 19 bits, so this cast is safe: */ + return (png_int_32)((lg2 << 16) + ((png_8bit_l2[x-128]+32768)>>16)); +} + +/* The above gives exact (to 16 binary places) log2 values for 8-bit images, + * for 16-bit images we use the most significant 8 bits of the 16-bit value to + * get an approximation then multiply the approximation by a correction factor + * determined by the remaining up to 8 bits. This requires an additional step + * in the 16-bit case. + * + * We want log2(value/65535), we have log2(v'/255), where: + * + * value = v' * 256 + v'' + * = v' * f + * + * So f is value/v', which is equal to (256+v''/v') since v' is in the range 128 + * to 255 and v'' is in the range 0 to 255 f will be in the range 256 to less + * than 258. The final factor also needs to correct for the fact that our 8-bit + * value is scaled by 255, whereas the 16-bit values must be scaled by 65535. + * + * This gives a final formula using a calculated value 'x' which is value/v' and + * scaling by 65536 to match the above table: + * + * log2(x/257) * 65536 + * + * Since these numbers are so close to '1' we can use simple linear + * interpolation between the two end values 256/257 (result -368.61) and 258/257 + * (result 367.179). The values used below are scaled by a further 64 to give + * 16-bit precision in the interpolation: + * + * Start (256): -23591 + * Zero (257): 0 + * End (258): 23499 + */ +static png_int_32 +png_log16bit(png_uint_32 x) +{ + unsigned int lg2 = 0; + + /* As above, but now the input has 16 bits. */ + if ((x &= 0xffff) == 0) + return -1; + + if ((x & 0xff00) == 0) + lg2 = 8, x <<= 8; + + if ((x & 0xf000) == 0) + lg2 += 4, x <<= 4; + + if ((x & 0xc000) == 0) + lg2 += 2, x <<= 2; + + if ((x & 0x8000) == 0) + lg2 += 1, x <<= 1; + + /* Calculate the base logarithm from the top 8 bits as a 28-bit fractional + * value. + */ + lg2 <<= 28; + lg2 += (png_8bit_l2[(x>>8)-128]+8) >> 4; + + /* Now we need to interpolate the factor, this requires a division by the top + * 8 bits. Do this with maximum precision. + */ + x = ((x << 16) + (x >> 9)) / (x >> 8); + + /* Since we divided by the top 8 bits of 'x' there will be a '1' at 1<<24, + * the value at 1<<16 (ignoring this) will be 0 or 1; this gives us exactly + * 16 bits to interpolate to get the low bits of the result. Round the + * answer. Note that the end point values are scaled by 64 to retain overall + * precision and that 'lg2' is current scaled by an extra 12 bits, so adjust + * the overall scaling by 6-12. Round at every step. + */ + x -= 1U << 24; + + if (x <= 65536U) /* <= '257' */ + lg2 += ((23591U * (65536U-x)) + (1U << (16+6-12-1))) >> (16+6-12); + + else + lg2 -= ((23499U * (x-65536U)) + (1U << (16+6-12-1))) >> (16+6-12); + + /* Safe, because the result can't have more than 20 bits: */ + return (png_int_32)((lg2 + 2048) >> 12); +} + +/* The 'exp()' case must invert the above, taking a 20-bit fixed point + * logarithmic value and returning a 16 or 8-bit number as appropriate. In + * each case only the low 16 bits are relevant - the fraction - since the + * integer bits (the top 4) simply determine a shift. + * + * The worst case is the 16-bit distinction between 65535 and 65534, this + * requires perhaps spurious accuracty in the decoding of the logarithm to + * distinguish log2(65535/65534.5) - 10^-5 or 17 bits. There is little chance + * of getting this accuracy in practice. + * + * To deal with this the following exp() function works out the exponent of the + * frational part of the logarithm by using an accurate 32-bit value from the + * top four fractional bits then multiplying in the remaining bits. + */ +static const png_uint_32 +png_32bit_exp[16] = +{ + /* NOTE: the first entry is deliberately set to the maximum 32-bit value. */ + 4294967295U, 4112874773U, 3938502376U, 3771522796U, 3611622603U, 3458501653U, + 3311872529U, 3171459999U, 3037000500U, 2908241642U, 2784941738U, 2666869345U, + 2553802834U, 2445529972U, 2341847524U, 2242560872U +}; + +/* Adjustment table; provided to explain the numbers in the code below. */ +#if 0 +for (i=11;i>=0;--i){ print i, " ", (1 - e(-(2^i)/65536*l(2))) * 2^(32-i), "\n"} + 11 44937.64284865548751208448 + 10 45180.98734845585101160448 + 9 45303.31936980687359311872 + 8 45364.65110595323018870784 + 7 45395.35850361789624614912 + 6 45410.72259715102037508096 + 5 45418.40724413220722311168 + 4 45422.25021786898173001728 + 3 45424.17186732298419044352 + 2 45425.13273269940811464704 + 1 45425.61317555035558641664 + 0 45425.85339951654943850496 +#endif + +static png_uint_32 +png_exp(png_fixed_point x) +{ + if (x > 0 && x <= 0xfffff) /* Else overflow or zero (underflow) */ + { + /* Obtain a 4-bit approximation */ + png_uint_32 e = png_32bit_exp[(x >> 12) & 0xf]; + + /* Incorporate the low 12 bits - these decrease the returned value by + * multiplying by a number less than 1 if the bit is set. The multiplier + * is determined by the above table and the shift. Notice that the values + * converge on 45426 and this is used to allow linear interpolation of the + * low bits. + */ + if (x & 0x800) + e -= (((e >> 16) * 44938U) + 16U) >> 5; + + if (x & 0x400) + e -= (((e >> 16) * 45181U) + 32U) >> 6; + + if (x & 0x200) + e -= (((e >> 16) * 45303U) + 64U) >> 7; + + if (x & 0x100) + e -= (((e >> 16) * 45365U) + 128U) >> 8; + + if (x & 0x080) + e -= (((e >> 16) * 45395U) + 256U) >> 9; + + if (x & 0x040) + e -= (((e >> 16) * 45410U) + 512U) >> 10; + + /* And handle the low 6 bits in a single block. */ + e -= (((e >> 16) * 355U * (x & 0x3fU)) + 256U) >> 9; + + /* Handle the upper bits of x. */ + e >>= x >> 16; + return e; + } + + /* Check for overflow */ + if (x <= 0) + return png_32bit_exp[0]; + + /* Else underflow */ + return 0; +} + +static png_byte +png_exp8bit(png_fixed_point lg2) +{ + /* Get a 32-bit value: */ + png_uint_32 x = png_exp(lg2); + + /* Convert the 32-bit value to 0..255 by multiplying by 256-1, note that the + * second, rounding, step can't overflow because of the first, subtraction, + * step. + */ + x -= x >> 8; + return (png_byte)((x + 0x7fffffU) >> 24); +} + +static png_uint_16 +png_exp16bit(png_fixed_point lg2) +{ + /* Get a 32-bit value: */ + png_uint_32 x = png_exp(lg2); + + /* Convert the 32-bit value to 0..65535 by multiplying by 65536-1: */ + x -= x >> 16; + return (png_uint_16)((x + 32767U) >> 16); +} +#endif /* FLOATING_ARITHMETIC */ + +png_byte +png_gamma_8bit_correct(unsigned int value, png_fixed_point gamma_val) +{ + if (value > 0 && value < 255) + { +# ifdef PNG_FLOATING_ARITHMETIC_SUPPORTED + double r = floor(255*pow(value/255.,gamma_val*.00001)+.5); + return (png_byte)r; +# else + png_int_32 lg2 = png_log8bit(value); + png_fixed_point res; + + if (png_muldiv(&res, gamma_val, lg2, PNG_FP_1)) + return png_exp8bit(res); + + /* Overflow. */ + value = 0; +# endif + } + + return (png_byte)value; +} + +png_uint_16 +png_gamma_16bit_correct(unsigned int value, png_fixed_point gamma_val) +{ + if (value > 0 && value < 65535) + { +# ifdef PNG_FLOATING_ARITHMETIC_SUPPORTED + double r = floor(65535*pow(value/65535.,gamma_val*.00001)+.5); + return (png_uint_16)r; +# else + png_int_32 lg2 = png_log16bit(value); + png_fixed_point res; + + if (png_muldiv(&res, gamma_val, lg2, PNG_FP_1)) + return png_exp16bit(res); + + /* Overflow. */ + value = 0; +# endif + } + + return (png_uint_16)value; +} + +/* This does the right thing based on the bit_depth field of the + * png_struct, interpreting values as 8-bit or 16-bit. While the result + * is nominally a 16-bit value if bit depth is 8 then the result is + * 8-bit (as are the arguments.) + */ +png_uint_16 /* PRIVATE */ +png_gamma_correct(png_structrp png_ptr, unsigned int value, + png_fixed_point gamma_val) +{ + if (png_ptr->bit_depth == 8) + return png_gamma_8bit_correct(value, gamma_val); + + else + return png_gamma_16bit_correct(value, gamma_val); +} + +/* Internal function to build a single 16-bit table - the table consists of + * 'num' 256 entry subtables, where 'num' is determined by 'shift' - the amount + * to shift the input values right (or 16-number_of_signifiant_bits). + * + * The caller is responsible for ensuring that the table gets cleaned up on + * png_error (i.e. if one of the mallocs below fails) - i.e. the *table argument + * should be somewhere that will be cleaned. + */ +static void +png_build_16bit_table(png_structrp png_ptr, png_uint_16pp *ptable, + PNG_CONST unsigned int shift, PNG_CONST png_fixed_point gamma_val) +{ + /* Various values derived from 'shift': */ + PNG_CONST unsigned int num = 1U << (8U - shift); + PNG_CONST unsigned int max = (1U << (16U - shift))-1U; + PNG_CONST unsigned int max_by_2 = 1U << (15U-shift); + unsigned int i; + + png_uint_16pp table = *ptable = + (png_uint_16pp)png_calloc(png_ptr, num * (sizeof (png_uint_16p))); + + for (i = 0; i < num; i++) + { + png_uint_16p sub_table = table[i] = + (png_uint_16p)png_malloc(png_ptr, 256 * (sizeof (png_uint_16))); + + /* The 'threshold' test is repeated here because it can arise for one of + * the 16-bit tables even if the others don't hit it. + */ + if (png_gamma_significant(gamma_val)) + { + /* The old code would overflow at the end and this would cause the + * 'pow' function to return a result >1, resulting in an + * arithmetic error. This code follows the spec exactly; ig is + * the recovered input sample, it always has 8-16 bits. + * + * We want input * 65535/max, rounded, the arithmetic fits in 32 + * bits (unsigned) so long as max <= 32767. + */ + unsigned int j; + for (j = 0; j < 256; j++) + { + png_uint_32 ig = (j << (8-shift)) + i; +# ifdef PNG_FLOATING_ARITHMETIC_SUPPORTED + /* Inline the 'max' scaling operation: */ + double d = floor(65535*pow(ig/(double)max, gamma_val*.00001)+.5); + sub_table[j] = (png_uint_16)d; +# else + if (shift) + ig = (ig * 65535U + max_by_2)/max; + + sub_table[j] = png_gamma_16bit_correct(ig, gamma_val); +# endif + } + } + else + { + /* We must still build a table, but do it the fast way. */ + unsigned int j; + + for (j = 0; j < 256; j++) + { + png_uint_32 ig = (j << (8-shift)) + i; + + if (shift) + ig = (ig * 65535U + max_by_2)/max; + + sub_table[j] = (png_uint_16)ig; + } + } + } +} + +/* NOTE: this function expects the *inverse* of the overall gamma transformation + * required. + */ +static void +png_build_16to8_table(png_structrp png_ptr, png_uint_16pp *ptable, + PNG_CONST unsigned int shift, PNG_CONST png_fixed_point gamma_val) +{ + PNG_CONST unsigned int num = 1U << (8U - shift); + PNG_CONST unsigned int max = (1U << (16U - shift))-1U; + unsigned int i; + png_uint_32 last; + + png_uint_16pp table = *ptable = + (png_uint_16pp)png_calloc(png_ptr, num * (sizeof (png_uint_16p))); + + /* 'num' is the number of tables and also the number of low bits of low + * bits of the input 16-bit value used to select a table. Each table is + * itself index by the high 8 bits of the value. + */ + for (i = 0; i < num; i++) + table[i] = (png_uint_16p)png_malloc(png_ptr, + 256 * (sizeof (png_uint_16))); + + /* 'gamma_val' is set to the reciprocal of the value calculated above, so + * pow(out,g) is an *input* value. 'last' is the last input value set. + * + * In the loop 'i' is used to find output values. Since the output is + * 8-bit there are only 256 possible values. The tables are set up to + * select the closest possible output value for each input by finding + * the input value at the boundary between each pair of output values + * and filling the table up to that boundary with the lower output + * value. + * + * The boundary values are 0.5,1.5..253.5,254.5. Since these are 9-bit + * values the code below uses a 16-bit value in i; the values start at + * 128.5 (for 0.5) and step by 257, for a total of 254 values (the last + * entries are filled with 255). Start i at 128 and fill all 'last' + * table entries <= 'max' + */ + last = 0; + for (i = 0; i < 255; ++i) /* 8-bit output value */ + { + /* Find the corresponding maximum input value */ + png_uint_16 out = (png_uint_16)(i * 257U); /* 16-bit output value */ + + /* Find the boundary value in 16 bits: */ + png_uint_32 bound = png_gamma_16bit_correct(out+128U, gamma_val); + + /* Adjust (round) to (16-shift) bits: */ + bound = (bound * max + 32768U)/65535U + 1U; + + while (last < bound) + { + table[last & (0xffU >> shift)][last >> (8U - shift)] = out; + last++; + } + } + + /* And fill in the final entries. */ + while (last < (num << 8)) + { + table[last & (0xff >> shift)][last >> (8U - shift)] = 65535U; + last++; + } +} + +/* Build a single 8-bit table: same as the 16-bit case but much simpler (and + * typically much faster). Note that libpng currently does no sBIT processing + * (apparently contrary to the spec) so a 256 entry table is always generated. + */ +static void +png_build_8bit_table(png_structrp png_ptr, png_bytepp ptable, + PNG_CONST png_fixed_point gamma_val) +{ + unsigned int i; + png_bytep table = *ptable = (png_bytep)png_malloc(png_ptr, 256); + + if (png_gamma_significant(gamma_val)) for (i=0; i<256; i++) + table[i] = png_gamma_8bit_correct(i, gamma_val); + + else for (i=0; i<256; ++i) + table[i] = (png_byte)i; +} + +/* Used from png_read_destroy and below to release the memory used by the gamma + * tables. + */ +void /* PRIVATE */ +png_destroy_gamma_table(png_structrp png_ptr) +{ + png_free(png_ptr, png_ptr->gamma_table); + png_ptr->gamma_table = NULL; + + if (png_ptr->gamma_16_table != NULL) + { + int i; + int istop = (1 << (8 - png_ptr->gamma_shift)); + for (i = 0; i < istop; i++) + { + png_free(png_ptr, png_ptr->gamma_16_table[i]); + } + png_free(png_ptr, png_ptr->gamma_16_table); + png_ptr->gamma_16_table = NULL; + } + +#if defined(PNG_READ_BACKGROUND_SUPPORTED) || \ + defined(PNG_READ_ALPHA_MODE_SUPPORTED) || \ + defined(PNG_READ_RGB_TO_GRAY_SUPPORTED) + png_free(png_ptr, png_ptr->gamma_from_1); + png_ptr->gamma_from_1 = NULL; + png_free(png_ptr, png_ptr->gamma_to_1); + png_ptr->gamma_to_1 = NULL; + + if (png_ptr->gamma_16_from_1 != NULL) + { + int i; + int istop = (1 << (8 - png_ptr->gamma_shift)); + for (i = 0; i < istop; i++) + { + png_free(png_ptr, png_ptr->gamma_16_from_1[i]); + } + png_free(png_ptr, png_ptr->gamma_16_from_1); + png_ptr->gamma_16_from_1 = NULL; + } + if (png_ptr->gamma_16_to_1 != NULL) + { + int i; + int istop = (1 << (8 - png_ptr->gamma_shift)); + for (i = 0; i < istop; i++) + { + png_free(png_ptr, png_ptr->gamma_16_to_1[i]); + } + png_free(png_ptr, png_ptr->gamma_16_to_1); + png_ptr->gamma_16_to_1 = NULL; + } +#endif /* READ_BACKGROUND || READ_ALPHA_MODE || RGB_TO_GRAY */ +} + +/* We build the 8- or 16-bit gamma tables here. Note that for 16-bit + * tables, we don't make a full table if we are reducing to 8-bit in + * the future. Note also how the gamma_16 tables are segmented so that + * we don't need to allocate > 64K chunks for a full 16-bit table. + */ +void /* PRIVATE */ +png_build_gamma_table(png_structrp png_ptr, int bit_depth) +{ + png_debug(1, "in png_build_gamma_table"); + + /* Remove any existing table; this copes with multiple calls to + * png_read_update_info. The warning is because building the gamma tables + * multiple times is a performance hit - it's harmless but the ability to call + * png_read_update_info() multiple times is new in 1.5.6 so it seems sensible + * to warn if the app introduces such a hit. + */ + if (png_ptr->gamma_table != NULL || png_ptr->gamma_16_table != NULL) + { + png_warning(png_ptr, "gamma table being rebuilt"); + png_destroy_gamma_table(png_ptr); + } + + if (bit_depth <= 8) + { + png_build_8bit_table(png_ptr, &png_ptr->gamma_table, + png_ptr->screen_gamma > 0 ? png_reciprocal2(png_ptr->colorspace.gamma, + png_ptr->screen_gamma) : PNG_FP_1); + +#if defined(PNG_READ_BACKGROUND_SUPPORTED) || \ + defined(PNG_READ_ALPHA_MODE_SUPPORTED) || \ + defined(PNG_READ_RGB_TO_GRAY_SUPPORTED) + if (png_ptr->transformations & (PNG_COMPOSE | PNG_RGB_TO_GRAY)) + { + png_build_8bit_table(png_ptr, &png_ptr->gamma_to_1, + png_reciprocal(png_ptr->colorspace.gamma)); + + png_build_8bit_table(png_ptr, &png_ptr->gamma_from_1, + png_ptr->screen_gamma > 0 ? png_reciprocal(png_ptr->screen_gamma) : + png_ptr->colorspace.gamma/* Probably doing rgb_to_gray */); + } +#endif /* READ_BACKGROUND || READ_ALPHA_MODE || RGB_TO_GRAY */ + } + else + { + png_byte shift, sig_bit; + + if (png_ptr->color_type & PNG_COLOR_MASK_COLOR) + { + sig_bit = png_ptr->sig_bit.red; + + if (png_ptr->sig_bit.green > sig_bit) + sig_bit = png_ptr->sig_bit.green; + + if (png_ptr->sig_bit.blue > sig_bit) + sig_bit = png_ptr->sig_bit.blue; + } + else + sig_bit = png_ptr->sig_bit.gray; + + /* 16-bit gamma code uses this equation: + * + * ov = table[(iv & 0xff) >> gamma_shift][iv >> 8] + * + * Where 'iv' is the input color value and 'ov' is the output value - + * pow(iv, gamma). + * + * Thus the gamma table consists of up to 256 256 entry tables. The table + * is selected by the (8-gamma_shift) most significant of the low 8 bits of + * the color value then indexed by the upper 8 bits: + * + * table[low bits][high 8 bits] + * + * So the table 'n' corresponds to all those 'iv' of: + * + * ..<(n+1 << gamma_shift)-1> + * + */ + if (sig_bit > 0 && sig_bit < 16U) + shift = (png_byte)(16U - sig_bit); /* shift == insignificant bits */ + + else + shift = 0; /* keep all 16 bits */ + + if (png_ptr->transformations & (PNG_16_TO_8 | PNG_SCALE_16_TO_8)) + { + /* PNG_MAX_GAMMA_8 is the number of bits to keep - effectively + * the significant bits in the *input* when the output will + * eventually be 8 bits. By default it is 11. + */ + if (shift < (16U - PNG_MAX_GAMMA_8)) + shift = (16U - PNG_MAX_GAMMA_8); + } + + if (shift > 8U) + shift = 8U; /* Guarantees at least one table! */ + + png_ptr->gamma_shift = shift; + +#ifdef PNG_16BIT_SUPPORTED + /* NOTE: prior to 1.5.4 this test used to include PNG_BACKGROUND (now + * PNG_COMPOSE). This effectively smashed the background calculation for + * 16-bit output because the 8-bit table assumes the result will be reduced + * to 8 bits. + */ + if (png_ptr->transformations & (PNG_16_TO_8 | PNG_SCALE_16_TO_8)) +#endif + png_build_16to8_table(png_ptr, &png_ptr->gamma_16_table, shift, + png_ptr->screen_gamma > 0 ? png_product2(png_ptr->colorspace.gamma, + png_ptr->screen_gamma) : PNG_FP_1); + +#ifdef PNG_16BIT_SUPPORTED + else + png_build_16bit_table(png_ptr, &png_ptr->gamma_16_table, shift, + png_ptr->screen_gamma > 0 ? png_reciprocal2(png_ptr->colorspace.gamma, + png_ptr->screen_gamma) : PNG_FP_1); +#endif + +#if defined(PNG_READ_BACKGROUND_SUPPORTED) || \ + defined(PNG_READ_ALPHA_MODE_SUPPORTED) || \ + defined(PNG_READ_RGB_TO_GRAY_SUPPORTED) + if (png_ptr->transformations & (PNG_COMPOSE | PNG_RGB_TO_GRAY)) + { + png_build_16bit_table(png_ptr, &png_ptr->gamma_16_to_1, shift, + png_reciprocal(png_ptr->colorspace.gamma)); + + /* Notice that the '16 from 1' table should be full precision, however + * the lookup on this table still uses gamma_shift, so it can't be. + * TODO: fix this. + */ + png_build_16bit_table(png_ptr, &png_ptr->gamma_16_from_1, shift, + png_ptr->screen_gamma > 0 ? png_reciprocal(png_ptr->screen_gamma) : + png_ptr->colorspace.gamma/* Probably doing rgb_to_gray */); + } +#endif /* READ_BACKGROUND || READ_ALPHA_MODE || RGB_TO_GRAY */ + } +} +#endif /* READ_GAMMA */ + +/* HARDWARE OPTION SUPPORT */ +#ifdef PNG_SET_OPTION_SUPPORTED +int PNGAPI +png_set_option(png_structrp png_ptr, int option, int onoff) +{ + if (png_ptr != NULL && option >= 0 && option < PNG_OPTION_NEXT && + (option & 1) == 0) + { + int mask = 3 << option; + int setting = (2 + (onoff != 0)) << option; + int current = png_ptr->options; + + png_ptr->options = (png_byte)((current & ~mask) | setting); + + return (current & mask) >> option; + } + + return PNG_OPTION_INVALID; +} +#endif + +/* sRGB support */ +#if defined(PNG_SIMPLIFIED_READ_SUPPORTED) ||\ + defined(PNG_SIMPLIFIED_WRITE_SUPPORTED) +/* sRGB conversion tables; these are machine generated with the code in + * contrib/tools/makesRGB.c. The actual sRGB transfer curve defined in the + * specification (see the article at http://en.wikipedia.org/wiki/SRGB) + * is used, not the gamma=1/2.2 approximation use elsewhere in libpng. + * The sRGB to linear table is exact (to the nearest 16 bit linear fraction). + * The inverse (linear to sRGB) table has accuracies as follows: + * + * For all possible (255*65535+1) input values: + * + * error: -0.515566 - 0.625971, 79441 (0.475369%) of readings inexact + * + * For the input values corresponding to the 65536 16-bit values: + * + * error: -0.513727 - 0.607759, 308 (0.469978%) of readings inexact + * + * In all cases the inexact readings are off by one. + */ + +#ifdef PNG_SIMPLIFIED_READ_SUPPORTED +/* The convert-to-sRGB table is only currently required for read. */ +const png_uint_16 png_sRGB_table[256] = +{ + 0,20,40,60,80,99,119,139, + 159,179,199,219,241,264,288,313, + 340,367,396,427,458,491,526,562, + 599,637,677,718,761,805,851,898, + 947,997,1048,1101,1156,1212,1270,1330, + 1391,1453,1517,1583,1651,1720,1790,1863, + 1937,2013,2090,2170,2250,2333,2418,2504, + 2592,2681,2773,2866,2961,3058,3157,3258, + 3360,3464,3570,3678,3788,3900,4014,4129, + 4247,4366,4488,4611,4736,4864,4993,5124, + 5257,5392,5530,5669,5810,5953,6099,6246, + 6395,6547,6700,6856,7014,7174,7335,7500, + 7666,7834,8004,8177,8352,8528,8708,8889, + 9072,9258,9445,9635,9828,10022,10219,10417, + 10619,10822,11028,11235,11446,11658,11873,12090, + 12309,12530,12754,12980,13209,13440,13673,13909, + 14146,14387,14629,14874,15122,15371,15623,15878, + 16135,16394,16656,16920,17187,17456,17727,18001, + 18277,18556,18837,19121,19407,19696,19987,20281, + 20577,20876,21177,21481,21787,22096,22407,22721, + 23038,23357,23678,24002,24329,24658,24990,25325, + 25662,26001,26344,26688,27036,27386,27739,28094, + 28452,28813,29176,29542,29911,30282,30656,31033, + 31412,31794,32179,32567,32957,33350,33745,34143, + 34544,34948,35355,35764,36176,36591,37008,37429, + 37852,38278,38706,39138,39572,40009,40449,40891, + 41337,41785,42236,42690,43147,43606,44069,44534, + 45002,45473,45947,46423,46903,47385,47871,48359, + 48850,49344,49841,50341,50844,51349,51858,52369, + 52884,53401,53921,54445,54971,55500,56032,56567, + 57105,57646,58190,58737,59287,59840,60396,60955, + 61517,62082,62650,63221,63795,64372,64952,65535 +}; + +#endif /* simplified read only */ + +/* The base/delta tables are required for both read and write (but currently + * only the simplified versions.) + */ +const png_uint_16 png_sRGB_base[512] = +{ + 128,1782,3383,4644,5675,6564,7357,8074, + 8732,9346,9921,10463,10977,11466,11935,12384, + 12816,13233,13634,14024,14402,14769,15125,15473, + 15812,16142,16466,16781,17090,17393,17690,17981, + 18266,18546,18822,19093,19359,19621,19879,20133, + 20383,20630,20873,21113,21349,21583,21813,22041, + 22265,22487,22707,22923,23138,23350,23559,23767, + 23972,24175,24376,24575,24772,24967,25160,25352, + 25542,25730,25916,26101,26284,26465,26645,26823, + 27000,27176,27350,27523,27695,27865,28034,28201, + 28368,28533,28697,28860,29021,29182,29341,29500, + 29657,29813,29969,30123,30276,30429,30580,30730, + 30880,31028,31176,31323,31469,31614,31758,31902, + 32045,32186,32327,32468,32607,32746,32884,33021, + 33158,33294,33429,33564,33697,33831,33963,34095, + 34226,34357,34486,34616,34744,34873,35000,35127, + 35253,35379,35504,35629,35753,35876,35999,36122, + 36244,36365,36486,36606,36726,36845,36964,37083, + 37201,37318,37435,37551,37668,37783,37898,38013, + 38127,38241,38354,38467,38580,38692,38803,38915, + 39026,39136,39246,39356,39465,39574,39682,39790, + 39898,40005,40112,40219,40325,40431,40537,40642, + 40747,40851,40955,41059,41163,41266,41369,41471, + 41573,41675,41777,41878,41979,42079,42179,42279, + 42379,42478,42577,42676,42775,42873,42971,43068, + 43165,43262,43359,43456,43552,43648,43743,43839, + 43934,44028,44123,44217,44311,44405,44499,44592, + 44685,44778,44870,44962,45054,45146,45238,45329, + 45420,45511,45601,45692,45782,45872,45961,46051, + 46140,46229,46318,46406,46494,46583,46670,46758, + 46846,46933,47020,47107,47193,47280,47366,47452, + 47538,47623,47709,47794,47879,47964,48048,48133, + 48217,48301,48385,48468,48552,48635,48718,48801, + 48884,48966,49048,49131,49213,49294,49376,49458, + 49539,49620,49701,49782,49862,49943,50023,50103, + 50183,50263,50342,50422,50501,50580,50659,50738, + 50816,50895,50973,51051,51129,51207,51285,51362, + 51439,51517,51594,51671,51747,51824,51900,51977, + 52053,52129,52205,52280,52356,52432,52507,52582, + 52657,52732,52807,52881,52956,53030,53104,53178, + 53252,53326,53400,53473,53546,53620,53693,53766, + 53839,53911,53984,54056,54129,54201,54273,54345, + 54417,54489,54560,54632,54703,54774,54845,54916, + 54987,55058,55129,55199,55269,55340,55410,55480, + 55550,55620,55689,55759,55828,55898,55967,56036, + 56105,56174,56243,56311,56380,56448,56517,56585, + 56653,56721,56789,56857,56924,56992,57059,57127, + 57194,57261,57328,57395,57462,57529,57595,57662, + 57728,57795,57861,57927,57993,58059,58125,58191, + 58256,58322,58387,58453,58518,58583,58648,58713, + 58778,58843,58908,58972,59037,59101,59165,59230, + 59294,59358,59422,59486,59549,59613,59677,59740, + 59804,59867,59930,59993,60056,60119,60182,60245, + 60308,60370,60433,60495,60558,60620,60682,60744, + 60806,60868,60930,60992,61054,61115,61177,61238, + 61300,61361,61422,61483,61544,61605,61666,61727, + 61788,61848,61909,61969,62030,62090,62150,62211, + 62271,62331,62391,62450,62510,62570,62630,62689, + 62749,62808,62867,62927,62986,63045,63104,63163, + 63222,63281,63340,63398,63457,63515,63574,63632, + 63691,63749,63807,63865,63923,63981,64039,64097, + 64155,64212,64270,64328,64385,64443,64500,64557, + 64614,64672,64729,64786,64843,64900,64956,65013, + 65070,65126,65183,65239,65296,65352,65409,65465 +}; + +const png_byte png_sRGB_delta[512] = +{ + 207,201,158,129,113,100,90,82,77,72,68,64,61,59,56,54, + 52,50,49,47,46,45,43,42,41,40,39,39,38,37,36,36, + 35,34,34,33,33,32,32,31,31,30,30,30,29,29,28,28, + 28,27,27,27,27,26,26,26,25,25,25,25,24,24,24,24, + 23,23,23,23,23,22,22,22,22,22,22,21,21,21,21,21, + 21,20,20,20,20,20,20,20,20,19,19,19,19,19,19,19, + 19,18,18,18,18,18,18,18,18,18,18,17,17,17,17,17, + 17,17,17,17,17,17,16,16,16,16,16,16,16,16,16,16, + 16,16,16,16,15,15,15,15,15,15,15,15,15,15,15,15, + 15,15,15,15,14,14,14,14,14,14,14,14,14,14,14,14, + 14,14,14,14,14,14,14,13,13,13,13,13,13,13,13,13, + 13,13,13,13,13,13,13,13,13,13,13,13,13,13,12,12, + 12,12,12,12,12,12,12,12,12,12,12,12,12,12,12,12, + 12,12,12,12,12,12,12,12,12,12,12,12,11,11,11,11, + 11,11,11,11,11,11,11,11,11,11,11,11,11,11,11,11, + 11,11,11,11,11,11,11,11,11,11,11,11,11,11,11,11, + 11,10,10,10,10,10,10,10,10,10,10,10,10,10,10,10, + 10,10,10,10,10,10,10,10,10,10,10,10,10,10,10,10, + 10,10,10,10,10,10,10,10,10,10,10,10,10,10,10,10, + 10,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, + 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, + 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, + 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, + 9,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, + 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, + 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, + 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, + 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, + 8,8,8,8,8,8,8,8,8,7,7,7,7,7,7,7, + 7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7, + 7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7, + 7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7 +}; +#endif /* SIMPLIFIED READ/WRITE sRGB support */ + +/* SIMPLIFIED READ/WRITE SUPPORT */ +#if defined(PNG_SIMPLIFIED_READ_SUPPORTED) ||\ + defined(PNG_SIMPLIFIED_WRITE_SUPPORTED) +static int +png_image_free_function(png_voidp argument) +{ + png_imagep image = png_voidcast(png_imagep, argument); + png_controlp cp = image->opaque; + png_control c; + + /* Double check that we have a png_ptr - it should be impossible to get here + * without one. + */ + if (cp->png_ptr == NULL) + return 0; + + /* First free any data held in the control structure. */ +# ifdef PNG_STDIO_SUPPORTED + if (cp->owned_file) + { + FILE *fp = png_voidcast(FILE*, cp->png_ptr->io_ptr); + cp->owned_file = 0; + + /* Ignore errors here. */ + if (fp != NULL) + { + cp->png_ptr->io_ptr = NULL; + (void)fclose(fp); + } + } +# endif + + /* Copy the control structure so that the original, allocated, version can be + * safely freed. Notice that a png_error here stops the remainder of the + * cleanup, but this is probably fine because that would indicate bad memory + * problems anyway. + */ + c = *cp; + image->opaque = &c; + png_free(c.png_ptr, cp); + + /* Then the structures, calling the correct API. */ + if (c.for_write) + { +# ifdef PNG_SIMPLIFIED_WRITE_SUPPORTED + png_destroy_write_struct(&c.png_ptr, &c.info_ptr); +# else + png_error(c.png_ptr, "simplified write not supported"); +# endif + } + else + { +# ifdef PNG_SIMPLIFIED_READ_SUPPORTED + png_destroy_read_struct(&c.png_ptr, &c.info_ptr, NULL); +# else + png_error(c.png_ptr, "simplified read not supported"); +# endif + } + + /* Success. */ + return 1; +} + +void PNGAPI +png_image_free(png_imagep image) +{ + /* Safely call the real function, but only if doing so is safe at this point + * (if not inside an error handling context). Otherwise assume + * png_safe_execute will call this API after the return. + */ + if (image != NULL && image->opaque != NULL && + image->opaque->error_buf == NULL) + { + /* Ignore errors here: */ + (void)png_safe_execute(image, png_image_free_function, image); + image->opaque = NULL; + } +} + +int /* PRIVATE */ +png_image_error(png_imagep image, png_const_charp error_message) +{ + /* Utility to log an error. */ + png_safecat(image->message, (sizeof image->message), 0, error_message); + image->warning_or_error |= PNG_IMAGE_ERROR; + png_image_free(image); + return 0; +} + +#endif /* SIMPLIFIED READ/WRITE */ +#endif /* defined(PNG_READ_SUPPORTED) || defined(PNG_WRITE_SUPPORTED) */ diff --git a/ml/dlib/dlib/external/libpng/png.h b/ml/dlib/dlib/external/libpng/png.h new file mode 100644 index 000000000..527392738 --- /dev/null +++ b/ml/dlib/dlib/external/libpng/png.h @@ -0,0 +1,3319 @@ + +/* png.h - header file for PNG reference library + * + * libpng version 1.6.7 - November 14, 2013 + * Copyright (c) 1998-2013 Glenn Randers-Pehrson + * (Version 0.96 Copyright (c) 1996, 1997 Andreas Dilger) + * (Version 0.88 Copyright (c) 1995, 1996 Guy Eric Schalnat, Group 42, Inc.) + * + * This code is released under the libpng license (See LICENSE, below) + * + * Authors and maintainers: + * libpng versions 0.71, May 1995, through 0.88, January 1996: Guy Schalnat + * libpng versions 0.89c, June 1996, through 0.96, May 1997: Andreas Dilger + * libpng versions 0.97, January 1998, through 1.6.7 - November 14, 2013: Glenn + * See also "Contributing Authors", below. + * + * Note about libpng version numbers: + * + * Due to various miscommunications, unforeseen code incompatibilities + * and occasional factors outside the authors' control, version numbering + * on the library has not always been consistent and straightforward. + * The following table summarizes matters since version 0.89c, which was + * the first widely used release: + * + * source png.h png.h shared-lib + * version string int version + * ------- ------ ----- ---------- + * 0.89c "1.0 beta 3" 0.89 89 1.0.89 + * 0.90 "1.0 beta 4" 0.90 90 0.90 [should have been 2.0.90] + * 0.95 "1.0 beta 5" 0.95 95 0.95 [should have been 2.0.95] + * 0.96 "1.0 beta 6" 0.96 96 0.96 [should have been 2.0.96] + * 0.97b "1.00.97 beta 7" 1.00.97 97 1.0.1 [should have been 2.0.97] + * 0.97c 0.97 97 2.0.97 + * 0.98 0.98 98 2.0.98 + * 0.99 0.99 98 2.0.99 + * 0.99a-m 0.99 99 2.0.99 + * 1.00 1.00 100 2.1.0 [100 should be 10000] + * 1.0.0 (from here on, the 100 2.1.0 [100 should be 10000] + * 1.0.1 png.h string is 10001 2.1.0 + * 1.0.1a-e identical to the 10002 from here on, the shared library + * 1.0.2 source version) 10002 is 2.V where V is the source code + * 1.0.2a-b 10003 version, except as noted. + * 1.0.3 10003 + * 1.0.3a-d 10004 + * 1.0.4 10004 + * 1.0.4a-f 10005 + * 1.0.5 (+ 2 patches) 10005 + * 1.0.5a-d 10006 + * 1.0.5e-r 10100 (not source compatible) + * 1.0.5s-v 10006 (not binary compatible) + * 1.0.6 (+ 3 patches) 10006 (still binary incompatible) + * 1.0.6d-f 10007 (still binary incompatible) + * 1.0.6g 10007 + * 1.0.6h 10007 10.6h (testing xy.z so-numbering) + * 1.0.6i 10007 10.6i + * 1.0.6j 10007 2.1.0.6j (incompatible with 1.0.0) + * 1.0.7beta11-14 DLLNUM 10007 2.1.0.7beta11-14 (binary compatible) + * 1.0.7beta15-18 1 10007 2.1.0.7beta15-18 (binary compatible) + * 1.0.7rc1-2 1 10007 2.1.0.7rc1-2 (binary compatible) + * 1.0.7 1 10007 (still compatible) + * 1.0.8beta1-4 1 10008 2.1.0.8beta1-4 + * 1.0.8rc1 1 10008 2.1.0.8rc1 + * 1.0.8 1 10008 2.1.0.8 + * 1.0.9beta1-6 1 10009 2.1.0.9beta1-6 + * 1.0.9rc1 1 10009 2.1.0.9rc1 + * 1.0.9beta7-10 1 10009 2.1.0.9beta7-10 + * 1.0.9rc2 1 10009 2.1.0.9rc2 + * 1.0.9 1 10009 2.1.0.9 + * 1.0.10beta1 1 10010 2.1.0.10beta1 + * 1.0.10rc1 1 10010 2.1.0.10rc1 + * 1.0.10 1 10010 2.1.0.10 + * 1.0.11beta1-3 1 10011 2.1.0.11beta1-3 + * 1.0.11rc1 1 10011 2.1.0.11rc1 + * 1.0.11 1 10011 2.1.0.11 + * 1.0.12beta1-2 2 10012 2.1.0.12beta1-2 + * 1.0.12rc1 2 10012 2.1.0.12rc1 + * 1.0.12 2 10012 2.1.0.12 + * 1.1.0a-f - 10100 2.1.1.0a-f (branch abandoned) + * 1.2.0beta1-2 2 10200 2.1.2.0beta1-2 + * 1.2.0beta3-5 3 10200 3.1.2.0beta3-5 + * 1.2.0rc1 3 10200 3.1.2.0rc1 + * 1.2.0 3 10200 3.1.2.0 + * 1.2.1beta1-4 3 10201 3.1.2.1beta1-4 + * 1.2.1rc1-2 3 10201 3.1.2.1rc1-2 + * 1.2.1 3 10201 3.1.2.1 + * 1.2.2beta1-6 12 10202 12.so.0.1.2.2beta1-6 + * 1.0.13beta1 10 10013 10.so.0.1.0.13beta1 + * 1.0.13rc1 10 10013 10.so.0.1.0.13rc1 + * 1.2.2rc1 12 10202 12.so.0.1.2.2rc1 + * 1.0.13 10 10013 10.so.0.1.0.13 + * 1.2.2 12 10202 12.so.0.1.2.2 + * 1.2.3rc1-6 12 10203 12.so.0.1.2.3rc1-6 + * 1.2.3 12 10203 12.so.0.1.2.3 + * 1.2.4beta1-3 13 10204 12.so.0.1.2.4beta1-3 + * 1.0.14rc1 13 10014 10.so.0.1.0.14rc1 + * 1.2.4rc1 13 10204 12.so.0.1.2.4rc1 + * 1.0.14 10 10014 10.so.0.1.0.14 + * 1.2.4 13 10204 12.so.0.1.2.4 + * 1.2.5beta1-2 13 10205 12.so.0.1.2.5beta1-2 + * 1.0.15rc1-3 10 10015 10.so.0.1.0.15rc1-3 + * 1.2.5rc1-3 13 10205 12.so.0.1.2.5rc1-3 + * 1.0.15 10 10015 10.so.0.1.0.15 + * 1.2.5 13 10205 12.so.0.1.2.5 + * 1.2.6beta1-4 13 10206 12.so.0.1.2.6beta1-4 + * 1.0.16 10 10016 10.so.0.1.0.16 + * 1.2.6 13 10206 12.so.0.1.2.6 + * 1.2.7beta1-2 13 10207 12.so.0.1.2.7beta1-2 + * 1.0.17rc1 10 10017 12.so.0.1.0.17rc1 + * 1.2.7rc1 13 10207 12.so.0.1.2.7rc1 + * 1.0.17 10 10017 12.so.0.1.0.17 + * 1.2.7 13 10207 12.so.0.1.2.7 + * 1.2.8beta1-5 13 10208 12.so.0.1.2.8beta1-5 + * 1.0.18rc1-5 10 10018 12.so.0.1.0.18rc1-5 + * 1.2.8rc1-5 13 10208 12.so.0.1.2.8rc1-5 + * 1.0.18 10 10018 12.so.0.1.0.18 + * 1.2.8 13 10208 12.so.0.1.2.8 + * 1.2.9beta1-3 13 10209 12.so.0.1.2.9beta1-3 + * 1.2.9beta4-11 13 10209 12.so.0.9[.0] + * 1.2.9rc1 13 10209 12.so.0.9[.0] + * 1.2.9 13 10209 12.so.0.9[.0] + * 1.2.10beta1-7 13 10210 12.so.0.10[.0] + * 1.2.10rc1-2 13 10210 12.so.0.10[.0] + * 1.2.10 13 10210 12.so.0.10[.0] + * 1.4.0beta1-5 14 10400 14.so.0.0[.0] + * 1.2.11beta1-4 13 10211 12.so.0.11[.0] + * 1.4.0beta7-8 14 10400 14.so.0.0[.0] + * 1.2.11 13 10211 12.so.0.11[.0] + * 1.2.12 13 10212 12.so.0.12[.0] + * 1.4.0beta9-14 14 10400 14.so.0.0[.0] + * 1.2.13 13 10213 12.so.0.13[.0] + * 1.4.0beta15-36 14 10400 14.so.0.0[.0] + * 1.4.0beta37-87 14 10400 14.so.14.0[.0] + * 1.4.0rc01 14 10400 14.so.14.0[.0] + * 1.4.0beta88-109 14 10400 14.so.14.0[.0] + * 1.4.0rc02-08 14 10400 14.so.14.0[.0] + * 1.4.0 14 10400 14.so.14.0[.0] + * 1.4.1beta01-03 14 10401 14.so.14.1[.0] + * 1.4.1rc01 14 10401 14.so.14.1[.0] + * 1.4.1beta04-12 14 10401 14.so.14.1[.0] + * 1.4.1 14 10401 14.so.14.1[.0] + * 1.4.2 14 10402 14.so.14.2[.0] + * 1.4.3 14 10403 14.so.14.3[.0] + * 1.4.4 14 10404 14.so.14.4[.0] + * 1.5.0beta01-58 15 10500 15.so.15.0[.0] + * 1.5.0rc01-07 15 10500 15.so.15.0[.0] + * 1.5.0 15 10500 15.so.15.0[.0] + * 1.5.1beta01-11 15 10501 15.so.15.1[.0] + * 1.5.1rc01-02 15 10501 15.so.15.1[.0] + * 1.5.1 15 10501 15.so.15.1[.0] + * 1.5.2beta01-03 15 10502 15.so.15.2[.0] + * 1.5.2rc01-03 15 10502 15.so.15.2[.0] + * 1.5.2 15 10502 15.so.15.2[.0] + * 1.5.3beta01-10 15 10503 15.so.15.3[.0] + * 1.5.3rc01-02 15 10503 15.so.15.3[.0] + * 1.5.3beta11 15 10503 15.so.15.3[.0] + * 1.5.3 [omitted] + * 1.5.4beta01-08 15 10504 15.so.15.4[.0] + * 1.5.4rc01 15 10504 15.so.15.4[.0] + * 1.5.4 15 10504 15.so.15.4[.0] + * 1.5.5beta01-08 15 10505 15.so.15.5[.0] + * 1.5.5rc01 15 10505 15.so.15.5[.0] + * 1.5.5 15 10505 15.so.15.5[.0] + * 1.5.6beta01-07 15 10506 15.so.15.6[.0] + * 1.5.6rc01-03 15 10506 15.so.15.6[.0] + * 1.5.6 15 10506 15.so.15.6[.0] + * 1.5.7beta01-05 15 10507 15.so.15.7[.0] + * 1.5.7rc01-03 15 10507 15.so.15.7[.0] + * 1.5.7 15 10507 15.so.15.7[.0] + * 1.6.0beta01-40 16 10600 16.so.16.0[.0] + * 1.6.0rc01-08 16 10600 16.so.16.0[.0] + * 1.6.0 16 10600 16.so.16.0[.0] + * 1.6.1beta01-09 16 10601 16.so.16.1[.0] + * 1.6.1rc01 16 10601 16.so.16.1[.0] + * 1.6.1 16 10601 16.so.16.1[.0] + * 1.6.2beta01 16 10602 16.so.16.2[.0] + * 1.6.2rc01-06 16 10602 16.so.16.2[.0] + * 1.6.2 16 10602 16.so.16.2[.0] + * 1.6.3beta01-11 16 10603 16.so.16.3[.0] + * 1.6.3rc01 16 10603 16.so.16.3[.0] + * 1.6.3 16 10603 16.so.16.3[.0] + * 1.6.4beta01-02 16 10604 16.so.16.4[.0] + * 1.6.4rc01 16 10604 16.so.16.4[.0] + * 1.6.4 16 10604 16.so.16.4[.0] + * 1.6.5 16 10605 16.so.16.5[.0] + * 1.6.6 16 10606 16.so.16.6[.0] + * 1.6.7beta01-04 16 10607 16.so.16.7[.0] + * 1.6.7rc01-02 16 10607 16.so.16.7[.0] + * 1.6.7 16 10607 16.so.16.7[.0] + * + * Henceforth the source version will match the shared-library major + * and minor numbers; the shared-library major version number will be + * used for changes in backward compatibility, as it is intended. The + * PNG_LIBPNG_VER macro, which is not used within libpng but is available + * for applications, is an unsigned integer of the form xyyzz corresponding + * to the source version x.y.z (leading zeros in y and z). Beta versions + * were given the previous public release number plus a letter, until + * version 1.0.6j; from then on they were given the upcoming public + * release number plus "betaNN" or "rcNN". + * + * Binary incompatibility exists only when applications make direct access + * to the info_ptr or png_ptr members through png.h, and the compiled + * application is loaded with a different version of the library. + * + * DLLNUM will change each time there are forward or backward changes + * in binary compatibility (e.g., when a new feature is added). + * + * See libpng-manual.txt or libpng.3 for more information. The PNG + * specification is available as a W3C Recommendation and as an ISO + * Specification, defines should NOT be changed. + */ +#define PNG_INFO_gAMA 0x0001 +#define PNG_INFO_sBIT 0x0002 +#define PNG_INFO_cHRM 0x0004 +#define PNG_INFO_PLTE 0x0008 +#define PNG_INFO_tRNS 0x0010 +#define PNG_INFO_bKGD 0x0020 +#define PNG_INFO_hIST 0x0040 +#define PNG_INFO_pHYs 0x0080 +#define PNG_INFO_oFFs 0x0100 +#define PNG_INFO_tIME 0x0200 +#define PNG_INFO_pCAL 0x0400 +#define PNG_INFO_sRGB 0x0800 /* GR-P, 0.96a */ +#define PNG_INFO_iCCP 0x1000 /* ESR, 1.0.6 */ +#define PNG_INFO_sPLT 0x2000 /* ESR, 1.0.6 */ +#define PNG_INFO_sCAL 0x4000 /* ESR, 1.0.6 */ +#define PNG_INFO_IDAT 0x8000 /* ESR, 1.0.6 */ + +/* This is used for the transformation routines, as some of them + * change these values for the row. It also should enable using + * the routines for other purposes. + */ +typedef struct png_row_info_struct +{ + png_uint_32 width; /* width of row */ + png_size_t rowbytes; /* number of bytes in row */ + png_byte color_type; /* color type of row */ + png_byte bit_depth; /* bit depth of row */ + png_byte channels; /* number of channels (1, 2, 3, or 4) */ + png_byte pixel_depth; /* bits per pixel (depth * channels) */ +} png_row_info; + +typedef png_row_info * png_row_infop; +typedef png_row_info * * png_row_infopp; + +/* These are the function types for the I/O functions and for the functions + * that allow the user to override the default I/O functions with his or her + * own. The png_error_ptr type should match that of user-supplied warning + * and error functions, while the png_rw_ptr type should match that of the + * user read/write data functions. Note that the 'write' function must not + * modify the buffer it is passed. The 'read' function, on the other hand, is + * expected to return the read data in the buffer. + */ +typedef PNG_CALLBACK(void, *png_error_ptr, (png_structp, png_const_charp)); +typedef PNG_CALLBACK(void, *png_rw_ptr, (png_structp, png_bytep, png_size_t)); +typedef PNG_CALLBACK(void, *png_flush_ptr, (png_structp)); +typedef PNG_CALLBACK(void, *png_read_status_ptr, (png_structp, png_uint_32, + int)); +typedef PNG_CALLBACK(void, *png_write_status_ptr, (png_structp, png_uint_32, + int)); + +#ifdef PNG_PROGRESSIVE_READ_SUPPORTED +typedef PNG_CALLBACK(void, *png_progressive_info_ptr, (png_structp, png_infop)); +typedef PNG_CALLBACK(void, *png_progressive_end_ptr, (png_structp, png_infop)); + +/* The following callback receives png_uint_32 row_number, int pass for the + * png_bytep data of the row. When transforming an interlaced image the + * row number is the row number within the sub-image of the interlace pass, so + * the value will increase to the height of the sub-image (not the full image) + * then reset to 0 for the next pass. + * + * Use PNG_ROW_FROM_PASS_ROW(row, pass) and PNG_COL_FROM_PASS_COL(col, pass) to + * find the output pixel (x,y) given an interlaced sub-image pixel + * (row,col,pass). (See below for these macros.) + */ +typedef PNG_CALLBACK(void, *png_progressive_row_ptr, (png_structp, png_bytep, + png_uint_32, int)); +#endif + +#if defined(PNG_READ_USER_TRANSFORM_SUPPORTED) || \ + defined(PNG_WRITE_USER_TRANSFORM_SUPPORTED) +typedef PNG_CALLBACK(void, *png_user_transform_ptr, (png_structp, png_row_infop, + png_bytep)); +#endif + +#ifdef PNG_USER_CHUNKS_SUPPORTED +typedef PNG_CALLBACK(int, *png_user_chunk_ptr, (png_structp, + png_unknown_chunkp)); +#endif +#ifdef PNG_UNKNOWN_CHUNKS_SUPPORTED +/* not used anywhere */ +/* typedef PNG_CALLBACK(void, *png_unknown_chunk_ptr, (png_structp)); */ +#endif + +#ifdef PNG_SETJMP_SUPPORTED +/* This must match the function definition in , and the application + * must include this before png.h to obtain the definition of jmp_buf. The + * function is required to be PNG_NORETURN, but this is not checked. If the + * function does return the application will crash via an abort() or similar + * system level call. + * + * If you get a warning here while building the library you may need to make + * changes to ensure that pnglibconf.h records the calling convention used by + * your compiler. This may be very difficult - try using a different compiler + * to build the library! + */ +PNG_FUNCTION(void, (PNGCAPI *png_longjmp_ptr), PNGARG((jmp_buf, int)), typedef); +#endif + +/* Transform masks for the high-level interface */ +#define PNG_TRANSFORM_IDENTITY 0x0000 /* read and write */ +#define PNG_TRANSFORM_STRIP_16 0x0001 /* read only */ +#define PNG_TRANSFORM_STRIP_ALPHA 0x0002 /* read only */ +#define PNG_TRANSFORM_PACKING 0x0004 /* read and write */ +#define PNG_TRANSFORM_PACKSWAP 0x0008 /* read and write */ +#define PNG_TRANSFORM_EXPAND 0x0010 /* read only */ +#define PNG_TRANSFORM_INVERT_MONO 0x0020 /* read and write */ +#define PNG_TRANSFORM_SHIFT 0x0040 /* read and write */ +#define PNG_TRANSFORM_BGR 0x0080 /* read and write */ +#define PNG_TRANSFORM_SWAP_ALPHA 0x0100 /* read and write */ +#define PNG_TRANSFORM_SWAP_ENDIAN 0x0200 /* read and write */ +#define PNG_TRANSFORM_INVERT_ALPHA 0x0400 /* read and write */ +#define PNG_TRANSFORM_STRIP_FILLER 0x0800 /* write only */ +/* Added to libpng-1.2.34 */ +#define PNG_TRANSFORM_STRIP_FILLER_BEFORE PNG_TRANSFORM_STRIP_FILLER +#define PNG_TRANSFORM_STRIP_FILLER_AFTER 0x1000 /* write only */ +/* Added to libpng-1.4.0 */ +#define PNG_TRANSFORM_GRAY_TO_RGB 0x2000 /* read only */ +/* Added to libpng-1.5.4 */ +#define PNG_TRANSFORM_EXPAND_16 0x4000 /* read only */ +#define PNG_TRANSFORM_SCALE_16 0x8000 /* read only */ + +/* Flags for MNG supported features */ +#define PNG_FLAG_MNG_EMPTY_PLTE 0x01 +#define PNG_FLAG_MNG_FILTER_64 0x04 +#define PNG_ALL_MNG_FEATURES 0x05 + +/* NOTE: prior to 1.5 these functions had no 'API' style declaration, + * this allowed the zlib default functions to be used on Windows + * platforms. In 1.5 the zlib default malloc (which just calls malloc and + * ignores the first argument) should be completely compatible with the + * following. + */ +typedef PNG_CALLBACK(png_voidp, *png_malloc_ptr, (png_structp, + png_alloc_size_t)); +typedef PNG_CALLBACK(void, *png_free_ptr, (png_structp, png_voidp)); + +/* Section 3: exported functions + * Here are the function definitions most commonly used. This is not + * the place to find out how to use libpng. See libpng-manual.txt for the + * full explanation, see example.c for the summary. This just provides + * a simple one line description of the use of each function. + * + * The PNG_EXPORT() and PNG_EXPORTA() macros used below are defined in + * pngconf.h and in the *.dfn files in the scripts directory. + * + * PNG_EXPORT(ordinal, type, name, (args)); + * + * ordinal: ordinal that is used while building + * *.def files. The ordinal value is only + * relevant when preprocessing png.h with + * the *.dfn files for building symbol table + * entries, and are removed by pngconf.h. + * type: return type of the function + * name: function name + * args: function arguments, with types + * + * When we wish to append attributes to a function prototype we use + * the PNG_EXPORTA() macro instead. + * + * PNG_EXPORTA(ordinal, type, name, (args), attributes); + * + * ordinal, type, name, and args: same as in PNG_EXPORT(). + * attributes: function attributes + */ + +/* Returns the version number of the library */ +PNG_EXPORT(1, png_uint_32, png_access_version_number, (void)); + +/* Tell lib we have already handled the first magic bytes. + * Handling more than 8 bytes from the beginning of the file is an error. + */ +PNG_EXPORT(2, void, png_set_sig_bytes, (png_structrp png_ptr, int num_bytes)); + +/* Check sig[start] through sig[start + num_to_check - 1] to see if it's a + * PNG file. Returns zero if the supplied bytes match the 8-byte PNG + * signature, and non-zero otherwise. Having num_to_check == 0 or + * start > 7 will always fail (ie return non-zero). + */ +PNG_EXPORT(3, int, png_sig_cmp, (png_const_bytep sig, png_size_t start, + png_size_t num_to_check)); + +/* Simple signature checking function. This is the same as calling + * png_check_sig(sig, n) := !png_sig_cmp(sig, 0, n). + */ +#define png_check_sig(sig, n) !png_sig_cmp((sig), 0, (n)) + +/* Allocate and initialize png_ptr struct for reading, and any other memory. */ +PNG_EXPORTA(4, png_structp, png_create_read_struct, + (png_const_charp user_png_ver, png_voidp error_ptr, + png_error_ptr error_fn, png_error_ptr warn_fn), + PNG_ALLOCATED); + +/* Allocate and initialize png_ptr struct for writing, and any other memory */ +PNG_EXPORTA(5, png_structp, png_create_write_struct, + (png_const_charp user_png_ver, png_voidp error_ptr, png_error_ptr error_fn, + png_error_ptr warn_fn), + PNG_ALLOCATED); + +PNG_EXPORT(6, png_size_t, png_get_compression_buffer_size, + (png_const_structrp png_ptr)); + +PNG_EXPORT(7, void, png_set_compression_buffer_size, (png_structrp png_ptr, + png_size_t size)); + +/* Moved from pngconf.h in 1.4.0 and modified to ensure setjmp/longjmp + * match up. + */ +#ifdef PNG_SETJMP_SUPPORTED +/* This function returns the jmp_buf built in to *png_ptr. It must be + * supplied with an appropriate 'longjmp' function to use on that jmp_buf + * unless the default error function is overridden in which case NULL is + * acceptable. The size of the jmp_buf is checked against the actual size + * allocated by the library - the call will return NULL on a mismatch + * indicating an ABI mismatch. + */ +PNG_EXPORT(8, jmp_buf*, png_set_longjmp_fn, (png_structrp png_ptr, + png_longjmp_ptr longjmp_fn, size_t jmp_buf_size)); +# define png_jmpbuf(png_ptr) \ + (*png_set_longjmp_fn((png_ptr), longjmp, (sizeof (jmp_buf)))) +#else +# define png_jmpbuf(png_ptr) \ + (LIBPNG_WAS_COMPILED_WITH__PNG_NO_SETJMP) +#endif +/* This function should be used by libpng applications in place of + * longjmp(png_ptr->jmpbuf, val). If longjmp_fn() has been set, it + * will use it; otherwise it will call PNG_ABORT(). This function was + * added in libpng-1.5.0. + */ +PNG_EXPORTA(9, void, png_longjmp, (png_const_structrp png_ptr, int val), + PNG_NORETURN); + +#ifdef PNG_READ_SUPPORTED +/* Reset the compression stream */ +PNG_EXPORTA(10, int, png_reset_zstream, (png_structrp png_ptr), PNG_DEPRECATED); +#endif + +/* New functions added in libpng-1.0.2 (not enabled by default until 1.2.0) */ +#ifdef PNG_USER_MEM_SUPPORTED +PNG_EXPORTA(11, png_structp, png_create_read_struct_2, + (png_const_charp user_png_ver, png_voidp error_ptr, png_error_ptr error_fn, + png_error_ptr warn_fn, + png_voidp mem_ptr, png_malloc_ptr malloc_fn, png_free_ptr free_fn), + PNG_ALLOCATED); +PNG_EXPORTA(12, png_structp, png_create_write_struct_2, + (png_const_charp user_png_ver, png_voidp error_ptr, png_error_ptr error_fn, + png_error_ptr warn_fn, + png_voidp mem_ptr, png_malloc_ptr malloc_fn, png_free_ptr free_fn), + PNG_ALLOCATED); +#endif + +/* Write the PNG file signature. */ +PNG_EXPORT(13, void, png_write_sig, (png_structrp png_ptr)); + +/* Write a PNG chunk - size, type, (optional) data, CRC. */ +PNG_EXPORT(14, void, png_write_chunk, (png_structrp png_ptr, png_const_bytep + chunk_name, png_const_bytep data, png_size_t length)); + +/* Write the start of a PNG chunk - length and chunk name. */ +PNG_EXPORT(15, void, png_write_chunk_start, (png_structrp png_ptr, + png_const_bytep chunk_name, png_uint_32 length)); + +/* Write the data of a PNG chunk started with png_write_chunk_start(). */ +PNG_EXPORT(16, void, png_write_chunk_data, (png_structrp png_ptr, + png_const_bytep data, png_size_t length)); + +/* Finish a chunk started with png_write_chunk_start() (includes CRC). */ +PNG_EXPORT(17, void, png_write_chunk_end, (png_structrp png_ptr)); + +/* Allocate and initialize the info structure */ +PNG_EXPORTA(18, png_infop, png_create_info_struct, (png_const_structrp png_ptr), + PNG_ALLOCATED); + +/* DEPRECATED: this function allowed init structures to be created using the + * default allocation method (typically malloc). Use is deprecated in 1.6.0 and + * the API will be removed in the future. + */ +PNG_EXPORTA(19, void, png_info_init_3, (png_infopp info_ptr, + png_size_t png_info_struct_size), PNG_DEPRECATED); + +/* Writes all the PNG information before the image. */ +PNG_EXPORT(20, void, png_write_info_before_PLTE, + (png_structrp png_ptr, png_const_inforp info_ptr)); +PNG_EXPORT(21, void, png_write_info, + (png_structrp png_ptr, png_const_inforp info_ptr)); + +#ifdef PNG_SEQUENTIAL_READ_SUPPORTED +/* Read the information before the actual image data. */ +PNG_EXPORT(22, void, png_read_info, + (png_structrp png_ptr, png_inforp info_ptr)); +#endif + +#ifdef PNG_TIME_RFC1123_SUPPORTED + /* Convert to a US string format: there is no localization support in this + * routine. The original implementation used a 29 character buffer in + * png_struct, this will be removed in future versions. + */ +#if PNG_LIBPNG_VER < 10700 +/* To do: remove this from libpng17 (and from libpng17/png.c and pngstruct.h) */ +PNG_EXPORTA(23, png_const_charp, png_convert_to_rfc1123, (png_structrp png_ptr, + png_const_timep ptime),PNG_DEPRECATED); +#endif +PNG_EXPORT(241, int, png_convert_to_rfc1123_buffer, (char out[29], + png_const_timep ptime)); +#endif + +#ifdef PNG_CONVERT_tIME_SUPPORTED +/* Convert from a struct tm to png_time */ +PNG_EXPORT(24, void, png_convert_from_struct_tm, (png_timep ptime, + const struct tm * ttime)); + +/* Convert from time_t to png_time. Uses gmtime() */ +PNG_EXPORT(25, void, png_convert_from_time_t, (png_timep ptime, time_t ttime)); +#endif /* PNG_CONVERT_tIME_SUPPORTED */ + +#ifdef PNG_READ_EXPAND_SUPPORTED +/* Expand data to 24-bit RGB, or 8-bit grayscale, with alpha if available. */ +PNG_EXPORT(26, void, png_set_expand, (png_structrp png_ptr)); +PNG_EXPORT(27, void, png_set_expand_gray_1_2_4_to_8, (png_structrp png_ptr)); +PNG_EXPORT(28, void, png_set_palette_to_rgb, (png_structrp png_ptr)); +PNG_EXPORT(29, void, png_set_tRNS_to_alpha, (png_structrp png_ptr)); +#endif + +#ifdef PNG_READ_EXPAND_16_SUPPORTED +/* Expand to 16-bit channels, forces conversion of palette to RGB and expansion + * of a tRNS chunk if present. + */ +PNG_EXPORT(221, void, png_set_expand_16, (png_structrp png_ptr)); +#endif + +#if defined(PNG_READ_BGR_SUPPORTED) || defined(PNG_WRITE_BGR_SUPPORTED) +/* Use blue, green, red order for pixels. */ +PNG_EXPORT(30, void, png_set_bgr, (png_structrp png_ptr)); +#endif + +#ifdef PNG_READ_GRAY_TO_RGB_SUPPORTED +/* Expand the grayscale to 24-bit RGB if necessary. */ +PNG_EXPORT(31, void, png_set_gray_to_rgb, (png_structrp png_ptr)); +#endif + +#ifdef PNG_READ_RGB_TO_GRAY_SUPPORTED +/* Reduce RGB to grayscale. */ +#define PNG_ERROR_ACTION_NONE 1 +#define PNG_ERROR_ACTION_WARN 2 +#define PNG_ERROR_ACTION_ERROR 3 +#define PNG_RGB_TO_GRAY_DEFAULT (-1)/*for red/green coefficients*/ + +PNG_FP_EXPORT(32, void, png_set_rgb_to_gray, (png_structrp png_ptr, + int error_action, double red, double green)) +PNG_FIXED_EXPORT(33, void, png_set_rgb_to_gray_fixed, (png_structrp png_ptr, + int error_action, png_fixed_point red, png_fixed_point green)) + +PNG_EXPORT(34, png_byte, png_get_rgb_to_gray_status, (png_const_structrp + png_ptr)); +#endif + +#ifdef PNG_BUILD_GRAYSCALE_PALETTE_SUPPORTED +PNG_EXPORT(35, void, png_build_grayscale_palette, (int bit_depth, + png_colorp palette)); +#endif + +#ifdef PNG_READ_ALPHA_MODE_SUPPORTED +/* How the alpha channel is interpreted - this affects how the color channels of + * a PNG file are returned when an alpha channel, or tRNS chunk in a palette + * file, is present. + * + * This has no effect on the way pixels are written into a PNG output + * datastream. The color samples in a PNG datastream are never premultiplied + * with the alpha samples. + * + * The default is to return data according to the PNG specification: the alpha + * channel is a linear measure of the contribution of the pixel to the + * corresponding composited pixel. The gamma encoded color channels must be + * scaled according to the contribution and to do this it is necessary to undo + * the encoding, scale the color values, perform the composition and reencode + * the values. This is the 'PNG' mode. + * + * The alternative is to 'associate' the alpha with the color information by + * storing color channel values that have been scaled by the alpha. The + * advantage is that the color channels can be resampled (the image can be + * scaled) in this form. The disadvantage is that normal practice is to store + * linear, not (gamma) encoded, values and this requires 16-bit channels for + * still images rather than the 8-bit channels that are just about sufficient if + * gamma encoding is used. In addition all non-transparent pixel values, + * including completely opaque ones, must be gamma encoded to produce the final + * image. This is the 'STANDARD', 'ASSOCIATED' or 'PREMULTIPLIED' mode (the + * latter being the two common names for associated alpha color channels.) + * + * Since it is not necessary to perform arithmetic on opaque color values so + * long as they are not to be resampled and are in the final color space it is + * possible to optimize the handling of alpha by storing the opaque pixels in + * the PNG format (adjusted for the output color space) while storing partially + * opaque pixels in the standard, linear, format. The accuracy required for + * standard alpha composition is relatively low, because the pixels are + * isolated, therefore typically the accuracy loss in storing 8-bit linear + * values is acceptable. (This is not true if the alpha channel is used to + * simulate transparency over large areas - use 16 bits or the PNG mode in + * this case!) This is the 'OPTIMIZED' mode. For this mode a pixel is + * treated as opaque only if the alpha value is equal to the maximum value. + * + * The final choice is to gamma encode the alpha channel as well. This is + * broken because, in practice, no implementation that uses this choice + * correctly undoes the encoding before handling alpha composition. Use this + * choice only if other serious errors in the software or hardware you use + * mandate it; the typical serious error is for dark halos to appear around + * opaque areas of the composited PNG image because of arithmetic overflow. + * + * The API function png_set_alpha_mode specifies which of these choices to use + * with an enumerated 'mode' value and the gamma of the required output: + */ +#define PNG_ALPHA_PNG 0 /* according to the PNG standard */ +#define PNG_ALPHA_STANDARD 1 /* according to Porter/Duff */ +#define PNG_ALPHA_ASSOCIATED 1 /* as above; this is the normal practice */ +#define PNG_ALPHA_PREMULTIPLIED 1 /* as above */ +#define PNG_ALPHA_OPTIMIZED 2 /* 'PNG' for opaque pixels, else 'STANDARD' */ +#define PNG_ALPHA_BROKEN 3 /* the alpha channel is gamma encoded */ + +PNG_FP_EXPORT(227, void, png_set_alpha_mode, (png_structrp png_ptr, int mode, + double output_gamma)) +PNG_FIXED_EXPORT(228, void, png_set_alpha_mode_fixed, (png_structrp png_ptr, + int mode, png_fixed_point output_gamma)) +#endif + +#if defined(PNG_GAMMA_SUPPORTED) || defined(PNG_READ_ALPHA_MODE_SUPPORTED) +/* The output_gamma value is a screen gamma in libpng terminology: it expresses + * how to decode the output values, not how they are encoded. The values used + * correspond to the normal numbers used to describe the overall gamma of a + * computer display system; for example 2.2 for an sRGB conformant system. The + * values are scaled by 100000 in the _fixed version of the API (so 220000 for + * sRGB.) + * + * The inverse of the value is always used to provide a default for the PNG file + * encoding if it has no gAMA chunk and if png_set_gamma() has not been called + * to override the PNG gamma information. + * + * When the ALPHA_OPTIMIZED mode is selected the output gamma is used to encode + * opaque pixels however pixels with lower alpha values are not encoded, + * regardless of the output gamma setting. + * + * When the standard Porter Duff handling is requested with mode 1 the output + * encoding is set to be linear and the output_gamma value is only relevant + * as a default for input data that has no gamma information. The linear output + * encoding will be overridden if png_set_gamma() is called - the results may be + * highly unexpected! + * + * The following numbers are derived from the sRGB standard and the research + * behind it. sRGB is defined to be approximated by a PNG gAMA chunk value of + * 0.45455 (1/2.2) for PNG. The value implicitly includes any viewing + * correction required to take account of any differences in the color + * environment of the original scene and the intended display environment; the + * value expresses how to *decode* the image for display, not how the original + * data was *encoded*. + * + * sRGB provides a peg for the PNG standard by defining a viewing environment. + * sRGB itself, and earlier TV standards, actually use a more complex transform + * (a linear portion then a gamma 2.4 power law) than PNG can express. (PNG is + * limited to simple power laws.) By saying that an image for direct display on + * an sRGB conformant system should be stored with a gAMA chunk value of 45455 + * (11.3.3.2 and 11.3.3.5 of the ISO PNG specification) the PNG specification + * makes it possible to derive values for other display systems and + * environments. + * + * The Mac value is deduced from the sRGB based on an assumption that the actual + * extra viewing correction used in early Mac display systems was implemented as + * a power 1.45 lookup table. + * + * Any system where a programmable lookup table is used or where the behavior of + * the final display device characteristics can be changed requires system + * specific code to obtain the current characteristic. However this can be + * difficult and most PNG gamma correction only requires an approximate value. + * + * By default, if png_set_alpha_mode() is not called, libpng assumes that all + * values are unencoded, linear, values and that the output device also has a + * linear characteristic. This is only very rarely correct - it is invariably + * better to call png_set_alpha_mode() with PNG_DEFAULT_sRGB than rely on the + * default if you don't know what the right answer is! + * + * The special value PNG_GAMMA_MAC_18 indicates an older Mac system (pre Mac OS + * 10.6) which used a correction table to implement a somewhat lower gamma on an + * otherwise sRGB system. + * + * Both these values are reserved (not simple gamma values) in order to allow + * more precise correction internally in the future. + * + * NOTE: the following values can be passed to either the fixed or floating + * point APIs, but the floating point API will also accept floating point + * values. + */ +#define PNG_DEFAULT_sRGB -1 /* sRGB gamma and color space */ +#define PNG_GAMMA_MAC_18 -2 /* Old Mac '1.8' gamma and color space */ +#define PNG_GAMMA_sRGB 220000 /* Television standards--matches sRGB gamma */ +#define PNG_GAMMA_LINEAR PNG_FP_1 /* Linear */ +#endif + +/* The following are examples of calls to png_set_alpha_mode to achieve the + * required overall gamma correction and, where necessary, alpha + * premultiplication. + * + * png_set_alpha_mode(pp, PNG_ALPHA_PNG, PNG_DEFAULT_sRGB); + * This is the default libpng handling of the alpha channel - it is not + * pre-multiplied into the color components. In addition the call states + * that the output is for a sRGB system and causes all PNG files without gAMA + * chunks to be assumed to be encoded using sRGB. + * + * png_set_alpha_mode(pp, PNG_ALPHA_PNG, PNG_GAMMA_MAC); + * In this case the output is assumed to be something like an sRGB conformant + * display preceeded by a power-law lookup table of power 1.45. This is how + * early Mac systems behaved. + * + * png_set_alpha_mode(pp, PNG_ALPHA_STANDARD, PNG_GAMMA_LINEAR); + * This is the classic Jim Blinn approach and will work in academic + * environments where everything is done by the book. It has the shortcoming + * of assuming that input PNG data with no gamma information is linear - this + * is unlikely to be correct unless the PNG files where generated locally. + * Most of the time the output precision will be so low as to show + * significant banding in dark areas of the image. + * + * png_set_expand_16(pp); + * png_set_alpha_mode(pp, PNG_ALPHA_STANDARD, PNG_DEFAULT_sRGB); + * This is a somewhat more realistic Jim Blinn inspired approach. PNG files + * are assumed to have the sRGB encoding if not marked with a gamma value and + * the output is always 16 bits per component. This permits accurate scaling + * and processing of the data. If you know that your input PNG files were + * generated locally you might need to replace PNG_DEFAULT_sRGB with the + * correct value for your system. + * + * png_set_alpha_mode(pp, PNG_ALPHA_OPTIMIZED, PNG_DEFAULT_sRGB); + * If you just need to composite the PNG image onto an existing background + * and if you control the code that does this you can use the optimization + * setting. In this case you just copy completely opaque pixels to the + * output. For pixels that are not completely transparent (you just skip + * those) you do the composition math using png_composite or png_composite_16 + * below then encode the resultant 8-bit or 16-bit values to match the output + * encoding. + * + * Other cases + * If neither the PNG nor the standard linear encoding work for you because + * of the software or hardware you use then you have a big problem. The PNG + * case will probably result in halos around the image. The linear encoding + * will probably result in a washed out, too bright, image (it's actually too + * contrasty.) Try the ALPHA_OPTIMIZED mode above - this will probably + * substantially reduce the halos. Alternatively try: + * + * png_set_alpha_mode(pp, PNG_ALPHA_BROKEN, PNG_DEFAULT_sRGB); + * This option will also reduce the halos, but there will be slight dark + * halos round the opaque parts of the image where the background is light. + * In the OPTIMIZED mode the halos will be light halos where the background + * is dark. Take your pick - the halos are unavoidable unless you can get + * your hardware/software fixed! (The OPTIMIZED approach is slightly + * faster.) + * + * When the default gamma of PNG files doesn't match the output gamma. + * If you have PNG files with no gamma information png_set_alpha_mode allows + * you to provide a default gamma, but it also sets the ouput gamma to the + * matching value. If you know your PNG files have a gamma that doesn't + * match the output you can take advantage of the fact that + * png_set_alpha_mode always sets the output gamma but only sets the PNG + * default if it is not already set: + * + * png_set_alpha_mode(pp, PNG_ALPHA_PNG, PNG_DEFAULT_sRGB); + * png_set_alpha_mode(pp, PNG_ALPHA_PNG, PNG_GAMMA_MAC); + * The first call sets both the default and the output gamma values, the + * second call overrides the output gamma without changing the default. This + * is easier than achieving the same effect with png_set_gamma. You must use + * PNG_ALPHA_PNG for the first call - internal checking in png_set_alpha will + * fire if more than one call to png_set_alpha_mode and png_set_background is + * made in the same read operation, however multiple calls with PNG_ALPHA_PNG + * are ignored. + */ + +#ifdef PNG_READ_STRIP_ALPHA_SUPPORTED +PNG_EXPORT(36, void, png_set_strip_alpha, (png_structrp png_ptr)); +#endif + +#if defined(PNG_READ_SWAP_ALPHA_SUPPORTED) || \ + defined(PNG_WRITE_SWAP_ALPHA_SUPPORTED) +PNG_EXPORT(37, void, png_set_swap_alpha, (png_structrp png_ptr)); +#endif + +#if defined(PNG_READ_INVERT_ALPHA_SUPPORTED) || \ + defined(PNG_WRITE_INVERT_ALPHA_SUPPORTED) +PNG_EXPORT(38, void, png_set_invert_alpha, (png_structrp png_ptr)); +#endif + +#if defined(PNG_READ_FILLER_SUPPORTED) || defined(PNG_WRITE_FILLER_SUPPORTED) +/* Add a filler byte to 8-bit Gray or 24-bit RGB images. */ +PNG_EXPORT(39, void, png_set_filler, (png_structrp png_ptr, png_uint_32 filler, + int flags)); +/* The values of the PNG_FILLER_ defines should NOT be changed */ +# define PNG_FILLER_BEFORE 0 +# define PNG_FILLER_AFTER 1 +/* Add an alpha byte to 8-bit Gray or 24-bit RGB images. */ +PNG_EXPORT(40, void, png_set_add_alpha, (png_structrp png_ptr, + png_uint_32 filler, int flags)); +#endif /* PNG_READ_FILLER_SUPPORTED || PNG_WRITE_FILLER_SUPPORTED */ + +#if defined(PNG_READ_SWAP_SUPPORTED) || defined(PNG_WRITE_SWAP_SUPPORTED) +/* Swap bytes in 16-bit depth files. */ +PNG_EXPORT(41, void, png_set_swap, (png_structrp png_ptr)); +#endif + +#if defined(PNG_READ_PACK_SUPPORTED) || defined(PNG_WRITE_PACK_SUPPORTED) +/* Use 1 byte per pixel in 1, 2, or 4-bit depth files. */ +PNG_EXPORT(42, void, png_set_packing, (png_structrp png_ptr)); +#endif + +#if defined(PNG_READ_PACKSWAP_SUPPORTED) || \ + defined(PNG_WRITE_PACKSWAP_SUPPORTED) +/* Swap packing order of pixels in bytes. */ +PNG_EXPORT(43, void, png_set_packswap, (png_structrp png_ptr)); +#endif + +#if defined(PNG_READ_SHIFT_SUPPORTED) || defined(PNG_WRITE_SHIFT_SUPPORTED) +/* Converts files to legal bit depths. */ +PNG_EXPORT(44, void, png_set_shift, (png_structrp png_ptr, png_const_color_8p + true_bits)); +#endif + +#if defined(PNG_READ_INTERLACING_SUPPORTED) || \ + defined(PNG_WRITE_INTERLACING_SUPPORTED) +/* Have the code handle the interlacing. Returns the number of passes. + * MUST be called before png_read_update_info or png_start_read_image, + * otherwise it will not have the desired effect. Note that it is still + * necessary to call png_read_row or png_read_rows png_get_image_height + * times for each pass. +*/ +PNG_EXPORT(45, int, png_set_interlace_handling, (png_structrp png_ptr)); +#endif + +#if defined(PNG_READ_INVERT_SUPPORTED) || defined(PNG_WRITE_INVERT_SUPPORTED) +/* Invert monochrome files */ +PNG_EXPORT(46, void, png_set_invert_mono, (png_structrp png_ptr)); +#endif + +#ifdef PNG_READ_BACKGROUND_SUPPORTED +/* Handle alpha and tRNS by replacing with a background color. Prior to + * libpng-1.5.4 this API must not be called before the PNG file header has been + * read. Doing so will result in unexpected behavior and possible warnings or + * errors if the PNG file contains a bKGD chunk. + */ +PNG_FP_EXPORT(47, void, png_set_background, (png_structrp png_ptr, + png_const_color_16p background_color, int background_gamma_code, + int need_expand, double background_gamma)) +PNG_FIXED_EXPORT(215, void, png_set_background_fixed, (png_structrp png_ptr, + png_const_color_16p background_color, int background_gamma_code, + int need_expand, png_fixed_point background_gamma)) +#endif +#ifdef PNG_READ_BACKGROUND_SUPPORTED +# define PNG_BACKGROUND_GAMMA_UNKNOWN 0 +# define PNG_BACKGROUND_GAMMA_SCREEN 1 +# define PNG_BACKGROUND_GAMMA_FILE 2 +# define PNG_BACKGROUND_GAMMA_UNIQUE 3 +#endif + +#ifdef PNG_READ_SCALE_16_TO_8_SUPPORTED +/* Scale a 16-bit depth file down to 8-bit, accurately. */ +PNG_EXPORT(229, void, png_set_scale_16, (png_structrp png_ptr)); +#endif + +#ifdef PNG_READ_STRIP_16_TO_8_SUPPORTED +#define PNG_READ_16_TO_8 SUPPORTED /* Name prior to 1.5.4 */ +/* Strip the second byte of information from a 16-bit depth file. */ +PNG_EXPORT(48, void, png_set_strip_16, (png_structrp png_ptr)); +#endif + +#ifdef PNG_READ_QUANTIZE_SUPPORTED +/* Turn on quantizing, and reduce the palette to the number of colors + * available. + */ +PNG_EXPORT(49, void, png_set_quantize, (png_structrp png_ptr, + png_colorp palette, int num_palette, int maximum_colors, + png_const_uint_16p histogram, int full_quantize)); +#endif + +#ifdef PNG_READ_GAMMA_SUPPORTED +/* The threshold on gamma processing is configurable but hard-wired into the + * library. The following is the floating point variant. + */ +#define PNG_GAMMA_THRESHOLD (PNG_GAMMA_THRESHOLD_FIXED*.00001) + +/* Handle gamma correction. Screen_gamma=(display_exponent). + * NOTE: this API simply sets the screen and file gamma values. It will + * therefore override the value for gamma in a PNG file if it is called after + * the file header has been read - use with care - call before reading the PNG + * file for best results! + * + * These routines accept the same gamma values as png_set_alpha_mode (described + * above). The PNG_GAMMA_ defines and PNG_DEFAULT_sRGB can be passed to either + * API (floating point or fixed.) Notice, however, that the 'file_gamma' value + * is the inverse of a 'screen gamma' value. + */ +PNG_FP_EXPORT(50, void, png_set_gamma, (png_structrp png_ptr, + double screen_gamma, double override_file_gamma)) +PNG_FIXED_EXPORT(208, void, png_set_gamma_fixed, (png_structrp png_ptr, + png_fixed_point screen_gamma, png_fixed_point override_file_gamma)) +#endif + +#ifdef PNG_WRITE_FLUSH_SUPPORTED +/* Set how many lines between output flushes - 0 for no flushing */ +PNG_EXPORT(51, void, png_set_flush, (png_structrp png_ptr, int nrows)); +/* Flush the current PNG output buffer */ +PNG_EXPORT(52, void, png_write_flush, (png_structrp png_ptr)); +#endif + +/* Optional update palette with requested transformations */ +PNG_EXPORT(53, void, png_start_read_image, (png_structrp png_ptr)); + +/* Optional call to update the users info structure */ +PNG_EXPORT(54, void, png_read_update_info, (png_structrp png_ptr, + png_inforp info_ptr)); + +#ifdef PNG_SEQUENTIAL_READ_SUPPORTED +/* Read one or more rows of image data. */ +PNG_EXPORT(55, void, png_read_rows, (png_structrp png_ptr, png_bytepp row, + png_bytepp display_row, png_uint_32 num_rows)); +#endif + +#ifdef PNG_SEQUENTIAL_READ_SUPPORTED +/* Read a row of data. */ +PNG_EXPORT(56, void, png_read_row, (png_structrp png_ptr, png_bytep row, + png_bytep display_row)); +#endif + +#ifdef PNG_SEQUENTIAL_READ_SUPPORTED +/* Read the whole image into memory at once. */ +PNG_EXPORT(57, void, png_read_image, (png_structrp png_ptr, png_bytepp image)); +#endif + +/* Write a row of image data */ +PNG_EXPORT(58, void, png_write_row, (png_structrp png_ptr, + png_const_bytep row)); + +/* Write a few rows of image data: (*row) is not written; however, the type + * is declared as writeable to maintain compatibility with previous versions + * of libpng and to allow the 'display_row' array from read_rows to be passed + * unchanged to write_rows. + */ +PNG_EXPORT(59, void, png_write_rows, (png_structrp png_ptr, png_bytepp row, + png_uint_32 num_rows)); + +/* Write the image data */ +PNG_EXPORT(60, void, png_write_image, (png_structrp png_ptr, png_bytepp image)); + +/* Write the end of the PNG file. */ +PNG_EXPORT(61, void, png_write_end, (png_structrp png_ptr, + png_inforp info_ptr)); + +#ifdef PNG_SEQUENTIAL_READ_SUPPORTED +/* Read the end of the PNG file. */ +PNG_EXPORT(62, void, png_read_end, (png_structrp png_ptr, png_inforp info_ptr)); +#endif + +/* Free any memory associated with the png_info_struct */ +PNG_EXPORT(63, void, png_destroy_info_struct, (png_const_structrp png_ptr, + png_infopp info_ptr_ptr)); + +/* Free any memory associated with the png_struct and the png_info_structs */ +PNG_EXPORT(64, void, png_destroy_read_struct, (png_structpp png_ptr_ptr, + png_infopp info_ptr_ptr, png_infopp end_info_ptr_ptr)); + +/* Free any memory associated with the png_struct and the png_info_structs */ +PNG_EXPORT(65, void, png_destroy_write_struct, (png_structpp png_ptr_ptr, + png_infopp info_ptr_ptr)); + +/* Set the libpng method of handling chunk CRC errors */ +PNG_EXPORT(66, void, png_set_crc_action, (png_structrp png_ptr, int crit_action, + int ancil_action)); + +/* Values for png_set_crc_action() say how to handle CRC errors in + * ancillary and critical chunks, and whether to use the data contained + * therein. Note that it is impossible to "discard" data in a critical + * chunk. For versions prior to 0.90, the action was always error/quit, + * whereas in version 0.90 and later, the action for CRC errors in ancillary + * chunks is warn/discard. These values should NOT be changed. + * + * value action:critical action:ancillary + */ +#define PNG_CRC_DEFAULT 0 /* error/quit warn/discard data */ +#define PNG_CRC_ERROR_QUIT 1 /* error/quit error/quit */ +#define PNG_CRC_WARN_DISCARD 2 /* (INVALID) warn/discard data */ +#define PNG_CRC_WARN_USE 3 /* warn/use data warn/use data */ +#define PNG_CRC_QUIET_USE 4 /* quiet/use data quiet/use data */ +#define PNG_CRC_NO_CHANGE 5 /* use current value use current value */ + +/* These functions give the user control over the scan-line filtering in + * libpng and the compression methods used by zlib. These functions are + * mainly useful for testing, as the defaults should work with most users. + * Those users who are tight on memory or want faster performance at the + * expense of compression can modify them. See the compression library + * header file (zlib.h) for an explination of the compression functions. + */ + +/* Set the filtering method(s) used by libpng. Currently, the only valid + * value for "method" is 0. + */ +PNG_EXPORT(67, void, png_set_filter, (png_structrp png_ptr, int method, + int filters)); + +/* Flags for png_set_filter() to say which filters to use. The flags + * are chosen so that they don't conflict with real filter types + * below, in case they are supplied instead of the #defined constants. + * These values should NOT be changed. + */ +#define PNG_NO_FILTERS 0x00 +#define PNG_FILTER_NONE 0x08 +#define PNG_FILTER_SUB 0x10 +#define PNG_FILTER_UP 0x20 +#define PNG_FILTER_AVG 0x40 +#define PNG_FILTER_PAETH 0x80 +#define PNG_ALL_FILTERS (PNG_FILTER_NONE | PNG_FILTER_SUB | PNG_FILTER_UP | \ + PNG_FILTER_AVG | PNG_FILTER_PAETH) + +/* Filter values (not flags) - used in pngwrite.c, pngwutil.c for now. + * These defines should NOT be changed. + */ +#define PNG_FILTER_VALUE_NONE 0 +#define PNG_FILTER_VALUE_SUB 1 +#define PNG_FILTER_VALUE_UP 2 +#define PNG_FILTER_VALUE_AVG 3 +#define PNG_FILTER_VALUE_PAETH 4 +#define PNG_FILTER_VALUE_LAST 5 + +#ifdef PNG_WRITE_WEIGHTED_FILTER_SUPPORTED /* EXPERIMENTAL */ +/* The "heuristic_method" is given by one of the PNG_FILTER_HEURISTIC_ + * defines, either the default (minimum-sum-of-absolute-differences), or + * the experimental method (weighted-minimum-sum-of-absolute-differences). + * + * Weights are factors >= 1.0, indicating how important it is to keep the + * filter type consistent between rows. Larger numbers mean the current + * filter is that many times as likely to be the same as the "num_weights" + * previous filters. This is cumulative for each previous row with a weight. + * There needs to be "num_weights" values in "filter_weights", or it can be + * NULL if the weights aren't being specified. Weights have no influence on + * the selection of the first row filter. Well chosen weights can (in theory) + * improve the compression for a given image. + * + * Costs are factors >= 1.0 indicating the relative decoding costs of a + * filter type. Higher costs indicate more decoding expense, and are + * therefore less likely to be selected over a filter with lower computational + * costs. There needs to be a value in "filter_costs" for each valid filter + * type (given by PNG_FILTER_VALUE_LAST), or it can be NULL if you aren't + * setting the costs. Costs try to improve the speed of decompression without + * unduly increasing the compressed image size. + * + * A negative weight or cost indicates the default value is to be used, and + * values in the range [0.0, 1.0) indicate the value is to remain unchanged. + * The default values for both weights and costs are currently 1.0, but may + * change if good general weighting/cost heuristics can be found. If both + * the weights and costs are set to 1.0, this degenerates the WEIGHTED method + * to the UNWEIGHTED method, but with added encoding time/computation. + */ +PNG_FP_EXPORT(68, void, png_set_filter_heuristics, (png_structrp png_ptr, + int heuristic_method, int num_weights, png_const_doublep filter_weights, + png_const_doublep filter_costs)) +PNG_FIXED_EXPORT(209, void, png_set_filter_heuristics_fixed, + (png_structrp png_ptr, int heuristic_method, int num_weights, + png_const_fixed_point_p filter_weights, + png_const_fixed_point_p filter_costs)) +#endif /* PNG_WRITE_WEIGHTED_FILTER_SUPPORTED */ + +/* Heuristic used for row filter selection. These defines should NOT be + * changed. + */ +#define PNG_FILTER_HEURISTIC_DEFAULT 0 /* Currently "UNWEIGHTED" */ +#define PNG_FILTER_HEURISTIC_UNWEIGHTED 1 /* Used by libpng < 0.95 */ +#define PNG_FILTER_HEURISTIC_WEIGHTED 2 /* Experimental feature */ +#define PNG_FILTER_HEURISTIC_LAST 3 /* Not a valid value */ + +#ifdef PNG_WRITE_SUPPORTED +/* Set the library compression level. Currently, valid values range from + * 0 - 9, corresponding directly to the zlib compression levels 0 - 9 + * (0 - no compression, 9 - "maximal" compression). Note that tests have + * shown that zlib compression levels 3-6 usually perform as well as level 9 + * for PNG images, and do considerably fewer caclulations. In the future, + * these values may not correspond directly to the zlib compression levels. + */ +PNG_EXPORT(69, void, png_set_compression_level, (png_structrp png_ptr, + int level)); + +PNG_EXPORT(70, void, png_set_compression_mem_level, (png_structrp png_ptr, + int mem_level)); + +PNG_EXPORT(71, void, png_set_compression_strategy, (png_structrp png_ptr, + int strategy)); + +/* If PNG_WRITE_OPTIMIZE_CMF_SUPPORTED is defined, libpng will use a + * smaller value of window_bits if it can do so safely. + */ +PNG_EXPORT(72, void, png_set_compression_window_bits, (png_structrp png_ptr, + int window_bits)); + +PNG_EXPORT(73, void, png_set_compression_method, (png_structrp png_ptr, + int method)); +#endif + +#ifdef PNG_WRITE_CUSTOMIZE_ZTXT_COMPRESSION_SUPPORTED +/* Also set zlib parameters for compressing non-IDAT chunks */ +PNG_EXPORT(222, void, png_set_text_compression_level, (png_structrp png_ptr, + int level)); + +PNG_EXPORT(223, void, png_set_text_compression_mem_level, (png_structrp png_ptr, + int mem_level)); + +PNG_EXPORT(224, void, png_set_text_compression_strategy, (png_structrp png_ptr, + int strategy)); + +/* If PNG_WRITE_OPTIMIZE_CMF_SUPPORTED is defined, libpng will use a + * smaller value of window_bits if it can do so safely. + */ +PNG_EXPORT(225, void, png_set_text_compression_window_bits, + (png_structrp png_ptr, int window_bits)); + +PNG_EXPORT(226, void, png_set_text_compression_method, (png_structrp png_ptr, + int method)); +#endif /* PNG_WRITE_CUSTOMIZE_ZTXT_COMPRESSION_SUPPORTED */ + +/* These next functions are called for input/output, memory, and error + * handling. They are in the file pngrio.c, pngwio.c, and pngerror.c, + * and call standard C I/O routines such as fread(), fwrite(), and + * fprintf(). These functions can be made to use other I/O routines + * at run time for those applications that need to handle I/O in a + * different manner by calling png_set_???_fn(). See libpng-manual.txt for + * more information. + */ + +#ifdef PNG_STDIO_SUPPORTED +/* Initialize the input/output for the PNG file to the default functions. */ +PNG_EXPORT(74, void, png_init_io, (png_structrp png_ptr, png_FILE_p fp)); +#endif + +/* Replace the (error and abort), and warning functions with user + * supplied functions. If no messages are to be printed you must still + * write and use replacement functions. The replacement error_fn should + * still do a longjmp to the last setjmp location if you are using this + * method of error handling. If error_fn or warning_fn is NULL, the + * default function will be used. + */ + +PNG_EXPORT(75, void, png_set_error_fn, (png_structrp png_ptr, + png_voidp error_ptr, png_error_ptr error_fn, png_error_ptr warning_fn)); + +/* Return the user pointer associated with the error functions */ +PNG_EXPORT(76, png_voidp, png_get_error_ptr, (png_const_structrp png_ptr)); + +/* Replace the default data output functions with a user supplied one(s). + * If buffered output is not used, then output_flush_fn can be set to NULL. + * If PNG_WRITE_FLUSH_SUPPORTED is not defined at libpng compile time + * output_flush_fn will be ignored (and thus can be NULL). + * It is probably a mistake to use NULL for output_flush_fn if + * write_data_fn is not also NULL unless you have built libpng with + * PNG_WRITE_FLUSH_SUPPORTED undefined, because in this case libpng's + * default flush function, which uses the standard *FILE structure, will + * be used. + */ +PNG_EXPORT(77, void, png_set_write_fn, (png_structrp png_ptr, png_voidp io_ptr, + png_rw_ptr write_data_fn, png_flush_ptr output_flush_fn)); + +/* Replace the default data input function with a user supplied one. */ +PNG_EXPORT(78, void, png_set_read_fn, (png_structrp png_ptr, png_voidp io_ptr, + png_rw_ptr read_data_fn)); + +/* Return the user pointer associated with the I/O functions */ +PNG_EXPORT(79, png_voidp, png_get_io_ptr, (png_const_structrp png_ptr)); + +PNG_EXPORT(80, void, png_set_read_status_fn, (png_structrp png_ptr, + png_read_status_ptr read_row_fn)); + +PNG_EXPORT(81, void, png_set_write_status_fn, (png_structrp png_ptr, + png_write_status_ptr write_row_fn)); + +#ifdef PNG_USER_MEM_SUPPORTED +/* Replace the default memory allocation functions with user supplied one(s). */ +PNG_EXPORT(82, void, png_set_mem_fn, (png_structrp png_ptr, png_voidp mem_ptr, + png_malloc_ptr malloc_fn, png_free_ptr free_fn)); +/* Return the user pointer associated with the memory functions */ +PNG_EXPORT(83, png_voidp, png_get_mem_ptr, (png_const_structrp png_ptr)); +#endif + +#ifdef PNG_READ_USER_TRANSFORM_SUPPORTED +PNG_EXPORT(84, void, png_set_read_user_transform_fn, (png_structrp png_ptr, + png_user_transform_ptr read_user_transform_fn)); +#endif + +#ifdef PNG_WRITE_USER_TRANSFORM_SUPPORTED +PNG_EXPORT(85, void, png_set_write_user_transform_fn, (png_structrp png_ptr, + png_user_transform_ptr write_user_transform_fn)); +#endif + +#ifdef PNG_USER_TRANSFORM_PTR_SUPPORTED +PNG_EXPORT(86, void, png_set_user_transform_info, (png_structrp png_ptr, + png_voidp user_transform_ptr, int user_transform_depth, + int user_transform_channels)); +/* Return the user pointer associated with the user transform functions */ +PNG_EXPORT(87, png_voidp, png_get_user_transform_ptr, + (png_const_structrp png_ptr)); +#endif + +#ifdef PNG_USER_TRANSFORM_INFO_SUPPORTED +/* Return information about the row currently being processed. Note that these + * APIs do not fail but will return unexpected results if called outside a user + * transform callback. Also note that when transforming an interlaced image the + * row number is the row number within the sub-image of the interlace pass, so + * the value will increase to the height of the sub-image (not the full image) + * then reset to 0 for the next pass. + * + * Use PNG_ROW_FROM_PASS_ROW(row, pass) and PNG_COL_FROM_PASS_COL(col, pass) to + * find the output pixel (x,y) given an interlaced sub-image pixel + * (row,col,pass). (See below for these macros.) + */ +PNG_EXPORT(217, png_uint_32, png_get_current_row_number, (png_const_structrp)); +PNG_EXPORT(218, png_byte, png_get_current_pass_number, (png_const_structrp)); +#endif + +#ifdef PNG_READ_USER_CHUNKS_SUPPORTED +/* This callback is called only for *unknown* chunks. If + * PNG_HANDLE_AS_UNKNOWN_SUPPORTED is set then it is possible to set known + * chunks to be treated as unknown, however in this case the callback must do + * any processing required by the chunk (e.g. by calling the appropriate + * png_set_ APIs.) + * + * There is no write support - on write, by default, all the chunks in the + * 'unknown' list are written in the specified position. + * + * The integer return from the callback function is interpreted thus: + * + * negative: An error occured, png_chunk_error will be called. + * zero: The chunk was not handled, the chunk will be saved. A critical + * chunk will cause an error at this point unless it is to be saved. + * positive: The chunk was handled, libpng will ignore/discard it. + * + * See "INTERACTION WTIH USER CHUNK CALLBACKS" below for important notes about + * how this behavior will change in libpng 1.7 + */ +PNG_EXPORT(88, void, png_set_read_user_chunk_fn, (png_structrp png_ptr, + png_voidp user_chunk_ptr, png_user_chunk_ptr read_user_chunk_fn)); +#endif + +#ifdef PNG_USER_CHUNKS_SUPPORTED +PNG_EXPORT(89, png_voidp, png_get_user_chunk_ptr, (png_const_structrp png_ptr)); +#endif + +#ifdef PNG_PROGRESSIVE_READ_SUPPORTED +/* Sets the function callbacks for the push reader, and a pointer to a + * user-defined structure available to the callback functions. + */ +PNG_EXPORT(90, void, png_set_progressive_read_fn, (png_structrp png_ptr, + png_voidp progressive_ptr, png_progressive_info_ptr info_fn, + png_progressive_row_ptr row_fn, png_progressive_end_ptr end_fn)); + +/* Returns the user pointer associated with the push read functions */ +PNG_EXPORT(91, png_voidp, png_get_progressive_ptr, + (png_const_structrp png_ptr)); + +/* Function to be called when data becomes available */ +PNG_EXPORT(92, void, png_process_data, (png_structrp png_ptr, + png_inforp info_ptr, png_bytep buffer, png_size_t buffer_size)); + +/* A function which may be called *only* within png_process_data to stop the + * processing of any more data. The function returns the number of bytes + * remaining, excluding any that libpng has cached internally. A subsequent + * call to png_process_data must supply these bytes again. If the argument + * 'save' is set to true the routine will first save all the pending data and + * will always return 0. + */ +PNG_EXPORT(219, png_size_t, png_process_data_pause, (png_structrp, int save)); + +/* A function which may be called *only* outside (after) a call to + * png_process_data. It returns the number of bytes of data to skip in the + * input. Normally it will return 0, but if it returns a non-zero value the + * application must skip than number of bytes of input data and pass the + * following data to the next call to png_process_data. + */ +PNG_EXPORT(220, png_uint_32, png_process_data_skip, (png_structrp)); + +#ifdef PNG_READ_INTERLACING_SUPPORTED +/* Function that combines rows. 'new_row' is a flag that should come from + * the callback and be non-NULL if anything needs to be done; the library + * stores its own version of the new data internally and ignores the passed + * in value. + */ +PNG_EXPORT(93, void, png_progressive_combine_row, (png_const_structrp png_ptr, + png_bytep old_row, png_const_bytep new_row)); +#endif /* PNG_READ_INTERLACING_SUPPORTED */ +#endif /* PNG_PROGRESSIVE_READ_SUPPORTED */ + +PNG_EXPORTA(94, png_voidp, png_malloc, (png_const_structrp png_ptr, + png_alloc_size_t size), PNG_ALLOCATED); +/* Added at libpng version 1.4.0 */ +PNG_EXPORTA(95, png_voidp, png_calloc, (png_const_structrp png_ptr, + png_alloc_size_t size), PNG_ALLOCATED); + +/* Added at libpng version 1.2.4 */ +PNG_EXPORTA(96, png_voidp, png_malloc_warn, (png_const_structrp png_ptr, + png_alloc_size_t size), PNG_ALLOCATED); + +/* Frees a pointer allocated by png_malloc() */ +PNG_EXPORT(97, void, png_free, (png_const_structrp png_ptr, png_voidp ptr)); + +/* Free data that was allocated internally */ +PNG_EXPORT(98, void, png_free_data, (png_const_structrp png_ptr, + png_inforp info_ptr, png_uint_32 free_me, int num)); + +/* Reassign responsibility for freeing existing data, whether allocated + * by libpng or by the application; this works on the png_info structure passed + * in, it does not change the state for other png_info structures. + * + * It is unlikely that this function works correctly as of 1.6.0 and using it + * may result either in memory leaks or double free of allocated data. + */ +PNG_EXPORTA(99, void, png_data_freer, (png_const_structrp png_ptr, + png_inforp info_ptr, int freer, png_uint_32 mask), PNG_DEPRECATED); + +/* Assignments for png_data_freer */ +#define PNG_DESTROY_WILL_FREE_DATA 1 +#define PNG_SET_WILL_FREE_DATA 1 +#define PNG_USER_WILL_FREE_DATA 2 +/* Flags for png_ptr->free_me and info_ptr->free_me */ +#define PNG_FREE_HIST 0x0008 +#define PNG_FREE_ICCP 0x0010 +#define PNG_FREE_SPLT 0x0020 +#define PNG_FREE_ROWS 0x0040 +#define PNG_FREE_PCAL 0x0080 +#define PNG_FREE_SCAL 0x0100 +#ifdef PNG_STORE_UNKNOWN_CHUNKS_SUPPORTED +# define PNG_FREE_UNKN 0x0200 +#endif +/* PNG_FREE_LIST 0x0400 removed in 1.6.0 because it is ignored */ +#define PNG_FREE_PLTE 0x1000 +#define PNG_FREE_TRNS 0x2000 +#define PNG_FREE_TEXT 0x4000 +#define PNG_FREE_ALL 0x7fff +#define PNG_FREE_MUL 0x4220 /* PNG_FREE_SPLT|PNG_FREE_TEXT|PNG_FREE_UNKN */ + +#ifdef PNG_USER_MEM_SUPPORTED +PNG_EXPORTA(100, png_voidp, png_malloc_default, (png_const_structrp png_ptr, + png_alloc_size_t size), PNG_ALLOCATED PNG_DEPRECATED); +PNG_EXPORTA(101, void, png_free_default, (png_const_structrp png_ptr, + png_voidp ptr), PNG_DEPRECATED); +#endif + +#ifdef PNG_ERROR_TEXT_SUPPORTED +/* Fatal error in PNG image of libpng - can't continue */ +PNG_EXPORTA(102, void, png_error, (png_const_structrp png_ptr, + png_const_charp error_message), PNG_NORETURN); + +/* The same, but the chunk name is prepended to the error string. */ +PNG_EXPORTA(103, void, png_chunk_error, (png_const_structrp png_ptr, + png_const_charp error_message), PNG_NORETURN); + +#else +/* Fatal error in PNG image of libpng - can't continue */ +PNG_EXPORTA(104, void, png_err, (png_const_structrp png_ptr), PNG_NORETURN); +#endif + +#ifdef PNG_WARNINGS_SUPPORTED +/* Non-fatal error in libpng. Can continue, but may have a problem. */ +PNG_EXPORT(105, void, png_warning, (png_const_structrp png_ptr, + png_const_charp warning_message)); + +/* Non-fatal error in libpng, chunk name is prepended to message. */ +PNG_EXPORT(106, void, png_chunk_warning, (png_const_structrp png_ptr, + png_const_charp warning_message)); +#endif + +#ifdef PNG_BENIGN_ERRORS_SUPPORTED +/* Benign error in libpng. Can continue, but may have a problem. + * User can choose whether to handle as a fatal error or as a warning. */ +PNG_EXPORT(107, void, png_benign_error, (png_const_structrp png_ptr, + png_const_charp warning_message)); + +#ifdef PNG_READ_SUPPORTED +/* Same, chunk name is prepended to message (only during read) */ +PNG_EXPORT(108, void, png_chunk_benign_error, (png_const_structrp png_ptr, + png_const_charp warning_message)); +#endif + +PNG_EXPORT(109, void, png_set_benign_errors, + (png_structrp png_ptr, int allowed)); +#else +# ifdef PNG_ALLOW_BENIGN_ERRORS +# define png_benign_error png_warning +# define png_chunk_benign_error png_chunk_warning +# else +# define png_benign_error png_error +# define png_chunk_benign_error png_chunk_error +# endif +#endif + +/* The png_set_ functions are for storing values in the png_info_struct. + * Similarly, the png_get_ calls are used to read values from the + * png_info_struct, either storing the parameters in the passed variables, or + * setting pointers into the png_info_struct where the data is stored. The + * png_get_ functions return a non-zero value if the data was available + * in info_ptr, or return zero and do not change any of the parameters if the + * data was not available. + * + * These functions should be used instead of directly accessing png_info + * to avoid problems with future changes in the size and internal layout of + * png_info_struct. + */ +/* Returns "flag" if chunk data is valid in info_ptr. */ +PNG_EXPORT(110, png_uint_32, png_get_valid, (png_const_structrp png_ptr, + png_const_inforp info_ptr, png_uint_32 flag)); + +/* Returns number of bytes needed to hold a transformed row. */ +PNG_EXPORT(111, png_size_t, png_get_rowbytes, (png_const_structrp png_ptr, + png_const_inforp info_ptr)); + +#ifdef PNG_INFO_IMAGE_SUPPORTED +/* Returns row_pointers, which is an array of pointers to scanlines that was + * returned from png_read_png(). + */ +PNG_EXPORT(112, png_bytepp, png_get_rows, (png_const_structrp png_ptr, + png_const_inforp info_ptr)); + +/* Set row_pointers, which is an array of pointers to scanlines for use + * by png_write_png(). + */ +PNG_EXPORT(113, void, png_set_rows, (png_const_structrp png_ptr, + png_inforp info_ptr, png_bytepp row_pointers)); +#endif + +/* Returns number of color channels in image. */ +PNG_EXPORT(114, png_byte, png_get_channels, (png_const_structrp png_ptr, + png_const_inforp info_ptr)); + +#ifdef PNG_EASY_ACCESS_SUPPORTED +/* Returns image width in pixels. */ +PNG_EXPORT(115, png_uint_32, png_get_image_width, (png_const_structrp png_ptr, + png_const_inforp info_ptr)); + +/* Returns image height in pixels. */ +PNG_EXPORT(116, png_uint_32, png_get_image_height, (png_const_structrp png_ptr, + png_const_inforp info_ptr)); + +/* Returns image bit_depth. */ +PNG_EXPORT(117, png_byte, png_get_bit_depth, (png_const_structrp png_ptr, + png_const_inforp info_ptr)); + +/* Returns image color_type. */ +PNG_EXPORT(118, png_byte, png_get_color_type, (png_const_structrp png_ptr, + png_const_inforp info_ptr)); + +/* Returns image filter_type. */ +PNG_EXPORT(119, png_byte, png_get_filter_type, (png_const_structrp png_ptr, + png_const_inforp info_ptr)); + +/* Returns image interlace_type. */ +PNG_EXPORT(120, png_byte, png_get_interlace_type, (png_const_structrp png_ptr, + png_const_inforp info_ptr)); + +/* Returns image compression_type. */ +PNG_EXPORT(121, png_byte, png_get_compression_type, (png_const_structrp png_ptr, + png_const_inforp info_ptr)); + +/* Returns image resolution in pixels per meter, from pHYs chunk data. */ +PNG_EXPORT(122, png_uint_32, png_get_pixels_per_meter, + (png_const_structrp png_ptr, png_const_inforp info_ptr)); +PNG_EXPORT(123, png_uint_32, png_get_x_pixels_per_meter, + (png_const_structrp png_ptr, png_const_inforp info_ptr)); +PNG_EXPORT(124, png_uint_32, png_get_y_pixels_per_meter, + (png_const_structrp png_ptr, png_const_inforp info_ptr)); + +/* Returns pixel aspect ratio, computed from pHYs chunk data. */ +PNG_FP_EXPORT(125, float, png_get_pixel_aspect_ratio, + (png_const_structrp png_ptr, png_const_inforp info_ptr)) +PNG_FIXED_EXPORT(210, png_fixed_point, png_get_pixel_aspect_ratio_fixed, + (png_const_structrp png_ptr, png_const_inforp info_ptr)) + +/* Returns image x, y offset in pixels or microns, from oFFs chunk data. */ +PNG_EXPORT(126, png_int_32, png_get_x_offset_pixels, + (png_const_structrp png_ptr, png_const_inforp info_ptr)); +PNG_EXPORT(127, png_int_32, png_get_y_offset_pixels, + (png_const_structrp png_ptr, png_const_inforp info_ptr)); +PNG_EXPORT(128, png_int_32, png_get_x_offset_microns, + (png_const_structrp png_ptr, png_const_inforp info_ptr)); +PNG_EXPORT(129, png_int_32, png_get_y_offset_microns, + (png_const_structrp png_ptr, png_const_inforp info_ptr)); + +#endif /* PNG_EASY_ACCESS_SUPPORTED */ + +#ifdef PNG_READ_SUPPORTED +/* Returns pointer to signature string read from PNG header */ +PNG_EXPORT(130, png_const_bytep, png_get_signature, (png_const_structrp png_ptr, + png_const_inforp info_ptr)); +#endif + +#ifdef PNG_bKGD_SUPPORTED +PNG_EXPORT(131, png_uint_32, png_get_bKGD, (png_const_structrp png_ptr, + png_inforp info_ptr, png_color_16p *background)); +#endif + +#ifdef PNG_bKGD_SUPPORTED +PNG_EXPORT(132, void, png_set_bKGD, (png_const_structrp png_ptr, + png_inforp info_ptr, png_const_color_16p background)); +#endif + +#ifdef PNG_cHRM_SUPPORTED +PNG_FP_EXPORT(133, png_uint_32, png_get_cHRM, (png_const_structrp png_ptr, + png_const_inforp info_ptr, double *white_x, double *white_y, double *red_x, + double *red_y, double *green_x, double *green_y, double *blue_x, + double *blue_y)) +PNG_FP_EXPORT(230, png_uint_32, png_get_cHRM_XYZ, (png_const_structrp png_ptr, + png_const_inforp info_ptr, double *red_X, double *red_Y, double *red_Z, + double *green_X, double *green_Y, double *green_Z, double *blue_X, + double *blue_Y, double *blue_Z)) +PNG_FIXED_EXPORT(134, png_uint_32, png_get_cHRM_fixed, + (png_const_structrp png_ptr, png_const_inforp info_ptr, + png_fixed_point *int_white_x, png_fixed_point *int_white_y, + png_fixed_point *int_red_x, png_fixed_point *int_red_y, + png_fixed_point *int_green_x, png_fixed_point *int_green_y, + png_fixed_point *int_blue_x, png_fixed_point *int_blue_y)) +PNG_FIXED_EXPORT(231, png_uint_32, png_get_cHRM_XYZ_fixed, + (png_const_structrp png_ptr, png_const_inforp info_ptr, + png_fixed_point *int_red_X, png_fixed_point *int_red_Y, + png_fixed_point *int_red_Z, png_fixed_point *int_green_X, + png_fixed_point *int_green_Y, png_fixed_point *int_green_Z, + png_fixed_point *int_blue_X, png_fixed_point *int_blue_Y, + png_fixed_point *int_blue_Z)) +#endif + +#ifdef PNG_cHRM_SUPPORTED +PNG_FP_EXPORT(135, void, png_set_cHRM, (png_const_structrp png_ptr, + png_inforp info_ptr, + double white_x, double white_y, double red_x, double red_y, double green_x, + double green_y, double blue_x, double blue_y)) +PNG_FP_EXPORT(232, void, png_set_cHRM_XYZ, (png_const_structrp png_ptr, + png_inforp info_ptr, double red_X, double red_Y, double red_Z, + double green_X, double green_Y, double green_Z, double blue_X, + double blue_Y, double blue_Z)) +PNG_FIXED_EXPORT(136, void, png_set_cHRM_fixed, (png_const_structrp png_ptr, + png_inforp info_ptr, png_fixed_point int_white_x, + png_fixed_point int_white_y, png_fixed_point int_red_x, + png_fixed_point int_red_y, png_fixed_point int_green_x, + png_fixed_point int_green_y, png_fixed_point int_blue_x, + png_fixed_point int_blue_y)) +PNG_FIXED_EXPORT(233, void, png_set_cHRM_XYZ_fixed, (png_const_structrp png_ptr, + png_inforp info_ptr, png_fixed_point int_red_X, png_fixed_point int_red_Y, + png_fixed_point int_red_Z, png_fixed_point int_green_X, + png_fixed_point int_green_Y, png_fixed_point int_green_Z, + png_fixed_point int_blue_X, png_fixed_point int_blue_Y, + png_fixed_point int_blue_Z)) +#endif + +#ifdef PNG_gAMA_SUPPORTED +PNG_FP_EXPORT(137, png_uint_32, png_get_gAMA, (png_const_structrp png_ptr, + png_const_inforp info_ptr, double *file_gamma)) +PNG_FIXED_EXPORT(138, png_uint_32, png_get_gAMA_fixed, + (png_const_structrp png_ptr, png_const_inforp info_ptr, + png_fixed_point *int_file_gamma)) +#endif + +#ifdef PNG_gAMA_SUPPORTED +PNG_FP_EXPORT(139, void, png_set_gAMA, (png_const_structrp png_ptr, + png_inforp info_ptr, double file_gamma)) +PNG_FIXED_EXPORT(140, void, png_set_gAMA_fixed, (png_const_structrp png_ptr, + png_inforp info_ptr, png_fixed_point int_file_gamma)) +#endif + +#ifdef PNG_hIST_SUPPORTED +PNG_EXPORT(141, png_uint_32, png_get_hIST, (png_const_structrp png_ptr, + png_inforp info_ptr, png_uint_16p *hist)); +#endif + +#ifdef PNG_hIST_SUPPORTED +PNG_EXPORT(142, void, png_set_hIST, (png_const_structrp png_ptr, + png_inforp info_ptr, png_const_uint_16p hist)); +#endif + +PNG_EXPORT(143, png_uint_32, png_get_IHDR, (png_const_structrp png_ptr, + png_const_inforp info_ptr, png_uint_32 *width, png_uint_32 *height, + int *bit_depth, int *color_type, int *interlace_method, + int *compression_method, int *filter_method)); + +PNG_EXPORT(144, void, png_set_IHDR, (png_const_structrp png_ptr, + png_inforp info_ptr, png_uint_32 width, png_uint_32 height, int bit_depth, + int color_type, int interlace_method, int compression_method, + int filter_method)); + +#ifdef PNG_oFFs_SUPPORTED +PNG_EXPORT(145, png_uint_32, png_get_oFFs, (png_const_structrp png_ptr, + png_const_inforp info_ptr, png_int_32 *offset_x, png_int_32 *offset_y, + int *unit_type)); +#endif + +#ifdef PNG_oFFs_SUPPORTED +PNG_EXPORT(146, void, png_set_oFFs, (png_const_structrp png_ptr, + png_inforp info_ptr, png_int_32 offset_x, png_int_32 offset_y, + int unit_type)); +#endif + +#ifdef PNG_pCAL_SUPPORTED +PNG_EXPORT(147, png_uint_32, png_get_pCAL, (png_const_structrp png_ptr, + png_inforp info_ptr, png_charp *purpose, png_int_32 *X0, + png_int_32 *X1, int *type, int *nparams, png_charp *units, + png_charpp *params)); +#endif + +#ifdef PNG_pCAL_SUPPORTED +PNG_EXPORT(148, void, png_set_pCAL, (png_const_structrp png_ptr, + png_inforp info_ptr, png_const_charp purpose, png_int_32 X0, png_int_32 X1, + int type, int nparams, png_const_charp units, png_charpp params)); +#endif + +#ifdef PNG_pHYs_SUPPORTED +PNG_EXPORT(149, png_uint_32, png_get_pHYs, (png_const_structrp png_ptr, + png_const_inforp info_ptr, png_uint_32 *res_x, png_uint_32 *res_y, + int *unit_type)); +#endif + +#ifdef PNG_pHYs_SUPPORTED +PNG_EXPORT(150, void, png_set_pHYs, (png_const_structrp png_ptr, + png_inforp info_ptr, png_uint_32 res_x, png_uint_32 res_y, int unit_type)); +#endif + +PNG_EXPORT(151, png_uint_32, png_get_PLTE, (png_const_structrp png_ptr, + png_inforp info_ptr, png_colorp *palette, int *num_palette)); + +PNG_EXPORT(152, void, png_set_PLTE, (png_structrp png_ptr, + png_inforp info_ptr, png_const_colorp palette, int num_palette)); + +#ifdef PNG_sBIT_SUPPORTED +PNG_EXPORT(153, png_uint_32, png_get_sBIT, (png_const_structrp png_ptr, + png_inforp info_ptr, png_color_8p *sig_bit)); +#endif + +#ifdef PNG_sBIT_SUPPORTED +PNG_EXPORT(154, void, png_set_sBIT, (png_const_structrp png_ptr, + png_inforp info_ptr, png_const_color_8p sig_bit)); +#endif + +#ifdef PNG_sRGB_SUPPORTED +PNG_EXPORT(155, png_uint_32, png_get_sRGB, (png_const_structrp png_ptr, + png_const_inforp info_ptr, int *file_srgb_intent)); +#endif + +#ifdef PNG_sRGB_SUPPORTED +PNG_EXPORT(156, void, png_set_sRGB, (png_const_structrp png_ptr, + png_inforp info_ptr, int srgb_intent)); +PNG_EXPORT(157, void, png_set_sRGB_gAMA_and_cHRM, (png_const_structrp png_ptr, + png_inforp info_ptr, int srgb_intent)); +#endif + +#ifdef PNG_iCCP_SUPPORTED +PNG_EXPORT(158, png_uint_32, png_get_iCCP, (png_const_structrp png_ptr, + png_inforp info_ptr, png_charpp name, int *compression_type, + png_bytepp profile, png_uint_32 *proflen)); +#endif + +#ifdef PNG_iCCP_SUPPORTED +PNG_EXPORT(159, void, png_set_iCCP, (png_const_structrp png_ptr, + png_inforp info_ptr, png_const_charp name, int compression_type, + png_const_bytep profile, png_uint_32 proflen)); +#endif + +#ifdef PNG_sPLT_SUPPORTED +PNG_EXPORT(160, int, png_get_sPLT, (png_const_structrp png_ptr, + png_inforp info_ptr, png_sPLT_tpp entries)); +#endif + +#ifdef PNG_sPLT_SUPPORTED +PNG_EXPORT(161, void, png_set_sPLT, (png_const_structrp png_ptr, + png_inforp info_ptr, png_const_sPLT_tp entries, int nentries)); +#endif + +#ifdef PNG_TEXT_SUPPORTED +/* png_get_text also returns the number of text chunks in *num_text */ +PNG_EXPORT(162, int, png_get_text, (png_const_structrp png_ptr, + png_inforp info_ptr, png_textp *text_ptr, int *num_text)); +#endif + +/* Note while png_set_text() will accept a structure whose text, + * language, and translated keywords are NULL pointers, the structure + * returned by png_get_text will always contain regular + * zero-terminated C strings. They might be empty strings but + * they will never be NULL pointers. + */ + +#ifdef PNG_TEXT_SUPPORTED +PNG_EXPORT(163, void, png_set_text, (png_const_structrp png_ptr, + png_inforp info_ptr, png_const_textp text_ptr, int num_text)); +#endif + +#ifdef PNG_tIME_SUPPORTED +PNG_EXPORT(164, png_uint_32, png_get_tIME, (png_const_structrp png_ptr, + png_inforp info_ptr, png_timep *mod_time)); +#endif + +#ifdef PNG_tIME_SUPPORTED +PNG_EXPORT(165, void, png_set_tIME, (png_const_structrp png_ptr, + png_inforp info_ptr, png_const_timep mod_time)); +#endif + +#ifdef PNG_tRNS_SUPPORTED +PNG_EXPORT(166, png_uint_32, png_get_tRNS, (png_const_structrp png_ptr, + png_inforp info_ptr, png_bytep *trans_alpha, int *num_trans, + png_color_16p *trans_color)); +#endif + +#ifdef PNG_tRNS_SUPPORTED +PNG_EXPORT(167, void, png_set_tRNS, (png_structrp png_ptr, + png_inforp info_ptr, png_const_bytep trans_alpha, int num_trans, + png_const_color_16p trans_color)); +#endif + +#ifdef PNG_sCAL_SUPPORTED +PNG_FP_EXPORT(168, png_uint_32, png_get_sCAL, (png_const_structrp png_ptr, + png_const_inforp info_ptr, int *unit, double *width, double *height)) +#if defined(PNG_FLOATING_ARITHMETIC_SUPPORTED) || \ + defined(PNG_FLOATING_POINT_SUPPORTED) +/* NOTE: this API is currently implemented using floating point arithmetic, + * consequently it can only be used on systems with floating point support. + * In any case the range of values supported by png_fixed_point is small and it + * is highly recommended that png_get_sCAL_s be used instead. + */ +PNG_FIXED_EXPORT(214, png_uint_32, png_get_sCAL_fixed, + (png_const_structrp png_ptr, png_const_inforp info_ptr, int *unit, + png_fixed_point *width, png_fixed_point *height)) +#endif +PNG_EXPORT(169, png_uint_32, png_get_sCAL_s, + (png_const_structrp png_ptr, png_const_inforp info_ptr, int *unit, + png_charpp swidth, png_charpp sheight)); + +PNG_FP_EXPORT(170, void, png_set_sCAL, (png_const_structrp png_ptr, + png_inforp info_ptr, int unit, double width, double height)) +PNG_FIXED_EXPORT(213, void, png_set_sCAL_fixed, (png_const_structrp png_ptr, + png_inforp info_ptr, int unit, png_fixed_point width, + png_fixed_point height)) +PNG_EXPORT(171, void, png_set_sCAL_s, (png_const_structrp png_ptr, + png_inforp info_ptr, int unit, + png_const_charp swidth, png_const_charp sheight)); +#endif /* PNG_sCAL_SUPPORTED */ + +#ifdef PNG_SET_UNKNOWN_CHUNKS_SUPPORTED +/* Provide the default handling for all unknown chunks or, optionally, for + * specific unknown chunks. + * + * NOTE: prior to 1.6.0 the handling specified for particular chunks on read was + * ignored and the default was used, the per-chunk setting only had an effect on + * write. If you wish to have chunk-specific handling on read in code that must + * work on earlier versions you must use a user chunk callback to specify the + * desired handling (keep or discard.) + * + * The 'keep' parameter is a PNG_HANDLE_CHUNK_ value as listed below. The + * parameter is interpreted as follows: + * + * READ: + * PNG_HANDLE_CHUNK_AS_DEFAULT: + * Known chunks: do normal libpng processing, do not keep the chunk (but + * see the comments below about PNG_HANDLE_AS_UNKNOWN_SUPPORTED) + * Unknown chunks: for a specific chunk use the global default, when used + * as the default discard the chunk data. + * PNG_HANDLE_CHUNK_NEVER: + * Discard the chunk data. + * PNG_HANDLE_CHUNK_IF_SAFE: + * Keep the chunk data if the chunk is not critical else raise a chunk + * error. + * PNG_HANDLE_CHUNK_ALWAYS: + * Keep the chunk data. + * + * If the chunk data is saved it can be retrieved using png_get_unknown_chunks, + * below. Notice that specifying "AS_DEFAULT" as a global default is equivalent + * to specifying "NEVER", however when "AS_DEFAULT" is used for specific chunks + * it simply resets the behavior to the libpng default. + * + * INTERACTION WTIH USER CHUNK CALLBACKS: + * The per-chunk handling is always used when there is a png_user_chunk_ptr + * callback and the callback returns 0; the chunk is then always stored *unless* + * it is critical and the per-chunk setting is other than ALWAYS. Notice that + * the global default is *not* used in this case. (In effect the per-chunk + * value is incremented to at least IF_SAFE.) + * + * IMPORTANT NOTE: this behavior will change in libpng 1.7 - the global and + * per-chunk defaults will be honored. If you want to preserve the current + * behavior when your callback returns 0 you must set PNG_HANDLE_CHUNK_IF_SAFE + * as the default - if you don't do this libpng 1.6 will issue a warning. + * + * If you want unhandled unknown chunks to be discarded in libpng 1.6 and + * earlier simply return '1' (handled). + * + * PNG_HANDLE_AS_UNKNOWN_SUPPORTED: + * If this is *not* set known chunks will always be handled by libpng and + * will never be stored in the unknown chunk list. Known chunks listed to + * png_set_keep_unknown_chunks will have no effect. If it is set then known + * chunks listed with a keep other than AS_DEFAULT will *never* be processed + * by libpng, in addition critical chunks must either be processed by the + * callback or saved. + * + * The IHDR and IEND chunks must not be listed. Because this turns off the + * default handling for chunks that would otherwise be recognized the + * behavior of libpng transformations may well become incorrect! + * + * WRITE: + * When writing chunks the options only apply to the chunks specified by + * png_set_unknown_chunks (below), libpng will *always* write known chunks + * required by png_set_ calls and will always write the core critical chunks + * (as required for PLTE). + * + * Each chunk in the png_set_unknown_chunks list is looked up in the + * png_set_keep_unknown_chunks list to find the keep setting, this is then + * interpreted as follows: + * + * PNG_HANDLE_CHUNK_AS_DEFAULT: + * Write safe-to-copy chunks and write other chunks if the global + * default is set to _ALWAYS, otherwise don't write this chunk. + * PNG_HANDLE_CHUNK_NEVER: + * Do not write the chunk. + * PNG_HANDLE_CHUNK_IF_SAFE: + * Write the chunk if it is safe-to-copy, otherwise do not write it. + * PNG_HANDLE_CHUNK_ALWAYS: + * Write the chunk. + * + * Note that the default behavior is effectively the opposite of the read case - + * in read unknown chunks are not stored by default, in write they are written + * by default. Also the behavior of PNG_HANDLE_CHUNK_IF_SAFE is very different + * - on write the safe-to-copy bit is checked, on read the critical bit is + * checked and on read if the chunk is critical an error will be raised. + * + * num_chunks: + * =========== + * If num_chunks is positive, then the "keep" parameter specifies the manner + * for handling only those chunks appearing in the chunk_list array, + * otherwise the chunk list array is ignored. + * + * If num_chunks is 0 the "keep" parameter specifies the default behavior for + * unknown chunks, as described above. + * + * If num_chunks is negative, then the "keep" parameter specifies the manner + * for handling all unknown chunks plus all chunks recognized by libpng + * except for the IHDR, PLTE, tRNS, IDAT, and IEND chunks (which continue to + * be processed by libpng. + */ +PNG_EXPORT(172, void, png_set_keep_unknown_chunks, (png_structrp png_ptr, + int keep, png_const_bytep chunk_list, int num_chunks)); + +/* The "keep" PNG_HANDLE_CHUNK_ parameter for the specified chunk is returned; + * the result is therefore true (non-zero) if special handling is required, + * false for the default handling. + */ +PNG_EXPORT(173, int, png_handle_as_unknown, (png_const_structrp png_ptr, + png_const_bytep chunk_name)); +#endif + +#ifdef PNG_STORE_UNKNOWN_CHUNKS_SUPPORTED +PNG_EXPORT(174, void, png_set_unknown_chunks, (png_const_structrp png_ptr, + png_inforp info_ptr, png_const_unknown_chunkp unknowns, + int num_unknowns)); + /* NOTE: prior to 1.6.0 this routine set the 'location' field of the added + * unknowns to the location currently stored in the png_struct. This is + * invariably the wrong value on write. To fix this call the following API + * for each chunk in the list with the correct location. If you know your + * code won't be compiled on earlier versions you can rely on + * png_set_unknown_chunks(write-ptr, png_get_unknown_chunks(read-ptr)) doing + * the correct thing. + */ + +PNG_EXPORT(175, void, png_set_unknown_chunk_location, + (png_const_structrp png_ptr, png_inforp info_ptr, int chunk, int location)); + +PNG_EXPORT(176, int, png_get_unknown_chunks, (png_const_structrp png_ptr, + png_inforp info_ptr, png_unknown_chunkpp entries)); +#endif + +/* Png_free_data() will turn off the "valid" flag for anything it frees. + * If you need to turn it off for a chunk that your application has freed, + * you can use png_set_invalid(png_ptr, info_ptr, PNG_INFO_CHNK); + */ +PNG_EXPORT(177, void, png_set_invalid, (png_const_structrp png_ptr, + png_inforp info_ptr, int mask)); + +#ifdef PNG_INFO_IMAGE_SUPPORTED +/* The "params" pointer is currently not used and is for future expansion. */ +PNG_EXPORT(178, void, png_read_png, (png_structrp png_ptr, png_inforp info_ptr, + int transforms, png_voidp params)); +PNG_EXPORT(179, void, png_write_png, (png_structrp png_ptr, png_inforp info_ptr, + int transforms, png_voidp params)); +#endif + +PNG_EXPORT(180, png_const_charp, png_get_copyright, + (png_const_structrp png_ptr)); +PNG_EXPORT(181, png_const_charp, png_get_header_ver, + (png_const_structrp png_ptr)); +PNG_EXPORT(182, png_const_charp, png_get_header_version, + (png_const_structrp png_ptr)); +PNG_EXPORT(183, png_const_charp, png_get_libpng_ver, + (png_const_structrp png_ptr)); + +#ifdef PNG_MNG_FEATURES_SUPPORTED +PNG_EXPORT(184, png_uint_32, png_permit_mng_features, (png_structrp png_ptr, + png_uint_32 mng_features_permitted)); +#endif + +/* For use in png_set_keep_unknown, added to version 1.2.6 */ +#define PNG_HANDLE_CHUNK_AS_DEFAULT 0 +#define PNG_HANDLE_CHUNK_NEVER 1 +#define PNG_HANDLE_CHUNK_IF_SAFE 2 +#define PNG_HANDLE_CHUNK_ALWAYS 3 +#define PNG_HANDLE_CHUNK_LAST 4 + +/* Strip the prepended error numbers ("#nnn ") from error and warning + * messages before passing them to the error or warning handler. + */ +#ifdef PNG_ERROR_NUMBERS_SUPPORTED +PNG_EXPORT(185, void, png_set_strip_error_numbers, (png_structrp png_ptr, + png_uint_32 strip_mode)); +#endif + +/* Added in libpng-1.2.6 */ +#ifdef PNG_SET_USER_LIMITS_SUPPORTED +PNG_EXPORT(186, void, png_set_user_limits, (png_structrp png_ptr, + png_uint_32 user_width_max, png_uint_32 user_height_max)); +PNG_EXPORT(187, png_uint_32, png_get_user_width_max, + (png_const_structrp png_ptr)); +PNG_EXPORT(188, png_uint_32, png_get_user_height_max, + (png_const_structrp png_ptr)); +/* Added in libpng-1.4.0 */ +PNG_EXPORT(189, void, png_set_chunk_cache_max, (png_structrp png_ptr, + png_uint_32 user_chunk_cache_max)); +PNG_EXPORT(190, png_uint_32, png_get_chunk_cache_max, + (png_const_structrp png_ptr)); +/* Added in libpng-1.4.1 */ +PNG_EXPORT(191, void, png_set_chunk_malloc_max, (png_structrp png_ptr, + png_alloc_size_t user_chunk_cache_max)); +PNG_EXPORT(192, png_alloc_size_t, png_get_chunk_malloc_max, + (png_const_structrp png_ptr)); +#endif + +#if defined(PNG_INCH_CONVERSIONS_SUPPORTED) +PNG_EXPORT(193, png_uint_32, png_get_pixels_per_inch, + (png_const_structrp png_ptr, png_const_inforp info_ptr)); + +PNG_EXPORT(194, png_uint_32, png_get_x_pixels_per_inch, + (png_const_structrp png_ptr, png_const_inforp info_ptr)); + +PNG_EXPORT(195, png_uint_32, png_get_y_pixels_per_inch, + (png_const_structrp png_ptr, png_const_inforp info_ptr)); + +PNG_FP_EXPORT(196, float, png_get_x_offset_inches, + (png_const_structrp png_ptr, png_const_inforp info_ptr)) +#ifdef PNG_FIXED_POINT_SUPPORTED /* otherwise not implemented. */ +PNG_FIXED_EXPORT(211, png_fixed_point, png_get_x_offset_inches_fixed, + (png_const_structrp png_ptr, png_const_inforp info_ptr)) +#endif + +PNG_FP_EXPORT(197, float, png_get_y_offset_inches, (png_const_structrp png_ptr, + png_const_inforp info_ptr)) +#ifdef PNG_FIXED_POINT_SUPPORTED /* otherwise not implemented. */ +PNG_FIXED_EXPORT(212, png_fixed_point, png_get_y_offset_inches_fixed, + (png_const_structrp png_ptr, png_const_inforp info_ptr)) +#endif + +# ifdef PNG_pHYs_SUPPORTED +PNG_EXPORT(198, png_uint_32, png_get_pHYs_dpi, (png_const_structrp png_ptr, + png_const_inforp info_ptr, png_uint_32 *res_x, png_uint_32 *res_y, + int *unit_type)); +# endif /* PNG_pHYs_SUPPORTED */ +#endif /* PNG_INCH_CONVERSIONS_SUPPORTED */ + +/* Added in libpng-1.4.0 */ +#ifdef PNG_IO_STATE_SUPPORTED +PNG_EXPORT(199, png_uint_32, png_get_io_state, (png_const_structrp png_ptr)); + +/* Removed from libpng 1.6; use png_get_io_chunk_type. */ +PNG_REMOVED(200, png_const_bytep, png_get_io_chunk_name, (png_structrp png_ptr), + PNG_DEPRECATED) + +PNG_EXPORT(216, png_uint_32, png_get_io_chunk_type, + (png_const_structrp png_ptr)); + +/* The flags returned by png_get_io_state() are the following: */ +# define PNG_IO_NONE 0x0000 /* no I/O at this moment */ +# define PNG_IO_READING 0x0001 /* currently reading */ +# define PNG_IO_WRITING 0x0002 /* currently writing */ +# define PNG_IO_SIGNATURE 0x0010 /* currently at the file signature */ +# define PNG_IO_CHUNK_HDR 0x0020 /* currently at the chunk header */ +# define PNG_IO_CHUNK_DATA 0x0040 /* currently at the chunk data */ +# define PNG_IO_CHUNK_CRC 0x0080 /* currently at the chunk crc */ +# define PNG_IO_MASK_OP 0x000f /* current operation: reading/writing */ +# define PNG_IO_MASK_LOC 0x00f0 /* current location: sig/hdr/data/crc */ +#endif /* ?PNG_IO_STATE_SUPPORTED */ + +/* Interlace support. The following macros are always defined so that if + * libpng interlace handling is turned off the macros may be used to handle + * interlaced images within the application. + */ +#define PNG_INTERLACE_ADAM7_PASSES 7 + +/* Two macros to return the first row and first column of the original, + * full, image which appears in a given pass. 'pass' is in the range 0 + * to 6 and the result is in the range 0 to 7. + */ +#define PNG_PASS_START_ROW(pass) (((1&~(pass))<<(3-((pass)>>1)))&7) +#define PNG_PASS_START_COL(pass) (((1& (pass))<<(3-(((pass)+1)>>1)))&7) + +/* A macro to return the offset between pixels in the output row for a pair of + * pixels in the input - effectively the inverse of the 'COL_SHIFT' macro that + * follows. Note that ROW_OFFSET is the offset from one row to the next whereas + * COL_OFFSET is from one column to the next, within a row. + */ +#define PNG_PASS_ROW_OFFSET(pass) ((pass)>2?(8>>(((pass)-1)>>1)):8) +#define PNG_PASS_COL_OFFSET(pass) (1<<((7-(pass))>>1)) + +/* Two macros to help evaluate the number of rows or columns in each + * pass. This is expressed as a shift - effectively log2 of the number or + * rows or columns in each 8x8 tile of the original image. + */ +#define PNG_PASS_ROW_SHIFT(pass) ((pass)>2?(8-(pass))>>1:3) +#define PNG_PASS_COL_SHIFT(pass) ((pass)>1?(7-(pass))>>1:3) + +/* Hence two macros to determine the number of rows or columns in a given + * pass of an image given its height or width. In fact these macros may + * return non-zero even though the sub-image is empty, because the other + * dimension may be empty for a small image. + */ +#define PNG_PASS_ROWS(height, pass) (((height)+(((1<>PNG_PASS_ROW_SHIFT(pass)) +#define PNG_PASS_COLS(width, pass) (((width)+(((1<>PNG_PASS_COL_SHIFT(pass)) + +/* For the reader row callbacks (both progressive and sequential) it is + * necessary to find the row in the output image given a row in an interlaced + * image, so two more macros: + */ +#define PNG_ROW_FROM_PASS_ROW(y_in, pass) \ + (((y_in)<>(((7-(off))-(pass))<<2)) & 0xF) | \ + ((0x01145AF0>>(((7-(off))-(pass))<<2)) & 0xF0)) + +#define PNG_ROW_IN_INTERLACE_PASS(y, pass) \ + ((PNG_PASS_MASK(pass,0) >> ((y)&7)) & 1) +#define PNG_COL_IN_INTERLACE_PASS(x, pass) \ + ((PNG_PASS_MASK(pass,1) >> ((x)&7)) & 1) + +#ifdef PNG_READ_COMPOSITE_NODIV_SUPPORTED +/* With these routines we avoid an integer divide, which will be slower on + * most machines. However, it does take more operations than the corresponding + * divide method, so it may be slower on a few RISC systems. There are two + * shifts (by 8 or 16 bits) and an addition, versus a single integer divide. + * + * Note that the rounding factors are NOT supposed to be the same! 128 and + * 32768 are correct for the NODIV code; 127 and 32767 are correct for the + * standard method. + * + * [Optimized code by Greg Roelofs and Mark Adler...blame us for bugs. :-) ] + */ + + /* fg and bg should be in `gamma 1.0' space; alpha is the opacity */ + +# define png_composite(composite, fg, alpha, bg) \ + { png_uint_16 temp = (png_uint_16)((png_uint_16)(fg) \ + * (png_uint_16)(alpha) \ + + (png_uint_16)(bg)*(png_uint_16)(255 \ + - (png_uint_16)(alpha)) + 128); \ + (composite) = (png_byte)((temp + (temp >> 8)) >> 8); } + +# define png_composite_16(composite, fg, alpha, bg) \ + { png_uint_32 temp = (png_uint_32)((png_uint_32)(fg) \ + * (png_uint_32)(alpha) \ + + (png_uint_32)(bg)*(65535 \ + - (png_uint_32)(alpha)) + 32768); \ + (composite) = (png_uint_16)((temp + (temp >> 16)) >> 16); } + +#else /* Standard method using integer division */ + +# define png_composite(composite, fg, alpha, bg) \ + (composite) = (png_byte)(((png_uint_16)(fg) * (png_uint_16)(alpha) + \ + (png_uint_16)(bg) * (png_uint_16)(255 - (png_uint_16)(alpha)) + \ + 127) / 255) + +# define png_composite_16(composite, fg, alpha, bg) \ + (composite) = (png_uint_16)(((png_uint_32)(fg) * (png_uint_32)(alpha) + \ + (png_uint_32)(bg)*(png_uint_32)(65535 - (png_uint_32)(alpha)) + \ + 32767) / 65535) +#endif /* PNG_READ_COMPOSITE_NODIV_SUPPORTED */ + +#ifdef PNG_READ_INT_FUNCTIONS_SUPPORTED +PNG_EXPORT(201, png_uint_32, png_get_uint_32, (png_const_bytep buf)); +PNG_EXPORT(202, png_uint_16, png_get_uint_16, (png_const_bytep buf)); +PNG_EXPORT(203, png_int_32, png_get_int_32, (png_const_bytep buf)); +#endif + +PNG_EXPORT(204, png_uint_32, png_get_uint_31, (png_const_structrp png_ptr, + png_const_bytep buf)); +/* No png_get_int_16 -- may be added if there's a real need for it. */ + +/* Place a 32-bit number into a buffer in PNG byte order (big-endian). */ +#ifdef PNG_WRITE_INT_FUNCTIONS_SUPPORTED +PNG_EXPORT(205, void, png_save_uint_32, (png_bytep buf, png_uint_32 i)); +#endif +#ifdef PNG_SAVE_INT_32_SUPPORTED +PNG_EXPORT(206, void, png_save_int_32, (png_bytep buf, png_int_32 i)); +#endif + +/* Place a 16-bit number into a buffer in PNG byte order. + * The parameter is declared unsigned int, not png_uint_16, + * just to avoid potential problems on pre-ANSI C compilers. + */ +#ifdef PNG_WRITE_INT_FUNCTIONS_SUPPORTED +PNG_EXPORT(207, void, png_save_uint_16, (png_bytep buf, unsigned int i)); +/* No png_save_int_16 -- may be added if there's a real need for it. */ +#endif + +#ifdef PNG_USE_READ_MACROS +/* Inline macros to do direct reads of bytes from the input buffer. + * The png_get_int_32() routine assumes we are using two's complement + * format for negative values, which is almost certainly true. + */ +# define PNG_get_uint_32(buf) \ + (((png_uint_32)(*(buf)) << 24) + \ + ((png_uint_32)(*((buf) + 1)) << 16) + \ + ((png_uint_32)(*((buf) + 2)) << 8) + \ + ((png_uint_32)(*((buf) + 3)))) + + /* From libpng-1.4.0 until 1.4.4, the png_get_uint_16 macro (but not the + * function) incorrectly returned a value of type png_uint_32. + */ +# define PNG_get_uint_16(buf) \ + ((png_uint_16) \ + (((unsigned int)(*(buf)) << 8) + \ + ((unsigned int)(*((buf) + 1))))) + +# define PNG_get_int_32(buf) \ + ((png_int_32)((*(buf) & 0x80) \ + ? -((png_int_32)((png_get_uint_32(buf) ^ 0xffffffffL) + 1)) \ + : (png_int_32)png_get_uint_32(buf))) + + /* If PNG_PREFIX is defined the same thing as below happens in pnglibconf.h, + * but defining a macro name prefixed with PNG_PREFIX. + */ +# ifndef PNG_PREFIX +# define png_get_uint_32(buf) PNG_get_uint_32(buf) +# define png_get_uint_16(buf) PNG_get_uint_16(buf) +# define png_get_int_32(buf) PNG_get_int_32(buf) +# endif +#else +# ifdef PNG_PREFIX + /* No macros; revert to the (redefined) function */ +# define PNG_get_uint_32 (png_get_uint_32) +# define PNG_get_uint_16 (png_get_uint_16) +# define PNG_get_int_32 (png_get_int_32) +# endif +#endif + +/******************************************************************************* + * SIMPLIFIED API + ******************************************************************************* + * + * Please read the documentation in libpng-manual.txt (TODO: write said + * documentation) if you don't understand what follows. + * + * The simplified API hides the details of both libpng and the PNG file format + * itself. It allows PNG files to be read into a very limited number of + * in-memory bitmap formats or to be written from the same formats. If these + * formats do not accomodate your needs then you can, and should, use the more + * sophisticated APIs above - these support a wide variety of in-memory formats + * and a wide variety of sophisticated transformations to those formats as well + * as a wide variety of APIs to manipulate ancillary information. + * + * To read a PNG file using the simplified API: + * + * 1) Declare a 'png_image' structure (see below) on the stack and set the + * version field to PNG_IMAGE_VERSION. + * 2) Call the appropriate png_image_begin_read... function. + * 3) Set the png_image 'format' member to the required sample format. + * 4) Allocate a buffer for the image and, if required, the color-map. + * 5) Call png_image_finish_read to read the image and, if required, the + * color-map into your buffers. + * + * There are no restrictions on the format of the PNG input itself; all valid + * color types, bit depths, and interlace methods are acceptable, and the + * input image is transformed as necessary to the requested in-memory format + * during the png_image_finish_read() step. The only caveat is that if you + * request a color-mapped image from a PNG that is full-color or makes + * complex use of an alpha channel the transformation is extremely lossy and the + * result may look terrible. + * + * To write a PNG file using the simplified API: + * + * 1) Declare a 'png_image' structure on the stack and memset() it to all zero. + * 2) Initialize the members of the structure that describe the image, setting + * the 'format' member to the format of the image samples. + * 3) Call the appropriate png_image_write... function with a pointer to the + * image and, if necessary, the color-map to write the PNG data. + * + * png_image is a structure that describes the in-memory format of an image + * when it is being read or defines the in-memory format of an image that you + * need to write: + */ +#define PNG_IMAGE_VERSION 1 + +typedef struct png_control *png_controlp; +typedef struct +{ + png_controlp opaque; /* Initialize to NULL, free with png_image_free */ + png_uint_32 version; /* Set to PNG_IMAGE_VERSION */ + png_uint_32 width; /* Image width in pixels (columns) */ + png_uint_32 height; /* Image height in pixels (rows) */ + png_uint_32 format; /* Image format as defined below */ + png_uint_32 flags; /* A bit mask containing informational flags */ + png_uint_32 colormap_entries; + /* Number of entries in the color-map */ + + /* In the event of an error or warning the following field will be set to a + * non-zero value and the 'message' field will contain a '\0' terminated + * string with the libpng error or warning message. If both warnings and + * an error were encountered, only the error is recorded. If there + * are multiple warnings, only the first one is recorded. + * + * The upper 30 bits of this value are reserved, the low two bits contain + * a value as follows: + */ +# define PNG_IMAGE_WARNING 1 +# define PNG_IMAGE_ERROR 2 + /* + * The result is a two bit code such that a value more than 1 indicates + * a failure in the API just called: + * + * 0 - no warning or error + * 1 - warning + * 2 - error + * 3 - error preceded by warning + */ +# define PNG_IMAGE_FAILED(png_cntrl) ((((png_cntrl).warning_or_error)&0x03)>1) + + png_uint_32 warning_or_error; + + char message[64]; +} png_image, *png_imagep; + +/* The samples of the image have one to four channels whose components have + * original values in the range 0 to 1.0: + * + * 1: A single gray or luminance channel (G). + * 2: A gray/luminance channel and an alpha channel (GA). + * 3: Three red, green, blue color channels (RGB). + * 4: Three color channels and an alpha channel (RGBA). + * + * The components are encoded in one of two ways: + * + * a) As a small integer, value 0..255, contained in a single byte. For the + * alpha channel the original value is simply value/255. For the color or + * luminance channels the value is encoded according to the sRGB specification + * and matches the 8-bit format expected by typical display devices. + * + * The color/gray channels are not scaled (pre-multiplied) by the alpha + * channel and are suitable for passing to color management software. + * + * b) As a value in the range 0..65535, contained in a 2-byte integer. All + * channels can be converted to the original value by dividing by 65535; all + * channels are linear. Color channels use the RGB encoding (RGB end-points) of + * the sRGB specification. This encoding is identified by the + * PNG_FORMAT_FLAG_LINEAR flag below. + * + * When the simplified API needs to convert between sRGB and linear colorspaces, + * the actual sRGB transfer curve defined in the sRGB specification (see the + * article at http://en.wikipedia.org/wiki/SRGB) is used, not the gamma=1/2.2 + * approximation used elsewhere in libpng. + * + * When an alpha channel is present it is expected to denote pixel coverage + * of the color or luminance channels and is returned as an associated alpha + * channel: the color/gray channels are scaled (pre-multiplied) by the alpha + * value. + * + * The samples are either contained directly in the image data, between 1 and 8 + * bytes per pixel according to the encoding, or are held in a color-map indexed + * by bytes in the image data. In the case of a color-map the color-map entries + * are individual samples, encoded as above, and the image data has one byte per + * pixel to select the relevant sample from the color-map. + */ + +/* PNG_FORMAT_* + * + * #defines to be used in png_image::format. Each #define identifies a + * particular layout of sample data and, if present, alpha values. There are + * separate defines for each of the two component encodings. + * + * A format is built up using single bit flag values. All combinations are + * valid. Formats can be built up from the flag values or you can use one of + * the predefined values below. When testing formats always use the FORMAT_FLAG + * macros to test for individual features - future versions of the library may + * add new flags. + * + * When reading or writing color-mapped images the format should be set to the + * format of the entries in the color-map then png_image_{read,write}_colormap + * called to read or write the color-map and set the format correctly for the + * image data. Do not set the PNG_FORMAT_FLAG_COLORMAP bit directly! + * + * NOTE: libpng can be built with particular features disabled, if you see + * compiler errors because the definition of one of the following flags has been + * compiled out it is because libpng does not have the required support. It is + * possible, however, for the libpng configuration to enable the format on just + * read or just write; in that case you may see an error at run time. You can + * guard against this by checking for the definition of the appropriate + * "_SUPPORTED" macro, one of: + * + * PNG_SIMPLIFIED_{READ,WRITE}_{BGR,AFIRST}_SUPPORTED + */ +#define PNG_FORMAT_FLAG_ALPHA 0x01U /* format with an alpha channel */ +#define PNG_FORMAT_FLAG_COLOR 0x02U /* color format: otherwise grayscale */ +#define PNG_FORMAT_FLAG_LINEAR 0x04U /* 2 byte channels else 1 byte */ +#define PNG_FORMAT_FLAG_COLORMAP 0x08U /* image data is color-mapped */ + +#ifdef PNG_FORMAT_BGR_SUPPORTED +# define PNG_FORMAT_FLAG_BGR 0x10U /* BGR colors, else order is RGB */ +#endif + +#ifdef PNG_FORMAT_AFIRST_SUPPORTED +# define PNG_FORMAT_FLAG_AFIRST 0x20U /* alpha channel comes first */ +#endif + +/* Commonly used formats have predefined macros. + * + * First the single byte (sRGB) formats: + */ +#define PNG_FORMAT_GRAY 0 +#define PNG_FORMAT_GA PNG_FORMAT_FLAG_ALPHA +#define PNG_FORMAT_AG (PNG_FORMAT_GA|PNG_FORMAT_FLAG_AFIRST) +#define PNG_FORMAT_RGB PNG_FORMAT_FLAG_COLOR +#define PNG_FORMAT_BGR (PNG_FORMAT_FLAG_COLOR|PNG_FORMAT_FLAG_BGR) +#define PNG_FORMAT_RGBA (PNG_FORMAT_RGB|PNG_FORMAT_FLAG_ALPHA) +#define PNG_FORMAT_ARGB (PNG_FORMAT_RGBA|PNG_FORMAT_FLAG_AFIRST) +#define PNG_FORMAT_BGRA (PNG_FORMAT_BGR|PNG_FORMAT_FLAG_ALPHA) +#define PNG_FORMAT_ABGR (PNG_FORMAT_BGRA|PNG_FORMAT_FLAG_AFIRST) + +/* Then the linear 2-byte formats. When naming these "Y" is used to + * indicate a luminance (gray) channel. + */ +#define PNG_FORMAT_LINEAR_Y PNG_FORMAT_FLAG_LINEAR +#define PNG_FORMAT_LINEAR_Y_ALPHA (PNG_FORMAT_FLAG_LINEAR|PNG_FORMAT_FLAG_ALPHA) +#define PNG_FORMAT_LINEAR_RGB (PNG_FORMAT_FLAG_LINEAR|PNG_FORMAT_FLAG_COLOR) +#define PNG_FORMAT_LINEAR_RGB_ALPHA \ + (PNG_FORMAT_FLAG_LINEAR|PNG_FORMAT_FLAG_COLOR|PNG_FORMAT_FLAG_ALPHA) + +/* With color-mapped formats the image data is one byte for each pixel, the byte + * is an index into the color-map which is formatted as above. To obtain a + * color-mapped format it is sufficient just to add the PNG_FOMAT_FLAG_COLORMAP + * to one of the above definitions, or you can use one of the definitions below. + */ +#define PNG_FORMAT_RGB_COLORMAP (PNG_FORMAT_RGB|PNG_FORMAT_FLAG_COLORMAP) +#define PNG_FORMAT_BGR_COLORMAP (PNG_FORMAT_BGR|PNG_FORMAT_FLAG_COLORMAP) +#define PNG_FORMAT_RGBA_COLORMAP (PNG_FORMAT_RGBA|PNG_FORMAT_FLAG_COLORMAP) +#define PNG_FORMAT_ARGB_COLORMAP (PNG_FORMAT_ARGB|PNG_FORMAT_FLAG_COLORMAP) +#define PNG_FORMAT_BGRA_COLORMAP (PNG_FORMAT_BGRA|PNG_FORMAT_FLAG_COLORMAP) +#define PNG_FORMAT_ABGR_COLORMAP (PNG_FORMAT_ABGR|PNG_FORMAT_FLAG_COLORMAP) + +/* PNG_IMAGE macros + * + * These are convenience macros to derive information from a png_image + * structure. The PNG_IMAGE_SAMPLE_ macros return values appropriate to the + * actual image sample values - either the entries in the color-map or the + * pixels in the image. The PNG_IMAGE_PIXEL_ macros return corresponding values + * for the pixels and will always return 1 for color-mapped formats. The + * remaining macros return information about the rows in the image and the + * complete image. + * + * NOTE: All the macros that take a png_image::format parameter are compile time + * constants if the format parameter is, itself, a constant. Therefore these + * macros can be used in array declarations and case labels where required. + * Similarly the macros are also pre-processor constants (sizeof is not used) so + * they can be used in #if tests. + * + * First the information about the samples. + */ +#define PNG_IMAGE_SAMPLE_CHANNELS(fmt)\ + (((fmt)&(PNG_FORMAT_FLAG_COLOR|PNG_FORMAT_FLAG_ALPHA))+1) + /* Return the total number of channels in a given format: 1..4 */ + +#define PNG_IMAGE_SAMPLE_COMPONENT_SIZE(fmt)\ + ((((fmt) & PNG_FORMAT_FLAG_LINEAR) >> 2)+1) + /* Return the size in bytes of a single component of a pixel or color-map + * entry (as appropriate) in the image: 1 or 2. + */ + +#define PNG_IMAGE_SAMPLE_SIZE(fmt)\ + (PNG_IMAGE_SAMPLE_CHANNELS(fmt) * PNG_IMAGE_SAMPLE_COMPONENT_SIZE(fmt)) + /* This is the size of the sample data for one sample. If the image is + * color-mapped it is the size of one color-map entry (and image pixels are + * one byte in size), otherwise it is the size of one image pixel. + */ + +#define PNG_IMAGE_MAXIMUM_COLORMAP_COMPONENTS(fmt)\ + (PNG_IMAGE_SAMPLE_CHANNELS(fmt) * 256) + /* The maximum size of the color-map required by the format expressed in a + * count of components. This can be used to compile-time allocate a + * color-map: + * + * png_uint_16 colormap[PNG_IMAGE_MAXIMUM_COLORMAP_COMPONENTS(linear_fmt)]; + * + * png_byte colormap[PNG_IMAGE_MAXIMUM_COLORMAP_COMPONENTS(sRGB_fmt)]; + * + * Alternatively use the PNG_IMAGE_COLORMAP_SIZE macro below to use the + * information from one of the png_image_begin_read_ APIs and dynamically + * allocate the required memory. + */ + +/* Corresponding information about the pixels */ +#define PNG_IMAGE_PIXEL_(test,fmt)\ + (((fmt)&PNG_FORMAT_FLAG_COLORMAP)?1:test(fmt)) + +#define PNG_IMAGE_PIXEL_CHANNELS(fmt)\ + PNG_IMAGE_PIXEL_(PNG_IMAGE_SAMPLE_CHANNELS,fmt) + /* The number of separate channels (components) in a pixel; 1 for a + * color-mapped image. + */ + +#define PNG_IMAGE_PIXEL_COMPONENT_SIZE(fmt)\ + PNG_IMAGE_PIXEL_(PNG_IMAGE_SAMPLE_COMPONENT_SIZE,fmt) + /* The size, in bytes, of each component in a pixel; 1 for a color-mapped + * image. + */ + +#define PNG_IMAGE_PIXEL_SIZE(fmt) PNG_IMAGE_PIXEL_(PNG_IMAGE_SAMPLE_SIZE,fmt) + /* The size, in bytes, of a complete pixel; 1 for a color-mapped image. */ + +/* Information about the whole row, or whole image */ +#define PNG_IMAGE_ROW_STRIDE(image)\ + (PNG_IMAGE_PIXEL_CHANNELS((image).format) * (image).width) + /* Return the total number of components in a single row of the image; this + * is the minimum 'row stride', the minimum count of components between each + * row. For a color-mapped image this is the minimum number of bytes in a + * row. + */ + +#define PNG_IMAGE_BUFFER_SIZE(image, row_stride)\ + (PNG_IMAGE_PIXEL_COMPONENT_SIZE((image).format)*(image).height*(row_stride)) + /* Return the size, in bytes, of an image buffer given a png_image and a row + * stride - the number of components to leave space for in each row. + */ + +#define PNG_IMAGE_SIZE(image)\ + PNG_IMAGE_BUFFER_SIZE(image, PNG_IMAGE_ROW_STRIDE(image)) + /* Return the size, in bytes, of the image in memory given just a png_image; + * the row stride is the minimum stride required for the image. + */ + +#define PNG_IMAGE_COLORMAP_SIZE(image)\ + (PNG_IMAGE_SAMPLE_SIZE((image).format) * (image).colormap_entries) + /* Return the size, in bytes, of the color-map of this image. If the image + * format is not a color-map format this will return a size sufficient for + * 256 entries in the given format; check PNG_FORMAT_FLAG_COLORMAP if + * you don't want to allocate a color-map in this case. + */ + +/* PNG_IMAGE_FLAG_* + * + * Flags containing additional information about the image are held in the + * 'flags' field of png_image. + */ +#define PNG_IMAGE_FLAG_COLORSPACE_NOT_sRGB 0x01 + /* This indicates the the RGB values of the in-memory bitmap do not + * correspond to the red, green and blue end-points defined by sRGB. + */ + +#define PNG_IMAGE_FLAG_FAST 0x02 + /* On write emphasise speed over compression; the resultant PNG file will be + * larger but will be produced significantly faster, particular for large + * images. Do not use this option for images which will be distributed, only + * used it when producing intermediate files that will be read back in + * repeatedly. For a typical 24-bit image the option will double the read + * speed at the cost of increasing the image size by 25%, however for many + * more compressible images the PNG file can be 10 times larger with only a + * slight speed gain. + */ + +#define PNG_IMAGE_FLAG_16BIT_sRGB 0x04 + /* On read if the image is a 16-bit per component image and there is no gAMA + * or sRGB chunk assume that the components are sRGB encoded. Notice that + * images output by the simplified API always have gamma information; setting + * this flag only affects the interpretation of 16-bit images from an + * external source. It is recommended that the application expose this flag + * to the user; the user can normally easily recognize the difference between + * linear and sRGB encoding. This flag has no effect on write - the data + * passed to the write APIs must have the correct encoding (as defined + * above.) + * + * If the flag is not set (the default) input 16-bit per component data is + * assumed to be linear. + * + * NOTE: the flag can only be set after the png_image_begin_read_ call, + * because that call initializes the 'flags' field. + */ + +#ifdef PNG_SIMPLIFIED_READ_SUPPORTED +/* READ APIs + * --------- + * + * The png_image passed to the read APIs must have been initialized by setting + * the png_controlp field 'opaque' to NULL (or, safer, memset the whole thing.) + */ +#ifdef PNG_STDIO_SUPPORTED +PNG_EXPORT(234, int, png_image_begin_read_from_file, (png_imagep image, + const char *file_name)); + /* The named file is opened for read and the image header is filled in + * from the PNG header in the file. + */ + +PNG_EXPORT(235, int, png_image_begin_read_from_stdio, (png_imagep image, + FILE* file)); + /* The PNG header is read from the stdio FILE object. */ +#endif /* PNG_STDIO_SUPPORTED */ + +PNG_EXPORT(236, int, png_image_begin_read_from_memory, (png_imagep image, + png_const_voidp memory, png_size_t size)); + /* The PNG header is read from the given memory buffer. */ + +PNG_EXPORT(237, int, png_image_finish_read, (png_imagep image, + png_const_colorp background, void *buffer, png_int_32 row_stride, + void *colormap)); + /* Finish reading the image into the supplied buffer and clean up the + * png_image structure. + * + * row_stride is the step, in byte or 2-byte units as appropriate, + * between adjacent rows. A positive stride indicates that the top-most row + * is first in the buffer - the normal top-down arrangement. A negative + * stride indicates that the bottom-most row is first in the buffer. + * + * background need only be supplied if an alpha channel must be removed from + * a png_byte format and the removal is to be done by compositing on a solid + * color; otherwise it may be NULL and any composition will be done directly + * onto the buffer. The value is an sRGB color to use for the background, + * for grayscale output the green channel is used. + * + * background must be supplied when an alpha channel must be removed from a + * single byte color-mapped output format, in other words if: + * + * 1) The original format from png_image_begin_read_from_* had + * PNG_FORMAT_FLAG_ALPHA set. + * 2) The format set by the application does not. + * 3) The format set by the application has PNG_FORMAT_FLAG_COLORMAP set and + * PNG_FORMAT_FLAG_LINEAR *not* set. + * + * For linear output removing the alpha channel is always done by compositing + * on black and background is ignored. + * + * colormap must be supplied when PNG_FORMAT_FLAG_COLORMAP is set. It must + * be at least the size (in bytes) returned by PNG_IMAGE_COLORMAP_SIZE. + * image->colormap_entries will be updated to the actual number of entries + * written to the colormap; this may be less than the original value. + */ + +PNG_EXPORT(238, void, png_image_free, (png_imagep image)); + /* Free any data allocated by libpng in image->opaque, setting the pointer to + * NULL. May be called at any time after the structure is initialized. + */ +#endif /* PNG_SIMPLIFIED_READ_SUPPORTED */ + +#ifdef PNG_SIMPLIFIED_WRITE_SUPPORTED +#ifdef PNG_STDIO_SUPPORTED +/* WRITE APIS + * ---------- + * For write you must initialize a png_image structure to describe the image to + * be written. To do this use memset to set the whole structure to 0 then + * initialize fields describing your image. + * + * version: must be set to PNG_IMAGE_VERSION + * opaque: must be initialized to NULL + * width: image width in pixels + * height: image height in rows + * format: the format of the data (image and color-map) you wish to write + * flags: set to 0 unless one of the defined flags applies; set + * PNG_IMAGE_FLAG_COLORSPACE_NOT_sRGB for color format images where the RGB + * values do not correspond to the colors in sRGB. + * colormap_entries: set to the number of entries in the color-map (0 to 256) + */ +PNG_EXPORT(239, int, png_image_write_to_file, (png_imagep image, + const char *file, int convert_to_8bit, const void *buffer, + png_int_32 row_stride, const void *colormap)); + /* Write the image to the named file. */ + +PNG_EXPORT(240, int, png_image_write_to_stdio, (png_imagep image, FILE *file, + int convert_to_8_bit, const void *buffer, png_int_32 row_stride, + const void *colormap)); + /* Write the image to the given (FILE*). */ + +/* With both write APIs if image is in one of the linear formats with 16-bit + * data then setting convert_to_8_bit will cause the output to be an 8-bit PNG + * gamma encoded according to the sRGB specification, otherwise a 16-bit linear + * encoded PNG file is written. + * + * With color-mapped data formats the colormap parameter point to a color-map + * with at least image->colormap_entries encoded in the specified format. If + * the format is linear the written PNG color-map will be converted to sRGB + * regardless of the convert_to_8_bit flag. + * + * With all APIs row_stride is handled as in the read APIs - it is the spacing + * from one row to the next in component sized units (1 or 2 bytes) and if + * negative indicates a bottom-up row layout in the buffer. + * + * Note that the write API does not support interlacing or sub-8-bit pixels. + */ +#endif /* PNG_STDIO_SUPPORTED */ +#endif /* PNG_SIMPLIFIED_WRITE_SUPPORTED */ +/******************************************************************************* + * END OF SIMPLIFIED API + ******************************************************************************/ + +#ifdef PNG_CHECK_FOR_INVALID_INDEX_SUPPORTED +PNG_EXPORT(242, void, png_set_check_for_invalid_index, + (png_structrp png_ptr, int allowed)); +# ifdef PNG_GET_PALETTE_MAX_SUPPORTED +PNG_EXPORT(243, int, png_get_palette_max, (png_const_structp png_ptr, + png_const_infop info_ptr)); +# endif +#endif /* CHECK_FOR_INVALID_INDEX */ + +/******************************************************************************* + * IMPLEMENTATION OPTIONS + ******************************************************************************* + * + * Support for arbitrary implementation-specific optimizations. The API allows + * particular options to be turned on or off. 'Option' is the number of the + * option and 'onoff' is 0 (off) or non-0 (on). The value returned is given + * by the PNG_OPTION_ defines below. + * + * HARDWARE: normally hardware capabilites, such as the Intel SSE instructions, + * are detected at run time, however sometimes it may be impossible + * to do this in user mode, in which case it is necessary to discover + * the capabilities in an OS specific way. Such capabilities are + * listed here when libpng has support for them and must be turned + * ON by the application if present. + * + * SOFTWARE: sometimes software optimizations actually result in performance + * decrease on some architectures or systems, or with some sets of + * PNG images. 'Software' options allow such optimizations to be + * selected at run time. + */ +#ifdef PNG_SET_OPTION_SUPPORTED +#ifdef PNG_ARM_NEON_API_SUPPORTED +# define PNG_ARM_NEON 0 /* HARDWARE: ARM Neon SIMD instructions supported */ +#endif +#define PNG_MAXIMUM_INFLATE_WINDOW 2 /* SOFTWARE: force maximum window */ +#define PNG_OPTION_NEXT 4 /* Next option - numbers must be even */ + +/* Return values: NOTE: there are four values and 'off' is *not* zero */ +#define PNG_OPTION_UNSET 0 /* Unset - defaults to off */ +#define PNG_OPTION_INVALID 1 /* Option number out of range */ +#define PNG_OPTION_OFF 2 +#define PNG_OPTION_ON 3 + +PNG_EXPORT(244, int, png_set_option, (png_structrp png_ptr, int option, + int onoff)); +#endif + +/******************************************************************************* + * END OF HARDWARE OPTIONS + ******************************************************************************/ + +/* Maintainer: Put new public prototypes here ^, in libpng.3, and project + * defs, scripts/pnglibconf.h, and scripts/pnglibconf.h.prebuilt + */ + +/* The last ordinal number (this is the *last* one already used; the next + * one to use is one more than this.) Maintainer, remember to add an entry to + * scripts/symbols.def as well. + */ +#ifdef PNG_EXPORT_LAST_ORDINAL + PNG_EXPORT_LAST_ORDINAL(244); +#endif + +#ifdef __cplusplus +} +#endif + +#endif /* PNG_VERSION_INFO_ONLY */ +/* Do not put anything past this line */ +#endif /* PNG_H */ diff --git a/ml/dlib/dlib/external/libpng/pngconf.h b/ml/dlib/dlib/external/libpng/pngconf.h new file mode 100644 index 000000000..8c5347224 --- /dev/null +++ b/ml/dlib/dlib/external/libpng/pngconf.h @@ -0,0 +1,626 @@ + +/* pngconf.h - machine configurable file for libpng + * + * libpng version 1.6.7 - November 14, 2013 + * + * Copyright (c) 1998-2013 Glenn Randers-Pehrson + * (Version 0.96 Copyright (c) 1996, 1997 Andreas Dilger) + * (Version 0.88 Copyright (c) 1995, 1996 Guy Eric Schalnat, Group 42, Inc.) + * + * This code is released under the libpng license. + * For conditions of distribution and use, see the disclaimer + * and license in png.h + * + */ + +/* Any machine specific code is near the front of this file, so if you + * are configuring libpng for a machine, you may want to read the section + * starting here down to where it starts to typedef png_color, png_text, + * and png_info. + */ + +#ifdef _MSC_VER +// Disable the following warnings for Visual Studio +// This is a warning you get from visual studio 2005 about things in the standard C++ +// library being "deprecated." I checked the C++ standard and it doesn't say jack +// about any of them (I checked the searchable PDF). So this warning is total Bunk. +#pragma warning(disable : 4996) +#endif + + +#ifndef PNGCONF_H +#define PNGCONF_H + +/* To do: Do all of this in scripts/pnglibconf.dfa */ +#ifdef PNG_SAFE_LIMITS_SUPPORTED +# ifdef PNG_USER_WIDTH_MAX +# undef PNG_USER_WIDTH_MAX +# define PNG_USER_WIDTH_MAX 1000000L +# endif +# ifdef PNG_USER_HEIGHT_MAX +# undef PNG_USER_HEIGHT_MAX +# define PNG_USER_HEIGHT_MAX 1000000L +# endif +# ifdef PNG_USER_CHUNK_MALLOC_MAX +# undef PNG_USER_CHUNK_MALLOC_MAX +# define PNG_USER_CHUNK_MALLOC_MAX 4000000L +# endif +# ifdef PNG_USER_CHUNK_CACHE_MAX +# undef PNG_USER_CHUNK_CACHE_MAX +# define PNG_USER_CHUNK_CACHE_MAX 128 +# endif +#endif + +#ifndef PNG_BUILDING_SYMBOL_TABLE /* else includes may cause problems */ + +/* From libpng 1.6.0 libpng requires an ANSI X3.159-1989 ("ISOC90") compliant C + * compiler for correct compilation. The following header files are required by + * the standard. If your compiler doesn't provide these header files, or they + * do not match the standard, you will need to provide/improve them. + */ +#include +#include + +/* Library header files. These header files are all defined by ISOC90; libpng + * expects conformant implementations, however, an ISOC90 conformant system need + * not provide these header files if the functionality cannot be implemented. + * In this case it will be necessary to disable the relevant parts of libpng in + * the build of pnglibconf.h. + * + * Prior to 1.6.0 string.h was included here; the API changes in 1.6.0 to not + * include this unnecessary header file. + */ + +#ifdef PNG_STDIO_SUPPORTED + /* Required for the definition of FILE: */ +# include +#endif + +#ifdef PNG_SETJMP_SUPPORTED + /* Required for the definition of jmp_buf and the declaration of longjmp: */ +# include +#endif + +#ifdef PNG_CONVERT_tIME_SUPPORTED + /* Required for struct tm: */ +# include +#endif + +#endif /* PNG_BUILDING_SYMBOL_TABLE */ + +/* Prior to 1.6.0 it was possible to turn off 'const' in declarations using + * PNG_NO_CONST; this is no longer supported except for data declarations which + * apparently still cause problems in 2011 on some compilers. + */ +#define PNG_CONST const /* backward compatibility only */ + +/* This controls optimization of the reading of 16 and 32 bit values + * from PNG files. It can be set on a per-app-file basis - it + * just changes whether a macro is used when the function is called. + * The library builder sets the default; if read functions are not + * built into the library the macro implementation is forced on. + */ +#ifndef PNG_READ_INT_FUNCTIONS_SUPPORTED +# define PNG_USE_READ_MACROS +#endif +#if !defined(PNG_NO_USE_READ_MACROS) && !defined(PNG_USE_READ_MACROS) +# if PNG_DEFAULT_READ_MACROS +# define PNG_USE_READ_MACROS +# endif +#endif + +/* COMPILER SPECIFIC OPTIONS. + * + * These options are provided so that a variety of difficult compilers + * can be used. Some are fixed at build time (e.g. PNG_API_RULE + * below) but still have compiler specific implementations, others + * may be changed on a per-file basis when compiling against libpng. + */ + +/* The PNGARG macro was used in versions of libpng prior to 1.6.0 to protect + * against legacy (pre ISOC90) compilers that did not understand function + * prototypes. It is not required for modern C compilers. + */ +#ifndef PNGARG +# define PNGARG(arglist) arglist +#endif + +/* Function calling conventions. + * ============================= + * Normally it is not necessary to specify to the compiler how to call + * a function - it just does it - however on x86 systems derived from + * Microsoft and Borland C compilers ('IBM PC', 'DOS', 'Windows' systems + * and some others) there are multiple ways to call a function and the + * default can be changed on the compiler command line. For this reason + * libpng specifies the calling convention of every exported function and + * every function called via a user supplied function pointer. This is + * done in this file by defining the following macros: + * + * PNGAPI Calling convention for exported functions. + * PNGCBAPI Calling convention for user provided (callback) functions. + * PNGCAPI Calling convention used by the ANSI-C library (required + * for longjmp callbacks and sometimes used internally to + * specify the calling convention for zlib). + * + * These macros should never be overridden. If it is necessary to + * change calling convention in a private build this can be done + * by setting PNG_API_RULE (which defaults to 0) to one of the values + * below to select the correct 'API' variants. + * + * PNG_API_RULE=0 Use PNGCAPI - the 'C' calling convention - throughout. + * This is correct in every known environment. + * PNG_API_RULE=1 Use the operating system convention for PNGAPI and + * the 'C' calling convention (from PNGCAPI) for + * callbacks (PNGCBAPI). This is no longer required + * in any known environment - if it has to be used + * please post an explanation of the problem to the + * libpng mailing list. + * + * These cases only differ if the operating system does not use the C + * calling convention, at present this just means the above cases + * (x86 DOS/Windows sytems) and, even then, this does not apply to + * Cygwin running on those systems. + * + * Note that the value must be defined in pnglibconf.h so that what + * the application uses to call the library matches the conventions + * set when building the library. + */ + +/* Symbol export + * ============= + * When building a shared library it is almost always necessary to tell + * the compiler which symbols to export. The png.h macro 'PNG_EXPORT' + * is used to mark the symbols. On some systems these symbols can be + * extracted at link time and need no special processing by the compiler, + * on other systems the symbols are flagged by the compiler and just + * the declaration requires a special tag applied (unfortunately) in a + * compiler dependent way. Some systems can do either. + * + * A small number of older systems also require a symbol from a DLL to + * be flagged to the program that calls it. This is a problem because + * we do not know in the header file included by application code that + * the symbol will come from a shared library, as opposed to a statically + * linked one. For this reason the application must tell us by setting + * the magic flag PNG_USE_DLL to turn on the special processing before + * it includes png.h. + * + * Four additional macros are used to make this happen: + * + * PNG_IMPEXP The magic (if any) to cause a symbol to be exported from + * the build or imported if PNG_USE_DLL is set - compiler + * and system specific. + * + * PNG_EXPORT_TYPE(type) A macro that pre or appends PNG_IMPEXP to + * 'type', compiler specific. + * + * PNG_DLL_EXPORT Set to the magic to use during a libpng build to + * make a symbol exported from the DLL. Not used in the + * public header files; see pngpriv.h for how it is used + * in the libpng build. + * + * PNG_DLL_IMPORT Set to the magic to force the libpng symbols to come + * from a DLL - used to define PNG_IMPEXP when + * PNG_USE_DLL is set. + */ + +/* System specific discovery. + * ========================== + * This code is used at build time to find PNG_IMPEXP, the API settings + * and PNG_EXPORT_TYPE(), it may also set a macro to indicate the DLL + * import processing is possible. On Windows systems it also sets + * compiler-specific macros to the values required to change the calling + * conventions of the various functions. + */ +#if defined(_Windows) || defined(_WINDOWS) || defined(WIN32) ||\ + defined(_WIN32) || defined(__WIN32__) || defined(__CYGWIN__) + /* Windows system (DOS doesn't support DLLs). Includes builds under Cygwin or + * MinGW on any architecture currently supported by Windows. Also includes + * Watcom builds but these need special treatment because they are not + * compatible with GCC or Visual C because of different calling conventions. + */ +# if PNG_API_RULE == 2 + /* If this line results in an error, either because __watcall is not + * understood or because of a redefine just below you cannot use *this* + * build of the library with the compiler you are using. *This* build was + * build using Watcom and applications must also be built using Watcom! + */ +# define PNGCAPI __watcall +# endif + +# if defined(__GNUC__) || (defined(_MSC_VER) && (_MSC_VER >= 800)) +# define PNGCAPI __cdecl +# if PNG_API_RULE == 1 + /* If this line results in an error __stdcall is not understood and + * PNG_API_RULE should not have been set to '1'. + */ +# define PNGAPI __stdcall +# endif +# else + /* An older compiler, or one not detected (erroneously) above, + * if necessary override on the command line to get the correct + * variants for the compiler. + */ +# ifndef PNGCAPI +# define PNGCAPI _cdecl +# endif +# if PNG_API_RULE == 1 && !defined(PNGAPI) +# define PNGAPI _stdcall +# endif +# endif /* compiler/api */ + + /* NOTE: PNGCBAPI always defaults to PNGCAPI. */ + +# if defined(PNGAPI) && !defined(PNG_USER_PRIVATEBUILD) +# error "PNG_USER_PRIVATEBUILD must be defined if PNGAPI is changed" +# endif + +# if (defined(_MSC_VER) && _MSC_VER < 800) ||\ + (defined(__BORLANDC__) && __BORLANDC__ < 0x500) + /* older Borland and MSC + * compilers used '__export' and required this to be after + * the type. + */ +# ifndef PNG_EXPORT_TYPE +# define PNG_EXPORT_TYPE(type) type PNG_IMPEXP +# endif +# define PNG_DLL_EXPORT __export +# else /* newer compiler */ +# define PNG_DLL_EXPORT __declspec(dllexport) +# ifndef PNG_DLL_IMPORT +# define PNG_DLL_IMPORT __declspec(dllimport) +# endif +# endif /* compiler */ + +#else /* !Windows */ +# if (defined(__IBMC__) || defined(__IBMCPP__)) && defined(__OS2__) +# define PNGAPI _System +# else /* !Windows/x86 && !OS/2 */ + /* Use the defaults, or define PNG*API on the command line (but + * this will have to be done for every compile!) + */ +# endif /* other system, !OS/2 */ +#endif /* !Windows/x86 */ + +/* Now do all the defaulting . */ +#ifndef PNGCAPI +# define PNGCAPI +#endif +#ifndef PNGCBAPI +# define PNGCBAPI PNGCAPI +#endif +#ifndef PNGAPI +# define PNGAPI PNGCAPI +#endif + +/* PNG_IMPEXP may be set on the compilation system command line or (if not set) + * then in an internal header file when building the library, otherwise (when + * using the library) it is set here. + */ +#ifndef PNG_IMPEXP +# if defined(PNG_USE_DLL) && defined(PNG_DLL_IMPORT) + /* This forces use of a DLL, disallowing static linking */ +# define PNG_IMPEXP PNG_DLL_IMPORT +# endif + +# ifndef PNG_IMPEXP +# define PNG_IMPEXP +# endif +#endif + +/* In 1.5.2 the definition of PNG_FUNCTION has been changed to always treat + * 'attributes' as a storage class - the attributes go at the start of the + * function definition, and attributes are always appended regardless of the + * compiler. This considerably simplifies these macros but may cause problems + * if any compilers both need function attributes and fail to handle them as + * a storage class (this is unlikely.) + */ +#ifndef PNG_FUNCTION +# define PNG_FUNCTION(type, name, args, attributes) attributes type name args +#endif + +#ifndef PNG_EXPORT_TYPE +# define PNG_EXPORT_TYPE(type) PNG_IMPEXP type +#endif + + /* The ordinal value is only relevant when preprocessing png.h for symbol + * table entries, so we discard it here. See the .dfn files in the + * scripts directory. + */ +#ifndef PNG_EXPORTA + +# define PNG_EXPORTA(ordinal, type, name, args, attributes)\ + PNG_FUNCTION(PNG_EXPORT_TYPE(type),(PNGAPI name),PNGARG(args), \ + extern attributes) +#endif + +/* ANSI-C (C90) does not permit a macro to be invoked with an empty argument, + * so make something non-empty to satisfy the requirement: + */ +#define PNG_EMPTY /*empty list*/ + +#define PNG_EXPORT(ordinal, type, name, args)\ + PNG_EXPORTA(ordinal, type, name, args, PNG_EMPTY) + +/* Use PNG_REMOVED to comment out a removed interface. */ +#ifndef PNG_REMOVED +# define PNG_REMOVED(ordinal, type, name, args, attributes) +#endif + +#ifndef PNG_CALLBACK +# define PNG_CALLBACK(type, name, args) type (PNGCBAPI name) PNGARG(args) +#endif + +/* Support for compiler specific function attributes. These are used + * so that where compiler support is available incorrect use of API + * functions in png.h will generate compiler warnings. + * + * Added at libpng-1.2.41. + */ + +#ifndef PNG_NO_PEDANTIC_WARNINGS +# ifndef PNG_PEDANTIC_WARNINGS_SUPPORTED +# define PNG_PEDANTIC_WARNINGS_SUPPORTED +# endif +#endif + +#ifdef PNG_PEDANTIC_WARNINGS_SUPPORTED + /* Support for compiler specific function attributes. These are used + * so that where compiler support is available, incorrect use of API + * functions in png.h will generate compiler warnings. Added at libpng + * version 1.2.41. Disabling these removes the warnings but may also produce + * less efficient code. + */ +# if defined(__GNUC__) +# ifndef PNG_USE_RESULT +# define PNG_USE_RESULT __attribute__((__warn_unused_result__)) +# endif +# ifndef PNG_NORETURN +# define PNG_NORETURN __attribute__((__noreturn__)) +# endif +# if __GNUC__ >= 3 +# ifndef PNG_ALLOCATED +# define PNG_ALLOCATED __attribute__((__malloc__)) +# endif +# ifndef PNG_DEPRECATED +# define PNG_DEPRECATED __attribute__((__deprecated__)) +# endif +# ifndef PNG_PRIVATE +# if 0 /* Doesn't work so we use deprecated instead*/ +# define PNG_PRIVATE \ + __attribute__((warning("This function is not exported by libpng."))) +# else +# define PNG_PRIVATE \ + __attribute__((__deprecated__)) +# endif +# endif +# if ((__GNUC__ != 3) || !defined(__GNUC_MINOR__) || (__GNUC_MINOR__ >= 1)) +# ifndef PNG_RESTRICT +# define PNG_RESTRICT __restrict +# endif +# endif /* __GNUC__ == 3.0 */ +# endif /* __GNUC__ >= 3 */ + +# elif defined(_MSC_VER) && (_MSC_VER >= 1300) +# ifndef PNG_USE_RESULT +# define PNG_USE_RESULT /* not supported */ +# endif +# ifndef PNG_NORETURN +# define PNG_NORETURN __declspec(noreturn) +# endif +# ifndef PNG_ALLOCATED +# if (_MSC_VER >= 1400) +# define PNG_ALLOCATED __declspec(restrict) +# endif +# endif +# ifndef PNG_DEPRECATED +# define PNG_DEPRECATED __declspec(deprecated) +# endif +# ifndef PNG_PRIVATE +# define PNG_PRIVATE __declspec(deprecated) +# endif +# ifndef PNG_RESTRICT +# if (_MSC_VER >= 1400) +# define PNG_RESTRICT __restrict +# endif +# endif + +# elif defined(__WATCOMC__) +# ifndef PNG_RESTRICT +# define PNG_RESTRICT __restrict +# endif +# endif /* _MSC_VER */ +#endif /* PNG_PEDANTIC_WARNINGS */ + +#ifndef PNG_DEPRECATED +# define PNG_DEPRECATED /* Use of this function is deprecated */ +#endif +#ifndef PNG_USE_RESULT +# define PNG_USE_RESULT /* The result of this function must be checked */ +#endif +#ifndef PNG_NORETURN +# define PNG_NORETURN /* This function does not return */ +#endif +#ifndef PNG_ALLOCATED +# define PNG_ALLOCATED /* The result of the function is new memory */ +#endif +#ifndef PNG_PRIVATE +# define PNG_PRIVATE /* This is a private libpng function */ +#endif +#ifndef PNG_RESTRICT +# define PNG_RESTRICT /* The C99 "restrict" feature */ +#endif +#ifndef PNG_FP_EXPORT /* A floating point API. */ +# ifdef PNG_FLOATING_POINT_SUPPORTED +# define PNG_FP_EXPORT(ordinal, type, name, args)\ + PNG_EXPORT(ordinal, type, name, args); +# else /* No floating point APIs */ +# define PNG_FP_EXPORT(ordinal, type, name, args) +# endif +#endif +#ifndef PNG_FIXED_EXPORT /* A fixed point API. */ +# ifdef PNG_FIXED_POINT_SUPPORTED +# define PNG_FIXED_EXPORT(ordinal, type, name, args)\ + PNG_EXPORT(ordinal, type, name, args); +# else /* No fixed point APIs */ +# define PNG_FIXED_EXPORT(ordinal, type, name, args) +# endif +#endif + +#ifndef PNG_BUILDING_SYMBOL_TABLE +/* Some typedefs to get us started. These should be safe on most of the common + * platforms. + * + * png_uint_32 and png_int_32 may, currently, be larger than required to hold a + * 32-bit value however this is not normally advisable. + * + * png_uint_16 and png_int_16 should always be two bytes in size - this is + * verified at library build time. + * + * png_byte must always be one byte in size. + * + * The checks below use constants from limits.h, as defined by the ISOC90 + * standard. + */ +#if CHAR_BIT == 8 && UCHAR_MAX == 255 + typedef unsigned char png_byte; +#else +# error "libpng requires 8 bit bytes" +#endif + +#if INT_MIN == -32768 && INT_MAX == 32767 + typedef int png_int_16; +#elif SHRT_MIN == -32768 && SHRT_MAX == 32767 + typedef short png_int_16; +#else +# error "libpng requires a signed 16 bit type" +#endif + +#if UINT_MAX == 65535 + typedef unsigned int png_uint_16; +#elif USHRT_MAX == 65535 + typedef unsigned short png_uint_16; +#else +# error "libpng requires an unsigned 16 bit type" +#endif + +#if INT_MIN < -2147483646 && INT_MAX > 2147483646 + typedef int png_int_32; +#elif LONG_MIN < -2147483646 && LONG_MAX > 2147483646 + typedef long int png_int_32; +#else +# error "libpng requires a signed 32 bit (or more) type" +#endif + +#if UINT_MAX > 4294967294 + typedef unsigned int png_uint_32; +#elif ULONG_MAX > 4294967294 + typedef unsigned long int png_uint_32; +#else +# error "libpng requires an unsigned 32 bit (or more) type" +#endif + +/* Prior to 1.6.0 it was possible to disable the use of size_t, 1.6.0, however, + * requires an ISOC90 compiler and relies on consistent behavior of sizeof. + */ +typedef size_t png_size_t; +typedef ptrdiff_t png_ptrdiff_t; + +/* libpng needs to know the maximum value of 'size_t' and this controls the + * definition of png_alloc_size_t, below. This maximum value of size_t limits + * but does not control the maximum allocations the library makes - there is + * direct application control of this through png_set_user_limits(). + */ +#ifndef PNG_SMALL_SIZE_T + /* Compiler specific tests for systems where size_t is known to be less than + * 32 bits (some of these systems may no longer work because of the lack of + * 'far' support; see above.) + */ +# if (defined(__TURBOC__) && !defined(__FLAT__)) ||\ + (defined(_MSC_VER) && defined(MAXSEG_64K)) +# define PNG_SMALL_SIZE_T +# endif +#endif + +/* png_alloc_size_t is guaranteed to be no smaller than png_size_t, and no + * smaller than png_uint_32. Casts from png_size_t or png_uint_32 to + * png_alloc_size_t are not necessary; in fact, it is recommended not to use + * them at all so that the compiler can complain when something turns out to be + * problematic. + * + * Casts in the other direction (from png_alloc_size_t to png_size_t or + * png_uint_32) should be explicitly applied; however, we do not expect to + * encounter practical situations that require such conversions. + * + * PNG_SMALL_SIZE_T must be defined if the maximum value of size_t is less than + * 4294967295 - i.e. less than the maximum value of png_uint_32. + */ +#ifdef PNG_SMALL_SIZE_T + typedef png_uint_32 png_alloc_size_t; +#else + typedef png_size_t png_alloc_size_t; +#endif + +/* Prior to 1.6.0 libpng offered limited support for Microsoft C compiler + * implementations of Intel CPU specific support of user-mode segmented address + * spaces, where 16-bit pointers address more than 65536 bytes of memory using + * separate 'segment' registers. The implementation requires two different + * types of pointer (only one of which includes the segment value.) + * + * If required this support is available in version 1.2 of libpng and may be + * available in versions through 1.5, although the correctness of the code has + * not been verified recently. + */ + +/* Typedef for floating-point numbers that are converted to fixed-point with a + * multiple of 100,000, e.g., gamma + */ +typedef png_int_32 png_fixed_point; + +/* Add typedefs for pointers */ +typedef void * png_voidp; +typedef const void * png_const_voidp; +typedef png_byte * png_bytep; +typedef const png_byte * png_const_bytep; +typedef png_uint_32 * png_uint_32p; +typedef const png_uint_32 * png_const_uint_32p; +typedef png_int_32 * png_int_32p; +typedef const png_int_32 * png_const_int_32p; +typedef png_uint_16 * png_uint_16p; +typedef const png_uint_16 * png_const_uint_16p; +typedef png_int_16 * png_int_16p; +typedef const png_int_16 * png_const_int_16p; +typedef char * png_charp; +typedef const char * png_const_charp; +typedef png_fixed_point * png_fixed_point_p; +typedef const png_fixed_point * png_const_fixed_point_p; +typedef png_size_t * png_size_tp; +typedef const png_size_t * png_const_size_tp; + +#ifdef PNG_STDIO_SUPPORTED +typedef FILE * png_FILE_p; +#endif + +#ifdef PNG_FLOATING_POINT_SUPPORTED +typedef double * png_doublep; +typedef const double * png_const_doublep; +#endif + +/* Pointers to pointers; i.e. arrays */ +typedef png_byte * * png_bytepp; +typedef png_uint_32 * * png_uint_32pp; +typedef png_int_32 * * png_int_32pp; +typedef png_uint_16 * * png_uint_16pp; +typedef png_int_16 * * png_int_16pp; +typedef const char * * png_const_charpp; +typedef char * * png_charpp; +typedef png_fixed_point * * png_fixed_point_pp; +#ifdef PNG_FLOATING_POINT_SUPPORTED +typedef double * * png_doublepp; +#endif + +/* Pointers to pointers to pointers; i.e., pointer to array */ +typedef char * * * png_charppp; + +#endif /* PNG_BUILDING_SYMBOL_TABLE */ + +#endif /* PNGCONF_H */ diff --git a/ml/dlib/dlib/external/libpng/pngdebug.h b/ml/dlib/dlib/external/libpng/pngdebug.h new file mode 100644 index 000000000..16f81fdd1 --- /dev/null +++ b/ml/dlib/dlib/external/libpng/pngdebug.h @@ -0,0 +1,157 @@ + +/* pngdebug.h - Debugging macros for libpng, also used in pngtest.c + * + * Copyright (c) 1998-2011 Glenn Randers-Pehrson + * (Version 0.96 Copyright (c) 1996, 1997 Andreas Dilger) + * (Version 0.88 Copyright (c) 1995, 1996 Guy Eric Schalnat, Group 42, Inc.) + * + * Last changed in libpng 1.5.0 [January 6, 2011] + * + * This code is released under the libpng license. + * For conditions of distribution and use, see the disclaimer + * and license in png.h + */ + +/* Define PNG_DEBUG at compile time for debugging information. Higher + * numbers for PNG_DEBUG mean more debugging information. This has + * only been added since version 0.95 so it is not implemented throughout + * libpng yet, but more support will be added as needed. + * + * png_debug[1-2]?(level, message ,arg{0-2}) + * Expands to a statement (either a simple expression or a compound + * do..while(0) statement) that outputs a message with parameter + * substitution if PNG_DEBUG is defined to 2 or more. If PNG_DEBUG + * is undefined, 0 or 1 every png_debug expands to a simple expression + * (actually ((void)0)). + * + * level: level of detail of message, starting at 0. A level 'n' + * message is preceded by 'n' tab characters (not implemented + * on Microsoft compilers unless PNG_DEBUG_FILE is also + * defined, to allow debug DLL compilation with no standard IO). + * message: a printf(3) style text string. A trailing '\n' is added + * to the message. + * arg: 0 to 2 arguments for printf(3) style substitution in message. + */ +#ifndef PNGDEBUG_H +#define PNGDEBUG_H +/* These settings control the formatting of messages in png.c and pngerror.c */ +/* Moved to pngdebug.h at 1.5.0 */ +# ifndef PNG_LITERAL_SHARP +# define PNG_LITERAL_SHARP 0x23 +# endif +# ifndef PNG_LITERAL_LEFT_SQUARE_BRACKET +# define PNG_LITERAL_LEFT_SQUARE_BRACKET 0x5b +# endif +# ifndef PNG_LITERAL_RIGHT_SQUARE_BRACKET +# define PNG_LITERAL_RIGHT_SQUARE_BRACKET 0x5d +# endif +# ifndef PNG_STRING_NEWLINE +# define PNG_STRING_NEWLINE "\n" +# endif + +#ifdef PNG_DEBUG +# if (PNG_DEBUG > 0) +# if !defined(PNG_DEBUG_FILE) && defined(_MSC_VER) +# include +# if (PNG_DEBUG > 1) +# ifndef _DEBUG +# define _DEBUG +# endif +# ifndef png_debug +# define png_debug(l,m) _RPT0(_CRT_WARN,m PNG_STRING_NEWLINE) +# endif +# ifndef png_debug1 +# define png_debug1(l,m,p1) _RPT1(_CRT_WARN,m PNG_STRING_NEWLINE,p1) +# endif +# ifndef png_debug2 +# define png_debug2(l,m,p1,p2) \ + _RPT2(_CRT_WARN,m PNG_STRING_NEWLINE,p1,p2) +# endif +# endif +# else /* PNG_DEBUG_FILE || !_MSC_VER */ +# ifndef PNG_STDIO_SUPPORTED +# include /* not included yet */ +# endif +# ifndef PNG_DEBUG_FILE +# define PNG_DEBUG_FILE stderr +# endif /* PNG_DEBUG_FILE */ + +# if (PNG_DEBUG > 1) +/* Note: ["%s"m PNG_STRING_NEWLINE] probably does not work on + * non-ISO compilers + */ +# ifdef __STDC__ +# ifndef png_debug +# define png_debug(l,m) \ + do { \ + int num_tabs=l; \ + fprintf(PNG_DEBUG_FILE,"%s"m PNG_STRING_NEWLINE,(num_tabs==1 ? "\t" : \ + (num_tabs==2 ? "\t\t":(num_tabs>2 ? "\t\t\t":"")))); \ + } while (0) +# endif +# ifndef png_debug1 +# define png_debug1(l,m,p1) \ + do { \ + int num_tabs=l; \ + fprintf(PNG_DEBUG_FILE,"%s"m PNG_STRING_NEWLINE,(num_tabs==1 ? "\t" : \ + (num_tabs==2 ? "\t\t":(num_tabs>2 ? "\t\t\t":""))),p1); \ + } while (0) +# endif +# ifndef png_debug2 +# define png_debug2(l,m,p1,p2) \ + do { \ + int num_tabs=l; \ + fprintf(PNG_DEBUG_FILE,"%s"m PNG_STRING_NEWLINE,(num_tabs==1 ? "\t" : \ + (num_tabs==2 ? "\t\t":(num_tabs>2 ? "\t\t\t":""))),p1,p2); \ + } while (0) +# endif +# else /* __STDC __ */ +# ifndef png_debug +# define png_debug(l,m) \ + do { \ + int num_tabs=l; \ + char format[256]; \ + snprintf(format,256,"%s%s%s",(num_tabs==1 ? "\t" : \ + (num_tabs==2 ? "\t\t":(num_tabs>2 ? "\t\t\t":""))), \ + m,PNG_STRING_NEWLINE); \ + fprintf(PNG_DEBUG_FILE,format); \ + } while (0) +# endif +# ifndef png_debug1 +# define png_debug1(l,m,p1) \ + do { \ + int num_tabs=l; \ + char format[256]; \ + snprintf(format,256,"%s%s%s",(num_tabs==1 ? "\t" : \ + (num_tabs==2 ? "\t\t":(num_tabs>2 ? "\t\t\t":""))), \ + m,PNG_STRING_NEWLINE); \ + fprintf(PNG_DEBUG_FILE,format,p1); \ + } while (0) +# endif +# ifndef png_debug2 +# define png_debug2(l,m,p1,p2) \ + do { \ + int num_tabs=l; \ + char format[256]; \ + snprintf(format,256,"%s%s%s",(num_tabs==1 ? "\t" : \ + (num_tabs==2 ? "\t\t":(num_tabs>2 ? "\t\t\t":""))), \ + m,PNG_STRING_NEWLINE); \ + fprintf(PNG_DEBUG_FILE,format,p1,p2); \ + } while (0) +# endif +# endif /* __STDC __ */ +# endif /* (PNG_DEBUG > 1) */ + +# endif /* _MSC_VER */ +# endif /* (PNG_DEBUG > 0) */ +#endif /* PNG_DEBUG */ +#ifndef png_debug +# define png_debug(l, m) ((void)0) +#endif +#ifndef png_debug1 +# define png_debug1(l, m, p1) ((void)0) +#endif +#ifndef png_debug2 +# define png_debug2(l, m, p1, p2) ((void)0) +#endif +#endif /* PNGDEBUG_H */ diff --git a/ml/dlib/dlib/external/libpng/pngerror.c b/ml/dlib/dlib/external/libpng/pngerror.c new file mode 100644 index 000000000..f469206ee --- /dev/null +++ b/ml/dlib/dlib/external/libpng/pngerror.c @@ -0,0 +1,932 @@ + +/* pngerror.c - stub functions for i/o and memory allocation + * + * Last changed in libpng 1.6.1 [March 28, 2013] + * Copyright (c) 1998-2013 Glenn Randers-Pehrson + * (Version 0.96 Copyright (c) 1996, 1997 Andreas Dilger) + * (Version 0.88 Copyright (c) 1995, 1996 Guy Eric Schalnat, Group 42, Inc.) + * + * This code is released under the libpng license. + * For conditions of distribution and use, see the disclaimer + * and license in png.h + * + * This file provides a location for all error handling. Users who + * need special error handling are expected to write replacement functions + * and use png_set_error_fn() to use those functions. See the instructions + * at each function. + */ + +#include "pngpriv.h" + +#if defined(PNG_READ_SUPPORTED) || defined(PNG_WRITE_SUPPORTED) + +static PNG_FUNCTION(void, png_default_error,PNGARG((png_const_structrp png_ptr, + png_const_charp error_message)),PNG_NORETURN); + +#ifdef PNG_WARNINGS_SUPPORTED +static void /* PRIVATE */ +png_default_warning PNGARG((png_const_structrp png_ptr, + png_const_charp warning_message)); +#endif /* PNG_WARNINGS_SUPPORTED */ + +/* This function is called whenever there is a fatal error. This function + * should not be changed. If there is a need to handle errors differently, + * you should supply a replacement error function and use png_set_error_fn() + * to replace the error function at run-time. + */ +#ifdef PNG_ERROR_TEXT_SUPPORTED +PNG_FUNCTION(void,PNGAPI +png_error,(png_const_structrp png_ptr, png_const_charp error_message), + PNG_NORETURN) +{ +#ifdef PNG_ERROR_NUMBERS_SUPPORTED + char msg[16]; + if (png_ptr != NULL) + { + if (png_ptr->flags& + (PNG_FLAG_STRIP_ERROR_NUMBERS|PNG_FLAG_STRIP_ERROR_TEXT)) + { + if (*error_message == PNG_LITERAL_SHARP) + { + /* Strip "#nnnn " from beginning of error message. */ + int offset; + for (offset = 1; offset<15; offset++) + if (error_message[offset] == ' ') + break; + + if (png_ptr->flags&PNG_FLAG_STRIP_ERROR_TEXT) + { + int i; + for (i = 0; i < offset - 1; i++) + msg[i] = error_message[i + 1]; + msg[i - 1] = '\0'; + error_message = msg; + } + + else + error_message += offset; + } + + else + { + if (png_ptr->flags&PNG_FLAG_STRIP_ERROR_TEXT) + { + msg[0] = '0'; + msg[1] = '\0'; + error_message = msg; + } + } + } + } +#endif + if (png_ptr != NULL && png_ptr->error_fn != NULL) + (*(png_ptr->error_fn))(png_constcast(png_structrp,png_ptr), + error_message); + + /* If the custom handler doesn't exist, or if it returns, + use the default handler, which will not return. */ + png_default_error(png_ptr, error_message); +} +#else +PNG_FUNCTION(void,PNGAPI +png_err,(png_const_structrp png_ptr),PNG_NORETURN) +{ + /* Prior to 1.5.2 the error_fn received a NULL pointer, expressed + * erroneously as '\0', instead of the empty string "". This was + * apparently an error, introduced in libpng-1.2.20, and png_default_error + * will crash in this case. + */ + if (png_ptr != NULL && png_ptr->error_fn != NULL) + (*(png_ptr->error_fn))(png_constcast(png_structrp,png_ptr), ""); + + /* If the custom handler doesn't exist, or if it returns, + use the default handler, which will not return. */ + png_default_error(png_ptr, ""); +} +#endif /* PNG_ERROR_TEXT_SUPPORTED */ + +/* Utility to safely appends strings to a buffer. This never errors out so + * error checking is not required in the caller. + */ +size_t +png_safecat(png_charp buffer, size_t bufsize, size_t pos, + png_const_charp string) +{ + if (buffer != NULL && pos < bufsize) + { + if (string != NULL) + while (*string != '\0' && pos < bufsize-1) + buffer[pos++] = *string++; + + buffer[pos] = '\0'; + } + + return pos; +} + +#if defined(PNG_WARNINGS_SUPPORTED) || defined(PNG_TIME_RFC1123_SUPPORTED) +/* Utility to dump an unsigned value into a buffer, given a start pointer and + * and end pointer (which should point just *beyond* the end of the buffer!) + * Returns the pointer to the start of the formatted string. + */ +png_charp +png_format_number(png_const_charp start, png_charp end, int format, + png_alloc_size_t number) +{ + int count = 0; /* number of digits output */ + int mincount = 1; /* minimum number required */ + int output = 0; /* digit output (for the fixed point format) */ + + *--end = '\0'; + + /* This is written so that the loop always runs at least once, even with + * number zero. + */ + while (end > start && (number != 0 || count < mincount)) + { + + static const char digits[] = "0123456789ABCDEF"; + + switch (format) + { + case PNG_NUMBER_FORMAT_fixed: + /* Needs five digits (the fraction) */ + mincount = 5; + if (output || number % 10 != 0) + { + *--end = digits[number % 10]; + output = 1; + } + number /= 10; + break; + + case PNG_NUMBER_FORMAT_02u: + /* Expects at least 2 digits. */ + mincount = 2; + /* FALL THROUGH */ + + case PNG_NUMBER_FORMAT_u: + *--end = digits[number % 10]; + number /= 10; + break; + + case PNG_NUMBER_FORMAT_02x: + /* This format expects at least two digits */ + mincount = 2; + /* FALL THROUGH */ + + case PNG_NUMBER_FORMAT_x: + *--end = digits[number & 0xf]; + number >>= 4; + break; + + default: /* an error */ + number = 0; + break; + } + + /* Keep track of the number of digits added */ + ++count; + + /* Float a fixed number here: */ + if (format == PNG_NUMBER_FORMAT_fixed) if (count == 5) if (end > start) + { + /* End of the fraction, but maybe nothing was output? In that case + * drop the decimal point. If the number is a true zero handle that + * here. + */ + if (output) + *--end = '.'; + else if (number == 0) /* and !output */ + *--end = '0'; + } + } + + return end; +} +#endif + +#ifdef PNG_WARNINGS_SUPPORTED +/* This function is called whenever there is a non-fatal error. This function + * should not be changed. If there is a need to handle warnings differently, + * you should supply a replacement warning function and use + * png_set_error_fn() to replace the warning function at run-time. + */ +void PNGAPI +png_warning(png_const_structrp png_ptr, png_const_charp warning_message) +{ + int offset = 0; + if (png_ptr != NULL) + { +#ifdef PNG_ERROR_NUMBERS_SUPPORTED + if (png_ptr->flags& + (PNG_FLAG_STRIP_ERROR_NUMBERS|PNG_FLAG_STRIP_ERROR_TEXT)) +#endif + { + if (*warning_message == PNG_LITERAL_SHARP) + { + for (offset = 1; offset < 15; offset++) + if (warning_message[offset] == ' ') + break; + } + } + } + if (png_ptr != NULL && png_ptr->warning_fn != NULL) + (*(png_ptr->warning_fn))(png_constcast(png_structrp,png_ptr), + warning_message + offset); + else + png_default_warning(png_ptr, warning_message + offset); +} + +/* These functions support 'formatted' warning messages with up to + * PNG_WARNING_PARAMETER_COUNT parameters. In the format string the parameter + * is introduced by @, where 'number' starts at 1. This follows the + * standard established by X/Open for internationalizable error messages. + */ +void +png_warning_parameter(png_warning_parameters p, int number, + png_const_charp string) +{ + if (number > 0 && number <= PNG_WARNING_PARAMETER_COUNT) + (void)png_safecat(p[number-1], (sizeof p[number-1]), 0, string); +} + +void +png_warning_parameter_unsigned(png_warning_parameters p, int number, int format, + png_alloc_size_t value) +{ + char buffer[PNG_NUMBER_BUFFER_SIZE]; + png_warning_parameter(p, number, PNG_FORMAT_NUMBER(buffer, format, value)); +} + +void +png_warning_parameter_signed(png_warning_parameters p, int number, int format, + png_int_32 value) +{ + png_alloc_size_t u; + png_charp str; + char buffer[PNG_NUMBER_BUFFER_SIZE]; + + /* Avoid overflow by doing the negate in a png_alloc_size_t: */ + u = (png_alloc_size_t)value; + if (value < 0) + u = ~u + 1; + + str = PNG_FORMAT_NUMBER(buffer, format, u); + + if (value < 0 && str > buffer) + *--str = '-'; + + png_warning_parameter(p, number, str); +} + +void +png_formatted_warning(png_const_structrp png_ptr, png_warning_parameters p, + png_const_charp message) +{ + /* The internal buffer is just 192 bytes - enough for all our messages, + * overflow doesn't happen because this code checks! If someone figures + * out how to send us a message longer than 192 bytes, all that will + * happen is that the message will be truncated appropriately. + */ + size_t i = 0; /* Index in the msg[] buffer: */ + char msg[192]; + + /* Each iteration through the following loop writes at most one character + * to msg[i++] then returns here to validate that there is still space for + * the trailing '\0'. It may (in the case of a parameter) read more than + * one character from message[]; it must check for '\0' and continue to the + * test if it finds the end of string. + */ + while (i<(sizeof msg)-1 && *message != '\0') + { + /* '@' at end of string is now just printed (previously it was skipped); + * it is an error in the calling code to terminate the string with @. + */ + if (p != NULL && *message == '@' && message[1] != '\0') + { + int parameter_char = *++message; /* Consume the '@' */ + static const char valid_parameters[] = "123456789"; + int parameter = 0; + + /* Search for the parameter digit, the index in the string is the + * parameter to use. + */ + while (valid_parameters[parameter] != parameter_char && + valid_parameters[parameter] != '\0') + ++parameter; + + /* If the parameter digit is out of range it will just get printed. */ + if (parameter < PNG_WARNING_PARAMETER_COUNT) + { + /* Append this parameter */ + png_const_charp parm = p[parameter]; + png_const_charp pend = p[parameter] + (sizeof p[parameter]); + + /* No need to copy the trailing '\0' here, but there is no guarantee + * that parm[] has been initialized, so there is no guarantee of a + * trailing '\0': + */ + while (i<(sizeof msg)-1 && *parm != '\0' && parm < pend) + msg[i++] = *parm++; + + /* Consume the parameter digit too: */ + ++message; + continue; + } + + /* else not a parameter and there is a character after the @ sign; just + * copy that. This is known not to be '\0' because of the test above. + */ + } + + /* At this point *message can't be '\0', even in the bad parameter case + * above where there is a lone '@' at the end of the message string. + */ + msg[i++] = *message++; + } + + /* i is always less than (sizeof msg), so: */ + msg[i] = '\0'; + + /* And this is the formatted message. It may be larger than + * PNG_MAX_ERROR_TEXT, but that is only used for 'chunk' errors and these + * are not (currently) formatted. + */ + png_warning(png_ptr, msg); +} +#endif /* PNG_WARNINGS_SUPPORTED */ + +#ifdef PNG_BENIGN_ERRORS_SUPPORTED +void PNGAPI +png_benign_error(png_const_structrp png_ptr, png_const_charp error_message) +{ + if (png_ptr->flags & PNG_FLAG_BENIGN_ERRORS_WARN) + { +# ifdef PNG_READ_SUPPORTED + if ((png_ptr->mode & PNG_IS_READ_STRUCT) != 0 && + png_ptr->chunk_name != 0) + png_chunk_warning(png_ptr, error_message); + else +# endif + png_warning(png_ptr, error_message); + } + + else + { +# ifdef PNG_READ_SUPPORTED + if ((png_ptr->mode & PNG_IS_READ_STRUCT) != 0 && + png_ptr->chunk_name != 0) + png_chunk_error(png_ptr, error_message); + else +# endif + png_error(png_ptr, error_message); + } +} + +void /* PRIVATE */ +png_app_warning(png_const_structrp png_ptr, png_const_charp error_message) +{ + if (png_ptr->flags & PNG_FLAG_APP_WARNINGS_WARN) + png_warning(png_ptr, error_message); + else + png_error(png_ptr, error_message); +} + +void /* PRIVATE */ +png_app_error(png_const_structrp png_ptr, png_const_charp error_message) +{ + if (png_ptr->flags & PNG_FLAG_APP_ERRORS_WARN) + png_warning(png_ptr, error_message); + else + png_error(png_ptr, error_message); +} +#endif /* BENIGN_ERRORS */ + +/* These utilities are used internally to build an error message that relates + * to the current chunk. The chunk name comes from png_ptr->chunk_name, + * this is used to prefix the message. The message is limited in length + * to 63 bytes, the name characters are output as hex digits wrapped in [] + * if the character is invalid. + */ +#define isnonalpha(c) ((c) < 65 || (c) > 122 || ((c) > 90 && (c) < 97)) +static PNG_CONST char png_digit[16] = { + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', + 'A', 'B', 'C', 'D', 'E', 'F' +}; + +#define PNG_MAX_ERROR_TEXT 196 /* Currently limited be profile_error in png.c */ +#if defined(PNG_WARNINGS_SUPPORTED) || defined(PNG_ERROR_TEXT_SUPPORTED) +static void /* PRIVATE */ +png_format_buffer(png_const_structrp png_ptr, png_charp buffer, png_const_charp + error_message) +{ + png_uint_32 chunk_name = png_ptr->chunk_name; + int iout = 0, ishift = 24; + + while (ishift >= 0) + { + int c = (int)(chunk_name >> ishift) & 0xff; + + ishift -= 8; + if (isnonalpha(c)) + { + buffer[iout++] = PNG_LITERAL_LEFT_SQUARE_BRACKET; + buffer[iout++] = png_digit[(c & 0xf0) >> 4]; + buffer[iout++] = png_digit[c & 0x0f]; + buffer[iout++] = PNG_LITERAL_RIGHT_SQUARE_BRACKET; + } + + else + { + buffer[iout++] = (char)c; + } + } + + if (error_message == NULL) + buffer[iout] = '\0'; + + else + { + int iin = 0; + + buffer[iout++] = ':'; + buffer[iout++] = ' '; + + while (iin < PNG_MAX_ERROR_TEXT-1 && error_message[iin] != '\0') + buffer[iout++] = error_message[iin++]; + + /* iin < PNG_MAX_ERROR_TEXT, so the following is safe: */ + buffer[iout] = '\0'; + } +} +#endif /* PNG_WARNINGS_SUPPORTED || PNG_ERROR_TEXT_SUPPORTED */ + +#if defined(PNG_READ_SUPPORTED) && defined(PNG_ERROR_TEXT_SUPPORTED) +PNG_FUNCTION(void,PNGAPI +png_chunk_error,(png_const_structrp png_ptr, png_const_charp error_message), + PNG_NORETURN) +{ + char msg[18+PNG_MAX_ERROR_TEXT]; + if (png_ptr == NULL) + png_error(png_ptr, error_message); + + else + { + png_format_buffer(png_ptr, msg, error_message); + png_error(png_ptr, msg); + } +} +#endif /* PNG_READ_SUPPORTED && PNG_ERROR_TEXT_SUPPORTED */ + +#ifdef PNG_WARNINGS_SUPPORTED +void PNGAPI +png_chunk_warning(png_const_structrp png_ptr, png_const_charp warning_message) +{ + char msg[18+PNG_MAX_ERROR_TEXT]; + if (png_ptr == NULL) + png_warning(png_ptr, warning_message); + + else + { + png_format_buffer(png_ptr, msg, warning_message); + png_warning(png_ptr, msg); + } +} +#endif /* PNG_WARNINGS_SUPPORTED */ + +#ifdef PNG_READ_SUPPORTED +#ifdef PNG_BENIGN_ERRORS_SUPPORTED +void PNGAPI +png_chunk_benign_error(png_const_structrp png_ptr, png_const_charp + error_message) +{ + if (png_ptr->flags & PNG_FLAG_BENIGN_ERRORS_WARN) + png_chunk_warning(png_ptr, error_message); + + else + png_chunk_error(png_ptr, error_message); +} +#endif +#endif /* PNG_READ_SUPPORTED */ + +void /* PRIVATE */ +png_chunk_report(png_const_structrp png_ptr, png_const_charp message, int error) +{ + /* This is always supported, but for just read or just write it + * unconditionally does the right thing. + */ +# if defined(PNG_READ_SUPPORTED) && defined(PNG_WRITE_SUPPORTED) + if (png_ptr->mode & PNG_IS_READ_STRUCT) +# endif + +# ifdef PNG_READ_SUPPORTED + { + if (error < PNG_CHUNK_ERROR) + png_chunk_warning(png_ptr, message); + + else + png_chunk_benign_error(png_ptr, message); + } +# endif + +# if defined(PNG_READ_SUPPORTED) && defined(PNG_WRITE_SUPPORTED) + else if (!(png_ptr->mode & PNG_IS_READ_STRUCT)) +# endif + +# ifdef PNG_WRITE_SUPPORTED + { + if (error < PNG_CHUNK_WRITE_ERROR) + png_app_warning(png_ptr, message); + + else + png_app_error(png_ptr, message); + } +# endif +} + +#ifdef PNG_ERROR_TEXT_SUPPORTED +#ifdef PNG_FLOATING_POINT_SUPPORTED +PNG_FUNCTION(void, +png_fixed_error,(png_const_structrp png_ptr, png_const_charp name),PNG_NORETURN) +{ +# define fixed_message "fixed point overflow in " +# define fixed_message_ln ((sizeof fixed_message)-1) + int iin; + char msg[fixed_message_ln+PNG_MAX_ERROR_TEXT]; + memcpy(msg, fixed_message, fixed_message_ln); + iin = 0; + if (name != NULL) while (iin < (PNG_MAX_ERROR_TEXT-1) && name[iin] != 0) + { + msg[fixed_message_ln + iin] = name[iin]; + ++iin; + } + msg[fixed_message_ln + iin] = 0; + png_error(png_ptr, msg); +} +#endif +#endif + +#ifdef PNG_SETJMP_SUPPORTED +/* This API only exists if ANSI-C style error handling is used, + * otherwise it is necessary for png_default_error to be overridden. + */ +jmp_buf* PNGAPI +png_set_longjmp_fn(png_structrp png_ptr, png_longjmp_ptr longjmp_fn, + size_t jmp_buf_size) +{ + /* From libpng 1.6.0 the app gets one chance to set a 'jmpbuf_size' value + * and it must not change after that. Libpng doesn't care how big the + * buffer is, just that it doesn't change. + * + * If the buffer size is no *larger* than the size of jmp_buf when libpng is + * compiled a built in jmp_buf is returned; this preserves the pre-1.6.0 + * semantics that this call will not fail. If the size is larger, however, + * the buffer is allocated and this may fail, causing the function to return + * NULL. + */ + if (png_ptr == NULL) + return NULL; + + if (png_ptr->jmp_buf_ptr == NULL) + { + png_ptr->jmp_buf_size = 0; /* not allocated */ + + if (jmp_buf_size <= (sizeof png_ptr->jmp_buf_local)) + png_ptr->jmp_buf_ptr = &png_ptr->jmp_buf_local; + + else + { + png_ptr->jmp_buf_ptr = png_voidcast(jmp_buf *, + png_malloc_warn(png_ptr, jmp_buf_size)); + + if (png_ptr->jmp_buf_ptr == NULL) + return NULL; /* new NULL return on OOM */ + + png_ptr->jmp_buf_size = jmp_buf_size; + } + } + + else /* Already allocated: check the size */ + { + size_t size = png_ptr->jmp_buf_size; + + if (size == 0) + { + size = (sizeof png_ptr->jmp_buf_local); + if (png_ptr->jmp_buf_ptr != &png_ptr->jmp_buf_local) + { + /* This is an internal error in libpng: somehow we have been left + * with a stack allocated jmp_buf when the application regained + * control. It's always possible to fix this up, but for the moment + * this is a png_error because that makes it easy to detect. + */ + png_error(png_ptr, "Libpng jmp_buf still allocated"); + /* png_ptr->jmp_buf_ptr = &png_ptr->jmp_buf_local; */ + } + } + + if (size != jmp_buf_size) + { + png_warning(png_ptr, "Application jmp_buf size changed"); + return NULL; /* caller will probably crash: no choice here */ + } + } + + /* Finally fill in the function, now we have a satisfactory buffer. It is + * valid to change the function on every call. + */ + png_ptr->longjmp_fn = longjmp_fn; + return png_ptr->jmp_buf_ptr; +} + +void /* PRIVATE */ +png_free_jmpbuf(png_structrp png_ptr) +{ + if (png_ptr != NULL) + { + jmp_buf *jb = png_ptr->jmp_buf_ptr; + + /* A size of 0 is used to indicate a local, stack, allocation of the + * pointer; used here and in png.c + */ + if (jb != NULL && png_ptr->jmp_buf_size > 0) + { + + /* This stuff is so that a failure to free the error control structure + * does not leave libpng in a state with no valid error handling: the + * free always succeeds, if there is an error it gets ignored. + */ + if (jb != &png_ptr->jmp_buf_local) + { + /* Make an internal, libpng, jmp_buf to return here */ + jmp_buf free_jmp_buf; + + if (!setjmp(free_jmp_buf)) + { + png_ptr->jmp_buf_ptr = &free_jmp_buf; /* come back here */ + png_ptr->jmp_buf_size = 0; /* stack allocation */ + png_ptr->longjmp_fn = longjmp; + png_free(png_ptr, jb); /* Return to setjmp on error */ + } + } + } + + /* *Always* cancel everything out: */ + png_ptr->jmp_buf_size = 0; + png_ptr->jmp_buf_ptr = NULL; + png_ptr->longjmp_fn = 0; + } +} +#endif + +/* This is the default error handling function. Note that replacements for + * this function MUST NOT RETURN, or the program will likely crash. This + * function is used by default, or if the program supplies NULL for the + * error function pointer in png_set_error_fn(). + */ +static PNG_FUNCTION(void /* PRIVATE */, +png_default_error,(png_const_structrp png_ptr, png_const_charp error_message), + PNG_NORETURN) +{ +#ifdef PNG_CONSOLE_IO_SUPPORTED +#ifdef PNG_ERROR_NUMBERS_SUPPORTED + /* Check on NULL only added in 1.5.4 */ + if (error_message != NULL && *error_message == PNG_LITERAL_SHARP) + { + /* Strip "#nnnn " from beginning of error message. */ + int offset; + char error_number[16]; + for (offset = 0; offset<15; offset++) + { + error_number[offset] = error_message[offset + 1]; + if (error_message[offset] == ' ') + break; + } + + if ((offset > 1) && (offset < 15)) + { + error_number[offset - 1] = '\0'; + fprintf(stderr, "libpng error no. %s: %s", + error_number, error_message + offset + 1); + fprintf(stderr, PNG_STRING_NEWLINE); + } + + else + { + fprintf(stderr, "libpng error: %s, offset=%d", + error_message, offset); + fprintf(stderr, PNG_STRING_NEWLINE); + } + } + else +#endif + { + fprintf(stderr, "libpng error: %s", error_message ? error_message : + "undefined"); + fprintf(stderr, PNG_STRING_NEWLINE); + } +#else + PNG_UNUSED(error_message) /* Make compiler happy */ +#endif + png_longjmp(png_ptr, 1); +} + +PNG_FUNCTION(void,PNGAPI +png_longjmp,(png_const_structrp png_ptr, int val),PNG_NORETURN) +{ +#ifdef PNG_SETJMP_SUPPORTED + if (png_ptr && png_ptr->longjmp_fn && png_ptr->jmp_buf_ptr) + png_ptr->longjmp_fn(*png_ptr->jmp_buf_ptr, val); +#endif + + /* Here if not setjmp support or if png_ptr is null. */ + PNG_ABORT(); +} + +#ifdef PNG_WARNINGS_SUPPORTED +/* This function is called when there is a warning, but the library thinks + * it can continue anyway. Replacement functions don't have to do anything + * here if you don't want them to. In the default configuration, png_ptr is + * not used, but it is passed in case it may be useful. + */ +static void /* PRIVATE */ +png_default_warning(png_const_structrp png_ptr, png_const_charp warning_message) +{ +#ifdef PNG_CONSOLE_IO_SUPPORTED +# ifdef PNG_ERROR_NUMBERS_SUPPORTED + if (*warning_message == PNG_LITERAL_SHARP) + { + int offset; + char warning_number[16]; + for (offset = 0; offset < 15; offset++) + { + warning_number[offset] = warning_message[offset + 1]; + if (warning_message[offset] == ' ') + break; + } + + if ((offset > 1) && (offset < 15)) + { + warning_number[offset + 1] = '\0'; + fprintf(stderr, "libpng warning no. %s: %s", + warning_number, warning_message + offset); + fprintf(stderr, PNG_STRING_NEWLINE); + } + + else + { + fprintf(stderr, "libpng warning: %s", + warning_message); + fprintf(stderr, PNG_STRING_NEWLINE); + } + } + else +# endif + + { + fprintf(stderr, "libpng warning: %s", warning_message); + fprintf(stderr, PNG_STRING_NEWLINE); + } +#else + PNG_UNUSED(warning_message) /* Make compiler happy */ +#endif + PNG_UNUSED(png_ptr) /* Make compiler happy */ +} +#endif /* PNG_WARNINGS_SUPPORTED */ + +/* This function is called when the application wants to use another method + * of handling errors and warnings. Note that the error function MUST NOT + * return to the calling routine or serious problems will occur. The return + * method used in the default routine calls longjmp(png_ptr->jmp_buf_ptr, 1) + */ +void PNGAPI +png_set_error_fn(png_structrp png_ptr, png_voidp error_ptr, + png_error_ptr error_fn, png_error_ptr warning_fn) +{ + if (png_ptr == NULL) + return; + + png_ptr->error_ptr = error_ptr; + png_ptr->error_fn = error_fn; +#ifdef PNG_WARNINGS_SUPPORTED + png_ptr->warning_fn = warning_fn; +#else + PNG_UNUSED(warning_fn) +#endif +} + + +/* This function returns a pointer to the error_ptr associated with the user + * functions. The application should free any memory associated with this + * pointer before png_write_destroy and png_read_destroy are called. + */ +png_voidp PNGAPI +png_get_error_ptr(png_const_structrp png_ptr) +{ + if (png_ptr == NULL) + return NULL; + + return ((png_voidp)png_ptr->error_ptr); +} + + +#ifdef PNG_ERROR_NUMBERS_SUPPORTED +void PNGAPI +png_set_strip_error_numbers(png_structrp png_ptr, png_uint_32 strip_mode) +{ + if (png_ptr != NULL) + { + png_ptr->flags &= + ((~(PNG_FLAG_STRIP_ERROR_NUMBERS | + PNG_FLAG_STRIP_ERROR_TEXT))&strip_mode); + } +} +#endif + +#if defined(PNG_SIMPLIFIED_READ_SUPPORTED) ||\ + defined(PNG_SIMPLIFIED_WRITE_SUPPORTED) + /* Currently the above both depend on SETJMP_SUPPORTED, however it would be + * possible to implement without setjmp support just so long as there is some + * way to handle the error return here: + */ +PNG_FUNCTION(void /* PRIVATE */, +png_safe_error,(png_structp png_nonconst_ptr, png_const_charp error_message), + PNG_NORETURN) +{ + const png_const_structrp png_ptr = png_nonconst_ptr; + png_imagep image = png_voidcast(png_imagep, png_ptr->error_ptr); + + /* An error is always logged here, overwriting anything (typically a warning) + * that is already there: + */ + if (image != NULL) + { + png_safecat(image->message, (sizeof image->message), 0, error_message); + image->warning_or_error |= PNG_IMAGE_ERROR; + + /* Retrieve the jmp_buf from within the png_control, making this work for + * C++ compilation too is pretty tricky: C++ wants a pointer to the first + * element of a jmp_buf, but C doesn't tell us the type of that. + */ + if (image->opaque != NULL && image->opaque->error_buf != NULL) + longjmp(png_control_jmp_buf(image->opaque), 1); + + /* Missing longjmp buffer, the following is to help debugging: */ + { + size_t pos = png_safecat(image->message, (sizeof image->message), 0, + "bad longjmp: "); + png_safecat(image->message, (sizeof image->message), pos, + error_message); + } + } + + /* Here on an internal programming error. */ + abort(); +} + +#ifdef PNG_WARNINGS_SUPPORTED +void /* PRIVATE */ +png_safe_warning(png_structp png_nonconst_ptr, png_const_charp warning_message) +{ + const png_const_structrp png_ptr = png_nonconst_ptr; + png_imagep image = png_voidcast(png_imagep, png_ptr->error_ptr); + + /* A warning is only logged if there is no prior warning or error. */ + if (image->warning_or_error == 0) + { + png_safecat(image->message, (sizeof image->message), 0, warning_message); + image->warning_or_error |= PNG_IMAGE_WARNING; + } +} +#endif + +int /* PRIVATE */ +png_safe_execute(png_imagep image_in, int (*function)(png_voidp), png_voidp arg) +{ + volatile png_imagep image = image_in; + volatile int result; + volatile png_voidp saved_error_buf; + jmp_buf safe_jmpbuf; + + /* Safely execute function(arg) with png_error returning to this function. */ + saved_error_buf = image->opaque->error_buf; + result = setjmp(safe_jmpbuf) == 0; + + if (result) + { + + image->opaque->error_buf = safe_jmpbuf; + result = function(arg); + } + + image->opaque->error_buf = saved_error_buf; + + /* And do the cleanup prior to any failure return. */ + if (!result) + png_image_free(image); + + return result; +} +#endif /* SIMPLIFIED READ/WRITE */ +#endif /* PNG_READ_SUPPORTED || PNG_WRITE_SUPPORTED */ diff --git a/ml/dlib/dlib/external/libpng/pngget.c b/ml/dlib/dlib/external/libpng/pngget.c new file mode 100644 index 000000000..aca63a958 --- /dev/null +++ b/ml/dlib/dlib/external/libpng/pngget.c @@ -0,0 +1,1177 @@ + +/* pngget.c - retrieval of values from info struct + * + * Last changed in libpng 1.6.1 [March 28, 2013] + * Copyright (c) 1998-2013 Glenn Randers-Pehrson + * (Version 0.96 Copyright (c) 1996, 1997 Andreas Dilger) + * (Version 0.88 Copyright (c) 1995, 1996 Guy Eric Schalnat, Group 42, Inc.) + * + * This code is released under the libpng license. + * For conditions of distribution and use, see the disclaimer + * and license in png.h + * + */ + +#include "pngpriv.h" + +#if defined(PNG_READ_SUPPORTED) || defined(PNG_WRITE_SUPPORTED) + +png_uint_32 PNGAPI +png_get_valid(png_const_structrp png_ptr, png_const_inforp info_ptr, + png_uint_32 flag) +{ + if (png_ptr != NULL && info_ptr != NULL) + return(info_ptr->valid & flag); + + return(0); +} + +png_size_t PNGAPI +png_get_rowbytes(png_const_structrp png_ptr, png_const_inforp info_ptr) +{ + if (png_ptr != NULL && info_ptr != NULL) + return(info_ptr->rowbytes); + + return(0); +} + +#ifdef PNG_INFO_IMAGE_SUPPORTED +png_bytepp PNGAPI +png_get_rows(png_const_structrp png_ptr, png_const_inforp info_ptr) +{ + if (png_ptr != NULL && info_ptr != NULL) + return(info_ptr->row_pointers); + + return(0); +} +#endif + +#ifdef PNG_EASY_ACCESS_SUPPORTED +/* Easy access to info, added in libpng-0.99 */ +png_uint_32 PNGAPI +png_get_image_width(png_const_structrp png_ptr, png_const_inforp info_ptr) +{ + if (png_ptr != NULL && info_ptr != NULL) + return info_ptr->width; + + return (0); +} + +png_uint_32 PNGAPI +png_get_image_height(png_const_structrp png_ptr, png_const_inforp info_ptr) +{ + if (png_ptr != NULL && info_ptr != NULL) + return info_ptr->height; + + return (0); +} + +png_byte PNGAPI +png_get_bit_depth(png_const_structrp png_ptr, png_const_inforp info_ptr) +{ + if (png_ptr != NULL && info_ptr != NULL) + return info_ptr->bit_depth; + + return (0); +} + +png_byte PNGAPI +png_get_color_type(png_const_structrp png_ptr, png_const_inforp info_ptr) +{ + if (png_ptr != NULL && info_ptr != NULL) + return info_ptr->color_type; + + return (0); +} + +png_byte PNGAPI +png_get_filter_type(png_const_structrp png_ptr, png_const_inforp info_ptr) +{ + if (png_ptr != NULL && info_ptr != NULL) + return info_ptr->filter_type; + + return (0); +} + +png_byte PNGAPI +png_get_interlace_type(png_const_structrp png_ptr, png_const_inforp info_ptr) +{ + if (png_ptr != NULL && info_ptr != NULL) + return info_ptr->interlace_type; + + return (0); +} + +png_byte PNGAPI +png_get_compression_type(png_const_structrp png_ptr, png_const_inforp info_ptr) +{ + if (png_ptr != NULL && info_ptr != NULL) + return info_ptr->compression_type; + + return (0); +} + +png_uint_32 PNGAPI +png_get_x_pixels_per_meter(png_const_structrp png_ptr, png_const_inforp + info_ptr) +{ +#ifdef PNG_pHYs_SUPPORTED + if (png_ptr != NULL && info_ptr != NULL && (info_ptr->valid & PNG_INFO_pHYs)) + { + png_debug1(1, "in %s retrieval function", + "png_get_x_pixels_per_meter"); + + if (info_ptr->phys_unit_type == PNG_RESOLUTION_METER) + return (info_ptr->x_pixels_per_unit); + } +#endif + + return (0); +} + +png_uint_32 PNGAPI +png_get_y_pixels_per_meter(png_const_structrp png_ptr, png_const_inforp + info_ptr) +{ +#ifdef PNG_pHYs_SUPPORTED + if (png_ptr != NULL && info_ptr != NULL && (info_ptr->valid & PNG_INFO_pHYs)) + { + png_debug1(1, "in %s retrieval function", + "png_get_y_pixels_per_meter"); + + if (info_ptr->phys_unit_type == PNG_RESOLUTION_METER) + return (info_ptr->y_pixels_per_unit); + } +#endif + + return (0); +} + +png_uint_32 PNGAPI +png_get_pixels_per_meter(png_const_structrp png_ptr, png_const_inforp info_ptr) +{ +#ifdef PNG_pHYs_SUPPORTED + if (png_ptr != NULL && info_ptr != NULL && (info_ptr->valid & PNG_INFO_pHYs)) + { + png_debug1(1, "in %s retrieval function", "png_get_pixels_per_meter"); + + if (info_ptr->phys_unit_type == PNG_RESOLUTION_METER && + info_ptr->x_pixels_per_unit == info_ptr->y_pixels_per_unit) + return (info_ptr->x_pixels_per_unit); + } +#endif + + return (0); +} + +#ifdef PNG_FLOATING_POINT_SUPPORTED +float PNGAPI +png_get_pixel_aspect_ratio(png_const_structrp png_ptr, png_const_inforp + info_ptr) +{ +#ifdef PNG_READ_pHYs_SUPPORTED + if (png_ptr != NULL && info_ptr != NULL && (info_ptr->valid & PNG_INFO_pHYs)) + { + png_debug1(1, "in %s retrieval function", "png_get_aspect_ratio"); + + if (info_ptr->x_pixels_per_unit != 0) + return ((float)((float)info_ptr->y_pixels_per_unit + /(float)info_ptr->x_pixels_per_unit)); + } +#else + PNG_UNUSED(png_ptr) + PNG_UNUSED(info_ptr) +#endif + + return ((float)0.0); +} +#endif + +#ifdef PNG_FIXED_POINT_SUPPORTED +png_fixed_point PNGAPI +png_get_pixel_aspect_ratio_fixed(png_const_structrp png_ptr, + png_const_inforp info_ptr) +{ +#ifdef PNG_READ_pHYs_SUPPORTED + if (png_ptr != NULL && info_ptr != NULL && (info_ptr->valid & PNG_INFO_pHYs) + && info_ptr->x_pixels_per_unit > 0 && info_ptr->y_pixels_per_unit > 0 + && info_ptr->x_pixels_per_unit <= PNG_UINT_31_MAX + && info_ptr->y_pixels_per_unit <= PNG_UINT_31_MAX) + { + png_fixed_point res; + + png_debug1(1, "in %s retrieval function", "png_get_aspect_ratio_fixed"); + + /* The following casts work because a PNG 4 byte integer only has a valid + * range of 0..2^31-1; otherwise the cast might overflow. + */ + if (png_muldiv(&res, (png_int_32)info_ptr->y_pixels_per_unit, PNG_FP_1, + (png_int_32)info_ptr->x_pixels_per_unit)) + return res; + } +#else + PNG_UNUSED(png_ptr) + PNG_UNUSED(info_ptr) +#endif + + return 0; +} +#endif + +png_int_32 PNGAPI +png_get_x_offset_microns(png_const_structrp png_ptr, png_const_inforp info_ptr) +{ +#ifdef PNG_oFFs_SUPPORTED + if (png_ptr != NULL && info_ptr != NULL && (info_ptr->valid & PNG_INFO_oFFs)) + { + png_debug1(1, "in %s retrieval function", "png_get_x_offset_microns"); + + if (info_ptr->offset_unit_type == PNG_OFFSET_MICROMETER) + return (info_ptr->x_offset); + } +#endif + + return (0); +} + +png_int_32 PNGAPI +png_get_y_offset_microns(png_const_structrp png_ptr, png_const_inforp info_ptr) +{ +#ifdef PNG_oFFs_SUPPORTED + if (png_ptr != NULL && info_ptr != NULL && (info_ptr->valid & PNG_INFO_oFFs)) + { + png_debug1(1, "in %s retrieval function", "png_get_y_offset_microns"); + + if (info_ptr->offset_unit_type == PNG_OFFSET_MICROMETER) + return (info_ptr->y_offset); + } +#endif + + return (0); +} + +png_int_32 PNGAPI +png_get_x_offset_pixels(png_const_structrp png_ptr, png_const_inforp info_ptr) +{ +#ifdef PNG_oFFs_SUPPORTED + if (png_ptr != NULL && info_ptr != NULL && (info_ptr->valid & PNG_INFO_oFFs)) + { + png_debug1(1, "in %s retrieval function", "png_get_x_offset_pixels"); + + if (info_ptr->offset_unit_type == PNG_OFFSET_PIXEL) + return (info_ptr->x_offset); + } +#endif + + return (0); +} + +png_int_32 PNGAPI +png_get_y_offset_pixels(png_const_structrp png_ptr, png_const_inforp info_ptr) +{ +#ifdef PNG_oFFs_SUPPORTED + if (png_ptr != NULL && info_ptr != NULL && (info_ptr->valid & PNG_INFO_oFFs)) + { + png_debug1(1, "in %s retrieval function", "png_get_y_offset_pixels"); + + if (info_ptr->offset_unit_type == PNG_OFFSET_PIXEL) + return (info_ptr->y_offset); + } +#endif + + return (0); +} + +#ifdef PNG_INCH_CONVERSIONS_SUPPORTED +static png_uint_32 +ppi_from_ppm(png_uint_32 ppm) +{ +#if 0 + /* The conversion is *(2.54/100), in binary (32 digits): + * .00000110100000001001110101001001 + */ + png_uint_32 t1001, t1101; + ppm >>= 1; /* .1 */ + t1001 = ppm + (ppm >> 3); /* .1001 */ + t1101 = t1001 + (ppm >> 1); /* .1101 */ + ppm >>= 20; /* .000000000000000000001 */ + t1101 += t1101 >> 15; /* .1101000000000001101 */ + t1001 >>= 11; /* .000000000001001 */ + t1001 += t1001 >> 12; /* .000000000001001000000001001 */ + ppm += t1001; /* .000000000001001000001001001 */ + ppm += t1101; /* .110100000001001110101001001 */ + return (ppm + 16) >> 5;/* .00000110100000001001110101001001 */ +#else + /* The argument is a PNG unsigned integer, so it is not permitted + * to be bigger than 2^31. + */ + png_fixed_point result; + if (ppm <= PNG_UINT_31_MAX && png_muldiv(&result, (png_int_32)ppm, 127, + 5000)) + return result; + + /* Overflow. */ + return 0; +#endif +} + +png_uint_32 PNGAPI +png_get_pixels_per_inch(png_const_structrp png_ptr, png_const_inforp info_ptr) +{ + return ppi_from_ppm(png_get_pixels_per_meter(png_ptr, info_ptr)); +} + +png_uint_32 PNGAPI +png_get_x_pixels_per_inch(png_const_structrp png_ptr, png_const_inforp info_ptr) +{ + return ppi_from_ppm(png_get_x_pixels_per_meter(png_ptr, info_ptr)); +} + +png_uint_32 PNGAPI +png_get_y_pixels_per_inch(png_const_structrp png_ptr, png_const_inforp info_ptr) +{ + return ppi_from_ppm(png_get_y_pixels_per_meter(png_ptr, info_ptr)); +} + +#ifdef PNG_FIXED_POINT_SUPPORTED +static png_fixed_point +png_fixed_inches_from_microns(png_const_structrp png_ptr, png_int_32 microns) +{ + /* Convert from metres * 1,000,000 to inches * 100,000, meters to + * inches is simply *(100/2.54), so we want *(10/2.54) == 500/127. + * Notice that this can overflow - a warning is output and 0 is + * returned. + */ + return png_muldiv_warn(png_ptr, microns, 500, 127); +} + +png_fixed_point PNGAPI +png_get_x_offset_inches_fixed(png_const_structrp png_ptr, + png_const_inforp info_ptr) +{ + return png_fixed_inches_from_microns(png_ptr, + png_get_x_offset_microns(png_ptr, info_ptr)); +} +#endif + +#ifdef PNG_FIXED_POINT_SUPPORTED +png_fixed_point PNGAPI +png_get_y_offset_inches_fixed(png_const_structrp png_ptr, + png_const_inforp info_ptr) +{ + return png_fixed_inches_from_microns(png_ptr, + png_get_y_offset_microns(png_ptr, info_ptr)); +} +#endif + +#ifdef PNG_FLOATING_POINT_SUPPORTED +float PNGAPI +png_get_x_offset_inches(png_const_structrp png_ptr, png_const_inforp info_ptr) +{ + /* To avoid the overflow do the conversion directly in floating + * point. + */ + return (float)(png_get_x_offset_microns(png_ptr, info_ptr) * .00003937); +} +#endif + +#ifdef PNG_FLOATING_POINT_SUPPORTED +float PNGAPI +png_get_y_offset_inches(png_const_structrp png_ptr, png_const_inforp info_ptr) +{ + /* To avoid the overflow do the conversion directly in floating + * point. + */ + return (float)(png_get_y_offset_microns(png_ptr, info_ptr) * .00003937); +} +#endif + +#ifdef PNG_pHYs_SUPPORTED +png_uint_32 PNGAPI +png_get_pHYs_dpi(png_const_structrp png_ptr, png_const_inforp info_ptr, + png_uint_32 *res_x, png_uint_32 *res_y, int *unit_type) +{ + png_uint_32 retval = 0; + + if (png_ptr != NULL && info_ptr != NULL && (info_ptr->valid & PNG_INFO_pHYs)) + { + png_debug1(1, "in %s retrieval function", "pHYs"); + + if (res_x != NULL) + { + *res_x = info_ptr->x_pixels_per_unit; + retval |= PNG_INFO_pHYs; + } + + if (res_y != NULL) + { + *res_y = info_ptr->y_pixels_per_unit; + retval |= PNG_INFO_pHYs; + } + + if (unit_type != NULL) + { + *unit_type = (int)info_ptr->phys_unit_type; + retval |= PNG_INFO_pHYs; + + if (*unit_type == 1) + { + if (res_x != NULL) *res_x = (png_uint_32)(*res_x * .0254 + .50); + if (res_y != NULL) *res_y = (png_uint_32)(*res_y * .0254 + .50); + } + } + } + + return (retval); +} +#endif /* PNG_pHYs_SUPPORTED */ +#endif /* PNG_INCH_CONVERSIONS_SUPPORTED */ + +/* png_get_channels really belongs in here, too, but it's been around longer */ + +#endif /* PNG_EASY_ACCESS_SUPPORTED */ + + +png_byte PNGAPI +png_get_channels(png_const_structrp png_ptr, png_const_inforp info_ptr) +{ + if (png_ptr != NULL && info_ptr != NULL) + return(info_ptr->channels); + + return (0); +} + +#ifdef PNG_READ_SUPPORTED +png_const_bytep PNGAPI +png_get_signature(png_const_structrp png_ptr, png_const_inforp info_ptr) +{ + if (png_ptr != NULL && info_ptr != NULL) + return(info_ptr->signature); + + return (NULL); +} +#endif + +#ifdef PNG_bKGD_SUPPORTED +png_uint_32 PNGAPI +png_get_bKGD(png_const_structrp png_ptr, png_inforp info_ptr, + png_color_16p *background) +{ + if (png_ptr != NULL && info_ptr != NULL && (info_ptr->valid & PNG_INFO_bKGD) + && background != NULL) + { + png_debug1(1, "in %s retrieval function", "bKGD"); + + *background = &(info_ptr->background); + return (PNG_INFO_bKGD); + } + + return (0); +} +#endif + +#ifdef PNG_cHRM_SUPPORTED +/* The XYZ APIs were added in 1.5.5 to take advantage of the code added at the + * same time to correct the rgb grayscale coefficient defaults obtained from the + * cHRM chunk in 1.5.4 + */ +# ifdef PNG_FLOATING_POINT_SUPPORTED +png_uint_32 PNGAPI +png_get_cHRM(png_const_structrp png_ptr, png_const_inforp info_ptr, + double *white_x, double *white_y, double *red_x, double *red_y, + double *green_x, double *green_y, double *blue_x, double *blue_y) +{ + /* Quiet API change: this code used to only return the end points if a cHRM + * chunk was present, but the end points can also come from iCCP or sRGB + * chunks, so in 1.6.0 the png_get_ APIs return the end points regardless and + * the png_set_ APIs merely check that set end points are mutually + * consistent. + */ + if (png_ptr != NULL && info_ptr != NULL && + (info_ptr->colorspace.flags & PNG_COLORSPACE_HAVE_ENDPOINTS)) + { + png_debug1(1, "in %s retrieval function", "cHRM"); + + if (white_x != NULL) + *white_x = png_float(png_ptr, + info_ptr->colorspace.end_points_xy.whitex, "cHRM white X"); + if (white_y != NULL) + *white_y = png_float(png_ptr, + info_ptr->colorspace.end_points_xy.whitey, "cHRM white Y"); + if (red_x != NULL) + *red_x = png_float(png_ptr, info_ptr->colorspace.end_points_xy.redx, + "cHRM red X"); + if (red_y != NULL) + *red_y = png_float(png_ptr, info_ptr->colorspace.end_points_xy.redy, + "cHRM red Y"); + if (green_x != NULL) + *green_x = png_float(png_ptr, + info_ptr->colorspace.end_points_xy.greenx, "cHRM green X"); + if (green_y != NULL) + *green_y = png_float(png_ptr, + info_ptr->colorspace.end_points_xy.greeny, "cHRM green Y"); + if (blue_x != NULL) + *blue_x = png_float(png_ptr, info_ptr->colorspace.end_points_xy.bluex, + "cHRM blue X"); + if (blue_y != NULL) + *blue_y = png_float(png_ptr, info_ptr->colorspace.end_points_xy.bluey, + "cHRM blue Y"); + return (PNG_INFO_cHRM); + } + + return (0); +} + +png_uint_32 PNGAPI +png_get_cHRM_XYZ(png_const_structrp png_ptr, png_const_inforp info_ptr, + double *red_X, double *red_Y, double *red_Z, double *green_X, + double *green_Y, double *green_Z, double *blue_X, double *blue_Y, + double *blue_Z) +{ + if (png_ptr != NULL && info_ptr != NULL && + (info_ptr->colorspace.flags & PNG_COLORSPACE_HAVE_ENDPOINTS)) + { + png_debug1(1, "in %s retrieval function", "cHRM_XYZ(float)"); + + if (red_X != NULL) + *red_X = png_float(png_ptr, info_ptr->colorspace.end_points_XYZ.red_X, + "cHRM red X"); + if (red_Y != NULL) + *red_Y = png_float(png_ptr, info_ptr->colorspace.end_points_XYZ.red_Y, + "cHRM red Y"); + if (red_Z != NULL) + *red_Z = png_float(png_ptr, info_ptr->colorspace.end_points_XYZ.red_Z, + "cHRM red Z"); + if (green_X != NULL) + *green_X = png_float(png_ptr, + info_ptr->colorspace.end_points_XYZ.green_X, "cHRM green X"); + if (green_Y != NULL) + *green_Y = png_float(png_ptr, + info_ptr->colorspace.end_points_XYZ.green_Y, "cHRM green Y"); + if (green_Z != NULL) + *green_Z = png_float(png_ptr, + info_ptr->colorspace.end_points_XYZ.green_Z, "cHRM green Z"); + if (blue_X != NULL) + *blue_X = png_float(png_ptr, + info_ptr->colorspace.end_points_XYZ.blue_X, "cHRM blue X"); + if (blue_Y != NULL) + *blue_Y = png_float(png_ptr, + info_ptr->colorspace.end_points_XYZ.blue_Y, "cHRM blue Y"); + if (blue_Z != NULL) + *blue_Z = png_float(png_ptr, + info_ptr->colorspace.end_points_XYZ.blue_Z, "cHRM blue Z"); + return (PNG_INFO_cHRM); + } + + return (0); +} +# endif + +# ifdef PNG_FIXED_POINT_SUPPORTED +png_uint_32 PNGAPI +png_get_cHRM_XYZ_fixed(png_const_structrp png_ptr, png_const_inforp info_ptr, + png_fixed_point *int_red_X, png_fixed_point *int_red_Y, + png_fixed_point *int_red_Z, png_fixed_point *int_green_X, + png_fixed_point *int_green_Y, png_fixed_point *int_green_Z, + png_fixed_point *int_blue_X, png_fixed_point *int_blue_Y, + png_fixed_point *int_blue_Z) +{ + if (png_ptr != NULL && info_ptr != NULL && + (info_ptr->colorspace.flags & PNG_COLORSPACE_HAVE_ENDPOINTS)) + { + png_debug1(1, "in %s retrieval function", "cHRM_XYZ"); + + if (int_red_X != NULL) + *int_red_X = info_ptr->colorspace.end_points_XYZ.red_X; + if (int_red_Y != NULL) + *int_red_Y = info_ptr->colorspace.end_points_XYZ.red_Y; + if (int_red_Z != NULL) + *int_red_Z = info_ptr->colorspace.end_points_XYZ.red_Z; + if (int_green_X != NULL) + *int_green_X = info_ptr->colorspace.end_points_XYZ.green_X; + if (int_green_Y != NULL) + *int_green_Y = info_ptr->colorspace.end_points_XYZ.green_Y; + if (int_green_Z != NULL) + *int_green_Z = info_ptr->colorspace.end_points_XYZ.green_Z; + if (int_blue_X != NULL) + *int_blue_X = info_ptr->colorspace.end_points_XYZ.blue_X; + if (int_blue_Y != NULL) + *int_blue_Y = info_ptr->colorspace.end_points_XYZ.blue_Y; + if (int_blue_Z != NULL) + *int_blue_Z = info_ptr->colorspace.end_points_XYZ.blue_Z; + return (PNG_INFO_cHRM); + } + + return (0); +} + +png_uint_32 PNGAPI +png_get_cHRM_fixed(png_const_structrp png_ptr, png_const_inforp info_ptr, + png_fixed_point *white_x, png_fixed_point *white_y, png_fixed_point *red_x, + png_fixed_point *red_y, png_fixed_point *green_x, png_fixed_point *green_y, + png_fixed_point *blue_x, png_fixed_point *blue_y) +{ + png_debug1(1, "in %s retrieval function", "cHRM"); + + if (png_ptr != NULL && info_ptr != NULL && + (info_ptr->colorspace.flags & PNG_COLORSPACE_HAVE_ENDPOINTS)) + { + if (white_x != NULL) + *white_x = info_ptr->colorspace.end_points_xy.whitex; + if (white_y != NULL) + *white_y = info_ptr->colorspace.end_points_xy.whitey; + if (red_x != NULL) + *red_x = info_ptr->colorspace.end_points_xy.redx; + if (red_y != NULL) + *red_y = info_ptr->colorspace.end_points_xy.redy; + if (green_x != NULL) + *green_x = info_ptr->colorspace.end_points_xy.greenx; + if (green_y != NULL) + *green_y = info_ptr->colorspace.end_points_xy.greeny; + if (blue_x != NULL) + *blue_x = info_ptr->colorspace.end_points_xy.bluex; + if (blue_y != NULL) + *blue_y = info_ptr->colorspace.end_points_xy.bluey; + return (PNG_INFO_cHRM); + } + + return (0); +} +# endif +#endif + +#ifdef PNG_gAMA_SUPPORTED +# ifdef PNG_FIXED_POINT_SUPPORTED +png_uint_32 PNGAPI +png_get_gAMA_fixed(png_const_structrp png_ptr, png_const_inforp info_ptr, + png_fixed_point *file_gamma) +{ + png_debug1(1, "in %s retrieval function", "gAMA"); + + if (png_ptr != NULL && info_ptr != NULL && + (info_ptr->colorspace.flags & PNG_COLORSPACE_HAVE_GAMMA) && + file_gamma != NULL) + { + *file_gamma = info_ptr->colorspace.gamma; + return (PNG_INFO_gAMA); + } + + return (0); +} +# endif + +# ifdef PNG_FLOATING_POINT_SUPPORTED +png_uint_32 PNGAPI +png_get_gAMA(png_const_structrp png_ptr, png_const_inforp info_ptr, + double *file_gamma) +{ + png_debug1(1, "in %s retrieval function", "gAMA(float)"); + + if (png_ptr != NULL && info_ptr != NULL && + (info_ptr->colorspace.flags & PNG_COLORSPACE_HAVE_GAMMA) && + file_gamma != NULL) + { + *file_gamma = png_float(png_ptr, info_ptr->colorspace.gamma, + "png_get_gAMA"); + return (PNG_INFO_gAMA); + } + + return (0); +} +# endif +#endif + +#ifdef PNG_sRGB_SUPPORTED +png_uint_32 PNGAPI +png_get_sRGB(png_const_structrp png_ptr, png_const_inforp info_ptr, + int *file_srgb_intent) +{ + png_debug1(1, "in %s retrieval function", "sRGB"); + + if (png_ptr != NULL && info_ptr != NULL && (info_ptr->valid & PNG_INFO_sRGB) + && file_srgb_intent != NULL) + { + *file_srgb_intent = info_ptr->colorspace.rendering_intent; + return (PNG_INFO_sRGB); + } + + return (0); +} +#endif + +#ifdef PNG_iCCP_SUPPORTED +png_uint_32 PNGAPI +png_get_iCCP(png_const_structrp png_ptr, png_inforp info_ptr, + png_charpp name, int *compression_type, + png_bytepp profile, png_uint_32 *proflen) +{ + png_debug1(1, "in %s retrieval function", "iCCP"); + + if (png_ptr != NULL && info_ptr != NULL && (info_ptr->valid & PNG_INFO_iCCP) + && name != NULL && compression_type != NULL && profile != NULL && + proflen != NULL) + { + *name = info_ptr->iccp_name; + *profile = info_ptr->iccp_profile; + *proflen = png_get_uint_32(info_ptr->iccp_profile); + /* This is somewhat irrelevant since the profile data returned has + * actually been uncompressed. + */ + *compression_type = PNG_COMPRESSION_TYPE_BASE; + return (PNG_INFO_iCCP); + } + + return (0); +} +#endif + +#ifdef PNG_sPLT_SUPPORTED +int PNGAPI +png_get_sPLT(png_const_structrp png_ptr, png_inforp info_ptr, + png_sPLT_tpp spalettes) +{ + if (png_ptr != NULL && info_ptr != NULL && spalettes != NULL) + { + *spalettes = info_ptr->splt_palettes; + return info_ptr->splt_palettes_num; + } + + return (0); +} +#endif + +#ifdef PNG_hIST_SUPPORTED +png_uint_32 PNGAPI +png_get_hIST(png_const_structrp png_ptr, png_inforp info_ptr, + png_uint_16p *hist) +{ + png_debug1(1, "in %s retrieval function", "hIST"); + + if (png_ptr != NULL && info_ptr != NULL && (info_ptr->valid & PNG_INFO_hIST) + && hist != NULL) + { + *hist = info_ptr->hist; + return (PNG_INFO_hIST); + } + + return (0); +} +#endif + +png_uint_32 PNGAPI +png_get_IHDR(png_const_structrp png_ptr, png_const_inforp info_ptr, + png_uint_32 *width, png_uint_32 *height, int *bit_depth, + int *color_type, int *interlace_type, int *compression_type, + int *filter_type) +{ + png_debug1(1, "in %s retrieval function", "IHDR"); + + if (png_ptr == NULL || info_ptr == NULL || width == NULL || + height == NULL || bit_depth == NULL || color_type == NULL) + return (0); + + *width = info_ptr->width; + *height = info_ptr->height; + *bit_depth = info_ptr->bit_depth; + *color_type = info_ptr->color_type; + + if (compression_type != NULL) + *compression_type = info_ptr->compression_type; + + if (filter_type != NULL) + *filter_type = info_ptr->filter_type; + + if (interlace_type != NULL) + *interlace_type = info_ptr->interlace_type; + + /* This is redundant if we can be sure that the info_ptr values were all + * assigned in png_set_IHDR(). We do the check anyhow in case an + * application has ignored our advice not to mess with the members + * of info_ptr directly. + */ + png_check_IHDR(png_ptr, info_ptr->width, info_ptr->height, + info_ptr->bit_depth, info_ptr->color_type, info_ptr->interlace_type, + info_ptr->compression_type, info_ptr->filter_type); + + return (1); +} + +#ifdef PNG_oFFs_SUPPORTED +png_uint_32 PNGAPI +png_get_oFFs(png_const_structrp png_ptr, png_const_inforp info_ptr, + png_int_32 *offset_x, png_int_32 *offset_y, int *unit_type) +{ + png_debug1(1, "in %s retrieval function", "oFFs"); + + if (png_ptr != NULL && info_ptr != NULL && (info_ptr->valid & PNG_INFO_oFFs) + && offset_x != NULL && offset_y != NULL && unit_type != NULL) + { + *offset_x = info_ptr->x_offset; + *offset_y = info_ptr->y_offset; + *unit_type = (int)info_ptr->offset_unit_type; + return (PNG_INFO_oFFs); + } + + return (0); +} +#endif + +#ifdef PNG_pCAL_SUPPORTED +png_uint_32 PNGAPI +png_get_pCAL(png_const_structrp png_ptr, png_inforp info_ptr, + png_charp *purpose, png_int_32 *X0, png_int_32 *X1, int *type, int *nparams, + png_charp *units, png_charpp *params) +{ + png_debug1(1, "in %s retrieval function", "pCAL"); + + if (png_ptr != NULL && info_ptr != NULL && (info_ptr->valid & PNG_INFO_pCAL) + && purpose != NULL && X0 != NULL && X1 != NULL && type != NULL && + nparams != NULL && units != NULL && params != NULL) + { + *purpose = info_ptr->pcal_purpose; + *X0 = info_ptr->pcal_X0; + *X1 = info_ptr->pcal_X1; + *type = (int)info_ptr->pcal_type; + *nparams = (int)info_ptr->pcal_nparams; + *units = info_ptr->pcal_units; + *params = info_ptr->pcal_params; + return (PNG_INFO_pCAL); + } + + return (0); +} +#endif + +#ifdef PNG_sCAL_SUPPORTED +# ifdef PNG_FIXED_POINT_SUPPORTED +# if defined(PNG_FLOATING_ARITHMETIC_SUPPORTED) || \ + defined(PNG_FLOATING_POINT_SUPPORTED) +png_uint_32 PNGAPI +png_get_sCAL_fixed(png_const_structrp png_ptr, png_const_inforp info_ptr, + int *unit, png_fixed_point *width, png_fixed_point *height) +{ + if (png_ptr != NULL && info_ptr != NULL && + (info_ptr->valid & PNG_INFO_sCAL)) + { + *unit = info_ptr->scal_unit; + /*TODO: make this work without FP support; the API is currently eliminated + * if neither floating point APIs nor internal floating point arithmetic + * are enabled. + */ + *width = png_fixed(png_ptr, atof(info_ptr->scal_s_width), "sCAL width"); + *height = png_fixed(png_ptr, atof(info_ptr->scal_s_height), + "sCAL height"); + return (PNG_INFO_sCAL); + } + + return(0); +} +# endif /* FLOATING_ARITHMETIC */ +# endif /* FIXED_POINT */ +# ifdef PNG_FLOATING_POINT_SUPPORTED +png_uint_32 PNGAPI +png_get_sCAL(png_const_structrp png_ptr, png_const_inforp info_ptr, + int *unit, double *width, double *height) +{ + if (png_ptr != NULL && info_ptr != NULL && + (info_ptr->valid & PNG_INFO_sCAL)) + { + *unit = info_ptr->scal_unit; + *width = atof(info_ptr->scal_s_width); + *height = atof(info_ptr->scal_s_height); + return (PNG_INFO_sCAL); + } + + return(0); +} +# endif /* FLOATING POINT */ +png_uint_32 PNGAPI +png_get_sCAL_s(png_const_structrp png_ptr, png_const_inforp info_ptr, + int *unit, png_charpp width, png_charpp height) +{ + if (png_ptr != NULL && info_ptr != NULL && + (info_ptr->valid & PNG_INFO_sCAL)) + { + *unit = info_ptr->scal_unit; + *width = info_ptr->scal_s_width; + *height = info_ptr->scal_s_height; + return (PNG_INFO_sCAL); + } + + return(0); +} +#endif /* sCAL */ + +#ifdef PNG_pHYs_SUPPORTED +png_uint_32 PNGAPI +png_get_pHYs(png_const_structrp png_ptr, png_const_inforp info_ptr, + png_uint_32 *res_x, png_uint_32 *res_y, int *unit_type) +{ + png_uint_32 retval = 0; + + png_debug1(1, "in %s retrieval function", "pHYs"); + + if (png_ptr != NULL && info_ptr != NULL && + (info_ptr->valid & PNG_INFO_pHYs)) + { + if (res_x != NULL) + { + *res_x = info_ptr->x_pixels_per_unit; + retval |= PNG_INFO_pHYs; + } + + if (res_y != NULL) + { + *res_y = info_ptr->y_pixels_per_unit; + retval |= PNG_INFO_pHYs; + } + + if (unit_type != NULL) + { + *unit_type = (int)info_ptr->phys_unit_type; + retval |= PNG_INFO_pHYs; + } + } + + return (retval); +} +#endif /* pHYs */ + +png_uint_32 PNGAPI +png_get_PLTE(png_const_structrp png_ptr, png_inforp info_ptr, + png_colorp *palette, int *num_palette) +{ + png_debug1(1, "in %s retrieval function", "PLTE"); + + if (png_ptr != NULL && info_ptr != NULL && (info_ptr->valid & PNG_INFO_PLTE) + && palette != NULL) + { + *palette = info_ptr->palette; + *num_palette = info_ptr->num_palette; + png_debug1(3, "num_palette = %d", *num_palette); + return (PNG_INFO_PLTE); + } + + return (0); +} + +#ifdef PNG_sBIT_SUPPORTED +png_uint_32 PNGAPI +png_get_sBIT(png_const_structrp png_ptr, png_inforp info_ptr, + png_color_8p *sig_bit) +{ + png_debug1(1, "in %s retrieval function", "sBIT"); + + if (png_ptr != NULL && info_ptr != NULL && (info_ptr->valid & PNG_INFO_sBIT) + && sig_bit != NULL) + { + *sig_bit = &(info_ptr->sig_bit); + return (PNG_INFO_sBIT); + } + + return (0); +} +#endif + +#ifdef PNG_TEXT_SUPPORTED +int PNGAPI +png_get_text(png_const_structrp png_ptr, png_inforp info_ptr, + png_textp *text_ptr, int *num_text) +{ + if (png_ptr != NULL && info_ptr != NULL && info_ptr->num_text > 0) + { + png_debug1(1, "in 0x%lx retrieval function", + (unsigned long)png_ptr->chunk_name); + + if (text_ptr != NULL) + *text_ptr = info_ptr->text; + + if (num_text != NULL) + *num_text = info_ptr->num_text; + + return info_ptr->num_text; + } + + if (num_text != NULL) + *num_text = 0; + + return(0); +} +#endif + +#ifdef PNG_tIME_SUPPORTED +png_uint_32 PNGAPI +png_get_tIME(png_const_structrp png_ptr, png_inforp info_ptr, + png_timep *mod_time) +{ + png_debug1(1, "in %s retrieval function", "tIME"); + + if (png_ptr != NULL && info_ptr != NULL && (info_ptr->valid & PNG_INFO_tIME) + && mod_time != NULL) + { + *mod_time = &(info_ptr->mod_time); + return (PNG_INFO_tIME); + } + + return (0); +} +#endif + +#ifdef PNG_tRNS_SUPPORTED +png_uint_32 PNGAPI +png_get_tRNS(png_const_structrp png_ptr, png_inforp info_ptr, + png_bytep *trans_alpha, int *num_trans, png_color_16p *trans_color) +{ + png_uint_32 retval = 0; + if (png_ptr != NULL && info_ptr != NULL && (info_ptr->valid & PNG_INFO_tRNS)) + { + png_debug1(1, "in %s retrieval function", "tRNS"); + + if (info_ptr->color_type == PNG_COLOR_TYPE_PALETTE) + { + if (trans_alpha != NULL) + { + *trans_alpha = info_ptr->trans_alpha; + retval |= PNG_INFO_tRNS; + } + + if (trans_color != NULL) + *trans_color = &(info_ptr->trans_color); + } + + else /* if (info_ptr->color_type != PNG_COLOR_TYPE_PALETTE) */ + { + if (trans_color != NULL) + { + *trans_color = &(info_ptr->trans_color); + retval |= PNG_INFO_tRNS; + } + + if (trans_alpha != NULL) + *trans_alpha = NULL; + } + + if (num_trans != NULL) + { + *num_trans = info_ptr->num_trans; + retval |= PNG_INFO_tRNS; + } + } + + return (retval); +} +#endif + +#ifdef PNG_STORE_UNKNOWN_CHUNKS_SUPPORTED +int PNGAPI +png_get_unknown_chunks(png_const_structrp png_ptr, png_inforp info_ptr, + png_unknown_chunkpp unknowns) +{ + if (png_ptr != NULL && info_ptr != NULL && unknowns != NULL) + { + *unknowns = info_ptr->unknown_chunks; + return info_ptr->unknown_chunks_num; + } + + return (0); +} +#endif + +#ifdef PNG_READ_RGB_TO_GRAY_SUPPORTED +png_byte PNGAPI +png_get_rgb_to_gray_status (png_const_structrp png_ptr) +{ + return (png_byte)(png_ptr ? png_ptr->rgb_to_gray_status : 0); +} +#endif + +#ifdef PNG_USER_CHUNKS_SUPPORTED +png_voidp PNGAPI +png_get_user_chunk_ptr(png_const_structrp png_ptr) +{ + return (png_ptr ? png_ptr->user_chunk_ptr : NULL); +} +#endif + +png_size_t PNGAPI +png_get_compression_buffer_size(png_const_structrp png_ptr) +{ + if (png_ptr == NULL) + return 0; + +# ifdef PNG_WRITE_SUPPORTED + if (png_ptr->mode & PNG_IS_READ_STRUCT) +# endif + { +# ifdef PNG_SEQUENTIAL_READ_SUPPORTED + return png_ptr->IDAT_read_size; +# else + return PNG_IDAT_READ_SIZE; +# endif + } + +# ifdef PNG_WRITE_SUPPORTED + else + return png_ptr->zbuffer_size; +# endif +} + +#ifdef PNG_SET_USER_LIMITS_SUPPORTED +/* These functions were added to libpng 1.2.6 and were enabled + * by default in libpng-1.4.0 */ +png_uint_32 PNGAPI +png_get_user_width_max (png_const_structrp png_ptr) +{ + return (png_ptr ? png_ptr->user_width_max : 0); +} + +png_uint_32 PNGAPI +png_get_user_height_max (png_const_structrp png_ptr) +{ + return (png_ptr ? png_ptr->user_height_max : 0); +} + +/* This function was added to libpng 1.4.0 */ +png_uint_32 PNGAPI +png_get_chunk_cache_max (png_const_structrp png_ptr) +{ + return (png_ptr ? png_ptr->user_chunk_cache_max : 0); +} + +/* This function was added to libpng 1.4.1 */ +png_alloc_size_t PNGAPI +png_get_chunk_malloc_max (png_const_structrp png_ptr) +{ + return (png_ptr ? png_ptr->user_chunk_malloc_max : 0); +} +#endif /* ?PNG_SET_USER_LIMITS_SUPPORTED */ + +/* These functions were added to libpng 1.4.0 */ +#ifdef PNG_IO_STATE_SUPPORTED +png_uint_32 PNGAPI +png_get_io_state (png_const_structrp png_ptr) +{ + return png_ptr->io_state; +} + +png_uint_32 PNGAPI +png_get_io_chunk_type (png_const_structrp png_ptr) +{ + return png_ptr->chunk_name; +} +#endif /* ?PNG_IO_STATE_SUPPORTED */ + +#ifdef PNG_CHECK_FOR_INVALID_INDEX_SUPPORTED +# ifdef PNG_GET_PALETTE_MAX_SUPPORTED +int PNGAPI +png_get_palette_max(png_const_structp png_ptr, png_const_infop info_ptr) +{ + if (png_ptr != NULL && info_ptr != NULL) + return png_ptr->num_palette_max; + + return (-1); +} +# endif +#endif + +#endif /* PNG_READ_SUPPORTED || PNG_WRITE_SUPPORTED */ diff --git a/ml/dlib/dlib/external/libpng/pnginfo.h b/ml/dlib/dlib/external/libpng/pnginfo.h new file mode 100644 index 000000000..26bf26502 --- /dev/null +++ b/ml/dlib/dlib/external/libpng/pnginfo.h @@ -0,0 +1,260 @@ + +/* pnginfo.h - header file for PNG reference library + * + * Copyright (c) 1998-2013 Glenn Randers-Pehrson + * (Version 0.96 Copyright (c) 1996, 1997 Andreas Dilger) + * (Version 0.88 Copyright (c) 1995, 1996 Guy Eric Schalnat, Group 42, Inc.) + * + * Last changed in libpng 1.6.1 [March 28, 2013] + * + * This code is released under the libpng license. + * For conditions of distribution and use, see the disclaimer + * and license in png.h + */ + + /* png_info is a structure that holds the information in a PNG file so + * that the application can find out the characteristics of the image. + * If you are reading the file, this structure will tell you what is + * in the PNG file. If you are writing the file, fill in the information + * you want to put into the PNG file, using png_set_*() functions, then + * call png_write_info(). + * + * The names chosen should be very close to the PNG specification, so + * consult that document for information about the meaning of each field. + * + * With libpng < 0.95, it was only possible to directly set and read the + * the values in the png_info_struct, which meant that the contents and + * order of the values had to remain fixed. With libpng 0.95 and later, + * however, there are now functions that abstract the contents of + * png_info_struct from the application, so this makes it easier to use + * libpng with dynamic libraries, and even makes it possible to use + * libraries that don't have all of the libpng ancillary chunk-handing + * functionality. In libpng-1.5.0 this was moved into a separate private + * file that is not visible to applications. + * + * The following members may have allocated storage attached that should be + * cleaned up before the structure is discarded: palette, trans, text, + * pcal_purpose, pcal_units, pcal_params, hist, iccp_name, iccp_profile, + * splt_palettes, scal_unit, row_pointers, and unknowns. By default, these + * are automatically freed when the info structure is deallocated, if they were + * allocated internally by libpng. This behavior can be changed by means + * of the png_data_freer() function. + * + * More allocation details: all the chunk-reading functions that + * change these members go through the corresponding png_set_* + * functions. A function to clear these members is available: see + * png_free_data(). The png_set_* functions do not depend on being + * able to point info structure members to any of the storage they are + * passed (they make their own copies), EXCEPT that the png_set_text + * functions use the same storage passed to them in the text_ptr or + * itxt_ptr structure argument, and the png_set_rows and png_set_unknowns + * functions do not make their own copies. + */ +#ifndef PNGINFO_H +#define PNGINFO_H + +struct png_info_def +{ + /* The following are necessary for every PNG file */ + png_uint_32 width; /* width of image in pixels (from IHDR) */ + png_uint_32 height; /* height of image in pixels (from IHDR) */ + png_uint_32 valid; /* valid chunk data (see PNG_INFO_ below) */ + png_size_t rowbytes; /* bytes needed to hold an untransformed row */ + png_colorp palette; /* array of color values (valid & PNG_INFO_PLTE) */ + png_uint_16 num_palette; /* number of color entries in "palette" (PLTE) */ + png_uint_16 num_trans; /* number of transparent palette color (tRNS) */ + png_byte bit_depth; /* 1, 2, 4, 8, or 16 bits/channel (from IHDR) */ + png_byte color_type; /* see PNG_COLOR_TYPE_ below (from IHDR) */ + /* The following three should have been named *_method not *_type */ + png_byte compression_type; /* must be PNG_COMPRESSION_TYPE_BASE (IHDR) */ + png_byte filter_type; /* must be PNG_FILTER_TYPE_BASE (from IHDR) */ + png_byte interlace_type; /* One of PNG_INTERLACE_NONE, PNG_INTERLACE_ADAM7 */ + + /* The following are set by png_set_IHDR, called from the application on + * write, but the are never actually used by the write code. + */ + png_byte channels; /* number of data channels per pixel (1, 2, 3, 4) */ + png_byte pixel_depth; /* number of bits per pixel */ + png_byte spare_byte; /* to align the data, and for future use */ + +#ifdef PNG_READ_SUPPORTED + /* This is never set during write */ + png_byte signature[8]; /* magic bytes read by libpng from start of file */ +#endif + + /* The rest of the data is optional. If you are reading, check the + * valid field to see if the information in these are valid. If you + * are writing, set the valid field to those chunks you want written, + * and initialize the appropriate fields below. + */ + +#if defined(PNG_COLORSPACE_SUPPORTED) || defined(PNG_GAMMA_SUPPORTED) + /* png_colorspace only contains 'flags' if neither GAMMA or COLORSPACE are + * defined. When COLORSPACE is switched on all the colorspace-defining + * chunks should be enabled, when GAMMA is switched on all the gamma-defining + * chunks should be enabled. If this is not done it becomes possible to read + * inconsistent PNG files and assign a probably incorrect interpretation to + * the information. (In other words, by carefully choosing which chunks to + * recognize the system configuration can select an interpretation for PNG + * files containing ambiguous data and this will result in inconsistent + * behavior between different libpng builds!) + */ + png_colorspace colorspace; +#endif + +#ifdef PNG_iCCP_SUPPORTED + /* iCCP chunk data. */ + png_charp iccp_name; /* profile name */ + png_bytep iccp_profile; /* International Color Consortium profile data */ + png_uint_32 iccp_proflen; /* ICC profile data length */ +#endif + +#ifdef PNG_TEXT_SUPPORTED + /* The tEXt, and zTXt chunks contain human-readable textual data in + * uncompressed, compressed, and optionally compressed forms, respectively. + * The data in "text" is an array of pointers to uncompressed, + * null-terminated C strings. Each chunk has a keyword that describes the + * textual data contained in that chunk. Keywords are not required to be + * unique, and the text string may be empty. Any number of text chunks may + * be in an image. + */ + int num_text; /* number of comments read or comments to write */ + int max_text; /* current size of text array */ + png_textp text; /* array of comments read or comments to write */ +#endif /* PNG_TEXT_SUPPORTED */ + +#ifdef PNG_tIME_SUPPORTED + /* The tIME chunk holds the last time the displayed image data was + * modified. See the png_time struct for the contents of this struct. + */ + png_time mod_time; +#endif + +#ifdef PNG_sBIT_SUPPORTED + /* The sBIT chunk specifies the number of significant high-order bits + * in the pixel data. Values are in the range [1, bit_depth], and are + * only specified for the channels in the pixel data. The contents of + * the low-order bits is not specified. Data is valid if + * (valid & PNG_INFO_sBIT) is non-zero. + */ + png_color_8 sig_bit; /* significant bits in color channels */ +#endif + +#if defined(PNG_tRNS_SUPPORTED) || defined(PNG_READ_EXPAND_SUPPORTED) || \ +defined(PNG_READ_BACKGROUND_SUPPORTED) + /* The tRNS chunk supplies transparency data for paletted images and + * other image types that don't need a full alpha channel. There are + * "num_trans" transparency values for a paletted image, stored in the + * same order as the palette colors, starting from index 0. Values + * for the data are in the range [0, 255], ranging from fully transparent + * to fully opaque, respectively. For non-paletted images, there is a + * single color specified that should be treated as fully transparent. + * Data is valid if (valid & PNG_INFO_tRNS) is non-zero. + */ + png_bytep trans_alpha; /* alpha values for paletted image */ + png_color_16 trans_color; /* transparent color for non-palette image */ +#endif + +#if defined(PNG_bKGD_SUPPORTED) || defined(PNG_READ_BACKGROUND_SUPPORTED) + /* The bKGD chunk gives the suggested image background color if the + * display program does not have its own background color and the image + * is needs to composited onto a background before display. The colors + * in "background" are normally in the same color space/depth as the + * pixel data. Data is valid if (valid & PNG_INFO_bKGD) is non-zero. + */ + png_color_16 background; +#endif + +#ifdef PNG_oFFs_SUPPORTED + /* The oFFs chunk gives the offset in "offset_unit_type" units rightwards + * and downwards from the top-left corner of the display, page, or other + * application-specific co-ordinate space. See the PNG_OFFSET_ defines + * below for the unit types. Valid if (valid & PNG_INFO_oFFs) non-zero. + */ + png_int_32 x_offset; /* x offset on page */ + png_int_32 y_offset; /* y offset on page */ + png_byte offset_unit_type; /* offset units type */ +#endif + +#ifdef PNG_pHYs_SUPPORTED + /* The pHYs chunk gives the physical pixel density of the image for + * display or printing in "phys_unit_type" units (see PNG_RESOLUTION_ + * defines below). Data is valid if (valid & PNG_INFO_pHYs) is non-zero. + */ + png_uint_32 x_pixels_per_unit; /* horizontal pixel density */ + png_uint_32 y_pixels_per_unit; /* vertical pixel density */ + png_byte phys_unit_type; /* resolution type (see PNG_RESOLUTION_ below) */ +#endif + +#ifdef PNG_hIST_SUPPORTED + /* The hIST chunk contains the relative frequency or importance of the + * various palette entries, so that a viewer can intelligently select a + * reduced-color palette, if required. Data is an array of "num_palette" + * values in the range [0,65535]. Data valid if (valid & PNG_INFO_hIST) + * is non-zero. + */ + png_uint_16p hist; +#endif + +#ifdef PNG_pCAL_SUPPORTED + /* The pCAL chunk describes a transformation between the stored pixel + * values and original physical data values used to create the image. + * The integer range [0, 2^bit_depth - 1] maps to the floating-point + * range given by [pcal_X0, pcal_X1], and are further transformed by a + * (possibly non-linear) transformation function given by "pcal_type" + * and "pcal_params" into "pcal_units". Please see the PNG_EQUATION_ + * defines below, and the PNG-Group's PNG extensions document for a + * complete description of the transformations and how they should be + * implemented, and for a description of the ASCII parameter strings. + * Data values are valid if (valid & PNG_INFO_pCAL) non-zero. + */ + png_charp pcal_purpose; /* pCAL chunk description string */ + png_int_32 pcal_X0; /* minimum value */ + png_int_32 pcal_X1; /* maximum value */ + png_charp pcal_units; /* Latin-1 string giving physical units */ + png_charpp pcal_params; /* ASCII strings containing parameter values */ + png_byte pcal_type; /* equation type (see PNG_EQUATION_ below) */ + png_byte pcal_nparams; /* number of parameters given in pcal_params */ +#endif + +/* New members added in libpng-1.0.6 */ + png_uint_32 free_me; /* flags items libpng is responsible for freeing */ + +#ifdef PNG_STORE_UNKNOWN_CHUNKS_SUPPORTED + /* Storage for unknown chunks that the library doesn't recognize. */ + png_unknown_chunkp unknown_chunks; + + /* The type of this field is limited by the type of + * png_struct::user_chunk_cache_max, else overflow can occur. + */ + int unknown_chunks_num; +#endif + +#ifdef PNG_sPLT_SUPPORTED + /* Data on sPLT chunks (there may be more than one). */ + png_sPLT_tp splt_palettes; + int splt_palettes_num; /* Match type returned by png_get API */ +#endif + +#ifdef PNG_sCAL_SUPPORTED + /* The sCAL chunk describes the actual physical dimensions of the + * subject matter of the graphic. The chunk contains a unit specification + * a byte value, and two ASCII strings representing floating-point + * values. The values are width and height corresponsing to one pixel + * in the image. Data values are valid if (valid & PNG_INFO_sCAL) is + * non-zero. + */ + png_byte scal_unit; /* unit of physical scale */ + png_charp scal_s_width; /* string containing height */ + png_charp scal_s_height; /* string containing width */ +#endif + +#ifdef PNG_INFO_IMAGE_SUPPORTED + /* Memory has been allocated if (valid & PNG_ALLOCATED_INFO_ROWS) + non-zero */ + /* Data valid if (valid & PNG_INFO_IDAT) non-zero */ + png_bytepp row_pointers; /* the image bits */ +#endif + +}; +#endif /* PNGINFO_H */ diff --git a/ml/dlib/dlib/external/libpng/pnglibconf.h b/ml/dlib/dlib/external/libpng/pnglibconf.h new file mode 100644 index 000000000..6e9a41152 --- /dev/null +++ b/ml/dlib/dlib/external/libpng/pnglibconf.h @@ -0,0 +1,211 @@ +/* libpng 1.6.7 STANDARD API DEFINITION */ + +/* pnglibconf.h - library build configuration */ + +/* Libpng version 1.6.7 - November 14, 2013 */ + +/* Copyright (c) 1998-2013 Glenn Randers-Pehrson */ + +/* This code is released under the libpng license. */ +/* For conditions of distribution and use, see the disclaimer */ +/* and license in png.h */ + +/* pnglibconf.h */ +/* Machine generated file: DO NOT EDIT */ +/* Derived from: scripts/pnglibconf.dfa */ +#ifndef PNGLCONF_H +#define PNGLCONF_H +/* options */ +#define PNG_16BIT_SUPPORTED +#define PNG_ALIGNED_MEMORY_SUPPORTED +/*#undef PNG_ARM_NEON_API_SUPPORTED*/ +/*#undef PNG_ARM_NEON_CHECK_SUPPORTED*/ +#define PNG_BENIGN_ERRORS_SUPPORTED +#define PNG_BENIGN_READ_ERRORS_SUPPORTED +/*#undef PNG_BENIGN_WRITE_ERRORS_SUPPORTED*/ +#define PNG_BUILD_GRAYSCALE_PALETTE_SUPPORTED +#define PNG_CHECK_FOR_INVALID_INDEX_SUPPORTED +#define PNG_COLORSPACE_SUPPORTED +#define PNG_CONSOLE_IO_SUPPORTED +#define PNG_CONVERT_tIME_SUPPORTED +#define PNG_EASY_ACCESS_SUPPORTED +/*#undef PNG_ERROR_NUMBERS_SUPPORTED*/ +#define PNG_ERROR_TEXT_SUPPORTED +#define PNG_FIXED_POINT_SUPPORTED +#define PNG_FLOATING_ARITHMETIC_SUPPORTED +#define PNG_FLOATING_POINT_SUPPORTED +#define PNG_FORMAT_AFIRST_SUPPORTED +#define PNG_FORMAT_BGR_SUPPORTED +#define PNG_GAMMA_SUPPORTED +#define PNG_GET_PALETTE_MAX_SUPPORTED +#define PNG_HANDLE_AS_UNKNOWN_SUPPORTED +#define PNG_INCH_CONVERSIONS_SUPPORTED +#define PNG_INFO_IMAGE_SUPPORTED +#define PNG_IO_STATE_SUPPORTED +#define PNG_MNG_FEATURES_SUPPORTED +#define PNG_POINTER_INDEXING_SUPPORTED +#define PNG_PROGRESSIVE_READ_SUPPORTED +#define PNG_READ_16BIT_SUPPORTED +#define PNG_READ_ALPHA_MODE_SUPPORTED +#define PNG_READ_ANCILLARY_CHUNKS_SUPPORTED +#define PNG_READ_BACKGROUND_SUPPORTED +#define PNG_READ_BGR_SUPPORTED +#define PNG_READ_CHECK_FOR_INVALID_INDEX_SUPPORTED +#define PNG_READ_COMPOSITE_NODIV_SUPPORTED +#define PNG_READ_COMPRESSED_TEXT_SUPPORTED +#define PNG_READ_EXPAND_16_SUPPORTED +#define PNG_READ_EXPAND_SUPPORTED +#define PNG_READ_FILLER_SUPPORTED +#define PNG_READ_GAMMA_SUPPORTED +#define PNG_READ_GET_PALETTE_MAX_SUPPORTED +#define PNG_READ_GRAY_TO_RGB_SUPPORTED +#define PNG_READ_INTERLACING_SUPPORTED +#define PNG_READ_INT_FUNCTIONS_SUPPORTED +#define PNG_READ_INVERT_ALPHA_SUPPORTED +#define PNG_READ_INVERT_SUPPORTED +#define PNG_READ_OPT_PLTE_SUPPORTED +#define PNG_READ_PACKSWAP_SUPPORTED +#define PNG_READ_PACK_SUPPORTED +#define PNG_READ_QUANTIZE_SUPPORTED +#define PNG_READ_RGB_TO_GRAY_SUPPORTED +#define PNG_READ_SCALE_16_TO_8_SUPPORTED +#define PNG_READ_SHIFT_SUPPORTED +#define PNG_READ_STRIP_16_TO_8_SUPPORTED +#define PNG_READ_STRIP_ALPHA_SUPPORTED +#define PNG_READ_SUPPORTED +#define PNG_READ_SWAP_ALPHA_SUPPORTED +#define PNG_READ_SWAP_SUPPORTED +#define PNG_READ_TEXT_SUPPORTED +#define PNG_READ_TRANSFORMS_SUPPORTED +#define PNG_READ_UNKNOWN_CHUNKS_SUPPORTED +#define PNG_READ_USER_CHUNKS_SUPPORTED +#define PNG_READ_USER_TRANSFORM_SUPPORTED +#define PNG_READ_bKGD_SUPPORTED +#define PNG_READ_cHRM_SUPPORTED +#define PNG_READ_gAMA_SUPPORTED +#define PNG_READ_hIST_SUPPORTED +#define PNG_READ_iCCP_SUPPORTED +#define PNG_READ_iTXt_SUPPORTED +#define PNG_READ_oFFs_SUPPORTED +#define PNG_READ_pCAL_SUPPORTED +#define PNG_READ_pHYs_SUPPORTED +#define PNG_READ_sBIT_SUPPORTED +#define PNG_READ_sCAL_SUPPORTED +#define PNG_READ_sPLT_SUPPORTED +#define PNG_READ_sRGB_SUPPORTED +#define PNG_READ_tEXt_SUPPORTED +#define PNG_READ_tIME_SUPPORTED +#define PNG_READ_tRNS_SUPPORTED +#define PNG_READ_zTXt_SUPPORTED +/*#undef PNG_SAFE_LIMITS_SUPPORTED*/ +#define PNG_SAVE_INT_32_SUPPORTED +#define PNG_SAVE_UNKNOWN_CHUNKS_SUPPORTED +#define PNG_SEQUENTIAL_READ_SUPPORTED +#define PNG_SETJMP_SUPPORTED +#define PNG_SET_CHUNK_CACHE_LIMIT_SUPPORTED +#define PNG_SET_CHUNK_MALLOC_LIMIT_SUPPORTED +#define PNG_SET_OPTION_SUPPORTED +#define PNG_SET_UNKNOWN_CHUNKS_SUPPORTED +#define PNG_SET_USER_LIMITS_SUPPORTED +#define PNG_SIMPLIFIED_READ_AFIRST_SUPPORTED +#define PNG_SIMPLIFIED_READ_BGR_SUPPORTED +#define PNG_SIMPLIFIED_READ_SUPPORTED +#define PNG_SIMPLIFIED_WRITE_AFIRST_SUPPORTED +#define PNG_SIMPLIFIED_WRITE_BGR_SUPPORTED +#define PNG_SIMPLIFIED_WRITE_SUPPORTED +#define PNG_STDIO_SUPPORTED +#define PNG_STORE_UNKNOWN_CHUNKS_SUPPORTED +#define PNG_TEXT_SUPPORTED +#define PNG_TIME_RFC1123_SUPPORTED +#define PNG_UNKNOWN_CHUNKS_SUPPORTED +#define PNG_USER_CHUNKS_SUPPORTED +#define PNG_USER_LIMITS_SUPPORTED +#define PNG_USER_MEM_SUPPORTED +#define PNG_USER_TRANSFORM_INFO_SUPPORTED +#define PNG_USER_TRANSFORM_PTR_SUPPORTED +#define PNG_WARNINGS_SUPPORTED +#define PNG_WRITE_16BIT_SUPPORTED +#define PNG_WRITE_ANCILLARY_CHUNKS_SUPPORTED +#define PNG_WRITE_BGR_SUPPORTED +#define PNG_WRITE_CHECK_FOR_INVALID_INDEX_SUPPORTED +#define PNG_WRITE_COMPRESSED_TEXT_SUPPORTED +#define PNG_WRITE_CUSTOMIZE_ZTXT_COMPRESSION_SUPPORTED +#define PNG_WRITE_FILLER_SUPPORTED +#define PNG_WRITE_FILTER_SUPPORTED +#define PNG_WRITE_FLUSH_SUPPORTED +#define PNG_WRITE_GET_PALETTE_MAX_SUPPORTED +#define PNG_WRITE_INTERLACING_SUPPORTED +#define PNG_WRITE_INT_FUNCTIONS_SUPPORTED +#define PNG_WRITE_INVERT_ALPHA_SUPPORTED +#define PNG_WRITE_INVERT_SUPPORTED +#define PNG_WRITE_OPTIMIZE_CMF_SUPPORTED +#define PNG_WRITE_PACKSWAP_SUPPORTED +#define PNG_WRITE_PACK_SUPPORTED +#define PNG_WRITE_SHIFT_SUPPORTED +#define PNG_WRITE_SUPPORTED +#define PNG_WRITE_SWAP_ALPHA_SUPPORTED +#define PNG_WRITE_SWAP_SUPPORTED +#define PNG_WRITE_TEXT_SUPPORTED +#define PNG_WRITE_TRANSFORMS_SUPPORTED +#define PNG_WRITE_UNKNOWN_CHUNKS_SUPPORTED +#define PNG_WRITE_USER_TRANSFORM_SUPPORTED +#define PNG_WRITE_WEIGHTED_FILTER_SUPPORTED +#define PNG_WRITE_bKGD_SUPPORTED +#define PNG_WRITE_cHRM_SUPPORTED +#define PNG_WRITE_gAMA_SUPPORTED +#define PNG_WRITE_hIST_SUPPORTED +#define PNG_WRITE_iCCP_SUPPORTED +#define PNG_WRITE_iTXt_SUPPORTED +#define PNG_WRITE_oFFs_SUPPORTED +#define PNG_WRITE_pCAL_SUPPORTED +#define PNG_WRITE_pHYs_SUPPORTED +#define PNG_WRITE_sBIT_SUPPORTED +#define PNG_WRITE_sCAL_SUPPORTED +#define PNG_WRITE_sPLT_SUPPORTED +#define PNG_WRITE_sRGB_SUPPORTED +#define PNG_WRITE_tEXt_SUPPORTED +#define PNG_WRITE_tIME_SUPPORTED +#define PNG_WRITE_tRNS_SUPPORTED +#define PNG_WRITE_zTXt_SUPPORTED +#define PNG_bKGD_SUPPORTED +#define PNG_cHRM_SUPPORTED +#define PNG_gAMA_SUPPORTED +#define PNG_hIST_SUPPORTED +#define PNG_iCCP_SUPPORTED +#define PNG_iTXt_SUPPORTED +#define PNG_oFFs_SUPPORTED +#define PNG_pCAL_SUPPORTED +#define PNG_pHYs_SUPPORTED +#define PNG_sBIT_SUPPORTED +#define PNG_sCAL_SUPPORTED +#define PNG_sPLT_SUPPORTED +#define PNG_sRGB_SUPPORTED +#define PNG_tEXt_SUPPORTED +#define PNG_tIME_SUPPORTED +#define PNG_tRNS_SUPPORTED +#define PNG_zTXt_SUPPORTED +/* end of options */ +/* settings */ +#define PNG_API_RULE 0 +#define PNG_CALLOC_SUPPORTED +#define PNG_COST_SHIFT 3 +#define PNG_DEFAULT_READ_MACROS 1 +#define PNG_GAMMA_THRESHOLD_FIXED 5000 +#define PNG_IDAT_READ_SIZE PNG_ZBUF_SIZE +#define PNG_INFLATE_BUF_SIZE 1024 +#define PNG_MAX_GAMMA_8 11 +#define PNG_QUANTIZE_BLUE_BITS 5 +#define PNG_QUANTIZE_GREEN_BITS 5 +#define PNG_QUANTIZE_RED_BITS 5 +#define PNG_TEXT_Z_DEFAULT_COMPRESSION (-1) +#define PNG_TEXT_Z_DEFAULT_STRATEGY 0 +#define PNG_WEIGHT_SHIFT 8 +#define PNG_ZBUF_SIZE 8192 +#define PNG_ZLIB_VERNUM 0 /* unknown */ +#define PNG_Z_DEFAULT_COMPRESSION (-1) +#define PNG_Z_DEFAULT_NOFILTER_STRATEGY 0 +#define PNG_Z_DEFAULT_STRATEGY 1 +#define PNG_sCAL_PRECISION 5 +#define PNG_sRGB_PROFILE_CHECKS 2 +/* end of settings */ +#endif /* PNGLCONF_H */ diff --git a/ml/dlib/dlib/external/libpng/pngmem.c b/ml/dlib/dlib/external/libpng/pngmem.c new file mode 100644 index 000000000..b9b3efb44 --- /dev/null +++ b/ml/dlib/dlib/external/libpng/pngmem.c @@ -0,0 +1,277 @@ + +/* pngmem.c - stub functions for memory allocation + * + * Last changed in libpng 1.6.0 [February 14, 2013] + * Copyright (c) 1998-2013 Glenn Randers-Pehrson + * (Version 0.96 Copyright (c) 1996, 1997 Andreas Dilger) + * (Version 0.88 Copyright (c) 1995, 1996 Guy Eric Schalnat, Group 42, Inc.) + * + * This code is released under the libpng license. + * For conditions of distribution and use, see the disclaimer + * and license in png.h + * + * This file provides a location for all memory allocation. Users who + * need special memory handling are expected to supply replacement + * functions for png_malloc() and png_free(), and to use + * png_create_read_struct_2() and png_create_write_struct_2() to + * identify the replacement functions. + */ + +#include "pngpriv.h" + +#if defined(PNG_READ_SUPPORTED) || defined(PNG_WRITE_SUPPORTED) +/* Free a png_struct */ +void /* PRIVATE */ +png_destroy_png_struct(png_structrp png_ptr) +{ + if (png_ptr != NULL) + { + /* png_free might call png_error and may certainly call + * png_get_mem_ptr, so fake a temporary png_struct to support this. + */ + png_struct dummy_struct = *png_ptr; + memset(png_ptr, 0, (sizeof *png_ptr)); + png_free(&dummy_struct, png_ptr); + +# ifdef PNG_SETJMP_SUPPORTED + /* We may have a jmp_buf left to deallocate. */ + png_free_jmpbuf(&dummy_struct); +# endif + } +} + +/* Allocate memory. For reasonable files, size should never exceed + * 64K. However, zlib may allocate more then 64K if you don't tell + * it not to. See zconf.h and png.h for more information. zlib does + * need to allocate exactly 64K, so whatever you call here must + * have the ability to do that. + */ +PNG_FUNCTION(png_voidp,PNGAPI +png_calloc,(png_const_structrp png_ptr, png_alloc_size_t size),PNG_ALLOCATED) +{ + png_voidp ret; + + ret = png_malloc(png_ptr, size); + + if (ret != NULL) + memset(ret, 0, size); + + return ret; +} + +/* png_malloc_base, an internal function added at libpng 1.6.0, does the work of + * allocating memory, taking into account limits and PNG_USER_MEM_SUPPORTED. + * Checking and error handling must happen outside this routine; it returns NULL + * if the allocation cannot be done (for any reason.) + */ +PNG_FUNCTION(png_voidp /* PRIVATE */, +png_malloc_base,(png_const_structrp png_ptr, png_alloc_size_t size), + PNG_ALLOCATED) +{ + /* Moved to png_malloc_base from png_malloc_default in 1.6.0; the DOS + * allocators have also been removed in 1.6.0, so any 16-bit system now has + * to implement a user memory handler. This checks to be sure it isn't + * called with big numbers. + */ +#ifdef PNG_USER_MEM_SUPPORTED + PNG_UNUSED(png_ptr) +#endif + if (size > 0 && size <= PNG_SIZE_MAX +# ifdef PNG_MAX_MALLOC_64K + && size <= 65536U +# endif + ) + { +#ifdef PNG_USER_MEM_SUPPORTED + if (png_ptr != NULL && png_ptr->malloc_fn != NULL) + return png_ptr->malloc_fn(png_constcast(png_structrp,png_ptr), size); + + else +#endif + return malloc((size_t)size); /* checked for truncation above */ + } + + else + return NULL; +} + +/* This is really here only to work round a spurious warning in GCC 4.6 and 4.7 + * that arises because of the checks in png_realloc_array that are repeated in + * png_malloc_array. + */ +static png_voidp +png_malloc_array_checked(png_const_structrp png_ptr, int nelements, + size_t element_size) +{ + png_alloc_size_t req = nelements; /* known to be > 0 */ + + if (req <= PNG_SIZE_MAX/element_size) + return png_malloc_base(png_ptr, req * element_size); + + /* The failure case when the request is too large */ + return NULL; +} + +PNG_FUNCTION(png_voidp /* PRIVATE */, +png_malloc_array,(png_const_structrp png_ptr, int nelements, + size_t element_size),PNG_ALLOCATED) +{ + if (nelements <= 0 || element_size == 0) + png_error(png_ptr, "internal error: array alloc"); + + return png_malloc_array_checked(png_ptr, nelements, element_size); +} + +PNG_FUNCTION(png_voidp /* PRIVATE */, +png_realloc_array,(png_const_structrp png_ptr, png_const_voidp old_array, + int old_elements, int add_elements, size_t element_size),PNG_ALLOCATED) +{ + /* These are internal errors: */ + if (add_elements <= 0 || element_size == 0 || old_elements < 0 || + (old_array == NULL && old_elements > 0)) + png_error(png_ptr, "internal error: array realloc"); + + /* Check for overflow on the elements count (so the caller does not have to + * check.) + */ + if (add_elements <= INT_MAX - old_elements) + { + png_voidp new_array = png_malloc_array_checked(png_ptr, + old_elements+add_elements, element_size); + + if (new_array != NULL) + { + /* Because png_malloc_array worked the size calculations below cannot + * overflow. + */ + if (old_elements > 0) + memcpy(new_array, old_array, element_size*(unsigned)old_elements); + + memset((char*)new_array + element_size*(unsigned)old_elements, 0, + element_size*(unsigned)add_elements); + + return new_array; + } + } + + return NULL; /* error */ +} + +/* Various functions that have different error handling are derived from this. + * png_malloc always exists, but if PNG_USER_MEM_SUPPORTED is defined a separate + * function png_malloc_default is also provided. + */ +PNG_FUNCTION(png_voidp,PNGAPI +png_malloc,(png_const_structrp png_ptr, png_alloc_size_t size),PNG_ALLOCATED) +{ + png_voidp ret; + + if (png_ptr == NULL) + return NULL; + + ret = png_malloc_base(png_ptr, size); + + if (ret == NULL) + png_error(png_ptr, "Out of memory"); /* 'm' means png_malloc */ + + return ret; +} + +#ifdef PNG_USER_MEM_SUPPORTED +PNG_FUNCTION(png_voidp,PNGAPI +png_malloc_default,(png_const_structrp png_ptr, png_alloc_size_t size), + PNG_ALLOCATED PNG_DEPRECATED) +{ + png_voidp ret; + + if (png_ptr == NULL) + return NULL; + + /* Passing 'NULL' here bypasses the application provided memory handler. */ + ret = png_malloc_base(NULL/*use malloc*/, size); + + if (ret == NULL) + png_error(png_ptr, "Out of Memory"); /* 'M' means png_malloc_default */ + + return ret; +} +#endif /* PNG_USER_MEM_SUPPORTED */ + +/* This function was added at libpng version 1.2.3. The png_malloc_warn() + * function will issue a png_warning and return NULL instead of issuing a + * png_error, if it fails to allocate the requested memory. + */ +PNG_FUNCTION(png_voidp,PNGAPI +png_malloc_warn,(png_const_structrp png_ptr, png_alloc_size_t size), + PNG_ALLOCATED) +{ + if (png_ptr != NULL) + { + png_voidp ret = png_malloc_base(png_ptr, size); + + if (ret != NULL) + return ret; + + png_warning(png_ptr, "Out of memory"); + } + + return NULL; +} + +/* Free a pointer allocated by png_malloc(). If ptr is NULL, return + * without taking any action. + */ +void PNGAPI +png_free(png_const_structrp png_ptr, png_voidp ptr) +{ + if (png_ptr == NULL || ptr == NULL) + return; + +#ifdef PNG_USER_MEM_SUPPORTED + if (png_ptr->free_fn != NULL) + png_ptr->free_fn(png_constcast(png_structrp,png_ptr), ptr); + + else + png_free_default(png_ptr, ptr); +} + +PNG_FUNCTION(void,PNGAPI +png_free_default,(png_const_structrp png_ptr, png_voidp ptr),PNG_DEPRECATED) +{ + if (png_ptr == NULL || ptr == NULL) + return; +#endif /* PNG_USER_MEM_SUPPORTED */ + + free(ptr); +} + +#ifdef PNG_USER_MEM_SUPPORTED +/* This function is called when the application wants to use another method + * of allocating and freeing memory. + */ +void PNGAPI +png_set_mem_fn(png_structrp png_ptr, png_voidp mem_ptr, png_malloc_ptr + malloc_fn, png_free_ptr free_fn) +{ + if (png_ptr != NULL) + { + png_ptr->mem_ptr = mem_ptr; + png_ptr->malloc_fn = malloc_fn; + png_ptr->free_fn = free_fn; + } +} + +/* This function returns a pointer to the mem_ptr associated with the user + * functions. The application should free any memory associated with this + * pointer before png_write_destroy and png_read_destroy are called. + */ +png_voidp PNGAPI +png_get_mem_ptr(png_const_structrp png_ptr) +{ + if (png_ptr == NULL) + return NULL; + + return png_ptr->mem_ptr; +} +#endif /* PNG_USER_MEM_SUPPORTED */ +#endif /* PNG_READ_SUPPORTED || PNG_WRITE_SUPPORTED */ diff --git a/ml/dlib/dlib/external/libpng/pngpread.c b/ml/dlib/dlib/external/libpng/pngpread.c new file mode 100644 index 000000000..0169ecb2c --- /dev/null +++ b/ml/dlib/dlib/external/libpng/pngpread.c @@ -0,0 +1,1291 @@ + +/* pngpread.c - read a png file in push mode + * + * Last changed in libpng 1.6.0 [February 14, 2013] + * Copyright (c) 1998-2013 Glenn Randers-Pehrson + * (Version 0.96 Copyright (c) 1996, 1997 Andreas Dilger) + * (Version 0.88 Copyright (c) 1995, 1996 Guy Eric Schalnat, Group 42, Inc.) + * + * This code is released under the libpng license. + * For conditions of distribution and use, see the disclaimer + * and license in png.h + */ + +#include "pngpriv.h" + +#ifdef PNG_PROGRESSIVE_READ_SUPPORTED + +/* Push model modes */ +#define PNG_READ_SIG_MODE 0 +#define PNG_READ_CHUNK_MODE 1 +#define PNG_READ_IDAT_MODE 2 +#define PNG_SKIP_MODE 3 +#define PNG_READ_tEXt_MODE 4 +#define PNG_READ_zTXt_MODE 5 +#define PNG_READ_DONE_MODE 6 +#define PNG_READ_iTXt_MODE 7 +#define PNG_ERROR_MODE 8 + +void PNGAPI +png_process_data(png_structrp png_ptr, png_inforp info_ptr, + png_bytep buffer, png_size_t buffer_size) +{ + if (png_ptr == NULL || info_ptr == NULL) + return; + + png_push_restore_buffer(png_ptr, buffer, buffer_size); + + while (png_ptr->buffer_size) + { + png_process_some_data(png_ptr, info_ptr); + } +} + +png_size_t PNGAPI +png_process_data_pause(png_structrp png_ptr, int save) +{ + if (png_ptr != NULL) + { + /* It's easiest for the caller if we do the save, then the caller doesn't + * have to supply the same data again: + */ + if (save) + png_push_save_buffer(png_ptr); + else + { + /* This includes any pending saved bytes: */ + png_size_t remaining = png_ptr->buffer_size; + png_ptr->buffer_size = 0; + + /* So subtract the saved buffer size, unless all the data + * is actually 'saved', in which case we just return 0 + */ + if (png_ptr->save_buffer_size < remaining) + return remaining - png_ptr->save_buffer_size; + } + } + + return 0; +} + +png_uint_32 PNGAPI +png_process_data_skip(png_structrp png_ptr) +{ + png_uint_32 remaining = 0; + + if (png_ptr != NULL && png_ptr->process_mode == PNG_SKIP_MODE && + png_ptr->skip_length > 0) + { + /* At the end of png_process_data the buffer size must be 0 (see the loop + * above) so we can detect a broken call here: + */ + if (png_ptr->buffer_size != 0) + png_error(png_ptr, + "png_process_data_skip called inside png_process_data"); + + /* If is impossible for there to be a saved buffer at this point - + * otherwise we could not be in SKIP mode. This will also happen if + * png_process_skip is called inside png_process_data (but only very + * rarely.) + */ + if (png_ptr->save_buffer_size != 0) + png_error(png_ptr, "png_process_data_skip called with saved data"); + + remaining = png_ptr->skip_length; + png_ptr->skip_length = 0; + png_ptr->process_mode = PNG_READ_CHUNK_MODE; + } + + return remaining; +} + +/* What we do with the incoming data depends on what we were previously + * doing before we ran out of data... + */ +void /* PRIVATE */ +png_process_some_data(png_structrp png_ptr, png_inforp info_ptr) +{ + if (png_ptr == NULL) + return; + + switch (png_ptr->process_mode) + { + case PNG_READ_SIG_MODE: + { + png_push_read_sig(png_ptr, info_ptr); + break; + } + + case PNG_READ_CHUNK_MODE: + { + png_push_read_chunk(png_ptr, info_ptr); + break; + } + + case PNG_READ_IDAT_MODE: + { + png_push_read_IDAT(png_ptr); + break; + } + + case PNG_SKIP_MODE: + { + png_push_crc_finish(png_ptr); + break; + } + + default: + { + png_ptr->buffer_size = 0; + break; + } + } +} + +/* Read any remaining signature bytes from the stream and compare them with + * the correct PNG signature. It is possible that this routine is called + * with bytes already read from the signature, either because they have been + * checked by the calling application, or because of multiple calls to this + * routine. + */ +void /* PRIVATE */ +png_push_read_sig(png_structrp png_ptr, png_inforp info_ptr) +{ + png_size_t num_checked = png_ptr->sig_bytes, /* SAFE, does not exceed 8 */ + num_to_check = 8 - num_checked; + + if (png_ptr->buffer_size < num_to_check) + { + num_to_check = png_ptr->buffer_size; + } + + png_push_fill_buffer(png_ptr, &(info_ptr->signature[num_checked]), + num_to_check); + png_ptr->sig_bytes = (png_byte)(png_ptr->sig_bytes + num_to_check); + + if (png_sig_cmp(info_ptr->signature, num_checked, num_to_check)) + { + if (num_checked < 4 && + png_sig_cmp(info_ptr->signature, num_checked, num_to_check - 4)) + png_error(png_ptr, "Not a PNG file"); + + else + png_error(png_ptr, "PNG file corrupted by ASCII conversion"); + } + else + { + if (png_ptr->sig_bytes >= 8) + { + png_ptr->process_mode = PNG_READ_CHUNK_MODE; + } + } +} + +void /* PRIVATE */ +png_push_read_chunk(png_structrp png_ptr, png_inforp info_ptr) +{ + png_uint_32 chunk_name; +#ifdef PNG_HANDLE_AS_UNKNOWN_SUPPORTED + int keep; /* unknown handling method */ +#endif + + /* First we make sure we have enough data for the 4 byte chunk name + * and the 4 byte chunk length before proceeding with decoding the + * chunk data. To fully decode each of these chunks, we also make + * sure we have enough data in the buffer for the 4 byte CRC at the + * end of every chunk (except IDAT, which is handled separately). + */ + if (!(png_ptr->mode & PNG_HAVE_CHUNK_HEADER)) + { + png_byte chunk_length[4]; + png_byte chunk_tag[4]; + + if (png_ptr->buffer_size < 8) + { + png_push_save_buffer(png_ptr); + return; + } + + png_push_fill_buffer(png_ptr, chunk_length, 4); + png_ptr->push_length = png_get_uint_31(png_ptr, chunk_length); + png_reset_crc(png_ptr); + png_crc_read(png_ptr, chunk_tag, 4); + png_ptr->chunk_name = PNG_CHUNK_FROM_STRING(chunk_tag); + png_check_chunk_name(png_ptr, png_ptr->chunk_name); + png_ptr->mode |= PNG_HAVE_CHUNK_HEADER; + } + + chunk_name = png_ptr->chunk_name; + + if (chunk_name == png_IDAT) + { + if (png_ptr->mode & PNG_AFTER_IDAT) + png_ptr->mode |= PNG_HAVE_CHUNK_AFTER_IDAT; + + /* If we reach an IDAT chunk, this means we have read all of the + * header chunks, and we can start reading the image (or if this + * is called after the image has been read - we have an error). + */ + if (!(png_ptr->mode & PNG_HAVE_IHDR)) + png_error(png_ptr, "Missing IHDR before IDAT"); + + else if (png_ptr->color_type == PNG_COLOR_TYPE_PALETTE && + !(png_ptr->mode & PNG_HAVE_PLTE)) + png_error(png_ptr, "Missing PLTE before IDAT"); + + png_ptr->mode |= PNG_HAVE_IDAT; + + if (!(png_ptr->mode & PNG_HAVE_CHUNK_AFTER_IDAT)) + if (png_ptr->push_length == 0) + return; + + if (png_ptr->mode & PNG_AFTER_IDAT) + png_benign_error(png_ptr, "Too many IDATs found"); + } + + if (chunk_name == png_IHDR) + { + if (png_ptr->push_length != 13) + png_error(png_ptr, "Invalid IHDR length"); + + if (png_ptr->push_length + 4 > png_ptr->buffer_size) + { + png_push_save_buffer(png_ptr); + return; + } + + png_handle_IHDR(png_ptr, info_ptr, png_ptr->push_length); + } + + else if (chunk_name == png_IEND) + { + if (png_ptr->push_length + 4 > png_ptr->buffer_size) + { + png_push_save_buffer(png_ptr); + return; + } + + png_handle_IEND(png_ptr, info_ptr, png_ptr->push_length); + + png_ptr->process_mode = PNG_READ_DONE_MODE; + png_push_have_end(png_ptr, info_ptr); + } + +#ifdef PNG_HANDLE_AS_UNKNOWN_SUPPORTED + else if ((keep = png_chunk_unknown_handling(png_ptr, chunk_name)) != 0) + { + if (png_ptr->push_length + 4 > png_ptr->buffer_size) + { + png_push_save_buffer(png_ptr); + return; + } + + png_handle_unknown(png_ptr, info_ptr, png_ptr->push_length, keep); + + if (chunk_name == png_PLTE) + png_ptr->mode |= PNG_HAVE_PLTE; + } + +#endif + else if (chunk_name == png_PLTE) + { + if (png_ptr->push_length + 4 > png_ptr->buffer_size) + { + png_push_save_buffer(png_ptr); + return; + } + png_handle_PLTE(png_ptr, info_ptr, png_ptr->push_length); + } + + else if (chunk_name == png_IDAT) + { + png_ptr->idat_size = png_ptr->push_length; + png_ptr->process_mode = PNG_READ_IDAT_MODE; + png_push_have_info(png_ptr, info_ptr); + png_ptr->zstream.avail_out = + (uInt) PNG_ROWBYTES(png_ptr->pixel_depth, + png_ptr->iwidth) + 1; + png_ptr->zstream.next_out = png_ptr->row_buf; + return; + } + +#ifdef PNG_READ_gAMA_SUPPORTED + else if (png_ptr->chunk_name == png_gAMA) + { + if (png_ptr->push_length + 4 > png_ptr->buffer_size) + { + png_push_save_buffer(png_ptr); + return; + } + + png_handle_gAMA(png_ptr, info_ptr, png_ptr->push_length); + } + +#endif +#ifdef PNG_READ_sBIT_SUPPORTED + else if (png_ptr->chunk_name == png_sBIT) + { + if (png_ptr->push_length + 4 > png_ptr->buffer_size) + { + png_push_save_buffer(png_ptr); + return; + } + + png_handle_sBIT(png_ptr, info_ptr, png_ptr->push_length); + } + +#endif +#ifdef PNG_READ_cHRM_SUPPORTED + else if (png_ptr->chunk_name == png_cHRM) + { + if (png_ptr->push_length + 4 > png_ptr->buffer_size) + { + png_push_save_buffer(png_ptr); + return; + } + + png_handle_cHRM(png_ptr, info_ptr, png_ptr->push_length); + } + +#endif +#ifdef PNG_READ_sRGB_SUPPORTED + else if (chunk_name == png_sRGB) + { + if (png_ptr->push_length + 4 > png_ptr->buffer_size) + { + png_push_save_buffer(png_ptr); + return; + } + + png_handle_sRGB(png_ptr, info_ptr, png_ptr->push_length); + } + +#endif +#ifdef PNG_READ_iCCP_SUPPORTED + else if (png_ptr->chunk_name == png_iCCP) + { + if (png_ptr->push_length + 4 > png_ptr->buffer_size) + { + png_push_save_buffer(png_ptr); + return; + } + + png_handle_iCCP(png_ptr, info_ptr, png_ptr->push_length); + } + +#endif +#ifdef PNG_READ_sPLT_SUPPORTED + else if (chunk_name == png_sPLT) + { + if (png_ptr->push_length + 4 > png_ptr->buffer_size) + { + png_push_save_buffer(png_ptr); + return; + } + + png_handle_sPLT(png_ptr, info_ptr, png_ptr->push_length); + } + +#endif +#ifdef PNG_READ_tRNS_SUPPORTED + else if (chunk_name == png_tRNS) + { + if (png_ptr->push_length + 4 > png_ptr->buffer_size) + { + png_push_save_buffer(png_ptr); + return; + } + + png_handle_tRNS(png_ptr, info_ptr, png_ptr->push_length); + } + +#endif +#ifdef PNG_READ_bKGD_SUPPORTED + else if (chunk_name == png_bKGD) + { + if (png_ptr->push_length + 4 > png_ptr->buffer_size) + { + png_push_save_buffer(png_ptr); + return; + } + + png_handle_bKGD(png_ptr, info_ptr, png_ptr->push_length); + } + +#endif +#ifdef PNG_READ_hIST_SUPPORTED + else if (chunk_name == png_hIST) + { + if (png_ptr->push_length + 4 > png_ptr->buffer_size) + { + png_push_save_buffer(png_ptr); + return; + } + + png_handle_hIST(png_ptr, info_ptr, png_ptr->push_length); + } + +#endif +#ifdef PNG_READ_pHYs_SUPPORTED + else if (chunk_name == png_pHYs) + { + if (png_ptr->push_length + 4 > png_ptr->buffer_size) + { + png_push_save_buffer(png_ptr); + return; + } + + png_handle_pHYs(png_ptr, info_ptr, png_ptr->push_length); + } + +#endif +#ifdef PNG_READ_oFFs_SUPPORTED + else if (chunk_name == png_oFFs) + { + if (png_ptr->push_length + 4 > png_ptr->buffer_size) + { + png_push_save_buffer(png_ptr); + return; + } + + png_handle_oFFs(png_ptr, info_ptr, png_ptr->push_length); + } +#endif + +#ifdef PNG_READ_pCAL_SUPPORTED + else if (chunk_name == png_pCAL) + { + if (png_ptr->push_length + 4 > png_ptr->buffer_size) + { + png_push_save_buffer(png_ptr); + return; + } + + png_handle_pCAL(png_ptr, info_ptr, png_ptr->push_length); + } + +#endif +#ifdef PNG_READ_sCAL_SUPPORTED + else if (chunk_name == png_sCAL) + { + if (png_ptr->push_length + 4 > png_ptr->buffer_size) + { + png_push_save_buffer(png_ptr); + return; + } + + png_handle_sCAL(png_ptr, info_ptr, png_ptr->push_length); + } + +#endif +#ifdef PNG_READ_tIME_SUPPORTED + else if (chunk_name == png_tIME) + { + if (png_ptr->push_length + 4 > png_ptr->buffer_size) + { + png_push_save_buffer(png_ptr); + return; + } + + png_handle_tIME(png_ptr, info_ptr, png_ptr->push_length); + } + +#endif +#ifdef PNG_READ_tEXt_SUPPORTED + else if (chunk_name == png_tEXt) + { + if (png_ptr->push_length + 4 > png_ptr->buffer_size) + { + png_push_save_buffer(png_ptr); + return; + } + + png_handle_tEXt(png_ptr, info_ptr, png_ptr->push_length); + } + +#endif +#ifdef PNG_READ_zTXt_SUPPORTED + else if (chunk_name == png_zTXt) + { + if (png_ptr->push_length + 4 > png_ptr->buffer_size) + { + png_push_save_buffer(png_ptr); + return; + } + + png_handle_zTXt(png_ptr, info_ptr, png_ptr->push_length); + } + +#endif +#ifdef PNG_READ_iTXt_SUPPORTED + else if (chunk_name == png_iTXt) + { + if (png_ptr->push_length + 4 > png_ptr->buffer_size) + { + png_push_save_buffer(png_ptr); + return; + } + + png_handle_iTXt(png_ptr, info_ptr, png_ptr->push_length); + } + +#endif + else + { + if (png_ptr->push_length + 4 > png_ptr->buffer_size) + { + png_push_save_buffer(png_ptr); + return; + } + png_handle_unknown(png_ptr, info_ptr, png_ptr->push_length, + PNG_HANDLE_CHUNK_AS_DEFAULT); + } + + png_ptr->mode &= ~PNG_HAVE_CHUNK_HEADER; +} + +void /* PRIVATE */ +png_push_crc_skip(png_structrp png_ptr, png_uint_32 skip) +{ + png_ptr->process_mode = PNG_SKIP_MODE; + png_ptr->skip_length = skip; +} + +void /* PRIVATE */ +png_push_crc_finish(png_structrp png_ptr) +{ + if (png_ptr->skip_length && png_ptr->save_buffer_size) + { + png_size_t save_size = png_ptr->save_buffer_size; + png_uint_32 skip_length = png_ptr->skip_length; + + /* We want the smaller of 'skip_length' and 'save_buffer_size', but + * they are of different types and we don't know which variable has the + * fewest bits. Carefully select the smaller and cast it to the type of + * the larger - this cannot overflow. Do not cast in the following test + * - it will break on either 16 or 64 bit platforms. + */ + if (skip_length < save_size) + save_size = (png_size_t)skip_length; + + else + skip_length = (png_uint_32)save_size; + + png_calculate_crc(png_ptr, png_ptr->save_buffer_ptr, save_size); + + png_ptr->skip_length -= skip_length; + png_ptr->buffer_size -= save_size; + png_ptr->save_buffer_size -= save_size; + png_ptr->save_buffer_ptr += save_size; + } + if (png_ptr->skip_length && png_ptr->current_buffer_size) + { + png_size_t save_size = png_ptr->current_buffer_size; + png_uint_32 skip_length = png_ptr->skip_length; + + /* We want the smaller of 'skip_length' and 'current_buffer_size', here, + * the same problem exists as above and the same solution. + */ + if (skip_length < save_size) + save_size = (png_size_t)skip_length; + + else + skip_length = (png_uint_32)save_size; + + png_calculate_crc(png_ptr, png_ptr->current_buffer_ptr, save_size); + + png_ptr->skip_length -= skip_length; + png_ptr->buffer_size -= save_size; + png_ptr->current_buffer_size -= save_size; + png_ptr->current_buffer_ptr += save_size; + } + if (!png_ptr->skip_length) + { + if (png_ptr->buffer_size < 4) + { + png_push_save_buffer(png_ptr); + return; + } + + png_crc_finish(png_ptr, 0); + png_ptr->process_mode = PNG_READ_CHUNK_MODE; + } +} + +void PNGCBAPI +png_push_fill_buffer(png_structp png_ptr, png_bytep buffer, png_size_t length) +{ + png_bytep ptr; + + if (png_ptr == NULL) + return; + + ptr = buffer; + if (png_ptr->save_buffer_size) + { + png_size_t save_size; + + if (length < png_ptr->save_buffer_size) + save_size = length; + + else + save_size = png_ptr->save_buffer_size; + + memcpy(ptr, png_ptr->save_buffer_ptr, save_size); + length -= save_size; + ptr += save_size; + png_ptr->buffer_size -= save_size; + png_ptr->save_buffer_size -= save_size; + png_ptr->save_buffer_ptr += save_size; + } + if (length && png_ptr->current_buffer_size) + { + png_size_t save_size; + + if (length < png_ptr->current_buffer_size) + save_size = length; + + else + save_size = png_ptr->current_buffer_size; + + memcpy(ptr, png_ptr->current_buffer_ptr, save_size); + png_ptr->buffer_size -= save_size; + png_ptr->current_buffer_size -= save_size; + png_ptr->current_buffer_ptr += save_size; + } +} + +void /* PRIVATE */ +png_push_save_buffer(png_structrp png_ptr) +{ + if (png_ptr->save_buffer_size) + { + if (png_ptr->save_buffer_ptr != png_ptr->save_buffer) + { + png_size_t i, istop; + png_bytep sp; + png_bytep dp; + + istop = png_ptr->save_buffer_size; + for (i = 0, sp = png_ptr->save_buffer_ptr, dp = png_ptr->save_buffer; + i < istop; i++, sp++, dp++) + { + *dp = *sp; + } + } + } + if (png_ptr->save_buffer_size + png_ptr->current_buffer_size > + png_ptr->save_buffer_max) + { + png_size_t new_max; + png_bytep old_buffer; + + if (png_ptr->save_buffer_size > PNG_SIZE_MAX - + (png_ptr->current_buffer_size + 256)) + { + png_error(png_ptr, "Potential overflow of save_buffer"); + } + + new_max = png_ptr->save_buffer_size + png_ptr->current_buffer_size + 256; + old_buffer = png_ptr->save_buffer; + png_ptr->save_buffer = (png_bytep)png_malloc_warn(png_ptr, + (png_size_t)new_max); + + if (png_ptr->save_buffer == NULL) + { + png_free(png_ptr, old_buffer); + png_error(png_ptr, "Insufficient memory for save_buffer"); + } + + memcpy(png_ptr->save_buffer, old_buffer, png_ptr->save_buffer_size); + png_free(png_ptr, old_buffer); + png_ptr->save_buffer_max = new_max; + } + if (png_ptr->current_buffer_size) + { + memcpy(png_ptr->save_buffer + png_ptr->save_buffer_size, + png_ptr->current_buffer_ptr, png_ptr->current_buffer_size); + png_ptr->save_buffer_size += png_ptr->current_buffer_size; + png_ptr->current_buffer_size = 0; + } + png_ptr->save_buffer_ptr = png_ptr->save_buffer; + png_ptr->buffer_size = 0; +} + +void /* PRIVATE */ +png_push_restore_buffer(png_structrp png_ptr, png_bytep buffer, + png_size_t buffer_length) +{ + png_ptr->current_buffer = buffer; + png_ptr->current_buffer_size = buffer_length; + png_ptr->buffer_size = buffer_length + png_ptr->save_buffer_size; + png_ptr->current_buffer_ptr = png_ptr->current_buffer; +} + +void /* PRIVATE */ +png_push_read_IDAT(png_structrp png_ptr) +{ + if (!(png_ptr->mode & PNG_HAVE_CHUNK_HEADER)) + { + png_byte chunk_length[4]; + png_byte chunk_tag[4]; + + /* TODO: this code can be commoned up with the same code in push_read */ + if (png_ptr->buffer_size < 8) + { + png_push_save_buffer(png_ptr); + return; + } + + png_push_fill_buffer(png_ptr, chunk_length, 4); + png_ptr->push_length = png_get_uint_31(png_ptr, chunk_length); + png_reset_crc(png_ptr); + png_crc_read(png_ptr, chunk_tag, 4); + png_ptr->chunk_name = PNG_CHUNK_FROM_STRING(chunk_tag); + png_ptr->mode |= PNG_HAVE_CHUNK_HEADER; + + if (png_ptr->chunk_name != png_IDAT) + { + png_ptr->process_mode = PNG_READ_CHUNK_MODE; + + if (!(png_ptr->flags & PNG_FLAG_ZSTREAM_ENDED)) + png_error(png_ptr, "Not enough compressed data"); + + return; + } + + png_ptr->idat_size = png_ptr->push_length; + } + + if (png_ptr->idat_size && png_ptr->save_buffer_size) + { + png_size_t save_size = png_ptr->save_buffer_size; + png_uint_32 idat_size = png_ptr->idat_size; + + /* We want the smaller of 'idat_size' and 'current_buffer_size', but they + * are of different types and we don't know which variable has the fewest + * bits. Carefully select the smaller and cast it to the type of the + * larger - this cannot overflow. Do not cast in the following test - it + * will break on either 16 or 64 bit platforms. + */ + if (idat_size < save_size) + save_size = (png_size_t)idat_size; + + else + idat_size = (png_uint_32)save_size; + + png_calculate_crc(png_ptr, png_ptr->save_buffer_ptr, save_size); + + png_process_IDAT_data(png_ptr, png_ptr->save_buffer_ptr, save_size); + + png_ptr->idat_size -= idat_size; + png_ptr->buffer_size -= save_size; + png_ptr->save_buffer_size -= save_size; + png_ptr->save_buffer_ptr += save_size; + } + + if (png_ptr->idat_size && png_ptr->current_buffer_size) + { + png_size_t save_size = png_ptr->current_buffer_size; + png_uint_32 idat_size = png_ptr->idat_size; + + /* We want the smaller of 'idat_size' and 'current_buffer_size', but they + * are of different types and we don't know which variable has the fewest + * bits. Carefully select the smaller and cast it to the type of the + * larger - this cannot overflow. + */ + if (idat_size < save_size) + save_size = (png_size_t)idat_size; + + else + idat_size = (png_uint_32)save_size; + + png_calculate_crc(png_ptr, png_ptr->current_buffer_ptr, save_size); + + png_process_IDAT_data(png_ptr, png_ptr->current_buffer_ptr, save_size); + + png_ptr->idat_size -= idat_size; + png_ptr->buffer_size -= save_size; + png_ptr->current_buffer_size -= save_size; + png_ptr->current_buffer_ptr += save_size; + } + if (!png_ptr->idat_size) + { + if (png_ptr->buffer_size < 4) + { + png_push_save_buffer(png_ptr); + return; + } + + png_crc_finish(png_ptr, 0); + png_ptr->mode &= ~PNG_HAVE_CHUNK_HEADER; + png_ptr->mode |= PNG_AFTER_IDAT; + png_ptr->zowner = 0; + } +} + +void /* PRIVATE */ +png_process_IDAT_data(png_structrp png_ptr, png_bytep buffer, + png_size_t buffer_length) +{ + /* The caller checks for a non-zero buffer length. */ + if (!(buffer_length > 0) || buffer == NULL) + png_error(png_ptr, "No IDAT data (internal error)"); + + /* This routine must process all the data it has been given + * before returning, calling the row callback as required to + * handle the uncompressed results. + */ + png_ptr->zstream.next_in = buffer; + /* TODO: WARNING: TRUNCATION ERROR: DANGER WILL ROBINSON: */ + png_ptr->zstream.avail_in = (uInt)buffer_length; + + /* Keep going until the decompressed data is all processed + * or the stream marked as finished. + */ + while (png_ptr->zstream.avail_in > 0 && + !(png_ptr->flags & PNG_FLAG_ZSTREAM_ENDED)) + { + int ret; + + /* We have data for zlib, but we must check that zlib + * has someplace to put the results. It doesn't matter + * if we don't expect any results -- it may be the input + * data is just the LZ end code. + */ + if (!(png_ptr->zstream.avail_out > 0)) + { + /* TODO: WARNING: TRUNCATION ERROR: DANGER WILL ROBINSON: */ + png_ptr->zstream.avail_out = (uInt)(PNG_ROWBYTES(png_ptr->pixel_depth, + png_ptr->iwidth) + 1); + + png_ptr->zstream.next_out = png_ptr->row_buf; + } + + /* Using Z_SYNC_FLUSH here means that an unterminated + * LZ stream (a stream with a missing end code) can still + * be handled, otherwise (Z_NO_FLUSH) a future zlib + * implementation might defer output and therefore + * change the current behavior (see comments in inflate.c + * for why this doesn't happen at present with zlib 1.2.5). + */ + ret = inflate(&png_ptr->zstream, Z_SYNC_FLUSH); + + /* Check for any failure before proceeding. */ + if (ret != Z_OK && ret != Z_STREAM_END) + { + /* Terminate the decompression. */ + png_ptr->flags |= PNG_FLAG_ZSTREAM_ENDED; + png_ptr->zowner = 0; + + /* This may be a truncated stream (missing or + * damaged end code). Treat that as a warning. + */ + if (png_ptr->row_number >= png_ptr->num_rows || + png_ptr->pass > 6) + png_warning(png_ptr, "Truncated compressed data in IDAT"); + + else + png_error(png_ptr, "Decompression error in IDAT"); + + /* Skip the check on unprocessed input */ + return; + } + + /* Did inflate output any data? */ + if (png_ptr->zstream.next_out != png_ptr->row_buf) + { + /* Is this unexpected data after the last row? + * If it is, artificially terminate the LZ output + * here. + */ + if (png_ptr->row_number >= png_ptr->num_rows || + png_ptr->pass > 6) + { + /* Extra data. */ + png_warning(png_ptr, "Extra compressed data in IDAT"); + png_ptr->flags |= PNG_FLAG_ZSTREAM_ENDED; + png_ptr->zowner = 0; + + /* Do no more processing; skip the unprocessed + * input check below. + */ + return; + } + + /* Do we have a complete row? */ + if (png_ptr->zstream.avail_out == 0) + png_push_process_row(png_ptr); + } + + /* And check for the end of the stream. */ + if (ret == Z_STREAM_END) + png_ptr->flags |= PNG_FLAG_ZSTREAM_ENDED; + } + + /* All the data should have been processed, if anything + * is left at this point we have bytes of IDAT data + * after the zlib end code. + */ + if (png_ptr->zstream.avail_in > 0) + png_warning(png_ptr, "Extra compression data in IDAT"); +} + +void /* PRIVATE */ +png_push_process_row(png_structrp png_ptr) +{ + /* 1.5.6: row_info moved out of png_struct to a local here. */ + png_row_info row_info; + + row_info.width = png_ptr->iwidth; /* NOTE: width of current interlaced row */ + row_info.color_type = png_ptr->color_type; + row_info.bit_depth = png_ptr->bit_depth; + row_info.channels = png_ptr->channels; + row_info.pixel_depth = png_ptr->pixel_depth; + row_info.rowbytes = PNG_ROWBYTES(row_info.pixel_depth, row_info.width); + + if (png_ptr->row_buf[0] > PNG_FILTER_VALUE_NONE) + { + if (png_ptr->row_buf[0] < PNG_FILTER_VALUE_LAST) + png_read_filter_row(png_ptr, &row_info, png_ptr->row_buf + 1, + png_ptr->prev_row + 1, png_ptr->row_buf[0]); + else + png_error(png_ptr, "bad adaptive filter value"); + } + + /* libpng 1.5.6: the following line was copying png_ptr->rowbytes before + * 1.5.6, while the buffer really is this big in current versions of libpng + * it may not be in the future, so this was changed just to copy the + * interlaced row count: + */ + memcpy(png_ptr->prev_row, png_ptr->row_buf, row_info.rowbytes + 1); + +#ifdef PNG_READ_TRANSFORMS_SUPPORTED + if (png_ptr->transformations) + png_do_read_transformations(png_ptr, &row_info); +#endif + + /* The transformed pixel depth should match the depth now in row_info. */ + if (png_ptr->transformed_pixel_depth == 0) + { + png_ptr->transformed_pixel_depth = row_info.pixel_depth; + if (row_info.pixel_depth > png_ptr->maximum_pixel_depth) + png_error(png_ptr, "progressive row overflow"); + } + + else if (png_ptr->transformed_pixel_depth != row_info.pixel_depth) + png_error(png_ptr, "internal progressive row size calculation error"); + + +#ifdef PNG_READ_INTERLACING_SUPPORTED + /* Blow up interlaced rows to full size */ + if (png_ptr->interlaced && (png_ptr->transformations & PNG_INTERLACE)) + { + if (png_ptr->pass < 6) + png_do_read_interlace(&row_info, png_ptr->row_buf + 1, png_ptr->pass, + png_ptr->transformations); + + switch (png_ptr->pass) + { + case 0: + { + int i; + for (i = 0; i < 8 && png_ptr->pass == 0; i++) + { + png_push_have_row(png_ptr, png_ptr->row_buf + 1); + png_read_push_finish_row(png_ptr); /* Updates png_ptr->pass */ + } + + if (png_ptr->pass == 2) /* Pass 1 might be empty */ + { + for (i = 0; i < 4 && png_ptr->pass == 2; i++) + { + png_push_have_row(png_ptr, NULL); + png_read_push_finish_row(png_ptr); + } + } + + if (png_ptr->pass == 4 && png_ptr->height <= 4) + { + for (i = 0; i < 2 && png_ptr->pass == 4; i++) + { + png_push_have_row(png_ptr, NULL); + png_read_push_finish_row(png_ptr); + } + } + + if (png_ptr->pass == 6 && png_ptr->height <= 4) + { + png_push_have_row(png_ptr, NULL); + png_read_push_finish_row(png_ptr); + } + + break; + } + + case 1: + { + int i; + for (i = 0; i < 8 && png_ptr->pass == 1; i++) + { + png_push_have_row(png_ptr, png_ptr->row_buf + 1); + png_read_push_finish_row(png_ptr); + } + + if (png_ptr->pass == 2) /* Skip top 4 generated rows */ + { + for (i = 0; i < 4 && png_ptr->pass == 2; i++) + { + png_push_have_row(png_ptr, NULL); + png_read_push_finish_row(png_ptr); + } + } + + break; + } + + case 2: + { + int i; + + for (i = 0; i < 4 && png_ptr->pass == 2; i++) + { + png_push_have_row(png_ptr, png_ptr->row_buf + 1); + png_read_push_finish_row(png_ptr); + } + + for (i = 0; i < 4 && png_ptr->pass == 2; i++) + { + png_push_have_row(png_ptr, NULL); + png_read_push_finish_row(png_ptr); + } + + if (png_ptr->pass == 4) /* Pass 3 might be empty */ + { + for (i = 0; i < 2 && png_ptr->pass == 4; i++) + { + png_push_have_row(png_ptr, NULL); + png_read_push_finish_row(png_ptr); + } + } + + break; + } + + case 3: + { + int i; + + for (i = 0; i < 4 && png_ptr->pass == 3; i++) + { + png_push_have_row(png_ptr, png_ptr->row_buf + 1); + png_read_push_finish_row(png_ptr); + } + + if (png_ptr->pass == 4) /* Skip top two generated rows */ + { + for (i = 0; i < 2 && png_ptr->pass == 4; i++) + { + png_push_have_row(png_ptr, NULL); + png_read_push_finish_row(png_ptr); + } + } + + break; + } + + case 4: + { + int i; + + for (i = 0; i < 2 && png_ptr->pass == 4; i++) + { + png_push_have_row(png_ptr, png_ptr->row_buf + 1); + png_read_push_finish_row(png_ptr); + } + + for (i = 0; i < 2 && png_ptr->pass == 4; i++) + { + png_push_have_row(png_ptr, NULL); + png_read_push_finish_row(png_ptr); + } + + if (png_ptr->pass == 6) /* Pass 5 might be empty */ + { + png_push_have_row(png_ptr, NULL); + png_read_push_finish_row(png_ptr); + } + + break; + } + + case 5: + { + int i; + + for (i = 0; i < 2 && png_ptr->pass == 5; i++) + { + png_push_have_row(png_ptr, png_ptr->row_buf + 1); + png_read_push_finish_row(png_ptr); + } + + if (png_ptr->pass == 6) /* Skip top generated row */ + { + png_push_have_row(png_ptr, NULL); + png_read_push_finish_row(png_ptr); + } + + break; + } + + default: + case 6: + { + png_push_have_row(png_ptr, png_ptr->row_buf + 1); + png_read_push_finish_row(png_ptr); + + if (png_ptr->pass != 6) + break; + + png_push_have_row(png_ptr, NULL); + png_read_push_finish_row(png_ptr); + } + } + } + else +#endif + { + png_push_have_row(png_ptr, png_ptr->row_buf + 1); + png_read_push_finish_row(png_ptr); + } +} + +void /* PRIVATE */ +png_read_push_finish_row(png_structrp png_ptr) +{ +#ifdef PNG_READ_INTERLACING_SUPPORTED + /* Arrays to facilitate easy interlacing - use pass (0 - 6) as index */ + + /* Start of interlace block */ + static PNG_CONST png_byte png_pass_start[] = {0, 4, 0, 2, 0, 1, 0}; + + /* Offset to next interlace block */ + static PNG_CONST png_byte png_pass_inc[] = {8, 8, 4, 4, 2, 2, 1}; + + /* Start of interlace block in the y direction */ + static PNG_CONST png_byte png_pass_ystart[] = {0, 0, 4, 0, 2, 0, 1}; + + /* Offset to next interlace block in the y direction */ + static PNG_CONST png_byte png_pass_yinc[] = {8, 8, 8, 4, 4, 2, 2}; + + /* Height of interlace block. This is not currently used - if you need + * it, uncomment it here and in png.h + static PNG_CONST png_byte png_pass_height[] = {8, 8, 4, 4, 2, 2, 1}; + */ +#endif + + png_ptr->row_number++; + if (png_ptr->row_number < png_ptr->num_rows) + return; + +#ifdef PNG_READ_INTERLACING_SUPPORTED + if (png_ptr->interlaced) + { + png_ptr->row_number = 0; + memset(png_ptr->prev_row, 0, png_ptr->rowbytes + 1); + + do + { + png_ptr->pass++; + if ((png_ptr->pass == 1 && png_ptr->width < 5) || + (png_ptr->pass == 3 && png_ptr->width < 3) || + (png_ptr->pass == 5 && png_ptr->width < 2)) + png_ptr->pass++; + + if (png_ptr->pass > 7) + png_ptr->pass--; + + if (png_ptr->pass >= 7) + break; + + png_ptr->iwidth = (png_ptr->width + + png_pass_inc[png_ptr->pass] - 1 - + png_pass_start[png_ptr->pass]) / + png_pass_inc[png_ptr->pass]; + + if (png_ptr->transformations & PNG_INTERLACE) + break; + + png_ptr->num_rows = (png_ptr->height + + png_pass_yinc[png_ptr->pass] - 1 - + png_pass_ystart[png_ptr->pass]) / + png_pass_yinc[png_ptr->pass]; + + } while (png_ptr->iwidth == 0 || png_ptr->num_rows == 0); + } +#endif /* PNG_READ_INTERLACING_SUPPORTED */ +} + +void /* PRIVATE */ +png_push_have_info(png_structrp png_ptr, png_inforp info_ptr) +{ + if (png_ptr->info_fn != NULL) + (*(png_ptr->info_fn))(png_ptr, info_ptr); +} + +void /* PRIVATE */ +png_push_have_end(png_structrp png_ptr, png_inforp info_ptr) +{ + if (png_ptr->end_fn != NULL) + (*(png_ptr->end_fn))(png_ptr, info_ptr); +} + +void /* PRIVATE */ +png_push_have_row(png_structrp png_ptr, png_bytep row) +{ + if (png_ptr->row_fn != NULL) + (*(png_ptr->row_fn))(png_ptr, row, png_ptr->row_number, + (int)png_ptr->pass); +} + +#ifdef PNG_READ_INTERLACING_SUPPORTED +void PNGAPI +png_progressive_combine_row(png_const_structrp png_ptr, png_bytep old_row, + png_const_bytep new_row) +{ + if (png_ptr == NULL) + return; + + /* new_row is a flag here - if it is NULL then the app callback was called + * from an empty row (see the calls to png_struct::row_fn below), otherwise + * it must be png_ptr->row_buf+1 + */ + if (new_row != NULL) + png_combine_row(png_ptr, old_row, 1/*display*/); +} +#endif /* PNG_READ_INTERLACING_SUPPORTED */ + +void PNGAPI +png_set_progressive_read_fn(png_structrp png_ptr, png_voidp progressive_ptr, + png_progressive_info_ptr info_fn, png_progressive_row_ptr row_fn, + png_progressive_end_ptr end_fn) +{ + if (png_ptr == NULL) + return; + + png_ptr->info_fn = info_fn; + png_ptr->row_fn = row_fn; + png_ptr->end_fn = end_fn; + + png_set_read_fn(png_ptr, progressive_ptr, png_push_fill_buffer); +} + +png_voidp PNGAPI +png_get_progressive_ptr(png_const_structrp png_ptr) +{ + if (png_ptr == NULL) + return (NULL); + + return png_ptr->io_ptr; +} +#endif /* PNG_PROGRESSIVE_READ_SUPPORTED */ diff --git a/ml/dlib/dlib/external/libpng/pngpriv.h b/ml/dlib/dlib/external/libpng/pngpriv.h new file mode 100644 index 000000000..43b2ef4c5 --- /dev/null +++ b/ml/dlib/dlib/external/libpng/pngpriv.h @@ -0,0 +1,2047 @@ + +/* pngpriv.h - private declarations for use inside libpng + * + * For conditions of distribution and use, see copyright notice in png.h + * Copyright (c) 1998-2013 Glenn Randers-Pehrson + * (Version 0.96 Copyright (c) 1996, 1997 Andreas Dilger) + * (Version 0.88 Copyright (c) 1995, 1996 Guy Eric Schalnat, Group 42, Inc.) + * + * Last changed in libpng 1.6.7 [November 14, 2013] + * + * This code is released under the libpng license. + * For conditions of distribution and use, see the disclaimer + * and license in png.h + */ + +/* The symbols declared in this file (including the functions declared + * as extern) are PRIVATE. They are not part of the libpng public + * interface, and are not recommended for use by regular applications. + * Some of them may become public in the future; others may stay private, + * change in an incompatible way, or even disappear. + * Although the libpng users are not forbidden to include this header, + * they should be well aware of the issues that may arise from doing so. + */ + +#ifndef PNGPRIV_H +#define PNGPRIV_H + +/* Feature Test Macros. The following are defined here to ensure that correctly + * implemented libraries reveal the APIs libpng needs to build and hide those + * that are not needed and potentially damaging to the compilation. + * + * Feature Test Macros must be defined before any system header is included (see + * POSIX 1003.1 2.8.2 "POSIX Symbols." + * + * These macros only have an effect if the operating system supports either + * POSIX 1003.1 or C99, or both. On other operating systems (particularly + * Windows/Visual Studio) there is no effect; the OS specific tests below are + * still required (as of 2011-05-02.) + */ +#define _POSIX_SOURCE 1 /* Just the POSIX 1003.1 and C89 APIs */ + +#ifndef PNG_VERSION_INFO_ONLY +/* Standard library headers not required by png.h: */ +# include +# include +#endif + +#define PNGLIB_BUILD /*libpng is being built, not used*/ + +/* If HAVE_CONFIG_H is defined during the build then the build system must + * provide an appropriate "config.h" file on the include path. The header file + * must provide definitions as required below (search for "HAVE_CONFIG_H"); + * see configure.ac for more details of the requirements. The macro + * "PNG_NO_CONFIG_H" is provided for maintainers to test for dependencies on + * 'configure'; define this macro to prevent the configure build including the + * configure generated config.h. Libpng is expected to compile without *any* + * special build system support on a reasonably ANSI-C compliant system. + */ +#if defined(HAVE_CONFIG_H) && !defined(PNG_NO_CONFIG_H) +# include + + /* Pick up the definition of 'restrict' from config.h if it was read: */ +# define PNG_RESTRICT restrict +#endif + +/* To support symbol prefixing it is necessary to know *before* including png.h + * whether the fixed point (and maybe other) APIs are exported, because if they + * are not internal definitions may be required. This is handled below just + * before png.h is included, but load the configuration now if it is available. + */ +#ifndef PNGLCONF_H +# include "pnglibconf.h" +#endif + +/* Local renames may change non-exported API functions from png.h */ +#if defined(PNG_PREFIX) && !defined(PNGPREFIX_H) +# include "pngprefix.h" +#endif + +#ifdef PNG_USER_CONFIG +# include "pngusr.h" + /* These should have been defined in pngusr.h */ +# ifndef PNG_USER_PRIVATEBUILD +# define PNG_USER_PRIVATEBUILD "Custom libpng build" +# endif +# ifndef PNG_USER_DLLFNAME_POSTFIX +# define PNG_USER_DLLFNAME_POSTFIX "Cb" +# endif +#endif + +/* Compile time options. + * ===================== + * In a multi-arch build the compiler may compile the code several times for the + * same object module, producing different binaries for different architectures. + * When this happens configure-time setting of the target host options cannot be + * done and this interferes with the handling of the ARM NEON optimizations, and + * possibly other similar optimizations. Put additional tests here; in general + * this is needed when the same option can be changed at both compile time and + * run time depending on the target OS (i.e. iOS vs Android.) + * + * NOTE: symbol prefixing does not pass $(CFLAGS) to the preprocessor, because + * this is not possible with certain compilers (Oracle SUN OS CC), as a result + * it is necessary to ensure that all extern functions that *might* be used + * regardless of $(CFLAGS) get declared in this file. The test on __ARM_NEON__ + * below is one example of this behavior because it is controlled by the + * presence or not of -mfpu=neon on the GCC command line, it is possible to do + * this in $(CC), e.g. "CC=gcc -mfpu=neon", but people who build libpng rarely + * do this. + */ +#ifndef PNG_ARM_NEON_OPT + /* ARM NEON optimizations are being controlled by the compiler settings, + * typically the target FPU. If the FPU has been set to NEON (-mfpu=neon + * with GCC) then the compiler will define __ARM_NEON__ and we can rely + * unconditionally on NEON instructions not crashing, otherwise we must + * disable use of NEON instructions: + */ +# ifdef __ARM_NEON__ +# define PNG_ARM_NEON_OPT 2 +# else +# define PNG_ARM_NEON_OPT 0 +# endif +#endif + +#if PNG_ARM_NEON_OPT > 0 + /* NEON optimizations are to be at least considered by libpng, so enable the + * callbacks to do this. + */ +# define PNG_FILTER_OPTIMIZATIONS png_init_filter_functions_neon + + /* By default the 'intrinsics' code in arm/filter_neon_intrinsics.c is used + * if possible - if __ARM_NEON__ is set and the compiler version is not known + * to be broken. This is control by PNG_ARM_NEON_IMPLEMENTATION which can + * be: + * + * 1 The intrinsics code (the default with __ARM_NEON__) + * 2 The hand coded assembler (the default without __ARM_NEON__) + * + * It is possible to set PNG_ARM_NEON_IMPLEMENTATION in CPPFLAGS, however + * this is *NOT* supported and may cease to work even after a minor revision + * to libpng. It *is* valid to do this for testing purposes, e.g. speed + * testing or a new compiler, but the results should be communicated to the + * libpng implementation list for incorporation in the next minor release. + */ +# ifndef PNG_ARM_NEON_IMPLEMENTATION +# ifdef __ARM_NEON__ +# if defined(__clang__) + /* At present it is unknown by the libpng developers which versions + * of clang support the intrinsics, however some or perhaps all + * versions do not work with the assembler so this may be + * irrelevant, so just use the default (do nothing here.) + */ +# elif defined(__GNUC__) + /* GCC 4.5.4 NEON support is known to be broken. 4.6.3 is known to + * work, so if this *is* GCC, or G++, look for a version >4.5 + */ +# if __GNUC__ < 4 || (__GNUC__ == 4 && __GNUC_MINOR__ < 6) +# define PNG_ARM_NEON_IMPLEMENTATION 2 +# endif /* no GNUC support */ +# endif /* __GNUC__ */ +# else /* !defined __ARM_NEON__ */ + /* The 'intrinsics' code simply won't compile without this -mfpu=neon: + */ +# define PNG_ARM_NEON_IMPLEMENTATION 2 +# endif /* __ARM_NEON__ */ +# endif /* !defined PNG_ARM_NEON_IMPLEMENTATION */ + +# ifndef PNG_ARM_NEON_IMPLEMENTATION + /* Use the intrinsics code by default. */ +# define PNG_ARM_NEON_IMPLEMENTATION 1 +# endif +#endif /* PNG_ARM_NEON_OPT > 0 */ + +/* Is this a build of a DLL where compilation of the object modules requires + * different preprocessor settings to those required for a simple library? If + * so PNG_BUILD_DLL must be set. + * + * If libpng is used inside a DLL but that DLL does not export the libpng APIs + * PNG_BUILD_DLL must not be set. To avoid the code below kicking in build a + * static library of libpng then link the DLL against that. + */ +#ifndef PNG_BUILD_DLL +# ifdef DLL_EXPORT + /* This is set by libtool when files are compiled for a DLL; libtool + * always compiles twice, even on systems where it isn't necessary. Set + * PNG_BUILD_DLL in case it is necessary: + */ +# define PNG_BUILD_DLL +# else +# ifdef _WINDLL + /* This is set by the Microsoft Visual Studio IDE in projects that + * build a DLL. It can't easily be removed from those projects (it + * isn't visible in the Visual Studio UI) so it is a fairly reliable + * indication that PNG_IMPEXP needs to be set to the DLL export + * attributes. + */ +# define PNG_BUILD_DLL +# else +# ifdef __DLL__ + /* This is set by the Borland C system when compiling for a DLL + * (as above.) + */ +# define PNG_BUILD_DLL +# else + /* Add additional compiler cases here. */ +# endif +# endif +# endif +#endif /* Setting PNG_BUILD_DLL if required */ + +/* See pngconf.h for more details: the builder of the library may set this on + * the command line to the right thing for the specific compilation system or it + * may be automagically set above (at present we know of no system where it does + * need to be set on the command line.) + * + * PNG_IMPEXP must be set here when building the library to prevent pngconf.h + * setting it to the "import" setting for a DLL build. + */ +#ifndef PNG_IMPEXP +# ifdef PNG_BUILD_DLL +# define PNG_IMPEXP PNG_DLL_EXPORT +# else + /* Not building a DLL, or the DLL doesn't require specific export + * definitions. + */ +# define PNG_IMPEXP +# endif +#endif + +/* No warnings for private or deprecated functions in the build: */ +#ifndef PNG_DEPRECATED +# define PNG_DEPRECATED +#endif +#ifndef PNG_PRIVATE +# define PNG_PRIVATE +#endif + +/* Symbol preprocessing support. + * + * To enable listing global, but internal, symbols the following macros should + * always be used to declare an extern data or function object in this file. + */ +#ifndef PNG_INTERNAL_DATA +# define PNG_INTERNAL_DATA(type, name, array) extern type name array +#endif + +#ifndef PNG_INTERNAL_FUNCTION +# define PNG_INTERNAL_FUNCTION(type, name, args, attributes)\ + extern PNG_FUNCTION(type, name, args, PNG_EMPTY attributes) +#endif + +/* If floating or fixed point APIs are disabled they may still be compiled + * internally. To handle this make sure they are declared as the appropriate + * internal extern function (otherwise the symbol prefixing stuff won't work and + * the functions will be used without definitions.) + * + * NOTE: although all the API functions are declared here they are not all + * actually built! Because the declarations are still made it is necessary to + * fake out types that they depend on. + */ +#ifndef PNG_FP_EXPORT +# ifndef PNG_FLOATING_POINT_SUPPORTED +# define PNG_FP_EXPORT(ordinal, type, name, args)\ + PNG_INTERNAL_FUNCTION(type, name, args, PNG_EMPTY); +# ifndef PNG_VERSION_INFO_ONLY + typedef struct png_incomplete png_double; + typedef png_double* png_doublep; + typedef const png_double* png_const_doublep; + typedef png_double** png_doublepp; +# endif +# endif +#endif +#ifndef PNG_FIXED_EXPORT +# ifndef PNG_FIXED_POINT_SUPPORTED +# define PNG_FIXED_EXPORT(ordinal, type, name, args)\ + PNG_INTERNAL_FUNCTION(type, name, args, PNG_EMPTY); +# endif +#endif + +#include "png.h" + +/* pngconf.h does not set PNG_DLL_EXPORT unless it is required, so: */ +#ifndef PNG_DLL_EXPORT +# define PNG_DLL_EXPORT +#endif + +/* SECURITY and SAFETY: + * + * By default libpng is built without any internal limits on image size, + * individual heap (png_malloc) allocations or the total amount of memory used. + * If PNG_SAFE_LIMITS_SUPPORTED is defined, however, the limits below are used + * (unless individually overridden). These limits are believed to be fairly + * safe, but builders of secure systems should verify the values against the + * real system capabilities. + */ +#ifdef PNG_SAFE_LIMITS_SUPPORTED + /* 'safe' limits */ +# ifndef PNG_USER_WIDTH_MAX +# define PNG_USER_WIDTH_MAX 1000000 +# endif +# ifndef PNG_USER_HEIGHT_MAX +# define PNG_USER_HEIGHT_MAX 1000000 +# endif +# ifndef PNG_USER_CHUNK_CACHE_MAX +# define PNG_USER_CHUNK_CACHE_MAX 128 +# endif +# ifndef PNG_USER_CHUNK_MALLOC_MAX +# define PNG_USER_CHUNK_MALLOC_MAX 8000000 +# endif +#else + /* values for no limits */ +# ifndef PNG_USER_WIDTH_MAX +# define PNG_USER_WIDTH_MAX 0x7fffffff +# endif +# ifndef PNG_USER_HEIGHT_MAX +# define PNG_USER_HEIGHT_MAX 0x7fffffff +# endif +# ifndef PNG_USER_CHUNK_CACHE_MAX +# define PNG_USER_CHUNK_CACHE_MAX 0 +# endif +# ifndef PNG_USER_CHUNK_MALLOC_MAX +# define PNG_USER_CHUNK_MALLOC_MAX 0 +# endif +#endif + +/* Moved to pngpriv.h at libpng-1.5.0 */ +/* NOTE: some of these may have been used in external applications as + * these definitions were exposed in pngconf.h prior to 1.5. + */ + +/* If you are running on a machine where you cannot allocate more + * than 64K of memory at once, uncomment this. While libpng will not + * normally need that much memory in a chunk (unless you load up a very + * large file), zlib needs to know how big of a chunk it can use, and + * libpng thus makes sure to check any memory allocation to verify it + * will fit into memory. + * + * zlib provides 'MAXSEG_64K' which, if defined, indicates the + * same limit and pngconf.h (already included) sets the limit + * if certain operating systems are detected. + */ +#if defined(MAXSEG_64K) && !defined(PNG_MAX_MALLOC_64K) +# define PNG_MAX_MALLOC_64K +#endif + +#ifndef PNG_UNUSED +/* Unused formal parameter warnings are silenced using the following macro + * which is expected to have no bad effects on performance (optimizing + * compilers will probably remove it entirely). Note that if you replace + * it with something other than whitespace, you must include the terminating + * semicolon. + */ +# define PNG_UNUSED(param) (void)param; +#endif + +/* Just a little check that someone hasn't tried to define something + * contradictory. + */ +#if (PNG_ZBUF_SIZE > 65536L) && defined(PNG_MAX_MALLOC_64K) +# undef PNG_ZBUF_SIZE +# define PNG_ZBUF_SIZE 65536L +#endif + +/* If warnings or errors are turned off the code is disabled or redirected here. + * From 1.5.4 functions have been added to allow very limited formatting of + * error and warning messages - this code will also be disabled here. + */ +#ifdef PNG_WARNINGS_SUPPORTED +# define PNG_WARNING_PARAMETERS(p) png_warning_parameters p; +#else +# define png_warning(s1,s2) ((void)(s1)) +# define png_chunk_warning(s1,s2) ((void)(s1)) +# define png_warning_parameter(p,number,string) ((void)0) +# define png_warning_parameter_unsigned(p,number,format,value) ((void)0) +# define png_warning_parameter_signed(p,number,format,value) ((void)0) +# define png_formatted_warning(pp,p,message) ((void)(pp)) +# define PNG_WARNING_PARAMETERS(p) +#endif +#ifndef PNG_ERROR_TEXT_SUPPORTED +# define png_error(s1,s2) png_err(s1) +# define png_chunk_error(s1,s2) png_err(s1) +# define png_fixed_error(s1,s2) png_err(s1) +#endif + +/* C allows up-casts from (void*) to any pointer and (const void*) to any + * pointer to a const object. C++ regards this as a type error and requires an + * explicit, static, cast and provides the static_cast<> rune to ensure that + * const is not cast away. + */ +#ifdef __cplusplus +# define png_voidcast(type, value) static_cast(value) +# define png_constcast(type, value) const_cast(value) +# define png_aligncast(type, value) \ + static_cast(static_cast(value)) +# define png_aligncastconst(type, value) \ + static_cast(static_cast(value)) +#else +# define png_voidcast(type, value) (value) +# define png_constcast(type, value) ((type)(value)) +# define png_aligncast(type, value) ((void*)(value)) +# define png_aligncastconst(type, value) ((const void*)(value)) +#endif /* __cplusplus */ + +/* Some fixed point APIs are still required even if not exported because + * they get used by the corresponding floating point APIs. This magic + * deals with this: + */ +#ifdef PNG_FIXED_POINT_SUPPORTED +# define PNGFAPI PNGAPI +#else +# define PNGFAPI /* PRIVATE */ +#endif + +#ifndef PNG_VERSION_INFO_ONLY +/* Other defines specific to compilers can go here. Try to keep + * them inside an appropriate ifdef/endif pair for portability. + */ +#if defined(PNG_FLOATING_POINT_SUPPORTED) ||\ + defined(PNG_FLOATING_ARITHMETIC_SUPPORTED) + /* png.c requires the following ANSI-C constants if the conversion of + * floating point to ASCII is implemented therein: + * + * DBL_DIG Maximum number of decimal digits (can be set to any constant) + * DBL_MIN Smallest normalized fp number (can be set to an arbitrary value) + * DBL_MAX Maximum floating point number (can be set to an arbitrary value) + */ +# include + +# if (defined(__MWERKS__) && defined(macintosh)) || defined(applec) || \ + defined(THINK_C) || defined(__SC__) || defined(TARGET_OS_MAC) + /* We need to check that hasn't already been included earlier + * as it seems it doesn't agree with , yet we should really use + * if possible. + */ +# if !defined(__MATH_H__) && !defined(__MATH_H) && !defined(__cmath__) +# include +# endif +# else +# include +# endif +# if defined(_AMIGA) && defined(__SASC) && defined(_M68881) + /* Amiga SAS/C: We must include builtin FPU functions when compiling using + * MATH=68881 + */ +# include +# endif +#endif + +/* This provides the non-ANSI (far) memory allocation routines. */ +#if defined(__TURBOC__) && defined(__MSDOS__) +# include +# include +#endif + +#if defined(WIN32) || defined(_Windows) || defined(_WINDOWS) || \ + defined(_WIN32) || defined(__WIN32__) +# include /* defines _WINDOWS_ macro */ +#endif +#endif /* PNG_VERSION_INFO_ONLY */ + +/* Moved here around 1.5.0beta36 from pngconf.h */ +/* Users may want to use these so they are not private. Any library + * functions that are passed far data must be model-independent. + */ + +/* Memory model/platform independent fns */ +#ifndef PNG_ABORT +# ifdef _WINDOWS_ +# define PNG_ABORT() ExitProcess(0) +# else +# define PNG_ABORT() abort() +# endif +#endif + +/* These macros may need to be architecture dependent. */ +#define PNG_ALIGN_NONE 0 /* do not use data alignment */ +#define PNG_ALIGN_ALWAYS 1 /* assume unaligned accesses are OK */ +#ifdef offsetof +# define PNG_ALIGN_OFFSET 2 /* use offsetof to determine alignment */ +#else +# define PNG_ALIGN_OFFSET -1 /* prevent the use of this */ +#endif +#define PNG_ALIGN_SIZE 3 /* use sizeof to determine alignment */ + +#ifndef PNG_ALIGN_TYPE + /* Default to using aligned access optimizations and requiring alignment to a + * multiple of the data type size. Override in a compiler specific fashion + * if necessary by inserting tests here: + */ +# define PNG_ALIGN_TYPE PNG_ALIGN_SIZE +#endif + +#if PNG_ALIGN_TYPE == PNG_ALIGN_SIZE + /* This is used because in some compiler implementations non-aligned + * structure members are supported, so the offsetof approach below fails. + * Set PNG_ALIGN_SIZE=0 for compiler combinations where unaligned access + * is good for performance. Do not do this unless you have tested the result + * and understand it. + */ +# define png_alignof(type) (sizeof (type)) +#else +# if PNG_ALIGN_TYPE == PNG_ALIGN_OFFSET +# define png_alignof(type) offsetof(struct{char c; type t;}, t) +# else +# if PNG_ALIGN_TYPE == PNG_ALIGN_ALWAYS +# define png_alignof(type) (1) +# endif + /* Else leave png_alignof undefined to prevent use thereof */ +# endif +#endif + +/* This implicitly assumes alignment is always to a power of 2. */ +#ifdef png_alignof +# define png_isaligned(ptr, type)\ + ((((const char*)ptr-(const char*)0) & (png_alignof(type)-1)) == 0) +#else +# define png_isaligned(ptr, type) 0 +#endif + +/* End of memory model/platform independent support */ +/* End of 1.5.0beta36 move from pngconf.h */ + +/* CONSTANTS and UTILITY MACROS + * These are used internally by libpng and not exposed in the API + */ + +/* Various modes of operation. Note that after an init, mode is set to + * zero automatically when the structure is created. Three of these + * are defined in png.h because they need to be visible to applications + * that call png_set_unknown_chunk(). + */ +/* #define PNG_HAVE_IHDR 0x01 (defined in png.h) */ +/* #define PNG_HAVE_PLTE 0x02 (defined in png.h) */ +#define PNG_HAVE_IDAT 0x04 +/* #define PNG_AFTER_IDAT 0x08 (defined in png.h) */ +#define PNG_HAVE_IEND 0x10 + /* 0x20 (unused) */ + /* 0x40 (unused) */ + /* 0x80 (unused) */ +#define PNG_HAVE_CHUNK_HEADER 0x100 +#define PNG_WROTE_tIME 0x200 +#define PNG_WROTE_INFO_BEFORE_PLTE 0x400 +#define PNG_BACKGROUND_IS_GRAY 0x800 +#define PNG_HAVE_PNG_SIGNATURE 0x1000 +#define PNG_HAVE_CHUNK_AFTER_IDAT 0x2000 /* Have another chunk after IDAT */ + /* 0x4000 (unused) */ +#define PNG_IS_READ_STRUCT 0x8000 /* Else is a write struct */ + +/* Flags for the transformations the PNG library does on the image data */ +#define PNG_BGR 0x0001 +#define PNG_INTERLACE 0x0002 +#define PNG_PACK 0x0004 +#define PNG_SHIFT 0x0008 +#define PNG_SWAP_BYTES 0x0010 +#define PNG_INVERT_MONO 0x0020 +#define PNG_QUANTIZE 0x0040 +#define PNG_COMPOSE 0x0080 /* Was PNG_BACKGROUND */ +#define PNG_BACKGROUND_EXPAND 0x0100 +#define PNG_EXPAND_16 0x0200 /* Added to libpng 1.5.2 */ +#define PNG_16_TO_8 0x0400 /* Becomes 'chop' in 1.5.4 */ +#define PNG_RGBA 0x0800 +#define PNG_EXPAND 0x1000 +#define PNG_GAMMA 0x2000 +#define PNG_GRAY_TO_RGB 0x4000 +#define PNG_FILLER 0x8000 +#define PNG_PACKSWAP 0x10000 +#define PNG_SWAP_ALPHA 0x20000 +#define PNG_STRIP_ALPHA 0x40000 +#define PNG_INVERT_ALPHA 0x80000 +#define PNG_USER_TRANSFORM 0x100000 +#define PNG_RGB_TO_GRAY_ERR 0x200000 +#define PNG_RGB_TO_GRAY_WARN 0x400000 +#define PNG_RGB_TO_GRAY 0x600000 /* two bits, RGB_TO_GRAY_ERR|WARN */ +#define PNG_ENCODE_ALPHA 0x800000 /* Added to libpng-1.5.4 */ +#define PNG_ADD_ALPHA 0x1000000 /* Added to libpng-1.2.7 */ +#define PNG_EXPAND_tRNS 0x2000000 /* Added to libpng-1.2.9 */ +#define PNG_SCALE_16_TO_8 0x4000000 /* Added to libpng-1.5.4 */ + /* 0x8000000 unused */ + /* 0x10000000 unused */ + /* 0x20000000 unused */ + /* 0x40000000 unused */ +/* Flags for png_create_struct */ +#define PNG_STRUCT_PNG 0x0001 +#define PNG_STRUCT_INFO 0x0002 + +/* Scaling factor for filter heuristic weighting calculations */ +#define PNG_WEIGHT_FACTOR (1<<(PNG_WEIGHT_SHIFT)) +#define PNG_COST_FACTOR (1<<(PNG_COST_SHIFT)) + +/* Flags for the png_ptr->flags rather than declaring a byte for each one */ +#define PNG_FLAG_ZLIB_CUSTOM_STRATEGY 0x0001 +#define PNG_FLAG_ZSTREAM_INITIALIZED 0x0002 /* Added to libpng-1.6.0 */ + /* 0x0004 unused */ +#define PNG_FLAG_ZSTREAM_ENDED 0x0008 /* Added to libpng-1.6.0 */ + /* 0x0010 unused */ + /* 0x0020 unused */ +#define PNG_FLAG_ROW_INIT 0x0040 +#define PNG_FLAG_FILLER_AFTER 0x0080 +#define PNG_FLAG_CRC_ANCILLARY_USE 0x0100 +#define PNG_FLAG_CRC_ANCILLARY_NOWARN 0x0200 +#define PNG_FLAG_CRC_CRITICAL_USE 0x0400 +#define PNG_FLAG_CRC_CRITICAL_IGNORE 0x0800 +#define PNG_FLAG_ASSUME_sRGB 0x1000 /* Added to libpng-1.5.4 */ +#define PNG_FLAG_OPTIMIZE_ALPHA 0x2000 /* Added to libpng-1.5.4 */ +#define PNG_FLAG_DETECT_UNINITIALIZED 0x4000 /* Added to libpng-1.5.4 */ +/* #define PNG_FLAG_KEEP_UNKNOWN_CHUNKS 0x8000 */ +/* #define PNG_FLAG_KEEP_UNSAFE_CHUNKS 0x10000 */ +#define PNG_FLAG_LIBRARY_MISMATCH 0x20000 +#define PNG_FLAG_STRIP_ERROR_NUMBERS 0x40000 +#define PNG_FLAG_STRIP_ERROR_TEXT 0x80000 +#define PNG_FLAG_BENIGN_ERRORS_WARN 0x100000 /* Added to libpng-1.4.0 */ +#define PNG_FLAG_APP_WARNINGS_WARN 0x200000 /* Added to libpng-1.6.0 */ +#define PNG_FLAG_APP_ERRORS_WARN 0x400000 /* Added to libpng-1.6.0 */ + /* 0x800000 unused */ + /* 0x1000000 unused */ + /* 0x2000000 unused */ + /* 0x4000000 unused */ + /* 0x8000000 unused */ + /* 0x10000000 unused */ + /* 0x20000000 unused */ + /* 0x40000000 unused */ + +#define PNG_FLAG_CRC_ANCILLARY_MASK (PNG_FLAG_CRC_ANCILLARY_USE | \ + PNG_FLAG_CRC_ANCILLARY_NOWARN) + +#define PNG_FLAG_CRC_CRITICAL_MASK (PNG_FLAG_CRC_CRITICAL_USE | \ + PNG_FLAG_CRC_CRITICAL_IGNORE) + +#define PNG_FLAG_CRC_MASK (PNG_FLAG_CRC_ANCILLARY_MASK | \ + PNG_FLAG_CRC_CRITICAL_MASK) + +/* Save typing and make code easier to understand */ + +#define PNG_COLOR_DIST(c1, c2) (abs((int)((c1).red) - (int)((c2).red)) + \ + abs((int)((c1).green) - (int)((c2).green)) + \ + abs((int)((c1).blue) - (int)((c2).blue))) + +/* Added to libpng-1.6.0: scale a 16-bit value in the range 0..65535 to 0..255 + * by dividing by 257 *with rounding*. This macro is exact for the given range. + * See the discourse in pngrtran.c png_do_scale_16_to_8. The values in the + * macro were established by experiment (modifying the added value). The macro + * has a second variant that takes a value already scaled by 255 and divides by + * 65535 - this has a maximum error of .502. Over the range 0..65535*65535 it + * only gives off-by-one errors and only for 0.5% (1 in 200) of the values. + */ +#define PNG_DIV65535(v24) (((v24) + 32895) >> 16) +#define PNG_DIV257(v16) PNG_DIV65535((png_uint_32)(v16) * 255) + +/* Added to libpng-1.2.6 JB */ +#define PNG_ROWBYTES(pixel_bits, width) \ + ((pixel_bits) >= 8 ? \ + ((png_size_t)(width) * (((png_size_t)(pixel_bits)) >> 3)) : \ + (( ((png_size_t)(width) * ((png_size_t)(pixel_bits))) + 7) >> 3) ) + +/* PNG_OUT_OF_RANGE returns true if value is outside the range + * ideal-delta..ideal+delta. Each argument is evaluated twice. + * "ideal" and "delta" should be constants, normally simple + * integers, "value" a variable. Added to libpng-1.2.6 JB + */ +#define PNG_OUT_OF_RANGE(value, ideal, delta) \ + ( (value) < (ideal)-(delta) || (value) > (ideal)+(delta) ) + +/* Conversions between fixed and floating point, only defined if + * required (to make sure the code doesn't accidentally use float + * when it is supposedly disabled.) + */ +#ifdef PNG_FLOATING_POINT_SUPPORTED +/* The floating point conversion can't overflow, though it can and + * does lose accuracy relative to the original fixed point value. + * In practice this doesn't matter because png_fixed_point only + * stores numbers with very low precision. The png_ptr and s + * arguments are unused by default but are there in case error + * checking becomes a requirement. + */ +#define png_float(png_ptr, fixed, s) (.00001 * (fixed)) + +/* The fixed point conversion performs range checking and evaluates + * its argument multiple times, so must be used with care. The + * range checking uses the PNG specification values for a signed + * 32 bit fixed point value except that the values are deliberately + * rounded-to-zero to an integral value - 21474 (21474.83 is roughly + * (2^31-1) * 100000). 's' is a string that describes the value being + * converted. + * + * NOTE: this macro will raise a png_error if the range check fails, + * therefore it is normally only appropriate to use this on values + * that come from API calls or other sources where an out of range + * error indicates a programming error, not a data error! + * + * NOTE: by default this is off - the macro is not used - because the + * function call saves a lot of code. + */ +#ifdef PNG_FIXED_POINT_MACRO_SUPPORTED +#define png_fixed(png_ptr, fp, s) ((fp) <= 21474 && (fp) >= -21474 ?\ + ((png_fixed_point)(100000 * (fp))) : (png_fixed_error(png_ptr, s),0)) +#endif +/* else the corresponding function is defined below, inside the scope of the + * cplusplus test. + */ +#endif + +/* Constants for known chunk types. If you need to add a chunk, define the name + * here. For historical reasons these constants have the form png_; i.e. + * the prefix is lower case. Please use decimal values as the parameters to + * match the ISO PNG specification and to avoid relying on the C locale + * interpretation of character values. + * + * Prior to 1.5.6 these constants were strings, as of 1.5.6 png_uint_32 values + * are computed and a new macro (PNG_STRING_FROM_CHUNK) added to allow a string + * to be generated if required. + * + * PNG_32b correctly produces a value shifted by up to 24 bits, even on + * architectures where (int) is only 16 bits. + */ +#define PNG_32b(b,s) ((png_uint_32)(b) << (s)) +#define PNG_U32(b1,b2,b3,b4) \ + (PNG_32b(b1,24) | PNG_32b(b2,16) | PNG_32b(b3,8) | PNG_32b(b4,0)) + +/* Constants for known chunk types. + * + * MAINTAINERS: If you need to add a chunk, define the name here. + * For historical reasons these constants have the form png_; i.e. + * the prefix is lower case. Please use decimal values as the parameters to + * match the ISO PNG specification and to avoid relying on the C locale + * interpretation of character values. Please keep the list sorted. + * + * Notice that PNG_U32 is used to define a 32-bit value for the 4 byte chunk + * type. In fact the specification does not express chunk types this way, + * however using a 32-bit value means that the chunk type can be read from the + * stream using exactly the same code as used for a 32-bit unsigned value and + * can be examined far more efficiently (using one arithmetic compare). + * + * Prior to 1.5.6 the chunk type constants were expressed as C strings. The + * libpng API still uses strings for 'unknown' chunks and a macro, + * PNG_STRING_FROM_CHUNK, allows a string to be generated if required. Notice + * that for portable code numeric values must still be used; the string "IHDR" + * is not portable and neither is PNG_U32('I', 'H', 'D', 'R'). + * + * In 1.7.0 the definitions will be made public in png.h to avoid having to + * duplicate the same definitions in application code. + */ +#define png_IDAT PNG_U32( 73, 68, 65, 84) +#define png_IEND PNG_U32( 73, 69, 78, 68) +#define png_IHDR PNG_U32( 73, 72, 68, 82) +#define png_PLTE PNG_U32( 80, 76, 84, 69) +#define png_bKGD PNG_U32( 98, 75, 71, 68) +#define png_cHRM PNG_U32( 99, 72, 82, 77) +#define png_fRAc PNG_U32(102, 82, 65, 99) /* registered, not defined */ +#define png_gAMA PNG_U32(103, 65, 77, 65) +#define png_gIFg PNG_U32(103, 73, 70, 103) +#define png_gIFt PNG_U32(103, 73, 70, 116) /* deprecated */ +#define png_gIFx PNG_U32(103, 73, 70, 120) +#define png_hIST PNG_U32(104, 73, 83, 84) +#define png_iCCP PNG_U32(105, 67, 67, 80) +#define png_iTXt PNG_U32(105, 84, 88, 116) +#define png_oFFs PNG_U32(111, 70, 70, 115) +#define png_pCAL PNG_U32(112, 67, 65, 76) +#define png_pHYs PNG_U32(112, 72, 89, 115) +#define png_sBIT PNG_U32(115, 66, 73, 84) +#define png_sCAL PNG_U32(115, 67, 65, 76) +#define png_sPLT PNG_U32(115, 80, 76, 84) +#define png_sRGB PNG_U32(115, 82, 71, 66) +#define png_sTER PNG_U32(115, 84, 69, 82) +#define png_tEXt PNG_U32(116, 69, 88, 116) +#define png_tIME PNG_U32(116, 73, 77, 69) +#define png_tRNS PNG_U32(116, 82, 78, 83) +#define png_zTXt PNG_U32(122, 84, 88, 116) + +/* The following will work on (signed char*) strings, whereas the get_uint_32 + * macro will fail on top-bit-set values because of the sign extension. + */ +#define PNG_CHUNK_FROM_STRING(s)\ + PNG_U32(0xff&(s)[0], 0xff&(s)[1], 0xff&(s)[2], 0xff&(s)[3]) + +/* This uses (char), not (png_byte) to avoid warnings on systems where (char) is + * signed and the argument is a (char[]) This macro will fail miserably on + * systems where (char) is more than 8 bits. + */ +#define PNG_STRING_FROM_CHUNK(s,c)\ + (void)(((char*)(s))[0]=(char)((c)>>24), ((char*)(s))[1]=(char)((c)>>16),\ + ((char*)(s))[2]=(char)((c)>>8), ((char*)(s))[3]=(char)((c))) + +/* Do the same but terminate with a null character. */ +#define PNG_CSTRING_FROM_CHUNK(s,c)\ + (void)(PNG_STRING_FROM_CHUNK(s,c), ((char*)(s))[4] = 0) + +/* Test on flag values as defined in the spec (section 5.4): */ +#define PNG_CHUNK_ANCILLARY(c) (1 & ((c) >> 29)) +#define PNG_CHUNK_CRITICAL(c) (!PNG_CHUNK_ANCILLARY(c)) +#define PNG_CHUNK_PRIVATE(c) (1 & ((c) >> 21)) +#define PNG_CHUNK_RESERVED(c) (1 & ((c) >> 13)) +#define PNG_CHUNK_SAFE_TO_COPY(c) (1 & ((c) >> 5)) + +/* Gamma values (new at libpng-1.5.4): */ +#define PNG_GAMMA_MAC_OLD 151724 /* Assume '1.8' is really 2.2/1.45! */ +#define PNG_GAMMA_MAC_INVERSE 65909 +#define PNG_GAMMA_sRGB_INVERSE 45455 + +/* Almost everything below is C specific; the #defines above can be used in + * non-C code (so long as it is C-preprocessed) the rest of this stuff cannot. + */ +#ifndef PNG_VERSION_INFO_ONLY + +#include "pngstruct.h" +#include "pnginfo.h" + +/* Validate the include paths - the include path used to generate pnglibconf.h + * must match that used in the build, or we must be using pnglibconf.h.prebuilt: + */ +#if PNG_ZLIB_VERNUM != 0 && PNG_ZLIB_VERNUM != ZLIB_VERNUM +# error ZLIB_VERNUM != PNG_ZLIB_VERNUM \ + "-I (include path) error: see the notes in pngpriv.h" + /* This means that when pnglibconf.h was built the copy of zlib.h that it + * used is not the same as the one being used here. Because the build of + * libpng makes decisions to use inflateInit2 and inflateReset2 based on the + * zlib version number and because this affects handling of certain broken + * PNG files the -I directives must match. + * + * The most likely explanation is that you passed a -I in CFLAGS, this will + * not work; all the preprocessor directories and in particular all the -I + * directives must be in CPPFLAGS. + */ +#endif + +/* This is used for 16 bit gamma tables -- only the top level pointers are + * const; this could be changed: + */ +typedef const png_uint_16p * png_const_uint_16pp; + +/* Added to libpng-1.5.7: sRGB conversion tables */ +#if defined(PNG_SIMPLIFIED_READ_SUPPORTED) ||\ + defined(PNG_SIMPLIFIED_WRITE_SUPPORTED) +#ifdef PNG_SIMPLIFIED_READ_SUPPORTED +PNG_INTERNAL_DATA(const png_uint_16, png_sRGB_table, [256]); + /* Convert from an sRGB encoded value 0..255 to a 16-bit linear value, + * 0..65535. This table gives the closest 16-bit answers (no errors). + */ +#endif + +PNG_INTERNAL_DATA(const png_uint_16, png_sRGB_base, [512]); +PNG_INTERNAL_DATA(const png_byte, png_sRGB_delta, [512]); + +#define PNG_sRGB_FROM_LINEAR(linear) ((png_byte)((png_sRGB_base[(linear)>>15] +\ + ((((linear)&0x7fff)*png_sRGB_delta[(linear)>>15])>>12)) >> 8)) + /* Given a value 'linear' in the range 0..255*65535 calculate the 8-bit sRGB + * encoded value with maximum error 0.646365. Note that the input is not a + * 16-bit value; it has been multiplied by 255! */ +#endif /* PNG_SIMPLIFIED_READ/WRITE */ + + +/* Inhibit C++ name-mangling for libpng functions but not for system calls. */ +#ifdef __cplusplus +extern "C" { +#endif /* __cplusplus */ + +/* Internal functions; these are not exported from a DLL however because they + * are used within several of the C source files they have to be C extern. + * + * All of these functions must be declared with PNG_INTERNAL_FUNCTION. + */ + +/* Zlib support */ +#define PNG_UNEXPECTED_ZLIB_RETURN (-7) +PNG_INTERNAL_FUNCTION(void, png_zstream_error,(png_structrp png_ptr, int ret), + PNG_EMPTY); + /* Used by the zlib handling functions to ensure that z_stream::msg is always + * set before they return. + */ + +#ifdef PNG_WRITE_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_free_buffer_list,(png_structrp png_ptr, + png_compression_bufferp *list),PNG_EMPTY); + /* Free the buffer list used by the compressed write code. */ +#endif + +#if defined(PNG_FLOATING_POINT_SUPPORTED) && \ + !defined(PNG_FIXED_POINT_MACRO_SUPPORTED) && \ + (defined(PNG_gAMA_SUPPORTED) || defined(PNG_cHRM_SUPPORTED) || \ + defined(PNG_sCAL_SUPPORTED) || defined(PNG_READ_BACKGROUND_SUPPORTED) || \ + defined(PNG_READ_RGB_TO_GRAY_SUPPORTED)) || \ + (defined(PNG_sCAL_SUPPORTED) && \ + defined(PNG_FLOATING_ARITHMETIC_SUPPORTED)) +PNG_INTERNAL_FUNCTION(png_fixed_point,png_fixed,(png_const_structrp png_ptr, + double fp, png_const_charp text),PNG_EMPTY); +#endif + +/* Check the user version string for compatibility, returns false if the version + * numbers aren't compatible. + */ +PNG_INTERNAL_FUNCTION(int,png_user_version_check,(png_structrp png_ptr, + png_const_charp user_png_ver),PNG_EMPTY); + +/* Internal base allocator - no messages, NULL on failure to allocate. This + * does, however, call the application provided allocator and that could call + * png_error (although that would be a bug in the application implementation.) + */ +PNG_INTERNAL_FUNCTION(png_voidp,png_malloc_base,(png_const_structrp png_ptr, + png_alloc_size_t size),PNG_ALLOCATED); + +#if defined(PNG_TEXT_SUPPORTED) || defined(PNG_sPLT_SUPPORTED) ||\ + defined(PNG_STORE_UNKNOWN_CHUNKS_SUPPORTED) +/* Internal array allocator, outputs no error or warning messages on failure, + * just returns NULL. + */ +PNG_INTERNAL_FUNCTION(png_voidp,png_malloc_array,(png_const_structrp png_ptr, + int nelements, size_t element_size),PNG_ALLOCATED); + +/* The same but an existing array is extended by add_elements. This function + * also memsets the new elements to 0 and copies the old elements. The old + * array is not freed or altered. + */ +PNG_INTERNAL_FUNCTION(png_voidp,png_realloc_array,(png_const_structrp png_ptr, + png_const_voidp array, int old_elements, int add_elements, + size_t element_size),PNG_ALLOCATED); +#endif /* text, sPLT or unknown chunks */ + +/* Magic to create a struct when there is no struct to call the user supplied + * memory allocators. Because error handling has not been set up the memory + * handlers can't safely call png_error, but this is an obscure and undocumented + * restriction so libpng has to assume that the 'free' handler, at least, might + * call png_error. + */ +PNG_INTERNAL_FUNCTION(png_structp,png_create_png_struct, + (png_const_charp user_png_ver, png_voidp error_ptr, png_error_ptr error_fn, + png_error_ptr warn_fn, png_voidp mem_ptr, png_malloc_ptr malloc_fn, + png_free_ptr free_fn),PNG_ALLOCATED); + +/* Free memory from internal libpng struct */ +PNG_INTERNAL_FUNCTION(void,png_destroy_png_struct,(png_structrp png_ptr), + PNG_EMPTY); + +/* Free an allocated jmp_buf (always succeeds) */ +PNG_INTERNAL_FUNCTION(void,png_free_jmpbuf,(png_structrp png_ptr),PNG_EMPTY); + +/* Function to allocate memory for zlib. PNGAPI is disallowed. */ +PNG_INTERNAL_FUNCTION(voidpf,png_zalloc,(voidpf png_ptr, uInt items, uInt size), + PNG_ALLOCATED); + +/* Function to free memory for zlib. PNGAPI is disallowed. */ +PNG_INTERNAL_FUNCTION(void,png_zfree,(voidpf png_ptr, voidpf ptr),PNG_EMPTY); + +/* Next four functions are used internally as callbacks. PNGCBAPI is required + * but not PNG_EXPORT. PNGAPI added at libpng version 1.2.3, changed to + * PNGCBAPI at 1.5.0 + */ + +PNG_INTERNAL_FUNCTION(void PNGCBAPI,png_default_read_data,(png_structp png_ptr, + png_bytep data, png_size_t length),PNG_EMPTY); + +#ifdef PNG_PROGRESSIVE_READ_SUPPORTED +PNG_INTERNAL_FUNCTION(void PNGCBAPI,png_push_fill_buffer,(png_structp png_ptr, + png_bytep buffer, png_size_t length),PNG_EMPTY); +#endif + +PNG_INTERNAL_FUNCTION(void PNGCBAPI,png_default_write_data,(png_structp png_ptr, + png_bytep data, png_size_t length),PNG_EMPTY); + +#ifdef PNG_WRITE_FLUSH_SUPPORTED +# ifdef PNG_STDIO_SUPPORTED +PNG_INTERNAL_FUNCTION(void PNGCBAPI,png_default_flush,(png_structp png_ptr), + PNG_EMPTY); +# endif +#endif + +/* Reset the CRC variable */ +PNG_INTERNAL_FUNCTION(void,png_reset_crc,(png_structrp png_ptr),PNG_EMPTY); + +/* Write the "data" buffer to whatever output you are using */ +PNG_INTERNAL_FUNCTION(void,png_write_data,(png_structrp png_ptr, + png_const_bytep data, png_size_t length),PNG_EMPTY); + +/* Read and check the PNG file signature */ +PNG_INTERNAL_FUNCTION(void,png_read_sig,(png_structrp png_ptr, + png_inforp info_ptr),PNG_EMPTY); + +/* Read the chunk header (length + type name) */ +PNG_INTERNAL_FUNCTION(png_uint_32,png_read_chunk_header,(png_structrp png_ptr), + PNG_EMPTY); + +/* Read data from whatever input you are using into the "data" buffer */ +PNG_INTERNAL_FUNCTION(void,png_read_data,(png_structrp png_ptr, png_bytep data, + png_size_t length),PNG_EMPTY); + +/* Read bytes into buf, and update png_ptr->crc */ +PNG_INTERNAL_FUNCTION(void,png_crc_read,(png_structrp png_ptr, png_bytep buf, + png_uint_32 length),PNG_EMPTY); + +/* Read "skip" bytes, read the file crc, and (optionally) verify png_ptr->crc */ +PNG_INTERNAL_FUNCTION(int,png_crc_finish,(png_structrp png_ptr, + png_uint_32 skip),PNG_EMPTY); + +/* Read the CRC from the file and compare it to the libpng calculated CRC */ +PNG_INTERNAL_FUNCTION(int,png_crc_error,(png_structrp png_ptr),PNG_EMPTY); + +/* Calculate the CRC over a section of data. Note that we are only + * passing a maximum of 64K on systems that have this as a memory limit, + * since this is the maximum buffer size we can specify. + */ +PNG_INTERNAL_FUNCTION(void,png_calculate_crc,(png_structrp png_ptr, + png_const_bytep ptr, png_size_t length),PNG_EMPTY); + +#ifdef PNG_WRITE_FLUSH_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_flush,(png_structrp png_ptr),PNG_EMPTY); +#endif + +/* Write various chunks */ + +/* Write the IHDR chunk, and update the png_struct with the necessary + * information. + */ +PNG_INTERNAL_FUNCTION(void,png_write_IHDR,(png_structrp png_ptr, + png_uint_32 width, png_uint_32 height, int bit_depth, int color_type, + int compression_method, int filter_method, int interlace_method),PNG_EMPTY); + +PNG_INTERNAL_FUNCTION(void,png_write_PLTE,(png_structrp png_ptr, + png_const_colorp palette, png_uint_32 num_pal),PNG_EMPTY); + +PNG_INTERNAL_FUNCTION(void,png_compress_IDAT,(png_structrp png_ptr, + png_const_bytep row_data, png_alloc_size_t row_data_length, int flush), + PNG_EMPTY); + +PNG_INTERNAL_FUNCTION(void,png_write_IEND,(png_structrp png_ptr),PNG_EMPTY); + +#ifdef PNG_WRITE_gAMA_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_write_gAMA_fixed,(png_structrp png_ptr, + png_fixed_point file_gamma),PNG_EMPTY); +#endif + +#ifdef PNG_WRITE_sBIT_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_write_sBIT,(png_structrp png_ptr, + png_const_color_8p sbit, int color_type),PNG_EMPTY); +#endif + +#ifdef PNG_WRITE_cHRM_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_write_cHRM_fixed,(png_structrp png_ptr, + const png_xy *xy), PNG_EMPTY); + /* The xy value must have been previously validated */ +#endif + +#ifdef PNG_WRITE_sRGB_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_write_sRGB,(png_structrp png_ptr, + int intent),PNG_EMPTY); +#endif + +#ifdef PNG_WRITE_iCCP_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_write_iCCP,(png_structrp png_ptr, + png_const_charp name, png_const_bytep profile), PNG_EMPTY); + /* The profile must have been previously validated for correctness, the + * length comes from the first four bytes. Only the base, deflate, + * compression is supported. + */ +#endif + +#ifdef PNG_WRITE_sPLT_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_write_sPLT,(png_structrp png_ptr, + png_const_sPLT_tp palette),PNG_EMPTY); +#endif + +#ifdef PNG_WRITE_tRNS_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_write_tRNS,(png_structrp png_ptr, + png_const_bytep trans, png_const_color_16p values, int number, + int color_type),PNG_EMPTY); +#endif + +#ifdef PNG_WRITE_bKGD_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_write_bKGD,(png_structrp png_ptr, + png_const_color_16p values, int color_type),PNG_EMPTY); +#endif + +#ifdef PNG_WRITE_hIST_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_write_hIST,(png_structrp png_ptr, + png_const_uint_16p hist, int num_hist),PNG_EMPTY); +#endif + +/* Chunks that have keywords */ +#ifdef PNG_WRITE_tEXt_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_write_tEXt,(png_structrp png_ptr, + png_const_charp key, png_const_charp text, png_size_t text_len),PNG_EMPTY); +#endif + +#ifdef PNG_WRITE_zTXt_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_write_zTXt,(png_structrp png_ptr, png_const_charp + key, png_const_charp text, png_size_t text_len, int compression),PNG_EMPTY); +#endif + +#ifdef PNG_WRITE_iTXt_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_write_iTXt,(png_structrp png_ptr, + int compression, png_const_charp key, png_const_charp lang, + png_const_charp lang_key, png_const_charp text),PNG_EMPTY); +#endif + +#ifdef PNG_TEXT_SUPPORTED /* Added at version 1.0.14 and 1.2.4 */ +PNG_INTERNAL_FUNCTION(int,png_set_text_2,(png_const_structrp png_ptr, + png_inforp info_ptr, png_const_textp text_ptr, int num_text),PNG_EMPTY); +#endif + +#ifdef PNG_WRITE_oFFs_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_write_oFFs,(png_structrp png_ptr, + png_int_32 x_offset, png_int_32 y_offset, int unit_type),PNG_EMPTY); +#endif + +#ifdef PNG_WRITE_pCAL_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_write_pCAL,(png_structrp png_ptr, + png_charp purpose, png_int_32 X0, png_int_32 X1, int type, int nparams, + png_const_charp units, png_charpp params),PNG_EMPTY); +#endif + +#ifdef PNG_WRITE_pHYs_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_write_pHYs,(png_structrp png_ptr, + png_uint_32 x_pixels_per_unit, png_uint_32 y_pixels_per_unit, + int unit_type),PNG_EMPTY); +#endif + +#ifdef PNG_WRITE_tIME_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_write_tIME,(png_structrp png_ptr, + png_const_timep mod_time),PNG_EMPTY); +#endif + +#ifdef PNG_WRITE_sCAL_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_write_sCAL_s,(png_structrp png_ptr, + int unit, png_const_charp width, png_const_charp height),PNG_EMPTY); +#endif + +/* Called when finished processing a row of data */ +PNG_INTERNAL_FUNCTION(void,png_write_finish_row,(png_structrp png_ptr), + PNG_EMPTY); + +/* Internal use only. Called before first row of data */ +PNG_INTERNAL_FUNCTION(void,png_write_start_row,(png_structrp png_ptr), + PNG_EMPTY); + +/* Combine a row of data, dealing with alpha, etc. if requested. 'row' is an + * array of png_ptr->width pixels. If the image is not interlaced or this + * is the final pass this just does a memcpy, otherwise the "display" flag + * is used to determine whether to copy pixels that are not in the current pass. + * + * Because 'png_do_read_interlace' (below) replicates pixels this allows this + * function to achieve the documented 'blocky' appearance during interlaced read + * if display is 1 and the 'sparkle' appearance, where existing pixels in 'row' + * are not changed if they are not in the current pass, when display is 0. + * + * 'display' must be 0 or 1, otherwise the memcpy will be done regardless. + * + * The API always reads from the png_struct row buffer and always assumes that + * it is full width (png_do_read_interlace has already been called.) + * + * This function is only ever used to write to row buffers provided by the + * caller of the relevant libpng API and the row must have already been + * transformed by the read transformations. + * + * The PNG_USE_COMPILE_TIME_MASKS option causes generation of pre-computed + * bitmasks for use within the code, otherwise runtime generated masks are used. + * The default is compile time masks. + */ +#ifndef PNG_USE_COMPILE_TIME_MASKS +# define PNG_USE_COMPILE_TIME_MASKS 1 +#endif +PNG_INTERNAL_FUNCTION(void,png_combine_row,(png_const_structrp png_ptr, + png_bytep row, int display),PNG_EMPTY); + +#ifdef PNG_READ_INTERLACING_SUPPORTED +/* Expand an interlaced row: the 'row_info' describes the pass data that has + * been read in and must correspond to the pixels in 'row', the pixels are + * expanded (moved apart) in 'row' to match the final layout, when doing this + * the pixels are *replicated* to the intervening space. This is essential for + * the correct operation of png_combine_row, above. + */ +PNG_INTERNAL_FUNCTION(void,png_do_read_interlace,(png_row_infop row_info, + png_bytep row, int pass, png_uint_32 transformations),PNG_EMPTY); +#endif + +/* GRR TO DO (2.0 or whenever): simplify other internal calling interfaces */ + +#ifdef PNG_WRITE_INTERLACING_SUPPORTED +/* Grab pixels out of a row for an interlaced pass */ +PNG_INTERNAL_FUNCTION(void,png_do_write_interlace,(png_row_infop row_info, + png_bytep row, int pass),PNG_EMPTY); +#endif + +/* Unfilter a row: check the filter value before calling this, there is no point + * calling it for PNG_FILTER_VALUE_NONE. + */ +PNG_INTERNAL_FUNCTION(void,png_read_filter_row,(png_structrp pp, png_row_infop + row_info, png_bytep row, png_const_bytep prev_row, int filter),PNG_EMPTY); + +PNG_INTERNAL_FUNCTION(void,png_read_filter_row_up_neon,(png_row_infop row_info, + png_bytep row, png_const_bytep prev_row),PNG_EMPTY); +PNG_INTERNAL_FUNCTION(void,png_read_filter_row_sub3_neon,(png_row_infop + row_info, png_bytep row, png_const_bytep prev_row),PNG_EMPTY); +PNG_INTERNAL_FUNCTION(void,png_read_filter_row_sub4_neon,(png_row_infop + row_info, png_bytep row, png_const_bytep prev_row),PNG_EMPTY); +PNG_INTERNAL_FUNCTION(void,png_read_filter_row_avg3_neon,(png_row_infop + row_info, png_bytep row, png_const_bytep prev_row),PNG_EMPTY); +PNG_INTERNAL_FUNCTION(void,png_read_filter_row_avg4_neon,(png_row_infop + row_info, png_bytep row, png_const_bytep prev_row),PNG_EMPTY); +PNG_INTERNAL_FUNCTION(void,png_read_filter_row_paeth3_neon,(png_row_infop + row_info, png_bytep row, png_const_bytep prev_row),PNG_EMPTY); +PNG_INTERNAL_FUNCTION(void,png_read_filter_row_paeth4_neon,(png_row_infop + row_info, png_bytep row, png_const_bytep prev_row),PNG_EMPTY); + +/* Choose the best filter to use and filter the row data */ +PNG_INTERNAL_FUNCTION(void,png_write_find_filter,(png_structrp png_ptr, + png_row_infop row_info),PNG_EMPTY); + +#ifdef PNG_SEQUENTIAL_READ_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_read_IDAT_data,(png_structrp png_ptr, + png_bytep output, png_alloc_size_t avail_out),PNG_EMPTY); + /* Read 'avail_out' bytes of data from the IDAT stream. If the output buffer + * is NULL the function checks, instead, for the end of the stream. In this + * case a benign error will be issued if the stream end is not found or if + * extra data has to be consumed. + */ +PNG_INTERNAL_FUNCTION(void,png_read_finish_IDAT,(png_structrp png_ptr), + PNG_EMPTY); + /* This cleans up when the IDAT LZ stream does not end when the last image + * byte is read; there is still some pending input. + */ + +PNG_INTERNAL_FUNCTION(void,png_read_finish_row,(png_structrp png_ptr), + PNG_EMPTY); + /* Finish a row while reading, dealing with interlacing passes, etc. */ +#endif + +/* Initialize the row buffers, etc. */ +PNG_INTERNAL_FUNCTION(void,png_read_start_row,(png_structrp png_ptr),PNG_EMPTY); + +#ifdef PNG_READ_TRANSFORMS_SUPPORTED +/* Optional call to update the users info structure */ +PNG_INTERNAL_FUNCTION(void,png_read_transform_info,(png_structrp png_ptr, + png_inforp info_ptr),PNG_EMPTY); +#endif + +/* These are the functions that do the transformations */ +#ifdef PNG_READ_FILLER_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_do_read_filler,(png_row_infop row_info, + png_bytep row, png_uint_32 filler, png_uint_32 flags),PNG_EMPTY); +#endif + +#ifdef PNG_READ_SWAP_ALPHA_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_do_read_swap_alpha,(png_row_infop row_info, + png_bytep row),PNG_EMPTY); +#endif + +#ifdef PNG_WRITE_SWAP_ALPHA_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_do_write_swap_alpha,(png_row_infop row_info, + png_bytep row),PNG_EMPTY); +#endif + +#ifdef PNG_READ_INVERT_ALPHA_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_do_read_invert_alpha,(png_row_infop row_info, + png_bytep row),PNG_EMPTY); +#endif + +#ifdef PNG_WRITE_INVERT_ALPHA_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_do_write_invert_alpha,(png_row_infop row_info, + png_bytep row),PNG_EMPTY); +#endif + +#if defined(PNG_WRITE_FILLER_SUPPORTED) || \ + defined(PNG_READ_STRIP_ALPHA_SUPPORTED) +PNG_INTERNAL_FUNCTION(void,png_do_strip_channel,(png_row_infop row_info, + png_bytep row, int at_start),PNG_EMPTY); +#endif + +#ifdef PNG_16BIT_SUPPORTED +#if defined(PNG_READ_SWAP_SUPPORTED) || defined(PNG_WRITE_SWAP_SUPPORTED) +PNG_INTERNAL_FUNCTION(void,png_do_swap,(png_row_infop row_info, + png_bytep row),PNG_EMPTY); +#endif +#endif + +#if defined(PNG_READ_PACKSWAP_SUPPORTED) || \ + defined(PNG_WRITE_PACKSWAP_SUPPORTED) +PNG_INTERNAL_FUNCTION(void,png_do_packswap,(png_row_infop row_info, + png_bytep row),PNG_EMPTY); +#endif + +#ifdef PNG_READ_RGB_TO_GRAY_SUPPORTED +PNG_INTERNAL_FUNCTION(int,png_do_rgb_to_gray,(png_structrp png_ptr, + png_row_infop row_info, png_bytep row),PNG_EMPTY); +#endif + +#ifdef PNG_READ_GRAY_TO_RGB_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_do_gray_to_rgb,(png_row_infop row_info, + png_bytep row),PNG_EMPTY); +#endif + +#ifdef PNG_READ_PACK_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_do_unpack,(png_row_infop row_info, + png_bytep row),PNG_EMPTY); +#endif + +#ifdef PNG_READ_SHIFT_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_do_unshift,(png_row_infop row_info, + png_bytep row, png_const_color_8p sig_bits),PNG_EMPTY); +#endif + +#if defined(PNG_READ_INVERT_SUPPORTED) || defined(PNG_WRITE_INVERT_SUPPORTED) +PNG_INTERNAL_FUNCTION(void,png_do_invert,(png_row_infop row_info, + png_bytep row),PNG_EMPTY); +#endif + +#ifdef PNG_READ_SCALE_16_TO_8_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_do_scale_16_to_8,(png_row_infop row_info, + png_bytep row),PNG_EMPTY); +#endif + +#ifdef PNG_READ_STRIP_16_TO_8_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_do_chop,(png_row_infop row_info, + png_bytep row),PNG_EMPTY); +#endif + +#ifdef PNG_READ_QUANTIZE_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_do_quantize,(png_row_infop row_info, + png_bytep row, png_const_bytep palette_lookup, + png_const_bytep quantize_lookup),PNG_EMPTY); + +# ifdef PNG_CORRECT_PALETTE_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_correct_palette,(png_structrp png_ptr, + png_colorp palette, int num_palette),PNG_EMPTY); +# endif +#endif + +#if defined(PNG_READ_BGR_SUPPORTED) || defined(PNG_WRITE_BGR_SUPPORTED) +PNG_INTERNAL_FUNCTION(void,png_do_bgr,(png_row_infop row_info, + png_bytep row),PNG_EMPTY); +#endif + +#ifdef PNG_WRITE_PACK_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_do_pack,(png_row_infop row_info, + png_bytep row, png_uint_32 bit_depth),PNG_EMPTY); +#endif + +#ifdef PNG_WRITE_SHIFT_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_do_shift,(png_row_infop row_info, + png_bytep row, png_const_color_8p bit_depth),PNG_EMPTY); +#endif + +#if defined(PNG_READ_BACKGROUND_SUPPORTED) ||\ + defined(PNG_READ_ALPHA_MODE_SUPPORTED) +PNG_INTERNAL_FUNCTION(void,png_do_compose,(png_row_infop row_info, + png_bytep row, png_structrp png_ptr),PNG_EMPTY); +#endif + +#ifdef PNG_READ_GAMMA_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_do_gamma,(png_row_infop row_info, + png_bytep row, png_structrp png_ptr),PNG_EMPTY); +#endif + +#ifdef PNG_READ_ALPHA_MODE_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_do_encode_alpha,(png_row_infop row_info, + png_bytep row, png_structrp png_ptr),PNG_EMPTY); +#endif + +#ifdef PNG_READ_EXPAND_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_do_expand_palette,(png_row_infop row_info, + png_bytep row, png_const_colorp palette, png_const_bytep trans, + int num_trans),PNG_EMPTY); +PNG_INTERNAL_FUNCTION(void,png_do_expand,(png_row_infop row_info, + png_bytep row, png_const_color_16p trans_color),PNG_EMPTY); +#endif + +#ifdef PNG_READ_EXPAND_16_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_do_expand_16,(png_row_infop row_info, + png_bytep row),PNG_EMPTY); +#endif + +/* The following decodes the appropriate chunks, and does error correction, + * then calls the appropriate callback for the chunk if it is valid. + */ + +/* Decode the IHDR chunk */ +PNG_INTERNAL_FUNCTION(void,png_handle_IHDR,(png_structrp png_ptr, + png_inforp info_ptr, png_uint_32 length),PNG_EMPTY); +PNG_INTERNAL_FUNCTION(void,png_handle_PLTE,(png_structrp png_ptr, + png_inforp info_ptr, png_uint_32 length),PNG_EMPTY); +PNG_INTERNAL_FUNCTION(void,png_handle_IEND,(png_structrp png_ptr, + png_inforp info_ptr, png_uint_32 length),PNG_EMPTY); + +#ifdef PNG_READ_bKGD_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_handle_bKGD,(png_structrp png_ptr, + png_inforp info_ptr, png_uint_32 length),PNG_EMPTY); +#endif + +#ifdef PNG_READ_cHRM_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_handle_cHRM,(png_structrp png_ptr, + png_inforp info_ptr, png_uint_32 length),PNG_EMPTY); +#endif + +#ifdef PNG_READ_gAMA_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_handle_gAMA,(png_structrp png_ptr, + png_inforp info_ptr, png_uint_32 length),PNG_EMPTY); +#endif + +#ifdef PNG_READ_hIST_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_handle_hIST,(png_structrp png_ptr, + png_inforp info_ptr, png_uint_32 length),PNG_EMPTY); +#endif + +#ifdef PNG_READ_iCCP_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_handle_iCCP,(png_structrp png_ptr, + png_inforp info_ptr, png_uint_32 length),PNG_EMPTY); +#endif /* PNG_READ_iCCP_SUPPORTED */ + +#ifdef PNG_READ_iTXt_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_handle_iTXt,(png_structrp png_ptr, + png_inforp info_ptr, png_uint_32 length),PNG_EMPTY); +#endif + +#ifdef PNG_READ_oFFs_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_handle_oFFs,(png_structrp png_ptr, + png_inforp info_ptr, png_uint_32 length),PNG_EMPTY); +#endif + +#ifdef PNG_READ_pCAL_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_handle_pCAL,(png_structrp png_ptr, + png_inforp info_ptr, png_uint_32 length),PNG_EMPTY); +#endif + +#ifdef PNG_READ_pHYs_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_handle_pHYs,(png_structrp png_ptr, + png_inforp info_ptr, png_uint_32 length),PNG_EMPTY); +#endif + +#ifdef PNG_READ_sBIT_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_handle_sBIT,(png_structrp png_ptr, + png_inforp info_ptr, png_uint_32 length),PNG_EMPTY); +#endif + +#ifdef PNG_READ_sCAL_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_handle_sCAL,(png_structrp png_ptr, + png_inforp info_ptr, png_uint_32 length),PNG_EMPTY); +#endif + +#ifdef PNG_READ_sPLT_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_handle_sPLT,(png_structrp png_ptr, + png_inforp info_ptr, png_uint_32 length),PNG_EMPTY); +#endif /* PNG_READ_sPLT_SUPPORTED */ + +#ifdef PNG_READ_sRGB_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_handle_sRGB,(png_structrp png_ptr, + png_inforp info_ptr, png_uint_32 length),PNG_EMPTY); +#endif + +#ifdef PNG_READ_tEXt_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_handle_tEXt,(png_structrp png_ptr, + png_inforp info_ptr, png_uint_32 length),PNG_EMPTY); +#endif + +#ifdef PNG_READ_tIME_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_handle_tIME,(png_structrp png_ptr, + png_inforp info_ptr, png_uint_32 length),PNG_EMPTY); +#endif + +#ifdef PNG_READ_tRNS_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_handle_tRNS,(png_structrp png_ptr, + png_inforp info_ptr, png_uint_32 length),PNG_EMPTY); +#endif + +#ifdef PNG_READ_zTXt_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_handle_zTXt,(png_structrp png_ptr, + png_inforp info_ptr, png_uint_32 length),PNG_EMPTY); +#endif + +PNG_INTERNAL_FUNCTION(void,png_check_chunk_name,(png_structrp png_ptr, + png_uint_32 chunk_name),PNG_EMPTY); + +#ifdef PNG_SET_UNKNOWN_CHUNKS_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_handle_unknown,(png_structrp png_ptr, + png_inforp info_ptr, png_uint_32 length, int keep),PNG_EMPTY); + /* This is the function that gets called for unknown chunks. The 'keep' + * argument is either non-zero for a known chunk that has been set to be + * handled as unknown or zero for an unknown chunk. By default the function + * just skips the chunk or errors out if it is critical. + */ + +#if defined(PNG_READ_UNKNOWN_CHUNKS_SUPPORTED) ||\ + defined(PNG_HANDLE_AS_UNKNOWN_SUPPORTED) +PNG_INTERNAL_FUNCTION(int,png_chunk_unknown_handling, + (png_const_structrp png_ptr, png_uint_32 chunk_name),PNG_EMPTY); + /* Exactly as the API png_handle_as_unknown() except that the argument is a + * 32-bit chunk name, not a string. + */ +#endif /* READ_UNKNOWN_CHUNKS || HANDLE_AS_UNKNOWN */ +#endif /* PNG_SET_UNKNOWN_CHUNKS_SUPPORTED */ + +/* Handle the transformations for reading and writing */ +#ifdef PNG_READ_TRANSFORMS_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_do_read_transformations,(png_structrp png_ptr, + png_row_infop row_info),PNG_EMPTY); +#endif +#ifdef PNG_WRITE_TRANSFORMS_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_do_write_transformations,(png_structrp png_ptr, + png_row_infop row_info),PNG_EMPTY); +#endif + +#ifdef PNG_READ_TRANSFORMS_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_init_read_transformations,(png_structrp png_ptr), + PNG_EMPTY); +#endif + +#ifdef PNG_PROGRESSIVE_READ_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_push_read_chunk,(png_structrp png_ptr, + png_inforp info_ptr),PNG_EMPTY); +PNG_INTERNAL_FUNCTION(void,png_push_read_sig,(png_structrp png_ptr, + png_inforp info_ptr),PNG_EMPTY); +PNG_INTERNAL_FUNCTION(void,png_push_check_crc,(png_structrp png_ptr),PNG_EMPTY); +PNG_INTERNAL_FUNCTION(void,png_push_crc_skip,(png_structrp png_ptr, + png_uint_32 length),PNG_EMPTY); +PNG_INTERNAL_FUNCTION(void,png_push_crc_finish,(png_structrp png_ptr), + PNG_EMPTY); +PNG_INTERNAL_FUNCTION(void,png_push_save_buffer,(png_structrp png_ptr), + PNG_EMPTY); +PNG_INTERNAL_FUNCTION(void,png_push_restore_buffer,(png_structrp png_ptr, + png_bytep buffer, png_size_t buffer_length),PNG_EMPTY); +PNG_INTERNAL_FUNCTION(void,png_push_read_IDAT,(png_structrp png_ptr),PNG_EMPTY); +PNG_INTERNAL_FUNCTION(void,png_process_IDAT_data,(png_structrp png_ptr, + png_bytep buffer, png_size_t buffer_length),PNG_EMPTY); +PNG_INTERNAL_FUNCTION(void,png_push_process_row,(png_structrp png_ptr), + PNG_EMPTY); +PNG_INTERNAL_FUNCTION(void,png_push_handle_unknown,(png_structrp png_ptr, + png_inforp info_ptr, png_uint_32 length),PNG_EMPTY); +PNG_INTERNAL_FUNCTION(void,png_push_have_info,(png_structrp png_ptr, + png_inforp info_ptr),PNG_EMPTY); +PNG_INTERNAL_FUNCTION(void,png_push_have_end,(png_structrp png_ptr, + png_inforp info_ptr),PNG_EMPTY); +PNG_INTERNAL_FUNCTION(void,png_push_have_row,(png_structrp png_ptr, + png_bytep row),PNG_EMPTY); +PNG_INTERNAL_FUNCTION(void,png_push_read_end,(png_structrp png_ptr, + png_inforp info_ptr),PNG_EMPTY); +PNG_INTERNAL_FUNCTION(void,png_process_some_data,(png_structrp png_ptr, + png_inforp info_ptr),PNG_EMPTY); +PNG_INTERNAL_FUNCTION(void,png_read_push_finish_row,(png_structrp png_ptr), + PNG_EMPTY); +# ifdef PNG_READ_tEXt_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_push_handle_tEXt,(png_structrp png_ptr, + png_inforp info_ptr, png_uint_32 length),PNG_EMPTY); +PNG_INTERNAL_FUNCTION(void,png_push_read_tEXt,(png_structrp png_ptr, + png_inforp info_ptr),PNG_EMPTY); +# endif +# ifdef PNG_READ_zTXt_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_push_handle_zTXt,(png_structrp png_ptr, + png_inforp info_ptr, png_uint_32 length),PNG_EMPTY); +PNG_INTERNAL_FUNCTION(void,png_push_read_zTXt,(png_structrp png_ptr, + png_inforp info_ptr),PNG_EMPTY); +# endif +# ifdef PNG_READ_iTXt_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_push_handle_iTXt,(png_structrp png_ptr, + png_inforp info_ptr, png_uint_32 length),PNG_EMPTY); +PNG_INTERNAL_FUNCTION(void,png_push_read_iTXt,(png_structrp png_ptr, + png_inforp info_ptr),PNG_EMPTY); +# endif + +#endif /* PNG_PROGRESSIVE_READ_SUPPORTED */ + +#ifdef PNG_MNG_FEATURES_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_do_read_intrapixel,(png_row_infop row_info, + png_bytep row),PNG_EMPTY); +PNG_INTERNAL_FUNCTION(void,png_do_write_intrapixel,(png_row_infop row_info, + png_bytep row),PNG_EMPTY); +#endif + +/* Added at libpng version 1.6.0 */ +#ifdef PNG_GAMMA_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_colorspace_set_gamma,(png_const_structrp png_ptr, + png_colorspacerp colorspace, png_fixed_point gAMA), PNG_EMPTY); + /* Set the colorspace gamma with a value provided by the application or by + * the gAMA chunk on read. The value will override anything set by an ICC + * profile. + */ + +PNG_INTERNAL_FUNCTION(void,png_colorspace_sync_info,(png_const_structrp png_ptr, + png_inforp info_ptr), PNG_EMPTY); + /* Synchronize the info 'valid' flags with the colorspace */ + +PNG_INTERNAL_FUNCTION(void,png_colorspace_sync,(png_const_structrp png_ptr, + png_inforp info_ptr), PNG_EMPTY); + /* Copy the png_struct colorspace to the info_struct and call the above to + * synchronize the flags. Checks for NULL info_ptr and does nothing. + */ +#endif + +/* Added at libpng version 1.4.0 */ +#ifdef PNG_COLORSPACE_SUPPORTED +/* These internal functions are for maintaining the colorspace structure within + * a png_info or png_struct (or, indeed, both). + */ +PNG_INTERNAL_FUNCTION(int,png_colorspace_set_chromaticities, + (png_const_structrp png_ptr, png_colorspacerp colorspace, const png_xy *xy, + int preferred), PNG_EMPTY); + +PNG_INTERNAL_FUNCTION(int,png_colorspace_set_endpoints, + (png_const_structrp png_ptr, png_colorspacerp colorspace, const png_XYZ *XYZ, + int preferred), PNG_EMPTY); + +#ifdef PNG_sRGB_SUPPORTED +PNG_INTERNAL_FUNCTION(int,png_colorspace_set_sRGB,(png_const_structrp png_ptr, + png_colorspacerp colorspace, int intent), PNG_EMPTY); + /* This does set the colorspace gAMA and cHRM values too, but doesn't set the + * flags to write them, if it returns false there was a problem and an error + * message has already been output (but the colorspace may still need to be + * synced to record the invalid flag). + */ +#endif /* sRGB */ + +#ifdef PNG_iCCP_SUPPORTED +PNG_INTERNAL_FUNCTION(int,png_colorspace_set_ICC,(png_const_structrp png_ptr, + png_colorspacerp colorspace, png_const_charp name, + png_uint_32 profile_length, png_const_bytep profile, int color_type), + PNG_EMPTY); + /* The 'name' is used for information only */ + +/* Routines for checking parts of an ICC profile. */ +PNG_INTERNAL_FUNCTION(int,png_icc_check_length,(png_const_structrp png_ptr, + png_colorspacerp colorspace, png_const_charp name, + png_uint_32 profile_length), PNG_EMPTY); +PNG_INTERNAL_FUNCTION(int,png_icc_check_header,(png_const_structrp png_ptr, + png_colorspacerp colorspace, png_const_charp name, + png_uint_32 profile_length, + png_const_bytep profile /* first 132 bytes only */, int color_type), + PNG_EMPTY); +PNG_INTERNAL_FUNCTION(int,png_icc_check_tag_table,(png_const_structrp png_ptr, + png_colorspacerp colorspace, png_const_charp name, + png_uint_32 profile_length, + png_const_bytep profile /* header plus whole tag table */), PNG_EMPTY); +#ifdef PNG_sRGB_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_icc_set_sRGB,( + png_const_structrp png_ptr, png_colorspacerp colorspace, + png_const_bytep profile, uLong adler), PNG_EMPTY); + /* 'adler' is the Adler32 checksum of the uncompressed profile data. It may + * be zero to indicate that it is not available. It is used, if provided, + * as a fast check on the profile when checking to see if it is sRGB. + */ +#endif +#endif /* iCCP */ + +#ifdef PNG_READ_RGB_TO_GRAY_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_colorspace_set_rgb_coefficients, + (png_structrp png_ptr), PNG_EMPTY); + /* Set the rgb_to_gray coefficients from the colorspace Y values */ +#endif /* READ_RGB_TO_GRAY */ +#endif /* COLORSPACE */ + +/* Added at libpng version 1.4.0 */ +PNG_INTERNAL_FUNCTION(void,png_check_IHDR,(png_const_structrp png_ptr, + png_uint_32 width, png_uint_32 height, int bit_depth, + int color_type, int interlace_type, int compression_type, + int filter_type),PNG_EMPTY); + +/* Added at libpng version 1.5.10 */ +#if defined(PNG_READ_CHECK_FOR_INVALID_INDEX_SUPPORTED) || \ + defined(PNG_WRITE_CHECK_FOR_INVALID_INDEX_SUPPORTED) +PNG_INTERNAL_FUNCTION(void,png_do_check_palette_indexes, + (png_structrp png_ptr, png_row_infop row_info),PNG_EMPTY); +#endif + +#if defined(PNG_FLOATING_POINT_SUPPORTED) && defined(PNG_ERROR_TEXT_SUPPORTED) +PNG_INTERNAL_FUNCTION(void,png_fixed_error,(png_const_structrp png_ptr, + png_const_charp name),PNG_NORETURN); +#endif + +/* Puts 'string' into 'buffer' at buffer[pos], taking care never to overwrite + * the end. Always leaves the buffer nul terminated. Never errors out (and + * there is no error code.) + */ +PNG_INTERNAL_FUNCTION(size_t,png_safecat,(png_charp buffer, size_t bufsize, + size_t pos, png_const_charp string),PNG_EMPTY); + +/* Various internal functions to handle formatted warning messages, currently + * only implemented for warnings. + */ +#if defined(PNG_WARNINGS_SUPPORTED) || defined(PNG_TIME_RFC1123_SUPPORTED) +/* Utility to dump an unsigned value into a buffer, given a start pointer and + * and end pointer (which should point just *beyond* the end of the buffer!) + * Returns the pointer to the start of the formatted string. This utility only + * does unsigned values. + */ +PNG_INTERNAL_FUNCTION(png_charp,png_format_number,(png_const_charp start, + png_charp end, int format, png_alloc_size_t number),PNG_EMPTY); + +/* Convenience macro that takes an array: */ +#define PNG_FORMAT_NUMBER(buffer,format,number) \ + png_format_number(buffer, buffer + (sizeof buffer), format, number) + +/* Suggested size for a number buffer (enough for 64 bits and a sign!) */ +#define PNG_NUMBER_BUFFER_SIZE 24 + +/* These are the integer formats currently supported, the name is formed from + * the standard printf(3) format string. + */ +#define PNG_NUMBER_FORMAT_u 1 /* chose unsigned API! */ +#define PNG_NUMBER_FORMAT_02u 2 +#define PNG_NUMBER_FORMAT_d 1 /* chose signed API! */ +#define PNG_NUMBER_FORMAT_02d 2 +#define PNG_NUMBER_FORMAT_x 3 +#define PNG_NUMBER_FORMAT_02x 4 +#define PNG_NUMBER_FORMAT_fixed 5 /* choose the signed API */ +#endif + +#ifdef PNG_WARNINGS_SUPPORTED +/* New defines and members adding in libpng-1.5.4 */ +# define PNG_WARNING_PARAMETER_SIZE 32 +# define PNG_WARNING_PARAMETER_COUNT 8 /* Maximum 9; see pngerror.c */ + +/* An l-value of this type has to be passed to the APIs below to cache the + * values of the parameters to a formatted warning message. + */ +typedef char png_warning_parameters[PNG_WARNING_PARAMETER_COUNT][ + PNG_WARNING_PARAMETER_SIZE]; + +PNG_INTERNAL_FUNCTION(void,png_warning_parameter,(png_warning_parameters p, + int number, png_const_charp string),PNG_EMPTY); + /* Parameters are limited in size to PNG_WARNING_PARAMETER_SIZE characters, + * including the trailing '\0'. + */ +PNG_INTERNAL_FUNCTION(void,png_warning_parameter_unsigned, + (png_warning_parameters p, int number, int format, png_alloc_size_t value), + PNG_EMPTY); + /* Use png_alloc_size_t because it is an unsigned type as big as any we + * need to output. Use the following for a signed value. + */ +PNG_INTERNAL_FUNCTION(void,png_warning_parameter_signed, + (png_warning_parameters p, int number, int format, png_int_32 value), + PNG_EMPTY); + +PNG_INTERNAL_FUNCTION(void,png_formatted_warning,(png_const_structrp png_ptr, + png_warning_parameters p, png_const_charp message),PNG_EMPTY); + /* 'message' follows the X/Open approach of using @1, @2 to insert + * parameters previously supplied using the above functions. Errors in + * specifying the parameters will simply result in garbage substitutions. + */ +#endif + +#ifdef PNG_BENIGN_ERRORS_SUPPORTED +/* Application errors (new in 1.6); use these functions (declared below) for + * errors in the parameters or order of API function calls on read. The + * 'warning' should be used for an error that can be handled completely; the + * 'error' for one which can be handled safely but which may lose application + * information or settings. + * + * By default these both result in a png_error call prior to release, while in a + * released version the 'warning' is just a warning. However if the application + * explicitly disables benign errors (explicitly permitting the code to lose + * information) they both turn into warnings. + * + * If benign errors aren't supported they end up as the corresponding base call + * (png_warning or png_error.) + */ +PNG_INTERNAL_FUNCTION(void,png_app_warning,(png_const_structrp png_ptr, + png_const_charp message),PNG_EMPTY); + /* The application provided invalid parameters to an API function or called + * an API function at the wrong time, libpng can completely recover. + */ + +PNG_INTERNAL_FUNCTION(void,png_app_error,(png_const_structrp png_ptr, + png_const_charp message),PNG_EMPTY); + /* As above but libpng will ignore the call, or attempt some other partial + * recovery from the error. + */ +#else +# define png_app_warning(pp,s) png_warning(pp,s) +# define png_app_error(pp,s) png_error(pp,s) +#endif + +PNG_INTERNAL_FUNCTION(void,png_chunk_report,(png_const_structrp png_ptr, + png_const_charp message, int error),PNG_EMPTY); + /* Report a recoverable issue in chunk data. On read this is used to report + * a problem found while reading a particular chunk and the + * png_chunk_benign_error or png_chunk_warning function is used as + * appropriate. On write this is used to report an error that comes from + * data set via an application call to a png_set_ API and png_app_error or + * png_app_warning is used as appropriate. + * + * The 'error' parameter must have one of the following values: + */ +#define PNG_CHUNK_WARNING 0 /* never an error */ +#define PNG_CHUNK_WRITE_ERROR 1 /* an error only on write */ +#define PNG_CHUNK_ERROR 2 /* always an error */ + +/* ASCII to FP interfaces, currently only implemented if sCAL + * support is required. + */ +#if defined(PNG_sCAL_SUPPORTED) +/* MAX_DIGITS is actually the maximum number of characters in an sCAL + * width or height, derived from the precision (number of significant + * digits - a build time settable option) and assumptions about the + * maximum ridiculous exponent. + */ +#define PNG_sCAL_MAX_DIGITS (PNG_sCAL_PRECISION+1/*.*/+1/*E*/+10/*exponent*/) + +#ifdef PNG_FLOATING_POINT_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_ascii_from_fp,(png_const_structrp png_ptr, + png_charp ascii, png_size_t size, double fp, unsigned int precision), + PNG_EMPTY); +#endif /* FLOATING_POINT */ + +#ifdef PNG_FIXED_POINT_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_ascii_from_fixed,(png_const_structrp png_ptr, + png_charp ascii, png_size_t size, png_fixed_point fp),PNG_EMPTY); +#endif /* FIXED_POINT */ +#endif /* sCAL */ + +#if defined(PNG_sCAL_SUPPORTED) || defined(PNG_pCAL_SUPPORTED) +/* An internal API to validate the format of a floating point number. + * The result is the index of the next character. If the number is + * not valid it will be the index of a character in the supposed number. + * + * The format of a number is defined in the PNG extensions specification + * and this API is strictly conformant to that spec, not anyone elses! + * + * The format as a regular expression is: + * + * [+-]?[0-9]+.?([Ee][+-]?[0-9]+)? + * + * or: + * + * [+-]?.[0-9]+(.[0-9]+)?([Ee][+-]?[0-9]+)? + * + * The complexity is that either integer or fraction must be present and the + * fraction is permitted to have no digits only if the integer is present. + * + * NOTE: The dangling E problem. + * There is a PNG valid floating point number in the following: + * + * PNG floating point numbers are not greedy. + * + * Working this out requires *TWO* character lookahead (because of the + * sign), the parser does not do this - it will fail at the 'r' - this + * doesn't matter for PNG sCAL chunk values, but it requires more care + * if the value were ever to be embedded in something more complex. Use + * ANSI-C strtod if you need the lookahead. + */ +/* State table for the parser. */ +#define PNG_FP_INTEGER 0 /* before or in integer */ +#define PNG_FP_FRACTION 1 /* before or in fraction */ +#define PNG_FP_EXPONENT 2 /* before or in exponent */ +#define PNG_FP_STATE 3 /* mask for the above */ +#define PNG_FP_SAW_SIGN 4 /* Saw +/- in current state */ +#define PNG_FP_SAW_DIGIT 8 /* Saw a digit in current state */ +#define PNG_FP_SAW_DOT 16 /* Saw a dot in current state */ +#define PNG_FP_SAW_E 32 /* Saw an E (or e) in current state */ +#define PNG_FP_SAW_ANY 60 /* Saw any of the above 4 */ + +/* These three values don't affect the parser. They are set but not used. + */ +#define PNG_FP_WAS_VALID 64 /* Preceding substring is a valid fp number */ +#define PNG_FP_NEGATIVE 128 /* A negative number, including "-0" */ +#define PNG_FP_NONZERO 256 /* A non-zero value */ +#define PNG_FP_STICKY 448 /* The above three flags */ + +/* This is available for the caller to store in 'state' if required. Do not + * call the parser after setting it (the parser sometimes clears it.) + */ +#define PNG_FP_INVALID 512 /* Available for callers as a distinct value */ + +/* Result codes for the parser (boolean - true meants ok, false means + * not ok yet.) + */ +#define PNG_FP_MAYBE 0 /* The number may be valid in the future */ +#define PNG_FP_OK 1 /* The number is valid */ + +/* Tests on the sticky non-zero and negative flags. To pass these checks + * the state must also indicate that the whole number is valid - this is + * achieved by testing PNG_FP_SAW_DIGIT (see the implementation for why this + * is equivalent to PNG_FP_OK above.) + */ +#define PNG_FP_NZ_MASK (PNG_FP_SAW_DIGIT | PNG_FP_NEGATIVE | PNG_FP_NONZERO) + /* NZ_MASK: the string is valid and a non-zero negative value */ +#define PNG_FP_Z_MASK (PNG_FP_SAW_DIGIT | PNG_FP_NONZERO) + /* Z MASK: the string is valid and a non-zero value. */ + /* PNG_FP_SAW_DIGIT: the string is valid. */ +#define PNG_FP_IS_ZERO(state) (((state) & PNG_FP_Z_MASK) == PNG_FP_SAW_DIGIT) +#define PNG_FP_IS_POSITIVE(state) (((state) & PNG_FP_NZ_MASK) == PNG_FP_Z_MASK) +#define PNG_FP_IS_NEGATIVE(state) (((state) & PNG_FP_NZ_MASK) == PNG_FP_NZ_MASK) + +/* The actual parser. This can be called repeatedly. It updates + * the index into the string and the state variable (which must + * be initialized to 0). It returns a result code, as above. There + * is no point calling the parser any more if it fails to advance to + * the end of the string - it is stuck on an invalid character (or + * terminated by '\0'). + * + * Note that the pointer will consume an E or even an E+ and then leave + * a 'maybe' state even though a preceding integer.fraction is valid. + * The PNG_FP_WAS_VALID flag indicates that a preceding substring was + * a valid number. It's possible to recover from this by calling + * the parser again (from the start, with state 0) but with a string + * that omits the last character (i.e. set the size to the index of + * the problem character.) This has not been tested within libpng. + */ +PNG_INTERNAL_FUNCTION(int,png_check_fp_number,(png_const_charp string, + png_size_t size, int *statep, png_size_tp whereami),PNG_EMPTY); + +/* This is the same but it checks a complete string and returns true + * only if it just contains a floating point number. As of 1.5.4 this + * function also returns the state at the end of parsing the number if + * it was valid (otherwise it returns 0.) This can be used for testing + * for negative or zero values using the sticky flag. + */ +PNG_INTERNAL_FUNCTION(int,png_check_fp_string,(png_const_charp string, + png_size_t size),PNG_EMPTY); +#endif /* pCAL || sCAL */ + +#if defined(PNG_READ_GAMMA_SUPPORTED) ||\ + defined(PNG_INCH_CONVERSIONS_SUPPORTED) || defined(PNG_READ_pHYs_SUPPORTED) +/* Added at libpng version 1.5.0 */ +/* This is a utility to provide a*times/div (rounded) and indicate + * if there is an overflow. The result is a boolean - false (0) + * for overflow, true (1) if no overflow, in which case *res + * holds the result. + */ +PNG_INTERNAL_FUNCTION(int,png_muldiv,(png_fixed_point_p res, png_fixed_point a, + png_int_32 multiplied_by, png_int_32 divided_by),PNG_EMPTY); +#endif + +#if defined(PNG_READ_GAMMA_SUPPORTED) || defined(PNG_INCH_CONVERSIONS_SUPPORTED) +/* Same deal, but issue a warning on overflow and return 0. */ +PNG_INTERNAL_FUNCTION(png_fixed_point,png_muldiv_warn, + (png_const_structrp png_ptr, png_fixed_point a, png_int_32 multiplied_by, + png_int_32 divided_by),PNG_EMPTY); +#endif + +#ifdef PNG_GAMMA_SUPPORTED +/* Calculate a reciprocal - used for gamma values. This returns + * 0 if the argument is 0 in order to maintain an undefined value; + * there are no warnings. + */ +PNG_INTERNAL_FUNCTION(png_fixed_point,png_reciprocal,(png_fixed_point a), + PNG_EMPTY); + +#ifdef PNG_READ_GAMMA_SUPPORTED +/* The same but gives a reciprocal of the product of two fixed point + * values. Accuracy is suitable for gamma calculations but this is + * not exact - use png_muldiv for that. Only required at present on read. + */ +PNG_INTERNAL_FUNCTION(png_fixed_point,png_reciprocal2,(png_fixed_point a, + png_fixed_point b),PNG_EMPTY); +#endif + +/* Return true if the gamma value is significantly different from 1.0 */ +PNG_INTERNAL_FUNCTION(int,png_gamma_significant,(png_fixed_point gamma_value), + PNG_EMPTY); +#endif + +#ifdef PNG_READ_GAMMA_SUPPORTED +/* Internal fixed point gamma correction. These APIs are called as + * required to convert single values - they don't need to be fast, + * they are not used when processing image pixel values. + * + * While the input is an 'unsigned' value it must actually be the + * correct bit value - 0..255 or 0..65535 as required. + */ +PNG_INTERNAL_FUNCTION(png_uint_16,png_gamma_correct,(png_structrp png_ptr, + unsigned int value, png_fixed_point gamma_value),PNG_EMPTY); +PNG_INTERNAL_FUNCTION(png_uint_16,png_gamma_16bit_correct,(unsigned int value, + png_fixed_point gamma_value),PNG_EMPTY); +PNG_INTERNAL_FUNCTION(png_byte,png_gamma_8bit_correct,(unsigned int value, + png_fixed_point gamma_value),PNG_EMPTY); +PNG_INTERNAL_FUNCTION(void,png_destroy_gamma_table,(png_structrp png_ptr), + PNG_EMPTY); +PNG_INTERNAL_FUNCTION(void,png_build_gamma_table,(png_structrp png_ptr, + int bit_depth),PNG_EMPTY); +#endif + +/* SIMPLIFIED READ/WRITE SUPPORT */ +#if defined(PNG_SIMPLIFIED_READ_SUPPORTED) ||\ + defined(PNG_SIMPLIFIED_WRITE_SUPPORTED) +/* The internal structure that png_image::opaque points to. */ +typedef struct png_control +{ + png_structp png_ptr; + png_infop info_ptr; + png_voidp error_buf; /* Always a jmp_buf at present. */ + + png_const_bytep memory; /* Memory buffer. */ + png_size_t size; /* Size of the memory buffer. */ + + unsigned int for_write :1; /* Otherwise it is a read structure */ + unsigned int owned_file :1; /* We own the file in io_ptr */ +} png_control; + +/* Return the pointer to the jmp_buf from a png_control: necessary because C + * does not reveal the type of the elements of jmp_buf. + */ +#ifdef __cplusplus +# define png_control_jmp_buf(pc) (((jmp_buf*)((pc)->error_buf))[0]) +#else +# define png_control_jmp_buf(pc) ((pc)->error_buf) +#endif + +/* Utility to safely execute a piece of libpng code catching and logging any + * errors that might occur. Returns true on success, false on failure (either + * of the function or as a result of a png_error.) + */ +PNG_INTERNAL_FUNCTION(void,png_safe_error,(png_structp png_ptr, + png_const_charp error_message),PNG_NORETURN); + +#ifdef PNG_WARNINGS_SUPPORTED +PNG_INTERNAL_FUNCTION(void,png_safe_warning,(png_structp png_ptr, + png_const_charp warning_message),PNG_EMPTY); +#else +# define png_safe_warning 0/*dummy argument*/ +#endif + +PNG_INTERNAL_FUNCTION(int,png_safe_execute,(png_imagep image, + int (*function)(png_voidp), png_voidp arg),PNG_EMPTY); + +/* Utility to log an error; this also cleans up the png_image; the function + * always returns 0 (false). + */ +PNG_INTERNAL_FUNCTION(int,png_image_error,(png_imagep image, + png_const_charp error_message),PNG_EMPTY); + +#ifndef PNG_SIMPLIFIED_READ_SUPPORTED +/* png_image_free is used by the write code but not exported */ +PNG_INTERNAL_FUNCTION(void, png_image_free, (png_imagep image), PNG_EMPTY); +#endif /* !SIMPLIFIED_READ */ + +#endif /* SIMPLIFIED READ/WRITE */ + +/* These are initialization functions for hardware specific PNG filter + * optimizations; list these here then select the appropriate one at compile + * time using the macro PNG_FILTER_OPTIMIZATIONS. If the macro is not defined + * the generic code is used. + */ +#ifdef PNG_FILTER_OPTIMIZATIONS +PNG_INTERNAL_FUNCTION(void, PNG_FILTER_OPTIMIZATIONS, (png_structp png_ptr, + unsigned int bpp), PNG_EMPTY); + /* Just declare the optimization that will be used */ +#else + /* List *all* the possible optimizations here - this branch is required if + * the builder of libpng passes the definition of PNG_FILTER_OPTIMIZATIONS in + * CFLAGS in place of CPPFLAGS *and* uses symbol prefixing. + */ +PNG_INTERNAL_FUNCTION(void, png_init_filter_functions_neon, + (png_structp png_ptr, unsigned int bpp), PNG_EMPTY); +#endif + +/* Maintainer: Put new private prototypes here ^ */ + +#include "pngdebug.h" + +#ifdef __cplusplus +} +#endif + +#endif /* PNG_VERSION_INFO_ONLY */ +#endif /* PNGPRIV_H */ diff --git a/ml/dlib/dlib/external/libpng/pngread.c b/ml/dlib/dlib/external/libpng/pngread.c new file mode 100644 index 000000000..8f96ca23e --- /dev/null +++ b/ml/dlib/dlib/external/libpng/pngread.c @@ -0,0 +1,4000 @@ + +/* pngread.c - read a PNG file + * + * Last changed in libpng 1.6.1 [March 28, 2013] + * Copyright (c) 1998-2013 Glenn Randers-Pehrson + * (Version 0.96 Copyright (c) 1996, 1997 Andreas Dilger) + * (Version 0.88 Copyright (c) 1995, 1996 Guy Eric Schalnat, Group 42, Inc.) + * + * This code is released under the libpng license. + * For conditions of distribution and use, see the disclaimer + * and license in png.h + * + * This file contains routines that an application calls directly to + * read a PNG file or stream. + */ + +#include "pngpriv.h" +#if defined(PNG_SIMPLIFIED_READ_SUPPORTED) && defined(PNG_STDIO_SUPPORTED) +# include +#endif + +#ifdef PNG_READ_SUPPORTED + +/* Create a PNG structure for reading, and allocate any memory needed. */ +PNG_FUNCTION(png_structp,PNGAPI +png_create_read_struct,(png_const_charp user_png_ver, png_voidp error_ptr, + png_error_ptr error_fn, png_error_ptr warn_fn),PNG_ALLOCATED) +{ +#ifndef PNG_USER_MEM_SUPPORTED + png_structp png_ptr = png_create_png_struct(user_png_ver, error_ptr, + error_fn, warn_fn, NULL, NULL, NULL); +#else + return png_create_read_struct_2(user_png_ver, error_ptr, error_fn, + warn_fn, NULL, NULL, NULL); +} + +/* Alternate create PNG structure for reading, and allocate any memory + * needed. + */ +PNG_FUNCTION(png_structp,PNGAPI +png_create_read_struct_2,(png_const_charp user_png_ver, png_voidp error_ptr, + png_error_ptr error_fn, png_error_ptr warn_fn, png_voidp mem_ptr, + png_malloc_ptr malloc_fn, png_free_ptr free_fn),PNG_ALLOCATED) +{ + png_structp png_ptr = png_create_png_struct(user_png_ver, error_ptr, + error_fn, warn_fn, mem_ptr, malloc_fn, free_fn); +#endif /* PNG_USER_MEM_SUPPORTED */ + + if (png_ptr != NULL) + { + png_ptr->mode = PNG_IS_READ_STRUCT; + + /* Added in libpng-1.6.0; this can be used to detect a read structure if + * required (it will be zero in a write structure.) + */ +# ifdef PNG_SEQUENTIAL_READ_SUPPORTED + png_ptr->IDAT_read_size = PNG_IDAT_READ_SIZE; +# endif + +# ifdef PNG_BENIGN_READ_ERRORS_SUPPORTED + png_ptr->flags |= PNG_FLAG_BENIGN_ERRORS_WARN; + + /* In stable builds only warn if an application error can be completely + * handled. + */ +# if PNG_LIBPNG_BUILD_BASE_TYPE >= PNG_LIBPNG_BUILD_RC + png_ptr->flags |= PNG_FLAG_APP_WARNINGS_WARN; +# endif +# endif + + /* TODO: delay this, it can be done in png_init_io (if the app doesn't + * do it itself) avoiding setting the default function if it is not + * required. + */ + png_set_read_fn(png_ptr, NULL, NULL); + } + + return png_ptr; +} + + +#ifdef PNG_SEQUENTIAL_READ_SUPPORTED +/* Read the information before the actual image data. This has been + * changed in v0.90 to allow reading a file that already has the magic + * bytes read from the stream. You can tell libpng how many bytes have + * been read from the beginning of the stream (up to the maximum of 8) + * via png_set_sig_bytes(), and we will only check the remaining bytes + * here. The application can then have access to the signature bytes we + * read if it is determined that this isn't a valid PNG file. + */ +void PNGAPI +png_read_info(png_structrp png_ptr, png_inforp info_ptr) +{ +#ifdef PNG_HANDLE_AS_UNKNOWN_SUPPORTED + int keep; +#endif + + png_debug(1, "in png_read_info"); + + if (png_ptr == NULL || info_ptr == NULL) + return; + + /* Read and check the PNG file signature. */ + png_read_sig(png_ptr, info_ptr); + + for (;;) + { + png_uint_32 length = png_read_chunk_header(png_ptr); + png_uint_32 chunk_name = png_ptr->chunk_name; + + /* IDAT logic needs to happen here to simplify getting the two flags + * right. + */ + if (chunk_name == png_IDAT) + { + if (!(png_ptr->mode & PNG_HAVE_IHDR)) + png_chunk_error(png_ptr, "Missing IHDR before IDAT"); + + else if (png_ptr->color_type == PNG_COLOR_TYPE_PALETTE && + !(png_ptr->mode & PNG_HAVE_PLTE)) + png_chunk_error(png_ptr, "Missing PLTE before IDAT"); + + else if (png_ptr->mode & PNG_AFTER_IDAT) + png_chunk_benign_error(png_ptr, "Too many IDATs found"); + + png_ptr->mode |= PNG_HAVE_IDAT; + } + + else if (png_ptr->mode & PNG_HAVE_IDAT) + png_ptr->mode |= PNG_AFTER_IDAT; + + /* This should be a binary subdivision search or a hash for + * matching the chunk name rather than a linear search. + */ + if (chunk_name == png_IHDR) + png_handle_IHDR(png_ptr, info_ptr, length); + + else if (chunk_name == png_IEND) + png_handle_IEND(png_ptr, info_ptr, length); + +#ifdef PNG_HANDLE_AS_UNKNOWN_SUPPORTED + else if ((keep = png_chunk_unknown_handling(png_ptr, chunk_name)) != 0) + { + png_handle_unknown(png_ptr, info_ptr, length, keep); + + if (chunk_name == png_PLTE) + png_ptr->mode |= PNG_HAVE_PLTE; + + else if (chunk_name == png_IDAT) + { + png_ptr->idat_size = 0; /* It has been consumed */ + break; + } + } +#endif + else if (chunk_name == png_PLTE) + png_handle_PLTE(png_ptr, info_ptr, length); + + else if (chunk_name == png_IDAT) + { + png_ptr->idat_size = length; + break; + } + +#ifdef PNG_READ_bKGD_SUPPORTED + else if (chunk_name == png_bKGD) + png_handle_bKGD(png_ptr, info_ptr, length); +#endif + +#ifdef PNG_READ_cHRM_SUPPORTED + else if (chunk_name == png_cHRM) + png_handle_cHRM(png_ptr, info_ptr, length); +#endif + +#ifdef PNG_READ_gAMA_SUPPORTED + else if (chunk_name == png_gAMA) + png_handle_gAMA(png_ptr, info_ptr, length); +#endif + +#ifdef PNG_READ_hIST_SUPPORTED + else if (chunk_name == png_hIST) + png_handle_hIST(png_ptr, info_ptr, length); +#endif + +#ifdef PNG_READ_oFFs_SUPPORTED + else if (chunk_name == png_oFFs) + png_handle_oFFs(png_ptr, info_ptr, length); +#endif + +#ifdef PNG_READ_pCAL_SUPPORTED + else if (chunk_name == png_pCAL) + png_handle_pCAL(png_ptr, info_ptr, length); +#endif + +#ifdef PNG_READ_sCAL_SUPPORTED + else if (chunk_name == png_sCAL) + png_handle_sCAL(png_ptr, info_ptr, length); +#endif + +#ifdef PNG_READ_pHYs_SUPPORTED + else if (chunk_name == png_pHYs) + png_handle_pHYs(png_ptr, info_ptr, length); +#endif + +#ifdef PNG_READ_sBIT_SUPPORTED + else if (chunk_name == png_sBIT) + png_handle_sBIT(png_ptr, info_ptr, length); +#endif + +#ifdef PNG_READ_sRGB_SUPPORTED + else if (chunk_name == png_sRGB) + png_handle_sRGB(png_ptr, info_ptr, length); +#endif + +#ifdef PNG_READ_iCCP_SUPPORTED + else if (chunk_name == png_iCCP) + png_handle_iCCP(png_ptr, info_ptr, length); +#endif + +#ifdef PNG_READ_sPLT_SUPPORTED + else if (chunk_name == png_sPLT) + png_handle_sPLT(png_ptr, info_ptr, length); +#endif + +#ifdef PNG_READ_tEXt_SUPPORTED + else if (chunk_name == png_tEXt) + png_handle_tEXt(png_ptr, info_ptr, length); +#endif + +#ifdef PNG_READ_tIME_SUPPORTED + else if (chunk_name == png_tIME) + png_handle_tIME(png_ptr, info_ptr, length); +#endif + +#ifdef PNG_READ_tRNS_SUPPORTED + else if (chunk_name == png_tRNS) + png_handle_tRNS(png_ptr, info_ptr, length); +#endif + +#ifdef PNG_READ_zTXt_SUPPORTED + else if (chunk_name == png_zTXt) + png_handle_zTXt(png_ptr, info_ptr, length); +#endif + +#ifdef PNG_READ_iTXt_SUPPORTED + else if (chunk_name == png_iTXt) + png_handle_iTXt(png_ptr, info_ptr, length); +#endif + + else + png_handle_unknown(png_ptr, info_ptr, length, + PNG_HANDLE_CHUNK_AS_DEFAULT); + } +} +#endif /* PNG_SEQUENTIAL_READ_SUPPORTED */ + +/* Optional call to update the users info_ptr structure */ +void PNGAPI +png_read_update_info(png_structrp png_ptr, png_inforp info_ptr) +{ + png_debug(1, "in png_read_update_info"); + + if (png_ptr != NULL) + { + if ((png_ptr->flags & PNG_FLAG_ROW_INIT) == 0) + { + png_read_start_row(png_ptr); + +# ifdef PNG_READ_TRANSFORMS_SUPPORTED + png_read_transform_info(png_ptr, info_ptr); +# else + PNG_UNUSED(info_ptr) +# endif + } + + /* New in 1.6.0 this avoids the bug of doing the initializations twice */ + else + png_app_error(png_ptr, + "png_read_update_info/png_start_read_image: duplicate call"); + } +} + +#ifdef PNG_SEQUENTIAL_READ_SUPPORTED +/* Initialize palette, background, etc, after transformations + * are set, but before any reading takes place. This allows + * the user to obtain a gamma-corrected palette, for example. + * If the user doesn't call this, we will do it ourselves. + */ +void PNGAPI +png_start_read_image(png_structrp png_ptr) +{ + png_debug(1, "in png_start_read_image"); + + if (png_ptr != NULL) + { + if ((png_ptr->flags & PNG_FLAG_ROW_INIT) == 0) + png_read_start_row(png_ptr); + + /* New in 1.6.0 this avoids the bug of doing the initializations twice */ + else + png_app_error(png_ptr, + "png_start_read_image/png_read_update_info: duplicate call"); + } +} +#endif /* PNG_SEQUENTIAL_READ_SUPPORTED */ + +#ifdef PNG_SEQUENTIAL_READ_SUPPORTED +void PNGAPI +png_read_row(png_structrp png_ptr, png_bytep row, png_bytep dsp_row) +{ + png_row_info row_info; + + if (png_ptr == NULL) + return; + + png_debug2(1, "in png_read_row (row %lu, pass %d)", + (unsigned long)png_ptr->row_number, png_ptr->pass); + + /* png_read_start_row sets the information (in particular iwidth) for this + * interlace pass. + */ + if (!(png_ptr->flags & PNG_FLAG_ROW_INIT)) + png_read_start_row(png_ptr); + + /* 1.5.6: row_info moved out of png_struct to a local here. */ + row_info.width = png_ptr->iwidth; /* NOTE: width of current interlaced row */ + row_info.color_type = png_ptr->color_type; + row_info.bit_depth = png_ptr->bit_depth; + row_info.channels = png_ptr->channels; + row_info.pixel_depth = png_ptr->pixel_depth; + row_info.rowbytes = PNG_ROWBYTES(row_info.pixel_depth, row_info.width); + + if (png_ptr->row_number == 0 && png_ptr->pass == 0) + { + /* Check for transforms that have been set but were defined out */ +#if defined(PNG_WRITE_INVERT_SUPPORTED) && !defined(PNG_READ_INVERT_SUPPORTED) + if (png_ptr->transformations & PNG_INVERT_MONO) + png_warning(png_ptr, "PNG_READ_INVERT_SUPPORTED is not defined"); +#endif + +#if defined(PNG_WRITE_FILLER_SUPPORTED) && !defined(PNG_READ_FILLER_SUPPORTED) + if (png_ptr->transformations & PNG_FILLER) + png_warning(png_ptr, "PNG_READ_FILLER_SUPPORTED is not defined"); +#endif + +#if defined(PNG_WRITE_PACKSWAP_SUPPORTED) && \ + !defined(PNG_READ_PACKSWAP_SUPPORTED) + if (png_ptr->transformations & PNG_PACKSWAP) + png_warning(png_ptr, "PNG_READ_PACKSWAP_SUPPORTED is not defined"); +#endif + +#if defined(PNG_WRITE_PACK_SUPPORTED) && !defined(PNG_READ_PACK_SUPPORTED) + if (png_ptr->transformations & PNG_PACK) + png_warning(png_ptr, "PNG_READ_PACK_SUPPORTED is not defined"); +#endif + +#if defined(PNG_WRITE_SHIFT_SUPPORTED) && !defined(PNG_READ_SHIFT_SUPPORTED) + if (png_ptr->transformations & PNG_SHIFT) + png_warning(png_ptr, "PNG_READ_SHIFT_SUPPORTED is not defined"); +#endif + +#if defined(PNG_WRITE_BGR_SUPPORTED) && !defined(PNG_READ_BGR_SUPPORTED) + if (png_ptr->transformations & PNG_BGR) + png_warning(png_ptr, "PNG_READ_BGR_SUPPORTED is not defined"); +#endif + +#if defined(PNG_WRITE_SWAP_SUPPORTED) && !defined(PNG_READ_SWAP_SUPPORTED) + if (png_ptr->transformations & PNG_SWAP_BYTES) + png_warning(png_ptr, "PNG_READ_SWAP_SUPPORTED is not defined"); +#endif + } + +#ifdef PNG_READ_INTERLACING_SUPPORTED + /* If interlaced and we do not need a new row, combine row and return. + * Notice that the pixels we have from previous rows have been transformed + * already; we can only combine like with like (transformed or + * untransformed) and, because of the libpng API for interlaced images, this + * means we must transform before de-interlacing. + */ + if (png_ptr->interlaced && (png_ptr->transformations & PNG_INTERLACE)) + { + switch (png_ptr->pass) + { + case 0: + if (png_ptr->row_number & 0x07) + { + if (dsp_row != NULL) + png_combine_row(png_ptr, dsp_row, 1/*display*/); + png_read_finish_row(png_ptr); + return; + } + break; + + case 1: + if ((png_ptr->row_number & 0x07) || png_ptr->width < 5) + { + if (dsp_row != NULL) + png_combine_row(png_ptr, dsp_row, 1/*display*/); + + png_read_finish_row(png_ptr); + return; + } + break; + + case 2: + if ((png_ptr->row_number & 0x07) != 4) + { + if (dsp_row != NULL && (png_ptr->row_number & 4)) + png_combine_row(png_ptr, dsp_row, 1/*display*/); + + png_read_finish_row(png_ptr); + return; + } + break; + + case 3: + if ((png_ptr->row_number & 3) || png_ptr->width < 3) + { + if (dsp_row != NULL) + png_combine_row(png_ptr, dsp_row, 1/*display*/); + + png_read_finish_row(png_ptr); + return; + } + break; + + case 4: + if ((png_ptr->row_number & 3) != 2) + { + if (dsp_row != NULL && (png_ptr->row_number & 2)) + png_combine_row(png_ptr, dsp_row, 1/*display*/); + + png_read_finish_row(png_ptr); + return; + } + break; + + case 5: + if ((png_ptr->row_number & 1) || png_ptr->width < 2) + { + if (dsp_row != NULL) + png_combine_row(png_ptr, dsp_row, 1/*display*/); + + png_read_finish_row(png_ptr); + return; + } + break; + + default: + case 6: + if (!(png_ptr->row_number & 1)) + { + png_read_finish_row(png_ptr); + return; + } + break; + } + } +#endif + + if (!(png_ptr->mode & PNG_HAVE_IDAT)) + png_error(png_ptr, "Invalid attempt to read row data"); + + /* Fill the row with IDAT data: */ + png_read_IDAT_data(png_ptr, png_ptr->row_buf, row_info.rowbytes + 1); + + if (png_ptr->row_buf[0] > PNG_FILTER_VALUE_NONE) + { + if (png_ptr->row_buf[0] < PNG_FILTER_VALUE_LAST) + png_read_filter_row(png_ptr, &row_info, png_ptr->row_buf + 1, + png_ptr->prev_row + 1, png_ptr->row_buf[0]); + else + png_error(png_ptr, "bad adaptive filter value"); + } + + /* libpng 1.5.6: the following line was copying png_ptr->rowbytes before + * 1.5.6, while the buffer really is this big in current versions of libpng + * it may not be in the future, so this was changed just to copy the + * interlaced count: + */ + memcpy(png_ptr->prev_row, png_ptr->row_buf, row_info.rowbytes + 1); + +#ifdef PNG_MNG_FEATURES_SUPPORTED + if ((png_ptr->mng_features_permitted & PNG_FLAG_MNG_FILTER_64) && + (png_ptr->filter_type == PNG_INTRAPIXEL_DIFFERENCING)) + { + /* Intrapixel differencing */ + png_do_read_intrapixel(&row_info, png_ptr->row_buf + 1); + } +#endif + + +#ifdef PNG_READ_TRANSFORMS_SUPPORTED + if (png_ptr->transformations) + png_do_read_transformations(png_ptr, &row_info); +#endif + + /* The transformed pixel depth should match the depth now in row_info. */ + if (png_ptr->transformed_pixel_depth == 0) + { + png_ptr->transformed_pixel_depth = row_info.pixel_depth; + if (row_info.pixel_depth > png_ptr->maximum_pixel_depth) + png_error(png_ptr, "sequential row overflow"); + } + + else if (png_ptr->transformed_pixel_depth != row_info.pixel_depth) + png_error(png_ptr, "internal sequential row size calculation error"); + +#ifdef PNG_READ_INTERLACING_SUPPORTED + /* Blow up interlaced rows to full size */ + if (png_ptr->interlaced && + (png_ptr->transformations & PNG_INTERLACE)) + { + if (png_ptr->pass < 6) + png_do_read_interlace(&row_info, png_ptr->row_buf + 1, png_ptr->pass, + png_ptr->transformations); + + if (dsp_row != NULL) + png_combine_row(png_ptr, dsp_row, 1/*display*/); + + if (row != NULL) + png_combine_row(png_ptr, row, 0/*row*/); + } + + else +#endif + { + if (row != NULL) + png_combine_row(png_ptr, row, -1/*ignored*/); + + if (dsp_row != NULL) + png_combine_row(png_ptr, dsp_row, -1/*ignored*/); + } + png_read_finish_row(png_ptr); + + if (png_ptr->read_row_fn != NULL) + (*(png_ptr->read_row_fn))(png_ptr, png_ptr->row_number, png_ptr->pass); + +} +#endif /* PNG_SEQUENTIAL_READ_SUPPORTED */ + +#ifdef PNG_SEQUENTIAL_READ_SUPPORTED +/* Read one or more rows of image data. If the image is interlaced, + * and png_set_interlace_handling() has been called, the rows need to + * contain the contents of the rows from the previous pass. If the + * image has alpha or transparency, and png_handle_alpha()[*] has been + * called, the rows contents must be initialized to the contents of the + * screen. + * + * "row" holds the actual image, and pixels are placed in it + * as they arrive. If the image is displayed after each pass, it will + * appear to "sparkle" in. "display_row" can be used to display a + * "chunky" progressive image, with finer detail added as it becomes + * available. If you do not want this "chunky" display, you may pass + * NULL for display_row. If you do not want the sparkle display, and + * you have not called png_handle_alpha(), you may pass NULL for rows. + * If you have called png_handle_alpha(), and the image has either an + * alpha channel or a transparency chunk, you must provide a buffer for + * rows. In this case, you do not have to provide a display_row buffer + * also, but you may. If the image is not interlaced, or if you have + * not called png_set_interlace_handling(), the display_row buffer will + * be ignored, so pass NULL to it. + * + * [*] png_handle_alpha() does not exist yet, as of this version of libpng + */ + +void PNGAPI +png_read_rows(png_structrp png_ptr, png_bytepp row, + png_bytepp display_row, png_uint_32 num_rows) +{ + png_uint_32 i; + png_bytepp rp; + png_bytepp dp; + + png_debug(1, "in png_read_rows"); + + if (png_ptr == NULL) + return; + + rp = row; + dp = display_row; + if (rp != NULL && dp != NULL) + for (i = 0; i < num_rows; i++) + { + png_bytep rptr = *rp++; + png_bytep dptr = *dp++; + + png_read_row(png_ptr, rptr, dptr); + } + + else if (rp != NULL) + for (i = 0; i < num_rows; i++) + { + png_bytep rptr = *rp; + png_read_row(png_ptr, rptr, NULL); + rp++; + } + + else if (dp != NULL) + for (i = 0; i < num_rows; i++) + { + png_bytep dptr = *dp; + png_read_row(png_ptr, NULL, dptr); + dp++; + } +} +#endif /* PNG_SEQUENTIAL_READ_SUPPORTED */ + +#ifdef PNG_SEQUENTIAL_READ_SUPPORTED +/* Read the entire image. If the image has an alpha channel or a tRNS + * chunk, and you have called png_handle_alpha()[*], you will need to + * initialize the image to the current image that PNG will be overlaying. + * We set the num_rows again here, in case it was incorrectly set in + * png_read_start_row() by a call to png_read_update_info() or + * png_start_read_image() if png_set_interlace_handling() wasn't called + * prior to either of these functions like it should have been. You can + * only call this function once. If you desire to have an image for + * each pass of a interlaced image, use png_read_rows() instead. + * + * [*] png_handle_alpha() does not exist yet, as of this version of libpng + */ +void PNGAPI +png_read_image(png_structrp png_ptr, png_bytepp image) +{ + png_uint_32 i, image_height; + int pass, j; + png_bytepp rp; + + png_debug(1, "in png_read_image"); + + if (png_ptr == NULL) + return; + +#ifdef PNG_READ_INTERLACING_SUPPORTED + if (!(png_ptr->flags & PNG_FLAG_ROW_INIT)) + { + pass = png_set_interlace_handling(png_ptr); + /* And make sure transforms are initialized. */ + png_start_read_image(png_ptr); + } + else + { + if (png_ptr->interlaced && !(png_ptr->transformations & PNG_INTERLACE)) + { + /* Caller called png_start_read_image or png_read_update_info without + * first turning on the PNG_INTERLACE transform. We can fix this here, + * but the caller should do it! + */ + png_warning(png_ptr, "Interlace handling should be turned on when " + "using png_read_image"); + /* Make sure this is set correctly */ + png_ptr->num_rows = png_ptr->height; + } + + /* Obtain the pass number, which also turns on the PNG_INTERLACE flag in + * the above error case. + */ + pass = png_set_interlace_handling(png_ptr); + } +#else + if (png_ptr->interlaced) + png_error(png_ptr, + "Cannot read interlaced image -- interlace handler disabled"); + + pass = 1; +#endif + + image_height=png_ptr->height; + + for (j = 0; j < pass; j++) + { + rp = image; + for (i = 0; i < image_height; i++) + { + png_read_row(png_ptr, *rp, NULL); + rp++; + } + } +} +#endif /* PNG_SEQUENTIAL_READ_SUPPORTED */ + +#ifdef PNG_SEQUENTIAL_READ_SUPPORTED +/* Read the end of the PNG file. Will not read past the end of the + * file, will verify the end is accurate, and will read any comments + * or time information at the end of the file, if info is not NULL. + */ +void PNGAPI +png_read_end(png_structrp png_ptr, png_inforp info_ptr) +{ +#ifdef PNG_HANDLE_AS_UNKNOWN_SUPPORTED + int keep; +#endif + + png_debug(1, "in png_read_end"); + + if (png_ptr == NULL) + return; + + /* If png_read_end is called in the middle of reading the rows there may + * still be pending IDAT data and an owned zstream. Deal with this here. + */ +#ifdef PNG_HANDLE_AS_UNKNOWN_SUPPORTED + if (!png_chunk_unknown_handling(png_ptr, png_IDAT)) +#endif + png_read_finish_IDAT(png_ptr); + +#ifdef PNG_READ_CHECK_FOR_INVALID_INDEX_SUPPORTED + /* Report invalid palette index; added at libng-1.5.10 */ + if (png_ptr->color_type == PNG_COLOR_TYPE_PALETTE && + png_ptr->num_palette_max > png_ptr->num_palette) + png_benign_error(png_ptr, "Read palette index exceeding num_palette"); +#endif + + do + { + png_uint_32 length = png_read_chunk_header(png_ptr); + png_uint_32 chunk_name = png_ptr->chunk_name; + + if (chunk_name == png_IHDR) + png_handle_IHDR(png_ptr, info_ptr, length); + + else if (chunk_name == png_IEND) + png_handle_IEND(png_ptr, info_ptr, length); + +#ifdef PNG_HANDLE_AS_UNKNOWN_SUPPORTED + else if ((keep = png_chunk_unknown_handling(png_ptr, chunk_name)) != 0) + { + if (chunk_name == png_IDAT) + { + if ((length > 0) || (png_ptr->mode & PNG_HAVE_CHUNK_AFTER_IDAT)) + png_benign_error(png_ptr, "Too many IDATs found"); + } + png_handle_unknown(png_ptr, info_ptr, length, keep); + if (chunk_name == png_PLTE) + png_ptr->mode |= PNG_HAVE_PLTE; + } +#endif + + else if (chunk_name == png_IDAT) + { + /* Zero length IDATs are legal after the last IDAT has been + * read, but not after other chunks have been read. + */ + if ((length > 0) || (png_ptr->mode & PNG_HAVE_CHUNK_AFTER_IDAT)) + png_benign_error(png_ptr, "Too many IDATs found"); + + png_crc_finish(png_ptr, length); + } + else if (chunk_name == png_PLTE) + png_handle_PLTE(png_ptr, info_ptr, length); + +#ifdef PNG_READ_bKGD_SUPPORTED + else if (chunk_name == png_bKGD) + png_handle_bKGD(png_ptr, info_ptr, length); +#endif + +#ifdef PNG_READ_cHRM_SUPPORTED + else if (chunk_name == png_cHRM) + png_handle_cHRM(png_ptr, info_ptr, length); +#endif + +#ifdef PNG_READ_gAMA_SUPPORTED + else if (chunk_name == png_gAMA) + png_handle_gAMA(png_ptr, info_ptr, length); +#endif + +#ifdef PNG_READ_hIST_SUPPORTED + else if (chunk_name == png_hIST) + png_handle_hIST(png_ptr, info_ptr, length); +#endif + +#ifdef PNG_READ_oFFs_SUPPORTED + else if (chunk_name == png_oFFs) + png_handle_oFFs(png_ptr, info_ptr, length); +#endif + +#ifdef PNG_READ_pCAL_SUPPORTED + else if (chunk_name == png_pCAL) + png_handle_pCAL(png_ptr, info_ptr, length); +#endif + +#ifdef PNG_READ_sCAL_SUPPORTED + else if (chunk_name == png_sCAL) + png_handle_sCAL(png_ptr, info_ptr, length); +#endif + +#ifdef PNG_READ_pHYs_SUPPORTED + else if (chunk_name == png_pHYs) + png_handle_pHYs(png_ptr, info_ptr, length); +#endif + +#ifdef PNG_READ_sBIT_SUPPORTED + else if (chunk_name == png_sBIT) + png_handle_sBIT(png_ptr, info_ptr, length); +#endif + +#ifdef PNG_READ_sRGB_SUPPORTED + else if (chunk_name == png_sRGB) + png_handle_sRGB(png_ptr, info_ptr, length); +#endif + +#ifdef PNG_READ_iCCP_SUPPORTED + else if (chunk_name == png_iCCP) + png_handle_iCCP(png_ptr, info_ptr, length); +#endif + +#ifdef PNG_READ_sPLT_SUPPORTED + else if (chunk_name == png_sPLT) + png_handle_sPLT(png_ptr, info_ptr, length); +#endif + +#ifdef PNG_READ_tEXt_SUPPORTED + else if (chunk_name == png_tEXt) + png_handle_tEXt(png_ptr, info_ptr, length); +#endif + +#ifdef PNG_READ_tIME_SUPPORTED + else if (chunk_name == png_tIME) + png_handle_tIME(png_ptr, info_ptr, length); +#endif + +#ifdef PNG_READ_tRNS_SUPPORTED + else if (chunk_name == png_tRNS) + png_handle_tRNS(png_ptr, info_ptr, length); +#endif + +#ifdef PNG_READ_zTXt_SUPPORTED + else if (chunk_name == png_zTXt) + png_handle_zTXt(png_ptr, info_ptr, length); +#endif + +#ifdef PNG_READ_iTXt_SUPPORTED + else if (chunk_name == png_iTXt) + png_handle_iTXt(png_ptr, info_ptr, length); +#endif + + else + png_handle_unknown(png_ptr, info_ptr, length, + PNG_HANDLE_CHUNK_AS_DEFAULT); + } while (!(png_ptr->mode & PNG_HAVE_IEND)); +} +#endif /* PNG_SEQUENTIAL_READ_SUPPORTED */ + +/* Free all memory used in the read struct */ +static void +png_read_destroy(png_structrp png_ptr) +{ + png_debug(1, "in png_read_destroy"); + +#ifdef PNG_READ_GAMMA_SUPPORTED + png_destroy_gamma_table(png_ptr); +#endif + + png_free(png_ptr, png_ptr->big_row_buf); + png_free(png_ptr, png_ptr->big_prev_row); + png_free(png_ptr, png_ptr->read_buffer); + +#ifdef PNG_READ_QUANTIZE_SUPPORTED + png_free(png_ptr, png_ptr->palette_lookup); + png_free(png_ptr, png_ptr->quantize_index); +#endif + + if (png_ptr->free_me & PNG_FREE_PLTE) + png_zfree(png_ptr, png_ptr->palette); + png_ptr->free_me &= ~PNG_FREE_PLTE; + +#if defined(PNG_tRNS_SUPPORTED) || \ + defined(PNG_READ_EXPAND_SUPPORTED) || defined(PNG_READ_BACKGROUND_SUPPORTED) + if (png_ptr->free_me & PNG_FREE_TRNS) + png_free(png_ptr, png_ptr->trans_alpha); + png_ptr->free_me &= ~PNG_FREE_TRNS; +#endif + + inflateEnd(&png_ptr->zstream); + +#ifdef PNG_PROGRESSIVE_READ_SUPPORTED + png_free(png_ptr, png_ptr->save_buffer); +#endif + +#if defined(PNG_STORE_UNKNOWN_CHUNKS_SUPPORTED) &&\ + defined(PNG_READ_UNKNOWN_CHUNKS_SUPPORTED) + png_free(png_ptr, png_ptr->unknown_chunk.data); +#endif + +#ifdef PNG_SET_UNKNOWN_CHUNKS_SUPPORTED + png_free(png_ptr, png_ptr->chunk_list); +#endif + + /* NOTE: the 'setjmp' buffer may still be allocated and the memory and error + * callbacks are still set at this point. They are required to complete the + * destruction of the png_struct itself. + */ +} + +/* Free all memory used by the read */ +void PNGAPI +png_destroy_read_struct(png_structpp png_ptr_ptr, png_infopp info_ptr_ptr, + png_infopp end_info_ptr_ptr) +{ + png_structrp png_ptr = NULL; + + png_debug(1, "in png_destroy_read_struct"); + + if (png_ptr_ptr != NULL) + png_ptr = *png_ptr_ptr; + + if (png_ptr == NULL) + return; + + /* libpng 1.6.0: use the API to destroy info structs to ensure consistent + * behavior. Prior to 1.6.0 libpng did extra 'info' destruction in this API. + * The extra was, apparently, unnecessary yet this hides memory leak bugs. + */ + png_destroy_info_struct(png_ptr, end_info_ptr_ptr); + png_destroy_info_struct(png_ptr, info_ptr_ptr); + + *png_ptr_ptr = NULL; + png_read_destroy(png_ptr); + png_destroy_png_struct(png_ptr); +} + +void PNGAPI +png_set_read_status_fn(png_structrp png_ptr, png_read_status_ptr read_row_fn) +{ + if (png_ptr == NULL) + return; + + png_ptr->read_row_fn = read_row_fn; +} + + +#ifdef PNG_SEQUENTIAL_READ_SUPPORTED +#ifdef PNG_INFO_IMAGE_SUPPORTED +void PNGAPI +png_read_png(png_structrp png_ptr, png_inforp info_ptr, + int transforms, + voidp params) +{ + int row; + + if (png_ptr == NULL || info_ptr == NULL) + return; + + /* png_read_info() gives us all of the information from the + * PNG file before the first IDAT (image data chunk). + */ + png_read_info(png_ptr, info_ptr); + if (info_ptr->height > PNG_UINT_32_MAX/(sizeof (png_bytep))) + png_error(png_ptr, "Image is too high to process with png_read_png()"); + + /* -------------- image transformations start here ------------------- */ + +#ifdef PNG_READ_SCALE_16_TO_8_SUPPORTED + /* Tell libpng to strip 16-bit/color files down to 8 bits per color. + */ + if (transforms & PNG_TRANSFORM_SCALE_16) + { + /* Added at libpng-1.5.4. "strip_16" produces the same result that it + * did in earlier versions, while "scale_16" is now more accurate. + */ + png_set_scale_16(png_ptr); + } +#endif + +#ifdef PNG_READ_STRIP_16_TO_8_SUPPORTED + /* If both SCALE and STRIP are required pngrtran will effectively cancel the + * latter by doing SCALE first. This is ok and allows apps not to check for + * which is supported to get the right answer. + */ + if (transforms & PNG_TRANSFORM_STRIP_16) + png_set_strip_16(png_ptr); +#endif + +#ifdef PNG_READ_STRIP_ALPHA_SUPPORTED + /* Strip alpha bytes from the input data without combining with + * the background (not recommended). + */ + if (transforms & PNG_TRANSFORM_STRIP_ALPHA) + png_set_strip_alpha(png_ptr); +#endif + +#if defined(PNG_READ_PACK_SUPPORTED) && !defined(PNG_READ_EXPAND_SUPPORTED) + /* Extract multiple pixels with bit depths of 1, 2, or 4 from a single + * byte into separate bytes (useful for paletted and grayscale images). + */ + if (transforms & PNG_TRANSFORM_PACKING) + png_set_packing(png_ptr); +#endif + +#ifdef PNG_READ_PACKSWAP_SUPPORTED + /* Change the order of packed pixels to least significant bit first + * (not useful if you are using png_set_packing). + */ + if (transforms & PNG_TRANSFORM_PACKSWAP) + png_set_packswap(png_ptr); +#endif + +#ifdef PNG_READ_EXPAND_SUPPORTED + /* Expand paletted colors into true RGB triplets + * Expand grayscale images to full 8 bits from 1, 2, or 4 bits/pixel + * Expand paletted or RGB images with transparency to full alpha + * channels so the data will be available as RGBA quartets. + */ + if (transforms & PNG_TRANSFORM_EXPAND) + if ((png_ptr->bit_depth < 8) || + (png_ptr->color_type == PNG_COLOR_TYPE_PALETTE) || + (png_get_valid(png_ptr, info_ptr, PNG_INFO_tRNS))) + png_set_expand(png_ptr); +#endif + + /* We don't handle background color or gamma transformation or quantizing. + */ + +#ifdef PNG_READ_INVERT_SUPPORTED + /* Invert monochrome files to have 0 as white and 1 as black + */ + if (transforms & PNG_TRANSFORM_INVERT_MONO) + png_set_invert_mono(png_ptr); +#endif + +#ifdef PNG_READ_SHIFT_SUPPORTED + /* If you want to shift the pixel values from the range [0,255] or + * [0,65535] to the original [0,7] or [0,31], or whatever range the + * colors were originally in: + */ + if ((transforms & PNG_TRANSFORM_SHIFT) + && png_get_valid(png_ptr, info_ptr, PNG_INFO_sBIT)) + { + png_color_8p sig_bit; + + png_get_sBIT(png_ptr, info_ptr, &sig_bit); + png_set_shift(png_ptr, sig_bit); + } +#endif + +#ifdef PNG_READ_BGR_SUPPORTED + /* Flip the RGB pixels to BGR (or RGBA to BGRA) */ + if (transforms & PNG_TRANSFORM_BGR) + png_set_bgr(png_ptr); +#endif + +#ifdef PNG_READ_SWAP_ALPHA_SUPPORTED + /* Swap the RGBA or GA data to ARGB or AG (or BGRA to ABGR) */ + if (transforms & PNG_TRANSFORM_SWAP_ALPHA) + png_set_swap_alpha(png_ptr); +#endif + +#ifdef PNG_READ_SWAP_SUPPORTED + /* Swap bytes of 16-bit files to least significant byte first */ + if (transforms & PNG_TRANSFORM_SWAP_ENDIAN) + png_set_swap(png_ptr); +#endif + +/* Added at libpng-1.2.41 */ +#ifdef PNG_READ_INVERT_ALPHA_SUPPORTED + /* Invert the alpha channel from opacity to transparency */ + if (transforms & PNG_TRANSFORM_INVERT_ALPHA) + png_set_invert_alpha(png_ptr); +#endif + +/* Added at libpng-1.2.41 */ +#ifdef PNG_READ_GRAY_TO_RGB_SUPPORTED + /* Expand grayscale image to RGB */ + if (transforms & PNG_TRANSFORM_GRAY_TO_RGB) + png_set_gray_to_rgb(png_ptr); +#endif + +/* Added at libpng-1.5.4 */ +#ifdef PNG_READ_EXPAND_16_SUPPORTED + if (transforms & PNG_TRANSFORM_EXPAND_16) + png_set_expand_16(png_ptr); +#endif + + /* We don't handle adding filler bytes */ + + /* We use png_read_image and rely on that for interlace handling, but we also + * call png_read_update_info therefore must turn on interlace handling now: + */ + (void)png_set_interlace_handling(png_ptr); + + /* Optional call to gamma correct and add the background to the palette + * and update info structure. REQUIRED if you are expecting libpng to + * update the palette for you (i.e., you selected such a transform above). + */ + png_read_update_info(png_ptr, info_ptr); + + /* -------------- image transformations end here ------------------- */ + + png_free_data(png_ptr, info_ptr, PNG_FREE_ROWS, 0); + if (info_ptr->row_pointers == NULL) + { + png_uint_32 iptr; + + info_ptr->row_pointers = (png_bytepp)png_malloc(png_ptr, + info_ptr->height * (sizeof (png_bytep))); + for (iptr=0; iptrheight; iptr++) + info_ptr->row_pointers[iptr] = NULL; + + info_ptr->free_me |= PNG_FREE_ROWS; + + for (row = 0; row < (int)info_ptr->height; row++) + info_ptr->row_pointers[row] = (png_bytep)png_malloc(png_ptr, + png_get_rowbytes(png_ptr, info_ptr)); + } + + png_read_image(png_ptr, info_ptr->row_pointers); + info_ptr->valid |= PNG_INFO_IDAT; + + /* Read rest of file, and get additional chunks in info_ptr - REQUIRED */ + png_read_end(png_ptr, info_ptr); + + PNG_UNUSED(transforms) /* Quiet compiler warnings */ + PNG_UNUSED(params) + +} +#endif /* PNG_INFO_IMAGE_SUPPORTED */ +#endif /* PNG_SEQUENTIAL_READ_SUPPORTED */ + +#ifdef PNG_SIMPLIFIED_READ_SUPPORTED +/* SIMPLIFIED READ + * + * This code currently relies on the sequential reader, though it could easily + * be made to work with the progressive one. + */ +/* Arguments to png_image_finish_read: */ + +/* Encoding of PNG data (used by the color-map code) */ +/* TODO: change these, dang, ANSI-C reserves the 'E' namespace. */ +# define E_NOTSET 0 /* File encoding not yet known */ +# define E_sRGB 1 /* 8-bit encoded to sRGB gamma */ +# define E_LINEAR 2 /* 16-bit linear: not encoded, NOT pre-multiplied! */ +# define E_FILE 3 /* 8-bit encoded to file gamma, not sRGB or linear */ +# define E_LINEAR8 4 /* 8-bit linear: only from a file value */ + +/* Color-map processing: after libpng has run on the PNG image further + * processing may be needed to conver the data to color-map indicies. + */ +#define PNG_CMAP_NONE 0 +#define PNG_CMAP_GA 1 /* Process GA data to a color-map with alpha */ +#define PNG_CMAP_TRANS 2 /* Process GA data to a background index */ +#define PNG_CMAP_RGB 3 /* Process RGB data */ +#define PNG_CMAP_RGB_ALPHA 4 /* Process RGBA data */ + +/* The following document where the background is for each processing case. */ +#define PNG_CMAP_NONE_BACKGROUND 256 +#define PNG_CMAP_GA_BACKGROUND 231 +#define PNG_CMAP_TRANS_BACKGROUND 254 +#define PNG_CMAP_RGB_BACKGROUND 256 +#define PNG_CMAP_RGB_ALPHA_BACKGROUND 216 + +typedef struct +{ + /* Arguments: */ + png_imagep image; + png_voidp buffer; + png_int_32 row_stride; + png_voidp colormap; + png_const_colorp background; + /* Local variables: */ + png_voidp local_row; + png_voidp first_row; + ptrdiff_t row_bytes; /* step between rows */ + int file_encoding; /* E_ values above */ + png_fixed_point gamma_to_linear; /* For E_FILE, reciprocal of gamma */ + int colormap_processing; /* PNG_CMAP_ values above */ +} png_image_read_control; + +/* Do all the *safe* initialization - 'safe' means that png_error won't be + * called, so setting up the jmp_buf is not required. This means that anything + * called from here must *not* call png_malloc - it has to call png_malloc_warn + * instead so that control is returned safely back to this routine. + */ +static int +png_image_read_init(png_imagep image) +{ + if (image->opaque == NULL) + { + png_structp png_ptr = png_create_read_struct(PNG_LIBPNG_VER_STRING, image, + png_safe_error, png_safe_warning); + + /* And set the rest of the structure to NULL to ensure that the various + * fields are consistent. + */ + memset(image, 0, (sizeof *image)); + image->version = PNG_IMAGE_VERSION; + + if (png_ptr != NULL) + { + png_infop info_ptr = png_create_info_struct(png_ptr); + + if (info_ptr != NULL) + { + png_controlp control = png_voidcast(png_controlp, + png_malloc_warn(png_ptr, (sizeof *control))); + + if (control != NULL) + { + memset(control, 0, (sizeof *control)); + + control->png_ptr = png_ptr; + control->info_ptr = info_ptr; + control->for_write = 0; + + image->opaque = control; + return 1; + } + + /* Error clean up */ + png_destroy_info_struct(png_ptr, &info_ptr); + } + + png_destroy_read_struct(&png_ptr, NULL, NULL); + } + + return png_image_error(image, "png_image_read: out of memory"); + } + + return png_image_error(image, "png_image_read: opaque pointer not NULL"); +} + +/* Utility to find the base format of a PNG file from a png_struct. */ +static png_uint_32 +png_image_format(png_structrp png_ptr) +{ + png_uint_32 format = 0; + + if (png_ptr->color_type & PNG_COLOR_MASK_COLOR) + format |= PNG_FORMAT_FLAG_COLOR; + + if (png_ptr->color_type & PNG_COLOR_MASK_ALPHA) + format |= PNG_FORMAT_FLAG_ALPHA; + + /* Use png_ptr here, not info_ptr, because by examination png_handle_tRNS + * sets the png_struct fields; that's all we are interested in here. The + * precise interaction with an app call to png_set_tRNS and PNG file reading + * is unclear. + */ + else if (png_ptr->num_trans > 0) + format |= PNG_FORMAT_FLAG_ALPHA; + + if (png_ptr->bit_depth == 16) + format |= PNG_FORMAT_FLAG_LINEAR; + + if (png_ptr->color_type & PNG_COLOR_MASK_PALETTE) + format |= PNG_FORMAT_FLAG_COLORMAP; + + return format; +} + +/* Is the given gamma significantly different from sRGB? The test is the same + * one used in pngrtran.c when deciding whether to do gamma correction. The + * arithmetic optimizes the division by using the fact that the inverse of the + * file sRGB gamma is 2.2 + */ +static int +png_gamma_not_sRGB(png_fixed_point g) +{ + if (g < PNG_FP_1) + { + /* An uninitialized gamma is assumed to be sRGB for the simplified API. */ + if (g == 0) + return 0; + + return png_gamma_significant((g * 11 + 2)/5 /* i.e. *2.2, rounded */); + } + + return 1; +} + +/* Do the main body of a 'png_image_begin_read' function; read the PNG file + * header and fill in all the information. This is executed in a safe context, + * unlike the init routine above. + */ +static int +png_image_read_header(png_voidp argument) +{ + png_imagep image = png_voidcast(png_imagep, argument); + png_structrp png_ptr = image->opaque->png_ptr; + png_inforp info_ptr = image->opaque->info_ptr; + + png_set_benign_errors(png_ptr, 1/*warn*/); + png_read_info(png_ptr, info_ptr); + + /* Do this the fast way; just read directly out of png_struct. */ + image->width = png_ptr->width; + image->height = png_ptr->height; + + { + png_uint_32 format = png_image_format(png_ptr); + + image->format = format; + +#ifdef PNG_COLORSPACE_SUPPORTED + /* Does the colorspace match sRGB? If there is no color endpoint + * (colorant) information assume yes, otherwise require the + * 'ENDPOINTS_MATCHE_sRGB' colorspace flag to have been set. If the + * colorspace has been determined to be invalid ignore it. + */ + if ((format & PNG_FORMAT_FLAG_COLOR) != 0 && ((png_ptr->colorspace.flags + & (PNG_COLORSPACE_HAVE_ENDPOINTS|PNG_COLORSPACE_ENDPOINTS_MATCH_sRGB| + PNG_COLORSPACE_INVALID)) == PNG_COLORSPACE_HAVE_ENDPOINTS)) + image->flags |= PNG_IMAGE_FLAG_COLORSPACE_NOT_sRGB; +#endif + } + + /* We need the maximum number of entries regardless of the format the + * application sets here. + */ + { + png_uint_32 cmap_entries; + + switch (png_ptr->color_type) + { + case PNG_COLOR_TYPE_GRAY: + cmap_entries = 1U << png_ptr->bit_depth; + break; + + case PNG_COLOR_TYPE_PALETTE: + cmap_entries = png_ptr->num_palette; + break; + + default: + cmap_entries = 256; + break; + } + + if (cmap_entries > 256) + cmap_entries = 256; + + image->colormap_entries = cmap_entries; + } + + return 1; +} + +#ifdef PNG_STDIO_SUPPORTED +int PNGAPI +png_image_begin_read_from_stdio(png_imagep image, FILE* file) +{ + if (image != NULL && image->version == PNG_IMAGE_VERSION) + { + if (file != NULL) + { + if (png_image_read_init(image)) + { + /* This is slightly evil, but png_init_io doesn't do anything other + * than this and we haven't changed the standard IO functions so + * this saves a 'safe' function. + */ + image->opaque->png_ptr->io_ptr = file; + return png_safe_execute(image, png_image_read_header, image); + } + } + + else + return png_image_error(image, + "png_image_begin_read_from_stdio: invalid argument"); + } + + else if (image != NULL) + return png_image_error(image, + "png_image_begin_read_from_stdio: incorrect PNG_IMAGE_VERSION"); + + return 0; +} + +int PNGAPI +png_image_begin_read_from_file(png_imagep image, const char *file_name) +{ + if (image != NULL && image->version == PNG_IMAGE_VERSION) + { + if (file_name != NULL) + { + FILE *fp = fopen(file_name, "rb"); + + if (fp != NULL) + { + if (png_image_read_init(image)) + { + image->opaque->png_ptr->io_ptr = fp; + image->opaque->owned_file = 1; + return png_safe_execute(image, png_image_read_header, image); + } + + /* Clean up: just the opened file. */ + (void)fclose(fp); + } + + else + return png_image_error(image, strerror(errno)); + } + + else + return png_image_error(image, + "png_image_begin_read_from_file: invalid argument"); + } + + else if (image != NULL) + return png_image_error(image, + "png_image_begin_read_from_file: incorrect PNG_IMAGE_VERSION"); + + return 0; +} +#endif /* PNG_STDIO_SUPPORTED */ + +static void PNGCBAPI +png_image_memory_read(png_structp png_ptr, png_bytep out, png_size_t need) +{ + if (png_ptr != NULL) + { + png_imagep image = png_voidcast(png_imagep, png_ptr->io_ptr); + if (image != NULL) + { + png_controlp cp = image->opaque; + if (cp != NULL) + { + png_const_bytep memory = cp->memory; + png_size_t size = cp->size; + + if (memory != NULL && size >= need) + { + memcpy(out, memory, need); + cp->memory = memory + need; + cp->size = size - need; + return; + } + + png_error(png_ptr, "read beyond end of data"); + } + } + + png_error(png_ptr, "invalid memory read"); + } +} + +int PNGAPI png_image_begin_read_from_memory(png_imagep image, + png_const_voidp memory, png_size_t size) +{ + if (image != NULL && image->version == PNG_IMAGE_VERSION) + { + if (memory != NULL && size > 0) + { + if (png_image_read_init(image)) + { + /* Now set the IO functions to read from the memory buffer and + * store it into io_ptr. Again do this in-place to avoid calling a + * libpng function that requires error handling. + */ + image->opaque->memory = png_voidcast(png_const_bytep, memory); + image->opaque->size = size; + image->opaque->png_ptr->io_ptr = image; + image->opaque->png_ptr->read_data_fn = png_image_memory_read; + + return png_safe_execute(image, png_image_read_header, image); + } + } + + else + return png_image_error(image, + "png_image_begin_read_from_memory: invalid argument"); + } + + else if (image != NULL) + return png_image_error(image, + "png_image_begin_read_from_memory: incorrect PNG_IMAGE_VERSION"); + + return 0; +} + +/* Utility function to skip chunks that are not used by the simplified image + * read functions and an appropriate macro to call it. + */ +#ifdef PNG_HANDLE_AS_UNKNOWN_SUPPORTED +static void +png_image_skip_unused_chunks(png_structrp png_ptr) +{ + /* Prepare the reader to ignore all recognized chunks whose data will not + * be used, i.e., all chunks recognized by libpng except for those + * involved in basic image reading: + * + * IHDR, PLTE, IDAT, IEND + * + * Or image data handling: + * + * tRNS, bKGD, gAMA, cHRM, sRGB, iCCP and sBIT. + * + * This provides a small performance improvement and eliminates any + * potential vulnerability to security problems in the unused chunks. + */ + { + static PNG_CONST png_byte chunks_to_process[] = { + 98, 75, 71, 68, '\0', /* bKGD */ + 99, 72, 82, 77, '\0', /* cHRM */ + 103, 65, 77, 65, '\0', /* gAMA */ + 105, 67, 67, 80, '\0', /* iCCP */ + 115, 66, 73, 84, '\0', /* sBIT */ + 115, 82, 71, 66, '\0', /* sRGB */ + }; + + /* Ignore unknown chunks and all other chunks except for the + * IHDR, PLTE, tRNS, IDAT, and IEND chunks. + */ + png_set_keep_unknown_chunks(png_ptr, PNG_HANDLE_CHUNK_NEVER, + NULL, -1); + + /* But do not ignore image data handling chunks */ + png_set_keep_unknown_chunks(png_ptr, PNG_HANDLE_CHUNK_AS_DEFAULT, + chunks_to_process, (sizeof chunks_to_process)/5); + } +} + +# define PNG_SKIP_CHUNKS(p) png_image_skip_unused_chunks(p) +#else +# define PNG_SKIP_CHUNKS(p) ((void)0) +#endif /* PNG_HANDLE_AS_UNKNOWN_SUPPORTED */ + +/* The following macro gives the exact rounded answer for all values in the + * range 0..255 (it actually divides by 51.2, but the rounding still generates + * the correct numbers 0..5 + */ +#define PNG_DIV51(v8) (((v8) * 5 + 130) >> 8) + +/* Utility functions to make particular color-maps */ +static void +set_file_encoding(png_image_read_control *display) +{ + png_fixed_point g = display->image->opaque->png_ptr->colorspace.gamma; + if (png_gamma_significant(g)) + { + if (png_gamma_not_sRGB(g)) + { + display->file_encoding = E_FILE; + display->gamma_to_linear = png_reciprocal(g); + } + + else + display->file_encoding = E_sRGB; + } + + else + display->file_encoding = E_LINEAR8; +} + +static unsigned int +decode_gamma(png_image_read_control *display, png_uint_32 value, int encoding) +{ + if (encoding == E_FILE) /* double check */ + encoding = display->file_encoding; + + if (encoding == E_NOTSET) /* must be the file encoding */ + { + set_file_encoding(display); + encoding = display->file_encoding; + } + + switch (encoding) + { + case E_FILE: + value = png_gamma_16bit_correct(value*257, display->gamma_to_linear); + break; + + case E_sRGB: + value = png_sRGB_table[value]; + break; + + case E_LINEAR: + break; + + case E_LINEAR8: + value *= 257; + break; + + default: + png_error(display->image->opaque->png_ptr, + "unexpected encoding (internal error)"); + break; + } + + return value; +} + +static png_uint_32 +png_colormap_compose(png_image_read_control *display, + png_uint_32 foreground, int foreground_encoding, png_uint_32 alpha, + png_uint_32 background, int encoding) +{ + /* The file value is composed on the background, the background has the given + * encoding and so does the result, the file is encoded with E_FILE and the + * file and alpha are 8-bit values. The (output) encoding will always be + * E_LINEAR or E_sRGB. + */ + png_uint_32 f = decode_gamma(display, foreground, foreground_encoding); + png_uint_32 b = decode_gamma(display, background, encoding); + + /* The alpha is always an 8-bit value (it comes from the palette), the value + * scaled by 255 is what PNG_sRGB_FROM_LINEAR requires. + */ + f = f * alpha + b * (255-alpha); + + if (encoding == E_LINEAR) + { + /* Scale to 65535; divide by 255, approximately (in fact this is extremely + * accurate, it divides by 255.00000005937181414556, with no overflow.) + */ + f *= 257; /* Now scaled by 65535 */ + f += f >> 16; + f = (f+32768) >> 16; + } + + else /* E_sRGB */ + f = PNG_sRGB_FROM_LINEAR(f); + + return f; +} + +/* NOTE: E_LINEAR values to this routine must be 16-bit, but E_FILE values must + * be 8-bit. + */ +static void +png_create_colormap_entry(png_image_read_control *display, + png_uint_32 ip, png_uint_32 red, png_uint_32 green, png_uint_32 blue, + png_uint_32 alpha, int encoding) +{ + png_imagep image = display->image; + const int output_encoding = (image->format & PNG_FORMAT_FLAG_LINEAR) ? + E_LINEAR : E_sRGB; + const int convert_to_Y = (image->format & PNG_FORMAT_FLAG_COLOR) == 0 && + (red != green || green != blue); + + if (ip > 255) + png_error(image->opaque->png_ptr, "color-map index out of range"); + + /* Update the cache with whether the file gamma is significantly different + * from sRGB. + */ + if (encoding == E_FILE) + { + if (display->file_encoding == E_NOTSET) + set_file_encoding(display); + + /* Note that the cached value may be E_FILE too, but if it is then the + * gamma_to_linear member has been set. + */ + encoding = display->file_encoding; + } + + if (encoding == E_FILE) + { + png_fixed_point g = display->gamma_to_linear; + + red = png_gamma_16bit_correct(red*257, g); + green = png_gamma_16bit_correct(green*257, g); + blue = png_gamma_16bit_correct(blue*257, g); + + if (convert_to_Y || output_encoding == E_LINEAR) + { + alpha *= 257; + encoding = E_LINEAR; + } + + else + { + red = PNG_sRGB_FROM_LINEAR(red * 255); + green = PNG_sRGB_FROM_LINEAR(green * 255); + blue = PNG_sRGB_FROM_LINEAR(blue * 255); + encoding = E_sRGB; + } + } + + else if (encoding == E_LINEAR8) + { + /* This encoding occurs quite frequently in test cases because PngSuite + * includes a gAMA 1.0 chunk with most images. + */ + red *= 257; + green *= 257; + blue *= 257; + alpha *= 257; + encoding = E_LINEAR; + } + + else if (encoding == E_sRGB && (convert_to_Y || output_encoding == E_LINEAR)) + { + /* The values are 8-bit sRGB values, but must be converted to 16-bit + * linear. + */ + red = png_sRGB_table[red]; + green = png_sRGB_table[green]; + blue = png_sRGB_table[blue]; + alpha *= 257; + encoding = E_LINEAR; + } + + /* This is set if the color isn't gray but the output is. */ + if (encoding == E_LINEAR) + { + if (convert_to_Y) + { + /* NOTE: these values are copied from png_do_rgb_to_gray */ + png_uint_32 y = (png_uint_32)6968 * red + (png_uint_32)23434 * green + + (png_uint_32)2366 * blue; + + if (output_encoding == E_LINEAR) + y = (y + 16384) >> 15; + + else + { + /* y is scaled by 32768, we need it scaled by 255: */ + y = (y + 128) >> 8; + y *= 255; + y = PNG_sRGB_FROM_LINEAR((y + 64) >> 7); + encoding = E_sRGB; + } + + blue = red = green = y; + } + + else if (output_encoding == E_sRGB) + { + red = PNG_sRGB_FROM_LINEAR(red * 255); + green = PNG_sRGB_FROM_LINEAR(green * 255); + blue = PNG_sRGB_FROM_LINEAR(blue * 255); + alpha = PNG_DIV257(alpha); + encoding = E_sRGB; + } + } + + if (encoding != output_encoding) + png_error(image->opaque->png_ptr, "bad encoding (internal error)"); + + /* Store the value. */ + { +# ifdef PNG_FORMAT_BGR_SUPPORTED + const int afirst = (image->format & PNG_FORMAT_FLAG_AFIRST) != 0 && + (image->format & PNG_FORMAT_FLAG_ALPHA) != 0; +# else +# define afirst 0 +# endif +# ifdef PNG_FORMAT_BGR_SUPPORTED + const int bgr = (image->format & PNG_FORMAT_FLAG_BGR) ? 2 : 0; +# else +# define bgr 0 +# endif + + if (output_encoding == E_LINEAR) + { + png_uint_16p entry = png_voidcast(png_uint_16p, display->colormap); + + entry += ip * PNG_IMAGE_SAMPLE_CHANNELS(image->format); + + /* The linear 16-bit values must be pre-multiplied by the alpha channel + * value, if less than 65535 (this is, effectively, composite on black + * if the alpha channel is removed.) + */ + switch (PNG_IMAGE_SAMPLE_CHANNELS(image->format)) + { + case 4: + entry[afirst ? 0 : 3] = (png_uint_16)alpha; + /* FALL THROUGH */ + + case 3: + if (alpha < 65535) + { + if (alpha > 0) + { + blue = (blue * alpha + 32767U)/65535U; + green = (green * alpha + 32767U)/65535U; + red = (red * alpha + 32767U)/65535U; + } + + else + red = green = blue = 0; + } + entry[afirst + (2 ^ bgr)] = (png_uint_16)blue; + entry[afirst + 1] = (png_uint_16)green; + entry[afirst + bgr] = (png_uint_16)red; + break; + + case 2: + entry[1 ^ afirst] = (png_uint_16)alpha; + /* FALL THROUGH */ + + case 1: + if (alpha < 65535) + { + if (alpha > 0) + green = (green * alpha + 32767U)/65535U; + + else + green = 0; + } + entry[afirst] = (png_uint_16)green; + break; + + default: + break; + } + } + + else /* output encoding is E_sRGB */ + { + png_bytep entry = png_voidcast(png_bytep, display->colormap); + + entry += ip * PNG_IMAGE_SAMPLE_CHANNELS(image->format); + + switch (PNG_IMAGE_SAMPLE_CHANNELS(image->format)) + { + case 4: + entry[afirst ? 0 : 3] = (png_byte)alpha; + case 3: + entry[afirst + (2 ^ bgr)] = (png_byte)blue; + entry[afirst + 1] = (png_byte)green; + entry[afirst + bgr] = (png_byte)red; + break; + + case 2: + entry[1 ^ afirst] = (png_byte)alpha; + case 1: + entry[afirst] = (png_byte)green; + break; + + default: + break; + } + } + +# ifdef afirst +# undef afirst +# endif +# ifdef bgr +# undef bgr +# endif + } +} + +static int +make_gray_file_colormap(png_image_read_control *display) +{ + unsigned int i; + + for (i=0; i<256; ++i) + png_create_colormap_entry(display, i, i, i, i, 255, E_FILE); + + return i; +} + +static int +make_gray_colormap(png_image_read_control *display) +{ + unsigned int i; + + for (i=0; i<256; ++i) + png_create_colormap_entry(display, i, i, i, i, 255, E_sRGB); + + return i; +} +#define PNG_GRAY_COLORMAP_ENTRIES 256 + +static int +make_ga_colormap(png_image_read_control *display) +{ + unsigned int i, a; + + /* Alpha is retained, the output will be a color-map with entries + * selected by six levels of alpha. One transparent entry, 6 gray + * levels for all the intermediate alpha values, leaving 230 entries + * for the opaque grays. The color-map entries are the six values + * [0..5]*51, the GA processing uses PNG_DIV51(value) to find the + * relevant entry. + * + * if (alpha > 229) // opaque + * { + * // The 231 entries are selected to make the math below work: + * base = 0; + * entry = (231 * gray + 128) >> 8; + * } + * else if (alpha < 26) // transparent + * { + * base = 231; + * entry = 0; + * } + * else // partially opaque + * { + * base = 226 + 6 * PNG_DIV51(alpha); + * entry = PNG_DIV51(gray); + * } + */ + i = 0; + while (i < 231) + { + unsigned int gray = (i * 256 + 115) / 231; + png_create_colormap_entry(display, i++, gray, gray, gray, 255, E_sRGB); + } + + /* 255 is used here for the component values for consistency with the code + * that undoes premultiplication in pngwrite.c. + */ + png_create_colormap_entry(display, i++, 255, 255, 255, 0, E_sRGB); + + for (a=1; a<5; ++a) + { + unsigned int g; + + for (g=0; g<6; ++g) + png_create_colormap_entry(display, i++, g*51, g*51, g*51, a*51, + E_sRGB); + } + + return i; +} + +#define PNG_GA_COLORMAP_ENTRIES 256 + +static int +make_rgb_colormap(png_image_read_control *display) +{ + unsigned int i, r; + + /* Build a 6x6x6 opaque RGB cube */ + for (i=r=0; r<6; ++r) + { + unsigned int g; + + for (g=0; g<6; ++g) + { + unsigned int b; + + for (b=0; b<6; ++b) + png_create_colormap_entry(display, i++, r*51, g*51, b*51, 255, + E_sRGB); + } + } + + return i; +} + +#define PNG_RGB_COLORMAP_ENTRIES 216 + +/* Return a palette index to the above palette given three 8-bit sRGB values. */ +#define PNG_RGB_INDEX(r,g,b) \ + ((png_byte)(6 * (6 * PNG_DIV51(r) + PNG_DIV51(g)) + PNG_DIV51(b))) + +static int +png_image_read_colormap(png_voidp argument) +{ + png_image_read_control *display = + png_voidcast(png_image_read_control*, argument); + const png_imagep image = display->image; + + const png_structrp png_ptr = image->opaque->png_ptr; + const png_uint_32 output_format = image->format; + const int output_encoding = (output_format & PNG_FORMAT_FLAG_LINEAR) ? + E_LINEAR : E_sRGB; + + unsigned int cmap_entries; + unsigned int output_processing; /* Output processing option */ + unsigned int data_encoding = E_NOTSET; /* Encoding libpng must produce */ + + /* Background information; the background color and the index of this color + * in the color-map if it exists (else 256). + */ + unsigned int background_index = 256; + png_uint_32 back_r, back_g, back_b; + + /* Flags to accumulate things that need to be done to the input. */ + int expand_tRNS = 0; + + /* Exclude the NYI feature of compositing onto a color-mapped buffer; it is + * very difficult to do, the results look awful, and it is difficult to see + * what possible use it is because the application can't control the + * color-map. + */ + if (((png_ptr->color_type & PNG_COLOR_MASK_ALPHA) != 0 || + png_ptr->num_trans > 0) /* alpha in input */ && + ((output_format & PNG_FORMAT_FLAG_ALPHA) == 0) /* no alpha in output */) + { + if (output_encoding == E_LINEAR) /* compose on black */ + back_b = back_g = back_r = 0; + + else if (display->background == NULL /* no way to remove it */) + png_error(png_ptr, + "a background color must be supplied to remove alpha/transparency"); + + /* Get a copy of the background color (this avoids repeating the checks + * below.) The encoding is 8-bit sRGB or 16-bit linear, depending on the + * output format. + */ + else + { + back_g = display->background->green; + if (output_format & PNG_FORMAT_FLAG_COLOR) + { + back_r = display->background->red; + back_b = display->background->blue; + } + else + back_b = back_r = back_g; + } + } + + else if (output_encoding == E_LINEAR) + back_b = back_r = back_g = 65535; + + else + back_b = back_r = back_g = 255; + + /* Default the input file gamma if required - this is necessary because + * libpng assumes that if no gamma information is present the data is in the + * output format, but the simplified API deduces the gamma from the input + * format. + */ + if ((png_ptr->colorspace.flags & PNG_COLORSPACE_HAVE_GAMMA) == 0) + { + /* Do this directly, not using the png_colorspace functions, to ensure + * that it happens even if the colorspace is invalid (though probably if + * it is the setting will be ignored) Note that the same thing can be + * achieved at the application interface with png_set_gAMA. + */ + if (png_ptr->bit_depth == 16 && + (image->flags & PNG_IMAGE_FLAG_16BIT_sRGB) == 0) + png_ptr->colorspace.gamma = PNG_GAMMA_LINEAR; + + else + png_ptr->colorspace.gamma = PNG_GAMMA_sRGB_INVERSE; + + png_ptr->colorspace.flags |= PNG_COLORSPACE_HAVE_GAMMA; + } + + /* Decide what to do based on the PNG color type of the input data. The + * utility function png_create_colormap_entry deals with most aspects of the + * output transformations; this code works out how to produce bytes of + * color-map entries from the original format. + */ + switch (png_ptr->color_type) + { + case PNG_COLOR_TYPE_GRAY: + if (png_ptr->bit_depth <= 8) + { + /* There at most 256 colors in the output, regardless of + * transparency. + */ + unsigned int step, i, val, trans = 256/*ignore*/, back_alpha = 0; + + cmap_entries = 1U << png_ptr->bit_depth; + if (cmap_entries > image->colormap_entries) + png_error(png_ptr, "gray[8] color-map: too few entries"); + + step = 255 / (cmap_entries - 1); + output_processing = PNG_CMAP_NONE; + + /* If there is a tRNS chunk then this either selects a transparent + * value or, if the output has no alpha, the background color. + */ + if (png_ptr->num_trans > 0) + { + trans = png_ptr->trans_color.gray; + + if ((output_format & PNG_FORMAT_FLAG_ALPHA) == 0) + back_alpha = output_encoding == E_LINEAR ? 65535 : 255; + } + + /* png_create_colormap_entry just takes an RGBA and writes the + * corresponding color-map entry using the format from 'image', + * including the required conversion to sRGB or linear as + * appropriate. The input values are always either sRGB (if the + * gamma correction flag is 0) or 0..255 scaled file encoded values + * (if the function must gamma correct them). + */ + for (i=val=0; ibit_depth < 8) + png_set_packing(png_ptr); + } + + else /* bit depth is 16 */ + { + /* The 16-bit input values can be converted directly to 8-bit gamma + * encoded values; however, if a tRNS chunk is present 257 color-map + * entries are required. This means that the extra entry requires + * special processing; add an alpha channel, sacrifice gray level + * 254 and convert transparent (alpha==0) entries to that. + * + * Use libpng to chop the data to 8 bits. Convert it to sRGB at the + * same time to minimize quality loss. If a tRNS chunk is present + * this means libpng must handle it too; otherwise it is impossible + * to do the exact match on the 16-bit value. + * + * If the output has no alpha channel *and* the background color is + * gray then it is possible to let libpng handle the substitution by + * ensuring that the corresponding gray level matches the background + * color exactly. + */ + data_encoding = E_sRGB; + + if (PNG_GRAY_COLORMAP_ENTRIES > image->colormap_entries) + png_error(png_ptr, "gray[16] color-map: too few entries"); + + cmap_entries = make_gray_colormap(display); + + if (png_ptr->num_trans > 0) + { + unsigned int back_alpha; + + if (output_format & PNG_FORMAT_FLAG_ALPHA) + back_alpha = 0; + + else + { + if (back_r == back_g && back_g == back_b) + { + /* Background is gray; no special processing will be + * required. + */ + png_color_16 c; + png_uint_32 gray = back_g; + + if (output_encoding == E_LINEAR) + { + gray = PNG_sRGB_FROM_LINEAR(gray * 255); + + /* And make sure the corresponding palette entry + * matches. + */ + png_create_colormap_entry(display, gray, back_g, back_g, + back_g, 65535, E_LINEAR); + } + + /* The background passed to libpng, however, must be the + * sRGB value. + */ + c.index = 0; /*unused*/ + c.gray = c.red = c.green = c.blue = (png_uint_16)gray; + + /* NOTE: does this work without expanding tRNS to alpha? + * It should be the color->gray case below apparently + * doesn't. + */ + png_set_background_fixed(png_ptr, &c, + PNG_BACKGROUND_GAMMA_SCREEN, 0/*need_expand*/, + 0/*gamma: not used*/); + + output_processing = PNG_CMAP_NONE; + break; + } + + back_alpha = output_encoding == E_LINEAR ? 65535 : 255; + } + + /* output_processing means that the libpng-processed row will be + * 8-bit GA and it has to be processing to single byte color-map + * values. Entry 254 is replaced by either a completely + * transparent entry or by the background color at full + * precision (and the background color is not a simple gray leve + * in this case.) + */ + expand_tRNS = 1; + output_processing = PNG_CMAP_TRANS; + background_index = 254; + + /* And set (overwrite) color-map entry 254 to the actual + * background color at full precision. + */ + png_create_colormap_entry(display, 254, back_r, back_g, back_b, + back_alpha, output_encoding); + } + + else + output_processing = PNG_CMAP_NONE; + } + break; + + case PNG_COLOR_TYPE_GRAY_ALPHA: + /* 8-bit or 16-bit PNG with two channels - gray and alpha. A minimum + * of 65536 combinations. If, however, the alpha channel is to be + * removed there are only 256 possibilities if the background is gray. + * (Otherwise there is a subset of the 65536 possibilities defined by + * the triangle between black, white and the background color.) + * + * Reduce 16-bit files to 8-bit and sRGB encode the result. No need to + * worry about tRNS matching - tRNS is ignored if there is an alpha + * channel. + */ + data_encoding = E_sRGB; + + if (output_format & PNG_FORMAT_FLAG_ALPHA) + { + if (PNG_GA_COLORMAP_ENTRIES > image->colormap_entries) + png_error(png_ptr, "gray+alpha color-map: too few entries"); + + cmap_entries = make_ga_colormap(display); + + background_index = PNG_CMAP_GA_BACKGROUND; + output_processing = PNG_CMAP_GA; + } + + else /* alpha is removed */ + { + /* Alpha must be removed as the PNG data is processed when the + * background is a color because the G and A channels are + * independent and the vector addition (non-parallel vectors) is a + * 2-D problem. + * + * This can be reduced to the same algorithm as above by making a + * colormap containing gray levels (for the opaque grays), a + * background entry (for a transparent pixel) and a set of four six + * level color values, one set for each intermediate alpha value. + * See the comments in make_ga_colormap for how this works in the + * per-pixel processing. + * + * If the background is gray, however, we only need a 256 entry gray + * level color map. It is sufficient to make the entry generated + * for the background color be exactly the color specified. + */ + if ((output_format & PNG_FORMAT_FLAG_COLOR) == 0 || + (back_r == back_g && back_g == back_b)) + { + /* Background is gray; no special processing will be required. */ + png_color_16 c; + png_uint_32 gray = back_g; + + if (PNG_GRAY_COLORMAP_ENTRIES > image->colormap_entries) + png_error(png_ptr, "gray-alpha color-map: too few entries"); + + cmap_entries = make_gray_colormap(display); + + if (output_encoding == E_LINEAR) + { + gray = PNG_sRGB_FROM_LINEAR(gray * 255); + + /* And make sure the corresponding palette entry matches. */ + png_create_colormap_entry(display, gray, back_g, back_g, + back_g, 65535, E_LINEAR); + } + + /* The background passed to libpng, however, must be the sRGB + * value. + */ + c.index = 0; /*unused*/ + c.gray = c.red = c.green = c.blue = (png_uint_16)gray; + + png_set_background_fixed(png_ptr, &c, + PNG_BACKGROUND_GAMMA_SCREEN, 0/*need_expand*/, + 0/*gamma: not used*/); + + output_processing = PNG_CMAP_NONE; + } + + else + { + png_uint_32 i, a; + + /* This is the same as png_make_ga_colormap, above, except that + * the entries are all opaque. + */ + if (PNG_GA_COLORMAP_ENTRIES > image->colormap_entries) + png_error(png_ptr, "ga-alpha color-map: too few entries"); + + i = 0; + while (i < 231) + { + png_uint_32 gray = (i * 256 + 115) / 231; + png_create_colormap_entry(display, i++, gray, gray, gray, + 255, E_sRGB); + } + + /* NOTE: this preserves the full precision of the application + * background color. + */ + background_index = i; + png_create_colormap_entry(display, i++, back_r, back_g, back_b, + output_encoding == E_LINEAR ? 65535U : 255U, output_encoding); + + /* For non-opaque input composite on the sRGB background - this + * requires inverting the encoding for each component. The input + * is still converted to the sRGB encoding because this is a + * reasonable approximate to the logarithmic curve of human + * visual sensitivity, at least over the narrow range which PNG + * represents. Consequently 'G' is always sRGB encoded, while + * 'A' is linear. We need the linear background colors. + */ + if (output_encoding == E_sRGB) /* else already linear */ + { + /* This may produce a value not exactly matching the + * background, but that's ok because these numbers are only + * used when alpha != 0 + */ + back_r = png_sRGB_table[back_r]; + back_g = png_sRGB_table[back_g]; + back_b = png_sRGB_table[back_b]; + } + + for (a=1; a<5; ++a) + { + unsigned int g; + + /* PNG_sRGB_FROM_LINEAR expects a 16-bit linear value scaled + * by an 8-bit alpha value (0..255). + */ + png_uint_32 alpha = 51 * a; + png_uint_32 back_rx = (255-alpha) * back_r; + png_uint_32 back_gx = (255-alpha) * back_g; + png_uint_32 back_bx = (255-alpha) * back_b; + + for (g=0; g<6; ++g) + { + png_uint_32 gray = png_sRGB_table[g*51] * alpha; + + png_create_colormap_entry(display, i++, + PNG_sRGB_FROM_LINEAR(gray + back_rx), + PNG_sRGB_FROM_LINEAR(gray + back_gx), + PNG_sRGB_FROM_LINEAR(gray + back_bx), 255, E_sRGB); + } + } + + cmap_entries = i; + output_processing = PNG_CMAP_GA; + } + } + break; + + case PNG_COLOR_TYPE_RGB: + case PNG_COLOR_TYPE_RGB_ALPHA: + /* Exclude the case where the output is gray; we can always handle this + * with the cases above. + */ + if ((output_format & PNG_FORMAT_FLAG_COLOR) == 0) + { + /* The color-map will be grayscale, so we may as well convert the + * input RGB values to a simple grayscale and use the grayscale + * code above. + * + * NOTE: calling this apparently damages the recognition of the + * transparent color in background color handling; call + * png_set_tRNS_to_alpha before png_set_background_fixed. + */ + png_set_rgb_to_gray_fixed(png_ptr, PNG_ERROR_ACTION_NONE, -1, + -1); + data_encoding = E_sRGB; + + /* The output will now be one or two 8-bit gray or gray+alpha + * channels. The more complex case arises when the input has alpha. + */ + if ((png_ptr->color_type == PNG_COLOR_TYPE_RGB_ALPHA || + png_ptr->num_trans > 0) && + (output_format & PNG_FORMAT_FLAG_ALPHA) != 0) + { + /* Both input and output have an alpha channel, so no background + * processing is required; just map the GA bytes to the right + * color-map entry. + */ + expand_tRNS = 1; + + if (PNG_GA_COLORMAP_ENTRIES > image->colormap_entries) + png_error(png_ptr, "rgb[ga] color-map: too few entries"); + + cmap_entries = make_ga_colormap(display); + background_index = PNG_CMAP_GA_BACKGROUND; + output_processing = PNG_CMAP_GA; + } + + else + { + /* Either the input or the output has no alpha channel, so there + * will be no non-opaque pixels in the color-map; it will just be + * grayscale. + */ + if (PNG_GRAY_COLORMAP_ENTRIES > image->colormap_entries) + png_error(png_ptr, "rgb[gray] color-map: too few entries"); + + /* Ideally this code would use libpng to do the gamma correction, + * but if an input alpha channel is to be removed we will hit the + * libpng bug in gamma+compose+rgb-to-gray (the double gamma + * correction bug). Fix this by dropping the gamma correction in + * this case and doing it in the palette; this will result in + * duplicate palette entries, but that's better than the + * alternative of double gamma correction. + */ + if ((png_ptr->color_type == PNG_COLOR_TYPE_RGB_ALPHA || + png_ptr->num_trans > 0) && + png_gamma_not_sRGB(png_ptr->colorspace.gamma)) + { + cmap_entries = make_gray_file_colormap(display); + data_encoding = E_FILE; + } + + else + cmap_entries = make_gray_colormap(display); + + /* But if the input has alpha or transparency it must be removed + */ + if (png_ptr->color_type == PNG_COLOR_TYPE_RGB_ALPHA || + png_ptr->num_trans > 0) + { + png_color_16 c; + png_uint_32 gray = back_g; + + /* We need to ensure that the application background exists in + * the colormap and that completely transparent pixels map to + * it. Achieve this simply by ensuring that the entry + * selected for the background really is the background color. + */ + if (data_encoding == E_FILE) /* from the fixup above */ + { + /* The app supplied a gray which is in output_encoding, we + * need to convert it to a value of the input (E_FILE) + * encoding then set this palette entry to the required + * output encoding. + */ + if (output_encoding == E_sRGB) + gray = png_sRGB_table[gray]; /* now E_LINEAR */ + + gray = PNG_DIV257(png_gamma_16bit_correct(gray, + png_ptr->colorspace.gamma)); /* now E_FILE */ + + /* And make sure the corresponding palette entry contains + * exactly the required sRGB value. + */ + png_create_colormap_entry(display, gray, back_g, back_g, + back_g, 0/*unused*/, output_encoding); + } + + else if (output_encoding == E_LINEAR) + { + gray = PNG_sRGB_FROM_LINEAR(gray * 255); + + /* And make sure the corresponding palette entry matches. + */ + png_create_colormap_entry(display, gray, back_g, back_g, + back_g, 0/*unused*/, E_LINEAR); + } + + /* The background passed to libpng, however, must be the + * output (normally sRGB) value. + */ + c.index = 0; /*unused*/ + c.gray = c.red = c.green = c.blue = (png_uint_16)gray; + + /* NOTE: the following is apparently a bug in libpng. Without + * it the transparent color recognition in + * png_set_background_fixed seems to go wrong. + */ + expand_tRNS = 1; + png_set_background_fixed(png_ptr, &c, + PNG_BACKGROUND_GAMMA_SCREEN, 0/*need_expand*/, + 0/*gamma: not used*/); + } + + output_processing = PNG_CMAP_NONE; + } + } + + else /* output is color */ + { + /* We could use png_quantize here so long as there is no transparent + * color or alpha; png_quantize ignores alpha. Easier overall just + * to do it once and using PNG_DIV51 on the 6x6x6 reduced RGB cube. + * Consequently we always want libpng to produce sRGB data. + */ + data_encoding = E_sRGB; + + /* Is there any transparency or alpha? */ + if (png_ptr->color_type == PNG_COLOR_TYPE_RGB_ALPHA || + png_ptr->num_trans > 0) + { + /* Is there alpha in the output too? If so all four channels are + * processed into a special RGB cube with alpha support. + */ + if (output_format & PNG_FORMAT_FLAG_ALPHA) + { + png_uint_32 r; + + if (PNG_RGB_COLORMAP_ENTRIES+1+27 > image->colormap_entries) + png_error(png_ptr, "rgb+alpha color-map: too few entries"); + + cmap_entries = make_rgb_colormap(display); + + /* Add a transparent entry. */ + png_create_colormap_entry(display, cmap_entries, 255, 255, + 255, 0, E_sRGB); + + /* This is stored as the background index for the processing + * algorithm. + */ + background_index = cmap_entries++; + + /* Add 27 r,g,b entries each with alpha 0.5. */ + for (r=0; r<256; r = (r << 1) | 0x7f) + { + png_uint_32 g; + + for (g=0; g<256; g = (g << 1) | 0x7f) + { + png_uint_32 b; + + /* This generates components with the values 0, 127 and + * 255 + */ + for (b=0; b<256; b = (b << 1) | 0x7f) + png_create_colormap_entry(display, cmap_entries++, + r, g, b, 128, E_sRGB); + } + } + + expand_tRNS = 1; + output_processing = PNG_CMAP_RGB_ALPHA; + } + + else + { + /* Alpha/transparency must be removed. The background must + * exist in the color map (achieved by setting adding it after + * the 666 color-map). If the standard processing code will + * pick up this entry automatically that's all that is + * required; libpng can be called to do the background + * processing. + */ + unsigned int sample_size = + PNG_IMAGE_SAMPLE_SIZE(output_format); + png_uint_32 r, g, b; /* sRGB background */ + + if (PNG_RGB_COLORMAP_ENTRIES+1+27 > image->colormap_entries) + png_error(png_ptr, "rgb-alpha color-map: too few entries"); + + cmap_entries = make_rgb_colormap(display); + + png_create_colormap_entry(display, cmap_entries, back_r, + back_g, back_b, 0/*unused*/, output_encoding); + + if (output_encoding == E_LINEAR) + { + r = PNG_sRGB_FROM_LINEAR(back_r * 255); + g = PNG_sRGB_FROM_LINEAR(back_g * 255); + b = PNG_sRGB_FROM_LINEAR(back_b * 255); + } + + else + { + r = back_r; + g = back_g; + b = back_g; + } + + /* Compare the newly-created color-map entry with the one the + * PNG_CMAP_RGB algorithm will use. If the two entries don't + * match, add the new one and set this as the background + * index. + */ + if (memcmp((png_const_bytep)display->colormap + + sample_size * cmap_entries, + (png_const_bytep)display->colormap + + sample_size * PNG_RGB_INDEX(r,g,b), + sample_size) != 0) + { + /* The background color must be added. */ + background_index = cmap_entries++; + + /* Add 27 r,g,b entries each with created by composing with + * the background at alpha 0.5. + */ + for (r=0; r<256; r = (r << 1) | 0x7f) + { + for (g=0; g<256; g = (g << 1) | 0x7f) + { + /* This generates components with the values 0, 127 + * and 255 + */ + for (b=0; b<256; b = (b << 1) | 0x7f) + png_create_colormap_entry(display, cmap_entries++, + png_colormap_compose(display, r, E_sRGB, 128, + back_r, output_encoding), + png_colormap_compose(display, g, E_sRGB, 128, + back_g, output_encoding), + png_colormap_compose(display, b, E_sRGB, 128, + back_b, output_encoding), + 0/*unused*/, output_encoding); + } + } + + expand_tRNS = 1; + output_processing = PNG_CMAP_RGB_ALPHA; + } + + else /* background color is in the standard color-map */ + { + png_color_16 c; + + c.index = 0; /*unused*/ + c.red = (png_uint_16)back_r; + c.gray = c.green = (png_uint_16)back_g; + c.blue = (png_uint_16)back_b; + + png_set_background_fixed(png_ptr, &c, + PNG_BACKGROUND_GAMMA_SCREEN, 0/*need_expand*/, + 0/*gamma: not used*/); + + output_processing = PNG_CMAP_RGB; + } + } + } + + else /* no alpha or transparency in the input */ + { + /* Alpha in the output is irrelevant, simply map the opaque input + * pixels to the 6x6x6 color-map. + */ + if (PNG_RGB_COLORMAP_ENTRIES > image->colormap_entries) + png_error(png_ptr, "rgb color-map: too few entries"); + + cmap_entries = make_rgb_colormap(display); + output_processing = PNG_CMAP_RGB; + } + } + break; + + case PNG_COLOR_TYPE_PALETTE: + /* It's already got a color-map. It may be necessary to eliminate the + * tRNS entries though. + */ + { + unsigned int num_trans = png_ptr->num_trans; + png_const_bytep trans = num_trans > 0 ? png_ptr->trans_alpha : NULL; + png_const_colorp colormap = png_ptr->palette; + const int do_background = trans != NULL && + (output_format & PNG_FORMAT_FLAG_ALPHA) == 0; + unsigned int i; + + /* Just in case: */ + if (trans == NULL) + num_trans = 0; + + output_processing = PNG_CMAP_NONE; + data_encoding = E_FILE; /* Don't change from color-map indicies */ + cmap_entries = png_ptr->num_palette; + if (cmap_entries > 256) + cmap_entries = 256; + + if (cmap_entries > image->colormap_entries) + png_error(png_ptr, "palette color-map: too few entries"); + + for (i=0; i < cmap_entries; ++i) + { + if (do_background && i < num_trans && trans[i] < 255) + { + if (trans[i] == 0) + png_create_colormap_entry(display, i, back_r, back_g, + back_b, 0, output_encoding); + + else + { + /* Must compose the PNG file color in the color-map entry + * on the sRGB color in 'back'. + */ + png_create_colormap_entry(display, i, + png_colormap_compose(display, colormap[i].red, E_FILE, + trans[i], back_r, output_encoding), + png_colormap_compose(display, colormap[i].green, E_FILE, + trans[i], back_g, output_encoding), + png_colormap_compose(display, colormap[i].blue, E_FILE, + trans[i], back_b, output_encoding), + output_encoding == E_LINEAR ? trans[i] * 257U : + trans[i], + output_encoding); + } + } + + else + png_create_colormap_entry(display, i, colormap[i].red, + colormap[i].green, colormap[i].blue, + i < num_trans ? trans[i] : 255U, E_FILE/*8-bit*/); + } + + /* The PNG data may have indicies packed in fewer than 8 bits, it + * must be expanded if so. + */ + if (png_ptr->bit_depth < 8) + png_set_packing(png_ptr); + } + break; + + default: + png_error(png_ptr, "invalid PNG color type"); + /*NOT REACHED*/ + break; + } + + /* Now deal with the output processing */ + if (expand_tRNS && png_ptr->num_trans > 0 && + (png_ptr->color_type & PNG_COLOR_MASK_ALPHA) == 0) + png_set_tRNS_to_alpha(png_ptr); + + switch (data_encoding) + { + default: + png_error(png_ptr, "bad data option (internal error)"); + break; + + case E_sRGB: + /* Change to 8-bit sRGB */ + png_set_alpha_mode_fixed(png_ptr, PNG_ALPHA_PNG, PNG_GAMMA_sRGB); + /* FALL THROUGH */ + + case E_FILE: + if (png_ptr->bit_depth > 8) + png_set_scale_16(png_ptr); + break; + } + + if (cmap_entries > 256 || cmap_entries > image->colormap_entries) + png_error(png_ptr, "color map overflow (BAD internal error)"); + + image->colormap_entries = cmap_entries; + + /* Double check using the recorded background index */ + switch (output_processing) + { + case PNG_CMAP_NONE: + if (background_index != PNG_CMAP_NONE_BACKGROUND) + goto bad_background; + break; + + case PNG_CMAP_GA: + if (background_index != PNG_CMAP_GA_BACKGROUND) + goto bad_background; + break; + + case PNG_CMAP_TRANS: + if (background_index >= cmap_entries || + background_index != PNG_CMAP_TRANS_BACKGROUND) + goto bad_background; + break; + + case PNG_CMAP_RGB: + if (background_index != PNG_CMAP_RGB_BACKGROUND) + goto bad_background; + break; + + case PNG_CMAP_RGB_ALPHA: + if (background_index != PNG_CMAP_RGB_ALPHA_BACKGROUND) + goto bad_background; + break; + + default: + png_error(png_ptr, "bad processing option (internal error)"); + + bad_background: + png_error(png_ptr, "bad background index (internal error)"); + } + + display->colormap_processing = output_processing; + + return 1/*ok*/; +} + +/* The final part of the color-map read called from png_image_finish_read. */ +static int +png_image_read_and_map(png_voidp argument) +{ + png_image_read_control *display = png_voidcast(png_image_read_control*, + argument); + png_imagep image = display->image; + png_structrp png_ptr = image->opaque->png_ptr; + int passes; + + /* Called when the libpng data must be transformed into the color-mapped + * form. There is a local row buffer in display->local and this routine must + * do the interlace handling. + */ + switch (png_ptr->interlaced) + { + case PNG_INTERLACE_NONE: + passes = 1; + break; + + case PNG_INTERLACE_ADAM7: + passes = PNG_INTERLACE_ADAM7_PASSES; + break; + + default: + passes = 0; + png_error(png_ptr, "unknown interlace type"); + } + + { + png_uint_32 height = image->height; + png_uint_32 width = image->width; + int proc = display->colormap_processing; + png_bytep first_row = png_voidcast(png_bytep, display->first_row); + ptrdiff_t step_row = display->row_bytes; + int pass; + + for (pass = 0; pass < passes; ++pass) + { + unsigned int startx, stepx, stepy; + png_uint_32 y; + + if (png_ptr->interlaced == PNG_INTERLACE_ADAM7) + { + /* The row may be empty for a short image: */ + if (PNG_PASS_COLS(width, pass) == 0) + continue; + + startx = PNG_PASS_START_COL(pass); + stepx = PNG_PASS_COL_OFFSET(pass); + y = PNG_PASS_START_ROW(pass); + stepy = PNG_PASS_ROW_OFFSET(pass); + } + + else + { + y = 0; + startx = 0; + stepx = stepy = 1; + } + + for (; ylocal_row); + png_bytep outrow = first_row + y * step_row; + png_const_bytep end_row = outrow + width; + + /* Read read the libpng data into the temporary buffer. */ + png_read_row(png_ptr, inrow, NULL); + + /* Now process the row according to the processing option, note + * that the caller verifies that the format of the libpng output + * data is as required. + */ + outrow += startx; + switch (proc) + { + case PNG_CMAP_GA: + for (; outrow < end_row; outrow += stepx) + { + /* The data is always in the PNG order */ + unsigned int gray = *inrow++; + unsigned int alpha = *inrow++; + unsigned int entry; + + /* NOTE: this code is copied as a comment in + * make_ga_colormap above. Please update the + * comment if you change this code! + */ + if (alpha > 229) /* opaque */ + { + entry = (231 * gray + 128) >> 8; + } + else if (alpha < 26) /* transparent */ + { + entry = 231; + } + else /* partially opaque */ + { + entry = 226 + 6 * PNG_DIV51(alpha) + PNG_DIV51(gray); + } + + *outrow = (png_byte)entry; + } + break; + + case PNG_CMAP_TRANS: + for (; outrow < end_row; outrow += stepx) + { + png_byte gray = *inrow++; + png_byte alpha = *inrow++; + + if (alpha == 0) + *outrow = PNG_CMAP_TRANS_BACKGROUND; + + else if (gray != PNG_CMAP_TRANS_BACKGROUND) + *outrow = gray; + + else + *outrow = (png_byte)(PNG_CMAP_TRANS_BACKGROUND+1); + } + break; + + case PNG_CMAP_RGB: + for (; outrow < end_row; outrow += stepx) + { + *outrow = PNG_RGB_INDEX(inrow[0], inrow[1], inrow[2]); + inrow += 3; + } + break; + + case PNG_CMAP_RGB_ALPHA: + for (; outrow < end_row; outrow += stepx) + { + unsigned int alpha = inrow[3]; + + /* Because the alpha entries only hold alpha==0.5 values + * split the processing at alpha==0.25 (64) and 0.75 + * (196). + */ + + if (alpha >= 196) + *outrow = PNG_RGB_INDEX(inrow[0], inrow[1], + inrow[2]); + + else if (alpha < 64) + *outrow = PNG_CMAP_RGB_ALPHA_BACKGROUND; + + else + { + /* Likewise there are three entries for each of r, g + * and b. We could select the entry by popcount on + * the top two bits on those architectures that + * support it, this is what the code below does, + * crudely. + */ + unsigned int back_i = PNG_CMAP_RGB_ALPHA_BACKGROUND+1; + + /* Here are how the values map: + * + * 0x00 .. 0x3f -> 0 + * 0x40 .. 0xbf -> 1 + * 0xc0 .. 0xff -> 2 + * + * So, as above with the explicit alpha checks, the + * breakpoints are at 64 and 196. + */ + if (inrow[0] & 0x80) back_i += 9; /* red */ + if (inrow[0] & 0x40) back_i += 9; + if (inrow[0] & 0x80) back_i += 3; /* green */ + if (inrow[0] & 0x40) back_i += 3; + if (inrow[0] & 0x80) back_i += 1; /* blue */ + if (inrow[0] & 0x40) back_i += 1; + + *outrow = (png_byte)back_i; + } + + inrow += 4; + } + break; + + default: + break; + } + } + } + } + + return 1; +} + +static int +png_image_read_colormapped(png_voidp argument) +{ + png_image_read_control *display = png_voidcast(png_image_read_control*, + argument); + png_imagep image = display->image; + png_controlp control = image->opaque; + png_structrp png_ptr = control->png_ptr; + png_inforp info_ptr = control->info_ptr; + + int passes = 0; /* As a flag */ + + PNG_SKIP_CHUNKS(png_ptr); + + /* Update the 'info' structure and make sure the result is as required; first + * make sure to turn on the interlace handling if it will be required + * (because it can't be turned on *after* the call to png_read_update_info!) + */ + if (display->colormap_processing == PNG_CMAP_NONE) + passes = png_set_interlace_handling(png_ptr); + + png_read_update_info(png_ptr, info_ptr); + + /* The expected output can be deduced from the colormap_processing option. */ + switch (display->colormap_processing) + { + case PNG_CMAP_NONE: + /* Output must be one channel and one byte per pixel, the output + * encoding can be anything. + */ + if ((info_ptr->color_type == PNG_COLOR_TYPE_PALETTE || + info_ptr->color_type == PNG_COLOR_TYPE_GRAY) && + info_ptr->bit_depth == 8) + break; + + goto bad_output; + + case PNG_CMAP_TRANS: + case PNG_CMAP_GA: + /* Output must be two channels and the 'G' one must be sRGB, the latter + * can be checked with an exact number because it should have been set + * to this number above! + */ + if (info_ptr->color_type == PNG_COLOR_TYPE_GRAY_ALPHA && + info_ptr->bit_depth == 8 && + png_ptr->screen_gamma == PNG_GAMMA_sRGB && + image->colormap_entries == 256) + break; + + goto bad_output; + + case PNG_CMAP_RGB: + /* Output must be 8-bit sRGB encoded RGB */ + if (info_ptr->color_type == PNG_COLOR_TYPE_RGB && + info_ptr->bit_depth == 8 && + png_ptr->screen_gamma == PNG_GAMMA_sRGB && + image->colormap_entries == 216) + break; + + goto bad_output; + + case PNG_CMAP_RGB_ALPHA: + /* Output must be 8-bit sRGB encoded RGBA */ + if (info_ptr->color_type == PNG_COLOR_TYPE_RGB_ALPHA && + info_ptr->bit_depth == 8 && + png_ptr->screen_gamma == PNG_GAMMA_sRGB && + image->colormap_entries == 244 /* 216 + 1 + 27 */) + break; + + /* goto bad_output; */ + /* FALL THROUGH */ + + default: + bad_output: + png_error(png_ptr, "bad color-map processing (internal error)"); + } + + /* Now read the rows. Do this here if it is possible to read directly into + * the output buffer, otherwise allocate a local row buffer of the maximum + * size libpng requires and call the relevant processing routine safely. + */ + { + png_voidp first_row = display->buffer; + ptrdiff_t row_bytes = display->row_stride; + + /* The following expression is designed to work correctly whether it gives + * a signed or an unsigned result. + */ + if (row_bytes < 0) + { + char *ptr = png_voidcast(char*, first_row); + ptr += (image->height-1) * (-row_bytes); + first_row = png_voidcast(png_voidp, ptr); + } + + display->first_row = first_row; + display->row_bytes = row_bytes; + } + + if (passes == 0) + { + int result; + png_voidp row = png_malloc(png_ptr, png_get_rowbytes(png_ptr, info_ptr)); + + display->local_row = row; + result = png_safe_execute(image, png_image_read_and_map, display); + display->local_row = NULL; + png_free(png_ptr, row); + + return result; + } + + else + { + png_alloc_size_t row_bytes = display->row_bytes; + + while (--passes >= 0) + { + png_uint_32 y = image->height; + png_bytep row = png_voidcast(png_bytep, display->first_row); + + while (y-- > 0) + { + png_read_row(png_ptr, row, NULL); + row += row_bytes; + } + } + + return 1; + } +} + +/* Just the row reading part of png_image_read. */ +static int +png_image_read_composite(png_voidp argument) +{ + png_image_read_control *display = png_voidcast(png_image_read_control*, + argument); + png_imagep image = display->image; + png_structrp png_ptr = image->opaque->png_ptr; + int passes; + + switch (png_ptr->interlaced) + { + case PNG_INTERLACE_NONE: + passes = 1; + break; + + case PNG_INTERLACE_ADAM7: + passes = PNG_INTERLACE_ADAM7_PASSES; + break; + + default: + passes = 0; + png_error(png_ptr, "unknown interlace type"); + } + + { + png_uint_32 height = image->height; + png_uint_32 width = image->width; + ptrdiff_t step_row = display->row_bytes; + unsigned int channels = (image->format & PNG_FORMAT_FLAG_COLOR) ? 3 : 1; + int pass; + + for (pass = 0; pass < passes; ++pass) + { + unsigned int startx, stepx, stepy; + png_uint_32 y; + + if (png_ptr->interlaced == PNG_INTERLACE_ADAM7) + { + /* The row may be empty for a short image: */ + if (PNG_PASS_COLS(width, pass) == 0) + continue; + + startx = PNG_PASS_START_COL(pass) * channels; + stepx = PNG_PASS_COL_OFFSET(pass) * channels; + y = PNG_PASS_START_ROW(pass); + stepy = PNG_PASS_ROW_OFFSET(pass); + } + + else + { + y = 0; + startx = 0; + stepx = channels; + stepy = 1; + } + + for (; ylocal_row); + png_bytep outrow; + png_const_bytep end_row; + + /* Read the row, which is packed: */ + png_read_row(png_ptr, inrow, NULL); + + outrow = png_voidcast(png_bytep, display->first_row); + outrow += y * step_row; + end_row = outrow + width * channels; + + /* Now do the composition on each pixel in this row. */ + outrow += startx; + for (; outrow < end_row; outrow += stepx) + { + png_byte alpha = inrow[channels]; + + if (alpha > 0) /* else no change to the output */ + { + unsigned int c; + + for (c=0; cimage; + png_structrp png_ptr = image->opaque->png_ptr; + png_inforp info_ptr = image->opaque->info_ptr; + png_uint_32 height = image->height; + png_uint_32 width = image->width; + int pass, passes; + + /* Double check the convoluted logic below. We expect to get here with + * libpng doing rgb to gray and gamma correction but background processing + * left to the png_image_read_background function. The rows libpng produce + * might be 8 or 16-bit but should always have two channels; gray plus alpha. + */ + if ((png_ptr->transformations & PNG_RGB_TO_GRAY) == 0) + png_error(png_ptr, "lost rgb to gray"); + + if ((png_ptr->transformations & PNG_COMPOSE) != 0) + png_error(png_ptr, "unexpected compose"); + + if (png_get_channels(png_ptr, info_ptr) != 2) + png_error(png_ptr, "lost/gained channels"); + + /* Expect the 8-bit case to always remove the alpha channel */ + if ((image->format & PNG_FORMAT_FLAG_LINEAR) == 0 && + (image->format & PNG_FORMAT_FLAG_ALPHA) != 0) + png_error(png_ptr, "unexpected 8-bit transformation"); + + switch (png_ptr->interlaced) + { + case PNG_INTERLACE_NONE: + passes = 1; + break; + + case PNG_INTERLACE_ADAM7: + passes = PNG_INTERLACE_ADAM7_PASSES; + break; + + default: + passes = 0; + png_error(png_ptr, "unknown interlace type"); + } + + switch (png_get_bit_depth(png_ptr, info_ptr)) + { + default: + png_error(png_ptr, "unexpected bit depth"); + break; + + case 8: + /* 8-bit sRGB gray values with an alpha channel; the alpha channel is + * to be removed by composing on a background: either the row if + * display->background is NULL or display->background->green if not. + * Unlike the code above ALPHA_OPTIMIZED has *not* been done. + */ + { + png_bytep first_row = png_voidcast(png_bytep, display->first_row); + ptrdiff_t step_row = display->row_bytes; + + for (pass = 0; pass < passes; ++pass) + { + png_bytep row = png_voidcast(png_bytep, + display->first_row); + unsigned int startx, stepx, stepy; + png_uint_32 y; + + if (png_ptr->interlaced == PNG_INTERLACE_ADAM7) + { + /* The row may be empty for a short image: */ + if (PNG_PASS_COLS(width, pass) == 0) + continue; + + startx = PNG_PASS_START_COL(pass); + stepx = PNG_PASS_COL_OFFSET(pass); + y = PNG_PASS_START_ROW(pass); + stepy = PNG_PASS_ROW_OFFSET(pass); + } + + else + { + y = 0; + startx = 0; + stepx = stepy = 1; + } + + if (display->background == NULL) + { + for (; ylocal_row); + png_bytep outrow = first_row + y * step_row; + png_const_bytep end_row = outrow + width; + + /* Read the row, which is packed: */ + png_read_row(png_ptr, inrow, NULL); + + /* Now do the composition on each pixel in this row. */ + outrow += startx; + for (; outrow < end_row; outrow += stepx) + { + png_byte alpha = inrow[1]; + + if (alpha > 0) /* else no change to the output */ + { + png_uint_32 component = inrow[0]; + + if (alpha < 255) /* else just use component */ + { + /* Since PNG_OPTIMIZED_ALPHA was not set it is + * necessary to invert the sRGB transfer + * function and multiply the alpha out. + */ + component = png_sRGB_table[component] * alpha; + component += png_sRGB_table[outrow[0]] * + (255-alpha); + component = PNG_sRGB_FROM_LINEAR(component); + } + + outrow[0] = (png_byte)component; + } + + inrow += 2; /* gray and alpha channel */ + } + } + } + + else /* constant background value */ + { + png_byte background8 = display->background->green; + png_uint_16 background = png_sRGB_table[background8]; + + for (; ylocal_row); + png_bytep outrow = first_row + y * step_row; + png_const_bytep end_row = outrow + width; + + /* Read the row, which is packed: */ + png_read_row(png_ptr, inrow, NULL); + + /* Now do the composition on each pixel in this row. */ + outrow += startx; + for (; outrow < end_row; outrow += stepx) + { + png_byte alpha = inrow[1]; + + if (alpha > 0) /* else use background */ + { + png_uint_32 component = inrow[0]; + + if (alpha < 255) /* else just use component */ + { + component = png_sRGB_table[component] * alpha; + component += background * (255-alpha); + component = PNG_sRGB_FROM_LINEAR(component); + } + + outrow[0] = (png_byte)component; + } + + else + outrow[0] = background8; + + inrow += 2; /* gray and alpha channel */ + } + + row += display->row_bytes; + } + } + } + } + break; + + case 16: + /* 16-bit linear with pre-multiplied alpha; the pre-multiplication must + * still be done and, maybe, the alpha channel removed. This code also + * handles the alpha-first option. + */ + { + png_uint_16p first_row = png_voidcast(png_uint_16p, + display->first_row); + /* The division by two is safe because the caller passed in a + * stride which was multiplied by 2 (below) to get row_bytes. + */ + ptrdiff_t step_row = display->row_bytes / 2; + int preserve_alpha = (image->format & PNG_FORMAT_FLAG_ALPHA) != 0; + unsigned int outchannels = 1+preserve_alpha; + int swap_alpha = 0; + + if (preserve_alpha && (image->format & PNG_FORMAT_FLAG_AFIRST)) + swap_alpha = 1; + + for (pass = 0; pass < passes; ++pass) + { + unsigned int startx, stepx, stepy; + png_uint_32 y; + + /* The 'x' start and step are adjusted to output components here. + */ + if (png_ptr->interlaced == PNG_INTERLACE_ADAM7) + { + /* The row may be empty for a short image: */ + if (PNG_PASS_COLS(width, pass) == 0) + continue; + + startx = PNG_PASS_START_COL(pass) * outchannels; + stepx = PNG_PASS_COL_OFFSET(pass) * outchannels; + y = PNG_PASS_START_ROW(pass); + stepy = PNG_PASS_ROW_OFFSET(pass); + } + + else + { + y = 0; + startx = 0; + stepx = outchannels; + stepy = 1; + } + + for (; ylocal_row), NULL); + inrow = png_voidcast(png_const_uint_16p, display->local_row); + + /* Now do the pre-multiplication on each pixel in this row. + */ + outrow += startx; + for (; outrow < end_row; outrow += stepx) + { + png_uint_32 component = inrow[0]; + png_uint_16 alpha = inrow[1]; + + if (alpha > 0) /* else 0 */ + { + if (alpha < 65535) /* else just use component */ + { + component *= alpha; + component += 32767; + component /= 65535; + } + } + + else + component = 0; + + outrow[swap_alpha] = (png_uint_16)component; + if (preserve_alpha) + outrow[1 ^ swap_alpha] = alpha; + + inrow += 2; /* components and alpha channel */ + } + } + } + } + break; + } + + return 1; +} + +/* The guts of png_image_finish_read as a png_safe_execute callback. */ +static int +png_image_read_direct(png_voidp argument) +{ + png_image_read_control *display = png_voidcast(png_image_read_control*, + argument); + png_imagep image = display->image; + png_structrp png_ptr = image->opaque->png_ptr; + png_inforp info_ptr = image->opaque->info_ptr; + + png_uint_32 format = image->format; + int linear = (format & PNG_FORMAT_FLAG_LINEAR) != 0; + int do_local_compose = 0; + int do_local_background = 0; /* to avoid double gamma correction bug */ + int passes = 0; + + /* Add transforms to ensure the correct output format is produced then check + * that the required implementation support is there. Always expand; always + * need 8 bits minimum, no palette and expanded tRNS. + */ + png_set_expand(png_ptr); + + /* Now check the format to see if it was modified. */ + { + png_uint_32 base_format = png_image_format(png_ptr) & + ~PNG_FORMAT_FLAG_COLORMAP /* removed by png_set_expand */; + png_uint_32 change = format ^ base_format; + png_fixed_point output_gamma; + int mode; /* alpha mode */ + + /* Do this first so that we have a record if rgb to gray is happening. */ + if (change & PNG_FORMAT_FLAG_COLOR) + { + /* gray<->color transformation required. */ + if (format & PNG_FORMAT_FLAG_COLOR) + png_set_gray_to_rgb(png_ptr); + + else + { + /* libpng can't do both rgb to gray and + * background/pre-multiplication if there is also significant gamma + * correction, because both operations require linear colors and + * the code only supports one transform doing the gamma correction. + * Handle this by doing the pre-multiplication or background + * operation in this code, if necessary. + * + * TODO: fix this by rewriting pngrtran.c (!) + * + * For the moment (given that fixing this in pngrtran.c is an + * enormous change) 'do_local_background' is used to indicate that + * the problem exists. + */ + if (base_format & PNG_FORMAT_FLAG_ALPHA) + do_local_background = 1/*maybe*/; + + png_set_rgb_to_gray_fixed(png_ptr, PNG_ERROR_ACTION_NONE, + PNG_RGB_TO_GRAY_DEFAULT, PNG_RGB_TO_GRAY_DEFAULT); + } + + change &= ~PNG_FORMAT_FLAG_COLOR; + } + + /* Set the gamma appropriately, linear for 16-bit input, sRGB otherwise. + */ + { + png_fixed_point input_gamma_default; + + if ((base_format & PNG_FORMAT_FLAG_LINEAR) && + (image->flags & PNG_IMAGE_FLAG_16BIT_sRGB) == 0) + input_gamma_default = PNG_GAMMA_LINEAR; + else + input_gamma_default = PNG_DEFAULT_sRGB; + + /* Call png_set_alpha_mode to set the default for the input gamma; the + * output gamma is set by a second call below. + */ + png_set_alpha_mode_fixed(png_ptr, PNG_ALPHA_PNG, input_gamma_default); + } + + if (linear) + { + /* If there *is* an alpha channel in the input it must be multiplied + * out; use PNG_ALPHA_STANDARD, otherwise just use PNG_ALPHA_PNG. + */ + if (base_format & PNG_FORMAT_FLAG_ALPHA) + mode = PNG_ALPHA_STANDARD; /* associated alpha */ + + else + mode = PNG_ALPHA_PNG; + + output_gamma = PNG_GAMMA_LINEAR; + } + + else + { + mode = PNG_ALPHA_PNG; + output_gamma = PNG_DEFAULT_sRGB; + } + + /* If 'do_local_background' is set check for the presence of gamma + * correction; this is part of the work-round for the libpng bug + * described above. + * + * TODO: fix libpng and remove this. + */ + if (do_local_background) + { + png_fixed_point gtest; + + /* This is 'png_gamma_threshold' from pngrtran.c; the test used for + * gamma correction, the screen gamma hasn't been set on png_struct + * yet; it's set below. png_struct::gamma, however, is set to the + * final value. + */ + if (png_muldiv(>est, output_gamma, png_ptr->colorspace.gamma, + PNG_FP_1) && !png_gamma_significant(gtest)) + do_local_background = 0; + + else if (mode == PNG_ALPHA_STANDARD) + { + do_local_background = 2/*required*/; + mode = PNG_ALPHA_PNG; /* prevent libpng doing it */ + } + + /* else leave as 1 for the checks below */ + } + + /* If the bit-depth changes then handle that here. */ + if (change & PNG_FORMAT_FLAG_LINEAR) + { + if (linear /*16-bit output*/) + png_set_expand_16(png_ptr); + + else /* 8-bit output */ + png_set_scale_16(png_ptr); + + change &= ~PNG_FORMAT_FLAG_LINEAR; + } + + /* Now the background/alpha channel changes. */ + if (change & PNG_FORMAT_FLAG_ALPHA) + { + /* Removing an alpha channel requires composition for the 8-bit + * formats; for the 16-bit it is already done, above, by the + * pre-multiplication and the channel just needs to be stripped. + */ + if (base_format & PNG_FORMAT_FLAG_ALPHA) + { + /* If RGB->gray is happening the alpha channel must be left and the + * operation completed locally. + * + * TODO: fix libpng and remove this. + */ + if (do_local_background) + do_local_background = 2/*required*/; + + /* 16-bit output: just remove the channel */ + else if (linear) /* compose on black (well, pre-multiply) */ + png_set_strip_alpha(png_ptr); + + /* 8-bit output: do an appropriate compose */ + else if (display->background != NULL) + { + png_color_16 c; + + c.index = 0; /*unused*/ + c.red = display->background->red; + c.green = display->background->green; + c.blue = display->background->blue; + c.gray = display->background->green; + + /* This is always an 8-bit sRGB value, using the 'green' channel + * for gray is much better than calculating the luminance here; + * we can get off-by-one errors in that calculation relative to + * the app expectations and that will show up in transparent + * pixels. + */ + png_set_background_fixed(png_ptr, &c, + PNG_BACKGROUND_GAMMA_SCREEN, 0/*need_expand*/, + 0/*gamma: not used*/); + } + + else /* compose on row: implemented below. */ + { + do_local_compose = 1; + /* This leaves the alpha channel in the output, so it has to be + * removed by the code below. Set the encoding to the 'OPTIMIZE' + * one so the code only has to hack on the pixels that require + * composition. + */ + mode = PNG_ALPHA_OPTIMIZED; + } + } + + else /* output needs an alpha channel */ + { + /* This is tricky because it happens before the swap operation has + * been accomplished; however, the swap does *not* swap the added + * alpha channel (weird API), so it must be added in the correct + * place. + */ + png_uint_32 filler; /* opaque filler */ + int where; + + if (linear) + filler = 65535; + + else + filler = 255; + +# ifdef PNG_FORMAT_AFIRST_SUPPORTED + if (format & PNG_FORMAT_FLAG_AFIRST) + { + where = PNG_FILLER_BEFORE; + change &= ~PNG_FORMAT_FLAG_AFIRST; + } + + else +# endif + where = PNG_FILLER_AFTER; + + png_set_add_alpha(png_ptr, filler, where); + } + + /* This stops the (irrelevant) call to swap_alpha below. */ + change &= ~PNG_FORMAT_FLAG_ALPHA; + } + + /* Now set the alpha mode correctly; this is always done, even if there is + * no alpha channel in either the input or the output because it correctly + * sets the output gamma. + */ + png_set_alpha_mode_fixed(png_ptr, mode, output_gamma); + +# ifdef PNG_FORMAT_BGR_SUPPORTED + if (change & PNG_FORMAT_FLAG_BGR) + { + /* Check only the output format; PNG is never BGR; don't do this if + * the output is gray, but fix up the 'format' value in that case. + */ + if (format & PNG_FORMAT_FLAG_COLOR) + png_set_bgr(png_ptr); + + else + format &= ~PNG_FORMAT_FLAG_BGR; + + change &= ~PNG_FORMAT_FLAG_BGR; + } +# endif + +# ifdef PNG_FORMAT_AFIRST_SUPPORTED + if (change & PNG_FORMAT_FLAG_AFIRST) + { + /* Only relevant if there is an alpha channel - it's particularly + * important to handle this correctly because do_local_compose may + * be set above and then libpng will keep the alpha channel for this + * code to remove. + */ + if (format & PNG_FORMAT_FLAG_ALPHA) + { + /* Disable this if doing a local background, + * TODO: remove this when local background is no longer required. + */ + if (do_local_background != 2) + png_set_swap_alpha(png_ptr); + } + + else + format &= ~PNG_FORMAT_FLAG_AFIRST; + + change &= ~PNG_FORMAT_FLAG_AFIRST; + } +# endif + + /* If the *output* is 16-bit then we need to check for a byte-swap on this + * architecture. + */ + if (linear) + { + PNG_CONST png_uint_16 le = 0x0001; + + if (*(png_const_bytep)&le) + png_set_swap(png_ptr); + } + + /* If change is not now 0 some transformation is missing - error out. */ + if (change) + png_error(png_ptr, "png_read_image: unsupported transformation"); + } + + PNG_SKIP_CHUNKS(png_ptr); + + /* Update the 'info' structure and make sure the result is as required; first + * make sure to turn on the interlace handling if it will be required + * (because it can't be turned on *after* the call to png_read_update_info!) + * + * TODO: remove the do_local_background fixup below. + */ + if (!do_local_compose && do_local_background != 2) + passes = png_set_interlace_handling(png_ptr); + + png_read_update_info(png_ptr, info_ptr); + + { + png_uint_32 info_format = 0; + + if (info_ptr->color_type & PNG_COLOR_MASK_COLOR) + info_format |= PNG_FORMAT_FLAG_COLOR; + + if (info_ptr->color_type & PNG_COLOR_MASK_ALPHA) + { + /* do_local_compose removes this channel below. */ + if (!do_local_compose) + { + /* do_local_background does the same if required. */ + if (do_local_background != 2 || + (format & PNG_FORMAT_FLAG_ALPHA) != 0) + info_format |= PNG_FORMAT_FLAG_ALPHA; + } + } + + else if (do_local_compose) /* internal error */ + png_error(png_ptr, "png_image_read: alpha channel lost"); + + if (info_ptr->bit_depth == 16) + info_format |= PNG_FORMAT_FLAG_LINEAR; + +# ifdef PNG_FORMAT_BGR_SUPPORTED + if (png_ptr->transformations & PNG_BGR) + info_format |= PNG_FORMAT_FLAG_BGR; +# endif + +# ifdef PNG_FORMAT_AFIRST_SUPPORTED + if (do_local_background == 2) + { + if (format & PNG_FORMAT_FLAG_AFIRST) + info_format |= PNG_FORMAT_FLAG_AFIRST; + } + + if ((png_ptr->transformations & PNG_SWAP_ALPHA) != 0 || + ((png_ptr->transformations & PNG_ADD_ALPHA) != 0 && + (png_ptr->flags & PNG_FLAG_FILLER_AFTER) == 0)) + { + if (do_local_background == 2) + png_error(png_ptr, "unexpected alpha swap transformation"); + + info_format |= PNG_FORMAT_FLAG_AFIRST; + } +# endif + + /* This is actually an internal error. */ + if (info_format != format) + png_error(png_ptr, "png_read_image: invalid transformations"); + } + + /* Now read the rows. If do_local_compose is set then it is necessary to use + * a local row buffer. The output will be GA, RGBA or BGRA and must be + * converted to G, RGB or BGR as appropriate. The 'local_row' member of the + * display acts as a flag. + */ + { + png_voidp first_row = display->buffer; + ptrdiff_t row_bytes = display->row_stride; + + if (linear) + row_bytes *= 2; + + /* The following expression is designed to work correctly whether it gives + * a signed or an unsigned result. + */ + if (row_bytes < 0) + { + char *ptr = png_voidcast(char*, first_row); + ptr += (image->height-1) * (-row_bytes); + first_row = png_voidcast(png_voidp, ptr); + } + + display->first_row = first_row; + display->row_bytes = row_bytes; + } + + if (do_local_compose) + { + int result; + png_voidp row = png_malloc(png_ptr, png_get_rowbytes(png_ptr, info_ptr)); + + display->local_row = row; + result = png_safe_execute(image, png_image_read_composite, display); + display->local_row = NULL; + png_free(png_ptr, row); + + return result; + } + + else if (do_local_background == 2) + { + int result; + png_voidp row = png_malloc(png_ptr, png_get_rowbytes(png_ptr, info_ptr)); + + display->local_row = row; + result = png_safe_execute(image, png_image_read_background, display); + display->local_row = NULL; + png_free(png_ptr, row); + + return result; + } + + else + { + png_alloc_size_t row_bytes = display->row_bytes; + + while (--passes >= 0) + { + png_uint_32 y = image->height; + png_bytep row = png_voidcast(png_bytep, display->first_row); + + while (y-- > 0) + { + png_read_row(png_ptr, row, NULL); + row += row_bytes; + } + } + + return 1; + } +} + +int PNGAPI +png_image_finish_read(png_imagep image, png_const_colorp background, + void *buffer, png_int_32 row_stride, void *colormap) +{ + if (image != NULL && image->version == PNG_IMAGE_VERSION) + { + png_uint_32 check; + + if (row_stride == 0) + row_stride = PNG_IMAGE_ROW_STRIDE(*image); + + if (row_stride < 0) + check = -row_stride; + + else + check = row_stride; + + if (image->opaque != NULL && buffer != NULL && + check >= PNG_IMAGE_ROW_STRIDE(*image)) + { + if ((image->format & PNG_FORMAT_FLAG_COLORMAP) == 0 || + (image->colormap_entries > 0 && colormap != NULL)) + { + int result; + png_image_read_control display; + + memset(&display, 0, (sizeof display)); + display.image = image; + display.buffer = buffer; + display.row_stride = row_stride; + display.colormap = colormap; + display.background = background; + display.local_row = NULL; + + /* Choose the correct 'end' routine; for the color-map case all the + * setup has already been done. + */ + if (image->format & PNG_FORMAT_FLAG_COLORMAP) + result = + png_safe_execute(image, png_image_read_colormap, &display) && + png_safe_execute(image, png_image_read_colormapped, &display); + + else + result = + png_safe_execute(image, png_image_read_direct, &display); + + png_image_free(image); + return result; + } + + else + return png_image_error(image, + "png_image_finish_read[color-map]: no color-map"); + } + + else + return png_image_error(image, + "png_image_finish_read: invalid argument"); + } + + else if (image != NULL) + return png_image_error(image, + "png_image_finish_read: damaged PNG_IMAGE_VERSION"); + + return 0; +} + +#endif /* PNG_SIMPLIFIED_READ_SUPPORTED */ +#endif /* PNG_READ_SUPPORTED */ diff --git a/ml/dlib/dlib/external/libpng/pngrio.c b/ml/dlib/dlib/external/libpng/pngrio.c new file mode 100644 index 000000000..d7864407b --- /dev/null +++ b/ml/dlib/dlib/external/libpng/pngrio.c @@ -0,0 +1,118 @@ + +/* pngrio.c - functions for data input + * + * Last changed in libpng 1.6.0 [February 14, 2013] + * Copyright (c) 1998-2013 Glenn Randers-Pehrson + * (Version 0.96 Copyright (c) 1996, 1997 Andreas Dilger) + * (Version 0.88 Copyright (c) 1995, 1996 Guy Eric Schalnat, Group 42, Inc.) + * + * This code is released under the libpng license. + * For conditions of distribution and use, see the disclaimer + * and license in png.h + * + * This file provides a location for all input. Users who need + * special handling are expected to write a function that has the same + * arguments as this and performs a similar function, but that possibly + * has a different input method. Note that you shouldn't change this + * function, but rather write a replacement function and then make + * libpng use it at run time with png_set_read_fn(...). + */ + +#include "pngpriv.h" + +#ifdef PNG_READ_SUPPORTED + +/* Read the data from whatever input you are using. The default routine + * reads from a file pointer. Note that this routine sometimes gets called + * with very small lengths, so you should implement some kind of simple + * buffering if you are using unbuffered reads. This should never be asked + * to read more then 64K on a 16 bit machine. + */ +void /* PRIVATE */ +png_read_data(png_structrp png_ptr, png_bytep data, png_size_t length) +{ + png_debug1(4, "reading %d bytes", (int)length); + + if (png_ptr->read_data_fn != NULL) + (*(png_ptr->read_data_fn))(png_ptr, data, length); + + else + png_error(png_ptr, "Call to NULL read function"); +} + +#ifdef PNG_STDIO_SUPPORTED +/* This is the function that does the actual reading of data. If you are + * not reading from a standard C stream, you should create a replacement + * read_data function and use it at run time with png_set_read_fn(), rather + * than changing the library. + */ +void PNGCBAPI +png_default_read_data(png_structp png_ptr, png_bytep data, png_size_t length) +{ + png_size_t check; + + if (png_ptr == NULL) + return; + + /* fread() returns 0 on error, so it is OK to store this in a png_size_t + * instead of an int, which is what fread() actually returns. + */ + check = fread(data, 1, length, png_voidcast(png_FILE_p, png_ptr->io_ptr)); + + if (check != length) + png_error(png_ptr, "Read Error"); +} +#endif + +/* This function allows the application to supply a new input function + * for libpng if standard C streams aren't being used. + * + * This function takes as its arguments: + * + * png_ptr - pointer to a png input data structure + * + * io_ptr - pointer to user supplied structure containing info about + * the input functions. May be NULL. + * + * read_data_fn - pointer to a new input function that takes as its + * arguments a pointer to a png_struct, a pointer to + * a location where input data can be stored, and a 32-bit + * unsigned int that is the number of bytes to be read. + * To exit and output any fatal error messages the new write + * function should call png_error(png_ptr, "Error msg"). + * May be NULL, in which case libpng's default function will + * be used. + */ +void PNGAPI +png_set_read_fn(png_structrp png_ptr, png_voidp io_ptr, + png_rw_ptr read_data_fn) +{ + if (png_ptr == NULL) + return; + + png_ptr->io_ptr = io_ptr; + +#ifdef PNG_STDIO_SUPPORTED + if (read_data_fn != NULL) + png_ptr->read_data_fn = read_data_fn; + + else + png_ptr->read_data_fn = png_default_read_data; +#else + png_ptr->read_data_fn = read_data_fn; +#endif + + /* It is an error to write to a read device */ + if (png_ptr->write_data_fn != NULL) + { + png_ptr->write_data_fn = NULL; + png_warning(png_ptr, + "Can't set both read_data_fn and write_data_fn in the" + " same structure"); + } + +#ifdef PNG_WRITE_FLUSH_SUPPORTED + png_ptr->output_flush_fn = NULL; +#endif +} +#endif /* PNG_READ_SUPPORTED */ diff --git a/ml/dlib/dlib/external/libpng/pngrtran.c b/ml/dlib/dlib/external/libpng/pngrtran.c new file mode 100644 index 000000000..3b7d484fc --- /dev/null +++ b/ml/dlib/dlib/external/libpng/pngrtran.c @@ -0,0 +1,5110 @@ + +/* pngrtran.c - transforms the data in a row for PNG readers + * + * Last changed in libpng 1.6.4 [August 21, 2013] + * Copyright (c) 1998-2013 Glenn Randers-Pehrson + * (Version 0.96 Copyright (c) 1996, 1997 Andreas Dilger) + * (Version 0.88 Copyright (c) 1995, 1996 Guy Eric Schalnat, Group 42, Inc.) + * + * This code is released under the libpng license. + * For conditions of distribution and use, see the disclaimer + * and license in png.h + * + * This file contains functions optionally called by an application + * in order to tell libpng how to handle data when reading a PNG. + * Transformations that are used in both reading and writing are + * in pngtrans.c. + */ + +#include "pngpriv.h" + +#ifdef PNG_READ_SUPPORTED + +/* Set the action on getting a CRC error for an ancillary or critical chunk. */ +void PNGAPI +png_set_crc_action(png_structrp png_ptr, int crit_action, int ancil_action) +{ + png_debug(1, "in png_set_crc_action"); + + if (png_ptr == NULL) + return; + + /* Tell libpng how we react to CRC errors in critical chunks */ + switch (crit_action) + { + case PNG_CRC_NO_CHANGE: /* Leave setting as is */ + break; + + case PNG_CRC_WARN_USE: /* Warn/use data */ + png_ptr->flags &= ~PNG_FLAG_CRC_CRITICAL_MASK; + png_ptr->flags |= PNG_FLAG_CRC_CRITICAL_USE; + break; + + case PNG_CRC_QUIET_USE: /* Quiet/use data */ + png_ptr->flags &= ~PNG_FLAG_CRC_CRITICAL_MASK; + png_ptr->flags |= PNG_FLAG_CRC_CRITICAL_USE | + PNG_FLAG_CRC_CRITICAL_IGNORE; + break; + + case PNG_CRC_WARN_DISCARD: /* Not a valid action for critical data */ + png_warning(png_ptr, + "Can't discard critical data on CRC error"); + case PNG_CRC_ERROR_QUIT: /* Error/quit */ + + case PNG_CRC_DEFAULT: + default: + png_ptr->flags &= ~PNG_FLAG_CRC_CRITICAL_MASK; + break; + } + + /* Tell libpng how we react to CRC errors in ancillary chunks */ + switch (ancil_action) + { + case PNG_CRC_NO_CHANGE: /* Leave setting as is */ + break; + + case PNG_CRC_WARN_USE: /* Warn/use data */ + png_ptr->flags &= ~PNG_FLAG_CRC_ANCILLARY_MASK; + png_ptr->flags |= PNG_FLAG_CRC_ANCILLARY_USE; + break; + + case PNG_CRC_QUIET_USE: /* Quiet/use data */ + png_ptr->flags &= ~PNG_FLAG_CRC_ANCILLARY_MASK; + png_ptr->flags |= PNG_FLAG_CRC_ANCILLARY_USE | + PNG_FLAG_CRC_ANCILLARY_NOWARN; + break; + + case PNG_CRC_ERROR_QUIT: /* Error/quit */ + png_ptr->flags &= ~PNG_FLAG_CRC_ANCILLARY_MASK; + png_ptr->flags |= PNG_FLAG_CRC_ANCILLARY_NOWARN; + break; + + case PNG_CRC_WARN_DISCARD: /* Warn/discard data */ + + case PNG_CRC_DEFAULT: + default: + png_ptr->flags &= ~PNG_FLAG_CRC_ANCILLARY_MASK; + break; + } +} + +#ifdef PNG_READ_TRANSFORMS_SUPPORTED +/* Is it OK to set a transformation now? Only if png_start_read_image or + * png_read_update_info have not been called. It is not necessary for the IHDR + * to have been read in all cases, the parameter allows for this check too. + */ +static int +png_rtran_ok(png_structrp png_ptr, int need_IHDR) +{ + if (png_ptr != NULL) + { + if (png_ptr->flags & PNG_FLAG_ROW_INIT) + png_app_error(png_ptr, + "invalid after png_start_read_image or png_read_update_info"); + + else if (need_IHDR && (png_ptr->mode & PNG_HAVE_IHDR) == 0) + png_app_error(png_ptr, "invalid before the PNG header has been read"); + + else + { + /* Turn on failure to initialize correctly for all transforms. */ + png_ptr->flags |= PNG_FLAG_DETECT_UNINITIALIZED; + + return 1; /* Ok */ + } + } + + return 0; /* no png_error possible! */ +} +#endif + +#ifdef PNG_READ_BACKGROUND_SUPPORTED +/* Handle alpha and tRNS via a background color */ +void PNGFAPI +png_set_background_fixed(png_structrp png_ptr, + png_const_color_16p background_color, int background_gamma_code, + int need_expand, png_fixed_point background_gamma) +{ + png_debug(1, "in png_set_background_fixed"); + + if (!png_rtran_ok(png_ptr, 0) || background_color == NULL) + return; + + if (background_gamma_code == PNG_BACKGROUND_GAMMA_UNKNOWN) + { + png_warning(png_ptr, "Application must supply a known background gamma"); + return; + } + + png_ptr->transformations |= PNG_COMPOSE | PNG_STRIP_ALPHA; + png_ptr->transformations &= ~PNG_ENCODE_ALPHA; + png_ptr->flags &= ~PNG_FLAG_OPTIMIZE_ALPHA; + + png_ptr->background = *background_color; + png_ptr->background_gamma = background_gamma; + png_ptr->background_gamma_type = (png_byte)(background_gamma_code); + if (need_expand) + png_ptr->transformations |= PNG_BACKGROUND_EXPAND; + else + png_ptr->transformations &= ~PNG_BACKGROUND_EXPAND; +} + +# ifdef PNG_FLOATING_POINT_SUPPORTED +void PNGAPI +png_set_background(png_structrp png_ptr, + png_const_color_16p background_color, int background_gamma_code, + int need_expand, double background_gamma) +{ + png_set_background_fixed(png_ptr, background_color, background_gamma_code, + need_expand, png_fixed(png_ptr, background_gamma, "png_set_background")); +} +# endif /* FLOATING_POINT */ +#endif /* READ_BACKGROUND */ + +/* Scale 16-bit depth files to 8-bit depth. If both of these are set then the + * one that pngrtran does first (scale) happens. This is necessary to allow the + * TRANSFORM and API behavior to be somewhat consistent, and it's simpler. + */ +#ifdef PNG_READ_SCALE_16_TO_8_SUPPORTED +void PNGAPI +png_set_scale_16(png_structrp png_ptr) +{ + png_debug(1, "in png_set_scale_16"); + + if (!png_rtran_ok(png_ptr, 0)) + return; + + png_ptr->transformations |= PNG_SCALE_16_TO_8; +} +#endif + +#ifdef PNG_READ_STRIP_16_TO_8_SUPPORTED +/* Chop 16-bit depth files to 8-bit depth */ +void PNGAPI +png_set_strip_16(png_structrp png_ptr) +{ + png_debug(1, "in png_set_strip_16"); + + if (!png_rtran_ok(png_ptr, 0)) + return; + + png_ptr->transformations |= PNG_16_TO_8; +} +#endif + +#ifdef PNG_READ_STRIP_ALPHA_SUPPORTED +void PNGAPI +png_set_strip_alpha(png_structrp png_ptr) +{ + png_debug(1, "in png_set_strip_alpha"); + + if (!png_rtran_ok(png_ptr, 0)) + return; + + png_ptr->transformations |= PNG_STRIP_ALPHA; +} +#endif + +#if defined(PNG_READ_ALPHA_MODE_SUPPORTED) || defined(PNG_READ_GAMMA_SUPPORTED) +static png_fixed_point +translate_gamma_flags(png_structrp png_ptr, png_fixed_point output_gamma, + int is_screen) +{ + /* Check for flag values. The main reason for having the old Mac value as a + * flag is that it is pretty near impossible to work out what the correct + * value is from Apple documentation - a working Mac system is needed to + * discover the value! + */ + if (output_gamma == PNG_DEFAULT_sRGB || + output_gamma == PNG_FP_1 / PNG_DEFAULT_sRGB) + { + /* If there is no sRGB support this just sets the gamma to the standard + * sRGB value. (This is a side effect of using this function!) + */ +# ifdef PNG_READ_sRGB_SUPPORTED + png_ptr->flags |= PNG_FLAG_ASSUME_sRGB; +# else + PNG_UNUSED(png_ptr) +# endif + if (is_screen) + output_gamma = PNG_GAMMA_sRGB; + else + output_gamma = PNG_GAMMA_sRGB_INVERSE; + } + + else if (output_gamma == PNG_GAMMA_MAC_18 || + output_gamma == PNG_FP_1 / PNG_GAMMA_MAC_18) + { + if (is_screen) + output_gamma = PNG_GAMMA_MAC_OLD; + else + output_gamma = PNG_GAMMA_MAC_INVERSE; + } + + return output_gamma; +} + +# ifdef PNG_FLOATING_POINT_SUPPORTED +static png_fixed_point +convert_gamma_value(png_structrp png_ptr, double output_gamma) +{ + /* The following silently ignores cases where fixed point (times 100,000) + * gamma values are passed to the floating point API. This is safe and it + * means the fixed point constants work just fine with the floating point + * API. The alternative would just lead to undetected errors and spurious + * bug reports. Negative values fail inside the _fixed API unless they + * correspond to the flag values. + */ + if (output_gamma > 0 && output_gamma < 128) + output_gamma *= PNG_FP_1; + + /* This preserves -1 and -2 exactly: */ + output_gamma = floor(output_gamma + .5); + + if (output_gamma > PNG_FP_MAX || output_gamma < PNG_FP_MIN) + png_fixed_error(png_ptr, "gamma value"); + + return (png_fixed_point)output_gamma; +} +# endif +#endif /* READ_ALPHA_MODE || READ_GAMMA */ + +#ifdef PNG_READ_ALPHA_MODE_SUPPORTED +void PNGFAPI +png_set_alpha_mode_fixed(png_structrp png_ptr, int mode, + png_fixed_point output_gamma) +{ + int compose = 0; + png_fixed_point file_gamma; + + png_debug(1, "in png_set_alpha_mode"); + + if (!png_rtran_ok(png_ptr, 0)) + return; + + output_gamma = translate_gamma_flags(png_ptr, output_gamma, 1/*screen*/); + + /* Validate the value to ensure it is in a reasonable range. The value + * is expected to be 1 or greater, but this range test allows for some + * viewing correction values. The intent is to weed out users of this API + * who use the inverse of the gamma value accidentally! Since some of these + * values are reasonable this may have to be changed. + */ + if (output_gamma < 70000 || output_gamma > 300000) + png_error(png_ptr, "output gamma out of expected range"); + + /* The default file gamma is the inverse of the output gamma; the output + * gamma may be changed below so get the file value first: + */ + file_gamma = png_reciprocal(output_gamma); + + /* There are really 8 possibilities here, composed of any combination + * of: + * + * premultiply the color channels + * do not encode non-opaque pixels + * encode the alpha as well as the color channels + * + * The differences disappear if the input/output ('screen') gamma is 1.0, + * because then the encoding is a no-op and there is only the choice of + * premultiplying the color channels or not. + * + * png_set_alpha_mode and png_set_background interact because both use + * png_compose to do the work. Calling both is only useful when + * png_set_alpha_mode is used to set the default mode - PNG_ALPHA_PNG - along + * with a default gamma value. Otherwise PNG_COMPOSE must not be set. + */ + switch (mode) + { + case PNG_ALPHA_PNG: /* default: png standard */ + /* No compose, but it may be set by png_set_background! */ + png_ptr->transformations &= ~PNG_ENCODE_ALPHA; + png_ptr->flags &= ~PNG_FLAG_OPTIMIZE_ALPHA; + break; + + case PNG_ALPHA_ASSOCIATED: /* color channels premultiplied */ + compose = 1; + png_ptr->transformations &= ~PNG_ENCODE_ALPHA; + png_ptr->flags &= ~PNG_FLAG_OPTIMIZE_ALPHA; + /* The output is linear: */ + output_gamma = PNG_FP_1; + break; + + case PNG_ALPHA_OPTIMIZED: /* associated, non-opaque pixels linear */ + compose = 1; + png_ptr->transformations &= ~PNG_ENCODE_ALPHA; + png_ptr->flags |= PNG_FLAG_OPTIMIZE_ALPHA; + /* output_gamma records the encoding of opaque pixels! */ + break; + + case PNG_ALPHA_BROKEN: /* associated, non-linear, alpha encoded */ + compose = 1; + png_ptr->transformations |= PNG_ENCODE_ALPHA; + png_ptr->flags &= ~PNG_FLAG_OPTIMIZE_ALPHA; + break; + + default: + png_error(png_ptr, "invalid alpha mode"); + } + + /* Only set the default gamma if the file gamma has not been set (this has + * the side effect that the gamma in a second call to png_set_alpha_mode will + * be ignored.) + */ + if (png_ptr->colorspace.gamma == 0) + { + png_ptr->colorspace.gamma = file_gamma; + png_ptr->colorspace.flags |= PNG_COLORSPACE_HAVE_GAMMA; + } + + /* But always set the output gamma: */ + png_ptr->screen_gamma = output_gamma; + + /* Finally, if pre-multiplying, set the background fields to achieve the + * desired result. + */ + if (compose) + { + /* And obtain alpha pre-multiplication by composing on black: */ + memset(&png_ptr->background, 0, (sizeof png_ptr->background)); + png_ptr->background_gamma = png_ptr->colorspace.gamma; /* just in case */ + png_ptr->background_gamma_type = PNG_BACKGROUND_GAMMA_FILE; + png_ptr->transformations &= ~PNG_BACKGROUND_EXPAND; + + if (png_ptr->transformations & PNG_COMPOSE) + png_error(png_ptr, + "conflicting calls to set alpha mode and background"); + + png_ptr->transformations |= PNG_COMPOSE; + } +} + +# ifdef PNG_FLOATING_POINT_SUPPORTED +void PNGAPI +png_set_alpha_mode(png_structrp png_ptr, int mode, double output_gamma) +{ + png_set_alpha_mode_fixed(png_ptr, mode, convert_gamma_value(png_ptr, + output_gamma)); +} +# endif +#endif + +#ifdef PNG_READ_QUANTIZE_SUPPORTED +/* Dither file to 8-bit. Supply a palette, the current number + * of elements in the palette, the maximum number of elements + * allowed, and a histogram if possible. If the current number + * of colors is greater then the maximum number, the palette will be + * modified to fit in the maximum number. "full_quantize" indicates + * whether we need a quantizing cube set up for RGB images, or if we + * simply are reducing the number of colors in a paletted image. + */ + +typedef struct png_dsort_struct +{ + struct png_dsort_struct * next; + png_byte left; + png_byte right; +} png_dsort; +typedef png_dsort * png_dsortp; +typedef png_dsort * * png_dsortpp; + +void PNGAPI +png_set_quantize(png_structrp png_ptr, png_colorp palette, + int num_palette, int maximum_colors, png_const_uint_16p histogram, + int full_quantize) +{ + png_debug(1, "in png_set_quantize"); + + if (!png_rtran_ok(png_ptr, 0)) + return; + + png_ptr->transformations |= PNG_QUANTIZE; + + if (!full_quantize) + { + int i; + + png_ptr->quantize_index = (png_bytep)png_malloc(png_ptr, + (png_uint_32)(num_palette * (sizeof (png_byte)))); + for (i = 0; i < num_palette; i++) + png_ptr->quantize_index[i] = (png_byte)i; + } + + if (num_palette > maximum_colors) + { + if (histogram != NULL) + { + /* This is easy enough, just throw out the least used colors. + * Perhaps not the best solution, but good enough. + */ + + int i; + + /* Initialize an array to sort colors */ + png_ptr->quantize_sort = (png_bytep)png_malloc(png_ptr, + (png_uint_32)(num_palette * (sizeof (png_byte)))); + + /* Initialize the quantize_sort array */ + for (i = 0; i < num_palette; i++) + png_ptr->quantize_sort[i] = (png_byte)i; + + /* Find the least used palette entries by starting a + * bubble sort, and running it until we have sorted + * out enough colors. Note that we don't care about + * sorting all the colors, just finding which are + * least used. + */ + + for (i = num_palette - 1; i >= maximum_colors; i--) + { + int done; /* To stop early if the list is pre-sorted */ + int j; + + done = 1; + for (j = 0; j < i; j++) + { + if (histogram[png_ptr->quantize_sort[j]] + < histogram[png_ptr->quantize_sort[j + 1]]) + { + png_byte t; + + t = png_ptr->quantize_sort[j]; + png_ptr->quantize_sort[j] = png_ptr->quantize_sort[j + 1]; + png_ptr->quantize_sort[j + 1] = t; + done = 0; + } + } + + if (done) + break; + } + + /* Swap the palette around, and set up a table, if necessary */ + if (full_quantize) + { + int j = num_palette; + + /* Put all the useful colors within the max, but don't + * move the others. + */ + for (i = 0; i < maximum_colors; i++) + { + if ((int)png_ptr->quantize_sort[i] >= maximum_colors) + { + do + j--; + while ((int)png_ptr->quantize_sort[j] >= maximum_colors); + + palette[i] = palette[j]; + } + } + } + else + { + int j = num_palette; + + /* Move all the used colors inside the max limit, and + * develop a translation table. + */ + for (i = 0; i < maximum_colors; i++) + { + /* Only move the colors we need to */ + if ((int)png_ptr->quantize_sort[i] >= maximum_colors) + { + png_color tmp_color; + + do + j--; + while ((int)png_ptr->quantize_sort[j] >= maximum_colors); + + tmp_color = palette[j]; + palette[j] = palette[i]; + palette[i] = tmp_color; + /* Indicate where the color went */ + png_ptr->quantize_index[j] = (png_byte)i; + png_ptr->quantize_index[i] = (png_byte)j; + } + } + + /* Find closest color for those colors we are not using */ + for (i = 0; i < num_palette; i++) + { + if ((int)png_ptr->quantize_index[i] >= maximum_colors) + { + int min_d, k, min_k, d_index; + + /* Find the closest color to one we threw out */ + d_index = png_ptr->quantize_index[i]; + min_d = PNG_COLOR_DIST(palette[d_index], palette[0]); + for (k = 1, min_k = 0; k < maximum_colors; k++) + { + int d; + + d = PNG_COLOR_DIST(palette[d_index], palette[k]); + + if (d < min_d) + { + min_d = d; + min_k = k; + } + } + /* Point to closest color */ + png_ptr->quantize_index[i] = (png_byte)min_k; + } + } + } + png_free(png_ptr, png_ptr->quantize_sort); + png_ptr->quantize_sort = NULL; + } + else + { + /* This is much harder to do simply (and quickly). Perhaps + * we need to go through a median cut routine, but those + * don't always behave themselves with only a few colors + * as input. So we will just find the closest two colors, + * and throw out one of them (chosen somewhat randomly). + * [We don't understand this at all, so if someone wants to + * work on improving it, be our guest - AED, GRP] + */ + int i; + int max_d; + int num_new_palette; + png_dsortp t; + png_dsortpp hash; + + t = NULL; + + /* Initialize palette index arrays */ + png_ptr->index_to_palette = (png_bytep)png_malloc(png_ptr, + (png_uint_32)(num_palette * (sizeof (png_byte)))); + png_ptr->palette_to_index = (png_bytep)png_malloc(png_ptr, + (png_uint_32)(num_palette * (sizeof (png_byte)))); + + /* Initialize the sort array */ + for (i = 0; i < num_palette; i++) + { + png_ptr->index_to_palette[i] = (png_byte)i; + png_ptr->palette_to_index[i] = (png_byte)i; + } + + hash = (png_dsortpp)png_calloc(png_ptr, (png_uint_32)(769 * + (sizeof (png_dsortp)))); + + num_new_palette = num_palette; + + /* Initial wild guess at how far apart the farthest pixel + * pair we will be eliminating will be. Larger + * numbers mean more areas will be allocated, Smaller + * numbers run the risk of not saving enough data, and + * having to do this all over again. + * + * I have not done extensive checking on this number. + */ + max_d = 96; + + while (num_new_palette > maximum_colors) + { + for (i = 0; i < num_new_palette - 1; i++) + { + int j; + + for (j = i + 1; j < num_new_palette; j++) + { + int d; + + d = PNG_COLOR_DIST(palette[i], palette[j]); + + if (d <= max_d) + { + + t = (png_dsortp)png_malloc_warn(png_ptr, + (png_uint_32)(sizeof (png_dsort))); + + if (t == NULL) + break; + + t->next = hash[d]; + t->left = (png_byte)i; + t->right = (png_byte)j; + hash[d] = t; + } + } + if (t == NULL) + break; + } + + if (t != NULL) + for (i = 0; i <= max_d; i++) + { + if (hash[i] != NULL) + { + png_dsortp p; + + for (p = hash[i]; p; p = p->next) + { + if ((int)png_ptr->index_to_palette[p->left] + < num_new_palette && + (int)png_ptr->index_to_palette[p->right] + < num_new_palette) + { + int j, next_j; + + if (num_new_palette & 0x01) + { + j = p->left; + next_j = p->right; + } + else + { + j = p->right; + next_j = p->left; + } + + num_new_palette--; + palette[png_ptr->index_to_palette[j]] + = palette[num_new_palette]; + if (!full_quantize) + { + int k; + + for (k = 0; k < num_palette; k++) + { + if (png_ptr->quantize_index[k] == + png_ptr->index_to_palette[j]) + png_ptr->quantize_index[k] = + png_ptr->index_to_palette[next_j]; + + if ((int)png_ptr->quantize_index[k] == + num_new_palette) + png_ptr->quantize_index[k] = + png_ptr->index_to_palette[j]; + } + } + + png_ptr->index_to_palette[png_ptr->palette_to_index + [num_new_palette]] = png_ptr->index_to_palette[j]; + + png_ptr->palette_to_index[png_ptr->index_to_palette[j]] + = png_ptr->palette_to_index[num_new_palette]; + + png_ptr->index_to_palette[j] = + (png_byte)num_new_palette; + + png_ptr->palette_to_index[num_new_palette] = + (png_byte)j; + } + if (num_new_palette <= maximum_colors) + break; + } + if (num_new_palette <= maximum_colors) + break; + } + } + + for (i = 0; i < 769; i++) + { + if (hash[i] != NULL) + { + png_dsortp p = hash[i]; + while (p) + { + t = p->next; + png_free(png_ptr, p); + p = t; + } + } + hash[i] = 0; + } + max_d += 96; + } + png_free(png_ptr, hash); + png_free(png_ptr, png_ptr->palette_to_index); + png_free(png_ptr, png_ptr->index_to_palette); + png_ptr->palette_to_index = NULL; + png_ptr->index_to_palette = NULL; + } + num_palette = maximum_colors; + } + if (png_ptr->palette == NULL) + { + png_ptr->palette = palette; + } + png_ptr->num_palette = (png_uint_16)num_palette; + + if (full_quantize) + { + int i; + png_bytep distance; + int total_bits = PNG_QUANTIZE_RED_BITS + PNG_QUANTIZE_GREEN_BITS + + PNG_QUANTIZE_BLUE_BITS; + int num_red = (1 << PNG_QUANTIZE_RED_BITS); + int num_green = (1 << PNG_QUANTIZE_GREEN_BITS); + int num_blue = (1 << PNG_QUANTIZE_BLUE_BITS); + png_size_t num_entries = ((png_size_t)1 << total_bits); + + png_ptr->palette_lookup = (png_bytep)png_calloc(png_ptr, + (png_uint_32)(num_entries * (sizeof (png_byte)))); + + distance = (png_bytep)png_malloc(png_ptr, (png_uint_32)(num_entries * + (sizeof (png_byte)))); + + memset(distance, 0xff, num_entries * (sizeof (png_byte))); + + for (i = 0; i < num_palette; i++) + { + int ir, ig, ib; + int r = (palette[i].red >> (8 - PNG_QUANTIZE_RED_BITS)); + int g = (palette[i].green >> (8 - PNG_QUANTIZE_GREEN_BITS)); + int b = (palette[i].blue >> (8 - PNG_QUANTIZE_BLUE_BITS)); + + for (ir = 0; ir < num_red; ir++) + { + /* int dr = abs(ir - r); */ + int dr = ((ir > r) ? ir - r : r - ir); + int index_r = (ir << (PNG_QUANTIZE_BLUE_BITS + + PNG_QUANTIZE_GREEN_BITS)); + + for (ig = 0; ig < num_green; ig++) + { + /* int dg = abs(ig - g); */ + int dg = ((ig > g) ? ig - g : g - ig); + int dt = dr + dg; + int dm = ((dr > dg) ? dr : dg); + int index_g = index_r | (ig << PNG_QUANTIZE_BLUE_BITS); + + for (ib = 0; ib < num_blue; ib++) + { + int d_index = index_g | ib; + /* int db = abs(ib - b); */ + int db = ((ib > b) ? ib - b : b - ib); + int dmax = ((dm > db) ? dm : db); + int d = dmax + dt + db; + + if (d < (int)distance[d_index]) + { + distance[d_index] = (png_byte)d; + png_ptr->palette_lookup[d_index] = (png_byte)i; + } + } + } + } + } + + png_free(png_ptr, distance); + } +} +#endif /* PNG_READ_QUANTIZE_SUPPORTED */ + +#ifdef PNG_READ_GAMMA_SUPPORTED +void PNGFAPI +png_set_gamma_fixed(png_structrp png_ptr, png_fixed_point scrn_gamma, + png_fixed_point file_gamma) +{ + png_debug(1, "in png_set_gamma_fixed"); + + if (!png_rtran_ok(png_ptr, 0)) + return; + + /* New in libpng-1.5.4 - reserve particular negative values as flags. */ + scrn_gamma = translate_gamma_flags(png_ptr, scrn_gamma, 1/*screen*/); + file_gamma = translate_gamma_flags(png_ptr, file_gamma, 0/*file*/); + + /* Checking the gamma values for being >0 was added in 1.5.4 along with the + * premultiplied alpha support; this actually hides an undocumented feature + * of the previous implementation which allowed gamma processing to be + * disabled in background handling. There is no evidence (so far) that this + * was being used; however, png_set_background itself accepted and must still + * accept '0' for the gamma value it takes, because it isn't always used. + * + * Since this is an API change (albeit a very minor one that removes an + * undocumented API feature) the following checks were only enabled in + * libpng-1.6.0. + */ + if (file_gamma <= 0) + png_error(png_ptr, "invalid file gamma in png_set_gamma"); + + if (scrn_gamma <= 0) + png_error(png_ptr, "invalid screen gamma in png_set_gamma"); + + /* Set the gamma values unconditionally - this overrides the value in the PNG + * file if a gAMA chunk was present. png_set_alpha_mode provides a + * different, easier, way to default the file gamma. + */ + png_ptr->colorspace.gamma = file_gamma; + png_ptr->colorspace.flags |= PNG_COLORSPACE_HAVE_GAMMA; + png_ptr->screen_gamma = scrn_gamma; +} + +# ifdef PNG_FLOATING_POINT_SUPPORTED +void PNGAPI +png_set_gamma(png_structrp png_ptr, double scrn_gamma, double file_gamma) +{ + png_set_gamma_fixed(png_ptr, convert_gamma_value(png_ptr, scrn_gamma), + convert_gamma_value(png_ptr, file_gamma)); +} +# endif /* FLOATING_POINT_SUPPORTED */ +#endif /* READ_GAMMA */ + +#ifdef PNG_READ_EXPAND_SUPPORTED +/* Expand paletted images to RGB, expand grayscale images of + * less than 8-bit depth to 8-bit depth, and expand tRNS chunks + * to alpha channels. + */ +void PNGAPI +png_set_expand(png_structrp png_ptr) +{ + png_debug(1, "in png_set_expand"); + + if (!png_rtran_ok(png_ptr, 0)) + return; + + png_ptr->transformations |= (PNG_EXPAND | PNG_EXPAND_tRNS); +} + +/* GRR 19990627: the following three functions currently are identical + * to png_set_expand(). However, it is entirely reasonable that someone + * might wish to expand an indexed image to RGB but *not* expand a single, + * fully transparent palette entry to a full alpha channel--perhaps instead + * convert tRNS to the grayscale/RGB format (16-bit RGB value), or replace + * the transparent color with a particular RGB value, or drop tRNS entirely. + * IOW, a future version of the library may make the transformations flag + * a bit more fine-grained, with separate bits for each of these three + * functions. + * + * More to the point, these functions make it obvious what libpng will be + * doing, whereas "expand" can (and does) mean any number of things. + * + * GRP 20060307: In libpng-1.2.9, png_set_gray_1_2_4_to_8() was modified + * to expand only the sample depth but not to expand the tRNS to alpha + * and its name was changed to png_set_expand_gray_1_2_4_to_8(). + */ + +/* Expand paletted images to RGB. */ +void PNGAPI +png_set_palette_to_rgb(png_structrp png_ptr) +{ + png_debug(1, "in png_set_palette_to_rgb"); + + if (!png_rtran_ok(png_ptr, 0)) + return; + + png_ptr->transformations |= (PNG_EXPAND | PNG_EXPAND_tRNS); +} + +/* Expand grayscale images of less than 8-bit depth to 8 bits. */ +void PNGAPI +png_set_expand_gray_1_2_4_to_8(png_structrp png_ptr) +{ + png_debug(1, "in png_set_expand_gray_1_2_4_to_8"); + + if (!png_rtran_ok(png_ptr, 0)) + return; + + png_ptr->transformations |= PNG_EXPAND; +} + +/* Expand tRNS chunks to alpha channels. */ +void PNGAPI +png_set_tRNS_to_alpha(png_structrp png_ptr) +{ + png_debug(1, "in png_set_tRNS_to_alpha"); + + if (!png_rtran_ok(png_ptr, 0)) + return; + + png_ptr->transformations |= (PNG_EXPAND | PNG_EXPAND_tRNS); +} +#endif /* defined(PNG_READ_EXPAND_SUPPORTED) */ + +#ifdef PNG_READ_EXPAND_16_SUPPORTED +/* Expand to 16-bit channels, expand the tRNS chunk too (because otherwise + * it may not work correctly.) + */ +void PNGAPI +png_set_expand_16(png_structrp png_ptr) +{ + png_debug(1, "in png_set_expand_16"); + + if (!png_rtran_ok(png_ptr, 0)) + return; + + png_ptr->transformations |= (PNG_EXPAND_16 | PNG_EXPAND | PNG_EXPAND_tRNS); +} +#endif + +#ifdef PNG_READ_GRAY_TO_RGB_SUPPORTED +void PNGAPI +png_set_gray_to_rgb(png_structrp png_ptr) +{ + png_debug(1, "in png_set_gray_to_rgb"); + + if (!png_rtran_ok(png_ptr, 0)) + return; + + /* Because rgb must be 8 bits or more: */ + png_set_expand_gray_1_2_4_to_8(png_ptr); + png_ptr->transformations |= PNG_GRAY_TO_RGB; +} +#endif + +#ifdef PNG_READ_RGB_TO_GRAY_SUPPORTED +void PNGFAPI +png_set_rgb_to_gray_fixed(png_structrp png_ptr, int error_action, + png_fixed_point red, png_fixed_point green) +{ + png_debug(1, "in png_set_rgb_to_gray"); + + /* Need the IHDR here because of the check on color_type below. */ + /* TODO: fix this */ + if (!png_rtran_ok(png_ptr, 1)) + return; + + switch(error_action) + { + case PNG_ERROR_ACTION_NONE: + png_ptr->transformations |= PNG_RGB_TO_GRAY; + break; + + case PNG_ERROR_ACTION_WARN: + png_ptr->transformations |= PNG_RGB_TO_GRAY_WARN; + break; + + case PNG_ERROR_ACTION_ERROR: + png_ptr->transformations |= PNG_RGB_TO_GRAY_ERR; + break; + + default: + png_error(png_ptr, "invalid error action to rgb_to_gray"); + break; + } + + if (png_ptr->color_type == PNG_COLOR_TYPE_PALETTE) +#ifdef PNG_READ_EXPAND_SUPPORTED + png_ptr->transformations |= PNG_EXPAND; +#else + { + /* Make this an error in 1.6 because otherwise the application may assume + * that it just worked and get a memory overwrite. + */ + png_error(png_ptr, + "Cannot do RGB_TO_GRAY without EXPAND_SUPPORTED"); + + /* png_ptr->transformations &= ~PNG_RGB_TO_GRAY; */ + } +#endif + { + if (red >= 0 && green >= 0 && red + green <= PNG_FP_1) + { + png_uint_16 red_int, green_int; + + /* NOTE: this calculation does not round, but this behavior is retained + * for consistency, the inaccuracy is very small. The code here always + * overwrites the coefficients, regardless of whether they have been + * defaulted or set already. + */ + red_int = (png_uint_16)(((png_uint_32)red*32768)/100000); + green_int = (png_uint_16)(((png_uint_32)green*32768)/100000); + + png_ptr->rgb_to_gray_red_coeff = red_int; + png_ptr->rgb_to_gray_green_coeff = green_int; + png_ptr->rgb_to_gray_coefficients_set = 1; + } + + else + { + if (red >= 0 && green >= 0) + png_app_warning(png_ptr, + "ignoring out of range rgb_to_gray coefficients"); + + /* Use the defaults, from the cHRM chunk if set, else the historical + * values which are close to the sRGB/HDTV/ITU-Rec 709 values. See + * png_do_rgb_to_gray for more discussion of the values. In this case + * the coefficients are not marked as 'set' and are not overwritten if + * something has already provided a default. + */ + if (png_ptr->rgb_to_gray_red_coeff == 0 && + png_ptr->rgb_to_gray_green_coeff == 0) + { + png_ptr->rgb_to_gray_red_coeff = 6968; + png_ptr->rgb_to_gray_green_coeff = 23434; + /* png_ptr->rgb_to_gray_blue_coeff = 2366; */ + } + } + } +} + +#ifdef PNG_FLOATING_POINT_SUPPORTED +/* Convert a RGB image to a grayscale of the same width. This allows us, + * for example, to convert a 24 bpp RGB image into an 8 bpp grayscale image. + */ + +void PNGAPI +png_set_rgb_to_gray(png_structrp png_ptr, int error_action, double red, + double green) +{ + png_set_rgb_to_gray_fixed(png_ptr, error_action, + png_fixed(png_ptr, red, "rgb to gray red coefficient"), + png_fixed(png_ptr, green, "rgb to gray green coefficient")); +} +#endif /* FLOATING POINT */ + +#endif /* RGB_TO_GRAY */ + +#if defined(PNG_READ_USER_TRANSFORM_SUPPORTED) || \ + defined(PNG_WRITE_USER_TRANSFORM_SUPPORTED) +void PNGAPI +png_set_read_user_transform_fn(png_structrp png_ptr, png_user_transform_ptr + read_user_transform_fn) +{ + png_debug(1, "in png_set_read_user_transform_fn"); + +#ifdef PNG_READ_USER_TRANSFORM_SUPPORTED + png_ptr->transformations |= PNG_USER_TRANSFORM; + png_ptr->read_user_transform_fn = read_user_transform_fn; +#endif +} +#endif + +#ifdef PNG_READ_TRANSFORMS_SUPPORTED +#ifdef PNG_READ_GAMMA_SUPPORTED +/* In the case of gamma transformations only do transformations on images where + * the [file] gamma and screen_gamma are not close reciprocals, otherwise it + * slows things down slightly, and also needlessly introduces small errors. + */ +static int /* PRIVATE */ +png_gamma_threshold(png_fixed_point screen_gamma, png_fixed_point file_gamma) +{ + /* PNG_GAMMA_THRESHOLD is the threshold for performing gamma + * correction as a difference of the overall transform from 1.0 + * + * We want to compare the threshold with s*f - 1, if we get + * overflow here it is because of wacky gamma values so we + * turn on processing anyway. + */ + png_fixed_point gtest; + return !png_muldiv(>est, screen_gamma, file_gamma, PNG_FP_1) || + png_gamma_significant(gtest); +} +#endif + +/* Initialize everything needed for the read. This includes modifying + * the palette. + */ + +/*For the moment 'png_init_palette_transformations' and + * 'png_init_rgb_transformations' only do some flag canceling optimizations. + * The intent is that these two routines should have palette or rgb operations + * extracted from 'png_init_read_transformations'. + */ +static void /* PRIVATE */ +png_init_palette_transformations(png_structrp png_ptr) +{ + /* Called to handle the (input) palette case. In png_do_read_transformations + * the first step is to expand the palette if requested, so this code must + * take care to only make changes that are invariant with respect to the + * palette expansion, or only do them if there is no expansion. + * + * STRIP_ALPHA has already been handled in the caller (by setting num_trans + * to 0.) + */ + int input_has_alpha = 0; + int input_has_transparency = 0; + + if (png_ptr->num_trans > 0) + { + int i; + + /* Ignore if all the entries are opaque (unlikely!) */ + for (i=0; inum_trans; ++i) + { + if (png_ptr->trans_alpha[i] == 255) + continue; + else if (png_ptr->trans_alpha[i] == 0) + input_has_transparency = 1; + else + { + input_has_transparency = 1; + input_has_alpha = 1; + break; + } + } + } + + /* If no alpha we can optimize. */ + if (!input_has_alpha) + { + /* Any alpha means background and associative alpha processing is + * required, however if the alpha is 0 or 1 throughout OPTIIMIZE_ALPHA + * and ENCODE_ALPHA are irrelevant. + */ + png_ptr->transformations &= ~PNG_ENCODE_ALPHA; + png_ptr->flags &= ~PNG_FLAG_OPTIMIZE_ALPHA; + + if (!input_has_transparency) + png_ptr->transformations &= ~(PNG_COMPOSE | PNG_BACKGROUND_EXPAND); + } + +#if defined(PNG_READ_EXPAND_SUPPORTED) && defined(PNG_READ_BACKGROUND_SUPPORTED) + /* png_set_background handling - deals with the complexity of whether the + * background color is in the file format or the screen format in the case + * where an 'expand' will happen. + */ + + /* The following code cannot be entered in the alpha pre-multiplication case + * because PNG_BACKGROUND_EXPAND is cancelled below. + */ + if ((png_ptr->transformations & PNG_BACKGROUND_EXPAND) && + (png_ptr->transformations & PNG_EXPAND)) + { + { + png_ptr->background.red = + png_ptr->palette[png_ptr->background.index].red; + png_ptr->background.green = + png_ptr->palette[png_ptr->background.index].green; + png_ptr->background.blue = + png_ptr->palette[png_ptr->background.index].blue; + +#ifdef PNG_READ_INVERT_ALPHA_SUPPORTED + if (png_ptr->transformations & PNG_INVERT_ALPHA) + { + if (!(png_ptr->transformations & PNG_EXPAND_tRNS)) + { + /* Invert the alpha channel (in tRNS) unless the pixels are + * going to be expanded, in which case leave it for later + */ + int i, istop = png_ptr->num_trans; + + for (i=0; itrans_alpha[i] = (png_byte)(255 - + png_ptr->trans_alpha[i]); + } + } +#endif /* PNG_READ_INVERT_ALPHA_SUPPORTED */ + } + } /* background expand and (therefore) no alpha association. */ +#endif /* PNG_READ_EXPAND_SUPPORTED && PNG_READ_BACKGROUND_SUPPORTED */ +} + +static void /* PRIVATE */ +png_init_rgb_transformations(png_structrp png_ptr) +{ + /* Added to libpng-1.5.4: check the color type to determine whether there + * is any alpha or transparency in the image and simply cancel the + * background and alpha mode stuff if there isn't. + */ + int input_has_alpha = (png_ptr->color_type & PNG_COLOR_MASK_ALPHA) != 0; + int input_has_transparency = png_ptr->num_trans > 0; + + /* If no alpha we can optimize. */ + if (!input_has_alpha) + { + /* Any alpha means background and associative alpha processing is + * required, however if the alpha is 0 or 1 throughout OPTIIMIZE_ALPHA + * and ENCODE_ALPHA are irrelevant. + */ +# ifdef PNG_READ_ALPHA_MODE_SUPPORTED + png_ptr->transformations &= ~PNG_ENCODE_ALPHA; + png_ptr->flags &= ~PNG_FLAG_OPTIMIZE_ALPHA; +# endif + + if (!input_has_transparency) + png_ptr->transformations &= ~(PNG_COMPOSE | PNG_BACKGROUND_EXPAND); + } + +#if defined(PNG_READ_EXPAND_SUPPORTED) && defined(PNG_READ_BACKGROUND_SUPPORTED) + /* png_set_background handling - deals with the complexity of whether the + * background color is in the file format or the screen format in the case + * where an 'expand' will happen. + */ + + /* The following code cannot be entered in the alpha pre-multiplication case + * because PNG_BACKGROUND_EXPAND is cancelled below. + */ + if ((png_ptr->transformations & PNG_BACKGROUND_EXPAND) && + (png_ptr->transformations & PNG_EXPAND) && + !(png_ptr->color_type & PNG_COLOR_MASK_COLOR)) + /* i.e., GRAY or GRAY_ALPHA */ + { + { + /* Expand background and tRNS chunks */ + int gray = png_ptr->background.gray; + int trans_gray = png_ptr->trans_color.gray; + + switch (png_ptr->bit_depth) + { + case 1: + gray *= 0xff; + trans_gray *= 0xff; + break; + + case 2: + gray *= 0x55; + trans_gray *= 0x55; + break; + + case 4: + gray *= 0x11; + trans_gray *= 0x11; + break; + + default: + + case 8: + /* FALL THROUGH (Already 8 bits) */ + + case 16: + /* Already a full 16 bits */ + break; + } + + png_ptr->background.red = png_ptr->background.green = + png_ptr->background.blue = (png_uint_16)gray; + + if (!(png_ptr->transformations & PNG_EXPAND_tRNS)) + { + png_ptr->trans_color.red = png_ptr->trans_color.green = + png_ptr->trans_color.blue = (png_uint_16)trans_gray; + } + } + } /* background expand and (therefore) no alpha association. */ +#endif /* PNG_READ_EXPAND_SUPPORTED && PNG_READ_BACKGROUND_SUPPORTED */ +} + +void /* PRIVATE */ +png_init_read_transformations(png_structrp png_ptr) +{ + png_debug(1, "in png_init_read_transformations"); + + /* This internal function is called from png_read_start_row in pngrutil.c + * and it is called before the 'rowbytes' calculation is done, so the code + * in here can change or update the transformations flags. + * + * First do updates that do not depend on the details of the PNG image data + * being processed. + */ + +#ifdef PNG_READ_GAMMA_SUPPORTED + /* Prior to 1.5.4 these tests were performed from png_set_gamma, 1.5.4 adds + * png_set_alpha_mode and this is another source for a default file gamma so + * the test needs to be performed later - here. In addition prior to 1.5.4 + * the tests were repeated for the PALETTE color type here - this is no + * longer necessary (and doesn't seem to have been necessary before.) + */ + { + /* The following temporary indicates if overall gamma correction is + * required. + */ + int gamma_correction = 0; + + if (png_ptr->colorspace.gamma != 0) /* has been set */ + { + if (png_ptr->screen_gamma != 0) /* screen set too */ + gamma_correction = png_gamma_threshold(png_ptr->colorspace.gamma, + png_ptr->screen_gamma); + + else + /* Assume the output matches the input; a long time default behavior + * of libpng, although the standard has nothing to say about this. + */ + png_ptr->screen_gamma = png_reciprocal(png_ptr->colorspace.gamma); + } + + else if (png_ptr->screen_gamma != 0) + /* The converse - assume the file matches the screen, note that this + * perhaps undesireable default can (from 1.5.4) be changed by calling + * png_set_alpha_mode (even if the alpha handling mode isn't required + * or isn't changed from the default.) + */ + png_ptr->colorspace.gamma = png_reciprocal(png_ptr->screen_gamma); + + else /* neither are set */ + /* Just in case the following prevents any processing - file and screen + * are both assumed to be linear and there is no way to introduce a + * third gamma value other than png_set_background with 'UNIQUE', and, + * prior to 1.5.4 + */ + png_ptr->screen_gamma = png_ptr->colorspace.gamma = PNG_FP_1; + + /* We have a gamma value now. */ + png_ptr->colorspace.flags |= PNG_COLORSPACE_HAVE_GAMMA; + + /* Now turn the gamma transformation on or off as appropriate. Notice + * that PNG_GAMMA just refers to the file->screen correction. Alpha + * composition may independently cause gamma correction because it needs + * linear data (e.g. if the file has a gAMA chunk but the screen gamma + * hasn't been specified.) In any case this flag may get turned off in + * the code immediately below if the transform can be handled outside the + * row loop. + */ + if (gamma_correction) + png_ptr->transformations |= PNG_GAMMA; + + else + png_ptr->transformations &= ~PNG_GAMMA; + } +#endif + + /* Certain transformations have the effect of preventing other + * transformations that happen afterward in png_do_read_transformations, + * resolve the interdependencies here. From the code of + * png_do_read_transformations the order is: + * + * 1) PNG_EXPAND (including PNG_EXPAND_tRNS) + * 2) PNG_STRIP_ALPHA (if no compose) + * 3) PNG_RGB_TO_GRAY + * 4) PNG_GRAY_TO_RGB iff !PNG_BACKGROUND_IS_GRAY + * 5) PNG_COMPOSE + * 6) PNG_GAMMA + * 7) PNG_STRIP_ALPHA (if compose) + * 8) PNG_ENCODE_ALPHA + * 9) PNG_SCALE_16_TO_8 + * 10) PNG_16_TO_8 + * 11) PNG_QUANTIZE (converts to palette) + * 12) PNG_EXPAND_16 + * 13) PNG_GRAY_TO_RGB iff PNG_BACKGROUND_IS_GRAY + * 14) PNG_INVERT_MONO + * 15) PNG_SHIFT + * 16) PNG_PACK + * 17) PNG_BGR + * 18) PNG_PACKSWAP + * 19) PNG_FILLER (includes PNG_ADD_ALPHA) + * 20) PNG_INVERT_ALPHA + * 21) PNG_SWAP_ALPHA + * 22) PNG_SWAP_BYTES + * 23) PNG_USER_TRANSFORM [must be last] + */ +#ifdef PNG_READ_STRIP_ALPHA_SUPPORTED + if ((png_ptr->transformations & PNG_STRIP_ALPHA) && + !(png_ptr->transformations & PNG_COMPOSE)) + { + /* Stripping the alpha channel happens immediately after the 'expand' + * transformations, before all other transformation, so it cancels out + * the alpha handling. It has the side effect negating the effect of + * PNG_EXPAND_tRNS too: + */ + png_ptr->transformations &= ~(PNG_BACKGROUND_EXPAND | PNG_ENCODE_ALPHA | + PNG_EXPAND_tRNS); + png_ptr->flags &= ~PNG_FLAG_OPTIMIZE_ALPHA; + + /* Kill the tRNS chunk itself too. Prior to 1.5.4 this did not happen + * so transparency information would remain just so long as it wasn't + * expanded. This produces unexpected API changes if the set of things + * that do PNG_EXPAND_tRNS changes (perfectly possible given the + * documentation - which says ask for what you want, accept what you + * get.) This makes the behavior consistent from 1.5.4: + */ + png_ptr->num_trans = 0; + } +#endif /* STRIP_ALPHA supported, no COMPOSE */ + +#ifdef PNG_READ_ALPHA_MODE_SUPPORTED + /* If the screen gamma is about 1.0 then the OPTIMIZE_ALPHA and ENCODE_ALPHA + * settings will have no effect. + */ + if (!png_gamma_significant(png_ptr->screen_gamma)) + { + png_ptr->transformations &= ~PNG_ENCODE_ALPHA; + png_ptr->flags &= ~PNG_FLAG_OPTIMIZE_ALPHA; + } +#endif + +#ifdef PNG_READ_RGB_TO_GRAY_SUPPORTED + /* Make sure the coefficients for the rgb to gray conversion are set + * appropriately. + */ + if (png_ptr->transformations & PNG_RGB_TO_GRAY) + png_colorspace_set_rgb_coefficients(png_ptr); +#endif + +#ifdef PNG_READ_GRAY_TO_RGB_SUPPORTED +#if defined(PNG_READ_EXPAND_SUPPORTED) && defined(PNG_READ_BACKGROUND_SUPPORTED) + /* Detect gray background and attempt to enable optimization for + * gray --> RGB case. + * + * Note: if PNG_BACKGROUND_EXPAND is set and color_type is either RGB or + * RGB_ALPHA (in which case need_expand is superfluous anyway), the + * background color might actually be gray yet not be flagged as such. + * This is not a problem for the current code, which uses + * PNG_BACKGROUND_IS_GRAY only to decide when to do the + * png_do_gray_to_rgb() transformation. + * + * TODO: this code needs to be revised to avoid the complexity and + * interdependencies. The color type of the background should be recorded in + * png_set_background, along with the bit depth, then the code has a record + * of exactly what color space the background is currently in. + */ + if (png_ptr->transformations & PNG_BACKGROUND_EXPAND) + { + /* PNG_BACKGROUND_EXPAND: the background is in the file color space, so if + * the file was grayscale the background value is gray. + */ + if (!(png_ptr->color_type & PNG_COLOR_MASK_COLOR)) + png_ptr->mode |= PNG_BACKGROUND_IS_GRAY; + } + + else if (png_ptr->transformations & PNG_COMPOSE) + { + /* PNG_COMPOSE: png_set_background was called with need_expand false, + * so the color is in the color space of the output or png_set_alpha_mode + * was called and the color is black. Ignore RGB_TO_GRAY because that + * happens before GRAY_TO_RGB. + */ + if (png_ptr->transformations & PNG_GRAY_TO_RGB) + { + if (png_ptr->background.red == png_ptr->background.green && + png_ptr->background.red == png_ptr->background.blue) + { + png_ptr->mode |= PNG_BACKGROUND_IS_GRAY; + png_ptr->background.gray = png_ptr->background.red; + } + } + } +#endif /* PNG_READ_EXPAND_SUPPORTED && PNG_READ_BACKGROUND_SUPPORTED */ +#endif /* PNG_READ_GRAY_TO_RGB_SUPPORTED */ + + /* For indexed PNG data (PNG_COLOR_TYPE_PALETTE) many of the transformations + * can be performed directly on the palette, and some (such as rgb to gray) + * can be optimized inside the palette. This is particularly true of the + * composite (background and alpha) stuff, which can be pretty much all done + * in the palette even if the result is expanded to RGB or gray afterward. + * + * NOTE: this is Not Yet Implemented, the code behaves as in 1.5.1 and + * earlier and the palette stuff is actually handled on the first row. This + * leads to the reported bug that the palette returned by png_get_PLTE is not + * updated. + */ + if (png_ptr->color_type == PNG_COLOR_TYPE_PALETTE) + png_init_palette_transformations(png_ptr); + + else + png_init_rgb_transformations(png_ptr); + +#if defined(PNG_READ_BACKGROUND_SUPPORTED) && \ + defined(PNG_READ_EXPAND_16_SUPPORTED) + if ((png_ptr->transformations & PNG_EXPAND_16) && + (png_ptr->transformations & PNG_COMPOSE) && + !(png_ptr->transformations & PNG_BACKGROUND_EXPAND) && + png_ptr->bit_depth != 16) + { + /* TODO: fix this. Because the expand_16 operation is after the compose + * handling the background color must be 8, not 16, bits deep, but the + * application will supply a 16-bit value so reduce it here. + * + * The PNG_BACKGROUND_EXPAND code above does not expand to 16 bits at + * present, so that case is ok (until do_expand_16 is moved.) + * + * NOTE: this discards the low 16 bits of the user supplied background + * color, but until expand_16 works properly there is no choice! + */ +# define CHOP(x) (x)=((png_uint_16)PNG_DIV257(x)) + CHOP(png_ptr->background.red); + CHOP(png_ptr->background.green); + CHOP(png_ptr->background.blue); + CHOP(png_ptr->background.gray); +# undef CHOP + } +#endif /* PNG_READ_BACKGROUND_SUPPORTED && PNG_READ_EXPAND_16_SUPPORTED */ + +#if defined(PNG_READ_BACKGROUND_SUPPORTED) && \ + (defined(PNG_READ_SCALE_16_TO_8_SUPPORTED) || \ + defined(PNG_READ_STRIP_16_TO_8_SUPPORTED)) + if ((png_ptr->transformations & (PNG_16_TO_8|PNG_SCALE_16_TO_8)) && + (png_ptr->transformations & PNG_COMPOSE) && + !(png_ptr->transformations & PNG_BACKGROUND_EXPAND) && + png_ptr->bit_depth == 16) + { + /* On the other hand, if a 16-bit file is to be reduced to 8-bits per + * component this will also happen after PNG_COMPOSE and so the background + * color must be pre-expanded here. + * + * TODO: fix this too. + */ + png_ptr->background.red = (png_uint_16)(png_ptr->background.red * 257); + png_ptr->background.green = + (png_uint_16)(png_ptr->background.green * 257); + png_ptr->background.blue = (png_uint_16)(png_ptr->background.blue * 257); + png_ptr->background.gray = (png_uint_16)(png_ptr->background.gray * 257); + } +#endif + + /* NOTE: below 'PNG_READ_ALPHA_MODE_SUPPORTED' is presumed to also enable the + * background support (see the comments in scripts/pnglibconf.dfa), this + * allows pre-multiplication of the alpha channel to be implemented as + * compositing on black. This is probably sub-optimal and has been done in + * 1.5.4 betas simply to enable external critique and testing (i.e. to + * implement the new API quickly, without lots of internal changes.) + */ + +#ifdef PNG_READ_GAMMA_SUPPORTED +# ifdef PNG_READ_BACKGROUND_SUPPORTED + /* Includes ALPHA_MODE */ + png_ptr->background_1 = png_ptr->background; +# endif + + /* This needs to change - in the palette image case a whole set of tables are + * built when it would be quicker to just calculate the correct value for + * each palette entry directly. Also, the test is too tricky - why check + * PNG_RGB_TO_GRAY if PNG_GAMMA is not set? The answer seems to be that + * PNG_GAMMA is cancelled even if the gamma is known? The test excludes the + * PNG_COMPOSE case, so apparently if there is no *overall* gamma correction + * the gamma tables will not be built even if composition is required on a + * gamma encoded value. + * + * In 1.5.4 this is addressed below by an additional check on the individual + * file gamma - if it is not 1.0 both RGB_TO_GRAY and COMPOSE need the + * tables. + */ + if ((png_ptr->transformations & PNG_GAMMA) + || ((png_ptr->transformations & PNG_RGB_TO_GRAY) + && (png_gamma_significant(png_ptr->colorspace.gamma) || + png_gamma_significant(png_ptr->screen_gamma))) + || ((png_ptr->transformations & PNG_COMPOSE) + && (png_gamma_significant(png_ptr->colorspace.gamma) + || png_gamma_significant(png_ptr->screen_gamma) +# ifdef PNG_READ_BACKGROUND_SUPPORTED + || (png_ptr->background_gamma_type == PNG_BACKGROUND_GAMMA_UNIQUE + && png_gamma_significant(png_ptr->background_gamma)) +# endif + )) || ((png_ptr->transformations & PNG_ENCODE_ALPHA) + && png_gamma_significant(png_ptr->screen_gamma)) + ) + { + png_build_gamma_table(png_ptr, png_ptr->bit_depth); + +#ifdef PNG_READ_BACKGROUND_SUPPORTED + if (png_ptr->transformations & PNG_COMPOSE) + { + /* Issue a warning about this combination: because RGB_TO_GRAY is + * optimized to do the gamma transform if present yet do_background has + * to do the same thing if both options are set a + * double-gamma-correction happens. This is true in all versions of + * libpng to date. + */ + if (png_ptr->transformations & PNG_RGB_TO_GRAY) + png_warning(png_ptr, + "libpng does not support gamma+background+rgb_to_gray"); + + if (png_ptr->color_type == PNG_COLOR_TYPE_PALETTE) + { + /* We don't get to here unless there is a tRNS chunk with non-opaque + * entries - see the checking code at the start of this function. + */ + png_color back, back_1; + png_colorp palette = png_ptr->palette; + int num_palette = png_ptr->num_palette; + int i; + if (png_ptr->background_gamma_type == PNG_BACKGROUND_GAMMA_FILE) + { + + back.red = png_ptr->gamma_table[png_ptr->background.red]; + back.green = png_ptr->gamma_table[png_ptr->background.green]; + back.blue = png_ptr->gamma_table[png_ptr->background.blue]; + + back_1.red = png_ptr->gamma_to_1[png_ptr->background.red]; + back_1.green = png_ptr->gamma_to_1[png_ptr->background.green]; + back_1.blue = png_ptr->gamma_to_1[png_ptr->background.blue]; + } + else + { + png_fixed_point g, gs; + + switch (png_ptr->background_gamma_type) + { + case PNG_BACKGROUND_GAMMA_SCREEN: + g = (png_ptr->screen_gamma); + gs = PNG_FP_1; + break; + + case PNG_BACKGROUND_GAMMA_FILE: + g = png_reciprocal(png_ptr->colorspace.gamma); + gs = png_reciprocal2(png_ptr->colorspace.gamma, + png_ptr->screen_gamma); + break; + + case PNG_BACKGROUND_GAMMA_UNIQUE: + g = png_reciprocal(png_ptr->background_gamma); + gs = png_reciprocal2(png_ptr->background_gamma, + png_ptr->screen_gamma); + break; + default: + g = PNG_FP_1; /* back_1 */ + gs = PNG_FP_1; /* back */ + break; + } + + if (png_gamma_significant(gs)) + { + back.red = png_gamma_8bit_correct(png_ptr->background.red, + gs); + back.green = png_gamma_8bit_correct(png_ptr->background.green, + gs); + back.blue = png_gamma_8bit_correct(png_ptr->background.blue, + gs); + } + + else + { + back.red = (png_byte)png_ptr->background.red; + back.green = (png_byte)png_ptr->background.green; + back.blue = (png_byte)png_ptr->background.blue; + } + + if (png_gamma_significant(g)) + { + back_1.red = png_gamma_8bit_correct(png_ptr->background.red, + g); + back_1.green = png_gamma_8bit_correct( + png_ptr->background.green, g); + back_1.blue = png_gamma_8bit_correct(png_ptr->background.blue, + g); + } + + else + { + back_1.red = (png_byte)png_ptr->background.red; + back_1.green = (png_byte)png_ptr->background.green; + back_1.blue = (png_byte)png_ptr->background.blue; + } + } + + for (i = 0; i < num_palette; i++) + { + if (i < (int)png_ptr->num_trans && + png_ptr->trans_alpha[i] != 0xff) + { + if (png_ptr->trans_alpha[i] == 0) + { + palette[i] = back; + } + else /* if (png_ptr->trans_alpha[i] != 0xff) */ + { + png_byte v, w; + + v = png_ptr->gamma_to_1[palette[i].red]; + png_composite(w, v, png_ptr->trans_alpha[i], back_1.red); + palette[i].red = png_ptr->gamma_from_1[w]; + + v = png_ptr->gamma_to_1[palette[i].green]; + png_composite(w, v, png_ptr->trans_alpha[i], back_1.green); + palette[i].green = png_ptr->gamma_from_1[w]; + + v = png_ptr->gamma_to_1[palette[i].blue]; + png_composite(w, v, png_ptr->trans_alpha[i], back_1.blue); + palette[i].blue = png_ptr->gamma_from_1[w]; + } + } + else + { + palette[i].red = png_ptr->gamma_table[palette[i].red]; + palette[i].green = png_ptr->gamma_table[palette[i].green]; + palette[i].blue = png_ptr->gamma_table[palette[i].blue]; + } + } + + /* Prevent the transformations being done again. + * + * NOTE: this is highly dubious; it removes the transformations in + * place. This seems inconsistent with the general treatment of the + * transformations elsewhere. + */ + png_ptr->transformations &= ~(PNG_COMPOSE | PNG_GAMMA); + } /* color_type == PNG_COLOR_TYPE_PALETTE */ + + /* if (png_ptr->background_gamma_type!=PNG_BACKGROUND_GAMMA_UNKNOWN) */ + else /* color_type != PNG_COLOR_TYPE_PALETTE */ + { + int gs_sig, g_sig; + png_fixed_point g = PNG_FP_1; /* Correction to linear */ + png_fixed_point gs = PNG_FP_1; /* Correction to screen */ + + switch (png_ptr->background_gamma_type) + { + case PNG_BACKGROUND_GAMMA_SCREEN: + g = png_ptr->screen_gamma; + /* gs = PNG_FP_1; */ + break; + + case PNG_BACKGROUND_GAMMA_FILE: + g = png_reciprocal(png_ptr->colorspace.gamma); + gs = png_reciprocal2(png_ptr->colorspace.gamma, + png_ptr->screen_gamma); + break; + + case PNG_BACKGROUND_GAMMA_UNIQUE: + g = png_reciprocal(png_ptr->background_gamma); + gs = png_reciprocal2(png_ptr->background_gamma, + png_ptr->screen_gamma); + break; + + default: + png_error(png_ptr, "invalid background gamma type"); + } + + g_sig = png_gamma_significant(g); + gs_sig = png_gamma_significant(gs); + + if (g_sig) + png_ptr->background_1.gray = png_gamma_correct(png_ptr, + png_ptr->background.gray, g); + + if (gs_sig) + png_ptr->background.gray = png_gamma_correct(png_ptr, + png_ptr->background.gray, gs); + + if ((png_ptr->background.red != png_ptr->background.green) || + (png_ptr->background.red != png_ptr->background.blue) || + (png_ptr->background.red != png_ptr->background.gray)) + { + /* RGB or RGBA with color background */ + if (g_sig) + { + png_ptr->background_1.red = png_gamma_correct(png_ptr, + png_ptr->background.red, g); + + png_ptr->background_1.green = png_gamma_correct(png_ptr, + png_ptr->background.green, g); + + png_ptr->background_1.blue = png_gamma_correct(png_ptr, + png_ptr->background.blue, g); + } + + if (gs_sig) + { + png_ptr->background.red = png_gamma_correct(png_ptr, + png_ptr->background.red, gs); + + png_ptr->background.green = png_gamma_correct(png_ptr, + png_ptr->background.green, gs); + + png_ptr->background.blue = png_gamma_correct(png_ptr, + png_ptr->background.blue, gs); + } + } + + else + { + /* GRAY, GRAY ALPHA, RGB, or RGBA with gray background */ + png_ptr->background_1.red = png_ptr->background_1.green + = png_ptr->background_1.blue = png_ptr->background_1.gray; + + png_ptr->background.red = png_ptr->background.green + = png_ptr->background.blue = png_ptr->background.gray; + } + + /* The background is now in screen gamma: */ + png_ptr->background_gamma_type = PNG_BACKGROUND_GAMMA_SCREEN; + } /* color_type != PNG_COLOR_TYPE_PALETTE */ + }/* png_ptr->transformations & PNG_BACKGROUND */ + + else + /* Transformation does not include PNG_BACKGROUND */ +#endif /* PNG_READ_BACKGROUND_SUPPORTED */ + if (png_ptr->color_type == PNG_COLOR_TYPE_PALETTE +#ifdef PNG_READ_RGB_TO_GRAY_SUPPORTED + /* RGB_TO_GRAY needs to have non-gamma-corrected values! */ + && ((png_ptr->transformations & PNG_EXPAND) == 0 || + (png_ptr->transformations & PNG_RGB_TO_GRAY) == 0) +#endif + ) + { + png_colorp palette = png_ptr->palette; + int num_palette = png_ptr->num_palette; + int i; + + /* NOTE: there are other transformations that should probably be in + * here too. + */ + for (i = 0; i < num_palette; i++) + { + palette[i].red = png_ptr->gamma_table[palette[i].red]; + palette[i].green = png_ptr->gamma_table[palette[i].green]; + palette[i].blue = png_ptr->gamma_table[palette[i].blue]; + } + + /* Done the gamma correction. */ + png_ptr->transformations &= ~PNG_GAMMA; + } /* color_type == PALETTE && !PNG_BACKGROUND transformation */ + } +#ifdef PNG_READ_BACKGROUND_SUPPORTED + else +#endif +#endif /* PNG_READ_GAMMA_SUPPORTED */ + +#ifdef PNG_READ_BACKGROUND_SUPPORTED + /* No GAMMA transformation (see the hanging else 4 lines above) */ + if ((png_ptr->transformations & PNG_COMPOSE) && + (png_ptr->color_type == PNG_COLOR_TYPE_PALETTE)) + { + int i; + int istop = (int)png_ptr->num_trans; + png_color back; + png_colorp palette = png_ptr->palette; + + back.red = (png_byte)png_ptr->background.red; + back.green = (png_byte)png_ptr->background.green; + back.blue = (png_byte)png_ptr->background.blue; + + for (i = 0; i < istop; i++) + { + if (png_ptr->trans_alpha[i] == 0) + { + palette[i] = back; + } + + else if (png_ptr->trans_alpha[i] != 0xff) + { + /* The png_composite() macro is defined in png.h */ + png_composite(palette[i].red, palette[i].red, + png_ptr->trans_alpha[i], back.red); + + png_composite(palette[i].green, palette[i].green, + png_ptr->trans_alpha[i], back.green); + + png_composite(palette[i].blue, palette[i].blue, + png_ptr->trans_alpha[i], back.blue); + } + } + + png_ptr->transformations &= ~PNG_COMPOSE; + } +#endif /* PNG_READ_BACKGROUND_SUPPORTED */ + +#ifdef PNG_READ_SHIFT_SUPPORTED + if ((png_ptr->transformations & PNG_SHIFT) && + !(png_ptr->transformations & PNG_EXPAND) && + (png_ptr->color_type == PNG_COLOR_TYPE_PALETTE)) + { + int i; + int istop = png_ptr->num_palette; + int shift = 8 - png_ptr->sig_bit.red; + + png_ptr->transformations &= ~PNG_SHIFT; + + /* significant bits can be in the range 1 to 7 for a meaninful result, if + * the number of significant bits is 0 then no shift is done (this is an + * error condition which is silently ignored.) + */ + if (shift > 0 && shift < 8) + for (i=0; ipalette[i].red; + + component >>= shift; + png_ptr->palette[i].red = (png_byte)component; + } + + shift = 8 - png_ptr->sig_bit.green; + if (shift > 0 && shift < 8) + for (i=0; ipalette[i].green; + + component >>= shift; + png_ptr->palette[i].green = (png_byte)component; + } + + shift = 8 - png_ptr->sig_bit.blue; + if (shift > 0 && shift < 8) + for (i=0; ipalette[i].blue; + + component >>= shift; + png_ptr->palette[i].blue = (png_byte)component; + } + } +#endif /* PNG_READ_SHIFT_SUPPORTED */ +} + +/* Modify the info structure to reflect the transformations. The + * info should be updated so a PNG file could be written with it, + * assuming the transformations result in valid PNG data. + */ +void /* PRIVATE */ +png_read_transform_info(png_structrp png_ptr, png_inforp info_ptr) +{ + png_debug(1, "in png_read_transform_info"); + +#ifdef PNG_READ_EXPAND_SUPPORTED + if (png_ptr->transformations & PNG_EXPAND) + { + if (info_ptr->color_type == PNG_COLOR_TYPE_PALETTE) + { + /* This check must match what actually happens in + * png_do_expand_palette; if it ever checks the tRNS chunk to see if + * it is all opaque we must do the same (at present it does not.) + */ + if (png_ptr->num_trans > 0) + info_ptr->color_type = PNG_COLOR_TYPE_RGB_ALPHA; + + else + info_ptr->color_type = PNG_COLOR_TYPE_RGB; + + info_ptr->bit_depth = 8; + info_ptr->num_trans = 0; + } + else + { + if (png_ptr->num_trans) + { + if (png_ptr->transformations & PNG_EXPAND_tRNS) + info_ptr->color_type |= PNG_COLOR_MASK_ALPHA; + } + if (info_ptr->bit_depth < 8) + info_ptr->bit_depth = 8; + + info_ptr->num_trans = 0; + } + } +#endif + +#if defined(PNG_READ_BACKGROUND_SUPPORTED) ||\ + defined(PNG_READ_ALPHA_MODE_SUPPORTED) + /* The following is almost certainly wrong unless the background value is in + * the screen space! + */ + if (png_ptr->transformations & PNG_COMPOSE) + info_ptr->background = png_ptr->background; +#endif + +#ifdef PNG_READ_GAMMA_SUPPORTED + /* The following used to be conditional on PNG_GAMMA (prior to 1.5.4), + * however it seems that the code in png_init_read_transformations, which has + * been called before this from png_read_update_info->png_read_start_row + * sometimes does the gamma transform and cancels the flag. + * + * TODO: this looks wrong; the info_ptr should end up with a gamma equal to + * the screen_gamma value. The following probably results in weirdness if + * the info_ptr is used by the app after the rows have been read. + */ + info_ptr->colorspace.gamma = png_ptr->colorspace.gamma; +#endif + + if (info_ptr->bit_depth == 16) + { +# ifdef PNG_READ_16BIT_SUPPORTED +# ifdef PNG_READ_SCALE_16_TO_8_SUPPORTED + if (png_ptr->transformations & PNG_SCALE_16_TO_8) + info_ptr->bit_depth = 8; +# endif + +# ifdef PNG_READ_STRIP_16_TO_8_SUPPORTED + if (png_ptr->transformations & PNG_16_TO_8) + info_ptr->bit_depth = 8; +# endif + +# else + /* No 16 bit support: force chopping 16-bit input down to 8, in this case + * the app program can chose if both APIs are available by setting the + * correct scaling to use. + */ +# ifdef PNG_READ_STRIP_16_TO_8_SUPPORTED + /* For compatibility with previous versions use the strip method by + * default. This code works because if PNG_SCALE_16_TO_8 is already + * set the code below will do that in preference to the chop. + */ + png_ptr->transformations |= PNG_16_TO_8; + info_ptr->bit_depth = 8; +# else + +# ifdef PNG_READ_SCALE_16_TO_8_SUPPORTED + png_ptr->transformations |= PNG_SCALE_16_TO_8; + info_ptr->bit_depth = 8; +# else + + CONFIGURATION ERROR: you must enable at least one 16 to 8 method +# endif +# endif +#endif /* !READ_16BIT_SUPPORTED */ + } + +#ifdef PNG_READ_GRAY_TO_RGB_SUPPORTED + if (png_ptr->transformations & PNG_GRAY_TO_RGB) + info_ptr->color_type = (png_byte)(info_ptr->color_type | + PNG_COLOR_MASK_COLOR); +#endif + +#ifdef PNG_READ_RGB_TO_GRAY_SUPPORTED + if (png_ptr->transformations & PNG_RGB_TO_GRAY) + info_ptr->color_type = (png_byte)(info_ptr->color_type & + ~PNG_COLOR_MASK_COLOR); +#endif + +#ifdef PNG_READ_QUANTIZE_SUPPORTED + if (png_ptr->transformations & PNG_QUANTIZE) + { + if (((info_ptr->color_type == PNG_COLOR_TYPE_RGB) || + (info_ptr->color_type == PNG_COLOR_TYPE_RGB_ALPHA)) && + png_ptr->palette_lookup && info_ptr->bit_depth == 8) + { + info_ptr->color_type = PNG_COLOR_TYPE_PALETTE; + } + } +#endif + +#ifdef PNG_READ_EXPAND_16_SUPPORTED + if (png_ptr->transformations & PNG_EXPAND_16 && info_ptr->bit_depth == 8 && + info_ptr->color_type != PNG_COLOR_TYPE_PALETTE) + { + info_ptr->bit_depth = 16; + } +#endif + +#ifdef PNG_READ_PACK_SUPPORTED + if ((png_ptr->transformations & PNG_PACK) && (info_ptr->bit_depth < 8)) + info_ptr->bit_depth = 8; +#endif + + if (info_ptr->color_type == PNG_COLOR_TYPE_PALETTE) + info_ptr->channels = 1; + + else if (info_ptr->color_type & PNG_COLOR_MASK_COLOR) + info_ptr->channels = 3; + + else + info_ptr->channels = 1; + +#ifdef PNG_READ_STRIP_ALPHA_SUPPORTED + if (png_ptr->transformations & PNG_STRIP_ALPHA) + { + info_ptr->color_type = (png_byte)(info_ptr->color_type & + ~PNG_COLOR_MASK_ALPHA); + info_ptr->num_trans = 0; + } +#endif + + if (info_ptr->color_type & PNG_COLOR_MASK_ALPHA) + info_ptr->channels++; + +#ifdef PNG_READ_FILLER_SUPPORTED + /* STRIP_ALPHA and FILLER allowed: MASK_ALPHA bit stripped above */ + if ((png_ptr->transformations & PNG_FILLER) && + ((info_ptr->color_type == PNG_COLOR_TYPE_RGB) || + (info_ptr->color_type == PNG_COLOR_TYPE_GRAY))) + { + info_ptr->channels++; + /* If adding a true alpha channel not just filler */ + if (png_ptr->transformations & PNG_ADD_ALPHA) + info_ptr->color_type |= PNG_COLOR_MASK_ALPHA; + } +#endif + +#if defined(PNG_USER_TRANSFORM_PTR_SUPPORTED) && \ +defined(PNG_READ_USER_TRANSFORM_SUPPORTED) + if (png_ptr->transformations & PNG_USER_TRANSFORM) + { + if (info_ptr->bit_depth < png_ptr->user_transform_depth) + info_ptr->bit_depth = png_ptr->user_transform_depth; + + if (info_ptr->channels < png_ptr->user_transform_channels) + info_ptr->channels = png_ptr->user_transform_channels; + } +#endif + + info_ptr->pixel_depth = (png_byte)(info_ptr->channels * + info_ptr->bit_depth); + + info_ptr->rowbytes = PNG_ROWBYTES(info_ptr->pixel_depth, info_ptr->width); + + /* Adding in 1.5.4: cache the above value in png_struct so that we can later + * check in png_rowbytes that the user buffer won't get overwritten. Note + * that the field is not always set - if png_read_update_info isn't called + * the application has to either not do any transforms or get the calculation + * right itself. + */ + png_ptr->info_rowbytes = info_ptr->rowbytes; + +#ifndef PNG_READ_EXPAND_SUPPORTED + if (png_ptr) + return; +#endif +} + +/* Transform the row. The order of transformations is significant, + * and is very touchy. If you add a transformation, take care to + * decide how it fits in with the other transformations here. + */ +void /* PRIVATE */ +png_do_read_transformations(png_structrp png_ptr, png_row_infop row_info) +{ + png_debug(1, "in png_do_read_transformations"); + + if (png_ptr->row_buf == NULL) + { + /* Prior to 1.5.4 this output row/pass where the NULL pointer is, but this + * error is incredibly rare and incredibly easy to debug without this + * information. + */ + png_error(png_ptr, "NULL row buffer"); + } + + /* The following is debugging; prior to 1.5.4 the code was never compiled in; + * in 1.5.4 PNG_FLAG_DETECT_UNINITIALIZED was added and the macro + * PNG_WARN_UNINITIALIZED_ROW removed. In 1.6 the new flag is set only for + * all transformations, however in practice the ROW_INIT always gets done on + * demand, if necessary. + */ + if ((png_ptr->flags & PNG_FLAG_DETECT_UNINITIALIZED) != 0 && + !(png_ptr->flags & PNG_FLAG_ROW_INIT)) + { + /* Application has failed to call either png_read_start_image() or + * png_read_update_info() after setting transforms that expand pixels. + * This check added to libpng-1.2.19 (but not enabled until 1.5.4). + */ + png_error(png_ptr, "Uninitialized row"); + } + +#ifdef PNG_READ_EXPAND_SUPPORTED + if (png_ptr->transformations & PNG_EXPAND) + { + if (row_info->color_type == PNG_COLOR_TYPE_PALETTE) + { + png_do_expand_palette(row_info, png_ptr->row_buf + 1, + png_ptr->palette, png_ptr->trans_alpha, png_ptr->num_trans); + } + + else + { + if (png_ptr->num_trans && + (png_ptr->transformations & PNG_EXPAND_tRNS)) + png_do_expand(row_info, png_ptr->row_buf + 1, + &(png_ptr->trans_color)); + + else + png_do_expand(row_info, png_ptr->row_buf + 1, + NULL); + } + } +#endif + +#ifdef PNG_READ_STRIP_ALPHA_SUPPORTED + if ((png_ptr->transformations & PNG_STRIP_ALPHA) && + !(png_ptr->transformations & PNG_COMPOSE) && + (row_info->color_type == PNG_COLOR_TYPE_RGB_ALPHA || + row_info->color_type == PNG_COLOR_TYPE_GRAY_ALPHA)) + png_do_strip_channel(row_info, png_ptr->row_buf + 1, + 0 /* at_start == false, because SWAP_ALPHA happens later */); +#endif + +#ifdef PNG_READ_RGB_TO_GRAY_SUPPORTED + if (png_ptr->transformations & PNG_RGB_TO_GRAY) + { + int rgb_error = + png_do_rgb_to_gray(png_ptr, row_info, + png_ptr->row_buf + 1); + + if (rgb_error) + { + png_ptr->rgb_to_gray_status=1; + if ((png_ptr->transformations & PNG_RGB_TO_GRAY) == + PNG_RGB_TO_GRAY_WARN) + png_warning(png_ptr, "png_do_rgb_to_gray found nongray pixel"); + + if ((png_ptr->transformations & PNG_RGB_TO_GRAY) == + PNG_RGB_TO_GRAY_ERR) + png_error(png_ptr, "png_do_rgb_to_gray found nongray pixel"); + } + } +#endif + +/* From Andreas Dilger e-mail to png-implement, 26 March 1998: + * + * In most cases, the "simple transparency" should be done prior to doing + * gray-to-RGB, or you will have to test 3x as many bytes to check if a + * pixel is transparent. You would also need to make sure that the + * transparency information is upgraded to RGB. + * + * To summarize, the current flow is: + * - Gray + simple transparency -> compare 1 or 2 gray bytes and composite + * with background "in place" if transparent, + * convert to RGB if necessary + * - Gray + alpha -> composite with gray background and remove alpha bytes, + * convert to RGB if necessary + * + * To support RGB backgrounds for gray images we need: + * - Gray + simple transparency -> convert to RGB + simple transparency, + * compare 3 or 6 bytes and composite with + * background "in place" if transparent + * (3x compare/pixel compared to doing + * composite with gray bkgrnd) + * - Gray + alpha -> convert to RGB + alpha, composite with background and + * remove alpha bytes (3x float + * operations/pixel compared with composite + * on gray background) + * + * Greg's change will do this. The reason it wasn't done before is for + * performance, as this increases the per-pixel operations. If we would check + * in advance if the background was gray or RGB, and position the gray-to-RGB + * transform appropriately, then it would save a lot of work/time. + */ + +#ifdef PNG_READ_GRAY_TO_RGB_SUPPORTED + /* If gray -> RGB, do so now only if background is non-gray; else do later + * for performance reasons + */ + if ((png_ptr->transformations & PNG_GRAY_TO_RGB) && + !(png_ptr->mode & PNG_BACKGROUND_IS_GRAY)) + png_do_gray_to_rgb(row_info, png_ptr->row_buf + 1); +#endif + +#if defined(PNG_READ_BACKGROUND_SUPPORTED) ||\ + defined(PNG_READ_ALPHA_MODE_SUPPORTED) + if (png_ptr->transformations & PNG_COMPOSE) + png_do_compose(row_info, png_ptr->row_buf + 1, png_ptr); +#endif + +#ifdef PNG_READ_GAMMA_SUPPORTED + if ((png_ptr->transformations & PNG_GAMMA) && +#ifdef PNG_READ_RGB_TO_GRAY_SUPPORTED + /* Because RGB_TO_GRAY does the gamma transform. */ + !(png_ptr->transformations & PNG_RGB_TO_GRAY) && +#endif +#if defined(PNG_READ_BACKGROUND_SUPPORTED) ||\ + defined(PNG_READ_ALPHA_MODE_SUPPORTED) + /* Because PNG_COMPOSE does the gamma transform if there is something to + * do (if there is an alpha channel or transparency.) + */ + !((png_ptr->transformations & PNG_COMPOSE) && + ((png_ptr->num_trans != 0) || + (png_ptr->color_type & PNG_COLOR_MASK_ALPHA))) && +#endif + /* Because png_init_read_transformations transforms the palette, unless + * RGB_TO_GRAY will do the transform. + */ + (png_ptr->color_type != PNG_COLOR_TYPE_PALETTE)) + png_do_gamma(row_info, png_ptr->row_buf + 1, png_ptr); +#endif + +#ifdef PNG_READ_STRIP_ALPHA_SUPPORTED + if ((png_ptr->transformations & PNG_STRIP_ALPHA) && + (png_ptr->transformations & PNG_COMPOSE) && + (row_info->color_type == PNG_COLOR_TYPE_RGB_ALPHA || + row_info->color_type == PNG_COLOR_TYPE_GRAY_ALPHA)) + png_do_strip_channel(row_info, png_ptr->row_buf + 1, + 0 /* at_start == false, because SWAP_ALPHA happens later */); +#endif + +#ifdef PNG_READ_ALPHA_MODE_SUPPORTED + if ((png_ptr->transformations & PNG_ENCODE_ALPHA) && + (row_info->color_type & PNG_COLOR_MASK_ALPHA)) + png_do_encode_alpha(row_info, png_ptr->row_buf + 1, png_ptr); +#endif + +#ifdef PNG_READ_SCALE_16_TO_8_SUPPORTED + if (png_ptr->transformations & PNG_SCALE_16_TO_8) + png_do_scale_16_to_8(row_info, png_ptr->row_buf + 1); +#endif + +#ifdef PNG_READ_STRIP_16_TO_8_SUPPORTED + /* There is no harm in doing both of these because only one has any effect, + * by putting the 'scale' option first if the app asks for scale (either by + * calling the API or in a TRANSFORM flag) this is what happens. + */ + if (png_ptr->transformations & PNG_16_TO_8) + png_do_chop(row_info, png_ptr->row_buf + 1); +#endif + +#ifdef PNG_READ_QUANTIZE_SUPPORTED + if (png_ptr->transformations & PNG_QUANTIZE) + { + png_do_quantize(row_info, png_ptr->row_buf + 1, + png_ptr->palette_lookup, png_ptr->quantize_index); + + if (row_info->rowbytes == 0) + png_error(png_ptr, "png_do_quantize returned rowbytes=0"); + } +#endif /* PNG_READ_QUANTIZE_SUPPORTED */ + +#ifdef PNG_READ_EXPAND_16_SUPPORTED + /* Do the expansion now, after all the arithmetic has been done. Notice + * that previous transformations can handle the PNG_EXPAND_16 flag if this + * is efficient (particularly true in the case of gamma correction, where + * better accuracy results faster!) + */ + if (png_ptr->transformations & PNG_EXPAND_16) + png_do_expand_16(row_info, png_ptr->row_buf + 1); +#endif + +#ifdef PNG_READ_GRAY_TO_RGB_SUPPORTED + /* NOTE: moved here in 1.5.4 (from much later in this list.) */ + if ((png_ptr->transformations & PNG_GRAY_TO_RGB) && + (png_ptr->mode & PNG_BACKGROUND_IS_GRAY)) + png_do_gray_to_rgb(row_info, png_ptr->row_buf + 1); +#endif + +#ifdef PNG_READ_INVERT_SUPPORTED + if (png_ptr->transformations & PNG_INVERT_MONO) + png_do_invert(row_info, png_ptr->row_buf + 1); +#endif + +#ifdef PNG_READ_SHIFT_SUPPORTED + if (png_ptr->transformations & PNG_SHIFT) + png_do_unshift(row_info, png_ptr->row_buf + 1, + &(png_ptr->shift)); +#endif + +#ifdef PNG_READ_PACK_SUPPORTED + if (png_ptr->transformations & PNG_PACK) + png_do_unpack(row_info, png_ptr->row_buf + 1); +#endif + +#ifdef PNG_READ_CHECK_FOR_INVALID_INDEX_SUPPORTED + /* Added at libpng-1.5.10 */ + if (row_info->color_type == PNG_COLOR_TYPE_PALETTE && + png_ptr->num_palette_max >= 0) + png_do_check_palette_indexes(png_ptr, row_info); +#endif + +#ifdef PNG_READ_BGR_SUPPORTED + if (png_ptr->transformations & PNG_BGR) + png_do_bgr(row_info, png_ptr->row_buf + 1); +#endif + +#ifdef PNG_READ_PACKSWAP_SUPPORTED + if (png_ptr->transformations & PNG_PACKSWAP) + png_do_packswap(row_info, png_ptr->row_buf + 1); +#endif + +#ifdef PNG_READ_FILLER_SUPPORTED + if (png_ptr->transformations & PNG_FILLER) + png_do_read_filler(row_info, png_ptr->row_buf + 1, + (png_uint_32)png_ptr->filler, png_ptr->flags); +#endif + +#ifdef PNG_READ_INVERT_ALPHA_SUPPORTED + if (png_ptr->transformations & PNG_INVERT_ALPHA) + png_do_read_invert_alpha(row_info, png_ptr->row_buf + 1); +#endif + +#ifdef PNG_READ_SWAP_ALPHA_SUPPORTED + if (png_ptr->transformations & PNG_SWAP_ALPHA) + png_do_read_swap_alpha(row_info, png_ptr->row_buf + 1); +#endif + +#ifdef PNG_READ_16BIT_SUPPORTED +#ifdef PNG_READ_SWAP_SUPPORTED + if (png_ptr->transformations & PNG_SWAP_BYTES) + png_do_swap(row_info, png_ptr->row_buf + 1); +#endif +#endif + +#ifdef PNG_READ_USER_TRANSFORM_SUPPORTED + if (png_ptr->transformations & PNG_USER_TRANSFORM) + { + if (png_ptr->read_user_transform_fn != NULL) + (*(png_ptr->read_user_transform_fn)) /* User read transform function */ + (png_ptr, /* png_ptr */ + row_info, /* row_info: */ + /* png_uint_32 width; width of row */ + /* png_size_t rowbytes; number of bytes in row */ + /* png_byte color_type; color type of pixels */ + /* png_byte bit_depth; bit depth of samples */ + /* png_byte channels; number of channels (1-4) */ + /* png_byte pixel_depth; bits per pixel (depth*channels) */ + png_ptr->row_buf + 1); /* start of pixel data for row */ +#ifdef PNG_USER_TRANSFORM_PTR_SUPPORTED + if (png_ptr->user_transform_depth) + row_info->bit_depth = png_ptr->user_transform_depth; + + if (png_ptr->user_transform_channels) + row_info->channels = png_ptr->user_transform_channels; +#endif + row_info->pixel_depth = (png_byte)(row_info->bit_depth * + row_info->channels); + + row_info->rowbytes = PNG_ROWBYTES(row_info->pixel_depth, row_info->width); + } +#endif +} + +#ifdef PNG_READ_PACK_SUPPORTED +/* Unpack pixels of 1, 2, or 4 bits per pixel into 1 byte per pixel, + * without changing the actual values. Thus, if you had a row with + * a bit depth of 1, you would end up with bytes that only contained + * the numbers 0 or 1. If you would rather they contain 0 and 255, use + * png_do_shift() after this. + */ +void /* PRIVATE */ +png_do_unpack(png_row_infop row_info, png_bytep row) +{ + png_debug(1, "in png_do_unpack"); + + if (row_info->bit_depth < 8) + { + png_uint_32 i; + png_uint_32 row_width=row_info->width; + + switch (row_info->bit_depth) + { + case 1: + { + png_bytep sp = row + (png_size_t)((row_width - 1) >> 3); + png_bytep dp = row + (png_size_t)row_width - 1; + png_uint_32 shift = 7 - (int)((row_width + 7) & 0x07); + for (i = 0; i < row_width; i++) + { + *dp = (png_byte)((*sp >> shift) & 0x01); + + if (shift == 7) + { + shift = 0; + sp--; + } + + else + shift++; + + dp--; + } + break; + } + + case 2: + { + + png_bytep sp = row + (png_size_t)((row_width - 1) >> 2); + png_bytep dp = row + (png_size_t)row_width - 1; + png_uint_32 shift = (int)((3 - ((row_width + 3) & 0x03)) << 1); + for (i = 0; i < row_width; i++) + { + *dp = (png_byte)((*sp >> shift) & 0x03); + + if (shift == 6) + { + shift = 0; + sp--; + } + + else + shift += 2; + + dp--; + } + break; + } + + case 4: + { + png_bytep sp = row + (png_size_t)((row_width - 1) >> 1); + png_bytep dp = row + (png_size_t)row_width - 1; + png_uint_32 shift = (int)((1 - ((row_width + 1) & 0x01)) << 2); + for (i = 0; i < row_width; i++) + { + *dp = (png_byte)((*sp >> shift) & 0x0f); + + if (shift == 4) + { + shift = 0; + sp--; + } + + else + shift = 4; + + dp--; + } + break; + } + + default: + break; + } + row_info->bit_depth = 8; + row_info->pixel_depth = (png_byte)(8 * row_info->channels); + row_info->rowbytes = row_width * row_info->channels; + } +} +#endif + +#ifdef PNG_READ_SHIFT_SUPPORTED +/* Reverse the effects of png_do_shift. This routine merely shifts the + * pixels back to their significant bits values. Thus, if you have + * a row of bit depth 8, but only 5 are significant, this will shift + * the values back to 0 through 31. + */ +void /* PRIVATE */ +png_do_unshift(png_row_infop row_info, png_bytep row, + png_const_color_8p sig_bits) +{ + int color_type; + + png_debug(1, "in png_do_unshift"); + + /* The palette case has already been handled in the _init routine. */ + color_type = row_info->color_type; + + if (color_type != PNG_COLOR_TYPE_PALETTE) + { + int shift[4]; + int channels = 0; + int bit_depth = row_info->bit_depth; + + if (color_type & PNG_COLOR_MASK_COLOR) + { + shift[channels++] = bit_depth - sig_bits->red; + shift[channels++] = bit_depth - sig_bits->green; + shift[channels++] = bit_depth - sig_bits->blue; + } + + else + { + shift[channels++] = bit_depth - sig_bits->gray; + } + + if (color_type & PNG_COLOR_MASK_ALPHA) + { + shift[channels++] = bit_depth - sig_bits->alpha; + } + + { + int c, have_shift; + + for (c = have_shift = 0; c < channels; ++c) + { + /* A shift of more than the bit depth is an error condition but it + * gets ignored here. + */ + if (shift[c] <= 0 || shift[c] >= bit_depth) + shift[c] = 0; + + else + have_shift = 1; + } + + if (!have_shift) + return; + } + + switch (bit_depth) + { + default: + /* Must be 1bpp gray: should not be here! */ + /* NOTREACHED */ + break; + + case 2: + /* Must be 2bpp gray */ + /* assert(channels == 1 && shift[0] == 1) */ + { + png_bytep bp = row; + png_bytep bp_end = bp + row_info->rowbytes; + + while (bp < bp_end) + { + int b = (*bp >> 1) & 0x55; + *bp++ = (png_byte)b; + } + break; + } + + case 4: + /* Must be 4bpp gray */ + /* assert(channels == 1) */ + { + png_bytep bp = row; + png_bytep bp_end = bp + row_info->rowbytes; + int gray_shift = shift[0]; + int mask = 0xf >> gray_shift; + + mask |= mask << 4; + + while (bp < bp_end) + { + int b = (*bp >> gray_shift) & mask; + *bp++ = (png_byte)b; + } + break; + } + + case 8: + /* Single byte components, G, GA, RGB, RGBA */ + { + png_bytep bp = row; + png_bytep bp_end = bp + row_info->rowbytes; + int channel = 0; + + while (bp < bp_end) + { + int b = *bp >> shift[channel]; + if (++channel >= channels) + channel = 0; + *bp++ = (png_byte)b; + } + break; + } + +#ifdef PNG_READ_16BIT_SUPPORTED + case 16: + /* Double byte components, G, GA, RGB, RGBA */ + { + png_bytep bp = row; + png_bytep bp_end = bp + row_info->rowbytes; + int channel = 0; + + while (bp < bp_end) + { + int value = (bp[0] << 8) + bp[1]; + + value >>= shift[channel]; + if (++channel >= channels) + channel = 0; + *bp++ = (png_byte)(value >> 8); + *bp++ = (png_byte)(value & 0xff); + } + break; + } +#endif + } + } +} +#endif + +#ifdef PNG_READ_SCALE_16_TO_8_SUPPORTED +/* Scale rows of bit depth 16 down to 8 accurately */ +void /* PRIVATE */ +png_do_scale_16_to_8(png_row_infop row_info, png_bytep row) +{ + png_debug(1, "in png_do_scale_16_to_8"); + + if (row_info->bit_depth == 16) + { + png_bytep sp = row; /* source */ + png_bytep dp = row; /* destination */ + png_bytep ep = sp + row_info->rowbytes; /* end+1 */ + + while (sp < ep) + { + /* The input is an array of 16 bit components, these must be scaled to + * 8 bits each. For a 16 bit value V the required value (from the PNG + * specification) is: + * + * (V * 255) / 65535 + * + * This reduces to round(V / 257), or floor((V + 128.5)/257) + * + * Represent V as the two byte value vhi.vlo. Make a guess that the + * result is the top byte of V, vhi, then the correction to this value + * is: + * + * error = floor(((V-vhi.vhi) + 128.5) / 257) + * = floor(((vlo-vhi) + 128.5) / 257) + * + * This can be approximated using integer arithmetic (and a signed + * shift): + * + * error = (vlo-vhi+128) >> 8; + * + * The approximate differs from the exact answer only when (vlo-vhi) is + * 128; it then gives a correction of +1 when the exact correction is + * 0. This gives 128 errors. The exact answer (correct for all 16 bit + * input values) is: + * + * error = (vlo-vhi+128)*65535 >> 24; + * + * An alternative arithmetic calculation which also gives no errors is: + * + * (V * 255 + 32895) >> 16 + */ + + png_int_32 tmp = *sp++; /* must be signed! */ + tmp += (((int)*sp++ - tmp + 128) * 65535) >> 24; + *dp++ = (png_byte)tmp; + } + + row_info->bit_depth = 8; + row_info->pixel_depth = (png_byte)(8 * row_info->channels); + row_info->rowbytes = row_info->width * row_info->channels; + } +} +#endif + +#ifdef PNG_READ_STRIP_16_TO_8_SUPPORTED +void /* PRIVATE */ +/* Simply discard the low byte. This was the default behavior prior + * to libpng-1.5.4. + */ +png_do_chop(png_row_infop row_info, png_bytep row) +{ + png_debug(1, "in png_do_chop"); + + if (row_info->bit_depth == 16) + { + png_bytep sp = row; /* source */ + png_bytep dp = row; /* destination */ + png_bytep ep = sp + row_info->rowbytes; /* end+1 */ + + while (sp < ep) + { + *dp++ = *sp; + sp += 2; /* skip low byte */ + } + + row_info->bit_depth = 8; + row_info->pixel_depth = (png_byte)(8 * row_info->channels); + row_info->rowbytes = row_info->width * row_info->channels; + } +} +#endif + +#ifdef PNG_READ_SWAP_ALPHA_SUPPORTED +void /* PRIVATE */ +png_do_read_swap_alpha(png_row_infop row_info, png_bytep row) +{ + png_debug(1, "in png_do_read_swap_alpha"); + + { + png_uint_32 row_width = row_info->width; + if (row_info->color_type == PNG_COLOR_TYPE_RGB_ALPHA) + { + /* This converts from RGBA to ARGB */ + if (row_info->bit_depth == 8) + { + png_bytep sp = row + row_info->rowbytes; + png_bytep dp = sp; + png_byte save; + png_uint_32 i; + + for (i = 0; i < row_width; i++) + { + save = *(--sp); + *(--dp) = *(--sp); + *(--dp) = *(--sp); + *(--dp) = *(--sp); + *(--dp) = save; + } + } + +#ifdef PNG_READ_16BIT_SUPPORTED + /* This converts from RRGGBBAA to AARRGGBB */ + else + { + png_bytep sp = row + row_info->rowbytes; + png_bytep dp = sp; + png_byte save[2]; + png_uint_32 i; + + for (i = 0; i < row_width; i++) + { + save[0] = *(--sp); + save[1] = *(--sp); + *(--dp) = *(--sp); + *(--dp) = *(--sp); + *(--dp) = *(--sp); + *(--dp) = *(--sp); + *(--dp) = *(--sp); + *(--dp) = *(--sp); + *(--dp) = save[0]; + *(--dp) = save[1]; + } + } +#endif + } + + else if (row_info->color_type == PNG_COLOR_TYPE_GRAY_ALPHA) + { + /* This converts from GA to AG */ + if (row_info->bit_depth == 8) + { + png_bytep sp = row + row_info->rowbytes; + png_bytep dp = sp; + png_byte save; + png_uint_32 i; + + for (i = 0; i < row_width; i++) + { + save = *(--sp); + *(--dp) = *(--sp); + *(--dp) = save; + } + } + +#ifdef PNG_READ_16BIT_SUPPORTED + /* This converts from GGAA to AAGG */ + else + { + png_bytep sp = row + row_info->rowbytes; + png_bytep dp = sp; + png_byte save[2]; + png_uint_32 i; + + for (i = 0; i < row_width; i++) + { + save[0] = *(--sp); + save[1] = *(--sp); + *(--dp) = *(--sp); + *(--dp) = *(--sp); + *(--dp) = save[0]; + *(--dp) = save[1]; + } + } +#endif + } + } +} +#endif + +#ifdef PNG_READ_INVERT_ALPHA_SUPPORTED +void /* PRIVATE */ +png_do_read_invert_alpha(png_row_infop row_info, png_bytep row) +{ + png_uint_32 row_width; + png_debug(1, "in png_do_read_invert_alpha"); + + row_width = row_info->width; + if (row_info->color_type == PNG_COLOR_TYPE_RGB_ALPHA) + { + if (row_info->bit_depth == 8) + { + /* This inverts the alpha channel in RGBA */ + png_bytep sp = row + row_info->rowbytes; + png_bytep dp = sp; + png_uint_32 i; + + for (i = 0; i < row_width; i++) + { + *(--dp) = (png_byte)(255 - *(--sp)); + +/* This does nothing: + *(--dp) = *(--sp); + *(--dp) = *(--sp); + *(--dp) = *(--sp); + We can replace it with: +*/ + sp-=3; + dp=sp; + } + } + +#ifdef PNG_READ_16BIT_SUPPORTED + /* This inverts the alpha channel in RRGGBBAA */ + else + { + png_bytep sp = row + row_info->rowbytes; + png_bytep dp = sp; + png_uint_32 i; + + for (i = 0; i < row_width; i++) + { + *(--dp) = (png_byte)(255 - *(--sp)); + *(--dp) = (png_byte)(255 - *(--sp)); + +/* This does nothing: + *(--dp) = *(--sp); + *(--dp) = *(--sp); + *(--dp) = *(--sp); + *(--dp) = *(--sp); + *(--dp) = *(--sp); + *(--dp) = *(--sp); + We can replace it with: +*/ + sp-=6; + dp=sp; + } + } +#endif + } + else if (row_info->color_type == PNG_COLOR_TYPE_GRAY_ALPHA) + { + if (row_info->bit_depth == 8) + { + /* This inverts the alpha channel in GA */ + png_bytep sp = row + row_info->rowbytes; + png_bytep dp = sp; + png_uint_32 i; + + for (i = 0; i < row_width; i++) + { + *(--dp) = (png_byte)(255 - *(--sp)); + *(--dp) = *(--sp); + } + } + +#ifdef PNG_READ_16BIT_SUPPORTED + else + { + /* This inverts the alpha channel in GGAA */ + png_bytep sp = row + row_info->rowbytes; + png_bytep dp = sp; + png_uint_32 i; + + for (i = 0; i < row_width; i++) + { + *(--dp) = (png_byte)(255 - *(--sp)); + *(--dp) = (png_byte)(255 - *(--sp)); +/* + *(--dp) = *(--sp); + *(--dp) = *(--sp); +*/ + sp-=2; + dp=sp; + } + } +#endif + } +} +#endif + +#ifdef PNG_READ_FILLER_SUPPORTED +/* Add filler channel if we have RGB color */ +void /* PRIVATE */ +png_do_read_filler(png_row_infop row_info, png_bytep row, + png_uint_32 filler, png_uint_32 flags) +{ + png_uint_32 i; + png_uint_32 row_width = row_info->width; + +#ifdef PNG_READ_16BIT_SUPPORTED + png_byte hi_filler = (png_byte)((filler>>8) & 0xff); +#endif + png_byte lo_filler = (png_byte)(filler & 0xff); + + png_debug(1, "in png_do_read_filler"); + + if ( + row_info->color_type == PNG_COLOR_TYPE_GRAY) + { + if (row_info->bit_depth == 8) + { + if (flags & PNG_FLAG_FILLER_AFTER) + { + /* This changes the data from G to GX */ + png_bytep sp = row + (png_size_t)row_width; + png_bytep dp = sp + (png_size_t)row_width; + for (i = 1; i < row_width; i++) + { + *(--dp) = lo_filler; + *(--dp) = *(--sp); + } + *(--dp) = lo_filler; + row_info->channels = 2; + row_info->pixel_depth = 16; + row_info->rowbytes = row_width * 2; + } + + else + { + /* This changes the data from G to XG */ + png_bytep sp = row + (png_size_t)row_width; + png_bytep dp = sp + (png_size_t)row_width; + for (i = 0; i < row_width; i++) + { + *(--dp) = *(--sp); + *(--dp) = lo_filler; + } + row_info->channels = 2; + row_info->pixel_depth = 16; + row_info->rowbytes = row_width * 2; + } + } + +#ifdef PNG_READ_16BIT_SUPPORTED + else if (row_info->bit_depth == 16) + { + if (flags & PNG_FLAG_FILLER_AFTER) + { + /* This changes the data from GG to GGXX */ + png_bytep sp = row + (png_size_t)row_width * 2; + png_bytep dp = sp + (png_size_t)row_width * 2; + for (i = 1; i < row_width; i++) + { + *(--dp) = hi_filler; + *(--dp) = lo_filler; + *(--dp) = *(--sp); + *(--dp) = *(--sp); + } + *(--dp) = hi_filler; + *(--dp) = lo_filler; + row_info->channels = 2; + row_info->pixel_depth = 32; + row_info->rowbytes = row_width * 4; + } + + else + { + /* This changes the data from GG to XXGG */ + png_bytep sp = row + (png_size_t)row_width * 2; + png_bytep dp = sp + (png_size_t)row_width * 2; + for (i = 0; i < row_width; i++) + { + *(--dp) = *(--sp); + *(--dp) = *(--sp); + *(--dp) = hi_filler; + *(--dp) = lo_filler; + } + row_info->channels = 2; + row_info->pixel_depth = 32; + row_info->rowbytes = row_width * 4; + } + } +#endif + } /* COLOR_TYPE == GRAY */ + else if (row_info->color_type == PNG_COLOR_TYPE_RGB) + { + if (row_info->bit_depth == 8) + { + if (flags & PNG_FLAG_FILLER_AFTER) + { + /* This changes the data from RGB to RGBX */ + png_bytep sp = row + (png_size_t)row_width * 3; + png_bytep dp = sp + (png_size_t)row_width; + for (i = 1; i < row_width; i++) + { + *(--dp) = lo_filler; + *(--dp) = *(--sp); + *(--dp) = *(--sp); + *(--dp) = *(--sp); + } + *(--dp) = lo_filler; + row_info->channels = 4; + row_info->pixel_depth = 32; + row_info->rowbytes = row_width * 4; + } + + else + { + /* This changes the data from RGB to XRGB */ + png_bytep sp = row + (png_size_t)row_width * 3; + png_bytep dp = sp + (png_size_t)row_width; + for (i = 0; i < row_width; i++) + { + *(--dp) = *(--sp); + *(--dp) = *(--sp); + *(--dp) = *(--sp); + *(--dp) = lo_filler; + } + row_info->channels = 4; + row_info->pixel_depth = 32; + row_info->rowbytes = row_width * 4; + } + } + +#ifdef PNG_READ_16BIT_SUPPORTED + else if (row_info->bit_depth == 16) + { + if (flags & PNG_FLAG_FILLER_AFTER) + { + /* This changes the data from RRGGBB to RRGGBBXX */ + png_bytep sp = row + (png_size_t)row_width * 6; + png_bytep dp = sp + (png_size_t)row_width * 2; + for (i = 1; i < row_width; i++) + { + *(--dp) = hi_filler; + *(--dp) = lo_filler; + *(--dp) = *(--sp); + *(--dp) = *(--sp); + *(--dp) = *(--sp); + *(--dp) = *(--sp); + *(--dp) = *(--sp); + *(--dp) = *(--sp); + } + *(--dp) = hi_filler; + *(--dp) = lo_filler; + row_info->channels = 4; + row_info->pixel_depth = 64; + row_info->rowbytes = row_width * 8; + } + + else + { + /* This changes the data from RRGGBB to XXRRGGBB */ + png_bytep sp = row + (png_size_t)row_width * 6; + png_bytep dp = sp + (png_size_t)row_width * 2; + for (i = 0; i < row_width; i++) + { + *(--dp) = *(--sp); + *(--dp) = *(--sp); + *(--dp) = *(--sp); + *(--dp) = *(--sp); + *(--dp) = *(--sp); + *(--dp) = *(--sp); + *(--dp) = hi_filler; + *(--dp) = lo_filler; + } + + row_info->channels = 4; + row_info->pixel_depth = 64; + row_info->rowbytes = row_width * 8; + } + } +#endif + } /* COLOR_TYPE == RGB */ +} +#endif + +#ifdef PNG_READ_GRAY_TO_RGB_SUPPORTED +/* Expand grayscale files to RGB, with or without alpha */ +void /* PRIVATE */ +png_do_gray_to_rgb(png_row_infop row_info, png_bytep row) +{ + png_uint_32 i; + png_uint_32 row_width = row_info->width; + + png_debug(1, "in png_do_gray_to_rgb"); + + if (row_info->bit_depth >= 8 && + !(row_info->color_type & PNG_COLOR_MASK_COLOR)) + { + if (row_info->color_type == PNG_COLOR_TYPE_GRAY) + { + if (row_info->bit_depth == 8) + { + /* This changes G to RGB */ + png_bytep sp = row + (png_size_t)row_width - 1; + png_bytep dp = sp + (png_size_t)row_width * 2; + for (i = 0; i < row_width; i++) + { + *(dp--) = *sp; + *(dp--) = *sp; + *(dp--) = *(sp--); + } + } + + else + { + /* This changes GG to RRGGBB */ + png_bytep sp = row + (png_size_t)row_width * 2 - 1; + png_bytep dp = sp + (png_size_t)row_width * 4; + for (i = 0; i < row_width; i++) + { + *(dp--) = *sp; + *(dp--) = *(sp - 1); + *(dp--) = *sp; + *(dp--) = *(sp - 1); + *(dp--) = *(sp--); + *(dp--) = *(sp--); + } + } + } + + else if (row_info->color_type == PNG_COLOR_TYPE_GRAY_ALPHA) + { + if (row_info->bit_depth == 8) + { + /* This changes GA to RGBA */ + png_bytep sp = row + (png_size_t)row_width * 2 - 1; + png_bytep dp = sp + (png_size_t)row_width * 2; + for (i = 0; i < row_width; i++) + { + *(dp--) = *(sp--); + *(dp--) = *sp; + *(dp--) = *sp; + *(dp--) = *(sp--); + } + } + + else + { + /* This changes GGAA to RRGGBBAA */ + png_bytep sp = row + (png_size_t)row_width * 4 - 1; + png_bytep dp = sp + (png_size_t)row_width * 4; + for (i = 0; i < row_width; i++) + { + *(dp--) = *(sp--); + *(dp--) = *(sp--); + *(dp--) = *sp; + *(dp--) = *(sp - 1); + *(dp--) = *sp; + *(dp--) = *(sp - 1); + *(dp--) = *(sp--); + *(dp--) = *(sp--); + } + } + } + row_info->channels = (png_byte)(row_info->channels + 2); + row_info->color_type |= PNG_COLOR_MASK_COLOR; + row_info->pixel_depth = (png_byte)(row_info->channels * + row_info->bit_depth); + row_info->rowbytes = PNG_ROWBYTES(row_info->pixel_depth, row_width); + } +} +#endif + +#ifdef PNG_READ_RGB_TO_GRAY_SUPPORTED +/* Reduce RGB files to grayscale, with or without alpha + * using the equation given in Poynton's ColorFAQ of 1998-01-04 at + * (THIS LINK IS DEAD June 2008 but + * versions dated 1998 through November 2002 have been archived at + * http://web.archive.org/web/20000816232553/http://www.inforamp.net/ + * ~poynton/notes/colour_and_gamma/ColorFAQ.txt ) + * Charles Poynton poynton at poynton.com + * + * Y = 0.212671 * R + 0.715160 * G + 0.072169 * B + * + * which can be expressed with integers as + * + * Y = (6969 * R + 23434 * G + 2365 * B)/32768 + * + * Poynton's current link (as of January 2003 through July 2011): + * + * has changed the numbers slightly: + * + * Y = 0.2126*R + 0.7152*G + 0.0722*B + * + * which can be expressed with integers as + * + * Y = (6966 * R + 23436 * G + 2366 * B)/32768 + * + * Historically, however, libpng uses numbers derived from the ITU-R Rec 709 + * end point chromaticities and the D65 white point. Depending on the + * precision used for the D65 white point this produces a variety of different + * numbers, however if the four decimal place value used in ITU-R Rec 709 is + * used (0.3127,0.3290) the Y calculation would be: + * + * Y = (6968 * R + 23435 * G + 2366 * B)/32768 + * + * While this is correct the rounding results in an overflow for white, because + * the sum of the rounded coefficients is 32769, not 32768. Consequently + * libpng uses, instead, the closest non-overflowing approximation: + * + * Y = (6968 * R + 23434 * G + 2366 * B)/32768 + * + * Starting with libpng-1.5.5, if the image being converted has a cHRM chunk + * (including an sRGB chunk) then the chromaticities are used to calculate the + * coefficients. See the chunk handling in pngrutil.c for more information. + * + * In all cases the calculation is to be done in a linear colorspace. If no + * gamma information is available to correct the encoding of the original RGB + * values this results in an implicit assumption that the original PNG RGB + * values were linear. + * + * Other integer coefficents can be used via png_set_rgb_to_gray(). Because + * the API takes just red and green coefficients the blue coefficient is + * calculated to make the sum 32768. This will result in different rounding + * to that used above. + */ +int /* PRIVATE */ +png_do_rgb_to_gray(png_structrp png_ptr, png_row_infop row_info, png_bytep row) + +{ + int rgb_error = 0; + + png_debug(1, "in png_do_rgb_to_gray"); + + if (!(row_info->color_type & PNG_COLOR_MASK_PALETTE) && + (row_info->color_type & PNG_COLOR_MASK_COLOR)) + { + PNG_CONST png_uint_32 rc = png_ptr->rgb_to_gray_red_coeff; + PNG_CONST png_uint_32 gc = png_ptr->rgb_to_gray_green_coeff; + PNG_CONST png_uint_32 bc = 32768 - rc - gc; + PNG_CONST png_uint_32 row_width = row_info->width; + PNG_CONST int have_alpha = + (row_info->color_type & PNG_COLOR_MASK_ALPHA) != 0; + + if (row_info->bit_depth == 8) + { +#ifdef PNG_READ_GAMMA_SUPPORTED + /* Notice that gamma to/from 1 are not necessarily inverses (if + * there is an overall gamma correction). Prior to 1.5.5 this code + * checked the linearized values for equality; this doesn't match + * the documentation, the original values must be checked. + */ + if (png_ptr->gamma_from_1 != NULL && png_ptr->gamma_to_1 != NULL) + { + png_bytep sp = row; + png_bytep dp = row; + png_uint_32 i; + + for (i = 0; i < row_width; i++) + { + png_byte red = *(sp++); + png_byte green = *(sp++); + png_byte blue = *(sp++); + + if (red != green || red != blue) + { + red = png_ptr->gamma_to_1[red]; + green = png_ptr->gamma_to_1[green]; + blue = png_ptr->gamma_to_1[blue]; + + rgb_error |= 1; + *(dp++) = png_ptr->gamma_from_1[ + (rc*red + gc*green + bc*blue + 16384)>>15]; + } + + else + { + /* If there is no overall correction the table will not be + * set. + */ + if (png_ptr->gamma_table != NULL) + red = png_ptr->gamma_table[red]; + + *(dp++) = red; + } + + if (have_alpha) + *(dp++) = *(sp++); + } + } + else +#endif + { + png_bytep sp = row; + png_bytep dp = row; + png_uint_32 i; + + for (i = 0; i < row_width; i++) + { + png_byte red = *(sp++); + png_byte green = *(sp++); + png_byte blue = *(sp++); + + if (red != green || red != blue) + { + rgb_error |= 1; + /* NOTE: this is the historical approach which simply + * truncates the results. + */ + *(dp++) = (png_byte)((rc*red + gc*green + bc*blue)>>15); + } + + else + *(dp++) = red; + + if (have_alpha) + *(dp++) = *(sp++); + } + } + } + + else /* RGB bit_depth == 16 */ + { +#ifdef PNG_READ_GAMMA_SUPPORTED + if (png_ptr->gamma_16_to_1 != NULL && png_ptr->gamma_16_from_1 != NULL) + { + png_bytep sp = row; + png_bytep dp = row; + png_uint_32 i; + + for (i = 0; i < row_width; i++) + { + png_uint_16 red, green, blue, w; + + red = (png_uint_16)(((*(sp))<<8) | *(sp + 1)); sp += 2; + green = (png_uint_16)(((*(sp))<<8) | *(sp + 1)); sp += 2; + blue = (png_uint_16)(((*(sp))<<8) | *(sp + 1)); sp += 2; + + if (red == green && red == blue) + { + if (png_ptr->gamma_16_table != NULL) + w = png_ptr->gamma_16_table[(red&0xff) + >> png_ptr->gamma_shift][red>>8]; + + else + w = red; + } + + else + { + png_uint_16 red_1 = png_ptr->gamma_16_to_1[(red&0xff) + >> png_ptr->gamma_shift][red>>8]; + png_uint_16 green_1 = + png_ptr->gamma_16_to_1[(green&0xff) >> + png_ptr->gamma_shift][green>>8]; + png_uint_16 blue_1 = png_ptr->gamma_16_to_1[(blue&0xff) + >> png_ptr->gamma_shift][blue>>8]; + png_uint_16 gray16 = (png_uint_16)((rc*red_1 + gc*green_1 + + bc*blue_1 + 16384)>>15); + w = png_ptr->gamma_16_from_1[(gray16&0xff) >> + png_ptr->gamma_shift][gray16 >> 8]; + rgb_error |= 1; + } + + *(dp++) = (png_byte)((w>>8) & 0xff); + *(dp++) = (png_byte)(w & 0xff); + + if (have_alpha) + { + *(dp++) = *(sp++); + *(dp++) = *(sp++); + } + } + } + else +#endif + { + png_bytep sp = row; + png_bytep dp = row; + png_uint_32 i; + + for (i = 0; i < row_width; i++) + { + png_uint_16 red, green, blue, gray16; + + red = (png_uint_16)(((*(sp))<<8) | *(sp + 1)); sp += 2; + green = (png_uint_16)(((*(sp))<<8) | *(sp + 1)); sp += 2; + blue = (png_uint_16)(((*(sp))<<8) | *(sp + 1)); sp += 2; + + if (red != green || red != blue) + rgb_error |= 1; + + /* From 1.5.5 in the 16 bit case do the accurate conversion even + * in the 'fast' case - this is because this is where the code + * ends up when handling linear 16 bit data. + */ + gray16 = (png_uint_16)((rc*red + gc*green + bc*blue + 16384) >> + 15); + *(dp++) = (png_byte)((gray16>>8) & 0xff); + *(dp++) = (png_byte)(gray16 & 0xff); + + if (have_alpha) + { + *(dp++) = *(sp++); + *(dp++) = *(sp++); + } + } + } + } + + row_info->channels = (png_byte)(row_info->channels - 2); + row_info->color_type = (png_byte)(row_info->color_type & + ~PNG_COLOR_MASK_COLOR); + row_info->pixel_depth = (png_byte)(row_info->channels * + row_info->bit_depth); + row_info->rowbytes = PNG_ROWBYTES(row_info->pixel_depth, row_width); + } + return rgb_error; +} +#endif +#endif /* PNG_READ_TRANSFORMS_SUPPORTED */ + +#ifdef PNG_BUILD_GRAYSCALE_PALETTE_SUPPORTED +/* Build a grayscale palette. Palette is assumed to be 1 << bit_depth + * large of png_color. This lets grayscale images be treated as + * paletted. Most useful for gamma correction and simplification + * of code. This API is not used internally. + */ +void PNGAPI +png_build_grayscale_palette(int bit_depth, png_colorp palette) +{ + int num_palette; + int color_inc; + int i; + int v; + + png_debug(1, "in png_do_build_grayscale_palette"); + + if (palette == NULL) + return; + + switch (bit_depth) + { + case 1: + num_palette = 2; + color_inc = 0xff; + break; + + case 2: + num_palette = 4; + color_inc = 0x55; + break; + + case 4: + num_palette = 16; + color_inc = 0x11; + break; + + case 8: + num_palette = 256; + color_inc = 1; + break; + + default: + num_palette = 0; + color_inc = 0; + break; + } + + for (i = 0, v = 0; i < num_palette; i++, v += color_inc) + { + palette[i].red = (png_byte)v; + palette[i].green = (png_byte)v; + palette[i].blue = (png_byte)v; + } +} +#endif + + +#ifdef PNG_READ_TRANSFORMS_SUPPORTED +#if defined(PNG_READ_BACKGROUND_SUPPORTED) ||\ + defined(PNG_READ_ALPHA_MODE_SUPPORTED) +/* Replace any alpha or transparency with the supplied background color. + * "background" is already in the screen gamma, while "background_1" is + * at a gamma of 1.0. Paletted files have already been taken care of. + */ +void /* PRIVATE */ +png_do_compose(png_row_infop row_info, png_bytep row, png_structrp png_ptr) +{ +#ifdef PNG_READ_GAMMA_SUPPORTED + png_const_bytep gamma_table = png_ptr->gamma_table; + png_const_bytep gamma_from_1 = png_ptr->gamma_from_1; + png_const_bytep gamma_to_1 = png_ptr->gamma_to_1; + png_const_uint_16pp gamma_16 = png_ptr->gamma_16_table; + png_const_uint_16pp gamma_16_from_1 = png_ptr->gamma_16_from_1; + png_const_uint_16pp gamma_16_to_1 = png_ptr->gamma_16_to_1; + int gamma_shift = png_ptr->gamma_shift; + int optimize = (png_ptr->flags & PNG_FLAG_OPTIMIZE_ALPHA) != 0; +#endif + + png_bytep sp; + png_uint_32 i; + png_uint_32 row_width = row_info->width; + int shift; + + png_debug(1, "in png_do_compose"); + + { + switch (row_info->color_type) + { + case PNG_COLOR_TYPE_GRAY: + { + switch (row_info->bit_depth) + { + case 1: + { + sp = row; + shift = 7; + for (i = 0; i < row_width; i++) + { + if ((png_uint_16)((*sp >> shift) & 0x01) + == png_ptr->trans_color.gray) + { + unsigned int tmp = *sp & (0x7f7f >> (7 - shift)); + tmp |= png_ptr->background.gray << shift; + *sp = (png_byte)(tmp & 0xff); + } + + if (!shift) + { + shift = 7; + sp++; + } + + else + shift--; + } + break; + } + + case 2: + { +#ifdef PNG_READ_GAMMA_SUPPORTED + if (gamma_table != NULL) + { + sp = row; + shift = 6; + for (i = 0; i < row_width; i++) + { + if ((png_uint_16)((*sp >> shift) & 0x03) + == png_ptr->trans_color.gray) + { + unsigned int tmp = *sp & (0x3f3f >> (6 - shift)); + tmp |= png_ptr->background.gray << shift; + *sp = (png_byte)(tmp & 0xff); + } + + else + { + unsigned int p = (*sp >> shift) & 0x03; + unsigned int g = (gamma_table [p | (p << 2) | + (p << 4) | (p << 6)] >> 6) & 0x03; + unsigned int tmp = *sp & (0x3f3f >> (6 - shift)); + tmp |= g << shift; + *sp = (png_byte)(tmp & 0xff); + } + + if (!shift) + { + shift = 6; + sp++; + } + + else + shift -= 2; + } + } + + else +#endif + { + sp = row; + shift = 6; + for (i = 0; i < row_width; i++) + { + if ((png_uint_16)((*sp >> shift) & 0x03) + == png_ptr->trans_color.gray) + { + unsigned int tmp = *sp & (0x3f3f >> (6 - shift)); + tmp |= png_ptr->background.gray << shift; + *sp = (png_byte)(tmp & 0xff); + } + + if (!shift) + { + shift = 6; + sp++; + } + + else + shift -= 2; + } + } + break; + } + + case 4: + { +#ifdef PNG_READ_GAMMA_SUPPORTED + if (gamma_table != NULL) + { + sp = row; + shift = 4; + for (i = 0; i < row_width; i++) + { + if ((png_uint_16)((*sp >> shift) & 0x0f) + == png_ptr->trans_color.gray) + { + unsigned int tmp = *sp & (0xf0f >> (4 - shift)); + tmp |= png_ptr->background.gray << shift; + *sp = (png_byte)(tmp & 0xff); + } + + else + { + unsigned int p = (*sp >> shift) & 0x0f; + unsigned int g = (gamma_table[p | (p << 4)] >> 4) & + 0x0f; + unsigned int tmp = *sp & (0xf0f >> (4 - shift)); + tmp |= g << shift; + *sp = (png_byte)(tmp & 0xff); + } + + if (!shift) + { + shift = 4; + sp++; + } + + else + shift -= 4; + } + } + + else +#endif + { + sp = row; + shift = 4; + for (i = 0; i < row_width; i++) + { + if ((png_uint_16)((*sp >> shift) & 0x0f) + == png_ptr->trans_color.gray) + { + unsigned int tmp = *sp & (0xf0f >> (4 - shift)); + tmp |= png_ptr->background.gray << shift; + *sp = (png_byte)(tmp & 0xff); + } + + if (!shift) + { + shift = 4; + sp++; + } + + else + shift -= 4; + } + } + break; + } + + case 8: + { +#ifdef PNG_READ_GAMMA_SUPPORTED + if (gamma_table != NULL) + { + sp = row; + for (i = 0; i < row_width; i++, sp++) + { + if (*sp == png_ptr->trans_color.gray) + *sp = (png_byte)png_ptr->background.gray; + + else + *sp = gamma_table[*sp]; + } + } + else +#endif + { + sp = row; + for (i = 0; i < row_width; i++, sp++) + { + if (*sp == png_ptr->trans_color.gray) + *sp = (png_byte)png_ptr->background.gray; + } + } + break; + } + + case 16: + { +#ifdef PNG_READ_GAMMA_SUPPORTED + if (gamma_16 != NULL) + { + sp = row; + for (i = 0; i < row_width; i++, sp += 2) + { + png_uint_16 v; + + v = (png_uint_16)(((*sp) << 8) + *(sp + 1)); + + if (v == png_ptr->trans_color.gray) + { + /* Background is already in screen gamma */ + *sp = (png_byte)((png_ptr->background.gray >> 8) + & 0xff); + *(sp + 1) = (png_byte)(png_ptr->background.gray + & 0xff); + } + + else + { + v = gamma_16[*(sp + 1) >> gamma_shift][*sp]; + *sp = (png_byte)((v >> 8) & 0xff); + *(sp + 1) = (png_byte)(v & 0xff); + } + } + } + else +#endif + { + sp = row; + for (i = 0; i < row_width; i++, sp += 2) + { + png_uint_16 v; + + v = (png_uint_16)(((*sp) << 8) + *(sp + 1)); + + if (v == png_ptr->trans_color.gray) + { + *sp = (png_byte)((png_ptr->background.gray >> 8) + & 0xff); + *(sp + 1) = (png_byte)(png_ptr->background.gray + & 0xff); + } + } + } + break; + } + + default: + break; + } + break; + } + + case PNG_COLOR_TYPE_RGB: + { + if (row_info->bit_depth == 8) + { +#ifdef PNG_READ_GAMMA_SUPPORTED + if (gamma_table != NULL) + { + sp = row; + for (i = 0; i < row_width; i++, sp += 3) + { + if (*sp == png_ptr->trans_color.red && + *(sp + 1) == png_ptr->trans_color.green && + *(sp + 2) == png_ptr->trans_color.blue) + { + *sp = (png_byte)png_ptr->background.red; + *(sp + 1) = (png_byte)png_ptr->background.green; + *(sp + 2) = (png_byte)png_ptr->background.blue; + } + + else + { + *sp = gamma_table[*sp]; + *(sp + 1) = gamma_table[*(sp + 1)]; + *(sp + 2) = gamma_table[*(sp + 2)]; + } + } + } + else +#endif + { + sp = row; + for (i = 0; i < row_width; i++, sp += 3) + { + if (*sp == png_ptr->trans_color.red && + *(sp + 1) == png_ptr->trans_color.green && + *(sp + 2) == png_ptr->trans_color.blue) + { + *sp = (png_byte)png_ptr->background.red; + *(sp + 1) = (png_byte)png_ptr->background.green; + *(sp + 2) = (png_byte)png_ptr->background.blue; + } + } + } + } + else /* if (row_info->bit_depth == 16) */ + { +#ifdef PNG_READ_GAMMA_SUPPORTED + if (gamma_16 != NULL) + { + sp = row; + for (i = 0; i < row_width; i++, sp += 6) + { + png_uint_16 r = (png_uint_16)(((*sp) << 8) + *(sp + 1)); + + png_uint_16 g = (png_uint_16)(((*(sp + 2)) << 8) + + *(sp + 3)); + + png_uint_16 b = (png_uint_16)(((*(sp + 4)) << 8) + + *(sp + 5)); + + if (r == png_ptr->trans_color.red && + g == png_ptr->trans_color.green && + b == png_ptr->trans_color.blue) + { + /* Background is already in screen gamma */ + *sp = (png_byte)((png_ptr->background.red >> 8) & 0xff); + *(sp + 1) = (png_byte)(png_ptr->background.red & 0xff); + *(sp + 2) = (png_byte)((png_ptr->background.green >> 8) + & 0xff); + *(sp + 3) = (png_byte)(png_ptr->background.green + & 0xff); + *(sp + 4) = (png_byte)((png_ptr->background.blue >> 8) + & 0xff); + *(sp + 5) = (png_byte)(png_ptr->background.blue & 0xff); + } + + else + { + png_uint_16 v = gamma_16[*(sp + 1) >> gamma_shift][*sp]; + *sp = (png_byte)((v >> 8) & 0xff); + *(sp + 1) = (png_byte)(v & 0xff); + + v = gamma_16[*(sp + 3) >> gamma_shift][*(sp + 2)]; + *(sp + 2) = (png_byte)((v >> 8) & 0xff); + *(sp + 3) = (png_byte)(v & 0xff); + + v = gamma_16[*(sp + 5) >> gamma_shift][*(sp + 4)]; + *(sp + 4) = (png_byte)((v >> 8) & 0xff); + *(sp + 5) = (png_byte)(v & 0xff); + } + } + } + + else +#endif + { + sp = row; + for (i = 0; i < row_width; i++, sp += 6) + { + png_uint_16 r = (png_uint_16)(((*sp) << 8) + *(sp + 1)); + + png_uint_16 g = (png_uint_16)(((*(sp + 2)) << 8) + + *(sp + 3)); + + png_uint_16 b = (png_uint_16)(((*(sp + 4)) << 8) + + *(sp + 5)); + + if (r == png_ptr->trans_color.red && + g == png_ptr->trans_color.green && + b == png_ptr->trans_color.blue) + { + *sp = (png_byte)((png_ptr->background.red >> 8) & 0xff); + *(sp + 1) = (png_byte)(png_ptr->background.red & 0xff); + *(sp + 2) = (png_byte)((png_ptr->background.green >> 8) + & 0xff); + *(sp + 3) = (png_byte)(png_ptr->background.green + & 0xff); + *(sp + 4) = (png_byte)((png_ptr->background.blue >> 8) + & 0xff); + *(sp + 5) = (png_byte)(png_ptr->background.blue & 0xff); + } + } + } + } + break; + } + + case PNG_COLOR_TYPE_GRAY_ALPHA: + { + if (row_info->bit_depth == 8) + { +#ifdef PNG_READ_GAMMA_SUPPORTED + if (gamma_to_1 != NULL && gamma_from_1 != NULL && + gamma_table != NULL) + { + sp = row; + for (i = 0; i < row_width; i++, sp += 2) + { + png_uint_16 a = *(sp + 1); + + if (a == 0xff) + *sp = gamma_table[*sp]; + + else if (a == 0) + { + /* Background is already in screen gamma */ + *sp = (png_byte)png_ptr->background.gray; + } + + else + { + png_byte v, w; + + v = gamma_to_1[*sp]; + png_composite(w, v, a, png_ptr->background_1.gray); + if (!optimize) + w = gamma_from_1[w]; + *sp = w; + } + } + } + else +#endif + { + sp = row; + for (i = 0; i < row_width; i++, sp += 2) + { + png_byte a = *(sp + 1); + + if (a == 0) + *sp = (png_byte)png_ptr->background.gray; + + else if (a < 0xff) + png_composite(*sp, *sp, a, png_ptr->background.gray); + } + } + } + else /* if (png_ptr->bit_depth == 16) */ + { +#ifdef PNG_READ_GAMMA_SUPPORTED + if (gamma_16 != NULL && gamma_16_from_1 != NULL && + gamma_16_to_1 != NULL) + { + sp = row; + for (i = 0; i < row_width; i++, sp += 4) + { + png_uint_16 a = (png_uint_16)(((*(sp + 2)) << 8) + + *(sp + 3)); + + if (a == (png_uint_16)0xffff) + { + png_uint_16 v; + + v = gamma_16[*(sp + 1) >> gamma_shift][*sp]; + *sp = (png_byte)((v >> 8) & 0xff); + *(sp + 1) = (png_byte)(v & 0xff); + } + + else if (a == 0) + { + /* Background is already in screen gamma */ + *sp = (png_byte)((png_ptr->background.gray >> 8) + & 0xff); + *(sp + 1) = (png_byte)(png_ptr->background.gray & 0xff); + } + + else + { + png_uint_16 g, v, w; + + g = gamma_16_to_1[*(sp + 1) >> gamma_shift][*sp]; + png_composite_16(v, g, a, png_ptr->background_1.gray); + if (optimize) + w = v; + else + w = gamma_16_from_1[(v&0xff) >> gamma_shift][v >> 8]; + *sp = (png_byte)((w >> 8) & 0xff); + *(sp + 1) = (png_byte)(w & 0xff); + } + } + } + else +#endif + { + sp = row; + for (i = 0; i < row_width; i++, sp += 4) + { + png_uint_16 a = (png_uint_16)(((*(sp + 2)) << 8) + + *(sp + 3)); + + if (a == 0) + { + *sp = (png_byte)((png_ptr->background.gray >> 8) + & 0xff); + *(sp + 1) = (png_byte)(png_ptr->background.gray & 0xff); + } + + else if (a < 0xffff) + { + png_uint_16 g, v; + + g = (png_uint_16)(((*sp) << 8) + *(sp + 1)); + png_composite_16(v, g, a, png_ptr->background.gray); + *sp = (png_byte)((v >> 8) & 0xff); + *(sp + 1) = (png_byte)(v & 0xff); + } + } + } + } + break; + } + + case PNG_COLOR_TYPE_RGB_ALPHA: + { + if (row_info->bit_depth == 8) + { +#ifdef PNG_READ_GAMMA_SUPPORTED + if (gamma_to_1 != NULL && gamma_from_1 != NULL && + gamma_table != NULL) + { + sp = row; + for (i = 0; i < row_width; i++, sp += 4) + { + png_byte a = *(sp + 3); + + if (a == 0xff) + { + *sp = gamma_table[*sp]; + *(sp + 1) = gamma_table[*(sp + 1)]; + *(sp + 2) = gamma_table[*(sp + 2)]; + } + + else if (a == 0) + { + /* Background is already in screen gamma */ + *sp = (png_byte)png_ptr->background.red; + *(sp + 1) = (png_byte)png_ptr->background.green; + *(sp + 2) = (png_byte)png_ptr->background.blue; + } + + else + { + png_byte v, w; + + v = gamma_to_1[*sp]; + png_composite(w, v, a, png_ptr->background_1.red); + if (!optimize) w = gamma_from_1[w]; + *sp = w; + + v = gamma_to_1[*(sp + 1)]; + png_composite(w, v, a, png_ptr->background_1.green); + if (!optimize) w = gamma_from_1[w]; + *(sp + 1) = w; + + v = gamma_to_1[*(sp + 2)]; + png_composite(w, v, a, png_ptr->background_1.blue); + if (!optimize) w = gamma_from_1[w]; + *(sp + 2) = w; + } + } + } + else +#endif + { + sp = row; + for (i = 0; i < row_width; i++, sp += 4) + { + png_byte a = *(sp + 3); + + if (a == 0) + { + *sp = (png_byte)png_ptr->background.red; + *(sp + 1) = (png_byte)png_ptr->background.green; + *(sp + 2) = (png_byte)png_ptr->background.blue; + } + + else if (a < 0xff) + { + png_composite(*sp, *sp, a, png_ptr->background.red); + + png_composite(*(sp + 1), *(sp + 1), a, + png_ptr->background.green); + + png_composite(*(sp + 2), *(sp + 2), a, + png_ptr->background.blue); + } + } + } + } + else /* if (row_info->bit_depth == 16) */ + { +#ifdef PNG_READ_GAMMA_SUPPORTED + if (gamma_16 != NULL && gamma_16_from_1 != NULL && + gamma_16_to_1 != NULL) + { + sp = row; + for (i = 0; i < row_width; i++, sp += 8) + { + png_uint_16 a = (png_uint_16)(((png_uint_16)(*(sp + 6)) + << 8) + (png_uint_16)(*(sp + 7))); + + if (a == (png_uint_16)0xffff) + { + png_uint_16 v; + + v = gamma_16[*(sp + 1) >> gamma_shift][*sp]; + *sp = (png_byte)((v >> 8) & 0xff); + *(sp + 1) = (png_byte)(v & 0xff); + + v = gamma_16[*(sp + 3) >> gamma_shift][*(sp + 2)]; + *(sp + 2) = (png_byte)((v >> 8) & 0xff); + *(sp + 3) = (png_byte)(v & 0xff); + + v = gamma_16[*(sp + 5) >> gamma_shift][*(sp + 4)]; + *(sp + 4) = (png_byte)((v >> 8) & 0xff); + *(sp + 5) = (png_byte)(v & 0xff); + } + + else if (a == 0) + { + /* Background is already in screen gamma */ + *sp = (png_byte)((png_ptr->background.red >> 8) & 0xff); + *(sp + 1) = (png_byte)(png_ptr->background.red & 0xff); + *(sp + 2) = (png_byte)((png_ptr->background.green >> 8) + & 0xff); + *(sp + 3) = (png_byte)(png_ptr->background.green + & 0xff); + *(sp + 4) = (png_byte)((png_ptr->background.blue >> 8) + & 0xff); + *(sp + 5) = (png_byte)(png_ptr->background.blue & 0xff); + } + + else + { + png_uint_16 v, w; + + v = gamma_16_to_1[*(sp + 1) >> gamma_shift][*sp]; + png_composite_16(w, v, a, png_ptr->background_1.red); + if (!optimize) + w = gamma_16_from_1[((w&0xff) >> gamma_shift)][w >> + 8]; + *sp = (png_byte)((w >> 8) & 0xff); + *(sp + 1) = (png_byte)(w & 0xff); + + v = gamma_16_to_1[*(sp + 3) >> gamma_shift][*(sp + 2)]; + png_composite_16(w, v, a, png_ptr->background_1.green); + if (!optimize) + w = gamma_16_from_1[((w&0xff) >> gamma_shift)][w >> + 8]; + + *(sp + 2) = (png_byte)((w >> 8) & 0xff); + *(sp + 3) = (png_byte)(w & 0xff); + + v = gamma_16_to_1[*(sp + 5) >> gamma_shift][*(sp + 4)]; + png_composite_16(w, v, a, png_ptr->background_1.blue); + if (!optimize) + w = gamma_16_from_1[((w&0xff) >> gamma_shift)][w >> + 8]; + + *(sp + 4) = (png_byte)((w >> 8) & 0xff); + *(sp + 5) = (png_byte)(w & 0xff); + } + } + } + + else +#endif + { + sp = row; + for (i = 0; i < row_width; i++, sp += 8) + { + png_uint_16 a = (png_uint_16)(((png_uint_16)(*(sp + 6)) + << 8) + (png_uint_16)(*(sp + 7))); + + if (a == 0) + { + *sp = (png_byte)((png_ptr->background.red >> 8) & 0xff); + *(sp + 1) = (png_byte)(png_ptr->background.red & 0xff); + *(sp + 2) = (png_byte)((png_ptr->background.green >> 8) + & 0xff); + *(sp + 3) = (png_byte)(png_ptr->background.green + & 0xff); + *(sp + 4) = (png_byte)((png_ptr->background.blue >> 8) + & 0xff); + *(sp + 5) = (png_byte)(png_ptr->background.blue & 0xff); + } + + else if (a < 0xffff) + { + png_uint_16 v; + + png_uint_16 r = (png_uint_16)(((*sp) << 8) + *(sp + 1)); + png_uint_16 g = (png_uint_16)(((*(sp + 2)) << 8) + + *(sp + 3)); + png_uint_16 b = (png_uint_16)(((*(sp + 4)) << 8) + + *(sp + 5)); + + png_composite_16(v, r, a, png_ptr->background.red); + *sp = (png_byte)((v >> 8) & 0xff); + *(sp + 1) = (png_byte)(v & 0xff); + + png_composite_16(v, g, a, png_ptr->background.green); + *(sp + 2) = (png_byte)((v >> 8) & 0xff); + *(sp + 3) = (png_byte)(v & 0xff); + + png_composite_16(v, b, a, png_ptr->background.blue); + *(sp + 4) = (png_byte)((v >> 8) & 0xff); + *(sp + 5) = (png_byte)(v & 0xff); + } + } + } + } + break; + } + + default: + break; + } + } +} +#endif /* PNG_READ_BACKGROUND_SUPPORTED || PNG_READ_ALPHA_MODE_SUPPORTED */ + +#ifdef PNG_READ_GAMMA_SUPPORTED +/* Gamma correct the image, avoiding the alpha channel. Make sure + * you do this after you deal with the transparency issue on grayscale + * or RGB images. If your bit depth is 8, use gamma_table, if it + * is 16, use gamma_16_table and gamma_shift. Build these with + * build_gamma_table(). + */ +void /* PRIVATE */ +png_do_gamma(png_row_infop row_info, png_bytep row, png_structrp png_ptr) +{ + png_const_bytep gamma_table = png_ptr->gamma_table; + png_const_uint_16pp gamma_16_table = png_ptr->gamma_16_table; + int gamma_shift = png_ptr->gamma_shift; + + png_bytep sp; + png_uint_32 i; + png_uint_32 row_width=row_info->width; + + png_debug(1, "in png_do_gamma"); + + if (((row_info->bit_depth <= 8 && gamma_table != NULL) || + (row_info->bit_depth == 16 && gamma_16_table != NULL))) + { + switch (row_info->color_type) + { + case PNG_COLOR_TYPE_RGB: + { + if (row_info->bit_depth == 8) + { + sp = row; + for (i = 0; i < row_width; i++) + { + *sp = gamma_table[*sp]; + sp++; + *sp = gamma_table[*sp]; + sp++; + *sp = gamma_table[*sp]; + sp++; + } + } + + else /* if (row_info->bit_depth == 16) */ + { + sp = row; + for (i = 0; i < row_width; i++) + { + png_uint_16 v; + + v = gamma_16_table[*(sp + 1) >> gamma_shift][*sp]; + *sp = (png_byte)((v >> 8) & 0xff); + *(sp + 1) = (png_byte)(v & 0xff); + sp += 2; + + v = gamma_16_table[*(sp + 1) >> gamma_shift][*sp]; + *sp = (png_byte)((v >> 8) & 0xff); + *(sp + 1) = (png_byte)(v & 0xff); + sp += 2; + + v = gamma_16_table[*(sp + 1) >> gamma_shift][*sp]; + *sp = (png_byte)((v >> 8) & 0xff); + *(sp + 1) = (png_byte)(v & 0xff); + sp += 2; + } + } + break; + } + + case PNG_COLOR_TYPE_RGB_ALPHA: + { + if (row_info->bit_depth == 8) + { + sp = row; + for (i = 0; i < row_width; i++) + { + *sp = gamma_table[*sp]; + sp++; + + *sp = gamma_table[*sp]; + sp++; + + *sp = gamma_table[*sp]; + sp++; + + sp++; + } + } + + else /* if (row_info->bit_depth == 16) */ + { + sp = row; + for (i = 0; i < row_width; i++) + { + png_uint_16 v = gamma_16_table[*(sp + 1) >> gamma_shift][*sp]; + *sp = (png_byte)((v >> 8) & 0xff); + *(sp + 1) = (png_byte)(v & 0xff); + sp += 2; + + v = gamma_16_table[*(sp + 1) >> gamma_shift][*sp]; + *sp = (png_byte)((v >> 8) & 0xff); + *(sp + 1) = (png_byte)(v & 0xff); + sp += 2; + + v = gamma_16_table[*(sp + 1) >> gamma_shift][*sp]; + *sp = (png_byte)((v >> 8) & 0xff); + *(sp + 1) = (png_byte)(v & 0xff); + sp += 4; + } + } + break; + } + + case PNG_COLOR_TYPE_GRAY_ALPHA: + { + if (row_info->bit_depth == 8) + { + sp = row; + for (i = 0; i < row_width; i++) + { + *sp = gamma_table[*sp]; + sp += 2; + } + } + + else /* if (row_info->bit_depth == 16) */ + { + sp = row; + for (i = 0; i < row_width; i++) + { + png_uint_16 v = gamma_16_table[*(sp + 1) >> gamma_shift][*sp]; + *sp = (png_byte)((v >> 8) & 0xff); + *(sp + 1) = (png_byte)(v & 0xff); + sp += 4; + } + } + break; + } + + case PNG_COLOR_TYPE_GRAY: + { + if (row_info->bit_depth == 2) + { + sp = row; + for (i = 0; i < row_width; i += 4) + { + int a = *sp & 0xc0; + int b = *sp & 0x30; + int c = *sp & 0x0c; + int d = *sp & 0x03; + + *sp = (png_byte)( + ((((int)gamma_table[a|(a>>2)|(a>>4)|(a>>6)]) ) & 0xc0)| + ((((int)gamma_table[(b<<2)|b|(b>>2)|(b>>4)])>>2) & 0x30)| + ((((int)gamma_table[(c<<4)|(c<<2)|c|(c>>2)])>>4) & 0x0c)| + ((((int)gamma_table[(d<<6)|(d<<4)|(d<<2)|d])>>6) )); + sp++; + } + } + + if (row_info->bit_depth == 4) + { + sp = row; + for (i = 0; i < row_width; i += 2) + { + int msb = *sp & 0xf0; + int lsb = *sp & 0x0f; + + *sp = (png_byte)((((int)gamma_table[msb | (msb >> 4)]) & 0xf0) + | (((int)gamma_table[(lsb << 4) | lsb]) >> 4)); + sp++; + } + } + + else if (row_info->bit_depth == 8) + { + sp = row; + for (i = 0; i < row_width; i++) + { + *sp = gamma_table[*sp]; + sp++; + } + } + + else if (row_info->bit_depth == 16) + { + sp = row; + for (i = 0; i < row_width; i++) + { + png_uint_16 v = gamma_16_table[*(sp + 1) >> gamma_shift][*sp]; + *sp = (png_byte)((v >> 8) & 0xff); + *(sp + 1) = (png_byte)(v & 0xff); + sp += 2; + } + } + break; + } + + default: + break; + } + } +} +#endif + +#ifdef PNG_READ_ALPHA_MODE_SUPPORTED +/* Encode the alpha channel to the output gamma (the input channel is always + * linear.) Called only with color types that have an alpha channel. Needs the + * from_1 tables. + */ +void /* PRIVATE */ +png_do_encode_alpha(png_row_infop row_info, png_bytep row, png_structrp png_ptr) +{ + png_uint_32 row_width = row_info->width; + + png_debug(1, "in png_do_encode_alpha"); + + if (row_info->color_type & PNG_COLOR_MASK_ALPHA) + { + if (row_info->bit_depth == 8) + { + PNG_CONST png_bytep table = png_ptr->gamma_from_1; + + if (table != NULL) + { + PNG_CONST int step = + (row_info->color_type & PNG_COLOR_MASK_COLOR) ? 4 : 2; + + /* The alpha channel is the last component: */ + row += step - 1; + + for (; row_width > 0; --row_width, row += step) + *row = table[*row]; + + return; + } + } + + else if (row_info->bit_depth == 16) + { + PNG_CONST png_uint_16pp table = png_ptr->gamma_16_from_1; + PNG_CONST int gamma_shift = png_ptr->gamma_shift; + + if (table != NULL) + { + PNG_CONST int step = + (row_info->color_type & PNG_COLOR_MASK_COLOR) ? 8 : 4; + + /* The alpha channel is the last component: */ + row += step - 2; + + for (; row_width > 0; --row_width, row += step) + { + png_uint_16 v; + + v = table[*(row + 1) >> gamma_shift][*row]; + *row = (png_byte)((v >> 8) & 0xff); + *(row + 1) = (png_byte)(v & 0xff); + } + + return; + } + } + } + + /* Only get to here if called with a weird row_info; no harm has been done, + * so just issue a warning. + */ + png_warning(png_ptr, "png_do_encode_alpha: unexpected call"); +} +#endif + +#ifdef PNG_READ_EXPAND_SUPPORTED +/* Expands a palette row to an RGB or RGBA row depending + * upon whether you supply trans and num_trans. + */ +void /* PRIVATE */ +png_do_expand_palette(png_row_infop row_info, png_bytep row, + png_const_colorp palette, png_const_bytep trans_alpha, int num_trans) +{ + int shift, value; + png_bytep sp, dp; + png_uint_32 i; + png_uint_32 row_width=row_info->width; + + png_debug(1, "in png_do_expand_palette"); + + if (row_info->color_type == PNG_COLOR_TYPE_PALETTE) + { + if (row_info->bit_depth < 8) + { + switch (row_info->bit_depth) + { + case 1: + { + sp = row + (png_size_t)((row_width - 1) >> 3); + dp = row + (png_size_t)row_width - 1; + shift = 7 - (int)((row_width + 7) & 0x07); + for (i = 0; i < row_width; i++) + { + if ((*sp >> shift) & 0x01) + *dp = 1; + + else + *dp = 0; + + if (shift == 7) + { + shift = 0; + sp--; + } + + else + shift++; + + dp--; + } + break; + } + + case 2: + { + sp = row + (png_size_t)((row_width - 1) >> 2); + dp = row + (png_size_t)row_width - 1; + shift = (int)((3 - ((row_width + 3) & 0x03)) << 1); + for (i = 0; i < row_width; i++) + { + value = (*sp >> shift) & 0x03; + *dp = (png_byte)value; + if (shift == 6) + { + shift = 0; + sp--; + } + + else + shift += 2; + + dp--; + } + break; + } + + case 4: + { + sp = row + (png_size_t)((row_width - 1) >> 1); + dp = row + (png_size_t)row_width - 1; + shift = (int)((row_width & 0x01) << 2); + for (i = 0; i < row_width; i++) + { + value = (*sp >> shift) & 0x0f; + *dp = (png_byte)value; + if (shift == 4) + { + shift = 0; + sp--; + } + + else + shift += 4; + + dp--; + } + break; + } + + default: + break; + } + row_info->bit_depth = 8; + row_info->pixel_depth = 8; + row_info->rowbytes = row_width; + } + + if (row_info->bit_depth == 8) + { + { + if (num_trans > 0) + { + sp = row + (png_size_t)row_width - 1; + dp = row + (png_size_t)(row_width << 2) - 1; + + for (i = 0; i < row_width; i++) + { + if ((int)(*sp) >= num_trans) + *dp-- = 0xff; + + else + *dp-- = trans_alpha[*sp]; + + *dp-- = palette[*sp].blue; + *dp-- = palette[*sp].green; + *dp-- = palette[*sp].red; + sp--; + } + row_info->bit_depth = 8; + row_info->pixel_depth = 32; + row_info->rowbytes = row_width * 4; + row_info->color_type = 6; + row_info->channels = 4; + } + + else + { + sp = row + (png_size_t)row_width - 1; + dp = row + (png_size_t)(row_width * 3) - 1; + + for (i = 0; i < row_width; i++) + { + *dp-- = palette[*sp].blue; + *dp-- = palette[*sp].green; + *dp-- = palette[*sp].red; + sp--; + } + + row_info->bit_depth = 8; + row_info->pixel_depth = 24; + row_info->rowbytes = row_width * 3; + row_info->color_type = 2; + row_info->channels = 3; + } + } + } + } +} + +/* If the bit depth < 8, it is expanded to 8. Also, if the already + * expanded transparency value is supplied, an alpha channel is built. + */ +void /* PRIVATE */ +png_do_expand(png_row_infop row_info, png_bytep row, + png_const_color_16p trans_color) +{ + int shift, value; + png_bytep sp, dp; + png_uint_32 i; + png_uint_32 row_width=row_info->width; + + png_debug(1, "in png_do_expand"); + + { + if (row_info->color_type == PNG_COLOR_TYPE_GRAY) + { + unsigned int gray = trans_color ? trans_color->gray : 0; + + if (row_info->bit_depth < 8) + { + switch (row_info->bit_depth) + { + case 1: + { + gray = (gray & 0x01) * 0xff; + sp = row + (png_size_t)((row_width - 1) >> 3); + dp = row + (png_size_t)row_width - 1; + shift = 7 - (int)((row_width + 7) & 0x07); + for (i = 0; i < row_width; i++) + { + if ((*sp >> shift) & 0x01) + *dp = 0xff; + + else + *dp = 0; + + if (shift == 7) + { + shift = 0; + sp--; + } + + else + shift++; + + dp--; + } + break; + } + + case 2: + { + gray = (gray & 0x03) * 0x55; + sp = row + (png_size_t)((row_width - 1) >> 2); + dp = row + (png_size_t)row_width - 1; + shift = (int)((3 - ((row_width + 3) & 0x03)) << 1); + for (i = 0; i < row_width; i++) + { + value = (*sp >> shift) & 0x03; + *dp = (png_byte)(value | (value << 2) | (value << 4) | + (value << 6)); + if (shift == 6) + { + shift = 0; + sp--; + } + + else + shift += 2; + + dp--; + } + break; + } + + case 4: + { + gray = (gray & 0x0f) * 0x11; + sp = row + (png_size_t)((row_width - 1) >> 1); + dp = row + (png_size_t)row_width - 1; + shift = (int)((1 - ((row_width + 1) & 0x01)) << 2); + for (i = 0; i < row_width; i++) + { + value = (*sp >> shift) & 0x0f; + *dp = (png_byte)(value | (value << 4)); + if (shift == 4) + { + shift = 0; + sp--; + } + + else + shift = 4; + + dp--; + } + break; + } + + default: + break; + } + + row_info->bit_depth = 8; + row_info->pixel_depth = 8; + row_info->rowbytes = row_width; + } + + if (trans_color != NULL) + { + if (row_info->bit_depth == 8) + { + gray = gray & 0xff; + sp = row + (png_size_t)row_width - 1; + dp = row + (png_size_t)(row_width << 1) - 1; + + for (i = 0; i < row_width; i++) + { + if (*sp == gray) + *dp-- = 0; + + else + *dp-- = 0xff; + + *dp-- = *sp--; + } + } + + else if (row_info->bit_depth == 16) + { + unsigned int gray_high = (gray >> 8) & 0xff; + unsigned int gray_low = gray & 0xff; + sp = row + row_info->rowbytes - 1; + dp = row + (row_info->rowbytes << 1) - 1; + for (i = 0; i < row_width; i++) + { + if (*(sp - 1) == gray_high && *(sp) == gray_low) + { + *dp-- = 0; + *dp-- = 0; + } + + else + { + *dp-- = 0xff; + *dp-- = 0xff; + } + + *dp-- = *sp--; + *dp-- = *sp--; + } + } + + row_info->color_type = PNG_COLOR_TYPE_GRAY_ALPHA; + row_info->channels = 2; + row_info->pixel_depth = (png_byte)(row_info->bit_depth << 1); + row_info->rowbytes = PNG_ROWBYTES(row_info->pixel_depth, + row_width); + } + } + else if (row_info->color_type == PNG_COLOR_TYPE_RGB && trans_color) + { + if (row_info->bit_depth == 8) + { + png_byte red = (png_byte)(trans_color->red & 0xff); + png_byte green = (png_byte)(trans_color->green & 0xff); + png_byte blue = (png_byte)(trans_color->blue & 0xff); + sp = row + (png_size_t)row_info->rowbytes - 1; + dp = row + (png_size_t)(row_width << 2) - 1; + for (i = 0; i < row_width; i++) + { + if (*(sp - 2) == red && *(sp - 1) == green && *(sp) == blue) + *dp-- = 0; + + else + *dp-- = 0xff; + + *dp-- = *sp--; + *dp-- = *sp--; + *dp-- = *sp--; + } + } + else if (row_info->bit_depth == 16) + { + png_byte red_high = (png_byte)((trans_color->red >> 8) & 0xff); + png_byte green_high = (png_byte)((trans_color->green >> 8) & 0xff); + png_byte blue_high = (png_byte)((trans_color->blue >> 8) & 0xff); + png_byte red_low = (png_byte)(trans_color->red & 0xff); + png_byte green_low = (png_byte)(trans_color->green & 0xff); + png_byte blue_low = (png_byte)(trans_color->blue & 0xff); + sp = row + row_info->rowbytes - 1; + dp = row + (png_size_t)(row_width << 3) - 1; + for (i = 0; i < row_width; i++) + { + if (*(sp - 5) == red_high && + *(sp - 4) == red_low && + *(sp - 3) == green_high && + *(sp - 2) == green_low && + *(sp - 1) == blue_high && + *(sp ) == blue_low) + { + *dp-- = 0; + *dp-- = 0; + } + + else + { + *dp-- = 0xff; + *dp-- = 0xff; + } + + *dp-- = *sp--; + *dp-- = *sp--; + *dp-- = *sp--; + *dp-- = *sp--; + *dp-- = *sp--; + *dp-- = *sp--; + } + } + row_info->color_type = PNG_COLOR_TYPE_RGB_ALPHA; + row_info->channels = 4; + row_info->pixel_depth = (png_byte)(row_info->bit_depth << 2); + row_info->rowbytes = PNG_ROWBYTES(row_info->pixel_depth, row_width); + } + } +} +#endif + +#ifdef PNG_READ_EXPAND_16_SUPPORTED +/* If the bit depth is 8 and the color type is not a palette type expand the + * whole row to 16 bits. Has no effect otherwise. + */ +void /* PRIVATE */ +png_do_expand_16(png_row_infop row_info, png_bytep row) +{ + if (row_info->bit_depth == 8 && + row_info->color_type != PNG_COLOR_TYPE_PALETTE) + { + /* The row have a sequence of bytes containing [0..255] and we need + * to turn it into another row containing [0..65535], to do this we + * calculate: + * + * (input / 255) * 65535 + * + * Which happens to be exactly input * 257 and this can be achieved + * simply by byte replication in place (copying backwards). + */ + png_byte *sp = row + row_info->rowbytes; /* source, last byte + 1 */ + png_byte *dp = sp + row_info->rowbytes; /* destination, end + 1 */ + while (dp > sp) + dp[-2] = dp[-1] = *--sp, dp -= 2; + + row_info->rowbytes *= 2; + row_info->bit_depth = 16; + row_info->pixel_depth = (png_byte)(row_info->channels * 16); + } +} +#endif + +#ifdef PNG_READ_QUANTIZE_SUPPORTED +void /* PRIVATE */ +png_do_quantize(png_row_infop row_info, png_bytep row, + png_const_bytep palette_lookup, png_const_bytep quantize_lookup) +{ + png_bytep sp, dp; + png_uint_32 i; + png_uint_32 row_width=row_info->width; + + png_debug(1, "in png_do_quantize"); + + if (row_info->bit_depth == 8) + { + if (row_info->color_type == PNG_COLOR_TYPE_RGB && palette_lookup) + { + int r, g, b, p; + sp = row; + dp = row; + for (i = 0; i < row_width; i++) + { + r = *sp++; + g = *sp++; + b = *sp++; + + /* This looks real messy, but the compiler will reduce + * it down to a reasonable formula. For example, with + * 5 bits per color, we get: + * p = (((r >> 3) & 0x1f) << 10) | + * (((g >> 3) & 0x1f) << 5) | + * ((b >> 3) & 0x1f); + */ + p = (((r >> (8 - PNG_QUANTIZE_RED_BITS)) & + ((1 << PNG_QUANTIZE_RED_BITS) - 1)) << + (PNG_QUANTIZE_GREEN_BITS + PNG_QUANTIZE_BLUE_BITS)) | + (((g >> (8 - PNG_QUANTIZE_GREEN_BITS)) & + ((1 << PNG_QUANTIZE_GREEN_BITS) - 1)) << + (PNG_QUANTIZE_BLUE_BITS)) | + ((b >> (8 - PNG_QUANTIZE_BLUE_BITS)) & + ((1 << PNG_QUANTIZE_BLUE_BITS) - 1)); + + *dp++ = palette_lookup[p]; + } + + row_info->color_type = PNG_COLOR_TYPE_PALETTE; + row_info->channels = 1; + row_info->pixel_depth = row_info->bit_depth; + row_info->rowbytes = PNG_ROWBYTES(row_info->pixel_depth, row_width); + } + + else if (row_info->color_type == PNG_COLOR_TYPE_RGB_ALPHA && + palette_lookup != NULL) + { + int r, g, b, p; + sp = row; + dp = row; + for (i = 0; i < row_width; i++) + { + r = *sp++; + g = *sp++; + b = *sp++; + sp++; + + p = (((r >> (8 - PNG_QUANTIZE_RED_BITS)) & + ((1 << PNG_QUANTIZE_RED_BITS) - 1)) << + (PNG_QUANTIZE_GREEN_BITS + PNG_QUANTIZE_BLUE_BITS)) | + (((g >> (8 - PNG_QUANTIZE_GREEN_BITS)) & + ((1 << PNG_QUANTIZE_GREEN_BITS) - 1)) << + (PNG_QUANTIZE_BLUE_BITS)) | + ((b >> (8 - PNG_QUANTIZE_BLUE_BITS)) & + ((1 << PNG_QUANTIZE_BLUE_BITS) - 1)); + + *dp++ = palette_lookup[p]; + } + + row_info->color_type = PNG_COLOR_TYPE_PALETTE; + row_info->channels = 1; + row_info->pixel_depth = row_info->bit_depth; + row_info->rowbytes = PNG_ROWBYTES(row_info->pixel_depth, row_width); + } + + else if (row_info->color_type == PNG_COLOR_TYPE_PALETTE && + quantize_lookup) + { + sp = row; + + for (i = 0; i < row_width; i++, sp++) + { + *sp = quantize_lookup[*sp]; + } + } + } +} +#endif /* PNG_READ_QUANTIZE_SUPPORTED */ +#endif /* PNG_READ_TRANSFORMS_SUPPORTED */ + +#ifdef PNG_MNG_FEATURES_SUPPORTED +/* Undoes intrapixel differencing */ +void /* PRIVATE */ +png_do_read_intrapixel(png_row_infop row_info, png_bytep row) +{ + png_debug(1, "in png_do_read_intrapixel"); + + if ( + (row_info->color_type & PNG_COLOR_MASK_COLOR)) + { + int bytes_per_pixel; + png_uint_32 row_width = row_info->width; + + if (row_info->bit_depth == 8) + { + png_bytep rp; + png_uint_32 i; + + if (row_info->color_type == PNG_COLOR_TYPE_RGB) + bytes_per_pixel = 3; + + else if (row_info->color_type == PNG_COLOR_TYPE_RGB_ALPHA) + bytes_per_pixel = 4; + + else + return; + + for (i = 0, rp = row; i < row_width; i++, rp += bytes_per_pixel) + { + *(rp) = (png_byte)((256 + *rp + *(rp + 1)) & 0xff); + *(rp+2) = (png_byte)((256 + *(rp + 2) + *(rp + 1)) & 0xff); + } + } + else if (row_info->bit_depth == 16) + { + png_bytep rp; + png_uint_32 i; + + if (row_info->color_type == PNG_COLOR_TYPE_RGB) + bytes_per_pixel = 6; + + else if (row_info->color_type == PNG_COLOR_TYPE_RGB_ALPHA) + bytes_per_pixel = 8; + + else + return; + + for (i = 0, rp = row; i < row_width; i++, rp += bytes_per_pixel) + { + png_uint_32 s0 = (*(rp ) << 8) | *(rp + 1); + png_uint_32 s1 = (*(rp + 2) << 8) | *(rp + 3); + png_uint_32 s2 = (*(rp + 4) << 8) | *(rp + 5); + png_uint_32 red = (s0 + s1 + 65536) & 0xffff; + png_uint_32 blue = (s2 + s1 + 65536) & 0xffff; + *(rp ) = (png_byte)((red >> 8) & 0xff); + *(rp + 1) = (png_byte)(red & 0xff); + *(rp + 4) = (png_byte)((blue >> 8) & 0xff); + *(rp + 5) = (png_byte)(blue & 0xff); + } + } + } +} +#endif /* PNG_MNG_FEATURES_SUPPORTED */ +#endif /* PNG_READ_SUPPORTED */ diff --git a/ml/dlib/dlib/external/libpng/pngrutil.c b/ml/dlib/dlib/external/libpng/pngrutil.c new file mode 100644 index 000000000..2438384dd --- /dev/null +++ b/ml/dlib/dlib/external/libpng/pngrutil.c @@ -0,0 +1,4475 @@ + +/* pngrutil.c - utilities to read a PNG file + * + * Last changed in libpng 1.6.7 [November 14, 2013] + * Copyright (c) 1998-2013 Glenn Randers-Pehrson + * (Version 0.96 Copyright (c) 1996, 1997 Andreas Dilger) + * (Version 0.88 Copyright (c) 1995, 1996 Guy Eric Schalnat, Group 42, Inc.) + * + * This code is released under the libpng license. + * For conditions of distribution and use, see the disclaimer + * and license in png.h + * + * This file contains routines that are only called from within + * libpng itself during the course of reading an image. + */ + +#include "pngpriv.h" + +#ifdef PNG_READ_SUPPORTED + +png_uint_32 PNGAPI +png_get_uint_31(png_const_structrp png_ptr, png_const_bytep buf) +{ + png_uint_32 uval = png_get_uint_32(buf); + + if (uval > PNG_UINT_31_MAX) + png_error(png_ptr, "PNG unsigned integer out of range"); + + return (uval); +} + +#if defined(PNG_READ_gAMA_SUPPORTED) || defined(PNG_READ_cHRM_SUPPORTED) +/* The following is a variation on the above for use with the fixed + * point values used for gAMA and cHRM. Instead of png_error it + * issues a warning and returns (-1) - an invalid value because both + * gAMA and cHRM use *unsigned* integers for fixed point values. + */ +#define PNG_FIXED_ERROR (-1) + +static png_fixed_point /* PRIVATE */ +png_get_fixed_point(png_structrp png_ptr, png_const_bytep buf) +{ + png_uint_32 uval = png_get_uint_32(buf); + + if (uval <= PNG_UINT_31_MAX) + return (png_fixed_point)uval; /* known to be in range */ + + /* The caller can turn off the warning by passing NULL. */ + if (png_ptr != NULL) + png_warning(png_ptr, "PNG fixed point integer out of range"); + + return PNG_FIXED_ERROR; +} +#endif + +#ifdef PNG_READ_INT_FUNCTIONS_SUPPORTED +/* NOTE: the read macros will obscure these definitions, so that if + * PNG_USE_READ_MACROS is set the library will not use them internally, + * but the APIs will still be available externally. + * + * The parentheses around "PNGAPI function_name" in the following three + * functions are necessary because they allow the macros to co-exist with + * these (unused but exported) functions. + */ + +/* Grab an unsigned 32-bit integer from a buffer in big-endian format. */ +png_uint_32 (PNGAPI +png_get_uint_32)(png_const_bytep buf) +{ + png_uint_32 uval = + ((png_uint_32)(*(buf )) << 24) + + ((png_uint_32)(*(buf + 1)) << 16) + + ((png_uint_32)(*(buf + 2)) << 8) + + ((png_uint_32)(*(buf + 3)) ) ; + + return uval; +} + +/* Grab a signed 32-bit integer from a buffer in big-endian format. The + * data is stored in the PNG file in two's complement format and there + * is no guarantee that a 'png_int_32' is exactly 32 bits, therefore + * the following code does a two's complement to native conversion. + */ +png_int_32 (PNGAPI +png_get_int_32)(png_const_bytep buf) +{ + png_uint_32 uval = png_get_uint_32(buf); + if ((uval & 0x80000000) == 0) /* non-negative */ + return uval; + + uval = (uval ^ 0xffffffff) + 1; /* 2's complement: -x = ~x+1 */ + return -(png_int_32)uval; +} + +/* Grab an unsigned 16-bit integer from a buffer in big-endian format. */ +png_uint_16 (PNGAPI +png_get_uint_16)(png_const_bytep buf) +{ + /* ANSI-C requires an int value to accomodate at least 16 bits so this + * works and allows the compiler not to worry about possible narrowing + * on 32 bit systems. (Pre-ANSI systems did not make integers smaller + * than 16 bits either.) + */ + unsigned int val = + ((unsigned int)(*buf) << 8) + + ((unsigned int)(*(buf + 1))); + + return (png_uint_16)val; +} + +#endif /* PNG_READ_INT_FUNCTIONS_SUPPORTED */ + +/* Read and check the PNG file signature */ +void /* PRIVATE */ +png_read_sig(png_structrp png_ptr, png_inforp info_ptr) +{ + png_size_t num_checked, num_to_check; + + /* Exit if the user application does not expect a signature. */ + if (png_ptr->sig_bytes >= 8) + return; + + num_checked = png_ptr->sig_bytes; + num_to_check = 8 - num_checked; + +#ifdef PNG_IO_STATE_SUPPORTED + png_ptr->io_state = PNG_IO_READING | PNG_IO_SIGNATURE; +#endif + + /* The signature must be serialized in a single I/O call. */ + png_read_data(png_ptr, &(info_ptr->signature[num_checked]), num_to_check); + png_ptr->sig_bytes = 8; + + if (png_sig_cmp(info_ptr->signature, num_checked, num_to_check)) + { + if (num_checked < 4 && + png_sig_cmp(info_ptr->signature, num_checked, num_to_check - 4)) + png_error(png_ptr, "Not a PNG file"); + else + png_error(png_ptr, "PNG file corrupted by ASCII conversion"); + } + if (num_checked < 3) + png_ptr->mode |= PNG_HAVE_PNG_SIGNATURE; +} + +/* Read the chunk header (length + type name). + * Put the type name into png_ptr->chunk_name, and return the length. + */ +png_uint_32 /* PRIVATE */ +png_read_chunk_header(png_structrp png_ptr) +{ + png_byte buf[8]; + png_uint_32 length; + +#ifdef PNG_IO_STATE_SUPPORTED + png_ptr->io_state = PNG_IO_READING | PNG_IO_CHUNK_HDR; +#endif + + /* Read the length and the chunk name. + * This must be performed in a single I/O call. + */ + png_read_data(png_ptr, buf, 8); + length = png_get_uint_31(png_ptr, buf); + + /* Put the chunk name into png_ptr->chunk_name. */ + png_ptr->chunk_name = PNG_CHUNK_FROM_STRING(buf+4); + + png_debug2(0, "Reading %lx chunk, length = %lu", + (unsigned long)png_ptr->chunk_name, (unsigned long)length); + + /* Reset the crc and run it over the chunk name. */ + png_reset_crc(png_ptr); + png_calculate_crc(png_ptr, buf + 4, 4); + + /* Check to see if chunk name is valid. */ + png_check_chunk_name(png_ptr, png_ptr->chunk_name); + +#ifdef PNG_IO_STATE_SUPPORTED + png_ptr->io_state = PNG_IO_READING | PNG_IO_CHUNK_DATA; +#endif + + return length; +} + +/* Read data, and (optionally) run it through the CRC. */ +void /* PRIVATE */ +png_crc_read(png_structrp png_ptr, png_bytep buf, png_uint_32 length) +{ + if (png_ptr == NULL) + return; + + png_read_data(png_ptr, buf, length); + png_calculate_crc(png_ptr, buf, length); +} + +/* Optionally skip data and then check the CRC. Depending on whether we + * are reading an ancillary or critical chunk, and how the program has set + * things up, we may calculate the CRC on the data and print a message. + * Returns '1' if there was a CRC error, '0' otherwise. + */ +int /* PRIVATE */ +png_crc_finish(png_structrp png_ptr, png_uint_32 skip) +{ + /* The size of the local buffer for inflate is a good guess as to a + * reasonable size to use for buffering reads from the application. + */ + while (skip > 0) + { + png_uint_32 len; + png_byte tmpbuf[PNG_INFLATE_BUF_SIZE]; + + len = (sizeof tmpbuf); + if (len > skip) + len = skip; + skip -= len; + + png_crc_read(png_ptr, tmpbuf, len); + } + + if (png_crc_error(png_ptr)) + { + if (PNG_CHUNK_ANCILLARY(png_ptr->chunk_name) ? + !(png_ptr->flags & PNG_FLAG_CRC_ANCILLARY_NOWARN) : + (png_ptr->flags & PNG_FLAG_CRC_CRITICAL_USE)) + { + png_chunk_warning(png_ptr, "CRC error"); + } + + else + { + png_chunk_benign_error(png_ptr, "CRC error"); + return (0); + } + + return (1); + } + + return (0); +} + +/* Compare the CRC stored in the PNG file with that calculated by libpng from + * the data it has read thus far. + */ +int /* PRIVATE */ +png_crc_error(png_structrp png_ptr) +{ + png_byte crc_bytes[4]; + png_uint_32 crc; + int need_crc = 1; + + if (PNG_CHUNK_ANCILLARY(png_ptr->chunk_name)) + { + if ((png_ptr->flags & PNG_FLAG_CRC_ANCILLARY_MASK) == + (PNG_FLAG_CRC_ANCILLARY_USE | PNG_FLAG_CRC_ANCILLARY_NOWARN)) + need_crc = 0; + } + + else /* critical */ + { + if (png_ptr->flags & PNG_FLAG_CRC_CRITICAL_IGNORE) + need_crc = 0; + } + +#ifdef PNG_IO_STATE_SUPPORTED + png_ptr->io_state = PNG_IO_READING | PNG_IO_CHUNK_CRC; +#endif + + /* The chunk CRC must be serialized in a single I/O call. */ + png_read_data(png_ptr, crc_bytes, 4); + + if (need_crc) + { + crc = png_get_uint_32(crc_bytes); + return ((int)(crc != png_ptr->crc)); + } + + else + return (0); +} + +/* Manage the read buffer; this simply reallocates the buffer if it is not small + * enough (or if it is not allocated). The routine returns a pointer to the + * buffer; if an error occurs and 'warn' is set the routine returns NULL, else + * it will call png_error (via png_malloc) on failure. (warn == 2 means + * 'silent'). + */ +static png_bytep +png_read_buffer(png_structrp png_ptr, png_alloc_size_t new_size, int warn) +{ + png_bytep buffer = png_ptr->read_buffer; + + if (buffer != NULL && new_size > png_ptr->read_buffer_size) + { + png_ptr->read_buffer = NULL; + png_ptr->read_buffer = NULL; + png_ptr->read_buffer_size = 0; + png_free(png_ptr, buffer); + buffer = NULL; + } + + if (buffer == NULL) + { + buffer = png_voidcast(png_bytep, png_malloc_base(png_ptr, new_size)); + + if (buffer != NULL) + { + png_ptr->read_buffer = buffer; + png_ptr->read_buffer_size = new_size; + } + + else if (warn < 2) /* else silent */ + { +#ifdef PNG_WARNINGS_SUPPORTED + if (warn) + png_chunk_warning(png_ptr, "insufficient memory to read chunk"); + else +#endif + { +#ifdef PNG_ERROR_TEXT_SUPPORTED + png_chunk_error(png_ptr, "insufficient memory to read chunk"); +#endif + } + } + } + + return buffer; +} + +/* png_inflate_claim: claim the zstream for some nefarious purpose that involves + * decompression. Returns Z_OK on success, else a zlib error code. It checks + * the owner but, in final release builds, just issues a warning if some other + * chunk apparently owns the stream. Prior to release it does a png_error. + */ +static int +png_inflate_claim(png_structrp png_ptr, png_uint_32 owner) +{ + if (png_ptr->zowner != 0) + { + char msg[64]; + + PNG_STRING_FROM_CHUNK(msg, png_ptr->zowner); + /* So the message that results is " using zstream"; this is an + * internal error, but is very useful for debugging. i18n requirements + * are minimal. + */ + (void)png_safecat(msg, (sizeof msg), 4, " using zstream"); +# if PNG_LIBPNG_BUILD_BASE_TYPE >= PNG_LIBPNG_BUILD_RC + png_chunk_warning(png_ptr, msg); + png_ptr->zowner = 0; +# else + png_chunk_error(png_ptr, msg); +# endif + } + + /* Implementation note: unlike 'png_deflate_claim' this internal function + * does not take the size of the data as an argument. Some efficiency could + * be gained by using this when it is known *if* the zlib stream itself does + * not record the number; however, this is an illusion: the original writer + * of the PNG may have selected a lower window size, and we really must + * follow that because, for systems with with limited capabilities, we + * would otherwise reject the application's attempts to use a smaller window + * size (zlib doesn't have an interface to say "this or lower"!). + * + * inflateReset2 was added to zlib 1.2.4; before this the window could not be + * reset, therefore it is necessary to always allocate the maximum window + * size with earlier zlibs just in case later compressed chunks need it. + */ + { + int ret; /* zlib return code */ +# if PNG_ZLIB_VERNUM >= 0x1240 + +# if defined(PNG_SET_OPTION_SUPPORTED) && \ + defined(PNG_MAXIMUM_INFLATE_WINDOW) + int window_bits; + + if (((png_ptr->options >> PNG_MAXIMUM_INFLATE_WINDOW) & 3) == + PNG_OPTION_ON) + window_bits = 15; + + else + window_bits = 0; +# else +# define window_bits 0 +# endif +# endif + + /* Set this for safety, just in case the previous owner left pointers to + * memory allocations. + */ + png_ptr->zstream.next_in = NULL; + png_ptr->zstream.avail_in = 0; + png_ptr->zstream.next_out = NULL; + png_ptr->zstream.avail_out = 0; + + if (png_ptr->flags & PNG_FLAG_ZSTREAM_INITIALIZED) + { +# if PNG_ZLIB_VERNUM < 0x1240 + ret = inflateReset(&png_ptr->zstream); +# else + ret = inflateReset2(&png_ptr->zstream, window_bits); +# endif + } + + else + { +# if PNG_ZLIB_VERNUM < 0x1240 + ret = inflateInit(&png_ptr->zstream); +# else + ret = inflateInit2(&png_ptr->zstream, window_bits); +# endif + + if (ret == Z_OK) + png_ptr->flags |= PNG_FLAG_ZSTREAM_INITIALIZED; + } + + if (ret == Z_OK) + png_ptr->zowner = owner; + + else + png_zstream_error(png_ptr, ret); + + return ret; + } + +# ifdef window_bits +# undef window_bits +# endif +} + +#ifdef PNG_READ_COMPRESSED_TEXT_SUPPORTED +/* png_inflate now returns zlib error codes including Z_OK and Z_STREAM_END to + * allow the caller to do multiple calls if required. If the 'finish' flag is + * set Z_FINISH will be passed to the final inflate() call and Z_STREAM_END must + * be returned or there has been a problem, otherwise Z_SYNC_FLUSH is used and + * Z_OK or Z_STREAM_END will be returned on success. + * + * The input and output sizes are updated to the actual amounts of data consumed + * or written, not the amount available (as in a z_stream). The data pointers + * are not changed, so the next input is (data+input_size) and the next + * available output is (output+output_size). + */ +static int +png_inflate(png_structrp png_ptr, png_uint_32 owner, int finish, + /* INPUT: */ png_const_bytep input, png_uint_32p input_size_ptr, + /* OUTPUT: */ png_bytep output, png_alloc_size_t *output_size_ptr) +{ + if (png_ptr->zowner == owner) /* Else not claimed */ + { + int ret; + png_alloc_size_t avail_out = *output_size_ptr; + png_uint_32 avail_in = *input_size_ptr; + + /* zlib can't necessarily handle more than 65535 bytes at once (i.e. it + * can't even necessarily handle 65536 bytes) because the type uInt is + * "16 bits or more". Consequently it is necessary to chunk the input to + * zlib. This code uses ZLIB_IO_MAX, from pngpriv.h, as the maximum (the + * maximum value that can be stored in a uInt.) It is possible to set + * ZLIB_IO_MAX to a lower value in pngpriv.h and this may sometimes have + * a performance advantage, because it reduces the amount of data accessed + * at each step and that may give the OS more time to page it in. + */ + png_ptr->zstream.next_in = PNGZ_INPUT_CAST(input); + /* avail_in and avail_out are set below from 'size' */ + png_ptr->zstream.avail_in = 0; + png_ptr->zstream.avail_out = 0; + + /* Read directly into the output if it is available (this is set to + * a local buffer below if output is NULL). + */ + if (output != NULL) + png_ptr->zstream.next_out = output; + + do + { + uInt avail; + Byte local_buffer[PNG_INFLATE_BUF_SIZE]; + + /* zlib INPUT BUFFER */ + /* The setting of 'avail_in' used to be outside the loop; by setting it + * inside it is possible to chunk the input to zlib and simply rely on + * zlib to advance the 'next_in' pointer. This allows arbitrary + * amounts of data to be passed through zlib at the unavoidable cost of + * requiring a window save (memcpy of up to 32768 output bytes) + * every ZLIB_IO_MAX input bytes. + */ + avail_in += png_ptr->zstream.avail_in; /* not consumed last time */ + + avail = ZLIB_IO_MAX; + + if (avail_in < avail) + avail = (uInt)avail_in; /* safe: < than ZLIB_IO_MAX */ + + avail_in -= avail; + png_ptr->zstream.avail_in = avail; + + /* zlib OUTPUT BUFFER */ + avail_out += png_ptr->zstream.avail_out; /* not written last time */ + + avail = ZLIB_IO_MAX; /* maximum zlib can process */ + + if (output == NULL) + { + /* Reset the output buffer each time round if output is NULL and + * make available the full buffer, up to 'remaining_space' + */ + png_ptr->zstream.next_out = local_buffer; + if ((sizeof local_buffer) < avail) + avail = (sizeof local_buffer); + } + + if (avail_out < avail) + avail = (uInt)avail_out; /* safe: < ZLIB_IO_MAX */ + + png_ptr->zstream.avail_out = avail; + avail_out -= avail; + + /* zlib inflate call */ + /* In fact 'avail_out' may be 0 at this point, that happens at the end + * of the read when the final LZ end code was not passed at the end of + * the previous chunk of input data. Tell zlib if we have reached the + * end of the output buffer. + */ + ret = inflate(&png_ptr->zstream, avail_out > 0 ? Z_NO_FLUSH : + (finish ? Z_FINISH : Z_SYNC_FLUSH)); + } while (ret == Z_OK); + + /* For safety kill the local buffer pointer now */ + if (output == NULL) + png_ptr->zstream.next_out = NULL; + + /* Claw back the 'size' and 'remaining_space' byte counts. */ + avail_in += png_ptr->zstream.avail_in; + avail_out += png_ptr->zstream.avail_out; + + /* Update the input and output sizes; the updated values are the amount + * consumed or written, effectively the inverse of what zlib uses. + */ + if (avail_out > 0) + *output_size_ptr -= avail_out; + + if (avail_in > 0) + *input_size_ptr -= avail_in; + + /* Ensure png_ptr->zstream.msg is set (even in the success case!) */ + png_zstream_error(png_ptr, ret); + return ret; + } + + else + { + /* This is a bad internal error. The recovery assigns to the zstream msg + * pointer, which is not owned by the caller, but this is safe; it's only + * used on errors! + */ + png_ptr->zstream.msg = PNGZ_MSG_CAST("zstream unclaimed"); + return Z_STREAM_ERROR; + } +} + +/* + * Decompress trailing data in a chunk. The assumption is that read_buffer + * points at an allocated area holding the contents of a chunk with a + * trailing compressed part. What we get back is an allocated area + * holding the original prefix part and an uncompressed version of the + * trailing part (the malloc area passed in is freed). + */ +static int +png_decompress_chunk(png_structrp png_ptr, + png_uint_32 chunklength, png_uint_32 prefix_size, + png_alloc_size_t *newlength /* must be initialized to the maximum! */, + int terminate /*add a '\0' to the end of the uncompressed data*/) +{ + /* TODO: implement different limits for different types of chunk. + * + * The caller supplies *newlength set to the maximum length of the + * uncompressed data, but this routine allocates space for the prefix and + * maybe a '\0' terminator too. We have to assume that 'prefix_size' is + * limited only by the maximum chunk size. + */ + png_alloc_size_t limit = PNG_SIZE_MAX; + +# ifdef PNG_SET_CHUNK_MALLOC_LIMIT_SUPPORTED + if (png_ptr->user_chunk_malloc_max > 0 && + png_ptr->user_chunk_malloc_max < limit) + limit = png_ptr->user_chunk_malloc_max; +# elif PNG_USER_CHUNK_MALLOC_MAX > 0 + if (PNG_USER_CHUNK_MALLOC_MAX < limit) + limit = PNG_USER_CHUNK_MALLOC_MAX; +# endif + + if (limit >= prefix_size + (terminate != 0)) + { + int ret; + + limit -= prefix_size + (terminate != 0); + + if (limit < *newlength) + *newlength = limit; + + /* Now try to claim the stream. */ + ret = png_inflate_claim(png_ptr, png_ptr->chunk_name); + + if (ret == Z_OK) + { + png_uint_32 lzsize = chunklength - prefix_size; + + ret = png_inflate(png_ptr, png_ptr->chunk_name, 1/*finish*/, + /* input: */ png_ptr->read_buffer + prefix_size, &lzsize, + /* output: */ NULL, newlength); + + if (ret == Z_STREAM_END) + { + /* Use 'inflateReset' here, not 'inflateReset2' because this + * preserves the previously decided window size (otherwise it would + * be necessary to store the previous window size.) In practice + * this doesn't matter anyway, because png_inflate will call inflate + * with Z_FINISH in almost all cases, so the window will not be + * maintained. + */ + if (inflateReset(&png_ptr->zstream) == Z_OK) + { + /* Because of the limit checks above we know that the new, + * expanded, size will fit in a size_t (let alone an + * png_alloc_size_t). Use png_malloc_base here to avoid an + * extra OOM message. + */ + png_alloc_size_t new_size = *newlength; + png_alloc_size_t buffer_size = prefix_size + new_size + + (terminate != 0); + png_bytep text = png_voidcast(png_bytep, png_malloc_base(png_ptr, + buffer_size)); + + if (text != NULL) + { + ret = png_inflate(png_ptr, png_ptr->chunk_name, 1/*finish*/, + png_ptr->read_buffer + prefix_size, &lzsize, + text + prefix_size, newlength); + + if (ret == Z_STREAM_END) + { + if (new_size == *newlength) + { + if (terminate) + text[prefix_size + *newlength] = 0; + + if (prefix_size > 0) + memcpy(text, png_ptr->read_buffer, prefix_size); + + { + png_bytep old_ptr = png_ptr->read_buffer; + + png_ptr->read_buffer = text; + png_ptr->read_buffer_size = buffer_size; + text = old_ptr; /* freed below */ + } + } + + else + { + /* The size changed on the second read, there can be no + * guarantee that anything is correct at this point. + * The 'msg' pointer has been set to "unexpected end of + * LZ stream", which is fine, but return an error code + * that the caller won't accept. + */ + ret = PNG_UNEXPECTED_ZLIB_RETURN; + } + } + + else if (ret == Z_OK) + ret = PNG_UNEXPECTED_ZLIB_RETURN; /* for safety */ + + /* Free the text pointer (this is the old read_buffer on + * success) + */ + png_free(png_ptr, text); + + /* This really is very benign, but it's still an error because + * the extra space may otherwise be used as a Trojan Horse. + */ + if (ret == Z_STREAM_END && + chunklength - prefix_size != lzsize) + png_chunk_benign_error(png_ptr, "extra compressed data"); + } + + else + { + /* Out of memory allocating the buffer */ + ret = Z_MEM_ERROR; + png_zstream_error(png_ptr, Z_MEM_ERROR); + } + } + + else + { + /* inflateReset failed, store the error message */ + png_zstream_error(png_ptr, ret); + + if (ret == Z_STREAM_END) + ret = PNG_UNEXPECTED_ZLIB_RETURN; + } + } + + else if (ret == Z_OK) + ret = PNG_UNEXPECTED_ZLIB_RETURN; + + /* Release the claimed stream */ + png_ptr->zowner = 0; + } + + else /* the claim failed */ if (ret == Z_STREAM_END) /* impossible! */ + ret = PNG_UNEXPECTED_ZLIB_RETURN; + + return ret; + } + + else + { + /* Application/configuration limits exceeded */ + png_zstream_error(png_ptr, Z_MEM_ERROR); + return Z_MEM_ERROR; + } +} +#endif /* PNG_READ_COMPRESSED_TEXT_SUPPORTED */ + +#ifdef PNG_READ_iCCP_SUPPORTED +/* Perform a partial read and decompress, producing 'avail_out' bytes and + * reading from the current chunk as required. + */ +static int +png_inflate_read(png_structrp png_ptr, png_bytep read_buffer, uInt read_size, + png_uint_32p chunk_bytes, png_bytep next_out, png_alloc_size_t *out_size, + int finish) +{ + if (png_ptr->zowner == png_ptr->chunk_name) + { + int ret; + + /* next_in and avail_in must have been initialized by the caller. */ + png_ptr->zstream.next_out = next_out; + png_ptr->zstream.avail_out = 0; /* set in the loop */ + + do + { + if (png_ptr->zstream.avail_in == 0) + { + if (read_size > *chunk_bytes) + read_size = (uInt)*chunk_bytes; + *chunk_bytes -= read_size; + + if (read_size > 0) + png_crc_read(png_ptr, read_buffer, read_size); + + png_ptr->zstream.next_in = read_buffer; + png_ptr->zstream.avail_in = read_size; + } + + if (png_ptr->zstream.avail_out == 0) + { + uInt avail = ZLIB_IO_MAX; + if (avail > *out_size) + avail = (uInt)*out_size; + *out_size -= avail; + + png_ptr->zstream.avail_out = avail; + } + + /* Use Z_SYNC_FLUSH when there is no more chunk data to ensure that all + * the available output is produced; this allows reading of truncated + * streams. + */ + ret = inflate(&png_ptr->zstream, + *chunk_bytes > 0 ? Z_NO_FLUSH : (finish ? Z_FINISH : Z_SYNC_FLUSH)); + } + while (ret == Z_OK && (*out_size > 0 || png_ptr->zstream.avail_out > 0)); + + *out_size += png_ptr->zstream.avail_out; + png_ptr->zstream.avail_out = 0; /* Should not be required, but is safe */ + + /* Ensure the error message pointer is always set: */ + png_zstream_error(png_ptr, ret); + return ret; + } + + else + { + png_ptr->zstream.msg = PNGZ_MSG_CAST("zstream unclaimed"); + return Z_STREAM_ERROR; + } +} +#endif + +/* Read and check the IDHR chunk */ +void /* PRIVATE */ +png_handle_IHDR(png_structrp png_ptr, png_inforp info_ptr, png_uint_32 length) +{ + png_byte buf[13]; + png_uint_32 width, height; + int bit_depth, color_type, compression_type, filter_type; + int interlace_type; + + png_debug(1, "in png_handle_IHDR"); + + if (png_ptr->mode & PNG_HAVE_IHDR) + png_chunk_error(png_ptr, "out of place"); + + /* Check the length */ + if (length != 13) + png_chunk_error(png_ptr, "invalid"); + + png_ptr->mode |= PNG_HAVE_IHDR; + + png_crc_read(png_ptr, buf, 13); + png_crc_finish(png_ptr, 0); + + width = png_get_uint_31(png_ptr, buf); + height = png_get_uint_31(png_ptr, buf + 4); + bit_depth = buf[8]; + color_type = buf[9]; + compression_type = buf[10]; + filter_type = buf[11]; + interlace_type = buf[12]; + + /* Set internal variables */ + png_ptr->width = width; + png_ptr->height = height; + png_ptr->bit_depth = (png_byte)bit_depth; + png_ptr->interlaced = (png_byte)interlace_type; + png_ptr->color_type = (png_byte)color_type; +#ifdef PNG_MNG_FEATURES_SUPPORTED + png_ptr->filter_type = (png_byte)filter_type; +#endif + png_ptr->compression_type = (png_byte)compression_type; + + /* Find number of channels */ + switch (png_ptr->color_type) + { + default: /* invalid, png_set_IHDR calls png_error */ + case PNG_COLOR_TYPE_GRAY: + case PNG_COLOR_TYPE_PALETTE: + png_ptr->channels = 1; + break; + + case PNG_COLOR_TYPE_RGB: + png_ptr->channels = 3; + break; + + case PNG_COLOR_TYPE_GRAY_ALPHA: + png_ptr->channels = 2; + break; + + case PNG_COLOR_TYPE_RGB_ALPHA: + png_ptr->channels = 4; + break; + } + + /* Set up other useful info */ + png_ptr->pixel_depth = (png_byte)(png_ptr->bit_depth * + png_ptr->channels); + png_ptr->rowbytes = PNG_ROWBYTES(png_ptr->pixel_depth, png_ptr->width); + png_debug1(3, "bit_depth = %d", png_ptr->bit_depth); + png_debug1(3, "channels = %d", png_ptr->channels); + png_debug1(3, "rowbytes = %lu", (unsigned long)png_ptr->rowbytes); + png_set_IHDR(png_ptr, info_ptr, width, height, bit_depth, + color_type, interlace_type, compression_type, filter_type); +} + +/* Read and check the palette */ +void /* PRIVATE */ +png_handle_PLTE(png_structrp png_ptr, png_inforp info_ptr, png_uint_32 length) +{ + png_color palette[PNG_MAX_PALETTE_LENGTH]; + int num, i; +#ifdef PNG_POINTER_INDEXING_SUPPORTED + png_colorp pal_ptr; +#endif + + png_debug(1, "in png_handle_PLTE"); + + if (!(png_ptr->mode & PNG_HAVE_IHDR)) + png_chunk_error(png_ptr, "missing IHDR"); + + /* Moved to before the 'after IDAT' check below because otherwise duplicate + * PLTE chunks are potentially ignored (the spec says there shall not be more + * than one PLTE, the error is not treated as benign, so this check trumps + * the requirement that PLTE appears before IDAT.) + */ + else if (png_ptr->mode & PNG_HAVE_PLTE) + png_chunk_error(png_ptr, "duplicate"); + + else if (png_ptr->mode & PNG_HAVE_IDAT) + { + /* This is benign because the non-benign error happened before, when an + * IDAT was encountered in a color-mapped image with no PLTE. + */ + png_crc_finish(png_ptr, length); + png_chunk_benign_error(png_ptr, "out of place"); + return; + } + + png_ptr->mode |= PNG_HAVE_PLTE; + + if (!(png_ptr->color_type & PNG_COLOR_MASK_COLOR)) + { + png_crc_finish(png_ptr, length); + png_chunk_benign_error(png_ptr, "ignored in grayscale PNG"); + return; + } + +#ifndef PNG_READ_OPT_PLTE_SUPPORTED + if (png_ptr->color_type != PNG_COLOR_TYPE_PALETTE) + { + png_crc_finish(png_ptr, length); + return; + } +#endif + + if (length > 3*PNG_MAX_PALETTE_LENGTH || length % 3) + { + png_crc_finish(png_ptr, length); + + if (png_ptr->color_type != PNG_COLOR_TYPE_PALETTE) + png_chunk_benign_error(png_ptr, "invalid"); + + else + png_chunk_error(png_ptr, "invalid"); + + return; + } + + /* The cast is safe because 'length' is less than 3*PNG_MAX_PALETTE_LENGTH */ + num = (int)length / 3; + +#ifdef PNG_POINTER_INDEXING_SUPPORTED + for (i = 0, pal_ptr = palette; i < num; i++, pal_ptr++) + { + png_byte buf[3]; + + png_crc_read(png_ptr, buf, 3); + pal_ptr->red = buf[0]; + pal_ptr->green = buf[1]; + pal_ptr->blue = buf[2]; + } +#else + for (i = 0; i < num; i++) + { + png_byte buf[3]; + + png_crc_read(png_ptr, buf, 3); + /* Don't depend upon png_color being any order */ + palette[i].red = buf[0]; + palette[i].green = buf[1]; + palette[i].blue = buf[2]; + } +#endif + + /* If we actually need the PLTE chunk (ie for a paletted image), we do + * whatever the normal CRC configuration tells us. However, if we + * have an RGB image, the PLTE can be considered ancillary, so + * we will act as though it is. + */ +#ifndef PNG_READ_OPT_PLTE_SUPPORTED + if (png_ptr->color_type == PNG_COLOR_TYPE_PALETTE) +#endif + { + png_crc_finish(png_ptr, 0); + } + +#ifndef PNG_READ_OPT_PLTE_SUPPORTED + else if (png_crc_error(png_ptr)) /* Only if we have a CRC error */ + { + /* If we don't want to use the data from an ancillary chunk, + * we have two options: an error abort, or a warning and we + * ignore the data in this chunk (which should be OK, since + * it's considered ancillary for a RGB or RGBA image). + * + * IMPLEMENTATION NOTE: this is only here because png_crc_finish uses the + * chunk type to determine whether to check the ancillary or the critical + * flags. + */ + if (!(png_ptr->flags & PNG_FLAG_CRC_ANCILLARY_USE)) + { + if (png_ptr->flags & PNG_FLAG_CRC_ANCILLARY_NOWARN) + { + png_chunk_benign_error(png_ptr, "CRC error"); + } + + else + { + png_chunk_warning(png_ptr, "CRC error"); + return; + } + } + + /* Otherwise, we (optionally) emit a warning and use the chunk. */ + else if (!(png_ptr->flags & PNG_FLAG_CRC_ANCILLARY_NOWARN)) + { + png_chunk_warning(png_ptr, "CRC error"); + } + } +#endif + + /* TODO: png_set_PLTE has the side effect of setting png_ptr->palette to its + * own copy of the palette. This has the side effect that when png_start_row + * is called (this happens after any call to png_read_update_info) the + * info_ptr palette gets changed. This is extremely unexpected and + * confusing. + * + * Fix this by not sharing the palette in this way. + */ + png_set_PLTE(png_ptr, info_ptr, palette, num); + + /* The three chunks, bKGD, hIST and tRNS *must* appear after PLTE and before + * IDAT. Prior to 1.6.0 this was not checked; instead the code merely + * checked the apparent validity of a tRNS chunk inserted before PLTE on a + * palette PNG. 1.6.0 attempts to rigorously follow the standard and + * therefore does a benign error if the erroneous condition is detected *and* + * cancels the tRNS if the benign error returns. The alternative is to + * amend the standard since it would be rather hypocritical of the standards + * maintainers to ignore it. + */ +#ifdef PNG_READ_tRNS_SUPPORTED + if (png_ptr->num_trans > 0 || + (info_ptr != NULL && (info_ptr->valid & PNG_INFO_tRNS) != 0)) + { + /* Cancel this because otherwise it would be used if the transforms + * require it. Don't cancel the 'valid' flag because this would prevent + * detection of duplicate chunks. + */ + png_ptr->num_trans = 0; + + if (info_ptr != NULL) + info_ptr->num_trans = 0; + + png_chunk_benign_error(png_ptr, "tRNS must be after"); + } +#endif + +#ifdef PNG_READ_hIST_SUPPORTED + if (info_ptr != NULL && (info_ptr->valid & PNG_INFO_hIST) != 0) + png_chunk_benign_error(png_ptr, "hIST must be after"); +#endif + +#ifdef PNG_READ_bKGD_SUPPORTED + if (info_ptr != NULL && (info_ptr->valid & PNG_INFO_bKGD) != 0) + png_chunk_benign_error(png_ptr, "bKGD must be after"); +#endif +} + +void /* PRIVATE */ +png_handle_IEND(png_structrp png_ptr, png_inforp info_ptr, png_uint_32 length) +{ + png_debug(1, "in png_handle_IEND"); + + if (!(png_ptr->mode & PNG_HAVE_IHDR) || !(png_ptr->mode & PNG_HAVE_IDAT)) + png_chunk_error(png_ptr, "out of place"); + + png_ptr->mode |= (PNG_AFTER_IDAT | PNG_HAVE_IEND); + + png_crc_finish(png_ptr, length); + + if (length != 0) + png_chunk_benign_error(png_ptr, "invalid"); + + PNG_UNUSED(info_ptr) +} + +#ifdef PNG_READ_gAMA_SUPPORTED +void /* PRIVATE */ +png_handle_gAMA(png_structrp png_ptr, png_inforp info_ptr, png_uint_32 length) +{ + png_fixed_point igamma; + png_byte buf[4]; + + png_debug(1, "in png_handle_gAMA"); + + if (!(png_ptr->mode & PNG_HAVE_IHDR)) + png_chunk_error(png_ptr, "missing IHDR"); + + else if (png_ptr->mode & (PNG_HAVE_IDAT|PNG_HAVE_PLTE)) + { + png_crc_finish(png_ptr, length); + png_chunk_benign_error(png_ptr, "out of place"); + return; + } + + if (length != 4) + { + png_crc_finish(png_ptr, length); + png_chunk_benign_error(png_ptr, "invalid"); + return; + } + + png_crc_read(png_ptr, buf, 4); + + if (png_crc_finish(png_ptr, 0)) + return; + + igamma = png_get_fixed_point(NULL, buf); + + png_colorspace_set_gamma(png_ptr, &png_ptr->colorspace, igamma); + png_colorspace_sync(png_ptr, info_ptr); +} +#endif + +#ifdef PNG_READ_sBIT_SUPPORTED +void /* PRIVATE */ +png_handle_sBIT(png_structrp png_ptr, png_inforp info_ptr, png_uint_32 length) +{ + unsigned int truelen; + png_byte buf[4]; + + png_debug(1, "in png_handle_sBIT"); + + buf[0] = buf[1] = buf[2] = buf[3] = 0; + + if (!(png_ptr->mode & PNG_HAVE_IHDR)) + png_chunk_error(png_ptr, "missing IHDR"); + + else if (png_ptr->mode & (PNG_HAVE_IDAT|PNG_HAVE_PLTE)) + { + png_crc_finish(png_ptr, length); + png_chunk_benign_error(png_ptr, "out of place"); + return; + } + + if (info_ptr != NULL && (info_ptr->valid & PNG_INFO_sBIT)) + { + png_crc_finish(png_ptr, length); + png_chunk_benign_error(png_ptr, "duplicate"); + return; + } + + if (png_ptr->color_type == PNG_COLOR_TYPE_PALETTE) + truelen = 3; + + else + truelen = png_ptr->channels; + + if (length != truelen || length > 4) + { + png_chunk_benign_error(png_ptr, "invalid"); + png_crc_finish(png_ptr, length); + return; + } + + png_crc_read(png_ptr, buf, truelen); + + if (png_crc_finish(png_ptr, 0)) + return; + + if (png_ptr->color_type & PNG_COLOR_MASK_COLOR) + { + png_ptr->sig_bit.red = buf[0]; + png_ptr->sig_bit.green = buf[1]; + png_ptr->sig_bit.blue = buf[2]; + png_ptr->sig_bit.alpha = buf[3]; + } + + else + { + png_ptr->sig_bit.gray = buf[0]; + png_ptr->sig_bit.red = buf[0]; + png_ptr->sig_bit.green = buf[0]; + png_ptr->sig_bit.blue = buf[0]; + png_ptr->sig_bit.alpha = buf[1]; + } + + png_set_sBIT(png_ptr, info_ptr, &(png_ptr->sig_bit)); +} +#endif + +#ifdef PNG_READ_cHRM_SUPPORTED +void /* PRIVATE */ +png_handle_cHRM(png_structrp png_ptr, png_inforp info_ptr, png_uint_32 length) +{ + png_byte buf[32]; + png_xy xy; + + png_debug(1, "in png_handle_cHRM"); + + if (!(png_ptr->mode & PNG_HAVE_IHDR)) + png_chunk_error(png_ptr, "missing IHDR"); + + else if (png_ptr->mode & (PNG_HAVE_IDAT|PNG_HAVE_PLTE)) + { + png_crc_finish(png_ptr, length); + png_chunk_benign_error(png_ptr, "out of place"); + return; + } + + if (length != 32) + { + png_crc_finish(png_ptr, length); + png_chunk_benign_error(png_ptr, "invalid"); + return; + } + + png_crc_read(png_ptr, buf, 32); + + if (png_crc_finish(png_ptr, 0)) + return; + + xy.whitex = png_get_fixed_point(NULL, buf); + xy.whitey = png_get_fixed_point(NULL, buf + 4); + xy.redx = png_get_fixed_point(NULL, buf + 8); + xy.redy = png_get_fixed_point(NULL, buf + 12); + xy.greenx = png_get_fixed_point(NULL, buf + 16); + xy.greeny = png_get_fixed_point(NULL, buf + 20); + xy.bluex = png_get_fixed_point(NULL, buf + 24); + xy.bluey = png_get_fixed_point(NULL, buf + 28); + + if (xy.whitex == PNG_FIXED_ERROR || + xy.whitey == PNG_FIXED_ERROR || + xy.redx == PNG_FIXED_ERROR || + xy.redy == PNG_FIXED_ERROR || + xy.greenx == PNG_FIXED_ERROR || + xy.greeny == PNG_FIXED_ERROR || + xy.bluex == PNG_FIXED_ERROR || + xy.bluey == PNG_FIXED_ERROR) + { + png_chunk_benign_error(png_ptr, "invalid values"); + return; + } + + /* If a colorspace error has already been output skip this chunk */ + if (png_ptr->colorspace.flags & PNG_COLORSPACE_INVALID) + return; + + if (png_ptr->colorspace.flags & PNG_COLORSPACE_FROM_cHRM) + { + png_ptr->colorspace.flags |= PNG_COLORSPACE_INVALID; + png_colorspace_sync(png_ptr, info_ptr); + png_chunk_benign_error(png_ptr, "duplicate"); + return; + } + + png_ptr->colorspace.flags |= PNG_COLORSPACE_FROM_cHRM; + (void)png_colorspace_set_chromaticities(png_ptr, &png_ptr->colorspace, &xy, + 1/*prefer cHRM values*/); + png_colorspace_sync(png_ptr, info_ptr); +} +#endif + +#ifdef PNG_READ_sRGB_SUPPORTED +void /* PRIVATE */ +png_handle_sRGB(png_structrp png_ptr, png_inforp info_ptr, png_uint_32 length) +{ + png_byte intent; + + png_debug(1, "in png_handle_sRGB"); + + if (!(png_ptr->mode & PNG_HAVE_IHDR)) + png_chunk_error(png_ptr, "missing IHDR"); + + else if (png_ptr->mode & (PNG_HAVE_IDAT|PNG_HAVE_PLTE)) + { + png_crc_finish(png_ptr, length); + png_chunk_benign_error(png_ptr, "out of place"); + return; + } + + if (length != 1) + { + png_crc_finish(png_ptr, length); + png_chunk_benign_error(png_ptr, "invalid"); + return; + } + + png_crc_read(png_ptr, &intent, 1); + + if (png_crc_finish(png_ptr, 0)) + return; + + /* If a colorspace error has already been output skip this chunk */ + if (png_ptr->colorspace.flags & PNG_COLORSPACE_INVALID) + return; + + /* Only one sRGB or iCCP chunk is allowed, use the HAVE_INTENT flag to detect + * this. + */ + if (png_ptr->colorspace.flags & PNG_COLORSPACE_HAVE_INTENT) + { + png_ptr->colorspace.flags |= PNG_COLORSPACE_INVALID; + png_colorspace_sync(png_ptr, info_ptr); + png_chunk_benign_error(png_ptr, "too many profiles"); + return; + } + + (void)png_colorspace_set_sRGB(png_ptr, &png_ptr->colorspace, intent); + png_colorspace_sync(png_ptr, info_ptr); +} +#endif /* PNG_READ_sRGB_SUPPORTED */ + +#ifdef PNG_READ_iCCP_SUPPORTED +void /* PRIVATE */ +png_handle_iCCP(png_structrp png_ptr, png_inforp info_ptr, png_uint_32 length) +/* Note: this does not properly handle profiles that are > 64K under DOS */ +{ + png_const_charp errmsg = NULL; /* error message output, or no error */ + int finished = 0; /* crc checked */ + + png_debug(1, "in png_handle_iCCP"); + + if (!(png_ptr->mode & PNG_HAVE_IHDR)) + png_chunk_error(png_ptr, "missing IHDR"); + + else if (png_ptr->mode & (PNG_HAVE_IDAT|PNG_HAVE_PLTE)) + { + png_crc_finish(png_ptr, length); + png_chunk_benign_error(png_ptr, "out of place"); + return; + } + + /* Consistent with all the above colorspace handling an obviously *invalid* + * chunk is just ignored, so does not invalidate the color space. An + * alternative is to set the 'invalid' flags at the start of this routine + * and only clear them in they were not set before and all the tests pass. + * The minimum 'deflate' stream is assumed to be just the 2 byte header and 4 + * byte checksum. The keyword must be one character and there is a + * terminator (0) byte and the compression method. + */ + if (length < 9) + { + png_crc_finish(png_ptr, length); + png_chunk_benign_error(png_ptr, "too short"); + return; + } + + /* If a colorspace error has already been output skip this chunk */ + if (png_ptr->colorspace.flags & PNG_COLORSPACE_INVALID) + { + png_crc_finish(png_ptr, length); + return; + } + + /* Only one sRGB or iCCP chunk is allowed, use the HAVE_INTENT flag to detect + * this. + */ + if ((png_ptr->colorspace.flags & PNG_COLORSPACE_HAVE_INTENT) == 0) + { + uInt read_length, keyword_length; + char keyword[81]; + + /* Find the keyword; the keyword plus separator and compression method + * bytes can be at most 81 characters long. + */ + read_length = 81; /* maximum */ + if (read_length > length) + read_length = (uInt)length; + + png_crc_read(png_ptr, (png_bytep)keyword, read_length); + length -= read_length; + + keyword_length = 0; + while (keyword_length < 80 && keyword_length < read_length && + keyword[keyword_length] != 0) + ++keyword_length; + + /* TODO: make the keyword checking common */ + if (keyword_length >= 1 && keyword_length <= 79) + { + /* We only understand '0' compression - deflate - so if we get a + * different value we can't safely decode the chunk. + */ + if (keyword_length+1 < read_length && + keyword[keyword_length+1] == PNG_COMPRESSION_TYPE_BASE) + { + read_length -= keyword_length+2; + + if (png_inflate_claim(png_ptr, png_iCCP) == Z_OK) + { + Byte profile_header[132]; + Byte local_buffer[PNG_INFLATE_BUF_SIZE]; + png_alloc_size_t size = (sizeof profile_header); + + png_ptr->zstream.next_in = (Bytef*)keyword + (keyword_length+2); + png_ptr->zstream.avail_in = read_length; + (void)png_inflate_read(png_ptr, local_buffer, + (sizeof local_buffer), &length, profile_header, &size, + 0/*finish: don't, because the output is too small*/); + + if (size == 0) + { + /* We have the ICC profile header; do the basic header checks. + */ + const png_uint_32 profile_length = + png_get_uint_32(profile_header); + + if (png_icc_check_length(png_ptr, &png_ptr->colorspace, + keyword, profile_length)) + { + /* The length is apparently ok, so we can check the 132 + * byte header. + */ + if (png_icc_check_header(png_ptr, &png_ptr->colorspace, + keyword, profile_length, profile_header, + png_ptr->color_type)) + { + /* Now read the tag table; a variable size buffer is + * needed at this point, allocate one for the whole + * profile. The header check has already validated + * that none of these stuff will overflow. + */ + const png_uint_32 tag_count = png_get_uint_32( + profile_header+128); + png_bytep profile = png_read_buffer(png_ptr, + profile_length, 2/*silent*/); + + if (profile != NULL) + { + memcpy(profile, profile_header, + (sizeof profile_header)); + + size = 12 * tag_count; + + (void)png_inflate_read(png_ptr, local_buffer, + (sizeof local_buffer), &length, + profile + (sizeof profile_header), &size, 0); + + /* Still expect a a buffer error because we expect + * there to be some tag data! + */ + if (size == 0) + { + if (png_icc_check_tag_table(png_ptr, + &png_ptr->colorspace, keyword, profile_length, + profile)) + { + /* The profile has been validated for basic + * security issues, so read the whole thing in. + */ + size = profile_length - (sizeof profile_header) + - 12 * tag_count; + + (void)png_inflate_read(png_ptr, local_buffer, + (sizeof local_buffer), &length, + profile + (sizeof profile_header) + + 12 * tag_count, &size, 1/*finish*/); + + if (length > 0 && !(png_ptr->flags & + PNG_FLAG_BENIGN_ERRORS_WARN)) + errmsg = "extra compressed data"; + + /* But otherwise allow extra data: */ + else if (size == 0) + { + if (length > 0) + { + /* This can be handled completely, so + * keep going. + */ + png_chunk_warning(png_ptr, + "extra compressed data"); + } + + png_crc_finish(png_ptr, length); + finished = 1; + +# ifdef PNG_sRGB_SUPPORTED + /* Check for a match against sRGB */ + png_icc_set_sRGB(png_ptr, + &png_ptr->colorspace, profile, + png_ptr->zstream.adler); +# endif + + /* Steal the profile for info_ptr. */ + if (info_ptr != NULL) + { + png_free_data(png_ptr, info_ptr, + PNG_FREE_ICCP, 0); + + info_ptr->iccp_name = png_voidcast(char*, + png_malloc_base(png_ptr, + keyword_length+1)); + if (info_ptr->iccp_name != NULL) + { + memcpy(info_ptr->iccp_name, keyword, + keyword_length+1); + info_ptr->iccp_proflen = + profile_length; + info_ptr->iccp_profile = profile; + png_ptr->read_buffer = NULL; /*steal*/ + info_ptr->free_me |= PNG_FREE_ICCP; + info_ptr->valid |= PNG_INFO_iCCP; + } + + else + { + png_ptr->colorspace.flags |= + PNG_COLORSPACE_INVALID; + errmsg = "out of memory"; + } + } + + /* else the profile remains in the read + * buffer which gets reused for subsequent + * chunks. + */ + + if (info_ptr != NULL) + png_colorspace_sync(png_ptr, info_ptr); + + if (errmsg == NULL) + { + png_ptr->zowner = 0; + return; + } + } + + else if (size > 0) + errmsg = "truncated"; + + else + errmsg = png_ptr->zstream.msg; + } + + /* else png_icc_check_tag_table output an error */ + } + + else /* profile truncated */ + errmsg = png_ptr->zstream.msg; + } + + else + errmsg = "out of memory"; + } + + /* else png_icc_check_header output an error */ + } + + /* else png_icc_check_length output an error */ + } + + else /* profile truncated */ + errmsg = png_ptr->zstream.msg; + + /* Release the stream */ + png_ptr->zowner = 0; + } + + else /* png_inflate_claim failed */ + errmsg = png_ptr->zstream.msg; + } + + else + errmsg = "bad compression method"; /* or missing */ + } + + else + errmsg = "bad keyword"; + } + + else + errmsg = "too many profiles"; + + /* Failure: the reason is in 'errmsg' */ + if (!finished) + png_crc_finish(png_ptr, length); + + png_ptr->colorspace.flags |= PNG_COLORSPACE_INVALID; + png_colorspace_sync(png_ptr, info_ptr); + if (errmsg != NULL) /* else already output */ + png_chunk_benign_error(png_ptr, errmsg); +} +#endif /* PNG_READ_iCCP_SUPPORTED */ + +#ifdef PNG_READ_sPLT_SUPPORTED +void /* PRIVATE */ +png_handle_sPLT(png_structrp png_ptr, png_inforp info_ptr, png_uint_32 length) +/* Note: this does not properly handle chunks that are > 64K under DOS */ +{ + png_bytep entry_start, buffer; + png_sPLT_t new_palette; + png_sPLT_entryp pp; + png_uint_32 data_length; + int entry_size, i; + png_uint_32 skip = 0; + png_uint_32 dl; + png_size_t max_dl; + + png_debug(1, "in png_handle_sPLT"); + +#ifdef PNG_USER_LIMITS_SUPPORTED + if (png_ptr->user_chunk_cache_max != 0) + { + if (png_ptr->user_chunk_cache_max == 1) + { + png_crc_finish(png_ptr, length); + return; + } + + if (--png_ptr->user_chunk_cache_max == 1) + { + png_warning(png_ptr, "No space in chunk cache for sPLT"); + png_crc_finish(png_ptr, length); + return; + } + } +#endif + + if (!(png_ptr->mode & PNG_HAVE_IHDR)) + png_chunk_error(png_ptr, "missing IHDR"); + + else if (png_ptr->mode & PNG_HAVE_IDAT) + { + png_crc_finish(png_ptr, length); + png_chunk_benign_error(png_ptr, "out of place"); + return; + } + +#ifdef PNG_MAX_MALLOC_64K + if (length > 65535U) + { + png_crc_finish(png_ptr, length); + png_chunk_benign_error(png_ptr, "too large to fit in memory"); + return; + } +#endif + + buffer = png_read_buffer(png_ptr, length+1, 2/*silent*/); + if (buffer == NULL) + { + png_crc_finish(png_ptr, length); + png_chunk_benign_error(png_ptr, "out of memory"); + return; + } + + + /* WARNING: this may break if size_t is less than 32 bits; it is assumed + * that the PNG_MAX_MALLOC_64K test is enabled in this case, but this is a + * potential breakage point if the types in pngconf.h aren't exactly right. + */ + png_crc_read(png_ptr, buffer, length); + + if (png_crc_finish(png_ptr, skip)) + return; + + buffer[length] = 0; + + for (entry_start = buffer; *entry_start; entry_start++) + /* Empty loop to find end of name */ ; + + ++entry_start; + + /* A sample depth should follow the separator, and we should be on it */ + if (entry_start > buffer + length - 2) + { + png_warning(png_ptr, "malformed sPLT chunk"); + return; + } + + new_palette.depth = *entry_start++; + entry_size = (new_palette.depth == 8 ? 6 : 10); + /* This must fit in a png_uint_32 because it is derived from the original + * chunk data length. + */ + data_length = length - (png_uint_32)(entry_start - buffer); + + /* Integrity-check the data length */ + if (data_length % entry_size) + { + png_warning(png_ptr, "sPLT chunk has bad length"); + return; + } + + dl = (png_int_32)(data_length / entry_size); + max_dl = PNG_SIZE_MAX / (sizeof (png_sPLT_entry)); + + if (dl > max_dl) + { + png_warning(png_ptr, "sPLT chunk too long"); + return; + } + + new_palette.nentries = (png_int_32)(data_length / entry_size); + + new_palette.entries = (png_sPLT_entryp)png_malloc_warn( + png_ptr, new_palette.nentries * (sizeof (png_sPLT_entry))); + + if (new_palette.entries == NULL) + { + png_warning(png_ptr, "sPLT chunk requires too much memory"); + return; + } + +#ifdef PNG_POINTER_INDEXING_SUPPORTED + for (i = 0; i < new_palette.nentries; i++) + { + pp = new_palette.entries + i; + + if (new_palette.depth == 8) + { + pp->red = *entry_start++; + pp->green = *entry_start++; + pp->blue = *entry_start++; + pp->alpha = *entry_start++; + } + + else + { + pp->red = png_get_uint_16(entry_start); entry_start += 2; + pp->green = png_get_uint_16(entry_start); entry_start += 2; + pp->blue = png_get_uint_16(entry_start); entry_start += 2; + pp->alpha = png_get_uint_16(entry_start); entry_start += 2; + } + + pp->frequency = png_get_uint_16(entry_start); entry_start += 2; + } +#else + pp = new_palette.entries; + + for (i = 0; i < new_palette.nentries; i++) + { + + if (new_palette.depth == 8) + { + pp[i].red = *entry_start++; + pp[i].green = *entry_start++; + pp[i].blue = *entry_start++; + pp[i].alpha = *entry_start++; + } + + else + { + pp[i].red = png_get_uint_16(entry_start); entry_start += 2; + pp[i].green = png_get_uint_16(entry_start); entry_start += 2; + pp[i].blue = png_get_uint_16(entry_start); entry_start += 2; + pp[i].alpha = png_get_uint_16(entry_start); entry_start += 2; + } + + pp[i].frequency = png_get_uint_16(entry_start); entry_start += 2; + } +#endif + + /* Discard all chunk data except the name and stash that */ + new_palette.name = (png_charp)buffer; + + png_set_sPLT(png_ptr, info_ptr, &new_palette, 1); + + png_free(png_ptr, new_palette.entries); +} +#endif /* PNG_READ_sPLT_SUPPORTED */ + +#ifdef PNG_READ_tRNS_SUPPORTED +void /* PRIVATE */ +png_handle_tRNS(png_structrp png_ptr, png_inforp info_ptr, png_uint_32 length) +{ + png_byte readbuf[PNG_MAX_PALETTE_LENGTH]; + + png_debug(1, "in png_handle_tRNS"); + + if (!(png_ptr->mode & PNG_HAVE_IHDR)) + png_chunk_error(png_ptr, "missing IHDR"); + + else if (png_ptr->mode & PNG_HAVE_IDAT) + { + png_crc_finish(png_ptr, length); + png_chunk_benign_error(png_ptr, "out of place"); + return; + } + + else if (info_ptr != NULL && (info_ptr->valid & PNG_INFO_tRNS)) + { + png_crc_finish(png_ptr, length); + png_chunk_benign_error(png_ptr, "duplicate"); + return; + } + + if (png_ptr->color_type == PNG_COLOR_TYPE_GRAY) + { + png_byte buf[2]; + + if (length != 2) + { + png_crc_finish(png_ptr, length); + png_chunk_benign_error(png_ptr, "invalid"); + return; + } + + png_crc_read(png_ptr, buf, 2); + png_ptr->num_trans = 1; + png_ptr->trans_color.gray = png_get_uint_16(buf); + } + + else if (png_ptr->color_type == PNG_COLOR_TYPE_RGB) + { + png_byte buf[6]; + + if (length != 6) + { + png_crc_finish(png_ptr, length); + png_chunk_benign_error(png_ptr, "invalid"); + return; + } + + png_crc_read(png_ptr, buf, length); + png_ptr->num_trans = 1; + png_ptr->trans_color.red = png_get_uint_16(buf); + png_ptr->trans_color.green = png_get_uint_16(buf + 2); + png_ptr->trans_color.blue = png_get_uint_16(buf + 4); + } + + else if (png_ptr->color_type == PNG_COLOR_TYPE_PALETTE) + { + if (!(png_ptr->mode & PNG_HAVE_PLTE)) + { + /* TODO: is this actually an error in the ISO spec? */ + png_crc_finish(png_ptr, length); + png_chunk_benign_error(png_ptr, "out of place"); + return; + } + + if (length > png_ptr->num_palette || length > PNG_MAX_PALETTE_LENGTH || + length == 0) + { + png_crc_finish(png_ptr, length); + png_chunk_benign_error(png_ptr, "invalid"); + return; + } + + png_crc_read(png_ptr, readbuf, length); + png_ptr->num_trans = (png_uint_16)length; + } + + else + { + png_crc_finish(png_ptr, length); + png_chunk_benign_error(png_ptr, "invalid with alpha channel"); + return; + } + + if (png_crc_finish(png_ptr, 0)) + { + png_ptr->num_trans = 0; + return; + } + + /* TODO: this is a horrible side effect in the palette case because the + * png_struct ends up with a pointer to the tRNS buffer owned by the + * png_info. Fix this. + */ + png_set_tRNS(png_ptr, info_ptr, readbuf, png_ptr->num_trans, + &(png_ptr->trans_color)); +} +#endif + +#ifdef PNG_READ_bKGD_SUPPORTED +void /* PRIVATE */ +png_handle_bKGD(png_structrp png_ptr, png_inforp info_ptr, png_uint_32 length) +{ + unsigned int truelen; + png_byte buf[6]; + png_color_16 background; + + png_debug(1, "in png_handle_bKGD"); + + if (!(png_ptr->mode & PNG_HAVE_IHDR)) + png_chunk_error(png_ptr, "missing IHDR"); + + else if ((png_ptr->mode & PNG_HAVE_IDAT) || + (png_ptr->color_type == PNG_COLOR_TYPE_PALETTE && + !(png_ptr->mode & PNG_HAVE_PLTE))) + { + png_crc_finish(png_ptr, length); + png_chunk_benign_error(png_ptr, "out of place"); + return; + } + + else if (info_ptr != NULL && (info_ptr->valid & PNG_INFO_bKGD)) + { + png_crc_finish(png_ptr, length); + png_chunk_benign_error(png_ptr, "duplicate"); + return; + } + + if (png_ptr->color_type == PNG_COLOR_TYPE_PALETTE) + truelen = 1; + + else if (png_ptr->color_type & PNG_COLOR_MASK_COLOR) + truelen = 6; + + else + truelen = 2; + + if (length != truelen) + { + png_crc_finish(png_ptr, length); + png_chunk_benign_error(png_ptr, "invalid"); + return; + } + + png_crc_read(png_ptr, buf, truelen); + + if (png_crc_finish(png_ptr, 0)) + return; + + /* We convert the index value into RGB components so that we can allow + * arbitrary RGB values for background when we have transparency, and + * so it is easy to determine the RGB values of the background color + * from the info_ptr struct. + */ + if (png_ptr->color_type == PNG_COLOR_TYPE_PALETTE) + { + background.index = buf[0]; + + if (info_ptr && info_ptr->num_palette) + { + if (buf[0] >= info_ptr->num_palette) + { + png_chunk_benign_error(png_ptr, "invalid index"); + return; + } + + background.red = (png_uint_16)png_ptr->palette[buf[0]].red; + background.green = (png_uint_16)png_ptr->palette[buf[0]].green; + background.blue = (png_uint_16)png_ptr->palette[buf[0]].blue; + } + + else + background.red = background.green = background.blue = 0; + + background.gray = 0; + } + + else if (!(png_ptr->color_type & PNG_COLOR_MASK_COLOR)) /* GRAY */ + { + background.index = 0; + background.red = + background.green = + background.blue = + background.gray = png_get_uint_16(buf); + } + + else + { + background.index = 0; + background.red = png_get_uint_16(buf); + background.green = png_get_uint_16(buf + 2); + background.blue = png_get_uint_16(buf + 4); + background.gray = 0; + } + + png_set_bKGD(png_ptr, info_ptr, &background); +} +#endif + +#ifdef PNG_READ_hIST_SUPPORTED +void /* PRIVATE */ +png_handle_hIST(png_structrp png_ptr, png_inforp info_ptr, png_uint_32 length) +{ + unsigned int num, i; + png_uint_16 readbuf[PNG_MAX_PALETTE_LENGTH]; + + png_debug(1, "in png_handle_hIST"); + + if (!(png_ptr->mode & PNG_HAVE_IHDR)) + png_chunk_error(png_ptr, "missing IHDR"); + + else if ((png_ptr->mode & PNG_HAVE_IDAT) || !(png_ptr->mode & PNG_HAVE_PLTE)) + { + png_crc_finish(png_ptr, length); + png_chunk_benign_error(png_ptr, "out of place"); + return; + } + + else if (info_ptr != NULL && (info_ptr->valid & PNG_INFO_hIST)) + { + png_crc_finish(png_ptr, length); + png_chunk_benign_error(png_ptr, "duplicate"); + return; + } + + num = length / 2 ; + + if (num != png_ptr->num_palette || num > PNG_MAX_PALETTE_LENGTH) + { + png_crc_finish(png_ptr, length); + png_chunk_benign_error(png_ptr, "invalid"); + return; + } + + for (i = 0; i < num; i++) + { + png_byte buf[2]; + + png_crc_read(png_ptr, buf, 2); + readbuf[i] = png_get_uint_16(buf); + } + + if (png_crc_finish(png_ptr, 0)) + return; + + png_set_hIST(png_ptr, info_ptr, readbuf); +} +#endif + +#ifdef PNG_READ_pHYs_SUPPORTED +void /* PRIVATE */ +png_handle_pHYs(png_structrp png_ptr, png_inforp info_ptr, png_uint_32 length) +{ + png_byte buf[9]; + png_uint_32 res_x, res_y; + int unit_type; + + png_debug(1, "in png_handle_pHYs"); + + if (!(png_ptr->mode & PNG_HAVE_IHDR)) + png_chunk_error(png_ptr, "missing IHDR"); + + else if (png_ptr->mode & PNG_HAVE_IDAT) + { + png_crc_finish(png_ptr, length); + png_chunk_benign_error(png_ptr, "out of place"); + return; + } + + else if (info_ptr != NULL && (info_ptr->valid & PNG_INFO_pHYs)) + { + png_crc_finish(png_ptr, length); + png_chunk_benign_error(png_ptr, "duplicate"); + return; + } + + if (length != 9) + { + png_crc_finish(png_ptr, length); + png_chunk_benign_error(png_ptr, "invalid"); + return; + } + + png_crc_read(png_ptr, buf, 9); + + if (png_crc_finish(png_ptr, 0)) + return; + + res_x = png_get_uint_32(buf); + res_y = png_get_uint_32(buf + 4); + unit_type = buf[8]; + png_set_pHYs(png_ptr, info_ptr, res_x, res_y, unit_type); +} +#endif + +#ifdef PNG_READ_oFFs_SUPPORTED +void /* PRIVATE */ +png_handle_oFFs(png_structrp png_ptr, png_inforp info_ptr, png_uint_32 length) +{ + png_byte buf[9]; + png_int_32 offset_x, offset_y; + int unit_type; + + png_debug(1, "in png_handle_oFFs"); + + if (!(png_ptr->mode & PNG_HAVE_IHDR)) + png_chunk_error(png_ptr, "missing IHDR"); + + else if (png_ptr->mode & PNG_HAVE_IDAT) + { + png_crc_finish(png_ptr, length); + png_chunk_benign_error(png_ptr, "out of place"); + return; + } + + else if (info_ptr != NULL && (info_ptr->valid & PNG_INFO_oFFs)) + { + png_crc_finish(png_ptr, length); + png_chunk_benign_error(png_ptr, "duplicate"); + return; + } + + if (length != 9) + { + png_crc_finish(png_ptr, length); + png_chunk_benign_error(png_ptr, "invalid"); + return; + } + + png_crc_read(png_ptr, buf, 9); + + if (png_crc_finish(png_ptr, 0)) + return; + + offset_x = png_get_int_32(buf); + offset_y = png_get_int_32(buf + 4); + unit_type = buf[8]; + png_set_oFFs(png_ptr, info_ptr, offset_x, offset_y, unit_type); +} +#endif + +#ifdef PNG_READ_pCAL_SUPPORTED +/* Read the pCAL chunk (described in the PNG Extensions document) */ +void /* PRIVATE */ +png_handle_pCAL(png_structrp png_ptr, png_inforp info_ptr, png_uint_32 length) +{ + png_int_32 X0, X1; + png_byte type, nparams; + png_bytep buffer, buf, units, endptr; + png_charpp params; + int i; + + png_debug(1, "in png_handle_pCAL"); + + if (!(png_ptr->mode & PNG_HAVE_IHDR)) + png_chunk_error(png_ptr, "missing IHDR"); + + else if (png_ptr->mode & PNG_HAVE_IDAT) + { + png_crc_finish(png_ptr, length); + png_chunk_benign_error(png_ptr, "out of place"); + return; + } + + else if (info_ptr != NULL && (info_ptr->valid & PNG_INFO_pCAL)) + { + png_crc_finish(png_ptr, length); + png_chunk_benign_error(png_ptr, "duplicate"); + return; + } + + png_debug1(2, "Allocating and reading pCAL chunk data (%u bytes)", + length + 1); + + buffer = png_read_buffer(png_ptr, length+1, 2/*silent*/); + + if (buffer == NULL) + { + png_crc_finish(png_ptr, length); + png_chunk_benign_error(png_ptr, "out of memory"); + return; + } + + png_crc_read(png_ptr, buffer, length); + + if (png_crc_finish(png_ptr, 0)) + return; + + buffer[length] = 0; /* Null terminate the last string */ + + png_debug(3, "Finding end of pCAL purpose string"); + for (buf = buffer; *buf; buf++) + /* Empty loop */ ; + + endptr = buffer + length; + + /* We need to have at least 12 bytes after the purpose string + * in order to get the parameter information. + */ + if (endptr <= buf + 12) + { + png_chunk_benign_error(png_ptr, "invalid"); + return; + } + + png_debug(3, "Reading pCAL X0, X1, type, nparams, and units"); + X0 = png_get_int_32((png_bytep)buf+1); + X1 = png_get_int_32((png_bytep)buf+5); + type = buf[9]; + nparams = buf[10]; + units = buf + 11; + + png_debug(3, "Checking pCAL equation type and number of parameters"); + /* Check that we have the right number of parameters for known + * equation types. + */ + if ((type == PNG_EQUATION_LINEAR && nparams != 2) || + (type == PNG_EQUATION_BASE_E && nparams != 3) || + (type == PNG_EQUATION_ARBITRARY && nparams != 3) || + (type == PNG_EQUATION_HYPERBOLIC && nparams != 4)) + { + png_chunk_benign_error(png_ptr, "invalid parameter count"); + return; + } + + else if (type >= PNG_EQUATION_LAST) + { + png_chunk_benign_error(png_ptr, "unrecognized equation type"); + } + + for (buf = units; *buf; buf++) + /* Empty loop to move past the units string. */ ; + + png_debug(3, "Allocating pCAL parameters array"); + + params = png_voidcast(png_charpp, png_malloc_warn(png_ptr, + nparams * (sizeof (png_charp)))); + + if (params == NULL) + { + png_chunk_benign_error(png_ptr, "out of memory"); + return; + } + + /* Get pointers to the start of each parameter string. */ + for (i = 0; i < nparams; i++) + { + buf++; /* Skip the null string terminator from previous parameter. */ + + png_debug1(3, "Reading pCAL parameter %d", i); + + for (params[i] = (png_charp)buf; buf <= endptr && *buf != 0; buf++) + /* Empty loop to move past each parameter string */ ; + + /* Make sure we haven't run out of data yet */ + if (buf > endptr) + { + png_free(png_ptr, params); + png_chunk_benign_error(png_ptr, "invalid data"); + return; + } + } + + png_set_pCAL(png_ptr, info_ptr, (png_charp)buffer, X0, X1, type, nparams, + (png_charp)units, params); + + png_free(png_ptr, params); +} +#endif + +#ifdef PNG_READ_sCAL_SUPPORTED +/* Read the sCAL chunk */ +void /* PRIVATE */ +png_handle_sCAL(png_structrp png_ptr, png_inforp info_ptr, png_uint_32 length) +{ + png_bytep buffer; + png_size_t i; + int state; + + png_debug(1, "in png_handle_sCAL"); + + if (!(png_ptr->mode & PNG_HAVE_IHDR)) + png_chunk_error(png_ptr, "missing IHDR"); + + else if (png_ptr->mode & PNG_HAVE_IDAT) + { + png_crc_finish(png_ptr, length); + png_chunk_benign_error(png_ptr, "out of place"); + return; + } + + else if (info_ptr != NULL && (info_ptr->valid & PNG_INFO_sCAL)) + { + png_crc_finish(png_ptr, length); + png_chunk_benign_error(png_ptr, "duplicate"); + return; + } + + /* Need unit type, width, \0, height: minimum 4 bytes */ + else if (length < 4) + { + png_crc_finish(png_ptr, length); + png_chunk_benign_error(png_ptr, "invalid"); + return; + } + + png_debug1(2, "Allocating and reading sCAL chunk data (%u bytes)", + length + 1); + + buffer = png_read_buffer(png_ptr, length+1, 2/*silent*/); + + if (buffer == NULL) + { + png_chunk_benign_error(png_ptr, "out of memory"); + png_crc_finish(png_ptr, length); + return; + } + + png_crc_read(png_ptr, buffer, length); + buffer[length] = 0; /* Null terminate the last string */ + + if (png_crc_finish(png_ptr, 0)) + return; + + /* Validate the unit. */ + if (buffer[0] != 1 && buffer[0] != 2) + { + png_chunk_benign_error(png_ptr, "invalid unit"); + return; + } + + /* Validate the ASCII numbers, need two ASCII numbers separated by + * a '\0' and they need to fit exactly in the chunk data. + */ + i = 1; + state = 0; + + if (!png_check_fp_number((png_const_charp)buffer, length, &state, &i) || + i >= length || buffer[i++] != 0) + png_chunk_benign_error(png_ptr, "bad width format"); + + else if (!PNG_FP_IS_POSITIVE(state)) + png_chunk_benign_error(png_ptr, "non-positive width"); + + else + { + png_size_t heighti = i; + + state = 0; + if (!png_check_fp_number((png_const_charp)buffer, length, &state, &i) || + i != length) + png_chunk_benign_error(png_ptr, "bad height format"); + + else if (!PNG_FP_IS_POSITIVE(state)) + png_chunk_benign_error(png_ptr, "non-positive height"); + + else + /* This is the (only) success case. */ + png_set_sCAL_s(png_ptr, info_ptr, buffer[0], + (png_charp)buffer+1, (png_charp)buffer+heighti); + } +} +#endif + +#ifdef PNG_READ_tIME_SUPPORTED +void /* PRIVATE */ +png_handle_tIME(png_structrp png_ptr, png_inforp info_ptr, png_uint_32 length) +{ + png_byte buf[7]; + png_time mod_time; + + png_debug(1, "in png_handle_tIME"); + + if (!(png_ptr->mode & PNG_HAVE_IHDR)) + png_chunk_error(png_ptr, "missing IHDR"); + + else if (info_ptr != NULL && (info_ptr->valid & PNG_INFO_tIME)) + { + png_crc_finish(png_ptr, length); + png_chunk_benign_error(png_ptr, "duplicate"); + return; + } + + if (png_ptr->mode & PNG_HAVE_IDAT) + png_ptr->mode |= PNG_AFTER_IDAT; + + if (length != 7) + { + png_crc_finish(png_ptr, length); + png_chunk_benign_error(png_ptr, "invalid"); + return; + } + + png_crc_read(png_ptr, buf, 7); + + if (png_crc_finish(png_ptr, 0)) + return; + + mod_time.second = buf[6]; + mod_time.minute = buf[5]; + mod_time.hour = buf[4]; + mod_time.day = buf[3]; + mod_time.month = buf[2]; + mod_time.year = png_get_uint_16(buf); + + png_set_tIME(png_ptr, info_ptr, &mod_time); +} +#endif + +#ifdef PNG_READ_tEXt_SUPPORTED +/* Note: this does not properly handle chunks that are > 64K under DOS */ +void /* PRIVATE */ +png_handle_tEXt(png_structrp png_ptr, png_inforp info_ptr, png_uint_32 length) +{ + png_text text_info; + png_bytep buffer; + png_charp key; + png_charp text; + png_uint_32 skip = 0; + + png_debug(1, "in png_handle_tEXt"); + +#ifdef PNG_USER_LIMITS_SUPPORTED + if (png_ptr->user_chunk_cache_max != 0) + { + if (png_ptr->user_chunk_cache_max == 1) + { + png_crc_finish(png_ptr, length); + return; + } + + if (--png_ptr->user_chunk_cache_max == 1) + { + png_crc_finish(png_ptr, length); + png_chunk_benign_error(png_ptr, "no space in chunk cache"); + return; + } + } +#endif + + if (!(png_ptr->mode & PNG_HAVE_IHDR)) + png_chunk_error(png_ptr, "missing IHDR"); + + if (png_ptr->mode & PNG_HAVE_IDAT) + png_ptr->mode |= PNG_AFTER_IDAT; + +#ifdef PNG_MAX_MALLOC_64K + if (length > 65535U) + { + png_crc_finish(png_ptr, length); + png_chunk_benign_error(png_ptr, "too large to fit in memory"); + return; + } +#endif + + buffer = png_read_buffer(png_ptr, length+1, 1/*warn*/); + + if (buffer == NULL) + { + png_chunk_benign_error(png_ptr, "out of memory"); + return; + } + + png_crc_read(png_ptr, buffer, length); + + if (png_crc_finish(png_ptr, skip)) + return; + + key = (png_charp)buffer; + key[length] = 0; + + for (text = key; *text; text++) + /* Empty loop to find end of key */ ; + + if (text != key + length) + text++; + + text_info.compression = PNG_TEXT_COMPRESSION_NONE; + text_info.key = key; + text_info.lang = NULL; + text_info.lang_key = NULL; + text_info.itxt_length = 0; + text_info.text = text; + text_info.text_length = strlen(text); + + if (png_set_text_2(png_ptr, info_ptr, &text_info, 1)) + png_warning(png_ptr, "Insufficient memory to process text chunk"); +} +#endif + +#ifdef PNG_READ_zTXt_SUPPORTED +/* Note: this does not correctly handle chunks that are > 64K under DOS */ +void /* PRIVATE */ +png_handle_zTXt(png_structrp png_ptr, png_inforp info_ptr, png_uint_32 length) +{ + png_const_charp errmsg = NULL; + png_bytep buffer; + png_uint_32 keyword_length; + + png_debug(1, "in png_handle_zTXt"); + +#ifdef PNG_USER_LIMITS_SUPPORTED + if (png_ptr->user_chunk_cache_max != 0) + { + if (png_ptr->user_chunk_cache_max == 1) + { + png_crc_finish(png_ptr, length); + return; + } + + if (--png_ptr->user_chunk_cache_max == 1) + { + png_crc_finish(png_ptr, length); + png_chunk_benign_error(png_ptr, "no space in chunk cache"); + return; + } + } +#endif + + if (!(png_ptr->mode & PNG_HAVE_IHDR)) + png_chunk_error(png_ptr, "missing IHDR"); + + if (png_ptr->mode & PNG_HAVE_IDAT) + png_ptr->mode |= PNG_AFTER_IDAT; + + buffer = png_read_buffer(png_ptr, length, 2/*silent*/); + + if (buffer == NULL) + { + png_crc_finish(png_ptr, length); + png_chunk_benign_error(png_ptr, "out of memory"); + return; + } + + png_crc_read(png_ptr, buffer, length); + + if (png_crc_finish(png_ptr, 0)) + return; + + /* TODO: also check that the keyword contents match the spec! */ + for (keyword_length = 0; + keyword_length < length && buffer[keyword_length] != 0; + ++keyword_length) + /* Empty loop to find end of name */ ; + + if (keyword_length > 79 || keyword_length < 1) + errmsg = "bad keyword"; + + /* zTXt must have some LZ data after the keyword, although it may expand to + * zero bytes; we need a '\0' at the end of the keyword, the compression type + * then the LZ data: + */ + else if (keyword_length + 3 > length) + errmsg = "truncated"; + + else if (buffer[keyword_length+1] != PNG_COMPRESSION_TYPE_BASE) + errmsg = "unknown compression type"; + + else + { + png_alloc_size_t uncompressed_length = PNG_SIZE_MAX; + + /* TODO: at present png_decompress_chunk imposes a single application + * level memory limit, this should be split to different values for iCCP + * and text chunks. + */ + if (png_decompress_chunk(png_ptr, length, keyword_length+2, + &uncompressed_length, 1/*terminate*/) == Z_STREAM_END) + { + png_text text; + + /* It worked; png_ptr->read_buffer now looks like a tEXt chunk except + * for the extra compression type byte and the fact that it isn't + * necessarily '\0' terminated. + */ + buffer = png_ptr->read_buffer; + buffer[uncompressed_length+(keyword_length+2)] = 0; + + text.compression = PNG_TEXT_COMPRESSION_zTXt; + text.key = (png_charp)buffer; + text.text = (png_charp)(buffer + keyword_length+2); + text.text_length = uncompressed_length; + text.itxt_length = 0; + text.lang = NULL; + text.lang_key = NULL; + + if (png_set_text_2(png_ptr, info_ptr, &text, 1)) + errmsg = "insufficient memory"; + } + + else + errmsg = png_ptr->zstream.msg; + } + + if (errmsg != NULL) + png_chunk_benign_error(png_ptr, errmsg); +} +#endif + +#ifdef PNG_READ_iTXt_SUPPORTED +/* Note: this does not correctly handle chunks that are > 64K under DOS */ +void /* PRIVATE */ +png_handle_iTXt(png_structrp png_ptr, png_inforp info_ptr, png_uint_32 length) +{ + png_const_charp errmsg = NULL; + png_bytep buffer; + png_uint_32 prefix_length; + + png_debug(1, "in png_handle_iTXt"); + +#ifdef PNG_USER_LIMITS_SUPPORTED + if (png_ptr->user_chunk_cache_max != 0) + { + if (png_ptr->user_chunk_cache_max == 1) + { + png_crc_finish(png_ptr, length); + return; + } + + if (--png_ptr->user_chunk_cache_max == 1) + { + png_crc_finish(png_ptr, length); + png_chunk_benign_error(png_ptr, "no space in chunk cache"); + return; + } + } +#endif + + if (!(png_ptr->mode & PNG_HAVE_IHDR)) + png_chunk_error(png_ptr, "missing IHDR"); + + if (png_ptr->mode & PNG_HAVE_IDAT) + png_ptr->mode |= PNG_AFTER_IDAT; + + buffer = png_read_buffer(png_ptr, length+1, 1/*warn*/); + + if (buffer == NULL) + { + png_crc_finish(png_ptr, length); + png_chunk_benign_error(png_ptr, "out of memory"); + return; + } + + png_crc_read(png_ptr, buffer, length); + + if (png_crc_finish(png_ptr, 0)) + return; + + /* First the keyword. */ + for (prefix_length=0; + prefix_length < length && buffer[prefix_length] != 0; + ++prefix_length) + /* Empty loop */ ; + + /* Perform a basic check on the keyword length here. */ + if (prefix_length > 79 || prefix_length < 1) + errmsg = "bad keyword"; + + /* Expect keyword, compression flag, compression type, language, translated + * keyword (both may be empty but are 0 terminated) then the text, which may + * be empty. + */ + else if (prefix_length + 5 > length) + errmsg = "truncated"; + + else if (buffer[prefix_length+1] == 0 || + (buffer[prefix_length+1] == 1 && + buffer[prefix_length+2] == PNG_COMPRESSION_TYPE_BASE)) + { + int compressed = buffer[prefix_length+1] != 0; + png_uint_32 language_offset, translated_keyword_offset; + png_alloc_size_t uncompressed_length = 0; + + /* Now the language tag */ + prefix_length += 3; + language_offset = prefix_length; + + for (; prefix_length < length && buffer[prefix_length] != 0; + ++prefix_length) + /* Empty loop */ ; + + /* WARNING: the length may be invalid here, this is checked below. */ + translated_keyword_offset = ++prefix_length; + + for (; prefix_length < length && buffer[prefix_length] != 0; + ++prefix_length) + /* Empty loop */ ; + + /* prefix_length should now be at the trailing '\0' of the translated + * keyword, but it may already be over the end. None of this arithmetic + * can overflow because chunks are at most 2^31 bytes long, but on 16-bit + * systems the available allocaton may overflow. + */ + ++prefix_length; + + if (!compressed && prefix_length <= length) + uncompressed_length = length - prefix_length; + + else if (compressed && prefix_length < length) + { + uncompressed_length = PNG_SIZE_MAX; + + /* TODO: at present png_decompress_chunk imposes a single application + * level memory limit, this should be split to different values for + * iCCP and text chunks. + */ + if (png_decompress_chunk(png_ptr, length, prefix_length, + &uncompressed_length, 1/*terminate*/) == Z_STREAM_END) + buffer = png_ptr->read_buffer; + + else + errmsg = png_ptr->zstream.msg; + } + + else + errmsg = "truncated"; + + if (errmsg == NULL) + { + png_text text; + + buffer[uncompressed_length+prefix_length] = 0; + + if (compressed) + text.compression = PNG_ITXT_COMPRESSION_NONE; + + else + text.compression = PNG_ITXT_COMPRESSION_zTXt; + + text.key = (png_charp)buffer; + text.lang = (png_charp)buffer + language_offset; + text.lang_key = (png_charp)buffer + translated_keyword_offset; + text.text = (png_charp)buffer + prefix_length; + text.text_length = 0; + text.itxt_length = uncompressed_length; + + if (png_set_text_2(png_ptr, info_ptr, &text, 1)) + errmsg = "insufficient memory"; + } + } + + else + errmsg = "bad compression info"; + + if (errmsg != NULL) + png_chunk_benign_error(png_ptr, errmsg); +} +#endif + +#ifdef PNG_READ_UNKNOWN_CHUNKS_SUPPORTED +/* Utility function for png_handle_unknown; set up png_ptr::unknown_chunk */ +static int +png_cache_unknown_chunk(png_structrp png_ptr, png_uint_32 length) +{ + png_alloc_size_t limit = PNG_SIZE_MAX; + + if (png_ptr->unknown_chunk.data != NULL) + { + png_free(png_ptr, png_ptr->unknown_chunk.data); + png_ptr->unknown_chunk.data = NULL; + } + +# ifdef PNG_SET_CHUNK_MALLOC_LIMIT_SUPPORTED + if (png_ptr->user_chunk_malloc_max > 0 && + png_ptr->user_chunk_malloc_max < limit) + limit = png_ptr->user_chunk_malloc_max; + +# elif PNG_USER_CHUNK_MALLOC_MAX > 0 + if (PNG_USER_CHUNK_MALLOC_MAX < limit) + limit = PNG_USER_CHUNK_MALLOC_MAX; +# endif + + if (length <= limit) + { + PNG_CSTRING_FROM_CHUNK(png_ptr->unknown_chunk.name, png_ptr->chunk_name); + /* The following is safe because of the PNG_SIZE_MAX init above */ + png_ptr->unknown_chunk.size = (png_size_t)length/*SAFE*/; + /* 'mode' is a flag array, only the bottom four bits matter here */ + png_ptr->unknown_chunk.location = (png_byte)png_ptr->mode/*SAFE*/; + + if (length == 0) + png_ptr->unknown_chunk.data = NULL; + + else + { + /* Do a 'warn' here - it is handled below. */ + png_ptr->unknown_chunk.data = png_voidcast(png_bytep, + png_malloc_warn(png_ptr, length)); + } + } + + if (png_ptr->unknown_chunk.data == NULL && length > 0) + { + /* This is benign because we clean up correctly */ + png_crc_finish(png_ptr, length); + png_chunk_benign_error(png_ptr, "unknown chunk exceeds memory limits"); + return 0; + } + + else + { + if (length > 0) + png_crc_read(png_ptr, png_ptr->unknown_chunk.data, length); + png_crc_finish(png_ptr, 0); + return 1; + } +} +#endif /* PNG_READ_UNKNOWN_CHUNKS_SUPPORTED */ + +/* Handle an unknown, or known but disabled, chunk */ +void /* PRIVATE */ +png_handle_unknown(png_structrp png_ptr, png_inforp info_ptr, + png_uint_32 length, int keep) +{ + int handled = 0; /* the chunk was handled */ + + png_debug(1, "in png_handle_unknown"); + +#ifdef PNG_READ_UNKNOWN_CHUNKS_SUPPORTED + /* NOTE: this code is based on the code in libpng-1.4.12 except for fixing + * the bug which meant that setting a non-default behavior for a specific + * chunk would be ignored (the default was always used unless a user + * callback was installed). + * + * 'keep' is the value from the png_chunk_unknown_handling, the setting for + * this specific chunk_name, if PNG_HANDLE_AS_UNKNOWN_SUPPORTED, if not it + * will always be PNG_HANDLE_CHUNK_AS_DEFAULT and it needs to be set here. + * This is just an optimization to avoid multiple calls to the lookup + * function. + */ +# ifndef PNG_HANDLE_AS_UNKNOWN_SUPPORTED +# ifdef PNG_SET_UNKNOWN_CHUNKS_SUPPORTED + keep = png_chunk_unknown_handling(png_ptr, png_ptr->chunk_name); +# endif +# endif + + /* One of the following methods will read the chunk or skip it (at least one + * of these is always defined because this is the only way to switch on + * PNG_READ_UNKNOWN_CHUNKS_SUPPORTED) + */ +# ifdef PNG_READ_USER_CHUNKS_SUPPORTED + /* The user callback takes precedence over the chunk keep value, but the + * keep value is still required to validate a save of a critical chunk. + */ + if (png_ptr->read_user_chunk_fn != NULL) + { + if (png_cache_unknown_chunk(png_ptr, length)) + { + /* Callback to user unknown chunk handler */ + int ret = (*(png_ptr->read_user_chunk_fn))(png_ptr, + &png_ptr->unknown_chunk); + + /* ret is: + * negative: An error occured, png_chunk_error will be called. + * zero: The chunk was not handled, the chunk will be discarded + * unless png_set_keep_unknown_chunks has been used to set + * a 'keep' behavior for this particular chunk, in which + * case that will be used. A critical chunk will cause an + * error at this point unless it is to be saved. + * positive: The chunk was handled, libpng will ignore/discard it. + */ + if (ret < 0) + png_chunk_error(png_ptr, "error in user chunk"); + + else if (ret == 0) + { + /* If the keep value is 'default' or 'never' override it, but + * still error out on critical chunks unless the keep value is + * 'always' While this is weird it is the behavior in 1.4.12. + * A possible improvement would be to obey the value set for the + * chunk, but this would be an API change that would probably + * damage some applications. + * + * The png_app_warning below catches the case that matters, where + * the application has not set specific save or ignore for this + * chunk or global save or ignore. + */ + if (keep < PNG_HANDLE_CHUNK_IF_SAFE) + { +# ifdef PNG_SET_UNKNOWN_CHUNKS_SUPPORTED + if (png_ptr->unknown_default < PNG_HANDLE_CHUNK_IF_SAFE) + { + png_chunk_warning(png_ptr, "Saving unknown chunk:"); + png_app_warning(png_ptr, + "forcing save of an unhandled chunk;" + " please call png_set_keep_unknown_chunks"); + /* with keep = PNG_HANDLE_CHUNK_IF_SAFE */ + } +# endif + keep = PNG_HANDLE_CHUNK_IF_SAFE; + } + } + + else /* chunk was handled */ + { + handled = 1; + /* Critical chunks can be safely discarded at this point. */ + keep = PNG_HANDLE_CHUNK_NEVER; + } + } + + else + keep = PNG_HANDLE_CHUNK_NEVER; /* insufficient memory */ + } + + else + /* Use the SAVE_UNKNOWN_CHUNKS code or skip the chunk */ +# endif /* PNG_READ_USER_CHUNKS_SUPPORTED */ + +# ifdef PNG_SAVE_UNKNOWN_CHUNKS_SUPPORTED + { + /* keep is currently just the per-chunk setting, if there was no + * setting change it to the global default now (not that this may + * still be AS_DEFAULT) then obtain the cache of the chunk if required, + * if not simply skip the chunk. + */ + if (keep == PNG_HANDLE_CHUNK_AS_DEFAULT) + keep = png_ptr->unknown_default; + + if (keep == PNG_HANDLE_CHUNK_ALWAYS || + (keep == PNG_HANDLE_CHUNK_IF_SAFE && + PNG_CHUNK_ANCILLARY(png_ptr->chunk_name))) + { + if (!png_cache_unknown_chunk(png_ptr, length)) + keep = PNG_HANDLE_CHUNK_NEVER; + } + + else + png_crc_finish(png_ptr, length); + } +# else +# ifndef PNG_READ_USER_CHUNKS_SUPPORTED +# error no method to support READ_UNKNOWN_CHUNKS +# endif + + { + /* If here there is no read callback pointer set and no support is + * compiled in to just save the unknown chunks, so simply skip this + * chunk. If 'keep' is something other than AS_DEFAULT or NEVER then + * the app has erroneously asked for unknown chunk saving when there + * is no support. + */ + if (keep > PNG_HANDLE_CHUNK_NEVER) + png_app_error(png_ptr, "no unknown chunk support available"); + + png_crc_finish(png_ptr, length); + } +# endif + +# ifdef PNG_STORE_UNKNOWN_CHUNKS_SUPPORTED + /* Now store the chunk in the chunk list if appropriate, and if the limits + * permit it. + */ + if (keep == PNG_HANDLE_CHUNK_ALWAYS || + (keep == PNG_HANDLE_CHUNK_IF_SAFE && + PNG_CHUNK_ANCILLARY(png_ptr->chunk_name))) + { +# ifdef PNG_USER_LIMITS_SUPPORTED + switch (png_ptr->user_chunk_cache_max) + { + case 2: + png_ptr->user_chunk_cache_max = 1; + png_chunk_benign_error(png_ptr, "no space in chunk cache"); + /* FALL THROUGH */ + case 1: + /* NOTE: prior to 1.6.0 this case resulted in an unknown critical + * chunk being skipped, now there will be a hard error below. + */ + break; + + default: /* not at limit */ + --(png_ptr->user_chunk_cache_max); + /* FALL THROUGH */ + case 0: /* no limit */ +# endif /* PNG_USER_LIMITS_SUPPORTED */ + /* Here when the limit isn't reached or when limits are compiled + * out; store the chunk. + */ + png_set_unknown_chunks(png_ptr, info_ptr, + &png_ptr->unknown_chunk, 1); + handled = 1; +# ifdef PNG_USER_LIMITS_SUPPORTED + break; + } +# endif + } +# else /* no store support: the chunk must be handled by the user callback */ + PNG_UNUSED(info_ptr) +# endif + + /* Regardless of the error handling below the cached data (if any) can be + * freed now. Notice that the data is not freed if there is a png_error, but + * it will be freed by destroy_read_struct. + */ + if (png_ptr->unknown_chunk.data != NULL) + png_free(png_ptr, png_ptr->unknown_chunk.data); + png_ptr->unknown_chunk.data = NULL; + +#else /* !PNG_READ_UNKNOWN_CHUNKS_SUPPORTED */ + /* There is no support to read an unknown chunk, so just skip it. */ + png_crc_finish(png_ptr, length); + PNG_UNUSED(info_ptr) + PNG_UNUSED(keep) +#endif /* !PNG_READ_UNKNOWN_CHUNKS_SUPPORTED */ + + /* Check for unhandled critical chunks */ + if (!handled && PNG_CHUNK_CRITICAL(png_ptr->chunk_name)) + png_chunk_error(png_ptr, "unhandled critical chunk"); +} + +/* This function is called to verify that a chunk name is valid. + * This function can't have the "critical chunk check" incorporated + * into it, since in the future we will need to be able to call user + * functions to handle unknown critical chunks after we check that + * the chunk name itself is valid. + */ + +/* Bit hacking: the test for an invalid byte in the 4 byte chunk name is: + * + * ((c) < 65 || (c) > 122 || ((c) > 90 && (c) < 97)) + */ + +void /* PRIVATE */ +png_check_chunk_name(png_structrp png_ptr, png_uint_32 chunk_name) +{ + int i; + + png_debug(1, "in png_check_chunk_name"); + + for (i=1; i<=4; ++i) + { + int c = chunk_name & 0xff; + + if (c < 65 || c > 122 || (c > 90 && c < 97)) + png_chunk_error(png_ptr, "invalid chunk type"); + + chunk_name >>= 8; + } +} + +/* Combines the row recently read in with the existing pixels in the row. This + * routine takes care of alpha and transparency if requested. This routine also + * handles the two methods of progressive display of interlaced images, + * depending on the 'display' value; if 'display' is true then the whole row + * (dp) is filled from the start by replicating the available pixels. If + * 'display' is false only those pixels present in the pass are filled in. + */ +void /* PRIVATE */ +png_combine_row(png_const_structrp png_ptr, png_bytep dp, int display) +{ + unsigned int pixel_depth = png_ptr->transformed_pixel_depth; + png_const_bytep sp = png_ptr->row_buf + 1; + png_uint_32 row_width = png_ptr->width; + unsigned int pass = png_ptr->pass; + png_bytep end_ptr = 0; + png_byte end_byte = 0; + unsigned int end_mask; + + png_debug(1, "in png_combine_row"); + + /* Added in 1.5.6: it should not be possible to enter this routine until at + * least one row has been read from the PNG data and transformed. + */ + if (pixel_depth == 0) + png_error(png_ptr, "internal row logic error"); + + /* Added in 1.5.4: the pixel depth should match the information returned by + * any call to png_read_update_info at this point. Do not continue if we got + * this wrong. + */ + if (png_ptr->info_rowbytes != 0 && png_ptr->info_rowbytes != + PNG_ROWBYTES(pixel_depth, row_width)) + png_error(png_ptr, "internal row size calculation error"); + + /* Don't expect this to ever happen: */ + if (row_width == 0) + png_error(png_ptr, "internal row width error"); + + /* Preserve the last byte in cases where only part of it will be overwritten, + * the multiply below may overflow, we don't care because ANSI-C guarantees + * we get the low bits. + */ + end_mask = (pixel_depth * row_width) & 7; + if (end_mask != 0) + { + /* end_ptr == NULL is a flag to say do nothing */ + end_ptr = dp + PNG_ROWBYTES(pixel_depth, row_width) - 1; + end_byte = *end_ptr; +# ifdef PNG_READ_PACKSWAP_SUPPORTED + if (png_ptr->transformations & PNG_PACKSWAP) /* little-endian byte */ + end_mask = 0xff << end_mask; + + else /* big-endian byte */ +# endif + end_mask = 0xff >> end_mask; + /* end_mask is now the bits to *keep* from the destination row */ + } + + /* For non-interlaced images this reduces to a memcpy(). A memcpy() + * will also happen if interlacing isn't supported or if the application + * does not call png_set_interlace_handling(). In the latter cases the + * caller just gets a sequence of the unexpanded rows from each interlace + * pass. + */ +#ifdef PNG_READ_INTERLACING_SUPPORTED + if (png_ptr->interlaced && (png_ptr->transformations & PNG_INTERLACE) && + pass < 6 && (display == 0 || + /* The following copies everything for 'display' on passes 0, 2 and 4. */ + (display == 1 && (pass & 1) != 0))) + { + /* Narrow images may have no bits in a pass; the caller should handle + * this, but this test is cheap: + */ + if (row_width <= PNG_PASS_START_COL(pass)) + return; + + if (pixel_depth < 8) + { + /* For pixel depths up to 4 bpp the 8-pixel mask can be expanded to fit + * into 32 bits, then a single loop over the bytes using the four byte + * values in the 32-bit mask can be used. For the 'display' option the + * expanded mask may also not require any masking within a byte. To + * make this work the PACKSWAP option must be taken into account - it + * simply requires the pixels to be reversed in each byte. + * + * The 'regular' case requires a mask for each of the first 6 passes, + * the 'display' case does a copy for the even passes in the range + * 0..6. This has already been handled in the test above. + * + * The masks are arranged as four bytes with the first byte to use in + * the lowest bits (little-endian) regardless of the order (PACKSWAP or + * not) of the pixels in each byte. + * + * NOTE: the whole of this logic depends on the caller of this function + * only calling it on rows appropriate to the pass. This function only + * understands the 'x' logic; the 'y' logic is handled by the caller. + * + * The following defines allow generation of compile time constant bit + * masks for each pixel depth and each possibility of swapped or not + * swapped bytes. Pass 'p' is in the range 0..6; 'x', a pixel index, + * is in the range 0..7; and the result is 1 if the pixel is to be + * copied in the pass, 0 if not. 'S' is for the sparkle method, 'B' + * for the block method. + * + * With some compilers a compile time expression of the general form: + * + * (shift >= 32) ? (a >> (shift-32)) : (b >> shift) + * + * Produces warnings with values of 'shift' in the range 33 to 63 + * because the right hand side of the ?: expression is evaluated by + * the compiler even though it isn't used. Microsoft Visual C (various + * versions) and the Intel C compiler are known to do this. To avoid + * this the following macros are used in 1.5.6. This is a temporary + * solution to avoid destabilizing the code during the release process. + */ +# if PNG_USE_COMPILE_TIME_MASKS +# define PNG_LSR(x,s) ((x)>>((s) & 0x1f)) +# define PNG_LSL(x,s) ((x)<<((s) & 0x1f)) +# else +# define PNG_LSR(x,s) ((x)>>(s)) +# define PNG_LSL(x,s) ((x)<<(s)) +# endif +# define S_COPY(p,x) (((p)<4 ? PNG_LSR(0x80088822,(3-(p))*8+(7-(x))) :\ + PNG_LSR(0xaa55ff00,(7-(p))*8+(7-(x)))) & 1) +# define B_COPY(p,x) (((p)<4 ? PNG_LSR(0xff0fff33,(3-(p))*8+(7-(x))) :\ + PNG_LSR(0xff55ff00,(7-(p))*8+(7-(x)))) & 1) + + /* Return a mask for pass 'p' pixel 'x' at depth 'd'. The mask is + * little endian - the first pixel is at bit 0 - however the extra + * parameter 's' can be set to cause the mask position to be swapped + * within each byte, to match the PNG format. This is done by XOR of + * the shift with 7, 6 or 4 for bit depths 1, 2 and 4. + */ +# define PIXEL_MASK(p,x,d,s) \ + (PNG_LSL(((PNG_LSL(1U,(d)))-1),(((x)*(d))^((s)?8-(d):0)))) + + /* Hence generate the appropriate 'block' or 'sparkle' pixel copy mask. + */ +# define S_MASKx(p,x,d,s) (S_COPY(p,x)?PIXEL_MASK(p,x,d,s):0) +# define B_MASKx(p,x,d,s) (B_COPY(p,x)?PIXEL_MASK(p,x,d,s):0) + + /* Combine 8 of these to get the full mask. For the 1-bpp and 2-bpp + * cases the result needs replicating, for the 4-bpp case the above + * generates a full 32 bits. + */ +# define MASK_EXPAND(m,d) ((m)*((d)==1?0x01010101:((d)==2?0x00010001:1))) + +# define S_MASK(p,d,s) MASK_EXPAND(S_MASKx(p,0,d,s) + S_MASKx(p,1,d,s) +\ + S_MASKx(p,2,d,s) + S_MASKx(p,3,d,s) + S_MASKx(p,4,d,s) +\ + S_MASKx(p,5,d,s) + S_MASKx(p,6,d,s) + S_MASKx(p,7,d,s), d) + +# define B_MASK(p,d,s) MASK_EXPAND(B_MASKx(p,0,d,s) + B_MASKx(p,1,d,s) +\ + B_MASKx(p,2,d,s) + B_MASKx(p,3,d,s) + B_MASKx(p,4,d,s) +\ + B_MASKx(p,5,d,s) + B_MASKx(p,6,d,s) + B_MASKx(p,7,d,s), d) + +#if PNG_USE_COMPILE_TIME_MASKS + /* Utility macros to construct all the masks for a depth/swap + * combination. The 's' parameter says whether the format is PNG + * (big endian bytes) or not. Only the three odd-numbered passes are + * required for the display/block algorithm. + */ +# define S_MASKS(d,s) { S_MASK(0,d,s), S_MASK(1,d,s), S_MASK(2,d,s),\ + S_MASK(3,d,s), S_MASK(4,d,s), S_MASK(5,d,s) } + +# define B_MASKS(d,s) { B_MASK(1,d,s), S_MASK(3,d,s), S_MASK(5,d,s) } + +# define DEPTH_INDEX(d) ((d)==1?0:((d)==2?1:2)) + + /* Hence the pre-compiled masks indexed by PACKSWAP (or not), depth and + * then pass: + */ + static PNG_CONST png_uint_32 row_mask[2/*PACKSWAP*/][3/*depth*/][6] = + { + /* Little-endian byte masks for PACKSWAP */ + { S_MASKS(1,0), S_MASKS(2,0), S_MASKS(4,0) }, + /* Normal (big-endian byte) masks - PNG format */ + { S_MASKS(1,1), S_MASKS(2,1), S_MASKS(4,1) } + }; + + /* display_mask has only three entries for the odd passes, so index by + * pass>>1. + */ + static PNG_CONST png_uint_32 display_mask[2][3][3] = + { + /* Little-endian byte masks for PACKSWAP */ + { B_MASKS(1,0), B_MASKS(2,0), B_MASKS(4,0) }, + /* Normal (big-endian byte) masks - PNG format */ + { B_MASKS(1,1), B_MASKS(2,1), B_MASKS(4,1) } + }; + +# define MASK(pass,depth,display,png)\ + ((display)?display_mask[png][DEPTH_INDEX(depth)][pass>>1]:\ + row_mask[png][DEPTH_INDEX(depth)][pass]) + +#else /* !PNG_USE_COMPILE_TIME_MASKS */ + /* This is the runtime alternative: it seems unlikely that this will + * ever be either smaller or faster than the compile time approach. + */ +# define MASK(pass,depth,display,png)\ + ((display)?B_MASK(pass,depth,png):S_MASK(pass,depth,png)) +#endif /* !PNG_USE_COMPILE_TIME_MASKS */ + + /* Use the appropriate mask to copy the required bits. In some cases + * the byte mask will be 0 or 0xff, optimize these cases. row_width is + * the number of pixels, but the code copies bytes, so it is necessary + * to special case the end. + */ + png_uint_32 pixels_per_byte = 8 / pixel_depth; + png_uint_32 mask; + +# ifdef PNG_READ_PACKSWAP_SUPPORTED + if (png_ptr->transformations & PNG_PACKSWAP) + mask = MASK(pass, pixel_depth, display, 0); + + else +# endif + mask = MASK(pass, pixel_depth, display, 1); + + for (;;) + { + png_uint_32 m; + + /* It doesn't matter in the following if png_uint_32 has more than + * 32 bits because the high bits always match those in m<<24; it is, + * however, essential to use OR here, not +, because of this. + */ + m = mask; + mask = (m >> 8) | (m << 24); /* rotate right to good compilers */ + m &= 0xff; + + if (m != 0) /* something to copy */ + { + if (m != 0xff) + *dp = (png_byte)((*dp & ~m) | (*sp & m)); + else + *dp = *sp; + } + + /* NOTE: this may overwrite the last byte with garbage if the image + * is not an exact number of bytes wide; libpng has always done + * this. + */ + if (row_width <= pixels_per_byte) + break; /* May need to restore part of the last byte */ + + row_width -= pixels_per_byte; + ++dp; + ++sp; + } + } + + else /* pixel_depth >= 8 */ + { + unsigned int bytes_to_copy, bytes_to_jump; + + /* Validate the depth - it must be a multiple of 8 */ + if (pixel_depth & 7) + png_error(png_ptr, "invalid user transform pixel depth"); + + pixel_depth >>= 3; /* now in bytes */ + row_width *= pixel_depth; + + /* Regardless of pass number the Adam 7 interlace always results in a + * fixed number of pixels to copy then to skip. There may be a + * different number of pixels to skip at the start though. + */ + { + unsigned int offset = PNG_PASS_START_COL(pass) * pixel_depth; + + row_width -= offset; + dp += offset; + sp += offset; + } + + /* Work out the bytes to copy. */ + if (display) + { + /* When doing the 'block' algorithm the pixel in the pass gets + * replicated to adjacent pixels. This is why the even (0,2,4,6) + * passes are skipped above - the entire expanded row is copied. + */ + bytes_to_copy = (1<<((6-pass)>>1)) * pixel_depth; + + /* But don't allow this number to exceed the actual row width. */ + if (bytes_to_copy > row_width) + bytes_to_copy = row_width; + } + + else /* normal row; Adam7 only ever gives us one pixel to copy. */ + bytes_to_copy = pixel_depth; + + /* In Adam7 there is a constant offset between where the pixels go. */ + bytes_to_jump = PNG_PASS_COL_OFFSET(pass) * pixel_depth; + + /* And simply copy these bytes. Some optimization is possible here, + * depending on the value of 'bytes_to_copy'. Special case the low + * byte counts, which we know to be frequent. + * + * Notice that these cases all 'return' rather than 'break' - this + * avoids an unnecessary test on whether to restore the last byte + * below. + */ + switch (bytes_to_copy) + { + case 1: + for (;;) + { + *dp = *sp; + + if (row_width <= bytes_to_jump) + return; + + dp += bytes_to_jump; + sp += bytes_to_jump; + row_width -= bytes_to_jump; + } + + case 2: + /* There is a possibility of a partial copy at the end here; this + * slows the code down somewhat. + */ + do + { + dp[0] = sp[0], dp[1] = sp[1]; + + if (row_width <= bytes_to_jump) + return; + + sp += bytes_to_jump; + dp += bytes_to_jump; + row_width -= bytes_to_jump; + } + while (row_width > 1); + + /* And there can only be one byte left at this point: */ + *dp = *sp; + return; + + case 3: + /* This can only be the RGB case, so each copy is exactly one + * pixel and it is not necessary to check for a partial copy. + */ + for(;;) + { + dp[0] = sp[0], dp[1] = sp[1], dp[2] = sp[2]; + + if (row_width <= bytes_to_jump) + return; + + sp += bytes_to_jump; + dp += bytes_to_jump; + row_width -= bytes_to_jump; + } + + default: +#if PNG_ALIGN_TYPE != PNG_ALIGN_NONE + /* Check for double byte alignment and, if possible, use a + * 16-bit copy. Don't attempt this for narrow images - ones that + * are less than an interlace panel wide. Don't attempt it for + * wide bytes_to_copy either - use the memcpy there. + */ + if (bytes_to_copy < 16 /*else use memcpy*/ && + png_isaligned(dp, png_uint_16) && + png_isaligned(sp, png_uint_16) && + bytes_to_copy % (sizeof (png_uint_16)) == 0 && + bytes_to_jump % (sizeof (png_uint_16)) == 0) + { + /* Everything is aligned for png_uint_16 copies, but try for + * png_uint_32 first. + */ + if (png_isaligned(dp, png_uint_32) && + png_isaligned(sp, png_uint_32) && + bytes_to_copy % (sizeof (png_uint_32)) == 0 && + bytes_to_jump % (sizeof (png_uint_32)) == 0) + { + png_uint_32p dp32 = png_aligncast(png_uint_32p,dp); + png_const_uint_32p sp32 = png_aligncastconst( + png_const_uint_32p, sp); + size_t skip = (bytes_to_jump-bytes_to_copy) / + (sizeof (png_uint_32)); + + do + { + size_t c = bytes_to_copy; + do + { + *dp32++ = *sp32++; + c -= (sizeof (png_uint_32)); + } + while (c > 0); + + if (row_width <= bytes_to_jump) + return; + + dp32 += skip; + sp32 += skip; + row_width -= bytes_to_jump; + } + while (bytes_to_copy <= row_width); + + /* Get to here when the row_width truncates the final copy. + * There will be 1-3 bytes left to copy, so don't try the + * 16-bit loop below. + */ + dp = (png_bytep)dp32; + sp = (png_const_bytep)sp32; + do + *dp++ = *sp++; + while (--row_width > 0); + return; + } + + /* Else do it in 16-bit quantities, but only if the size is + * not too large. + */ + else + { + png_uint_16p dp16 = png_aligncast(png_uint_16p, dp); + png_const_uint_16p sp16 = png_aligncastconst( + png_const_uint_16p, sp); + size_t skip = (bytes_to_jump-bytes_to_copy) / + (sizeof (png_uint_16)); + + do + { + size_t c = bytes_to_copy; + do + { + *dp16++ = *sp16++; + c -= (sizeof (png_uint_16)); + } + while (c > 0); + + if (row_width <= bytes_to_jump) + return; + + dp16 += skip; + sp16 += skip; + row_width -= bytes_to_jump; + } + while (bytes_to_copy <= row_width); + + /* End of row - 1 byte left, bytes_to_copy > row_width: */ + dp = (png_bytep)dp16; + sp = (png_const_bytep)sp16; + do + *dp++ = *sp++; + while (--row_width > 0); + return; + } + } +#endif /* PNG_ALIGN_ code */ + + /* The true default - use a memcpy: */ + for (;;) + { + memcpy(dp, sp, bytes_to_copy); + + if (row_width <= bytes_to_jump) + return; + + sp += bytes_to_jump; + dp += bytes_to_jump; + row_width -= bytes_to_jump; + if (bytes_to_copy > row_width) + bytes_to_copy = row_width; + } + } + + /* NOT REACHED*/ + } /* pixel_depth >= 8 */ + + /* Here if pixel_depth < 8 to check 'end_ptr' below. */ + } + else +#endif + + /* If here then the switch above wasn't used so just memcpy the whole row + * from the temporary row buffer (notice that this overwrites the end of the + * destination row if it is a partial byte.) + */ + memcpy(dp, sp, PNG_ROWBYTES(pixel_depth, row_width)); + + /* Restore the overwritten bits from the last byte if necessary. */ + if (end_ptr != NULL) + *end_ptr = (png_byte)((end_byte & end_mask) | (*end_ptr & ~end_mask)); +} + +#ifdef PNG_READ_INTERLACING_SUPPORTED +void /* PRIVATE */ +png_do_read_interlace(png_row_infop row_info, png_bytep row, int pass, + png_uint_32 transformations /* Because these may affect the byte layout */) +{ + /* Arrays to facilitate easy interlacing - use pass (0 - 6) as index */ + /* Offset to next interlace block */ + static PNG_CONST int png_pass_inc[7] = {8, 8, 4, 4, 2, 2, 1}; + + png_debug(1, "in png_do_read_interlace"); + if (row != NULL && row_info != NULL) + { + png_uint_32 final_width; + + final_width = row_info->width * png_pass_inc[pass]; + + switch (row_info->pixel_depth) + { + case 1: + { + png_bytep sp = row + (png_size_t)((row_info->width - 1) >> 3); + png_bytep dp = row + (png_size_t)((final_width - 1) >> 3); + int sshift, dshift; + int s_start, s_end, s_inc; + int jstop = png_pass_inc[pass]; + png_byte v; + png_uint_32 i; + int j; + +#ifdef PNG_READ_PACKSWAP_SUPPORTED + if (transformations & PNG_PACKSWAP) + { + sshift = (int)((row_info->width + 7) & 0x07); + dshift = (int)((final_width + 7) & 0x07); + s_start = 7; + s_end = 0; + s_inc = -1; + } + + else +#endif + { + sshift = 7 - (int)((row_info->width + 7) & 0x07); + dshift = 7 - (int)((final_width + 7) & 0x07); + s_start = 0; + s_end = 7; + s_inc = 1; + } + + for (i = 0; i < row_info->width; i++) + { + v = (png_byte)((*sp >> sshift) & 0x01); + for (j = 0; j < jstop; j++) + { + unsigned int tmp = *dp & (0x7f7f >> (7 - dshift)); + tmp |= v << dshift; + *dp = (png_byte)(tmp & 0xff); + + if (dshift == s_end) + { + dshift = s_start; + dp--; + } + + else + dshift += s_inc; + } + + if (sshift == s_end) + { + sshift = s_start; + sp--; + } + + else + sshift += s_inc; + } + break; + } + + case 2: + { + png_bytep sp = row + (png_uint_32)((row_info->width - 1) >> 2); + png_bytep dp = row + (png_uint_32)((final_width - 1) >> 2); + int sshift, dshift; + int s_start, s_end, s_inc; + int jstop = png_pass_inc[pass]; + png_uint_32 i; + +#ifdef PNG_READ_PACKSWAP_SUPPORTED + if (transformations & PNG_PACKSWAP) + { + sshift = (int)(((row_info->width + 3) & 0x03) << 1); + dshift = (int)(((final_width + 3) & 0x03) << 1); + s_start = 6; + s_end = 0; + s_inc = -2; + } + + else +#endif + { + sshift = (int)((3 - ((row_info->width + 3) & 0x03)) << 1); + dshift = (int)((3 - ((final_width + 3) & 0x03)) << 1); + s_start = 0; + s_end = 6; + s_inc = 2; + } + + for (i = 0; i < row_info->width; i++) + { + png_byte v; + int j; + + v = (png_byte)((*sp >> sshift) & 0x03); + for (j = 0; j < jstop; j++) + { + unsigned int tmp = *dp & (0x3f3f >> (6 - dshift)); + tmp |= v << dshift; + *dp = (png_byte)(tmp & 0xff); + + if (dshift == s_end) + { + dshift = s_start; + dp--; + } + + else + dshift += s_inc; + } + + if (sshift == s_end) + { + sshift = s_start; + sp--; + } + + else + sshift += s_inc; + } + break; + } + + case 4: + { + png_bytep sp = row + (png_size_t)((row_info->width - 1) >> 1); + png_bytep dp = row + (png_size_t)((final_width - 1) >> 1); + int sshift, dshift; + int s_start, s_end, s_inc; + png_uint_32 i; + int jstop = png_pass_inc[pass]; + +#ifdef PNG_READ_PACKSWAP_SUPPORTED + if (transformations & PNG_PACKSWAP) + { + sshift = (int)(((row_info->width + 1) & 0x01) << 2); + dshift = (int)(((final_width + 1) & 0x01) << 2); + s_start = 4; + s_end = 0; + s_inc = -4; + } + + else +#endif + { + sshift = (int)((1 - ((row_info->width + 1) & 0x01)) << 2); + dshift = (int)((1 - ((final_width + 1) & 0x01)) << 2); + s_start = 0; + s_end = 4; + s_inc = 4; + } + + for (i = 0; i < row_info->width; i++) + { + png_byte v = (png_byte)((*sp >> sshift) & 0x0f); + int j; + + for (j = 0; j < jstop; j++) + { + unsigned int tmp = *dp & (0xf0f >> (4 - dshift)); + tmp |= v << dshift; + *dp = (png_byte)(tmp & 0xff); + + if (dshift == s_end) + { + dshift = s_start; + dp--; + } + + else + dshift += s_inc; + } + + if (sshift == s_end) + { + sshift = s_start; + sp--; + } + + else + sshift += s_inc; + } + break; + } + + default: + { + png_size_t pixel_bytes = (row_info->pixel_depth >> 3); + + png_bytep sp = row + (png_size_t)(row_info->width - 1) + * pixel_bytes; + + png_bytep dp = row + (png_size_t)(final_width - 1) * pixel_bytes; + + int jstop = png_pass_inc[pass]; + png_uint_32 i; + + for (i = 0; i < row_info->width; i++) + { + png_byte v[8]; /* SAFE; pixel_depth does not exceed 64 */ + int j; + + memcpy(v, sp, pixel_bytes); + + for (j = 0; j < jstop; j++) + { + memcpy(dp, v, pixel_bytes); + dp -= pixel_bytes; + } + + sp -= pixel_bytes; + } + break; + } + } + + row_info->width = final_width; + row_info->rowbytes = PNG_ROWBYTES(row_info->pixel_depth, final_width); + } +#ifndef PNG_READ_PACKSWAP_SUPPORTED + PNG_UNUSED(transformations) /* Silence compiler warning */ +#endif +} +#endif /* PNG_READ_INTERLACING_SUPPORTED */ + +static void +png_read_filter_row_sub(png_row_infop row_info, png_bytep row, + png_const_bytep prev_row) +{ + png_size_t i; + png_size_t istop = row_info->rowbytes; + unsigned int bpp = (row_info->pixel_depth + 7) >> 3; + png_bytep rp = row + bpp; + + PNG_UNUSED(prev_row) + + for (i = bpp; i < istop; i++) + { + *rp = (png_byte)(((int)(*rp) + (int)(*(rp-bpp))) & 0xff); + rp++; + } +} + +static void +png_read_filter_row_up(png_row_infop row_info, png_bytep row, + png_const_bytep prev_row) +{ + png_size_t i; + png_size_t istop = row_info->rowbytes; + png_bytep rp = row; + png_const_bytep pp = prev_row; + + for (i = 0; i < istop; i++) + { + *rp = (png_byte)(((int)(*rp) + (int)(*pp++)) & 0xff); + rp++; + } +} + +static void +png_read_filter_row_avg(png_row_infop row_info, png_bytep row, + png_const_bytep prev_row) +{ + png_size_t i; + png_bytep rp = row; + png_const_bytep pp = prev_row; + unsigned int bpp = (row_info->pixel_depth + 7) >> 3; + png_size_t istop = row_info->rowbytes - bpp; + + for (i = 0; i < bpp; i++) + { + *rp = (png_byte)(((int)(*rp) + + ((int)(*pp++) / 2 )) & 0xff); + + rp++; + } + + for (i = 0; i < istop; i++) + { + *rp = (png_byte)(((int)(*rp) + + (int)(*pp++ + *(rp-bpp)) / 2 ) & 0xff); + + rp++; + } +} + +static void +png_read_filter_row_paeth_1byte_pixel(png_row_infop row_info, png_bytep row, + png_const_bytep prev_row) +{ + png_bytep rp_end = row + row_info->rowbytes; + int a, c; + + /* First pixel/byte */ + c = *prev_row++; + a = *row + c; + *row++ = (png_byte)a; + + /* Remainder */ + while (row < rp_end) + { + int b, pa, pb, pc, p; + + a &= 0xff; /* From previous iteration or start */ + b = *prev_row++; + + p = b - c; + pc = a - c; + +# ifdef PNG_USE_ABS + pa = abs(p); + pb = abs(pc); + pc = abs(p + pc); +# else + pa = p < 0 ? -p : p; + pb = pc < 0 ? -pc : pc; + pc = (p + pc) < 0 ? -(p + pc) : p + pc; +# endif + + /* Find the best predictor, the least of pa, pb, pc favoring the earlier + * ones in the case of a tie. + */ + if (pb < pa) pa = pb, a = b; + if (pc < pa) a = c; + + /* Calculate the current pixel in a, and move the previous row pixel to c + * for the next time round the loop + */ + c = b; + a += *row; + *row++ = (png_byte)a; + } +} + +static void +png_read_filter_row_paeth_multibyte_pixel(png_row_infop row_info, png_bytep row, + png_const_bytep prev_row) +{ + int bpp = (row_info->pixel_depth + 7) >> 3; + png_bytep rp_end = row + bpp; + + /* Process the first pixel in the row completely (this is the same as 'up' + * because there is only one candidate predictor for the first row). + */ + while (row < rp_end) + { + int a = *row + *prev_row++; + *row++ = (png_byte)a; + } + + /* Remainder */ + rp_end += row_info->rowbytes - bpp; + + while (row < rp_end) + { + int a, b, c, pa, pb, pc, p; + + c = *(prev_row - bpp); + a = *(row - bpp); + b = *prev_row++; + + p = b - c; + pc = a - c; + +# ifdef PNG_USE_ABS + pa = abs(p); + pb = abs(pc); + pc = abs(p + pc); +# else + pa = p < 0 ? -p : p; + pb = pc < 0 ? -pc : pc; + pc = (p + pc) < 0 ? -(p + pc) : p + pc; +# endif + + if (pb < pa) pa = pb, a = b; + if (pc < pa) a = c; + + c = b; + a += *row; + *row++ = (png_byte)a; + } +} + +static void +png_init_filter_functions(png_structrp pp) + /* This function is called once for every PNG image (except for PNG images + * that only use PNG_FILTER_VALUE_NONE for all rows) to set the + * implementations required to reverse the filtering of PNG rows. Reversing + * the filter is the first transformation performed on the row data. It is + * performed in place, therefore an implementation can be selected based on + * the image pixel format. If the implementation depends on image width then + * take care to ensure that it works correctly if the image is interlaced - + * interlacing causes the actual row width to vary. + */ +{ + unsigned int bpp = (pp->pixel_depth + 7) >> 3; + + pp->read_filter[PNG_FILTER_VALUE_SUB-1] = png_read_filter_row_sub; + pp->read_filter[PNG_FILTER_VALUE_UP-1] = png_read_filter_row_up; + pp->read_filter[PNG_FILTER_VALUE_AVG-1] = png_read_filter_row_avg; + if (bpp == 1) + pp->read_filter[PNG_FILTER_VALUE_PAETH-1] = + png_read_filter_row_paeth_1byte_pixel; + else + pp->read_filter[PNG_FILTER_VALUE_PAETH-1] = + png_read_filter_row_paeth_multibyte_pixel; + +#ifdef PNG_FILTER_OPTIMIZATIONS + /* To use this define PNG_FILTER_OPTIMIZATIONS as the name of a function to + * call to install hardware optimizations for the above functions; simply + * replace whatever elements of the pp->read_filter[] array with a hardware + * specific (or, for that matter, generic) optimization. + * + * To see an example of this examine what configure.ac does when + * --enable-arm-neon is specified on the command line. + */ + PNG_FILTER_OPTIMIZATIONS(pp, bpp); +#endif +} + +void /* PRIVATE */ +png_read_filter_row(png_structrp pp, png_row_infop row_info, png_bytep row, + png_const_bytep prev_row, int filter) +{ + /* OPTIMIZATION: DO NOT MODIFY THIS FUNCTION, instead #define + * PNG_FILTER_OPTIMIZATIONS to a function that overrides the generic + * implementations. See png_init_filter_functions above. + */ + if (filter > PNG_FILTER_VALUE_NONE && filter < PNG_FILTER_VALUE_LAST) + { + if (pp->read_filter[0] == NULL) + png_init_filter_functions(pp); + + pp->read_filter[filter-1](row_info, row, prev_row); + } +} + +#ifdef PNG_SEQUENTIAL_READ_SUPPORTED +void /* PRIVATE */ +png_read_IDAT_data(png_structrp png_ptr, png_bytep output, + png_alloc_size_t avail_out) +{ + /* Loop reading IDATs and decompressing the result into output[avail_out] */ + png_ptr->zstream.next_out = output; + png_ptr->zstream.avail_out = 0; /* safety: set below */ + + if (output == NULL) + avail_out = 0; + + do + { + int ret; + png_byte tmpbuf[PNG_INFLATE_BUF_SIZE]; + + if (png_ptr->zstream.avail_in == 0) + { + uInt avail_in; + png_bytep buffer; + + while (png_ptr->idat_size == 0) + { + png_crc_finish(png_ptr, 0); + + png_ptr->idat_size = png_read_chunk_header(png_ptr); + /* This is an error even in the 'check' case because the code just + * consumed a non-IDAT header. + */ + if (png_ptr->chunk_name != png_IDAT) + png_error(png_ptr, "Not enough image data"); + } + + avail_in = png_ptr->IDAT_read_size; + + if (avail_in > png_ptr->idat_size) + avail_in = (uInt)png_ptr->idat_size; + + /* A PNG with a gradually increasing IDAT size will defeat this attempt + * to minimize memory usage by causing lots of re-allocs, but + * realistically doing IDAT_read_size re-allocs is not likely to be a + * big problem. + */ + buffer = png_read_buffer(png_ptr, avail_in, 0/*error*/); + + png_crc_read(png_ptr, buffer, avail_in); + png_ptr->idat_size -= avail_in; + + png_ptr->zstream.next_in = buffer; + png_ptr->zstream.avail_in = avail_in; + } + + /* And set up the output side. */ + if (output != NULL) /* standard read */ + { + uInt out = ZLIB_IO_MAX; + + if (out > avail_out) + out = (uInt)avail_out; + + avail_out -= out; + png_ptr->zstream.avail_out = out; + } + + else /* after last row, checking for end */ + { + png_ptr->zstream.next_out = tmpbuf; + png_ptr->zstream.avail_out = (sizeof tmpbuf); + } + + /* Use NO_FLUSH; this gives zlib the maximum opportunity to optimize the + * process. If the LZ stream is truncated the sequential reader will + * terminally damage the stream, above, by reading the chunk header of the + * following chunk (it then exits with png_error). + * + * TODO: deal more elegantly with truncated IDAT lists. + */ + ret = inflate(&png_ptr->zstream, Z_NO_FLUSH); + + /* Take the unconsumed output back. */ + if (output != NULL) + avail_out += png_ptr->zstream.avail_out; + + else /* avail_out counts the extra bytes */ + avail_out += (sizeof tmpbuf) - png_ptr->zstream.avail_out; + + png_ptr->zstream.avail_out = 0; + + if (ret == Z_STREAM_END) + { + /* Do this for safety; we won't read any more into this row. */ + png_ptr->zstream.next_out = NULL; + + png_ptr->mode |= PNG_AFTER_IDAT; + png_ptr->flags |= PNG_FLAG_ZSTREAM_ENDED; + + if (png_ptr->zstream.avail_in > 0 || png_ptr->idat_size > 0) + png_chunk_benign_error(png_ptr, "Extra compressed data"); + break; + } + + if (ret != Z_OK) + { + png_zstream_error(png_ptr, ret); + + if (output != NULL) + png_chunk_error(png_ptr, png_ptr->zstream.msg); + + else /* checking */ + { + png_chunk_benign_error(png_ptr, png_ptr->zstream.msg); + return; + } + } + } while (avail_out > 0); + + if (avail_out > 0) + { + /* The stream ended before the image; this is the same as too few IDATs so + * should be handled the same way. + */ + if (output != NULL) + png_error(png_ptr, "Not enough image data"); + + else /* the deflate stream contained extra data */ + png_chunk_benign_error(png_ptr, "Too much image data"); + } +} + +void /* PRIVATE */ +png_read_finish_IDAT(png_structrp png_ptr) +{ + /* We don't need any more data and the stream should have ended, however the + * LZ end code may actually not have been processed. In this case we must + * read it otherwise stray unread IDAT data or, more likely, an IDAT chunk + * may still remain to be consumed. + */ + if (!(png_ptr->flags & PNG_FLAG_ZSTREAM_ENDED)) + { + /* The NULL causes png_read_IDAT_data to swallow any remaining bytes in + * the compressed stream, but the stream may be damaged too, so even after + * this call we may need to terminate the zstream ownership. + */ + png_read_IDAT_data(png_ptr, NULL, 0); + png_ptr->zstream.next_out = NULL; /* safety */ + + /* Now clear everything out for safety; the following may not have been + * done. + */ + if (!(png_ptr->flags & PNG_FLAG_ZSTREAM_ENDED)) + { + png_ptr->mode |= PNG_AFTER_IDAT; + png_ptr->flags |= PNG_FLAG_ZSTREAM_ENDED; + } + } + + /* If the zstream has not been released do it now *and* terminate the reading + * of the final IDAT chunk. + */ + if (png_ptr->zowner == png_IDAT) + { + /* Always do this; the pointers otherwise point into the read buffer. */ + png_ptr->zstream.next_in = NULL; + png_ptr->zstream.avail_in = 0; + + /* Now we no longer own the zstream. */ + png_ptr->zowner = 0; + + /* The slightly weird semantics of the sequential IDAT reading is that we + * are always in or at the end of an IDAT chunk, so we always need to do a + * crc_finish here. If idat_size is non-zero we also need to read the + * spurious bytes at the end of the chunk now. + */ + (void)png_crc_finish(png_ptr, png_ptr->idat_size); + } +} + +void /* PRIVATE */ +png_read_finish_row(png_structrp png_ptr) +{ +#ifdef PNG_READ_INTERLACING_SUPPORTED + /* Arrays to facilitate easy interlacing - use pass (0 - 6) as index */ + + /* Start of interlace block */ + static PNG_CONST png_byte png_pass_start[7] = {0, 4, 0, 2, 0, 1, 0}; + + /* Offset to next interlace block */ + static PNG_CONST png_byte png_pass_inc[7] = {8, 8, 4, 4, 2, 2, 1}; + + /* Start of interlace block in the y direction */ + static PNG_CONST png_byte png_pass_ystart[7] = {0, 0, 4, 0, 2, 0, 1}; + + /* Offset to next interlace block in the y direction */ + static PNG_CONST png_byte png_pass_yinc[7] = {8, 8, 8, 4, 4, 2, 2}; +#endif /* PNG_READ_INTERLACING_SUPPORTED */ + + png_debug(1, "in png_read_finish_row"); + png_ptr->row_number++; + if (png_ptr->row_number < png_ptr->num_rows) + return; + +#ifdef PNG_READ_INTERLACING_SUPPORTED + if (png_ptr->interlaced) + { + png_ptr->row_number = 0; + + /* TO DO: don't do this if prev_row isn't needed (requires + * read-ahead of the next row's filter byte. + */ + memset(png_ptr->prev_row, 0, png_ptr->rowbytes + 1); + + do + { + png_ptr->pass++; + + if (png_ptr->pass >= 7) + break; + + png_ptr->iwidth = (png_ptr->width + + png_pass_inc[png_ptr->pass] - 1 - + png_pass_start[png_ptr->pass]) / + png_pass_inc[png_ptr->pass]; + + if (!(png_ptr->transformations & PNG_INTERLACE)) + { + png_ptr->num_rows = (png_ptr->height + + png_pass_yinc[png_ptr->pass] - 1 - + png_pass_ystart[png_ptr->pass]) / + png_pass_yinc[png_ptr->pass]; + } + + else /* if (png_ptr->transformations & PNG_INTERLACE) */ + break; /* libpng deinterlacing sees every row */ + + } while (png_ptr->num_rows == 0 || png_ptr->iwidth == 0); + + if (png_ptr->pass < 7) + return; + } +#endif /* PNG_READ_INTERLACING_SUPPORTED */ + + /* Here after at the end of the last row of the last pass. */ + png_read_finish_IDAT(png_ptr); +} +#endif /* PNG_SEQUENTIAL_READ_SUPPORTED */ + +void /* PRIVATE */ +png_read_start_row(png_structrp png_ptr) +{ +#ifdef PNG_READ_INTERLACING_SUPPORTED + /* Arrays to facilitate easy interlacing - use pass (0 - 6) as index */ + + /* Start of interlace block */ + static PNG_CONST png_byte png_pass_start[7] = {0, 4, 0, 2, 0, 1, 0}; + + /* Offset to next interlace block */ + static PNG_CONST png_byte png_pass_inc[7] = {8, 8, 4, 4, 2, 2, 1}; + + /* Start of interlace block in the y direction */ + static PNG_CONST png_byte png_pass_ystart[7] = {0, 0, 4, 0, 2, 0, 1}; + + /* Offset to next interlace block in the y direction */ + static PNG_CONST png_byte png_pass_yinc[7] = {8, 8, 8, 4, 4, 2, 2}; +#endif + + int max_pixel_depth; + png_size_t row_bytes; + + png_debug(1, "in png_read_start_row"); + +#ifdef PNG_READ_TRANSFORMS_SUPPORTED + png_init_read_transformations(png_ptr); +#endif +#ifdef PNG_READ_INTERLACING_SUPPORTED + if (png_ptr->interlaced) + { + if (!(png_ptr->transformations & PNG_INTERLACE)) + png_ptr->num_rows = (png_ptr->height + png_pass_yinc[0] - 1 - + png_pass_ystart[0]) / png_pass_yinc[0]; + + else + png_ptr->num_rows = png_ptr->height; + + png_ptr->iwidth = (png_ptr->width + + png_pass_inc[png_ptr->pass] - 1 - + png_pass_start[png_ptr->pass]) / + png_pass_inc[png_ptr->pass]; + } + + else +#endif /* PNG_READ_INTERLACING_SUPPORTED */ + { + png_ptr->num_rows = png_ptr->height; + png_ptr->iwidth = png_ptr->width; + } + + max_pixel_depth = png_ptr->pixel_depth; + + /* WARNING: * png_read_transform_info (pngrtran.c) performs a simpliar set of + * calculations to calculate the final pixel depth, then + * png_do_read_transforms actually does the transforms. This means that the + * code which effectively calculates this value is actually repeated in three + * separate places. They must all match. Innocent changes to the order of + * transformations can and will break libpng in a way that causes memory + * overwrites. + * + * TODO: fix this. + */ +#ifdef PNG_READ_PACK_SUPPORTED + if ((png_ptr->transformations & PNG_PACK) && png_ptr->bit_depth < 8) + max_pixel_depth = 8; +#endif + +#ifdef PNG_READ_EXPAND_SUPPORTED + if (png_ptr->transformations & PNG_EXPAND) + { + if (png_ptr->color_type == PNG_COLOR_TYPE_PALETTE) + { + if (png_ptr->num_trans) + max_pixel_depth = 32; + + else + max_pixel_depth = 24; + } + + else if (png_ptr->color_type == PNG_COLOR_TYPE_GRAY) + { + if (max_pixel_depth < 8) + max_pixel_depth = 8; + + if (png_ptr->num_trans) + max_pixel_depth *= 2; + } + + else if (png_ptr->color_type == PNG_COLOR_TYPE_RGB) + { + if (png_ptr->num_trans) + { + max_pixel_depth *= 4; + max_pixel_depth /= 3; + } + } + } +#endif + +#ifdef PNG_READ_EXPAND_16_SUPPORTED + if (png_ptr->transformations & PNG_EXPAND_16) + { +# ifdef PNG_READ_EXPAND_SUPPORTED + /* In fact it is an error if it isn't supported, but checking is + * the safe way. + */ + if (png_ptr->transformations & PNG_EXPAND) + { + if (png_ptr->bit_depth < 16) + max_pixel_depth *= 2; + } + else +# endif + png_ptr->transformations &= ~PNG_EXPAND_16; + } +#endif + +#ifdef PNG_READ_FILLER_SUPPORTED + if (png_ptr->transformations & (PNG_FILLER)) + { + if (png_ptr->color_type == PNG_COLOR_TYPE_GRAY) + { + if (max_pixel_depth <= 8) + max_pixel_depth = 16; + + else + max_pixel_depth = 32; + } + + else if (png_ptr->color_type == PNG_COLOR_TYPE_RGB || + png_ptr->color_type == PNG_COLOR_TYPE_PALETTE) + { + if (max_pixel_depth <= 32) + max_pixel_depth = 32; + + else + max_pixel_depth = 64; + } + } +#endif + +#ifdef PNG_READ_GRAY_TO_RGB_SUPPORTED + if (png_ptr->transformations & PNG_GRAY_TO_RGB) + { + if ( +#ifdef PNG_READ_EXPAND_SUPPORTED + (png_ptr->num_trans && (png_ptr->transformations & PNG_EXPAND)) || +#endif +#ifdef PNG_READ_FILLER_SUPPORTED + (png_ptr->transformations & (PNG_FILLER)) || +#endif + png_ptr->color_type == PNG_COLOR_TYPE_GRAY_ALPHA) + { + if (max_pixel_depth <= 16) + max_pixel_depth = 32; + + else + max_pixel_depth = 64; + } + + else + { + if (max_pixel_depth <= 8) + { + if (png_ptr->color_type == PNG_COLOR_TYPE_RGB_ALPHA) + max_pixel_depth = 32; + + else + max_pixel_depth = 24; + } + + else if (png_ptr->color_type == PNG_COLOR_TYPE_RGB_ALPHA) + max_pixel_depth = 64; + + else + max_pixel_depth = 48; + } + } +#endif + +#if defined(PNG_READ_USER_TRANSFORM_SUPPORTED) && \ +defined(PNG_USER_TRANSFORM_PTR_SUPPORTED) + if (png_ptr->transformations & PNG_USER_TRANSFORM) + { + int user_pixel_depth = png_ptr->user_transform_depth * + png_ptr->user_transform_channels; + + if (user_pixel_depth > max_pixel_depth) + max_pixel_depth = user_pixel_depth; + } +#endif + + /* This value is stored in png_struct and double checked in the row read + * code. + */ + png_ptr->maximum_pixel_depth = (png_byte)max_pixel_depth; + png_ptr->transformed_pixel_depth = 0; /* calculated on demand */ + + /* Align the width on the next larger 8 pixels. Mainly used + * for interlacing + */ + row_bytes = ((png_ptr->width + 7) & ~((png_uint_32)7)); + /* Calculate the maximum bytes needed, adding a byte and a pixel + * for safety's sake + */ + row_bytes = PNG_ROWBYTES(max_pixel_depth, row_bytes) + + 1 + ((max_pixel_depth + 7) >> 3); + +#ifdef PNG_MAX_MALLOC_64K + if (row_bytes > (png_uint_32)65536L) + png_error(png_ptr, "This image requires a row greater than 64KB"); +#endif + + if (row_bytes + 48 > png_ptr->old_big_row_buf_size) + { + png_free(png_ptr, png_ptr->big_row_buf); + png_free(png_ptr, png_ptr->big_prev_row); + + if (png_ptr->interlaced) + png_ptr->big_row_buf = (png_bytep)png_calloc(png_ptr, + row_bytes + 48); + + else + png_ptr->big_row_buf = (png_bytep)png_malloc(png_ptr, row_bytes + 48); + + png_ptr->big_prev_row = (png_bytep)png_malloc(png_ptr, row_bytes + 48); + +#ifdef PNG_ALIGNED_MEMORY_SUPPORTED + /* Use 16-byte aligned memory for row_buf with at least 16 bytes + * of padding before and after row_buf; treat prev_row similarly. + * NOTE: the alignment is to the start of the pixels, one beyond the start + * of the buffer, because of the filter byte. Prior to libpng 1.5.6 this + * was incorrect; the filter byte was aligned, which had the exact + * opposite effect of that intended. + */ + { + png_bytep temp = png_ptr->big_row_buf + 32; + int extra = (int)((temp - (png_bytep)0) & 0x0f); + png_ptr->row_buf = temp - extra - 1/*filter byte*/; + + temp = png_ptr->big_prev_row + 32; + extra = (int)((temp - (png_bytep)0) & 0x0f); + png_ptr->prev_row = temp - extra - 1/*filter byte*/; + } + +#else + /* Use 31 bytes of padding before and 17 bytes after row_buf. */ + png_ptr->row_buf = png_ptr->big_row_buf + 31; + png_ptr->prev_row = png_ptr->big_prev_row + 31; +#endif + png_ptr->old_big_row_buf_size = row_bytes + 48; + } + +#ifdef PNG_MAX_MALLOC_64K + if (png_ptr->rowbytes > 65535) + png_error(png_ptr, "This image requires a row greater than 64KB"); + +#endif + if (png_ptr->rowbytes > (PNG_SIZE_MAX - 1)) + png_error(png_ptr, "Row has too many bytes to allocate in memory"); + + memset(png_ptr->prev_row, 0, png_ptr->rowbytes + 1); + + png_debug1(3, "width = %u,", png_ptr->width); + png_debug1(3, "height = %u,", png_ptr->height); + png_debug1(3, "iwidth = %u,", png_ptr->iwidth); + png_debug1(3, "num_rows = %u,", png_ptr->num_rows); + png_debug1(3, "rowbytes = %lu,", (unsigned long)png_ptr->rowbytes); + png_debug1(3, "irowbytes = %lu", + (unsigned long)PNG_ROWBYTES(png_ptr->pixel_depth, png_ptr->iwidth) + 1); + + /* The sequential reader needs a buffer for IDAT, but the progressive reader + * does not, so free the read buffer now regardless; the sequential reader + * reallocates it on demand. + */ + if (png_ptr->read_buffer) + { + png_bytep buffer = png_ptr->read_buffer; + + png_ptr->read_buffer_size = 0; + png_ptr->read_buffer = NULL; + png_free(png_ptr, buffer); + } + + /* Finally claim the zstream for the inflate of the IDAT data, use the bits + * value from the stream (note that this will result in a fatal error if the + * IDAT stream has a bogus deflate header window_bits value, but this should + * not be happening any longer!) + */ + if (png_inflate_claim(png_ptr, png_IDAT) != Z_OK) + png_error(png_ptr, png_ptr->zstream.msg); + + png_ptr->flags |= PNG_FLAG_ROW_INIT; +} +#endif /* PNG_READ_SUPPORTED */ diff --git a/ml/dlib/dlib/external/libpng/pngset.c b/ml/dlib/dlib/external/libpng/pngset.c new file mode 100644 index 000000000..7e355d1f4 --- /dev/null +++ b/ml/dlib/dlib/external/libpng/pngset.c @@ -0,0 +1,1597 @@ + +/* pngset.c - storage of image information into info struct + * + * Last changed in libpng 1.6.3 [July 18, 2013] + * Copyright (c) 1998-2013 Glenn Randers-Pehrson + * (Version 0.96 Copyright (c) 1996, 1997 Andreas Dilger) + * (Version 0.88 Copyright (c) 1995, 1996 Guy Eric Schalnat, Group 42, Inc.) + * + * This code is released under the libpng license. + * For conditions of distribution and use, see the disclaimer + * and license in png.h + * + * The functions here are used during reads to store data from the file + * into the info struct, and during writes to store application data + * into the info struct for writing into the file. This abstracts the + * info struct and allows us to change the structure in the future. + */ + +#include "pngpriv.h" + +#if defined(PNG_READ_SUPPORTED) || defined(PNG_WRITE_SUPPORTED) + +#ifdef PNG_bKGD_SUPPORTED +void PNGAPI +png_set_bKGD(png_const_structrp png_ptr, png_inforp info_ptr, + png_const_color_16p background) +{ + png_debug1(1, "in %s storage function", "bKGD"); + + if (png_ptr == NULL || info_ptr == NULL || background == NULL) + return; + + info_ptr->background = *background; + info_ptr->valid |= PNG_INFO_bKGD; +} +#endif + +#ifdef PNG_cHRM_SUPPORTED +void PNGFAPI +png_set_cHRM_fixed(png_const_structrp png_ptr, png_inforp info_ptr, + png_fixed_point white_x, png_fixed_point white_y, png_fixed_point red_x, + png_fixed_point red_y, png_fixed_point green_x, png_fixed_point green_y, + png_fixed_point blue_x, png_fixed_point blue_y) +{ + png_xy xy; + + png_debug1(1, "in %s storage function", "cHRM fixed"); + + if (png_ptr == NULL || info_ptr == NULL) + return; + + xy.redx = red_x; + xy.redy = red_y; + xy.greenx = green_x; + xy.greeny = green_y; + xy.bluex = blue_x; + xy.bluey = blue_y; + xy.whitex = white_x; + xy.whitey = white_y; + + if (png_colorspace_set_chromaticities(png_ptr, &info_ptr->colorspace, &xy, + 2/* override with app values*/)) + info_ptr->colorspace.flags |= PNG_COLORSPACE_FROM_cHRM; + + png_colorspace_sync_info(png_ptr, info_ptr); +} + +void PNGFAPI +png_set_cHRM_XYZ_fixed(png_const_structrp png_ptr, png_inforp info_ptr, + png_fixed_point int_red_X, png_fixed_point int_red_Y, + png_fixed_point int_red_Z, png_fixed_point int_green_X, + png_fixed_point int_green_Y, png_fixed_point int_green_Z, + png_fixed_point int_blue_X, png_fixed_point int_blue_Y, + png_fixed_point int_blue_Z) +{ + png_XYZ XYZ; + + png_debug1(1, "in %s storage function", "cHRM XYZ fixed"); + + if (png_ptr == NULL || info_ptr == NULL) + return; + + XYZ.red_X = int_red_X; + XYZ.red_Y = int_red_Y; + XYZ.red_Z = int_red_Z; + XYZ.green_X = int_green_X; + XYZ.green_Y = int_green_Y; + XYZ.green_Z = int_green_Z; + XYZ.blue_X = int_blue_X; + XYZ.blue_Y = int_blue_Y; + XYZ.blue_Z = int_blue_Z; + + if (png_colorspace_set_endpoints(png_ptr, &info_ptr->colorspace, &XYZ, 2)) + info_ptr->colorspace.flags |= PNG_COLORSPACE_FROM_cHRM; + + png_colorspace_sync_info(png_ptr, info_ptr); +} + +# ifdef PNG_FLOATING_POINT_SUPPORTED +void PNGAPI +png_set_cHRM(png_const_structrp png_ptr, png_inforp info_ptr, + double white_x, double white_y, double red_x, double red_y, + double green_x, double green_y, double blue_x, double blue_y) +{ + png_set_cHRM_fixed(png_ptr, info_ptr, + png_fixed(png_ptr, white_x, "cHRM White X"), + png_fixed(png_ptr, white_y, "cHRM White Y"), + png_fixed(png_ptr, red_x, "cHRM Red X"), + png_fixed(png_ptr, red_y, "cHRM Red Y"), + png_fixed(png_ptr, green_x, "cHRM Green X"), + png_fixed(png_ptr, green_y, "cHRM Green Y"), + png_fixed(png_ptr, blue_x, "cHRM Blue X"), + png_fixed(png_ptr, blue_y, "cHRM Blue Y")); +} + +void PNGAPI +png_set_cHRM_XYZ(png_const_structrp png_ptr, png_inforp info_ptr, double red_X, + double red_Y, double red_Z, double green_X, double green_Y, double green_Z, + double blue_X, double blue_Y, double blue_Z) +{ + png_set_cHRM_XYZ_fixed(png_ptr, info_ptr, + png_fixed(png_ptr, red_X, "cHRM Red X"), + png_fixed(png_ptr, red_Y, "cHRM Red Y"), + png_fixed(png_ptr, red_Z, "cHRM Red Z"), + png_fixed(png_ptr, green_X, "cHRM Red X"), + png_fixed(png_ptr, green_Y, "cHRM Red Y"), + png_fixed(png_ptr, green_Z, "cHRM Red Z"), + png_fixed(png_ptr, blue_X, "cHRM Red X"), + png_fixed(png_ptr, blue_Y, "cHRM Red Y"), + png_fixed(png_ptr, blue_Z, "cHRM Red Z")); +} +# endif /* PNG_FLOATING_POINT_SUPPORTED */ + +#endif /* PNG_cHRM_SUPPORTED */ + +#ifdef PNG_gAMA_SUPPORTED +void PNGFAPI +png_set_gAMA_fixed(png_const_structrp png_ptr, png_inforp info_ptr, + png_fixed_point file_gamma) +{ + png_debug1(1, "in %s storage function", "gAMA"); + + if (png_ptr == NULL || info_ptr == NULL) + return; + + png_colorspace_set_gamma(png_ptr, &info_ptr->colorspace, file_gamma); + png_colorspace_sync_info(png_ptr, info_ptr); +} + +# ifdef PNG_FLOATING_POINT_SUPPORTED +void PNGAPI +png_set_gAMA(png_const_structrp png_ptr, png_inforp info_ptr, double file_gamma) +{ + png_set_gAMA_fixed(png_ptr, info_ptr, png_fixed(png_ptr, file_gamma, + "png_set_gAMA")); +} +# endif +#endif + +#ifdef PNG_hIST_SUPPORTED +void PNGAPI +png_set_hIST(png_const_structrp png_ptr, png_inforp info_ptr, + png_const_uint_16p hist) +{ + int i; + + png_debug1(1, "in %s storage function", "hIST"); + + if (png_ptr == NULL || info_ptr == NULL) + return; + + if (info_ptr->num_palette == 0 || info_ptr->num_palette + > PNG_MAX_PALETTE_LENGTH) + { + png_warning(png_ptr, + "Invalid palette size, hIST allocation skipped"); + + return; + } + + png_free_data(png_ptr, info_ptr, PNG_FREE_HIST, 0); + + /* Changed from info->num_palette to PNG_MAX_PALETTE_LENGTH in + * version 1.2.1 + */ + info_ptr->hist = png_voidcast(png_uint_16p, png_malloc_warn(png_ptr, + PNG_MAX_PALETTE_LENGTH * (sizeof (png_uint_16)))); + + if (info_ptr->hist == NULL) + { + png_warning(png_ptr, "Insufficient memory for hIST chunk data"); + return; + } + + info_ptr->free_me |= PNG_FREE_HIST; + + for (i = 0; i < info_ptr->num_palette; i++) + info_ptr->hist[i] = hist[i]; + + info_ptr->valid |= PNG_INFO_hIST; +} +#endif + +void PNGAPI +png_set_IHDR(png_const_structrp png_ptr, png_inforp info_ptr, + png_uint_32 width, png_uint_32 height, int bit_depth, + int color_type, int interlace_type, int compression_type, + int filter_type) +{ + png_debug1(1, "in %s storage function", "IHDR"); + + if (png_ptr == NULL || info_ptr == NULL) + return; + + info_ptr->width = width; + info_ptr->height = height; + info_ptr->bit_depth = (png_byte)bit_depth; + info_ptr->color_type = (png_byte)color_type; + info_ptr->compression_type = (png_byte)compression_type; + info_ptr->filter_type = (png_byte)filter_type; + info_ptr->interlace_type = (png_byte)interlace_type; + + png_check_IHDR (png_ptr, info_ptr->width, info_ptr->height, + info_ptr->bit_depth, info_ptr->color_type, info_ptr->interlace_type, + info_ptr->compression_type, info_ptr->filter_type); + + if (info_ptr->color_type == PNG_COLOR_TYPE_PALETTE) + info_ptr->channels = 1; + + else if (info_ptr->color_type & PNG_COLOR_MASK_COLOR) + info_ptr->channels = 3; + + else + info_ptr->channels = 1; + + if (info_ptr->color_type & PNG_COLOR_MASK_ALPHA) + info_ptr->channels++; + + info_ptr->pixel_depth = (png_byte)(info_ptr->channels * info_ptr->bit_depth); + + info_ptr->rowbytes = PNG_ROWBYTES(info_ptr->pixel_depth, width); +} + +#ifdef PNG_oFFs_SUPPORTED +void PNGAPI +png_set_oFFs(png_const_structrp png_ptr, png_inforp info_ptr, + png_int_32 offset_x, png_int_32 offset_y, int unit_type) +{ + png_debug1(1, "in %s storage function", "oFFs"); + + if (png_ptr == NULL || info_ptr == NULL) + return; + + info_ptr->x_offset = offset_x; + info_ptr->y_offset = offset_y; + info_ptr->offset_unit_type = (png_byte)unit_type; + info_ptr->valid |= PNG_INFO_oFFs; +} +#endif + +#ifdef PNG_pCAL_SUPPORTED +void PNGAPI +png_set_pCAL(png_const_structrp png_ptr, png_inforp info_ptr, + png_const_charp purpose, png_int_32 X0, png_int_32 X1, int type, + int nparams, png_const_charp units, png_charpp params) +{ + png_size_t length; + int i; + + png_debug1(1, "in %s storage function", "pCAL"); + + if (png_ptr == NULL || info_ptr == NULL || purpose == NULL || units == NULL + || (nparams > 0 && params == NULL)) + return; + + length = strlen(purpose) + 1; + png_debug1(3, "allocating purpose for info (%lu bytes)", + (unsigned long)length); + + /* TODO: validate format of calibration name and unit name */ + + /* Check that the type matches the specification. */ + if (type < 0 || type > 3) + png_error(png_ptr, "Invalid pCAL equation type"); + + if (nparams < 0 || nparams > 255) + png_error(png_ptr, "Invalid pCAL parameter count"); + + /* Validate params[nparams] */ + for (i=0; ipcal_purpose = png_voidcast(png_charp, + png_malloc_warn(png_ptr, length)); + + if (info_ptr->pcal_purpose == NULL) + { + png_warning(png_ptr, "Insufficient memory for pCAL purpose"); + return; + } + + memcpy(info_ptr->pcal_purpose, purpose, length); + + png_debug(3, "storing X0, X1, type, and nparams in info"); + info_ptr->pcal_X0 = X0; + info_ptr->pcal_X1 = X1; + info_ptr->pcal_type = (png_byte)type; + info_ptr->pcal_nparams = (png_byte)nparams; + + length = strlen(units) + 1; + png_debug1(3, "allocating units for info (%lu bytes)", + (unsigned long)length); + + info_ptr->pcal_units = png_voidcast(png_charp, + png_malloc_warn(png_ptr, length)); + + if (info_ptr->pcal_units == NULL) + { + png_warning(png_ptr, "Insufficient memory for pCAL units"); + return; + } + + memcpy(info_ptr->pcal_units, units, length); + + info_ptr->pcal_params = png_voidcast(png_charpp, png_malloc_warn(png_ptr, + (png_size_t)((nparams + 1) * (sizeof (png_charp))))); + + if (info_ptr->pcal_params == NULL) + { + png_warning(png_ptr, "Insufficient memory for pCAL params"); + return; + } + + memset(info_ptr->pcal_params, 0, (nparams + 1) * (sizeof (png_charp))); + + for (i = 0; i < nparams; i++) + { + length = strlen(params[i]) + 1; + png_debug2(3, "allocating parameter %d for info (%lu bytes)", i, + (unsigned long)length); + + info_ptr->pcal_params[i] = (png_charp)png_malloc_warn(png_ptr, length); + + if (info_ptr->pcal_params[i] == NULL) + { + png_warning(png_ptr, "Insufficient memory for pCAL parameter"); + return; + } + + memcpy(info_ptr->pcal_params[i], params[i], length); + } + + info_ptr->valid |= PNG_INFO_pCAL; + info_ptr->free_me |= PNG_FREE_PCAL; +} +#endif + +#ifdef PNG_sCAL_SUPPORTED +void PNGAPI +png_set_sCAL_s(png_const_structrp png_ptr, png_inforp info_ptr, + int unit, png_const_charp swidth, png_const_charp sheight) +{ + png_size_t lengthw = 0, lengthh = 0; + + png_debug1(1, "in %s storage function", "sCAL"); + + if (png_ptr == NULL || info_ptr == NULL) + return; + + /* Double check the unit (should never get here with an invalid + * unit unless this is an API call.) + */ + if (unit != 1 && unit != 2) + png_error(png_ptr, "Invalid sCAL unit"); + + if (swidth == NULL || (lengthw = strlen(swidth)) == 0 || + swidth[0] == 45 /* '-' */ || !png_check_fp_string(swidth, lengthw)) + png_error(png_ptr, "Invalid sCAL width"); + + if (sheight == NULL || (lengthh = strlen(sheight)) == 0 || + sheight[0] == 45 /* '-' */ || !png_check_fp_string(sheight, lengthh)) + png_error(png_ptr, "Invalid sCAL height"); + + info_ptr->scal_unit = (png_byte)unit; + + ++lengthw; + + png_debug1(3, "allocating unit for info (%u bytes)", (unsigned int)lengthw); + + info_ptr->scal_s_width = png_voidcast(png_charp, + png_malloc_warn(png_ptr, lengthw)); + + if (info_ptr->scal_s_width == NULL) + { + png_warning(png_ptr, "Memory allocation failed while processing sCAL"); + return; + } + + memcpy(info_ptr->scal_s_width, swidth, lengthw); + + ++lengthh; + + png_debug1(3, "allocating unit for info (%u bytes)", (unsigned int)lengthh); + + info_ptr->scal_s_height = png_voidcast(png_charp, + png_malloc_warn(png_ptr, lengthh)); + + if (info_ptr->scal_s_height == NULL) + { + png_free (png_ptr, info_ptr->scal_s_width); + info_ptr->scal_s_width = NULL; + + png_warning(png_ptr, "Memory allocation failed while processing sCAL"); + return; + } + + memcpy(info_ptr->scal_s_height, sheight, lengthh); + + info_ptr->valid |= PNG_INFO_sCAL; + info_ptr->free_me |= PNG_FREE_SCAL; +} + +# ifdef PNG_FLOATING_POINT_SUPPORTED +void PNGAPI +png_set_sCAL(png_const_structrp png_ptr, png_inforp info_ptr, int unit, + double width, double height) +{ + png_debug1(1, "in %s storage function", "sCAL"); + + /* Check the arguments. */ + if (width <= 0) + png_warning(png_ptr, "Invalid sCAL width ignored"); + + else if (height <= 0) + png_warning(png_ptr, "Invalid sCAL height ignored"); + + else + { + /* Convert 'width' and 'height' to ASCII. */ + char swidth[PNG_sCAL_MAX_DIGITS+1]; + char sheight[PNG_sCAL_MAX_DIGITS+1]; + + png_ascii_from_fp(png_ptr, swidth, (sizeof swidth), width, + PNG_sCAL_PRECISION); + png_ascii_from_fp(png_ptr, sheight, (sizeof sheight), height, + PNG_sCAL_PRECISION); + + png_set_sCAL_s(png_ptr, info_ptr, unit, swidth, sheight); + } +} +# endif + +# ifdef PNG_FIXED_POINT_SUPPORTED +void PNGAPI +png_set_sCAL_fixed(png_const_structrp png_ptr, png_inforp info_ptr, int unit, + png_fixed_point width, png_fixed_point height) +{ + png_debug1(1, "in %s storage function", "sCAL"); + + /* Check the arguments. */ + if (width <= 0) + png_warning(png_ptr, "Invalid sCAL width ignored"); + + else if (height <= 0) + png_warning(png_ptr, "Invalid sCAL height ignored"); + + else + { + /* Convert 'width' and 'height' to ASCII. */ + char swidth[PNG_sCAL_MAX_DIGITS+1]; + char sheight[PNG_sCAL_MAX_DIGITS+1]; + + png_ascii_from_fixed(png_ptr, swidth, (sizeof swidth), width); + png_ascii_from_fixed(png_ptr, sheight, (sizeof sheight), height); + + png_set_sCAL_s(png_ptr, info_ptr, unit, swidth, sheight); + } +} +# endif +#endif + +#ifdef PNG_pHYs_SUPPORTED +void PNGAPI +png_set_pHYs(png_const_structrp png_ptr, png_inforp info_ptr, + png_uint_32 res_x, png_uint_32 res_y, int unit_type) +{ + png_debug1(1, "in %s storage function", "pHYs"); + + if (png_ptr == NULL || info_ptr == NULL) + return; + + info_ptr->x_pixels_per_unit = res_x; + info_ptr->y_pixels_per_unit = res_y; + info_ptr->phys_unit_type = (png_byte)unit_type; + info_ptr->valid |= PNG_INFO_pHYs; +} +#endif + +void PNGAPI +png_set_PLTE(png_structrp png_ptr, png_inforp info_ptr, + png_const_colorp palette, int num_palette) +{ + + png_debug1(1, "in %s storage function", "PLTE"); + + if (png_ptr == NULL || info_ptr == NULL) + return; + + if (num_palette < 0 || num_palette > PNG_MAX_PALETTE_LENGTH) + { + if (info_ptr->color_type == PNG_COLOR_TYPE_PALETTE) + png_error(png_ptr, "Invalid palette length"); + + else + { + png_warning(png_ptr, "Invalid palette length"); + return; + } + } + + if ((num_palette > 0 && palette == NULL) || + (num_palette == 0 +# ifdef PNG_MNG_FEATURES_SUPPORTED + && (png_ptr->mng_features_permitted & PNG_FLAG_MNG_EMPTY_PLTE) == 0 +# endif + )) + { + png_chunk_report(png_ptr, "Invalid palette", PNG_CHUNK_ERROR); + return; + } + + /* It may not actually be necessary to set png_ptr->palette here; + * we do it for backward compatibility with the way the png_handle_tRNS + * function used to do the allocation. + * + * 1.6.0: the above statement appears to be incorrect; something has to set + * the palette inside png_struct on read. + */ + png_free_data(png_ptr, info_ptr, PNG_FREE_PLTE, 0); + + /* Changed in libpng-1.2.1 to allocate PNG_MAX_PALETTE_LENGTH instead + * of num_palette entries, in case of an invalid PNG file that has + * too-large sample values. + */ + png_ptr->palette = png_voidcast(png_colorp, png_calloc(png_ptr, + PNG_MAX_PALETTE_LENGTH * (sizeof (png_color)))); + + if (num_palette > 0) + memcpy(png_ptr->palette, palette, num_palette * (sizeof (png_color))); + info_ptr->palette = png_ptr->palette; + info_ptr->num_palette = png_ptr->num_palette = (png_uint_16)num_palette; + + info_ptr->free_me |= PNG_FREE_PLTE; + + info_ptr->valid |= PNG_INFO_PLTE; +} + +#ifdef PNG_sBIT_SUPPORTED +void PNGAPI +png_set_sBIT(png_const_structrp png_ptr, png_inforp info_ptr, + png_const_color_8p sig_bit) +{ + png_debug1(1, "in %s storage function", "sBIT"); + + if (png_ptr == NULL || info_ptr == NULL || sig_bit == NULL) + return; + + info_ptr->sig_bit = *sig_bit; + info_ptr->valid |= PNG_INFO_sBIT; +} +#endif + +#ifdef PNG_sRGB_SUPPORTED +void PNGAPI +png_set_sRGB(png_const_structrp png_ptr, png_inforp info_ptr, int srgb_intent) +{ + png_debug1(1, "in %s storage function", "sRGB"); + + if (png_ptr == NULL || info_ptr == NULL) + return; + + (void)png_colorspace_set_sRGB(png_ptr, &info_ptr->colorspace, srgb_intent); + png_colorspace_sync_info(png_ptr, info_ptr); +} + +void PNGAPI +png_set_sRGB_gAMA_and_cHRM(png_const_structrp png_ptr, png_inforp info_ptr, + int srgb_intent) +{ + png_debug1(1, "in %s storage function", "sRGB_gAMA_and_cHRM"); + + if (png_ptr == NULL || info_ptr == NULL) + return; + + if (png_colorspace_set_sRGB(png_ptr, &info_ptr->colorspace, srgb_intent)) + { + /* This causes the gAMA and cHRM to be written too */ + info_ptr->colorspace.flags |= + PNG_COLORSPACE_FROM_gAMA|PNG_COLORSPACE_FROM_cHRM; + } + + png_colorspace_sync_info(png_ptr, info_ptr); +} +#endif /* sRGB */ + + +#ifdef PNG_iCCP_SUPPORTED +void PNGAPI +png_set_iCCP(png_const_structrp png_ptr, png_inforp info_ptr, + png_const_charp name, int compression_type, + png_const_bytep profile, png_uint_32 proflen) +{ + png_charp new_iccp_name; + png_bytep new_iccp_profile; + png_size_t length; + + png_debug1(1, "in %s storage function", "iCCP"); + + if (png_ptr == NULL || info_ptr == NULL || name == NULL || profile == NULL) + return; + + if (compression_type != PNG_COMPRESSION_TYPE_BASE) + png_app_error(png_ptr, "Invalid iCCP compression method"); + + /* Set the colorspace first because this validates the profile; do not + * override previously set app cHRM or gAMA here (because likely as not the + * application knows better than libpng what the correct values are.) Pass + * the info_ptr color_type field to png_colorspace_set_ICC because in the + * write case it has not yet been stored in png_ptr. + */ + { + int result = png_colorspace_set_ICC(png_ptr, &info_ptr->colorspace, name, + proflen, profile, info_ptr->color_type); + + png_colorspace_sync_info(png_ptr, info_ptr); + + /* Don't do any of the copying if the profile was bad, or inconsistent. */ + if (!result) + return; + + /* But do write the gAMA and cHRM chunks from the profile. */ + info_ptr->colorspace.flags |= + PNG_COLORSPACE_FROM_gAMA|PNG_COLORSPACE_FROM_cHRM; + } + + length = strlen(name)+1; + new_iccp_name = png_voidcast(png_charp, png_malloc_warn(png_ptr, length)); + + if (new_iccp_name == NULL) + { + png_benign_error(png_ptr, "Insufficient memory to process iCCP chunk"); + return; + } + + memcpy(new_iccp_name, name, length); + new_iccp_profile = png_voidcast(png_bytep, + png_malloc_warn(png_ptr, proflen)); + + if (new_iccp_profile == NULL) + { + png_free(png_ptr, new_iccp_name); + png_benign_error(png_ptr, + "Insufficient memory to process iCCP profile"); + return; + } + + memcpy(new_iccp_profile, profile, proflen); + + png_free_data(png_ptr, info_ptr, PNG_FREE_ICCP, 0); + + info_ptr->iccp_proflen = proflen; + info_ptr->iccp_name = new_iccp_name; + info_ptr->iccp_profile = new_iccp_profile; + info_ptr->free_me |= PNG_FREE_ICCP; + info_ptr->valid |= PNG_INFO_iCCP; +} +#endif + +#ifdef PNG_TEXT_SUPPORTED +void PNGAPI +png_set_text(png_const_structrp png_ptr, png_inforp info_ptr, + png_const_textp text_ptr, int num_text) +{ + int ret; + ret = png_set_text_2(png_ptr, info_ptr, text_ptr, num_text); + + if (ret) + png_error(png_ptr, "Insufficient memory to store text"); +} + +int /* PRIVATE */ +png_set_text_2(png_const_structrp png_ptr, png_inforp info_ptr, + png_const_textp text_ptr, int num_text) +{ + int i; + + png_debug1(1, "in %lx storage function", png_ptr == NULL ? "unexpected" : + (unsigned long)png_ptr->chunk_name); + + if (png_ptr == NULL || info_ptr == NULL || num_text <= 0 || text_ptr == NULL) + return(0); + + /* Make sure we have enough space in the "text" array in info_struct + * to hold all of the incoming text_ptr objects. This compare can't overflow + * because max_text >= num_text (anyway, subtract of two positive integers + * can't overflow in any case.) + */ + if (num_text > info_ptr->max_text - info_ptr->num_text) + { + int old_num_text = info_ptr->num_text; + int max_text; + png_textp new_text = NULL; + + /* Calculate an appropriate max_text, checking for overflow. */ + max_text = old_num_text; + if (num_text <= INT_MAX - max_text) + { + max_text += num_text; + + /* Round up to a multiple of 8 */ + if (max_text < INT_MAX-8) + max_text = (max_text + 8) & ~0x7; + + else + max_text = INT_MAX; + + /* Now allocate a new array and copy the old members in, this does all + * the overflow checks. + */ + new_text = png_voidcast(png_textp,png_realloc_array(png_ptr, + info_ptr->text, old_num_text, max_text-old_num_text, + sizeof *new_text)); + } + + if (new_text == NULL) + { + png_chunk_report(png_ptr, "too many text chunks", + PNG_CHUNK_WRITE_ERROR); + return 1; + } + + png_free(png_ptr, info_ptr->text); + + info_ptr->text = new_text; + info_ptr->free_me |= PNG_FREE_TEXT; + info_ptr->max_text = max_text; + /* num_text is adjusted below as the entries are copied in */ + + png_debug1(3, "allocated %d entries for info_ptr->text", max_text); + } + + for (i = 0; i < num_text; i++) + { + size_t text_length, key_len; + size_t lang_len, lang_key_len; + png_textp textp = &(info_ptr->text[info_ptr->num_text]); + + if (text_ptr[i].key == NULL) + continue; + + if (text_ptr[i].compression < PNG_TEXT_COMPRESSION_NONE || + text_ptr[i].compression >= PNG_TEXT_COMPRESSION_LAST) + { + png_chunk_report(png_ptr, "text compression mode is out of range", + PNG_CHUNK_WRITE_ERROR); + continue; + } + + key_len = strlen(text_ptr[i].key); + + if (text_ptr[i].compression <= 0) + { + lang_len = 0; + lang_key_len = 0; + } + + else +# ifdef PNG_iTXt_SUPPORTED + { + /* Set iTXt data */ + + if (text_ptr[i].lang != NULL) + lang_len = strlen(text_ptr[i].lang); + + else + lang_len = 0; + + if (text_ptr[i].lang_key != NULL) + lang_key_len = strlen(text_ptr[i].lang_key); + + else + lang_key_len = 0; + } +# else /* PNG_iTXt_SUPPORTED */ + { + png_chunk_report(png_ptr, "iTXt chunk not supported", + PNG_CHUNK_WRITE_ERROR); + continue; + } +# endif + + if (text_ptr[i].text == NULL || text_ptr[i].text[0] == '\0') + { + text_length = 0; +# ifdef PNG_iTXt_SUPPORTED + if (text_ptr[i].compression > 0) + textp->compression = PNG_ITXT_COMPRESSION_NONE; + + else +# endif + textp->compression = PNG_TEXT_COMPRESSION_NONE; + } + + else + { + text_length = strlen(text_ptr[i].text); + textp->compression = text_ptr[i].compression; + } + + textp->key = png_voidcast(png_charp,png_malloc_base(png_ptr, + key_len + text_length + lang_len + lang_key_len + 4)); + + if (textp->key == NULL) + { + png_chunk_report(png_ptr, "text chunk: out of memory", + PNG_CHUNK_WRITE_ERROR); + return 1; + } + + png_debug2(2, "Allocated %lu bytes at %p in png_set_text", + (unsigned long)(png_uint_32) + (key_len + lang_len + lang_key_len + text_length + 4), + textp->key); + + memcpy(textp->key, text_ptr[i].key, key_len); + *(textp->key + key_len) = '\0'; + + if (text_ptr[i].compression > 0) + { + textp->lang = textp->key + key_len + 1; + memcpy(textp->lang, text_ptr[i].lang, lang_len); + *(textp->lang + lang_len) = '\0'; + textp->lang_key = textp->lang + lang_len + 1; + memcpy(textp->lang_key, text_ptr[i].lang_key, lang_key_len); + *(textp->lang_key + lang_key_len) = '\0'; + textp->text = textp->lang_key + lang_key_len + 1; + } + + else + { + textp->lang=NULL; + textp->lang_key=NULL; + textp->text = textp->key + key_len + 1; + } + + if (text_length) + memcpy(textp->text, text_ptr[i].text, text_length); + + *(textp->text + text_length) = '\0'; + +# ifdef PNG_iTXt_SUPPORTED + if (textp->compression > 0) + { + textp->text_length = 0; + textp->itxt_length = text_length; + } + + else +# endif + { + textp->text_length = text_length; + textp->itxt_length = 0; + } + + info_ptr->num_text++; + png_debug1(3, "transferred text chunk %d", info_ptr->num_text); + } + + return(0); +} +#endif + +#ifdef PNG_tIME_SUPPORTED +void PNGAPI +png_set_tIME(png_const_structrp png_ptr, png_inforp info_ptr, + png_const_timep mod_time) +{ + png_debug1(1, "in %s storage function", "tIME"); + + if (png_ptr == NULL || info_ptr == NULL || mod_time == NULL || + (png_ptr->mode & PNG_WROTE_tIME)) + return; + + if (mod_time->month == 0 || mod_time->month > 12 || + mod_time->day == 0 || mod_time->day > 31 || + mod_time->hour > 23 || mod_time->minute > 59 || + mod_time->second > 60) + { + png_warning(png_ptr, "Ignoring invalid time value"); + return; + } + + info_ptr->mod_time = *mod_time; + info_ptr->valid |= PNG_INFO_tIME; +} +#endif + +#ifdef PNG_tRNS_SUPPORTED +void PNGAPI +png_set_tRNS(png_structrp png_ptr, png_inforp info_ptr, + png_const_bytep trans_alpha, int num_trans, png_const_color_16p trans_color) +{ + png_debug1(1, "in %s storage function", "tRNS"); + + if (png_ptr == NULL || info_ptr == NULL) + return; + + if (trans_alpha != NULL) + { + /* It may not actually be necessary to set png_ptr->trans_alpha here; + * we do it for backward compatibility with the way the png_handle_tRNS + * function used to do the allocation. + * + * 1.6.0: The above statement is incorrect; png_handle_tRNS effectively + * relies on png_set_tRNS storing the information in png_struct + * (otherwise it won't be there for the code in pngrtran.c). + */ + + png_free_data(png_ptr, info_ptr, PNG_FREE_TRNS, 0); + + /* Changed from num_trans to PNG_MAX_PALETTE_LENGTH in version 1.2.1 */ + png_ptr->trans_alpha = info_ptr->trans_alpha = png_voidcast(png_bytep, + png_malloc(png_ptr, PNG_MAX_PALETTE_LENGTH)); + + if (num_trans > 0 && num_trans <= PNG_MAX_PALETTE_LENGTH) + memcpy(info_ptr->trans_alpha, trans_alpha, (png_size_t)num_trans); + } + + if (trans_color != NULL) + { + int sample_max = (1 << info_ptr->bit_depth); + + if ((info_ptr->color_type == PNG_COLOR_TYPE_GRAY && + trans_color->gray > sample_max) || + (info_ptr->color_type == PNG_COLOR_TYPE_RGB && + (trans_color->red > sample_max || + trans_color->green > sample_max || + trans_color->blue > sample_max))) + png_warning(png_ptr, + "tRNS chunk has out-of-range samples for bit_depth"); + + info_ptr->trans_color = *trans_color; + + if (num_trans == 0) + num_trans = 1; + } + + info_ptr->num_trans = (png_uint_16)num_trans; + + if (num_trans != 0) + { + info_ptr->valid |= PNG_INFO_tRNS; + info_ptr->free_me |= PNG_FREE_TRNS; + } +} +#endif + +#ifdef PNG_sPLT_SUPPORTED +void PNGAPI +png_set_sPLT(png_const_structrp png_ptr, + png_inforp info_ptr, png_const_sPLT_tp entries, int nentries) +/* + * entries - array of png_sPLT_t structures + * to be added to the list of palettes + * in the info structure. + * + * nentries - number of palette structures to be + * added. + */ +{ + png_sPLT_tp np; + + if (png_ptr == NULL || info_ptr == NULL || nentries <= 0 || entries == NULL) + return; + + /* Use the internal realloc function, which checks for all the possible + * overflows. Notice that the parameters are (int) and (size_t) + */ + np = png_voidcast(png_sPLT_tp,png_realloc_array(png_ptr, + info_ptr->splt_palettes, info_ptr->splt_palettes_num, nentries, + sizeof *np)); + + if (np == NULL) + { + /* Out of memory or too many chunks */ + png_chunk_report(png_ptr, "too many sPLT chunks", PNG_CHUNK_WRITE_ERROR); + return; + } + + png_free(png_ptr, info_ptr->splt_palettes); + info_ptr->splt_palettes = np; + info_ptr->free_me |= PNG_FREE_SPLT; + + np += info_ptr->splt_palettes_num; + + do + { + png_size_t length; + + /* Skip invalid input entries */ + if (entries->name == NULL || entries->entries == NULL) + { + /* png_handle_sPLT doesn't do this, so this is an app error */ + png_app_error(png_ptr, "png_set_sPLT: invalid sPLT"); + /* Just skip the invalid entry */ + continue; + } + + np->depth = entries->depth; + + /* In the even of out-of-memory just return - there's no point keeping on + * trying to add sPLT chunks. + */ + length = strlen(entries->name) + 1; + np->name = png_voidcast(png_charp, png_malloc_base(png_ptr, length)); + + if (np->name == NULL) + break; + + memcpy(np->name, entries->name, length); + + /* IMPORTANT: we have memory now that won't get freed if something else + * goes wrong, this code must free it. png_malloc_array produces no + * warnings, use a png_chunk_report (below) if there is an error. + */ + np->entries = png_voidcast(png_sPLT_entryp, png_malloc_array(png_ptr, + entries->nentries, sizeof (png_sPLT_entry))); + + if (np->entries == NULL) + { + png_free(png_ptr, np->name); + break; + } + + np->nentries = entries->nentries; + /* This multiply can't overflow because png_malloc_array has already + * checked it when doing the allocation. + */ + memcpy(np->entries, entries->entries, + entries->nentries * sizeof (png_sPLT_entry)); + + /* Note that 'continue' skips the advance of the out pointer and out + * count, so an invalid entry is not added. + */ + info_ptr->valid |= PNG_INFO_sPLT; + ++(info_ptr->splt_palettes_num); + ++np; + } + while (++entries, --nentries); + + if (nentries > 0) + png_chunk_report(png_ptr, "sPLT out of memory", PNG_CHUNK_WRITE_ERROR); +} +#endif /* PNG_sPLT_SUPPORTED */ + +#ifdef PNG_STORE_UNKNOWN_CHUNKS_SUPPORTED +static png_byte +check_location(png_const_structrp png_ptr, int location) +{ + location &= (PNG_HAVE_IHDR|PNG_HAVE_PLTE|PNG_AFTER_IDAT); + + /* New in 1.6.0; copy the location and check it. This is an API + * change, previously the app had to use the + * png_set_unknown_chunk_location API below for each chunk. + */ + if (location == 0 && !(png_ptr->mode & PNG_IS_READ_STRUCT)) + { + /* Write struct, so unknown chunks come from the app */ + png_app_warning(png_ptr, + "png_set_unknown_chunks now expects a valid location"); + /* Use the old behavior */ + location = (png_byte)(png_ptr->mode & + (PNG_HAVE_IHDR|PNG_HAVE_PLTE|PNG_AFTER_IDAT)); + } + + /* This need not be an internal error - if the app calls + * png_set_unknown_chunks on a read pointer it must get the location right. + */ + if (location == 0) + png_error(png_ptr, "invalid location in png_set_unknown_chunks"); + + /* Now reduce the location to the top-most set bit by removing each least + * significant bit in turn. + */ + while (location != (location & -location)) + location &= ~(location & -location); + + /* The cast is safe because 'location' is a bit mask and only the low four + * bits are significant. + */ + return (png_byte)location; +} + +void PNGAPI +png_set_unknown_chunks(png_const_structrp png_ptr, + png_inforp info_ptr, png_const_unknown_chunkp unknowns, int num_unknowns) +{ + png_unknown_chunkp np; + + if (png_ptr == NULL || info_ptr == NULL || num_unknowns <= 0 || + unknowns == NULL) + return; + + /* Check for the failure cases where support has been disabled at compile + * time. This code is hardly ever compiled - it's here because + * STORE_UNKNOWN_CHUNKS is set by both read and write code (compiling in this + * code) but may be meaningless if the read or write handling of unknown + * chunks is not compiled in. + */ +# if !defined(PNG_READ_UNKNOWN_CHUNKS_SUPPORTED) && \ + defined(PNG_READ_SUPPORTED) + if (png_ptr->mode & PNG_IS_READ_STRUCT) + { + png_app_error(png_ptr, "no unknown chunk support on read"); + return; + } +# endif +# if !defined(PNG_WRITE_UNKNOWN_CHUNKS_SUPPORTED) && \ + defined(PNG_WRITE_SUPPORTED) + if (!(png_ptr->mode & PNG_IS_READ_STRUCT)) + { + png_app_error(png_ptr, "no unknown chunk support on write"); + return; + } +# endif + + /* Prior to 1.6.0 this code used png_malloc_warn; however, this meant that + * unknown critical chunks could be lost with just a warning resulting in + * undefined behavior. Now png_chunk_report is used to provide behavior + * appropriate to read or write. + */ + np = png_voidcast(png_unknown_chunkp, png_realloc_array(png_ptr, + info_ptr->unknown_chunks, info_ptr->unknown_chunks_num, num_unknowns, + sizeof *np)); + + if (np == NULL) + { + png_chunk_report(png_ptr, "too many unknown chunks", + PNG_CHUNK_WRITE_ERROR); + return; + } + + png_free(png_ptr, info_ptr->unknown_chunks); + info_ptr->unknown_chunks = np; /* safe because it is initialized */ + info_ptr->free_me |= PNG_FREE_UNKN; + + np += info_ptr->unknown_chunks_num; + + /* Increment unknown_chunks_num each time round the loop to protect the + * just-allocated chunk data. + */ + for (; num_unknowns > 0; --num_unknowns, ++unknowns) + { + memcpy(np->name, unknowns->name, (sizeof np->name)); + np->name[(sizeof np->name)-1] = '\0'; + np->location = check_location(png_ptr, unknowns->location); + + if (unknowns->size == 0) + { + np->data = NULL; + np->size = 0; + } + + else + { + np->data = png_voidcast(png_bytep, + png_malloc_base(png_ptr, unknowns->size)); + + if (np->data == NULL) + { + png_chunk_report(png_ptr, "unknown chunk: out of memory", + PNG_CHUNK_WRITE_ERROR); + /* But just skip storing the unknown chunk */ + continue; + } + + memcpy(np->data, unknowns->data, unknowns->size); + np->size = unknowns->size; + } + + /* These increments are skipped on out-of-memory for the data - the + * unknown chunk entry gets overwritten if the png_chunk_report returns. + * This is correct in the read case (the chunk is just dropped.) + */ + ++np; + ++(info_ptr->unknown_chunks_num); + } +} + +void PNGAPI +png_set_unknown_chunk_location(png_const_structrp png_ptr, png_inforp info_ptr, + int chunk, int location) +{ + /* This API is pretty pointless in 1.6.0 because the location can be set + * before the call to png_set_unknown_chunks. + * + * TODO: add a png_app_warning in 1.7 + */ + if (png_ptr != NULL && info_ptr != NULL && chunk >= 0 && + chunk < info_ptr->unknown_chunks_num) + { + if ((location & (PNG_HAVE_IHDR|PNG_HAVE_PLTE|PNG_AFTER_IDAT)) == 0) + { + png_app_error(png_ptr, "invalid unknown chunk location"); + /* Fake out the pre 1.6.0 behavior: */ + if ((location & PNG_HAVE_IDAT)) /* undocumented! */ + location = PNG_AFTER_IDAT; + + else + location = PNG_HAVE_IHDR; /* also undocumented */ + } + + info_ptr->unknown_chunks[chunk].location = + check_location(png_ptr, location); + } +} +#endif + + +#ifdef PNG_MNG_FEATURES_SUPPORTED +png_uint_32 PNGAPI +png_permit_mng_features (png_structrp png_ptr, png_uint_32 mng_features) +{ + png_debug(1, "in png_permit_mng_features"); + + if (png_ptr == NULL) + return 0; + + png_ptr->mng_features_permitted = mng_features & PNG_ALL_MNG_FEATURES; + + return png_ptr->mng_features_permitted; +} +#endif + +#ifdef PNG_HANDLE_AS_UNKNOWN_SUPPORTED +static unsigned int +add_one_chunk(png_bytep list, unsigned int count, png_const_bytep add, int keep) +{ + unsigned int i; + + /* Utility function: update the 'keep' state of a chunk if it is already in + * the list, otherwise add it to the list. + */ + for (i=0; i= PNG_HANDLE_CHUNK_LAST) + { + png_app_error(png_ptr, "png_set_keep_unknown_chunks: invalid keep"); + return; + } + + if (num_chunks_in <= 0) + { + png_ptr->unknown_default = keep; + + /* '0' means just set the flags, so stop here */ + if (num_chunks_in == 0) + return; + } + + if (num_chunks_in < 0) + { + /* Ignore all unknown chunks and all chunks recognized by + * libpng except for IHDR, PLTE, tRNS, IDAT, and IEND + */ + static PNG_CONST png_byte chunks_to_ignore[] = { + 98, 75, 71, 68, '\0', /* bKGD */ + 99, 72, 82, 77, '\0', /* cHRM */ + 103, 65, 77, 65, '\0', /* gAMA */ + 104, 73, 83, 84, '\0', /* hIST */ + 105, 67, 67, 80, '\0', /* iCCP */ + 105, 84, 88, 116, '\0', /* iTXt */ + 111, 70, 70, 115, '\0', /* oFFs */ + 112, 67, 65, 76, '\0', /* pCAL */ + 112, 72, 89, 115, '\0', /* pHYs */ + 115, 66, 73, 84, '\0', /* sBIT */ + 115, 67, 65, 76, '\0', /* sCAL */ + 115, 80, 76, 84, '\0', /* sPLT */ + 115, 84, 69, 82, '\0', /* sTER */ + 115, 82, 71, 66, '\0', /* sRGB */ + 116, 69, 88, 116, '\0', /* tEXt */ + 116, 73, 77, 69, '\0', /* tIME */ + 122, 84, 88, 116, '\0' /* zTXt */ + }; + + chunk_list = chunks_to_ignore; + num_chunks = (sizeof chunks_to_ignore)/5; + } + + else /* num_chunks_in > 0 */ + { + if (chunk_list == NULL) + { + /* Prior to 1.6.0 this was silently ignored, now it is an app_error + * which can be switched off. + */ + png_app_error(png_ptr, "png_set_keep_unknown_chunks: no chunk list"); + return; + } + + num_chunks = num_chunks_in; + } + + old_num_chunks = png_ptr->num_chunk_list; + if (png_ptr->chunk_list == NULL) + old_num_chunks = 0; + + /* Since num_chunks is always restricted to UINT_MAX/5 this can't overflow. + */ + if (num_chunks + old_num_chunks > UINT_MAX/5) + { + png_app_error(png_ptr, "png_set_keep_unknown_chunks: too many chunks"); + return; + } + + /* If these chunks are being reset to the default then no more memory is + * required because add_one_chunk above doesn't extend the list if the 'keep' + * parameter is the default. + */ + if (keep) + { + new_list = png_voidcast(png_bytep, png_malloc(png_ptr, + 5 * (num_chunks + old_num_chunks))); + + if (old_num_chunks > 0) + memcpy(new_list, png_ptr->chunk_list, 5*old_num_chunks); + } + + else if (old_num_chunks > 0) + new_list = png_ptr->chunk_list; + + else + new_list = NULL; + + /* Add the new chunks together with each one's handling code. If the chunk + * already exists the code is updated, otherwise the chunk is added to the + * end. (In libpng 1.6.0 order no longer matters because this code enforces + * the earlier convention that the last setting is the one that is used.) + */ + if (new_list != NULL) + { + png_const_bytep inlist; + png_bytep outlist; + unsigned int i; + + for (i=0; ichunk_list != new_list) + png_free(png_ptr, new_list); + + new_list = NULL; + } + } + + else + num_chunks = 0; + + png_ptr->num_chunk_list = num_chunks; + + if (png_ptr->chunk_list != new_list) + { + if (png_ptr->chunk_list != NULL) + png_free(png_ptr, png_ptr->chunk_list); + + png_ptr->chunk_list = new_list; + } +} +#endif + +#ifdef PNG_READ_USER_CHUNKS_SUPPORTED +void PNGAPI +png_set_read_user_chunk_fn(png_structrp png_ptr, png_voidp user_chunk_ptr, + png_user_chunk_ptr read_user_chunk_fn) +{ + png_debug(1, "in png_set_read_user_chunk_fn"); + + if (png_ptr == NULL) + return; + + png_ptr->read_user_chunk_fn = read_user_chunk_fn; + png_ptr->user_chunk_ptr = user_chunk_ptr; +} +#endif + +#ifdef PNG_INFO_IMAGE_SUPPORTED +void PNGAPI +png_set_rows(png_const_structrp png_ptr, png_inforp info_ptr, + png_bytepp row_pointers) +{ + png_debug1(1, "in %s storage function", "rows"); + + if (png_ptr == NULL || info_ptr == NULL) + return; + + if (info_ptr->row_pointers && (info_ptr->row_pointers != row_pointers)) + png_free_data(png_ptr, info_ptr, PNG_FREE_ROWS, 0); + + info_ptr->row_pointers = row_pointers; + + if (row_pointers) + info_ptr->valid |= PNG_INFO_IDAT; +} +#endif + +void PNGAPI +png_set_compression_buffer_size(png_structrp png_ptr, png_size_t size) +{ + if (png_ptr == NULL) + return; + + if (size == 0 || size > PNG_UINT_31_MAX) + png_error(png_ptr, "invalid compression buffer size"); + +# ifdef PNG_SEQUENTIAL_READ_SUPPORTED + if (png_ptr->mode & PNG_IS_READ_STRUCT) + { + png_ptr->IDAT_read_size = (png_uint_32)size; /* checked above */ + return; + } +# endif + +# ifdef PNG_WRITE_SUPPORTED + if (!(png_ptr->mode & PNG_IS_READ_STRUCT)) + { + if (png_ptr->zowner != 0) + { + png_warning(png_ptr, + "Compression buffer size cannot be changed because it is in use"); + return; + } + + if (size > ZLIB_IO_MAX) + { + png_warning(png_ptr, + "Compression buffer size limited to system maximum"); + size = ZLIB_IO_MAX; /* must fit */ + } + + else if (size < 6) + { + /* Deflate will potentially go into an infinite loop on a SYNC_FLUSH + * if this is permitted. + */ + png_warning(png_ptr, + "Compression buffer size cannot be reduced below 6"); + return; + } + + if (png_ptr->zbuffer_size != size) + { + png_free_buffer_list(png_ptr, &png_ptr->zbuffer_list); + png_ptr->zbuffer_size = (uInt)size; + } + } +# endif +} + +void PNGAPI +png_set_invalid(png_const_structrp png_ptr, png_inforp info_ptr, int mask) +{ + if (png_ptr && info_ptr) + info_ptr->valid &= ~mask; +} + + +#ifdef PNG_SET_USER_LIMITS_SUPPORTED +/* This function was added to libpng 1.2.6 */ +void PNGAPI +png_set_user_limits (png_structrp png_ptr, png_uint_32 user_width_max, + png_uint_32 user_height_max) +{ + /* Images with dimensions larger than these limits will be + * rejected by png_set_IHDR(). To accept any PNG datastream + * regardless of dimensions, set both limits to 0x7ffffffL. + */ + if (png_ptr == NULL) + return; + + png_ptr->user_width_max = user_width_max; + png_ptr->user_height_max = user_height_max; +} + +/* This function was added to libpng 1.4.0 */ +void PNGAPI +png_set_chunk_cache_max (png_structrp png_ptr, png_uint_32 user_chunk_cache_max) +{ + if (png_ptr) + png_ptr->user_chunk_cache_max = user_chunk_cache_max; +} + +/* This function was added to libpng 1.4.1 */ +void PNGAPI +png_set_chunk_malloc_max (png_structrp png_ptr, + png_alloc_size_t user_chunk_malloc_max) +{ + if (png_ptr) + png_ptr->user_chunk_malloc_max = user_chunk_malloc_max; +} +#endif /* ?PNG_SET_USER_LIMITS_SUPPORTED */ + + +#ifdef PNG_BENIGN_ERRORS_SUPPORTED +void PNGAPI +png_set_benign_errors(png_structrp png_ptr, int allowed) +{ + png_debug(1, "in png_set_benign_errors"); + + /* If allowed is 1, png_benign_error() is treated as a warning. + * + * If allowed is 0, png_benign_error() is treated as an error (which + * is the default behavior if png_set_benign_errors() is not called). + */ + + if (allowed) + png_ptr->flags |= PNG_FLAG_BENIGN_ERRORS_WARN | + PNG_FLAG_APP_WARNINGS_WARN | PNG_FLAG_APP_ERRORS_WARN; + + else + png_ptr->flags &= ~(PNG_FLAG_BENIGN_ERRORS_WARN | + PNG_FLAG_APP_WARNINGS_WARN | PNG_FLAG_APP_ERRORS_WARN); +} +#endif /* PNG_BENIGN_ERRORS_SUPPORTED */ + +#ifdef PNG_CHECK_FOR_INVALID_INDEX_SUPPORTED + /* Whether to report invalid palette index; added at libng-1.5.10. + * It is possible for an indexed (color-type==3) PNG file to contain + * pixels with invalid (out-of-range) indexes if the PLTE chunk has + * fewer entries than the image's bit-depth would allow. We recover + * from this gracefully by filling any incomplete palette with zeroes + * (opaque black). By default, when this occurs libpng will issue + * a benign error. This API can be used to override that behavior. + */ +void PNGAPI +png_set_check_for_invalid_index(png_structrp png_ptr, int allowed) +{ + png_debug(1, "in png_set_check_for_invalid_index"); + + if (allowed > 0) + png_ptr->num_palette_max = 0; + + else + png_ptr->num_palette_max = -1; +} +#endif +#endif /* PNG_READ_SUPPORTED || PNG_WRITE_SUPPORTED */ diff --git a/ml/dlib/dlib/external/libpng/pngstruct.h b/ml/dlib/dlib/external/libpng/pngstruct.h new file mode 100644 index 000000000..d58c02884 --- /dev/null +++ b/ml/dlib/dlib/external/libpng/pngstruct.h @@ -0,0 +1,489 @@ + +/* pngstruct.h - header file for PNG reference library + * + * Copyright (c) 1998-2013 Glenn Randers-Pehrson + * (Version 0.96 Copyright (c) 1996, 1997 Andreas Dilger) + * (Version 0.88 Copyright (c) 1995, 1996 Guy Eric Schalnat, Group 42, Inc.) + * + * Last changed in libpng 1.6.1 [March 28, 2013] + * + * This code is released under the libpng license. + * For conditions of distribution and use, see the disclaimer + * and license in png.h + */ + +/* The structure that holds the information to read and write PNG files. + * The only people who need to care about what is inside of this are the + * people who will be modifying the library for their own special needs. + * It should NOT be accessed directly by an application. + */ + +#ifndef PNGSTRUCT_H +#define PNGSTRUCT_H +/* zlib.h defines the structure z_stream, an instance of which is included + * in this structure and is required for decompressing the LZ compressed + * data in PNG files. + */ +#ifndef ZLIB_CONST + /* We must ensure that zlib uses 'const' in declarations. */ +# define ZLIB_CONST +#endif +#include "zlib.h" +#ifdef const + /* zlib.h sometimes #defines const to nothing, undo this. */ +# undef const +#endif + +/* zlib.h has mediocre z_const use before 1.2.6, this stuff is for compatibility + * with older builds. + */ +#if ZLIB_VERNUM < 0x1260 +# define PNGZ_MSG_CAST(s) png_constcast(char*,s) +# define PNGZ_INPUT_CAST(b) png_constcast(png_bytep,b) +#else +# define PNGZ_MSG_CAST(s) (s) +# define PNGZ_INPUT_CAST(b) (b) +#endif + +/* zlib.h declares a magic type 'uInt' that limits the amount of data that zlib + * can handle at once. This type need be no larger than 16 bits (so maximum of + * 65535), this define allows us to discover how big it is, but limited by the + * maximuum for png_size_t. The value can be overriden in a library build + * (pngusr.h, or set it in CPPFLAGS) and it works to set it to a considerably + * lower value (e.g. 255 works). A lower value may help memory usage (slightly) + * and may even improve performance on some systems (and degrade it on others.) + */ +#ifndef ZLIB_IO_MAX +# define ZLIB_IO_MAX ((uInt)-1) +#endif + +#ifdef PNG_WRITE_SUPPORTED +/* The type of a compression buffer list used by the write code. */ +typedef struct png_compression_buffer +{ + struct png_compression_buffer *next; + png_byte output[1]; /* actually zbuf_size */ +} png_compression_buffer, *png_compression_bufferp; + +#define PNG_COMPRESSION_BUFFER_SIZE(pp)\ + (offsetof(png_compression_buffer, output) + (pp)->zbuffer_size) +#endif + +/* Colorspace support; structures used in png_struct, png_info and in internal + * functions to hold and communicate information about the color space. + * + * PNG_COLORSPACE_SUPPORTED is only required if the application will perform + * colorspace corrections, otherwise all the colorspace information can be + * skipped and the size of libpng can be reduced (significantly) by compiling + * out the colorspace support. + */ +#ifdef PNG_COLORSPACE_SUPPORTED +/* The chromaticities of the red, green and blue colorants and the chromaticity + * of the corresponding white point (i.e. of rgb(1.0,1.0,1.0)). + */ +typedef struct png_xy +{ + png_fixed_point redx, redy; + png_fixed_point greenx, greeny; + png_fixed_point bluex, bluey; + png_fixed_point whitex, whitey; +} png_xy; + +/* The same data as above but encoded as CIE XYZ values. When this data comes + * from chromaticities the sum of the Y values is assumed to be 1.0 + */ +typedef struct png_XYZ +{ + png_fixed_point red_X, red_Y, red_Z; + png_fixed_point green_X, green_Y, green_Z; + png_fixed_point blue_X, blue_Y, blue_Z; +} png_XYZ; +#endif /* COLORSPACE */ + +#if defined(PNG_COLORSPACE_SUPPORTED) || defined(PNG_GAMMA_SUPPORTED) +/* A colorspace is all the above plus, potentially, profile information, + * however at present libpng does not use the profile internally so it is only + * stored in the png_info struct (if iCCP is supported.) The rendering intent + * is retained here and is checked. + * + * The file gamma encoding information is also stored here and gamma correction + * is done by libpng, whereas color correction must currently be done by the + * application. + */ +typedef struct png_colorspace +{ +#ifdef PNG_GAMMA_SUPPORTED + png_fixed_point gamma; /* File gamma */ +#endif + +#ifdef PNG_COLORSPACE_SUPPORTED + png_xy end_points_xy; /* End points as chromaticities */ + png_XYZ end_points_XYZ; /* End points as CIE XYZ colorant values */ + png_uint_16 rendering_intent; /* Rendering intent of a profile */ +#endif + + /* Flags are always defined to simplify the code. */ + png_uint_16 flags; /* As defined below */ +} png_colorspace, * PNG_RESTRICT png_colorspacerp; + +typedef const png_colorspace * PNG_RESTRICT png_const_colorspacerp; + +/* General flags for the 'flags' field */ +#define PNG_COLORSPACE_HAVE_GAMMA 0x0001 +#define PNG_COLORSPACE_HAVE_ENDPOINTS 0x0002 +#define PNG_COLORSPACE_HAVE_INTENT 0x0004 +#define PNG_COLORSPACE_FROM_gAMA 0x0008 +#define PNG_COLORSPACE_FROM_cHRM 0x0010 +#define PNG_COLORSPACE_FROM_sRGB 0x0020 +#define PNG_COLORSPACE_ENDPOINTS_MATCH_sRGB 0x0040 +#define PNG_COLORSPACE_MATCHES_sRGB 0x0080 /* exact match on profile */ +#define PNG_COLORSPACE_INVALID 0x8000 +#define PNG_COLORSPACE_CANCEL(flags) (0xffff ^ (flags)) +#endif /* COLORSPACE || GAMMA */ + +struct png_struct_def +{ +#ifdef PNG_SETJMP_SUPPORTED + jmp_buf jmp_buf_local; /* New name in 1.6.0 for jmp_buf in png_struct */ + png_longjmp_ptr longjmp_fn;/* setjmp non-local goto function. */ + jmp_buf *jmp_buf_ptr; /* passed to longjmp_fn */ + size_t jmp_buf_size; /* size of the above, if allocated */ +#endif + png_error_ptr error_fn; /* function for printing errors and aborting */ +#ifdef PNG_WARNINGS_SUPPORTED + png_error_ptr warning_fn; /* function for printing warnings */ +#endif + png_voidp error_ptr; /* user supplied struct for error functions */ + png_rw_ptr write_data_fn; /* function for writing output data */ + png_rw_ptr read_data_fn; /* function for reading input data */ + png_voidp io_ptr; /* ptr to application struct for I/O functions */ + +#ifdef PNG_READ_USER_TRANSFORM_SUPPORTED + png_user_transform_ptr read_user_transform_fn; /* user read transform */ +#endif + +#ifdef PNG_WRITE_USER_TRANSFORM_SUPPORTED + png_user_transform_ptr write_user_transform_fn; /* user write transform */ +#endif + +/* These were added in libpng-1.0.2 */ +#ifdef PNG_USER_TRANSFORM_PTR_SUPPORTED +#if defined(PNG_READ_USER_TRANSFORM_SUPPORTED) || \ + defined(PNG_WRITE_USER_TRANSFORM_SUPPORTED) + png_voidp user_transform_ptr; /* user supplied struct for user transform */ + png_byte user_transform_depth; /* bit depth of user transformed pixels */ + png_byte user_transform_channels; /* channels in user transformed pixels */ +#endif +#endif + + png_uint_32 mode; /* tells us where we are in the PNG file */ + png_uint_32 flags; /* flags indicating various things to libpng */ + png_uint_32 transformations; /* which transformations to perform */ + + png_uint_32 zowner; /* ID (chunk type) of zstream owner, 0 if none */ + z_stream zstream; /* decompression structure */ + +#ifdef PNG_WRITE_SUPPORTED + png_compression_bufferp zbuffer_list; /* Created on demand during write */ + uInt zbuffer_size; /* size of the actual buffer */ + + int zlib_level; /* holds zlib compression level */ + int zlib_method; /* holds zlib compression method */ + int zlib_window_bits; /* holds zlib compression window bits */ + int zlib_mem_level; /* holds zlib compression memory level */ + int zlib_strategy; /* holds zlib compression strategy */ +#endif +/* Added at libpng 1.5.4 */ +#ifdef PNG_WRITE_CUSTOMIZE_ZTXT_COMPRESSION_SUPPORTED + int zlib_text_level; /* holds zlib compression level */ + int zlib_text_method; /* holds zlib compression method */ + int zlib_text_window_bits; /* holds zlib compression window bits */ + int zlib_text_mem_level; /* holds zlib compression memory level */ + int zlib_text_strategy; /* holds zlib compression strategy */ +#endif +/* End of material added at libpng 1.5.4 */ +/* Added at libpng 1.6.0 */ +#ifdef PNG_WRITE_SUPPORTED + int zlib_set_level; /* Actual values set into the zstream on write */ + int zlib_set_method; + int zlib_set_window_bits; + int zlib_set_mem_level; + int zlib_set_strategy; +#endif + + png_uint_32 width; /* width of image in pixels */ + png_uint_32 height; /* height of image in pixels */ + png_uint_32 num_rows; /* number of rows in current pass */ + png_uint_32 usr_width; /* width of row at start of write */ + png_size_t rowbytes; /* size of row in bytes */ + png_uint_32 iwidth; /* width of current interlaced row in pixels */ + png_uint_32 row_number; /* current row in interlace pass */ + png_uint_32 chunk_name; /* PNG_CHUNK() id of current chunk */ + png_bytep prev_row; /* buffer to save previous (unfiltered) row. + * This is a pointer into big_prev_row + */ + png_bytep row_buf; /* buffer to save current (unfiltered) row. + * This is a pointer into big_row_buf + */ +#ifdef PNG_WRITE_SUPPORTED + png_bytep sub_row; /* buffer to save "sub" row when filtering */ + png_bytep up_row; /* buffer to save "up" row when filtering */ + png_bytep avg_row; /* buffer to save "avg" row when filtering */ + png_bytep paeth_row; /* buffer to save "Paeth" row when filtering */ +#endif + png_size_t info_rowbytes; /* Added in 1.5.4: cache of updated row bytes */ + + png_uint_32 idat_size; /* current IDAT size for read */ + png_uint_32 crc; /* current chunk CRC value */ + png_colorp palette; /* palette from the input file */ + png_uint_16 num_palette; /* number of color entries in palette */ + +/* Added at libpng-1.5.10 */ +#ifdef PNG_CHECK_FOR_INVALID_INDEX_SUPPORTED + int num_palette_max; /* maximum palette index found in IDAT */ +#endif + + png_uint_16 num_trans; /* number of transparency values */ + png_byte compression; /* file compression type (always 0) */ + png_byte filter; /* file filter type (always 0) */ + png_byte interlaced; /* PNG_INTERLACE_NONE, PNG_INTERLACE_ADAM7 */ + png_byte pass; /* current interlace pass (0 - 6) */ + png_byte do_filter; /* row filter flags (see PNG_FILTER_ below ) */ + png_byte color_type; /* color type of file */ + png_byte bit_depth; /* bit depth of file */ + png_byte usr_bit_depth; /* bit depth of users row: write only */ + png_byte pixel_depth; /* number of bits per pixel */ + png_byte channels; /* number of channels in file */ +#ifdef PNG_WRITE_SUPPORTED + png_byte usr_channels; /* channels at start of write: write only */ +#endif + png_byte sig_bytes; /* magic bytes read/written from start of file */ + png_byte maximum_pixel_depth; + /* pixel depth used for the row buffers */ + png_byte transformed_pixel_depth; + /* pixel depth after read/write transforms */ +#if defined(PNG_READ_FILLER_SUPPORTED) || defined(PNG_WRITE_FILLER_SUPPORTED) + png_uint_16 filler; /* filler bytes for pixel expansion */ +#endif + +#if defined(PNG_bKGD_SUPPORTED) || defined(PNG_READ_BACKGROUND_SUPPORTED) ||\ + defined(PNG_READ_ALPHA_MODE_SUPPORTED) + png_byte background_gamma_type; + png_fixed_point background_gamma; + png_color_16 background; /* background color in screen gamma space */ +#ifdef PNG_READ_GAMMA_SUPPORTED + png_color_16 background_1; /* background normalized to gamma 1.0 */ +#endif +#endif /* PNG_bKGD_SUPPORTED */ + +#ifdef PNG_WRITE_FLUSH_SUPPORTED + png_flush_ptr output_flush_fn; /* Function for flushing output */ + png_uint_32 flush_dist; /* how many rows apart to flush, 0 - no flush */ + png_uint_32 flush_rows; /* number of rows written since last flush */ +#endif + +#ifdef PNG_READ_GAMMA_SUPPORTED + int gamma_shift; /* number of "insignificant" bits in 16-bit gamma */ + png_fixed_point screen_gamma; /* screen gamma value (display_exponent) */ + + png_bytep gamma_table; /* gamma table for 8-bit depth files */ + png_uint_16pp gamma_16_table; /* gamma table for 16-bit depth files */ +#if defined(PNG_READ_BACKGROUND_SUPPORTED) || \ + defined(PNG_READ_ALPHA_MODE_SUPPORTED) || \ + defined(PNG_READ_RGB_TO_GRAY_SUPPORTED) + png_bytep gamma_from_1; /* converts from 1.0 to screen */ + png_bytep gamma_to_1; /* converts from file to 1.0 */ + png_uint_16pp gamma_16_from_1; /* converts from 1.0 to screen */ + png_uint_16pp gamma_16_to_1; /* converts from file to 1.0 */ +#endif /* READ_BACKGROUND || READ_ALPHA_MODE || RGB_TO_GRAY */ +#endif + +#if defined(PNG_READ_GAMMA_SUPPORTED) || defined(PNG_sBIT_SUPPORTED) + png_color_8 sig_bit; /* significant bits in each available channel */ +#endif + +#if defined(PNG_READ_SHIFT_SUPPORTED) || defined(PNG_WRITE_SHIFT_SUPPORTED) + png_color_8 shift; /* shift for significant bit tranformation */ +#endif + +#if defined(PNG_tRNS_SUPPORTED) || defined(PNG_READ_BACKGROUND_SUPPORTED) \ + || defined(PNG_READ_EXPAND_SUPPORTED) || defined(PNG_READ_BACKGROUND_SUPPORTED) + png_bytep trans_alpha; /* alpha values for paletted files */ + png_color_16 trans_color; /* transparent color for non-paletted files */ +#endif + + png_read_status_ptr read_row_fn; /* called after each row is decoded */ + png_write_status_ptr write_row_fn; /* called after each row is encoded */ +#ifdef PNG_PROGRESSIVE_READ_SUPPORTED + png_progressive_info_ptr info_fn; /* called after header data fully read */ + png_progressive_row_ptr row_fn; /* called after a prog. row is decoded */ + png_progressive_end_ptr end_fn; /* called after image is complete */ + png_bytep save_buffer_ptr; /* current location in save_buffer */ + png_bytep save_buffer; /* buffer for previously read data */ + png_bytep current_buffer_ptr; /* current location in current_buffer */ + png_bytep current_buffer; /* buffer for recently used data */ + png_uint_32 push_length; /* size of current input chunk */ + png_uint_32 skip_length; /* bytes to skip in input data */ + png_size_t save_buffer_size; /* amount of data now in save_buffer */ + png_size_t save_buffer_max; /* total size of save_buffer */ + png_size_t buffer_size; /* total amount of available input data */ + png_size_t current_buffer_size; /* amount of data now in current_buffer */ + int process_mode; /* what push library is currently doing */ + int cur_palette; /* current push library palette index */ + +#endif /* PNG_PROGRESSIVE_READ_SUPPORTED */ + +#if defined(__TURBOC__) && !defined(_Windows) && !defined(__FLAT__) +/* For the Borland special 64K segment handler */ + png_bytepp offset_table_ptr; + png_bytep offset_table; + png_uint_16 offset_table_number; + png_uint_16 offset_table_count; + png_uint_16 offset_table_count_free; +#endif + +#ifdef PNG_READ_QUANTIZE_SUPPORTED + png_bytep palette_lookup; /* lookup table for quantizing */ + png_bytep quantize_index; /* index translation for palette files */ +#endif + +#ifdef PNG_WRITE_WEIGHTED_FILTER_SUPPORTED + png_byte heuristic_method; /* heuristic for row filter selection */ + png_byte num_prev_filters; /* number of weights for previous rows */ + png_bytep prev_filters; /* filter type(s) of previous row(s) */ + png_uint_16p filter_weights; /* weight(s) for previous line(s) */ + png_uint_16p inv_filter_weights; /* 1/weight(s) for previous line(s) */ + png_uint_16p filter_costs; /* relative filter calculation cost */ + png_uint_16p inv_filter_costs; /* 1/relative filter calculation cost */ +#endif + + /* Options */ +#ifdef PNG_SET_OPTION_SUPPORTED + png_byte options; /* On/off state (up to 4 options) */ +#endif + +#if PNG_LIBPNG_VER < 10700 +/* To do: remove this from libpng-1.7 */ +#ifdef PNG_TIME_RFC1123_SUPPORTED + char time_buffer[29]; /* String to hold RFC 1123 time text */ +#endif +#endif + +/* New members added in libpng-1.0.6 */ + + png_uint_32 free_me; /* flags items libpng is responsible for freeing */ + +#ifdef PNG_USER_CHUNKS_SUPPORTED + png_voidp user_chunk_ptr; +#ifdef PNG_READ_USER_CHUNKS_SUPPORTED + png_user_chunk_ptr read_user_chunk_fn; /* user read chunk handler */ +#endif +#endif + +#ifdef PNG_SET_UNKNOWN_CHUNKS_SUPPORTED + int unknown_default; /* As PNG_HANDLE_* */ + unsigned int num_chunk_list; /* Number of entries in the list */ + png_bytep chunk_list; /* List of png_byte[5]; the textual chunk name + * followed by a PNG_HANDLE_* byte */ +#endif + +/* New members added in libpng-1.0.3 */ +#ifdef PNG_READ_RGB_TO_GRAY_SUPPORTED + png_byte rgb_to_gray_status; + /* Added in libpng 1.5.5 to record setting of coefficients: */ + png_byte rgb_to_gray_coefficients_set; + /* These were changed from png_byte in libpng-1.0.6 */ + png_uint_16 rgb_to_gray_red_coeff; + png_uint_16 rgb_to_gray_green_coeff; + /* deleted in 1.5.5: rgb_to_gray_blue_coeff; */ +#endif + +/* New member added in libpng-1.0.4 (renamed in 1.0.9) */ +#if defined(PNG_MNG_FEATURES_SUPPORTED) +/* Changed from png_byte to png_uint_32 at version 1.2.0 */ + png_uint_32 mng_features_permitted; +#endif + +/* New member added in libpng-1.0.9, ifdef'ed out in 1.0.12, enabled in 1.2.0 */ +#ifdef PNG_MNG_FEATURES_SUPPORTED + png_byte filter_type; +#endif + +/* New members added in libpng-1.2.0 */ + +/* New members added in libpng-1.0.2 but first enabled by default in 1.2.0 */ +#ifdef PNG_USER_MEM_SUPPORTED + png_voidp mem_ptr; /* user supplied struct for mem functions */ + png_malloc_ptr malloc_fn; /* function for allocating memory */ + png_free_ptr free_fn; /* function for freeing memory */ +#endif + +/* New member added in libpng-1.0.13 and 1.2.0 */ + png_bytep big_row_buf; /* buffer to save current (unfiltered) row */ + +#ifdef PNG_READ_QUANTIZE_SUPPORTED +/* The following three members were added at version 1.0.14 and 1.2.4 */ + png_bytep quantize_sort; /* working sort array */ + png_bytep index_to_palette; /* where the original index currently is + in the palette */ + png_bytep palette_to_index; /* which original index points to this + palette color */ +#endif + +/* New members added in libpng-1.0.16 and 1.2.6 */ + png_byte compression_type; + +#ifdef PNG_USER_LIMITS_SUPPORTED + png_uint_32 user_width_max; + png_uint_32 user_height_max; + + /* Added in libpng-1.4.0: Total number of sPLT, text, and unknown + * chunks that can be stored (0 means unlimited). + */ + png_uint_32 user_chunk_cache_max; + + /* Total memory that a zTXt, sPLT, iTXt, iCCP, or unknown chunk + * can occupy when decompressed. 0 means unlimited. + */ + png_alloc_size_t user_chunk_malloc_max; +#endif + +/* New member added in libpng-1.0.25 and 1.2.17 */ +#ifdef PNG_READ_UNKNOWN_CHUNKS_SUPPORTED + /* Temporary storage for unknown chunk that the library doesn't recognize, + * used while reading the chunk. + */ + png_unknown_chunk unknown_chunk; +#endif + +/* New member added in libpng-1.2.26 */ + png_size_t old_big_row_buf_size; + +#ifdef PNG_READ_SUPPORTED +/* New member added in libpng-1.2.30 */ + png_bytep read_buffer; /* buffer for reading chunk data */ + png_alloc_size_t read_buffer_size; /* current size of the buffer */ +#endif +#ifdef PNG_SEQUENTIAL_READ_SUPPORTED + uInt IDAT_read_size; /* limit on read buffer size for IDAT */ +#endif + +#ifdef PNG_IO_STATE_SUPPORTED +/* New member added in libpng-1.4.0 */ + png_uint_32 io_state; +#endif + +/* New member added in libpng-1.5.6 */ + png_bytep big_prev_row; + +/* New member added in libpng-1.5.7 */ + void (*read_filter[PNG_FILTER_VALUE_LAST-1])(png_row_infop row_info, + png_bytep row, png_const_bytep prev_row); + +#ifdef PNG_READ_SUPPORTED +#if defined(PNG_COLORSPACE_SUPPORTED) || defined(PNG_GAMMA_SUPPORTED) + png_colorspace colorspace; +#endif +#endif +}; +#endif /* PNGSTRUCT_H */ diff --git a/ml/dlib/dlib/external/libpng/pngtrans.c b/ml/dlib/dlib/external/libpng/pngtrans.c new file mode 100644 index 000000000..8f8bc5d9e --- /dev/null +++ b/ml/dlib/dlib/external/libpng/pngtrans.c @@ -0,0 +1,841 @@ + +/* pngtrans.c - transforms the data in a row (used by both readers and writers) + * + * Last changed in libpng 1.6.2 [April 25, 2013] + * Copyright (c) 1998-2013 Glenn Randers-Pehrson + * (Version 0.96 Copyright (c) 1996, 1997 Andreas Dilger) + * (Version 0.88 Copyright (c) 1995, 1996 Guy Eric Schalnat, Group 42, Inc.) + * + * This code is released under the libpng license. + * For conditions of distribution and use, see the disclaimer + * and license in png.h + */ + +#include "pngpriv.h" + +#if defined(PNG_READ_SUPPORTED) || defined(PNG_WRITE_SUPPORTED) + +#if defined(PNG_READ_BGR_SUPPORTED) || defined(PNG_WRITE_BGR_SUPPORTED) +/* Turn on BGR-to-RGB mapping */ +void PNGAPI +png_set_bgr(png_structrp png_ptr) +{ + png_debug(1, "in png_set_bgr"); + + if (png_ptr == NULL) + return; + + png_ptr->transformations |= PNG_BGR; +} +#endif + +#if defined(PNG_READ_SWAP_SUPPORTED) || defined(PNG_WRITE_SWAP_SUPPORTED) +/* Turn on 16 bit byte swapping */ +void PNGAPI +png_set_swap(png_structrp png_ptr) +{ + png_debug(1, "in png_set_swap"); + + if (png_ptr == NULL) + return; + + if (png_ptr->bit_depth == 16) + png_ptr->transformations |= PNG_SWAP_BYTES; +} +#endif + +#if defined(PNG_READ_PACK_SUPPORTED) || defined(PNG_WRITE_PACK_SUPPORTED) +/* Turn on pixel packing */ +void PNGAPI +png_set_packing(png_structrp png_ptr) +{ + png_debug(1, "in png_set_packing"); + + if (png_ptr == NULL) + return; + + if (png_ptr->bit_depth < 8) + { + png_ptr->transformations |= PNG_PACK; + png_ptr->usr_bit_depth = 8; + } +} +#endif + +#if defined(PNG_READ_PACKSWAP_SUPPORTED)||defined(PNG_WRITE_PACKSWAP_SUPPORTED) +/* Turn on packed pixel swapping */ +void PNGAPI +png_set_packswap(png_structrp png_ptr) +{ + png_debug(1, "in png_set_packswap"); + + if (png_ptr == NULL) + return; + + if (png_ptr->bit_depth < 8) + png_ptr->transformations |= PNG_PACKSWAP; +} +#endif + +#if defined(PNG_READ_SHIFT_SUPPORTED) || defined(PNG_WRITE_SHIFT_SUPPORTED) +void PNGAPI +png_set_shift(png_structrp png_ptr, png_const_color_8p true_bits) +{ + png_debug(1, "in png_set_shift"); + + if (png_ptr == NULL) + return; + + png_ptr->transformations |= PNG_SHIFT; + png_ptr->shift = *true_bits; +} +#endif + +#if defined(PNG_READ_INTERLACING_SUPPORTED) || \ + defined(PNG_WRITE_INTERLACING_SUPPORTED) +int PNGAPI +png_set_interlace_handling(png_structrp png_ptr) +{ + png_debug(1, "in png_set_interlace handling"); + + if (png_ptr && png_ptr->interlaced) + { + png_ptr->transformations |= PNG_INTERLACE; + return (7); + } + + return (1); +} +#endif + +#if defined(PNG_READ_FILLER_SUPPORTED) || defined(PNG_WRITE_FILLER_SUPPORTED) +/* Add a filler byte on read, or remove a filler or alpha byte on write. + * The filler type has changed in v0.95 to allow future 2-byte fillers + * for 48-bit input data, as well as to avoid problems with some compilers + * that don't like bytes as parameters. + */ +void PNGAPI +png_set_filler(png_structrp png_ptr, png_uint_32 filler, int filler_loc) +{ + png_debug(1, "in png_set_filler"); + + if (png_ptr == NULL) + return; + + /* In libpng 1.6 it is possible to determine whether this is a read or write + * operation and therefore to do more checking here for a valid call. + */ + if (png_ptr->mode & PNG_IS_READ_STRUCT) + { +# ifdef PNG_READ_FILLER_SUPPORTED + /* On read png_set_filler is always valid, regardless of the base PNG + * format, because other transformations can give a format where the + * filler code can execute (basically an 8 or 16-bit component RGB or G + * format.) + * + * NOTE: usr_channels is not used by the read code! (This has led to + * confusion in the past.) The filler is only used in the read code. + */ + png_ptr->filler = (png_uint_16)filler; +# else + png_app_error(png_ptr, "png_set_filler not supported on read"); + PNG_UNUSED(filler) /* not used in the write case */ + return; +# endif + } + + else /* write */ + { +# ifdef PNG_WRITE_FILLER_SUPPORTED + /* On write the usr_channels parameter must be set correctly at the + * start to record the number of channels in the app-supplied data. + */ + switch (png_ptr->color_type) + { + case PNG_COLOR_TYPE_RGB: + png_ptr->usr_channels = 4; + break; + + case PNG_COLOR_TYPE_GRAY: + if (png_ptr->bit_depth >= 8) + { + png_ptr->usr_channels = 2; + break; + } + + else + { + /* There simply isn't any code in libpng to strip out bits + * from bytes when the components are less than a byte in + * size! + */ + png_app_error(png_ptr, + "png_set_filler is invalid for low bit depth gray output"); + return; + } + + default: + png_app_error(png_ptr, + "png_set_filler: inappropriate color type"); + return; + } +# else + png_app_error(png_ptr, "png_set_filler not supported on write"); + return; +# endif + } + + /* Here on success - libpng supports the operation, set the transformation + * and the flag to say where the filler channel is. + */ + png_ptr->transformations |= PNG_FILLER; + + if (filler_loc == PNG_FILLER_AFTER) + png_ptr->flags |= PNG_FLAG_FILLER_AFTER; + + else + png_ptr->flags &= ~PNG_FLAG_FILLER_AFTER; +} + +/* Added to libpng-1.2.7 */ +void PNGAPI +png_set_add_alpha(png_structrp png_ptr, png_uint_32 filler, int filler_loc) +{ + png_debug(1, "in png_set_add_alpha"); + + if (png_ptr == NULL) + return; + + png_set_filler(png_ptr, filler, filler_loc); + /* The above may fail to do anything. */ + if (png_ptr->transformations & PNG_FILLER) + png_ptr->transformations |= PNG_ADD_ALPHA; +} + +#endif + +#if defined(PNG_READ_SWAP_ALPHA_SUPPORTED) || \ + defined(PNG_WRITE_SWAP_ALPHA_SUPPORTED) +void PNGAPI +png_set_swap_alpha(png_structrp png_ptr) +{ + png_debug(1, "in png_set_swap_alpha"); + + if (png_ptr == NULL) + return; + + png_ptr->transformations |= PNG_SWAP_ALPHA; +} +#endif + +#if defined(PNG_READ_INVERT_ALPHA_SUPPORTED) || \ + defined(PNG_WRITE_INVERT_ALPHA_SUPPORTED) +void PNGAPI +png_set_invert_alpha(png_structrp png_ptr) +{ + png_debug(1, "in png_set_invert_alpha"); + + if (png_ptr == NULL) + return; + + png_ptr->transformations |= PNG_INVERT_ALPHA; +} +#endif + +#if defined(PNG_READ_INVERT_SUPPORTED) || defined(PNG_WRITE_INVERT_SUPPORTED) +void PNGAPI +png_set_invert_mono(png_structrp png_ptr) +{ + png_debug(1, "in png_set_invert_mono"); + + if (png_ptr == NULL) + return; + + png_ptr->transformations |= PNG_INVERT_MONO; +} + +/* Invert monochrome grayscale data */ +void /* PRIVATE */ +png_do_invert(png_row_infop row_info, png_bytep row) +{ + png_debug(1, "in png_do_invert"); + + /* This test removed from libpng version 1.0.13 and 1.2.0: + * if (row_info->bit_depth == 1 && + */ + if (row_info->color_type == PNG_COLOR_TYPE_GRAY) + { + png_bytep rp = row; + png_size_t i; + png_size_t istop = row_info->rowbytes; + + for (i = 0; i < istop; i++) + { + *rp = (png_byte)(~(*rp)); + rp++; + } + } + + else if (row_info->color_type == PNG_COLOR_TYPE_GRAY_ALPHA && + row_info->bit_depth == 8) + { + png_bytep rp = row; + png_size_t i; + png_size_t istop = row_info->rowbytes; + + for (i = 0; i < istop; i += 2) + { + *rp = (png_byte)(~(*rp)); + rp += 2; + } + } + +#ifdef PNG_16BIT_SUPPORTED + else if (row_info->color_type == PNG_COLOR_TYPE_GRAY_ALPHA && + row_info->bit_depth == 16) + { + png_bytep rp = row; + png_size_t i; + png_size_t istop = row_info->rowbytes; + + for (i = 0; i < istop; i += 4) + { + *rp = (png_byte)(~(*rp)); + *(rp + 1) = (png_byte)(~(*(rp + 1))); + rp += 4; + } + } +#endif +} +#endif + +#ifdef PNG_16BIT_SUPPORTED +#if defined(PNG_READ_SWAP_SUPPORTED) || defined(PNG_WRITE_SWAP_SUPPORTED) +/* Swaps byte order on 16 bit depth images */ +void /* PRIVATE */ +png_do_swap(png_row_infop row_info, png_bytep row) +{ + png_debug(1, "in png_do_swap"); + + if (row_info->bit_depth == 16) + { + png_bytep rp = row; + png_uint_32 i; + png_uint_32 istop= row_info->width * row_info->channels; + + for (i = 0; i < istop; i++, rp += 2) + { + png_byte t = *rp; + *rp = *(rp + 1); + *(rp + 1) = t; + } + } +} +#endif +#endif + +#if defined(PNG_READ_PACKSWAP_SUPPORTED)||defined(PNG_WRITE_PACKSWAP_SUPPORTED) +static PNG_CONST png_byte onebppswaptable[256] = { + 0x00, 0x80, 0x40, 0xC0, 0x20, 0xA0, 0x60, 0xE0, + 0x10, 0x90, 0x50, 0xD0, 0x30, 0xB0, 0x70, 0xF0, + 0x08, 0x88, 0x48, 0xC8, 0x28, 0xA8, 0x68, 0xE8, + 0x18, 0x98, 0x58, 0xD8, 0x38, 0xB8, 0x78, 0xF8, + 0x04, 0x84, 0x44, 0xC4, 0x24, 0xA4, 0x64, 0xE4, + 0x14, 0x94, 0x54, 0xD4, 0x34, 0xB4, 0x74, 0xF4, + 0x0C, 0x8C, 0x4C, 0xCC, 0x2C, 0xAC, 0x6C, 0xEC, + 0x1C, 0x9C, 0x5C, 0xDC, 0x3C, 0xBC, 0x7C, 0xFC, + 0x02, 0x82, 0x42, 0xC2, 0x22, 0xA2, 0x62, 0xE2, + 0x12, 0x92, 0x52, 0xD2, 0x32, 0xB2, 0x72, 0xF2, + 0x0A, 0x8A, 0x4A, 0xCA, 0x2A, 0xAA, 0x6A, 0xEA, + 0x1A, 0x9A, 0x5A, 0xDA, 0x3A, 0xBA, 0x7A, 0xFA, + 0x06, 0x86, 0x46, 0xC6, 0x26, 0xA6, 0x66, 0xE6, + 0x16, 0x96, 0x56, 0xD6, 0x36, 0xB6, 0x76, 0xF6, + 0x0E, 0x8E, 0x4E, 0xCE, 0x2E, 0xAE, 0x6E, 0xEE, + 0x1E, 0x9E, 0x5E, 0xDE, 0x3E, 0xBE, 0x7E, 0xFE, + 0x01, 0x81, 0x41, 0xC1, 0x21, 0xA1, 0x61, 0xE1, + 0x11, 0x91, 0x51, 0xD1, 0x31, 0xB1, 0x71, 0xF1, + 0x09, 0x89, 0x49, 0xC9, 0x29, 0xA9, 0x69, 0xE9, + 0x19, 0x99, 0x59, 0xD9, 0x39, 0xB9, 0x79, 0xF9, + 0x05, 0x85, 0x45, 0xC5, 0x25, 0xA5, 0x65, 0xE5, + 0x15, 0x95, 0x55, 0xD5, 0x35, 0xB5, 0x75, 0xF5, + 0x0D, 0x8D, 0x4D, 0xCD, 0x2D, 0xAD, 0x6D, 0xED, + 0x1D, 0x9D, 0x5D, 0xDD, 0x3D, 0xBD, 0x7D, 0xFD, + 0x03, 0x83, 0x43, 0xC3, 0x23, 0xA3, 0x63, 0xE3, + 0x13, 0x93, 0x53, 0xD3, 0x33, 0xB3, 0x73, 0xF3, + 0x0B, 0x8B, 0x4B, 0xCB, 0x2B, 0xAB, 0x6B, 0xEB, + 0x1B, 0x9B, 0x5B, 0xDB, 0x3B, 0xBB, 0x7B, 0xFB, + 0x07, 0x87, 0x47, 0xC7, 0x27, 0xA7, 0x67, 0xE7, + 0x17, 0x97, 0x57, 0xD7, 0x37, 0xB7, 0x77, 0xF7, + 0x0F, 0x8F, 0x4F, 0xCF, 0x2F, 0xAF, 0x6F, 0xEF, + 0x1F, 0x9F, 0x5F, 0xDF, 0x3F, 0xBF, 0x7F, 0xFF +}; + +static PNG_CONST png_byte twobppswaptable[256] = { + 0x00, 0x40, 0x80, 0xC0, 0x10, 0x50, 0x90, 0xD0, + 0x20, 0x60, 0xA0, 0xE0, 0x30, 0x70, 0xB0, 0xF0, + 0x04, 0x44, 0x84, 0xC4, 0x14, 0x54, 0x94, 0xD4, + 0x24, 0x64, 0xA4, 0xE4, 0x34, 0x74, 0xB4, 0xF4, + 0x08, 0x48, 0x88, 0xC8, 0x18, 0x58, 0x98, 0xD8, + 0x28, 0x68, 0xA8, 0xE8, 0x38, 0x78, 0xB8, 0xF8, + 0x0C, 0x4C, 0x8C, 0xCC, 0x1C, 0x5C, 0x9C, 0xDC, + 0x2C, 0x6C, 0xAC, 0xEC, 0x3C, 0x7C, 0xBC, 0xFC, + 0x01, 0x41, 0x81, 0xC1, 0x11, 0x51, 0x91, 0xD1, + 0x21, 0x61, 0xA1, 0xE1, 0x31, 0x71, 0xB1, 0xF1, + 0x05, 0x45, 0x85, 0xC5, 0x15, 0x55, 0x95, 0xD5, + 0x25, 0x65, 0xA5, 0xE5, 0x35, 0x75, 0xB5, 0xF5, + 0x09, 0x49, 0x89, 0xC9, 0x19, 0x59, 0x99, 0xD9, + 0x29, 0x69, 0xA9, 0xE9, 0x39, 0x79, 0xB9, 0xF9, + 0x0D, 0x4D, 0x8D, 0xCD, 0x1D, 0x5D, 0x9D, 0xDD, + 0x2D, 0x6D, 0xAD, 0xED, 0x3D, 0x7D, 0xBD, 0xFD, + 0x02, 0x42, 0x82, 0xC2, 0x12, 0x52, 0x92, 0xD2, + 0x22, 0x62, 0xA2, 0xE2, 0x32, 0x72, 0xB2, 0xF2, + 0x06, 0x46, 0x86, 0xC6, 0x16, 0x56, 0x96, 0xD6, + 0x26, 0x66, 0xA6, 0xE6, 0x36, 0x76, 0xB6, 0xF6, + 0x0A, 0x4A, 0x8A, 0xCA, 0x1A, 0x5A, 0x9A, 0xDA, + 0x2A, 0x6A, 0xAA, 0xEA, 0x3A, 0x7A, 0xBA, 0xFA, + 0x0E, 0x4E, 0x8E, 0xCE, 0x1E, 0x5E, 0x9E, 0xDE, + 0x2E, 0x6E, 0xAE, 0xEE, 0x3E, 0x7E, 0xBE, 0xFE, + 0x03, 0x43, 0x83, 0xC3, 0x13, 0x53, 0x93, 0xD3, + 0x23, 0x63, 0xA3, 0xE3, 0x33, 0x73, 0xB3, 0xF3, + 0x07, 0x47, 0x87, 0xC7, 0x17, 0x57, 0x97, 0xD7, + 0x27, 0x67, 0xA7, 0xE7, 0x37, 0x77, 0xB7, 0xF7, + 0x0B, 0x4B, 0x8B, 0xCB, 0x1B, 0x5B, 0x9B, 0xDB, + 0x2B, 0x6B, 0xAB, 0xEB, 0x3B, 0x7B, 0xBB, 0xFB, + 0x0F, 0x4F, 0x8F, 0xCF, 0x1F, 0x5F, 0x9F, 0xDF, + 0x2F, 0x6F, 0xAF, 0xEF, 0x3F, 0x7F, 0xBF, 0xFF +}; + +static PNG_CONST png_byte fourbppswaptable[256] = { + 0x00, 0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, + 0x80, 0x90, 0xA0, 0xB0, 0xC0, 0xD0, 0xE0, 0xF0, + 0x01, 0x11, 0x21, 0x31, 0x41, 0x51, 0x61, 0x71, + 0x81, 0x91, 0xA1, 0xB1, 0xC1, 0xD1, 0xE1, 0xF1, + 0x02, 0x12, 0x22, 0x32, 0x42, 0x52, 0x62, 0x72, + 0x82, 0x92, 0xA2, 0xB2, 0xC2, 0xD2, 0xE2, 0xF2, + 0x03, 0x13, 0x23, 0x33, 0x43, 0x53, 0x63, 0x73, + 0x83, 0x93, 0xA3, 0xB3, 0xC3, 0xD3, 0xE3, 0xF3, + 0x04, 0x14, 0x24, 0x34, 0x44, 0x54, 0x64, 0x74, + 0x84, 0x94, 0xA4, 0xB4, 0xC4, 0xD4, 0xE4, 0xF4, + 0x05, 0x15, 0x25, 0x35, 0x45, 0x55, 0x65, 0x75, + 0x85, 0x95, 0xA5, 0xB5, 0xC5, 0xD5, 0xE5, 0xF5, + 0x06, 0x16, 0x26, 0x36, 0x46, 0x56, 0x66, 0x76, + 0x86, 0x96, 0xA6, 0xB6, 0xC6, 0xD6, 0xE6, 0xF6, + 0x07, 0x17, 0x27, 0x37, 0x47, 0x57, 0x67, 0x77, + 0x87, 0x97, 0xA7, 0xB7, 0xC7, 0xD7, 0xE7, 0xF7, + 0x08, 0x18, 0x28, 0x38, 0x48, 0x58, 0x68, 0x78, + 0x88, 0x98, 0xA8, 0xB8, 0xC8, 0xD8, 0xE8, 0xF8, + 0x09, 0x19, 0x29, 0x39, 0x49, 0x59, 0x69, 0x79, + 0x89, 0x99, 0xA9, 0xB9, 0xC9, 0xD9, 0xE9, 0xF9, + 0x0A, 0x1A, 0x2A, 0x3A, 0x4A, 0x5A, 0x6A, 0x7A, + 0x8A, 0x9A, 0xAA, 0xBA, 0xCA, 0xDA, 0xEA, 0xFA, + 0x0B, 0x1B, 0x2B, 0x3B, 0x4B, 0x5B, 0x6B, 0x7B, + 0x8B, 0x9B, 0xAB, 0xBB, 0xCB, 0xDB, 0xEB, 0xFB, + 0x0C, 0x1C, 0x2C, 0x3C, 0x4C, 0x5C, 0x6C, 0x7C, + 0x8C, 0x9C, 0xAC, 0xBC, 0xCC, 0xDC, 0xEC, 0xFC, + 0x0D, 0x1D, 0x2D, 0x3D, 0x4D, 0x5D, 0x6D, 0x7D, + 0x8D, 0x9D, 0xAD, 0xBD, 0xCD, 0xDD, 0xED, 0xFD, + 0x0E, 0x1E, 0x2E, 0x3E, 0x4E, 0x5E, 0x6E, 0x7E, + 0x8E, 0x9E, 0xAE, 0xBE, 0xCE, 0xDE, 0xEE, 0xFE, + 0x0F, 0x1F, 0x2F, 0x3F, 0x4F, 0x5F, 0x6F, 0x7F, + 0x8F, 0x9F, 0xAF, 0xBF, 0xCF, 0xDF, 0xEF, 0xFF +}; + +/* Swaps pixel packing order within bytes */ +void /* PRIVATE */ +png_do_packswap(png_row_infop row_info, png_bytep row) +{ + png_debug(1, "in png_do_packswap"); + + if (row_info->bit_depth < 8) + { + png_bytep rp; + png_const_bytep end, table; + + end = row + row_info->rowbytes; + + if (row_info->bit_depth == 1) + table = onebppswaptable; + + else if (row_info->bit_depth == 2) + table = twobppswaptable; + + else if (row_info->bit_depth == 4) + table = fourbppswaptable; + + else + return; + + for (rp = row; rp < end; rp++) + *rp = table[*rp]; + } +} +#endif /* PNG_READ_PACKSWAP_SUPPORTED or PNG_WRITE_PACKSWAP_SUPPORTED */ + +#if defined(PNG_WRITE_FILLER_SUPPORTED) || \ + defined(PNG_READ_STRIP_ALPHA_SUPPORTED) +/* Remove a channel - this used to be 'png_do_strip_filler' but it used a + * somewhat weird combination of flags to determine what to do. All the calls + * to png_do_strip_filler are changed in 1.5.2 to call this instead with the + * correct arguments. + * + * The routine isn't general - the channel must be the channel at the start or + * end (not in the middle) of each pixel. + */ +void /* PRIVATE */ +png_do_strip_channel(png_row_infop row_info, png_bytep row, int at_start) +{ + png_bytep sp = row; /* source pointer */ + png_bytep dp = row; /* destination pointer */ + png_bytep ep = row + row_info->rowbytes; /* One beyond end of row */ + + /* At the start sp will point to the first byte to copy and dp to where + * it is copied to. ep always points just beyond the end of the row, so + * the loop simply copies (channels-1) channels until sp reaches ep. + * + * at_start: 0 -- convert AG, XG, ARGB, XRGB, AAGG, XXGG, etc. + * nonzero -- convert GA, GX, RGBA, RGBX, GGAA, RRGGBBXX, etc. + */ + + /* GA, GX, XG cases */ + if (row_info->channels == 2) + { + if (row_info->bit_depth == 8) + { + if (at_start) /* Skip initial filler */ + ++sp; + else /* Skip initial channel and, for sp, the filler */ + sp += 2, ++dp; + + /* For a 1 pixel wide image there is nothing to do */ + while (sp < ep) + *dp++ = *sp, sp += 2; + + row_info->pixel_depth = 8; + } + + else if (row_info->bit_depth == 16) + { + if (at_start) /* Skip initial filler */ + sp += 2; + else /* Skip initial channel and, for sp, the filler */ + sp += 4, dp += 2; + + while (sp < ep) + *dp++ = *sp++, *dp++ = *sp, sp += 3; + + row_info->pixel_depth = 16; + } + + else + return; /* bad bit depth */ + + row_info->channels = 1; + + /* Finally fix the color type if it records an alpha channel */ + if (row_info->color_type == PNG_COLOR_TYPE_GRAY_ALPHA) + row_info->color_type = PNG_COLOR_TYPE_GRAY; + } + + /* RGBA, RGBX, XRGB cases */ + else if (row_info->channels == 4) + { + if (row_info->bit_depth == 8) + { + if (at_start) /* Skip initial filler */ + ++sp; + else /* Skip initial channels and, for sp, the filler */ + sp += 4, dp += 3; + + /* Note that the loop adds 3 to dp and 4 to sp each time. */ + while (sp < ep) + *dp++ = *sp++, *dp++ = *sp++, *dp++ = *sp, sp += 2; + + row_info->pixel_depth = 24; + } + + else if (row_info->bit_depth == 16) + { + if (at_start) /* Skip initial filler */ + sp += 2; + else /* Skip initial channels and, for sp, the filler */ + sp += 8, dp += 6; + + while (sp < ep) + { + /* Copy 6 bytes, skip 2 */ + *dp++ = *sp++, *dp++ = *sp++; + *dp++ = *sp++, *dp++ = *sp++; + *dp++ = *sp++, *dp++ = *sp, sp += 3; + } + + row_info->pixel_depth = 48; + } + + else + return; /* bad bit depth */ + + row_info->channels = 3; + + /* Finally fix the color type if it records an alpha channel */ + if (row_info->color_type == PNG_COLOR_TYPE_RGB_ALPHA) + row_info->color_type = PNG_COLOR_TYPE_RGB; + } + + else + return; /* The filler channel has gone already */ + + /* Fix the rowbytes value. */ + row_info->rowbytes = dp-row; +} +#endif + +#if defined(PNG_READ_BGR_SUPPORTED) || defined(PNG_WRITE_BGR_SUPPORTED) +/* Swaps red and blue bytes within a pixel */ +void /* PRIVATE */ +png_do_bgr(png_row_infop row_info, png_bytep row) +{ + png_debug(1, "in png_do_bgr"); + + if ((row_info->color_type & PNG_COLOR_MASK_COLOR)) + { + png_uint_32 row_width = row_info->width; + if (row_info->bit_depth == 8) + { + if (row_info->color_type == PNG_COLOR_TYPE_RGB) + { + png_bytep rp; + png_uint_32 i; + + for (i = 0, rp = row; i < row_width; i++, rp += 3) + { + png_byte save = *rp; + *rp = *(rp + 2); + *(rp + 2) = save; + } + } + + else if (row_info->color_type == PNG_COLOR_TYPE_RGB_ALPHA) + { + png_bytep rp; + png_uint_32 i; + + for (i = 0, rp = row; i < row_width; i++, rp += 4) + { + png_byte save = *rp; + *rp = *(rp + 2); + *(rp + 2) = save; + } + } + } + +#ifdef PNG_16BIT_SUPPORTED + else if (row_info->bit_depth == 16) + { + if (row_info->color_type == PNG_COLOR_TYPE_RGB) + { + png_bytep rp; + png_uint_32 i; + + for (i = 0, rp = row; i < row_width; i++, rp += 6) + { + png_byte save = *rp; + *rp = *(rp + 4); + *(rp + 4) = save; + save = *(rp + 1); + *(rp + 1) = *(rp + 5); + *(rp + 5) = save; + } + } + + else if (row_info->color_type == PNG_COLOR_TYPE_RGB_ALPHA) + { + png_bytep rp; + png_uint_32 i; + + for (i = 0, rp = row; i < row_width; i++, rp += 8) + { + png_byte save = *rp; + *rp = *(rp + 4); + *(rp + 4) = save; + save = *(rp + 1); + *(rp + 1) = *(rp + 5); + *(rp + 5) = save; + } + } + } +#endif + } +} +#endif /* PNG_READ_BGR_SUPPORTED or PNG_WRITE_BGR_SUPPORTED */ + +#if defined(PNG_READ_CHECK_FOR_INVALID_INDEX_SUPPORTED) || \ + defined(PNG_WRITE_CHECK_FOR_INVALID_INDEX_SUPPORTED) +/* Added at libpng-1.5.10 */ +void /* PRIVATE */ +png_do_check_palette_indexes(png_structrp png_ptr, png_row_infop row_info) +{ + if (png_ptr->num_palette < (1 << row_info->bit_depth) && + png_ptr->num_palette > 0) /* num_palette can be 0 in MNG files */ + { + /* Calculations moved outside switch in an attempt to stop different + * compiler warnings. 'padding' is in *bits* within the last byte, it is + * an 'int' because pixel_depth becomes an 'int' in the expression below, + * and this calculation is used because it avoids warnings that other + * forms produced on either GCC or MSVC. + */ + int padding = (-row_info->pixel_depth * row_info->width) & 7; + png_bytep rp = png_ptr->row_buf + row_info->rowbytes; + + switch (row_info->bit_depth) + { + case 1: + { + /* in this case, all bytes must be 0 so we don't need + * to unpack the pixels except for the rightmost one. + */ + for (; rp > png_ptr->row_buf; rp--) + { + if (*rp >> padding != 0) + png_ptr->num_palette_max = 1; + padding = 0; + } + + break; + } + + case 2: + { + for (; rp > png_ptr->row_buf; rp--) + { + int i = ((*rp >> padding) & 0x03); + + if (i > png_ptr->num_palette_max) + png_ptr->num_palette_max = i; + + i = (((*rp >> padding) >> 2) & 0x03); + + if (i > png_ptr->num_palette_max) + png_ptr->num_palette_max = i; + + i = (((*rp >> padding) >> 4) & 0x03); + + if (i > png_ptr->num_palette_max) + png_ptr->num_palette_max = i; + + i = (((*rp >> padding) >> 6) & 0x03); + + if (i > png_ptr->num_palette_max) + png_ptr->num_palette_max = i; + + padding = 0; + } + + break; + } + + case 4: + { + for (; rp > png_ptr->row_buf; rp--) + { + int i = ((*rp >> padding) & 0x0f); + + if (i > png_ptr->num_palette_max) + png_ptr->num_palette_max = i; + + i = (((*rp >> padding) >> 4) & 0x0f); + + if (i > png_ptr->num_palette_max) + png_ptr->num_palette_max = i; + + padding = 0; + } + + break; + } + + case 8: + { + for (; rp > png_ptr->row_buf; rp--) + { + if (*rp > png_ptr->num_palette_max) + png_ptr->num_palette_max = (int) *rp; + } + + break; + } + + default: + break; + } + } +} +#endif /* PNG_CHECK_FOR_INVALID_INDEX_SUPPORTED */ + +#if defined(PNG_READ_USER_TRANSFORM_SUPPORTED) || \ + defined(PNG_WRITE_USER_TRANSFORM_SUPPORTED) +#ifdef PNG_USER_TRANSFORM_PTR_SUPPORTED +void PNGAPI +png_set_user_transform_info(png_structrp png_ptr, png_voidp + user_transform_ptr, int user_transform_depth, int user_transform_channels) +{ + png_debug(1, "in png_set_user_transform_info"); + + if (png_ptr == NULL) + return; + +#ifdef PNG_READ_USER_TRANSFORM_SUPPORTED + if ((png_ptr->mode & PNG_IS_READ_STRUCT) != 0 && + (png_ptr->flags & PNG_FLAG_ROW_INIT) != 0) + { + png_app_error(png_ptr, + "info change after png_start_read_image or png_read_update_info"); + return; + } +#endif + + png_ptr->user_transform_ptr = user_transform_ptr; + png_ptr->user_transform_depth = (png_byte)user_transform_depth; + png_ptr->user_transform_channels = (png_byte)user_transform_channels; +} +#endif + +/* This function returns a pointer to the user_transform_ptr associated with + * the user transform functions. The application should free any memory + * associated with this pointer before png_write_destroy and png_read_destroy + * are called. + */ +#ifdef PNG_USER_TRANSFORM_PTR_SUPPORTED +png_voidp PNGAPI +png_get_user_transform_ptr(png_const_structrp png_ptr) +{ + if (png_ptr == NULL) + return (NULL); + + return png_ptr->user_transform_ptr; +} +#endif + +#ifdef PNG_USER_TRANSFORM_INFO_SUPPORTED +png_uint_32 PNGAPI +png_get_current_row_number(png_const_structrp png_ptr) +{ + /* See the comments in png.h - this is the sub-image row when reading and + * interlaced image. + */ + if (png_ptr != NULL) + return png_ptr->row_number; + + return PNG_UINT_32_MAX; /* help the app not to fail silently */ +} + +png_byte PNGAPI +png_get_current_pass_number(png_const_structrp png_ptr) +{ + if (png_ptr != NULL) + return png_ptr->pass; + return 8; /* invalid */ +} +#endif /* PNG_USER_TRANSFORM_INFO_SUPPORTED */ +#endif /* PNG_READ_USER_TRANSFORM_SUPPORTED || + PNG_WRITE_USER_TRANSFORM_SUPPORTED */ +#endif /* PNG_READ_SUPPORTED || PNG_WRITE_SUPPORTED */ diff --git a/ml/dlib/dlib/external/libpng/pngwio.c b/ml/dlib/dlib/external/libpng/pngwio.c new file mode 100644 index 000000000..e3289dfe4 --- /dev/null +++ b/ml/dlib/dlib/external/libpng/pngwio.c @@ -0,0 +1,164 @@ + +/* pngwio.c - functions for data output + * + * Last changed in libpng 1.6.0 [February 14, 2013] + * Copyright (c) 1998-2013 Glenn Randers-Pehrson + * (Version 0.96 Copyright (c) 1996, 1997 Andreas Dilger) + * (Version 0.88 Copyright (c) 1995, 1996 Guy Eric Schalnat, Group 42, Inc.) + * + * This code is released under the libpng license. + * For conditions of distribution and use, see the disclaimer + * and license in png.h + * + * This file provides a location for all output. Users who need + * special handling are expected to write functions that have the same + * arguments as these and perform similar functions, but that possibly + * use different output methods. Note that you shouldn't change these + * functions, but rather write replacement functions and then change + * them at run time with png_set_write_fn(...). + */ + +#include "pngpriv.h" + +#ifdef PNG_WRITE_SUPPORTED + +/* Write the data to whatever output you are using. The default routine + * writes to a file pointer. Note that this routine sometimes gets called + * with very small lengths, so you should implement some kind of simple + * buffering if you are using unbuffered writes. This should never be asked + * to write more than 64K on a 16 bit machine. + */ + +void /* PRIVATE */ +png_write_data(png_structrp png_ptr, png_const_bytep data, png_size_t length) +{ + /* NOTE: write_data_fn must not change the buffer! */ + if (png_ptr->write_data_fn != NULL ) + (*(png_ptr->write_data_fn))(png_ptr, png_constcast(png_bytep,data), + length); + + else + png_error(png_ptr, "Call to NULL write function"); +} + +#ifdef PNG_STDIO_SUPPORTED +/* This is the function that does the actual writing of data. If you are + * not writing to a standard C stream, you should create a replacement + * write_data function and use it at run time with png_set_write_fn(), rather + * than changing the library. + */ +void PNGCBAPI +png_default_write_data(png_structp png_ptr, png_bytep data, png_size_t length) +{ + png_size_t check; + + if (png_ptr == NULL) + return; + + check = fwrite(data, 1, length, (png_FILE_p)(png_ptr->io_ptr)); + + if (check != length) + png_error(png_ptr, "Write Error"); +} +#endif + +/* This function is called to output any data pending writing (normally + * to disk). After png_flush is called, there should be no data pending + * writing in any buffers. + */ +#ifdef PNG_WRITE_FLUSH_SUPPORTED +void /* PRIVATE */ +png_flush(png_structrp png_ptr) +{ + if (png_ptr->output_flush_fn != NULL) + (*(png_ptr->output_flush_fn))(png_ptr); +} + +# ifdef PNG_STDIO_SUPPORTED +void PNGCBAPI +png_default_flush(png_structp png_ptr) +{ + png_FILE_p io_ptr; + + if (png_ptr == NULL) + return; + + io_ptr = png_voidcast(png_FILE_p, (png_ptr->io_ptr)); + fflush(io_ptr); +} +# endif +#endif + +/* This function allows the application to supply new output functions for + * libpng if standard C streams aren't being used. + * + * This function takes as its arguments: + * png_ptr - pointer to a png output data structure + * io_ptr - pointer to user supplied structure containing info about + * the output functions. May be NULL. + * write_data_fn - pointer to a new output function that takes as its + * arguments a pointer to a png_struct, a pointer to + * data to be written, and a 32-bit unsigned int that is + * the number of bytes to be written. The new write + * function should call png_error(png_ptr, "Error msg") + * to exit and output any fatal error messages. May be + * NULL, in which case libpng's default function will + * be used. + * flush_data_fn - pointer to a new flush function that takes as its + * arguments a pointer to a png_struct. After a call to + * the flush function, there should be no data in any buffers + * or pending transmission. If the output method doesn't do + * any buffering of output, a function prototype must still be + * supplied although it doesn't have to do anything. If + * PNG_WRITE_FLUSH_SUPPORTED is not defined at libpng compile + * time, output_flush_fn will be ignored, although it must be + * supplied for compatibility. May be NULL, in which case + * libpng's default function will be used, if + * PNG_WRITE_FLUSH_SUPPORTED is defined. This is not + * a good idea if io_ptr does not point to a standard + * *FILE structure. + */ +void PNGAPI +png_set_write_fn(png_structrp png_ptr, png_voidp io_ptr, + png_rw_ptr write_data_fn, png_flush_ptr output_flush_fn) +{ + if (png_ptr == NULL) + return; + + png_ptr->io_ptr = io_ptr; + +#ifdef PNG_STDIO_SUPPORTED + if (write_data_fn != NULL) + png_ptr->write_data_fn = write_data_fn; + + else + png_ptr->write_data_fn = png_default_write_data; +#else + png_ptr->write_data_fn = write_data_fn; +#endif + +#ifdef PNG_WRITE_FLUSH_SUPPORTED +# ifdef PNG_STDIO_SUPPORTED + + if (output_flush_fn != NULL) + png_ptr->output_flush_fn = output_flush_fn; + + else + png_ptr->output_flush_fn = png_default_flush; + +# else + png_ptr->output_flush_fn = output_flush_fn; +# endif +#endif /* PNG_WRITE_FLUSH_SUPPORTED */ + + /* It is an error to read while writing a png file */ + if (png_ptr->read_data_fn != NULL) + { + png_ptr->read_data_fn = NULL; + + png_warning(png_ptr, + "Can't set both read_data_fn and write_data_fn in the" + " same structure"); + } +} +#endif /* PNG_WRITE_SUPPORTED */ diff --git a/ml/dlib/dlib/external/libpng/pngwrite.c b/ml/dlib/dlib/external/libpng/pngwrite.c new file mode 100644 index 000000000..b71a3d345 --- /dev/null +++ b/ml/dlib/dlib/external/libpng/pngwrite.c @@ -0,0 +1,2330 @@ + +/* pngwrite.c - general routines to write a PNG file + * + * Last changed in libpng 1.6.2 [April 25, 2013] + * Copyright (c) 1998-2013 Glenn Randers-Pehrson + * (Version 0.96 Copyright (c) 1996, 1997 Andreas Dilger) + * (Version 0.88 Copyright (c) 1995, 1996 Guy Eric Schalnat, Group 42, Inc.) + * + * This code is released under the libpng license. + * For conditions of distribution and use, see the disclaimer + * and license in png.h + */ + +#include "pngpriv.h" +#if defined(PNG_SIMPLIFIED_WRITE_SUPPORTED) && defined(PNG_STDIO_SUPPORTED) +# include +#endif + +#ifdef PNG_WRITE_SUPPORTED + +#ifdef PNG_WRITE_UNKNOWN_CHUNKS_SUPPORTED +/* Write out all the unknown chunks for the current given location */ +static void +write_unknown_chunks(png_structrp png_ptr, png_const_inforp info_ptr, + unsigned int where) +{ + if (info_ptr->unknown_chunks_num) + { + png_const_unknown_chunkp up; + + png_debug(5, "writing extra chunks"); + + for (up = info_ptr->unknown_chunks; + up < info_ptr->unknown_chunks + info_ptr->unknown_chunks_num; + ++up) + if (up->location & where) + { + /* If per-chunk unknown chunk handling is enabled use it, otherwise + * just write the chunks the application has set. + */ +#ifdef PNG_SET_UNKNOWN_CHUNKS_SUPPORTED + int keep = png_handle_as_unknown(png_ptr, up->name); + + /* NOTE: this code is radically different from the read side in the + * matter of handling an ancillary unknown chunk. In the read side + * the default behavior is to discard it, in the code below the default + * behavior is to write it. Critical chunks are, however, only + * written if explicitly listed or if the default is set to write all + * unknown chunks. + * + * The default handling is also slightly weird - it is not possible to + * stop the writing of all unsafe-to-copy chunks! + * + * TODO: REVIEW: this would seem to be a bug. + */ + if (keep != PNG_HANDLE_CHUNK_NEVER && + ((up->name[3] & 0x20) /* safe-to-copy overrides everything */ || + keep == PNG_HANDLE_CHUNK_ALWAYS || + (keep == PNG_HANDLE_CHUNK_AS_DEFAULT && + png_ptr->unknown_default == PNG_HANDLE_CHUNK_ALWAYS))) +#endif + { + /* TODO: review, what is wrong with a zero length unknown chunk? */ + if (up->size == 0) + png_warning(png_ptr, "Writing zero-length unknown chunk"); + + png_write_chunk(png_ptr, up->name, up->data, up->size); + } + } + } +} +#endif /* PNG_WRITE_UNKNOWN_CHUNKS_SUPPORTED */ + +/* Writes all the PNG information. This is the suggested way to use the + * library. If you have a new chunk to add, make a function to write it, + * and put it in the correct location here. If you want the chunk written + * after the image data, put it in png_write_end(). I strongly encourage + * you to supply a PNG_INFO_ flag, and check info_ptr->valid before writing + * the chunk, as that will keep the code from breaking if you want to just + * write a plain PNG file. If you have long comments, I suggest writing + * them in png_write_end(), and compressing them. + */ +void PNGAPI +png_write_info_before_PLTE(png_structrp png_ptr, png_const_inforp info_ptr) +{ + png_debug(1, "in png_write_info_before_PLTE"); + + if (png_ptr == NULL || info_ptr == NULL) + return; + + if (!(png_ptr->mode & PNG_WROTE_INFO_BEFORE_PLTE)) + { + /* Write PNG signature */ + png_write_sig(png_ptr); + +#ifdef PNG_MNG_FEATURES_SUPPORTED + if ((png_ptr->mode&PNG_HAVE_PNG_SIGNATURE) && \ + (png_ptr->mng_features_permitted)) + { + png_warning(png_ptr, "MNG features are not allowed in a PNG datastream"); + png_ptr->mng_features_permitted = 0; + } +#endif + + /* Write IHDR information. */ + png_write_IHDR(png_ptr, info_ptr->width, info_ptr->height, + info_ptr->bit_depth, info_ptr->color_type, info_ptr->compression_type, + info_ptr->filter_type, +#ifdef PNG_WRITE_INTERLACING_SUPPORTED + info_ptr->interlace_type +#else + 0 +#endif + ); + + /* The rest of these check to see if the valid field has the appropriate + * flag set, and if it does, writes the chunk. + * + * 1.6.0: COLORSPACE support controls the writing of these chunks too, and + * the chunks will be written if the WRITE routine is there and information + * is available in the COLORSPACE. (See png_colorspace_sync_info in png.c + * for where the valid flags get set.) + * + * Under certain circumstances the colorspace can be invalidated without + * syncing the info_struct 'valid' flags; this happens if libpng detects and + * error and calls png_error while the color space is being set, yet the + * application continues writing the PNG. So check the 'invalid' flag here + * too. + */ +#ifdef PNG_GAMMA_SUPPORTED +# ifdef PNG_WRITE_gAMA_SUPPORTED + if (!(info_ptr->colorspace.flags & PNG_COLORSPACE_INVALID) && + (info_ptr->colorspace.flags & PNG_COLORSPACE_FROM_gAMA) && + (info_ptr->valid & PNG_INFO_gAMA)) + png_write_gAMA_fixed(png_ptr, info_ptr->colorspace.gamma); +# endif +#endif + +#ifdef PNG_COLORSPACE_SUPPORTED + /* Write only one of sRGB or an ICC profile. If a profile was supplied + * and it matches one of the known sRGB ones issue a warning. + */ +# ifdef PNG_WRITE_iCCP_SUPPORTED + if (!(info_ptr->colorspace.flags & PNG_COLORSPACE_INVALID) && + (info_ptr->valid & PNG_INFO_iCCP)) + { +# ifdef PNG_WRITE_sRGB_SUPPORTED + if (info_ptr->valid & PNG_INFO_sRGB) + png_app_warning(png_ptr, + "profile matches sRGB but writing iCCP instead"); +# endif + + png_write_iCCP(png_ptr, info_ptr->iccp_name, + info_ptr->iccp_profile); + } +# ifdef PNG_WRITE_sRGB_SUPPORTED + else +# endif +# endif + +# ifdef PNG_WRITE_sRGB_SUPPORTED + if (!(info_ptr->colorspace.flags & PNG_COLORSPACE_INVALID) && + (info_ptr->valid & PNG_INFO_sRGB)) + png_write_sRGB(png_ptr, info_ptr->colorspace.rendering_intent); +# endif /* WRITE_sRGB */ +#endif /* COLORSPACE */ + +#ifdef PNG_WRITE_sBIT_SUPPORTED + if (info_ptr->valid & PNG_INFO_sBIT) + png_write_sBIT(png_ptr, &(info_ptr->sig_bit), info_ptr->color_type); +#endif + +#ifdef PNG_COLORSPACE_SUPPORTED +# ifdef PNG_WRITE_cHRM_SUPPORTED + if (!(info_ptr->colorspace.flags & PNG_COLORSPACE_INVALID) && + (info_ptr->colorspace.flags & PNG_COLORSPACE_FROM_cHRM) && + (info_ptr->valid & PNG_INFO_cHRM)) + png_write_cHRM_fixed(png_ptr, &info_ptr->colorspace.end_points_xy); +# endif +#endif + +#ifdef PNG_WRITE_UNKNOWN_CHUNKS_SUPPORTED + write_unknown_chunks(png_ptr, info_ptr, PNG_HAVE_IHDR); +#endif + + png_ptr->mode |= PNG_WROTE_INFO_BEFORE_PLTE; + } +} + +void PNGAPI +png_write_info(png_structrp png_ptr, png_const_inforp info_ptr) +{ +#if defined(PNG_WRITE_TEXT_SUPPORTED) || defined(PNG_WRITE_sPLT_SUPPORTED) + int i; +#endif + + png_debug(1, "in png_write_info"); + + if (png_ptr == NULL || info_ptr == NULL) + return; + + png_write_info_before_PLTE(png_ptr, info_ptr); + + if (info_ptr->valid & PNG_INFO_PLTE) + png_write_PLTE(png_ptr, info_ptr->palette, + (png_uint_32)info_ptr->num_palette); + + else if (info_ptr->color_type == PNG_COLOR_TYPE_PALETTE) + png_error(png_ptr, "Valid palette required for paletted images"); + +#ifdef PNG_WRITE_tRNS_SUPPORTED + if (info_ptr->valid & PNG_INFO_tRNS) + { +#ifdef PNG_WRITE_INVERT_ALPHA_SUPPORTED + /* Invert the alpha channel (in tRNS) */ + if ((png_ptr->transformations & PNG_INVERT_ALPHA) && + info_ptr->color_type == PNG_COLOR_TYPE_PALETTE) + { + int j; + for (j = 0; j<(int)info_ptr->num_trans; j++) + info_ptr->trans_alpha[j] = + (png_byte)(255 - info_ptr->trans_alpha[j]); + } +#endif + png_write_tRNS(png_ptr, info_ptr->trans_alpha, &(info_ptr->trans_color), + info_ptr->num_trans, info_ptr->color_type); + } +#endif +#ifdef PNG_WRITE_bKGD_SUPPORTED + if (info_ptr->valid & PNG_INFO_bKGD) + png_write_bKGD(png_ptr, &(info_ptr->background), info_ptr->color_type); +#endif + +#ifdef PNG_WRITE_hIST_SUPPORTED + if (info_ptr->valid & PNG_INFO_hIST) + png_write_hIST(png_ptr, info_ptr->hist, info_ptr->num_palette); +#endif + +#ifdef PNG_WRITE_oFFs_SUPPORTED + if (info_ptr->valid & PNG_INFO_oFFs) + png_write_oFFs(png_ptr, info_ptr->x_offset, info_ptr->y_offset, + info_ptr->offset_unit_type); +#endif + +#ifdef PNG_WRITE_pCAL_SUPPORTED + if (info_ptr->valid & PNG_INFO_pCAL) + png_write_pCAL(png_ptr, info_ptr->pcal_purpose, info_ptr->pcal_X0, + info_ptr->pcal_X1, info_ptr->pcal_type, info_ptr->pcal_nparams, + info_ptr->pcal_units, info_ptr->pcal_params); +#endif + +#ifdef PNG_WRITE_sCAL_SUPPORTED + if (info_ptr->valid & PNG_INFO_sCAL) + png_write_sCAL_s(png_ptr, (int)info_ptr->scal_unit, + info_ptr->scal_s_width, info_ptr->scal_s_height); +#endif /* sCAL */ + +#ifdef PNG_WRITE_pHYs_SUPPORTED + if (info_ptr->valid & PNG_INFO_pHYs) + png_write_pHYs(png_ptr, info_ptr->x_pixels_per_unit, + info_ptr->y_pixels_per_unit, info_ptr->phys_unit_type); +#endif /* pHYs */ + +#ifdef PNG_WRITE_tIME_SUPPORTED + if (info_ptr->valid & PNG_INFO_tIME) + { + png_write_tIME(png_ptr, &(info_ptr->mod_time)); + png_ptr->mode |= PNG_WROTE_tIME; + } +#endif /* tIME */ + +#ifdef PNG_WRITE_sPLT_SUPPORTED + if (info_ptr->valid & PNG_INFO_sPLT) + for (i = 0; i < (int)info_ptr->splt_palettes_num; i++) + png_write_sPLT(png_ptr, info_ptr->splt_palettes + i); +#endif /* sPLT */ + +#ifdef PNG_WRITE_TEXT_SUPPORTED + /* Check to see if we need to write text chunks */ + for (i = 0; i < info_ptr->num_text; i++) + { + png_debug2(2, "Writing header text chunk %d, type %d", i, + info_ptr->text[i].compression); + /* An internationalized chunk? */ + if (info_ptr->text[i].compression > 0) + { +#ifdef PNG_WRITE_iTXt_SUPPORTED + /* Write international chunk */ + png_write_iTXt(png_ptr, + info_ptr->text[i].compression, + info_ptr->text[i].key, + info_ptr->text[i].lang, + info_ptr->text[i].lang_key, + info_ptr->text[i].text); +#else + png_warning(png_ptr, "Unable to write international text"); +#endif + /* Mark this chunk as written */ + info_ptr->text[i].compression = PNG_TEXT_COMPRESSION_NONE_WR; + } + + /* If we want a compressed text chunk */ + else if (info_ptr->text[i].compression == PNG_TEXT_COMPRESSION_zTXt) + { +#ifdef PNG_WRITE_zTXt_SUPPORTED + /* Write compressed chunk */ + png_write_zTXt(png_ptr, info_ptr->text[i].key, + info_ptr->text[i].text, 0, + info_ptr->text[i].compression); +#else + png_warning(png_ptr, "Unable to write compressed text"); +#endif + /* Mark this chunk as written */ + info_ptr->text[i].compression = PNG_TEXT_COMPRESSION_zTXt_WR; + } + + else if (info_ptr->text[i].compression == PNG_TEXT_COMPRESSION_NONE) + { +#ifdef PNG_WRITE_tEXt_SUPPORTED + /* Write uncompressed chunk */ + png_write_tEXt(png_ptr, info_ptr->text[i].key, + info_ptr->text[i].text, + 0); + /* Mark this chunk as written */ + info_ptr->text[i].compression = PNG_TEXT_COMPRESSION_NONE_WR; +#else + /* Can't get here */ + png_warning(png_ptr, "Unable to write uncompressed text"); +#endif + } + } +#endif /* tEXt */ + +#ifdef PNG_WRITE_UNKNOWN_CHUNKS_SUPPORTED + write_unknown_chunks(png_ptr, info_ptr, PNG_HAVE_PLTE); +#endif +} + +/* Writes the end of the PNG file. If you don't want to write comments or + * time information, you can pass NULL for info. If you already wrote these + * in png_write_info(), do not write them again here. If you have long + * comments, I suggest writing them here, and compressing them. + */ +void PNGAPI +png_write_end(png_structrp png_ptr, png_inforp info_ptr) +{ + png_debug(1, "in png_write_end"); + + if (png_ptr == NULL) + return; + + if (!(png_ptr->mode & PNG_HAVE_IDAT)) + png_error(png_ptr, "No IDATs written into file"); + +#ifdef PNG_WRITE_CHECK_FOR_INVALID_INDEX_SUPPORTED + if (png_ptr->num_palette_max > png_ptr->num_palette) + png_benign_error(png_ptr, "Wrote palette index exceeding num_palette"); +#endif + + /* See if user wants us to write information chunks */ + if (info_ptr != NULL) + { +#ifdef PNG_WRITE_TEXT_SUPPORTED + int i; /* local index variable */ +#endif +#ifdef PNG_WRITE_tIME_SUPPORTED + /* Check to see if user has supplied a time chunk */ + if ((info_ptr->valid & PNG_INFO_tIME) && + !(png_ptr->mode & PNG_WROTE_tIME)) + png_write_tIME(png_ptr, &(info_ptr->mod_time)); + +#endif +#ifdef PNG_WRITE_TEXT_SUPPORTED + /* Loop through comment chunks */ + for (i = 0; i < info_ptr->num_text; i++) + { + png_debug2(2, "Writing trailer text chunk %d, type %d", i, + info_ptr->text[i].compression); + /* An internationalized chunk? */ + if (info_ptr->text[i].compression > 0) + { +#ifdef PNG_WRITE_iTXt_SUPPORTED + /* Write international chunk */ + png_write_iTXt(png_ptr, + info_ptr->text[i].compression, + info_ptr->text[i].key, + info_ptr->text[i].lang, + info_ptr->text[i].lang_key, + info_ptr->text[i].text); +#else + png_warning(png_ptr, "Unable to write international text"); +#endif + /* Mark this chunk as written */ + info_ptr->text[i].compression = PNG_TEXT_COMPRESSION_NONE_WR; + } + + else if (info_ptr->text[i].compression >= PNG_TEXT_COMPRESSION_zTXt) + { +#ifdef PNG_WRITE_zTXt_SUPPORTED + /* Write compressed chunk */ + png_write_zTXt(png_ptr, info_ptr->text[i].key, + info_ptr->text[i].text, 0, + info_ptr->text[i].compression); +#else + png_warning(png_ptr, "Unable to write compressed text"); +#endif + /* Mark this chunk as written */ + info_ptr->text[i].compression = PNG_TEXT_COMPRESSION_zTXt_WR; + } + + else if (info_ptr->text[i].compression == PNG_TEXT_COMPRESSION_NONE) + { +#ifdef PNG_WRITE_tEXt_SUPPORTED + /* Write uncompressed chunk */ + png_write_tEXt(png_ptr, info_ptr->text[i].key, + info_ptr->text[i].text, 0); +#else + png_warning(png_ptr, "Unable to write uncompressed text"); +#endif + + /* Mark this chunk as written */ + info_ptr->text[i].compression = PNG_TEXT_COMPRESSION_NONE_WR; + } + } +#endif +#ifdef PNG_WRITE_UNKNOWN_CHUNKS_SUPPORTED + write_unknown_chunks(png_ptr, info_ptr, PNG_AFTER_IDAT); +#endif + } + + png_ptr->mode |= PNG_AFTER_IDAT; + + /* Write end of PNG file */ + png_write_IEND(png_ptr); + /* This flush, added in libpng-1.0.8, removed from libpng-1.0.9beta03, + * and restored again in libpng-1.2.30, may cause some applications that + * do not set png_ptr->output_flush_fn to crash. If your application + * experiences a problem, please try building libpng with + * PNG_WRITE_FLUSH_AFTER_IEND_SUPPORTED defined, and report the event to + * png-mng-implement at lists.sf.net . + */ +#ifdef PNG_WRITE_FLUSH_SUPPORTED +# ifdef PNG_WRITE_FLUSH_AFTER_IEND_SUPPORTED + png_flush(png_ptr); +# endif +#endif +} + +#ifdef PNG_CONVERT_tIME_SUPPORTED +void PNGAPI +png_convert_from_struct_tm(png_timep ptime, PNG_CONST struct tm * ttime) +{ + png_debug(1, "in png_convert_from_struct_tm"); + + ptime->year = (png_uint_16)(1900 + ttime->tm_year); + ptime->month = (png_byte)(ttime->tm_mon + 1); + ptime->day = (png_byte)ttime->tm_mday; + ptime->hour = (png_byte)ttime->tm_hour; + ptime->minute = (png_byte)ttime->tm_min; + ptime->second = (png_byte)ttime->tm_sec; +} + +void PNGAPI +png_convert_from_time_t(png_timep ptime, time_t ttime) +{ + struct tm *tbuf; + + png_debug(1, "in png_convert_from_time_t"); + + tbuf = gmtime(&ttime); + png_convert_from_struct_tm(ptime, tbuf); +} +#endif + +/* Initialize png_ptr structure, and allocate any memory needed */ +PNG_FUNCTION(png_structp,PNGAPI +png_create_write_struct,(png_const_charp user_png_ver, png_voidp error_ptr, + png_error_ptr error_fn, png_error_ptr warn_fn),PNG_ALLOCATED) +{ +#ifndef PNG_USER_MEM_SUPPORTED + png_structrp png_ptr = png_create_png_struct(user_png_ver, error_ptr, + error_fn, warn_fn, NULL, NULL, NULL); +#else + return png_create_write_struct_2(user_png_ver, error_ptr, error_fn, + warn_fn, NULL, NULL, NULL); +} + +/* Alternate initialize png_ptr structure, and allocate any memory needed */ +PNG_FUNCTION(png_structp,PNGAPI +png_create_write_struct_2,(png_const_charp user_png_ver, png_voidp error_ptr, + png_error_ptr error_fn, png_error_ptr warn_fn, png_voidp mem_ptr, + png_malloc_ptr malloc_fn, png_free_ptr free_fn),PNG_ALLOCATED) +{ + png_structrp png_ptr = png_create_png_struct(user_png_ver, error_ptr, + error_fn, warn_fn, mem_ptr, malloc_fn, free_fn); +#endif /* PNG_USER_MEM_SUPPORTED */ + if (png_ptr != NULL) + { + /* Set the zlib control values to defaults; they can be overridden by the + * application after the struct has been created. + */ + png_ptr->zbuffer_size = PNG_ZBUF_SIZE; + + /* The 'zlib_strategy' setting is irrelevant because png_default_claim in + * pngwutil.c defaults it according to whether or not filters will be + * used, and ignores this setting. + */ + png_ptr->zlib_strategy = PNG_Z_DEFAULT_STRATEGY; + png_ptr->zlib_level = PNG_Z_DEFAULT_COMPRESSION; + png_ptr->zlib_mem_level = 8; + png_ptr->zlib_window_bits = 15; + png_ptr->zlib_method = 8; + +#ifdef PNG_WRITE_COMPRESSED_TEXT_SUPPORTED + png_ptr->zlib_text_strategy = PNG_TEXT_Z_DEFAULT_STRATEGY; + png_ptr->zlib_text_level = PNG_TEXT_Z_DEFAULT_COMPRESSION; + png_ptr->zlib_text_mem_level = 8; + png_ptr->zlib_text_window_bits = 15; + png_ptr->zlib_text_method = 8; +#endif /* PNG_WRITE_COMPRESSED_TEXT_SUPPORTED */ + + /* This is a highly dubious configuration option; by default it is off, + * but it may be appropriate for private builds that are testing + * extensions not conformant to the current specification, or of + * applications that must not fail to write at all costs! + */ +#ifdef PNG_BENIGN_WRITE_ERRORS_SUPPORTED + png_ptr->flags |= PNG_FLAG_BENIGN_ERRORS_WARN; + /* In stable builds only warn if an application error can be completely + * handled. + */ +#endif + + /* App warnings are warnings in release (or release candidate) builds but + * are errors during development. + */ +#if PNG_LIBPNG_BUILD_BASE_TYPE >= PNG_LIBPNG_BUILD_RC + png_ptr->flags |= PNG_FLAG_APP_WARNINGS_WARN; +#endif + + /* TODO: delay this, it can be done in png_init_io() (if the app doesn't + * do it itself) avoiding setting the default function if it is not + * required. + */ + png_set_write_fn(png_ptr, NULL, NULL, NULL); + } + + return png_ptr; +} + + +/* Write a few rows of image data. If the image is interlaced, + * either you will have to write the 7 sub images, or, if you + * have called png_set_interlace_handling(), you will have to + * "write" the image seven times. + */ +void PNGAPI +png_write_rows(png_structrp png_ptr, png_bytepp row, + png_uint_32 num_rows) +{ + png_uint_32 i; /* row counter */ + png_bytepp rp; /* row pointer */ + + png_debug(1, "in png_write_rows"); + + if (png_ptr == NULL) + return; + + /* Loop through the rows */ + for (i = 0, rp = row; i < num_rows; i++, rp++) + { + png_write_row(png_ptr, *rp); + } +} + +/* Write the image. You only need to call this function once, even + * if you are writing an interlaced image. + */ +void PNGAPI +png_write_image(png_structrp png_ptr, png_bytepp image) +{ + png_uint_32 i; /* row index */ + int pass, num_pass; /* pass variables */ + png_bytepp rp; /* points to current row */ + + if (png_ptr == NULL) + return; + + png_debug(1, "in png_write_image"); + +#ifdef PNG_WRITE_INTERLACING_SUPPORTED + /* Initialize interlace handling. If image is not interlaced, + * this will set pass to 1 + */ + num_pass = png_set_interlace_handling(png_ptr); +#else + num_pass = 1; +#endif + /* Loop through passes */ + for (pass = 0; pass < num_pass; pass++) + { + /* Loop through image */ + for (i = 0, rp = image; i < png_ptr->height; i++, rp++) + { + png_write_row(png_ptr, *rp); + } + } +} + +/* Called by user to write a row of image data */ +void PNGAPI +png_write_row(png_structrp png_ptr, png_const_bytep row) +{ + /* 1.5.6: moved from png_struct to be a local structure: */ + png_row_info row_info; + + if (png_ptr == NULL) + return; + + png_debug2(1, "in png_write_row (row %u, pass %d)", + png_ptr->row_number, png_ptr->pass); + + /* Initialize transformations and other stuff if first time */ + if (png_ptr->row_number == 0 && png_ptr->pass == 0) + { + /* Make sure we wrote the header info */ + if (!(png_ptr->mode & PNG_WROTE_INFO_BEFORE_PLTE)) + png_error(png_ptr, + "png_write_info was never called before png_write_row"); + + /* Check for transforms that have been set but were defined out */ +#if !defined(PNG_WRITE_INVERT_SUPPORTED) && defined(PNG_READ_INVERT_SUPPORTED) + if (png_ptr->transformations & PNG_INVERT_MONO) + png_warning(png_ptr, "PNG_WRITE_INVERT_SUPPORTED is not defined"); +#endif + +#if !defined(PNG_WRITE_FILLER_SUPPORTED) && defined(PNG_READ_FILLER_SUPPORTED) + if (png_ptr->transformations & PNG_FILLER) + png_warning(png_ptr, "PNG_WRITE_FILLER_SUPPORTED is not defined"); +#endif +#if !defined(PNG_WRITE_PACKSWAP_SUPPORTED) && \ + defined(PNG_READ_PACKSWAP_SUPPORTED) + if (png_ptr->transformations & PNG_PACKSWAP) + png_warning(png_ptr, + "PNG_WRITE_PACKSWAP_SUPPORTED is not defined"); +#endif + +#if !defined(PNG_WRITE_PACK_SUPPORTED) && defined(PNG_READ_PACK_SUPPORTED) + if (png_ptr->transformations & PNG_PACK) + png_warning(png_ptr, "PNG_WRITE_PACK_SUPPORTED is not defined"); +#endif + +#if !defined(PNG_WRITE_SHIFT_SUPPORTED) && defined(PNG_READ_SHIFT_SUPPORTED) + if (png_ptr->transformations & PNG_SHIFT) + png_warning(png_ptr, "PNG_WRITE_SHIFT_SUPPORTED is not defined"); +#endif + +#if !defined(PNG_WRITE_BGR_SUPPORTED) && defined(PNG_READ_BGR_SUPPORTED) + if (png_ptr->transformations & PNG_BGR) + png_warning(png_ptr, "PNG_WRITE_BGR_SUPPORTED is not defined"); +#endif + +#if !defined(PNG_WRITE_SWAP_SUPPORTED) && defined(PNG_READ_SWAP_SUPPORTED) + if (png_ptr->transformations & PNG_SWAP_BYTES) + png_warning(png_ptr, "PNG_WRITE_SWAP_SUPPORTED is not defined"); +#endif + + png_write_start_row(png_ptr); + } + +#ifdef PNG_WRITE_INTERLACING_SUPPORTED + /* If interlaced and not interested in row, return */ + if (png_ptr->interlaced && (png_ptr->transformations & PNG_INTERLACE)) + { + switch (png_ptr->pass) + { + case 0: + if (png_ptr->row_number & 0x07) + { + png_write_finish_row(png_ptr); + return; + } + break; + + case 1: + if ((png_ptr->row_number & 0x07) || png_ptr->width < 5) + { + png_write_finish_row(png_ptr); + return; + } + break; + + case 2: + if ((png_ptr->row_number & 0x07) != 4) + { + png_write_finish_row(png_ptr); + return; + } + break; + + case 3: + if ((png_ptr->row_number & 0x03) || png_ptr->width < 3) + { + png_write_finish_row(png_ptr); + return; + } + break; + + case 4: + if ((png_ptr->row_number & 0x03) != 2) + { + png_write_finish_row(png_ptr); + return; + } + break; + + case 5: + if ((png_ptr->row_number & 0x01) || png_ptr->width < 2) + { + png_write_finish_row(png_ptr); + return; + } + break; + + case 6: + if (!(png_ptr->row_number & 0x01)) + { + png_write_finish_row(png_ptr); + return; + } + break; + + default: /* error: ignore it */ + break; + } + } +#endif + + /* Set up row info for transformations */ + row_info.color_type = png_ptr->color_type; + row_info.width = png_ptr->usr_width; + row_info.channels = png_ptr->usr_channels; + row_info.bit_depth = png_ptr->usr_bit_depth; + row_info.pixel_depth = (png_byte)(row_info.bit_depth * row_info.channels); + row_info.rowbytes = PNG_ROWBYTES(row_info.pixel_depth, row_info.width); + + png_debug1(3, "row_info->color_type = %d", row_info.color_type); + png_debug1(3, "row_info->width = %u", row_info.width); + png_debug1(3, "row_info->channels = %d", row_info.channels); + png_debug1(3, "row_info->bit_depth = %d", row_info.bit_depth); + png_debug1(3, "row_info->pixel_depth = %d", row_info.pixel_depth); + png_debug1(3, "row_info->rowbytes = %lu", (unsigned long)row_info.rowbytes); + + /* Copy user's row into buffer, leaving room for filter byte. */ + memcpy(png_ptr->row_buf + 1, row, row_info.rowbytes); + +#ifdef PNG_WRITE_INTERLACING_SUPPORTED + /* Handle interlacing */ + if (png_ptr->interlaced && png_ptr->pass < 6 && + (png_ptr->transformations & PNG_INTERLACE)) + { + png_do_write_interlace(&row_info, png_ptr->row_buf + 1, png_ptr->pass); + /* This should always get caught above, but still ... */ + if (!(row_info.width)) + { + png_write_finish_row(png_ptr); + return; + } + } +#endif + +#ifdef PNG_WRITE_TRANSFORMS_SUPPORTED + /* Handle other transformations */ + if (png_ptr->transformations) + png_do_write_transformations(png_ptr, &row_info); +#endif + + /* At this point the row_info pixel depth must match the 'transformed' depth, + * which is also the output depth. + */ + if (row_info.pixel_depth != png_ptr->pixel_depth || + row_info.pixel_depth != png_ptr->transformed_pixel_depth) + png_error(png_ptr, "internal write transform logic error"); + +#ifdef PNG_MNG_FEATURES_SUPPORTED + /* Write filter_method 64 (intrapixel differencing) only if + * 1. Libpng was compiled with PNG_MNG_FEATURES_SUPPORTED and + * 2. Libpng did not write a PNG signature (this filter_method is only + * used in PNG datastreams that are embedded in MNG datastreams) and + * 3. The application called png_permit_mng_features with a mask that + * included PNG_FLAG_MNG_FILTER_64 and + * 4. The filter_method is 64 and + * 5. The color_type is RGB or RGBA + */ + if ((png_ptr->mng_features_permitted & PNG_FLAG_MNG_FILTER_64) && + (png_ptr->filter_type == PNG_INTRAPIXEL_DIFFERENCING)) + { + /* Intrapixel differencing */ + png_do_write_intrapixel(&row_info, png_ptr->row_buf + 1); + } +#endif + +/* Added at libpng-1.5.10 */ +#ifdef PNG_WRITE_CHECK_FOR_INVALID_INDEX_SUPPORTED + /* Check for out-of-range palette index */ + if (row_info.color_type == PNG_COLOR_TYPE_PALETTE && + png_ptr->num_palette_max >= 0) + png_do_check_palette_indexes(png_ptr, &row_info); +#endif + + /* Find a filter if necessary, filter the row and write it out. */ + png_write_find_filter(png_ptr, &row_info); + + if (png_ptr->write_row_fn != NULL) + (*(png_ptr->write_row_fn))(png_ptr, png_ptr->row_number, png_ptr->pass); +} + +#ifdef PNG_WRITE_FLUSH_SUPPORTED +/* Set the automatic flush interval or 0 to turn flushing off */ +void PNGAPI +png_set_flush(png_structrp png_ptr, int nrows) +{ + png_debug(1, "in png_set_flush"); + + if (png_ptr == NULL) + return; + + png_ptr->flush_dist = (nrows < 0 ? 0 : nrows); +} + +/* Flush the current output buffers now */ +void PNGAPI +png_write_flush(png_structrp png_ptr) +{ + png_debug(1, "in png_write_flush"); + + if (png_ptr == NULL) + return; + + /* We have already written out all of the data */ + if (png_ptr->row_number >= png_ptr->num_rows) + return; + + png_compress_IDAT(png_ptr, NULL, 0, Z_SYNC_FLUSH); + png_ptr->flush_rows = 0; + png_flush(png_ptr); +} +#endif /* PNG_WRITE_FLUSH_SUPPORTED */ + +#ifdef PNG_WRITE_WEIGHTED_FILTER_SUPPORTED +static void png_reset_filter_heuristics(png_structrp png_ptr);/* forward decl */ +#endif + +/* Free any memory used in png_ptr struct without freeing the struct itself. */ +static void +png_write_destroy(png_structrp png_ptr) +{ + png_debug(1, "in png_write_destroy"); + + /* Free any memory zlib uses */ + if (png_ptr->flags & PNG_FLAG_ZSTREAM_INITIALIZED) + deflateEnd(&png_ptr->zstream); + + /* Free our memory. png_free checks NULL for us. */ + png_free_buffer_list(png_ptr, &png_ptr->zbuffer_list); + png_free(png_ptr, png_ptr->row_buf); +#ifdef PNG_WRITE_FILTER_SUPPORTED + png_free(png_ptr, png_ptr->prev_row); + png_free(png_ptr, png_ptr->sub_row); + png_free(png_ptr, png_ptr->up_row); + png_free(png_ptr, png_ptr->avg_row); + png_free(png_ptr, png_ptr->paeth_row); +#endif + +#ifdef PNG_WRITE_WEIGHTED_FILTER_SUPPORTED + /* Use this to save a little code space, it doesn't free the filter_costs */ + png_reset_filter_heuristics(png_ptr); + png_free(png_ptr, png_ptr->filter_costs); + png_free(png_ptr, png_ptr->inv_filter_costs); +#endif + +#ifdef PNG_SET_UNKNOWN_CHUNKS_SUPPORTED + png_free(png_ptr, png_ptr->chunk_list); +#endif + + /* The error handling and memory handling information is left intact at this + * point: the jmp_buf may still have to be freed. See png_destroy_png_struct + * for how this happens. + */ +} + +/* Free all memory used by the write. + * In libpng 1.6.0 this API changed quietly to no longer accept a NULL value for + * *png_ptr_ptr. Prior to 1.6.0 it would accept such a value and it would free + * the passed in info_structs but it would quietly fail to free any of the data + * inside them. In 1.6.0 it quietly does nothing (it has to be quiet because it + * has no png_ptr.) + */ +void PNGAPI +png_destroy_write_struct(png_structpp png_ptr_ptr, png_infopp info_ptr_ptr) +{ + png_debug(1, "in png_destroy_write_struct"); + + if (png_ptr_ptr != NULL) + { + png_structrp png_ptr = *png_ptr_ptr; + + if (png_ptr != NULL) /* added in libpng 1.6.0 */ + { + png_destroy_info_struct(png_ptr, info_ptr_ptr); + + *png_ptr_ptr = NULL; + png_write_destroy(png_ptr); + png_destroy_png_struct(png_ptr); + } + } +} + +/* Allow the application to select one or more row filters to use. */ +void PNGAPI +png_set_filter(png_structrp png_ptr, int method, int filters) +{ + png_debug(1, "in png_set_filter"); + + if (png_ptr == NULL) + return; + +#ifdef PNG_MNG_FEATURES_SUPPORTED + if ((png_ptr->mng_features_permitted & PNG_FLAG_MNG_FILTER_64) && + (method == PNG_INTRAPIXEL_DIFFERENCING)) + method = PNG_FILTER_TYPE_BASE; + +#endif + if (method == PNG_FILTER_TYPE_BASE) + { + switch (filters & (PNG_ALL_FILTERS | 0x07)) + { +#ifdef PNG_WRITE_FILTER_SUPPORTED + case 5: + case 6: + case 7: png_app_error(png_ptr, "Unknown row filter for method 0"); + /* FALL THROUGH */ +#endif /* PNG_WRITE_FILTER_SUPPORTED */ + case PNG_FILTER_VALUE_NONE: + png_ptr->do_filter = PNG_FILTER_NONE; break; + +#ifdef PNG_WRITE_FILTER_SUPPORTED + case PNG_FILTER_VALUE_SUB: + png_ptr->do_filter = PNG_FILTER_SUB; break; + + case PNG_FILTER_VALUE_UP: + png_ptr->do_filter = PNG_FILTER_UP; break; + + case PNG_FILTER_VALUE_AVG: + png_ptr->do_filter = PNG_FILTER_AVG; break; + + case PNG_FILTER_VALUE_PAETH: + png_ptr->do_filter = PNG_FILTER_PAETH; break; + + default: + png_ptr->do_filter = (png_byte)filters; break; +#else + default: + png_app_error(png_ptr, "Unknown row filter for method 0"); +#endif /* PNG_WRITE_FILTER_SUPPORTED */ + } + + /* If we have allocated the row_buf, this means we have already started + * with the image and we should have allocated all of the filter buffers + * that have been selected. If prev_row isn't already allocated, then + * it is too late to start using the filters that need it, since we + * will be missing the data in the previous row. If an application + * wants to start and stop using particular filters during compression, + * it should start out with all of the filters, and then add and + * remove them after the start of compression. + */ + if (png_ptr->row_buf != NULL) + { +#ifdef PNG_WRITE_FILTER_SUPPORTED + if ((png_ptr->do_filter & PNG_FILTER_SUB) && png_ptr->sub_row == NULL) + { + png_ptr->sub_row = (png_bytep)png_malloc(png_ptr, + (png_ptr->rowbytes + 1)); + png_ptr->sub_row[0] = PNG_FILTER_VALUE_SUB; + } + + if ((png_ptr->do_filter & PNG_FILTER_UP) && png_ptr->up_row == NULL) + { + if (png_ptr->prev_row == NULL) + { + png_warning(png_ptr, "Can't add Up filter after starting"); + png_ptr->do_filter = (png_byte)(png_ptr->do_filter & + ~PNG_FILTER_UP); + } + + else + { + png_ptr->up_row = (png_bytep)png_malloc(png_ptr, + (png_ptr->rowbytes + 1)); + png_ptr->up_row[0] = PNG_FILTER_VALUE_UP; + } + } + + if ((png_ptr->do_filter & PNG_FILTER_AVG) && png_ptr->avg_row == NULL) + { + if (png_ptr->prev_row == NULL) + { + png_warning(png_ptr, "Can't add Average filter after starting"); + png_ptr->do_filter = (png_byte)(png_ptr->do_filter & + ~PNG_FILTER_AVG); + } + + else + { + png_ptr->avg_row = (png_bytep)png_malloc(png_ptr, + (png_ptr->rowbytes + 1)); + png_ptr->avg_row[0] = PNG_FILTER_VALUE_AVG; + } + } + + if ((png_ptr->do_filter & PNG_FILTER_PAETH) && + png_ptr->paeth_row == NULL) + { + if (png_ptr->prev_row == NULL) + { + png_warning(png_ptr, "Can't add Paeth filter after starting"); + png_ptr->do_filter &= (png_byte)(~PNG_FILTER_PAETH); + } + + else + { + png_ptr->paeth_row = (png_bytep)png_malloc(png_ptr, + (png_ptr->rowbytes + 1)); + png_ptr->paeth_row[0] = PNG_FILTER_VALUE_PAETH; + } + } + + if (png_ptr->do_filter == PNG_NO_FILTERS) +#endif /* PNG_WRITE_FILTER_SUPPORTED */ + png_ptr->do_filter = PNG_FILTER_NONE; + } + } + else + png_error(png_ptr, "Unknown custom filter method"); +} + +/* This allows us to influence the way in which libpng chooses the "best" + * filter for the current scanline. While the "minimum-sum-of-absolute- + * differences metric is relatively fast and effective, there is some + * question as to whether it can be improved upon by trying to keep the + * filtered data going to zlib more consistent, hopefully resulting in + * better compression. + */ +#ifdef PNG_WRITE_WEIGHTED_FILTER_SUPPORTED /* GRR 970116 */ +/* Convenience reset API. */ +static void +png_reset_filter_heuristics(png_structrp png_ptr) +{ + /* Clear out any old values in the 'weights' - this must be done because if + * the app calls set_filter_heuristics multiple times with different + * 'num_weights' values we would otherwise potentially have wrong sized + * arrays. + */ + png_ptr->num_prev_filters = 0; + png_ptr->heuristic_method = PNG_FILTER_HEURISTIC_UNWEIGHTED; + if (png_ptr->prev_filters != NULL) + { + png_bytep old = png_ptr->prev_filters; + png_ptr->prev_filters = NULL; + png_free(png_ptr, old); + } + if (png_ptr->filter_weights != NULL) + { + png_uint_16p old = png_ptr->filter_weights; + png_ptr->filter_weights = NULL; + png_free(png_ptr, old); + } + + if (png_ptr->inv_filter_weights != NULL) + { + png_uint_16p old = png_ptr->inv_filter_weights; + png_ptr->inv_filter_weights = NULL; + png_free(png_ptr, old); + } + + /* Leave the filter_costs - this array is fixed size. */ +} + +static int +png_init_filter_heuristics(png_structrp png_ptr, int heuristic_method, + int num_weights) +{ + if (png_ptr == NULL) + return 0; + + /* Clear out the arrays */ + png_reset_filter_heuristics(png_ptr); + + /* Check arguments; the 'reset' function makes the correct settings for the + * unweighted case, but we must handle the weight case by initializing the + * arrays for the caller. + */ + if (heuristic_method == PNG_FILTER_HEURISTIC_WEIGHTED) + { + int i; + + if (num_weights > 0) + { + png_ptr->prev_filters = (png_bytep)png_malloc(png_ptr, + (png_uint_32)((sizeof (png_byte)) * num_weights)); + + /* To make sure that the weighting starts out fairly */ + for (i = 0; i < num_weights; i++) + { + png_ptr->prev_filters[i] = 255; + } + + png_ptr->filter_weights = (png_uint_16p)png_malloc(png_ptr, + (png_uint_32)((sizeof (png_uint_16)) * num_weights)); + + png_ptr->inv_filter_weights = (png_uint_16p)png_malloc(png_ptr, + (png_uint_32)((sizeof (png_uint_16)) * num_weights)); + + for (i = 0; i < num_weights; i++) + { + png_ptr->inv_filter_weights[i] = + png_ptr->filter_weights[i] = PNG_WEIGHT_FACTOR; + } + + /* Safe to set this now */ + png_ptr->num_prev_filters = (png_byte)num_weights; + } + + /* If, in the future, there are other filter methods, this would + * need to be based on png_ptr->filter. + */ + if (png_ptr->filter_costs == NULL) + { + png_ptr->filter_costs = (png_uint_16p)png_malloc(png_ptr, + (png_uint_32)((sizeof (png_uint_16)) * PNG_FILTER_VALUE_LAST)); + + png_ptr->inv_filter_costs = (png_uint_16p)png_malloc(png_ptr, + (png_uint_32)((sizeof (png_uint_16)) * PNG_FILTER_VALUE_LAST)); + } + + for (i = 0; i < PNG_FILTER_VALUE_LAST; i++) + { + png_ptr->inv_filter_costs[i] = + png_ptr->filter_costs[i] = PNG_COST_FACTOR; + } + + /* All the arrays are inited, safe to set this: */ + png_ptr->heuristic_method = PNG_FILTER_HEURISTIC_WEIGHTED; + + /* Return the 'ok' code. */ + return 1; + } + else if (heuristic_method == PNG_FILTER_HEURISTIC_DEFAULT || + heuristic_method == PNG_FILTER_HEURISTIC_UNWEIGHTED) + { + return 1; + } + else + { + png_warning(png_ptr, "Unknown filter heuristic method"); + return 0; + } +} + +/* Provide floating and fixed point APIs */ +#ifdef PNG_FLOATING_POINT_SUPPORTED +void PNGAPI +png_set_filter_heuristics(png_structrp png_ptr, int heuristic_method, + int num_weights, png_const_doublep filter_weights, + png_const_doublep filter_costs) +{ + png_debug(1, "in png_set_filter_heuristics"); + + /* The internal API allocates all the arrays and ensures that the elements of + * those arrays are set to the default value. + */ + if (!png_init_filter_heuristics(png_ptr, heuristic_method, num_weights)) + return; + + /* If using the weighted method copy in the weights. */ + if (heuristic_method == PNG_FILTER_HEURISTIC_WEIGHTED) + { + int i; + for (i = 0; i < num_weights; i++) + { + if (filter_weights[i] <= 0.0) + { + png_ptr->inv_filter_weights[i] = + png_ptr->filter_weights[i] = PNG_WEIGHT_FACTOR; + } + + else + { + png_ptr->inv_filter_weights[i] = + (png_uint_16)(PNG_WEIGHT_FACTOR*filter_weights[i]+.5); + + png_ptr->filter_weights[i] = + (png_uint_16)(PNG_WEIGHT_FACTOR/filter_weights[i]+.5); + } + } + + /* Here is where we set the relative costs of the different filters. We + * should take the desired compression level into account when setting + * the costs, so that Paeth, for instance, has a high relative cost at low + * compression levels, while it has a lower relative cost at higher + * compression settings. The filter types are in order of increasing + * relative cost, so it would be possible to do this with an algorithm. + */ + for (i = 0; i < PNG_FILTER_VALUE_LAST; i++) if (filter_costs[i] >= 1.0) + { + png_ptr->inv_filter_costs[i] = + (png_uint_16)(PNG_COST_FACTOR / filter_costs[i] + .5); + + png_ptr->filter_costs[i] = + (png_uint_16)(PNG_COST_FACTOR * filter_costs[i] + .5); + } + } +} +#endif /* FLOATING_POINT */ + +#ifdef PNG_FIXED_POINT_SUPPORTED +void PNGAPI +png_set_filter_heuristics_fixed(png_structrp png_ptr, int heuristic_method, + int num_weights, png_const_fixed_point_p filter_weights, + png_const_fixed_point_p filter_costs) +{ + png_debug(1, "in png_set_filter_heuristics_fixed"); + + /* The internal API allocates all the arrays and ensures that the elements of + * those arrays are set to the default value. + */ + if (!png_init_filter_heuristics(png_ptr, heuristic_method, num_weights)) + return; + + /* If using the weighted method copy in the weights. */ + if (heuristic_method == PNG_FILTER_HEURISTIC_WEIGHTED) + { + int i; + for (i = 0; i < num_weights; i++) + { + if (filter_weights[i] <= 0) + { + png_ptr->inv_filter_weights[i] = + png_ptr->filter_weights[i] = PNG_WEIGHT_FACTOR; + } + + else + { + png_ptr->inv_filter_weights[i] = (png_uint_16) + ((PNG_WEIGHT_FACTOR*filter_weights[i]+PNG_FP_HALF)/PNG_FP_1); + + png_ptr->filter_weights[i] = (png_uint_16)((PNG_WEIGHT_FACTOR* + PNG_FP_1+(filter_weights[i]/2))/filter_weights[i]); + } + } + + /* Here is where we set the relative costs of the different filters. We + * should take the desired compression level into account when setting + * the costs, so that Paeth, for instance, has a high relative cost at low + * compression levels, while it has a lower relative cost at higher + * compression settings. The filter types are in order of increasing + * relative cost, so it would be possible to do this with an algorithm. + */ + for (i = 0; i < PNG_FILTER_VALUE_LAST; i++) + if (filter_costs[i] >= PNG_FP_1) + { + png_uint_32 tmp; + + /* Use a 32 bit unsigned temporary here because otherwise the + * intermediate value will be a 32 bit *signed* integer (ANSI rules) + * and this will get the wrong answer on division. + */ + tmp = PNG_COST_FACTOR*PNG_FP_1 + (filter_costs[i]/2); + tmp /= filter_costs[i]; + + png_ptr->inv_filter_costs[i] = (png_uint_16)tmp; + + tmp = PNG_COST_FACTOR * filter_costs[i] + PNG_FP_HALF; + tmp /= PNG_FP_1; + + png_ptr->filter_costs[i] = (png_uint_16)tmp; + } + } +} +#endif /* FIXED_POINT */ +#endif /* PNG_WRITE_WEIGHTED_FILTER_SUPPORTED */ + +void PNGAPI +png_set_compression_level(png_structrp png_ptr, int level) +{ + png_debug(1, "in png_set_compression_level"); + + if (png_ptr == NULL) + return; + + png_ptr->zlib_level = level; +} + +void PNGAPI +png_set_compression_mem_level(png_structrp png_ptr, int mem_level) +{ + png_debug(1, "in png_set_compression_mem_level"); + + if (png_ptr == NULL) + return; + + png_ptr->zlib_mem_level = mem_level; +} + +void PNGAPI +png_set_compression_strategy(png_structrp png_ptr, int strategy) +{ + png_debug(1, "in png_set_compression_strategy"); + + if (png_ptr == NULL) + return; + + /* The flag setting here prevents the libpng dynamic selection of strategy. + */ + png_ptr->flags |= PNG_FLAG_ZLIB_CUSTOM_STRATEGY; + png_ptr->zlib_strategy = strategy; +} + +/* If PNG_WRITE_OPTIMIZE_CMF_SUPPORTED is defined, libpng will use a + * smaller value of window_bits if it can do so safely. + */ +void PNGAPI +png_set_compression_window_bits(png_structrp png_ptr, int window_bits) +{ + if (png_ptr == NULL) + return; + + /* Prior to 1.6.0 this would warn but then set the window_bits value, this + * meant that negative window bits values could be selected which would cause + * libpng to write a non-standard PNG file with raw deflate or gzip + * compressed IDAT or ancillary chunks. Such files can be read and there is + * no warning on read, so this seems like a very bad idea. + */ + if (window_bits > 15) + { + png_warning(png_ptr, "Only compression windows <= 32k supported by PNG"); + window_bits = 15; + } + + else if (window_bits < 8) + { + png_warning(png_ptr, "Only compression windows >= 256 supported by PNG"); + window_bits = 8; + } + + png_ptr->zlib_window_bits = window_bits; +} + +void PNGAPI +png_set_compression_method(png_structrp png_ptr, int method) +{ + png_debug(1, "in png_set_compression_method"); + + if (png_ptr == NULL) + return; + + /* This would produce an invalid PNG file if it worked, but it doesn't and + * deflate will fault it, so it is harmless to just warn here. + */ + if (method != 8) + png_warning(png_ptr, "Only compression method 8 is supported by PNG"); + + png_ptr->zlib_method = method; +} + +/* The following were added to libpng-1.5.4 */ +#ifdef PNG_WRITE_CUSTOMIZE_ZTXT_COMPRESSION_SUPPORTED +void PNGAPI +png_set_text_compression_level(png_structrp png_ptr, int level) +{ + png_debug(1, "in png_set_text_compression_level"); + + if (png_ptr == NULL) + return; + + png_ptr->zlib_text_level = level; +} + +void PNGAPI +png_set_text_compression_mem_level(png_structrp png_ptr, int mem_level) +{ + png_debug(1, "in png_set_text_compression_mem_level"); + + if (png_ptr == NULL) + return; + + png_ptr->zlib_text_mem_level = mem_level; +} + +void PNGAPI +png_set_text_compression_strategy(png_structrp png_ptr, int strategy) +{ + png_debug(1, "in png_set_text_compression_strategy"); + + if (png_ptr == NULL) + return; + + png_ptr->zlib_text_strategy = strategy; +} + +/* If PNG_WRITE_OPTIMIZE_CMF_SUPPORTED is defined, libpng will use a + * smaller value of window_bits if it can do so safely. + */ +void PNGAPI +png_set_text_compression_window_bits(png_structrp png_ptr, int window_bits) +{ + if (png_ptr == NULL) + return; + + if (window_bits > 15) + { + png_warning(png_ptr, "Only compression windows <= 32k supported by PNG"); + window_bits = 15; + } + + else if (window_bits < 8) + { + png_warning(png_ptr, "Only compression windows >= 256 supported by PNG"); + window_bits = 8; + } + + png_ptr->zlib_text_window_bits = window_bits; +} + +void PNGAPI +png_set_text_compression_method(png_structrp png_ptr, int method) +{ + png_debug(1, "in png_set_text_compression_method"); + + if (png_ptr == NULL) + return; + + if (method != 8) + png_warning(png_ptr, "Only compression method 8 is supported by PNG"); + + png_ptr->zlib_text_method = method; +} +#endif /* PNG_WRITE_CUSTOMIZE_ZTXT_COMPRESSION_SUPPORTED */ +/* end of API added to libpng-1.5.4 */ + +void PNGAPI +png_set_write_status_fn(png_structrp png_ptr, png_write_status_ptr write_row_fn) +{ + if (png_ptr == NULL) + return; + + png_ptr->write_row_fn = write_row_fn; +} + +#ifdef PNG_WRITE_USER_TRANSFORM_SUPPORTED +void PNGAPI +png_set_write_user_transform_fn(png_structrp png_ptr, png_user_transform_ptr + write_user_transform_fn) +{ + png_debug(1, "in png_set_write_user_transform_fn"); + + if (png_ptr == NULL) + return; + + png_ptr->transformations |= PNG_USER_TRANSFORM; + png_ptr->write_user_transform_fn = write_user_transform_fn; +} +#endif + + +#ifdef PNG_INFO_IMAGE_SUPPORTED +void PNGAPI +png_write_png(png_structrp png_ptr, png_inforp info_ptr, + int transforms, voidp params) +{ + if (png_ptr == NULL || info_ptr == NULL) + return; + + /* Write the file header information. */ + png_write_info(png_ptr, info_ptr); + + /* ------ these transformations don't touch the info structure ------- */ + +#ifdef PNG_WRITE_INVERT_SUPPORTED + /* Invert monochrome pixels */ + if (transforms & PNG_TRANSFORM_INVERT_MONO) + png_set_invert_mono(png_ptr); +#endif + +#ifdef PNG_WRITE_SHIFT_SUPPORTED + /* Shift the pixels up to a legal bit depth and fill in + * as appropriate to correctly scale the image. + */ + if ((transforms & PNG_TRANSFORM_SHIFT) + && (info_ptr->valid & PNG_INFO_sBIT)) + png_set_shift(png_ptr, &info_ptr->sig_bit); +#endif + +#ifdef PNG_WRITE_PACK_SUPPORTED + /* Pack pixels into bytes */ + if (transforms & PNG_TRANSFORM_PACKING) + png_set_packing(png_ptr); +#endif + +#ifdef PNG_WRITE_SWAP_ALPHA_SUPPORTED + /* Swap location of alpha bytes from ARGB to RGBA */ + if (transforms & PNG_TRANSFORM_SWAP_ALPHA) + png_set_swap_alpha(png_ptr); +#endif + +#ifdef PNG_WRITE_FILLER_SUPPORTED + /* Pack XRGB/RGBX/ARGB/RGBA into RGB (4 channels -> 3 channels) */ + if (transforms & PNG_TRANSFORM_STRIP_FILLER_AFTER) + png_set_filler(png_ptr, 0, PNG_FILLER_AFTER); + + else if (transforms & PNG_TRANSFORM_STRIP_FILLER_BEFORE) + png_set_filler(png_ptr, 0, PNG_FILLER_BEFORE); +#endif + +#ifdef PNG_WRITE_BGR_SUPPORTED + /* Flip BGR pixels to RGB */ + if (transforms & PNG_TRANSFORM_BGR) + png_set_bgr(png_ptr); +#endif + +#ifdef PNG_WRITE_SWAP_SUPPORTED + /* Swap bytes of 16-bit files to most significant byte first */ + if (transforms & PNG_TRANSFORM_SWAP_ENDIAN) + png_set_swap(png_ptr); +#endif + +#ifdef PNG_WRITE_PACKSWAP_SUPPORTED + /* Swap bits of 1, 2, 4 bit packed pixel formats */ + if (transforms & PNG_TRANSFORM_PACKSWAP) + png_set_packswap(png_ptr); +#endif + +#ifdef PNG_WRITE_INVERT_ALPHA_SUPPORTED + /* Invert the alpha channel from opacity to transparency */ + if (transforms & PNG_TRANSFORM_INVERT_ALPHA) + png_set_invert_alpha(png_ptr); +#endif + + /* ----------------------- end of transformations ------------------- */ + + /* Write the bits */ + if (info_ptr->valid & PNG_INFO_IDAT) + png_write_image(png_ptr, info_ptr->row_pointers); + + /* It is REQUIRED to call this to finish writing the rest of the file */ + png_write_end(png_ptr, info_ptr); + + PNG_UNUSED(transforms) /* Quiet compiler warnings */ + PNG_UNUSED(params) +} +#endif + + +#ifdef PNG_SIMPLIFIED_WRITE_SUPPORTED +#ifdef PNG_STDIO_SUPPORTED /* currently required for png_image_write_* */ +/* Initialize the write structure - general purpose utility. */ +static int +png_image_write_init(png_imagep image) +{ + png_structp png_ptr = png_create_write_struct(PNG_LIBPNG_VER_STRING, image, + png_safe_error, png_safe_warning); + + if (png_ptr != NULL) + { + png_infop info_ptr = png_create_info_struct(png_ptr); + + if (info_ptr != NULL) + { + png_controlp control = png_voidcast(png_controlp, + png_malloc_warn(png_ptr, (sizeof *control))); + + if (control != NULL) + { + memset(control, 0, (sizeof *control)); + + control->png_ptr = png_ptr; + control->info_ptr = info_ptr; + control->for_write = 1; + + image->opaque = control; + return 1; + } + + /* Error clean up */ + png_destroy_info_struct(png_ptr, &info_ptr); + } + + png_destroy_write_struct(&png_ptr, NULL); + } + + return png_image_error(image, "png_image_write_: out of memory"); +} + +/* Arguments to png_image_write_main: */ +typedef struct +{ + /* Arguments: */ + png_imagep image; + png_const_voidp buffer; + png_int_32 row_stride; + png_const_voidp colormap; + int convert_to_8bit; + /* Local variables: */ + png_const_voidp first_row; + ptrdiff_t row_bytes; + png_voidp local_row; +} png_image_write_control; + +/* Write png_uint_16 input to a 16-bit PNG; the png_ptr has already been set to + * do any necessary byte swapping. The component order is defined by the + * png_image format value. + */ +static int +png_write_image_16bit(png_voidp argument) +{ + png_image_write_control *display = png_voidcast(png_image_write_control*, + argument); + png_imagep image = display->image; + png_structrp png_ptr = image->opaque->png_ptr; + + png_const_uint_16p input_row = png_voidcast(png_const_uint_16p, + display->first_row); + png_uint_16p output_row = png_voidcast(png_uint_16p, display->local_row); + png_uint_16p row_end; + const int channels = (image->format & PNG_FORMAT_FLAG_COLOR) ? 3 : 1; + int aindex = 0; + png_uint_32 y = image->height; + + if (image->format & PNG_FORMAT_FLAG_ALPHA) + { + if (image->format & PNG_FORMAT_FLAG_AFIRST) + { + aindex = -1; + ++input_row; /* To point to the first component */ + ++output_row; + } + + else + aindex = channels; + } + + else + png_error(png_ptr, "png_write_image: internal call error"); + + /* Work out the output row end and count over this, note that the increment + * above to 'row' means that row_end can actually be beyond the end of the + * row; this is correct. + */ + row_end = output_row + image->width * (channels+1); + + while (y-- > 0) + { + png_const_uint_16p in_ptr = input_row; + png_uint_16p out_ptr = output_row; + + while (out_ptr < row_end) + { + const png_uint_16 alpha = in_ptr[aindex]; + png_uint_32 reciprocal = 0; + int c; + + out_ptr[aindex] = alpha; + + /* Calculate a reciprocal. The correct calculation is simply + * component/alpha*65535 << 15. (I.e. 15 bits of precision); this + * allows correct rounding by adding .5 before the shift. 'reciprocal' + * is only initialized when required. + */ + if (alpha > 0 && alpha < 65535) + reciprocal = ((0xffff<<15)+(alpha>>1))/alpha; + + c = channels; + do /* always at least one channel */ + { + png_uint_16 component = *in_ptr++; + + /* The following gives 65535 for an alpha of 0, which is fine, + * otherwise if 0/0 is represented as some other value there is more + * likely to be a discontinuity which will probably damage + * compression when moving from a fully transparent area to a + * nearly transparent one. (The assumption here is that opaque + * areas tend not to be 0 intensity.) + */ + if (component >= alpha) + component = 65535; + + /* component 0 && alpha < 65535) + { + png_uint_32 calc = component * reciprocal; + calc += 16384; /* round to nearest */ + component = (png_uint_16)(calc >> 15); + } + + *out_ptr++ = component; + } + while (--c > 0); + + /* Skip to next component (skip the intervening alpha channel) */ + ++in_ptr; + ++out_ptr; + } + + png_write_row(png_ptr, png_voidcast(png_const_bytep, display->local_row)); + input_row += display->row_bytes/(sizeof (png_uint_16)); + } + + return 1; +} + +/* Given 16-bit input (1 to 4 channels) write 8-bit output. If an alpha channel + * is present it must be removed from the components, the components are then + * written in sRGB encoding. No components are added or removed. + * + * Calculate an alpha reciprocal to reverse pre-multiplication. As above the + * calculation can be done to 15 bits of accuracy; however, the output needs to + * be scaled in the range 0..255*65535, so include that scaling here. + */ +#define UNP_RECIPROCAL(alpha) ((((0xffff*0xff)<<7)+(alpha>>1))/alpha) + +static png_byte +png_unpremultiply(png_uint_32 component, png_uint_32 alpha, + png_uint_32 reciprocal/*from the above macro*/) +{ + /* The following gives 1.0 for an alpha of 0, which is fine, otherwise if 0/0 + * is represented as some other value there is more likely to be a + * discontinuity which will probably damage compression when moving from a + * fully transparent area to a nearly transparent one. (The assumption here + * is that opaque areas tend not to be 0 intensity.) + * + * There is a rounding problem here; if alpha is less than 128 it will end up + * as 0 when scaled to 8 bits. To avoid introducing spurious colors into the + * output change for this too. + */ + if (component >= alpha || alpha < 128) + return 255; + + /* component 0) + { + /* The test is that alpha/257 (rounded) is less than 255, the first value + * that becomes 255 is 65407. + * NOTE: this must agree with the PNG_DIV257 macro (which must, therefore, + * be exact!) [Could also test reciprocal != 0] + */ + if (alpha < 65407) + { + component *= reciprocal; + component += 64; /* round to nearest */ + component >>= 7; + } + + else + component *= 255; + + /* Convert the component to sRGB. */ + return (png_byte)PNG_sRGB_FROM_LINEAR(component); + } + + else + return 0; +} + +static int +png_write_image_8bit(png_voidp argument) +{ + png_image_write_control *display = png_voidcast(png_image_write_control*, + argument); + png_imagep image = display->image; + png_structrp png_ptr = image->opaque->png_ptr; + + png_const_uint_16p input_row = png_voidcast(png_const_uint_16p, + display->first_row); + png_bytep output_row = png_voidcast(png_bytep, display->local_row); + png_uint_32 y = image->height; + const int channels = (image->format & PNG_FORMAT_FLAG_COLOR) ? 3 : 1; + + if (image->format & PNG_FORMAT_FLAG_ALPHA) + { + png_bytep row_end; + int aindex; + + if (image->format & PNG_FORMAT_FLAG_AFIRST) + { + aindex = -1; + ++input_row; /* To point to the first component */ + ++output_row; + } + + else + aindex = channels; + + /* Use row_end in place of a loop counter: */ + row_end = output_row + image->width * (channels+1); + + while (y-- > 0) + { + png_const_uint_16p in_ptr = input_row; + png_bytep out_ptr = output_row; + + while (out_ptr < row_end) + { + png_uint_16 alpha = in_ptr[aindex]; + png_byte alphabyte = (png_byte)PNG_DIV257(alpha); + png_uint_32 reciprocal = 0; + int c; + + /* Scale and write the alpha channel. */ + out_ptr[aindex] = alphabyte; + + if (alphabyte > 0 && alphabyte < 255) + reciprocal = UNP_RECIPROCAL(alpha); + + c = channels; + do /* always at least one channel */ + *out_ptr++ = png_unpremultiply(*in_ptr++, alpha, reciprocal); + while (--c > 0); + + /* Skip to next component (skip the intervening alpha channel) */ + ++in_ptr; + ++out_ptr; + } /* while out_ptr < row_end */ + + png_write_row(png_ptr, png_voidcast(png_const_bytep, + display->local_row)); + input_row += display->row_bytes/(sizeof (png_uint_16)); + } /* while y */ + } + + else + { + /* No alpha channel, so the row_end really is the end of the row and it + * is sufficient to loop over the components one by one. + */ + png_bytep row_end = output_row + image->width * channels; + + while (y-- > 0) + { + png_const_uint_16p in_ptr = input_row; + png_bytep out_ptr = output_row; + + while (out_ptr < row_end) + { + png_uint_32 component = *in_ptr++; + + component *= 255; + *out_ptr++ = (png_byte)PNG_sRGB_FROM_LINEAR(component); + } + + png_write_row(png_ptr, output_row); + input_row += display->row_bytes/(sizeof (png_uint_16)); + } + } + + return 1; +} + +static void +png_image_set_PLTE(png_image_write_control *display) +{ + const png_imagep image = display->image; + const void *cmap = display->colormap; + const int entries = image->colormap_entries > 256 ? 256 : + (int)image->colormap_entries; + + /* NOTE: the caller must check for cmap != NULL and entries != 0 */ + const png_uint_32 format = image->format; + const int channels = PNG_IMAGE_SAMPLE_CHANNELS(format); + +# ifdef PNG_FORMAT_BGR_SUPPORTED + const int afirst = (format & PNG_FORMAT_FLAG_AFIRST) != 0 && + (format & PNG_FORMAT_FLAG_ALPHA) != 0; +# else +# define afirst 0 +# endif + +# ifdef PNG_FORMAT_BGR_SUPPORTED + const int bgr = (format & PNG_FORMAT_FLAG_BGR) ? 2 : 0; +# else +# define bgr 0 +# endif + + int i, num_trans; + png_color palette[256]; + png_byte tRNS[256]; + + memset(tRNS, 255, (sizeof tRNS)); + memset(palette, 0, (sizeof palette)); + + for (i=num_trans=0; i= 3) /* RGB */ + { + palette[i].blue = (png_byte)PNG_sRGB_FROM_LINEAR(255 * + entry[(2 ^ bgr)]); + palette[i].green = (png_byte)PNG_sRGB_FROM_LINEAR(255 * + entry[1]); + palette[i].red = (png_byte)PNG_sRGB_FROM_LINEAR(255 * + entry[bgr]); + } + + else /* Gray */ + palette[i].blue = palette[i].red = palette[i].green = + (png_byte)PNG_sRGB_FROM_LINEAR(255 * *entry); + } + + else /* alpha */ + { + png_uint_16 alpha = entry[afirst ? 0 : channels-1]; + png_byte alphabyte = (png_byte)PNG_DIV257(alpha); + png_uint_32 reciprocal = 0; + + /* Calculate a reciprocal, as in the png_write_image_8bit code above + * this is designed to produce a value scaled to 255*65535 when + * divided by 128 (i.e. asr 7). + */ + if (alphabyte > 0 && alphabyte < 255) + reciprocal = (((0xffff*0xff)<<7)+(alpha>>1))/alpha; + + tRNS[i] = alphabyte; + if (alphabyte < 255) + num_trans = i+1; + + if (channels >= 3) /* RGB */ + { + palette[i].blue = png_unpremultiply(entry[afirst + (2 ^ bgr)], + alpha, reciprocal); + palette[i].green = png_unpremultiply(entry[afirst + 1], alpha, + reciprocal); + palette[i].red = png_unpremultiply(entry[afirst + bgr], alpha, + reciprocal); + } + + else /* gray */ + palette[i].blue = palette[i].red = palette[i].green = + png_unpremultiply(entry[afirst], alpha, reciprocal); + } + } + + else /* Color-map has sRGB values */ + { + png_const_bytep entry = png_voidcast(png_const_bytep, cmap); + + entry += i * channels; + + switch (channels) + { + case 4: + tRNS[i] = entry[afirst ? 0 : 3]; + if (tRNS[i] < 255) + num_trans = i+1; + /* FALL THROUGH */ + case 3: + palette[i].blue = entry[afirst + (2 ^ bgr)]; + palette[i].green = entry[afirst + 1]; + palette[i].red = entry[afirst + bgr]; + break; + + case 2: + tRNS[i] = entry[1 ^ afirst]; + if (tRNS[i] < 255) + num_trans = i+1; + /* FALL THROUGH */ + case 1: + palette[i].blue = palette[i].red = palette[i].green = + entry[afirst]; + break; + + default: + break; + } + } + } + +# ifdef afirst +# undef afirst +# endif +# ifdef bgr +# undef bgr +# endif + + png_set_PLTE(image->opaque->png_ptr, image->opaque->info_ptr, palette, + entries); + + if (num_trans > 0) + png_set_tRNS(image->opaque->png_ptr, image->opaque->info_ptr, tRNS, + num_trans, NULL); + + image->colormap_entries = entries; +} + +static int +png_image_write_main(png_voidp argument) +{ + png_image_write_control *display = png_voidcast(png_image_write_control*, + argument); + png_imagep image = display->image; + png_structrp png_ptr = image->opaque->png_ptr; + png_inforp info_ptr = image->opaque->info_ptr; + png_uint_32 format = image->format; + + int colormap = (format & PNG_FORMAT_FLAG_COLORMAP) != 0; + int linear = !colormap && (format & PNG_FORMAT_FLAG_LINEAR) != 0; /* input */ + int alpha = !colormap && (format & PNG_FORMAT_FLAG_ALPHA) != 0; + int write_16bit = linear && !colormap && !display->convert_to_8bit; + +# ifdef PNG_BENIGN_ERRORS_SUPPORTED + /* Make sure we error out on any bad situation */ + png_set_benign_errors(png_ptr, 0/*error*/); +# endif + + /* Default the 'row_stride' parameter if required. */ + if (display->row_stride == 0) + display->row_stride = PNG_IMAGE_ROW_STRIDE(*image); + + /* Set the required transforms then write the rows in the correct order. */ + if (format & PNG_FORMAT_FLAG_COLORMAP) + { + if (display->colormap != NULL && image->colormap_entries > 0) + { + png_uint_32 entries = image->colormap_entries; + + png_set_IHDR(png_ptr, info_ptr, image->width, image->height, + entries > 16 ? 8 : (entries > 4 ? 4 : (entries > 2 ? 2 : 1)), + PNG_COLOR_TYPE_PALETTE, PNG_INTERLACE_NONE, + PNG_COMPRESSION_TYPE_BASE, PNG_FILTER_TYPE_BASE); + + png_image_set_PLTE(display); + } + + else + png_error(image->opaque->png_ptr, + "no color-map for color-mapped image"); + } + + else + png_set_IHDR(png_ptr, info_ptr, image->width, image->height, + write_16bit ? 16 : 8, + ((format & PNG_FORMAT_FLAG_COLOR) ? PNG_COLOR_MASK_COLOR : 0) + + ((format & PNG_FORMAT_FLAG_ALPHA) ? PNG_COLOR_MASK_ALPHA : 0), + PNG_INTERLACE_NONE, PNG_COMPRESSION_TYPE_BASE, PNG_FILTER_TYPE_BASE); + + /* Counter-intuitively the data transformations must be called *after* + * png_write_info, not before as in the read code, but the 'set' functions + * must still be called before. Just set the color space information, never + * write an interlaced image. + */ + + if (write_16bit) + { + /* The gamma here is 1.0 (linear) and the cHRM chunk matches sRGB. */ + png_set_gAMA_fixed(png_ptr, info_ptr, PNG_GAMMA_LINEAR); + + if (!(image->flags & PNG_IMAGE_FLAG_COLORSPACE_NOT_sRGB)) + png_set_cHRM_fixed(png_ptr, info_ptr, + /* color x y */ + /* white */ 31270, 32900, + /* red */ 64000, 33000, + /* green */ 30000, 60000, + /* blue */ 15000, 6000 + ); + } + + else if (!(image->flags & PNG_IMAGE_FLAG_COLORSPACE_NOT_sRGB)) + png_set_sRGB(png_ptr, info_ptr, PNG_sRGB_INTENT_PERCEPTUAL); + + /* Else writing an 8-bit file and the *colors* aren't sRGB, but the 8-bit + * space must still be gamma encoded. + */ + else + png_set_gAMA_fixed(png_ptr, info_ptr, PNG_GAMMA_sRGB_INVERSE); + + /* Write the file header. */ + png_write_info(png_ptr, info_ptr); + + /* Now set up the data transformations (*after* the header is written), + * remove the handled transformations from the 'format' flags for checking. + * + * First check for a little endian system if writing 16 bit files. + */ + if (write_16bit) + { + PNG_CONST png_uint_16 le = 0x0001; + + if (*(png_const_bytep)&le) + png_set_swap(png_ptr); + } + +# ifdef PNG_SIMPLIFIED_WRITE_BGR_SUPPORTED + if (format & PNG_FORMAT_FLAG_BGR) + { + if (!colormap && (format & PNG_FORMAT_FLAG_COLOR) != 0) + png_set_bgr(png_ptr); + format &= ~PNG_FORMAT_FLAG_BGR; + } +# endif + +# ifdef PNG_SIMPLIFIED_WRITE_AFIRST_SUPPORTED + if (format & PNG_FORMAT_FLAG_AFIRST) + { + if (!colormap && (format & PNG_FORMAT_FLAG_ALPHA) != 0) + png_set_swap_alpha(png_ptr); + format &= ~PNG_FORMAT_FLAG_AFIRST; + } +# endif + + /* If there are 16 or fewer color-map entries we wrote a lower bit depth + * above, but the application data is still byte packed. + */ + if (colormap && image->colormap_entries <= 16) + png_set_packing(png_ptr); + + /* That should have handled all (both) the transforms. */ + if ((format & ~(png_uint_32)(PNG_FORMAT_FLAG_COLOR | PNG_FORMAT_FLAG_LINEAR | + PNG_FORMAT_FLAG_ALPHA | PNG_FORMAT_FLAG_COLORMAP)) != 0) + png_error(png_ptr, "png_write_image: unsupported transformation"); + + { + png_const_bytep row = png_voidcast(png_const_bytep, display->buffer); + ptrdiff_t row_bytes = display->row_stride; + + if (linear) + row_bytes *= (sizeof (png_uint_16)); + + if (row_bytes < 0) + row += (image->height-1) * (-row_bytes); + + display->first_row = row; + display->row_bytes = row_bytes; + } + + /* Apply 'fast' options if the flag is set. */ + if ((image->flags & PNG_IMAGE_FLAG_FAST) != 0) + { + png_set_filter(png_ptr, PNG_FILTER_TYPE_BASE, PNG_NO_FILTERS); + /* NOTE: determined by experiment using pngstest, this reflects some + * balance between the time to write the image once and the time to read + * it about 50 times. The speed-up in pngstest was about 10-20% of the + * total (user) time on a heavily loaded system. + */ + png_set_compression_level(png_ptr, 3); + } + + /* Check for the cases that currently require a pre-transform on the row + * before it is written. This only applies when the input is 16-bit and + * either there is an alpha channel or it is converted to 8-bit. + */ + if ((linear && alpha) || (!colormap && display->convert_to_8bit)) + { + png_bytep row = png_voidcast(png_bytep, png_malloc(png_ptr, + png_get_rowbytes(png_ptr, info_ptr))); + int result; + + display->local_row = row; + if (write_16bit) + result = png_safe_execute(image, png_write_image_16bit, display); + else + result = png_safe_execute(image, png_write_image_8bit, display); + display->local_row = NULL; + + png_free(png_ptr, row); + + /* Skip the 'write_end' on error: */ + if (!result) + return 0; + } + + /* Otherwise this is the case where the input is in a format currently + * supported by the rest of the libpng write code; call it directly. + */ + else + { + png_const_bytep row = png_voidcast(png_const_bytep, display->first_row); + ptrdiff_t row_bytes = display->row_bytes; + png_uint_32 y = image->height; + + while (y-- > 0) + { + png_write_row(png_ptr, row); + row += row_bytes; + } + } + + png_write_end(png_ptr, info_ptr); + return 1; +} + +int PNGAPI +png_image_write_to_stdio(png_imagep image, FILE *file, int convert_to_8bit, + const void *buffer, png_int_32 row_stride, const void *colormap) +{ + /* Write the image to the given (FILE*). */ + if (image != NULL && image->version == PNG_IMAGE_VERSION) + { + if (file != NULL) + { + if (png_image_write_init(image)) + { + png_image_write_control display; + int result; + + /* This is slightly evil, but png_init_io doesn't do anything other + * than this and we haven't changed the standard IO functions so + * this saves a 'safe' function. + */ + image->opaque->png_ptr->io_ptr = file; + + memset(&display, 0, (sizeof display)); + display.image = image; + display.buffer = buffer; + display.row_stride = row_stride; + display.colormap = colormap; + display.convert_to_8bit = convert_to_8bit; + + result = png_safe_execute(image, png_image_write_main, &display); + png_image_free(image); + return result; + } + + else + return 0; + } + + else + return png_image_error(image, + "png_image_write_to_stdio: invalid argument"); + } + + else if (image != NULL) + return png_image_error(image, + "png_image_write_to_stdio: incorrect PNG_IMAGE_VERSION"); + + else + return 0; +} + +int PNGAPI +png_image_write_to_file(png_imagep image, const char *file_name, + int convert_to_8bit, const void *buffer, png_int_32 row_stride, + const void *colormap) +{ + /* Write the image to the named file. */ + if (image != NULL && image->version == PNG_IMAGE_VERSION) + { + if (file_name != NULL) + { + FILE *fp = fopen(file_name, "wb"); + + if (fp != NULL) + { + if (png_image_write_to_stdio(image, fp, convert_to_8bit, buffer, + row_stride, colormap)) + { + int error; /* from fflush/fclose */ + + /* Make sure the file is flushed correctly. */ + if (fflush(fp) == 0 && ferror(fp) == 0) + { + if (fclose(fp) == 0) + return 1; + + error = errno; /* from fclose */ + } + + else + { + error = errno; /* from fflush or ferror */ + (void)fclose(fp); + } + + (void)remove(file_name); + /* The image has already been cleaned up; this is just used to + * set the error (because the original write succeeded). + */ + return png_image_error(image, strerror(error)); + } + + else + { + /* Clean up: just the opened file. */ + (void)fclose(fp); + (void)remove(file_name); + return 0; + } + } + + else + return png_image_error(image, strerror(errno)); + } + + else + return png_image_error(image, + "png_image_write_to_file: invalid argument"); + } + + else if (image != NULL) + return png_image_error(image, + "png_image_write_to_file: incorrect PNG_IMAGE_VERSION"); + + else + return 0; +} +#endif /* PNG_STDIO_SUPPORTED */ +#endif /* SIMPLIFIED_WRITE */ +#endif /* PNG_WRITE_SUPPORTED */ diff --git a/ml/dlib/dlib/external/libpng/pngwtran.c b/ml/dlib/dlib/external/libpng/pngwtran.c new file mode 100644 index 000000000..98703f8c8 --- /dev/null +++ b/ml/dlib/dlib/external/libpng/pngwtran.c @@ -0,0 +1,637 @@ + +/* pngwtran.c - transforms the data in a row for PNG writers + * + * Last changed in libpng 1.6.0 [February 14, 2013] + * Copyright (c) 1998-2013 Glenn Randers-Pehrson + * (Version 0.96 Copyright (c) 1996, 1997 Andreas Dilger) + * (Version 0.88 Copyright (c) 1995, 1996 Guy Eric Schalnat, Group 42, Inc.) + * + * This code is released under the libpng license. + * For conditions of distribution and use, see the disclaimer + * and license in png.h + */ + +#include "pngpriv.h" + +#ifdef PNG_WRITE_SUPPORTED + +#ifdef PNG_WRITE_TRANSFORMS_SUPPORTED +/* Transform the data according to the user's wishes. The order of + * transformations is significant. + */ +void /* PRIVATE */ +png_do_write_transformations(png_structrp png_ptr, png_row_infop row_info) +{ + png_debug(1, "in png_do_write_transformations"); + + if (png_ptr == NULL) + return; + +#ifdef PNG_WRITE_USER_TRANSFORM_SUPPORTED + if (png_ptr->transformations & PNG_USER_TRANSFORM) + if (png_ptr->write_user_transform_fn != NULL) + (*(png_ptr->write_user_transform_fn)) /* User write transform + function */ + (png_ptr, /* png_ptr */ + row_info, /* row_info: */ + /* png_uint_32 width; width of row */ + /* png_size_t rowbytes; number of bytes in row */ + /* png_byte color_type; color type of pixels */ + /* png_byte bit_depth; bit depth of samples */ + /* png_byte channels; number of channels (1-4) */ + /* png_byte pixel_depth; bits per pixel (depth*channels) */ + png_ptr->row_buf + 1); /* start of pixel data for row */ +#endif + +#ifdef PNG_WRITE_FILLER_SUPPORTED + if (png_ptr->transformations & PNG_FILLER) + png_do_strip_channel(row_info, png_ptr->row_buf + 1, + !(png_ptr->flags & PNG_FLAG_FILLER_AFTER)); +#endif + +#ifdef PNG_WRITE_PACKSWAP_SUPPORTED + if (png_ptr->transformations & PNG_PACKSWAP) + png_do_packswap(row_info, png_ptr->row_buf + 1); +#endif + +#ifdef PNG_WRITE_PACK_SUPPORTED + if (png_ptr->transformations & PNG_PACK) + png_do_pack(row_info, png_ptr->row_buf + 1, + (png_uint_32)png_ptr->bit_depth); +#endif + +#ifdef PNG_WRITE_SWAP_SUPPORTED + if (png_ptr->transformations & PNG_SWAP_BYTES) + png_do_swap(row_info, png_ptr->row_buf + 1); +#endif + +#ifdef PNG_WRITE_SHIFT_SUPPORTED + if (png_ptr->transformations & PNG_SHIFT) + png_do_shift(row_info, png_ptr->row_buf + 1, + &(png_ptr->shift)); +#endif + +#ifdef PNG_WRITE_SWAP_ALPHA_SUPPORTED + if (png_ptr->transformations & PNG_SWAP_ALPHA) + png_do_write_swap_alpha(row_info, png_ptr->row_buf + 1); +#endif + +#ifdef PNG_WRITE_INVERT_ALPHA_SUPPORTED + if (png_ptr->transformations & PNG_INVERT_ALPHA) + png_do_write_invert_alpha(row_info, png_ptr->row_buf + 1); +#endif + +#ifdef PNG_WRITE_BGR_SUPPORTED + if (png_ptr->transformations & PNG_BGR) + png_do_bgr(row_info, png_ptr->row_buf + 1); +#endif + +#ifdef PNG_WRITE_INVERT_SUPPORTED + if (png_ptr->transformations & PNG_INVERT_MONO) + png_do_invert(row_info, png_ptr->row_buf + 1); +#endif +} + +#ifdef PNG_WRITE_PACK_SUPPORTED +/* Pack pixels into bytes. Pass the true bit depth in bit_depth. The + * row_info bit depth should be 8 (one pixel per byte). The channels + * should be 1 (this only happens on grayscale and paletted images). + */ +void /* PRIVATE */ +png_do_pack(png_row_infop row_info, png_bytep row, png_uint_32 bit_depth) +{ + png_debug(1, "in png_do_pack"); + + if (row_info->bit_depth == 8 && + row_info->channels == 1) + { + switch ((int)bit_depth) + { + case 1: + { + png_bytep sp, dp; + int mask, v; + png_uint_32 i; + png_uint_32 row_width = row_info->width; + + sp = row; + dp = row; + mask = 0x80; + v = 0; + + for (i = 0; i < row_width; i++) + { + if (*sp != 0) + v |= mask; + + sp++; + + if (mask > 1) + mask >>= 1; + + else + { + mask = 0x80; + *dp = (png_byte)v; + dp++; + v = 0; + } + } + + if (mask != 0x80) + *dp = (png_byte)v; + + break; + } + + case 2: + { + png_bytep sp, dp; + int shift, v; + png_uint_32 i; + png_uint_32 row_width = row_info->width; + + sp = row; + dp = row; + shift = 6; + v = 0; + + for (i = 0; i < row_width; i++) + { + png_byte value; + + value = (png_byte)(*sp & 0x03); + v |= (value << shift); + + if (shift == 0) + { + shift = 6; + *dp = (png_byte)v; + dp++; + v = 0; + } + + else + shift -= 2; + + sp++; + } + + if (shift != 6) + *dp = (png_byte)v; + + break; + } + + case 4: + { + png_bytep sp, dp; + int shift, v; + png_uint_32 i; + png_uint_32 row_width = row_info->width; + + sp = row; + dp = row; + shift = 4; + v = 0; + + for (i = 0; i < row_width; i++) + { + png_byte value; + + value = (png_byte)(*sp & 0x0f); + v |= (value << shift); + + if (shift == 0) + { + shift = 4; + *dp = (png_byte)v; + dp++; + v = 0; + } + + else + shift -= 4; + + sp++; + } + + if (shift != 4) + *dp = (png_byte)v; + + break; + } + + default: + break; + } + + row_info->bit_depth = (png_byte)bit_depth; + row_info->pixel_depth = (png_byte)(bit_depth * row_info->channels); + row_info->rowbytes = PNG_ROWBYTES(row_info->pixel_depth, + row_info->width); + } +} +#endif + +#ifdef PNG_WRITE_SHIFT_SUPPORTED +/* Shift pixel values to take advantage of whole range. Pass the + * true number of bits in bit_depth. The row should be packed + * according to row_info->bit_depth. Thus, if you had a row of + * bit depth 4, but the pixels only had values from 0 to 7, you + * would pass 3 as bit_depth, and this routine would translate the + * data to 0 to 15. + */ +void /* PRIVATE */ +png_do_shift(png_row_infop row_info, png_bytep row, + png_const_color_8p bit_depth) +{ + png_debug(1, "in png_do_shift"); + + if (row_info->color_type != PNG_COLOR_TYPE_PALETTE) + { + int shift_start[4], shift_dec[4]; + int channels = 0; + + if (row_info->color_type & PNG_COLOR_MASK_COLOR) + { + shift_start[channels] = row_info->bit_depth - bit_depth->red; + shift_dec[channels] = bit_depth->red; + channels++; + + shift_start[channels] = row_info->bit_depth - bit_depth->green; + shift_dec[channels] = bit_depth->green; + channels++; + + shift_start[channels] = row_info->bit_depth - bit_depth->blue; + shift_dec[channels] = bit_depth->blue; + channels++; + } + + else + { + shift_start[channels] = row_info->bit_depth - bit_depth->gray; + shift_dec[channels] = bit_depth->gray; + channels++; + } + + if (row_info->color_type & PNG_COLOR_MASK_ALPHA) + { + shift_start[channels] = row_info->bit_depth - bit_depth->alpha; + shift_dec[channels] = bit_depth->alpha; + channels++; + } + + /* With low row depths, could only be grayscale, so one channel */ + if (row_info->bit_depth < 8) + { + png_bytep bp = row; + png_size_t i; + unsigned int mask; + png_size_t row_bytes = row_info->rowbytes; + + if (bit_depth->gray == 1 && row_info->bit_depth == 2) + mask = 0x55; + + else if (row_info->bit_depth == 4 && bit_depth->gray == 3) + mask = 0x11; + + else + mask = 0xff; + + for (i = 0; i < row_bytes; i++, bp++) + { + int j; + unsigned int v, out; + + v = *bp; + out = 0; + + for (j = shift_start[0]; j > -shift_dec[0]; j -= shift_dec[0]) + { + if (j > 0) + out |= v << j; + + else + out |= (v >> (-j)) & mask; + } + + *bp = (png_byte)(out & 0xff); + } + } + + else if (row_info->bit_depth == 8) + { + png_bytep bp = row; + png_uint_32 i; + png_uint_32 istop = channels * row_info->width; + + for (i = 0; i < istop; i++, bp++) + { + + const unsigned int c = i%channels; + int j; + unsigned int v, out; + + v = *bp; + out = 0; + + for (j = shift_start[c]; j > -shift_dec[c]; j -= shift_dec[c]) + { + if (j > 0) + out |= v << j; + + else + out |= v >> (-j); + } + + *bp = (png_byte)(out & 0xff); + } + } + + else + { + png_bytep bp; + png_uint_32 i; + png_uint_32 istop = channels * row_info->width; + + for (bp = row, i = 0; i < istop; i++) + { + const unsigned int c = i%channels; + int j; + unsigned int value, v; + + v = png_get_uint_16(bp); + value = 0; + + for (j = shift_start[c]; j > -shift_dec[c]; j -= shift_dec[c]) + { + if (j > 0) + value |= v << j; + + else + value |= v >> (-j); + } + *bp++ = (png_byte)((value >> 8) & 0xff); + *bp++ = (png_byte)(value & 0xff); + } + } + } +} +#endif + +#ifdef PNG_WRITE_SWAP_ALPHA_SUPPORTED +void /* PRIVATE */ +png_do_write_swap_alpha(png_row_infop row_info, png_bytep row) +{ + png_debug(1, "in png_do_write_swap_alpha"); + + { + if (row_info->color_type == PNG_COLOR_TYPE_RGB_ALPHA) + { + if (row_info->bit_depth == 8) + { + /* This converts from ARGB to RGBA */ + png_bytep sp, dp; + png_uint_32 i; + png_uint_32 row_width = row_info->width; + + for (i = 0, sp = dp = row; i < row_width; i++) + { + png_byte save = *(sp++); + *(dp++) = *(sp++); + *(dp++) = *(sp++); + *(dp++) = *(sp++); + *(dp++) = save; + } + } + +#ifdef PNG_WRITE_16BIT_SUPPORTED + else + { + /* This converts from AARRGGBB to RRGGBBAA */ + png_bytep sp, dp; + png_uint_32 i; + png_uint_32 row_width = row_info->width; + + for (i = 0, sp = dp = row; i < row_width; i++) + { + png_byte save[2]; + save[0] = *(sp++); + save[1] = *(sp++); + *(dp++) = *(sp++); + *(dp++) = *(sp++); + *(dp++) = *(sp++); + *(dp++) = *(sp++); + *(dp++) = *(sp++); + *(dp++) = *(sp++); + *(dp++) = save[0]; + *(dp++) = save[1]; + } + } +#endif /* PNG_WRITE_16BIT_SUPPORTED */ + } + + else if (row_info->color_type == PNG_COLOR_TYPE_GRAY_ALPHA) + { + if (row_info->bit_depth == 8) + { + /* This converts from AG to GA */ + png_bytep sp, dp; + png_uint_32 i; + png_uint_32 row_width = row_info->width; + + for (i = 0, sp = dp = row; i < row_width; i++) + { + png_byte save = *(sp++); + *(dp++) = *(sp++); + *(dp++) = save; + } + } + +#ifdef PNG_WRITE_16BIT_SUPPORTED + else + { + /* This converts from AAGG to GGAA */ + png_bytep sp, dp; + png_uint_32 i; + png_uint_32 row_width = row_info->width; + + for (i = 0, sp = dp = row; i < row_width; i++) + { + png_byte save[2]; + save[0] = *(sp++); + save[1] = *(sp++); + *(dp++) = *(sp++); + *(dp++) = *(sp++); + *(dp++) = save[0]; + *(dp++) = save[1]; + } + } +#endif /* PNG_WRITE_16BIT_SUPPORTED */ + } + } +} +#endif + +#ifdef PNG_WRITE_INVERT_ALPHA_SUPPORTED +void /* PRIVATE */ +png_do_write_invert_alpha(png_row_infop row_info, png_bytep row) +{ + png_debug(1, "in png_do_write_invert_alpha"); + + { + if (row_info->color_type == PNG_COLOR_TYPE_RGB_ALPHA) + { + if (row_info->bit_depth == 8) + { + /* This inverts the alpha channel in RGBA */ + png_bytep sp, dp; + png_uint_32 i; + png_uint_32 row_width = row_info->width; + + for (i = 0, sp = dp = row; i < row_width; i++) + { + /* Does nothing + *(dp++) = *(sp++); + *(dp++) = *(sp++); + *(dp++) = *(sp++); + */ + sp+=3; dp = sp; + *(dp++) = (png_byte)(255 - *(sp++)); + } + } + +#ifdef PNG_WRITE_16BIT_SUPPORTED + else + { + /* This inverts the alpha channel in RRGGBBAA */ + png_bytep sp, dp; + png_uint_32 i; + png_uint_32 row_width = row_info->width; + + for (i = 0, sp = dp = row; i < row_width; i++) + { + /* Does nothing + *(dp++) = *(sp++); + *(dp++) = *(sp++); + *(dp++) = *(sp++); + *(dp++) = *(sp++); + *(dp++) = *(sp++); + *(dp++) = *(sp++); + */ + sp+=6; dp = sp; + *(dp++) = (png_byte)(255 - *(sp++)); + *(dp++) = (png_byte)(255 - *(sp++)); + } + } +#endif /* PNG_WRITE_16BIT_SUPPORTED */ + } + + else if (row_info->color_type == PNG_COLOR_TYPE_GRAY_ALPHA) + { + if (row_info->bit_depth == 8) + { + /* This inverts the alpha channel in GA */ + png_bytep sp, dp; + png_uint_32 i; + png_uint_32 row_width = row_info->width; + + for (i = 0, sp = dp = row; i < row_width; i++) + { + *(dp++) = *(sp++); + *(dp++) = (png_byte)(255 - *(sp++)); + } + } + +#ifdef PNG_WRITE_16BIT_SUPPORTED + else + { + /* This inverts the alpha channel in GGAA */ + png_bytep sp, dp; + png_uint_32 i; + png_uint_32 row_width = row_info->width; + + for (i = 0, sp = dp = row; i < row_width; i++) + { + /* Does nothing + *(dp++) = *(sp++); + *(dp++) = *(sp++); + */ + sp+=2; dp = sp; + *(dp++) = (png_byte)(255 - *(sp++)); + *(dp++) = (png_byte)(255 - *(sp++)); + } + } +#endif /* PNG_WRITE_16BIT_SUPPORTED */ + } + } +} +#endif +#endif /* PNG_WRITE_TRANSFORMS_SUPPORTED */ + +#ifdef PNG_MNG_FEATURES_SUPPORTED +/* Undoes intrapixel differencing */ +void /* PRIVATE */ +png_do_write_intrapixel(png_row_infop row_info, png_bytep row) +{ + png_debug(1, "in png_do_write_intrapixel"); + + if ((row_info->color_type & PNG_COLOR_MASK_COLOR)) + { + int bytes_per_pixel; + png_uint_32 row_width = row_info->width; + if (row_info->bit_depth == 8) + { + png_bytep rp; + png_uint_32 i; + + if (row_info->color_type == PNG_COLOR_TYPE_RGB) + bytes_per_pixel = 3; + + else if (row_info->color_type == PNG_COLOR_TYPE_RGB_ALPHA) + bytes_per_pixel = 4; + + else + return; + + for (i = 0, rp = row; i < row_width; i++, rp += bytes_per_pixel) + { + *(rp) = (png_byte)((*rp - *(rp + 1)) & 0xff); + *(rp + 2) = (png_byte)((*(rp + 2) - *(rp + 1)) & 0xff); + } + } + +#ifdef PNG_WRITE_16BIT_SUPPORTED + else if (row_info->bit_depth == 16) + { + png_bytep rp; + png_uint_32 i; + + if (row_info->color_type == PNG_COLOR_TYPE_RGB) + bytes_per_pixel = 6; + + else if (row_info->color_type == PNG_COLOR_TYPE_RGB_ALPHA) + bytes_per_pixel = 8; + + else + return; + + for (i = 0, rp = row; i < row_width; i++, rp += bytes_per_pixel) + { + png_uint_32 s0 = (*(rp ) << 8) | *(rp + 1); + png_uint_32 s1 = (*(rp + 2) << 8) | *(rp + 3); + png_uint_32 s2 = (*(rp + 4) << 8) | *(rp + 5); + png_uint_32 red = (png_uint_32)((s0 - s1) & 0xffffL); + png_uint_32 blue = (png_uint_32)((s2 - s1) & 0xffffL); + *(rp ) = (png_byte)((red >> 8) & 0xff); + *(rp + 1) = (png_byte)(red & 0xff); + *(rp + 4) = (png_byte)((blue >> 8) & 0xff); + *(rp + 5) = (png_byte)(blue & 0xff); + } + } +#endif /* PNG_WRITE_16BIT_SUPPORTED */ + } +} +#endif /* PNG_MNG_FEATURES_SUPPORTED */ +#endif /* PNG_WRITE_SUPPORTED */ diff --git a/ml/dlib/dlib/external/libpng/pngwutil.c b/ml/dlib/dlib/external/libpng/pngwutil.c new file mode 100644 index 000000000..49e6a2d21 --- /dev/null +++ b/ml/dlib/dlib/external/libpng/pngwutil.c @@ -0,0 +1,3023 @@ + +/* pngwutil.c - utilities to write a PNG file + * + * Last changed in libpng 1.6.2 [April 25, 2013] + * Copyright (c) 1998-2013 Glenn Randers-Pehrson + * (Version 0.96 Copyright (c) 1996, 1997 Andreas Dilger) + * (Version 0.88 Copyright (c) 1995, 1996 Guy Eric Schalnat, Group 42, Inc.) + * + * This code is released under the libpng license. + * For conditions of distribution and use, see the disclaimer + * and license in png.h + */ + +#include "pngpriv.h" + +#ifdef PNG_WRITE_SUPPORTED + +#ifdef PNG_WRITE_INT_FUNCTIONS_SUPPORTED +/* Place a 32-bit number into a buffer in PNG byte order. We work + * with unsigned numbers for convenience, although one supported + * ancillary chunk uses signed (two's complement) numbers. + */ +void PNGAPI +png_save_uint_32(png_bytep buf, png_uint_32 i) +{ + buf[0] = (png_byte)((i >> 24) & 0xff); + buf[1] = (png_byte)((i >> 16) & 0xff); + buf[2] = (png_byte)((i >> 8) & 0xff); + buf[3] = (png_byte)(i & 0xff); +} + +/* Place a 16-bit number into a buffer in PNG byte order. + * The parameter is declared unsigned int, not png_uint_16, + * just to avoid potential problems on pre-ANSI C compilers. + */ +void PNGAPI +png_save_uint_16(png_bytep buf, unsigned int i) +{ + buf[0] = (png_byte)((i >> 8) & 0xff); + buf[1] = (png_byte)(i & 0xff); +} +#endif + +/* Simple function to write the signature. If we have already written + * the magic bytes of the signature, or more likely, the PNG stream is + * being embedded into another stream and doesn't need its own signature, + * we should call png_set_sig_bytes() to tell libpng how many of the + * bytes have already been written. + */ +void PNGAPI +png_write_sig(png_structrp png_ptr) +{ + png_byte png_signature[8] = {137, 80, 78, 71, 13, 10, 26, 10}; + +#ifdef PNG_IO_STATE_SUPPORTED + /* Inform the I/O callback that the signature is being written */ + png_ptr->io_state = PNG_IO_WRITING | PNG_IO_SIGNATURE; +#endif + + /* Write the rest of the 8 byte signature */ + png_write_data(png_ptr, &png_signature[png_ptr->sig_bytes], + (png_size_t)(8 - png_ptr->sig_bytes)); + + if (png_ptr->sig_bytes < 3) + png_ptr->mode |= PNG_HAVE_PNG_SIGNATURE; +} + +/* Write the start of a PNG chunk. The type is the chunk type. + * The total_length is the sum of the lengths of all the data you will be + * passing in png_write_chunk_data(). + */ +static void +png_write_chunk_header(png_structrp png_ptr, png_uint_32 chunk_name, + png_uint_32 length) +{ + png_byte buf[8]; + +#if defined(PNG_DEBUG) && (PNG_DEBUG > 0) + PNG_CSTRING_FROM_CHUNK(buf, chunk_name); + png_debug2(0, "Writing %s chunk, length = %lu", buf, (unsigned long)length); +#endif + + if (png_ptr == NULL) + return; + +#ifdef PNG_IO_STATE_SUPPORTED + /* Inform the I/O callback that the chunk header is being written. + * PNG_IO_CHUNK_HDR requires a single I/O call. + */ + png_ptr->io_state = PNG_IO_WRITING | PNG_IO_CHUNK_HDR; +#endif + + /* Write the length and the chunk name */ + png_save_uint_32(buf, length); + png_save_uint_32(buf + 4, chunk_name); + png_write_data(png_ptr, buf, 8); + + /* Put the chunk name into png_ptr->chunk_name */ + png_ptr->chunk_name = chunk_name; + + /* Reset the crc and run it over the chunk name */ + png_reset_crc(png_ptr); + + png_calculate_crc(png_ptr, buf + 4, 4); + +#ifdef PNG_IO_STATE_SUPPORTED + /* Inform the I/O callback that chunk data will (possibly) be written. + * PNG_IO_CHUNK_DATA does NOT require a specific number of I/O calls. + */ + png_ptr->io_state = PNG_IO_WRITING | PNG_IO_CHUNK_DATA; +#endif +} + +void PNGAPI +png_write_chunk_start(png_structrp png_ptr, png_const_bytep chunk_string, + png_uint_32 length) +{ + png_write_chunk_header(png_ptr, PNG_CHUNK_FROM_STRING(chunk_string), length); +} + +/* Write the data of a PNG chunk started with png_write_chunk_header(). + * Note that multiple calls to this function are allowed, and that the + * sum of the lengths from these calls *must* add up to the total_length + * given to png_write_chunk_header(). + */ +void PNGAPI +png_write_chunk_data(png_structrp png_ptr, png_const_bytep data, + png_size_t length) +{ + /* Write the data, and run the CRC over it */ + if (png_ptr == NULL) + return; + + if (data != NULL && length > 0) + { + png_write_data(png_ptr, data, length); + + /* Update the CRC after writing the data, + * in case that the user I/O routine alters it. + */ + png_calculate_crc(png_ptr, data, length); + } +} + +/* Finish a chunk started with png_write_chunk_header(). */ +void PNGAPI +png_write_chunk_end(png_structrp png_ptr) +{ + png_byte buf[4]; + + if (png_ptr == NULL) return; + +#ifdef PNG_IO_STATE_SUPPORTED + /* Inform the I/O callback that the chunk CRC is being written. + * PNG_IO_CHUNK_CRC requires a single I/O function call. + */ + png_ptr->io_state = PNG_IO_WRITING | PNG_IO_CHUNK_CRC; +#endif + + /* Write the crc in a single operation */ + png_save_uint_32(buf, png_ptr->crc); + + png_write_data(png_ptr, buf, (png_size_t)4); +} + +/* Write a PNG chunk all at once. The type is an array of ASCII characters + * representing the chunk name. The array must be at least 4 bytes in + * length, and does not need to be null terminated. To be safe, pass the + * pre-defined chunk names here, and if you need a new one, define it + * where the others are defined. The length is the length of the data. + * All the data must be present. If that is not possible, use the + * png_write_chunk_start(), png_write_chunk_data(), and png_write_chunk_end() + * functions instead. + */ +static void +png_write_complete_chunk(png_structrp png_ptr, png_uint_32 chunk_name, + png_const_bytep data, png_size_t length) +{ + if (png_ptr == NULL) + return; + + /* On 64 bit architectures 'length' may not fit in a png_uint_32. */ + if (length > PNG_UINT_31_MAX) + png_error(png_ptr, "length exceeds PNG maxima"); + + png_write_chunk_header(png_ptr, chunk_name, (png_uint_32)length); + png_write_chunk_data(png_ptr, data, length); + png_write_chunk_end(png_ptr); +} + +/* This is the API that calls the internal function above. */ +void PNGAPI +png_write_chunk(png_structrp png_ptr, png_const_bytep chunk_string, + png_const_bytep data, png_size_t length) +{ + png_write_complete_chunk(png_ptr, PNG_CHUNK_FROM_STRING(chunk_string), data, + length); +} + +/* This is used below to find the size of an image to pass to png_deflate_claim, + * so it only needs to be accurate if the size is less than 16384 bytes (the + * point at which a lower LZ window size can be used.) + */ +static png_alloc_size_t +png_image_size(png_structrp png_ptr) +{ + /* Only return sizes up to the maximum of a png_uint_32, do this by limiting + * the width and height used to 15 bits. + */ + png_uint_32 h = png_ptr->height; + + if (png_ptr->rowbytes < 32768 && h < 32768) + { + if (png_ptr->interlaced) + { + /* Interlacing makes the image larger because of the replication of + * both the filter byte and the padding to a byte boundary. + */ + png_uint_32 w = png_ptr->width; + unsigned int pd = png_ptr->pixel_depth; + png_alloc_size_t cb_base; + int pass; + + for (cb_base=0, pass=0; pass<=6; ++pass) + { + png_uint_32 pw = PNG_PASS_COLS(w, pass); + + if (pw > 0) + cb_base += (PNG_ROWBYTES(pd, pw)+1) * PNG_PASS_ROWS(h, pass); + } + + return cb_base; + } + + else + return (png_ptr->rowbytes+1) * h; + } + + else + return 0xffffffffU; +} + +#ifdef PNG_WRITE_OPTIMIZE_CMF_SUPPORTED + /* This is the code to hack the first two bytes of the deflate stream (the + * deflate header) to correct the windowBits value to match the actual data + * size. Note that the second argument is the *uncompressed* size but the + * first argument is the *compressed* data (and it must be deflate + * compressed.) + */ +static void +optimize_cmf(png_bytep data, png_alloc_size_t data_size) +{ + /* Optimize the CMF field in the zlib stream. The resultant zlib stream is + * still compliant to the stream specification. + */ + if (data_size <= 16384) /* else windowBits must be 15 */ + { + unsigned int z_cmf = data[0]; /* zlib compression method and flags */ + + if ((z_cmf & 0x0f) == 8 && (z_cmf & 0xf0) <= 0x70) + { + unsigned int z_cinfo; + unsigned int half_z_window_size; + + z_cinfo = z_cmf >> 4; + half_z_window_size = 1U << (z_cinfo + 7); + + if (data_size <= half_z_window_size) /* else no change */ + { + unsigned int tmp; + + do + { + half_z_window_size >>= 1; + --z_cinfo; + } + while (z_cinfo > 0 && data_size <= half_z_window_size); + + z_cmf = (z_cmf & 0x0f) | (z_cinfo << 4); + + data[0] = (png_byte)z_cmf; + tmp = data[1] & 0xe0; + tmp += 0x1f - ((z_cmf << 8) + tmp) % 0x1f; + data[1] = (png_byte)tmp; + } + } + } +} +#else +# define optimize_cmf(dp,dl) ((void)0) +#endif /* PNG_WRITE_OPTIMIZE_CMF_SUPPORTED */ + +/* Initialize the compressor for the appropriate type of compression. */ +static int +png_deflate_claim(png_structrp png_ptr, png_uint_32 owner, + png_alloc_size_t data_size) +{ + if (png_ptr->zowner != 0) + { + char msg[64]; + + PNG_STRING_FROM_CHUNK(msg, owner); + msg[4] = ':'; + msg[5] = ' '; + PNG_STRING_FROM_CHUNK(msg+6, png_ptr->zowner); + /* So the message that results is " using zstream"; this is an + * internal error, but is very useful for debugging. i18n requirements + * are minimal. + */ + (void)png_safecat(msg, (sizeof msg), 10, " using zstream"); +# if PNG_LIBPNG_BUILD_BASE_TYPE >= PNG_LIBPNG_BUILD_RC + png_warning(png_ptr, msg); + + /* Attempt sane error recovery */ + if (png_ptr->zowner == png_IDAT) /* don't steal from IDAT */ + { + png_ptr->zstream.msg = PNGZ_MSG_CAST("in use by IDAT"); + return Z_STREAM_ERROR; + } + + png_ptr->zowner = 0; +# else + png_error(png_ptr, msg); +# endif + } + + { + int level = png_ptr->zlib_level; + int method = png_ptr->zlib_method; + int windowBits = png_ptr->zlib_window_bits; + int memLevel = png_ptr->zlib_mem_level; + int strategy; /* set below */ + int ret; /* zlib return code */ + + if (owner == png_IDAT) + { + if (png_ptr->flags & PNG_FLAG_ZLIB_CUSTOM_STRATEGY) + strategy = png_ptr->zlib_strategy; + + else if (png_ptr->do_filter != PNG_FILTER_NONE) + strategy = PNG_Z_DEFAULT_STRATEGY; + + else + strategy = PNG_Z_DEFAULT_NOFILTER_STRATEGY; + } + + else + { +# ifdef PNG_WRITE_CUSTOMIZE_ZTXT_COMPRESSION_SUPPORTED + level = png_ptr->zlib_text_level; + method = png_ptr->zlib_text_method; + windowBits = png_ptr->zlib_text_window_bits; + memLevel = png_ptr->zlib_text_mem_level; + strategy = png_ptr->zlib_text_strategy; +# else + /* If customization is not supported the values all come from the + * IDAT values except for the strategy, which is fixed to the + * default. (This is the pre-1.6.0 behavior too, although it was + * implemented in a very different way.) + */ + strategy = Z_DEFAULT_STRATEGY; +# endif + } + + /* Adjust 'windowBits' down if larger than 'data_size'; to stop this + * happening just pass 32768 as the data_size parameter. Notice that zlib + * requires an extra 262 bytes in the window in addition to the data to be + * able to see the whole of the data, so if data_size+262 takes us to the + * next windowBits size we need to fix up the value later. (Because even + * though deflate needs the extra window, inflate does not!) + */ + if (data_size <= 16384) + { + /* IMPLEMENTATION NOTE: this 'half_window_size' stuff is only here to + * work round a Microsoft Visual C misbehavior which, contrary to C-90, + * widens the result of the following shift to 64-bits if (and, + * apparently, only if) it is used in a test. + */ + unsigned int half_window_size = 1U << (windowBits-1); + + while (data_size + 262 <= half_window_size) + { + half_window_size >>= 1; + --windowBits; + } + } + + /* Check against the previous initialized values, if any. */ + if ((png_ptr->flags & PNG_FLAG_ZSTREAM_INITIALIZED) && + (png_ptr->zlib_set_level != level || + png_ptr->zlib_set_method != method || + png_ptr->zlib_set_window_bits != windowBits || + png_ptr->zlib_set_mem_level != memLevel || + png_ptr->zlib_set_strategy != strategy)) + { + if (deflateEnd(&png_ptr->zstream) != Z_OK) + png_warning(png_ptr, "deflateEnd failed (ignored)"); + + png_ptr->flags &= ~PNG_FLAG_ZSTREAM_INITIALIZED; + } + + /* For safety clear out the input and output pointers (currently zlib + * doesn't use them on Init, but it might in the future). + */ + png_ptr->zstream.next_in = NULL; + png_ptr->zstream.avail_in = 0; + png_ptr->zstream.next_out = NULL; + png_ptr->zstream.avail_out = 0; + + /* Now initialize if required, setting the new parameters, otherwise just + * to a simple reset to the previous parameters. + */ + if (png_ptr->flags & PNG_FLAG_ZSTREAM_INITIALIZED) + ret = deflateReset(&png_ptr->zstream); + + else + { + ret = deflateInit2(&png_ptr->zstream, level, method, windowBits, + memLevel, strategy); + + if (ret == Z_OK) + png_ptr->flags |= PNG_FLAG_ZSTREAM_INITIALIZED; + } + + /* The return code is from either deflateReset or deflateInit2; they have + * pretty much the same set of error codes. + */ + if (ret == Z_OK) + png_ptr->zowner = owner; + + else + png_zstream_error(png_ptr, ret); + + return ret; + } +} + +/* Clean up (or trim) a linked list of compression buffers. */ +void /* PRIVATE */ +png_free_buffer_list(png_structrp png_ptr, png_compression_bufferp *listp) +{ + png_compression_bufferp list = *listp; + + if (list != NULL) + { + *listp = NULL; + + do + { + png_compression_bufferp next = list->next; + + png_free(png_ptr, list); + list = next; + } + while (list != NULL); + } +} + +#ifdef PNG_WRITE_COMPRESSED_TEXT_SUPPORTED +/* This pair of functions encapsulates the operation of (a) compressing a + * text string, and (b) issuing it later as a series of chunk data writes. + * The compression_state structure is shared context for these functions + * set up by the caller to allow access to the relevant local variables. + * + * compression_buffer (new in 1.6.0) is just a linked list of zbuffer_size + * temporary buffers. From 1.6.0 it is retained in png_struct so that it will + * be correctly freed in the event of a write error (previous implementations + * just leaked memory.) + */ +typedef struct +{ + png_const_bytep input; /* The uncompressed input data */ + png_alloc_size_t input_len; /* Its length */ + png_uint_32 output_len; /* Final compressed length */ + png_byte output[1024]; /* First block of output */ +} compression_state; + +static void +png_text_compress_init(compression_state *comp, png_const_bytep input, + png_alloc_size_t input_len) +{ + comp->input = input; + comp->input_len = input_len; + comp->output_len = 0; +} + +/* Compress the data in the compression state input */ +static int +png_text_compress(png_structrp png_ptr, png_uint_32 chunk_name, + compression_state *comp, png_uint_32 prefix_len) +{ + int ret; + + /* To find the length of the output it is necessary to first compress the + * input, the result is buffered rather than using the two-pass algorithm + * that is used on the inflate side; deflate is assumed to be slower and a + * PNG writer is assumed to have more memory available than a PNG reader. + * + * IMPLEMENTATION NOTE: the zlib API deflateBound() can be used to find an + * upper limit on the output size, but it is always bigger than the input + * size so it is likely to be more efficient to use this linked-list + * approach. + */ + ret = png_deflate_claim(png_ptr, chunk_name, comp->input_len); + + if (ret != Z_OK) + return ret; + + /* Set up the compression buffers, we need a loop here to avoid overflowing a + * uInt. Use ZLIB_IO_MAX to limit the input. The output is always limited + * by the output buffer size, so there is no need to check that. Since this + * is ANSI-C we know that an 'int', hence a uInt, is always at least 16 bits + * in size. + */ + { + png_compression_bufferp *end = &png_ptr->zbuffer_list; + png_alloc_size_t input_len = comp->input_len; /* may be zero! */ + png_uint_32 output_len; + + /* zlib updates these for us: */ + png_ptr->zstream.next_in = PNGZ_INPUT_CAST(comp->input); + png_ptr->zstream.avail_in = 0; /* Set below */ + png_ptr->zstream.next_out = comp->output; + png_ptr->zstream.avail_out = (sizeof comp->output); + + output_len = png_ptr->zstream.avail_out; + + do + { + uInt avail_in = ZLIB_IO_MAX; + + if (avail_in > input_len) + avail_in = (uInt)input_len; + + input_len -= avail_in; + + png_ptr->zstream.avail_in = avail_in; + + if (png_ptr->zstream.avail_out == 0) + { + png_compression_buffer *next; + + /* Chunk data is limited to 2^31 bytes in length, so the prefix + * length must be counted here. + */ + if (output_len + prefix_len > PNG_UINT_31_MAX) + { + ret = Z_MEM_ERROR; + break; + } + + /* Need a new (malloc'ed) buffer, but there may be one present + * already. + */ + next = *end; + if (next == NULL) + { + next = png_voidcast(png_compression_bufferp, png_malloc_base + (png_ptr, PNG_COMPRESSION_BUFFER_SIZE(png_ptr))); + + if (next == NULL) + { + ret = Z_MEM_ERROR; + break; + } + + /* Link in this buffer (so that it will be freed later) */ + next->next = NULL; + *end = next; + } + + png_ptr->zstream.next_out = next->output; + png_ptr->zstream.avail_out = png_ptr->zbuffer_size; + output_len += png_ptr->zstream.avail_out; + + /* Move 'end' to the next buffer pointer. */ + end = &next->next; + } + + /* Compress the data */ + ret = deflate(&png_ptr->zstream, + input_len > 0 ? Z_NO_FLUSH : Z_FINISH); + + /* Claw back input data that was not consumed (because avail_in is + * reset above every time round the loop). + */ + input_len += png_ptr->zstream.avail_in; + png_ptr->zstream.avail_in = 0; /* safety */ + } + while (ret == Z_OK); + + /* There may be some space left in the last output buffer, this needs to + * be subtracted from output_len. + */ + output_len -= png_ptr->zstream.avail_out; + png_ptr->zstream.avail_out = 0; /* safety */ + comp->output_len = output_len; + + /* Now double check the output length, put in a custom message if it is + * too long. Otherwise ensure the z_stream::msg pointer is set to + * something. + */ + if (output_len + prefix_len >= PNG_UINT_31_MAX) + { + png_ptr->zstream.msg = PNGZ_MSG_CAST("compressed data too long"); + ret = Z_MEM_ERROR; + } + + else + png_zstream_error(png_ptr, ret); + + /* Reset zlib for another zTXt/iTXt or image data */ + png_ptr->zowner = 0; + + /* The only success case is Z_STREAM_END, input_len must be 0, if not this + * is an internal error. + */ + if (ret == Z_STREAM_END && input_len == 0) + { + /* Fix up the deflate header, if required */ + optimize_cmf(comp->output, comp->input_len); + + /* But Z_OK is returned, not Z_STREAM_END; this allows the claim + * function above to return Z_STREAM_END on an error (though it never + * does in the current versions of zlib.) + */ + return Z_OK; + } + + else + return ret; + } +} + +/* Ship the compressed text out via chunk writes */ +static void +png_write_compressed_data_out(png_structrp png_ptr, compression_state *comp) +{ + png_uint_32 output_len = comp->output_len; + png_const_bytep output = comp->output; + png_uint_32 avail = (sizeof comp->output); + png_compression_buffer *next = png_ptr->zbuffer_list; + + for (;;) + { + if (avail > output_len) + avail = output_len; + + png_write_chunk_data(png_ptr, output, avail); + + output_len -= avail; + + if (output_len == 0 || next == NULL) + break; + + avail = png_ptr->zbuffer_size; + output = next->output; + next = next->next; + } + + /* This is an internal error; 'next' must have been NULL! */ + if (output_len > 0) + png_error(png_ptr, "error writing ancillary chunked compressed data"); +} +#endif /* PNG_WRITE_COMPRESSED_TEXT_SUPPORTED */ + +#if defined(PNG_WRITE_TEXT_SUPPORTED) || defined(PNG_WRITE_pCAL_SUPPORTED) || \ + defined(PNG_WRITE_iCCP_SUPPORTED) || defined(PNG_WRITE_sPLT_SUPPORTED) +/* Check that the tEXt or zTXt keyword is valid per PNG 1.0 specification, + * and if invalid, correct the keyword rather than discarding the entire + * chunk. The PNG 1.0 specification requires keywords 1-79 characters in + * length, forbids leading or trailing whitespace, multiple internal spaces, + * and the non-break space (0x80) from ISO 8859-1. Returns keyword length. + * + * The 'new_key' buffer must be 80 characters in size (for the keyword plus a + * trailing '\0'). If this routine returns 0 then there was no keyword, or a + * valid one could not be generated, and the caller must png_error. + */ +static png_uint_32 +png_check_keyword(png_structrp png_ptr, png_const_charp key, png_bytep new_key) +{ + png_const_charp orig_key = key; + png_uint_32 key_len = 0; + int bad_character = 0; + int space = 1; + + png_debug(1, "in png_check_keyword"); + + if (key == NULL) + { + *new_key = 0; + return 0; + } + + while (*key && key_len < 79) + { + png_byte ch = (png_byte)(0xff & *key++); + + if ((ch > 32 && ch <= 126) || (ch >= 161 /*&& ch <= 255*/)) + *new_key++ = ch, ++key_len, space = 0; + + else if (!space) + { + /* A space or an invalid character when one wasn't seen immediately + * before; output just a space. + */ + *new_key++ = 32, ++key_len, space = 1; + + /* If the character was not a space then it is invalid. */ + if (ch != 32) + bad_character = ch; + } + + else if (!bad_character) + bad_character = ch; /* just skip it, record the first error */ + } + + if (key_len > 0 && space) /* trailing space */ + { + --key_len, --new_key; + if (!bad_character) + bad_character = 32; + } + + /* Terminate the keyword */ + *new_key = 0; + + if (key_len == 0) + return 0; + + /* Try to only output one warning per keyword: */ + if (*key) /* keyword too long */ + png_warning(png_ptr, "keyword truncated"); + + else if (bad_character) + { + PNG_WARNING_PARAMETERS(p) + + png_warning_parameter(p, 1, orig_key); + png_warning_parameter_signed(p, 2, PNG_NUMBER_FORMAT_02x, bad_character); + + png_formatted_warning(png_ptr, p, "keyword \"@1\": bad character '0x@2'"); + } + + return key_len; +} +#endif + +/* Write the IHDR chunk, and update the png_struct with the necessary + * information. Note that the rest of this code depends upon this + * information being correct. + */ +void /* PRIVATE */ +png_write_IHDR(png_structrp png_ptr, png_uint_32 width, png_uint_32 height, + int bit_depth, int color_type, int compression_type, int filter_type, + int interlace_type) +{ + png_byte buf[13]; /* Buffer to store the IHDR info */ + + png_debug(1, "in png_write_IHDR"); + + /* Check that we have valid input data from the application info */ + switch (color_type) + { + case PNG_COLOR_TYPE_GRAY: + switch (bit_depth) + { + case 1: + case 2: + case 4: + case 8: +#ifdef PNG_WRITE_16BIT_SUPPORTED + case 16: +#endif + png_ptr->channels = 1; break; + + default: + png_error(png_ptr, + "Invalid bit depth for grayscale image"); + } + break; + + case PNG_COLOR_TYPE_RGB: +#ifdef PNG_WRITE_16BIT_SUPPORTED + if (bit_depth != 8 && bit_depth != 16) +#else + if (bit_depth != 8) +#endif + png_error(png_ptr, "Invalid bit depth for RGB image"); + + png_ptr->channels = 3; + break; + + case PNG_COLOR_TYPE_PALETTE: + switch (bit_depth) + { + case 1: + case 2: + case 4: + case 8: + png_ptr->channels = 1; + break; + + default: + png_error(png_ptr, "Invalid bit depth for paletted image"); + } + break; + + case PNG_COLOR_TYPE_GRAY_ALPHA: + if (bit_depth != 8 && bit_depth != 16) + png_error(png_ptr, "Invalid bit depth for grayscale+alpha image"); + + png_ptr->channels = 2; + break; + + case PNG_COLOR_TYPE_RGB_ALPHA: +#ifdef PNG_WRITE_16BIT_SUPPORTED + if (bit_depth != 8 && bit_depth != 16) +#else + if (bit_depth != 8) +#endif + png_error(png_ptr, "Invalid bit depth for RGBA image"); + + png_ptr->channels = 4; + break; + + default: + png_error(png_ptr, "Invalid image color type specified"); + } + + if (compression_type != PNG_COMPRESSION_TYPE_BASE) + { + png_warning(png_ptr, "Invalid compression type specified"); + compression_type = PNG_COMPRESSION_TYPE_BASE; + } + + /* Write filter_method 64 (intrapixel differencing) only if + * 1. Libpng was compiled with PNG_MNG_FEATURES_SUPPORTED and + * 2. Libpng did not write a PNG signature (this filter_method is only + * used in PNG datastreams that are embedded in MNG datastreams) and + * 3. The application called png_permit_mng_features with a mask that + * included PNG_FLAG_MNG_FILTER_64 and + * 4. The filter_method is 64 and + * 5. The color_type is RGB or RGBA + */ + if ( +#ifdef PNG_MNG_FEATURES_SUPPORTED + !((png_ptr->mng_features_permitted & PNG_FLAG_MNG_FILTER_64) && + ((png_ptr->mode&PNG_HAVE_PNG_SIGNATURE) == 0) && + (color_type == PNG_COLOR_TYPE_RGB || + color_type == PNG_COLOR_TYPE_RGB_ALPHA) && + (filter_type == PNG_INTRAPIXEL_DIFFERENCING)) && +#endif + filter_type != PNG_FILTER_TYPE_BASE) + { + png_warning(png_ptr, "Invalid filter type specified"); + filter_type = PNG_FILTER_TYPE_BASE; + } + +#ifdef PNG_WRITE_INTERLACING_SUPPORTED + if (interlace_type != PNG_INTERLACE_NONE && + interlace_type != PNG_INTERLACE_ADAM7) + { + png_warning(png_ptr, "Invalid interlace type specified"); + interlace_type = PNG_INTERLACE_ADAM7; + } +#else + interlace_type=PNG_INTERLACE_NONE; +#endif + + /* Save the relevent information */ + png_ptr->bit_depth = (png_byte)bit_depth; + png_ptr->color_type = (png_byte)color_type; + png_ptr->interlaced = (png_byte)interlace_type; +#ifdef PNG_MNG_FEATURES_SUPPORTED + png_ptr->filter_type = (png_byte)filter_type; +#endif + png_ptr->compression_type = (png_byte)compression_type; + png_ptr->width = width; + png_ptr->height = height; + + png_ptr->pixel_depth = (png_byte)(bit_depth * png_ptr->channels); + png_ptr->rowbytes = PNG_ROWBYTES(png_ptr->pixel_depth, width); + /* Set the usr info, so any transformations can modify it */ + png_ptr->usr_width = png_ptr->width; + png_ptr->usr_bit_depth = png_ptr->bit_depth; + png_ptr->usr_channels = png_ptr->channels; + + /* Pack the header information into the buffer */ + png_save_uint_32(buf, width); + png_save_uint_32(buf + 4, height); + buf[8] = (png_byte)bit_depth; + buf[9] = (png_byte)color_type; + buf[10] = (png_byte)compression_type; + buf[11] = (png_byte)filter_type; + buf[12] = (png_byte)interlace_type; + + /* Write the chunk */ + png_write_complete_chunk(png_ptr, png_IHDR, buf, (png_size_t)13); + + if (!(png_ptr->do_filter)) + { + if (png_ptr->color_type == PNG_COLOR_TYPE_PALETTE || + png_ptr->bit_depth < 8) + png_ptr->do_filter = PNG_FILTER_NONE; + + else + png_ptr->do_filter = PNG_ALL_FILTERS; + } + + png_ptr->mode = PNG_HAVE_IHDR; /* not READY_FOR_ZTXT */ +} + +/* Write the palette. We are careful not to trust png_color to be in the + * correct order for PNG, so people can redefine it to any convenient + * structure. + */ +void /* PRIVATE */ +png_write_PLTE(png_structrp png_ptr, png_const_colorp palette, + png_uint_32 num_pal) +{ + png_uint_32 i; + png_const_colorp pal_ptr; + png_byte buf[3]; + + png_debug(1, "in png_write_PLTE"); + + if (( +#ifdef PNG_MNG_FEATURES_SUPPORTED + !(png_ptr->mng_features_permitted & PNG_FLAG_MNG_EMPTY_PLTE) && +#endif + num_pal == 0) || num_pal > 256) + { + if (png_ptr->color_type == PNG_COLOR_TYPE_PALETTE) + { + png_error(png_ptr, "Invalid number of colors in palette"); + } + + else + { + png_warning(png_ptr, "Invalid number of colors in palette"); + return; + } + } + + if (!(png_ptr->color_type&PNG_COLOR_MASK_COLOR)) + { + png_warning(png_ptr, + "Ignoring request to write a PLTE chunk in grayscale PNG"); + + return; + } + + png_ptr->num_palette = (png_uint_16)num_pal; + png_debug1(3, "num_palette = %d", png_ptr->num_palette); + + png_write_chunk_header(png_ptr, png_PLTE, (png_uint_32)(num_pal * 3)); +#ifdef PNG_POINTER_INDEXING_SUPPORTED + + for (i = 0, pal_ptr = palette; i < num_pal; i++, pal_ptr++) + { + buf[0] = pal_ptr->red; + buf[1] = pal_ptr->green; + buf[2] = pal_ptr->blue; + png_write_chunk_data(png_ptr, buf, (png_size_t)3); + } + +#else + /* This is a little slower but some buggy compilers need to do this + * instead + */ + pal_ptr=palette; + + for (i = 0; i < num_pal; i++) + { + buf[0] = pal_ptr[i].red; + buf[1] = pal_ptr[i].green; + buf[2] = pal_ptr[i].blue; + png_write_chunk_data(png_ptr, buf, (png_size_t)3); + } + +#endif + png_write_chunk_end(png_ptr); + png_ptr->mode |= PNG_HAVE_PLTE; +} + +/* This is similar to png_text_compress, above, except that it does not require + * all of the data at once and, instead of buffering the compressed result, + * writes it as IDAT chunks. Unlike png_text_compress it *can* png_error out + * because it calls the write interface. As a result it does its own error + * reporting and does not return an error code. In the event of error it will + * just call png_error. The input data length may exceed 32-bits. The 'flush' + * parameter is exactly the same as that to deflate, with the following + * meanings: + * + * Z_NO_FLUSH: normal incremental output of compressed data + * Z_SYNC_FLUSH: do a SYNC_FLUSH, used by png_write_flush + * Z_FINISH: this is the end of the input, do a Z_FINISH and clean up + * + * The routine manages the acquire and release of the png_ptr->zstream by + * checking and (at the end) clearing png_ptr->zowner, it does some sanity + * checks on the 'mode' flags while doing this. + */ +void /* PRIVATE */ +png_compress_IDAT(png_structrp png_ptr, png_const_bytep input, + png_alloc_size_t input_len, int flush) +{ + if (png_ptr->zowner != png_IDAT) + { + /* First time. Ensure we have a temporary buffer for compression and + * trim the buffer list if it has more than one entry to free memory. + * If 'WRITE_COMPRESSED_TEXT' is not set the list will never have been + * created at this point, but the check here is quick and safe. + */ + if (png_ptr->zbuffer_list == NULL) + { + png_ptr->zbuffer_list = png_voidcast(png_compression_bufferp, + png_malloc(png_ptr, PNG_COMPRESSION_BUFFER_SIZE(png_ptr))); + png_ptr->zbuffer_list->next = NULL; + } + + else + png_free_buffer_list(png_ptr, &png_ptr->zbuffer_list->next); + + /* It is a terminal error if we can't claim the zstream. */ + if (png_deflate_claim(png_ptr, png_IDAT, png_image_size(png_ptr)) != Z_OK) + png_error(png_ptr, png_ptr->zstream.msg); + + /* The output state is maintained in png_ptr->zstream, so it must be + * initialized here after the claim. + */ + png_ptr->zstream.next_out = png_ptr->zbuffer_list->output; + png_ptr->zstream.avail_out = png_ptr->zbuffer_size; + } + + /* Now loop reading and writing until all the input is consumed or an error + * terminates the operation. The _out values are maintained across calls to + * this function, but the input must be reset each time. + */ + png_ptr->zstream.next_in = PNGZ_INPUT_CAST(input); + png_ptr->zstream.avail_in = 0; /* set below */ + for (;;) + { + int ret; + + /* INPUT: from the row data */ + uInt avail = ZLIB_IO_MAX; + + if (avail > input_len) + avail = (uInt)input_len; /* safe because of the check */ + + png_ptr->zstream.avail_in = avail; + input_len -= avail; + + ret = deflate(&png_ptr->zstream, input_len > 0 ? Z_NO_FLUSH : flush); + + /* Include as-yet unconsumed input */ + input_len += png_ptr->zstream.avail_in; + png_ptr->zstream.avail_in = 0; + + /* OUTPUT: write complete IDAT chunks when avail_out drops to zero, note + * that these two zstream fields are preserved across the calls, therefore + * there is no need to set these up on entry to the loop. + */ + if (png_ptr->zstream.avail_out == 0) + { + png_bytep data = png_ptr->zbuffer_list->output; + uInt size = png_ptr->zbuffer_size; + + /* Write an IDAT containing the data then reset the buffer. The + * first IDAT may need deflate header optimization. + */ +# ifdef PNG_WRITE_OPTIMIZE_CMF_SUPPORTED + if (!(png_ptr->mode & PNG_HAVE_IDAT) && + png_ptr->compression_type == PNG_COMPRESSION_TYPE_BASE) + optimize_cmf(data, png_image_size(png_ptr)); +# endif + + png_write_complete_chunk(png_ptr, png_IDAT, data, size); + png_ptr->mode |= PNG_HAVE_IDAT; + + png_ptr->zstream.next_out = data; + png_ptr->zstream.avail_out = size; + + /* For SYNC_FLUSH or FINISH it is essential to keep calling zlib with + * the same flush parameter until it has finished output, for NO_FLUSH + * it doesn't matter. + */ + if (ret == Z_OK && flush != Z_NO_FLUSH) + continue; + } + + /* The order of these checks doesn't matter much; it just effect which + * possible error might be detected if multiple things go wrong at once. + */ + if (ret == Z_OK) /* most likely return code! */ + { + /* If all the input has been consumed then just return. If Z_FINISH + * was used as the flush parameter something has gone wrong if we get + * here. + */ + if (input_len == 0) + { + if (flush == Z_FINISH) + png_error(png_ptr, "Z_OK on Z_FINISH with output space"); + + return; + } + } + + else if (ret == Z_STREAM_END && flush == Z_FINISH) + { + /* This is the end of the IDAT data; any pending output must be + * flushed. For small PNG files we may still be at the beginning. + */ + png_bytep data = png_ptr->zbuffer_list->output; + uInt size = png_ptr->zbuffer_size - png_ptr->zstream.avail_out; + +# ifdef PNG_WRITE_OPTIMIZE_CMF_SUPPORTED + if (!(png_ptr->mode & PNG_HAVE_IDAT) && + png_ptr->compression_type == PNG_COMPRESSION_TYPE_BASE) + optimize_cmf(data, png_image_size(png_ptr)); +# endif + + png_write_complete_chunk(png_ptr, png_IDAT, data, size); + png_ptr->zstream.avail_out = 0; + png_ptr->zstream.next_out = NULL; + png_ptr->mode |= PNG_HAVE_IDAT | PNG_AFTER_IDAT; + + png_ptr->zowner = 0; /* Release the stream */ + return; + } + + else + { + /* This is an error condition. */ + png_zstream_error(png_ptr, ret); + png_error(png_ptr, png_ptr->zstream.msg); + } + } +} + +/* Write an IEND chunk */ +void /* PRIVATE */ +png_write_IEND(png_structrp png_ptr) +{ + png_debug(1, "in png_write_IEND"); + + png_write_complete_chunk(png_ptr, png_IEND, NULL, (png_size_t)0); + png_ptr->mode |= PNG_HAVE_IEND; +} + +#ifdef PNG_WRITE_gAMA_SUPPORTED +/* Write a gAMA chunk */ +void /* PRIVATE */ +png_write_gAMA_fixed(png_structrp png_ptr, png_fixed_point file_gamma) +{ + png_byte buf[4]; + + png_debug(1, "in png_write_gAMA"); + + /* file_gamma is saved in 1/100,000ths */ + png_save_uint_32(buf, (png_uint_32)file_gamma); + png_write_complete_chunk(png_ptr, png_gAMA, buf, (png_size_t)4); +} +#endif + +#ifdef PNG_WRITE_sRGB_SUPPORTED +/* Write a sRGB chunk */ +void /* PRIVATE */ +png_write_sRGB(png_structrp png_ptr, int srgb_intent) +{ + png_byte buf[1]; + + png_debug(1, "in png_write_sRGB"); + + if (srgb_intent >= PNG_sRGB_INTENT_LAST) + png_warning(png_ptr, + "Invalid sRGB rendering intent specified"); + + buf[0]=(png_byte)srgb_intent; + png_write_complete_chunk(png_ptr, png_sRGB, buf, (png_size_t)1); +} +#endif + +#ifdef PNG_WRITE_iCCP_SUPPORTED +/* Write an iCCP chunk */ +void /* PRIVATE */ +png_write_iCCP(png_structrp png_ptr, png_const_charp name, + png_const_bytep profile) +{ + png_uint_32 name_len; + png_uint_32 profile_len; + png_byte new_name[81]; /* 1 byte for the compression byte */ + compression_state comp; + + png_debug(1, "in png_write_iCCP"); + + /* These are all internal problems: the profile should have been checked + * before when it was stored. + */ + if (profile == NULL) + png_error(png_ptr, "No profile for iCCP chunk"); /* internal error */ + + profile_len = png_get_uint_32(profile); + + if (profile_len < 132) + png_error(png_ptr, "ICC profile too short"); + + if (profile_len & 0x03) + png_error(png_ptr, "ICC profile length invalid (not a multiple of 4)"); + + { + png_uint_32 embedded_profile_len = png_get_uint_32(profile); + + if (profile_len != embedded_profile_len) + png_error(png_ptr, "Profile length does not match profile"); + } + + name_len = png_check_keyword(png_ptr, name, new_name); + + if (name_len == 0) + png_error(png_ptr, "iCCP: invalid keyword"); + + new_name[++name_len] = PNG_COMPRESSION_TYPE_BASE; + + /* Make sure we include the NULL after the name and the compression type */ + ++name_len; + + png_text_compress_init(&comp, profile, profile_len); + + /* Allow for keyword terminator and compression byte */ + if (png_text_compress(png_ptr, png_iCCP, &comp, name_len) != Z_OK) + png_error(png_ptr, png_ptr->zstream.msg); + + png_write_chunk_header(png_ptr, png_iCCP, name_len + comp.output_len); + + png_write_chunk_data(png_ptr, new_name, name_len); + + png_write_compressed_data_out(png_ptr, &comp); + + png_write_chunk_end(png_ptr); +} +#endif + +#ifdef PNG_WRITE_sPLT_SUPPORTED +/* Write a sPLT chunk */ +void /* PRIVATE */ +png_write_sPLT(png_structrp png_ptr, png_const_sPLT_tp spalette) +{ + png_uint_32 name_len; + png_byte new_name[80]; + png_byte entrybuf[10]; + png_size_t entry_size = (spalette->depth == 8 ? 6 : 10); + png_size_t palette_size = entry_size * spalette->nentries; + png_sPLT_entryp ep; +#ifndef PNG_POINTER_INDEXING_SUPPORTED + int i; +#endif + + png_debug(1, "in png_write_sPLT"); + + name_len = png_check_keyword(png_ptr, spalette->name, new_name); + + if (name_len == 0) + png_error(png_ptr, "sPLT: invalid keyword"); + + /* Make sure we include the NULL after the name */ + png_write_chunk_header(png_ptr, png_sPLT, + (png_uint_32)(name_len + 2 + palette_size)); + + png_write_chunk_data(png_ptr, (png_bytep)new_name, + (png_size_t)(name_len + 1)); + + png_write_chunk_data(png_ptr, &spalette->depth, (png_size_t)1); + + /* Loop through each palette entry, writing appropriately */ +#ifdef PNG_POINTER_INDEXING_SUPPORTED + for (ep = spalette->entries; epentries + spalette->nentries; ep++) + { + if (spalette->depth == 8) + { + entrybuf[0] = (png_byte)ep->red; + entrybuf[1] = (png_byte)ep->green; + entrybuf[2] = (png_byte)ep->blue; + entrybuf[3] = (png_byte)ep->alpha; + png_save_uint_16(entrybuf + 4, ep->frequency); + } + + else + { + png_save_uint_16(entrybuf + 0, ep->red); + png_save_uint_16(entrybuf + 2, ep->green); + png_save_uint_16(entrybuf + 4, ep->blue); + png_save_uint_16(entrybuf + 6, ep->alpha); + png_save_uint_16(entrybuf + 8, ep->frequency); + } + + png_write_chunk_data(png_ptr, entrybuf, entry_size); + } +#else + ep=spalette->entries; + for (i = 0; i>spalette->nentries; i++) + { + if (spalette->depth == 8) + { + entrybuf[0] = (png_byte)ep[i].red; + entrybuf[1] = (png_byte)ep[i].green; + entrybuf[2] = (png_byte)ep[i].blue; + entrybuf[3] = (png_byte)ep[i].alpha; + png_save_uint_16(entrybuf + 4, ep[i].frequency); + } + + else + { + png_save_uint_16(entrybuf + 0, ep[i].red); + png_save_uint_16(entrybuf + 2, ep[i].green); + png_save_uint_16(entrybuf + 4, ep[i].blue); + png_save_uint_16(entrybuf + 6, ep[i].alpha); + png_save_uint_16(entrybuf + 8, ep[i].frequency); + } + + png_write_chunk_data(png_ptr, entrybuf, entry_size); + } +#endif + + png_write_chunk_end(png_ptr); +} +#endif + +#ifdef PNG_WRITE_sBIT_SUPPORTED +/* Write the sBIT chunk */ +void /* PRIVATE */ +png_write_sBIT(png_structrp png_ptr, png_const_color_8p sbit, int color_type) +{ + png_byte buf[4]; + png_size_t size; + + png_debug(1, "in png_write_sBIT"); + + /* Make sure we don't depend upon the order of PNG_COLOR_8 */ + if (color_type & PNG_COLOR_MASK_COLOR) + { + png_byte maxbits; + + maxbits = (png_byte)(color_type==PNG_COLOR_TYPE_PALETTE ? 8 : + png_ptr->usr_bit_depth); + + if (sbit->red == 0 || sbit->red > maxbits || + sbit->green == 0 || sbit->green > maxbits || + sbit->blue == 0 || sbit->blue > maxbits) + { + png_warning(png_ptr, "Invalid sBIT depth specified"); + return; + } + + buf[0] = sbit->red; + buf[1] = sbit->green; + buf[2] = sbit->blue; + size = 3; + } + + else + { + if (sbit->gray == 0 || sbit->gray > png_ptr->usr_bit_depth) + { + png_warning(png_ptr, "Invalid sBIT depth specified"); + return; + } + + buf[0] = sbit->gray; + size = 1; + } + + if (color_type & PNG_COLOR_MASK_ALPHA) + { + if (sbit->alpha == 0 || sbit->alpha > png_ptr->usr_bit_depth) + { + png_warning(png_ptr, "Invalid sBIT depth specified"); + return; + } + + buf[size++] = sbit->alpha; + } + + png_write_complete_chunk(png_ptr, png_sBIT, buf, size); +} +#endif + +#ifdef PNG_WRITE_cHRM_SUPPORTED +/* Write the cHRM chunk */ +void /* PRIVATE */ +png_write_cHRM_fixed(png_structrp png_ptr, const png_xy *xy) +{ + png_byte buf[32]; + + png_debug(1, "in png_write_cHRM"); + + /* Each value is saved in 1/100,000ths */ + png_save_int_32(buf, xy->whitex); + png_save_int_32(buf + 4, xy->whitey); + + png_save_int_32(buf + 8, xy->redx); + png_save_int_32(buf + 12, xy->redy); + + png_save_int_32(buf + 16, xy->greenx); + png_save_int_32(buf + 20, xy->greeny); + + png_save_int_32(buf + 24, xy->bluex); + png_save_int_32(buf + 28, xy->bluey); + + png_write_complete_chunk(png_ptr, png_cHRM, buf, 32); +} +#endif + +#ifdef PNG_WRITE_tRNS_SUPPORTED +/* Write the tRNS chunk */ +void /* PRIVATE */ +png_write_tRNS(png_structrp png_ptr, png_const_bytep trans_alpha, + png_const_color_16p tran, int num_trans, int color_type) +{ + png_byte buf[6]; + + png_debug(1, "in png_write_tRNS"); + + if (color_type == PNG_COLOR_TYPE_PALETTE) + { + if (num_trans <= 0 || num_trans > (int)png_ptr->num_palette) + { + png_app_warning(png_ptr, + "Invalid number of transparent colors specified"); + return; + } + + /* Write the chunk out as it is */ + png_write_complete_chunk(png_ptr, png_tRNS, trans_alpha, + (png_size_t)num_trans); + } + + else if (color_type == PNG_COLOR_TYPE_GRAY) + { + /* One 16 bit value */ + if (tran->gray >= (1 << png_ptr->bit_depth)) + { + png_app_warning(png_ptr, + "Ignoring attempt to write tRNS chunk out-of-range for bit_depth"); + + return; + } + + png_save_uint_16(buf, tran->gray); + png_write_complete_chunk(png_ptr, png_tRNS, buf, (png_size_t)2); + } + + else if (color_type == PNG_COLOR_TYPE_RGB) + { + /* Three 16 bit values */ + png_save_uint_16(buf, tran->red); + png_save_uint_16(buf + 2, tran->green); + png_save_uint_16(buf + 4, tran->blue); +#ifdef PNG_WRITE_16BIT_SUPPORTED + if (png_ptr->bit_depth == 8 && (buf[0] | buf[2] | buf[4])) +#else + if (buf[0] | buf[2] | buf[4]) +#endif + { + png_app_warning(png_ptr, + "Ignoring attempt to write 16-bit tRNS chunk when bit_depth is 8"); + return; + } + + png_write_complete_chunk(png_ptr, png_tRNS, buf, (png_size_t)6); + } + + else + { + png_app_warning(png_ptr, "Can't write tRNS with an alpha channel"); + } +} +#endif + +#ifdef PNG_WRITE_bKGD_SUPPORTED +/* Write the background chunk */ +void /* PRIVATE */ +png_write_bKGD(png_structrp png_ptr, png_const_color_16p back, int color_type) +{ + png_byte buf[6]; + + png_debug(1, "in png_write_bKGD"); + + if (color_type == PNG_COLOR_TYPE_PALETTE) + { + if ( +#ifdef PNG_MNG_FEATURES_SUPPORTED + (png_ptr->num_palette || + (!(png_ptr->mng_features_permitted & PNG_FLAG_MNG_EMPTY_PLTE))) && +#endif + back->index >= png_ptr->num_palette) + { + png_warning(png_ptr, "Invalid background palette index"); + return; + } + + buf[0] = back->index; + png_write_complete_chunk(png_ptr, png_bKGD, buf, (png_size_t)1); + } + + else if (color_type & PNG_COLOR_MASK_COLOR) + { + png_save_uint_16(buf, back->red); + png_save_uint_16(buf + 2, back->green); + png_save_uint_16(buf + 4, back->blue); +#ifdef PNG_WRITE_16BIT_SUPPORTED + if (png_ptr->bit_depth == 8 && (buf[0] | buf[2] | buf[4])) +#else + if (buf[0] | buf[2] | buf[4]) +#endif + { + png_warning(png_ptr, + "Ignoring attempt to write 16-bit bKGD chunk when bit_depth is 8"); + + return; + } + + png_write_complete_chunk(png_ptr, png_bKGD, buf, (png_size_t)6); + } + + else + { + if (back->gray >= (1 << png_ptr->bit_depth)) + { + png_warning(png_ptr, + "Ignoring attempt to write bKGD chunk out-of-range for bit_depth"); + + return; + } + + png_save_uint_16(buf, back->gray); + png_write_complete_chunk(png_ptr, png_bKGD, buf, (png_size_t)2); + } +} +#endif + +#ifdef PNG_WRITE_hIST_SUPPORTED +/* Write the histogram */ +void /* PRIVATE */ +png_write_hIST(png_structrp png_ptr, png_const_uint_16p hist, int num_hist) +{ + int i; + png_byte buf[3]; + + png_debug(1, "in png_write_hIST"); + + if (num_hist > (int)png_ptr->num_palette) + { + png_debug2(3, "num_hist = %d, num_palette = %d", num_hist, + png_ptr->num_palette); + + png_warning(png_ptr, "Invalid number of histogram entries specified"); + return; + } + + png_write_chunk_header(png_ptr, png_hIST, (png_uint_32)(num_hist * 2)); + + for (i = 0; i < num_hist; i++) + { + png_save_uint_16(buf, hist[i]); + png_write_chunk_data(png_ptr, buf, (png_size_t)2); + } + + png_write_chunk_end(png_ptr); +} +#endif + +#ifdef PNG_WRITE_tEXt_SUPPORTED +/* Write a tEXt chunk */ +void /* PRIVATE */ +png_write_tEXt(png_structrp png_ptr, png_const_charp key, png_const_charp text, + png_size_t text_len) +{ + png_uint_32 key_len; + png_byte new_key[80]; + + png_debug(1, "in png_write_tEXt"); + + key_len = png_check_keyword(png_ptr, key, new_key); + + if (key_len == 0) + png_error(png_ptr, "tEXt: invalid keyword"); + + if (text == NULL || *text == '\0') + text_len = 0; + + else + text_len = strlen(text); + + if (text_len > PNG_UINT_31_MAX - (key_len+1)) + png_error(png_ptr, "tEXt: text too long"); + + /* Make sure we include the 0 after the key */ + png_write_chunk_header(png_ptr, png_tEXt, + (png_uint_32)/*checked above*/(key_len + text_len + 1)); + /* + * We leave it to the application to meet PNG-1.0 requirements on the + * contents of the text. PNG-1.0 through PNG-1.2 discourage the use of + * any non-Latin-1 characters except for NEWLINE. ISO PNG will forbid them. + * The NUL character is forbidden by PNG-1.0 through PNG-1.2 and ISO PNG. + */ + png_write_chunk_data(png_ptr, new_key, key_len + 1); + + if (text_len) + png_write_chunk_data(png_ptr, (png_const_bytep)text, text_len); + + png_write_chunk_end(png_ptr); +} +#endif + +#ifdef PNG_WRITE_zTXt_SUPPORTED +/* Write a compressed text chunk */ +void /* PRIVATE */ +png_write_zTXt(png_structrp png_ptr, png_const_charp key, png_const_charp text, + png_size_t text_len, int compression) +{ + png_uint_32 key_len; + png_byte new_key[81]; + compression_state comp; + + png_debug(1, "in png_write_zTXt"); + PNG_UNUSED(text_len) /* Always use strlen */ + + if (compression == PNG_TEXT_COMPRESSION_NONE) + { + png_write_tEXt(png_ptr, key, text, 0); + return; + } + + if (compression != PNG_TEXT_COMPRESSION_zTXt) + png_error(png_ptr, "zTXt: invalid compression type"); + + key_len = png_check_keyword(png_ptr, key, new_key); + + if (key_len == 0) + png_error(png_ptr, "zTXt: invalid keyword"); + + /* Add the compression method and 1 for the keyword separator. */ + new_key[++key_len] = PNG_COMPRESSION_TYPE_BASE; + ++key_len; + + /* Compute the compressed data; do it now for the length */ + png_text_compress_init(&comp, (png_const_bytep)text, + text == NULL ? 0 : strlen(text)); + + if (png_text_compress(png_ptr, png_zTXt, &comp, key_len) != Z_OK) + png_error(png_ptr, png_ptr->zstream.msg); + + /* Write start of chunk */ + png_write_chunk_header(png_ptr, png_zTXt, key_len + comp.output_len); + + /* Write key */ + png_write_chunk_data(png_ptr, new_key, key_len); + + /* Write the compressed data */ + png_write_compressed_data_out(png_ptr, &comp); + + /* Close the chunk */ + png_write_chunk_end(png_ptr); +} +#endif + +#ifdef PNG_WRITE_iTXt_SUPPORTED +/* Write an iTXt chunk */ +void /* PRIVATE */ +png_write_iTXt(png_structrp png_ptr, int compression, png_const_charp key, + png_const_charp lang, png_const_charp lang_key, png_const_charp text) +{ + png_uint_32 key_len, prefix_len; + png_size_t lang_len, lang_key_len; + png_byte new_key[82]; + compression_state comp; + + png_debug(1, "in png_write_iTXt"); + + key_len = png_check_keyword(png_ptr, key, new_key); + + if (key_len == 0) + png_error(png_ptr, "iTXt: invalid keyword"); + + /* Set the compression flag */ + switch (compression) + { + case PNG_ITXT_COMPRESSION_NONE: + case PNG_TEXT_COMPRESSION_NONE: + compression = new_key[++key_len] = 0; /* no compression */ + break; + + case PNG_TEXT_COMPRESSION_zTXt: + case PNG_ITXT_COMPRESSION_zTXt: + compression = new_key[++key_len] = 1; /* compressed */ + break; + + default: + png_error(png_ptr, "iTXt: invalid compression"); + } + + new_key[++key_len] = PNG_COMPRESSION_TYPE_BASE; + ++key_len; /* for the keywod separator */ + + /* We leave it to the application to meet PNG-1.0 requirements on the + * contents of the text. PNG-1.0 through PNG-1.2 discourage the use of + * any non-Latin-1 characters except for NEWLINE. ISO PNG, however, + * specifies that the text is UTF-8 and this really doesn't require any + * checking. + * + * The NUL character is forbidden by PNG-1.0 through PNG-1.2 and ISO PNG. + * + * TODO: validate the language tag correctly (see the spec.) + */ + if (lang == NULL) lang = ""; /* empty language is valid */ + lang_len = strlen(lang)+1; + if (lang_key == NULL) lang_key = ""; /* may be empty */ + lang_key_len = strlen(lang_key)+1; + if (text == NULL) text = ""; /* may be empty */ + + prefix_len = key_len; + if (lang_len > PNG_UINT_31_MAX-prefix_len) + prefix_len = PNG_UINT_31_MAX; + else + prefix_len = (png_uint_32)(prefix_len + lang_len); + + if (lang_key_len > PNG_UINT_31_MAX-prefix_len) + prefix_len = PNG_UINT_31_MAX; + else + prefix_len = (png_uint_32)(prefix_len + lang_key_len); + + png_text_compress_init(&comp, (png_const_bytep)text, strlen(text)); + + if (compression) + { + if (png_text_compress(png_ptr, png_iTXt, &comp, prefix_len) != Z_OK) + png_error(png_ptr, png_ptr->zstream.msg); + } + + else + { + if (comp.input_len > PNG_UINT_31_MAX-prefix_len) + png_error(png_ptr, "iTXt: uncompressed text too long"); + + /* So the string will fit in a chunk: */ + comp.output_len = (png_uint_32)/*SAFE*/comp.input_len; + } + + png_write_chunk_header(png_ptr, png_iTXt, comp.output_len + prefix_len); + + png_write_chunk_data(png_ptr, new_key, key_len); + + png_write_chunk_data(png_ptr, (png_const_bytep)lang, lang_len); + + png_write_chunk_data(png_ptr, (png_const_bytep)lang_key, lang_key_len); + + if (compression) + png_write_compressed_data_out(png_ptr, &comp); + + else + png_write_chunk_data(png_ptr, (png_const_bytep)text, comp.input_len); + + png_write_chunk_end(png_ptr); +} +#endif + +#ifdef PNG_WRITE_oFFs_SUPPORTED +/* Write the oFFs chunk */ +void /* PRIVATE */ +png_write_oFFs(png_structrp png_ptr, png_int_32 x_offset, png_int_32 y_offset, + int unit_type) +{ + png_byte buf[9]; + + png_debug(1, "in png_write_oFFs"); + + if (unit_type >= PNG_OFFSET_LAST) + png_warning(png_ptr, "Unrecognized unit type for oFFs chunk"); + + png_save_int_32(buf, x_offset); + png_save_int_32(buf + 4, y_offset); + buf[8] = (png_byte)unit_type; + + png_write_complete_chunk(png_ptr, png_oFFs, buf, (png_size_t)9); +} +#endif +#ifdef PNG_WRITE_pCAL_SUPPORTED +/* Write the pCAL chunk (described in the PNG extensions document) */ +void /* PRIVATE */ +png_write_pCAL(png_structrp png_ptr, png_charp purpose, png_int_32 X0, + png_int_32 X1, int type, int nparams, png_const_charp units, + png_charpp params) +{ + png_uint_32 purpose_len; + png_size_t units_len, total_len; + png_size_tp params_len; + png_byte buf[10]; + png_byte new_purpose[80]; + int i; + + png_debug1(1, "in png_write_pCAL (%d parameters)", nparams); + + if (type >= PNG_EQUATION_LAST) + png_error(png_ptr, "Unrecognized equation type for pCAL chunk"); + + purpose_len = png_check_keyword(png_ptr, purpose, new_purpose); + + if (purpose_len == 0) + png_error(png_ptr, "pCAL: invalid keyword"); + + ++purpose_len; /* terminator */ + + png_debug1(3, "pCAL purpose length = %d", (int)purpose_len); + units_len = strlen(units) + (nparams == 0 ? 0 : 1); + png_debug1(3, "pCAL units length = %d", (int)units_len); + total_len = purpose_len + units_len + 10; + + params_len = (png_size_tp)png_malloc(png_ptr, + (png_alloc_size_t)(nparams * (sizeof (png_size_t)))); + + /* Find the length of each parameter, making sure we don't count the + * null terminator for the last parameter. + */ + for (i = 0; i < nparams; i++) + { + params_len[i] = strlen(params[i]) + (i == nparams - 1 ? 0 : 1); + png_debug2(3, "pCAL parameter %d length = %lu", i, + (unsigned long)params_len[i]); + total_len += params_len[i]; + } + + png_debug1(3, "pCAL total length = %d", (int)total_len); + png_write_chunk_header(png_ptr, png_pCAL, (png_uint_32)total_len); + png_write_chunk_data(png_ptr, new_purpose, purpose_len); + png_save_int_32(buf, X0); + png_save_int_32(buf + 4, X1); + buf[8] = (png_byte)type; + buf[9] = (png_byte)nparams; + png_write_chunk_data(png_ptr, buf, (png_size_t)10); + png_write_chunk_data(png_ptr, (png_const_bytep)units, (png_size_t)units_len); + + for (i = 0; i < nparams; i++) + { + png_write_chunk_data(png_ptr, (png_const_bytep)params[i], params_len[i]); + } + + png_free(png_ptr, params_len); + png_write_chunk_end(png_ptr); +} +#endif + +#ifdef PNG_WRITE_sCAL_SUPPORTED +/* Write the sCAL chunk */ +void /* PRIVATE */ +png_write_sCAL_s(png_structrp png_ptr, int unit, png_const_charp width, + png_const_charp height) +{ + png_byte buf[64]; + png_size_t wlen, hlen, total_len; + + png_debug(1, "in png_write_sCAL_s"); + + wlen = strlen(width); + hlen = strlen(height); + total_len = wlen + hlen + 2; + + if (total_len > 64) + { + png_warning(png_ptr, "Can't write sCAL (buffer too small)"); + return; + } + + buf[0] = (png_byte)unit; + memcpy(buf + 1, width, wlen + 1); /* Append the '\0' here */ + memcpy(buf + wlen + 2, height, hlen); /* Do NOT append the '\0' here */ + + png_debug1(3, "sCAL total length = %u", (unsigned int)total_len); + png_write_complete_chunk(png_ptr, png_sCAL, buf, total_len); +} +#endif + +#ifdef PNG_WRITE_pHYs_SUPPORTED +/* Write the pHYs chunk */ +void /* PRIVATE */ +png_write_pHYs(png_structrp png_ptr, png_uint_32 x_pixels_per_unit, + png_uint_32 y_pixels_per_unit, + int unit_type) +{ + png_byte buf[9]; + + png_debug(1, "in png_write_pHYs"); + + if (unit_type >= PNG_RESOLUTION_LAST) + png_warning(png_ptr, "Unrecognized unit type for pHYs chunk"); + + png_save_uint_32(buf, x_pixels_per_unit); + png_save_uint_32(buf + 4, y_pixels_per_unit); + buf[8] = (png_byte)unit_type; + + png_write_complete_chunk(png_ptr, png_pHYs, buf, (png_size_t)9); +} +#endif + +#ifdef PNG_WRITE_tIME_SUPPORTED +/* Write the tIME chunk. Use either png_convert_from_struct_tm() + * or png_convert_from_time_t(), or fill in the structure yourself. + */ +void /* PRIVATE */ +png_write_tIME(png_structrp png_ptr, png_const_timep mod_time) +{ + png_byte buf[7]; + + png_debug(1, "in png_write_tIME"); + + if (mod_time->month > 12 || mod_time->month < 1 || + mod_time->day > 31 || mod_time->day < 1 || + mod_time->hour > 23 || mod_time->second > 60) + { + png_warning(png_ptr, "Invalid time specified for tIME chunk"); + return; + } + + png_save_uint_16(buf, mod_time->year); + buf[2] = mod_time->month; + buf[3] = mod_time->day; + buf[4] = mod_time->hour; + buf[5] = mod_time->minute; + buf[6] = mod_time->second; + + png_write_complete_chunk(png_ptr, png_tIME, buf, (png_size_t)7); +} +#endif + +/* Initializes the row writing capability of libpng */ +void /* PRIVATE */ +png_write_start_row(png_structrp png_ptr) +{ +#ifdef PNG_WRITE_INTERLACING_SUPPORTED + /* Arrays to facilitate easy interlacing - use pass (0 - 6) as index */ + + /* Start of interlace block */ + static PNG_CONST png_byte png_pass_start[7] = {0, 4, 0, 2, 0, 1, 0}; + + /* Offset to next interlace block */ + static PNG_CONST png_byte png_pass_inc[7] = {8, 8, 4, 4, 2, 2, 1}; + + /* Start of interlace block in the y direction */ + static PNG_CONST png_byte png_pass_ystart[7] = {0, 0, 4, 0, 2, 0, 1}; + + /* Offset to next interlace block in the y direction */ + static PNG_CONST png_byte png_pass_yinc[7] = {8, 8, 8, 4, 4, 2, 2}; +#endif + + png_alloc_size_t buf_size; + int usr_pixel_depth; + + png_debug(1, "in png_write_start_row"); + + usr_pixel_depth = png_ptr->usr_channels * png_ptr->usr_bit_depth; + buf_size = PNG_ROWBYTES(usr_pixel_depth, png_ptr->width) + 1; + + /* 1.5.6: added to allow checking in the row write code. */ + png_ptr->transformed_pixel_depth = png_ptr->pixel_depth; + png_ptr->maximum_pixel_depth = (png_byte)usr_pixel_depth; + + /* Set up row buffer */ + png_ptr->row_buf = (png_bytep)png_malloc(png_ptr, buf_size); + + png_ptr->row_buf[0] = PNG_FILTER_VALUE_NONE; + +#ifdef PNG_WRITE_FILTER_SUPPORTED + /* Set up filtering buffer, if using this filter */ + if (png_ptr->do_filter & PNG_FILTER_SUB) + { + png_ptr->sub_row = (png_bytep)png_malloc(png_ptr, png_ptr->rowbytes + 1); + + png_ptr->sub_row[0] = PNG_FILTER_VALUE_SUB; + } + + /* We only need to keep the previous row if we are using one of these. */ + if (png_ptr->do_filter & (PNG_FILTER_AVG | PNG_FILTER_UP | PNG_FILTER_PAETH)) + { + /* Set up previous row buffer */ + png_ptr->prev_row = (png_bytep)png_calloc(png_ptr, buf_size); + + if (png_ptr->do_filter & PNG_FILTER_UP) + { + png_ptr->up_row = (png_bytep)png_malloc(png_ptr, + png_ptr->rowbytes + 1); + + png_ptr->up_row[0] = PNG_FILTER_VALUE_UP; + } + + if (png_ptr->do_filter & PNG_FILTER_AVG) + { + png_ptr->avg_row = (png_bytep)png_malloc(png_ptr, + png_ptr->rowbytes + 1); + + png_ptr->avg_row[0] = PNG_FILTER_VALUE_AVG; + } + + if (png_ptr->do_filter & PNG_FILTER_PAETH) + { + png_ptr->paeth_row = (png_bytep)png_malloc(png_ptr, + png_ptr->rowbytes + 1); + + png_ptr->paeth_row[0] = PNG_FILTER_VALUE_PAETH; + } + } +#endif /* PNG_WRITE_FILTER_SUPPORTED */ + +#ifdef PNG_WRITE_INTERLACING_SUPPORTED + /* If interlaced, we need to set up width and height of pass */ + if (png_ptr->interlaced) + { + if (!(png_ptr->transformations & PNG_INTERLACE)) + { + png_ptr->num_rows = (png_ptr->height + png_pass_yinc[0] - 1 - + png_pass_ystart[0]) / png_pass_yinc[0]; + + png_ptr->usr_width = (png_ptr->width + png_pass_inc[0] - 1 - + png_pass_start[0]) / png_pass_inc[0]; + } + + else + { + png_ptr->num_rows = png_ptr->height; + png_ptr->usr_width = png_ptr->width; + } + } + + else +#endif + { + png_ptr->num_rows = png_ptr->height; + png_ptr->usr_width = png_ptr->width; + } +} + +/* Internal use only. Called when finished processing a row of data. */ +void /* PRIVATE */ +png_write_finish_row(png_structrp png_ptr) +{ +#ifdef PNG_WRITE_INTERLACING_SUPPORTED + /* Arrays to facilitate easy interlacing - use pass (0 - 6) as index */ + + /* Start of interlace block */ + static PNG_CONST png_byte png_pass_start[7] = {0, 4, 0, 2, 0, 1, 0}; + + /* Offset to next interlace block */ + static PNG_CONST png_byte png_pass_inc[7] = {8, 8, 4, 4, 2, 2, 1}; + + /* Start of interlace block in the y direction */ + static PNG_CONST png_byte png_pass_ystart[7] = {0, 0, 4, 0, 2, 0, 1}; + + /* Offset to next interlace block in the y direction */ + static PNG_CONST png_byte png_pass_yinc[7] = {8, 8, 8, 4, 4, 2, 2}; +#endif + + png_debug(1, "in png_write_finish_row"); + + /* Next row */ + png_ptr->row_number++; + + /* See if we are done */ + if (png_ptr->row_number < png_ptr->num_rows) + return; + +#ifdef PNG_WRITE_INTERLACING_SUPPORTED + /* If interlaced, go to next pass */ + if (png_ptr->interlaced) + { + png_ptr->row_number = 0; + if (png_ptr->transformations & PNG_INTERLACE) + { + png_ptr->pass++; + } + + else + { + /* Loop until we find a non-zero width or height pass */ + do + { + png_ptr->pass++; + + if (png_ptr->pass >= 7) + break; + + png_ptr->usr_width = (png_ptr->width + + png_pass_inc[png_ptr->pass] - 1 - + png_pass_start[png_ptr->pass]) / + png_pass_inc[png_ptr->pass]; + + png_ptr->num_rows = (png_ptr->height + + png_pass_yinc[png_ptr->pass] - 1 - + png_pass_ystart[png_ptr->pass]) / + png_pass_yinc[png_ptr->pass]; + + if (png_ptr->transformations & PNG_INTERLACE) + break; + + } while (png_ptr->usr_width == 0 || png_ptr->num_rows == 0); + + } + + /* Reset the row above the image for the next pass */ + if (png_ptr->pass < 7) + { + if (png_ptr->prev_row != NULL) + memset(png_ptr->prev_row, 0, + (png_size_t)(PNG_ROWBYTES(png_ptr->usr_channels* + png_ptr->usr_bit_depth, png_ptr->width)) + 1); + + return; + } + } +#endif + + /* If we get here, we've just written the last row, so we need + to flush the compressor */ + png_compress_IDAT(png_ptr, NULL, 0, Z_FINISH); +} + +#ifdef PNG_WRITE_INTERLACING_SUPPORTED +/* Pick out the correct pixels for the interlace pass. + * The basic idea here is to go through the row with a source + * pointer and a destination pointer (sp and dp), and copy the + * correct pixels for the pass. As the row gets compacted, + * sp will always be >= dp, so we should never overwrite anything. + * See the default: case for the easiest code to understand. + */ +void /* PRIVATE */ +png_do_write_interlace(png_row_infop row_info, png_bytep row, int pass) +{ + /* Arrays to facilitate easy interlacing - use pass (0 - 6) as index */ + + /* Start of interlace block */ + static PNG_CONST png_byte png_pass_start[7] = {0, 4, 0, 2, 0, 1, 0}; + + /* Offset to next interlace block */ + static PNG_CONST png_byte png_pass_inc[7] = {8, 8, 4, 4, 2, 2, 1}; + + png_debug(1, "in png_do_write_interlace"); + + /* We don't have to do anything on the last pass (6) */ + if (pass < 6) + { + /* Each pixel depth is handled separately */ + switch (row_info->pixel_depth) + { + case 1: + { + png_bytep sp; + png_bytep dp; + int shift; + int d; + int value; + png_uint_32 i; + png_uint_32 row_width = row_info->width; + + dp = row; + d = 0; + shift = 7; + + for (i = png_pass_start[pass]; i < row_width; + i += png_pass_inc[pass]) + { + sp = row + (png_size_t)(i >> 3); + value = (int)(*sp >> (7 - (int)(i & 0x07))) & 0x01; + d |= (value << shift); + + if (shift == 0) + { + shift = 7; + *dp++ = (png_byte)d; + d = 0; + } + + else + shift--; + + } + if (shift != 7) + *dp = (png_byte)d; + + break; + } + + case 2: + { + png_bytep sp; + png_bytep dp; + int shift; + int d; + int value; + png_uint_32 i; + png_uint_32 row_width = row_info->width; + + dp = row; + shift = 6; + d = 0; + + for (i = png_pass_start[pass]; i < row_width; + i += png_pass_inc[pass]) + { + sp = row + (png_size_t)(i >> 2); + value = (*sp >> ((3 - (int)(i & 0x03)) << 1)) & 0x03; + d |= (value << shift); + + if (shift == 0) + { + shift = 6; + *dp++ = (png_byte)d; + d = 0; + } + + else + shift -= 2; + } + if (shift != 6) + *dp = (png_byte)d; + + break; + } + + case 4: + { + png_bytep sp; + png_bytep dp; + int shift; + int d; + int value; + png_uint_32 i; + png_uint_32 row_width = row_info->width; + + dp = row; + shift = 4; + d = 0; + for (i = png_pass_start[pass]; i < row_width; + i += png_pass_inc[pass]) + { + sp = row + (png_size_t)(i >> 1); + value = (*sp >> ((1 - (int)(i & 0x01)) << 2)) & 0x0f; + d |= (value << shift); + + if (shift == 0) + { + shift = 4; + *dp++ = (png_byte)d; + d = 0; + } + + else + shift -= 4; + } + if (shift != 4) + *dp = (png_byte)d; + + break; + } + + default: + { + png_bytep sp; + png_bytep dp; + png_uint_32 i; + png_uint_32 row_width = row_info->width; + png_size_t pixel_bytes; + + /* Start at the beginning */ + dp = row; + + /* Find out how many bytes each pixel takes up */ + pixel_bytes = (row_info->pixel_depth >> 3); + + /* Loop through the row, only looking at the pixels that matter */ + for (i = png_pass_start[pass]; i < row_width; + i += png_pass_inc[pass]) + { + /* Find out where the original pixel is */ + sp = row + (png_size_t)i * pixel_bytes; + + /* Move the pixel */ + if (dp != sp) + memcpy(dp, sp, pixel_bytes); + + /* Next pixel */ + dp += pixel_bytes; + } + break; + } + } + /* Set new row width */ + row_info->width = (row_info->width + + png_pass_inc[pass] - 1 - + png_pass_start[pass]) / + png_pass_inc[pass]; + + row_info->rowbytes = PNG_ROWBYTES(row_info->pixel_depth, + row_info->width); + } +} +#endif + +/* This filters the row, chooses which filter to use, if it has not already + * been specified by the application, and then writes the row out with the + * chosen filter. + */ +static void png_write_filtered_row(png_structrp png_ptr, png_bytep filtered_row, + png_size_t row_bytes); + +#define PNG_MAXSUM (((png_uint_32)(-1)) >> 1) +#define PNG_HISHIFT 10 +#define PNG_LOMASK ((png_uint_32)0xffffL) +#define PNG_HIMASK ((png_uint_32)(~PNG_LOMASK >> PNG_HISHIFT)) +void /* PRIVATE */ +png_write_find_filter(png_structrp png_ptr, png_row_infop row_info) +{ + png_bytep best_row; +#ifdef PNG_WRITE_FILTER_SUPPORTED + png_bytep prev_row, row_buf; + png_uint_32 mins, bpp; + png_byte filter_to_do = png_ptr->do_filter; + png_size_t row_bytes = row_info->rowbytes; +#ifdef PNG_WRITE_WEIGHTED_FILTER_SUPPORTED + int num_p_filters = png_ptr->num_prev_filters; +#endif + + png_debug(1, "in png_write_find_filter"); + +#ifndef PNG_WRITE_WEIGHTED_FILTER_SUPPORTED + if (png_ptr->row_number == 0 && filter_to_do == PNG_ALL_FILTERS) + { + /* These will never be selected so we need not test them. */ + filter_to_do &= ~(PNG_FILTER_UP | PNG_FILTER_PAETH); + } +#endif + + /* Find out how many bytes offset each pixel is */ + bpp = (row_info->pixel_depth + 7) >> 3; + + prev_row = png_ptr->prev_row; +#endif + best_row = png_ptr->row_buf; +#ifdef PNG_WRITE_FILTER_SUPPORTED + row_buf = best_row; + mins = PNG_MAXSUM; + + /* The prediction method we use is to find which method provides the + * smallest value when summing the absolute values of the distances + * from zero, using anything >= 128 as negative numbers. This is known + * as the "minimum sum of absolute differences" heuristic. Other + * heuristics are the "weighted minimum sum of absolute differences" + * (experimental and can in theory improve compression), and the "zlib + * predictive" method (not implemented yet), which does test compressions + * of lines using different filter methods, and then chooses the + * (series of) filter(s) that give minimum compressed data size (VERY + * computationally expensive). + * + * GRR 980525: consider also + * + * (1) minimum sum of absolute differences from running average (i.e., + * keep running sum of non-absolute differences & count of bytes) + * [track dispersion, too? restart average if dispersion too large?] + * + * (1b) minimum sum of absolute differences from sliding average, probably + * with window size <= deflate window (usually 32K) + * + * (2) minimum sum of squared differences from zero or running average + * (i.e., ~ root-mean-square approach) + */ + + + /* We don't need to test the 'no filter' case if this is the only filter + * that has been chosen, as it doesn't actually do anything to the data. + */ + if ((filter_to_do & PNG_FILTER_NONE) && filter_to_do != PNG_FILTER_NONE) + { + png_bytep rp; + png_uint_32 sum = 0; + png_size_t i; + int v; + + for (i = 0, rp = row_buf + 1; i < row_bytes; i++, rp++) + { + v = *rp; + sum += (v < 128) ? v : 256 - v; + } + +#ifdef PNG_WRITE_WEIGHTED_FILTER_SUPPORTED + if (png_ptr->heuristic_method == PNG_FILTER_HEURISTIC_WEIGHTED) + { + png_uint_32 sumhi, sumlo; + int j; + sumlo = sum & PNG_LOMASK; + sumhi = (sum >> PNG_HISHIFT) & PNG_HIMASK; /* Gives us some footroom */ + + /* Reduce the sum if we match any of the previous rows */ + for (j = 0; j < num_p_filters; j++) + { + if (png_ptr->prev_filters[j] == PNG_FILTER_VALUE_NONE) + { + sumlo = (sumlo * png_ptr->filter_weights[j]) >> + PNG_WEIGHT_SHIFT; + + sumhi = (sumhi * png_ptr->filter_weights[j]) >> + PNG_WEIGHT_SHIFT; + } + } + + /* Factor in the cost of this filter (this is here for completeness, + * but it makes no sense to have a "cost" for the NONE filter, as + * it has the minimum possible computational cost - none). + */ + sumlo = (sumlo * png_ptr->filter_costs[PNG_FILTER_VALUE_NONE]) >> + PNG_COST_SHIFT; + + sumhi = (sumhi * png_ptr->filter_costs[PNG_FILTER_VALUE_NONE]) >> + PNG_COST_SHIFT; + + if (sumhi > PNG_HIMASK) + sum = PNG_MAXSUM; + + else + sum = (sumhi << PNG_HISHIFT) + sumlo; + } +#endif + mins = sum; + } + + /* Sub filter */ + if (filter_to_do == PNG_FILTER_SUB) + /* It's the only filter so no testing is needed */ + { + png_bytep rp, lp, dp; + png_size_t i; + + for (i = 0, rp = row_buf + 1, dp = png_ptr->sub_row + 1; i < bpp; + i++, rp++, dp++) + { + *dp = *rp; + } + + for (lp = row_buf + 1; i < row_bytes; + i++, rp++, lp++, dp++) + { + *dp = (png_byte)(((int)*rp - (int)*lp) & 0xff); + } + + best_row = png_ptr->sub_row; + } + + else if (filter_to_do & PNG_FILTER_SUB) + { + png_bytep rp, dp, lp; + png_uint_32 sum = 0, lmins = mins; + png_size_t i; + int v; + +#ifdef PNG_WRITE_WEIGHTED_FILTER_SUPPORTED + /* We temporarily increase the "minimum sum" by the factor we + * would reduce the sum of this filter, so that we can do the + * early exit comparison without scaling the sum each time. + */ + if (png_ptr->heuristic_method == PNG_FILTER_HEURISTIC_WEIGHTED) + { + int j; + png_uint_32 lmhi, lmlo; + lmlo = lmins & PNG_LOMASK; + lmhi = (lmins >> PNG_HISHIFT) & PNG_HIMASK; + + for (j = 0; j < num_p_filters; j++) + { + if (png_ptr->prev_filters[j] == PNG_FILTER_VALUE_SUB) + { + lmlo = (lmlo * png_ptr->inv_filter_weights[j]) >> + PNG_WEIGHT_SHIFT; + + lmhi = (lmhi * png_ptr->inv_filter_weights[j]) >> + PNG_WEIGHT_SHIFT; + } + } + + lmlo = (lmlo * png_ptr->inv_filter_costs[PNG_FILTER_VALUE_SUB]) >> + PNG_COST_SHIFT; + + lmhi = (lmhi * png_ptr->inv_filter_costs[PNG_FILTER_VALUE_SUB]) >> + PNG_COST_SHIFT; + + if (lmhi > PNG_HIMASK) + lmins = PNG_MAXSUM; + + else + lmins = (lmhi << PNG_HISHIFT) + lmlo; + } +#endif + + for (i = 0, rp = row_buf + 1, dp = png_ptr->sub_row + 1; i < bpp; + i++, rp++, dp++) + { + v = *dp = *rp; + + sum += (v < 128) ? v : 256 - v; + } + + for (lp = row_buf + 1; i < row_bytes; + i++, rp++, lp++, dp++) + { + v = *dp = (png_byte)(((int)*rp - (int)*lp) & 0xff); + + sum += (v < 128) ? v : 256 - v; + + if (sum > lmins) /* We are already worse, don't continue. */ + break; + } + +#ifdef PNG_WRITE_WEIGHTED_FILTER_SUPPORTED + if (png_ptr->heuristic_method == PNG_FILTER_HEURISTIC_WEIGHTED) + { + int j; + png_uint_32 sumhi, sumlo; + sumlo = sum & PNG_LOMASK; + sumhi = (sum >> PNG_HISHIFT) & PNG_HIMASK; + + for (j = 0; j < num_p_filters; j++) + { + if (png_ptr->prev_filters[j] == PNG_FILTER_VALUE_SUB) + { + sumlo = (sumlo * png_ptr->inv_filter_weights[j]) >> + PNG_WEIGHT_SHIFT; + + sumhi = (sumhi * png_ptr->inv_filter_weights[j]) >> + PNG_WEIGHT_SHIFT; + } + } + + sumlo = (sumlo * png_ptr->inv_filter_costs[PNG_FILTER_VALUE_SUB]) >> + PNG_COST_SHIFT; + + sumhi = (sumhi * png_ptr->inv_filter_costs[PNG_FILTER_VALUE_SUB]) >> + PNG_COST_SHIFT; + + if (sumhi > PNG_HIMASK) + sum = PNG_MAXSUM; + + else + sum = (sumhi << PNG_HISHIFT) + sumlo; + } +#endif + + if (sum < mins) + { + mins = sum; + best_row = png_ptr->sub_row; + } + } + + /* Up filter */ + if (filter_to_do == PNG_FILTER_UP) + { + png_bytep rp, dp, pp; + png_size_t i; + + for (i = 0, rp = row_buf + 1, dp = png_ptr->up_row + 1, + pp = prev_row + 1; i < row_bytes; + i++, rp++, pp++, dp++) + { + *dp = (png_byte)(((int)*rp - (int)*pp) & 0xff); + } + + best_row = png_ptr->up_row; + } + + else if (filter_to_do & PNG_FILTER_UP) + { + png_bytep rp, dp, pp; + png_uint_32 sum = 0, lmins = mins; + png_size_t i; + int v; + + +#ifdef PNG_WRITE_WEIGHTED_FILTER_SUPPORTED + if (png_ptr->heuristic_method == PNG_FILTER_HEURISTIC_WEIGHTED) + { + int j; + png_uint_32 lmhi, lmlo; + lmlo = lmins & PNG_LOMASK; + lmhi = (lmins >> PNG_HISHIFT) & PNG_HIMASK; + + for (j = 0; j < num_p_filters; j++) + { + if (png_ptr->prev_filters[j] == PNG_FILTER_VALUE_UP) + { + lmlo = (lmlo * png_ptr->inv_filter_weights[j]) >> + PNG_WEIGHT_SHIFT; + + lmhi = (lmhi * png_ptr->inv_filter_weights[j]) >> + PNG_WEIGHT_SHIFT; + } + } + + lmlo = (lmlo * png_ptr->inv_filter_costs[PNG_FILTER_VALUE_UP]) >> + PNG_COST_SHIFT; + + lmhi = (lmhi * png_ptr->inv_filter_costs[PNG_FILTER_VALUE_UP]) >> + PNG_COST_SHIFT; + + if (lmhi > PNG_HIMASK) + lmins = PNG_MAXSUM; + + else + lmins = (lmhi << PNG_HISHIFT) + lmlo; + } +#endif + + for (i = 0, rp = row_buf + 1, dp = png_ptr->up_row + 1, + pp = prev_row + 1; i < row_bytes; i++) + { + v = *dp++ = (png_byte)(((int)*rp++ - (int)*pp++) & 0xff); + + sum += (v < 128) ? v : 256 - v; + + if (sum > lmins) /* We are already worse, don't continue. */ + break; + } + +#ifdef PNG_WRITE_WEIGHTED_FILTER_SUPPORTED + if (png_ptr->heuristic_method == PNG_FILTER_HEURISTIC_WEIGHTED) + { + int j; + png_uint_32 sumhi, sumlo; + sumlo = sum & PNG_LOMASK; + sumhi = (sum >> PNG_HISHIFT) & PNG_HIMASK; + + for (j = 0; j < num_p_filters; j++) + { + if (png_ptr->prev_filters[j] == PNG_FILTER_VALUE_UP) + { + sumlo = (sumlo * png_ptr->filter_weights[j]) >> + PNG_WEIGHT_SHIFT; + + sumhi = (sumhi * png_ptr->filter_weights[j]) >> + PNG_WEIGHT_SHIFT; + } + } + + sumlo = (sumlo * png_ptr->filter_costs[PNG_FILTER_VALUE_UP]) >> + PNG_COST_SHIFT; + + sumhi = (sumhi * png_ptr->filter_costs[PNG_FILTER_VALUE_UP]) >> + PNG_COST_SHIFT; + + if (sumhi > PNG_HIMASK) + sum = PNG_MAXSUM; + + else + sum = (sumhi << PNG_HISHIFT) + sumlo; + } +#endif + + if (sum < mins) + { + mins = sum; + best_row = png_ptr->up_row; + } + } + + /* Avg filter */ + if (filter_to_do == PNG_FILTER_AVG) + { + png_bytep rp, dp, pp, lp; + png_uint_32 i; + + for (i = 0, rp = row_buf + 1, dp = png_ptr->avg_row + 1, + pp = prev_row + 1; i < bpp; i++) + { + *dp++ = (png_byte)(((int)*rp++ - ((int)*pp++ / 2)) & 0xff); + } + + for (lp = row_buf + 1; i < row_bytes; i++) + { + *dp++ = (png_byte)(((int)*rp++ - (((int)*pp++ + (int)*lp++) / 2)) + & 0xff); + } + best_row = png_ptr->avg_row; + } + + else if (filter_to_do & PNG_FILTER_AVG) + { + png_bytep rp, dp, pp, lp; + png_uint_32 sum = 0, lmins = mins; + png_size_t i; + int v; + +#ifdef PNG_WRITE_WEIGHTED_FILTER_SUPPORTED + if (png_ptr->heuristic_method == PNG_FILTER_HEURISTIC_WEIGHTED) + { + int j; + png_uint_32 lmhi, lmlo; + lmlo = lmins & PNG_LOMASK; + lmhi = (lmins >> PNG_HISHIFT) & PNG_HIMASK; + + for (j = 0; j < num_p_filters; j++) + { + if (png_ptr->prev_filters[j] == PNG_FILTER_VALUE_AVG) + { + lmlo = (lmlo * png_ptr->inv_filter_weights[j]) >> + PNG_WEIGHT_SHIFT; + + lmhi = (lmhi * png_ptr->inv_filter_weights[j]) >> + PNG_WEIGHT_SHIFT; + } + } + + lmlo = (lmlo * png_ptr->inv_filter_costs[PNG_FILTER_VALUE_AVG]) >> + PNG_COST_SHIFT; + + lmhi = (lmhi * png_ptr->inv_filter_costs[PNG_FILTER_VALUE_AVG]) >> + PNG_COST_SHIFT; + + if (lmhi > PNG_HIMASK) + lmins = PNG_MAXSUM; + + else + lmins = (lmhi << PNG_HISHIFT) + lmlo; + } +#endif + + for (i = 0, rp = row_buf + 1, dp = png_ptr->avg_row + 1, + pp = prev_row + 1; i < bpp; i++) + { + v = *dp++ = (png_byte)(((int)*rp++ - ((int)*pp++ / 2)) & 0xff); + + sum += (v < 128) ? v : 256 - v; + } + + for (lp = row_buf + 1; i < row_bytes; i++) + { + v = *dp++ = + (png_byte)(((int)*rp++ - (((int)*pp++ + (int)*lp++) / 2)) & 0xff); + + sum += (v < 128) ? v : 256 - v; + + if (sum > lmins) /* We are already worse, don't continue. */ + break; + } + +#ifdef PNG_WRITE_WEIGHTED_FILTER_SUPPORTED + if (png_ptr->heuristic_method == PNG_FILTER_HEURISTIC_WEIGHTED) + { + int j; + png_uint_32 sumhi, sumlo; + sumlo = sum & PNG_LOMASK; + sumhi = (sum >> PNG_HISHIFT) & PNG_HIMASK; + + for (j = 0; j < num_p_filters; j++) + { + if (png_ptr->prev_filters[j] == PNG_FILTER_VALUE_NONE) + { + sumlo = (sumlo * png_ptr->filter_weights[j]) >> + PNG_WEIGHT_SHIFT; + + sumhi = (sumhi * png_ptr->filter_weights[j]) >> + PNG_WEIGHT_SHIFT; + } + } + + sumlo = (sumlo * png_ptr->filter_costs[PNG_FILTER_VALUE_AVG]) >> + PNG_COST_SHIFT; + + sumhi = (sumhi * png_ptr->filter_costs[PNG_FILTER_VALUE_AVG]) >> + PNG_COST_SHIFT; + + if (sumhi > PNG_HIMASK) + sum = PNG_MAXSUM; + + else + sum = (sumhi << PNG_HISHIFT) + sumlo; + } +#endif + + if (sum < mins) + { + mins = sum; + best_row = png_ptr->avg_row; + } + } + + /* Paeth filter */ + if (filter_to_do == PNG_FILTER_PAETH) + { + png_bytep rp, dp, pp, cp, lp; + png_size_t i; + + for (i = 0, rp = row_buf + 1, dp = png_ptr->paeth_row + 1, + pp = prev_row + 1; i < bpp; i++) + { + *dp++ = (png_byte)(((int)*rp++ - (int)*pp++) & 0xff); + } + + for (lp = row_buf + 1, cp = prev_row + 1; i < row_bytes; i++) + { + int a, b, c, pa, pb, pc, p; + + b = *pp++; + c = *cp++; + a = *lp++; + + p = b - c; + pc = a - c; + +#ifdef PNG_USE_ABS + pa = abs(p); + pb = abs(pc); + pc = abs(p + pc); +#else + pa = p < 0 ? -p : p; + pb = pc < 0 ? -pc : pc; + pc = (p + pc) < 0 ? -(p + pc) : p + pc; +#endif + + p = (pa <= pb && pa <=pc) ? a : (pb <= pc) ? b : c; + + *dp++ = (png_byte)(((int)*rp++ - p) & 0xff); + } + best_row = png_ptr->paeth_row; + } + + else if (filter_to_do & PNG_FILTER_PAETH) + { + png_bytep rp, dp, pp, cp, lp; + png_uint_32 sum = 0, lmins = mins; + png_size_t i; + int v; + +#ifdef PNG_WRITE_WEIGHTED_FILTER_SUPPORTED + if (png_ptr->heuristic_method == PNG_FILTER_HEURISTIC_WEIGHTED) + { + int j; + png_uint_32 lmhi, lmlo; + lmlo = lmins & PNG_LOMASK; + lmhi = (lmins >> PNG_HISHIFT) & PNG_HIMASK; + + for (j = 0; j < num_p_filters; j++) + { + if (png_ptr->prev_filters[j] == PNG_FILTER_VALUE_PAETH) + { + lmlo = (lmlo * png_ptr->inv_filter_weights[j]) >> + PNG_WEIGHT_SHIFT; + + lmhi = (lmhi * png_ptr->inv_filter_weights[j]) >> + PNG_WEIGHT_SHIFT; + } + } + + lmlo = (lmlo * png_ptr->inv_filter_costs[PNG_FILTER_VALUE_PAETH]) >> + PNG_COST_SHIFT; + + lmhi = (lmhi * png_ptr->inv_filter_costs[PNG_FILTER_VALUE_PAETH]) >> + PNG_COST_SHIFT; + + if (lmhi > PNG_HIMASK) + lmins = PNG_MAXSUM; + + else + lmins = (lmhi << PNG_HISHIFT) + lmlo; + } +#endif + + for (i = 0, rp = row_buf + 1, dp = png_ptr->paeth_row + 1, + pp = prev_row + 1; i < bpp; i++) + { + v = *dp++ = (png_byte)(((int)*rp++ - (int)*pp++) & 0xff); + + sum += (v < 128) ? v : 256 - v; + } + + for (lp = row_buf + 1, cp = prev_row + 1; i < row_bytes; i++) + { + int a, b, c, pa, pb, pc, p; + + b = *pp++; + c = *cp++; + a = *lp++; + +#ifndef PNG_SLOW_PAETH + p = b - c; + pc = a - c; +#ifdef PNG_USE_ABS + pa = abs(p); + pb = abs(pc); + pc = abs(p + pc); +#else + pa = p < 0 ? -p : p; + pb = pc < 0 ? -pc : pc; + pc = (p + pc) < 0 ? -(p + pc) : p + pc; +#endif + p = (pa <= pb && pa <=pc) ? a : (pb <= pc) ? b : c; +#else /* PNG_SLOW_PAETH */ + p = a + b - c; + pa = abs(p - a); + pb = abs(p - b); + pc = abs(p - c); + + if (pa <= pb && pa <= pc) + p = a; + + else if (pb <= pc) + p = b; + + else + p = c; +#endif /* PNG_SLOW_PAETH */ + + v = *dp++ = (png_byte)(((int)*rp++ - p) & 0xff); + + sum += (v < 128) ? v : 256 - v; + + if (sum > lmins) /* We are already worse, don't continue. */ + break; + } + +#ifdef PNG_WRITE_WEIGHTED_FILTER_SUPPORTED + if (png_ptr->heuristic_method == PNG_FILTER_HEURISTIC_WEIGHTED) + { + int j; + png_uint_32 sumhi, sumlo; + sumlo = sum & PNG_LOMASK; + sumhi = (sum >> PNG_HISHIFT) & PNG_HIMASK; + + for (j = 0; j < num_p_filters; j++) + { + if (png_ptr->prev_filters[j] == PNG_FILTER_VALUE_PAETH) + { + sumlo = (sumlo * png_ptr->filter_weights[j]) >> + PNG_WEIGHT_SHIFT; + + sumhi = (sumhi * png_ptr->filter_weights[j]) >> + PNG_WEIGHT_SHIFT; + } + } + + sumlo = (sumlo * png_ptr->filter_costs[PNG_FILTER_VALUE_PAETH]) >> + PNG_COST_SHIFT; + + sumhi = (sumhi * png_ptr->filter_costs[PNG_FILTER_VALUE_PAETH]) >> + PNG_COST_SHIFT; + + if (sumhi > PNG_HIMASK) + sum = PNG_MAXSUM; + + else + sum = (sumhi << PNG_HISHIFT) + sumlo; + } +#endif + + if (sum < mins) + { + best_row = png_ptr->paeth_row; + } + } +#endif /* PNG_WRITE_FILTER_SUPPORTED */ + + /* Do the actual writing of the filtered row data from the chosen filter. */ + png_write_filtered_row(png_ptr, best_row, row_info->rowbytes+1); + +#ifdef PNG_WRITE_FILTER_SUPPORTED +#ifdef PNG_WRITE_WEIGHTED_FILTER_SUPPORTED + /* Save the type of filter we picked this time for future calculations */ + if (png_ptr->num_prev_filters > 0) + { + int j; + + for (j = 1; j < num_p_filters; j++) + { + png_ptr->prev_filters[j] = png_ptr->prev_filters[j - 1]; + } + + png_ptr->prev_filters[j] = best_row[0]; + } +#endif +#endif /* PNG_WRITE_FILTER_SUPPORTED */ +} + + +/* Do the actual writing of a previously filtered row. */ +static void +png_write_filtered_row(png_structrp png_ptr, png_bytep filtered_row, + png_size_t full_row_length/*includes filter byte*/) +{ + png_debug(1, "in png_write_filtered_row"); + + png_debug1(2, "filter = %d", filtered_row[0]); + + png_compress_IDAT(png_ptr, filtered_row, full_row_length, Z_NO_FLUSH); + + /* Swap the current and previous rows */ + if (png_ptr->prev_row != NULL) + { + png_bytep tptr; + + tptr = png_ptr->prev_row; + png_ptr->prev_row = png_ptr->row_buf; + png_ptr->row_buf = tptr; + } + + /* Finish row - updates counters and flushes zlib if last row */ + png_write_finish_row(png_ptr); + +#ifdef PNG_WRITE_FLUSH_SUPPORTED + png_ptr->flush_rows++; + + if (png_ptr->flush_dist > 0 && + png_ptr->flush_rows >= png_ptr->flush_dist) + { + png_write_flush(png_ptr); + } +#endif +} +#endif /* PNG_WRITE_SUPPORTED */ diff --git a/ml/dlib/dlib/external/pybind11/CMakeLists.txt b/ml/dlib/dlib/external/pybind11/CMakeLists.txt new file mode 100644 index 000000000..4280ba742 --- /dev/null +++ b/ml/dlib/dlib/external/pybind11/CMakeLists.txt @@ -0,0 +1,155 @@ +# CMakeLists.txt -- Build system for the pybind11 modules +# +# Copyright (c) 2015 Wenzel Jakob +# +# All rights reserved. Use of this source code is governed by a +# BSD-style license that can be found in the LICENSE file. + +cmake_minimum_required(VERSION 2.8.12) + +if (POLICY CMP0048) + # cmake warns if loaded from a min-3.0-required parent dir, so silence the warning: + cmake_policy(SET CMP0048 NEW) +endif() + +# CMake versions < 3.4.0 do not support try_compile/pthread checks without C as active language. +if(CMAKE_VERSION VERSION_LESS 3.4.0) + project(pybind11) +else() + project(pybind11 CXX) +endif() + +# Check if pybind11 is being used directly or via add_subdirectory +set(PYBIND11_MASTER_PROJECT OFF) +if (CMAKE_CURRENT_SOURCE_DIR STREQUAL CMAKE_SOURCE_DIR) + set(PYBIND11_MASTER_PROJECT ON) +endif() + +option(PYBIND11_INSTALL "Install pybind11 header files?" ${PYBIND11_MASTER_PROJECT}) +option(PYBIND11_TEST "Build pybind11 test suite?" ${PYBIND11_MASTER_PROJECT}) + +list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}/tools") + +include(pybind11Tools) + +# Cache variables so pybind11_add_module can be used in parent projects +set(PYBIND11_INCLUDE_DIR "${CMAKE_CURRENT_LIST_DIR}/include" CACHE INTERNAL "") +set(PYTHON_INCLUDE_DIRS ${PYTHON_INCLUDE_DIRS} CACHE INTERNAL "") +set(PYTHON_LIBRARIES ${PYTHON_LIBRARIES} CACHE INTERNAL "") +set(PYTHON_MODULE_PREFIX ${PYTHON_MODULE_PREFIX} CACHE INTERNAL "") +set(PYTHON_MODULE_EXTENSION ${PYTHON_MODULE_EXTENSION} CACHE INTERNAL "") + +# NB: when adding a header don't forget to also add it to setup.py +set(PYBIND11_HEADERS + include/pybind11/detail/class.h + include/pybind11/detail/common.h + include/pybind11/detail/descr.h + include/pybind11/detail/init.h + include/pybind11/detail/internals.h + include/pybind11/detail/typeid.h + include/pybind11/attr.h + include/pybind11/buffer_info.h + include/pybind11/cast.h + include/pybind11/chrono.h + include/pybind11/common.h + include/pybind11/complex.h + include/pybind11/options.h + include/pybind11/eigen.h + include/pybind11/embed.h + include/pybind11/eval.h + include/pybind11/functional.h + include/pybind11/numpy.h + include/pybind11/operators.h + include/pybind11/pybind11.h + include/pybind11/pytypes.h + include/pybind11/stl.h + include/pybind11/stl_bind.h +) +string(REPLACE "include/" "${CMAKE_CURRENT_SOURCE_DIR}/include/" + PYBIND11_HEADERS "${PYBIND11_HEADERS}") + +if (PYBIND11_TEST) + add_subdirectory(tests) +endif() + +include(GNUInstallDirs) +include(CMakePackageConfigHelpers) + +# extract project version from source +file(STRINGS "${PYBIND11_INCLUDE_DIR}/pybind11/detail/common.h" pybind11_version_defines + REGEX "#define PYBIND11_VERSION_(MAJOR|MINOR|PATCH) ") +foreach(ver ${pybind11_version_defines}) + if (ver MATCHES "#define PYBIND11_VERSION_(MAJOR|MINOR|PATCH) +([^ ]+)$") + set(PYBIND11_VERSION_${CMAKE_MATCH_1} "${CMAKE_MATCH_2}" CACHE INTERNAL "") + endif() +endforeach() +set(${PROJECT_NAME}_VERSION ${PYBIND11_VERSION_MAJOR}.${PYBIND11_VERSION_MINOR}.${PYBIND11_VERSION_PATCH}) +message(STATUS "pybind11 v${${PROJECT_NAME}_VERSION}") + +option (USE_PYTHON_INCLUDE_DIR "Install pybind11 headers in Python include directory instead of default installation prefix" OFF) +if (USE_PYTHON_INCLUDE_DIR) + file(RELATIVE_PATH CMAKE_INSTALL_INCLUDEDIR ${CMAKE_INSTALL_PREFIX} ${PYTHON_INCLUDE_DIRS}) +endif() + +if(NOT (CMAKE_VERSION VERSION_LESS 3.0)) # CMake >= 3.0 + # Build an interface library target: + add_library(pybind11 INTERFACE) + add_library(pybind11::pybind11 ALIAS pybind11) # to match exported target + target_include_directories(pybind11 INTERFACE $ + $ + $) + target_compile_options(pybind11 INTERFACE $) + + add_library(module INTERFACE) + add_library(pybind11::module ALIAS module) + if(NOT MSVC) + target_compile_options(module INTERFACE -fvisibility=hidden) + endif() + target_link_libraries(module INTERFACE pybind11::pybind11) + if(WIN32 OR CYGWIN) + target_link_libraries(module INTERFACE $) + elseif(APPLE) + target_link_libraries(module INTERFACE "-undefined dynamic_lookup") + endif() + + add_library(embed INTERFACE) + add_library(pybind11::embed ALIAS embed) + target_link_libraries(embed INTERFACE pybind11::pybind11 $) +endif() + +if (PYBIND11_INSTALL) + install(DIRECTORY ${PYBIND11_INCLUDE_DIR}/pybind11 DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) + # GNUInstallDirs "DATADIR" wrong here; CMake search path wants "share". + set(PYBIND11_CMAKECONFIG_INSTALL_DIR "share/cmake/${PROJECT_NAME}" CACHE STRING "install path for pybind11Config.cmake") + + configure_package_config_file(tools/${PROJECT_NAME}Config.cmake.in + "${CMAKE_CURRENT_BINARY_DIR}/${PROJECT_NAME}Config.cmake" + INSTALL_DESTINATION ${PYBIND11_CMAKECONFIG_INSTALL_DIR}) + # Remove CMAKE_SIZEOF_VOID_P from ConfigVersion.cmake since the library does + # not depend on architecture specific settings or libraries. + set(_PYBIND11_CMAKE_SIZEOF_VOID_P ${CMAKE_SIZEOF_VOID_P}) + unset(CMAKE_SIZEOF_VOID_P) + write_basic_package_version_file(${CMAKE_CURRENT_BINARY_DIR}/${PROJECT_NAME}ConfigVersion.cmake + VERSION ${${PROJECT_NAME}_VERSION} + COMPATIBILITY AnyNewerVersion) + set(CMAKE_SIZEOF_VOID_P ${_PYBIND11_CMAKE_SIZEOF_VOID_P}) + install(FILES ${CMAKE_CURRENT_BINARY_DIR}/${PROJECT_NAME}Config.cmake + ${CMAKE_CURRENT_BINARY_DIR}/${PROJECT_NAME}ConfigVersion.cmake + tools/FindPythonLibsNew.cmake + tools/pybind11Tools.cmake + DESTINATION ${PYBIND11_CMAKECONFIG_INSTALL_DIR}) + + if(NOT (CMAKE_VERSION VERSION_LESS 3.0)) + if(NOT PYBIND11_EXPORT_NAME) + set(PYBIND11_EXPORT_NAME "${PROJECT_NAME}Targets") + endif() + + install(TARGETS pybind11 module embed + EXPORT "${PYBIND11_EXPORT_NAME}") + if(PYBIND11_MASTER_PROJECT) + install(EXPORT "${PYBIND11_EXPORT_NAME}" + NAMESPACE "${PROJECT_NAME}::" + DESTINATION ${PYBIND11_CMAKECONFIG_INSTALL_DIR}) + endif() + endif() +endif() diff --git a/ml/dlib/dlib/external/pybind11/CONTRIBUTING.md b/ml/dlib/dlib/external/pybind11/CONTRIBUTING.md new file mode 100644 index 000000000..375735f6c --- /dev/null +++ b/ml/dlib/dlib/external/pybind11/CONTRIBUTING.md @@ -0,0 +1,47 @@ +Thank you for your interest in this project! Please refer to the following +sections on how to contribute code and bug reports. + +### Reporting bugs + +At the moment, this project is run in the spare time of a single person +([Wenzel Jakob](http://rgl.epfl.ch/people/wjakob)) with very limited resources +for issue tracker tickets. Thus, before submitting a question or bug report, +please take a moment of your time and ensure that your issue isn't already +discussed in the project documentation provided at +[http://pybind11.readthedocs.org/en/latest](http://pybind11.readthedocs.org/en/latest). + +Assuming that you have identified a previously unknown problem or an important +question, it's essential that you submit a self-contained and minimal piece of +code that reproduces the problem. In other words: no external dependencies, +isolate the function(s) that cause breakage, submit matched and complete C++ +and Python snippets that can be easily compiled and run on my end. + +## Pull requests +Contributions are submitted, reviewed, and accepted using Github pull requests. +Please refer to [this +article](https://help.github.com/articles/using-pull-requests) for details and +adhere to the following rules to make the process as smooth as possible: + +* Make a new branch for every feature you're working on. +* Make small and clean pull requests that are easy to review but make sure they + do add value by themselves. +* Add tests for any new functionality and run the test suite (``make pytest``) + to ensure that no existing features break. +* This project has a strong focus on providing general solutions using a + minimal amount of code, thus small pull requests are greatly preferred. + +### Licensing of contributions + +pybind11 is provided under a BSD-style license that can be found in the +``LICENSE`` file. By using, distributing, or contributing to this project, you +agree to the terms and conditions of this license. + +You are under no obligation whatsoever to provide any bug fixes, patches, or +upgrades to the features, functionality or performance of the source code +("Enhancements") to anyone; however, if you choose to make your Enhancements +available either publicly, or directly to the author of this software, without +imposing a separate written license agreement for such Enhancements, then you +hereby grant the following license: a non-exclusive, royalty-free perpetual +license to install, use, modify, prepare derivative works, incorporate into +other computer software, distribute, and sublicense such enhancements or +derivative works thereof, in binary and source code form. diff --git a/ml/dlib/dlib/external/pybind11/LICENSE b/ml/dlib/dlib/external/pybind11/LICENSE new file mode 100644 index 000000000..6f15578cc --- /dev/null +++ b/ml/dlib/dlib/external/pybind11/LICENSE @@ -0,0 +1,29 @@ +Copyright (c) 2016 Wenzel Jakob , All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors + may be used to endorse or promote products derived from this software + without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +Please also refer to the file CONTRIBUTING.md, which clarifies licensing of +external contributions to this project including patches, pull requests, etc. diff --git a/ml/dlib/dlib/external/pybind11/README.md b/ml/dlib/dlib/external/pybind11/README.md new file mode 100644 index 000000000..447788240 --- /dev/null +++ b/ml/dlib/dlib/external/pybind11/README.md @@ -0,0 +1,129 @@ +![pybind11 logo](https://github.com/pybind/pybind11/raw/master/docs/pybind11-logo.png) + +# pybind11 — Seamless operability between C++11 and Python + +[![Documentation Status](https://readthedocs.org/projects/pybind11/badge/?version=master)](http://pybind11.readthedocs.org/en/master/?badge=master) +[![Documentation Status](https://readthedocs.org/projects/pybind11/badge/?version=stable)](http://pybind11.readthedocs.org/en/stable/?badge=stable) +[![Gitter chat](https://img.shields.io/gitter/room/gitterHQ/gitter.svg)](https://gitter.im/pybind/Lobby) +[![Build Status](https://travis-ci.org/pybind/pybind11.svg?branch=master)](https://travis-ci.org/pybind/pybind11) +[![Build status](https://ci.appveyor.com/api/projects/status/riaj54pn4h08xy40?svg=true)](https://ci.appveyor.com/project/wjakob/pybind11) + +**pybind11** is a lightweight header-only library that exposes C++ types in Python +and vice versa, mainly to create Python bindings of existing C++ code. Its +goals and syntax are similar to the excellent +[Boost.Python](http://www.boost.org/doc/libs/1_58_0/libs/python/doc/) library +by David Abrahams: to minimize boilerplate code in traditional extension +modules by inferring type information using compile-time introspection. + +The main issue with Boost.Python—and the reason for creating such a similar +project—is Boost. Boost is an enormously large and complex suite of utility +libraries that works with almost every C++ compiler in existence. This +compatibility has its cost: arcane template tricks and workarounds are +necessary to support the oldest and buggiest of compiler specimens. Now that +C++11-compatible compilers are widely available, this heavy machinery has +become an excessively large and unnecessary dependency. + +Think of this library as a tiny self-contained version of Boost.Python with +everything stripped away that isn't relevant for binding generation. Without +comments, the core header files only require ~4K lines of code and depend on +Python (2.7 or 3.x, or PyPy2.7 >= 5.7) and the C++ standard library. This +compact implementation was possible thanks to some of the new C++11 language +features (specifically: tuples, lambda functions and variadic templates). Since +its creation, this library has grown beyond Boost.Python in many ways, leading +to dramatically simpler binding code in many common situations. + +Tutorial and reference documentation is provided at +[http://pybind11.readthedocs.org/en/master](http://pybind11.readthedocs.org/en/master). +A PDF version of the manual is available +[here](https://media.readthedocs.org/pdf/pybind11/master/pybind11.pdf). + +## Core features +pybind11 can map the following core C++ features to Python + +- Functions accepting and returning custom data structures per value, reference, or pointer +- Instance methods and static methods +- Overloaded functions +- Instance attributes and static attributes +- Arbitrary exception types +- Enumerations +- Callbacks +- Iterators and ranges +- Custom operators +- Single and multiple inheritance +- STL data structures +- Iterators and ranges +- Smart pointers with reference counting like ``std::shared_ptr`` +- Internal references with correct reference counting +- C++ classes with virtual (and pure virtual) methods can be extended in Python + +## Goodies +In addition to the core functionality, pybind11 provides some extra goodies: + +- Python 2.7, 3.x, and PyPy (PyPy2.7 >= 5.7) are supported with an + implementation-agnostic interface. + +- It is possible to bind C++11 lambda functions with captured variables. The + lambda capture data is stored inside the resulting Python function object. + +- pybind11 uses C++11 move constructors and move assignment operators whenever + possible to efficiently transfer custom data types. + +- It's easy to expose the internal storage of custom data types through + Pythons' buffer protocols. This is handy e.g. for fast conversion between + C++ matrix classes like Eigen and NumPy without expensive copy operations. + +- pybind11 can automatically vectorize functions so that they are transparently + applied to all entries of one or more NumPy array arguments. + +- Python's slice-based access and assignment operations can be supported with + just a few lines of code. + +- Everything is contained in just a few header files; there is no need to link + against any additional libraries. + +- Binaries are generally smaller by a factor of at least 2 compared to + equivalent bindings generated by Boost.Python. A recent pybind11 conversion + of PyRosetta, an enormous Boost.Python binding project, + [reported](http://graylab.jhu.edu/RosettaCon2016/PyRosetta-4.pdf) a binary + size reduction of **5.4x** and compile time reduction by **5.8x**. + +- When supported by the compiler, two new C++14 features (relaxed constexpr and + return value deduction) are used to precompute function signatures at compile + time, leading to smaller binaries. + +- With little extra effort, C++ types can be pickled and unpickled similar to + regular Python objects. + +## Supported compilers + +1. Clang/LLVM 3.3 or newer (for Apple Xcode's clang, this is 5.0.0 or newer) +2. GCC 4.8 or newer +3. Microsoft Visual Studio 2015 Update 3 or newer +4. Intel C++ compiler 16 or newer (15 with a [workaround](https://github.com/pybind/pybind11/issues/276)) +5. Cygwin/GCC (tested on 2.5.1) + +## About + +This project was created by [Wenzel Jakob](http://rgl.epfl.ch/people/wjakob). +Significant features and/or improvements to the code were contributed by +Jonas Adler, +Sylvain Corlay, +Trent Houliston, +Axel Huebl, +@hulucc, +Sergey Lyskov +Johan Mabille, +Tomasz Miąsko, +Dean Moldovan, +Ben Pritchard, +Jason Rhinelander, +Boris Schäling, +Pim Schellart, +Ivan Smirnov, and +Patrick Stewart. + +### License + +pybind11 is provided under a BSD-style license that can be found in the +``LICENSE`` file. By using, distributing, or contributing to this project, +you agree to the terms and conditions of this license. diff --git a/ml/dlib/dlib/external/pybind11/include/pybind11/attr.h b/ml/dlib/dlib/external/pybind11/include/pybind11/attr.h new file mode 100644 index 000000000..dce875a6b --- /dev/null +++ b/ml/dlib/dlib/external/pybind11/include/pybind11/attr.h @@ -0,0 +1,489 @@ +/* + pybind11/attr.h: Infrastructure for processing custom + type and function attributes + + Copyright (c) 2016 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include "cast.h" + +NAMESPACE_BEGIN(PYBIND11_NAMESPACE) + +/// \addtogroup annotations +/// @{ + +/// Annotation for methods +struct is_method { handle class_; is_method(const handle &c) : class_(c) { } }; + +/// Annotation for operators +struct is_operator { }; + +/// Annotation for parent scope +struct scope { handle value; scope(const handle &s) : value(s) { } }; + +/// Annotation for documentation +struct doc { const char *value; doc(const char *value) : value(value) { } }; + +/// Annotation for function names +struct name { const char *value; name(const char *value) : value(value) { } }; + +/// Annotation indicating that a function is an overload associated with a given "sibling" +struct sibling { handle value; sibling(const handle &value) : value(value.ptr()) { } }; + +/// Annotation indicating that a class derives from another given type +template struct base { + PYBIND11_DEPRECATED("base() was deprecated in favor of specifying 'T' as a template argument to class_") + base() { } +}; + +/// Keep patient alive while nurse lives +template struct keep_alive { }; + +/// Annotation indicating that a class is involved in a multiple inheritance relationship +struct multiple_inheritance { }; + +/// Annotation which enables dynamic attributes, i.e. adds `__dict__` to a class +struct dynamic_attr { }; + +/// Annotation which enables the buffer protocol for a type +struct buffer_protocol { }; + +/// Annotation which requests that a special metaclass is created for a type +struct metaclass { + handle value; + + PYBIND11_DEPRECATED("py::metaclass() is no longer required. It's turned on by default now.") + metaclass() {} + + /// Override pybind11's default metaclass + explicit metaclass(handle value) : value(value) { } +}; + +/// Annotation that marks a class as local to the module: +struct module_local { const bool value; constexpr module_local(bool v = true) : value(v) { } }; + +/// Annotation to mark enums as an arithmetic type +struct arithmetic { }; + +/** \rst + A call policy which places one or more guard variables (``Ts...``) around the function call. + + For example, this definition: + + .. code-block:: cpp + + m.def("foo", foo, py::call_guard()); + + is equivalent to the following pseudocode: + + .. code-block:: cpp + + m.def("foo", [](args...) { + T scope_guard; + return foo(args...); // forwarded arguments + }); + \endrst */ +template struct call_guard; + +template <> struct call_guard<> { using type = detail::void_type; }; + +template +struct call_guard { + static_assert(std::is_default_constructible::value, + "The guard type must be default constructible"); + + using type = T; +}; + +template +struct call_guard { + struct type { + T guard{}; // Compose multiple guard types with left-to-right default-constructor order + typename call_guard::type next{}; + }; +}; + +/// @} annotations + +NAMESPACE_BEGIN(detail) +/* Forward declarations */ +enum op_id : int; +enum op_type : int; +struct undefined_t; +template struct op_; +inline void keep_alive_impl(size_t Nurse, size_t Patient, function_call &call, handle ret); + +/// Internal data structure which holds metadata about a keyword argument +struct argument_record { + const char *name; ///< Argument name + const char *descr; ///< Human-readable version of the argument value + handle value; ///< Associated Python object + bool convert : 1; ///< True if the argument is allowed to convert when loading + bool none : 1; ///< True if None is allowed when loading + + argument_record(const char *name, const char *descr, handle value, bool convert, bool none) + : name(name), descr(descr), value(value), convert(convert), none(none) { } +}; + +/// Internal data structure which holds metadata about a bound function (signature, overloads, etc.) +struct function_record { + function_record() + : is_constructor(false), is_new_style_constructor(false), is_stateless(false), + is_operator(false), has_args(false), has_kwargs(false), is_method(false) { } + + /// Function name + char *name = nullptr; /* why no C++ strings? They generate heavier code.. */ + + // User-specified documentation string + char *doc = nullptr; + + /// Human-readable version of the function signature + char *signature = nullptr; + + /// List of registered keyword arguments + std::vector args; + + /// Pointer to lambda function which converts arguments and performs the actual call + handle (*impl) (function_call &) = nullptr; + + /// Storage for the wrapped function pointer and captured data, if any + void *data[3] = { }; + + /// Pointer to custom destructor for 'data' (if needed) + void (*free_data) (function_record *ptr) = nullptr; + + /// Return value policy associated with this function + return_value_policy policy = return_value_policy::automatic; + + /// True if name == '__init__' + bool is_constructor : 1; + + /// True if this is a new-style `__init__` defined in `detail/init.h` + bool is_new_style_constructor : 1; + + /// True if this is a stateless function pointer + bool is_stateless : 1; + + /// True if this is an operator (__add__), etc. + bool is_operator : 1; + + /// True if the function has a '*args' argument + bool has_args : 1; + + /// True if the function has a '**kwargs' argument + bool has_kwargs : 1; + + /// True if this is a method + bool is_method : 1; + + /// Number of arguments (including py::args and/or py::kwargs, if present) + std::uint16_t nargs; + + /// Python method object + PyMethodDef *def = nullptr; + + /// Python handle to the parent scope (a class or a module) + handle scope; + + /// Python handle to the sibling function representing an overload chain + handle sibling; + + /// Pointer to next overload + function_record *next = nullptr; +}; + +/// Special data structure which (temporarily) holds metadata about a bound class +struct type_record { + PYBIND11_NOINLINE type_record() + : multiple_inheritance(false), dynamic_attr(false), buffer_protocol(false), module_local(false) { } + + /// Handle to the parent scope + handle scope; + + /// Name of the class + const char *name = nullptr; + + // Pointer to RTTI type_info data structure + const std::type_info *type = nullptr; + + /// How large is the underlying C++ type? + size_t type_size = 0; + + /// How large is the type's holder? + size_t holder_size = 0; + + /// The global operator new can be overridden with a class-specific variant + void *(*operator_new)(size_t) = ::operator new; + + /// Function pointer to class_<..>::init_instance + void (*init_instance)(instance *, const void *) = nullptr; + + /// Function pointer to class_<..>::dealloc + void (*dealloc)(detail::value_and_holder &) = nullptr; + + /// List of base classes of the newly created type + list bases; + + /// Optional docstring + const char *doc = nullptr; + + /// Custom metaclass (optional) + handle metaclass; + + /// Multiple inheritance marker + bool multiple_inheritance : 1; + + /// Does the class manage a __dict__? + bool dynamic_attr : 1; + + /// Does the class implement the buffer protocol? + bool buffer_protocol : 1; + + /// Is the default (unique_ptr) holder type used? + bool default_holder : 1; + + /// Is the class definition local to the module shared object? + bool module_local : 1; + + PYBIND11_NOINLINE void add_base(const std::type_info &base, void *(*caster)(void *)) { + auto base_info = detail::get_type_info(base, false); + if (!base_info) { + std::string tname(base.name()); + detail::clean_type_id(tname); + pybind11_fail("generic_type: type \"" + std::string(name) + + "\" referenced unknown base type \"" + tname + "\""); + } + + if (default_holder != base_info->default_holder) { + std::string tname(base.name()); + detail::clean_type_id(tname); + pybind11_fail("generic_type: type \"" + std::string(name) + "\" " + + (default_holder ? "does not have" : "has") + + " a non-default holder type while its base \"" + tname + "\" " + + (base_info->default_holder ? "does not" : "does")); + } + + bases.append((PyObject *) base_info->type); + + if (base_info->type->tp_dictoffset != 0) + dynamic_attr = true; + + if (caster) + base_info->implicit_casts.emplace_back(type, caster); + } +}; + +inline function_call::function_call(function_record &f, handle p) : + func(f), parent(p) { + args.reserve(f.nargs); + args_convert.reserve(f.nargs); +} + +/// Tag for a new-style `__init__` defined in `detail/init.h` +struct is_new_style_constructor { }; + +/** + * Partial template specializations to process custom attributes provided to + * cpp_function_ and class_. These are either used to initialize the respective + * fields in the type_record and function_record data structures or executed at + * runtime to deal with custom call policies (e.g. keep_alive). + */ +template struct process_attribute; + +template struct process_attribute_default { + /// Default implementation: do nothing + static void init(const T &, function_record *) { } + static void init(const T &, type_record *) { } + static void precall(function_call &) { } + static void postcall(function_call &, handle) { } +}; + +/// Process an attribute specifying the function's name +template <> struct process_attribute : process_attribute_default { + static void init(const name &n, function_record *r) { r->name = const_cast(n.value); } +}; + +/// Process an attribute specifying the function's docstring +template <> struct process_attribute : process_attribute_default { + static void init(const doc &n, function_record *r) { r->doc = const_cast(n.value); } +}; + +/// Process an attribute specifying the function's docstring (provided as a C-style string) +template <> struct process_attribute : process_attribute_default { + static void init(const char *d, function_record *r) { r->doc = const_cast(d); } + static void init(const char *d, type_record *r) { r->doc = const_cast(d); } +}; +template <> struct process_attribute : process_attribute { }; + +/// Process an attribute indicating the function's return value policy +template <> struct process_attribute : process_attribute_default { + static void init(const return_value_policy &p, function_record *r) { r->policy = p; } +}; + +/// Process an attribute which indicates that this is an overloaded function associated with a given sibling +template <> struct process_attribute : process_attribute_default { + static void init(const sibling &s, function_record *r) { r->sibling = s.value; } +}; + +/// Process an attribute which indicates that this function is a method +template <> struct process_attribute : process_attribute_default { + static void init(const is_method &s, function_record *r) { r->is_method = true; r->scope = s.class_; } +}; + +/// Process an attribute which indicates the parent scope of a method +template <> struct process_attribute : process_attribute_default { + static void init(const scope &s, function_record *r) { r->scope = s.value; } +}; + +/// Process an attribute which indicates that this function is an operator +template <> struct process_attribute : process_attribute_default { + static void init(const is_operator &, function_record *r) { r->is_operator = true; } +}; + +template <> struct process_attribute : process_attribute_default { + static void init(const is_new_style_constructor &, function_record *r) { r->is_new_style_constructor = true; } +}; + +/// Process a keyword argument attribute (*without* a default value) +template <> struct process_attribute : process_attribute_default { + static void init(const arg &a, function_record *r) { + if (r->is_method && r->args.empty()) + r->args.emplace_back("self", nullptr, handle(), true /*convert*/, false /*none not allowed*/); + r->args.emplace_back(a.name, nullptr, handle(), !a.flag_noconvert, a.flag_none); + } +}; + +/// Process a keyword argument attribute (*with* a default value) +template <> struct process_attribute : process_attribute_default { + static void init(const arg_v &a, function_record *r) { + if (r->is_method && r->args.empty()) + r->args.emplace_back("self", nullptr /*descr*/, handle() /*parent*/, true /*convert*/, false /*none not allowed*/); + + if (!a.value) { +#if !defined(NDEBUG) + std::string descr("'"); + if (a.name) descr += std::string(a.name) + ": "; + descr += a.type + "'"; + if (r->is_method) { + if (r->name) + descr += " in method '" + (std::string) str(r->scope) + "." + (std::string) r->name + "'"; + else + descr += " in method of '" + (std::string) str(r->scope) + "'"; + } else if (r->name) { + descr += " in function '" + (std::string) r->name + "'"; + } + pybind11_fail("arg(): could not convert default argument " + + descr + " into a Python object (type not registered yet?)"); +#else + pybind11_fail("arg(): could not convert default argument " + "into a Python object (type not registered yet?). " + "Compile in debug mode for more information."); +#endif + } + r->args.emplace_back(a.name, a.descr, a.value.inc_ref(), !a.flag_noconvert, a.flag_none); + } +}; + +/// Process a parent class attribute. Single inheritance only (class_ itself already guarantees that) +template +struct process_attribute::value>> : process_attribute_default { + static void init(const handle &h, type_record *r) { r->bases.append(h); } +}; + +/// Process a parent class attribute (deprecated, does not support multiple inheritance) +template +struct process_attribute> : process_attribute_default> { + static void init(const base &, type_record *r) { r->add_base(typeid(T), nullptr); } +}; + +/// Process a multiple inheritance attribute +template <> +struct process_attribute : process_attribute_default { + static void init(const multiple_inheritance &, type_record *r) { r->multiple_inheritance = true; } +}; + +template <> +struct process_attribute : process_attribute_default { + static void init(const dynamic_attr &, type_record *r) { r->dynamic_attr = true; } +}; + +template <> +struct process_attribute : process_attribute_default { + static void init(const buffer_protocol &, type_record *r) { r->buffer_protocol = true; } +}; + +template <> +struct process_attribute : process_attribute_default { + static void init(const metaclass &m, type_record *r) { r->metaclass = m.value; } +}; + +template <> +struct process_attribute : process_attribute_default { + static void init(const module_local &l, type_record *r) { r->module_local = l.value; } +}; + +/// Process an 'arithmetic' attribute for enums (does nothing here) +template <> +struct process_attribute : process_attribute_default {}; + +template +struct process_attribute> : process_attribute_default> { }; + +/** + * Process a keep_alive call policy -- invokes keep_alive_impl during the + * pre-call handler if both Nurse, Patient != 0 and use the post-call handler + * otherwise + */ +template struct process_attribute> : public process_attribute_default> { + template = 0> + static void precall(function_call &call) { keep_alive_impl(Nurse, Patient, call, handle()); } + template = 0> + static void postcall(function_call &, handle) { } + template = 0> + static void precall(function_call &) { } + template = 0> + static void postcall(function_call &call, handle ret) { keep_alive_impl(Nurse, Patient, call, ret); } +}; + +/// Recursively iterate over variadic template arguments +template struct process_attributes { + static void init(const Args&... args, function_record *r) { + int unused[] = { 0, (process_attribute::type>::init(args, r), 0) ... }; + ignore_unused(unused); + } + static void init(const Args&... args, type_record *r) { + int unused[] = { 0, (process_attribute::type>::init(args, r), 0) ... }; + ignore_unused(unused); + } + static void precall(function_call &call) { + int unused[] = { 0, (process_attribute::type>::precall(call), 0) ... }; + ignore_unused(unused); + } + static void postcall(function_call &call, handle fn_ret) { + int unused[] = { 0, (process_attribute::type>::postcall(call, fn_ret), 0) ... }; + ignore_unused(unused); + } +}; + +template +using is_call_guard = is_instantiation; + +/// Extract the ``type`` from the first `call_guard` in `Extras...` (or `void_type` if none found) +template +using extract_guard_t = typename exactly_one_t, Extra...>::type; + +/// Check the number of named arguments at compile time +template ::value...), + size_t self = constexpr_sum(std::is_same::value...)> +constexpr bool expected_num_args(size_t nargs, bool has_args, bool has_kwargs) { + return named == 0 || (self + named + has_args + has_kwargs) == nargs; +} + +NAMESPACE_END(detail) +NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/ml/dlib/dlib/external/pybind11/include/pybind11/buffer_info.h b/ml/dlib/dlib/external/pybind11/include/pybind11/buffer_info.h new file mode 100644 index 000000000..9f072fa73 --- /dev/null +++ b/ml/dlib/dlib/external/pybind11/include/pybind11/buffer_info.h @@ -0,0 +1,108 @@ +/* + pybind11/buffer_info.h: Python buffer object interface + + Copyright (c) 2016 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include "detail/common.h" + +NAMESPACE_BEGIN(PYBIND11_NAMESPACE) + +/// Information record describing a Python buffer object +struct buffer_info { + void *ptr = nullptr; // Pointer to the underlying storage + ssize_t itemsize = 0; // Size of individual items in bytes + ssize_t size = 0; // Total number of entries + std::string format; // For homogeneous buffers, this should be set to format_descriptor::format() + ssize_t ndim = 0; // Number of dimensions + std::vector shape; // Shape of the tensor (1 entry per dimension) + std::vector strides; // Number of entries between adjacent entries (for each per dimension) + + buffer_info() { } + + buffer_info(void *ptr, ssize_t itemsize, const std::string &format, ssize_t ndim, + detail::any_container shape_in, detail::any_container strides_in) + : ptr(ptr), itemsize(itemsize), size(1), format(format), ndim(ndim), + shape(std::move(shape_in)), strides(std::move(strides_in)) { + if (ndim != (ssize_t) shape.size() || ndim != (ssize_t) strides.size()) + pybind11_fail("buffer_info: ndim doesn't match shape and/or strides length"); + for (size_t i = 0; i < (size_t) ndim; ++i) + size *= shape[i]; + } + + template + buffer_info(T *ptr, detail::any_container shape_in, detail::any_container strides_in) + : buffer_info(private_ctr_tag(), ptr, sizeof(T), format_descriptor::format(), static_cast(shape_in->size()), std::move(shape_in), std::move(strides_in)) { } + + buffer_info(void *ptr, ssize_t itemsize, const std::string &format, ssize_t size) + : buffer_info(ptr, itemsize, format, 1, {size}, {itemsize}) { } + + template + buffer_info(T *ptr, ssize_t size) + : buffer_info(ptr, sizeof(T), format_descriptor::format(), size) { } + + explicit buffer_info(Py_buffer *view, bool ownview = true) + : buffer_info(view->buf, view->itemsize, view->format, view->ndim, + {view->shape, view->shape + view->ndim}, {view->strides, view->strides + view->ndim}) { + this->view = view; + this->ownview = ownview; + } + + buffer_info(const buffer_info &) = delete; + buffer_info& operator=(const buffer_info &) = delete; + + buffer_info(buffer_info &&other) { + (*this) = std::move(other); + } + + buffer_info& operator=(buffer_info &&rhs) { + ptr = rhs.ptr; + itemsize = rhs.itemsize; + size = rhs.size; + format = std::move(rhs.format); + ndim = rhs.ndim; + shape = std::move(rhs.shape); + strides = std::move(rhs.strides); + std::swap(view, rhs.view); + std::swap(ownview, rhs.ownview); + return *this; + } + + ~buffer_info() { + if (view && ownview) { PyBuffer_Release(view); delete view; } + } + +private: + struct private_ctr_tag { }; + + buffer_info(private_ctr_tag, void *ptr, ssize_t itemsize, const std::string &format, ssize_t ndim, + detail::any_container &&shape_in, detail::any_container &&strides_in) + : buffer_info(ptr, itemsize, format, ndim, std::move(shape_in), std::move(strides_in)) { } + + Py_buffer *view = nullptr; + bool ownview = false; +}; + +NAMESPACE_BEGIN(detail) + +template struct compare_buffer_info { + static bool compare(const buffer_info& b) { + return b.format == format_descriptor::format() && b.itemsize == (ssize_t) sizeof(T); + } +}; + +template struct compare_buffer_info::value>> { + static bool compare(const buffer_info& b) { + return (size_t) b.itemsize == sizeof(T) && (b.format == format_descriptor::value || + ((sizeof(T) == sizeof(long)) && b.format == (std::is_unsigned::value ? "L" : "l")) || + ((sizeof(T) == sizeof(size_t)) && b.format == (std::is_unsigned::value ? "N" : "n"))); + } +}; + +NAMESPACE_END(detail) +NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/ml/dlib/dlib/external/pybind11/include/pybind11/cast.h b/ml/dlib/dlib/external/pybind11/include/pybind11/cast.h new file mode 100644 index 000000000..a722a9e81 --- /dev/null +++ b/ml/dlib/dlib/external/pybind11/include/pybind11/cast.h @@ -0,0 +1,2063 @@ +/* + pybind11/cast.h: Partial template specializations to cast between + C++ and Python types + + Copyright (c) 2016 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include "pytypes.h" +#include "detail/typeid.h" +#include "detail/descr.h" +#include "detail/internals.h" +#include +#include +#include + +#if defined(PYBIND11_CPP17) +# if defined(__has_include) +# if __has_include() +# define PYBIND11_HAS_STRING_VIEW +# endif +# elif defined(_MSC_VER) +# define PYBIND11_HAS_STRING_VIEW +# endif +#endif +#ifdef PYBIND11_HAS_STRING_VIEW +#include +#endif + +NAMESPACE_BEGIN(PYBIND11_NAMESPACE) +NAMESPACE_BEGIN(detail) + +/// A life support system for temporary objects created by `type_caster::load()`. +/// Adding a patient will keep it alive up until the enclosing function returns. +class loader_life_support { +public: + /// A new patient frame is created when a function is entered + loader_life_support() { + get_internals().loader_patient_stack.push_back(nullptr); + } + + /// ... and destroyed after it returns + ~loader_life_support() { + auto &stack = get_internals().loader_patient_stack; + if (stack.empty()) + pybind11_fail("loader_life_support: internal error"); + + auto ptr = stack.back(); + stack.pop_back(); + Py_CLEAR(ptr); + + // A heuristic to reduce the stack's capacity (e.g. after long recursive calls) + if (stack.capacity() > 16 && stack.size() != 0 && stack.capacity() / stack.size() > 2) + stack.shrink_to_fit(); + } + + /// This can only be used inside a pybind11-bound function, either by `argument_loader` + /// at argument preparation time or by `py::cast()` at execution time. + PYBIND11_NOINLINE static void add_patient(handle h) { + auto &stack = get_internals().loader_patient_stack; + if (stack.empty()) + throw cast_error("When called outside a bound function, py::cast() cannot " + "do Python -> C++ conversions which require the creation " + "of temporary values"); + + auto &list_ptr = stack.back(); + if (list_ptr == nullptr) { + list_ptr = PyList_New(1); + if (!list_ptr) + pybind11_fail("loader_life_support: error allocating list"); + PyList_SET_ITEM(list_ptr, 0, h.inc_ref().ptr()); + } else { + auto result = PyList_Append(list_ptr, h.ptr()); + if (result == -1) + pybind11_fail("loader_life_support: error adding patient"); + } + } +}; + +// Gets the cache entry for the given type, creating it if necessary. The return value is the pair +// returned by emplace, i.e. an iterator for the entry and a bool set to `true` if the entry was +// just created. +inline std::pair all_type_info_get_cache(PyTypeObject *type); + +// Populates a just-created cache entry. +PYBIND11_NOINLINE inline void all_type_info_populate(PyTypeObject *t, std::vector &bases) { + std::vector check; + for (handle parent : reinterpret_borrow(t->tp_bases)) + check.push_back((PyTypeObject *) parent.ptr()); + + auto const &type_dict = get_internals().registered_types_py; + for (size_t i = 0; i < check.size(); i++) { + auto type = check[i]; + // Ignore Python2 old-style class super type: + if (!PyType_Check((PyObject *) type)) continue; + + // Check `type` in the current set of registered python types: + auto it = type_dict.find(type); + if (it != type_dict.end()) { + // We found a cache entry for it, so it's either pybind-registered or has pre-computed + // pybind bases, but we have to make sure we haven't already seen the type(s) before: we + // want to follow Python/virtual C++ rules that there should only be one instance of a + // common base. + for (auto *tinfo : it->second) { + // NB: Could use a second set here, rather than doing a linear search, but since + // having a large number of immediate pybind11-registered types seems fairly + // unlikely, that probably isn't worthwhile. + bool found = false; + for (auto *known : bases) { + if (known == tinfo) { found = true; break; } + } + if (!found) bases.push_back(tinfo); + } + } + else if (type->tp_bases) { + // It's some python type, so keep follow its bases classes to look for one or more + // registered types + if (i + 1 == check.size()) { + // When we're at the end, we can pop off the current element to avoid growing + // `check` when adding just one base (which is typical--i.e. when there is no + // multiple inheritance) + check.pop_back(); + i--; + } + for (handle parent : reinterpret_borrow(type->tp_bases)) + check.push_back((PyTypeObject *) parent.ptr()); + } + } +} + +/** + * Extracts vector of type_info pointers of pybind-registered roots of the given Python type. Will + * be just 1 pybind type for the Python type of a pybind-registered class, or for any Python-side + * derived class that uses single inheritance. Will contain as many types as required for a Python + * class that uses multiple inheritance to inherit (directly or indirectly) from multiple + * pybind-registered classes. Will be empty if neither the type nor any base classes are + * pybind-registered. + * + * The value is cached for the lifetime of the Python type. + */ +inline const std::vector &all_type_info(PyTypeObject *type) { + auto ins = all_type_info_get_cache(type); + if (ins.second) + // New cache entry: populate it + all_type_info_populate(type, ins.first->second); + + return ins.first->second; +} + +/** + * Gets a single pybind11 type info for a python type. Returns nullptr if neither the type nor any + * ancestors are pybind11-registered. Throws an exception if there are multiple bases--use + * `all_type_info` instead if you want to support multiple bases. + */ +PYBIND11_NOINLINE inline detail::type_info* get_type_info(PyTypeObject *type) { + auto &bases = all_type_info(type); + if (bases.size() == 0) + return nullptr; + if (bases.size() > 1) + pybind11_fail("pybind11::detail::get_type_info: type has multiple pybind11-registered bases"); + return bases.front(); +} + +inline detail::type_info *get_local_type_info(const std::type_index &tp) { + auto &locals = registered_local_types_cpp(); + auto it = locals.find(tp); + if (it != locals.end()) + return it->second; + return nullptr; +} + +inline detail::type_info *get_global_type_info(const std::type_index &tp) { + auto &types = get_internals().registered_types_cpp; + auto it = types.find(tp); + if (it != types.end()) + return it->second; + return nullptr; +} + +/// Return the type info for a given C++ type; on lookup failure can either throw or return nullptr. +PYBIND11_NOINLINE inline detail::type_info *get_type_info(const std::type_index &tp, + bool throw_if_missing = false) { + if (auto ltype = get_local_type_info(tp)) + return ltype; + if (auto gtype = get_global_type_info(tp)) + return gtype; + + if (throw_if_missing) { + std::string tname = tp.name(); + detail::clean_type_id(tname); + pybind11_fail("pybind11::detail::get_type_info: unable to find type info for \"" + tname + "\""); + } + return nullptr; +} + +PYBIND11_NOINLINE inline handle get_type_handle(const std::type_info &tp, bool throw_if_missing) { + detail::type_info *type_info = get_type_info(tp, throw_if_missing); + return handle(type_info ? ((PyObject *) type_info->type) : nullptr); +} + +struct value_and_holder { + instance *inst; + size_t index; + const detail::type_info *type; + void **vh; + + // Main constructor for a found value/holder: + value_and_holder(instance *i, const detail::type_info *type, size_t vpos, size_t index) : + inst{i}, index{index}, type{type}, + vh{inst->simple_layout ? inst->simple_value_holder : &inst->nonsimple.values_and_holders[vpos]} + {} + + // Default constructor (used to signal a value-and-holder not found by get_value_and_holder()) + value_and_holder() : inst{nullptr} {} + + // Used for past-the-end iterator + value_and_holder(size_t index) : index{index} {} + + template V *&value_ptr() const { + return reinterpret_cast(vh[0]); + } + // True if this `value_and_holder` has a non-null value pointer + explicit operator bool() const { return value_ptr(); } + + template H &holder() const { + return reinterpret_cast(vh[1]); + } + bool holder_constructed() const { + return inst->simple_layout + ? inst->simple_holder_constructed + : inst->nonsimple.status[index] & instance::status_holder_constructed; + } + void set_holder_constructed(bool v = true) { + if (inst->simple_layout) + inst->simple_holder_constructed = v; + else if (v) + inst->nonsimple.status[index] |= instance::status_holder_constructed; + else + inst->nonsimple.status[index] &= (uint8_t) ~instance::status_holder_constructed; + } + bool instance_registered() const { + return inst->simple_layout + ? inst->simple_instance_registered + : inst->nonsimple.status[index] & instance::status_instance_registered; + } + void set_instance_registered(bool v = true) { + if (inst->simple_layout) + inst->simple_instance_registered = v; + else if (v) + inst->nonsimple.status[index] |= instance::status_instance_registered; + else + inst->nonsimple.status[index] &= (uint8_t) ~instance::status_instance_registered; + } +}; + +// Container for accessing and iterating over an instance's values/holders +struct values_and_holders { +private: + instance *inst; + using type_vec = std::vector; + const type_vec &tinfo; + +public: + values_and_holders(instance *inst) : inst{inst}, tinfo(all_type_info(Py_TYPE(inst))) {} + + struct iterator { + private: + instance *inst; + const type_vec *types; + value_and_holder curr; + friend struct values_and_holders; + iterator(instance *inst, const type_vec *tinfo) + : inst{inst}, types{tinfo}, + curr(inst /* instance */, + types->empty() ? nullptr : (*types)[0] /* type info */, + 0, /* vpos: (non-simple types only): the first vptr comes first */ + 0 /* index */) + {} + // Past-the-end iterator: + iterator(size_t end) : curr(end) {} + public: + bool operator==(const iterator &other) { return curr.index == other.curr.index; } + bool operator!=(const iterator &other) { return curr.index != other.curr.index; } + iterator &operator++() { + if (!inst->simple_layout) + curr.vh += 1 + (*types)[curr.index]->holder_size_in_ptrs; + ++curr.index; + curr.type = curr.index < types->size() ? (*types)[curr.index] : nullptr; + return *this; + } + value_and_holder &operator*() { return curr; } + value_and_holder *operator->() { return &curr; } + }; + + iterator begin() { return iterator(inst, &tinfo); } + iterator end() { return iterator(tinfo.size()); } + + iterator find(const type_info *find_type) { + auto it = begin(), endit = end(); + while (it != endit && it->type != find_type) ++it; + return it; + } + + size_t size() { return tinfo.size(); } +}; + +/** + * Extracts C++ value and holder pointer references from an instance (which may contain multiple + * values/holders for python-side multiple inheritance) that match the given type. Throws an error + * if the given type (or ValueType, if omitted) is not a pybind11 base of the given instance. If + * `find_type` is omitted (or explicitly specified as nullptr) the first value/holder are returned, + * regardless of type (and the resulting .type will be nullptr). + * + * The returned object should be short-lived: in particular, it must not outlive the called-upon + * instance. + */ +PYBIND11_NOINLINE inline value_and_holder instance::get_value_and_holder(const type_info *find_type /*= nullptr default in common.h*/, bool throw_if_missing /*= true in common.h*/) { + // Optimize common case: + if (!find_type || Py_TYPE(this) == find_type->type) + return value_and_holder(this, find_type, 0, 0); + + detail::values_and_holders vhs(this); + auto it = vhs.find(find_type); + if (it != vhs.end()) + return *it; + + if (!throw_if_missing) + return value_and_holder(); + +#if defined(NDEBUG) + pybind11_fail("pybind11::detail::instance::get_value_and_holder: " + "type is not a pybind11 base of the given instance " + "(compile in debug mode for type details)"); +#else + pybind11_fail("pybind11::detail::instance::get_value_and_holder: `" + + std::string(find_type->type->tp_name) + "' is not a pybind11 base of the given `" + + std::string(Py_TYPE(this)->tp_name) + "' instance"); +#endif +} + +PYBIND11_NOINLINE inline void instance::allocate_layout() { + auto &tinfo = all_type_info(Py_TYPE(this)); + + const size_t n_types = tinfo.size(); + + if (n_types == 0) + pybind11_fail("instance allocation failed: new instance has no pybind11-registered base types"); + + simple_layout = + n_types == 1 && tinfo.front()->holder_size_in_ptrs <= instance_simple_holder_in_ptrs(); + + // Simple path: no python-side multiple inheritance, and a small-enough holder + if (simple_layout) { + simple_value_holder[0] = nullptr; + simple_holder_constructed = false; + simple_instance_registered = false; + } + else { // multiple base types or a too-large holder + // Allocate space to hold: [v1*][h1][v2*][h2]...[bb...] where [vN*] is a value pointer, + // [hN] is the (uninitialized) holder instance for value N, and [bb...] is a set of bool + // values that tracks whether each associated holder has been initialized. Each [block] is + // padded, if necessary, to an integer multiple of sizeof(void *). + size_t space = 0; + for (auto t : tinfo) { + space += 1; // value pointer + space += t->holder_size_in_ptrs; // holder instance + } + size_t flags_at = space; + space += size_in_ptrs(n_types); // status bytes (holder_constructed and instance_registered) + + // Allocate space for flags, values, and holders, and initialize it to 0 (flags and values, + // in particular, need to be 0). Use Python's memory allocation functions: in Python 3.6 + // they default to using pymalloc, which is designed to be efficient for small allocations + // like the one we're doing here; in earlier versions (and for larger allocations) they are + // just wrappers around malloc. +#if PY_VERSION_HEX >= 0x03050000 + nonsimple.values_and_holders = (void **) PyMem_Calloc(space, sizeof(void *)); + if (!nonsimple.values_and_holders) throw std::bad_alloc(); +#else + nonsimple.values_and_holders = (void **) PyMem_New(void *, space); + if (!nonsimple.values_and_holders) throw std::bad_alloc(); + std::memset(nonsimple.values_and_holders, 0, space * sizeof(void *)); +#endif + nonsimple.status = reinterpret_cast(&nonsimple.values_and_holders[flags_at]); + } + owned = true; +} + +PYBIND11_NOINLINE inline void instance::deallocate_layout() { + if (!simple_layout) + PyMem_Free(nonsimple.values_and_holders); +} + +PYBIND11_NOINLINE inline bool isinstance_generic(handle obj, const std::type_info &tp) { + handle type = detail::get_type_handle(tp, false); + if (!type) + return false; + return isinstance(obj, type); +} + +PYBIND11_NOINLINE inline std::string error_string() { + if (!PyErr_Occurred()) { + PyErr_SetString(PyExc_RuntimeError, "Unknown internal error occurred"); + return "Unknown internal error occurred"; + } + + error_scope scope; // Preserve error state + + std::string errorString; + if (scope.type) { + errorString += handle(scope.type).attr("__name__").cast(); + errorString += ": "; + } + if (scope.value) + errorString += (std::string) str(scope.value); + + PyErr_NormalizeException(&scope.type, &scope.value, &scope.trace); + +#if PY_MAJOR_VERSION >= 3 + if (scope.trace != nullptr) + PyException_SetTraceback(scope.value, scope.trace); +#endif + +#if !defined(PYPY_VERSION) + if (scope.trace) { + PyTracebackObject *trace = (PyTracebackObject *) scope.trace; + + /* Get the deepest trace possible */ + while (trace->tb_next) + trace = trace->tb_next; + + PyFrameObject *frame = trace->tb_frame; + errorString += "\n\nAt:\n"; + while (frame) { + int lineno = PyFrame_GetLineNumber(frame); + errorString += + " " + handle(frame->f_code->co_filename).cast() + + "(" + std::to_string(lineno) + "): " + + handle(frame->f_code->co_name).cast() + "\n"; + frame = frame->f_back; + } + } +#endif + + return errorString; +} + +PYBIND11_NOINLINE inline handle get_object_handle(const void *ptr, const detail::type_info *type ) { + auto &instances = get_internals().registered_instances; + auto range = instances.equal_range(ptr); + for (auto it = range.first; it != range.second; ++it) { + for (auto vh : values_and_holders(it->second)) { + if (vh.type == type) + return handle((PyObject *) it->second); + } + } + return handle(); +} + +inline PyThreadState *get_thread_state_unchecked() { +#if defined(PYPY_VERSION) + return PyThreadState_GET(); +#elif PY_VERSION_HEX < 0x03000000 + return _PyThreadState_Current; +#elif PY_VERSION_HEX < 0x03050000 + return (PyThreadState*) _Py_atomic_load_relaxed(&_PyThreadState_Current); +#elif PY_VERSION_HEX < 0x03050200 + return (PyThreadState*) _PyThreadState_Current.value; +#else + return _PyThreadState_UncheckedGet(); +#endif +} + +// Forward declarations +inline void keep_alive_impl(handle nurse, handle patient); +inline PyObject *make_new_instance(PyTypeObject *type); + +class type_caster_generic { +public: + PYBIND11_NOINLINE type_caster_generic(const std::type_info &type_info) + : typeinfo(get_type_info(type_info)), cpptype(&type_info) { } + + type_caster_generic(const type_info *typeinfo) + : typeinfo(typeinfo), cpptype(typeinfo ? typeinfo->cpptype : nullptr) { } + + bool load(handle src, bool convert) { + return load_impl(src, convert); + } + + PYBIND11_NOINLINE static handle cast(const void *_src, return_value_policy policy, handle parent, + const detail::type_info *tinfo, + void *(*copy_constructor)(const void *), + void *(*move_constructor)(const void *), + const void *existing_holder = nullptr) { + if (!tinfo) // no type info: error will be set already + return handle(); + + void *src = const_cast(_src); + if (src == nullptr) + return none().release(); + + auto it_instances = get_internals().registered_instances.equal_range(src); + for (auto it_i = it_instances.first; it_i != it_instances.second; ++it_i) { + for (auto instance_type : detail::all_type_info(Py_TYPE(it_i->second))) { + if (instance_type && same_type(*instance_type->cpptype, *tinfo->cpptype)) + return handle((PyObject *) it_i->second).inc_ref(); + } + } + + auto inst = reinterpret_steal(make_new_instance(tinfo->type)); + auto wrapper = reinterpret_cast(inst.ptr()); + wrapper->owned = false; + void *&valueptr = values_and_holders(wrapper).begin()->value_ptr(); + + switch (policy) { + case return_value_policy::automatic: + case return_value_policy::take_ownership: + valueptr = src; + wrapper->owned = true; + break; + + case return_value_policy::automatic_reference: + case return_value_policy::reference: + valueptr = src; + wrapper->owned = false; + break; + + case return_value_policy::copy: + if (copy_constructor) + valueptr = copy_constructor(src); + else + throw cast_error("return_value_policy = copy, but the " + "object is non-copyable!"); + wrapper->owned = true; + break; + + case return_value_policy::move: + if (move_constructor) + valueptr = move_constructor(src); + else if (copy_constructor) + valueptr = copy_constructor(src); + else + throw cast_error("return_value_policy = move, but the " + "object is neither movable nor copyable!"); + wrapper->owned = true; + break; + + case return_value_policy::reference_internal: + valueptr = src; + wrapper->owned = false; + keep_alive_impl(inst, parent); + break; + + default: + throw cast_error("unhandled return_value_policy: should not happen!"); + } + + tinfo->init_instance(wrapper, existing_holder); + + return inst.release(); + } + + // Base methods for generic caster; there are overridden in copyable_holder_caster + void load_value(value_and_holder &&v_h) { + auto *&vptr = v_h.value_ptr(); + // Lazy allocation for unallocated values: + if (vptr == nullptr) { + auto *type = v_h.type ? v_h.type : typeinfo; + vptr = type->operator_new(type->type_size); + } + value = vptr; + } + bool try_implicit_casts(handle src, bool convert) { + for (auto &cast : typeinfo->implicit_casts) { + type_caster_generic sub_caster(*cast.first); + if (sub_caster.load(src, convert)) { + value = cast.second(sub_caster.value); + return true; + } + } + return false; + } + bool try_direct_conversions(handle src) { + for (auto &converter : *typeinfo->direct_conversions) { + if (converter(src.ptr(), value)) + return true; + } + return false; + } + void check_holder_compat() {} + + PYBIND11_NOINLINE static void *local_load(PyObject *src, const type_info *ti) { + auto caster = type_caster_generic(ti); + if (caster.load(src, false)) + return caster.value; + return nullptr; + } + + /// Try to load with foreign typeinfo, if available. Used when there is no + /// native typeinfo, or when the native one wasn't able to produce a value. + PYBIND11_NOINLINE bool try_load_foreign_module_local(handle src) { + constexpr auto *local_key = PYBIND11_MODULE_LOCAL_ID; + const auto pytype = src.get_type(); + if (!hasattr(pytype, local_key)) + return false; + + type_info *foreign_typeinfo = reinterpret_borrow(getattr(pytype, local_key)); + // Only consider this foreign loader if actually foreign and is a loader of the correct cpp type + if (foreign_typeinfo->module_local_load == &local_load + || (cpptype && !same_type(*cpptype, *foreign_typeinfo->cpptype))) + return false; + + if (auto result = foreign_typeinfo->module_local_load(src.ptr(), foreign_typeinfo)) { + value = result; + return true; + } + return false; + } + + // Implementation of `load`; this takes the type of `this` so that it can dispatch the relevant + // bits of code between here and copyable_holder_caster where the two classes need different + // logic (without having to resort to virtual inheritance). + template + PYBIND11_NOINLINE bool load_impl(handle src, bool convert) { + if (!src) return false; + if (!typeinfo) return try_load_foreign_module_local(src); + if (src.is_none()) { + // Defer accepting None to other overloads (if we aren't in convert mode): + if (!convert) return false; + value = nullptr; + return true; + } + + auto &this_ = static_cast(*this); + this_.check_holder_compat(); + + PyTypeObject *srctype = Py_TYPE(src.ptr()); + + // Case 1: If src is an exact type match for the target type then we can reinterpret_cast + // the instance's value pointer to the target type: + if (srctype == typeinfo->type) { + this_.load_value(reinterpret_cast(src.ptr())->get_value_and_holder()); + return true; + } + // Case 2: We have a derived class + else if (PyType_IsSubtype(srctype, typeinfo->type)) { + auto &bases = all_type_info(srctype); + bool no_cpp_mi = typeinfo->simple_type; + + // Case 2a: the python type is a Python-inherited derived class that inherits from just + // one simple (no MI) pybind11 class, or is an exact match, so the C++ instance is of + // the right type and we can use reinterpret_cast. + // (This is essentially the same as case 2b, but because not using multiple inheritance + // is extremely common, we handle it specially to avoid the loop iterator and type + // pointer lookup overhead) + if (bases.size() == 1 && (no_cpp_mi || bases.front()->type == typeinfo->type)) { + this_.load_value(reinterpret_cast(src.ptr())->get_value_and_holder()); + return true; + } + // Case 2b: the python type inherits from multiple C++ bases. Check the bases to see if + // we can find an exact match (or, for a simple C++ type, an inherited match); if so, we + // can safely reinterpret_cast to the relevant pointer. + else if (bases.size() > 1) { + for (auto base : bases) { + if (no_cpp_mi ? PyType_IsSubtype(base->type, typeinfo->type) : base->type == typeinfo->type) { + this_.load_value(reinterpret_cast(src.ptr())->get_value_and_holder(base)); + return true; + } + } + } + + // Case 2c: C++ multiple inheritance is involved and we couldn't find an exact type match + // in the registered bases, above, so try implicit casting (needed for proper C++ casting + // when MI is involved). + if (this_.try_implicit_casts(src, convert)) + return true; + } + + // Perform an implicit conversion + if (convert) { + for (auto &converter : typeinfo->implicit_conversions) { + auto temp = reinterpret_steal(converter(src.ptr(), typeinfo->type)); + if (load_impl(temp, false)) { + loader_life_support::add_patient(temp); + return true; + } + } + if (this_.try_direct_conversions(src)) + return true; + } + + // Failed to match local typeinfo. Try again with global. + if (typeinfo->module_local) { + if (auto gtype = get_global_type_info(*typeinfo->cpptype)) { + typeinfo = gtype; + return load(src, false); + } + } + + // Global typeinfo has precedence over foreign module_local + return try_load_foreign_module_local(src); + } + + + // Called to do type lookup and wrap the pointer and type in a pair when a dynamic_cast + // isn't needed or can't be used. If the type is unknown, sets the error and returns a pair + // with .second = nullptr. (p.first = nullptr is not an error: it becomes None). + PYBIND11_NOINLINE static std::pair src_and_type( + const void *src, const std::type_info &cast_type, const std::type_info *rtti_type = nullptr) { + if (auto *tpi = get_type_info(cast_type)) + return {src, const_cast(tpi)}; + + // Not found, set error: + std::string tname = rtti_type ? rtti_type->name() : cast_type.name(); + detail::clean_type_id(tname); + std::string msg = "Unregistered type : " + tname; + PyErr_SetString(PyExc_TypeError, msg.c_str()); + return {nullptr, nullptr}; + } + + const type_info *typeinfo = nullptr; + const std::type_info *cpptype = nullptr; + void *value = nullptr; +}; + +/** + * Determine suitable casting operator for pointer-or-lvalue-casting type casters. The type caster + * needs to provide `operator T*()` and `operator T&()` operators. + * + * If the type supports moving the value away via an `operator T&&() &&` method, it should use + * `movable_cast_op_type` instead. + */ +template +using cast_op_type = + conditional_t>::value, + typename std::add_pointer>::type, + typename std::add_lvalue_reference>::type>; + +/** + * Determine suitable casting operator for a type caster with a movable value. Such a type caster + * needs to provide `operator T*()`, `operator T&()`, and `operator T&&() &&`. The latter will be + * called in appropriate contexts where the value can be moved rather than copied. + * + * These operator are automatically provided when using the PYBIND11_TYPE_CASTER macro. + */ +template +using movable_cast_op_type = + conditional_t::type>::value, + typename std::add_pointer>::type, + conditional_t::value, + typename std::add_rvalue_reference>::type, + typename std::add_lvalue_reference>::type>>; + +// std::is_copy_constructible isn't quite enough: it lets std::vector (and similar) through when +// T is non-copyable, but code containing such a copy constructor fails to actually compile. +template struct is_copy_constructible : std::is_copy_constructible {}; + +// Specialization for types that appear to be copy constructible but also look like stl containers +// (we specifically check for: has `value_type` and `reference` with `reference = value_type&`): if +// so, copy constructability depends on whether the value_type is copy constructible. +template struct is_copy_constructible, + std::is_same + >::value>> : is_copy_constructible {}; + +#if !defined(PYBIND11_CPP17) +// Likewise for std::pair before C++17 (which mandates that the copy constructor not exist when the +// two types aren't themselves copy constructible). +template struct is_copy_constructible> + : all_of, is_copy_constructible> {}; +#endif + +/// Generic type caster for objects stored on the heap +template class type_caster_base : public type_caster_generic { + using itype = intrinsic_t; +public: + static PYBIND11_DESCR name() { return type_descr(_()); } + + type_caster_base() : type_caster_base(typeid(type)) { } + explicit type_caster_base(const std::type_info &info) : type_caster_generic(info) { } + + static handle cast(const itype &src, return_value_policy policy, handle parent) { + if (policy == return_value_policy::automatic || policy == return_value_policy::automatic_reference) + policy = return_value_policy::copy; + return cast(&src, policy, parent); + } + + static handle cast(itype &&src, return_value_policy, handle parent) { + return cast(&src, return_value_policy::move, parent); + } + + // Returns a (pointer, type_info) pair taking care of necessary RTTI type lookup for a + // polymorphic type. If the instance isn't derived, returns the non-RTTI base version. + template ::value, int> = 0> + static std::pair src_and_type(const itype *src) { + const void *vsrc = src; + auto &cast_type = typeid(itype); + const std::type_info *instance_type = nullptr; + if (vsrc) { + instance_type = &typeid(*src); + if (!same_type(cast_type, *instance_type)) { + // This is a base pointer to a derived type; if it is a pybind11-registered type, we + // can get the correct derived pointer (which may be != base pointer) by a + // dynamic_cast to most derived type: + if (auto *tpi = get_type_info(*instance_type)) + return {dynamic_cast(src), const_cast(tpi)}; + } + } + // Otherwise we have either a nullptr, an `itype` pointer, or an unknown derived pointer, so + // don't do a cast + return type_caster_generic::src_and_type(vsrc, cast_type, instance_type); + } + + // Non-polymorphic type, so no dynamic casting; just call the generic version directly + template ::value, int> = 0> + static std::pair src_and_type(const itype *src) { + return type_caster_generic::src_and_type(src, typeid(itype)); + } + + static handle cast(const itype *src, return_value_policy policy, handle parent) { + auto st = src_and_type(src); + return type_caster_generic::cast( + st.first, policy, parent, st.second, + make_copy_constructor(src), make_move_constructor(src)); + } + + static handle cast_holder(const itype *src, const void *holder) { + auto st = src_and_type(src); + return type_caster_generic::cast( + st.first, return_value_policy::take_ownership, {}, st.second, + nullptr, nullptr, holder); + } + + template using cast_op_type = cast_op_type; + + operator itype*() { return (type *) value; } + operator itype&() { if (!value) throw reference_cast_error(); return *((itype *) value); } + +protected: + using Constructor = void *(*)(const void *); + + /* Only enabled when the types are {copy,move}-constructible *and* when the type + does not have a private operator new implementation. */ + template ::value>> + static auto make_copy_constructor(const T *x) -> decltype(new T(*x), Constructor{}) { + return [](const void *arg) -> void * { + return new T(*reinterpret_cast(arg)); + }; + } + + template ::value>> + static auto make_move_constructor(const T *x) -> decltype(new T(std::move(*const_cast(x))), Constructor{}) { + return [](const void *arg) -> void * { + return new T(std::move(*const_cast(reinterpret_cast(arg)))); + }; + } + + static Constructor make_copy_constructor(...) { return nullptr; } + static Constructor make_move_constructor(...) { return nullptr; } +}; + +template class type_caster : public type_caster_base { }; +template using make_caster = type_caster>; + +// Shortcut for calling a caster's `cast_op_type` cast operator for casting a type_caster to a T +template typename make_caster::template cast_op_type cast_op(make_caster &caster) { + return caster.operator typename make_caster::template cast_op_type(); +} +template typename make_caster::template cast_op_type::type> +cast_op(make_caster &&caster) { + return std::move(caster).operator + typename make_caster::template cast_op_type::type>(); +} + +template class type_caster> { +private: + using caster_t = make_caster; + caster_t subcaster; + using subcaster_cast_op_type = typename caster_t::template cast_op_type; + static_assert(std::is_same::type &, subcaster_cast_op_type>::value, + "std::reference_wrapper caster requires T to have a caster with an `T &` operator"); +public: + bool load(handle src, bool convert) { return subcaster.load(src, convert); } + static PYBIND11_DESCR name() { return caster_t::name(); } + static handle cast(const std::reference_wrapper &src, return_value_policy policy, handle parent) { + // It is definitely wrong to take ownership of this pointer, so mask that rvp + if (policy == return_value_policy::take_ownership || policy == return_value_policy::automatic) + policy = return_value_policy::automatic_reference; + return caster_t::cast(&src.get(), policy, parent); + } + template using cast_op_type = std::reference_wrapper; + operator std::reference_wrapper() { return subcaster.operator subcaster_cast_op_type&(); } +}; + +#define PYBIND11_TYPE_CASTER(type, py_name) \ + protected: \ + type value; \ + public: \ + static PYBIND11_DESCR name() { return type_descr(py_name); } \ + template >::value, int> = 0> \ + static handle cast(T_ *src, return_value_policy policy, handle parent) { \ + if (!src) return none().release(); \ + if (policy == return_value_policy::take_ownership) { \ + auto h = cast(std::move(*src), policy, parent); delete src; return h; \ + } else { \ + return cast(*src, policy, parent); \ + } \ + } \ + operator type*() { return &value; } \ + operator type&() { return value; } \ + operator type&&() && { return std::move(value); } \ + template using cast_op_type = pybind11::detail::movable_cast_op_type + + +template using is_std_char_type = any_of< + std::is_same, /* std::string */ + std::is_same, /* std::u16string */ + std::is_same, /* std::u32string */ + std::is_same /* std::wstring */ +>; + +template +struct type_caster::value && !is_std_char_type::value>> { + using _py_type_0 = conditional_t; + using _py_type_1 = conditional_t::value, _py_type_0, typename std::make_unsigned<_py_type_0>::type>; + using py_type = conditional_t::value, double, _py_type_1>; +public: + + bool load(handle src, bool convert) { + py_type py_value; + + if (!src) + return false; + + if (std::is_floating_point::value) { + if (convert || PyFloat_Check(src.ptr())) + py_value = (py_type) PyFloat_AsDouble(src.ptr()); + else + return false; + } else if (PyFloat_Check(src.ptr())) { + return false; + } else if (std::is_unsigned::value) { + py_value = as_unsigned(src.ptr()); + } else { // signed integer: + py_value = sizeof(T) <= sizeof(long) + ? (py_type) PyLong_AsLong(src.ptr()) + : (py_type) PYBIND11_LONG_AS_LONGLONG(src.ptr()); + } + + bool py_err = py_value == (py_type) -1 && PyErr_Occurred(); + if (py_err || (std::is_integral::value && sizeof(py_type) != sizeof(T) && + (py_value < (py_type) std::numeric_limits::min() || + py_value > (py_type) std::numeric_limits::max()))) { + bool type_error = py_err && PyErr_ExceptionMatches( +#if PY_VERSION_HEX < 0x03000000 && !defined(PYPY_VERSION) + PyExc_SystemError +#else + PyExc_TypeError +#endif + ); + PyErr_Clear(); + if (type_error && convert && PyNumber_Check(src.ptr())) { + auto tmp = reinterpret_steal(std::is_floating_point::value + ? PyNumber_Float(src.ptr()) + : PyNumber_Long(src.ptr())); + PyErr_Clear(); + return load(tmp, false); + } + return false; + } + + value = (T) py_value; + return true; + } + + static handle cast(T src, return_value_policy /* policy */, handle /* parent */) { + if (std::is_floating_point::value) { + return PyFloat_FromDouble((double) src); + } else if (sizeof(T) <= sizeof(long)) { + if (std::is_signed::value) + return PyLong_FromLong((long) src); + else + return PyLong_FromUnsignedLong((unsigned long) src); + } else { + if (std::is_signed::value) + return PyLong_FromLongLong((long long) src); + else + return PyLong_FromUnsignedLongLong((unsigned long long) src); + } + } + + PYBIND11_TYPE_CASTER(T, _::value>("int", "float")); +}; + +template struct void_caster { +public: + bool load(handle src, bool) { + if (src && src.is_none()) + return true; + return false; + } + static handle cast(T, return_value_policy /* policy */, handle /* parent */) { + return none().inc_ref(); + } + PYBIND11_TYPE_CASTER(T, _("None")); +}; + +template <> class type_caster : public void_caster {}; + +template <> class type_caster : public type_caster { +public: + using type_caster::cast; + + bool load(handle h, bool) { + if (!h) { + return false; + } else if (h.is_none()) { + value = nullptr; + return true; + } + + /* Check if this is a capsule */ + if (isinstance(h)) { + value = reinterpret_borrow(h); + return true; + } + + /* Check if this is a C++ type */ + auto &bases = all_type_info((PyTypeObject *) h.get_type().ptr()); + if (bases.size() == 1) { // Only allowing loading from a single-value type + value = values_and_holders(reinterpret_cast(h.ptr())).begin()->value_ptr(); + return true; + } + + /* Fail */ + return false; + } + + static handle cast(const void *ptr, return_value_policy /* policy */, handle /* parent */) { + if (ptr) + return capsule(ptr).release(); + else + return none().inc_ref(); + } + + template using cast_op_type = void*&; + operator void *&() { return value; } + static PYBIND11_DESCR name() { return type_descr(_("capsule")); } +private: + void *value = nullptr; +}; + +template <> class type_caster : public void_caster { }; + +template <> class type_caster { +public: + bool load(handle src, bool convert) { + if (!src) return false; + else if (src.ptr() == Py_True) { value = true; return true; } + else if (src.ptr() == Py_False) { value = false; return true; } + else if (convert || !strcmp("numpy.bool_", Py_TYPE(src.ptr())->tp_name)) { + // (allow non-implicit conversion for numpy booleans) + + Py_ssize_t res = -1; + if (src.is_none()) { + res = 0; // None is implicitly converted to False + } + #if defined(PYPY_VERSION) + // On PyPy, check that "__bool__" (or "__nonzero__" on Python 2.7) attr exists + else if (hasattr(src, PYBIND11_BOOL_ATTR)) { + res = PyObject_IsTrue(src.ptr()); + } + #else + // Alternate approach for CPython: this does the same as the above, but optimized + // using the CPython API so as to avoid an unneeded attribute lookup. + else if (auto tp_as_number = src.ptr()->ob_type->tp_as_number) { + if (PYBIND11_NB_BOOL(tp_as_number)) { + res = (*PYBIND11_NB_BOOL(tp_as_number))(src.ptr()); + } + } + #endif + if (res == 0 || res == 1) { + value = (bool) res; + return true; + } + } + return false; + } + static handle cast(bool src, return_value_policy /* policy */, handle /* parent */) { + return handle(src ? Py_True : Py_False).inc_ref(); + } + PYBIND11_TYPE_CASTER(bool, _("bool")); +}; + +// Helper class for UTF-{8,16,32} C++ stl strings: +template struct string_caster { + using CharT = typename StringType::value_type; + + // Simplify life by being able to assume standard char sizes (the standard only guarantees + // minimums, but Python requires exact sizes) + static_assert(!std::is_same::value || sizeof(CharT) == 1, "Unsupported char size != 1"); + static_assert(!std::is_same::value || sizeof(CharT) == 2, "Unsupported char16_t size != 2"); + static_assert(!std::is_same::value || sizeof(CharT) == 4, "Unsupported char32_t size != 4"); + // wchar_t can be either 16 bits (Windows) or 32 (everywhere else) + static_assert(!std::is_same::value || sizeof(CharT) == 2 || sizeof(CharT) == 4, + "Unsupported wchar_t size != 2/4"); + static constexpr size_t UTF_N = 8 * sizeof(CharT); + + bool load(handle src, bool) { +#if PY_MAJOR_VERSION < 3 + object temp; +#endif + handle load_src = src; + if (!src) { + return false; + } else if (!PyUnicode_Check(load_src.ptr())) { +#if PY_MAJOR_VERSION >= 3 + return load_bytes(load_src); +#else + if (sizeof(CharT) == 1) { + return load_bytes(load_src); + } + + // The below is a guaranteed failure in Python 3 when PyUnicode_Check returns false + if (!PYBIND11_BYTES_CHECK(load_src.ptr())) + return false; + + temp = reinterpret_steal(PyUnicode_FromObject(load_src.ptr())); + if (!temp) { PyErr_Clear(); return false; } + load_src = temp; +#endif + } + + object utfNbytes = reinterpret_steal(PyUnicode_AsEncodedString( + load_src.ptr(), UTF_N == 8 ? "utf-8" : UTF_N == 16 ? "utf-16" : "utf-32", nullptr)); + if (!utfNbytes) { PyErr_Clear(); return false; } + + const CharT *buffer = reinterpret_cast(PYBIND11_BYTES_AS_STRING(utfNbytes.ptr())); + size_t length = (size_t) PYBIND11_BYTES_SIZE(utfNbytes.ptr()) / sizeof(CharT); + if (UTF_N > 8) { buffer++; length--; } // Skip BOM for UTF-16/32 + value = StringType(buffer, length); + + // If we're loading a string_view we need to keep the encoded Python object alive: + if (IsView) + loader_life_support::add_patient(utfNbytes); + + return true; + } + + static handle cast(const StringType &src, return_value_policy /* policy */, handle /* parent */) { + const char *buffer = reinterpret_cast(src.data()); + ssize_t nbytes = ssize_t(src.size() * sizeof(CharT)); + handle s = decode_utfN(buffer, nbytes); + if (!s) throw error_already_set(); + return s; + } + + PYBIND11_TYPE_CASTER(StringType, _(PYBIND11_STRING_NAME)); + +private: + static handle decode_utfN(const char *buffer, ssize_t nbytes) { +#if !defined(PYPY_VERSION) + return + UTF_N == 8 ? PyUnicode_DecodeUTF8(buffer, nbytes, nullptr) : + UTF_N == 16 ? PyUnicode_DecodeUTF16(buffer, nbytes, nullptr, nullptr) : + PyUnicode_DecodeUTF32(buffer, nbytes, nullptr, nullptr); +#else + // PyPy seems to have multiple problems related to PyUnicode_UTF*: the UTF8 version + // sometimes segfaults for unknown reasons, while the UTF16 and 32 versions require a + // non-const char * arguments, which is also a nuissance, so bypass the whole thing by just + // passing the encoding as a string value, which works properly: + return PyUnicode_Decode(buffer, nbytes, UTF_N == 8 ? "utf-8" : UTF_N == 16 ? "utf-16" : "utf-32", nullptr); +#endif + } + + // When loading into a std::string or char*, accept a bytes object as-is (i.e. + // without any encoding/decoding attempt). For other C++ char sizes this is a no-op. + // which supports loading a unicode from a str, doesn't take this path. + template + bool load_bytes(enable_if_t src) { + if (PYBIND11_BYTES_CHECK(src.ptr())) { + // We were passed a Python 3 raw bytes; accept it into a std::string or char* + // without any encoding attempt. + const char *bytes = PYBIND11_BYTES_AS_STRING(src.ptr()); + if (bytes) { + value = StringType(bytes, (size_t) PYBIND11_BYTES_SIZE(src.ptr())); + return true; + } + } + + return false; + } + + template + bool load_bytes(enable_if_t) { return false; } +}; + +template +struct type_caster, enable_if_t::value>> + : string_caster> {}; + +#ifdef PYBIND11_HAS_STRING_VIEW +template +struct type_caster, enable_if_t::value>> + : string_caster, true> {}; +#endif + +// Type caster for C-style strings. We basically use a std::string type caster, but also add the +// ability to use None as a nullptr char* (which the string caster doesn't allow). +template struct type_caster::value>> { + using StringType = std::basic_string; + using StringCaster = type_caster; + StringCaster str_caster; + bool none = false; + CharT one_char = 0; +public: + bool load(handle src, bool convert) { + if (!src) return false; + if (src.is_none()) { + // Defer accepting None to other overloads (if we aren't in convert mode): + if (!convert) return false; + none = true; + return true; + } + return str_caster.load(src, convert); + } + + static handle cast(const CharT *src, return_value_policy policy, handle parent) { + if (src == nullptr) return pybind11::none().inc_ref(); + return StringCaster::cast(StringType(src), policy, parent); + } + + static handle cast(CharT src, return_value_policy policy, handle parent) { + if (std::is_same::value) { + handle s = PyUnicode_DecodeLatin1((const char *) &src, 1, nullptr); + if (!s) throw error_already_set(); + return s; + } + return StringCaster::cast(StringType(1, src), policy, parent); + } + + operator CharT*() { return none ? nullptr : const_cast(static_cast(str_caster).c_str()); } + operator CharT&() { + if (none) + throw value_error("Cannot convert None to a character"); + + auto &value = static_cast(str_caster); + size_t str_len = value.size(); + if (str_len == 0) + throw value_error("Cannot convert empty string to a character"); + + // If we're in UTF-8 mode, we have two possible failures: one for a unicode character that + // is too high, and one for multiple unicode characters (caught later), so we need to figure + // out how long the first encoded character is in bytes to distinguish between these two + // errors. We also allow want to allow unicode characters U+0080 through U+00FF, as those + // can fit into a single char value. + if (StringCaster::UTF_N == 8 && str_len > 1 && str_len <= 4) { + unsigned char v0 = static_cast(value[0]); + size_t char0_bytes = !(v0 & 0x80) ? 1 : // low bits only: 0-127 + (v0 & 0xE0) == 0xC0 ? 2 : // 0b110xxxxx - start of 2-byte sequence + (v0 & 0xF0) == 0xE0 ? 3 : // 0b1110xxxx - start of 3-byte sequence + 4; // 0b11110xxx - start of 4-byte sequence + + if (char0_bytes == str_len) { + // If we have a 128-255 value, we can decode it into a single char: + if (char0_bytes == 2 && (v0 & 0xFC) == 0xC0) { // 0x110000xx 0x10xxxxxx + one_char = static_cast(((v0 & 3) << 6) + (static_cast(value[1]) & 0x3F)); + return one_char; + } + // Otherwise we have a single character, but it's > U+00FF + throw value_error("Character code point not in range(0x100)"); + } + } + + // UTF-16 is much easier: we can only have a surrogate pair for values above U+FFFF, thus a + // surrogate pair with total length 2 instantly indicates a range error (but not a "your + // string was too long" error). + else if (StringCaster::UTF_N == 16 && str_len == 2) { + one_char = static_cast(value[0]); + if (one_char >= 0xD800 && one_char < 0xE000) + throw value_error("Character code point not in range(0x10000)"); + } + + if (str_len != 1) + throw value_error("Expected a character, but multi-character string found"); + + one_char = value[0]; + return one_char; + } + + static PYBIND11_DESCR name() { return type_descr(_(PYBIND11_STRING_NAME)); } + template using cast_op_type = pybind11::detail::cast_op_type<_T>; +}; + +// Base implementation for std::tuple and std::pair +template class Tuple, typename... Ts> class tuple_caster { + using type = Tuple; + static constexpr auto size = sizeof...(Ts); + using indices = make_index_sequence; +public: + + bool load(handle src, bool convert) { + if (!isinstance(src)) + return false; + const auto seq = reinterpret_borrow(src); + if (seq.size() != size) + return false; + return load_impl(seq, convert, indices{}); + } + + template + static handle cast(T &&src, return_value_policy policy, handle parent) { + return cast_impl(std::forward(src), policy, parent, indices{}); + } + + static PYBIND11_DESCR name() { + return type_descr(_("Tuple[") + detail::concat(make_caster::name()...) + _("]")); + } + + template using cast_op_type = type; + + operator type() & { return implicit_cast(indices{}); } + operator type() && { return std::move(*this).implicit_cast(indices{}); } + +protected: + template + type implicit_cast(index_sequence) & { return type(cast_op(std::get(subcasters))...); } + template + type implicit_cast(index_sequence) && { return type(cast_op(std::move(std::get(subcasters)))...); } + + static constexpr bool load_impl(const sequence &, bool, index_sequence<>) { return true; } + + template + bool load_impl(const sequence &seq, bool convert, index_sequence) { + for (bool r : {std::get(subcasters).load(seq[Is], convert)...}) + if (!r) + return false; + return true; + } + + /* Implementation: Convert a C++ tuple into a Python tuple */ + template + static handle cast_impl(T &&src, return_value_policy policy, handle parent, index_sequence) { + std::array entries{{ + reinterpret_steal(make_caster::cast(std::get(std::forward(src)), policy, parent))... + }}; + for (const auto &entry: entries) + if (!entry) + return handle(); + tuple result(size); + int counter = 0; + for (auto & entry: entries) + PyTuple_SET_ITEM(result.ptr(), counter++, entry.release().ptr()); + return result.release(); + } + + Tuple...> subcasters; +}; + +template class type_caster> + : public tuple_caster {}; + +template class type_caster> + : public tuple_caster {}; + +/// Helper class which abstracts away certain actions. Users can provide specializations for +/// custom holders, but it's only necessary if the type has a non-standard interface. +template +struct holder_helper { + static auto get(const T &p) -> decltype(p.get()) { return p.get(); } +}; + +/// Type caster for holder types like std::shared_ptr, etc. +template +struct copyable_holder_caster : public type_caster_base { +public: + using base = type_caster_base; + static_assert(std::is_base_of>::value, + "Holder classes are only supported for custom types"); + using base::base; + using base::cast; + using base::typeinfo; + using base::value; + + bool load(handle src, bool convert) { + return base::template load_impl>(src, convert); + } + + explicit operator type*() { return this->value; } + explicit operator type&() { return *(this->value); } + explicit operator holder_type*() { return &holder; } + + // Workaround for Intel compiler bug + // see pybind11 issue 94 + #if defined(__ICC) || defined(__INTEL_COMPILER) + operator holder_type&() { return holder; } + #else + explicit operator holder_type&() { return holder; } + #endif + + static handle cast(const holder_type &src, return_value_policy, handle) { + const auto *ptr = holder_helper::get(src); + return type_caster_base::cast_holder(ptr, &src); + } + +protected: + friend class type_caster_generic; + void check_holder_compat() { + if (typeinfo->default_holder) + throw cast_error("Unable to load a custom holder type from a default-holder instance"); + } + + bool load_value(value_and_holder &&v_h) { + if (v_h.holder_constructed()) { + value = v_h.value_ptr(); + holder = v_h.template holder(); + return true; + } else { + throw cast_error("Unable to cast from non-held to held instance (T& to Holder) " +#if defined(NDEBUG) + "(compile in debug mode for type information)"); +#else + "of type '" + type_id() + "''"); +#endif + } + } + + template ::value, int> = 0> + bool try_implicit_casts(handle, bool) { return false; } + + template ::value, int> = 0> + bool try_implicit_casts(handle src, bool convert) { + for (auto &cast : typeinfo->implicit_casts) { + copyable_holder_caster sub_caster(*cast.first); + if (sub_caster.load(src, convert)) { + value = cast.second(sub_caster.value); + holder = holder_type(sub_caster.holder, (type *) value); + return true; + } + } + return false; + } + + static bool try_direct_conversions(handle) { return false; } + + + holder_type holder; +}; + +/// Specialize for the common std::shared_ptr, so users don't need to +template +class type_caster> : public copyable_holder_caster> { }; + +template +struct move_only_holder_caster { + static_assert(std::is_base_of, type_caster>::value, + "Holder classes are only supported for custom types"); + + static handle cast(holder_type &&src, return_value_policy, handle) { + auto *ptr = holder_helper::get(src); + return type_caster_base::cast_holder(ptr, &src); + } + static PYBIND11_DESCR name() { return type_caster_base::name(); } +}; + +template +class type_caster> + : public move_only_holder_caster> { }; + +template +using type_caster_holder = conditional_t::value, + copyable_holder_caster, + move_only_holder_caster>; + +template struct always_construct_holder { static constexpr bool value = Value; }; + +/// Create a specialization for custom holder types (silently ignores std::shared_ptr) +#define PYBIND11_DECLARE_HOLDER_TYPE(type, holder_type, ...) \ + namespace pybind11 { namespace detail { \ + template \ + struct always_construct_holder : always_construct_holder { }; \ + template \ + class type_caster::value>> \ + : public type_caster_holder { }; \ + }} + +// PYBIND11_DECLARE_HOLDER_TYPE holder types: +template struct is_holder_type : + std::is_base_of, detail::type_caster> {}; +// Specialization for always-supported unique_ptr holders: +template struct is_holder_type> : + std::true_type {}; + +template struct handle_type_name { static PYBIND11_DESCR name() { return _(); } }; +template <> struct handle_type_name { static PYBIND11_DESCR name() { return _(PYBIND11_BYTES_NAME); } }; +template <> struct handle_type_name { static PYBIND11_DESCR name() { return _("*args"); } }; +template <> struct handle_type_name { static PYBIND11_DESCR name() { return _("**kwargs"); } }; + +template +struct pyobject_caster { + template ::value, int> = 0> + bool load(handle src, bool /* convert */) { value = src; return static_cast(value); } + + template ::value, int> = 0> + bool load(handle src, bool /* convert */) { + if (!isinstance(src)) + return false; + value = reinterpret_borrow(src); + return true; + } + + static handle cast(const handle &src, return_value_policy /* policy */, handle /* parent */) { + return src.inc_ref(); + } + PYBIND11_TYPE_CASTER(type, handle_type_name::name()); +}; + +template +class type_caster::value>> : public pyobject_caster { }; + +// Our conditions for enabling moving are quite restrictive: +// At compile time: +// - T needs to be a non-const, non-pointer, non-reference type +// - type_caster::operator T&() must exist +// - the type must be move constructible (obviously) +// At run-time: +// - if the type is non-copy-constructible, the object must be the sole owner of the type (i.e. it +// must have ref_count() == 1)h +// If any of the above are not satisfied, we fall back to copying. +template using move_is_plain_type = satisfies_none_of; +template struct move_always : std::false_type {}; +template struct move_always, + negation>, + std::is_move_constructible, + std::is_same>().operator T&()), T&> +>::value>> : std::true_type {}; +template struct move_if_unreferenced : std::false_type {}; +template struct move_if_unreferenced, + negation>, + std::is_move_constructible, + std::is_same>().operator T&()), T&> +>::value>> : std::true_type {}; +template using move_never = none_of, move_if_unreferenced>; + +// Detect whether returning a `type` from a cast on type's type_caster is going to result in a +// reference or pointer to a local variable of the type_caster. Basically, only +// non-reference/pointer `type`s and reference/pointers from a type_caster_generic are safe; +// everything else returns a reference/pointer to a local variable. +template using cast_is_temporary_value_reference = bool_constant< + (std::is_reference::value || std::is_pointer::value) && + !std::is_base_of>::value +>; + +// When a value returned from a C++ function is being cast back to Python, we almost always want to +// force `policy = move`, regardless of the return value policy the function/method was declared +// with. Some classes (most notably Eigen::Ref and related) need to avoid this, and so can do so by +// specializing this struct. +template struct return_value_policy_override { + static return_value_policy policy(return_value_policy p) { + return !std::is_lvalue_reference::value && !std::is_pointer::value + ? return_value_policy::move : p; + } +}; + +// Basic python -> C++ casting; throws if casting fails +template type_caster &load_type(type_caster &conv, const handle &handle) { + if (!conv.load(handle, true)) { +#if defined(NDEBUG) + throw cast_error("Unable to cast Python instance to C++ type (compile in debug mode for details)"); +#else + throw cast_error("Unable to cast Python instance of type " + + (std::string) str(handle.get_type()) + " to C++ type '" + type_id() + "'"); +#endif + } + return conv; +} +// Wrapper around the above that also constructs and returns a type_caster +template make_caster load_type(const handle &handle) { + make_caster conv; + load_type(conv, handle); + return conv; +} + +NAMESPACE_END(detail) + +// pytype -> C++ type +template ::value, int> = 0> +T cast(const handle &handle) { + using namespace detail; + static_assert(!cast_is_temporary_value_reference::value, + "Unable to cast type to reference: value is local to type caster"); + return cast_op(load_type(handle)); +} + +// pytype -> pytype (calls converting constructor) +template ::value, int> = 0> +T cast(const handle &handle) { return T(reinterpret_borrow(handle)); } + +// C++ type -> py::object +template ::value, int> = 0> +object cast(const T &value, return_value_policy policy = return_value_policy::automatic_reference, + handle parent = handle()) { + if (policy == return_value_policy::automatic) + policy = std::is_pointer::value ? return_value_policy::take_ownership : return_value_policy::copy; + else if (policy == return_value_policy::automatic_reference) + policy = std::is_pointer::value ? return_value_policy::reference : return_value_policy::copy; + return reinterpret_steal(detail::make_caster::cast(value, policy, parent)); +} + +template T handle::cast() const { return pybind11::cast(*this); } +template <> inline void handle::cast() const { return; } + +template +detail::enable_if_t::value, T> move(object &&obj) { + if (obj.ref_count() > 1) +#if defined(NDEBUG) + throw cast_error("Unable to cast Python instance to C++ rvalue: instance has multiple references" + " (compile in debug mode for details)"); +#else + throw cast_error("Unable to move from Python " + (std::string) str(obj.get_type()) + + " instance to C++ " + type_id() + " instance: instance has multiple references"); +#endif + + // Move into a temporary and return that, because the reference may be a local value of `conv` + T ret = std::move(detail::load_type(obj).operator T&()); + return ret; +} + +// Calling cast() on an rvalue calls pybind::cast with the object rvalue, which does: +// - If we have to move (because T has no copy constructor), do it. This will fail if the moved +// object has multiple references, but trying to copy will fail to compile. +// - If both movable and copyable, check ref count: if 1, move; otherwise copy +// - Otherwise (not movable), copy. +template detail::enable_if_t::value, T> cast(object &&object) { + return move(std::move(object)); +} +template detail::enable_if_t::value, T> cast(object &&object) { + if (object.ref_count() > 1) + return cast(object); + else + return move(std::move(object)); +} +template detail::enable_if_t::value, T> cast(object &&object) { + return cast(object); +} + +template T object::cast() const & { return pybind11::cast(*this); } +template T object::cast() && { return pybind11::cast(std::move(*this)); } +template <> inline void object::cast() const & { return; } +template <> inline void object::cast() && { return; } + +NAMESPACE_BEGIN(detail) + +// Declared in pytypes.h: +template ::value, int>> +object object_or_cast(T &&o) { return pybind11::cast(std::forward(o)); } + +struct overload_unused {}; // Placeholder type for the unneeded (and dead code) static variable in the OVERLOAD_INT macro +template using overload_caster_t = conditional_t< + cast_is_temporary_value_reference::value, make_caster, overload_unused>; + +// Trampoline use: for reference/pointer types to value-converted values, we do a value cast, then +// store the result in the given variable. For other types, this is a no-op. +template enable_if_t::value, T> cast_ref(object &&o, make_caster &caster) { + return cast_op(load_type(caster, o)); +} +template enable_if_t::value, T> cast_ref(object &&, overload_unused &) { + pybind11_fail("Internal error: cast_ref fallback invoked"); } + +// Trampoline use: Having a pybind11::cast with an invalid reference type is going to static_assert, even +// though if it's in dead code, so we provide a "trampoline" to pybind11::cast that only does anything in +// cases where pybind11::cast is valid. +template enable_if_t::value, T> cast_safe(object &&o) { + return pybind11::cast(std::move(o)); } +template enable_if_t::value, T> cast_safe(object &&) { + pybind11_fail("Internal error: cast_safe fallback invoked"); } +template <> inline void cast_safe(object &&) {} + +NAMESPACE_END(detail) + +template +tuple make_tuple() { return tuple(0); } + +template tuple make_tuple(Args&&... args_) { + constexpr size_t size = sizeof...(Args); + std::array args { + { reinterpret_steal(detail::make_caster::cast( + std::forward(args_), policy, nullptr))... } + }; + for (size_t i = 0; i < args.size(); i++) { + if (!args[i]) { +#if defined(NDEBUG) + throw cast_error("make_tuple(): unable to convert arguments to Python object (compile in debug mode for details)"); +#else + std::array argtypes { {type_id()...} }; + throw cast_error("make_tuple(): unable to convert argument of type '" + + argtypes[i] + "' to Python object"); +#endif + } + } + tuple result(size); + int counter = 0; + for (auto &arg_value : args) + PyTuple_SET_ITEM(result.ptr(), counter++, arg_value.release().ptr()); + return result; +} + +/// \ingroup annotations +/// Annotation for arguments +struct arg { + /// Constructs an argument with the name of the argument; if null or omitted, this is a positional argument. + constexpr explicit arg(const char *name = nullptr) : name(name), flag_noconvert(false), flag_none(true) { } + /// Assign a value to this argument + template arg_v operator=(T &&value) const; + /// Indicate that the type should not be converted in the type caster + arg &noconvert(bool flag = true) { flag_noconvert = flag; return *this; } + /// Indicates that the argument should/shouldn't allow None (e.g. for nullable pointer args) + arg &none(bool flag = true) { flag_none = flag; return *this; } + + const char *name; ///< If non-null, this is a named kwargs argument + bool flag_noconvert : 1; ///< If set, do not allow conversion (requires a supporting type caster!) + bool flag_none : 1; ///< If set (the default), allow None to be passed to this argument +}; + +/// \ingroup annotations +/// Annotation for arguments with values +struct arg_v : arg { +private: + template + arg_v(arg &&base, T &&x, const char *descr = nullptr) + : arg(base), + value(reinterpret_steal( + detail::make_caster::cast(x, return_value_policy::automatic, {}) + )), + descr(descr) +#if !defined(NDEBUG) + , type(type_id()) +#endif + { } + +public: + /// Direct construction with name, default, and description + template + arg_v(const char *name, T &&x, const char *descr = nullptr) + : arg_v(arg(name), std::forward(x), descr) { } + + /// Called internally when invoking `py::arg("a") = value` + template + arg_v(const arg &base, T &&x, const char *descr = nullptr) + : arg_v(arg(base), std::forward(x), descr) { } + + /// Same as `arg::noconvert()`, but returns *this as arg_v&, not arg& + arg_v &noconvert(bool flag = true) { arg::noconvert(flag); return *this; } + + /// Same as `arg::nonone()`, but returns *this as arg_v&, not arg& + arg_v &none(bool flag = true) { arg::none(flag); return *this; } + + /// The default value + object value; + /// The (optional) description of the default value + const char *descr; +#if !defined(NDEBUG) + /// The C++ type name of the default value (only available when compiled in debug mode) + std::string type; +#endif +}; + +template +arg_v arg::operator=(T &&value) const { return {std::move(*this), std::forward(value)}; } + +/// Alias for backward compatibility -- to be removed in version 2.0 +template using arg_t = arg_v; + +inline namespace literals { +/** \rst + String literal version of `arg` + \endrst */ +constexpr arg operator"" _a(const char *name, size_t) { return arg(name); } +} + +NAMESPACE_BEGIN(detail) + +// forward declaration (definition in attr.h) +struct function_record; + +/// Internal data associated with a single function call +struct function_call { + function_call(function_record &f, handle p); // Implementation in attr.h + + /// The function data: + const function_record &func; + + /// Arguments passed to the function: + std::vector args; + + /// The `convert` value the arguments should be loaded with + std::vector args_convert; + + /// Extra references for the optional `py::args` and/or `py::kwargs` arguments (which, if + /// present, are also in `args` but without a reference). + object args_ref, kwargs_ref; + + /// The parent, if any + handle parent; + + /// If this is a call to an initializer, this argument contains `self` + handle init_self; +}; + + +/// Helper class which loads arguments for C++ functions called from Python +template +class argument_loader { + using indices = make_index_sequence; + + template using argument_is_args = std::is_same, args>; + template using argument_is_kwargs = std::is_same, kwargs>; + // Get args/kwargs argument positions relative to the end of the argument list: + static constexpr auto args_pos = constexpr_first() - (int) sizeof...(Args), + kwargs_pos = constexpr_first() - (int) sizeof...(Args); + + static constexpr bool args_kwargs_are_last = kwargs_pos >= - 1 && args_pos >= kwargs_pos - 1; + + static_assert(args_kwargs_are_last, "py::args/py::kwargs are only permitted as the last argument(s) of a function"); + +public: + static constexpr bool has_kwargs = kwargs_pos < 0; + static constexpr bool has_args = args_pos < 0; + + static PYBIND11_DESCR arg_names() { return detail::concat(make_caster::name()...); } + + bool load_args(function_call &call) { + return load_impl_sequence(call, indices{}); + } + + template + enable_if_t::value, Return> call(Func &&f) && { + return std::move(*this).template call_impl(std::forward(f), indices{}, Guard{}); + } + + template + enable_if_t::value, void_type> call(Func &&f) && { + std::move(*this).template call_impl(std::forward(f), indices{}, Guard{}); + return void_type(); + } + +private: + + static bool load_impl_sequence(function_call &, index_sequence<>) { return true; } + + template + bool load_impl_sequence(function_call &call, index_sequence) { + for (bool r : {std::get(argcasters).load(call.args[Is], call.args_convert[Is])...}) + if (!r) + return false; + return true; + } + + template + Return call_impl(Func &&f, index_sequence, Guard &&) { + return std::forward(f)(cast_op(std::move(std::get(argcasters)))...); + } + + std::tuple...> argcasters; +}; + +/// Helper class which collects only positional arguments for a Python function call. +/// A fancier version below can collect any argument, but this one is optimal for simple calls. +template +class simple_collector { +public: + template + explicit simple_collector(Ts &&...values) + : m_args(pybind11::make_tuple(std::forward(values)...)) { } + + const tuple &args() const & { return m_args; } + dict kwargs() const { return {}; } + + tuple args() && { return std::move(m_args); } + + /// Call a Python function and pass the collected arguments + object call(PyObject *ptr) const { + PyObject *result = PyObject_CallObject(ptr, m_args.ptr()); + if (!result) + throw error_already_set(); + return reinterpret_steal(result); + } + +private: + tuple m_args; +}; + +/// Helper class which collects positional, keyword, * and ** arguments for a Python function call +template +class unpacking_collector { +public: + template + explicit unpacking_collector(Ts &&...values) { + // Tuples aren't (easily) resizable so a list is needed for collection, + // but the actual function call strictly requires a tuple. + auto args_list = list(); + int _[] = { 0, (process(args_list, std::forward(values)), 0)... }; + ignore_unused(_); + + m_args = std::move(args_list); + } + + const tuple &args() const & { return m_args; } + const dict &kwargs() const & { return m_kwargs; } + + tuple args() && { return std::move(m_args); } + dict kwargs() && { return std::move(m_kwargs); } + + /// Call a Python function and pass the collected arguments + object call(PyObject *ptr) const { + PyObject *result = PyObject_Call(ptr, m_args.ptr(), m_kwargs.ptr()); + if (!result) + throw error_already_set(); + return reinterpret_steal(result); + } + +private: + template + void process(list &args_list, T &&x) { + auto o = reinterpret_steal(detail::make_caster::cast(std::forward(x), policy, {})); + if (!o) { +#if defined(NDEBUG) + argument_cast_error(); +#else + argument_cast_error(std::to_string(args_list.size()), type_id()); +#endif + } + args_list.append(o); + } + + void process(list &args_list, detail::args_proxy ap) { + for (const auto &a : ap) + args_list.append(a); + } + + void process(list &/*args_list*/, arg_v a) { + if (!a.name) +#if defined(NDEBUG) + nameless_argument_error(); +#else + nameless_argument_error(a.type); +#endif + + if (m_kwargs.contains(a.name)) { +#if defined(NDEBUG) + multiple_values_error(); +#else + multiple_values_error(a.name); +#endif + } + if (!a.value) { +#if defined(NDEBUG) + argument_cast_error(); +#else + argument_cast_error(a.name, a.type); +#endif + } + m_kwargs[a.name] = a.value; + } + + void process(list &/*args_list*/, detail::kwargs_proxy kp) { + if (!kp) + return; + for (const auto &k : reinterpret_borrow(kp)) { + if (m_kwargs.contains(k.first)) { +#if defined(NDEBUG) + multiple_values_error(); +#else + multiple_values_error(str(k.first)); +#endif + } + m_kwargs[k.first] = k.second; + } + } + + [[noreturn]] static void nameless_argument_error() { + throw type_error("Got kwargs without a name; only named arguments " + "may be passed via py::arg() to a python function call. " + "(compile in debug mode for details)"); + } + [[noreturn]] static void nameless_argument_error(std::string type) { + throw type_error("Got kwargs without a name of type '" + type + "'; only named " + "arguments may be passed via py::arg() to a python function call. "); + } + [[noreturn]] static void multiple_values_error() { + throw type_error("Got multiple values for keyword argument " + "(compile in debug mode for details)"); + } + + [[noreturn]] static void multiple_values_error(std::string name) { + throw type_error("Got multiple values for keyword argument '" + name + "'"); + } + + [[noreturn]] static void argument_cast_error() { + throw cast_error("Unable to convert call argument to Python object " + "(compile in debug mode for details)"); + } + + [[noreturn]] static void argument_cast_error(std::string name, std::string type) { + throw cast_error("Unable to convert call argument '" + name + + "' of type '" + type + "' to Python object"); + } + +private: + tuple m_args; + dict m_kwargs; +}; + +/// Collect only positional arguments for a Python function call +template ...>::value>> +simple_collector collect_arguments(Args &&...args) { + return simple_collector(std::forward(args)...); +} + +/// Collect all arguments, including keywords and unpacking (only instantiated when needed) +template ...>::value>> +unpacking_collector collect_arguments(Args &&...args) { + // Following argument order rules for generalized unpacking according to PEP 448 + static_assert( + constexpr_last() < constexpr_first() + && constexpr_last() < constexpr_first(), + "Invalid function call: positional args must precede keywords and ** unpacking; " + "* unpacking must precede ** unpacking" + ); + return unpacking_collector(std::forward(args)...); +} + +template +template +object object_api::operator()(Args &&...args) const { + return detail::collect_arguments(std::forward(args)...).call(derived().ptr()); +} + +template +template +object object_api::call(Args &&...args) const { + return operator()(std::forward(args)...); +} + +NAMESPACE_END(detail) + +#define PYBIND11_MAKE_OPAQUE(Type) \ + namespace pybind11 { namespace detail { \ + template<> class type_caster : public type_caster_base { }; \ + }} + +NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/ml/dlib/dlib/external/pybind11/include/pybind11/chrono.h b/ml/dlib/dlib/external/pybind11/include/pybind11/chrono.h new file mode 100644 index 000000000..95ada76e0 --- /dev/null +++ b/ml/dlib/dlib/external/pybind11/include/pybind11/chrono.h @@ -0,0 +1,162 @@ +/* + pybind11/chrono.h: Transparent conversion between std::chrono and python's datetime + + Copyright (c) 2016 Trent Houliston and + Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include "pybind11.h" +#include +#include +#include +#include + +// Backport the PyDateTime_DELTA functions from Python3.3 if required +#ifndef PyDateTime_DELTA_GET_DAYS +#define PyDateTime_DELTA_GET_DAYS(o) (((PyDateTime_Delta*)o)->days) +#endif +#ifndef PyDateTime_DELTA_GET_SECONDS +#define PyDateTime_DELTA_GET_SECONDS(o) (((PyDateTime_Delta*)o)->seconds) +#endif +#ifndef PyDateTime_DELTA_GET_MICROSECONDS +#define PyDateTime_DELTA_GET_MICROSECONDS(o) (((PyDateTime_Delta*)o)->microseconds) +#endif + +NAMESPACE_BEGIN(PYBIND11_NAMESPACE) +NAMESPACE_BEGIN(detail) + +template class duration_caster { +public: + typedef typename type::rep rep; + typedef typename type::period period; + + typedef std::chrono::duration> days; + + bool load(handle src, bool) { + using namespace std::chrono; + + // Lazy initialise the PyDateTime import + if (!PyDateTimeAPI) { PyDateTime_IMPORT; } + + if (!src) return false; + // If invoked with datetime.delta object + if (PyDelta_Check(src.ptr())) { + value = type(duration_cast>( + days(PyDateTime_DELTA_GET_DAYS(src.ptr())) + + seconds(PyDateTime_DELTA_GET_SECONDS(src.ptr())) + + microseconds(PyDateTime_DELTA_GET_MICROSECONDS(src.ptr())))); + return true; + } + // If invoked with a float we assume it is seconds and convert + else if (PyFloat_Check(src.ptr())) { + value = type(duration_cast>(duration(PyFloat_AsDouble(src.ptr())))); + return true; + } + else return false; + } + + // If this is a duration just return it back + static const std::chrono::duration& get_duration(const std::chrono::duration &src) { + return src; + } + + // If this is a time_point get the time_since_epoch + template static std::chrono::duration get_duration(const std::chrono::time_point> &src) { + return src.time_since_epoch(); + } + + static handle cast(const type &src, return_value_policy /* policy */, handle /* parent */) { + using namespace std::chrono; + + // Use overloaded function to get our duration from our source + // Works out if it is a duration or time_point and get the duration + auto d = get_duration(src); + + // Lazy initialise the PyDateTime import + if (!PyDateTimeAPI) { PyDateTime_IMPORT; } + + // Declare these special duration types so the conversions happen with the correct primitive types (int) + using dd_t = duration>; + using ss_t = duration>; + using us_t = duration; + + auto dd = duration_cast(d); + auto subd = d - dd; + auto ss = duration_cast(subd); + auto us = duration_cast(subd - ss); + return PyDelta_FromDSU(dd.count(), ss.count(), us.count()); + } + + PYBIND11_TYPE_CASTER(type, _("datetime.timedelta")); +}; + +// This is for casting times on the system clock into datetime.datetime instances +template class type_caster> { +public: + typedef std::chrono::time_point type; + bool load(handle src, bool) { + using namespace std::chrono; + + // Lazy initialise the PyDateTime import + if (!PyDateTimeAPI) { PyDateTime_IMPORT; } + + if (!src) return false; + if (PyDateTime_Check(src.ptr())) { + std::tm cal; + cal.tm_sec = PyDateTime_DATE_GET_SECOND(src.ptr()); + cal.tm_min = PyDateTime_DATE_GET_MINUTE(src.ptr()); + cal.tm_hour = PyDateTime_DATE_GET_HOUR(src.ptr()); + cal.tm_mday = PyDateTime_GET_DAY(src.ptr()); + cal.tm_mon = PyDateTime_GET_MONTH(src.ptr()) - 1; + cal.tm_year = PyDateTime_GET_YEAR(src.ptr()) - 1900; + cal.tm_isdst = -1; + + value = system_clock::from_time_t(std::mktime(&cal)) + microseconds(PyDateTime_DATE_GET_MICROSECOND(src.ptr())); + return true; + } + else return false; + } + + static handle cast(const std::chrono::time_point &src, return_value_policy /* policy */, handle /* parent */) { + using namespace std::chrono; + + // Lazy initialise the PyDateTime import + if (!PyDateTimeAPI) { PyDateTime_IMPORT; } + + std::time_t tt = system_clock::to_time_t(src); + // this function uses static memory so it's best to copy it out asap just in case + // otherwise other code that is using localtime may break this (not just python code) + std::tm localtime = *std::localtime(&tt); + + // Declare these special duration types so the conversions happen with the correct primitive types (int) + using us_t = duration; + + return PyDateTime_FromDateAndTime(localtime.tm_year + 1900, + localtime.tm_mon + 1, + localtime.tm_mday, + localtime.tm_hour, + localtime.tm_min, + localtime.tm_sec, + (duration_cast(src.time_since_epoch() % seconds(1))).count()); + } + PYBIND11_TYPE_CASTER(type, _("datetime.datetime")); +}; + +// Other clocks that are not the system clock are not measured as datetime.datetime objects +// since they are not measured on calendar time. So instead we just make them timedeltas +// Or if they have passed us a time as a float we convert that +template class type_caster> +: public duration_caster> { +}; + +template class type_caster> +: public duration_caster> { +}; + +NAMESPACE_END(detail) +NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/ml/dlib/dlib/external/pybind11/include/pybind11/common.h b/ml/dlib/dlib/external/pybind11/include/pybind11/common.h new file mode 100644 index 000000000..6c8a4f1e8 --- /dev/null +++ b/ml/dlib/dlib/external/pybind11/include/pybind11/common.h @@ -0,0 +1,2 @@ +#include "detail/common.h" +#warning "Including 'common.h' is deprecated. It will be removed in v3.0. Use 'pybind11.h'." diff --git a/ml/dlib/dlib/external/pybind11/include/pybind11/complex.h b/ml/dlib/dlib/external/pybind11/include/pybind11/complex.h new file mode 100644 index 000000000..5dac27cc4 --- /dev/null +++ b/ml/dlib/dlib/external/pybind11/include/pybind11/complex.h @@ -0,0 +1,61 @@ +/* + pybind11/complex.h: Complex number support + + Copyright (c) 2016 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include "pybind11.h" +#include + +/// glibc defines I as a macro which breaks things, e.g., boost template names +#ifdef I +# undef I +#endif + +NAMESPACE_BEGIN(PYBIND11_NAMESPACE) + +template struct format_descriptor, detail::enable_if_t::value>> { + static constexpr const char c = format_descriptor::c; + static constexpr const char value[3] = { 'Z', c, '\0' }; + static std::string format() { return std::string(value); } +}; + +template constexpr const char format_descriptor< + std::complex, detail::enable_if_t::value>>::value[3]; + +NAMESPACE_BEGIN(detail) + +template struct is_fmt_numeric, detail::enable_if_t::value>> { + static constexpr bool value = true; + static constexpr int index = is_fmt_numeric::index + 3; +}; + +template class type_caster> { +public: + bool load(handle src, bool convert) { + if (!src) + return false; + if (!convert && !PyComplex_Check(src.ptr())) + return false; + Py_complex result = PyComplex_AsCComplex(src.ptr()); + if (result.real == -1.0 && PyErr_Occurred()) { + PyErr_Clear(); + return false; + } + value = std::complex((T) result.real, (T) result.imag); + return true; + } + + static handle cast(const std::complex &src, return_value_policy /* policy */, handle /* parent */) { + return PyComplex_FromDoubles((double) src.real(), (double) src.imag()); + } + + PYBIND11_TYPE_CASTER(std::complex, _("complex")); +}; +NAMESPACE_END(detail) +NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/ml/dlib/dlib/external/pybind11/include/pybind11/detail/class.h b/ml/dlib/dlib/external/pybind11/include/pybind11/detail/class.h new file mode 100644 index 000000000..ff06370fa --- /dev/null +++ b/ml/dlib/dlib/external/pybind11/include/pybind11/detail/class.h @@ -0,0 +1,626 @@ +/* + pybind11/detail/class.h: Python C API implementation details for py::class_ + + Copyright (c) 2017 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include "../attr.h" + +NAMESPACE_BEGIN(PYBIND11_NAMESPACE) +NAMESPACE_BEGIN(detail) + +#if PY_VERSION_HEX >= 0x03030000 +# define PYBIND11_BUILTIN_QUALNAME +# define PYBIND11_SET_OLDPY_QUALNAME(obj, nameobj) +#else +// In pre-3.3 Python, we still set __qualname__ so that we can produce reliable function type +// signatures; in 3.3+ this macro expands to nothing: +# define PYBIND11_SET_OLDPY_QUALNAME(obj, nameobj) setattr((PyObject *) obj, "__qualname__", nameobj) +#endif + +inline PyTypeObject *type_incref(PyTypeObject *type) { + Py_INCREF(type); + return type; +} + +#if !defined(PYPY_VERSION) + +/// `pybind11_static_property.__get__()`: Always pass the class instead of the instance. +extern "C" inline PyObject *pybind11_static_get(PyObject *self, PyObject * /*ob*/, PyObject *cls) { + return PyProperty_Type.tp_descr_get(self, cls, cls); +} + +/// `pybind11_static_property.__set__()`: Just like the above `__get__()`. +extern "C" inline int pybind11_static_set(PyObject *self, PyObject *obj, PyObject *value) { + PyObject *cls = PyType_Check(obj) ? obj : (PyObject *) Py_TYPE(obj); + return PyProperty_Type.tp_descr_set(self, cls, value); +} + +/** A `static_property` is the same as a `property` but the `__get__()` and `__set__()` + methods are modified to always use the object type instead of a concrete instance. + Return value: New reference. */ +inline PyTypeObject *make_static_property_type() { + constexpr auto *name = "pybind11_static_property"; + auto name_obj = reinterpret_steal(PYBIND11_FROM_STRING(name)); + + /* Danger zone: from now (and until PyType_Ready), make sure to + issue no Python C API calls which could potentially invoke the + garbage collector (the GC will call type_traverse(), which will in + turn find the newly constructed type in an invalid state) */ + auto heap_type = (PyHeapTypeObject *) PyType_Type.tp_alloc(&PyType_Type, 0); + if (!heap_type) + pybind11_fail("make_static_property_type(): error allocating type!"); + + heap_type->ht_name = name_obj.inc_ref().ptr(); +#ifdef PYBIND11_BUILTIN_QUALNAME + heap_type->ht_qualname = name_obj.inc_ref().ptr(); +#endif + + auto type = &heap_type->ht_type; + type->tp_name = name; + type->tp_base = type_incref(&PyProperty_Type); + type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE; + type->tp_descr_get = pybind11_static_get; + type->tp_descr_set = pybind11_static_set; + + if (PyType_Ready(type) < 0) + pybind11_fail("make_static_property_type(): failure in PyType_Ready()!"); + + setattr((PyObject *) type, "__module__", str("pybind11_builtins")); + PYBIND11_SET_OLDPY_QUALNAME(type, name_obj); + + return type; +} + +#else // PYPY + +/** PyPy has some issues with the above C API, so we evaluate Python code instead. + This function will only be called once so performance isn't really a concern. + Return value: New reference. */ +inline PyTypeObject *make_static_property_type() { + auto d = dict(); + PyObject *result = PyRun_String(R"(\ + class pybind11_static_property(property): + def __get__(self, obj, cls): + return property.__get__(self, cls, cls) + + def __set__(self, obj, value): + cls = obj if isinstance(obj, type) else type(obj) + property.__set__(self, cls, value) + )", Py_file_input, d.ptr(), d.ptr() + ); + if (result == nullptr) + throw error_already_set(); + Py_DECREF(result); + return (PyTypeObject *) d["pybind11_static_property"].cast().release().ptr(); +} + +#endif // PYPY + +/** Types with static properties need to handle `Type.static_prop = x` in a specific way. + By default, Python replaces the `static_property` itself, but for wrapped C++ types + we need to call `static_property.__set__()` in order to propagate the new value to + the underlying C++ data structure. */ +extern "C" inline int pybind11_meta_setattro(PyObject* obj, PyObject* name, PyObject* value) { + // Use `_PyType_Lookup()` instead of `PyObject_GetAttr()` in order to get the raw + // descriptor (`property`) instead of calling `tp_descr_get` (`property.__get__()`). + PyObject *descr = _PyType_Lookup((PyTypeObject *) obj, name); + + // The following assignment combinations are possible: + // 1. `Type.static_prop = value` --> descr_set: `Type.static_prop.__set__(value)` + // 2. `Type.static_prop = other_static_prop` --> setattro: replace existing `static_prop` + // 3. `Type.regular_attribute = value` --> setattro: regular attribute assignment + const auto static_prop = (PyObject *) get_internals().static_property_type; + const auto call_descr_set = descr && PyObject_IsInstance(descr, static_prop) + && !PyObject_IsInstance(value, static_prop); + if (call_descr_set) { + // Call `static_property.__set__()` instead of replacing the `static_property`. +#if !defined(PYPY_VERSION) + return Py_TYPE(descr)->tp_descr_set(descr, obj, value); +#else + if (PyObject *result = PyObject_CallMethod(descr, "__set__", "OO", obj, value)) { + Py_DECREF(result); + return 0; + } else { + return -1; + } +#endif + } else { + // Replace existing attribute. + return PyType_Type.tp_setattro(obj, name, value); + } +} + +#if PY_MAJOR_VERSION >= 3 +/** + * Python 3's PyInstanceMethod_Type hides itself via its tp_descr_get, which prevents aliasing + * methods via cls.attr("m2") = cls.attr("m1"): instead the tp_descr_get returns a plain function, + * when called on a class, or a PyMethod, when called on an instance. Override that behaviour here + * to do a special case bypass for PyInstanceMethod_Types. + */ +extern "C" inline PyObject *pybind11_meta_getattro(PyObject *obj, PyObject *name) { + PyObject *descr = _PyType_Lookup((PyTypeObject *) obj, name); + if (descr && PyInstanceMethod_Check(descr)) { + Py_INCREF(descr); + return descr; + } + else { + return PyType_Type.tp_getattro(obj, name); + } +} +#endif + +/** This metaclass is assigned by default to all pybind11 types and is required in order + for static properties to function correctly. Users may override this using `py::metaclass`. + Return value: New reference. */ +inline PyTypeObject* make_default_metaclass() { + constexpr auto *name = "pybind11_type"; + auto name_obj = reinterpret_steal(PYBIND11_FROM_STRING(name)); + + /* Danger zone: from now (and until PyType_Ready), make sure to + issue no Python C API calls which could potentially invoke the + garbage collector (the GC will call type_traverse(), which will in + turn find the newly constructed type in an invalid state) */ + auto heap_type = (PyHeapTypeObject *) PyType_Type.tp_alloc(&PyType_Type, 0); + if (!heap_type) + pybind11_fail("make_default_metaclass(): error allocating metaclass!"); + + heap_type->ht_name = name_obj.inc_ref().ptr(); +#ifdef PYBIND11_BUILTIN_QUALNAME + heap_type->ht_qualname = name_obj.inc_ref().ptr(); +#endif + + auto type = &heap_type->ht_type; + type->tp_name = name; + type->tp_base = type_incref(&PyType_Type); + type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE; + + type->tp_setattro = pybind11_meta_setattro; +#if PY_MAJOR_VERSION >= 3 + type->tp_getattro = pybind11_meta_getattro; +#endif + + if (PyType_Ready(type) < 0) + pybind11_fail("make_default_metaclass(): failure in PyType_Ready()!"); + + setattr((PyObject *) type, "__module__", str("pybind11_builtins")); + PYBIND11_SET_OLDPY_QUALNAME(type, name_obj); + + return type; +} + +/// For multiple inheritance types we need to recursively register/deregister base pointers for any +/// base classes with pointers that are difference from the instance value pointer so that we can +/// correctly recognize an offset base class pointer. This calls a function with any offset base ptrs. +inline void traverse_offset_bases(void *valueptr, const detail::type_info *tinfo, instance *self, + bool (*f)(void * /*parentptr*/, instance * /*self*/)) { + for (handle h : reinterpret_borrow(tinfo->type->tp_bases)) { + if (auto parent_tinfo = get_type_info((PyTypeObject *) h.ptr())) { + for (auto &c : parent_tinfo->implicit_casts) { + if (c.first == tinfo->cpptype) { + auto *parentptr = c.second(valueptr); + if (parentptr != valueptr) + f(parentptr, self); + traverse_offset_bases(parentptr, parent_tinfo, self, f); + break; + } + } + } + } +} + +inline bool register_instance_impl(void *ptr, instance *self) { + get_internals().registered_instances.emplace(ptr, self); + return true; // unused, but gives the same signature as the deregister func +} +inline bool deregister_instance_impl(void *ptr, instance *self) { + auto ®istered_instances = get_internals().registered_instances; + auto range = registered_instances.equal_range(ptr); + for (auto it = range.first; it != range.second; ++it) { + if (Py_TYPE(self) == Py_TYPE(it->second)) { + registered_instances.erase(it); + return true; + } + } + return false; +} + +inline void register_instance(instance *self, void *valptr, const type_info *tinfo) { + register_instance_impl(valptr, self); + if (!tinfo->simple_ancestors) + traverse_offset_bases(valptr, tinfo, self, register_instance_impl); +} + +inline bool deregister_instance(instance *self, void *valptr, const type_info *tinfo) { + bool ret = deregister_instance_impl(valptr, self); + if (!tinfo->simple_ancestors) + traverse_offset_bases(valptr, tinfo, self, deregister_instance_impl); + return ret; +} + +/// Instance creation function for all pybind11 types. It allocates the internal instance layout for +/// holding C++ objects and holders. Allocation is done lazily (the first time the instance is cast +/// to a reference or pointer), and initialization is done by an `__init__` function. +inline PyObject *make_new_instance(PyTypeObject *type) { +#if defined(PYPY_VERSION) + // PyPy gets tp_basicsize wrong (issue 2482) under multiple inheritance when the first inherited + // object is a a plain Python type (i.e. not derived from an extension type). Fix it. + ssize_t instance_size = static_cast(sizeof(instance)); + if (type->tp_basicsize < instance_size) { + type->tp_basicsize = instance_size; + } +#endif + PyObject *self = type->tp_alloc(type, 0); + auto inst = reinterpret_cast(self); + // Allocate the value/holder internals: + inst->allocate_layout(); + + inst->owned = true; + + return self; +} + +/// Instance creation function for all pybind11 types. It only allocates space for the +/// C++ object, but doesn't call the constructor -- an `__init__` function must do that. +extern "C" inline PyObject *pybind11_object_new(PyTypeObject *type, PyObject *, PyObject *) { + return make_new_instance(type); +} + +/// An `__init__` function constructs the C++ object. Users should provide at least one +/// of these using `py::init` or directly with `.def(__init__, ...)`. Otherwise, the +/// following default function will be used which simply throws an exception. +extern "C" inline int pybind11_object_init(PyObject *self, PyObject *, PyObject *) { + PyTypeObject *type = Py_TYPE(self); + std::string msg; +#if defined(PYPY_VERSION) + msg += handle((PyObject *) type).attr("__module__").cast() + "."; +#endif + msg += type->tp_name; + msg += ": No constructor defined!"; + PyErr_SetString(PyExc_TypeError, msg.c_str()); + return -1; +} + +inline void add_patient(PyObject *nurse, PyObject *patient) { + auto &internals = get_internals(); + auto instance = reinterpret_cast(nurse); + auto ¤t_patients = internals.patients[nurse]; + instance->has_patients = true; + for (auto &p : current_patients) + if (p == patient) + return; + Py_INCREF(patient); + current_patients.push_back(patient); +} + +inline void clear_patients(PyObject *self) { + auto instance = reinterpret_cast(self); + auto &internals = get_internals(); + auto pos = internals.patients.find(self); + assert(pos != internals.patients.end()); + // Clearing the patients can cause more Python code to run, which + // can invalidate the iterator. Extract the vector of patients + // from the unordered_map first. + auto patients = std::move(pos->second); + internals.patients.erase(pos); + instance->has_patients = false; + for (PyObject *&patient : patients) + Py_CLEAR(patient); +} + +/// Clears all internal data from the instance and removes it from registered instances in +/// preparation for deallocation. +inline void clear_instance(PyObject *self) { + auto instance = reinterpret_cast(self); + + // Deallocate any values/holders, if present: + for (auto &v_h : values_and_holders(instance)) { + if (v_h) { + + // We have to deregister before we call dealloc because, for virtual MI types, we still + // need to be able to get the parent pointers. + if (v_h.instance_registered() && !deregister_instance(instance, v_h.value_ptr(), v_h.type)) + pybind11_fail("pybind11_object_dealloc(): Tried to deallocate unregistered instance!"); + + if (instance->owned || v_h.holder_constructed()) + v_h.type->dealloc(v_h); + } + } + // Deallocate the value/holder layout internals: + instance->deallocate_layout(); + + if (instance->weakrefs) + PyObject_ClearWeakRefs(self); + + PyObject **dict_ptr = _PyObject_GetDictPtr(self); + if (dict_ptr) + Py_CLEAR(*dict_ptr); + + if (instance->has_patients) + clear_patients(self); +} + +/// Instance destructor function for all pybind11 types. It calls `type_info.dealloc` +/// to destroy the C++ object itself, while the rest is Python bookkeeping. +extern "C" inline void pybind11_object_dealloc(PyObject *self) { + clear_instance(self); + + auto type = Py_TYPE(self); + type->tp_free(self); + + // `type->tp_dealloc != pybind11_object_dealloc` means that we're being called + // as part of a derived type's dealloc, in which case we're not allowed to decref + // the type here. For cross-module compatibility, we shouldn't compare directly + // with `pybind11_object_dealloc`, but with the common one stashed in internals. + auto pybind11_object_type = (PyTypeObject *) get_internals().instance_base; + if (type->tp_dealloc == pybind11_object_type->tp_dealloc) + Py_DECREF(type); +} + +/** Create the type which can be used as a common base for all classes. This is + needed in order to satisfy Python's requirements for multiple inheritance. + Return value: New reference. */ +inline PyObject *make_object_base_type(PyTypeObject *metaclass) { + constexpr auto *name = "pybind11_object"; + auto name_obj = reinterpret_steal(PYBIND11_FROM_STRING(name)); + + /* Danger zone: from now (and until PyType_Ready), make sure to + issue no Python C API calls which could potentially invoke the + garbage collector (the GC will call type_traverse(), which will in + turn find the newly constructed type in an invalid state) */ + auto heap_type = (PyHeapTypeObject *) metaclass->tp_alloc(metaclass, 0); + if (!heap_type) + pybind11_fail("make_object_base_type(): error allocating type!"); + + heap_type->ht_name = name_obj.inc_ref().ptr(); +#ifdef PYBIND11_BUILTIN_QUALNAME + heap_type->ht_qualname = name_obj.inc_ref().ptr(); +#endif + + auto type = &heap_type->ht_type; + type->tp_name = name; + type->tp_base = type_incref(&PyBaseObject_Type); + type->tp_basicsize = static_cast(sizeof(instance)); + type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE; + + type->tp_new = pybind11_object_new; + type->tp_init = pybind11_object_init; + type->tp_dealloc = pybind11_object_dealloc; + + /* Support weak references (needed for the keep_alive feature) */ + type->tp_weaklistoffset = offsetof(instance, weakrefs); + + if (PyType_Ready(type) < 0) + pybind11_fail("PyType_Ready failed in make_object_base_type():" + error_string()); + + setattr((PyObject *) type, "__module__", str("pybind11_builtins")); + PYBIND11_SET_OLDPY_QUALNAME(type, name_obj); + + assert(!PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC)); + return (PyObject *) heap_type; +} + +/// dynamic_attr: Support for `d = instance.__dict__`. +extern "C" inline PyObject *pybind11_get_dict(PyObject *self, void *) { + PyObject *&dict = *_PyObject_GetDictPtr(self); + if (!dict) + dict = PyDict_New(); + Py_XINCREF(dict); + return dict; +} + +/// dynamic_attr: Support for `instance.__dict__ = dict()`. +extern "C" inline int pybind11_set_dict(PyObject *self, PyObject *new_dict, void *) { + if (!PyDict_Check(new_dict)) { + PyErr_Format(PyExc_TypeError, "__dict__ must be set to a dictionary, not a '%.200s'", + Py_TYPE(new_dict)->tp_name); + return -1; + } + PyObject *&dict = *_PyObject_GetDictPtr(self); + Py_INCREF(new_dict); + Py_CLEAR(dict); + dict = new_dict; + return 0; +} + +/// dynamic_attr: Allow the garbage collector to traverse the internal instance `__dict__`. +extern "C" inline int pybind11_traverse(PyObject *self, visitproc visit, void *arg) { + PyObject *&dict = *_PyObject_GetDictPtr(self); + Py_VISIT(dict); + return 0; +} + +/// dynamic_attr: Allow the GC to clear the dictionary. +extern "C" inline int pybind11_clear(PyObject *self) { + PyObject *&dict = *_PyObject_GetDictPtr(self); + Py_CLEAR(dict); + return 0; +} + +/// Give instances of this type a `__dict__` and opt into garbage collection. +inline void enable_dynamic_attributes(PyHeapTypeObject *heap_type) { + auto type = &heap_type->ht_type; +#if defined(PYPY_VERSION) + pybind11_fail(std::string(type->tp_name) + ": dynamic attributes are " + "currently not supported in " + "conjunction with PyPy!"); +#endif + type->tp_flags |= Py_TPFLAGS_HAVE_GC; + type->tp_dictoffset = type->tp_basicsize; // place dict at the end + type->tp_basicsize += (ssize_t)sizeof(PyObject *); // and allocate enough space for it + type->tp_traverse = pybind11_traverse; + type->tp_clear = pybind11_clear; + + static PyGetSetDef getset[] = { + {const_cast("__dict__"), pybind11_get_dict, pybind11_set_dict, nullptr, nullptr}, + {nullptr, nullptr, nullptr, nullptr, nullptr} + }; + type->tp_getset = getset; +} + +/// buffer_protocol: Fill in the view as specified by flags. +extern "C" inline int pybind11_getbuffer(PyObject *obj, Py_buffer *view, int flags) { + // Look for a `get_buffer` implementation in this type's info or any bases (following MRO). + type_info *tinfo = nullptr; + for (auto type : reinterpret_borrow(Py_TYPE(obj)->tp_mro)) { + tinfo = get_type_info((PyTypeObject *) type.ptr()); + if (tinfo && tinfo->get_buffer) + break; + } + if (view == nullptr || obj == nullptr || !tinfo || !tinfo->get_buffer) { + if (view) + view->obj = nullptr; + PyErr_SetString(PyExc_BufferError, "pybind11_getbuffer(): Internal error"); + return -1; + } + std::memset(view, 0, sizeof(Py_buffer)); + buffer_info *info = tinfo->get_buffer(obj, tinfo->get_buffer_data); + view->obj = obj; + view->ndim = 1; + view->internal = info; + view->buf = info->ptr; + view->itemsize = info->itemsize; + view->len = view->itemsize; + for (auto s : info->shape) + view->len *= s; + if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) + view->format = const_cast(info->format.c_str()); + if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) { + view->ndim = (int) info->ndim; + view->strides = &info->strides[0]; + view->shape = &info->shape[0]; + } + Py_INCREF(view->obj); + return 0; +} + +/// buffer_protocol: Release the resources of the buffer. +extern "C" inline void pybind11_releasebuffer(PyObject *, Py_buffer *view) { + delete (buffer_info *) view->internal; +} + +/// Give this type a buffer interface. +inline void enable_buffer_protocol(PyHeapTypeObject *heap_type) { + heap_type->ht_type.tp_as_buffer = &heap_type->as_buffer; +#if PY_MAJOR_VERSION < 3 + heap_type->ht_type.tp_flags |= Py_TPFLAGS_HAVE_NEWBUFFER; +#endif + + heap_type->as_buffer.bf_getbuffer = pybind11_getbuffer; + heap_type->as_buffer.bf_releasebuffer = pybind11_releasebuffer; +} + +/** Create a brand new Python type according to the `type_record` specification. + Return value: New reference. */ +inline PyObject* make_new_python_type(const type_record &rec) { + auto name = reinterpret_steal(PYBIND11_FROM_STRING(rec.name)); + + auto qualname = name; + if (rec.scope && !PyModule_Check(rec.scope.ptr()) && hasattr(rec.scope, "__qualname__")) { +#if PY_MAJOR_VERSION >= 3 + qualname = reinterpret_steal( + PyUnicode_FromFormat("%U.%U", rec.scope.attr("__qualname__").ptr(), name.ptr())); +#else + qualname = str(rec.scope.attr("__qualname__").cast() + "." + rec.name); +#endif + } + + object module; + if (rec.scope) { + if (hasattr(rec.scope, "__module__")) + module = rec.scope.attr("__module__"); + else if (hasattr(rec.scope, "__name__")) + module = rec.scope.attr("__name__"); + } + + auto full_name = c_str( +#if !defined(PYPY_VERSION) + module ? str(module).cast() + "." + rec.name : +#endif + rec.name); + + char *tp_doc = nullptr; + if (rec.doc && options::show_user_defined_docstrings()) { + /* Allocate memory for docstring (using PyObject_MALLOC, since + Python will free this later on) */ + size_t size = strlen(rec.doc) + 1; + tp_doc = (char *) PyObject_MALLOC(size); + memcpy((void *) tp_doc, rec.doc, size); + } + + auto &internals = get_internals(); + auto bases = tuple(rec.bases); + auto base = (bases.size() == 0) ? internals.instance_base + : bases[0].ptr(); + + /* Danger zone: from now (and until PyType_Ready), make sure to + issue no Python C API calls which could potentially invoke the + garbage collector (the GC will call type_traverse(), which will in + turn find the newly constructed type in an invalid state) */ + auto metaclass = rec.metaclass.ptr() ? (PyTypeObject *) rec.metaclass.ptr() + : internals.default_metaclass; + + auto heap_type = (PyHeapTypeObject *) metaclass->tp_alloc(metaclass, 0); + if (!heap_type) + pybind11_fail(std::string(rec.name) + ": Unable to create type object!"); + + heap_type->ht_name = name.release().ptr(); +#ifdef PYBIND11_BUILTIN_QUALNAME + heap_type->ht_qualname = qualname.inc_ref().ptr(); +#endif + + auto type = &heap_type->ht_type; + type->tp_name = full_name; + type->tp_doc = tp_doc; + type->tp_base = type_incref((PyTypeObject *)base); + type->tp_basicsize = static_cast(sizeof(instance)); + if (bases.size() > 0) + type->tp_bases = bases.release().ptr(); + + /* Don't inherit base __init__ */ + type->tp_init = pybind11_object_init; + + /* Supported protocols */ + type->tp_as_number = &heap_type->as_number; + type->tp_as_sequence = &heap_type->as_sequence; + type->tp_as_mapping = &heap_type->as_mapping; + + /* Flags */ + type->tp_flags |= Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE; +#if PY_MAJOR_VERSION < 3 + type->tp_flags |= Py_TPFLAGS_CHECKTYPES; +#endif + + if (rec.dynamic_attr) + enable_dynamic_attributes(heap_type); + + if (rec.buffer_protocol) + enable_buffer_protocol(heap_type); + + if (PyType_Ready(type) < 0) + pybind11_fail(std::string(rec.name) + ": PyType_Ready failed (" + error_string() + ")!"); + + assert(rec.dynamic_attr ? PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC) + : !PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC)); + + /* Register type with the parent scope */ + if (rec.scope) + setattr(rec.scope, rec.name, (PyObject *) type); + else + Py_INCREF(type); // Keep it alive forever (reference leak) + + if (module) // Needed by pydoc + setattr((PyObject *) type, "__module__", module); + + PYBIND11_SET_OLDPY_QUALNAME(type, qualname); + + return (PyObject *) type; +} + +NAMESPACE_END(detail) +NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/ml/dlib/dlib/external/pybind11/include/pybind11/detail/common.h b/ml/dlib/dlib/external/pybind11/include/pybind11/detail/common.h new file mode 100644 index 000000000..7d41cd63b --- /dev/null +++ b/ml/dlib/dlib/external/pybind11/include/pybind11/detail/common.h @@ -0,0 +1,802 @@ +/* + pybind11/detail/common.h -- Basic macros + + Copyright (c) 2016 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#if !defined(NAMESPACE_BEGIN) +# define NAMESPACE_BEGIN(name) namespace name { +#endif +#if !defined(NAMESPACE_END) +# define NAMESPACE_END(name) } +#endif + +// Robust support for some features and loading modules compiled against different pybind versions +// requires forcing hidden visibility on pybind code, so we enforce this by setting the attribute on +// the main `pybind11` namespace. +#if !defined(PYBIND11_NAMESPACE) +# ifdef __GNUG__ +# define PYBIND11_NAMESPACE pybind11 __attribute__((visibility("hidden"))) +# else +# define PYBIND11_NAMESPACE pybind11 +# endif +#endif + +#if !defined(_MSC_VER) && !defined(__INTEL_COMPILER) +# if __cplusplus >= 201402L +# define PYBIND11_CPP14 +# if __cplusplus > 201402L /* Temporary: should be updated to >= the final C++17 value once known */ +# define PYBIND11_CPP17 +# endif +# endif +#elif defined(_MSC_VER) +// MSVC sets _MSVC_LANG rather than __cplusplus (supposedly until the standard is fully implemented) +# if _MSVC_LANG >= 201402L +# define PYBIND11_CPP14 +# if _MSVC_LANG > 201402L && _MSC_VER >= 1910 +# define PYBIND11_CPP17 +# endif +# endif +#endif + +// Compiler version assertions +#if defined(__INTEL_COMPILER) +# if __INTEL_COMPILER < 1500 +# error pybind11 requires Intel C++ compiler v15 or newer +# endif +#elif defined(__clang__) && !defined(__apple_build_version__) +# if __clang_major__ < 3 || (__clang_major__ == 3 && __clang_minor__ < 3) +# error pybind11 requires clang 3.3 or newer +# endif +#elif defined(__clang__) +// Apple changes clang version macros to its Xcode version; the first Xcode release based on +// (upstream) clang 3.3 was Xcode 5: +# if __clang_major__ < 5 +# error pybind11 requires Xcode/clang 5.0 or newer +# endif +#elif defined(__GNUG__) +# if __GNUC__ < 4 || (__GNUC__ == 4 && __GNUC_MINOR__ < 8) +# error pybind11 requires gcc 4.8 or newer +# endif +#elif defined(_MSC_VER) +// Pybind hits various compiler bugs in 2015u2 and earlier, and also makes use of some stl features +// (e.g. std::negation) added in 2015u3: +# if _MSC_FULL_VER < 190024210 +# error pybind11 requires MSVC 2015 update 3 or newer +# endif +#endif + +#if !defined(PYBIND11_EXPORT) +# if defined(WIN32) || defined(_WIN32) +# define PYBIND11_EXPORT __declspec(dllexport) +# else +# define PYBIND11_EXPORT __attribute__ ((visibility("default"))) +# endif +#endif + +#if defined(_MSC_VER) +# define PYBIND11_NOINLINE __declspec(noinline) +#else +# define PYBIND11_NOINLINE __attribute__ ((noinline)) +#endif + +#if defined(PYBIND11_CPP14) +# define PYBIND11_DEPRECATED(reason) [[deprecated(reason)]] +#else +# define PYBIND11_DEPRECATED(reason) __attribute__((deprecated(reason))) +#endif + +#define PYBIND11_VERSION_MAJOR 2 +#define PYBIND11_VERSION_MINOR 2 +#define PYBIND11_VERSION_PATCH 2 + +/// Include Python header, disable linking to pythonX_d.lib on Windows in debug mode +#if defined(_MSC_VER) +# if (PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION < 4) +# define HAVE_ROUND 1 +# endif +# pragma warning(push) +# pragma warning(disable: 4510 4610 4512 4005) +# if defined(_DEBUG) +# define PYBIND11_DEBUG_MARKER +# undef _DEBUG +# endif +#endif + +#include +#include +#include + +#if defined(_WIN32) && (defined(min) || defined(max)) +# error Macro clash with min and max -- define NOMINMAX when compiling your program on Windows +#endif + +#if defined(isalnum) +# undef isalnum +# undef isalpha +# undef islower +# undef isspace +# undef isupper +# undef tolower +# undef toupper +#endif + +#if defined(_MSC_VER) +# if defined(PYBIND11_DEBUG_MARKER) +# define _DEBUG +# undef PYBIND11_DEBUG_MARKER +# endif +# pragma warning(pop) +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if PY_MAJOR_VERSION >= 3 /// Compatibility macros for various Python versions +#define PYBIND11_INSTANCE_METHOD_NEW(ptr, class_) PyInstanceMethod_New(ptr) +#define PYBIND11_INSTANCE_METHOD_CHECK PyInstanceMethod_Check +#define PYBIND11_INSTANCE_METHOD_GET_FUNCTION PyInstanceMethod_GET_FUNCTION +#define PYBIND11_BYTES_CHECK PyBytes_Check +#define PYBIND11_BYTES_FROM_STRING PyBytes_FromString +#define PYBIND11_BYTES_FROM_STRING_AND_SIZE PyBytes_FromStringAndSize +#define PYBIND11_BYTES_AS_STRING_AND_SIZE PyBytes_AsStringAndSize +#define PYBIND11_BYTES_AS_STRING PyBytes_AsString +#define PYBIND11_BYTES_SIZE PyBytes_Size +#define PYBIND11_LONG_CHECK(o) PyLong_Check(o) +#define PYBIND11_LONG_AS_LONGLONG(o) PyLong_AsLongLong(o) +#define PYBIND11_BYTES_NAME "bytes" +#define PYBIND11_STRING_NAME "str" +#define PYBIND11_SLICE_OBJECT PyObject +#define PYBIND11_FROM_STRING PyUnicode_FromString +#define PYBIND11_STR_TYPE ::pybind11::str +#define PYBIND11_BOOL_ATTR "__bool__" +#define PYBIND11_NB_BOOL(ptr) ((ptr)->nb_bool) +#define PYBIND11_PLUGIN_IMPL(name) \ + extern "C" PYBIND11_EXPORT PyObject *PyInit_##name() + +#else +#define PYBIND11_INSTANCE_METHOD_NEW(ptr, class_) PyMethod_New(ptr, nullptr, class_) +#define PYBIND11_INSTANCE_METHOD_CHECK PyMethod_Check +#define PYBIND11_INSTANCE_METHOD_GET_FUNCTION PyMethod_GET_FUNCTION +#define PYBIND11_BYTES_CHECK PyString_Check +#define PYBIND11_BYTES_FROM_STRING PyString_FromString +#define PYBIND11_BYTES_FROM_STRING_AND_SIZE PyString_FromStringAndSize +#define PYBIND11_BYTES_AS_STRING_AND_SIZE PyString_AsStringAndSize +#define PYBIND11_BYTES_AS_STRING PyString_AsString +#define PYBIND11_BYTES_SIZE PyString_Size +#define PYBIND11_LONG_CHECK(o) (PyInt_Check(o) || PyLong_Check(o)) +#define PYBIND11_LONG_AS_LONGLONG(o) (PyInt_Check(o) ? (long long) PyLong_AsLong(o) : PyLong_AsLongLong(o)) +#define PYBIND11_BYTES_NAME "str" +#define PYBIND11_STRING_NAME "unicode" +#define PYBIND11_SLICE_OBJECT PySliceObject +#define PYBIND11_FROM_STRING PyString_FromString +#define PYBIND11_STR_TYPE ::pybind11::bytes +#define PYBIND11_BOOL_ATTR "__nonzero__" +#define PYBIND11_NB_BOOL(ptr) ((ptr)->nb_nonzero) +#define PYBIND11_PLUGIN_IMPL(name) \ + static PyObject *pybind11_init_wrapper(); \ + extern "C" PYBIND11_EXPORT void init##name() { \ + (void)pybind11_init_wrapper(); \ + } \ + PyObject *pybind11_init_wrapper() +#endif + +#if PY_VERSION_HEX >= 0x03050000 && PY_VERSION_HEX < 0x03050200 +extern "C" { + struct _Py_atomic_address { void *value; }; + PyAPI_DATA(_Py_atomic_address) _PyThreadState_Current; +} +#endif + +#define PYBIND11_TRY_NEXT_OVERLOAD ((PyObject *) 1) // special failure return code +#define PYBIND11_STRINGIFY(x) #x +#define PYBIND11_TOSTRING(x) PYBIND11_STRINGIFY(x) +#define PYBIND11_CONCAT(first, second) first##second + +/** \rst + ***Deprecated in favor of PYBIND11_MODULE*** + + This macro creates the entry point that will be invoked when the Python interpreter + imports a plugin library. Please create a `module` in the function body and return + the pointer to its underlying Python object at the end. + + .. code-block:: cpp + + PYBIND11_PLUGIN(example) { + pybind11::module m("example", "pybind11 example plugin"); + /// Set up bindings here + return m.ptr(); + } +\endrst */ +#define PYBIND11_PLUGIN(name) \ + PYBIND11_DEPRECATED("PYBIND11_PLUGIN is deprecated, use PYBIND11_MODULE") \ + static PyObject *pybind11_init(); \ + PYBIND11_PLUGIN_IMPL(name) { \ + int major, minor; \ + if (sscanf(Py_GetVersion(), "%i.%i", &major, &minor) != 2) { \ + PyErr_SetString(PyExc_ImportError, "Can't parse Python version."); \ + return nullptr; \ + } else if (major != PY_MAJOR_VERSION || minor != PY_MINOR_VERSION) { \ + PyErr_Format(PyExc_ImportError, \ + "Python version mismatch: module was compiled for " \ + "version %i.%i, while the interpreter is running " \ + "version %i.%i.", PY_MAJOR_VERSION, PY_MINOR_VERSION, \ + major, minor); \ + return nullptr; \ + } \ + try { \ + return pybind11_init(); \ + } catch (pybind11::error_already_set &e) { \ + PyErr_SetString(PyExc_ImportError, e.what()); \ + return nullptr; \ + } catch (const std::exception &e) { \ + PyErr_SetString(PyExc_ImportError, e.what()); \ + return nullptr; \ + } \ + } \ + PyObject *pybind11_init() + +/** \rst + This macro creates the entry point that will be invoked when the Python interpreter + imports an extension module. The module name is given as the fist argument and it + should not be in quotes. The second macro argument defines a variable of type + `py::module` which can be used to initialize the module. + + .. code-block:: cpp + + PYBIND11_MODULE(example, m) { + m.doc() = "pybind11 example module"; + + // Add bindings here + m.def("foo", []() { + return "Hello, World!"; + }); + } +\endrst */ +#define PYBIND11_MODULE(name, variable) \ + static void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &); \ + PYBIND11_PLUGIN_IMPL(name) { \ + int major, minor; \ + if (sscanf(Py_GetVersion(), "%i.%i", &major, &minor) != 2) { \ + PyErr_SetString(PyExc_ImportError, "Can't parse Python version."); \ + return nullptr; \ + } else if (major != PY_MAJOR_VERSION || minor != PY_MINOR_VERSION) { \ + PyErr_Format(PyExc_ImportError, \ + "Python version mismatch: module was compiled for " \ + "version %i.%i, while the interpreter is running " \ + "version %i.%i.", PY_MAJOR_VERSION, PY_MINOR_VERSION, \ + major, minor); \ + return nullptr; \ + } \ + auto m = pybind11::module(PYBIND11_TOSTRING(name)); \ + try { \ + PYBIND11_CONCAT(pybind11_init_, name)(m); \ + return m.ptr(); \ + } catch (pybind11::error_already_set &e) { \ + PyErr_SetString(PyExc_ImportError, e.what()); \ + return nullptr; \ + } catch (const std::exception &e) { \ + PyErr_SetString(PyExc_ImportError, e.what()); \ + return nullptr; \ + } \ + } \ + void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &variable) + + +NAMESPACE_BEGIN(PYBIND11_NAMESPACE) + +using ssize_t = Py_ssize_t; +using size_t = std::size_t; + +/// Approach used to cast a previously unknown C++ instance into a Python object +enum class return_value_policy : uint8_t { + /** This is the default return value policy, which falls back to the policy + return_value_policy::take_ownership when the return value is a pointer. + Otherwise, it uses return_value::move or return_value::copy for rvalue + and lvalue references, respectively. See below for a description of what + all of these different policies do. */ + automatic = 0, + + /** As above, but use policy return_value_policy::reference when the return + value is a pointer. This is the default conversion policy for function + arguments when calling Python functions manually from C++ code (i.e. via + handle::operator()). You probably won't need to use this. */ + automatic_reference, + + /** Reference an existing object (i.e. do not create a new copy) and take + ownership. Python will call the destructor and delete operator when the + object’s reference count reaches zero. Undefined behavior ensues when + the C++ side does the same.. */ + take_ownership, + + /** Create a new copy of the returned object, which will be owned by + Python. This policy is comparably safe because the lifetimes of the two + instances are decoupled. */ + copy, + + /** Use std::move to move the return value contents into a new instance + that will be owned by Python. This policy is comparably safe because the + lifetimes of the two instances (move source and destination) are + decoupled. */ + move, + + /** Reference an existing object, but do not take ownership. The C++ side + is responsible for managing the object’s lifetime and deallocating it + when it is no longer used. Warning: undefined behavior will ensue when + the C++ side deletes an object that is still referenced and used by + Python. */ + reference, + + /** This policy only applies to methods and properties. It references the + object without taking ownership similar to the above + return_value_policy::reference policy. In contrast to that policy, the + function or property’s implicit this argument (called the parent) is + considered to be the the owner of the return value (the child). + pybind11 then couples the lifetime of the parent to the child via a + reference relationship that ensures that the parent cannot be garbage + collected while Python is still using the child. More advanced + variations of this scheme are also possible using combinations of + return_value_policy::reference and the keep_alive call policy */ + reference_internal +}; + +NAMESPACE_BEGIN(detail) + +inline static constexpr int log2(size_t n, int k = 0) { return (n <= 1) ? k : log2(n >> 1, k + 1); } + +// Returns the size as a multiple of sizeof(void *), rounded up. +inline static constexpr size_t size_in_ptrs(size_t s) { return 1 + ((s - 1) >> log2(sizeof(void *))); } + +/** + * The space to allocate for simple layout instance holders (see below) in multiple of the size of + * a pointer (e.g. 2 means 16 bytes on 64-bit architectures). The default is the minimum required + * to holder either a std::unique_ptr or std::shared_ptr (which is almost always + * sizeof(std::shared_ptr)). + */ +constexpr size_t instance_simple_holder_in_ptrs() { + static_assert(sizeof(std::shared_ptr) >= sizeof(std::unique_ptr), + "pybind assumes std::shared_ptrs are at least as big as std::unique_ptrs"); + return size_in_ptrs(sizeof(std::shared_ptr)); +} + +// Forward declarations +struct type_info; +struct value_and_holder; + +struct nonsimple_values_and_holders { + void **values_and_holders; + uint8_t *status; +}; + +/// The 'instance' type which needs to be standard layout (need to be able to use 'offsetof') +struct instance { + PyObject_HEAD + /// Storage for pointers and holder; see simple_layout, below, for a description + union { + void *simple_value_holder[1 + instance_simple_holder_in_ptrs()]; + nonsimple_values_and_holders nonsimple; + }; + /// Weak references (needed for keep alive): + PyObject *weakrefs; + /// If true, the pointer is owned which means we're free to manage it with a holder. + bool owned : 1; + /** + * An instance has two possible value/holder layouts. + * + * Simple layout (when this flag is true), means the `simple_value_holder` is set with a pointer + * and the holder object governing that pointer, i.e. [val1*][holder]. This layout is applied + * whenever there is no python-side multiple inheritance of bound C++ types *and* the type's + * holder will fit in the default space (which is large enough to hold either a std::unique_ptr + * or std::shared_ptr). + * + * Non-simple layout applies when using custom holders that require more space than `shared_ptr` + * (which is typically the size of two pointers), or when multiple inheritance is used on the + * python side. Non-simple layout allocates the required amount of memory to have multiple + * bound C++ classes as parents. Under this layout, `nonsimple.values_and_holders` is set to a + * pointer to allocated space of the required space to hold a a sequence of value pointers and + * holders followed `status`, a set of bit flags (1 byte each), i.e. + * [val1*][holder1][val2*][holder2]...[bb...] where each [block] is rounded up to a multiple of + * `sizeof(void *)`. `nonsimple.holder_constructed` is, for convenience, a pointer to the + * beginning of the [bb...] block (but not independently allocated). + * + * Status bits indicate whether the associated holder is constructed (& + * status_holder_constructed) and whether the value pointer is registered (& + * status_instance_registered) in `registered_instances`. + */ + bool simple_layout : 1; + /// For simple layout, tracks whether the holder has been constructed + bool simple_holder_constructed : 1; + /// For simple layout, tracks whether the instance is registered in `registered_instances` + bool simple_instance_registered : 1; + /// If true, get_internals().patients has an entry for this object + bool has_patients : 1; + + /// Initializes all of the above type/values/holders data (but not the instance values themselves) + void allocate_layout(); + + /// Destroys/deallocates all of the above + void deallocate_layout(); + + /// Returns the value_and_holder wrapper for the given type (or the first, if `find_type` + /// omitted). Returns a default-constructed (with `.inst = nullptr`) object on failure if + /// `throw_if_missing` is false. + value_and_holder get_value_and_holder(const type_info *find_type = nullptr, bool throw_if_missing = true); + + /// Bit values for the non-simple status flags + static constexpr uint8_t status_holder_constructed = 1; + static constexpr uint8_t status_instance_registered = 2; +}; + +static_assert(std::is_standard_layout::value, "Internal error: `pybind11::detail::instance` is not standard layout!"); + +/// from __cpp_future__ import (convenient aliases from C++14/17) +#if defined(PYBIND11_CPP14) && (!defined(_MSC_VER) || _MSC_VER >= 1910) +using std::enable_if_t; +using std::conditional_t; +using std::remove_cv_t; +using std::remove_reference_t; +#else +template using enable_if_t = typename std::enable_if::type; +template using conditional_t = typename std::conditional::type; +template using remove_cv_t = typename std::remove_cv::type; +template using remove_reference_t = typename std::remove_reference::type; +#endif + +/// Index sequences +#if defined(PYBIND11_CPP14) +using std::index_sequence; +using std::make_index_sequence; +#else +template struct index_sequence { }; +template struct make_index_sequence_impl : make_index_sequence_impl { }; +template struct make_index_sequence_impl <0, S...> { typedef index_sequence type; }; +template using make_index_sequence = typename make_index_sequence_impl::type; +#endif + +/// Make an index sequence of the indices of true arguments +template struct select_indices_impl { using type = ISeq; }; +template struct select_indices_impl, I, B, Bs...> + : select_indices_impl, index_sequence>, I + 1, Bs...> {}; +template using select_indices = typename select_indices_impl, 0, Bs...>::type; + +/// Backports of std::bool_constant and std::negation to accommodate older compilers +template using bool_constant = std::integral_constant; +template struct negation : bool_constant { }; + +template struct void_t_impl { using type = void; }; +template using void_t = typename void_t_impl::type; + +/// Compile-time all/any/none of that check the boolean value of all template types +#ifdef __cpp_fold_expressions +template using all_of = bool_constant<(Ts::value && ...)>; +template using any_of = bool_constant<(Ts::value || ...)>; +#elif !defined(_MSC_VER) +template struct bools {}; +template using all_of = std::is_same< + bools, + bools>; +template using any_of = negation...>>; +#else +// MSVC has trouble with the above, but supports std::conjunction, which we can use instead (albeit +// at a slight loss of compilation efficiency). +template using all_of = std::conjunction; +template using any_of = std::disjunction; +#endif +template using none_of = negation>; + +template class... Predicates> using satisfies_all_of = all_of...>; +template class... Predicates> using satisfies_any_of = any_of...>; +template class... Predicates> using satisfies_none_of = none_of...>; + +/// Strip the class from a method type +template struct remove_class { }; +template struct remove_class { typedef R type(A...); }; +template struct remove_class { typedef R type(A...); }; + +/// Helper template to strip away type modifiers +template struct intrinsic_type { typedef T type; }; +template struct intrinsic_type { typedef typename intrinsic_type::type type; }; +template struct intrinsic_type { typedef typename intrinsic_type::type type; }; +template struct intrinsic_type { typedef typename intrinsic_type::type type; }; +template struct intrinsic_type { typedef typename intrinsic_type::type type; }; +template struct intrinsic_type { typedef typename intrinsic_type::type type; }; +template struct intrinsic_type { typedef typename intrinsic_type::type type; }; +template using intrinsic_t = typename intrinsic_type::type; + +/// Helper type to replace 'void' in some expressions +struct void_type { }; + +/// Helper template which holds a list of types +template struct type_list { }; + +/// Compile-time integer sum +#ifdef __cpp_fold_expressions +template constexpr size_t constexpr_sum(Ts... ns) { return (0 + ... + size_t{ns}); } +#else +constexpr size_t constexpr_sum() { return 0; } +template +constexpr size_t constexpr_sum(T n, Ts... ns) { return size_t{n} + constexpr_sum(ns...); } +#endif + +NAMESPACE_BEGIN(constexpr_impl) +/// Implementation details for constexpr functions +constexpr int first(int i) { return i; } +template +constexpr int first(int i, T v, Ts... vs) { return v ? i : first(i + 1, vs...); } + +constexpr int last(int /*i*/, int result) { return result; } +template +constexpr int last(int i, int result, T v, Ts... vs) { return last(i + 1, v ? i : result, vs...); } +NAMESPACE_END(constexpr_impl) + +/// Return the index of the first type in Ts which satisfies Predicate. Returns sizeof...(Ts) if +/// none match. +template class Predicate, typename... Ts> +constexpr int constexpr_first() { return constexpr_impl::first(0, Predicate::value...); } + +/// Return the index of the last type in Ts which satisfies Predicate, or -1 if none match. +template class Predicate, typename... Ts> +constexpr int constexpr_last() { return constexpr_impl::last(0, -1, Predicate::value...); } + +/// Return the Nth element from the parameter pack +template +struct pack_element { using type = typename pack_element::type; }; +template +struct pack_element<0, T, Ts...> { using type = T; }; + +/// Return the one and only type which matches the predicate, or Default if none match. +/// If more than one type matches the predicate, fail at compile-time. +template class Predicate, typename Default, typename... Ts> +struct exactly_one { + static constexpr auto found = constexpr_sum(Predicate::value...); + static_assert(found <= 1, "Found more than one type matching the predicate"); + + static constexpr auto index = found ? constexpr_first() : 0; + using type = conditional_t::type, Default>; +}; +template class P, typename Default> +struct exactly_one { using type = Default; }; + +template class Predicate, typename Default, typename... Ts> +using exactly_one_t = typename exactly_one::type; + +/// Defer the evaluation of type T until types Us are instantiated +template struct deferred_type { using type = T; }; +template using deferred_t = typename deferred_type::type; + +/// Like is_base_of, but requires a strict base (i.e. `is_strict_base_of::value == false`, +/// unlike `std::is_base_of`) +template using is_strict_base_of = bool_constant< + std::is_base_of::value && !std::is_same::value>; + +template class Base> +struct is_template_base_of_impl { + template static std::true_type check(Base *); + static std::false_type check(...); +}; + +/// Check if a template is the base of a type. For example: +/// `is_template_base_of` is true if `struct T : Base {}` where U can be anything +template class Base, typename T> +#if !defined(_MSC_VER) +using is_template_base_of = decltype(is_template_base_of_impl::check((intrinsic_t*)nullptr)); +#else // MSVC2015 has trouble with decltype in template aliases +struct is_template_base_of : decltype(is_template_base_of_impl::check((intrinsic_t*)nullptr)) { }; +#endif + +/// Check if T is an instantiation of the template `Class`. For example: +/// `is_instantiation` is true if `T == shared_ptr` where U can be anything. +template class Class, typename T> +struct is_instantiation : std::false_type { }; +template class Class, typename... Us> +struct is_instantiation> : std::true_type { }; + +/// Check if T is std::shared_ptr where U can be anything +template using is_shared_ptr = is_instantiation; + +/// Check if T looks like an input iterator +template struct is_input_iterator : std::false_type {}; +template +struct is_input_iterator()), decltype(++std::declval())>> + : std::true_type {}; + +template using is_function_pointer = bool_constant< + std::is_pointer::value && std::is_function::type>::value>; + +template struct strip_function_object { + using type = typename remove_class::type; +}; + +// Extracts the function signature from a function, function pointer or lambda. +template > +using function_signature_t = conditional_t< + std::is_function::value, + F, + typename conditional_t< + std::is_pointer::value || std::is_member_pointer::value, + std::remove_pointer, + strip_function_object + >::type +>; + +/// Returns true if the type looks like a lambda: that is, isn't a function, pointer or member +/// pointer. Note that this can catch all sorts of other things, too; this is intended to be used +/// in a place where passing a lambda makes sense. +template using is_lambda = satisfies_none_of, + std::is_function, std::is_pointer, std::is_member_pointer>; + +/// Ignore that a variable is unused in compiler warnings +inline void ignore_unused(const int *) { } + +/// Apply a function over each element of a parameter pack +#ifdef __cpp_fold_expressions +#define PYBIND11_EXPAND_SIDE_EFFECTS(PATTERN) (((PATTERN), void()), ...) +#else +using expand_side_effects = bool[]; +#define PYBIND11_EXPAND_SIDE_EFFECTS(PATTERN) pybind11::detail::expand_side_effects{ ((PATTERN), void(), false)..., false } +#endif + +NAMESPACE_END(detail) + +/// C++ bindings of builtin Python exceptions +class builtin_exception : public std::runtime_error { +public: + using std::runtime_error::runtime_error; + /// Set the error using the Python C API + virtual void set_error() const = 0; +}; + +#define PYBIND11_RUNTIME_EXCEPTION(name, type) \ + class name : public builtin_exception { public: \ + using builtin_exception::builtin_exception; \ + name() : name("") { } \ + void set_error() const override { PyErr_SetString(type, what()); } \ + }; + +PYBIND11_RUNTIME_EXCEPTION(stop_iteration, PyExc_StopIteration) +PYBIND11_RUNTIME_EXCEPTION(index_error, PyExc_IndexError) +PYBIND11_RUNTIME_EXCEPTION(key_error, PyExc_KeyError) +PYBIND11_RUNTIME_EXCEPTION(value_error, PyExc_ValueError) +PYBIND11_RUNTIME_EXCEPTION(type_error, PyExc_TypeError) +PYBIND11_RUNTIME_EXCEPTION(cast_error, PyExc_RuntimeError) /// Thrown when pybind11::cast or handle::call fail due to a type casting error +PYBIND11_RUNTIME_EXCEPTION(reference_cast_error, PyExc_RuntimeError) /// Used internally + +[[noreturn]] PYBIND11_NOINLINE inline void pybind11_fail(const char *reason) { throw std::runtime_error(reason); } +[[noreturn]] PYBIND11_NOINLINE inline void pybind11_fail(const std::string &reason) { throw std::runtime_error(reason); } + +template struct format_descriptor { }; + +NAMESPACE_BEGIN(detail) +// Returns the index of the given type in the type char array below, and in the list in numpy.h +// The order here is: bool; 8 ints ((signed,unsigned)x(8,16,32,64)bits); float,double,long double; +// complex float,double,long double. Note that the long double types only participate when long +// double is actually longer than double (it isn't under MSVC). +// NB: not only the string below but also complex.h and numpy.h rely on this order. +template struct is_fmt_numeric { static constexpr bool value = false; }; +template struct is_fmt_numeric::value>> { + static constexpr bool value = true; + static constexpr int index = std::is_same::value ? 0 : 1 + ( + std::is_integral::value ? detail::log2(sizeof(T))*2 + std::is_unsigned::value : 8 + ( + std::is_same::value ? 1 : std::is_same::value ? 2 : 0)); +}; +NAMESPACE_END(detail) + +template struct format_descriptor::value>> { + static constexpr const char c = "?bBhHiIqQfdg"[detail::is_fmt_numeric::index]; + static constexpr const char value[2] = { c, '\0' }; + static std::string format() { return std::string(1, c); } +}; + +template constexpr const char format_descriptor< + T, detail::enable_if_t::value>>::value[2]; + +/// RAII wrapper that temporarily clears any Python error state +struct error_scope { + PyObject *type, *value, *trace; + error_scope() { PyErr_Fetch(&type, &value, &trace); } + ~error_scope() { PyErr_Restore(type, value, trace); } +}; + +/// Dummy destructor wrapper that can be used to expose classes with a private destructor +struct nodelete { template void operator()(T*) { } }; + +// overload_cast requires variable templates: C++14 +#if defined(PYBIND11_CPP14) +#define PYBIND11_OVERLOAD_CAST 1 + +NAMESPACE_BEGIN(detail) +template +struct overload_cast_impl { + constexpr overload_cast_impl() {} // MSVC 2015 needs this + + template + constexpr auto operator()(Return (*pf)(Args...)) const noexcept + -> decltype(pf) { return pf; } + + template + constexpr auto operator()(Return (Class::*pmf)(Args...), std::false_type = {}) const noexcept + -> decltype(pmf) { return pmf; } + + template + constexpr auto operator()(Return (Class::*pmf)(Args...) const, std::true_type) const noexcept + -> decltype(pmf) { return pmf; } +}; +NAMESPACE_END(detail) + +/// Syntax sugar for resolving overloaded function pointers: +/// - regular: static_cast(&Class::func) +/// - sweet: overload_cast(&Class::func) +template +static constexpr detail::overload_cast_impl overload_cast = {}; +// MSVC 2015 only accepts this particular initialization syntax for this variable template. + +/// Const member function selector for overload_cast +/// - regular: static_cast(&Class::func) +/// - sweet: overload_cast(&Class::func, const_) +static constexpr auto const_ = std::true_type{}; + +#else // no overload_cast: providing something that static_assert-fails: +template struct overload_cast { + static_assert(detail::deferred_t::value, + "pybind11::overload_cast<...> requires compiling in C++14 mode"); +}; +#endif // overload_cast + +NAMESPACE_BEGIN(detail) + +// Adaptor for converting arbitrary container arguments into a vector; implicitly convertible from +// any standard container (or C-style array) supporting std::begin/std::end, any singleton +// arithmetic type (if T is arithmetic), or explicitly constructible from an iterator pair. +template +class any_container { + std::vector v; +public: + any_container() = default; + + // Can construct from a pair of iterators + template ::value>> + any_container(It first, It last) : v(first, last) { } + + // Implicit conversion constructor from any arbitrary container type with values convertible to T + template ())), T>::value>> + any_container(const Container &c) : any_container(std::begin(c), std::end(c)) { } + + // initializer_list's aren't deducible, so don't get matched by the above template; we need this + // to explicitly allow implicit conversion from one: + template ::value>> + any_container(const std::initializer_list &c) : any_container(c.begin(), c.end()) { } + + // Avoid copying if given an rvalue vector of the correct type. + any_container(std::vector &&v) : v(std::move(v)) { } + + // Moves the vector out of an rvalue any_container + operator std::vector &&() && { return std::move(v); } + + // Dereferencing obtains a reference to the underlying vector + std::vector &operator*() { return v; } + const std::vector &operator*() const { return v; } + + // -> lets you call methods on the underlying vector + std::vector *operator->() { return &v; } + const std::vector *operator->() const { return &v; } +}; + +NAMESPACE_END(detail) + + + +NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/ml/dlib/dlib/external/pybind11/include/pybind11/detail/descr.h b/ml/dlib/dlib/external/pybind11/include/pybind11/detail/descr.h new file mode 100644 index 000000000..e3bf2ba97 --- /dev/null +++ b/ml/dlib/dlib/external/pybind11/include/pybind11/detail/descr.h @@ -0,0 +1,185 @@ +/* + pybind11/detail/descr.h: Helper type for concatenating type signatures + either at runtime (C++11) or compile time (C++14) + + Copyright (c) 2016 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include "common.h" + +NAMESPACE_BEGIN(PYBIND11_NAMESPACE) +NAMESPACE_BEGIN(detail) + +/* Concatenate type signatures at compile time using C++14 */ +#if defined(PYBIND11_CPP14) && !defined(_MSC_VER) +#define PYBIND11_CONSTEXPR_DESCR + +template class descr { + template friend class descr; +public: + constexpr descr(char const (&text) [Size1+1], const std::type_info * const (&types)[Size2+1]) + : descr(text, types, + make_index_sequence(), + make_index_sequence()) { } + + constexpr const char *text() const { return m_text; } + constexpr const std::type_info * const * types() const { return m_types; } + + template + constexpr descr operator+(const descr &other) const { + return concat(other, + make_index_sequence(), + make_index_sequence(), + make_index_sequence(), + make_index_sequence()); + } + +protected: + template + constexpr descr( + char const (&text) [Size1+1], + const std::type_info * const (&types) [Size2+1], + index_sequence, index_sequence) + : m_text{text[Indices1]..., '\0'}, + m_types{types[Indices2]..., nullptr } {} + + template + constexpr descr + concat(const descr &other, + index_sequence, index_sequence, + index_sequence, index_sequence) const { + return descr( + { m_text[Indices1]..., other.m_text[OtherIndices1]..., '\0' }, + { m_types[Indices2]..., other.m_types[OtherIndices2]..., nullptr } + ); + } + +protected: + char m_text[Size1 + 1]; + const std::type_info * m_types[Size2 + 1]; +}; + +template constexpr descr _(char const(&text)[Size]) { + return descr(text, { nullptr }); +} + +template struct int_to_str : int_to_str { }; +template struct int_to_str<0, Digits...> { + static constexpr auto digits = descr({ ('0' + Digits)..., '\0' }, { nullptr }); +}; + +// Ternary description (like std::conditional) +template +constexpr enable_if_t> _(char const(&text1)[Size1], char const(&)[Size2]) { + return _(text1); +} +template +constexpr enable_if_t> _(char const(&)[Size1], char const(&text2)[Size2]) { + return _(text2); +} +template +constexpr enable_if_t> _(descr d, descr) { return d; } +template +constexpr enable_if_t> _(descr, descr d) { return d; } + +template auto constexpr _() -> decltype(int_to_str::digits) { + return int_to_str::digits; +} + +template constexpr descr<1, 1> _() { + return descr<1, 1>({ '%', '\0' }, { &typeid(Type), nullptr }); +} + +inline constexpr descr<0, 0> concat() { return _(""); } +template auto constexpr concat(descr descr) { return descr; } +template auto constexpr concat(descr descr, Args&&... args) { return descr + _(", ") + concat(args...); } +template auto constexpr type_descr(descr descr) { return _("{") + descr + _("}"); } + +#define PYBIND11_DESCR constexpr auto + +#else /* Simpler C++11 implementation based on run-time memory allocation and copying */ + +class descr { +public: + PYBIND11_NOINLINE descr(const char *text, const std::type_info * const * types) { + size_t nChars = len(text), nTypes = len(types); + m_text = new char[nChars]; + m_types = new const std::type_info *[nTypes]; + memcpy(m_text, text, nChars * sizeof(char)); + memcpy(m_types, types, nTypes * sizeof(const std::type_info *)); + } + + PYBIND11_NOINLINE descr operator+(descr &&d2) && { + descr r; + + size_t nChars1 = len(m_text), nTypes1 = len(m_types); + size_t nChars2 = len(d2.m_text), nTypes2 = len(d2.m_types); + + r.m_text = new char[nChars1 + nChars2 - 1]; + r.m_types = new const std::type_info *[nTypes1 + nTypes2 - 1]; + memcpy(r.m_text, m_text, (nChars1-1) * sizeof(char)); + memcpy(r.m_text + nChars1 - 1, d2.m_text, nChars2 * sizeof(char)); + memcpy(r.m_types, m_types, (nTypes1-1) * sizeof(std::type_info *)); + memcpy(r.m_types + nTypes1 - 1, d2.m_types, nTypes2 * sizeof(std::type_info *)); + + delete[] m_text; delete[] m_types; + delete[] d2.m_text; delete[] d2.m_types; + + return r; + } + + char *text() { return m_text; } + const std::type_info * * types() { return m_types; } + +protected: + PYBIND11_NOINLINE descr() { } + + template static size_t len(const T *ptr) { // return length including null termination + const T *it = ptr; + while (*it++ != (T) 0) + ; + return static_cast(it - ptr); + } + + const std::type_info **m_types = nullptr; + char *m_text = nullptr; +}; + +/* The 'PYBIND11_NOINLINE inline' combinations below are intentional to get the desired linkage while producing as little object code as possible */ + +PYBIND11_NOINLINE inline descr _(const char *text) { + const std::type_info *types[1] = { nullptr }; + return descr(text, types); +} + +template PYBIND11_NOINLINE enable_if_t _(const char *text1, const char *) { return _(text1); } +template PYBIND11_NOINLINE enable_if_t _(char const *, const char *text2) { return _(text2); } +template PYBIND11_NOINLINE enable_if_t _(descr d, descr) { return d; } +template PYBIND11_NOINLINE enable_if_t _(descr, descr d) { return d; } + +template PYBIND11_NOINLINE descr _() { + const std::type_info *types[2] = { &typeid(Type), nullptr }; + return descr("%", types); +} + +template PYBIND11_NOINLINE descr _() { + const std::type_info *types[1] = { nullptr }; + return descr(std::to_string(Size).c_str(), types); +} + +PYBIND11_NOINLINE inline descr concat() { return _(""); } +PYBIND11_NOINLINE inline descr concat(descr &&d) { return d; } +template PYBIND11_NOINLINE descr concat(descr &&d, Args&&... args) { return std::move(d) + _(", ") + concat(std::forward(args)...); } +PYBIND11_NOINLINE inline descr type_descr(descr&& d) { return _("{") + std::move(d) + _("}"); } + +#define PYBIND11_DESCR ::pybind11::detail::descr +#endif + +NAMESPACE_END(detail) +NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/ml/dlib/dlib/external/pybind11/include/pybind11/detail/init.h b/ml/dlib/dlib/external/pybind11/include/pybind11/detail/init.h new file mode 100644 index 000000000..82f740760 --- /dev/null +++ b/ml/dlib/dlib/external/pybind11/include/pybind11/detail/init.h @@ -0,0 +1,335 @@ +/* + pybind11/detail/init.h: init factory function implementation and support code. + + Copyright (c) 2017 Jason Rhinelander + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include "class.h" + +NAMESPACE_BEGIN(PYBIND11_NAMESPACE) +NAMESPACE_BEGIN(detail) + +template <> +class type_caster { +public: + bool load(handle h, bool) { + value = reinterpret_cast(h.ptr()); + return true; + } + + template using cast_op_type = value_and_holder &; + operator value_and_holder &() { return *value; } + static PYBIND11_DESCR name() { return type_descr(_()); } + +private: + value_and_holder *value = nullptr; +}; + +NAMESPACE_BEGIN(initimpl) + +inline void no_nullptr(void *ptr) { + if (!ptr) throw type_error("pybind11::init(): factory function returned nullptr"); +} + +// Implementing functions for all forms of py::init<...> and py::init(...) +template using Cpp = typename Class::type; +template using Alias = typename Class::type_alias; +template using Holder = typename Class::holder_type; + +template using is_alias_constructible = std::is_constructible, Cpp &&>; + +// Takes a Cpp pointer and returns true if it actually is a polymorphic Alias instance. +template = 0> +bool is_alias(Cpp *ptr) { + return dynamic_cast *>(ptr) != nullptr; +} +// Failing fallback version of the above for a no-alias class (always returns false) +template +constexpr bool is_alias(void *) { return false; } + +// Constructs and returns a new object; if the given arguments don't map to a constructor, we fall +// back to brace aggregate initiailization so that for aggregate initialization can be used with +// py::init, e.g. `py::init` to initialize a `struct T { int a; int b; }`. For +// non-aggregate types, we need to use an ordinary T(...) constructor (invoking as `T{...}` usually +// works, but will not do the expected thing when `T` has an `initializer_list` constructor). +template ::value, int> = 0> +inline Class *construct_or_initialize(Args &&...args) { return new Class(std::forward(args)...); } +template ::value, int> = 0> +inline Class *construct_or_initialize(Args &&...args) { return new Class{std::forward(args)...}; } + +// Attempts to constructs an alias using a `Alias(Cpp &&)` constructor. This allows types with +// an alias to provide only a single Cpp factory function as long as the Alias can be +// constructed from an rvalue reference of the base Cpp type. This means that Alias classes +// can, when appropriate, simply define a `Alias(Cpp &&)` constructor rather than needing to +// inherit all the base class constructors. +template +void construct_alias_from_cpp(std::true_type /*is_alias_constructible*/, + value_and_holder &v_h, Cpp &&base) { + v_h.value_ptr() = new Alias(std::move(base)); +} +template +[[noreturn]] void construct_alias_from_cpp(std::false_type /*!is_alias_constructible*/, + value_and_holder &, Cpp &&) { + throw type_error("pybind11::init(): unable to convert returned instance to required " + "alias class: no `Alias(Class &&)` constructor available"); +} + +// Error-generating fallback for factories that don't match one of the below construction +// mechanisms. +template +void construct(...) { + static_assert(!std::is_same::value /* always false */, + "pybind11::init(): init function must return a compatible pointer, " + "holder, or value"); +} + +// Pointer return v1: the factory function returns a class pointer for a registered class. +// If we don't need an alias (because this class doesn't have one, or because the final type is +// inherited on the Python side) we can simply take over ownership. Otherwise we need to try to +// construct an Alias from the returned base instance. +template +void construct(value_and_holder &v_h, Cpp *ptr, bool need_alias) { + no_nullptr(ptr); + if (Class::has_alias && need_alias && !is_alias(ptr)) { + // We're going to try to construct an alias by moving the cpp type. Whether or not + // that succeeds, we still need to destroy the original cpp pointer (either the + // moved away leftover, if the alias construction works, or the value itself if we + // throw an error), but we can't just call `delete ptr`: it might have a special + // deleter, or might be shared_from_this. So we construct a holder around it as if + // it was a normal instance, then steal the holder away into a local variable; thus + // the holder and destruction happens when we leave the C++ scope, and the holder + // class gets to handle the destruction however it likes. + v_h.value_ptr() = ptr; + v_h.set_instance_registered(true); // To prevent init_instance from registering it + v_h.type->init_instance(v_h.inst, nullptr); // Set up the holder + Holder temp_holder(std::move(v_h.holder>())); // Steal the holder + v_h.type->dealloc(v_h); // Destroys the moved-out holder remains, resets value ptr to null + v_h.set_instance_registered(false); + + construct_alias_from_cpp(is_alias_constructible{}, v_h, std::move(*ptr)); + } else { + // Otherwise the type isn't inherited, so we don't need an Alias + v_h.value_ptr() = ptr; + } +} + +// Pointer return v2: a factory that always returns an alias instance ptr. We simply take over +// ownership of the pointer. +template = 0> +void construct(value_and_holder &v_h, Alias *alias_ptr, bool) { + no_nullptr(alias_ptr); + v_h.value_ptr() = static_cast *>(alias_ptr); +} + +// Holder return: copy its pointer, and move or copy the returned holder into the new instance's +// holder. This also handles types like std::shared_ptr and std::unique_ptr where T is a +// derived type (through those holder's implicit conversion from derived class holder constructors). +template +void construct(value_and_holder &v_h, Holder holder, bool need_alias) { + auto *ptr = holder_helper>::get(holder); + // If we need an alias, check that the held pointer is actually an alias instance + if (Class::has_alias && need_alias && !is_alias(ptr)) + throw type_error("pybind11::init(): construction failed: returned holder-wrapped instance " + "is not an alias instance"); + + v_h.value_ptr() = ptr; + v_h.type->init_instance(v_h.inst, &holder); +} + +// return-by-value version 1: returning a cpp class by value. If the class has an alias and an +// alias is required the alias must have an `Alias(Cpp &&)` constructor so that we can construct +// the alias from the base when needed (i.e. because of Python-side inheritance). When we don't +// need it, we simply move-construct the cpp value into a new instance. +template +void construct(value_and_holder &v_h, Cpp &&result, bool need_alias) { + static_assert(std::is_move_constructible>::value, + "pybind11::init() return-by-value factory function requires a movable class"); + if (Class::has_alias && need_alias) + construct_alias_from_cpp(is_alias_constructible{}, v_h, std::move(result)); + else + v_h.value_ptr() = new Cpp(std::move(result)); +} + +// return-by-value version 2: returning a value of the alias type itself. We move-construct an +// Alias instance (even if no the python-side inheritance is involved). The is intended for +// cases where Alias initialization is always desired. +template +void construct(value_and_holder &v_h, Alias &&result, bool) { + static_assert(std::is_move_constructible>::value, + "pybind11::init() return-by-alias-value factory function requires a movable alias class"); + v_h.value_ptr() = new Alias(std::move(result)); +} + +// Implementing class for py::init<...>() +template +struct constructor { + template = 0> + static void execute(Class &cl, const Extra&... extra) { + cl.def("__init__", [](value_and_holder &v_h, Args... args) { + v_h.value_ptr() = construct_or_initialize>(std::forward(args)...); + }, is_new_style_constructor(), extra...); + } + + template , Args...>::value, int> = 0> + static void execute(Class &cl, const Extra&... extra) { + cl.def("__init__", [](value_and_holder &v_h, Args... args) { + if (Py_TYPE(v_h.inst) == v_h.type->type) + v_h.value_ptr() = construct_or_initialize>(std::forward(args)...); + else + v_h.value_ptr() = construct_or_initialize>(std::forward(args)...); + }, is_new_style_constructor(), extra...); + } + + template , Args...>::value, int> = 0> + static void execute(Class &cl, const Extra&... extra) { + cl.def("__init__", [](value_and_holder &v_h, Args... args) { + v_h.value_ptr() = construct_or_initialize>(std::forward(args)...); + }, is_new_style_constructor(), extra...); + } +}; + +// Implementing class for py::init_alias<...>() +template struct alias_constructor { + template , Args...>::value, int> = 0> + static void execute(Class &cl, const Extra&... extra) { + cl.def("__init__", [](value_and_holder &v_h, Args... args) { + v_h.value_ptr() = construct_or_initialize>(std::forward(args)...); + }, is_new_style_constructor(), extra...); + } +}; + +// Implementation class for py::init(Func) and py::init(Func, AliasFunc) +template , typename = function_signature_t> +struct factory; + +// Specialization for py::init(Func) +template +struct factory { + remove_reference_t class_factory; + + factory(Func &&f) : class_factory(std::forward(f)) { } + + // The given class either has no alias or has no separate alias factory; + // this always constructs the class itself. If the class is registered with an alias + // type and an alias instance is needed (i.e. because the final type is a Python class + // inheriting from the C++ type) the returned value needs to either already be an alias + // instance, or the alias needs to be constructible from a `Class &&` argument. + template + void execute(Class &cl, const Extra &...extra) && { + #if defined(PYBIND11_CPP14) + cl.def("__init__", [func = std::move(class_factory)] + #else + auto &func = class_factory; + cl.def("__init__", [func] + #endif + (value_and_holder &v_h, Args... args) { + construct(v_h, func(std::forward(args)...), + Py_TYPE(v_h.inst) != v_h.type->type); + }, is_new_style_constructor(), extra...); + } +}; + +// Specialization for py::init(Func, AliasFunc) +template +struct factory { + static_assert(sizeof...(CArgs) == sizeof...(AArgs), + "pybind11::init(class_factory, alias_factory): class and alias factories " + "must have identical argument signatures"); + static_assert(all_of...>::value, + "pybind11::init(class_factory, alias_factory): class and alias factories " + "must have identical argument signatures"); + + remove_reference_t class_factory; + remove_reference_t alias_factory; + + factory(CFunc &&c, AFunc &&a) + : class_factory(std::forward(c)), alias_factory(std::forward(a)) { } + + // The class factory is called when the `self` type passed to `__init__` is the direct + // class (i.e. not inherited), the alias factory when `self` is a Python-side subtype. + template + void execute(Class &cl, const Extra&... extra) && { + static_assert(Class::has_alias, "The two-argument version of `py::init()` can " + "only be used if the class has an alias"); + #if defined(PYBIND11_CPP14) + cl.def("__init__", [class_func = std::move(class_factory), alias_func = std::move(alias_factory)] + #else + auto &class_func = class_factory; + auto &alias_func = alias_factory; + cl.def("__init__", [class_func, alias_func] + #endif + (value_and_holder &v_h, CArgs... args) { + if (Py_TYPE(v_h.inst) == v_h.type->type) + // If the instance type equals the registered type we don't have inheritance, so + // don't need the alias and can construct using the class function: + construct(v_h, class_func(std::forward(args)...), false); + else + construct(v_h, alias_func(std::forward(args)...), true); + }, is_new_style_constructor(), extra...); + } +}; + +/// Set just the C++ state. Same as `__init__`. +template +void setstate(value_and_holder &v_h, T &&result, bool need_alias) { + construct(v_h, std::forward(result), need_alias); +} + +/// Set both the C++ and Python states +template ::value, int> = 0> +void setstate(value_and_holder &v_h, std::pair &&result, bool need_alias) { + construct(v_h, std::move(result.first), need_alias); + setattr((PyObject *) v_h.inst, "__dict__", result.second); +} + +/// Implementation for py::pickle(GetState, SetState) +template , typename = function_signature_t> +struct pickle_factory; + +template +struct pickle_factory { + static_assert(std::is_same, intrinsic_t>::value, + "The type returned by `__getstate__` must be the same " + "as the argument accepted by `__setstate__`"); + + remove_reference_t get; + remove_reference_t set; + + pickle_factory(Get get, Set set) + : get(std::forward(get)), set(std::forward(set)) { } + + template + void execute(Class &cl, const Extra &...extra) && { + cl.def("__getstate__", std::move(get)); + +#if defined(PYBIND11_CPP14) + cl.def("__setstate__", [func = std::move(set)] +#else + auto &func = set; + cl.def("__setstate__", [func] +#endif + (value_and_holder &v_h, ArgState state) { + setstate(v_h, func(std::forward(state)), + Py_TYPE(v_h.inst) != v_h.type->type); + }, is_new_style_constructor(), extra...); + } +}; + +NAMESPACE_END(initimpl) +NAMESPACE_END(detail) +NAMESPACE_END(pybind11) diff --git a/ml/dlib/dlib/external/pybind11/include/pybind11/detail/internals.h b/ml/dlib/dlib/external/pybind11/include/pybind11/detail/internals.h new file mode 100644 index 000000000..e39f38695 --- /dev/null +++ b/ml/dlib/dlib/external/pybind11/include/pybind11/detail/internals.h @@ -0,0 +1,249 @@ +/* + pybind11/detail/internals.h: Internal data structure and related functions + + Copyright (c) 2017 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include "../pytypes.h" + +NAMESPACE_BEGIN(PYBIND11_NAMESPACE) +NAMESPACE_BEGIN(detail) +// Forward declarations +inline PyTypeObject *make_static_property_type(); +inline PyTypeObject *make_default_metaclass(); +inline PyObject *make_object_base_type(PyTypeObject *metaclass); + +// Python loads modules by default with dlopen with the RTLD_LOCAL flag; under libc++ and possibly +// other STLs, this means `typeid(A)` from one module won't equal `typeid(A)` from another module +// even when `A` is the same, non-hidden-visibility type (e.g. from a common include). Under +// libstdc++, this doesn't happen: equality and the type_index hash are based on the type name, +// which works. If not under a known-good stl, provide our own name-based hash and equality +// functions that use the type name. +#if defined(__GLIBCXX__) +inline bool same_type(const std::type_info &lhs, const std::type_info &rhs) { return lhs == rhs; } +using type_hash = std::hash; +using type_equal_to = std::equal_to; +#else +inline bool same_type(const std::type_info &lhs, const std::type_info &rhs) { + return lhs.name() == rhs.name() || std::strcmp(lhs.name(), rhs.name()) == 0; +} + +struct type_hash { + size_t operator()(const std::type_index &t) const { + size_t hash = 5381; + const char *ptr = t.name(); + while (auto c = static_cast(*ptr++)) + hash = (hash * 33) ^ c; + return hash; + } +}; + +struct type_equal_to { + bool operator()(const std::type_index &lhs, const std::type_index &rhs) const { + return lhs.name() == rhs.name() || std::strcmp(lhs.name(), rhs.name()) == 0; + } +}; +#endif + +template +using type_map = std::unordered_map; + +struct overload_hash { + inline size_t operator()(const std::pair& v) const { + size_t value = std::hash()(v.first); + value ^= std::hash()(v.second) + 0x9e3779b9 + (value<<6) + (value>>2); + return value; + } +}; + +/// Internal data structure used to track registered instances and types. +/// Whenever binary incompatible changes are made to this structure, +/// `PYBIND11_INTERNALS_VERSION` must be incremented. +struct internals { + type_map registered_types_cpp; // std::type_index -> pybind11's type information + std::unordered_map> registered_types_py; // PyTypeObject* -> base type_info(s) + std::unordered_multimap registered_instances; // void * -> instance* + std::unordered_set, overload_hash> inactive_overload_cache; + type_map> direct_conversions; + std::unordered_map> patients; + std::forward_list registered_exception_translators; + std::unordered_map shared_data; // Custom data to be shared across extensions + std::vector loader_patient_stack; // Used by `loader_life_support` + std::forward_list static_strings; // Stores the std::strings backing detail::c_str() + PyTypeObject *static_property_type; + PyTypeObject *default_metaclass; + PyObject *instance_base; +#if defined(WITH_THREAD) + decltype(PyThread_create_key()) tstate = 0; // Usually an int but a long on Cygwin64 with Python 3.x + PyInterpreterState *istate = nullptr; +#endif +}; + +/// Additional type information which does not fit into the PyTypeObject. +/// Changes to this struct also require bumping `PYBIND11_INTERNALS_VERSION`. +struct type_info { + PyTypeObject *type; + const std::type_info *cpptype; + size_t type_size, holder_size_in_ptrs; + void *(*operator_new)(size_t); + void (*init_instance)(instance *, const void *); + void (*dealloc)(value_and_holder &v_h); + std::vector implicit_conversions; + std::vector> implicit_casts; + std::vector *direct_conversions; + buffer_info *(*get_buffer)(PyObject *, void *) = nullptr; + void *get_buffer_data = nullptr; + void *(*module_local_load)(PyObject *, const type_info *) = nullptr; + /* A simple type never occurs as a (direct or indirect) parent + * of a class that makes use of multiple inheritance */ + bool simple_type : 1; + /* True if there is no multiple inheritance in this type's inheritance tree */ + bool simple_ancestors : 1; + /* for base vs derived holder_type checks */ + bool default_holder : 1; + /* true if this is a type registered with py::module_local */ + bool module_local : 1; +}; + +/// Tracks the `internals` and `type_info` ABI version independent of the main library version +#define PYBIND11_INTERNALS_VERSION 1 + +#if defined(WITH_THREAD) +# define PYBIND11_INTERNALS_KIND "" +#else +# define PYBIND11_INTERNALS_KIND "_without_thread" +#endif + +#define PYBIND11_INTERNALS_ID "__pybind11_internals_v" \ + PYBIND11_TOSTRING(PYBIND11_INTERNALS_VERSION) PYBIND11_INTERNALS_KIND "__" + +#define PYBIND11_MODULE_LOCAL_ID "__pybind11_module_local_v" \ + PYBIND11_TOSTRING(PYBIND11_INTERNALS_VERSION) PYBIND11_INTERNALS_KIND "__" + +/// Each module locally stores a pointer to the `internals` data. The data +/// itself is shared among modules with the same `PYBIND11_INTERNALS_ID`. +inline internals **&get_internals_pp() { + static internals **internals_pp = nullptr; + return internals_pp; +} + +/// Return a reference to the current `internals` data +PYBIND11_NOINLINE inline internals &get_internals() { + auto **&internals_pp = get_internals_pp(); + if (internals_pp && *internals_pp) + return **internals_pp; + + constexpr auto *id = PYBIND11_INTERNALS_ID; + auto builtins = handle(PyEval_GetBuiltins()); + if (builtins.contains(id) && isinstance(builtins[id])) { + internals_pp = static_cast(capsule(builtins[id])); + + // We loaded builtins through python's builtins, which means that our `error_already_set` + // and `builtin_exception` may be different local classes than the ones set up in the + // initial exception translator, below, so add another for our local exception classes. + // + // libstdc++ doesn't require this (types there are identified only by name) +#if !defined(__GLIBCXX__) + (*internals_pp)->registered_exception_translators.push_front( + [](std::exception_ptr p) -> void { + try { + if (p) std::rethrow_exception(p); + } catch (error_already_set &e) { e.restore(); return; + } catch (const builtin_exception &e) { e.set_error(); return; + } + } + ); +#endif + } else { + if (!internals_pp) internals_pp = new internals*(); + auto *&internals_ptr = *internals_pp; + internals_ptr = new internals(); +#if defined(WITH_THREAD) + PyEval_InitThreads(); + PyThreadState *tstate = PyThreadState_Get(); + internals_ptr->tstate = PyThread_create_key(); + PyThread_set_key_value(internals_ptr->tstate, tstate); + internals_ptr->istate = tstate->interp; +#endif + builtins[id] = capsule(internals_pp); + internals_ptr->registered_exception_translators.push_front( + [](std::exception_ptr p) -> void { + try { + if (p) std::rethrow_exception(p); + } catch (error_already_set &e) { e.restore(); return; + } catch (const builtin_exception &e) { e.set_error(); return; + } catch (const std::bad_alloc &e) { PyErr_SetString(PyExc_MemoryError, e.what()); return; + } catch (const std::domain_error &e) { PyErr_SetString(PyExc_ValueError, e.what()); return; + } catch (const std::invalid_argument &e) { PyErr_SetString(PyExc_ValueError, e.what()); return; + } catch (const std::length_error &e) { PyErr_SetString(PyExc_ValueError, e.what()); return; + } catch (const std::out_of_range &e) { PyErr_SetString(PyExc_IndexError, e.what()); return; + } catch (const std::range_error &e) { PyErr_SetString(PyExc_ValueError, e.what()); return; + } catch (const std::exception &e) { PyErr_SetString(PyExc_RuntimeError, e.what()); return; + } catch (...) { + PyErr_SetString(PyExc_RuntimeError, "Caught an unknown exception!"); + return; + } + } + ); + internals_ptr->static_property_type = make_static_property_type(); + internals_ptr->default_metaclass = make_default_metaclass(); + internals_ptr->instance_base = make_object_base_type(internals_ptr->default_metaclass); + } + return **internals_pp; +} + +/// Works like `internals.registered_types_cpp`, but for module-local registered types: +inline type_map ®istered_local_types_cpp() { + static type_map locals{}; + return locals; +} + +/// Constructs a std::string with the given arguments, stores it in `internals`, and returns its +/// `c_str()`. Such strings objects have a long storage duration -- the internal strings are only +/// cleared when the program exits or after interpreter shutdown (when embedding), and so are +/// suitable for c-style strings needed by Python internals (such as PyTypeObject's tp_name). +template +const char *c_str(Args &&...args) { + auto &strings = get_internals().static_strings; + strings.emplace_front(std::forward(args)...); + return strings.front().c_str(); +} + +NAMESPACE_END(detail) + +/// Returns a named pointer that is shared among all extension modules (using the same +/// pybind11 version) running in the current interpreter. Names starting with underscores +/// are reserved for internal usage. Returns `nullptr` if no matching entry was found. +inline PYBIND11_NOINLINE void *get_shared_data(const std::string &name) { + auto &internals = detail::get_internals(); + auto it = internals.shared_data.find(name); + return it != internals.shared_data.end() ? it->second : nullptr; +} + +/// Set the shared data that can be later recovered by `get_shared_data()`. +inline PYBIND11_NOINLINE void *set_shared_data(const std::string &name, void *data) { + detail::get_internals().shared_data[name] = data; + return data; +} + +/// Returns a typed reference to a shared data entry (by using `get_shared_data()`) if +/// such entry exists. Otherwise, a new object of default-constructible type `T` is +/// added to the shared data under the given name and a reference to it is returned. +template +T &get_or_create_shared_data(const std::string &name) { + auto &internals = detail::get_internals(); + auto it = internals.shared_data.find(name); + T *ptr = (T *) (it != internals.shared_data.end() ? it->second : nullptr); + if (!ptr) { + ptr = new T(); + internals.shared_data[name] = ptr; + } + return *ptr; +} + +NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/ml/dlib/dlib/external/pybind11/include/pybind11/detail/typeid.h b/ml/dlib/dlib/external/pybind11/include/pybind11/detail/typeid.h new file mode 100644 index 000000000..6f36aab75 --- /dev/null +++ b/ml/dlib/dlib/external/pybind11/include/pybind11/detail/typeid.h @@ -0,0 +1,53 @@ +/* + pybind11/detail/typeid.h: Compiler-independent access to type identifiers + + Copyright (c) 2016 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include +#include + +#if defined(__GNUG__) +#include +#endif + +NAMESPACE_BEGIN(PYBIND11_NAMESPACE) +NAMESPACE_BEGIN(detail) +/// Erase all occurrences of a substring +inline void erase_all(std::string &string, const std::string &search) { + for (size_t pos = 0;;) { + pos = string.find(search, pos); + if (pos == std::string::npos) break; + string.erase(pos, search.length()); + } +} + +PYBIND11_NOINLINE inline void clean_type_id(std::string &name) { +#if defined(__GNUG__) + int status = 0; + std::unique_ptr res { + abi::__cxa_demangle(name.c_str(), nullptr, nullptr, &status), std::free }; + if (status == 0) + name = res.get(); +#else + detail::erase_all(name, "class "); + detail::erase_all(name, "struct "); + detail::erase_all(name, "enum "); +#endif + detail::erase_all(name, "pybind11::"); +} +NAMESPACE_END(detail) + +/// Return a string representation of a C++ type +template static std::string type_id() { + std::string name(typeid(T).name()); + detail::clean_type_id(name); + return name; +} + +NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/ml/dlib/dlib/external/pybind11/include/pybind11/eigen.h b/ml/dlib/dlib/external/pybind11/include/pybind11/eigen.h new file mode 100644 index 000000000..693a484dc --- /dev/null +++ b/ml/dlib/dlib/external/pybind11/include/pybind11/eigen.h @@ -0,0 +1,612 @@ +/* + pybind11/eigen.h: Transparent conversion for dense and sparse Eigen matrices + + Copyright (c) 2016 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include "numpy.h" + +#if defined(__INTEL_COMPILER) +# pragma warning(disable: 1682) // implicit conversion of a 64-bit integral type to a smaller integral type (potential portability problem) +#elif defined(__GNUG__) || defined(__clang__) +# pragma GCC diagnostic push +# pragma GCC diagnostic ignored "-Wconversion" +# pragma GCC diagnostic ignored "-Wdeprecated-declarations" +# if __GNUC__ >= 7 +# pragma GCC diagnostic ignored "-Wint-in-bool-context" +# endif +#endif + +#if defined(_MSC_VER) +# pragma warning(push) +# pragma warning(disable: 4127) // warning C4127: Conditional expression is constant +# pragma warning(disable: 4996) // warning C4996: std::unary_negate is deprecated in C++17 +#endif + +#include +#include + +// Eigen prior to 3.2.7 doesn't have proper move constructors--but worse, some classes get implicit +// move constructors that break things. We could detect this an explicitly copy, but an extra copy +// of matrices seems highly undesirable. +static_assert(EIGEN_VERSION_AT_LEAST(3,2,7), "Eigen support in pybind11 requires Eigen >= 3.2.7"); + +NAMESPACE_BEGIN(PYBIND11_NAMESPACE) + +// Provide a convenience alias for easier pass-by-ref usage with fully dynamic strides: +using EigenDStride = Eigen::Stride; +template using EigenDRef = Eigen::Ref; +template using EigenDMap = Eigen::Map; + +NAMESPACE_BEGIN(detail) + +#if EIGEN_VERSION_AT_LEAST(3,3,0) +using EigenIndex = Eigen::Index; +#else +using EigenIndex = EIGEN_DEFAULT_DENSE_INDEX_TYPE; +#endif + +// Matches Eigen::Map, Eigen::Ref, blocks, etc: +template using is_eigen_dense_map = all_of, std::is_base_of, T>>; +template using is_eigen_mutable_map = std::is_base_of, T>; +template using is_eigen_dense_plain = all_of>, is_template_base_of>; +template using is_eigen_sparse = is_template_base_of; +// Test for objects inheriting from EigenBase that aren't captured by the above. This +// basically covers anything that can be assigned to a dense matrix but that don't have a typical +// matrix data layout that can be copied from their .data(). For example, DiagonalMatrix and +// SelfAdjointView fall into this category. +template using is_eigen_other = all_of< + is_template_base_of, + negation, is_eigen_dense_plain, is_eigen_sparse>> +>; + +// Captures numpy/eigen conformability status (returned by EigenProps::conformable()): +template struct EigenConformable { + bool conformable = false; + EigenIndex rows = 0, cols = 0; + EigenDStride stride{0, 0}; // Only valid if negativestrides is false! + bool negativestrides = false; // If true, do not use stride! + + EigenConformable(bool fits = false) : conformable{fits} {} + // Matrix type: + EigenConformable(EigenIndex r, EigenIndex c, + EigenIndex rstride, EigenIndex cstride) : + conformable{true}, rows{r}, cols{c} { + // TODO: when Eigen bug #747 is fixed, remove the tests for non-negativity. http://eigen.tuxfamily.org/bz/show_bug.cgi?id=747 + if (rstride < 0 || cstride < 0) { + negativestrides = true; + } else { + stride = {EigenRowMajor ? rstride : cstride /* outer stride */, + EigenRowMajor ? cstride : rstride /* inner stride */ }; + } + } + // Vector type: + EigenConformable(EigenIndex r, EigenIndex c, EigenIndex stride) + : EigenConformable(r, c, r == 1 ? c*stride : stride, c == 1 ? r : r*stride) {} + + template bool stride_compatible() const { + // To have compatible strides, we need (on both dimensions) one of fully dynamic strides, + // matching strides, or a dimension size of 1 (in which case the stride value is irrelevant) + return + !negativestrides && + (props::inner_stride == Eigen::Dynamic || props::inner_stride == stride.inner() || + (EigenRowMajor ? cols : rows) == 1) && + (props::outer_stride == Eigen::Dynamic || props::outer_stride == stride.outer() || + (EigenRowMajor ? rows : cols) == 1); + } + operator bool() const { return conformable; } +}; + +template struct eigen_extract_stride { using type = Type; }; +template +struct eigen_extract_stride> { using type = StrideType; }; +template +struct eigen_extract_stride> { using type = StrideType; }; + +// Helper struct for extracting information from an Eigen type +template struct EigenProps { + using Type = Type_; + using Scalar = typename Type::Scalar; + using StrideType = typename eigen_extract_stride::type; + static constexpr EigenIndex + rows = Type::RowsAtCompileTime, + cols = Type::ColsAtCompileTime, + size = Type::SizeAtCompileTime; + static constexpr bool + row_major = Type::IsRowMajor, + vector = Type::IsVectorAtCompileTime, // At least one dimension has fixed size 1 + fixed_rows = rows != Eigen::Dynamic, + fixed_cols = cols != Eigen::Dynamic, + fixed = size != Eigen::Dynamic, // Fully-fixed size + dynamic = !fixed_rows && !fixed_cols; // Fully-dynamic size + + template using if_zero = std::integral_constant; + static constexpr EigenIndex inner_stride = if_zero::value, + outer_stride = if_zero::value; + static constexpr bool dynamic_stride = inner_stride == Eigen::Dynamic && outer_stride == Eigen::Dynamic; + static constexpr bool requires_row_major = !dynamic_stride && !vector && (row_major ? inner_stride : outer_stride) == 1; + static constexpr bool requires_col_major = !dynamic_stride && !vector && (row_major ? outer_stride : inner_stride) == 1; + + // Takes an input array and determines whether we can make it fit into the Eigen type. If + // the array is a vector, we attempt to fit it into either an Eigen 1xN or Nx1 vector + // (preferring the latter if it will fit in either, i.e. for a fully dynamic matrix type). + static EigenConformable conformable(const array &a) { + const auto dims = a.ndim(); + if (dims < 1 || dims > 2) + return false; + + if (dims == 2) { // Matrix type: require exact match (or dynamic) + + EigenIndex + np_rows = a.shape(0), + np_cols = a.shape(1), + np_rstride = a.strides(0) / static_cast(sizeof(Scalar)), + np_cstride = a.strides(1) / static_cast(sizeof(Scalar)); + if ((fixed_rows && np_rows != rows) || (fixed_cols && np_cols != cols)) + return false; + + return {np_rows, np_cols, np_rstride, np_cstride}; + } + + // Otherwise we're storing an n-vector. Only one of the strides will be used, but whichever + // is used, we want the (single) numpy stride value. + const EigenIndex n = a.shape(0), + stride = a.strides(0) / static_cast(sizeof(Scalar)); + + if (vector) { // Eigen type is a compile-time vector + if (fixed && size != n) + return false; // Vector size mismatch + return {rows == 1 ? 1 : n, cols == 1 ? 1 : n, stride}; + } + else if (fixed) { + // The type has a fixed size, but is not a vector: abort + return false; + } + else if (fixed_cols) { + // Since this isn't a vector, cols must be != 1. We allow this only if it exactly + // equals the number of elements (rows is Dynamic, and so 1 row is allowed). + if (cols != n) return false; + return {1, n, stride}; + } + else { + // Otherwise it's either fully dynamic, or column dynamic; both become a column vector + if (fixed_rows && rows != n) return false; + return {n, 1, stride}; + } + } + + static PYBIND11_DESCR descriptor() { + constexpr bool show_writeable = is_eigen_dense_map::value && is_eigen_mutable_map::value; + constexpr bool show_order = is_eigen_dense_map::value; + constexpr bool show_c_contiguous = show_order && requires_row_major; + constexpr bool show_f_contiguous = !show_c_contiguous && show_order && requires_col_major; + + return type_descr(_("numpy.ndarray[") + npy_format_descriptor::name() + + _("[") + _(_<(size_t) rows>(), _("m")) + + _(", ") + _(_<(size_t) cols>(), _("n")) + + _("]") + + // For a reference type (e.g. Ref) we have other constraints that might need to be + // satisfied: writeable=True (for a mutable reference), and, depending on the map's stride + // options, possibly f_contiguous or c_contiguous. We include them in the descriptor output + // to provide some hint as to why a TypeError is occurring (otherwise it can be confusing to + // see that a function accepts a 'numpy.ndarray[float64[3,2]]' and an error message that you + // *gave* a numpy.ndarray of the right type and dimensions. + _(", flags.writeable", "") + + _(", flags.c_contiguous", "") + + _(", flags.f_contiguous", "") + + _("]") + ); + } +}; + +// Casts an Eigen type to numpy array. If given a base, the numpy array references the src data, +// otherwise it'll make a copy. writeable lets you turn off the writeable flag for the array. +template handle eigen_array_cast(typename props::Type const &src, handle base = handle(), bool writeable = true) { + constexpr ssize_t elem_size = sizeof(typename props::Scalar); + array a; + if (props::vector) + a = array({ src.size() }, { elem_size * src.innerStride() }, src.data(), base); + else + a = array({ src.rows(), src.cols() }, { elem_size * src.rowStride(), elem_size * src.colStride() }, + src.data(), base); + + if (!writeable) + array_proxy(a.ptr())->flags &= ~detail::npy_api::NPY_ARRAY_WRITEABLE_; + + return a.release(); +} + +// Takes an lvalue ref to some Eigen type and a (python) base object, creating a numpy array that +// reference the Eigen object's data with `base` as the python-registered base class (if omitted, +// the base will be set to None, and lifetime management is up to the caller). The numpy array is +// non-writeable if the given type is const. +template +handle eigen_ref_array(Type &src, handle parent = none()) { + // none here is to get past array's should-we-copy detection, which currently always + // copies when there is no base. Setting the base to None should be harmless. + return eigen_array_cast(src, parent, !std::is_const::value); +} + +// Takes a pointer to some dense, plain Eigen type, builds a capsule around it, then returns a numpy +// array that references the encapsulated data with a python-side reference to the capsule to tie +// its destruction to that of any dependent python objects. Const-ness is determined by whether or +// not the Type of the pointer given is const. +template ::value>> +handle eigen_encapsulate(Type *src) { + capsule base(src, [](void *o) { delete static_cast(o); }); + return eigen_ref_array(*src, base); +} + +// Type caster for regular, dense matrix types (e.g. MatrixXd), but not maps/refs/etc. of dense +// types. +template +struct type_caster::value>> { + using Scalar = typename Type::Scalar; + using props = EigenProps; + + bool load(handle src, bool convert) { + // If we're in no-convert mode, only load if given an array of the correct type + if (!convert && !isinstance>(src)) + return false; + + // Coerce into an array, but don't do type conversion yet; the copy below handles it. + auto buf = array::ensure(src); + + if (!buf) + return false; + + auto dims = buf.ndim(); + if (dims < 1 || dims > 2) + return false; + + auto fits = props::conformable(buf); + if (!fits) + return false; + + // Allocate the new type, then build a numpy reference into it + value = Type(fits.rows, fits.cols); + auto ref = reinterpret_steal(eigen_ref_array(value)); + if (dims == 1) ref = ref.squeeze(); + else if (ref.ndim() == 1) buf = buf.squeeze(); + + int result = detail::npy_api::get().PyArray_CopyInto_(ref.ptr(), buf.ptr()); + + if (result < 0) { // Copy failed! + PyErr_Clear(); + return false; + } + + return true; + } + +private: + + // Cast implementation + template + static handle cast_impl(CType *src, return_value_policy policy, handle parent) { + switch (policy) { + case return_value_policy::take_ownership: + case return_value_policy::automatic: + return eigen_encapsulate(src); + case return_value_policy::move: + return eigen_encapsulate(new CType(std::move(*src))); + case return_value_policy::copy: + return eigen_array_cast(*src); + case return_value_policy::reference: + case return_value_policy::automatic_reference: + return eigen_ref_array(*src); + case return_value_policy::reference_internal: + return eigen_ref_array(*src, parent); + default: + throw cast_error("unhandled return_value_policy: should not happen!"); + }; + } + +public: + + // Normal returned non-reference, non-const value: + static handle cast(Type &&src, return_value_policy /* policy */, handle parent) { + return cast_impl(&src, return_value_policy::move, parent); + } + // If you return a non-reference const, we mark the numpy array readonly: + static handle cast(const Type &&src, return_value_policy /* policy */, handle parent) { + return cast_impl(&src, return_value_policy::move, parent); + } + // lvalue reference return; default (automatic) becomes copy + static handle cast(Type &src, return_value_policy policy, handle parent) { + if (policy == return_value_policy::automatic || policy == return_value_policy::automatic_reference) + policy = return_value_policy::copy; + return cast_impl(&src, policy, parent); + } + // const lvalue reference return; default (automatic) becomes copy + static handle cast(const Type &src, return_value_policy policy, handle parent) { + if (policy == return_value_policy::automatic || policy == return_value_policy::automatic_reference) + policy = return_value_policy::copy; + return cast(&src, policy, parent); + } + // non-const pointer return + static handle cast(Type *src, return_value_policy policy, handle parent) { + return cast_impl(src, policy, parent); + } + // const pointer return + static handle cast(const Type *src, return_value_policy policy, handle parent) { + return cast_impl(src, policy, parent); + } + + static PYBIND11_DESCR name() { return props::descriptor(); } + + operator Type*() { return &value; } + operator Type&() { return value; } + operator Type&&() && { return std::move(value); } + template using cast_op_type = movable_cast_op_type; + +private: + Type value; +}; + +// Eigen Ref/Map classes have slightly different policy requirements, meaning we don't want to force +// `move` when a Ref/Map rvalue is returned; we treat Ref<> sort of like a pointer (we care about +// the underlying data, not the outer shell). +template +struct return_value_policy_override::value>> { + static return_value_policy policy(return_value_policy p) { return p; } +}; + +// Base class for casting reference/map/block/etc. objects back to python. +template struct eigen_map_caster { +private: + using props = EigenProps; + +public: + + // Directly referencing a ref/map's data is a bit dangerous (whatever the map/ref points to has + // to stay around), but we'll allow it under the assumption that you know what you're doing (and + // have an appropriate keep_alive in place). We return a numpy array pointing directly at the + // ref's data (The numpy array ends up read-only if the ref was to a const matrix type.) Note + // that this means you need to ensure you don't destroy the object in some other way (e.g. with + // an appropriate keep_alive, or with a reference to a statically allocated matrix). + static handle cast(const MapType &src, return_value_policy policy, handle parent) { + switch (policy) { + case return_value_policy::copy: + return eigen_array_cast(src); + case return_value_policy::reference_internal: + return eigen_array_cast(src, parent, is_eigen_mutable_map::value); + case return_value_policy::reference: + case return_value_policy::automatic: + case return_value_policy::automatic_reference: + return eigen_array_cast(src, none(), is_eigen_mutable_map::value); + default: + // move, take_ownership don't make any sense for a ref/map: + pybind11_fail("Invalid return_value_policy for Eigen Map/Ref/Block type"); + } + } + + static PYBIND11_DESCR name() { return props::descriptor(); } + + // Explicitly delete these: support python -> C++ conversion on these (i.e. these can be return + // types but not bound arguments). We still provide them (with an explicitly delete) so that + // you end up here if you try anyway. + bool load(handle, bool) = delete; + operator MapType() = delete; + template using cast_op_type = MapType; +}; + +// We can return any map-like object (but can only load Refs, specialized next): +template struct type_caster::value>> + : eigen_map_caster {}; + +// Loader for Ref<...> arguments. See the documentation for info on how to make this work without +// copying (it requires some extra effort in many cases). +template +struct type_caster< + Eigen::Ref, + enable_if_t>::value> +> : public eigen_map_caster> { +private: + using Type = Eigen::Ref; + using props = EigenProps; + using Scalar = typename props::Scalar; + using MapType = Eigen::Map; + using Array = array_t; + static constexpr bool need_writeable = is_eigen_mutable_map::value; + // Delay construction (these have no default constructor) + std::unique_ptr map; + std::unique_ptr ref; + // Our array. When possible, this is just a numpy array pointing to the source data, but + // sometimes we can't avoid copying (e.g. input is not a numpy array at all, has an incompatible + // layout, or is an array of a type that needs to be converted). Using a numpy temporary + // (rather than an Eigen temporary) saves an extra copy when we need both type conversion and + // storage order conversion. (Note that we refuse to use this temporary copy when loading an + // argument for a Ref with M non-const, i.e. a read-write reference). + Array copy_or_ref; +public: + bool load(handle src, bool convert) { + // First check whether what we have is already an array of the right type. If not, we can't + // avoid a copy (because the copy is also going to do type conversion). + bool need_copy = !isinstance(src); + + EigenConformable fits; + if (!need_copy) { + // We don't need a converting copy, but we also need to check whether the strides are + // compatible with the Ref's stride requirements + Array aref = reinterpret_borrow(src); + + if (aref && (!need_writeable || aref.writeable())) { + fits = props::conformable(aref); + if (!fits) return false; // Incompatible dimensions + if (!fits.template stride_compatible()) + need_copy = true; + else + copy_or_ref = std::move(aref); + } + else { + need_copy = true; + } + } + + if (need_copy) { + // We need to copy: If we need a mutable reference, or we're not supposed to convert + // (either because we're in the no-convert overload pass, or because we're explicitly + // instructed not to copy (via `py::arg().noconvert()`) we have to fail loading. + if (!convert || need_writeable) return false; + + Array copy = Array::ensure(src); + if (!copy) return false; + fits = props::conformable(copy); + if (!fits || !fits.template stride_compatible()) + return false; + copy_or_ref = std::move(copy); + loader_life_support::add_patient(copy_or_ref); + } + + ref.reset(); + map.reset(new MapType(data(copy_or_ref), fits.rows, fits.cols, make_stride(fits.stride.outer(), fits.stride.inner()))); + ref.reset(new Type(*map)); + + return true; + } + + operator Type*() { return ref.get(); } + operator Type&() { return *ref; } + template using cast_op_type = pybind11::detail::cast_op_type<_T>; + +private: + template ::value, int> = 0> + Scalar *data(Array &a) { return a.mutable_data(); } + + template ::value, int> = 0> + const Scalar *data(Array &a) { return a.data(); } + + // Attempt to figure out a constructor of `Stride` that will work. + // If both strides are fixed, use a default constructor: + template using stride_ctor_default = bool_constant< + S::InnerStrideAtCompileTime != Eigen::Dynamic && S::OuterStrideAtCompileTime != Eigen::Dynamic && + std::is_default_constructible::value>; + // Otherwise, if there is a two-index constructor, assume it is (outer,inner) like + // Eigen::Stride, and use it: + template using stride_ctor_dual = bool_constant< + !stride_ctor_default::value && std::is_constructible::value>; + // Otherwise, if there is a one-index constructor, and just one of the strides is dynamic, use + // it (passing whichever stride is dynamic). + template using stride_ctor_outer = bool_constant< + !any_of, stride_ctor_dual>::value && + S::OuterStrideAtCompileTime == Eigen::Dynamic && S::InnerStrideAtCompileTime != Eigen::Dynamic && + std::is_constructible::value>; + template using stride_ctor_inner = bool_constant< + !any_of, stride_ctor_dual>::value && + S::InnerStrideAtCompileTime == Eigen::Dynamic && S::OuterStrideAtCompileTime != Eigen::Dynamic && + std::is_constructible::value>; + + template ::value, int> = 0> + static S make_stride(EigenIndex, EigenIndex) { return S(); } + template ::value, int> = 0> + static S make_stride(EigenIndex outer, EigenIndex inner) { return S(outer, inner); } + template ::value, int> = 0> + static S make_stride(EigenIndex outer, EigenIndex) { return S(outer); } + template ::value, int> = 0> + static S make_stride(EigenIndex, EigenIndex inner) { return S(inner); } + +}; + +// type_caster for special matrix types (e.g. DiagonalMatrix), which are EigenBase, but not +// EigenDense (i.e. they don't have a data(), at least not with the usual matrix layout). +// load() is not supported, but we can cast them into the python domain by first copying to a +// regular Eigen::Matrix, then casting that. +template +struct type_caster::value>> { +protected: + using Matrix = Eigen::Matrix; + using props = EigenProps; +public: + static handle cast(const Type &src, return_value_policy /* policy */, handle /* parent */) { + handle h = eigen_encapsulate(new Matrix(src)); + return h; + } + static handle cast(const Type *src, return_value_policy policy, handle parent) { return cast(*src, policy, parent); } + + static PYBIND11_DESCR name() { return props::descriptor(); } + + // Explicitly delete these: support python -> C++ conversion on these (i.e. these can be return + // types but not bound arguments). We still provide them (with an explicitly delete) so that + // you end up here if you try anyway. + bool load(handle, bool) = delete; + operator Type() = delete; + template using cast_op_type = Type; +}; + +template +struct type_caster::value>> { + typedef typename Type::Scalar Scalar; + typedef remove_reference_t().outerIndexPtr())> StorageIndex; + typedef typename Type::Index Index; + static constexpr bool rowMajor = Type::IsRowMajor; + + bool load(handle src, bool) { + if (!src) + return false; + + auto obj = reinterpret_borrow(src); + object sparse_module = module::import("scipy.sparse"); + object matrix_type = sparse_module.attr( + rowMajor ? "csr_matrix" : "csc_matrix"); + + if (!obj.get_type().is(matrix_type)) { + try { + obj = matrix_type(obj); + } catch (const error_already_set &) { + return false; + } + } + + auto values = array_t((object) obj.attr("data")); + auto innerIndices = array_t((object) obj.attr("indices")); + auto outerIndices = array_t((object) obj.attr("indptr")); + auto shape = pybind11::tuple((pybind11::object) obj.attr("shape")); + auto nnz = obj.attr("nnz").cast(); + + if (!values || !innerIndices || !outerIndices) + return false; + + value = Eigen::MappedSparseMatrix( + shape[0].cast(), shape[1].cast(), nnz, + outerIndices.mutable_data(), innerIndices.mutable_data(), values.mutable_data()); + + return true; + } + + static handle cast(const Type &src, return_value_policy /* policy */, handle /* parent */) { + const_cast(src).makeCompressed(); + + object matrix_type = module::import("scipy.sparse").attr( + rowMajor ? "csr_matrix" : "csc_matrix"); + + array data(src.nonZeros(), src.valuePtr()); + array outerIndices((rowMajor ? src.rows() : src.cols()) + 1, src.outerIndexPtr()); + array innerIndices(src.nonZeros(), src.innerIndexPtr()); + + return matrix_type( + std::make_tuple(data, innerIndices, outerIndices), + std::make_pair(src.rows(), src.cols()) + ).release(); + } + + PYBIND11_TYPE_CASTER(Type, _<(Type::IsRowMajor) != 0>("scipy.sparse.csr_matrix[", "scipy.sparse.csc_matrix[") + + npy_format_descriptor::name() + _("]")); +}; + +NAMESPACE_END(detail) +NAMESPACE_END(PYBIND11_NAMESPACE) + +#if defined(__GNUG__) || defined(__clang__) +# pragma GCC diagnostic pop +#elif defined(_MSC_VER) +# pragma warning(pop) +#endif diff --git a/ml/dlib/dlib/external/pybind11/include/pybind11/embed.h b/ml/dlib/dlib/external/pybind11/include/pybind11/embed.h new file mode 100644 index 000000000..9abc61c34 --- /dev/null +++ b/ml/dlib/dlib/external/pybind11/include/pybind11/embed.h @@ -0,0 +1,194 @@ +/* + pybind11/embed.h: Support for embedding the interpreter + + Copyright (c) 2017 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include "pybind11.h" +#include "eval.h" + +#if defined(PYPY_VERSION) +# error Embedding the interpreter is not supported with PyPy +#endif + +#if PY_MAJOR_VERSION >= 3 +# define PYBIND11_EMBEDDED_MODULE_IMPL(name) \ + extern "C" PyObject *pybind11_init_impl_##name() { \ + return pybind11_init_wrapper_##name(); \ + } +#else +# define PYBIND11_EMBEDDED_MODULE_IMPL(name) \ + extern "C" void pybind11_init_impl_##name() { \ + pybind11_init_wrapper_##name(); \ + } +#endif + +/** \rst + Add a new module to the table of builtins for the interpreter. Must be + defined in global scope. The first macro parameter is the name of the + module (without quotes). The second parameter is the variable which will + be used as the interface to add functions and classes to the module. + + .. code-block:: cpp + + PYBIND11_EMBEDDED_MODULE(example, m) { + // ... initialize functions and classes here + m.def("foo", []() { + return "Hello, World!"; + }); + } + \endrst */ +#define PYBIND11_EMBEDDED_MODULE(name, variable) \ + static void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &); \ + static PyObject PYBIND11_CONCAT(*pybind11_init_wrapper_, name)() { \ + auto m = pybind11::module(PYBIND11_TOSTRING(name)); \ + try { \ + PYBIND11_CONCAT(pybind11_init_, name)(m); \ + return m.ptr(); \ + } catch (pybind11::error_already_set &e) { \ + PyErr_SetString(PyExc_ImportError, e.what()); \ + return nullptr; \ + } catch (const std::exception &e) { \ + PyErr_SetString(PyExc_ImportError, e.what()); \ + return nullptr; \ + } \ + } \ + PYBIND11_EMBEDDED_MODULE_IMPL(name) \ + pybind11::detail::embedded_module name(PYBIND11_TOSTRING(name), \ + PYBIND11_CONCAT(pybind11_init_impl_, name)); \ + void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &variable) + + +NAMESPACE_BEGIN(PYBIND11_NAMESPACE) +NAMESPACE_BEGIN(detail) + +/// Python 2.7/3.x compatible version of `PyImport_AppendInittab` and error checks. +struct embedded_module { +#if PY_MAJOR_VERSION >= 3 + using init_t = PyObject *(*)(); +#else + using init_t = void (*)(); +#endif + embedded_module(const char *name, init_t init) { + if (Py_IsInitialized()) + pybind11_fail("Can't add new modules after the interpreter has been initialized"); + + auto result = PyImport_AppendInittab(name, init); + if (result == -1) + pybind11_fail("Insufficient memory to add a new module"); + } +}; + +NAMESPACE_END(detail) + +/** \rst + Initialize the Python interpreter. No other pybind11 or CPython API functions can be + called before this is done; with the exception of `PYBIND11_EMBEDDED_MODULE`. The + optional parameter can be used to skip the registration of signal handlers (see the + Python documentation for details). Calling this function again after the interpreter + has already been initialized is a fatal error. + \endrst */ +inline void initialize_interpreter(bool init_signal_handlers = true) { + if (Py_IsInitialized()) + pybind11_fail("The interpreter is already running"); + + Py_InitializeEx(init_signal_handlers ? 1 : 0); + + // Make .py files in the working directory available by default + module::import("sys").attr("path").cast().append("."); +} + +/** \rst + Shut down the Python interpreter. No pybind11 or CPython API functions can be called + after this. In addition, pybind11 objects must not outlive the interpreter: + + .. code-block:: cpp + + { // BAD + py::initialize_interpreter(); + auto hello = py::str("Hello, World!"); + py::finalize_interpreter(); + } // <-- BOOM, hello's destructor is called after interpreter shutdown + + { // GOOD + py::initialize_interpreter(); + { // scoped + auto hello = py::str("Hello, World!"); + } // <-- OK, hello is cleaned up properly + py::finalize_interpreter(); + } + + { // BETTER + py::scoped_interpreter guard{}; + auto hello = py::str("Hello, World!"); + } + + .. warning:: + + The interpreter can be restarted by calling `initialize_interpreter` again. + Modules created using pybind11 can be safely re-initialized. However, Python + itself cannot completely unload binary extension modules and there are several + caveats with regard to interpreter restarting. All the details can be found + in the CPython documentation. In short, not all interpreter memory may be + freed, either due to reference cycles or user-created global data. + + \endrst */ +inline void finalize_interpreter() { + handle builtins(PyEval_GetBuiltins()); + const char *id = PYBIND11_INTERNALS_ID; + + // Get the internals pointer (without creating it if it doesn't exist). It's possible for the + // internals to be created during Py_Finalize() (e.g. if a py::capsule calls `get_internals()` + // during destruction), so we get the pointer-pointer here and check it after Py_Finalize(). + detail::internals **internals_ptr_ptr = detail::get_internals_pp(); + // It could also be stashed in builtins, so look there too: + if (builtins.contains(id) && isinstance(builtins[id])) + internals_ptr_ptr = capsule(builtins[id]); + + Py_Finalize(); + + if (internals_ptr_ptr) { + delete *internals_ptr_ptr; + *internals_ptr_ptr = nullptr; + } +} + +/** \rst + Scope guard version of `initialize_interpreter` and `finalize_interpreter`. + This a move-only guard and only a single instance can exist. + + .. code-block:: cpp + + #include + + int main() { + py::scoped_interpreter guard{}; + py::print(Hello, World!); + } // <-- interpreter shutdown + \endrst */ +class scoped_interpreter { +public: + scoped_interpreter(bool init_signal_handlers = true) { + initialize_interpreter(init_signal_handlers); + } + + scoped_interpreter(const scoped_interpreter &) = delete; + scoped_interpreter(scoped_interpreter &&other) noexcept { other.is_valid = false; } + scoped_interpreter &operator=(const scoped_interpreter &) = delete; + scoped_interpreter &operator=(scoped_interpreter &&) = delete; + + ~scoped_interpreter() { + if (is_valid) + finalize_interpreter(); + } + +private: + bool is_valid = true; +}; + +NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/ml/dlib/dlib/external/pybind11/include/pybind11/eval.h b/ml/dlib/dlib/external/pybind11/include/pybind11/eval.h new file mode 100644 index 000000000..ea85ba1db --- /dev/null +++ b/ml/dlib/dlib/external/pybind11/include/pybind11/eval.h @@ -0,0 +1,117 @@ +/* + pybind11/exec.h: Support for evaluating Python expressions and statements + from strings and files + + Copyright (c) 2016 Klemens Morgenstern and + Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include "pybind11.h" + +NAMESPACE_BEGIN(PYBIND11_NAMESPACE) + +enum eval_mode { + /// Evaluate a string containing an isolated expression + eval_expr, + + /// Evaluate a string containing a single statement. Returns \c none + eval_single_statement, + + /// Evaluate a string containing a sequence of statement. Returns \c none + eval_statements +}; + +template +object eval(str expr, object global = globals(), object local = object()) { + if (!local) + local = global; + + /* PyRun_String does not accept a PyObject / encoding specifier, + this seems to be the only alternative */ + std::string buffer = "# -*- coding: utf-8 -*-\n" + (std::string) expr; + + int start; + switch (mode) { + case eval_expr: start = Py_eval_input; break; + case eval_single_statement: start = Py_single_input; break; + case eval_statements: start = Py_file_input; break; + default: pybind11_fail("invalid evaluation mode"); + } + + PyObject *result = PyRun_String(buffer.c_str(), start, global.ptr(), local.ptr()); + if (!result) + throw error_already_set(); + return reinterpret_steal(result); +} + +template +object eval(const char (&s)[N], object global = globals(), object local = object()) { + /* Support raw string literals by removing common leading whitespace */ + auto expr = (s[0] == '\n') ? str(module::import("textwrap").attr("dedent")(s)) + : str(s); + return eval(expr, global, local); +} + +inline void exec(str expr, object global = globals(), object local = object()) { + eval(expr, global, local); +} + +template +void exec(const char (&s)[N], object global = globals(), object local = object()) { + eval(s, global, local); +} + +template +object eval_file(str fname, object global = globals(), object local = object()) { + if (!local) + local = global; + + int start; + switch (mode) { + case eval_expr: start = Py_eval_input; break; + case eval_single_statement: start = Py_single_input; break; + case eval_statements: start = Py_file_input; break; + default: pybind11_fail("invalid evaluation mode"); + } + + int closeFile = 1; + std::string fname_str = (std::string) fname; +#if PY_VERSION_HEX >= 0x03040000 + FILE *f = _Py_fopen_obj(fname.ptr(), "r"); +#elif PY_VERSION_HEX >= 0x03000000 + FILE *f = _Py_fopen(fname.ptr(), "r"); +#else + /* No unicode support in open() :( */ + auto fobj = reinterpret_steal(PyFile_FromString( + const_cast(fname_str.c_str()), + const_cast("r"))); + FILE *f = nullptr; + if (fobj) + f = PyFile_AsFile(fobj.ptr()); + closeFile = 0; +#endif + if (!f) { + PyErr_Clear(); + pybind11_fail("File \"" + fname_str + "\" could not be opened!"); + } + +#if PY_VERSION_HEX < 0x03000000 && defined(PYPY_VERSION) + PyObject *result = PyRun_File(f, fname_str.c_str(), start, global.ptr(), + local.ptr()); + (void) closeFile; +#else + PyObject *result = PyRun_FileEx(f, fname_str.c_str(), start, global.ptr(), + local.ptr(), closeFile); +#endif + + if (!result) + throw error_already_set(); + return reinterpret_steal(result); +} + +NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/ml/dlib/dlib/external/pybind11/include/pybind11/functional.h b/ml/dlib/dlib/external/pybind11/include/pybind11/functional.h new file mode 100644 index 000000000..eda14ba58 --- /dev/null +++ b/ml/dlib/dlib/external/pybind11/include/pybind11/functional.h @@ -0,0 +1,85 @@ +/* + pybind11/functional.h: std::function<> support + + Copyright (c) 2016 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include "pybind11.h" +#include + +NAMESPACE_BEGIN(PYBIND11_NAMESPACE) +NAMESPACE_BEGIN(detail) + +template +struct type_caster> { + using type = std::function; + using retval_type = conditional_t::value, void_type, Return>; + using function_type = Return (*) (Args...); + +public: + bool load(handle src, bool convert) { + if (src.is_none()) { + // Defer accepting None to other overloads (if we aren't in convert mode): + if (!convert) return false; + return true; + } + + if (!isinstance(src)) + return false; + + auto func = reinterpret_borrow(src); + + /* + When passing a C++ function as an argument to another C++ + function via Python, every function call would normally involve + a full C++ -> Python -> C++ roundtrip, which can be prohibitive. + Here, we try to at least detect the case where the function is + stateless (i.e. function pointer or lambda function without + captured variables), in which case the roundtrip can be avoided. + */ + if (auto cfunc = func.cpp_function()) { + auto c = reinterpret_borrow(PyCFunction_GET_SELF(cfunc.ptr())); + auto rec = (function_record *) c; + + if (rec && rec->is_stateless && + same_type(typeid(function_type), *reinterpret_cast(rec->data[1]))) { + struct capture { function_type f; }; + value = ((capture *) &rec->data)->f; + return true; + } + } + + value = [func](Args... args) -> Return { + gil_scoped_acquire acq; + object retval(func(std::forward(args)...)); + /* Visual studio 2015 parser issue: need parentheses around this expression */ + return (retval.template cast()); + }; + return true; + } + + template + static handle cast(Func &&f_, return_value_policy policy, handle /* parent */) { + if (!f_) + return none().inc_ref(); + + auto result = f_.template target(); + if (result) + return cpp_function(*result, policy).release(); + else + return cpp_function(std::forward(f_), policy).release(); + } + + PYBIND11_TYPE_CASTER(type, _("Callable[[") + + argument_loader::arg_names() + _("], ") + + make_caster::name() + + _("]")); +}; + +NAMESPACE_END(detail) +NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/ml/dlib/dlib/external/pybind11/include/pybind11/iostream.h b/ml/dlib/dlib/external/pybind11/include/pybind11/iostream.h new file mode 100644 index 000000000..a9c27aac1 --- /dev/null +++ b/ml/dlib/dlib/external/pybind11/include/pybind11/iostream.h @@ -0,0 +1,200 @@ +/* + pybind11/iostream.h -- Tools to assist with redirecting cout and cerr to Python + + Copyright (c) 2017 Henry F. Schreiner + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include "pybind11.h" + +#include +#include +#include +#include +#include + +NAMESPACE_BEGIN(PYBIND11_NAMESPACE) +NAMESPACE_BEGIN(detail) + +// Buffer that writes to Python instead of C++ +class pythonbuf : public std::streambuf { +private: + using traits_type = std::streambuf::traits_type; + + char d_buffer[1024]; + object pywrite; + object pyflush; + + int overflow(int c) { + if (!traits_type::eq_int_type(c, traits_type::eof())) { + *pptr() = traits_type::to_char_type(c); + pbump(1); + } + return sync() ? traits_type::not_eof(c) : traits_type::eof(); + } + + int sync() { + if (pbase() != pptr()) { + // This subtraction cannot be negative, so dropping the sign + str line(pbase(), static_cast(pptr() - pbase())); + + pywrite(line); + pyflush(); + + setp(pbase(), epptr()); + } + return 0; + } + +public: + pythonbuf(object pyostream) + : pywrite(pyostream.attr("write")), + pyflush(pyostream.attr("flush")) { + setp(d_buffer, d_buffer + sizeof(d_buffer) - 1); + } + + /// Sync before destroy + ~pythonbuf() { + sync(); + } +}; + +NAMESPACE_END(detail) + + +/** \rst + This a move-only guard that redirects output. + + .. code-block:: cpp + + #include + + ... + + { + py::scoped_ostream_redirect output; + std::cout << "Hello, World!"; // Python stdout + } // <-- return std::cout to normal + + You can explicitly pass the c++ stream and the python object, + for example to guard stderr instead. + + .. code-block:: cpp + + { + py::scoped_ostream_redirect output{std::cerr, py::module::import("sys").attr("stderr")}; + std::cerr << "Hello, World!"; + } + \endrst */ +class scoped_ostream_redirect { +protected: + std::streambuf *old; + std::ostream &costream; + detail::pythonbuf buffer; + +public: + scoped_ostream_redirect( + std::ostream &costream = std::cout, + object pyostream = module::import("sys").attr("stdout")) + : costream(costream), buffer(pyostream) { + old = costream.rdbuf(&buffer); + } + + ~scoped_ostream_redirect() { + costream.rdbuf(old); + } + + scoped_ostream_redirect(const scoped_ostream_redirect &) = delete; + scoped_ostream_redirect(scoped_ostream_redirect &&other) = default; + scoped_ostream_redirect &operator=(const scoped_ostream_redirect &) = delete; + scoped_ostream_redirect &operator=(scoped_ostream_redirect &&) = delete; +}; + + +/** \rst + Like `scoped_ostream_redirect`, but redirects cerr by default. This class + is provided primary to make ``py::call_guard`` easier to make. + + .. code-block:: cpp + + m.def("noisy_func", &noisy_func, + py::call_guard()); + +\endrst */ +class scoped_estream_redirect : public scoped_ostream_redirect { +public: + scoped_estream_redirect( + std::ostream &costream = std::cerr, + object pyostream = module::import("sys").attr("stderr")) + : scoped_ostream_redirect(costream,pyostream) {} +}; + + +NAMESPACE_BEGIN(detail) + +// Class to redirect output as a context manager. C++ backend. +class OstreamRedirect { + bool do_stdout_; + bool do_stderr_; + std::unique_ptr redirect_stdout; + std::unique_ptr redirect_stderr; + +public: + OstreamRedirect(bool do_stdout = true, bool do_stderr = true) + : do_stdout_(do_stdout), do_stderr_(do_stderr) {} + + void enter() { + if (do_stdout_) + redirect_stdout.reset(new scoped_ostream_redirect()); + if (do_stderr_) + redirect_stderr.reset(new scoped_estream_redirect()); + } + + void exit() { + redirect_stdout.reset(); + redirect_stderr.reset(); + } +}; + +NAMESPACE_END(detail) + +/** \rst + This is a helper function to add a C++ redirect context manager to Python + instead of using a C++ guard. To use it, add the following to your binding code: + + .. code-block:: cpp + + #include + + ... + + py::add_ostream_redirect(m, "ostream_redirect"); + + You now have a Python context manager that redirects your output: + + .. code-block:: python + + with m.ostream_redirect(): + m.print_to_cout_function() + + This manager can optionally be told which streams to operate on: + + .. code-block:: python + + with m.ostream_redirect(stdout=true, stderr=true): + m.noisy_function_with_error_printing() + + \endrst */ +inline class_ add_ostream_redirect(module m, std::string name = "ostream_redirect") { + return class_(m, name.c_str(), module_local()) + .def(init(), arg("stdout")=true, arg("stderr")=true) + .def("__enter__", &detail::OstreamRedirect::enter) + .def("__exit__", [](detail::OstreamRedirect &self, args) { self.exit(); }); +} + +NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/ml/dlib/dlib/external/pybind11/include/pybind11/numpy.h b/ml/dlib/dlib/external/pybind11/include/pybind11/numpy.h new file mode 100644 index 000000000..b1600dc2e --- /dev/null +++ b/ml/dlib/dlib/external/pybind11/include/pybind11/numpy.h @@ -0,0 +1,1600 @@ +/* + pybind11/numpy.h: Basic NumPy support, vectorize() wrapper + + Copyright (c) 2016 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include "pybind11.h" +#include "complex.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(_MSC_VER) +# pragma warning(push) +# pragma warning(disable: 4127) // warning C4127: Conditional expression is constant +#endif + +/* This will be true on all flat address space platforms and allows us to reduce the + whole npy_intp / ssize_t / Py_intptr_t business down to just ssize_t for all size + and dimension types (e.g. shape, strides, indexing), instead of inflicting this + upon the library user. */ +static_assert(sizeof(ssize_t) == sizeof(Py_intptr_t), "ssize_t != Py_intptr_t"); + +NAMESPACE_BEGIN(PYBIND11_NAMESPACE) + +class array; // Forward declaration + +NAMESPACE_BEGIN(detail) +template struct npy_format_descriptor; + +struct PyArrayDescr_Proxy { + PyObject_HEAD + PyObject *typeobj; + char kind; + char type; + char byteorder; + char flags; + int type_num; + int elsize; + int alignment; + char *subarray; + PyObject *fields; + PyObject *names; +}; + +struct PyArray_Proxy { + PyObject_HEAD + char *data; + int nd; + ssize_t *dimensions; + ssize_t *strides; + PyObject *base; + PyObject *descr; + int flags; +}; + +struct PyVoidScalarObject_Proxy { + PyObject_VAR_HEAD + char *obval; + PyArrayDescr_Proxy *descr; + int flags; + PyObject *base; +}; + +struct numpy_type_info { + PyObject* dtype_ptr; + std::string format_str; +}; + +struct numpy_internals { + std::unordered_map registered_dtypes; + + numpy_type_info *get_type_info(const std::type_info& tinfo, bool throw_if_missing = true) { + auto it = registered_dtypes.find(std::type_index(tinfo)); + if (it != registered_dtypes.end()) + return &(it->second); + if (throw_if_missing) + pybind11_fail(std::string("NumPy type info missing for ") + tinfo.name()); + return nullptr; + } + + template numpy_type_info *get_type_info(bool throw_if_missing = true) { + return get_type_info(typeid(typename std::remove_cv::type), throw_if_missing); + } +}; + +inline PYBIND11_NOINLINE void load_numpy_internals(numpy_internals* &ptr) { + ptr = &get_or_create_shared_data("_numpy_internals"); +} + +inline numpy_internals& get_numpy_internals() { + static numpy_internals* ptr = nullptr; + if (!ptr) + load_numpy_internals(ptr); + return *ptr; +} + +struct npy_api { + enum constants { + NPY_ARRAY_C_CONTIGUOUS_ = 0x0001, + NPY_ARRAY_F_CONTIGUOUS_ = 0x0002, + NPY_ARRAY_OWNDATA_ = 0x0004, + NPY_ARRAY_FORCECAST_ = 0x0010, + NPY_ARRAY_ENSUREARRAY_ = 0x0040, + NPY_ARRAY_ALIGNED_ = 0x0100, + NPY_ARRAY_WRITEABLE_ = 0x0400, + NPY_BOOL_ = 0, + NPY_BYTE_, NPY_UBYTE_, + NPY_SHORT_, NPY_USHORT_, + NPY_INT_, NPY_UINT_, + NPY_LONG_, NPY_ULONG_, + NPY_LONGLONG_, NPY_ULONGLONG_, + NPY_FLOAT_, NPY_DOUBLE_, NPY_LONGDOUBLE_, + NPY_CFLOAT_, NPY_CDOUBLE_, NPY_CLONGDOUBLE_, + NPY_OBJECT_ = 17, + NPY_STRING_, NPY_UNICODE_, NPY_VOID_ + }; + + typedef struct { + Py_intptr_t *ptr; + int len; + } PyArray_Dims; + + static npy_api& get() { + static npy_api api = lookup(); + return api; + } + + bool PyArray_Check_(PyObject *obj) const { + return (bool) PyObject_TypeCheck(obj, PyArray_Type_); + } + bool PyArrayDescr_Check_(PyObject *obj) const { + return (bool) PyObject_TypeCheck(obj, PyArrayDescr_Type_); + } + + unsigned int (*PyArray_GetNDArrayCFeatureVersion_)(); + PyObject *(*PyArray_DescrFromType_)(int); + PyObject *(*PyArray_NewFromDescr_) + (PyTypeObject *, PyObject *, int, Py_intptr_t *, + Py_intptr_t *, void *, int, PyObject *); + PyObject *(*PyArray_DescrNewFromType_)(int); + int (*PyArray_CopyInto_)(PyObject *, PyObject *); + PyObject *(*PyArray_NewCopy_)(PyObject *, int); + PyTypeObject *PyArray_Type_; + PyTypeObject *PyVoidArrType_Type_; + PyTypeObject *PyArrayDescr_Type_; + PyObject *(*PyArray_DescrFromScalar_)(PyObject *); + PyObject *(*PyArray_FromAny_) (PyObject *, PyObject *, int, int, int, PyObject *); + int (*PyArray_DescrConverter_) (PyObject *, PyObject **); + bool (*PyArray_EquivTypes_) (PyObject *, PyObject *); + int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, char, PyObject **, int *, + Py_ssize_t *, PyObject **, PyObject *); + PyObject *(*PyArray_Squeeze_)(PyObject *); + int (*PyArray_SetBaseObject_)(PyObject *, PyObject *); + PyObject* (*PyArray_Resize_)(PyObject*, PyArray_Dims*, int, int); +private: + enum functions { + API_PyArray_GetNDArrayCFeatureVersion = 211, + API_PyArray_Type = 2, + API_PyArrayDescr_Type = 3, + API_PyVoidArrType_Type = 39, + API_PyArray_DescrFromType = 45, + API_PyArray_DescrFromScalar = 57, + API_PyArray_FromAny = 69, + API_PyArray_Resize = 80, + API_PyArray_CopyInto = 82, + API_PyArray_NewCopy = 85, + API_PyArray_NewFromDescr = 94, + API_PyArray_DescrNewFromType = 9, + API_PyArray_DescrConverter = 174, + API_PyArray_EquivTypes = 182, + API_PyArray_GetArrayParamsFromObject = 278, + API_PyArray_Squeeze = 136, + API_PyArray_SetBaseObject = 282 + }; + + static npy_api lookup() { + module m = module::import("numpy.core.multiarray"); + auto c = m.attr("_ARRAY_API"); +#if PY_MAJOR_VERSION >= 3 + void **api_ptr = (void **) PyCapsule_GetPointer(c.ptr(), NULL); +#else + void **api_ptr = (void **) PyCObject_AsVoidPtr(c.ptr()); +#endif + npy_api api; +#define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func]; + DECL_NPY_API(PyArray_GetNDArrayCFeatureVersion); + if (api.PyArray_GetNDArrayCFeatureVersion_() < 0x7) + pybind11_fail("pybind11 numpy support requires numpy >= 1.7.0"); + DECL_NPY_API(PyArray_Type); + DECL_NPY_API(PyVoidArrType_Type); + DECL_NPY_API(PyArrayDescr_Type); + DECL_NPY_API(PyArray_DescrFromType); + DECL_NPY_API(PyArray_DescrFromScalar); + DECL_NPY_API(PyArray_FromAny); + DECL_NPY_API(PyArray_Resize); + DECL_NPY_API(PyArray_CopyInto); + DECL_NPY_API(PyArray_NewCopy); + DECL_NPY_API(PyArray_NewFromDescr); + DECL_NPY_API(PyArray_DescrNewFromType); + DECL_NPY_API(PyArray_DescrConverter); + DECL_NPY_API(PyArray_EquivTypes); + DECL_NPY_API(PyArray_GetArrayParamsFromObject); + DECL_NPY_API(PyArray_Squeeze); + DECL_NPY_API(PyArray_SetBaseObject); +#undef DECL_NPY_API + return api; + } +}; + +inline PyArray_Proxy* array_proxy(void* ptr) { + return reinterpret_cast(ptr); +} + +inline const PyArray_Proxy* array_proxy(const void* ptr) { + return reinterpret_cast(ptr); +} + +inline PyArrayDescr_Proxy* array_descriptor_proxy(PyObject* ptr) { + return reinterpret_cast(ptr); +} + +inline const PyArrayDescr_Proxy* array_descriptor_proxy(const PyObject* ptr) { + return reinterpret_cast(ptr); +} + +inline bool check_flags(const void* ptr, int flag) { + return (flag == (array_proxy(ptr)->flags & flag)); +} + +template struct is_std_array : std::false_type { }; +template struct is_std_array> : std::true_type { }; +template struct is_complex : std::false_type { }; +template struct is_complex> : std::true_type { }; + +template struct array_info_scalar { + typedef T type; + static constexpr bool is_array = false; + static constexpr bool is_empty = false; + static PYBIND11_DESCR extents() { return _(""); } + static void append_extents(list& /* shape */) { } +}; +// Computes underlying type and a comma-separated list of extents for array +// types (any mix of std::array and built-in arrays). An array of char is +// treated as scalar because it gets special handling. +template struct array_info : array_info_scalar { }; +template struct array_info> { + using type = typename array_info::type; + static constexpr bool is_array = true; + static constexpr bool is_empty = (N == 0) || array_info::is_empty; + static constexpr size_t extent = N; + + // appends the extents to shape + static void append_extents(list& shape) { + shape.append(N); + array_info::append_extents(shape); + } + + template::is_array, int> = 0> + static PYBIND11_DESCR extents() { + return _(); + } + + template::is_array, int> = 0> + static PYBIND11_DESCR extents() { + return concat(_(), array_info::extents()); + } +}; +// For numpy we have special handling for arrays of characters, so we don't include +// the size in the array extents. +template struct array_info : array_info_scalar { }; +template struct array_info> : array_info_scalar> { }; +template struct array_info : array_info> { }; +template using remove_all_extents_t = typename array_info::type; + +template using is_pod_struct = all_of< + std::is_standard_layout, // since we're accessing directly in memory we need a standard layout type +#if !defined(__GNUG__) || defined(_LIBCPP_VERSION) || defined(_GLIBCXX_USE_CXX11_ABI) + // _GLIBCXX_USE_CXX11_ABI indicates that we're using libstdc++ from GCC 5 or newer, independent + // of the actual compiler (Clang can also use libstdc++, but it always defines __GNUC__ == 4). + std::is_trivially_copyable, +#else + // GCC 4 doesn't implement is_trivially_copyable, so approximate it + std::is_trivially_destructible, + satisfies_any_of, +#endif + satisfies_none_of +>; + +template ssize_t byte_offset_unsafe(const Strides &) { return 0; } +template +ssize_t byte_offset_unsafe(const Strides &strides, ssize_t i, Ix... index) { + return i * strides[Dim] + byte_offset_unsafe(strides, index...); +} + +/** + * Proxy class providing unsafe, unchecked const access to array data. This is constructed through + * the `unchecked()` method of `array` or the `unchecked()` method of `array_t`. `Dims` + * will be -1 for dimensions determined at runtime. + */ +template +class unchecked_reference { +protected: + static constexpr bool Dynamic = Dims < 0; + const unsigned char *data_; + // Storing the shape & strides in local variables (i.e. these arrays) allows the compiler to + // make large performance gains on big, nested loops, but requires compile-time dimensions + conditional_t> + shape_, strides_; + const ssize_t dims_; + + friend class pybind11::array; + // Constructor for compile-time dimensions: + template + unchecked_reference(const void *data, const ssize_t *shape, const ssize_t *strides, enable_if_t) + : data_{reinterpret_cast(data)}, dims_{Dims} { + for (size_t i = 0; i < (size_t) dims_; i++) { + shape_[i] = shape[i]; + strides_[i] = strides[i]; + } + } + // Constructor for runtime dimensions: + template + unchecked_reference(const void *data, const ssize_t *shape, const ssize_t *strides, enable_if_t dims) + : data_{reinterpret_cast(data)}, shape_{shape}, strides_{strides}, dims_{dims} {} + +public: + /** + * Unchecked const reference access to data at the given indices. For a compile-time known + * number of dimensions, this requires the correct number of arguments; for run-time + * dimensionality, this is not checked (and so is up to the caller to use safely). + */ + template const T &operator()(Ix... index) const { + static_assert(ssize_t{sizeof...(Ix)} == Dims || Dynamic, + "Invalid number of indices for unchecked array reference"); + return *reinterpret_cast(data_ + byte_offset_unsafe(strides_, ssize_t(index)...)); + } + /** + * Unchecked const reference access to data; this operator only participates if the reference + * is to a 1-dimensional array. When present, this is exactly equivalent to `obj(index)`. + */ + template > + const T &operator[](ssize_t index) const { return operator()(index); } + + /// Pointer access to the data at the given indices. + template const T *data(Ix... ix) const { return &operator()(ssize_t(ix)...); } + + /// Returns the item size, i.e. sizeof(T) + constexpr static ssize_t itemsize() { return sizeof(T); } + + /// Returns the shape (i.e. size) of dimension `dim` + ssize_t shape(ssize_t dim) const { return shape_[(size_t) dim]; } + + /// Returns the number of dimensions of the array + ssize_t ndim() const { return dims_; } + + /// Returns the total number of elements in the referenced array, i.e. the product of the shapes + template + enable_if_t size() const { + return std::accumulate(shape_.begin(), shape_.end(), (ssize_t) 1, std::multiplies()); + } + template + enable_if_t size() const { + return std::accumulate(shape_, shape_ + ndim(), (ssize_t) 1, std::multiplies()); + } + + /// Returns the total number of bytes used by the referenced data. Note that the actual span in + /// memory may be larger if the referenced array has non-contiguous strides (e.g. for a slice). + ssize_t nbytes() const { + return size() * itemsize(); + } +}; + +template +class unchecked_mutable_reference : public unchecked_reference { + friend class pybind11::array; + using ConstBase = unchecked_reference; + using ConstBase::ConstBase; + using ConstBase::Dynamic; +public: + /// Mutable, unchecked access to data at the given indices. + template T& operator()(Ix... index) { + static_assert(ssize_t{sizeof...(Ix)} == Dims || Dynamic, + "Invalid number of indices for unchecked array reference"); + return const_cast(ConstBase::operator()(index...)); + } + /** + * Mutable, unchecked access data at the given index; this operator only participates if the + * reference is to a 1-dimensional array (or has runtime dimensions). When present, this is + * exactly equivalent to `obj(index)`. + */ + template > + T &operator[](ssize_t index) { return operator()(index); } + + /// Mutable pointer access to the data at the given indices. + template T *mutable_data(Ix... ix) { return &operator()(ssize_t(ix)...); } +}; + +template +struct type_caster> { + static_assert(Dim == 0 && Dim > 0 /* always fail */, "unchecked array proxy object is not castable"); +}; +template +struct type_caster> : type_caster> {}; + +NAMESPACE_END(detail) + +class dtype : public object { +public: + PYBIND11_OBJECT_DEFAULT(dtype, object, detail::npy_api::get().PyArrayDescr_Check_); + + explicit dtype(const buffer_info &info) { + dtype descr(_dtype_from_pep3118()(PYBIND11_STR_TYPE(info.format))); + // If info.itemsize == 0, use the value calculated from the format string + m_ptr = descr.strip_padding(info.itemsize ? info.itemsize : descr.itemsize()).release().ptr(); + } + + explicit dtype(const std::string &format) { + m_ptr = from_args(pybind11::str(format)).release().ptr(); + } + + dtype(const char *format) : dtype(std::string(format)) { } + + dtype(list names, list formats, list offsets, ssize_t itemsize) { + dict args; + args["names"] = names; + args["formats"] = formats; + args["offsets"] = offsets; + args["itemsize"] = pybind11::int_(itemsize); + m_ptr = from_args(args).release().ptr(); + } + + /// This is essentially the same as calling numpy.dtype(args) in Python. + static dtype from_args(object args) { + PyObject *ptr = nullptr; + if (!detail::npy_api::get().PyArray_DescrConverter_(args.release().ptr(), &ptr) || !ptr) + throw error_already_set(); + return reinterpret_steal(ptr); + } + + /// Return dtype associated with a C++ type. + template static dtype of() { + return detail::npy_format_descriptor::type>::dtype(); + } + + /// Size of the data type in bytes. + ssize_t itemsize() const { + return detail::array_descriptor_proxy(m_ptr)->elsize; + } + + /// Returns true for structured data types. + bool has_fields() const { + return detail::array_descriptor_proxy(m_ptr)->names != nullptr; + } + + /// Single-character type code. + char kind() const { + return detail::array_descriptor_proxy(m_ptr)->kind; + } + +private: + static object _dtype_from_pep3118() { + static PyObject *obj = module::import("numpy.core._internal") + .attr("_dtype_from_pep3118").cast().release().ptr(); + return reinterpret_borrow(obj); + } + + dtype strip_padding(ssize_t itemsize) { + // Recursively strip all void fields with empty names that are generated for + // padding fields (as of NumPy v1.11). + if (!has_fields()) + return *this; + + struct field_descr { PYBIND11_STR_TYPE name; object format; pybind11::int_ offset; }; + std::vector field_descriptors; + + for (auto field : attr("fields").attr("items")()) { + auto spec = field.cast(); + auto name = spec[0].cast(); + auto format = spec[1].cast()[0].cast(); + auto offset = spec[1].cast()[1].cast(); + if (!len(name) && format.kind() == 'V') + continue; + field_descriptors.push_back({(PYBIND11_STR_TYPE) name, format.strip_padding(format.itemsize()), offset}); + } + + std::sort(field_descriptors.begin(), field_descriptors.end(), + [](const field_descr& a, const field_descr& b) { + return a.offset.cast() < b.offset.cast(); + }); + + list names, formats, offsets; + for (auto& descr : field_descriptors) { + names.append(descr.name); + formats.append(descr.format); + offsets.append(descr.offset); + } + return dtype(names, formats, offsets, itemsize); + } +}; + +class array : public buffer { +public: + PYBIND11_OBJECT_CVT(array, buffer, detail::npy_api::get().PyArray_Check_, raw_array) + + enum { + c_style = detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_, + f_style = detail::npy_api::NPY_ARRAY_F_CONTIGUOUS_, + forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_ + }; + + array() : array({{0}}, static_cast(nullptr)) {} + + using ShapeContainer = detail::any_container; + using StridesContainer = detail::any_container; + + // Constructs an array taking shape/strides from arbitrary container types + array(const pybind11::dtype &dt, ShapeContainer shape, StridesContainer strides, + const void *ptr = nullptr, handle base = handle()) { + + if (strides->empty()) + *strides = c_strides(*shape, dt.itemsize()); + + auto ndim = shape->size(); + if (ndim != strides->size()) + pybind11_fail("NumPy: shape ndim doesn't match strides ndim"); + auto descr = dt; + + int flags = 0; + if (base && ptr) { + if (isinstance(base)) + /* Copy flags from base (except ownership bit) */ + flags = reinterpret_borrow(base).flags() & ~detail::npy_api::NPY_ARRAY_OWNDATA_; + else + /* Writable by default, easy to downgrade later on if needed */ + flags = detail::npy_api::NPY_ARRAY_WRITEABLE_; + } + + auto &api = detail::npy_api::get(); + auto tmp = reinterpret_steal(api.PyArray_NewFromDescr_( + api.PyArray_Type_, descr.release().ptr(), (int) ndim, shape->data(), strides->data(), + const_cast(ptr), flags, nullptr)); + if (!tmp) + throw error_already_set(); + if (ptr) { + if (base) { + api.PyArray_SetBaseObject_(tmp.ptr(), base.inc_ref().ptr()); + } else { + tmp = reinterpret_steal(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */)); + } + } + m_ptr = tmp.release().ptr(); + } + + array(const pybind11::dtype &dt, ShapeContainer shape, const void *ptr = nullptr, handle base = handle()) + : array(dt, std::move(shape), {}, ptr, base) { } + + template ::value && !std::is_same::value>> + array(const pybind11::dtype &dt, T count, const void *ptr = nullptr, handle base = handle()) + : array(dt, {{count}}, ptr, base) { } + + template + array(ShapeContainer shape, StridesContainer strides, const T *ptr, handle base = handle()) + : array(pybind11::dtype::of(), std::move(shape), std::move(strides), ptr, base) { } + + template + array(ShapeContainer shape, const T *ptr, handle base = handle()) + : array(std::move(shape), {}, ptr, base) { } + + template + explicit array(ssize_t count, const T *ptr, handle base = handle()) : array({count}, {}, ptr, base) { } + + explicit array(const buffer_info &info) + : array(pybind11::dtype(info), info.shape, info.strides, info.ptr) { } + + /// Array descriptor (dtype) + pybind11::dtype dtype() const { + return reinterpret_borrow(detail::array_proxy(m_ptr)->descr); + } + + /// Total number of elements + ssize_t size() const { + return std::accumulate(shape(), shape() + ndim(), (ssize_t) 1, std::multiplies()); + } + + /// Byte size of a single element + ssize_t itemsize() const { + return detail::array_descriptor_proxy(detail::array_proxy(m_ptr)->descr)->elsize; + } + + /// Total number of bytes + ssize_t nbytes() const { + return size() * itemsize(); + } + + /// Number of dimensions + ssize_t ndim() const { + return detail::array_proxy(m_ptr)->nd; + } + + /// Base object + object base() const { + return reinterpret_borrow(detail::array_proxy(m_ptr)->base); + } + + /// Dimensions of the array + const ssize_t* shape() const { + return detail::array_proxy(m_ptr)->dimensions; + } + + /// Dimension along a given axis + ssize_t shape(ssize_t dim) const { + if (dim >= ndim()) + fail_dim_check(dim, "invalid axis"); + return shape()[dim]; + } + + /// Strides of the array + const ssize_t* strides() const { + return detail::array_proxy(m_ptr)->strides; + } + + /// Stride along a given axis + ssize_t strides(ssize_t dim) const { + if (dim >= ndim()) + fail_dim_check(dim, "invalid axis"); + return strides()[dim]; + } + + /// Return the NumPy array flags + int flags() const { + return detail::array_proxy(m_ptr)->flags; + } + + /// If set, the array is writeable (otherwise the buffer is read-only) + bool writeable() const { + return detail::check_flags(m_ptr, detail::npy_api::NPY_ARRAY_WRITEABLE_); + } + + /// If set, the array owns the data (will be freed when the array is deleted) + bool owndata() const { + return detail::check_flags(m_ptr, detail::npy_api::NPY_ARRAY_OWNDATA_); + } + + /// Pointer to the contained data. If index is not provided, points to the + /// beginning of the buffer. May throw if the index would lead to out of bounds access. + template const void* data(Ix... index) const { + return static_cast(detail::array_proxy(m_ptr)->data + offset_at(index...)); + } + + /// Mutable pointer to the contained data. If index is not provided, points to the + /// beginning of the buffer. May throw if the index would lead to out of bounds access. + /// May throw if the array is not writeable. + template void* mutable_data(Ix... index) { + check_writeable(); + return static_cast(detail::array_proxy(m_ptr)->data + offset_at(index...)); + } + + /// Byte offset from beginning of the array to a given index (full or partial). + /// May throw if the index would lead to out of bounds access. + template ssize_t offset_at(Ix... index) const { + if ((ssize_t) sizeof...(index) > ndim()) + fail_dim_check(sizeof...(index), "too many indices for an array"); + return byte_offset(ssize_t(index)...); + } + + ssize_t offset_at() const { return 0; } + + /// Item count from beginning of the array to a given index (full or partial). + /// May throw if the index would lead to out of bounds access. + template ssize_t index_at(Ix... index) const { + return offset_at(index...) / itemsize(); + } + + /** + * Returns a proxy object that provides access to the array's data without bounds or + * dimensionality checking. Will throw if the array is missing the `writeable` flag. Use with + * care: the array must not be destroyed or reshaped for the duration of the returned object, + * and the caller must take care not to access invalid dimensions or dimension indices. + */ + template detail::unchecked_mutable_reference mutable_unchecked() & { + if (Dims >= 0 && ndim() != Dims) + throw std::domain_error("array has incorrect number of dimensions: " + std::to_string(ndim()) + + "; expected " + std::to_string(Dims)); + return detail::unchecked_mutable_reference(mutable_data(), shape(), strides(), ndim()); + } + + /** + * Returns a proxy object that provides const access to the array's data without bounds or + * dimensionality checking. Unlike `mutable_unchecked()`, this does not require that the + * underlying array have the `writable` flag. Use with care: the array must not be destroyed or + * reshaped for the duration of the returned object, and the caller must take care not to access + * invalid dimensions or dimension indices. + */ + template detail::unchecked_reference unchecked() const & { + if (Dims >= 0 && ndim() != Dims) + throw std::domain_error("array has incorrect number of dimensions: " + std::to_string(ndim()) + + "; expected " + std::to_string(Dims)); + return detail::unchecked_reference(data(), shape(), strides(), ndim()); + } + + /// Return a new view with all of the dimensions of length 1 removed + array squeeze() { + auto& api = detail::npy_api::get(); + return reinterpret_steal(api.PyArray_Squeeze_(m_ptr)); + } + + /// Resize array to given shape + /// If refcheck is true and more that one reference exist to this array + /// then resize will succeed only if it makes a reshape, i.e. original size doesn't change + void resize(ShapeContainer new_shape, bool refcheck = true) { + detail::npy_api::PyArray_Dims d = { + new_shape->data(), int(new_shape->size()) + }; + // try to resize, set ordering param to -1 cause it's not used anyway + object new_array = reinterpret_steal( + detail::npy_api::get().PyArray_Resize_(m_ptr, &d, int(refcheck), -1) + ); + if (!new_array) throw error_already_set(); + if (isinstance(new_array)) { *this = std::move(new_array); } + } + + /// Ensure that the argument is a NumPy array + /// In case of an error, nullptr is returned and the Python error is cleared. + static array ensure(handle h, int ExtraFlags = 0) { + auto result = reinterpret_steal(raw_array(h.ptr(), ExtraFlags)); + if (!result) + PyErr_Clear(); + return result; + } + +protected: + template friend struct detail::npy_format_descriptor; + + void fail_dim_check(ssize_t dim, const std::string& msg) const { + throw index_error(msg + ": " + std::to_string(dim) + + " (ndim = " + std::to_string(ndim()) + ")"); + } + + template ssize_t byte_offset(Ix... index) const { + check_dimensions(index...); + return detail::byte_offset_unsafe(strides(), ssize_t(index)...); + } + + void check_writeable() const { + if (!writeable()) + throw std::domain_error("array is not writeable"); + } + + // Default, C-style strides + static std::vector c_strides(const std::vector &shape, ssize_t itemsize) { + auto ndim = shape.size(); + std::vector strides(ndim, itemsize); + for (size_t i = ndim - 1; i > 0; --i) + strides[i - 1] = strides[i] * shape[i]; + return strides; + } + + // F-style strides; default when constructing an array_t with `ExtraFlags & f_style` + static std::vector f_strides(const std::vector &shape, ssize_t itemsize) { + auto ndim = shape.size(); + std::vector strides(ndim, itemsize); + for (size_t i = 1; i < ndim; ++i) + strides[i] = strides[i - 1] * shape[i - 1]; + return strides; + } + + template void check_dimensions(Ix... index) const { + check_dimensions_impl(ssize_t(0), shape(), ssize_t(index)...); + } + + void check_dimensions_impl(ssize_t, const ssize_t*) const { } + + template void check_dimensions_impl(ssize_t axis, const ssize_t* shape, ssize_t i, Ix... index) const { + if (i >= *shape) { + throw index_error(std::string("index ") + std::to_string(i) + + " is out of bounds for axis " + std::to_string(axis) + + " with size " + std::to_string(*shape)); + } + check_dimensions_impl(axis + 1, shape + 1, index...); + } + + /// Create array from any object -- always returns a new reference + static PyObject *raw_array(PyObject *ptr, int ExtraFlags = 0) { + if (ptr == nullptr) { + PyErr_SetString(PyExc_ValueError, "cannot create a pybind11::array from a nullptr"); + return nullptr; + } + return detail::npy_api::get().PyArray_FromAny_( + ptr, nullptr, 0, 0, detail::npy_api::NPY_ARRAY_ENSUREARRAY_ | ExtraFlags, nullptr); + } +}; + +template class array_t : public array { +private: + struct private_ctor {}; + // Delegating constructor needed when both moving and accessing in the same constructor + array_t(private_ctor, ShapeContainer &&shape, StridesContainer &&strides, const T *ptr, handle base) + : array(std::move(shape), std::move(strides), ptr, base) {} +public: + static_assert(!detail::array_info::is_array, "Array types cannot be used with array_t"); + + using value_type = T; + + array_t() : array(0, static_cast(nullptr)) {} + array_t(handle h, borrowed_t) : array(h, borrowed_t{}) { } + array_t(handle h, stolen_t) : array(h, stolen_t{}) { } + + PYBIND11_DEPRECATED("Use array_t::ensure() instead") + array_t(handle h, bool is_borrowed) : array(raw_array_t(h.ptr()), stolen_t{}) { + if (!m_ptr) PyErr_Clear(); + if (!is_borrowed) Py_XDECREF(h.ptr()); + } + + array_t(const object &o) : array(raw_array_t(o.ptr()), stolen_t{}) { + if (!m_ptr) throw error_already_set(); + } + + explicit array_t(const buffer_info& info) : array(info) { } + + array_t(ShapeContainer shape, StridesContainer strides, const T *ptr = nullptr, handle base = handle()) + : array(std::move(shape), std::move(strides), ptr, base) { } + + explicit array_t(ShapeContainer shape, const T *ptr = nullptr, handle base = handle()) + : array_t(private_ctor{}, std::move(shape), + ExtraFlags & f_style ? f_strides(*shape, itemsize()) : c_strides(*shape, itemsize()), + ptr, base) { } + + explicit array_t(size_t count, const T *ptr = nullptr, handle base = handle()) + : array({count}, {}, ptr, base) { } + + constexpr ssize_t itemsize() const { + return sizeof(T); + } + + template ssize_t index_at(Ix... index) const { + return offset_at(index...) / itemsize(); + } + + template const T* data(Ix... index) const { + return static_cast(array::data(index...)); + } + + template T* mutable_data(Ix... index) { + return static_cast(array::mutable_data(index...)); + } + + // Reference to element at a given index + template const T& at(Ix... index) const { + if (sizeof...(index) != ndim()) + fail_dim_check(sizeof...(index), "index dimension mismatch"); + return *(static_cast(array::data()) + byte_offset(ssize_t(index)...) / itemsize()); + } + + // Mutable reference to element at a given index + template T& mutable_at(Ix... index) { + if (sizeof...(index) != ndim()) + fail_dim_check(sizeof...(index), "index dimension mismatch"); + return *(static_cast(array::mutable_data()) + byte_offset(ssize_t(index)...) / itemsize()); + } + + /** + * Returns a proxy object that provides access to the array's data without bounds or + * dimensionality checking. Will throw if the array is missing the `writeable` flag. Use with + * care: the array must not be destroyed or reshaped for the duration of the returned object, + * and the caller must take care not to access invalid dimensions or dimension indices. + */ + template detail::unchecked_mutable_reference mutable_unchecked() & { + return array::mutable_unchecked(); + } + + /** + * Returns a proxy object that provides const access to the array's data without bounds or + * dimensionality checking. Unlike `unchecked()`, this does not require that the underlying + * array have the `writable` flag. Use with care: the array must not be destroyed or reshaped + * for the duration of the returned object, and the caller must take care not to access invalid + * dimensions or dimension indices. + */ + template detail::unchecked_reference unchecked() const & { + return array::unchecked(); + } + + /// Ensure that the argument is a NumPy array of the correct dtype (and if not, try to convert + /// it). In case of an error, nullptr is returned and the Python error is cleared. + static array_t ensure(handle h) { + auto result = reinterpret_steal(raw_array_t(h.ptr())); + if (!result) + PyErr_Clear(); + return result; + } + + static bool check_(handle h) { + const auto &api = detail::npy_api::get(); + return api.PyArray_Check_(h.ptr()) + && api.PyArray_EquivTypes_(detail::array_proxy(h.ptr())->descr, dtype::of().ptr()); + } + +protected: + /// Create array from any object -- always returns a new reference + static PyObject *raw_array_t(PyObject *ptr) { + if (ptr == nullptr) { + PyErr_SetString(PyExc_ValueError, "cannot create a pybind11::array_t from a nullptr"); + return nullptr; + } + return detail::npy_api::get().PyArray_FromAny_( + ptr, dtype::of().release().ptr(), 0, 0, + detail::npy_api::NPY_ARRAY_ENSUREARRAY_ | ExtraFlags, nullptr); + } +}; + +template +struct format_descriptor::value>> { + static std::string format() { + return detail::npy_format_descriptor::type>::format(); + } +}; + +template struct format_descriptor { + static std::string format() { return std::to_string(N) + "s"; } +}; +template struct format_descriptor> { + static std::string format() { return std::to_string(N) + "s"; } +}; + +template +struct format_descriptor::value>> { + static std::string format() { + return format_descriptor< + typename std::remove_cv::type>::type>::format(); + } +}; + +template +struct format_descriptor::is_array>> { + static std::string format() { + using namespace detail; + PYBIND11_DESCR extents = _("(") + array_info::extents() + _(")"); + return extents.text() + format_descriptor>::format(); + } +}; + +NAMESPACE_BEGIN(detail) +template +struct pyobject_caster> { + using type = array_t; + + bool load(handle src, bool convert) { + if (!convert && !type::check_(src)) + return false; + value = type::ensure(src); + return static_cast(value); + } + + static handle cast(const handle &src, return_value_policy /* policy */, handle /* parent */) { + return src.inc_ref(); + } + PYBIND11_TYPE_CASTER(type, handle_type_name::name()); +}; + +template +struct compare_buffer_info::value>> { + static bool compare(const buffer_info& b) { + return npy_api::get().PyArray_EquivTypes_(dtype::of().ptr(), dtype(b).ptr()); + } +}; + +template struct npy_format_descriptor::value>> { +private: + // NB: the order here must match the one in common.h + constexpr static const int values[15] = { + npy_api::NPY_BOOL_, + npy_api::NPY_BYTE_, npy_api::NPY_UBYTE_, npy_api::NPY_SHORT_, npy_api::NPY_USHORT_, + npy_api::NPY_INT_, npy_api::NPY_UINT_, npy_api::NPY_LONGLONG_, npy_api::NPY_ULONGLONG_, + npy_api::NPY_FLOAT_, npy_api::NPY_DOUBLE_, npy_api::NPY_LONGDOUBLE_, + npy_api::NPY_CFLOAT_, npy_api::NPY_CDOUBLE_, npy_api::NPY_CLONGDOUBLE_ + }; + +public: + static constexpr int value = values[detail::is_fmt_numeric::index]; + + static pybind11::dtype dtype() { + if (auto ptr = npy_api::get().PyArray_DescrFromType_(value)) + return reinterpret_borrow(ptr); + pybind11_fail("Unsupported buffer format!"); + } + template ::value, int> = 0> + static PYBIND11_DESCR name() { + return _::value>(_("bool"), + _::value>("int", "uint") + _()); + } + template ::value, int> = 0> + static PYBIND11_DESCR name() { + return _::value || std::is_same::value>( + _("float") + _(), _("longdouble")); + } + template ::value, int> = 0> + static PYBIND11_DESCR name() { + return _::value || std::is_same::value>( + _("complex") + _(), _("longcomplex")); + } +}; + +#define PYBIND11_DECL_CHAR_FMT \ + static PYBIND11_DESCR name() { return _("S") + _(); } \ + static pybind11::dtype dtype() { return pybind11::dtype(std::string("S") + std::to_string(N)); } +template struct npy_format_descriptor { PYBIND11_DECL_CHAR_FMT }; +template struct npy_format_descriptor> { PYBIND11_DECL_CHAR_FMT }; +#undef PYBIND11_DECL_CHAR_FMT + +template struct npy_format_descriptor::is_array>> { +private: + using base_descr = npy_format_descriptor::type>; +public: + static_assert(!array_info::is_empty, "Zero-sized arrays are not supported"); + + static PYBIND11_DESCR name() { return _("(") + array_info::extents() + _(")") + base_descr::name(); } + static pybind11::dtype dtype() { + list shape; + array_info::append_extents(shape); + return pybind11::dtype::from_args(pybind11::make_tuple(base_descr::dtype(), shape)); + } +}; + +template struct npy_format_descriptor::value>> { +private: + using base_descr = npy_format_descriptor::type>; +public: + static PYBIND11_DESCR name() { return base_descr::name(); } + static pybind11::dtype dtype() { return base_descr::dtype(); } +}; + +struct field_descriptor { + const char *name; + ssize_t offset; + ssize_t size; + std::string format; + dtype descr; +}; + +inline PYBIND11_NOINLINE void register_structured_dtype( + const std::initializer_list& fields, + const std::type_info& tinfo, ssize_t itemsize, + bool (*direct_converter)(PyObject *, void *&)) { + + auto& numpy_internals = get_numpy_internals(); + if (numpy_internals.get_type_info(tinfo, false)) + pybind11_fail("NumPy: dtype is already registered"); + + list names, formats, offsets; + for (auto field : fields) { + if (!field.descr) + pybind11_fail(std::string("NumPy: unsupported field dtype: `") + + field.name + "` @ " + tinfo.name()); + names.append(PYBIND11_STR_TYPE(field.name)); + formats.append(field.descr); + offsets.append(pybind11::int_(field.offset)); + } + auto dtype_ptr = pybind11::dtype(names, formats, offsets, itemsize).release().ptr(); + + // There is an existing bug in NumPy (as of v1.11): trailing bytes are + // not encoded explicitly into the format string. This will supposedly + // get fixed in v1.12; for further details, see these: + // - https://github.com/numpy/numpy/issues/7797 + // - https://github.com/numpy/numpy/pull/7798 + // Because of this, we won't use numpy's logic to generate buffer format + // strings and will just do it ourselves. + std::vector ordered_fields(fields); + std::sort(ordered_fields.begin(), ordered_fields.end(), + [](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset; }); + ssize_t offset = 0; + std::ostringstream oss; + // mark the structure as unaligned with '^', because numpy and C++ don't + // always agree about alignment (particularly for complex), and we're + // explicitly listing all our padding. This depends on none of the fields + // overriding the endianness. Putting the ^ in front of individual fields + // isn't guaranteed to work due to https://github.com/numpy/numpy/issues/9049 + oss << "^T{"; + for (auto& field : ordered_fields) { + if (field.offset > offset) + oss << (field.offset - offset) << 'x'; + oss << field.format << ':' << field.name << ':'; + offset = field.offset + field.size; + } + if (itemsize > offset) + oss << (itemsize - offset) << 'x'; + oss << '}'; + auto format_str = oss.str(); + + // Sanity check: verify that NumPy properly parses our buffer format string + auto& api = npy_api::get(); + auto arr = array(buffer_info(nullptr, itemsize, format_str, 1)); + if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr())) + pybind11_fail("NumPy: invalid buffer descriptor!"); + + auto tindex = std::type_index(tinfo); + numpy_internals.registered_dtypes[tindex] = { dtype_ptr, format_str }; + get_internals().direct_conversions[tindex].push_back(direct_converter); +} + +template struct npy_format_descriptor { + static_assert(is_pod_struct::value, "Attempt to use a non-POD or unimplemented POD type as a numpy dtype"); + + static PYBIND11_DESCR name() { return make_caster::name(); } + + static pybind11::dtype dtype() { + return reinterpret_borrow(dtype_ptr()); + } + + static std::string format() { + static auto format_str = get_numpy_internals().get_type_info(true)->format_str; + return format_str; + } + + static void register_dtype(const std::initializer_list& fields) { + register_structured_dtype(fields, typeid(typename std::remove_cv::type), + sizeof(T), &direct_converter); + } + +private: + static PyObject* dtype_ptr() { + static PyObject* ptr = get_numpy_internals().get_type_info(true)->dtype_ptr; + return ptr; + } + + static bool direct_converter(PyObject *obj, void*& value) { + auto& api = npy_api::get(); + if (!PyObject_TypeCheck(obj, api.PyVoidArrType_Type_)) + return false; + if (auto descr = reinterpret_steal(api.PyArray_DescrFromScalar_(obj))) { + if (api.PyArray_EquivTypes_(dtype_ptr(), descr.ptr())) { + value = ((PyVoidScalarObject_Proxy *) obj)->obval; + return true; + } + } + return false; + } +}; + +#ifdef __CLION_IDE__ // replace heavy macro with dummy code for the IDE (doesn't affect code) +# define PYBIND11_NUMPY_DTYPE(Type, ...) ((void)0) +# define PYBIND11_NUMPY_DTYPE_EX(Type, ...) ((void)0) +#else + +#define PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, Name) \ + ::pybind11::detail::field_descriptor { \ + Name, offsetof(T, Field), sizeof(decltype(std::declval().Field)), \ + ::pybind11::format_descriptor().Field)>::format(), \ + ::pybind11::detail::npy_format_descriptor().Field)>::dtype() \ + } + +// Extract name, offset and format descriptor for a struct field +#define PYBIND11_FIELD_DESCRIPTOR(T, Field) PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, #Field) + +// The main idea of this macro is borrowed from https://github.com/swansontec/map-macro +// (C) William Swanson, Paul Fultz +#define PYBIND11_EVAL0(...) __VA_ARGS__ +#define PYBIND11_EVAL1(...) PYBIND11_EVAL0 (PYBIND11_EVAL0 (PYBIND11_EVAL0 (__VA_ARGS__))) +#define PYBIND11_EVAL2(...) PYBIND11_EVAL1 (PYBIND11_EVAL1 (PYBIND11_EVAL1 (__VA_ARGS__))) +#define PYBIND11_EVAL3(...) PYBIND11_EVAL2 (PYBIND11_EVAL2 (PYBIND11_EVAL2 (__VA_ARGS__))) +#define PYBIND11_EVAL4(...) PYBIND11_EVAL3 (PYBIND11_EVAL3 (PYBIND11_EVAL3 (__VA_ARGS__))) +#define PYBIND11_EVAL(...) PYBIND11_EVAL4 (PYBIND11_EVAL4 (PYBIND11_EVAL4 (__VA_ARGS__))) +#define PYBIND11_MAP_END(...) +#define PYBIND11_MAP_OUT +#define PYBIND11_MAP_COMMA , +#define PYBIND11_MAP_GET_END() 0, PYBIND11_MAP_END +#define PYBIND11_MAP_NEXT0(test, next, ...) next PYBIND11_MAP_OUT +#define PYBIND11_MAP_NEXT1(test, next) PYBIND11_MAP_NEXT0 (test, next, 0) +#define PYBIND11_MAP_NEXT(test, next) PYBIND11_MAP_NEXT1 (PYBIND11_MAP_GET_END test, next) +#ifdef _MSC_VER // MSVC is not as eager to expand macros, hence this workaround +#define PYBIND11_MAP_LIST_NEXT1(test, next) \ + PYBIND11_EVAL0 (PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0)) +#else +#define PYBIND11_MAP_LIST_NEXT1(test, next) \ + PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0) +#endif +#define PYBIND11_MAP_LIST_NEXT(test, next) \ + PYBIND11_MAP_LIST_NEXT1 (PYBIND11_MAP_GET_END test, next) +#define PYBIND11_MAP_LIST0(f, t, x, peek, ...) \ + f(t, x) PYBIND11_MAP_LIST_NEXT (peek, PYBIND11_MAP_LIST1) (f, t, peek, __VA_ARGS__) +#define PYBIND11_MAP_LIST1(f, t, x, peek, ...) \ + f(t, x) PYBIND11_MAP_LIST_NEXT (peek, PYBIND11_MAP_LIST0) (f, t, peek, __VA_ARGS__) +// PYBIND11_MAP_LIST(f, t, a1, a2, ...) expands to f(t, a1), f(t, a2), ... +#define PYBIND11_MAP_LIST(f, t, ...) \ + PYBIND11_EVAL (PYBIND11_MAP_LIST1 (f, t, __VA_ARGS__, (), 0)) + +#define PYBIND11_NUMPY_DTYPE(Type, ...) \ + ::pybind11::detail::npy_format_descriptor::register_dtype \ + ({PYBIND11_MAP_LIST (PYBIND11_FIELD_DESCRIPTOR, Type, __VA_ARGS__)}) + +#ifdef _MSC_VER +#define PYBIND11_MAP2_LIST_NEXT1(test, next) \ + PYBIND11_EVAL0 (PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0)) +#else +#define PYBIND11_MAP2_LIST_NEXT1(test, next) \ + PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0) +#endif +#define PYBIND11_MAP2_LIST_NEXT(test, next) \ + PYBIND11_MAP2_LIST_NEXT1 (PYBIND11_MAP_GET_END test, next) +#define PYBIND11_MAP2_LIST0(f, t, x1, x2, peek, ...) \ + f(t, x1, x2) PYBIND11_MAP2_LIST_NEXT (peek, PYBIND11_MAP2_LIST1) (f, t, peek, __VA_ARGS__) +#define PYBIND11_MAP2_LIST1(f, t, x1, x2, peek, ...) \ + f(t, x1, x2) PYBIND11_MAP2_LIST_NEXT (peek, PYBIND11_MAP2_LIST0) (f, t, peek, __VA_ARGS__) +// PYBIND11_MAP2_LIST(f, t, a1, a2, ...) expands to f(t, a1, a2), f(t, a3, a4), ... +#define PYBIND11_MAP2_LIST(f, t, ...) \ + PYBIND11_EVAL (PYBIND11_MAP2_LIST1 (f, t, __VA_ARGS__, (), 0)) + +#define PYBIND11_NUMPY_DTYPE_EX(Type, ...) \ + ::pybind11::detail::npy_format_descriptor::register_dtype \ + ({PYBIND11_MAP2_LIST (PYBIND11_FIELD_DESCRIPTOR_EX, Type, __VA_ARGS__)}) + +#endif // __CLION_IDE__ + +template +using array_iterator = typename std::add_pointer::type; + +template +array_iterator array_begin(const buffer_info& buffer) { + return array_iterator(reinterpret_cast(buffer.ptr)); +} + +template +array_iterator array_end(const buffer_info& buffer) { + return array_iterator(reinterpret_cast(buffer.ptr) + buffer.size); +} + +class common_iterator { +public: + using container_type = std::vector; + using value_type = container_type::value_type; + using size_type = container_type::size_type; + + common_iterator() : p_ptr(0), m_strides() {} + + common_iterator(void* ptr, const container_type& strides, const container_type& shape) + : p_ptr(reinterpret_cast(ptr)), m_strides(strides.size()) { + m_strides.back() = static_cast(strides.back()); + for (size_type i = m_strides.size() - 1; i != 0; --i) { + size_type j = i - 1; + value_type s = static_cast(shape[i]); + m_strides[j] = strides[j] + m_strides[i] - strides[i] * s; + } + } + + void increment(size_type dim) { + p_ptr += m_strides[dim]; + } + + void* data() const { + return p_ptr; + } + +private: + char* p_ptr; + container_type m_strides; +}; + +template class multi_array_iterator { +public: + using container_type = std::vector; + + multi_array_iterator(const std::array &buffers, + const container_type &shape) + : m_shape(shape.size()), m_index(shape.size(), 0), + m_common_iterator() { + + // Manual copy to avoid conversion warning if using std::copy + for (size_t i = 0; i < shape.size(); ++i) + m_shape[i] = shape[i]; + + container_type strides(shape.size()); + for (size_t i = 0; i < N; ++i) + init_common_iterator(buffers[i], shape, m_common_iterator[i], strides); + } + + multi_array_iterator& operator++() { + for (size_t j = m_index.size(); j != 0; --j) { + size_t i = j - 1; + if (++m_index[i] != m_shape[i]) { + increment_common_iterator(i); + break; + } else { + m_index[i] = 0; + } + } + return *this; + } + + template T* data() const { + return reinterpret_cast(m_common_iterator[K].data()); + } + +private: + + using common_iter = common_iterator; + + void init_common_iterator(const buffer_info &buffer, + const container_type &shape, + common_iter &iterator, + container_type &strides) { + auto buffer_shape_iter = buffer.shape.rbegin(); + auto buffer_strides_iter = buffer.strides.rbegin(); + auto shape_iter = shape.rbegin(); + auto strides_iter = strides.rbegin(); + + while (buffer_shape_iter != buffer.shape.rend()) { + if (*shape_iter == *buffer_shape_iter) + *strides_iter = *buffer_strides_iter; + else + *strides_iter = 0; + + ++buffer_shape_iter; + ++buffer_strides_iter; + ++shape_iter; + ++strides_iter; + } + + std::fill(strides_iter, strides.rend(), 0); + iterator = common_iter(buffer.ptr, strides, shape); + } + + void increment_common_iterator(size_t dim) { + for (auto &iter : m_common_iterator) + iter.increment(dim); + } + + container_type m_shape; + container_type m_index; + std::array m_common_iterator; +}; + +enum class broadcast_trivial { non_trivial, c_trivial, f_trivial }; + +// Populates the shape and number of dimensions for the set of buffers. Returns a broadcast_trivial +// enum value indicating whether the broadcast is "trivial"--that is, has each buffer being either a +// singleton or a full-size, C-contiguous (`c_trivial`) or Fortran-contiguous (`f_trivial`) storage +// buffer; returns `non_trivial` otherwise. +template +broadcast_trivial broadcast(const std::array &buffers, ssize_t &ndim, std::vector &shape) { + ndim = std::accumulate(buffers.begin(), buffers.end(), ssize_t(0), [](ssize_t res, const buffer_info &buf) { + return std::max(res, buf.ndim); + }); + + shape.clear(); + shape.resize((size_t) ndim, 1); + + // Figure out the output size, and make sure all input arrays conform (i.e. are either size 1 or + // the full size). + for (size_t i = 0; i < N; ++i) { + auto res_iter = shape.rbegin(); + auto end = buffers[i].shape.rend(); + for (auto shape_iter = buffers[i].shape.rbegin(); shape_iter != end; ++shape_iter, ++res_iter) { + const auto &dim_size_in = *shape_iter; + auto &dim_size_out = *res_iter; + + // Each input dimension can either be 1 or `n`, but `n` values must match across buffers + if (dim_size_out == 1) + dim_size_out = dim_size_in; + else if (dim_size_in != 1 && dim_size_in != dim_size_out) + pybind11_fail("pybind11::vectorize: incompatible size/dimension of inputs!"); + } + } + + bool trivial_broadcast_c = true; + bool trivial_broadcast_f = true; + for (size_t i = 0; i < N && (trivial_broadcast_c || trivial_broadcast_f); ++i) { + if (buffers[i].size == 1) + continue; + + // Require the same number of dimensions: + if (buffers[i].ndim != ndim) + return broadcast_trivial::non_trivial; + + // Require all dimensions be full-size: + if (!std::equal(buffers[i].shape.cbegin(), buffers[i].shape.cend(), shape.cbegin())) + return broadcast_trivial::non_trivial; + + // Check for C contiguity (but only if previous inputs were also C contiguous) + if (trivial_broadcast_c) { + ssize_t expect_stride = buffers[i].itemsize; + auto end = buffers[i].shape.crend(); + for (auto shape_iter = buffers[i].shape.crbegin(), stride_iter = buffers[i].strides.crbegin(); + trivial_broadcast_c && shape_iter != end; ++shape_iter, ++stride_iter) { + if (expect_stride == *stride_iter) + expect_stride *= *shape_iter; + else + trivial_broadcast_c = false; + } + } + + // Check for Fortran contiguity (if previous inputs were also F contiguous) + if (trivial_broadcast_f) { + ssize_t expect_stride = buffers[i].itemsize; + auto end = buffers[i].shape.cend(); + for (auto shape_iter = buffers[i].shape.cbegin(), stride_iter = buffers[i].strides.cbegin(); + trivial_broadcast_f && shape_iter != end; ++shape_iter, ++stride_iter) { + if (expect_stride == *stride_iter) + expect_stride *= *shape_iter; + else + trivial_broadcast_f = false; + } + } + } + + return + trivial_broadcast_c ? broadcast_trivial::c_trivial : + trivial_broadcast_f ? broadcast_trivial::f_trivial : + broadcast_trivial::non_trivial; +} + +template +struct vectorize_arg { + static_assert(!std::is_rvalue_reference::value, "Functions with rvalue reference arguments cannot be vectorized"); + // The wrapped function gets called with this type: + using call_type = remove_reference_t; + // Is this a vectorized argument? + static constexpr bool vectorize = + satisfies_any_of::value && + satisfies_none_of::value && + (!std::is_reference::value || + (std::is_lvalue_reference::value && std::is_const::value)); + // Accept this type: an array for vectorized types, otherwise the type as-is: + using type = conditional_t, array::forcecast>, T>; +}; + +template +struct vectorize_helper { +private: + static constexpr size_t N = sizeof...(Args); + static constexpr size_t NVectorized = constexpr_sum(vectorize_arg::vectorize...); + static_assert(NVectorized >= 1, + "pybind11::vectorize(...) requires a function with at least one vectorizable argument"); + +public: + template + explicit vectorize_helper(T &&f) : f(std::forward(f)) { } + + object operator()(typename vectorize_arg::type... args) { + return run(args..., + make_index_sequence(), + select_indices::vectorize...>(), + make_index_sequence()); + } + +private: + remove_reference_t f; + + template using param_n_t = typename pack_element::call_type...>::type; + + // Runs a vectorized function given arguments tuple and three index sequences: + // - Index is the full set of 0 ... (N-1) argument indices; + // - VIndex is the subset of argument indices with vectorized parameters, letting us access + // vectorized arguments (anything not in this sequence is passed through) + // - BIndex is a incremental sequence (beginning at 0) of the same size as VIndex, so that + // we can store vectorized buffer_infos in an array (argument VIndex has its buffer at + // index BIndex in the array). + template object run( + typename vectorize_arg::type &...args, + index_sequence i_seq, index_sequence vi_seq, index_sequence bi_seq) { + + // Pointers to values the function was called with; the vectorized ones set here will start + // out as array_t pointers, but they will be changed them to T pointers before we make + // call the wrapped function. Non-vectorized pointers are left as-is. + std::array params{{ &args... }}; + + // The array of `buffer_info`s of vectorized arguments: + std::array buffers{{ reinterpret_cast(params[VIndex])->request()... }}; + + /* Determine dimensions parameters of output array */ + ssize_t nd = 0; + std::vector shape(0); + auto trivial = broadcast(buffers, nd, shape); + size_t ndim = (size_t) nd; + + size_t size = std::accumulate(shape.begin(), shape.end(), (size_t) 1, std::multiplies()); + + // If all arguments are 0-dimension arrays (i.e. single values) return a plain value (i.e. + // not wrapped in an array). + if (size == 1 && ndim == 0) { + PYBIND11_EXPAND_SIDE_EFFECTS(params[VIndex] = buffers[BIndex].ptr); + return cast(f(*reinterpret_cast *>(params[Index])...)); + } + + array_t result; + if (trivial == broadcast_trivial::f_trivial) result = array_t(shape); + else result = array_t(shape); + + if (size == 0) return result; + + /* Call the function */ + if (trivial == broadcast_trivial::non_trivial) + apply_broadcast(buffers, params, result, i_seq, vi_seq, bi_seq); + else + apply_trivial(buffers, params, result.mutable_data(), size, i_seq, vi_seq, bi_seq); + + return result; + } + + template + void apply_trivial(std::array &buffers, + std::array ¶ms, + Return *out, + size_t size, + index_sequence, index_sequence, index_sequence) { + + // Initialize an array of mutable byte references and sizes with references set to the + // appropriate pointer in `params`; as we iterate, we'll increment each pointer by its size + // (except for singletons, which get an increment of 0). + std::array, NVectorized> vecparams{{ + std::pair( + reinterpret_cast(params[VIndex] = buffers[BIndex].ptr), + buffers[BIndex].size == 1 ? 0 : sizeof(param_n_t) + )... + }}; + + for (size_t i = 0; i < size; ++i) { + out[i] = f(*reinterpret_cast *>(params[Index])...); + for (auto &x : vecparams) x.first += x.second; + } + } + + template + void apply_broadcast(std::array &buffers, + std::array ¶ms, + array_t &output_array, + index_sequence, index_sequence, index_sequence) { + + buffer_info output = output_array.request(); + multi_array_iterator input_iter(buffers, output.shape); + + for (array_iterator iter = array_begin(output), end = array_end(output); + iter != end; + ++iter, ++input_iter) { + PYBIND11_EXPAND_SIDE_EFFECTS(( + params[VIndex] = input_iter.template data() + )); + *iter = f(*reinterpret_cast *>(std::get(params))...); + } + } +}; + +template +vectorize_helper +vectorize_extractor(const Func &f, Return (*) (Args ...)) { + return detail::vectorize_helper(f); +} + +template struct handle_type_name> { + static PYBIND11_DESCR name() { + return _("numpy.ndarray[") + npy_format_descriptor::name() + _("]"); + } +}; + +NAMESPACE_END(detail) + +// Vanilla pointer vectorizer: +template +detail::vectorize_helper +vectorize(Return (*f) (Args ...)) { + return detail::vectorize_helper(f); +} + +// lambda vectorizer: +template ::value, int> = 0> +auto vectorize(Func &&f) -> decltype( + detail::vectorize_extractor(std::forward(f), (detail::function_signature_t *) nullptr)) { + return detail::vectorize_extractor(std::forward(f), (detail::function_signature_t *) nullptr); +} + +// Vectorize a class method (non-const): +template ())), Return, Class *, Args...>> +Helper vectorize(Return (Class::*f)(Args...)) { + return Helper(std::mem_fn(f)); +} + +// Vectorize a class method (non-const): +template ())), Return, const Class *, Args...>> +Helper vectorize(Return (Class::*f)(Args...) const) { + return Helper(std::mem_fn(f)); +} + +NAMESPACE_END(PYBIND11_NAMESPACE) + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif diff --git a/ml/dlib/dlib/external/pybind11/include/pybind11/operators.h b/ml/dlib/dlib/external/pybind11/include/pybind11/operators.h new file mode 100644 index 000000000..b3dd62c3b --- /dev/null +++ b/ml/dlib/dlib/external/pybind11/include/pybind11/operators.h @@ -0,0 +1,168 @@ +/* + pybind11/operator.h: Metatemplates for operator overloading + + Copyright (c) 2016 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include "pybind11.h" + +#if defined(__clang__) && !defined(__INTEL_COMPILER) +# pragma clang diagnostic ignored "-Wunsequenced" // multiple unsequenced modifications to 'self' (when using def(py::self OP Type())) +#elif defined(_MSC_VER) +# pragma warning(push) +# pragma warning(disable: 4127) // warning C4127: Conditional expression is constant +#endif + +NAMESPACE_BEGIN(PYBIND11_NAMESPACE) +NAMESPACE_BEGIN(detail) + +/// Enumeration with all supported operator types +enum op_id : int { + op_add, op_sub, op_mul, op_div, op_mod, op_divmod, op_pow, op_lshift, + op_rshift, op_and, op_xor, op_or, op_neg, op_pos, op_abs, op_invert, + op_int, op_long, op_float, op_str, op_cmp, op_gt, op_ge, op_lt, op_le, + op_eq, op_ne, op_iadd, op_isub, op_imul, op_idiv, op_imod, op_ilshift, + op_irshift, op_iand, op_ixor, op_ior, op_complex, op_bool, op_nonzero, + op_repr, op_truediv, op_itruediv, op_hash +}; + +enum op_type : int { + op_l, /* base type on left */ + op_r, /* base type on right */ + op_u /* unary operator */ +}; + +struct self_t { }; +static const self_t self = self_t(); + +/// Type for an unused type slot +struct undefined_t { }; + +/// Don't warn about an unused variable +inline self_t __self() { return self; } + +/// base template of operator implementations +template struct op_impl { }; + +/// Operator implementation generator +template struct op_ { + template void execute(Class &cl, const Extra&... extra) const { + using Base = typename Class::type; + using L_type = conditional_t::value, Base, L>; + using R_type = conditional_t::value, Base, R>; + using op = op_impl; + cl.def(op::name(), &op::execute, is_operator(), extra...); + #if PY_MAJOR_VERSION < 3 + if (id == op_truediv || id == op_itruediv) + cl.def(id == op_itruediv ? "__idiv__" : ot == op_l ? "__div__" : "__rdiv__", + &op::execute, is_operator(), extra...); + #endif + } + template void execute_cast(Class &cl, const Extra&... extra) const { + using Base = typename Class::type; + using L_type = conditional_t::value, Base, L>; + using R_type = conditional_t::value, Base, R>; + using op = op_impl; + cl.def(op::name(), &op::execute_cast, is_operator(), extra...); + #if PY_MAJOR_VERSION < 3 + if (id == op_truediv || id == op_itruediv) + cl.def(id == op_itruediv ? "__idiv__" : ot == op_l ? "__div__" : "__rdiv__", + &op::execute, is_operator(), extra...); + #endif + } +}; + +#define PYBIND11_BINARY_OPERATOR(id, rid, op, expr) \ +template struct op_impl { \ + static char const* name() { return "__" #id "__"; } \ + static auto execute(const L &l, const R &r) -> decltype(expr) { return (expr); } \ + static B execute_cast(const L &l, const R &r) { return B(expr); } \ +}; \ +template struct op_impl { \ + static char const* name() { return "__" #rid "__"; } \ + static auto execute(const R &r, const L &l) -> decltype(expr) { return (expr); } \ + static B execute_cast(const R &r, const L &l) { return B(expr); } \ +}; \ +inline op_ op(const self_t &, const self_t &) { \ + return op_(); \ +} \ +template op_ op(const self_t &, const T &) { \ + return op_(); \ +} \ +template op_ op(const T &, const self_t &) { \ + return op_(); \ +} + +#define PYBIND11_INPLACE_OPERATOR(id, op, expr) \ +template struct op_impl { \ + static char const* name() { return "__" #id "__"; } \ + static auto execute(L &l, const R &r) -> decltype(expr) { return expr; } \ + static B execute_cast(L &l, const R &r) { return B(expr); } \ +}; \ +template op_ op(const self_t &, const T &) { \ + return op_(); \ +} + +#define PYBIND11_UNARY_OPERATOR(id, op, expr) \ +template struct op_impl { \ + static char const* name() { return "__" #id "__"; } \ + static auto execute(const L &l) -> decltype(expr) { return expr; } \ + static B execute_cast(const L &l) { return B(expr); } \ +}; \ +inline op_ op(const self_t &) { \ + return op_(); \ +} + +PYBIND11_BINARY_OPERATOR(sub, rsub, operator-, l - r) +PYBIND11_BINARY_OPERATOR(add, radd, operator+, l + r) +PYBIND11_BINARY_OPERATOR(mul, rmul, operator*, l * r) +PYBIND11_BINARY_OPERATOR(truediv, rtruediv, operator/, l / r) +PYBIND11_BINARY_OPERATOR(mod, rmod, operator%, l % r) +PYBIND11_BINARY_OPERATOR(lshift, rlshift, operator<<, l << r) +PYBIND11_BINARY_OPERATOR(rshift, rrshift, operator>>, l >> r) +PYBIND11_BINARY_OPERATOR(and, rand, operator&, l & r) +PYBIND11_BINARY_OPERATOR(xor, rxor, operator^, l ^ r) +PYBIND11_BINARY_OPERATOR(eq, eq, operator==, l == r) +PYBIND11_BINARY_OPERATOR(ne, ne, operator!=, l != r) +PYBIND11_BINARY_OPERATOR(or, ror, operator|, l | r) +PYBIND11_BINARY_OPERATOR(gt, lt, operator>, l > r) +PYBIND11_BINARY_OPERATOR(ge, le, operator>=, l >= r) +PYBIND11_BINARY_OPERATOR(lt, gt, operator<, l < r) +PYBIND11_BINARY_OPERATOR(le, ge, operator<=, l <= r) +//PYBIND11_BINARY_OPERATOR(pow, rpow, pow, std::pow(l, r)) +PYBIND11_INPLACE_OPERATOR(iadd, operator+=, l += r) +PYBIND11_INPLACE_OPERATOR(isub, operator-=, l -= r) +PYBIND11_INPLACE_OPERATOR(imul, operator*=, l *= r) +PYBIND11_INPLACE_OPERATOR(itruediv, operator/=, l /= r) +PYBIND11_INPLACE_OPERATOR(imod, operator%=, l %= r) +PYBIND11_INPLACE_OPERATOR(ilshift, operator<<=, l <<= r) +PYBIND11_INPLACE_OPERATOR(irshift, operator>>=, l >>= r) +PYBIND11_INPLACE_OPERATOR(iand, operator&=, l &= r) +PYBIND11_INPLACE_OPERATOR(ixor, operator^=, l ^= r) +PYBIND11_INPLACE_OPERATOR(ior, operator|=, l |= r) +PYBIND11_UNARY_OPERATOR(neg, operator-, -l) +PYBIND11_UNARY_OPERATOR(pos, operator+, +l) +PYBIND11_UNARY_OPERATOR(abs, abs, std::abs(l)) +PYBIND11_UNARY_OPERATOR(hash, hash, std::hash()(l)) +PYBIND11_UNARY_OPERATOR(invert, operator~, (~l)) +PYBIND11_UNARY_OPERATOR(bool, operator!, !!l) +PYBIND11_UNARY_OPERATOR(int, int_, (int) l) +PYBIND11_UNARY_OPERATOR(float, float_, (double) l) + +#undef PYBIND11_BINARY_OPERATOR +#undef PYBIND11_INPLACE_OPERATOR +#undef PYBIND11_UNARY_OPERATOR +NAMESPACE_END(detail) + +using detail::self; + +NAMESPACE_END(PYBIND11_NAMESPACE) + +#if defined(_MSC_VER) +# pragma warning(pop) +#endif diff --git a/ml/dlib/dlib/external/pybind11/include/pybind11/options.h b/ml/dlib/dlib/external/pybind11/include/pybind11/options.h new file mode 100644 index 000000000..cc1e1f6f0 --- /dev/null +++ b/ml/dlib/dlib/external/pybind11/include/pybind11/options.h @@ -0,0 +1,65 @@ +/* + pybind11/options.h: global settings that are configurable at runtime. + + Copyright (c) 2016 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include "detail/common.h" + +NAMESPACE_BEGIN(PYBIND11_NAMESPACE) + +class options { +public: + + // Default RAII constructor, which leaves settings as they currently are. + options() : previous_state(global_state()) {} + + // Class is non-copyable. + options(const options&) = delete; + options& operator=(const options&) = delete; + + // Destructor, which restores settings that were in effect before. + ~options() { + global_state() = previous_state; + } + + // Setter methods (affect the global state): + + options& disable_user_defined_docstrings() & { global_state().show_user_defined_docstrings = false; return *this; } + + options& enable_user_defined_docstrings() & { global_state().show_user_defined_docstrings = true; return *this; } + + options& disable_function_signatures() & { global_state().show_function_signatures = false; return *this; } + + options& enable_function_signatures() & { global_state().show_function_signatures = true; return *this; } + + // Getter methods (return the global state): + + static bool show_user_defined_docstrings() { return global_state().show_user_defined_docstrings; } + + static bool show_function_signatures() { return global_state().show_function_signatures; } + + // This type is not meant to be allocated on the heap. + void* operator new(size_t) = delete; + +private: + + struct state { + bool show_user_defined_docstrings = true; //< Include user-supplied texts in docstrings. + bool show_function_signatures = true; //< Include auto-generated function signatures in docstrings. + }; + + static state &global_state() { + static state instance; + return instance; + } + + state previous_state; +}; + +NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/ml/dlib/dlib/external/pybind11/include/pybind11/pybind11.h b/ml/dlib/dlib/external/pybind11/include/pybind11/pybind11.h new file mode 100644 index 000000000..7723d2a8e --- /dev/null +++ b/ml/dlib/dlib/external/pybind11/include/pybind11/pybind11.h @@ -0,0 +1,1963 @@ +/* + pybind11/pybind11.h: Main header file of the C++11 python + binding generator library + + Copyright (c) 2016 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#if defined(_MSC_VER) +# pragma warning(push) +# pragma warning(disable: 4100) // warning C4100: Unreferenced formal parameter +# pragma warning(disable: 4127) // warning C4127: Conditional expression is constant +# pragma warning(disable: 4512) // warning C4512: Assignment operator was implicitly defined as deleted +# pragma warning(disable: 4800) // warning C4800: 'int': forcing value to bool 'true' or 'false' (performance warning) +# pragma warning(disable: 4996) // warning C4996: The POSIX name for this item is deprecated. Instead, use the ISO C and C++ conformant name +# pragma warning(disable: 4702) // warning C4702: unreachable code +# pragma warning(disable: 4522) // warning C4522: multiple assignment operators specified +#elif defined(__INTEL_COMPILER) +# pragma warning(push) +# pragma warning(disable: 68) // integer conversion resulted in a change of sign +# pragma warning(disable: 186) // pointless comparison of unsigned integer with zero +# pragma warning(disable: 878) // incompatible exception specifications +# pragma warning(disable: 1334) // the "template" keyword used for syntactic disambiguation may only be used within a template +# pragma warning(disable: 1682) // implicit conversion of a 64-bit integral type to a smaller integral type (potential portability problem) +# pragma warning(disable: 1875) // offsetof applied to non-POD (Plain Old Data) types is nonstandard +# pragma warning(disable: 2196) // warning #2196: routine is both "inline" and "noinline" +#elif defined(__GNUG__) && !defined(__clang__) +# pragma GCC diagnostic push +# pragma GCC diagnostic ignored "-Wunused-but-set-parameter" +# pragma GCC diagnostic ignored "-Wunused-but-set-variable" +# pragma GCC diagnostic ignored "-Wmissing-field-initializers" +# pragma GCC diagnostic ignored "-Wstrict-aliasing" +# pragma GCC diagnostic ignored "-Wattributes" +# if __GNUC__ >= 7 +# pragma GCC diagnostic ignored "-Wnoexcept-type" +# endif +#endif + +#include "attr.h" +#include "options.h" +#include "detail/class.h" +#include "detail/init.h" + +NAMESPACE_BEGIN(PYBIND11_NAMESPACE) + +/// Wraps an arbitrary C++ function/method/lambda function/.. into a callable Python object +class cpp_function : public function { +public: + cpp_function() { } + + /// Construct a cpp_function from a vanilla function pointer + template + cpp_function(Return (*f)(Args...), const Extra&... extra) { + initialize(f, f, extra...); + } + + /// Construct a cpp_function from a lambda function (possibly with internal state) + template ::value>> + cpp_function(Func &&f, const Extra&... extra) { + initialize(std::forward(f), + (detail::function_signature_t *) nullptr, extra...); + } + + /// Construct a cpp_function from a class method (non-const) + template + cpp_function(Return (Class::*f)(Arg...), const Extra&... extra) { + initialize([f](Class *c, Arg... args) -> Return { return (c->*f)(args...); }, + (Return (*) (Class *, Arg...)) nullptr, extra...); + } + + /// Construct a cpp_function from a class method (const) + template + cpp_function(Return (Class::*f)(Arg...) const, const Extra&... extra) { + initialize([f](const Class *c, Arg... args) -> Return { return (c->*f)(args...); }, + (Return (*)(const Class *, Arg ...)) nullptr, extra...); + } + + /// Return the function name + object name() const { return attr("__name__"); } + +protected: + /// Space optimization: don't inline this frequently instantiated fragment + PYBIND11_NOINLINE detail::function_record *make_function_record() { + return new detail::function_record(); + } + + /// Special internal constructor for functors, lambda functions, etc. + template + void initialize(Func &&f, Return (*)(Args...), const Extra&... extra) { + using namespace detail; + + struct capture { remove_reference_t f; }; + + /* Store the function including any extra state it might have (e.g. a lambda capture object) */ + auto rec = make_function_record(); + + /* Store the capture object directly in the function record if there is enough space */ + if (sizeof(capture) <= sizeof(rec->data)) { + /* Without these pragmas, GCC warns that there might not be + enough space to use the placement new operator. However, the + 'if' statement above ensures that this is the case. */ +#if defined(__GNUG__) && !defined(__clang__) && __GNUC__ >= 6 +# pragma GCC diagnostic push +# pragma GCC diagnostic ignored "-Wplacement-new" +#endif + new ((capture *) &rec->data) capture { std::forward(f) }; +#if defined(__GNUG__) && !defined(__clang__) && __GNUC__ >= 6 +# pragma GCC diagnostic pop +#endif + if (!std::is_trivially_destructible::value) + rec->free_data = [](function_record *r) { ((capture *) &r->data)->~capture(); }; + } else { + rec->data[0] = new capture { std::forward(f) }; + rec->free_data = [](function_record *r) { delete ((capture *) r->data[0]); }; + } + + /* Type casters for the function arguments and return value */ + using cast_in = argument_loader; + using cast_out = make_caster< + conditional_t::value, void_type, Return> + >; + + static_assert(expected_num_args(sizeof...(Args), cast_in::has_args, cast_in::has_kwargs), + "The number of argument annotations does not match the number of function arguments"); + + /* Dispatch code which converts function arguments and performs the actual function call */ + rec->impl = [](function_call &call) -> handle { + cast_in args_converter; + + /* Try to cast the function arguments into the C++ domain */ + if (!args_converter.load_args(call)) + return PYBIND11_TRY_NEXT_OVERLOAD; + + /* Invoke call policy pre-call hook */ + process_attributes::precall(call); + + /* Get a pointer to the capture object */ + auto data = (sizeof(capture) <= sizeof(call.func.data) + ? &call.func.data : call.func.data[0]); + capture *cap = const_cast(reinterpret_cast(data)); + + /* Override policy for rvalues -- usually to enforce rvp::move on an rvalue */ + const auto policy = return_value_policy_override::policy(call.func.policy); + + /* Function scope guard -- defaults to the compile-to-nothing `void_type` */ + using Guard = extract_guard_t; + + /* Perform the function call */ + handle result = cast_out::cast( + std::move(args_converter).template call(cap->f), policy, call.parent); + + /* Invoke call policy post-call hook */ + process_attributes::postcall(call, result); + + return result; + }; + + /* Process any user-provided function attributes */ + process_attributes::init(extra..., rec); + + /* Generate a readable signature describing the function's arguments and return value types */ + PYBIND11_DESCR signature = _("(") + cast_in::arg_names() + _(") -> ") + cast_out::name(); + + /* Register the function with Python from generic (non-templated) code */ + initialize_generic(rec, signature.text(), signature.types(), sizeof...(Args)); + + if (cast_in::has_args) rec->has_args = true; + if (cast_in::has_kwargs) rec->has_kwargs = true; + + /* Stash some additional information used by an important optimization in 'functional.h' */ + using FunctionType = Return (*)(Args...); + constexpr bool is_function_ptr = + std::is_convertible::value && + sizeof(capture) == sizeof(void *); + if (is_function_ptr) { + rec->is_stateless = true; + rec->data[1] = const_cast(reinterpret_cast(&typeid(FunctionType))); + } + } + + /// Register a function call with Python (generic non-templated code goes here) + void initialize_generic(detail::function_record *rec, const char *text, + const std::type_info *const *types, size_t args) { + + /* Create copies of all referenced C-style strings */ + rec->name = strdup(rec->name ? rec->name : ""); + if (rec->doc) rec->doc = strdup(rec->doc); + for (auto &a: rec->args) { + if (a.name) + a.name = strdup(a.name); + if (a.descr) + a.descr = strdup(a.descr); + else if (a.value) + a.descr = strdup(a.value.attr("__repr__")().cast().c_str()); + } + + rec->is_constructor = !strcmp(rec->name, "__init__") || !strcmp(rec->name, "__setstate__"); + +#if !defined(NDEBUG) && !defined(PYBIND11_DISABLE_NEW_STYLE_INIT_WARNING) + if (rec->is_constructor && !rec->is_new_style_constructor) { + const auto class_name = std::string(((PyTypeObject *) rec->scope.ptr())->tp_name); + const auto func_name = std::string(rec->name); + PyErr_WarnEx( + PyExc_FutureWarning, + ("pybind11-bound class '" + class_name + "' is using an old-style " + "placement-new '" + func_name + "' which has been deprecated. See " + "the upgrade guide in pybind11's docs. This message is only visible " + "when compiled in debug mode.").c_str(), 0 + ); + } +#endif + + /* Generate a proper function signature */ + std::string signature; + size_t type_depth = 0, char_index = 0, type_index = 0, arg_index = 0; + while (true) { + char c = text[char_index++]; + if (c == '\0') + break; + + if (c == '{') { + // Write arg name for everything except *args, **kwargs and return type. + if (type_depth == 0 && text[char_index] != '*' && arg_index < args) { + if (!rec->args.empty() && rec->args[arg_index].name) { + signature += rec->args[arg_index].name; + } else if (arg_index == 0 && rec->is_method) { + signature += "self"; + } else { + signature += "arg" + std::to_string(arg_index - (rec->is_method ? 1 : 0)); + } + signature += ": "; + } + ++type_depth; + } else if (c == '}') { + --type_depth; + if (type_depth == 0) { + if (arg_index < rec->args.size() && rec->args[arg_index].descr) { + signature += "="; + signature += rec->args[arg_index].descr; + } + arg_index++; + } + } else if (c == '%') { + const std::type_info *t = types[type_index++]; + if (!t) + pybind11_fail("Internal error while parsing type signature (1)"); + if (auto tinfo = detail::get_type_info(*t)) { + handle th((PyObject *) tinfo->type); + signature += + th.attr("__module__").cast() + "." + + th.attr("__qualname__").cast(); // Python 3.3+, but we backport it to earlier versions + } else if (rec->is_new_style_constructor && arg_index == 0) { + // A new-style `__init__` takes `self` as `value_and_holder`. + // Rewrite it to the proper class type. + signature += + rec->scope.attr("__module__").cast() + "." + + rec->scope.attr("__qualname__").cast(); + } else { + std::string tname(t->name()); + detail::clean_type_id(tname); + signature += tname; + } + } else { + signature += c; + } + } + if (type_depth != 0 || types[type_index] != nullptr) + pybind11_fail("Internal error while parsing type signature (2)"); + + #if !defined(PYBIND11_CONSTEXPR_DESCR) + delete[] types; + delete[] text; + #endif + +#if PY_MAJOR_VERSION < 3 + if (strcmp(rec->name, "__next__") == 0) { + std::free(rec->name); + rec->name = strdup("next"); + } else if (strcmp(rec->name, "__bool__") == 0) { + std::free(rec->name); + rec->name = strdup("__nonzero__"); + } +#endif + rec->signature = strdup(signature.c_str()); + rec->args.shrink_to_fit(); + rec->nargs = (std::uint16_t) args; + + if (rec->sibling && PYBIND11_INSTANCE_METHOD_CHECK(rec->sibling.ptr())) + rec->sibling = PYBIND11_INSTANCE_METHOD_GET_FUNCTION(rec->sibling.ptr()); + + detail::function_record *chain = nullptr, *chain_start = rec; + if (rec->sibling) { + if (PyCFunction_Check(rec->sibling.ptr())) { + auto rec_capsule = reinterpret_borrow(PyCFunction_GET_SELF(rec->sibling.ptr())); + chain = (detail::function_record *) rec_capsule; + /* Never append a method to an overload chain of a parent class; + instead, hide the parent's overloads in this case */ + if (!chain->scope.is(rec->scope)) + chain = nullptr; + } + // Don't trigger for things like the default __init__, which are wrapper_descriptors that we are intentionally replacing + else if (!rec->sibling.is_none() && rec->name[0] != '_') + pybind11_fail("Cannot overload existing non-function object \"" + std::string(rec->name) + + "\" with a function of the same name"); + } + + if (!chain) { + /* No existing overload was found, create a new function object */ + rec->def = new PyMethodDef(); + std::memset(rec->def, 0, sizeof(PyMethodDef)); + rec->def->ml_name = rec->name; + rec->def->ml_meth = reinterpret_cast(*dispatcher); + rec->def->ml_flags = METH_VARARGS | METH_KEYWORDS; + + capsule rec_capsule(rec, [](void *ptr) { + destruct((detail::function_record *) ptr); + }); + + object scope_module; + if (rec->scope) { + if (hasattr(rec->scope, "__module__")) { + scope_module = rec->scope.attr("__module__"); + } else if (hasattr(rec->scope, "__name__")) { + scope_module = rec->scope.attr("__name__"); + } + } + + m_ptr = PyCFunction_NewEx(rec->def, rec_capsule.ptr(), scope_module.ptr()); + if (!m_ptr) + pybind11_fail("cpp_function::cpp_function(): Could not allocate function object"); + } else { + /* Append at the end of the overload chain */ + m_ptr = rec->sibling.ptr(); + inc_ref(); + chain_start = chain; + if (chain->is_method != rec->is_method) + pybind11_fail("overloading a method with both static and instance methods is not supported; " + #if defined(NDEBUG) + "compile in debug mode for more details" + #else + "error while attempting to bind " + std::string(rec->is_method ? "instance" : "static") + " method " + + std::string(pybind11::str(rec->scope.attr("__name__"))) + "." + std::string(rec->name) + signature + #endif + ); + while (chain->next) + chain = chain->next; + chain->next = rec; + } + + std::string signatures; + int index = 0; + /* Create a nice pydoc rec including all signatures and + docstrings of the functions in the overload chain */ + if (chain && options::show_function_signatures()) { + // First a generic signature + signatures += rec->name; + signatures += "(*args, **kwargs)\n"; + signatures += "Overloaded function.\n\n"; + } + // Then specific overload signatures + bool first_user_def = true; + for (auto it = chain_start; it != nullptr; it = it->next) { + if (options::show_function_signatures()) { + if (index > 0) signatures += "\n"; + if (chain) + signatures += std::to_string(++index) + ". "; + signatures += rec->name; + signatures += it->signature; + signatures += "\n"; + } + if (it->doc && strlen(it->doc) > 0 && options::show_user_defined_docstrings()) { + // If we're appending another docstring, and aren't printing function signatures, we + // need to append a newline first: + if (!options::show_function_signatures()) { + if (first_user_def) first_user_def = false; + else signatures += "\n"; + } + if (options::show_function_signatures()) signatures += "\n"; + signatures += it->doc; + if (options::show_function_signatures()) signatures += "\n"; + } + } + + /* Install docstring */ + PyCFunctionObject *func = (PyCFunctionObject *) m_ptr; + if (func->m_ml->ml_doc) + std::free(const_cast(func->m_ml->ml_doc)); + func->m_ml->ml_doc = strdup(signatures.c_str()); + + if (rec->is_method) { + m_ptr = PYBIND11_INSTANCE_METHOD_NEW(m_ptr, rec->scope.ptr()); + if (!m_ptr) + pybind11_fail("cpp_function::cpp_function(): Could not allocate instance method object"); + Py_DECREF(func); + } + } + + /// When a cpp_function is GCed, release any memory allocated by pybind11 + static void destruct(detail::function_record *rec) { + while (rec) { + detail::function_record *next = rec->next; + if (rec->free_data) + rec->free_data(rec); + std::free((char *) rec->name); + std::free((char *) rec->doc); + std::free((char *) rec->signature); + for (auto &arg: rec->args) { + std::free(const_cast(arg.name)); + std::free(const_cast(arg.descr)); + arg.value.dec_ref(); + } + if (rec->def) { + std::free(const_cast(rec->def->ml_doc)); + delete rec->def; + } + delete rec; + rec = next; + } + } + + /// Main dispatch logic for calls to functions bound using pybind11 + static PyObject *dispatcher(PyObject *self, PyObject *args_in, PyObject *kwargs_in) { + using namespace detail; + + /* Iterator over the list of potentially admissible overloads */ + function_record *overloads = (function_record *) PyCapsule_GetPointer(self, nullptr), + *it = overloads; + + /* Need to know how many arguments + keyword arguments there are to pick the right overload */ + const size_t n_args_in = (size_t) PyTuple_GET_SIZE(args_in); + + handle parent = n_args_in > 0 ? PyTuple_GET_ITEM(args_in, 0) : nullptr, + result = PYBIND11_TRY_NEXT_OVERLOAD; + + auto self_value_and_holder = value_and_holder(); + if (overloads->is_constructor) { + const auto tinfo = get_type_info((PyTypeObject *) overloads->scope.ptr()); + const auto pi = reinterpret_cast(parent.ptr()); + self_value_and_holder = pi->get_value_and_holder(tinfo, false); + + if (!self_value_and_holder.type || !self_value_and_holder.inst) { + PyErr_SetString(PyExc_TypeError, "__init__(self, ...) called with invalid `self` argument"); + return nullptr; + } + + // If this value is already registered it must mean __init__ is invoked multiple times; + // we really can't support that in C++, so just ignore the second __init__. + if (self_value_and_holder.instance_registered()) + return none().release().ptr(); + } + + try { + // We do this in two passes: in the first pass, we load arguments with `convert=false`; + // in the second, we allow conversion (except for arguments with an explicit + // py::arg().noconvert()). This lets us prefer calls without conversion, with + // conversion as a fallback. + std::vector second_pass; + + // However, if there are no overloads, we can just skip the no-convert pass entirely + const bool overloaded = it != nullptr && it->next != nullptr; + + for (; it != nullptr; it = it->next) { + + /* For each overload: + 1. Copy all positional arguments we were given, also checking to make sure that + named positional arguments weren't *also* specified via kwarg. + 2. If we weren't given enough, try to make up the omitted ones by checking + whether they were provided by a kwarg matching the `py::arg("name")` name. If + so, use it (and remove it from kwargs; if not, see if the function binding + provided a default that we can use. + 3. Ensure that either all keyword arguments were "consumed", or that the function + takes a kwargs argument to accept unconsumed kwargs. + 4. Any positional arguments still left get put into a tuple (for args), and any + leftover kwargs get put into a dict. + 5. Pack everything into a vector; if we have py::args or py::kwargs, they are an + extra tuple or dict at the end of the positional arguments. + 6. Call the function call dispatcher (function_record::impl) + + If one of these fail, move on to the next overload and keep trying until we get a + result other than PYBIND11_TRY_NEXT_OVERLOAD. + */ + + function_record &func = *it; + size_t pos_args = func.nargs; // Number of positional arguments that we need + if (func.has_args) --pos_args; // (but don't count py::args + if (func.has_kwargs) --pos_args; // or py::kwargs) + + if (!func.has_args && n_args_in > pos_args) + continue; // Too many arguments for this overload + + if (n_args_in < pos_args && func.args.size() < pos_args) + continue; // Not enough arguments given, and not enough defaults to fill in the blanks + + function_call call(func, parent); + + size_t args_to_copy = std::min(pos_args, n_args_in); + size_t args_copied = 0; + + // 0. Inject new-style `self` argument + if (func.is_new_style_constructor) { + // The `value` may have been preallocated by an old-style `__init__` + // if it was a preceding candidate for overload resolution. + if (self_value_and_holder) + self_value_and_holder.type->dealloc(self_value_and_holder); + + call.init_self = PyTuple_GET_ITEM(args_in, 0); + call.args.push_back(reinterpret_cast(&self_value_and_holder)); + call.args_convert.push_back(false); + ++args_copied; + } + + // 1. Copy any position arguments given. + bool bad_arg = false; + for (; args_copied < args_to_copy; ++args_copied) { + argument_record *arg_rec = args_copied < func.args.size() ? &func.args[args_copied] : nullptr; + if (kwargs_in && arg_rec && arg_rec->name && PyDict_GetItemString(kwargs_in, arg_rec->name)) { + bad_arg = true; + break; + } + + handle arg(PyTuple_GET_ITEM(args_in, args_copied)); + if (arg_rec && !arg_rec->none && arg.is_none()) { + bad_arg = true; + break; + } + call.args.push_back(arg); + call.args_convert.push_back(arg_rec ? arg_rec->convert : true); + } + if (bad_arg) + continue; // Maybe it was meant for another overload (issue #688) + + // We'll need to copy this if we steal some kwargs for defaults + dict kwargs = reinterpret_borrow(kwargs_in); + + // 2. Check kwargs and, failing that, defaults that may help complete the list + if (args_copied < pos_args) { + bool copied_kwargs = false; + + for (; args_copied < pos_args; ++args_copied) { + const auto &arg = func.args[args_copied]; + + handle value; + if (kwargs_in && arg.name) + value = PyDict_GetItemString(kwargs.ptr(), arg.name); + + if (value) { + // Consume a kwargs value + if (!copied_kwargs) { + kwargs = reinterpret_steal(PyDict_Copy(kwargs.ptr())); + copied_kwargs = true; + } + PyDict_DelItemString(kwargs.ptr(), arg.name); + } else if (arg.value) { + value = arg.value; + } + + if (value) { + call.args.push_back(value); + call.args_convert.push_back(arg.convert); + } + else + break; + } + + if (args_copied < pos_args) + continue; // Not enough arguments, defaults, or kwargs to fill the positional arguments + } + + // 3. Check everything was consumed (unless we have a kwargs arg) + if (kwargs && kwargs.size() > 0 && !func.has_kwargs) + continue; // Unconsumed kwargs, but no py::kwargs argument to accept them + + // 4a. If we have a py::args argument, create a new tuple with leftovers + if (func.has_args) { + tuple extra_args; + if (args_to_copy == 0) { + // We didn't copy out any position arguments from the args_in tuple, so we + // can reuse it directly without copying: + extra_args = reinterpret_borrow(args_in); + } else if (args_copied >= n_args_in) { + extra_args = tuple(0); + } else { + size_t args_size = n_args_in - args_copied; + extra_args = tuple(args_size); + for (size_t i = 0; i < args_size; ++i) { + extra_args[i] = PyTuple_GET_ITEM(args_in, args_copied + i); + } + } + call.args.push_back(extra_args); + call.args_convert.push_back(false); + call.args_ref = std::move(extra_args); + } + + // 4b. If we have a py::kwargs, pass on any remaining kwargs + if (func.has_kwargs) { + if (!kwargs.ptr()) + kwargs = dict(); // If we didn't get one, send an empty one + call.args.push_back(kwargs); + call.args_convert.push_back(false); + call.kwargs_ref = std::move(kwargs); + } + + // 5. Put everything in a vector. Not technically step 5, we've been building it + // in `call.args` all along. + #if !defined(NDEBUG) + if (call.args.size() != func.nargs || call.args_convert.size() != func.nargs) + pybind11_fail("Internal error: function call dispatcher inserted wrong number of arguments!"); + #endif + + std::vector second_pass_convert; + if (overloaded) { + // We're in the first no-convert pass, so swap out the conversion flags for a + // set of all-false flags. If the call fails, we'll swap the flags back in for + // the conversion-allowed call below. + second_pass_convert.resize(func.nargs, false); + call.args_convert.swap(second_pass_convert); + } + + // 6. Call the function. + try { + loader_life_support guard{}; + result = func.impl(call); + } catch (reference_cast_error &) { + result = PYBIND11_TRY_NEXT_OVERLOAD; + } + + if (result.ptr() != PYBIND11_TRY_NEXT_OVERLOAD) + break; + + if (overloaded) { + // The (overloaded) call failed; if the call has at least one argument that + // permits conversion (i.e. it hasn't been explicitly specified `.noconvert()`) + // then add this call to the list of second pass overloads to try. + for (size_t i = func.is_method ? 1 : 0; i < pos_args; i++) { + if (second_pass_convert[i]) { + // Found one: swap the converting flags back in and store the call for + // the second pass. + call.args_convert.swap(second_pass_convert); + second_pass.push_back(std::move(call)); + break; + } + } + } + } + + if (overloaded && !second_pass.empty() && result.ptr() == PYBIND11_TRY_NEXT_OVERLOAD) { + // The no-conversion pass finished without success, try again with conversion allowed + for (auto &call : second_pass) { + try { + loader_life_support guard{}; + result = call.func.impl(call); + } catch (reference_cast_error &) { + result = PYBIND11_TRY_NEXT_OVERLOAD; + } + + if (result.ptr() != PYBIND11_TRY_NEXT_OVERLOAD) + break; + } + } + } catch (error_already_set &e) { + e.restore(); + return nullptr; + } catch (...) { + /* When an exception is caught, give each registered exception + translator a chance to translate it to a Python exception + in reverse order of registration. + + A translator may choose to do one of the following: + + - catch the exception and call PyErr_SetString or PyErr_SetObject + to set a standard (or custom) Python exception, or + - do nothing and let the exception fall through to the next translator, or + - delegate translation to the next translator by throwing a new type of exception. */ + + auto last_exception = std::current_exception(); + auto ®istered_exception_translators = get_internals().registered_exception_translators; + for (auto& translator : registered_exception_translators) { + try { + translator(last_exception); + } catch (...) { + last_exception = std::current_exception(); + continue; + } + return nullptr; + } + PyErr_SetString(PyExc_SystemError, "Exception escaped from default exception translator!"); + return nullptr; + } + + auto append_note_if_missing_header_is_suspected = [](std::string &msg) { + if (msg.find("std::") != std::string::npos) { + msg += "\n\n" + "Did you forget to `#include `? Or ,\n" + ", , etc. Some automatic\n" + "conversions are optional and require extra headers to be included\n" + "when compiling your pybind11 module."; + } + }; + + if (result.ptr() == PYBIND11_TRY_NEXT_OVERLOAD) { + if (overloads->is_operator) + return handle(Py_NotImplemented).inc_ref().ptr(); + + std::string msg = std::string(overloads->name) + "(): incompatible " + + std::string(overloads->is_constructor ? "constructor" : "function") + + " arguments. The following argument types are supported:\n"; + + int ctr = 0; + for (function_record *it2 = overloads; it2 != nullptr; it2 = it2->next) { + msg += " "+ std::to_string(++ctr) + ". "; + + bool wrote_sig = false; + if (overloads->is_constructor) { + // For a constructor, rewrite `(self: Object, arg0, ...) -> NoneType` as `Object(arg0, ...)` + std::string sig = it2->signature; + size_t start = sig.find('(') + 7; // skip "(self: " + if (start < sig.size()) { + // End at the , for the next argument + size_t end = sig.find(", "), next = end + 2; + size_t ret = sig.rfind(" -> "); + // Or the ), if there is no comma: + if (end >= sig.size()) next = end = sig.find(')'); + if (start < end && next < sig.size()) { + msg.append(sig, start, end - start); + msg += '('; + msg.append(sig, next, ret - next); + wrote_sig = true; + } + } + } + if (!wrote_sig) msg += it2->signature; + + msg += "\n"; + } + msg += "\nInvoked with: "; + auto args_ = reinterpret_borrow(args_in); + bool some_args = false; + for (size_t ti = overloads->is_constructor ? 1 : 0; ti < args_.size(); ++ti) { + if (!some_args) some_args = true; + else msg += ", "; + msg += pybind11::repr(args_[ti]); + } + if (kwargs_in) { + auto kwargs = reinterpret_borrow(kwargs_in); + if (kwargs.size() > 0) { + if (some_args) msg += "; "; + msg += "kwargs: "; + bool first = true; + for (auto kwarg : kwargs) { + if (first) first = false; + else msg += ", "; + msg += pybind11::str("{}={!r}").format(kwarg.first, kwarg.second); + } + } + } + + append_note_if_missing_header_is_suspected(msg); + PyErr_SetString(PyExc_TypeError, msg.c_str()); + return nullptr; + } else if (!result) { + std::string msg = "Unable to convert function return value to a " + "Python type! The signature was\n\t"; + msg += it->signature; + append_note_if_missing_header_is_suspected(msg); + PyErr_SetString(PyExc_TypeError, msg.c_str()); + return nullptr; + } else { + if (overloads->is_constructor && !self_value_and_holder.holder_constructed()) { + auto *pi = reinterpret_cast(parent.ptr()); + self_value_and_holder.type->init_instance(pi, nullptr); + } + return result.ptr(); + } + } +}; + +/// Wrapper for Python extension modules +class module : public object { +public: + PYBIND11_OBJECT_DEFAULT(module, object, PyModule_Check) + + /// Create a new top-level Python module with the given name and docstring + explicit module(const char *name, const char *doc = nullptr) { + if (!options::show_user_defined_docstrings()) doc = nullptr; +#if PY_MAJOR_VERSION >= 3 + PyModuleDef *def = new PyModuleDef(); + std::memset(def, 0, sizeof(PyModuleDef)); + def->m_name = name; + def->m_doc = doc; + def->m_size = -1; + Py_INCREF(def); + m_ptr = PyModule_Create(def); +#else + m_ptr = Py_InitModule3(name, nullptr, doc); +#endif + if (m_ptr == nullptr) + pybind11_fail("Internal error in module::module()"); + inc_ref(); + } + + /** \rst + Create Python binding for a new function within the module scope. ``Func`` + can be a plain C++ function, a function pointer, or a lambda function. For + details on the ``Extra&& ... extra`` argument, see section :ref:`extras`. + \endrst */ + template + module &def(const char *name_, Func &&f, const Extra& ... extra) { + cpp_function func(std::forward(f), name(name_), scope(*this), + sibling(getattr(*this, name_, none())), extra...); + // NB: allow overwriting here because cpp_function sets up a chain with the intention of + // overwriting (and has already checked internally that it isn't overwriting non-functions). + add_object(name_, func, true /* overwrite */); + return *this; + } + + /** \rst + Create and return a new Python submodule with the given name and docstring. + This also works recursively, i.e. + + .. code-block:: cpp + + py::module m("example", "pybind11 example plugin"); + py::module m2 = m.def_submodule("sub", "A submodule of 'example'"); + py::module m3 = m2.def_submodule("subsub", "A submodule of 'example.sub'"); + \endrst */ + module def_submodule(const char *name, const char *doc = nullptr) { + std::string full_name = std::string(PyModule_GetName(m_ptr)) + + std::string(".") + std::string(name); + auto result = reinterpret_borrow(PyImport_AddModule(full_name.c_str())); + if (doc && options::show_user_defined_docstrings()) + result.attr("__doc__") = pybind11::str(doc); + attr(name) = result; + return result; + } + + /// Import and return a module or throws `error_already_set`. + static module import(const char *name) { + PyObject *obj = PyImport_ImportModule(name); + if (!obj) + throw error_already_set(); + return reinterpret_steal(obj); + } + + /// Reload the module or throws `error_already_set`. + void reload() { + PyObject *obj = PyImport_ReloadModule(ptr()); + if (!obj) + throw error_already_set(); + *this = reinterpret_steal(obj); + } + + // Adds an object to the module using the given name. Throws if an object with the given name + // already exists. + // + // overwrite should almost always be false: attempting to overwrite objects that pybind11 has + // established will, in most cases, break things. + PYBIND11_NOINLINE void add_object(const char *name, handle obj, bool overwrite = false) { + if (!overwrite && hasattr(*this, name)) + pybind11_fail("Error during initialization: multiple incompatible definitions with name \"" + + std::string(name) + "\""); + + PyModule_AddObject(ptr(), name, obj.inc_ref().ptr() /* steals a reference */); + } +}; + +/// \ingroup python_builtins +/// Return a dictionary representing the global variables in the current execution frame, +/// or ``__main__.__dict__`` if there is no frame (usually when the interpreter is embedded). +inline dict globals() { + PyObject *p = PyEval_GetGlobals(); + return reinterpret_borrow(p ? p : module::import("__main__").attr("__dict__").ptr()); +} + +NAMESPACE_BEGIN(detail) +/// Generic support for creating new Python heap types +class generic_type : public object { + template friend class class_; +public: + PYBIND11_OBJECT_DEFAULT(generic_type, object, PyType_Check) +protected: + void initialize(const type_record &rec) { + if (rec.scope && hasattr(rec.scope, rec.name)) + pybind11_fail("generic_type: cannot initialize type \"" + std::string(rec.name) + + "\": an object with that name is already defined"); + + if (rec.module_local ? get_local_type_info(*rec.type) : get_global_type_info(*rec.type)) + pybind11_fail("generic_type: type \"" + std::string(rec.name) + + "\" is already registered!"); + + m_ptr = make_new_python_type(rec); + + /* Register supplemental type information in C++ dict */ + auto *tinfo = new detail::type_info(); + tinfo->type = (PyTypeObject *) m_ptr; + tinfo->cpptype = rec.type; + tinfo->type_size = rec.type_size; + tinfo->operator_new = rec.operator_new; + tinfo->holder_size_in_ptrs = size_in_ptrs(rec.holder_size); + tinfo->init_instance = rec.init_instance; + tinfo->dealloc = rec.dealloc; + tinfo->simple_type = true; + tinfo->simple_ancestors = true; + tinfo->default_holder = rec.default_holder; + tinfo->module_local = rec.module_local; + + auto &internals = get_internals(); + auto tindex = std::type_index(*rec.type); + tinfo->direct_conversions = &internals.direct_conversions[tindex]; + if (rec.module_local) + registered_local_types_cpp()[tindex] = tinfo; + else + internals.registered_types_cpp[tindex] = tinfo; + internals.registered_types_py[(PyTypeObject *) m_ptr] = { tinfo }; + + if (rec.bases.size() > 1 || rec.multiple_inheritance) { + mark_parents_nonsimple(tinfo->type); + tinfo->simple_ancestors = false; + } + else if (rec.bases.size() == 1) { + auto parent_tinfo = get_type_info((PyTypeObject *) rec.bases[0].ptr()); + tinfo->simple_ancestors = parent_tinfo->simple_ancestors; + } + + if (rec.module_local) { + // Stash the local typeinfo and loader so that external modules can access it. + tinfo->module_local_load = &type_caster_generic::local_load; + setattr(m_ptr, PYBIND11_MODULE_LOCAL_ID, capsule(tinfo)); + } + } + + /// Helper function which tags all parents of a type using mult. inheritance + void mark_parents_nonsimple(PyTypeObject *value) { + auto t = reinterpret_borrow(value->tp_bases); + for (handle h : t) { + auto tinfo2 = get_type_info((PyTypeObject *) h.ptr()); + if (tinfo2) + tinfo2->simple_type = false; + mark_parents_nonsimple((PyTypeObject *) h.ptr()); + } + } + + void install_buffer_funcs( + buffer_info *(*get_buffer)(PyObject *, void *), + void *get_buffer_data) { + PyHeapTypeObject *type = (PyHeapTypeObject*) m_ptr; + auto tinfo = detail::get_type_info(&type->ht_type); + + if (!type->ht_type.tp_as_buffer) + pybind11_fail( + "To be able to register buffer protocol support for the type '" + + std::string(tinfo->type->tp_name) + + "' the associated class<>(..) invocation must " + "include the pybind11::buffer_protocol() annotation!"); + + tinfo->get_buffer = get_buffer; + tinfo->get_buffer_data = get_buffer_data; + } + + void def_property_static_impl(const char *name, + handle fget, handle fset, + detail::function_record *rec_fget) { + const auto is_static = !(rec_fget->is_method && rec_fget->scope); + const auto has_doc = rec_fget->doc && pybind11::options::show_user_defined_docstrings(); + + auto property = handle((PyObject *) (is_static ? get_internals().static_property_type + : &PyProperty_Type)); + attr(name) = property(fget.ptr() ? fget : none(), + fset.ptr() ? fset : none(), + /*deleter*/none(), + pybind11::str(has_doc ? rec_fget->doc : "")); + } +}; + +/// Set the pointer to operator new if it exists. The cast is needed because it can be overloaded. +template (T::operator new))>> +void set_operator_new(type_record *r) { r->operator_new = &T::operator new; } + +template void set_operator_new(...) { } + +template struct has_operator_delete : std::false_type { }; +template struct has_operator_delete(T::operator delete))>> + : std::true_type { }; +template struct has_operator_delete_size : std::false_type { }; +template struct has_operator_delete_size(T::operator delete))>> + : std::true_type { }; +/// Call class-specific delete if it exists or global otherwise. Can also be an overload set. +template ::value, int> = 0> +void call_operator_delete(T *p, size_t) { T::operator delete(p); } +template ::value && has_operator_delete_size::value, int> = 0> +void call_operator_delete(T *p, size_t s) { T::operator delete(p, s); } + +inline void call_operator_delete(void *p, size_t) { ::operator delete(p); } + +NAMESPACE_END(detail) + +/// Given a pointer to a member function, cast it to its `Derived` version. +/// Forward everything else unchanged. +template +auto method_adaptor(F &&f) -> decltype(std::forward(f)) { return std::forward(f); } + +template +auto method_adaptor(Return (Class::*pmf)(Args...)) -> Return (Derived::*)(Args...) { return pmf; } + +template +auto method_adaptor(Return (Class::*pmf)(Args...) const) -> Return (Derived::*)(Args...) const { return pmf; } + +template +class class_ : public detail::generic_type { + template using is_holder = detail::is_holder_type; + template using is_subtype = detail::is_strict_base_of; + template using is_base = detail::is_strict_base_of; + // struct instead of using here to help MSVC: + template struct is_valid_class_option : + detail::any_of, is_subtype, is_base> {}; + +public: + using type = type_; + using type_alias = detail::exactly_one_t; + constexpr static bool has_alias = !std::is_void::value; + using holder_type = detail::exactly_one_t, options...>; + + static_assert(detail::all_of...>::value, + "Unknown/invalid class_ template parameters provided"); + + static_assert(!has_alias || std::is_polymorphic::value, + "Cannot use an alias class with a non-polymorphic type"); + + PYBIND11_OBJECT(class_, generic_type, PyType_Check) + + template + class_(handle scope, const char *name, const Extra &... extra) { + using namespace detail; + + // MI can only be specified via class_ template options, not constructor parameters + static_assert( + none_of...>::value || // no base class arguments, or: + ( constexpr_sum(is_pyobject::value...) == 1 && // Exactly one base + constexpr_sum(is_base::value...) == 0 && // no template option bases + none_of...>::value), // no multiple_inheritance attr + "Error: multiple inheritance bases must be specified via class_ template options"); + + type_record record; + record.scope = scope; + record.name = name; + record.type = &typeid(type); + record.type_size = sizeof(conditional_t); + record.holder_size = sizeof(holder_type); + record.init_instance = init_instance; + record.dealloc = dealloc; + record.default_holder = std::is_same>::value; + + set_operator_new(&record); + + /* Register base classes specified via template arguments to class_, if any */ + PYBIND11_EXPAND_SIDE_EFFECTS(add_base(record)); + + /* Process optional arguments, if any */ + process_attributes::init(extra..., &record); + + generic_type::initialize(record); + + if (has_alias) { + auto &instances = record.module_local ? registered_local_types_cpp() : get_internals().registered_types_cpp; + instances[std::type_index(typeid(type_alias))] = instances[std::type_index(typeid(type))]; + } + } + + template ::value, int> = 0> + static void add_base(detail::type_record &rec) { + rec.add_base(typeid(Base), [](void *src) -> void * { + return static_cast(reinterpret_cast(src)); + }); + } + + template ::value, int> = 0> + static void add_base(detail::type_record &) { } + + template + class_ &def(const char *name_, Func&& f, const Extra&... extra) { + cpp_function cf(method_adaptor(std::forward(f)), name(name_), is_method(*this), + sibling(getattr(*this, name_, none())), extra...); + attr(cf.name()) = cf; + return *this; + } + + template class_ & + def_static(const char *name_, Func &&f, const Extra&... extra) { + static_assert(!std::is_member_function_pointer::value, + "def_static(...) called with a non-static member function pointer"); + cpp_function cf(std::forward(f), name(name_), scope(*this), + sibling(getattr(*this, name_, none())), extra...); + attr(cf.name()) = cf; + return *this; + } + + template + class_ &def(const detail::op_ &op, const Extra&... extra) { + op.execute(*this, extra...); + return *this; + } + + template + class_ & def_cast(const detail::op_ &op, const Extra&... extra) { + op.execute_cast(*this, extra...); + return *this; + } + + template + class_ &def(const detail::initimpl::constructor &init, const Extra&... extra) { + init.execute(*this, extra...); + return *this; + } + + template + class_ &def(const detail::initimpl::alias_constructor &init, const Extra&... extra) { + init.execute(*this, extra...); + return *this; + } + + template + class_ &def(detail::initimpl::factory &&init, const Extra&... extra) { + std::move(init).execute(*this, extra...); + return *this; + } + + template + class_ &def(detail::initimpl::pickle_factory &&pf, const Extra &...extra) { + std::move(pf).execute(*this, extra...); + return *this; + } + + template class_& def_buffer(Func &&func) { + struct capture { Func func; }; + capture *ptr = new capture { std::forward(func) }; + install_buffer_funcs([](PyObject *obj, void *ptr) -> buffer_info* { + detail::make_caster caster; + if (!caster.load(obj, false)) + return nullptr; + return new buffer_info(((capture *) ptr)->func(caster)); + }, ptr); + return *this; + } + + template + class_ &def_buffer(Return (Class::*func)(Args...)) { + return def_buffer([func] (type &obj) { return (obj.*func)(); }); + } + + template + class_ &def_buffer(Return (Class::*func)(Args...) const) { + return def_buffer([func] (const type &obj) { return (obj.*func)(); }); + } + + template + class_ &def_readwrite(const char *name, D C::*pm, const Extra&... extra) { + static_assert(std::is_base_of::value, "def_readwrite() requires a class member (or base class member)"); + cpp_function fget([pm](const type &c) -> const D &{ return c.*pm; }, is_method(*this)), + fset([pm](type &c, const D &value) { c.*pm = value; }, is_method(*this)); + def_property(name, fget, fset, return_value_policy::reference_internal, extra...); + return *this; + } + + template + class_ &def_readonly(const char *name, const D C::*pm, const Extra& ...extra) { + static_assert(std::is_base_of::value, "def_readonly() requires a class member (or base class member)"); + cpp_function fget([pm](const type &c) -> const D &{ return c.*pm; }, is_method(*this)); + def_property_readonly(name, fget, return_value_policy::reference_internal, extra...); + return *this; + } + + template + class_ &def_readwrite_static(const char *name, D *pm, const Extra& ...extra) { + cpp_function fget([pm](object) -> const D &{ return *pm; }, scope(*this)), + fset([pm](object, const D &value) { *pm = value; }, scope(*this)); + def_property_static(name, fget, fset, return_value_policy::reference, extra...); + return *this; + } + + template + class_ &def_readonly_static(const char *name, const D *pm, const Extra& ...extra) { + cpp_function fget([pm](object) -> const D &{ return *pm; }, scope(*this)); + def_property_readonly_static(name, fget, return_value_policy::reference, extra...); + return *this; + } + + /// Uses return_value_policy::reference_internal by default + template + class_ &def_property_readonly(const char *name, const Getter &fget, const Extra& ...extra) { + return def_property_readonly(name, cpp_function(method_adaptor(fget)), + return_value_policy::reference_internal, extra...); + } + + /// Uses cpp_function's return_value_policy by default + template + class_ &def_property_readonly(const char *name, const cpp_function &fget, const Extra& ...extra) { + return def_property(name, fget, cpp_function(), extra...); + } + + /// Uses return_value_policy::reference by default + template + class_ &def_property_readonly_static(const char *name, const Getter &fget, const Extra& ...extra) { + return def_property_readonly_static(name, cpp_function(fget), return_value_policy::reference, extra...); + } + + /// Uses cpp_function's return_value_policy by default + template + class_ &def_property_readonly_static(const char *name, const cpp_function &fget, const Extra& ...extra) { + return def_property_static(name, fget, cpp_function(), extra...); + } + + /// Uses return_value_policy::reference_internal by default + template + class_ &def_property(const char *name, const Getter &fget, const Setter &fset, const Extra& ...extra) { + return def_property(name, fget, cpp_function(method_adaptor(fset)), extra...); + } + template + class_ &def_property(const char *name, const Getter &fget, const cpp_function &fset, const Extra& ...extra) { + return def_property(name, cpp_function(method_adaptor(fget)), fset, + return_value_policy::reference_internal, extra...); + } + + /// Uses cpp_function's return_value_policy by default + template + class_ &def_property(const char *name, const cpp_function &fget, const cpp_function &fset, const Extra& ...extra) { + return def_property_static(name, fget, fset, is_method(*this), extra...); + } + + /// Uses return_value_policy::reference by default + template + class_ &def_property_static(const char *name, const Getter &fget, const cpp_function &fset, const Extra& ...extra) { + return def_property_static(name, cpp_function(fget), fset, return_value_policy::reference, extra...); + } + + /// Uses cpp_function's return_value_policy by default + template + class_ &def_property_static(const char *name, const cpp_function &fget, const cpp_function &fset, const Extra& ...extra) { + auto rec_fget = get_function_record(fget), rec_fset = get_function_record(fset); + char *doc_prev = rec_fget->doc; /* 'extra' field may include a property-specific documentation string */ + detail::process_attributes::init(extra..., rec_fget); + if (rec_fget->doc && rec_fget->doc != doc_prev) { + free(doc_prev); + rec_fget->doc = strdup(rec_fget->doc); + } + if (rec_fset) { + doc_prev = rec_fset->doc; + detail::process_attributes::init(extra..., rec_fset); + if (rec_fset->doc && rec_fset->doc != doc_prev) { + free(doc_prev); + rec_fset->doc = strdup(rec_fset->doc); + } + } + def_property_static_impl(name, fget, fset, rec_fget); + return *this; + } + +private: + /// Initialize holder object, variant 1: object derives from enable_shared_from_this + template + static void init_holder(detail::instance *inst, detail::value_and_holder &v_h, + const holder_type * /* unused */, const std::enable_shared_from_this * /* dummy */) { + try { + auto sh = std::dynamic_pointer_cast( + v_h.value_ptr()->shared_from_this()); + if (sh) { + new (&v_h.holder()) holder_type(std::move(sh)); + v_h.set_holder_constructed(); + } + } catch (const std::bad_weak_ptr &) {} + + if (!v_h.holder_constructed() && inst->owned) { + new (&v_h.holder()) holder_type(v_h.value_ptr()); + v_h.set_holder_constructed(); + } + } + + static void init_holder_from_existing(const detail::value_and_holder &v_h, + const holder_type *holder_ptr, std::true_type /*is_copy_constructible*/) { + new (&v_h.holder()) holder_type(*reinterpret_cast(holder_ptr)); + } + + static void init_holder_from_existing(const detail::value_and_holder &v_h, + const holder_type *holder_ptr, std::false_type /*is_copy_constructible*/) { + new (&v_h.holder()) holder_type(std::move(*const_cast(holder_ptr))); + } + + /// Initialize holder object, variant 2: try to construct from existing holder object, if possible + static void init_holder(detail::instance *inst, detail::value_and_holder &v_h, + const holder_type *holder_ptr, const void * /* dummy -- not enable_shared_from_this) */) { + if (holder_ptr) { + init_holder_from_existing(v_h, holder_ptr, std::is_copy_constructible()); + v_h.set_holder_constructed(); + } else if (inst->owned || detail::always_construct_holder::value) { + new (&v_h.holder()) holder_type(v_h.value_ptr()); + v_h.set_holder_constructed(); + } + } + + /// Performs instance initialization including constructing a holder and registering the known + /// instance. Should be called as soon as the `type` value_ptr is set for an instance. Takes an + /// optional pointer to an existing holder to use; if not specified and the instance is + /// `.owned`, a new holder will be constructed to manage the value pointer. + static void init_instance(detail::instance *inst, const void *holder_ptr) { + auto v_h = inst->get_value_and_holder(detail::get_type_info(typeid(type))); + if (!v_h.instance_registered()) { + register_instance(inst, v_h.value_ptr(), v_h.type); + v_h.set_instance_registered(); + } + init_holder(inst, v_h, (const holder_type *) holder_ptr, v_h.value_ptr()); + } + + /// Deallocates an instance; via holder, if constructed; otherwise via operator delete. + static void dealloc(detail::value_and_holder &v_h) { + if (v_h.holder_constructed()) { + v_h.holder().~holder_type(); + v_h.set_holder_constructed(false); + } + else { + detail::call_operator_delete(v_h.value_ptr(), v_h.type->type_size); + } + v_h.value_ptr() = nullptr; + } + + static detail::function_record *get_function_record(handle h) { + h = detail::get_function(h); + return h ? (detail::function_record *) reinterpret_borrow(PyCFunction_GET_SELF(h.ptr())) + : nullptr; + } +}; + +/// Binds an existing constructor taking arguments Args... +template detail::initimpl::constructor init() { return {}; } +/// Like `init()`, but the instance is always constructed through the alias class (even +/// when not inheriting on the Python side). +template detail::initimpl::alias_constructor init_alias() { return {}; } + +/// Binds a factory function as a constructor +template > +Ret init(Func &&f) { return {std::forward(f)}; } + +/// Dual-argument factory function: the first function is called when no alias is needed, the second +/// when an alias is needed (i.e. due to python-side inheritance). Arguments must be identical. +template > +Ret init(CFunc &&c, AFunc &&a) { + return {std::forward(c), std::forward(a)}; +} + +/// Binds pickling functions `__getstate__` and `__setstate__` and ensures that the type +/// returned by `__getstate__` is the same as the argument accepted by `__setstate__`. +template +detail::initimpl::pickle_factory pickle(GetState &&g, SetState &&s) { + return {std::forward(g), std::forward(s)}; +} + +/// Binds C++ enumerations and enumeration classes to Python +template class enum_ : public class_ { +public: + using class_::def; + using class_::def_property_readonly_static; + using Scalar = typename std::underlying_type::type; + + template + enum_(const handle &scope, const char *name, const Extra&... extra) + : class_(scope, name, extra...), m_entries(), m_parent(scope) { + + constexpr bool is_arithmetic = detail::any_of...>::value; + + auto m_entries_ptr = m_entries.inc_ref().ptr(); + def("__repr__", [name, m_entries_ptr](Type value) -> pybind11::str { + for (const auto &kv : reinterpret_borrow(m_entries_ptr)) { + if (pybind11::cast(kv.second) == value) + return pybind11::str("{}.{}").format(name, kv.first); + } + return pybind11::str("{}.???").format(name); + }); + def_property_readonly_static("__members__", [m_entries_ptr](object /* self */) { + dict m; + for (const auto &kv : reinterpret_borrow(m_entries_ptr)) + m[kv.first] = kv.second; + return m; + }, return_value_policy::copy); + def(init([](Scalar i) { return static_cast(i); })); + def("__int__", [](Type value) { return (Scalar) value; }); + #if PY_MAJOR_VERSION < 3 + def("__long__", [](Type value) { return (Scalar) value; }); + #endif + def("__eq__", [](const Type &value, Type *value2) { return value2 && value == *value2; }); + def("__ne__", [](const Type &value, Type *value2) { return !value2 || value != *value2; }); + if (is_arithmetic) { + def("__lt__", [](const Type &value, Type *value2) { return value2 && value < *value2; }); + def("__gt__", [](const Type &value, Type *value2) { return value2 && value > *value2; }); + def("__le__", [](const Type &value, Type *value2) { return value2 && value <= *value2; }); + def("__ge__", [](const Type &value, Type *value2) { return value2 && value >= *value2; }); + } + if (std::is_convertible::value) { + // Don't provide comparison with the underlying type if the enum isn't convertible, + // i.e. if Type is a scoped enum, mirroring the C++ behaviour. (NB: we explicitly + // convert Type to Scalar below anyway because this needs to compile). + def("__eq__", [](const Type &value, Scalar value2) { return (Scalar) value == value2; }); + def("__ne__", [](const Type &value, Scalar value2) { return (Scalar) value != value2; }); + if (is_arithmetic) { + def("__lt__", [](const Type &value, Scalar value2) { return (Scalar) value < value2; }); + def("__gt__", [](const Type &value, Scalar value2) { return (Scalar) value > value2; }); + def("__le__", [](const Type &value, Scalar value2) { return (Scalar) value <= value2; }); + def("__ge__", [](const Type &value, Scalar value2) { return (Scalar) value >= value2; }); + def("__invert__", [](const Type &value) { return ~((Scalar) value); }); + def("__and__", [](const Type &value, Scalar value2) { return (Scalar) value & value2; }); + def("__or__", [](const Type &value, Scalar value2) { return (Scalar) value | value2; }); + def("__xor__", [](const Type &value, Scalar value2) { return (Scalar) value ^ value2; }); + def("__rand__", [](const Type &value, Scalar value2) { return (Scalar) value & value2; }); + def("__ror__", [](const Type &value, Scalar value2) { return (Scalar) value | value2; }); + def("__rxor__", [](const Type &value, Scalar value2) { return (Scalar) value ^ value2; }); + def("__and__", [](const Type &value, const Type &value2) { return (Scalar) value & (Scalar) value2; }); + def("__or__", [](const Type &value, const Type &value2) { return (Scalar) value | (Scalar) value2; }); + def("__xor__", [](const Type &value, const Type &value2) { return (Scalar) value ^ (Scalar) value2; }); + } + } + def("__hash__", [](const Type &value) { return (Scalar) value; }); + // Pickling and unpickling -- needed for use with the 'multiprocessing' module + def(pickle([](const Type &value) { return pybind11::make_tuple((Scalar) value); }, + [](tuple t) { return static_cast(t[0].cast()); })); + } + + /// Export enumeration entries into the parent scope + enum_& export_values() { + for (const auto &kv : m_entries) + m_parent.attr(kv.first) = kv.second; + return *this; + } + + /// Add an enumeration entry + enum_& value(char const* name, Type value) { + auto v = pybind11::cast(value, return_value_policy::copy); + this->attr(name) = v; + m_entries[pybind11::str(name)] = v; + return *this; + } + +private: + dict m_entries; + handle m_parent; +}; + +NAMESPACE_BEGIN(detail) + + +inline void keep_alive_impl(handle nurse, handle patient) { + if (!nurse || !patient) + pybind11_fail("Could not activate keep_alive!"); + + if (patient.is_none() || nurse.is_none()) + return; /* Nothing to keep alive or nothing to be kept alive by */ + + auto tinfo = all_type_info(Py_TYPE(nurse.ptr())); + if (!tinfo.empty()) { + /* It's a pybind-registered type, so we can store the patient in the + * internal list. */ + add_patient(nurse.ptr(), patient.ptr()); + } + else { + /* Fall back to clever approach based on weak references taken from + * Boost.Python. This is not used for pybind-registered types because + * the objects can be destroyed out-of-order in a GC pass. */ + cpp_function disable_lifesupport( + [patient](handle weakref) { patient.dec_ref(); weakref.dec_ref(); }); + + weakref wr(nurse, disable_lifesupport); + + patient.inc_ref(); /* reference patient and leak the weak reference */ + (void) wr.release(); + } +} + +PYBIND11_NOINLINE inline void keep_alive_impl(size_t Nurse, size_t Patient, function_call &call, handle ret) { + auto get_arg = [&](size_t n) { + if (n == 0) + return ret; + else if (n == 1 && call.init_self) + return call.init_self; + else if (n <= call.args.size()) + return call.args[n - 1]; + return handle(); + }; + + keep_alive_impl(get_arg(Nurse), get_arg(Patient)); +} + +inline std::pair all_type_info_get_cache(PyTypeObject *type) { + auto res = get_internals().registered_types_py +#ifdef __cpp_lib_unordered_map_try_emplace + .try_emplace(type); +#else + .emplace(type, std::vector()); +#endif + if (res.second) { + // New cache entry created; set up a weak reference to automatically remove it if the type + // gets destroyed: + weakref((PyObject *) type, cpp_function([type](handle wr) { + get_internals().registered_types_py.erase(type); + wr.dec_ref(); + })).release(); + } + + return res; +} + +template +struct iterator_state { + Iterator it; + Sentinel end; + bool first_or_done; +}; + +NAMESPACE_END(detail) + +/// Makes a python iterator from a first and past-the-end C++ InputIterator. +template ()), + typename... Extra> +iterator make_iterator(Iterator first, Sentinel last, Extra &&... extra) { + typedef detail::iterator_state state; + + if (!detail::get_type_info(typeid(state), false)) { + class_(handle(), "iterator", pybind11::module_local()) + .def("__iter__", [](state &s) -> state& { return s; }) + .def("__next__", [](state &s) -> ValueType { + if (!s.first_or_done) + ++s.it; + else + s.first_or_done = false; + if (s.it == s.end) { + s.first_or_done = true; + throw stop_iteration(); + } + return *s.it; + }, std::forward(extra)..., Policy); + } + + return cast(state{first, last, true}); +} + +/// Makes an python iterator over the keys (`.first`) of a iterator over pairs from a +/// first and past-the-end InputIterator. +template ()).first), + typename... Extra> +iterator make_key_iterator(Iterator first, Sentinel last, Extra &&... extra) { + typedef detail::iterator_state state; + + if (!detail::get_type_info(typeid(state), false)) { + class_(handle(), "iterator", pybind11::module_local()) + .def("__iter__", [](state &s) -> state& { return s; }) + .def("__next__", [](state &s) -> KeyType { + if (!s.first_or_done) + ++s.it; + else + s.first_or_done = false; + if (s.it == s.end) { + s.first_or_done = true; + throw stop_iteration(); + } + return (*s.it).first; + }, std::forward(extra)..., Policy); + } + + return cast(state{first, last, true}); +} + +/// Makes an iterator over values of an stl container or other container supporting +/// `std::begin()`/`std::end()` +template iterator make_iterator(Type &value, Extra&&... extra) { + return make_iterator(std::begin(value), std::end(value), extra...); +} + +/// Makes an iterator over the keys (`.first`) of a stl map-like container supporting +/// `std::begin()`/`std::end()` +template iterator make_key_iterator(Type &value, Extra&&... extra) { + return make_key_iterator(std::begin(value), std::end(value), extra...); +} + +template void implicitly_convertible() { + struct set_flag { + bool &flag; + set_flag(bool &flag) : flag(flag) { flag = true; } + ~set_flag() { flag = false; } + }; + auto implicit_caster = [](PyObject *obj, PyTypeObject *type) -> PyObject * { + static bool currently_used = false; + if (currently_used) // implicit conversions are non-reentrant + return nullptr; + set_flag flag_helper(currently_used); + if (!detail::make_caster().load(obj, false)) + return nullptr; + tuple args(1); + args[0] = obj; + PyObject *result = PyObject_Call((PyObject *) type, args.ptr(), nullptr); + if (result == nullptr) + PyErr_Clear(); + return result; + }; + + if (auto tinfo = detail::get_type_info(typeid(OutputType))) + tinfo->implicit_conversions.push_back(implicit_caster); + else + pybind11_fail("implicitly_convertible: Unable to find type " + type_id()); +} + +template +void register_exception_translator(ExceptionTranslator&& translator) { + detail::get_internals().registered_exception_translators.push_front( + std::forward(translator)); +} + +/** + * Wrapper to generate a new Python exception type. + * + * This should only be used with PyErr_SetString for now. + * It is not (yet) possible to use as a py::base. + * Template type argument is reserved for future use. + */ +template +class exception : public object { +public: + exception(handle scope, const char *name, PyObject *base = PyExc_Exception) { + std::string full_name = scope.attr("__name__").cast() + + std::string(".") + name; + m_ptr = PyErr_NewException(const_cast(full_name.c_str()), base, NULL); + if (hasattr(scope, name)) + pybind11_fail("Error during initialization: multiple incompatible " + "definitions with name \"" + std::string(name) + "\""); + scope.attr(name) = *this; + } + + // Sets the current python exception to this exception object with the given message + void operator()(const char *message) { + PyErr_SetString(m_ptr, message); + } +}; + +/** + * Registers a Python exception in `m` of the given `name` and installs an exception translator to + * translate the C++ exception to the created Python exception using the exceptions what() method. + * This is intended for simple exception translations; for more complex translation, register the + * exception object and translator directly. + */ +template +exception ®ister_exception(handle scope, + const char *name, + PyObject *base = PyExc_Exception) { + static exception ex(scope, name, base); + register_exception_translator([](std::exception_ptr p) { + if (!p) return; + try { + std::rethrow_exception(p); + } catch (const CppException &e) { + ex(e.what()); + } + }); + return ex; +} + +NAMESPACE_BEGIN(detail) +PYBIND11_NOINLINE inline void print(tuple args, dict kwargs) { + auto strings = tuple(args.size()); + for (size_t i = 0; i < args.size(); ++i) { + strings[i] = str(args[i]); + } + auto sep = kwargs.contains("sep") ? kwargs["sep"] : cast(" "); + auto line = sep.attr("join")(strings); + + object file; + if (kwargs.contains("file")) { + file = kwargs["file"].cast(); + } else { + try { + file = module::import("sys").attr("stdout"); + } catch (const error_already_set &) { + /* If print() is called from code that is executed as + part of garbage collection during interpreter shutdown, + importing 'sys' can fail. Give up rather than crashing the + interpreter in this case. */ + return; + } + } + + auto write = file.attr("write"); + write(line); + write(kwargs.contains("end") ? kwargs["end"] : cast("\n")); + + if (kwargs.contains("flush") && kwargs["flush"].cast()) + file.attr("flush")(); +} +NAMESPACE_END(detail) + +template +void print(Args &&...args) { + auto c = detail::collect_arguments(std::forward(args)...); + detail::print(c.args(), c.kwargs()); +} + +#if defined(WITH_THREAD) && !defined(PYPY_VERSION) + +/* The functions below essentially reproduce the PyGILState_* API using a RAII + * pattern, but there are a few important differences: + * + * 1. When acquiring the GIL from an non-main thread during the finalization + * phase, the GILState API blindly terminates the calling thread, which + * is often not what is wanted. This API does not do this. + * + * 2. The gil_scoped_release function can optionally cut the relationship + * of a PyThreadState and its associated thread, which allows moving it to + * another thread (this is a fairly rare/advanced use case). + * + * 3. The reference count of an acquired thread state can be controlled. This + * can be handy to prevent cases where callbacks issued from an external + * thread would otherwise constantly construct and destroy thread state data + * structures. + * + * See the Python bindings of NanoGUI (http://github.com/wjakob/nanogui) for an + * example which uses features 2 and 3 to migrate the Python thread of + * execution to another thread (to run the event loop on the original thread, + * in this case). + */ + +class gil_scoped_acquire { +public: + PYBIND11_NOINLINE gil_scoped_acquire() { + auto const &internals = detail::get_internals(); + tstate = (PyThreadState *) PyThread_get_key_value(internals.tstate); + + if (!tstate) { + tstate = PyThreadState_New(internals.istate); + #if !defined(NDEBUG) + if (!tstate) + pybind11_fail("scoped_acquire: could not create thread state!"); + #endif + tstate->gilstate_counter = 0; + #if PY_MAJOR_VERSION < 3 + PyThread_delete_key_value(internals.tstate); + #endif + PyThread_set_key_value(internals.tstate, tstate); + } else { + release = detail::get_thread_state_unchecked() != tstate; + } + + if (release) { + /* Work around an annoying assertion in PyThreadState_Swap */ + #if defined(Py_DEBUG) + PyInterpreterState *interp = tstate->interp; + tstate->interp = nullptr; + #endif + PyEval_AcquireThread(tstate); + #if defined(Py_DEBUG) + tstate->interp = interp; + #endif + } + + inc_ref(); + } + + void inc_ref() { + ++tstate->gilstate_counter; + } + + PYBIND11_NOINLINE void dec_ref() { + --tstate->gilstate_counter; + #if !defined(NDEBUG) + if (detail::get_thread_state_unchecked() != tstate) + pybind11_fail("scoped_acquire::dec_ref(): thread state must be current!"); + if (tstate->gilstate_counter < 0) + pybind11_fail("scoped_acquire::dec_ref(): reference count underflow!"); + #endif + if (tstate->gilstate_counter == 0) { + #if !defined(NDEBUG) + if (!release) + pybind11_fail("scoped_acquire::dec_ref(): internal error!"); + #endif + PyThreadState_Clear(tstate); + PyThreadState_DeleteCurrent(); + PyThread_delete_key_value(detail::get_internals().tstate); + release = false; + } + } + + PYBIND11_NOINLINE ~gil_scoped_acquire() { + dec_ref(); + if (release) + PyEval_SaveThread(); + } +private: + PyThreadState *tstate = nullptr; + bool release = true; +}; + +class gil_scoped_release { +public: + explicit gil_scoped_release(bool disassoc = false) : disassoc(disassoc) { + // `get_internals()` must be called here unconditionally in order to initialize + // `internals.tstate` for subsequent `gil_scoped_acquire` calls. Otherwise, an + // initialization race could occur as multiple threads try `gil_scoped_acquire`. + const auto &internals = detail::get_internals(); + tstate = PyEval_SaveThread(); + if (disassoc) { + auto key = internals.tstate; + #if PY_MAJOR_VERSION < 3 + PyThread_delete_key_value(key); + #else + PyThread_set_key_value(key, nullptr); + #endif + } + } + ~gil_scoped_release() { + if (!tstate) + return; + PyEval_RestoreThread(tstate); + if (disassoc) { + auto key = detail::get_internals().tstate; + #if PY_MAJOR_VERSION < 3 + PyThread_delete_key_value(key); + #endif + PyThread_set_key_value(key, tstate); + } + } +private: + PyThreadState *tstate; + bool disassoc; +}; +#elif defined(PYPY_VERSION) +class gil_scoped_acquire { + PyGILState_STATE state; +public: + gil_scoped_acquire() { state = PyGILState_Ensure(); } + ~gil_scoped_acquire() { PyGILState_Release(state); } +}; + +class gil_scoped_release { + PyThreadState *state; +public: + gil_scoped_release() { state = PyEval_SaveThread(); } + ~gil_scoped_release() { PyEval_RestoreThread(state); } +}; +#else +class gil_scoped_acquire { }; +class gil_scoped_release { }; +#endif + +error_already_set::~error_already_set() { + if (type) { + gil_scoped_acquire gil; + type.release().dec_ref(); + value.release().dec_ref(); + trace.release().dec_ref(); + } +} + +inline function get_type_overload(const void *this_ptr, const detail::type_info *this_type, const char *name) { + handle self = detail::get_object_handle(this_ptr, this_type); + if (!self) + return function(); + handle type = self.get_type(); + auto key = std::make_pair(type.ptr(), name); + + /* Cache functions that aren't overloaded in Python to avoid + many costly Python dictionary lookups below */ + auto &cache = detail::get_internals().inactive_overload_cache; + if (cache.find(key) != cache.end()) + return function(); + + function overload = getattr(self, name, function()); + if (overload.is_cpp_function()) { + cache.insert(key); + return function(); + } + + /* Don't call dispatch code if invoked from overridden function. + Unfortunately this doesn't work on PyPy. */ +#if !defined(PYPY_VERSION) + PyFrameObject *frame = PyThreadState_Get()->frame; + if (frame && (std::string) str(frame->f_code->co_name) == name && + frame->f_code->co_argcount > 0) { + PyFrame_FastToLocals(frame); + PyObject *self_caller = PyDict_GetItem( + frame->f_locals, PyTuple_GET_ITEM(frame->f_code->co_varnames, 0)); + if (self_caller == self.ptr()) + return function(); + } +#else + /* PyPy currently doesn't provide a detailed cpyext emulation of + frame objects, so we have to emulate this using Python. This + is going to be slow..*/ + dict d; d["self"] = self; d["name"] = pybind11::str(name); + PyObject *result = PyRun_String( + "import inspect\n" + "frame = inspect.currentframe()\n" + "if frame is not None:\n" + " frame = frame.f_back\n" + " if frame is not None and str(frame.f_code.co_name) == name and " + "frame.f_code.co_argcount > 0:\n" + " self_caller = frame.f_locals[frame.f_code.co_varnames[0]]\n" + " if self_caller == self:\n" + " self = None\n", + Py_file_input, d.ptr(), d.ptr()); + if (result == nullptr) + throw error_already_set(); + if (d["self"].is_none()) + return function(); + Py_DECREF(result); +#endif + + return overload; +} + +template function get_overload(const T *this_ptr, const char *name) { + auto tinfo = detail::get_type_info(typeid(T)); + return tinfo ? get_type_overload(this_ptr, tinfo, name) : function(); +} + +#define PYBIND11_OVERLOAD_INT(ret_type, cname, name, ...) { \ + pybind11::gil_scoped_acquire gil; \ + pybind11::function overload = pybind11::get_overload(static_cast(this), name); \ + if (overload) { \ + auto o = overload(__VA_ARGS__); \ + if (pybind11::detail::cast_is_temporary_value_reference::value) { \ + static pybind11::detail::overload_caster_t caster; \ + return pybind11::detail::cast_ref(std::move(o), caster); \ + } \ + else return pybind11::detail::cast_safe(std::move(o)); \ + } \ + } + +#define PYBIND11_OVERLOAD_NAME(ret_type, cname, name, fn, ...) \ + PYBIND11_OVERLOAD_INT(ret_type, cname, name, __VA_ARGS__) \ + return cname::fn(__VA_ARGS__) + +#define PYBIND11_OVERLOAD_PURE_NAME(ret_type, cname, name, fn, ...) \ + PYBIND11_OVERLOAD_INT(ret_type, cname, name, __VA_ARGS__) \ + pybind11::pybind11_fail("Tried to call pure virtual function \"" #cname "::" name "\""); + +#define PYBIND11_OVERLOAD(ret_type, cname, fn, ...) \ + PYBIND11_OVERLOAD_NAME(ret_type, cname, #fn, fn, __VA_ARGS__) + +#define PYBIND11_OVERLOAD_PURE(ret_type, cname, fn, ...) \ + PYBIND11_OVERLOAD_PURE_NAME(ret_type, cname, #fn, fn, __VA_ARGS__) + +NAMESPACE_END(PYBIND11_NAMESPACE) + +#if defined(_MSC_VER) +# pragma warning(pop) +#elif defined(__INTEL_COMPILER) +/* Leave ignored warnings on */ +#elif defined(__GNUG__) && !defined(__clang__) +# pragma GCC diagnostic pop +#endif diff --git a/ml/dlib/dlib/external/pybind11/include/pybind11/pytypes.h b/ml/dlib/dlib/external/pybind11/include/pybind11/pytypes.h new file mode 100644 index 000000000..d7fa17775 --- /dev/null +++ b/ml/dlib/dlib/external/pybind11/include/pybind11/pytypes.h @@ -0,0 +1,1332 @@ +/* + pybind11/pytypes.h: Convenience wrapper classes for basic Python types + + Copyright (c) 2016 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include "detail/common.h" +#include "buffer_info.h" +#include +#include + +NAMESPACE_BEGIN(PYBIND11_NAMESPACE) + +/* A few forward declarations */ +class handle; class object; +class str; class iterator; +struct arg; struct arg_v; + +NAMESPACE_BEGIN(detail) +class args_proxy; +inline bool isinstance_generic(handle obj, const std::type_info &tp); + +// Accessor forward declarations +template class accessor; +namespace accessor_policies { + struct obj_attr; + struct str_attr; + struct generic_item; + struct sequence_item; + struct list_item; + struct tuple_item; +} +using obj_attr_accessor = accessor; +using str_attr_accessor = accessor; +using item_accessor = accessor; +using sequence_accessor = accessor; +using list_accessor = accessor; +using tuple_accessor = accessor; + +/// Tag and check to identify a class which implements the Python object API +class pyobject_tag { }; +template using is_pyobject = std::is_base_of>; + +/** \rst + A mixin class which adds common functions to `handle`, `object` and various accessors. + The only requirement for `Derived` is to implement ``PyObject *Derived::ptr() const``. +\endrst */ +template +class object_api : public pyobject_tag { + const Derived &derived() const { return static_cast(*this); } + +public: + /** \rst + Return an iterator equivalent to calling ``iter()`` in Python. The object + must be a collection which supports the iteration protocol. + \endrst */ + iterator begin() const; + /// Return a sentinel which ends iteration. + iterator end() const; + + /** \rst + Return an internal functor to invoke the object's sequence protocol. Casting + the returned ``detail::item_accessor`` instance to a `handle` or `object` + subclass causes a corresponding call to ``__getitem__``. Assigning a `handle` + or `object` subclass causes a call to ``__setitem__``. + \endrst */ + item_accessor operator[](handle key) const; + /// See above (the only difference is that they key is provided as a string literal) + item_accessor operator[](const char *key) const; + + /** \rst + Return an internal functor to access the object's attributes. Casting the + returned ``detail::obj_attr_accessor`` instance to a `handle` or `object` + subclass causes a corresponding call to ``getattr``. Assigning a `handle` + or `object` subclass causes a call to ``setattr``. + \endrst */ + obj_attr_accessor attr(handle key) const; + /// See above (the only difference is that they key is provided as a string literal) + str_attr_accessor attr(const char *key) const; + + /** \rst + Matches * unpacking in Python, e.g. to unpack arguments out of a ``tuple`` + or ``list`` for a function call. Applying another * to the result yields + ** unpacking, e.g. to unpack a dict as function keyword arguments. + See :ref:`calling_python_functions`. + \endrst */ + args_proxy operator*() const; + + /// Check if the given item is contained within this object, i.e. ``item in obj``. + template bool contains(T &&item) const; + + /** \rst + Assuming the Python object is a function or implements the ``__call__`` + protocol, ``operator()`` invokes the underlying function, passing an + arbitrary set of parameters. The result is returned as a `object` and + may need to be converted back into a Python object using `handle::cast()`. + + When some of the arguments cannot be converted to Python objects, the + function will throw a `cast_error` exception. When the Python function + call fails, a `error_already_set` exception is thrown. + \endrst */ + template + object operator()(Args &&...args) const; + template + PYBIND11_DEPRECATED("call(...) was deprecated in favor of operator()(...)") + object call(Args&&... args) const; + + /// Equivalent to ``obj is other`` in Python. + bool is(object_api const& other) const { return derived().ptr() == other.derived().ptr(); } + /// Equivalent to ``obj is None`` in Python. + bool is_none() const { return derived().ptr() == Py_None; } + PYBIND11_DEPRECATED("Use py::str(obj) instead") + pybind11::str str() const; + + /// Get or set the object's docstring, i.e. ``obj.__doc__``. + str_attr_accessor doc() const; + + /// Return the object's current reference count + int ref_count() const { return static_cast(Py_REFCNT(derived().ptr())); } + /// Return a handle to the Python type object underlying the instance + handle get_type() const; +}; + +NAMESPACE_END(detail) + +/** \rst + Holds a reference to a Python object (no reference counting) + + The `handle` class is a thin wrapper around an arbitrary Python object (i.e. a + ``PyObject *`` in Python's C API). It does not perform any automatic reference + counting and merely provides a basic C++ interface to various Python API functions. + + .. seealso:: + The `object` class inherits from `handle` and adds automatic reference + counting features. +\endrst */ +class handle : public detail::object_api { +public: + /// The default constructor creates a handle with a ``nullptr``-valued pointer + handle() = default; + /// Creates a ``handle`` from the given raw Python object pointer + handle(PyObject *ptr) : m_ptr(ptr) { } // Allow implicit conversion from PyObject* + + /// Return the underlying ``PyObject *`` pointer + PyObject *ptr() const { return m_ptr; } + PyObject *&ptr() { return m_ptr; } + + /** \rst + Manually increase the reference count of the Python object. Usually, it is + preferable to use the `object` class which derives from `handle` and calls + this function automatically. Returns a reference to itself. + \endrst */ + const handle& inc_ref() const & { Py_XINCREF(m_ptr); return *this; } + + /** \rst + Manually decrease the reference count of the Python object. Usually, it is + preferable to use the `object` class which derives from `handle` and calls + this function automatically. Returns a reference to itself. + \endrst */ + const handle& dec_ref() const & { Py_XDECREF(m_ptr); return *this; } + + /** \rst + Attempt to cast the Python object into the given C++ type. A `cast_error` + will be throw upon failure. + \endrst */ + template T cast() const; + /// Return ``true`` when the `handle` wraps a valid Python object + explicit operator bool() const { return m_ptr != nullptr; } + /** \rst + Deprecated: Check that the underlying pointers are the same. + Equivalent to ``obj1 is obj2`` in Python. + \endrst */ + PYBIND11_DEPRECATED("Use obj1.is(obj2) instead") + bool operator==(const handle &h) const { return m_ptr == h.m_ptr; } + PYBIND11_DEPRECATED("Use !obj1.is(obj2) instead") + bool operator!=(const handle &h) const { return m_ptr != h.m_ptr; } + PYBIND11_DEPRECATED("Use handle::operator bool() instead") + bool check() const { return m_ptr != nullptr; } +protected: + PyObject *m_ptr = nullptr; +}; + +/** \rst + Holds a reference to a Python object (with reference counting) + + Like `handle`, the `object` class is a thin wrapper around an arbitrary Python + object (i.e. a ``PyObject *`` in Python's C API). In contrast to `handle`, it + optionally increases the object's reference count upon construction, and it + *always* decreases the reference count when the `object` instance goes out of + scope and is destructed. When using `object` instances consistently, it is much + easier to get reference counting right at the first attempt. +\endrst */ +class object : public handle { +public: + object() = default; + PYBIND11_DEPRECATED("Use reinterpret_borrow() or reinterpret_steal()") + object(handle h, bool is_borrowed) : handle(h) { if (is_borrowed) inc_ref(); } + /// Copy constructor; always increases the reference count + object(const object &o) : handle(o) { inc_ref(); } + /// Move constructor; steals the object from ``other`` and preserves its reference count + object(object &&other) noexcept { m_ptr = other.m_ptr; other.m_ptr = nullptr; } + /// Destructor; automatically calls `handle::dec_ref()` + ~object() { dec_ref(); } + + /** \rst + Resets the internal pointer to ``nullptr`` without without decreasing the + object's reference count. The function returns a raw handle to the original + Python object. + \endrst */ + handle release() { + PyObject *tmp = m_ptr; + m_ptr = nullptr; + return handle(tmp); + } + + object& operator=(const object &other) { + other.inc_ref(); + dec_ref(); + m_ptr = other.m_ptr; + return *this; + } + + object& operator=(object &&other) noexcept { + if (this != &other) { + handle temp(m_ptr); + m_ptr = other.m_ptr; + other.m_ptr = nullptr; + temp.dec_ref(); + } + return *this; + } + + // Calling cast() on an object lvalue just copies (via handle::cast) + template T cast() const &; + // Calling on an object rvalue does a move, if needed and/or possible + template T cast() &&; + +protected: + // Tags for choosing constructors from raw PyObject * + struct borrowed_t { }; + struct stolen_t { }; + + template friend T reinterpret_borrow(handle); + template friend T reinterpret_steal(handle); + +public: + // Only accessible from derived classes and the reinterpret_* functions + object(handle h, borrowed_t) : handle(h) { inc_ref(); } + object(handle h, stolen_t) : handle(h) { } +}; + +/** \rst + Declare that a `handle` or ``PyObject *`` is a certain type and borrow the reference. + The target type ``T`` must be `object` or one of its derived classes. The function + doesn't do any conversions or checks. It's up to the user to make sure that the + target type is correct. + + .. code-block:: cpp + + PyObject *p = PyList_GetItem(obj, index); + py::object o = reinterpret_borrow(p); + // or + py::tuple t = reinterpret_borrow(p); // <-- `p` must be already be a `tuple` +\endrst */ +template T reinterpret_borrow(handle h) { return {h, object::borrowed_t{}}; } + +/** \rst + Like `reinterpret_borrow`, but steals the reference. + + .. code-block:: cpp + + PyObject *p = PyObject_Str(obj); + py::str s = reinterpret_steal(p); // <-- `p` must be already be a `str` +\endrst */ +template T reinterpret_steal(handle h) { return {h, object::stolen_t{}}; } + +NAMESPACE_BEGIN(detail) +inline std::string error_string(); +NAMESPACE_END(detail) + +/// Fetch and hold an error which was already set in Python. An instance of this is typically +/// thrown to propagate python-side errors back through C++ which can either be caught manually or +/// else falls back to the function dispatcher (which then raises the captured error back to +/// python). +class error_already_set : public std::runtime_error { +public: + /// Constructs a new exception from the current Python error indicator, if any. The current + /// Python error indicator will be cleared. + error_already_set() : std::runtime_error(detail::error_string()) { + PyErr_Fetch(&type.ptr(), &value.ptr(), &trace.ptr()); + } + + inline ~error_already_set(); + + /// Give the currently-held error back to Python, if any. If there is currently a Python error + /// already set it is cleared first. After this call, the current object no longer stores the + /// error variables (but the `.what()` string is still available). + void restore() { PyErr_Restore(type.release().ptr(), value.release().ptr(), trace.release().ptr()); } + + // Does nothing; provided for backwards compatibility. + PYBIND11_DEPRECATED("Use of error_already_set.clear() is deprecated") + void clear() {} + + /// Check if the currently trapped error type matches the given Python exception class (or a + /// subclass thereof). May also be passed a tuple to search for any exception class matches in + /// the given tuple. + bool matches(handle ex) const { return PyErr_GivenExceptionMatches(ex.ptr(), type.ptr()); } + +private: + object type, value, trace; +}; + +/** \defgroup python_builtins _ + Unless stated otherwise, the following C++ functions behave the same + as their Python counterparts. + */ + +/** \ingroup python_builtins + \rst + Return true if ``obj`` is an instance of ``T``. Type ``T`` must be a subclass of + `object` or a class which was exposed to Python as ``py::class_``. +\endrst */ +template ::value, int> = 0> +bool isinstance(handle obj) { return T::check_(obj); } + +template ::value, int> = 0> +bool isinstance(handle obj) { return detail::isinstance_generic(obj, typeid(T)); } + +template <> inline bool isinstance(handle obj) = delete; +template <> inline bool isinstance(handle obj) { return obj.ptr() != nullptr; } + +/// \ingroup python_builtins +/// Return true if ``obj`` is an instance of the ``type``. +inline bool isinstance(handle obj, handle type) { + const auto result = PyObject_IsInstance(obj.ptr(), type.ptr()); + if (result == -1) + throw error_already_set(); + return result != 0; +} + +/// \addtogroup python_builtins +/// @{ +inline bool hasattr(handle obj, handle name) { + return PyObject_HasAttr(obj.ptr(), name.ptr()) == 1; +} + +inline bool hasattr(handle obj, const char *name) { + return PyObject_HasAttrString(obj.ptr(), name) == 1; +} + +inline object getattr(handle obj, handle name) { + PyObject *result = PyObject_GetAttr(obj.ptr(), name.ptr()); + if (!result) { throw error_already_set(); } + return reinterpret_steal(result); +} + +inline object getattr(handle obj, const char *name) { + PyObject *result = PyObject_GetAttrString(obj.ptr(), name); + if (!result) { throw error_already_set(); } + return reinterpret_steal(result); +} + +inline object getattr(handle obj, handle name, handle default_) { + if (PyObject *result = PyObject_GetAttr(obj.ptr(), name.ptr())) { + return reinterpret_steal(result); + } else { + PyErr_Clear(); + return reinterpret_borrow(default_); + } +} + +inline object getattr(handle obj, const char *name, handle default_) { + if (PyObject *result = PyObject_GetAttrString(obj.ptr(), name)) { + return reinterpret_steal(result); + } else { + PyErr_Clear(); + return reinterpret_borrow(default_); + } +} + +inline void setattr(handle obj, handle name, handle value) { + if (PyObject_SetAttr(obj.ptr(), name.ptr(), value.ptr()) != 0) { throw error_already_set(); } +} + +inline void setattr(handle obj, const char *name, handle value) { + if (PyObject_SetAttrString(obj.ptr(), name, value.ptr()) != 0) { throw error_already_set(); } +} + +inline ssize_t hash(handle obj) { + auto h = PyObject_Hash(obj.ptr()); + if (h == -1) { throw error_already_set(); } + return h; +} + +/// @} python_builtins + +NAMESPACE_BEGIN(detail) +inline handle get_function(handle value) { + if (value) { +#if PY_MAJOR_VERSION >= 3 + if (PyInstanceMethod_Check(value.ptr())) + value = PyInstanceMethod_GET_FUNCTION(value.ptr()); + else +#endif + if (PyMethod_Check(value.ptr())) + value = PyMethod_GET_FUNCTION(value.ptr()); + } + return value; +} + +// Helper aliases/functions to support implicit casting of values given to python accessors/methods. +// When given a pyobject, this simply returns the pyobject as-is; for other C++ type, the value goes +// through pybind11::cast(obj) to convert it to an `object`. +template ::value, int> = 0> +auto object_or_cast(T &&o) -> decltype(std::forward(o)) { return std::forward(o); } +// The following casting version is implemented in cast.h: +template ::value, int> = 0> +object object_or_cast(T &&o); +// Match a PyObject*, which we want to convert directly to handle via its converting constructor +inline handle object_or_cast(PyObject *ptr) { return ptr; } + + +template +class accessor : public object_api> { + using key_type = typename Policy::key_type; + +public: + accessor(handle obj, key_type key) : obj(obj), key(std::move(key)) { } + accessor(const accessor &) = default; + accessor(accessor &&) = default; + + // accessor overload required to override default assignment operator (templates are not allowed + // to replace default compiler-generated assignments). + void operator=(const accessor &a) && { std::move(*this).operator=(handle(a)); } + void operator=(const accessor &a) & { operator=(handle(a)); } + + template void operator=(T &&value) && { + Policy::set(obj, key, object_or_cast(std::forward(value))); + } + template void operator=(T &&value) & { + get_cache() = reinterpret_borrow(object_or_cast(std::forward(value))); + } + + template + PYBIND11_DEPRECATED("Use of obj.attr(...) as bool is deprecated in favor of pybind11::hasattr(obj, ...)") + explicit operator enable_if_t::value || + std::is_same::value, bool>() const { + return hasattr(obj, key); + } + template + PYBIND11_DEPRECATED("Use of obj[key] as bool is deprecated in favor of obj.contains(key)") + explicit operator enable_if_t::value, bool>() const { + return obj.contains(key); + } + + operator object() const { return get_cache(); } + PyObject *ptr() const { return get_cache().ptr(); } + template T cast() const { return get_cache().template cast(); } + +private: + object &get_cache() const { + if (!cache) { cache = Policy::get(obj, key); } + return cache; + } + +private: + handle obj; + key_type key; + mutable object cache; +}; + +NAMESPACE_BEGIN(accessor_policies) +struct obj_attr { + using key_type = object; + static object get(handle obj, handle key) { return getattr(obj, key); } + static void set(handle obj, handle key, handle val) { setattr(obj, key, val); } +}; + +struct str_attr { + using key_type = const char *; + static object get(handle obj, const char *key) { return getattr(obj, key); } + static void set(handle obj, const char *key, handle val) { setattr(obj, key, val); } +}; + +struct generic_item { + using key_type = object; + + static object get(handle obj, handle key) { + PyObject *result = PyObject_GetItem(obj.ptr(), key.ptr()); + if (!result) { throw error_already_set(); } + return reinterpret_steal(result); + } + + static void set(handle obj, handle key, handle val) { + if (PyObject_SetItem(obj.ptr(), key.ptr(), val.ptr()) != 0) { throw error_already_set(); } + } +}; + +struct sequence_item { + using key_type = size_t; + + static object get(handle obj, size_t index) { + PyObject *result = PySequence_GetItem(obj.ptr(), static_cast(index)); + if (!result) { throw error_already_set(); } + return reinterpret_steal(result); + } + + static void set(handle obj, size_t index, handle val) { + // PySequence_SetItem does not steal a reference to 'val' + if (PySequence_SetItem(obj.ptr(), static_cast(index), val.ptr()) != 0) { + throw error_already_set(); + } + } +}; + +struct list_item { + using key_type = size_t; + + static object get(handle obj, size_t index) { + PyObject *result = PyList_GetItem(obj.ptr(), static_cast(index)); + if (!result) { throw error_already_set(); } + return reinterpret_borrow(result); + } + + static void set(handle obj, size_t index, handle val) { + // PyList_SetItem steals a reference to 'val' + if (PyList_SetItem(obj.ptr(), static_cast(index), val.inc_ref().ptr()) != 0) { + throw error_already_set(); + } + } +}; + +struct tuple_item { + using key_type = size_t; + + static object get(handle obj, size_t index) { + PyObject *result = PyTuple_GetItem(obj.ptr(), static_cast(index)); + if (!result) { throw error_already_set(); } + return reinterpret_borrow(result); + } + + static void set(handle obj, size_t index, handle val) { + // PyTuple_SetItem steals a reference to 'val' + if (PyTuple_SetItem(obj.ptr(), static_cast(index), val.inc_ref().ptr()) != 0) { + throw error_already_set(); + } + } +}; +NAMESPACE_END(accessor_policies) + +/// STL iterator template used for tuple, list, sequence and dict +template +class generic_iterator : public Policy { + using It = generic_iterator; + +public: + using difference_type = ssize_t; + using iterator_category = typename Policy::iterator_category; + using value_type = typename Policy::value_type; + using reference = typename Policy::reference; + using pointer = typename Policy::pointer; + + generic_iterator() = default; + generic_iterator(handle seq, ssize_t index) : Policy(seq, index) { } + + reference operator*() const { return Policy::dereference(); } + reference operator[](difference_type n) const { return *(*this + n); } + pointer operator->() const { return **this; } + + It &operator++() { Policy::increment(); return *this; } + It operator++(int) { auto copy = *this; Policy::increment(); return copy; } + It &operator--() { Policy::decrement(); return *this; } + It operator--(int) { auto copy = *this; Policy::decrement(); return copy; } + It &operator+=(difference_type n) { Policy::advance(n); return *this; } + It &operator-=(difference_type n) { Policy::advance(-n); return *this; } + + friend It operator+(const It &a, difference_type n) { auto copy = a; return copy += n; } + friend It operator+(difference_type n, const It &b) { return b + n; } + friend It operator-(const It &a, difference_type n) { auto copy = a; return copy -= n; } + friend difference_type operator-(const It &a, const It &b) { return a.distance_to(b); } + + friend bool operator==(const It &a, const It &b) { return a.equal(b); } + friend bool operator!=(const It &a, const It &b) { return !(a == b); } + friend bool operator< (const It &a, const It &b) { return b - a > 0; } + friend bool operator> (const It &a, const It &b) { return b < a; } + friend bool operator>=(const It &a, const It &b) { return !(a < b); } + friend bool operator<=(const It &a, const It &b) { return !(a > b); } +}; + +NAMESPACE_BEGIN(iterator_policies) +/// Quick proxy class needed to implement ``operator->`` for iterators which can't return pointers +template +struct arrow_proxy { + T value; + + arrow_proxy(T &&value) : value(std::move(value)) { } + T *operator->() const { return &value; } +}; + +/// Lightweight iterator policy using just a simple pointer: see ``PySequence_Fast_ITEMS`` +class sequence_fast_readonly { +protected: + using iterator_category = std::random_access_iterator_tag; + using value_type = handle; + using reference = const handle; + using pointer = arrow_proxy; + + sequence_fast_readonly(handle obj, ssize_t n) : ptr(PySequence_Fast_ITEMS(obj.ptr()) + n) { } + + reference dereference() const { return *ptr; } + void increment() { ++ptr; } + void decrement() { --ptr; } + void advance(ssize_t n) { ptr += n; } + bool equal(const sequence_fast_readonly &b) const { return ptr == b.ptr; } + ssize_t distance_to(const sequence_fast_readonly &b) const { return ptr - b.ptr; } + +private: + PyObject **ptr; +}; + +/// Full read and write access using the sequence protocol: see ``detail::sequence_accessor`` +class sequence_slow_readwrite { +protected: + using iterator_category = std::random_access_iterator_tag; + using value_type = object; + using reference = sequence_accessor; + using pointer = arrow_proxy; + + sequence_slow_readwrite(handle obj, ssize_t index) : obj(obj), index(index) { } + + reference dereference() const { return {obj, static_cast(index)}; } + void increment() { ++index; } + void decrement() { --index; } + void advance(ssize_t n) { index += n; } + bool equal(const sequence_slow_readwrite &b) const { return index == b.index; } + ssize_t distance_to(const sequence_slow_readwrite &b) const { return index - b.index; } + +private: + handle obj; + ssize_t index; +}; + +/// Python's dictionary protocol permits this to be a forward iterator +class dict_readonly { +protected: + using iterator_category = std::forward_iterator_tag; + using value_type = std::pair; + using reference = const value_type; + using pointer = arrow_proxy; + + dict_readonly() = default; + dict_readonly(handle obj, ssize_t pos) : obj(obj), pos(pos) { increment(); } + + reference dereference() const { return {key, value}; } + void increment() { if (!PyDict_Next(obj.ptr(), &pos, &key, &value)) { pos = -1; } } + bool equal(const dict_readonly &b) const { return pos == b.pos; } + +private: + handle obj; + PyObject *key, *value; + ssize_t pos = -1; +}; +NAMESPACE_END(iterator_policies) + +#if !defined(PYPY_VERSION) +using tuple_iterator = generic_iterator; +using list_iterator = generic_iterator; +#else +using tuple_iterator = generic_iterator; +using list_iterator = generic_iterator; +#endif + +using sequence_iterator = generic_iterator; +using dict_iterator = generic_iterator; + +inline bool PyIterable_Check(PyObject *obj) { + PyObject *iter = PyObject_GetIter(obj); + if (iter) { + Py_DECREF(iter); + return true; + } else { + PyErr_Clear(); + return false; + } +} + +inline bool PyNone_Check(PyObject *o) { return o == Py_None; } + +inline bool PyUnicode_Check_Permissive(PyObject *o) { return PyUnicode_Check(o) || PYBIND11_BYTES_CHECK(o); } + +class kwargs_proxy : public handle { +public: + explicit kwargs_proxy(handle h) : handle(h) { } +}; + +class args_proxy : public handle { +public: + explicit args_proxy(handle h) : handle(h) { } + kwargs_proxy operator*() const { return kwargs_proxy(*this); } +}; + +/// Python argument categories (using PEP 448 terms) +template using is_keyword = std::is_base_of; +template using is_s_unpacking = std::is_same; // * unpacking +template using is_ds_unpacking = std::is_same; // ** unpacking +template using is_positional = satisfies_none_of; +template using is_keyword_or_ds = satisfies_any_of; + +// Call argument collector forward declarations +template +class simple_collector; +template +class unpacking_collector; + +NAMESPACE_END(detail) + +// TODO: After the deprecated constructors are removed, this macro can be simplified by +// inheriting ctors: `using Parent::Parent`. It's not an option right now because +// the `using` statement triggers the parent deprecation warning even if the ctor +// isn't even used. +#define PYBIND11_OBJECT_COMMON(Name, Parent, CheckFun) \ + public: \ + PYBIND11_DEPRECATED("Use reinterpret_borrow<"#Name">() or reinterpret_steal<"#Name">()") \ + Name(handle h, bool is_borrowed) : Parent(is_borrowed ? Parent(h, borrowed_t{}) : Parent(h, stolen_t{})) { } \ + Name(handle h, borrowed_t) : Parent(h, borrowed_t{}) { } \ + Name(handle h, stolen_t) : Parent(h, stolen_t{}) { } \ + PYBIND11_DEPRECATED("Use py::isinstance(obj) instead") \ + bool check() const { return m_ptr != nullptr && (bool) CheckFun(m_ptr); } \ + static bool check_(handle h) { return h.ptr() != nullptr && CheckFun(h.ptr()); } + +#define PYBIND11_OBJECT_CVT(Name, Parent, CheckFun, ConvertFun) \ + PYBIND11_OBJECT_COMMON(Name, Parent, CheckFun) \ + /* This is deliberately not 'explicit' to allow implicit conversion from object: */ \ + Name(const object &o) \ + : Parent(check_(o) ? o.inc_ref().ptr() : ConvertFun(o.ptr()), stolen_t{}) \ + { if (!m_ptr) throw error_already_set(); } \ + Name(object &&o) \ + : Parent(check_(o) ? o.release().ptr() : ConvertFun(o.ptr()), stolen_t{}) \ + { if (!m_ptr) throw error_already_set(); } \ + template \ + Name(const ::pybind11::detail::accessor &a) : Name(object(a)) { } + +#define PYBIND11_OBJECT(Name, Parent, CheckFun) \ + PYBIND11_OBJECT_COMMON(Name, Parent, CheckFun) \ + /* This is deliberately not 'explicit' to allow implicit conversion from object: */ \ + Name(const object &o) : Parent(o) { } \ + Name(object &&o) : Parent(std::move(o)) { } + +#define PYBIND11_OBJECT_DEFAULT(Name, Parent, CheckFun) \ + PYBIND11_OBJECT(Name, Parent, CheckFun) \ + Name() : Parent() { } + +/// \addtogroup pytypes +/// @{ + +/** \rst + Wraps a Python iterator so that it can also be used as a C++ input iterator + + Caveat: copying an iterator does not (and cannot) clone the internal + state of the Python iterable. This also applies to the post-increment + operator. This iterator should only be used to retrieve the current + value using ``operator*()``. +\endrst */ +class iterator : public object { +public: + using iterator_category = std::input_iterator_tag; + using difference_type = ssize_t; + using value_type = handle; + using reference = const handle; + using pointer = const handle *; + + PYBIND11_OBJECT_DEFAULT(iterator, object, PyIter_Check) + + iterator& operator++() { + advance(); + return *this; + } + + iterator operator++(int) { + auto rv = *this; + advance(); + return rv; + } + + reference operator*() const { + if (m_ptr && !value.ptr()) { + auto& self = const_cast(*this); + self.advance(); + } + return value; + } + + pointer operator->() const { operator*(); return &value; } + + /** \rst + The value which marks the end of the iteration. ``it == iterator::sentinel()`` + is equivalent to catching ``StopIteration`` in Python. + + .. code-block:: cpp + + void foo(py::iterator it) { + while (it != py::iterator::sentinel()) { + // use `*it` + ++it; + } + } + \endrst */ + static iterator sentinel() { return {}; } + + friend bool operator==(const iterator &a, const iterator &b) { return a->ptr() == b->ptr(); } + friend bool operator!=(const iterator &a, const iterator &b) { return a->ptr() != b->ptr(); } + +private: + void advance() { + value = reinterpret_steal(PyIter_Next(m_ptr)); + if (PyErr_Occurred()) { throw error_already_set(); } + } + +private: + object value = {}; +}; + +class iterable : public object { +public: + PYBIND11_OBJECT_DEFAULT(iterable, object, detail::PyIterable_Check) +}; + +class bytes; + +class str : public object { +public: + PYBIND11_OBJECT_CVT(str, object, detail::PyUnicode_Check_Permissive, raw_str) + + str(const char *c, size_t n) + : object(PyUnicode_FromStringAndSize(c, (ssize_t) n), stolen_t{}) { + if (!m_ptr) pybind11_fail("Could not allocate string object!"); + } + + // 'explicit' is explicitly omitted from the following constructors to allow implicit conversion to py::str from C++ string-like objects + str(const char *c = "") + : object(PyUnicode_FromString(c), stolen_t{}) { + if (!m_ptr) pybind11_fail("Could not allocate string object!"); + } + + str(const std::string &s) : str(s.data(), s.size()) { } + + explicit str(const bytes &b); + + /** \rst + Return a string representation of the object. This is analogous to + the ``str()`` function in Python. + \endrst */ + explicit str(handle h) : object(raw_str(h.ptr()), stolen_t{}) { } + + operator std::string() const { + object temp = *this; + if (PyUnicode_Check(m_ptr)) { + temp = reinterpret_steal(PyUnicode_AsUTF8String(m_ptr)); + if (!temp) + pybind11_fail("Unable to extract string contents! (encoding issue)"); + } + char *buffer; + ssize_t length; + if (PYBIND11_BYTES_AS_STRING_AND_SIZE(temp.ptr(), &buffer, &length)) + pybind11_fail("Unable to extract string contents! (invalid type)"); + return std::string(buffer, (size_t) length); + } + + template + str format(Args &&...args) const { + return attr("format")(std::forward(args)...); + } + +private: + /// Return string representation -- always returns a new reference, even if already a str + static PyObject *raw_str(PyObject *op) { + PyObject *str_value = PyObject_Str(op); +#if PY_MAJOR_VERSION < 3 + if (!str_value) throw error_already_set(); + PyObject *unicode = PyUnicode_FromEncodedObject(str_value, "utf-8", nullptr); + Py_XDECREF(str_value); str_value = unicode; +#endif + return str_value; + } +}; +/// @} pytypes + +inline namespace literals { +/** \rst + String literal version of `str` + \endrst */ +inline str operator"" _s(const char *s, size_t size) { return {s, size}; } +} + +/// \addtogroup pytypes +/// @{ +class bytes : public object { +public: + PYBIND11_OBJECT(bytes, object, PYBIND11_BYTES_CHECK) + + // Allow implicit conversion: + bytes(const char *c = "") + : object(PYBIND11_BYTES_FROM_STRING(c), stolen_t{}) { + if (!m_ptr) pybind11_fail("Could not allocate bytes object!"); + } + + bytes(const char *c, size_t n) + : object(PYBIND11_BYTES_FROM_STRING_AND_SIZE(c, (ssize_t) n), stolen_t{}) { + if (!m_ptr) pybind11_fail("Could not allocate bytes object!"); + } + + // Allow implicit conversion: + bytes(const std::string &s) : bytes(s.data(), s.size()) { } + + explicit bytes(const pybind11::str &s); + + operator std::string() const { + char *buffer; + ssize_t length; + if (PYBIND11_BYTES_AS_STRING_AND_SIZE(m_ptr, &buffer, &length)) + pybind11_fail("Unable to extract bytes contents!"); + return std::string(buffer, (size_t) length); + } +}; + +inline bytes::bytes(const pybind11::str &s) { + object temp = s; + if (PyUnicode_Check(s.ptr())) { + temp = reinterpret_steal(PyUnicode_AsUTF8String(s.ptr())); + if (!temp) + pybind11_fail("Unable to extract string contents! (encoding issue)"); + } + char *buffer; + ssize_t length; + if (PYBIND11_BYTES_AS_STRING_AND_SIZE(temp.ptr(), &buffer, &length)) + pybind11_fail("Unable to extract string contents! (invalid type)"); + auto obj = reinterpret_steal(PYBIND11_BYTES_FROM_STRING_AND_SIZE(buffer, length)); + if (!obj) + pybind11_fail("Could not allocate bytes object!"); + m_ptr = obj.release().ptr(); +} + +inline str::str(const bytes& b) { + char *buffer; + ssize_t length; + if (PYBIND11_BYTES_AS_STRING_AND_SIZE(b.ptr(), &buffer, &length)) + pybind11_fail("Unable to extract bytes contents!"); + auto obj = reinterpret_steal(PyUnicode_FromStringAndSize(buffer, (ssize_t) length)); + if (!obj) + pybind11_fail("Could not allocate string object!"); + m_ptr = obj.release().ptr(); +} + +class none : public object { +public: + PYBIND11_OBJECT(none, object, detail::PyNone_Check) + none() : object(Py_None, borrowed_t{}) { } +}; + +class bool_ : public object { +public: + PYBIND11_OBJECT_CVT(bool_, object, PyBool_Check, raw_bool) + bool_() : object(Py_False, borrowed_t{}) { } + // Allow implicit conversion from and to `bool`: + bool_(bool value) : object(value ? Py_True : Py_False, borrowed_t{}) { } + operator bool() const { return m_ptr && PyLong_AsLong(m_ptr) != 0; } + +private: + /// Return the truth value of an object -- always returns a new reference + static PyObject *raw_bool(PyObject *op) { + const auto value = PyObject_IsTrue(op); + if (value == -1) return nullptr; + return handle(value ? Py_True : Py_False).inc_ref().ptr(); + } +}; + +NAMESPACE_BEGIN(detail) +// Converts a value to the given unsigned type. If an error occurs, you get back (Unsigned) -1; +// otherwise you get back the unsigned long or unsigned long long value cast to (Unsigned). +// (The distinction is critically important when casting a returned -1 error value to some other +// unsigned type: (A)-1 != (B)-1 when A and B are unsigned types of different sizes). +template +Unsigned as_unsigned(PyObject *o) { + if (sizeof(Unsigned) <= sizeof(unsigned long) +#if PY_VERSION_HEX < 0x03000000 + || PyInt_Check(o) +#endif + ) { + unsigned long v = PyLong_AsUnsignedLong(o); + return v == (unsigned long) -1 && PyErr_Occurred() ? (Unsigned) -1 : (Unsigned) v; + } + else { + unsigned long long v = PyLong_AsUnsignedLongLong(o); + return v == (unsigned long long) -1 && PyErr_Occurred() ? (Unsigned) -1 : (Unsigned) v; + } +} +NAMESPACE_END(detail) + +class int_ : public object { +public: + PYBIND11_OBJECT_CVT(int_, object, PYBIND11_LONG_CHECK, PyNumber_Long) + int_() : object(PyLong_FromLong(0), stolen_t{}) { } + // Allow implicit conversion from C++ integral types: + template ::value, int> = 0> + int_(T value) { + if (sizeof(T) <= sizeof(long)) { + if (std::is_signed::value) + m_ptr = PyLong_FromLong((long) value); + else + m_ptr = PyLong_FromUnsignedLong((unsigned long) value); + } else { + if (std::is_signed::value) + m_ptr = PyLong_FromLongLong((long long) value); + else + m_ptr = PyLong_FromUnsignedLongLong((unsigned long long) value); + } + if (!m_ptr) pybind11_fail("Could not allocate int object!"); + } + + template ::value, int> = 0> + operator T() const { + return std::is_unsigned::value + ? detail::as_unsigned(m_ptr) + : sizeof(T) <= sizeof(long) + ? (T) PyLong_AsLong(m_ptr) + : (T) PYBIND11_LONG_AS_LONGLONG(m_ptr); + } +}; + +class float_ : public object { +public: + PYBIND11_OBJECT_CVT(float_, object, PyFloat_Check, PyNumber_Float) + // Allow implicit conversion from float/double: + float_(float value) : object(PyFloat_FromDouble((double) value), stolen_t{}) { + if (!m_ptr) pybind11_fail("Could not allocate float object!"); + } + float_(double value = .0) : object(PyFloat_FromDouble((double) value), stolen_t{}) { + if (!m_ptr) pybind11_fail("Could not allocate float object!"); + } + operator float() const { return (float) PyFloat_AsDouble(m_ptr); } + operator double() const { return (double) PyFloat_AsDouble(m_ptr); } +}; + +class weakref : public object { +public: + PYBIND11_OBJECT_DEFAULT(weakref, object, PyWeakref_Check) + explicit weakref(handle obj, handle callback = {}) + : object(PyWeakref_NewRef(obj.ptr(), callback.ptr()), stolen_t{}) { + if (!m_ptr) pybind11_fail("Could not allocate weak reference!"); + } +}; + +class slice : public object { +public: + PYBIND11_OBJECT_DEFAULT(slice, object, PySlice_Check) + slice(ssize_t start_, ssize_t stop_, ssize_t step_) { + int_ start(start_), stop(stop_), step(step_); + m_ptr = PySlice_New(start.ptr(), stop.ptr(), step.ptr()); + if (!m_ptr) pybind11_fail("Could not allocate slice object!"); + } + bool compute(size_t length, size_t *start, size_t *stop, size_t *step, + size_t *slicelength) const { + return PySlice_GetIndicesEx((PYBIND11_SLICE_OBJECT *) m_ptr, + (ssize_t) length, (ssize_t *) start, + (ssize_t *) stop, (ssize_t *) step, + (ssize_t *) slicelength) == 0; + } +}; + +class capsule : public object { +public: + PYBIND11_OBJECT_DEFAULT(capsule, object, PyCapsule_CheckExact) + PYBIND11_DEPRECATED("Use reinterpret_borrow() or reinterpret_steal()") + capsule(PyObject *ptr, bool is_borrowed) : object(is_borrowed ? object(ptr, borrowed_t{}) : object(ptr, stolen_t{})) { } + + explicit capsule(const void *value, const char *name = nullptr, void (*destructor)(PyObject *) = nullptr) + : object(PyCapsule_New(const_cast(value), name, destructor), stolen_t{}) { + if (!m_ptr) + pybind11_fail("Could not allocate capsule object!"); + } + + PYBIND11_DEPRECATED("Please pass a destructor that takes a void pointer as input") + capsule(const void *value, void (*destruct)(PyObject *)) + : object(PyCapsule_New(const_cast(value), nullptr, destruct), stolen_t{}) { + if (!m_ptr) + pybind11_fail("Could not allocate capsule object!"); + } + + capsule(const void *value, void (*destructor)(void *)) { + m_ptr = PyCapsule_New(const_cast(value), nullptr, [](PyObject *o) { + auto destructor = reinterpret_cast(PyCapsule_GetContext(o)); + void *ptr = PyCapsule_GetPointer(o, nullptr); + destructor(ptr); + }); + + if (!m_ptr) + pybind11_fail("Could not allocate capsule object!"); + + if (PyCapsule_SetContext(m_ptr, (void *) destructor) != 0) + pybind11_fail("Could not set capsule context!"); + } + + capsule(void (*destructor)()) { + m_ptr = PyCapsule_New(reinterpret_cast(destructor), nullptr, [](PyObject *o) { + auto destructor = reinterpret_cast(PyCapsule_GetPointer(o, nullptr)); + destructor(); + }); + + if (!m_ptr) + pybind11_fail("Could not allocate capsule object!"); + } + + template operator T *() const { + auto name = this->name(); + T * result = static_cast(PyCapsule_GetPointer(m_ptr, name)); + if (!result) pybind11_fail("Unable to extract capsule contents!"); + return result; + } + + const char *name() const { return PyCapsule_GetName(m_ptr); } +}; + +class tuple : public object { +public: + PYBIND11_OBJECT_CVT(tuple, object, PyTuple_Check, PySequence_Tuple) + explicit tuple(size_t size = 0) : object(PyTuple_New((ssize_t) size), stolen_t{}) { + if (!m_ptr) pybind11_fail("Could not allocate tuple object!"); + } + size_t size() const { return (size_t) PyTuple_Size(m_ptr); } + detail::tuple_accessor operator[](size_t index) const { return {*this, index}; } + detail::tuple_iterator begin() const { return {*this, 0}; } + detail::tuple_iterator end() const { return {*this, PyTuple_GET_SIZE(m_ptr)}; } +}; + +class dict : public object { +public: + PYBIND11_OBJECT_CVT(dict, object, PyDict_Check, raw_dict) + dict() : object(PyDict_New(), stolen_t{}) { + if (!m_ptr) pybind11_fail("Could not allocate dict object!"); + } + template ...>::value>, + // MSVC workaround: it can't compile an out-of-line definition, so defer the collector + typename collector = detail::deferred_t, Args...>> + explicit dict(Args &&...args) : dict(collector(std::forward(args)...).kwargs()) { } + + size_t size() const { return (size_t) PyDict_Size(m_ptr); } + detail::dict_iterator begin() const { return {*this, 0}; } + detail::dict_iterator end() const { return {}; } + void clear() const { PyDict_Clear(ptr()); } + bool contains(handle key) const { return PyDict_Contains(ptr(), key.ptr()) == 1; } + bool contains(const char *key) const { return PyDict_Contains(ptr(), pybind11::str(key).ptr()) == 1; } + +private: + /// Call the `dict` Python type -- always returns a new reference + static PyObject *raw_dict(PyObject *op) { + if (PyDict_Check(op)) + return handle(op).inc_ref().ptr(); + return PyObject_CallFunctionObjArgs((PyObject *) &PyDict_Type, op, nullptr); + } +}; + +class sequence : public object { +public: + PYBIND11_OBJECT_DEFAULT(sequence, object, PySequence_Check) + size_t size() const { return (size_t) PySequence_Size(m_ptr); } + detail::sequence_accessor operator[](size_t index) const { return {*this, index}; } + detail::sequence_iterator begin() const { return {*this, 0}; } + detail::sequence_iterator end() const { return {*this, PySequence_Size(m_ptr)}; } +}; + +class list : public object { +public: + PYBIND11_OBJECT_CVT(list, object, PyList_Check, PySequence_List) + explicit list(size_t size = 0) : object(PyList_New((ssize_t) size), stolen_t{}) { + if (!m_ptr) pybind11_fail("Could not allocate list object!"); + } + size_t size() const { return (size_t) PyList_Size(m_ptr); } + detail::list_accessor operator[](size_t index) const { return {*this, index}; } + detail::list_iterator begin() const { return {*this, 0}; } + detail::list_iterator end() const { return {*this, PyList_GET_SIZE(m_ptr)}; } + template void append(T &&val) const { + PyList_Append(m_ptr, detail::object_or_cast(std::forward(val)).ptr()); + } +}; + +class args : public tuple { PYBIND11_OBJECT_DEFAULT(args, tuple, PyTuple_Check) }; +class kwargs : public dict { PYBIND11_OBJECT_DEFAULT(kwargs, dict, PyDict_Check) }; + +class set : public object { +public: + PYBIND11_OBJECT_CVT(set, object, PySet_Check, PySet_New) + set() : object(PySet_New(nullptr), stolen_t{}) { + if (!m_ptr) pybind11_fail("Could not allocate set object!"); + } + size_t size() const { return (size_t) PySet_Size(m_ptr); } + template bool add(T &&val) const { + return PySet_Add(m_ptr, detail::object_or_cast(std::forward(val)).ptr()) == 0; + } + void clear() const { PySet_Clear(m_ptr); } +}; + +class function : public object { +public: + PYBIND11_OBJECT_DEFAULT(function, object, PyCallable_Check) + handle cpp_function() const { + handle fun = detail::get_function(m_ptr); + if (fun && PyCFunction_Check(fun.ptr())) + return fun; + return handle(); + } + bool is_cpp_function() const { return (bool) cpp_function(); } +}; + +class buffer : public object { +public: + PYBIND11_OBJECT_DEFAULT(buffer, object, PyObject_CheckBuffer) + + buffer_info request(bool writable = false) { + int flags = PyBUF_STRIDES | PyBUF_FORMAT; + if (writable) flags |= PyBUF_WRITABLE; + Py_buffer *view = new Py_buffer(); + if (PyObject_GetBuffer(m_ptr, view, flags) != 0) { + delete view; + throw error_already_set(); + } + return buffer_info(view); + } +}; + +class memoryview : public object { +public: + explicit memoryview(const buffer_info& info) { + static Py_buffer buf { }; + // Py_buffer uses signed sizes, strides and shape!.. + static std::vector py_strides { }; + static std::vector py_shape { }; + buf.buf = info.ptr; + buf.itemsize = info.itemsize; + buf.format = const_cast(info.format.c_str()); + buf.ndim = (int) info.ndim; + buf.len = info.size; + py_strides.clear(); + py_shape.clear(); + for (size_t i = 0; i < (size_t) info.ndim; ++i) { + py_strides.push_back(info.strides[i]); + py_shape.push_back(info.shape[i]); + } + buf.strides = py_strides.data(); + buf.shape = py_shape.data(); + buf.suboffsets = nullptr; + buf.readonly = false; + buf.internal = nullptr; + + m_ptr = PyMemoryView_FromBuffer(&buf); + if (!m_ptr) + pybind11_fail("Unable to create memoryview from buffer descriptor"); + } + + PYBIND11_OBJECT_CVT(memoryview, object, PyMemoryView_Check, PyMemoryView_FromObject) +}; +/// @} pytypes + +/// \addtogroup python_builtins +/// @{ +inline size_t len(handle h) { + ssize_t result = PyObject_Length(h.ptr()); + if (result < 0) + pybind11_fail("Unable to compute length of object"); + return (size_t) result; +} + +inline str repr(handle h) { + PyObject *str_value = PyObject_Repr(h.ptr()); + if (!str_value) throw error_already_set(); +#if PY_MAJOR_VERSION < 3 + PyObject *unicode = PyUnicode_FromEncodedObject(str_value, "utf-8", nullptr); + Py_XDECREF(str_value); str_value = unicode; + if (!str_value) throw error_already_set(); +#endif + return reinterpret_steal(str_value); +} + +inline iterator iter(handle obj) { + PyObject *result = PyObject_GetIter(obj.ptr()); + if (!result) { throw error_already_set(); } + return reinterpret_steal(result); +} +/// @} python_builtins + +NAMESPACE_BEGIN(detail) +template iterator object_api::begin() const { return iter(derived()); } +template iterator object_api::end() const { return iterator::sentinel(); } +template item_accessor object_api::operator[](handle key) const { + return {derived(), reinterpret_borrow(key)}; +} +template item_accessor object_api::operator[](const char *key) const { + return {derived(), pybind11::str(key)}; +} +template obj_attr_accessor object_api::attr(handle key) const { + return {derived(), reinterpret_borrow(key)}; +} +template str_attr_accessor object_api::attr(const char *key) const { + return {derived(), key}; +} +template args_proxy object_api::operator*() const { + return args_proxy(derived().ptr()); +} +template template bool object_api::contains(T &&item) const { + return attr("__contains__")(std::forward(item)).template cast(); +} + +template +pybind11::str object_api::str() const { return pybind11::str(derived()); } + +template +str_attr_accessor object_api::doc() const { return attr("__doc__"); } + +template +handle object_api::get_type() const { return (PyObject *) Py_TYPE(derived().ptr()); } + +NAMESPACE_END(detail) +NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/ml/dlib/dlib/external/pybind11/include/pybind11/stl.h b/ml/dlib/dlib/external/pybind11/include/pybind11/stl.h new file mode 100644 index 000000000..90eb7ea2e --- /dev/null +++ b/ml/dlib/dlib/external/pybind11/include/pybind11/stl.h @@ -0,0 +1,370 @@ +/* + pybind11/stl.h: Transparent conversion for STL data types + + Copyright (c) 2016 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include "pybind11.h" +#include +#include +#include +#include +#include +#include +#include + +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable: 4127) // warning C4127: Conditional expression is constant +#endif + +#ifdef __has_include +// std::optional (but including it in c++14 mode isn't allowed) +# if defined(PYBIND11_CPP17) && __has_include() +# include +# define PYBIND11_HAS_OPTIONAL 1 +# endif +// std::experimental::optional (but not allowed in c++11 mode) +# if defined(PYBIND11_CPP14) && (__has_include() && \ + !__has_include()) +# include +# define PYBIND11_HAS_EXP_OPTIONAL 1 +# endif +// std::variant +# if defined(PYBIND11_CPP17) && __has_include() +# include +# define PYBIND11_HAS_VARIANT 1 +# endif +#elif defined(_MSC_VER) && defined(PYBIND11_CPP17) +# include +# include +# define PYBIND11_HAS_OPTIONAL 1 +# define PYBIND11_HAS_VARIANT 1 +#endif + +NAMESPACE_BEGIN(PYBIND11_NAMESPACE) +NAMESPACE_BEGIN(detail) + +/// Extracts an const lvalue reference or rvalue reference for U based on the type of T (e.g. for +/// forwarding a container element). Typically used indirect via forwarded_type(), below. +template +using forwarded_type = conditional_t< + std::is_lvalue_reference::value, remove_reference_t &, remove_reference_t &&>; + +/// Forwards a value U as rvalue or lvalue according to whether T is rvalue or lvalue; typically +/// used for forwarding a container's elements. +template +forwarded_type forward_like(U &&u) { + return std::forward>(std::forward(u)); +} + +template struct set_caster { + using type = Type; + using key_conv = make_caster; + + bool load(handle src, bool convert) { + if (!isinstance(src)) + return false; + auto s = reinterpret_borrow(src); + value.clear(); + for (auto entry : s) { + key_conv conv; + if (!conv.load(entry, convert)) + return false; + value.insert(cast_op(std::move(conv))); + } + return true; + } + + template + static handle cast(T &&src, return_value_policy policy, handle parent) { + pybind11::set s; + for (auto &&value : src) { + auto value_ = reinterpret_steal(key_conv::cast(forward_like(value), policy, parent)); + if (!value_ || !s.add(value_)) + return handle(); + } + return s.release(); + } + + PYBIND11_TYPE_CASTER(type, _("Set[") + key_conv::name() + _("]")); +}; + +template struct map_caster { + using key_conv = make_caster; + using value_conv = make_caster; + + bool load(handle src, bool convert) { + if (!isinstance(src)) + return false; + auto d = reinterpret_borrow(src); + value.clear(); + for (auto it : d) { + key_conv kconv; + value_conv vconv; + if (!kconv.load(it.first.ptr(), convert) || + !vconv.load(it.second.ptr(), convert)) + return false; + value.emplace(cast_op(std::move(kconv)), cast_op(std::move(vconv))); + } + return true; + } + + template + static handle cast(T &&src, return_value_policy policy, handle parent) { + dict d; + for (auto &&kv : src) { + auto key = reinterpret_steal(key_conv::cast(forward_like(kv.first), policy, parent)); + auto value = reinterpret_steal(value_conv::cast(forward_like(kv.second), policy, parent)); + if (!key || !value) + return handle(); + d[key] = value; + } + return d.release(); + } + + PYBIND11_TYPE_CASTER(Type, _("Dict[") + key_conv::name() + _(", ") + value_conv::name() + _("]")); +}; + +template struct list_caster { + using value_conv = make_caster; + + bool load(handle src, bool convert) { + if (!isinstance(src)) + return false; + auto s = reinterpret_borrow(src); + value.clear(); + reserve_maybe(s, &value); + for (auto it : s) { + value_conv conv; + if (!conv.load(it, convert)) + return false; + value.push_back(cast_op(std::move(conv))); + } + return true; + } + +private: + template ().reserve(0)), void>::value, int> = 0> + void reserve_maybe(sequence s, Type *) { value.reserve(s.size()); } + void reserve_maybe(sequence, void *) { } + +public: + template + static handle cast(T &&src, return_value_policy policy, handle parent) { + list l(src.size()); + size_t index = 0; + for (auto &&value : src) { + auto value_ = reinterpret_steal(value_conv::cast(forward_like(value), policy, parent)); + if (!value_) + return handle(); + PyList_SET_ITEM(l.ptr(), (ssize_t) index++, value_.release().ptr()); // steals a reference + } + return l.release(); + } + + PYBIND11_TYPE_CASTER(Type, _("List[") + value_conv::name() + _("]")); +}; + +template struct type_caster> + : list_caster, Type> { }; + +template struct type_caster> + : list_caster, Type> { }; + +template struct array_caster { + using value_conv = make_caster; + +private: + template + bool require_size(enable_if_t size) { + if (value.size() != size) + value.resize(size); + return true; + } + template + bool require_size(enable_if_t size) { + return size == Size; + } + +public: + bool load(handle src, bool convert) { + if (!isinstance(src)) + return false; + auto l = reinterpret_borrow(src); + if (!require_size(l.size())) + return false; + size_t ctr = 0; + for (auto it : l) { + value_conv conv; + if (!conv.load(it, convert)) + return false; + value[ctr++] = cast_op(std::move(conv)); + } + return true; + } + + template + static handle cast(T &&src, return_value_policy policy, handle parent) { + list l(src.size()); + size_t index = 0; + for (auto &&value : src) { + auto value_ = reinterpret_steal(value_conv::cast(forward_like(value), policy, parent)); + if (!value_) + return handle(); + PyList_SET_ITEM(l.ptr(), (ssize_t) index++, value_.release().ptr()); // steals a reference + } + return l.release(); + } + + PYBIND11_TYPE_CASTER(ArrayType, _("List[") + value_conv::name() + _(_(""), _("[") + _() + _("]")) + _("]")); +}; + +template struct type_caster> + : array_caster, Type, false, Size> { }; + +template struct type_caster> + : array_caster, Type, true> { }; + +template struct type_caster> + : set_caster, Key> { }; + +template struct type_caster> + : set_caster, Key> { }; + +template struct type_caster> + : map_caster, Key, Value> { }; + +template struct type_caster> + : map_caster, Key, Value> { }; + +// This type caster is intended to be used for std::optional and std::experimental::optional +template struct optional_caster { + using value_conv = make_caster; + + template + static handle cast(T_ &&src, return_value_policy policy, handle parent) { + if (!src) + return none().inc_ref(); + return value_conv::cast(*std::forward(src), policy, parent); + } + + bool load(handle src, bool convert) { + if (!src) { + return false; + } else if (src.is_none()) { + return true; // default-constructed value is already empty + } + value_conv inner_caster; + if (!inner_caster.load(src, convert)) + return false; + + value.emplace(cast_op(std::move(inner_caster))); + return true; + } + + PYBIND11_TYPE_CASTER(T, _("Optional[") + value_conv::name() + _("]")); +}; + +#if PYBIND11_HAS_OPTIONAL +template struct type_caster> + : public optional_caster> {}; + +template<> struct type_caster + : public void_caster {}; +#endif + +#if PYBIND11_HAS_EXP_OPTIONAL +template struct type_caster> + : public optional_caster> {}; + +template<> struct type_caster + : public void_caster {}; +#endif + +/// Visit a variant and cast any found type to Python +struct variant_caster_visitor { + return_value_policy policy; + handle parent; + + using result_type = handle; // required by boost::variant in C++11 + + template + result_type operator()(T &&src) const { + return make_caster::cast(std::forward(src), policy, parent); + } +}; + +/// Helper class which abstracts away variant's `visit` function. `std::variant` and similar +/// `namespace::variant` types which provide a `namespace::visit()` function are handled here +/// automatically using argument-dependent lookup. Users can provide specializations for other +/// variant-like classes, e.g. `boost::variant` and `boost::apply_visitor`. +template class Variant> +struct visit_helper { + template + static auto call(Args &&...args) -> decltype(visit(std::forward(args)...)) { + return visit(std::forward(args)...); + } +}; + +/// Generic variant caster +template struct variant_caster; + +template class V, typename... Ts> +struct variant_caster> { + static_assert(sizeof...(Ts) > 0, "Variant must consist of at least one alternative."); + + template + bool load_alternative(handle src, bool convert, type_list) { + auto caster = make_caster(); + if (caster.load(src, convert)) { + value = cast_op(caster); + return true; + } + return load_alternative(src, convert, type_list{}); + } + + bool load_alternative(handle, bool, type_list<>) { return false; } + + bool load(handle src, bool convert) { + // Do a first pass without conversions to improve constructor resolution. + // E.g. `py::int_(1).cast>()` needs to fill the `int` + // slot of the variant. Without two-pass loading `double` would be filled + // because it appears first and a conversion is possible. + if (convert && load_alternative(src, false, type_list{})) + return true; + return load_alternative(src, convert, type_list{}); + } + + template + static handle cast(Variant &&src, return_value_policy policy, handle parent) { + return visit_helper::call(variant_caster_visitor{policy, parent}, + std::forward(src)); + } + + using Type = V; + PYBIND11_TYPE_CASTER(Type, _("Union[") + detail::concat(make_caster::name()...) + _("]")); +}; + +#if PYBIND11_HAS_VARIANT +template +struct type_caster> : variant_caster> { }; +#endif +NAMESPACE_END(detail) + +inline std::ostream &operator<<(std::ostream &os, const handle &obj) { + os << (std::string) str(obj); + return os; +} + +NAMESPACE_END(PYBIND11_NAMESPACE) + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif diff --git a/ml/dlib/dlib/external/pybind11/include/pybind11/stl_bind.h b/ml/dlib/dlib/external/pybind11/include/pybind11/stl_bind.h new file mode 100644 index 000000000..38dd68f69 --- /dev/null +++ b/ml/dlib/dlib/external/pybind11/include/pybind11/stl_bind.h @@ -0,0 +1,599 @@ +/* + pybind11/std_bind.h: Binding generators for STL data types + + Copyright (c) 2016 Sergey Lyskov and Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include "detail/common.h" +#include "operators.h" + +#include +#include + +NAMESPACE_BEGIN(PYBIND11_NAMESPACE) +NAMESPACE_BEGIN(detail) + +/* SFINAE helper class used by 'is_comparable */ +template struct container_traits { + template static std::true_type test_comparable(decltype(std::declval() == std::declval())*); + template static std::false_type test_comparable(...); + template static std::true_type test_value(typename T2::value_type *); + template static std::false_type test_value(...); + template static std::true_type test_pair(typename T2::first_type *, typename T2::second_type *); + template static std::false_type test_pair(...); + + static constexpr const bool is_comparable = std::is_same(nullptr))>::value; + static constexpr const bool is_pair = std::is_same(nullptr, nullptr))>::value; + static constexpr const bool is_vector = std::is_same(nullptr))>::value; + static constexpr const bool is_element = !is_pair && !is_vector; +}; + +/* Default: is_comparable -> std::false_type */ +template +struct is_comparable : std::false_type { }; + +/* For non-map data structures, check whether operator== can be instantiated */ +template +struct is_comparable< + T, enable_if_t::is_element && + container_traits::is_comparable>> + : std::true_type { }; + +/* For a vector/map data structure, recursively check the value type (which is std::pair for maps) */ +template +struct is_comparable::is_vector>> { + static constexpr const bool value = + is_comparable::value; +}; + +/* For pairs, recursively check the two data types */ +template +struct is_comparable::is_pair>> { + static constexpr const bool value = + is_comparable::value && + is_comparable::value; +}; + +/* Fallback functions */ +template void vector_if_copy_constructible(const Args &...) { } +template void vector_if_equal_operator(const Args &...) { } +template void vector_if_insertion_operator(const Args &...) { } +template void vector_modifiers(const Args &...) { } + +template +void vector_if_copy_constructible(enable_if_t::value, Class_> &cl) { + cl.def(init(), "Copy constructor"); +} + +template +void vector_if_equal_operator(enable_if_t::value, Class_> &cl) { + using T = typename Vector::value_type; + + cl.def(self == self); + cl.def(self != self); + + cl.def("count", + [](const Vector &v, const T &x) { + return std::count(v.begin(), v.end(), x); + }, + arg("x"), + "Return the number of times ``x`` appears in the list" + ); + + cl.def("remove", [](Vector &v, const T &x) { + auto p = std::find(v.begin(), v.end(), x); + if (p != v.end()) + v.erase(p); + else + throw value_error(); + }, + arg("x"), + "Remove the first item from the list whose value is x. " + "It is an error if there is no such item." + ); + + cl.def("__contains__", + [](const Vector &v, const T &x) { + return std::find(v.begin(), v.end(), x) != v.end(); + }, + arg("x"), + "Return true the container contains ``x``" + ); +} + +// Vector modifiers -- requires a copyable vector_type: +// (Technically, some of these (pop and __delitem__) don't actually require copyability, but it seems +// silly to allow deletion but not insertion, so include them here too.) +template +void vector_modifiers(enable_if_t::value, Class_> &cl) { + using T = typename Vector::value_type; + using SizeType = typename Vector::size_type; + using DiffType = typename Vector::difference_type; + + cl.def("append", + [](Vector &v, const T &value) { v.push_back(value); }, + arg("x"), + "Add an item to the end of the list"); + + cl.def(init([](iterable it) { + auto v = std::unique_ptr(new Vector()); + v->reserve(len(it)); + for (handle h : it) + v->push_back(h.cast()); + return v.release(); + })); + + cl.def("extend", + [](Vector &v, const Vector &src) { + v.insert(v.end(), src.begin(), src.end()); + }, + arg("L"), + "Extend the list by appending all the items in the given list" + ); + + cl.def("insert", + [](Vector &v, SizeType i, const T &x) { + if (i > v.size()) + throw index_error(); + v.insert(v.begin() + (DiffType) i, x); + }, + arg("i") , arg("x"), + "Insert an item at a given position." + ); + + cl.def("pop", + [](Vector &v) { + if (v.empty()) + throw index_error(); + T t = v.back(); + v.pop_back(); + return t; + }, + "Remove and return the last item" + ); + + cl.def("pop", + [](Vector &v, SizeType i) { + if (i >= v.size()) + throw index_error(); + T t = v[i]; + v.erase(v.begin() + (DiffType) i); + return t; + }, + arg("i"), + "Remove and return the item at index ``i``" + ); + + cl.def("__setitem__", + [](Vector &v, SizeType i, const T &t) { + if (i >= v.size()) + throw index_error(); + v[i] = t; + } + ); + + /// Slicing protocol + cl.def("__getitem__", + [](const Vector &v, slice slice) -> Vector * { + size_t start, stop, step, slicelength; + + if (!slice.compute(v.size(), &start, &stop, &step, &slicelength)) + throw error_already_set(); + + Vector *seq = new Vector(); + seq->reserve((size_t) slicelength); + + for (size_t i=0; ipush_back(v[start]); + start += step; + } + return seq; + }, + arg("s"), + "Retrieve list elements using a slice object" + ); + + cl.def("__setitem__", + [](Vector &v, slice slice, const Vector &value) { + size_t start, stop, step, slicelength; + if (!slice.compute(v.size(), &start, &stop, &step, &slicelength)) + throw error_already_set(); + + if (slicelength != value.size()) + throw std::runtime_error("Left and right hand size of slice assignment have different sizes!"); + + for (size_t i=0; i= v.size()) + throw index_error(); + v.erase(v.begin() + DiffType(i)); + }, + "Delete the list elements at index ``i``" + ); + + cl.def("__delitem__", + [](Vector &v, slice slice) { + size_t start, stop, step, slicelength; + + if (!slice.compute(v.size(), &start, &stop, &step, &slicelength)) + throw error_already_set(); + + if (step == 1 && false) { + v.erase(v.begin() + (DiffType) start, v.begin() + DiffType(start + slicelength)); + } else { + for (size_t i = 0; i < slicelength; ++i) { + v.erase(v.begin() + DiffType(start)); + start += step - 1; + } + } + }, + "Delete list elements using a slice object" + ); + +} + +// If the type has an operator[] that doesn't return a reference (most notably std::vector), +// we have to access by copying; otherwise we return by reference. +template using vector_needs_copy = negation< + std::is_same()[typename Vector::size_type()]), typename Vector::value_type &>>; + +// The usual case: access and iterate by reference +template +void vector_accessor(enable_if_t::value, Class_> &cl) { + using T = typename Vector::value_type; + using SizeType = typename Vector::size_type; + using ItType = typename Vector::iterator; + + cl.def("__getitem__", + [](Vector &v, SizeType i) -> T & { + if (i >= v.size()) + throw index_error(); + return v[i]; + }, + return_value_policy::reference_internal // ref + keepalive + ); + + cl.def("__iter__", + [](Vector &v) { + return make_iterator< + return_value_policy::reference_internal, ItType, ItType, T&>( + v.begin(), v.end()); + }, + keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */ + ); +} + +// The case for special objects, like std::vector, that have to be returned-by-copy: +template +void vector_accessor(enable_if_t::value, Class_> &cl) { + using T = typename Vector::value_type; + using SizeType = typename Vector::size_type; + using ItType = typename Vector::iterator; + cl.def("__getitem__", + [](const Vector &v, SizeType i) -> T { + if (i >= v.size()) + throw index_error(); + return v[i]; + } + ); + + cl.def("__iter__", + [](Vector &v) { + return make_iterator< + return_value_policy::copy, ItType, ItType, T>( + v.begin(), v.end()); + }, + keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */ + ); +} + +template auto vector_if_insertion_operator(Class_ &cl, std::string const &name) + -> decltype(std::declval() << std::declval(), void()) { + using size_type = typename Vector::size_type; + + cl.def("__repr__", + [name](Vector &v) { + std::ostringstream s; + s << name << '['; + for (size_type i=0; i < v.size(); ++i) { + s << v[i]; + if (i != v.size() - 1) + s << ", "; + } + s << ']'; + return s.str(); + }, + "Return the canonical string representation of this list." + ); +} + +// Provide the buffer interface for vectors if we have data() and we have a format for it +// GCC seems to have "void std::vector::data()" - doing SFINAE on the existence of data() is insufficient, we need to check it returns an appropriate pointer +template +struct vector_has_data_and_format : std::false_type {}; +template +struct vector_has_data_and_format::format(), std::declval().data()), typename Vector::value_type*>::value>> : std::true_type {}; + +// Add the buffer interface to a vector +template +enable_if_t...>::value> +vector_buffer(Class_& cl) { + using T = typename Vector::value_type; + + static_assert(vector_has_data_and_format::value, "There is not an appropriate format descriptor for this vector"); + + // numpy.h declares this for arbitrary types, but it may raise an exception and crash hard at runtime if PYBIND11_NUMPY_DTYPE hasn't been called, so check here + format_descriptor::format(); + + cl.def_buffer([](Vector& v) -> buffer_info { + return buffer_info(v.data(), static_cast(sizeof(T)), format_descriptor::format(), 1, {v.size()}, {sizeof(T)}); + }); + + cl.def(init([](buffer buf) { + auto info = buf.request(); + if (info.ndim != 1 || info.strides[0] % static_cast(sizeof(T))) + throw type_error("Only valid 1D buffers can be copied to a vector"); + if (!detail::compare_buffer_info::compare(info) || (ssize_t) sizeof(T) != info.itemsize) + throw type_error("Format mismatch (Python: " + info.format + " C++: " + format_descriptor::format() + ")"); + + auto vec = std::unique_ptr(new Vector()); + vec->reserve((size_t) info.shape[0]); + T *p = static_cast(info.ptr); + ssize_t step = info.strides[0] / static_cast(sizeof(T)); + T *end = p + info.shape[0] * step; + for (; p != end; p += step) + vec->push_back(*p); + return vec.release(); + })); + + return; +} + +template +enable_if_t...>::value> vector_buffer(Class_&) {} + +NAMESPACE_END(detail) + +// +// std::vector +// +template , typename... Args> +class_ bind_vector(handle scope, std::string const &name, Args&&... args) { + using Class_ = class_; + + // If the value_type is unregistered (e.g. a converting type) or is itself registered + // module-local then make the vector binding module-local as well: + using vtype = typename Vector::value_type; + auto vtype_info = detail::get_type_info(typeid(vtype)); + bool local = !vtype_info || vtype_info->module_local; + + Class_ cl(scope, name.c_str(), pybind11::module_local(local), std::forward(args)...); + + // Declare the buffer interface if a buffer_protocol() is passed in + detail::vector_buffer(cl); + + cl.def(init<>()); + + // Register copy constructor (if possible) + detail::vector_if_copy_constructible(cl); + + // Register comparison-related operators and functions (if possible) + detail::vector_if_equal_operator(cl); + + // Register stream insertion operator (if possible) + detail::vector_if_insertion_operator(cl, name); + + // Modifiers require copyable vector value type + detail::vector_modifiers(cl); + + // Accessor and iterator; return by value if copyable, otherwise we return by ref + keep-alive + detail::vector_accessor(cl); + + cl.def("__bool__", + [](const Vector &v) -> bool { + return !v.empty(); + }, + "Check whether the list is nonempty" + ); + + cl.def("__len__", &Vector::size); + + + + +#if 0 + // C++ style functions deprecated, leaving it here as an example + cl.def(init()); + + cl.def("resize", + (void (Vector::*) (size_type count)) & Vector::resize, + "changes the number of elements stored"); + + cl.def("erase", + [](Vector &v, SizeType i) { + if (i >= v.size()) + throw index_error(); + v.erase(v.begin() + i); + }, "erases element at index ``i``"); + + cl.def("empty", &Vector::empty, "checks whether the container is empty"); + cl.def("size", &Vector::size, "returns the number of elements"); + cl.def("push_back", (void (Vector::*)(const T&)) &Vector::push_back, "adds an element to the end"); + cl.def("pop_back", &Vector::pop_back, "removes the last element"); + + cl.def("max_size", &Vector::max_size, "returns the maximum possible number of elements"); + cl.def("reserve", &Vector::reserve, "reserves storage"); + cl.def("capacity", &Vector::capacity, "returns the number of elements that can be held in currently allocated storage"); + cl.def("shrink_to_fit", &Vector::shrink_to_fit, "reduces memory usage by freeing unused memory"); + + cl.def("clear", &Vector::clear, "clears the contents"); + cl.def("swap", &Vector::swap, "swaps the contents"); + + cl.def("front", [](Vector &v) { + if (v.size()) return v.front(); + else throw index_error(); + }, "access the first element"); + + cl.def("back", [](Vector &v) { + if (v.size()) return v.back(); + else throw index_error(); + }, "access the last element "); + +#endif + + return cl; +} + + + +// +// std::map, std::unordered_map +// + +NAMESPACE_BEGIN(detail) + +/* Fallback functions */ +template void map_if_insertion_operator(const Args &...) { } +template void map_assignment(const Args &...) { } + +// Map assignment when copy-assignable: just copy the value +template +void map_assignment(enable_if_t::value, Class_> &cl) { + using KeyType = typename Map::key_type; + using MappedType = typename Map::mapped_type; + + cl.def("__setitem__", + [](Map &m, const KeyType &k, const MappedType &v) { + auto it = m.find(k); + if (it != m.end()) it->second = v; + else m.emplace(k, v); + } + ); +} + +// Not copy-assignable, but still copy-constructible: we can update the value by erasing and reinserting +template +void map_assignment(enable_if_t< + !std::is_copy_assignable::value && + is_copy_constructible::value, + Class_> &cl) { + using KeyType = typename Map::key_type; + using MappedType = typename Map::mapped_type; + + cl.def("__setitem__", + [](Map &m, const KeyType &k, const MappedType &v) { + // We can't use m[k] = v; because value type might not be default constructable + auto r = m.emplace(k, v); + if (!r.second) { + // value type is not copy assignable so the only way to insert it is to erase it first... + m.erase(r.first); + m.emplace(k, v); + } + } + ); +} + + +template auto map_if_insertion_operator(Class_ &cl, std::string const &name) +-> decltype(std::declval() << std::declval() << std::declval(), void()) { + + cl.def("__repr__", + [name](Map &m) { + std::ostringstream s; + s << name << '{'; + bool f = false; + for (auto const &kv : m) { + if (f) + s << ", "; + s << kv.first << ": " << kv.second; + f = true; + } + s << '}'; + return s.str(); + }, + "Return the canonical string representation of this map." + ); +} + + +NAMESPACE_END(detail) + +template , typename... Args> +class_ bind_map(handle scope, const std::string &name, Args&&... args) { + using KeyType = typename Map::key_type; + using MappedType = typename Map::mapped_type; + using Class_ = class_; + + // If either type is a non-module-local bound type then make the map binding non-local as well; + // otherwise (e.g. both types are either module-local or converting) the map will be + // module-local. + auto tinfo = detail::get_type_info(typeid(MappedType)); + bool local = !tinfo || tinfo->module_local; + if (local) { + tinfo = detail::get_type_info(typeid(KeyType)); + local = !tinfo || tinfo->module_local; + } + + Class_ cl(scope, name.c_str(), pybind11::module_local(local), std::forward(args)...); + + cl.def(init<>()); + + // Register stream insertion operator (if possible) + detail::map_if_insertion_operator(cl, name); + + cl.def("__bool__", + [](const Map &m) -> bool { return !m.empty(); }, + "Check whether the map is nonempty" + ); + + cl.def("__iter__", + [](Map &m) { return make_key_iterator(m.begin(), m.end()); }, + keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */ + ); + + cl.def("items", + [](Map &m) { return make_iterator(m.begin(), m.end()); }, + keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */ + ); + + cl.def("__getitem__", + [](Map &m, const KeyType &k) -> MappedType & { + auto it = m.find(k); + if (it == m.end()) + throw key_error(); + return it->second; + }, + return_value_policy::reference_internal // ref + keepalive + ); + + // Assignment provided only if the type is copyable + detail::map_assignment(cl); + + cl.def("__delitem__", + [](Map &m, const KeyType &k) { + auto it = m.find(k); + if (it == m.end()) + throw key_error(); + m.erase(it); + } + ); + + cl.def("__len__", &Map::size); + + return cl; +} + +NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/ml/dlib/dlib/external/pybind11/tools/FindCatch.cmake b/ml/dlib/dlib/external/pybind11/tools/FindCatch.cmake new file mode 100644 index 000000000..9d490c5aa --- /dev/null +++ b/ml/dlib/dlib/external/pybind11/tools/FindCatch.cmake @@ -0,0 +1,57 @@ +# - Find the Catch test framework or download it (single header) +# +# This is a quick module for internal use. It assumes that Catch is +# REQUIRED and that a minimum version is provided (not EXACT). If +# a suitable version isn't found locally, the single header file +# will be downloaded and placed in the build dir: PROJECT_BINARY_DIR. +# +# This code sets the following variables: +# CATCH_INCLUDE_DIR - path to catch.hpp +# CATCH_VERSION - version number + +if(NOT Catch_FIND_VERSION) + message(FATAL_ERROR "A version number must be specified.") +elseif(Catch_FIND_REQUIRED) + message(FATAL_ERROR "This module assumes Catch is not required.") +elseif(Catch_FIND_VERSION_EXACT) + message(FATAL_ERROR "Exact version numbers are not supported, only minimum.") +endif() + +# Extract the version number from catch.hpp +function(_get_catch_version) + file(STRINGS "${CATCH_INCLUDE_DIR}/catch.hpp" version_line REGEX "Catch v.*" LIMIT_COUNT 1) + if(version_line MATCHES "Catch v([0-9]+)\\.([0-9]+)\\.([0-9]+)") + set(CATCH_VERSION "${CMAKE_MATCH_1}.${CMAKE_MATCH_2}.${CMAKE_MATCH_3}" PARENT_SCOPE) + endif() +endfunction() + +# Download the single-header version of Catch +function(_download_catch version destination_dir) + message(STATUS "Downloading catch v${version}...") + set(url https://github.com/philsquared/Catch/releases/download/v${version}/catch.hpp) + file(DOWNLOAD ${url} "${destination_dir}/catch.hpp" STATUS status) + list(GET status 0 error) + if(error) + message(FATAL_ERROR "Could not download ${url}") + endif() + set(CATCH_INCLUDE_DIR "${destination_dir}" CACHE INTERNAL "") +endfunction() + +# Look for catch locally +find_path(CATCH_INCLUDE_DIR NAMES catch.hpp PATH_SUFFIXES catch) +if(CATCH_INCLUDE_DIR) + _get_catch_version() +endif() + +# Download the header if it wasn't found or if it's outdated +if(NOT CATCH_VERSION OR CATCH_VERSION VERSION_LESS ${Catch_FIND_VERSION}) + if(DOWNLOAD_CATCH) + _download_catch(${Catch_FIND_VERSION} "${PROJECT_BINARY_DIR}/catch/") + _get_catch_version() + else() + set(CATCH_FOUND FALSE) + return() + endif() +endif() + +set(CATCH_FOUND TRUE) diff --git a/ml/dlib/dlib/external/pybind11/tools/FindEigen3.cmake b/ml/dlib/dlib/external/pybind11/tools/FindEigen3.cmake new file mode 100644 index 000000000..9c546a05d --- /dev/null +++ b/ml/dlib/dlib/external/pybind11/tools/FindEigen3.cmake @@ -0,0 +1,81 @@ +# - Try to find Eigen3 lib +# +# This module supports requiring a minimum version, e.g. you can do +# find_package(Eigen3 3.1.2) +# to require version 3.1.2 or newer of Eigen3. +# +# Once done this will define +# +# EIGEN3_FOUND - system has eigen lib with correct version +# EIGEN3_INCLUDE_DIR - the eigen include directory +# EIGEN3_VERSION - eigen version + +# Copyright (c) 2006, 2007 Montel Laurent, +# Copyright (c) 2008, 2009 Gael Guennebaud, +# Copyright (c) 2009 Benoit Jacob +# Redistribution and use is allowed according to the terms of the 2-clause BSD license. + +if(NOT Eigen3_FIND_VERSION) + if(NOT Eigen3_FIND_VERSION_MAJOR) + set(Eigen3_FIND_VERSION_MAJOR 2) + endif(NOT Eigen3_FIND_VERSION_MAJOR) + if(NOT Eigen3_FIND_VERSION_MINOR) + set(Eigen3_FIND_VERSION_MINOR 91) + endif(NOT Eigen3_FIND_VERSION_MINOR) + if(NOT Eigen3_FIND_VERSION_PATCH) + set(Eigen3_FIND_VERSION_PATCH 0) + endif(NOT Eigen3_FIND_VERSION_PATCH) + + set(Eigen3_FIND_VERSION "${Eigen3_FIND_VERSION_MAJOR}.${Eigen3_FIND_VERSION_MINOR}.${Eigen3_FIND_VERSION_PATCH}") +endif(NOT Eigen3_FIND_VERSION) + +macro(_eigen3_check_version) + file(READ "${EIGEN3_INCLUDE_DIR}/Eigen/src/Core/util/Macros.h" _eigen3_version_header) + + string(REGEX MATCH "define[ \t]+EIGEN_WORLD_VERSION[ \t]+([0-9]+)" _eigen3_world_version_match "${_eigen3_version_header}") + set(EIGEN3_WORLD_VERSION "${CMAKE_MATCH_1}") + string(REGEX MATCH "define[ \t]+EIGEN_MAJOR_VERSION[ \t]+([0-9]+)" _eigen3_major_version_match "${_eigen3_version_header}") + set(EIGEN3_MAJOR_VERSION "${CMAKE_MATCH_1}") + string(REGEX MATCH "define[ \t]+EIGEN_MINOR_VERSION[ \t]+([0-9]+)" _eigen3_minor_version_match "${_eigen3_version_header}") + set(EIGEN3_MINOR_VERSION "${CMAKE_MATCH_1}") + + set(EIGEN3_VERSION ${EIGEN3_WORLD_VERSION}.${EIGEN3_MAJOR_VERSION}.${EIGEN3_MINOR_VERSION}) + if(${EIGEN3_VERSION} VERSION_LESS ${Eigen3_FIND_VERSION}) + set(EIGEN3_VERSION_OK FALSE) + else(${EIGEN3_VERSION} VERSION_LESS ${Eigen3_FIND_VERSION}) + set(EIGEN3_VERSION_OK TRUE) + endif(${EIGEN3_VERSION} VERSION_LESS ${Eigen3_FIND_VERSION}) + + if(NOT EIGEN3_VERSION_OK) + + message(STATUS "Eigen3 version ${EIGEN3_VERSION} found in ${EIGEN3_INCLUDE_DIR}, " + "but at least version ${Eigen3_FIND_VERSION} is required") + endif(NOT EIGEN3_VERSION_OK) +endmacro(_eigen3_check_version) + +if (EIGEN3_INCLUDE_DIR) + + # in cache already + _eigen3_check_version() + set(EIGEN3_FOUND ${EIGEN3_VERSION_OK}) + +else (EIGEN3_INCLUDE_DIR) + + find_path(EIGEN3_INCLUDE_DIR NAMES signature_of_eigen3_matrix_library + PATHS + ${CMAKE_INSTALL_PREFIX}/include + ${KDE4_INCLUDE_DIR} + PATH_SUFFIXES eigen3 eigen + ) + + if(EIGEN3_INCLUDE_DIR) + _eigen3_check_version() + endif(EIGEN3_INCLUDE_DIR) + + include(FindPackageHandleStandardArgs) + find_package_handle_standard_args(Eigen3 DEFAULT_MSG EIGEN3_INCLUDE_DIR EIGEN3_VERSION_OK) + + mark_as_advanced(EIGEN3_INCLUDE_DIR) + +endif(EIGEN3_INCLUDE_DIR) + diff --git a/ml/dlib/dlib/external/pybind11/tools/FindPythonLibsNew.cmake b/ml/dlib/dlib/external/pybind11/tools/FindPythonLibsNew.cmake new file mode 100644 index 000000000..b29b287de --- /dev/null +++ b/ml/dlib/dlib/external/pybind11/tools/FindPythonLibsNew.cmake @@ -0,0 +1,195 @@ +# - Find python libraries +# This module finds the libraries corresponding to the Python interpreter +# FindPythonInterp provides. +# This code sets the following variables: +# +# PYTHONLIBS_FOUND - have the Python libs been found +# PYTHON_PREFIX - path to the Python installation +# PYTHON_LIBRARIES - path to the python library +# PYTHON_INCLUDE_DIRS - path to where Python.h is found +# PYTHON_MODULE_EXTENSION - lib extension, e.g. '.so' or '.pyd' +# PYTHON_MODULE_PREFIX - lib name prefix: usually an empty string +# PYTHON_SITE_PACKAGES - path to installation site-packages +# PYTHON_IS_DEBUG - whether the Python interpreter is a debug build +# +# Thanks to talljimbo for the patch adding the 'LDVERSION' config +# variable usage. + +#============================================================================= +# Copyright 2001-2009 Kitware, Inc. +# Copyright 2012 Continuum Analytics, Inc. +# +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# * Neither the names of Kitware, Inc., the Insight Software Consortium, +# nor the names of their contributors may be used to endorse or promote +# products derived from this software without specific prior written +# permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#============================================================================= + +# Checking for the extension makes sure that `LibsNew` was found and not just `Libs`. +if(PYTHONLIBS_FOUND AND PYTHON_MODULE_EXTENSION) + return() +endif() + +# Use the Python interpreter to find the libs. +if(PythonLibsNew_FIND_REQUIRED) + find_package(PythonInterp ${PythonLibsNew_FIND_VERSION} REQUIRED) +else() + find_package(PythonInterp ${PythonLibsNew_FIND_VERSION}) +endif() + +if(NOT PYTHONINTERP_FOUND) + set(PYTHONLIBS_FOUND FALSE) + return() +endif() + +# According to http://stackoverflow.com/questions/646518/python-how-to-detect-debug-interpreter +# testing whether sys has the gettotalrefcount function is a reliable, cross-platform +# way to detect a CPython debug interpreter. +# +# The library suffix is from the config var LDVERSION sometimes, otherwise +# VERSION. VERSION will typically be like "2.7" on unix, and "27" on windows. +execute_process(COMMAND "${PYTHON_EXECUTABLE}" "-c" + "from distutils import sysconfig as s;import sys;import struct; +print('.'.join(str(v) for v in sys.version_info)); +print(sys.prefix); +print(s.get_python_inc(plat_specific=True)); +print(s.get_python_lib(plat_specific=True)); +print(s.get_config_var('SO')); +print(hasattr(sys, 'gettotalrefcount')+0); +print(struct.calcsize('@P')); +print(s.get_config_var('LDVERSION') or s.get_config_var('VERSION')); +print(s.get_config_var('LIBDIR') or ''); +print(s.get_config_var('MULTIARCH') or ''); +" + RESULT_VARIABLE _PYTHON_SUCCESS + OUTPUT_VARIABLE _PYTHON_VALUES + ERROR_VARIABLE _PYTHON_ERROR_VALUE) + +if(NOT _PYTHON_SUCCESS MATCHES 0) + if(PythonLibsNew_FIND_REQUIRED) + message(FATAL_ERROR + "Python config failure:\n${_PYTHON_ERROR_VALUE}") + endif() + set(PYTHONLIBS_FOUND FALSE) + return() +endif() + +# Convert the process output into a list +string(REGEX REPLACE ";" "\\\\;" _PYTHON_VALUES ${_PYTHON_VALUES}) +string(REGEX REPLACE "\n" ";" _PYTHON_VALUES ${_PYTHON_VALUES}) +list(GET _PYTHON_VALUES 0 _PYTHON_VERSION_LIST) +list(GET _PYTHON_VALUES 1 PYTHON_PREFIX) +list(GET _PYTHON_VALUES 2 PYTHON_INCLUDE_DIR) +list(GET _PYTHON_VALUES 3 PYTHON_SITE_PACKAGES) +list(GET _PYTHON_VALUES 4 PYTHON_MODULE_EXTENSION) +list(GET _PYTHON_VALUES 5 PYTHON_IS_DEBUG) +list(GET _PYTHON_VALUES 6 PYTHON_SIZEOF_VOID_P) +list(GET _PYTHON_VALUES 7 PYTHON_LIBRARY_SUFFIX) +list(GET _PYTHON_VALUES 8 PYTHON_LIBDIR) +list(GET _PYTHON_VALUES 9 PYTHON_MULTIARCH) + +# Make sure the Python has the same pointer-size as the chosen compiler +# Skip if CMAKE_SIZEOF_VOID_P is not defined +if(CMAKE_SIZEOF_VOID_P AND (NOT "${PYTHON_SIZEOF_VOID_P}" STREQUAL "${CMAKE_SIZEOF_VOID_P}")) + if(PythonLibsNew_FIND_REQUIRED) + math(EXPR _PYTHON_BITS "${PYTHON_SIZEOF_VOID_P} * 8") + math(EXPR _CMAKE_BITS "${CMAKE_SIZEOF_VOID_P} * 8") + message(FATAL_ERROR + "Python config failure: Python is ${_PYTHON_BITS}-bit, " + "chosen compiler is ${_CMAKE_BITS}-bit") + endif() + set(PYTHONLIBS_FOUND FALSE) + return() +endif() + +# The built-in FindPython didn't always give the version numbers +string(REGEX REPLACE "\\." ";" _PYTHON_VERSION_LIST ${_PYTHON_VERSION_LIST}) +list(GET _PYTHON_VERSION_LIST 0 PYTHON_VERSION_MAJOR) +list(GET _PYTHON_VERSION_LIST 1 PYTHON_VERSION_MINOR) +list(GET _PYTHON_VERSION_LIST 2 PYTHON_VERSION_PATCH) + +# Make sure all directory separators are '/' +string(REGEX REPLACE "\\\\" "/" PYTHON_PREFIX ${PYTHON_PREFIX}) +string(REGEX REPLACE "\\\\" "/" PYTHON_INCLUDE_DIR ${PYTHON_INCLUDE_DIR}) +string(REGEX REPLACE "\\\\" "/" PYTHON_SITE_PACKAGES ${PYTHON_SITE_PACKAGES}) + +if(CMAKE_HOST_WIN32) + set(PYTHON_LIBRARY + "${PYTHON_PREFIX}/libs/Python${PYTHON_LIBRARY_SUFFIX}.lib") + + # when run in a venv, PYTHON_PREFIX points to it. But the libraries remain in the + # original python installation. They may be found relative to PYTHON_INCLUDE_DIR. + if(NOT EXISTS "${PYTHON_LIBRARY}") + get_filename_component(_PYTHON_ROOT ${PYTHON_INCLUDE_DIR} DIRECTORY) + set(PYTHON_LIBRARY + "${_PYTHON_ROOT}/libs/Python${PYTHON_LIBRARY_SUFFIX}.lib") + endif() + + # raise an error if the python libs are still not found. + if(NOT EXISTS "${PYTHON_LIBRARY}") + message(FATAL_ERROR "Python libraries not found") + endif() + +else() + if(PYTHON_MULTIARCH) + set(_PYTHON_LIBS_SEARCH "${PYTHON_LIBDIR}/${PYTHON_MULTIARCH}" "${PYTHON_LIBDIR}") + else() + set(_PYTHON_LIBS_SEARCH "${PYTHON_LIBDIR}") + endif() + #message(STATUS "Searching for Python libs in ${_PYTHON_LIBS_SEARCH}") + # Probably this needs to be more involved. It would be nice if the config + # information the python interpreter itself gave us were more complete. + find_library(PYTHON_LIBRARY + NAMES "python${PYTHON_LIBRARY_SUFFIX}" + PATHS ${_PYTHON_LIBS_SEARCH} + NO_DEFAULT_PATH) + + # If all else fails, just set the name/version and let the linker figure out the path. + if(NOT PYTHON_LIBRARY) + set(PYTHON_LIBRARY python${PYTHON_LIBRARY_SUFFIX}) + endif() +endif() + +MARK_AS_ADVANCED( + PYTHON_LIBRARY + PYTHON_INCLUDE_DIR +) + +# We use PYTHON_INCLUDE_DIR, PYTHON_LIBRARY and PYTHON_DEBUG_LIBRARY for the +# cache entries because they are meant to specify the location of a single +# library. We now set the variables listed by the documentation for this +# module. +SET(PYTHON_INCLUDE_DIRS "${PYTHON_INCLUDE_DIR}") +SET(PYTHON_LIBRARIES "${PYTHON_LIBRARY}") +SET(PYTHON_DEBUG_LIBRARIES "${PYTHON_DEBUG_LIBRARY}") + +find_package_message(PYTHON + "Found PythonLibs: ${PYTHON_LIBRARY}" + "${PYTHON_EXECUTABLE}${PYTHON_VERSION}") + +set(PYTHONLIBS_FOUND TRUE) diff --git a/ml/dlib/dlib/external/pybind11/tools/check-style.sh b/ml/dlib/dlib/external/pybind11/tools/check-style.sh new file mode 100755 index 000000000..0a9f7d24f --- /dev/null +++ b/ml/dlib/dlib/external/pybind11/tools/check-style.sh @@ -0,0 +1,70 @@ +#!/bin/bash +# +# Script to check include/test code for common pybind11 code style errors. +# +# This script currently checks for +# +# 1. use of tabs instead of spaces +# 2. MSDOS-style CRLF endings +# 3. trailing spaces +# 4. missing space between keyword and parenthesis, e.g.: for(, if(, while( +# 5. Missing space between right parenthesis and brace, e.g. 'for (...){' +# 6. opening brace on its own line. It should always be on the same line as the +# if/while/for/do statement. +# +# Invoke as: tools/check-style.sh +# + +check_style_errors=0 +IFS=$'\n' + +found="$( GREP_COLORS='mt=41' GREP_COLOR='41' grep $'\t' include tests/*.{cpp,py,h} docs/*.rst -rn --color=always )" +if [ -n "$found" ]; then + # The mt=41 sets a red background for matched tabs: + echo -e '\033[31;01mError: found tab characters in the following files:\033[0m' + check_style_errors=1 + echo "$found" | sed -e 's/^/ /' +fi + + +found="$( grep -IUlr $'\r' include tests/*.{cpp,py,h} docs/*.rst --color=always )" +if [ -n "$found" ]; then + echo -e '\033[31;01mError: found CRLF characters in the following files:\033[0m' + check_style_errors=1 + echo "$found" | sed -e 's/^/ /' +fi + +found="$(GREP_COLORS='mt=41' GREP_COLOR='41' grep '[[:blank:]]\+$' include tests/*.{cpp,py,h} docs/*.rst -rn --color=always )" +if [ -n "$found" ]; then + # The mt=41 sets a red background for matched trailing spaces + echo -e '\033[31;01mError: found trailing spaces in the following files:\033[0m' + check_style_errors=1 + echo "$found" | sed -e 's/^/ /' +fi + +found="$(grep '\<\(if\|for\|while\|catch\)(\|){' include tests/*.{cpp,h} -rn --color=always)" +if [ -n "$found" ]; then + echo -e '\033[31;01mError: found the following coding style problems:\033[0m' + check_style_errors=1 + echo "$found" | sed -e 's/^/ /' +fi + +found="$(awk ' +function prefix(filename, lineno) { + return " \033[35m" filename "\033[36m:\033[32m" lineno "\033[36m:\033[0m" +} +function mark(pattern, string) { sub(pattern, "\033[01;31m&\033[0m", string); return string } +last && /^\s*{/ { + print prefix(FILENAME, FNR-1) mark("\\)\\s*$", last) + print prefix(FILENAME, FNR) mark("^\\s*{", $0) + last="" +} +{ last = /(if|for|while|catch|switch)\s*\(.*\)\s*$/ ? $0 : "" } +' $(find include -type f) tests/*.{cpp,h} docs/*.rst)" +if [ -n "$found" ]; then + check_style_errors=1 + echo -e '\033[31;01mError: braces should occur on the same line as the if/while/.. statement. Found issues in the following files:\033[0m' + echo "$found" +fi + +exit $check_style_errors diff --git a/ml/dlib/dlib/external/pybind11/tools/libsize.py b/ml/dlib/dlib/external/pybind11/tools/libsize.py new file mode 100644 index 000000000..5dcb8b0d0 --- /dev/null +++ b/ml/dlib/dlib/external/pybind11/tools/libsize.py @@ -0,0 +1,38 @@ +from __future__ import print_function, division +import os +import sys + +# Internal build script for generating debugging test .so size. +# Usage: +# python libsize.py file.so save.txt -- displays the size of file.so and, if save.txt exists, compares it to the +# size in it, then overwrites save.txt with the new size for future runs. + +if len(sys.argv) != 3: + sys.exit("Invalid arguments: usage: python libsize.py file.so save.txt") + +lib = sys.argv[1] +save = sys.argv[2] + +if not os.path.exists(lib): + sys.exit("Error: requested file ({}) does not exist".format(lib)) + +libsize = os.path.getsize(lib) + +print("------", os.path.basename(lib), "file size:", libsize, end='') + +if os.path.exists(save): + with open(save) as sf: + oldsize = int(sf.readline()) + + if oldsize > 0: + change = libsize - oldsize + if change == 0: + print(" (no change)") + else: + print(" (change of {:+} bytes = {:+.2%})".format(change, change / oldsize)) +else: + print() + +with open(save, 'w') as sf: + sf.write(str(libsize)) + diff --git a/ml/dlib/dlib/external/pybind11/tools/mkdoc.py b/ml/dlib/dlib/external/pybind11/tools/mkdoc.py new file mode 100644 index 000000000..1fd8cceed --- /dev/null +++ b/ml/dlib/dlib/external/pybind11/tools/mkdoc.py @@ -0,0 +1,304 @@ +#!/usr/bin/env python3 +# +# Syntax: mkdoc.py [-I ..] [.. a list of header files ..] +# +# Extract documentation from C++ header files to use it in Python bindings +# + +import os +import sys +import platform +import re +import textwrap + +from clang import cindex +from clang.cindex import CursorKind +from collections import OrderedDict +from threading import Thread, Semaphore +from multiprocessing import cpu_count + +RECURSE_LIST = [ + CursorKind.TRANSLATION_UNIT, + CursorKind.NAMESPACE, + CursorKind.CLASS_DECL, + CursorKind.STRUCT_DECL, + CursorKind.ENUM_DECL, + CursorKind.CLASS_TEMPLATE +] + +PRINT_LIST = [ + CursorKind.CLASS_DECL, + CursorKind.STRUCT_DECL, + CursorKind.ENUM_DECL, + CursorKind.ENUM_CONSTANT_DECL, + CursorKind.CLASS_TEMPLATE, + CursorKind.FUNCTION_DECL, + CursorKind.FUNCTION_TEMPLATE, + CursorKind.CONVERSION_FUNCTION, + CursorKind.CXX_METHOD, + CursorKind.CONSTRUCTOR, + CursorKind.FIELD_DECL +] + +CPP_OPERATORS = { + '<=': 'le', '>=': 'ge', '==': 'eq', '!=': 'ne', '[]': 'array', + '+=': 'iadd', '-=': 'isub', '*=': 'imul', '/=': 'idiv', '%=': + 'imod', '&=': 'iand', '|=': 'ior', '^=': 'ixor', '<<=': 'ilshift', + '>>=': 'irshift', '++': 'inc', '--': 'dec', '<<': 'lshift', '>>': + 'rshift', '&&': 'land', '||': 'lor', '!': 'lnot', '~': 'bnot', + '&': 'band', '|': 'bor', '+': 'add', '-': 'sub', '*': 'mul', '/': + 'div', '%': 'mod', '<': 'lt', '>': 'gt', '=': 'assign', '()': 'call' +} + +CPP_OPERATORS = OrderedDict( + sorted(CPP_OPERATORS.items(), key=lambda t: -len(t[0]))) + +job_count = cpu_count() +job_semaphore = Semaphore(job_count) + +output = [] + +def d(s): + return s.decode('utf8') + + +def sanitize_name(name): + name = re.sub(r'type-parameter-0-([0-9]+)', r'T\1', name) + for k, v in CPP_OPERATORS.items(): + name = name.replace('operator%s' % k, 'operator_%s' % v) + name = re.sub('<.*>', '', name) + name = ''.join([ch if ch.isalnum() else '_' for ch in name]) + name = re.sub('_$', '', re.sub('_+', '_', name)) + return '__doc_' + name + + +def process_comment(comment): + result = '' + + # Remove C++ comment syntax + leading_spaces = float('inf') + for s in comment.expandtabs(tabsize=4).splitlines(): + s = s.strip() + if s.startswith('/*'): + s = s[2:].lstrip('*') + elif s.endswith('*/'): + s = s[:-2].rstrip('*') + elif s.startswith('///'): + s = s[3:] + if s.startswith('*'): + s = s[1:] + if len(s) > 0: + leading_spaces = min(leading_spaces, len(s) - len(s.lstrip())) + result += s + '\n' + + if leading_spaces != float('inf'): + result2 = "" + for s in result.splitlines(): + result2 += s[leading_spaces:] + '\n' + result = result2 + + # Doxygen tags + cpp_group = '([\w:]+)' + param_group = '([\[\w:\]]+)' + + s = result + s = re.sub(r'\\c\s+%s' % cpp_group, r'``\1``', s) + s = re.sub(r'\\a\s+%s' % cpp_group, r'*\1*', s) + s = re.sub(r'\\e\s+%s' % cpp_group, r'*\1*', s) + s = re.sub(r'\\em\s+%s' % cpp_group, r'*\1*', s) + s = re.sub(r'\\b\s+%s' % cpp_group, r'**\1**', s) + s = re.sub(r'\\ingroup\s+%s' % cpp_group, r'', s) + s = re.sub(r'\\param%s?\s+%s' % (param_group, cpp_group), + r'\n\n$Parameter ``\2``:\n\n', s) + s = re.sub(r'\\tparam%s?\s+%s' % (param_group, cpp_group), + r'\n\n$Template parameter ``\2``:\n\n', s) + + for in_, out_ in { + 'return': 'Returns', + 'author': 'Author', + 'authors': 'Authors', + 'copyright': 'Copyright', + 'date': 'Date', + 'remark': 'Remark', + 'sa': 'See also', + 'see': 'See also', + 'extends': 'Extends', + 'throw': 'Throws', + 'throws': 'Throws' + }.items(): + s = re.sub(r'\\%s\s*' % in_, r'\n\n$%s:\n\n' % out_, s) + + s = re.sub(r'\\details\s*', r'\n\n', s) + s = re.sub(r'\\brief\s*', r'', s) + s = re.sub(r'\\short\s*', r'', s) + s = re.sub(r'\\ref\s*', r'', s) + + s = re.sub(r'\\code\s?(.*?)\s?\\endcode', + r"```\n\1\n```\n", s, flags=re.DOTALL) + + # HTML/TeX tags + s = re.sub(r'(.*?)', r'``\1``', s, flags=re.DOTALL) + s = re.sub(r'
(.*?)
', r"```\n\1\n```\n", s, flags=re.DOTALL) + s = re.sub(r'(.*?)', r'*\1*', s, flags=re.DOTALL) + s = re.sub(r'(.*?)', r'**\1**', s, flags=re.DOTALL) + s = re.sub(r'\\f\$(.*?)\\f\$', r'$\1$', s, flags=re.DOTALL) + s = re.sub(r'
  • ', r'\n\n* ', s) + s = re.sub(r'', r'', s) + s = re.sub(r'
  • ', r'\n\n', s) + + s = s.replace('``true``', '``True``') + s = s.replace('``false``', '``False``') + + # Re-flow text + wrapper = textwrap.TextWrapper() + wrapper.expand_tabs = True + wrapper.replace_whitespace = True + wrapper.drop_whitespace = True + wrapper.width = 70 + wrapper.initial_indent = wrapper.subsequent_indent = '' + + result = '' + in_code_segment = False + for x in re.split(r'(```)', s): + if x == '```': + if not in_code_segment: + result += '```\n' + else: + result += '\n```\n\n' + in_code_segment = not in_code_segment + elif in_code_segment: + result += x.strip() + else: + for y in re.split(r'(?: *\n *){2,}', x): + wrapped = wrapper.fill(re.sub(r'\s+', ' ', y).strip()) + if len(wrapped) > 0 and wrapped[0] == '$': + result += wrapped[1:] + '\n' + wrapper.initial_indent = \ + wrapper.subsequent_indent = ' ' * 4 + else: + if len(wrapped) > 0: + result += wrapped + '\n\n' + wrapper.initial_indent = wrapper.subsequent_indent = '' + return result.rstrip().lstrip('\n') + + +def extract(filename, node, prefix): + if not (node.location.file is None or + os.path.samefile(d(node.location.file.name), filename)): + return 0 + if node.kind in RECURSE_LIST: + sub_prefix = prefix + if node.kind != CursorKind.TRANSLATION_UNIT: + if len(sub_prefix) > 0: + sub_prefix += '_' + sub_prefix += d(node.spelling) + for i in node.get_children(): + extract(filename, i, sub_prefix) + if node.kind in PRINT_LIST: + comment = d(node.raw_comment) if node.raw_comment is not None else '' + comment = process_comment(comment) + sub_prefix = prefix + if len(sub_prefix) > 0: + sub_prefix += '_' + if len(node.spelling) > 0: + name = sanitize_name(sub_prefix + d(node.spelling)) + global output + output.append((name, filename, comment)) + + +class ExtractionThread(Thread): + def __init__(self, filename, parameters): + Thread.__init__(self) + self.filename = filename + self.parameters = parameters + job_semaphore.acquire() + + def run(self): + print('Processing "%s" ..' % self.filename, file=sys.stderr) + try: + index = cindex.Index( + cindex.conf.lib.clang_createIndex(False, True)) + tu = index.parse(self.filename, self.parameters) + extract(self.filename, tu.cursor, '') + finally: + job_semaphore.release() + +if __name__ == '__main__': + parameters = ['-x', 'c++', '-std=c++11'] + filenames = [] + + if platform.system() == 'Darwin': + dev_path = '/Applications/Xcode.app/Contents/Developer/' + lib_dir = dev_path + 'Toolchains/XcodeDefault.xctoolchain/usr/lib/' + sdk_dir = dev_path + 'Platforms/MacOSX.platform/Developer/SDKs' + libclang = lib_dir + 'libclang.dylib' + + if os.path.exists(libclang): + cindex.Config.set_library_path(os.path.dirname(libclang)) + + if os.path.exists(sdk_dir): + sysroot_dir = os.path.join(sdk_dir, next(os.walk(sdk_dir))[1][0]) + parameters.append('-isysroot') + parameters.append(sysroot_dir) + + for item in sys.argv[1:]: + if item.startswith('-'): + parameters.append(item) + else: + filenames.append(item) + + if len(filenames) == 0: + print('Syntax: %s [.. a list of header files ..]' % sys.argv[0]) + exit(-1) + + print('''/* + This file contains docstrings for the Python bindings. + Do not edit! These were automatically extracted by mkdoc.py + */ + +#define __EXPAND(x) x +#define __COUNT(_1, _2, _3, _4, _5, _6, _7, COUNT, ...) COUNT +#define __VA_SIZE(...) __EXPAND(__COUNT(__VA_ARGS__, 7, 6, 5, 4, 3, 2, 1)) +#define __CAT1(a, b) a ## b +#define __CAT2(a, b) __CAT1(a, b) +#define __DOC1(n1) __doc_##n1 +#define __DOC2(n1, n2) __doc_##n1##_##n2 +#define __DOC3(n1, n2, n3) __doc_##n1##_##n2##_##n3 +#define __DOC4(n1, n2, n3, n4) __doc_##n1##_##n2##_##n3##_##n4 +#define __DOC5(n1, n2, n3, n4, n5) __doc_##n1##_##n2##_##n3##_##n4##_##n5 +#define __DOC6(n1, n2, n3, n4, n5, n6) __doc_##n1##_##n2##_##n3##_##n4##_##n5##_##n6 +#define __DOC7(n1, n2, n3, n4, n5, n6, n7) __doc_##n1##_##n2##_##n3##_##n4##_##n5##_##n6##_##n7 +#define DOC(...) __EXPAND(__EXPAND(__CAT2(__DOC, __VA_SIZE(__VA_ARGS__)))(__VA_ARGS__)) + +#if defined(__GNUG__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-variable" +#endif +''') + + output.clear() + for filename in filenames: + thr = ExtractionThread(filename, parameters) + thr.start() + + print('Waiting for jobs to finish ..', file=sys.stderr) + for i in range(job_count): + job_semaphore.acquire() + + name_ctr = 1 + name_prev = None + for name, _, comment in list(sorted(output, key=lambda x: (x[0], x[1]))): + if name == name_prev: + name_ctr += 1 + name = name + "_%i" % name_ctr + else: + name_prev = name + name_ctr = 1 + print('\nstatic const char *%s =%sR"doc(%s)doc";' % + (name, '\n' if '\n' in comment else ' ', comment)) + + print(''' +#if defined(__GNUG__) +#pragma GCC diagnostic pop +#endif +''') diff --git a/ml/dlib/dlib/external/pybind11/tools/pybind11Config.cmake.in b/ml/dlib/dlib/external/pybind11/tools/pybind11Config.cmake.in new file mode 100644 index 000000000..3dd1b2c1a --- /dev/null +++ b/ml/dlib/dlib/external/pybind11/tools/pybind11Config.cmake.in @@ -0,0 +1,100 @@ +# pybind11Config.cmake +# -------------------- +# +# PYBIND11 cmake module. +# This module sets the following variables in your project:: +# +# pybind11_FOUND - true if pybind11 and all required components found on the system +# pybind11_VERSION - pybind11 version in format Major.Minor.Release +# pybind11_INCLUDE_DIRS - Directories where pybind11 and python headers are located. +# pybind11_INCLUDE_DIR - Directory where pybind11 headers are located. +# pybind11_DEFINITIONS - Definitions necessary to use pybind11, namely USING_pybind11. +# pybind11_LIBRARIES - compile flags and python libraries (as needed) to link against. +# pybind11_LIBRARY - empty. +# CMAKE_MODULE_PATH - appends location of accompanying FindPythonLibsNew.cmake and +# pybind11Tools.cmake modules. +# +# +# Available components: None +# +# +# Exported targets:: +# +# If pybind11 is found, this module defines the following :prop_tgt:`IMPORTED` +# interface library targets:: +# +# pybind11::module - for extension modules +# pybind11::embed - for embedding the Python interpreter +# +# Python headers, libraries (as needed by platform), and the C++ standard +# are attached to the target. Set PythonLibsNew variables to influence +# python detection and PYBIND11_CPP_STANDARD (-std=c++11 or -std=c++14) to +# influence standard setting. :: +# +# find_package(pybind11 CONFIG REQUIRED) +# message(STATUS "Found pybind11 v${pybind11_VERSION}: ${pybind11_INCLUDE_DIRS}") +# +# # Create an extension module +# add_library(mylib MODULE main.cpp) +# target_link_libraries(mylib pybind11::module) +# +# # Or embed the Python interpreter into an executable +# add_executable(myexe main.cpp) +# target_link_libraries(myexe pybind11::embed) +# +# Suggested usage:: +# +# find_package with version info is not recommended except for release versions. :: +# +# find_package(pybind11 CONFIG) +# find_package(pybind11 2.0 EXACT CONFIG REQUIRED) +# +# +# The following variables can be set to guide the search for this package:: +# +# pybind11_DIR - CMake variable, set to directory containing this Config file +# CMAKE_PREFIX_PATH - CMake variable, set to root directory of this package +# PATH - environment variable, set to bin directory of this package +# CMAKE_DISABLE_FIND_PACKAGE_pybind11 - CMake variable, disables +# find_package(pybind11) when not REQUIRED, perhaps to force internal build + +@PACKAGE_INIT@ + +set(PN pybind11) + +# location of pybind11/pybind11.h +set(${PN}_INCLUDE_DIR "${PACKAGE_PREFIX_DIR}/@CMAKE_INSTALL_INCLUDEDIR@") + +set(${PN}_LIBRARY "") +set(${PN}_DEFINITIONS USING_${PN}) + +check_required_components(${PN}) + +# make detectable the FindPythonLibsNew.cmake module +list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_LIST_DIR}) + +include(pybind11Tools) + +if(NOT (CMAKE_VERSION VERSION_LESS 3.0)) +#----------------------------------------------------------------------------- +# Don't include targets if this file is being picked up by another +# project which has already built this as a subproject +#----------------------------------------------------------------------------- +if(NOT TARGET ${PN}::pybind11) + include("${CMAKE_CURRENT_LIST_DIR}/${PN}Targets.cmake") + + find_package(PythonLibsNew ${PYBIND11_PYTHON_VERSION} MODULE REQUIRED) + set_property(TARGET ${PN}::pybind11 APPEND PROPERTY INTERFACE_INCLUDE_DIRECTORIES ${PYTHON_INCLUDE_DIRS}) + set_property(TARGET ${PN}::embed APPEND PROPERTY INTERFACE_LINK_LIBRARIES ${PYTHON_LIBRARIES}) + if(WIN32 OR CYGWIN) + set_property(TARGET ${PN}::module APPEND PROPERTY INTERFACE_LINK_LIBRARIES ${PYTHON_LIBRARIES}) + endif() + + set_property(TARGET ${PN}::pybind11 APPEND PROPERTY INTERFACE_COMPILE_OPTIONS "${PYBIND11_CPP_STANDARD}") + + get_property(_iid TARGET ${PN}::pybind11 PROPERTY INTERFACE_INCLUDE_DIRECTORIES) + get_property(_ill TARGET ${PN}::module PROPERTY INTERFACE_LINK_LIBRARIES) + set(${PN}_INCLUDE_DIRS ${_iid}) + set(${PN}_LIBRARIES ${_ico} ${_ill}) +endif() +endif() diff --git a/ml/dlib/dlib/external/pybind11/tools/pybind11Tools.cmake b/ml/dlib/dlib/external/pybind11/tools/pybind11Tools.cmake new file mode 100644 index 000000000..a7c471a07 --- /dev/null +++ b/ml/dlib/dlib/external/pybind11/tools/pybind11Tools.cmake @@ -0,0 +1,202 @@ +# tools/pybind11Tools.cmake -- Build system for the pybind11 modules +# +# Copyright (c) 2015 Wenzel Jakob +# +# All rights reserved. Use of this source code is governed by a +# BSD-style license that can be found in the LICENSE file. + +cmake_minimum_required(VERSION 2.8.12) + +# Add a CMake parameter for choosing a desired Python version +if(NOT PYBIND11_PYTHON_VERSION) + set(PYBIND11_PYTHON_VERSION "" CACHE STRING "Python version to use for compiling modules") +endif() + +set(Python_ADDITIONAL_VERSIONS 3.7 3.6 3.5 3.4) +find_package(PythonLibsNew ${PYBIND11_PYTHON_VERSION} REQUIRED) + +include(CheckCXXCompilerFlag) +include(CMakeParseArguments) + +if(NOT PYBIND11_CPP_STANDARD AND NOT CMAKE_CXX_STANDARD) + if(NOT MSVC) + check_cxx_compiler_flag("-std=c++14" HAS_CPP14_FLAG) + + if (HAS_CPP14_FLAG) + set(PYBIND11_CPP_STANDARD -std=c++14) + else() + check_cxx_compiler_flag("-std=c++11" HAS_CPP11_FLAG) + if (HAS_CPP11_FLAG) + set(PYBIND11_CPP_STANDARD -std=c++11) + else() + message(FATAL_ERROR "Unsupported compiler -- pybind11 requires C++11 support!") + endif() + endif() + elseif(MSVC) + set(PYBIND11_CPP_STANDARD /std:c++14) + endif() + + set(PYBIND11_CPP_STANDARD ${PYBIND11_CPP_STANDARD} CACHE STRING + "C++ standard flag, e.g. -std=c++11, -std=c++14, /std:c++14. Defaults to C++14 mode." FORCE) +endif() + +# Checks whether the given CXX/linker flags can compile and link a cxx file. cxxflags and +# linkerflags are lists of flags to use. The result variable is a unique variable name for each set +# of flags: the compilation result will be cached base on the result variable. If the flags work, +# sets them in cxxflags_out/linkerflags_out internal cache variables (in addition to ${result}). +function(_pybind11_return_if_cxx_and_linker_flags_work result cxxflags linkerflags cxxflags_out linkerflags_out) + set(CMAKE_REQUIRED_LIBRARIES ${linkerflags}) + check_cxx_compiler_flag("${cxxflags}" ${result}) + if (${result}) + set(${cxxflags_out} "${cxxflags}" CACHE INTERNAL "" FORCE) + set(${linkerflags_out} "${linkerflags}" CACHE INTERNAL "" FORCE) + endif() +endfunction() + +# Internal: find the appropriate link time optimization flags for this compiler +function(_pybind11_add_lto_flags target_name prefer_thin_lto) + if (NOT DEFINED PYBIND11_LTO_CXX_FLAGS) + set(PYBIND11_LTO_CXX_FLAGS "" CACHE INTERNAL "") + set(PYBIND11_LTO_LINKER_FLAGS "" CACHE INTERNAL "") + + if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang") + set(cxx_append "") + set(linker_append "") + if (CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND NOT APPLE) + # Clang Gold plugin does not support -Os; append -O3 to MinSizeRel builds to override it + set(linker_append ";$<$:-O3>") + elseif(CMAKE_CXX_COMPILER_ID MATCHES "GNU") + set(cxx_append ";-fno-fat-lto-objects") + endif() + + if (CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND prefer_thin_lto) + _pybind11_return_if_cxx_and_linker_flags_work(HAS_FLTO_THIN + "-flto=thin${cxx_append}" "-flto=thin${linker_append}" + PYBIND11_LTO_CXX_FLAGS PYBIND11_LTO_LINKER_FLAGS) + endif() + + if (NOT HAS_FLTO_THIN) + _pybind11_return_if_cxx_and_linker_flags_work(HAS_FLTO + "-flto${cxx_append}" "-flto${linker_append}" + PYBIND11_LTO_CXX_FLAGS PYBIND11_LTO_LINKER_FLAGS) + endif() + elseif (CMAKE_CXX_COMPILER_ID MATCHES "Intel") + # Intel equivalent to LTO is called IPO + _pybind11_return_if_cxx_and_linker_flags_work(HAS_INTEL_IPO + "-ipo" "-ipo" PYBIND11_LTO_CXX_FLAGS PYBIND11_LTO_LINKER_FLAGS) + elseif(MSVC) + # cmake only interprets libraries as linker flags when they start with a - (otherwise it + # converts /LTCG to \LTCG as if it was a Windows path). Luckily MSVC supports passing flags + # with - instead of /, even if it is a bit non-standard: + _pybind11_return_if_cxx_and_linker_flags_work(HAS_MSVC_GL_LTCG + "/GL" "-LTCG" PYBIND11_LTO_CXX_FLAGS PYBIND11_LTO_LINKER_FLAGS) + endif() + + if (PYBIND11_LTO_CXX_FLAGS) + message(STATUS "LTO enabled") + else() + message(STATUS "LTO disabled (not supported by the compiler and/or linker)") + endif() + endif() + + # Enable LTO flags if found, except for Debug builds + if (PYBIND11_LTO_CXX_FLAGS) + target_compile_options(${target_name} PRIVATE "$<$>:${PYBIND11_LTO_CXX_FLAGS}>") + endif() + if (PYBIND11_LTO_LINKER_FLAGS) + target_link_libraries(${target_name} PRIVATE "$<$>:${PYBIND11_LTO_LINKER_FLAGS}>") + endif() +endfunction() + +# Build a Python extension module: +# pybind11_add_module( [MODULE | SHARED] [EXCLUDE_FROM_ALL] +# [NO_EXTRAS] [THIN_LTO] source1 [source2 ...]) +# +function(pybind11_add_module target_name) + set(options MODULE SHARED EXCLUDE_FROM_ALL NO_EXTRAS THIN_LTO) + cmake_parse_arguments(ARG "${options}" "" "" ${ARGN}) + + if(ARG_MODULE AND ARG_SHARED) + message(FATAL_ERROR "Can't be both MODULE and SHARED") + elseif(ARG_SHARED) + set(lib_type SHARED) + else() + set(lib_type MODULE) + endif() + + if(ARG_EXCLUDE_FROM_ALL) + set(exclude_from_all EXCLUDE_FROM_ALL) + endif() + + add_library(${target_name} ${lib_type} ${exclude_from_all} ${ARG_UNPARSED_ARGUMENTS}) + + target_include_directories(${target_name} + PRIVATE ${PYBIND11_INCLUDE_DIR} # from project CMakeLists.txt + PRIVATE ${pybind11_INCLUDE_DIR} # from pybind11Config + PRIVATE ${PYTHON_INCLUDE_DIRS}) + + # The prefix and extension are provided by FindPythonLibsNew.cmake + set_target_properties(${target_name} PROPERTIES PREFIX "${PYTHON_MODULE_PREFIX}") + set_target_properties(${target_name} PROPERTIES SUFFIX "${PYTHON_MODULE_EXTENSION}") + + # -fvisibility=hidden is required to allow multiple modules compiled against + # different pybind versions to work properly, and for some features (e.g. + # py::module_local). We force it on everything inside the `pybind11` + # namespace; also turning it on for a pybind module compilation here avoids + # potential warnings or issues from having mixed hidden/non-hidden types. + set_target_properties(${target_name} PROPERTIES CXX_VISIBILITY_PRESET "hidden") + + if(WIN32 OR CYGWIN) + # Link against the Python shared library on Windows + target_link_libraries(${target_name} PRIVATE ${PYTHON_LIBRARIES}) + elseif(APPLE) + # It's quite common to have multiple copies of the same Python version + # installed on one's system. E.g.: one copy from the OS and another copy + # that's statically linked into an application like Blender or Maya. + # If we link our plugin library against the OS Python here and import it + # into Blender or Maya later on, this will cause segfaults when multiple + # conflicting Python instances are active at the same time (even when they + # are of the same version). + + # Windows is not affected by this issue since it handles DLL imports + # differently. The solution for Linux and Mac OS is simple: we just don't + # link against the Python library. The resulting shared library will have + # missing symbols, but that's perfectly fine -- they will be resolved at + # import time. + + target_link_libraries(${target_name} PRIVATE "-undefined dynamic_lookup") + + if(ARG_SHARED) + # Suppress CMake >= 3.0 warning for shared libraries + set_target_properties(${target_name} PROPERTIES MACOSX_RPATH ON) + endif() + endif() + + # Make sure C++11/14 are enabled + target_compile_options(${target_name} PUBLIC ${PYBIND11_CPP_STANDARD}) + + if(ARG_NO_EXTRAS) + return() + endif() + + _pybind11_add_lto_flags(${target_name} ${ARG_THIN_LTO}) + + if (NOT MSVC AND NOT ${CMAKE_BUILD_TYPE} MATCHES Debug) + # Strip unnecessary sections of the binary on Linux/Mac OS + if(CMAKE_STRIP) + if(APPLE) + add_custom_command(TARGET ${target_name} POST_BUILD + COMMAND ${CMAKE_STRIP} -x $) + else() + add_custom_command(TARGET ${target_name} POST_BUILD + COMMAND ${CMAKE_STRIP} $) + endif() + endif() + endif() + + if(MSVC) + # /MP enables multithreaded builds (relevant when there are many files), /bigobj is + # needed for bigger binding projects due to the limit to 64k addressable sections + target_compile_options(${target_name} PRIVATE /MP /bigobj) + endif() +endfunction() diff --git a/ml/dlib/dlib/external/zlib/README b/ml/dlib/dlib/external/zlib/README new file mode 100644 index 000000000..5ca9d127e --- /dev/null +++ b/ml/dlib/dlib/external/zlib/README @@ -0,0 +1,115 @@ +ZLIB DATA COMPRESSION LIBRARY + +zlib 1.2.8 is a general purpose data compression library. All the code is +thread safe. The data format used by the zlib library is described by RFCs +(Request for Comments) 1950 to 1952 in the files +http://tools.ietf.org/html/rfc1950 (zlib format), rfc1951 (deflate format) and +rfc1952 (gzip format). + +All functions of the compression library are documented in the file zlib.h +(volunteer to write man pages welcome, contact zlib@gzip.org). A usage example +of the library is given in the file test/example.c which also tests that +the library is working correctly. Another example is given in the file +test/minigzip.c. The compression library itself is composed of all source +files in the root directory. + +To compile all files and run the test program, follow the instructions given at +the top of Makefile.in. In short "./configure; make test", and if that goes +well, "make install" should work for most flavors of Unix. For Windows, use +one of the special makefiles in win32/ or contrib/vstudio/ . For VMS, use +make_vms.com. + +Questions about zlib should be sent to , or to Gilles Vollant + for the Windows DLL version. The zlib home page is +http://zlib.net/ . Before reporting a problem, please check this site to +verify that you have the latest version of zlib; otherwise get the latest +version and check whether the problem still exists or not. + +PLEASE read the zlib FAQ http://zlib.net/zlib_faq.html before asking for help. + +Mark Nelson wrote an article about zlib for the Jan. 1997 +issue of Dr. Dobb's Journal; a copy of the article is available at +http://marknelson.us/1997/01/01/zlib-engine/ . + +The changes made in version 1.2.8 are documented in the file ChangeLog. + +Unsupported third party contributions are provided in directory contrib/ . + +zlib is available in Java using the java.util.zip package, documented at +http://java.sun.com/developer/technicalArticles/Programming/compression/ . + +A Perl interface to zlib written by Paul Marquess is available +at CPAN (Comprehensive Perl Archive Network) sites, including +http://search.cpan.org/~pmqs/IO-Compress-Zlib/ . + +A Python interface to zlib written by A.M. Kuchling is +available in Python 1.5 and later versions, see +http://docs.python.org/library/zlib.html . + +zlib is built into tcl: http://wiki.tcl.tk/4610 . + +An experimental package to read and write files in .zip format, written on top +of zlib by Gilles Vollant , is available in the +contrib/minizip directory of zlib. + + +Notes for some targets: + +- For Windows DLL versions, please see win32/DLL_FAQ.txt + +- For 64-bit Irix, deflate.c must be compiled without any optimization. With + -O, one libpng test fails. The test works in 32 bit mode (with the -n32 + compiler flag). The compiler bug has been reported to SGI. + +- zlib doesn't work with gcc 2.6.3 on a DEC 3000/300LX under OSF/1 2.1 it works + when compiled with cc. + +- On Digital Unix 4.0D (formely OSF/1) on AlphaServer, the cc option -std1 is + necessary to get gzprintf working correctly. This is done by configure. + +- zlib doesn't work on HP-UX 9.05 with some versions of /bin/cc. It works with + other compilers. Use "make test" to check your compiler. + +- gzdopen is not supported on RISCOS or BEOS. + +- For PalmOs, see http://palmzlib.sourceforge.net/ + + +Acknowledgments: + + The deflate format used by zlib was defined by Phil Katz. The deflate and + zlib specifications were written by L. Peter Deutsch. Thanks to all the + people who reported problems and suggested various improvements in zlib; they + are too numerous to cite here. + +Copyright notice: + + (C) 1995-2013 Jean-loup Gailly and Mark Adler + + This software is provided 'as-is', without any express or implied + warranty. In no event will the authors be held liable for any damages + arising from the use of this software. + + Permission is granted to anyone to use this software for any purpose, + including commercial applications, and to alter it and redistribute it + freely, subject to the following restrictions: + + 1. The origin of this software must not be misrepresented; you must not + claim that you wrote the original software. If you use this software + in a product, an acknowledgment in the product documentation would be + appreciated but is not required. + 2. Altered source versions must be plainly marked as such, and must not be + misrepresented as being the original software. + 3. This notice may not be removed or altered from any source distribution. + + Jean-loup Gailly Mark Adler + jloup@gzip.org madler@alumni.caltech.edu + +If you use the zlib library in a product, we would appreciate *not* receiving +lengthy legal documents to sign. The sources are provided for free but without +warranty of any kind. The library has been entirely written by Jean-loup +Gailly and Mark Adler; it does not include third-party code. + +If you redistribute modified sources, we would appreciate that you include in +the file ChangeLog history information documenting your changes. Please read +the FAQ for more information on the distribution of modified source versions. diff --git a/ml/dlib/dlib/external/zlib/adler32.c b/ml/dlib/dlib/external/zlib/adler32.c new file mode 100644 index 000000000..a868f073d --- /dev/null +++ b/ml/dlib/dlib/external/zlib/adler32.c @@ -0,0 +1,179 @@ +/* adler32.c -- compute the Adler-32 checksum of a data stream + * Copyright (C) 1995-2011 Mark Adler + * For conditions of distribution and use, see copyright notice in zlib.h + */ + +/* @(#) $Id$ */ + +#include "zutil.h" + +#define local static + +local uLong adler32_combine_ OF((uLong adler1, uLong adler2, z_off64_t len2)); + +#define BASE 65521 /* largest prime smaller than 65536 */ +#define NMAX 5552 +/* NMAX is the largest n such that 255n(n+1)/2 + (n+1)(BASE-1) <= 2^32-1 */ + +#define DO1(buf,i) {adler += (buf)[i]; sum2 += adler;} +#define DO2(buf,i) DO1(buf,i); DO1(buf,i+1); +#define DO4(buf,i) DO2(buf,i); DO2(buf,i+2); +#define DO8(buf,i) DO4(buf,i); DO4(buf,i+4); +#define DO16(buf) DO8(buf,0); DO8(buf,8); + +/* use NO_DIVIDE if your processor does not do division in hardware -- + try it both ways to see which is faster */ +#ifdef NO_DIVIDE +/* note that this assumes BASE is 65521, where 65536 % 65521 == 15 + (thank you to John Reiser for pointing this out) */ +# define CHOP(a) \ + do { \ + unsigned long tmp = a >> 16; \ + a &= 0xffffUL; \ + a += (tmp << 4) - tmp; \ + } while (0) +# define MOD28(a) \ + do { \ + CHOP(a); \ + if (a >= BASE) a -= BASE; \ + } while (0) +# define MOD(a) \ + do { \ + CHOP(a); \ + MOD28(a); \ + } while (0) +# define MOD63(a) \ + do { /* this assumes a is not negative */ \ + z_off64_t tmp = a >> 32; \ + a &= 0xffffffffL; \ + a += (tmp << 8) - (tmp << 5) + tmp; \ + tmp = a >> 16; \ + a &= 0xffffL; \ + a += (tmp << 4) - tmp; \ + tmp = a >> 16; \ + a &= 0xffffL; \ + a += (tmp << 4) - tmp; \ + if (a >= BASE) a -= BASE; \ + } while (0) +#else +# define MOD(a) a %= BASE +# define MOD28(a) a %= BASE +# define MOD63(a) a %= BASE +#endif + +/* ========================================================================= */ +uLong ZEXPORT adler32(adler, buf, len) + uLong adler; + const Bytef *buf; + uInt len; +{ + unsigned long sum2; + unsigned n; + + /* split Adler-32 into component sums */ + sum2 = (adler >> 16) & 0xffff; + adler &= 0xffff; + + /* in case user likes doing a byte at a time, keep it fast */ + if (len == 1) { + adler += buf[0]; + if (adler >= BASE) + adler -= BASE; + sum2 += adler; + if (sum2 >= BASE) + sum2 -= BASE; + return adler | (sum2 << 16); + } + + /* initial Adler-32 value (deferred check for len == 1 speed) */ + if (buf == Z_NULL) + return 1L; + + /* in case short lengths are provided, keep it somewhat fast */ + if (len < 16) { + while (len--) { + adler += *buf++; + sum2 += adler; + } + if (adler >= BASE) + adler -= BASE; + MOD28(sum2); /* only added so many BASE's */ + return adler | (sum2 << 16); + } + + /* do length NMAX blocks -- requires just one modulo operation */ + while (len >= NMAX) { + len -= NMAX; + n = NMAX / 16; /* NMAX is divisible by 16 */ + do { + DO16(buf); /* 16 sums unrolled */ + buf += 16; + } while (--n); + MOD(adler); + MOD(sum2); + } + + /* do remaining bytes (less than NMAX, still just one modulo) */ + if (len) { /* avoid modulos if none remaining */ + while (len >= 16) { + len -= 16; + DO16(buf); + buf += 16; + } + while (len--) { + adler += *buf++; + sum2 += adler; + } + MOD(adler); + MOD(sum2); + } + + /* return recombined sums */ + return adler | (sum2 << 16); +} + +/* ========================================================================= */ +local uLong adler32_combine_(adler1, adler2, len2) + uLong adler1; + uLong adler2; + z_off64_t len2; +{ + unsigned long sum1; + unsigned long sum2; + unsigned rem; + + /* for negative len, return invalid adler32 as a clue for debugging */ + if (len2 < 0) + return 0xffffffffUL; + + /* the derivation of this formula is left as an exercise for the reader */ + MOD63(len2); /* assumes len2 >= 0 */ + rem = (unsigned)len2; + sum1 = adler1 & 0xffff; + sum2 = rem * sum1; + MOD(sum2); + sum1 += (adler2 & 0xffff) + BASE - 1; + sum2 += ((adler1 >> 16) & 0xffff) + ((adler2 >> 16) & 0xffff) + BASE - rem; + if (sum1 >= BASE) sum1 -= BASE; + if (sum1 >= BASE) sum1 -= BASE; + if (sum2 >= (BASE << 1)) sum2 -= (BASE << 1); + if (sum2 >= BASE) sum2 -= BASE; + return sum1 | (sum2 << 16); +} + +/* ========================================================================= */ +uLong ZEXPORT adler32_combine(adler1, adler2, len2) + uLong adler1; + uLong adler2; + z_off_t len2; +{ + return adler32_combine_(adler1, adler2, len2); +} + +uLong ZEXPORT adler32_combine64(adler1, adler2, len2) + uLong adler1; + uLong adler2; + z_off64_t len2; +{ + return adler32_combine_(adler1, adler2, len2); +} diff --git a/ml/dlib/dlib/external/zlib/compress.c b/ml/dlib/dlib/external/zlib/compress.c new file mode 100644 index 000000000..6e9762676 --- /dev/null +++ b/ml/dlib/dlib/external/zlib/compress.c @@ -0,0 +1,80 @@ +/* compress.c -- compress a memory buffer + * Copyright (C) 1995-2005 Jean-loup Gailly. + * For conditions of distribution and use, see copyright notice in zlib.h + */ + +/* @(#) $Id$ */ + +#define ZLIB_INTERNAL +#include "zlib.h" + +/* =========================================================================== + Compresses the source buffer into the destination buffer. The level + parameter has the same meaning as in deflateInit. sourceLen is the byte + length of the source buffer. Upon entry, destLen is the total size of the + destination buffer, which must be at least 0.1% larger than sourceLen plus + 12 bytes. Upon exit, destLen is the actual size of the compressed buffer. + + compress2 returns Z_OK if success, Z_MEM_ERROR if there was not enough + memory, Z_BUF_ERROR if there was not enough room in the output buffer, + Z_STREAM_ERROR if the level parameter is invalid. +*/ +int ZEXPORT compress2 (dest, destLen, source, sourceLen, level) + Bytef *dest; + uLongf *destLen; + const Bytef *source; + uLong sourceLen; + int level; +{ + z_stream stream; + int err; + + stream.next_in = (z_const Bytef *)source; + stream.avail_in = (uInt)sourceLen; +#ifdef MAXSEG_64K + /* Check for source > 64K on 16-bit machine: */ + if ((uLong)stream.avail_in != sourceLen) return Z_BUF_ERROR; +#endif + stream.next_out = dest; + stream.avail_out = (uInt)*destLen; + if ((uLong)stream.avail_out != *destLen) return Z_BUF_ERROR; + + stream.zalloc = (alloc_func)0; + stream.zfree = (free_func)0; + stream.opaque = (voidpf)0; + + err = deflateInit(&stream, level); + if (err != Z_OK) return err; + + err = deflate(&stream, Z_FINISH); + if (err != Z_STREAM_END) { + deflateEnd(&stream); + return err == Z_OK ? Z_BUF_ERROR : err; + } + *destLen = stream.total_out; + + err = deflateEnd(&stream); + return err; +} + +/* =========================================================================== + */ +int ZEXPORT compress (dest, destLen, source, sourceLen) + Bytef *dest; + uLongf *destLen; + const Bytef *source; + uLong sourceLen; +{ + return compress2(dest, destLen, source, sourceLen, Z_DEFAULT_COMPRESSION); +} + +/* =========================================================================== + If the default memLevel or windowBits for deflateInit() is changed, then + this function needs to be updated. + */ +uLong ZEXPORT compressBound (sourceLen) + uLong sourceLen; +{ + return sourceLen + (sourceLen >> 12) + (sourceLen >> 14) + + (sourceLen >> 25) + 13; +} diff --git a/ml/dlib/dlib/external/zlib/crc32.c b/ml/dlib/dlib/external/zlib/crc32.c new file mode 100644 index 000000000..979a7190a --- /dev/null +++ b/ml/dlib/dlib/external/zlib/crc32.c @@ -0,0 +1,425 @@ +/* crc32.c -- compute the CRC-32 of a data stream + * Copyright (C) 1995-2006, 2010, 2011, 2012 Mark Adler + * For conditions of distribution and use, see copyright notice in zlib.h + * + * Thanks to Rodney Brown for his contribution of faster + * CRC methods: exclusive-oring 32 bits of data at a time, and pre-computing + * tables for updating the shift register in one step with three exclusive-ors + * instead of four steps with four exclusive-ors. This results in about a + * factor of two increase in speed on a Power PC G4 (PPC7455) using gcc -O3. + */ + +/* @(#) $Id$ */ + +/* + Note on the use of DYNAMIC_CRC_TABLE: there is no mutex or semaphore + protection on the static variables used to control the first-use generation + of the crc tables. Therefore, if you #define DYNAMIC_CRC_TABLE, you should + first call get_crc_table() to initialize the tables before allowing more than + one thread to use crc32(). + + DYNAMIC_CRC_TABLE and MAKECRCH can be #defined to write out crc32.h. + */ + +#ifdef MAKECRCH +# include +# ifndef DYNAMIC_CRC_TABLE +# define DYNAMIC_CRC_TABLE +# endif /* !DYNAMIC_CRC_TABLE */ +#endif /* MAKECRCH */ + +#include "zutil.h" /* for STDC and FAR definitions */ + +#define local static + +/* Definitions for doing the crc four data bytes at a time. */ +#if !defined(NOBYFOUR) && defined(Z_U4) +# define BYFOUR +#endif +#ifdef BYFOUR + local unsigned long crc32_little OF((unsigned long, + const unsigned char FAR *, unsigned)); + local unsigned long crc32_big OF((unsigned long, + const unsigned char FAR *, unsigned)); +# define TBLS 8 +#else +# define TBLS 1 +#endif /* BYFOUR */ + +/* Local functions for crc concatenation */ +local unsigned long gf2_matrix_times OF((unsigned long *mat, + unsigned long vec)); +local void gf2_matrix_square OF((unsigned long *square, unsigned long *mat)); +local uLong crc32_combine_ OF((uLong crc1, uLong crc2, z_off64_t len2)); + + +#ifdef DYNAMIC_CRC_TABLE + +local volatile int crc_table_empty = 1; +local z_crc_t FAR crc_table[TBLS][256]; +local void make_crc_table OF((void)); +#ifdef MAKECRCH + local void write_table OF((FILE *, const z_crc_t FAR *)); +#endif /* MAKECRCH */ +/* + Generate tables for a byte-wise 32-bit CRC calculation on the polynomial: + x^32+x^26+x^23+x^22+x^16+x^12+x^11+x^10+x^8+x^7+x^5+x^4+x^2+x+1. + + Polynomials over GF(2) are represented in binary, one bit per coefficient, + with the lowest powers in the most significant bit. Then adding polynomials + is just exclusive-or, and multiplying a polynomial by x is a right shift by + one. If we call the above polynomial p, and represent a byte as the + polynomial q, also with the lowest power in the most significant bit (so the + byte 0xb1 is the polynomial x^7+x^3+x+1), then the CRC is (q*x^32) mod p, + where a mod b means the remainder after dividing a by b. + + This calculation is done using the shift-register method of multiplying and + taking the remainder. The register is initialized to zero, and for each + incoming bit, x^32 is added mod p to the register if the bit is a one (where + x^32 mod p is p+x^32 = x^26+...+1), and the register is multiplied mod p by + x (which is shifting right by one and adding x^32 mod p if the bit shifted + out is a one). We start with the highest power (least significant bit) of + q and repeat for all eight bits of q. + + The first table is simply the CRC of all possible eight bit values. This is + all the information needed to generate CRCs on data a byte at a time for all + combinations of CRC register values and incoming bytes. The remaining tables + allow for word-at-a-time CRC calculation for both big-endian and little- + endian machines, where a word is four bytes. +*/ +local void make_crc_table() +{ + z_crc_t c; + int n, k; + z_crc_t poly; /* polynomial exclusive-or pattern */ + /* terms of polynomial defining this crc (except x^32): */ + static volatile int first = 1; /* flag to limit concurrent making */ + static const unsigned char p[] = {0,1,2,4,5,7,8,10,11,12,16,22,23,26}; + + /* See if another task is already doing this (not thread-safe, but better + than nothing -- significantly reduces duration of vulnerability in + case the advice about DYNAMIC_CRC_TABLE is ignored) */ + if (first) { + first = 0; + + /* make exclusive-or pattern from polynomial (0xedb88320UL) */ + poly = 0; + for (n = 0; n < (int)(sizeof(p)/sizeof(unsigned char)); n++) + poly |= (z_crc_t)1 << (31 - p[n]); + + /* generate a crc for every 8-bit value */ + for (n = 0; n < 256; n++) { + c = (z_crc_t)n; + for (k = 0; k < 8; k++) + c = c & 1 ? poly ^ (c >> 1) : c >> 1; + crc_table[0][n] = c; + } + +#ifdef BYFOUR + /* generate crc for each value followed by one, two, and three zeros, + and then the byte reversal of those as well as the first table */ + for (n = 0; n < 256; n++) { + c = crc_table[0][n]; + crc_table[4][n] = ZSWAP32(c); + for (k = 1; k < 4; k++) { + c = crc_table[0][c & 0xff] ^ (c >> 8); + crc_table[k][n] = c; + crc_table[k + 4][n] = ZSWAP32(c); + } + } +#endif /* BYFOUR */ + + crc_table_empty = 0; + } + else { /* not first */ + /* wait for the other guy to finish (not efficient, but rare) */ + while (crc_table_empty) + ; + } + +#ifdef MAKECRCH + /* write out CRC tables to crc32.h */ + { + FILE *out; + + out = fopen("crc32.h", "w"); + if (out == NULL) return; + fprintf(out, "/* crc32.h -- tables for rapid CRC calculation\n"); + fprintf(out, " * Generated automatically by crc32.c\n */\n\n"); + fprintf(out, "local const z_crc_t FAR "); + fprintf(out, "crc_table[TBLS][256] =\n{\n {\n"); + write_table(out, crc_table[0]); +# ifdef BYFOUR + fprintf(out, "#ifdef BYFOUR\n"); + for (k = 1; k < 8; k++) { + fprintf(out, " },\n {\n"); + write_table(out, crc_table[k]); + } + fprintf(out, "#endif\n"); +# endif /* BYFOUR */ + fprintf(out, " }\n};\n"); + fclose(out); + } +#endif /* MAKECRCH */ +} + +#ifdef MAKECRCH +local void write_table(out, table) + FILE *out; + const z_crc_t FAR *table; +{ + int n; + + for (n = 0; n < 256; n++) + fprintf(out, "%s0x%08lxUL%s", n % 5 ? "" : " ", + (unsigned long)(table[n]), + n == 255 ? "\n" : (n % 5 == 4 ? ",\n" : ", ")); +} +#endif /* MAKECRCH */ + +#else /* !DYNAMIC_CRC_TABLE */ +/* ======================================================================== + * Tables of CRC-32s of all single-byte values, made by make_crc_table(). + */ +#include "crc32.h" +#endif /* DYNAMIC_CRC_TABLE */ + +/* ========================================================================= + * This function can be used by asm versions of crc32() + */ +const z_crc_t FAR * ZEXPORT get_crc_table() +{ +#ifdef DYNAMIC_CRC_TABLE + if (crc_table_empty) + make_crc_table(); +#endif /* DYNAMIC_CRC_TABLE */ + return (const z_crc_t FAR *)crc_table; +} + +/* ========================================================================= */ +#define DO1 crc = crc_table[0][((int)crc ^ (*buf++)) & 0xff] ^ (crc >> 8) +#define DO8 DO1; DO1; DO1; DO1; DO1; DO1; DO1; DO1 + +/* ========================================================================= */ +unsigned long ZEXPORT crc32(crc, buf, len) + unsigned long crc; + const unsigned char FAR *buf; + uInt len; +{ + if (buf == Z_NULL) return 0UL; + +#ifdef DYNAMIC_CRC_TABLE + if (crc_table_empty) + make_crc_table(); +#endif /* DYNAMIC_CRC_TABLE */ + +#ifdef BYFOUR + if (sizeof(void *) == sizeof(ptrdiff_t)) { + z_crc_t endian; + + endian = 1; + if (*((unsigned char *)(&endian))) + return crc32_little(crc, buf, len); + else + return crc32_big(crc, buf, len); + } +#endif /* BYFOUR */ + crc = crc ^ 0xffffffffUL; + while (len >= 8) { + DO8; + len -= 8; + } + if (len) do { + DO1; + } while (--len); + return crc ^ 0xffffffffUL; +} + +#ifdef BYFOUR + +/* ========================================================================= */ +#define DOLIT4 c ^= *buf4++; \ + c = crc_table[3][c & 0xff] ^ crc_table[2][(c >> 8) & 0xff] ^ \ + crc_table[1][(c >> 16) & 0xff] ^ crc_table[0][c >> 24] +#define DOLIT32 DOLIT4; DOLIT4; DOLIT4; DOLIT4; DOLIT4; DOLIT4; DOLIT4; DOLIT4 + +/* ========================================================================= */ +local unsigned long crc32_little(crc, buf, len) + unsigned long crc; + const unsigned char FAR *buf; + unsigned len; +{ + register z_crc_t c; + register const z_crc_t FAR *buf4; + + c = (z_crc_t)crc; + c = ~c; + while (len && ((ptrdiff_t)buf & 3)) { + c = crc_table[0][(c ^ *buf++) & 0xff] ^ (c >> 8); + len--; + } + + buf4 = (const z_crc_t FAR *)(const void FAR *)buf; + while (len >= 32) { + DOLIT32; + len -= 32; + } + while (len >= 4) { + DOLIT4; + len -= 4; + } + buf = (const unsigned char FAR *)buf4; + + if (len) do { + c = crc_table[0][(c ^ *buf++) & 0xff] ^ (c >> 8); + } while (--len); + c = ~c; + return (unsigned long)c; +} + +/* ========================================================================= */ +#define DOBIG4 c ^= *++buf4; \ + c = crc_table[4][c & 0xff] ^ crc_table[5][(c >> 8) & 0xff] ^ \ + crc_table[6][(c >> 16) & 0xff] ^ crc_table[7][c >> 24] +#define DOBIG32 DOBIG4; DOBIG4; DOBIG4; DOBIG4; DOBIG4; DOBIG4; DOBIG4; DOBIG4 + +/* ========================================================================= */ +local unsigned long crc32_big(crc, buf, len) + unsigned long crc; + const unsigned char FAR *buf; + unsigned len; +{ + register z_crc_t c; + register const z_crc_t FAR *buf4; + + c = ZSWAP32((z_crc_t)crc); + c = ~c; + while (len && ((ptrdiff_t)buf & 3)) { + c = crc_table[4][(c >> 24) ^ *buf++] ^ (c << 8); + len--; + } + + buf4 = (const z_crc_t FAR *)(const void FAR *)buf; + buf4--; + while (len >= 32) { + DOBIG32; + len -= 32; + } + while (len >= 4) { + DOBIG4; + len -= 4; + } + buf4++; + buf = (const unsigned char FAR *)buf4; + + if (len) do { + c = crc_table[4][(c >> 24) ^ *buf++] ^ (c << 8); + } while (--len); + c = ~c; + return (unsigned long)(ZSWAP32(c)); +} + +#endif /* BYFOUR */ + +#define GF2_DIM 32 /* dimension of GF(2) vectors (length of CRC) */ + +/* ========================================================================= */ +local unsigned long gf2_matrix_times(mat, vec) + unsigned long *mat; + unsigned long vec; +{ + unsigned long sum; + + sum = 0; + while (vec) { + if (vec & 1) + sum ^= *mat; + vec >>= 1; + mat++; + } + return sum; +} + +/* ========================================================================= */ +local void gf2_matrix_square(square, mat) + unsigned long *square; + unsigned long *mat; +{ + int n; + + for (n = 0; n < GF2_DIM; n++) + square[n] = gf2_matrix_times(mat, mat[n]); +} + +/* ========================================================================= */ +local uLong crc32_combine_(crc1, crc2, len2) + uLong crc1; + uLong crc2; + z_off64_t len2; +{ + int n; + unsigned long row; + unsigned long even[GF2_DIM]; /* even-power-of-two zeros operator */ + unsigned long odd[GF2_DIM]; /* odd-power-of-two zeros operator */ + + /* degenerate case (also disallow negative lengths) */ + if (len2 <= 0) + return crc1; + + /* put operator for one zero bit in odd */ + odd[0] = 0xedb88320UL; /* CRC-32 polynomial */ + row = 1; + for (n = 1; n < GF2_DIM; n++) { + odd[n] = row; + row <<= 1; + } + + /* put operator for two zero bits in even */ + gf2_matrix_square(even, odd); + + /* put operator for four zero bits in odd */ + gf2_matrix_square(odd, even); + + /* apply len2 zeros to crc1 (first square will put the operator for one + zero byte, eight zero bits, in even) */ + do { + /* apply zeros operator for this bit of len2 */ + gf2_matrix_square(even, odd); + if (len2 & 1) + crc1 = gf2_matrix_times(even, crc1); + len2 >>= 1; + + /* if no more bits set, then done */ + if (len2 == 0) + break; + + /* another iteration of the loop with odd and even swapped */ + gf2_matrix_square(odd, even); + if (len2 & 1) + crc1 = gf2_matrix_times(odd, crc1); + len2 >>= 1; + + /* if no more bits set, then done */ + } while (len2 != 0); + + /* return combined crc */ + crc1 ^= crc2; + return crc1; +} + +/* ========================================================================= */ +uLong ZEXPORT crc32_combine(crc1, crc2, len2) + uLong crc1; + uLong crc2; + z_off_t len2; +{ + return crc32_combine_(crc1, crc2, len2); +} + +uLong ZEXPORT crc32_combine64(crc1, crc2, len2) + uLong crc1; + uLong crc2; + z_off64_t len2; +{ + return crc32_combine_(crc1, crc2, len2); +} diff --git a/ml/dlib/dlib/external/zlib/crc32.h b/ml/dlib/dlib/external/zlib/crc32.h new file mode 100644 index 000000000..9e0c77810 --- /dev/null +++ b/ml/dlib/dlib/external/zlib/crc32.h @@ -0,0 +1,441 @@ +/* crc32.h -- tables for rapid CRC calculation + * Generated automatically by crc32.c + */ + +local const z_crc_t FAR crc_table[TBLS][256] = +{ + { + 0x00000000UL, 0x77073096UL, 0xee0e612cUL, 0x990951baUL, 0x076dc419UL, + 0x706af48fUL, 0xe963a535UL, 0x9e6495a3UL, 0x0edb8832UL, 0x79dcb8a4UL, + 0xe0d5e91eUL, 0x97d2d988UL, 0x09b64c2bUL, 0x7eb17cbdUL, 0xe7b82d07UL, + 0x90bf1d91UL, 0x1db71064UL, 0x6ab020f2UL, 0xf3b97148UL, 0x84be41deUL, + 0x1adad47dUL, 0x6ddde4ebUL, 0xf4d4b551UL, 0x83d385c7UL, 0x136c9856UL, + 0x646ba8c0UL, 0xfd62f97aUL, 0x8a65c9ecUL, 0x14015c4fUL, 0x63066cd9UL, + 0xfa0f3d63UL, 0x8d080df5UL, 0x3b6e20c8UL, 0x4c69105eUL, 0xd56041e4UL, + 0xa2677172UL, 0x3c03e4d1UL, 0x4b04d447UL, 0xd20d85fdUL, 0xa50ab56bUL, + 0x35b5a8faUL, 0x42b2986cUL, 0xdbbbc9d6UL, 0xacbcf940UL, 0x32d86ce3UL, + 0x45df5c75UL, 0xdcd60dcfUL, 0xabd13d59UL, 0x26d930acUL, 0x51de003aUL, + 0xc8d75180UL, 0xbfd06116UL, 0x21b4f4b5UL, 0x56b3c423UL, 0xcfba9599UL, + 0xb8bda50fUL, 0x2802b89eUL, 0x5f058808UL, 0xc60cd9b2UL, 0xb10be924UL, + 0x2f6f7c87UL, 0x58684c11UL, 0xc1611dabUL, 0xb6662d3dUL, 0x76dc4190UL, + 0x01db7106UL, 0x98d220bcUL, 0xefd5102aUL, 0x71b18589UL, 0x06b6b51fUL, + 0x9fbfe4a5UL, 0xe8b8d433UL, 0x7807c9a2UL, 0x0f00f934UL, 0x9609a88eUL, + 0xe10e9818UL, 0x7f6a0dbbUL, 0x086d3d2dUL, 0x91646c97UL, 0xe6635c01UL, + 0x6b6b51f4UL, 0x1c6c6162UL, 0x856530d8UL, 0xf262004eUL, 0x6c0695edUL, + 0x1b01a57bUL, 0x8208f4c1UL, 0xf50fc457UL, 0x65b0d9c6UL, 0x12b7e950UL, + 0x8bbeb8eaUL, 0xfcb9887cUL, 0x62dd1ddfUL, 0x15da2d49UL, 0x8cd37cf3UL, + 0xfbd44c65UL, 0x4db26158UL, 0x3ab551ceUL, 0xa3bc0074UL, 0xd4bb30e2UL, + 0x4adfa541UL, 0x3dd895d7UL, 0xa4d1c46dUL, 0xd3d6f4fbUL, 0x4369e96aUL, + 0x346ed9fcUL, 0xad678846UL, 0xda60b8d0UL, 0x44042d73UL, 0x33031de5UL, + 0xaa0a4c5fUL, 0xdd0d7cc9UL, 0x5005713cUL, 0x270241aaUL, 0xbe0b1010UL, + 0xc90c2086UL, 0x5768b525UL, 0x206f85b3UL, 0xb966d409UL, 0xce61e49fUL, + 0x5edef90eUL, 0x29d9c998UL, 0xb0d09822UL, 0xc7d7a8b4UL, 0x59b33d17UL, + 0x2eb40d81UL, 0xb7bd5c3bUL, 0xc0ba6cadUL, 0xedb88320UL, 0x9abfb3b6UL, + 0x03b6e20cUL, 0x74b1d29aUL, 0xead54739UL, 0x9dd277afUL, 0x04db2615UL, + 0x73dc1683UL, 0xe3630b12UL, 0x94643b84UL, 0x0d6d6a3eUL, 0x7a6a5aa8UL, + 0xe40ecf0bUL, 0x9309ff9dUL, 0x0a00ae27UL, 0x7d079eb1UL, 0xf00f9344UL, + 0x8708a3d2UL, 0x1e01f268UL, 0x6906c2feUL, 0xf762575dUL, 0x806567cbUL, + 0x196c3671UL, 0x6e6b06e7UL, 0xfed41b76UL, 0x89d32be0UL, 0x10da7a5aUL, + 0x67dd4accUL, 0xf9b9df6fUL, 0x8ebeeff9UL, 0x17b7be43UL, 0x60b08ed5UL, + 0xd6d6a3e8UL, 0xa1d1937eUL, 0x38d8c2c4UL, 0x4fdff252UL, 0xd1bb67f1UL, + 0xa6bc5767UL, 0x3fb506ddUL, 0x48b2364bUL, 0xd80d2bdaUL, 0xaf0a1b4cUL, + 0x36034af6UL, 0x41047a60UL, 0xdf60efc3UL, 0xa867df55UL, 0x316e8eefUL, + 0x4669be79UL, 0xcb61b38cUL, 0xbc66831aUL, 0x256fd2a0UL, 0x5268e236UL, + 0xcc0c7795UL, 0xbb0b4703UL, 0x220216b9UL, 0x5505262fUL, 0xc5ba3bbeUL, + 0xb2bd0b28UL, 0x2bb45a92UL, 0x5cb36a04UL, 0xc2d7ffa7UL, 0xb5d0cf31UL, + 0x2cd99e8bUL, 0x5bdeae1dUL, 0x9b64c2b0UL, 0xec63f226UL, 0x756aa39cUL, + 0x026d930aUL, 0x9c0906a9UL, 0xeb0e363fUL, 0x72076785UL, 0x05005713UL, + 0x95bf4a82UL, 0xe2b87a14UL, 0x7bb12baeUL, 0x0cb61b38UL, 0x92d28e9bUL, + 0xe5d5be0dUL, 0x7cdcefb7UL, 0x0bdbdf21UL, 0x86d3d2d4UL, 0xf1d4e242UL, + 0x68ddb3f8UL, 0x1fda836eUL, 0x81be16cdUL, 0xf6b9265bUL, 0x6fb077e1UL, + 0x18b74777UL, 0x88085ae6UL, 0xff0f6a70UL, 0x66063bcaUL, 0x11010b5cUL, + 0x8f659effUL, 0xf862ae69UL, 0x616bffd3UL, 0x166ccf45UL, 0xa00ae278UL, + 0xd70dd2eeUL, 0x4e048354UL, 0x3903b3c2UL, 0xa7672661UL, 0xd06016f7UL, + 0x4969474dUL, 0x3e6e77dbUL, 0xaed16a4aUL, 0xd9d65adcUL, 0x40df0b66UL, + 0x37d83bf0UL, 0xa9bcae53UL, 0xdebb9ec5UL, 0x47b2cf7fUL, 0x30b5ffe9UL, + 0xbdbdf21cUL, 0xcabac28aUL, 0x53b39330UL, 0x24b4a3a6UL, 0xbad03605UL, + 0xcdd70693UL, 0x54de5729UL, 0x23d967bfUL, 0xb3667a2eUL, 0xc4614ab8UL, + 0x5d681b02UL, 0x2a6f2b94UL, 0xb40bbe37UL, 0xc30c8ea1UL, 0x5a05df1bUL, + 0x2d02ef8dUL +#ifdef BYFOUR + }, + { + 0x00000000UL, 0x191b3141UL, 0x32366282UL, 0x2b2d53c3UL, 0x646cc504UL, + 0x7d77f445UL, 0x565aa786UL, 0x4f4196c7UL, 0xc8d98a08UL, 0xd1c2bb49UL, + 0xfaefe88aUL, 0xe3f4d9cbUL, 0xacb54f0cUL, 0xb5ae7e4dUL, 0x9e832d8eUL, + 0x87981ccfUL, 0x4ac21251UL, 0x53d92310UL, 0x78f470d3UL, 0x61ef4192UL, + 0x2eaed755UL, 0x37b5e614UL, 0x1c98b5d7UL, 0x05838496UL, 0x821b9859UL, + 0x9b00a918UL, 0xb02dfadbUL, 0xa936cb9aUL, 0xe6775d5dUL, 0xff6c6c1cUL, + 0xd4413fdfUL, 0xcd5a0e9eUL, 0x958424a2UL, 0x8c9f15e3UL, 0xa7b24620UL, + 0xbea97761UL, 0xf1e8e1a6UL, 0xe8f3d0e7UL, 0xc3de8324UL, 0xdac5b265UL, + 0x5d5daeaaUL, 0x44469febUL, 0x6f6bcc28UL, 0x7670fd69UL, 0x39316baeUL, + 0x202a5aefUL, 0x0b07092cUL, 0x121c386dUL, 0xdf4636f3UL, 0xc65d07b2UL, + 0xed705471UL, 0xf46b6530UL, 0xbb2af3f7UL, 0xa231c2b6UL, 0x891c9175UL, + 0x9007a034UL, 0x179fbcfbUL, 0x0e848dbaUL, 0x25a9de79UL, 0x3cb2ef38UL, + 0x73f379ffUL, 0x6ae848beUL, 0x41c51b7dUL, 0x58de2a3cUL, 0xf0794f05UL, + 0xe9627e44UL, 0xc24f2d87UL, 0xdb541cc6UL, 0x94158a01UL, 0x8d0ebb40UL, + 0xa623e883UL, 0xbf38d9c2UL, 0x38a0c50dUL, 0x21bbf44cUL, 0x0a96a78fUL, + 0x138d96ceUL, 0x5ccc0009UL, 0x45d73148UL, 0x6efa628bUL, 0x77e153caUL, + 0xbabb5d54UL, 0xa3a06c15UL, 0x888d3fd6UL, 0x91960e97UL, 0xded79850UL, + 0xc7cca911UL, 0xece1fad2UL, 0xf5facb93UL, 0x7262d75cUL, 0x6b79e61dUL, + 0x4054b5deUL, 0x594f849fUL, 0x160e1258UL, 0x0f152319UL, 0x243870daUL, + 0x3d23419bUL, 0x65fd6ba7UL, 0x7ce65ae6UL, 0x57cb0925UL, 0x4ed03864UL, + 0x0191aea3UL, 0x188a9fe2UL, 0x33a7cc21UL, 0x2abcfd60UL, 0xad24e1afUL, + 0xb43fd0eeUL, 0x9f12832dUL, 0x8609b26cUL, 0xc94824abUL, 0xd05315eaUL, + 0xfb7e4629UL, 0xe2657768UL, 0x2f3f79f6UL, 0x362448b7UL, 0x1d091b74UL, + 0x04122a35UL, 0x4b53bcf2UL, 0x52488db3UL, 0x7965de70UL, 0x607eef31UL, + 0xe7e6f3feUL, 0xfefdc2bfUL, 0xd5d0917cUL, 0xcccba03dUL, 0x838a36faUL, + 0x9a9107bbUL, 0xb1bc5478UL, 0xa8a76539UL, 0x3b83984bUL, 0x2298a90aUL, + 0x09b5fac9UL, 0x10aecb88UL, 0x5fef5d4fUL, 0x46f46c0eUL, 0x6dd93fcdUL, + 0x74c20e8cUL, 0xf35a1243UL, 0xea412302UL, 0xc16c70c1UL, 0xd8774180UL, + 0x9736d747UL, 0x8e2de606UL, 0xa500b5c5UL, 0xbc1b8484UL, 0x71418a1aUL, + 0x685abb5bUL, 0x4377e898UL, 0x5a6cd9d9UL, 0x152d4f1eUL, 0x0c367e5fUL, + 0x271b2d9cUL, 0x3e001cddUL, 0xb9980012UL, 0xa0833153UL, 0x8bae6290UL, + 0x92b553d1UL, 0xddf4c516UL, 0xc4eff457UL, 0xefc2a794UL, 0xf6d996d5UL, + 0xae07bce9UL, 0xb71c8da8UL, 0x9c31de6bUL, 0x852aef2aUL, 0xca6b79edUL, + 0xd37048acUL, 0xf85d1b6fUL, 0xe1462a2eUL, 0x66de36e1UL, 0x7fc507a0UL, + 0x54e85463UL, 0x4df36522UL, 0x02b2f3e5UL, 0x1ba9c2a4UL, 0x30849167UL, + 0x299fa026UL, 0xe4c5aeb8UL, 0xfdde9ff9UL, 0xd6f3cc3aUL, 0xcfe8fd7bUL, + 0x80a96bbcUL, 0x99b25afdUL, 0xb29f093eUL, 0xab84387fUL, 0x2c1c24b0UL, + 0x350715f1UL, 0x1e2a4632UL, 0x07317773UL, 0x4870e1b4UL, 0x516bd0f5UL, + 0x7a468336UL, 0x635db277UL, 0xcbfad74eUL, 0xd2e1e60fUL, 0xf9ccb5ccUL, + 0xe0d7848dUL, 0xaf96124aUL, 0xb68d230bUL, 0x9da070c8UL, 0x84bb4189UL, + 0x03235d46UL, 0x1a386c07UL, 0x31153fc4UL, 0x280e0e85UL, 0x674f9842UL, + 0x7e54a903UL, 0x5579fac0UL, 0x4c62cb81UL, 0x8138c51fUL, 0x9823f45eUL, + 0xb30ea79dUL, 0xaa1596dcUL, 0xe554001bUL, 0xfc4f315aUL, 0xd7626299UL, + 0xce7953d8UL, 0x49e14f17UL, 0x50fa7e56UL, 0x7bd72d95UL, 0x62cc1cd4UL, + 0x2d8d8a13UL, 0x3496bb52UL, 0x1fbbe891UL, 0x06a0d9d0UL, 0x5e7ef3ecUL, + 0x4765c2adUL, 0x6c48916eUL, 0x7553a02fUL, 0x3a1236e8UL, 0x230907a9UL, + 0x0824546aUL, 0x113f652bUL, 0x96a779e4UL, 0x8fbc48a5UL, 0xa4911b66UL, + 0xbd8a2a27UL, 0xf2cbbce0UL, 0xebd08da1UL, 0xc0fdde62UL, 0xd9e6ef23UL, + 0x14bce1bdUL, 0x0da7d0fcUL, 0x268a833fUL, 0x3f91b27eUL, 0x70d024b9UL, + 0x69cb15f8UL, 0x42e6463bUL, 0x5bfd777aUL, 0xdc656bb5UL, 0xc57e5af4UL, + 0xee530937UL, 0xf7483876UL, 0xb809aeb1UL, 0xa1129ff0UL, 0x8a3fcc33UL, + 0x9324fd72UL + }, + { + 0x00000000UL, 0x01c26a37UL, 0x0384d46eUL, 0x0246be59UL, 0x0709a8dcUL, + 0x06cbc2ebUL, 0x048d7cb2UL, 0x054f1685UL, 0x0e1351b8UL, 0x0fd13b8fUL, + 0x0d9785d6UL, 0x0c55efe1UL, 0x091af964UL, 0x08d89353UL, 0x0a9e2d0aUL, + 0x0b5c473dUL, 0x1c26a370UL, 0x1de4c947UL, 0x1fa2771eUL, 0x1e601d29UL, + 0x1b2f0bacUL, 0x1aed619bUL, 0x18abdfc2UL, 0x1969b5f5UL, 0x1235f2c8UL, + 0x13f798ffUL, 0x11b126a6UL, 0x10734c91UL, 0x153c5a14UL, 0x14fe3023UL, + 0x16b88e7aUL, 0x177ae44dUL, 0x384d46e0UL, 0x398f2cd7UL, 0x3bc9928eUL, + 0x3a0bf8b9UL, 0x3f44ee3cUL, 0x3e86840bUL, 0x3cc03a52UL, 0x3d025065UL, + 0x365e1758UL, 0x379c7d6fUL, 0x35dac336UL, 0x3418a901UL, 0x3157bf84UL, + 0x3095d5b3UL, 0x32d36beaUL, 0x331101ddUL, 0x246be590UL, 0x25a98fa7UL, + 0x27ef31feUL, 0x262d5bc9UL, 0x23624d4cUL, 0x22a0277bUL, 0x20e69922UL, + 0x2124f315UL, 0x2a78b428UL, 0x2bbade1fUL, 0x29fc6046UL, 0x283e0a71UL, + 0x2d711cf4UL, 0x2cb376c3UL, 0x2ef5c89aUL, 0x2f37a2adUL, 0x709a8dc0UL, + 0x7158e7f7UL, 0x731e59aeUL, 0x72dc3399UL, 0x7793251cUL, 0x76514f2bUL, + 0x7417f172UL, 0x75d59b45UL, 0x7e89dc78UL, 0x7f4bb64fUL, 0x7d0d0816UL, + 0x7ccf6221UL, 0x798074a4UL, 0x78421e93UL, 0x7a04a0caUL, 0x7bc6cafdUL, + 0x6cbc2eb0UL, 0x6d7e4487UL, 0x6f38fadeUL, 0x6efa90e9UL, 0x6bb5866cUL, + 0x6a77ec5bUL, 0x68315202UL, 0x69f33835UL, 0x62af7f08UL, 0x636d153fUL, + 0x612bab66UL, 0x60e9c151UL, 0x65a6d7d4UL, 0x6464bde3UL, 0x662203baUL, + 0x67e0698dUL, 0x48d7cb20UL, 0x4915a117UL, 0x4b531f4eUL, 0x4a917579UL, + 0x4fde63fcUL, 0x4e1c09cbUL, 0x4c5ab792UL, 0x4d98dda5UL, 0x46c49a98UL, + 0x4706f0afUL, 0x45404ef6UL, 0x448224c1UL, 0x41cd3244UL, 0x400f5873UL, + 0x4249e62aUL, 0x438b8c1dUL, 0x54f16850UL, 0x55330267UL, 0x5775bc3eUL, + 0x56b7d609UL, 0x53f8c08cUL, 0x523aaabbUL, 0x507c14e2UL, 0x51be7ed5UL, + 0x5ae239e8UL, 0x5b2053dfUL, 0x5966ed86UL, 0x58a487b1UL, 0x5deb9134UL, + 0x5c29fb03UL, 0x5e6f455aUL, 0x5fad2f6dUL, 0xe1351b80UL, 0xe0f771b7UL, + 0xe2b1cfeeUL, 0xe373a5d9UL, 0xe63cb35cUL, 0xe7fed96bUL, 0xe5b86732UL, + 0xe47a0d05UL, 0xef264a38UL, 0xeee4200fUL, 0xeca29e56UL, 0xed60f461UL, + 0xe82fe2e4UL, 0xe9ed88d3UL, 0xebab368aUL, 0xea695cbdUL, 0xfd13b8f0UL, + 0xfcd1d2c7UL, 0xfe976c9eUL, 0xff5506a9UL, 0xfa1a102cUL, 0xfbd87a1bUL, + 0xf99ec442UL, 0xf85cae75UL, 0xf300e948UL, 0xf2c2837fUL, 0xf0843d26UL, + 0xf1465711UL, 0xf4094194UL, 0xf5cb2ba3UL, 0xf78d95faUL, 0xf64fffcdUL, + 0xd9785d60UL, 0xd8ba3757UL, 0xdafc890eUL, 0xdb3ee339UL, 0xde71f5bcUL, + 0xdfb39f8bUL, 0xddf521d2UL, 0xdc374be5UL, 0xd76b0cd8UL, 0xd6a966efUL, + 0xd4efd8b6UL, 0xd52db281UL, 0xd062a404UL, 0xd1a0ce33UL, 0xd3e6706aUL, + 0xd2241a5dUL, 0xc55efe10UL, 0xc49c9427UL, 0xc6da2a7eUL, 0xc7184049UL, + 0xc25756ccUL, 0xc3953cfbUL, 0xc1d382a2UL, 0xc011e895UL, 0xcb4dafa8UL, + 0xca8fc59fUL, 0xc8c97bc6UL, 0xc90b11f1UL, 0xcc440774UL, 0xcd866d43UL, + 0xcfc0d31aUL, 0xce02b92dUL, 0x91af9640UL, 0x906dfc77UL, 0x922b422eUL, + 0x93e92819UL, 0x96a63e9cUL, 0x976454abUL, 0x9522eaf2UL, 0x94e080c5UL, + 0x9fbcc7f8UL, 0x9e7eadcfUL, 0x9c381396UL, 0x9dfa79a1UL, 0x98b56f24UL, + 0x99770513UL, 0x9b31bb4aUL, 0x9af3d17dUL, 0x8d893530UL, 0x8c4b5f07UL, + 0x8e0de15eUL, 0x8fcf8b69UL, 0x8a809decUL, 0x8b42f7dbUL, 0x89044982UL, + 0x88c623b5UL, 0x839a6488UL, 0x82580ebfUL, 0x801eb0e6UL, 0x81dcdad1UL, + 0x8493cc54UL, 0x8551a663UL, 0x8717183aUL, 0x86d5720dUL, 0xa9e2d0a0UL, + 0xa820ba97UL, 0xaa6604ceUL, 0xaba46ef9UL, 0xaeeb787cUL, 0xaf29124bUL, + 0xad6fac12UL, 0xacadc625UL, 0xa7f18118UL, 0xa633eb2fUL, 0xa4755576UL, + 0xa5b73f41UL, 0xa0f829c4UL, 0xa13a43f3UL, 0xa37cfdaaUL, 0xa2be979dUL, + 0xb5c473d0UL, 0xb40619e7UL, 0xb640a7beUL, 0xb782cd89UL, 0xb2cddb0cUL, + 0xb30fb13bUL, 0xb1490f62UL, 0xb08b6555UL, 0xbbd72268UL, 0xba15485fUL, + 0xb853f606UL, 0xb9919c31UL, 0xbcde8ab4UL, 0xbd1ce083UL, 0xbf5a5edaUL, + 0xbe9834edUL + }, + { + 0x00000000UL, 0xb8bc6765UL, 0xaa09c88bUL, 0x12b5afeeUL, 0x8f629757UL, + 0x37def032UL, 0x256b5fdcUL, 0x9dd738b9UL, 0xc5b428efUL, 0x7d084f8aUL, + 0x6fbde064UL, 0xd7018701UL, 0x4ad6bfb8UL, 0xf26ad8ddUL, 0xe0df7733UL, + 0x58631056UL, 0x5019579fUL, 0xe8a530faUL, 0xfa109f14UL, 0x42acf871UL, + 0xdf7bc0c8UL, 0x67c7a7adUL, 0x75720843UL, 0xcdce6f26UL, 0x95ad7f70UL, + 0x2d111815UL, 0x3fa4b7fbUL, 0x8718d09eUL, 0x1acfe827UL, 0xa2738f42UL, + 0xb0c620acUL, 0x087a47c9UL, 0xa032af3eUL, 0x188ec85bUL, 0x0a3b67b5UL, + 0xb28700d0UL, 0x2f503869UL, 0x97ec5f0cUL, 0x8559f0e2UL, 0x3de59787UL, + 0x658687d1UL, 0xdd3ae0b4UL, 0xcf8f4f5aUL, 0x7733283fUL, 0xeae41086UL, + 0x525877e3UL, 0x40edd80dUL, 0xf851bf68UL, 0xf02bf8a1UL, 0x48979fc4UL, + 0x5a22302aUL, 0xe29e574fUL, 0x7f496ff6UL, 0xc7f50893UL, 0xd540a77dUL, + 0x6dfcc018UL, 0x359fd04eUL, 0x8d23b72bUL, 0x9f9618c5UL, 0x272a7fa0UL, + 0xbafd4719UL, 0x0241207cUL, 0x10f48f92UL, 0xa848e8f7UL, 0x9b14583dUL, + 0x23a83f58UL, 0x311d90b6UL, 0x89a1f7d3UL, 0x1476cf6aUL, 0xaccaa80fUL, + 0xbe7f07e1UL, 0x06c36084UL, 0x5ea070d2UL, 0xe61c17b7UL, 0xf4a9b859UL, + 0x4c15df3cUL, 0xd1c2e785UL, 0x697e80e0UL, 0x7bcb2f0eUL, 0xc377486bUL, + 0xcb0d0fa2UL, 0x73b168c7UL, 0x6104c729UL, 0xd9b8a04cUL, 0x446f98f5UL, + 0xfcd3ff90UL, 0xee66507eUL, 0x56da371bUL, 0x0eb9274dUL, 0xb6054028UL, + 0xa4b0efc6UL, 0x1c0c88a3UL, 0x81dbb01aUL, 0x3967d77fUL, 0x2bd27891UL, + 0x936e1ff4UL, 0x3b26f703UL, 0x839a9066UL, 0x912f3f88UL, 0x299358edUL, + 0xb4446054UL, 0x0cf80731UL, 0x1e4da8dfUL, 0xa6f1cfbaUL, 0xfe92dfecUL, + 0x462eb889UL, 0x549b1767UL, 0xec277002UL, 0x71f048bbUL, 0xc94c2fdeUL, + 0xdbf98030UL, 0x6345e755UL, 0x6b3fa09cUL, 0xd383c7f9UL, 0xc1366817UL, + 0x798a0f72UL, 0xe45d37cbUL, 0x5ce150aeUL, 0x4e54ff40UL, 0xf6e89825UL, + 0xae8b8873UL, 0x1637ef16UL, 0x048240f8UL, 0xbc3e279dUL, 0x21e91f24UL, + 0x99557841UL, 0x8be0d7afUL, 0x335cb0caUL, 0xed59b63bUL, 0x55e5d15eUL, + 0x47507eb0UL, 0xffec19d5UL, 0x623b216cUL, 0xda874609UL, 0xc832e9e7UL, + 0x708e8e82UL, 0x28ed9ed4UL, 0x9051f9b1UL, 0x82e4565fUL, 0x3a58313aUL, + 0xa78f0983UL, 0x1f336ee6UL, 0x0d86c108UL, 0xb53aa66dUL, 0xbd40e1a4UL, + 0x05fc86c1UL, 0x1749292fUL, 0xaff54e4aUL, 0x322276f3UL, 0x8a9e1196UL, + 0x982bbe78UL, 0x2097d91dUL, 0x78f4c94bUL, 0xc048ae2eUL, 0xd2fd01c0UL, + 0x6a4166a5UL, 0xf7965e1cUL, 0x4f2a3979UL, 0x5d9f9697UL, 0xe523f1f2UL, + 0x4d6b1905UL, 0xf5d77e60UL, 0xe762d18eUL, 0x5fdeb6ebUL, 0xc2098e52UL, + 0x7ab5e937UL, 0x680046d9UL, 0xd0bc21bcUL, 0x88df31eaUL, 0x3063568fUL, + 0x22d6f961UL, 0x9a6a9e04UL, 0x07bda6bdUL, 0xbf01c1d8UL, 0xadb46e36UL, + 0x15080953UL, 0x1d724e9aUL, 0xa5ce29ffUL, 0xb77b8611UL, 0x0fc7e174UL, + 0x9210d9cdUL, 0x2aacbea8UL, 0x38191146UL, 0x80a57623UL, 0xd8c66675UL, + 0x607a0110UL, 0x72cfaefeUL, 0xca73c99bUL, 0x57a4f122UL, 0xef189647UL, + 0xfdad39a9UL, 0x45115eccUL, 0x764dee06UL, 0xcef18963UL, 0xdc44268dUL, + 0x64f841e8UL, 0xf92f7951UL, 0x41931e34UL, 0x5326b1daUL, 0xeb9ad6bfUL, + 0xb3f9c6e9UL, 0x0b45a18cUL, 0x19f00e62UL, 0xa14c6907UL, 0x3c9b51beUL, + 0x842736dbUL, 0x96929935UL, 0x2e2efe50UL, 0x2654b999UL, 0x9ee8defcUL, + 0x8c5d7112UL, 0x34e11677UL, 0xa9362eceUL, 0x118a49abUL, 0x033fe645UL, + 0xbb838120UL, 0xe3e09176UL, 0x5b5cf613UL, 0x49e959fdUL, 0xf1553e98UL, + 0x6c820621UL, 0xd43e6144UL, 0xc68bceaaUL, 0x7e37a9cfUL, 0xd67f4138UL, + 0x6ec3265dUL, 0x7c7689b3UL, 0xc4caeed6UL, 0x591dd66fUL, 0xe1a1b10aUL, + 0xf3141ee4UL, 0x4ba87981UL, 0x13cb69d7UL, 0xab770eb2UL, 0xb9c2a15cUL, + 0x017ec639UL, 0x9ca9fe80UL, 0x241599e5UL, 0x36a0360bUL, 0x8e1c516eUL, + 0x866616a7UL, 0x3eda71c2UL, 0x2c6fde2cUL, 0x94d3b949UL, 0x090481f0UL, + 0xb1b8e695UL, 0xa30d497bUL, 0x1bb12e1eUL, 0x43d23e48UL, 0xfb6e592dUL, + 0xe9dbf6c3UL, 0x516791a6UL, 0xccb0a91fUL, 0x740cce7aUL, 0x66b96194UL, + 0xde0506f1UL + }, + { + 0x00000000UL, 0x96300777UL, 0x2c610eeeUL, 0xba510999UL, 0x19c46d07UL, + 0x8ff46a70UL, 0x35a563e9UL, 0xa395649eUL, 0x3288db0eUL, 0xa4b8dc79UL, + 0x1ee9d5e0UL, 0x88d9d297UL, 0x2b4cb609UL, 0xbd7cb17eUL, 0x072db8e7UL, + 0x911dbf90UL, 0x6410b71dUL, 0xf220b06aUL, 0x4871b9f3UL, 0xde41be84UL, + 0x7dd4da1aUL, 0xebe4dd6dUL, 0x51b5d4f4UL, 0xc785d383UL, 0x56986c13UL, + 0xc0a86b64UL, 0x7af962fdUL, 0xecc9658aUL, 0x4f5c0114UL, 0xd96c0663UL, + 0x633d0ffaUL, 0xf50d088dUL, 0xc8206e3bUL, 0x5e10694cUL, 0xe44160d5UL, + 0x727167a2UL, 0xd1e4033cUL, 0x47d4044bUL, 0xfd850dd2UL, 0x6bb50aa5UL, + 0xfaa8b535UL, 0x6c98b242UL, 0xd6c9bbdbUL, 0x40f9bcacUL, 0xe36cd832UL, + 0x755cdf45UL, 0xcf0dd6dcUL, 0x593dd1abUL, 0xac30d926UL, 0x3a00de51UL, + 0x8051d7c8UL, 0x1661d0bfUL, 0xb5f4b421UL, 0x23c4b356UL, 0x9995bacfUL, + 0x0fa5bdb8UL, 0x9eb80228UL, 0x0888055fUL, 0xb2d90cc6UL, 0x24e90bb1UL, + 0x877c6f2fUL, 0x114c6858UL, 0xab1d61c1UL, 0x3d2d66b6UL, 0x9041dc76UL, + 0x0671db01UL, 0xbc20d298UL, 0x2a10d5efUL, 0x8985b171UL, 0x1fb5b606UL, + 0xa5e4bf9fUL, 0x33d4b8e8UL, 0xa2c90778UL, 0x34f9000fUL, 0x8ea80996UL, + 0x18980ee1UL, 0xbb0d6a7fUL, 0x2d3d6d08UL, 0x976c6491UL, 0x015c63e6UL, + 0xf4516b6bUL, 0x62616c1cUL, 0xd8306585UL, 0x4e0062f2UL, 0xed95066cUL, + 0x7ba5011bUL, 0xc1f40882UL, 0x57c40ff5UL, 0xc6d9b065UL, 0x50e9b712UL, + 0xeab8be8bUL, 0x7c88b9fcUL, 0xdf1ddd62UL, 0x492dda15UL, 0xf37cd38cUL, + 0x654cd4fbUL, 0x5861b24dUL, 0xce51b53aUL, 0x7400bca3UL, 0xe230bbd4UL, + 0x41a5df4aUL, 0xd795d83dUL, 0x6dc4d1a4UL, 0xfbf4d6d3UL, 0x6ae96943UL, + 0xfcd96e34UL, 0x468867adUL, 0xd0b860daUL, 0x732d0444UL, 0xe51d0333UL, + 0x5f4c0aaaUL, 0xc97c0dddUL, 0x3c710550UL, 0xaa410227UL, 0x10100bbeUL, + 0x86200cc9UL, 0x25b56857UL, 0xb3856f20UL, 0x09d466b9UL, 0x9fe461ceUL, + 0x0ef9de5eUL, 0x98c9d929UL, 0x2298d0b0UL, 0xb4a8d7c7UL, 0x173db359UL, + 0x810db42eUL, 0x3b5cbdb7UL, 0xad6cbac0UL, 0x2083b8edUL, 0xb6b3bf9aUL, + 0x0ce2b603UL, 0x9ad2b174UL, 0x3947d5eaUL, 0xaf77d29dUL, 0x1526db04UL, + 0x8316dc73UL, 0x120b63e3UL, 0x843b6494UL, 0x3e6a6d0dUL, 0xa85a6a7aUL, + 0x0bcf0ee4UL, 0x9dff0993UL, 0x27ae000aUL, 0xb19e077dUL, 0x44930ff0UL, + 0xd2a30887UL, 0x68f2011eUL, 0xfec20669UL, 0x5d5762f7UL, 0xcb676580UL, + 0x71366c19UL, 0xe7066b6eUL, 0x761bd4feUL, 0xe02bd389UL, 0x5a7ada10UL, + 0xcc4add67UL, 0x6fdfb9f9UL, 0xf9efbe8eUL, 0x43beb717UL, 0xd58eb060UL, + 0xe8a3d6d6UL, 0x7e93d1a1UL, 0xc4c2d838UL, 0x52f2df4fUL, 0xf167bbd1UL, + 0x6757bca6UL, 0xdd06b53fUL, 0x4b36b248UL, 0xda2b0dd8UL, 0x4c1b0aafUL, + 0xf64a0336UL, 0x607a0441UL, 0xc3ef60dfUL, 0x55df67a8UL, 0xef8e6e31UL, + 0x79be6946UL, 0x8cb361cbUL, 0x1a8366bcUL, 0xa0d26f25UL, 0x36e26852UL, + 0x95770cccUL, 0x03470bbbUL, 0xb9160222UL, 0x2f260555UL, 0xbe3bbac5UL, + 0x280bbdb2UL, 0x925ab42bUL, 0x046ab35cUL, 0xa7ffd7c2UL, 0x31cfd0b5UL, + 0x8b9ed92cUL, 0x1daede5bUL, 0xb0c2649bUL, 0x26f263ecUL, 0x9ca36a75UL, + 0x0a936d02UL, 0xa906099cUL, 0x3f360eebUL, 0x85670772UL, 0x13570005UL, + 0x824abf95UL, 0x147ab8e2UL, 0xae2bb17bUL, 0x381bb60cUL, 0x9b8ed292UL, + 0x0dbed5e5UL, 0xb7efdc7cUL, 0x21dfdb0bUL, 0xd4d2d386UL, 0x42e2d4f1UL, + 0xf8b3dd68UL, 0x6e83da1fUL, 0xcd16be81UL, 0x5b26b9f6UL, 0xe177b06fUL, + 0x7747b718UL, 0xe65a0888UL, 0x706a0fffUL, 0xca3b0666UL, 0x5c0b0111UL, + 0xff9e658fUL, 0x69ae62f8UL, 0xd3ff6b61UL, 0x45cf6c16UL, 0x78e20aa0UL, + 0xeed20dd7UL, 0x5483044eUL, 0xc2b30339UL, 0x612667a7UL, 0xf71660d0UL, + 0x4d476949UL, 0xdb776e3eUL, 0x4a6ad1aeUL, 0xdc5ad6d9UL, 0x660bdf40UL, + 0xf03bd837UL, 0x53aebca9UL, 0xc59ebbdeUL, 0x7fcfb247UL, 0xe9ffb530UL, + 0x1cf2bdbdUL, 0x8ac2bacaUL, 0x3093b353UL, 0xa6a3b424UL, 0x0536d0baUL, + 0x9306d7cdUL, 0x2957de54UL, 0xbf67d923UL, 0x2e7a66b3UL, 0xb84a61c4UL, + 0x021b685dUL, 0x942b6f2aUL, 0x37be0bb4UL, 0xa18e0cc3UL, 0x1bdf055aUL, + 0x8def022dUL + }, + { + 0x00000000UL, 0x41311b19UL, 0x82623632UL, 0xc3532d2bUL, 0x04c56c64UL, + 0x45f4777dUL, 0x86a75a56UL, 0xc796414fUL, 0x088ad9c8UL, 0x49bbc2d1UL, + 0x8ae8effaUL, 0xcbd9f4e3UL, 0x0c4fb5acUL, 0x4d7eaeb5UL, 0x8e2d839eUL, + 0xcf1c9887UL, 0x5112c24aUL, 0x1023d953UL, 0xd370f478UL, 0x9241ef61UL, + 0x55d7ae2eUL, 0x14e6b537UL, 0xd7b5981cUL, 0x96848305UL, 0x59981b82UL, + 0x18a9009bUL, 0xdbfa2db0UL, 0x9acb36a9UL, 0x5d5d77e6UL, 0x1c6c6cffUL, + 0xdf3f41d4UL, 0x9e0e5acdUL, 0xa2248495UL, 0xe3159f8cUL, 0x2046b2a7UL, + 0x6177a9beUL, 0xa6e1e8f1UL, 0xe7d0f3e8UL, 0x2483dec3UL, 0x65b2c5daUL, + 0xaaae5d5dUL, 0xeb9f4644UL, 0x28cc6b6fUL, 0x69fd7076UL, 0xae6b3139UL, + 0xef5a2a20UL, 0x2c09070bUL, 0x6d381c12UL, 0xf33646dfUL, 0xb2075dc6UL, + 0x715470edUL, 0x30656bf4UL, 0xf7f32abbUL, 0xb6c231a2UL, 0x75911c89UL, + 0x34a00790UL, 0xfbbc9f17UL, 0xba8d840eUL, 0x79dea925UL, 0x38efb23cUL, + 0xff79f373UL, 0xbe48e86aUL, 0x7d1bc541UL, 0x3c2ade58UL, 0x054f79f0UL, + 0x447e62e9UL, 0x872d4fc2UL, 0xc61c54dbUL, 0x018a1594UL, 0x40bb0e8dUL, + 0x83e823a6UL, 0xc2d938bfUL, 0x0dc5a038UL, 0x4cf4bb21UL, 0x8fa7960aUL, + 0xce968d13UL, 0x0900cc5cUL, 0x4831d745UL, 0x8b62fa6eUL, 0xca53e177UL, + 0x545dbbbaUL, 0x156ca0a3UL, 0xd63f8d88UL, 0x970e9691UL, 0x5098d7deUL, + 0x11a9ccc7UL, 0xd2fae1ecUL, 0x93cbfaf5UL, 0x5cd76272UL, 0x1de6796bUL, + 0xdeb55440UL, 0x9f844f59UL, 0x58120e16UL, 0x1923150fUL, 0xda703824UL, + 0x9b41233dUL, 0xa76bfd65UL, 0xe65ae67cUL, 0x2509cb57UL, 0x6438d04eUL, + 0xa3ae9101UL, 0xe29f8a18UL, 0x21cca733UL, 0x60fdbc2aUL, 0xafe124adUL, + 0xeed03fb4UL, 0x2d83129fUL, 0x6cb20986UL, 0xab2448c9UL, 0xea1553d0UL, + 0x29467efbUL, 0x687765e2UL, 0xf6793f2fUL, 0xb7482436UL, 0x741b091dUL, + 0x352a1204UL, 0xf2bc534bUL, 0xb38d4852UL, 0x70de6579UL, 0x31ef7e60UL, + 0xfef3e6e7UL, 0xbfc2fdfeUL, 0x7c91d0d5UL, 0x3da0cbccUL, 0xfa368a83UL, + 0xbb07919aUL, 0x7854bcb1UL, 0x3965a7a8UL, 0x4b98833bUL, 0x0aa99822UL, + 0xc9fab509UL, 0x88cbae10UL, 0x4f5def5fUL, 0x0e6cf446UL, 0xcd3fd96dUL, + 0x8c0ec274UL, 0x43125af3UL, 0x022341eaUL, 0xc1706cc1UL, 0x804177d8UL, + 0x47d73697UL, 0x06e62d8eUL, 0xc5b500a5UL, 0x84841bbcUL, 0x1a8a4171UL, + 0x5bbb5a68UL, 0x98e87743UL, 0xd9d96c5aUL, 0x1e4f2d15UL, 0x5f7e360cUL, + 0x9c2d1b27UL, 0xdd1c003eUL, 0x120098b9UL, 0x533183a0UL, 0x9062ae8bUL, + 0xd153b592UL, 0x16c5f4ddUL, 0x57f4efc4UL, 0x94a7c2efUL, 0xd596d9f6UL, + 0xe9bc07aeUL, 0xa88d1cb7UL, 0x6bde319cUL, 0x2aef2a85UL, 0xed796bcaUL, + 0xac4870d3UL, 0x6f1b5df8UL, 0x2e2a46e1UL, 0xe136de66UL, 0xa007c57fUL, + 0x6354e854UL, 0x2265f34dUL, 0xe5f3b202UL, 0xa4c2a91bUL, 0x67918430UL, + 0x26a09f29UL, 0xb8aec5e4UL, 0xf99fdefdUL, 0x3accf3d6UL, 0x7bfde8cfUL, + 0xbc6ba980UL, 0xfd5ab299UL, 0x3e099fb2UL, 0x7f3884abUL, 0xb0241c2cUL, + 0xf1150735UL, 0x32462a1eUL, 0x73773107UL, 0xb4e17048UL, 0xf5d06b51UL, + 0x3683467aUL, 0x77b25d63UL, 0x4ed7facbUL, 0x0fe6e1d2UL, 0xccb5ccf9UL, + 0x8d84d7e0UL, 0x4a1296afUL, 0x0b238db6UL, 0xc870a09dUL, 0x8941bb84UL, + 0x465d2303UL, 0x076c381aUL, 0xc43f1531UL, 0x850e0e28UL, 0x42984f67UL, + 0x03a9547eUL, 0xc0fa7955UL, 0x81cb624cUL, 0x1fc53881UL, 0x5ef42398UL, + 0x9da70eb3UL, 0xdc9615aaUL, 0x1b0054e5UL, 0x5a314ffcUL, 0x996262d7UL, + 0xd85379ceUL, 0x174fe149UL, 0x567efa50UL, 0x952dd77bUL, 0xd41ccc62UL, + 0x138a8d2dUL, 0x52bb9634UL, 0x91e8bb1fUL, 0xd0d9a006UL, 0xecf37e5eUL, + 0xadc26547UL, 0x6e91486cUL, 0x2fa05375UL, 0xe836123aUL, 0xa9070923UL, + 0x6a542408UL, 0x2b653f11UL, 0xe479a796UL, 0xa548bc8fUL, 0x661b91a4UL, + 0x272a8abdUL, 0xe0bccbf2UL, 0xa18dd0ebUL, 0x62defdc0UL, 0x23efe6d9UL, + 0xbde1bc14UL, 0xfcd0a70dUL, 0x3f838a26UL, 0x7eb2913fUL, 0xb924d070UL, + 0xf815cb69UL, 0x3b46e642UL, 0x7a77fd5bUL, 0xb56b65dcUL, 0xf45a7ec5UL, + 0x370953eeUL, 0x763848f7UL, 0xb1ae09b8UL, 0xf09f12a1UL, 0x33cc3f8aUL, + 0x72fd2493UL + }, + { + 0x00000000UL, 0x376ac201UL, 0x6ed48403UL, 0x59be4602UL, 0xdca80907UL, + 0xebc2cb06UL, 0xb27c8d04UL, 0x85164f05UL, 0xb851130eUL, 0x8f3bd10fUL, + 0xd685970dUL, 0xe1ef550cUL, 0x64f91a09UL, 0x5393d808UL, 0x0a2d9e0aUL, + 0x3d475c0bUL, 0x70a3261cUL, 0x47c9e41dUL, 0x1e77a21fUL, 0x291d601eUL, + 0xac0b2f1bUL, 0x9b61ed1aUL, 0xc2dfab18UL, 0xf5b56919UL, 0xc8f23512UL, + 0xff98f713UL, 0xa626b111UL, 0x914c7310UL, 0x145a3c15UL, 0x2330fe14UL, + 0x7a8eb816UL, 0x4de47a17UL, 0xe0464d38UL, 0xd72c8f39UL, 0x8e92c93bUL, + 0xb9f80b3aUL, 0x3cee443fUL, 0x0b84863eUL, 0x523ac03cUL, 0x6550023dUL, + 0x58175e36UL, 0x6f7d9c37UL, 0x36c3da35UL, 0x01a91834UL, 0x84bf5731UL, + 0xb3d59530UL, 0xea6bd332UL, 0xdd011133UL, 0x90e56b24UL, 0xa78fa925UL, + 0xfe31ef27UL, 0xc95b2d26UL, 0x4c4d6223UL, 0x7b27a022UL, 0x2299e620UL, + 0x15f32421UL, 0x28b4782aUL, 0x1fdeba2bUL, 0x4660fc29UL, 0x710a3e28UL, + 0xf41c712dUL, 0xc376b32cUL, 0x9ac8f52eUL, 0xada2372fUL, 0xc08d9a70UL, + 0xf7e75871UL, 0xae591e73UL, 0x9933dc72UL, 0x1c259377UL, 0x2b4f5176UL, + 0x72f11774UL, 0x459bd575UL, 0x78dc897eUL, 0x4fb64b7fUL, 0x16080d7dUL, + 0x2162cf7cUL, 0xa4748079UL, 0x931e4278UL, 0xcaa0047aUL, 0xfdcac67bUL, + 0xb02ebc6cUL, 0x87447e6dUL, 0xdefa386fUL, 0xe990fa6eUL, 0x6c86b56bUL, + 0x5bec776aUL, 0x02523168UL, 0x3538f369UL, 0x087faf62UL, 0x3f156d63UL, + 0x66ab2b61UL, 0x51c1e960UL, 0xd4d7a665UL, 0xe3bd6464UL, 0xba032266UL, + 0x8d69e067UL, 0x20cbd748UL, 0x17a11549UL, 0x4e1f534bUL, 0x7975914aUL, + 0xfc63de4fUL, 0xcb091c4eUL, 0x92b75a4cUL, 0xa5dd984dUL, 0x989ac446UL, + 0xaff00647UL, 0xf64e4045UL, 0xc1248244UL, 0x4432cd41UL, 0x73580f40UL, + 0x2ae64942UL, 0x1d8c8b43UL, 0x5068f154UL, 0x67023355UL, 0x3ebc7557UL, + 0x09d6b756UL, 0x8cc0f853UL, 0xbbaa3a52UL, 0xe2147c50UL, 0xd57ebe51UL, + 0xe839e25aUL, 0xdf53205bUL, 0x86ed6659UL, 0xb187a458UL, 0x3491eb5dUL, + 0x03fb295cUL, 0x5a456f5eUL, 0x6d2fad5fUL, 0x801b35e1UL, 0xb771f7e0UL, + 0xeecfb1e2UL, 0xd9a573e3UL, 0x5cb33ce6UL, 0x6bd9fee7UL, 0x3267b8e5UL, + 0x050d7ae4UL, 0x384a26efUL, 0x0f20e4eeUL, 0x569ea2ecUL, 0x61f460edUL, + 0xe4e22fe8UL, 0xd388ede9UL, 0x8a36abebUL, 0xbd5c69eaUL, 0xf0b813fdUL, + 0xc7d2d1fcUL, 0x9e6c97feUL, 0xa90655ffUL, 0x2c101afaUL, 0x1b7ad8fbUL, + 0x42c49ef9UL, 0x75ae5cf8UL, 0x48e900f3UL, 0x7f83c2f2UL, 0x263d84f0UL, + 0x115746f1UL, 0x944109f4UL, 0xa32bcbf5UL, 0xfa958df7UL, 0xcdff4ff6UL, + 0x605d78d9UL, 0x5737bad8UL, 0x0e89fcdaUL, 0x39e33edbUL, 0xbcf571deUL, + 0x8b9fb3dfUL, 0xd221f5ddUL, 0xe54b37dcUL, 0xd80c6bd7UL, 0xef66a9d6UL, + 0xb6d8efd4UL, 0x81b22dd5UL, 0x04a462d0UL, 0x33cea0d1UL, 0x6a70e6d3UL, + 0x5d1a24d2UL, 0x10fe5ec5UL, 0x27949cc4UL, 0x7e2adac6UL, 0x494018c7UL, + 0xcc5657c2UL, 0xfb3c95c3UL, 0xa282d3c1UL, 0x95e811c0UL, 0xa8af4dcbUL, + 0x9fc58fcaUL, 0xc67bc9c8UL, 0xf1110bc9UL, 0x740744ccUL, 0x436d86cdUL, + 0x1ad3c0cfUL, 0x2db902ceUL, 0x4096af91UL, 0x77fc6d90UL, 0x2e422b92UL, + 0x1928e993UL, 0x9c3ea696UL, 0xab546497UL, 0xf2ea2295UL, 0xc580e094UL, + 0xf8c7bc9fUL, 0xcfad7e9eUL, 0x9613389cUL, 0xa179fa9dUL, 0x246fb598UL, + 0x13057799UL, 0x4abb319bUL, 0x7dd1f39aUL, 0x3035898dUL, 0x075f4b8cUL, + 0x5ee10d8eUL, 0x698bcf8fUL, 0xec9d808aUL, 0xdbf7428bUL, 0x82490489UL, + 0xb523c688UL, 0x88649a83UL, 0xbf0e5882UL, 0xe6b01e80UL, 0xd1dadc81UL, + 0x54cc9384UL, 0x63a65185UL, 0x3a181787UL, 0x0d72d586UL, 0xa0d0e2a9UL, + 0x97ba20a8UL, 0xce0466aaUL, 0xf96ea4abUL, 0x7c78ebaeUL, 0x4b1229afUL, + 0x12ac6fadUL, 0x25c6adacUL, 0x1881f1a7UL, 0x2feb33a6UL, 0x765575a4UL, + 0x413fb7a5UL, 0xc429f8a0UL, 0xf3433aa1UL, 0xaafd7ca3UL, 0x9d97bea2UL, + 0xd073c4b5UL, 0xe71906b4UL, 0xbea740b6UL, 0x89cd82b7UL, 0x0cdbcdb2UL, + 0x3bb10fb3UL, 0x620f49b1UL, 0x55658bb0UL, 0x6822d7bbUL, 0x5f4815baUL, + 0x06f653b8UL, 0x319c91b9UL, 0xb48adebcUL, 0x83e01cbdUL, 0xda5e5abfUL, + 0xed3498beUL + }, + { + 0x00000000UL, 0x6567bcb8UL, 0x8bc809aaUL, 0xeeafb512UL, 0x5797628fUL, + 0x32f0de37UL, 0xdc5f6b25UL, 0xb938d79dUL, 0xef28b4c5UL, 0x8a4f087dUL, + 0x64e0bd6fUL, 0x018701d7UL, 0xb8bfd64aUL, 0xddd86af2UL, 0x3377dfe0UL, + 0x56106358UL, 0x9f571950UL, 0xfa30a5e8UL, 0x149f10faUL, 0x71f8ac42UL, + 0xc8c07bdfUL, 0xada7c767UL, 0x43087275UL, 0x266fcecdUL, 0x707fad95UL, + 0x1518112dUL, 0xfbb7a43fUL, 0x9ed01887UL, 0x27e8cf1aUL, 0x428f73a2UL, + 0xac20c6b0UL, 0xc9477a08UL, 0x3eaf32a0UL, 0x5bc88e18UL, 0xb5673b0aUL, + 0xd00087b2UL, 0x6938502fUL, 0x0c5fec97UL, 0xe2f05985UL, 0x8797e53dUL, + 0xd1878665UL, 0xb4e03addUL, 0x5a4f8fcfUL, 0x3f283377UL, 0x8610e4eaUL, + 0xe3775852UL, 0x0dd8ed40UL, 0x68bf51f8UL, 0xa1f82bf0UL, 0xc49f9748UL, + 0x2a30225aUL, 0x4f579ee2UL, 0xf66f497fUL, 0x9308f5c7UL, 0x7da740d5UL, + 0x18c0fc6dUL, 0x4ed09f35UL, 0x2bb7238dUL, 0xc518969fUL, 0xa07f2a27UL, + 0x1947fdbaUL, 0x7c204102UL, 0x928ff410UL, 0xf7e848a8UL, 0x3d58149bUL, + 0x583fa823UL, 0xb6901d31UL, 0xd3f7a189UL, 0x6acf7614UL, 0x0fa8caacUL, + 0xe1077fbeUL, 0x8460c306UL, 0xd270a05eUL, 0xb7171ce6UL, 0x59b8a9f4UL, + 0x3cdf154cUL, 0x85e7c2d1UL, 0xe0807e69UL, 0x0e2fcb7bUL, 0x6b4877c3UL, + 0xa20f0dcbUL, 0xc768b173UL, 0x29c70461UL, 0x4ca0b8d9UL, 0xf5986f44UL, + 0x90ffd3fcUL, 0x7e5066eeUL, 0x1b37da56UL, 0x4d27b90eUL, 0x284005b6UL, + 0xc6efb0a4UL, 0xa3880c1cUL, 0x1ab0db81UL, 0x7fd76739UL, 0x9178d22bUL, + 0xf41f6e93UL, 0x03f7263bUL, 0x66909a83UL, 0x883f2f91UL, 0xed589329UL, + 0x546044b4UL, 0x3107f80cUL, 0xdfa84d1eUL, 0xbacff1a6UL, 0xecdf92feUL, + 0x89b82e46UL, 0x67179b54UL, 0x027027ecUL, 0xbb48f071UL, 0xde2f4cc9UL, + 0x3080f9dbUL, 0x55e74563UL, 0x9ca03f6bUL, 0xf9c783d3UL, 0x176836c1UL, + 0x720f8a79UL, 0xcb375de4UL, 0xae50e15cUL, 0x40ff544eUL, 0x2598e8f6UL, + 0x73888baeUL, 0x16ef3716UL, 0xf8408204UL, 0x9d273ebcUL, 0x241fe921UL, + 0x41785599UL, 0xafd7e08bUL, 0xcab05c33UL, 0x3bb659edUL, 0x5ed1e555UL, + 0xb07e5047UL, 0xd519ecffUL, 0x6c213b62UL, 0x094687daUL, 0xe7e932c8UL, + 0x828e8e70UL, 0xd49eed28UL, 0xb1f95190UL, 0x5f56e482UL, 0x3a31583aUL, + 0x83098fa7UL, 0xe66e331fUL, 0x08c1860dUL, 0x6da63ab5UL, 0xa4e140bdUL, + 0xc186fc05UL, 0x2f294917UL, 0x4a4ef5afUL, 0xf3762232UL, 0x96119e8aUL, + 0x78be2b98UL, 0x1dd99720UL, 0x4bc9f478UL, 0x2eae48c0UL, 0xc001fdd2UL, + 0xa566416aUL, 0x1c5e96f7UL, 0x79392a4fUL, 0x97969f5dUL, 0xf2f123e5UL, + 0x05196b4dUL, 0x607ed7f5UL, 0x8ed162e7UL, 0xebb6de5fUL, 0x528e09c2UL, + 0x37e9b57aUL, 0xd9460068UL, 0xbc21bcd0UL, 0xea31df88UL, 0x8f566330UL, + 0x61f9d622UL, 0x049e6a9aUL, 0xbda6bd07UL, 0xd8c101bfUL, 0x366eb4adUL, + 0x53090815UL, 0x9a4e721dUL, 0xff29cea5UL, 0x11867bb7UL, 0x74e1c70fUL, + 0xcdd91092UL, 0xa8beac2aUL, 0x46111938UL, 0x2376a580UL, 0x7566c6d8UL, + 0x10017a60UL, 0xfeaecf72UL, 0x9bc973caUL, 0x22f1a457UL, 0x479618efUL, + 0xa939adfdUL, 0xcc5e1145UL, 0x06ee4d76UL, 0x6389f1ceUL, 0x8d2644dcUL, + 0xe841f864UL, 0x51792ff9UL, 0x341e9341UL, 0xdab12653UL, 0xbfd69aebUL, + 0xe9c6f9b3UL, 0x8ca1450bUL, 0x620ef019UL, 0x07694ca1UL, 0xbe519b3cUL, + 0xdb362784UL, 0x35999296UL, 0x50fe2e2eUL, 0x99b95426UL, 0xfcdee89eUL, + 0x12715d8cUL, 0x7716e134UL, 0xce2e36a9UL, 0xab498a11UL, 0x45e63f03UL, + 0x208183bbUL, 0x7691e0e3UL, 0x13f65c5bUL, 0xfd59e949UL, 0x983e55f1UL, + 0x2106826cUL, 0x44613ed4UL, 0xaace8bc6UL, 0xcfa9377eUL, 0x38417fd6UL, + 0x5d26c36eUL, 0xb389767cUL, 0xd6eecac4UL, 0x6fd61d59UL, 0x0ab1a1e1UL, + 0xe41e14f3UL, 0x8179a84bUL, 0xd769cb13UL, 0xb20e77abUL, 0x5ca1c2b9UL, + 0x39c67e01UL, 0x80fea99cUL, 0xe5991524UL, 0x0b36a036UL, 0x6e511c8eUL, + 0xa7166686UL, 0xc271da3eUL, 0x2cde6f2cUL, 0x49b9d394UL, 0xf0810409UL, + 0x95e6b8b1UL, 0x7b490da3UL, 0x1e2eb11bUL, 0x483ed243UL, 0x2d596efbUL, + 0xc3f6dbe9UL, 0xa6916751UL, 0x1fa9b0ccUL, 0x7ace0c74UL, 0x9461b966UL, + 0xf10605deUL +#endif + } +}; diff --git a/ml/dlib/dlib/external/zlib/deflate.c b/ml/dlib/dlib/external/zlib/deflate.c new file mode 100644 index 000000000..696957705 --- /dev/null +++ b/ml/dlib/dlib/external/zlib/deflate.c @@ -0,0 +1,1967 @@ +/* deflate.c -- compress data using the deflation algorithm + * Copyright (C) 1995-2013 Jean-loup Gailly and Mark Adler + * For conditions of distribution and use, see copyright notice in zlib.h + */ + +/* + * ALGORITHM + * + * The "deflation" process depends on being able to identify portions + * of the input text which are identical to earlier input (within a + * sliding window trailing behind the input currently being processed). + * + * The most straightforward technique turns out to be the fastest for + * most input files: try all possible matches and select the longest. + * The key feature of this algorithm is that insertions into the string + * dictionary are very simple and thus fast, and deletions are avoided + * completely. Insertions are performed at each input character, whereas + * string matches are performed only when the previous match ends. So it + * is preferable to spend more time in matches to allow very fast string + * insertions and avoid deletions. The matching algorithm for small + * strings is inspired from that of Rabin & Karp. A brute force approach + * is used to find longer strings when a small match has been found. + * A similar algorithm is used in comic (by Jan-Mark Wams) and freeze + * (by Leonid Broukhis). + * A previous version of this file used a more sophisticated algorithm + * (by Fiala and Greene) which is guaranteed to run in linear amortized + * time, but has a larger average cost, uses more memory and is patented. + * However the F&G algorithm may be faster for some highly redundant + * files if the parameter max_chain_length (described below) is too large. + * + * ACKNOWLEDGEMENTS + * + * The idea of lazy evaluation of matches is due to Jan-Mark Wams, and + * I found it in 'freeze' written by Leonid Broukhis. + * Thanks to many people for bug reports and testing. + * + * REFERENCES + * + * Deutsch, L.P.,"DEFLATE Compressed Data Format Specification". + * Available in http://tools.ietf.org/html/rfc1951 + * + * A description of the Rabin and Karp algorithm is given in the book + * "Algorithms" by R. Sedgewick, Addison-Wesley, p252. + * + * Fiala,E.R., and Greene,D.H. + * Data Compression with Finite Windows, Comm.ACM, 32,4 (1989) 490-595 + * + */ + +/* @(#) $Id$ */ + +#include "deflate.h" + +const char deflate_copyright[] = + " deflate 1.2.8 Copyright 1995-2013 Jean-loup Gailly and Mark Adler "; +/* + If you use the zlib library in a product, an acknowledgment is welcome + in the documentation of your product. If for some reason you cannot + include such an acknowledgment, I would appreciate that you keep this + copyright string in the executable of your product. + */ + +/* =========================================================================== + * Function prototypes. + */ +typedef enum { + need_more, /* block not completed, need more input or more output */ + block_done, /* block flush performed */ + finish_started, /* finish started, need only more output at next deflate */ + finish_done /* finish done, accept no more input or output */ +} block_state; + +typedef block_state (*compress_func) OF((deflate_state *s, int flush)); +/* Compression function. Returns the block state after the call. */ + +local void fill_window OF((deflate_state *s)); +local block_state deflate_stored OF((deflate_state *s, int flush)); +local block_state deflate_fast OF((deflate_state *s, int flush)); +#ifndef FASTEST +local block_state deflate_slow OF((deflate_state *s, int flush)); +#endif +local block_state deflate_rle OF((deflate_state *s, int flush)); +local block_state deflate_huff OF((deflate_state *s, int flush)); +local void lm_init OF((deflate_state *s)); +local void putShortMSB OF((deflate_state *s, uInt b)); +local void flush_pending OF((z_streamp strm)); +local int read_buf OF((z_streamp strm, Bytef *buf, unsigned size)); +#ifdef ASMV + void match_init OF((void)); /* asm code initialization */ + uInt longest_match OF((deflate_state *s, IPos cur_match)); +#else +local uInt longest_match OF((deflate_state *s, IPos cur_match)); +#endif + +#ifdef DEBUG +local void check_match OF((deflate_state *s, IPos start, IPos match, + int length)); +#endif + +/* =========================================================================== + * Local data + */ + +#define NIL 0 +/* Tail of hash chains */ + +#ifndef TOO_FAR +# define TOO_FAR 4096 +#endif +/* Matches of length 3 are discarded if their distance exceeds TOO_FAR */ + +/* Values for max_lazy_match, good_match and max_chain_length, depending on + * the desired pack level (0..9). The values given below have been tuned to + * exclude worst case performance for pathological files. Better values may be + * found for specific files. + */ +typedef struct config_s { + ush good_length; /* reduce lazy search above this match length */ + ush max_lazy; /* do not perform lazy search above this match length */ + ush nice_length; /* quit search above this match length */ + ush max_chain; + compress_func func; +} config; + +#ifdef FASTEST +local const config configuration_table[2] = { +/* good lazy nice chain */ +/* 0 */ {0, 0, 0, 0, deflate_stored}, /* store only */ +/* 1 */ {4, 4, 8, 4, deflate_fast}}; /* max speed, no lazy matches */ +#else +local const config configuration_table[10] = { +/* good lazy nice chain */ +/* 0 */ {0, 0, 0, 0, deflate_stored}, /* store only */ +/* 1 */ {4, 4, 8, 4, deflate_fast}, /* max speed, no lazy matches */ +/* 2 */ {4, 5, 16, 8, deflate_fast}, +/* 3 */ {4, 6, 32, 32, deflate_fast}, + +/* 4 */ {4, 4, 16, 16, deflate_slow}, /* lazy matches */ +/* 5 */ {8, 16, 32, 32, deflate_slow}, +/* 6 */ {8, 16, 128, 128, deflate_slow}, +/* 7 */ {8, 32, 128, 256, deflate_slow}, +/* 8 */ {32, 128, 258, 1024, deflate_slow}, +/* 9 */ {32, 258, 258, 4096, deflate_slow}}; /* max compression */ +#endif + +/* Note: the deflate() code requires max_lazy >= MIN_MATCH and max_chain >= 4 + * For deflate_fast() (levels <= 3) good is ignored and lazy has a different + * meaning. + */ + +#define EQUAL 0 +/* result of memcmp for equal strings */ + +#ifndef NO_DUMMY_DECL +struct static_tree_desc_s {int dummy;}; /* for buggy compilers */ +#endif + +/* rank Z_BLOCK between Z_NO_FLUSH and Z_PARTIAL_FLUSH */ +#define RANK(f) (((f) << 1) - ((f) > 4 ? 9 : 0)) + +/* =========================================================================== + * Update a hash value with the given input byte + * IN assertion: all calls to to UPDATE_HASH are made with consecutive + * input characters, so that a running hash key can be computed from the + * previous key instead of complete recalculation each time. + */ +#define UPDATE_HASH(s,h,c) (h = (((h)<hash_shift) ^ (c)) & s->hash_mask) + + +/* =========================================================================== + * Insert string str in the dictionary and set match_head to the previous head + * of the hash chain (the most recent string with same hash key). Return + * the previous length of the hash chain. + * If this file is compiled with -DFASTEST, the compression level is forced + * to 1, and no hash chains are maintained. + * IN assertion: all calls to to INSERT_STRING are made with consecutive + * input characters and the first MIN_MATCH bytes of str are valid + * (except for the last MIN_MATCH-1 bytes of the input file). + */ +#ifdef FASTEST +#define INSERT_STRING(s, str, match_head) \ + (UPDATE_HASH(s, s->ins_h, s->window[(str) + (MIN_MATCH-1)]), \ + match_head = s->head[s->ins_h], \ + s->head[s->ins_h] = (Pos)(str)) +#else +#define INSERT_STRING(s, str, match_head) \ + (UPDATE_HASH(s, s->ins_h, s->window[(str) + (MIN_MATCH-1)]), \ + match_head = s->prev[(str) & s->w_mask] = s->head[s->ins_h], \ + s->head[s->ins_h] = (Pos)(str)) +#endif + +/* =========================================================================== + * Initialize the hash table (avoiding 64K overflow for 16 bit systems). + * prev[] will be initialized on the fly. + */ +#define CLEAR_HASH(s) \ + s->head[s->hash_size-1] = NIL; \ + zmemzero((Bytef *)s->head, (unsigned)(s->hash_size-1)*sizeof(*s->head)); + +/* ========================================================================= */ +int ZEXPORT deflateInit_(strm, level, version, stream_size) + z_streamp strm; + int level; + const char *version; + int stream_size; +{ + return deflateInit2_(strm, level, Z_DEFLATED, MAX_WBITS, DEF_MEM_LEVEL, + Z_DEFAULT_STRATEGY, version, stream_size); + /* To do: ignore strm->next_in if we use it as window */ +} + +/* ========================================================================= */ +int ZEXPORT deflateInit2_(strm, level, method, windowBits, memLevel, strategy, + version, stream_size) + z_streamp strm; + int level; + int method; + int windowBits; + int memLevel; + int strategy; + const char *version; + int stream_size; +{ + deflate_state *s; + int wrap = 1; + static const char my_version[] = ZLIB_VERSION; + + ushf *overlay; + /* We overlay pending_buf and d_buf+l_buf. This works since the average + * output size for (length,distance) codes is <= 24 bits. + */ + + if (version == Z_NULL || version[0] != my_version[0] || + stream_size != sizeof(z_stream)) { + return Z_VERSION_ERROR; + } + if (strm == Z_NULL) return Z_STREAM_ERROR; + + strm->msg = Z_NULL; + if (strm->zalloc == (alloc_func)0) { +#ifdef Z_SOLO + return Z_STREAM_ERROR; +#else + strm->zalloc = zcalloc; + strm->opaque = (voidpf)0; +#endif + } + if (strm->zfree == (free_func)0) +#ifdef Z_SOLO + return Z_STREAM_ERROR; +#else + strm->zfree = zcfree; +#endif + +#ifdef FASTEST + if (level != 0) level = 1; +#else + if (level == Z_DEFAULT_COMPRESSION) level = 6; +#endif + + if (windowBits < 0) { /* suppress zlib wrapper */ + wrap = 0; + windowBits = -windowBits; + } +#ifdef GZIP + else if (windowBits > 15) { + wrap = 2; /* write gzip wrapper instead */ + windowBits -= 16; + } +#endif + if (memLevel < 1 || memLevel > MAX_MEM_LEVEL || method != Z_DEFLATED || + windowBits < 8 || windowBits > 15 || level < 0 || level > 9 || + strategy < 0 || strategy > Z_FIXED) { + return Z_STREAM_ERROR; + } + if (windowBits == 8) windowBits = 9; /* until 256-byte window bug fixed */ + s = (deflate_state *) ZALLOC(strm, 1, sizeof(deflate_state)); + if (s == Z_NULL) return Z_MEM_ERROR; + strm->state = (struct internal_state FAR *)s; + s->strm = strm; + + s->wrap = wrap; + s->gzhead = Z_NULL; + s->w_bits = windowBits; + s->w_size = 1 << s->w_bits; + s->w_mask = s->w_size - 1; + + s->hash_bits = memLevel + 7; + s->hash_size = 1 << s->hash_bits; + s->hash_mask = s->hash_size - 1; + s->hash_shift = ((s->hash_bits+MIN_MATCH-1)/MIN_MATCH); + + s->window = (Bytef *) ZALLOC(strm, s->w_size, 2*sizeof(Byte)); + s->prev = (Posf *) ZALLOC(strm, s->w_size, sizeof(Pos)); + s->head = (Posf *) ZALLOC(strm, s->hash_size, sizeof(Pos)); + + s->high_water = 0; /* nothing written to s->window yet */ + + s->lit_bufsize = 1 << (memLevel + 6); /* 16K elements by default */ + + overlay = (ushf *) ZALLOC(strm, s->lit_bufsize, sizeof(ush)+2); + s->pending_buf = (uchf *) overlay; + s->pending_buf_size = (ulg)s->lit_bufsize * (sizeof(ush)+2L); + + if (s->window == Z_NULL || s->prev == Z_NULL || s->head == Z_NULL || + s->pending_buf == Z_NULL) { + s->status = FINISH_STATE; + strm->msg = ERR_MSG(Z_MEM_ERROR); + deflateEnd (strm); + return Z_MEM_ERROR; + } + s->d_buf = overlay + s->lit_bufsize/sizeof(ush); + s->l_buf = s->pending_buf + (1+sizeof(ush))*s->lit_bufsize; + + s->level = level; + s->strategy = strategy; + s->method = (Byte)method; + + return deflateReset(strm); +} + +/* ========================================================================= */ +int ZEXPORT deflateSetDictionary (strm, dictionary, dictLength) + z_streamp strm; + const Bytef *dictionary; + uInt dictLength; +{ + deflate_state *s; + uInt str, n; + int wrap; + unsigned avail; + z_const unsigned char *next; + + if (strm == Z_NULL || strm->state == Z_NULL || dictionary == Z_NULL) + return Z_STREAM_ERROR; + s = strm->state; + wrap = s->wrap; + if (wrap == 2 || (wrap == 1 && s->status != INIT_STATE) || s->lookahead) + return Z_STREAM_ERROR; + + /* when using zlib wrappers, compute Adler-32 for provided dictionary */ + if (wrap == 1) + strm->adler = adler32(strm->adler, dictionary, dictLength); + s->wrap = 0; /* avoid computing Adler-32 in read_buf */ + + /* if dictionary would fill window, just replace the history */ + if (dictLength >= s->w_size) { + if (wrap == 0) { /* already empty otherwise */ + CLEAR_HASH(s); + s->strstart = 0; + s->block_start = 0L; + s->insert = 0; + } + dictionary += dictLength - s->w_size; /* use the tail */ + dictLength = s->w_size; + } + + /* insert dictionary into window and hash */ + avail = strm->avail_in; + next = strm->next_in; + strm->avail_in = dictLength; + strm->next_in = (z_const Bytef *)dictionary; + fill_window(s); + while (s->lookahead >= MIN_MATCH) { + str = s->strstart; + n = s->lookahead - (MIN_MATCH-1); + do { + UPDATE_HASH(s, s->ins_h, s->window[str + MIN_MATCH-1]); +#ifndef FASTEST + s->prev[str & s->w_mask] = s->head[s->ins_h]; +#endif + s->head[s->ins_h] = (Pos)str; + str++; + } while (--n); + s->strstart = str; + s->lookahead = MIN_MATCH-1; + fill_window(s); + } + s->strstart += s->lookahead; + s->block_start = (long)s->strstart; + s->insert = s->lookahead; + s->lookahead = 0; + s->match_length = s->prev_length = MIN_MATCH-1; + s->match_available = 0; + strm->next_in = next; + strm->avail_in = avail; + s->wrap = wrap; + return Z_OK; +} + +/* ========================================================================= */ +int ZEXPORT deflateResetKeep (strm) + z_streamp strm; +{ + deflate_state *s; + + if (strm == Z_NULL || strm->state == Z_NULL || + strm->zalloc == (alloc_func)0 || strm->zfree == (free_func)0) { + return Z_STREAM_ERROR; + } + + strm->total_in = strm->total_out = 0; + strm->msg = Z_NULL; /* use zfree if we ever allocate msg dynamically */ + strm->data_type = Z_UNKNOWN; + + s = (deflate_state *)strm->state; + s->pending = 0; + s->pending_out = s->pending_buf; + + if (s->wrap < 0) { + s->wrap = -s->wrap; /* was made negative by deflate(..., Z_FINISH); */ + } + s->status = s->wrap ? INIT_STATE : BUSY_STATE; + strm->adler = +#ifdef GZIP + s->wrap == 2 ? crc32(0L, Z_NULL, 0) : +#endif + adler32(0L, Z_NULL, 0); + s->last_flush = Z_NO_FLUSH; + + _tr_init(s); + + return Z_OK; +} + +/* ========================================================================= */ +int ZEXPORT deflateReset (strm) + z_streamp strm; +{ + int ret; + + ret = deflateResetKeep(strm); + if (ret == Z_OK) + lm_init(strm->state); + return ret; +} + +/* ========================================================================= */ +int ZEXPORT deflateSetHeader (strm, head) + z_streamp strm; + gz_headerp head; +{ + if (strm == Z_NULL || strm->state == Z_NULL) return Z_STREAM_ERROR; + if (strm->state->wrap != 2) return Z_STREAM_ERROR; + strm->state->gzhead = head; + return Z_OK; +} + +/* ========================================================================= */ +int ZEXPORT deflatePending (strm, pending, bits) + unsigned *pending; + int *bits; + z_streamp strm; +{ + if (strm == Z_NULL || strm->state == Z_NULL) return Z_STREAM_ERROR; + if (pending != Z_NULL) + *pending = strm->state->pending; + if (bits != Z_NULL) + *bits = strm->state->bi_valid; + return Z_OK; +} + +/* ========================================================================= */ +int ZEXPORT deflatePrime (strm, bits, value) + z_streamp strm; + int bits; + int value; +{ + deflate_state *s; + int put; + + if (strm == Z_NULL || strm->state == Z_NULL) return Z_STREAM_ERROR; + s = strm->state; + if ((Bytef *)(s->d_buf) < s->pending_out + ((Buf_size + 7) >> 3)) + return Z_BUF_ERROR; + do { + put = Buf_size - s->bi_valid; + if (put > bits) + put = bits; + s->bi_buf |= (ush)((value & ((1 << put) - 1)) << s->bi_valid); + s->bi_valid += put; + _tr_flush_bits(s); + value >>= put; + bits -= put; + } while (bits); + return Z_OK; +} + +/* ========================================================================= */ +int ZEXPORT deflateParams(strm, level, strategy) + z_streamp strm; + int level; + int strategy; +{ + deflate_state *s; + compress_func func; + int err = Z_OK; + + if (strm == Z_NULL || strm->state == Z_NULL) return Z_STREAM_ERROR; + s = strm->state; + +#ifdef FASTEST + if (level != 0) level = 1; +#else + if (level == Z_DEFAULT_COMPRESSION) level = 6; +#endif + if (level < 0 || level > 9 || strategy < 0 || strategy > Z_FIXED) { + return Z_STREAM_ERROR; + } + func = configuration_table[s->level].func; + + if ((strategy != s->strategy || func != configuration_table[level].func) && + strm->total_in != 0) { + /* Flush the last buffer: */ + err = deflate(strm, Z_BLOCK); + if (err == Z_BUF_ERROR && s->pending == 0) + err = Z_OK; + } + if (s->level != level) { + s->level = level; + s->max_lazy_match = configuration_table[level].max_lazy; + s->good_match = configuration_table[level].good_length; + s->nice_match = configuration_table[level].nice_length; + s->max_chain_length = configuration_table[level].max_chain; + } + s->strategy = strategy; + return err; +} + +/* ========================================================================= */ +int ZEXPORT deflateTune(strm, good_length, max_lazy, nice_length, max_chain) + z_streamp strm; + int good_length; + int max_lazy; + int nice_length; + int max_chain; +{ + deflate_state *s; + + if (strm == Z_NULL || strm->state == Z_NULL) return Z_STREAM_ERROR; + s = strm->state; + s->good_match = good_length; + s->max_lazy_match = max_lazy; + s->nice_match = nice_length; + s->max_chain_length = max_chain; + return Z_OK; +} + +/* ========================================================================= + * For the default windowBits of 15 and memLevel of 8, this function returns + * a close to exact, as well as small, upper bound on the compressed size. + * They are coded as constants here for a reason--if the #define's are + * changed, then this function needs to be changed as well. The return + * value for 15 and 8 only works for those exact settings. + * + * For any setting other than those defaults for windowBits and memLevel, + * the value returned is a conservative worst case for the maximum expansion + * resulting from using fixed blocks instead of stored blocks, which deflate + * can emit on compressed data for some combinations of the parameters. + * + * This function could be more sophisticated to provide closer upper bounds for + * every combination of windowBits and memLevel. But even the conservative + * upper bound of about 14% expansion does not seem onerous for output buffer + * allocation. + */ +uLong ZEXPORT deflateBound(strm, sourceLen) + z_streamp strm; + uLong sourceLen; +{ + deflate_state *s; + uLong complen, wraplen; + Bytef *str; + + /* conservative upper bound for compressed data */ + complen = sourceLen + + ((sourceLen + 7) >> 3) + ((sourceLen + 63) >> 6) + 5; + + /* if can't get parameters, return conservative bound plus zlib wrapper */ + if (strm == Z_NULL || strm->state == Z_NULL) + return complen + 6; + + /* compute wrapper length */ + s = strm->state; + switch (s->wrap) { + case 0: /* raw deflate */ + wraplen = 0; + break; + case 1: /* zlib wrapper */ + wraplen = 6 + (s->strstart ? 4 : 0); + break; + case 2: /* gzip wrapper */ + wraplen = 18; + if (s->gzhead != Z_NULL) { /* user-supplied gzip header */ + if (s->gzhead->extra != Z_NULL) + wraplen += 2 + s->gzhead->extra_len; + str = s->gzhead->name; + if (str != Z_NULL) + do { + wraplen++; + } while (*str++); + str = s->gzhead->comment; + if (str != Z_NULL) + do { + wraplen++; + } while (*str++); + if (s->gzhead->hcrc) + wraplen += 2; + } + break; + default: /* for compiler happiness */ + wraplen = 6; + } + + /* if not default parameters, return conservative bound */ + if (s->w_bits != 15 || s->hash_bits != 8 + 7) + return complen + wraplen; + + /* default settings: return tight bound for that case */ + return sourceLen + (sourceLen >> 12) + (sourceLen >> 14) + + (sourceLen >> 25) + 13 - 6 + wraplen; +} + +/* ========================================================================= + * Put a short in the pending buffer. The 16-bit value is put in MSB order. + * IN assertion: the stream state is correct and there is enough room in + * pending_buf. + */ +local void putShortMSB (s, b) + deflate_state *s; + uInt b; +{ + put_byte(s, (Byte)(b >> 8)); + put_byte(s, (Byte)(b & 0xff)); +} + +/* ========================================================================= + * Flush as much pending output as possible. All deflate() output goes + * through this function so some applications may wish to modify it + * to avoid allocating a large strm->next_out buffer and copying into it. + * (See also read_buf()). + */ +local void flush_pending(strm) + z_streamp strm; +{ + unsigned len; + deflate_state *s = strm->state; + + _tr_flush_bits(s); + len = s->pending; + if (len > strm->avail_out) len = strm->avail_out; + if (len == 0) return; + + zmemcpy(strm->next_out, s->pending_out, len); + strm->next_out += len; + s->pending_out += len; + strm->total_out += len; + strm->avail_out -= len; + s->pending -= len; + if (s->pending == 0) { + s->pending_out = s->pending_buf; + } +} + +/* ========================================================================= */ +int ZEXPORT deflate (strm, flush) + z_streamp strm; + int flush; +{ + int old_flush; /* value of flush param for previous deflate call */ + deflate_state *s; + + if (strm == Z_NULL || strm->state == Z_NULL || + flush > Z_BLOCK || flush < 0) { + return Z_STREAM_ERROR; + } + s = strm->state; + + if (strm->next_out == Z_NULL || + (strm->next_in == Z_NULL && strm->avail_in != 0) || + (s->status == FINISH_STATE && flush != Z_FINISH)) { + ERR_RETURN(strm, Z_STREAM_ERROR); + } + if (strm->avail_out == 0) ERR_RETURN(strm, Z_BUF_ERROR); + + s->strm = strm; /* just in case */ + old_flush = s->last_flush; + s->last_flush = flush; + + /* Write the header */ + if (s->status == INIT_STATE) { +#ifdef GZIP + if (s->wrap == 2) { + strm->adler = crc32(0L, Z_NULL, 0); + put_byte(s, 31); + put_byte(s, 139); + put_byte(s, 8); + if (s->gzhead == Z_NULL) { + put_byte(s, 0); + put_byte(s, 0); + put_byte(s, 0); + put_byte(s, 0); + put_byte(s, 0); + put_byte(s, s->level == 9 ? 2 : + (s->strategy >= Z_HUFFMAN_ONLY || s->level < 2 ? + 4 : 0)); + put_byte(s, OS_CODE); + s->status = BUSY_STATE; + } + else { + put_byte(s, (s->gzhead->text ? 1 : 0) + + (s->gzhead->hcrc ? 2 : 0) + + (s->gzhead->extra == Z_NULL ? 0 : 4) + + (s->gzhead->name == Z_NULL ? 0 : 8) + + (s->gzhead->comment == Z_NULL ? 0 : 16) + ); + put_byte(s, (Byte)(s->gzhead->time & 0xff)); + put_byte(s, (Byte)((s->gzhead->time >> 8) & 0xff)); + put_byte(s, (Byte)((s->gzhead->time >> 16) & 0xff)); + put_byte(s, (Byte)((s->gzhead->time >> 24) & 0xff)); + put_byte(s, s->level == 9 ? 2 : + (s->strategy >= Z_HUFFMAN_ONLY || s->level < 2 ? + 4 : 0)); + put_byte(s, s->gzhead->os & 0xff); + if (s->gzhead->extra != Z_NULL) { + put_byte(s, s->gzhead->extra_len & 0xff); + put_byte(s, (s->gzhead->extra_len >> 8) & 0xff); + } + if (s->gzhead->hcrc) + strm->adler = crc32(strm->adler, s->pending_buf, + s->pending); + s->gzindex = 0; + s->status = EXTRA_STATE; + } + } + else +#endif + { + uInt header = (Z_DEFLATED + ((s->w_bits-8)<<4)) << 8; + uInt level_flags; + + if (s->strategy >= Z_HUFFMAN_ONLY || s->level < 2) + level_flags = 0; + else if (s->level < 6) + level_flags = 1; + else if (s->level == 6) + level_flags = 2; + else + level_flags = 3; + header |= (level_flags << 6); + if (s->strstart != 0) header |= PRESET_DICT; + header += 31 - (header % 31); + + s->status = BUSY_STATE; + putShortMSB(s, header); + + /* Save the adler32 of the preset dictionary: */ + if (s->strstart != 0) { + putShortMSB(s, (uInt)(strm->adler >> 16)); + putShortMSB(s, (uInt)(strm->adler & 0xffff)); + } + strm->adler = adler32(0L, Z_NULL, 0); + } + } +#ifdef GZIP + if (s->status == EXTRA_STATE) { + if (s->gzhead->extra != Z_NULL) { + uInt beg = s->pending; /* start of bytes to update crc */ + + while (s->gzindex < (s->gzhead->extra_len & 0xffff)) { + if (s->pending == s->pending_buf_size) { + if (s->gzhead->hcrc && s->pending > beg) + strm->adler = crc32(strm->adler, s->pending_buf + beg, + s->pending - beg); + flush_pending(strm); + beg = s->pending; + if (s->pending == s->pending_buf_size) + break; + } + put_byte(s, s->gzhead->extra[s->gzindex]); + s->gzindex++; + } + if (s->gzhead->hcrc && s->pending > beg) + strm->adler = crc32(strm->adler, s->pending_buf + beg, + s->pending - beg); + if (s->gzindex == s->gzhead->extra_len) { + s->gzindex = 0; + s->status = NAME_STATE; + } + } + else + s->status = NAME_STATE; + } + if (s->status == NAME_STATE) { + if (s->gzhead->name != Z_NULL) { + uInt beg = s->pending; /* start of bytes to update crc */ + int val; + + do { + if (s->pending == s->pending_buf_size) { + if (s->gzhead->hcrc && s->pending > beg) + strm->adler = crc32(strm->adler, s->pending_buf + beg, + s->pending - beg); + flush_pending(strm); + beg = s->pending; + if (s->pending == s->pending_buf_size) { + val = 1; + break; + } + } + val = s->gzhead->name[s->gzindex++]; + put_byte(s, val); + } while (val != 0); + if (s->gzhead->hcrc && s->pending > beg) + strm->adler = crc32(strm->adler, s->pending_buf + beg, + s->pending - beg); + if (val == 0) { + s->gzindex = 0; + s->status = COMMENT_STATE; + } + } + else + s->status = COMMENT_STATE; + } + if (s->status == COMMENT_STATE) { + if (s->gzhead->comment != Z_NULL) { + uInt beg = s->pending; /* start of bytes to update crc */ + int val; + + do { + if (s->pending == s->pending_buf_size) { + if (s->gzhead->hcrc && s->pending > beg) + strm->adler = crc32(strm->adler, s->pending_buf + beg, + s->pending - beg); + flush_pending(strm); + beg = s->pending; + if (s->pending == s->pending_buf_size) { + val = 1; + break; + } + } + val = s->gzhead->comment[s->gzindex++]; + put_byte(s, val); + } while (val != 0); + if (s->gzhead->hcrc && s->pending > beg) + strm->adler = crc32(strm->adler, s->pending_buf + beg, + s->pending - beg); + if (val == 0) + s->status = HCRC_STATE; + } + else + s->status = HCRC_STATE; + } + if (s->status == HCRC_STATE) { + if (s->gzhead->hcrc) { + if (s->pending + 2 > s->pending_buf_size) + flush_pending(strm); + if (s->pending + 2 <= s->pending_buf_size) { + put_byte(s, (Byte)(strm->adler & 0xff)); + put_byte(s, (Byte)((strm->adler >> 8) & 0xff)); + strm->adler = crc32(0L, Z_NULL, 0); + s->status = BUSY_STATE; + } + } + else + s->status = BUSY_STATE; + } +#endif + + /* Flush as much pending output as possible */ + if (s->pending != 0) { + flush_pending(strm); + if (strm->avail_out == 0) { + /* Since avail_out is 0, deflate will be called again with + * more output space, but possibly with both pending and + * avail_in equal to zero. There won't be anything to do, + * but this is not an error situation so make sure we + * return OK instead of BUF_ERROR at next call of deflate: + */ + s->last_flush = -1; + return Z_OK; + } + + /* Make sure there is something to do and avoid duplicate consecutive + * flushes. For repeated and useless calls with Z_FINISH, we keep + * returning Z_STREAM_END instead of Z_BUF_ERROR. + */ + } else if (strm->avail_in == 0 && RANK(flush) <= RANK(old_flush) && + flush != Z_FINISH) { + ERR_RETURN(strm, Z_BUF_ERROR); + } + + /* User must not provide more input after the first FINISH: */ + if (s->status == FINISH_STATE && strm->avail_in != 0) { + ERR_RETURN(strm, Z_BUF_ERROR); + } + + /* Start a new block or continue the current one. + */ + if (strm->avail_in != 0 || s->lookahead != 0 || + (flush != Z_NO_FLUSH && s->status != FINISH_STATE)) { + block_state bstate; + + bstate = s->strategy == Z_HUFFMAN_ONLY ? deflate_huff(s, flush) : + (s->strategy == Z_RLE ? deflate_rle(s, flush) : + (*(configuration_table[s->level].func))(s, flush)); + + if (bstate == finish_started || bstate == finish_done) { + s->status = FINISH_STATE; + } + if (bstate == need_more || bstate == finish_started) { + if (strm->avail_out == 0) { + s->last_flush = -1; /* avoid BUF_ERROR next call, see above */ + } + return Z_OK; + /* If flush != Z_NO_FLUSH && avail_out == 0, the next call + * of deflate should use the same flush parameter to make sure + * that the flush is complete. So we don't have to output an + * empty block here, this will be done at next call. This also + * ensures that for a very small output buffer, we emit at most + * one empty block. + */ + } + if (bstate == block_done) { + if (flush == Z_PARTIAL_FLUSH) { + _tr_align(s); + } else if (flush != Z_BLOCK) { /* FULL_FLUSH or SYNC_FLUSH */ + _tr_stored_block(s, (char*)0, 0L, 0); + /* For a full flush, this empty block will be recognized + * as a special marker by inflate_sync(). + */ + if (flush == Z_FULL_FLUSH) { + CLEAR_HASH(s); /* forget history */ + if (s->lookahead == 0) { + s->strstart = 0; + s->block_start = 0L; + s->insert = 0; + } + } + } + flush_pending(strm); + if (strm->avail_out == 0) { + s->last_flush = -1; /* avoid BUF_ERROR at next call, see above */ + return Z_OK; + } + } + } + Assert(strm->avail_out > 0, "bug2"); + + if (flush != Z_FINISH) return Z_OK; + if (s->wrap <= 0) return Z_STREAM_END; + + /* Write the trailer */ +#ifdef GZIP + if (s->wrap == 2) { + put_byte(s, (Byte)(strm->adler & 0xff)); + put_byte(s, (Byte)((strm->adler >> 8) & 0xff)); + put_byte(s, (Byte)((strm->adler >> 16) & 0xff)); + put_byte(s, (Byte)((strm->adler >> 24) & 0xff)); + put_byte(s, (Byte)(strm->total_in & 0xff)); + put_byte(s, (Byte)((strm->total_in >> 8) & 0xff)); + put_byte(s, (Byte)((strm->total_in >> 16) & 0xff)); + put_byte(s, (Byte)((strm->total_in >> 24) & 0xff)); + } + else +#endif + { + putShortMSB(s, (uInt)(strm->adler >> 16)); + putShortMSB(s, (uInt)(strm->adler & 0xffff)); + } + flush_pending(strm); + /* If avail_out is zero, the application will call deflate again + * to flush the rest. + */ + if (s->wrap > 0) s->wrap = -s->wrap; /* write the trailer only once! */ + return s->pending != 0 ? Z_OK : Z_STREAM_END; +} + +/* ========================================================================= */ +int ZEXPORT deflateEnd (strm) + z_streamp strm; +{ + int status; + + if (strm == Z_NULL || strm->state == Z_NULL) return Z_STREAM_ERROR; + + status = strm->state->status; + if (status != INIT_STATE && + status != EXTRA_STATE && + status != NAME_STATE && + status != COMMENT_STATE && + status != HCRC_STATE && + status != BUSY_STATE && + status != FINISH_STATE) { + return Z_STREAM_ERROR; + } + + /* Deallocate in reverse order of allocations: */ + TRY_FREE(strm, strm->state->pending_buf); + TRY_FREE(strm, strm->state->head); + TRY_FREE(strm, strm->state->prev); + TRY_FREE(strm, strm->state->window); + + ZFREE(strm, strm->state); + strm->state = Z_NULL; + + return status == BUSY_STATE ? Z_DATA_ERROR : Z_OK; +} + +/* ========================================================================= + * Copy the source state to the destination state. + * To simplify the source, this is not supported for 16-bit MSDOS (which + * doesn't have enough memory anyway to duplicate compression states). + */ +int ZEXPORT deflateCopy (dest, source) + z_streamp dest; + z_streamp source; +{ +#ifdef MAXSEG_64K + return Z_STREAM_ERROR; +#else + deflate_state *ds; + deflate_state *ss; + ushf *overlay; + + + if (source == Z_NULL || dest == Z_NULL || source->state == Z_NULL) { + return Z_STREAM_ERROR; + } + + ss = source->state; + + zmemcpy((voidpf)dest, (voidpf)source, sizeof(z_stream)); + + ds = (deflate_state *) ZALLOC(dest, 1, sizeof(deflate_state)); + if (ds == Z_NULL) return Z_MEM_ERROR; + dest->state = (struct internal_state FAR *) ds; + zmemcpy((voidpf)ds, (voidpf)ss, sizeof(deflate_state)); + ds->strm = dest; + + ds->window = (Bytef *) ZALLOC(dest, ds->w_size, 2*sizeof(Byte)); + ds->prev = (Posf *) ZALLOC(dest, ds->w_size, sizeof(Pos)); + ds->head = (Posf *) ZALLOC(dest, ds->hash_size, sizeof(Pos)); + overlay = (ushf *) ZALLOC(dest, ds->lit_bufsize, sizeof(ush)+2); + ds->pending_buf = (uchf *) overlay; + + if (ds->window == Z_NULL || ds->prev == Z_NULL || ds->head == Z_NULL || + ds->pending_buf == Z_NULL) { + deflateEnd (dest); + return Z_MEM_ERROR; + } + /* following zmemcpy do not work for 16-bit MSDOS */ + zmemcpy(ds->window, ss->window, ds->w_size * 2 * sizeof(Byte)); + zmemcpy((voidpf)ds->prev, (voidpf)ss->prev, ds->w_size * sizeof(Pos)); + zmemcpy((voidpf)ds->head, (voidpf)ss->head, ds->hash_size * sizeof(Pos)); + zmemcpy(ds->pending_buf, ss->pending_buf, (uInt)ds->pending_buf_size); + + ds->pending_out = ds->pending_buf + (ss->pending_out - ss->pending_buf); + ds->d_buf = overlay + ds->lit_bufsize/sizeof(ush); + ds->l_buf = ds->pending_buf + (1+sizeof(ush))*ds->lit_bufsize; + + ds->l_desc.dyn_tree = ds->dyn_ltree; + ds->d_desc.dyn_tree = ds->dyn_dtree; + ds->bl_desc.dyn_tree = ds->bl_tree; + + return Z_OK; +#endif /* MAXSEG_64K */ +} + +/* =========================================================================== + * Read a new buffer from the current input stream, update the adler32 + * and total number of bytes read. All deflate() input goes through + * this function so some applications may wish to modify it to avoid + * allocating a large strm->next_in buffer and copying from it. + * (See also flush_pending()). + */ +local int read_buf(strm, buf, size) + z_streamp strm; + Bytef *buf; + unsigned size; +{ + unsigned len = strm->avail_in; + + if (len > size) len = size; + if (len == 0) return 0; + + strm->avail_in -= len; + + zmemcpy(buf, strm->next_in, len); + if (strm->state->wrap == 1) { + strm->adler = adler32(strm->adler, buf, len); + } +#ifdef GZIP + else if (strm->state->wrap == 2) { + strm->adler = crc32(strm->adler, buf, len); + } +#endif + strm->next_in += len; + strm->total_in += len; + + return (int)len; +} + +/* =========================================================================== + * Initialize the "longest match" routines for a new zlib stream + */ +local void lm_init (s) + deflate_state *s; +{ + s->window_size = (ulg)2L*s->w_size; + + CLEAR_HASH(s); + + /* Set the default configuration parameters: + */ + s->max_lazy_match = configuration_table[s->level].max_lazy; + s->good_match = configuration_table[s->level].good_length; + s->nice_match = configuration_table[s->level].nice_length; + s->max_chain_length = configuration_table[s->level].max_chain; + + s->strstart = 0; + s->block_start = 0L; + s->lookahead = 0; + s->insert = 0; + s->match_length = s->prev_length = MIN_MATCH-1; + s->match_available = 0; + s->ins_h = 0; +#ifndef FASTEST +#ifdef ASMV + match_init(); /* initialize the asm code */ +#endif +#endif +} + +#ifndef FASTEST +/* =========================================================================== + * Set match_start to the longest match starting at the given string and + * return its length. Matches shorter or equal to prev_length are discarded, + * in which case the result is equal to prev_length and match_start is + * garbage. + * IN assertions: cur_match is the head of the hash chain for the current + * string (strstart) and its distance is <= MAX_DIST, and prev_length >= 1 + * OUT assertion: the match length is not greater than s->lookahead. + */ +#ifndef ASMV +/* For 80x86 and 680x0, an optimized version will be provided in match.asm or + * match.S. The code will be functionally equivalent. + */ +local uInt longest_match(s, cur_match) + deflate_state *s; + IPos cur_match; /* current match */ +{ + unsigned chain_length = s->max_chain_length;/* max hash chain length */ + register Bytef *scan = s->window + s->strstart; /* current string */ + register Bytef *match; /* matched string */ + register int len; /* length of current match */ + int best_len = s->prev_length; /* best match length so far */ + int nice_match = s->nice_match; /* stop if match long enough */ + IPos limit = s->strstart > (IPos)MAX_DIST(s) ? + s->strstart - (IPos)MAX_DIST(s) : NIL; + /* Stop when cur_match becomes <= limit. To simplify the code, + * we prevent matches with the string of window index 0. + */ + Posf *prev = s->prev; + uInt wmask = s->w_mask; + +#ifdef UNALIGNED_OK + /* Compare two bytes at a time. Note: this is not always beneficial. + * Try with and without -DUNALIGNED_OK to check. + */ + register Bytef *strend = s->window + s->strstart + MAX_MATCH - 1; + register ush scan_start = *(ushf*)scan; + register ush scan_end = *(ushf*)(scan+best_len-1); +#else + register Bytef *strend = s->window + s->strstart + MAX_MATCH; + register Byte scan_end1 = scan[best_len-1]; + register Byte scan_end = scan[best_len]; +#endif + + /* The code is optimized for HASH_BITS >= 8 and MAX_MATCH-2 multiple of 16. + * It is easy to get rid of this optimization if necessary. + */ + Assert(s->hash_bits >= 8 && MAX_MATCH == 258, "Code too clever"); + + /* Do not waste too much time if we already have a good match: */ + if (s->prev_length >= s->good_match) { + chain_length >>= 2; + } + /* Do not look for matches beyond the end of the input. This is necessary + * to make deflate deterministic. + */ + if ((uInt)nice_match > s->lookahead) nice_match = s->lookahead; + + Assert((ulg)s->strstart <= s->window_size-MIN_LOOKAHEAD, "need lookahead"); + + do { + Assert(cur_match < s->strstart, "no future"); + match = s->window + cur_match; + + /* Skip to next match if the match length cannot increase + * or if the match length is less than 2. Note that the checks below + * for insufficient lookahead only occur occasionally for performance + * reasons. Therefore uninitialized memory will be accessed, and + * conditional jumps will be made that depend on those values. + * However the length of the match is limited to the lookahead, so + * the output of deflate is not affected by the uninitialized values. + */ +#if (defined(UNALIGNED_OK) && MAX_MATCH == 258) + /* This code assumes sizeof(unsigned short) == 2. Do not use + * UNALIGNED_OK if your compiler uses a different size. + */ + if (*(ushf*)(match+best_len-1) != scan_end || + *(ushf*)match != scan_start) continue; + + /* It is not necessary to compare scan[2] and match[2] since they are + * always equal when the other bytes match, given that the hash keys + * are equal and that HASH_BITS >= 8. Compare 2 bytes at a time at + * strstart+3, +5, ... up to strstart+257. We check for insufficient + * lookahead only every 4th comparison; the 128th check will be made + * at strstart+257. If MAX_MATCH-2 is not a multiple of 8, it is + * necessary to put more guard bytes at the end of the window, or + * to check more often for insufficient lookahead. + */ + Assert(scan[2] == match[2], "scan[2]?"); + scan++, match++; + do { + } while (*(ushf*)(scan+=2) == *(ushf*)(match+=2) && + *(ushf*)(scan+=2) == *(ushf*)(match+=2) && + *(ushf*)(scan+=2) == *(ushf*)(match+=2) && + *(ushf*)(scan+=2) == *(ushf*)(match+=2) && + scan < strend); + /* The funny "do {}" generates better code on most compilers */ + + /* Here, scan <= window+strstart+257 */ + Assert(scan <= s->window+(unsigned)(s->window_size-1), "wild scan"); + if (*scan == *match) scan++; + + len = (MAX_MATCH - 1) - (int)(strend-scan); + scan = strend - (MAX_MATCH-1); + +#else /* UNALIGNED_OK */ + + if (match[best_len] != scan_end || + match[best_len-1] != scan_end1 || + *match != *scan || + *++match != scan[1]) continue; + + /* The check at best_len-1 can be removed because it will be made + * again later. (This heuristic is not always a win.) + * It is not necessary to compare scan[2] and match[2] since they + * are always equal when the other bytes match, given that + * the hash keys are equal and that HASH_BITS >= 8. + */ + scan += 2, match++; + Assert(*scan == *match, "match[2]?"); + + /* We check for insufficient lookahead only every 8th comparison; + * the 256th check will be made at strstart+258. + */ + do { + } while (*++scan == *++match && *++scan == *++match && + *++scan == *++match && *++scan == *++match && + *++scan == *++match && *++scan == *++match && + *++scan == *++match && *++scan == *++match && + scan < strend); + + Assert(scan <= s->window+(unsigned)(s->window_size-1), "wild scan"); + + len = MAX_MATCH - (int)(strend - scan); + scan = strend - MAX_MATCH; + +#endif /* UNALIGNED_OK */ + + if (len > best_len) { + s->match_start = cur_match; + best_len = len; + if (len >= nice_match) break; +#ifdef UNALIGNED_OK + scan_end = *(ushf*)(scan+best_len-1); +#else + scan_end1 = scan[best_len-1]; + scan_end = scan[best_len]; +#endif + } + } while ((cur_match = prev[cur_match & wmask]) > limit + && --chain_length != 0); + + if ((uInt)best_len <= s->lookahead) return (uInt)best_len; + return s->lookahead; +} +#endif /* ASMV */ + +#else /* FASTEST */ + +/* --------------------------------------------------------------------------- + * Optimized version for FASTEST only + */ +local uInt longest_match(s, cur_match) + deflate_state *s; + IPos cur_match; /* current match */ +{ + register Bytef *scan = s->window + s->strstart; /* current string */ + register Bytef *match; /* matched string */ + register int len; /* length of current match */ + register Bytef *strend = s->window + s->strstart + MAX_MATCH; + + /* The code is optimized for HASH_BITS >= 8 and MAX_MATCH-2 multiple of 16. + * It is easy to get rid of this optimization if necessary. + */ + Assert(s->hash_bits >= 8 && MAX_MATCH == 258, "Code too clever"); + + Assert((ulg)s->strstart <= s->window_size-MIN_LOOKAHEAD, "need lookahead"); + + Assert(cur_match < s->strstart, "no future"); + + match = s->window + cur_match; + + /* Return failure if the match length is less than 2: + */ + if (match[0] != scan[0] || match[1] != scan[1]) return MIN_MATCH-1; + + /* The check at best_len-1 can be removed because it will be made + * again later. (This heuristic is not always a win.) + * It is not necessary to compare scan[2] and match[2] since they + * are always equal when the other bytes match, given that + * the hash keys are equal and that HASH_BITS >= 8. + */ + scan += 2, match += 2; + Assert(*scan == *match, "match[2]?"); + + /* We check for insufficient lookahead only every 8th comparison; + * the 256th check will be made at strstart+258. + */ + do { + } while (*++scan == *++match && *++scan == *++match && + *++scan == *++match && *++scan == *++match && + *++scan == *++match && *++scan == *++match && + *++scan == *++match && *++scan == *++match && + scan < strend); + + Assert(scan <= s->window+(unsigned)(s->window_size-1), "wild scan"); + + len = MAX_MATCH - (int)(strend - scan); + + if (len < MIN_MATCH) return MIN_MATCH - 1; + + s->match_start = cur_match; + return (uInt)len <= s->lookahead ? (uInt)len : s->lookahead; +} + +#endif /* FASTEST */ + +#ifdef DEBUG +/* =========================================================================== + * Check that the match at match_start is indeed a match. + */ +local void check_match(s, start, match, length) + deflate_state *s; + IPos start, match; + int length; +{ + /* check that the match is indeed a match */ + if (zmemcmp(s->window + match, + s->window + start, length) != EQUAL) { + fprintf(stderr, " start %u, match %u, length %d\n", + start, match, length); + do { + fprintf(stderr, "%c%c", s->window[match++], s->window[start++]); + } while (--length != 0); + z_error("invalid match"); + } + if (z_verbose > 1) { + fprintf(stderr,"\\[%d,%d]", start-match, length); + do { putc(s->window[start++], stderr); } while (--length != 0); + } +} +#else +# define check_match(s, start, match, length) +#endif /* DEBUG */ + +/* =========================================================================== + * Fill the window when the lookahead becomes insufficient. + * Updates strstart and lookahead. + * + * IN assertion: lookahead < MIN_LOOKAHEAD + * OUT assertions: strstart <= window_size-MIN_LOOKAHEAD + * At least one byte has been read, or avail_in == 0; reads are + * performed for at least two bytes (required for the zip translate_eol + * option -- not supported here). + */ +local void fill_window(s) + deflate_state *s; +{ + register unsigned n, m; + register Posf *p; + unsigned more; /* Amount of free space at the end of the window. */ + uInt wsize = s->w_size; + + Assert(s->lookahead < MIN_LOOKAHEAD, "already enough lookahead"); + + do { + more = (unsigned)(s->window_size -(ulg)s->lookahead -(ulg)s->strstart); + + /* Deal with !@#$% 64K limit: */ + if (sizeof(int) <= 2) { + if (more == 0 && s->strstart == 0 && s->lookahead == 0) { + more = wsize; + + } else if (more == (unsigned)(-1)) { + /* Very unlikely, but possible on 16 bit machine if + * strstart == 0 && lookahead == 1 (input done a byte at time) + */ + more--; + } + } + + /* If the window is almost full and there is insufficient lookahead, + * move the upper half to the lower one to make room in the upper half. + */ + if (s->strstart >= wsize+MAX_DIST(s)) { + + zmemcpy(s->window, s->window+wsize, (unsigned)wsize); + s->match_start -= wsize; + s->strstart -= wsize; /* we now have strstart >= MAX_DIST */ + s->block_start -= (long) wsize; + + /* Slide the hash table (could be avoided with 32 bit values + at the expense of memory usage). We slide even when level == 0 + to keep the hash table consistent if we switch back to level > 0 + later. (Using level 0 permanently is not an optimal usage of + zlib, so we don't care about this pathological case.) + */ + n = s->hash_size; + p = &s->head[n]; + do { + m = *--p; + *p = (Pos)(m >= wsize ? m-wsize : NIL); + } while (--n); + + n = wsize; +#ifndef FASTEST + p = &s->prev[n]; + do { + m = *--p; + *p = (Pos)(m >= wsize ? m-wsize : NIL); + /* If n is not on any hash chain, prev[n] is garbage but + * its value will never be used. + */ + } while (--n); +#endif + more += wsize; + } + if (s->strm->avail_in == 0) break; + + /* If there was no sliding: + * strstart <= WSIZE+MAX_DIST-1 && lookahead <= MIN_LOOKAHEAD - 1 && + * more == window_size - lookahead - strstart + * => more >= window_size - (MIN_LOOKAHEAD-1 + WSIZE + MAX_DIST-1) + * => more >= window_size - 2*WSIZE + 2 + * In the BIG_MEM or MMAP case (not yet supported), + * window_size == input_size + MIN_LOOKAHEAD && + * strstart + s->lookahead <= input_size => more >= MIN_LOOKAHEAD. + * Otherwise, window_size == 2*WSIZE so more >= 2. + * If there was sliding, more >= WSIZE. So in all cases, more >= 2. + */ + Assert(more >= 2, "more < 2"); + + n = read_buf(s->strm, s->window + s->strstart + s->lookahead, more); + s->lookahead += n; + + /* Initialize the hash value now that we have some input: */ + if (s->lookahead + s->insert >= MIN_MATCH) { + uInt str = s->strstart - s->insert; + s->ins_h = s->window[str]; + UPDATE_HASH(s, s->ins_h, s->window[str + 1]); +#if MIN_MATCH != 3 + Call UPDATE_HASH() MIN_MATCH-3 more times +#endif + while (s->insert) { + UPDATE_HASH(s, s->ins_h, s->window[str + MIN_MATCH-1]); +#ifndef FASTEST + s->prev[str & s->w_mask] = s->head[s->ins_h]; +#endif + s->head[s->ins_h] = (Pos)str; + str++; + s->insert--; + if (s->lookahead + s->insert < MIN_MATCH) + break; + } + } + /* If the whole input has less than MIN_MATCH bytes, ins_h is garbage, + * but this is not important since only literal bytes will be emitted. + */ + + } while (s->lookahead < MIN_LOOKAHEAD && s->strm->avail_in != 0); + + /* If the WIN_INIT bytes after the end of the current data have never been + * written, then zero those bytes in order to avoid memory check reports of + * the use of uninitialized (or uninitialised as Julian writes) bytes by + * the longest match routines. Update the high water mark for the next + * time through here. WIN_INIT is set to MAX_MATCH since the longest match + * routines allow scanning to strstart + MAX_MATCH, ignoring lookahead. + */ + if (s->high_water < s->window_size) { + ulg curr = s->strstart + (ulg)(s->lookahead); + ulg init; + + if (s->high_water < curr) { + /* Previous high water mark below current data -- zero WIN_INIT + * bytes or up to end of window, whichever is less. + */ + init = s->window_size - curr; + if (init > WIN_INIT) + init = WIN_INIT; + zmemzero(s->window + curr, (unsigned)init); + s->high_water = curr + init; + } + else if (s->high_water < (ulg)curr + WIN_INIT) { + /* High water mark at or above current data, but below current data + * plus WIN_INIT -- zero out to current data plus WIN_INIT, or up + * to end of window, whichever is less. + */ + init = (ulg)curr + WIN_INIT - s->high_water; + if (init > s->window_size - s->high_water) + init = s->window_size - s->high_water; + zmemzero(s->window + s->high_water, (unsigned)init); + s->high_water += init; + } + } + + Assert((ulg)s->strstart <= s->window_size - MIN_LOOKAHEAD, + "not enough room for search"); +} + +/* =========================================================================== + * Flush the current block, with given end-of-file flag. + * IN assertion: strstart is set to the end of the current match. + */ +#define FLUSH_BLOCK_ONLY(s, last) { \ + _tr_flush_block(s, (s->block_start >= 0L ? \ + (charf *)&s->window[(unsigned)s->block_start] : \ + (charf *)Z_NULL), \ + (ulg)((long)s->strstart - s->block_start), \ + (last)); \ + s->block_start = s->strstart; \ + flush_pending(s->strm); \ + Tracev((stderr,"[FLUSH]")); \ +} + +/* Same but force premature exit if necessary. */ +#define FLUSH_BLOCK(s, last) { \ + FLUSH_BLOCK_ONLY(s, last); \ + if (s->strm->avail_out == 0) return (last) ? finish_started : need_more; \ +} + +/* =========================================================================== + * Copy without compression as much as possible from the input stream, return + * the current block state. + * This function does not insert new strings in the dictionary since + * uncompressible data is probably not useful. This function is used + * only for the level=0 compression option. + * NOTE: this function should be optimized to avoid extra copying from + * window to pending_buf. + */ +local block_state deflate_stored(s, flush) + deflate_state *s; + int flush; +{ + /* Stored blocks are limited to 0xffff bytes, pending_buf is limited + * to pending_buf_size, and each stored block has a 5 byte header: + */ + ulg max_block_size = 0xffff; + ulg max_start; + + if (max_block_size > s->pending_buf_size - 5) { + max_block_size = s->pending_buf_size - 5; + } + + /* Copy as much as possible from input to output: */ + for (;;) { + /* Fill the window as much as possible: */ + if (s->lookahead <= 1) { + + Assert(s->strstart < s->w_size+MAX_DIST(s) || + s->block_start >= (long)s->w_size, "slide too late"); + + fill_window(s); + if (s->lookahead == 0 && flush == Z_NO_FLUSH) return need_more; + + if (s->lookahead == 0) break; /* flush the current block */ + } + Assert(s->block_start >= 0L, "block gone"); + + s->strstart += s->lookahead; + s->lookahead = 0; + + /* Emit a stored block if pending_buf will be full: */ + max_start = s->block_start + max_block_size; + if (s->strstart == 0 || (ulg)s->strstart >= max_start) { + /* strstart == 0 is possible when wraparound on 16-bit machine */ + s->lookahead = (uInt)(s->strstart - max_start); + s->strstart = (uInt)max_start; + FLUSH_BLOCK(s, 0); + } + /* Flush if we may have to slide, otherwise block_start may become + * negative and the data will be gone: + */ + if (s->strstart - (uInt)s->block_start >= MAX_DIST(s)) { + FLUSH_BLOCK(s, 0); + } + } + s->insert = 0; + if (flush == Z_FINISH) { + FLUSH_BLOCK(s, 1); + return finish_done; + } + if ((long)s->strstart > s->block_start) + FLUSH_BLOCK(s, 0); + return block_done; +} + +/* =========================================================================== + * Compress as much as possible from the input stream, return the current + * block state. + * This function does not perform lazy evaluation of matches and inserts + * new strings in the dictionary only for unmatched strings or for short + * matches. It is used only for the fast compression options. + */ +local block_state deflate_fast(s, flush) + deflate_state *s; + int flush; +{ + IPos hash_head; /* head of the hash chain */ + int bflush; /* set if current block must be flushed */ + + for (;;) { + /* Make sure that we always have enough lookahead, except + * at the end of the input file. We need MAX_MATCH bytes + * for the next match, plus MIN_MATCH bytes to insert the + * string following the next match. + */ + if (s->lookahead < MIN_LOOKAHEAD) { + fill_window(s); + if (s->lookahead < MIN_LOOKAHEAD && flush == Z_NO_FLUSH) { + return need_more; + } + if (s->lookahead == 0) break; /* flush the current block */ + } + + /* Insert the string window[strstart .. strstart+2] in the + * dictionary, and set hash_head to the head of the hash chain: + */ + hash_head = NIL; + if (s->lookahead >= MIN_MATCH) { + INSERT_STRING(s, s->strstart, hash_head); + } + + /* Find the longest match, discarding those <= prev_length. + * At this point we have always match_length < MIN_MATCH + */ + if (hash_head != NIL && s->strstart - hash_head <= MAX_DIST(s)) { + /* To simplify the code, we prevent matches with the string + * of window index 0 (in particular we have to avoid a match + * of the string with itself at the start of the input file). + */ + s->match_length = longest_match (s, hash_head); + /* longest_match() sets match_start */ + } + if (s->match_length >= MIN_MATCH) { + check_match(s, s->strstart, s->match_start, s->match_length); + + _tr_tally_dist(s, s->strstart - s->match_start, + s->match_length - MIN_MATCH, bflush); + + s->lookahead -= s->match_length; + + /* Insert new strings in the hash table only if the match length + * is not too large. This saves time but degrades compression. + */ +#ifndef FASTEST + if (s->match_length <= s->max_insert_length && + s->lookahead >= MIN_MATCH) { + s->match_length--; /* string at strstart already in table */ + do { + s->strstart++; + INSERT_STRING(s, s->strstart, hash_head); + /* strstart never exceeds WSIZE-MAX_MATCH, so there are + * always MIN_MATCH bytes ahead. + */ + } while (--s->match_length != 0); + s->strstart++; + } else +#endif + { + s->strstart += s->match_length; + s->match_length = 0; + s->ins_h = s->window[s->strstart]; + UPDATE_HASH(s, s->ins_h, s->window[s->strstart+1]); +#if MIN_MATCH != 3 + Call UPDATE_HASH() MIN_MATCH-3 more times +#endif + /* If lookahead < MIN_MATCH, ins_h is garbage, but it does not + * matter since it will be recomputed at next deflate call. + */ + } + } else { + /* No match, output a literal byte */ + Tracevv((stderr,"%c", s->window[s->strstart])); + _tr_tally_lit (s, s->window[s->strstart], bflush); + s->lookahead--; + s->strstart++; + } + if (bflush) FLUSH_BLOCK(s, 0); + } + s->insert = s->strstart < MIN_MATCH-1 ? s->strstart : MIN_MATCH-1; + if (flush == Z_FINISH) { + FLUSH_BLOCK(s, 1); + return finish_done; + } + if (s->last_lit) + FLUSH_BLOCK(s, 0); + return block_done; +} + +#ifndef FASTEST +/* =========================================================================== + * Same as above, but achieves better compression. We use a lazy + * evaluation for matches: a match is finally adopted only if there is + * no better match at the next window position. + */ +local block_state deflate_slow(s, flush) + deflate_state *s; + int flush; +{ + IPos hash_head; /* head of hash chain */ + int bflush; /* set if current block must be flushed */ + + /* Process the input block. */ + for (;;) { + /* Make sure that we always have enough lookahead, except + * at the end of the input file. We need MAX_MATCH bytes + * for the next match, plus MIN_MATCH bytes to insert the + * string following the next match. + */ + if (s->lookahead < MIN_LOOKAHEAD) { + fill_window(s); + if (s->lookahead < MIN_LOOKAHEAD && flush == Z_NO_FLUSH) { + return need_more; + } + if (s->lookahead == 0) break; /* flush the current block */ + } + + /* Insert the string window[strstart .. strstart+2] in the + * dictionary, and set hash_head to the head of the hash chain: + */ + hash_head = NIL; + if (s->lookahead >= MIN_MATCH) { + INSERT_STRING(s, s->strstart, hash_head); + } + + /* Find the longest match, discarding those <= prev_length. + */ + s->prev_length = s->match_length, s->prev_match = s->match_start; + s->match_length = MIN_MATCH-1; + + if (hash_head != NIL && s->prev_length < s->max_lazy_match && + s->strstart - hash_head <= MAX_DIST(s)) { + /* To simplify the code, we prevent matches with the string + * of window index 0 (in particular we have to avoid a match + * of the string with itself at the start of the input file). + */ + s->match_length = longest_match (s, hash_head); + /* longest_match() sets match_start */ + + if (s->match_length <= 5 && (s->strategy == Z_FILTERED +#if TOO_FAR <= 32767 + || (s->match_length == MIN_MATCH && + s->strstart - s->match_start > TOO_FAR) +#endif + )) { + + /* If prev_match is also MIN_MATCH, match_start is garbage + * but we will ignore the current match anyway. + */ + s->match_length = MIN_MATCH-1; + } + } + /* If there was a match at the previous step and the current + * match is not better, output the previous match: + */ + if (s->prev_length >= MIN_MATCH && s->match_length <= s->prev_length) { + uInt max_insert = s->strstart + s->lookahead - MIN_MATCH; + /* Do not insert strings in hash table beyond this. */ + + check_match(s, s->strstart-1, s->prev_match, s->prev_length); + + _tr_tally_dist(s, s->strstart -1 - s->prev_match, + s->prev_length - MIN_MATCH, bflush); + + /* Insert in hash table all strings up to the end of the match. + * strstart-1 and strstart are already inserted. If there is not + * enough lookahead, the last two strings are not inserted in + * the hash table. + */ + s->lookahead -= s->prev_length-1; + s->prev_length -= 2; + do { + if (++s->strstart <= max_insert) { + INSERT_STRING(s, s->strstart, hash_head); + } + } while (--s->prev_length != 0); + s->match_available = 0; + s->match_length = MIN_MATCH-1; + s->strstart++; + + if (bflush) FLUSH_BLOCK(s, 0); + + } else if (s->match_available) { + /* If there was no match at the previous position, output a + * single literal. If there was a match but the current match + * is longer, truncate the previous match to a single literal. + */ + Tracevv((stderr,"%c", s->window[s->strstart-1])); + _tr_tally_lit(s, s->window[s->strstart-1], bflush); + if (bflush) { + FLUSH_BLOCK_ONLY(s, 0); + } + s->strstart++; + s->lookahead--; + if (s->strm->avail_out == 0) return need_more; + } else { + /* There is no previous match to compare with, wait for + * the next step to decide. + */ + s->match_available = 1; + s->strstart++; + s->lookahead--; + } + } + Assert (flush != Z_NO_FLUSH, "no flush?"); + if (s->match_available) { + Tracevv((stderr,"%c", s->window[s->strstart-1])); + _tr_tally_lit(s, s->window[s->strstart-1], bflush); + s->match_available = 0; + } + s->insert = s->strstart < MIN_MATCH-1 ? s->strstart : MIN_MATCH-1; + if (flush == Z_FINISH) { + FLUSH_BLOCK(s, 1); + return finish_done; + } + if (s->last_lit) + FLUSH_BLOCK(s, 0); + return block_done; +} +#endif /* FASTEST */ + +/* =========================================================================== + * For Z_RLE, simply look for runs of bytes, generate matches only of distance + * one. Do not maintain a hash table. (It will be regenerated if this run of + * deflate switches away from Z_RLE.) + */ +local block_state deflate_rle(s, flush) + deflate_state *s; + int flush; +{ + int bflush; /* set if current block must be flushed */ + uInt prev; /* byte at distance one to match */ + Bytef *scan, *strend; /* scan goes up to strend for length of run */ + + for (;;) { + /* Make sure that we always have enough lookahead, except + * at the end of the input file. We need MAX_MATCH bytes + * for the longest run, plus one for the unrolled loop. + */ + if (s->lookahead <= MAX_MATCH) { + fill_window(s); + if (s->lookahead <= MAX_MATCH && flush == Z_NO_FLUSH) { + return need_more; + } + if (s->lookahead == 0) break; /* flush the current block */ + } + + /* See how many times the previous byte repeats */ + s->match_length = 0; + if (s->lookahead >= MIN_MATCH && s->strstart > 0) { + scan = s->window + s->strstart - 1; + prev = *scan; + if (prev == *++scan && prev == *++scan && prev == *++scan) { + strend = s->window + s->strstart + MAX_MATCH; + do { + } while (prev == *++scan && prev == *++scan && + prev == *++scan && prev == *++scan && + prev == *++scan && prev == *++scan && + prev == *++scan && prev == *++scan && + scan < strend); + s->match_length = MAX_MATCH - (int)(strend - scan); + if (s->match_length > s->lookahead) + s->match_length = s->lookahead; + } + Assert(scan <= s->window+(uInt)(s->window_size-1), "wild scan"); + } + + /* Emit match if have run of MIN_MATCH or longer, else emit literal */ + if (s->match_length >= MIN_MATCH) { + check_match(s, s->strstart, s->strstart - 1, s->match_length); + + _tr_tally_dist(s, 1, s->match_length - MIN_MATCH, bflush); + + s->lookahead -= s->match_length; + s->strstart += s->match_length; + s->match_length = 0; + } else { + /* No match, output a literal byte */ + Tracevv((stderr,"%c", s->window[s->strstart])); + _tr_tally_lit (s, s->window[s->strstart], bflush); + s->lookahead--; + s->strstart++; + } + if (bflush) FLUSH_BLOCK(s, 0); + } + s->insert = 0; + if (flush == Z_FINISH) { + FLUSH_BLOCK(s, 1); + return finish_done; + } + if (s->last_lit) + FLUSH_BLOCK(s, 0); + return block_done; +} + +/* =========================================================================== + * For Z_HUFFMAN_ONLY, do not look for matches. Do not maintain a hash table. + * (It will be regenerated if this run of deflate switches away from Huffman.) + */ +local block_state deflate_huff(s, flush) + deflate_state *s; + int flush; +{ + int bflush; /* set if current block must be flushed */ + + for (;;) { + /* Make sure that we have a literal to write. */ + if (s->lookahead == 0) { + fill_window(s); + if (s->lookahead == 0) { + if (flush == Z_NO_FLUSH) + return need_more; + break; /* flush the current block */ + } + } + + /* Output a literal byte */ + s->match_length = 0; + Tracevv((stderr,"%c", s->window[s->strstart])); + _tr_tally_lit (s, s->window[s->strstart], bflush); + s->lookahead--; + s->strstart++; + if (bflush) FLUSH_BLOCK(s, 0); + } + s->insert = 0; + if (flush == Z_FINISH) { + FLUSH_BLOCK(s, 1); + return finish_done; + } + if (s->last_lit) + FLUSH_BLOCK(s, 0); + return block_done; +} diff --git a/ml/dlib/dlib/external/zlib/deflate.h b/ml/dlib/dlib/external/zlib/deflate.h new file mode 100644 index 000000000..ce0299edd --- /dev/null +++ b/ml/dlib/dlib/external/zlib/deflate.h @@ -0,0 +1,346 @@ +/* deflate.h -- internal compression state + * Copyright (C) 1995-2012 Jean-loup Gailly + * For conditions of distribution and use, see copyright notice in zlib.h + */ + +/* WARNING: this file should *not* be used by applications. It is + part of the implementation of the compression library and is + subject to change. Applications should only use zlib.h. + */ + +/* @(#) $Id$ */ + +#ifndef DEFLATE_H +#define DEFLATE_H + +#include "zutil.h" + +/* define NO_GZIP when compiling if you want to disable gzip header and + trailer creation by deflate(). NO_GZIP would be used to avoid linking in + the crc code when it is not needed. For shared libraries, gzip encoding + should be left enabled. */ +#ifndef NO_GZIP +# define GZIP +#endif + +/* =========================================================================== + * Internal compression state. + */ + +#define LENGTH_CODES 29 +/* number of length codes, not counting the special END_BLOCK code */ + +#define LITERALS 256 +/* number of literal bytes 0..255 */ + +#define L_CODES (LITERALS+1+LENGTH_CODES) +/* number of Literal or Length codes, including the END_BLOCK code */ + +#define D_CODES 30 +/* number of distance codes */ + +#define BL_CODES 19 +/* number of codes used to transfer the bit lengths */ + +#define HEAP_SIZE (2*L_CODES+1) +/* maximum heap size */ + +#define MAX_BITS 15 +/* All codes must not exceed MAX_BITS bits */ + +#define Buf_size 16 +/* size of bit buffer in bi_buf */ + +#define INIT_STATE 42 +#define EXTRA_STATE 69 +#define NAME_STATE 73 +#define COMMENT_STATE 91 +#define HCRC_STATE 103 +#define BUSY_STATE 113 +#define FINISH_STATE 666 +/* Stream status */ + + +/* Data structure describing a single value and its code string. */ +typedef struct ct_data_s { + union { + ush freq; /* frequency count */ + ush code; /* bit string */ + } fc; + union { + ush dad; /* father node in Huffman tree */ + ush len; /* length of bit string */ + } dl; +} FAR ct_data; + +#define Freq fc.freq +#define Code fc.code +#define Dad dl.dad +#define Len dl.len + +typedef struct static_tree_desc_s static_tree_desc; + +typedef struct tree_desc_s { + ct_data *dyn_tree; /* the dynamic tree */ + int max_code; /* largest code with non zero frequency */ + static_tree_desc *stat_desc; /* the corresponding static tree */ +} FAR tree_desc; + +typedef ush Pos; +typedef Pos FAR Posf; +typedef unsigned IPos; + +/* A Pos is an index in the character window. We use short instead of int to + * save space in the various tables. IPos is used only for parameter passing. + */ + +typedef struct internal_state { + z_streamp strm; /* pointer back to this zlib stream */ + int status; /* as the name implies */ + Bytef *pending_buf; /* output still pending */ + ulg pending_buf_size; /* size of pending_buf */ + Bytef *pending_out; /* next pending byte to output to the stream */ + uInt pending; /* nb of bytes in the pending buffer */ + int wrap; /* bit 0 true for zlib, bit 1 true for gzip */ + gz_headerp gzhead; /* gzip header information to write */ + uInt gzindex; /* where in extra, name, or comment */ + Byte method; /* can only be DEFLATED */ + int last_flush; /* value of flush param for previous deflate call */ + + /* used by deflate.c: */ + + uInt w_size; /* LZ77 window size (32K by default) */ + uInt w_bits; /* log2(w_size) (8..16) */ + uInt w_mask; /* w_size - 1 */ + + Bytef *window; + /* Sliding window. Input bytes are read into the second half of the window, + * and move to the first half later to keep a dictionary of at least wSize + * bytes. With this organization, matches are limited to a distance of + * wSize-MAX_MATCH bytes, but this ensures that IO is always + * performed with a length multiple of the block size. Also, it limits + * the window size to 64K, which is quite useful on MSDOS. + * To do: use the user input buffer as sliding window. + */ + + ulg window_size; + /* Actual size of window: 2*wSize, except when the user input buffer + * is directly used as sliding window. + */ + + Posf *prev; + /* Link to older string with same hash index. To limit the size of this + * array to 64K, this link is maintained only for the last 32K strings. + * An index in this array is thus a window index modulo 32K. + */ + + Posf *head; /* Heads of the hash chains or NIL. */ + + uInt ins_h; /* hash index of string to be inserted */ + uInt hash_size; /* number of elements in hash table */ + uInt hash_bits; /* log2(hash_size) */ + uInt hash_mask; /* hash_size-1 */ + + uInt hash_shift; + /* Number of bits by which ins_h must be shifted at each input + * step. It must be such that after MIN_MATCH steps, the oldest + * byte no longer takes part in the hash key, that is: + * hash_shift * MIN_MATCH >= hash_bits + */ + + long block_start; + /* Window position at the beginning of the current output block. Gets + * negative when the window is moved backwards. + */ + + uInt match_length; /* length of best match */ + IPos prev_match; /* previous match */ + int match_available; /* set if previous match exists */ + uInt strstart; /* start of string to insert */ + uInt match_start; /* start of matching string */ + uInt lookahead; /* number of valid bytes ahead in window */ + + uInt prev_length; + /* Length of the best match at previous step. Matches not greater than this + * are discarded. This is used in the lazy match evaluation. + */ + + uInt max_chain_length; + /* To speed up deflation, hash chains are never searched beyond this + * length. A higher limit improves compression ratio but degrades the + * speed. + */ + + uInt max_lazy_match; + /* Attempt to find a better match only when the current match is strictly + * smaller than this value. This mechanism is used only for compression + * levels >= 4. + */ +# define max_insert_length max_lazy_match + /* Insert new strings in the hash table only if the match length is not + * greater than this length. This saves time but degrades compression. + * max_insert_length is used only for compression levels <= 3. + */ + + int level; /* compression level (1..9) */ + int strategy; /* favor or force Huffman coding*/ + + uInt good_match; + /* Use a faster search when the previous match is longer than this */ + + int nice_match; /* Stop searching when current match exceeds this */ + + /* used by trees.c: */ + /* Didn't use ct_data typedef below to suppress compiler warning */ + struct ct_data_s dyn_ltree[HEAP_SIZE]; /* literal and length tree */ + struct ct_data_s dyn_dtree[2*D_CODES+1]; /* distance tree */ + struct ct_data_s bl_tree[2*BL_CODES+1]; /* Huffman tree for bit lengths */ + + struct tree_desc_s l_desc; /* desc. for literal tree */ + struct tree_desc_s d_desc; /* desc. for distance tree */ + struct tree_desc_s bl_desc; /* desc. for bit length tree */ + + ush bl_count[MAX_BITS+1]; + /* number of codes at each bit length for an optimal tree */ + + int heap[2*L_CODES+1]; /* heap used to build the Huffman trees */ + int heap_len; /* number of elements in the heap */ + int heap_max; /* element of largest frequency */ + /* The sons of heap[n] are heap[2*n] and heap[2*n+1]. heap[0] is not used. + * The same heap array is used to build all trees. + */ + + uch depth[2*L_CODES+1]; + /* Depth of each subtree used as tie breaker for trees of equal frequency + */ + + uchf *l_buf; /* buffer for literals or lengths */ + + uInt lit_bufsize; + /* Size of match buffer for literals/lengths. There are 4 reasons for + * limiting lit_bufsize to 64K: + * - frequencies can be kept in 16 bit counters + * - if compression is not successful for the first block, all input + * data is still in the window so we can still emit a stored block even + * when input comes from standard input. (This can also be done for + * all blocks if lit_bufsize is not greater than 32K.) + * - if compression is not successful for a file smaller than 64K, we can + * even emit a stored file instead of a stored block (saving 5 bytes). + * This is applicable only for zip (not gzip or zlib). + * - creating new Huffman trees less frequently may not provide fast + * adaptation to changes in the input data statistics. (Take for + * example a binary file with poorly compressible code followed by + * a highly compressible string table.) Smaller buffer sizes give + * fast adaptation but have of course the overhead of transmitting + * trees more frequently. + * - I can't count above 4 + */ + + uInt last_lit; /* running index in l_buf */ + + ushf *d_buf; + /* Buffer for distances. To simplify the code, d_buf and l_buf have + * the same number of elements. To use different lengths, an extra flag + * array would be necessary. + */ + + ulg opt_len; /* bit length of current block with optimal trees */ + ulg static_len; /* bit length of current block with static trees */ + uInt matches; /* number of string matches in current block */ + uInt insert; /* bytes at end of window left to insert */ + +#ifdef DEBUG + ulg compressed_len; /* total bit length of compressed file mod 2^32 */ + ulg bits_sent; /* bit length of compressed data sent mod 2^32 */ +#endif + + ush bi_buf; + /* Output buffer. bits are inserted starting at the bottom (least + * significant bits). + */ + int bi_valid; + /* Number of valid bits in bi_buf. All bits above the last valid bit + * are always zero. + */ + + ulg high_water; + /* High water mark offset in window for initialized bytes -- bytes above + * this are set to zero in order to avoid memory check warnings when + * longest match routines access bytes past the input. This is then + * updated to the new high water mark. + */ + +} FAR deflate_state; + +/* Output a byte on the stream. + * IN assertion: there is enough room in pending_buf. + */ +#define put_byte(s, c) {s->pending_buf[s->pending++] = (c);} + + +#define MIN_LOOKAHEAD (MAX_MATCH+MIN_MATCH+1) +/* Minimum amount of lookahead, except at the end of the input file. + * See deflate.c for comments about the MIN_MATCH+1. + */ + +#define MAX_DIST(s) ((s)->w_size-MIN_LOOKAHEAD) +/* In order to simplify the code, particularly on 16 bit machines, match + * distances are limited to MAX_DIST instead of WSIZE. + */ + +#define WIN_INIT MAX_MATCH +/* Number of bytes after end of data in window to initialize in order to avoid + memory checker errors from longest match routines */ + + /* in trees.c */ +void ZLIB_INTERNAL _tr_init OF((deflate_state *s)); +int ZLIB_INTERNAL _tr_tally OF((deflate_state *s, unsigned dist, unsigned lc)); +void ZLIB_INTERNAL _tr_flush_block OF((deflate_state *s, charf *buf, + ulg stored_len, int last)); +void ZLIB_INTERNAL _tr_flush_bits OF((deflate_state *s)); +void ZLIB_INTERNAL _tr_align OF((deflate_state *s)); +void ZLIB_INTERNAL _tr_stored_block OF((deflate_state *s, charf *buf, + ulg stored_len, int last)); + +#define d_code(dist) \ + ((dist) < 256 ? _dist_code[dist] : _dist_code[256+((dist)>>7)]) +/* Mapping from a distance to a distance code. dist is the distance - 1 and + * must not have side effects. _dist_code[256] and _dist_code[257] are never + * used. + */ + +#ifndef DEBUG +/* Inline versions of _tr_tally for speed: */ + +#if defined(GEN_TREES_H) || !defined(STDC) + extern uch ZLIB_INTERNAL _length_code[]; + extern uch ZLIB_INTERNAL _dist_code[]; +#else + extern const uch ZLIB_INTERNAL _length_code[]; + extern const uch ZLIB_INTERNAL _dist_code[]; +#endif + +# define _tr_tally_lit(s, c, flush) \ + { uch cc = (c); \ + s->d_buf[s->last_lit] = 0; \ + s->l_buf[s->last_lit++] = cc; \ + s->dyn_ltree[cc].Freq++; \ + flush = (s->last_lit == s->lit_bufsize-1); \ + } +# define _tr_tally_dist(s, distance, length, flush) \ + { uch len = (length); \ + ush dist = (distance); \ + s->d_buf[s->last_lit] = dist; \ + s->l_buf[s->last_lit++] = len; \ + dist--; \ + s->dyn_ltree[_length_code[len]+LITERALS+1].Freq++; \ + s->dyn_dtree[d_code(dist)].Freq++; \ + flush = (s->last_lit == s->lit_bufsize-1); \ + } +#else +# define _tr_tally_lit(s, c, flush) flush = _tr_tally(s, 0, c) +# define _tr_tally_dist(s, distance, length, flush) \ + flush = _tr_tally(s, distance, length) +#endif + +#endif /* DEFLATE_H */ diff --git a/ml/dlib/dlib/external/zlib/gzclose.c b/ml/dlib/dlib/external/zlib/gzclose.c new file mode 100644 index 000000000..caeb99a31 --- /dev/null +++ b/ml/dlib/dlib/external/zlib/gzclose.c @@ -0,0 +1,25 @@ +/* gzclose.c -- zlib gzclose() function + * Copyright (C) 2004, 2010 Mark Adler + * For conditions of distribution and use, see copyright notice in zlib.h + */ + +#include "gzguts.h" + +/* gzclose() is in a separate file so that it is linked in only if it is used. + That way the other gzclose functions can be used instead to avoid linking in + unneeded compression or decompression routines. */ +int ZEXPORT gzclose(file) + gzFile file; +{ +#ifndef NO_GZCOMPRESS + gz_statep state; + + if (file == NULL) + return Z_STREAM_ERROR; + state = (gz_statep)file; + + return state->mode == GZ_READ ? gzclose_r(file) : gzclose_w(file); +#else + return gzclose_r(file); +#endif +} diff --git a/ml/dlib/dlib/external/zlib/gzguts.h b/ml/dlib/dlib/external/zlib/gzguts.h new file mode 100644 index 000000000..2bb0b0499 --- /dev/null +++ b/ml/dlib/dlib/external/zlib/gzguts.h @@ -0,0 +1,219 @@ +/* gzguts.h -- zlib internal header definitions for gz* operations + * Copyright (C) 2004, 2005, 2010, 2011, 2012, 2013 Mark Adler + * For conditions of distribution and use, see copyright notice in zlib.h + */ + +#ifdef _MSC_VER +// Disable the following warnings for Visual Studio +// This is a warning you get from visual studio 2005 about things in the standard C++ +// library being "deprecated." I checked the C++ standard and it doesn't say jack +// about any of them (I checked the searchable PDF). So this warning is total Bunk. +#pragma warning(disable : 4996) +#endif + +#ifdef _LARGEFILE64_SOURCE +# ifndef _LARGEFILE_SOURCE +# define _LARGEFILE_SOURCE 1 +# endif +# ifdef _FILE_OFFSET_BITS +# undef _FILE_OFFSET_BITS +# endif +#endif + +#ifdef HAVE_HIDDEN +# define ZLIB_INTERNAL __attribute__((visibility ("hidden"))) +#else +# define ZLIB_INTERNAL +#endif + +#include +#include "zlib.h" +#ifdef STDC +# include +# include +# include +#endif +#include + +#ifdef _WIN32 +# include +#endif + +#if defined(__TURBOC__) || defined(_MSC_VER) || defined(_WIN32) +# include +#else +# include +#endif + +#ifdef WINAPI_FAMILY +# define open _open +# define read _read +# define write _write +# define close _close +#endif + +#ifdef NO_DEFLATE /* for compatibility with old definition */ +# define NO_GZCOMPRESS +#endif + +#if defined(STDC99) || (defined(__TURBOC__) && __TURBOC__ >= 0x550) +# ifndef HAVE_VSNPRINTF +# define HAVE_VSNPRINTF +# endif +#endif + +#if defined(__CYGWIN__) +# ifndef HAVE_VSNPRINTF +# define HAVE_VSNPRINTF +# endif +#endif + +#if defined(MSDOS) && defined(__BORLANDC__) && (BORLANDC > 0x410) +# ifndef HAVE_VSNPRINTF +# define HAVE_VSNPRINTF +# endif +#endif + +#ifndef HAVE_VSNPRINTF +# ifdef MSDOS +/* vsnprintf may exist on some MS-DOS compilers (DJGPP?), + but for now we just assume it doesn't. */ +# define NO_vsnprintf +# endif +# ifdef __TURBOC__ +# define NO_vsnprintf +# endif +# ifdef WIN32 +/* In Win32, vsnprintf is available as the "non-ANSI" _vsnprintf. */ +# if !defined(vsnprintf) && !defined(NO_vsnprintf) +# if !defined(_MSC_VER) || ( defined(_MSC_VER) && _MSC_VER < 1500 ) +# define vsnprintf _vsnprintf +# endif +# endif +# endif +# ifdef __SASC +# define NO_vsnprintf +# endif +# ifdef VMS +# define NO_vsnprintf +# endif +# ifdef __OS400__ +# define NO_vsnprintf +# endif +# ifdef __MVS__ +# define NO_vsnprintf +# endif +#endif + +/* unlike snprintf (which is required in C99, yet still not supported by + Microsoft more than a decade later!), _snprintf does not guarantee null + termination of the result -- however this is only used in gzlib.c where + the result is assured to fit in the space provided */ +#ifdef _MSC_VER +# define snprintf _snprintf +#endif + +#ifndef local +# define local static +#endif +/* compile with -Dlocal if your debugger can't find static symbols */ + +/* gz* functions always use library allocation functions */ +#ifndef STDC + extern voidp malloc OF((uInt size)); + extern void free OF((voidpf ptr)); +#endif + +/* get errno and strerror definition */ +#if defined UNDER_CE +# include +# define zstrerror() gz_strwinerror((DWORD)GetLastError()) +#else +# ifndef NO_STRERROR +# include +# define zstrerror() strerror(errno) +# else +# define zstrerror() "stdio error (consult errno)" +# endif +#endif + +/* provide prototypes for these when building zlib without LFS */ +#if !defined(_LARGEFILE64_SOURCE) || _LFS64_LARGEFILE-0 == 0 + ZEXTERN gzFile ZEXPORT gzopen64 OF((const char *, const char *)); + ZEXTERN z_off64_t ZEXPORT gzseek64 OF((gzFile, z_off64_t, int)); + ZEXTERN z_off64_t ZEXPORT gztell64 OF((gzFile)); + ZEXTERN z_off64_t ZEXPORT gzoffset64 OF((gzFile)); +#endif + +/* default memLevel */ +#if MAX_MEM_LEVEL >= 8 +# define DEF_MEM_LEVEL 8 +#else +# define DEF_MEM_LEVEL MAX_MEM_LEVEL +#endif + +/* default i/o buffer size -- double this for output when reading (this and + twice this must be able to fit in an unsigned type) */ +#define GZBUFSIZE 8192 + +/* gzip modes, also provide a little integrity check on the passed structure */ +#define GZ_NONE 0 +#define GZ_READ 7247 +#define GZ_WRITE 31153 +#define GZ_APPEND 1 /* mode set to GZ_WRITE after the file is opened */ + +/* values for gz_state how */ +#define LOOK 0 /* look for a gzip header */ +#define COPY 1 /* copy input directly */ +#define GZIP 2 /* decompress a gzip stream */ + +/* internal gzip file state data structure */ +typedef struct { + /* exposed contents for gzgetc() macro */ + struct gzFile_s x; /* "x" for exposed */ + /* x.have: number of bytes available at x.next */ + /* x.next: next output data to deliver or write */ + /* x.pos: current position in uncompressed data */ + /* used for both reading and writing */ + int mode; /* see gzip modes above */ + int fd; /* file descriptor */ + char *path; /* path or fd for error messages */ + unsigned size; /* buffer size, zero if not allocated yet */ + unsigned want; /* requested buffer size, default is GZBUFSIZE */ + unsigned char *in; /* input buffer */ + unsigned char *out; /* output buffer (double-sized when reading) */ + int direct; /* 0 if processing gzip, 1 if transparent */ + /* just for reading */ + int how; /* 0: get header, 1: copy, 2: decompress */ + z_off64_t start; /* where the gzip data started, for rewinding */ + int eof; /* true if end of input file reached */ + int past; /* true if read requested past end */ + /* just for writing */ + int level; /* compression level */ + int strategy; /* compression strategy */ + /* seek request */ + z_off64_t skip; /* amount to skip (already rewound if backwards) */ + int seek; /* true if seek request pending */ + /* error information */ + int err; /* error code */ + char *msg; /* error message */ + /* zlib inflate or deflate stream */ + z_stream strm; /* stream structure in-place (not a pointer) */ +} gz_state; +typedef gz_state FAR *gz_statep; + +/* shared functions */ +void ZLIB_INTERNAL gz_error OF((gz_statep, int, const char *)); +#if defined UNDER_CE +char ZLIB_INTERNAL *gz_strwinerror OF((DWORD error)); +#endif + +/* GT_OFF(x), where x is an unsigned value, is true if x > maximum z_off64_t + value -- needed when comparing unsigned to z_off64_t, which is signed + (possible z_off64_t types off_t, off64_t, and long are all signed) */ +#ifdef INT_MAX +# define GT_OFF(x) (sizeof(int) == sizeof(z_off64_t) && (x) > INT_MAX) +#else +unsigned ZLIB_INTERNAL gz_intmax OF((void)); +# define GT_OFF(x) (sizeof(int) == sizeof(z_off64_t) && (x) > gz_intmax()) +#endif diff --git a/ml/dlib/dlib/external/zlib/gzlib.c b/ml/dlib/dlib/external/zlib/gzlib.c new file mode 100644 index 000000000..fae202ef8 --- /dev/null +++ b/ml/dlib/dlib/external/zlib/gzlib.c @@ -0,0 +1,634 @@ +/* gzlib.c -- zlib functions common to reading and writing gzip files + * Copyright (C) 2004, 2010, 2011, 2012, 2013 Mark Adler + * For conditions of distribution and use, see copyright notice in zlib.h + */ + +#include "gzguts.h" + +#if defined(_WIN32) && !defined(__BORLANDC__) +# define LSEEK _lseeki64 +#else +#if defined(_LARGEFILE64_SOURCE) && _LFS64_LARGEFILE-0 +# define LSEEK lseek64 +#else +# define LSEEK lseek +#endif +#endif + +/* Local functions */ +local void gz_reset OF((gz_statep)); +local gzFile gz_open OF((const void *, int, const char *)); + +#if defined UNDER_CE + +/* Map the Windows error number in ERROR to a locale-dependent error message + string and return a pointer to it. Typically, the values for ERROR come + from GetLastError. + + The string pointed to shall not be modified by the application, but may be + overwritten by a subsequent call to gz_strwinerror + + The gz_strwinerror function does not change the current setting of + GetLastError. */ +char ZLIB_INTERNAL *gz_strwinerror (error) + DWORD error; +{ + static char buf[1024]; + + wchar_t *msgbuf; + DWORD lasterr = GetLastError(); + DWORD chars = FormatMessage(FORMAT_MESSAGE_FROM_SYSTEM + | FORMAT_MESSAGE_ALLOCATE_BUFFER, + NULL, + error, + 0, /* Default language */ + (LPVOID)&msgbuf, + 0, + NULL); + if (chars != 0) { + /* If there is an \r\n appended, zap it. */ + if (chars >= 2 + && msgbuf[chars - 2] == '\r' && msgbuf[chars - 1] == '\n') { + chars -= 2; + msgbuf[chars] = 0; + } + + if (chars > sizeof (buf) - 1) { + chars = sizeof (buf) - 1; + msgbuf[chars] = 0; + } + + wcstombs(buf, msgbuf, chars + 1); + LocalFree(msgbuf); + } + else { + sprintf(buf, "unknown win32 error (%ld)", error); + } + + SetLastError(lasterr); + return buf; +} + +#endif /* UNDER_CE */ + +/* Reset gzip file state */ +local void gz_reset(state) + gz_statep state; +{ + state->x.have = 0; /* no output data available */ + if (state->mode == GZ_READ) { /* for reading ... */ + state->eof = 0; /* not at end of file */ + state->past = 0; /* have not read past end yet */ + state->how = LOOK; /* look for gzip header */ + } + state->seek = 0; /* no seek request pending */ + gz_error(state, Z_OK, NULL); /* clear error */ + state->x.pos = 0; /* no uncompressed data yet */ + state->strm.avail_in = 0; /* no input data yet */ +} + +/* Open a gzip file either by name or file descriptor. */ +local gzFile gz_open(path, fd, mode) + const void *path; + int fd; + const char *mode; +{ + gz_statep state; + size_t len; + int oflag; +#ifdef O_CLOEXEC + int cloexec = 0; +#endif +#ifdef O_EXCL + int exclusive = 0; +#endif + + /* check input */ + if (path == NULL) + return NULL; + + /* allocate gzFile structure to return */ + state = (gz_statep)malloc(sizeof(gz_state)); + if (state == NULL) + return NULL; + state->size = 0; /* no buffers allocated yet */ + state->want = GZBUFSIZE; /* requested buffer size */ + state->msg = NULL; /* no error message yet */ + + /* interpret mode */ + state->mode = GZ_NONE; + state->level = Z_DEFAULT_COMPRESSION; + state->strategy = Z_DEFAULT_STRATEGY; + state->direct = 0; + while (*mode) { + if (*mode >= '0' && *mode <= '9') + state->level = *mode - '0'; + else + switch (*mode) { + case 'r': + state->mode = GZ_READ; + break; +#ifndef NO_GZCOMPRESS + case 'w': + state->mode = GZ_WRITE; + break; + case 'a': + state->mode = GZ_APPEND; + break; +#endif + case '+': /* can't read and write at the same time */ + free(state); + return NULL; + case 'b': /* ignore -- will request binary anyway */ + break; +#ifdef O_CLOEXEC + case 'e': + cloexec = 1; + break; +#endif +#ifdef O_EXCL + case 'x': + exclusive = 1; + break; +#endif + case 'f': + state->strategy = Z_FILTERED; + break; + case 'h': + state->strategy = Z_HUFFMAN_ONLY; + break; + case 'R': + state->strategy = Z_RLE; + break; + case 'F': + state->strategy = Z_FIXED; + break; + case 'T': + state->direct = 1; + break; + default: /* could consider as an error, but just ignore */ + ; + } + mode++; + } + + /* must provide an "r", "w", or "a" */ + if (state->mode == GZ_NONE) { + free(state); + return NULL; + } + + /* can't force transparent read */ + if (state->mode == GZ_READ) { + if (state->direct) { + free(state); + return NULL; + } + state->direct = 1; /* for empty file */ + } + + /* save the path name for error messages */ +#ifdef _WIN32 + if (fd == -2) { + len = wcstombs(NULL, path, 0); + if (len == (size_t)-1) + len = 0; + } + else +#endif + len = strlen((const char *)path); + state->path = (char *)malloc(len + 1); + if (state->path == NULL) { + free(state); + return NULL; + } +#ifdef _WIN32 + if (fd == -2) + if (len) + wcstombs(state->path, path, len + 1); + else + *(state->path) = 0; + else +#endif +#if !defined(NO_snprintf) && !defined(NO_vsnprintf) + snprintf(state->path, len + 1, "%s", (const char *)path); +#else + strcpy(state->path, path); +#endif + + /* compute the flags for open() */ + oflag = +#ifdef O_LARGEFILE + O_LARGEFILE | +#endif +#ifdef O_BINARY + O_BINARY | +#endif +#ifdef O_CLOEXEC + (cloexec ? O_CLOEXEC : 0) | +#endif + (state->mode == GZ_READ ? + O_RDONLY : + (O_WRONLY | O_CREAT | +#ifdef O_EXCL + (exclusive ? O_EXCL : 0) | +#endif + (state->mode == GZ_WRITE ? + O_TRUNC : + O_APPEND))); + + /* open the file with the appropriate flags (or just use fd) */ + state->fd = fd > -1 ? fd : ( +#ifdef _WIN32 + fd == -2 ? _wopen(path, oflag, 0666) : +#endif + open((const char *)path, oflag, 0666)); + if (state->fd == -1) { + free(state->path); + free(state); + return NULL; + } + if (state->mode == GZ_APPEND) + state->mode = GZ_WRITE; /* simplify later checks */ + + /* save the current position for rewinding (only if reading) */ + if (state->mode == GZ_READ) { + state->start = LSEEK(state->fd, 0, SEEK_CUR); + if (state->start == -1) state->start = 0; + } + + /* initialize stream */ + gz_reset(state); + + /* return stream */ + return (gzFile)state; +} + +/* -- see zlib.h -- */ +gzFile ZEXPORT gzopen(path, mode) + const char *path; + const char *mode; +{ + return gz_open(path, -1, mode); +} + +/* -- see zlib.h -- */ +gzFile ZEXPORT gzopen64(path, mode) + const char *path; + const char *mode; +{ + return gz_open(path, -1, mode); +} + +/* -- see zlib.h -- */ +gzFile ZEXPORT gzdopen(fd, mode) + int fd; + const char *mode; +{ + char *path; /* identifier for error messages */ + gzFile gz; + + if (fd == -1 || (path = (char *)malloc(7 + 3 * sizeof(int))) == NULL) + return NULL; +#if !defined(NO_snprintf) && !defined(NO_vsnprintf) + snprintf(path, 7 + 3 * sizeof(int), "", fd); /* for debugging */ +#else + sprintf(path, "", fd); /* for debugging */ +#endif + gz = gz_open(path, fd, mode); + free(path); + return gz; +} + +/* -- see zlib.h -- */ +#ifdef _WIN32 +gzFile ZEXPORT gzopen_w(path, mode) + const wchar_t *path; + const char *mode; +{ + return gz_open(path, -2, mode); +} +#endif + +/* -- see zlib.h -- */ +int ZEXPORT gzbuffer(file, size) + gzFile file; + unsigned size; +{ + gz_statep state; + + /* get internal structure and check integrity */ + if (file == NULL) + return -1; + state = (gz_statep)file; + if (state->mode != GZ_READ && state->mode != GZ_WRITE) + return -1; + + /* make sure we haven't already allocated memory */ + if (state->size != 0) + return -1; + + /* check and set requested size */ + if (size < 2) + size = 2; /* need two bytes to check magic header */ + state->want = size; + return 0; +} + +/* -- see zlib.h -- */ +int ZEXPORT gzrewind(file) + gzFile file; +{ + gz_statep state; + + /* get internal structure */ + if (file == NULL) + return -1; + state = (gz_statep)file; + + /* check that we're reading and that there's no error */ + if (state->mode != GZ_READ || + (state->err != Z_OK && state->err != Z_BUF_ERROR)) + return -1; + + /* back up and start over */ + if (LSEEK(state->fd, state->start, SEEK_SET) == -1) + return -1; + gz_reset(state); + return 0; +} + +/* -- see zlib.h -- */ +z_off64_t ZEXPORT gzseek64(file, offset, whence) + gzFile file; + z_off64_t offset; + int whence; +{ + unsigned n; + z_off64_t ret; + gz_statep state; + + /* get internal structure and check integrity */ + if (file == NULL) + return -1; + state = (gz_statep)file; + if (state->mode != GZ_READ && state->mode != GZ_WRITE) + return -1; + + /* check that there's no error */ + if (state->err != Z_OK && state->err != Z_BUF_ERROR) + return -1; + + /* can only seek from start or relative to current position */ + if (whence != SEEK_SET && whence != SEEK_CUR) + return -1; + + /* normalize offset to a SEEK_CUR specification */ + if (whence == SEEK_SET) + offset -= state->x.pos; + else if (state->seek) + offset += state->skip; + state->seek = 0; + + /* if within raw area while reading, just go there */ + if (state->mode == GZ_READ && state->how == COPY && + state->x.pos + offset >= 0) { + ret = LSEEK(state->fd, offset - state->x.have, SEEK_CUR); + if (ret == -1) + return -1; + state->x.have = 0; + state->eof = 0; + state->past = 0; + state->seek = 0; + gz_error(state, Z_OK, NULL); + state->strm.avail_in = 0; + state->x.pos += offset; + return state->x.pos; + } + + /* calculate skip amount, rewinding if needed for back seek when reading */ + if (offset < 0) { + if (state->mode != GZ_READ) /* writing -- can't go backwards */ + return -1; + offset += state->x.pos; + if (offset < 0) /* before start of file! */ + return -1; + if (gzrewind(file) == -1) /* rewind, then skip to offset */ + return -1; + } + + /* if reading, skip what's in output buffer (one less gzgetc() check) */ + if (state->mode == GZ_READ) { + n = GT_OFF(state->x.have) || (z_off64_t)state->x.have > offset ? + (unsigned)offset : state->x.have; + state->x.have -= n; + state->x.next += n; + state->x.pos += n; + offset -= n; + } + + /* request skip (if not zero) */ + if (offset) { + state->seek = 1; + state->skip = offset; + } + return state->x.pos + offset; +} + +/* -- see zlib.h -- */ +z_off_t ZEXPORT gzseek(file, offset, whence) + gzFile file; + z_off_t offset; + int whence; +{ + z_off64_t ret; + + ret = gzseek64(file, (z_off64_t)offset, whence); + return ret == (z_off_t)ret ? (z_off_t)ret : -1; +} + +/* -- see zlib.h -- */ +z_off64_t ZEXPORT gztell64(file) + gzFile file; +{ + gz_statep state; + + /* get internal structure and check integrity */ + if (file == NULL) + return -1; + state = (gz_statep)file; + if (state->mode != GZ_READ && state->mode != GZ_WRITE) + return -1; + + /* return position */ + return state->x.pos + (state->seek ? state->skip : 0); +} + +/* -- see zlib.h -- */ +z_off_t ZEXPORT gztell(file) + gzFile file; +{ + z_off64_t ret; + + ret = gztell64(file); + return ret == (z_off_t)ret ? (z_off_t)ret : -1; +} + +/* -- see zlib.h -- */ +z_off64_t ZEXPORT gzoffset64(file) + gzFile file; +{ + z_off64_t offset; + gz_statep state; + + /* get internal structure and check integrity */ + if (file == NULL) + return -1; + state = (gz_statep)file; + if (state->mode != GZ_READ && state->mode != GZ_WRITE) + return -1; + + /* compute and return effective offset in file */ + offset = LSEEK(state->fd, 0, SEEK_CUR); + if (offset == -1) + return -1; + if (state->mode == GZ_READ) /* reading */ + offset -= state->strm.avail_in; /* don't count buffered input */ + return offset; +} + +/* -- see zlib.h -- */ +z_off_t ZEXPORT gzoffset(file) + gzFile file; +{ + z_off64_t ret; + + ret = gzoffset64(file); + return ret == (z_off_t)ret ? (z_off_t)ret : -1; +} + +/* -- see zlib.h -- */ +int ZEXPORT gzeof(file) + gzFile file; +{ + gz_statep state; + + /* get internal structure and check integrity */ + if (file == NULL) + return 0; + state = (gz_statep)file; + if (state->mode != GZ_READ && state->mode != GZ_WRITE) + return 0; + + /* return end-of-file state */ + return state->mode == GZ_READ ? state->past : 0; +} + +/* -- see zlib.h -- */ +const char * ZEXPORT gzerror(file, errnum) + gzFile file; + int *errnum; +{ + gz_statep state; + + /* get internal structure and check integrity */ + if (file == NULL) + return NULL; + state = (gz_statep)file; + if (state->mode != GZ_READ && state->mode != GZ_WRITE) + return NULL; + + /* return error information */ + if (errnum != NULL) + *errnum = state->err; + return state->err == Z_MEM_ERROR ? "out of memory" : + (state->msg == NULL ? "" : state->msg); +} + +/* -- see zlib.h -- */ +void ZEXPORT gzclearerr(file) + gzFile file; +{ + gz_statep state; + + /* get internal structure and check integrity */ + if (file == NULL) + return; + state = (gz_statep)file; + if (state->mode != GZ_READ && state->mode != GZ_WRITE) + return; + + /* clear error and end-of-file */ + if (state->mode == GZ_READ) { + state->eof = 0; + state->past = 0; + } + gz_error(state, Z_OK, NULL); +} + +/* Create an error message in allocated memory and set state->err and + state->msg accordingly. Free any previous error message already there. Do + not try to free or allocate space if the error is Z_MEM_ERROR (out of + memory). Simply save the error message as a static string. If there is an + allocation failure constructing the error message, then convert the error to + out of memory. */ +void ZLIB_INTERNAL gz_error(state, err, msg) + gz_statep state; + int err; + const char *msg; +{ + /* free previously allocated message and clear */ + if (state->msg != NULL) { + if (state->err != Z_MEM_ERROR) + free(state->msg); + state->msg = NULL; + } + + /* if fatal, set state->x.have to 0 so that the gzgetc() macro fails */ + if (err != Z_OK && err != Z_BUF_ERROR) + state->x.have = 0; + + /* set error code, and if no message, then done */ + state->err = err; + if (msg == NULL) + return; + + /* for an out of memory error, return literal string when requested */ + if (err == Z_MEM_ERROR) + return; + + /* construct error message with path */ + if ((state->msg = (char *)malloc(strlen(state->path) + strlen(msg) + 3)) == + NULL) { + state->err = Z_MEM_ERROR; + return; + } +#if !defined(NO_snprintf) && !defined(NO_vsnprintf) + snprintf(state->msg, strlen(state->path) + strlen(msg) + 3, + "%s%s%s", state->path, ": ", msg); +#else + strcpy(state->msg, state->path); + strcat(state->msg, ": "); + strcat(state->msg, msg); +#endif + return; +} + +#ifndef INT_MAX +/* portably return maximum value for an int (when limits.h presumed not + available) -- we need to do this to cover cases where 2's complement not + used, since C standard permits 1's complement and sign-bit representations, + otherwise we could just use ((unsigned)-1) >> 1 */ +unsigned ZLIB_INTERNAL gz_intmax() +{ + unsigned p, q; + + p = 1; + do { + q = p; + p <<= 1; + p++; + } while (p > q); + return q >> 1; +} +#endif diff --git a/ml/dlib/dlib/external/zlib/gzread.c b/ml/dlib/dlib/external/zlib/gzread.c new file mode 100644 index 000000000..bf4538eb2 --- /dev/null +++ b/ml/dlib/dlib/external/zlib/gzread.c @@ -0,0 +1,594 @@ +/* gzread.c -- zlib functions for reading gzip files + * Copyright (C) 2004, 2005, 2010, 2011, 2012, 2013 Mark Adler + * For conditions of distribution and use, see copyright notice in zlib.h + */ + +#include "gzguts.h" + +/* Local functions */ +local int gz_load OF((gz_statep, unsigned char *, unsigned, unsigned *)); +local int gz_avail OF((gz_statep)); +local int gz_look OF((gz_statep)); +local int gz_decomp OF((gz_statep)); +local int gz_fetch OF((gz_statep)); +local int gz_skip OF((gz_statep, z_off64_t)); + +/* Use read() to load a buffer -- return -1 on error, otherwise 0. Read from + state->fd, and update state->eof, state->err, and state->msg as appropriate. + This function needs to loop on read(), since read() is not guaranteed to + read the number of bytes requested, depending on the type of descriptor. */ +local int gz_load(state, buf, len, have) + gz_statep state; + unsigned char *buf; + unsigned len; + unsigned *have; +{ + int ret; + + *have = 0; + do { + ret = read(state->fd, buf + *have, len - *have); + if (ret <= 0) + break; + *have += ret; + } while (*have < len); + if (ret < 0) { + gz_error(state, Z_ERRNO, zstrerror()); + return -1; + } + if (ret == 0) + state->eof = 1; + return 0; +} + +/* Load up input buffer and set eof flag if last data loaded -- return -1 on + error, 0 otherwise. Note that the eof flag is set when the end of the input + file is reached, even though there may be unused data in the buffer. Once + that data has been used, no more attempts will be made to read the file. + If strm->avail_in != 0, then the current data is moved to the beginning of + the input buffer, and then the remainder of the buffer is loaded with the + available data from the input file. */ +local int gz_avail(state) + gz_statep state; +{ + unsigned got; + z_streamp strm = &(state->strm); + + if (state->err != Z_OK && state->err != Z_BUF_ERROR) + return -1; + if (state->eof == 0) { + if (strm->avail_in) { /* copy what's there to the start */ + unsigned char *p = state->in; + unsigned const char *q = strm->next_in; + unsigned n = strm->avail_in; + do { + *p++ = *q++; + } while (--n); + } + if (gz_load(state, state->in + strm->avail_in, + state->size - strm->avail_in, &got) == -1) + return -1; + strm->avail_in += got; + strm->next_in = state->in; + } + return 0; +} + +/* Look for gzip header, set up for inflate or copy. state->x.have must be 0. + If this is the first time in, allocate required memory. state->how will be + left unchanged if there is no more input data available, will be set to COPY + if there is no gzip header and direct copying will be performed, or it will + be set to GZIP for decompression. If direct copying, then leftover input + data from the input buffer will be copied to the output buffer. In that + case, all further file reads will be directly to either the output buffer or + a user buffer. If decompressing, the inflate state will be initialized. + gz_look() will return 0 on success or -1 on failure. */ +local int gz_look(state) + gz_statep state; +{ + z_streamp strm = &(state->strm); + + /* allocate read buffers and inflate memory */ + if (state->size == 0) { + /* allocate buffers */ + state->in = (unsigned char *)malloc(state->want); + state->out = (unsigned char *)malloc(state->want << 1); + if (state->in == NULL || state->out == NULL) { + if (state->out != NULL) + free(state->out); + if (state->in != NULL) + free(state->in); + gz_error(state, Z_MEM_ERROR, "out of memory"); + return -1; + } + state->size = state->want; + + /* allocate inflate memory */ + state->strm.zalloc = Z_NULL; + state->strm.zfree = Z_NULL; + state->strm.opaque = Z_NULL; + state->strm.avail_in = 0; + state->strm.next_in = Z_NULL; + if (inflateInit2(&(state->strm), 15 + 16) != Z_OK) { /* gunzip */ + free(state->out); + free(state->in); + state->size = 0; + gz_error(state, Z_MEM_ERROR, "out of memory"); + return -1; + } + } + + /* get at least the magic bytes in the input buffer */ + if (strm->avail_in < 2) { + if (gz_avail(state) == -1) + return -1; + if (strm->avail_in == 0) + return 0; + } + + /* look for gzip magic bytes -- if there, do gzip decoding (note: there is + a logical dilemma here when considering the case of a partially written + gzip file, to wit, if a single 31 byte is written, then we cannot tell + whether this is a single-byte file, or just a partially written gzip + file -- for here we assume that if a gzip file is being written, then + the header will be written in a single operation, so that reading a + single byte is sufficient indication that it is not a gzip file) */ + if (strm->avail_in > 1 && + strm->next_in[0] == 31 && strm->next_in[1] == 139) { + inflateReset(strm); + state->how = GZIP; + state->direct = 0; + return 0; + } + + /* no gzip header -- if we were decoding gzip before, then this is trailing + garbage. Ignore the trailing garbage and finish. */ + if (state->direct == 0) { + strm->avail_in = 0; + state->eof = 1; + state->x.have = 0; + return 0; + } + + /* doing raw i/o, copy any leftover input to output -- this assumes that + the output buffer is larger than the input buffer, which also assures + space for gzungetc() */ + state->x.next = state->out; + if (strm->avail_in) { + memcpy(state->x.next, strm->next_in, strm->avail_in); + state->x.have = strm->avail_in; + strm->avail_in = 0; + } + state->how = COPY; + state->direct = 1; + return 0; +} + +/* Decompress from input to the provided next_out and avail_out in the state. + On return, state->x.have and state->x.next point to the just decompressed + data. If the gzip stream completes, state->how is reset to LOOK to look for + the next gzip stream or raw data, once state->x.have is depleted. Returns 0 + on success, -1 on failure. */ +local int gz_decomp(state) + gz_statep state; +{ + int ret = Z_OK; + unsigned had; + z_streamp strm = &(state->strm); + + /* fill output buffer up to end of deflate stream */ + had = strm->avail_out; + do { + /* get more input for inflate() */ + if (strm->avail_in == 0 && gz_avail(state) == -1) + return -1; + if (strm->avail_in == 0) { + gz_error(state, Z_BUF_ERROR, "unexpected end of file"); + break; + } + + /* decompress and handle errors */ + ret = inflate(strm, Z_NO_FLUSH); + if (ret == Z_STREAM_ERROR || ret == Z_NEED_DICT) { + gz_error(state, Z_STREAM_ERROR, + "internal error: inflate stream corrupt"); + return -1; + } + if (ret == Z_MEM_ERROR) { + gz_error(state, Z_MEM_ERROR, "out of memory"); + return -1; + } + if (ret == Z_DATA_ERROR) { /* deflate stream invalid */ + gz_error(state, Z_DATA_ERROR, + strm->msg == NULL ? "compressed data error" : strm->msg); + return -1; + } + } while (strm->avail_out && ret != Z_STREAM_END); + + /* update available output */ + state->x.have = had - strm->avail_out; + state->x.next = strm->next_out - state->x.have; + + /* if the gzip stream completed successfully, look for another */ + if (ret == Z_STREAM_END) + state->how = LOOK; + + /* good decompression */ + return 0; +} + +/* Fetch data and put it in the output buffer. Assumes state->x.have is 0. + Data is either copied from the input file or decompressed from the input + file depending on state->how. If state->how is LOOK, then a gzip header is + looked for to determine whether to copy or decompress. Returns -1 on error, + otherwise 0. gz_fetch() will leave state->how as COPY or GZIP unless the + end of the input file has been reached and all data has been processed. */ +local int gz_fetch(state) + gz_statep state; +{ + z_streamp strm = &(state->strm); + + do { + switch(state->how) { + case LOOK: /* -> LOOK, COPY (only if never GZIP), or GZIP */ + if (gz_look(state) == -1) + return -1; + if (state->how == LOOK) + return 0; + break; + case COPY: /* -> COPY */ + if (gz_load(state, state->out, state->size << 1, &(state->x.have)) + == -1) + return -1; + state->x.next = state->out; + return 0; + case GZIP: /* -> GZIP or LOOK (if end of gzip stream) */ + strm->avail_out = state->size << 1; + strm->next_out = state->out; + if (gz_decomp(state) == -1) + return -1; + } + } while (state->x.have == 0 && (!state->eof || strm->avail_in)); + return 0; +} + +/* Skip len uncompressed bytes of output. Return -1 on error, 0 on success. */ +local int gz_skip(state, len) + gz_statep state; + z_off64_t len; +{ + unsigned n; + + /* skip over len bytes or reach end-of-file, whichever comes first */ + while (len) + /* skip over whatever is in output buffer */ + if (state->x.have) { + n = GT_OFF(state->x.have) || (z_off64_t)state->x.have > len ? + (unsigned)len : state->x.have; + state->x.have -= n; + state->x.next += n; + state->x.pos += n; + len -= n; + } + + /* output buffer empty -- return if we're at the end of the input */ + else if (state->eof && state->strm.avail_in == 0) + break; + + /* need more data to skip -- load up output buffer */ + else { + /* get more output, looking for header if required */ + if (gz_fetch(state) == -1) + return -1; + } + return 0; +} + +/* -- see zlib.h -- */ +int ZEXPORT gzread(file, buf, len) + gzFile file; + voidp buf; + unsigned len; +{ + unsigned got, n; + gz_statep state; + z_streamp strm; + + /* get internal structure */ + if (file == NULL) + return -1; + state = (gz_statep)file; + strm = &(state->strm); + + /* check that we're reading and that there's no (serious) error */ + if (state->mode != GZ_READ || + (state->err != Z_OK && state->err != Z_BUF_ERROR)) + return -1; + + /* since an int is returned, make sure len fits in one, otherwise return + with an error (this avoids the flaw in the interface) */ + if ((int)len < 0) { + gz_error(state, Z_DATA_ERROR, "requested length does not fit in int"); + return -1; + } + + /* if len is zero, avoid unnecessary operations */ + if (len == 0) + return 0; + + /* process a skip request */ + if (state->seek) { + state->seek = 0; + if (gz_skip(state, state->skip) == -1) + return -1; + } + + /* get len bytes to buf, or less than len if at the end */ + got = 0; + do { + /* first just try copying data from the output buffer */ + if (state->x.have) { + n = state->x.have > len ? len : state->x.have; + memcpy(buf, state->x.next, n); + state->x.next += n; + state->x.have -= n; + } + + /* output buffer empty -- return if we're at the end of the input */ + else if (state->eof && strm->avail_in == 0) { + state->past = 1; /* tried to read past end */ + break; + } + + /* need output data -- for small len or new stream load up our output + buffer */ + else if (state->how == LOOK || len < (state->size << 1)) { + /* get more output, looking for header if required */ + if (gz_fetch(state) == -1) + return -1; + continue; /* no progress yet -- go back to copy above */ + /* the copy above assures that we will leave with space in the + output buffer, allowing at least one gzungetc() to succeed */ + } + + /* large len -- read directly into user buffer */ + else if (state->how == COPY) { /* read directly */ + if (gz_load(state, (unsigned char *)buf, len, &n) == -1) + return -1; + } + + /* large len -- decompress directly into user buffer */ + else { /* state->how == GZIP */ + strm->avail_out = len; + strm->next_out = (unsigned char *)buf; + if (gz_decomp(state) == -1) + return -1; + n = state->x.have; + state->x.have = 0; + } + + /* update progress */ + len -= n; + buf = (char *)buf + n; + got += n; + state->x.pos += n; + } while (len); + + /* return number of bytes read into user buffer (will fit in int) */ + return (int)got; +} + +/* -- see zlib.h -- */ +#ifdef Z_PREFIX_SET +# undef z_gzgetc +#else +# undef gzgetc +#endif +int ZEXPORT gzgetc(file) + gzFile file; +{ + int ret; + unsigned char buf[1]; + gz_statep state; + + /* get internal structure */ + if (file == NULL) + return -1; + state = (gz_statep)file; + + /* check that we're reading and that there's no (serious) error */ + if (state->mode != GZ_READ || + (state->err != Z_OK && state->err != Z_BUF_ERROR)) + return -1; + + /* try output buffer (no need to check for skip request) */ + if (state->x.have) { + state->x.have--; + state->x.pos++; + return *(state->x.next)++; + } + + /* nothing there -- try gzread() */ + ret = gzread(file, buf, 1); + return ret < 1 ? -1 : buf[0]; +} + +int ZEXPORT gzgetc_(file) +gzFile file; +{ + return gzgetc(file); +} + +/* -- see zlib.h -- */ +int ZEXPORT gzungetc(c, file) + int c; + gzFile file; +{ + gz_statep state; + + /* get internal structure */ + if (file == NULL) + return -1; + state = (gz_statep)file; + + /* check that we're reading and that there's no (serious) error */ + if (state->mode != GZ_READ || + (state->err != Z_OK && state->err != Z_BUF_ERROR)) + return -1; + + /* process a skip request */ + if (state->seek) { + state->seek = 0; + if (gz_skip(state, state->skip) == -1) + return -1; + } + + /* can't push EOF */ + if (c < 0) + return -1; + + /* if output buffer empty, put byte at end (allows more pushing) */ + if (state->x.have == 0) { + state->x.have = 1; + state->x.next = state->out + (state->size << 1) - 1; + state->x.next[0] = c; + state->x.pos--; + state->past = 0; + return c; + } + + /* if no room, give up (must have already done a gzungetc()) */ + if (state->x.have == (state->size << 1)) { + gz_error(state, Z_DATA_ERROR, "out of room to push characters"); + return -1; + } + + /* slide output data if needed and insert byte before existing data */ + if (state->x.next == state->out) { + unsigned char *src = state->out + state->x.have; + unsigned char *dest = state->out + (state->size << 1); + while (src > state->out) + *--dest = *--src; + state->x.next = dest; + } + state->x.have++; + state->x.next--; + state->x.next[0] = c; + state->x.pos--; + state->past = 0; + return c; +} + +/* -- see zlib.h -- */ +char * ZEXPORT gzgets(file, buf, len) + gzFile file; + char *buf; + int len; +{ + unsigned left, n; + char *str; + unsigned char *eol; + gz_statep state; + + /* check parameters and get internal structure */ + if (file == NULL || buf == NULL || len < 1) + return NULL; + state = (gz_statep)file; + + /* check that we're reading and that there's no (serious) error */ + if (state->mode != GZ_READ || + (state->err != Z_OK && state->err != Z_BUF_ERROR)) + return NULL; + + /* process a skip request */ + if (state->seek) { + state->seek = 0; + if (gz_skip(state, state->skip) == -1) + return NULL; + } + + /* copy output bytes up to new line or len - 1, whichever comes first -- + append a terminating zero to the string (we don't check for a zero in + the contents, let the user worry about that) */ + str = buf; + left = (unsigned)len - 1; + if (left) do { + /* assure that something is in the output buffer */ + if (state->x.have == 0 && gz_fetch(state) == -1) + return NULL; /* error */ + if (state->x.have == 0) { /* end of file */ + state->past = 1; /* read past end */ + break; /* return what we have */ + } + + /* look for end-of-line in current output buffer */ + n = state->x.have > left ? left : state->x.have; + eol = (unsigned char *)memchr(state->x.next, '\n', n); + if (eol != NULL) + n = (unsigned)(eol - state->x.next) + 1; + + /* copy through end-of-line, or remainder if not found */ + memcpy(buf, state->x.next, n); + state->x.have -= n; + state->x.next += n; + state->x.pos += n; + left -= n; + buf += n; + } while (left && eol == NULL); + + /* return terminated string, or if nothing, end of file */ + if (buf == str) + return NULL; + buf[0] = 0; + return str; +} + +/* -- see zlib.h -- */ +int ZEXPORT gzdirect(file) + gzFile file; +{ + gz_statep state; + + /* get internal structure */ + if (file == NULL) + return 0; + state = (gz_statep)file; + + /* if the state is not known, but we can find out, then do so (this is + mainly for right after a gzopen() or gzdopen()) */ + if (state->mode == GZ_READ && state->how == LOOK && state->x.have == 0) + (void)gz_look(state); + + /* return 1 if transparent, 0 if processing a gzip stream */ + return state->direct; +} + +/* -- see zlib.h -- */ +int ZEXPORT gzclose_r(file) + gzFile file; +{ + int ret, err; + gz_statep state; + + /* get internal structure */ + if (file == NULL) + return Z_STREAM_ERROR; + state = (gz_statep)file; + + /* check that we're reading */ + if (state->mode != GZ_READ) + return Z_STREAM_ERROR; + + /* free memory and close file */ + if (state->size) { + inflateEnd(&(state->strm)); + free(state->out); + free(state->in); + } + err = state->err == Z_BUF_ERROR ? Z_BUF_ERROR : Z_OK; + gz_error(state, Z_OK, NULL); + free(state->path); + ret = close(state->fd); + free(state); + return ret ? Z_ERRNO : err; +} diff --git a/ml/dlib/dlib/external/zlib/gzwrite.c b/ml/dlib/dlib/external/zlib/gzwrite.c new file mode 100644 index 000000000..aa767fbf6 --- /dev/null +++ b/ml/dlib/dlib/external/zlib/gzwrite.c @@ -0,0 +1,577 @@ +/* gzwrite.c -- zlib functions for writing gzip files + * Copyright (C) 2004, 2005, 2010, 2011, 2012, 2013 Mark Adler + * For conditions of distribution and use, see copyright notice in zlib.h + */ + +#include "gzguts.h" + +/* Local functions */ +local int gz_init OF((gz_statep)); +local int gz_comp OF((gz_statep, int)); +local int gz_zero OF((gz_statep, z_off64_t)); + +/* Initialize state for writing a gzip file. Mark initialization by setting + state->size to non-zero. Return -1 on failure or 0 on success. */ +local int gz_init(state) + gz_statep state; +{ + int ret; + z_streamp strm = &(state->strm); + + /* allocate input buffer */ + state->in = (unsigned char *)malloc(state->want); + if (state->in == NULL) { + gz_error(state, Z_MEM_ERROR, "out of memory"); + return -1; + } + + /* only need output buffer and deflate state if compressing */ + if (!state->direct) { + /* allocate output buffer */ + state->out = (unsigned char *)malloc(state->want); + if (state->out == NULL) { + free(state->in); + gz_error(state, Z_MEM_ERROR, "out of memory"); + return -1; + } + + /* allocate deflate memory, set up for gzip compression */ + strm->zalloc = Z_NULL; + strm->zfree = Z_NULL; + strm->opaque = Z_NULL; + ret = deflateInit2(strm, state->level, Z_DEFLATED, + MAX_WBITS + 16, DEF_MEM_LEVEL, state->strategy); + if (ret != Z_OK) { + free(state->out); + free(state->in); + gz_error(state, Z_MEM_ERROR, "out of memory"); + return -1; + } + } + + /* mark state as initialized */ + state->size = state->want; + + /* initialize write buffer if compressing */ + if (!state->direct) { + strm->avail_out = state->size; + strm->next_out = state->out; + state->x.next = strm->next_out; + } + return 0; +} + +/* Compress whatever is at avail_in and next_in and write to the output file. + Return -1 if there is an error writing to the output file, otherwise 0. + flush is assumed to be a valid deflate() flush value. If flush is Z_FINISH, + then the deflate() state is reset to start a new gzip stream. If gz->direct + is true, then simply write to the output file without compressing, and + ignore flush. */ +local int gz_comp(state, flush) + gz_statep state; + int flush; +{ + int ret, got; + unsigned have; + z_streamp strm = &(state->strm); + + /* allocate memory if this is the first time through */ + if (state->size == 0 && gz_init(state) == -1) + return -1; + + /* write directly if requested */ + if (state->direct) { + got = write(state->fd, strm->next_in, strm->avail_in); + if (got < 0 || (unsigned)got != strm->avail_in) { + gz_error(state, Z_ERRNO, zstrerror()); + return -1; + } + strm->avail_in = 0; + return 0; + } + + /* run deflate() on provided input until it produces no more output */ + ret = Z_OK; + do { + /* write out current buffer contents if full, or if flushing, but if + doing Z_FINISH then don't write until we get to Z_STREAM_END */ + if (strm->avail_out == 0 || (flush != Z_NO_FLUSH && + (flush != Z_FINISH || ret == Z_STREAM_END))) { + have = (unsigned)(strm->next_out - state->x.next); + if (have && ((got = write(state->fd, state->x.next, have)) < 0 || + (unsigned)got != have)) { + gz_error(state, Z_ERRNO, zstrerror()); + return -1; + } + if (strm->avail_out == 0) { + strm->avail_out = state->size; + strm->next_out = state->out; + } + state->x.next = strm->next_out; + } + + /* compress */ + have = strm->avail_out; + ret = deflate(strm, flush); + if (ret == Z_STREAM_ERROR) { + gz_error(state, Z_STREAM_ERROR, + "internal error: deflate stream corrupt"); + return -1; + } + have -= strm->avail_out; + } while (have); + + /* if that completed a deflate stream, allow another to start */ + if (flush == Z_FINISH) + deflateReset(strm); + + /* all done, no errors */ + return 0; +} + +/* Compress len zeros to output. Return -1 on error, 0 on success. */ +local int gz_zero(state, len) + gz_statep state; + z_off64_t len; +{ + int first; + unsigned n; + z_streamp strm = &(state->strm); + + /* consume whatever's left in the input buffer */ + if (strm->avail_in && gz_comp(state, Z_NO_FLUSH) == -1) + return -1; + + /* compress len zeros (len guaranteed > 0) */ + first = 1; + while (len) { + n = GT_OFF(state->size) || (z_off64_t)state->size > len ? + (unsigned)len : state->size; + if (first) { + memset(state->in, 0, n); + first = 0; + } + strm->avail_in = n; + strm->next_in = state->in; + state->x.pos += n; + if (gz_comp(state, Z_NO_FLUSH) == -1) + return -1; + len -= n; + } + return 0; +} + +/* -- see zlib.h -- */ +int ZEXPORT gzwrite(file, buf, len) + gzFile file; + voidpc buf; + unsigned len; +{ + unsigned put = len; + gz_statep state; + z_streamp strm; + + /* get internal structure */ + if (file == NULL) + return 0; + state = (gz_statep)file; + strm = &(state->strm); + + /* check that we're writing and that there's no error */ + if (state->mode != GZ_WRITE || state->err != Z_OK) + return 0; + + /* since an int is returned, make sure len fits in one, otherwise return + with an error (this avoids the flaw in the interface) */ + if ((int)len < 0) { + gz_error(state, Z_DATA_ERROR, "requested length does not fit in int"); + return 0; + } + + /* if len is zero, avoid unnecessary operations */ + if (len == 0) + return 0; + + /* allocate memory if this is the first time through */ + if (state->size == 0 && gz_init(state) == -1) + return 0; + + /* check for seek request */ + if (state->seek) { + state->seek = 0; + if (gz_zero(state, state->skip) == -1) + return 0; + } + + /* for small len, copy to input buffer, otherwise compress directly */ + if (len < state->size) { + /* copy to input buffer, compress when full */ + do { + unsigned have, copy; + + if (strm->avail_in == 0) + strm->next_in = state->in; + have = (unsigned)((strm->next_in + strm->avail_in) - state->in); + copy = state->size - have; + if (copy > len) + copy = len; + memcpy(state->in + have, buf, copy); + strm->avail_in += copy; + state->x.pos += copy; + buf = (const char *)buf + copy; + len -= copy; + if (len && gz_comp(state, Z_NO_FLUSH) == -1) + return 0; + } while (len); + } + else { + /* consume whatever's left in the input buffer */ + if (strm->avail_in && gz_comp(state, Z_NO_FLUSH) == -1) + return 0; + + /* directly compress user buffer to file */ + strm->avail_in = len; + strm->next_in = (z_const Bytef *)buf; + state->x.pos += len; + if (gz_comp(state, Z_NO_FLUSH) == -1) + return 0; + } + + /* input was all buffered or compressed (put will fit in int) */ + return (int)put; +} + +/* -- see zlib.h -- */ +int ZEXPORT gzputc(file, c) + gzFile file; + int c; +{ + unsigned have; + unsigned char buf[1]; + gz_statep state; + z_streamp strm; + + /* get internal structure */ + if (file == NULL) + return -1; + state = (gz_statep)file; + strm = &(state->strm); + + /* check that we're writing and that there's no error */ + if (state->mode != GZ_WRITE || state->err != Z_OK) + return -1; + + /* check for seek request */ + if (state->seek) { + state->seek = 0; + if (gz_zero(state, state->skip) == -1) + return -1; + } + + /* try writing to input buffer for speed (state->size == 0 if buffer not + initialized) */ + if (state->size) { + if (strm->avail_in == 0) + strm->next_in = state->in; + have = (unsigned)((strm->next_in + strm->avail_in) - state->in); + if (have < state->size) { + state->in[have] = c; + strm->avail_in++; + state->x.pos++; + return c & 0xff; + } + } + + /* no room in buffer or not initialized, use gz_write() */ + buf[0] = c; + if (gzwrite(file, buf, 1) != 1) + return -1; + return c & 0xff; +} + +/* -- see zlib.h -- */ +int ZEXPORT gzputs(file, str) + gzFile file; + const char *str; +{ + int ret; + unsigned len; + + /* write string */ + len = (unsigned)strlen(str); + ret = gzwrite(file, str, len); + return ret == 0 && len != 0 ? -1 : ret; +} + +#if defined(STDC) || defined(Z_HAVE_STDARG_H) +#include + +/* -- see zlib.h -- */ +int ZEXPORTVA gzvprintf(gzFile file, const char *format, va_list va) +{ + int size, len; + gz_statep state; + z_streamp strm; + + /* get internal structure */ + if (file == NULL) + return -1; + state = (gz_statep)file; + strm = &(state->strm); + + /* check that we're writing and that there's no error */ + if (state->mode != GZ_WRITE || state->err != Z_OK) + return 0; + + /* make sure we have some buffer space */ + if (state->size == 0 && gz_init(state) == -1) + return 0; + + /* check for seek request */ + if (state->seek) { + state->seek = 0; + if (gz_zero(state, state->skip) == -1) + return 0; + } + + /* consume whatever's left in the input buffer */ + if (strm->avail_in && gz_comp(state, Z_NO_FLUSH) == -1) + return 0; + + /* do the printf() into the input buffer, put length in len */ + size = (int)(state->size); + state->in[size - 1] = 0; +#ifdef NO_vsnprintf +# ifdef HAS_vsprintf_void + (void)vsprintf((char *)(state->in), format, va); + for (len = 0; len < size; len++) + if (state->in[len] == 0) break; +# else + len = vsprintf((char *)(state->in), format, va); +# endif +#else +# ifdef HAS_vsnprintf_void + (void)vsnprintf((char *)(state->in), size, format, va); + len = strlen((char *)(state->in)); +# else + len = vsnprintf((char *)(state->in), size, format, va); +# endif +#endif + + /* check that printf() results fit in buffer */ + if (len <= 0 || len >= (int)size || state->in[size - 1] != 0) + return 0; + + /* update buffer and position, defer compression until needed */ + strm->avail_in = (unsigned)len; + strm->next_in = state->in; + state->x.pos += len; + return len; +} + +int ZEXPORTVA gzprintf(gzFile file, const char *format, ...) +{ + va_list va; + int ret; + + va_start(va, format); + ret = gzvprintf(file, format, va); + va_end(va); + return ret; +} + +#else /* !STDC && !Z_HAVE_STDARG_H */ + +/* -- see zlib.h -- */ +int ZEXPORTVA gzprintf (file, format, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, + a11, a12, a13, a14, a15, a16, a17, a18, a19, a20) + gzFile file; + const char *format; + int a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, + a11, a12, a13, a14, a15, a16, a17, a18, a19, a20; +{ + int size, len; + gz_statep state; + z_streamp strm; + + /* get internal structure */ + if (file == NULL) + return -1; + state = (gz_statep)file; + strm = &(state->strm); + + /* check that can really pass pointer in ints */ + if (sizeof(int) != sizeof(void *)) + return 0; + + /* check that we're writing and that there's no error */ + if (state->mode != GZ_WRITE || state->err != Z_OK) + return 0; + + /* make sure we have some buffer space */ + if (state->size == 0 && gz_init(state) == -1) + return 0; + + /* check for seek request */ + if (state->seek) { + state->seek = 0; + if (gz_zero(state, state->skip) == -1) + return 0; + } + + /* consume whatever's left in the input buffer */ + if (strm->avail_in && gz_comp(state, Z_NO_FLUSH) == -1) + return 0; + + /* do the printf() into the input buffer, put length in len */ + size = (int)(state->size); + state->in[size - 1] = 0; +#ifdef NO_snprintf +# ifdef HAS_sprintf_void + sprintf((char *)(state->in), format, a1, a2, a3, a4, a5, a6, a7, a8, + a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19, a20); + for (len = 0; len < size; len++) + if (state->in[len] == 0) break; +# else + len = sprintf((char *)(state->in), format, a1, a2, a3, a4, a5, a6, a7, a8, + a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19, a20); +# endif +#else +# ifdef HAS_snprintf_void + snprintf((char *)(state->in), size, format, a1, a2, a3, a4, a5, a6, a7, a8, + a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19, a20); + len = strlen((char *)(state->in)); +# else + len = snprintf((char *)(state->in), size, format, a1, a2, a3, a4, a5, a6, + a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, + a19, a20); +# endif +#endif + + /* check that printf() results fit in buffer */ + if (len <= 0 || len >= (int)size || state->in[size - 1] != 0) + return 0; + + /* update buffer and position, defer compression until needed */ + strm->avail_in = (unsigned)len; + strm->next_in = state->in; + state->x.pos += len; + return len; +} + +#endif + +/* -- see zlib.h -- */ +int ZEXPORT gzflush(file, flush) + gzFile file; + int flush; +{ + gz_statep state; + + /* get internal structure */ + if (file == NULL) + return -1; + state = (gz_statep)file; + + /* check that we're writing and that there's no error */ + if (state->mode != GZ_WRITE || state->err != Z_OK) + return Z_STREAM_ERROR; + + /* check flush parameter */ + if (flush < 0 || flush > Z_FINISH) + return Z_STREAM_ERROR; + + /* check for seek request */ + if (state->seek) { + state->seek = 0; + if (gz_zero(state, state->skip) == -1) + return -1; + } + + /* compress remaining data with requested flush */ + gz_comp(state, flush); + return state->err; +} + +/* -- see zlib.h -- */ +int ZEXPORT gzsetparams(file, level, strategy) + gzFile file; + int level; + int strategy; +{ + gz_statep state; + z_streamp strm; + + /* get internal structure */ + if (file == NULL) + return Z_STREAM_ERROR; + state = (gz_statep)file; + strm = &(state->strm); + + /* check that we're writing and that there's no error */ + if (state->mode != GZ_WRITE || state->err != Z_OK) + return Z_STREAM_ERROR; + + /* if no change is requested, then do nothing */ + if (level == state->level && strategy == state->strategy) + return Z_OK; + + /* check for seek request */ + if (state->seek) { + state->seek = 0; + if (gz_zero(state, state->skip) == -1) + return -1; + } + + /* change compression parameters for subsequent input */ + if (state->size) { + /* flush previous input with previous parameters before changing */ + if (strm->avail_in && gz_comp(state, Z_PARTIAL_FLUSH) == -1) + return state->err; + deflateParams(strm, level, strategy); + } + state->level = level; + state->strategy = strategy; + return Z_OK; +} + +/* -- see zlib.h -- */ +int ZEXPORT gzclose_w(file) + gzFile file; +{ + int ret = Z_OK; + gz_statep state; + + /* get internal structure */ + if (file == NULL) + return Z_STREAM_ERROR; + state = (gz_statep)file; + + /* check that we're writing */ + if (state->mode != GZ_WRITE) + return Z_STREAM_ERROR; + + /* check for seek request */ + if (state->seek) { + state->seek = 0; + if (gz_zero(state, state->skip) == -1) + ret = state->err; + } + + /* flush, free memory, and close file */ + if (gz_comp(state, Z_FINISH) == -1) + ret = state->err; + if (state->size) { + if (!state->direct) { + (void)deflateEnd(&(state->strm)); + free(state->out); + } + free(state->in); + } + gz_error(state, Z_OK, NULL); + free(state->path); + if (close(state->fd) == -1) + ret = Z_ERRNO; + free(state); + return ret; +} diff --git a/ml/dlib/dlib/external/zlib/infback.c b/ml/dlib/dlib/external/zlib/infback.c new file mode 100644 index 000000000..f3833c2e4 --- /dev/null +++ b/ml/dlib/dlib/external/zlib/infback.c @@ -0,0 +1,640 @@ +/* infback.c -- inflate using a call-back interface + * Copyright (C) 1995-2011 Mark Adler + * For conditions of distribution and use, see copyright notice in zlib.h + */ + +/* + This code is largely copied from inflate.c. Normally either infback.o or + inflate.o would be linked into an application--not both. The interface + with inffast.c is retained so that optimized assembler-coded versions of + inflate_fast() can be used with either inflate.c or infback.c. + */ + +#include "zutil.h" +#include "inftrees.h" +#include "inflate.h" +#include "inffast.h" + +/* function prototypes */ +local void fixedtables OF((struct inflate_state FAR *state)); + +/* + strm provides memory allocation functions in zalloc and zfree, or + Z_NULL to use the library memory allocation functions. + + windowBits is in the range 8..15, and window is a user-supplied + window and output buffer that is 2**windowBits bytes. + */ +int ZEXPORT inflateBackInit_(strm, windowBits, window, version, stream_size) +z_streamp strm; +int windowBits; +unsigned char FAR *window; +const char *version; +int stream_size; +{ + struct inflate_state FAR *state; + + if (version == Z_NULL || version[0] != ZLIB_VERSION[0] || + stream_size != (int)(sizeof(z_stream))) + return Z_VERSION_ERROR; + if (strm == Z_NULL || window == Z_NULL || + windowBits < 8 || windowBits > 15) + return Z_STREAM_ERROR; + strm->msg = Z_NULL; /* in case we return an error */ + if (strm->zalloc == (alloc_func)0) { +#ifdef Z_SOLO + return Z_STREAM_ERROR; +#else + strm->zalloc = zcalloc; + strm->opaque = (voidpf)0; +#endif + } + if (strm->zfree == (free_func)0) +#ifdef Z_SOLO + return Z_STREAM_ERROR; +#else + strm->zfree = zcfree; +#endif + state = (struct inflate_state FAR *)ZALLOC(strm, 1, + sizeof(struct inflate_state)); + if (state == Z_NULL) return Z_MEM_ERROR; + Tracev((stderr, "inflate: allocated\n")); + strm->state = (struct internal_state FAR *)state; + state->dmax = 32768U; + state->wbits = windowBits; + state->wsize = 1U << windowBits; + state->window = window; + state->wnext = 0; + state->whave = 0; + return Z_OK; +} + +/* + Return state with length and distance decoding tables and index sizes set to + fixed code decoding. Normally this returns fixed tables from inffixed.h. + If BUILDFIXED is defined, then instead this routine builds the tables the + first time it's called, and returns those tables the first time and + thereafter. This reduces the size of the code by about 2K bytes, in + exchange for a little execution time. However, BUILDFIXED should not be + used for threaded applications, since the rewriting of the tables and virgin + may not be thread-safe. + */ +local void fixedtables(state) +struct inflate_state FAR *state; +{ +#ifdef BUILDFIXED + static int virgin = 1; + static code *lenfix, *distfix; + static code fixed[544]; + + /* build fixed huffman tables if first call (may not be thread safe) */ + if (virgin) { + unsigned sym, bits; + static code *next; + + /* literal/length table */ + sym = 0; + while (sym < 144) state->lens[sym++] = 8; + while (sym < 256) state->lens[sym++] = 9; + while (sym < 280) state->lens[sym++] = 7; + while (sym < 288) state->lens[sym++] = 8; + next = fixed; + lenfix = next; + bits = 9; + inflate_table(LENS, state->lens, 288, &(next), &(bits), state->work); + + /* distance table */ + sym = 0; + while (sym < 32) state->lens[sym++] = 5; + distfix = next; + bits = 5; + inflate_table(DISTS, state->lens, 32, &(next), &(bits), state->work); + + /* do this just once */ + virgin = 0; + } +#else /* !BUILDFIXED */ +# include "inffixed.h" +#endif /* BUILDFIXED */ + state->lencode = lenfix; + state->lenbits = 9; + state->distcode = distfix; + state->distbits = 5; +} + +/* Macros for inflateBack(): */ + +/* Load returned state from inflate_fast() */ +#define LOAD() \ + do { \ + put = strm->next_out; \ + left = strm->avail_out; \ + next = strm->next_in; \ + have = strm->avail_in; \ + hold = state->hold; \ + bits = state->bits; \ + } while (0) + +/* Set state from registers for inflate_fast() */ +#define RESTORE() \ + do { \ + strm->next_out = put; \ + strm->avail_out = left; \ + strm->next_in = next; \ + strm->avail_in = have; \ + state->hold = hold; \ + state->bits = bits; \ + } while (0) + +/* Clear the input bit accumulator */ +#define INITBITS() \ + do { \ + hold = 0; \ + bits = 0; \ + } while (0) + +/* Assure that some input is available. If input is requested, but denied, + then return a Z_BUF_ERROR from inflateBack(). */ +#define PULL() \ + do { \ + if (have == 0) { \ + have = in(in_desc, &next); \ + if (have == 0) { \ + next = Z_NULL; \ + ret = Z_BUF_ERROR; \ + goto inf_leave; \ + } \ + } \ + } while (0) + +/* Get a byte of input into the bit accumulator, or return from inflateBack() + with an error if there is no input available. */ +#define PULLBYTE() \ + do { \ + PULL(); \ + have--; \ + hold += (unsigned long)(*next++) << bits; \ + bits += 8; \ + } while (0) + +/* Assure that there are at least n bits in the bit accumulator. If there is + not enough available input to do that, then return from inflateBack() with + an error. */ +#define NEEDBITS(n) \ + do { \ + while (bits < (unsigned)(n)) \ + PULLBYTE(); \ + } while (0) + +/* Return the low n bits of the bit accumulator (n < 16) */ +#define BITS(n) \ + ((unsigned)hold & ((1U << (n)) - 1)) + +/* Remove n bits from the bit accumulator */ +#define DROPBITS(n) \ + do { \ + hold >>= (n); \ + bits -= (unsigned)(n); \ + } while (0) + +/* Remove zero to seven bits as needed to go to a byte boundary */ +#define BYTEBITS() \ + do { \ + hold >>= bits & 7; \ + bits -= bits & 7; \ + } while (0) + +/* Assure that some output space is available, by writing out the window + if it's full. If the write fails, return from inflateBack() with a + Z_BUF_ERROR. */ +#define ROOM() \ + do { \ + if (left == 0) { \ + put = state->window; \ + left = state->wsize; \ + state->whave = left; \ + if (out(out_desc, put, left)) { \ + ret = Z_BUF_ERROR; \ + goto inf_leave; \ + } \ + } \ + } while (0) + +/* + strm provides the memory allocation functions and window buffer on input, + and provides information on the unused input on return. For Z_DATA_ERROR + returns, strm will also provide an error message. + + in() and out() are the call-back input and output functions. When + inflateBack() needs more input, it calls in(). When inflateBack() has + filled the window with output, or when it completes with data in the + window, it calls out() to write out the data. The application must not + change the provided input until in() is called again or inflateBack() + returns. The application must not change the window/output buffer until + inflateBack() returns. + + in() and out() are called with a descriptor parameter provided in the + inflateBack() call. This parameter can be a structure that provides the + information required to do the read or write, as well as accumulated + information on the input and output such as totals and check values. + + in() should return zero on failure. out() should return non-zero on + failure. If either in() or out() fails, than inflateBack() returns a + Z_BUF_ERROR. strm->next_in can be checked for Z_NULL to see whether it + was in() or out() that caused in the error. Otherwise, inflateBack() + returns Z_STREAM_END on success, Z_DATA_ERROR for an deflate format + error, or Z_MEM_ERROR if it could not allocate memory for the state. + inflateBack() can also return Z_STREAM_ERROR if the input parameters + are not correct, i.e. strm is Z_NULL or the state was not initialized. + */ +int ZEXPORT inflateBack(strm, in, in_desc, out, out_desc) +z_streamp strm; +in_func in; +void FAR *in_desc; +out_func out; +void FAR *out_desc; +{ + struct inflate_state FAR *state; + z_const unsigned char FAR *next; /* next input */ + unsigned char FAR *put; /* next output */ + unsigned have, left; /* available input and output */ + unsigned long hold; /* bit buffer */ + unsigned bits; /* bits in bit buffer */ + unsigned copy; /* number of stored or match bytes to copy */ + unsigned char FAR *from; /* where to copy match bytes from */ + code here; /* current decoding table entry */ + code last; /* parent table entry */ + unsigned len; /* length to copy for repeats, bits to drop */ + int ret; /* return code */ + static const unsigned short order[19] = /* permutation of code lengths */ + {16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15}; + + /* Check that the strm exists and that the state was initialized */ + if (strm == Z_NULL || strm->state == Z_NULL) + return Z_STREAM_ERROR; + state = (struct inflate_state FAR *)strm->state; + + /* Reset the state */ + strm->msg = Z_NULL; + state->mode = TYPE; + state->last = 0; + state->whave = 0; + next = strm->next_in; + have = next != Z_NULL ? strm->avail_in : 0; + hold = 0; + bits = 0; + put = state->window; + left = state->wsize; + + /* Inflate until end of block marked as last */ + for (;;) + switch (state->mode) { + case TYPE: + /* determine and dispatch block type */ + if (state->last) { + BYTEBITS(); + state->mode = DONE; + break; + } + NEEDBITS(3); + state->last = BITS(1); + DROPBITS(1); + switch (BITS(2)) { + case 0: /* stored block */ + Tracev((stderr, "inflate: stored block%s\n", + state->last ? " (last)" : "")); + state->mode = STORED; + break; + case 1: /* fixed block */ + fixedtables(state); + Tracev((stderr, "inflate: fixed codes block%s\n", + state->last ? " (last)" : "")); + state->mode = LEN; /* decode codes */ + break; + case 2: /* dynamic block */ + Tracev((stderr, "inflate: dynamic codes block%s\n", + state->last ? " (last)" : "")); + state->mode = TABLE; + break; + case 3: + strm->msg = (char *)"invalid block type"; + state->mode = BAD; + } + DROPBITS(2); + break; + + case STORED: + /* get and verify stored block length */ + BYTEBITS(); /* go to byte boundary */ + NEEDBITS(32); + if ((hold & 0xffff) != ((hold >> 16) ^ 0xffff)) { + strm->msg = (char *)"invalid stored block lengths"; + state->mode = BAD; + break; + } + state->length = (unsigned)hold & 0xffff; + Tracev((stderr, "inflate: stored length %u\n", + state->length)); + INITBITS(); + + /* copy stored block from input to output */ + while (state->length != 0) { + copy = state->length; + PULL(); + ROOM(); + if (copy > have) copy = have; + if (copy > left) copy = left; + zmemcpy(put, next, copy); + have -= copy; + next += copy; + left -= copy; + put += copy; + state->length -= copy; + } + Tracev((stderr, "inflate: stored end\n")); + state->mode = TYPE; + break; + + case TABLE: + /* get dynamic table entries descriptor */ + NEEDBITS(14); + state->nlen = BITS(5) + 257; + DROPBITS(5); + state->ndist = BITS(5) + 1; + DROPBITS(5); + state->ncode = BITS(4) + 4; + DROPBITS(4); +#ifndef PKZIP_BUG_WORKAROUND + if (state->nlen > 286 || state->ndist > 30) { + strm->msg = (char *)"too many length or distance symbols"; + state->mode = BAD; + break; + } +#endif + Tracev((stderr, "inflate: table sizes ok\n")); + + /* get code length code lengths (not a typo) */ + state->have = 0; + while (state->have < state->ncode) { + NEEDBITS(3); + state->lens[order[state->have++]] = (unsigned short)BITS(3); + DROPBITS(3); + } + while (state->have < 19) + state->lens[order[state->have++]] = 0; + state->next = state->codes; + state->lencode = (code const FAR *)(state->next); + state->lenbits = 7; + ret = inflate_table(CODES, state->lens, 19, &(state->next), + &(state->lenbits), state->work); + if (ret) { + strm->msg = (char *)"invalid code lengths set"; + state->mode = BAD; + break; + } + Tracev((stderr, "inflate: code lengths ok\n")); + + /* get length and distance code code lengths */ + state->have = 0; + while (state->have < state->nlen + state->ndist) { + for (;;) { + here = state->lencode[BITS(state->lenbits)]; + if ((unsigned)(here.bits) <= bits) break; + PULLBYTE(); + } + if (here.val < 16) { + DROPBITS(here.bits); + state->lens[state->have++] = here.val; + } + else { + if (here.val == 16) { + NEEDBITS(here.bits + 2); + DROPBITS(here.bits); + if (state->have == 0) { + strm->msg = (char *)"invalid bit length repeat"; + state->mode = BAD; + break; + } + len = (unsigned)(state->lens[state->have - 1]); + copy = 3 + BITS(2); + DROPBITS(2); + } + else if (here.val == 17) { + NEEDBITS(here.bits + 3); + DROPBITS(here.bits); + len = 0; + copy = 3 + BITS(3); + DROPBITS(3); + } + else { + NEEDBITS(here.bits + 7); + DROPBITS(here.bits); + len = 0; + copy = 11 + BITS(7); + DROPBITS(7); + } + if (state->have + copy > state->nlen + state->ndist) { + strm->msg = (char *)"invalid bit length repeat"; + state->mode = BAD; + break; + } + while (copy--) + state->lens[state->have++] = (unsigned short)len; + } + } + + /* handle error breaks in while */ + if (state->mode == BAD) break; + + /* check for end-of-block code (better have one) */ + if (state->lens[256] == 0) { + strm->msg = (char *)"invalid code -- missing end-of-block"; + state->mode = BAD; + break; + } + + /* build code tables -- note: do not change the lenbits or distbits + values here (9 and 6) without reading the comments in inftrees.h + concerning the ENOUGH constants, which depend on those values */ + state->next = state->codes; + state->lencode = (code const FAR *)(state->next); + state->lenbits = 9; + ret = inflate_table(LENS, state->lens, state->nlen, &(state->next), + &(state->lenbits), state->work); + if (ret) { + strm->msg = (char *)"invalid literal/lengths set"; + state->mode = BAD; + break; + } + state->distcode = (code const FAR *)(state->next); + state->distbits = 6; + ret = inflate_table(DISTS, state->lens + state->nlen, state->ndist, + &(state->next), &(state->distbits), state->work); + if (ret) { + strm->msg = (char *)"invalid distances set"; + state->mode = BAD; + break; + } + Tracev((stderr, "inflate: codes ok\n")); + state->mode = LEN; + + case LEN: + /* use inflate_fast() if we have enough input and output */ + if (have >= 6 && left >= 258) { + RESTORE(); + if (state->whave < state->wsize) + state->whave = state->wsize - left; + inflate_fast(strm, state->wsize); + LOAD(); + break; + } + + /* get a literal, length, or end-of-block code */ + for (;;) { + here = state->lencode[BITS(state->lenbits)]; + if ((unsigned)(here.bits) <= bits) break; + PULLBYTE(); + } + if (here.op && (here.op & 0xf0) == 0) { + last = here; + for (;;) { + here = state->lencode[last.val + + (BITS(last.bits + last.op) >> last.bits)]; + if ((unsigned)(last.bits + here.bits) <= bits) break; + PULLBYTE(); + } + DROPBITS(last.bits); + } + DROPBITS(here.bits); + state->length = (unsigned)here.val; + + /* process literal */ + if (here.op == 0) { + Tracevv((stderr, here.val >= 0x20 && here.val < 0x7f ? + "inflate: literal '%c'\n" : + "inflate: literal 0x%02x\n", here.val)); + ROOM(); + *put++ = (unsigned char)(state->length); + left--; + state->mode = LEN; + break; + } + + /* process end of block */ + if (here.op & 32) { + Tracevv((stderr, "inflate: end of block\n")); + state->mode = TYPE; + break; + } + + /* invalid code */ + if (here.op & 64) { + strm->msg = (char *)"invalid literal/length code"; + state->mode = BAD; + break; + } + + /* length code -- get extra bits, if any */ + state->extra = (unsigned)(here.op) & 15; + if (state->extra != 0) { + NEEDBITS(state->extra); + state->length += BITS(state->extra); + DROPBITS(state->extra); + } + Tracevv((stderr, "inflate: length %u\n", state->length)); + + /* get distance code */ + for (;;) { + here = state->distcode[BITS(state->distbits)]; + if ((unsigned)(here.bits) <= bits) break; + PULLBYTE(); + } + if ((here.op & 0xf0) == 0) { + last = here; + for (;;) { + here = state->distcode[last.val + + (BITS(last.bits + last.op) >> last.bits)]; + if ((unsigned)(last.bits + here.bits) <= bits) break; + PULLBYTE(); + } + DROPBITS(last.bits); + } + DROPBITS(here.bits); + if (here.op & 64) { + strm->msg = (char *)"invalid distance code"; + state->mode = BAD; + break; + } + state->offset = (unsigned)here.val; + + /* get distance extra bits, if any */ + state->extra = (unsigned)(here.op) & 15; + if (state->extra != 0) { + NEEDBITS(state->extra); + state->offset += BITS(state->extra); + DROPBITS(state->extra); + } + if (state->offset > state->wsize - (state->whave < state->wsize ? + left : 0)) { + strm->msg = (char *)"invalid distance too far back"; + state->mode = BAD; + break; + } + Tracevv((stderr, "inflate: distance %u\n", state->offset)); + + /* copy match from window to output */ + do { + ROOM(); + copy = state->wsize - state->offset; + if (copy < left) { + from = put + copy; + copy = left - copy; + } + else { + from = put - state->offset; + copy = left; + } + if (copy > state->length) copy = state->length; + state->length -= copy; + left -= copy; + do { + *put++ = *from++; + } while (--copy); + } while (state->length != 0); + break; + + case DONE: + /* inflate stream terminated properly -- write leftover output */ + ret = Z_STREAM_END; + if (left < state->wsize) { + if (out(out_desc, state->window, state->wsize - left)) + ret = Z_BUF_ERROR; + } + goto inf_leave; + + case BAD: + ret = Z_DATA_ERROR; + goto inf_leave; + + default: /* can't happen, but makes compilers happy */ + ret = Z_STREAM_ERROR; + goto inf_leave; + } + + /* Return unused input */ + inf_leave: + strm->next_in = next; + strm->avail_in = have; + return ret; +} + +int ZEXPORT inflateBackEnd(strm) +z_streamp strm; +{ + if (strm == Z_NULL || strm->state == Z_NULL || strm->zfree == (free_func)0) + return Z_STREAM_ERROR; + ZFREE(strm, strm->state); + strm->state = Z_NULL; + Tracev((stderr, "inflate: end\n")); + return Z_OK; +} diff --git a/ml/dlib/dlib/external/zlib/inffast.c b/ml/dlib/dlib/external/zlib/inffast.c new file mode 100644 index 000000000..bda59ceb6 --- /dev/null +++ b/ml/dlib/dlib/external/zlib/inffast.c @@ -0,0 +1,340 @@ +/* inffast.c -- fast decoding + * Copyright (C) 1995-2008, 2010, 2013 Mark Adler + * For conditions of distribution and use, see copyright notice in zlib.h + */ + +#include "zutil.h" +#include "inftrees.h" +#include "inflate.h" +#include "inffast.h" + +#ifndef ASMINF + +/* Allow machine dependent optimization for post-increment or pre-increment. + Based on testing to date, + Pre-increment preferred for: + - PowerPC G3 (Adler) + - MIPS R5000 (Randers-Pehrson) + Post-increment preferred for: + - none + No measurable difference: + - Pentium III (Anderson) + - M68060 (Nikl) + */ +#ifdef POSTINC +# define OFF 0 +# define PUP(a) *(a)++ +#else +# define OFF 1 +# define PUP(a) *++(a) +#endif + +/* + Decode literal, length, and distance codes and write out the resulting + literal and match bytes until either not enough input or output is + available, an end-of-block is encountered, or a data error is encountered. + When large enough input and output buffers are supplied to inflate(), for + example, a 16K input buffer and a 64K output buffer, more than 95% of the + inflate execution time is spent in this routine. + + Entry assumptions: + + state->mode == LEN + strm->avail_in >= 6 + strm->avail_out >= 258 + start >= strm->avail_out + state->bits < 8 + + On return, state->mode is one of: + + LEN -- ran out of enough output space or enough available input + TYPE -- reached end of block code, inflate() to interpret next block + BAD -- error in block data + + Notes: + + - The maximum input bits used by a length/distance pair is 15 bits for the + length code, 5 bits for the length extra, 15 bits for the distance code, + and 13 bits for the distance extra. This totals 48 bits, or six bytes. + Therefore if strm->avail_in >= 6, then there is enough input to avoid + checking for available input while decoding. + + - The maximum bytes that a single length/distance pair can output is 258 + bytes, which is the maximum length that can be coded. inflate_fast() + requires strm->avail_out >= 258 for each loop to avoid checking for + output space. + */ +void ZLIB_INTERNAL inflate_fast(strm, start) +z_streamp strm; +unsigned start; /* inflate()'s starting value for strm->avail_out */ +{ + struct inflate_state FAR *state; + z_const unsigned char FAR *in; /* local strm->next_in */ + z_const unsigned char FAR *last; /* have enough input while in < last */ + unsigned char FAR *out; /* local strm->next_out */ + unsigned char FAR *beg; /* inflate()'s initial strm->next_out */ + unsigned char FAR *end; /* while out < end, enough space available */ +#ifdef INFLATE_STRICT + unsigned dmax; /* maximum distance from zlib header */ +#endif + unsigned wsize; /* window size or zero if not using window */ + unsigned whave; /* valid bytes in the window */ + unsigned wnext; /* window write index */ + unsigned char FAR *window; /* allocated sliding window, if wsize != 0 */ + unsigned long hold; /* local strm->hold */ + unsigned bits; /* local strm->bits */ + code const FAR *lcode; /* local strm->lencode */ + code const FAR *dcode; /* local strm->distcode */ + unsigned lmask; /* mask for first level of length codes */ + unsigned dmask; /* mask for first level of distance codes */ + code here; /* retrieved table entry */ + unsigned op; /* code bits, operation, extra bits, or */ + /* window position, window bytes to copy */ + unsigned len; /* match length, unused bytes */ + unsigned dist; /* match distance */ + unsigned char FAR *from; /* where to copy match from */ + + /* copy state to local variables */ + state = (struct inflate_state FAR *)strm->state; + in = strm->next_in - OFF; + last = in + (strm->avail_in - 5); + out = strm->next_out - OFF; + beg = out - (start - strm->avail_out); + end = out + (strm->avail_out - 257); +#ifdef INFLATE_STRICT + dmax = state->dmax; +#endif + wsize = state->wsize; + whave = state->whave; + wnext = state->wnext; + window = state->window; + hold = state->hold; + bits = state->bits; + lcode = state->lencode; + dcode = state->distcode; + lmask = (1U << state->lenbits) - 1; + dmask = (1U << state->distbits) - 1; + + /* decode literals and length/distances until end-of-block or not enough + input data or output space */ + do { + if (bits < 15) { + hold += (unsigned long)(PUP(in)) << bits; + bits += 8; + hold += (unsigned long)(PUP(in)) << bits; + bits += 8; + } + here = lcode[hold & lmask]; + dolen: + op = (unsigned)(here.bits); + hold >>= op; + bits -= op; + op = (unsigned)(here.op); + if (op == 0) { /* literal */ + Tracevv((stderr, here.val >= 0x20 && here.val < 0x7f ? + "inflate: literal '%c'\n" : + "inflate: literal 0x%02x\n", here.val)); + PUP(out) = (unsigned char)(here.val); + } + else if (op & 16) { /* length base */ + len = (unsigned)(here.val); + op &= 15; /* number of extra bits */ + if (op) { + if (bits < op) { + hold += (unsigned long)(PUP(in)) << bits; + bits += 8; + } + len += (unsigned)hold & ((1U << op) - 1); + hold >>= op; + bits -= op; + } + Tracevv((stderr, "inflate: length %u\n", len)); + if (bits < 15) { + hold += (unsigned long)(PUP(in)) << bits; + bits += 8; + hold += (unsigned long)(PUP(in)) << bits; + bits += 8; + } + here = dcode[hold & dmask]; + dodist: + op = (unsigned)(here.bits); + hold >>= op; + bits -= op; + op = (unsigned)(here.op); + if (op & 16) { /* distance base */ + dist = (unsigned)(here.val); + op &= 15; /* number of extra bits */ + if (bits < op) { + hold += (unsigned long)(PUP(in)) << bits; + bits += 8; + if (bits < op) { + hold += (unsigned long)(PUP(in)) << bits; + bits += 8; + } + } + dist += (unsigned)hold & ((1U << op) - 1); +#ifdef INFLATE_STRICT + if (dist > dmax) { + strm->msg = (char *)"invalid distance too far back"; + state->mode = BAD; + break; + } +#endif + hold >>= op; + bits -= op; + Tracevv((stderr, "inflate: distance %u\n", dist)); + op = (unsigned)(out - beg); /* max distance in output */ + if (dist > op) { /* see if copy from window */ + op = dist - op; /* distance back in window */ + if (op > whave) { + if (state->sane) { + strm->msg = + (char *)"invalid distance too far back"; + state->mode = BAD; + break; + } +#ifdef INFLATE_ALLOW_INVALID_DISTANCE_TOOFAR_ARRR + if (len <= op - whave) { + do { + PUP(out) = 0; + } while (--len); + continue; + } + len -= op - whave; + do { + PUP(out) = 0; + } while (--op > whave); + if (op == 0) { + from = out - dist; + do { + PUP(out) = PUP(from); + } while (--len); + continue; + } +#endif + } + from = window - OFF; + if (wnext == 0) { /* very common case */ + from += wsize - op; + if (op < len) { /* some from window */ + len -= op; + do { + PUP(out) = PUP(from); + } while (--op); + from = out - dist; /* rest from output */ + } + } + else if (wnext < op) { /* wrap around window */ + from += wsize + wnext - op; + op -= wnext; + if (op < len) { /* some from end of window */ + len -= op; + do { + PUP(out) = PUP(from); + } while (--op); + from = window - OFF; + if (wnext < len) { /* some from start of window */ + op = wnext; + len -= op; + do { + PUP(out) = PUP(from); + } while (--op); + from = out - dist; /* rest from output */ + } + } + } + else { /* contiguous in window */ + from += wnext - op; + if (op < len) { /* some from window */ + len -= op; + do { + PUP(out) = PUP(from); + } while (--op); + from = out - dist; /* rest from output */ + } + } + while (len > 2) { + PUP(out) = PUP(from); + PUP(out) = PUP(from); + PUP(out) = PUP(from); + len -= 3; + } + if (len) { + PUP(out) = PUP(from); + if (len > 1) + PUP(out) = PUP(from); + } + } + else { + from = out - dist; /* copy direct from output */ + do { /* minimum length is three */ + PUP(out) = PUP(from); + PUP(out) = PUP(from); + PUP(out) = PUP(from); + len -= 3; + } while (len > 2); + if (len) { + PUP(out) = PUP(from); + if (len > 1) + PUP(out) = PUP(from); + } + } + } + else if ((op & 64) == 0) { /* 2nd level distance code */ + here = dcode[here.val + (hold & ((1U << op) - 1))]; + goto dodist; + } + else { + strm->msg = (char *)"invalid distance code"; + state->mode = BAD; + break; + } + } + else if ((op & 64) == 0) { /* 2nd level length code */ + here = lcode[here.val + (hold & ((1U << op) - 1))]; + goto dolen; + } + else if (op & 32) { /* end-of-block */ + Tracevv((stderr, "inflate: end of block\n")); + state->mode = TYPE; + break; + } + else { + strm->msg = (char *)"invalid literal/length code"; + state->mode = BAD; + break; + } + } while (in < last && out < end); + + /* return unused bytes (on entry, bits < 8, so in won't go too far back) */ + len = bits >> 3; + in -= len; + bits -= len << 3; + hold &= (1U << bits) - 1; + + /* update state and return */ + strm->next_in = in + OFF; + strm->next_out = out + OFF; + strm->avail_in = (unsigned)(in < last ? 5 + (last - in) : 5 - (in - last)); + strm->avail_out = (unsigned)(out < end ? + 257 + (end - out) : 257 - (out - end)); + state->hold = hold; + state->bits = bits; + return; +} + +/* + inflate_fast() speedups that turned out slower (on a PowerPC G3 750CXe): + - Using bit fields for code structure + - Different op definition to avoid & for extra bits (do & for table bits) + - Three separate decoding do-loops for direct, window, and wnext == 0 + - Special case for distance > 1 copies to do overlapped load and store copy + - Explicit branch predictions (based on measured branch probabilities) + - Deferring match copy and interspersed it with decoding subsequent codes + - Swapping literal/length else + - Swapping window/direct else + - Larger unrolled copy loops (three is about right) + - Moving len -= 3 statement into middle of loop + */ + +#endif /* !ASMINF */ diff --git a/ml/dlib/dlib/external/zlib/inffast.h b/ml/dlib/dlib/external/zlib/inffast.h new file mode 100644 index 000000000..e5c1aa4ca --- /dev/null +++ b/ml/dlib/dlib/external/zlib/inffast.h @@ -0,0 +1,11 @@ +/* inffast.h -- header to use inffast.c + * Copyright (C) 1995-2003, 2010 Mark Adler + * For conditions of distribution and use, see copyright notice in zlib.h + */ + +/* WARNING: this file should *not* be used by applications. It is + part of the implementation of the compression library and is + subject to change. Applications should only use zlib.h. + */ + +void ZLIB_INTERNAL inflate_fast OF((z_streamp strm, unsigned start)); diff --git a/ml/dlib/dlib/external/zlib/inffixed.h b/ml/dlib/dlib/external/zlib/inffixed.h new file mode 100644 index 000000000..d62832776 --- /dev/null +++ b/ml/dlib/dlib/external/zlib/inffixed.h @@ -0,0 +1,94 @@ + /* inffixed.h -- table for decoding fixed codes + * Generated automatically by makefixed(). + */ + + /* WARNING: this file should *not* be used by applications. + It is part of the implementation of this library and is + subject to change. Applications should only use zlib.h. + */ + + static const code lenfix[512] = { + {96,7,0},{0,8,80},{0,8,16},{20,8,115},{18,7,31},{0,8,112},{0,8,48}, + {0,9,192},{16,7,10},{0,8,96},{0,8,32},{0,9,160},{0,8,0},{0,8,128}, + {0,8,64},{0,9,224},{16,7,6},{0,8,88},{0,8,24},{0,9,144},{19,7,59}, + {0,8,120},{0,8,56},{0,9,208},{17,7,17},{0,8,104},{0,8,40},{0,9,176}, + {0,8,8},{0,8,136},{0,8,72},{0,9,240},{16,7,4},{0,8,84},{0,8,20}, + {21,8,227},{19,7,43},{0,8,116},{0,8,52},{0,9,200},{17,7,13},{0,8,100}, + {0,8,36},{0,9,168},{0,8,4},{0,8,132},{0,8,68},{0,9,232},{16,7,8}, + {0,8,92},{0,8,28},{0,9,152},{20,7,83},{0,8,124},{0,8,60},{0,9,216}, + {18,7,23},{0,8,108},{0,8,44},{0,9,184},{0,8,12},{0,8,140},{0,8,76}, + {0,9,248},{16,7,3},{0,8,82},{0,8,18},{21,8,163},{19,7,35},{0,8,114}, + {0,8,50},{0,9,196},{17,7,11},{0,8,98},{0,8,34},{0,9,164},{0,8,2}, + {0,8,130},{0,8,66},{0,9,228},{16,7,7},{0,8,90},{0,8,26},{0,9,148}, + {20,7,67},{0,8,122},{0,8,58},{0,9,212},{18,7,19},{0,8,106},{0,8,42}, + {0,9,180},{0,8,10},{0,8,138},{0,8,74},{0,9,244},{16,7,5},{0,8,86}, + {0,8,22},{64,8,0},{19,7,51},{0,8,118},{0,8,54},{0,9,204},{17,7,15}, + {0,8,102},{0,8,38},{0,9,172},{0,8,6},{0,8,134},{0,8,70},{0,9,236}, + {16,7,9},{0,8,94},{0,8,30},{0,9,156},{20,7,99},{0,8,126},{0,8,62}, + {0,9,220},{18,7,27},{0,8,110},{0,8,46},{0,9,188},{0,8,14},{0,8,142}, + {0,8,78},{0,9,252},{96,7,0},{0,8,81},{0,8,17},{21,8,131},{18,7,31}, + {0,8,113},{0,8,49},{0,9,194},{16,7,10},{0,8,97},{0,8,33},{0,9,162}, + {0,8,1},{0,8,129},{0,8,65},{0,9,226},{16,7,6},{0,8,89},{0,8,25}, + {0,9,146},{19,7,59},{0,8,121},{0,8,57},{0,9,210},{17,7,17},{0,8,105}, + {0,8,41},{0,9,178},{0,8,9},{0,8,137},{0,8,73},{0,9,242},{16,7,4}, + {0,8,85},{0,8,21},{16,8,258},{19,7,43},{0,8,117},{0,8,53},{0,9,202}, + {17,7,13},{0,8,101},{0,8,37},{0,9,170},{0,8,5},{0,8,133},{0,8,69}, + {0,9,234},{16,7,8},{0,8,93},{0,8,29},{0,9,154},{20,7,83},{0,8,125}, + {0,8,61},{0,9,218},{18,7,23},{0,8,109},{0,8,45},{0,9,186},{0,8,13}, + {0,8,141},{0,8,77},{0,9,250},{16,7,3},{0,8,83},{0,8,19},{21,8,195}, + {19,7,35},{0,8,115},{0,8,51},{0,9,198},{17,7,11},{0,8,99},{0,8,35}, + {0,9,166},{0,8,3},{0,8,131},{0,8,67},{0,9,230},{16,7,7},{0,8,91}, + {0,8,27},{0,9,150},{20,7,67},{0,8,123},{0,8,59},{0,9,214},{18,7,19}, + {0,8,107},{0,8,43},{0,9,182},{0,8,11},{0,8,139},{0,8,75},{0,9,246}, + {16,7,5},{0,8,87},{0,8,23},{64,8,0},{19,7,51},{0,8,119},{0,8,55}, + {0,9,206},{17,7,15},{0,8,103},{0,8,39},{0,9,174},{0,8,7},{0,8,135}, + {0,8,71},{0,9,238},{16,7,9},{0,8,95},{0,8,31},{0,9,158},{20,7,99}, + {0,8,127},{0,8,63},{0,9,222},{18,7,27},{0,8,111},{0,8,47},{0,9,190}, + {0,8,15},{0,8,143},{0,8,79},{0,9,254},{96,7,0},{0,8,80},{0,8,16}, + {20,8,115},{18,7,31},{0,8,112},{0,8,48},{0,9,193},{16,7,10},{0,8,96}, + {0,8,32},{0,9,161},{0,8,0},{0,8,128},{0,8,64},{0,9,225},{16,7,6}, + {0,8,88},{0,8,24},{0,9,145},{19,7,59},{0,8,120},{0,8,56},{0,9,209}, + {17,7,17},{0,8,104},{0,8,40},{0,9,177},{0,8,8},{0,8,136},{0,8,72}, + {0,9,241},{16,7,4},{0,8,84},{0,8,20},{21,8,227},{19,7,43},{0,8,116}, + {0,8,52},{0,9,201},{17,7,13},{0,8,100},{0,8,36},{0,9,169},{0,8,4}, + {0,8,132},{0,8,68},{0,9,233},{16,7,8},{0,8,92},{0,8,28},{0,9,153}, + {20,7,83},{0,8,124},{0,8,60},{0,9,217},{18,7,23},{0,8,108},{0,8,44}, + {0,9,185},{0,8,12},{0,8,140},{0,8,76},{0,9,249},{16,7,3},{0,8,82}, + {0,8,18},{21,8,163},{19,7,35},{0,8,114},{0,8,50},{0,9,197},{17,7,11}, + {0,8,98},{0,8,34},{0,9,165},{0,8,2},{0,8,130},{0,8,66},{0,9,229}, + {16,7,7},{0,8,90},{0,8,26},{0,9,149},{20,7,67},{0,8,122},{0,8,58}, + {0,9,213},{18,7,19},{0,8,106},{0,8,42},{0,9,181},{0,8,10},{0,8,138}, + {0,8,74},{0,9,245},{16,7,5},{0,8,86},{0,8,22},{64,8,0},{19,7,51}, + {0,8,118},{0,8,54},{0,9,205},{17,7,15},{0,8,102},{0,8,38},{0,9,173}, + {0,8,6},{0,8,134},{0,8,70},{0,9,237},{16,7,9},{0,8,94},{0,8,30}, + {0,9,157},{20,7,99},{0,8,126},{0,8,62},{0,9,221},{18,7,27},{0,8,110}, + {0,8,46},{0,9,189},{0,8,14},{0,8,142},{0,8,78},{0,9,253},{96,7,0}, + {0,8,81},{0,8,17},{21,8,131},{18,7,31},{0,8,113},{0,8,49},{0,9,195}, + {16,7,10},{0,8,97},{0,8,33},{0,9,163},{0,8,1},{0,8,129},{0,8,65}, + {0,9,227},{16,7,6},{0,8,89},{0,8,25},{0,9,147},{19,7,59},{0,8,121}, + {0,8,57},{0,9,211},{17,7,17},{0,8,105},{0,8,41},{0,9,179},{0,8,9}, + {0,8,137},{0,8,73},{0,9,243},{16,7,4},{0,8,85},{0,8,21},{16,8,258}, + {19,7,43},{0,8,117},{0,8,53},{0,9,203},{17,7,13},{0,8,101},{0,8,37}, + {0,9,171},{0,8,5},{0,8,133},{0,8,69},{0,9,235},{16,7,8},{0,8,93}, + {0,8,29},{0,9,155},{20,7,83},{0,8,125},{0,8,61},{0,9,219},{18,7,23}, + {0,8,109},{0,8,45},{0,9,187},{0,8,13},{0,8,141},{0,8,77},{0,9,251}, + {16,7,3},{0,8,83},{0,8,19},{21,8,195},{19,7,35},{0,8,115},{0,8,51}, + {0,9,199},{17,7,11},{0,8,99},{0,8,35},{0,9,167},{0,8,3},{0,8,131}, + {0,8,67},{0,9,231},{16,7,7},{0,8,91},{0,8,27},{0,9,151},{20,7,67}, + {0,8,123},{0,8,59},{0,9,215},{18,7,19},{0,8,107},{0,8,43},{0,9,183}, + {0,8,11},{0,8,139},{0,8,75},{0,9,247},{16,7,5},{0,8,87},{0,8,23}, + {64,8,0},{19,7,51},{0,8,119},{0,8,55},{0,9,207},{17,7,15},{0,8,103}, + {0,8,39},{0,9,175},{0,8,7},{0,8,135},{0,8,71},{0,9,239},{16,7,9}, + {0,8,95},{0,8,31},{0,9,159},{20,7,99},{0,8,127},{0,8,63},{0,9,223}, + {18,7,27},{0,8,111},{0,8,47},{0,9,191},{0,8,15},{0,8,143},{0,8,79}, + {0,9,255} + }; + + static const code distfix[32] = { + {16,5,1},{23,5,257},{19,5,17},{27,5,4097},{17,5,5},{25,5,1025}, + {21,5,65},{29,5,16385},{16,5,3},{24,5,513},{20,5,33},{28,5,8193}, + {18,5,9},{26,5,2049},{22,5,129},{64,5,0},{16,5,2},{23,5,385}, + {19,5,25},{27,5,6145},{17,5,7},{25,5,1537},{21,5,97},{29,5,24577}, + {16,5,4},{24,5,769},{20,5,49},{28,5,12289},{18,5,13},{26,5,3073}, + {22,5,193},{64,5,0} + }; diff --git a/ml/dlib/dlib/external/zlib/inflate.c b/ml/dlib/dlib/external/zlib/inflate.c new file mode 100644 index 000000000..870f89bb4 --- /dev/null +++ b/ml/dlib/dlib/external/zlib/inflate.c @@ -0,0 +1,1512 @@ +/* inflate.c -- zlib decompression + * Copyright (C) 1995-2012 Mark Adler + * For conditions of distribution and use, see copyright notice in zlib.h + */ + +/* + * Change history: + * + * 1.2.beta0 24 Nov 2002 + * - First version -- complete rewrite of inflate to simplify code, avoid + * creation of window when not needed, minimize use of window when it is + * needed, make inffast.c even faster, implement gzip decoding, and to + * improve code readability and style over the previous zlib inflate code + * + * 1.2.beta1 25 Nov 2002 + * - Use pointers for available input and output checking in inffast.c + * - Remove input and output counters in inffast.c + * - Change inffast.c entry and loop from avail_in >= 7 to >= 6 + * - Remove unnecessary second byte pull from length extra in inffast.c + * - Unroll direct copy to three copies per loop in inffast.c + * + * 1.2.beta2 4 Dec 2002 + * - Change external routine names to reduce potential conflicts + * - Correct filename to inffixed.h for fixed tables in inflate.c + * - Make hbuf[] unsigned char to match parameter type in inflate.c + * - Change strm->next_out[-state->offset] to *(strm->next_out - state->offset) + * to avoid negation problem on Alphas (64 bit) in inflate.c + * + * 1.2.beta3 22 Dec 2002 + * - Add comments on state->bits assertion in inffast.c + * - Add comments on op field in inftrees.h + * - Fix bug in reuse of allocated window after inflateReset() + * - Remove bit fields--back to byte structure for speed + * - Remove distance extra == 0 check in inflate_fast()--only helps for lengths + * - Change post-increments to pre-increments in inflate_fast(), PPC biased? + * - Add compile time option, POSTINC, to use post-increments instead (Intel?) + * - Make MATCH copy in inflate() much faster for when inflate_fast() not used + * - Use local copies of stream next and avail values, as well as local bit + * buffer and bit count in inflate()--for speed when inflate_fast() not used + * + * 1.2.beta4 1 Jan 2003 + * - Split ptr - 257 statements in inflate_table() to avoid compiler warnings + * - Move a comment on output buffer sizes from inffast.c to inflate.c + * - Add comments in inffast.c to introduce the inflate_fast() routine + * - Rearrange window copies in inflate_fast() for speed and simplification + * - Unroll last copy for window match in inflate_fast() + * - Use local copies of window variables in inflate_fast() for speed + * - Pull out common wnext == 0 case for speed in inflate_fast() + * - Make op and len in inflate_fast() unsigned for consistency + * - Add FAR to lcode and dcode declarations in inflate_fast() + * - Simplified bad distance check in inflate_fast() + * - Added inflateBackInit(), inflateBack(), and inflateBackEnd() in new + * source file infback.c to provide a call-back interface to inflate for + * programs like gzip and unzip -- uses window as output buffer to avoid + * window copying + * + * 1.2.beta5 1 Jan 2003 + * - Improved inflateBack() interface to allow the caller to provide initial + * input in strm. + * - Fixed stored blocks bug in inflateBack() + * + * 1.2.beta6 4 Jan 2003 + * - Added comments in inffast.c on effectiveness of POSTINC + * - Typecasting all around to reduce compiler warnings + * - Changed loops from while (1) or do {} while (1) to for (;;), again to + * make compilers happy + * - Changed type of window in inflateBackInit() to unsigned char * + * + * 1.2.beta7 27 Jan 2003 + * - Changed many types to unsigned or unsigned short to avoid warnings + * - Added inflateCopy() function + * + * 1.2.0 9 Mar 2003 + * - Changed inflateBack() interface to provide separate opaque descriptors + * for the in() and out() functions + * - Changed inflateBack() argument and in_func typedef to swap the length + * and buffer address return values for the input function + * - Check next_in and next_out for Z_NULL on entry to inflate() + * + * The history for versions after 1.2.0 are in ChangeLog in zlib distribution. + */ + +#include "zutil.h" +#include "inftrees.h" +#include "inflate.h" +#include "inffast.h" + +#ifdef MAKEFIXED +# ifndef BUILDFIXED +# define BUILDFIXED +# endif +#endif + +/* function prototypes */ +local void fixedtables OF((struct inflate_state FAR *state)); +local int updatewindow OF((z_streamp strm, const unsigned char FAR *end, + unsigned copy)); +#ifdef BUILDFIXED + void makefixed OF((void)); +#endif +local unsigned syncsearch OF((unsigned FAR *have, const unsigned char FAR *buf, + unsigned len)); + +int ZEXPORT inflateResetKeep(strm) +z_streamp strm; +{ + struct inflate_state FAR *state; + + if (strm == Z_NULL || strm->state == Z_NULL) return Z_STREAM_ERROR; + state = (struct inflate_state FAR *)strm->state; + strm->total_in = strm->total_out = state->total = 0; + strm->msg = Z_NULL; + if (state->wrap) /* to support ill-conceived Java test suite */ + strm->adler = state->wrap & 1; + state->mode = HEAD; + state->last = 0; + state->havedict = 0; + state->dmax = 32768U; + state->head = Z_NULL; + state->hold = 0; + state->bits = 0; + state->lencode = state->distcode = state->next = state->codes; + state->sane = 1; + state->back = -1; + Tracev((stderr, "inflate: reset\n")); + return Z_OK; +} + +int ZEXPORT inflateReset(strm) +z_streamp strm; +{ + struct inflate_state FAR *state; + + if (strm == Z_NULL || strm->state == Z_NULL) return Z_STREAM_ERROR; + state = (struct inflate_state FAR *)strm->state; + state->wsize = 0; + state->whave = 0; + state->wnext = 0; + return inflateResetKeep(strm); +} + +int ZEXPORT inflateReset2(strm, windowBits) +z_streamp strm; +int windowBits; +{ + int wrap; + struct inflate_state FAR *state; + + /* get the state */ + if (strm == Z_NULL || strm->state == Z_NULL) return Z_STREAM_ERROR; + state = (struct inflate_state FAR *)strm->state; + + /* extract wrap request from windowBits parameter */ + if (windowBits < 0) { + wrap = 0; + windowBits = -windowBits; + } + else { + wrap = (windowBits >> 4) + 1; +#ifdef GUNZIP + if (windowBits < 48) + windowBits &= 15; +#endif + } + + /* set number of window bits, free window if different */ + if (windowBits && (windowBits < 8 || windowBits > 15)) + return Z_STREAM_ERROR; + if (state->window != Z_NULL && state->wbits != (unsigned)windowBits) { + ZFREE(strm, state->window); + state->window = Z_NULL; + } + + /* update state and reset the rest of it */ + state->wrap = wrap; + state->wbits = (unsigned)windowBits; + return inflateReset(strm); +} + +int ZEXPORT inflateInit2_(strm, windowBits, version, stream_size) +z_streamp strm; +int windowBits; +const char *version; +int stream_size; +{ + int ret; + struct inflate_state FAR *state; + + if (version == Z_NULL || version[0] != ZLIB_VERSION[0] || + stream_size != (int)(sizeof(z_stream))) + return Z_VERSION_ERROR; + if (strm == Z_NULL) return Z_STREAM_ERROR; + strm->msg = Z_NULL; /* in case we return an error */ + if (strm->zalloc == (alloc_func)0) { +#ifdef Z_SOLO + return Z_STREAM_ERROR; +#else + strm->zalloc = zcalloc; + strm->opaque = (voidpf)0; +#endif + } + if (strm->zfree == (free_func)0) +#ifdef Z_SOLO + return Z_STREAM_ERROR; +#else + strm->zfree = zcfree; +#endif + state = (struct inflate_state FAR *) + ZALLOC(strm, 1, sizeof(struct inflate_state)); + if (state == Z_NULL) return Z_MEM_ERROR; + Tracev((stderr, "inflate: allocated\n")); + strm->state = (struct internal_state FAR *)state; + state->window = Z_NULL; + ret = inflateReset2(strm, windowBits); + if (ret != Z_OK) { + ZFREE(strm, state); + strm->state = Z_NULL; + } + return ret; +} + +int ZEXPORT inflateInit_(strm, version, stream_size) +z_streamp strm; +const char *version; +int stream_size; +{ + return inflateInit2_(strm, DEF_WBITS, version, stream_size); +} + +int ZEXPORT inflatePrime(strm, bits, value) +z_streamp strm; +int bits; +int value; +{ + struct inflate_state FAR *state; + + if (strm == Z_NULL || strm->state == Z_NULL) return Z_STREAM_ERROR; + state = (struct inflate_state FAR *)strm->state; + if (bits < 0) { + state->hold = 0; + state->bits = 0; + return Z_OK; + } + if (bits > 16 || state->bits + bits > 32) return Z_STREAM_ERROR; + value &= (1L << bits) - 1; + state->hold += value << state->bits; + state->bits += bits; + return Z_OK; +} + +/* + Return state with length and distance decoding tables and index sizes set to + fixed code decoding. Normally this returns fixed tables from inffixed.h. + If BUILDFIXED is defined, then instead this routine builds the tables the + first time it's called, and returns those tables the first time and + thereafter. This reduces the size of the code by about 2K bytes, in + exchange for a little execution time. However, BUILDFIXED should not be + used for threaded applications, since the rewriting of the tables and virgin + may not be thread-safe. + */ +local void fixedtables(state) +struct inflate_state FAR *state; +{ +#ifdef BUILDFIXED + static int virgin = 1; + static code *lenfix, *distfix; + static code fixed[544]; + + /* build fixed huffman tables if first call (may not be thread safe) */ + if (virgin) { + unsigned sym, bits; + static code *next; + + /* literal/length table */ + sym = 0; + while (sym < 144) state->lens[sym++] = 8; + while (sym < 256) state->lens[sym++] = 9; + while (sym < 280) state->lens[sym++] = 7; + while (sym < 288) state->lens[sym++] = 8; + next = fixed; + lenfix = next; + bits = 9; + inflate_table(LENS, state->lens, 288, &(next), &(bits), state->work); + + /* distance table */ + sym = 0; + while (sym < 32) state->lens[sym++] = 5; + distfix = next; + bits = 5; + inflate_table(DISTS, state->lens, 32, &(next), &(bits), state->work); + + /* do this just once */ + virgin = 0; + } +#else /* !BUILDFIXED */ +# include "inffixed.h" +#endif /* BUILDFIXED */ + state->lencode = lenfix; + state->lenbits = 9; + state->distcode = distfix; + state->distbits = 5; +} + +#ifdef MAKEFIXED +#include + +/* + Write out the inffixed.h that is #include'd above. Defining MAKEFIXED also + defines BUILDFIXED, so the tables are built on the fly. makefixed() writes + those tables to stdout, which would be piped to inffixed.h. A small program + can simply call makefixed to do this: + + void makefixed(void); + + int main(void) + { + makefixed(); + return 0; + } + + Then that can be linked with zlib built with MAKEFIXED defined and run: + + a.out > inffixed.h + */ +void makefixed() +{ + unsigned low, size; + struct inflate_state state; + + fixedtables(&state); + puts(" /* inffixed.h -- table for decoding fixed codes"); + puts(" * Generated automatically by makefixed()."); + puts(" */"); + puts(""); + puts(" /* WARNING: this file should *not* be used by applications."); + puts(" It is part of the implementation of this library and is"); + puts(" subject to change. Applications should only use zlib.h."); + puts(" */"); + puts(""); + size = 1U << 9; + printf(" static const code lenfix[%u] = {", size); + low = 0; + for (;;) { + if ((low % 7) == 0) printf("\n "); + printf("{%u,%u,%d}", (low & 127) == 99 ? 64 : state.lencode[low].op, + state.lencode[low].bits, state.lencode[low].val); + if (++low == size) break; + putchar(','); + } + puts("\n };"); + size = 1U << 5; + printf("\n static const code distfix[%u] = {", size); + low = 0; + for (;;) { + if ((low % 6) == 0) printf("\n "); + printf("{%u,%u,%d}", state.distcode[low].op, state.distcode[low].bits, + state.distcode[low].val); + if (++low == size) break; + putchar(','); + } + puts("\n };"); +} +#endif /* MAKEFIXED */ + +/* + Update the window with the last wsize (normally 32K) bytes written before + returning. If window does not exist yet, create it. This is only called + when a window is already in use, or when output has been written during this + inflate call, but the end of the deflate stream has not been reached yet. + It is also called to create a window for dictionary data when a dictionary + is loaded. + + Providing output buffers larger than 32K to inflate() should provide a speed + advantage, since only the last 32K of output is copied to the sliding window + upon return from inflate(), and since all distances after the first 32K of + output will fall in the output data, making match copies simpler and faster. + The advantage may be dependent on the size of the processor's data caches. + */ +local int updatewindow(strm, end, copy) +z_streamp strm; +const Bytef *end; +unsigned copy; +{ + struct inflate_state FAR *state; + unsigned dist; + + state = (struct inflate_state FAR *)strm->state; + + /* if it hasn't been done already, allocate space for the window */ + if (state->window == Z_NULL) { + state->window = (unsigned char FAR *) + ZALLOC(strm, 1U << state->wbits, + sizeof(unsigned char)); + if (state->window == Z_NULL) return 1; + } + + /* if window not in use yet, initialize */ + if (state->wsize == 0) { + state->wsize = 1U << state->wbits; + state->wnext = 0; + state->whave = 0; + } + + /* copy state->wsize or less output bytes into the circular window */ + if (copy >= state->wsize) { + zmemcpy(state->window, end - state->wsize, state->wsize); + state->wnext = 0; + state->whave = state->wsize; + } + else { + dist = state->wsize - state->wnext; + if (dist > copy) dist = copy; + zmemcpy(state->window + state->wnext, end - copy, dist); + copy -= dist; + if (copy) { + zmemcpy(state->window, end - copy, copy); + state->wnext = copy; + state->whave = state->wsize; + } + else { + state->wnext += dist; + if (state->wnext == state->wsize) state->wnext = 0; + if (state->whave < state->wsize) state->whave += dist; + } + } + return 0; +} + +/* Macros for inflate(): */ + +/* check function to use adler32() for zlib or crc32() for gzip */ +#ifdef GUNZIP +# define UPDATE(check, buf, len) \ + (state->flags ? crc32(check, buf, len) : adler32(check, buf, len)) +#else +# define UPDATE(check, buf, len) adler32(check, buf, len) +#endif + +/* check macros for header crc */ +#ifdef GUNZIP +# define CRC2(check, word) \ + do { \ + hbuf[0] = (unsigned char)(word); \ + hbuf[1] = (unsigned char)((word) >> 8); \ + check = crc32(check, hbuf, 2); \ + } while (0) + +# define CRC4(check, word) \ + do { \ + hbuf[0] = (unsigned char)(word); \ + hbuf[1] = (unsigned char)((word) >> 8); \ + hbuf[2] = (unsigned char)((word) >> 16); \ + hbuf[3] = (unsigned char)((word) >> 24); \ + check = crc32(check, hbuf, 4); \ + } while (0) +#endif + +/* Load registers with state in inflate() for speed */ +#define LOAD() \ + do { \ + put = strm->next_out; \ + left = strm->avail_out; \ + next = strm->next_in; \ + have = strm->avail_in; \ + hold = state->hold; \ + bits = state->bits; \ + } while (0) + +/* Restore state from registers in inflate() */ +#define RESTORE() \ + do { \ + strm->next_out = put; \ + strm->avail_out = left; \ + strm->next_in = next; \ + strm->avail_in = have; \ + state->hold = hold; \ + state->bits = bits; \ + } while (0) + +/* Clear the input bit accumulator */ +#define INITBITS() \ + do { \ + hold = 0; \ + bits = 0; \ + } while (0) + +/* Get a byte of input into the bit accumulator, or return from inflate() + if there is no input available. */ +#define PULLBYTE() \ + do { \ + if (have == 0) goto inf_leave; \ + have--; \ + hold += (unsigned long)(*next++) << bits; \ + bits += 8; \ + } while (0) + +/* Assure that there are at least n bits in the bit accumulator. If there is + not enough available input to do that, then return from inflate(). */ +#define NEEDBITS(n) \ + do { \ + while (bits < (unsigned)(n)) \ + PULLBYTE(); \ + } while (0) + +/* Return the low n bits of the bit accumulator (n < 16) */ +#define BITS(n) \ + ((unsigned)hold & ((1U << (n)) - 1)) + +/* Remove n bits from the bit accumulator */ +#define DROPBITS(n) \ + do { \ + hold >>= (n); \ + bits -= (unsigned)(n); \ + } while (0) + +/* Remove zero to seven bits as needed to go to a byte boundary */ +#define BYTEBITS() \ + do { \ + hold >>= bits & 7; \ + bits -= bits & 7; \ + } while (0) + +/* + inflate() uses a state machine to process as much input data and generate as + much output data as possible before returning. The state machine is + structured roughly as follows: + + for (;;) switch (state) { + ... + case STATEn: + if (not enough input data or output space to make progress) + return; + ... make progress ... + state = STATEm; + break; + ... + } + + so when inflate() is called again, the same case is attempted again, and + if the appropriate resources are provided, the machine proceeds to the + next state. The NEEDBITS() macro is usually the way the state evaluates + whether it can proceed or should return. NEEDBITS() does the return if + the requested bits are not available. The typical use of the BITS macros + is: + + NEEDBITS(n); + ... do something with BITS(n) ... + DROPBITS(n); + + where NEEDBITS(n) either returns from inflate() if there isn't enough + input left to load n bits into the accumulator, or it continues. BITS(n) + gives the low n bits in the accumulator. When done, DROPBITS(n) drops + the low n bits off the accumulator. INITBITS() clears the accumulator + and sets the number of available bits to zero. BYTEBITS() discards just + enough bits to put the accumulator on a byte boundary. After BYTEBITS() + and a NEEDBITS(8), then BITS(8) would return the next byte in the stream. + + NEEDBITS(n) uses PULLBYTE() to get an available byte of input, or to return + if there is no input available. The decoding of variable length codes uses + PULLBYTE() directly in order to pull just enough bytes to decode the next + code, and no more. + + Some states loop until they get enough input, making sure that enough + state information is maintained to continue the loop where it left off + if NEEDBITS() returns in the loop. For example, want, need, and keep + would all have to actually be part of the saved state in case NEEDBITS() + returns: + + case STATEw: + while (want < need) { + NEEDBITS(n); + keep[want++] = BITS(n); + DROPBITS(n); + } + state = STATEx; + case STATEx: + + As shown above, if the next state is also the next case, then the break + is omitted. + + A state may also return if there is not enough output space available to + complete that state. Those states are copying stored data, writing a + literal byte, and copying a matching string. + + When returning, a "goto inf_leave" is used to update the total counters, + update the check value, and determine whether any progress has been made + during that inflate() call in order to return the proper return code. + Progress is defined as a change in either strm->avail_in or strm->avail_out. + When there is a window, goto inf_leave will update the window with the last + output written. If a goto inf_leave occurs in the middle of decompression + and there is no window currently, goto inf_leave will create one and copy + output to the window for the next call of inflate(). + + In this implementation, the flush parameter of inflate() only affects the + return code (per zlib.h). inflate() always writes as much as possible to + strm->next_out, given the space available and the provided input--the effect + documented in zlib.h of Z_SYNC_FLUSH. Furthermore, inflate() always defers + the allocation of and copying into a sliding window until necessary, which + provides the effect documented in zlib.h for Z_FINISH when the entire input + stream available. So the only thing the flush parameter actually does is: + when flush is set to Z_FINISH, inflate() cannot return Z_OK. Instead it + will return Z_BUF_ERROR if it has not reached the end of the stream. + */ + +int ZEXPORT inflate(strm, flush) +z_streamp strm; +int flush; +{ + struct inflate_state FAR *state; + z_const unsigned char FAR *next; /* next input */ + unsigned char FAR *put; /* next output */ + unsigned have, left; /* available input and output */ + unsigned long hold; /* bit buffer */ + unsigned bits; /* bits in bit buffer */ + unsigned in, out; /* save starting available input and output */ + unsigned copy; /* number of stored or match bytes to copy */ + unsigned char FAR *from; /* where to copy match bytes from */ + code here; /* current decoding table entry */ + code last; /* parent table entry */ + unsigned len; /* length to copy for repeats, bits to drop */ + int ret; /* return code */ +#ifdef GUNZIP + unsigned char hbuf[4]; /* buffer for gzip header crc calculation */ +#endif + static const unsigned short order[19] = /* permutation of code lengths */ + {16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15}; + + if (strm == Z_NULL || strm->state == Z_NULL || strm->next_out == Z_NULL || + (strm->next_in == Z_NULL && strm->avail_in != 0)) + return Z_STREAM_ERROR; + + state = (struct inflate_state FAR *)strm->state; + if (state->mode == TYPE) state->mode = TYPEDO; /* skip check */ + LOAD(); + in = have; + out = left; + ret = Z_OK; + for (;;) + switch (state->mode) { + case HEAD: + if (state->wrap == 0) { + state->mode = TYPEDO; + break; + } + NEEDBITS(16); +#ifdef GUNZIP + if ((state->wrap & 2) && hold == 0x8b1f) { /* gzip header */ + state->check = crc32(0L, Z_NULL, 0); + CRC2(state->check, hold); + INITBITS(); + state->mode = FLAGS; + break; + } + state->flags = 0; /* expect zlib header */ + if (state->head != Z_NULL) + state->head->done = -1; + if (!(state->wrap & 1) || /* check if zlib header allowed */ +#else + if ( +#endif + ((BITS(8) << 8) + (hold >> 8)) % 31) { + strm->msg = (char *)"incorrect header check"; + state->mode = BAD; + break; + } + if (BITS(4) != Z_DEFLATED) { + strm->msg = (char *)"unknown compression method"; + state->mode = BAD; + break; + } + DROPBITS(4); + len = BITS(4) + 8; + if (state->wbits == 0) + state->wbits = len; + else if (len > state->wbits) { + strm->msg = (char *)"invalid window size"; + state->mode = BAD; + break; + } + state->dmax = 1U << len; + Tracev((stderr, "inflate: zlib header ok\n")); + strm->adler = state->check = adler32(0L, Z_NULL, 0); + state->mode = hold & 0x200 ? DICTID : TYPE; + INITBITS(); + break; +#ifdef GUNZIP + case FLAGS: + NEEDBITS(16); + state->flags = (int)(hold); + if ((state->flags & 0xff) != Z_DEFLATED) { + strm->msg = (char *)"unknown compression method"; + state->mode = BAD; + break; + } + if (state->flags & 0xe000) { + strm->msg = (char *)"unknown header flags set"; + state->mode = BAD; + break; + } + if (state->head != Z_NULL) + state->head->text = (int)((hold >> 8) & 1); + if (state->flags & 0x0200) CRC2(state->check, hold); + INITBITS(); + state->mode = TIME; + case TIME: + NEEDBITS(32); + if (state->head != Z_NULL) + state->head->time = hold; + if (state->flags & 0x0200) CRC4(state->check, hold); + INITBITS(); + state->mode = OS; + case OS: + NEEDBITS(16); + if (state->head != Z_NULL) { + state->head->xflags = (int)(hold & 0xff); + state->head->os = (int)(hold >> 8); + } + if (state->flags & 0x0200) CRC2(state->check, hold); + INITBITS(); + state->mode = EXLEN; + case EXLEN: + if (state->flags & 0x0400) { + NEEDBITS(16); + state->length = (unsigned)(hold); + if (state->head != Z_NULL) + state->head->extra_len = (unsigned)hold; + if (state->flags & 0x0200) CRC2(state->check, hold); + INITBITS(); + } + else if (state->head != Z_NULL) + state->head->extra = Z_NULL; + state->mode = EXTRA; + case EXTRA: + if (state->flags & 0x0400) { + copy = state->length; + if (copy > have) copy = have; + if (copy) { + if (state->head != Z_NULL && + state->head->extra != Z_NULL) { + len = state->head->extra_len - state->length; + zmemcpy(state->head->extra + len, next, + len + copy > state->head->extra_max ? + state->head->extra_max - len : copy); + } + if (state->flags & 0x0200) + state->check = crc32(state->check, next, copy); + have -= copy; + next += copy; + state->length -= copy; + } + if (state->length) goto inf_leave; + } + state->length = 0; + state->mode = NAME; + case NAME: + if (state->flags & 0x0800) { + if (have == 0) goto inf_leave; + copy = 0; + do { + len = (unsigned)(next[copy++]); + if (state->head != Z_NULL && + state->head->name != Z_NULL && + state->length < state->head->name_max) + state->head->name[state->length++] = len; + } while (len && copy < have); + if (state->flags & 0x0200) + state->check = crc32(state->check, next, copy); + have -= copy; + next += copy; + if (len) goto inf_leave; + } + else if (state->head != Z_NULL) + state->head->name = Z_NULL; + state->length = 0; + state->mode = COMMENT; + case COMMENT: + if (state->flags & 0x1000) { + if (have == 0) goto inf_leave; + copy = 0; + do { + len = (unsigned)(next[copy++]); + if (state->head != Z_NULL && + state->head->comment != Z_NULL && + state->length < state->head->comm_max) + state->head->comment[state->length++] = len; + } while (len && copy < have); + if (state->flags & 0x0200) + state->check = crc32(state->check, next, copy); + have -= copy; + next += copy; + if (len) goto inf_leave; + } + else if (state->head != Z_NULL) + state->head->comment = Z_NULL; + state->mode = HCRC; + case HCRC: + if (state->flags & 0x0200) { + NEEDBITS(16); + if (hold != (state->check & 0xffff)) { + strm->msg = (char *)"header crc mismatch"; + state->mode = BAD; + break; + } + INITBITS(); + } + if (state->head != Z_NULL) { + state->head->hcrc = (int)((state->flags >> 9) & 1); + state->head->done = 1; + } + strm->adler = state->check = crc32(0L, Z_NULL, 0); + state->mode = TYPE; + break; +#endif + case DICTID: + NEEDBITS(32); + strm->adler = state->check = ZSWAP32(hold); + INITBITS(); + state->mode = DICT; + case DICT: + if (state->havedict == 0) { + RESTORE(); + return Z_NEED_DICT; + } + strm->adler = state->check = adler32(0L, Z_NULL, 0); + state->mode = TYPE; + case TYPE: + if (flush == Z_BLOCK || flush == Z_TREES) goto inf_leave; + case TYPEDO: + if (state->last) { + BYTEBITS(); + state->mode = CHECK; + break; + } + NEEDBITS(3); + state->last = BITS(1); + DROPBITS(1); + switch (BITS(2)) { + case 0: /* stored block */ + Tracev((stderr, "inflate: stored block%s\n", + state->last ? " (last)" : "")); + state->mode = STORED; + break; + case 1: /* fixed block */ + fixedtables(state); + Tracev((stderr, "inflate: fixed codes block%s\n", + state->last ? " (last)" : "")); + state->mode = LEN_; /* decode codes */ + if (flush == Z_TREES) { + DROPBITS(2); + goto inf_leave; + } + break; + case 2: /* dynamic block */ + Tracev((stderr, "inflate: dynamic codes block%s\n", + state->last ? " (last)" : "")); + state->mode = TABLE; + break; + case 3: + strm->msg = (char *)"invalid block type"; + state->mode = BAD; + } + DROPBITS(2); + break; + case STORED: + BYTEBITS(); /* go to byte boundary */ + NEEDBITS(32); + if ((hold & 0xffff) != ((hold >> 16) ^ 0xffff)) { + strm->msg = (char *)"invalid stored block lengths"; + state->mode = BAD; + break; + } + state->length = (unsigned)hold & 0xffff; + Tracev((stderr, "inflate: stored length %u\n", + state->length)); + INITBITS(); + state->mode = COPY_; + if (flush == Z_TREES) goto inf_leave; + case COPY_: + state->mode = COPY; + case COPY: + copy = state->length; + if (copy) { + if (copy > have) copy = have; + if (copy > left) copy = left; + if (copy == 0) goto inf_leave; + zmemcpy(put, next, copy); + have -= copy; + next += copy; + left -= copy; + put += copy; + state->length -= copy; + break; + } + Tracev((stderr, "inflate: stored end\n")); + state->mode = TYPE; + break; + case TABLE: + NEEDBITS(14); + state->nlen = BITS(5) + 257; + DROPBITS(5); + state->ndist = BITS(5) + 1; + DROPBITS(5); + state->ncode = BITS(4) + 4; + DROPBITS(4); +#ifndef PKZIP_BUG_WORKAROUND + if (state->nlen > 286 || state->ndist > 30) { + strm->msg = (char *)"too many length or distance symbols"; + state->mode = BAD; + break; + } +#endif + Tracev((stderr, "inflate: table sizes ok\n")); + state->have = 0; + state->mode = LENLENS; + case LENLENS: + while (state->have < state->ncode) { + NEEDBITS(3); + state->lens[order[state->have++]] = (unsigned short)BITS(3); + DROPBITS(3); + } + while (state->have < 19) + state->lens[order[state->have++]] = 0; + state->next = state->codes; + state->lencode = (const code FAR *)(state->next); + state->lenbits = 7; + ret = inflate_table(CODES, state->lens, 19, &(state->next), + &(state->lenbits), state->work); + if (ret) { + strm->msg = (char *)"invalid code lengths set"; + state->mode = BAD; + break; + } + Tracev((stderr, "inflate: code lengths ok\n")); + state->have = 0; + state->mode = CODELENS; + case CODELENS: + while (state->have < state->nlen + state->ndist) { + for (;;) { + here = state->lencode[BITS(state->lenbits)]; + if ((unsigned)(here.bits) <= bits) break; + PULLBYTE(); + } + if (here.val < 16) { + DROPBITS(here.bits); + state->lens[state->have++] = here.val; + } + else { + if (here.val == 16) { + NEEDBITS(here.bits + 2); + DROPBITS(here.bits); + if (state->have == 0) { + strm->msg = (char *)"invalid bit length repeat"; + state->mode = BAD; + break; + } + len = state->lens[state->have - 1]; + copy = 3 + BITS(2); + DROPBITS(2); + } + else if (here.val == 17) { + NEEDBITS(here.bits + 3); + DROPBITS(here.bits); + len = 0; + copy = 3 + BITS(3); + DROPBITS(3); + } + else { + NEEDBITS(here.bits + 7); + DROPBITS(here.bits); + len = 0; + copy = 11 + BITS(7); + DROPBITS(7); + } + if (state->have + copy > state->nlen + state->ndist) { + strm->msg = (char *)"invalid bit length repeat"; + state->mode = BAD; + break; + } + while (copy--) + state->lens[state->have++] = (unsigned short)len; + } + } + + /* handle error breaks in while */ + if (state->mode == BAD) break; + + /* check for end-of-block code (better have one) */ + if (state->lens[256] == 0) { + strm->msg = (char *)"invalid code -- missing end-of-block"; + state->mode = BAD; + break; + } + + /* build code tables -- note: do not change the lenbits or distbits + values here (9 and 6) without reading the comments in inftrees.h + concerning the ENOUGH constants, which depend on those values */ + state->next = state->codes; + state->lencode = (const code FAR *)(state->next); + state->lenbits = 9; + ret = inflate_table(LENS, state->lens, state->nlen, &(state->next), + &(state->lenbits), state->work); + if (ret) { + strm->msg = (char *)"invalid literal/lengths set"; + state->mode = BAD; + break; + } + state->distcode = (const code FAR *)(state->next); + state->distbits = 6; + ret = inflate_table(DISTS, state->lens + state->nlen, state->ndist, + &(state->next), &(state->distbits), state->work); + if (ret) { + strm->msg = (char *)"invalid distances set"; + state->mode = BAD; + break; + } + Tracev((stderr, "inflate: codes ok\n")); + state->mode = LEN_; + if (flush == Z_TREES) goto inf_leave; + case LEN_: + state->mode = LEN; + case LEN: + if (have >= 6 && left >= 258) { + RESTORE(); + inflate_fast(strm, out); + LOAD(); + if (state->mode == TYPE) + state->back = -1; + break; + } + state->back = 0; + for (;;) { + here = state->lencode[BITS(state->lenbits)]; + if ((unsigned)(here.bits) <= bits) break; + PULLBYTE(); + } + if (here.op && (here.op & 0xf0) == 0) { + last = here; + for (;;) { + here = state->lencode[last.val + + (BITS(last.bits + last.op) >> last.bits)]; + if ((unsigned)(last.bits + here.bits) <= bits) break; + PULLBYTE(); + } + DROPBITS(last.bits); + state->back += last.bits; + } + DROPBITS(here.bits); + state->back += here.bits; + state->length = (unsigned)here.val; + if ((int)(here.op) == 0) { + Tracevv((stderr, here.val >= 0x20 && here.val < 0x7f ? + "inflate: literal '%c'\n" : + "inflate: literal 0x%02x\n", here.val)); + state->mode = LIT; + break; + } + if (here.op & 32) { + Tracevv((stderr, "inflate: end of block\n")); + state->back = -1; + state->mode = TYPE; + break; + } + if (here.op & 64) { + strm->msg = (char *)"invalid literal/length code"; + state->mode = BAD; + break; + } + state->extra = (unsigned)(here.op) & 15; + state->mode = LENEXT; + case LENEXT: + if (state->extra) { + NEEDBITS(state->extra); + state->length += BITS(state->extra); + DROPBITS(state->extra); + state->back += state->extra; + } + Tracevv((stderr, "inflate: length %u\n", state->length)); + state->was = state->length; + state->mode = DIST; + case DIST: + for (;;) { + here = state->distcode[BITS(state->distbits)]; + if ((unsigned)(here.bits) <= bits) break; + PULLBYTE(); + } + if ((here.op & 0xf0) == 0) { + last = here; + for (;;) { + here = state->distcode[last.val + + (BITS(last.bits + last.op) >> last.bits)]; + if ((unsigned)(last.bits + here.bits) <= bits) break; + PULLBYTE(); + } + DROPBITS(last.bits); + state->back += last.bits; + } + DROPBITS(here.bits); + state->back += here.bits; + if (here.op & 64) { + strm->msg = (char *)"invalid distance code"; + state->mode = BAD; + break; + } + state->offset = (unsigned)here.val; + state->extra = (unsigned)(here.op) & 15; + state->mode = DISTEXT; + case DISTEXT: + if (state->extra) { + NEEDBITS(state->extra); + state->offset += BITS(state->extra); + DROPBITS(state->extra); + state->back += state->extra; + } +#ifdef INFLATE_STRICT + if (state->offset > state->dmax) { + strm->msg = (char *)"invalid distance too far back"; + state->mode = BAD; + break; + } +#endif + Tracevv((stderr, "inflate: distance %u\n", state->offset)); + state->mode = MATCH; + case MATCH: + if (left == 0) goto inf_leave; + copy = out - left; + if (state->offset > copy) { /* copy from window */ + copy = state->offset - copy; + if (copy > state->whave) { + if (state->sane) { + strm->msg = (char *)"invalid distance too far back"; + state->mode = BAD; + break; + } +#ifdef INFLATE_ALLOW_INVALID_DISTANCE_TOOFAR_ARRR + Trace((stderr, "inflate.c too far\n")); + copy -= state->whave; + if (copy > state->length) copy = state->length; + if (copy > left) copy = left; + left -= copy; + state->length -= copy; + do { + *put++ = 0; + } while (--copy); + if (state->length == 0) state->mode = LEN; + break; +#endif + } + if (copy > state->wnext) { + copy -= state->wnext; + from = state->window + (state->wsize - copy); + } + else + from = state->window + (state->wnext - copy); + if (copy > state->length) copy = state->length; + } + else { /* copy from output */ + from = put - state->offset; + copy = state->length; + } + if (copy > left) copy = left; + left -= copy; + state->length -= copy; + do { + *put++ = *from++; + } while (--copy); + if (state->length == 0) state->mode = LEN; + break; + case LIT: + if (left == 0) goto inf_leave; + *put++ = (unsigned char)(state->length); + left--; + state->mode = LEN; + break; + case CHECK: + if (state->wrap) { + NEEDBITS(32); + out -= left; + strm->total_out += out; + state->total += out; + if (out) + strm->adler = state->check = + UPDATE(state->check, put - out, out); + out = left; + if (( +#ifdef GUNZIP + state->flags ? hold : +#endif + ZSWAP32(hold)) != state->check) { + strm->msg = (char *)"incorrect data check"; + state->mode = BAD; + break; + } + INITBITS(); + Tracev((stderr, "inflate: check matches trailer\n")); + } +#ifdef GUNZIP + state->mode = LENGTH; + case LENGTH: + if (state->wrap && state->flags) { + NEEDBITS(32); + if (hold != (state->total & 0xffffffffUL)) { + strm->msg = (char *)"incorrect length check"; + state->mode = BAD; + break; + } + INITBITS(); + Tracev((stderr, "inflate: length matches trailer\n")); + } +#endif + state->mode = DONE; + case DONE: + ret = Z_STREAM_END; + goto inf_leave; + case BAD: + ret = Z_DATA_ERROR; + goto inf_leave; + case MEM: + return Z_MEM_ERROR; + case SYNC: + default: + return Z_STREAM_ERROR; + } + + /* + Return from inflate(), updating the total counts and the check value. + If there was no progress during the inflate() call, return a buffer + error. Call updatewindow() to create and/or update the window state. + Note: a memory error from inflate() is non-recoverable. + */ + inf_leave: + RESTORE(); + if (state->wsize || (out != strm->avail_out && state->mode < BAD && + (state->mode < CHECK || flush != Z_FINISH))) + if (updatewindow(strm, strm->next_out, out - strm->avail_out)) { + state->mode = MEM; + return Z_MEM_ERROR; + } + in -= strm->avail_in; + out -= strm->avail_out; + strm->total_in += in; + strm->total_out += out; + state->total += out; + if (state->wrap && out) + strm->adler = state->check = + UPDATE(state->check, strm->next_out - out, out); + strm->data_type = state->bits + (state->last ? 64 : 0) + + (state->mode == TYPE ? 128 : 0) + + (state->mode == LEN_ || state->mode == COPY_ ? 256 : 0); + if (((in == 0 && out == 0) || flush == Z_FINISH) && ret == Z_OK) + ret = Z_BUF_ERROR; + return ret; +} + +int ZEXPORT inflateEnd(strm) +z_streamp strm; +{ + struct inflate_state FAR *state; + if (strm == Z_NULL || strm->state == Z_NULL || strm->zfree == (free_func)0) + return Z_STREAM_ERROR; + state = (struct inflate_state FAR *)strm->state; + if (state->window != Z_NULL) ZFREE(strm, state->window); + ZFREE(strm, strm->state); + strm->state = Z_NULL; + Tracev((stderr, "inflate: end\n")); + return Z_OK; +} + +int ZEXPORT inflateGetDictionary(strm, dictionary, dictLength) +z_streamp strm; +Bytef *dictionary; +uInt *dictLength; +{ + struct inflate_state FAR *state; + + /* check state */ + if (strm == Z_NULL || strm->state == Z_NULL) return Z_STREAM_ERROR; + state = (struct inflate_state FAR *)strm->state; + + /* copy dictionary */ + if (state->whave && dictionary != Z_NULL) { + zmemcpy(dictionary, state->window + state->wnext, + state->whave - state->wnext); + zmemcpy(dictionary + state->whave - state->wnext, + state->window, state->wnext); + } + if (dictLength != Z_NULL) + *dictLength = state->whave; + return Z_OK; +} + +int ZEXPORT inflateSetDictionary(strm, dictionary, dictLength) +z_streamp strm; +const Bytef *dictionary; +uInt dictLength; +{ + struct inflate_state FAR *state; + unsigned long dictid; + int ret; + + /* check state */ + if (strm == Z_NULL || strm->state == Z_NULL) return Z_STREAM_ERROR; + state = (struct inflate_state FAR *)strm->state; + if (state->wrap != 0 && state->mode != DICT) + return Z_STREAM_ERROR; + + /* check for correct dictionary identifier */ + if (state->mode == DICT) { + dictid = adler32(0L, Z_NULL, 0); + dictid = adler32(dictid, dictionary, dictLength); + if (dictid != state->check) + return Z_DATA_ERROR; + } + + /* copy dictionary to window using updatewindow(), which will amend the + existing dictionary if appropriate */ + ret = updatewindow(strm, dictionary + dictLength, dictLength); + if (ret) { + state->mode = MEM; + return Z_MEM_ERROR; + } + state->havedict = 1; + Tracev((stderr, "inflate: dictionary set\n")); + return Z_OK; +} + +int ZEXPORT inflateGetHeader(strm, head) +z_streamp strm; +gz_headerp head; +{ + struct inflate_state FAR *state; + + /* check state */ + if (strm == Z_NULL || strm->state == Z_NULL) return Z_STREAM_ERROR; + state = (struct inflate_state FAR *)strm->state; + if ((state->wrap & 2) == 0) return Z_STREAM_ERROR; + + /* save header structure */ + state->head = head; + head->done = 0; + return Z_OK; +} + +/* + Search buf[0..len-1] for the pattern: 0, 0, 0xff, 0xff. Return when found + or when out of input. When called, *have is the number of pattern bytes + found in order so far, in 0..3. On return *have is updated to the new + state. If on return *have equals four, then the pattern was found and the + return value is how many bytes were read including the last byte of the + pattern. If *have is less than four, then the pattern has not been found + yet and the return value is len. In the latter case, syncsearch() can be + called again with more data and the *have state. *have is initialized to + zero for the first call. + */ +local unsigned syncsearch(have, buf, len) +unsigned FAR *have; +const unsigned char FAR *buf; +unsigned len; +{ + unsigned got; + unsigned next; + + got = *have; + next = 0; + while (next < len && got < 4) { + if ((int)(buf[next]) == (got < 2 ? 0 : 0xff)) + got++; + else if (buf[next]) + got = 0; + else + got = 4 - got; + next++; + } + *have = got; + return next; +} + +int ZEXPORT inflateSync(strm) +z_streamp strm; +{ + unsigned len; /* number of bytes to look at or looked at */ + unsigned long in, out; /* temporary to save total_in and total_out */ + unsigned char buf[4]; /* to restore bit buffer to byte string */ + struct inflate_state FAR *state; + + /* check parameters */ + if (strm == Z_NULL || strm->state == Z_NULL) return Z_STREAM_ERROR; + state = (struct inflate_state FAR *)strm->state; + if (strm->avail_in == 0 && state->bits < 8) return Z_BUF_ERROR; + + /* if first time, start search in bit buffer */ + if (state->mode != SYNC) { + state->mode = SYNC; + state->hold <<= state->bits & 7; + state->bits -= state->bits & 7; + len = 0; + while (state->bits >= 8) { + buf[len++] = (unsigned char)(state->hold); + state->hold >>= 8; + state->bits -= 8; + } + state->have = 0; + syncsearch(&(state->have), buf, len); + } + + /* search available input */ + len = syncsearch(&(state->have), strm->next_in, strm->avail_in); + strm->avail_in -= len; + strm->next_in += len; + strm->total_in += len; + + /* return no joy or set up to restart inflate() on a new block */ + if (state->have != 4) return Z_DATA_ERROR; + in = strm->total_in; out = strm->total_out; + inflateReset(strm); + strm->total_in = in; strm->total_out = out; + state->mode = TYPE; + return Z_OK; +} + +/* + Returns true if inflate is currently at the end of a block generated by + Z_SYNC_FLUSH or Z_FULL_FLUSH. This function is used by one PPP + implementation to provide an additional safety check. PPP uses + Z_SYNC_FLUSH but removes the length bytes of the resulting empty stored + block. When decompressing, PPP checks that at the end of input packet, + inflate is waiting for these length bytes. + */ +int ZEXPORT inflateSyncPoint(strm) +z_streamp strm; +{ + struct inflate_state FAR *state; + + if (strm == Z_NULL || strm->state == Z_NULL) return Z_STREAM_ERROR; + state = (struct inflate_state FAR *)strm->state; + return state->mode == STORED && state->bits == 0; +} + +int ZEXPORT inflateCopy(dest, source) +z_streamp dest; +z_streamp source; +{ + struct inflate_state FAR *state; + struct inflate_state FAR *copy; + unsigned char FAR *window; + unsigned wsize; + + /* check input */ + if (dest == Z_NULL || source == Z_NULL || source->state == Z_NULL || + source->zalloc == (alloc_func)0 || source->zfree == (free_func)0) + return Z_STREAM_ERROR; + state = (struct inflate_state FAR *)source->state; + + /* allocate space */ + copy = (struct inflate_state FAR *) + ZALLOC(source, 1, sizeof(struct inflate_state)); + if (copy == Z_NULL) return Z_MEM_ERROR; + window = Z_NULL; + if (state->window != Z_NULL) { + window = (unsigned char FAR *) + ZALLOC(source, 1U << state->wbits, sizeof(unsigned char)); + if (window == Z_NULL) { + ZFREE(source, copy); + return Z_MEM_ERROR; + } + } + + /* copy state */ + zmemcpy((voidpf)dest, (voidpf)source, sizeof(z_stream)); + zmemcpy((voidpf)copy, (voidpf)state, sizeof(struct inflate_state)); + if (state->lencode >= state->codes && + state->lencode <= state->codes + ENOUGH - 1) { + copy->lencode = copy->codes + (state->lencode - state->codes); + copy->distcode = copy->codes + (state->distcode - state->codes); + } + copy->next = copy->codes + (state->next - state->codes); + if (window != Z_NULL) { + wsize = 1U << state->wbits; + zmemcpy(window, state->window, wsize); + } + copy->window = window; + dest->state = (struct internal_state FAR *)copy; + return Z_OK; +} + +int ZEXPORT inflateUndermine(strm, subvert) +z_streamp strm; +int subvert; +{ + struct inflate_state FAR *state; + + if (strm == Z_NULL || strm->state == Z_NULL) return Z_STREAM_ERROR; + state = (struct inflate_state FAR *)strm->state; + state->sane = !subvert; +#ifdef INFLATE_ALLOW_INVALID_DISTANCE_TOOFAR_ARRR + return Z_OK; +#else + state->sane = 1; + return Z_DATA_ERROR; +#endif +} + +long ZEXPORT inflateMark(strm) +z_streamp strm; +{ + struct inflate_state FAR *state; + + if (strm == Z_NULL || strm->state == Z_NULL) return -1L << 16; + state = (struct inflate_state FAR *)strm->state; + return ((long)(state->back) << 16) + + (state->mode == COPY ? state->length : + (state->mode == MATCH ? state->was - state->length : 0)); +} diff --git a/ml/dlib/dlib/external/zlib/inflate.h b/ml/dlib/dlib/external/zlib/inflate.h new file mode 100644 index 000000000..95f4986d4 --- /dev/null +++ b/ml/dlib/dlib/external/zlib/inflate.h @@ -0,0 +1,122 @@ +/* inflate.h -- internal inflate state definition + * Copyright (C) 1995-2009 Mark Adler + * For conditions of distribution and use, see copyright notice in zlib.h + */ + +/* WARNING: this file should *not* be used by applications. It is + part of the implementation of the compression library and is + subject to change. Applications should only use zlib.h. + */ + +/* define NO_GZIP when compiling if you want to disable gzip header and + trailer decoding by inflate(). NO_GZIP would be used to avoid linking in + the crc code when it is not needed. For shared libraries, gzip decoding + should be left enabled. */ +#ifndef NO_GZIP +# define GUNZIP +#endif + +/* Possible inflate modes between inflate() calls */ +typedef enum { + HEAD, /* i: waiting for magic header */ + FLAGS, /* i: waiting for method and flags (gzip) */ + TIME, /* i: waiting for modification time (gzip) */ + OS, /* i: waiting for extra flags and operating system (gzip) */ + EXLEN, /* i: waiting for extra length (gzip) */ + EXTRA, /* i: waiting for extra bytes (gzip) */ + NAME, /* i: waiting for end of file name (gzip) */ + COMMENT, /* i: waiting for end of comment (gzip) */ + HCRC, /* i: waiting for header crc (gzip) */ + DICTID, /* i: waiting for dictionary check value */ + DICT, /* waiting for inflateSetDictionary() call */ + TYPE, /* i: waiting for type bits, including last-flag bit */ + TYPEDO, /* i: same, but skip check to exit inflate on new block */ + STORED, /* i: waiting for stored size (length and complement) */ + COPY_, /* i/o: same as COPY below, but only first time in */ + COPY, /* i/o: waiting for input or output to copy stored block */ + TABLE, /* i: waiting for dynamic block table lengths */ + LENLENS, /* i: waiting for code length code lengths */ + CODELENS, /* i: waiting for length/lit and distance code lengths */ + LEN_, /* i: same as LEN below, but only first time in */ + LEN, /* i: waiting for length/lit/eob code */ + LENEXT, /* i: waiting for length extra bits */ + DIST, /* i: waiting for distance code */ + DISTEXT, /* i: waiting for distance extra bits */ + MATCH, /* o: waiting for output space to copy string */ + LIT, /* o: waiting for output space to write literal */ + CHECK, /* i: waiting for 32-bit check value */ + LENGTH, /* i: waiting for 32-bit length (gzip) */ + DONE, /* finished check, done -- remain here until reset */ + BAD, /* got a data error -- remain here until reset */ + MEM, /* got an inflate() memory error -- remain here until reset */ + SYNC /* looking for synchronization bytes to restart inflate() */ +} inflate_mode; + +/* + State transitions between above modes - + + (most modes can go to BAD or MEM on error -- not shown for clarity) + + Process header: + HEAD -> (gzip) or (zlib) or (raw) + (gzip) -> FLAGS -> TIME -> OS -> EXLEN -> EXTRA -> NAME -> COMMENT -> + HCRC -> TYPE + (zlib) -> DICTID or TYPE + DICTID -> DICT -> TYPE + (raw) -> TYPEDO + Read deflate blocks: + TYPE -> TYPEDO -> STORED or TABLE or LEN_ or CHECK + STORED -> COPY_ -> COPY -> TYPE + TABLE -> LENLENS -> CODELENS -> LEN_ + LEN_ -> LEN + Read deflate codes in fixed or dynamic block: + LEN -> LENEXT or LIT or TYPE + LENEXT -> DIST -> DISTEXT -> MATCH -> LEN + LIT -> LEN + Process trailer: + CHECK -> LENGTH -> DONE + */ + +/* state maintained between inflate() calls. Approximately 10K bytes. */ +struct inflate_state { + inflate_mode mode; /* current inflate mode */ + int last; /* true if processing last block */ + int wrap; /* bit 0 true for zlib, bit 1 true for gzip */ + int havedict; /* true if dictionary provided */ + int flags; /* gzip header method and flags (0 if zlib) */ + unsigned dmax; /* zlib header max distance (INFLATE_STRICT) */ + unsigned long check; /* protected copy of check value */ + unsigned long total; /* protected copy of output count */ + gz_headerp head; /* where to save gzip header information */ + /* sliding window */ + unsigned wbits; /* log base 2 of requested window size */ + unsigned wsize; /* window size or zero if not using window */ + unsigned whave; /* valid bytes in the window */ + unsigned wnext; /* window write index */ + unsigned char FAR *window; /* allocated sliding window, if needed */ + /* bit accumulator */ + unsigned long hold; /* input bit accumulator */ + unsigned bits; /* number of bits in "in" */ + /* for string and stored block copying */ + unsigned length; /* literal or length of data to copy */ + unsigned offset; /* distance back to copy string from */ + /* for table and code decoding */ + unsigned extra; /* extra bits needed */ + /* fixed and dynamic code tables */ + code const FAR *lencode; /* starting table for length/literal codes */ + code const FAR *distcode; /* starting table for distance codes */ + unsigned lenbits; /* index bits for lencode */ + unsigned distbits; /* index bits for distcode */ + /* dynamic table building */ + unsigned ncode; /* number of code length code lengths */ + unsigned nlen; /* number of length code lengths */ + unsigned ndist; /* number of distance code lengths */ + unsigned have; /* number of code lengths in lens[] */ + code FAR *next; /* next available space in codes[] */ + unsigned short lens[320]; /* temporary storage for code lengths */ + unsigned short work[288]; /* work area for code table building */ + code codes[ENOUGH]; /* space for code tables */ + int sane; /* if false, allow invalid distance too far */ + int back; /* bits back of last unprocessed length/lit */ + unsigned was; /* initial length of match */ +}; diff --git a/ml/dlib/dlib/external/zlib/inftrees.c b/ml/dlib/dlib/external/zlib/inftrees.c new file mode 100644 index 000000000..44d89cf24 --- /dev/null +++ b/ml/dlib/dlib/external/zlib/inftrees.c @@ -0,0 +1,306 @@ +/* inftrees.c -- generate Huffman trees for efficient decoding + * Copyright (C) 1995-2013 Mark Adler + * For conditions of distribution and use, see copyright notice in zlib.h + */ + +#include "zutil.h" +#include "inftrees.h" + +#define MAXBITS 15 + +const char inflate_copyright[] = + " inflate 1.2.8 Copyright 1995-2013 Mark Adler "; +/* + If you use the zlib library in a product, an acknowledgment is welcome + in the documentation of your product. If for some reason you cannot + include such an acknowledgment, I would appreciate that you keep this + copyright string in the executable of your product. + */ + +/* + Build a set of tables to decode the provided canonical Huffman code. + The code lengths are lens[0..codes-1]. The result starts at *table, + whose indices are 0..2^bits-1. work is a writable array of at least + lens shorts, which is used as a work area. type is the type of code + to be generated, CODES, LENS, or DISTS. On return, zero is success, + -1 is an invalid code, and +1 means that ENOUGH isn't enough. table + on return points to the next available entry's address. bits is the + requested root table index bits, and on return it is the actual root + table index bits. It will differ if the request is greater than the + longest code or if it is less than the shortest code. + */ +int ZLIB_INTERNAL inflate_table(type, lens, codes, table, bits, work) +codetype type; +unsigned short FAR *lens; +unsigned codes; +code FAR * FAR *table; +unsigned FAR *bits; +unsigned short FAR *work; +{ + unsigned len; /* a code's length in bits */ + unsigned sym; /* index of code symbols */ + unsigned min, max; /* minimum and maximum code lengths */ + unsigned root; /* number of index bits for root table */ + unsigned curr; /* number of index bits for current table */ + unsigned drop; /* code bits to drop for sub-table */ + int left; /* number of prefix codes available */ + unsigned used; /* code entries in table used */ + unsigned huff; /* Huffman code */ + unsigned incr; /* for incrementing code, index */ + unsigned fill; /* index for replicating entries */ + unsigned low; /* low bits for current root entry */ + unsigned mask; /* mask for low root bits */ + code here; /* table entry for duplication */ + code FAR *next; /* next available space in table */ + const unsigned short FAR *base; /* base value table to use */ + const unsigned short FAR *extra; /* extra bits table to use */ + int end; /* use base and extra for symbol > end */ + unsigned short count[MAXBITS+1]; /* number of codes of each length */ + unsigned short offs[MAXBITS+1]; /* offsets in table for each length */ + static const unsigned short lbase[31] = { /* Length codes 257..285 base */ + 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31, + 35, 43, 51, 59, 67, 83, 99, 115, 131, 163, 195, 227, 258, 0, 0}; + static const unsigned short lext[31] = { /* Length codes 257..285 extra */ + 16, 16, 16, 16, 16, 16, 16, 16, 17, 17, 17, 17, 18, 18, 18, 18, + 19, 19, 19, 19, 20, 20, 20, 20, 21, 21, 21, 21, 16, 72, 78}; + static const unsigned short dbase[32] = { /* Distance codes 0..29 base */ + 1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193, + 257, 385, 513, 769, 1025, 1537, 2049, 3073, 4097, 6145, + 8193, 12289, 16385, 24577, 0, 0}; + static const unsigned short dext[32] = { /* Distance codes 0..29 extra */ + 16, 16, 16, 16, 17, 17, 18, 18, 19, 19, 20, 20, 21, 21, 22, 22, + 23, 23, 24, 24, 25, 25, 26, 26, 27, 27, + 28, 28, 29, 29, 64, 64}; + + /* + Process a set of code lengths to create a canonical Huffman code. The + code lengths are lens[0..codes-1]. Each length corresponds to the + symbols 0..codes-1. The Huffman code is generated by first sorting the + symbols by length from short to long, and retaining the symbol order + for codes with equal lengths. Then the code starts with all zero bits + for the first code of the shortest length, and the codes are integer + increments for the same length, and zeros are appended as the length + increases. For the deflate format, these bits are stored backwards + from their more natural integer increment ordering, and so when the + decoding tables are built in the large loop below, the integer codes + are incremented backwards. + + This routine assumes, but does not check, that all of the entries in + lens[] are in the range 0..MAXBITS. The caller must assure this. + 1..MAXBITS is interpreted as that code length. zero means that that + symbol does not occur in this code. + + The codes are sorted by computing a count of codes for each length, + creating from that a table of starting indices for each length in the + sorted table, and then entering the symbols in order in the sorted + table. The sorted table is work[], with that space being provided by + the caller. + + The length counts are used for other purposes as well, i.e. finding + the minimum and maximum length codes, determining if there are any + codes at all, checking for a valid set of lengths, and looking ahead + at length counts to determine sub-table sizes when building the + decoding tables. + */ + + /* accumulate lengths for codes (assumes lens[] all in 0..MAXBITS) */ + for (len = 0; len <= MAXBITS; len++) + count[len] = 0; + for (sym = 0; sym < codes; sym++) + count[lens[sym]]++; + + /* bound code lengths, force root to be within code lengths */ + root = *bits; + for (max = MAXBITS; max >= 1; max--) + if (count[max] != 0) break; + if (root > max) root = max; + if (max == 0) { /* no symbols to code at all */ + here.op = (unsigned char)64; /* invalid code marker */ + here.bits = (unsigned char)1; + here.val = (unsigned short)0; + *(*table)++ = here; /* make a table to force an error */ + *(*table)++ = here; + *bits = 1; + return 0; /* no symbols, but wait for decoding to report error */ + } + for (min = 1; min < max; min++) + if (count[min] != 0) break; + if (root < min) root = min; + + /* check for an over-subscribed or incomplete set of lengths */ + left = 1; + for (len = 1; len <= MAXBITS; len++) { + left <<= 1; + left -= count[len]; + if (left < 0) return -1; /* over-subscribed */ + } + if (left > 0 && (type == CODES || max != 1)) + return -1; /* incomplete set */ + + /* generate offsets into symbol table for each length for sorting */ + offs[1] = 0; + for (len = 1; len < MAXBITS; len++) + offs[len + 1] = offs[len] + count[len]; + + /* sort symbols by length, by symbol order within each length */ + for (sym = 0; sym < codes; sym++) + if (lens[sym] != 0) work[offs[lens[sym]]++] = (unsigned short)sym; + + /* + Create and fill in decoding tables. In this loop, the table being + filled is at next and has curr index bits. The code being used is huff + with length len. That code is converted to an index by dropping drop + bits off of the bottom. For codes where len is less than drop + curr, + those top drop + curr - len bits are incremented through all values to + fill the table with replicated entries. + + root is the number of index bits for the root table. When len exceeds + root, sub-tables are created pointed to by the root entry with an index + of the low root bits of huff. This is saved in low to check for when a + new sub-table should be started. drop is zero when the root table is + being filled, and drop is root when sub-tables are being filled. + + When a new sub-table is needed, it is necessary to look ahead in the + code lengths to determine what size sub-table is needed. The length + counts are used for this, and so count[] is decremented as codes are + entered in the tables. + + used keeps track of how many table entries have been allocated from the + provided *table space. It is checked for LENS and DIST tables against + the constants ENOUGH_LENS and ENOUGH_DISTS to guard against changes in + the initial root table size constants. See the comments in inftrees.h + for more information. + + sym increments through all symbols, and the loop terminates when + all codes of length max, i.e. all codes, have been processed. This + routine permits incomplete codes, so another loop after this one fills + in the rest of the decoding tables with invalid code markers. + */ + + /* set up for code type */ + switch (type) { + case CODES: + base = extra = work; /* dummy value--not used */ + end = 19; + break; + case LENS: + base = lbase; + base -= 257; + extra = lext; + extra -= 257; + end = 256; + break; + default: /* DISTS */ + base = dbase; + extra = dext; + end = -1; + } + + /* initialize state for loop */ + huff = 0; /* starting code */ + sym = 0; /* starting code symbol */ + len = min; /* starting code length */ + next = *table; /* current table to fill in */ + curr = root; /* current table index bits */ + drop = 0; /* current bits to drop from code for index */ + low = (unsigned)(-1); /* trigger new sub-table when len > root */ + used = 1U << root; /* use root table entries */ + mask = used - 1; /* mask for comparing low */ + + /* check available table space */ + if ((type == LENS && used > ENOUGH_LENS) || + (type == DISTS && used > ENOUGH_DISTS)) + return 1; + + /* process all codes and make table entries */ + for (;;) { + /* create table entry */ + here.bits = (unsigned char)(len - drop); + if ((int)(work[sym]) < end) { + here.op = (unsigned char)0; + here.val = work[sym]; + } + else if ((int)(work[sym]) > end) { + here.op = (unsigned char)(extra[work[sym]]); + here.val = base[work[sym]]; + } + else { + here.op = (unsigned char)(32 + 64); /* end of block */ + here.val = 0; + } + + /* replicate for those indices with low len bits equal to huff */ + incr = 1U << (len - drop); + fill = 1U << curr; + min = fill; /* save offset to next table */ + do { + fill -= incr; + next[(huff >> drop) + fill] = here; + } while (fill != 0); + + /* backwards increment the len-bit code huff */ + incr = 1U << (len - 1); + while (huff & incr) + incr >>= 1; + if (incr != 0) { + huff &= incr - 1; + huff += incr; + } + else + huff = 0; + + /* go to next symbol, update count, len */ + sym++; + if (--(count[len]) == 0) { + if (len == max) break; + len = lens[work[sym]]; + } + + /* create new sub-table if needed */ + if (len > root && (huff & mask) != low) { + /* if first time, transition to sub-tables */ + if (drop == 0) + drop = root; + + /* increment past last table */ + next += min; /* here min is 1 << curr */ + + /* determine length of next table */ + curr = len - drop; + left = (int)(1 << curr); + while (curr + drop < max) { + left -= count[curr + drop]; + if (left <= 0) break; + curr++; + left <<= 1; + } + + /* check for enough space */ + used += 1U << curr; + if ((type == LENS && used > ENOUGH_LENS) || + (type == DISTS && used > ENOUGH_DISTS)) + return 1; + + /* point entry in root table to sub-table */ + low = huff & mask; + (*table)[low].op = (unsigned char)curr; + (*table)[low].bits = (unsigned char)root; + (*table)[low].val = (unsigned short)(next - *table); + } + } + + /* fill in remaining table entry if code is incomplete (guaranteed to have + at most one remaining entry, since if the code is incomplete, the + maximum code length that was allowed to get this far is one bit) */ + if (huff != 0) { + here.op = (unsigned char)64; /* invalid code marker */ + here.bits = (unsigned char)(len - drop); + here.val = (unsigned short)0; + next[huff] = here; + } + + /* set return parameters */ + *table += used; + *bits = root; + return 0; +} diff --git a/ml/dlib/dlib/external/zlib/inftrees.h b/ml/dlib/dlib/external/zlib/inftrees.h new file mode 100644 index 000000000..baa53a0b1 --- /dev/null +++ b/ml/dlib/dlib/external/zlib/inftrees.h @@ -0,0 +1,62 @@ +/* inftrees.h -- header to use inftrees.c + * Copyright (C) 1995-2005, 2010 Mark Adler + * For conditions of distribution and use, see copyright notice in zlib.h + */ + +/* WARNING: this file should *not* be used by applications. It is + part of the implementation of the compression library and is + subject to change. Applications should only use zlib.h. + */ + +/* Structure for decoding tables. Each entry provides either the + information needed to do the operation requested by the code that + indexed that table entry, or it provides a pointer to another + table that indexes more bits of the code. op indicates whether + the entry is a pointer to another table, a literal, a length or + distance, an end-of-block, or an invalid code. For a table + pointer, the low four bits of op is the number of index bits of + that table. For a length or distance, the low four bits of op + is the number of extra bits to get after the code. bits is + the number of bits in this code or part of the code to drop off + of the bit buffer. val is the actual byte to output in the case + of a literal, the base length or distance, or the offset from + the current table to the next table. Each entry is four bytes. */ +typedef struct { + unsigned char op; /* operation, extra bits, table bits */ + unsigned char bits; /* bits in this part of the code */ + unsigned short val; /* offset in table or code value */ +} code; + +/* op values as set by inflate_table(): + 00000000 - literal + 0000tttt - table link, tttt != 0 is the number of table index bits + 0001eeee - length or distance, eeee is the number of extra bits + 01100000 - end of block + 01000000 - invalid code + */ + +/* Maximum size of the dynamic table. The maximum number of code structures is + 1444, which is the sum of 852 for literal/length codes and 592 for distance + codes. These values were found by exhaustive searches using the program + examples/enough.c found in the zlib distribtution. The arguments to that + program are the number of symbols, the initial root table size, and the + maximum bit length of a code. "enough 286 9 15" for literal/length codes + returns returns 852, and "enough 30 6 15" for distance codes returns 592. + The initial root table size (9 or 6) is found in the fifth argument of the + inflate_table() calls in inflate.c and infback.c. If the root table size is + changed, then these maximum sizes would be need to be recalculated and + updated. */ +#define ENOUGH_LENS 852 +#define ENOUGH_DISTS 592 +#define ENOUGH (ENOUGH_LENS+ENOUGH_DISTS) + +/* Type of code to build for inflate_table() */ +typedef enum { + CODES, + LENS, + DISTS +} codetype; + +int ZLIB_INTERNAL inflate_table OF((codetype type, unsigned short FAR *lens, + unsigned codes, code FAR * FAR *table, + unsigned FAR *bits, unsigned short FAR *work)); diff --git a/ml/dlib/dlib/external/zlib/trees.c b/ml/dlib/dlib/external/zlib/trees.c new file mode 100644 index 000000000..1fd7759ef --- /dev/null +++ b/ml/dlib/dlib/external/zlib/trees.c @@ -0,0 +1,1226 @@ +/* trees.c -- output deflated data using Huffman coding + * Copyright (C) 1995-2012 Jean-loup Gailly + * detect_data_type() function provided freely by Cosmin Truta, 2006 + * For conditions of distribution and use, see copyright notice in zlib.h + */ + +/* + * ALGORITHM + * + * The "deflation" process uses several Huffman trees. The more + * common source values are represented by shorter bit sequences. + * + * Each code tree is stored in a compressed form which is itself + * a Huffman encoding of the lengths of all the code strings (in + * ascending order by source values). The actual code strings are + * reconstructed from the lengths in the inflate process, as described + * in the deflate specification. + * + * REFERENCES + * + * Deutsch, L.P.,"'Deflate' Compressed Data Format Specification". + * Available in ftp.uu.net:/pub/archiving/zip/doc/deflate-1.1.doc + * + * Storer, James A. + * Data Compression: Methods and Theory, pp. 49-50. + * Computer Science Press, 1988. ISBN 0-7167-8156-5. + * + * Sedgewick, R. + * Algorithms, p290. + * Addison-Wesley, 1983. ISBN 0-201-06672-6. + */ + +/* @(#) $Id$ */ + +/* #define GEN_TREES_H */ + +#include "deflate.h" + +#ifdef DEBUG +# include +#endif + +/* =========================================================================== + * Constants + */ + +#define MAX_BL_BITS 7 +/* Bit length codes must not exceed MAX_BL_BITS bits */ + +#define END_BLOCK 256 +/* end of block literal code */ + +#define REP_3_6 16 +/* repeat previous bit length 3-6 times (2 bits of repeat count) */ + +#define REPZ_3_10 17 +/* repeat a zero length 3-10 times (3 bits of repeat count) */ + +#define REPZ_11_138 18 +/* repeat a zero length 11-138 times (7 bits of repeat count) */ + +local const int extra_lbits[LENGTH_CODES] /* extra bits for each length code */ + = {0,0,0,0,0,0,0,0,1,1,1,1,2,2,2,2,3,3,3,3,4,4,4,4,5,5,5,5,0}; + +local const int extra_dbits[D_CODES] /* extra bits for each distance code */ + = {0,0,0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7,8,8,9,9,10,10,11,11,12,12,13,13}; + +local const int extra_blbits[BL_CODES]/* extra bits for each bit length code */ + = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2,3,7}; + +local const uch bl_order[BL_CODES] + = {16,17,18,0,8,7,9,6,10,5,11,4,12,3,13,2,14,1,15}; +/* The lengths of the bit length codes are sent in order of decreasing + * probability, to avoid transmitting the lengths for unused bit length codes. + */ + +/* =========================================================================== + * Local data. These are initialized only once. + */ + +#define DIST_CODE_LEN 512 /* see definition of array dist_code below */ + +#if defined(GEN_TREES_H) || !defined(STDC) +/* non ANSI compilers may not accept trees.h */ + +local ct_data static_ltree[L_CODES+2]; +/* The static literal tree. Since the bit lengths are imposed, there is no + * need for the L_CODES extra codes used during heap construction. However + * The codes 286 and 287 are needed to build a canonical tree (see _tr_init + * below). + */ + +local ct_data static_dtree[D_CODES]; +/* The static distance tree. (Actually a trivial tree since all codes use + * 5 bits.) + */ + +uch _dist_code[DIST_CODE_LEN]; +/* Distance codes. The first 256 values correspond to the distances + * 3 .. 258, the last 256 values correspond to the top 8 bits of + * the 15 bit distances. + */ + +uch _length_code[MAX_MATCH-MIN_MATCH+1]; +/* length code for each normalized match length (0 == MIN_MATCH) */ + +local int base_length[LENGTH_CODES]; +/* First normalized length for each code (0 = MIN_MATCH) */ + +local int base_dist[D_CODES]; +/* First normalized distance for each code (0 = distance of 1) */ + +#else +# include "trees.h" +#endif /* GEN_TREES_H */ + +struct static_tree_desc_s { + const ct_data *static_tree; /* static tree or NULL */ + const intf *extra_bits; /* extra bits for each code or NULL */ + int extra_base; /* base index for extra_bits */ + int elems; /* max number of elements in the tree */ + int max_length; /* max bit length for the codes */ +}; + +local static_tree_desc static_l_desc = +{static_ltree, extra_lbits, LITERALS+1, L_CODES, MAX_BITS}; + +local static_tree_desc static_d_desc = +{static_dtree, extra_dbits, 0, D_CODES, MAX_BITS}; + +local static_tree_desc static_bl_desc = +{(const ct_data *)0, extra_blbits, 0, BL_CODES, MAX_BL_BITS}; + +/* =========================================================================== + * Local (static) routines in this file. + */ + +local void tr_static_init OF((void)); +local void init_block OF((deflate_state *s)); +local void pqdownheap OF((deflate_state *s, ct_data *tree, int k)); +local void gen_bitlen OF((deflate_state *s, tree_desc *desc)); +local void gen_codes OF((ct_data *tree, int max_code, ushf *bl_count)); +local void build_tree OF((deflate_state *s, tree_desc *desc)); +local void scan_tree OF((deflate_state *s, ct_data *tree, int max_code)); +local void send_tree OF((deflate_state *s, ct_data *tree, int max_code)); +local int build_bl_tree OF((deflate_state *s)); +local void send_all_trees OF((deflate_state *s, int lcodes, int dcodes, + int blcodes)); +local void compress_block OF((deflate_state *s, const ct_data *ltree, + const ct_data *dtree)); +local int detect_data_type OF((deflate_state *s)); +local unsigned bi_reverse OF((unsigned value, int length)); +local void bi_windup OF((deflate_state *s)); +local void bi_flush OF((deflate_state *s)); +local void copy_block OF((deflate_state *s, charf *buf, unsigned len, + int header)); + +#ifdef GEN_TREES_H +local void gen_trees_header OF((void)); +#endif + +#ifndef DEBUG +# define send_code(s, c, tree) send_bits(s, tree[c].Code, tree[c].Len) + /* Send a code of the given tree. c and tree must not have side effects */ + +#else /* DEBUG */ +# define send_code(s, c, tree) \ + { if (z_verbose>2) fprintf(stderr,"\ncd %3d ",(c)); \ + send_bits(s, tree[c].Code, tree[c].Len); } +#endif + +/* =========================================================================== + * Output a short LSB first on the stream. + * IN assertion: there is enough room in pendingBuf. + */ +#define put_short(s, w) { \ + put_byte(s, (uch)((w) & 0xff)); \ + put_byte(s, (uch)((ush)(w) >> 8)); \ +} + +/* =========================================================================== + * Send a value on a given number of bits. + * IN assertion: length <= 16 and value fits in length bits. + */ +#ifdef DEBUG +local void send_bits OF((deflate_state *s, int value, int length)); + +local void send_bits(s, value, length) + deflate_state *s; + int value; /* value to send */ + int length; /* number of bits */ +{ + Tracevv((stderr," l %2d v %4x ", length, value)); + Assert(length > 0 && length <= 15, "invalid length"); + s->bits_sent += (ulg)length; + + /* If not enough room in bi_buf, use (valid) bits from bi_buf and + * (16 - bi_valid) bits from value, leaving (width - (16-bi_valid)) + * unused bits in value. + */ + if (s->bi_valid > (int)Buf_size - length) { + s->bi_buf |= (ush)value << s->bi_valid; + put_short(s, s->bi_buf); + s->bi_buf = (ush)value >> (Buf_size - s->bi_valid); + s->bi_valid += length - Buf_size; + } else { + s->bi_buf |= (ush)value << s->bi_valid; + s->bi_valid += length; + } +} +#else /* !DEBUG */ + +#define send_bits(s, value, length) \ +{ int len = length;\ + if (s->bi_valid > (int)Buf_size - len) {\ + int val = value;\ + s->bi_buf |= (ush)val << s->bi_valid;\ + put_short(s, s->bi_buf);\ + s->bi_buf = (ush)val >> (Buf_size - s->bi_valid);\ + s->bi_valid += len - Buf_size;\ + } else {\ + s->bi_buf |= (ush)(value) << s->bi_valid;\ + s->bi_valid += len;\ + }\ +} +#endif /* DEBUG */ + + +/* the arguments must not have side effects */ + +/* =========================================================================== + * Initialize the various 'constant' tables. + */ +local void tr_static_init() +{ +#if defined(GEN_TREES_H) || !defined(STDC) + static int static_init_done = 0; + int n; /* iterates over tree elements */ + int bits; /* bit counter */ + int length; /* length value */ + int code; /* code value */ + int dist; /* distance index */ + ush bl_count[MAX_BITS+1]; + /* number of codes at each bit length for an optimal tree */ + + if (static_init_done) return; + + /* For some embedded targets, global variables are not initialized: */ +#ifdef NO_INIT_GLOBAL_POINTERS + static_l_desc.static_tree = static_ltree; + static_l_desc.extra_bits = extra_lbits; + static_d_desc.static_tree = static_dtree; + static_d_desc.extra_bits = extra_dbits; + static_bl_desc.extra_bits = extra_blbits; +#endif + + /* Initialize the mapping length (0..255) -> length code (0..28) */ + length = 0; + for (code = 0; code < LENGTH_CODES-1; code++) { + base_length[code] = length; + for (n = 0; n < (1< dist code (0..29) */ + dist = 0; + for (code = 0 ; code < 16; code++) { + base_dist[code] = dist; + for (n = 0; n < (1<>= 7; /* from now on, all distances are divided by 128 */ + for ( ; code < D_CODES; code++) { + base_dist[code] = dist << 7; + for (n = 0; n < (1<<(extra_dbits[code]-7)); n++) { + _dist_code[256 + dist++] = (uch)code; + } + } + Assert (dist == 256, "tr_static_init: 256+dist != 512"); + + /* Construct the codes of the static literal tree */ + for (bits = 0; bits <= MAX_BITS; bits++) bl_count[bits] = 0; + n = 0; + while (n <= 143) static_ltree[n++].Len = 8, bl_count[8]++; + while (n <= 255) static_ltree[n++].Len = 9, bl_count[9]++; + while (n <= 279) static_ltree[n++].Len = 7, bl_count[7]++; + while (n <= 287) static_ltree[n++].Len = 8, bl_count[8]++; + /* Codes 286 and 287 do not exist, but we must include them in the + * tree construction to get a canonical Huffman tree (longest code + * all ones) + */ + gen_codes((ct_data *)static_ltree, L_CODES+1, bl_count); + + /* The static distance tree is trivial: */ + for (n = 0; n < D_CODES; n++) { + static_dtree[n].Len = 5; + static_dtree[n].Code = bi_reverse((unsigned)n, 5); + } + static_init_done = 1; + +# ifdef GEN_TREES_H + gen_trees_header(); +# endif +#endif /* defined(GEN_TREES_H) || !defined(STDC) */ +} + +/* =========================================================================== + * Genererate the file trees.h describing the static trees. + */ +#ifdef GEN_TREES_H +# ifndef DEBUG +# include +# endif + +# define SEPARATOR(i, last, width) \ + ((i) == (last)? "\n};\n\n" : \ + ((i) % (width) == (width)-1 ? ",\n" : ", ")) + +void gen_trees_header() +{ + FILE *header = fopen("trees.h", "w"); + int i; + + Assert (header != NULL, "Can't open trees.h"); + fprintf(header, + "/* header created automatically with -DGEN_TREES_H */\n\n"); + + fprintf(header, "local const ct_data static_ltree[L_CODES+2] = {\n"); + for (i = 0; i < L_CODES+2; i++) { + fprintf(header, "{{%3u},{%3u}}%s", static_ltree[i].Code, + static_ltree[i].Len, SEPARATOR(i, L_CODES+1, 5)); + } + + fprintf(header, "local const ct_data static_dtree[D_CODES] = {\n"); + for (i = 0; i < D_CODES; i++) { + fprintf(header, "{{%2u},{%2u}}%s", static_dtree[i].Code, + static_dtree[i].Len, SEPARATOR(i, D_CODES-1, 5)); + } + + fprintf(header, "const uch ZLIB_INTERNAL _dist_code[DIST_CODE_LEN] = {\n"); + for (i = 0; i < DIST_CODE_LEN; i++) { + fprintf(header, "%2u%s", _dist_code[i], + SEPARATOR(i, DIST_CODE_LEN-1, 20)); + } + + fprintf(header, + "const uch ZLIB_INTERNAL _length_code[MAX_MATCH-MIN_MATCH+1]= {\n"); + for (i = 0; i < MAX_MATCH-MIN_MATCH+1; i++) { + fprintf(header, "%2u%s", _length_code[i], + SEPARATOR(i, MAX_MATCH-MIN_MATCH, 20)); + } + + fprintf(header, "local const int base_length[LENGTH_CODES] = {\n"); + for (i = 0; i < LENGTH_CODES; i++) { + fprintf(header, "%1u%s", base_length[i], + SEPARATOR(i, LENGTH_CODES-1, 20)); + } + + fprintf(header, "local const int base_dist[D_CODES] = {\n"); + for (i = 0; i < D_CODES; i++) { + fprintf(header, "%5u%s", base_dist[i], + SEPARATOR(i, D_CODES-1, 10)); + } + + fclose(header); +} +#endif /* GEN_TREES_H */ + +/* =========================================================================== + * Initialize the tree data structures for a new zlib stream. + */ +void ZLIB_INTERNAL _tr_init(s) + deflate_state *s; +{ + tr_static_init(); + + s->l_desc.dyn_tree = s->dyn_ltree; + s->l_desc.stat_desc = &static_l_desc; + + s->d_desc.dyn_tree = s->dyn_dtree; + s->d_desc.stat_desc = &static_d_desc; + + s->bl_desc.dyn_tree = s->bl_tree; + s->bl_desc.stat_desc = &static_bl_desc; + + s->bi_buf = 0; + s->bi_valid = 0; +#ifdef DEBUG + s->compressed_len = 0L; + s->bits_sent = 0L; +#endif + + /* Initialize the first block of the first file: */ + init_block(s); +} + +/* =========================================================================== + * Initialize a new block. + */ +local void init_block(s) + deflate_state *s; +{ + int n; /* iterates over tree elements */ + + /* Initialize the trees. */ + for (n = 0; n < L_CODES; n++) s->dyn_ltree[n].Freq = 0; + for (n = 0; n < D_CODES; n++) s->dyn_dtree[n].Freq = 0; + for (n = 0; n < BL_CODES; n++) s->bl_tree[n].Freq = 0; + + s->dyn_ltree[END_BLOCK].Freq = 1; + s->opt_len = s->static_len = 0L; + s->last_lit = s->matches = 0; +} + +#define SMALLEST 1 +/* Index within the heap array of least frequent node in the Huffman tree */ + + +/* =========================================================================== + * Remove the smallest element from the heap and recreate the heap with + * one less element. Updates heap and heap_len. + */ +#define pqremove(s, tree, top) \ +{\ + top = s->heap[SMALLEST]; \ + s->heap[SMALLEST] = s->heap[s->heap_len--]; \ + pqdownheap(s, tree, SMALLEST); \ +} + +/* =========================================================================== + * Compares to subtrees, using the tree depth as tie breaker when + * the subtrees have equal frequency. This minimizes the worst case length. + */ +#define smaller(tree, n, m, depth) \ + (tree[n].Freq < tree[m].Freq || \ + (tree[n].Freq == tree[m].Freq && depth[n] <= depth[m])) + +/* =========================================================================== + * Restore the heap property by moving down the tree starting at node k, + * exchanging a node with the smallest of its two sons if necessary, stopping + * when the heap property is re-established (each father smaller than its + * two sons). + */ +local void pqdownheap(s, tree, k) + deflate_state *s; + ct_data *tree; /* the tree to restore */ + int k; /* node to move down */ +{ + int v = s->heap[k]; + int j = k << 1; /* left son of k */ + while (j <= s->heap_len) { + /* Set j to the smallest of the two sons: */ + if (j < s->heap_len && + smaller(tree, s->heap[j+1], s->heap[j], s->depth)) { + j++; + } + /* Exit if v is smaller than both sons */ + if (smaller(tree, v, s->heap[j], s->depth)) break; + + /* Exchange v with the smallest son */ + s->heap[k] = s->heap[j]; k = j; + + /* And continue down the tree, setting j to the left son of k */ + j <<= 1; + } + s->heap[k] = v; +} + +/* =========================================================================== + * Compute the optimal bit lengths for a tree and update the total bit length + * for the current block. + * IN assertion: the fields freq and dad are set, heap[heap_max] and + * above are the tree nodes sorted by increasing frequency. + * OUT assertions: the field len is set to the optimal bit length, the + * array bl_count contains the frequencies for each bit length. + * The length opt_len is updated; static_len is also updated if stree is + * not null. + */ +local void gen_bitlen(s, desc) + deflate_state *s; + tree_desc *desc; /* the tree descriptor */ +{ + ct_data *tree = desc->dyn_tree; + int max_code = desc->max_code; + const ct_data *stree = desc->stat_desc->static_tree; + const intf *extra = desc->stat_desc->extra_bits; + int base = desc->stat_desc->extra_base; + int max_length = desc->stat_desc->max_length; + int h; /* heap index */ + int n, m; /* iterate over the tree elements */ + int bits; /* bit length */ + int xbits; /* extra bits */ + ush f; /* frequency */ + int overflow = 0; /* number of elements with bit length too large */ + + for (bits = 0; bits <= MAX_BITS; bits++) s->bl_count[bits] = 0; + + /* In a first pass, compute the optimal bit lengths (which may + * overflow in the case of the bit length tree). + */ + tree[s->heap[s->heap_max]].Len = 0; /* root of the heap */ + + for (h = s->heap_max+1; h < HEAP_SIZE; h++) { + n = s->heap[h]; + bits = tree[tree[n].Dad].Len + 1; + if (bits > max_length) bits = max_length, overflow++; + tree[n].Len = (ush)bits; + /* We overwrite tree[n].Dad which is no longer needed */ + + if (n > max_code) continue; /* not a leaf node */ + + s->bl_count[bits]++; + xbits = 0; + if (n >= base) xbits = extra[n-base]; + f = tree[n].Freq; + s->opt_len += (ulg)f * (bits + xbits); + if (stree) s->static_len += (ulg)f * (stree[n].Len + xbits); + } + if (overflow == 0) return; + + Trace((stderr,"\nbit length overflow\n")); + /* This happens for example on obj2 and pic of the Calgary corpus */ + + /* Find the first bit length which could increase: */ + do { + bits = max_length-1; + while (s->bl_count[bits] == 0) bits--; + s->bl_count[bits]--; /* move one leaf down the tree */ + s->bl_count[bits+1] += 2; /* move one overflow item as its brother */ + s->bl_count[max_length]--; + /* The brother of the overflow item also moves one step up, + * but this does not affect bl_count[max_length] + */ + overflow -= 2; + } while (overflow > 0); + + /* Now recompute all bit lengths, scanning in increasing frequency. + * h is still equal to HEAP_SIZE. (It is simpler to reconstruct all + * lengths instead of fixing only the wrong ones. This idea is taken + * from 'ar' written by Haruhiko Okumura.) + */ + for (bits = max_length; bits != 0; bits--) { + n = s->bl_count[bits]; + while (n != 0) { + m = s->heap[--h]; + if (m > max_code) continue; + if ((unsigned) tree[m].Len != (unsigned) bits) { + Trace((stderr,"code %d bits %d->%d\n", m, tree[m].Len, bits)); + s->opt_len += ((long)bits - (long)tree[m].Len) + *(long)tree[m].Freq; + tree[m].Len = (ush)bits; + } + n--; + } + } +} + +/* =========================================================================== + * Generate the codes for a given tree and bit counts (which need not be + * optimal). + * IN assertion: the array bl_count contains the bit length statistics for + * the given tree and the field len is set for all tree elements. + * OUT assertion: the field code is set for all tree elements of non + * zero code length. + */ +local void gen_codes (tree, max_code, bl_count) + ct_data *tree; /* the tree to decorate */ + int max_code; /* largest code with non zero frequency */ + ushf *bl_count; /* number of codes at each bit length */ +{ + ush next_code[MAX_BITS+1]; /* next code value for each bit length */ + ush code = 0; /* running code value */ + int bits; /* bit index */ + int n; /* code index */ + + /* The distribution counts are first used to generate the code values + * without bit reversal. + */ + for (bits = 1; bits <= MAX_BITS; bits++) { + next_code[bits] = code = (code + bl_count[bits-1]) << 1; + } + /* Check that the bit counts in bl_count are consistent. The last code + * must be all ones. + */ + Assert (code + bl_count[MAX_BITS]-1 == (1<dyn_tree; + const ct_data *stree = desc->stat_desc->static_tree; + int elems = desc->stat_desc->elems; + int n, m; /* iterate over heap elements */ + int max_code = -1; /* largest code with non zero frequency */ + int node; /* new node being created */ + + /* Construct the initial heap, with least frequent element in + * heap[SMALLEST]. The sons of heap[n] are heap[2*n] and heap[2*n+1]. + * heap[0] is not used. + */ + s->heap_len = 0, s->heap_max = HEAP_SIZE; + + for (n = 0; n < elems; n++) { + if (tree[n].Freq != 0) { + s->heap[++(s->heap_len)] = max_code = n; + s->depth[n] = 0; + } else { + tree[n].Len = 0; + } + } + + /* The pkzip format requires that at least one distance code exists, + * and that at least one bit should be sent even if there is only one + * possible code. So to avoid special checks later on we force at least + * two codes of non zero frequency. + */ + while (s->heap_len < 2) { + node = s->heap[++(s->heap_len)] = (max_code < 2 ? ++max_code : 0); + tree[node].Freq = 1; + s->depth[node] = 0; + s->opt_len--; if (stree) s->static_len -= stree[node].Len; + /* node is 0 or 1 so it does not have extra bits */ + } + desc->max_code = max_code; + + /* The elements heap[heap_len/2+1 .. heap_len] are leaves of the tree, + * establish sub-heaps of increasing lengths: + */ + for (n = s->heap_len/2; n >= 1; n--) pqdownheap(s, tree, n); + + /* Construct the Huffman tree by repeatedly combining the least two + * frequent nodes. + */ + node = elems; /* next internal node of the tree */ + do { + pqremove(s, tree, n); /* n = node of least frequency */ + m = s->heap[SMALLEST]; /* m = node of next least frequency */ + + s->heap[--(s->heap_max)] = n; /* keep the nodes sorted by frequency */ + s->heap[--(s->heap_max)] = m; + + /* Create a new node father of n and m */ + tree[node].Freq = tree[n].Freq + tree[m].Freq; + s->depth[node] = (uch)((s->depth[n] >= s->depth[m] ? + s->depth[n] : s->depth[m]) + 1); + tree[n].Dad = tree[m].Dad = (ush)node; +#ifdef DUMP_BL_TREE + if (tree == s->bl_tree) { + fprintf(stderr,"\nnode %d(%d), sons %d(%d) %d(%d)", + node, tree[node].Freq, n, tree[n].Freq, m, tree[m].Freq); + } +#endif + /* and insert the new node in the heap */ + s->heap[SMALLEST] = node++; + pqdownheap(s, tree, SMALLEST); + + } while (s->heap_len >= 2); + + s->heap[--(s->heap_max)] = s->heap[SMALLEST]; + + /* At this point, the fields freq and dad are set. We can now + * generate the bit lengths. + */ + gen_bitlen(s, (tree_desc *)desc); + + /* The field len is now set, we can generate the bit codes */ + gen_codes ((ct_data *)tree, max_code, s->bl_count); +} + +/* =========================================================================== + * Scan a literal or distance tree to determine the frequencies of the codes + * in the bit length tree. + */ +local void scan_tree (s, tree, max_code) + deflate_state *s; + ct_data *tree; /* the tree to be scanned */ + int max_code; /* and its largest code of non zero frequency */ +{ + int n; /* iterates over all tree elements */ + int prevlen = -1; /* last emitted length */ + int curlen; /* length of current code */ + int nextlen = tree[0].Len; /* length of next code */ + int count = 0; /* repeat count of the current code */ + int max_count = 7; /* max repeat count */ + int min_count = 4; /* min repeat count */ + + if (nextlen == 0) max_count = 138, min_count = 3; + tree[max_code+1].Len = (ush)0xffff; /* guard */ + + for (n = 0; n <= max_code; n++) { + curlen = nextlen; nextlen = tree[n+1].Len; + if (++count < max_count && curlen == nextlen) { + continue; + } else if (count < min_count) { + s->bl_tree[curlen].Freq += count; + } else if (curlen != 0) { + if (curlen != prevlen) s->bl_tree[curlen].Freq++; + s->bl_tree[REP_3_6].Freq++; + } else if (count <= 10) { + s->bl_tree[REPZ_3_10].Freq++; + } else { + s->bl_tree[REPZ_11_138].Freq++; + } + count = 0; prevlen = curlen; + if (nextlen == 0) { + max_count = 138, min_count = 3; + } else if (curlen == nextlen) { + max_count = 6, min_count = 3; + } else { + max_count = 7, min_count = 4; + } + } +} + +/* =========================================================================== + * Send a literal or distance tree in compressed form, using the codes in + * bl_tree. + */ +local void send_tree (s, tree, max_code) + deflate_state *s; + ct_data *tree; /* the tree to be scanned */ + int max_code; /* and its largest code of non zero frequency */ +{ + int n; /* iterates over all tree elements */ + int prevlen = -1; /* last emitted length */ + int curlen; /* length of current code */ + int nextlen = tree[0].Len; /* length of next code */ + int count = 0; /* repeat count of the current code */ + int max_count = 7; /* max repeat count */ + int min_count = 4; /* min repeat count */ + + /* tree[max_code+1].Len = -1; */ /* guard already set */ + if (nextlen == 0) max_count = 138, min_count = 3; + + for (n = 0; n <= max_code; n++) { + curlen = nextlen; nextlen = tree[n+1].Len; + if (++count < max_count && curlen == nextlen) { + continue; + } else if (count < min_count) { + do { send_code(s, curlen, s->bl_tree); } while (--count != 0); + + } else if (curlen != 0) { + if (curlen != prevlen) { + send_code(s, curlen, s->bl_tree); count--; + } + Assert(count >= 3 && count <= 6, " 3_6?"); + send_code(s, REP_3_6, s->bl_tree); send_bits(s, count-3, 2); + + } else if (count <= 10) { + send_code(s, REPZ_3_10, s->bl_tree); send_bits(s, count-3, 3); + + } else { + send_code(s, REPZ_11_138, s->bl_tree); send_bits(s, count-11, 7); + } + count = 0; prevlen = curlen; + if (nextlen == 0) { + max_count = 138, min_count = 3; + } else if (curlen == nextlen) { + max_count = 6, min_count = 3; + } else { + max_count = 7, min_count = 4; + } + } +} + +/* =========================================================================== + * Construct the Huffman tree for the bit lengths and return the index in + * bl_order of the last bit length code to send. + */ +local int build_bl_tree(s) + deflate_state *s; +{ + int max_blindex; /* index of last bit length code of non zero freq */ + + /* Determine the bit length frequencies for literal and distance trees */ + scan_tree(s, (ct_data *)s->dyn_ltree, s->l_desc.max_code); + scan_tree(s, (ct_data *)s->dyn_dtree, s->d_desc.max_code); + + /* Build the bit length tree: */ + build_tree(s, (tree_desc *)(&(s->bl_desc))); + /* opt_len now includes the length of the tree representations, except + * the lengths of the bit lengths codes and the 5+5+4 bits for the counts. + */ + + /* Determine the number of bit length codes to send. The pkzip format + * requires that at least 4 bit length codes be sent. (appnote.txt says + * 3 but the actual value used is 4.) + */ + for (max_blindex = BL_CODES-1; max_blindex >= 3; max_blindex--) { + if (s->bl_tree[bl_order[max_blindex]].Len != 0) break; + } + /* Update opt_len to include the bit length tree and counts */ + s->opt_len += 3*(max_blindex+1) + 5+5+4; + Tracev((stderr, "\ndyn trees: dyn %ld, stat %ld", + s->opt_len, s->static_len)); + + return max_blindex; +} + +/* =========================================================================== + * Send the header for a block using dynamic Huffman trees: the counts, the + * lengths of the bit length codes, the literal tree and the distance tree. + * IN assertion: lcodes >= 257, dcodes >= 1, blcodes >= 4. + */ +local void send_all_trees(s, lcodes, dcodes, blcodes) + deflate_state *s; + int lcodes, dcodes, blcodes; /* number of codes for each tree */ +{ + int rank; /* index in bl_order */ + + Assert (lcodes >= 257 && dcodes >= 1 && blcodes >= 4, "not enough codes"); + Assert (lcodes <= L_CODES && dcodes <= D_CODES && blcodes <= BL_CODES, + "too many codes"); + Tracev((stderr, "\nbl counts: ")); + send_bits(s, lcodes-257, 5); /* not +255 as stated in appnote.txt */ + send_bits(s, dcodes-1, 5); + send_bits(s, blcodes-4, 4); /* not -3 as stated in appnote.txt */ + for (rank = 0; rank < blcodes; rank++) { + Tracev((stderr, "\nbl code %2d ", bl_order[rank])); + send_bits(s, s->bl_tree[bl_order[rank]].Len, 3); + } + Tracev((stderr, "\nbl tree: sent %ld", s->bits_sent)); + + send_tree(s, (ct_data *)s->dyn_ltree, lcodes-1); /* literal tree */ + Tracev((stderr, "\nlit tree: sent %ld", s->bits_sent)); + + send_tree(s, (ct_data *)s->dyn_dtree, dcodes-1); /* distance tree */ + Tracev((stderr, "\ndist tree: sent %ld", s->bits_sent)); +} + +/* =========================================================================== + * Send a stored block + */ +void ZLIB_INTERNAL _tr_stored_block(s, buf, stored_len, last) + deflate_state *s; + charf *buf; /* input block */ + ulg stored_len; /* length of input block */ + int last; /* one if this is the last block for a file */ +{ + send_bits(s, (STORED_BLOCK<<1)+last, 3); /* send block type */ +#ifdef DEBUG + s->compressed_len = (s->compressed_len + 3 + 7) & (ulg)~7L; + s->compressed_len += (stored_len + 4) << 3; +#endif + copy_block(s, buf, (unsigned)stored_len, 1); /* with header */ +} + +/* =========================================================================== + * Flush the bits in the bit buffer to pending output (leaves at most 7 bits) + */ +void ZLIB_INTERNAL _tr_flush_bits(s) + deflate_state *s; +{ + bi_flush(s); +} + +/* =========================================================================== + * Send one empty static block to give enough lookahead for inflate. + * This takes 10 bits, of which 7 may remain in the bit buffer. + */ +void ZLIB_INTERNAL _tr_align(s) + deflate_state *s; +{ + send_bits(s, STATIC_TREES<<1, 3); + send_code(s, END_BLOCK, static_ltree); +#ifdef DEBUG + s->compressed_len += 10L; /* 3 for block type, 7 for EOB */ +#endif + bi_flush(s); +} + +/* =========================================================================== + * Determine the best encoding for the current block: dynamic trees, static + * trees or store, and output the encoded block to the zip file. + */ +void ZLIB_INTERNAL _tr_flush_block(s, buf, stored_len, last) + deflate_state *s; + charf *buf; /* input block, or NULL if too old */ + ulg stored_len; /* length of input block */ + int last; /* one if this is the last block for a file */ +{ + ulg opt_lenb, static_lenb; /* opt_len and static_len in bytes */ + int max_blindex = 0; /* index of last bit length code of non zero freq */ + + /* Build the Huffman trees unless a stored block is forced */ + if (s->level > 0) { + + /* Check if the file is binary or text */ + if (s->strm->data_type == Z_UNKNOWN) + s->strm->data_type = detect_data_type(s); + + /* Construct the literal and distance trees */ + build_tree(s, (tree_desc *)(&(s->l_desc))); + Tracev((stderr, "\nlit data: dyn %ld, stat %ld", s->opt_len, + s->static_len)); + + build_tree(s, (tree_desc *)(&(s->d_desc))); + Tracev((stderr, "\ndist data: dyn %ld, stat %ld", s->opt_len, + s->static_len)); + /* At this point, opt_len and static_len are the total bit lengths of + * the compressed block data, excluding the tree representations. + */ + + /* Build the bit length tree for the above two trees, and get the index + * in bl_order of the last bit length code to send. + */ + max_blindex = build_bl_tree(s); + + /* Determine the best encoding. Compute the block lengths in bytes. */ + opt_lenb = (s->opt_len+3+7)>>3; + static_lenb = (s->static_len+3+7)>>3; + + Tracev((stderr, "\nopt %lu(%lu) stat %lu(%lu) stored %lu lit %u ", + opt_lenb, s->opt_len, static_lenb, s->static_len, stored_len, + s->last_lit)); + + if (static_lenb <= opt_lenb) opt_lenb = static_lenb; + + } else { + Assert(buf != (char*)0, "lost buf"); + opt_lenb = static_lenb = stored_len + 5; /* force a stored block */ + } + +#ifdef FORCE_STORED + if (buf != (char*)0) { /* force stored block */ +#else + if (stored_len+4 <= opt_lenb && buf != (char*)0) { + /* 4: two words for the lengths */ +#endif + /* The test buf != NULL is only necessary if LIT_BUFSIZE > WSIZE. + * Otherwise we can't have processed more than WSIZE input bytes since + * the last block flush, because compression would have been + * successful. If LIT_BUFSIZE <= WSIZE, it is never too late to + * transform a block into a stored block. + */ + _tr_stored_block(s, buf, stored_len, last); + +#ifdef FORCE_STATIC + } else if (static_lenb >= 0) { /* force static trees */ +#else + } else if (s->strategy == Z_FIXED || static_lenb == opt_lenb) { +#endif + send_bits(s, (STATIC_TREES<<1)+last, 3); + compress_block(s, (const ct_data *)static_ltree, + (const ct_data *)static_dtree); +#ifdef DEBUG + s->compressed_len += 3 + s->static_len; +#endif + } else { + send_bits(s, (DYN_TREES<<1)+last, 3); + send_all_trees(s, s->l_desc.max_code+1, s->d_desc.max_code+1, + max_blindex+1); + compress_block(s, (const ct_data *)s->dyn_ltree, + (const ct_data *)s->dyn_dtree); +#ifdef DEBUG + s->compressed_len += 3 + s->opt_len; +#endif + } + Assert (s->compressed_len == s->bits_sent, "bad compressed size"); + /* The above check is made mod 2^32, for files larger than 512 MB + * and uLong implemented on 32 bits. + */ + init_block(s); + + if (last) { + bi_windup(s); +#ifdef DEBUG + s->compressed_len += 7; /* align on byte boundary */ +#endif + } + Tracev((stderr,"\ncomprlen %lu(%lu) ", s->compressed_len>>3, + s->compressed_len-7*last)); +} + +/* =========================================================================== + * Save the match info and tally the frequency counts. Return true if + * the current block must be flushed. + */ +int ZLIB_INTERNAL _tr_tally (s, dist, lc) + deflate_state *s; + unsigned dist; /* distance of matched string */ + unsigned lc; /* match length-MIN_MATCH or unmatched char (if dist==0) */ +{ + s->d_buf[s->last_lit] = (ush)dist; + s->l_buf[s->last_lit++] = (uch)lc; + if (dist == 0) { + /* lc is the unmatched char */ + s->dyn_ltree[lc].Freq++; + } else { + s->matches++; + /* Here, lc is the match length - MIN_MATCH */ + dist--; /* dist = match distance - 1 */ + Assert((ush)dist < (ush)MAX_DIST(s) && + (ush)lc <= (ush)(MAX_MATCH-MIN_MATCH) && + (ush)d_code(dist) < (ush)D_CODES, "_tr_tally: bad match"); + + s->dyn_ltree[_length_code[lc]+LITERALS+1].Freq++; + s->dyn_dtree[d_code(dist)].Freq++; + } + +#ifdef TRUNCATE_BLOCK + /* Try to guess if it is profitable to stop the current block here */ + if ((s->last_lit & 0x1fff) == 0 && s->level > 2) { + /* Compute an upper bound for the compressed length */ + ulg out_length = (ulg)s->last_lit*8L; + ulg in_length = (ulg)((long)s->strstart - s->block_start); + int dcode; + for (dcode = 0; dcode < D_CODES; dcode++) { + out_length += (ulg)s->dyn_dtree[dcode].Freq * + (5L+extra_dbits[dcode]); + } + out_length >>= 3; + Tracev((stderr,"\nlast_lit %u, in %ld, out ~%ld(%ld%%) ", + s->last_lit, in_length, out_length, + 100L - out_length*100L/in_length)); + if (s->matches < s->last_lit/2 && out_length < in_length/2) return 1; + } +#endif + return (s->last_lit == s->lit_bufsize-1); + /* We avoid equality with lit_bufsize because of wraparound at 64K + * on 16 bit machines and because stored blocks are restricted to + * 64K-1 bytes. + */ +} + +/* =========================================================================== + * Send the block data compressed using the given Huffman trees + */ +local void compress_block(s, ltree, dtree) + deflate_state *s; + const ct_data *ltree; /* literal tree */ + const ct_data *dtree; /* distance tree */ +{ + unsigned dist; /* distance of matched string */ + int lc; /* match length or unmatched char (if dist == 0) */ + unsigned lx = 0; /* running index in l_buf */ + unsigned code; /* the code to send */ + int extra; /* number of extra bits to send */ + + if (s->last_lit != 0) do { + dist = s->d_buf[lx]; + lc = s->l_buf[lx++]; + if (dist == 0) { + send_code(s, lc, ltree); /* send a literal byte */ + Tracecv(isgraph(lc), (stderr," '%c' ", lc)); + } else { + /* Here, lc is the match length - MIN_MATCH */ + code = _length_code[lc]; + send_code(s, code+LITERALS+1, ltree); /* send the length code */ + extra = extra_lbits[code]; + if (extra != 0) { + lc -= base_length[code]; + send_bits(s, lc, extra); /* send the extra length bits */ + } + dist--; /* dist is now the match distance - 1 */ + code = d_code(dist); + Assert (code < D_CODES, "bad d_code"); + + send_code(s, code, dtree); /* send the distance code */ + extra = extra_dbits[code]; + if (extra != 0) { + dist -= base_dist[code]; + send_bits(s, dist, extra); /* send the extra distance bits */ + } + } /* literal or match pair ? */ + + /* Check that the overlay between pending_buf and d_buf+l_buf is ok: */ + Assert((uInt)(s->pending) < s->lit_bufsize + 2*lx, + "pendingBuf overflow"); + + } while (lx < s->last_lit); + + send_code(s, END_BLOCK, ltree); +} + +/* =========================================================================== + * Check if the data type is TEXT or BINARY, using the following algorithm: + * - TEXT if the two conditions below are satisfied: + * a) There are no non-portable control characters belonging to the + * "black list" (0..6, 14..25, 28..31). + * b) There is at least one printable character belonging to the + * "white list" (9 {TAB}, 10 {LF}, 13 {CR}, 32..255). + * - BINARY otherwise. + * - The following partially-portable control characters form a + * "gray list" that is ignored in this detection algorithm: + * (7 {BEL}, 8 {BS}, 11 {VT}, 12 {FF}, 26 {SUB}, 27 {ESC}). + * IN assertion: the fields Freq of dyn_ltree are set. + */ +local int detect_data_type(s) + deflate_state *s; +{ + /* black_mask is the bit mask of black-listed bytes + * set bits 0..6, 14..25, and 28..31 + * 0xf3ffc07f = binary 11110011111111111100000001111111 + */ + unsigned long black_mask = 0xf3ffc07fUL; + int n; + + /* Check for non-textual ("black-listed") bytes. */ + for (n = 0; n <= 31; n++, black_mask >>= 1) + if ((black_mask & 1) && (s->dyn_ltree[n].Freq != 0)) + return Z_BINARY; + + /* Check for textual ("white-listed") bytes. */ + if (s->dyn_ltree[9].Freq != 0 || s->dyn_ltree[10].Freq != 0 + || s->dyn_ltree[13].Freq != 0) + return Z_TEXT; + for (n = 32; n < LITERALS; n++) + if (s->dyn_ltree[n].Freq != 0) + return Z_TEXT; + + /* There are no "black-listed" or "white-listed" bytes: + * this stream either is empty or has tolerated ("gray-listed") bytes only. + */ + return Z_BINARY; +} + +/* =========================================================================== + * Reverse the first len bits of a code, using straightforward code (a faster + * method would use a table) + * IN assertion: 1 <= len <= 15 + */ +local unsigned bi_reverse(code, len) + unsigned code; /* the value to invert */ + int len; /* its bit length */ +{ + register unsigned res = 0; + do { + res |= code & 1; + code >>= 1, res <<= 1; + } while (--len > 0); + return res >> 1; +} + +/* =========================================================================== + * Flush the bit buffer, keeping at most 7 bits in it. + */ +local void bi_flush(s) + deflate_state *s; +{ + if (s->bi_valid == 16) { + put_short(s, s->bi_buf); + s->bi_buf = 0; + s->bi_valid = 0; + } else if (s->bi_valid >= 8) { + put_byte(s, (Byte)s->bi_buf); + s->bi_buf >>= 8; + s->bi_valid -= 8; + } +} + +/* =========================================================================== + * Flush the bit buffer and align the output on a byte boundary + */ +local void bi_windup(s) + deflate_state *s; +{ + if (s->bi_valid > 8) { + put_short(s, s->bi_buf); + } else if (s->bi_valid > 0) { + put_byte(s, (Byte)s->bi_buf); + } + s->bi_buf = 0; + s->bi_valid = 0; +#ifdef DEBUG + s->bits_sent = (s->bits_sent+7) & ~7; +#endif +} + +/* =========================================================================== + * Copy a stored block, storing first the length and its + * one's complement if requested. + */ +local void copy_block(s, buf, len, header) + deflate_state *s; + charf *buf; /* the input data */ + unsigned len; /* its length */ + int header; /* true if block header must be written */ +{ + bi_windup(s); /* align on byte boundary */ + + if (header) { + put_short(s, (ush)len); + put_short(s, (ush)~len); +#ifdef DEBUG + s->bits_sent += 2*16; +#endif + } +#ifdef DEBUG + s->bits_sent += (ulg)len<<3; +#endif + while (len--) { + put_byte(s, *buf++); + } +} diff --git a/ml/dlib/dlib/external/zlib/trees.h b/ml/dlib/dlib/external/zlib/trees.h new file mode 100644 index 000000000..d35639d82 --- /dev/null +++ b/ml/dlib/dlib/external/zlib/trees.h @@ -0,0 +1,128 @@ +/* header created automatically with -DGEN_TREES_H */ + +local const ct_data static_ltree[L_CODES+2] = { +{{ 12},{ 8}}, {{140},{ 8}}, {{ 76},{ 8}}, {{204},{ 8}}, {{ 44},{ 8}}, +{{172},{ 8}}, {{108},{ 8}}, {{236},{ 8}}, {{ 28},{ 8}}, {{156},{ 8}}, +{{ 92},{ 8}}, {{220},{ 8}}, {{ 60},{ 8}}, {{188},{ 8}}, {{124},{ 8}}, +{{252},{ 8}}, {{ 2},{ 8}}, {{130},{ 8}}, {{ 66},{ 8}}, {{194},{ 8}}, +{{ 34},{ 8}}, {{162},{ 8}}, {{ 98},{ 8}}, {{226},{ 8}}, {{ 18},{ 8}}, +{{146},{ 8}}, {{ 82},{ 8}}, {{210},{ 8}}, {{ 50},{ 8}}, {{178},{ 8}}, +{{114},{ 8}}, {{242},{ 8}}, {{ 10},{ 8}}, {{138},{ 8}}, {{ 74},{ 8}}, +{{202},{ 8}}, {{ 42},{ 8}}, {{170},{ 8}}, {{106},{ 8}}, {{234},{ 8}}, +{{ 26},{ 8}}, {{154},{ 8}}, {{ 90},{ 8}}, {{218},{ 8}}, {{ 58},{ 8}}, +{{186},{ 8}}, {{122},{ 8}}, {{250},{ 8}}, {{ 6},{ 8}}, {{134},{ 8}}, +{{ 70},{ 8}}, {{198},{ 8}}, {{ 38},{ 8}}, {{166},{ 8}}, {{102},{ 8}}, +{{230},{ 8}}, {{ 22},{ 8}}, {{150},{ 8}}, {{ 86},{ 8}}, {{214},{ 8}}, +{{ 54},{ 8}}, {{182},{ 8}}, {{118},{ 8}}, {{246},{ 8}}, {{ 14},{ 8}}, +{{142},{ 8}}, {{ 78},{ 8}}, {{206},{ 8}}, {{ 46},{ 8}}, {{174},{ 8}}, +{{110},{ 8}}, {{238},{ 8}}, {{ 30},{ 8}}, {{158},{ 8}}, {{ 94},{ 8}}, +{{222},{ 8}}, {{ 62},{ 8}}, {{190},{ 8}}, {{126},{ 8}}, {{254},{ 8}}, +{{ 1},{ 8}}, {{129},{ 8}}, {{ 65},{ 8}}, {{193},{ 8}}, {{ 33},{ 8}}, +{{161},{ 8}}, {{ 97},{ 8}}, {{225},{ 8}}, {{ 17},{ 8}}, {{145},{ 8}}, +{{ 81},{ 8}}, {{209},{ 8}}, {{ 49},{ 8}}, {{177},{ 8}}, {{113},{ 8}}, +{{241},{ 8}}, {{ 9},{ 8}}, {{137},{ 8}}, {{ 73},{ 8}}, {{201},{ 8}}, +{{ 41},{ 8}}, {{169},{ 8}}, {{105},{ 8}}, {{233},{ 8}}, {{ 25},{ 8}}, +{{153},{ 8}}, {{ 89},{ 8}}, {{217},{ 8}}, {{ 57},{ 8}}, {{185},{ 8}}, +{{121},{ 8}}, {{249},{ 8}}, {{ 5},{ 8}}, {{133},{ 8}}, {{ 69},{ 8}}, +{{197},{ 8}}, {{ 37},{ 8}}, {{165},{ 8}}, {{101},{ 8}}, {{229},{ 8}}, +{{ 21},{ 8}}, {{149},{ 8}}, {{ 85},{ 8}}, {{213},{ 8}}, {{ 53},{ 8}}, +{{181},{ 8}}, {{117},{ 8}}, {{245},{ 8}}, {{ 13},{ 8}}, {{141},{ 8}}, +{{ 77},{ 8}}, {{205},{ 8}}, {{ 45},{ 8}}, {{173},{ 8}}, {{109},{ 8}}, +{{237},{ 8}}, {{ 29},{ 8}}, {{157},{ 8}}, {{ 93},{ 8}}, {{221},{ 8}}, +{{ 61},{ 8}}, {{189},{ 8}}, {{125},{ 8}}, {{253},{ 8}}, {{ 19},{ 9}}, +{{275},{ 9}}, {{147},{ 9}}, {{403},{ 9}}, {{ 83},{ 9}}, {{339},{ 9}}, +{{211},{ 9}}, {{467},{ 9}}, {{ 51},{ 9}}, {{307},{ 9}}, {{179},{ 9}}, +{{435},{ 9}}, {{115},{ 9}}, {{371},{ 9}}, {{243},{ 9}}, {{499},{ 9}}, +{{ 11},{ 9}}, {{267},{ 9}}, {{139},{ 9}}, {{395},{ 9}}, {{ 75},{ 9}}, +{{331},{ 9}}, {{203},{ 9}}, {{459},{ 9}}, {{ 43},{ 9}}, {{299},{ 9}}, +{{171},{ 9}}, {{427},{ 9}}, {{107},{ 9}}, {{363},{ 9}}, {{235},{ 9}}, +{{491},{ 9}}, {{ 27},{ 9}}, {{283},{ 9}}, {{155},{ 9}}, {{411},{ 9}}, +{{ 91},{ 9}}, {{347},{ 9}}, {{219},{ 9}}, {{475},{ 9}}, {{ 59},{ 9}}, +{{315},{ 9}}, {{187},{ 9}}, {{443},{ 9}}, {{123},{ 9}}, {{379},{ 9}}, +{{251},{ 9}}, {{507},{ 9}}, {{ 7},{ 9}}, {{263},{ 9}}, {{135},{ 9}}, +{{391},{ 9}}, {{ 71},{ 9}}, {{327},{ 9}}, {{199},{ 9}}, {{455},{ 9}}, +{{ 39},{ 9}}, {{295},{ 9}}, {{167},{ 9}}, {{423},{ 9}}, {{103},{ 9}}, +{{359},{ 9}}, {{231},{ 9}}, {{487},{ 9}}, {{ 23},{ 9}}, {{279},{ 9}}, +{{151},{ 9}}, {{407},{ 9}}, {{ 87},{ 9}}, {{343},{ 9}}, {{215},{ 9}}, +{{471},{ 9}}, {{ 55},{ 9}}, {{311},{ 9}}, {{183},{ 9}}, {{439},{ 9}}, +{{119},{ 9}}, {{375},{ 9}}, {{247},{ 9}}, {{503},{ 9}}, {{ 15},{ 9}}, +{{271},{ 9}}, {{143},{ 9}}, {{399},{ 9}}, {{ 79},{ 9}}, {{335},{ 9}}, +{{207},{ 9}}, {{463},{ 9}}, {{ 47},{ 9}}, {{303},{ 9}}, {{175},{ 9}}, +{{431},{ 9}}, {{111},{ 9}}, {{367},{ 9}}, {{239},{ 9}}, {{495},{ 9}}, +{{ 31},{ 9}}, {{287},{ 9}}, {{159},{ 9}}, {{415},{ 9}}, {{ 95},{ 9}}, +{{351},{ 9}}, {{223},{ 9}}, {{479},{ 9}}, {{ 63},{ 9}}, {{319},{ 9}}, +{{191},{ 9}}, {{447},{ 9}}, {{127},{ 9}}, {{383},{ 9}}, {{255},{ 9}}, +{{511},{ 9}}, {{ 0},{ 7}}, {{ 64},{ 7}}, {{ 32},{ 7}}, {{ 96},{ 7}}, +{{ 16},{ 7}}, {{ 80},{ 7}}, {{ 48},{ 7}}, {{112},{ 7}}, {{ 8},{ 7}}, +{{ 72},{ 7}}, {{ 40},{ 7}}, {{104},{ 7}}, {{ 24},{ 7}}, {{ 88},{ 7}}, +{{ 56},{ 7}}, {{120},{ 7}}, {{ 4},{ 7}}, {{ 68},{ 7}}, {{ 36},{ 7}}, +{{100},{ 7}}, {{ 20},{ 7}}, {{ 84},{ 7}}, {{ 52},{ 7}}, {{116},{ 7}}, +{{ 3},{ 8}}, {{131},{ 8}}, {{ 67},{ 8}}, {{195},{ 8}}, {{ 35},{ 8}}, +{{163},{ 8}}, {{ 99},{ 8}}, {{227},{ 8}} +}; + +local const ct_data static_dtree[D_CODES] = { +{{ 0},{ 5}}, {{16},{ 5}}, {{ 8},{ 5}}, {{24},{ 5}}, {{ 4},{ 5}}, +{{20},{ 5}}, {{12},{ 5}}, {{28},{ 5}}, {{ 2},{ 5}}, {{18},{ 5}}, +{{10},{ 5}}, {{26},{ 5}}, {{ 6},{ 5}}, {{22},{ 5}}, {{14},{ 5}}, +{{30},{ 5}}, {{ 1},{ 5}}, {{17},{ 5}}, {{ 9},{ 5}}, {{25},{ 5}}, +{{ 5},{ 5}}, {{21},{ 5}}, {{13},{ 5}}, {{29},{ 5}}, {{ 3},{ 5}}, +{{19},{ 5}}, {{11},{ 5}}, {{27},{ 5}}, {{ 7},{ 5}}, {{23},{ 5}} +}; + +const uch ZLIB_INTERNAL _dist_code[DIST_CODE_LEN] = { + 0, 1, 2, 3, 4, 4, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8, + 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 10, 10, 10, +10, 10, 10, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, +11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, +12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 13, 13, 13, 13, +13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, +13, 13, 13, 13, 13, 13, 13, 13, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, +14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, +14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, +14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15, 15, 15, 15, +15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, +15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, +15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 0, 0, 16, 17, +18, 18, 19, 19, 20, 20, 20, 20, 21, 21, 21, 21, 22, 22, 22, 22, 22, 22, 22, 22, +23, 23, 23, 23, 23, 23, 23, 23, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, +24, 24, 24, 24, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, +26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, +26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 27, 27, 27, 27, 27, 27, 27, 27, +27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, +27, 27, 27, 27, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, +28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, +28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, +28, 28, 28, 28, 28, 28, 28, 28, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, +29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, +29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, +29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29 +}; + +const uch ZLIB_INTERNAL _length_code[MAX_MATCH-MIN_MATCH+1]= { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 12, 12, +13, 13, 13, 13, 14, 14, 14, 14, 15, 15, 15, 15, 16, 16, 16, 16, 16, 16, 16, 16, +17, 17, 17, 17, 17, 17, 17, 17, 18, 18, 18, 18, 18, 18, 18, 18, 19, 19, 19, 19, +19, 19, 19, 19, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, +21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 22, 22, 22, 22, +22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 23, 23, 23, 23, 23, 23, 23, 23, +23, 23, 23, 23, 23, 23, 23, 23, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, +24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, +25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, +25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 26, 26, 26, 26, 26, 26, 26, 26, +26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, +26, 26, 26, 26, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, +27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 28 +}; + +local const int base_length[LENGTH_CODES] = { +0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 32, 40, 48, 56, +64, 80, 96, 112, 128, 160, 192, 224, 0 +}; + +local const int base_dist[D_CODES] = { + 0, 1, 2, 3, 4, 6, 8, 12, 16, 24, + 32, 48, 64, 96, 128, 192, 256, 384, 512, 768, + 1024, 1536, 2048, 3072, 4096, 6144, 8192, 12288, 16384, 24576 +}; + diff --git a/ml/dlib/dlib/external/zlib/uncompr.c b/ml/dlib/dlib/external/zlib/uncompr.c new file mode 100644 index 000000000..242e9493d --- /dev/null +++ b/ml/dlib/dlib/external/zlib/uncompr.c @@ -0,0 +1,59 @@ +/* uncompr.c -- decompress a memory buffer + * Copyright (C) 1995-2003, 2010 Jean-loup Gailly. + * For conditions of distribution and use, see copyright notice in zlib.h + */ + +/* @(#) $Id$ */ + +#define ZLIB_INTERNAL +#include "zlib.h" + +/* =========================================================================== + Decompresses the source buffer into the destination buffer. sourceLen is + the byte length of the source buffer. Upon entry, destLen is the total + size of the destination buffer, which must be large enough to hold the + entire uncompressed data. (The size of the uncompressed data must have + been saved previously by the compressor and transmitted to the decompressor + by some mechanism outside the scope of this compression library.) + Upon exit, destLen is the actual size of the compressed buffer. + + uncompress returns Z_OK if success, Z_MEM_ERROR if there was not + enough memory, Z_BUF_ERROR if there was not enough room in the output + buffer, or Z_DATA_ERROR if the input data was corrupted. +*/ +int ZEXPORT uncompress (dest, destLen, source, sourceLen) + Bytef *dest; + uLongf *destLen; + const Bytef *source; + uLong sourceLen; +{ + z_stream stream; + int err; + + stream.next_in = (z_const Bytef *)source; + stream.avail_in = (uInt)sourceLen; + /* Check for source > 64K on 16-bit machine: */ + if ((uLong)stream.avail_in != sourceLen) return Z_BUF_ERROR; + + stream.next_out = dest; + stream.avail_out = (uInt)*destLen; + if ((uLong)stream.avail_out != *destLen) return Z_BUF_ERROR; + + stream.zalloc = (alloc_func)0; + stream.zfree = (free_func)0; + + err = inflateInit(&stream); + if (err != Z_OK) return err; + + err = inflate(&stream, Z_FINISH); + if (err != Z_STREAM_END) { + inflateEnd(&stream); + if (err == Z_NEED_DICT || (err == Z_BUF_ERROR && stream.avail_in == 0)) + return Z_DATA_ERROR; + return err; + } + *destLen = stream.total_out; + + err = inflateEnd(&stream); + return err; +} diff --git a/ml/dlib/dlib/external/zlib/zconf.h b/ml/dlib/dlib/external/zlib/zconf.h new file mode 100644 index 000000000..9987a7755 --- /dev/null +++ b/ml/dlib/dlib/external/zlib/zconf.h @@ -0,0 +1,511 @@ +/* zconf.h -- configuration of the zlib compression library + * Copyright (C) 1995-2013 Jean-loup Gailly. + * For conditions of distribution and use, see copyright notice in zlib.h + */ + +/* @(#) $Id$ */ + +#ifndef ZCONF_H +#define ZCONF_H + +/* + * If you *really* need a unique prefix for all types and library functions, + * compile with -DZ_PREFIX. The "standard" zlib should be compiled without it. + * Even better than compiling with -DZ_PREFIX would be to use configure to set + * this permanently in zconf.h using "./configure --zprefix". + */ +#ifdef Z_PREFIX /* may be set to #if 1 by ./configure */ +# define Z_PREFIX_SET + +/* all linked symbols */ +# define _dist_code z__dist_code +# define _length_code z__length_code +# define _tr_align z__tr_align +# define _tr_flush_bits z__tr_flush_bits +# define _tr_flush_block z__tr_flush_block +# define _tr_init z__tr_init +# define _tr_stored_block z__tr_stored_block +# define _tr_tally z__tr_tally +# define adler32 z_adler32 +# define adler32_combine z_adler32_combine +# define adler32_combine64 z_adler32_combine64 +# ifndef Z_SOLO +# define compress z_compress +# define compress2 z_compress2 +# define compressBound z_compressBound +# endif +# define crc32 z_crc32 +# define crc32_combine z_crc32_combine +# define crc32_combine64 z_crc32_combine64 +# define deflate z_deflate +# define deflateBound z_deflateBound +# define deflateCopy z_deflateCopy +# define deflateEnd z_deflateEnd +# define deflateInit2_ z_deflateInit2_ +# define deflateInit_ z_deflateInit_ +# define deflateParams z_deflateParams +# define deflatePending z_deflatePending +# define deflatePrime z_deflatePrime +# define deflateReset z_deflateReset +# define deflateResetKeep z_deflateResetKeep +# define deflateSetDictionary z_deflateSetDictionary +# define deflateSetHeader z_deflateSetHeader +# define deflateTune z_deflateTune +# define deflate_copyright z_deflate_copyright +# define get_crc_table z_get_crc_table +# ifndef Z_SOLO +# define gz_error z_gz_error +# define gz_intmax z_gz_intmax +# define gz_strwinerror z_gz_strwinerror +# define gzbuffer z_gzbuffer +# define gzclearerr z_gzclearerr +# define gzclose z_gzclose +# define gzclose_r z_gzclose_r +# define gzclose_w z_gzclose_w +# define gzdirect z_gzdirect +# define gzdopen z_gzdopen +# define gzeof z_gzeof +# define gzerror z_gzerror +# define gzflush z_gzflush +# define gzgetc z_gzgetc +# define gzgetc_ z_gzgetc_ +# define gzgets z_gzgets +# define gzoffset z_gzoffset +# define gzoffset64 z_gzoffset64 +# define gzopen z_gzopen +# define gzopen64 z_gzopen64 +# ifdef _WIN32 +# define gzopen_w z_gzopen_w +# endif +# define gzprintf z_gzprintf +# define gzvprintf z_gzvprintf +# define gzputc z_gzputc +# define gzputs z_gzputs +# define gzread z_gzread +# define gzrewind z_gzrewind +# define gzseek z_gzseek +# define gzseek64 z_gzseek64 +# define gzsetparams z_gzsetparams +# define gztell z_gztell +# define gztell64 z_gztell64 +# define gzungetc z_gzungetc +# define gzwrite z_gzwrite +# endif +# define inflate z_inflate +# define inflateBack z_inflateBack +# define inflateBackEnd z_inflateBackEnd +# define inflateBackInit_ z_inflateBackInit_ +# define inflateCopy z_inflateCopy +# define inflateEnd z_inflateEnd +# define inflateGetHeader z_inflateGetHeader +# define inflateInit2_ z_inflateInit2_ +# define inflateInit_ z_inflateInit_ +# define inflateMark z_inflateMark +# define inflatePrime z_inflatePrime +# define inflateReset z_inflateReset +# define inflateReset2 z_inflateReset2 +# define inflateSetDictionary z_inflateSetDictionary +# define inflateGetDictionary z_inflateGetDictionary +# define inflateSync z_inflateSync +# define inflateSyncPoint z_inflateSyncPoint +# define inflateUndermine z_inflateUndermine +# define inflateResetKeep z_inflateResetKeep +# define inflate_copyright z_inflate_copyright +# define inflate_fast z_inflate_fast +# define inflate_table z_inflate_table +# ifndef Z_SOLO +# define uncompress z_uncompress +# endif +# define zError z_zError +# ifndef Z_SOLO +# define zcalloc z_zcalloc +# define zcfree z_zcfree +# endif +# define zlibCompileFlags z_zlibCompileFlags +# define zlibVersion z_zlibVersion + +/* all zlib typedefs in zlib.h and zconf.h */ +# define Byte z_Byte +# define Bytef z_Bytef +# define alloc_func z_alloc_func +# define charf z_charf +# define free_func z_free_func +# ifndef Z_SOLO +# define gzFile z_gzFile +# endif +# define gz_header z_gz_header +# define gz_headerp z_gz_headerp +# define in_func z_in_func +# define intf z_intf +# define out_func z_out_func +# define uInt z_uInt +# define uIntf z_uIntf +# define uLong z_uLong +# define uLongf z_uLongf +# define voidp z_voidp +# define voidpc z_voidpc +# define voidpf z_voidpf + +/* all zlib structs in zlib.h and zconf.h */ +# define gz_header_s z_gz_header_s +# define internal_state z_internal_state + +#endif + +#if defined(__MSDOS__) && !defined(MSDOS) +# define MSDOS +#endif +#if (defined(OS_2) || defined(__OS2__)) && !defined(OS2) +# define OS2 +#endif +#if defined(_WINDOWS) && !defined(WINDOWS) +# define WINDOWS +#endif +#if defined(_WIN32) || defined(_WIN32_WCE) || defined(__WIN32__) +# ifndef WIN32 +# define WIN32 +# endif +#endif +#if (defined(MSDOS) || defined(OS2) || defined(WINDOWS)) && !defined(WIN32) +# if !defined(__GNUC__) && !defined(__FLAT__) && !defined(__386__) +# ifndef SYS16BIT +# define SYS16BIT +# endif +# endif +#endif + +/* + * Compile with -DMAXSEG_64K if the alloc function cannot allocate more + * than 64k bytes at a time (needed on systems with 16-bit int). + */ +#ifdef SYS16BIT +# define MAXSEG_64K +#endif +#ifdef MSDOS +# define UNALIGNED_OK +#endif + +#ifdef __STDC_VERSION__ +# ifndef STDC +# define STDC +# endif +# if __STDC_VERSION__ >= 199901L +# ifndef STDC99 +# define STDC99 +# endif +# endif +#endif +#if !defined(STDC) && (defined(__STDC__) || defined(__cplusplus)) +# define STDC +#endif +#if !defined(STDC) && (defined(__GNUC__) || defined(__BORLANDC__)) +# define STDC +#endif +#if !defined(STDC) && (defined(MSDOS) || defined(WINDOWS) || defined(WIN32)) +# define STDC +#endif +#if !defined(STDC) && (defined(OS2) || defined(__HOS_AIX__)) +# define STDC +#endif + +#if defined(__OS400__) && !defined(STDC) /* iSeries (formerly AS/400). */ +# define STDC +#endif + +#ifndef STDC +# ifndef const /* cannot use !defined(STDC) && !defined(const) on Mac */ +# define const /* note: need a more gentle solution here */ +# endif +#endif + +#if defined(ZLIB_CONST) && !defined(z_const) +# define z_const const +#else +# define z_const +#endif + +/* Some Mac compilers merge all .h files incorrectly: */ +#if defined(__MWERKS__)||defined(applec)||defined(THINK_C)||defined(__SC__) +# define NO_DUMMY_DECL +#endif + +/* Maximum value for memLevel in deflateInit2 */ +#ifndef MAX_MEM_LEVEL +# ifdef MAXSEG_64K +# define MAX_MEM_LEVEL 8 +# else +# define MAX_MEM_LEVEL 9 +# endif +#endif + +/* Maximum value for windowBits in deflateInit2 and inflateInit2. + * WARNING: reducing MAX_WBITS makes minigzip unable to extract .gz files + * created by gzip. (Files created by minigzip can still be extracted by + * gzip.) + */ +#ifndef MAX_WBITS +# define MAX_WBITS 15 /* 32K LZ77 window */ +#endif + +/* The memory requirements for deflate are (in bytes): + (1 << (windowBits+2)) + (1 << (memLevel+9)) + that is: 128K for windowBits=15 + 128K for memLevel = 8 (default values) + plus a few kilobytes for small objects. For example, if you want to reduce + the default memory requirements from 256K to 128K, compile with + make CFLAGS="-O -DMAX_WBITS=14 -DMAX_MEM_LEVEL=7" + Of course this will generally degrade compression (there's no free lunch). + + The memory requirements for inflate are (in bytes) 1 << windowBits + that is, 32K for windowBits=15 (default value) plus a few kilobytes + for small objects. +*/ + + /* Type declarations */ + +#ifndef OF /* function prototypes */ +# ifdef STDC +# define OF(args) args +# else +# define OF(args) () +# endif +#endif + +#ifndef Z_ARG /* function prototypes for stdarg */ +# if defined(STDC) || defined(Z_HAVE_STDARG_H) +# define Z_ARG(args) args +# else +# define Z_ARG(args) () +# endif +#endif + +/* The following definitions for FAR are needed only for MSDOS mixed + * model programming (small or medium model with some far allocations). + * This was tested only with MSC; for other MSDOS compilers you may have + * to define NO_MEMCPY in zutil.h. If you don't need the mixed model, + * just define FAR to be empty. + */ +#ifdef SYS16BIT +# if defined(M_I86SM) || defined(M_I86MM) + /* MSC small or medium model */ +# define SMALL_MEDIUM +# ifdef _MSC_VER +# define FAR _far +# else +# define FAR far +# endif +# endif +# if (defined(__SMALL__) || defined(__MEDIUM__)) + /* Turbo C small or medium model */ +# define SMALL_MEDIUM +# ifdef __BORLANDC__ +# define FAR _far +# else +# define FAR far +# endif +# endif +#endif + +#if defined(WINDOWS) || defined(WIN32) + /* If building or using zlib as a DLL, define ZLIB_DLL. + * This is not mandatory, but it offers a little performance increase. + */ +# ifdef ZLIB_DLL +# if defined(WIN32) && (!defined(__BORLANDC__) || (__BORLANDC__ >= 0x500)) +# ifdef ZLIB_INTERNAL +# define ZEXTERN extern __declspec(dllexport) +# else +# define ZEXTERN extern __declspec(dllimport) +# endif +# endif +# endif /* ZLIB_DLL */ + /* If building or using zlib with the WINAPI/WINAPIV calling convention, + * define ZLIB_WINAPI. + * Caution: the standard ZLIB1.DLL is NOT compiled using ZLIB_WINAPI. + */ +# ifdef ZLIB_WINAPI +# ifdef FAR +# undef FAR +# endif +# include + /* No need for _export, use ZLIB.DEF instead. */ + /* For complete Windows compatibility, use WINAPI, not __stdcall. */ +# define ZEXPORT WINAPI +# ifdef WIN32 +# define ZEXPORTVA WINAPIV +# else +# define ZEXPORTVA FAR CDECL +# endif +# endif +#endif + +#if defined (__BEOS__) +# ifdef ZLIB_DLL +# ifdef ZLIB_INTERNAL +# define ZEXPORT __declspec(dllexport) +# define ZEXPORTVA __declspec(dllexport) +# else +# define ZEXPORT __declspec(dllimport) +# define ZEXPORTVA __declspec(dllimport) +# endif +# endif +#endif + +#ifndef ZEXTERN +# define ZEXTERN extern +#endif +#ifndef ZEXPORT +# define ZEXPORT +#endif +#ifndef ZEXPORTVA +# define ZEXPORTVA +#endif + +#ifndef FAR +# define FAR +#endif + +#if !defined(__MACTYPES__) +typedef unsigned char Byte; /* 8 bits */ +#endif +typedef unsigned int uInt; /* 16 bits or more */ +typedef unsigned long uLong; /* 32 bits or more */ + +#ifdef SMALL_MEDIUM + /* Borland C/C++ and some old MSC versions ignore FAR inside typedef */ +# define Bytef Byte FAR +#else + typedef Byte FAR Bytef; +#endif +typedef char FAR charf; +typedef int FAR intf; +typedef uInt FAR uIntf; +typedef uLong FAR uLongf; + +#ifdef STDC + typedef void const *voidpc; + typedef void FAR *voidpf; + typedef void *voidp; +#else + typedef Byte const *voidpc; + typedef Byte FAR *voidpf; + typedef Byte *voidp; +#endif + +#if !defined(Z_U4) && !defined(Z_SOLO) && defined(STDC) +# include +# if (UINT_MAX == 0xffffffffUL) +# define Z_U4 unsigned +# elif (ULONG_MAX == 0xffffffffUL) +# define Z_U4 unsigned long +# elif (USHRT_MAX == 0xffffffffUL) +# define Z_U4 unsigned short +# endif +#endif + +#ifdef Z_U4 + typedef Z_U4 z_crc_t; +#else + typedef unsigned long z_crc_t; +#endif + +#ifdef HAVE_UNISTD_H /* may be set to #if 1 by ./configure */ +# define Z_HAVE_UNISTD_H +#endif + +#ifdef HAVE_STDARG_H /* may be set to #if 1 by ./configure */ +# define Z_HAVE_STDARG_H +#endif + +#ifdef STDC +# ifndef Z_SOLO +# include /* for off_t */ +# endif +#endif + +#if defined(STDC) || defined(Z_HAVE_STDARG_H) +# ifndef Z_SOLO +# include /* for va_list */ +# endif +#endif + +#ifdef _WIN32 +# ifndef Z_SOLO +# include /* for wchar_t */ +# endif +#endif + +/* a little trick to accommodate both "#define _LARGEFILE64_SOURCE" and + * "#define _LARGEFILE64_SOURCE 1" as requesting 64-bit operations, (even + * though the former does not conform to the LFS document), but considering + * both "#undef _LARGEFILE64_SOURCE" and "#define _LARGEFILE64_SOURCE 0" as + * equivalently requesting no 64-bit operations + */ +#if defined(_LARGEFILE64_SOURCE) && -_LARGEFILE64_SOURCE - -1 == 1 +# undef _LARGEFILE64_SOURCE +#endif + +#if defined(__WATCOMC__) && !defined(Z_HAVE_UNISTD_H) +# define Z_HAVE_UNISTD_H +#endif +#ifndef Z_SOLO +# if defined(Z_HAVE_UNISTD_H) || defined(_LARGEFILE64_SOURCE) +# include /* for SEEK_*, off_t, and _LFS64_LARGEFILE */ +# ifdef VMS +# include /* for off_t */ +# endif +# ifndef z_off_t +# define z_off_t off_t +# endif +# endif +#endif + +#if defined(_LFS64_LARGEFILE) && _LFS64_LARGEFILE-0 +# define Z_LFS64 +#endif + +#if defined(_LARGEFILE64_SOURCE) && defined(Z_LFS64) +# define Z_LARGE64 +#endif + +#if defined(_FILE_OFFSET_BITS) && _FILE_OFFSET_BITS-0 == 64 && defined(Z_LFS64) +# define Z_WANT64 +#endif + +#if !defined(SEEK_SET) && !defined(Z_SOLO) +# define SEEK_SET 0 /* Seek from beginning of file. */ +# define SEEK_CUR 1 /* Seek from current position. */ +# define SEEK_END 2 /* Set file pointer to EOF plus "offset" */ +#endif + +#ifndef z_off_t +# define z_off_t long +#endif + +#if !defined(_WIN32) && defined(Z_LARGE64) +# define z_off64_t off64_t +#else +# if defined(_WIN32) && !defined(__GNUC__) && !defined(Z_SOLO) +# define z_off64_t __int64 +# else +# define z_off64_t z_off_t +# endif +#endif + +/* MVS linker does not support external names larger than 8 bytes */ +#if defined(__MVS__) + #pragma map(deflateInit_,"DEIN") + #pragma map(deflateInit2_,"DEIN2") + #pragma map(deflateEnd,"DEEND") + #pragma map(deflateBound,"DEBND") + #pragma map(inflateInit_,"ININ") + #pragma map(inflateInit2_,"ININ2") + #pragma map(inflateEnd,"INEND") + #pragma map(inflateSync,"INSY") + #pragma map(inflateSetDictionary,"INSEDI") + #pragma map(compressBound,"CMBND") + #pragma map(inflate_table,"INTABL") + #pragma map(inflate_fast,"INFA") + #pragma map(inflate_copyright,"INCOPY") +#endif + +#endif /* ZCONF_H */ diff --git a/ml/dlib/dlib/external/zlib/zlib.h b/ml/dlib/dlib/external/zlib/zlib.h new file mode 100644 index 000000000..3e0c7672a --- /dev/null +++ b/ml/dlib/dlib/external/zlib/zlib.h @@ -0,0 +1,1768 @@ +/* zlib.h -- interface of the 'zlib' general purpose compression library + version 1.2.8, April 28th, 2013 + + Copyright (C) 1995-2013 Jean-loup Gailly and Mark Adler + + This software is provided 'as-is', without any express or implied + warranty. In no event will the authors be held liable for any damages + arising from the use of this software. + + Permission is granted to anyone to use this software for any purpose, + including commercial applications, and to alter it and redistribute it + freely, subject to the following restrictions: + + 1. The origin of this software must not be misrepresented; you must not + claim that you wrote the original software. If you use this software + in a product, an acknowledgment in the product documentation would be + appreciated but is not required. + 2. Altered source versions must be plainly marked as such, and must not be + misrepresented as being the original software. + 3. This notice may not be removed or altered from any source distribution. + + Jean-loup Gailly Mark Adler + jloup@gzip.org madler@alumni.caltech.edu + + + The data format used by the zlib library is described by RFCs (Request for + Comments) 1950 to 1952 in the files http://tools.ietf.org/html/rfc1950 + (zlib format), rfc1951 (deflate format) and rfc1952 (gzip format). +*/ + +#ifndef ZLIB_H +#define ZLIB_H + +#include "zconf.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#define ZLIB_VERSION "1.2.8" +#define ZLIB_VERNUM 0x1280 +#define ZLIB_VER_MAJOR 1 +#define ZLIB_VER_MINOR 2 +#define ZLIB_VER_REVISION 8 +#define ZLIB_VER_SUBREVISION 0 + +/* + The 'zlib' compression library provides in-memory compression and + decompression functions, including integrity checks of the uncompressed data. + This version of the library supports only one compression method (deflation) + but other algorithms will be added later and will have the same stream + interface. + + Compression can be done in a single step if the buffers are large enough, + or can be done by repeated calls of the compression function. In the latter + case, the application must provide more input and/or consume the output + (providing more output space) before each call. + + The compressed data format used by default by the in-memory functions is + the zlib format, which is a zlib wrapper documented in RFC 1950, wrapped + around a deflate stream, which is itself documented in RFC 1951. + + The library also supports reading and writing files in gzip (.gz) format + with an interface similar to that of stdio using the functions that start + with "gz". The gzip format is different from the zlib format. gzip is a + gzip wrapper, documented in RFC 1952, wrapped around a deflate stream. + + This library can optionally read and write gzip streams in memory as well. + + The zlib format was designed to be compact and fast for use in memory + and on communications channels. The gzip format was designed for single- + file compression on file systems, has a larger header than zlib to maintain + directory information, and uses a different, slower check method than zlib. + + The library does not install any signal handler. The decoder checks + the consistency of the compressed data, so the library should never crash + even in case of corrupted input. +*/ + +typedef voidpf (*alloc_func) OF((voidpf opaque, uInt items, uInt size)); +typedef void (*free_func) OF((voidpf opaque, voidpf address)); + +struct internal_state; + +typedef struct z_stream_s { + z_const Bytef *next_in; /* next input byte */ + uInt avail_in; /* number of bytes available at next_in */ + uLong total_in; /* total number of input bytes read so far */ + + Bytef *next_out; /* next output byte should be put there */ + uInt avail_out; /* remaining free space at next_out */ + uLong total_out; /* total number of bytes output so far */ + + z_const char *msg; /* last error message, NULL if no error */ + struct internal_state FAR *state; /* not visible by applications */ + + alloc_func zalloc; /* used to allocate the internal state */ + free_func zfree; /* used to free the internal state */ + voidpf opaque; /* private data object passed to zalloc and zfree */ + + int data_type; /* best guess about the data type: binary or text */ + uLong adler; /* adler32 value of the uncompressed data */ + uLong reserved; /* reserved for future use */ +} z_stream; + +typedef z_stream FAR *z_streamp; + +/* + gzip header information passed to and from zlib routines. See RFC 1952 + for more details on the meanings of these fields. +*/ +typedef struct gz_header_s { + int text; /* true if compressed data believed to be text */ + uLong time; /* modification time */ + int xflags; /* extra flags (not used when writing a gzip file) */ + int os; /* operating system */ + Bytef *extra; /* pointer to extra field or Z_NULL if none */ + uInt extra_len; /* extra field length (valid if extra != Z_NULL) */ + uInt extra_max; /* space at extra (only when reading header) */ + Bytef *name; /* pointer to zero-terminated file name or Z_NULL */ + uInt name_max; /* space at name (only when reading header) */ + Bytef *comment; /* pointer to zero-terminated comment or Z_NULL */ + uInt comm_max; /* space at comment (only when reading header) */ + int hcrc; /* true if there was or will be a header crc */ + int done; /* true when done reading gzip header (not used + when writing a gzip file) */ +} gz_header; + +typedef gz_header FAR *gz_headerp; + +/* + The application must update next_in and avail_in when avail_in has dropped + to zero. It must update next_out and avail_out when avail_out has dropped + to zero. The application must initialize zalloc, zfree and opaque before + calling the init function. All other fields are set by the compression + library and must not be updated by the application. + + The opaque value provided by the application will be passed as the first + parameter for calls of zalloc and zfree. This can be useful for custom + memory management. The compression library attaches no meaning to the + opaque value. + + zalloc must return Z_NULL if there is not enough memory for the object. + If zlib is used in a multi-threaded application, zalloc and zfree must be + thread safe. + + On 16-bit systems, the functions zalloc and zfree must be able to allocate + exactly 65536 bytes, but will not be required to allocate more than this if + the symbol MAXSEG_64K is defined (see zconf.h). WARNING: On MSDOS, pointers + returned by zalloc for objects of exactly 65536 bytes *must* have their + offset normalized to zero. The default allocation function provided by this + library ensures this (see zutil.c). To reduce memory requirements and avoid + any allocation of 64K objects, at the expense of compression ratio, compile + the library with -DMAX_WBITS=14 (see zconf.h). + + The fields total_in and total_out can be used for statistics or progress + reports. After compression, total_in holds the total size of the + uncompressed data and may be saved for use in the decompressor (particularly + if the decompressor wants to decompress everything in a single step). +*/ + + /* constants */ + +#define Z_NO_FLUSH 0 +#define Z_PARTIAL_FLUSH 1 +#define Z_SYNC_FLUSH 2 +#define Z_FULL_FLUSH 3 +#define Z_FINISH 4 +#define Z_BLOCK 5 +#define Z_TREES 6 +/* Allowed flush values; see deflate() and inflate() below for details */ + +#define Z_OK 0 +#define Z_STREAM_END 1 +#define Z_NEED_DICT 2 +#define Z_ERRNO (-1) +#define Z_STREAM_ERROR (-2) +#define Z_DATA_ERROR (-3) +#define Z_MEM_ERROR (-4) +#define Z_BUF_ERROR (-5) +#define Z_VERSION_ERROR (-6) +/* Return codes for the compression/decompression functions. Negative values + * are errors, positive values are used for special but normal events. + */ + +#define Z_NO_COMPRESSION 0 +#define Z_BEST_SPEED 1 +#define Z_BEST_COMPRESSION 9 +#define Z_DEFAULT_COMPRESSION (-1) +/* compression levels */ + +#define Z_FILTERED 1 +#define Z_HUFFMAN_ONLY 2 +#define Z_RLE 3 +#define Z_FIXED 4 +#define Z_DEFAULT_STRATEGY 0 +/* compression strategy; see deflateInit2() below for details */ + +#define Z_BINARY 0 +#define Z_TEXT 1 +#define Z_ASCII Z_TEXT /* for compatibility with 1.2.2 and earlier */ +#define Z_UNKNOWN 2 +/* Possible values of the data_type field (though see inflate()) */ + +#define Z_DEFLATED 8 +/* The deflate compression method (the only one supported in this version) */ + +#define Z_NULL 0 /* for initializing zalloc, zfree, opaque */ + +#define zlib_version zlibVersion() +/* for compatibility with versions < 1.0.2 */ + + + /* basic functions */ + +ZEXTERN const char * ZEXPORT zlibVersion OF((void)); +/* The application can compare zlibVersion and ZLIB_VERSION for consistency. + If the first character differs, the library code actually used is not + compatible with the zlib.h header file used by the application. This check + is automatically made by deflateInit and inflateInit. + */ + +/* +ZEXTERN int ZEXPORT deflateInit OF((z_streamp strm, int level)); + + Initializes the internal stream state for compression. The fields + zalloc, zfree and opaque must be initialized before by the caller. If + zalloc and zfree are set to Z_NULL, deflateInit updates them to use default + allocation functions. + + The compression level must be Z_DEFAULT_COMPRESSION, or between 0 and 9: + 1 gives best speed, 9 gives best compression, 0 gives no compression at all + (the input data is simply copied a block at a time). Z_DEFAULT_COMPRESSION + requests a default compromise between speed and compression (currently + equivalent to level 6). + + deflateInit returns Z_OK if success, Z_MEM_ERROR if there was not enough + memory, Z_STREAM_ERROR if level is not a valid compression level, or + Z_VERSION_ERROR if the zlib library version (zlib_version) is incompatible + with the version assumed by the caller (ZLIB_VERSION). msg is set to null + if there is no error message. deflateInit does not perform any compression: + this will be done by deflate(). +*/ + + +ZEXTERN int ZEXPORT deflate OF((z_streamp strm, int flush)); +/* + deflate compresses as much data as possible, and stops when the input + buffer becomes empty or the output buffer becomes full. It may introduce + some output latency (reading input without producing any output) except when + forced to flush. + + The detailed semantics are as follows. deflate performs one or both of the + following actions: + + - Compress more input starting at next_in and update next_in and avail_in + accordingly. If not all input can be processed (because there is not + enough room in the output buffer), next_in and avail_in are updated and + processing will resume at this point for the next call of deflate(). + + - Provide more output starting at next_out and update next_out and avail_out + accordingly. This action is forced if the parameter flush is non zero. + Forcing flush frequently degrades the compression ratio, so this parameter + should be set only when necessary (in interactive applications). Some + output may be provided even if flush is not set. + + Before the call of deflate(), the application should ensure that at least + one of the actions is possible, by providing more input and/or consuming more + output, and updating avail_in or avail_out accordingly; avail_out should + never be zero before the call. The application can consume the compressed + output when it wants, for example when the output buffer is full (avail_out + == 0), or after each call of deflate(). If deflate returns Z_OK and with + zero avail_out, it must be called again after making room in the output + buffer because there might be more output pending. + + Normally the parameter flush is set to Z_NO_FLUSH, which allows deflate to + decide how much data to accumulate before producing output, in order to + maximize compression. + + If the parameter flush is set to Z_SYNC_FLUSH, all pending output is + flushed to the output buffer and the output is aligned on a byte boundary, so + that the decompressor can get all input data available so far. (In + particular avail_in is zero after the call if enough output space has been + provided before the call.) Flushing may degrade compression for some + compression algorithms and so it should be used only when necessary. This + completes the current deflate block and follows it with an empty stored block + that is three bits plus filler bits to the next byte, followed by four bytes + (00 00 ff ff). + + If flush is set to Z_PARTIAL_FLUSH, all pending output is flushed to the + output buffer, but the output is not aligned to a byte boundary. All of the + input data so far will be available to the decompressor, as for Z_SYNC_FLUSH. + This completes the current deflate block and follows it with an empty fixed + codes block that is 10 bits long. This assures that enough bytes are output + in order for the decompressor to finish the block before the empty fixed code + block. + + If flush is set to Z_BLOCK, a deflate block is completed and emitted, as + for Z_SYNC_FLUSH, but the output is not aligned on a byte boundary, and up to + seven bits of the current block are held to be written as the next byte after + the next deflate block is completed. In this case, the decompressor may not + be provided enough bits at this point in order to complete decompression of + the data provided so far to the compressor. It may need to wait for the next + block to be emitted. This is for advanced applications that need to control + the emission of deflate blocks. + + If flush is set to Z_FULL_FLUSH, all output is flushed as with + Z_SYNC_FLUSH, and the compression state is reset so that decompression can + restart from this point if previous compressed data has been damaged or if + random access is desired. Using Z_FULL_FLUSH too often can seriously degrade + compression. + + If deflate returns with avail_out == 0, this function must be called again + with the same value of the flush parameter and more output space (updated + avail_out), until the flush is complete (deflate returns with non-zero + avail_out). In the case of a Z_FULL_FLUSH or Z_SYNC_FLUSH, make sure that + avail_out is greater than six to avoid repeated flush markers due to + avail_out == 0 on return. + + If the parameter flush is set to Z_FINISH, pending input is processed, + pending output is flushed and deflate returns with Z_STREAM_END if there was + enough output space; if deflate returns with Z_OK, this function must be + called again with Z_FINISH and more output space (updated avail_out) but no + more input data, until it returns with Z_STREAM_END or an error. After + deflate has returned Z_STREAM_END, the only possible operations on the stream + are deflateReset or deflateEnd. + + Z_FINISH can be used immediately after deflateInit if all the compression + is to be done in a single step. In this case, avail_out must be at least the + value returned by deflateBound (see below). Then deflate is guaranteed to + return Z_STREAM_END. If not enough output space is provided, deflate will + not return Z_STREAM_END, and it must be called again as described above. + + deflate() sets strm->adler to the adler32 checksum of all input read + so far (that is, total_in bytes). + + deflate() may update strm->data_type if it can make a good guess about + the input data type (Z_BINARY or Z_TEXT). In doubt, the data is considered + binary. This field is only for information purposes and does not affect the + compression algorithm in any manner. + + deflate() returns Z_OK if some progress has been made (more input + processed or more output produced), Z_STREAM_END if all input has been + consumed and all output has been produced (only when flush is set to + Z_FINISH), Z_STREAM_ERROR if the stream state was inconsistent (for example + if next_in or next_out was Z_NULL), Z_BUF_ERROR if no progress is possible + (for example avail_in or avail_out was zero). Note that Z_BUF_ERROR is not + fatal, and deflate() can be called again with more input and more output + space to continue compressing. +*/ + + +ZEXTERN int ZEXPORT deflateEnd OF((z_streamp strm)); +/* + All dynamically allocated data structures for this stream are freed. + This function discards any unprocessed input and does not flush any pending + output. + + deflateEnd returns Z_OK if success, Z_STREAM_ERROR if the + stream state was inconsistent, Z_DATA_ERROR if the stream was freed + prematurely (some input or output was discarded). In the error case, msg + may be set but then points to a static string (which must not be + deallocated). +*/ + + +/* +ZEXTERN int ZEXPORT inflateInit OF((z_streamp strm)); + + Initializes the internal stream state for decompression. The fields + next_in, avail_in, zalloc, zfree and opaque must be initialized before by + the caller. If next_in is not Z_NULL and avail_in is large enough (the + exact value depends on the compression method), inflateInit determines the + compression method from the zlib header and allocates all data structures + accordingly; otherwise the allocation will be deferred to the first call of + inflate. If zalloc and zfree are set to Z_NULL, inflateInit updates them to + use default allocation functions. + + inflateInit returns Z_OK if success, Z_MEM_ERROR if there was not enough + memory, Z_VERSION_ERROR if the zlib library version is incompatible with the + version assumed by the caller, or Z_STREAM_ERROR if the parameters are + invalid, such as a null pointer to the structure. msg is set to null if + there is no error message. inflateInit does not perform any decompression + apart from possibly reading the zlib header if present: actual decompression + will be done by inflate(). (So next_in and avail_in may be modified, but + next_out and avail_out are unused and unchanged.) The current implementation + of inflateInit() does not process any header information -- that is deferred + until inflate() is called. +*/ + + +ZEXTERN int ZEXPORT inflate OF((z_streamp strm, int flush)); +/* + inflate decompresses as much data as possible, and stops when the input + buffer becomes empty or the output buffer becomes full. It may introduce + some output latency (reading input without producing any output) except when + forced to flush. + + The detailed semantics are as follows. inflate performs one or both of the + following actions: + + - Decompress more input starting at next_in and update next_in and avail_in + accordingly. If not all input can be processed (because there is not + enough room in the output buffer), next_in is updated and processing will + resume at this point for the next call of inflate(). + + - Provide more output starting at next_out and update next_out and avail_out + accordingly. inflate() provides as much output as possible, until there is + no more input data or no more space in the output buffer (see below about + the flush parameter). + + Before the call of inflate(), the application should ensure that at least + one of the actions is possible, by providing more input and/or consuming more + output, and updating the next_* and avail_* values accordingly. The + application can consume the uncompressed output when it wants, for example + when the output buffer is full (avail_out == 0), or after each call of + inflate(). If inflate returns Z_OK and with zero avail_out, it must be + called again after making room in the output buffer because there might be + more output pending. + + The flush parameter of inflate() can be Z_NO_FLUSH, Z_SYNC_FLUSH, Z_FINISH, + Z_BLOCK, or Z_TREES. Z_SYNC_FLUSH requests that inflate() flush as much + output as possible to the output buffer. Z_BLOCK requests that inflate() + stop if and when it gets to the next deflate block boundary. When decoding + the zlib or gzip format, this will cause inflate() to return immediately + after the header and before the first block. When doing a raw inflate, + inflate() will go ahead and process the first block, and will return when it + gets to the end of that block, or when it runs out of data. + + The Z_BLOCK option assists in appending to or combining deflate streams. + Also to assist in this, on return inflate() will set strm->data_type to the + number of unused bits in the last byte taken from strm->next_in, plus 64 if + inflate() is currently decoding the last block in the deflate stream, plus + 128 if inflate() returned immediately after decoding an end-of-block code or + decoding the complete header up to just before the first byte of the deflate + stream. The end-of-block will not be indicated until all of the uncompressed + data from that block has been written to strm->next_out. The number of + unused bits may in general be greater than seven, except when bit 7 of + data_type is set, in which case the number of unused bits will be less than + eight. data_type is set as noted here every time inflate() returns for all + flush options, and so can be used to determine the amount of currently + consumed input in bits. + + The Z_TREES option behaves as Z_BLOCK does, but it also returns when the + end of each deflate block header is reached, before any actual data in that + block is decoded. This allows the caller to determine the length of the + deflate block header for later use in random access within a deflate block. + 256 is added to the value of strm->data_type when inflate() returns + immediately after reaching the end of the deflate block header. + + inflate() should normally be called until it returns Z_STREAM_END or an + error. However if all decompression is to be performed in a single step (a + single call of inflate), the parameter flush should be set to Z_FINISH. In + this case all pending input is processed and all pending output is flushed; + avail_out must be large enough to hold all of the uncompressed data for the + operation to complete. (The size of the uncompressed data may have been + saved by the compressor for this purpose.) The use of Z_FINISH is not + required to perform an inflation in one step. However it may be used to + inform inflate that a faster approach can be used for the single inflate() + call. Z_FINISH also informs inflate to not maintain a sliding window if the + stream completes, which reduces inflate's memory footprint. If the stream + does not complete, either because not all of the stream is provided or not + enough output space is provided, then a sliding window will be allocated and + inflate() can be called again to continue the operation as if Z_NO_FLUSH had + been used. + + In this implementation, inflate() always flushes as much output as + possible to the output buffer, and always uses the faster approach on the + first call. So the effects of the flush parameter in this implementation are + on the return value of inflate() as noted below, when inflate() returns early + when Z_BLOCK or Z_TREES is used, and when inflate() avoids the allocation of + memory for a sliding window when Z_FINISH is used. + + If a preset dictionary is needed after this call (see inflateSetDictionary + below), inflate sets strm->adler to the Adler-32 checksum of the dictionary + chosen by the compressor and returns Z_NEED_DICT; otherwise it sets + strm->adler to the Adler-32 checksum of all output produced so far (that is, + total_out bytes) and returns Z_OK, Z_STREAM_END or an error code as described + below. At the end of the stream, inflate() checks that its computed adler32 + checksum is equal to that saved by the compressor and returns Z_STREAM_END + only if the checksum is correct. + + inflate() can decompress and check either zlib-wrapped or gzip-wrapped + deflate data. The header type is detected automatically, if requested when + initializing with inflateInit2(). Any information contained in the gzip + header is not retained, so applications that need that information should + instead use raw inflate, see inflateInit2() below, or inflateBack() and + perform their own processing of the gzip header and trailer. When processing + gzip-wrapped deflate data, strm->adler32 is set to the CRC-32 of the output + producted so far. The CRC-32 is checked against the gzip trailer. + + inflate() returns Z_OK if some progress has been made (more input processed + or more output produced), Z_STREAM_END if the end of the compressed data has + been reached and all uncompressed output has been produced, Z_NEED_DICT if a + preset dictionary is needed at this point, Z_DATA_ERROR if the input data was + corrupted (input stream not conforming to the zlib format or incorrect check + value), Z_STREAM_ERROR if the stream structure was inconsistent (for example + next_in or next_out was Z_NULL), Z_MEM_ERROR if there was not enough memory, + Z_BUF_ERROR if no progress is possible or if there was not enough room in the + output buffer when Z_FINISH is used. Note that Z_BUF_ERROR is not fatal, and + inflate() can be called again with more input and more output space to + continue decompressing. If Z_DATA_ERROR is returned, the application may + then call inflateSync() to look for a good compression block if a partial + recovery of the data is desired. +*/ + + +ZEXTERN int ZEXPORT inflateEnd OF((z_streamp strm)); +/* + All dynamically allocated data structures for this stream are freed. + This function discards any unprocessed input and does not flush any pending + output. + + inflateEnd returns Z_OK if success, Z_STREAM_ERROR if the stream state + was inconsistent. In the error case, msg may be set but then points to a + static string (which must not be deallocated). +*/ + + + /* Advanced functions */ + +/* + The following functions are needed only in some special applications. +*/ + +/* +ZEXTERN int ZEXPORT deflateInit2 OF((z_streamp strm, + int level, + int method, + int windowBits, + int memLevel, + int strategy)); + + This is another version of deflateInit with more compression options. The + fields next_in, zalloc, zfree and opaque must be initialized before by the + caller. + + The method parameter is the compression method. It must be Z_DEFLATED in + this version of the library. + + The windowBits parameter is the base two logarithm of the window size + (the size of the history buffer). It should be in the range 8..15 for this + version of the library. Larger values of this parameter result in better + compression at the expense of memory usage. The default value is 15 if + deflateInit is used instead. + + windowBits can also be -8..-15 for raw deflate. In this case, -windowBits + determines the window size. deflate() will then generate raw deflate data + with no zlib header or trailer, and will not compute an adler32 check value. + + windowBits can also be greater than 15 for optional gzip encoding. Add + 16 to windowBits to write a simple gzip header and trailer around the + compressed data instead of a zlib wrapper. The gzip header will have no + file name, no extra data, no comment, no modification time (set to zero), no + header crc, and the operating system will be set to 255 (unknown). If a + gzip stream is being written, strm->adler is a crc32 instead of an adler32. + + The memLevel parameter specifies how much memory should be allocated + for the internal compression state. memLevel=1 uses minimum memory but is + slow and reduces compression ratio; memLevel=9 uses maximum memory for + optimal speed. The default value is 8. See zconf.h for total memory usage + as a function of windowBits and memLevel. + + The strategy parameter is used to tune the compression algorithm. Use the + value Z_DEFAULT_STRATEGY for normal data, Z_FILTERED for data produced by a + filter (or predictor), Z_HUFFMAN_ONLY to force Huffman encoding only (no + string match), or Z_RLE to limit match distances to one (run-length + encoding). Filtered data consists mostly of small values with a somewhat + random distribution. In this case, the compression algorithm is tuned to + compress them better. The effect of Z_FILTERED is to force more Huffman + coding and less string matching; it is somewhat intermediate between + Z_DEFAULT_STRATEGY and Z_HUFFMAN_ONLY. Z_RLE is designed to be almost as + fast as Z_HUFFMAN_ONLY, but give better compression for PNG image data. The + strategy parameter only affects the compression ratio but not the + correctness of the compressed output even if it is not set appropriately. + Z_FIXED prevents the use of dynamic Huffman codes, allowing for a simpler + decoder for special applications. + + deflateInit2 returns Z_OK if success, Z_MEM_ERROR if there was not enough + memory, Z_STREAM_ERROR if any parameter is invalid (such as an invalid + method), or Z_VERSION_ERROR if the zlib library version (zlib_version) is + incompatible with the version assumed by the caller (ZLIB_VERSION). msg is + set to null if there is no error message. deflateInit2 does not perform any + compression: this will be done by deflate(). +*/ + +ZEXTERN int ZEXPORT deflateSetDictionary OF((z_streamp strm, + const Bytef *dictionary, + uInt dictLength)); +/* + Initializes the compression dictionary from the given byte sequence + without producing any compressed output. When using the zlib format, this + function must be called immediately after deflateInit, deflateInit2 or + deflateReset, and before any call of deflate. When doing raw deflate, this + function must be called either before any call of deflate, or immediately + after the completion of a deflate block, i.e. after all input has been + consumed and all output has been delivered when using any of the flush + options Z_BLOCK, Z_PARTIAL_FLUSH, Z_SYNC_FLUSH, or Z_FULL_FLUSH. The + compressor and decompressor must use exactly the same dictionary (see + inflateSetDictionary). + + The dictionary should consist of strings (byte sequences) that are likely + to be encountered later in the data to be compressed, with the most commonly + used strings preferably put towards the end of the dictionary. Using a + dictionary is most useful when the data to be compressed is short and can be + predicted with good accuracy; the data can then be compressed better than + with the default empty dictionary. + + Depending on the size of the compression data structures selected by + deflateInit or deflateInit2, a part of the dictionary may in effect be + discarded, for example if the dictionary is larger than the window size + provided in deflateInit or deflateInit2. Thus the strings most likely to be + useful should be put at the end of the dictionary, not at the front. In + addition, the current implementation of deflate will use at most the window + size minus 262 bytes of the provided dictionary. + + Upon return of this function, strm->adler is set to the adler32 value + of the dictionary; the decompressor may later use this value to determine + which dictionary has been used by the compressor. (The adler32 value + applies to the whole dictionary even if only a subset of the dictionary is + actually used by the compressor.) If a raw deflate was requested, then the + adler32 value is not computed and strm->adler is not set. + + deflateSetDictionary returns Z_OK if success, or Z_STREAM_ERROR if a + parameter is invalid (e.g. dictionary being Z_NULL) or the stream state is + inconsistent (for example if deflate has already been called for this stream + or if not at a block boundary for raw deflate). deflateSetDictionary does + not perform any compression: this will be done by deflate(). +*/ + +ZEXTERN int ZEXPORT deflateCopy OF((z_streamp dest, + z_streamp source)); +/* + Sets the destination stream as a complete copy of the source stream. + + This function can be useful when several compression strategies will be + tried, for example when there are several ways of pre-processing the input + data with a filter. The streams that will be discarded should then be freed + by calling deflateEnd. Note that deflateCopy duplicates the internal + compression state which can be quite large, so this strategy is slow and can + consume lots of memory. + + deflateCopy returns Z_OK if success, Z_MEM_ERROR if there was not + enough memory, Z_STREAM_ERROR if the source stream state was inconsistent + (such as zalloc being Z_NULL). msg is left unchanged in both source and + destination. +*/ + +ZEXTERN int ZEXPORT deflateReset OF((z_streamp strm)); +/* + This function is equivalent to deflateEnd followed by deflateInit, + but does not free and reallocate all the internal compression state. The + stream will keep the same compression level and any other attributes that + may have been set by deflateInit2. + + deflateReset returns Z_OK if success, or Z_STREAM_ERROR if the source + stream state was inconsistent (such as zalloc or state being Z_NULL). +*/ + +ZEXTERN int ZEXPORT deflateParams OF((z_streamp strm, + int level, + int strategy)); +/* + Dynamically update the compression level and compression strategy. The + interpretation of level and strategy is as in deflateInit2. This can be + used to switch between compression and straight copy of the input data, or + to switch to a different kind of input data requiring a different strategy. + If the compression level is changed, the input available so far is + compressed with the old level (and may be flushed); the new level will take + effect only at the next call of deflate(). + + Before the call of deflateParams, the stream state must be set as for + a call of deflate(), since the currently available input may have to be + compressed and flushed. In particular, strm->avail_out must be non-zero. + + deflateParams returns Z_OK if success, Z_STREAM_ERROR if the source + stream state was inconsistent or if a parameter was invalid, Z_BUF_ERROR if + strm->avail_out was zero. +*/ + +ZEXTERN int ZEXPORT deflateTune OF((z_streamp strm, + int good_length, + int max_lazy, + int nice_length, + int max_chain)); +/* + Fine tune deflate's internal compression parameters. This should only be + used by someone who understands the algorithm used by zlib's deflate for + searching for the best matching string, and even then only by the most + fanatic optimizer trying to squeeze out the last compressed bit for their + specific input data. Read the deflate.c source code for the meaning of the + max_lazy, good_length, nice_length, and max_chain parameters. + + deflateTune() can be called after deflateInit() or deflateInit2(), and + returns Z_OK on success, or Z_STREAM_ERROR for an invalid deflate stream. + */ + +ZEXTERN uLong ZEXPORT deflateBound OF((z_streamp strm, + uLong sourceLen)); +/* + deflateBound() returns an upper bound on the compressed size after + deflation of sourceLen bytes. It must be called after deflateInit() or + deflateInit2(), and after deflateSetHeader(), if used. This would be used + to allocate an output buffer for deflation in a single pass, and so would be + called before deflate(). If that first deflate() call is provided the + sourceLen input bytes, an output buffer allocated to the size returned by + deflateBound(), and the flush value Z_FINISH, then deflate() is guaranteed + to return Z_STREAM_END. Note that it is possible for the compressed size to + be larger than the value returned by deflateBound() if flush options other + than Z_FINISH or Z_NO_FLUSH are used. +*/ + +ZEXTERN int ZEXPORT deflatePending OF((z_streamp strm, + unsigned *pending, + int *bits)); +/* + deflatePending() returns the number of bytes and bits of output that have + been generated, but not yet provided in the available output. The bytes not + provided would be due to the available output space having being consumed. + The number of bits of output not provided are between 0 and 7, where they + await more bits to join them in order to fill out a full byte. If pending + or bits are Z_NULL, then those values are not set. + + deflatePending returns Z_OK if success, or Z_STREAM_ERROR if the source + stream state was inconsistent. + */ + +ZEXTERN int ZEXPORT deflatePrime OF((z_streamp strm, + int bits, + int value)); +/* + deflatePrime() inserts bits in the deflate output stream. The intent + is that this function is used to start off the deflate output with the bits + leftover from a previous deflate stream when appending to it. As such, this + function can only be used for raw deflate, and must be used before the first + deflate() call after a deflateInit2() or deflateReset(). bits must be less + than or equal to 16, and that many of the least significant bits of value + will be inserted in the output. + + deflatePrime returns Z_OK if success, Z_BUF_ERROR if there was not enough + room in the internal buffer to insert the bits, or Z_STREAM_ERROR if the + source stream state was inconsistent. +*/ + +ZEXTERN int ZEXPORT deflateSetHeader OF((z_streamp strm, + gz_headerp head)); +/* + deflateSetHeader() provides gzip header information for when a gzip + stream is requested by deflateInit2(). deflateSetHeader() may be called + after deflateInit2() or deflateReset() and before the first call of + deflate(). The text, time, os, extra field, name, and comment information + in the provided gz_header structure are written to the gzip header (xflag is + ignored -- the extra flags are set according to the compression level). The + caller must assure that, if not Z_NULL, name and comment are terminated with + a zero byte, and that if extra is not Z_NULL, that extra_len bytes are + available there. If hcrc is true, a gzip header crc is included. Note that + the current versions of the command-line version of gzip (up through version + 1.3.x) do not support header crc's, and will report that it is a "multi-part + gzip file" and give up. + + If deflateSetHeader is not used, the default gzip header has text false, + the time set to zero, and os set to 255, with no extra, name, or comment + fields. The gzip header is returned to the default state by deflateReset(). + + deflateSetHeader returns Z_OK if success, or Z_STREAM_ERROR if the source + stream state was inconsistent. +*/ + +/* +ZEXTERN int ZEXPORT inflateInit2 OF((z_streamp strm, + int windowBits)); + + This is another version of inflateInit with an extra parameter. The + fields next_in, avail_in, zalloc, zfree and opaque must be initialized + before by the caller. + + The windowBits parameter is the base two logarithm of the maximum window + size (the size of the history buffer). It should be in the range 8..15 for + this version of the library. The default value is 15 if inflateInit is used + instead. windowBits must be greater than or equal to the windowBits value + provided to deflateInit2() while compressing, or it must be equal to 15 if + deflateInit2() was not used. If a compressed stream with a larger window + size is given as input, inflate() will return with the error code + Z_DATA_ERROR instead of trying to allocate a larger window. + + windowBits can also be zero to request that inflate use the window size in + the zlib header of the compressed stream. + + windowBits can also be -8..-15 for raw inflate. In this case, -windowBits + determines the window size. inflate() will then process raw deflate data, + not looking for a zlib or gzip header, not generating a check value, and not + looking for any check values for comparison at the end of the stream. This + is for use with other formats that use the deflate compressed data format + such as zip. Those formats provide their own check values. If a custom + format is developed using the raw deflate format for compressed data, it is + recommended that a check value such as an adler32 or a crc32 be applied to + the uncompressed data as is done in the zlib, gzip, and zip formats. For + most applications, the zlib format should be used as is. Note that comments + above on the use in deflateInit2() applies to the magnitude of windowBits. + + windowBits can also be greater than 15 for optional gzip decoding. Add + 32 to windowBits to enable zlib and gzip decoding with automatic header + detection, or add 16 to decode only the gzip format (the zlib format will + return a Z_DATA_ERROR). If a gzip stream is being decoded, strm->adler is a + crc32 instead of an adler32. + + inflateInit2 returns Z_OK if success, Z_MEM_ERROR if there was not enough + memory, Z_VERSION_ERROR if the zlib library version is incompatible with the + version assumed by the caller, or Z_STREAM_ERROR if the parameters are + invalid, such as a null pointer to the structure. msg is set to null if + there is no error message. inflateInit2 does not perform any decompression + apart from possibly reading the zlib header if present: actual decompression + will be done by inflate(). (So next_in and avail_in may be modified, but + next_out and avail_out are unused and unchanged.) The current implementation + of inflateInit2() does not process any header information -- that is + deferred until inflate() is called. +*/ + +ZEXTERN int ZEXPORT inflateSetDictionary OF((z_streamp strm, + const Bytef *dictionary, + uInt dictLength)); +/* + Initializes the decompression dictionary from the given uncompressed byte + sequence. This function must be called immediately after a call of inflate, + if that call returned Z_NEED_DICT. The dictionary chosen by the compressor + can be determined from the adler32 value returned by that call of inflate. + The compressor and decompressor must use exactly the same dictionary (see + deflateSetDictionary). For raw inflate, this function can be called at any + time to set the dictionary. If the provided dictionary is smaller than the + window and there is already data in the window, then the provided dictionary + will amend what's there. The application must insure that the dictionary + that was used for compression is provided. + + inflateSetDictionary returns Z_OK if success, Z_STREAM_ERROR if a + parameter is invalid (e.g. dictionary being Z_NULL) or the stream state is + inconsistent, Z_DATA_ERROR if the given dictionary doesn't match the + expected one (incorrect adler32 value). inflateSetDictionary does not + perform any decompression: this will be done by subsequent calls of + inflate(). +*/ + +ZEXTERN int ZEXPORT inflateGetDictionary OF((z_streamp strm, + Bytef *dictionary, + uInt *dictLength)); +/* + Returns the sliding dictionary being maintained by inflate. dictLength is + set to the number of bytes in the dictionary, and that many bytes are copied + to dictionary. dictionary must have enough space, where 32768 bytes is + always enough. If inflateGetDictionary() is called with dictionary equal to + Z_NULL, then only the dictionary length is returned, and nothing is copied. + Similary, if dictLength is Z_NULL, then it is not set. + + inflateGetDictionary returns Z_OK on success, or Z_STREAM_ERROR if the + stream state is inconsistent. +*/ + +ZEXTERN int ZEXPORT inflateSync OF((z_streamp strm)); +/* + Skips invalid compressed data until a possible full flush point (see above + for the description of deflate with Z_FULL_FLUSH) can be found, or until all + available input is skipped. No output is provided. + + inflateSync searches for a 00 00 FF FF pattern in the compressed data. + All full flush points have this pattern, but not all occurrences of this + pattern are full flush points. + + inflateSync returns Z_OK if a possible full flush point has been found, + Z_BUF_ERROR if no more input was provided, Z_DATA_ERROR if no flush point + has been found, or Z_STREAM_ERROR if the stream structure was inconsistent. + In the success case, the application may save the current current value of + total_in which indicates where valid compressed data was found. In the + error case, the application may repeatedly call inflateSync, providing more + input each time, until success or end of the input data. +*/ + +ZEXTERN int ZEXPORT inflateCopy OF((z_streamp dest, + z_streamp source)); +/* + Sets the destination stream as a complete copy of the source stream. + + This function can be useful when randomly accessing a large stream. The + first pass through the stream can periodically record the inflate state, + allowing restarting inflate at those points when randomly accessing the + stream. + + inflateCopy returns Z_OK if success, Z_MEM_ERROR if there was not + enough memory, Z_STREAM_ERROR if the source stream state was inconsistent + (such as zalloc being Z_NULL). msg is left unchanged in both source and + destination. +*/ + +ZEXTERN int ZEXPORT inflateReset OF((z_streamp strm)); +/* + This function is equivalent to inflateEnd followed by inflateInit, + but does not free and reallocate all the internal decompression state. The + stream will keep attributes that may have been set by inflateInit2. + + inflateReset returns Z_OK if success, or Z_STREAM_ERROR if the source + stream state was inconsistent (such as zalloc or state being Z_NULL). +*/ + +ZEXTERN int ZEXPORT inflateReset2 OF((z_streamp strm, + int windowBits)); +/* + This function is the same as inflateReset, but it also permits changing + the wrap and window size requests. The windowBits parameter is interpreted + the same as it is for inflateInit2. + + inflateReset2 returns Z_OK if success, or Z_STREAM_ERROR if the source + stream state was inconsistent (such as zalloc or state being Z_NULL), or if + the windowBits parameter is invalid. +*/ + +ZEXTERN int ZEXPORT inflatePrime OF((z_streamp strm, + int bits, + int value)); +/* + This function inserts bits in the inflate input stream. The intent is + that this function is used to start inflating at a bit position in the + middle of a byte. The provided bits will be used before any bytes are used + from next_in. This function should only be used with raw inflate, and + should be used before the first inflate() call after inflateInit2() or + inflateReset(). bits must be less than or equal to 16, and that many of the + least significant bits of value will be inserted in the input. + + If bits is negative, then the input stream bit buffer is emptied. Then + inflatePrime() can be called again to put bits in the buffer. This is used + to clear out bits leftover after feeding inflate a block description prior + to feeding inflate codes. + + inflatePrime returns Z_OK if success, or Z_STREAM_ERROR if the source + stream state was inconsistent. +*/ + +ZEXTERN long ZEXPORT inflateMark OF((z_streamp strm)); +/* + This function returns two values, one in the lower 16 bits of the return + value, and the other in the remaining upper bits, obtained by shifting the + return value down 16 bits. If the upper value is -1 and the lower value is + zero, then inflate() is currently decoding information outside of a block. + If the upper value is -1 and the lower value is non-zero, then inflate is in + the middle of a stored block, with the lower value equaling the number of + bytes from the input remaining to copy. If the upper value is not -1, then + it is the number of bits back from the current bit position in the input of + the code (literal or length/distance pair) currently being processed. In + that case the lower value is the number of bytes already emitted for that + code. + + A code is being processed if inflate is waiting for more input to complete + decoding of the code, or if it has completed decoding but is waiting for + more output space to write the literal or match data. + + inflateMark() is used to mark locations in the input data for random + access, which may be at bit positions, and to note those cases where the + output of a code may span boundaries of random access blocks. The current + location in the input stream can be determined from avail_in and data_type + as noted in the description for the Z_BLOCK flush parameter for inflate. + + inflateMark returns the value noted above or -1 << 16 if the provided + source stream state was inconsistent. +*/ + +ZEXTERN int ZEXPORT inflateGetHeader OF((z_streamp strm, + gz_headerp head)); +/* + inflateGetHeader() requests that gzip header information be stored in the + provided gz_header structure. inflateGetHeader() may be called after + inflateInit2() or inflateReset(), and before the first call of inflate(). + As inflate() processes the gzip stream, head->done is zero until the header + is completed, at which time head->done is set to one. If a zlib stream is + being decoded, then head->done is set to -1 to indicate that there will be + no gzip header information forthcoming. Note that Z_BLOCK or Z_TREES can be + used to force inflate() to return immediately after header processing is + complete and before any actual data is decompressed. + + The text, time, xflags, and os fields are filled in with the gzip header + contents. hcrc is set to true if there is a header CRC. (The header CRC + was valid if done is set to one.) If extra is not Z_NULL, then extra_max + contains the maximum number of bytes to write to extra. Once done is true, + extra_len contains the actual extra field length, and extra contains the + extra field, or that field truncated if extra_max is less than extra_len. + If name is not Z_NULL, then up to name_max characters are written there, + terminated with a zero unless the length is greater than name_max. If + comment is not Z_NULL, then up to comm_max characters are written there, + terminated with a zero unless the length is greater than comm_max. When any + of extra, name, or comment are not Z_NULL and the respective field is not + present in the header, then that field is set to Z_NULL to signal its + absence. This allows the use of deflateSetHeader() with the returned + structure to duplicate the header. However if those fields are set to + allocated memory, then the application will need to save those pointers + elsewhere so that they can be eventually freed. + + If inflateGetHeader is not used, then the header information is simply + discarded. The header is always checked for validity, including the header + CRC if present. inflateReset() will reset the process to discard the header + information. The application would need to call inflateGetHeader() again to + retrieve the header from the next gzip stream. + + inflateGetHeader returns Z_OK if success, or Z_STREAM_ERROR if the source + stream state was inconsistent. +*/ + +/* +ZEXTERN int ZEXPORT inflateBackInit OF((z_streamp strm, int windowBits, + unsigned char FAR *window)); + + Initialize the internal stream state for decompression using inflateBack() + calls. The fields zalloc, zfree and opaque in strm must be initialized + before the call. If zalloc and zfree are Z_NULL, then the default library- + derived memory allocation routines are used. windowBits is the base two + logarithm of the window size, in the range 8..15. window is a caller + supplied buffer of that size. Except for special applications where it is + assured that deflate was used with small window sizes, windowBits must be 15 + and a 32K byte window must be supplied to be able to decompress general + deflate streams. + + See inflateBack() for the usage of these routines. + + inflateBackInit will return Z_OK on success, Z_STREAM_ERROR if any of + the parameters are invalid, Z_MEM_ERROR if the internal state could not be + allocated, or Z_VERSION_ERROR if the version of the library does not match + the version of the header file. +*/ + +typedef unsigned (*in_func) OF((void FAR *, + z_const unsigned char FAR * FAR *)); +typedef int (*out_func) OF((void FAR *, unsigned char FAR *, unsigned)); + +ZEXTERN int ZEXPORT inflateBack OF((z_streamp strm, + in_func in, void FAR *in_desc, + out_func out, void FAR *out_desc)); +/* + inflateBack() does a raw inflate with a single call using a call-back + interface for input and output. This is potentially more efficient than + inflate() for file i/o applications, in that it avoids copying between the + output and the sliding window by simply making the window itself the output + buffer. inflate() can be faster on modern CPUs when used with large + buffers. inflateBack() trusts the application to not change the output + buffer passed by the output function, at least until inflateBack() returns. + + inflateBackInit() must be called first to allocate the internal state + and to initialize the state with the user-provided window buffer. + inflateBack() may then be used multiple times to inflate a complete, raw + deflate stream with each call. inflateBackEnd() is then called to free the + allocated state. + + A raw deflate stream is one with no zlib or gzip header or trailer. + This routine would normally be used in a utility that reads zip or gzip + files and writes out uncompressed files. The utility would decode the + header and process the trailer on its own, hence this routine expects only + the raw deflate stream to decompress. This is different from the normal + behavior of inflate(), which expects either a zlib or gzip header and + trailer around the deflate stream. + + inflateBack() uses two subroutines supplied by the caller that are then + called by inflateBack() for input and output. inflateBack() calls those + routines until it reads a complete deflate stream and writes out all of the + uncompressed data, or until it encounters an error. The function's + parameters and return types are defined above in the in_func and out_func + typedefs. inflateBack() will call in(in_desc, &buf) which should return the + number of bytes of provided input, and a pointer to that input in buf. If + there is no input available, in() must return zero--buf is ignored in that + case--and inflateBack() will return a buffer error. inflateBack() will call + out(out_desc, buf, len) to write the uncompressed data buf[0..len-1]. out() + should return zero on success, or non-zero on failure. If out() returns + non-zero, inflateBack() will return with an error. Neither in() nor out() + are permitted to change the contents of the window provided to + inflateBackInit(), which is also the buffer that out() uses to write from. + The length written by out() will be at most the window size. Any non-zero + amount of input may be provided by in(). + + For convenience, inflateBack() can be provided input on the first call by + setting strm->next_in and strm->avail_in. If that input is exhausted, then + in() will be called. Therefore strm->next_in must be initialized before + calling inflateBack(). If strm->next_in is Z_NULL, then in() will be called + immediately for input. If strm->next_in is not Z_NULL, then strm->avail_in + must also be initialized, and then if strm->avail_in is not zero, input will + initially be taken from strm->next_in[0 .. strm->avail_in - 1]. + + The in_desc and out_desc parameters of inflateBack() is passed as the + first parameter of in() and out() respectively when they are called. These + descriptors can be optionally used to pass any information that the caller- + supplied in() and out() functions need to do their job. + + On return, inflateBack() will set strm->next_in and strm->avail_in to + pass back any unused input that was provided by the last in() call. The + return values of inflateBack() can be Z_STREAM_END on success, Z_BUF_ERROR + if in() or out() returned an error, Z_DATA_ERROR if there was a format error + in the deflate stream (in which case strm->msg is set to indicate the nature + of the error), or Z_STREAM_ERROR if the stream was not properly initialized. + In the case of Z_BUF_ERROR, an input or output error can be distinguished + using strm->next_in which will be Z_NULL only if in() returned an error. If + strm->next_in is not Z_NULL, then the Z_BUF_ERROR was due to out() returning + non-zero. (in() will always be called before out(), so strm->next_in is + assured to be defined if out() returns non-zero.) Note that inflateBack() + cannot return Z_OK. +*/ + +ZEXTERN int ZEXPORT inflateBackEnd OF((z_streamp strm)); +/* + All memory allocated by inflateBackInit() is freed. + + inflateBackEnd() returns Z_OK on success, or Z_STREAM_ERROR if the stream + state was inconsistent. +*/ + +ZEXTERN uLong ZEXPORT zlibCompileFlags OF((void)); +/* Return flags indicating compile-time options. + + Type sizes, two bits each, 00 = 16 bits, 01 = 32, 10 = 64, 11 = other: + 1.0: size of uInt + 3.2: size of uLong + 5.4: size of voidpf (pointer) + 7.6: size of z_off_t + + Compiler, assembler, and debug options: + 8: DEBUG + 9: ASMV or ASMINF -- use ASM code + 10: ZLIB_WINAPI -- exported functions use the WINAPI calling convention + 11: 0 (reserved) + + One-time table building (smaller code, but not thread-safe if true): + 12: BUILDFIXED -- build static block decoding tables when needed + 13: DYNAMIC_CRC_TABLE -- build CRC calculation tables when needed + 14,15: 0 (reserved) + + Library content (indicates missing functionality): + 16: NO_GZCOMPRESS -- gz* functions cannot compress (to avoid linking + deflate code when not needed) + 17: NO_GZIP -- deflate can't write gzip streams, and inflate can't detect + and decode gzip streams (to avoid linking crc code) + 18-19: 0 (reserved) + + Operation variations (changes in library functionality): + 20: PKZIP_BUG_WORKAROUND -- slightly more permissive inflate + 21: FASTEST -- deflate algorithm with only one, lowest compression level + 22,23: 0 (reserved) + + The sprintf variant used by gzprintf (zero is best): + 24: 0 = vs*, 1 = s* -- 1 means limited to 20 arguments after the format + 25: 0 = *nprintf, 1 = *printf -- 1 means gzprintf() not secure! + 26: 0 = returns value, 1 = void -- 1 means inferred string length returned + + Remainder: + 27-31: 0 (reserved) + */ + +#ifndef Z_SOLO + + /* utility functions */ + +/* + The following utility functions are implemented on top of the basic + stream-oriented functions. To simplify the interface, some default options + are assumed (compression level and memory usage, standard memory allocation + functions). The source code of these utility functions can be modified if + you need special options. +*/ + +ZEXTERN int ZEXPORT compress OF((Bytef *dest, uLongf *destLen, + const Bytef *source, uLong sourceLen)); +/* + Compresses the source buffer into the destination buffer. sourceLen is + the byte length of the source buffer. Upon entry, destLen is the total size + of the destination buffer, which must be at least the value returned by + compressBound(sourceLen). Upon exit, destLen is the actual size of the + compressed buffer. + + compress returns Z_OK if success, Z_MEM_ERROR if there was not + enough memory, Z_BUF_ERROR if there was not enough room in the output + buffer. +*/ + +ZEXTERN int ZEXPORT compress2 OF((Bytef *dest, uLongf *destLen, + const Bytef *source, uLong sourceLen, + int level)); +/* + Compresses the source buffer into the destination buffer. The level + parameter has the same meaning as in deflateInit. sourceLen is the byte + length of the source buffer. Upon entry, destLen is the total size of the + destination buffer, which must be at least the value returned by + compressBound(sourceLen). Upon exit, destLen is the actual size of the + compressed buffer. + + compress2 returns Z_OK if success, Z_MEM_ERROR if there was not enough + memory, Z_BUF_ERROR if there was not enough room in the output buffer, + Z_STREAM_ERROR if the level parameter is invalid. +*/ + +ZEXTERN uLong ZEXPORT compressBound OF((uLong sourceLen)); +/* + compressBound() returns an upper bound on the compressed size after + compress() or compress2() on sourceLen bytes. It would be used before a + compress() or compress2() call to allocate the destination buffer. +*/ + +ZEXTERN int ZEXPORT uncompress OF((Bytef *dest, uLongf *destLen, + const Bytef *source, uLong sourceLen)); +/* + Decompresses the source buffer into the destination buffer. sourceLen is + the byte length of the source buffer. Upon entry, destLen is the total size + of the destination buffer, which must be large enough to hold the entire + uncompressed data. (The size of the uncompressed data must have been saved + previously by the compressor and transmitted to the decompressor by some + mechanism outside the scope of this compression library.) Upon exit, destLen + is the actual size of the uncompressed buffer. + + uncompress returns Z_OK if success, Z_MEM_ERROR if there was not + enough memory, Z_BUF_ERROR if there was not enough room in the output + buffer, or Z_DATA_ERROR if the input data was corrupted or incomplete. In + the case where there is not enough room, uncompress() will fill the output + buffer with the uncompressed data up to that point. +*/ + + /* gzip file access functions */ + +/* + This library supports reading and writing files in gzip (.gz) format with + an interface similar to that of stdio, using the functions that start with + "gz". The gzip format is different from the zlib format. gzip is a gzip + wrapper, documented in RFC 1952, wrapped around a deflate stream. +*/ + +typedef struct gzFile_s *gzFile; /* semi-opaque gzip file descriptor */ + +/* +ZEXTERN gzFile ZEXPORT gzopen OF((const char *path, const char *mode)); + + Opens a gzip (.gz) file for reading or writing. The mode parameter is as + in fopen ("rb" or "wb") but can also include a compression level ("wb9") or + a strategy: 'f' for filtered data as in "wb6f", 'h' for Huffman-only + compression as in "wb1h", 'R' for run-length encoding as in "wb1R", or 'F' + for fixed code compression as in "wb9F". (See the description of + deflateInit2 for more information about the strategy parameter.) 'T' will + request transparent writing or appending with no compression and not using + the gzip format. + + "a" can be used instead of "w" to request that the gzip stream that will + be written be appended to the file. "+" will result in an error, since + reading and writing to the same gzip file is not supported. The addition of + "x" when writing will create the file exclusively, which fails if the file + already exists. On systems that support it, the addition of "e" when + reading or writing will set the flag to close the file on an execve() call. + + These functions, as well as gzip, will read and decode a sequence of gzip + streams in a file. The append function of gzopen() can be used to create + such a file. (Also see gzflush() for another way to do this.) When + appending, gzopen does not test whether the file begins with a gzip stream, + nor does it look for the end of the gzip streams to begin appending. gzopen + will simply append a gzip stream to the existing file. + + gzopen can be used to read a file which is not in gzip format; in this + case gzread will directly read from the file without decompression. When + reading, this will be detected automatically by looking for the magic two- + byte gzip header. + + gzopen returns NULL if the file could not be opened, if there was + insufficient memory to allocate the gzFile state, or if an invalid mode was + specified (an 'r', 'w', or 'a' was not provided, or '+' was provided). + errno can be checked to determine if the reason gzopen failed was that the + file could not be opened. +*/ + +ZEXTERN gzFile ZEXPORT gzdopen OF((int fd, const char *mode)); +/* + gzdopen associates a gzFile with the file descriptor fd. File descriptors + are obtained from calls like open, dup, creat, pipe or fileno (if the file + has been previously opened with fopen). The mode parameter is as in gzopen. + + The next call of gzclose on the returned gzFile will also close the file + descriptor fd, just like fclose(fdopen(fd, mode)) closes the file descriptor + fd. If you want to keep fd open, use fd = dup(fd_keep); gz = gzdopen(fd, + mode);. The duplicated descriptor should be saved to avoid a leak, since + gzdopen does not close fd if it fails. If you are using fileno() to get the + file descriptor from a FILE *, then you will have to use dup() to avoid + double-close()ing the file descriptor. Both gzclose() and fclose() will + close the associated file descriptor, so they need to have different file + descriptors. + + gzdopen returns NULL if there was insufficient memory to allocate the + gzFile state, if an invalid mode was specified (an 'r', 'w', or 'a' was not + provided, or '+' was provided), or if fd is -1. The file descriptor is not + used until the next gz* read, write, seek, or close operation, so gzdopen + will not detect if fd is invalid (unless fd is -1). +*/ + +ZEXTERN int ZEXPORT gzbuffer OF((gzFile file, unsigned size)); +/* + Set the internal buffer size used by this library's functions. The + default buffer size is 8192 bytes. This function must be called after + gzopen() or gzdopen(), and before any other calls that read or write the + file. The buffer memory allocation is always deferred to the first read or + write. Two buffers are allocated, either both of the specified size when + writing, or one of the specified size and the other twice that size when + reading. A larger buffer size of, for example, 64K or 128K bytes will + noticeably increase the speed of decompression (reading). + + The new buffer size also affects the maximum length for gzprintf(). + + gzbuffer() returns 0 on success, or -1 on failure, such as being called + too late. +*/ + +ZEXTERN int ZEXPORT gzsetparams OF((gzFile file, int level, int strategy)); +/* + Dynamically update the compression level or strategy. See the description + of deflateInit2 for the meaning of these parameters. + + gzsetparams returns Z_OK if success, or Z_STREAM_ERROR if the file was not + opened for writing. +*/ + +ZEXTERN int ZEXPORT gzread OF((gzFile file, voidp buf, unsigned len)); +/* + Reads the given number of uncompressed bytes from the compressed file. If + the input file is not in gzip format, gzread copies the given number of + bytes into the buffer directly from the file. + + After reaching the end of a gzip stream in the input, gzread will continue + to read, looking for another gzip stream. Any number of gzip streams may be + concatenated in the input file, and will all be decompressed by gzread(). + If something other than a gzip stream is encountered after a gzip stream, + that remaining trailing garbage is ignored (and no error is returned). + + gzread can be used to read a gzip file that is being concurrently written. + Upon reaching the end of the input, gzread will return with the available + data. If the error code returned by gzerror is Z_OK or Z_BUF_ERROR, then + gzclearerr can be used to clear the end of file indicator in order to permit + gzread to be tried again. Z_OK indicates that a gzip stream was completed + on the last gzread. Z_BUF_ERROR indicates that the input file ended in the + middle of a gzip stream. Note that gzread does not return -1 in the event + of an incomplete gzip stream. This error is deferred until gzclose(), which + will return Z_BUF_ERROR if the last gzread ended in the middle of a gzip + stream. Alternatively, gzerror can be used before gzclose to detect this + case. + + gzread returns the number of uncompressed bytes actually read, less than + len for end of file, or -1 for error. +*/ + +ZEXTERN int ZEXPORT gzwrite OF((gzFile file, + voidpc buf, unsigned len)); +/* + Writes the given number of uncompressed bytes into the compressed file. + gzwrite returns the number of uncompressed bytes written or 0 in case of + error. +*/ + +ZEXTERN int ZEXPORTVA gzprintf Z_ARG((gzFile file, const char *format, ...)); +/* + Converts, formats, and writes the arguments to the compressed file under + control of the format string, as in fprintf. gzprintf returns the number of + uncompressed bytes actually written, or 0 in case of error. The number of + uncompressed bytes written is limited to 8191, or one less than the buffer + size given to gzbuffer(). The caller should assure that this limit is not + exceeded. If it is exceeded, then gzprintf() will return an error (0) with + nothing written. In this case, there may also be a buffer overflow with + unpredictable consequences, which is possible only if zlib was compiled with + the insecure functions sprintf() or vsprintf() because the secure snprintf() + or vsnprintf() functions were not available. This can be determined using + zlibCompileFlags(). +*/ + +ZEXTERN int ZEXPORT gzputs OF((gzFile file, const char *s)); +/* + Writes the given null-terminated string to the compressed file, excluding + the terminating null character. + + gzputs returns the number of characters written, or -1 in case of error. +*/ + +ZEXTERN char * ZEXPORT gzgets OF((gzFile file, char *buf, int len)); +/* + Reads bytes from the compressed file until len-1 characters are read, or a + newline character is read and transferred to buf, or an end-of-file + condition is encountered. If any characters are read or if len == 1, the + string is terminated with a null character. If no characters are read due + to an end-of-file or len < 1, then the buffer is left untouched. + + gzgets returns buf which is a null-terminated string, or it returns NULL + for end-of-file or in case of error. If there was an error, the contents at + buf are indeterminate. +*/ + +ZEXTERN int ZEXPORT gzputc OF((gzFile file, int c)); +/* + Writes c, converted to an unsigned char, into the compressed file. gzputc + returns the value that was written, or -1 in case of error. +*/ + +ZEXTERN int ZEXPORT gzgetc OF((gzFile file)); +/* + Reads one byte from the compressed file. gzgetc returns this byte or -1 + in case of end of file or error. This is implemented as a macro for speed. + As such, it does not do all of the checking the other functions do. I.e. + it does not check to see if file is NULL, nor whether the structure file + points to has been clobbered or not. +*/ + +ZEXTERN int ZEXPORT gzungetc OF((int c, gzFile file)); +/* + Push one character back onto the stream to be read as the first character + on the next read. At least one character of push-back is allowed. + gzungetc() returns the character pushed, or -1 on failure. gzungetc() will + fail if c is -1, and may fail if a character has been pushed but not read + yet. If gzungetc is used immediately after gzopen or gzdopen, at least the + output buffer size of pushed characters is allowed. (See gzbuffer above.) + The pushed character will be discarded if the stream is repositioned with + gzseek() or gzrewind(). +*/ + +ZEXTERN int ZEXPORT gzflush OF((gzFile file, int flush)); +/* + Flushes all pending output into the compressed file. The parameter flush + is as in the deflate() function. The return value is the zlib error number + (see function gzerror below). gzflush is only permitted when writing. + + If the flush parameter is Z_FINISH, the remaining data is written and the + gzip stream is completed in the output. If gzwrite() is called again, a new + gzip stream will be started in the output. gzread() is able to read such + concatented gzip streams. + + gzflush should be called only when strictly necessary because it will + degrade compression if called too often. +*/ + +/* +ZEXTERN z_off_t ZEXPORT gzseek OF((gzFile file, + z_off_t offset, int whence)); + + Sets the starting position for the next gzread or gzwrite on the given + compressed file. The offset represents a number of bytes in the + uncompressed data stream. The whence parameter is defined as in lseek(2); + the value SEEK_END is not supported. + + If the file is opened for reading, this function is emulated but can be + extremely slow. If the file is opened for writing, only forward seeks are + supported; gzseek then compresses a sequence of zeroes up to the new + starting position. + + gzseek returns the resulting offset location as measured in bytes from + the beginning of the uncompressed stream, or -1 in case of error, in + particular if the file is opened for writing and the new starting position + would be before the current position. +*/ + +ZEXTERN int ZEXPORT gzrewind OF((gzFile file)); +/* + Rewinds the given file. This function is supported only for reading. + + gzrewind(file) is equivalent to (int)gzseek(file, 0L, SEEK_SET) +*/ + +/* +ZEXTERN z_off_t ZEXPORT gztell OF((gzFile file)); + + Returns the starting position for the next gzread or gzwrite on the given + compressed file. This position represents a number of bytes in the + uncompressed data stream, and is zero when starting, even if appending or + reading a gzip stream from the middle of a file using gzdopen(). + + gztell(file) is equivalent to gzseek(file, 0L, SEEK_CUR) +*/ + +/* +ZEXTERN z_off_t ZEXPORT gzoffset OF((gzFile file)); + + Returns the current offset in the file being read or written. This offset + includes the count of bytes that precede the gzip stream, for example when + appending or when using gzdopen() for reading. When reading, the offset + does not include as yet unused buffered input. This information can be used + for a progress indicator. On error, gzoffset() returns -1. +*/ + +ZEXTERN int ZEXPORT gzeof OF((gzFile file)); +/* + Returns true (1) if the end-of-file indicator has been set while reading, + false (0) otherwise. Note that the end-of-file indicator is set only if the + read tried to go past the end of the input, but came up short. Therefore, + just like feof(), gzeof() may return false even if there is no more data to + read, in the event that the last read request was for the exact number of + bytes remaining in the input file. This will happen if the input file size + is an exact multiple of the buffer size. + + If gzeof() returns true, then the read functions will return no more data, + unless the end-of-file indicator is reset by gzclearerr() and the input file + has grown since the previous end of file was detected. +*/ + +ZEXTERN int ZEXPORT gzdirect OF((gzFile file)); +/* + Returns true (1) if file is being copied directly while reading, or false + (0) if file is a gzip stream being decompressed. + + If the input file is empty, gzdirect() will return true, since the input + does not contain a gzip stream. + + If gzdirect() is used immediately after gzopen() or gzdopen() it will + cause buffers to be allocated to allow reading the file to determine if it + is a gzip file. Therefore if gzbuffer() is used, it should be called before + gzdirect(). + + When writing, gzdirect() returns true (1) if transparent writing was + requested ("wT" for the gzopen() mode), or false (0) otherwise. (Note: + gzdirect() is not needed when writing. Transparent writing must be + explicitly requested, so the application already knows the answer. When + linking statically, using gzdirect() will include all of the zlib code for + gzip file reading and decompression, which may not be desired.) +*/ + +ZEXTERN int ZEXPORT gzclose OF((gzFile file)); +/* + Flushes all pending output if necessary, closes the compressed file and + deallocates the (de)compression state. Note that once file is closed, you + cannot call gzerror with file, since its structures have been deallocated. + gzclose must not be called more than once on the same file, just as free + must not be called more than once on the same allocation. + + gzclose will return Z_STREAM_ERROR if file is not valid, Z_ERRNO on a + file operation error, Z_MEM_ERROR if out of memory, Z_BUF_ERROR if the + last read ended in the middle of a gzip stream, or Z_OK on success. +*/ + +ZEXTERN int ZEXPORT gzclose_r OF((gzFile file)); +ZEXTERN int ZEXPORT gzclose_w OF((gzFile file)); +/* + Same as gzclose(), but gzclose_r() is only for use when reading, and + gzclose_w() is only for use when writing or appending. The advantage to + using these instead of gzclose() is that they avoid linking in zlib + compression or decompression code that is not used when only reading or only + writing respectively. If gzclose() is used, then both compression and + decompression code will be included the application when linking to a static + zlib library. +*/ + +ZEXTERN const char * ZEXPORT gzerror OF((gzFile file, int *errnum)); +/* + Returns the error message for the last error which occurred on the given + compressed file. errnum is set to zlib error number. If an error occurred + in the file system and not in the compression library, errnum is set to + Z_ERRNO and the application may consult errno to get the exact error code. + + The application must not modify the returned string. Future calls to + this function may invalidate the previously returned string. If file is + closed, then the string previously returned by gzerror will no longer be + available. + + gzerror() should be used to distinguish errors from end-of-file for those + functions above that do not distinguish those cases in their return values. +*/ + +ZEXTERN void ZEXPORT gzclearerr OF((gzFile file)); +/* + Clears the error and end-of-file flags for file. This is analogous to the + clearerr() function in stdio. This is useful for continuing to read a gzip + file that is being written concurrently. +*/ + +#endif /* !Z_SOLO */ + + /* checksum functions */ + +/* + These functions are not related to compression but are exported + anyway because they might be useful in applications using the compression + library. +*/ + +ZEXTERN uLong ZEXPORT adler32 OF((uLong adler, const Bytef *buf, uInt len)); +/* + Update a running Adler-32 checksum with the bytes buf[0..len-1] and + return the updated checksum. If buf is Z_NULL, this function returns the + required initial value for the checksum. + + An Adler-32 checksum is almost as reliable as a CRC32 but can be computed + much faster. + + Usage example: + + uLong adler = adler32(0L, Z_NULL, 0); + + while (read_buffer(buffer, length) != EOF) { + adler = adler32(adler, buffer, length); + } + if (adler != original_adler) error(); +*/ + +/* +ZEXTERN uLong ZEXPORT adler32_combine OF((uLong adler1, uLong adler2, + z_off_t len2)); + + Combine two Adler-32 checksums into one. For two sequences of bytes, seq1 + and seq2 with lengths len1 and len2, Adler-32 checksums were calculated for + each, adler1 and adler2. adler32_combine() returns the Adler-32 checksum of + seq1 and seq2 concatenated, requiring only adler1, adler2, and len2. Note + that the z_off_t type (like off_t) is a signed integer. If len2 is + negative, the result has no meaning or utility. +*/ + +ZEXTERN uLong ZEXPORT crc32 OF((uLong crc, const Bytef *buf, uInt len)); +/* + Update a running CRC-32 with the bytes buf[0..len-1] and return the + updated CRC-32. If buf is Z_NULL, this function returns the required + initial value for the crc. Pre- and post-conditioning (one's complement) is + performed within this function so it shouldn't be done by the application. + + Usage example: + + uLong crc = crc32(0L, Z_NULL, 0); + + while (read_buffer(buffer, length) != EOF) { + crc = crc32(crc, buffer, length); + } + if (crc != original_crc) error(); +*/ + +/* +ZEXTERN uLong ZEXPORT crc32_combine OF((uLong crc1, uLong crc2, z_off_t len2)); + + Combine two CRC-32 check values into one. For two sequences of bytes, + seq1 and seq2 with lengths len1 and len2, CRC-32 check values were + calculated for each, crc1 and crc2. crc32_combine() returns the CRC-32 + check value of seq1 and seq2 concatenated, requiring only crc1, crc2, and + len2. +*/ + + + /* various hacks, don't look :) */ + +/* deflateInit and inflateInit are macros to allow checking the zlib version + * and the compiler's view of z_stream: + */ +ZEXTERN int ZEXPORT deflateInit_ OF((z_streamp strm, int level, + const char *version, int stream_size)); +ZEXTERN int ZEXPORT inflateInit_ OF((z_streamp strm, + const char *version, int stream_size)); +ZEXTERN int ZEXPORT deflateInit2_ OF((z_streamp strm, int level, int method, + int windowBits, int memLevel, + int strategy, const char *version, + int stream_size)); +ZEXTERN int ZEXPORT inflateInit2_ OF((z_streamp strm, int windowBits, + const char *version, int stream_size)); +ZEXTERN int ZEXPORT inflateBackInit_ OF((z_streamp strm, int windowBits, + unsigned char FAR *window, + const char *version, + int stream_size)); +#define deflateInit(strm, level) \ + deflateInit_((strm), (level), ZLIB_VERSION, (int)sizeof(z_stream)) +#define inflateInit(strm) \ + inflateInit_((strm), ZLIB_VERSION, (int)sizeof(z_stream)) +#define deflateInit2(strm, level, method, windowBits, memLevel, strategy) \ + deflateInit2_((strm),(level),(method),(windowBits),(memLevel),\ + (strategy), ZLIB_VERSION, (int)sizeof(z_stream)) +#define inflateInit2(strm, windowBits) \ + inflateInit2_((strm), (windowBits), ZLIB_VERSION, \ + (int)sizeof(z_stream)) +#define inflateBackInit(strm, windowBits, window) \ + inflateBackInit_((strm), (windowBits), (window), \ + ZLIB_VERSION, (int)sizeof(z_stream)) + +#ifndef Z_SOLO + +/* gzgetc() macro and its supporting function and exposed data structure. Note + * that the real internal state is much larger than the exposed structure. + * This abbreviated structure exposes just enough for the gzgetc() macro. The + * user should not mess with these exposed elements, since their names or + * behavior could change in the future, perhaps even capriciously. They can + * only be used by the gzgetc() macro. You have been warned. + */ +struct gzFile_s { + unsigned have; + unsigned char *next; + z_off64_t pos; +}; +ZEXTERN int ZEXPORT gzgetc_ OF((gzFile file)); /* backward compatibility */ +#ifdef Z_PREFIX_SET +# undef z_gzgetc +# define z_gzgetc(g) \ + ((g)->have ? ((g)->have--, (g)->pos++, *((g)->next)++) : gzgetc(g)) +#else +# define gzgetc(g) \ + ((g)->have ? ((g)->have--, (g)->pos++, *((g)->next)++) : gzgetc(g)) +#endif + +/* provide 64-bit offset functions if _LARGEFILE64_SOURCE defined, and/or + * change the regular functions to 64 bits if _FILE_OFFSET_BITS is 64 (if + * both are true, the application gets the *64 functions, and the regular + * functions are changed to 64 bits) -- in case these are set on systems + * without large file support, _LFS64_LARGEFILE must also be true + */ +#ifdef Z_LARGE64 + ZEXTERN gzFile ZEXPORT gzopen64 OF((const char *, const char *)); + ZEXTERN z_off64_t ZEXPORT gzseek64 OF((gzFile, z_off64_t, int)); + ZEXTERN z_off64_t ZEXPORT gztell64 OF((gzFile)); + ZEXTERN z_off64_t ZEXPORT gzoffset64 OF((gzFile)); + ZEXTERN uLong ZEXPORT adler32_combine64 OF((uLong, uLong, z_off64_t)); + ZEXTERN uLong ZEXPORT crc32_combine64 OF((uLong, uLong, z_off64_t)); +#endif + +#if !defined(ZLIB_INTERNAL) && defined(Z_WANT64) +# ifdef Z_PREFIX_SET +# define z_gzopen z_gzopen64 +# define z_gzseek z_gzseek64 +# define z_gztell z_gztell64 +# define z_gzoffset z_gzoffset64 +# define z_adler32_combine z_adler32_combine64 +# define z_crc32_combine z_crc32_combine64 +# else +# define gzopen gzopen64 +# define gzseek gzseek64 +# define gztell gztell64 +# define gzoffset gzoffset64 +# define adler32_combine adler32_combine64 +# define crc32_combine crc32_combine64 +# endif +# ifndef Z_LARGE64 + ZEXTERN gzFile ZEXPORT gzopen64 OF((const char *, const char *)); + ZEXTERN z_off_t ZEXPORT gzseek64 OF((gzFile, z_off_t, int)); + ZEXTERN z_off_t ZEXPORT gztell64 OF((gzFile)); + ZEXTERN z_off_t ZEXPORT gzoffset64 OF((gzFile)); + ZEXTERN uLong ZEXPORT adler32_combine64 OF((uLong, uLong, z_off_t)); + ZEXTERN uLong ZEXPORT crc32_combine64 OF((uLong, uLong, z_off_t)); +# endif +#else + ZEXTERN gzFile ZEXPORT gzopen OF((const char *, const char *)); + ZEXTERN z_off_t ZEXPORT gzseek OF((gzFile, z_off_t, int)); + ZEXTERN z_off_t ZEXPORT gztell OF((gzFile)); + ZEXTERN z_off_t ZEXPORT gzoffset OF((gzFile)); + ZEXTERN uLong ZEXPORT adler32_combine OF((uLong, uLong, z_off_t)); + ZEXTERN uLong ZEXPORT crc32_combine OF((uLong, uLong, z_off_t)); +#endif + +#else /* Z_SOLO */ + + ZEXTERN uLong ZEXPORT adler32_combine OF((uLong, uLong, z_off_t)); + ZEXTERN uLong ZEXPORT crc32_combine OF((uLong, uLong, z_off_t)); + +#endif /* !Z_SOLO */ + +/* hack for buggy compilers */ +#if !defined(ZUTIL_H) && !defined(NO_DUMMY_DECL) + struct internal_state {int dummy;}; +#endif + +/* undocumented functions */ +ZEXTERN const char * ZEXPORT zError OF((int)); +ZEXTERN int ZEXPORT inflateSyncPoint OF((z_streamp)); +ZEXTERN const z_crc_t FAR * ZEXPORT get_crc_table OF((void)); +ZEXTERN int ZEXPORT inflateUndermine OF((z_streamp, int)); +ZEXTERN int ZEXPORT inflateResetKeep OF((z_streamp)); +ZEXTERN int ZEXPORT deflateResetKeep OF((z_streamp)); +#if defined(_WIN32) && !defined(Z_SOLO) +ZEXTERN gzFile ZEXPORT gzopen_w OF((const wchar_t *path, + const char *mode)); +#endif +#if defined(STDC) || defined(Z_HAVE_STDARG_H) +# ifndef Z_SOLO +ZEXTERN int ZEXPORTVA gzvprintf Z_ARG((gzFile file, + const char *format, + va_list va)); +# endif +#endif + +#ifdef __cplusplus +} +#endif + +#endif /* ZLIB_H */ diff --git a/ml/dlib/dlib/external/zlib/zutil.c b/ml/dlib/dlib/external/zlib/zutil.c new file mode 100644 index 000000000..23d2ebef0 --- /dev/null +++ b/ml/dlib/dlib/external/zlib/zutil.c @@ -0,0 +1,324 @@ +/* zutil.c -- target dependent utility functions for the compression library + * Copyright (C) 1995-2005, 2010, 2011, 2012 Jean-loup Gailly. + * For conditions of distribution and use, see copyright notice in zlib.h + */ + +/* @(#) $Id$ */ + +#include "zutil.h" +#ifndef Z_SOLO +# include "gzguts.h" +#endif + +#ifndef NO_DUMMY_DECL +struct internal_state {int dummy;}; /* for buggy compilers */ +#endif + +z_const char * const z_errmsg[10] = { +"need dictionary", /* Z_NEED_DICT 2 */ +"stream end", /* Z_STREAM_END 1 */ +"", /* Z_OK 0 */ +"file error", /* Z_ERRNO (-1) */ +"stream error", /* Z_STREAM_ERROR (-2) */ +"data error", /* Z_DATA_ERROR (-3) */ +"insufficient memory", /* Z_MEM_ERROR (-4) */ +"buffer error", /* Z_BUF_ERROR (-5) */ +"incompatible version",/* Z_VERSION_ERROR (-6) */ +""}; + + +const char * ZEXPORT zlibVersion() +{ + return ZLIB_VERSION; +} + +uLong ZEXPORT zlibCompileFlags() +{ + uLong flags; + + flags = 0; + switch ((int)(sizeof(uInt))) { + case 2: break; + case 4: flags += 1; break; + case 8: flags += 2; break; + default: flags += 3; + } + switch ((int)(sizeof(uLong))) { + case 2: break; + case 4: flags += 1 << 2; break; + case 8: flags += 2 << 2; break; + default: flags += 3 << 2; + } + switch ((int)(sizeof(voidpf))) { + case 2: break; + case 4: flags += 1 << 4; break; + case 8: flags += 2 << 4; break; + default: flags += 3 << 4; + } + switch ((int)(sizeof(z_off_t))) { + case 2: break; + case 4: flags += 1 << 6; break; + case 8: flags += 2 << 6; break; + default: flags += 3 << 6; + } +#ifdef DEBUG + flags += 1 << 8; +#endif +#if defined(ASMV) || defined(ASMINF) + flags += 1 << 9; +#endif +#ifdef ZLIB_WINAPI + flags += 1 << 10; +#endif +#ifdef BUILDFIXED + flags += 1 << 12; +#endif +#ifdef DYNAMIC_CRC_TABLE + flags += 1 << 13; +#endif +#ifdef NO_GZCOMPRESS + flags += 1L << 16; +#endif +#ifdef NO_GZIP + flags += 1L << 17; +#endif +#ifdef PKZIP_BUG_WORKAROUND + flags += 1L << 20; +#endif +#ifdef FASTEST + flags += 1L << 21; +#endif +#if defined(STDC) || defined(Z_HAVE_STDARG_H) +# ifdef NO_vsnprintf + flags += 1L << 25; +# ifdef HAS_vsprintf_void + flags += 1L << 26; +# endif +# else +# ifdef HAS_vsnprintf_void + flags += 1L << 26; +# endif +# endif +#else + flags += 1L << 24; +# ifdef NO_snprintf + flags += 1L << 25; +# ifdef HAS_sprintf_void + flags += 1L << 26; +# endif +# else +# ifdef HAS_snprintf_void + flags += 1L << 26; +# endif +# endif +#endif + return flags; +} + +#ifdef DEBUG + +# ifndef verbose +# define verbose 0 +# endif +int ZLIB_INTERNAL z_verbose = verbose; + +void ZLIB_INTERNAL z_error (m) + char *m; +{ + fprintf(stderr, "%s\n", m); + exit(1); +} +#endif + +/* exported to allow conversion of error code to string for compress() and + * uncompress() + */ +const char * ZEXPORT zError(err) + int err; +{ + return ERR_MSG(err); +} + +#if defined(_WIN32_WCE) + /* The Microsoft C Run-Time Library for Windows CE doesn't have + * errno. We define it as a global variable to simplify porting. + * Its value is always 0 and should not be used. + */ + int errno = 0; +#endif + +#ifndef HAVE_MEMCPY + +void ZLIB_INTERNAL zmemcpy(dest, source, len) + Bytef* dest; + const Bytef* source; + uInt len; +{ + if (len == 0) return; + do { + *dest++ = *source++; /* ??? to be unrolled */ + } while (--len != 0); +} + +int ZLIB_INTERNAL zmemcmp(s1, s2, len) + const Bytef* s1; + const Bytef* s2; + uInt len; +{ + uInt j; + + for (j = 0; j < len; j++) { + if (s1[j] != s2[j]) return 2*(s1[j] > s2[j])-1; + } + return 0; +} + +void ZLIB_INTERNAL zmemzero(dest, len) + Bytef* dest; + uInt len; +{ + if (len == 0) return; + do { + *dest++ = 0; /* ??? to be unrolled */ + } while (--len != 0); +} +#endif + +#ifndef Z_SOLO + +#ifdef SYS16BIT + +#ifdef __TURBOC__ +/* Turbo C in 16-bit mode */ + +# define MY_ZCALLOC + +/* Turbo C malloc() does not allow dynamic allocation of 64K bytes + * and farmalloc(64K) returns a pointer with an offset of 8, so we + * must fix the pointer. Warning: the pointer must be put back to its + * original form in order to free it, use zcfree(). + */ + +#define MAX_PTR 10 +/* 10*64K = 640K */ + +local int next_ptr = 0; + +typedef struct ptr_table_s { + voidpf org_ptr; + voidpf new_ptr; +} ptr_table; + +local ptr_table table[MAX_PTR]; +/* This table is used to remember the original form of pointers + * to large buffers (64K). Such pointers are normalized with a zero offset. + * Since MSDOS is not a preemptive multitasking OS, this table is not + * protected from concurrent access. This hack doesn't work anyway on + * a protected system like OS/2. Use Microsoft C instead. + */ + +voidpf ZLIB_INTERNAL zcalloc (voidpf opaque, unsigned items, unsigned size) +{ + voidpf buf = opaque; /* just to make some compilers happy */ + ulg bsize = (ulg)items*size; + + /* If we allocate less than 65520 bytes, we assume that farmalloc + * will return a usable pointer which doesn't have to be normalized. + */ + if (bsize < 65520L) { + buf = farmalloc(bsize); + if (*(ush*)&buf != 0) return buf; + } else { + buf = farmalloc(bsize + 16L); + } + if (buf == NULL || next_ptr >= MAX_PTR) return NULL; + table[next_ptr].org_ptr = buf; + + /* Normalize the pointer to seg:0 */ + *((ush*)&buf+1) += ((ush)((uch*)buf-0) + 15) >> 4; + *(ush*)&buf = 0; + table[next_ptr++].new_ptr = buf; + return buf; +} + +void ZLIB_INTERNAL zcfree (voidpf opaque, voidpf ptr) +{ + int n; + if (*(ush*)&ptr != 0) { /* object < 64K */ + farfree(ptr); + return; + } + /* Find the original pointer */ + for (n = 0; n < next_ptr; n++) { + if (ptr != table[n].new_ptr) continue; + + farfree(table[n].org_ptr); + while (++n < next_ptr) { + table[n-1] = table[n]; + } + next_ptr--; + return; + } + ptr = opaque; /* just to make some compilers happy */ + Assert(0, "zcfree: ptr not found"); +} + +#endif /* __TURBOC__ */ + + +#ifdef M_I86 +/* Microsoft C in 16-bit mode */ + +# define MY_ZCALLOC + +#if (!defined(_MSC_VER) || (_MSC_VER <= 600)) +# define _halloc halloc +# define _hfree hfree +#endif + +voidpf ZLIB_INTERNAL zcalloc (voidpf opaque, uInt items, uInt size) +{ + if (opaque) opaque = 0; /* to make compiler happy */ + return _halloc((long)items, size); +} + +void ZLIB_INTERNAL zcfree (voidpf opaque, voidpf ptr) +{ + if (opaque) opaque = 0; /* to make compiler happy */ + _hfree(ptr); +} + +#endif /* M_I86 */ + +#endif /* SYS16BIT */ + + +#ifndef MY_ZCALLOC /* Any system without a special alloc function */ + +#ifndef STDC +extern voidp malloc OF((uInt size)); +extern voidp calloc OF((uInt items, uInt size)); +extern void free OF((voidpf ptr)); +#endif + +voidpf ZLIB_INTERNAL zcalloc (opaque, items, size) + voidpf opaque; + unsigned items; + unsigned size; +{ + if (opaque) items += size - size; /* make compiler happy */ + return sizeof(uInt) > 2 ? (voidpf)malloc(items * size) : + (voidpf)calloc(items, size); +} + +void ZLIB_INTERNAL zcfree (opaque, ptr) + voidpf opaque; + voidpf ptr; +{ + free(ptr); + if (opaque) return; /* make compiler happy */ +} + +#endif /* MY_ZCALLOC */ + +#endif /* !Z_SOLO */ diff --git a/ml/dlib/dlib/external/zlib/zutil.h b/ml/dlib/dlib/external/zlib/zutil.h new file mode 100644 index 000000000..24ab06b1c --- /dev/null +++ b/ml/dlib/dlib/external/zlib/zutil.h @@ -0,0 +1,253 @@ +/* zutil.h -- internal interface and configuration of the compression library + * Copyright (C) 1995-2013 Jean-loup Gailly. + * For conditions of distribution and use, see copyright notice in zlib.h + */ + +/* WARNING: this file should *not* be used by applications. It is + part of the implementation of the compression library and is + subject to change. Applications should only use zlib.h. + */ + +/* @(#) $Id$ */ + +#ifndef ZUTIL_H +#define ZUTIL_H + +#ifdef HAVE_HIDDEN +# define ZLIB_INTERNAL __attribute__((visibility ("hidden"))) +#else +# define ZLIB_INTERNAL +#endif + +#include "zlib.h" + +#if defined(STDC) && !defined(Z_SOLO) +# if !(defined(_WIN32_WCE) && defined(_MSC_VER)) +# include +# endif +# include +# include +#endif + +#ifdef Z_SOLO + typedef long ptrdiff_t; /* guess -- will be caught if guess is wrong */ +#endif + +#ifndef local +# define local static +#endif +/* compile with -Dlocal if your debugger can't find static symbols */ + +typedef unsigned char uch; +typedef uch FAR uchf; +typedef unsigned short ush; +typedef ush FAR ushf; +typedef unsigned long ulg; + +extern z_const char * const z_errmsg[10]; /* indexed by 2-zlib_error */ +/* (size given to avoid silly warnings with Visual C++) */ + +#define ERR_MSG(err) z_errmsg[Z_NEED_DICT-(err)] + +#define ERR_RETURN(strm,err) \ + return (strm->msg = ERR_MSG(err), (err)) +/* To be used only when the state is known to be valid */ + + /* common constants */ + +#ifndef DEF_WBITS +# define DEF_WBITS MAX_WBITS +#endif +/* default windowBits for decompression. MAX_WBITS is for compression only */ + +#if MAX_MEM_LEVEL >= 8 +# define DEF_MEM_LEVEL 8 +#else +# define DEF_MEM_LEVEL MAX_MEM_LEVEL +#endif +/* default memLevel */ + +#define STORED_BLOCK 0 +#define STATIC_TREES 1 +#define DYN_TREES 2 +/* The three kinds of block type */ + +#define MIN_MATCH 3 +#define MAX_MATCH 258 +/* The minimum and maximum match lengths */ + +#define PRESET_DICT 0x20 /* preset dictionary flag in zlib header */ + + /* target dependencies */ + +#if defined(MSDOS) || (defined(WINDOWS) && !defined(WIN32)) +# define OS_CODE 0x00 +# ifndef Z_SOLO +# if defined(__TURBOC__) || defined(__BORLANDC__) +# if (__STDC__ == 1) && (defined(__LARGE__) || defined(__COMPACT__)) + /* Allow compilation with ANSI keywords only enabled */ + void _Cdecl farfree( void *block ); + void *_Cdecl farmalloc( unsigned long nbytes ); +# else +# include +# endif +# else /* MSC or DJGPP */ +# include +# endif +# endif +#endif + +#ifdef AMIGA +# define OS_CODE 0x01 +#endif + +#if defined(VAXC) || defined(VMS) +# define OS_CODE 0x02 +# define F_OPEN(name, mode) \ + fopen((name), (mode), "mbc=60", "ctx=stm", "rfm=fix", "mrs=512") +#endif + +#if defined(ATARI) || defined(atarist) +# define OS_CODE 0x05 +#endif + +#ifdef OS2 +# define OS_CODE 0x06 +# if defined(M_I86) && !defined(Z_SOLO) +# include +# endif +#endif + +#if defined(MACOS) || defined(TARGET_OS_MAC) +# define OS_CODE 0x07 +# ifndef Z_SOLO +# if defined(__MWERKS__) && __dest_os != __be_os && __dest_os != __win32_os +# include /* for fdopen */ +# else +# ifndef fdopen +# define fdopen(fd,mode) NULL /* No fdopen() */ +# endif +# endif +# endif +#endif + +#ifdef TOPS20 +# define OS_CODE 0x0a +#endif + +#ifdef WIN32 +# ifndef __CYGWIN__ /* Cygwin is Unix, not Win32 */ +# define OS_CODE 0x0b +# endif +#endif + +#ifdef __50SERIES /* Prime/PRIMOS */ +# define OS_CODE 0x0f +#endif + +#if defined(_BEOS_) || defined(RISCOS) +# define fdopen(fd,mode) NULL /* No fdopen() */ +#endif + +#if (defined(_MSC_VER) && (_MSC_VER > 600)) && !defined __INTERIX +# if defined(_WIN32_WCE) +# define fdopen(fd,mode) NULL /* No fdopen() */ +# ifndef _PTRDIFF_T_DEFINED + typedef int ptrdiff_t; +# define _PTRDIFF_T_DEFINED +# endif +# else +# define fdopen(fd,type) _fdopen(fd,type) +# endif +#endif + +#if defined(__BORLANDC__) && !defined(MSDOS) + #pragma warn -8004 + #pragma warn -8008 + #pragma warn -8066 +#endif + +/* provide prototypes for these when building zlib without LFS */ +#if !defined(_WIN32) && \ + (!defined(_LARGEFILE64_SOURCE) || _LFS64_LARGEFILE-0 == 0) + ZEXTERN uLong ZEXPORT adler32_combine64 OF((uLong, uLong, z_off_t)); + ZEXTERN uLong ZEXPORT crc32_combine64 OF((uLong, uLong, z_off_t)); +#endif + + /* common defaults */ + +#ifndef OS_CODE +# define OS_CODE 0x03 /* assume Unix */ +#endif + +#ifndef F_OPEN +# define F_OPEN(name, mode) fopen((name), (mode)) +#endif + + /* functions */ + +#if defined(pyr) || defined(Z_SOLO) +# define NO_MEMCPY +#endif +#if defined(SMALL_MEDIUM) && !defined(_MSC_VER) && !defined(__SC__) + /* Use our own functions for small and medium model with MSC <= 5.0. + * You may have to use the same strategy for Borland C (untested). + * The __SC__ check is for Symantec. + */ +# define NO_MEMCPY +#endif +#if defined(STDC) && !defined(HAVE_MEMCPY) && !defined(NO_MEMCPY) +# define HAVE_MEMCPY +#endif +#ifdef HAVE_MEMCPY +# ifdef SMALL_MEDIUM /* MSDOS small or medium model */ +# define zmemcpy _fmemcpy +# define zmemcmp _fmemcmp +# define zmemzero(dest, len) _fmemset(dest, 0, len) +# else +# define zmemcpy memcpy +# define zmemcmp memcmp +# define zmemzero(dest, len) memset(dest, 0, len) +# endif +#else + void ZLIB_INTERNAL zmemcpy OF((Bytef* dest, const Bytef* source, uInt len)); + int ZLIB_INTERNAL zmemcmp OF((const Bytef* s1, const Bytef* s2, uInt len)); + void ZLIB_INTERNAL zmemzero OF((Bytef* dest, uInt len)); +#endif + +/* Diagnostic functions */ +#ifdef DEBUG +# include + extern int ZLIB_INTERNAL z_verbose; + extern void ZLIB_INTERNAL z_error OF((char *m)); +# define Assert(cond,msg) {if(!(cond)) z_error(msg);} +# define Trace(x) {if (z_verbose>=0) fprintf x ;} +# define Tracev(x) {if (z_verbose>0) fprintf x ;} +# define Tracevv(x) {if (z_verbose>1) fprintf x ;} +# define Tracec(c,x) {if (z_verbose>0 && (c)) fprintf x ;} +# define Tracecv(c,x) {if (z_verbose>1 && (c)) fprintf x ;} +#else +# define Assert(cond,msg) +# define Trace(x) +# define Tracev(x) +# define Tracevv(x) +# define Tracec(c,x) +# define Tracecv(c,x) +#endif + +#ifndef Z_SOLO + voidpf ZLIB_INTERNAL zcalloc OF((voidpf opaque, unsigned items, + unsigned size)); + void ZLIB_INTERNAL zcfree OF((voidpf opaque, voidpf ptr)); +#endif + +#define ZALLOC(strm, items, size) \ + (*((strm)->zalloc))((strm)->opaque, (items), (size)) +#define ZFREE(strm, addr) (*((strm)->zfree))((strm)->opaque, (voidpf)(addr)) +#define TRY_FREE(s, p) {if (p) ZFREE(s, p);} + +/* Reverse the bytes in a 32-bit value */ +#define ZSWAP32(q) ((((q) >> 24) & 0xff) + (((q) >> 8) & 0xff00) + \ + (((q) & 0xff00) << 8) + (((q) & 0xff) << 24)) + +#endif /* ZUTIL_H */ diff --git a/ml/dlib/dlib/filtering.h b/ml/dlib/dlib/filtering.h new file mode 100644 index 000000000..764d54e81 --- /dev/null +++ b/ml/dlib/dlib/filtering.h @@ -0,0 +1,12 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_FILTERiNG_HEADER +#define DLIB_FILTERiNG_HEADER + +#include "filtering/kalman_filter.h" +#include "filtering/rls_filter.h" + +#endif // DLIB_FILTERiNG_HEADER + + + diff --git a/ml/dlib/dlib/filtering/kalman_filter.cpp b/ml/dlib/dlib/filtering/kalman_filter.cpp new file mode 100644 index 000000000..5d47793a9 --- /dev/null +++ b/ml/dlib/dlib/filtering/kalman_filter.cpp @@ -0,0 +1,104 @@ +// Copyright (C) 2018 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_KALMAN_FiLTER_CPp_ +#define DLIB_KALMAN_FiLTER_CPp_ + +#include "kalman_filter.h" +#include "../global_optimization.h" +#include "../statistics.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + momentum_filter find_optimal_momentum_filter ( + const std::vector>& sequences, + const double smoothness + ) + { + DLIB_CASSERT(sequences.size() != 0); + for (auto& vals : sequences) + DLIB_CASSERT(vals.size() > 4); + DLIB_CASSERT(smoothness >= 0); + + // define the objective function we optimize to find the best filter + auto obj = [&](double measurement_noise, double typical_acceleration, double max_measurement_deviation) + { + running_stats rs; + for (auto& vals : sequences) + { + momentum_filter filt(measurement_noise, typical_acceleration, max_measurement_deviation); + double prev_filt = 0; + for (size_t i = 0; i < vals.size(); ++i) + { + // we care about smoothness and fitting the data. + if (i > 0) + { + // the filter should fit the data + rs.add(std::abs(vals[i]-filt.get_predicted_next_position())); + } + double next_filt = filt(vals[i]); + if (i > 0) + { + // the filter should also output a smooth trajectory + rs.add(smoothness*std::abs(next_filt-prev_filt)); + } + prev_filt = next_filt; + } + } + return rs.mean(); + }; + + running_stats avgdiff; + for (auto& vals : sequences) + { + for (size_t i = 1; i < vals.size(); ++i) + avgdiff.add(vals[i]-vals[i-1]); + } + const double scale = avgdiff.stddev(); + + function_evaluation opt = find_min_global(obj, {scale*0.01, scale*0.0001, 0.00001}, {scale*10, scale*10, 10}, max_function_calls(400)); + + momentum_filter filt(opt.x(0), opt.x(1), opt.x(2)); + + return filt; + } + +// ---------------------------------------------------------------------------------------- + + momentum_filter find_optimal_momentum_filter ( + const std::vector& sequence, + const double smoothness + ) + { + return find_optimal_momentum_filter({1,sequence}, smoothness); + } + +// ---------------------------------------------------------------------------------------- + + rect_filter find_optimal_rect_filter ( + const std::vector& rects, + const double smoothness + ) + { + DLIB_CASSERT(rects.size() > 4); + DLIB_CASSERT(smoothness >= 0); + + std::vector> vals(4); + for (auto& r : rects) + { + vals[0].push_back(r.left()); + vals[1].push_back(r.top()); + vals[2].push_back(r.right()); + vals[3].push_back(r.bottom()); + } + return rect_filter(find_optimal_momentum_filter(vals, smoothness)); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_KALMAN_FiLTER_CPp_ + diff --git a/ml/dlib/dlib/filtering/kalman_filter.h b/ml/dlib/dlib/filtering/kalman_filter.h new file mode 100644 index 000000000..30289fa42 --- /dev/null +++ b/ml/dlib/dlib/filtering/kalman_filter.h @@ -0,0 +1,382 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_KALMAN_FiLTER_Hh_ +#define DLIB_KALMAN_FiLTER_Hh_ + +#include "kalman_filter_abstract.h" +#include "../matrix.h" +#include "../geometry.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + long states, + long measurements + > + class kalman_filter + { + public: + + kalman_filter() + { + H = 0; + A = 0; + Q = 0; + R = 0; + x = 0; + xb = 0; + P = identity_matrix(states); + got_first_meas = false; + } + + void set_observation_model ( const matrix& H_) { H = H_; } + void set_transition_model ( const matrix& A_) { A = A_; } + void set_process_noise ( const matrix& Q_) { Q = Q_; } + void set_measurement_noise ( const matrix& R_) { R = R_; } + void set_estimation_error_covariance( const matrix& P_) { P = P_; } + void set_state ( const matrix& xb_) + { + xb = xb_; + if (!got_first_meas) + { + x = xb_; + got_first_meas = true; + } + } + + const matrix& get_observation_model ( + ) const { return H; } + + const matrix& get_transition_model ( + ) const { return A; } + + const matrix& get_process_noise ( + ) const { return Q; } + + const matrix& get_measurement_noise ( + ) const { return R; } + + void update ( + ) + { + // propagate estimation error covariance forward + P = A*P*trans(A) + Q; + + // propagate state forward + x = xb; + xb = A*x; + } + + void update (const matrix& z) + { + // propagate estimation error covariance forward + P = A*P*trans(A) + Q; + + // compute Kalman gain matrix + const matrix K = P*trans(H)*pinv(H*P*trans(H) + R); + + if (got_first_meas) + { + const matrix res = z - H*xb; + // correct the current state estimate + x = xb + K*res; + } + else + { + // Since we don't have a previous state estimate at the start of filtering, + // we will just set the current state to whatever is indicated by the measurement + x = pinv(H)*z; + got_first_meas = true; + } + + // propagate state forward in time + xb = A*x; + + // update estimation error covariance since we got a measurement. + P = (identity_matrix() - K*H)*P; + } + + const matrix& get_current_state( + ) const + { + return x; + } + + const matrix& get_predicted_next_state( + ) const + { + return xb; + } + + const matrix& get_current_estimation_error_covariance( + ) const + { + return P; + } + + friend inline void serialize(const kalman_filter& item, std::ostream& out) + { + int version = 1; + serialize(version, out); + serialize(item.got_first_meas, out); + serialize(item.x, out); + serialize(item.xb, out); + serialize(item.P, out); + serialize(item.H, out); + serialize(item.A, out); + serialize(item.Q, out); + serialize(item.R, out); + } + + friend inline void deserialize(kalman_filter& item, std::istream& in) + { + int version = 0; + deserialize(version, in); + if (version != 1) + throw dlib::serialization_error("Unknown version number found while deserializing kalman_filter object."); + + deserialize(item.got_first_meas, in); + deserialize(item.x, in); + deserialize(item.xb, in); + deserialize(item.P, in); + deserialize(item.H, in); + deserialize(item.A, in); + deserialize(item.Q, in); + deserialize(item.R, in); + } + + private: + + bool got_first_meas; + matrix x, xb; + matrix P; + + matrix H; + matrix A; + matrix Q; + matrix R; + + + }; + +// ---------------------------------------------------------------------------------------- + + class momentum_filter + { + public: + + momentum_filter( + double meas_noise, + double acc, + double max_meas_dev + ) : + measurement_noise(meas_noise), + typical_acceleration(acc), + max_measurement_deviation(max_meas_dev) + { + DLIB_CASSERT(meas_noise >= 0); + DLIB_CASSERT(acc >= 0); + DLIB_CASSERT(max_meas_dev >= 0); + + kal.set_observation_model({1, 0}); + kal.set_transition_model( {1, 1, + 0, 1}); + kal.set_process_noise({0, 0, + 0, typical_acceleration*typical_acceleration}); + + kal.set_measurement_noise({measurement_noise*measurement_noise}); + } + + momentum_filter() = default; + + double get_measurement_noise ( + ) const { return measurement_noise; } + + double get_typical_acceleration ( + ) const { return typical_acceleration; } + + double get_max_measurement_deviation ( + ) const { return max_measurement_deviation; } + + void reset() + { + *this = momentum_filter(measurement_noise, typical_acceleration, max_measurement_deviation); + } + + double get_predicted_next_position( + ) const + { + return kal.get_predicted_next_state()(0); + } + + double operator()( + const double measured_position + ) + { + auto x = kal.get_predicted_next_state(); + const auto max_deviation = max_measurement_deviation*measurement_noise; + // Check if measured_position has suddenly jumped in value by a whole lot. This + // could happen if the velocity term experiences a much larger than normal + // acceleration, e.g. because the underlying object is doing a maneuver. If + // this happens then we clamp the state so that the predicted next value is no + // more than max_deviation away from measured_position at all times. + if (x(0) > measured_position + max_deviation) + { + x(0) = measured_position + max_deviation; + kal.set_state(x); + } + else if (x(0) < measured_position - max_deviation) + { + x(0) = measured_position - max_deviation; + kal.set_state(x); + } + + kal.update({measured_position}); + + return kal.get_current_state()(0); + } + + friend std::ostream& operator << (std::ostream& out, const momentum_filter& item) + { + out << "measurement_noise: " << item.measurement_noise << "\n"; + out << "typical_acceleration: " << item.typical_acceleration << "\n"; + out << "max_measurement_deviation: " << item.max_measurement_deviation; + return out; + } + + friend void serialize(const momentum_filter& item, std::ostream& out) + { + int version = 15; + serialize(version, out); + serialize(item.measurement_noise, out); + serialize(item.typical_acceleration, out); + serialize(item.max_measurement_deviation, out); + serialize(item.kal, out); + } + + friend void deserialize(momentum_filter& item, std::istream& in) + { + int version = 0; + deserialize(version, in); + if (version != 15) + throw serialization_error("Unexpected version found while deserializing momentum_filter."); + deserialize(item.measurement_noise, in); + deserialize(item.typical_acceleration, in); + deserialize(item.max_measurement_deviation, in); + deserialize(item.kal, in); + } + + private: + + double measurement_noise = 2; + double typical_acceleration = 0.1; + double max_measurement_deviation = 3; // nominally number of standard deviations + + kalman_filter<2,1> kal; + }; + +// ---------------------------------------------------------------------------------------- + + momentum_filter find_optimal_momentum_filter ( + const std::vector>& sequences, + const double smoothness = 1 + ); + +// ---------------------------------------------------------------------------------------- + + momentum_filter find_optimal_momentum_filter ( + const std::vector& sequence, + const double smoothness = 1 + ); + +// ---------------------------------------------------------------------------------------- + + class rect_filter + { + public: + rect_filter() = default; + + rect_filter( + double meas_noise, + double acc, + double max_meas_dev + ) : rect_filter(momentum_filter(meas_noise, acc, max_meas_dev)) {} + + rect_filter( + const momentum_filter& filt + ) : + left(filt), + top(filt), + right(filt), + bottom(filt) + { + } + + drectangle operator()(const drectangle& r) + { + return drectangle(left(r.left()), + top(r.top()), + right(r.right()), + bottom(r.bottom())); + } + + drectangle operator()(const rectangle& r) + { + return drectangle(left(r.left()), + top(r.top()), + right(r.right()), + bottom(r.bottom())); + } + + const momentum_filter& get_left () const { return left; } + momentum_filter& get_left () { return left; } + const momentum_filter& get_top () const { return top; } + momentum_filter& get_top () { return top; } + const momentum_filter& get_right () const { return right; } + momentum_filter& get_right () { return right; } + const momentum_filter& get_bottom () const { return bottom; } + momentum_filter& get_bottom () { return bottom; } + + friend void serialize(const rect_filter& item, std::ostream& out) + { + int version = 123; + serialize(version, out); + serialize(item.left, out); + serialize(item.top, out); + serialize(item.right, out); + serialize(item.bottom, out); + } + + friend void deserialize(rect_filter& item, std::istream& in) + { + int version = 0; + deserialize(version, in); + if (version != 123) + throw dlib::serialization_error("Unknown version number found while deserializing rect_filter object."); + deserialize(item.left, in); + deserialize(item.top, in); + deserialize(item.right, in); + deserialize(item.bottom, in); + } + + private: + + momentum_filter left, top, right, bottom; + }; + +// ---------------------------------------------------------------------------------------- + + rect_filter find_optimal_rect_filter ( + const std::vector& rects, + const double smoothness = 1 + ); + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_KALMAN_FiLTER_Hh_ + diff --git a/ml/dlib/dlib/filtering/kalman_filter_abstract.h b/ml/dlib/dlib/filtering/kalman_filter_abstract.h new file mode 100644 index 000000000..cdac2c569 --- /dev/null +++ b/ml/dlib/dlib/filtering/kalman_filter_abstract.h @@ -0,0 +1,492 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_KALMAN_FiLTER_ABSTRACT_Hh_ +#ifdef DLIB_KALMAN_FiLTER_ABSTRACT_Hh_ + +#include "../serialize.h" +#include "../matrix.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + long states, + long measurements + > + class kalman_filter + { + /*! + REQUIREMENTS ON states + states > 0 + + REQUIREMENTS ON measurements + measurements > 0 + + WHAT THIS OBJECT REPRESENTS + This object implements the Kalman filter, which is a tool for + recursively estimating the state of a process given measurements + related to that process. To use this tool you will have to + be familiar with the workings of the Kalman filter. An excellent + introduction can be found in the paper: + An Introduction to the Kalman Filter + by Greg Welch and Gary Bishop + + !*/ + + public: + + kalman_filter( + ); + /*! + - #get_observation_model() == 0 + - #get_transition_model() == 0 + - #get_process_noise() == 0 + - #get_measurement_noise() == 0 + - #get_current_state() == 0 + - #get_predicted_next_state() == 0 + - #get_current_estimation_error_covariance() == the identity matrix + !*/ + + void set_observation_model ( + const matrix& H + ); + /*! + ensures + - #get_observation_model() == H + !*/ + + void set_transition_model ( + const matrix& A + ); + /*! + ensures + - #get_transition_model() == A + !*/ + + void set_process_noise ( + const matrix& Q + ); + /*! + ensures + - #get_process_noise() == Q + !*/ + + void set_measurement_noise ( + const matrix& R + ); + /*! + ensures + - #get_measurement_noise() == R + !*/ + + void set_estimation_error_covariance ( + const matrix& P + ); + /*! + ensures + - #get_current_estimation_error_covariance() == P + (Note that you should only set this before you start filtering + since the Kalman filter will maintain the value of P on its own. + So only set this during initialization unless you are sure you + understand what you are doing.) + !*/ + + void set_state ( + const matrix& xb + ); + /*! + ensures + - This function can be used when the initial state is known, or if the + state needs to be corrected before the next update(). + - #get_predicted_next_state() == xb + - If (update() hasn't been called yet) then + - #get_current_state() == xb + !*/ + + const matrix& get_observation_model ( + ) const; + /*! + ensures + - Returns the matrix "H" which relates process states x to measurements z. + The relation is linear, therefore, z = H*x. That is, multiplying a + state by H gives the measurement you expect to observe for that state. + !*/ + + const matrix& get_transition_model ( + ) const; + /*! + ensures + - Returns the matrix "A" which determines how process states change over time. + The relation is linear, therefore, given a state vector x, the value you + expect it to have at the next time step is A*x. + !*/ + + const matrix& get_process_noise ( + ) const; + /*! + ensures + - returns the process noise covariance matrix. You can think of this + covariance matrix as a measure of how wrong the assumption of + linear state transitions is. + !*/ + + const matrix& get_measurement_noise ( + ) const; + /*! + ensures + - returns the measurement noise covariance matrix. That is, when we + measure a state x we only obtain H*x corrupted by Gaussian noise. + The measurement noise is the covariance matrix of this Gaussian + noise which corrupts our measurements. + !*/ + + void update ( + ); + /*! + ensures + - propagates the current state estimate forward in time one + time step. In particular: + - #get_current_state() == get_predicted_next_state() + - #get_predicted_next_state() == get_transition_model()*get_current_state() + - #get_current_estimation_error_covariance() == the propagated value of this covariance matrix + !*/ + + void update ( + const matrix& z + ); + /*! + ensures + - propagates the current state estimate forward in time one time step. + Also applies a correction based on the given measurement z. In particular: + - #get_current_state(), #get_predicted_next_state(), and + #get_current_estimation_error_covariance() are updated using the + Kalman filter method based on the new measurement in z. + !*/ + + const matrix& get_current_state( + ) const; + /*! + ensures + - returns the current estimate of the state of the process. This + estimate is based on all the measurements supplied to the update() + method. + !*/ + + const matrix& get_predicted_next_state( + ) const; + /*! + ensures + - returns the next expected value of the process state. + - Specifically, returns get_transition_model()*get_current_state() + + !*/ + + const matrix& get_current_estimation_error_covariance( + ) const; + /*! + ensures + - returns the current state error estimation covariance matrix. + This matrix captures our uncertainty about the value of get_current_state(). + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + void serialize ( + const kalman_filter& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + + void deserialize ( + kalman_filter& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class momentum_filter + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a simple tool for filtering a single scalar value that + measures the location of a moving object that has some non-trivial + momentum. Importantly, the measurements are noisy and the object can + experience sudden unpredictable accelerations. To accomplish this + filtering we use a simple Kalman filter with a state transition model of: + + position_{i+1} = position_{i} + velocity_{i} + velocity_{i+1} = velocity_{i} + some_unpredictable_acceleration + + and a measurement model of: + + measured_position_{i} = position_{i} + measurement_noise + + Where some_unpredictable_acceleration and measurement_noise are 0 mean Gaussian + noise sources with standard deviations of get_typical_acceleration() and + get_measurement_noise() respectively. + + To allow for really sudden and large but infrequent accelerations, at each + step we check if the current measured position deviates from the predicted + filtered position by more than get_max_measurement_deviation()*get_measurement_noise() + and if so we adjust the filter's state to keep it within these bounds. + This allows the moving object to undergo large unmodeled accelerations, far + in excess of what would be suggested by get_typical_acceleration(), without + then experiencing a long lag time where the Kalman filter has to "catch + up" to the new position. + !*/ + + public: + + momentum_filter( + ) = default; + /*! + ensures + - #get_measurement_noise() == 2 + - #get_typical_acceleration() == 0.1 + - #get_max_measurement_deviation() == 3 + !*/ + + momentum_filter( + double meas_noise, + double acc, + double max_meas_dev + ); + /*! + requires + - meas_noise >= 0 + - acc >= 0 + - max_meas_dev >= 0 + ensures + - #get_measurement_noise() == meas_noise + - #get_typical_acceleration() == acc + - #get_max_measurement_deviation() == max_meas_dev + !*/ + + + double get_measurement_noise ( + ) const; + /*! + ensures + - Returns the standard deviation of the 0 mean Gaussian noise that corrupts + measurements of the moving object. + !*/ + + double get_typical_acceleration ( + ) const; + /*! + ensures + - We assume that the moving object experiences random accelerations that + are distributed by 0 mean Gaussian noise with get_typical_acceleration() + standard deviation. + !*/ + + double get_max_measurement_deviation ( + ) const; + /*! + ensures + - This object will never let the filtered location of the object deviate + from the measured location by much more than + get_max_measurement_deviation()*get_measurement_noise(). + !*/ + + void reset( + ); + /*! + ensures + - Returns this object to the state immediately after construction. To be precise, we do: + *this = momentum_filter(get_measurement_noise(), get_typical_acceleration(), get_max_measurement_deviation()); + !*/ + + double operator()( + const double measured_position + ); + /*! + ensures + - Updates the Kalman filter with the new measured position of the object + and returns the new filtered estimate of the object's position, now that + we have seen the latest measured position. + - #get_predicted_next_position() == the prediction for the *next* place we + will see the object. That is, where we think it will be in the future + rather than where it is now. + !*/ + + double get_predicted_next_position ( + ) const; + /*! + ensures + - Returns the Kalman filter's estimate of the next position we will see the object. + !*/ + }; + + std::ostream& operator << (std::ostream& out, const momentum_filter& item); + void serialize(const momentum_filter& item, std::ostream& out); + void deserialize(momentum_filter& item, std::istream& in); + /*! + Provide printing and serialization support. + !*/ + +// ---------------------------------------------------------------------------------------- + + momentum_filter find_optimal_momentum_filter ( + const std::vector>& sequences, + const double smoothness = 1 + ); + /*! + requires + - sequences.size() != 0 + - for all valid i: sequences[i].size() > 4 + - smoothness >= 0 + ensures + - This function finds the "optimal" settings of a momentum_filter based on + recorded measurement data stored in sequences. Here we assume that each + vector in sequences is a complete track history of some object's measured + positions. What we do is find the momentum_filter that minimizes the + following objective function: + sum of abs(predicted_location[i] - measured_location[i]) + smoothness*abs(filtered_location[i]-filtered_location[i-1]) + Where i is a time index. + The sum runs over all the data in sequences. So what we do is find the + filter settings that produce smooth filtered trajectories but also produce + filtered outputs that are as close to the measured positions as possible. + The larger the value of smoothness the less jittery the filter outputs will + be, but they might become biased or laggy if smoothness is set really high. + !*/ + +// ---------------------------------------------------------------------------------------- + + momentum_filter find_optimal_momentum_filter ( + const std::vector& sequence, + const double smoothness = 1 + ); + /*! + requires + - sequence.size() > 4 + - smoothness >= 0 + ensures + - performs: find_optimal_momentum_filter({1,sequence}, smoothness); + !*/ + +// ---------------------------------------------------------------------------------------- + + class rect_filter + { + /*! + WHAT THIS OBJECT REPRESENTS + This object simply contains four momentum_filters and applies them to the + 4 components of a dlib::rectangle's position. It therefore allows you to + easily filter a sequence of rectangles. For instance, it can be used to + smooth the output of an object detector running on a video. + !*/ + + public: + rect_filter( + ) = default; + /*! + ensures + - The four momentum_filters in this object are default initialized. + !*/ + + rect_filter( + const momentum_filter& filt + ); + /*! + ensures + - #get_left() == filt + - #get_top() == filt + - #get_right() == filt + - #get_bottom() == filt + !*/ + + rect_filter( + double meas_noise, + double acc, + double max_meas_dev + ) : rect_filter(momentum_filter(meas_noise, acc, max_meas_dev)) {} + /*! + requires + - meas_noise >= 0 + - acc >= 0 + - max_meas_dev >= 0 + ensures + - Initializes this object with momentum_filter(meas_noise, acc, max_meas_dev) + !*/ + + drectangle operator()( + const drectangle& r + ); + /*! + ensures + - Runs the given rectangle through the momentum_filters and returns the + filtered rectangle location. That is, performs: + return drectangle(get_left()(r.left()), + get_top()(r.top()), + get_right()(r.right()), + get_bottom()(r.bottom())); + !*/ + + drectangle operator()( + const rectangle& r + ); + /*! + ensures + - Runs the given rectangle through the momentum_filters and returns the + filtered rectangle location. That is, performs: + return drectangle(get_left()(r.left()), + get_top()(r.top()), + get_right()(r.right()), + get_bottom()(r.bottom())); + !*/ + + const momentum_filter& get_left() const; + momentum_filter& get_left(); + const momentum_filter& get_top() const; + momentum_filter& get_top(); + const momentum_filter& get_right() const; + momentum_filter& get_right(); + const momentum_filter& get_bottom() const; + momentum_filter& get_bottom(); + /*! + Provides access to the 4 momentum_filters used to filter the 4 coordinates that define a rectangle. + !*/ + }; + + void serialize(const rect_filter& item, std::ostream& out); + void deserialize(rect_filter& item, std::istream& in); + /*! + Provide serialization support. + !*/ + +// ---------------------------------------------------------------------------------------- + + rect_filter find_optimal_rect_filter ( + const std::vector& rects, + const double smoothness = 1 + ); + /*! + requires + - rects.size() > 4 + - smoothness >= 0 + ensures + - This routine simply invokes find_optimal_momentum_filter() to find the + momentum_filter that works best on the provided sequence of rectangles. It + then constructs a rect_filter using that momentum_filter and returns it. + Therefore, this routine finds the rect_filter that is "optimal" for filtering + the given sequence of rectangles. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_KALMAN_FiLTER_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/filtering/rls_filter.h b/ml/dlib/dlib/filtering/rls_filter.h new file mode 100644 index 000000000..4481ab3f4 --- /dev/null +++ b/ml/dlib/dlib/filtering/rls_filter.h @@ -0,0 +1,198 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_RLS_FiLTER_Hh_ +#define DLIB_RLS_FiLTER_Hh_ + +#include "rls_filter_abstract.h" +#include "../svm/rls.h" +#include +#include "../matrix.h" +#include "../sliding_buffer.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class rls_filter + { + /*! + CONVENTION + - data.size() == the number of variables in a measurement + - data[i].size() == data[j].size() for all i and j. + - data[i].size() == get_window_size() + - data[i][0] == most recent measurement of i-th variable given to update. + - data[i].back() == oldest measurement of i-th variable given to update + (or zero if we haven't seen this much data yet). + + - if (count <= 2) then + - count == number of times update(z) has been called + !*/ + public: + + rls_filter() + { + size = 5; + count = 0; + filter = rls(0.8, 100); + } + + explicit rls_filter ( + unsigned long size_, + double forget_factor = 0.8, + double C = 100 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(0 < forget_factor && forget_factor <= 1 && + 0 < C && size_ >= 2, + "\t rls_filter::rls_filter()" + << "\n\t invalid arguments were given to this function" + << "\n\t forget_factor: " << forget_factor + << "\n\t C: " << C + << "\n\t size_: " << size_ + << "\n\t this: " << this + ); + + size = size_; + count = 0; + filter = rls(forget_factor, C); + } + + double get_c( + ) const + { + return filter.get_c(); + } + + double get_forget_factor( + ) const + { + return filter.get_forget_factor(); + } + + unsigned long get_window_size ( + ) const + { + return size; + } + + void update ( + ) + { + if (filter.get_w().size() == 0) + return; + + for (unsigned long i = 0; i < data.size(); ++i) + { + // Put old predicted value into the circular buffer as if it was + // the measurement we just observed. But don't update the rls filter. + data[i].push_front(next(i)); + } + + // predict next state + for (long i = 0; i < next.size(); ++i) + next(i) = filter(mat(data[i])); + } + + template + void update ( + const matrix_exp& z + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(is_col_vector(z) == true && + z.size() != 0 && + (get_predicted_next_state().size()==0 || z.size()==get_predicted_next_state().size()), + "\t void rls_filter::update(z)" + << "\n\t invalid arguments were given to this function" + << "\n\t is_col_vector(z): " << is_col_vector(z) + << "\n\t z.size(): " << z.size() + << "\n\t get_predicted_next_state().size(): " << get_predicted_next_state().size() + << "\n\t this: " << this + ); + + // initialize data if necessary + if (data.size() == 0) + { + data.resize(z.size()); + for (long i = 0; i < z.size(); ++i) + data[i].assign(size, 0); + } + + + for (unsigned long i = 0; i < data.size(); ++i) + { + // Once there is some stuff in the circular buffer, start + // showing it to the rls filter so it can do its thing. + if (count >= 2) + { + filter.train(mat(data[i]), z(i)); + } + + // keep track of the measurements in our circular buffer + data[i].push_front(z(i)); + } + + // Don't bother with the filter until we have seen two samples + if (count >= 2) + { + // predict next state + for (long i = 0; i < z.size(); ++i) + next(i) = filter(mat(data[i])); + } + else + { + // Use current measurement as the next state prediction + // since we don't know any better at this point. + ++count; + next = matrix_cast(z); + } + } + + const matrix& get_predicted_next_state( + ) const + { + return next; + } + + friend inline void serialize(const rls_filter& item, std::ostream& out) + { + int version = 1; + serialize(version, out); + serialize(item.count, out); + serialize(item.size, out); + serialize(item.filter, out); + serialize(item.next, out); + serialize(item.data, out); + } + + friend inline void deserialize(rls_filter& item, std::istream& in) + { + int version = 0; + deserialize(version, in); + if (version != 1) + throw dlib::serialization_error("Unknown version number found while deserializing rls_filter object."); + + deserialize(item.count, in); + deserialize(item.size, in); + deserialize(item.filter, in); + deserialize(item.next, in); + deserialize(item.data, in); + } + + private: + + unsigned long count; + unsigned long size; + rls filter; + matrix next; + std::vector > data; + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_RLS_FiLTER_Hh_ + diff --git a/ml/dlib/dlib/filtering/rls_filter_abstract.h b/ml/dlib/dlib/filtering/rls_filter_abstract.h new file mode 100644 index 000000000..0a932cb87 --- /dev/null +++ b/ml/dlib/dlib/filtering/rls_filter_abstract.h @@ -0,0 +1,171 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_RLS_FiLTER_ABSTRACT_Hh_ +#ifdef DLIB_RLS_FiLTER_ABSTRACT_Hh_ + +#include "../svm/rls_abstract.h" +#include "../matrix/matrix_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class rls_filter + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a tool for doing time series prediction using linear + recursive least squares. In particular, this object takes a sequence + of points from the user and, at each step, attempts to predict the + value of the next point. + + To accomplish this, this object maintains a fixed size buffer of recent + points. Each prediction is a linear combination of the points in this + history buffer. It uses the recursive least squares algorithm to + determine how to best combine the contents of the history buffer to + predict each point. Therefore, each time update() is called with + a point, recursive least squares updates the linear combination weights, + and then it inserts the point into the history buffer. After that, the + next prediction is based on these updated weights and the current history + buffer. + !*/ + + public: + + rls_filter( + ); + /*! + ensures + - #get_window_size() == 5 + - #get_forget_factor() == 0.8 + - #get_c() == 100 + - #get_predicted_next_state().size() == 0 + !*/ + + explicit rls_filter ( + unsigned long size, + double forget_factor = 0.8, + double C = 100 + ); + /*! + requires + - 0 < forget_factor <= 1 + - 0 < C + - size >= 2 + ensures + - #get_window_size() == size + - #get_forget_factor() == forget_factor + - #get_c() == C + - #get_predicted_next_state().size() == 0 + !*/ + + double get_c( + ) const; + /*! + ensures + - returns the regularization parameter. It is the parameter that determines + the trade-off between trying to fit the data points given to update() or + allowing more errors but hopefully improving the generalization of the + predictions. Larger values encourage exact fitting while smaller values + of C may encourage better generalization. + !*/ + + double get_forget_factor( + ) const; + /*! + ensures + - This object uses exponential forgetting in its implementation of recursive + least squares. Therefore, this function returns the "forget factor". + - if (get_forget_factor() == 1) then + - In this case, exponential forgetting is disabled. + - The recursive least squares algorithm will implicitly take all previous + calls to update(z) into account when estimating the optimal weights for + linearly combining the history buffer into a prediction of the next point. + - else + - Old calls to update(z) are eventually forgotten. That is, the smaller + the forget factor, the less recursive least squares will care about + attempting to find linear combination weights which would have make + good predictions on old points. It will care more about fitting recent + points. This is appropriate if the statistical properties of the time + series we are modeling are not constant. + !*/ + + unsigned long get_window_size ( + ) const; + /*! + ensures + - returns the size of the history buffer. This is the number of points which + are linearly combine to make the predictions returned by get_predicted_next_state(). + !*/ + + void update ( + ); + /*! + ensures + - Propagates the prediction forward in time. + - In particular, the value in get_predicted_next_state() is inserted + into the history buffer and then the next prediction is estimated + based on this updated history buffer. + - #get_predicted_next_state() == the prediction for the next point + in the time series. + !*/ + + template + void update ( + const matrix_exp& z + ); + /*! + requires + - is_col_vector(z) == true + - z.size() != 0 + - if (get_predicted_next_state().size() != 0) then + - z.size() == get_predicted_next_state().size() + (i.e. z must be the same size as all the previous z values given + to this function) + ensures + - Updates the state of this filter based on the current measurement in z. + - In particular, the filter weights are updated and z is inserted into + the history buffer. Then the next prediction is estimated based on + these updated weights and history buffer. + - #get_predicted_next_state() == the prediction for the next point + in the time series. + - #get_predicted_next_state().size() == z.size() + !*/ + + const matrix& get_predicted_next_state( + ) const; + /*! + ensures + - returns the estimate of the next point we will observe in the + time series data. + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + void serialize ( + const rls_filter& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + + void deserialize ( + rls_filter& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +// ---------------------------------------------------------------------------------------- + + +} + +#endif // DLIB_RLS_FiLTER_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/float_details.h b/ml/dlib/dlib/float_details.h new file mode 100644 index 000000000..3dc7eae49 --- /dev/null +++ b/ml/dlib/dlib/float_details.h @@ -0,0 +1,161 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_FLOAT_DEtAILS_Hh_ +#define DLIB_FLOAT_DEtAILS_Hh_ + +#include +#include "algs.h" +#include + +namespace dlib +{ + struct float_details + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a tool for converting floating point numbers into an + explicit integer representation and then also converting back. In + particular, a float_details object represents a floating point number with + a 64 bit mantissa and 16 bit exponent. These are stored in the public + fields of the same names. + + The main use of this object is to convert floating point values into a + known uniform representation so they can be serialized to an output stream. + This allows dlib serialization code to work on any system, regardless of + the floating point representation used by the hardware. It also means + that, for example, a double can be serialized and then deserialized into a + float and it will perform the appropriate conversion. + + + In more detail, this object represents a floating point value equal to + mantissa*pow(2,exponent), except when exponent takes on any of the + following special values: + - is_inf + - is_ninf + - is_nan + These values are used to indicate that the floating point value should be + either infinity, negative infinity, or not-a-number respectively. + !*/ + + float_details( + int64 man, + int16 exp + ) : mantissa(man), exponent(exp) {} + /*! + ensures + - #mantissa == man + - #exponent == exp + !*/ + + float_details() : + mantissa(0), exponent(0) + {} + /*! + ensures + - this object represents a floating point value of 0 + !*/ + + float_details ( const double& val) { *this = val; } + float_details ( const float& val) { *this = val; } + float_details ( const long double& val) { *this = val; } + /*! + ensures + - converts the given value into a float_details representation. This + means that converting #*this back into a floating point number should + recover the input val. + !*/ + + float_details& operator= ( const double& val) { convert_from_T(val); return *this; } + float_details& operator= ( const float& val) { convert_from_T(val); return *this; } + float_details& operator= ( const long double& val) { convert_from_T(val); return *this; } + /*! + ensures + - converts the given value into a float_details representation. This + means that converting #*this back into a floating point number should + recover the input val. + !*/ + + operator double () const { return convert_to_T(); } + operator float () const { return convert_to_T(); } + operator long double () const { return convert_to_T(); } + /*! + ensures + - converts the contents of this float_details object into a floating point number. + !*/ + + const static int16 is_inf = 32000; + const static int16 is_ninf = 32001; + const static int16 is_nan = 32002; + + int64 mantissa; + int16 exponent; + + + private: + + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// IMPLEMENTATION DETAILS +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template + void convert_from_T ( + const T& val + ) + { + mantissa = 0; + + const int digits = dlib::tmin::digits, 63>::value; + + if (val == std::numeric_limits::infinity()) + { + exponent = is_inf; + } + else if (val == -std::numeric_limits::infinity()) + { + exponent = is_ninf; + } + else if (val < std::numeric_limits::infinity()) + { + int exp; + mantissa = static_cast(std::frexp(val, &exp)*(((uint64)1)<>= 8; + exponent += 8; + } + } + else + { + exponent = is_nan; + } + } + + template + T convert_to_T ( + ) const + { + if (exponent < is_inf) + return std::ldexp((T)mantissa, exponent); + else if (exponent == is_inf) + return std::numeric_limits::infinity(); + else if (exponent == is_ninf) + return -std::numeric_limits::infinity(); + else + return std::numeric_limits::quiet_NaN(); + } + + }; + +} + +#endif // DLIB_FLOAT_DEtAILS_Hh_ + diff --git a/ml/dlib/dlib/fstream b/ml/dlib/dlib/fstream new file mode 100644 index 000000000..eb0e59e41 --- /dev/null +++ b/ml/dlib/dlib/fstream @@ -0,0 +1 @@ +#include "dlib_include_path_tutorial.txt" diff --git a/ml/dlib/dlib/general_hash/count_bits.h b/ml/dlib/dlib/general_hash/count_bits.h new file mode 100644 index 000000000..01acad2ab --- /dev/null +++ b/ml/dlib/dlib/general_hash/count_bits.h @@ -0,0 +1,82 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_COUNT_BiTS_Hh_ +#define DLIB_COUNT_BiTS_Hh_ + +#include "../algs.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + T count_bits ( + T v + ) + /*! + requires + - T is an unsigned integral type + ensures + - returns the number of bits in v which are set to 1. + !*/ + { + COMPILE_TIME_ASSERT(is_unsigned_type::value && sizeof(T) <= 8); + + // This bit of bit trickery is from: + // http://graphics.stanford.edu/~seander/bithacks.html#CountBitsSet64 + + v = v - ((v >> 1) & (T)~(T)0/3); + v = (v & (T)~(T)0/15*3) + ((v >> 2) & (T)~(T)0/15*3); + v = (v + (v >> 4)) & (T)~(T)0/255*15; + return (T)(v * ((T)~(T)0/255)) >> (sizeof(T) - 1) * CHAR_BIT; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + T hamming_distance ( + const T& a, + const T& b + ) + /*! + requires + - T is an unsigned integral type + ensures + - returns the number of bits which differ between a and b. + !*/ + { + return count_bits(a^b); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + T hamming_distance ( + const std::pair& a, + const std::pair& b + ) + /*! + requires + - T is an unsigned integral type or a std::pair that, recursively, eventually + contains unsigned integral types. + ensures + - returns the number of bits which differ between a and b. + !*/ + { + return hamming_distance(a.first,b.first) + hamming_distance(a.second, b.second); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_COUNT_BiTS_Hh_ + diff --git a/ml/dlib/dlib/general_hash/count_bits_abstract.h b/ml/dlib/dlib/general_hash/count_bits_abstract.h new file mode 100644 index 000000000..ff5d4482c --- /dev/null +++ b/ml/dlib/dlib/general_hash/count_bits_abstract.h @@ -0,0 +1,48 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_COUNT_BiTS_ABSTRACT_Hh_ +#ifdef DLIB_COUNT_BiTS_ABSTRACT_Hh_ + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + T count_bits ( + T v + ); + /*! + requires + - T is an unsigned integral type + ensures + - returns the number of bits in v which are set to 1. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + T hamming_distance ( + const T& a, + const T& b + ); + /*! + requires + - T is an unsigned integral type + ensures + - returns the number of bits which differ between a and b. (I.e. returns + count_bits(a^b).) + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_COUNT_BiTS_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/general_hash/general_hash.h b/ml/dlib/dlib/general_hash/general_hash.h new file mode 100644 index 000000000..3de0b2698 --- /dev/null +++ b/ml/dlib/dlib/general_hash/general_hash.h @@ -0,0 +1,80 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_GENERAL_HASh_ +#define DLIB_GENERAL_HASh_ + + +#include +#include "hash.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ----------------------- provide a general hashing function object ---------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + class general_hash + { + public: + inline unsigned long operator() ( + const T& item + ) const; + }; + /*! + Note that the default behavior of general hash is to attempt to cast + an object of type T to an unsigned long and use that as the hash. + + REQUIREMENTS ON general_hash + - must have a default constructor + - must be a function object which overloads operator() as follows: + unsigned long operator()(const T& item) + - must take item, compute a hash number and return it + - must not throw + - must define the hash in such a way that all equivalent objects have + the same hash. where equivalent means the following: + definition of equivalent: + a is equivalent to b if + a < b == false and + b < a == false + !*/ + +// --------------- + + template < + typename T + > + unsigned long general_hash:: + operator() ( + const T& item + ) const + { + // hash any types that have a conversion to uint64 + return hash(static_cast(item)); + } + + +// --------------- + + // std::string hash + template <> + inline unsigned long general_hash:: + operator() ( + const std::string& item + ) const + { + return hash(item); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_GENERAL_HASh_ + diff --git a/ml/dlib/dlib/general_hash/hash.h b/ml/dlib/dlib/general_hash/hash.h new file mode 100644 index 000000000..6edb99b99 --- /dev/null +++ b/ml/dlib/dlib/general_hash/hash.h @@ -0,0 +1,142 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_HAsH_Hh_ +#define DLIB_HAsH_Hh_ + +#include "hash_abstract.h" +#include +#include +#include +#include "murmur_hash3.h" +#include "../algs.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + inline uint32 hash ( + const std::string& item, + uint32 seed = 0 + ) + { + if (item.size() == 0) + return 0; + else + return murmur_hash3(&item[0], sizeof(item[0])*item.size(), seed); + } + +// ---------------------------------------------------------------------------------------- + + inline uint32 hash ( + const std::wstring& item, + uint32 seed = 0 + ) + { + if (item.size() == 0) + return 0; + else + return murmur_hash3(&item[0], sizeof(item[0])*item.size(), seed); + } + +// ---------------------------------------------------------------------------------------- + + template + uint32 hash ( + const std::vector& item, + uint32 seed = 0 + ) + { + DLIB_ASSERT_HAS_STANDARD_LAYOUT(T); + + if (item.size() == 0) + return 0; + else + return murmur_hash3(&item[0], sizeof(T)*item.size(), seed); + } + +// ---------------------------------------------------------------------------------------- + + template + uint32 hash ( + const std::vector,alloc>& item, + uint32 seed = 0 + ) + { + DLIB_ASSERT_HAS_STANDARD_LAYOUT(T); + DLIB_ASSERT_HAS_STANDARD_LAYOUT(U); + + if (item.size() == 0) + return 0; + else + return murmur_hash3(&item[0], sizeof(item[0])*item.size(), seed); + } + +// ---------------------------------------------------------------------------------------- + + template + uint32 hash ( + const std::map& item, + uint32 seed = 0 + ) + { + return hash(std::vector >(item.begin(), item.end()), seed); + } + +// ---------------------------------------------------------------------------------------- + + inline uint32 hash ( + uint32 val, + uint32 seed = 0 + ) + { + return murmur_hash3_2(val,seed); + } + +// ---------------------------------------------------------------------------------------- + + inline uint32 hash ( + uint64 val, + uint32 seed = 0 + ) + { + return static_cast(murmur_hash3_128bit_3(val,seed,0).first); + } + +// ---------------------------------------------------------------------------------------- + + inline uint32 hash ( + const std::pair& item, + uint32 seed = 0 + ) + { + return static_cast(murmur_hash3_128bit_3(item.first,item.second,seed).first); + } + +// ---------------------------------------------------------------------------------------- + + inline uint32 hash ( + const std::pair& item, + uint32 seed = 0 + ) + { + return murmur_hash3_3(item.first,item.second,seed); + } + +// ---------------------------------------------------------------------------------------- + + template + uint32 hash ( + const std::pair& item, + uint32 seed = 0 + ) + { + return hash(item.first, seed) ^ hash(item.second, seed+1); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_HAsH_Hh_ + diff --git a/ml/dlib/dlib/general_hash/hash_abstract.h b/ml/dlib/dlib/general_hash/hash_abstract.h new file mode 100644 index 000000000..8959bbe44 --- /dev/null +++ b/ml/dlib/dlib/general_hash/hash_abstract.h @@ -0,0 +1,182 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_HAsH_ABSTRACT_Hh_ +#ifdef DLIB_HAsH_ABSTRACT_Hh_ + +#include "murmur_hash3_abstract.h" +#include +#include +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + uint32 hash ( + const std::string& item, + uint32 seed = 0 + ); + /*! + ensures + - returns a 32bit hash of the data stored in item. + - Each value of seed results in a different hash function being used. + (e.g. hash(item,0) should generally not be equal to hash(item,1)) + - uses the murmur_hash3() routine to compute the actual hash. + - This routine will always give the same hash value when presented + with the same input string. + !*/ + +// ---------------------------------------------------------------------------------------- + + uint32 hash ( + const std::wstring& item, + uint32 seed = 0 + ); + /*! + ensures + - returns a 32bit hash of the data stored in item. + - Each value of seed results in a different hash function being used. + (e.g. hash(item,0) should generally not be equal to hash(item,1)) + - uses the murmur_hash3() routine to compute the actual hash. + - Note that if the memory layout of the elements in item change between + hardware platforms then hash() will give different outputs. If you want + hash() to always give the same output for the same input then you must + ensure that elements of item always have the same layout in memory. + Typically this means using fixed width types and performing byte swapping + to account for endianness before passing item to hash(). + !*/ + +// ---------------------------------------------------------------------------------------- + + template + uint32 hash ( + const std::vector& item, + uint32 seed = 0 + ); + /*! + requires + - T is a standard layout type (e.g. a POD type like int, float, + or a simple struct). + ensures + - returns a 32bit hash of the data stored in item. + - Each value of seed results in a different hash function being used. + (e.g. hash(item,0) should generally not be equal to hash(item,1)) + - uses the murmur_hash3() routine to compute the actual hash. + - Note that if the memory layout of the elements in item change between + hardware platforms then hash() will give different outputs. If you want + hash() to always give the same output for the same input then you must + ensure that elements of item always have the same layout in memory. + Typically this means using fixed width types and performing byte swapping + to account for endianness before passing item to hash(). + !*/ + +// ---------------------------------------------------------------------------------------- + + template + uint32 hash ( + const std::vector,alloc>& item, + uint32 seed = 0 + ); + /*! + requires + - T and U are standard layout types (e.g. POD types like int, float, + or simple structs). + ensures + - returns a 32bit hash of the data stored in item. + - Each value of seed results in a different hash function being used. + (e.g. hash(item,0) should generally not be equal to hash(item,1)) + - uses the murmur_hash3() routine to compute the actual hash. + - Note that if the memory layout of the elements in item change between + hardware platforms then hash() will give different outputs. If you want + hash() to always give the same output for the same input then you must + ensure that elements of item always have the same layout in memory. + Typically this means using fixed width types and performing byte swapping + to account for endianness before passing item to hash(). + !*/ + +// ---------------------------------------------------------------------------------------- + + template + uint32 hash ( + const std::map& item, + uint32 seed = 0 + ); + /*! + requires + - T and U are standard layout types (e.g. POD types like int, float, + or simple structs). + ensures + - returns a 32bit hash of the data stored in item. + - Each value of seed results in a different hash function being used. + (e.g. hash(item,0) should generally not be equal to hash(item,1)) + - uses the murmur_hash3() routine to compute the actual hash. + - Note that if the memory layout of the elements in item change between + hardware platforms then hash() will give different outputs. If you want + hash() to always give the same output for the same input then you must + ensure that elements of item always have the same layout in memory. + Typically this means using fixed width types and performing byte swapping + to account for endianness before passing item to hash(). However, since + you can't modify the keys in a map you may have to copy it into a + std::vector and then work from there. + !*/ + +// ---------------------------------------------------------------------------------------- + + inline uint32 hash ( + uint32 item, + uint32 seed = 0 + ); + /*! + ensures + - returns a 32bit hash of the data stored in item. + - Each value of seed results in a different hash function being used. + (e.g. hash(item,0) should generally not be equal to hash(item,1)) + - uses the murmur_hash3_2() routine to compute the actual hash. + - This routine will always give the same hash value when presented + with the same input. + !*/ + +// ---------------------------------------------------------------------------------------- + + inline uint32 hash ( + uint64 item, + uint32 seed = 0 + ); + /*! + ensures + - returns a 32bit hash of the data stored in item. + - Each value of seed results in a different hash function being used. + (e.g. hash(item,0) should generally not be equal to hash(item,1)) + - uses the murmur_hash3_128bit_3() routine to compute the actual hash. + - This routine will always give the same hash value when presented + with the same input. + !*/ + +// ---------------------------------------------------------------------------------------- + + template + uint32 hash ( + const std::pair& item, + uint32 seed = 0 + ); + /*! + requires + - hash() is defined for objects of type T and U + ensures + - returns a 32bit hash of the data stored in item. + - Each value of seed results in a different hash function being used. + (e.g. hash(item,0) should generally not be equal to hash(item,1)) + - if (calling hash() on items of type T and U is always guaranteed to give the + same hash values when presented with the same input) then + - This routine will always give the same hash value when presented + with the same input. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_HAsH_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/general_hash/murmur_hash3.h b/ml/dlib/dlib/general_hash/murmur_hash3.h new file mode 100644 index 000000000..b8e37260e --- /dev/null +++ b/ml/dlib/dlib/general_hash/murmur_hash3.h @@ -0,0 +1,519 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MURMUR_HAsH_3_Hh_ +#define DLIB_MURMUR_HAsH_3_Hh_ + +#include "murmur_hash3_abstract.h" +#include "../uintn.h" +#include +#include + +namespace dlib +{ + //----------------------------------------------------------------------------- + // The original MurmurHash3 code was written by Austin Appleby, and is placed + // in the public domain. The author hereby disclaims copyright to this source code. + // The code in this particular file was modified by Davis E. King. In + // particular, endian-swapping was added along with some other minor code + // changes like avoiding strict aliasing violations. + + + //----------------------------------------------------------------------------- + // Platform-specific functions and macros + + // Microsoft Visual Studio + +#if defined(_MSC_VER) + +#define DLIB_FORCE_INLINE __forceinline + +#include + +#define DLIB_ROTL32(x,y) _rotl(x,y) +#define DLIB_ROTL64(x,y) _rotl64(x,y) + +#define DLIB_BIG_CONSTANT(x) (x) + + // Other compilers + +#else // defined(_MSC_VER) + +#define DLIB_FORCE_INLINE __attribute__((always_inline)) inline + + inline uint32 murmur_rotl32 ( uint32 x, int8 r ) + { + return (x << r) | (x >> (32 - r)); + } + + inline uint64 murmur_rotl64 ( uint64 x, int8 r ) + { + return (x << r) | (x >> (64 - r)); + } + +#define DLIB_ROTL32(x,y) dlib::murmur_rotl32(x,y) +#define DLIB_ROTL64(x,y) dlib::murmur_rotl64(x,y) + +#define DLIB_BIG_CONSTANT(x) (x##LLU) + +#endif // !defined(_MSC_VER) + +// ---------------------------------------------------------------------------------------- + // Block read - if your platform needs to do endian-swapping or can only + // handle aligned reads, do the conversion here + + DLIB_FORCE_INLINE uint32 murmur_getblock ( const uint32 * p, int i ) + { + // The reason we do a memcpy() here instead of simply returning p[i] is because + // doing it this way avoids violations of the strict aliasing rule when all these + // functions are inlined into the user's code. + uint32 temp; + memcpy(&temp, p+i, 4); + return temp; + } + + DLIB_FORCE_INLINE uint32 murmur_getblock_byte_swap ( const uint32 * p, int i ) + { + union + { + uint8 bytes[4]; + uint32 val; + } temp; + + const uint8* pp = reinterpret_cast(p + i); + temp.bytes[0] = pp[3]; + temp.bytes[1] = pp[2]; + temp.bytes[2] = pp[1]; + temp.bytes[3] = pp[0]; + + return temp.val; + } + + DLIB_FORCE_INLINE uint64 murmur_getblock ( const uint64 * p, int i ) + { + // The reason we do a memcpy() here instead of simply returning p[i] is because + // doing it this way avoids violations of the strict aliasing rule when all these + // functions are inlined into the user's code. + uint64 temp; + memcpy(&temp, p+i, 8); + return temp; + } + + DLIB_FORCE_INLINE uint64 murmur_getblock_byte_swap ( const uint64 * p, int i ) + { + union + { + uint8 bytes[8]; + uint64 val; + } temp; + + const uint8* pp = reinterpret_cast(p + i); + temp.bytes[0] = pp[7]; + temp.bytes[1] = pp[6]; + temp.bytes[2] = pp[5]; + temp.bytes[3] = pp[4]; + temp.bytes[4] = pp[3]; + temp.bytes[5] = pp[2]; + temp.bytes[6] = pp[1]; + temp.bytes[7] = pp[0]; + + return temp.val; + } + +// ---------------------------------------------------------------------------------------- + // Finalization mix - force all bits of a hash block to avalanche + + DLIB_FORCE_INLINE uint32 murmur_fmix ( uint32 h ) + { + h ^= h >> 16; + h *= 0x85ebca6b; + h ^= h >> 13; + h *= 0xc2b2ae35; + h ^= h >> 16; + + return h; + } + +// ---------------------------------------------------------------------------------------- + + DLIB_FORCE_INLINE uint64 murmur_fmix ( uint64 k ) + { + k ^= k >> 33; + k *= DLIB_BIG_CONSTANT(0xff51afd7ed558ccd); + k ^= k >> 33; + k *= DLIB_BIG_CONSTANT(0xc4ceb9fe1a85ec53); + k ^= k >> 33; + + return k; + } + +// ---------------------------------------------------------------------------------------- + + inline uint32 murmur_hash3 ( + const void * key, + const int len, + const uint32 seed = 0 + ) + { + const uint8 * data = (const uint8*)key; + const int nblocks = len / 4; + + uint32 h1 = seed; + + uint32 c1 = 0xcc9e2d51; + uint32 c2 = 0x1b873593; + + //---------- + // body + + const uint32 * blocks = (const uint32 *)(data + nblocks*4); + + bool is_little_endian = true; + uint32 endian_test = 1; + if (*reinterpret_cast(&endian_test) != 1) + is_little_endian = false; + + + if (is_little_endian) + { + for(int i = -nblocks; i; i++) + { + uint32 k1 = murmur_getblock(blocks,i); + + k1 *= c1; + k1 = DLIB_ROTL32(k1,15); + k1 *= c2; + + h1 ^= k1; + h1 = DLIB_ROTL32(h1,13); + h1 = h1*5+0xe6546b64; + } + } + else + { + for(int i = -nblocks; i; i++) + { + uint32 k1 = murmur_getblock_byte_swap(blocks,i); + + k1 *= c1; + k1 = DLIB_ROTL32(k1,15); + k1 *= c2; + + h1 ^= k1; + h1 = DLIB_ROTL32(h1,13); + h1 = h1*5+0xe6546b64; + } + } + + //---------- + // tail + + const uint8 * tail = (const uint8*)(data + nblocks*4); + + uint32 k1 = 0; + + switch(len & 3) + { + case 3: k1 ^= tail[2] << 16; + // fall through + case 2: k1 ^= tail[1] << 8; + // fall through + case 1: k1 ^= tail[0]; + k1 *= c1; k1 = DLIB_ROTL32(k1,15); k1 *= c2; h1 ^= k1; + }; + + //---------- + // finalization + + h1 ^= len; + + h1 = murmur_fmix(h1); + + return h1; + } + +// ---------------------------------------------------------------------------------------- + + inline uint32 murmur_hash3_2 ( + const uint32 v1, + const uint32 v2 + ) + { + uint32 h1 = v2; + + uint32 c1 = 0xcc9e2d51; + uint32 c2 = 0x1b873593; + + //---------- + // body + + + uint32 k1 = v1; + + k1 *= c1; + k1 = DLIB_ROTL32(k1,15); + k1 *= c2; + + h1 ^= k1; + h1 = DLIB_ROTL32(h1,13); + h1 = h1*5+0xe6546b64; + + + //---------- + // finalization + + h1 ^= 4; // =^ by length in bytes + + h1 = murmur_fmix(h1); + + return h1; + } + +// ---------------------------------------------------------------------------------------- + + inline uint32 murmur_hash3_3 ( + const uint32 v1, + const uint32 v2, + const uint32 v3 + ) + { + + uint32 h1 = v3; + + uint32 c1 = 0xcc9e2d51; + uint32 c2 = 0x1b873593; + + //---------- + // body + + + uint32 k1 = v1; + + k1 *= c1; + k1 = DLIB_ROTL32(k1,15); + k1 *= c2; + + h1 ^= k1; + h1 = DLIB_ROTL32(h1,13); + h1 = h1*5+0xe6546b64; + + k1 = v2; + k1 *= c1; + k1 = DLIB_ROTL32(k1,15); + k1 *= c2; + + h1 ^= k1; + h1 = DLIB_ROTL32(h1,13); + h1 = h1*5+0xe6546b64; + + //---------- + // finalization + + h1 ^= 8; // =^ by length in bytes + + h1 = murmur_fmix(h1); + + return h1; + } + +// ---------------------------------------------------------------------------------------- + + inline std::pair murmur_hash3_128bit ( + const void* key, + const int len, + const uint32 seed = 0 + ) + { + const uint8 * data = (const uint8*)key; + const int nblocks = len / 16; + + uint64 h1 = seed; + uint64 h2 = seed; + + uint64 c1 = DLIB_BIG_CONSTANT(0x87c37b91114253d5); + uint64 c2 = DLIB_BIG_CONSTANT(0x4cf5ad432745937f); + + //---------- + // body + + const uint64 * blocks = (const uint64 *)(data); + + bool is_little_endian = true; + uint32 endian_test = 1; + if (*reinterpret_cast(&endian_test) != 1) + is_little_endian = false; + + + if (is_little_endian) + { + for(int i = 0; i < nblocks; i++) + { + uint64 k1 = murmur_getblock(blocks,i*2+0); + uint64 k2 = murmur_getblock(blocks,i*2+1); + + k1 *= c1; k1 = DLIB_ROTL64(k1,31); k1 *= c2; h1 ^= k1; + + h1 = DLIB_ROTL64(h1,27); h1 += h2; h1 = h1*5+0x52dce729; + + k2 *= c2; k2 = DLIB_ROTL64(k2,33); k2 *= c1; h2 ^= k2; + + h2 = DLIB_ROTL64(h2,31); h2 += h1; h2 = h2*5+0x38495ab5; + } + } + else + { + for(int i = 0; i < nblocks; i++) + { + uint64 k1 = murmur_getblock_byte_swap(blocks,i*2+0); + uint64 k2 = murmur_getblock_byte_swap(blocks,i*2+1); + + k1 *= c1; k1 = DLIB_ROTL64(k1,31); k1 *= c2; h1 ^= k1; + + h1 = DLIB_ROTL64(h1,27); h1 += h2; h1 = h1*5+0x52dce729; + + k2 *= c2; k2 = DLIB_ROTL64(k2,33); k2 *= c1; h2 ^= k2; + + h2 = DLIB_ROTL64(h2,31); h2 += h1; h2 = h2*5+0x38495ab5; + } + } + + //---------- + // tail + + const uint8 * tail = (const uint8*)(data + nblocks*16); + + uint64 k1 = 0; + uint64 k2 = 0; + + switch(len & 15) + { + case 15: k2 ^= uint64(tail[14]) << 48; // fall through + case 14: k2 ^= uint64(tail[13]) << 40; // fall through + case 13: k2 ^= uint64(tail[12]) << 32; // fall through + case 12: k2 ^= uint64(tail[11]) << 24; // fall through + case 11: k2 ^= uint64(tail[10]) << 16; // fall through + case 10: k2 ^= uint64(tail[ 9]) << 8; // fall through + case 9: k2 ^= uint64(tail[ 8]) << 0; + k2 *= c2; k2 = DLIB_ROTL64(k2,33); k2 *= c1; h2 ^= k2; // fall through + + case 8: k1 ^= uint64(tail[ 7]) << 56; // fall through + case 7: k1 ^= uint64(tail[ 6]) << 48; // fall through + case 6: k1 ^= uint64(tail[ 5]) << 40; // fall through + case 5: k1 ^= uint64(tail[ 4]) << 32; // fall through + case 4: k1 ^= uint64(tail[ 3]) << 24; // fall through + case 3: k1 ^= uint64(tail[ 2]) << 16; // fall through + case 2: k1 ^= uint64(tail[ 1]) << 8; // fall through + case 1: k1 ^= uint64(tail[ 0]) << 0; + k1 *= c1; k1 = DLIB_ROTL64(k1,31); k1 *= c2; h1 ^= k1; // fall through + }; + + //---------- + // finalization + + h1 ^= len; h2 ^= len; + + h1 += h2; + h2 += h1; + + h1 = murmur_fmix(h1); + h2 = murmur_fmix(h2); + + h1 += h2; + h2 += h1; + + return std::make_pair(h1,h2); + } + +// ---------------------------------------------------------------------------------------- + + inline std::pair murmur_hash3_128bit ( + const uint32& v1, + const uint32& v2, + const uint32& v3, + const uint32& v4 + ) + { + uint64 h1 = 0; + uint64 h2 = 0; + + const uint64 c1 = DLIB_BIG_CONSTANT(0x87c37b91114253d5); + const uint64 c2 = DLIB_BIG_CONSTANT(0x4cf5ad432745937f); + + //---------- + // body + + uint64 k1 = (static_cast(v2)<<32)|v1; + uint64 k2 = (static_cast(v4)<<32)|v3; + + k1 *= c1; k1 = DLIB_ROTL64(k1,31); k1 *= c2; + + h1 = DLIB_ROTL64(k1,27); h1 = h1*5+0x52dce729; + + k2 *= c2; k2 = DLIB_ROTL64(k2,33); k2 *= c1; + + h2 = DLIB_ROTL64(k2,31); h2 += h1; h2 = h2*5+0x38495ab5; + + //---------- + // finalization + + h1 ^= 16; h2 ^= 16; + + h1 += h2; + h2 += h1; + + h1 = murmur_fmix(h1); + h2 = murmur_fmix(h2); + + h1 += h2; + h2 += h1; + + return std::make_pair(h1,h2); + } + +// ---------------------------------------------------------------------------------------- + + inline std::pair murmur_hash3_128bit_3 ( + uint64 k1, + uint64 k2, + uint64 k3 + ) + { + uint64 h1 = k3; + uint64 h2 = k3; + + const uint64 c1 = DLIB_BIG_CONSTANT(0x87c37b91114253d5); + const uint64 c2 = DLIB_BIG_CONSTANT(0x4cf5ad432745937f); + + //---------- + // body + + k1 *= c1; k1 = DLIB_ROTL64(k1,31); k1 *= c2; h1 ^= k1; + + h1 = DLIB_ROTL64(h1,27); h1 += h2; h1 = h1*5+0x52dce729; + + k2 *= c2; k2 = DLIB_ROTL64(k2,33); k2 *= c1; h2 ^= k2; + + h2 = DLIB_ROTL64(h2,31); h2 += h1; h2 = h2*5+0x38495ab5; + + //---------- + // finalization + + h1 ^= 16; h2 ^= 16; + + h1 += h2; + h2 += h1; + + h1 = murmur_fmix(h1); + h2 = murmur_fmix(h2); + + h1 += h2; + h2 += h1; + + return std::make_pair(h1,h2); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MURMUR_HAsH_3_Hh_ + diff --git a/ml/dlib/dlib/general_hash/murmur_hash3_abstract.h b/ml/dlib/dlib/general_hash/murmur_hash3_abstract.h new file mode 100644 index 000000000..ef7bc0b41 --- /dev/null +++ b/ml/dlib/dlib/general_hash/murmur_hash3_abstract.h @@ -0,0 +1,125 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_MURMUR_HAsH_3_ABSTRACT_Hh_ +#ifdef DLIB_MURMUR_HAsH_3_ABSTRACT_Hh_ + +#include "../uintn.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + uint32 murmur_hash3 ( + const void* key, + const int len, + const uint32 seed = 0 + ); + /*! + requires + - key == a pointer to a block of memory len bytes long + ensures + - returns a 32bit hash of the len bytes pointed to by key. + - Each value of seed results in a different hash function being used. + (e.g. murmur_hash3(key,len,0) should generally not be equal to murmur_hash3(key,len,1)) + - This function is machine architecture agnostic and should always give the same + hash value when presented with the same inputs. + - This hashing algorithm is Austin Appleby's excellent MurmurHash3_x86_32. + See: http://code.google.com/p/smhasher/ + !*/ + +// ---------------------------------------------------------------------------------------- + + inline uint32 murmur_hash3_2 ( + const uint32 v1, + const uint32 v2 + ); + /*! + ensures + - returns a 32bit hash of the two integers given to this function. + - This function is machine architecture agnostic and should always give the same + hash value when presented with the same inputs. + - This hashing algorithm is Austin Appleby's excellent MurmurHash3_x86_32. + See: http://code.google.com/p/smhasher/ + !*/ + +// ---------------------------------------------------------------------------------------- + + inline uint32 murmur_hash3_3 ( + const uint32 v1, + const uint32 v2, + const uint32 v3 + ); + /*! + ensures + - returns a 32bit hash of the three integers given to this function. + - This function is machine architecture agnostic and should always give the same + hash value when presented with the same inputs. + - This hashing algorithm is Austin Appleby's excellent MurmurHash3_x86_32. + See: http://code.google.com/p/smhasher/ + !*/ + +// ---------------------------------------------------------------------------------------- + + std::pair murmur_hash3_128bit ( + const void* key, + const int len, + const uint32 seed = 0 + ); + /*! + requires + - key == a pointer to a block of memory len bytes long + ensures + - returns a 128bit hash (as two 64bit numbers) of the len bytes pointed to by key. + - Each value of seed results in a different hash function being used. + (e.g. murmur_hash3_128bit(key,len,0) should generally not be equal to + murmur_hash3_128bit(key,len,1)) + - This function is machine architecture agnostic and should always give the same + hash value when presented with the same inputs. + - This hashing algorithm is Austin Appleby's excellent MurmurHash3_x64_128. + See: http://code.google.com/p/smhasher/ + !*/ + +// ---------------------------------------------------------------------------------------- + + std::pair murmur_hash3_128bit ( + const uint32& v1, + const uint32& v2, + const uint32& v3, + const uint32& v4 + ); + /*! + ensures + - returns a 128bit hash (as two 64bit numbers) of the 4 integers given to this + function. + - This function is machine architecture agnostic and should always give the + same hash value when presented with the same inputs. + - This hashing algorithm is Austin Appleby's excellent MurmurHash3_x64_128. + See: http://code.google.com/p/smhasher/ + !*/ + +// ---------------------------------------------------------------------------------------- + + std::pair murmur_hash3_128bit_3 ( + uint64 k1, + uint64 k2, + uint64 k3 + ); + /*! + ensures + - returns a 128bit hash (as two 64bit numbers) of the 3 integers given to this + function. + - This function is machine architecture agnostic and should always give the + same hash value when presented with the same inputs. + - This hashing algorithm is Austin Appleby's excellent MurmurHash3_x64_128. + See: http://code.google.com/p/smhasher/ + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MURMUR_HAsH_3_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/general_hash/random_hashing.h b/ml/dlib/dlib/general_hash/random_hashing.h new file mode 100644 index 000000000..7a06a6878 --- /dev/null +++ b/ml/dlib/dlib/general_hash/random_hashing.h @@ -0,0 +1,877 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_RANDOM_HAsHING_Hh_ +#define DLIB_RANDOM_HAsHING_Hh_ + +#include "random_hashing_abstract.h" +#include "murmur_hash3.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + inline double uniform_random_hash ( + const uint64& k1, + const uint64& k2, + const uint64& k3 + ) + { + const std::pair h = murmur_hash3_128bit_3(k1,k2,k3); + const uint64 mask = DLIB_BIG_CONSTANT(0xFFFFFFFFFF); + const double max = mask+1; + return static_cast(h.first&mask)/max; + } + +// ---------------------------------------------------------------------------------------- + + inline double gaussian_random_hash ( + const uint64& k1, + const uint64& k2, + const uint64& k3 + ) + { + const std::pair h = murmur_hash3_128bit_3(k1,k2,k3); + + const static unsigned int max = 4096; + + const static double logvals[max] = { + 4.079, 3.905, 3.8, 3.723, 3.663, 3.613, 3.57, 3.532, 3.499, 3.468, + 3.441, 3.416, 3.392, 3.37, 3.35, 3.33, 3.312, 3.295, 3.278, 3.263, + 3.248, 3.233, 3.219, 3.206, 3.193, 3.181, 3.169, 3.158, 3.147, 3.136, + 3.125, 3.115, 3.105, 3.096, 3.086, 3.077, 3.068, 3.059, 3.051, 3.043, + 3.035, 3.027, 3.019, 3.011, 3.004, 2.996, 2.989, 2.982, 2.975, 2.968, + 2.962, 2.955, 2.949, 2.942, 2.936, 2.93, 2.924, 2.918, 2.912, 2.906, + 2.901, 2.895, 2.89, 2.884, 2.879, 2.873, 2.868, 2.863, 2.858, 2.853, + 2.848, 2.843, 2.838, 2.833, 2.829, 2.824, 2.819, 2.815, 2.81, 2.806, + 2.801, 2.797, 2.792, 2.788, 2.784, 2.78, 2.776, 2.771, 2.767, 2.763, + 2.759, 2.755, 2.751, 2.748, 2.744, 2.74, 2.736, 2.732, 2.729, 2.725, + 2.721, 2.718, 2.714, 2.71, 2.707, 2.703, 2.7, 2.697, 2.693, 2.69, + 2.686, 2.683, 2.68, 2.676, 2.673, 2.67, 2.667, 2.663, 2.66, 2.657, + 2.654, 2.651, 2.648, 2.645, 2.642, 2.639, 2.636, 2.633, 2.63, 2.627, + 2.624, 2.621, 2.618, 2.615, 2.612, 2.61, 2.607, 2.604, 2.601, 2.599, + 2.596, 2.593, 2.59, 2.588, 2.585, 2.582, 2.58, 2.577, 2.574, 2.572, + 2.569, 2.567, 2.564, 2.562, 2.559, 2.557, 2.554, 2.552, 2.549, 2.547, + 2.544, 2.542, 2.539, 2.537, 2.534, 2.532, 2.53, 2.527, 2.525, 2.523, + 2.52, 2.518, 2.516, 2.513, 2.511, 2.509, 2.507, 2.504, 2.502, 2.5, + 2.498, 2.495, 2.493, 2.491, 2.489, 2.487, 2.485, 2.482, 2.48, 2.478, + 2.476, 2.474, 2.472, 2.47, 2.468, 2.466, 2.464, 2.462, 2.459, 2.457, + 2.455, 2.453, 2.451, 2.449, 2.447, 2.445, 2.443, 2.441, 2.439, 2.437, + 2.436, 2.434, 2.432, 2.43, 2.428, 2.426, 2.424, 2.422, 2.42, 2.418, + 2.416, 2.415, 2.413, 2.411, 2.409, 2.407, 2.405, 2.404, 2.402, 2.4, + 2.398, 2.396, 2.394, 2.393, 2.391, 2.389, 2.387, 2.386, 2.384, 2.382, + 2.38, 2.379, 2.377, 2.375, 2.373, 2.372, 2.37, 2.368, 2.367, 2.365, + 2.363, 2.361, 2.36, 2.358, 2.356, 2.355, 2.353, 2.352, 2.35, 2.348, + 2.347, 2.345, 2.343, 2.342, 2.34, 2.338, 2.337, 2.335, 2.334, 2.332, + 2.331, 2.329, 2.327, 2.326, 2.324, 2.323, 2.321, 2.32, 2.318, 2.316, + 2.315, 2.313, 2.312, 2.31, 2.309, 2.307, 2.306, 2.304, 2.303, 2.301, + 2.3, 2.298, 2.297, 2.295, 2.294, 2.292, 2.291, 2.289, 2.288, 2.286, + 2.285, 2.284, 2.282, 2.281, 2.279, 2.278, 2.276, 2.275, 2.274, 2.272, + 2.271, 2.269, 2.268, 2.266, 2.265, 2.264, 2.262, 2.261, 2.259, 2.258, + 2.257, 2.255, 2.254, 2.253, 2.251, 2.25, 2.248, 2.247, 2.246, 2.244, + 2.243, 2.242, 2.24, 2.239, 2.238, 2.236, 2.235, 2.234, 2.232, 2.231, + 2.23, 2.228, 2.227, 2.226, 2.225, 2.223, 2.222, 2.221, 2.219, 2.218, + 2.217, 2.215, 2.214, 2.213, 2.212, 2.21, 2.209, 2.208, 2.207, 2.205, + 2.204, 2.203, 2.202, 2.2, 2.199, 2.198, 2.197, 2.195, 2.194, 2.193, + 2.192, 2.19, 2.189, 2.188, 2.187, 2.185, 2.184, 2.183, 2.182, 2.181, + 2.179, 2.178, 2.177, 2.176, 2.175, 2.173, 2.172, 2.171, 2.17, 2.169, + 2.168, 2.166, 2.165, 2.164, 2.163, 2.162, 2.16, 2.159, 2.158, 2.157, + 2.156, 2.155, 2.154, 2.152, 2.151, 2.15, 2.149, 2.148, 2.147, 2.146, + 2.144, 2.143, 2.142, 2.141, 2.14, 2.139, 2.138, 2.136, 2.135, 2.134, + 2.133, 2.132, 2.131, 2.13, 2.129, 2.128, 2.126, 2.125, 2.124, 2.123, + 2.122, 2.121, 2.12, 2.119, 2.118, 2.117, 2.116, 2.114, 2.113, 2.112, + 2.111, 2.11, 2.109, 2.108, 2.107, 2.106, 2.105, 2.104, 2.103, 2.102, + 2.101, 2.1, 2.099, 2.097, 2.096, 2.095, 2.094, 2.093, 2.092, 2.091, + 2.09, 2.089, 2.088, 2.087, 2.086, 2.085, 2.084, 2.083, 2.082, 2.081, + 2.08, 2.079, 2.078, 2.077, 2.076, 2.075, 2.074, 2.073, 2.072, 2.071, + 2.07, 2.069, 2.068, 2.067, 2.066, 2.065, 2.064, 2.063, 2.062, 2.061, + 2.06, 2.059, 2.058, 2.057, 2.056, 2.055, 2.054, 2.053, 2.052, 2.051, + 2.05, 2.049, 2.048, 2.047, 2.046, 2.045, 2.044, 2.043, 2.042, 2.041, + 2.04, 2.039, 2.038, 2.037, 2.036, 2.036, 2.035, 2.034, 2.033, 2.032, + 2.031, 2.03, 2.029, 2.028, 2.027, 2.026, 2.025, 2.024, 2.023, 2.022, + 2.021, 2.02, 2.02, 2.019, 2.018, 2.017, 2.016, 2.015, 2.014, 2.013, + 2.012, 2.011, 2.01, 2.009, 2.008, 2.008, 2.007, 2.006, 2.005, 2.004, + 2.003, 2.002, 2.001, 2, 1.999, 1.998, 1.998, 1.997, 1.996, 1.995, + 1.994, 1.993, 1.992, 1.991, 1.99, 1.99, 1.989, 1.988, 1.987, 1.986, + 1.985, 1.984, 1.983, 1.982, 1.982, 1.981, 1.98, 1.979, 1.978, 1.977, + 1.976, 1.975, 1.975, 1.974, 1.973, 1.972, 1.971, 1.97, 1.969, 1.969, + 1.968, 1.967, 1.966, 1.965, 1.964, 1.963, 1.963, 1.962, 1.961, 1.96, + 1.959, 1.958, 1.957, 1.957, 1.956, 1.955, 1.954, 1.953, 1.952, 1.952, + 1.951, 1.95, 1.949, 1.948, 1.947, 1.947, 1.946, 1.945, 1.944, 1.943, + 1.942, 1.942, 1.941, 1.94, 1.939, 1.938, 1.937, 1.937, 1.936, 1.935, + 1.934, 1.933, 1.933, 1.932, 1.931, 1.93, 1.929, 1.928, 1.928, 1.927, + 1.926, 1.925, 1.924, 1.924, 1.923, 1.922, 1.921, 1.92, 1.92, 1.919, + 1.918, 1.917, 1.916, 1.916, 1.915, 1.914, 1.913, 1.912, 1.912, 1.911, + 1.91, 1.909, 1.908, 1.908, 1.907, 1.906, 1.905, 1.904, 1.904, 1.903, + 1.902, 1.901, 1.901, 1.9, 1.899, 1.898, 1.897, 1.897, 1.896, 1.895, + 1.894, 1.894, 1.893, 1.892, 1.891, 1.89, 1.89, 1.889, 1.888, 1.887, + 1.887, 1.886, 1.885, 1.884, 1.884, 1.883, 1.882, 1.881, 1.88, 1.88, + 1.879, 1.878, 1.877, 1.877, 1.876, 1.875, 1.874, 1.874, 1.873, 1.872, + 1.871, 1.871, 1.87, 1.869, 1.868, 1.868, 1.867, 1.866, 1.865, 1.865, + 1.864, 1.863, 1.862, 1.862, 1.861, 1.86, 1.859, 1.859, 1.858, 1.857, + 1.857, 1.856, 1.855, 1.854, 1.854, 1.853, 1.852, 1.851, 1.851, 1.85, + 1.849, 1.848, 1.848, 1.847, 1.846, 1.846, 1.845, 1.844, 1.843, 1.843, + 1.842, 1.841, 1.84, 1.84, 1.839, 1.838, 1.838, 1.837, 1.836, 1.835, + 1.835, 1.834, 1.833, 1.833, 1.832, 1.831, 1.83, 1.83, 1.829, 1.828, + 1.828, 1.827, 1.826, 1.825, 1.825, 1.824, 1.823, 1.823, 1.822, 1.821, + 1.821, 1.82, 1.819, 1.818, 1.818, 1.817, 1.816, 1.816, 1.815, 1.814, + 1.814, 1.813, 1.812, 1.811, 1.811, 1.81, 1.809, 1.809, 1.808, 1.807, + 1.807, 1.806, 1.805, 1.805, 1.804, 1.803, 1.802, 1.802, 1.801, 1.8, + 1.8, 1.799, 1.798, 1.798, 1.797, 1.796, 1.796, 1.795, 1.794, 1.794, + 1.793, 1.792, 1.792, 1.791, 1.79, 1.79, 1.789, 1.788, 1.787, 1.787, + 1.786, 1.785, 1.785, 1.784, 1.783, 1.783, 1.782, 1.781, 1.781, 1.78, + 1.779, 1.779, 1.778, 1.777, 1.777, 1.776, 1.775, 1.775, 1.774, 1.773, + 1.773, 1.772, 1.771, 1.771, 1.77, 1.769, 1.769, 1.768, 1.767, 1.767, + 1.766, 1.766, 1.765, 1.764, 1.764, 1.763, 1.762, 1.762, 1.761, 1.76, + 1.76, 1.759, 1.758, 1.758, 1.757, 1.756, 1.756, 1.755, 1.754, 1.754, + 1.753, 1.752, 1.752, 1.751, 1.751, 1.75, 1.749, 1.749, 1.748, 1.747, + 1.747, 1.746, 1.745, 1.745, 1.744, 1.743, 1.743, 1.742, 1.742, 1.741, + 1.74, 1.74, 1.739, 1.738, 1.738, 1.737, 1.736, 1.736, 1.735, 1.735, + 1.734, 1.733, 1.733, 1.732, 1.731, 1.731, 1.73, 1.729, 1.729, 1.728, + 1.728, 1.727, 1.726, 1.726, 1.725, 1.724, 1.724, 1.723, 1.723, 1.722, + 1.721, 1.721, 1.72, 1.719, 1.719, 1.718, 1.718, 1.717, 1.716, 1.716, + 1.715, 1.715, 1.714, 1.713, 1.713, 1.712, 1.711, 1.711, 1.71, 1.71, + 1.709, 1.708, 1.708, 1.707, 1.706, 1.706, 1.705, 1.705, 1.704, 1.703, + 1.703, 1.702, 1.702, 1.701, 1.7, 1.7, 1.699, 1.699, 1.698, 1.697, + 1.697, 1.696, 1.696, 1.695, 1.694, 1.694, 1.693, 1.692, 1.692, 1.691, + 1.691, 1.69, 1.689, 1.689, 1.688, 1.688, 1.687, 1.686, 1.686, 1.685, + 1.685, 1.684, 1.683, 1.683, 1.682, 1.682, 1.681, 1.68, 1.68, 1.679, + 1.679, 1.678, 1.678, 1.677, 1.676, 1.676, 1.675, 1.675, 1.674, 1.673, + 1.673, 1.672, 1.672, 1.671, 1.67, 1.67, 1.669, 1.669, 1.668, 1.667, + 1.667, 1.666, 1.666, 1.665, 1.665, 1.664, 1.663, 1.663, 1.662, 1.662, + 1.661, 1.66, 1.66, 1.659, 1.659, 1.658, 1.658, 1.657, 1.656, 1.656, + 1.655, 1.655, 1.654, 1.653, 1.653, 1.652, 1.652, 1.651, 1.651, 1.65, + 1.649, 1.649, 1.648, 1.648, 1.647, 1.647, 1.646, 1.645, 1.645, 1.644, + 1.644, 1.643, 1.643, 1.642, 1.641, 1.641, 1.64, 1.64, 1.639, 1.639, + 1.638, 1.637, 1.637, 1.636, 1.636, 1.635, 1.635, 1.634, 1.633, 1.633, + 1.632, 1.632, 1.631, 1.631, 1.63, 1.629, 1.629, 1.628, 1.628, 1.627, + 1.627, 1.626, 1.625, 1.625, 1.624, 1.624, 1.623, 1.623, 1.622, 1.622, + 1.621, 1.62, 1.62, 1.619, 1.619, 1.618, 1.618, 1.617, 1.617, 1.616, + 1.615, 1.615, 1.614, 1.614, 1.613, 1.613, 1.612, 1.612, 1.611, 1.61, + 1.61, 1.609, 1.609, 1.608, 1.608, 1.607, 1.607, 1.606, 1.605, 1.605, + 1.604, 1.604, 1.603, 1.603, 1.602, 1.602, 1.601, 1.6, 1.6, 1.599, + 1.599, 1.598, 1.598, 1.597, 1.597, 1.596, 1.596, 1.595, 1.594, 1.594, + 1.593, 1.593, 1.592, 1.592, 1.591, 1.591, 1.59, 1.59, 1.589, 1.588, + 1.588, 1.587, 1.587, 1.586, 1.586, 1.585, 1.585, 1.584, 1.584, 1.583, + 1.582, 1.582, 1.581, 1.581, 1.58, 1.58, 1.579, 1.579, 1.578, 1.578, + 1.577, 1.577, 1.576, 1.576, 1.575, 1.574, 1.574, 1.573, 1.573, 1.572, + 1.572, 1.571, 1.571, 1.57, 1.57, 1.569, 1.569, 1.568, 1.567, 1.567, + 1.566, 1.566, 1.565, 1.565, 1.564, 1.564, 1.563, 1.563, 1.562, 1.562, + 1.561, 1.561, 1.56, 1.56, 1.559, 1.558, 1.558, 1.557, 1.557, 1.556, + 1.556, 1.555, 1.555, 1.554, 1.554, 1.553, 1.553, 1.552, 1.552, 1.551, + 1.551, 1.55, 1.55, 1.549, 1.549, 1.548, 1.547, 1.547, 1.546, 1.546, + 1.545, 1.545, 1.544, 1.544, 1.543, 1.543, 1.542, 1.542, 1.541, 1.541, + 1.54, 1.54, 1.539, 1.539, 1.538, 1.538, 1.537, 1.537, 1.536, 1.536, + 1.535, 1.534, 1.534, 1.533, 1.533, 1.532, 1.532, 1.531, 1.531, 1.53, + 1.53, 1.529, 1.529, 1.528, 1.528, 1.527, 1.527, 1.526, 1.526, 1.525, + 1.525, 1.524, 1.524, 1.523, 1.523, 1.522, 1.522, 1.521, 1.521, 1.52, + 1.52, 1.519, 1.519, 1.518, 1.518, 1.517, 1.517, 1.516, 1.516, 1.515, + 1.515, 1.514, 1.514, 1.513, 1.512, 1.512, 1.511, 1.511, 1.51, 1.51, + 1.509, 1.509, 1.508, 1.508, 1.507, 1.507, 1.506, 1.506, 1.505, 1.505, + 1.504, 1.504, 1.503, 1.503, 1.502, 1.502, 1.501, 1.501, 1.5, 1.5, + 1.499, 1.499, 1.498, 1.498, 1.497, 1.497, 1.496, 1.496, 1.495, 1.495, + 1.494, 1.494, 1.493, 1.493, 1.492, 1.492, 1.491, 1.491, 1.49, 1.49, + 1.489, 1.489, 1.488, 1.488, 1.487, 1.487, 1.486, 1.486, 1.485, 1.485, + 1.484, 1.484, 1.483, 1.483, 1.482, 1.482, 1.481, 1.481, 1.48, 1.48, + 1.48, 1.479, 1.479, 1.478, 1.478, 1.477, 1.477, 1.476, 1.476, 1.475, + 1.475, 1.474, 1.474, 1.473, 1.473, 1.472, 1.472, 1.471, 1.471, 1.47, + 1.47, 1.469, 1.469, 1.468, 1.468, 1.467, 1.467, 1.466, 1.466, 1.465, + 1.465, 1.464, 1.464, 1.463, 1.463, 1.462, 1.462, 1.461, 1.461, 1.46, + 1.46, 1.459, 1.459, 1.458, 1.458, 1.458, 1.457, 1.457, 1.456, 1.456, + 1.455, 1.455, 1.454, 1.454, 1.453, 1.453, 1.452, 1.452, 1.451, 1.451, + 1.45, 1.45, 1.449, 1.449, 1.448, 1.448, 1.447, 1.447, 1.446, 1.446, + 1.445, 1.445, 1.444, 1.444, 1.444, 1.443, 1.443, 1.442, 1.442, 1.441, + 1.441, 1.44, 1.44, 1.439, 1.439, 1.438, 1.438, 1.437, 1.437, 1.436, + 1.436, 1.435, 1.435, 1.434, 1.434, 1.434, 1.433, 1.433, 1.432, 1.432, + 1.431, 1.431, 1.43, 1.43, 1.429, 1.429, 1.428, 1.428, 1.427, 1.427, + 1.426, 1.426, 1.425, 1.425, 1.424, 1.424, 1.424, 1.423, 1.423, 1.422, + 1.422, 1.421, 1.421, 1.42, 1.42, 1.419, 1.419, 1.418, 1.418, 1.417, + 1.417, 1.416, 1.416, 1.416, 1.415, 1.415, 1.414, 1.414, 1.413, 1.413, + 1.412, 1.412, 1.411, 1.411, 1.41, 1.41, 1.409, 1.409, 1.409, 1.408, + 1.408, 1.407, 1.407, 1.406, 1.406, 1.405, 1.405, 1.404, 1.404, 1.403, + 1.403, 1.402, 1.402, 1.402, 1.401, 1.401, 1.4, 1.4, 1.399, 1.399, + 1.398, 1.398, 1.397, 1.397, 1.396, 1.396, 1.395, 1.395, 1.395, 1.394, + 1.394, 1.393, 1.393, 1.392, 1.392, 1.391, 1.391, 1.39, 1.39, 1.389, + 1.389, 1.389, 1.388, 1.388, 1.387, 1.387, 1.386, 1.386, 1.385, 1.385, + 1.384, 1.384, 1.383, 1.383, 1.383, 1.382, 1.382, 1.381, 1.381, 1.38, + 1.38, 1.379, 1.379, 1.378, 1.378, 1.378, 1.377, 1.377, 1.376, 1.376, + 1.375, 1.375, 1.374, 1.374, 1.373, 1.373, 1.373, 1.372, 1.372, 1.371, + 1.371, 1.37, 1.37, 1.369, 1.369, 1.368, 1.368, 1.367, 1.367, 1.367, + 1.366, 1.366, 1.365, 1.365, 1.364, 1.364, 1.363, 1.363, 1.362, 1.362, + 1.362, 1.361, 1.361, 1.36, 1.36, 1.359, 1.359, 1.358, 1.358, 1.358, + 1.357, 1.357, 1.356, 1.356, 1.355, 1.355, 1.354, 1.354, 1.353, 1.353, + 1.353, 1.352, 1.352, 1.351, 1.351, 1.35, 1.35, 1.349, 1.349, 1.349, + 1.348, 1.348, 1.347, 1.347, 1.346, 1.346, 1.345, 1.345, 1.344, 1.344, + 1.344, 1.343, 1.343, 1.342, 1.342, 1.341, 1.341, 1.34, 1.34, 1.34, + 1.339, 1.339, 1.338, 1.338, 1.337, 1.337, 1.336, 1.336, 1.336, 1.335, + 1.335, 1.334, 1.334, 1.333, 1.333, 1.332, 1.332, 1.332, 1.331, 1.331, + 1.33, 1.33, 1.329, 1.329, 1.328, 1.328, 1.328, 1.327, 1.327, 1.326, + 1.326, 1.325, 1.325, 1.324, 1.324, 1.324, 1.323, 1.323, 1.322, 1.322, + 1.321, 1.321, 1.32, 1.32, 1.32, 1.319, 1.319, 1.318, 1.318, 1.317, + 1.317, 1.316, 1.316, 1.316, 1.315, 1.315, 1.314, 1.314, 1.313, 1.313, + 1.312, 1.312, 1.312, 1.311, 1.311, 1.31, 1.31, 1.309, 1.309, 1.309, + 1.308, 1.308, 1.307, 1.307, 1.306, 1.306, 1.305, 1.305, 1.305, 1.304, + 1.304, 1.303, 1.303, 1.302, 1.302, 1.302, 1.301, 1.301, 1.3, 1.3, + 1.299, 1.299, 1.298, 1.298, 1.298, 1.297, 1.297, 1.296, 1.296, 1.295, + 1.295, 1.295, 1.294, 1.294, 1.293, 1.293, 1.292, 1.292, 1.291, 1.291, + 1.291, 1.29, 1.29, 1.289, 1.289, 1.288, 1.288, 1.288, 1.287, 1.287, + 1.286, 1.286, 1.285, 1.285, 1.285, 1.284, 1.284, 1.283, 1.283, 1.282, + 1.282, 1.281, 1.281, 1.281, 1.28, 1.28, 1.279, 1.279, 1.278, 1.278, + 1.278, 1.277, 1.277, 1.276, 1.276, 1.275, 1.275, 1.275, 1.274, 1.274, + 1.273, 1.273, 1.272, 1.272, 1.272, 1.271, 1.271, 1.27, 1.27, 1.269, + 1.269, 1.269, 1.268, 1.268, 1.267, 1.267, 1.266, 1.266, 1.266, 1.265, + 1.265, 1.264, 1.264, 1.263, 1.263, 1.263, 1.262, 1.262, 1.261, 1.261, + 1.26, 1.26, 1.26, 1.259, 1.259, 1.258, 1.258, 1.257, 1.257, 1.257, + 1.256, 1.256, 1.255, 1.255, 1.254, 1.254, 1.254, 1.253, 1.253, 1.252, + 1.252, 1.251, 1.251, 1.251, 1.25, 1.25, 1.249, 1.249, 1.248, 1.248, + 1.248, 1.247, 1.247, 1.246, 1.246, 1.245, 1.245, 1.245, 1.244, 1.244, + 1.243, 1.243, 1.242, 1.242, 1.242, 1.241, 1.241, 1.24, 1.24, 1.239, + 1.239, 1.239, 1.238, 1.238, 1.237, 1.237, 1.237, 1.236, 1.236, 1.235, + 1.235, 1.234, 1.234, 1.234, 1.233, 1.233, 1.232, 1.232, 1.231, 1.231, + 1.231, 1.23, 1.23, 1.229, 1.229, 1.228, 1.228, 1.228, 1.227, 1.227, + 1.226, 1.226, 1.226, 1.225, 1.225, 1.224, 1.224, 1.223, 1.223, 1.223, + 1.222, 1.222, 1.221, 1.221, 1.22, 1.22, 1.22, 1.219, 1.219, 1.218, + 1.218, 1.218, 1.217, 1.217, 1.216, 1.216, 1.215, 1.215, 1.215, 1.214, + 1.214, 1.213, 1.213, 1.212, 1.212, 1.212, 1.211, 1.211, 1.21, 1.21, + 1.21, 1.209, 1.209, 1.208, 1.208, 1.207, 1.207, 1.207, 1.206, 1.206, + 1.205, 1.205, 1.204, 1.204, 1.204, 1.203, 1.203, 1.202, 1.202, 1.202, + 1.201, 1.201, 1.2, 1.2, 1.199, 1.199, 1.199, 1.198, 1.198, 1.197, + 1.197, 1.197, 1.196, 1.196, 1.195, 1.195, 1.194, 1.194, 1.194, 1.193, + 1.193, 1.192, 1.192, 1.192, 1.191, 1.191, 1.19, 1.19, 1.189, 1.189, + 1.189, 1.188, 1.188, 1.187, 1.187, 1.187, 1.186, 1.186, 1.185, 1.185, + 1.184, 1.184, 1.184, 1.183, 1.183, 1.182, 1.182, 1.182, 1.181, 1.181, + 1.18, 1.18, 1.179, 1.179, 1.179, 1.178, 1.178, 1.177, 1.177, 1.177, + 1.176, 1.176, 1.175, 1.175, 1.175, 1.174, 1.174, 1.173, 1.173, 1.172, + 1.172, 1.172, 1.171, 1.171, 1.17, 1.17, 1.17, 1.169, 1.169, 1.168, + 1.168, 1.167, 1.167, 1.167, 1.166, 1.166, 1.165, 1.165, 1.165, 1.164, + 1.164, 1.163, 1.163, 1.163, 1.162, 1.162, 1.161, 1.161, 1.16, 1.16, + 1.16, 1.159, 1.159, 1.158, 1.158, 1.158, 1.157, 1.157, 1.156, 1.156, + 1.156, 1.155, 1.155, 1.154, 1.154, 1.153, 1.153, 1.153, 1.152, 1.152, + 1.151, 1.151, 1.151, 1.15, 1.15, 1.149, 1.149, 1.149, 1.148, 1.148, + 1.147, 1.147, 1.146, 1.146, 1.146, 1.145, 1.145, 1.144, 1.144, 1.144, + 1.143, 1.143, 1.142, 1.142, 1.142, 1.141, 1.141, 1.14, 1.14, 1.139, + 1.139, 1.139, 1.138, 1.138, 1.137, 1.137, 1.137, 1.136, 1.136, 1.135, + 1.135, 1.135, 1.134, 1.134, 1.133, 1.133, 1.133, 1.132, 1.132, 1.131, + 1.131, 1.13, 1.13, 1.13, 1.129, 1.129, 1.128, 1.128, 1.128, 1.127, + 1.127, 1.126, 1.126, 1.126, 1.125, 1.125, 1.124, 1.124, 1.124, 1.123, + 1.123, 1.122, 1.122, 1.121, 1.121, 1.121, 1.12, 1.12, 1.119, 1.119, + 1.119, 1.118, 1.118, 1.117, 1.117, 1.117, 1.116, 1.116, 1.115, 1.115, + 1.115, 1.114, 1.114, 1.113, 1.113, 1.113, 1.112, 1.112, 1.111, 1.111, + 1.11, 1.11, 1.11, 1.109, 1.109, 1.108, 1.108, 1.108, 1.107, 1.107, + 1.106, 1.106, 1.106, 1.105, 1.105, 1.104, 1.104, 1.104, 1.103, 1.103, + 1.102, 1.102, 1.102, 1.101, 1.101, 1.1, 1.1, 1.099, 1.099, 1.099, + 1.098, 1.098, 1.097, 1.097, 1.097, 1.096, 1.096, 1.095, 1.095, 1.095, + 1.094, 1.094, 1.093, 1.093, 1.093, 1.092, 1.092, 1.091, 1.091, 1.091, + 1.09, 1.09, 1.089, 1.089, 1.089, 1.088, 1.088, 1.087, 1.087, 1.086, + 1.086, 1.086, 1.085, 1.085, 1.084, 1.084, 1.084, 1.083, 1.083, 1.082, + 1.082, 1.082, 1.081, 1.081, 1.08, 1.08, 1.08, 1.079, 1.079, 1.078, + 1.078, 1.078, 1.077, 1.077, 1.076, 1.076, 1.076, 1.075, 1.075, 1.074, + 1.074, 1.074, 1.073, 1.073, 1.072, 1.072, 1.072, 1.071, 1.071, 1.07, + 1.07, 1.069, 1.069, 1.069, 1.068, 1.068, 1.067, 1.067, 1.067, 1.066, + 1.066, 1.065, 1.065, 1.065, 1.064, 1.064, 1.063, 1.063, 1.063, 1.062, + 1.062, 1.061, 1.061, 1.061, 1.06, 1.06, 1.059, 1.059, 1.059, 1.058, + 1.058, 1.057, 1.057, 1.057, 1.056, 1.056, 1.055, 1.055, 1.055, 1.054, + 1.054, 1.053, 1.053, 1.053, 1.052, 1.052, 1.051, 1.051, 1.05, 1.05, + 1.05, 1.049, 1.049, 1.048, 1.048, 1.048, 1.047, 1.047, 1.046, 1.046, + 1.046, 1.045, 1.045, 1.044, 1.044, 1.044, 1.043, 1.043, 1.042, 1.042, + 1.042, 1.041, 1.041, 1.04, 1.04, 1.04, 1.039, 1.039, 1.038, 1.038, + 1.038, 1.037, 1.037, 1.036, 1.036, 1.036, 1.035, 1.035, 1.034, 1.034, + 1.034, 1.033, 1.033, 1.032, 1.032, 1.032, 1.031, 1.031, 1.03, 1.03, + 1.03, 1.029, 1.029, 1.028, 1.028, 1.028, 1.027, 1.027, 1.026, 1.026, + 1.026, 1.025, 1.025, 1.024, 1.024, 1.023, 1.023, 1.023, 1.022, 1.022, + 1.021, 1.021, 1.021, 1.02, 1.02, 1.019, 1.019, 1.019, 1.018, 1.018, + 1.017, 1.017, 1.017, 1.016, 1.016, 1.015, 1.015, 1.015, 1.014, 1.014, + 1.013, 1.013, 1.013, 1.012, 1.012, 1.011, 1.011, 1.011, 1.01, 1.01, + 1.009, 1.009, 1.009, 1.008, 1.008, 1.007, 1.007, 1.007, 1.006, 1.006, + 1.005, 1.005, 1.005, 1.004, 1.004, 1.003, 1.003, 1.003, 1.002, 1.002, + 1.001, 1.001, 1.001, 1, 0.9997, 0.9993, 0.9989, 0.9985, 0.9981, 0.9977, + 0.9973, 0.9969, 0.9965, 0.9961, 0.9957, 0.9953, 0.9949, 0.9945, 0.9941, 0.9937, + 0.9933, 0.9929, 0.9925, 0.9921, 0.9917, 0.9913, 0.9909, 0.9905, 0.9901, 0.9897, + 0.9893, 0.9889, 0.9885, 0.9881, 0.9877, 0.9873, 0.9869, 0.9865, 0.9861, 0.9856, + 0.9852, 0.9848, 0.9844, 0.984, 0.9836, 0.9832, 0.9828, 0.9824, 0.982, 0.9816, + 0.9812, 0.9808, 0.9804, 0.98, 0.9796, 0.9792, 0.9788, 0.9784, 0.978, 0.9776, + 0.9772, 0.9768, 0.9764, 0.976, 0.9756, 0.9752, 0.9748, 0.9744, 0.974, 0.9736, + 0.9732, 0.9728, 0.9724, 0.972, 0.9716, 0.9712, 0.9707, 0.9703, 0.9699, 0.9695, + 0.9691, 0.9687, 0.9683, 0.9679, 0.9675, 0.9671, 0.9667, 0.9663, 0.9659, 0.9655, + 0.9651, 0.9647, 0.9643, 0.9639, 0.9635, 0.9631, 0.9627, 0.9623, 0.9619, 0.9615, + 0.9611, 0.9607, 0.9603, 0.9599, 0.9595, 0.9591, 0.9587, 0.9583, 0.9579, 0.9574, + 0.957, 0.9566, 0.9562, 0.9558, 0.9554, 0.955, 0.9546, 0.9542, 0.9538, 0.9534, + 0.953, 0.9526, 0.9522, 0.9518, 0.9514, 0.951, 0.9506, 0.9502, 0.9498, 0.9494, + 0.949, 0.9486, 0.9482, 0.9478, 0.9474, 0.947, 0.9466, 0.9462, 0.9457, 0.9453, + 0.9449, 0.9445, 0.9441, 0.9437, 0.9433, 0.9429, 0.9425, 0.9421, 0.9417, 0.9413, + 0.9409, 0.9405, 0.9401, 0.9397, 0.9393, 0.9389, 0.9385, 0.9381, 0.9377, 0.9373, + 0.9369, 0.9365, 0.9361, 0.9356, 0.9352, 0.9348, 0.9344, 0.934, 0.9336, 0.9332, + 0.9328, 0.9324, 0.932, 0.9316, 0.9312, 0.9308, 0.9304, 0.93, 0.9296, 0.9292, + 0.9288, 0.9284, 0.928, 0.9276, 0.9272, 0.9267, 0.9263, 0.9259, 0.9255, 0.9251, + 0.9247, 0.9243, 0.9239, 0.9235, 0.9231, 0.9227, 0.9223, 0.9219, 0.9215, 0.9211, + 0.9207, 0.9203, 0.9199, 0.9195, 0.9191, 0.9186, 0.9182, 0.9178, 0.9174, 0.917, + 0.9166, 0.9162, 0.9158, 0.9154, 0.915, 0.9146, 0.9142, 0.9138, 0.9134, 0.913, + 0.9126, 0.9122, 0.9118, 0.9113, 0.9109, 0.9105, 0.9101, 0.9097, 0.9093, 0.9089, + 0.9085, 0.9081, 0.9077, 0.9073, 0.9069, 0.9065, 0.9061, 0.9057, 0.9053, 0.9049, + 0.9044, 0.904, 0.9036, 0.9032, 0.9028, 0.9024, 0.902, 0.9016, 0.9012, 0.9008, + 0.9004, 0.9, 0.8996, 0.8992, 0.8988, 0.8983, 0.8979, 0.8975, 0.8971, 0.8967, + 0.8963, 0.8959, 0.8955, 0.8951, 0.8947, 0.8943, 0.8939, 0.8935, 0.8931, 0.8926, + 0.8922, 0.8918, 0.8914, 0.891, 0.8906, 0.8902, 0.8898, 0.8894, 0.889, 0.8886, + 0.8882, 0.8878, 0.8873, 0.8869, 0.8865, 0.8861, 0.8857, 0.8853, 0.8849, 0.8845, + 0.8841, 0.8837, 0.8833, 0.8829, 0.8825, 0.882, 0.8816, 0.8812, 0.8808, 0.8804, + 0.88, 0.8796, 0.8792, 0.8788, 0.8784, 0.878, 0.8775, 0.8771, 0.8767, 0.8763, + 0.8759, 0.8755, 0.8751, 0.8747, 0.8743, 0.8739, 0.8735, 0.873, 0.8726, 0.8722, + 0.8718, 0.8714, 0.871, 0.8706, 0.8702, 0.8698, 0.8694, 0.869, 0.8685, 0.8681, + 0.8677, 0.8673, 0.8669, 0.8665, 0.8661, 0.8657, 0.8653, 0.8649, 0.8644, 0.864, + 0.8636, 0.8632, 0.8628, 0.8624, 0.862, 0.8616, 0.8612, 0.8607, 0.8603, 0.8599, + 0.8595, 0.8591, 0.8587, 0.8583, 0.8579, 0.8575, 0.857, 0.8566, 0.8562, 0.8558, + 0.8554, 0.855, 0.8546, 0.8542, 0.8538, 0.8533, 0.8529, 0.8525, 0.8521, 0.8517, + 0.8513, 0.8509, 0.8505, 0.85, 0.8496, 0.8492, 0.8488, 0.8484, 0.848, 0.8476, + 0.8472, 0.8467, 0.8463, 0.8459, 0.8455, 0.8451, 0.8447, 0.8443, 0.8439, 0.8434, + 0.843, 0.8426, 0.8422, 0.8418, 0.8414, 0.841, 0.8406, 0.8401, 0.8397, 0.8393, + 0.8389, 0.8385, 0.8381, 0.8377, 0.8372, 0.8368, 0.8364, 0.836, 0.8356, 0.8352, + 0.8348, 0.8343, 0.8339, 0.8335, 0.8331, 0.8327, 0.8323, 0.8319, 0.8314, 0.831, + 0.8306, 0.8302, 0.8298, 0.8294, 0.8289, 0.8285, 0.8281, 0.8277, 0.8273, 0.8269, + 0.8265, 0.826, 0.8256, 0.8252, 0.8248, 0.8244, 0.824, 0.8235, 0.8231, 0.8227, + 0.8223, 0.8219, 0.8215, 0.821, 0.8206, 0.8202, 0.8198, 0.8194, 0.819, 0.8185, + 0.8181, 0.8177, 0.8173, 0.8169, 0.8165, 0.816, 0.8156, 0.8152, 0.8148, 0.8144, + 0.814, 0.8135, 0.8131, 0.8127, 0.8123, 0.8119, 0.8114, 0.811, 0.8106, 0.8102, + 0.8098, 0.8094, 0.8089, 0.8085, 0.8081, 0.8077, 0.8073, 0.8068, 0.8064, 0.806, + 0.8056, 0.8052, 0.8047, 0.8043, 0.8039, 0.8035, 0.8031, 0.8026, 0.8022, 0.8018, + 0.8014, 0.801, 0.8005, 0.8001, 0.7997, 0.7993, 0.7989, 0.7984, 0.798, 0.7976, + 0.7972, 0.7968, 0.7963, 0.7959, 0.7955, 0.7951, 0.7947, 0.7942, 0.7938, 0.7934, + 0.793, 0.7926, 0.7921, 0.7917, 0.7913, 0.7909, 0.7904, 0.79, 0.7896, 0.7892, + 0.7888, 0.7883, 0.7879, 0.7875, 0.7871, 0.7866, 0.7862, 0.7858, 0.7854, 0.7849, + 0.7845, 0.7841, 0.7837, 0.7833, 0.7828, 0.7824, 0.782, 0.7816, 0.7811, 0.7807, + 0.7803, 0.7799, 0.7794, 0.779, 0.7786, 0.7782, 0.7777, 0.7773, 0.7769, 0.7765, + 0.776, 0.7756, 0.7752, 0.7748, 0.7743, 0.7739, 0.7735, 0.7731, 0.7726, 0.7722, + 0.7718, 0.7714, 0.7709, 0.7705, 0.7701, 0.7697, 0.7692, 0.7688, 0.7684, 0.7679, + 0.7675, 0.7671, 0.7667, 0.7662, 0.7658, 0.7654, 0.765, 0.7645, 0.7641, 0.7637, + 0.7632, 0.7628, 0.7624, 0.762, 0.7615, 0.7611, 0.7607, 0.7602, 0.7598, 0.7594, + 0.759, 0.7585, 0.7581, 0.7577, 0.7572, 0.7568, 0.7564, 0.756, 0.7555, 0.7551, + 0.7547, 0.7542, 0.7538, 0.7534, 0.7529, 0.7525, 0.7521, 0.7516, 0.7512, 0.7508, + 0.7504, 0.7499, 0.7495, 0.7491, 0.7486, 0.7482, 0.7478, 0.7473, 0.7469, 0.7465, + 0.746, 0.7456, 0.7452, 0.7447, 0.7443, 0.7439, 0.7434, 0.743, 0.7426, 0.7421, + 0.7417, 0.7413, 0.7408, 0.7404, 0.74, 0.7395, 0.7391, 0.7387, 0.7382, 0.7378, + 0.7374, 0.7369, 0.7365, 0.7361, 0.7356, 0.7352, 0.7348, 0.7343, 0.7339, 0.7335, + 0.733, 0.7326, 0.7321, 0.7317, 0.7313, 0.7308, 0.7304, 0.73, 0.7295, 0.7291, + 0.7287, 0.7282, 0.7278, 0.7273, 0.7269, 0.7265, 0.726, 0.7256, 0.7252, 0.7247, + 0.7243, 0.7238, 0.7234, 0.723, 0.7225, 0.7221, 0.7216, 0.7212, 0.7208, 0.7203, + 0.7199, 0.7195, 0.719, 0.7186, 0.7181, 0.7177, 0.7173, 0.7168, 0.7164, 0.7159, + 0.7155, 0.7151, 0.7146, 0.7142, 0.7137, 0.7133, 0.7128, 0.7124, 0.712, 0.7115, + 0.7111, 0.7106, 0.7102, 0.7098, 0.7093, 0.7089, 0.7084, 0.708, 0.7075, 0.7071, + 0.7066, 0.7062, 0.7058, 0.7053, 0.7049, 0.7044, 0.704, 0.7035, 0.7031, 0.7027, + 0.7022, 0.7018, 0.7013, 0.7009, 0.7004, 0.7, 0.6995, 0.6991, 0.6986, 0.6982, + 0.6978, 0.6973, 0.6969, 0.6964, 0.696, 0.6955, 0.6951, 0.6946, 0.6942, 0.6937, + 0.6933, 0.6928, 0.6924, 0.6919, 0.6915, 0.691, 0.6906, 0.6901, 0.6897, 0.6892, + 0.6888, 0.6883, 0.6879, 0.6874, 0.687, 0.6865, 0.6861, 0.6856, 0.6852, 0.6847, + 0.6843, 0.6838, 0.6834, 0.6829, 0.6825, 0.682, 0.6816, 0.6811, 0.6807, 0.6802, + 0.6798, 0.6793, 0.6789, 0.6784, 0.678, 0.6775, 0.6771, 0.6766, 0.6762, 0.6757, + 0.6752, 0.6748, 0.6743, 0.6739, 0.6734, 0.673, 0.6725, 0.6721, 0.6716, 0.6711, + 0.6707, 0.6702, 0.6698, 0.6693, 0.6689, 0.6684, 0.668, 0.6675, 0.667, 0.6666, + 0.6661, 0.6657, 0.6652, 0.6648, 0.6643, 0.6638, 0.6634, 0.6629, 0.6625, 0.662, + 0.6615, 0.6611, 0.6606, 0.6602, 0.6597, 0.6592, 0.6588, 0.6583, 0.6579, 0.6574, + 0.6569, 0.6565, 0.656, 0.6556, 0.6551, 0.6546, 0.6542, 0.6537, 0.6532, 0.6528, + 0.6523, 0.6519, 0.6514, 0.6509, 0.6505, 0.65, 0.6495, 0.6491, 0.6486, 0.6481, + 0.6477, 0.6472, 0.6468, 0.6463, 0.6458, 0.6454, 0.6449, 0.6444, 0.644, 0.6435, + 0.643, 0.6426, 0.6421, 0.6416, 0.6412, 0.6407, 0.6402, 0.6397, 0.6393, 0.6388, + 0.6383, 0.6379, 0.6374, 0.6369, 0.6365, 0.636, 0.6355, 0.6351, 0.6346, 0.6341, + 0.6336, 0.6332, 0.6327, 0.6322, 0.6318, 0.6313, 0.6308, 0.6303, 0.6299, 0.6294, + 0.6289, 0.6285, 0.628, 0.6275, 0.627, 0.6266, 0.6261, 0.6256, 0.6251, 0.6247, + 0.6242, 0.6237, 0.6232, 0.6228, 0.6223, 0.6218, 0.6213, 0.6208, 0.6204, 0.6199, + 0.6194, 0.6189, 0.6185, 0.618, 0.6175, 0.617, 0.6165, 0.6161, 0.6156, 0.6151, + 0.6146, 0.6142, 0.6137, 0.6132, 0.6127, 0.6122, 0.6117, 0.6113, 0.6108, 0.6103, + 0.6098, 0.6093, 0.6089, 0.6084, 0.6079, 0.6074, 0.6069, 0.6064, 0.606, 0.6055, + 0.605, 0.6045, 0.604, 0.6035, 0.603, 0.6026, 0.6021, 0.6016, 0.6011, 0.6006, + 0.6001, 0.5996, 0.5992, 0.5987, 0.5982, 0.5977, 0.5972, 0.5967, 0.5962, 0.5957, + 0.5952, 0.5948, 0.5943, 0.5938, 0.5933, 0.5928, 0.5923, 0.5918, 0.5913, 0.5908, + 0.5903, 0.5898, 0.5894, 0.5889, 0.5884, 0.5879, 0.5874, 0.5869, 0.5864, 0.5859, + 0.5854, 0.5849, 0.5844, 0.5839, 0.5834, 0.5829, 0.5824, 0.5819, 0.5814, 0.5809, + 0.5804, 0.5799, 0.5794, 0.5789, 0.5784, 0.5779, 0.5774, 0.5769, 0.5764, 0.5759, + 0.5754, 0.5749, 0.5744, 0.5739, 0.5734, 0.5729, 0.5724, 0.5719, 0.5714, 0.5709, + 0.5704, 0.5699, 0.5694, 0.5689, 0.5684, 0.5679, 0.5674, 0.5669, 0.5664, 0.5659, + 0.5654, 0.5649, 0.5644, 0.5639, 0.5633, 0.5628, 0.5623, 0.5618, 0.5613, 0.5608, + 0.5603, 0.5598, 0.5593, 0.5588, 0.5582, 0.5577, 0.5572, 0.5567, 0.5562, 0.5557, + 0.5552, 0.5547, 0.5541, 0.5536, 0.5531, 0.5526, 0.5521, 0.5516, 0.5511, 0.5505, + 0.55, 0.5495, 0.549, 0.5485, 0.548, 0.5474, 0.5469, 0.5464, 0.5459, 0.5454, + 0.5448, 0.5443, 0.5438, 0.5433, 0.5428, 0.5422, 0.5417, 0.5412, 0.5407, 0.5402, + 0.5396, 0.5391, 0.5386, 0.5381, 0.5375, 0.537, 0.5365, 0.536, 0.5354, 0.5349, + 0.5344, 0.5339, 0.5333, 0.5328, 0.5323, 0.5317, 0.5312, 0.5307, 0.5302, 0.5296, + 0.5291, 0.5286, 0.528, 0.5275, 0.527, 0.5264, 0.5259, 0.5254, 0.5248, 0.5243, + 0.5238, 0.5232, 0.5227, 0.5222, 0.5216, 0.5211, 0.5206, 0.52, 0.5195, 0.5189, + 0.5184, 0.5179, 0.5173, 0.5168, 0.5162, 0.5157, 0.5152, 0.5146, 0.5141, 0.5135, + 0.513, 0.5124, 0.5119, 0.5114, 0.5108, 0.5103, 0.5097, 0.5092, 0.5086, 0.5081, + 0.5075, 0.507, 0.5064, 0.5059, 0.5053, 0.5048, 0.5043, 0.5037, 0.5032, 0.5026, + 0.502, 0.5015, 0.5009, 0.5004, 0.4998, 0.4993, 0.4987, 0.4982, 0.4976, 0.4971, + 0.4965, 0.496, 0.4954, 0.4948, 0.4943, 0.4937, 0.4932, 0.4926, 0.492, 0.4915, + 0.4909, 0.4904, 0.4898, 0.4892, 0.4887, 0.4881, 0.4875, 0.487, 0.4864, 0.4859, + 0.4853, 0.4847, 0.4842, 0.4836, 0.483, 0.4825, 0.4819, 0.4813, 0.4807, 0.4802, + 0.4796, 0.479, 0.4785, 0.4779, 0.4773, 0.4767, 0.4762, 0.4756, 0.475, 0.4744, + 0.4739, 0.4733, 0.4727, 0.4721, 0.4716, 0.471, 0.4704, 0.4698, 0.4692, 0.4687, + 0.4681, 0.4675, 0.4669, 0.4663, 0.4657, 0.4652, 0.4646, 0.464, 0.4634, 0.4628, + 0.4622, 0.4616, 0.461, 0.4605, 0.4599, 0.4593, 0.4587, 0.4581, 0.4575, 0.4569, + 0.4563, 0.4557, 0.4551, 0.4545, 0.4539, 0.4533, 0.4527, 0.4521, 0.4515, 0.451, + 0.4504, 0.4498, 0.4491, 0.4485, 0.4479, 0.4473, 0.4467, 0.4461, 0.4455, 0.4449, + 0.4443, 0.4437, 0.4431, 0.4425, 0.4419, 0.4413, 0.4407, 0.4401, 0.4394, 0.4388, + 0.4382, 0.4376, 0.437, 0.4364, 0.4358, 0.4351, 0.4345, 0.4339, 0.4333, 0.4327, + 0.4321, 0.4314, 0.4308, 0.4302, 0.4296, 0.4289, 0.4283, 0.4277, 0.4271, 0.4264, + 0.4258, 0.4252, 0.4246, 0.4239, 0.4233, 0.4227, 0.422, 0.4214, 0.4208, 0.4201, + 0.4195, 0.4189, 0.4182, 0.4176, 0.4169, 0.4163, 0.4157, 0.415, 0.4144, 0.4137, + 0.4131, 0.4125, 0.4118, 0.4112, 0.4105, 0.4099, 0.4092, 0.4086, 0.4079, 0.4073, + 0.4066, 0.406, 0.4053, 0.4047, 0.404, 0.4034, 0.4027, 0.402, 0.4014, 0.4007, + 0.4001, 0.3994, 0.3987, 0.3981, 0.3974, 0.3967, 0.3961, 0.3954, 0.3947, 0.3941, + 0.3934, 0.3927, 0.3921, 0.3914, 0.3907, 0.39, 0.3894, 0.3887, 0.388, 0.3873, + 0.3866, 0.386, 0.3853, 0.3846, 0.3839, 0.3832, 0.3825, 0.3819, 0.3812, 0.3805, + 0.3798, 0.3791, 0.3784, 0.3777, 0.377, 0.3763, 0.3756, 0.3749, 0.3742, 0.3735, + 0.3728, 0.3721, 0.3714, 0.3707, 0.37, 0.3693, 0.3686, 0.3679, 0.3672, 0.3665, + 0.3657, 0.365, 0.3643, 0.3636, 0.3629, 0.3622, 0.3614, 0.3607, 0.36, 0.3593, + 0.3585, 0.3578, 0.3571, 0.3564, 0.3556, 0.3549, 0.3542, 0.3534, 0.3527, 0.352, + 0.3512, 0.3505, 0.3497, 0.349, 0.3483, 0.3475, 0.3468, 0.346, 0.3453, 0.3445, + 0.3438, 0.343, 0.3422, 0.3415, 0.3407, 0.34, 0.3392, 0.3384, 0.3377, 0.3369, + 0.3361, 0.3354, 0.3346, 0.3338, 0.3331, 0.3323, 0.3315, 0.3307, 0.3299, 0.3292, + 0.3284, 0.3276, 0.3268, 0.326, 0.3252, 0.3244, 0.3236, 0.3228, 0.3221, 0.3213, + 0.3205, 0.3196, 0.3188, 0.318, 0.3172, 0.3164, 0.3156, 0.3148, 0.314, 0.3132, + 0.3123, 0.3115, 0.3107, 0.3099, 0.309, 0.3082, 0.3074, 0.3065, 0.3057, 0.3049, + 0.304, 0.3032, 0.3023, 0.3015, 0.3007, 0.2998, 0.2989, 0.2981, 0.2972, 0.2964, + 0.2955, 0.2946, 0.2938, 0.2929, 0.292, 0.2912, 0.2903, 0.2894, 0.2885, 0.2877, + 0.2868, 0.2859, 0.285, 0.2841, 0.2832, 0.2823, 0.2814, 0.2805, 0.2796, 0.2787, + 0.2778, 0.2768, 0.2759, 0.275, 0.2741, 0.2732, 0.2722, 0.2713, 0.2704, 0.2694, + 0.2685, 0.2675, 0.2666, 0.2656, 0.2647, 0.2637, 0.2628, 0.2618, 0.2608, 0.2599, + 0.2589, 0.2579, 0.2569, 0.256, 0.255, 0.254, 0.253, 0.252, 0.251, 0.25, + 0.249, 0.248, 0.2469, 0.2459, 0.2449, 0.2439, 0.2428, 0.2418, 0.2408, 0.2397, + 0.2387, 0.2376, 0.2365, 0.2355, 0.2344, 0.2333, 0.2323, 0.2312, 0.2301, 0.229, + 0.2279, 0.2268, 0.2257, 0.2246, 0.2235, 0.2223, 0.2212, 0.2201, 0.2189, 0.2178, + 0.2166, 0.2155, 0.2143, 0.2132, 0.212, 0.2108, 0.2096, 0.2084, 0.2072, 0.206, + 0.2048, 0.2036, 0.2023, 0.2011, 0.1999, 0.1986, 0.1974, 0.1961, 0.1948, 0.1935, + 0.1923, 0.191, 0.1896, 0.1883, 0.187, 0.1857, 0.1843, 0.183, 0.1816, 0.1802, + 0.1789, 0.1775, 0.1761, 0.1747, 0.1732, 0.1718, 0.1703, 0.1689, 0.1674, 0.1659, + 0.1644, 0.1629, 0.1614, 0.1599, 0.1583, 0.1567, 0.1551, 0.1535, 0.1519, 0.1503, + 0.1486, 0.147, 0.1453, 0.1436, 0.1418, 0.1401, 0.1383, 0.1365, 0.1347, 0.1329, + 0.131, 0.1291, 0.1272, 0.1252, 0.1233, 0.1213, 0.1192, 0.1171, 0.115, 0.1129, + 0.1107, 0.1084, 0.1061, 0.1038, 0.1014, 0.09894, 0.09643, 0.09385, 0.0912, 0.08847, + 0.08566, 0.08275, 0.07974, 0.0766, 0.07334, 0.06992, 0.06633, 0.06253, 0.05849, 0.05415, + 0.04943, 0.0442, 0.03828, 0.03125, 0.0221, -0}; + + const static double cosvals[max] = { + 1, 1, 1, 1, 1, 1, 0.9999, 0.9999, 0.9999, 0.9999, + 0.9999, 0.9998, 0.9998, 0.9998, 0.9997, 0.9997, 0.9997, 0.9996, 0.9996, 0.9995, + 0.9995, 0.9994, 0.9994, 0.9993, 0.9993, 0.9992, 0.9991, 0.9991, 0.999, 0.9989, + 0.9989, 0.9988, 0.9987, 0.9986, 0.9986, 0.9985, 0.9984, 0.9983, 0.9982, 0.9981, + 0.998, 0.9979, 0.9978, 0.9977, 0.9976, 0.9975, 0.9974, 0.9973, 0.9972, 0.9971, + 0.9969, 0.9968, 0.9967, 0.9966, 0.9964, 0.9963, 0.9962, 0.996, 0.9959, 0.9958, + 0.9956, 0.9955, 0.9953, 0.9952, 0.995, 0.9949, 0.9947, 0.9946, 0.9944, 0.9942, + 0.9941, 0.9939, 0.9937, 0.9936, 0.9934, 0.9932, 0.993, 0.9929, 0.9927, 0.9925, + 0.9923, 0.9921, 0.9919, 0.9917, 0.9915, 0.9913, 0.9911, 0.9909, 0.9907, 0.9905, + 0.9903, 0.9901, 0.9898, 0.9896, 0.9894, 0.9892, 0.989, 0.9887, 0.9885, 0.9883, + 0.988, 0.9878, 0.9875, 0.9873, 0.9871, 0.9868, 0.9866, 0.9863, 0.9861, 0.9858, + 0.9855, 0.9853, 0.985, 0.9847, 0.9845, 0.9842, 0.9839, 0.9837, 0.9834, 0.9831, + 0.9828, 0.9825, 0.9823, 0.982, 0.9817, 0.9814, 0.9811, 0.9808, 0.9805, 0.9802, + 0.9799, 0.9796, 0.9793, 0.9789, 0.9786, 0.9783, 0.978, 0.9777, 0.9774, 0.977, + 0.9767, 0.9764, 0.976, 0.9757, 0.9754, 0.975, 0.9747, 0.9743, 0.974, 0.9736, + 0.9733, 0.9729, 0.9726, 0.9722, 0.9719, 0.9715, 0.9711, 0.9708, 0.9704, 0.97, + 0.9697, 0.9693, 0.9689, 0.9685, 0.9681, 0.9678, 0.9674, 0.967, 0.9666, 0.9662, + 0.9658, 0.9654, 0.965, 0.9646, 0.9642, 0.9638, 0.9634, 0.963, 0.9625, 0.9621, + 0.9617, 0.9613, 0.9609, 0.9604, 0.96, 0.9596, 0.9591, 0.9587, 0.9583, 0.9578, + 0.9574, 0.9569, 0.9565, 0.956, 0.9556, 0.9551, 0.9547, 0.9542, 0.9538, 0.9533, + 0.9528, 0.9524, 0.9519, 0.9514, 0.951, 0.9505, 0.95, 0.9495, 0.949, 0.9486, + 0.9481, 0.9476, 0.9471, 0.9466, 0.9461, 0.9456, 0.9451, 0.9446, 0.9441, 0.9436, + 0.9431, 0.9426, 0.9421, 0.9415, 0.941, 0.9405, 0.94, 0.9395, 0.9389, 0.9384, + 0.9379, 0.9373, 0.9368, 0.9363, 0.9357, 0.9352, 0.9346, 0.9341, 0.9335, 0.933, + 0.9324, 0.9319, 0.9313, 0.9308, 0.9302, 0.9296, 0.9291, 0.9285, 0.9279, 0.9274, + 0.9268, 0.9262, 0.9256, 0.925, 0.9245, 0.9239, 0.9233, 0.9227, 0.9221, 0.9215, + 0.9209, 0.9203, 0.9197, 0.9191, 0.9185, 0.9179, 0.9173, 0.9167, 0.9161, 0.9154, + 0.9148, 0.9142, 0.9136, 0.913, 0.9123, 0.9117, 0.9111, 0.9104, 0.9098, 0.9092, + 0.9085, 0.9079, 0.9072, 0.9066, 0.9059, 0.9053, 0.9046, 0.904, 0.9033, 0.9027, + 0.902, 0.9013, 0.9007, 0.9, 0.8993, 0.8987, 0.898, 0.8973, 0.8966, 0.896, + 0.8953, 0.8946, 0.8939, 0.8932, 0.8925, 0.8918, 0.8911, 0.8904, 0.8897, 0.889, + 0.8883, 0.8876, 0.8869, 0.8862, 0.8855, 0.8848, 0.8841, 0.8834, 0.8826, 0.8819, + 0.8812, 0.8805, 0.8797, 0.879, 0.8783, 0.8775, 0.8768, 0.8761, 0.8753, 0.8746, + 0.8738, 0.8731, 0.8723, 0.8716, 0.8708, 0.8701, 0.8693, 0.8686, 0.8678, 0.867, + 0.8663, 0.8655, 0.8647, 0.864, 0.8632, 0.8624, 0.8616, 0.8609, 0.8601, 0.8593, + 0.8585, 0.8577, 0.8569, 0.8561, 0.8554, 0.8546, 0.8538, 0.853, 0.8522, 0.8514, + 0.8505, 0.8497, 0.8489, 0.8481, 0.8473, 0.8465, 0.8457, 0.8449, 0.844, 0.8432, + 0.8424, 0.8416, 0.8407, 0.8399, 0.8391, 0.8382, 0.8374, 0.8365, 0.8357, 0.8349, + 0.834, 0.8332, 0.8323, 0.8315, 0.8306, 0.8298, 0.8289, 0.828, 0.8272, 0.8263, + 0.8255, 0.8246, 0.8237, 0.8228, 0.822, 0.8211, 0.8202, 0.8193, 0.8185, 0.8176, + 0.8167, 0.8158, 0.8149, 0.814, 0.8131, 0.8123, 0.8114, 0.8105, 0.8096, 0.8087, + 0.8078, 0.8068, 0.8059, 0.805, 0.8041, 0.8032, 0.8023, 0.8014, 0.8005, 0.7995, + 0.7986, 0.7977, 0.7968, 0.7958, 0.7949, 0.794, 0.793, 0.7921, 0.7912, 0.7902, + 0.7893, 0.7883, 0.7874, 0.7865, 0.7855, 0.7846, 0.7836, 0.7827, 0.7817, 0.7807, + 0.7798, 0.7788, 0.7779, 0.7769, 0.7759, 0.775, 0.774, 0.773, 0.772, 0.7711, + 0.7701, 0.7691, 0.7681, 0.7671, 0.7662, 0.7652, 0.7642, 0.7632, 0.7622, 0.7612, + 0.7602, 0.7592, 0.7582, 0.7572, 0.7562, 0.7552, 0.7542, 0.7532, 0.7522, 0.7512, + 0.7502, 0.7491, 0.7481, 0.7471, 0.7461, 0.7451, 0.744, 0.743, 0.742, 0.741, + 0.7399, 0.7389, 0.7379, 0.7368, 0.7358, 0.7347, 0.7337, 0.7327, 0.7316, 0.7306, + 0.7295, 0.7285, 0.7274, 0.7264, 0.7253, 0.7242, 0.7232, 0.7221, 0.7211, 0.72, + 0.7189, 0.7179, 0.7168, 0.7157, 0.7147, 0.7136, 0.7125, 0.7114, 0.7104, 0.7093, + 0.7082, 0.7071, 0.706, 0.7049, 0.7038, 0.7028, 0.7017, 0.7006, 0.6995, 0.6984, + 0.6973, 0.6962, 0.6951, 0.694, 0.6929, 0.6918, 0.6907, 0.6895, 0.6884, 0.6873, + 0.6862, 0.6851, 0.684, 0.6828, 0.6817, 0.6806, 0.6795, 0.6784, 0.6772, 0.6761, + 0.675, 0.6738, 0.6727, 0.6716, 0.6704, 0.6693, 0.6681, 0.667, 0.6659, 0.6647, + 0.6636, 0.6624, 0.6613, 0.6601, 0.659, 0.6578, 0.6567, 0.6555, 0.6543, 0.6532, + 0.652, 0.6508, 0.6497, 0.6485, 0.6473, 0.6462, 0.645, 0.6438, 0.6427, 0.6415, + 0.6403, 0.6391, 0.6379, 0.6368, 0.6356, 0.6344, 0.6332, 0.632, 0.6308, 0.6296, + 0.6284, 0.6273, 0.6261, 0.6249, 0.6237, 0.6225, 0.6213, 0.6201, 0.6189, 0.6176, + 0.6164, 0.6152, 0.614, 0.6128, 0.6116, 0.6104, 0.6092, 0.6079, 0.6067, 0.6055, + 0.6043, 0.6031, 0.6018, 0.6006, 0.5994, 0.5982, 0.5969, 0.5957, 0.5945, 0.5932, + 0.592, 0.5908, 0.5895, 0.5883, 0.587, 0.5858, 0.5846, 0.5833, 0.5821, 0.5808, + 0.5796, 0.5783, 0.5771, 0.5758, 0.5746, 0.5733, 0.572, 0.5708, 0.5695, 0.5683, + 0.567, 0.5657, 0.5645, 0.5632, 0.5619, 0.5607, 0.5594, 0.5581, 0.5568, 0.5556, + 0.5543, 0.553, 0.5517, 0.5505, 0.5492, 0.5479, 0.5466, 0.5453, 0.544, 0.5428, + 0.5415, 0.5402, 0.5389, 0.5376, 0.5363, 0.535, 0.5337, 0.5324, 0.5311, 0.5298, + 0.5285, 0.5272, 0.5259, 0.5246, 0.5233, 0.522, 0.5207, 0.5194, 0.518, 0.5167, + 0.5154, 0.5141, 0.5128, 0.5115, 0.5102, 0.5088, 0.5075, 0.5062, 0.5049, 0.5035, + 0.5022, 0.5009, 0.4996, 0.4982, 0.4969, 0.4956, 0.4942, 0.4929, 0.4916, 0.4902, + 0.4889, 0.4876, 0.4862, 0.4849, 0.4835, 0.4822, 0.4808, 0.4795, 0.4781, 0.4768, + 0.4755, 0.4741, 0.4727, 0.4714, 0.47, 0.4687, 0.4673, 0.466, 0.4646, 0.4633, + 0.4619, 0.4605, 0.4592, 0.4578, 0.4564, 0.4551, 0.4537, 0.4523, 0.451, 0.4496, + 0.4482, 0.4469, 0.4455, 0.4441, 0.4427, 0.4414, 0.44, 0.4386, 0.4372, 0.4359, + 0.4345, 0.4331, 0.4317, 0.4303, 0.4289, 0.4276, 0.4262, 0.4248, 0.4234, 0.422, + 0.4206, 0.4192, 0.4178, 0.4164, 0.415, 0.4136, 0.4122, 0.4108, 0.4094, 0.408, + 0.4066, 0.4052, 0.4038, 0.4024, 0.401, 0.3996, 0.3982, 0.3968, 0.3954, 0.394, + 0.3926, 0.3912, 0.3898, 0.3883, 0.3869, 0.3855, 0.3841, 0.3827, 0.3813, 0.3798, + 0.3784, 0.377, 0.3756, 0.3742, 0.3727, 0.3713, 0.3699, 0.3685, 0.367, 0.3656, + 0.3642, 0.3628, 0.3613, 0.3599, 0.3585, 0.357, 0.3556, 0.3542, 0.3527, 0.3513, + 0.3499, 0.3484, 0.347, 0.3455, 0.3441, 0.3427, 0.3412, 0.3398, 0.3383, 0.3369, + 0.3354, 0.334, 0.3326, 0.3311, 0.3297, 0.3282, 0.3268, 0.3253, 0.3239, 0.3224, + 0.321, 0.3195, 0.318, 0.3166, 0.3151, 0.3137, 0.3122, 0.3108, 0.3093, 0.3078, + 0.3064, 0.3049, 0.3035, 0.302, 0.3005, 0.2991, 0.2976, 0.2962, 0.2947, 0.2932, + 0.2918, 0.2903, 0.2888, 0.2873, 0.2859, 0.2844, 0.2829, 0.2815, 0.28, 0.2785, + 0.277, 0.2756, 0.2741, 0.2726, 0.2711, 0.2697, 0.2682, 0.2667, 0.2652, 0.2638, + 0.2623, 0.2608, 0.2593, 0.2578, 0.2563, 0.2549, 0.2534, 0.2519, 0.2504, 0.2489, + 0.2474, 0.246, 0.2445, 0.243, 0.2415, 0.24, 0.2385, 0.237, 0.2355, 0.234, + 0.2326, 0.2311, 0.2296, 0.2281, 0.2266, 0.2251, 0.2236, 0.2221, 0.2206, 0.2191, + 0.2176, 0.2161, 0.2146, 0.2131, 0.2116, 0.2101, 0.2086, 0.2071, 0.2056, 0.2041, + 0.2026, 0.2011, 0.1996, 0.1981, 0.1966, 0.1951, 0.1936, 0.1921, 0.1906, 0.1891, + 0.1876, 0.1861, 0.1845, 0.183, 0.1815, 0.18, 0.1785, 0.177, 0.1755, 0.174, + 0.1725, 0.171, 0.1695, 0.1679, 0.1664, 0.1649, 0.1634, 0.1619, 0.1604, 0.1589, + 0.1573, 0.1558, 0.1543, 0.1528, 0.1513, 0.1498, 0.1482, 0.1467, 0.1452, 0.1437, + 0.1422, 0.1407, 0.1391, 0.1376, 0.1361, 0.1346, 0.1331, 0.1315, 0.13, 0.1285, + 0.127, 0.1255, 0.1239, 0.1224, 0.1209, 0.1194, 0.1178, 0.1163, 0.1148, 0.1133, + 0.1117, 0.1102, 0.1087, 0.1072, 0.1056, 0.1041, 0.1026, 0.1011, 0.09954, 0.09802, + 0.09649, 0.09496, 0.09344, 0.09191, 0.09038, 0.08885, 0.08733, 0.0858, 0.08427, 0.08274, + 0.08121, 0.07968, 0.07815, 0.07662, 0.07509, 0.07356, 0.07203, 0.0705, 0.06897, 0.06744, + 0.06591, 0.06438, 0.06285, 0.06132, 0.05979, 0.05826, 0.05673, 0.0552, 0.05366, 0.05213, + 0.0506, 0.04907, 0.04754, 0.046, 0.04447, 0.04294, 0.04141, 0.03987, 0.03834, 0.03681, + 0.03527, 0.03374, 0.03221, 0.03067, 0.02914, 0.02761, 0.02607, 0.02454, 0.02301, 0.02147, + 0.01994, 0.01841, 0.01687, 0.01534, 0.01381, 0.01227, 0.01074, 0.009204, 0.00767, 0.006136, + 0.004602, 0.003068, 0.001534, 6.123e-17, -0.001534, -0.003068, -0.004602, -0.006136, -0.00767, -0.009204, + -0.01074, -0.01227, -0.01381, -0.01534, -0.01687, -0.01841, -0.01994, -0.02147, -0.02301, -0.02454, + -0.02607, -0.02761, -0.02914, -0.03067, -0.03221, -0.03374, -0.03527, -0.03681, -0.03834, -0.03987, + -0.04141, -0.04294, -0.04447, -0.046, -0.04754, -0.04907, -0.0506, -0.05213, -0.05366, -0.0552, + -0.05673, -0.05826, -0.05979, -0.06132, -0.06285, -0.06438, -0.06591, -0.06744, -0.06897, -0.0705, + -0.07203, -0.07356, -0.07509, -0.07662, -0.07815, -0.07968, -0.08121, -0.08274, -0.08427, -0.0858, + -0.08733, -0.08885, -0.09038, -0.09191, -0.09344, -0.09496, -0.09649, -0.09802, -0.09954, -0.1011, + -0.1026, -0.1041, -0.1056, -0.1072, -0.1087, -0.1102, -0.1117, -0.1133, -0.1148, -0.1163, + -0.1178, -0.1194, -0.1209, -0.1224, -0.1239, -0.1255, -0.127, -0.1285, -0.13, -0.1315, + -0.1331, -0.1346, -0.1361, -0.1376, -0.1391, -0.1407, -0.1422, -0.1437, -0.1452, -0.1467, + -0.1482, -0.1498, -0.1513, -0.1528, -0.1543, -0.1558, -0.1573, -0.1589, -0.1604, -0.1619, + -0.1634, -0.1649, -0.1664, -0.1679, -0.1695, -0.171, -0.1725, -0.174, -0.1755, -0.177, + -0.1785, -0.18, -0.1815, -0.183, -0.1845, -0.1861, -0.1876, -0.1891, -0.1906, -0.1921, + -0.1936, -0.1951, -0.1966, -0.1981, -0.1996, -0.2011, -0.2026, -0.2041, -0.2056, -0.2071, + -0.2086, -0.2101, -0.2116, -0.2131, -0.2146, -0.2161, -0.2176, -0.2191, -0.2206, -0.2221, + -0.2236, -0.2251, -0.2266, -0.2281, -0.2296, -0.2311, -0.2326, -0.234, -0.2355, -0.237, + -0.2385, -0.24, -0.2415, -0.243, -0.2445, -0.246, -0.2474, -0.2489, -0.2504, -0.2519, + -0.2534, -0.2549, -0.2563, -0.2578, -0.2593, -0.2608, -0.2623, -0.2638, -0.2652, -0.2667, + -0.2682, -0.2697, -0.2711, -0.2726, -0.2741, -0.2756, -0.277, -0.2785, -0.28, -0.2815, + -0.2829, -0.2844, -0.2859, -0.2873, -0.2888, -0.2903, -0.2918, -0.2932, -0.2947, -0.2962, + -0.2976, -0.2991, -0.3005, -0.302, -0.3035, -0.3049, -0.3064, -0.3078, -0.3093, -0.3108, + -0.3122, -0.3137, -0.3151, -0.3166, -0.318, -0.3195, -0.321, -0.3224, -0.3239, -0.3253, + -0.3268, -0.3282, -0.3297, -0.3311, -0.3326, -0.334, -0.3354, -0.3369, -0.3383, -0.3398, + -0.3412, -0.3427, -0.3441, -0.3455, -0.347, -0.3484, -0.3499, -0.3513, -0.3527, -0.3542, + -0.3556, -0.357, -0.3585, -0.3599, -0.3613, -0.3628, -0.3642, -0.3656, -0.367, -0.3685, + -0.3699, -0.3713, -0.3727, -0.3742, -0.3756, -0.377, -0.3784, -0.3798, -0.3813, -0.3827, + -0.3841, -0.3855, -0.3869, -0.3883, -0.3898, -0.3912, -0.3926, -0.394, -0.3954, -0.3968, + -0.3982, -0.3996, -0.401, -0.4024, -0.4038, -0.4052, -0.4066, -0.408, -0.4094, -0.4108, + -0.4122, -0.4136, -0.415, -0.4164, -0.4178, -0.4192, -0.4206, -0.422, -0.4234, -0.4248, + -0.4262, -0.4276, -0.4289, -0.4303, -0.4317, -0.4331, -0.4345, -0.4359, -0.4372, -0.4386, + -0.44, -0.4414, -0.4427, -0.4441, -0.4455, -0.4469, -0.4482, -0.4496, -0.451, -0.4523, + -0.4537, -0.4551, -0.4564, -0.4578, -0.4592, -0.4605, -0.4619, -0.4633, -0.4646, -0.466, + -0.4673, -0.4687, -0.47, -0.4714, -0.4727, -0.4741, -0.4755, -0.4768, -0.4781, -0.4795, + -0.4808, -0.4822, -0.4835, -0.4849, -0.4862, -0.4876, -0.4889, -0.4902, -0.4916, -0.4929, + -0.4942, -0.4956, -0.4969, -0.4982, -0.4996, -0.5009, -0.5022, -0.5035, -0.5049, -0.5062, + -0.5075, -0.5088, -0.5102, -0.5115, -0.5128, -0.5141, -0.5154, -0.5167, -0.518, -0.5194, + -0.5207, -0.522, -0.5233, -0.5246, -0.5259, -0.5272, -0.5285, -0.5298, -0.5311, -0.5324, + -0.5337, -0.535, -0.5363, -0.5376, -0.5389, -0.5402, -0.5415, -0.5428, -0.544, -0.5453, + -0.5466, -0.5479, -0.5492, -0.5505, -0.5517, -0.553, -0.5543, -0.5556, -0.5568, -0.5581, + -0.5594, -0.5607, -0.5619, -0.5632, -0.5645, -0.5657, -0.567, -0.5683, -0.5695, -0.5708, + -0.572, -0.5733, -0.5746, -0.5758, -0.5771, -0.5783, -0.5796, -0.5808, -0.5821, -0.5833, + -0.5846, -0.5858, -0.587, -0.5883, -0.5895, -0.5908, -0.592, -0.5932, -0.5945, -0.5957, + -0.5969, -0.5982, -0.5994, -0.6006, -0.6018, -0.6031, -0.6043, -0.6055, -0.6067, -0.6079, + -0.6092, -0.6104, -0.6116, -0.6128, -0.614, -0.6152, -0.6164, -0.6176, -0.6189, -0.6201, + -0.6213, -0.6225, -0.6237, -0.6249, -0.6261, -0.6273, -0.6284, -0.6296, -0.6308, -0.632, + -0.6332, -0.6344, -0.6356, -0.6368, -0.6379, -0.6391, -0.6403, -0.6415, -0.6427, -0.6438, + -0.645, -0.6462, -0.6473, -0.6485, -0.6497, -0.6508, -0.652, -0.6532, -0.6543, -0.6555, + -0.6567, -0.6578, -0.659, -0.6601, -0.6613, -0.6624, -0.6636, -0.6647, -0.6659, -0.667, + -0.6681, -0.6693, -0.6704, -0.6716, -0.6727, -0.6738, -0.675, -0.6761, -0.6772, -0.6784, + -0.6795, -0.6806, -0.6817, -0.6828, -0.684, -0.6851, -0.6862, -0.6873, -0.6884, -0.6895, + -0.6907, -0.6918, -0.6929, -0.694, -0.6951, -0.6962, -0.6973, -0.6984, -0.6995, -0.7006, + -0.7017, -0.7028, -0.7038, -0.7049, -0.706, -0.7071, -0.7082, -0.7093, -0.7104, -0.7114, + -0.7125, -0.7136, -0.7147, -0.7157, -0.7168, -0.7179, -0.7189, -0.72, -0.7211, -0.7221, + -0.7232, -0.7242, -0.7253, -0.7264, -0.7274, -0.7285, -0.7295, -0.7306, -0.7316, -0.7327, + -0.7337, -0.7347, -0.7358, -0.7368, -0.7379, -0.7389, -0.7399, -0.741, -0.742, -0.743, + -0.744, -0.7451, -0.7461, -0.7471, -0.7481, -0.7491, -0.7502, -0.7512, -0.7522, -0.7532, + -0.7542, -0.7552, -0.7562, -0.7572, -0.7582, -0.7592, -0.7602, -0.7612, -0.7622, -0.7632, + -0.7642, -0.7652, -0.7662, -0.7671, -0.7681, -0.7691, -0.7701, -0.7711, -0.772, -0.773, + -0.774, -0.775, -0.7759, -0.7769, -0.7779, -0.7788, -0.7798, -0.7807, -0.7817, -0.7827, + -0.7836, -0.7846, -0.7855, -0.7865, -0.7874, -0.7883, -0.7893, -0.7902, -0.7912, -0.7921, + -0.793, -0.794, -0.7949, -0.7958, -0.7968, -0.7977, -0.7986, -0.7995, -0.8005, -0.8014, + -0.8023, -0.8032, -0.8041, -0.805, -0.8059, -0.8068, -0.8078, -0.8087, -0.8096, -0.8105, + -0.8114, -0.8123, -0.8131, -0.814, -0.8149, -0.8158, -0.8167, -0.8176, -0.8185, -0.8193, + -0.8202, -0.8211, -0.822, -0.8228, -0.8237, -0.8246, -0.8255, -0.8263, -0.8272, -0.828, + -0.8289, -0.8298, -0.8306, -0.8315, -0.8323, -0.8332, -0.834, -0.8349, -0.8357, -0.8365, + -0.8374, -0.8382, -0.8391, -0.8399, -0.8407, -0.8416, -0.8424, -0.8432, -0.844, -0.8449, + -0.8457, -0.8465, -0.8473, -0.8481, -0.8489, -0.8497, -0.8505, -0.8514, -0.8522, -0.853, + -0.8538, -0.8546, -0.8554, -0.8561, -0.8569, -0.8577, -0.8585, -0.8593, -0.8601, -0.8609, + -0.8616, -0.8624, -0.8632, -0.864, -0.8647, -0.8655, -0.8663, -0.867, -0.8678, -0.8686, + -0.8693, -0.8701, -0.8708, -0.8716, -0.8723, -0.8731, -0.8738, -0.8746, -0.8753, -0.8761, + -0.8768, -0.8775, -0.8783, -0.879, -0.8797, -0.8805, -0.8812, -0.8819, -0.8826, -0.8834, + -0.8841, -0.8848, -0.8855, -0.8862, -0.8869, -0.8876, -0.8883, -0.889, -0.8897, -0.8904, + -0.8911, -0.8918, -0.8925, -0.8932, -0.8939, -0.8946, -0.8953, -0.896, -0.8966, -0.8973, + -0.898, -0.8987, -0.8993, -0.9, -0.9007, -0.9013, -0.902, -0.9027, -0.9033, -0.904, + -0.9046, -0.9053, -0.9059, -0.9066, -0.9072, -0.9079, -0.9085, -0.9092, -0.9098, -0.9104, + -0.9111, -0.9117, -0.9123, -0.913, -0.9136, -0.9142, -0.9148, -0.9154, -0.9161, -0.9167, + -0.9173, -0.9179, -0.9185, -0.9191, -0.9197, -0.9203, -0.9209, -0.9215, -0.9221, -0.9227, + -0.9233, -0.9239, -0.9245, -0.925, -0.9256, -0.9262, -0.9268, -0.9274, -0.9279, -0.9285, + -0.9291, -0.9296, -0.9302, -0.9308, -0.9313, -0.9319, -0.9324, -0.933, -0.9335, -0.9341, + -0.9346, -0.9352, -0.9357, -0.9363, -0.9368, -0.9373, -0.9379, -0.9384, -0.9389, -0.9395, + -0.94, -0.9405, -0.941, -0.9415, -0.9421, -0.9426, -0.9431, -0.9436, -0.9441, -0.9446, + -0.9451, -0.9456, -0.9461, -0.9466, -0.9471, -0.9476, -0.9481, -0.9486, -0.949, -0.9495, + -0.95, -0.9505, -0.951, -0.9514, -0.9519, -0.9524, -0.9528, -0.9533, -0.9538, -0.9542, + -0.9547, -0.9551, -0.9556, -0.956, -0.9565, -0.9569, -0.9574, -0.9578, -0.9583, -0.9587, + -0.9591, -0.9596, -0.96, -0.9604, -0.9609, -0.9613, -0.9617, -0.9621, -0.9625, -0.963, + -0.9634, -0.9638, -0.9642, -0.9646, -0.965, -0.9654, -0.9658, -0.9662, -0.9666, -0.967, + -0.9674, -0.9678, -0.9681, -0.9685, -0.9689, -0.9693, -0.9697, -0.97, -0.9704, -0.9708, + -0.9711, -0.9715, -0.9719, -0.9722, -0.9726, -0.9729, -0.9733, -0.9736, -0.974, -0.9743, + -0.9747, -0.975, -0.9754, -0.9757, -0.976, -0.9764, -0.9767, -0.977, -0.9774, -0.9777, + -0.978, -0.9783, -0.9786, -0.9789, -0.9793, -0.9796, -0.9799, -0.9802, -0.9805, -0.9808, + -0.9811, -0.9814, -0.9817, -0.982, -0.9823, -0.9825, -0.9828, -0.9831, -0.9834, -0.9837, + -0.9839, -0.9842, -0.9845, -0.9847, -0.985, -0.9853, -0.9855, -0.9858, -0.9861, -0.9863, + -0.9866, -0.9868, -0.9871, -0.9873, -0.9875, -0.9878, -0.988, -0.9883, -0.9885, -0.9887, + -0.989, -0.9892, -0.9894, -0.9896, -0.9898, -0.9901, -0.9903, -0.9905, -0.9907, -0.9909, + -0.9911, -0.9913, -0.9915, -0.9917, -0.9919, -0.9921, -0.9923, -0.9925, -0.9927, -0.9929, + -0.993, -0.9932, -0.9934, -0.9936, -0.9937, -0.9939, -0.9941, -0.9942, -0.9944, -0.9946, + -0.9947, -0.9949, -0.995, -0.9952, -0.9953, -0.9955, -0.9956, -0.9958, -0.9959, -0.996, + -0.9962, -0.9963, -0.9964, -0.9966, -0.9967, -0.9968, -0.9969, -0.9971, -0.9972, -0.9973, + -0.9974, -0.9975, -0.9976, -0.9977, -0.9978, -0.9979, -0.998, -0.9981, -0.9982, -0.9983, + -0.9984, -0.9985, -0.9986, -0.9986, -0.9987, -0.9988, -0.9989, -0.9989, -0.999, -0.9991, + -0.9991, -0.9992, -0.9993, -0.9993, -0.9994, -0.9994, -0.9995, -0.9995, -0.9996, -0.9996, + -0.9997, -0.9997, -0.9997, -0.9998, -0.9998, -0.9998, -0.9999, -0.9999, -0.9999, -0.9999, + -0.9999, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -0.9999, -0.9999, -0.9999, -0.9999, -0.9999, -0.9998, + -0.9998, -0.9998, -0.9997, -0.9997, -0.9997, -0.9996, -0.9996, -0.9995, -0.9995, -0.9994, + -0.9994, -0.9993, -0.9993, -0.9992, -0.9991, -0.9991, -0.999, -0.9989, -0.9989, -0.9988, + -0.9987, -0.9986, -0.9986, -0.9985, -0.9984, -0.9983, -0.9982, -0.9981, -0.998, -0.9979, + -0.9978, -0.9977, -0.9976, -0.9975, -0.9974, -0.9973, -0.9972, -0.9971, -0.9969, -0.9968, + -0.9967, -0.9966, -0.9964, -0.9963, -0.9962, -0.996, -0.9959, -0.9958, -0.9956, -0.9955, + -0.9953, -0.9952, -0.995, -0.9949, -0.9947, -0.9946, -0.9944, -0.9942, -0.9941, -0.9939, + -0.9937, -0.9936, -0.9934, -0.9932, -0.993, -0.9929, -0.9927, -0.9925, -0.9923, -0.9921, + -0.9919, -0.9917, -0.9915, -0.9913, -0.9911, -0.9909, -0.9907, -0.9905, -0.9903, -0.9901, + -0.9898, -0.9896, -0.9894, -0.9892, -0.989, -0.9887, -0.9885, -0.9883, -0.988, -0.9878, + -0.9875, -0.9873, -0.9871, -0.9868, -0.9866, -0.9863, -0.9861, -0.9858, -0.9855, -0.9853, + -0.985, -0.9847, -0.9845, -0.9842, -0.9839, -0.9837, -0.9834, -0.9831, -0.9828, -0.9825, + -0.9823, -0.982, -0.9817, -0.9814, -0.9811, -0.9808, -0.9805, -0.9802, -0.9799, -0.9796, + -0.9793, -0.9789, -0.9786, -0.9783, -0.978, -0.9777, -0.9774, -0.977, -0.9767, -0.9764, + -0.976, -0.9757, -0.9754, -0.975, -0.9747, -0.9743, -0.974, -0.9736, -0.9733, -0.9729, + -0.9726, -0.9722, -0.9719, -0.9715, -0.9711, -0.9708, -0.9704, -0.97, -0.9697, -0.9693, + -0.9689, -0.9685, -0.9681, -0.9678, -0.9674, -0.967, -0.9666, -0.9662, -0.9658, -0.9654, + -0.965, -0.9646, -0.9642, -0.9638, -0.9634, -0.963, -0.9625, -0.9621, -0.9617, -0.9613, + -0.9609, -0.9604, -0.96, -0.9596, -0.9591, -0.9587, -0.9583, -0.9578, -0.9574, -0.9569, + -0.9565, -0.956, -0.9556, -0.9551, -0.9547, -0.9542, -0.9538, -0.9533, -0.9528, -0.9524, + -0.9519, -0.9514, -0.951, -0.9505, -0.95, -0.9495, -0.949, -0.9486, -0.9481, -0.9476, + -0.9471, -0.9466, -0.9461, -0.9456, -0.9451, -0.9446, -0.9441, -0.9436, -0.9431, -0.9426, + -0.9421, -0.9415, -0.941, -0.9405, -0.94, -0.9395, -0.9389, -0.9384, -0.9379, -0.9373, + -0.9368, -0.9363, -0.9357, -0.9352, -0.9346, -0.9341, -0.9335, -0.933, -0.9324, -0.9319, + -0.9313, -0.9308, -0.9302, -0.9296, -0.9291, -0.9285, -0.9279, -0.9274, -0.9268, -0.9262, + -0.9256, -0.925, -0.9245, -0.9239, -0.9233, -0.9227, -0.9221, -0.9215, -0.9209, -0.9203, + -0.9197, -0.9191, -0.9185, -0.9179, -0.9173, -0.9167, -0.9161, -0.9154, -0.9148, -0.9142, + -0.9136, -0.913, -0.9123, -0.9117, -0.9111, -0.9104, -0.9098, -0.9092, -0.9085, -0.9079, + -0.9072, -0.9066, -0.9059, -0.9053, -0.9046, -0.904, -0.9033, -0.9027, -0.902, -0.9013, + -0.9007, -0.9, -0.8993, -0.8987, -0.898, -0.8973, -0.8966, -0.896, -0.8953, -0.8946, + -0.8939, -0.8932, -0.8925, -0.8918, -0.8911, -0.8904, -0.8897, -0.889, -0.8883, -0.8876, + -0.8869, -0.8862, -0.8855, -0.8848, -0.8841, -0.8834, -0.8826, -0.8819, -0.8812, -0.8805, + -0.8797, -0.879, -0.8783, -0.8775, -0.8768, -0.8761, -0.8753, -0.8746, -0.8738, -0.8731, + -0.8723, -0.8716, -0.8708, -0.8701, -0.8693, -0.8686, -0.8678, -0.867, -0.8663, -0.8655, + -0.8647, -0.864, -0.8632, -0.8624, -0.8616, -0.8609, -0.8601, -0.8593, -0.8585, -0.8577, + -0.8569, -0.8561, -0.8554, -0.8546, -0.8538, -0.853, -0.8522, -0.8514, -0.8505, -0.8497, + -0.8489, -0.8481, -0.8473, -0.8465, -0.8457, -0.8449, -0.844, -0.8432, -0.8424, -0.8416, + -0.8407, -0.8399, -0.8391, -0.8382, -0.8374, -0.8365, -0.8357, -0.8349, -0.834, -0.8332, + -0.8323, -0.8315, -0.8306, -0.8298, -0.8289, -0.828, -0.8272, -0.8263, -0.8255, -0.8246, + -0.8237, -0.8228, -0.822, -0.8211, -0.8202, -0.8193, -0.8185, -0.8176, -0.8167, -0.8158, + -0.8149, -0.814, -0.8131, -0.8123, -0.8114, -0.8105, -0.8096, -0.8087, -0.8078, -0.8068, + -0.8059, -0.805, -0.8041, -0.8032, -0.8023, -0.8014, -0.8005, -0.7995, -0.7986, -0.7977, + -0.7968, -0.7958, -0.7949, -0.794, -0.793, -0.7921, -0.7912, -0.7902, -0.7893, -0.7883, + -0.7874, -0.7865, -0.7855, -0.7846, -0.7836, -0.7827, -0.7817, -0.7807, -0.7798, -0.7788, + -0.7779, -0.7769, -0.7759, -0.775, -0.774, -0.773, -0.772, -0.7711, -0.7701, -0.7691, + -0.7681, -0.7671, -0.7662, -0.7652, -0.7642, -0.7632, -0.7622, -0.7612, -0.7602, -0.7592, + -0.7582, -0.7572, -0.7562, -0.7552, -0.7542, -0.7532, -0.7522, -0.7512, -0.7502, -0.7491, + -0.7481, -0.7471, -0.7461, -0.7451, -0.744, -0.743, -0.742, -0.741, -0.7399, -0.7389, + -0.7379, -0.7368, -0.7358, -0.7347, -0.7337, -0.7327, -0.7316, -0.7306, -0.7295, -0.7285, + -0.7274, -0.7264, -0.7253, -0.7242, -0.7232, -0.7221, -0.7211, -0.72, -0.7189, -0.7179, + -0.7168, -0.7157, -0.7147, -0.7136, -0.7125, -0.7114, -0.7104, -0.7093, -0.7082, -0.7071, + -0.706, -0.7049, -0.7038, -0.7028, -0.7017, -0.7006, -0.6995, -0.6984, -0.6973, -0.6962, + -0.6951, -0.694, -0.6929, -0.6918, -0.6907, -0.6895, -0.6884, -0.6873, -0.6862, -0.6851, + -0.684, -0.6828, -0.6817, -0.6806, -0.6795, -0.6784, -0.6772, -0.6761, -0.675, -0.6738, + -0.6727, -0.6716, -0.6704, -0.6693, -0.6681, -0.667, -0.6659, -0.6647, -0.6636, -0.6624, + -0.6613, -0.6601, -0.659, -0.6578, -0.6567, -0.6555, -0.6543, -0.6532, -0.652, -0.6508, + -0.6497, -0.6485, -0.6473, -0.6462, -0.645, -0.6438, -0.6427, -0.6415, -0.6403, -0.6391, + -0.6379, -0.6368, -0.6356, -0.6344, -0.6332, -0.632, -0.6308, -0.6296, -0.6284, -0.6273, + -0.6261, -0.6249, -0.6237, -0.6225, -0.6213, -0.6201, -0.6189, -0.6176, -0.6164, -0.6152, + -0.614, -0.6128, -0.6116, -0.6104, -0.6092, -0.6079, -0.6067, -0.6055, -0.6043, -0.6031, + -0.6018, -0.6006, -0.5994, -0.5982, -0.5969, -0.5957, -0.5945, -0.5932, -0.592, -0.5908, + -0.5895, -0.5883, -0.587, -0.5858, -0.5846, -0.5833, -0.5821, -0.5808, -0.5796, -0.5783, + -0.5771, -0.5758, -0.5746, -0.5733, -0.572, -0.5708, -0.5695, -0.5683, -0.567, -0.5657, + -0.5645, -0.5632, -0.5619, -0.5607, -0.5594, -0.5581, -0.5568, -0.5556, -0.5543, -0.553, + -0.5517, -0.5505, -0.5492, -0.5479, -0.5466, -0.5453, -0.544, -0.5428, -0.5415, -0.5402, + -0.5389, -0.5376, -0.5363, -0.535, -0.5337, -0.5324, -0.5311, -0.5298, -0.5285, -0.5272, + -0.5259, -0.5246, -0.5233, -0.522, -0.5207, -0.5194, -0.518, -0.5167, -0.5154, -0.5141, + -0.5128, -0.5115, -0.5102, -0.5088, -0.5075, -0.5062, -0.5049, -0.5035, -0.5022, -0.5009, + -0.4996, -0.4982, -0.4969, -0.4956, -0.4942, -0.4929, -0.4916, -0.4902, -0.4889, -0.4876, + -0.4862, -0.4849, -0.4835, -0.4822, -0.4808, -0.4795, -0.4781, -0.4768, -0.4755, -0.4741, + -0.4727, -0.4714, -0.47, -0.4687, -0.4673, -0.466, -0.4646, -0.4633, -0.4619, -0.4605, + -0.4592, -0.4578, -0.4564, -0.4551, -0.4537, -0.4523, -0.451, -0.4496, -0.4482, -0.4469, + -0.4455, -0.4441, -0.4427, -0.4414, -0.44, -0.4386, -0.4372, -0.4359, -0.4345, -0.4331, + -0.4317, -0.4303, -0.4289, -0.4276, -0.4262, -0.4248, -0.4234, -0.422, -0.4206, -0.4192, + -0.4178, -0.4164, -0.415, -0.4136, -0.4122, -0.4108, -0.4094, -0.408, -0.4066, -0.4052, + -0.4038, -0.4024, -0.401, -0.3996, -0.3982, -0.3968, -0.3954, -0.394, -0.3926, -0.3912, + -0.3898, -0.3883, -0.3869, -0.3855, -0.3841, -0.3827, -0.3813, -0.3798, -0.3784, -0.377, + -0.3756, -0.3742, -0.3727, -0.3713, -0.3699, -0.3685, -0.367, -0.3656, -0.3642, -0.3628, + -0.3613, -0.3599, -0.3585, -0.357, -0.3556, -0.3542, -0.3527, -0.3513, -0.3499, -0.3484, + -0.347, -0.3455, -0.3441, -0.3427, -0.3412, -0.3398, -0.3383, -0.3369, -0.3354, -0.334, + -0.3326, -0.3311, -0.3297, -0.3282, -0.3268, -0.3253, -0.3239, -0.3224, -0.321, -0.3195, + -0.318, -0.3166, -0.3151, -0.3137, -0.3122, -0.3108, -0.3093, -0.3078, -0.3064, -0.3049, + -0.3035, -0.302, -0.3005, -0.2991, -0.2976, -0.2962, -0.2947, -0.2932, -0.2918, -0.2903, + -0.2888, -0.2873, -0.2859, -0.2844, -0.2829, -0.2815, -0.28, -0.2785, -0.277, -0.2756, + -0.2741, -0.2726, -0.2711, -0.2697, -0.2682, -0.2667, -0.2652, -0.2638, -0.2623, -0.2608, + -0.2593, -0.2578, -0.2563, -0.2549, -0.2534, -0.2519, -0.2504, -0.2489, -0.2474, -0.246, + -0.2445, -0.243, -0.2415, -0.24, -0.2385, -0.237, -0.2355, -0.234, -0.2326, -0.2311, + -0.2296, -0.2281, -0.2266, -0.2251, -0.2236, -0.2221, -0.2206, -0.2191, -0.2176, -0.2161, + -0.2146, -0.2131, -0.2116, -0.2101, -0.2086, -0.2071, -0.2056, -0.2041, -0.2026, -0.2011, + -0.1996, -0.1981, -0.1966, -0.1951, -0.1936, -0.1921, -0.1906, -0.1891, -0.1876, -0.1861, + -0.1845, -0.183, -0.1815, -0.18, -0.1785, -0.177, -0.1755, -0.174, -0.1725, -0.171, + -0.1695, -0.1679, -0.1664, -0.1649, -0.1634, -0.1619, -0.1604, -0.1589, -0.1573, -0.1558, + -0.1543, -0.1528, -0.1513, -0.1498, -0.1482, -0.1467, -0.1452, -0.1437, -0.1422, -0.1407, + -0.1391, -0.1376, -0.1361, -0.1346, -0.1331, -0.1315, -0.13, -0.1285, -0.127, -0.1255, + -0.1239, -0.1224, -0.1209, -0.1194, -0.1178, -0.1163, -0.1148, -0.1133, -0.1117, -0.1102, + -0.1087, -0.1072, -0.1056, -0.1041, -0.1026, -0.1011, -0.09954, -0.09802, -0.09649, -0.09496, + -0.09344, -0.09191, -0.09038, -0.08885, -0.08733, -0.0858, -0.08427, -0.08274, -0.08121, -0.07968, + -0.07815, -0.07662, -0.07509, -0.07356, -0.07203, -0.0705, -0.06897, -0.06744, -0.06591, -0.06438, + -0.06285, -0.06132, -0.05979, -0.05826, -0.05673, -0.0552, -0.05366, -0.05213, -0.0506, -0.04907, + -0.04754, -0.046, -0.04447, -0.04294, -0.04141, -0.03987, -0.03834, -0.03681, -0.03527, -0.03374, + -0.03221, -0.03067, -0.02914, -0.02761, -0.02607, -0.02454, -0.02301, -0.02147, -0.01994, -0.01841, + -0.01687, -0.01534, -0.01381, -0.01227, -0.01074, -0.009204, -0.00767, -0.006136, -0.004602, -0.003068, + -0.001534, -1.837e-16, 0.001534, 0.003068, 0.004602, 0.006136, 0.00767, 0.009204, 0.01074, 0.01227, + 0.01381, 0.01534, 0.01687, 0.01841, 0.01994, 0.02147, 0.02301, 0.02454, 0.02607, 0.02761, + 0.02914, 0.03067, 0.03221, 0.03374, 0.03527, 0.03681, 0.03834, 0.03987, 0.04141, 0.04294, + 0.04447, 0.046, 0.04754, 0.04907, 0.0506, 0.05213, 0.05366, 0.0552, 0.05673, 0.05826, + 0.05979, 0.06132, 0.06285, 0.06438, 0.06591, 0.06744, 0.06897, 0.0705, 0.07203, 0.07356, + 0.07509, 0.07662, 0.07815, 0.07968, 0.08121, 0.08274, 0.08427, 0.0858, 0.08733, 0.08885, + 0.09038, 0.09191, 0.09344, 0.09496, 0.09649, 0.09802, 0.09954, 0.1011, 0.1026, 0.1041, + 0.1056, 0.1072, 0.1087, 0.1102, 0.1117, 0.1133, 0.1148, 0.1163, 0.1178, 0.1194, + 0.1209, 0.1224, 0.1239, 0.1255, 0.127, 0.1285, 0.13, 0.1315, 0.1331, 0.1346, + 0.1361, 0.1376, 0.1391, 0.1407, 0.1422, 0.1437, 0.1452, 0.1467, 0.1482, 0.1498, + 0.1513, 0.1528, 0.1543, 0.1558, 0.1573, 0.1589, 0.1604, 0.1619, 0.1634, 0.1649, + 0.1664, 0.1679, 0.1695, 0.171, 0.1725, 0.174, 0.1755, 0.177, 0.1785, 0.18, + 0.1815, 0.183, 0.1845, 0.1861, 0.1876, 0.1891, 0.1906, 0.1921, 0.1936, 0.1951, + 0.1966, 0.1981, 0.1996, 0.2011, 0.2026, 0.2041, 0.2056, 0.2071, 0.2086, 0.2101, + 0.2116, 0.2131, 0.2146, 0.2161, 0.2176, 0.2191, 0.2206, 0.2221, 0.2236, 0.2251, + 0.2266, 0.2281, 0.2296, 0.2311, 0.2326, 0.234, 0.2355, 0.237, 0.2385, 0.24, + 0.2415, 0.243, 0.2445, 0.246, 0.2474, 0.2489, 0.2504, 0.2519, 0.2534, 0.2549, + 0.2563, 0.2578, 0.2593, 0.2608, 0.2623, 0.2638, 0.2652, 0.2667, 0.2682, 0.2697, + 0.2711, 0.2726, 0.2741, 0.2756, 0.277, 0.2785, 0.28, 0.2815, 0.2829, 0.2844, + 0.2859, 0.2873, 0.2888, 0.2903, 0.2918, 0.2932, 0.2947, 0.2962, 0.2976, 0.2991, + 0.3005, 0.302, 0.3035, 0.3049, 0.3064, 0.3078, 0.3093, 0.3108, 0.3122, 0.3137, + 0.3151, 0.3166, 0.318, 0.3195, 0.321, 0.3224, 0.3239, 0.3253, 0.3268, 0.3282, + 0.3297, 0.3311, 0.3326, 0.334, 0.3354, 0.3369, 0.3383, 0.3398, 0.3412, 0.3427, + 0.3441, 0.3455, 0.347, 0.3484, 0.3499, 0.3513, 0.3527, 0.3542, 0.3556, 0.357, + 0.3585, 0.3599, 0.3613, 0.3628, 0.3642, 0.3656, 0.367, 0.3685, 0.3699, 0.3713, + 0.3727, 0.3742, 0.3756, 0.377, 0.3784, 0.3798, 0.3813, 0.3827, 0.3841, 0.3855, + 0.3869, 0.3883, 0.3898, 0.3912, 0.3926, 0.394, 0.3954, 0.3968, 0.3982, 0.3996, + 0.401, 0.4024, 0.4038, 0.4052, 0.4066, 0.408, 0.4094, 0.4108, 0.4122, 0.4136, + 0.415, 0.4164, 0.4178, 0.4192, 0.4206, 0.422, 0.4234, 0.4248, 0.4262, 0.4276, + 0.4289, 0.4303, 0.4317, 0.4331, 0.4345, 0.4359, 0.4372, 0.4386, 0.44, 0.4414, + 0.4427, 0.4441, 0.4455, 0.4469, 0.4482, 0.4496, 0.451, 0.4523, 0.4537, 0.4551, + 0.4564, 0.4578, 0.4592, 0.4605, 0.4619, 0.4633, 0.4646, 0.466, 0.4673, 0.4687, + 0.47, 0.4714, 0.4727, 0.4741, 0.4755, 0.4768, 0.4781, 0.4795, 0.4808, 0.4822, + 0.4835, 0.4849, 0.4862, 0.4876, 0.4889, 0.4902, 0.4916, 0.4929, 0.4942, 0.4956, + 0.4969, 0.4982, 0.4996, 0.5009, 0.5022, 0.5035, 0.5049, 0.5062, 0.5075, 0.5088, + 0.5102, 0.5115, 0.5128, 0.5141, 0.5154, 0.5167, 0.518, 0.5194, 0.5207, 0.522, + 0.5233, 0.5246, 0.5259, 0.5272, 0.5285, 0.5298, 0.5311, 0.5324, 0.5337, 0.535, + 0.5363, 0.5376, 0.5389, 0.5402, 0.5415, 0.5428, 0.544, 0.5453, 0.5466, 0.5479, + 0.5492, 0.5505, 0.5517, 0.553, 0.5543, 0.5556, 0.5568, 0.5581, 0.5594, 0.5607, + 0.5619, 0.5632, 0.5645, 0.5657, 0.567, 0.5683, 0.5695, 0.5708, 0.572, 0.5733, + 0.5746, 0.5758, 0.5771, 0.5783, 0.5796, 0.5808, 0.5821, 0.5833, 0.5846, 0.5858, + 0.587, 0.5883, 0.5895, 0.5908, 0.592, 0.5932, 0.5945, 0.5957, 0.5969, 0.5982, + 0.5994, 0.6006, 0.6018, 0.6031, 0.6043, 0.6055, 0.6067, 0.6079, 0.6092, 0.6104, + 0.6116, 0.6128, 0.614, 0.6152, 0.6164, 0.6176, 0.6189, 0.6201, 0.6213, 0.6225, + 0.6237, 0.6249, 0.6261, 0.6273, 0.6284, 0.6296, 0.6308, 0.632, 0.6332, 0.6344, + 0.6356, 0.6368, 0.6379, 0.6391, 0.6403, 0.6415, 0.6427, 0.6438, 0.645, 0.6462, + 0.6473, 0.6485, 0.6497, 0.6508, 0.652, 0.6532, 0.6543, 0.6555, 0.6567, 0.6578, + 0.659, 0.6601, 0.6613, 0.6624, 0.6636, 0.6647, 0.6659, 0.667, 0.6681, 0.6693, + 0.6704, 0.6716, 0.6727, 0.6738, 0.675, 0.6761, 0.6772, 0.6784, 0.6795, 0.6806, + 0.6817, 0.6828, 0.684, 0.6851, 0.6862, 0.6873, 0.6884, 0.6895, 0.6907, 0.6918, + 0.6929, 0.694, 0.6951, 0.6962, 0.6973, 0.6984, 0.6995, 0.7006, 0.7017, 0.7028, + 0.7038, 0.7049, 0.706, 0.7071, 0.7082, 0.7093, 0.7104, 0.7114, 0.7125, 0.7136, + 0.7147, 0.7157, 0.7168, 0.7179, 0.7189, 0.72, 0.7211, 0.7221, 0.7232, 0.7242, + 0.7253, 0.7264, 0.7274, 0.7285, 0.7295, 0.7306, 0.7316, 0.7327, 0.7337, 0.7347, + 0.7358, 0.7368, 0.7379, 0.7389, 0.7399, 0.741, 0.742, 0.743, 0.744, 0.7451, + 0.7461, 0.7471, 0.7481, 0.7491, 0.7502, 0.7512, 0.7522, 0.7532, 0.7542, 0.7552, + 0.7562, 0.7572, 0.7582, 0.7592, 0.7602, 0.7612, 0.7622, 0.7632, 0.7642, 0.7652, + 0.7662, 0.7671, 0.7681, 0.7691, 0.7701, 0.7711, 0.772, 0.773, 0.774, 0.775, + 0.7759, 0.7769, 0.7779, 0.7788, 0.7798, 0.7807, 0.7817, 0.7827, 0.7836, 0.7846, + 0.7855, 0.7865, 0.7874, 0.7883, 0.7893, 0.7902, 0.7912, 0.7921, 0.793, 0.794, + 0.7949, 0.7958, 0.7968, 0.7977, 0.7986, 0.7995, 0.8005, 0.8014, 0.8023, 0.8032, + 0.8041, 0.805, 0.8059, 0.8068, 0.8078, 0.8087, 0.8096, 0.8105, 0.8114, 0.8123, + 0.8131, 0.814, 0.8149, 0.8158, 0.8167, 0.8176, 0.8185, 0.8193, 0.8202, 0.8211, + 0.822, 0.8228, 0.8237, 0.8246, 0.8255, 0.8263, 0.8272, 0.828, 0.8289, 0.8298, + 0.8306, 0.8315, 0.8323, 0.8332, 0.834, 0.8349, 0.8357, 0.8365, 0.8374, 0.8382, + 0.8391, 0.8399, 0.8407, 0.8416, 0.8424, 0.8432, 0.844, 0.8449, 0.8457, 0.8465, + 0.8473, 0.8481, 0.8489, 0.8497, 0.8505, 0.8514, 0.8522, 0.853, 0.8538, 0.8546, + 0.8554, 0.8561, 0.8569, 0.8577, 0.8585, 0.8593, 0.8601, 0.8609, 0.8616, 0.8624, + 0.8632, 0.864, 0.8647, 0.8655, 0.8663, 0.867, 0.8678, 0.8686, 0.8693, 0.8701, + 0.8708, 0.8716, 0.8723, 0.8731, 0.8738, 0.8746, 0.8753, 0.8761, 0.8768, 0.8775, + 0.8783, 0.879, 0.8797, 0.8805, 0.8812, 0.8819, 0.8826, 0.8834, 0.8841, 0.8848, + 0.8855, 0.8862, 0.8869, 0.8876, 0.8883, 0.889, 0.8897, 0.8904, 0.8911, 0.8918, + 0.8925, 0.8932, 0.8939, 0.8946, 0.8953, 0.896, 0.8966, 0.8973, 0.898, 0.8987, + 0.8993, 0.9, 0.9007, 0.9013, 0.902, 0.9027, 0.9033, 0.904, 0.9046, 0.9053, + 0.9059, 0.9066, 0.9072, 0.9079, 0.9085, 0.9092, 0.9098, 0.9104, 0.9111, 0.9117, + 0.9123, 0.913, 0.9136, 0.9142, 0.9148, 0.9154, 0.9161, 0.9167, 0.9173, 0.9179, + 0.9185, 0.9191, 0.9197, 0.9203, 0.9209, 0.9215, 0.9221, 0.9227, 0.9233, 0.9239, + 0.9245, 0.925, 0.9256, 0.9262, 0.9268, 0.9274, 0.9279, 0.9285, 0.9291, 0.9296, + 0.9302, 0.9308, 0.9313, 0.9319, 0.9324, 0.933, 0.9335, 0.9341, 0.9346, 0.9352, + 0.9357, 0.9363, 0.9368, 0.9373, 0.9379, 0.9384, 0.9389, 0.9395, 0.94, 0.9405, + 0.941, 0.9415, 0.9421, 0.9426, 0.9431, 0.9436, 0.9441, 0.9446, 0.9451, 0.9456, + 0.9461, 0.9466, 0.9471, 0.9476, 0.9481, 0.9486, 0.949, 0.9495, 0.95, 0.9505, + 0.951, 0.9514, 0.9519, 0.9524, 0.9528, 0.9533, 0.9538, 0.9542, 0.9547, 0.9551, + 0.9556, 0.956, 0.9565, 0.9569, 0.9574, 0.9578, 0.9583, 0.9587, 0.9591, 0.9596, + 0.96, 0.9604, 0.9609, 0.9613, 0.9617, 0.9621, 0.9625, 0.963, 0.9634, 0.9638, + 0.9642, 0.9646, 0.965, 0.9654, 0.9658, 0.9662, 0.9666, 0.967, 0.9674, 0.9678, + 0.9681, 0.9685, 0.9689, 0.9693, 0.9697, 0.97, 0.9704, 0.9708, 0.9711, 0.9715, + 0.9719, 0.9722, 0.9726, 0.9729, 0.9733, 0.9736, 0.974, 0.9743, 0.9747, 0.975, + 0.9754, 0.9757, 0.976, 0.9764, 0.9767, 0.977, 0.9774, 0.9777, 0.978, 0.9783, + 0.9786, 0.9789, 0.9793, 0.9796, 0.9799, 0.9802, 0.9805, 0.9808, 0.9811, 0.9814, + 0.9817, 0.982, 0.9823, 0.9825, 0.9828, 0.9831, 0.9834, 0.9837, 0.9839, 0.9842, + 0.9845, 0.9847, 0.985, 0.9853, 0.9855, 0.9858, 0.9861, 0.9863, 0.9866, 0.9868, + 0.9871, 0.9873, 0.9875, 0.9878, 0.988, 0.9883, 0.9885, 0.9887, 0.989, 0.9892, + 0.9894, 0.9896, 0.9898, 0.9901, 0.9903, 0.9905, 0.9907, 0.9909, 0.9911, 0.9913, + 0.9915, 0.9917, 0.9919, 0.9921, 0.9923, 0.9925, 0.9927, 0.9929, 0.993, 0.9932, + 0.9934, 0.9936, 0.9937, 0.9939, 0.9941, 0.9942, 0.9944, 0.9946, 0.9947, 0.9949, + 0.995, 0.9952, 0.9953, 0.9955, 0.9956, 0.9958, 0.9959, 0.996, 0.9962, 0.9963, + 0.9964, 0.9966, 0.9967, 0.9968, 0.9969, 0.9971, 0.9972, 0.9973, 0.9974, 0.9975, + 0.9976, 0.9977, 0.9978, 0.9979, 0.998, 0.9981, 0.9982, 0.9983, 0.9984, 0.9985, + 0.9986, 0.9986, 0.9987, 0.9988, 0.9989, 0.9989, 0.999, 0.9991, 0.9991, 0.9992, + 0.9993, 0.9993, 0.9994, 0.9994, 0.9995, 0.9995, 0.9996, 0.9996, 0.9997, 0.9997, + 0.9997, 0.9998, 0.9998, 0.9998, 0.9999, 0.9999, 0.9999, 0.9999, 0.9999, 1, + 1, 1, 1, 1, 1, 1}; + + const static uint64 mask = max-1; + return logvals[h.first&mask]*cosvals[h.second&mask]; + + // Note that we are just using the Box-Muller transform to compute the result. In + // particular, we are doing this (where u1 and u2 are uniform random variables in + // the range [0,1]): + // return sqrt(-2*log(u1)) * cos(2*PI*u2); + // It is just that we use table lookups to avoid calling sqrt(), log() and cos(). + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_RANDOM_HAsHING_Hh_ + diff --git a/ml/dlib/dlib/general_hash/random_hashing_abstract.h b/ml/dlib/dlib/general_hash/random_hashing_abstract.h new file mode 100644 index 000000000..3d196d8c0 --- /dev/null +++ b/ml/dlib/dlib/general_hash/random_hashing_abstract.h @@ -0,0 +1,58 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_RANDOM_HAsHING_ABSTRACT_Hh_ +#ifdef DLIB_RANDOM_HAsHING_ABSTRACT_Hh_ + +#include "random_hashing_abstract.h" +#include "murmur_hash3.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + double uniform_random_hash ( + const uint64& k1, + const uint64& k2, + const uint64& k3 + ); + /*! + ensures + - This function uses hashing to generate uniform random values in the range [0,1). + - To define this function precisely, assume we have an arbitrary sequence of + input triplets. Then calling uniform_random_hash() on each of them should + result in a sequence of double values that look like numbers sampled + independently and uniformly at random from the interval [0,1). This is true + even if there is some simple pattern in the inputs. For example, (0,0,0), + (1,0,0), (2,0,0), (3,0,0), etc. + - This function is deterministic. That is, the same output is always returned + when given the same input. + !*/ + +// ---------------------------------------------------------------------------------------- + + double gaussian_random_hash ( + const uint64& k1, + const uint64& k2, + const uint64& k3 + ); + /*! + ensures + - This function uses hashing to generate Gaussian distributed random values + with mean 0 and variance 1. + - To define this function precisely, assume we have an arbitrary sequence of + input triplets. Then calling gaussian_random_hash() on each of them should + result in a sequence of double values that look like numbers sampled + independently from a standard normal distribution. This is true even if + there is some simple pattern in the inputs. For example, (0,0,0), (1,0,0), + (2,0,0), (3,0,0), etc. + - This function is deterministic. That is, the same output is always returned + when given the same input. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_RANDOM_HAsHING_ABSTRACT_Hh_ + diff --git a/ml/dlib/dlib/geometry.h b/ml/dlib/dlib/geometry.h new file mode 100644 index 000000000..9d326b150 --- /dev/null +++ b/ml/dlib/dlib/geometry.h @@ -0,0 +1,14 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_GEOMETRy_HEADER +#define DLIB_GEOMETRy_HEADER + +#include "geometry/rectangle.h" +#include "geometry/drectangle.h" +#include "geometry/vector.h" +#include "geometry/border_enumerator.h" +#include "geometry/point_transforms.h" + +#endif // DLIB_GEOMETRy_HEADER + + diff --git a/ml/dlib/dlib/geometry/border_enumerator.h b/ml/dlib/dlib/geometry/border_enumerator.h new file mode 100644 index 000000000..0c69cc37f --- /dev/null +++ b/ml/dlib/dlib/geometry/border_enumerator.h @@ -0,0 +1,186 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BORDER_EnUMERATOR_H_ +#define DLIB_BORDER_EnUMERATOR_H_ + +#include "border_enumerator_abstract.h" +#include "rectangle.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class border_enumerator + { + public: + border_enumerator( + ) + { + reset(); + } + + border_enumerator( + const rectangle& rect_, + unsigned long border_size + ) : + rect(rect_), + inner_rect(shrink_rect(rect_, border_size)) + { + reset(); + } + + border_enumerator( + const rectangle& rect_, + const rectangle& non_border_region + ) : + rect(rect_), + inner_rect(non_border_region.intersect(rect)) + { + reset(); + } + + void reset ( + ) + { + // make the four rectangles that surround inner_rect and intersect them + // with rect. + bleft = rect.intersect(rectangle(std::numeric_limits::min(), + std::numeric_limits::min(), + inner_rect.left()-1, + std::numeric_limits::max())); + + bright = rect.intersect(rectangle(inner_rect.right()+1, + std::numeric_limits::min(), + std::numeric_limits::max(), + std::numeric_limits::max())); + + btop = rect.intersect(rectangle(inner_rect.left(), + std::numeric_limits::min(), + inner_rect.right(), + inner_rect.top()-1)); + + bbottom = rect.intersect(rectangle(inner_rect.left(), + inner_rect.bottom()+1, + inner_rect.right(), + std::numeric_limits::max())); + + p = bleft.tl_corner(); + p.x() -= 1; + + mode = atleft; + } + + bool at_start ( + ) const + { + point temp = bleft.tl_corner(); + temp.x() -=1; + return temp == p; + } + + bool current_element_valid( + ) const + { + return rect.contains(p); + } + + bool move_next() + { + if (mode == atleft) + { + if (advance_point(bleft, p)) + return true; + + mode = attop; + p = btop.tl_corner(); + p.x() -= 1; + } + if (mode == attop) + { + if (advance_point(btop, p)) + return true; + + mode = atright; + p = bright.tl_corner(); + p.x() -= 1; + } + if (mode == atright) + { + if (advance_point(bright, p)) + return true; + + mode = atbottom; + p = bbottom.tl_corner(); + p.x() -= 1; + } + + if (advance_point(bbottom, p)) + return true; + + // put p outside rect since there are no more points to enumerate + p = rect.br_corner(); + p.x() += 1; + + return false; + } + + size_t size ( + ) const + { + return rect.area() - inner_rect.area(); + } + + const point& element ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(current_element_valid(), + "\t point border_enumerator::element()" + << "\n\t This function can't be called unless the element is valid." + << "\n\t this: " << this + ); + + return p; + } + + private: + + bool advance_point ( + const rectangle& r, + point& p + ) const + { + p.x() += 1; + if (p.x() > r.right()) + { + p.x() = r.left(); + p.y() += 1; + } + + return r.contains(p); + } + + point p; + rectangle rect; + rectangle inner_rect; // the non-border regions of rect + + enum emode + { + atleft, + atright, + atbottom, + attop + }; + + emode mode; + + rectangle btop, bleft, bright, bbottom; + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BORDER_EnUMERATOR_H_ + diff --git a/ml/dlib/dlib/geometry/border_enumerator_abstract.h b/ml/dlib/dlib/geometry/border_enumerator_abstract.h new file mode 100644 index 000000000..11118d571 --- /dev/null +++ b/ml/dlib/dlib/geometry/border_enumerator_abstract.h @@ -0,0 +1,126 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_BORDER_EnUMERATOR_ABSTRACT_H_ +#ifdef DLIB_BORDER_EnUMERATOR_ABSTRACT_H_ + +#include "rectangle_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class border_enumerator + { + /*! + POINTERS AND REFERENCES TO INTERNAL DATA + All operations on this object other than calling element() invalidate + pointers and references to internal data. + + WHAT THIS OBJECT REPRESENTS + This object is an enumerator over the border points of a rectangle. + !*/ + public: + + border_enumerator( + ); + /*! + ensures + - #move_next() == false + (i.e. this object is "empty" and won't enumerate anything) + - current_element_valid() == false + - at_start() == true + - size() == 0 + !*/ + + border_enumerator( + const rectangle& rect, + unsigned long border_size + ); + /*! + ensures + - This object will enumerate over the border points which are inside rect + but within border_size of the edge. For example, if border_size == 1 + then it enumerates over the single point wide strip of points all around + the interior edge of rect. + - current_element_valid() == false + - at_start() == true + - size() == rect.area() - shrink_rect(rect,border_size).area() + (i.e. the number of points in the border area of rect) + !*/ + + border_enumerator( + const rectangle& rect, + const rectangle& non_border_region + ); + /*! + ensures + - This object will enumerate over all points which are in rect but + not in non_border_region. + - current_element_valid() == false + - at_start() == true + - size() == rect.area() - rect.intersect(non_border_region).area() + !*/ + + bool at_start ( + ) const; + /*! + ensures + - returns true if *this represents one position before the first point + (this would also make the current element invalid) else returns false + !*/ + + void reset ( + ); + /*! + ensures + - #current_element_valid() == false + - #at_start() == true + !*/ + + bool current_element_valid( + ) const; + /*! + ensures + - returns true if we are currently at a valid element else + returns false + !*/ + + bool move_next( + ); + /*! + ensures + - moves to the next element. i.e. #element() will now + return the next border point. + - the return value will be equal to #current_element_valid() + - #at_start() == false + + - returns true if there is another element + - returns false if there are no more elements in the container + !*/ + + size_t size ( + ) const; + /*! + ensures + - returns the number of border points + !*/ + + const point& element ( + ) const; + /*! + requires + - current_element_valid() == true + ensures + - returns the current border point + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BORDER_EnUMERATOR_ABSTRACT_H_ + + diff --git a/ml/dlib/dlib/geometry/drectangle.h b/ml/dlib/dlib/geometry/drectangle.h new file mode 100644 index 000000000..9ccc5c0ee --- /dev/null +++ b/ml/dlib/dlib/geometry/drectangle.h @@ -0,0 +1,488 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_DRECTANGLe_ +#define DLIB_DRECTANGLe_ + +#include "drectangle_abstract.h" +#include "rectangle.h" + +namespace dlib +{ + class drectangle; + drectangle operator* ( + const drectangle& rect, + const double& scale + ); + +// ---------------------------------------------------------------------------------------- + + class drectangle + { + public: + + drectangle ( + ) : l(0), t(0), r(-1), b(-1) {} + + drectangle ( + double l_, + double t_, + double r_, + double b_ + ) : + l(l_), + t(t_), + r(r_), + b(b_) + {} + + drectangle ( + const dlib::vector& p + ) : + l(p.x()), + t(p.y()), + r(p.x()), + b(p.y()) + { + } + + template + drectangle ( + const vector& p1, + const vector& p2 + ) + { + *this = drectangle(p1) + drectangle(p2); + } + + drectangle ( + const rectangle& rect + ) : l(rect.left()), + t(rect.top()), + r(rect.right()), + b(rect.bottom()) {} + + operator rectangle ( + ) const + { + return rectangle((long)std::floor(l+0.5), + (long)std::floor(t+0.5), + (long)std::floor(r+0.5), + (long)std::floor(b+0.5)); + } + + double left() const { return l; } + double top() const { return t; } + double right() const { return r; } + double bottom() const { return b; } + + double& left() { return l; } + double& top() { return t; } + double& right() { return r; } + double& bottom() { return b; } + + const dlib::vector tl_corner ( + ) const { return dlib::vector(left(), top()); } + + const dlib::vector bl_corner ( + ) const { return dlib::vector(left(), bottom()); } + + const dlib::vector tr_corner ( + ) const { return dlib::vector(right(), top()); } + + const dlib::vector br_corner ( + ) const { return dlib::vector(right(), bottom()); } + + double width ( + ) const + { + if (is_empty()) + return 0; + else + return r - l + 1; + } + + double height ( + ) const + { + if (is_empty()) + return 0; + else + return b - t + 1; + } + + double area ( + ) const + { + return width()*height(); + } + + bool is_empty ( + ) const { return (t > b || l > r); } + + drectangle operator + ( + const drectangle& rhs + ) const + { + if (rhs.is_empty()) + return *this; + else if (is_empty()) + return rhs; + + return drectangle ( + std::min(l,rhs.l), + std::min(t,rhs.t), + std::max(r,rhs.r), + std::max(b,rhs.b) + ); + } + + drectangle intersect ( + const drectangle& rhs + ) const + { + return drectangle ( + std::max(l,rhs.l), + std::max(t,rhs.t), + std::min(r,rhs.r), + std::min(b,rhs.b) + ); + } + + bool contains ( + const dlib::vector& p + ) const + { + if (p.x() < l || p.x() > r || p.y() < t || p.y() > b) + return false; + return true; + } + + bool contains ( + const drectangle& rect + ) const + { + if (rect.is_empty()) + return true; + if (l <= rect.left() && + r >= rect.right() && + t <= rect.top() && + b >= rect.bottom()) + return true; + return false; + } + + drectangle& operator *= ( + const double& scale + ) + { + *this = *this*scale; + return *this; + } + + drectangle& operator /= ( + const double& scale + ) + { + *this = *this*(1.0/scale); + return *this; + } + + drectangle& operator += ( + const dlib::vector& p + ) + { + *this = *this + drectangle(p); + return *this; + } + + bool operator== ( + const drectangle& rect + ) const + { + return (l == rect.l) && (t == rect.t) && (r == rect.r) && (b == rect.b); + } + + bool operator!= ( + const drectangle& rect + ) const + { + return !(*this == rect); + } + + private: + double l; + double t; + double r; + double b; + }; + +// ---------------------------------------------------------------------------------------- + + inline void serialize ( + const drectangle& item, + std::ostream& out + ) + { + try + { + serialize(item.left(),out); + serialize(item.top(),out); + serialize(item.right(),out); + serialize(item.bottom(),out); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing an object of type drectangle"); + } + } + + inline void deserialize ( + drectangle& item, + std::istream& in + ) + { + try + { + deserialize(item.left(),in); + deserialize(item.top(),in); + deserialize(item.right(),in); + deserialize(item.bottom(),in); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while deserializing an object of type drectangle"); + } + } + + inline std::ostream& operator<< ( + std::ostream& out, + const drectangle& item + ) + { + out << "[(" << item.left() << ", " << item.top() << ") (" << item.right() << ", " << item.bottom() << ")]"; + return out; + } + + inline std::istream& operator>>( + std::istream& in, + drectangle& item + ) + { + // ignore any whitespace + while (in.peek() == ' ' || in.peek() == '\t' || in.peek() == '\r' || in.peek() == '\n') + in.get(); + // now eat the leading '[' character + if (in.get() != '[') + { + in.setstate(in.rdstate() | std::ios::failbit); + return in; + } + + dlib::vector p1, p2; + in >> p1; + in >> p2; + item = drectangle(p1) + drectangle(p2); + + // ignore any whitespace + while (in.peek() == ' ' || in.peek() == '\t' || in.peek() == '\r' || in.peek() == '\n') + in.get(); + // now eat the trailing ']' character + if (in.get() != ']') + { + in.setstate(in.rdstate() | std::ios::failbit); + } + return in; + } + +// ---------------------------------------------------------------------------------------- + + inline dlib::vector center ( + const drectangle& rect + ) + { + dlib::vector temp(rect.left() + rect.right(), + rect.top() + rect.bottom()); + + return temp/2.0; + } + + inline dlib::vector dcenter ( + const drectangle& rect + ) + { + return center(rect); + } + + inline drectangle operator* ( + const drectangle& rect, + const double& scale + ) + { + if (!rect.is_empty()) + { + const double width = (rect.right()-rect.left())*scale; + const double height = (rect.bottom()-rect.top())*scale; + const dlib::vector p = center(rect); + return drectangle(p.x()-width/2, p.y()-height/2, p.x()+width/2, p.y()+height/2); + } + else + { + return rect; + } + } + + inline drectangle operator* ( + const double& scale, + const drectangle& rect + ) + { + return rect*scale; + } + + inline drectangle operator/ ( + const drectangle& rect, + const double& scale + ) + { + return rect*(1.0/scale); + } + + inline drectangle operator+ ( + const drectangle& r, + const dlib::vector& p + ) + { + return r + drectangle(p); + } + + inline drectangle operator+ ( + const dlib::vector& p, + const drectangle& r + ) + { + return r + drectangle(p); + } + + template + inline drectangle translate_rect ( + const drectangle& rect, + const dlib::vector& p + ) + { + drectangle result; + result.top () = rect.top() + p.y(); + result.bottom () = rect.bottom() + p.y(); + result.left () = rect.left() + p.x(); + result.right () = rect.right() + p.x(); + return result; + } + + inline drectangle intersect ( + const drectangle& a, + const drectangle& b + ) { return a.intersect(b); } + + inline double area ( + const drectangle& a + ) { return a.area(); } + + inline drectangle centered_drect ( + const dlib::vector& p, + double width, + double height + ) + { + width--; + height--; + + return drectangle(p.x()-width/2, p.y()-height/2, p.x()+width/2, p.y()+height/2); + } + + inline drectangle centered_drect ( + const drectangle& rect, + double width, + double height + ) + { + return centered_drect(dcenter(rect), width, height); + } + + inline const drectangle shrink_rect ( + const drectangle& rect, + double num + ) + { + return drectangle(rect.left()+num, rect.top()+num, rect.right()-num, rect.bottom()-num); + } + + inline const drectangle grow_rect ( + const drectangle& rect, + double num + ) + { + return shrink_rect(rect, -num); + } + + inline const drectangle shrink_rect ( + const drectangle& rect, + double width, + double height + ) + { + return drectangle(rect.left()+width, rect.top()+height, rect.right()-width, rect.bottom()-height); + } + + inline const drectangle grow_rect ( + const drectangle& rect, + double width, + double height + ) + { + return shrink_rect(rect, -width, -height); + } + + inline drectangle set_rect_area ( + const drectangle& rect, + double area + ) + { + DLIB_ASSERT(area >= 0, "drectangle can't have a negative area."); + + if (area == 0) + return drectangle(dcenter(rect)); + + if (rect.area() == 0) + { + // In this case we will make the output rectangle a square with the requested + // area. + double scale = std::sqrt(area); + return centered_drect(rect, scale, scale); + } + else + { + double scale = std::sqrt(area/rect.area()); + return centered_drect(rect, rect.width()*scale, rect.height()*scale); + } + } + + inline drectangle set_aspect_ratio ( + const drectangle& rect, + double ratio + ) + { + DLIB_ASSERT(ratio > 0, + "\t drectangle set_aspect_ratio()" + << "\n\t ratio: " << ratio + ); + + const double h = std::sqrt(rect.area()/ratio); + const double w = h*ratio; + return centered_drect(rect, w, h); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_DRECTANGLe_ + diff --git a/ml/dlib/dlib/geometry/drectangle_abstract.h b/ml/dlib/dlib/geometry/drectangle_abstract.h new file mode 100644 index 000000000..0f2221353 --- /dev/null +++ b/ml/dlib/dlib/geometry/drectangle_abstract.h @@ -0,0 +1,628 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_DRECTANGLe_ABSTRACT_H_ +#ifdef DLIB_DRECTANGLe_ABSTRACT_H_ + +#include "rectangle_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class drectangle + { + /*! + INITIAL VALUE + The initial value of this object is defined by its constructor. + + WHAT THIS OBJECT REPRESENTS + This object is just like dlib::rectangle except that it stores the + coordinates of the rectangle using double rather than long variables. As + such, this object represents a rectangular region inside an image. The + region is the rectangle with its top left corner at position (left(),top()) + and its bottom right corner at (right(),bottom()). + + Note that the origin of the coordinate system, i.e. (0,0), is located at + the upper left corner. That is, points such as (1,1) or (3,5) represent + locations that are below and to the right of the origin. + + Also note that rectangles where top() > bottom() or left() > right() + represent empty rectangles. + !*/ + public: + + drectangle ( + ); + /*! + ensures + - #left() == 0 + - #top() == 0 + - #right() == -1 + - #bottom() == -1 + - #is_empty() == true + !*/ + + drectangle ( + double left_, + double top_, + double right_, + double bottom_ + ); + /*! + ensures + - #left() == left_ + - #top() == top_ + - #right() == right_ + - #bottom() == bottom_ + !*/ + + drectangle ( + const vector& p + ); + /*! + ensures + - #left() == p.x() + - #top() == p.y() + - #right() == p.x() + - #bottom() == p.y() + !*/ + + template + drectangle ( + const vector& p1, + const vector& p2 + ); + /*! + ensures + - #*this == drectangle(p1) + drectangle(p2) + !*/ + + drectangle ( + const drectangle& rect + ); + /*! + ensures + - #*this represents the same rectangle as rect + !*/ + + drectangle ( + const rectangle& rect + ); + /*! + ensures + - left() == rect.left() + - top() == rect.top() + - right() == rect.right() + - bottom() == rect.bottom() + - dcenter(*this) == dcenter(rect) + - width() == rect.width() + - height() == rect.height() + !*/ + + operator rectangle ( + ) const; + /*! + ensures + - returns a rectangle where left(), top(), right(), and bottom() have been + rounded to the nearest integer values. + !*/ + + double left ( + ) const; + /*! + ensures + - returns the x coordinate for the left side of this rectangle + !*/ + + double& left ( + ); + /*! + ensures + - returns a non-const reference to the x coordinate for the left side + of this rectangle + !*/ + + double top ( + ) const; + /*! + ensures + - returns the y coordinate for the top of this rectangle + !*/ + + double& top ( + ); + /*! + ensures + - returns a non-const reference to the y coordinate for the + top of this rectangle + !*/ + + double right ( + ) const; + /*! + ensures + - returns the x coordinate for the right side of this rectangle + !*/ + + double& right ( + ); + /*! + ensures + - returns a non-const reference to the x coordinate for the right + side of this rectangle + !*/ + + double bottom ( + ) const; + /*! + ensures + - returns the y coordinate for the bottom of this rectangle + !*/ + + double& bottom ( + ); + /*! + ensures + - returns a non-const reference to the y coordinate for the bottom + of this rectangle + !*/ + + const vector tl_corner ( + ) const; + /*! + ensures + - returns vector(left(), top()) + (i.e. returns the top left corner point for this rectangle) + !*/ + + const vector bl_corner ( + ) const; + /*! + ensures + - returns vector(left(), bottom()) + (i.e. returns the bottom left corner point for this rectangle) + !*/ + + const vector tr_corner ( + ) const; + /*! + ensures + - returns vector(right(), top()) + (i.e. returns the top right corner point for this rectangle) + !*/ + + const vector br_corner ( + ) const; + /*! + ensures + - returns vector(right(), bottom()) + (i.e. returns the bottom right corner point for this rectangle) + !*/ + + double width ( + ) const; + /*! + ensures + - if (is_empty()) then + - returns 0 + - else + - returns the width of this rectangle. + (i.e. right() - left() + 1) + !*/ + + double height ( + ) const; + /*! + ensures + - if (is_empty()) then + - returns 0 + - else + - returns the height of this rectangle. + (i.e. bottom() - top() + 1) + !*/ + + double area ( + ) const; + /*! + ensures + - returns width()*height() + !*/ + + bool is_empty ( + ) const; + /*! + ensures + - if (top() > bottom() || left() > right()) then + - returns true + - else + - returns false + !*/ + + drectangle operator + ( + const drectangle& rhs + ) const; + /*! + ensures + - if (rhs.is_empty() == false && this->is_empty() == false) then + - returns the smallest rectangle that contains both *this and + rhs. + - if (rhs.is_empty() == true && this->is_empty() == false) then + - returns *this + - if (rhs.is_empty() == false && this->is_empty() == true) then + - returns rhs + - if (rhs.is_empty() == true && this->is_empty() == true) then + - returns a rectangle that has is_empty() == true + !*/ + + drectangle intersect ( + const drectangle& rhs + ) const; + /*! + ensures + - if (there is a region of intersection between *this and rhs) then + - returns a rectangle that represents the intersection of *this + and rhs. + - else + - returns a rectangle where is_empty() == true + !*/ + + bool contains ( + const vector& p + ) const; + /*! + ensures + - if (the point (p.x(),p.y()) is contained in this rectangle) then + - returns true + - else + - returns false + !*/ + + bool contains ( + const drectangle& rect + ) const + /*! + ensures + - if (rect + *this == *this) then + - returns true + (i.e. returns true if *this contains the given rectangle) + - else + - returns false + !*/ + + drectangle& operator *= ( + const double& scale + ); + /*! + ensures + - performs: *this = *this*scale; + - returns #*this + !*/ + + drectangle& operator /= ( + const double& scale + ); + /*! + requires + - scale != 0 + ensures + - performs: *this = *this*(1.0/scale); + - returns #*this + !*/ + + drectangle& operator += ( + const dlib::vector& p + ); + /*! + ensures + - performs: *this = *this + drectangle(p) + - returns #*this + !*/ + + bool operator== ( + const drectangle& rect + ) const; + /*! + ensures + - if (top() == rect.top() && left() == rect.left() && + right() == rect.right() && bottom() == rect.bottom()) then + - returns true + - else + - returns false + !*/ + + bool operator!= ( + const drectangle& rect + ) const; + /*! + ensures + - returns !(*this == rect) + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + void serialize ( + const drectangle& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + + void deserialize ( + drectangle& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +// ---------------------------------------------------------------------------------------- + + std::ostream& operator<< ( + std::ostream& out, + const drectangle& item + ); + /*! + ensures + - writes item to out in the form "[(left, top) (right, bottom)]" + !*/ + +// ---------------------------------------------------------------------------------------- + + std::istream& operator>>( + std::istream& in, + drectangle& item + ); + /*! + ensures + - reads a drectangle from the input stream in and stores it in #item. The data + in the input stream should be of the form [(left, top) (right, bottom)] + !*/ + +// ---------------------------------------------------------------------------------------- + + vector center ( + const drectangle& rect + ); + /*! + ensures + - returns the center of the given rectangle + !*/ + +// ---------------------------------------------------------------------------------------- + + vector dcenter ( + const drectangle& rect + ); + /*! + ensures + - returns the center of the given rectangle. (Both center() and dcenter() are + identical when applied to drectangle objects) + !*/ + +// ---------------------------------------------------------------------------------------- + + drectangle operator* ( + const drectangle& rect, + const double& scale + ); + /*! + ensures + - This function returns a rectangle that has the same center as rect but with + dimensions that are scale times larger. That is, we return a new rectangle R + such that: + - center(R) == center(rect) + - R.right()-R.left() == (rect.right()-rect.left())*scale + - R.bottom()-R.top() == (rect.bottom()-rect.top())*scale + !*/ + +// ---------------------------------------------------------------------------------------- + + drectangle operator* ( + const double& scale, + const drectangle& rect + ); + /*! + ensures + - returns rect*scale + !*/ + +// ---------------------------------------------------------------------------------------- + + drectangle operator/ ( + const drectangle& rect, + const double& scale + ); + /*! + ensures + - returns rect*(1.0/scale) + !*/ + +// ---------------------------------------------------------------------------------------- + + drectangle operator+ ( + const drectangle& r, + const vector& p + ); + /*! + ensures + - returns r + drectangle(p) + (i.e. returns the rectangle that contains both r and p) + !*/ + +// ---------------------------------------------------------------------------------------- + + drectangle operator+ ( + const vector& p, + const drectangle& r + ); + /*! + ensures + - returns r + drectangle(p) + (i.e. returns the rectangle that contains both r and p) + !*/ + +// ---------------------------------------------------------------------------------------- + + template + drectangle translate_rect ( + const drectangle& rect, + const vector& p + ); + /*! + ensures + - returns a rectangle R such that: + - R.left() == rect.left() + p.x() + - R.right() == rect.right() + p.x() + - R.top() == rect.top() + p.y() + - R.bottom() == rect.bottom() + p.y() + !*/ + +// ---------------------------------------------------------------------------------------- + + drectangle intersect ( + const drectangle& a, + const drectangle& b + ); + /*! + ensures + - returns a.intersect(b) + (i.e. returns a rectangle representing the intersection of a and b) + !*/ + +// ---------------------------------------------------------------------------------------- + + double area ( + const drectangle& a + ); + /*! + ensures + - returns a.area() + !*/ + +// ---------------------------------------------------------------------------------------- + + drectangle centered_drect ( + const vector& p, + double width, + double height + ); + /*! + ensures + - returns a rectangle R such that: + - center(R) == p + - if (width < 1 || height < 1) + - R.width() == 0 + - R.height() == 0 + - R.is_empty() == true + - else + - R.width() == width + - R.height() == height + !*/ + +// ---------------------------------------------------------------------------------------- + + drectangle centered_drect ( + const drectangle& rect, + double width, + double height + ); + /*! + ensures + - returns centered_drect(center(rect), width, height) + !*/ + +// ---------------------------------------------------------------------------------------- + + const drectangle shrink_rect ( + const drectangle& rect, + double num + ); + /*! + ensures + - returns drectangle(rect.left()+num, rect.top()+num, rect.right()-num, rect.bottom()-num) + (i.e. shrinks the given drectangle by shrinking its border by num) + !*/ + +// ---------------------------------------------------------------------------------------- + + const drectangle grow_rect ( + const drectangle& rect, + double num + ); + /*! + ensures + - return shrink_rect(rect, -num) + (i.e. grows the given drectangle by expanding its border by num) + !*/ + +// ---------------------------------------------------------------------------------------- + + const drectangle shrink_rect ( + const drectangle& rect, + double width, + double height + ); + /*! + ensures + - returns drectangle(rect.left()+width, rect.top()+height, rect.right()-width, rect.bottom()-height) + (i.e. shrinks the given drectangle by shrinking its left and right borders by width + and its top and bottom borders by height. ) + !*/ + +// ---------------------------------------------------------------------------------------- + + const drectangle grow_rect ( + const drectangle& rect, + double width, + double height + ); + /*! + ensures + - return shrink_rect(rect, -width, -height) + (i.e. grows the given drectangle by expanding its border) + !*/ + +// ---------------------------------------------------------------------------------------- + + drectangle set_rect_area ( + const drectangle& rect, + double area + ); + /*! + requires + - area >= 0 + ensures + - Returns a rectangle R such that: + - center(R) == center(rect) + - R has the same aspect ratio as rect. If rect.area() == 0 then the + returned rect has a 1:1 aspect ratio. + - R.area() == area + !*/ + +// ---------------------------------------------------------------------------------------- + + drectangle set_aspect_ratio ( + const drectangle& rect, + double ratio + ); + /*! + requires + - ratio > 0 + ensures + - This function reshapes the given rectangle so that it has the given aspect + ratio. In particular, this means we return a rectangle R such that the + following equations are true: + - R.width()/R.height() == ratio + - R.area() == rect.area() + - dcenter(rect) == dcenter(R) + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_DRECTANGLe_ABSTRACT_H_ + diff --git a/ml/dlib/dlib/geometry/point_transforms.h b/ml/dlib/dlib/geometry/point_transforms.h new file mode 100644 index 000000000..e789fd2a2 --- /dev/null +++ b/ml/dlib/dlib/geometry/point_transforms.h @@ -0,0 +1,989 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_POINT_TrANSFORMS_H_ +#define DLIB_POINT_TrANSFORMS_H_ + +#include "point_transforms_abstract.h" +#include "../algs.h" +#include "vector.h" +#include "../matrix.h" +#include "../matrix/matrix_la.h" +#include "../optimization/optimization.h" +#include "rectangle.h" +#include "drectangle.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class point_rotator + { + public: + point_rotator ( + ) + { + sin_angle = 0; + cos_angle = 1; + } + + point_rotator ( + const double& angle + ) + { + sin_angle = std::sin(angle); + cos_angle = std::cos(angle); + } + + template + const dlib::vector operator() ( + const dlib::vector& p + ) const + { + double x = cos_angle*p.x() - sin_angle*p.y(); + double y = sin_angle*p.x() + cos_angle*p.y(); + + return dlib::vector(x,y); + } + + const matrix get_m( + ) const + { + matrix temp; + temp = cos_angle, -sin_angle, + sin_angle, cos_angle; + return temp; + } + + inline friend void serialize (const point_rotator& item, std::ostream& out) + { + serialize(item.sin_angle, out); + serialize(item.cos_angle, out); + } + + inline friend void deserialize (point_rotator& item, std::istream& in) + { + deserialize(item.sin_angle, in); + deserialize(item.cos_angle, in); + } + + private: + double sin_angle; + double cos_angle; + }; + +// ---------------------------------------------------------------------------------------- + + class point_transform + { + public: + + point_transform ( + ) + { + sin_angle = 0; + cos_angle = 1; + translate.x() = 0; + translate.y() = 0; + } + + point_transform ( + const double& angle, + const dlib::vector& translate_ + ) + { + sin_angle = std::sin(angle); + cos_angle = std::cos(angle); + translate = translate_; + } + + template + const dlib::vector operator() ( + const dlib::vector& p + ) const + { + double x = cos_angle*p.x() - sin_angle*p.y(); + double y = sin_angle*p.x() + cos_angle*p.y(); + + return dlib::vector(x,y) + translate; + } + + const matrix get_m( + ) const + { + matrix temp; + temp = cos_angle, -sin_angle, + sin_angle, cos_angle; + return temp; + } + + const dlib::vector get_b( + ) const { return translate; } + + inline friend void serialize (const point_transform& item, std::ostream& out) + { + serialize(item.sin_angle, out); + serialize(item.cos_angle, out); + serialize(item.translate, out); + } + + inline friend void deserialize (point_transform& item, std::istream& in) + { + deserialize(item.sin_angle, in); + deserialize(item.cos_angle, in); + deserialize(item.translate, in); + } + + private: + double sin_angle; + double cos_angle; + dlib::vector translate; + }; + +// ---------------------------------------------------------------------------------------- + + class point_transform_affine + { + public: + + point_transform_affine ( + ) + { + m = identity_matrix(2); + b.x() = 0; + b.y() = 0; + } + + point_transform_affine ( + const matrix& m_, + const dlib::vector& b_ + ) :m(m_), b(b_) + { + } + + const dlib::vector operator() ( + const dlib::vector& p + ) const + { + return m*p + b; + } + + const matrix& get_m( + ) const { return m; } + + const dlib::vector& get_b( + ) const { return b; } + + inline friend void serialize (const point_transform_affine& item, std::ostream& out) + { + serialize(item.m, out); + serialize(item.b, out); + } + + inline friend void deserialize (point_transform_affine& item, std::istream& in) + { + deserialize(item.m, in); + deserialize(item.b, in); + } + + private: + matrix m; + dlib::vector b; + }; + +// ---------------------------------------------------------------------------------------- + + class rectangle_transform + { + public: + + rectangle_transform ( + ) + { + } + + rectangle_transform ( + const point_transform_affine& tform_ + ) :tform(tform_) + { + } + + drectangle operator() ( + const drectangle& r + ) const + { + dpoint tl = r.tl_corner(); + dpoint tr = r.tr_corner(); + dpoint bl = r.bl_corner(); + dpoint br = r.br_corner(); + // The new rectangle wouold ideally have this area if we could actually rotrate + // the box. + double new_area = length(tform(tl)-tform(tr))*length(tform(tl)-tform(bl)); + + // But if we rotate the coners of the rectangle and then find the rectangle + // that contains them we get this, which might have a much larger area than we + // want. + drectangle temp; + temp += tform(tl); + temp += tform(tr); + temp += tform(bl); + temp += tform(br); + // so we adjust the area to match the target area and have the same center as + // the above box. + double scale = std::sqrt(new_area/temp.area()); + + return centered_rect(center(temp), static_cast(temp.width()*scale+0.5), static_cast(temp.height()*scale+0.5)); + } + + rectangle operator() ( + const rectangle& r + ) const + { + return (*this)(drectangle(r)); + } + + const point_transform_affine& get_tform( + ) const { return tform; } + + inline friend void serialize (const rectangle_transform& item, std::ostream& out) + { + serialize(item.tform, out); + } + + inline friend void deserialize (rectangle_transform& item, std::istream& in) + { + deserialize(item.tform, in); + } + + private: + point_transform_affine tform; + }; + +// ---------------------------------------------------------------------------------------- + + inline point_transform_affine operator* ( + const point_transform_affine& lhs, + const point_transform_affine& rhs + ) + { + return point_transform_affine(lhs.get_m()*rhs.get_m(), lhs.get_m()*rhs.get_b()+lhs.get_b()); + } + +// ---------------------------------------------------------------------------------------- + + inline point_transform_affine inv ( + const point_transform_affine& trans + ) + { + matrix im = inv(trans.get_m()); + return point_transform_affine(im, -im*trans.get_b()); + } + +// ---------------------------------------------------------------------------------------- + + template + point_transform_affine find_affine_transform ( + const std::vector >& from_points, + const std::vector >& to_points + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(from_points.size() == to_points.size() && + from_points.size() >= 3, + "\t point_transform_affine find_affine_transform(from_points, to_points)" + << "\n\t Invalid inputs were given to this function." + << "\n\t from_points.size(): " << from_points.size() + << "\n\t to_points.size(): " << to_points.size() + ); + + matrix P(3, from_points.size()); + matrix Q(2, from_points.size()); + + for (unsigned long i = 0; i < from_points.size(); ++i) + { + P(0,i) = from_points[i].x(); + P(1,i) = from_points[i].y(); + P(2,i) = 1; + + Q(0,i) = to_points[i].x(); + Q(1,i) = to_points[i].y(); + } + + const matrix m = Q*pinv(P); + return point_transform_affine(subm(m,0,0,2,2), colm(m,2)); + } + +// ---------------------------------------------------------------------------------------- + + template + point_transform_affine find_similarity_transform ( + const std::vector >& from_points, + const std::vector >& to_points + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(from_points.size() == to_points.size() && + from_points.size() >= 2, + "\t point_transform_affine find_similarity_transform(from_points, to_points)" + << "\n\t Invalid inputs were given to this function." + << "\n\t from_points.size(): " << from_points.size() + << "\n\t to_points.size(): " << to_points.size() + ); + + // We use the formulas from the paper: Least-squares estimation of transformation + // parameters between two point patterns by Umeyama. They are equations 34 through + // 43. + + dlib::vector mean_from, mean_to; + double sigma_from = 0, sigma_to = 0; + matrix cov; + cov = 0; + + for (unsigned long i = 0; i < from_points.size(); ++i) + { + mean_from += from_points[i]; + mean_to += to_points[i]; + } + mean_from /= from_points.size(); + mean_to /= from_points.size(); + + for (unsigned long i = 0; i < from_points.size(); ++i) + { + sigma_from += length_squared(from_points[i] - mean_from); + sigma_to += length_squared(to_points[i] - mean_to); + cov += (to_points[i] - mean_to)*trans(from_points[i] - mean_from); + } + + sigma_from /= from_points.size(); + sigma_to /= from_points.size(); + cov /= from_points.size(); + + matrix u, v, s, d; + svd(cov, u,d,v); + s = identity_matrix(cov); + if (det(cov) < 0 || (det(cov) == 0 && det(u)*det(v)<0)) + { + if (d(1,1) < d(0,0)) + s(1,1) = -1; + else + s(0,0) = -1; + } + + matrix r = u*s*trans(v); + double c = 1; + if (sigma_from != 0) + c = 1.0/sigma_from * trace(d*s); + vector t = mean_to - c*r*mean_from; + + return point_transform_affine(c*r, t); + } + +// ---------------------------------------------------------------------------------------- + + class point_transform_projective + { + public: + + point_transform_projective ( + ) + { + m = identity_matrix(3); + } + + point_transform_projective ( + const matrix& m_ + ) :m(m_) + { + } + + point_transform_projective ( + const point_transform_affine& tran + ) + { + set_subm(m, 0,0, 2,2) = tran.get_m(); + set_subm(m, 0,2, 2,1) = tran.get_b(); + m(2,0) = 0; + m(2,1) = 0; + m(2,2) = 1; + } + + + const dlib::vector operator() ( + const dlib::vector& p + ) const + { + dlib::vector temp(p); + temp.z() = 1; + temp = m*temp; + if (temp.z() != 0) + temp = temp/temp.z(); + return temp; + } + + const matrix& get_m( + ) const { return m; } + + inline friend void serialize (const point_transform_projective& item, std::ostream& out) + { + serialize(item.m, out); + } + + inline friend void deserialize (point_transform_projective& item, std::istream& in) + { + deserialize(item.m, in); + } + + private: + matrix m; + }; + +// ---------------------------------------------------------------------------------------- + + inline point_transform_projective operator* ( + const point_transform_projective& lhs, + const point_transform_projective& rhs + ) + { + return point_transform_projective(lhs.get_m()*rhs.get_m()); + } + +// ---------------------------------------------------------------------------------------- + + inline point_transform_projective inv ( + const point_transform_projective& trans + ) + { + return point_transform_projective(inv(trans.get_m())); + } + +// ---------------------------------------------------------------------------------------- + + namespace impl_proj + { + + inline point_transform_projective find_projective_transform_basic ( + const std::vector >& from_points, + const std::vector >& to_points + ) + /*! + ensures + - Uses the system of equations approach to finding a projective transform. + This is "Method 3" from Estimating Projective Transformation Matrix by + Zhengyou Zhang. + - It should be emphasized that the find_projective_transform_basic() + routine, which uses the most popular method for finding projective + transformations, doesn't really work well when the minimum error solution + doesn't have zero error. In this case, it can deviate by a large amount + from the proper minimum mean squared error transformation. Therefore, + our overall strategy will be to use the solution from + find_projective_transform_basic() as a starting point for a BFGS based + non-linear optimizer which will optimize the correct mean squared error + criterion. + + A great essay on this subject is Homography Estimation by Elan Dubrofsky. + !*/ + { + // make sure requires clause is not broken + DLIB_ASSERT(from_points.size() == to_points.size() && + from_points.size() >= 4, + "\t point_transform_projective find_projective_transform_basic(from_points, to_points)" + << "\n\t Invalid inputs were given to this function." + << "\n\t from_points.size(): " << from_points.size() + << "\n\t to_points.size(): " << to_points.size() + ); + + matrix accum, u, v; + matrix w; + matrix B; + accum = 0; + B = 0; + for (unsigned long i = 0; i < from_points.size(); ++i) + { + dlib::vector f = from_points[i]; + f.z() = 1; + dlib::vector t = to_points[i]; + t.z() = 1; + + set_subm(B,0,0,1,3) = t.y()*trans(f); + set_subm(B,1,0,1,3) = trans(f); + + set_subm(B,0,3,1,3) = -t.x()*trans(f); + set_subm(B,1,6,1,3) = -t.x()*trans(f); + + accum += trans(B)*B; + } + + svd2(true, false, accum, u, w, v); + long j = index_of_min(w); + + return point_transform_projective(reshape(colm(u,j),3,3)); + } + + // ---------------------------------------------------------------------------------------- + + struct obj + { + /*! + WHAT THIS OBJECT REPRESENTS + This is the objective function we really want to minimize when looking + for a transformation matrix. That is, we would like the transformed + points to be as close as possible to their "to" points. Here, + closeness is measured using Euclidean distance. + + !*/ + obj( + const std::vector >& from_points_, + const std::vector >& to_points_ + ) : + from_points(from_points_) , + to_points(to_points_) + {} + const std::vector >& from_points; + const std::vector >& to_points; + + double operator() ( + const matrix& p + ) const + { + point_transform_projective tran(reshape(p,3,3)); + + double sum = 0; + for (unsigned long i = 0; i < from_points.size(); ++i) + { + sum += length_squared(tran(from_points[i]) - to_points[i]); + } + return sum; + } + }; + + struct obj_der + { + /*! + WHAT THIS OBJECT REPRESENTS + This is the derivative of obj. + !*/ + + obj_der( + const std::vector >& from_points_, + const std::vector >& to_points_ + ) : + from_points(from_points_) , + to_points(to_points_) + {} + const std::vector >& from_points; + const std::vector >& to_points; + + matrix operator() ( + const matrix& p + ) const + { + const matrix H = reshape(p,3,3); + + matrix grad; + grad = 0; + for (unsigned long i = 0; i < from_points.size(); ++i) + { + dlib::vector from, to; + from = from_points[i]; + from.z() = 1; + to = to_points[i]; + to.z() = 1; + + matrix w = H*from; + const double scale = (w(2) != 0) ? (1.0/w(2)) : (1); + w *= scale; + matrix residual = (w-to)*2*scale; + + grad(0,0) += from.x()*residual(0); + grad(0,1) += from.y()*residual(0); + grad(0,2) += residual(0); + + grad(1,0) += from.x()*residual(1); + grad(1,1) += from.y()*residual(1); + grad(1,2) += residual(1); + + grad(2,0) += -(from.x()*w(0)*residual(0) + from.x()*w(1)*residual(1)); + grad(2,1) += -(from.y()*w(0)*residual(0) + from.y()*w(1)*residual(1)); + grad(2,2) += -( w(0)*residual(0) + w(1)*residual(1)); + + } + return reshape_to_column_vector(grad); + } + }; + } + +// ---------------------------------------------------------------------------------------- + + inline point_transform_projective find_projective_transform ( + const std::vector >& from_points, + const std::vector >& to_points + ) + { + using namespace impl_proj; + // make sure requires clause is not broken + DLIB_ASSERT(from_points.size() == to_points.size() && + from_points.size() >= 4, + "\t point_transform_projective find_projective_transform(from_points, to_points)" + << "\n\t Invalid inputs were given to this function." + << "\n\t from_points.size(): " << from_points.size() + << "\n\t to_points.size(): " << to_points.size() + ); + + + // Find a candidate projective transformation. Also, find the best affine + // transform and then compare it with the projective transform estimated using the + // direct SVD method. Use whichever one works better as the starting point for a + // BFGS based optimizer. If the best solution has large mean squared error and is + // also close to affine then find_projective_transform_basic() might give a very + // bad initial guess. So also checking for a good affine transformation can + // produce a much better final result in many cases. + point_transform_projective tran1 = find_projective_transform_basic(from_points, to_points); + point_transform_affine tran2 = find_affine_transform(from_points, to_points); + + // check which is best + double error1 = 0; + double error2 = 0; + for (unsigned long i = 0; i < from_points.size(); ++i) + { + error1 += length_squared(tran1(from_points[i])-to_points[i]); + error2 += length_squared(tran2(from_points[i])-to_points[i]); + } + matrix params; + // Pick the minimum error solution among the two so far. + if (error1 < error2) + params = reshape_to_column_vector(tran1.get_m()); + else + params = reshape_to_column_vector(point_transform_projective(tran2).get_m()); + + + // Now refine the transformation matrix so that we can be sure we have + // at least a local minimizer. + obj o(from_points, to_points); + obj_der der(from_points, to_points); + find_min(bfgs_search_strategy(), + objective_delta_stop_strategy(1e-6,100), + o, + der, + params, + 0); + + return point_transform_projective(reshape(params,3,3)); + } + +// ---------------------------------------------------------------------------------------- + + template + const dlib::vector rotate_point ( + const dlib::vector& center, + const dlib::vector& p, + double angle + ) + { + point_rotator rot(angle); + return rot(p-center)+center; + } + +// ---------------------------------------------------------------------------------------- + + inline matrix rotation_matrix ( + double angle + ) + { + const double ca = std::cos(angle); + const double sa = std::sin(angle); + + matrix m; + m = ca, -sa, + sa, ca; + return m; + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class point_transform_affine3d + { + public: + + point_transform_affine3d ( + ) + { + m = identity_matrix(3); + b.x() = 0; + b.y() = 0; + } + + point_transform_affine3d ( + const matrix& m_, + const dlib::vector& b_ + ) :m(m_), b(b_) + { + } + + const dlib::vector operator() ( + const dlib::vector& p + ) const + { + return m*p + b; + } + + const matrix& get_m( + ) const { return m; } + + const dlib::vector& get_b( + ) const { return b; } + + inline friend void serialize (const point_transform_affine3d& item, std::ostream& out) + { + serialize(item.m, out); + serialize(item.b, out); + } + + inline friend void deserialize (point_transform_affine3d& item, std::istream& in) + { + deserialize(item.m, in); + deserialize(item.b, in); + } + + private: + matrix m; + dlib::vector b; + }; + +// ---------------------------------------------------------------------------------------- + + inline point_transform_affine3d operator* ( + const point_transform_affine3d& lhs, + const point_transform_affine& rhs + ) + { + matrix m; + m = 0; + set_subm(m, get_rect(rhs.get_m())) = rhs.get_m(); + vector b = rhs.get_b(); + + return point_transform_affine3d(lhs.get_m()*m, lhs.get_m()*b+lhs.get_b()); + } + +// ---------------------------------------------------------------------------------------- + + inline point_transform_affine3d operator* ( + const point_transform_affine3d& lhs, + const point_transform_affine3d& rhs + ) + { + return point_transform_affine3d(lhs.get_m()*rhs.get_m(), lhs.get_m()*rhs.get_b()+lhs.get_b()); + } + +// ---------------------------------------------------------------------------------------- + + inline point_transform_affine3d inv ( + const point_transform_affine3d& trans + ) + { + matrix im = inv(trans.get_m()); + return point_transform_affine3d(im, -im*trans.get_b()); + } + +// ---------------------------------------------------------------------------------------- + + inline point_transform_affine3d rotate_around_x ( + double angle + ) + { + const double ca = std::cos(angle); + const double sa = std::sin(angle); + + matrix m; + m = 1, 0, 0, + 0, ca, -sa, + 0, sa, ca; + + vector b; + + return point_transform_affine3d(m,b); + } + +// ---------------------------------------------------------------------------------------- + + inline point_transform_affine3d rotate_around_y ( + double angle + ) + { + const double ca = std::cos(angle); + const double sa = std::sin(angle); + + matrix m; + m = ca, 0, sa, + 0, 1, 0, + -sa, 0, ca; + + vector b; + + return point_transform_affine3d(m,b); + } + +// ---------------------------------------------------------------------------------------- + + inline point_transform_affine3d rotate_around_z ( + double angle + ) + { + const double ca = std::cos(angle); + const double sa = std::sin(angle); + + matrix m; + m = ca, -sa, 0, + sa, ca, 0, + 0, 0, 1; + + vector b; + + return point_transform_affine3d(m,b); + } + +// ---------------------------------------------------------------------------------------- + + inline point_transform_affine3d translate_point ( + const vector& delta + ) + { + return point_transform_affine3d(identity_matrix(3),delta); + } + + inline point_transform_affine3d translate_point ( + double x, + double y, + double z + ) + { + return translate_point(vector(x,y,z)); + } + +// ---------------------------------------------------------------------------------------- + + class camera_transform + { + + public: + + camera_transform ( + ) + { + *this = camera_transform(vector(1,1,1), + vector(0,0,0), + vector(0,0,1), + 90, + 1); + } + + camera_transform ( + const vector& camera_pos_, + const vector& camera_looking_at_, + const vector& camera_up_direction_, + const double camera_field_of_view_, + const unsigned long num_pixels_ + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(0 < camera_field_of_view_ && camera_field_of_view_ < 180, + "\t camera_transform::camera_transform()" + << "\n\t Invalid inputs were given to this function." + << "\n\t camera_field_of_view_: " << camera_field_of_view_ + ); + + camera_pos = camera_pos_; + camera_looking_at = camera_looking_at_; + camera_up_direction = camera_up_direction_; + camera_field_of_view = camera_field_of_view_; + num_pixels = num_pixels_; + + dlib::vector X,Y,Z; + Z = (camera_looking_at - camera_pos).normalize(); + Y = camera_up_direction - dot(camera_up_direction,Z)*Z; + Y = Y.normalize(); + X = Z.cross(Y); + + set_rowm(proj,0) = trans(X); + // Minus because images have y axis going down but we want the 3d projection to appear using a normal coordinate system with y going up. + set_rowm(proj,1) = -trans(Y); + set_rowm(proj,2) = trans(Z); + + width = num_pixels/2.0; + dist_scale = width/std::tan(pi/180*camera_field_of_view/2); + } + + vector get_camera_pos() const { return camera_pos; } + vector get_camera_looking_at() const { return camera_looking_at; } + vector get_camera_up_direction()const { return camera_up_direction; } + double get_camera_field_of_view() const { return camera_field_of_view; } + unsigned long get_num_pixels() const { return num_pixels; } + + inline dpoint operator() ( + const vector& p, + double& scale, + double& distance + ) const + { + vector temp = p-camera_pos; + temp = proj*temp; + distance = temp.z(); + scale = dist_scale/(temp.z()>0 ? temp.z() : 1e-9); + temp.x() = temp.x()*scale + width; + temp.y() = temp.y()*scale + width; + return temp; + } + + dpoint operator() ( + const vector& p + ) const + { + double scale, distance; + return (*this)(p,scale,distance); + } + + inline friend void serialize (const camera_transform& item, std::ostream& out) + { + serialize(item.camera_pos, out); + serialize(item.camera_looking_at, out); + serialize(item.camera_up_direction, out); + serialize(item.camera_field_of_view, out); + serialize(item.num_pixels, out); + serialize(item.proj, out); + serialize(item.dist_scale, out); + serialize(item.width, out); + } + + inline friend void deserialize (camera_transform& item, std::istream& in) + { + deserialize(item.camera_pos, in); + deserialize(item.camera_looking_at, in); + deserialize(item.camera_up_direction, in); + deserialize(item.camera_field_of_view, in); + deserialize(item.num_pixels, in); + deserialize(item.proj, in); + deserialize(item.dist_scale, in); + deserialize(item.width, in); + } + + private: + + vector camera_pos; + vector camera_looking_at; + vector camera_up_direction; + double camera_field_of_view; + unsigned long num_pixels; + matrix proj; + double dist_scale; + double width; + + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_POINT_TrANSFORMS_H_ + diff --git a/ml/dlib/dlib/geometry/point_transforms_abstract.h b/ml/dlib/dlib/geometry/point_transforms_abstract.h new file mode 100644 index 000000000..492ae745c --- /dev/null +++ b/ml/dlib/dlib/geometry/point_transforms_abstract.h @@ -0,0 +1,797 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_POINT_TrANSFORMS_ABSTRACT_Hh_ +#ifdef DLIB_POINT_TrANSFORMS_ABSTRACT_Hh_ + +#include "../matrix/matrix_abstract.h" +#include "vector_abstract.h" +#include "rectangle_abstract.h" +#include "drectangle_abstract.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class point_transform_affine + { + /*! + WHAT THIS OBJECT REPRESENTS + This is an object that takes 2D points or vectors and + applies an affine transformation to them. + + THREAD SAFETY + It is safe for multiple threads to make concurrent accesses to this object + without synchronization. + !*/ + public: + + point_transform_affine ( + ); + /*! + ensures + - This object will perform the identity transform. That is, given a point + as input it will return the same point as output. + !*/ + + point_transform_affine ( + const matrix& m, + const dlib::vector& b + ); + /*! + ensures + - #get_m() == m + - #get_b() == b + - When (*this)(p) is invoked it will return a point P such that: + - P == m*p + b + !*/ + + const dlib::vector operator() ( + const dlib::vector& p + ) const; + /*! + ensures + - applies the affine transformation defined by this object's constructor + to p and returns the result. + !*/ + + const matrix& get_m( + ) const; + /*! + ensures + - returns the transformation matrix used by this object. + !*/ + + const dlib::vector& get_b( + ) const; + /*! + ensures + - returns the offset vector used by this object. + !*/ + + }; + + void serialize (const point_transform_affine& item, std::ostream& out); + void deserialize (point_transform_affine& item, std::istream& in); + /*! + provides serialization support + !*/ + +// ---------------------------------------------------------------------------------------- + + point_transform_affine operator* ( + const point_transform_affine& lhs, + const point_transform_affine& rhs + ); + /*! + ensures + - returns a transformation TFORM(x) that is equivalent to lhs(rhs(x)). That + is, for all valid x: TFORM(x) == lhs(rhs(x)). + !*/ + + // ---------------------------------------------------------------------------------------- + + point_transform_affine inv ( + const point_transform_affine& trans + ); + /*! + ensures + - If trans is an invertible transformation then this function returns a new + transformation that is the inverse of trans. + !*/ + +// ---------------------------------------------------------------------------------------- + + template + point_transform_affine find_affine_transform ( + const std::vector >& from_points, + const std::vector >& to_points + ); + /*! + requires + - from_points.size() == to_points.size() + - from_points.size() >= 3 + ensures + - returns a point_transform_affine object, T, such that for all valid i: + length(T(from_points[i]) - to_points[i]) + is minimized as often as possible. That is, this function finds the affine + transform that maps points in from_points to points in to_points. If no + affine transform exists which performs this mapping exactly then the one + which minimizes the mean squared error is selected. Additionally, if many + equally good transformations exist, then the transformation with the smallest + squared parameters is selected (i.e. if you wrote the transformation as a + matrix then we say we select the transform with minimum Frobenius norm among + all possible solutions). + !*/ + +// ---------------------------------------------------------------------------------------- + + template + point_transform_affine find_similarity_transform ( + const std::vector >& from_points, + const std::vector >& to_points + ); + /*! + requires + - from_points.size() == to_points.size() + - from_points.size() >= 2 + ensures + - This function is just like find_affine_transform() except it finds the best + similarity transform instead of a full affine transform. This means that it + optimizes over only the space of rotations, scale changes, and translations. + So for example, if you mapped the 3 vertices of a triangle through a + similarity transform then the output would still be the same triangle. + However, the triangle itself may be larger or smaller, rotated, or at a + different location in the coordinate system. This is not the case for a + general affine transform which can stretch points in ways that cause, for + example, an equilateral triangle to turn into an isosceles triangle. + !*/ + +// ---------------------------------------------------------------------------------------- + + class rectangle_transform + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is just a point_transform_affine wrapped up so that it can + transform rectangle objects. It will take a rectangle and transform it + according to an affine transformation. + + THREAD SAFETY + It is safe for multiple threads to make concurrent accesses to this object + without synchronization. + !*/ + public: + + rectangle_transform ( + ); + /*! + ensures + - This object will perform the identity transform. That is, given a rectangle + as input it will return the same rectangle as output. + !*/ + + rectangle_transform ( + const point_transform_affine& tform + ); + /*! + ensures + - #get_tform() == tform + !*/ + + drectangle operator() ( + const drectangle& r + ) const; + /*! + ensures + - Applies the transformation get_tform() to r and returns the resulting + rectangle. If the transformation doesn't have any rotation then the + transformation simply maps the corners of the rectangle according to + get_tform() and returns the exact result. However, since + dlib::drectangle can't represent rotated rectangles, if there is any + rotation in the affine transform we will attempt to produce the most + faithful possible outputs by ensuring the output rectangle has the + correct center point and that its area and aspect ratio match the correct + rotated rectangle's as much as possible. + !*/ + + rectangle operator() ( + const rectangle& r + ) const; + /*! + ensures + - returns (*this)(drectangle(r)) + !*/ + + const point_transform_affine& get_tform( + ) const; + /*! + ensures + - returns the affine transformation this object uses to transform rectangles. + !*/ + + }; + + void serialize (const rectangle_transform& item, std::ostream& out); + void deserialize (rectangle_transform& item, std::istream& in); + /*! + provides serialization support + !*/ + +// ---------------------------------------------------------------------------------------- + + class point_transform_projective + { + /*! + WHAT THIS OBJECT REPRESENTS + This is an object that takes 2D points or vectors and + applies a projective transformation to them. + + THREAD SAFETY + It is safe for multiple threads to make concurrent accesses to this object + without synchronization. + !*/ + + public: + + point_transform_projective ( + ); + /*! + ensures + - This object will perform the identity transform. That is, given a point + as input it will return the same point as output. + !*/ + + point_transform_projective ( + const matrix& m + ); + /*! + ensures + - #get_m() == m + !*/ + + point_transform_projective ( + const point_transform_affine& tran + ); + /*! + ensures + - This object will perform exactly the same transformation as the given + affine transform. + !*/ + + const dlib::vector operator() ( + const dlib::vector& p + ) const; + /*! + ensures + - Applies the projective transformation defined by this object's constructor + to p and returns the result. To define this precisely: + - let p_h == the point p in homogeneous coordinates. That is: + - p_h.x() == p.x() + - p_h.y() == p.y() + - p_h.z() == 1 + - let x == get_m()*p_h + - Then this function returns the value x/x.z() + !*/ + + const matrix& get_m( + ) const; + /*! + ensures + - returns the transformation matrix used by this object. + !*/ + + }; + + void serialize (const point_transform_projective& item, std::ostream& out); + void deserialize (point_transform_projective& item, std::istream& in); + /*! + provides serialization support + !*/ + +// ---------------------------------------------------------------------------------------- + + point_transform_projective operator* ( + const point_transform_projective& lhs, + const point_transform_projective& rhs + ); + /*! + ensures + - returns a transformation TFORM(x) that is equivalent to lhs(rhs(x)). That + is, for all valid x: TFORM(x) == lhs(rhs(x)). + !*/ + +// ---------------------------------------------------------------------------------------- + + point_transform_projective inv ( + const point_transform_projective& trans + ); + /*! + ensures + - If trans is an invertible transformation then this function returns a new + transformation that is the inverse of trans. + !*/ + +// ---------------------------------------------------------------------------------------- + + point_transform_projective find_projective_transform ( + const std::vector >& from_points, + const std::vector >& to_points + ); + /*! + requires + - from_points.size() == to_points.size() + - from_points.size() >= 4 + ensures + - returns a point_transform_projective object, T, such that for all valid i: + length(T(from_points[i]) - to_points[i]) + is minimized as often as possible. That is, this function finds the projective + transform that maps points in from_points to points in to_points. If no + projective transform exists which performs this mapping exactly then the one + which minimizes the mean squared error is selected. + !*/ + +// ---------------------------------------------------------------------------------------- + + class point_transform + { + /*! + WHAT THIS OBJECT REPRESENTS + This is an object that takes 2D points or vectors and + rotates them around the origin by a given angle and then + translates them. + + THREAD SAFETY + It is safe for multiple threads to make concurrent accesses to this object + without synchronization. + !*/ + public: + + point_transform ( + ); + /*! + ensures + - This object will perform the identity transform. That is, given a point + as input it will return the same point as output. + !*/ + + point_transform ( + const double& angle, + const dlib::vector& translate + ) + /*! + ensures + - When (*this)(p) is invoked it will return a point P such that: + - P is the point p rotated counter-clockwise around the origin + angle radians and then shifted by having translate added to it. + (Note that this is counter clockwise with respect to the normal + coordinate system with positive y going up and positive x going + to the right) + !*/ + + template + const dlib::vector operator() ( + const dlib::vector& p + ) const; + /*! + ensures + - rotates p, then translates it and returns the result. The output + of this function is therefore equal to get_m()*p + get_b(). + !*/ + + const matrix get_m( + ) const; + /*! + ensures + - returns the transformation matrix used by this object. + !*/ + + const dlib::vector get_b( + ) const; + /*! + ensures + - returns the offset vector used by this object. + !*/ + + }; + + void serialize (const point_transform& item, std::ostream& out); + void deserialize (point_transform& item, std::istream& in); + /*! + provides serialization support + !*/ + +// ---------------------------------------------------------------------------------------- + + class point_rotator + { + /*! + WHAT THIS OBJECT REPRESENTS + This is an object that takes 2D points or vectors and + rotates them around the origin by a given angle. + + THREAD SAFETY + It is safe for multiple threads to make concurrent accesses to this object + without synchronization. + !*/ + public: + + point_rotator ( + ); + /*! + ensures + - This object will perform the identity transform. That is, given a point + as input it will return the same point as output. + !*/ + + point_rotator ( + const double& angle + ); + /*! + ensures + - When (*this)(p) is invoked it will return a point P such that: + - P is the point p rotated counter-clockwise around the origin + angle radians. + (Note that this is counter clockwise with respect to the normal + coordinate system with positive y going up and positive x going + to the right) + !*/ + + template + const dlib::vector operator() ( + const dlib::vector& p + ) const; + /*! + ensures + - rotates p and returns the result. The output of this function is + therefore equal to get_m()*p. + !*/ + + const matrix get_m( + ) const; + /*! + ensures + - returns the transformation matrix used by this object. + !*/ + }; + + void serialize (const point_rotator& item, std::ostream& out); + void deserialize (point_rotator& item, std::istream& in); + /*! + provides serialization support + !*/ + +// ---------------------------------------------------------------------------------------- + + template + const dlib::vector rotate_point ( + const dlib::vector center, + const dlib::vector p, + double angle + ); + /*! + ensures + - returns a point P such that: + - P is the point p rotated counter-clockwise around the given + center point by angle radians. + (Note that this is counter clockwise with respect to the normal + coordinate system with positive y going up and positive x going + to the right) + !*/ + +// ---------------------------------------------------------------------------------------- + + matrix rotation_matrix ( + double angle + ); + /*! + ensures + - returns a rotation matrix which rotates points around the origin in a + counter-clockwise direction by angle radians. + (Note that this is counter clockwise with respect to the normal + coordinate system with positive y going up and positive x going + to the right) + Or in other words, this function returns a matrix M such that, given a + point P, M*P gives a point which is P rotated by angle radians around + the origin in a counter-clockwise direction. + !*/ + +// ---------------------------------------------------------------------------------------- + + class point_transform_affine3d + { + /*! + WHAT THIS OBJECT REPRESENTS + This is an object that takes 3D points or vectors and + applies an affine transformation to them. + + THREAD SAFETY + It is safe for multiple threads to make concurrent accesses to this object + without synchronization. + !*/ + public: + + point_transform_affine3d ( + ); + /*! + ensures + - This object will perform the identity transform. That is, given a point + as input it will return the same point as output. + !*/ + + point_transform_affine3d ( + const matrix& m, + const dlib::vector& b + ); + /*! + ensures + - #get_m() == m + - #get_b() == b + - When (*this)(p) is invoked it will return a point P such that: + - P == m*p + b + !*/ + + const dlib::vector operator() ( + const dlib::vector& p + ) const; + /*! + ensures + - applies the affine transformation defined by this object's constructor + to p and returns the result. + !*/ + + const matrix& get_m( + ) const; + /*! + ensures + - returns the transformation matrix used by this object. + !*/ + + const dlib::vector& get_b( + ) const; + /*! + ensures + - returns the offset vector used by this object. + !*/ + + }; + + void serialize (const point_transform_affine3d& item, std::ostream& out); + void deserialize (point_transform_affine3d& item, std::istream& in); + /*! + provides serialization support + !*/ + +// ---------------------------------------------------------------------------------------- + + point_transform_affine3d operator* ( + const point_transform_affine3d& lhs, + const point_transform_affine3d& rhs + ); + /*! + ensures + - returns a transformation TFORM(x) that is equivalent to lhs(rhs(x)). That + is, for all valid x: TFORM(x) == lhs(rhs(x)). + !*/ + +// ---------------------------------------------------------------------------------------- + + point_transform_affine3d operator* ( + const point_transform_affine3d& lhs, + const point_transform_affine& rhs + ); + /*! + ensures + - returns a transformation TFORM(x) that is equivalent to lhs(rhs(x)). That + is, for all valid x: TFORM(x) == lhs(rhs(x)). + !*/ + +// ---------------------------------------------------------------------------------------- + + point_transform_affine3d inv ( + const point_transform_affine3d& trans + ); + /*! + ensures + - If trans is an invertible transformation then this function returns a new + transformation that is the inverse of trans. + !*/ + +// ---------------------------------------------------------------------------------------- + + point_transform_affine3d rotate_around_x ( + double angle + ); + /*! + ensures + - Returns a transformation that rotates a point around the x axis in a + counter-clockwise direction by angle radians. That is, the rotation appears + counter-clockwise when the x axis points toward the observer, the coordinate + system is right-handed, and the angle is positive. + !*/ + +// ---------------------------------------------------------------------------------------- + + point_transform_affine3d rotate_around_y ( + double angle + ); + /*! + ensures + - Returns a transformation that rotates a point around the y axis in a + counter-clockwise direction by angle radians. That is, the rotation appears + counter-clockwise when the y axis points toward the observer, the coordinate + system is right-handed, and the angle is positive. + !*/ + +// ---------------------------------------------------------------------------------------- + + point_transform_affine3d rotate_around_z ( + double angle + ); + /*! + ensures + - Returns a transformation that rotates a point around the z axis in a + counter-clockwise direction by angle radians. That is, the rotation appears + counter-clockwise when the z axis points toward the observer, the coordinate + system is right-handed, and the angle is positive. + !*/ + +// ---------------------------------------------------------------------------------------- + + point_transform_affine3d translate_point ( + const vector& delta + ); + /*! + ensures + - returns a transformation that simply translates points by adding delta to + them. That is, this function returns: + point_transform_affine3d(identity_matrix(3),delta); + !*/ + + point_transform_affine3d translate_point ( + double x, + double y, + double z + ); + /*! + ensures + - returns translate_point(vector(x,y,z)) + !*/ + +// ---------------------------------------------------------------------------------------- + + class camera_transform + { + /*! + WHAT THIS OBJECT REPRESENTS + This object maps 3D points into the image plane of a camera. Therefore, + you can use it to compute 2D representations of 3D data from the point of + view of some camera in 3D space. + + THREAD SAFETY + It is safe for multiple threads to make concurrent accesses to this object + without synchronization. + !*/ + + public: + + camera_transform ( + ); + /*! + ensures + - #get_camera_pos() == vector(1,1,1) + - #get_camera_looking_at() == vector(0,0,0) + - #get_camera_up_direction() == vector(0,0,1) + - #get_camera_field_of_view() == 90 + - #get_num_pixels() == 1 + !*/ + + camera_transform ( + const vector& camera_pos, + const vector& camera_looking_at, + const vector& camera_up_direction, + const double camera_field_of_view, + const unsigned long num_pixels + ); + /*! + requires + - 0 < camera_field_of_view < 180 + ensures + - #get_camera_pos() == camera_pos + - #get_camera_looking_at() == camera_looking_at + - #get_camera_up_direction() == camera_up_direction + - #get_camera_field_of_view() == camera_field_of_view + - #get_num_pixels() == num_pixels + !*/ + + dpoint operator() ( + const vector& p + ) const; + /*! + ensures + - Maps the given 3D point p into the 2D image plane defined by the camera + parameters given to this object's constructor. The 2D point in the image + plane is returned. + !*/ + + dpoint operator() ( + const vector& p, + double& scale, + double& distance + ) const; + /*! + ensures + - Maps the given 3D point p into the 2D image plane defined by the camera + parameters given to this object's constructor. The 2D point in the image + plane is returned. + - #scale == a number that tells you how large things are at the point p. + Objects further from the camera appear smaller, in particular, they + appear #scale times their normal size. + - #distance == how far away the point is from the image plane. Objects in + front of the camera will have a positive distance and those behind a + negative distance. + !*/ + + vector get_camera_pos( + ) const; + /*! + ensures + - returns the position, in 3D space, of the camera. When operator() is + invoked it maps 3D points into the image plane of this camera. + !*/ + + vector get_camera_looking_at( + ) const; + /*! + ensures + - returns the point in 3D space the camera is pointed at. + !*/ + + vector get_camera_up_direction( + ) const; + /*! + ensures + - returns a vector that defines what direction is "up" for the camera. + This means that as you travel from the bottom of the image plane to the + top you will be traveling in the direction of this vector. Note that + get_camera_up_direction() doesn't need to be orthogonal to the camera's + line of sight (i.e. get_camera_looking_at()-get_camera_pos()), it just + needs to not be an exact multiple of the line of sight. Any necessary + orthogonalization will be taken care of internally. + !*/ + + double get_camera_field_of_view( + ) const; + /*! + ensures + - returns the field of view of the camera in degrees. + !*/ + + unsigned long get_num_pixels( + ) const; + /*! + ensures + - 3D points that fall within the field of view of the camera are mapped by + operator() into the pixel coordinates of a get_num_pixels() by + get_num_pixels() image. Therefore, you can use the output of operator() + to index into an image. However, you still need to perform bounds + checking as there might be 3D points outside the field of view of the + camera and those will be mapped to 2D points outside the image. + !*/ + + }; + + void serialize (const camera_transform& item, std::ostream& out); + void deserialize (camera_transform& item, std::istream& in); + /*! + provides serialization support + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_POINT_TrANSFORMS_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/geometry/rectangle.h b/ml/dlib/dlib/geometry/rectangle.h new file mode 100644 index 000000000..3d67ca8c4 --- /dev/null +++ b/ml/dlib/dlib/geometry/rectangle.h @@ -0,0 +1,824 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_RECTANGLe_ +#define DLIB_RECTANGLe_ + +#include "rectangle_abstract.h" +#include "../algs.h" +#include +#include +#include "../serialize.h" +#include "vector.h" +#include "../image_processing/generic_image.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class rectangle + { + /*! + INITIAL VALUE + The initial value of this object is defined by its constructor. + + CONVENTION + left() == l + top() == t + right() == r + bottom() == b + !*/ + + public: + + rectangle ( + long l_, + long t_, + long r_, + long b_ + ) : + l(l_), + t(t_), + r(r_), + b(b_) + {} + + rectangle ( + unsigned long w, + unsigned long h + ) : + l(0), + t(0), + r(static_cast(w)-1), + b(static_cast(h)-1) + { + DLIB_ASSERT((w > 0 && h > 0) || (w == 0 && h == 0), + "\trectangle(width,height)" + << "\n\twidth and height must be > 0 or both == 0" + << "\n\twidth: " << w + << "\n\theight: " << h + << "\n\tthis: " << this + ); + } + + rectangle ( + const point& p + ) : + l(p.x()), + t(p.y()), + r(p.x()), + b(p.y()) + { + } + + rectangle ( + const point& p1, + const point& p2 + ) + { + *this = rectangle(p1) + rectangle(p2); + } + + template + rectangle ( + const vector& p1, + const vector& p2 + ) + { + *this = rectangle(p1) + rectangle(p2); + } + + rectangle ( + ) : + l(0), + t(0), + r(-1), + b(-1) + {} + + long top ( + ) const { return t; } + + long& top ( + ) { return t; } + + void set_top ( + long top_ + ) { t = top_; } + + long left ( + ) const { return l; } + + long& left ( + ) { return l; } + + void set_left ( + long left_ + ) { l = left_; } + + long right ( + ) const { return r; } + + long& right ( + ) { return r; } + + void set_right ( + long right_ + ) { r = right_; } + + long bottom ( + ) const { return b; } + + long& bottom ( + ) { return b; } + + void set_bottom ( + long bottom_ + ) { b = bottom_; } + + const point tl_corner ( + ) const { return point(left(), top()); } + + const point bl_corner ( + ) const { return point(left(), bottom()); } + + const point tr_corner ( + ) const { return point(right(), top()); } + + const point br_corner ( + ) const { return point(right(), bottom()); } + + unsigned long width ( + ) const + { + if (is_empty()) + return 0; + else + return r - l + 1; + } + + unsigned long height ( + ) const + { + if (is_empty()) + return 0; + else + return b - t + 1; + } + + unsigned long area ( + ) const + { + return width()*height(); + } + + bool is_empty ( + ) const { return (t > b || l > r); } + + rectangle operator + ( + const rectangle& rhs + ) const + { + if (rhs.is_empty()) + return *this; + else if (is_empty()) + return rhs; + + return rectangle ( + std::min(l,rhs.l), + std::min(t,rhs.t), + std::max(r,rhs.r), + std::max(b,rhs.b) + ); + } + + rectangle intersect ( + const rectangle& rhs + ) const + { + return rectangle ( + std::max(l,rhs.l), + std::max(t,rhs.t), + std::min(r,rhs.r), + std::min(b,rhs.b) + ); + } + + bool contains ( + const point& p + ) const + { + if (p.x() < l || p.x() > r || p.y() < t || p.y() > b) + return false; + return true; + } + + bool contains ( + long x, + long y + ) const + { + if (x < l || x > r || y < t || y > b) + return false; + return true; + } + + bool contains ( + const rectangle& rect + ) const + { + return (rect + *this == *this); + } + + rectangle& operator+= ( + const point& p + ) + { + *this = *this + rectangle(p); + return *this; + } + + rectangle& operator+= ( + const rectangle& rect + ) + { + *this = *this + rect; + return *this; + } + + bool operator== ( + const rectangle& rect + ) const + { + return (l == rect.l) && (t == rect.t) && (r == rect.r) && (b == rect.b); + } + + bool operator!= ( + const rectangle& rect + ) const + { + return !(*this == rect); + } + + inline bool operator< (const dlib::rectangle& b) const + { + if (left() < b.left()) return true; + else if (left() > b.left()) return false; + else if (top() < b.top()) return true; + else if (top() > b.top()) return false; + else if (right() < b.right()) return true; + else if (right() > b.right()) return false; + else if (bottom() < b.bottom()) return true; + else if (bottom() > b.bottom()) return false; + else return false; + } + + private: + long l; + long t; + long r; + long b; + }; + +// ---------------------------------------------------------------------------------------- + + inline void serialize ( + const rectangle& item, + std::ostream& out + ) + { + try + { + serialize(item.left(),out); + serialize(item.top(),out); + serialize(item.right(),out); + serialize(item.bottom(),out); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing an object of type rectangle"); + } + } + + inline void deserialize ( + rectangle& item, + std::istream& in + ) + { + try + { + deserialize(item.left(),in); + deserialize(item.top(),in); + deserialize(item.right(),in); + deserialize(item.bottom(),in); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while deserializing an object of type rectangle"); + } + } + + inline std::ostream& operator<< ( + std::ostream& out, + const rectangle& item + ) + { + out << "[(" << item.left() << ", " << item.top() << ") (" << item.right() << ", " << item.bottom() << ")]"; + return out; + } + + inline std::istream& operator>>( + std::istream& in, + rectangle& item + ) + { + // ignore any whitespace + while (in.peek() == ' ' || in.peek() == '\t' || in.peek() == '\r' || in.peek() == '\n') + in.get(); + // now eat the leading '[' character + if (in.get() != '[') + { + in.setstate(in.rdstate() | std::ios::failbit); + return in; + } + + point p1, p2; + in >> p1; + in >> p2; + item = rectangle(p1) + rectangle(p2); + + // ignore any whitespace + while (in.peek() == ' ' || in.peek() == '\t' || in.peek() == '\r' || in.peek() == '\n') + in.get(); + // now eat the trailing ']' character + if (in.get() != ']') + { + in.setstate(in.rdstate() | std::ios::failbit); + } + return in; + } + +// ---------------------------------------------------------------------------------------- + + inline const rectangle centered_rect ( + long x, + long y, + unsigned long width, + unsigned long height + ) + { + rectangle result; + result.set_left ( x - static_cast(width) / 2 ); + result.set_top ( y - static_cast(height) / 2 ); + result.set_right ( result.left() + width - 1 ); + result.set_bottom ( result.top() + height - 1 ); + return result; + } + +// ---------------------------------------------------------------------------------------- + + inline rectangle intersect ( + const rectangle& a, + const rectangle& b + ) { return a.intersect(b); } + +// ---------------------------------------------------------------------------------------- + + inline unsigned long area ( + const rectangle& a + ) { return a.area(); } + +// ---------------------------------------------------------------------------------------- + + inline point center ( + const dlib::rectangle& rect + ) + { + point temp(rect.left() + rect.right() + 1, + rect.top() + rect.bottom() + 1); + + if (temp.x() < 0) + temp.x() -= 1; + + if (temp.y() < 0) + temp.y() -= 1; + + return temp/2; + } + +// ---------------------------------------------------------------------------------------- + + inline dlib::vector dcenter ( + const dlib::rectangle& rect + ) + { + dlib::vector temp(rect.left() + rect.right(), + rect.top() + rect.bottom()); + + return temp/2.0; + } + +// ---------------------------------------------------------------------------------------- + + inline long distance_to_rect_edge ( + const rectangle& rect, + const point& p + ) + { + using std::max; + using std::min; + using std::abs; + + const long dist_x = min(abs(p.x()-rect.left()), abs(p.x()-rect.right())); + const long dist_y = min(abs(p.y()-rect.top()), abs(p.y()-rect.bottom())); + + if (rect.contains(p)) + return min(dist_x,dist_y); + else if (rect.left() <= p.x() && p.x() <= rect.right()) + return dist_y; + else if (rect.top() <= p.y() && p.y() <= rect.bottom()) + return dist_x; + else + return dist_x + dist_y; + } + +// ---------------------------------------------------------------------------------------- + + inline const point nearest_point ( + const rectangle& rect, + const point& p + ) + { + point temp(p); + if (temp.x() < rect.left()) + temp.x() = rect.left(); + else if (temp.x() > rect.right()) + temp.x() = rect.right(); + + if (temp.y() < rect.top()) + temp.y() = rect.top(); + else if (temp.y() > rect.bottom()) + temp.y() = rect.bottom(); + + return temp; + } + +// ---------------------------------------------------------------------------------------- + + inline size_t nearest_rect ( + const std::vector& rects, + const point& p + ) + { + DLIB_ASSERT(rects.size() > 0); + size_t idx = 0; + double best_dist = std::numeric_limits::infinity(); + + for (size_t i = 0; i < rects.size(); ++i) + { + if (rects[i].contains(p)) + { + return i; + } + else + { + double dist = (nearest_point(rects[i],p)-p).length(); + if (dist < best_dist) + { + best_dist = dist; + idx = i; + } + } + } + return idx; + } + +// ---------------------------------------------------------------------------------------- + + template + double distance_to_line ( + const std::pair,vector >& line, + const vector& p + ) + { + const vector delta = p-line.second; + const double along_dist = (line.first-line.second).normalize().dot(delta); + return std::sqrt(std::max(0.0,delta.length_squared() - along_dist*along_dist)); + } + +// ---------------------------------------------------------------------------------------- + + inline void clip_line_to_rectangle ( + const rectangle& box, + point& p1, + point& p2 + ) + { + // Now clip the line segment so it is contained inside box. + if (p1.x() == p2.x()) + { + if (!box.contains(p1)) + p1.y() = box.top(); + if (!box.contains(p2)) + p2.y() = box.bottom(); + } + else if (p1.y() == p2.y()) + { + if (!box.contains(p1)) + p1.x() = box.left(); + if (!box.contains(p2)) + p2.x() = box.right(); + } + else + { + // We use these relations to find alpha values. These values tell us + // how to generate points intersecting the rectangle boundaries. We then + // test the resulting points for ones that are inside the rectangle and output + // those. + //box.left() == alpha1*(p1.x()-p2.x()) + p2.x(); + //box.right() == alpha2*(p1.x()-p2.x()) + p2.x(); + + const point d = p1-p2; + double alpha1 = (box.left() -p2.x())/(double)d.x(); + double alpha2 = (box.right() -p2.x())/(double)d.x(); + double alpha3 = (box.top() -p2.y())/(double)d.y(); + double alpha4 = (box.bottom()-p2.y())/(double)d.y(); + + const point c1 = alpha1*d + p2; + const point c2 = alpha2*d + p2; + const point c3 = alpha3*d + p2; + const point c4 = alpha4*d + p2; + + if (!box.contains(p1)) + p1 = c1; + if (!box.contains(p2)) + p2 = c2; + if (box.contains(c3)) + { + if (!box.contains(p2)) + p2 = c3; + else if (!box.contains(p1)) + p1 = c3; + } + if (box.contains(c4)) + { + if (!box.contains(p2)) + p2 = c4; + else if (!box.contains(p1)) + p1 = c4; + } + } + + p1 = nearest_point(box, p1); + p2 = nearest_point(box, p2); + } + +// ---------------------------------------------------------------------------------------- + + inline const rectangle centered_rect ( + const point& p, + unsigned long width, + unsigned long height + ) + { + return centered_rect(p.x(),p.y(),width,height); + } + +// ---------------------------------------------------------------------------------------- + + inline const rectangle centered_rect ( + const rectangle& rect, + unsigned long width, + unsigned long height + ) + { + return centered_rect((rect.left()+rect.right())/2, (rect.top()+rect.bottom())/2, width, height); + } + +// ---------------------------------------------------------------------------------------- + + inline const rectangle shrink_rect ( + const rectangle& rect, + long num + ) + { + return rectangle(rect.left()+num, rect.top()+num, rect.right()-num, rect.bottom()-num); + } + +// ---------------------------------------------------------------------------------------- + + inline const rectangle grow_rect ( + const rectangle& rect, + long num + ) + { + return shrink_rect(rect, -num); + } + +// ---------------------------------------------------------------------------------------- + + inline const rectangle shrink_rect ( + const rectangle& rect, + long width, + long height + ) + { + return rectangle(rect.left()+width, rect.top()+height, rect.right()-width, rect.bottom()-height); + } + +// ---------------------------------------------------------------------------------------- + + inline const rectangle grow_rect ( + const rectangle& rect, + long width, + long height + ) + { + return shrink_rect(rect, -width, -height); + } + +// ---------------------------------------------------------------------------------------- + + inline const rectangle translate_rect ( + const rectangle& rect, + const point& p + ) + { + rectangle result; + result.set_top ( rect.top() + p.y() ); + result.set_bottom ( rect.bottom() + p.y() ); + result.set_left ( rect.left() + p.x() ); + result.set_right ( rect.right() + p.x() ); + return result; + } + +// ---------------------------------------------------------------------------------------- + + inline const rectangle translate_rect ( + const rectangle& rect, + long x, + long y + ) + { + rectangle result; + result.set_top ( rect.top() + y ); + result.set_bottom ( rect.bottom() + y ); + result.set_left ( rect.left() + x ); + result.set_right ( rect.right() + x ); + return result; + } + +// ---------------------------------------------------------------------------------------- + + inline const rectangle resize_rect ( + const rectangle& rect, + unsigned long width, + unsigned long height + ) + { + return rectangle(rect.left(),rect.top(), + rect.left()+width-1, + rect.top()+height-1); + } + +// ---------------------------------------------------------------------------------------- + + inline const rectangle resize_rect_width ( + const rectangle& rect, + unsigned long width + ) + { + return rectangle(rect.left(),rect.top(), + rect.left()+width-1, + rect.bottom()); + } + +// ---------------------------------------------------------------------------------------- + + inline const rectangle resize_rect_height ( + const rectangle& rect, + unsigned long height + ) + { + return rectangle(rect.left(),rect.top(), + rect.right(), + rect.top()+height-1); + } + +// ---------------------------------------------------------------------------------------- + + inline const rectangle move_rect ( + const rectangle& rect, + const point& p + ) + { + return rectangle(p.x(), p.y(), p.x()+rect.width()-1, p.y()+rect.height()-1); + } + +// ---------------------------------------------------------------------------------------- + + inline const rectangle move_rect ( + const rectangle& rect, + long x, + long y + ) + { + return rectangle(x, y, x+rect.width()-1, y+rect.height()-1); + } + +// ---------------------------------------------------------------------------------------- + + inline rectangle set_rect_area ( + const rectangle& rect, + unsigned long area + ) + { + DLIB_ASSERT(area > 0); + + if (rect.area() == 0) + { + // In this case we will make the output rectangle a square with the requested + // area. + unsigned long scale = std::round(std::sqrt(area)); + return centered_rect(rect, scale, scale); + } + else + { + double scale = std::sqrt(area/(double)rect.area()); + return centered_rect(rect, (long)std::round(rect.width()*scale), (long)std::round(rect.height()*scale)); + } + } + +// ---------------------------------------------------------------------------------------- + + inline rectangle set_aspect_ratio ( + const rectangle& rect, + double ratio + ) + { + DLIB_ASSERT(ratio > 0, + "\t rectangle set_aspect_ratio()" + << "\n\t ratio: " << ratio + ); + + // aspect ratio is w/h + + // we need to find the rectangle that is nearest to rect in area but + // with an aspect ratio of ratio. + + // w/h == ratio + // w*h == rect.area() + + if (ratio >= 1) + { + const long h = static_cast(std::sqrt(rect.area()/ratio) + 0.5); + const long w = static_cast(h*ratio + 0.5); + return centered_rect(rect, w, h); + } + else + { + const long w = static_cast(std::sqrt(rect.area()*ratio) + 0.5); + const long h = static_cast(w/ratio + 0.5); + return centered_rect(rect, w, h); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + inline const rectangle get_rect ( + const T& m + ) + { + return rectangle(0, 0, num_columns(m)-1, num_rows(m)-1); + } + +// ---------------------------------------------------------------------------------------- + + inline rectangle operator+ ( + const rectangle& r, + const point& p + ) + { + return r + rectangle(p); + } + +// ---------------------------------------------------------------------------------------- + + inline rectangle operator+ ( + const point& p, + const rectangle& r + ) + { + return r + rectangle(p); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_RECTANGLe_ + + diff --git a/ml/dlib/dlib/geometry/rectangle_abstract.h b/ml/dlib/dlib/geometry/rectangle_abstract.h new file mode 100644 index 000000000..0ff0f0a8d --- /dev/null +++ b/ml/dlib/dlib/geometry/rectangle_abstract.h @@ -0,0 +1,836 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_RECTANGLe_ABSTRACT_ +#ifdef DLIB_RECTANGLe_ABSTRACT_ + +#include "vector_abstract.h" +#include +#include "../serialize.h" + +namespace dlib +{ + + class rectangle + { + /*! + INITIAL VALUE + The initial value of this object is defined by its constructor. + + WHAT THIS OBJECT REPRESENTS + This object represents a rectangular region inside a Cartesian + coordinate system. The region is the rectangle with its top + left corner at position (left(),top()) and its bottom right corner + at (right(),bottom()). + + Note that the origin of the coordinate system, i.e. (0,0), is located + at the upper left corner. That is, points such as (1,1) or (3,5) + represent locations that are below and to the right of the origin. + + Also note that rectangles where top() > bottom() or left() > right() + represent empty rectangles. + !*/ + + public: + + rectangle ( + const rectangle& rect + ); + /*! + ensures + - #*this represents the same rectangle as rect + !*/ + + rectangle ( + ); + /*! + ensures + - #left() == 0 + - #top() == 0 + - #right() == -1 + - #bottom() == -1 + - #is_empty() == true + !*/ + + rectangle ( + long left_, + long top_, + long right_, + long bottom_ + ); + /*! + ensures + - #left() == left_ + - #top() == top_ + - #right() == right_ + - #bottom() == bottom_ + !*/ + + rectangle ( + unsigned long width_, + unsigned long height_ + ); + /*! + requires + - (width_ > 0 && height_ > 0) || (width_ == 0 && height_ == 0) + ensures + - #left() == 0 + - #top() == 0 + - #width() == width_ + - #height() == height_ + !*/ + + rectangle ( + const point& p + ); + /*! + ensures + - #left() == p.x() + - #top() == p.y() + - #right() == p.x() + - #bottom() == p.y() + !*/ + + template + rectangle ( + const vector& p1, + const vector& p2 + ); + /*! + ensures + - #*this == rectangle(p1) + rectangle(p2) + !*/ + + long left ( + ) const; + /*! + ensures + - returns the x coordinate for the left side of this rectangle + !*/ + + long& left ( + ); + /*! + ensures + - returns a non-const reference to the x coordinate for the left side + of this rectangle + !*/ + + void set_left ( + long left_ + ); + /*! + ensures + - #left() == left_ + !*/ + + long top ( + ) const; + /*! + ensures + - returns the y coordinate for the top of this rectangle + !*/ + + long& top ( + ); + /*! + ensures + - returns a non-const reference to the y coordinate for the + top of this rectangle + !*/ + + void set_top ( + long top_ + ); + /*! + ensures + - #top() == top_ + !*/ + + long right ( + ) const; + /*! + ensures + - returns the x coordinate for the right side of this rectangle + !*/ + + long& right ( + ); + /*! + ensures + - returns a non-const reference to the x coordinate for the right + side of this rectangle + !*/ + + void set_right ( + long right_ + ); + /*! + ensures + - #right() == right_ + !*/ + + long bottom ( + ) const; + /*! + ensures + - returns the y coordinate for the bottom of this rectangle + !*/ + + long& bottom ( + ); + /*! + ensures + - returns a non-const reference to the y coordinate for the bottom + of this rectangle + !*/ + + void set_bottom ( + long bottom_ + ); + /*! + ensures + - #bottom() == bottom_ + !*/ + + const point tl_corner ( + ) const; + /*! + ensures + - returns point(left(), top()) + (i.e. returns the top left corner point for this rectangle) + !*/ + + const point bl_corner ( + ) const; + /*! + ensures + - returns point(left(), bottom()) + (i.e. returns the bottom left corner point for this rectangle) + !*/ + + const point tr_corner ( + ) const; + /*! + ensures + - returns point(right(), top()) + (i.e. returns the top right corner point for this rectangle) + !*/ + + const point br_corner ( + ) const; + /*! + ensures + - returns point(right(), bottom()) + (i.e. returns the bottom right corner point for this rectangle) + !*/ + + bool is_empty ( + ) const; + /*! + ensures + - if (top() > bottom() || left() > right()) then + - returns true + - else + - returns false + !*/ + + unsigned long width ( + ) const; + /*! + ensures + - if (is_empty()) then + - returns 0 + - else + - returns the width of this rectangle. + (i.e. right() - left() + 1) + !*/ + + unsigned long height ( + ) const; + /*! + ensures + - if (is_empty()) then + - returns 0 + - else + - returns the height of this rectangle. + (i.e. bottom() - top() + 1) + !*/ + + unsigned long area ( + ) const; + /*! + ensures + - returns width()*height() + !*/ + + rectangle operator + ( + const rectangle& rhs + ) const; + /*! + ensures + - if (rhs.is_empty() == false && this->is_empty() == false) then + - returns the smallest rectangle that contains both *this and + rhs. + - if (rhs.is_empty() == true && this->is_empty() == false) then + - returns *this + - if (rhs.is_empty() == false && this->is_empty() == true) then + - returns rhs + - if (rhs.is_empty() == true && this->is_empty() == true) then + - returns a rectangle that has is_empty() == true + !*/ + + rectangle intersect ( + const rectangle& rhs + ) const; + /*! + ensures + - if (there is a region of intersection between *this and rhs) then + - returns a rectangle that represents the intersection of *this + and rhs. + - else + - returns a rectangle where is_empty() == true + !*/ + + bool contains ( + long x, + long y + ) const; + /*! + ensures + - if (the point (x,y) is contained in this rectangle) then + - returns true + - else + - returns false + !*/ + + bool contains ( + const point& p + ) const; + /*! + ensures + - if (the point (p.x(),p.y()) is contained in this rectangle) then + - returns true + - else + - returns false + !*/ + + bool contains ( + const rectangle& rect + ) const; + /*! + ensures + - if (rect + *this == *this) then + - returns true + (i.e. returns true if *this contains the given rectangle) + - else + - returns false + !*/ + + rectangle& operator= ( + const rectangle& rect + ); + /*! + ensures + - #*this represents the same rectangle as rect + - returns #*this + !*/ + + rectangle& operator+= ( + const rectangle& rect + ); + /*! + ensures + - #*this == *this + rect + - returns #*this + !*/ + + bool operator== ( + const rectangle& rect + ) const; + /*! + ensures + - if (top() == rect.top() && left() == rect.left() && + right() == rect.right() && bottom() == rect.bottom()) then + - returns true + - else + - returns false + !*/ + + bool operator!= ( + const rectangle& rect + ) const; + /*! + ensures + - returns !(*this == rect) + !*/ + + bool operator< ( + const dlib::rectangle& a, + const dlib::rectangle& b + ) const; + /*! + ensures + - Defines a total ordering over rectangles so they can be used in + associative containers. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + void serialize ( + const rectangle& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + + void deserialize ( + rectangle& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + + std::ostream& operator<< ( + std::ostream& out, + const rectangle& item + ); + /*! + ensures + - writes item to out in the form "[(left, top) (right, bottom)]" + !*/ + + std::istream& operator>>( + std::istream& in, + rectangle& item + ); + /*! + ensures + - reads a rectangle from the input stream in and stores it in #item. + The data in the input stream should be of the form [(left, top) (right, bottom)] + !*/ + +// ---------------------------------------------------------------------------------------- + + point center ( + const dlib::rectangle& rect + ); + /*! + ensures + - returns the center of the given rectangle + !*/ + +// ---------------------------------------------------------------------------------------- + + dlib::vector dcenter ( + const dlib::rectangle& rect + ); + /*! + ensures + - returns the center of the given rectangle using a real valued vector. + !*/ + +// ---------------------------------------------------------------------------------------- + + inline const rectangle centered_rect ( + const point& p, + unsigned long width, + unsigned long height + ); + /*! + ensures + - returns a rectangle R such that: + - center(R) == p + - if (width == 0 || height == 0) + - R.width() == 0 + - R.height() == 0 + - else + - R.width() == width + - R.height() == height + - R.tl_corner() == point(p.x()-width/2, p.y()-height/2) + !*/ + +// ---------------------------------------------------------------------------------------- + + const rectangle centered_rect ( + long x, + long y, + unsigned long width, + unsigned long height + ); + /*! + ensures + - returns a rectangle R such that: + - center(R) == p + - if (width == 0 || height == 0) + - R.width() == 0 + - R.height() == 0 + - else + - R.width() == width + - R.height() == height + - R.tl_corner() == point(x-width/2, y-height/2) + !*/ + +// ---------------------------------------------------------------------------------------- + + inline const rectangle centered_rect ( + const rectangle& rect, + unsigned long width, + unsigned long height + ); + /*! + ensures + - returns centered_rect( (rect.tl_corner() + rect.br_corner())/2, width, height) + (i.e. returns a rectangle centered on rect but with the given width + and height) + !*/ + +// ---------------------------------------------------------------------------------------- + + inline rectangle set_rect_area ( + const rectangle& rect, + unsigned long area + ); + /*! + requires + - area > 0 + ensures + - Returns a rectangle R such that: + - center(R) == center(rect) + - R has the same aspect ratio as rect. If rect.area() == 0 then the + returned rect has a 1:1 aspect ratio. + - R.area() == area + !*/ + +// ---------------------------------------------------------------------------------------- + + inline rectangle set_aspect_ratio ( + const rectangle& rect, + double ratio + ); + /*! + requires + - ratio > 0 + ensures + - This function reshapes the given rectangle so that it has the given aspect + ratio. In particular, this means we return a rectangle R such that the + following equations are as true as possible: + - R.width()/R.height() == ratio + - R.area() == rect.area() + - center(rect) == center(R) + !*/ + +// ---------------------------------------------------------------------------------------- + + inline rectangle intersect ( + const rectangle& a, + const rectangle& b + ); + /*! + ensures + - returns a.intersect(b) + (i.e. returns a rectangle representing the intersection of a and b) + !*/ + +// ---------------------------------------------------------------------------------------- + + inline unsigned long area ( + const rectangle& a + ); + /*! + ensures + - returns a.area() + !*/ + +// ---------------------------------------------------------------------------------------- + + inline const rectangle shrink_rect ( + const rectangle& rect, + long num + ); + /*! + ensures + - returns rectangle(rect.left()+num, rect.top()+num, rect.right()-num, rect.bottom()-num) + (i.e. shrinks the given rectangle by shrinking its border by num) + !*/ + +// ---------------------------------------------------------------------------------------- + + inline const rectangle grow_rect ( + const rectangle& rect, + long num + ); + /*! + ensures + - return shrink_rect(rect, -num) + (i.e. grows the given rectangle by expanding its border by num) + !*/ + +// ---------------------------------------------------------------------------------------- + + inline const rectangle shrink_rect ( + const rectangle& rect, + long width, + long height + ); + /*! + ensures + - returns rectangle(rect.left()+width, rect.top()+height, rect.right()-width, rect.bottom()-height) + (i.e. shrinks the given rectangle by shrinking its left and right borders by width + and its top and bottom borders by height. ) + !*/ + +// ---------------------------------------------------------------------------------------- + + inline const rectangle grow_rect ( + const rectangle& rect, + long width, + long height + ); + /*! + ensures + - return shrink_rect(rect, -width, -height) + (i.e. grows the given rectangle by expanding its border) + !*/ + +// ---------------------------------------------------------------------------------------- + + const rectangle translate_rect ( + const rectangle& rect, + const point& p + ); + /*! + ensures + - returns a rectangle R such that: + - R.left() == rect.left() + p.x() + - R.right() == rect.right() + p.x() + - R.top() == rect.top() + p.y() + - R.bottom() == rect.bottom() + p.y() + !*/ + +// ---------------------------------------------------------------------------------------- + + const rectangle translate_rect ( + const rectangle& rect, + long x, + long y + ); + /*! + ensures + - returns a rectangle R such that: + - R.left() == rect.left() + x + - R.right() == rect.right() + x + - R.top() == rect.top() + y + - R.bottom() == rect.bottom() + y + !*/ + +// ---------------------------------------------------------------------------------------- + + const rectangle resize_rect ( + const rectangle& rect, + unsigned long width, + unsigned long height + ); + /*! + ensures + - returns a rectangle R such that: + - if (width == 0 || height == 0) + - R.width() == 0 + - R.height() == 0 + - else + - R.width() == width + - R.height() == height + - R.left() == rect.left() + - R.top() == rect.top() + !*/ + +// ---------------------------------------------------------------------------------------- + + const rectangle resize_rect_width ( + const rectangle& rect, + unsigned long width + ); + /*! + ensures + - returns a rectangle R such that: + - R.width() == width + - R.left() == rect.left() + - R.top() == rect.top() + - R.bottom() == rect.bottom() + !*/ + +// ---------------------------------------------------------------------------------------- + + const rectangle resize_rect_height ( + const rectangle& rect, + unsigned long height + ); + /*! + ensures + - returns a rectangle R such that: + - R.height() == height + - R.left() == rect.left() + - R.top() == rect.top() + - R.right() == rect.right() + !*/ + +// ---------------------------------------------------------------------------------------- + + const rectangle move_rect ( + const rectangle& rect, + const point& p + ); + /*! + ensures + - returns a rectangle R such that: + - R.width() == rect.width() + - R.height() == rect.height() + - R.left() == p.x() + - R.top() == p.y() + !*/ + +// ---------------------------------------------------------------------------------------- + + const rectangle move_rect ( + const rectangle& rect, + long x, + long y + ); + /*! + ensures + - returns a rectangle R such that: + - R.width() == rect.width() + - R.height() == rect.height() + - R.left() == x + - R.top() == y + !*/ + +// ---------------------------------------------------------------------------------------- + + inline const point nearest_point ( + const rectangle& rect, + const point& p + ); + /*! + ensures + - if (rect.contains(p)) then + - returns p + - else + - returns the point in rect that is closest to p + !*/ + +// ---------------------------------------------------------------------------------------- + + inline size_t nearest_rect ( + const std::vector& rects, + const point& p + ); + /*! + requires + - rects.size() > 0 + ensures + - returns the index of the rectangle that is closest to the point p. In + particular, this function returns an IDX such that: + length(nearest_point(rects[IDX],p) - p) + is minimized. + !*/ + +// ---------------------------------------------------------------------------------------- + + inline long distance_to_rect_edge ( + const rectangle& rect, + const point& p + ); + /*! + ensures + - returns the Manhattan distance between the edge of rect and p. + !*/ + +// ---------------------------------------------------------------------------------------- + + template + double distance_to_line ( + const std::pair,vector >& line, + const vector& p + ); + /*! + ensures + - returns the euclidean distance between the given line and the point p. That + is, given a line that passes though the points line.first and line.second, + what is the distance between p and the nearest point on the line? This + function returns that distance. + !*/ + +// ---------------------------------------------------------------------------------------- + + void clip_line_to_rectangle ( + const rectangle& box, + point& p1, + point& p2 + ); + /*! + ensures + - clips the line segment that goes from points p1 to p2 so that it is entirely + within the given box. In particular, we will have: + - box.contains(#p1) == true + - box.contains(#p2) == true + - The line segment #p1 to #p2 is entirely contained within the line segment + p1 to p2. Moreover, #p1 to #p2 is the largest such line segment that + fits within the given box. + - If the line segment does not intersect the box then the result is some + arbitrary line segment inside the box. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + const rectangle get_rect ( + const T& m + ); + /*! + requires + - It must be possible to determine the number of "rows" and "columns" in m. + Either by calling num_rows(m) and num_columns(m) or by calling m.nr() and + m.nc() to obtain the number of rows and columns respectively. Moreover, + these routines should return longs. + ensures + - returns rectangle(0, 0, num_columns(m)-1, num_rows(m)-1) + (i.e. assuming T represents some kind of rectangular grid, such as + the dlib::matrix or dlib::array2d objects, this function returns the + bounding rectangle for that gridded object.) + !*/ + +// ---------------------------------------------------------------------------------------- + + inline rectangle operator+ ( + const rectangle& r, + const point& p + ); + /*! + ensures + - returns r + rectangle(p) + (i.e. returns the rectangle that contains both r and p) + !*/ + +// ---------------------------------------------------------------------------------------- + + inline rectangle operator+ ( + const point& p, + const rectangle& r + ); + /*! + ensures + - returns r + rectangle(p) + (i.e. returns the rectangle that contains both r and p) + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_RECTANGLe_ABSTRACT_ + diff --git a/ml/dlib/dlib/geometry/vector.h b/ml/dlib/dlib/geometry/vector.h new file mode 100644 index 000000000..4ea53799d --- /dev/null +++ b/ml/dlib/dlib/geometry/vector.h @@ -0,0 +1,1330 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_VECTOr_H_ +#define DLIB_VECTOr_H_ + +#include +#include "vector_abstract.h" +#include "../algs.h" +#include "../serialize.h" +#include +#include +#include "../matrix/matrix.h" +#include + +#if defined(_MSC_VER) && _MSC_VER < 1400 +// Despite my efforts to disabuse visual studio of its usual nonsense I can't find a +// way to make this warning go away without just disabling it. This is the warning: +// dlib\geometry\vector.h(129) : warning C4805: '==' : unsafe mix of type 'std::numeric_limits<_Ty>::is_integer' and type 'bool' in operation +// +#pragma warning(disable:4805) +#endif + +namespace dlib +{ + + template < + typename T, + long NR = 3 + > + class vector; + +// ---------------------------------------------------------------------------------------- + + template + struct vect_promote; + + template + struct largest_type + { + typedef T type; + }; + template + struct largest_type + { + typedef U type; + }; + + template + struct vect_promote::is_integer == std::numeric_limits::is_integer>::type> + { + // If both T and U are both either integral or non-integral then just + // use the biggest one + typedef typename largest_type::type type; + }; + + template + struct vect_promote::is_integer != std::numeric_limits::is_integer>::type> + { + typedef double type; + }; + +// ---------------------------------------------------------------------------------------- + + // This insanity here is to work around a bug in visual studio 8. These two rebind + // structures are actually declared at a few points in this file because just having the + // one declaration here isn't enough for visual studio. It takes the three spread around + // to avoid all its bugs. + template + struct vc_rebind + { + typedef vector type; + }; + template + struct vc_rebind_promote + { + typedef vector::type,N> type; + }; + +// ---------------------------------------------------------------------------------------- + + template + struct vector_assign_helper + { + template + static void assign ( + vector& dest, + const vector& src + ) + { + dest.x() = static_cast(src.x()); + dest.y() = static_cast(src.y()); + } + + template + static void assign ( + vector& dest, + const vector& src + ) + { + dest.x() = static_cast(src.x()); + dest.y() = static_cast(src.y()); + dest.z() = static_cast(src.z()); + } + + template + static void assign ( + vector& dest, + const matrix_exp& m + ) + { + T x = static_cast(m(0)); + T y = static_cast(m(1)); + dest.x() = x; + dest.y() = y; + } + + template + static void assign ( + vector& dest, + const matrix_exp& m + ) + { + T x = static_cast(m(0)); + T y = static_cast(m(1)); + T z = static_cast(m(2)); + + dest.x() = x; + dest.y() = y; + dest.z() = z; + } + }; + + // This is an overload for the case where you are converting from a floating point + // type to an integral type. These overloads make sure values are rounded to + // the nearest integral value. + template + struct vector_assign_helper::is_integer == true && + std::numeric_limits::is_integer == false>::type> + { + template + static void assign ( + vector& dest, + const vector& src + ) + { + dest.x() = static_cast(std::floor(src.x() + 0.5)); + dest.y() = static_cast(std::floor(src.y() + 0.5)); + } + + template + static void assign ( + vector& dest, + const vector& src + ) + { + dest.x() = static_cast(std::floor(src.x() + 0.5)); + dest.y() = static_cast(std::floor(src.y() + 0.5)); + dest.z() = static_cast(std::floor(src.z() + 0.5)); + } + + template + static void assign ( + vector& dest, + const matrix_exp& m + ) + { + dest.x() = static_cast(std::floor(m(0) + 0.5)); + dest.y() = static_cast(std::floor(m(1) + 0.5)); + dest.z() = static_cast(std::floor(m(2) + 0.5)); + } + + template + static void assign ( + vector& dest, + const matrix_exp& m + ) + { + dest.x() = static_cast(std::floor(m(0) + 0.5)); + dest.y() = static_cast(std::floor(m(1) + 0.5)); + } + + }; + +// ---------------------------------------------------------------------------------------- + + template + class vector : public matrix + { + /*! + INITIAL VALUE + - x() == 0 + - y() == 0 + - z() == 0 + + CONVENTION + - (*this)(0) == x() + - (*this)(1) == y() + - (*this)(2) == z() + + !*/ + + // This insanity here is to work around a bug in visual studio 8. + template + struct vc_rebind + { + typedef vector type; + }; + template + struct vc_rebind_promote + { + typedef vector::type,N> type; + }; + + public: + + typedef T type; + + vector ( + ) + { + x() = 0; + y() = 0; + z() = 0; + } + + // --------------------------------------- + + vector ( + const T _x, + const T _y, + const T _z + ) + { + x() = _x; + y() = _y; + z() = _z; + } + + // --------------------------------------- + + vector ( + const vector& item + ) : matrix(item) + { + } + + // --------------------------------------- + + template + vector ( + const vector& item + ) + { + // Do this so that we get the appropriate rounding depending on the relative + // type of T and U. + vector temp(item); + x() = temp.x(); + y() = temp.y(); + z() = 0; + } + + // --------------------------------------- + + vector ( + const vector& item + ) + { + x() = item.x(); + y() = item.y(); + z() = 0; + } + + // --------------------------------------- + + template + vector ( + const vector& item + ) + { + (*this) = item; + } + + // --------------------------------------- + + template + vector ( const matrix_exp& m) + { + (*this) = m; + } + + // --------------------------------------- + + template + vector& operator = ( + const matrix_exp& m + ) + { + // you can only assign vectors with 3 elements to a dlib::vector object + COMPILE_TIME_ASSERT(EXP::NR*EXP::NC == 3 || EXP::NR*EXP::NC == 0); + + // make sure requires clause is not broken + DLIB_ASSERT((m.nr() == 1 || m.nc() == 1) && (m.size() == 3), + "\t vector(const matrix_exp& m)" + << "\n\t the given matrix is of the wrong size" + << "\n\t m.nr(): " << m.nr() + << "\n\t m.nc(): " << m.nc() + << "\n\t m.size(): " << m.size() + << "\n\t this: " << this + ); + + vector_assign_helper::assign(*this, m); + return *this; + } + + // --------------------------------------- + + template + vector& operator = ( + const vector& item + ) + { + vector_assign_helper::assign(*this, item); + return *this; + } + + // --------------------------------------- + + vector& operator= ( + const vector& item + ) + { + x() = item.x(); + y() = item.y(); + z() = item.z(); + return *this; + } + + // --------------------------------------- + + double length( + ) const + { + return std::sqrt((double)(x()*x() + y()*y() + z()*z())); + } + + // --------------------------------------- + + double length_squared( + ) const + { + return (double)(x()*x() + y()*y() + z()*z()); + } + + // --------------------------------------- + + typename vc_rebind::type normalize ( + ) const + { + const double tmp = std::sqrt((double)(x()*x() + y()*y() + z()*z())); + return vector ( x()/tmp, + y()/tmp, + z()/tmp + ); + } + + // --------------------------------------- + + T& x ( + ) + { + return (*this)(0); + } + + // --------------------------------------- + + T& y ( + ) + { + return (*this)(1); + } + + // --------------------------------------- + + T& z ( + ) + { + return (*this)(2); + } + + // --------------------------------------- + + const T& x ( + ) const + { + return (*this)(0); + } + + // --------------------------------------- + + const T& y ( + ) const + { + return (*this)(1); + } + + // --------------------------------------- + + const T& z ( + ) const + { + return (*this)(2); + } + + // --------------------------------------- + + T dot ( + const vector& rhs + ) const + { + return x()*rhs.x() + y()*rhs.y() + z()*rhs.z(); + } + + // --------------------------------------- + + template + typename vect_promote::type dot ( + const vector& rhs + ) const + { + return x()*rhs.x() + y()*rhs.y() + z()*rhs.z(); + } + + // --------------------------------------- + + template + typename vc_rebind_promote::type cross ( + const vector& rhs + ) const + { + typedef vector::type,3> ret_type; + + return ret_type ( + y()*rhs.z() - z()*rhs.y(), + z()*rhs.x() - x()*rhs.z(), + x()*rhs.y() - y()*rhs.x() + ); + } + + // --------------------------------------- + + vector& operator += ( + const vector& rhs + ) + { + x() += rhs.x(); + y() += rhs.y(); + z() += rhs.z(); + return *this; + } + + // --------------------------------------- + + vector& operator -= ( + const vector& rhs + ) + { + x() -= rhs.x(); + y() -= rhs.y(); + z() -= rhs.z(); + return *this; + } + + // --------------------------------------- + + vector& operator /= ( + const T& rhs + ) + { + x() /= rhs; + y() /= rhs; + z() /= rhs; + return *this; + } + + // --------------------------------------- + + vector& operator *= ( + const T& rhs + ) + { + x() *= rhs; + y() *= rhs; + z() *= rhs; + return *this; + } + + // --------------------------------------- + + vector operator - ( + ) const + { + return vector(-x(), -y(), -z()); + } + + // --------------------------------------- + + template + typename vc_rebind_promote::type operator / ( + const U& val + ) const + { + typedef vector::type,3> ret_type; + return ret_type(x()/val, y()/val, z()/val); + } + + // --------------------------------------- + + template + bool operator== ( + const vector& rhs + ) const + { + return x()==rhs.x() && y()==rhs.y() && z()==rhs.z(); + } + + // --------------------------------------- + + template + bool operator!= ( + const vector& rhs + ) const + { + return !(*this == rhs); + } + + // --------------------------------------- + + void swap ( + vector& item + ) + { + dlib::exchange(x(), item.x()); + dlib::exchange(y(), item.y()); + dlib::exchange(z(), item.z()); + } + + // --------------------------------------- + + }; + +// ---------------------------------------------------------------------------------------- + + template + class vector : public matrix + { + /*! + INITIAL VALUE + - x() == 0 + - y() == 0 + + CONVENTION + - (*this)(0) == x() + - (*this)(1) == y() + - z() == 0 + !*/ + + // This insanity here is to work around a bug in visual studio 8. + template + struct vc_rebind + { + typedef vector type; + }; + template + struct vc_rebind_promote + { + typedef vector::type,N> type; + }; + + + public: + + typedef T type; + + vector ( + ) + { + x() = 0; + y() = 0; + } + + // --------------------------------------- + + vector ( + const T _x, + const T _y + ) + { + x() = _x; + y() = _y; + } + + // --------------------------------------- + + template + vector ( + const vector& item + ) + { + // Do this so that we get the appropriate rounding depending on the relative + // type of T and U. + vector temp(item); + x() = temp.x(); + y() = temp.y(); + } + + // --------------------------------------- + + vector ( + const vector& item + ) : matrix(item) + { + } + + // --------------------------------------- + + vector ( + const vector& item + ) + { + x() = item.x(); + y() = item.y(); + } + + // --------------------------------------- + + template + vector ( + const vector& item + ) + { + (*this) = item; + } + + // --------------------------------------- + + template + vector ( const matrix_exp& m) + { + (*this) = m; + } + + // --------------------------------------- + + template + vector& operator = ( + const matrix_exp& m + ) + { + // you can only assign vectors with 2 elements to a dlib::vector object + COMPILE_TIME_ASSERT(EXP::NR*EXP::NC == 2 || EXP::NR*EXP::NC == 0); + + // make sure requires clause is not broken + DLIB_ASSERT((m.nr() == 1 || m.nc() == 1) && (m.size() == 2), + "\t vector(const matrix_exp& m)" + << "\n\t the given matrix is of the wrong size" + << "\n\t m.nr(): " << m.nr() + << "\n\t m.nc(): " << m.nc() + << "\n\t m.size(): " << m.size() + << "\n\t this: " << this + ); + + vector_assign_helper::assign(*this, m); + return *this; + } + + // --------------------------------------- + + template + vector& operator = ( + const vector& item + ) + { + vector_assign_helper::assign(*this, item); + return *this; + } + + // --------------------------------------- + + vector& operator= ( + const vector& item + ) + { + x() = item.x(); + y() = item.y(); + return *this; + } + + // --------------------------------------- + + double length( + ) const + { + return std::sqrt((double)(x()*x() + y()*y())); + } + + // --------------------------------------- + + double length_squared( + ) const + { + return (double)(x()*x() + y()*y()); + } + + // --------------------------------------- + + typename vc_rebind::type normalize ( + ) const + { + const double tmp = std::sqrt((double)(x()*x() + y()*y())); + return vector ( x()/tmp, + y()/tmp + ); + } + + // --------------------------------------- + + T& x ( + ) + { + return (*this)(0); + } + + // --------------------------------------- + + T& y ( + ) + { + return (*this)(1); + } + + // --------------------------------------- + + const T& x ( + ) const + { + return (*this)(0); + } + + // --------------------------------------- + + const T& y ( + ) const + { + return (*this)(1); + } + + // --------------------------------------- + + const T z ( + ) const + { + return 0; + } + + // --------------------------------------- + + T dot ( + const vector& rhs + ) const + { + return x()*rhs.x() + y()*rhs.y(); + } + + // --------------------------------------- + + template + typename vect_promote::type dot ( + const vector& rhs + ) const + { + return x()*rhs.x() + y()*rhs.y() + z()*rhs.z(); + } + + // --------------------------------------- + + vector& operator += ( + const vector& rhs + ) + { + x() += rhs.x(); + y() += rhs.y(); + return *this; + } + + // --------------------------------------- + + vector& operator -= ( + const vector& rhs + ) + { + x() -= rhs.x(); + y() -= rhs.y(); + return *this; + } + + // --------------------------------------- + + vector& operator /= ( + const T& rhs + ) + { + x() /= rhs; + y() /= rhs; + return *this; + } + + // --------------------------------------- + + vector& operator *= ( + const T& rhs + ) + { + x() *= rhs; + y() *= rhs; + return *this; + } + + // --------------------------------------- + + vector operator - ( + ) const + { + return vector(-x(), -y()); + } + + // --------------------------------------- + + template + typename vc_rebind_promote::type operator / ( + const U& val + ) const + { + typedef vector::type,2> ret_type; + return ret_type(x()/val, y()/val); + } + + // --------------------------------------- + + template + bool operator== ( + const vector& rhs + ) const + { + return x()==rhs.x() && y()==rhs.y() && z()==rhs.z(); + } + + // --------------------------------------- + + bool operator== ( + const vector& rhs + ) const + { + return x()==rhs.x() && y()==rhs.y(); + } + + // --------------------------------------- + + template + bool operator!= ( + const vector& rhs + ) const + { + return !(*this == rhs); + } + + // --------------------------------------- + + bool operator!= ( + const vector& rhs + ) const + { + return !(*this == rhs); + } + + // --------------------------------------- + + void swap ( + vector& item + ) + { + dlib::exchange(x(), item.x()); + dlib::exchange(y(), item.y()); + } + + // --------------------------------------- + + template + typename vc_rebind_promote::type cross ( + const vector& rhs + ) const + { + typedef vector::type,3> ret_type; + return ret_type ( + y()*rhs.z(), + - x()*rhs.z(), + x()*rhs.y() - y()*rhs.x() + ); + } + + // --------------------------------------- + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template + inline const typename vc_rebind_promote::type operator+ ( + const vector& lhs, + const vector& rhs + ) + { + typedef typename vc_rebind_promote::type ret_type; + return ret_type(lhs.x()+rhs.x(), lhs.y()+rhs.y()); + } + +// ---------------------------------------------------------------------------------------- + + template + inline const typename vc_rebind_promote::type operator+ ( + const vector& lhs, + const vector& rhs + ) + { + typedef typename vc_rebind_promote::type ret_type; + return ret_type(lhs.x()+rhs.x(), lhs.y()+rhs.y(), lhs.z()+rhs.z()); + } + +// ---------------------------------------------------------------------------------------- + + template + inline const typename vc_rebind_promote::type operator+ ( + const vector& lhs, + const vector& rhs + ) + { + typedef typename vc_rebind_promote::type ret_type; + return ret_type(lhs.x()+rhs.x(), lhs.y()+rhs.y(), rhs.z()); + } + +// ---------------------------------------------------------------------------------------- + + template + inline const typename vc_rebind_promote::type operator+ ( + const vector& lhs, + const vector& rhs + ) + { + typedef typename vc_rebind_promote::type ret_type; + return ret_type(lhs.x()+rhs.x(), lhs.y()+rhs.y(), lhs.z()); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template + inline const typename vc_rebind_promote::type operator- ( + const vector& lhs, + const vector& rhs + ) + { + typedef typename vc_rebind_promote::type ret_type; + return ret_type(lhs.x()-rhs.x(), lhs.y()-rhs.y()); + } + +// ---------------------------------------------------------------------------------------- + + template + inline const typename vc_rebind_promote::type operator- ( + const vector& lhs, + const vector& rhs + ) + { + typedef typename vc_rebind_promote::type ret_type; + return ret_type(lhs.x()-rhs.x(), lhs.y()-rhs.y(), lhs.z()-rhs.z()); + } + +// ---------------------------------------------------------------------------------------- + + template + inline const typename vc_rebind_promote::type operator- ( + const vector& lhs, + const vector& rhs + ) + { + typedef typename vc_rebind_promote::type ret_type; + return ret_type(lhs.x()-rhs.x(), lhs.y()-rhs.y(), -rhs.z()); + } + +// ---------------------------------------------------------------------------------------- + + template + inline const typename vc_rebind_promote::type operator- ( + const vector& lhs, + const vector& rhs + ) + { + typedef typename vc_rebind_promote::type ret_type; + return ret_type(lhs.x()-rhs.x(), lhs.y()-rhs.y(), lhs.z()); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template + inline typename disable_if, const typename vc_rebind_promote::type >::type operator* ( + const vector& v, + const U& s + ) + { + typedef typename vc_rebind_promote::type ret_type; + return ret_type(v.x()*s, v.y()*s); + } + +// ---------------------------------------------------------------------------------------- + + template + inline typename disable_if, const typename vc_rebind_promote::type >::type operator* ( + const U& s, + const vector& v + ) + { + typedef typename vc_rebind_promote::type ret_type; + return ret_type(v.x()*s, v.y()*s); + } + +// ---------------------------------------------------------------------------------------- + + template + inline typename disable_if, const typename vc_rebind_promote::type >::type operator* ( + const vector& v, + const U& s + ) + { + typedef typename vc_rebind_promote::type ret_type; + return ret_type(v.x()*s, v.y()*s, v.z()*s); + } + +// ---------------------------------------------------------------------------------------- + + template + inline typename disable_if, const typename vc_rebind_promote::type >::type operator* ( + const U& s, + const vector& v + ) + { + typedef typename vc_rebind_promote::type ret_type; + return ret_type(v.x()*s, v.y()*s, v.z()*s); + } + +// ---------------------------------------------------------------------------------------- + + template + inline void swap ( + vector & a, + vector & b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- + + template + inline void serialize ( + const vector& item, + std::ostream& out + ) + { + try + { + serialize(item.x(),out); + serialize(item.y(),out); + serialize(item.z(),out); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing object of type vector"); + } + } + + template + inline void deserialize ( + vector& item, + std::istream& in + ) + { + try + { + deserialize(item.x(),in); + deserialize(item.y(),in); + deserialize(item.z(),in); + } + catch (serialization_error& e) + { + item.x() = 0; + item.y() = 0; + item.z() = 0; + throw serialization_error(e.info + "\n while deserializing object of type vector"); + } + } + +// ---------------------------------------------------------------------------------------- + + template + inline void serialize ( + const vector& item, + std::ostream& out + ) + { + try + { + serialize(item.x(),out); + serialize(item.y(),out); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing object of type vector"); + } + } + + template + inline void deserialize ( + vector& item, + std::istream& in + ) + { + try + { + deserialize(item.x(),in); + deserialize(item.y(),in); + } + catch (serialization_error& e) + { + item.x() = 0; + item.y() = 0; + throw serialization_error(e.info + "\n while deserializing object of type vector"); + } + } + +// ---------------------------------------------------------------------------------------- + + template + std::ostream& operator<< ( + std::ostream& out, + const vector& item + ) + { + out << "(" << item.x() << ", " << item.y() << ", " << item.z() << ")"; + return out; + } + + template + std::istream& operator>>( + std::istream& in, + vector& item + ) + { + + // eat all the crap up to the '(' + while (in.peek() == ' ' || in.peek() == '\t' || in.peek() == '\r' || in.peek() == '\n') + in.get(); + + // there should be a '(' if not then this is an error + if (in.get() != '(') + { + in.setstate(in.rdstate() | std::ios::failbit); + return in; + } + + // eat all the crap up to the first number + while (in.peek() == ' ' || in.peek() == '\t') + in.get(); + in >> item.x(); + + if (!in.good()) + return in; + + // eat all the crap up to the next number + while (in.peek() == ' ' || in.peek() == '\t' || in.peek() == ',') + in.get(); + in >> item.y(); + + if (!in.good()) + return in; + + // eat all the crap up to the next number + while (in.peek() == ' ' || in.peek() == '\t' || in.peek() == ',') + in.get(); + in >> item.z(); + + if (!in.good()) + return in; + + // eat all the crap up to the ')' + while (in.peek() == ' ' || in.peek() == '\t') + in.get(); + + // there should be a ')' if not then this is an error + if (in.get() != ')') + in.setstate(in.rdstate() | std::ios::failbit); + return in; + } + +// ---------------------------------------------------------------------------------------- + + + template + std::ostream& operator<< ( + std::ostream& out, + const vector& item + ) + { + out << "(" << item.x() << ", " << item.y() << ")"; + return out; + } + + template + std::istream& operator>>( + std::istream& in, + vector& item + ) + { + + // eat all the crap up to the '(' + while (in.peek() == ' ' || in.peek() == '\t' || in.peek() == '\r' || in.peek() == '\n') + in.get(); + + // there should be a '(' if not then this is an error + if (in.get() != '(') + { + in.setstate(in.rdstate() | std::ios::failbit); + return in; + } + + // eat all the crap up to the first number + while (in.peek() == ' ' || in.peek() == '\t') + in.get(); + in >> item.x(); + + if (!in.good()) + return in; + + // eat all the crap up to the next number + while (in.peek() == ' ' || in.peek() == '\t' || in.peek() == ',') + in.get(); + in >> item.y(); + + if (!in.good()) + return in; + + // eat all the crap up to the ')' + while (in.peek() == ' ' || in.peek() == '\t') + in.get(); + + // there should be a ')' if not then this is an error + if (in.get() != ')') + in.setstate(in.rdstate() | std::ios::failbit); + return in; + } + +// ---------------------------------------------------------------------------------------- + + typedef vector point; + typedef vector dpoint; + +// ---------------------------------------------------------------------------------------- + +} + +namespace std +{ + /*! + Define std::less > so that you can use vectors in the associative containers. + !*/ + template + struct less > + { + typedef dlib::vector first_argument_type; + typedef dlib::vector second_argument_type; + typedef bool result_type; + inline bool operator() (const dlib::vector & a, const dlib::vector & b) const + { + if (a.x() < b.x()) return true; + else if (a.x() > b.x()) return false; + else if (a.y() < b.y()) return true; + else if (a.y() > b.y()) return false; + else if (a.z() < b.z()) return true; + else if (a.z() > b.z()) return false; + else return false; + } + }; + + /*! + Define std::less > so that you can use vectors in the associative containers. + !*/ + template + struct less > + { + typedef dlib::vector first_argument_type; + typedef dlib::vector second_argument_type; + typedef bool result_type; + inline bool operator() (const dlib::vector & a, const dlib::vector & b) const + { + if (a.x() < b.x()) return true; + else if (a.x() > b.x()) return false; + else if (a.y() < b.y()) return true; + else if (a.y() > b.y()) return false; + else return false; + } + }; +} + +#if defined(_MSC_VER) && _MSC_VER < 1400 +// turn this warning back on +#pragma warning(default:4805) +#endif + +#endif // DLIB_VECTOr_H_ + diff --git a/ml/dlib/dlib/geometry/vector_abstract.h b/ml/dlib/dlib/geometry/vector_abstract.h new file mode 100644 index 000000000..4aee8e32d --- /dev/null +++ b/ml/dlib/dlib/geometry/vector_abstract.h @@ -0,0 +1,489 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_VECTOR_ABSTRACT_ +#ifdef DLIB_VECTOR_ABSTRACT_ + +#include "../serialize.h" +#include +#include +#include "../matrix/matrix_abstract.h" + +namespace dlib +{ + template < + typename T, + long NR = 3 + > + class vector : public matrix + { + /*! + REQUIREMENTS ON T + T should be some object that provides an interface that is + compatible with double, float, int, long and the like. + + REQUIREMENTS ON NR + NR == 3 || NR == 2 + + INITIAL VALUE + x() == 0 + y() == 0 + z() == 0 + + WHAT THIS OBJECT REPRESENTS + This object represents a three dimensional vector. If NR == 2 then + this object is limited to representing points on the XY plane where + Z is set to 0. + + Also note that this object performs the appropriate integer and + floating point conversions and promotions when vectors of mixed + type are used together. For example: + vector vi; + vector vd; + vd + vi == a vector object type since that is what + is needed to contain the result of vi+vd without + any loss of information. + !*/ + + public: + + typedef T type; + + vector ( + ); + /*! + ensures + - #*this has been properly initialized + !*/ + + vector ( + const T _x, + const T _y, + const T _z + ); + /*! + requires + - NR == 3 + ensures + - #x() == _x + - #y() == _y + - #z() == _z + !*/ + + vector ( + const T _x, + const T _y + ); + /*! + requires + - NR == 2 + ensures + - #x() == _x + - #y() == _y + - #z() == 0 + !*/ + + template + vector ( + const vector& v + ); + /*! + ensures + - Initializes *this with the contents of v and does any rounding if necessary and also + takes care of converting between 2 and 3 dimensional vectors. + - if (U is a real valued type like float or double and T is an integral type like long) then + - if (NR == 3) then + - #x() == floor(v.x() + 0.5) + - #y() == floor(v.y() + 0.5) + - #z() == floor(v.z() + 0.5) + - else // NR == 2 + - #x() == floor(v.x() + 0.5) + - #y() == floor(v.y() + 0.5) + - #z() == 0 + - else + - if (NR == 3) then + - #x() == v.x() + - #y() == v.y() + - #z() == v.z() + - else // NR == 2 + - #x() == v.x() + - #y() == v.y() + - #z() == 0 + !*/ + + template + vector ( + const matrix_exp& m + ); + /*! + requires + - m.size() == NR + - m.nr() == 1 || m.nc() == 1 (i.e. m must be a row or column matrix) + ensures + - Initializes *this with the contents of m and does any rounding if necessary and also + takes care of converting between 2 and 3 dimensional vectors. + - if (m contains real valued values like float or double and T is an integral type like long) then + - #x() == floor(m(0) + 0.5) + - #y() == floor(m(1) + 0.5) + - if (NR == 3) then + - #z() == floor(m(2) + 0.5) + - else + - #z() == 0 + - else + - #x() == m(0) + - #y() == m(1) + - if (NR == 3) then + - #z() == m(2) + - else + - #z() == 0 + !*/ + + ~vector ( + ); + /*! + ensures + - all resources associated with *this have been released + !*/ + + + double length( + ) const; + /*! + ensures + - returns the length of the vector + !*/ + + double length_squared( + ) const; + /*! + ensures + - returns length()*length() + !*/ + + T& x ( + ); + /*! + ensures + - returns a reference to the x component of the vector + !*/ + + T& y ( + ); + /*! + ensures + - returns a reference to the y component of the vector + !*/ + + T& z ( + ); + /*! + requires + - NR == 3 (this function actually doesn't exist when NR != 3) + ensures + - returns a reference to the z component of the vector + !*/ + + const T& x ( + ) const; + /*! + ensures + - returns a const reference to the x component of the vector + !*/ + + const T& y ( + ) const; + /*! + ensures + - returns a const reference to the y component of the vector + !*/ + + const T& z ( + ) const; + /*! + ensures + - if (NR == 3) then + - returns a const reference to the z component of the vector + - else + - return 0 + (there isn't really a z in this case so we just return 0) + !*/ + + T dot ( + const vector& rhs + ) const; + /*! + ensures + - returns the result of the dot product between *this and rhs + !*/ + + vector cross ( + const vector& rhs + ) const; + /*! + ensures + - returns the result of the cross product between *this and rhs + !*/ + + vector normalize ( + ) const; + /*! + ensures + - returns a vector with length() == 1 and in the same direction as *this + !*/ + + vector operator+ ( + const vector& rhs + ) const; + /*! + ensures + - returns the result of adding *this to rhs + !*/ + + vector operator- ( + const vector& rhs + ) const; + /*! + ensures + - returns the result of subtracting rhs from *this + !*/ + + vector operator- ( + ) const; + /*! + ensures + - returns -1*(*this) + !*/ + + vector operator/ ( + const T rhs + ) const; + /*! + ensures + - returns the result of dividing *this by rhs + !*/ + + vector& operator= ( + const vector& rhs + ); + /*! + ensures + - #x() == rhs.x() + - #y() == rhs.y() + - #z() == rhs.z() + - returns #*this + !*/ + + vector& operator += ( + const vector& rhs + ); + /*! + ensures + - #*this == *this + rhs + - returns #*this + !*/ + + vector& operator -= ( + const vector& rhs + ); + /*! + ensures + - #*this == *this - rhs + - returns #*this + !*/ + + vector& operator *= ( + const T rhs + ); + /*! + ensures + - #*this == *this * rhs + - returns #*this + !*/ + + vector& operator /= ( + const T rhs + ); + /*! + ensures + - #*this == *this / rhs + - returns #*this + !*/ + + template + bool operator== ( + const vector& rhs + ) const; + /*! + ensures + - if (x() == rhs.x() && y() == rhs.y() && z() == rhs.z()) then + - returns true + - else + - returns false + !*/ + + template + bool operator!= ( + const vector& rhs + ) const; + /*! + ensures + - returns !((*this) == rhs) + !*/ + + void swap ( + vector& item + ); + /*! + ensures + - swaps *this and item + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template + vector operator* ( + const vector & lhs, + const U rhs + ); + /*! + ensures + - returns the result of multiplying the scalar rhs by lhs + !*/ + + template + vector operator* ( + const U lhs, + const vector & rhs + ); + /*! + ensures + - returns the result of multiplying the scalar lhs by rhs + !*/ + + template + inline void swap ( + vector & a, + vector & b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + + template + void serialize ( + const vector& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + + template + void deserialize ( + vector& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + + template + std::ostream& operator<< ( + std::ostream& out, + const vector& item + ); + /*! + ensures + - writes item to out in the form "(x, y, z)" + !*/ + + template + std::istream& operator>>( + std::istream& in, + vector& item + ); + /*! + ensures + - reads a vector from the input stream in and stores it in #item. + The data in the input stream should be of the form (x, y, z) + !*/ + + template + std::ostream& operator<< ( + std::ostream& out, + const vector& item + ); + /*! + ensures + - writes item to out in the form "(x, y)" + !*/ + + template + std::istream& operator>>( + std::istream& in, + vector& item + ); + /*! + ensures + - reads a vector from the input stream in and stores it in #item. + The data in the input stream should be of the form (x, y) + !*/ + +// ---------------------------------------------------------------------------------------- + + /*!A point + This is just a typedef of the vector object. + !*/ + + typedef vector point; + + /*!A dpoint + This is just a typedef of the vector object. + !*/ + + typedef vector dpoint; + +// ---------------------------------------------------------------------------------------- + +} + +namespace std +{ + /*! + Define std::less > so that you can use vectors in the associative containers. + !*/ + template + struct less > : public binary_function ,dlib::vector ,bool> + { + inline bool operator() (const dlib::vector & a, const dlib::vector & b) const + { + if (a.x() < b.x()) return true; + else if (a.x() > b.x()) return false; + else if (a.y() < b.y()) return true; + else if (a.y() > b.y()) return false; + else if (a.z() < b.z()) return true; + else if (a.z() > b.z()) return false; + else return false; + } + }; + + /*! + Define std::less > so that you can use vectors in the associative containers. + !*/ + template + struct less > : public binary_function ,dlib::vector ,bool> + { + inline bool operator() (const dlib::vector & a, const dlib::vector & b) const + { + if (a.x() < b.x()) return true; + else if (a.x() > b.x()) return false; + else if (a.y() < b.y()) return true; + else if (a.y() > b.y()) return false; + else return false; + } + }; +} + +#endif // DLIB_VECTOR_ABSTRACT_ + diff --git a/ml/dlib/dlib/global_optimization.h b/ml/dlib/dlib/global_optimization.h new file mode 100644 index 000000000..26b40fcd9 --- /dev/null +++ b/ml/dlib/dlib/global_optimization.h @@ -0,0 +1,14 @@ +// Copyright (C) 2017 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_GLOBAL_OPTIMIZATIOn_HEADER +#define DLIB_GLOBAL_OPTIMIZATIOn_HEADER + +#include "global_optimization/upper_bound_function.h" +#include "global_optimization/global_function_search.h" +#include "global_optimization/find_max_global.h" + +#endif // DLIB_GLOBAL_OPTIMIZATIOn_HEADER + + + + diff --git a/ml/dlib/dlib/global_optimization/find_max_global.h b/ml/dlib/dlib/global_optimization/find_max_global.h new file mode 100644 index 000000000..5356129f5 --- /dev/null +++ b/ml/dlib/dlib/global_optimization/find_max_global.h @@ -0,0 +1,511 @@ +// Copyright (C) 2017 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_FiND_GLOBAL_MAXIMUM_hH_ +#define DLIB_FiND_GLOBAL_MAXIMUM_hH_ + +#include "find_max_global_abstract.h" +#include "global_function_search.h" +#include "../metaprogramming.h" +#include +#include + +namespace dlib +{ + namespace gopt_impl + { + + // ---------------------------------------------------------------------------------------- + + class disable_decay_to_scalar + { + const matrix& a; + public: + disable_decay_to_scalar(const matrix& a) : a(a){} + operator const matrix&() const { return a;} + }; + + + template + auto _cwv ( + T&& f, + const matrix& a, + compile_time_integer_list + ) -> decltype(f(a(indices-1)...)) + { + DLIB_CASSERT(a.size() == sizeof...(indices), + "You invoked dlib::call_function_and_expand_args(f,a) but the number of arguments expected by f() doesn't match the size of 'a'. " + << "Expected " << sizeof...(indices) << " arguments but got " << a.size() << "." + ); + return f(a(indices-1)...); + } + + // Visual studio, as of November 2017, doesn't support C++11 and can't compile this code. + // So we write the terrible garbage in the #else for visual studio. When Visual Studio supports C++11 I'll update this #ifdef to use the C++11 code. +#ifndef _MSC_VER + template + struct call_function_and_expand_args + { + template + static auto go(T&& f, const matrix& a) -> decltype(_cwv(std::forward(f),a,typename make_compile_time_integer_range::type())) + { + return _cwv(std::forward(f),a,typename make_compile_time_integer_range::type()); + } + + template + static auto go(T&& f, const matrix& a) -> decltype(call_function_and_expand_args::template go(std::forward(f),a)) + { + return call_function_and_expand_args::go(std::forward(f),a); + } + }; + + template <> + struct call_function_and_expand_args<0> + { + template + static auto go(T&& f, const matrix& a) -> decltype(f(disable_decay_to_scalar(a))) + { + return f(disable_decay_to_scalar(a)); + } + }; +#else + template + struct call_function_and_expand_args + { +template static auto go(T&& f, const matrix& a) -> decltype(f(disable_decay_to_scalar(a))) {return f(disable_decay_to_scalar(a)); } +template static auto go(T&& f, const matrix& a) -> decltype(f(a(0))) { DLIB_CASSERT(a.size() == 1); return f(a(0)); } +template static auto go(T&& f, const matrix& a) -> decltype(f(a(0),a(1))) { DLIB_CASSERT(a.size() == 2); return f(a(0),a(1)); } +template static auto go(T&& f, const matrix& a) -> decltype(f(a(0), a(1), a(2))) { DLIB_CASSERT(a.size() == 3); return f(a(0), a(1),a(2)); } +template static auto go(T&& f, const matrix& a) -> decltype(f(a(0), a(1), a(2), a(3))) { DLIB_CASSERT(a.size() == 4); return f(a(0), a(1), a(2), a(3)); } +template static auto go(T&& f, const matrix& a) -> decltype(f(a(0), a(1), a(2), a(3), a(4))) { DLIB_CASSERT(a.size() == 5); return f(a(0), a(1), a(2), a(3), a(4)); } +template static auto go(T&& f, const matrix& a) -> decltype(f(a(0), a(1), a(2), a(3), a(4), a(5))) { DLIB_CASSERT(a.size() == 6); return f(a(0), a(1), a(2), a(3), a(4), a(5)); } +template static auto go(T&& f, const matrix& a) -> decltype(f(a(0), a(1), a(2), a(3), a(4), a(5), a(6))) { DLIB_CASSERT(a.size() == 7); return f(a(0), a(1), a(2), a(3), a(4), a(5), a(6)); } + }; +#endif + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template + auto call_function_and_expand_args( + T&& f, + const matrix& a + ) -> decltype(gopt_impl::call_function_and_expand_args<40>::go(f,a)) + { + // unpack up to 40 parameters when calling f() + return gopt_impl::call_function_and_expand_args<40>::go(std::forward(f),a); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + struct max_function_calls + { + max_function_calls() = default; + explicit max_function_calls(size_t max_calls) : max_calls(max_calls) {} + size_t max_calls = std::numeric_limits::max(); + }; + +// ---------------------------------------------------------------------------------------- + + const auto FOREVER = std::chrono::hours(24*356*290); // 290 years + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + template < + typename funct + > + std::pair find_max_global ( + std::vector& functions, + std::vector specs, + const max_function_calls num, + const std::chrono::nanoseconds max_runtime, + double solver_epsilon, + double ymult + ) + { + // Decide which parameters should be searched on a log scale. Basically, it's + // common for machine learning models to have parameters that should be searched on + // a log scale (e.g. SVM C). These parameters are usually identifiable because + // they have bounds like [1e-5 1e10], that is, they span a very large range of + // magnitudes from really small to really big. So there we are going to check for + // that and if we find parameters with that kind of bound constraints we will + // transform them to a log scale automatically. + std::vector> log_scale(specs.size()); + for (size_t i = 0; i < specs.size(); ++i) + { + for (long j = 0; j < specs[i].lower.size(); ++j) + { + if (!specs[i].is_integer_variable[j] && specs[i].lower(j) > 0 && specs[i].upper(j)/specs[i].lower(j) >= 1000) + { + log_scale[i].push_back(true); + specs[i].lower(j) = std::log(specs[i].lower(j)); + specs[i].upper(j) = std::log(specs[i].upper(j)); + } + else + { + log_scale[i].push_back(false); + } + } + } + + global_function_search opt(specs); + opt.set_solver_epsilon(solver_epsilon); + + const auto time_to_stop = std::chrono::steady_clock::now() + max_runtime; + + // Now run the main solver loop. + for (size_t i = 0; i < num.max_calls && std::chrono::steady_clock::now() < time_to_stop; ++i) + { + auto next = opt.get_next_x(); + matrix x = next.x(); + // Undo any log-scaling that was applied to the variables before we pass them + // to the functions being optimized. + for (long j = 0; j < x.size(); ++j) + { + if (log_scale[next.function_idx()][j]) + x(j) = std::exp(x(j)); + } + double y = ymult*call_function_and_expand_args(functions[next.function_idx()], x); + next.set(y); + } + + + matrix x; + double y; + size_t function_idx; + opt.get_best_function_eval(x,y,function_idx); + // Undo any log-scaling that was applied to the variables before we output them. + for (long j = 0; j < x.size(); ++j) + { + if (log_scale[function_idx][j]) + x(j) = std::exp(x(j)); + } + return std::make_pair(function_idx, function_evaluation(x,y/ymult)); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct + > + std::pair find_max_global ( + std::vector& functions, + std::vector specs, + const max_function_calls num, + const std::chrono::nanoseconds max_runtime = FOREVER, + double solver_epsilon = 0 + ) + { + return impl::find_max_global(functions, std::move(specs), num, max_runtime, solver_epsilon, +1); + } + + template < + typename funct + > + std::pair find_min_global ( + std::vector& functions, + std::vector specs, + const max_function_calls num, + const std::chrono::nanoseconds max_runtime = FOREVER, + double solver_epsilon = 0 + ) + { + return impl::find_max_global(functions, std::move(specs), num, max_runtime, solver_epsilon, -1); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct + > + function_evaluation find_max_global ( + funct f, + const matrix& bound1, + const matrix& bound2, + const std::vector& is_integer_variable, + const max_function_calls num, + const std::chrono::nanoseconds max_runtime = FOREVER, + double solver_epsilon = 0 + ) + { + std::vector functions(1,std::move(f)); + std::vector specs(1, function_spec(bound1, bound2, is_integer_variable)); + return find_max_global(functions, std::move(specs), num, max_runtime, solver_epsilon).second; + } + + template < + typename funct + > + function_evaluation find_min_global ( + funct f, + const matrix& bound1, + const matrix& bound2, + const std::vector& is_integer_variable, + const max_function_calls num, + const std::chrono::nanoseconds max_runtime = FOREVER, + double solver_epsilon = 0 + ) + { + std::vector functions(1,std::move(f)); + std::vector specs(1, function_spec(bound1, bound2, is_integer_variable)); + return find_min_global(functions, std::move(specs), num, max_runtime, solver_epsilon).second; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct + > + function_evaluation find_max_global ( + funct f, + const matrix& bound1, + const matrix& bound2, + const std::vector& is_integer_variable, + const max_function_calls num, + double solver_epsilon + ) + { + return find_max_global(std::move(f), bound1, bound2, is_integer_variable, num, FOREVER, solver_epsilon); + } + + template < + typename funct + > + function_evaluation find_min_global ( + funct f, + const matrix& bound1, + const matrix& bound2, + const std::vector& is_integer_variable, + const max_function_calls num, + double solver_epsilon + ) + { + return find_min_global(std::move(f), bound1, bound2, is_integer_variable, num, FOREVER, solver_epsilon); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct + > + function_evaluation find_max_global ( + funct f, + const matrix& bound1, + const matrix& bound2, + const max_function_calls num, + const std::chrono::nanoseconds max_runtime = FOREVER, + double solver_epsilon = 0 + ) + { + return find_max_global(std::move(f), bound1, bound2, std::vector(bound1.size(),false), num, max_runtime, solver_epsilon); + } + + template < + typename funct + > + function_evaluation find_min_global ( + funct f, + const matrix& bound1, + const matrix& bound2, + const max_function_calls num, + const std::chrono::nanoseconds max_runtime = FOREVER, + double solver_epsilon = 0 + ) + { + return find_min_global(std::move(f), bound1, bound2, std::vector(bound1.size(),false), num, max_runtime, solver_epsilon); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct + > + function_evaluation find_max_global ( + funct f, + const matrix& bound1, + const matrix& bound2, + const max_function_calls num, + double solver_epsilon + ) + { + return find_max_global(std::move(f), bound1, bound2, std::vector(bound1.size(),false), num, FOREVER, solver_epsilon); + } + + template < + typename funct + > + function_evaluation find_min_global ( + funct f, + const matrix& bound1, + const matrix& bound2, + const max_function_calls num, + double solver_epsilon + ) + { + return find_min_global(std::move(f), bound1, bound2, std::vector(bound1.size(),false), num, FOREVER, solver_epsilon); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct + > + function_evaluation find_max_global ( + funct f, + const double bound1, + const double bound2, + const max_function_calls num, + const std::chrono::nanoseconds max_runtime = FOREVER, + double solver_epsilon = 0 + ) + { + return find_max_global(std::move(f), matrix({bound1}), matrix({bound2}), num, max_runtime, solver_epsilon); + } + + template < + typename funct + > + function_evaluation find_min_global ( + funct f, + const double bound1, + const double bound2, + const max_function_calls num, + const std::chrono::nanoseconds max_runtime = FOREVER, + double solver_epsilon = 0 + ) + { + return find_min_global(std::move(f), matrix({bound1}), matrix({bound2}), num, max_runtime, solver_epsilon); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct + > + function_evaluation find_max_global ( + funct f, + const double bound1, + const double bound2, + const max_function_calls num, + double solver_epsilon + ) + { + return find_max_global(std::move(f), matrix({bound1}), matrix({bound2}), num, FOREVER, solver_epsilon); + } + + template < + typename funct + > + function_evaluation find_min_global ( + funct f, + const double bound1, + const double bound2, + const max_function_calls num, + double solver_epsilon + ) + { + return find_min_global(std::move(f), matrix({bound1}), matrix({bound2}), num, FOREVER, solver_epsilon); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct + > + function_evaluation find_max_global ( + funct f, + const matrix& bound1, + const matrix& bound2, + const std::chrono::nanoseconds max_runtime, + double solver_epsilon = 0 + ) + { + return find_max_global(std::move(f), bound1, bound2, max_function_calls(), max_runtime, solver_epsilon); + } + + template < + typename funct + > + function_evaluation find_min_global ( + funct f, + const matrix& bound1, + const matrix& bound2, + const std::chrono::nanoseconds max_runtime, + double solver_epsilon = 0 + ) + { + return find_min_global(std::move(f), bound1, bound2, max_function_calls(), max_runtime, solver_epsilon); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct + > + function_evaluation find_max_global ( + funct f, + const double bound1, + const double bound2, + const std::chrono::nanoseconds max_runtime, + double solver_epsilon = 0 + ) + { + return find_max_global(std::move(f), bound1, bound2, max_function_calls(), max_runtime, solver_epsilon); + } + + template < + typename funct + > + function_evaluation find_min_global ( + funct f, + const double bound1, + const double bound2, + const std::chrono::nanoseconds max_runtime, + double solver_epsilon = 0 + ) + { + return find_min_global(std::move(f), bound1, bound2, max_function_calls(), max_runtime, solver_epsilon); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct + > + function_evaluation find_max_global ( + funct f, + const matrix& bound1, + const matrix& bound2, + const std::vector& is_integer_variable, + const std::chrono::nanoseconds max_runtime, + double solver_epsilon = 0 + ) + { + return find_max_global(std::move(f), bound1, bound2, is_integer_variable, max_function_calls(), max_runtime, solver_epsilon); + } + + template < + typename funct + > + function_evaluation find_min_global ( + funct f, + const matrix& bound1, + const matrix& bound2, + const std::vector& is_integer_variable, + const std::chrono::nanoseconds max_runtime, + double solver_epsilon = 0 + ) + { + return find_min_global(std::move(f), bound1, bound2, is_integer_variable, max_function_calls(), max_runtime, solver_epsilon); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_FiND_GLOBAL_MAXIMUM_hH_ + diff --git a/ml/dlib/dlib/global_optimization/find_max_global_abstract.h b/ml/dlib/dlib/global_optimization/find_max_global_abstract.h new file mode 100644 index 000000000..4be62b154 --- /dev/null +++ b/ml/dlib/dlib/global_optimization/find_max_global_abstract.h @@ -0,0 +1,496 @@ +// Copyright (C) 2017 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_FiND_GLOBAL_MAXIMUM_ABSTRACT_hH_ +#ifdef DLIB_FiND_GLOBAL_MAXIMUM_ABSTRACT_hH_ + +#include "upper_bound_function_abstract.h" +#include "global_function_search_abstract.h" +#include "../metaprogramming.h" +#include "../matrix.h" +#include +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + auto call_function_and_expand_args( + T&& f, + const matrix& args + ) -> decltype(f(args or args expanded out as discussed below)); + /*! + requires + - f is a function object with one of the following signatures: + auto f(matrix) + auto f(double) + auto f(double,double) + auto f(double,double,double) + ... + auto f(double,double,...,double) // up to 40 double arguments + - if (f() explicitly expands its arguments) then + - args.size() == the number of arguments taken by f. + ensures + - This function invokes f() with the given arguments and returns the result. + However, the signature of f() is allowed to vary. In particular, if f() + takes a matrix as a single argument then this function simply + calls f(args). However, if f() takes double arguments then args is expanded + appropriately, i.e. it calls one of the following as appropriate: + f(args(0)) + f(args(0),args(1)) + ... + f(args(0),args(1),...,args(N)) + and the result of f() is returned. + !*/ + +// ---------------------------------------------------------------------------------------- + + struct max_function_calls + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a simple typed integer class used to strongly type the "max number + of function calls" argument to find_max_global() and find_min_global(). + + !*/ + + max_function_calls() = default; + + explicit max_function_calls(size_t max_calls) : max_calls(max_calls) {} + + size_t max_calls = std::numeric_limits::max(); + }; + +// ---------------------------------------------------------------------------------------- + + const auto FOREVER = std::chrono::hours(24*356*290); // 290 years, basically forever + +// ---------------------------------------------------------------------------------------- + + template < + typename funct + > + std::pair find_max_global ( + std::vector& functions, + const std::vector& specs, + const max_function_calls num, + const std::chrono::nanoseconds max_runtime = FOREVER, + double solver_epsilon = 0 + ); + /*! + requires + - functions.size() != 0 + - functions.size() == specs.size() + - solver_epsilon >= 0 + - for all valid i: + - functions[i] is a real valued multi-variate function object. Moreover, + it must be callable via an expression of the form: + call_function_and_expand_args(functions[i], specs.lower). This means + function[i] should have a signature like one of the following: + double f(matrix) + double f(double) + double f(double,double) + etc. + - The range of inputs defined by specs[i] must be valid inputs to + functions[i]. + ensures + - This function performs global optimization on the set of given functions. + The goal is to maximize the following objective function: + max_{i,x_i}: functions[i](x_i) + subject to the constraints on x_i defined by specs[i]. + Once found, the return value of find_max_global() is: + make_pair(i, function_evaluation(x_i,functions[i](x_i))). + That is, we search for the settings of i and x that return the largest output + and return those settings. + - The search is performed using the global_function_search object. See its + documentation for details of the algorithm. + - We set the global_function_search::get_solver_epsilon() parameter to + solver_epsilon. Therefore, the search will only attempt to find a global + maximizer to at most solver_epsilon accuracy. Once a local maximizer is + found to that accuracy the search will focus entirely on finding other maxima + elsewhere rather than on further improving the current local optima found so + far. That is, once a local maxima is identified to about solver_epsilon + accuracy, the algorithm will spend all its time exploring the functions to + find other local maxima to investigate. An epsilon of 0 means it will keep + solving until it reaches full floating point precision. Larger values will + cause it to switch to pure global exploration sooner and therefore might be + more effective if your objective function has many local maxima and you don't + care about a super high precision solution. + - find_max_global() runs until one of the following is true: + - The total number of calls to the provided functions is == num.max_calls + - More than max_runtime time has elapsed since the start of this function. + - Any variables that satisfy the following conditions are optimized on a log-scale: + - The lower bound on the variable is > 0 + - The ratio of the upper bound to lower bound is >= 1000 + - The variable is not an integer variable + We do this because it's common to optimize machine learning models that have + parameters with bounds in a range such as [1e-5 to 1e10] (e.g. the SVM C + parameter) and it's much more appropriate to optimize these kinds of + variables on a log scale. So we transform them by applying std::log() to + them and then undo the transform via std::exp() before invoking the function + being optimized. Therefore, this transformation is invisible to the user + supplied functions. In most cases, it improves the efficiency of the + optimizer. + !*/ + + template < + typename funct + > + std::pair find_min_global ( + std::vector& functions, + const std::vector& specs, + const max_function_calls num, + const std::chrono::nanoseconds max_runtime = FOREVER, + double solver_epsilon = 0 + ); + /*! + This function is identical to the find_max_global() defined immediately above, + except that we perform minimization rather than maximization. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename funct + > + function_evaluation find_max_global ( + funct f, + const matrix& bound1, + const matrix& bound2, + const std::vector& is_integer_variable, + const max_function_calls num, + const std::chrono::nanoseconds max_runtime = FOREVER, + double solver_epsilon = 0 + ); + /*! + requires + - bound1.size() == bound2.size() == is_integer_variable.size() + - for all valid i: bound1(i) != bound2(i) + - solver_epsilon >= 0 + - f() is a real valued multi-variate function object. Moreover, it must be + callable via an expression of the form: call_function_and_expand_args(f, + bound1). This means f() should have a signature like one of the following: + double f(matrix) + double f(double) + double f(double,double) + etc. + - The range of inputs defined by function_spec(bound1,bound2,is_integer_variable) + must be valid inputs to f(). + ensures + - This function performs global optimization on the given f() function. + The goal is to maximize the following objective function: + f(x) + subject to the constraints on x defined by function_spec(bound1,bound2,is_integer_variable). + Once found, the return value of find_max_global() is: + function_evaluation(x,f(x))). + That is, we search for the setting of x that returns the largest output and + return that setting. + - The search is performed using the global_function_search object. See its + documentation for details of the algorithm. + - We set the global_function_search::get_solver_epsilon() parameter to + solver_epsilon. Therefore, the search will only attempt to find a global + maximizer to at most solver_epsilon accuracy. Once a local maximizer is + found to that accuracy the search will focus entirely on finding other maxima + elsewhere rather than on further improving the current local optima found so + far. That is, once a local maxima is identified to about solver_epsilon + accuracy, the algorithm will spend all its time exploring the function to + find other local maxima to investigate. An epsilon of 0 means it will keep + solving until it reaches full floating point precision. Larger values will + cause it to switch to pure global exploration sooner and therefore might be + more effective if your objective function has many local maxima and you don't + care about a super high precision solution. + - find_max_global() runs until one of the following is true: + - The total number of calls to f() is == num.max_calls + - More than max_runtime time has elapsed since the start of this function. + - Any variables that satisfy the following conditions are optimized on a log-scale: + - The lower bound on the variable is > 0 + - The ratio of the upper bound to lower bound is >= 1000 + - The variable is not an integer variable + We do this because it's common to optimize machine learning models that have + parameters with bounds in a range such as [1e-5 to 1e10] (e.g. the SVM C + parameter) and it's much more appropriate to optimize these kinds of + variables on a log scale. So we transform them by applying std::log() to + them and then undo the transform via std::exp() before invoking the function + being optimized. Therefore, this transformation is invisible to the user + supplied functions. In most cases, it improves the efficiency of the + optimizer. + !*/ + + template < + typename funct + > + function_evaluation find_min_global ( + funct f, + const matrix& bound1, + const matrix& bound2, + const std::vector& is_integer_variable, + const max_function_calls num, + const std::chrono::nanoseconds max_runtime = FOREVER, + double solver_epsilon = 0 + ); + /*! + This function is identical to the find_max_global() defined immediately above, + except that we perform minimization rather than maximization. + !*/ + +// ---------------------------------------------------------------------------------------- +// The following functions are just convenient overloads for calling the above defined +// find_max_global() and find_min_global() routines. +// ---------------------------------------------------------------------------------------- + + template < + typename funct + > + function_evaluation find_max_global ( + funct f, + const matrix& bound1, + const matrix& bound2, + const std::vector& is_integer_variable, + const max_function_calls num, + double solver_epsilon + ) + { + return find_max_global(std::move(f), bound1, bound2, std::vector(bound1.size(),false), num, FOREVER, solver_epsilon); + } + + template < + typename funct + > + function_evaluation find_min_global ( + funct f, + const matrix& bound1, + const matrix& bound2, + const std::vector& is_integer_variable, + const max_function_calls num, + double solver_epsilon + ) + { + return find_min_global(std::move(f), bound1, bound2, std::vector(bound1.size(),false), num, FOREVER, solver_epsilon); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct + > + function_evaluation find_max_global ( + funct f, + const matrix& bound1, + const matrix& bound2, + const max_function_calls num, + const std::chrono::nanoseconds max_runtime = FOREVER, + double solver_epsilon = 0 + ) + { + return find_max_global(std::move(f), bound1, bound2, std::vector(bound1.size(),false), num, max_runtime, solver_epsilon); + } + + template < + typename funct + > + function_evaluation find_min_global ( + funct f, + const matrix& bound1, + const matrix& bound2, + const max_function_calls num, + const std::chrono::nanoseconds max_runtime = FOREVER, + double solver_epsilon = 0 + ) + { + return find_min_global(std::move(f), bound1, bound2, std::vector(bound1.size(),false), num, max_runtime, solver_epsilon); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct + > + function_evaluation find_max_global ( + funct f, + const matrix& bound1, + const matrix& bound2, + const max_function_calls num, + double solver_epsilon + ) + { + return find_max_global(std::move(f), bound1, bound2, std::vector(bound1.size(),false), num, FOREVER, solver_epsilon); + } + + template < + typename funct + > + function_evaluation find_min_global ( + funct f, + const matrix& bound1, + const matrix& bound2, + const max_function_calls num, + double solver_epsilon + ) + { + return find_min_global(std::move(f), bound1, bound2, std::vector(bound1.size(),false), num, FOREVER, solver_epsilon); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct + > + function_evaluation find_max_global ( + funct f, + const double bound1, + const double bound2, + const max_function_calls num, + const std::chrono::nanoseconds max_runtime = FOREVER, + double solver_epsilon = 0 + ) + { + return find_max_global(std::move(f), matrix({bound1}), matrix({bound2}), num, max_runtime, solver_epsilon); + } + + template < + typename funct + > + function_evaluation find_min_global ( + funct f, + const double bound1, + const double bound2, + const max_function_calls num, + const std::chrono::nanoseconds max_runtime = FOREVER, + double solver_epsilon = 0 + ) + { + return find_min_global(std::move(f), matrix({bound1}), matrix({bound2}), num, max_runtime, solver_epsilon); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct + > + function_evaluation find_max_global ( + funct f, + const double bound1, + const double bound2, + const max_function_calls num, + double solver_epsilon + ) + { + return find_max_global(std::move(f), matrix({bound1}), matrix({bound2}), num, FOREVER, solver_epsilon); + } + + template < + typename funct + > + function_evaluation find_min_global ( + funct f, + const double bound1, + const double bound2, + const max_function_calls num, + double solver_epsilon + ) + { + return find_min_global(std::move(f), matrix({bound1}), matrix({bound2}), num, FOREVER, solver_epsilon); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct + > + function_evaluation find_max_global ( + funct f, + const matrix& bound1, + const matrix& bound2, + const std::chrono::nanoseconds max_runtime, + double solver_epsilon = 0 + ) + { + return find_max_global(std::move(f), bound1, bound2, max_function_calls(), max_runtime, solver_epsilon); + } + + template < + typename funct + > + function_evaluation find_min_global ( + funct f, + const matrix& bound1, + const matrix& bound2, + const std::chrono::nanoseconds max_runtime, + double solver_epsilon = 0 + ) + { + return find_min_global(std::move(f), bound1, bound2, max_function_calls(), max_runtime, solver_epsilon); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct + > + function_evaluation find_max_global ( + funct f, + const double bound1, + const double bound2, + const std::chrono::nanoseconds max_runtime, + double solver_epsilon = 0 + ) + { + return find_max_global(std::move(f), bound1, bound2, max_function_calls(), max_runtime, solver_epsilon); + } + + template < + typename funct + > + function_evaluation find_min_global ( + funct f, + const double bound1, + const double bound2, + const std::chrono::nanoseconds max_runtime, + double solver_epsilon = 0 + ) + { + return find_min_global(std::move(f), bound1, bound2, max_function_calls(), max_runtime, solver_epsilon); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct + > + function_evaluation find_max_global ( + funct f, + const matrix& bound1, + const matrix& bound2, + const std::vector& is_integer_variable, + const std::chrono::nanoseconds max_runtime, + double solver_epsilon = 0 + ) + { + return find_max_global(std::move(f), bound1, bound2, is_integer_variable, max_function_calls(), max_runtime, solver_epsilon); + } + + template < + typename funct + > + function_evaluation find_min_global ( + funct f, + const matrix& bound1, + const matrix& bound2, + const std::vector& is_integer_variable, + const std::chrono::nanoseconds max_runtime, + double solver_epsilon = 0 + ) + { + return find_min_global(std::move(f), bound1, bound2, is_integer_variable, max_function_calls(), max_runtime, solver_epsilon); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_FiND_GLOBAL_MAXIMUM_ABSTRACT_hH_ + + diff --git a/ml/dlib/dlib/global_optimization/global_function_search.cpp b/ml/dlib/dlib/global_optimization/global_function_search.cpp new file mode 100644 index 000000000..fada289a4 --- /dev/null +++ b/ml/dlib/dlib/global_optimization/global_function_search.cpp @@ -0,0 +1,942 @@ + +#include "global_function_search.h" +#include "upper_bound_function.h" +#include "../optimization.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace qopt_impl + { + void fit_quadratic_to_points_mse( + const matrix& X, + const matrix& Y, + matrix& H, + matrix& g, + double& c + ) + { + DLIB_CASSERT(X.size() > 0); + DLIB_CASSERT(X.nc() == Y.size()); + DLIB_CASSERT(X.nc() >= (X.nr()+1)*(X.nr()+2)/2); + + const long dims = X.nr(); + const long M = X.nc(); + + matrix W((X.nr()+1)*(X.nr()+2)/2, M); + + set_subm(W, 0,0, dims, M) = X; + set_subm(W, dims,0, 1, M) = 1; + for (long c = 0; c < X.nc(); ++c) + { + long wr = dims+1; + for (long r = 0; r < X.nr(); ++r) + { + for (long r2 = r; r2 < X.nr(); ++r2) + { + W(wr,c) = X(r,c)*X(r2,c); + if (r2 == r) + W(wr,c) *= 0.5; + ++wr; + } + } + } + + matrix z = pinv(trans(W))*Y; + + c = z(dims); + g = rowm(z, range(0,dims-1)); + + H.set_size(dims,dims); + + long wr = dims+1; + for (long r = 0; r < X.nr(); ++r) + { + for (long r2 = r; r2 < X.nr(); ++r2) + { + H(r,r2) = H(r2,r) = z(wr++); + } + } + } + + // ---------------------------------------------------------------------------------------- + + void fit_quadratic_to_points( + const matrix& X, + const matrix& Y, + matrix& H, + matrix& g, + double& c + ) + /*! + requires + - X.size() > 0 + - X.nc() == Y.size() + - X.nr()+1 <= X.nc() + ensures + - This function finds a quadratic function, Q(x), that interpolates the + given set of points. If there aren't enough points to uniquely define + Q(x) then the Q(x) that fits the given points with the minimum Frobenius + norm hessian matrix is selected. + - To be precise: + - Let: Q(x) == 0.5*trans(x)*H*x + trans(x)*g + c + - Then this function finds H, g, and c that minimizes the following: + sum(squared(H)) + such that: + Q(colm(X,i)) == Y(i), for all valid i + - If there are more points than necessary to constrain Q then the Q + that best interpolates the function in the mean squared sense is + found. + !*/ + { + DLIB_CASSERT(X.size() > 0); + DLIB_CASSERT(X.nc() == Y.size()); + DLIB_CASSERT(X.nr()+1 <= X.nc()); + + + if (X.nc() >= (X.nr()+1)*(X.nr()+2)/2) + { + fit_quadratic_to_points_mse(X,Y,H,g,c); + return; + } + + + const long dims = X.nr(); + const long M = X.nc(); + + /* + Our implementation uses the equations 3.9 - 3.12 from the paper: + The NEWUOA software for unconstrained optimization without derivatives + By M.J.D. Powell, 40th Workshop on Large Scale Nonlinear Optimization (Erice, Italy, 2004) + */ + + matrix W(M + dims + 1, M + dims + 1); + + set_subm(W, 0, 0, M, M) = 0.5*squared(tmp(trans(X)*X)); + set_subm(W, 0, M, M, 1) = 1; + set_subm(W, M, 0, 1, M) = 1; + set_subm(W, M, M, dims+1, dims+1) = 0; + set_subm(W, 0, M+1, X.nc(), X.nr()) = trans(X); + set_subm(W, M+1, 0, X.nr(), X.nc()) = X; + + + const matrix r = join_cols(Y, zeros_matrix(dims+1,1)); + + //matrix z = pinv(W)*r; + lu_decomposition lu(W); + matrix z = lu.solve(r); + //if (lu.is_singular()) std::cout << "WARNING, THE W MATRIX IS SINGULAR!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!" << std::endl; + + matrix lambda = rowm(z, range(0,M-1)); + + c = z(M); + g = rowm(z, range(M+1,z.size()-1)); + H = X*diagm(lambda)*trans(X); + } + + // ---------------------------------------------------------------------------------------- + + struct quad_interp_result + { + quad_interp_result() = default; + + template + quad_interp_result( + const matrix_exp& best_x, + double predicted_improvement + ) : best_x(best_x), predicted_improvement(predicted_improvement) {} + + matrix best_x; + double predicted_improvement = std::numeric_limits::quiet_NaN(); + }; + + // ---------------------------------------------------------------------------------------- + + quad_interp_result find_max_quadraticly_interpolated_vector ( + const matrix& anchor, + const double radius, + const std::vector>& x, + const std::vector& y, + const matrix& lower, + const matrix& upper + ) + { + DLIB_CASSERT(x.size() == y.size()); + DLIB_CASSERT(x.size() > 0); + for (size_t i = 0; i < x.size(); ++i) + DLIB_CASSERT(anchor.size() == x[i].size()); + + const long x_size = static_cast(x.size()); + DLIB_CASSERT(anchor.size()+1 <= x_size && x_size <= (anchor.size()+1)*(anchor.size()+2)/2); + + + matrix X(anchor.size(), x.size()); + matrix Y(x.size()); + for (size_t i = 0; i < x.size(); ++i) + { + set_colm(X,i) = x[i] - anchor; + Y(i) = y[i]; + } + + matrix H; + matrix g; + double c; + + fit_quadratic_to_points(X, Y, H, g, c); + + matrix p; + + solve_trust_region_subproblem_bounded(-H,-g, radius, p, 0.001, 500, lower-anchor, upper-anchor); + + // ensure we never move more than radius from the anchor. This might happen if the + // trust region subproblem isn't solved accurately enough. + if (length(p) >= radius) + p *= radius/length(p); + + + double predicted_improvement = 0.5*trans(p)*H*p + trans(p)*g; + return quad_interp_result{clamp(anchor+p,lower,upper), predicted_improvement}; + } + + // ---------------------------------------------------------------------------------------- + + quad_interp_result pick_next_sample_using_trust_region ( + const std::vector& samples, + double& radius, + const matrix& lower, + const matrix& upper, + const std::vector& is_integer_variable + ) + { + DLIB_CASSERT(samples.size() > 0); + // We don't use the QP to optimize integer variables. Instead, we just fix them at + // their best observed value and use the QP to optimize the real variables. So the + // number of dimensions, as far as the QP is concerned, is the number of non-integer + // variables. + size_t dims = 0; + for (auto is_int : is_integer_variable) + { + if (!is_int) + ++dims; + } + + DLIB_CASSERT(samples.size() >= dims+1); + + // Use enough points to fill out a quadratic model or the max available if we don't + // have quite enough. + const long N = std::min(samples.size(), (dims+1)*(dims+2)/2); + + + // first find the best sample; + double best_val = -1e300; + matrix best_x; + for (auto& v : samples) + { + if (v.y > best_val) + { + best_val = v.y; + best_x = v.x; + } + } + + // if there are only integer variables then there isn't really anything to do. So just + // return the best_x and say there is no improvement. + if (dims == 0) + return quad_interp_result(best_x, 0); + + matrix active_dims(dims); + long j = 0; + for (size_t i = 0; i < is_integer_variable.size(); ++i) + { + if (!is_integer_variable[i]) + active_dims(j++) = i; + } + + // now find the N-1 nearest neighbors of best_x + std::vector> distances; + for (size_t i = 0; i < samples.size(); ++i) + distances.emplace_back(length(best_x-samples[i].x), i); + std::sort(distances.begin(), distances.end()); + distances.resize(N); + + std::vector> x; + std::vector y; + for (auto& idx : distances) + { + x.emplace_back(rowm(samples[idx.second].x, active_dims)); + y.emplace_back(samples[idx.second].y); + } + + if (radius == 0) + { + for (auto& idx : distances) + radius = std::max(radius, length(rowm(best_x-samples[idx.second].x, active_dims)) ); + // Shrink the radius a little so we are always going to be making the sampling of + // points near the best current point smaller. + radius *= 0.95; + } + + + auto tmp = find_max_quadraticly_interpolated_vector(rowm(best_x,active_dims), radius, x, y, rowm(lower,active_dims), rowm(upper,active_dims)); + + // stick the integer variables back into the solution + for (long i = 0; i < active_dims.size(); ++i) + best_x(active_dims(i)) = tmp.best_x(i); + + tmp.best_x = best_x; + return tmp; + } + + // ---------------------------------------------------------------------------------------- + + matrix make_random_vector( + dlib::rand& rnd, + const matrix& lower, + const matrix& upper, + const std::vector& is_integer_variable + ) + { + matrix temp(lower.size()); + for (long i = 0; i < temp.size(); ++i) + { + temp(i) = rnd.get_double_in_range(lower(i), upper(i)); + if (is_integer_variable[i]) + temp(i) = std::round(temp(i)); + } + return temp; + } + + // ---------------------------------------------------------------------------------------- + + struct max_upper_bound_function + { + max_upper_bound_function() = default; + + template + max_upper_bound_function( + const matrix_exp& x, + double predicted_improvement, + double upper_bound + ) : x(x), predicted_improvement(predicted_improvement), upper_bound(upper_bound) {} + + matrix x; + double predicted_improvement = 0; + double upper_bound = 0; + }; + + // ------------------------------------------------------------------------------------ + + max_upper_bound_function pick_next_sample_as_max_upper_bound ( + dlib::rand& rnd, + const upper_bound_function& ub, + const matrix& lower, + const matrix& upper, + const std::vector& is_integer_variable, + const size_t num_random_samples + ) + { + DLIB_CASSERT(ub.num_points() > 0); + + + + // now do a simple random search to find the maximum upper bound + double best_ub_so_far = -std::numeric_limits::infinity(); + matrix vtemp(lower.size()), v; + for (size_t rounds = 0; rounds < num_random_samples; ++rounds) + { + vtemp = make_random_vector(rnd, lower, upper, is_integer_variable); + + double bound = ub(vtemp); + if (bound > best_ub_so_far) + { + best_ub_so_far = bound; + v = vtemp; + } + } + + double max_value = -std::numeric_limits::infinity(); + for (auto& v : ub.get_points()) + max_value = std::max(max_value, v.y); + + return max_upper_bound_function(v, best_ub_so_far - max_value, best_ub_so_far); + } + + } // end of namespace qopt_impl; + + using namespace qopt_impl; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + function_spec::function_spec( + matrix bound1, + matrix bound2 + ) : + lower(std::move(bound1)), upper(std::move(bound2)) + { + DLIB_CASSERT(lower.size() == upper.size()); + for (long i = 0; i < lower.size(); ++i) + { + if (upper(i) < lower(i)) + std::swap(lower(i), upper(i)); + DLIB_CASSERT(upper(i) != lower(i), "The upper and lower bounds can't be equal."); + } + is_integer_variable.assign(lower.size(), false); + } + +// ---------------------------------------------------------------------------------------- + + function_spec::function_spec( + matrix bound1, + matrix bound2, + std::vector is_integer + ) : + function_spec(std::move(bound1),std::move(bound2)) + { + is_integer_variable = std::move(is_integer); + DLIB_CASSERT(lower.size() == (long)is_integer_variable.size()); + + + // Make sure any integer variables have integer bounds. + for (size_t i = 0; i < is_integer_variable.size(); ++i) + { + if (is_integer_variable[i]) + { + DLIB_CASSERT(std::round(lower(i)) == lower(i), "If you say a variable is an integer variable then it must have an integer lower bound. \n" + << "lower[i] = " << lower(i)); + DLIB_CASSERT(std::round(upper(i)) == upper(i), "If you say a variable is an integer variable then it must have an integer upper bound. \n" + << "upper[i] = " << upper(i)); + } + } + } + +// ---------------------------------------------------------------------------------------- + + namespace gopt_impl + { + upper_bound_function funct_info::build_upper_bound_with_all_function_evals ( + ) const + { + upper_bound_function tmp(ub); + + // we are going to add the outstanding evals into this and assume the + // outstanding evals are going to take y values equal to their nearest + // neighbor complete evals. + for (auto& eval : outstanding_evals) + { + function_evaluation e; + e.x = eval.x; + e.y = find_nn(ub.get_points(), eval.x); + tmp.add(e); + } + + return tmp; + } + + // ------------------------------------------------------------------------------------ + + double funct_info::find_nn ( + const std::vector& evals, + const matrix& x + ) + { + double best_y = 0; + double best_dist = std::numeric_limits::infinity(); + for (auto& v : evals) + { + double dist = length_squared(v.x-x); + if (dist < best_dist) + { + best_dist = dist; + best_y = v.y; + } + } + return best_y; + } + + } // end namespace gopt_impl + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + function_evaluation_request::function_evaluation_request( + function_evaluation_request&& item + ) + { + m_has_been_evaluated = item.m_has_been_evaluated; + req = item.req; + info = item.info; + item.info.reset(); + + item.m_has_been_evaluated = true; + } + +// ---------------------------------------------------------------------------------------- + + function_evaluation_request& function_evaluation_request:: + operator=( + function_evaluation_request&& item + ) + { + function_evaluation_request(std::move(item)).swap(*this); + return *this; + } + +// ---------------------------------------------------------------------------------------- + + void function_evaluation_request:: + swap( + function_evaluation_request& item + ) + { + std::swap(m_has_been_evaluated, item.m_has_been_evaluated); + std::swap(req, item.req); + std::swap(info, item.info); + } + +// ---------------------------------------------------------------------------------------- + + size_t function_evaluation_request:: + function_idx ( + ) const + { + return info->function_idx; + } + + const matrix& function_evaluation_request:: + x ( + ) const + { + return req.x; + } + +// ---------------------------------------------------------------------------------------- + + bool function_evaluation_request:: + has_been_evaluated ( + ) const + { + return m_has_been_evaluated; + } + +// ---------------------------------------------------------------------------------------- + + function_evaluation_request:: + ~function_evaluation_request() + { + if (!m_has_been_evaluated) + { + std::lock_guard lock(*info->m); + + // remove the evaluation request from the outstanding list. + auto i = std::find(info->outstanding_evals.begin(), info->outstanding_evals.end(), req); + info->outstanding_evals.erase(i); + } + } + +// ---------------------------------------------------------------------------------------- + + void function_evaluation_request:: + set ( + double y + ) + { + DLIB_CASSERT(has_been_evaluated() == false); + std::lock_guard lock(*info->m); + + m_has_been_evaluated = true; + + + // move the evaluation from outstanding to complete + auto i = std::find(info->outstanding_evals.begin(), info->outstanding_evals.end(), req); + DLIB_CASSERT(i != info->outstanding_evals.end()); + info->outstanding_evals.erase(i); + info->ub.add(function_evaluation(req.x,y)); + + + // Now do trust region radius maintenance and keep track of the best objective + // values and all that. + if (req.was_trust_region_generated_request) + { + // Adjust trust region radius based on how good this evaluation + // was. + double measured_improvement = y-req.anchor_objective_value; + double rho = measured_improvement/std::abs(req.predicted_improvement); + //std::cout << "rho: "<< rho << std::endl; + //std::cout << "radius: "<< info->radius << std::endl; + if (rho < 0.25) + info->radius *= 0.5; + else if (rho > 0.75) + info->radius *= 2; + } + + if (y > info->best_objective_value) + { + if (!req.was_trust_region_generated_request && length(req.x - info->best_x) > info->radius*1.001) + { + //std::cout << "reset radius because of big move, " << length(req.x - info->best_x) << " radius was " << info->radius << std::endl; + // reset trust region radius since we made a big move. Doing this will + // cause the radius to be reset to the size of the local region. + info->radius = 0; + } + info->best_objective_value = y; + info->best_x = std::move(req.x); + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + global_function_search:: + global_function_search( + const function_spec& function + ) : global_function_search(std::vector(1,function)) {} + +// ---------------------------------------------------------------------------------------- + + global_function_search:: + global_function_search( + const std::vector& functions_ + ) + { + DLIB_CASSERT(functions_.size() > 0); + m = std::make_shared(); + functions.reserve(functions_.size()); + for (size_t i = 0; i < functions_.size(); ++i) + functions.emplace_back(std::make_shared(functions_[i],i,m)); + } + +// ---------------------------------------------------------------------------------------- + + global_function_search:: + global_function_search( + const std::vector& functions_, + const std::vector>& initial_function_evals, + const double relative_noise_magnitude_ + ) : + global_function_search(functions_) + { + DLIB_CASSERT(functions_.size() > 0); + DLIB_CASSERT(functions_.size() == initial_function_evals.size()); + DLIB_CASSERT(relative_noise_magnitude >= 0); + relative_noise_magnitude = relative_noise_magnitude_; + for (size_t i = 0; i < initial_function_evals.size(); ++i) + { + functions[i]->ub = upper_bound_function(initial_function_evals[i], relative_noise_magnitude); + } + } + +// ---------------------------------------------------------------------------------------- + + size_t global_function_search:: + num_functions( + ) const + { + return functions.size(); + } + +// ---------------------------------------------------------------------------------------- + + void global_function_search:: + set_seed ( + time_t seed + ) + { + rnd = dlib::rand(seed); + } + +// ---------------------------------------------------------------------------------------- + + void global_function_search:: + get_function_evaluations ( + std::vector& specs, + std::vector>& function_evals + ) const + { + std::lock_guard lock(*m); + specs.clear(); + function_evals.clear(); + for (size_t i = 0; i < functions.size(); ++i) + { + specs.emplace_back(functions[i]->spec); + function_evals.emplace_back(functions[i]->ub.get_points()); + } + } + +// ---------------------------------------------------------------------------------------- + + void global_function_search:: + get_best_function_eval ( + matrix& x, + double& y, + size_t& function_idx + ) const + { + DLIB_CASSERT(num_functions() != 0); + + std::lock_guard lock(*m); + + // find the largest value + auto& info = *best_function(function_idx); + y = info.best_objective_value; + x = info.best_x; + } + +// ---------------------------------------------------------------------------------------- + + function_evaluation_request global_function_search:: + get_next_x ( + ) + { + DLIB_CASSERT(num_functions() != 0); + + using namespace gopt_impl; + + std::lock_guard lock(*m); + + + // the first thing we do is make sure each function has at least max(3,dimensionality of function) evaluations + for (auto& info : functions) + { + const long dims = info->spec.lower.size(); + // If this is the very beginning of the optimization process + if (info->ub.num_points()+info->outstanding_evals.size() < 1) + { + outstanding_function_eval_request new_req; + new_req.request_id = next_request_id++; + // Pick the point right in the center of the bounds to evaluate first since + // people will commonly center the bound on a location they think is good. + // So might as well try there first. + new_req.x = (info->spec.lower + info->spec.upper)/2.0; + for (long i = 0; i < new_req.x.size(); ++i) + { + if (info->spec.is_integer_variable[i]) + new_req.x(i) = std::round(new_req.x(i)); + } + info->outstanding_evals.emplace_back(new_req); + return function_evaluation_request(new_req,info); + } + else if (info->ub.num_points() < std::max(3,dims)) + { + outstanding_function_eval_request new_req; + new_req.request_id = next_request_id++; + new_req.x = make_random_vector(rnd, info->spec.lower, info->spec.upper, info->spec.is_integer_variable); + info->outstanding_evals.emplace_back(new_req); + return function_evaluation_request(new_req,info); + } + } + + + if (do_trust_region_step && !has_outstanding_trust_region_request()) + { + // find the currently best performing function, we will do a trust region + // step on it. + auto info = best_function(); + const long dims = info->spec.lower.size(); + // if we have enough points to do a trust region step + if (info->ub.num_points() > dims+1) + { + auto tmp = pick_next_sample_using_trust_region(info->ub.get_points(), + info->radius, info->spec.lower, info->spec.upper, info->spec.is_integer_variable); + //std::cout << "QP predicted improvement: "<< tmp.predicted_improvement << std::endl; + if (tmp.predicted_improvement > min_trust_region_epsilon) + { + do_trust_region_step = false; + outstanding_function_eval_request new_req; + new_req.request_id = next_request_id++; + new_req.x = tmp.best_x; + new_req.was_trust_region_generated_request = true; + new_req.anchor_objective_value = info->best_objective_value; + new_req.predicted_improvement = tmp.predicted_improvement; + info->outstanding_evals.emplace_back(new_req); + return function_evaluation_request(new_req, info); + } + } + } + + // make it so we alternate between upper bounded and trust region steps. + do_trust_region_step = true; + + if (rnd.get_random_double() >= pure_random_search_probability) + { + // pick a point at random to sample according to the upper bound + double best_upper_bound = -std::numeric_limits::infinity(); + std::shared_ptr best_funct; + matrix next_sample; + // so figure out if any function has a good upper bound and if so pick the + // function with the largest upper bound for evaluation. + for (auto& info : functions) + { + auto tmp = pick_next_sample_as_max_upper_bound(rnd, + info->build_upper_bound_with_all_function_evals(), info->spec.lower, info->spec.upper, + info->spec.is_integer_variable, num_random_samples); + if (tmp.predicted_improvement > 0 && tmp.upper_bound > best_upper_bound) + { + best_upper_bound = tmp.upper_bound; + next_sample = std::move(tmp.x); + best_funct = info; + } + } + + // if we found a good function to evaluate then return that. + if (best_funct) + { + outstanding_function_eval_request new_req; + new_req.request_id = next_request_id++; + new_req.x = std::move(next_sample); + best_funct->outstanding_evals.emplace_back(new_req); + return function_evaluation_request(new_req, best_funct); + } + } + + + // pick entirely at random + size_t function_idx = rnd.get_integer(functions.size()); + auto info = functions[function_idx]; + outstanding_function_eval_request new_req; + new_req.request_id = next_request_id++; + new_req.x = make_random_vector(rnd, info->spec.lower, info->spec.upper, info->spec.is_integer_variable); + info->outstanding_evals.emplace_back(new_req); + return function_evaluation_request(new_req, info); + + } + +// ---------------------------------------------------------------------------------------- + + double global_function_search:: + get_pure_random_search_probability ( + ) const + { + return pure_random_search_probability; + } + +// ---------------------------------------------------------------------------------------- + + void global_function_search:: + set_pure_random_search_probability ( + double prob + ) + { + DLIB_CASSERT(0 <= prob && prob <= 1); + pure_random_search_probability = prob; + } + +// ---------------------------------------------------------------------------------------- + + double global_function_search:: + get_solver_epsilon ( + ) const + { + return min_trust_region_epsilon; + } + +// ---------------------------------------------------------------------------------------- + + void global_function_search:: + set_solver_epsilon ( + double eps + ) + { + DLIB_CASSERT(0 <= eps); + min_trust_region_epsilon = eps; + } + +// ---------------------------------------------------------------------------------------- + + double global_function_search:: + get_relative_noise_magnitude ( + ) const + { + return relative_noise_magnitude; + } + +// ---------------------------------------------------------------------------------------- + + void global_function_search:: + set_relative_noise_magnitude ( + double value + ) + { + DLIB_CASSERT(0 <= value); + relative_noise_magnitude = value; + if (m) + { + std::lock_guard lock(*m); + // recreate all the upper bound functions with the new relative noise magnitude + for (auto& f : functions) + f->ub = upper_bound_function(f->ub.get_points(), relative_noise_magnitude); + } + } + +// ---------------------------------------------------------------------------------------- + + size_t global_function_search:: + get_monte_carlo_upper_bound_sample_num ( + ) const + { + return num_random_samples; + } + +// ---------------------------------------------------------------------------------------- + + void global_function_search:: + set_monte_carlo_upper_bound_sample_num ( + size_t num + ) + { + DLIB_CASSERT(0 <= num); + num_random_samples = num; + } + +// ---------------------------------------------------------------------------------------- + + std::shared_ptr global_function_search:: + best_function( + ) const + { + size_t idx = 0; + return best_function(idx); + } + +// ---------------------------------------------------------------------------------------- + + std::shared_ptr global_function_search:: + best_function( + size_t& idx + ) const + { + auto compare = [](const std::shared_ptr& a, const std::shared_ptr& b) + { return a->best_objective_value < b->best_objective_value; }; + + auto i = std::max_element(functions.begin(), functions.end(), compare); + + idx = std::distance(functions.begin(),i); + return *i; + } + +// ---------------------------------------------------------------------------------------- + + bool global_function_search:: + has_outstanding_trust_region_request ( + ) const + { + for (auto& f : functions) + { + for (auto& i : f->outstanding_evals) + { + if (i.was_trust_region_generated_request) + return true; + } + } + return false; + } + +// ---------------------------------------------------------------------------------------- + +} + diff --git a/ml/dlib/dlib/global_optimization/global_function_search.h b/ml/dlib/dlib/global_optimization/global_function_search.h new file mode 100644 index 000000000..fa036884a --- /dev/null +++ b/ml/dlib/dlib/global_optimization/global_function_search.h @@ -0,0 +1,245 @@ +// Copyright (C) 2017 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_GLOBAL_FuNCTION_SEARCH_Hh_ +#define DLIB_GLOBAL_FuNCTION_SEARCH_Hh_ + +#include "global_function_search_abstract.h" +#include +#include "../matrix.h" +#include +#include "../rand.h" +#include "upper_bound_function.h" +#include "../test_for_odr_violations.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + struct function_spec + { + function_spec( + matrix bound1, + matrix bound2 + ); + + function_spec( + matrix bound1, + matrix bound2, + std::vector is_integer + ); + + matrix lower; + matrix upper; + std::vector is_integer_variable; + }; + +// ---------------------------------------------------------------------------------------- + + namespace gopt_impl + { + struct outstanding_function_eval_request + { + size_t request_id = 0; // unique id for this eval request + matrix x; // function x to evaluate + + // trust region specific stuff + bool was_trust_region_generated_request = false; + double predicted_improvement = std::numeric_limits::quiet_NaN(); + double anchor_objective_value = std::numeric_limits::quiet_NaN(); // objective value at center of TR step + + bool operator==(const outstanding_function_eval_request& item) const { return request_id == item.request_id; } + }; + + struct funct_info + { + funct_info() = delete; + funct_info(const funct_info&) = delete; + funct_info& operator=(const funct_info&) = delete; + + funct_info( + const function_spec& spec, + size_t function_idx, + const std::shared_ptr& m + ) : + spec(spec), function_idx(function_idx), m(m) + { + best_x = zeros_matrix(spec.lower); + } + + upper_bound_function build_upper_bound_with_all_function_evals ( + ) const; + + static double find_nn ( + const std::vector& evals, + const matrix& x + ); + + + function_spec spec; + size_t function_idx = 0; + std::shared_ptr m; + upper_bound_function ub; + std::vector outstanding_evals; + matrix best_x; + double best_objective_value = -std::numeric_limits::infinity(); + double radius = 0; + }; + + } + +// ---------------------------------------------------------------------------------------- + + class function_evaluation_request + { + public: + + function_evaluation_request() = delete; + function_evaluation_request(const function_evaluation_request&) = delete; + function_evaluation_request& operator=(const function_evaluation_request&) = delete; + + + function_evaluation_request(function_evaluation_request&& item); + function_evaluation_request& operator=(function_evaluation_request&& item); + + ~function_evaluation_request(); + + size_t function_idx ( + ) const; + + const matrix& x ( + ) const; + + bool has_been_evaluated ( + ) const; + + void set ( + double y + ); + + void swap(function_evaluation_request& item); + + private: + + friend class global_function_search; + + explicit function_evaluation_request( + const gopt_impl::outstanding_function_eval_request& req, + const std::shared_ptr& info + ) : req(req), info(info) {} + + bool m_has_been_evaluated = false; + gopt_impl::outstanding_function_eval_request req; + std::shared_ptr info; + }; + +// ---------------------------------------------------------------------------------------- + + class global_function_search + { + public: + + global_function_search() = default; + + explicit global_function_search( + const function_spec& function + ); + + explicit global_function_search( + const std::vector& functions_ + ); + + global_function_search( + const std::vector& functions_, + const std::vector>& initial_function_evals, + const double relative_noise_magnitude = 0.001 + ); + + global_function_search(const global_function_search&) = delete; + global_function_search& operator=(const global_function_search& item) = delete; + + global_function_search(global_function_search&& item) = default; + global_function_search& operator=(global_function_search&& item) = default; + + size_t num_functions( + ) const; + + void set_seed ( + time_t seed + ); + + void get_function_evaluations ( + std::vector& specs, + std::vector>& function_evals + ) const; + + void get_best_function_eval ( + matrix& x, + double& y, + size_t& function_idx + ) const; + + function_evaluation_request get_next_x ( + ); + + double get_pure_random_search_probability ( + ) const; + + void set_pure_random_search_probability ( + double prob + ); + + double get_solver_epsilon ( + ) const; + + void set_solver_epsilon ( + double eps + ); + + double get_relative_noise_magnitude ( + ) const; + + void set_relative_noise_magnitude ( + double value + ); + + size_t get_monte_carlo_upper_bound_sample_num ( + ) const; + + void set_monte_carlo_upper_bound_sample_num ( + size_t num + ); + + private: + + std::shared_ptr best_function( + ) const; + + std::shared_ptr best_function( + size_t& idx + ) const; + + bool has_outstanding_trust_region_request ( + ) const; + + + dlib::rand rnd; + double pure_random_search_probability = 0.02; + double min_trust_region_epsilon = 0; + double relative_noise_magnitude = 0.001; + size_t num_random_samples = 5000; + bool do_trust_region_step = true; + + size_t next_request_id = 1; + + std::vector> functions; + std::shared_ptr m; + + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_GLOBAL_FuNCTION_SEARCH_Hh_ + diff --git a/ml/dlib/dlib/global_optimization/global_function_search_abstract.h b/ml/dlib/dlib/global_optimization/global_function_search_abstract.h new file mode 100644 index 000000000..c8bfc3993 --- /dev/null +++ b/ml/dlib/dlib/global_optimization/global_function_search_abstract.h @@ -0,0 +1,605 @@ +// Copyright (C) 2017 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_GLOBAL_FuNCTION_SEARCH_ABSTRACT_Hh_ +#ifdef DLIB_GLOBAL_FuNCTION_SEARCH_ABSTRACT_Hh_ + +#include +#include "../matrix.h" +#include "upper_bound_function_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + struct function_spec + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a simple struct that lets you define the valid inputs to a + multivariate function. It lets you define bound constraints for each + variable as well as say if a variable is integer valued or not. Therefore, + an instance of this struct says that a function takes upper.size() input + variables, where the ith variable must be in the range [lower(i) upper(i)] + and be an integer if is_integer_variable[i]==true. + !*/ + + function_spec( + matrix bound1, + matrix bound2 + ); + /*! + requires + - bound1.size() == bound2.size() + - for all valid i: bound1(i) != bound2(i) + ensures + - #is_integer_variable.size() == bound1.size() + - #lower.size() == bound1.size() + - #upper.size() == bound1.size() + - for all valid i: + - #is_integer_variable[i] == false + - #lower(i) == min(bound1(i), bound2(i)) + - #upper(i) == max(bound1(i), bound2(i)) + !*/ + + function_spec( + matrix lower, + matrix upper, + std::vector is_integer + ); + /*! + requires + - bound1.size() == bound2.size() == is_integer.size() + - for all valid i: bound1(i) != bound2(i) + ensures + - #is_integer_variable.size() == bound1.size() + - #lower.size() == bound1.size() + - #upper.size() == bound1.size() + - for all valid i: + - #is_integer_variable[i] == is_integer[i] + - #lower(i) == min(bound1(i), bound2(i)) + - #upper(i) == max(bound1(i), bound2(i)) + !*/ + + matrix lower; + matrix upper; + std::vector is_integer_variable; + }; + +// ---------------------------------------------------------------------------------------- + + class function_evaluation_request + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a request, by the global_function_search object, to + evaluate a real-valued function and report back the results. + + THREAD SAFETY + You shouldn't let more than one thread touch a function_evaluation_request + at the same time. However, it is safe to send instances of this class to + other threads for processing. This lets you evaluate multiple + function_evaluation_requests in parallel. Any appropriate synchronization + with regard to the originating global_function_search instance is handled + automatically. + !*/ + + public: + + // You can't make or copy this object, the only way to get one is from the + // global_function_search class via get_next_x(). + function_evaluation_request() = delete; + function_evaluation_request(const function_evaluation_request&) = delete; + function_evaluation_request& operator=(const function_evaluation_request&) = delete; + + // You can however move and swap this object. + function_evaluation_request(function_evaluation_request&& item); + function_evaluation_request& operator=(function_evaluation_request&& item); + /*! + ensures + - *this takes the state of item. + - #item.has_been_evaluated() == true + !*/ + + ~function_evaluation_request( + ); + /*! + ensures + - frees all resources associated with this object. + - It's fine to destruct function_evaluation_requests even if they haven't + been evaluated yet. If this happens it will simply be as if the request + was never issued. + !*/ + + size_t function_idx ( + ) const; + /*! + ensures + - Returns the function index that identifies which function is to be + evaluated. + !*/ + + const matrix& x ( + ) const; + /*! + ensures + - returns the input parameters to the function to be evaluated. + !*/ + + bool has_been_evaluated ( + ) const; + /*! + ensures + - If this evaluation request is still outstanding then returns false, + otherwise returns true. That is, if the global_function_search is still + waiting for you report back by calling set() then + has_been_evaluated()==false. + !*/ + + void set ( + double y + ); + /*! + requires + - has_been_evaluated() == false + ensures + - #has_been_evaluated() == true + - Notifies the global_function_search instance that created this object + that when the function_idx()th function is evaluated with x() as input + then the output is y. + !*/ + + void swap( + function_evaluation_request& item + ); + /*! + ensures + - swaps the state of *this and item + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + class global_function_search + { + /*! + WHAT THIS OBJECT REPRESENTS + This object performs global optimization of a set of user supplied + functions. The goal is to maximize the following objective function: + max_{function_i,x_i}: function_i(x_i) + subject to bound constraints on each element of x_i. Moreover, each + element of x_i can be either real valued or integer valued. Each of the + functions can also take a different number of variables. Therefore, the + final result of the optimization tells you which function produced the + largest output and what input (i.e. the x value) to that function is + necessary to obtain that maximal value. + + Importantly, the global_function_search object does not require the user to + supply derivatives. Moreover, the functions may contain discontinuities, + behave stochastically, and have many local maxima. The global_function_search + object will attempt to find the global optima in the face of these challenges. + It is also designed to use as few function evaluations as possible, making + it suitable for optimizing functions that are very expensive to evaluate. + + It does this by alternating between two modes. A global exploration mode + and a local optima refinement mode. This is accomplished by building and + maintaining two models of the objective function: + 1. A global model that upper bounds our objective function. This is a + non-parametric piecewise linear model based on all function + evaluations ever seen by the global_function_search object. + 2. A local quadratic model fit around the best point seen so far. + + The optimization procedure therefore looks like this: + + while(not done) + { + DO GLOBAL EXPLORE STEP: + Find the point that maximizes the upper bounding model since + that is the point with the largest possible improvement in the + objective function. + + Evaluate the new point and incorporate it into our models. + + DO LOCAL REFINEMENT STEP: + Find the optimal solution to the local quadratic model. + + If this point looks like it will improve on the "best point seen + so far" by at least get_solver_epsilon() then we evaluate that + point and incorporate it into our models, otherwise we ignore + it. + } + + You can see that we alternate between global search and local refinement, + except in the case where the local model seems to have converged to within + get_solver_epsilon() accuracy. In that case only global search steps are + used. We do this in the hope that the global search will find a new and + better local optima to explore, which would then reactivate local + refinement when it has something productive to do. + + + Now let's turn our attention to the specific API defined by the + global_function_search object. We will begin by showing a short example of + its use: + + // Suppose we want to find which of these functions, F() and G(), have + // the largest output and what input is necessary to produce the + // maximal output. + auto F = [](double a, double b) { return -std::pow(a-2,2.0) - std::pow(b-4,2.0); }; + auto G = [](double x) { return 2-std::pow(x-5,2.0); }; + + // We first define function_spec objects that specify bounds on the + // inputs to each function. The search process will only search within + // these bounds. + function_spec spec_F({-10,-10}, {10,10}); + function_spec spec_G({-2}, {6}); + // Then we create a global_function_search object with those function specifications. + global_function_search opt({spec_F, spec_G}); + + // Here we run 15 iterations of the search process. Note that the user + // of global_function_search writes the main solver loop, which is + // somewhat unusual. We will discuss why that is in a moment, but for + // now let's look at this example. + for (int i = 0; i < 15; ++i) + { + // All we do here is ask the global_function_search object what to + // evaluate next, then do what it asked, and then report the + // results back by calling function_evaluation_request's set() + // method. + function_evaluation_request next = opt.get_next_x(); + // next.function_idx() tells you which of the functions you should + // evaluate. We have 2 functions here (F and G) so function_idx() + // can take only the values 0 and 1. If, for example, we had 10 + // functions it would take the values 0 through 9. + if (next.function_idx() == 0) + { + // Call F with the inputs requested by the + // global_function_search and report them back. + double a = next.x()(0); + double b = next.x()(1); + next.set(F(a,b)); // Tell the solver what happened. + } + else + { + double x = next.x()(0); + next.set(G(x)); + } + } + + // Find out what point gave the largest outputs: + matrix x; + double y; + size_t function_idx; + opt.get_best_function_eval(x,y,function_idx); + + cout << "function_idx: "<< function_idx << endl; + cout << "y: " << y << endl; + cout << "x: " << x << endl; + + The above cout statements will print this: + + function_idx: 1 + y: 2 + x: 5 + + Which is the correct result since G(5) gives the largest possible output in + our example. + + So why does the user write the main loop? Why isn't it embedded inside + dlib? Well, there are two answers to this. The first is that it is. Most + users should just call dlib::find_max_global() which does exactly that, it + runs the loop for you. However, the API shown above gives you the + opportunity to run multiple function evaluations in parallel. For + instance, it is perfectly valid to call get_next_x() multiple times and + send the resulting function_evaluation_request objects to separate threads + for processing. Those separate threads can run the functions being + optimized (e.g. F and G or whatever) and report back by calling + function_evaluation_request::set(). You could even spread the work across + a compute cluster if you have one. + + So what happens if you have N outstanding function evaluation requests? + Or in other words, what happens if you called get_next_x() N times and + haven't yet called their set() methods? Well, 1 of the N requests will be + a local refinement step while the N-1 other requests will be global + exploration steps generated from the current upper bounding model. This + should give you an idea of the usefulness of this kind of parallelism. If + for example, your functions being optimized were simple convex functions + this kind of parallelism wouldn't help since essentially all the + interesting work in the solver is going to be done by the local optimizer. + However, if your function has a lot of local optima, running many global + exploration steps in parallel might significantly reduce the time it takes + to find a good solution. + + It should also be noted that our upper bounding model is implemented by the + dlib::upper_bound_function object, which is a tool that allows us to create + a tight upper bound on our objective function. This upper bound is + non-parametric and gets progressively more accurate as the optimization + progresses, but also more and more expensive to maintain. It causes the + runtime of the entire optimization procedure to be O(N^2) where N is the + number of objective function evaluations. So problems that require millions + of function evaluations to find a good solution are not appropriate for the + global_function_search tool. However, if your objective function is very + expensive to evaluate then this relatively expensive upper bounding model + is well worth its computational cost. + + Finally, let's introduce some background literature on this algorithm. The + two most relevant papers in the optimization literature are: + Global optimization of Lipschitz functions Malherbe, Cédric and Vayatis, + Nicolas International Conference on Machine Learning - 2017 + and + The NEWUOA software for unconstrained optimization without derivatives By + M.J.D. Powell, 40th Workshop on Large Scale Nonlinear Optimization (Erice, + Italy, 2004) + + Our upper bounding model is an extension of the AdaLIPO method in the + Malherbe. See the documentation of dlib::upper_bound_function for more + details on that, as we make a number of important extensions. The other + part of our method, our local refinement model, is essentially the same + type of trust region model proposed by Powell in the above paper. That is, + each time we do a local refinement step we identify the best point seen so + far, fit a quadratic function around it using the function evaluations we + have collected so far, and then use a simple trust region procedure to + decide the next best point to evaluate based on our quadratic model. + + The method proposed by Malherbe gives excellent global search performance + but has terrible convergence properties in the area around a maxima. + Powell's method on the other hand has excellent convergence in the area + around a local maxima, as expected by a quadratic trust region method, but + is aggressively local maxima seeking. It will simply get stuck in the + nearest local optima. Combining the two together as we do here gives us + excellent performance in both global search and final convergence speed + near a local optima. Causing the global_function_search to perform well + for functions with many local optima while still giving high precision + solutions. For instance, on typical tests problems, like the Holder table + function, the global_function_search object can reliably find the globally + optimal solution to full floating point precision in under a few hundred + steps. + + + THREAD SAFETY + You shouldn't let more than one thread touch a global_function_search + instance at the same time. + !*/ + + public: + + global_function_search( + ); + /*! + ensures + - #num_functions() == 0 + - #get_relative_noise_magnitude() == 0.001 + - #get_solver_epsilon() == 0 + - #get_monte_carlo_upper_bound_sample_num() == 5000 + - #get_pure_random_search_probability() == 0.02 + !*/ + + explicit global_function_search( + const function_spec& function + ); + /*! + ensures + - #num_functions() == 1 + - #get_function_evaluations() will indicate that there are no function evaluations yet. + - #get_relative_noise_magnitude() == 0.001 + - #get_solver_epsilon() == 0 + - #get_monte_carlo_upper_bound_sample_num() == 5000 + - #get_pure_random_search_probability() == 0.02 + !*/ + + explicit global_function_search( + const std::vector& functions + ); + /*! + ensures + - #num_functions() == functions.size() + - #get_function_evaluations() will indicate that there are no function evaluations yet. + - #get_relative_noise_magnitude() == 0.001 + - #get_solver_epsilon() == 0 + - #get_monte_carlo_upper_bound_sample_num() == 5000 + - #get_pure_random_search_probability() == 0.02 + !*/ + + global_function_search( + const std::vector& functions, + const std::vector>& initial_function_evals, + const double relative_noise_magnitude = 0.001 + ); + /*! + requires + - functions.size() == initial_function_evals.size() + - relative_noise_magnitude >= 0 + ensures + - #num_functions() == functions.size() + - #get_function_evaluations() will return the provided initial_function_evals. + - #get_relative_noise_magnitude() == relative_noise_magnitude + - #get_solver_epsilon() == 0 + - #get_monte_carlo_upper_bound_sample_num() == 5000 + - #get_pure_random_search_probability() == 0.02 + !*/ + + // This object can't be copied. + global_function_search(const global_function_search&) = delete; + global_function_search& operator=(const global_function_search& item) = delete; + // But it can be moved + global_function_search(global_function_search&& item) = default; + global_function_search& operator=(global_function_search&& item) = default; + /*! + ensures + - moves the state of item into *this + - #item.num_functions() == 0 + !*/ + + void set_seed ( + time_t seed + ); + /*! + ensures + - Part of this object's algorithm uses random sampling to decide what + points to evaluate next. Calling set_seed() lets you set the seed used + by the random number generator. Note that if you don't call set_seed() + you will always get the same deterministic behavior. + !*/ + + size_t num_functions( + ) const; + /*! + ensures + - returns the number of functions being optimized. + !*/ + + void get_function_evaluations ( + std::vector& specs, + std::vector>& function_evals + ) const; + /*! + ensures + - #specs.size() == num_functions() + - #function_evals.size() == num_functions() + - This function allows you to query the state of the solver. In + particular, you can find the function_specs for each function being + optimized and their recorded evaluations. + - for all valid i: + - function_evals[i] == all the function evaluations that have been + recorded for the ith function (i.e. the function with the + function_spec #specs[i]). That is, this is the record of all the x + and y values reported back by function_evaluation_request::set() + calls. + !*/ + + void get_best_function_eval ( + matrix& x, + double& y, + size_t& function_idx + ) const; + /*! + requires + - num_functions() != 0 + ensures + - if (no function evaluations have been recorded yet) then + - The outputs of this function are in a valid but undefined state. + - else + - This function tells you which function has produced the largest + output seen so far. It also tells you the inputs to that function + that leads to those outputs (x) as well as the output value itself (y). + - 0 <= #function_idx < num_functions() + - #function_idx == the index of the function that produced the largest output seen so far. + - #x == the input parameters to the function that produced the largest outputs seen so far. + - #y == the largest output seen so far. + !*/ + + function_evaluation_request get_next_x ( + ); + /*! + requires + - num_functions() != 0 + ensures + - Generates and returns a function evaluation request. See the discussion + in the WHAT THIS OBJECT REPRESENTS section above for details. + !*/ + + double get_pure_random_search_probability ( + ) const; + /*! + ensures + - When we decide to do a global explore step we will, with probability + get_pure_random_search_probability(), sample a point completely at random + rather than using the upper bounding model. Therefore, if you set this + probability to 0 then we will depend entirely on the upper bounding + model. Alternatively, if you set get_pure_random_search_probability() to + 1 then we won't use the upper bounding model at all and instead use pure + random search to do global exploration. Pure random search is much + faster than using the upper bounding model, so if you know that your + objective function is especially simple it can be faster to use pure + random search. However, if you really know your function that well you + should probably use a gradient based optimizer :) + !*/ + + void set_pure_random_search_probability ( + double prob + ); + /*! + requires + - prob >= 0 + ensures + - #get_pure_random_search_probability() == prob + !*/ + + double get_solver_epsilon ( + ) const; + /*! + ensures + - As discussed in the WHAT THIS OBJECT REPRESENTS section, we only do a + local refinement step if we haven't already found the peak of the current + local optima. get_solver_epsilon() sets the tolerance for deciding if + the local search method has found the local optima. Therefore, when the + local trust region model runs we check if its predicted improvement in + the objective function is greater than get_solver_epsilon(). If it isn't + then we assume it has converged and we should focus entirely on global + search. + + This means that, for instance, setting get_solver_epsilon() to 0 + essentially instructs the solver to find each local optima to full + floating point precision and only then to focus on pure global search. + !*/ + + void set_solver_epsilon ( + double eps + ); + /*! + requires + - eps >= 0 + ensures + - #get_solver_epsilon() == eps + !*/ + + double get_relative_noise_magnitude ( + ) const; + /*! + ensures + - Returns the value of the relative noise magnitude parameter to the + dlib::upper_bound_function's used by this object. See the + upper_bound_function's documentation for a detailed discussion of this + parameter's meaning. Most users should leave this value as its default + setting. + !*/ + + void set_relative_noise_magnitude ( + double value + ); + /*! + requires + - value >= 0 + ensures + - #get_relative_noise_magnitude() == value + !*/ + + size_t get_monte_carlo_upper_bound_sample_num ( + ) const; + /*! + ensures + - To find the point that maximizes the upper bounding model we use + get_monte_carlo_upper_bound_sample_num() random evaluations and select + the largest upper bound from that set. So this parameter influences how + well we estimate the maximum point on the upper bounding model. + !*/ + + void set_monte_carlo_upper_bound_sample_num ( + size_t num + ); + /*! + requires + - num > 0 + ensures + - #get_monte_carlo_upper_bound_sample_num() == num + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_GLOBAL_FuNCTION_SEARCH_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/global_optimization/upper_bound_function.h b/ml/dlib/dlib/global_optimization/upper_bound_function.h new file mode 100644 index 000000000..d1957623e --- /dev/null +++ b/ml/dlib/dlib/global_optimization/upper_bound_function.h @@ -0,0 +1,286 @@ +// Copyright (C) 2017 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_UPPER_bOUND_FUNCTION_Hh_ +#define DLIB_UPPER_bOUND_FUNCTION_Hh_ + +#include "upper_bound_function_abstract.h" +#include "../svm/svm_c_linear_dcd_trainer.h" +#include "../statistics.h" +#include +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + struct function_evaluation + { + function_evaluation() = default; + function_evaluation(const matrix& x, double y) :x(x), y(y) {} + + matrix x; + double y = std::numeric_limits::quiet_NaN(); + }; + +// ---------------------------------------------------------------------------------------- + + class upper_bound_function + { + + public: + + upper_bound_function( + ) = default; + + upper_bound_function( + const double relative_noise_magnitude, + const double solver_eps + ) : relative_noise_magnitude(relative_noise_magnitude), solver_eps(solver_eps) + { + DLIB_CASSERT(relative_noise_magnitude >= 0); + DLIB_CASSERT(solver_eps > 0); + } + + explicit upper_bound_function( + const std::vector& _points, + const double relative_noise_magnitude = 0.001, + const double solver_eps = 0.0001 + ) : relative_noise_magnitude(relative_noise_magnitude), solver_eps(solver_eps), points(_points) + { + DLIB_CASSERT(relative_noise_magnitude >= 0); + DLIB_CASSERT(solver_eps > 0); + + if (points.size() > 1) + { + DLIB_CASSERT(points[0].x.size() > 0, "The vectors can't be empty."); + + const long dims = points[0].x.size(); + for (auto& p : points) + DLIB_CASSERT(p.x.size() == dims, "All the vectors given to upper_bound_function must have the same dimensionality."); + + learn_params(); + } + + } + + void add ( + const function_evaluation& point + ) + { + DLIB_CASSERT(point.x.size() != 0, "The vectors can't be empty."); + if (points.size() == 0) + { + points.push_back(point); + return; + } + + DLIB_CASSERT(point.x.size() == dimensionality(), "All the vectors given to upper_bound_function must have the same dimensionality."); + + if (points.size() < 4) + { + points.push_back(point); + *this = upper_bound_function(points, relative_noise_magnitude, solver_eps); + return; + } + + points.push_back(point); + // add constraints between the new point and the old points + for (size_t i = 0; i < points.size()-1; ++i) + active_constraints.push_back(std::make_pair(i,points.size()-1)); + + learn_params(); + } + + long num_points( + ) const + { + return points.size(); + } + + long dimensionality( + ) const + { + if (points.size() == 0) + return 0; + else + return points[0].x.size(); + } + + const std::vector& get_points( + ) const + { + return points; + } + + double operator() ( + const matrix& x + ) const + { + DLIB_CASSERT(num_points() > 0); + DLIB_CASSERT(x.size() == dimensionality()); + + + + double upper_bound = std::numeric_limits::infinity(); + + for (size_t i = 0; i < points.size(); ++i) + { + const double local_bound = points[i].y + std::sqrt(offsets[i] + dot(slopes, squared(x-points[i].x))); + upper_bound = std::min(upper_bound, local_bound); + } + + return upper_bound; + } + + private: + + void learn_params ( + ) + { + const long dims = points[0].x.size(); + + using sample_type = std::vector>; + using kernel_type = sparse_linear_kernel; + std::vector x; + std::vector y; + + // We are going to normalize the data so the values aren't extreme. First, we + // collect statistics on our data. + std::vector> x_rs(dims); + running_stats y_rs; + for (auto& v : points) + { + for (long i = 0; i < v.x.size(); ++i) + x_rs[i].add(v.x(i)); + y_rs.add(v.y); + } + + + // compute normalization vectors for the data. The only reason we do this is + // to make the optimization well conditioned. In particular, scaling the y + // values will prevent numerical errors in the 1-diff*diff computation below that + // would otherwise result when diff is really big. Also, scaling the xvalues + // to be about 1 will similarly make the optimization more stable and it also + // has the added benefit of keeping the relative_noise_magnitude's scale + // constant regardless of the size of x values. + const double yscale = 1.0/y_rs.stddev(); + std::vector xscale(dims); + for (size_t i = 0; i < xscale.size(); ++i) + xscale[i] = 1.0/(x_rs[i].stddev()*yscale); // make it so that xscale[i]*yscale == 1/x_rs[i].stddev() + + sample_type samp; + auto add_constraint = [&](long i, long j) { + samp.clear(); + for (long k = 0; k < dims; ++k) + { + double temp = (points[i].x(k) - points[j].x(k))*xscale[k]*yscale; + samp.push_back(std::make_pair(k, temp*temp)); + } + + if (points[i].y > points[j].y) + samp.push_back(std::make_pair(dims + j, relative_noise_magnitude)); + else + samp.push_back(std::make_pair(dims + i, relative_noise_magnitude)); + + const double diff = (points[i].y - points[j].y)*yscale; + samp.push_back(std::make_pair(dims + points.size(), 1-diff*diff)); + + x.push_back(samp); + y.push_back(1); + }; + + if (active_constraints.size() == 0) + { + x.reserve(points.size()*(points.size()-1)/2); + y.reserve(points.size()*(points.size()-1)/2); + for (size_t i = 0; i < points.size(); ++i) + { + for (size_t j = i+1; j < points.size(); ++j) + { + add_constraint(i,j); + } + } + } + else + { + for (auto& p : active_constraints) + add_constraint(p.first, p.second); + } + + + + + svm_c_linear_dcd_trainer trainer; + trainer.set_c(std::numeric_limits::infinity()); + //trainer.be_verbose(); + trainer.force_last_weight_to_1(true); + trainer.set_epsilon(solver_eps); + + svm_c_linear_dcd_trainer::optimizer_state state; + auto df = trainer.train(x,y, state); + + // save the active constraints for later so we can use them inside add() to add + // new points efficiently. + if (active_constraints.size() == 0) + { + long k = 0; + for (size_t i = 0; i < points.size(); ++i) + { + for (size_t j = i+1; j < points.size(); ++j) + { + if (state.get_alpha()[k++] != 0) + active_constraints.push_back(std::make_pair(i,j)); + } + } + } + else + { + DLIB_CASSERT(state.get_alpha().size() == active_constraints.size()); + new_active_constraints.clear(); + for (size_t i = 0; i < state.get_alpha().size(); ++i) + { + if (state.get_alpha()[i] != 0) + new_active_constraints.push_back(active_constraints[i]); + } + active_constraints.swap(new_active_constraints); + } + + //std::cout << "points.size(): " << points.size() << std::endl; + //std::cout << "active_constraints.size(): " << active_constraints.size() << std::endl; + + + const auto& bv = df.basis_vectors(0); + slopes.set_size(dims); + for (long i = 0; i < dims; ++i) + slopes(i) = bv[i].second*xscale[i]*xscale[i]; + + //std::cout << "slopes:" << trans(slopes); + + offsets.assign(points.size(),0); + + + for (size_t i = 0; i < points.size(); ++i) + { + offsets[i] += bv[slopes.size()+i].second*relative_noise_magnitude; + } + } + + + + double relative_noise_magnitude = 0.001; + double solver_eps = 0.0001; + std::vector> active_constraints, new_active_constraints; + + std::vector points; + std::vector offsets; // offsets.size() == points.size() + matrix slopes; // slopes.size() == points[0].first.size() + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_UPPER_bOUND_FUNCTION_Hh_ + + diff --git a/ml/dlib/dlib/global_optimization/upper_bound_function_abstract.h b/ml/dlib/dlib/global_optimization/upper_bound_function_abstract.h new file mode 100644 index 000000000..56b361597 --- /dev/null +++ b/ml/dlib/dlib/global_optimization/upper_bound_function_abstract.h @@ -0,0 +1,212 @@ +// Copyright (C) 2017 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_UPPER_bOUND_FUNCTION_ABSTRACT_Hh_ +#ifdef DLIB_UPPER_bOUND_FUNCTION_ABSTRACT_Hh_ + +#include "../matrix.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + struct function_evaluation + { + /*! + WHAT THIS OBJECT REPRESENTS + This object records the output of a real valued function in response to + some input. + + In particular, if you have a function F(x) then the function_evaluation is + simply a struct that records x and the scalar value F(x). + !*/ + + function_evaluation() = default; + function_evaluation(const matrix& x, double y) :x(x), y(y) {} + + matrix x; + double y = std::numeric_limits::quiet_NaN(); + }; + +// ---------------------------------------------------------------------------------------- + + class upper_bound_function + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a piecewise linear non-parametric function that can + be used to define an upper bound on some more complex and unknown function. + To describe this precisely, lets assume there is a function F(x) which you + are capable of sampling from but otherwise know nothing about, and that you + would like to find an upper bounding function U(x) such that U(x) >= F(x) + for any x. It would also be good if U(x)-F(x) was minimal. I.e. we would + like U(x) to be a tight upper bound, not something vacuous like U(x) = + infinity. + + The upper_bound_function class is a tool for creating this kind of upper + bounding function from a set of function_evaluations of F(x). We do this + by considering only U(x) of the form: + U = [](matrix x) { + double min_ub = infinity; + for (size_t i = 0; i < POINTS.size(); ++i) { + function_evaluation p = POINTS[i] + double local_bound = p.y + sqrt(noise_terms[i] + trans(p.x-x)*M*(p.x-x)) + min_ub = min(min_ub, local_bound) + } + return min_ub; + } + Where POINTS is an array of function_evaluation instances drawn from F(x), + M is a diagonal matrix, and noise_terms is an array of scalars. + + To create an upper bound U(x), the upper_bound_function takes a POINTS array + containing evaluations of F(x) as input and solves the following quadratic + program to find the parameters of U(x): + + min_{M,noise_terms}: sum(squared(M)) + sum(squared(noise_terms/relative_noise_magnitude)) + s.t. U(POINTS[i].x) >= POINTS[i].y, for all i + noise_terms[i] >= 0 + min(M) >= 0 + M is a diagonal matrix + + Therefore, the quadratic program finds the U(x) that always upper bounds + F(x) on the supplied POINTS, but is otherwise as small as possible. + + + + The inspiration for the upper_bound_function object came from the AdaLIPO + algorithm from this excellent paper: + Global optimization of Lipschitz functions + Malherbe, Cédric and Vayatis, Nicolas + International Conference on Machine Learning - 2017 + In that paper, they propose to use a simpler U(x) where noise_terms is + always 0 and M is a diagonal matrix where each diagonal element is the same + value. Therefore, there is only a single scalar parameter for U(x) in + their formulation of the problem. This causes difficulties if F(x) is + stochastic or has discontinuities since, without the noise term, M will + become really huge and the upper bound becomes vacuously large. It is also + problematic if the gradient of F(x) with respect to x contains elements of + widely varying magnitude since the simpler formulation of U(x) assumes a + uniform rate of change regardless of which dimension is varying. + !*/ + + public: + + upper_bound_function( + ); + /*! + ensures + - #num_points() == 0 + - #dimensionality() == 0 + !*/ + + explicit upper_bound_function( + const std::vector& points, + const double relative_noise_magnitude = 0.001, + const double solver_eps = 0.0001 + ); + /*! + requires + - all the x vectors in points must have the same non-zero dimensionality. + - relative_noise_magnitude >= 0 + - solver_eps > 0 + ensures + - Creates an upper bounding function U(x), as described above, assuming that + the given points are drawn from F(x). + - Uses the provided relative_noise_magnitude when solving the QP, as + described above. Note that relative_noise_magnitude can be set to 0. If + you do this then all the noise terms are constrained to 0. You should + only do this if you know F(x) is non-stochastic and continuous + everywhere. + - When solving the QP used to find the parameters of U(x), the upper + bounding function, we solve the QP to solver_eps accuracy. It's + possible that large enough solver_eps can lead to upper bounds that don't + upper bound all the supplied points. But for reasonable epsilon values + this shouldn't be a problem. + - #num_points() == points.size() + - #dimensionality() == points[0].x.size() + !*/ + + upper_bound_function( + const double relative_noise_magnitude, + const double solver_eps + ); + /*! + requires + - relative_noise_magnitude >= 0 + - solver_eps > 0 + ensures + - #num_points() == 0 + - #dimensionality() == 0 + - This destructor is the same as calling the above constructor with points.size()==0 + !*/ + + + void add ( + const function_evaluation& point + ); + /*! + requires + - num_points() == 0 || point.x.size() == dimensionality() + - point.x.size() != 0 + ensures + - Adds point to get_points(). + - Incrementally updates the upper bounding function with the given function + evaluation. That is, we assume that F(point.x)==point.y and solve the QP + described above to find the new U(x) that upper bounds all the points + this object knows about (i.e. all the points in get_points() and the new point). + - Calling add() is much faster than recreating the upper_bound_function + from scratch with all the points. This is because we warm start with the + previous solution to the QP. This is done by discarding any non-active + constraints and solving the QP again with only the previously active + constraints and the new constraints formed by all the pairs of the new + point and the old points. This means the QP solved by add() is much + smaller than the QP that would be solved by a fresh call to the + upper_bound_function constructor. + !*/ + + const std::vector& get_points( + ) const; + /*! + ensures + - returns the points from F(x) used to define this upper bounding function. + These are all the function_evaluation objects given to this object via + its constructor and add(). + !*/ + + long num_points( + ) const; + /*! + ensures + - returns the number of points used to define the upper bounding function. + (i.e. returns get_points().size()) + !*/ + + long dimensionality( + ) const; + /*! + ensures + - returns the dimensionality of the input vectors to the upper bounding function. + !*/ + + double operator() ( + const matrix& x + ) const; + /*! + requires + - num_points() > 0 + - x.size() == dimensionality() + ensures + - return U(x) + (i.e. returns the upper bound on F(x) at x given by our upper bounding function) + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_UPPER_bOUND_FUNCTION_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/graph.h b/ml/dlib/dlib/graph.h new file mode 100644 index 000000000..39f7cfbb6 --- /dev/null +++ b/ml/dlib/dlib/graph.h @@ -0,0 +1,37 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_GRAPh_ +#define DLIB_GRAPh_ + +#include "graph/graph_kernel_1.h" + +#include "algs.h" + +namespace dlib +{ + + template < + typename T, + typename E = char, + typename mem_manager = default_memory_manager + > + class graph + { + graph() {} + public: + + + //----------- kernels --------------- + + // kernel_1a + typedef graph_kernel_1 + kernel_1a; + typedef graph_kernel_1 + kernel_1a_c; + + }; +} + +#endif // DLIB_GRAPh_ + + diff --git a/ml/dlib/dlib/graph/graph_kernel_1.h b/ml/dlib/dlib/graph/graph_kernel_1.h new file mode 100644 index 000000000..fb0d6e7a6 --- /dev/null +++ b/ml/dlib/dlib/graph/graph_kernel_1.h @@ -0,0 +1,629 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_GRAPH_KERNEl_1_ +#define DLIB_GRAPH_KERNEl_1_ + +#include +#include + +#include "../serialize.h" +#include "../noncopyable.h" +#include "../std_allocator.h" +#include "../algs.h" +#include "graph_kernel_abstract.h" +#include "../is_kind.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template + struct graph_checker_helper + { + /*! + This object is used to check preconditions based on the value of is_checked + !*/ + + static void check_neighbor ( + unsigned long edge_index, + const node_type& self + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(edge_index < self.number_of_neighbors(), + "\tnode_type& graph::node_type::neighbor(edge_index)" + << "\n\tYou have specified an invalid index" + << "\n\tedge_index: " << edge_index + << "\n\tnumber_of_neighbors(): " << self.number_of_neighbors() + << "\n\tthis: " << &self + ); + } + + static void check_edge ( + unsigned long edge_index, + const node_type& self + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(edge_index < self.number_of_neighbors(), + "\tE& graph::node_type::edge(edge_index)" + << "\n\tYou have specified an invalid index" + << "\n\tedge_index: " << edge_index + << "\n\tnumber_of_neighbors(): " << self.number_of_neighbors() + << "\n\tthis: " << &self + ); + } + + static void check_node ( + unsigned long index, + const graph& self + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(index < self.number_of_nodes(), + "\tnode_type& graph::node(index)" + << "\n\tYou have specified an invalid index" + << "\n\tindex: " << index + << "\n\tnumber_of_nodes(): " << self.number_of_nodes() + << "\n\tthis: " << &self + ); + } + + static void check_has_edge ( + unsigned long node_index1, + unsigned long node_index2, + const graph& self + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(node_index1 < self.number_of_nodes() && + node_index2 < self.number_of_nodes(), + "\tvoid graph::has_edge(node_index1, node_index2)" + << "\n\tYou have specified an invalid index" + << "\n\tnode_index1: " << node_index1 + << "\n\tnode_index2: " << node_index2 + << "\n\tnumber_of_nodes(): " << self.number_of_nodes() + << "\n\tthis: " << &self + ); + } + + static void check_add_edge ( + unsigned long node_index1, + unsigned long node_index2, + const graph& self + ) + { + DLIB_CASSERT(node_index1 < self.number_of_nodes() && + node_index2 < self.number_of_nodes(), + "\tvoid graph::add_edge(node_index1, node_index2)" + << "\n\tYou have specified an invalid index" + << "\n\tnode_index1: " << node_index1 + << "\n\tnode_index2: " << node_index2 + << "\n\tnumber_of_nodes(): " << self.number_of_nodes() + << "\n\tthis: " << &self + ); + + DLIB_CASSERT( self.has_edge(node_index1, node_index2) == false, + "\tvoid graph::add_edge(node_index1, node_index2)" + << "\n\tYou can't add an edge if it already exists in the graph" + << "\n\tnode_index1: " << node_index1 + << "\n\tnode_index2: " << node_index2 + << "\n\tnumber_of_nodes(): " << self.number_of_nodes() + << "\n\tthis: " << &self + ); + + } + + static void check_remove_edge ( + unsigned long node_index1, + unsigned long node_index2, + const graph& self + ) + { + DLIB_CASSERT(node_index1 < self.number_of_nodes() && + node_index2 < self.number_of_nodes(), + "\tvoid graph::remove_edge(node_index1, node_index2)" + << "\n\tYou have specified an invalid index" + << "\n\tnode_index1: " << node_index1 + << "\n\tnode_index2: " << node_index2 + << "\n\tnumber_of_nodes(): " << self.number_of_nodes() + << "\n\tthis: " << &self + ); + + DLIB_CASSERT( self.has_edge(node_index1, node_index2) == true, + "\tvoid graph::remove_edge(node_index1, node_index2)" + << "\n\tYou can't remove an edge if it isn't in the graph" + << "\n\tnode_index1: " << node_index1 + << "\n\tnode_index2: " << node_index2 + << "\n\tnumber_of_nodes(): " << self.number_of_nodes() + << "\n\tthis: " << &self + ); + + } + + static void check_remove_node ( + unsigned long index, + const graph& self + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(index < self.number_of_nodes(), + "\tvoid graph::remove_node(index)" + << "\n\tYou have specified an invalid index" + << "\n\tindex: " << index + << "\n\tnumber_of_nodes(): " << self.number_of_nodes() + << "\n\tthis: " << &self + ); + } + }; + + template + struct graph_checker_helper + { + static inline void check_edge ( unsigned long , const node_type& ) { } + static inline void check_neighbor ( unsigned long , const node_type& ) { } + static inline void check_node ( unsigned long , const graph& ) { } + static inline void check_has_edge ( unsigned long , unsigned long , const graph& ) { } + static inline void check_add_edge ( unsigned long , unsigned long , const graph& ) { } + static inline void check_remove_edge ( unsigned long , unsigned long , const graph& ) { } + static inline void check_remove_node ( unsigned long , const graph& ) { } + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename E = char, + typename mem_manager = default_memory_manager, + bool is_checked = true + > + class graph_kernel_1 : noncopyable + { + + /*! + INITIAL VALUE + - nodes.size() == 0 + + CONVENTION + - nodes.size() == number_of_nodes() + - for all valid i: + - *nodes[i] == node(i) + - nodes[i]->neighbors.size() == nodes[i]->number_of_neighbors(i) + - nodes[i]->idx == i == nodes[i]->index() + - for all valid n: + - nodes[i]->neighbors[n] == pointer to the n'th parent node of i + - *nodes[i]->neighbors[n] == node(i).neighbor(n) + - *nodes[i]->edges[n] == node(i).edge(n) + !*/ + + public: + struct node_type; + + private: + typedef graph_checker_helper checker; + + + public: + + typedef T type; + typedef E edge_type; + typedef mem_manager mem_manager_type; + + graph_kernel_1( + ) {} + + virtual ~graph_kernel_1( + ) {} + + void clear( + ) { nodes.clear(); } + + void set_number_of_nodes ( + unsigned long new_size + ); + + unsigned long number_of_nodes ( + ) const { return nodes.size(); } + + node_type& node ( + unsigned long index + ) { checker::check_node(index,*this); return *nodes[index]; } + + const node_type& node ( + unsigned long index + ) const { checker::check_node(index,*this); return *nodes[index]; } + + bool has_edge ( + unsigned long node_index1, + unsigned long node_index2 + ) const; + + void add_edge ( + unsigned long node_index1, + unsigned long node_index2 + ); + + void remove_edge ( + unsigned long node_index1, + unsigned long node_index2 + ); + + unsigned long add_node ( + ); + + void remove_node ( + unsigned long index + ); + + void swap ( + graph_kernel_1& item + ) { nodes.swap(item.nodes); } + + public: + + struct node_type + { + T data; + typedef graph_kernel_1 graph_type; + + unsigned long index( + ) const { return idx; } + + unsigned long number_of_neighbors ( + ) const { return neighbors.size(); } + + const node_type& neighbor ( + unsigned long edge_index + ) const { checker::check_neighbor(edge_index,*this); return *neighbors[edge_index]; } + + node_type& neighbor ( + unsigned long edge_index + ) { checker::check_neighbor(edge_index,*this); return *neighbors[edge_index]; } + + const E& edge ( + unsigned long edge_index + ) const { checker::check_edge(edge_index,*this); return *edges[edge_index]; } + + E& edge ( + unsigned long edge_index + ) { checker::check_edge(edge_index,*this); return *edges[edge_index]; } + + private: + friend class graph_kernel_1; + typedef std_allocator alloc_type; + typedef std_allocator,mem_manager> alloc_edge_type; + std::vector neighbors; + std::vector,alloc_edge_type> edges; + unsigned long idx; + }; + + private: + + typedef std_allocator,mem_manager> alloc_type; + typedef std::vector, alloc_type> vector_type; + vector_type nodes; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename E, + typename mem_manager, + bool is_checked + > + inline void swap ( + graph_kernel_1& a, + graph_kernel_1& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename E, + typename mem_manager, + bool is_checked + > + struct is_graph > + { + static const bool value = true; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename E, + typename mem_manager, + bool is_checked + > + void serialize ( + const graph_kernel_1& item, + std::ostream& out + ) + { + try + { + serialize(item.number_of_nodes(), out); + + // serialize each node + for (unsigned long i = 0; i < item.number_of_nodes(); ++i) + { + serialize(item.node(i).data, out); + + // serialize all the edges + for (unsigned long n = 0; n < item.node(i).number_of_neighbors(); ++n) + { + // only serialize edges that we haven't already serialized + if (item.node(i).neighbor(n).index() >= i) + { + serialize(item.node(i).neighbor(n).index(), out); + serialize(item.node(i).edge(n), out); + } + } + const unsigned long stop_mark = 0xFFFFFFFF; + serialize(stop_mark, out); + } + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing object of type graph_kernel_1"); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename E, + typename mem_manager, + bool is_checked + > + void deserialize ( + graph_kernel_1& item, + std::istream& in + ) + { + try + { + unsigned long size; + deserialize(size, in); + + item.clear(); + item.set_number_of_nodes(size); + + // deserialize each node + for (unsigned long i = 0; i < item.number_of_nodes(); ++i) + { + deserialize(item.node(i).data, in); + + const unsigned long stop_mark = 0xFFFFFFFF; + // Add all the edges going to this node's neighbors + unsigned long index; + deserialize(index, in); + while (index != stop_mark) + { + item.add_edge(i, index); + // find the edge + unsigned long j = 0; + for (j = 0; j < item.node(i).number_of_neighbors(); ++j) + if (item.node(i).neighbor(j).index() == index) + break; + + deserialize(item.node(i).edge(j), in); + deserialize(index, in); + } + + } + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while deserializing object of type graph_kernel_1"); + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename E, + typename mem_manager, + bool is_checked + > + void graph_kernel_1:: + set_number_of_nodes ( + unsigned long new_size + ) + { + try + { + nodes.resize(new_size); + for (unsigned long i = 0; i < nodes.size(); ++i) + { + nodes[i].reset(new node_type); + nodes[i]->idx = i; + } + } + catch (...) + { + clear(); + throw; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename E, + typename mem_manager, + bool is_checked + > + bool graph_kernel_1:: + has_edge ( + unsigned long node_index1, + unsigned long node_index2 + ) const + { + checker::check_has_edge(node_index1, node_index2, *this); + + node_type& n = *nodes[node_index1]; + + // search all the child nodes to see if there is a link to the right node + for (unsigned long i = 0; i < n.neighbors.size(); ++i) + { + if (n.neighbors[i]->idx == node_index2) + return true; + } + + return false; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename E, + typename mem_manager, + bool is_checked + > + void graph_kernel_1:: + add_edge ( + unsigned long node_index1, + unsigned long node_index2 + ) + { + checker::check_add_edge(node_index1, node_index2, *this); + try + { + node_type& n1 = *nodes[node_index1]; + node_type& n2 = *nodes[node_index2]; + + n1.neighbors.push_back(&n2); + + std::shared_ptr e(new E); + n1.edges.push_back(e); + + // don't add this twice if this is an edge from node_index1 back to itself + if (node_index1 != node_index2) + { + n2.neighbors.push_back(&n1); + n2.edges.push_back(e); + } + } + catch (...) + { + clear(); + throw; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename E, + typename mem_manager, + bool is_checked + > + void graph_kernel_1:: + remove_edge ( + unsigned long node_index1, + unsigned long node_index2 + ) + { + checker::check_remove_edge(node_index1, node_index2, *this); + + node_type& n1 = *nodes[node_index1]; + node_type& n2 = *nodes[node_index2]; + + // remove the record of the link from n1 + unsigned long pos = static_cast(find(n1.neighbors.begin(), n1.neighbors.end(), &n2) - n1.neighbors.begin()); + n1.neighbors.erase(n1.neighbors.begin() + pos); + n1.edges.erase(n1.edges.begin() + pos); + + // check if this is an edge that goes from node_index1 back to itself + if (node_index1 != node_index2) + { + // remove the record of the link from n2 + unsigned long pos = static_cast(find(n2.neighbors.begin(), n2.neighbors.end(), &n1) - n2.neighbors.begin()); + n2.neighbors.erase(n2.neighbors.begin() + pos); + n2.edges.erase(n2.edges.begin() + pos); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename E, + typename mem_manager, + bool is_checked + > + unsigned long graph_kernel_1:: + add_node ( + ) + { + try + { + std::shared_ptr n(new node_type); + n->idx = nodes.size(); + nodes.push_back(n); + return n->idx; + } + catch (...) + { + clear(); + throw; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename E, + typename mem_manager, + bool is_checked + > + void graph_kernel_1:: + remove_node ( + unsigned long index + ) + { + checker::check_remove_node(index,*this); + + node_type& n = *nodes[index]; + + // remove all edges pointing to this node from its neighbors + for (unsigned long i = 0; i < n.neighbors.size(); ++i) + { + // remove the edge from this specific parent + unsigned long pos = static_cast(find(n.neighbors[i]->neighbors.begin(), n.neighbors[i]->neighbors.end(), &n) - + n.neighbors[i]->neighbors.begin()); + n.neighbors[i]->neighbors.erase(n.neighbors[i]->neighbors.begin() + pos); + n.neighbors[i]->edges.erase(n.neighbors[i]->edges.begin() + pos); + } + + // now remove this node by replacing it with the last node in the nodes vector + nodes[index] = nodes[nodes.size()-1]; + + // update the index for the node we just moved + nodes[index]->idx = index; + + // now remove the duplicated node at the end of the vector + nodes.pop_back(); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_GRAPH_KERNEl_1_ + diff --git a/ml/dlib/dlib/graph/graph_kernel_abstract.h b/ml/dlib/dlib/graph/graph_kernel_abstract.h new file mode 100644 index 000000000..e6e699332 --- /dev/null +++ b/ml/dlib/dlib/graph/graph_kernel_abstract.h @@ -0,0 +1,329 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_GRAPH_KERNEl_ABSTRACT_ +#ifdef DLIB_GRAPH_KERNEl_ABSTRACT_ + +#include "../serialize.h" +#include "../algs.h" +#include "../noncopyable.h" + +namespace dlib +{ + + template < + typename T, + typename E = char, + typename mem_manager = default_memory_manager + > + class graph : noncopyable + { + + /*! + REQUIREMENTS ON T + T must be swappable by a global swap() and + T must have a default constructor + + REQUIREMENTS ON E + E must be swappable by a global swap() and + E must have a default constructor + + REQUIREMENTS ON mem_manager + must be an implementation of memory_manager/memory_manager_kernel_abstract.h or + must be an implementation of memory_manager_global/memory_manager_global_kernel_abstract.h or + must be an implementation of memory_manager_stateless/memory_manager_stateless_kernel_abstract.h + mem_manager::type can be set to anything. + + POINTERS AND REFERENCES TO INTERNAL DATA + The only time pointers or references to nodes or edges become invalid is when + they reference nodes or edges that have been removed from a graph. + + INITIAL VALUE + number_of_nodes() == 0 + + WHAT THIS OBJECT REPRESENTS + This object represents an undirected graph which is a set of nodes with undirected + edges connecting various nodes. + + Also note that unless specified otherwise, no member functions + of this object throw exceptions. + !*/ + + public: + + typedef T type; + typedef E edge_type; + typedef mem_manager mem_manager_type; + + graph( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc or any exception thrown by T's constructor. + !*/ + + virtual ~graph( + ); + /*! + ensures + - all resources associated with *this has been released + !*/ + + void clear( + ); + /*! + ensures + - #*this has its initial value + throws + - std::bad_alloc or any exception thrown by T's constructor. + If this exception is thrown then *this is unusable + until clear() is called and succeeds + !*/ + + void set_number_of_nodes ( + unsigned long new_size + ); + /*! + ensures + - #number_of_nodes() == new_size + - for all i < new_size: + - number_of_neighbors(i) == 0 + throws + - std::bad_alloc or any exception thrown by T's constructor. + If this exception is thrown then this object reverts back + to its initial state. + !*/ + + unsigned long number_of_nodes ( + ) const; + /*! + ensures + - returns the number of nodes in this graph + !*/ + + struct node_type + { + T data; + typedef graph graph_type; + + unsigned long index( + ) const; + /*! + ensures + - let G be the graph that contains the node *this + - returns a number N such that G.node(N) == *this + (i.e. returns the index of this node in the graph) + !*/ + + unsigned long number_of_neighbors ( + ) const; + /*! + ensures + - returns the number of nodes in this graph that are + adjacent to this node. I.e. the number of nodes + that are directly connected to this node via an edge. + !*/ + + const node_type& neighbor ( + unsigned long edge_index + ) const; + /*! + requires + - edge_index < number_of_neighbors() + ensures + - returns a const reference to the edge_index'th neighbor of *this + !*/ + + node_type& neighbor ( + unsigned long edge_index + ); + /*! + requires + - edge_index < number_of_neighbors() + ensures + - returns a non-const reference to the edge_index'th neighbor of *this + !*/ + + const E& edge ( + unsigned long edge_index + ) const; + /*! + requires + - edge_index < number_of_neighbors() + ensures + - returns a const reference to the edge_index'th edge data for the + edge connecting to neighbor this->neighbor(edge_index) + !*/ + + E& edge ( + unsigned long edge_index + ); + /*! + requires + - edge_index < number_of_neighbors() + ensures + - returns a non-const reference to the edge_index'th edge data for the + edge connecting to neighbor this->neighbor(edge_index) + !*/ + + }; + + node_type& node ( + unsigned long index + ); + /*! + requires + - index < number_of_nodes() + ensures + - returns a non-const reference to the node with the given index + !*/ + + const node_type& node ( + unsigned long index + ) const; + /*! + requires + - index < number_of_nodes() + ensures + - returns a const reference to the node with the given index + !*/ + + bool has_edge ( + unsigned long node_index1, + unsigned long node_index2 + ) const; + /*! + requires + - node_index1 < number_of_nodes() + - node_index2 < number_of_nodes() + ensures + - if (there is an edge connecting node(node_index1) and node(node_index2)) then + - returns true + - else + - returns false + !*/ + + void add_edge ( + unsigned long node_index1, + unsigned long node_index2 + ); + /*! + requires + - node_index1 < number_of_nodes() + - node_index2 < number_of_nodes() + - has_edge(node_index1, node_index2) == false + ensures + - #has_edge(node_index1, node_index2) == true + throws + - std::bad_alloc + If this exception is thrown then this object reverts back + to its initial state. + !*/ + + void remove_edge ( + unsigned long node_index1, + unsigned long node_index2 + ); + /*! + requires + - node_index1 < number_of_nodes() + - node_index2 < number_of_nodes() + - has_edge(node_index1, node_index2) == true + ensures + - #has_edge(node_index1, node_index2) == false + throws + - std::bad_alloc + If this exception is thrown then this object reverts back + to its initial state. + !*/ + + unsigned long add_node ( + ); + /*! + ensures + - does not change the index number of existing nodes + - adds a node with index N == number_of_nodes() such that: + - #node(N).number_of_neighbors() == 0 + - #number_of_nodes() == number_of_nodes() + 1 + - returns N + throws + - std::bad_alloc or any exception thrown by T's constructor. + If this exception is thrown then this object reverts back + to its initial state. + !*/ + + void remove_node ( + unsigned long index + ); + /*! + requires + - index < number_of_nodes() + ensures + - removes the node with the given index from the graph. + - removes all edges linking the removed node to the rest + of the graph. + - the remaining node indexes are remapped so that they remain + contiguous. (This means that for all valid N, node(N) doesn't + necessarily reference the same node as #node(N)) + - #number_of_nodes() == number_of_nodes() - 1 + throws + - std::bad_alloc or any exception thrown by T's constructor. + If this exception is thrown then this object reverts back + to its initial state. + !*/ + + void swap ( + graph& item + ); + /*! + ensures + - swaps *this and item + !*/ + + }; + + template < + typename T, + typename E, + typename mem_manager + > + inline void swap ( + graph& a, + graph& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + + template < + typename T, + typename E, + typename mem_manager + > + void serialize ( + const graph& item, + std::ostream& out + ); + /*! + provides deserialization support + !*/ + + template < + typename T, + typename E, + typename mem_manager + > + void deserialize ( + graph& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +} + +#endif // DLIB_GRAPH_KERNEl_ABSTRACT_ + + diff --git a/ml/dlib/dlib/graph_cuts.h b/ml/dlib/dlib/graph_cuts.h new file mode 100644 index 000000000..c245b2be4 --- /dev/null +++ b/ml/dlib/dlib/graph_cuts.h @@ -0,0 +1,14 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_GRAPH_CUTs_HEADER_ +#define DLIB_GRAPH_CUTs_HEADER_ + +#include "graph_cuts/min_cut.h" +#include "graph_cuts/general_flow_graph.h" +#include "graph_cuts/find_max_factor_graph_potts.h" +#include "graph_cuts/graph_labeler.h" + +#endif // DLIB_GRAPH_CUTs_HEADER_ + + + diff --git a/ml/dlib/dlib/graph_cuts/find_max_factor_graph_potts.h b/ml/dlib/dlib/graph_cuts/find_max_factor_graph_potts.h new file mode 100644 index 000000000..f035442bf --- /dev/null +++ b/ml/dlib/dlib/graph_cuts/find_max_factor_graph_potts.h @@ -0,0 +1,959 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_FIND_MAX_FACTOR_GRAPH_PoTTS_Hh_ +#define DLIB_FIND_MAX_FACTOR_GRAPH_PoTTS_Hh_ + +#include "find_max_factor_graph_potts_abstract.h" +#include "../matrix.h" +#include "min_cut.h" +#include "general_potts_problem.h" +#include "../algs.h" +#include "../graph_utils.h" +#include "../array2d.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + + template < + typename potts_problem, + typename T = void + > + class flows_container + { + /* + This object notionally represents a matrix of flow values. It's + overloaded to represent this matrix efficiently though. In this case + it represents the matrix using a sparse representation. + */ + + typedef typename potts_problem::value_type edge_type; + std::vector > flows; + public: + + void setup( + const potts_problem& p + ) + { + flows.resize(p.number_of_nodes()); + for (unsigned long i = 0; i < flows.size(); ++i) + { + flows[i].resize(p.number_of_neighbors(i)); + } + } + + edge_type& operator() ( + const long r, + const long c + ) { return flows[r][c]; } + + const edge_type& operator() ( + const long r, + const long c + ) const { return flows[r][c]; } + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename potts_problem + > + class flows_container::type> + { + /* + This object notionally represents a matrix of flow values. It's + overloaded to represent this matrix efficiently though. In this case + it represents the matrix using a dense representation. + + */ + typedef typename potts_problem::value_type edge_type; + const static unsigned long max_number_of_neighbors = potts_problem::max_number_of_neighbors; + matrix flows; + public: + + void setup( + const potts_problem& p + ) + { + flows.set_size(p.number_of_nodes(), max_number_of_neighbors); + } + + edge_type& operator() ( + const long r, + const long c + ) { return flows(r,c); } + + const edge_type& operator() ( + const long r, + const long c + ) const { return flows(r,c); } + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename potts_problem + > + class potts_flow_graph + { + public: + typedef typename potts_problem::value_type edge_type; + private: + /*! + This is a utility class used by dlib::min_cut to convert a potts_problem + into the kind of flow graph expected by the min_cut object's main block + of code. + + Within this object, we will use the convention that one past + potts_problem::number_of_nodes() is the source node and two past is + the sink node. + !*/ + + potts_problem& g; + + // flows(i,j) == the flow from node id i to it's jth neighbor + flows_container flows; + // source_flows(i,0) == flow from source to node i, + // source_flows(i,1) == flow from node i to source + matrix source_flows; + + // sink_flows(i,0) == flow from sink to node i, + // sink_flows(i,1) == flow from node i to sink + matrix sink_flows; + + node_label source_label, sink_label; + public: + + potts_flow_graph( + potts_problem& g_ + ) : g(g_) + { + flows.setup(g); + + source_flows.set_size(g.number_of_nodes(), 2); + sink_flows.set_size(g.number_of_nodes(), 2); + source_flows = 0; + sink_flows = 0; + + source_label = FREE_NODE; + sink_label = FREE_NODE; + + // setup flows based on factor potentials + for (unsigned long i = 0; i < g.number_of_nodes(); ++i) + { + const edge_type temp = g.factor_value(i); + if (temp < 0) + source_flows(i,0) = -temp; + else + sink_flows(i,1) = temp; + + for (unsigned long j = 0; j < g.number_of_neighbors(i); ++j) + { + flows(i,j) = g.factor_value_disagreement(i, g.get_neighbor(i,j)); + } + } + } + + class out_edge_iterator + { + friend class potts_flow_graph; + unsigned long idx; // base node idx + unsigned long cnt; // count over the neighbors of idx + public: + + out_edge_iterator( + ):idx(0),cnt(0){} + + out_edge_iterator( + unsigned long idx_, + unsigned long cnt_ + ):idx(idx_),cnt(cnt_) + {} + + bool operator!= ( + const out_edge_iterator& item + ) const { return cnt != item.cnt; } + + out_edge_iterator& operator++( + ) + { + ++cnt; + return *this; + } + }; + + class in_edge_iterator + { + friend class potts_flow_graph; + unsigned long idx; // base node idx + unsigned long cnt; // count over the neighbors of idx + public: + + in_edge_iterator( + ):idx(0),cnt(0) + {} + + + in_edge_iterator( + unsigned long idx_, + unsigned long cnt_ + ):idx(idx_),cnt(cnt_) + {} + + bool operator!= ( + const in_edge_iterator& item + ) const { return cnt != item.cnt; } + + in_edge_iterator& operator++( + ) + { + ++cnt; + return *this; + } + }; + + unsigned long number_of_nodes ( + ) const { return g.number_of_nodes() + 2; } + + out_edge_iterator out_begin( + const unsigned long& it + ) const { return out_edge_iterator(it, 0); } + + in_edge_iterator in_begin( + const unsigned long& it + ) const { return in_edge_iterator(it, 0); } + + out_edge_iterator out_end( + const unsigned long& it + ) const + { + if (it >= g.number_of_nodes()) + return out_edge_iterator(it, g.number_of_nodes()); + else + return out_edge_iterator(it, g.number_of_neighbors(it)+2); + } + + in_edge_iterator in_end( + const unsigned long& it + ) const + { + if (it >= g.number_of_nodes()) + return in_edge_iterator(it, g.number_of_nodes()); + else + return in_edge_iterator(it, g.number_of_neighbors(it)+2); + } + + + template + unsigned long node_id ( + const iterator_type& it + ) const + { + // if this isn't an iterator over the source or sink nodes + if (it.idx < g.number_of_nodes()) + { + const unsigned long num = g.number_of_neighbors(it.idx); + if (it.cnt < num) + return g.get_neighbor(it.idx, it.cnt); + else if (it.cnt == num) + return g.number_of_nodes(); + else + return g.number_of_nodes()+1; + } + else + { + return it.cnt; + } + } + + + edge_type get_flow ( + const unsigned long& it1, + const unsigned long& it2 + ) const + { + if (it1 >= g.number_of_nodes()) + { + // if it1 is the source + if (it1 == g.number_of_nodes()) + { + return source_flows(it2,0); + } + else // if it1 is the sink + { + return sink_flows(it2,0); + } + } + else if (it2 >= g.number_of_nodes()) + { + // if it2 is the source + if (it2 == g.number_of_nodes()) + { + return source_flows(it1,1); + } + else // if it2 is the sink + { + return sink_flows(it1,1); + } + } + else + { + return flows(it1, g.get_neighbor_idx(it1, it2)); + } + + } + + edge_type get_flow ( + const out_edge_iterator& it + ) const + { + if (it.idx < g.number_of_nodes()) + { + const unsigned long num = g.number_of_neighbors(it.idx); + if (it.cnt < num) + return flows(it.idx, it.cnt); + else if (it.cnt == num) + return source_flows(it.idx,1); + else + return sink_flows(it.idx,1); + } + else + { + // if it.idx is the source + if (it.idx == g.number_of_nodes()) + { + return source_flows(it.cnt,0); + } + else // if it.idx is the sink + { + return sink_flows(it.cnt,0); + } + } + } + + edge_type get_flow ( + const in_edge_iterator& it + ) const + { + return get_flow(node_id(it), it.idx); + } + + void adjust_flow ( + const unsigned long& it1, + const unsigned long& it2, + const edge_type& value + ) + { + if (it1 >= g.number_of_nodes()) + { + // if it1 is the source + if (it1 == g.number_of_nodes()) + { + source_flows(it2,0) += value; + source_flows(it2,1) -= value; + } + else // if it1 is the sink + { + sink_flows(it2,0) += value; + sink_flows(it2,1) -= value; + } + } + else if (it2 >= g.number_of_nodes()) + { + // if it2 is the source + if (it2 == g.number_of_nodes()) + { + source_flows(it1,1) += value; + source_flows(it1,0) -= value; + } + else // if it2 is the sink + { + sink_flows(it1,1) += value; + sink_flows(it1,0) -= value; + } + } + else + { + flows(it1, g.get_neighbor_idx(it1, it2)) += value; + flows(it2, g.get_neighbor_idx(it2, it1)) -= value; + } + + } + + void set_label ( + const unsigned long& it, + node_label value + ) + { + if (it < g.number_of_nodes()) + g.set_label(it, value); + else if (it == g.number_of_nodes()) + source_label = value; + else + sink_label = value; + } + + node_label get_label ( + const unsigned long& it + ) const + { + if (it < g.number_of_nodes()) + return g.get_label(it); + if (it == g.number_of_nodes()) + return source_label; + else + return sink_label; + } + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename label_image_type, + typename image_potts_model + > + class potts_grid_problem + { + label_image_type& label_img; + long nc; + long num_nodes; + unsigned char* labels; + const image_potts_model& model; + + public: + const static unsigned long max_number_of_neighbors = 4; + + potts_grid_problem ( + label_image_type& label_img_, + const image_potts_model& image_potts_model_ + ) : + label_img(label_img_), + model(image_potts_model_) + { + num_nodes = model.nr()*model.nc(); + nc = model.nc(); + labels = &label_img[0][0]; + } + + unsigned long number_of_nodes ( + ) const { return num_nodes; } + + unsigned long number_of_neighbors ( + unsigned long + ) const + { + return 4; + } + + unsigned long get_neighbor_idx ( + long node_id1, + long node_id2 + ) const + { + long diff = node_id2-node_id1; + if (diff > nc) + diff -= (long)number_of_nodes(); + else if (diff < -nc) + diff += (long)number_of_nodes(); + + if (diff == 1) + return 0; + else if (diff == -1) + return 1; + else if (diff == nc) + return 2; + else + return 3; + } + + unsigned long get_neighbor ( + long node_id, + long idx + ) const + { + switch(idx) + { + case 0: + { + long temp = node_id+1; + if (temp < (long)number_of_nodes()) + return temp; + else + return temp - (long)number_of_nodes(); + } + case 1: + { + long temp = node_id-1; + if (node_id >= 1) + return temp; + else + return temp + (long)number_of_nodes(); + } + case 2: + { + long temp = node_id+nc; + if (temp < (long)number_of_nodes()) + return temp; + else + return temp - (long)number_of_nodes(); + } + case 3: + { + long temp = node_id-nc; + if (node_id >= nc) + return temp; + else + return temp + (long)number_of_nodes(); + } + } + return 0; + } + + void set_label ( + const unsigned long& idx, + node_label value + ) + { + *(labels+idx) = value; + } + + node_label get_label ( + const unsigned long& idx + ) const + { + return *(labels+idx); + } + + typedef typename image_potts_model::value_type value_type; + + value_type factor_value (unsigned long idx) const + { + return model.factor_value(idx); + } + + value_type factor_value_disagreement (unsigned long idx1, unsigned long idx2) const + { + return model.factor_value_disagreement(idx1,idx2); + } + + }; + + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename potts_model + > + typename potts_model::value_type potts_model_score ( + const potts_model& prob + ) + { +#ifdef ENABLE_ASSERTS + for (unsigned long i = 0; i < prob.number_of_nodes(); ++i) + { + for (unsigned long jj = 0; jj < prob.number_of_neighbors(i); ++jj) + { + unsigned long j = prob.get_neighbor(i,jj); + DLIB_ASSERT(prob.factor_value_disagreement(i,j) >= 0, + "\t value_type potts_model_score(prob)" + << "\n\t Invalid inputs were given to this function." + << "\n\t i: " << i + << "\n\t j: " << j + << "\n\t prob.factor_value_disagreement(i,j): " << prob.factor_value_disagreement(i,j) + ); + DLIB_ASSERT(prob.factor_value_disagreement(i,j) == prob.factor_value_disagreement(j,i), + "\t value_type potts_model_score(prob)" + << "\n\t Invalid inputs were given to this function." + << "\n\t i: " << i + << "\n\t j: " << j + << "\n\t prob.factor_value_disagreement(i,j): " << prob.factor_value_disagreement(i,j) + << "\n\t prob.factor_value_disagreement(j,i): " << prob.factor_value_disagreement(j,i) + ); + } + } +#endif + + typename potts_model::value_type score = 0; + for (unsigned long i = 0; i < prob.number_of_nodes(); ++i) + { + const bool label = (prob.get_label(i)!=0); + if (label) + score += prob.factor_value(i); + } + + for (unsigned long i = 0; i < prob.number_of_nodes(); ++i) + { + for (unsigned long n = 0; n < prob.number_of_neighbors(i); ++n) + { + const unsigned long idx2 = prob.get_neighbor(i,n); + const bool label_i = (prob.get_label(i)!=0); + const bool label_idx2 = (prob.get_label(idx2)!=0); + if (label_i != label_idx2 && i < idx2) + score -= prob.factor_value_disagreement(i, idx2); + } + } + + return score; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename graph_type + > + typename graph_type::edge_type potts_model_score ( + const graph_type& g, + const std::vector& labels + ) + { + DLIB_ASSERT(graph_contains_length_one_cycle(g) == false, + "\t edge_type potts_model_score(g,labels)" + << "\n\t Invalid inputs were given to this function." + ); + typedef typename graph_type::edge_type edge_type; + typedef typename graph_type::type type; + + // The edges and node's have to use the same type to represent factor weights! + COMPILE_TIME_ASSERT((is_same_type::value == true)); + +#ifdef ENABLE_ASSERTS + for (unsigned long i = 0; i < g.number_of_nodes(); ++i) + { + for (unsigned long jj = 0; jj < g.node(i).number_of_neighbors(); ++jj) + { + unsigned long j = g.node(i).neighbor(jj).index(); + DLIB_ASSERT(edge(g,i,j) >= 0, + "\t edge_type potts_model_score(g,labels)" + << "\n\t Invalid inputs were given to this function." + << "\n\t i: " << i + << "\n\t j: " << j + << "\n\t edge(g,i,j): " << edge(g,i,j) + ); + } + } +#endif + + typename graph_type::edge_type score = 0; + for (unsigned long i = 0; i < g.number_of_nodes(); ++i) + { + const bool label = (labels[i]!=0); + if (label) + score += g.node(i).data; + } + + for (unsigned long i = 0; i < g.number_of_nodes(); ++i) + { + for (unsigned long n = 0; n < g.node(i).number_of_neighbors(); ++n) + { + const unsigned long idx2 = g.node(i).neighbor(n).index(); + const bool label_i = (labels[i]!=0); + const bool label_idx2 = (labels[idx2]!=0); + if (label_i != label_idx2 && i < idx2) + score -= g.node(i).edge(n); + } + } + + return score; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename potts_grid_problem, + typename mem_manager + > + typename potts_grid_problem::value_type potts_model_score ( + const potts_grid_problem& prob, + const array2d& labels + ) + { + DLIB_ASSERT(prob.nr() == labels.nr() && prob.nc() == labels.nc(), + "\t value_type potts_model_score(prob,labels)" + << "\n\t Invalid inputs were given to this function." + << "\n\t prob.nr(): " << labels.nr() + << "\n\t prob.nc(): " << labels.nc() + ); + typedef array2d image_type; + // This const_cast is ok because the model object won't actually modify labels + dlib::impl::potts_grid_problem model(const_cast(labels),prob); + return potts_model_score(model); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename potts_model + > + void find_max_factor_graph_potts ( + potts_model& prob + ) + { +#ifdef ENABLE_ASSERTS + for (unsigned long node_i = 0; node_i < prob.number_of_nodes(); ++node_i) + { + for (unsigned long jj = 0; jj < prob.number_of_neighbors(node_i); ++jj) + { + unsigned long node_j = prob.get_neighbor(node_i,jj); + DLIB_ASSERT(prob.get_neighbor_idx(node_j,node_i) < prob.number_of_neighbors(node_j), + "\t void find_max_factor_graph_potts(prob)" + << "\n\t The supplied potts problem defines an invalid graph." + << "\n\t node_i: " << node_i + << "\n\t node_j: " << node_j + << "\n\t prob.get_neighbor_idx(node_j,node_i): " << prob.get_neighbor_idx(node_j,node_i) + << "\n\t prob.number_of_neighbors(node_j): " << prob.number_of_neighbors(node_j) + ); + + DLIB_ASSERT(prob.get_neighbor_idx(node_i,prob.get_neighbor(node_i,jj)) == jj, + "\t void find_max_factor_graph_potts(prob)" + << "\n\t The get_neighbor_idx() and get_neighbor() functions must be inverses of each other." + << "\n\t node_i: " << node_i + << "\n\t jj: " << jj + << "\n\t prob.get_neighbor(node_i,jj): " << prob.get_neighbor(node_i,jj) + << "\n\t prob.get_neighbor_idx(node_i,prob.get_neighbor(node_i,jj)): " << prob.get_neighbor_idx(node_i,node_j) + ); + + DLIB_ASSERT(prob.get_neighbor(node_j,prob.get_neighbor_idx(node_j,node_i))==node_i, + "\t void find_max_factor_graph_potts(prob)" + << "\n\t The get_neighbor_idx() and get_neighbor() functions must be inverses of each other." + << "\n\t node_i: " << node_i + << "\n\t node_j: " << node_j + << "\n\t prob.get_neighbor_idx(node_j,node_i): " << prob.get_neighbor_idx(node_j,node_i) + << "\n\t prob.get_neighbor(node_j,prob.get_neighbor_idx(node_j,node_i)): " << prob.get_neighbor(node_j,prob.get_neighbor_idx(node_j,node_i)) + ); + + DLIB_ASSERT(prob.factor_value_disagreement(node_i,node_j) >= 0, + "\t void find_max_factor_graph_potts(prob)" + << "\n\t Invalid inputs were given to this function." + << "\n\t node_i: " << node_i + << "\n\t node_j: " << node_j + << "\n\t prob.factor_value_disagreement(node_i,node_j): " << prob.factor_value_disagreement(node_i,node_j) + ); + DLIB_ASSERT(prob.factor_value_disagreement(node_i,node_j) == prob.factor_value_disagreement(node_j,node_i), + "\t void find_max_factor_graph_potts(prob)" + << "\n\t Invalid inputs were given to this function." + << "\n\t node_i: " << node_i + << "\n\t node_j: " << node_j + << "\n\t prob.factor_value_disagreement(node_i,node_j): " << prob.factor_value_disagreement(node_i,node_j) + << "\n\t prob.factor_value_disagreement(node_j,node_i): " << prob.factor_value_disagreement(node_j,node_i) + ); + } + } +#endif + COMPILE_TIME_ASSERT(is_signed_type::value); + min_cut mc; + dlib::impl::potts_flow_graph pfg(prob); + mc(pfg, prob.number_of_nodes(), prob.number_of_nodes()+1); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename graph_type + > + void find_max_factor_graph_potts ( + const graph_type& g, + std::vector& labels + ) + { + DLIB_ASSERT(graph_contains_length_one_cycle(g) == false, + "\t void find_max_factor_graph_potts(g,labels)" + << "\n\t Invalid inputs were given to this function." + ); + typedef typename graph_type::edge_type edge_type; + typedef typename graph_type::type type; + + // The edges and node's have to use the same type to represent factor weights! + COMPILE_TIME_ASSERT((is_same_type::value == true)); + COMPILE_TIME_ASSERT(is_signed_type::value); + +#ifdef ENABLE_ASSERTS + for (unsigned long i = 0; i < g.number_of_nodes(); ++i) + { + for (unsigned long jj = 0; jj < g.node(i).number_of_neighbors(); ++jj) + { + unsigned long j = g.node(i).neighbor(jj).index(); + DLIB_ASSERT(edge(g,i,j) >= 0, + "\t void find_max_factor_graph_potts(g,labels)" + << "\n\t Invalid inputs were given to this function." + << "\n\t i: " << i + << "\n\t j: " << j + << "\n\t edge(g,i,j): " << edge(g,i,j) + ); + } + } +#endif + + dlib::impl::general_potts_problem gg(g, labels); + find_max_factor_graph_potts(gg); + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename potts_grid_problem, + typename mem_manager + > + void find_max_factor_graph_potts ( + const potts_grid_problem& prob, + array2d& labels + ) + { + typedef array2d image_type; + labels.set_size(prob.nr(), prob.nc()); + dlib::impl::potts_grid_problem model(labels,prob); + find_max_factor_graph_potts(model); + } + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + template < + typename pixel_type1, + typename pixel_type2, + typename model_type + > + struct potts_grid_image_pair_model + { + const pixel_type1* data1; + const pixel_type2* data2; + const model_type& model; + const long nr_; + const long nc_; + template + potts_grid_image_pair_model( + const model_type& model_, + const image_type1& img1, + const image_type2& img2 + ) : + model(model_), + nr_(img1.nr()), + nc_(img1.nc()) + { + data1 = &img1[0][0]; + data2 = &img2[0][0]; + } + + typedef typename model_type::value_type value_type; + + long nr() const { return nr_; } + long nc() const { return nc_; } + + value_type factor_value ( + unsigned long idx + ) const + { + return model.factor_value(*(data1 + idx), *(data2 + idx)); + } + + value_type factor_value_disagreement ( + unsigned long idx1, + unsigned long idx2 + ) const + { + return model.factor_value_disagreement(*(data1 + idx1), *(data1 + idx2)); + } + }; + + // ---------------------------------------------------------------------------------------- + + template < + typename image_type, + typename model_type + > + struct potts_grid_image_single_model + { + const typename image_type::type* data1; + const model_type& model; + const long nr_; + const long nc_; + potts_grid_image_single_model( + const model_type& model_, + const image_type& img1 + ) : + model(model_), + nr_(img1.nr()), + nc_(img1.nc()) + { + data1 = &img1[0][0]; + } + + typedef typename model_type::value_type value_type; + + long nr() const { return nr_; } + long nc() const { return nc_; } + + value_type factor_value ( + unsigned long idx + ) const + { + return model.factor_value(*(data1 + idx)); + } + + value_type factor_value_disagreement ( + unsigned long idx1, + unsigned long idx2 + ) const + { + return model.factor_value_disagreement(*(data1 + idx1), *(data1 + idx2)); + } + }; + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename pair_image_model, + typename pixel_type1, + typename pixel_type2, + typename mem_manager + > + impl::potts_grid_image_pair_model make_potts_grid_problem ( + const pair_image_model& model, + const array2d& img1, + const array2d& img2 + ) + { + DLIB_ASSERT(get_rect(img1) == get_rect(img2), + "\t potts_grid_problem make_potts_grid_problem()" + << "\n\t Invalid inputs were given to this function." + << "\n\t get_rect(img1): " << get_rect(img1) + << "\n\t get_rect(img2): " << get_rect(img2) + ); + typedef impl::potts_grid_image_pair_model potts_type; + return potts_type(model,img1,img2); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename single_image_model, + typename pixel_type, + typename mem_manager + > + impl::potts_grid_image_single_model, single_image_model> make_potts_grid_problem ( + const single_image_model& model, + const array2d& img + ) + { + typedef impl::potts_grid_image_single_model, single_image_model> potts_type; + return potts_type(model,img); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_FIND_MAX_FACTOR_GRAPH_PoTTS_Hh_ + diff --git a/ml/dlib/dlib/graph_cuts/find_max_factor_graph_potts_abstract.h b/ml/dlib/dlib/graph_cuts/find_max_factor_graph_potts_abstract.h new file mode 100644 index 000000000..69aa59256 --- /dev/null +++ b/ml/dlib/dlib/graph_cuts/find_max_factor_graph_potts_abstract.h @@ -0,0 +1,636 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_FIND_MAX_FACTOR_GRAPH_PoTTS_ABSTRACT_Hh_ +#ifdef DLIB_FIND_MAX_FACTOR_GRAPH_PoTTS_ABSTRACT_Hh_ + +#include "../matrix.h" +#include "min_cut_abstract.h" +#include "../graph_utils.h" +#include "../array2d/array2d_kernel_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class potts_problem + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a boolean valued factor graph or graphical model + that can be efficiently operated on using graph cuts. In particular, this + object defines the interface a MAP problem on a factor graph must + implement if it is to be solved using the find_max_factor_graph_potts() + routine defined at the bottom of this file. + + Note that there is no dlib::potts_problem object. What you are looking + at here is simply the interface definition for a Potts problem. You must + implement your own version of this object for the problem you wish to + solve and then pass it to the find_max_factor_graph_potts() routine. + + Note also that a factor graph should not have any nodes which are + neighbors with themselves. Additionally, the graph is undirected. This + mean that if A is a neighbor of B then B must be a neighbor of A for + the MAP problem to be valid. + !*/ + + public: + + unsigned long number_of_nodes ( + ) const; + /*! + ensures + - returns the number of nodes in the factor graph. Or in other words, + returns the number of variables in the MAP problem/Potts model. + !*/ + + unsigned long number_of_neighbors ( + unsigned long idx + ) const; + /*! + requires + - idx < number_of_nodes() + ensures + - returns the number of neighbors of node idx. + !*/ + + // This is an optional variable which specifies a number that is always + // greater than or equal to number_of_neighbors(idx). If you don't know + // the value at compile time then either don't include max_number_of_neighbors + // in your potts_problem object or set it to 0. + const static unsigned long max_number_of_neighbors = 0; + + unsigned long get_neighbor ( + unsigned long idx, + unsigned long n + ) const; + /*! + requires + - idx < number_of_nodes() + - n < number_of_neighbors(idx) + ensures + - returns the node index value of the n-th neighbor of + the node with index value idx. + - The neighbor relationship is reciprocal. That is, if + get_neighbor(A,i)==B then there is a value of j such + that get_neighbor(B,j)==A. + - A node is never its own neighbor. That is, there is + no i such that get_neighbor(idx,i)==idx. + !*/ + + unsigned long get_neighbor_idx ( + unsigned long idx1, + unsigned long idx2 + ) const; + /*! + requires + - idx1 < number_of_nodes() + - idx2 < number_of_nodes() + ensures + - This function is basically the inverse of get_neighbor(). + - returns a number IDX such that: + - get_neighbor(idx1,IDX) == idx2 + - IDX < number_of_neighbors(idx1) + !*/ + + void set_label ( + const unsigned long& idx, + node_label value + ); + /*! + requires + - idx < number_of_nodes() + ensures + - #get_label(idx) == value + !*/ + + node_label get_label ( + const unsigned long& idx + ) const; + /*! + requires + - idx < number_of_nodes() + ensures + - returns the current label for the idx-th node. This is a value which is + 0 if the node's label is false and is any other value if it is true. + + Note that this value is not used by factor_value() or factor_value_disagreement(). + It is simply here to provide a mechanism for find_max_factor_graph_potts() + to return its labeled result. Additionally, the reason it returns a + node_label rather than a bool is because doing it this way facilitates + use of a graph cut algorithm for the solution of the MAP problem. For + more of an explanation you should read the paper referenced by the min_cut + object. + !*/ + + // This typedef should be for a type like int or double. It + // must also be capable of representing signed values. + typedef an_integer_or_real_type value_type; + + value_type factor_value ( + unsigned long idx + ) const; + /*! + requires + - idx < number_of_nodes() + ensures + - returns a value which indicates how "good" it is to assign the idx-th + node the label of true. The larger the value, the more desirable it is + to give it this label. Similarly, a negative value indicates that it is + better to give the node a label of false. + - It is valid for the returned value to be positive or negative infinity. + A value of positive infinity indicates that the idx-th node must be labeled + true while negative infinity means it must be labeled false. + !*/ + + value_type factor_value_disagreement ( + unsigned long idx1, + unsigned long idx2 + ) const; + /*! + requires + - idx1 < number_of_nodes() + - idx2 < number_of_nodes() + - idx1 != idx2 + - the idx1-th node and idx2-th node are neighbors in the graph. That is, + get_neighbor(idx1,i)==idx2 for some value of i. + ensures + - returns a number >= 0. This is the penalty for giving node idx1 and idx2 + different labels. Larger values indicate a larger penalty. + - this function is symmetric. That is, it is true that: + factor_value_disagreement(i,j) == factor_value_disagreement(j,i) + - It is valid for the returned value to be positive infinity. Returning + infinity indicates that the idx1-th and idx2-th nodes must share the same + label. + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + class potts_grid_problem + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a specialization of a potts_problem to the case where + the graph is a regular grid where each node is connected to its four + neighbors. An example of this is an image where each pixel is a node + and is connected to its four immediate neighboring pixels. Therefore, + this object defines the interface this special kind of MAP problem + must implement if it is to be solved by the find_max_factor_graph_potts(potts_grid_problem,array2d) + routine defined at the end of this file. + + + Note that all nodes always have four neighbors, even nodes on the edge + of the graph. This is because these border nodes are connected to + the border nodes on the other side of the graph. That is, the graph + "wraps" around at the borders. + !*/ + + public: + + // This typedef should be for a type like int or double. It + // must also be capable of representing signed values. + typedef an_integer_or_real_type value_type; + + long nr( + ) const; + /*! + ensures + - returns the number of rows in the grid + !*/ + + long nc( + ) const; + /*! + ensures + - returns the number of columns in the grid + !*/ + + value_type factor_value ( + unsigned long idx + ) const; + /*! + requires + - idx < nr()*nc() + ensures + - The grid is represented in row-major-order format. Therefore, idx + identifies a node according to its position in the row-major-order + representation of the grid graph. Or in other words, idx corresponds + to the following row and column location in the graph: + - row == idx/nc() + - col == idx%nc() + - returns a value which indicates how "good" it is to assign the idx-th + node the label of true. The larger the value, the more desirable it is + to give it this label. Similarly, a negative value indicates that it is + better to give the node a label of false. + - It is valid for the returned value to be positive or negative infinity. + A value of positive infinity indicates that the idx-th node must be labeled + true while negative infinity means it must be labeled false. + !*/ + + value_type factor_value_disagreement ( + unsigned long idx1, + unsigned long idx2 + ) const; + /*! + requires + - idx1 < nr()*nc() + - idx2 < nr()*nc() + - idx1 != idx2 + - the idx1-th node and idx2-th node are neighbors in the grid graph. + ensures + - The grid is represented in row-major-order format. Therefore, idx1 and + idx2 identify nodes according to their positions in the row-major-order + representation of the grid graph. For example, idx1 corresponds + to the following row and column location in the graph: + - row == idx1/nc() + - col == idx1%nc() + - returns a number >= 0. This is the penalty for giving node idx1 and idx2 + different labels. Larger values indicate a larger penalty. + - this function is symmetric. That is, it is true that: + factor_value_disagreement(i,j) == factor_value_disagreement(j,i) + - It is valid for the returned value to be positive infinity. Returning + infinity indicates that the idx1-th and idx2-th nodes must share the same + label. + !*/ + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename potts_problem + > + typename potts_problem::value_type potts_model_score ( + const potts_problem& prob + ); + /*! + requires + - potts_problem == an object with an interface compatible with the potts_problem + object defined at the top of this file. + - for all valid i and j: + - prob.factor_value_disagreement(i,j) >= 0 + - prob.factor_value_disagreement(i,j) == prob.factor_value_disagreement(j,i) + ensures + - computes the model score for the given potts_problem. We define this + precisely below: + - let L(i) == the boolean label of the i-th variable in prob. Or in other + words, L(i) == (prob.get_label(i) != 0). + - let F == the sum of values of prob.factor_value(i) for only i values + where L(i) == true. + - Let D == the sum of values of prob.factor_value_disagreement(i,j) + for only i and j values which meet the following conditions: + - i and j are neighbors in the graph defined by prob, that is, + it is valid to call prob.factor_value_disagreement(i,j). + - L(i) != L(j) + - i < j + (i.e. We want to make sure to only count the edge between i and j once) + + - Then this function returns F - D + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename graph_type + > + typename graph_type::edge_type potts_model_score ( + const graph_type& g, + const std::vector& labels + ); + /*! + requires + - graph_type is an implementation of dlib/graph/graph_kernel_abstract.h + - graph_type::edge_type is some signed type such as int or double + - graph_type::type must be the same type as graph_type::edge_type + - graph_contains_length_one_cycle(g) == false + - for all valid i and j: + - edge(g,i,j) >= 0 + ensures + - This function does the same thing as the version of potts_model_score() + defined above, except that this version operates on a dlib::graph + instead of a potts_problem object. + - computes the model score for the given graph and labeling. We define this + precisely below: + - let L(i) == the boolean label of the i-th variable in g. Or in other + words, L(i) == (labels[i] != 0). + - let F == the sum of values of g.node(i).data for only i values + where L(i) == true. + - Let D == the sum of values of edge(g,i,j) for only i and j + values which meet the following conditions: + - i and j are neighbors in the graph defined by g, that is, + it is valid to call edge(g,i,j). + - L(i) != L(j) + - i < j + (i.e. We want to make sure to only count the edge between i and j once) + + - Then this function returns F - D + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename potts_grid_problem, + typename mem_manager + > + typename potts_grid_problem::value_type potts_model_score ( + const potts_grid_problem& prob, + const array2d& labels + ); + /*! + requires + - prob.nr() == labels.nr() + - prob.nc() == labels.nc() + - potts_grid_problem == an object with an interface compatible with the + potts_grid_problem object defined above. + - for all valid i and j: + - prob.factor_value_disagreement(i,j) >= 0 + - prob.factor_value_disagreement(i,j) == prob.factor_value_disagreement(j,i) + ensures + - computes the model score for the given potts_grid_problem. We define this + precisely below: + - let L(i) == the boolean label of the i-th variable in prob. Or in other + words, L(i) == (labels[i/labels.nc()][i%labels.nc()] != 0). + - let F == the sum of values of prob.factor_value(i) for only i values + where L(i) == true. + - Let D == the sum of values of prob.factor_value_disagreement(i,j) + for only i and j values which meet the following conditions: + - i and j are neighbors in the graph defined by prob, that is, + it is valid to call prob.factor_value_disagreement(i,j). + - L(i) != L(j) + - i < j + (i.e. We want to make sure to only count the edge between i and j once) + + - Then this function returns F - D + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename potts_problem + > + void find_max_factor_graph_potts ( + potts_problem& prob + ); + /*! + requires + - potts_problem == an object with an interface compatible with the potts_problem + object defined at the top of this file. + - for all valid i and j: + - prob.factor_value_disagreement(i,j) >= 0 + - prob.factor_value_disagreement(i,j) == prob.factor_value_disagreement(j,i) + ensures + - This function is a tool for exactly solving the MAP problem in a Potts + model. In particular, this means that this function finds the assignments + to all the labels in prob which maximizes potts_model_score(#prob). + - The optimal labels are stored in #prob. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename graph_type + > + void find_max_factor_graph_potts ( + const graph_type& g, + std::vector& labels + ); + /*! + requires + - graph_type is an implementation of dlib/graph/graph_kernel_abstract.h + - graph_type::edge_type is some signed type such as int or double + - graph_type::type must be the same type as graph_type::edge_type + - graph_contains_length_one_cycle(g) == false + - for all valid i and j: + - edge(g,i,j) >= 0 + ensures + - This routine simply converts g into a potts_problem and calls the + version of find_max_factor_graph_potts() defined above on it. Therefore, + this routine is just a convenience wrapper that lets you use a dlib::graph + to represent a potts problem. This means that this function maximizes + the value of potts_model_score(g, #labels). + - #labels.size() == g.number_of_nodes() + - for all valid i: + - #labels[i] == the optimal label for g.node(i) + - The correspondence between g and a potts_problem is the following: + - the factor_value() for a node is stored in g.node(i).data. + - the factor_value_disagreement(i,j) is stored in edge(g,i,j). + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename potts_grid_problem, + typename mem_manager + > + void find_max_factor_graph_potts ( + const potts_grid_problem& prob, + array2d& labels + ); + /*! + requires + - potts_grid_problem == an object with an interface compatible with the + potts_grid_problem object defined above. + - for all valid i and j: + - prob.factor_value_disagreement(i,j) >= 0 + - prob.factor_value_disagreement(i,j) == prob.factor_value_disagreement(j,i) + ensures + - This routine solves a version of a potts problem where the graph is a + regular grid where each node is connected to its four immediate neighbors. + In particular, this means that this function finds the assignments + to all the labels in prob which maximizes potts_model_score(prob,#labels). + - The optimal labels are stored in #labels. + - #labels.nr() == prob.nr() + - #labels.nc() == prob.nc() + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// The following functions and interface definitions are convenience routines for use +// with the potts grid problem version of find_max_factor_graph_potts() defined above. +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + struct single_image_model + { + /*! + WHAT THIS OBJECT REPRESENTS + This object defines a slightly more convenient interface for creating + potts_grid_problems which operate on an image. In this case, the goal + is to assign a binary label to each pixel in an image. In particular, + this object defines the interface used by the make_potts_grid_problem() + routine defined below. + + In the following comments, we will refer to the image supplied to + make_potts_grid_problem() as IMG. + !*/ + + // This typedef should be for a type like int or double. It + // must also be capable of representing signed values. + typedef an_integer_or_real_type value_type; + + template + value_type factor_value ( + const pixel_type& v + ) const; + /*! + requires + - v is a pixel value from IMG. + ensures + - returns a value which indicates how "good" it is to assign the location + in IMG corresponding to v with the label of true. The larger the value, + the more desirable it is to give it this label. Similarly, a negative + value indicates that it is better to give the node a label of false. + - It is valid for the returned value to be positive or negative infinity. + A value of positive infinity indicates that the pixel must be labeled + true while negative infinity means it must be labeled false. + !*/ + + template + value_type factor_value_disagreement ( + const pixel_type& v1, + const pixel_type& v2 + ) const; + /*! + requires + - v1 and v2 are pixel values from neighboring pixels in the IMG image. + ensures + - returns a number >= 0. This is the penalty for giving neighboring pixels + with values v1 and v2 different labels. Larger values indicate a larger + penalty. + - this function is symmetric. That is, it is true that: + factor_value_disagreement(i,j) == factor_value_disagreement(j,i) + - It is valid for the returned value to be positive infinity. Returning + infinity indicates that the idx1-th and idx2-th nodes must share the same + label. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + struct pair_image_model + { + /*! + WHAT THIS OBJECT REPRESENTS + This object defines a slightly more convenient interface for creating + potts_grid_problems which operate on a pair of identically sized images. + In this case, the goal is to assign a label to each pixel in the first + image of the pair. In particular, this object defines the interface + used by the make_potts_grid_problem() routine defined below. + + In the following comments, we will refer to the two images supplied to + make_potts_grid_problem() as IMG1 and IMG2. The goal of the potts + problem will be to assign labels to each pixel in IMG1 (IMG2 is + not labeled, it is simply used as a place to keep auxiliary data). + !*/ + + // This typedef should be for a type like int or double. It + // must also be capable of representing signed values. + typedef an_integer_or_real_type value_type; + + template + value_type factor_value ( + const pixel_type1& v1, + const pixel_type2& v2 + ) const; + /*! + requires + - v1 and v2 are corresponding pixels from IMG1 and IMG2 respectively. + That is, both pixel values have the same coordinates in the images. + So for example, if v1 is the value of IMG1[4][5] then v2 is the value + of IMG2[4][5]. + ensures + - returns a value which indicates how "good" it is to assign the location + in IMG1 corresponding to v1 with the label of true. The larger the value, + the more desirable it is to give it this label. Similarly, a negative + value indicates that it is better to give the node a label of false. + - It is valid for the returned value to be positive or negative infinity. + A value of positive infinity indicates that the pixel must be labeled + true while negative infinity means it must be labeled false. + !*/ + + template + value_type factor_value_disagreement ( + const pixel_type& v1, + const pixel_type& v2 + ) const; + /*! + requires + - v1 and v2 are pixel values from neighboring pixels in the IMG1 image. + ensures + - returns a number >= 0. This is the penalty for giving neighboring pixels + with values v1 and v2 different labels. Larger values indicate a larger + penalty. + - this function is symmetric. That is, it is true that: + factor_value_disagreement(i,j) == factor_value_disagreement(j,i) + - It is valid for the returned value to be positive infinity. Returning + infinity indicates that the idx1-th and idx2-th nodes must share the same + label. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename single_image_model, + typename pixel_type, + typename mem_manager + > + potts_grid_problem make_potts_grid_problem ( + const single_image_model& model, + const array2d& img + ); + /*! + requires + - single_image_model == an object with an interface compatible with the + single_image_model object defined above. + - for all valid i and j: + - model.factor_value_disagreement(i,j) >= 0 + - model.factor_value_disagreement(i,j) == model.factor_value_disagreement(j,i) + ensures + - returns a potts_grid_problem which can be solved using the + find_max_factor_graph_potts(prob,array2d) routine defined above. That is, + given an image to store the labels, the following statement would solve the + potts problem defined by the given model and image: + find_max_factor_graph_potts(make_potts_grid_problem(model,img),labels); + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename pair_image_model, + typename pixel_type1, + typename pixel_type2, + typename mem_manager + > + potts_grid_problem make_potts_grid_problem ( + const pair_image_model& model, + const array2d& img1, + const array2d& img2 + ); + /*! + requires + - get_rect(img1) == get_rect(img2) + (i.e. img1 and img2 have the same dimensions) + - pair_image_model == an object with an interface compatible with the + pair_image_model object defined above. + - for all valid i and j: + - model.factor_value_disagreement(i,j) >= 0 + - model.factor_value_disagreement(i,j) == model.factor_value_disagreement(j,i) + ensures + - returns a potts_grid_problem which can be solved using the + find_max_factor_graph_potts(prob,array2d) routine defined above. That is, + given an image to store the labels, the following statement would solve the + potts problem defined by the given model and pair of images: + find_max_factor_graph_potts(make_potts_grid_problem(model,img1,img2),labels); + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_FIND_MAX_FACTOR_GRAPH_PoTTS_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/graph_cuts/general_flow_graph.h b/ml/dlib/dlib/graph_cuts/general_flow_graph.h new file mode 100644 index 000000000..d0b93e311 --- /dev/null +++ b/ml/dlib/dlib/graph_cuts/general_flow_graph.h @@ -0,0 +1,172 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_GENERAL_FLOW_GRaPH_Hh_ +#define DLIB_GENERAL_FLOW_GRaPH_Hh_ + +#include "../graph_utils.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + + template < + typename directed_graph_type + > + class general_flow_graph + { + /*! + this is a utility class used by dlib::min_cut to convert a directed_graph + into the kind of flow graph expected by the min_cut object's main block + of code. + !*/ + + directed_graph_type& g; + + typedef typename directed_graph_type::node_type node_type; + typedef typename directed_graph_type::type node_label; + + public: + + general_flow_graph( + directed_graph_type& g_ + ) : g(g_) + { + } + + class out_edge_iterator + { + friend class general_flow_graph; + unsigned long idx; // base node idx + unsigned long cnt; // count over the neighbors of idx + public: + + out_edge_iterator( + ):idx(0),cnt(0) {} + + out_edge_iterator( + unsigned long idx_, + unsigned long cnt_ + ):idx(idx_),cnt(cnt_) + {} + + bool operator!= ( + const out_edge_iterator& item + ) const { return cnt != item.cnt; } + + out_edge_iterator& operator++( + ) + { + ++cnt; + return *this; + } + }; + + class in_edge_iterator + { + + friend class general_flow_graph; + unsigned long idx; // base node idx + unsigned long cnt; // count over the neighbors of idx + public: + + in_edge_iterator( + ):idx(0),cnt(0) {} + + in_edge_iterator( + unsigned long idx_, + unsigned long cnt_ + ):idx(idx_),cnt(cnt_) + {} + + bool operator!= ( + const in_edge_iterator& item + ) const { return cnt != item.cnt; } + + in_edge_iterator& operator++( + ) + { + ++cnt; + return *this; + } + }; + + unsigned long number_of_nodes ( + ) const { return g.number_of_nodes(); } + + out_edge_iterator out_begin( + const unsigned long& it + ) const { return out_edge_iterator(it, 0); } + + in_edge_iterator in_begin( + const unsigned long& it + ) const { return in_edge_iterator(it, 0); } + + out_edge_iterator out_end( + const unsigned long& it + ) const { return out_edge_iterator(it, g.node(it).number_of_children()); } + + in_edge_iterator in_end( + const unsigned long& it + ) const { return in_edge_iterator(it, g.node(it).number_of_parents()); } + + unsigned long node_id ( + const out_edge_iterator& it + ) const { return g.node(it.idx).child(it.cnt).index(); } + unsigned long node_id ( + const in_edge_iterator& it + ) const { return g.node(it.idx).parent(it.cnt).index(); } + + typedef typename directed_graph_type::edge_type edge_type; + + edge_type get_flow (const unsigned long& it1, const unsigned long& it2) const + { + return edge(g, it1, it2); + } + edge_type get_flow (const out_edge_iterator& it) const + { + return g.node(it.idx).child_edge(it.cnt); + } + edge_type get_flow (const in_edge_iterator& it) const + { + return g.node(it.idx).parent_edge(it.cnt); + } + + void adjust_flow ( + const unsigned long& it1, + const unsigned long& it2, + const edge_type& value + ) + { + edge(g, it1, it2) += value; + edge(g, it2, it1) -= value; + } + + void set_label ( + const unsigned long& it, + node_label value + ) + { + g.node(it).data = value; + } + + node_label get_label ( + const unsigned long& it + ) const + { + return g.node(it).data; + } + + }; + + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_GENERAL_FLOW_GRaPH_Hh_ + diff --git a/ml/dlib/dlib/graph_cuts/general_potts_problem.h b/ml/dlib/dlib/graph_cuts/general_potts_problem.h new file mode 100644 index 000000000..ebedbc572 --- /dev/null +++ b/ml/dlib/dlib/graph_cuts/general_potts_problem.h @@ -0,0 +1,99 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_GENERAL_POTTS_PRoBLEM_Hh_ +#define DLIB_GENERAL_POTTS_PRoBLEM_Hh_ + +#include "../graph_utils.h" +#include "min_cut.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + template < + typename graph_type + > + class general_potts_problem + { + + const graph_type& g; + std::vector& labels; + public: + general_potts_problem ( + const graph_type& g_, + std::vector& labels_ + ) : g(g_), labels(labels_) + { + labels.resize(g.number_of_nodes()); + } + + unsigned long number_of_nodes ( + ) const { return g.number_of_nodes(); } + + unsigned long number_of_neighbors ( + unsigned long idx + ) const { return g.node(idx).number_of_neighbors(); } + + unsigned long get_neighbor ( + unsigned long idx, + unsigned long n + ) const { return g.node(idx).neighbor(n).index(); } + + unsigned long get_neighbor_idx ( + unsigned long idx1, + unsigned long idx2 + ) const + { + for (unsigned long i = 0; i < g.node(idx1).number_of_neighbors(); ++i) + { + if (g.node(idx1).neighbor(i).index() == idx2) + return i; + } + + // This should never ever execute + return 0; + } + + void set_label ( + const unsigned long& idx, + node_label value + ) + { + labels[idx] = value; + } + + node_label get_label ( + const unsigned long& idx + ) const { return labels[idx]; } + + typedef typename graph_type::edge_type value_type; + + value_type factor_value ( + unsigned long idx + ) const + { + return g.node(idx).data; + } + + value_type factor_value_disagreement ( + unsigned long idx1, + unsigned long idx2 + ) const + { + return edge(g, idx1, idx2); + } + + }; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_GENERAL_POTTS_PRoBLEM_Hh_ + + diff --git a/ml/dlib/dlib/graph_cuts/graph_labeler.h b/ml/dlib/dlib/graph_cuts/graph_labeler.h new file mode 100644 index 000000000..34c3fcb5e --- /dev/null +++ b/ml/dlib/dlib/graph_cuts/graph_labeler.h @@ -0,0 +1,211 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_GRAPH_LaBELER_Hh_ +#define DLIB_GRAPH_LaBELER_Hh_ + +#include "graph_labeler_abstract.h" +#include "../matrix.h" +#include "../string.h" +#include +#include "find_max_factor_graph_potts.h" +#include "../svm/sparse_vector.h" +#include "../graph.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename vector_type + > + class graph_labeler + { + + public: + + typedef std::vector label_type; + typedef label_type result_type; + + graph_labeler() + { + } + + graph_labeler( + const vector_type& edge_weights_, + const vector_type& node_weights_ + ) : + edge_weights(edge_weights_), + node_weights(node_weights_) + { + // make sure requires clause is not broken + DLIB_ASSERT(edge_weights.size() == 0 || min(edge_weights) >= 0, + "\t graph_labeler::graph_labeler()" + << "\n\t Invalid inputs were given to this function." + << "\n\t min(edge_weights): " << min(edge_weights) + << "\n\t this: " << this + ); + } + + const vector_type& get_edge_weights ( + ) const { return edge_weights; } + + const vector_type& get_node_weights ( + ) const { return node_weights; } + + template + void operator() ( + const graph_type& sample, + std::vector& labels + ) const + { + // make sure requires clause is not broken +#ifdef ENABLE_ASSERTS + DLIB_ASSERT(graph_contains_length_one_cycle(sample) == false, + "\t void graph_labeler::operator()" + << "\n\t Invalid inputs were given to this function." + << "\n\t get_edge_weights().size(): " << get_edge_weights().size() + << "\n\t get_node_weights().size(): " << get_node_weights().size() + << "\n\t graph_contains_length_one_cycle(sample): " << graph_contains_length_one_cycle(sample) + << "\n\t this: " << this + ); + for (unsigned long i = 0; i < sample.number_of_nodes(); ++i) + { + if (is_matrix::value && + is_matrix::value) + { + // check that dot() is legal. + DLIB_ASSERT((unsigned long)get_node_weights().size() == (unsigned long)sample.node(i).data.size(), + "\t void graph_labeler::operator()" + << "\n\t The size of the node weight vector must match the one in the node." + << "\n\t get_node_weights().size(): " << get_node_weights().size() + << "\n\t sample.node(i).data.size(): " << sample.node(i).data.size() + << "\n\t i: " << i + << "\n\t this: " << this + ); + } + + for (unsigned long n = 0; n < sample.node(i).number_of_neighbors(); ++n) + { + if (is_matrix::value && + is_matrix::value) + { + // check that dot() is legal. + DLIB_ASSERT((unsigned long)get_edge_weights().size() == (unsigned long)sample.node(i).edge(n).size(), + "\t void graph_labeler::operator()" + << "\n\t The size of the edge weight vector must match the one in graph's edge." + << "\n\t get_edge_weights().size(): " << get_edge_weights().size() + << "\n\t sample.node(i).edge(n).size(): " << sample.node(i).edge(n).size() + << "\n\t i: " << i + << "\n\t this: " << this + ); + } + + DLIB_ASSERT(sample.node(i).edge(n).size() == 0 || min(sample.node(i).edge(n)) >= 0, + "\t void graph_labeler::operator()" + << "\n\t No edge vectors are allowed to have negative elements." + << "\n\t min(sample.node(i).edge(n)): " << min(sample.node(i).edge(n)) + << "\n\t i: " << i + << "\n\t n: " << n + << "\n\t this: " << this + ); + } + } +#endif + + + graph::kernel_1a g; + copy_graph_structure(sample, g); + for (unsigned long i = 0; i < g.number_of_nodes(); ++i) + { + g.node(i).data = dot(node_weights, sample.node(i).data); + + for (unsigned long n = 0; n < g.node(i).number_of_neighbors(); ++n) + { + const unsigned long j = g.node(i).neighbor(n).index(); + // Don't compute an edge weight more than once. + if (i < j) + { + g.node(i).edge(n) = dot(edge_weights, sample.node(i).edge(n)); + } + } + + } + + labels.clear(); + std::vector temp; + find_max_factor_graph_potts(g, temp); + for (unsigned long i = 0; i < temp.size(); ++i) + { + if (temp[i] != 0) + labels.push_back(true); + else + labels.push_back(false); + } + } + + template + std::vector operator() ( + const graph_type& sample + ) const + { + std::vector temp; + (*this)(sample, temp); + return temp; + } + + private: + + vector_type edge_weights; + vector_type node_weights; + }; + + +// ---------------------------------------------------------------------------------------- + + template < + typename vector_type + > + void serialize ( + const graph_labeler& item, + std::ostream& out + ) + { + int version = 1; + serialize(version, out); + serialize(item.get_edge_weights(), out); + serialize(item.get_node_weights(), out); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename vector_type + > + void deserialize ( + graph_labeler& item, + std::istream& in + ) + { + int version = 0; + deserialize(version, in); + if (version != 1) + { + throw dlib::serialization_error("While deserializing graph_labeler, found unexpected version number of " + + cast_to_string(version) + "."); + } + + vector_type edge_weights, node_weights; + deserialize(edge_weights, in); + deserialize(node_weights, in); + + item = graph_labeler(edge_weights, node_weights); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_GRAPH_LaBELER_Hh_ + + diff --git a/ml/dlib/dlib/graph_cuts/graph_labeler_abstract.h b/ml/dlib/dlib/graph_cuts/graph_labeler_abstract.h new file mode 100644 index 000000000..a0821b696 --- /dev/null +++ b/ml/dlib/dlib/graph_cuts/graph_labeler_abstract.h @@ -0,0 +1,185 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_GRAPH_LaBELER_ABSTRACT_Hh_ +#ifdef DLIB_GRAPH_LaBELER_ABSTRACT_Hh_ + +#include "find_max_factor_graph_potts_abstract.h" +#include "../graph/graph_kernel_abstract.h" +#include "../matrix/matrix_abstract.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename vector_type + > + class graph_labeler + { + /*! + REQUIREMENTS ON vector_type + - vector_type is a dlib::matrix capable of representing column + vectors or it is a sparse vector type as defined in dlib/svm/sparse_vector_abstract.h. + + WHAT THIS OBJECT REPRESENTS + This object is a tool for labeling each node in a graph with a value + of true or false, subject to a labeling consistency constraint between + nodes that share an edge. In particular, this object is useful for + representing a graph labeling model learned via some machine learning + method. + + To elaborate, suppose we have a graph we want to label. Moreover, + suppose we can assign a score to each node which represents how much + we want to label the node as true, and we also have scores for each + edge which represent how much we wanted the nodes sharing the edge to + have the same label. If we could do this then we could find the optimal + labeling using the find_max_factor_graph_potts() routine. Therefore, + the graph_labeler is just an object which contains the necessary data + to compute these score functions and then call find_max_factor_graph_potts(). + Additionally, this object uses linear functions to represent these score + functions. + + THREAD SAFETY + It is always safe to use distinct instances of this object in different + threads. However, when a single instance is shared between threads then + the following rules apply: + It is safe to call the const members of this object from multiple + threads. This is because the const members are purely read-only + operations. However, any operation that modifies a graph_labeler is + not threadsafe. + !*/ + + public: + + typedef std::vector label_type; + typedef label_type result_type; + + graph_labeler( + ); + /*! + ensures + - this object is properly initialized + - #get_node_weights() == an initial value of type vector_type. + - #get_edge_weights() == an initial value of type vector_type. + !*/ + + graph_labeler( + const vector_type& edge_weights, + const vector_type& node_weights + ); + /*! + requires + - min(edge_weights) >= 0 + ensures + - #get_edge_weights() == edge_weights + - #get_node_weights() == node_weights + !*/ + + const vector_type& get_edge_weights ( + ) const; + /*! + ensures + - Recall that the score function for an edge is a linear function of + the vector stored at that edge. This means there is some vector, E, + which we dot product with the vector in the graph to compute the + score. Therefore, this function returns that E vector which defines + the edge score function. + !*/ + + const vector_type& get_node_weights ( + ) const; + /*! + ensures + - Recall that the score function for a node is a linear function of + the vector stored in that node. This means there is some vector, W, + which we dot product with the vector in the graph to compute the score. + Therefore, this function returns that W vector which defines the node + score function. + !*/ + + template + void operator() ( + const graph_type& sample, + std::vector& labels + ) const; + /*! + requires + - graph_type is an implementation of dlib/graph/graph_kernel_abstract.h + - graph_type::type and graph_type::edge_type must be either matrix objects + capable of representing column vectors or some kind of sparse vector + type as defined in dlib/svm/sparse_vector_abstract.h. + - graph_contains_length_one_cycle(sample) == false + - for all valid i and j: + - min(edge(sample,i,j)) >= 0 + - it must be legal to call dot(edge(sample,i,j), get_edge_weights()) + - it must be legal to call dot(sample.node(i).data, get_node_weights()) + ensures + - Computes a labeling for each node in the given graph and stores the result + in #labels. + - #labels.size() == sample.number_of_nodes() + - for all valid i: + - #labels[i] == the label of the node sample.node(i). + - The labels are computed by creating a graph, G, with scalar values on each node + and edge. The scalar values are calculated according to the following: + - for all valid i: + - G.node(i).data == dot(get_node_weights(), sample.node(i).data) + - for all valid i and j: + - edge(G,i,j) == dot(get_edge_weights(), edge(sample,i,j)) + Then the labels are computed by calling find_max_factor_graph_potts(G,#labels). + !*/ + + template + std::vector operator() ( + const graph_type& sample + ) const; + /*! + requires + - graph_type is an implementation of dlib/graph/graph_kernel_abstract.h + - graph_contains_length_one_cycle(sample) == false + - for all valid i and j: + - min(edge(sample,i,j)) >= 0 + - it must be legal to call dot(edge(sample,i,j), get_edge_weights()) + - it must be legal to call dot(sample.node(i).data, get_node_weights()) + ensures + - Performs (*this)(sample, labels); return labels; + (i.e. This is just another version of the above operator() routine + but instead of returning the labels via the second argument, it + returns them as the normal return value). + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename vector_type + > + void serialize ( + const graph_labeler& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename vector_type + > + void deserialize ( + graph_labeler& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_GRAPH_LaBELER_ABSTRACT_Hh_ + diff --git a/ml/dlib/dlib/graph_cuts/min_cut.h b/ml/dlib/dlib/graph_cuts/min_cut.h new file mode 100644 index 000000000..6bbb57608 --- /dev/null +++ b/ml/dlib/dlib/graph_cuts/min_cut.h @@ -0,0 +1,571 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MIN_CuT_Hh_ +#define DLIB_MIN_CuT_Hh_ + +#include "min_cut_abstract.h" +#include "../matrix.h" +#include "general_flow_graph.h" +#include "../is_kind.h" + +#include +#include +#include + + +// ---------------------------------------------------------------------------------------- + + +namespace dlib +{ + + typedef unsigned char node_label; + +// ---------------------------------------------------------------------------------------- + + const node_label SOURCE_CUT = 0; + const node_label SINK_CUT = 254; + const node_label FREE_NODE = 255; + +// ---------------------------------------------------------------------------------------- + + template + typename disable_if,typename flow_graph::edge_type>::type + graph_cut_score ( + const flow_graph& g + ) + { + typedef typename flow_graph::edge_type edge_weight_type; + edge_weight_type score = 0; + typedef typename flow_graph::out_edge_iterator out_edge_iterator; + for (unsigned long i = 0; i < g.number_of_nodes(); ++i) + { + if (g.get_label(i) != SOURCE_CUT) + continue; + + for (out_edge_iterator n = g.out_begin(i); n != g.out_end(i); ++n) + { + if (g.get_label(g.node_id(n)) != SOURCE_CUT) + { + score += g.get_flow(n); + } + } + } + + return score; + } + + template + typename enable_if,typename directed_graph::edge_type>::type + graph_cut_score ( + const directed_graph& g + ) + { + return graph_cut_score(dlib::impl::general_flow_graph(g)); + } + +// ---------------------------------------------------------------------------------------- + + class min_cut + { + + public: + + min_cut() + { + } + + min_cut( const min_cut& ) + { + // Intentionally left empty since all the member variables + // don't logically contribute to the state of this object. + // This copy constructor is here to explicitly avoid the overhead + // of copying these transient variables. + } + + template < + typename directed_graph + > + typename enable_if >::type operator() ( + directed_graph& g, + const unsigned long source_node, + const unsigned long sink_node + ) const + { + DLIB_ASSERT(graph_contains_length_one_cycle(g) == false, + "\t void min_cut::operator()" + << "\n\t Invalid arguments were given to this function." + ); + DLIB_ASSERT(graph_has_symmetric_edges(g) == true, + "\t void min_cut::operator()" + << "\n\t Invalid arguments were given to this function." + ); + + dlib::impl::general_flow_graph temp(g); + (*this)(temp, source_node, sink_node); + } + + template < + typename flow_graph + > + typename disable_if >::type operator() ( + flow_graph& g, + const unsigned long source_node, + const unsigned long sink_node + ) const + { +#ifdef ENABLE_ASSERTS + DLIB_ASSERT(source_node != sink_node && + source_node < g.number_of_nodes() && + sink_node < g.number_of_nodes(), + "\t void min_cut::operator()" + << "\n\t Invalid arguments were given to this function." + << "\n\t g.number_of_nodes(): " << g.number_of_nodes() + << "\n\t source_node: " << source_node + << "\n\t sink_node: " << sink_node + << "\n\t this: " << this + ); + + for (unsigned long i = 0; i < g.number_of_nodes(); ++i) + { + typename flow_graph::out_edge_iterator j, end = g.out_end(i); + for (j = g.out_begin(i); j != end; ++j) + { + const unsigned long jj = g.node_id(j); + DLIB_ASSERT(g.get_flow(i,jj) >= 0, + "\t void min_cut::operator()" + << "\n\t Invalid inputs were given to this function." + << "\n\t i: "<< i + << "\n\t jj: "<< jj + << "\n\t g.get_flow(i,jj): "<< g.get_flow(i,jj) + << "\n\t this: "<< this + ); + + } + } +#endif + parent.clear(); + active.clear(); + orphans.clear(); + + typedef typename flow_graph::edge_type edge_type; + COMPILE_TIME_ASSERT(is_signed_type::value); + + typedef typename flow_graph::out_edge_iterator out_edge_iterator; + typedef typename flow_graph::in_edge_iterator in_edge_iterator; + + // initialize labels + for (unsigned long i = 0; i < g.number_of_nodes(); ++i) + g.set_label(i, FREE_NODE); + + g.set_label(source_node, SOURCE_CUT); + g.set_label(sink_node, SINK_CUT); + + // used to indicate "no parent" + const unsigned long no_parent = g.number_of_nodes(); + + parent.assign(g.number_of_nodes(), no_parent); + + time = 1; + dist.assign(g.number_of_nodes(), 0); + ts.assign(g.number_of_nodes(), time); + + active.push_back(source_node); + active.push_back(sink_node); + + + in_edge_iterator in_begin = g.in_begin(active[0]); + out_edge_iterator out_begin = g.out_begin(active[0]); + + unsigned long source_side, sink_side; + while (grow(g,source_side,sink_side, in_begin, out_begin)) + { + ++time; + ts[source_node] = time; + ts[sink_node] = time; + + augment(g, source_node, sink_node, source_side, sink_side); + adopt(g, source_node, sink_node); + } + + } + + + private: + + unsigned long distance_to_origin ( + const unsigned long no_parent, + unsigned long p, + unsigned long + ) const + { + unsigned long start = p; + unsigned long count = 0; + while (p != no_parent) + { + if (ts[p] == time) + { + count += dist[p]; + + unsigned long count_down = count; + // adjust the dist and ts for the nodes on this path. + while (start != p) + { + ts[start] = time; + dist[start] = count_down; + --count_down; + start = parent[start]; + } + + return count; + } + p = parent[p]; + ++count; + } + + return std::numeric_limits::max(); + } + + template + void adopt ( + flow_graph& g, + const unsigned long source, + const unsigned long sink + ) const + { + typedef typename flow_graph::out_edge_iterator out_edge_iterator; + typedef typename flow_graph::in_edge_iterator in_edge_iterator; + + // used to indicate "no parent" + const unsigned long no_parent = g.number_of_nodes(); + + while (orphans.size() > 0) + { + const unsigned long p = orphans.back(); + orphans.pop_back(); + + const unsigned char label_p = g.get_label(p); + + // Try to find a valid parent for p. + if (label_p == SOURCE_CUT) + { + const in_edge_iterator begin(g.in_begin(p)); + const in_edge_iterator end(g.in_end(p)); + unsigned long best_dist = std::numeric_limits::max(); + unsigned long best_node = 0; + for(in_edge_iterator q = begin; q != end; ++q) + { + const unsigned long id = g.node_id(q); + + if (g.get_label(id) != label_p || g.get_flow(q) <= 0 ) + continue; + + unsigned long temp = distance_to_origin(no_parent, id,source); + if (temp < best_dist) + { + best_dist = temp; + best_node = id; + } + + } + if (best_dist != std::numeric_limits::max()) + { + parent[p] = best_node; + dist[p] = dist[best_node] + 1; + ts[p] = time; + } + + // if we didn't find a parent for p + if (parent[p] == no_parent) + { + for(in_edge_iterator q = begin; q != end; ++q) + { + const unsigned long id = g.node_id(q); + + if (g.get_label(id) != SOURCE_CUT) + continue; + + if (g.get_flow(q) > 0) + active.push_back(id); + + if (parent[id] == p) + { + parent[id] = no_parent; + orphans.push_back(id); + } + } + g.set_label(p, FREE_NODE); + } + } + else + { + unsigned long best_node = 0; + unsigned long best_dist = std::numeric_limits::max(); + const out_edge_iterator begin(g.out_begin(p)); + const out_edge_iterator end(g.out_end(p)); + for(out_edge_iterator q = begin; q != end; ++q) + { + const unsigned long id = g.node_id(q); + if (g.get_label(id) != label_p || g.get_flow(q) <= 0) + continue; + + unsigned long temp = distance_to_origin(no_parent, id,sink); + + if (temp < best_dist) + { + best_dist = temp; + best_node = id; + } + } + + if (best_dist != std::numeric_limits::max()) + { + parent[p] = best_node; + dist[p] = dist[best_node] + 1; + ts[p] = time; + } + + // if we didn't find a parent for p + if (parent[p] == no_parent) + { + for(out_edge_iterator q = begin; q != end; ++q) + { + const unsigned long id = g.node_id(q); + + if (g.get_label(id) != SINK_CUT) + continue; + + if (g.get_flow(q) > 0) + active.push_back(id); + + if (parent[id] == p) + { + parent[id] = no_parent; + orphans.push_back(id); + } + } + + g.set_label(p, FREE_NODE); + } + } + + + } + + } + + template + void augment ( + flow_graph& g, + const unsigned long& source, + const unsigned long& sink, + const unsigned long& source_side, + const unsigned long& sink_side + ) const + { + typedef typename flow_graph::edge_type edge_type; + + // used to indicate "no parent" + const unsigned long no_parent = g.number_of_nodes(); + + unsigned long s = source_side; + unsigned long t = sink_side; + edge_type min_cap = g.get_flow(s,t); + + // find the bottleneck capacity on the current path. + + // check from source_side back to the source for the min capacity link. + t = s; + while (t != source) + { + s = parent[t]; + const edge_type temp = g.get_flow(s, t); + if (temp < min_cap) + { + min_cap = temp; + } + t = s; + } + + // check from sink_side back to the sink for the min capacity link + s = sink_side; + while (s != sink) + { + t = parent[s]; + const edge_type temp = g.get_flow(s, t); + if (temp < min_cap) + { + min_cap = temp; + } + s = t; + } + + + // now push the max possible amount of flow though the path + s = source_side; + t = sink_side; + g.adjust_flow(t,s, min_cap); + + // trace back towards the source + t = s; + while (t != source) + { + s = parent[t]; + g.adjust_flow(t,s, min_cap); + if (g.get_flow(s,t) <= 0) + { + parent[t] = no_parent; + orphans.push_back(t); + } + + t = s; + } + + // trace back towards the sink + s = sink_side; + while (s != sink) + { + t = parent[s]; + g.adjust_flow(t,s, min_cap); + if (g.get_flow(s,t) <= 0) + { + parent[s] = no_parent; + orphans.push_back(s); + } + s = t; + } + } + + template + bool grow ( + flow_graph& g, + unsigned long& source_side, + unsigned long& sink_side, + typename flow_graph::in_edge_iterator& in_begin, + typename flow_graph::out_edge_iterator& out_begin + ) const + /*! + ensures + - if (an augmenting path was found) then + - returns true + - (#source_side, #sink_side) == the point where the two trees meet. + #source_side is part of the source tree and #sink_side is part of + the sink tree. + - else + - returns false + !*/ + { + typedef typename flow_graph::out_edge_iterator out_edge_iterator; + typedef typename flow_graph::in_edge_iterator in_edge_iterator; + + + while (active.size() != 0) + { + // pick an active node + const unsigned long A = active[0]; + + const unsigned char label_A = g.get_label(A); + + // process its neighbors + if (label_A == SOURCE_CUT) + { + const out_edge_iterator out_end = g.out_end(A); + for(out_edge_iterator& i = out_begin; i != out_end; ++i) + { + if (g.get_flow(i) > 0) + { + const unsigned long id = g.node_id(i); + const unsigned char label_i = g.get_label(id); + if (label_i == FREE_NODE) + { + active.push_back(id); + g.set_label(id,SOURCE_CUT); + parent[id] = A; + ts[id] = ts[A]; + dist[id] = dist[A] + 1; + } + else if (label_A != label_i) + { + source_side = A; + sink_side = id; + return true; + } + else if (is_closer(A, id)) + { + parent[id] = A; + ts[id] = ts[A]; + dist[id] = dist[A] + 1; + } + } + } + } + else if (label_A == SINK_CUT) + { + const in_edge_iterator in_end = g.in_end(A); + for(in_edge_iterator& i = in_begin; i != in_end; ++i) + { + if (g.get_flow(i) > 0) + { + const unsigned long id = g.node_id(i); + const unsigned char label_i = g.get_label(id); + if (label_i == FREE_NODE) + { + active.push_back(id); + g.set_label(id,SINK_CUT); + parent[id] = A; + ts[id] = ts[A]; + dist[id] = dist[A] + 1; + } + else if (label_A != label_i) + { + sink_side = A; + source_side = id; + return true; + } + else if (is_closer(A, id)) + { + parent[id] = A; + ts[id] = ts[A]; + dist[id] = dist[A] + 1; + } + } + } + } + + active.pop_front(); + if (active.size() != 0) + { + in_begin = g.in_begin(active[0]); + out_begin = g.out_begin(active[0]); + } + } + + return false; + } + + inline bool is_closer ( + unsigned long p, + unsigned long q + ) const + { + // return true if p is closer to a terminal than q + return ts[q] <= ts[p] && dist[q] > dist[p]; + } + + mutable std::vector dist; + mutable std::vector ts; + mutable uint32 time; + mutable std::vector parent; + + mutable std::deque active; + mutable std::vector orphans; + }; + +// ---------------------------------------------------------------------------------------- + +} + +// ---------------------------------------------------------------------------------------- + +#endif // DLIB_MIN_CuT_Hh_ + diff --git a/ml/dlib/dlib/graph_cuts/min_cut_abstract.h b/ml/dlib/dlib/graph_cuts/min_cut_abstract.h new file mode 100644 index 000000000..748aca950 --- /dev/null +++ b/ml/dlib/dlib/graph_cuts/min_cut_abstract.h @@ -0,0 +1,476 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_MIN_CuT_ABSTRACT_Hh_ +#ifdef DLIB_MIN_CuT_ABSTRACT_Hh_ + +#include "../graph_utils.h" + +// ---------------------------------------------------------------------------------------- + +namespace dlib +{ + /*!A node_label + The node_label type is the type used to label which part of a graph cut + a node is on. It is used by all the graph cut tools. The three possible + values of a node label are SOURCE_CUT, SINK_CUT, or FREE_NODE. + !*/ + + typedef unsigned char node_label; + const node_label SOURCE_CUT = 0; + const node_label SINK_CUT = 254; + const node_label FREE_NODE = 255; + +// ---------------------------------------------------------------------------------------- + + class flow_graph + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a flow capacity graph for use with the + min_cut algorithm defined below. In particular, this object + is a kind of directed graph where the edge weights specify the + flow capacities. + + Note that there is no dlib::flow_graph object. What you are + looking at here is simply the interface definition for a graph + which can be used with the min_cut algorithm. You must implement + your own version of this object for the graph you wish to work with + and then pass it to the min_cut::operator() routine. + + It's also worth pointing out that this graph has symmetric edge + connections. That is, if there is an edge from node A to node B + then there must also be an edge from node B to node A. + !*/ + + public: + + class out_edge_iterator + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a simple forward iterator for iterating over the neighbors + of a node in the graph. It also represents the fact that the neighbors + are on the end of an outgoing edge. That is, the edge represents + the amount of flow which can flow towards the neighbor. + !*/ + + public: + out_edge_iterator( + ); + /*! + ensures + - constructs an iterator in an undefined state. It can't + be used until assigned with a valid iterator. + !*/ + + out_edge_iterator( + const out_edge_iterator& item + ); + /*! + ensures + - #*this is a copy of item + !*/ + + out_edge_iterator& operator=( + const out_edge_iterator& item + ); + /*! + ensures + - #*this is a copy of item + - returns #*this + !*/ + + bool operator!= ( + const out_edge_iterator& item + ) const; + /*! + requires + - *this and item are iterators over the neighbors for the + same node. + ensures + - returns false if *this and item both reference the same + node in the graph and true otherwise. + !*/ + + out_edge_iterator& operator++( + ); + /*! + ensures + - advances *this to the next neighbor node. + - returns a reference to the updated *this + (i.e. this is the ++object form of the increment operator) + !*/ + }; + + class in_edge_iterator + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a simple forward iterator for iterating over the neighbors + of a node in the graph. It also represents the fact that the neighbors + are on the end of an incoming edge. That is, the edge represents + the amount of flow which can flow out of the neighbor node. + !*/ + + public: + + in_edge_iterator( + ); + /*! + ensures + - constructs an iterator in an undefined state. It can't + be used until assigned with a valid iterator. + !*/ + + in_edge_iterator( + const in_edge_iterator& item + ); + /*! + ensures + - #*this is a copy of item + !*/ + + in_edge_iterator& operator=( + const in_edge_iterator& item + ); + /*! + ensures + - #*this is a copy of item + - returns #*this + !*/ + + bool operator!= ( + const in_edge_iterator& item + ) const; + /*! + requires + - *this and item are iterators over the neighbors for the + same node. + ensures + - returns false if *this and item both reference the same + node in the graph and true otherwise. + !*/ + + in_edge_iterator& operator++( + ); + /*! + ensures + - advances *this to the next neighbor node. + - returns a reference to the updated *this + (i.e. this is the ++object form of the increment operator) + !*/ + }; + + + + unsigned long number_of_nodes ( + ) const; + /*! + ensures + - returns the number of nodes in the graph. + !*/ + + out_edge_iterator out_begin( + const unsigned long& idx + ) const; + /*! + requires + - idx < number_of_nodes() + ensures + - returns an iterator pointing to the first neighboring node of + the idx-th node. If no such node exists then returns out_end(idx). + - The returned iterator also represents the directed edge going from + node idx to the neighbor. + !*/ + + in_edge_iterator in_begin( + const unsigned long& idx + ) const; + /*! + requires + - idx < number_of_nodes() + ensures + - returns an iterator pointing to the first neighboring node of + the idx-th node. If no such node exists then returns in_end(idx). + - The returned iterator also represents the directed edge going from + the neighbor to node idx. + !*/ + + out_edge_iterator out_end( + const unsigned long& idx + ) const; + /*! + requires + - idx < number_of_nodes() + ensures + - returns an iterator to one past the last neighboring node of + the idx-th node. + !*/ + + in_edge_iterator in_end( + const unsigned long& idx + ) const; + /*! + requires + - idx < number_of_nodes() + ensures + - returns an iterator to one past the last neighboring node of + the idx-th node. + !*/ + + + unsigned long node_id ( + const out_edge_iterator& it + ) const; + /*! + requires + - it == a valid iterator (i.e. it must be in the range [out_begin(idx), out_end(idx)) + for some valid idx) + ensures + - returns a number IDX such that: + - 0 <= IDX < number_of_nodes() + - IDX == The index which uniquely identifies the node pointed to by the + iterator it. This number can be used with any member function in this + object which expect a node index. (e.g. get_label(IDX) == the label for the + node pointed to by it) + !*/ + + unsigned long node_id ( + const in_edge_iterator& it + ) const; + /*! + requires + - it == a valid iterator (i.e. it must be in the range [in_begin(idx), in_end(idx)) + for some valid idx) + ensures + - returns a number IDX such that: + - 0 <= IDX < number_of_nodes() + - IDX == The index which uniquely identifies the node pointed to by the + iterator it. This number can be used with any member function in this + object which expect a node index. (e.g. get_label(IDX) == the label for the + node pointed to by it) + !*/ + + // This typedef should be for a type like int or double. It + // must also be capable of representing signed values. + typedef an_integer_or_real_type edge_type; + + edge_type get_flow ( + const unsigned long& idx1, + const unsigned long& idx2 + ) const; + /*! + requires + - idx1 < number_of_nodes() + - idx2 < number_of_nodes() + - idx1 and idx2 are neighbors in the graph + ensures + - returns the residual flow capacity from the idx1-th node to the idx2-th node. + - It is valid for this function to return a floating point value of infinity. + This value means this edge has an unlimited capacity. + !*/ + + edge_type get_flow ( + const out_edge_iterator& it + ) const; + /*! + requires + - it == a valid iterator (i.e. it must be in the range [out_begin(idx), out_end(idx)) + for some valid idx) + ensures + - let IDX = node_id(it) + - it represents the directed edge from a node, call it H, to the node IDX. Therefore, + this function returns get_flow(H,IDX) + - It is valid for this function to return a floating point value of infinity. + This value means this edge has an unlimited capacity. + !*/ + + edge_type get_flow ( + const in_edge_iterator& it + ) const; + /*! + requires + - it == a valid iterator (i.e. it must be in the range [in_begin(idx), in_end(idx)) + for some valid idx) + ensures + - let IDX = node_id(it) + - it represents the directed edge from node IDX to another node, call it H. Therefore, + this function returns get_flow(IDX,H) + - It is valid for this function to return a floating point value of infinity. + This value means this edge has an unlimited capacity. + !*/ + + void adjust_flow ( + const unsigned long& idx1, + const unsigned long& idx2, + const edge_type& value + ); + /*! + requires + - idx1 < number_of_nodes() + - idx2 < number_of_nodes() + - idx1 and idx2 are neighbors in the graph + ensures + - #get_flow(idx1,idx2) == get_flow(idx1,idx2) + value + - #get_flow(idx2,idx1) == get_flow(idx2,idx1) - value + !*/ + + void set_label ( + const unsigned long& idx, + node_label value + ); + /*! + requires + - idx < number_of_nodes() + ensures + - #get_label(idx) == value + !*/ + + node_label get_label ( + const unsigned long& idx + ) const; + /*! + requires + - idx < number_of_nodes() + ensures + - returns the label for the idx-th node in the graph. + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename flow_graph + > + typename flow_graph::edge_type graph_cut_score ( + const flow_graph& g + ); + /*! + requires + - flow_graph == an object with an interface compatible with the flow_graph + object defined at the top of this file, or, an implementation of + dlib/directed_graph/directed_graph_kernel_abstract.h. + ensures + - returns the sum of the outgoing flows from nodes with a label of SOURCE_CUT + to nodes with a label != SOURCE_CUT. Note that for a directed_graph object, + the labels are stored in the node's data field. + !*/ + +// ---------------------------------------------------------------------------------------- + + class min_cut + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a function object which can be used to find the min cut + on a graph. + + The implementation is based on the method described in the following + paper: + An Experimental Comparison of Min-Cut/Max-Flow Algorithms for + Energy Minimization in Vision, by Yuri Boykov and Vladimir Kolmogorov, + in PAMI 2004. + + !*/ + + public: + + min_cut( + ); + /*! + ensures + - this object is properly initialized + !*/ + + template < + typename flow_graph + > + void operator() ( + flow_graph& g, + const unsigned long source_node, + const unsigned long sink_node + ) const; + /*! + requires + - flow_graph == an object with an interface compatible with the flow_graph + object defined at the top of this file. + - source_node != sink_node + - source_node < g.number_of_nodes() + - sink_node < g.number_of_nodes() + - for all valid i and j: + - g.get_flow(i,j) >= 0 + (i.e. all the flow capacities/edge weights are non-negative) + - g does not contain any self loops. That is, no nodes are neighbors with + themselves. + ensures + - Finds the minimum cut on the given graph. That is, this function finds + a labeling of nodes in g such that graph_cut_score(g) would be minimized. Note + that the flow values in #g are modified by this algorithm so if you want + to obtain the min cut score you must call min_cut::operator(), then copy + the flow values back into #g, and then call graph_cut_score(#g). But in most + cases you don't care about the value of the min cut score, rather, you + just want the labels in #g. + - #g.get_label(source_node) == SOURCE_CUT + - #g.get_label(sink_node) == SINK_CUT + - for all valid i: + - #g.get_label(i) == SOURCE_CUT, SINK_CUT, or FREE_NODE + - if (#g.get_label(i) == SOURCE_CUT) then + - The minimum cut of g places node i into the source side of the cut. + - if (#g.get_label(i) == SINK_CUT) then + - The minimum cut of g places node i into the sink side of the cut. + - if (#g.get_label(i) == FREE_NODE) then + - Node i can be labeled SOURCE_CUT or SINK_CUT. Both labelings + result in the same cut score. + - When interpreting g as a graph of flow capacities from the source_node + to the sink_node we can say that the min cut problem is equivalent to + the max flow problem. This equivalent problem is to find out how to push + as much "flow" from the source node to the sink node as possible. + Upon termination, #g will contain the final flow residuals in addition to + the graph cut labels. That is, for all valid i and j: + - #g.get_flow(i,j) == the residual flow capacity left after the max + possible amount of flow is passing from the source node to the sink + node. For example, this means that #g.get_flow(i,j) == 0 whenever + node i is in the SOURCE_CUT and j is in the SINK_CUT. + - #g.get_flow(i,j) >= 0 + !*/ + + template < + typename directed_graph + > + void operator() ( + directed_graph& g, + const unsigned long source_node, + const unsigned long sink_node + ) const; + /*! + requires + - directed_graph == an implementation of dlib/directed_graph/directed_graph_kernel_abstract.h + - directed_graph::type == node_label + - directed_graph::edge_type == and integer or double type + - source_node != sink_node + - source_node < g.number_of_nodes() + - sink_node < g.number_of_nodes() + - for all valid i and j: + - edge(g,i,j) >= 0 + (i.e. all the flow capacities/edge weights are positive) + - graph_contains_length_one_cycle(g) == false + - graph_has_symmetric_edges(g) == true + ensures + - This routine simply converts g into a flow graph and calls the version + of operator() defined above. Note that the conversion is done in O(1) + time, it's just an interface adaptor. + - edge weights in g correspond to network flows while the .data field of + each node in g corresponds to the graph node labels. + - upon termination, the flows and labels in g will have been modified + as described in the above operator() routine. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MIN_CuT_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/graph_utils.h b/ml/dlib/dlib/graph_utils.h new file mode 100644 index 000000000..c79e05b64 --- /dev/null +++ b/ml/dlib/dlib/graph_utils.h @@ -0,0 +1,12 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_GRAPH_UTILs_H_ +#define DLIB_GRAPH_UTILs_H_ + +#include "graph_utils/graph_utils.h" +#include "graph_utils/edge_list_graphs.h" +#include "graph_utils/function_objects.h" + +#endif // DLIB_GRAPH_UTILs_H_ + + diff --git a/ml/dlib/dlib/graph_utils/edge_list_graphs.h b/ml/dlib/dlib/graph_utils/edge_list_graphs.h new file mode 100644 index 000000000..d2447acdb --- /dev/null +++ b/ml/dlib/dlib/graph_utils/edge_list_graphs.h @@ -0,0 +1,593 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_EDGE_LIST_GrAPHS_Hh_ +#define DLIB_EDGE_LIST_GrAPHS_Hh_ + +#include "edge_list_graphs_abstract.h" +#include +#include +#include "../string.h" +#include "../rand.h" +#include +#include "sample_pair.h" +#include "ordered_sample_pair.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename vector_type + > + void remove_duplicate_edges ( + vector_type& pairs + ) + { + typedef typename vector_type::value_type T; + if (pairs.size() > 0) + { + // sort pairs so that we can avoid duplicates in the loop below + std::sort(pairs.begin(), pairs.end(), &order_by_index); + + // now put edges into temp while avoiding duplicates + vector_type temp; + temp.reserve(pairs.size()); + temp.push_back(pairs[0]); + for (unsigned long i = 1; i < pairs.size(); ++i) + { + if (pairs[i] != pairs[i-1]) + { + temp.push_back(pairs[i]); + } + } + + temp.swap(pairs); + } + } + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + template + iterator iterator_of_worst ( + iterator begin, + const iterator& end + ) + /*! + ensures + - returns an iterator that points to the element in the given range + that has the biggest distance + !*/ + { + double dist = begin->distance(); + iterator worst = begin; + for (; begin != end; ++begin) + { + if (begin->distance() > dist) + { + dist = begin->distance(); + worst = begin; + } + } + + return worst; + } + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename vector_type, + typename distance_function_type, + typename alloc, + typename T + > + void find_percent_shortest_edges_randomly ( + const vector_type& samples, + const distance_function_type& dist_funct, + const double percent, + const unsigned long num, + const T& random_seed, + std::vector& out + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( 0 < percent && percent <= 1 && + num > 0, + "\t void find_percent_shortest_edges_randomly()" + << "\n\t Invalid inputs were given to this function." + << "\n\t samples.size(): " << samples.size() + << "\n\t percent: " << percent + << "\n\t num: " << num + ); + + out.clear(); + + if (samples.size() <= 1) + { + return; + } + + std::vector edges; + edges.reserve(num); + + dlib::rand rnd; + rnd.set_seed(cast_to_string(random_seed)); + + // randomly sample a bunch of edges + for (unsigned long i = 0; i < num; ++i) + { + const unsigned long idx1 = rnd.get_random_32bit_number()%samples.size(); + const unsigned long idx2 = rnd.get_random_32bit_number()%samples.size(); + if (idx1 != idx2) + { + const double dist = dist_funct(samples[idx1], samples[idx2]); + if (dist < std::numeric_limits::infinity()) + { + edges.push_back(sample_pair(idx1, idx2, dist)); + } + } + } + + + // now put edges into out while avoiding duplicates + if (edges.size() > 0) + { + remove_duplicate_edges(edges); + + // now sort all the edges by distance and take the percent with the smallest distance + std::sort(edges.begin(), edges.end(), &order_by_distance); + + const unsigned long out_size = std::min((unsigned long)(num*percent), edges.size()); + out.assign(edges.begin(), edges.begin() + out_size); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename vector_type, + typename distance_function_type, + typename alloc, + typename T + > + void find_approximate_k_nearest_neighbors ( + const vector_type& samples, + const distance_function_type& dist_funct, + const unsigned long k, + unsigned long num, + const T& random_seed, + std::vector& out + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( num > 0 && k > 0, + "\t void find_approximate_k_nearest_neighbors()" + << "\n\t Invalid inputs were given to this function." + << "\n\t samples.size(): " << samples.size() + << "\n\t k: " << k + << "\n\t num: " << num + ); + + out.clear(); + + if (samples.size() <= 1) + { + return; + } + + // we add each edge twice in the following loop. So multiply num by 2 to account for that. + num *= 2; + + std::vector edges; + edges.reserve(num); + std::vector temp; + temp.reserve(num); + + dlib::rand rnd; + rnd.set_seed(cast_to_string(random_seed)); + + // randomly sample a bunch of edges + for (unsigned long i = 0; i < num; ++i) + { + const unsigned long idx1 = rnd.get_random_32bit_number()%samples.size(); + const unsigned long idx2 = rnd.get_random_32bit_number()%samples.size(); + if (idx1 != idx2) + { + const double dist = dist_funct(samples[idx1], samples[idx2]); + if (dist < std::numeric_limits::infinity()) + { + edges.push_back(ordered_sample_pair(idx1, idx2, dist)); + edges.push_back(ordered_sample_pair(idx2, idx1, dist)); + } + } + } + + std::sort(edges.begin(), edges.end(), &order_by_index); + + std::vector::iterator beg, itr; + // now copy edges into temp when they aren't duplicates and also only move in the k shortest for + // each index. + itr = edges.begin(); + while (itr != edges.end()) + { + // first find the bounding range for all the edges connected to node itr->index1() + beg = itr; + while (itr != edges.end() && itr->index1() == beg->index1()) + ++itr; + + // If the node has more than k edges then sort them by distance so that + // we will end up with the k best. + if (static_cast(itr - beg) > k) + { + std::sort(beg, itr, &order_by_distance_and_index); + } + + // take the k best unique edges from the range [beg,itr) + temp.push_back(sample_pair(beg->index1(), beg->index2(), beg->distance())); + unsigned long prev_index2 = beg->index2(); + ++beg; + unsigned long count = 1; + for (; beg != itr && count < k; ++beg) + { + if (beg->index2() != prev_index2) + { + temp.push_back(sample_pair(beg->index1(), beg->index2(), beg->distance())); + ++count; + } + prev_index2 = beg->index2(); + } + } + + + remove_duplicate_edges(temp); + temp.swap(out); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename vector_type, + typename distance_function_type, + typename alloc + > + void find_k_nearest_neighbors ( + const vector_type& samples, + const distance_function_type& dist_funct, + const unsigned long k, + std::vector& out + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(k > 0, + "\t void find_k_nearest_neighbors()" + << "\n\t Invalid inputs were given to this function." + << "\n\t samples.size(): " << samples.size() + << "\n\t k: " << k + ); + + out.clear(); + + if (samples.size() <= 1) + { + return; + } + + using namespace impl; + std::vector edges; + + // Initialize all the edges to an edge with an invalid index + edges.resize(samples.size()*k, + sample_pair(samples.size(),samples.size(),std::numeric_limits::infinity())); + + // Hold the length for the longest edge for each node. Initially they are all infinity. + std::vector worst_dists(samples.size(), std::numeric_limits::infinity()); + + std::vector::iterator begin_i, end_i, begin_j, end_j; + begin_i = edges.begin(); + end_i = begin_i + k; + + // Loop over all combinations of samples. We will maintain the iterator ranges so that + // within the inner for loop we have: + // [begin_i, end_i) == the range in edges that contains neighbors of samples[i] + // [begin_j, end_j) == the range in edges that contains neighbors of samples[j] + for (unsigned long i = 0; i+1 < samples.size(); ++i) + { + begin_j = begin_i; + end_j = end_i; + + for (unsigned long j = i+1; j < samples.size(); ++j) + { + begin_j += k; + end_j += k; + + const double dist = dist_funct(samples[i], samples[j]); + + if (dist < worst_dists[i]) + { + *iterator_of_worst(begin_i, end_i) = sample_pair(i, j, dist); + worst_dists[i] = iterator_of_worst(begin_i, end_i)->distance(); + } + + if (dist < worst_dists[j]) + { + *iterator_of_worst(begin_j, end_j) = sample_pair(i, j, dist); + worst_dists[j] = iterator_of_worst(begin_j, end_j)->distance(); + } + } + + begin_i += k; + end_i += k; + } + + // sort the edges so that duplicate edges will be adjacent + std::sort(edges.begin(), edges.end(), &order_by_index); + + // if the first edge is valid + if (edges[0].index1() < samples.size()) + { + // now put edges into out while avoiding duplicates and any remaining invalid edges. + out.reserve(edges.size()); + out.push_back(edges[0]); + for (unsigned long i = 1; i < edges.size(); ++i) + { + // if we hit an invalid edge then we can stop + if (edges[i].index1() >= samples.size()) + break; + + // if this isn't a duplicate edge + if (edges[i] != edges[i-1]) + { + out.push_back(edges[i]); + } + } + } + + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename vector_type + > + bool contains_duplicate_pairs ( + const vector_type& pairs + ) + { + typedef typename vector_type::value_type T; + vector_type temp(pairs); + std::sort(temp.begin(), temp.end(), &order_by_index); + + for (unsigned long i = 1; i < temp.size(); ++i) + { + // if we found a duplicate + if (temp[i-1] == temp[i]) + return true; + } + + return false; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename vector_type + > + typename enable_if_c<(is_same_type::value || + is_same_type::value), + unsigned long>::type + max_index_plus_one ( + const vector_type& pairs + ) + { + if (pairs.size() == 0) + { + return 0; + } + else + { + unsigned long max_idx = 0; + for (unsigned long i = 0; i < pairs.size(); ++i) + { + if (pairs[i].index1() > max_idx) + max_idx = pairs[i].index1(); + if (pairs[i].index2() > max_idx) + max_idx = pairs[i].index2(); + } + + return max_idx + 1; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename vector_type + > + void remove_long_edges ( + vector_type& pairs, + double distance_threshold + ) + { + vector_type temp; + temp.reserve(pairs.size()); + + // add all the pairs shorter than the given threshold into temp + for (unsigned long i = 0; i < pairs.size(); ++i) + { + if (pairs[i].distance() <= distance_threshold) + temp.push_back(pairs[i]); + } + + // move temp into the output vector + temp.swap(pairs); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename vector_type + > + void remove_short_edges ( + vector_type& pairs, + double distance_threshold + ) + { + vector_type temp; + temp.reserve(pairs.size()); + + // add all the pairs longer than the given threshold into temp + for (unsigned long i = 0; i < pairs.size(); ++i) + { + if (pairs[i].distance() >= distance_threshold) + temp.push_back(pairs[i]); + } + + // move temp into the output vector + temp.swap(pairs); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename vector_type + > + void remove_percent_longest_edges ( + vector_type& pairs, + double percent + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( 0 <= percent && percent < 1, + "\t void remove_percent_longest_edges()" + << "\n\t Invalid inputs were given to this function." + << "\n\t percent: " << percent + ); + + typedef typename vector_type::value_type T; + std::sort(pairs.begin(), pairs.end(), &order_by_distance); + + const unsigned long num = static_cast((1.0-percent)*pairs.size()); + + // pick out the num shortest pairs + vector_type temp(pairs.begin(), pairs.begin() + num); + + // move temp into the output vector + temp.swap(pairs); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename vector_type + > + void remove_percent_shortest_edges ( + vector_type& pairs, + double percent + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( 0 <= percent && percent < 1, + "\t void remove_percent_shortest_edges()" + << "\n\t Invalid inputs were given to this function." + << "\n\t percent: " << percent + ); + + typedef typename vector_type::value_type T; + std::sort(pairs.rbegin(), pairs.rend(), &order_by_distance); + + const unsigned long num = static_cast((1.0-percent)*pairs.size()); + + // pick out the num shortest pairs + vector_type temp(pairs.begin(), pairs.begin() + num); + + // move temp into the output vector + temp.swap(pairs); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename vector_type + > + bool is_ordered_by_index ( + const vector_type& edges + ) + { + for (unsigned long i = 1; i < edges.size(); ++i) + { + if (order_by_index(edges[i], edges[i-1])) + return false; + } + return true; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename alloc1, + typename alloc2 + > + void find_neighbor_ranges ( + const std::vector& edges, + std::vector,alloc2>& neighbors + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(is_ordered_by_index(edges), + "\t void find_neighbor_ranges()" + << "\n\t Invalid inputs were given to this function" + ); + + + // setup neighbors so that [neighbors[i].first, neighbors[i].second) is the range + // within edges that contains all node i's edges. + const unsigned long num_nodes = max_index_plus_one(edges); + neighbors.assign(num_nodes, std::make_pair(0,0)); + unsigned long cur_node = 0; + unsigned long start_idx = 0; + for (unsigned long i = 0; i < edges.size(); ++i) + { + if (edges[i].index1() != cur_node) + { + neighbors[cur_node] = std::make_pair(start_idx, i); + start_idx = i; + cur_node = edges[i].index1(); + } + } + if (neighbors.size() != 0) + neighbors[cur_node] = std::make_pair(start_idx, (unsigned long)edges.size()); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename alloc1, + typename alloc2 + > + void convert_unordered_to_ordered ( + const std::vector& edges, + std::vector& out_edges + ) + { + out_edges.clear(); + out_edges.reserve(edges.size()*2); + for (unsigned long i = 0; i < edges.size(); ++i) + { + out_edges.push_back(ordered_sample_pair(edges[i].index1(), edges[i].index2(), edges[i].distance())); + if (edges[i].index1() != edges[i].index2()) + out_edges.push_back(ordered_sample_pair(edges[i].index2(), edges[i].index1(), edges[i].distance())); + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_EDGE_LIST_GrAPHS_Hh_ + + diff --git a/ml/dlib/dlib/graph_utils/edge_list_graphs_abstract.h b/ml/dlib/dlib/graph_utils/edge_list_graphs_abstract.h new file mode 100644 index 000000000..1f72c2739 --- /dev/null +++ b/ml/dlib/dlib/graph_utils/edge_list_graphs_abstract.h @@ -0,0 +1,358 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_EDGE_LIST_GrAPHS_ABSTRACT_Hh_ +#ifdef DLIB_EDGE_LIST_GrAPHS_ABSTRACT_Hh_ + +#include +#include "../string.h" +#include "sample_pair_abstract.h" +#include "ordered_sample_pair_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename vector_type, + typename distance_function_type, + typename alloc, + typename T + > + void find_percent_shortest_edges_randomly ( + const vector_type& samples, + const distance_function_type& dist_funct, + const double percent, + const unsigned long num, + const T& random_seed, + std::vector& out + ); + /*! + requires + - 0 < percent <= 1 + - num > 0 + - random_seed must be convertible to a string by dlib::cast_to_string() + - dist_funct(samples[i], samples[j]) must be a valid expression that evaluates + to a floating point number + ensures + - This function randomly samples the space of pairs of integers between + 0 and samples.size()-1 inclusive. For each of these pairs, (i,j), a + sample_pair is created as follows: + sample_pair(i, j, dist_funct(samples[i], samples[j])) + num such sample_pair objects are generated, duplicates and pairs with distance + values == infinity are removed, and then the top percent of them with the + smallest distance are stored into out. + - #out.size() <= num*percent + - contains_duplicate_pairs(#out) == false + - for all valid i: + - #out[i].distance() == dist_funct(samples[#out[i].index1()], samples[#out[i].index2()]) + - #out[i].distance() < std::numeric_limits::infinity() + - random_seed is used to seed the random number generator used by this + function. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename vector_type, + typename distance_function_type, + typename alloc, + typename T + > + void find_approximate_k_nearest_neighbors ( + const vector_type& samples, + const distance_function_type& dist_funct, + const unsigned long k, + const unsigned long num, + const T& random_seed, + std::vector& out + ); + /*! + requires + - k > 0 + - num > 0 + - random_seed must be convertible to a string by dlib::cast_to_string() + - dist_funct(samples[i], samples[j]) must be a valid expression that evaluates + to a floating point number + ensures + - This function computes an approximate form of k nearest neighbors. As num grows + larger the output of this function converges to the output of the + find_k_nearest_neighbors() function defined below. + - Specifically, this function randomly samples the space of pairs of integers between + 0 and samples.size()-1 inclusive. For each of these pairs, (i,j), a + sample_pair is created as follows: + sample_pair(i, j, dist_funct(samples[i], samples[j])) + num such sample_pair objects are generated and then exact k-nearest-neighbors + is performed amongst these sample_pairs and the results are stored into #out. + Note that samples with an infinite distance between them are considered to + be not connected at all. + - contains_duplicate_pairs(#out) == false + - for all valid i: + - #out[i].distance() == dist_funct(samples[#out[i].index1()], samples[#out[i].index2()]) + - #out[i].distance() < std::numeric_limits::infinity() + - random_seed is used to seed the random number generator used by this + function. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename vector_type, + typename distance_function_type, + typename alloc + > + void find_k_nearest_neighbors ( + const vector_type& samples, + const distance_function_type& dist_funct, + const unsigned long k, + std::vector& out + ); + /*! + requires + - k > 0 + - dist_funct(samples[i], samples[j]) must be a valid expression that evaluates + to a floating point number + ensures + - #out == a set of sample_pair objects that represent all the k nearest + neighbors in samples according to the given distance function dist_funct. + Note that samples with an infinite distance between them are considered to + be not connected at all. + - for all valid i: + - #out[i].distance() == dist_funct(samples[#out[i].index1()], samples[#out[i].index2()]) + - #out[i].distance() < std::numeric_limits::infinity() + - contains_duplicate_pairs(#out) == false + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename vector_type + > + bool contains_duplicate_pairs ( + const vector_type& pairs + ); + /*! + requires + - vector_type == a type with an interface compatible with std::vector and it + must in turn contain objects with an interface compatible with + dlib::sample_pair or dlib::ordered_sample_pair. + ensures + - if (pairs contains any elements that are equal according to operator==) then + - returns true + - else + - returns false + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename vector_type + > + unsigned long max_index_plus_one ( + const vector_type& pairs + ); + /*! + requires + - vector_type == a type with an interface compatible with std::vector and it + must in turn contain objects with an interface compatible with + dlib::sample_pair or dlib::ordered_sample_pair. + ensures + - if (pairs.size() == 0) then + - returns 0 + - else + - returns a number N such that: + - for all i: pairs[i].index1() < N && pairs[i].index2() < N + - for some j: pairs[j].index1()+1 == N || pairs[j].index2()+1 == N + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename vector_type + > + void remove_long_edges ( + vector_type& pairs, + double distance_threshold + ); + /*! + requires + - vector_type == a type with an interface compatible with std::vector and it + must in turn contain objects with an interface compatible with + dlib::sample_pair or dlib::ordered_sample_pair. + ensures + - Removes all elements of pairs that have a distance value greater than the + given threshold. + - #pairs.size() <= pairs.size() + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename vector_type + > + void remove_short_edges ( + vector_type& pairs, + double distance_threshold + ); + /*! + requires + - vector_type == a type with an interface compatible with std::vector and it + must in turn contain objects with an interface compatible with + dlib::sample_pair or dlib::ordered_sample_pair. + ensures + - Removes all elements of pairs that have a distance value less than the + given threshold. + - #pairs.size() <= pairs.size() + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename vector_type + > + void remove_percent_longest_edges ( + vector_type& pairs, + double percent + ); + /*! + requires + - 0 <= percent < 1 + - vector_type == a type with an interface compatible with std::vector and it + must in turn contain objects with an interface compatible with + dlib::sample_pair or dlib::ordered_sample_pair. + ensures + - Removes the given upper percentage of the longest edges in pairs. I.e. + this function removes the long edges from pairs. + - #pairs.size() == (1-percent)*pairs.size() + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename vector_type + > + void remove_percent_shortest_edges ( + vector_type& pairs, + double percent + ); + /*! + requires + - 0 <= percent < 1 + - vector_type == a type with an interface compatible with std::vector and it + must in turn contain objects with an interface compatible with + dlib::sample_pair or dlib::ordered_sample_pair. + ensures + - Removes the given upper percentage of the shortest edges in pairs. I.e. + this function removes the short edges from pairs. + - #pairs.size() == (1-percent)*pairs.size() + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename vector_type + > + void remove_duplicate_edges ( + vector_type& pairs + ); + /*! + requires + - vector_type == a type with an interface compatible with std::vector and it + must in turn contain objects with an interface compatible with + dlib::sample_pair or dlib::ordered_sample_pair. + ensures + - Removes any duplicate edges from pairs. That is, for all elements of pairs, + A and B, such that A == B, only one of A or B will be in pairs after this + function terminates. + - #pairs.size() <= pairs.size() + - is_ordered_by_index(#pairs) == true + - contains_duplicate_pairs(#pairs) == false + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename vector_type + > + bool is_ordered_by_index ( + const vector_type& edges + ); + /*! + requires + - vector_type == a type with an interface compatible with std::vector and it + must in turn contain objects with an interface compatible with + dlib::sample_pair or dlib::ordered_sample_pair. + ensures + - returns true if and only if the contents of edges are in sorted order + according to order_by_index(). That is, we return true if calling + std::stable_sort(edges.begin(), edges.end(), &order_by_index) would not + change the ordering of elements of edges. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename alloc1, + typename alloc2 + > + void find_neighbor_ranges ( + const std::vector& edges, + std::vector,alloc2>& neighbors + ); + /*! + requires + - is_ordered_by_index(edges) == true + (i.e. edges is sorted so that all the edges for a particular node are grouped + together) + ensures + - This function takes a graph, represented by its list of edges, and finds the + ranges that contain the edges for each node in the graph. In particular, + #neighbors[i] will tell you which edges correspond to the ith node in the + graph. + - #neighbors.size() == max_index_plus_one(edges) + (i.e. neighbors will have an entry for each node in the graph defined by the + list of edges) + - for all valid i: + - all elements of edges such that their index1() value == i are in the + range [neighbors[i].first, neighbors[i].second). That is, for all k such + that neighbors[i].first <= k < neighbors[i].second: + - edges[k].index1() == i. + - all edges outside this range have an index1() value != i + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename alloc1, + typename alloc2 + > + void convert_unordered_to_ordered ( + const std::vector& edges, + std::vector& out_edges + ); + /*! + ensures + - interprets edges a defining an undirected graph. + - This function populates out_edges with a directed graph that represents the + same graph as the one in edges. In particular, this means that for all valid + i we have the following: + - if (edges[i].index1() != edges[i].index2()) then + - #out_edges contains two edges corresponding to edges[i]. They + represent the two directions of this edge. The distance value from + edges[i] is also copied into the output edges. + - else + - #out_edges contains one edge corresponding to edges[i] since this is + a self edge. The distance value from edges[i] is also copied into + the output edge. + - max_index_plus_one(edges) == max_index_plus_one(#out_edges) + (i.e. both graphs have the same number of nodes) + - In all but the most trivial cases, we will have is_ordered_by_index(#out_edges) == false + - contains_duplicate_pairs(#out_edges) == contains_duplicate_pairs(edges) + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_EDGE_LIST_GrAPHS_ABSTRACT_Hh_ + diff --git a/ml/dlib/dlib/graph_utils/find_k_nearest_neighbors_lsh.h b/ml/dlib/dlib/graph_utils/find_k_nearest_neighbors_lsh.h new file mode 100644 index 000000000..433a588a5 --- /dev/null +++ b/ml/dlib/dlib/graph_utils/find_k_nearest_neighbors_lsh.h @@ -0,0 +1,217 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_FIND_K_NEAREST_NEIGHBOrS_LSH_Hh_ +#define DLIB_FIND_K_NEAREST_NEIGHBOrS_LSH_Hh_ + +#include "find_k_nearest_neighbors_lsh_abstract.h" +#include "../threads.h" +#include "../lsh/hashes.h" +#include +#include +#include "sample_pair.h" +#include "edge_list_graphs.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + struct compare_sample_pair_with_distance + { + inline bool operator() (const sample_pair& a, const sample_pair& b) const + { + return a.distance() < b.distance(); + } + }; + + template < + typename vector_type, + typename hash_function_type + > + class hash_block + { + public: + hash_block( + const vector_type& samples_, + const hash_function_type& hash_funct_, + std::vector& hashes_ + ) : + samples(samples_), + hash_funct(hash_funct_), + hashes(hashes_) + {} + + void operator() (long i) const + { + hashes[i] = hash_funct(samples[i]); + } + + const vector_type& samples; + const hash_function_type& hash_funct; + std::vector& hashes; + }; + + template < + typename vector_type, + typename distance_function_type, + typename hash_function_type, + typename alloc + > + class scan_find_k_nearest_neighbors_lsh + { + public: + scan_find_k_nearest_neighbors_lsh ( + const vector_type& samples_, + const distance_function_type& dist_funct_, + const hash_function_type& hash_funct_, + const unsigned long k_, + std::vector& edges_, + const unsigned long k_oversample_, + const std::vector& hashes_ + ) : + samples(samples_), + dist_funct(dist_funct_), + hash_funct(hash_funct_), + k(k_), + edges(edges_), + k_oversample(k_oversample_), + hashes(hashes_) + { + edges.clear(); + edges.reserve(samples.size()*k/2); + } + + mutex m; + const vector_type& samples; + const distance_function_type& dist_funct; + const hash_function_type& hash_funct; + const unsigned long k; + std::vector& edges; + const unsigned long k_oversample; + const std::vector& hashes; + + void operator() (unsigned long i) const + { + const unsigned long k_hash = k*k_oversample; + + std::priority_queue > best_hashes; + std::priority_queue, dlib::impl::compare_sample_pair_with_distance> best_samples; + unsigned long worst_distance = std::numeric_limits::max(); + // scan over the hashes and find the best matches for hashes[i] + for (unsigned long j = 0; j < hashes.size(); ++j) + { + if (i == j) + continue; + + const unsigned long dist = hash_funct.distance(hashes[i], hashes[j]); + if (dist < worst_distance || best_hashes.size() < k_hash) + { + if (best_hashes.size() >= k_hash) + best_hashes.pop(); + best_hashes.push(std::make_pair(dist, j)); + worst_distance = best_hashes.top().first; + } + } + + // Now figure out which of the best_hashes are actually the k best matches + // according to dist_funct() + while (best_hashes.size() != 0) + { + const unsigned long j = best_hashes.top().second; + best_hashes.pop(); + + const double dist = dist_funct(samples[i], samples[j]); + if (dist < std::numeric_limits::infinity()) + { + if (best_samples.size() >= k) + best_samples.pop(); + best_samples.push(sample_pair(i,j,dist)); + } + } + + // Finally, now put the k best matches according to dist_funct() into edges + auto_mutex lock(m); + while (best_samples.size() != 0) + { + edges.push_back(best_samples.top()); + best_samples.pop(); + } + } + }; + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename vector_type, + typename hash_function_type + > + void hash_samples ( + const vector_type& samples, + const hash_function_type& hash_funct, + const unsigned long num_threads, + std::vector& hashes + ) + { + hashes.resize(samples.size()); + + typedef impl::hash_block block_type; + block_type temp(samples, hash_funct, hashes); + parallel_for(num_threads, 0, samples.size(), temp); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename vector_type, + typename distance_function_type, + typename hash_function_type, + typename alloc + > + void find_k_nearest_neighbors_lsh ( + const vector_type& samples, + const distance_function_type& dist_funct, + const hash_function_type& hash_funct, + const unsigned long k, + const unsigned long num_threads, + std::vector& edges, + const unsigned long k_oversample = 20 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(k > 0 && k_oversample > 0, + "\t void find_k_nearest_neighbors_lsh()" + << "\n\t Invalid inputs were given to this function." + << "\n\t samples.size(): " << samples.size() + << "\n\t k: " << k + << "\n\t k_oversample: " << k_oversample + ); + + edges.clear(); + + if (samples.size() <= 1) + { + return; + } + + typedef typename hash_function_type::result_type hash_type; + std::vector hashes; + hash_samples(samples, hash_funct, num_threads, hashes); + + typedef impl::scan_find_k_nearest_neighbors_lsh scan_type; + scan_type temp(samples, dist_funct, hash_funct, k, edges, k_oversample, hashes); + parallel_for(num_threads, 0, hashes.size(), temp); + + remove_duplicate_edges(edges); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_FIND_K_NEAREST_NEIGHBOrS_LSH_Hh_ + + diff --git a/ml/dlib/dlib/graph_utils/find_k_nearest_neighbors_lsh_abstract.h b/ml/dlib/dlib/graph_utils/find_k_nearest_neighbors_lsh_abstract.h new file mode 100644 index 000000000..1de159be4 --- /dev/null +++ b/ml/dlib/dlib/graph_utils/find_k_nearest_neighbors_lsh_abstract.h @@ -0,0 +1,102 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_FIND_K_NEAREST_NEIGHBOrS_LSH_ABSTRACT_Hh_ +#ifdef DLIB_FIND_K_NEAREST_NEIGHBOrS_LSH_ABSTRACT_Hh_ + +#include "../lsh/hashes_abstract.h" +#include "sample_pair_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename vector_type, + typename hash_function_type + > + void hash_samples ( + const vector_type& samples, + const hash_function_type& hash_funct, + const unsigned long num_threads, + std::vector& hashes + ); + /*! + requires + - hash_funct() is threadsafe. This means that it must be safe for multiple + threads to invoke the member functions of hash_funct() at the same time. + - vector_type is any container that looks like a std::vector or dlib::array. + - hash_funct must be a function object with an interface compatible with the + objects defined in dlib/lsh/hashes_abstract.h. In particular, hash_funct + must be capable of hashing the elements in the samples vector. + ensures + - This function hashes all the elements in samples and stores the results in + hashes. It will also use num_threads concurrent threads to do this. You + should set this value equal to the number of processing cores on your + computer for maximum speed. + - #hashes.size() == 0 + - for all valid i: + - #hashes[i] = hash_funct(samples[i]) + (i.e. #hashes[i] will contain the hash of samples[i]) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename vector_type, + typename distance_function_type, + typename hash_function_type, + typename alloc + > + void find_k_nearest_neighbors_lsh ( + const vector_type& samples, + const distance_function_type& dist_funct, + const hash_function_type& hash_funct, + const unsigned long k, + const unsigned long num_threads, + std::vector& edges, + const unsigned long k_oversample = 20 + ); + /*! + requires + - hash_funct and dist_funct are threadsafe. This means that it must be safe + for multiple threads to invoke the member functions of these objects at the + same time. + - k > 0 + - k_oversample > 0 + - dist_funct(samples[i], samples[j]) must be a valid expression that evaluates + to a floating point number + - vector_type is any container that looks like a std::vector or dlib::array. + - hash_funct must be a function object with an interface compatible with the + objects defined in dlib/lsh/hashes_abstract.h. In particular, hash_funct + must be capable of hashing the elements in the samples vector. + ensures + - This function computes an approximate form of a k nearest neighbors graph of + the elements in samples. In particular, the way it works is that it first + hashes all elements in samples using the provided locality sensitive hash + function hash_funct(). Then it performs an exact k nearest neighbors on the + hashes which can be done very quickly. For each of these neighbors we + compute the true distance using dist_funct() and the k nearest neighbors for + each sample are stored into #edges. + - Note that samples with an infinite distance between them are considered to be + not connected at all. Therefore, we exclude edges with such distances from + being output. + - for all valid i: + - #edges[i].distance() == dist_funct(samples[#edges[i].index1()], samples[#edges[i].index2()]) + - #edges[i].distance() < std::numeric_limits::infinity() + - contains_duplicate_pairs(#edges) == false + - This function will use num_threads concurrent threads of processing. You + should set this value equal to the number of processing cores on your + computer for maximum speed. + - The hash based k nearest neighbor step is approximate, however, you can + improve the output accuracy by using a larger k value for this first step. + Therefore, this function finds k*k_oversample nearest neighbors during the + first hashing based step. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_FIND_K_NEAREST_NEIGHBOrS_LSH_ABSTRACT_Hh_ + diff --git a/ml/dlib/dlib/graph_utils/function_objects.h b/ml/dlib/dlib/graph_utils/function_objects.h new file mode 100644 index 000000000..33b6e51ff --- /dev/null +++ b/ml/dlib/dlib/graph_utils/function_objects.h @@ -0,0 +1,129 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MR_FUNCTION_ObJECTS_Hh_ +#define DLIB_MR_FUNCTION_ObJECTS_Hh_ + +#include "function_objects_abstract.h" +#include "../matrix.h" +#include "../svm/sparse_vector.h" +#include +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + struct squared_euclidean_distance + { + squared_euclidean_distance ( + ) : + lower(0), + upper(std::numeric_limits::infinity()) + {} + + squared_euclidean_distance ( + const double l, + const double u + ) : + lower(l), + upper(u) + {} + + const double lower; + const double upper; + + template + double operator() ( + const sample_type& a, + const sample_type& b + ) const + { + const double len = length_squared(a-b); + if (lower <= len && len <= upper) + return len; + else + return std::numeric_limits::infinity(); + } + }; + +// ---------------------------------------------------------------------------------------- + + struct cosine_distance + { + template + double operator() ( + const sample_type& a, + const sample_type& b + ) const + { + const double temp = length(a)*length(b); + if (temp == 0) + return 0; + else + return 1-dot(a,b)/temp; + } + }; + +// ---------------------------------------------------------------------------------------- + + struct negative_dot_product_distance + { + template + double operator() ( + const sample_type& a, + const sample_type& b + ) const + { + return -dot(a,b); + } + }; + +// ---------------------------------------------------------------------------------------- + + struct use_weights_of_one + { + template + double operator() ( + const edge_type& + ) const + { + return 1; + } + }; + +// ---------------------------------------------------------------------------------------- + + struct use_gaussian_weights + { + use_gaussian_weights ( + ) + { + gamma = 0.1; + } + + use_gaussian_weights ( + double g + ) + { + gamma = g; + } + + double gamma; + + template + double operator() ( + const edge_type& e + ) const + { + return std::exp(-gamma*e.distance()); + } + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MR_FUNCTION_ObJECTS_Hh_ + + diff --git a/ml/dlib/dlib/graph_utils/function_objects_abstract.h b/ml/dlib/dlib/graph_utils/function_objects_abstract.h new file mode 100644 index 000000000..394b99397 --- /dev/null +++ b/ml/dlib/dlib/graph_utils/function_objects_abstract.h @@ -0,0 +1,209 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_MR_FUNCTION_ObJECTS_ABSTRACT_Hh_ +#ifdef DLIB_MR_FUNCTION_ObJECTS_ABSTRACT_Hh_ + +#include "../matrix.h" +#include +#include "../svm/sparse_vector_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + struct squared_euclidean_distance + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a simple function object that computes squared euclidean distance + between two dlib::matrix objects. + + THREAD SAFETY + This object has no mutable members. Therefore, it is safe to call + operator() on a single instance of this object simultaneously from multiple + threads. + !*/ + + squared_euclidean_distance ( + ); + /*! + ensures + - #lower == 0 + - #upper == std::numeric_limits::infinity() + !*/ + + squared_euclidean_distance ( + const double l, + const double u + ); + /*! + ensures + - #lower == l + - #upper == u + !*/ + + const double lower; + const double upper; + + template + double operator() ( + const sample_type& a, + const sample_type& b + ) const; + /*! + requires + - sample_type should be a kind of dlib::matrix + ensures + - let LEN = length_squared(a-b) + - if (lower <= LEN <= upper) then + - returns LEN + - else + - returns std::numeric_limits::infinity() + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + struct cosine_distance + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a simple function object that computes the cosine of the angle + between two vectors and returns 1 - this quantity. Moreover, this object + works for both sparse and dense vectors. + + THREAD SAFETY + This object has no mutable members. Therefore, it is safe to call + operator() on a single instance of this object simultaneously from multiple + threads. + !*/ + + template + double operator() ( + const sample_type& a, + const sample_type& b + ) const; + /*! + requires + - sample_type is a dense vector (e.g. a dlib::matrix) or a sparse + vector as defined at the top of dlib/svm/sparse_vector_abstract.h + ensures + - let theta = the angle between a and b. + - returns 1 - cos(theta) + (e.g. this function returns 0 when a and b have an angle of 0 between + each other, 1 if they have a 90 degree angle, and a maximum of 2 if the + vectors have a 180 degree angle between each other). + - zero length vectors are considered to have angles of 0 between all other + vectors. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + struct negative_dot_product_distance + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a simple function object that computes the dot product between two + vectors and returns the negation of this value. Moreover, this object + works for both sparse and dense vectors. + + THREAD SAFETY + This object has no mutable members. Therefore, it is safe to call + operator() on a single instance of this object simultaneously from multiple + threads. + !*/ + + template + double operator() ( + const sample_type& a, + const sample_type& b + ) const; + /*! + requires + - sample_type is a dense vector (e.g. a dlib::matrix) or a sparse + vector as defined at the top of dlib/svm/sparse_vector_abstract.h + ensures + - returns -dot(a,b) + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + struct use_weights_of_one + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a simple function object that takes a single argument + and always returns 1 + + THREAD SAFETY + This object has no mutable members. Therefore, it is safe to call + operator() on a single instance of this object simultaneously from multiple + threads. + !*/ + + template + double operator() ( + const edge_type& + ) const; + /*! + ensures + - returns 1 + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + struct use_gaussian_weights + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a simple function object that takes a single argument + which should be an object similar to dlib::sample_pair. + + THREAD SAFETY + This object has no mutable members. Therefore, it is safe to call + operator() on a single instance of this object simultaneously from multiple + threads. + !*/ + + use_gaussian_weights ( + ); + /*! + ensures + - #gamma == 0.1 + !*/ + + use_gaussian_weights ( + double g + ); + /*! + ensures + - #gamma == g + !*/ + + double gamma; + + template + double operator() ( + const edge_type& e + ) const; + /*! + requires + - e.distance() must be a valid expression that returns a number + (e.g. edge_type might be dlib::sample_pair) + ensures + - returns std::exp(-gamma*e.distance()); + !*/ + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MR_FUNCTION_ObJECTS_ABSTRACT_Hh_ + + + diff --git a/ml/dlib/dlib/graph_utils/graph_utils.h b/ml/dlib/dlib/graph_utils/graph_utils.h new file mode 100644 index 000000000..81262b7f5 --- /dev/null +++ b/ml/dlib/dlib/graph_utils/graph_utils.h @@ -0,0 +1,1227 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_GRAPH_UTILs_ +#define DLIB_GRAPH_UTILs_ + +#include "../algs.h" +#include +#include "graph_utils_abstract.h" +#include "../is_kind.h" +#include "../enable_if.h" +#include +#include "../set.h" +#include "../memory_manager.h" +#include "../set_utils.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template + typename enable_if,typename T::edge_type>::type& edge( + T& g, + unsigned long idx_i, + unsigned long idx_j + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(g.has_edge(idx_i,idx_j) == true, + "\tT::edge_type& edge(g, idx_i, idx_j)" + << "\n\t you have requested an invalid edge" + << "\n\t idx_i: " << idx_i + << "\n\t idx_j: " << idx_j + ); + + for (unsigned long i = 0; i < g.node(idx_i).number_of_neighbors(); ++i) + { + if (g.node(idx_i).neighbor(i).index() == idx_j) + return g.node(idx_i).edge(i); + } + + // put this here just so compilers don't complain about a lack of + // a return here + DLIB_CASSERT(false, + "\tT::edge_type& edge(g, idx_i, idx_j)" + << "\n\t you have requested an invalid edge" + << "\n\t idx_i: " << idx_i + << "\n\t idx_j: " << idx_j + ); + } + + template + const typename enable_if,typename T::edge_type>::type& edge( + const T& g, + unsigned long idx_i, + unsigned long idx_j + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(g.has_edge(idx_i,idx_j) == true, + "\tT::edge_type& edge(g, idx_i, idx_j)" + << "\n\t you have requested an invalid edge" + << "\n\t idx_i: " << idx_i + << "\n\t idx_j: " << idx_j + ); + + for (unsigned long i = 0; i < g.node(idx_i).number_of_neighbors(); ++i) + { + if (g.node(idx_i).neighbor(i).index() == idx_j) + return g.node(idx_i).edge(i); + } + + // put this here just so compilers don't complain about a lack of + // a return here + DLIB_CASSERT(false, + "\tT::edge_type& edge(g, idx_i, idx_j)" + << "\n\t you have requested an invalid edge" + << "\n\t idx_i: " << idx_i + << "\n\t idx_j: " << idx_j + ); + } + +// ---------------------------------------------------------------------------------------- + + template + typename enable_if,typename T::edge_type>::type& edge( + T& g, + unsigned long parent_idx, + unsigned long child_idx + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(g.has_edge(parent_idx,child_idx) == true, + "\t T::edge_type& edge(g, parent_idx, child_idx)" + << "\n\t you have requested an invalid edge" + << "\n\t parent_idx: " << parent_idx + << "\n\t child_idx: " << child_idx + ); + + for (unsigned long i = 0; i < g.node(parent_idx).number_of_children(); ++i) + { + if (g.node(parent_idx).child(i).index() == child_idx) + return g.node(parent_idx).child_edge(i); + } + + // put this here just so compilers don't complain about a lack of + // a return here + DLIB_CASSERT(false, + "\t T::edge_type& edge(g, parent_idx, child_idx)" + << "\n\t you have requested an invalid edge" + << "\n\t parent_idx: " << parent_idx + << "\n\t child_idx: " << child_idx + ); + } + + template + const typename enable_if,typename T::edge_type>::type& edge( + const T& g, + unsigned long parent_idx, + unsigned long child_idx + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(g.has_edge(parent_idx,child_idx) == true, + "\t T::edge_type& edge(g, parent_idx, child_idx)" + << "\n\t you have requested an invalid edge" + << "\n\t parent_idx: " << parent_idx + << "\n\t child_idx: " << child_idx + ); + + for (unsigned long i = 0; i < g.node(parent_idx).number_of_children(); ++i) + { + if (g.node(parent_idx).child(i).index() == child_idx) + return g.node(parent_idx).child_edge(i); + } + + // put this here just so compilers don't complain about a lack of + // a return here + DLIB_ASSERT(false, + "\t T::edge_type& edge(g, parent_idx, child_idx)" + << "\n\t you have requested an invalid edge" + << "\n\t parent_idx: " << parent_idx + << "\n\t child_idx: " << child_idx + ); + } + +// ---------------------------------------------------------------------------------------- + + namespace graph_helpers + { + template + inline bool is_same_object ( + const T& a, + const U& b + ) + { + if (is_same_type::value == false) + return false; + if ((void*)&a == (void*)&b) + return true; + else + return false; + } + + template < + typename T + > + bool search_for_directed_cycles ( + const T& node, + std::vector& visited, + std::vector& temp + ) + /*! + requires + - visited.size() >= number of nodes in the graph that contains the given node + - temp.size() >= number of nodes in the graph that contains the given node + - for all i in temp: + - temp[i] == false + ensures + - checks the connected subgraph containing the given node for directed cycles + and returns true if any are found and false otherwise. + - for all nodes N in the connected subgraph containing the given node: + - #visited[N.index()] == true + - for all i in temp: + - #temp[i] == false + !*/ + { + if (temp[node.index()] == true) + return true; + + visited[node.index()] = true; + temp[node.index()] = true; + + for (unsigned long i = 0; i < node.number_of_children(); ++i) + { + if (search_for_directed_cycles(node.child(i), visited, temp)) + return true; + } + + temp[node.index()] = false; + + return false; + } + + // ------------------------------------------------------------------------------------ + + template < + typename T + > + typename enable_if,bool>::type search_for_undirected_cycles ( + const T& node, + std::vector& visited, + unsigned long prev = std::numeric_limits::max() + ) + /*! + requires + - visited.size() >= number of nodes in the graph that contains the given node + - for all nodes N in the connected subgraph containing the given node: + - visited[N.index] == false + ensures + - checks the connected subgraph containing the given node for directed cycles + and returns true if any are found and false otherwise. + - for all nodes N in the connected subgraph containing the given node: + - #visited[N.index()] == true + !*/ + { + using namespace std; + if (visited[node.index()] == true) + return true; + + visited[node.index()] = true; + + for (unsigned long i = 0; i < node.number_of_children(); ++i) + { + if (node.child(i).index() != prev && + search_for_undirected_cycles(node.child(i), visited, node.index())) + return true; + } + + for (unsigned long i = 0; i < node.number_of_parents(); ++i) + { + if (node.parent(i).index() != prev && + search_for_undirected_cycles(node.parent(i), visited, node.index())) + return true; + } + + return false; + } + + // ------------------------------------------------------------------------------------ + + template < + typename T + > + typename enable_if,bool>::type search_for_undirected_cycles ( + const T& node, + std::vector& visited, + unsigned long prev = std::numeric_limits::max() + ) + /*! + requires + - visited.size() >= number of nodes in the graph that contains the given node + - for all nodes N in the connected subgraph containing the given node: + - visited[N.index] == false + ensures + - checks the connected subgraph containing the given node for directed cycles + and returns true if any are found and false otherwise. + - for all nodes N in the connected subgraph containing the given node: + - #visited[N.index()] == true + !*/ + { + using namespace std; + if (visited[node.index()] == true) + return true; + + visited[node.index()] = true; + + for (unsigned long i = 0; i < node.number_of_neighbors(); ++i) + { + if (node.neighbor(i).index() != prev && + search_for_undirected_cycles(node.neighbor(i), visited, node.index())) + return true; + } + + return false; + } + + } + +// ------------------------------------------------------------------------------------ + + template < + typename graph_type1, + typename graph_type2 + > + typename enable_if >::type copy_graph_structure ( + const graph_type1& src, + graph_type2& dest + ) + { + COMPILE_TIME_ASSERT(is_graph::value); + COMPILE_TIME_ASSERT(is_graph::value); + if (graph_helpers::is_same_object(src,dest)) + return; + + dest.clear(); + dest.set_number_of_nodes(src.number_of_nodes()); + + // copy all the edges from src into dest + for (unsigned long i = 0; i < src.number_of_nodes(); ++i) + { + for (unsigned long j = 0; j < src.node(i).number_of_neighbors(); ++j) + { + const unsigned long nidx = src.node(i).neighbor(j).index(); + if (nidx >= i) + { + dest.add_edge(i,nidx); + } + } + } + } + + template < + typename graph_type1, + typename graph_type2 + > + typename enable_if >::type copy_graph_structure ( + const graph_type1& src, + graph_type2& dest + ) + { + COMPILE_TIME_ASSERT(is_directed_graph::value); + COMPILE_TIME_ASSERT(is_directed_graph::value || is_graph::value ); + if (graph_helpers::is_same_object(src,dest)) + return; + + dest.clear(); + dest.set_number_of_nodes(src.number_of_nodes()); + + // copy all the edges from src into dest + for (unsigned long i = 0; i < src.number_of_nodes(); ++i) + { + for (unsigned long j = 0; j < src.node(i).number_of_children(); ++j) + { + const unsigned long nidx = src.node(i).child(j).index(); + if (dest.has_edge(i,nidx) == false) + { + dest.add_edge(i,nidx); + } + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename graph_type1, + typename graph_type2 + > + typename enable_if >::type copy_graph ( + const graph_type1& src, + graph_type2& dest + ) + { + COMPILE_TIME_ASSERT(is_graph::value); + COMPILE_TIME_ASSERT(is_graph::value); + if (graph_helpers::is_same_object(src,dest)) + return; + + copy_graph_structure(src,dest); + + // copy all the node and edge content + for (unsigned long i = 0; i < src.number_of_nodes(); ++i) + { + dest.node(i).data = src.node(i).data; + + for (unsigned long j = 0; j < src.node(i).number_of_neighbors(); ++j) + { + const unsigned long nidx = src.node(i).neighbor(j).index(); + if (nidx >= i) + { + dest.node(i).edge(j) = src.node(i).edge(j); + } + } + } + } + + template < + typename graph_type1, + typename graph_type2 + > + typename enable_if >::type copy_graph ( + const graph_type1& src, + graph_type2& dest + ) + { + COMPILE_TIME_ASSERT(is_directed_graph::value); + COMPILE_TIME_ASSERT(is_directed_graph::value); + if (graph_helpers::is_same_object(src,dest)) + return; + + copy_graph_structure(src,dest); + + // copy all the node and edge content + for (unsigned long i = 0; i < src.number_of_nodes(); ++i) + { + dest.node(i).data = src.node(i).data; + for (unsigned long j = 0; j < src.node(i).number_of_children(); ++j) + { + dest.node(i).child_edge(j) = src.node(i).child_edge(j); + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename S + > + typename enable_if >::type find_connected_nodes ( + const T& n, + S& visited + ) + { + if (visited.is_member(n.index()) == false) + { + unsigned long temp = n.index(); + visited.add(temp); + + for (unsigned long i = 0; i < n.number_of_neighbors(); ++i) + find_connected_nodes(n.neighbor(i), visited); + } + } + + template < + typename T, + typename S + > + typename enable_if >::type find_connected_nodes ( + const T& n, + S& visited + ) + { + if (visited.is_member(n.index()) == false) + { + unsigned long temp = n.index(); + visited.add(temp); + + for (unsigned long i = 0; i < n.number_of_parents(); ++i) + find_connected_nodes(n.parent(i), visited); + for (unsigned long i = 0; i < n.number_of_children(); ++i) + find_connected_nodes(n.child(i), visited); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + bool graph_is_connected ( + const T& g + ) + { + if (g.number_of_nodes() == 0) + return true; + + set::kernel_1b_c visited; + find_connected_nodes(g.node(0), visited); + return (visited.size() == g.number_of_nodes()); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + bool graph_has_symmetric_edges ( + const T& graph + ) + { + for (unsigned long i = 0; i < graph.number_of_nodes(); ++i) + { + for (unsigned long j = 0; j < graph.node(i).number_of_children(); ++j) + { + const unsigned long jj = graph.node(i).child(j).index(); + // make sure every edge from a parent to a child has an edge linking back + if (graph.has_edge(jj,i) == false) + return false; + } + + for (unsigned long j = 0; j < graph.node(i).number_of_parents(); ++j) + { + const unsigned long jj = graph.node(i).parent(j).index(); + // make sure every edge from a child to a parent has an edge linking back + if (graph.has_edge(i,jj) == false) + return false; + } + } + + return true; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + bool graph_contains_directed_cycle ( + const T& graph + ) + { + using namespace std; + using namespace graph_helpers; + std::vector visited(graph.number_of_nodes(), false); + std::vector temp(graph.number_of_nodes(), false); + + while (true) + { + // find the first node that hasn't been visited yet + unsigned long i; + for (i = 0; i < visited.size(); ++i) + { + if (visited[i] == false) + break; + } + + // if we didn't find any non-visited nodes then we are done + if (i == visited.size()) + return false; + + if (search_for_directed_cycles(graph.node(i), visited, temp)) + return true; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + bool graph_contains_undirected_cycle ( + const T& graph + ) + { + using namespace std; + using namespace graph_helpers; + std::vector visited(graph.number_of_nodes(), false); + + while (true) + { + // find the first node that hasn't been visited yet + unsigned long i; + for (i = 0; i < visited.size(); ++i) + { + if (visited[i] == false) + break; + } + + // if we didn't find any non-visited nodes then we are done + if (i == visited.size()) + return false; + + if (search_for_undirected_cycles(graph.node(i), visited)) + return true; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename directed_graph_type, + typename graph_type + > + void create_moral_graph ( + const directed_graph_type& g, + graph_type& moral_graph + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(graph_contains_directed_cycle(g) == false, + "\tvoid create_moral_graph(g, moral_graph)" + << "\n\tYou can only make moral graphs if g doesn't have directed cycles" + ); + COMPILE_TIME_ASSERT(is_graph::value); + COMPILE_TIME_ASSERT(is_directed_graph::value); + + copy_graph_structure(g, moral_graph); + + // now marry all the parents (i.e. add edges between parent nodes) + for (unsigned long i = 0; i < g.number_of_nodes(); ++i) + { + // loop over all combinations of parents of g.node(i) + for (unsigned long j = 0; j < g.node(i).number_of_parents(); ++j) + { + for (unsigned long k = 0; k < g.node(i).number_of_parents(); ++k) + { + const unsigned long p1 = g.node(i).parent(j).index(); + const unsigned long p2 = g.node(i).parent(k).index(); + if (p1 == p2) + continue; + + if (moral_graph.has_edge(p1,p2) == false) + moral_graph.add_edge(p1,p2); + } + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename graph_type, + typename sets_of_int + > + bool is_clique ( + const graph_type& g, + const sets_of_int& clique + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(graph_contains_length_one_cycle(g) == false, + "\tvoid is_clique(g, clique)" + << "\n\tinvalid graph" + ); +#ifdef ENABLE_ASSERTS + clique.reset(); + while (clique.move_next()) + { + const unsigned long x = clique.element(); + DLIB_ASSERT( x < g.number_of_nodes(), + "\tvoid is_clique(g, clique)" + << "\n\tthe clique set contained an invalid node index" + << "\n\tx: " << x + << "\n\tg.number_of_nodes(): " << g.number_of_nodes() + ); + } +#endif + + COMPILE_TIME_ASSERT(is_graph::value); + + std::vector v; + v.reserve(clique.size()); + clique.reset(); + while (clique.move_next()) + { + v.push_back(clique.element()); + } + + for (unsigned long i = 0; i < v.size(); ++i) + { + for (unsigned long j = 0; j < v.size(); ++j) + { + if (v[i] == v[j]) + continue; + if (g.has_edge(v[i], v[j]) == false) + return false; + } + } + + return true; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename graph_type, + typename sets_of_int + > + bool is_maximal_clique ( + const graph_type& g, + const sets_of_int& clique + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(graph_contains_length_one_cycle(g) == false, + "\tvoid is_maximal_clique(g, clique)" + << "\n\tinvalid graph" + ); + DLIB_ASSERT(is_clique(g,clique) == true, + "\tvoid is_maximal_clique(g, clique)" + << "\n\tinvalid graph" + ); +#ifdef ENABLE_ASSERTS + clique.reset(); + while (clique.move_next()) + { + const unsigned long x = clique.element(); + DLIB_ASSERT( x < g.number_of_nodes(), + "\tvoid is_maximal_clique(g, clique)" + << "\n\tthe clique set contained an invalid node index" + << "\n\tx: " << x + << "\n\tg.number_of_nodes(): " << g.number_of_nodes() + ); + } +#endif + + COMPILE_TIME_ASSERT(is_graph::value); + + if (clique.size() == 0) + return true; + + // get an element in the clique and make sure that + // none of its neighbors that aren't in the clique are connected + // to all the elements of the clique. + clique.reset(); + clique.move_next(); + const unsigned long idx = clique.element(); + + for (unsigned long i = 0; i < g.node(idx).number_of_neighbors(); ++i) + { + const unsigned long n = g.node(idx).neighbor(i).index(); + if (clique.is_member(n)) + continue; + + // now loop over all the clique members and make sure they don't all + // share an edge with node n + bool all_share_edge = true; + clique.reset(); + while (clique.move_next()) + { + if (g.has_edge(clique.element(), n) == false) + { + all_share_edge = false; + break; + } + } + + if (all_share_edge == true) + return false; + } + + return true; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + typename enable_if,bool>::type graph_contains_length_one_cycle ( + const T& graph + ) + { + for (unsigned long i = 0; i < graph.number_of_nodes(); ++i) + { + // make sure none of this guys children are actually itself + for (unsigned long n = 0; n < graph.node(i).number_of_children(); ++n) + { + if (graph.node(i).child(n).index() == i) + return true; + } + } + + return false; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + typename enable_if,bool>::type graph_contains_length_one_cycle ( + const T& graph + ) + { + for (unsigned long i = 0; i < graph.number_of_nodes(); ++i) + { + // make sure none of this guys neighbors are actually itself + for (unsigned long n = 0; n < graph.node(i).number_of_neighbors(); ++n) + { + if (graph.node(i).neighbor(n).index() == i) + return true; + } + } + + return false; + } + +// ---------------------------------------------------------------------------------------- + + namespace graph_helpers + { + struct pair + { + unsigned long index; + unsigned long num_neighbors; + + bool operator< (const pair& p) const { return num_neighbors < p.num_neighbors; } + }; + + template < + typename T, + typename S, + typename V + > + void search_graph_for_triangulate ( + const T& n, + S& visited, + V& order_visited + ) + { + // base case of recursion. stop when we hit a node we have + // already visited. + if (visited.is_member(n.index())) + return; + + // record that we have visited this node + order_visited.push_back(n.index()); + unsigned long temp = n.index(); + visited.add(temp); + + // we want to visit all the neighbors of this node but do + // so by visiting the nodes with the most neighbors first. So + // lets make a vector that lists the nodes in the order we + // want to visit them + std::vector neighbors; + for (unsigned long i = 0; i < n.number_of_neighbors(); ++i) + { + pair p; + p.index = i; + p.num_neighbors = n.neighbor(i).number_of_neighbors(); + neighbors.push_back(p); + } + + // now sort the neighbors array so that the neighbors with the + // most neighbors come first. + std::sort(neighbors.rbegin(), neighbors.rend()); + + // now visit all the nodes + for (unsigned long i = 0; i < neighbors.size(); ++i) + { + search_graph_for_triangulate(n.neighbor(neighbors[i].index), visited, order_visited); + } + } + } // end namespace graph_helpers + + template < + typename graph_type, + typename set_of_sets_of_int + > + void triangulate_graph_and_find_cliques ( + graph_type& g, + set_of_sets_of_int& cliques + ) + { + + // make sure requires clause is not broken + DLIB_ASSERT(graph_contains_length_one_cycle(g) == false, + "\tvoid triangulate_graph_and_find_cliques(g, cliques)" + << "\n\tInvalid graph" + ); + DLIB_ASSERT(graph_is_connected(g) == true, + "\tvoid triangulate_graph_and_find_cliques(g, cliques)" + << "\n\tInvalid graph" + ); + + COMPILE_TIME_ASSERT(is_graph::value); + + + using namespace graph_helpers; + using namespace std; + typedef typename set_of_sets_of_int::type set_of_int; + + cliques.clear(); + + // first we find the node with the most neighbors + unsigned long max_index = 0; + unsigned long num_neighbors = 0; + for (unsigned long i = 0; i < g.number_of_nodes(); ++i) + { + if (g.node(i).number_of_neighbors() > num_neighbors) + { + max_index = i; + num_neighbors = g.node(i).number_of_neighbors(); + } + } + + // now we do a depth first search of the entire graph starting + // with the node we just found. We record the order in which + // we visit each node in the vector order_visited. + std::vector order_visited; + set_of_int visited; + search_graph_for_triangulate(g.node(max_index), visited, order_visited); + + set_of_int clique; + + // now add edges to the graph to make it triangulated + while (visited.size() > 0) + { + // we are going to enumerate over the nodes in the reverse of the + // order in which they were visited. So get the last node out. + const unsigned long idx = order_visited.back(); + order_visited.pop_back(); + visited.destroy(idx); + + // as a start add this node to our current clique + unsigned long temp = idx; + clique.clear(); + clique.add(temp); + + // now we want to make a clique that contains node g.node(idx) and + // all of its neighbors that are still recorded in the visited set + // (except for neighbors that have only one edge). + for (unsigned long i = 0; i < g.node(idx).number_of_neighbors(); ++i) + { + // get the index of the i'th neighbor + unsigned long nidx = g.node(idx).neighbor(i).index(); + + // add it to the clique if it is still in visited and it isn't + // a node with only one neighbor + if (visited.is_member(nidx) == true && + g.node(nidx).number_of_neighbors() != 1) + { + // add edges between this new node and all the nodes + // that are already in the clique + clique.reset(); + while (clique.move_next()) + { + if (g.has_edge(nidx, clique.element()) == false) + g.add_edge(nidx, clique.element()); + } + + // now also record that we added this node to the clique + clique.add(nidx); + } + } + + if (cliques.is_member(clique) == false && is_maximal_clique(g,clique) ) + { + cliques.add(clique); + } + + // now it is possible that we are missing some cliques of size 2 since + // above we didn't add nodes with only one edge to any of our cliques. + // Now lets make sure all these nodes are accounted for + for (unsigned long i = 0; i < g.number_of_nodes(); ++i) + { + clique.clear(); + if (g.node(i).number_of_neighbors() == 1) + { + unsigned long temp = i; + clique.add(temp); + temp = g.node(i).neighbor(0).index(); + clique.add(temp); + + if (cliques.is_member(clique) == false) + cliques.add(clique); + } + } + } + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename graph_type, + typename join_tree_type + > + void create_join_tree ( + const graph_type& g, + join_tree_type& join_tree + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(graph_contains_length_one_cycle(g) == false, + "\tvoid create_join_tree(g, join_tree)" + << "\n\tInvalid graph" + ); + DLIB_ASSERT(graph_is_connected(g) == true, + "\tvoid create_join_tree(g, join_tree)" + << "\n\tInvalid graph" + ); + + COMPILE_TIME_ASSERT(is_graph::value); + COMPILE_TIME_ASSERT(is_graph::value); + + + + typedef typename join_tree_type::type set_of_int; + typedef typename join_tree_type::edge_type set_of_int_edge; + typedef typename set::kernel_1b_c set_of_sets_of_int; + + copy_graph_structure(g, join_tree); + + // don't even bother in this case + if (g.number_of_nodes() == 0) + return; + + set_of_sets_of_int cliques; + set_of_int s; + + triangulate_graph_and_find_cliques(join_tree, cliques); + + join_tree.set_number_of_nodes(cliques.size()); + + // copy the cliques into each of the nodes of tree + for (unsigned long i = 0; i < join_tree.number_of_nodes(); ++i) + { + cliques.remove_any(s); + s.swap(join_tree.node(i).data); + } + + set_of_int_edge e; + + // add all possible edges to the join_tree + for (unsigned long i = 0; i < join_tree.number_of_nodes(); ++i) + { + for (unsigned long j = i+1; j < join_tree.number_of_nodes(); ++j) + { + set_intersection( + join_tree.node(i).data, + join_tree.node(j).data, + e); + + if (e.size() > 0) + { + join_tree.add_edge(i,j); + edge(join_tree,i,j).swap(e); + } + } + } + + // now we just need to remove the unnecessary edges so that we get a + // proper join tree + s.clear(); + set_of_int& good = s; // rename s to something slightly more meaningful + // good will contain nodes that have been "approved" + unsigned long n = 0; + good.add(n); + + std::vector vtemp; + + while (good.size() < join_tree.number_of_nodes()) + { + // figure out which of the neighbors of nodes in good has the best edge + unsigned long best_bad_idx = 0; + unsigned long best_good_idx = 0; + unsigned long best_overlap = 0; + good.reset(); + while (good.move_next()) + { + // loop over all the neighbors of the current node in good + for (unsigned long i = 0; i < join_tree.node(good.element()).number_of_neighbors(); ++i) + { + const unsigned long idx = join_tree.node(good.element()).neighbor(i).index(); + if (!good.is_member(idx)) + { + const unsigned long overlap = join_tree.node(good.element()).edge(i).size(); + + if (overlap > best_overlap) + { + best_overlap = overlap; + best_bad_idx = idx; + best_good_idx = good.element(); + } + } + } + } + + // now remove all the edges from best_bad_idx to the nodes in good except for the + // edge to best_good_idx. + for (unsigned long i = 0; i < join_tree.node(best_bad_idx).number_of_neighbors(); ++i) + { + const unsigned long idx = join_tree.node(best_bad_idx).neighbor(i).index(); + if (idx != best_good_idx && good.is_member(idx)) + { + vtemp.push_back(idx); + } + } + + for (unsigned long i = 0; i < vtemp.size(); ++i) + join_tree.remove_edge(vtemp[i], best_bad_idx); + + vtemp.clear(); + + + // and finally add this bad index into the good set + good.add(best_bad_idx); + } + } + +// ---------------------------------------------------------------------------------------- + + namespace graph_helpers + { + template < + typename T, + typename U + > + bool validate_join_tree ( + const T& n, + U& deads, + unsigned long parent = 0xffffffff + ) + /*! + this function makes sure that a join tree satisfies the following criterion for paths starting at the given node: + - for all valid i and j such that i and j are both < #join_tree.number_of_nodes() + - let X be the set of numbers that is contained in both #join_tree.node(i).data + and #join_tree.node(j).data + - It is the case that all nodes on the unique path between #join_tree.node(i) + and #join_tree.node(j) contain the numbers from X in their sets. + + returns true if validation passed and false if there is a problem with the tree + !*/ + { + n.data.reset(); + while (n.data.move_next()) + { + if (deads.is_member(n.data.element())) + return false; + } + + + for (unsigned long i = 0; i < n.number_of_neighbors(); ++i) + { + if (n.neighbor(i).index() == parent) + continue; + + // add anything to dead stuff + n.data.reset(); + while (n.data.move_next()) + { + if (n.neighbor(i).data.is_member(n.data.element()) == false) + { + unsigned long temp = n.data.element(); + deads.add(temp); + } + } + + if (validate_join_tree(n.neighbor(i), deads, n.index()) == false) + return false; + + // remove this nodes stuff from dead stuff + n.data.reset(); + while (n.data.move_next()) + { + if (n.neighbor(i).data.is_member(n.data.element()) == false) + { + unsigned long temp = n.data.element(); + deads.destroy(temp); + } + } + } + + return true; + } + } + + template < + typename graph_type, + typename join_tree_type + > + bool is_join_tree ( + const graph_type& g, + const join_tree_type& join_tree + ) + { + + // make sure requires clause is not broken + DLIB_ASSERT(graph_contains_length_one_cycle(g) == false, + "\tvoid create_join_tree(g, join_tree)" + << "\n\tInvalid graph" + ); + DLIB_ASSERT(graph_is_connected(g) == true, + "\tvoid create_join_tree(g, join_tree)" + << "\n\tInvalid graph" + ); + + COMPILE_TIME_ASSERT(is_graph::value || is_directed_graph::value); + COMPILE_TIME_ASSERT(is_graph::value); + + + if (graph_contains_undirected_cycle(join_tree)) + return false; + + if (graph_is_connected(join_tree) == false) + return false; + + // verify that the path condition of the join tree is valid + for (unsigned long i = 0; i < join_tree.number_of_nodes(); ++i) + { + typename join_tree_type::type deads; + if (graph_helpers::validate_join_tree(join_tree.node(i), deads) == false) + return false; + } + + typename join_tree_type::edge_type e; + typename join_tree_type::edge_type all; + // now make sure that the edges contain correct intersections + for (unsigned long i = 0; i < join_tree.number_of_nodes(); ++i) + { + set_union(all,join_tree.node(i).data, all); + for (unsigned long j = 0; j < join_tree.node(i).number_of_neighbors(); ++j) + { + set_intersection(join_tree.node(i).data, + join_tree.node(i).neighbor(j).data, + e); + + if (!(e == join_tree.node(i).edge(j))) + return false; + } + } + + // and finally check that all the nodes in g show up in the join tree + if (all.size() != g.number_of_nodes()) + return false; + all.reset(); + while (all.move_next()) + { + if (all.element() >= g.number_of_nodes()) + return false; + } + + + return true; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_GRAPH_UTILs_ + + diff --git a/ml/dlib/dlib/graph_utils/graph_utils_abstract.h b/ml/dlib/dlib/graph_utils/graph_utils_abstract.h new file mode 100644 index 000000000..52e170237 --- /dev/null +++ b/ml/dlib/dlib/graph_utils/graph_utils_abstract.h @@ -0,0 +1,452 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_GRAPH_UTILs_ABSTRACT_ +#ifdef DLIB_GRAPH_UTILs_ABSTRACT_ + +#include "../directed_graph.h" +#include "../algs.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + typename T::edge_type& edge( + T& g, + unsigned long i, + unsigned long j + ); + /*! + requires + - T is an implementation of graph/graph_kernel_abstract.h + - g.has_edge(i,j) + ensures + - returns a reference to the edge data for the edge connecting nodes i and j + (i.e. returns g.node(i).edge(x) such that g.node(i).neighbor(x).index() == j) + !*/ + + template < + typename T + > + typename const T::edge_type& edge( + const T& g, + unsigned long i, + unsigned long j + ); + /*! + requires + - T is an implementation of graph/graph_kernel_abstract.h + - g.has_edge(i,j) + ensures + - returns a const reference to the edge data for the edge connecting nodes i and j + (i.e. returns g.node(i).edge(x) such that g.node(i).neighbor(x).index() == j) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + typename T::edge_type& edge( + T& g, + unsigned long parent_idx, + unsigned long child_idx + ); + /*! + requires + - T is an implementation of directed_graph/directed_graph_kernel_abstract.h + - g.has_edge(parent_idx,child_idx) + ensures + - returns a reference to the edge data for the directed edge connecting parent + node g.node(parent_idx) to child node g.node(child_idx). + !*/ + + template < + typename T + > + typename const T::edge_type& edge( + const T& g, + unsigned long parent_idx, + unsigned long child_idx + ); + /*! + requires + - T is an implementation of directed_graph/directed_graph_kernel_abstract.h + - g.has_edge(parent_idx,child_idx) + ensures + - returns a const reference to the edge data for the directed edge connecting + parent node g.node(parent_idx) to child node g.node(child_idx). + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + bool graph_has_symmetric_edges ( + const T& graph + ); + /*! + requires + - T is an implementation of directed_graph/directed_graph_kernel_abstract.h + ensures + - if (All nodes have either 0 edges between them or 2 edges between them. + That is, if there is an edge pointing from node A to node B then there is + also an edge from B to A) then + - returns true + - else + - returns false + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + bool graph_contains_directed_cycle ( + const T& graph + ); + /*! + requires + - T is an implementation of directed_graph/directed_graph_kernel_abstract.h + ensures + - if (there is a directed cycle in the given graph) then + - returns true + - else + - returns false + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + bool graph_contains_undirected_cycle ( + const T& graph + ); + /*! + requires + - T is an implementation of directed_graph/directed_graph_kernel_abstract.h or + T is an implementation of graph/graph_kernel_abstract.h + ensures + - if (there is an undirected cycle in the given graph) then + - returns true + - else + - returns false + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + bool graph_contains_length_one_cycle ( + const T& graph + ); + /*! + requires + - T is an implementation of directed_graph/directed_graph_kernel_abstract.h or + T is an implementation of graph/graph_kernel_abstract.h + ensures + - if (it is the case that graph.has_edge(i,i) == true for some i) then + - returns true + - else + - returns false + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename src_type, + typename dest_type + > + void copy_graph_structure ( + const src_type& src, + dest_type& dest + ); + /*! + requires + - src_type is an implementation of directed_graph/directed_graph_kernel_abstract.h or + src_type is an implementation of graph/graph_kernel_abstract.h + - dest_type is an implementation of directed_graph/directed_graph_kernel_abstract.h or + dest_type is an implementation of graph/graph_kernel_abstract.h + - dest_type is not a directed_graph when src_type is a graph + ensures + - this function copies the graph structure from src into dest + - #dest.number_of_nodes() == src.number_of_nodes() + - for all valid i: #dest.node(i).item has an initial value for its type + - for all valid i and j: + - if (src.has_edge(i,j) == true) then + - #dest.has_edge(i,j) == true + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename src_type, + typename dest_type + > + void copy_graph ( + const src_type& src, + dest_type& dest + ); + /*! + requires + - src_type is an implementation of directed_graph/directed_graph_kernel_abstract.h or + src_type is an implementation of graph/graph_kernel_abstract.h + - dest_type is an implementation of directed_graph/directed_graph_kernel_abstract.h or + dest_type is an implementation of graph/graph_kernel_abstract.h + - src_type and dest_type are both the same kind of graph. That is, they + are either both directed or both undirected. + - the node and edge data in the graphs are copyable via operator=(). + ensures + - #dest is a complete duplicate of src. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename directed_graph_type, + typename graph_type + > + void create_moral_graph ( + const directed_graph_type& g, + graph_type& moral_graph + ); + /*! + requires + - directed_graph_type is an implementation of directed_graph/directed_graph_kernel_abstract.h + - graph_type is an implementation of graph/graph_kernel_abstract.h + - graph_contains_directed_cycle(g) == false + ensures + - #moral_graph == the moralized version of the directed graph g + - #moral_graph.number_of_nodes() == g.number_of_nodes() + - for all valid i and j: + - if (g.has_edge(i,j) == true) then + - #moral_graph.has_edge(i,j) == true + (i.e. all the edges that are in g are also in moral_graph) + - for all valid i: + - for all pairs p1 and p2 such that p1 != p2 and g.node(p1) and g.node(p2) are both + parents of node g.node(i): + - #moral_graph.has_edge(p1,p2) == true + (i.e. all the parents of a node are connected in the moral graph) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename S + > + void find_connected_nodes ( + const T& n, + S& visited + ); + /*! + requires + - T is a node_type from an implementation of directed_graph/directed_graph_kernel_abstract.h or + T is a node_type from an implementation of graph/graph_kernel_abstract.h + - S is an implementation of set/set_kernel_abstract.h + ensures + - let G be the graph that contains node n + - #visited.is_member(n.index()) == true + - for all i such that there is an undirected path from n to G.node(i): + - #visited.is_member(i) == true + - for all i such that visited.is_member(i): + - #visited.is_member(i) == true + (i.e. this function doesn't remove anything from visited. So if + it contains stuff when you call this function then it will still + contain those things once the function ends) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + bool graph_is_connected ( + const T& g + ); + /*! + requires + - T is an implementation of directed_graph/directed_graph_kernel_abstract.h or + T is an implementation of graph/graph_kernel_abstract.h + ensures + - every node in g has an undirected path to every other node in g. + I.e. g is a connected graph + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename graph_type, + typename sets_of_int + > + bool is_clique ( + const graph_type& g, + const sets_of_int& clique + ); + /*! + requires + - graph_type is an implementation of graph/graph_kernel_abstract.h + - sets_of_int is an implementation of set/set_kernel_abstract.h + and it contains unsigned long objects. + - graph_contains_length_one_cycle(g) == false + - for all x such that clique.is_member(x): + - x < g.number_of_nodes() + ensures + - if (it is true that for all i and j such that clique.is_member(i) and + clique.is_member(j) then g.has_edge(i,j) == true) then + - returns true + - else + - returns false + - if (clique.size() == 0) then + - returns true + (this is just a special case of the above condition) + - else + - returns false + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename graph_type, + typename sets_of_int + > + bool is_maximal_clique ( + const graph_type& g, + const sets_of_int& clique + ); + /*! + requires + - graph_type is an implementation of graph/graph_kernel_abstract.h + - sets_of_int is an implementation of set/set_kernel_abstract.h + and it contains unsigned long objects. + - graph_contains_length_one_cycle(g) == false + - for all x such that clique.is_member(x): + - x < g.number_of_nodes() + - is_clique(g,clique) == true + ensures + - if (there is no x such that clique.is_member(x) == false + and g.has_edge(i,x) for all i such that cliques.is_member(i)) then + - returns true + - else + - returns false + - if (clique.size() == 0) then + - returns true + (this is just a special case of the above condition) + - else + - returns false + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename graph_type, + typename set_of_sets_of_int + > + void triangulate_graph_and_find_cliques ( + graph_type& g, + set_of_sets_of_int& cliques + ); + /*! + requires + - graph_type is an implementation of graph/graph_kernel_abstract.h + - set_of_sets_of_int is an implementation of set/set_kernel_abstract.h + and it contains another set object which is comparable by operator< and + itself contains unsigned long objects. + (e.g. set::compare_1a>::kernel_1a) + - graph_contains_length_one_cycle(g) == false + - graph_is_connected(g) == true + ensures + - #g.number_of_nodes() == g.number_of_nodes() + - all this function does to g is add edges to it until g becomes a + chordal graph where a chordal graph is a graph where each cycle + in the graph of 4 or more nodes has an edge joining two nodes + that are not adjacent in the cycle. + - #cliques.size() == the number of maximal cliques in the graph #g + - for all valid sets S such that #cliques.is_member(S): + - for all valid integers i and j such that S.is_member(i) == true + and S.is_member(j) == true and i != j: + - #g.has_edge(i,j) == true + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename graph_type, + typename join_tree_type + > + bool is_join_tree ( + const graph_type& g, + const join_tree_type& join_tree + ); + /*! + requires + - graph_type is an implementation of directed_graph/directed_graph_kernel_abstract.h or + graph_type is an implementation of graph/graph_kernel_abstract.h + - join_tree_type is an implementation of graph/graph_kernel_abstract.h + - join_tree_type::type is an implementation of set/set_compare_abstract.h and + this set type contains unsigned long objects. + - join_tree_type::edge_type is an implementation of set/set_compare_abstract.h and + this set type contains unsigned long objects. + - graph_contains_length_one_cycle(g) == false + - graph_is_connected(g) == true + ensures + - if (join_tree is a valid join tree of graph g. That is, join_tree is a + tree decomposition of g) then + - returns true + - else + - returns false + + - a join tree of graph g is defined as follows: + - graph_contains_undirected_cycle(join_tree) == false + - graph_is_connected(join_tree) == true + - for all valid i: + - join_tree.node(i).item == a non-empty set containing node indexes + from g. That is, this set contains all the nodes from g that are + in this cluster in the join tree + - for all valid i and j such that i and j are both < join_tree.number_of_nodes() + - let X be the set of numbers that is contained in both join_tree.node(i).item + and join_tree.node(j).item + - It is the case that all nodes on the unique path between join_tree.node(i) + and join_tree.node(j) contain the numbers from X in their sets. + - edge(join_tree,i,j) == a set containing the intersection of + join_tree.node(i).item and join_tree.node(j).item + - the node index for every node in g appears in some node in join_tree at + least once. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename graph_type, + typename join_tree_type + > + void create_join_tree ( + const graph_type& g, + join_tree_type& join_tree + ); + /*! + requires + - graph_type is an implementation of graph/graph_kernel_abstract.h + - join_tree_type is an implementation of graph/graph_kernel_abstract.h + - join_tree_type::type is an implementation of set/set_compare_abstract.h and + this set type contains unsigned long objects. + - join_tree_type::edge_type is an implementation of set/set_compare_abstract.h and + this set type contains unsigned long objects. + - graph_contains_length_one_cycle(g) == false + - graph_is_connected(g) == true + ensures + - #is_join_tree(g, join_tree) == true + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_GRAPH_UTILs_ABSTRACT_ + diff --git a/ml/dlib/dlib/graph_utils/ordered_sample_pair.h b/ml/dlib/dlib/graph_utils/ordered_sample_pair.h new file mode 100644 index 000000000..7d510e122 --- /dev/null +++ b/ml/dlib/dlib/graph_utils/ordered_sample_pair.h @@ -0,0 +1,125 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ORDERED_SAMPLE_PaIR_Hh_ +#define DLIB_ORDERED_SAMPLE_PaIR_Hh_ + +#include "ordered_sample_pair_abstract.h" +#include +#include "../serialize.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class ordered_sample_pair + { + public: + ordered_sample_pair( + ) : + _index1(0), + _index2(0) + { + _distance = 1; + } + + ordered_sample_pair ( + const unsigned long idx1, + const unsigned long idx2 + ) + { + _distance = 1; + _index1 = idx1; + _index2 = idx2; + } + + ordered_sample_pair ( + const unsigned long idx1, + const unsigned long idx2, + const double dist + ) + { + _distance = dist; + _index1 = idx1; + _index2 = idx2; + } + + const unsigned long& index1 ( + ) const { return _index1; } + + const unsigned long& index2 ( + ) const { return _index2; } + + const double& distance ( + ) const { return _distance; } + + private: + unsigned long _index1; + unsigned long _index2; + double _distance; + }; + +// ---------------------------------------------------------------------------------------- + + inline bool operator == ( + const ordered_sample_pair& a, + const ordered_sample_pair& b + ) + { + return a.index1() == b.index1() && a.index2() == b.index2(); + } + + inline bool operator != ( + const ordered_sample_pair& a, + const ordered_sample_pair& b + ) + { + return !(a == b); + } + +// ---------------------------------------------------------------------------------------- + + inline void serialize ( + const ordered_sample_pair& item, + std::ostream& out + ) + { + try + { + serialize(item.index1(),out); + serialize(item.index2(),out); + serialize(item.distance(),out); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing object of type ordered_sample_pair"); + } + } + + inline void deserialize ( + ordered_sample_pair& item, + std::istream& in + ) + { + try + { + unsigned long idx1, idx2; + double dist; + + deserialize(idx1,in); + deserialize(idx2,in); + deserialize(dist,in); + item = ordered_sample_pair(idx1, idx2, dist); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while deserializing object of type ordered_sample_pair"); + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_ORDERED_SAMPLE_PaIR_Hh_ + diff --git a/ml/dlib/dlib/graph_utils/ordered_sample_pair_abstract.h b/ml/dlib/dlib/graph_utils/ordered_sample_pair_abstract.h new file mode 100644 index 000000000..9d150e257 --- /dev/null +++ b/ml/dlib/dlib/graph_utils/ordered_sample_pair_abstract.h @@ -0,0 +1,128 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_ORDERED_SAMPLE_PaIR_ABSTRACT_Hh_ +#ifdef DLIB_ORDERED_SAMPLE_PaIR_ABSTRACT_Hh_ + +#include +#include "../serialize.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class ordered_sample_pair + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is intended to represent an edge in a directed graph which has + data samples at its vertices. So it contains two integers (index1 and + index2) which represent the identifying indices of the samples at the ends + of an edge. + + This object also contains a double which can be used for any purpose. + !*/ + + public: + ordered_sample_pair( + ); + /*! + ensures + - #index1() == 0 + - #index2() == 0 + - #distance() == 1 + !*/ + + ordered_sample_pair ( + const unsigned long idx1, + const unsigned long idx2 + ); + /*! + ensures + - #index1() == idx1 + - #index2() == idx2 + - #distance() == 1 + !*/ + + ordered_sample_pair ( + const unsigned long idx1, + const unsigned long idx2, + const double dist + ); + /*! + ensures + - #index1() == idx1 + - #index2() == idx2 + - #distance() == dist + !*/ + + const unsigned long& index1 ( + ) const; + /*! + ensures + - returns the first index value stored in this object + !*/ + + const unsigned long& index2 ( + ) const; + /*! + ensures + - returns the second index value stored in this object + !*/ + + const double& distance ( + ) const; + /*! + ensures + - returns the floating point number stored in this object + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + inline bool operator == ( + const ordered_sample_pair& a, + const ordered_sample_pair& b + ); + /*! + ensures + - returns a.index1() == b.index1() && a.index2() == b.index2(); + I.e. returns true if a and b both represent the same pair and false otherwise. + Note that the distance field is not involved in this comparison. + !*/ + + inline bool operator != ( + const ordered_sample_pair& a, + const ordered_sample_pair& b + ); + /*! + ensures + - returns !(a == b) + !*/ + +// ---------------------------------------------------------------------------------------- + + inline void serialize ( + const ordered_sample_pair& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + + inline void deserialize ( + ordered_sample_pair& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_ORDERED_SAMPLE_PaIR_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/graph_utils/sample_pair.h b/ml/dlib/dlib/graph_utils/sample_pair.h new file mode 100644 index 000000000..88ad458f6 --- /dev/null +++ b/ml/dlib/dlib/graph_utils/sample_pair.h @@ -0,0 +1,179 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SAMPLE_PaIR_Hh_ +#define DLIB_SAMPLE_PaIR_Hh_ + +#include "sample_pair_abstract.h" +#include +#include "../serialize.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class sample_pair + { + public: + sample_pair( + ) : + _index1(0), + _index2(0) + { + _distance = 1; + } + + sample_pair ( + const unsigned long idx1, + const unsigned long idx2 + ) + { + _distance = 1; + if (idx1 < idx2) + { + _index1 = idx1; + _index2 = idx2; + } + else + { + _index1 = idx2; + _index2 = idx1; + } + } + + sample_pair ( + const unsigned long idx1, + const unsigned long idx2, + const double dist + ) + { + _distance = dist; + if (idx1 < idx2) + { + _index1 = idx1; + _index2 = idx2; + } + else + { + _index1 = idx2; + _index2 = idx1; + } + } + + const unsigned long& index1 ( + ) const { return _index1; } + + const unsigned long& index2 ( + ) const { return _index2; } + + const double& distance ( + ) const { return _distance; } + + private: + unsigned long _index1; + unsigned long _index2; + double _distance; + }; + +// ---------------------------------------------------------------------------------------- + + template + inline bool order_by_index ( + const T& a, + const T& b + ) + { + return a.index1() < b.index1() || (a.index1() == b.index1() && a.index2() < b.index2()); + } + + template + inline bool order_by_distance ( + const T& a, + const T& b + ) + { + return a.distance() < b.distance(); + } + + template + inline bool order_by_descending_distance ( + const T& a, + const T& b + ) + { + return a.distance() > b.distance(); + } + + template + bool order_by_distance_and_index ( + const T& a, + const T& b + ) + { + return a.distance() < b.distance() || (a.distance() == b.distance() && order_by_index(a,b)); + } + +// ---------------------------------------------------------------------------------------- + + inline bool operator == ( + const sample_pair& a, + const sample_pair& b + ) + { + return a.index1() == b.index1() && a.index2() == b.index2(); + } + + inline bool operator != ( + const sample_pair& a, + const sample_pair& b + ) + { + return !(a == b); + } + +// ---------------------------------------------------------------------------------------- + + inline void serialize ( + const sample_pair& item, + std::ostream& out + ) + { + try + { + serialize(item.index1(),out); + serialize(item.index2(),out); + serialize(item.distance(),out); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing object of type sample_pair"); + } + } + + inline void deserialize ( + sample_pair& item, + std::istream& in + ) + { + try + { + unsigned long idx1, idx2; + double dist; + + deserialize(idx1,in); + deserialize(idx2,in); + deserialize(dist,in); + item = sample_pair(idx1, idx2, dist); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while deserializing object of type sample_pair"); + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SAMPLE_PaIR_Hh_ + diff --git a/ml/dlib/dlib/graph_utils/sample_pair_abstract.h b/ml/dlib/dlib/graph_utils/sample_pair_abstract.h new file mode 100644 index 000000000..3306899e3 --- /dev/null +++ b/ml/dlib/dlib/graph_utils/sample_pair_abstract.h @@ -0,0 +1,192 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SAMPLE_PaIR_ABSTRACT_Hh_ +#ifdef DLIB_SAMPLE_PaIR_ABSTRACT_Hh_ + +#include +#include "../serialize.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class sample_pair + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is intended to represent an edge in an undirected graph + which has data samples at its vertices. So it contains two integers + (index1 and index2) which represent the identifying indices of + the samples at the ends of an edge. Note that this object enforces + the constraint that index1 <= index2. This has the effect of + making the edges undirected since a sample_pair is incapable + of representing a single edge in more than one way. That is, + sample_pair(i,j) == sample_pair(j,i) for any value of i and j. + + This object also contains a double which can be used for any purpose. + !*/ + + public: + sample_pair( + ); + /*! + ensures + - #index1() == 0 + - #index2() == 0 + - #distance() == 1 + !*/ + + sample_pair ( + const unsigned long idx1, + const unsigned long idx2 + ); + /*! + ensures + - #index1() == min(idx1,idx2) + - #index2() == max(idx1,idx2) + - #distance() == 1 + !*/ + + sample_pair ( + const unsigned long idx1, + const unsigned long idx2, + const double dist + ); + /*! + ensures + - #index1() == min(idx1,idx2) + - #index2() == max(idx1,idx2) + - #distance() == dist + !*/ + + const unsigned long& index1 ( + ) const; + /*! + ensures + - returns the first index value stored in this object + !*/ + + const unsigned long& index2 ( + ) const; + /*! + ensures + - returns the second index value stored in this object + !*/ + + const double& distance ( + ) const; + /*! + ensures + - returns the floating point number stored in this object + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template + bool order_by_index ( + const T& a, + const T& b + ) { return a.index1() < b.index1() || (a.index1() == b.index1() && a.index2() < b.index2()); } + /*! + requires + - T is a type with an interface compatible with sample_pair. + ensures + - provides a total ordering of sample_pair objects that will cause pairs that are + equal to be adjacent when sorted. So for example, this function can be used + with std::sort() to first sort a sequence of sample_pair objects and then + find duplicate edges. + !*/ + + template + bool order_by_distance ( + const T& a, + const T& b + ) { return a.distance() < b.distance(); } + /*! + requires + - T is a type with an interface compatible with sample_pair. + ensures + - provides a total ordering of sample_pair objects that causes pairs with + smallest distance to be the first in a sorted list. This function can be + used with std::sort(). + !*/ + + template + bool order_by_descending_distance ( + const T& a, + const T& b + ) { return a.distance() > b.distance(); } + /*! + requires + - T is a type with an interface compatible with sample_pair. + ensures + - provides a total ordering of sample_pair objects that causes pairs with + largest distance to be the first in a sorted list. This function can be + used with std::sort(). + !*/ + + template + bool order_by_distance_and_index ( + const T& a, + const T& b + ) { return a.distance() < b.distance() || (a.distance() == b.distance() && order_by_index(a,b)); } + /*! + requires + - T is a type with an interface compatible with sample_pair. + ensures + - provides a total ordering of sample_pair objects that causes pairs with + smallest distance to be the first in a sorted list but also orders samples + with equal distances according to order_by_index(). This function can be + used with std::sort(). + !*/ + +// ---------------------------------------------------------------------------------------- + + inline bool operator == ( + const sample_pair& a, + const sample_pair& b + ); + /*! + ensures + - returns a.index1() == b.index1() && a.index2() == b.index2(); + I.e. returns true if a and b both represent the same pair and false otherwise. + Note that the distance field is not involved in this comparison. + !*/ + + inline bool operator != ( + const sample_pair& a, + const sample_pair& b + ); + /*! + ensures + - returns !(a == b) + !*/ + +// ---------------------------------------------------------------------------------------- + + inline void serialize ( + const sample_pair& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + + inline void deserialize ( + sample_pair& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SAMPLE_PaIR_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/graph_utils_threaded.h b/ml/dlib/dlib/graph_utils_threaded.h new file mode 100644 index 000000000..c9938fd80 --- /dev/null +++ b/ml/dlib/dlib/graph_utils_threaded.h @@ -0,0 +1,12 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_GRAPH_UTILs_THREADED_H_ +#define DLIB_GRAPH_UTILs_THREADED_H_ + +#include "graph_utils.h" +#include "graph_utils/find_k_nearest_neighbors_lsh.h" + +#endif // DLIB_GRAPH_UTILs_THREADED_H_ + + + diff --git a/ml/dlib/dlib/gui_core.h b/ml/dlib/dlib/gui_core.h new file mode 100644 index 000000000..6ba54b11c --- /dev/null +++ b/ml/dlib/dlib/gui_core.h @@ -0,0 +1,20 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_GUI_CORe_ +#define DLIB_GUI_CORe_ + + +#include "platform.h" + + + +#ifdef WIN32 +#include "gui_core/windows.h" +#else +#include "gui_core/xlib.h" +#endif + + + +#endif // DLIB_GUI_CORe_ + diff --git a/ml/dlib/dlib/gui_core/gui_core_kernel_1.cpp b/ml/dlib/dlib/gui_core/gui_core_kernel_1.cpp new file mode 100644 index 000000000..2a6efa4c9 --- /dev/null +++ b/ml/dlib/dlib/gui_core/gui_core_kernel_1.cpp @@ -0,0 +1,2204 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net), Keita Mochizuki +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_GUI_CORE_KERNEL_1_CPp_ +#define DLIB_GUI_CORE_KERNEL_1_CPp_ +#include "../platform.h" + +#ifdef WIN32 + +#include "gui_core_kernel_1.h" + +// tell visual studio to link to the libraries we need if we are +// in fact using visual studio +#ifdef _MSC_VER +#pragma comment (lib, "gdi32.lib") +#pragma comment (lib, "comctl32.lib") +#pragma comment (lib, "user32.lib") +#pragma comment (lib, "imm32.lib") +#endif + +#include +#include +#include +#include + +#include "../threads.h" +#include "../assert.h" +#include "../queue.h" +#include "../sync_extension.h" +#include "../queue.h" +#include "../logger.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace gui_core_kernel_1_globals + { + + + struct user_event_type + { + HWND w; + void* p; + int i; + }; + + typedef sync_extension::kernel_1a>::kernel_1a window_table_type; + typedef sync_extension::kernel_1b>::kernel_2a_c>::kernel_1a queue_of_user_events; + + + enum USER_OFFSETS + { + CREATE_WINDOW, + DESTROY_WINDOW, + SET_ACTIVE_WINDOW, + QUIT_EVENT_HANDLER_THREAD, + USER_EVENTS_READY, + CALL_MOVE_WINDOW, + SHOW_WINDOW_SHOW, + SHOW_WINDOW_HIDE, + CALL_SET_WINDOW_TITLE + }; + + // ---------------------------------------------------------------------------------------- + + const std::shared_ptr& global_mutex() + { + static std::shared_ptr m(new dlib::mutex); + return m; + } + + class event_handler_thread : public threaded_object + { + public: + + enum et_state + { + uninitialized, + initialized, + failure_to_init + }; + + et_state status; + + queue_of_user_events user_events; + queue_of_user_events user_events_temp; + logger dlog; + + HINSTANCE hInstance; + HWND helper_window; + const TCHAR* window_class_name; + + bool quit_windows_loop; + bool set_window_title_done; + std::wstring window_title; + bool move_window_done; + HWND move_window_hwnd; + int move_window_width; + int move_window_height; + int move_window_x; + int move_window_y; + bool request_new_window; + DWORD dwStyle; + HWND new_window; + bool in_ime_composition; + bool event_thread_started; + // the window_table.get_mutex() mutex locks the above 11 variables + + + // this variable holds a mapping from window handles to the base_window + // objects which represent them. Note that this objects mutex is always locked + // when inside the event loop. + window_table_type window_table; + rsignaler window_close_signaler; + rsignaler et_signaler; + + // note that this is the thread that will perform all the event + // processing. + thread_id_type event_thread_id; + + std::shared_ptr reference_to_global_mutex; + + event_handler_thread( + ) : + dlog("dlib.gui_core"), + hInstance(0), + helper_window(0), + window_class_name(TEXT ("w3049u6qc2d94thw9m34f4we0gvwa3-tgkser0-b9gm 05")), + quit_windows_loop(false), + set_window_title_done(true), + move_window_done(true), + move_window_hwnd(0), + move_window_width(0), + move_window_height(0), + move_window_x(0), + move_window_y(0), + request_new_window(false), + dwStyle(0), + new_window(0), + in_ime_composition(false), + event_thread_started(false), + window_close_signaler(window_table.get_mutex()), + et_signaler(window_table.get_mutex()), + reference_to_global_mutex(global_mutex()) + { + status = uninitialized; + } + + void start_event_thread ( + ) + /*! + we can't call this function from this objects constructor because + starting the event thread in windows involves sending messages to the + WndProc() and that requires this object to be fully constructed. + !*/ + { + + if (event_thread_started == false) + { + auto_mutex M(window_table.get_mutex()); + if (event_thread_started == false) + { + event_thread_started = true; + // start up the event handler thread + start(); + + // wait for the event thread to get up and running + while (status == uninitialized) + et_signaler.wait(); + + if (status == failure_to_init) + throw gui_error("Failed to start event thread"); + } + } + } + + ~event_handler_thread () + { + using namespace gui_core_kernel_1_globals; + + if (is_alive()) + { + if (PostMessage(helper_window,WM_USER+QUIT_EVENT_HANDLER_THREAD,0,0)==0) + { + dlog << LWARN << "Unable to schedule function for execution in event handling thread."; + // No point calling wait() here since the thread isn't going to + // terminate gracefully in this case. So we just let the program + // end as it will and hope for the best. + } + else + { + // wait for the event handler thread to terminate. + wait(); + } + } + + } + + private: + + void thread ( + ) + { + event_thread_id = get_thread_id(); + + hInstance = GetModuleHandle(NULL); + if (hInstance == NULL) + { + dlog << LFATAL << "Error gathering needed resources"; + + // signal that an error has occurred + window_table.get_mutex().lock(); + status = failure_to_init; + et_signaler.broadcast(); + window_table.get_mutex().unlock(); + return; + } + + // register the main window class + WNDCLASS wndclass ; + + wndclass.style = CS_DBLCLKS; + wndclass.lpfnWndProc = dlib::gui_core_kernel_1_globals::WndProc ; + wndclass.cbClsExtra = 0 ; + wndclass.cbWndExtra = 0 ; + wndclass.hInstance = hInstance ; + wndclass.hIcon = LoadIcon (NULL, IDI_APPLICATION) ; + wndclass.hCursor = LoadCursor (NULL, IDC_ARROW) ; + wndclass.hbrBackground = 0; + wndclass.lpszMenuName = NULL ; + wndclass.lpszClassName = window_class_name ; + + if (!RegisterClass (&wndclass)) + { + dlog << LFATAL << "Error registering window class"; + + // signal that an error has occurred + window_table.get_mutex().lock(); + status = failure_to_init; + et_signaler.broadcast(); + window_table.get_mutex().unlock(); + return; + } + + + // make the helper window that is used to trigger events in the + // event handler loop from other threads + TCHAR nothing[] = TEXT(""); + helper_window = CreateWindow(window_class_name,nothing,WS_DISABLED,0,0,0,0,HWND_MESSAGE,NULL,hInstance,NULL); + if (helper_window == NULL) + { + dlog << LFATAL << "Error gathering needed resources"; + + // signal that an error has occurred + window_table.get_mutex().lock(); + status = failure_to_init; + et_signaler.broadcast(); + window_table.get_mutex().unlock(); + return; + } + + // signal that the event thread is now up and running + window_table.get_mutex().lock(); + status = initialized; + et_signaler.broadcast(); + window_table.get_mutex().unlock(); + + // start the event handler loop. + /* + A note about this quit_windows_loop thing. If the user is holding + the mouse button down on the title bar of a window it will cause + the PostQuitMessage() function to be ignored!! This extra bool + is a work around to prevent that from happening. + */ + MSG msg; + while (GetMessage (&msg, NULL, 0, 0) && + quit_windows_loop == false) + { + TranslateMessage (&msg) ; + DispatchMessage (&msg) ; + } + } + }; + + // Do all this just to make sure global_mutex() is initialized at program start + // and thus hopefully before any threads have the chance to startup and call + // global_data() concurrently. + struct call_global_mutex { call_global_mutex() { global_mutex(); } }; + static call_global_mutex call_global_mutex_instance; + + // Note that we need to use dlib::shared_ptr_thread_safe here rather than + // std::shared_ptr. This is because the destructor of event_handler_thread + // triggers an event in the main event loop telling it to stop and that event loop + // holds a shared pointer to the event_handler_thread. So what can happen is + // global_data()'s shared pointer gets destructed because the program is + // terminating, which causes the event_handler_thread destructor to run, which + // eventually causes the event loop to ask global_data() for a handle to the event + // thread. This is bad for std::shared_ptr since (at least as of 2017) most + // implementations of std::shared_ptr decrement the reference count to 0 before + // invoking the event_handler_thread's destructor and then when the event thread + // calls global_data() it increments the counter again, then decrements it back to + // 0, triggering a double deletion, when the event handler routine finally + // finishes. + // + // dlib::shared_ptr_thread_safe doesn't have this problem. An alternative would be + // to somehow avoid this kind of self reference. But it's not obvious how to do + // that given the limitations of the Win32 event WndProc() structure imposed by + // windows. So in any case, we just use the old dlib::shared_ptr_thread_safe to + // avoid this problem. + const shared_ptr_thread_safe& global_data() + { + auto_mutex M(*global_mutex()); + static shared_ptr_thread_safe p; + if (p.get() == 0) + { + p.reset(new event_handler_thread()); + M.unlock(); + p->start_event_thread(); + } + return p; + } + + // ---------------------------------------------------------------------------------------- + + struct ebh_param + { + std::string text; + std::string title; + }; + + static void error_box_helper(void* param) + { + ebh_param& p = *static_cast(param); +#ifdef UNICODE + MessageBox (NULL, convert_mbstring_to_wstring(p.text).c_str(), + convert_mbstring_to_wstring(p.title).c_str(), MB_OK|MB_ICONERROR|MB_SYSTEMMODAL + ); +#else + MessageBox (NULL, p.text.c_str(), + p.title.c_str(), MB_OK|MB_ICONERROR|MB_SYSTEMMODAL + ); +#endif + delete &p; + } + + static void error_box ( + const char* title, + const char* text, + bool nonblocking = false + ) + { + try + { + if (nonblocking) + { + ebh_param* param = new ebh_param; + param->text = text; + param->title = title; + dlib::create_new_thread(error_box_helper,param); + } + else + { +#ifdef UNICODE + MessageBox (NULL, convert_mbstring_to_wstring(text).c_str(), + convert_mbstring_to_wstring(title).c_str(), + MB_OK|MB_ICONERROR|MB_SYSTEMMODAL + ); +#else + MessageBox (NULL, text, + title, MB_OK|MB_ICONERROR|MB_SYSTEMMODAL + ); +#endif + } + } + catch (...) + { + // we are totally screwed if this happens so just quit + exit(0); + } + } + + // ---------------------------------------------------------------------------------------- + + static bool map_keys ( + unsigned long keycode, + bool shift, + bool caps, + unsigned long& result, + bool& is_printable + ) + /*! + requires + - if (shift was down for this key) then + - shift == true + - if (caps lock was on for this key) then + - caps == true + - keycode == the keycode from windows that we are to process + - keycode < keyboard_keys_size + ensures + - if (this key should be ignored) then + - returns false + - else + - returns true + - #is_printable == true if result is a printable ascii character + - #result == the keycode converted into the proper number to tbe + returned by the event handler. + !*/ + { + is_printable = true; + + if (keycode <= '9' && keycode >= '0') + { + result = keycode; + if (shift) + { + switch (result) + { + case '0': result = ')'; break; + case '1': result = '!'; break; + case '2': result = '@'; break; + case '3': result = '#'; break; + case '4': result = '$'; break; + case '5': result = '%'; break; + case '6': result = '^'; break; + case '7': result = '&'; break; + case '8': result = '*'; break; + case '9': result = '('; break; + } + } + } + else if (keycode <= 'Z' && keycode >= 'A') + { + result = keycode; + + // make the result lower case if we need to. + if ((shift && caps) || (!caps && !shift)) + result = result - 'A' + 'a'; + } + else + { + switch (keycode) + { + case VK_BACK: + is_printable = false; + result = base_window::KEY_BACKSPACE; + break; + + case VK_SHIFT: + is_printable = false; + result = base_window::KEY_SHIFT; + break; + + case VK_CONTROL: + is_printable = false; + result = base_window::KEY_CTRL; + break; + + case VK_MENU: + is_printable = false; + result = base_window::KEY_ALT; + break; + + case VK_PAUSE: + is_printable = false; + result = base_window::KEY_PAUSE; + break; + + case VK_CAPITAL: + is_printable = false; + result = base_window::KEY_CAPS_LOCK; + break; + + case VK_ESCAPE: + is_printable = false; + result = base_window::KEY_ESC; + break; + + case VK_PRIOR: + is_printable = false; + result = base_window::KEY_PAGE_UP; + break; + + case VK_NEXT: + is_printable = false; + result = base_window::KEY_PAGE_DOWN; + break; + + case VK_END: + is_printable = false; + result = base_window::KEY_END; + break; + + case VK_HOME: + is_printable = false; + result = base_window::KEY_HOME; + break; + + case VK_LEFT: + is_printable = false; + result = base_window::KEY_LEFT; + break; + + case VK_RIGHT: + is_printable = false; + result = base_window::KEY_RIGHT; + break; + + case VK_UP: + is_printable = false; + result = base_window::KEY_UP; + break; + + case VK_DOWN: + is_printable = false; + result = base_window::KEY_DOWN; + break; + + case VK_INSERT: + is_printable = false; + result = base_window::KEY_INSERT; + break; + + case VK_DELETE: + is_printable = false; + result = base_window::KEY_DELETE; + break; + + case 0x91: + is_printable = false; + result = base_window::KEY_SCROLL_LOCK; + break; + + case VK_F1: + is_printable = false; + result = base_window::KEY_F1; + break; + + case VK_F2: + is_printable = false; + result = base_window::KEY_F2; + break; + + case VK_F3: + is_printable = false; + result = base_window::KEY_F3; + break; + + case VK_F4: + is_printable = false; + result = base_window::KEY_F4; + break; + + case VK_F5: + is_printable = false; + result = base_window::KEY_F5; + break; + + case VK_F6: + is_printable = false; + result = base_window::KEY_F6; + break; + + case VK_F7: + is_printable = false; + result = base_window::KEY_F7; + break; + + case VK_F8: + is_printable = false; + result = base_window::KEY_F8; + break; + + case VK_F9: + is_printable = false; + result = base_window::KEY_F9; + break; + + case VK_F10: + is_printable = false; + result = base_window::KEY_F10; + break; + + case VK_F11: + is_printable = false; + result = base_window::KEY_F11; + break; + + case VK_F12: + is_printable = false; + result = base_window::KEY_F12; + break; + + + case VK_SPACE: result = ' '; break; + case VK_TAB: result = '\t'; break; + case VK_RETURN: result = '\n'; break; + case VK_NUMPAD0: result = '0'; break; + case VK_NUMPAD1: result = '1'; break; + case VK_NUMPAD2: result = '2'; break; + case VK_NUMPAD3: result = '3'; break; + case VK_NUMPAD4: result = '4'; break; + case VK_NUMPAD5: result = '5'; break; + case VK_NUMPAD6: result = '6'; break; + case VK_NUMPAD7: result = '7'; break; + case VK_NUMPAD8: result = '8'; break; + case VK_NUMPAD9: result = '9'; break; + + case VK_MULTIPLY: result = '*'; break; + case VK_ADD: result = '+'; break; + case VK_SUBTRACT: result = '-'; break; + case VK_DECIMAL: result = '.'; break; + case VK_DIVIDE: result = '/'; break; + + case VK_OEM_1: + if (shift) result = ':'; + else result = ';'; + break; + + case VK_OEM_PLUS: + if (shift) result = '+'; + else result = '='; + break; + + case VK_OEM_COMMA: + if (shift) result = '<'; + else result = ','; + break; + + case VK_OEM_MINUS: + if (shift) result = '_'; + else result = '-'; + break; + + case VK_OEM_PERIOD: + if (shift) result = '>'; + else result = '.'; + break; + + case VK_OEM_2: + if (shift) result = '?'; + else result = '/'; + break; + + case VK_OEM_3: + if (shift) result = '~'; + else result = '`'; + break; + + case VK_OEM_4: + if (shift) result = '{'; + else result = '['; + break; + + case VK_OEM_5: + if (shift) result = '|'; + else result = '\\'; + break; + + case VK_OEM_6: + if (shift) result = '}'; + else result = ']'; + break; + + case VK_OEM_7: + if (shift) result = '"'; + else result = '\''; + break; + + default: + return false; + } + } + + return true; + } + + // ------------------------------------------------------------------------------------ + + LRESULT CALLBACK WndProc ( + HWND hwnd, + UINT message, + WPARAM wParam, + LPARAM lParam + ) + { + using namespace gui_core_kernel_1_globals; + // Make the event processing thread have a priority slightly above normal. + // This makes the GUI smother if you do heavy processing in other threads. + HANDLE hand = OpenThread(THREAD_ALL_ACCESS,FALSE,GetCurrentThreadId()); + SetThreadPriority(hand,THREAD_PRIORITY_ABOVE_NORMAL); + CloseHandle(hand); + + auto globals = global_data(); + + window_table_type& window_table = globals->window_table; + HWND& helper_window = globals->helper_window; + + auto_mutex M(window_table.get_mutex()); + + try + { + std::vector bitmap_buffer; + + bool is_double = false; + unsigned long btn = base_window::NONE; + + switch (message) + { + case WM_USER+QUIT_EVENT_HANDLER_THREAD: + if (hwnd == helper_window) + { + globals->quit_windows_loop = true; + PostQuitMessage(0); + } + return 0; + + case WM_USER+DESTROY_WINDOW: + if (hwnd == helper_window) + { + DestroyWindow((HWND)wParam); + } + return 0; + + case WM_USER+CALL_MOVE_WINDOW: + if (hwnd == helper_window) + { + MoveWindow( + globals->move_window_hwnd, + globals->move_window_x, + globals->move_window_y, + globals->move_window_width, + globals->move_window_height, + TRUE); + globals->move_window_done = true; + globals->et_signaler.broadcast(); + } + return 0; + + case WM_USER+USER_EVENTS_READY: + if (hwnd == helper_window) + { + // this is the signal to look in the user_events queue + globals->user_events.lock(); + globals->user_events.swap(globals->user_events_temp); + globals->user_events.unlock(); + globals->user_events_temp.reset(); + // now dispatch all these user events + while (globals->user_events_temp.move_next()) + { + base_window** win_ = window_table[globals->user_events_temp.element().w]; + base_window* win; + // if this window exists in the window table then dispatch + // its event. + if (win_) + { + win = *win_; + win->on_user_event( + globals->user_events_temp.element().p, + globals->user_events_temp.element().i + ); + } + } + globals->user_events_temp.clear(); + } + return 0; + + case WM_USER+SET_ACTIVE_WINDOW: + if (hwnd == helper_window) + { + SetActiveWindow((HWND)wParam); + } + return 0; + + case WM_USER+SHOW_WINDOW_SHOW: + if (hwnd == helper_window) + { + ShowWindow((HWND)wParam,SW_SHOW); + BringWindowToTop((HWND)wParam); + } + return 0; + + case WM_USER+SHOW_WINDOW_HIDE: + if (hwnd == helper_window) + { + ShowWindow((HWND)wParam,SW_HIDE); + } + return 0; + + case WM_USER+CALL_SET_WINDOW_TITLE: + if (hwnd == helper_window) + { + SetWindowTextW((HWND)wParam,globals->window_title.c_str()); + globals->set_window_title_done = true; + globals->et_signaler.broadcast(); + } + return 0; + + + case WM_USER+CREATE_WINDOW: + if (hwnd == helper_window) + { + + // if this is stupposed to be a popup window then do the popup window thing + if (globals->dwStyle == WS_CHILD) + { + TCHAR nothing[] = TEXT(""); + globals->new_window = CreateWindowEx (WS_EX_TOOLWINDOW,globals->window_class_name, nothing, + globals->dwStyle, + CW_USEDEFAULT, CW_USEDEFAULT, + CW_USEDEFAULT, CW_USEDEFAULT, + helper_window, NULL, globals->hInstance, NULL); + SetParent(globals->new_window,NULL); + } + else + { + TCHAR nothing[] = TEXT(""); + globals->new_window = CreateWindow (globals->window_class_name, nothing, + globals->dwStyle, + CW_USEDEFAULT, CW_USEDEFAULT, + CW_USEDEFAULT, CW_USEDEFAULT, + NULL, NULL, globals->hInstance, NULL); + } + // use the helper_window to indicate that CreateWindow failed + if (globals->new_window == NULL) + globals->new_window = helper_window; + globals->et_signaler.broadcast(); + } + return 0; + + case WM_SYSKEYDOWN: + case WM_KEYDOWN: + { + if (globals->in_ime_composition) break; + + base_window** win_ = window_table[hwnd]; + base_window* win; + if (win_) + win = *win_; + else + break; + + unsigned long state = 0; + + bool shift = ((GetKeyState(VK_SHIFT)&0x8000)!=0); + bool ctrl = ((GetKeyState(VK_CONTROL)&0x8000)!=0); + bool caps = ((GetKeyState(VK_CAPITAL)&0x0001)!=0); + if(shift) + state = base_window::KBD_MOD_SHIFT; + if(ctrl) + state |= base_window::KBD_MOD_CONTROL; + if(caps) + state |= base_window::KBD_MOD_CAPS_LOCK; + if((GetKeyState(VK_MENU)&0x8000)!=0) + state |= base_window::KBD_MOD_ALT; + if((GetKeyState(VK_NUMLOCK)&0x0001)!=0) + state |= base_window::KBD_MOD_NUM_LOCK; + if((GetKeyState(VK_SCROLL)&0x0001)!=0) + state |= base_window::KBD_MOD_SCROLL_LOCK; + + + bool is_printable; + unsigned long result; + + if (map_keys(wParam,shift,caps,result,is_printable)) + { + // signal the keyboard event + win->on_keydown(result,is_printable,state); + } + + } + break; + + // treat the user releasing the mouse button on the non client area (e.g. the title bar) + // like focus being lost since that is what X11 does + case WM_NCLBUTTONUP: + case WM_NCMBUTTONUP: + case WM_NCRBUTTONUP: + case WM_SETFOCUS: + { + base_window** win_ = window_table[hwnd]; + base_window* win; + if (win_) + win = *win_; + else + break; + + // signal that the window is gaining focus + win->on_focus_gained(); + } + break; + + // treat the user clicking on the non client area (e.g. the title bar) + // like focus being lost since that is what X11 does + case WM_NCLBUTTONDBLCLK: + case WM_NCMBUTTONDBLCLK: + case WM_NCRBUTTONDBLCLK: + case WM_NCLBUTTONDOWN: + case WM_NCMBUTTONDOWN: + case WM_NCRBUTTONDOWN: + case WM_KILLFOCUS: + { + base_window** win_ = window_table[hwnd]; + base_window* win; + if (win_) + win = *win_; + else + break; + + // signal that the window is gaining focus + win->on_focus_lost(); + } + break; + + case WM_SIZE: + { + base_window** win_ = window_table[hwnd]; + base_window* win; + if (win_) + win = *win_; + else + break; + + + // signal that the window has been resized + win->on_window_resized(); + + } + return 0; + + case WM_MOVE: + { + base_window** win_ = window_table[hwnd]; + base_window* win; + if (win_) + win = *win_; + else + break; + + + // signal that the window has moved + win->on_window_moved(); + + } + return 0; + + case WM_MOUSELEAVE: + { + base_window** win_ = window_table[hwnd]; + base_window* win; + if (win_) + win = *win_; + else + break; + + + // signal that the mouse has left the window + if (win->mouse_in) + { + win->on_mouse_leave(); + win->mouse_in = false; + } + + } + return 0; + + case WM_MOUSEWHEEL: + { + base_window** win_ = window_table[hwnd]; + base_window* win; + if (win_) + win = *win_; + else + break; + + unsigned long state = 0; + if (wParam & MK_CONTROL) + state |= base_window::CONTROL; + if (wParam & MK_LBUTTON) + state |= base_window::LEFT; + if (wParam & MK_MBUTTON) + state |= base_window::MIDDLE; + if (wParam & MK_RBUTTON) + state |= base_window::RIGHT; + if (wParam & MK_SHIFT) + state |= base_window::SHIFT; + + // signal the mouse wheel event + if (GET_WHEEL_DELTA_WPARAM(wParam) > 0) + { + win->on_wheel_up(state); + } + else + { + win->on_wheel_down(state); + } + + } + return 0; + + case WM_LBUTTONUP: + btn = base_window::LEFT; + case WM_MBUTTONUP: + if (btn == base_window::NONE) + btn = base_window::MIDDLE; + case WM_RBUTTONUP: + if (btn == base_window::NONE) + btn = base_window::RIGHT; + { + // release the mouse capture if the user isn't holding any + // other mouse buttons + if (!((wParam & MK_LBUTTON) | (wParam & MK_MBUTTON) | (wParam & MK_RBUTTON))) + ReleaseCapture(); + + base_window** win_ = window_table[hwnd]; + base_window* win; + if (win_) + win = *win_; + else + break; + + unsigned long state = 0; + if (wParam & MK_CONTROL) + state |= base_window::CONTROL; + if (wParam & MK_LBUTTON) + state |= base_window::LEFT; + if (wParam & MK_MBUTTON) + state |= base_window::MIDDLE; + if (wParam & MK_RBUTTON) + state |= base_window::RIGHT; + if (wParam & MK_SHIFT) + state |= base_window::SHIFT; + + // remove the clicked button from the state + state &= (~btn); + + // signal the mouse click + win->on_mouse_up(btn,state,GET_X_LPARAM(lParam),GET_Y_LPARAM(lParam)); + + } + return 0; + + + + case WM_LBUTTONDBLCLK: + if (btn == base_window::NONE) + btn = base_window::LEFT; + case WM_MBUTTONDBLCLK: + if (btn == base_window::NONE) + btn = base_window::MIDDLE; + case WM_RBUTTONDBLCLK: + if (btn == base_window::NONE) + btn = base_window::RIGHT; + is_double = true; + case WM_LBUTTONDOWN: + if (btn == base_window::NONE) + btn = base_window::LEFT; + case WM_MBUTTONDOWN: + if (btn == base_window::NONE) + btn = base_window::MIDDLE; + case WM_RBUTTONDOWN: + if (btn == base_window::NONE) + btn = base_window::RIGHT; + { + SetCapture(hwnd); + + + base_window** win_ = window_table[hwnd]; + base_window* win; + if (win_) + win = *win_; + else + break; + + unsigned long state = 0; + if (wParam & MK_CONTROL) + state |= base_window::CONTROL; + if (wParam & MK_LBUTTON) + state |= base_window::LEFT; + if (wParam & MK_MBUTTON) + state |= base_window::MIDDLE; + if (wParam & MK_RBUTTON) + state |= base_window::RIGHT; + if (wParam & MK_SHIFT) + state |= base_window::SHIFT; + + // remove the clicked button from the state + state &= (~btn); + + // signal the mouse click + win->on_mouse_down(btn,state,GET_X_LPARAM(lParam),GET_Y_LPARAM(lParam),is_double); + + } + return 0; + + case WM_MOUSEMOVE: + { + base_window** win_ = window_table[hwnd]; + base_window* win; + if (win_) + win = *win_; + else + break; + + unsigned long state = 0; + bool mouse_button_down = false; + if (wParam & MK_CONTROL) + state |= base_window::CONTROL; + if (wParam & MK_LBUTTON) + { + state |= base_window::LEFT; + mouse_button_down = true; + } + if (wParam & MK_MBUTTON) + { + mouse_button_down = true; + state |= base_window::MIDDLE; + } + if (wParam & MK_RBUTTON) + { + state |= base_window::RIGHT; + mouse_button_down = true; + } + if (wParam & MK_SHIFT) + state |= base_window::SHIFT; + + // signal the mouse movement if this mouse event isn't identical to the + // last one we got + if ( GET_X_LPARAM(lParam) != win->prevx || + GET_Y_LPARAM(lParam) != win->prevy || + state != win->prev_state) + { + win->on_mouse_move(state,GET_X_LPARAM(lParam),GET_Y_LPARAM(lParam)); + } + + // save the event data into the prev* member variables + win->prevx = GET_X_LPARAM(lParam); + win->prevy = GET_Y_LPARAM(lParam); + win->prev_state = state; + + // The following block of code checks if the mouse is moving + // into or out of the window. + if (mouse_button_down == false) + { + // if there isn't any mouse button down then the fact that + // we are getting a mouse move message means it is in the + // window + if (win->mouse_in == false) + { + win->on_mouse_enter(); + win->mouse_in = true; + + // set the tracker for the mouse + TRACKMOUSEEVENT tm; + tm.hwndTrack = hwnd; + tm.cbSize = sizeof(tm); + tm.dwFlags = TME_LEAVE; + _TrackMouseEvent(&tm); + } + } + else if (win->mouse_in) + { + // check if the mouse is currently outside the window + const long mouse_x = GET_X_LPARAM(lParam); + const long mouse_y = GET_Y_LPARAM(lParam); + if (mouse_x < 0 || mouse_y < 0) + { + // the mouse is not in the window + win->mouse_in = false; + win->on_mouse_leave(); + } + else + { + unsigned long width, height; + win->get_size(width,height); + if (mouse_x >= static_cast(width) || + mouse_y >= static_cast(height)) + { + // the mouse is not in the window + win->mouse_in = false; + win->on_mouse_leave(); + } + } + } + else if (win->mouse_in == false) + { + // at this point we know that the mouse is moving around + // with some of its buttons down. So it might be outside the window. + // get the window size and see if the mouse is outside + // it. + const long mouse_x = GET_X_LPARAM(lParam); + const long mouse_y = GET_Y_LPARAM(lParam); + unsigned long width, height; + win->get_size(width,height); + if (mouse_x < static_cast(width) && + mouse_y < static_cast(height) && + mouse_x >= 0 && + mouse_y >= 0) + { + // The mouse has gone inside the window + win->mouse_in = true; + win->on_mouse_enter(); + + // set the tracker for the mouse + TRACKMOUSEEVENT tm; + tm.hwndTrack = hwnd; + tm.cbSize = sizeof(tm); + tm.dwFlags = TME_LEAVE; + _TrackMouseEvent(&tm); + } + + } + + + } + return 0; + + case WM_PAINT : + { + + PAINTSTRUCT ps; + HDC hdc = NULL; + + hdc = BeginPaint (hwnd, &ps) ; + + try + { + base_window** win_ = window_table[hwnd]; + base_window* win; + if (win_) + win = *win_; + else + break; + + + + + LONG x = ps.rcPaint.left; + LONG y = ps.rcPaint.top; + LONG width = ps.rcPaint.right - x; + LONG height = ps.rcPaint.bottom - y; + + if (width != 0 && height != 0) + { + + BITMAPINFO bmap_info; + bmap_info.bmiColors[0].rgbBlue = 0; + bmap_info.bmiColors[0].rgbGreen = 0; + bmap_info.bmiColors[0].rgbRed = 0; + bmap_info.bmiColors[0].rgbReserved = 0; + bmap_info.bmiHeader.biSize = sizeof(bmap_info.bmiHeader); + bmap_info.bmiHeader.biWidth = width; + bmap_info.bmiHeader.biHeight = -1*height; + bmap_info.bmiHeader.biPlanes = 1; + bmap_info.bmiHeader.biBitCount = 24; + bmap_info.bmiHeader.biCompression = BI_RGB; + bmap_info.bmiHeader.biSizeImage = 0; + bmap_info.bmiHeader.biXPelsPerMeter = 0; + bmap_info.bmiHeader.biYPelsPerMeter = 0; + bmap_info.bmiHeader.biClrUsed = 0; + bmap_info.bmiHeader.biClrImportant = 0; + + + + unsigned char* bitmap ; + unsigned long size; + unsigned long padding = 0; + if ((width*3)%sizeof(LONG) != 0) + { + padding = sizeof(LONG) - (width*3)%sizeof(LONG); + size = (width*3+padding)*height; + } + else + { + size = width*height*3; + } + + if (bitmap_buffer.size() < size) + bitmap_buffer.resize(size); + bitmap = &bitmap_buffer[0]; + + canvas bits(bitmap,padding,x,y,x+width-1,y+height-1); + + + + win->paint(bits); + + + + SetDIBitsToDevice ( + hdc, + ps.rcPaint.left, + ps.rcPaint.top, + width, + height, + 0, + 0, + 0, + height, + bitmap, + &bmap_info, + DIB_RGB_COLORS + ); + } + + EndPaint (hwnd, &ps) ; + + } + catch (...) + { + // make sure EndPaint is called even if an exception + // is thrown. + if (hdc != NULL) + EndPaint (hwnd, &ps); + throw; + } + } + return 0 ; + + case WM_ERASEBKGND: + return 1; + + + + + case WM_CLOSE: + { + base_window** win_ = window_table[hwnd]; + base_window* win; + if (win_) + win = *win_; + else + break; + + + // signal that the window is being closed + if (win->on_window_close() == base_window::DO_NOT_CLOSE_WINDOW) + { + DLIB_ASSERT(win->has_been_destroyed == false, + "\tYou called close_window() inside the on_window_close() event but" + << "\n\tthen returned DO_NOT_CLOSE_WINDOW. You can do one or the other but not both." + << "\n\tthis: " << win + ); + // this happens if the on_window_close() callback + // tells us to ignore the close event. + return 0; + } + else + { + if (window_table[hwnd]) + { + window_table.destroy(hwnd); + win->has_been_destroyed = true; + win->hwnd = 0; + globals->window_close_signaler.broadcast(); + } + else + { + // in this case the window must have self destructed by + // calling delete this; + return 0; + } + } + + } + return DefWindowProc (hwnd, message, wParam, lParam); + + case WM_IME_STARTCOMPOSITION: + globals->in_ime_composition = true; + break; + + case WM_IME_COMPOSITION: + { + globals->in_ime_composition = false; + base_window** win_ = window_table[hwnd]; + base_window* win; + if (win_) + win = *win_; + else + break; + HIMC hImc = ImmGetContext(hwnd); + if (lParam & GCS_RESULTSTR){ + WCHAR wc; + LONG bufbyte = ImmGetCompositionStringW(hImc, GCS_RESULTSTR, &wc, 0); + if (bufbyte != IMM_ERROR_NODATA && bufbyte != IMM_ERROR_GENERAL){ + bufbyte += sizeof(WCHAR); + + + WCHAR *buf = new WCHAR[bufbyte / sizeof(WCHAR)]; + ImmGetCompositionStringW(hImc, GCS_RESULTSTR, buf, bufbyte); + buf[bufbyte / sizeof(WCHAR) - 1] = L'\0'; + + // signal the putstring event + win->on_string_put(std::wstring(buf)); + delete [] buf; + } + } + ImmReleaseContext(hwnd, hImc); + } + break; + + default: + break; + + } // switch (message) + + + } + catch (std::exception& e) + { + error_box("Exception thrown in event handler",e.what()); + globals->quit_windows_loop = true; + } + catch (...) + { + error_box("Exception thrown in event handler","Unknown Exception type."); + globals->quit_windows_loop = true; + } + + return DefWindowProc (hwnd, message, wParam, lParam) ; + + } + + // ---------------------------------------------------------------------------------------- + + void show_window ( + HWND hwnd + ) + { + using namespace gui_core_kernel_1_globals; + PostMessage(global_data()->helper_window,WM_USER+SHOW_WINDOW_SHOW,(WPARAM)hwnd,0); + } + + // ---------------------------------------------------------------------------------------- + + void hide_window ( + HWND hwnd + ) + { + using namespace gui_core_kernel_1_globals; + PostMessage(global_data()->helper_window,WM_USER+SHOW_WINDOW_HIDE,(WPARAM)hwnd,0); + } + + // ---------------------------------------------------------------------------------------- + + void give_window_focus ( + HWND hwnd + ) + /*! + ensures + - calls SetActiveWindow(hwnd) from the event handling thread. + !*/ + { + using namespace gui_core_kernel_1_globals; + PostMessage(global_data()->helper_window,WM_USER+SET_ACTIVE_WINDOW,(WPARAM)hwnd,0); + } + + // ---------------------------------------------------------------------------------------- + + void destroy_window ( + HWND hwnd + ) + /*! + ensures + - calls DestroyWindow(hwnd) from the event handling thread. + !*/ + { + using namespace gui_core_kernel_1_globals; + PostMessage(global_data()->helper_window,WM_USER+DESTROY_WINDOW,(WPARAM)hwnd,0); + } + + // ---------------------------------------------------------------------------------------- + + HWND make_window ( + DWORD dwStyle_ + ) + /*! + ensures + - creates a window by calling CreateWindow and passes on the + dwStyle argument. + - returns the HWND that is returned by CreateWindow + - ensures that CreateWindow is called from the event handler thread + - if (it was unable to create a window) then + - returns NULL or helper_window + !*/ + { + using namespace gui_core_kernel_1_globals; + auto globals = global_data(); + // if we are running in the event handling thread then just call + // CreateWindow directly + if (get_thread_id() == globals->event_thread_id) + { + // if this is stupposed to be a popup window then do the popup window thing + if (dwStyle_ == WS_CHILD) + { + TCHAR nothing[] = TEXT(""); + HWND tmp = CreateWindowEx (WS_EX_TOOLWINDOW|WS_EX_TOPMOST, globals->window_class_name, nothing, + dwStyle_, + CW_USEDEFAULT, CW_USEDEFAULT, + CW_USEDEFAULT, CW_USEDEFAULT, + globals->helper_window, NULL, globals->hInstance, NULL); + SetParent(tmp,NULL); + return tmp; + } + else + { + TCHAR nothing[] = TEXT(""); + return CreateWindow (globals->window_class_name, nothing, + dwStyle_, + CW_USEDEFAULT, CW_USEDEFAULT, + CW_USEDEFAULT, CW_USEDEFAULT, + NULL, NULL, globals->hInstance, NULL); + } + } + else + { + auto_mutex M(globals->window_table.get_mutex()); + // wait for our chance to make a new window request + while (globals->request_new_window) + globals->et_signaler.wait(); + + + globals->dwStyle = dwStyle_; + if (PostMessage(globals->helper_window,WM_USER+CREATE_WINDOW,0,0)==0) + { + throw gui_error("Unable to schedule function for execution in event handling thread."); + } + + // wait for our request to be serviced + while (globals->new_window == NULL) + globals->et_signaler.wait(); + + HWND temp = globals->new_window; + globals->new_window = NULL; + globals->request_new_window = false; + globals->et_signaler.broadcast(); + + // if make_window() returns the helper_window then it means it failed + // to make a new window + if (temp == globals->helper_window) + temp = NULL; + + return temp; + } + } + + // ------------------------------------------------------------------------------------ + + + } // end namespace gui_core_kernel_1_globals + +// ---------------------------------------------------------------------------------------- + + void canvas:: + fill ( + unsigned char red_, + unsigned char green_, + unsigned char blue_ + ) const + { + const unsigned long red = red_; + const unsigned long green = green_; + const unsigned long blue = blue_; + + const LONG block1 = (blue<<24) | (red<<16) | (green<<8) | blue; + const LONG block2 = (green<<24) | (blue<<16) | (red<<8) | green; + const LONG block3 = (red<<24) | (green<<16) | (blue<<8) | red; + + // remember that row_width is a multiple of 4 because windows + // requires that all bitmaps have row widths that are multiples of 4. + unsigned long size = row_width/4; + for (unsigned long i = 0; i < height_; ++i) + { + unsigned long padding = size%3; + LONG* start = reinterpret_cast(bits+row_width*i); + LONG* end = reinterpret_cast(start) + size - padding; + while (start != end) + { + *start = block1; + ++start; + *start = block2; + ++start; + *start = block3; + ++start; + } + if (padding) + { + *start = block1; + ++start; + --padding; + } + if (padding) + { + *start = block2; + } + } + } + +// ---------------------------------------------------------------------------------------- + + void base_window:: + trigger_user_event ( + void* p, + int i + ) + { + using namespace gui_core_kernel_1_globals; + + user_event_type e; + e.w = hwnd; + e.p = p; + e.i = i; + { + auto_mutex M(globals->user_events.get_mutex()); + globals->user_events.enqueue(e); + } + + if (PostMessage(globals->helper_window,WM_USER+USER_EVENTS_READY,0,0)==0) + { + throw gui_error("Unable to schedule function for execution in event handling thread."); + } + } + +// ---------------------------------------------------------------------------------------- + + base_window:: + base_window ( + bool resizable, + bool undecorated + ) : + globals(gui_core_kernel_1_globals::global_data()), + has_been_destroyed(false), + prevx(-1), + prevy(-1), + prev_state(0), + wm(globals->window_table.get_mutex()) + { + using namespace gui_core_kernel_1_globals; + DLIB_ASSERT(!(undecorated == true && resizable == true), + "\tbase_window::base_window()" + << "\n\tThere is no such thing as an undecorated window that is resizable by the user." + << "\n\tthis: " << this + ); + + if (resizable) + style = WS_OVERLAPPEDWINDOW; + else if (undecorated) + style = WS_CHILD; + else + style = WS_OVERLAPPEDWINDOW ^ WS_THICKFRAME ^ WS_MAXIMIZEBOX; + + hwnd = gui_core_kernel_1_globals::make_window(style); + + if (hwnd == NULL) + throw gui_error("unable to create base_window"); + + auto_mutex M(wm); + + mouse_in = false; + + HWND temp = hwnd; + base_window* ttemp = this; + globals->window_table.add(temp,ttemp); + } + +// ---------------------------------------------------------------------------------------- + + base_window:: + ~base_window ( + ) + { + using namespace gui_core_kernel_1_globals; + close_window(); + } + +// ---------------------------------------------------------------------------------------- + + void base_window:: + close_window ( + ) + { + using namespace gui_core_kernel_1_globals; + auto_mutex M(wm); + if (has_been_destroyed == false) + { + // do this just to make sure no one tries to call this window's + // calbacks. + globals->window_table.destroy(hwnd); + gui_core_kernel_1_globals::destroy_window(hwnd); + hwnd = 0; + has_been_destroyed = true; + globals->window_close_signaler.broadcast(); + } + } + +// ---------------------------------------------------------------------------------------- + + void base_window:: + wait_until_closed ( + ) const + { + using namespace gui_core_kernel_1_globals; + auto_mutex M(wm); + while (has_been_destroyed == false) + globals->window_close_signaler.wait(); + } + +// ---------------------------------------------------------------------------------------- + + bool base_window:: + is_closed ( + ) const + { + auto_mutex M(wm); + return has_been_destroyed; + } + +// ---------------------------------------------------------------------------------------- + + void base_window:: + set_title ( + const std::string &title + ) + { + set_title(convert_mbstring_to_wstring(title)); + } + + void base_window:: + set_title ( + const ustring &title + ) + { + set_title(convert_utf32_to_wstring(title)); + } + + void base_window:: + set_title ( + const std::wstring& title + ) + { + using namespace gui_core_kernel_1_globals; + + auto_mutex M(wm); + if (has_been_destroyed == true) + return; + + + // call the SetWindowText function with our arguments but make sure it is from + // the event thread. We have to do this because the SetWindowText() apparently blocks + // until something happens in the event thread so we have to + // do this to avoid possible deadlocks. + if (get_thread_id() == globals->event_thread_id) + { + SetWindowTextW(hwnd,title.c_str()); + } + else + { + globals->window_title = title; + globals->set_window_title_done = false; + + if (PostMessage(globals->helper_window,WM_USER+CALL_SET_WINDOW_TITLE,(WPARAM)hwnd,0)==0) + { + throw gui_error("Unable to schedule SetWindowText function for execution in event handling thread."); + } + + // wait for any SetWindowText() calls to finish + while (globals->set_window_title_done == false) + globals->et_signaler.wait(); + } + } + +// ---------------------------------------------------------------------------------------- + + void base_window:: + show ( + ) + { + using namespace gui_core_kernel_1_globals; + auto_mutex M(wm); + if (has_been_destroyed == true) + return; + + show_window(hwnd); + if (style != WS_CHILD) + give_window_focus(hwnd); + } + +// ---------------------------------------------------------------------------------------- + + void base_window:: + hide( + ) + { + using namespace gui_core_kernel_1_globals; + auto_mutex M(wm); + if (has_been_destroyed == true) + return; + + hide_window(hwnd); + } + +// ---------------------------------------------------------------------------------------- + + void base_window:: + set_size ( + int width_, + int height_ + ) + { + using namespace gui_core_kernel_1_globals; + auto_mutex M(wm); + if (has_been_destroyed == true) + return; + + if (get_thread_id() == globals->event_thread_id) + { + RECT info; + GetWindowRect(hwnd,&info); + + int x = info.left; + int y = info.top; + int width; + int height; + + RECT rect; + rect.top = 0; + rect.left = 0; + rect.bottom = height_; + rect.right = width_; + AdjustWindowRectEx(&rect,style,FALSE,0); + + width = std::abs(rect.right - rect.left); + height = std::abs(rect.bottom - rect.top); + + MoveWindow( + hwnd, + x, + y, + width, + height, + TRUE); + } + else + { + RECT info; + GetWindowRect(hwnd,&info); + + int x = info.left; + int y = info.top; + int width; + int height; + + RECT rect; + rect.top = 0; + rect.left = 0; + rect.bottom = height_; + rect.right = width_; + AdjustWindowRectEx(&rect,style,FALSE,0); + + width = std::abs(rect.right - rect.left); + height = std::abs(rect.bottom - rect.top); + + // call the MoveWindow function with our arguments. We + // have to do this because the MoveWindow() apparently blocks + // until something happens in the event thread so we have to + // do this to avoid possible deadlocks. + globals->move_window_hwnd = hwnd; + globals->move_window_x = x; + globals->move_window_y = y; + globals->move_window_width = width; + globals->move_window_height = height; + globals->move_window_done = false; + + if (PostMessage(globals->helper_window,WM_USER+CALL_MOVE_WINDOW,0,0)==0) + { + throw gui_error("Unable to schedule MoveWindow function for execution in event handling thread."); + } + + // wait for any MoveWindow calls to finish + while (globals->move_window_done == false) + globals->et_signaler.wait(); + } + + } + +// ---------------------------------------------------------------------------------------- + + void base_window:: + set_pos ( + long x_, + long y_ + ) + { + using namespace gui_core_kernel_1_globals; + auto_mutex M(wm); + if (has_been_destroyed == true) + return; + + if (get_thread_id() == globals->event_thread_id) + { + RECT info; + GetWindowRect(hwnd,&info); + int width = info.right - info.left; + int height = info.bottom - info.top; + + MoveWindow( + hwnd, + x_, + y_, + width, + height, + TRUE); + + } + else + { + RECT info; + GetWindowRect(hwnd,&info); + int width = info.right - info.left; + int height = info.bottom - info.top; + + + + // call the MoveWindow function with our arguments. We + // have to do this because the MoveWindow() apparently blocks + // until something happens in the event thread so we have to + // do this to avoid possible deadlocks. + globals->move_window_hwnd = hwnd; + globals->move_window_x = x_; + globals->move_window_y = y_; + globals->move_window_width = width; + globals->move_window_height = height; + globals->move_window_done = false; + + if (PostMessage(globals->helper_window,WM_USER+CALL_MOVE_WINDOW,0,0)==0) + { + throw gui_error("Unable to schedule MoveWindow function for execution in event handling thread."); + } + + // wait for any MoveWindow calls to finish + while (globals->move_window_done == false) + globals->et_signaler.wait(); + } + } + +// ---------------------------------------------------------------------------------------- + + void base_window:: + get_pos ( + long& x_, + long& y_ + ) + { + auto_mutex M(wm); + x_ = 0; + y_ = 0; + if (has_been_destroyed == true) + return; + + POINT p; + p.x = 0; + p.y = 0; + ClientToScreen(hwnd,&p); + + x_ = p.x; + y_ = p.y; + } + +// ---------------------------------------------------------------------------------------- + + void base_window:: + get_display_size ( + unsigned long& width, + unsigned long& height + ) const + { + auto_mutex M(wm); + width = 0; + height = 0; + if (has_been_destroyed == true) + return; + + + RECT rc; + GetWindowRect(hwnd, &rc); + + HMONITOR hMonitor; + MONITORINFO mi; + // + // get the nearest monitor to the passed rect. + // + hMonitor = MonitorFromRect(&rc, MONITOR_DEFAULTTONEAREST); + + // + // get the work area or entire monitor rect. + // + mi.cbSize = sizeof(mi); + GetMonitorInfo(hMonitor, &mi); + + rc = mi.rcMonitor; + + width = static_cast(rc.right - rc.left); + height = static_cast(rc.bottom - rc.top); + + } + +// ---------------------------------------------------------------------------------------- + + void base_window:: + get_size ( + unsigned long& width, + unsigned long& height + ) const + { + auto_mutex M(wm); + width = 0; + height = 0; + if (has_been_destroyed == true) + return; + + + RECT r; + GetClientRect(hwnd,&r); + + width = r.right - r.left; + height = r.bottom - r.top; + } + +// ---------------------------------------------------------------------------------------- + + void base_window:: + invalidate_rectangle ( + const rectangle& rect + ) + { + auto_mutex M(wm); + if (rect.is_empty() == false && !has_been_destroyed) + { + RECT info; + info.top = rect.top(); + info.left = rect.left(); + info.right = rect.right()+1; + info.bottom = rect.bottom()+1; + + InvalidateRect(hwnd,&info,FALSE); + } + } + +// ---------------------------------------------------------------------------------------- + + void base_window:: + set_im_pos ( + long x, + long y + ) + { + auto_mutex M(wm); + if (has_been_destroyed == true) + return; + + HIMC hImc = ImmGetContext(hwnd); + + COMPOSITIONFORM cf; + cf.dwStyle = CFS_POINT; + cf.ptCurrentPos.x = x; + cf.ptCurrentPos.y = y; + ImmSetCompositionWindow(hImc, &cf); + ImmReleaseContext(hwnd, hImc); + } + +// ---------------------------------------------------------------------------------------- + + void put_on_clipboard ( + const std::string& str + ) + { + put_on_clipboard(convert_mbstring_to_wstring(str)); + } + + void put_on_clipboard ( + const dlib::ustring& str + ) + { + put_on_clipboard(convert_utf32_to_wstring(str)); + } + + void put_on_clipboard ( + const std::wstring& str + ) + { + using namespace gui_core_kernel_1_globals; + using namespace std; + + auto globals = global_data(); + + if (OpenClipboard(globals->helper_window)) + { + EmptyClipboard(); + auto_mutex M(globals->window_table.get_mutex()); + + const unsigned long newlines = count(str.begin(),str.end(),L'\n'); + + HGLOBAL mem = GlobalAlloc(GMEM_MOVEABLE,(str.size()+newlines+1)*sizeof(wchar_t)); + if (mem != NULL) + { + wchar_t* buf = reinterpret_cast(GlobalLock(mem)); + + if (buf != NULL) + { + // copy str into buf while also replacing all the \n with \r\n + for (wstring::size_type i = 0; i < str.size(); ++i) + { + if (str[i] != L'\n') + { + *buf = str[i]; + ++buf; + } + else + { + *buf = L'\r'; + ++buf; + *buf = L'\n'; + ++buf; + } + } + *buf = L'\0'; + GlobalUnlock(mem); + SetClipboardData(CF_UNICODETEXT,mem); + } + } + CloseClipboard(); + } + } + +// ---------------------------------------------------------------------------------------- + + void get_from_clipboard ( + std::string& str + ) + { + std::wstring wstr; + get_from_clipboard(wstr); + str = convert_wstring_to_mbstring(wstr); + } + + void get_from_clipboard ( + dlib::ustring& str + ) + { + std::wstring wstr; + get_from_clipboard(wstr); + str = convert_wstring_to_utf32(wstr); + } + + void get_from_clipboard ( + std::wstring& str + ) + { + using namespace gui_core_kernel_1_globals; + using namespace std; + auto globals = global_data(); + + auto_mutex M(globals->window_table.get_mutex()); + if (OpenClipboard(globals->helper_window)) + { + + HANDLE data = GetClipboardData(CF_UNICODETEXT); + if (data != NULL) + { + wchar_t* buf = reinterpret_cast(GlobalLock(data)); + if (buf != 0) + { + str.clear(); + + // copy the data from buf into str while also removing any '\r' + // characters. + while (*buf != L'\0') + { + if (*buf != L'\r') + str += *buf; + ++buf; + } + + GlobalUnlock(data); + } + else + { + Beep(500,500); + } + } + + CloseClipboard(); + } + } + +// ---------------------------------------------------------------------------------------- + +} + + +#endif // WIN32 + +#endif // DLIB_GUI_CORE_KERNEL_1_CPp_ + diff --git a/ml/dlib/dlib/gui_core/gui_core_kernel_1.h b/ml/dlib/dlib/gui_core/gui_core_kernel_1.h new file mode 100644 index 000000000..b7077ac01 --- /dev/null +++ b/ml/dlib/dlib/gui_core/gui_core_kernel_1.h @@ -0,0 +1,420 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net), Keita Mochizuki +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_GUI_CORE_KERNEl_1_ +#define DLIB_GUI_CORE_KERNEl_1_ + +#ifdef DLIB_ISO_CPP_ONLY +#error "DLIB_ISO_CPP_ONLY is defined so you can't use this OS dependent code. Turn DLIB_ISO_CPP_ONLY off if you want to use it." +#endif + +#ifdef DLIB_NO_GUI_SUPPORT +#error "DLIB_NO_GUI_SUPPORT is defined so you can't use the GUI code. Turn DLIB_NO_GUI_SUPPORT off if you want to use it." +#endif + +#include + +#include "../windows_magic.h" + + +#include +#include +#include +#include + +#include "gui_core_kernel_abstract.h" + +#ifdef _MSC_VER +// Disable the following warnings for Visual Studio +// +// These two warnings have to do with converting points to and from the LONG +// type. But both these types are 32 bits in windows so it is fine. +#pragma warning(disable: 4244; disable: 4312) +#endif + +#include "../algs.h" +#include "../sync_extension.h" +#include "../binary_search_tree.h" +#include "../threads.h" +#include "../geometry/rectangle.h" +#include "../assert.h" +#include "../queue.h" +#include "../pixel.h" +#include "../unicode.h" +#include "../smart_pointers/shared_ptr_thread_safe.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class base_window; + namespace gui_core_kernel_1_globals + { + + LRESULT CALLBACK WndProc (HWND, UINT, WPARAM, LPARAM); + class event_handler_thread; + + } + +// ---------------------------------------------------------------------------------------- + + class canvas : public rectangle + { + public: + struct pixel + { + unsigned char blue; + unsigned char green; + unsigned char red; + }; + + ~canvas() { } + + inline pixel* operator[] ( + unsigned long row + ) const + { + DLIB_ASSERT(row < height(), + "\tpixel* canvas::operator[]" + << "\n\tyou have to give a row that is less than the height()" + << "\n\tthis: " << this + << "\n\trow: " << row + << "\n\theight(): " << height() + ); + unsigned char* temp = bits + row_width*row; + return reinterpret_cast(temp); + } + + void fill ( + unsigned char red_, + unsigned char green_, + unsigned char blue_ + ) const; + + private: + + friend LRESULT CALLBACK gui_core_kernel_1_globals::WndProc (HWND, UINT, WPARAM, LPARAM); + + canvas ( + unsigned char* bits_, + unsigned long padding_, + unsigned long left_, + unsigned long top_, + unsigned long right_, + unsigned long bottom_ + ) : + rectangle(left_,top_,right_,bottom_), + bits(bits_), + width_(width()), + height_(height()), + row_width(width_*3+padding_) + {} + + // restricted functions + canvas(); // normal constructor + canvas(canvas&); // copy constructor + canvas& operator=(canvas&); // assignment operator + + unsigned char* const bits; + const unsigned long width_; + const unsigned long height_; + const unsigned long row_width; + }; + + template <> + struct pixel_traits + { + constexpr static bool rgb = true; + constexpr static bool rgb_alpha = false; + constexpr static bool grayscale = false; + constexpr static bool hsi = false; + constexpr static long num = 3; + typedef unsigned char basic_pixel_type; + static basic_pixel_type min() { return 0;} + static basic_pixel_type max() { return 255;} + constexpr static bool is_unsigned = true; + constexpr static bool has_alpha = false; + }; + +// ---------------------------------------------------------------------------------------- + + void put_on_clipboard ( + const std::string& str + ); + + void put_on_clipboard ( + const std::wstring& str + ); + + void put_on_clipboard ( + const dlib::ustring& str + ); + +// ---------------------------------------------------------------------------------------- + + void get_from_clipboard ( + std::string& str + ); + + void get_from_clipboard ( + std::wstring& str + ); + + void get_from_clipboard ( + dlib::ustring& str + ); + +// ---------------------------------------------------------------------------------------- + + class base_window + { + friend LRESULT CALLBACK gui_core_kernel_1_globals::WndProc (HWND, UINT, WPARAM, LPARAM); + dlib::shared_ptr_thread_safe globals; + + HWND hwnd; + DWORD style; + bool has_been_destroyed; + + // This is true if the mouse is in this window. false otherwise. + // also note that this variable is only accessed from the event handling thread + // (except for being initialized below in the constructor, but that is inside + // the window_table mutex so it doesn't matter). + bool mouse_in; + + // this is a copy of the last inputs we sent to the on_mouse_move() event. + long prevx; + long prevy; + unsigned long prev_state; + + protected: + const rmutex& wm; + + public: + + base_window ( + bool resizable = true, + bool undecorated = false + ); + + virtual ~base_window ( + ); + + void close_window ( + ); + + bool is_closed ( + ) const; + + void set_title ( + const std::string& title + ); + + void set_title ( + const std::wstring& title + ); + + void set_title ( + const ustring& title + ); + + virtual void show ( + ); + + virtual void hide( + ); + + void set_size ( + int width_, + int height_ + ); + + void set_pos ( + long x_, + long y_ + ); + + void get_pos ( + long& x_, + long& y_ + ); + + void get_size ( + unsigned long& width, + unsigned long& height + ) const; + + void get_display_size ( + unsigned long& width, + unsigned long& height + ) const; + + void invalidate_rectangle ( + const rectangle& rect + ); + + void trigger_user_event ( + void* p, + int i + ); + + void wait_until_closed ( + ) const; + + void set_im_pos ( + long x_, + long y_ + ); + + enum on_close_return_code + { + DO_NOT_CLOSE_WINDOW, + CLOSE_WINDOW + }; + + enum mouse_state_masks + { + NONE = 0, + LEFT = 1, + RIGHT = 2, + MIDDLE = 4, + SHIFT = 8, + CONTROL = 16 + }; + + enum keyboard_state_masks + { + KBD_MOD_NONE = 0, + KBD_MOD_SHIFT = 1, + KBD_MOD_CONTROL = 2, + KBD_MOD_ALT = 4, + KBD_MOD_META = 8, + KBD_MOD_CAPS_LOCK = 16, + KBD_MOD_NUM_LOCK = 32, + KBD_MOD_SCROLL_LOCK = 64 + }; + + enum non_printable_keyboard_keys + { + KEY_BACKSPACE, + KEY_SHIFT, + KEY_CTRL, + KEY_ALT, + KEY_PAUSE, + KEY_CAPS_LOCK, + KEY_ESC, + KEY_PAGE_UP, + KEY_PAGE_DOWN, + KEY_END, + KEY_HOME, + KEY_LEFT, // This is the left arrow key + KEY_RIGHT, // This is the right arrow key + KEY_UP, // This is the up arrow key + KEY_DOWN, // This is the down arrow key + KEY_INSERT, + KEY_DELETE, + KEY_SCROLL_LOCK, + + // Function Keys + KEY_F1, + KEY_F2, + KEY_F3, + KEY_F4, + KEY_F5, + KEY_F6, + KEY_F7, + KEY_F8, + KEY_F9, + KEY_F10, + KEY_F11, + KEY_F12 + }; + + protected: + + virtual on_close_return_code on_window_close( + ){return CLOSE_WINDOW;} + + virtual void on_user_event ( + void* , + int + ){} + + virtual void on_window_resized( + ){} + + virtual void on_window_moved( + ){} + + virtual void on_mouse_down ( + unsigned long , + unsigned long , + long , + long , + bool + ){} + + virtual void on_mouse_up ( + unsigned long , + unsigned long , + long , + long + ){} + + virtual void on_mouse_move ( + unsigned long , + long , + long + ){} + + virtual void on_mouse_leave ( + ){} + + virtual void on_mouse_enter ( + ){} + + virtual void on_wheel_up ( + unsigned long + ){} + + virtual void on_wheel_down ( + unsigned long + ){} + + virtual void on_focus_gained ( + ){} + + virtual void on_focus_lost ( + ){} + + virtual void on_keydown ( + unsigned long , + bool , + unsigned long + ){} + + virtual void on_string_put ( + const std::wstring& + ){} + + private: + + virtual void paint ( + const canvas& + ) =0; + + base_window(base_window&); // copy constructor + base_window& operator=(base_window&); // assignment operator + + }; + +// ---------------------------------------------------------------------------------------- + +} + + +#ifdef NO_MAKEFILE +#include "gui_core_kernel_1.cpp" +#endif + +#endif // DLIB_GUI_CORE_KERNEl_1_ + diff --git a/ml/dlib/dlib/gui_core/gui_core_kernel_2.cpp b/ml/dlib/dlib/gui_core/gui_core_kernel_2.cpp new file mode 100644 index 000000000..feca4bf22 --- /dev/null +++ b/ml/dlib/dlib/gui_core/gui_core_kernel_2.cpp @@ -0,0 +1,1996 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net), Keita Mochizuki +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_GUI_CORE_KERNEL_2_CPp_ +#define DLIB_GUI_CORE_KERNEL_2_CPp_ +#include "../platform.h" + +#ifdef POSIX + +#include "gui_core_kernel_2.h" + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +#include "../assert.h" +#include "../queue.h" +#include "../sync_extension.h" +#include "../logger.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace gui_core_kernel_2_globals + { + void init_keyboard_mod_masks(); + struct user_event_type + { + Window w; + void* p; + int i; + }; + + typedef sync_extension::kernel_1b>::kernel_2a_c>::kernel_1a queue_of_user_events; + + typedef sync_extension::kernel_1a>::kernel_1a + window_table_type; + + // ---------------------------------------------------------------------------------------- + + const std::shared_ptr& global_mutex() + { + static std::shared_ptr m(new dlib::mutex); + return m; + } + + class event_handler_thread : public threaded_object + { + public: + + enum et_state + { + uninitialized, + initialized, + failure_to_init + }; + + et_state status; + logger dlog; + + + int depth; + Display* disp; + XIM xim; + XIMStyle xim_style; + Screen* screen; + + Atom delete_window; + Window exit_window; + std::wstring clipboard; + + int alt_mask; + int meta_mask; + int num_lock_mask; + int scroll_lock_mask; + + // the mutex in this object is the global mutex used to protect everything + // in the gui_core and gui_widgets components. + window_table_type window_table; + + rsignaler window_close_signaler; + rsignaler et_signaler; + + queue_of_user_events user_events; + queue_of_user_events user_events_temp; + + std::shared_ptr reference_to_global_mutex; + + event_handler_thread( + ) : + dlog("dlib.gui_core"), + depth(0), + disp(0), + xim(0), + screen(0), + alt_mask(0), + meta_mask(0), + num_lock_mask(0), + scroll_lock_mask(0), + window_close_signaler(window_table.get_mutex()), + et_signaler(window_table.get_mutex()), + reference_to_global_mutex(global_mutex()) + { + auto_mutex M(window_table.get_mutex()); + + status = uninitialized; + + // start up the event handler thread + start(); + + // wait for the event thread to get up and running + while (status == uninitialized) + et_signaler.wait(); + + if (status == failure_to_init) + throw gui_error("Failed to initialize X11 resources"); + + init_keyboard_mod_masks(); + } + + ~event_handler_thread () + { + + if (is_alive()) + { + + if (status != failure_to_init) + { + XConfigureEvent event; + event.type = ConfigureNotify; + event.send_event = True; + event.display = disp; + event.window = exit_window; + event.x = 1; + XFlush(disp); + XPutBackEvent(disp,reinterpret_cast(&event)); + XFlush(disp); + + // This should cause XNextEvent() to unblock so that it will see + // this ConfigureNotify event we are putting onto the event queue. + XSendEvent(disp,exit_window,False,0,reinterpret_cast(&event)); + XFlush(disp); + + wait(); + + if (xim != NULL) + { + XCloseIM(xim); + } + + XCloseDisplay(disp); + + + } + else + { + + wait(); + } + } + + + } + + private: + + void thread ( + ) + { + using namespace std; + using namespace dlib; + try + { + + // You are supposed to call this if using XLib in a threaded program. Note + // however that at one point I noticed that calling this causes a dead-lock + // when using XIM. But I can't reproduce that anymore and not calling it + // sometimes causes XCloseDisplay() to hang. + if (XInitThreads() == 0) + { + dlog << LFATAL << "Unable to initialize threading support."; + // signal that an error has occurred + window_table.get_mutex().lock(); + status = failure_to_init; + et_signaler.broadcast(); + window_table.get_mutex().unlock(); + return; + } + + window_table.get_mutex().lock(); + disp = XOpenDisplay(NULL); + window_table.get_mutex().unlock(); + if (disp == 0) + { + window_table.get_mutex().lock(); + disp = XOpenDisplay(":0.0"); + window_table.get_mutex().unlock(); + if (disp == 0) + { + dlog << LFATAL << "Unable to connect to the X display."; + // signal that an error has occurred + window_table.get_mutex().lock(); + status = failure_to_init; + et_signaler.broadcast(); + window_table.get_mutex().unlock(); + return; + } + } + + window_table.get_mutex().lock(); + screen = DefaultScreenOfDisplay(disp); + depth = DefaultDepthOfScreen(screen); + delete_window = XInternAtom(disp,"WM_DELETE_WINDOW",1); + window_table.get_mutex().unlock(); + + xim = NULL; + // I'm disabling XIM usage all together because calling XSetICValues() + // in set_im_pos() randomly hangs the application (on Ubuntu 13.10 at + // least). + /* + window_table.get_mutex().lock(); + std::string saved_locale(setlocale (LC_CTYPE, NULL)); + if (setlocale( LC_CTYPE, "" ) && XSupportsLocale() && XSetLocaleModifiers("")) + xim = XOpenIM(disp, NULL, NULL, NULL); + else + setlocale( LC_CTYPE, saved_locale.c_str() ); + window_table.get_mutex().unlock(); + */ + if (xim) + { + const static XIMStyle preedit_styles[] = + {XIMPreeditPosition, XIMPreeditNothing, XIMPreeditNone, 0}; + const static XIMStyle status_styles[] = + {XIMStatusNothing, XIMStatusNone, 0}; + xim_style = 0; + + XIMStyles *xim_styles; + window_table.get_mutex().lock(); + + XGetIMValues (xim, XNQueryInputStyle, &xim_styles, (const void*)NULL); + window_table.get_mutex().unlock(); + std::set xims; + for (int i = 0; i < xim_styles->count_styles; ++i){ + xims.insert(xim_styles->supported_styles[i]); + } + for (int j = 0; status_styles[j]; ++j){ + for (int i = 0; preedit_styles[i]; ++i){ + xim_style = (status_styles[j] | preedit_styles[i]); + if (xims.count(xim_style)) break; + } + if (xim_style) break; + } + XFree(xim_styles); + } + + // make this window just so we can send messages to it and trigger + // events in the event thread + XSetWindowAttributes attr; + window_table.get_mutex().lock(); + exit_window = XCreateWindow( + disp, + DefaultRootWindow(disp), + 0, + 0, + 10, // this is the default width of a window + 10, // this is the default width of a window + 0, + depth, + InputOutput, + CopyFromParent, + 0, + &attr + ); + window_table.get_mutex().unlock(); + + // signal that the event thread is now up and running + window_table.get_mutex().lock(); + status = initialized; + et_signaler.broadcast(); + window_table.get_mutex().unlock(); + + // start the event handler + event_handler(); + } + catch (std::exception& e) + { + cout << "\nEXCEPTION THROWN: \n" << e.what() << endl; + abort(); + } + catch (...) + { + cout << "UNKNOWN EXCEPTION THROWN.\n" << endl; + abort(); + } + } + + void event_handler(); + void init_keyboard_mod_masks(); + }; + + struct x11_base_windowstuff + { + Window hwnd; + Time last_click_time; + XIC xic; + XFontSet fs; + std::shared_ptr globals; + }; + + // Do all this just to make sure global_mutex() is initialized at program start + // and thus hopefully before any threads have the chance to startup and call + // global_data() concurrently. + struct call_global_mutex { call_global_mutex() { global_mutex(); } }; + static call_global_mutex call_global_mutex_instance; + + const std::shared_ptr& global_data() + { + auto_mutex M(*global_mutex()); + static std::shared_ptr p; + if (p.get() == 0) + p.reset(new event_handler_thread()); + return p; + } + + // ---------------------------------------------------------------------------------------- + + Bool XCheckIfEventPredicate ( + Display* , + XEvent* event, + XPointer arg + ) + /*! + ensures + - if (event is an Expose event for the window pointed to by arg) then + - returns true + - else + - returns false + !*/ + { + if (event->type == Expose) + { + XExposeEvent* e = reinterpret_cast(event); + Window* win= reinterpret_cast(arg); + if (e->window == *win) + { + return 1; + } + } + return 0; + } + + // ---------------------------------------------------------------------------------------- + + static bool map_keys ( + KeySym keycode, + bool , + bool , + unsigned long& result, + bool& is_printable + ) + /*! + requires + - if (shift was down for this key) then + - shift == true + - if (caps lock was on for this key) then + - caps == true + - keycode == the keycode from windows that we are to process + - keycode < keyboard_keys_size + ensures + - if (this key should be ignored) then + - returns false + - else + - returns true + - #is_printable == true if result is a printable ascii character + - #result == the keycode converted into the proper number to tbe + returned by the event handler. + !*/ + { + is_printable = true; + if ((keycode <= 'z' && keycode >= 'a') || + (keycode <= 'Z' && keycode >= 'A') || + (keycode <= '9' && keycode >= '0')) + { + result = keycode; + } + else + { + is_printable = false; + switch (keycode) + { + case XK_Home: result = base_window::KEY_HOME; break; + case XK_Left: result = base_window::KEY_LEFT; break; + case XK_Right: result = base_window::KEY_RIGHT; break; + case XK_Down: result = base_window::KEY_DOWN; break; + case XK_Up: result = base_window::KEY_UP; break; + case XK_Prior: result = base_window::KEY_PAGE_UP; break; + case XK_Next: result = base_window::KEY_PAGE_DOWN; break; + case XK_End: result = base_window::KEY_END; break; + case XK_Escape: result = base_window::KEY_ESC; break; + + case XK_KP_Delete: result = base_window::KEY_DELETE; break; + case XK_KP_Prior: result = base_window::KEY_PAGE_UP; break; + case XK_KP_Next: result = base_window::KEY_PAGE_DOWN; break; + + + case XK_F1: result = base_window::KEY_F1; break; + case XK_F2: result = base_window::KEY_F2; break; + case XK_F3: result = base_window::KEY_F3; break; + case XK_F4: result = base_window::KEY_F4; break; + case XK_F5: result = base_window::KEY_F5; break; + case XK_F6: result = base_window::KEY_F6; break; + case XK_F7: result = base_window::KEY_F7; break; + case XK_F8: result = base_window::KEY_F8; break; + case XK_F9: result = base_window::KEY_F9; break; + case XK_F10: result = base_window::KEY_F10; break; + case XK_F11: result = base_window::KEY_F11; break; + case XK_F12: result = base_window::KEY_F12; break; + + + case XK_Shift_L: result = base_window::KEY_SHIFT; break; + case XK_Shift_R: result = base_window::KEY_SHIFT; break; + case XK_Control_L: result = base_window::KEY_CTRL; break; + case XK_Control_R: result = base_window::KEY_CTRL; break; + case XK_Caps_Lock: result = base_window::KEY_CAPS_LOCK; break; + case XK_Alt_L: result = base_window::KEY_ALT; break; + case XK_Alt_R: result = base_window::KEY_ALT; break; + + + case XK_BackSpace: result = base_window::KEY_BACKSPACE; break; + case XK_Delete: result = base_window::KEY_DELETE; break; + case XK_Scroll_Lock: result = base_window::KEY_SCROLL_LOCK; break; + case XK_Pause: result = base_window::KEY_PAUSE; break; + case XK_Insert: result = base_window::KEY_INSERT; break; + case XK_KP_Insert: result = base_window::KEY_INSERT; break; + + + + + case XK_exclam: + is_printable = true; + result = '!'; break; + case XK_quotedbl: + is_printable = true; + result = '"'; break; + case XK_numbersign: + is_printable = true; + result = '#'; break; + case XK_dollar: + is_printable = true; + result = '$'; break; + case XK_percent: + is_printable = true; + result = '%'; break; + case XK_ampersand: + is_printable = true; + result = '&'; break; + case XK_apostrophe: + is_printable = true; + result = '\''; break; + case XK_parenleft: + is_printable = true; + result = '('; break; + case XK_parenright: + is_printable = true; + result = ')'; break; + case XK_asterisk: + is_printable = true; + result = '*'; break; + case XK_plus: + is_printable = true; + result = '+'; break; + case XK_comma: + is_printable = true; + result = ','; break; + case XK_minus: + is_printable = true; + result = '-'; break; + case XK_period: + is_printable = true; + result = '.'; break; + case XK_slash: + is_printable = true; + result = '/'; break; + case XK_colon: + is_printable = true; + result = ':'; break; + case XK_semicolon: + is_printable = true; + result = ';'; break; + case XK_less: + is_printable = true; + result = '<'; break; + case XK_equal: + is_printable = true; + result = '='; break; + case XK_greater: + is_printable = true; + result = '>'; break; + case XK_question: + is_printable = true; + result = '?'; break; + case XK_at: + is_printable = true; + result = '@'; break; + case XK_grave: + is_printable = true; + result = '`'; break; + case XK_underscore: + is_printable = true; + result = '_'; break; + case XK_asciicircum: + is_printable = true; + result = '^'; break; + case XK_bracketleft: + is_printable = true; + result = '['; break; + case XK_backslash: + is_printable = true; + result = '\\'; break; + case XK_bracketright: + is_printable = true; + result = ']'; break; + case XK_asciitilde: + is_printable = true; + result = '~'; break; + case XK_braceleft: + is_printable = true; + result = '{'; break; + case XK_bar: + is_printable = true; + result = '|'; break; + case XK_braceright: + is_printable = true; + result = '}'; break; + + + + + case XK_space: + is_printable = true; + result = ' '; break; + case XK_Return: + is_printable = true; + result = '\n'; break; + case XK_Tab: + is_printable = true; + result = '\t'; break; + case XK_KP_Divide: + is_printable = true; + result = '/'; break; + case XK_KP_Decimal: + is_printable = true; + result = '.'; break; + case XK_KP_Subtract: + is_printable = true; + result = '-'; break; + case XK_KP_Add: + is_printable = true; + result = '+'; break; + case XK_KP_Multiply: + is_printable = true; + result = '*'; break; + case XK_KP_Equal: + is_printable = true; + result = '='; break; + + case XK_KP_0: + is_printable = true; + result = '0'; break; + case XK_KP_1: + is_printable = true; + result = '1'; break; + case XK_KP_2: + is_printable = true; + result = '2'; break; + case XK_KP_3: + is_printable = true; + result = '3'; break; + case XK_KP_4: + is_printable = true; + result = '4'; break; + case XK_KP_5: + is_printable = true; + result = '5'; break; + case XK_KP_6: + is_printable = true; + result = '6'; break; + case XK_KP_7: + is_printable = true; + result = '7'; break; + case XK_KP_8: + is_printable = true; + result = '8'; break; + case XK_KP_9: + is_printable = true; + result = '9'; break; + + default: + return false; + } + } + + return true; + } + + // ---------------------------------------------------------------------------------------- + + void event_handler_thread:: + event_handler ( + ) + /*! + ensures + - will handle all events and event dispatching + !*/ + { + try + { + std::vector bitmap_buffer; + bool quit_event_loop = false; + while (quit_event_loop == false) + { + // get a lock on the window_table's mutex + auto_mutex window_table_locker(window_table.get_mutex()); + + XEvent ev; + memset(&ev, 0, sizeof(ev)); + while (XPending(disp) == 0){ + window_table.get_mutex().unlock(); + // wait until receiving X11 next event + struct pollfd pfd; + pfd.fd = ConnectionNumber(disp); + pfd.events = POLLIN | POLLPRI; + poll(&pfd, 1, -1); + + window_table.get_mutex().lock(); + } + XNextEvent(disp,&ev); + + // pass events to input method. + // if this event is needed by input method, XFilterEvent returns True + if (XFilterEvent(&ev, None) == True){ + continue; + } + + // if this event is for one of the windows in the window_table + // then get that window out of the table and put it into win. + XAnyEvent* _ae = reinterpret_cast(&ev); + base_window** win_ = window_table[_ae->window]; + base_window* win = 0; + if (win_) + win = *win_; + + + // ignore messages for unmapped windows + if (ev.type != MapNotify && win != 0) + { + if (win->is_mapped == false) + continue; + } + + + switch (ev.type) + { + + case SelectionRequest: + { + Atom a_ct = XInternAtom(disp, "COMPOUND_TEXT", False); + XSelectionRequestEvent* req = reinterpret_cast(&ev.xselectionrequest); + XEvent respond; + + if (req->target == XA_STRING) + { + XChangeProperty (disp, + req->requestor, + req->property, + XA_STRING, + 8, + PropModeReplace, + reinterpret_cast(convert_wstring_to_mbstring(clipboard).c_str()), + clipboard.size()+1); + respond.xselection.property=req->property; + } + else if (req->target == a_ct) + { + XChangeProperty (disp, + req->requestor, + req->property, + a_ct, + sizeof(wchar_t)*8, + PropModeReplace, + reinterpret_cast(clipboard.c_str()), + clipboard.size()+1); + respond.xselection.property=req->property; + } + else + { + respond.xselection.property= None; + } + respond.xselection.type= SelectionNotify; + respond.xselection.display= req->display; + respond.xselection.requestor= req->requestor; + respond.xselection.selection=req->selection; + respond.xselection.target= req->target; + respond.xselection.time = req->time; + XSendEvent (disp, req->requestor,0,0,&respond); + XFlush (disp); + + } break; + + case MapNotify: + { + if (win == 0) + break; + + win->is_mapped = true; + + if (win->resizable == false) + { + XSizeHints* hints = XAllocSizeHints(); + hints->flags = PMinSize|PMaxSize; + hints->min_width = win->width; + hints->max_width = win->width; + hints->max_height = win->height; + hints->min_height = win->height; + XSetNormalHints(disp,win->x11_stuff.hwnd,hints); + XFree(hints); + } + + + + if (win->has_been_resized) + { + XResizeWindow(disp,win->x11_stuff.hwnd,win->width,win->height); + win->has_been_resized = false; + win->on_window_resized(); + } + + if (win->has_been_moved) + { + XMoveWindow(disp,win->x11_stuff.hwnd,win->x,win->y); + win->has_been_moved = false; + win->on_window_moved(); + } + XFlush(disp); + + + } break; + + + case KeyPress: + { + XKeyPressedEvent* e = reinterpret_cast(&ev); + + if (win == 0) + break; + + unsigned long state = 0; + bool shift = ((e->state & ShiftMask)!=0); + bool ctrl = ((e->state & ControlMask)!=0); + bool caps = ((e->state & LockMask)!=0); + if(shift) + state |= base_window::KBD_MOD_SHIFT; + if(ctrl) + state |= base_window::KBD_MOD_CONTROL; + if(caps) + state |= base_window::KBD_MOD_CAPS_LOCK; + if((e->state & alt_mask)!=0) + state |= base_window::KBD_MOD_ALT; + if((e->state & meta_mask)!=0) + state |= base_window::KBD_MOD_META; + if((e->state & num_lock_mask)!=0) + state |= base_window::KBD_MOD_NUM_LOCK; + if((e->state & scroll_lock_mask)!=0) + state |= base_window::KBD_MOD_SCROLL_LOCK; + + KeySym key; + Status status; + + if (win->x11_stuff.xic) { + std::wstring wstr; + wstr.resize(2); + int len = XwcLookupString(win->x11_stuff.xic,e,&wstr[0],wstr.size(),&key,&status); + if (status == XBufferOverflow){ + wstr.resize(len); + len = XwcLookupString(win->x11_stuff.xic,e,&wstr[0],wstr.size(),&key,&status); + } + if (status == XLookupChars){ + win->on_string_put(wstr); + } + } else { + char buffer[2]; + XLookupString(e, buffer, sizeof(buffer), &key, NULL); + status = XLookupKeySym; + } + + if (status == XLookupKeySym || status == XLookupBoth){ + + bool is_printable; + unsigned long result; + + if (map_keys(key,shift,caps,result,is_printable)) + { + // signal the keyboard event + win->on_keydown(result,is_printable,state); + } + } + + } break; + + case FocusIn: + { + if (win == 0) + break; + + // signal the focus event + win->on_focus_gained(); + } break; + + case FocusOut: + { + if (win == 0) + break; + + // signal the focus event + win->on_focus_lost(); + } break; + + case ButtonPress: + case ButtonRelease: + { + XButtonEvent* e = reinterpret_cast(&ev); + + if (win == 0) + break; + + unsigned long btn = base_window::NONE; + if (e->button == Button1) + btn = base_window::LEFT; + else if (e->button == Button3) + btn = base_window::RIGHT; + else if (e->button == Button2) + btn = base_window::MIDDLE; + + unsigned long state = 0; + if (e->state & ControlMask) + state |= base_window::CONTROL; + if (e->state & Button1Mask) + state |= base_window::LEFT; + if (e->state & Button2Mask) + state |= base_window::MIDDLE; + if (e->state & Button3Mask) + state |= base_window::RIGHT; + if (e->state & ShiftMask) + state |= base_window::SHIFT; + + // only send the event if this is a button we support + if (btn != (unsigned long)base_window::NONE) + { + + + if (ev.type == ButtonPress) + { + bool is_double_click = false; + if (win->last_click_button == btn && + std::abs((long)win->last_click_x - (long)e->x) < 5 && + std::abs((long)win->last_click_y - (long)e->y) < 5 && + e->time - win->x11_stuff.last_click_time <= 400) + { + // this is a double click + is_double_click = true; + // set this to make sure the next click can't be + // interpreted as a double click + win->last_click_button = base_window::NONE; + } + else + { + win->last_click_button = btn; + win->last_click_x = e->x; + win->last_click_y = e->y; + win->x11_stuff.last_click_time = e->time; + } + + // remove the clicked button from the state + state &= (~btn); + win->on_mouse_down(btn,state,e->x,e->y,is_double_click); + + } + else + { + // remove the clicked button from the state + state &= (~btn); + win->on_mouse_up(btn,state,e->x,e->y); + } + } + else if (e->button == Button4 && ev.type == ButtonPress) + { + win->on_wheel_up(state); + } + else if (e->button == Button5 && ev.type == ButtonPress) + { + win->on_wheel_down(state); + } + + } break; + + case LeaveNotify: + { + if (win == 0) + break; + + win->on_mouse_leave(); + + } break; + + case EnterNotify: + { + if (win == 0) + break; + + win->on_mouse_enter(); + } break; + + case MotionNotify: + { + XMotionEvent* e = reinterpret_cast(&ev); + + if (win == 0) + break; + + unsigned long state = 0; + if (e->state & ControlMask) + state |= base_window::CONTROL; + if (e->state & Button1Mask) + state |= base_window::LEFT; + if (e->state & Button2Mask) + state |= base_window::MIDDLE; + if (e->state & Button3Mask) + state |= base_window::RIGHT; + if (e->state & ShiftMask) + state |= base_window::SHIFT; + + win->on_mouse_move(state,e->x,e->y); + + } break; + + case ConfigureNotify: + { + XConfigureEvent* e = reinterpret_cast(&ev); + if (e->window == exit_window) + { + // this is the signal to quit the event handler + quit_event_loop = true; + break; + } + + if (win == 0) + break; + + if (win->width != e->width || + win->height != e->height || + win->has_been_resized) + { + win->has_been_resized = false; + // this is a resize + win->width = e->width; + win->height = e->height; + win->on_window_resized(); + } + if (win->x != e->x || + win->y != e->y || + win->has_been_moved) + { + win->has_been_moved = false; + // this is a move + win->x = e->x; + win->y = e->y; + win->on_window_moved(); + } + + } break; + + case ClientMessage: + { + XClientMessageEvent* e = reinterpret_cast(&ev); + if ((Atom)e->data.l[0] == delete_window) + { + if (win == 0) + break; + + + if (win->on_window_close() == base_window::DO_NOT_CLOSE_WINDOW) + { + DLIB_ASSERT(win->has_been_destroyed == false, + "\tYou called close_window() inside the on_window_close() event but" + << "\n\tthen returned DO_NOT_CLOSE_WINDOW. You can do one or the other but not both." + << "\n\tthis: " << win + ); + // the client has decided not to close the window + // after all + } + else + { + if (window_table[e->window]) + { + window_table.destroy(e->window); + XDestroyWindow(disp,e->window); + win->has_been_destroyed = true; + window_close_signaler.broadcast(); + } + else + { + // in this case the window must have self destructed by + // calling delete this; so we don't have to do anything. + } + } + } + } break; + + case Expose: + { + XExposeEvent* e = reinterpret_cast(&ev); + + if (win == 0) + break; + + // take all the expose events for this window out + XEvent etemp; + int x = e->x; + int y = e->y; + int width = e->width; + int height = e->height; + + + + // What we are doing here with this loop is we are combining + // all of the Expose events for this window that are + // currently in the queue. + while (XCheckIfEvent(disp,&etemp,XCheckIfEventPredicate,reinterpret_cast(&(e->window)))) + { + XExposeEvent* e2 = reinterpret_cast(&etemp); + if (e2->x < x) + { + width += x - e2->x; + x = e2->x; + } + if (e2->y < y) + { + height += y - e2->y; + y = e2->y; + } + if (e2->width + e2->x > width + x) + { + width = e2->width + e2->x - x; + } + if (e2->height + e2->y > height + y) + { + height = e2->height + e2->y - y; + } + } + + // I'm not sure if this sort of thing can happen but + // if it does then just ignore this entire event. + if (width == 0 || height == 0) + { + break; + } + + if (bitmap_buffer.size() < static_cast(width*height*4)) + bitmap_buffer.resize(width*height*4); + + unsigned char* const bitmap = &bitmap_buffer[0]; + unsigned char* const end = bitmap + width*height*4; + + unsigned char* temp; + canvas c(bitmap,x,y,x+width-1,y+height-1); + + + win->paint(c); + + // the user might have called win->close_window() and if they did + // then just stop right here. We don't want to paint the window. + if (win->has_been_destroyed) + break; + + // if the color depth we are working with isn't 24bits then we need + // to transform our image into whatever it is supposed to be. + if (depth != 24) + { + // convert this image into an 8 bit image + unsigned int red_bits = 0; + unsigned int green_bits = 0; + unsigned int blue_bits = 0; + if (depth != 16) + { + unsigned int bits = depth/3; + unsigned int extra = depth%3; + red_bits = bits; + green_bits = bits; + blue_bits = bits; + if (extra) + { + ++red_bits; + --extra; + } + if (extra) + { + ++green_bits; + } + } + else if (depth == 16) + { + red_bits = 5; + green_bits = 6; + blue_bits = 5; + } + + if (depth == 16) + { + temp = bitmap; + unsigned char *red, *green, *blue; + while (temp != end) + { + blue = temp; + ++temp; + green = temp; + ++temp; + red = temp; + ++temp; + ++temp; + + const unsigned long r = static_cast(*red)>>(8-red_bits); + const unsigned long g = static_cast(*green)>>(8-green_bits); + const unsigned long b = static_cast(*blue)>>(8-blue_bits); + + unsigned long color = (r<<(depth-red_bits))| (g<<(depth-red_bits-green_bits))| b; + + *blue = (color>>0)&0xFF; + *green = (color>>8)&0xFF; + } + } + else if (depth < 24) + { + temp = bitmap; + unsigned char *red, *green, *blue; + while (temp != end) + { + blue = temp; + ++temp; + green = temp; + ++temp; + red = temp; + ++temp; + ++temp; + + const unsigned long r = static_cast(*red)>>(8-red_bits); + const unsigned long g = static_cast(*green)>>(8-green_bits); + const unsigned long b = static_cast(*blue)>>(8-blue_bits); + + unsigned long color = (b<<(depth-blue_bits))| (g<<(depth-blue_bits-green_bits))| r; + + *blue = (color>>0)&0xFF; + *green = (color>>8)&0xFF; + *red = (color>>16)&0xFF; + } + } + else if (depth > 24) + { + temp = bitmap; + unsigned char *red, *green, *blue, *four; + while (temp != end) + { + blue = temp; + ++temp; + green = temp; + ++temp; + red = temp; + ++temp; + four = temp; + ++temp; + + const unsigned long r = static_cast(*red)<<(red_bits-8); + const unsigned long g = static_cast(*green)<<(green_bits-8); + const unsigned long b = static_cast(*blue)<<(blue_bits-8); + + unsigned long color = (b<<(depth-blue_bits))| (g<<(depth-blue_bits-green_bits))| r; + + *blue = (color>>0)&0xFF; + *green = (color>>8)&0xFF; + *red = (color>>16)&0xFF; + *four = (color>>24)&0xFF; + } + } + } // if (depth != 24) + + + + XImage img; + memset(&img,0,sizeof(img)); + img.width = width; + img.height = height; + img.depth = depth; + img.data = reinterpret_cast(bitmap); + img.bitmap_bit_order = LSBFirst; + img.byte_order = LSBFirst; + img.format = ZPixmap; + img.bitmap_pad = 32; + img.bitmap_unit = 32; + img.bits_per_pixel = 32; + + + XInitImage(&img); + + GC gc = XCreateGC(disp, e->window, 0, NULL); + + XPutImage(disp,e->window,gc,&img,0,0,x,y,width,height); + + XFreeGC(disp,gc); + } break; + } // switch (ev.type) + } + } + catch (std::exception& e) + { + dlog << LFATAL << "Exception thrown in event handler: " << e.what(); + } + catch (...) + { + dlog << LFATAL << "Unknown exception thrown in event handler."; + } + } + + // ---------------------------------------------------------------------------------------- + + + int index_to_modmask(unsigned long n) + { + switch ( n ) + { + case 0: + return Mod1Mask; + case 1: + return Mod2Mask; + case 2: + return Mod3Mask; + case 3: + return Mod4Mask; + } + return Mod5Mask; + } + + void event_handler_thread:: + init_keyboard_mod_masks() + { + XModifierKeymap* map = XGetModifierMapping( disp ); + KeyCode* codes = map->modifiermap + map->max_keypermod * Mod1MapIndex; + for (int n = 0; n < 5 * map->max_keypermod; n++ ) + { + if ( codes[n] == 0 ) + continue; + switch(XkbKeycodeToKeysym( disp, codes[n], 0, 0 )) + { + case XK_Alt_L: + alt_mask = index_to_modmask(n / map->max_keypermod); + continue; + case XK_Alt_R: + if(alt_mask == 0) + alt_mask = index_to_modmask(n / map->max_keypermod); + continue; + case XK_Meta_L: + case XK_Meta_R: + meta_mask = index_to_modmask(n / map->max_keypermod); + continue; + case XK_Scroll_Lock: + scroll_lock_mask = index_to_modmask(n / map->max_keypermod); + continue; + case XK_Num_Lock: + num_lock_mask = index_to_modmask(n / map->max_keypermod); + default: + continue; + } + } + XFreeModifiermap( map ); + if ( alt_mask == 0 ) + { + dlog << LWARN << "Search for Alt-key faild."; + if ( meta_mask != 0 ) + alt_mask = meta_mask; + else + alt_mask = Mod1Mask; // resort to guessing + } + } + + // ---------------------------------------------------------------------------------------- + + + + + + } // namespace gui_core_kernel_2_globals + +// ---------------------------------------------------------------------------------------- + + void canvas:: + fill ( + unsigned char red_, + unsigned char green_, + unsigned char blue_ + ) const + { + pixel pixel_value; + pixel_value.red = red_; + pixel_value.green = green_; + pixel_value.blue = blue_; + pixel_value._padding = 0; + + pixel* start = reinterpret_cast(bits); + pixel* end = start + width_*height_; + + while (start != end) + { + *start = pixel_value; + ++start; + } + } + +// ---------------------------------------------------------------------------------------- + + void put_on_clipboard ( + const std::string& str + ) + { + put_on_clipboard(convert_mbstring_to_wstring(str)); + } + + void put_on_clipboard ( + const dlib::ustring& str + ) + { + put_on_clipboard(convert_utf32_to_wstring(str)); + } + + void put_on_clipboard ( + const std::wstring& str + ) + { + using namespace gui_core_kernel_2_globals; + + std::shared_ptr globals(global_data()); + + auto_mutex M(globals->window_table.get_mutex()); + globals->clipboard = str.c_str(); + + XSetSelectionOwner(globals->disp,XA_PRIMARY,globals->exit_window,CurrentTime); + } + +// ---------------------------------------------------------------------------------------- + + Bool clip_peek_helper ( + Display*, + XEvent* event, + XPointer + ) + { + if ( event->type == SelectionNotify) + { + return True; + } + else + { + return False; + } + } + + void get_from_clipboard ( + std::string& str + ) + { + std::wstring wstr; + get_from_clipboard(wstr); + str = convert_wstring_to_mbstring(wstr); + } + + void get_from_clipboard ( + dlib::ustring& str + ) + { + std::wstring wstr; + get_from_clipboard(wstr); + str = convert_wstring_to_utf32(wstr); + } + + void get_from_clipboard ( + std::wstring& str + ) + { + using namespace gui_core_kernel_2_globals; + std::shared_ptr globals(global_data()); + + auto_mutex M(globals->window_table.get_mutex()); + str.clear(); + unsigned char *data = 0; + wchar_t **plist = 0; + Window sown; + Atom type; + int format, result; + unsigned long len, bytes_left, dummy; + XEvent e; + + try + { + Atom atom_ct = XInternAtom(globals->disp, "COMPOUND_TEXT", False); + sown = XGetSelectionOwner (globals->disp, XA_PRIMARY); + if (sown == globals->exit_window) + { + // if we are copying from ourselfs then don't fool with the Xwindows junk. + str = globals->clipboard.c_str(); + } + else if (sown != None) + { + // request that the selection be copied into the XA_PRIMARY property + // of the exit_window. It doesn't matter what window we put it in + // so long as it is one under the control of this process and exit_window + // is easy to use here so that is what I'm using. + XConvertSelection (globals->disp, XA_PRIMARY, atom_ct, XA_PRIMARY, + globals->exit_window, CurrentTime); + + // This will wait until we get a SelectionNotify event which should happen + // really soon. + XPeekIfEvent(globals->disp,&e,clip_peek_helper,0); + + // See how much data we got + XGetWindowProperty (globals->disp, globals->exit_window, + XA_PRIMARY, // Tricky.. + 0, 0, // offset - len + 0, // Delete 0==FALSE + AnyPropertyType, //flag + &type, // return type + &format, // return format + &len, &bytes_left, //that + &data); + if (data) + { + XFree(data); + data = 0; + } + if (bytes_left > 0 && type == atom_ct) + { + XTextProperty p; + result = XGetWindowProperty (globals->disp, globals->exit_window, + XA_PRIMARY, 0,bytes_left,0, + AnyPropertyType, &p.encoding,&p.format, + &p.nitems, &dummy, &p.value); + if (result == Success && p.encoding == atom_ct) + { + int n; + XwcTextPropertyToTextList(globals->disp, &p, &plist, &n); + str = plist[0]; + } + if (plist) + { + XwcFreeStringList(plist); + plist = 0; + } + } + } + } + catch (...) + { + if (data) + XFree(data); + if (plist) + { + XwcFreeStringList(plist); + plist = 0; + } + } + } + +// ---------------------------------------------------------------------------------------- + + namespace gui_core_kernel_2_globals + { + void trigger_user_event_threadproc ( + void* + ) + { + std::shared_ptr globals(global_data()); + auto_mutex M(globals->window_table.get_mutex()); + + globals->user_events.lock(); + globals->user_events.swap(globals->user_events_temp); + globals->user_events.unlock(); + + + globals->user_events_temp.reset(); + // now dispatch all these user events + while (globals->user_events_temp.move_next()) + { + base_window** win_ = globals->window_table[globals->user_events_temp.element().w]; + base_window* win; + // if this window exists in the window table then dispatch + // its event. + if (win_) + { + win = *win_; + win->on_user_event( + globals->user_events_temp.element().p, + globals->user_events_temp.element().i + ); + } + } + globals->user_events_temp.clear(); + } + } + + void base_window:: + trigger_user_event ( + void* p, + int i + ) + { + using namespace gui_core_kernel_2_globals; + user_event_type e; + e.w = x11_stuff.hwnd; + e.p = p; + e.i = i; + { + std::shared_ptr globals(global_data()); + auto_mutex M(globals->user_events.get_mutex()); + globals->user_events.enqueue(e); + + // we only need to start a thread to deal with this if there isn't already + // one out working on the queue + if (globals->user_events.size() == 1) + create_new_thread (trigger_user_event_threadproc,0); + } + } + +// ---------------------------------------------------------------------------------------- + + base_window:: + base_window ( + bool resizable_, + bool undecorated + ) : + x11_stuff(*(new gui_core_kernel_2_globals::x11_base_windowstuff)), + is_mapped(false), + resizable(resizable_), + has_been_destroyed(false), + has_been_resized(false), + has_been_moved(false), + wm(gui_core_kernel_2_globals::global_data()->window_table.get_mutex()) + { + DLIB_ASSERT(!(undecorated == true && resizable_ == true), + "\tbase_window::base_window()" + << "\n\tThere is no such thing as an undecorated window that is resizable by the user." + << "\n\tthis: " << this + ); + using namespace gui_core_kernel_2_globals; + + auto_mutex M(wm); + + x11_stuff.globals = global_data(); + + x11_stuff.last_click_time = 0; + last_click_x = 0; + last_click_y = 0; + last_click_button = NONE; + + XSetWindowAttributes attr; + memset(&attr,'\0',sizeof(attr)); + + unsigned long valuemask = 0; + if (undecorated) + { + attr.override_redirect = True; + valuemask = CWOverrideRedirect; + } + + + x11_stuff.hwnd = XCreateWindow( + x11_stuff.globals->disp, + DefaultRootWindow(x11_stuff.globals->disp), + 0, + 0, + 10, // this is the default width of a window + 10, // this is the default width of a window + 0, + x11_stuff.globals->depth, + InputOutput, + CopyFromParent, + valuemask, + &attr + ); + + x11_stuff.xic = NULL; + if (x11_stuff.globals->xim) + { + XVaNestedList xva_nlist; + XPoint xpoint; + + char **mlist; + int mcount; + char *def_str; + char fontset[256]; + const long native_font_height = 12; + sprintf(fontset, "-*-*-medium-r-normal--%lu-*-*-*-", native_font_height); + x11_stuff.fs = XCreateFontSet(x11_stuff.globals->disp, fontset, &mlist, &mcount, &def_str); + xpoint.x = 0; + xpoint.y = 0; + xva_nlist = XVaCreateNestedList(0, XNSpotLocation, &xpoint, XNFontSet, x11_stuff.fs, (const void*)NULL); + x11_stuff.xic = XCreateIC( + x11_stuff.globals->xim, + XNInputStyle, x11_stuff.globals->xim_style, + XNClientWindow, x11_stuff.hwnd, + XNPreeditAttributes, xva_nlist, + (const void*)NULL + ); + XFree(xva_nlist); + XFreeStringList(mlist); + } + + Window temp = x11_stuff.hwnd; + base_window* ttemp = this; + x11_stuff.globals->window_table.add(temp,ttemp); + + // query event mask required by input method + unsigned long event_xim = 0; + if (x11_stuff.xic) + XGetICValues( x11_stuff.xic, XNFilterEvents, &event_xim, (const void*)NULL ); + + XSelectInput( + x11_stuff.globals->disp, + x11_stuff.hwnd, + StructureNotifyMask|ExposureMask|ButtonPressMask|ButtonReleaseMask| + PointerMotionMask|LeaveWindowMask|EnterWindowMask|KeyPressMask| + KeyReleaseMask| FocusChangeMask | event_xim + ); + + XSetWMProtocols( + x11_stuff.globals->disp, + x11_stuff.hwnd, + &x11_stuff.globals->delete_window, + 1 + ); + + + // these are just default values + x = 0; + y = 0; + width = 10; + height = 10; + + if (resizable == false) + { + XSizeHints* hints = XAllocSizeHints(); + hints->flags = PMinSize|PMaxSize; + hints->min_width = width; + hints->max_width = width; + hints->max_height = height; + hints->min_height = height; + XSetNormalHints(x11_stuff.globals->disp,x11_stuff.hwnd,hints); + XFree(hints); + } + } + +// ---------------------------------------------------------------------------------------- + + base_window:: + ~base_window ( + ) + { + using namespace gui_core_kernel_2_globals; + close_window(); + + if (x11_stuff.globals->xim != NULL) + { + XDestroyIC(x11_stuff.xic); + x11_stuff.xic = 0; + XFreeFontSet(x11_stuff.globals->disp,x11_stuff.fs); + } + + delete &x11_stuff; + } + +// ---------------------------------------------------------------------------------------- + + void base_window:: + close_window ( + ) + { + using namespace gui_core_kernel_2_globals; + auto_mutex M(wm); + if (has_been_destroyed == false) + { + has_been_destroyed = true; + + x11_stuff.globals->window_table.destroy(x11_stuff.hwnd); + + XDestroyWindow(x11_stuff.globals->disp,x11_stuff.hwnd); + x11_stuff.hwnd = 0; + x11_stuff.globals->window_close_signaler.broadcast(); + } + } + +// ---------------------------------------------------------------------------------------- + + bool base_window:: + is_closed ( + ) const + { + auto_mutex M(wm); + return has_been_destroyed; + } + +// ---------------------------------------------------------------------------------------- + + void base_window:: + set_title ( + const std::string& title_ + ) + { + set_title(convert_mbstring_to_wstring(title_)); + } + + void base_window:: + set_title ( + const ustring& title_ + ) + { + set_title(convert_utf32_to_wstring(title_)); + } + + void base_window:: + set_title ( + const std::wstring& title_ + ) + { + using namespace gui_core_kernel_2_globals; + auto_mutex M(wm); + if (has_been_destroyed == true) + return; + + // I'm pretty sure the pointer won't be modified even though + // it isn't const anymore. + wchar_t *title = const_cast(title_.c_str()); + XTextProperty property; + XwcTextListToTextProperty(x11_stuff.globals->disp,&title,1,XStdICCTextStyle, &property); + XSetWMName(x11_stuff.globals->disp,x11_stuff.hwnd,&property); + XFree(property.value); + XFlush(x11_stuff.globals->disp); + } + +// ---------------------------------------------------------------------------------------- + + void base_window:: + show ( + ) + { + using namespace gui_core_kernel_2_globals; + auto_mutex M(wm); + if (has_been_destroyed == true) + return; + + XMapRaised(x11_stuff.globals->disp,x11_stuff.hwnd); + XFlush(x11_stuff.globals->disp); + } + +// ---------------------------------------------------------------------------------------- + + void base_window:: + wait_until_closed ( + ) const + { + using namespace gui_core_kernel_2_globals; + auto_mutex M(wm); + while (has_been_destroyed == false) + x11_stuff.globals->window_close_signaler.wait(); + } + +// ---------------------------------------------------------------------------------------- + + void base_window:: + hide ( + ) + { + using namespace gui_core_kernel_2_globals; + auto_mutex M(wm); + if (has_been_destroyed == true) + return; + + XUnmapWindow(x11_stuff.globals->disp,x11_stuff.hwnd); + XFlush(x11_stuff.globals->disp); + } + +// ---------------------------------------------------------------------------------------- + + void base_window:: + set_size ( + int width_, + int height_ + ) + { + using namespace gui_core_kernel_2_globals; + auto_mutex a(wm); + if (has_been_destroyed == true) + return; + + + // do some sanity checking on these values + if (width_ < 1) + width_ = 1; + if (height_ < 1) + height_ = 1; + + width = width_; + height = height_; + has_been_resized = true; + + if (resizable == false) + { + XSizeHints* hints = XAllocSizeHints(); + hints->flags = PMinSize|PMaxSize; + hints->min_width = width; + hints->max_width = width; + hints->max_height = height; + hints->min_height = height; + XSetNormalHints(x11_stuff.globals->disp,x11_stuff.hwnd,hints); + XFree(hints); + } + + XResizeWindow(x11_stuff.globals->disp,x11_stuff.hwnd,width,height); + + XFlush(x11_stuff.globals->disp); + } + +// ---------------------------------------------------------------------------------------- + + void base_window:: + set_pos ( + long x_, + long y_ + ) + { + using namespace gui_core_kernel_2_globals; + auto_mutex a(wm); + if (has_been_destroyed == true) + return; + + + x = x_; + y = y_; + + has_been_moved = true; + + XMoveWindow(x11_stuff.globals->disp,x11_stuff.hwnd,x,y); + XFlush(x11_stuff.globals->disp); + } + +// ---------------------------------------------------------------------------------------- + + void base_window:: + get_pos ( + long& x_, + long& y_ + ) + { + using namespace gui_core_kernel_2_globals; + auto_mutex a(wm); + x_ = 0; + y_ = 0; + if (has_been_destroyed == true) + return; + + // we can't really trust the values we have for x and y because some window managers + // will have reported bogus values back in the ConfigureNotify event. So just to be + // on the safe side we will use XTranslateCoordinates() + int rx, ry; + Window desktop_window = DefaultRootWindow(x11_stuff.globals->disp); + Window junk; + XTranslateCoordinates(x11_stuff.globals->disp,x11_stuff.hwnd,desktop_window,0,0,&rx, &ry, &junk); + x_ = rx; + y_ = ry; + x = rx; + y = ry; + } + +// ---------------------------------------------------------------------------------------- + + void base_window:: + get_size ( + unsigned long& width_, + unsigned long& height_ + ) const + { + auto_mutex M(wm); + width_ = 0; + height_ = 0; + if (has_been_destroyed == true) + return; + + + width_ = width; + height_ = height; + } + +// ---------------------------------------------------------------------------------------- + + void base_window:: + get_display_size ( + unsigned long& width_, + unsigned long& height_ + ) const + { + using namespace gui_core_kernel_2_globals; + auto_mutex M(wm); + width_ = 0; + height_ = 0; + if (has_been_destroyed == true) + return; + + int screen_number = XScreenNumberOfScreen(x11_stuff.globals->screen); + width_ = DisplayWidth(x11_stuff.globals->disp, screen_number); + height_ = DisplayHeight(x11_stuff.globals->disp, screen_number); + } + +// ---------------------------------------------------------------------------------------- + + void base_window:: + invalidate_rectangle ( + const rectangle& rect + ) + { + using namespace gui_core_kernel_2_globals; + auto_mutex a(wm); + if (is_mapped == false) + return; + + if (rect.is_empty() == false && !has_been_destroyed) + { + const long x = rect.left(); + const long y = rect.top(); + const unsigned long width = rect.width(); + const unsigned long height = rect.height(); + + XClearArea(x11_stuff.globals->disp,x11_stuff.hwnd,x,y,width,height,1); + XFlush(x11_stuff.globals->disp); + } + } + +// ---------------------------------------------------------------------------------------- + + void base_window:: + set_im_pos ( + long x, + long y + ) + { + using namespace gui_core_kernel_2_globals; + auto_mutex a(wm); + if (has_been_destroyed == true) + return; + + if (!x11_stuff.xic || !(x11_stuff.globals->xim_style & XIMPreeditPosition)) return; + + XVaNestedList xva_nlist; + XPoint xpoint; + + xpoint.x = x; + xpoint.y = y; + + xva_nlist = XVaCreateNestedList(0, XNSpotLocation, &xpoint, (const void*)NULL); + XSetICValues(x11_stuff.xic, XNPreeditAttributes, xva_nlist, (const void*)NULL); + XFree(xva_nlist); + } + +} + +// ---------------------------------------------------------------------------------------- + +#endif // POSIX + +#endif // DLIB_GUI_CORE_KERNEL_2_CPp_ + diff --git a/ml/dlib/dlib/gui_core/gui_core_kernel_2.h b/ml/dlib/dlib/gui_core/gui_core_kernel_2.h new file mode 100644 index 000000000..efcd4ba19 --- /dev/null +++ b/ml/dlib/dlib/gui_core/gui_core_kernel_2.h @@ -0,0 +1,419 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net), Keita Mochizuki +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_GUI_CORE_KERNEl_2_ +#define DLIB_GUI_CORE_KERNEl_2_ + +#ifdef DLIB_ISO_CPP_ONLY +#error "DLIB_ISO_CPP_ONLY is defined so you can't use this OS dependent code. Turn DLIB_ISO_CPP_ONLY off if you want to use it." +#endif + +#ifdef DLIB_NO_GUI_SUPPORT +#error "DLIB_NO_GUI_SUPPORT is defined so you can't use the GUI code. Turn DLIB_NO_GUI_SUPPORT off if you want to use it." +#error "Also make sure you have libx11-dev installed on your system" +#endif + +#include + +#include "gui_core_kernel_abstract.h" +#include "../algs.h" +#include "../threads.h" +#include "../geometry/rectangle.h" +#include "../binary_search_tree.h" +#include +#include "../pixel.h" +#include "../unicode.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace gui_core_kernel_2_globals + { + class event_handler_thread; + void trigger_user_event_threadproc (void*); + + // This is a forward declaration for a struct that contains any + // X11 variables. This allows me to avoid having any dlib header files + // include the X11 headers. Which in turn speeds build times and simplifies + // build setups. + struct x11_base_windowstuff; + } + +// ---------------------------------------------------------------------------------------- + + void put_on_clipboard ( + const std::string& str + ); + + void put_on_clipboard ( + const std::wstring& str + ); + + void put_on_clipboard ( + const dlib::ustring& str + ); + +// ---------------------------------------------------------------------------------------- + + void get_from_clipboard ( + std::string& str + ); + + void get_from_clipboard ( + std::wstring& str + ); + + void get_from_clipboard ( + dlib::ustring& str + ); + +// ---------------------------------------------------------------------------------------- + + class canvas : public rectangle + { + public: + struct pixel + { + unsigned char blue; + unsigned char green; + unsigned char red; + private: + friend class canvas; + unsigned char _padding; + }; + + ~canvas() {} + + inline pixel* operator[] ( + unsigned long row + ) const + { + DLIB_ASSERT(row < height(), + "\tpixel* canvas::operator[]" + << "\n\tyou have to give a row that is less than the height()" + << "\n\tthis: " << this + << "\n\trow: " << row + << "\n\theight(): " << height() + ); + unsigned char* temp = bits + row_width*row; + return reinterpret_cast(temp); + } + + void fill ( + unsigned char red_, + unsigned char green_, + unsigned char blue_ + ) const; + + private: + + friend class gui_core_kernel_2_globals::event_handler_thread; + + + canvas ( + unsigned char* bits_, + unsigned long left_, + unsigned long top_, + unsigned long right_, + unsigned long bottom_ + ) : + rectangle(left_,top_,right_,bottom_), + bits(bits_), + width_(width()), + height_(height()), + row_width(width_*4) + {} + + // restricted functions + canvas(); // normal constructor + canvas(canvas&); // copy constructor + canvas& operator=(canvas&); // assignment operator + + unsigned char* const bits; + const unsigned long width_; + const unsigned long height_; + const unsigned long row_width; + }; + + template <> + struct pixel_traits + { + constexpr static bool rgb = true; + constexpr static bool rgb_alpha = false; + constexpr static bool grayscale = false; + constexpr static bool hsi = false; + constexpr static long num = 3; + typedef unsigned char basic_pixel_type; + static basic_pixel_type min() { return 0;} + static basic_pixel_type max() { return 255;} + constexpr static bool is_unsigned = true; + constexpr static bool has_alpha = false; + }; + +// ----------------- + + class base_window + { + friend class gui_core_kernel_2_globals::event_handler_thread; + friend void gui_core_kernel_2_globals::trigger_user_event_threadproc (void*); + + public: + + enum mouse_state_masks + { + NONE = 0, + LEFT = 1, + RIGHT = 2, + MIDDLE = 4, + SHIFT = 8, + CONTROL = 16 + }; + + enum keyboard_state_masks + { + KBD_MOD_NONE = 0, + KBD_MOD_SHIFT = 1, + KBD_MOD_CONTROL = 2, + KBD_MOD_ALT = 4, + KBD_MOD_META = 8, + KBD_MOD_CAPS_LOCK = 16, + KBD_MOD_NUM_LOCK = 32, + KBD_MOD_SCROLL_LOCK = 64 + }; + + enum on_close_return_code + { + DO_NOT_CLOSE_WINDOW, + CLOSE_WINDOW + }; + + enum non_printable_keyboard_keys + { + KEY_BACKSPACE, + KEY_SHIFT, + KEY_CTRL, + KEY_ALT, + KEY_PAUSE, + KEY_CAPS_LOCK, + KEY_ESC, + KEY_PAGE_UP, + KEY_PAGE_DOWN, + KEY_END, + KEY_HOME, + KEY_LEFT, // This is the left arrow key + KEY_RIGHT, // This is the right arrow key + KEY_UP, // This is the up arrow key + KEY_DOWN, // This is the down arrow key + KEY_INSERT, + KEY_DELETE, + KEY_SCROLL_LOCK, + + // Function Keys + KEY_F1, + KEY_F2, + KEY_F3, + KEY_F4, + KEY_F5, + KEY_F6, + KEY_F7, + KEY_F8, + KEY_F9, + KEY_F10, + KEY_F11, + KEY_F12 + }; + + private: + + gui_core_kernel_2_globals::x11_base_windowstuff& x11_stuff; + + int x, y, width, height; + bool is_mapped; + + const bool resizable; + bool has_been_destroyed; + bool has_been_resized; // true if someone called set_size() and the on_window_resized() event + // hasn't yet occurred. + bool has_been_moved; // true if someone called set_pos() and the on_window_moved() event + // hasn't yet occurred. + + + // The following 3 variables (and x11_stuff.last_click_time) are only accessed from the + // event handling loop (except for being initialized below). They record the last + // mouse click event details. + long last_click_x, last_click_y; + unsigned long last_click_button; + + + protected: + const rmutex& wm; + + public: + + base_window ( + bool resizable_ = true, + bool undecorated = false + ); + + virtual ~base_window ( + ); + + void close_window ( + ); + + void wait_until_closed ( + ) const; + + void set_im_pos ( + long x_, + long y_ + ); + + bool is_closed ( + ) const; + + void set_title ( + const std::string& title_ + ); + + void set_title ( + const std::wstring& title_ + ); + + void set_title ( + const dlib::ustring& title_ + ); + + virtual void show ( + ); + + virtual void hide( + ); + + void set_size ( + int width_, + int height_ + ); + + void set_pos ( + long x_, + long y_ + ); + + void get_pos ( + long& x_, + long& y_ + ); + + void get_size ( + unsigned long& width_, + unsigned long& height_ + ) const; + + void get_display_size ( + unsigned long& width, + unsigned long& height + ) const; + + void invalidate_rectangle ( + const rectangle& rect + ); + + void trigger_user_event ( + void* p, + int i + ); + + protected: + + virtual on_close_return_code on_window_close( + ){return CLOSE_WINDOW;} + + virtual void on_window_resized( + ){} + + virtual void on_window_moved( + ){} + virtual void on_user_event ( + void* , + int + ){} + + virtual void on_mouse_down ( + unsigned long , + unsigned long , + long , + long , + bool + ){} + + virtual void on_mouse_up ( + unsigned long , + unsigned long , + long , + long + ){} + + virtual void on_mouse_move ( + unsigned long , + long , + long + ){} + + virtual void on_mouse_leave ( + ){} + + virtual void on_mouse_enter ( + ){} + + virtual void on_wheel_up ( + unsigned long + ){} + + virtual void on_wheel_down ( + unsigned long + ){} + + virtual void on_focus_gained ( + ){} + + virtual void on_focus_lost ( + ){} + + virtual void on_keydown ( + unsigned long , + bool , + unsigned long + ){} + + virtual void on_string_put ( + const std::wstring& + ){} + + private: + + virtual void paint ( + const canvas& c + ) =0; + + + + base_window(base_window&); // copy constructor + base_window& operator=(base_window&); // assignment operator + + }; + +// ---------------------------------------------------------------------------------------- + + +} + + +#ifdef NO_MAKEFILE +#include "gui_core_kernel_2.cpp" +#endif + +#endif // DLIB_GUI_CORE_KERNEl_2_ + diff --git a/ml/dlib/dlib/gui_core/gui_core_kernel_abstract.h b/ml/dlib/dlib/gui_core/gui_core_kernel_abstract.h new file mode 100644 index 000000000..a773e4287 --- /dev/null +++ b/ml/dlib/dlib/gui_core/gui_core_kernel_abstract.h @@ -0,0 +1,792 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net), Keita Mochizuki +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_GUI_CORE_KERNEl_ABSTRACT_ +#ifdef DLIB_GUI_CORE_KERNEl_ABSTRACT_ + +#include +#include "../algs.h" +#include "../geometry/rectangle_abstract.h" +#include "../unicode/unicode_abstract.h" + +namespace dlib +{ + + /*! + OVERVIEW: + This is a set of objects and functions which provide a very basic + framework for manipulating windows. It is intended to provide a + portable interface which can be used to build a more complex windowing + toolkit. + + EXCEPTIONS + Do not let an exception leave any of the base_window event handlers. + The results of doing so are undefined. + + THREAD SAFETY + Event Handlers + All event handlers are executed in a special event handling thread. + This means that you must not do anything that will take a long time or + block while in an event handler. Doing so will freeze all event + processing. + + Also, don't rely on get_thread_id() always returning the same ID from + inside event handlers. + + canvas + Never access a canvas object outside of the paint() callback + that supplied it. Only access a canvas object from the event + handling thread. After the paint() event handler has returned do + not access that canvas object again. + + base_window + All methods for this class are thread safe. You may call them + from any thread and do not need to serialize access. + !*/ + +// ---------------------------------------------------------------------------------------- + + void put_on_clipboard ( + const std::string& str + ); + /*! + ensures + - posts the contents of str to the system clipboard + throws + - std::bad_alloc + - dlib::gui_error + - dlib::thread_error + !*/ + + // overloads for wide character strings + void put_on_clipboard (const std::wstring& str); + void put_on_clipboard (const dlib::ustring& str); + +// ---------------------------------------------------------------------------------------- + + void get_from_clipboard ( + std::string& str + ); + /*! + ensures + - if (there is string data on the system clipboard) then + - #str == the data from the clipboard + - else + - #str == "" + throws + - std::bad_alloc + - dlib::gui_error + - dlib::thread_error + !*/ + + // overloads for wide character strings + void get_from_clipboard (std::wtring& str); + void get_from_clipboard (dlib::utring& str); + +// ---------------------------------------------------------------------------------------- + + + class canvas : public rectangle + { + /*! + POINTERS AND REFERENCES TO INTERNAL DATA + All functions of this object may invalidate pointers and references + to internal data. + + INITIAL VALUE + The initial value of each pixel is undefined. + is_empty() == false + + WHAT THIS OBJECT REPRESENTS + This object represents a rectangular area of a window that you + can draw on. + + Each pixel can be accessed with the following syntax: + canvas_instance[y][x].red == the red value for this pixel + canvas_instance[y][x].blue == the blue value for this pixel + canvas_instance[y][x].green == the green value for this pixel + + The origin, i.e. (0,0), of the x,y coordinate plane of the canvas is in + the upper left corner of the canvas. Note that the upper left corner + of the canvas appears at the point (left(),top()) in its window. + !*/ + + public: + + struct pixel + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a single pixel. Each pixel's value + ranges from 0 to 255 with 0 indicating that the color is not + present in the pixel at all and 255 indicating that the color + is present in the pixel with maximum intensity. + + Note that the structure, order, and size of of this struct are + implementation dependent. It will always contain fields called + red, green, and blue but they may not be in that order and there + may be padding. + + Also note that pixel_traits<> is defined for this pixel type, + thus you can use it in assign_pixel() calls. + !*/ + unsigned char red; + unsigned char green; + unsigned char blue; + }; + + + pixel* operator[] ( + unsigned long row + ) const; + /*! + requires + - row < height() + ensures + - returns an array of width() pixel structs that represents the given + row of pixels in the canvas. + !*/ + + void fill ( + unsigned char red, + unsigned char green, + unsigned char blue + ) const; + /*! + ensures + - for all valid values of x and y: + - (#*this)[y][x].red = red + - (#*this)[y][x].green = green + - (#*this)[y][x].blue = blue + !*/ + + private: + + // restricted functions + canvas(); // normal constructor + canvas(canvas&); // copy constructor + canvas& operator=(canvas&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- + + class base_window + { + + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a window on the desktop. A window has a "client + area" that is a region of the screen that you can draw whatever you like + on. You implement the paint() callback and use the canvas object to do + this drawing. + + INITIAL STATE + - The initial state of the window is to be hidden. This means you need + to call show() to make it appear. + - is_closed() == false + + paint() callback: + This is where you will do all your drawing. It is triggered when + part of the window needs to be drawn/redrawn. + + mouse events: + It is important to note a few things about the mouse events. First, + the on_mouse_move() event is not triggered for each pixel the mouse crosses + but rather its frequency and precision is implementation dependent. + + Second, it is possible that a mouse button may be depressed but the + corresponding button release event does not go to the window. For instance, + if the mouse is outside the window and some other application jumps to the + top it is possible that the new application will receive any mouse button + release events rather than the original window. But the point is that + you should not rely on always getting a button up event for every button + down event. + + keydown event: + Note that the existence of a typematic action (holding down a key + and having it start to repeat itself after a moment) for each key is + totally implementation dependent. So don't rely on it for any key + and conversely don't assume it isn't present either. + + The base_window::wm mutex + This is a reference to a global rmutex. All instances of base_window make + reference to the same global rmutex. It is used to synchronize access to + the base_window to make it thread safe. It is also always locked before + an event handler is called. + !*/ + + public: + + enum on_close_return_code + { + DO_NOT_CLOSE_WINDOW, + CLOSE_WINDOW + }; + + enum mouse_state_masks + { + /*! + These constants represent the various buttons referenced by + mouse events. + !*/ + NONE = 0, + LEFT = 1, + RIGHT = 2, + MIDDLE = 4, + SHIFT = 8, + CONTROL = 16 + }; + + enum keyboard_state_masks + { + /*! + These constants represent the various modifier buttons that + could be in effect during a key press on the keyboard + !*/ + KBD_MOD_NONE = 0, + KBD_MOD_SHIFT = 1, + KBD_MOD_CONTROL = 2, + KBD_MOD_ALT = 4, + KBD_MOD_META = 8, + KBD_MOD_CAPS_LOCK = 16, + KBD_MOD_NUM_LOCK = 32, + KBD_MOD_SCROLL_LOCK = 64 + }; + + enum non_printable_keyboard_keys + { + KEY_BACKSPACE, + KEY_SHIFT, + KEY_CTRL, + KEY_ALT, + KEY_PAUSE, + KEY_CAPS_LOCK, + KEY_ESC, + KEY_PAGE_UP, + KEY_PAGE_DOWN, + KEY_END, + KEY_HOME, + KEY_LEFT, // This is the left arrow key + KEY_RIGHT, // This is the right arrow key + KEY_UP, // This is the up arrow key + KEY_DOWN, // This is the down arrow key + KEY_INSERT, + KEY_DELETE, + KEY_SCROLL_LOCK, + + // Function Keys + KEY_F1, + KEY_F2, + KEY_F3, + KEY_F4, + KEY_F5, + KEY_F6, + KEY_F7, + KEY_F8, + KEY_F9, + KEY_F10, + KEY_F11, + KEY_F12 + }; + + base_window ( + bool resizable = true, + bool undecorated = false + ); + /*! + requires + - if (undecorated == true) then + - resizable == false + ensures + - #*this has been properly initialized + - if (resizable == true) then + - this window will be resizable by the user + - else + - this window will not be resizable by the user + - if (undecorated == true) then + - this window will not have any graphical elements outside + of its drawable area or appear in the system task bar. It + also won't take the input focus from other windows. + (it is suitable for making things such as popup menus) + throws + - std::bad_alloc + - dlib::thread_error + - dlib::gui_error + This exception is thrown if there is an error while + creating this window. + !*/ + + virtual ~base_window ( + ); + /*! + ensures + - does NOT trigger the on_window_close() event + - all resources associated with *this have been released + - closes this window + !*/ + + void close_window ( + ); + /*! + ensures + - #is_closed() == true + (i.e. permanently closes this window. The window is removed from the + screen and no more events will be dispatched to this window. ) + - does NOT trigger the on_window_close() event + !*/ + + void wait_until_closed ( + ) const; + /*! + ensures + - blocks until is_closed() == true + !*/ + + bool is_closed ( + ) const; + /*! + ensures + - returns true if this window has been closed, false otherwise. + (Note that closed windows do not receive any callbacks at all. + They are also not visible on the screen.) + !*/ + + void set_title ( + const std::string& title + ); + /*! + ensures + - if (is_closed() == false) then + - sets the title of the window + !*/ + + void set_title ( + const std::wstring& title + ); + /*! + ensures + - if (is_closed() == false) then + - sets the title of the window + !*/ + + void set_title ( + const dlib::ustring& title + ); + /*! + ensures + - if (is_closed() == false) then + - sets the title of the window + !*/ + + virtual void show ( + ); + /*! + ensures + - if (is_closed() == false) then + - this window will appear on the screen + !*/ + + virtual void hide( + ); + /*! + ensures + - if (is_closed() == false) then + - the window does not appear on the screen + !*/ + + void set_size ( + int width, + int height + ); + /*! + ensures + - if (is_closed() == false) then + - The width of the client area of this window is at least width + pixels. + - The height of the client area of this window is at least height + pixels. + - if (the window wasn't already this size) then + - triggers the on_window_resized() callback + !*/ + + void set_pos ( + long x, + long y + ); + /*! + ensures + - if (is_closed() == false) then + - sets the upper left corner of this window to the position (x,y) + on the desktop. Note that the origin (0,0) is at the upper left + corner of the desktop. + !*/ + + void get_pos ( + long& x, + long& y + ) const; + /*! + ensures + - if (is_closed() == false) then + - #x == the x coordinate of the upper left corner of the client area of + this window. + - #y == the y coordinate of the upper left corner of the client area of + this window. + - i.e. the point (#x,#y) on the desktop is coincident with the point + (0,0) in the client area of this window. + - else + - #x == 0 + - #y == 0 + !*/ + + void get_size ( + unsigned long& width, + unsigned long& height + ) const; + /*! + ensures + - if (is_closed() == false) then + - #width == the width of the client area of this window in pixels + - #height == the height of the client area of this window in pixels + - else + - #width == 0 + - #height == 0 + !*/ + + void get_display_size ( + unsigned long& width, + unsigned long& height + ) const; + /*! + ensures + - if (is_closed() == false) then + - #width == the width in pixels of the display device that contains this window + - #height == the height in pixels of the display device that contains this window + - else + - #width == 0 + - #height == 0 + !*/ + + void invalidate_rectangle ( + const rectangle& rect + ); + /*! + ensures + - if (is_closed() == false) then + - causes the area of this window defined by rect to become invalid. + This means that a paint() message will be dispatched to repaint + this area of the window. Note that it is possible that the + resulting paint() message may include a bigger rectangle than + the one defined by rect. + !*/ + + void trigger_user_event ( + void* p, + int i + ); + /*! + ensures + - will never block (even if some other thread has a lock on the + global mutex referenced by wm.) + - if (is_closed() == false) then + - causes the on_user_event() event to be called with + the given arguments. + !*/ + + void set_im_pos ( + long x_, + long y_ + ); + /*! + ensures + - if (is_closed() == false) then + - sets the left-top position of input method rectangle used + for wide character input methods. + !*/ + + protected: + const rmutex& wm; + + // let the window close by default + virtual on_close_return_code on_window_close( + ){return CLOSE_WINDOW;} + /*! + requires + - is_closed() == false + - mutex wm is locked + - is called when the user attempts to close this window + - if (this function returns CLOSE_WINDOW) then + - #is_closed() == true (i.e. this window will be closed) + - it is safe to call "delete this;" inside on_window_close() + if *this was allocated on the heap and no one will try to + access *this anymore. + - else + - this window will not be closed and the attempt to close it + by the user will have no effect. + - #is_closed() == false + ensures + - does not change the state of mutex wm + !*/ + + // do nothing by default + virtual void on_user_event ( + void* p, + int i + ){} + /*! + requires + - is_closed() == false + - mutex wm is locked + - is called whenever someone calls trigger_user_event() + ensures + - does not change the state of mutex wm + !*/ + + // do nothing by default + virtual void on_window_resized( + ){} + /*! + requires + - is_closed() == false + - mutex wm is locked + - is called when this window is resized + ensures + - does not change the state of mutex wm + !*/ + + // do nothing by default + virtual void on_window_moved( + ){} + /*! + requires + - is_closed() == false + - mutex wm is locked + - is called when this window's position changes + ensures + - does not change the state of mutex wm + !*/ + + // do nothing by default + virtual void on_mouse_down ( + unsigned long btn, + unsigned long state, + long x, + long y, + bool is_double_click + ){} + /*! + requires + - is_closed() == false + - mutex wm is locked + - is called when the user depresses one of the mouse buttons + - btn == the button that was depressed. (either LEFT, MIDDLE, or RIGHT) + - state == the bitwise OR of the buttons that are currently depressed + excluding the button given by btn. (from the mouse_state_masks enum) + - (x,y) == the position of the mouse (relative to the upper left corner + of the window) when this event occurred. Note that the mouse may be + outside the window. + - if (this is the second button press of a double click) then + - is_double_click == true + - else + - is_double_click == false + ensures + - does not change the state of mutex wm + !*/ + + // do nothing by default + virtual void on_mouse_up ( + unsigned long btn, + unsigned long state, + long x, + long y + ){} + /*! + requires + - is_closed() == false + - mutex wm is locked + - is called when the user releases one of the mouse buttons + - btn == the button that was released. (either LEFT, MIDDLE, or RIGHT) + - state == the bitwise OR of the buttons that are currently depressed + (from the mouse_state_masks enum) + - (x,y) == the position of the mouse (relative to the upper left corner + of the window) when this event occurred. Note that the mouse may be + outside the window. + ensures + - does not change the state of mutex wm + !*/ + + // do nothing by default + virtual void on_mouse_move ( + unsigned long state, + long x, + long y + ){} + /*! + requires + - is_closed() == false + - mutex wm is locked + - is called when the user moves the mouse + - state == the bitwise OR of the buttons that are currently depressed + (from the mouse_state_masks enum) + - (x,y) == the position of the mouse (relative to the upper left corner + of the window) when this event occurred. + - if (the user is holding down one or more of the mouse buttons) then + - the mouse move events will continue to track the mouse even if + it goes out of the window. This will continue until the user + releases all the mouse buttons. + ensures + - does not change the state of mutex wm + !*/ + + // do nothing by default + virtual void on_mouse_leave ( + ){} + /*! + requires + - is_closed() == false + - mutex wm is locked + - is called when the mouse leaves this window + ensures + - does not change the state of mutex wm + !*/ + + // do nothing by default + virtual void on_mouse_enter ( + ){} + /*! + requires + - is_closed() == false + - mutex wm is locked + - is called when the mouse enters this window + ensures + - does not change the state of mutex wm + !*/ + + // do nothing by default + virtual void on_focus_gained ( + ){} + /*! + requires + - is_closed() == false + - mutex wm is locked + - is called when this window gains input focus + ensures + - does not change the state of mutex wm + !*/ + + // do nothing by default + virtual void on_focus_lost ( + ){} + /*! + requires + - is_closed() == false + - mutex wm is locked + - is called when this window loses input focus + ensures + - does not change the state of mutex wm + !*/ + + // do nothing by default + virtual void on_wheel_up ( + unsigned long state + ){} + /*! + requires + - is_closed() == false + - mutex wm is locked + - is called every time the mouse wheel is scrolled up one notch + - state == the bitwise OR of the buttons that are currently depressed + (from the mouse_state_masks enum) + ensures + - does not change the state of mutex wm + !*/ + + // do nothing by default + virtual void on_wheel_down ( + unsigned long state + ){} + /*! + requires + - is_closed() == false + - mutex wm is locked + - is called every time the mouse wheel is scrolled down one notch + - state == the bitwise OR of the buttons that are currently depressed + (from the mouse_state_masks enum) + ensures + - does not change the state of mutex wm + !*/ + + // do nothing by default + virtual void on_keydown ( + unsigned long key, + bool is_printable, + unsigned long state + ){} + /*! + requires + - is_closed() == false + - mutex wm is locked + - is called when a keyboard key is pressed or if a key is held + down then this is called repeatedly at a certain rate once the + typematic action begins (note that some keys might not have any + typematic action on some platforms). + - if (is_printable) then + - key == the character that was pressed. (e.g. 'a', 'b', '1' etc.) + - this is a printable character. Note that ' ', '\t', and + '\n' (this is the return/enter key) are all considered printable. + - else + - key == one of the non_printable_keyboard_keys enums. + - state == the bitwise OR of the keyboard modifiers that are currently + depressed (taken from keyboard_state_masks). + - if (key is not in the range 'a' to 'z' or 'A' to 'Z') then + - if (the shift key was down when this key was pressed) then + - (state & KBD_MOD_SHIFT) != 0 + - else + - (state & KBD_MOD_SHIFT) == 0 + - else + - the state of the shift key is implementation defined + ensures + - does not change the state of mutex wm + !*/ + + virtual void on_string_put ( + const std::wstring &str + ){} + /*! + requires + - is_closed() == false + - mutex wm is locked + - is called when a wide/multibyte character input method determines a string + that is being input to the window. + - str == the string that is being input + ensures + - does not change the state of mutex wm + !*/ + + private: + + virtual void paint ( + const canvas& c + ) =0; + /*! + requires + - is_closed() == false + - mutex wm is locked + - is called when part of the window needs to be repainted for + any reason. + - c == a canvas object that represents the invalid area of this + window which needs to be painted. + ensures + - does not change the state of mutex wm + !*/ + + base_window(base_window&); // copy constructor + base_window& operator=(base_window&); // assignment operator + + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_GUI_CORE_KERNEl_ABSTRACT_ + diff --git a/ml/dlib/dlib/gui_core/windows.h b/ml/dlib/dlib/gui_core/windows.h new file mode 100644 index 000000000..8d71fd70e --- /dev/null +++ b/ml/dlib/dlib/gui_core/windows.h @@ -0,0 +1,6 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_GUI_CORE_KERNEl_2_ +#include "gui_core_kernel_1.h" +#endif + diff --git a/ml/dlib/dlib/gui_core/xlib.h b/ml/dlib/dlib/gui_core/xlib.h new file mode 100644 index 000000000..de9c4666b --- /dev/null +++ b/ml/dlib/dlib/gui_core/xlib.h @@ -0,0 +1,6 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_GUI_CORE_KERNEl_1_ +#include "gui_core_kernel_2.h" +#endif + diff --git a/ml/dlib/dlib/gui_widgets.h b/ml/dlib/dlib/gui_widgets.h new file mode 100644 index 000000000..5b243ef8f --- /dev/null +++ b/ml/dlib/dlib/gui_widgets.h @@ -0,0 +1,18 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#ifdef DLIB_ALL_SOURCE_END +#include "dlib_basic_cpp_build_tutorial.txt" +#endif + +#ifndef DLIB_GUI_WIDGETs_ +#define DLIB_GUI_WIDGETs_ + + + +#include "gui_widgets/widgets.h" + + + +#endif // DLIB_GUI_WIDGETs_ + diff --git a/ml/dlib/dlib/gui_widgets/base_widgets.cpp b/ml/dlib/dlib/gui_widgets/base_widgets.cpp new file mode 100644 index 000000000..2f2eb8e9c --- /dev/null +++ b/ml/dlib/dlib/gui_widgets/base_widgets.cpp @@ -0,0 +1,3343 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net), Keita Mochizuki +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BASE_WIDGETs_CPP_ +#define DLIB_BASE_WIDGETs_CPP_ + +#include +#include + +#include "base_widgets.h" +#include "../assert.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // button object methods +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + void button:: + set_size ( + unsigned long width, + unsigned long height + ) + { + auto_mutex M(m); + rectangle min_rect = style->get_min_size(name_,*mfont); + // only change the size if it isn't going to be too small to fit the name + if (height >= min_rect.height() && + width >= min_rect.width()) + { + rectangle old(rect); + rect = resize_rect(rect,width,height); + parent.invalidate_rectangle(style->get_invalidation_rect(rect+old)); + btn_tooltip.set_size(width,height); + } + } + +// ---------------------------------------------------------------------------------------- + + void button:: + show ( + ) + { + button_action::show(); + btn_tooltip.show(); + } + +// ---------------------------------------------------------------------------------------- + + void button:: + hide ( + ) + { + button_action::hide(); + btn_tooltip.hide(); + } + +// ---------------------------------------------------------------------------------------- + + void button:: + enable ( + ) + { + button_action::enable(); + btn_tooltip.enable(); + } + +// ---------------------------------------------------------------------------------------- + + void button:: + disable ( + ) + { + button_action::disable(); + btn_tooltip.disable(); + } + +// ---------------------------------------------------------------------------------------- + + void button:: + set_tooltip_text ( + const std::string& text + ) + { + btn_tooltip.set_text(text); + } + +// ---------------------------------------------------------------------------------------- + + void button:: + set_tooltip_text ( + const std::wstring& text + ) + { + btn_tooltip.set_text(text); + } + +// ---------------------------------------------------------------------------------------- + + void button:: + set_tooltip_text ( + const ustring& text + ) + { + btn_tooltip.set_text(text); + } + +// ---------------------------------------------------------------------------------------- + + const std::string button:: + tooltip_text ( + ) const + { + return btn_tooltip.text(); + } + + const std::wstring button:: + tooltip_wtext ( + ) const + { + return btn_tooltip.wtext(); + } + + const dlib::ustring button:: + tooltip_utext ( + ) const + { + return btn_tooltip.utext(); + } + +// ---------------------------------------------------------------------------------------- + + void button:: + set_main_font ( + const std::shared_ptr& f + ) + { + auto_mutex M(m); + mfont = f; + set_name(name_); + } + +// ---------------------------------------------------------------------------------------- + + void button:: + set_pos ( + long x, + long y + ) + { + auto_mutex M(m); + button_action::set_pos(x,y); + btn_tooltip.set_pos(x,y); + } + +// ---------------------------------------------------------------------------------------- + + void button:: + set_name ( + const std::string& name + ) + { + set_name(convert_mbstring_to_wstring(name)); + } + + void button:: + set_name ( + const std::wstring& name + ) + { + set_name(convert_wstring_to_utf32(name)); + } + + void button:: + set_name ( + const ustring& name + ) + { + auto_mutex M(m); + name_ = name; + // do this to get rid of any reference counting that may be present in + // the std::string implementation. + name_[0] = name_[0]; + + rectangle old(rect); + rect = move_rect(style->get_min_size(name,*mfont),rect.left(),rect.top()); + btn_tooltip.set_size(rect.width(),rect.height()); + + parent.invalidate_rectangle(style->get_invalidation_rect(rect+old)); + } + +// ---------------------------------------------------------------------------------------- + + const std::string button:: + name ( + ) const + { + auto_mutex M(m); + std::string temp = convert_wstring_to_mbstring(wname()); + // do this to get rid of any reference counting that may be present in + // the std::string implementation. + char c = temp[0]; + temp[0] = c; + return temp; + } + + const std::wstring button:: + wname ( + ) const + { + auto_mutex M(m); + std::wstring temp = convert_utf32_to_wstring(uname()); + // do this to get rid of any reference counting that may be present in + // the std::wstring implementation. + wchar_t w = temp[0]; + temp[0] = w; + return temp; + } + + const dlib::ustring button:: + uname ( + ) const + { + auto_mutex M(m); + dlib::ustring temp = name_; + // do this to get rid of any reference counting that may be present in + // the dlib::ustring implementation. + temp[0] = name_[0]; + return temp; + } + +// ---------------------------------------------------------------------------------------- + + void button:: + on_button_up ( + bool mouse_over + ) + { + if (mouse_over) + { + // this is a valid button click + if (event_handler.is_set()) + event_handler(); + if (event_handler_self.is_set()) + event_handler_self(*this); + } + if (button_up_handler.is_set()) + button_up_handler(mouse_over); + if (button_up_handler_self.is_set()) + button_up_handler_self(mouse_over,*this); + } + +// ---------------------------------------------------------------------------------------- + + void button:: + on_button_down ( + ) + { + if (button_down_handler.is_set()) + button_down_handler(); + if (button_down_handler_self.is_set()) + button_down_handler_self(*this); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // draggable object methods +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + draggable::~draggable() {} + +// ---------------------------------------------------------------------------------------- + + void draggable:: + on_mouse_move ( + unsigned long state, + long x, + long y + ) + { + if (drag && (state & base_window::LEFT) && enabled && !hidden) + { + // the user is trying to drag this object. we should calculate the new + // x and y positions for the upper left corner of this object's rectangle + + long new_x = x - this->x; + long new_y = y - this->y; + + // make sure these points are inside the draggable area. + if (new_x < area.left()) + new_x = area.left(); + if (new_x + static_cast(rect.width()) - 1 > area.right()) + new_x = area.right() - rect.width() + 1; + + if (new_y + static_cast(rect.height()) - 1 > area.bottom()) + new_y = area.bottom() - rect.height() + 1; + if (new_y < area.top()) + new_y = area.top(); + + // now make the new rectangle for this object + rectangle new_rect( + new_x, + new_y, + new_x + rect.width() - 1, + new_y + rect.height() - 1 + ); + + // only do anything if this is a new rectangle and it is inside area + if (new_rect != rect && area.intersect(new_rect) == new_rect) + { + parent.invalidate_rectangle(new_rect + rect); + rect = new_rect; + + // call the on_drag() event handler + on_drag(); + } + } + else + { + drag = false; + on_drag_stop(); + } + } + +// ---------------------------------------------------------------------------------------- + + void draggable:: + on_mouse_up ( + unsigned long , + unsigned long state, + long , + long + ) + { + if (drag && (state & base_window::LEFT) == 0) + { + drag = false; + on_drag_stop(); + } + } + +// ---------------------------------------------------------------------------------------- + + void draggable:: + on_mouse_down ( + unsigned long btn, + unsigned long , + long x, + long y, + bool + ) + { + if (enabled && !hidden && rect.contains(x,y) && btn == base_window::LEFT) + { + drag = true; + this->x = x - rect.left(); + this->y = y - rect.top(); + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // mouse_over_event object methods +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + mouse_over_event::~mouse_over_event() {} + +// ---------------------------------------------------------------------------------------- + + void mouse_over_event:: + on_mouse_leave ( + ) + { + if (is_mouse_over_) + { + is_mouse_over_ = false; + on_mouse_not_over(); + } + } + +// ---------------------------------------------------------------------------------------- + + void mouse_over_event:: + on_mouse_move ( + unsigned long , + long x, + long y + ) + { + if (rect.contains(x,y) == false) + { + if (is_mouse_over_) + { + is_mouse_over_ = false; + on_mouse_not_over(); + } + } + else if (is_mouse_over_ == false) + { + is_mouse_over_ = true; + if (enabled && !hidden) + on_mouse_over(); + } + } + +// ---------------------------------------------------------------------------------------- + + bool mouse_over_event:: + is_mouse_over ( + ) const + { + // check if the mouse is still really over this button + if (is_mouse_over_ && rect.contains(lastx,lasty) == false) + { + // trigger a user event to call on_mouse_not_over() and repaint this object. + // we must do this in another event because someone might call is_mouse_over() + // from draw() and you don't want this function to end up calling + // parent.invalidate_rectangle(). It would lead to draw() being called over + // and over. + parent.trigger_user_event((void*)this,drawable::next_free_user_event_number()); + return false; + } + + return is_mouse_over_; + } + +// ---------------------------------------------------------------------------------------- + + void mouse_over_event:: + on_user_event ( + int num + ) + { + if (is_mouse_over_ && num == drawable::next_free_user_event_number()) + { + is_mouse_over_ = false; + on_mouse_not_over(); + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // button_action object methods +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + button_action::~button_action() {} + +// ---------------------------------------------------------------------------------------- + + void button_action:: + on_mouse_down ( + unsigned long btn, + unsigned long , + long x, + long y, + bool + ) + { + if (enabled && !hidden && btn == base_window::LEFT && rect.contains(x,y)) + { + is_depressed_ = true; + seen_click = true; + parent.invalidate_rectangle(rect); + on_button_down(); + } + } + +// ---------------------------------------------------------------------------------------- + + void button_action:: + on_mouse_not_over ( + ) + { + if (is_depressed_) + { + is_depressed_ = false; + parent.invalidate_rectangle(rect); + on_button_up(false); + } + } + +// ---------------------------------------------------------------------------------------- + + void button_action:: + on_mouse_move ( + unsigned long state, + long x, + long y + ) + { + // forward event to the parent class so it can do it's thing as well as us + mouse_over_event::on_mouse_move(state,x,y); + + if (enabled == false || hidden == true) + return; + + + if ((state & base_window::LEFT) == 0) + { + seen_click = false; + if (is_depressed_) + { + is_depressed_ = false; + parent.invalidate_rectangle(rect); + on_button_up(false); + } + + // the left button isn't down so we don't care about anything else + return; + } + + if (rect.contains(x,y) == false) + { + if (is_depressed_) + { + is_depressed_ = false; + parent.invalidate_rectangle(rect); + on_button_up(false); + } + } + else if (is_depressed_ == false && seen_click) + { + is_depressed_ = true; + parent.invalidate_rectangle(rect); + on_button_down(); + } + } + +// ---------------------------------------------------------------------------------------- + + void button_action:: + on_mouse_up ( + unsigned long btn, + unsigned long, + long x, + long y + ) + { + if (enabled && !hidden && btn == base_window::LEFT) + { + if (is_depressed_) + { + is_depressed_ = false; + parent.invalidate_rectangle(rect); + + if (rect.contains(x,y)) + { + on_button_up(true); + } + else + { + on_button_up(false); + } + } + else if (seen_click && rect.contains(x,y)) + { + // this case here covers the unlikly event that you click on a button, + // move the mouse off the button and then move it back very quickly and + // release the mouse button. It is possible that this mouse up event + // will occurr before any mouse move event so you might not have set + // that the button is depressed yet. + + // So we should say that this triggers an on_button_down() event and + // then an on_button_up(true) event. + + parent.invalidate_rectangle(rect); + + on_button_down(); + on_button_up(true); + } + + seen_click = false; + } + } + +// ---------------------------------------------------------------------------------------- + + bool button_action:: + is_depressed ( + ) const + { + // check if the mouse is still really over this button + if (enabled && !hidden && is_depressed_ && rect.contains(lastx,lasty) == false) + { + // trigger a user event to call on_button_up() and repaint this object. + // we must do this in another event because someone might call is_depressed() + // from draw() and you don't want this function to end up calling + // parent.invalidate_rectangle(). It would lead to draw() being called over + // and over. + parent.trigger_user_event((void*)this,mouse_over_event::next_free_user_event_number()); + return false; + } + + return is_depressed_; + } + +// ---------------------------------------------------------------------------------------- + + void button_action:: + on_user_event ( + int num + ) + { + // forward event to the parent class so it can do it's thing as well as us + mouse_over_event::on_user_event(num); + + if (is_depressed_ && num == mouse_over_event::next_free_user_event_number()) + { + is_depressed_ = false; + parent.invalidate_rectangle(rect); + on_button_up(false); + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // scroll_bar object methods +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + scroll_bar:: + scroll_bar( + drawable_window& w, + bar_orientation orientation + ) : + drawable(w), + b1(w), + b2(w), + slider(w,*this,&scroll_bar::on_slider_drag), + ori(orientation), + top_filler(w,*this,&scroll_bar::top_filler_down,&scroll_bar::top_filler_up), + bottom_filler(w,*this,&scroll_bar::bottom_filler_down,&scroll_bar::bottom_filler_up), + pos(0), + max_pos(0), + js(10), + b1_timer(*this,&scroll_bar::b1_down_t), + b2_timer(*this,&scroll_bar::b2_down_t), + top_filler_timer(*this,&scroll_bar::top_filler_down_t), + bottom_filler_timer(*this,&scroll_bar::bottom_filler_down_t) + { + set_style(scroll_bar_style_default()); + + // don't show the slider when there is no place it can move. + slider.hide(); + + set_length(100); + + b1.set_button_down_handler(*this,&scroll_bar::b1_down); + b2.set_button_down_handler(*this,&scroll_bar::b2_down); + + b1.set_button_up_handler(*this,&scroll_bar::b1_up); + b2.set_button_up_handler(*this,&scroll_bar::b2_up); + b1.disable(); + b2.disable(); + enable_events(); + } + +// ---------------------------------------------------------------------------------------- + + scroll_bar:: + ~scroll_bar( + ) + { + disable_events(); + parent.invalidate_rectangle(rect); + // wait for all the timers to be stopped + b1_timer.stop_and_wait(); + b2_timer.stop_and_wait(); + top_filler_timer.stop_and_wait(); + bottom_filler_timer.stop_and_wait(); + } + +// ---------------------------------------------------------------------------------------- + + scroll_bar::bar_orientation scroll_bar:: + orientation ( + ) const + { + auto_mutex M(m); + return ori; + } + +// ---------------------------------------------------------------------------------------- + + void scroll_bar:: + set_length ( + unsigned long length + ) + { + auto_mutex M(m); + // make the min length be at least 1 + if (length == 0) + { + length = 1; + } + + + parent.invalidate_rectangle(rect); + + if (ori == HORIZONTAL) + { + rect.set_right(rect.left() + length - 1); + rect.set_bottom(rect.top() + style->get_width() - 1); + + const long btn_size = style->get_button_length(rect.width(), max_pos); + + b1.set_size(btn_size,style->get_width()); + b2.set_size(btn_size,style->get_width()); + + slider.set_size(get_slider_size(),style->get_width()); + } + else + { + rect.set_right(rect.left() + style->get_width() - 1); + rect.set_bottom(rect.top() + length - 1); + + const long btn_size = style->get_button_length(rect.height(), max_pos); + + b1.set_size(style->get_width(),btn_size); + b2.set_size(style->get_width(),btn_size); + + slider.set_size(style->get_width(),get_slider_size()); + } + + // call this to put everything is in the right spot. + set_pos (rect.left(),rect.top()); + + if ((b2.get_rect().top() - b1.get_rect().bottom() - 1 <= 8 && ori == VERTICAL) || + (b2.get_rect().left() - b1.get_rect().right() - 1 <= 8 && ori == HORIZONTAL) || + max_pos == 0) + { + hide_slider(); + } + else if (enabled && !hidden) + { + show_slider(); + } + } + +// ---------------------------------------------------------------------------------------- + + void scroll_bar:: + set_pos ( + long x, + long y + ) + { + auto_mutex M(m); + drawable::set_pos(x,y); + + b1.set_pos(rect.left(),rect.top()); + if (ori == HORIZONTAL) + { + // make the b2 button appear at the end of the scroll_bar + b2.set_pos(rect.right()-b2.get_rect().width() + 1,rect.top()); + + if (max_pos != 0) + { + double range = b2.get_rect().left() - b1.get_rect().right() - slider.get_rect().width() - 1; + double slider_pos = pos; + slider_pos /= max_pos; + slider_pos *= range; + slider.set_pos( + static_cast(slider_pos)+rect.left() + b1.get_rect().width(), + rect.top() + ); + + // move the draggable area for the slider to the new location + rectangle area = rect; + area.set_left(area.left() + style->get_width()); + area.set_right(area.right() - style->get_width()); + slider.set_draggable_area(area); + + } + + + } + else + { + // make the b2 button appear at the end of the scroll_bar + b2.set_pos(rect.left(), rect.bottom() - b2.get_rect().height() + 1); + + if (max_pos != 0) + { + double range = b2.get_rect().top() - b1.get_rect().bottom() - slider.get_rect().height() - 1; + double slider_pos = pos; + slider_pos /= max_pos; + slider_pos *= range; + slider.set_pos( + rect.left(), + static_cast(slider_pos) + rect.top() + b1.get_rect().height() + ); + + // move the draggable area for the slider to the new location + rectangle area = rect; + area.set_top(area.top() + style->get_width()); + area.set_bottom(area.bottom() - style->get_width()); + slider.set_draggable_area(area); + } + } + + adjust_fillers(); + } + +// ---------------------------------------------------------------------------------------- + + unsigned long scroll_bar:: + get_slider_size ( + ) const + { + if (ori == HORIZONTAL) + return style->get_slider_length(rect.width(),max_pos); + else + return style->get_slider_length(rect.height(),max_pos); + } + +// ---------------------------------------------------------------------------------------- + + void scroll_bar:: + adjust_fillers ( + ) + { + rectangle top(rect), bottom(rect); + + if (ori == HORIZONTAL) + { + if (slider.is_hidden()) + { + top.set_left(b1.get_rect().right()+1); + top.set_right(b2.get_rect().left()-1); + bottom.set_left(1); + bottom.set_right(-1); + } + else + { + top.set_left(b1.get_rect().right()+1); + top.set_right(slider.get_rect().left()-1); + bottom.set_left(slider.get_rect().right()+1); + bottom.set_right(b2.get_rect().left()-1); + } + } + else + { + if (slider.is_hidden()) + { + top.set_top(b1.get_rect().bottom()+1); + top.set_bottom(b2.get_rect().top()-1); + bottom.set_top(1); + bottom.set_bottom(-1); + } + else + { + top.set_top(b1.get_rect().bottom()+1); + top.set_bottom(slider.get_rect().top()-1); + bottom.set_top(slider.get_rect().bottom()+1); + bottom.set_bottom(b2.get_rect().top()-1); + } + } + + top_filler.rect = top; + bottom_filler.rect = bottom; + } + +// ---------------------------------------------------------------------------------------- + + void scroll_bar:: + hide_slider ( + ) + { + rectangle top(rect), bottom(rect); + slider.hide(); + top_filler.disable(); + bottom_filler.disable(); + bottom_filler.hide(); + if (ori == HORIZONTAL) + { + top.set_left(b1.get_rect().right()+1); + top.set_right(b2.get_rect().left()-1); + } + else + { + top.set_top(b1.get_rect().bottom()+1); + top.set_bottom(b2.get_rect().top()-1); + } + top_filler.rect = top; + bottom_filler.rect = bottom; + } + +// ---------------------------------------------------------------------------------------- + + void scroll_bar:: + show_slider ( + ) + { + if ((b2.get_rect().top() - b1.get_rect().bottom() - 1 <= 8 && ori == VERTICAL) || + (b2.get_rect().left() - b1.get_rect().right() - 1 <= 8 && ori == HORIZONTAL) || + max_pos == 0) + return; + + rectangle top(rect), bottom(rect); + slider.show(); + top_filler.enable(); + bottom_filler.enable(); + bottom_filler.show(); + if (ori == HORIZONTAL) + { + top.set_left(b1.get_rect().right()+1); + top.set_right(slider.get_rect().left()-1); + bottom.set_left(slider.get_rect().right()+1); + bottom.set_right(b2.get_rect().left()-1); + } + else + { + top.set_top(b1.get_rect().bottom()+1); + top.set_bottom(slider.get_rect().top()-1); + bottom.set_top(slider.get_rect().bottom()+1); + bottom.set_bottom(b2.get_rect().top()-1); + } + top_filler.rect = top; + bottom_filler.rect = bottom; + } + +// ---------------------------------------------------------------------------------------- + + long scroll_bar:: + max_slider_pos ( + ) const + { + auto_mutex M(m); + return max_pos; + } + +// ---------------------------------------------------------------------------------------- + + void scroll_bar:: + set_max_slider_pos ( + long mpos + ) + { + auto_mutex M(m); + max_pos = mpos; + if (pos > mpos) + pos = mpos; + + if (ori == HORIZONTAL) + set_length(rect.width()); + else + set_length(rect.height()); + + if (mpos != 0 && enabled) + { + b1.enable(); + b2.enable(); + } + else + { + b1.disable(); + b2.disable(); + } + + } + +// ---------------------------------------------------------------------------------------- + + void scroll_bar:: + set_slider_pos ( + long pos + ) + { + auto_mutex M(m); + if (pos < 0) + pos = 0; + if (pos > max_pos) + pos = max_pos; + + this->pos = pos; + + // move the slider object to its new position + set_pos(rect.left(),rect.top()); + } + +// ---------------------------------------------------------------------------------------- + + long scroll_bar:: + slider_pos ( + ) const + { + auto_mutex M(m); + return pos; + } + +// ---------------------------------------------------------------------------------------- + + void scroll_bar:: + on_slider_drag ( + ) + { + if (ori == HORIZONTAL) + { + double slider_pos = slider.get_rect().left() - b1.get_rect().right() - 1; + double range = b2.get_rect().left() - b1.get_rect().right() - slider.get_rect().width() - 1; + slider_pos /= range; + slider_pos *= max_pos; + pos = static_cast(slider_pos); + } + else + { + double slider_pos = slider.get_rect().top() - b1.get_rect().bottom() - 1; + double range = b2.get_rect().top() - b1.get_rect().bottom() - slider.get_rect().height() - 1; + slider_pos /= range; + slider_pos *= max_pos; + pos = static_cast(slider_pos); + } + + adjust_fillers(); + + if (scroll_handler.is_set()) + scroll_handler(); + } + +// ---------------------------------------------------------------------------------------- + + void scroll_bar:: + draw ( + const canvas& + ) const + { + } + +// ---------------------------------------------------------------------------------------- + + void scroll_bar:: + b1_down ( + ) + { + if (pos != 0) + { + set_slider_pos(pos-1); + if (scroll_handler.is_set()) + scroll_handler(); + + if (b1_timer.delay_time() == 1000) + b1_timer.set_delay_time(500); + else + b1_timer.set_delay_time(50); + b1_timer.start(); + } + } + +// ---------------------------------------------------------------------------------------- + + void scroll_bar:: + b1_up ( + bool + ) + { + b1_timer.stop(); + b1_timer.set_delay_time(1000); + } + +// ---------------------------------------------------------------------------------------- + + void scroll_bar:: + b2_down ( + ) + { + if (pos != max_pos) + { + set_slider_pos(pos+1); + if (scroll_handler.is_set()) + scroll_handler(); + + if (b2_timer.delay_time() == 1000) + b2_timer.set_delay_time(500); + else + b2_timer.set_delay_time(50); + b2_timer.start(); + } + } + +// ---------------------------------------------------------------------------------------- + + void scroll_bar:: + b2_up ( + bool + ) + { + b2_timer.stop(); + b2_timer.set_delay_time(1000); + } + +// ---------------------------------------------------------------------------------------- + + void scroll_bar:: + top_filler_down ( + ) + { + // ignore this if the mouse is now outside this object. This could happen + // since the timers are also calling this function. + if (top_filler.rect.contains(lastx,lasty) == false) + { + top_filler_up(false); + return; + } + + if (pos != 0) + { + if (pos < js) + { + // if there is less than jump_size() space left then jump the remaining + // amount. + delayed_set_slider_pos(0); + } + else + { + delayed_set_slider_pos(pos-js); + } + + if (top_filler_timer.delay_time() == 1000) + top_filler_timer.set_delay_time(500); + else + top_filler_timer.set_delay_time(50); + top_filler_timer.start(); + } + } + +// ---------------------------------------------------------------------------------------- + + void scroll_bar:: + top_filler_up ( + bool + ) + { + top_filler_timer.stop(); + top_filler_timer.set_delay_time(1000); + } + +// ---------------------------------------------------------------------------------------- + + void scroll_bar:: + bottom_filler_down ( + ) + { + // ignore this if the mouse is now outside this object. This could happen + // since the timers are also calling this function. + if (bottom_filler.rect.contains(lastx,lasty) == false) + { + bottom_filler_up(false); + return; + } + + if (pos != max_pos) + { + if (max_pos - pos < js) + { + // if there is less than jump_size() space left then jump the remaining + // amount. + delayed_set_slider_pos(max_pos); + } + else + { + delayed_set_slider_pos(pos+js); + } + + if (bottom_filler_timer.delay_time() == 1000) + bottom_filler_timer.set_delay_time(500); + else + bottom_filler_timer.set_delay_time(50); + bottom_filler_timer.start(); + } + } + +// ---------------------------------------------------------------------------------------- + + void scroll_bar:: + bottom_filler_up ( + bool + ) + { + bottom_filler_timer.stop(); + bottom_filler_timer.set_delay_time(1000); + } + +// ---------------------------------------------------------------------------------------- + + void scroll_bar:: + set_jump_size ( + long js_ + ) + { + auto_mutex M(m); + if (js_ < 1) + js = 1; + else + js = js_; + } + +// ---------------------------------------------------------------------------------------- + + long scroll_bar:: + jump_size ( + ) const + { + auto_mutex M(m); + return js; + } + +// ---------------------------------------------------------------------------------------- + + void scroll_bar:: + on_user_event ( + int i + ) + { + switch (i) + { + case 0: + b1_down(); + break; + case 1: + b2_down(); + break; + case 2: + top_filler_down(); + break; + case 3: + bottom_filler_down(); + break; + case 4: + // if the position we are supposed to switch the slider too isn't + // already set + if (delayed_pos != pos) + { + set_slider_pos(delayed_pos); + if (scroll_handler.is_set()) + scroll_handler(); + } + break; + default: + break; + } + } + +// ---------------------------------------------------------------------------------------- + + void scroll_bar:: + delayed_set_slider_pos ( + unsigned long dpos + ) + { + delayed_pos = dpos; + parent.trigger_user_event(this,4); + } + +// ---------------------------------------------------------------------------------------- + + void scroll_bar:: + b1_down_t ( + ) + { + parent.trigger_user_event(this,0); + } + +// ---------------------------------------------------------------------------------------- + + void scroll_bar:: + b2_down_t ( + ) + { + parent.trigger_user_event(this,1); + } + +// ---------------------------------------------------------------------------------------- + + void scroll_bar:: + top_filler_down_t ( + ) + { + parent.trigger_user_event(this,2); + } + +// ---------------------------------------------------------------------------------------- + + void scroll_bar:: + bottom_filler_down_t ( + ) + { + parent.trigger_user_event(this,3); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// widget_group object methods +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + void widget_group:: + empty ( + ) + { + auto_mutex M(m); + widgets.clear(); + wg_widgets.clear(); + } + +// ---------------------------------------------------------------------------------------- + + void widget_group:: + add ( + drawable& widget, + unsigned long x, + unsigned long y + ) + { + auto_mutex M(m); + drawable* w = &widget; + relpos rp; + rp.x = x; + rp.y = y; + if (widgets.is_in_domain(w)) + { + widgets[w].x = x; + widgets[w].y = y; + } + else + { + widgets.add(w,rp); + } + if (is_hidden()) + widget.hide(); + else + widget.show(); + + if (is_enabled()) + widget.enable(); + else + widget.disable(); + + widget.set_z_order(z_order()); + widget.set_pos(x+rect.left(),y+rect.top()); + } + +// ---------------------------------------------------------------------------------------- + + void widget_group:: + add ( + widget_group& widget, + unsigned long x, + unsigned long y + ) + { + auto_mutex M(m); + drawable& w = widget; + add(w, x, y); + + widget_group* wg = &widget; + wg_widgets.add(wg); + } + +// ---------------------------------------------------------------------------------------- + + bool widget_group:: + is_member ( + const drawable& widget + ) const + { + auto_mutex M(m); + drawable* w = const_cast(&widget); + return widgets.is_in_domain(w); + } + +// ---------------------------------------------------------------------------------------- + + void widget_group:: + remove ( + const drawable& widget + ) + { + auto_mutex M(m); + drawable* w = const_cast(&widget); + if (widgets.is_in_domain(w)) + { + widgets.destroy(w); + + // check if we also have an entry in the wg_widgets set and if + // so then remove that too + widget_group* wg = reinterpret_cast(w); + if (wg_widgets.is_member(wg)) + { + wg_widgets.destroy(wg); + } + } + } + +// ---------------------------------------------------------------------------------------- + + size_t widget_group:: + size ( + ) const + { + auto_mutex M(m); + return widgets.size(); + } + +// ---------------------------------------------------------------------------------------- + + void widget_group:: + set_pos ( + long x, + long y + ) + { + auto_mutex M(m); + widgets.reset(); + while (widgets.move_next()) + { + const unsigned long rx = widgets.element().value().x; + const unsigned long ry = widgets.element().value().y; + widgets.element().key()->set_pos(x+rx,y+ry); + } + drawable::set_pos(x,y); + } + +// ---------------------------------------------------------------------------------------- + + void widget_group:: + set_z_order ( + long order + ) + { + auto_mutex M(m); + widgets.reset(); + while (widgets.move_next()) + widgets.element().key()->set_z_order(order); + drawable::set_z_order(order); + } + +// ---------------------------------------------------------------------------------------- + + void widget_group:: + show ( + ) + { + auto_mutex M(m); + widgets.reset(); + while (widgets.move_next()) + widgets.element().key()->show(); + drawable::show(); + } + +// ---------------------------------------------------------------------------------------- + + void widget_group:: + hide ( + ) + { + auto_mutex M(m); + widgets.reset(); + while (widgets.move_next()) + widgets.element().key()->hide(); + drawable::hide(); + } + +// ---------------------------------------------------------------------------------------- + + void widget_group:: + enable ( + ) + { + auto_mutex M(m); + widgets.reset(); + while (widgets.move_next()) + widgets.element().key()->enable(); + drawable::enable(); + } + +// ---------------------------------------------------------------------------------------- + + void widget_group:: + disable () + { + auto_mutex M(m); + widgets.reset(); + while (widgets.move_next()) + widgets.element().key()->disable(); + drawable::disable(); + } + +// ---------------------------------------------------------------------------------------- + + void widget_group:: + fit_to_contents ( + ) + { + auto_mutex M(m); + + // call fit_to_contents on all the widget_groups we contain + wg_widgets.reset(); + while (wg_widgets.move_next()) + wg_widgets.element()->fit_to_contents(); + + // now accumulate a rectangle that contains everything in this widget_group + rectangle r; + widgets.reset(); + while (widgets.move_next()) + r = r + widgets.element().key()->get_rect(); + + if (r.is_empty()) + { + // make sure it is still empty after we set it at the correct position + r.set_right(rect.left()-1); + r.set_bottom(rect.top()-1); + } + + r.set_left(rect.left()); + r.set_top(rect.top()); + rect = r; + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// class popup_menu +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + popup_menu:: + popup_menu ( + ) : + base_window(false,true), + pad(2), + item_pad(3), + cur_rect(pad,pad,pad-1,pad-1), + left_width(0), + middle_width(0), + selected_item(0), + submenu_open(false) + { + } + +// ---------------------------------------------------------------------------------------- + + void popup_menu:: + enable_menu_item ( + unsigned long idx + ) + { + DLIB_ASSERT ( idx < size() , + "\tvoid popup_menu::enable_menu_item()" + << "\n\tidx: " << idx + << "\n\tsize(): " << size() + ); + auto_mutex M(wm); + item_enabled[idx] = true; + invalidate_rectangle(cur_rect); + } + +// ---------------------------------------------------------------------------------------- + + void popup_menu:: + disable_menu_item ( + unsigned long idx + ) + { + DLIB_ASSERT ( idx < size() , + "\tvoid popup_menu::enable_menu_item()" + << "\n\tidx: " << idx + << "\n\tsize(): " << size() + ); + auto_mutex M(wm); + item_enabled[idx] = false; + invalidate_rectangle(cur_rect); + } + +// ---------------------------------------------------------------------------------------- + + size_t popup_menu:: + size ( + ) const + { + auto_mutex M(wm); + return items.size(); + } + +// ---------------------------------------------------------------------------------------- + + void popup_menu:: + clear ( + ) + { + auto_mutex M(wm); + hide(); + cur_rect = rectangle(pad,pad,pad-1,pad-1); + win_rect = rectangle(); + left_width = 0; + middle_width = 0; + items.clear(); + item_enabled.clear(); + left_rects.clear(); + middle_rects.clear(); + right_rects.clear(); + line_rects.clear(); + submenus.clear(); + selected_item = 0; + submenu_open = false; + } + +// ---------------------------------------------------------------------------------------- + + void popup_menu:: + show ( + ) + { + auto_mutex M(wm); + selected_item = submenus.size(); + base_window::show(); + } + +// ---------------------------------------------------------------------------------------- + + void popup_menu:: + hide ( + ) + { + auto_mutex M(wm); + // hide ourselves + close_submenu(); + selected_item = submenus.size(); + base_window::hide(); + } + +// ---------------------------------------------------------------------------------------- + + void popup_menu:: + select_first_item ( + ) + { + auto_mutex M(wm); + close_submenu(); + selected_item = items.size(); + for (unsigned long i = 0; i < items.size(); ++i) + { + if ((items[i]->has_click_event() || submenus[i]) && item_enabled[i]) + { + selected_item = i; + break; + } + } + invalidate_rectangle(cur_rect); + } + +// ---------------------------------------------------------------------------------------- + + bool popup_menu:: + forwarded_on_keydown ( + unsigned long key, + bool is_printable, + unsigned long state + ) + { + auto_mutex M(wm); + // do nothing if this popup menu is empty + if (items.size() == 0) + return false; + + + // check if the selected item is a submenu + if (selected_item != submenus.size() && submenus[selected_item] != 0 && submenu_open) + { + // send the key to the submenu and return if that menu used the key + if (submenus[selected_item]->forwarded_on_keydown(key,is_printable,state) == true) + return true; + } + + if (key == KEY_UP) + { + for (unsigned long i = 0; i < items.size(); ++i) + { + selected_item = (selected_item + items.size() - 1)%items.size(); + // only stop looking if this one is enabled and has a click event or is a submenu + if (item_enabled[selected_item] && (items[selected_item]->has_click_event() || submenus[selected_item]) ) + break; + } + invalidate_rectangle(cur_rect); + return true; + } + else if (key == KEY_DOWN) + { + for (unsigned long i = 0; i < items.size(); ++i) + { + selected_item = (selected_item + 1)%items.size(); + // only stop looking if this one is enabled and has a click event or is a submenu + if (item_enabled[selected_item] && (items[selected_item]->has_click_event() || submenus[selected_item])) + break; + } + invalidate_rectangle(cur_rect); + return true; + } + else if (key == KEY_RIGHT && submenu_open == false && display_selected_submenu()) + { + submenus[selected_item]->select_first_item(); + return true; + } + else if (key == KEY_LEFT && selected_item != submenus.size() && + submenus[selected_item] != 0 && submenu_open) + { + close_submenu(); + return true; + } + else if (key == '\n') + { + if (selected_item != submenus.size() && (items[selected_item]->has_click_event() || submenus[selected_item])) + { + const long idx = selected_item; + // only hide this popup window if this isn't a submenu + if (submenus[idx] == 0) + { + hide(); + hide_handlers.reset(); + while (hide_handlers.move_next()) + hide_handlers.element()(); + } + else + { + display_selected_submenu(); + submenus[idx]->select_first_item(); + } + items[idx]->on_click(); + return true; + } + } + else if (is_printable) + { + // check if there is a hotkey for this key + for (unsigned long i = 0; i < items.size(); ++i) + { + if (std::tolower(key) == std::tolower(items[i]->get_hot_key()) && + (items[i]->has_click_event() || submenus[i]) && item_enabled[i] ) + { + // only hide this popup window if this isn't a submenu + if (submenus[i] == 0) + { + hide(); + hide_handlers.reset(); + while (hide_handlers.move_next()) + hide_handlers.element()(); + } + else + { + if (selected_item != items.size()) + invalidate_rectangle(line_rects[selected_item]); + + selected_item = i; + display_selected_submenu(); + invalidate_rectangle(line_rects[i]); + submenus[i]->select_first_item(); + } + items[i]->on_click(); + } + } + + // always say we use a printable key for hotkeys + return true; + } + + return false; + } + +// ---------------------------------------------------------------------------------------- + + void popup_menu:: + on_submenu_hide ( + ) + { + hide(); + hide_handlers.reset(); + while (hide_handlers.move_next()) + hide_handlers.element()(); + } + +// ---------------------------------------------------------------------------------------- + + void popup_menu:: + on_window_resized( + ) + { + invalidate_rectangle(win_rect); + } + +// ---------------------------------------------------------------------------------------- + + void popup_menu:: + on_mouse_up ( + unsigned long btn, + unsigned long, + long x, + long y + ) + { + if (cur_rect.contains(x,y) && btn == LEFT) + { + // figure out which item this was on + for (unsigned long i = 0; i < items.size(); ++i) + { + if (line_rects[i].contains(x,y) && item_enabled[i] && items[i]->has_click_event()) + { + // only hide this popup window if this isn't a submenu + if (submenus[i] == 0) + { + hide(); + hide_handlers.reset(); + while (hide_handlers.move_next()) + hide_handlers.element()(); + } + items[i]->on_click(); + break; + } + } + } + } + +// ---------------------------------------------------------------------------------------- + + void popup_menu:: + on_mouse_move ( + unsigned long , + long x, + long y + ) + { + if (cur_rect.contains(x,y)) + { + // check if the mouse is still in the same rect it was in last time + rectangle last_rect; + if (selected_item != submenus.size()) + { + last_rect = line_rects[selected_item]; + } + + // if the mouse isn't in the same rectangle any more + if (last_rect.contains(x,y) == false) + { + if (selected_item != submenus.size()) + { + invalidate_rectangle(last_rect); + close_submenu(); + selected_item = submenus.size(); + } + + + // figure out if we should redraw any menu items + for (unsigned long i = 0; i < items.size(); ++i) + { + if (items[i]->has_click_event() || submenus[i]) + { + if (line_rects[i].contains(x,y)) + { + selected_item = i; + break; + } + } + } + + // if we found a rectangle that contains the mouse then + // tell it to redraw itself + if (selected_item != submenus.size()) + { + display_selected_submenu(); + invalidate_rectangle(line_rects[selected_item]); + } + } + } + } + +// ---------------------------------------------------------------------------------------- + + void popup_menu:: + close_submenu ( + ) + { + if (selected_item != submenus.size() && submenus[selected_item] && submenu_open) + { + submenus[selected_item]->hide(); + submenu_open = false; + } + } + +// ---------------------------------------------------------------------------------------- + + bool popup_menu:: + display_selected_submenu ( + ) + { + // show the submenu if one exists + if (selected_item != submenus.size() && submenus[selected_item]) + { + long wx, wy; + get_pos(wx,wy); + wx += line_rects[selected_item].right(); + wy += line_rects[selected_item].top(); + submenus[selected_item]->set_pos(wx+1,wy-2); + submenus[selected_item]->show(); + submenu_open = true; + return true; + } + return false; + } + +// ---------------------------------------------------------------------------------------- + + void popup_menu:: + on_mouse_leave ( + ) + { + if (selected_item != submenus.size()) + { + // only unhighlight a menu item if it isn't a submenu item + if (submenus[selected_item] == 0) + { + invalidate_rectangle(line_rects[selected_item]); + selected_item = submenus.size(); + } + } + } + +// ---------------------------------------------------------------------------------------- + + void popup_menu:: + paint ( + const canvas& c + ) + { + c.fill(200,200,200); + draw_rectangle(c, win_rect); + for (unsigned long i = 0; i < items.size(); ++i) + { + bool is_selected = false; + if (selected_item != submenus.size() && i == selected_item && + item_enabled[i]) + is_selected = true; + + items[i]->draw_background(c,line_rects[i], item_enabled[i], is_selected); + items[i]->draw_left(c,left_rects[i], item_enabled[i], is_selected); + items[i]->draw_middle(c,middle_rects[i], item_enabled[i], is_selected); + items[i]->draw_right(c,right_rects[i], item_enabled[i], is_selected); + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// class zoomable_region +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + zoomable_region:: + zoomable_region ( + drawable_window& w, + unsigned long events + ) : + drawable(w,MOUSE_CLICK | MOUSE_WHEEL | MOUSE_MOVE | events), + min_scale(0.15), + max_scale(1.0), + zoom_increment_(0.90), + vsb(w, scroll_bar::VERTICAL), + hsb(w, scroll_bar::HORIZONTAL) + { + scale = 1; + mouse_drag_screen = false; + style.reset(new scrollable_region_style_default()); + + hsb.set_scroll_handler(*this,&zoomable_region::on_h_scroll); + vsb.set_scroll_handler(*this,&zoomable_region::on_v_scroll); + } + +// ---------------------------------------------------------------------------------------- + + zoomable_region:: + ~zoomable_region() + { + } + +// ---------------------------------------------------------------------------------------- + + void zoomable_region:: + set_pos ( + long x, + long y + ) + { + auto_mutex M(m); + drawable::set_pos(x,y); + const long border_size = style->get_border_size(); + vsb.set_pos(rect.right()-border_size+1-vsb.width(),rect.top()+border_size); + hsb.set_pos(rect.left()+border_size,rect.bottom()-border_size+1-hsb.height()); + + display_rect_ = rectangle(rect.left()+border_size, + rect.top()+border_size, + rect.right()-border_size-vsb.width(), + rect.bottom()-border_size-hsb.height()); + + } + +// ---------------------------------------------------------------------------------------- + + void zoomable_region:: + set_zoom_increment ( + double zi + ) + { + DLIB_ASSERT(0.0 < zi && zi < 1.0, + "\tvoid zoomable_region::set_zoom_increment(zi)" + << "\n\t the zoom increment must be between 0 and 1" + << "\n\t zi: " << zi + << "\n\t this: " << this + ); + + auto_mutex M(m); + zoom_increment_ = zi; + } + +// ---------------------------------------------------------------------------------------- + + double zoomable_region:: + zoom_increment ( + ) const + { + auto_mutex M(m); + return zoom_increment_; + } + +// ---------------------------------------------------------------------------------------- + + void zoomable_region:: + set_max_zoom_scale ( + double ms + ) + { + DLIB_ASSERT(ms > 0, + "\tvoid zoomable_region::set_max_zoom_scale(ms)" + << "\n\t the max zoom scale must be greater than 0" + << "\n\t ms: " << ms + << "\n\t this: " << this + ); + + auto_mutex M(m); + max_scale = ms; + if (scale > ms) + { + scale = max_scale; + lr_point = gui_to_graph_space(point(display_rect_.right(),display_rect_.bottom())); + redraw_graph(); + } + } + +// ---------------------------------------------------------------------------------------- + + void zoomable_region:: + set_min_zoom_scale ( + double ms + ) + { + DLIB_ASSERT(ms > 0, + "\tvoid zoomable_region::set_min_zoom_scale(ms)" + << "\n\t the min zoom scale must be greater than 0" + << "\n\t ms: " << ms + << "\n\t this: " << this + ); + + auto_mutex M(m); + min_scale = ms; + + if (scale < ms) + { + scale = min_scale; + } + + // just call set_size so that everything gets redrawn right + set_size(rect.width(), rect.height()); + } + +// ---------------------------------------------------------------------------------------- + + double zoomable_region:: + min_zoom_scale ( + ) const + { + auto_mutex M(m); + return min_scale; + } + +// ---------------------------------------------------------------------------------------- + + double zoomable_region:: + max_zoom_scale ( + ) const + { + auto_mutex M(m); + return max_scale; + } + +// ---------------------------------------------------------------------------------------- + + void zoomable_region:: + set_size ( + unsigned long width, + unsigned long height + ) + { + auto_mutex M(m); + rectangle old(rect); + const long border_size = style->get_border_size(); + rect = resize_rect(rect,width,height); + vsb.set_pos(rect.right()-border_size+1-vsb.width(), rect.top()+border_size); + hsb.set_pos(rect.left()+border_size, rect.bottom()-border_size+1-hsb.height()); + + display_rect_ = rectangle(rect.left()+border_size, + rect.top()+border_size, + rect.right()-border_size-vsb.width(), + rect.bottom()-border_size-hsb.height()); + vsb.set_length(display_rect_.height()); + hsb.set_length(display_rect_.width()); + parent.invalidate_rectangle(rect+old); + + const double old_scale = scale; + const vector old_gr_orig(gr_orig); + scale = min_scale; + gr_orig = vector(0,0); + lr_point = gui_to_graph_space(point(display_rect_.right(),display_rect_.bottom())); + scale = old_scale; + + // call adjust_origin() so that the scroll bars get their max slider positions + // setup right + const point rect_corner(display_rect_.left(), display_rect_.top()); + adjust_origin(rect_corner, old_gr_orig); + } + +// ---------------------------------------------------------------------------------------- + + void zoomable_region:: + show ( + ) + { + auto_mutex M(m); + drawable::show(); + hsb.show(); + vsb.show(); + } + +// ---------------------------------------------------------------------------------------- + + void zoomable_region:: + hide ( + ) + { + auto_mutex M(m); + drawable::hide(); + hsb.hide(); + vsb.hide(); + } + +// ---------------------------------------------------------------------------------------- + + void zoomable_region:: + enable ( + ) + { + auto_mutex M(m); + drawable::enable(); + hsb.enable(); + vsb.enable(); + } + +// ---------------------------------------------------------------------------------------- + + void zoomable_region:: + disable ( + ) + { + auto_mutex M(m); + drawable::disable(); + hsb.disable(); + vsb.disable(); + } + +// ---------------------------------------------------------------------------------------- + + void zoomable_region:: + set_z_order ( + long order + ) + { + auto_mutex M(m); + drawable::set_z_order(order); + hsb.set_z_order(order); + vsb.set_z_order(order); + } + +// ---------------------------------------------------------------------------------------- + + point zoomable_region:: + graph_to_gui_space ( + const vector& p + ) const + { + const point rect_corner(display_rect_.left(), display_rect_.top()); + return (p - gr_orig)*scale + rect_corner; + } + +// ---------------------------------------------------------------------------------------- + + vector zoomable_region:: + gui_to_graph_space ( + const point& p + ) const + { + const point rect_corner(display_rect_.left(), display_rect_.top()); + return (p - rect_corner)/scale + gr_orig; + } + +// ---------------------------------------------------------------------------------------- + + point zoomable_region:: + max_graph_point ( + ) const + { + return lr_point; + } + +// ---------------------------------------------------------------------------------------- + + rectangle zoomable_region:: + display_rect ( + ) const + { + return display_rect_; + } + +// ---------------------------------------------------------------------------------------- + + double zoomable_region:: + zoom_scale ( + ) const + { + return scale; + } + +// ---------------------------------------------------------------------------------------- + + void zoomable_region:: + set_zoom_scale ( + double new_scale + ) + { + // if new_scale isn't in the right range then put it back in range before we do the + // rest of this function + if (!(min_scale <= new_scale && new_scale <= max_scale)) + { + if (new_scale > max_scale) + new_scale = max_scale; + else + new_scale = min_scale; + } + + // find the point in the center of the graph area + point center((display_rect_.left()+display_rect_.right())/2, (display_rect_.top()+display_rect_.bottom())/2); + point graph_p(gui_to_graph_space(center)); + scale = new_scale; + adjust_origin(center, graph_p); + redraw_graph(); + } + +// ---------------------------------------------------------------------------------------- + + void zoomable_region:: + center_display_at_graph_point ( + const vector& p + ) + { + // find the point in the center of the graph area + point center((display_rect_.left()+display_rect_.right())/2, (display_rect_.top()+display_rect_.bottom())/2); + adjust_origin(center, p); + redraw_graph(); + } + +// ---------------------------------------------------------------------------------------- + + void zoomable_region:: + on_wheel_down ( + unsigned long + ) + { + // zoom out + if (enabled && !hidden && scale > min_scale && display_rect_.contains(lastx,lasty)) + { + point gui_p(lastx,lasty); + point graph_p(gui_to_graph_space(gui_p)); + const double old_scale = scale; + scale *= zoom_increment_; + if (scale < min_scale) + scale = min_scale; + redraw_graph(); + adjust_origin(gui_p, graph_p); + + if (scale != old_scale) + on_view_changed(); + } + } + +// ---------------------------------------------------------------------------------------- + + void zoomable_region:: + on_wheel_up ( + unsigned long + ) + { + // zoom in + if (enabled && !hidden && scale < max_scale && display_rect_.contains(lastx,lasty)) + { + point gui_p(lastx,lasty); + point graph_p(gui_to_graph_space(gui_p)); + const double old_scale = scale; + scale /= zoom_increment_; + if (scale > max_scale) + scale = max_scale; + redraw_graph(); + adjust_origin(gui_p, graph_p); + + if (scale != old_scale) + on_view_changed(); + } + } + +// ---------------------------------------------------------------------------------------- + + void zoomable_region:: + on_mouse_move ( + unsigned long state, + long x, + long y + ) + { + if (enabled && !hidden && mouse_drag_screen) + { + adjust_origin(point(x,y), drag_screen_point); + redraw_graph(); + on_view_changed(); + } + + // check if the mouse isn't being dragged anymore + if ((state & base_window::LEFT) == 0) + { + mouse_drag_screen = false; + } + } + +// ---------------------------------------------------------------------------------------- + + void zoomable_region:: + on_mouse_up ( + unsigned long , + unsigned long , + long , + long + ) + { + mouse_drag_screen = false; + } + +// ---------------------------------------------------------------------------------------- + + void zoomable_region:: + on_mouse_down ( + unsigned long btn, + unsigned long , + long x, + long y, + bool + ) + { + if (enabled && !hidden && display_rect_.contains(x,y) && btn == base_window::LEFT) + { + mouse_drag_screen = true; + drag_screen_point = gui_to_graph_space(point(x,y)); + } + } + +// ---------------------------------------------------------------------------------------- + + void zoomable_region:: + draw ( + const canvas& c + ) const + { + style->draw_scrollable_region_border(c, rect, enabled); + } + +// ---------------------------------------------------------------------------------------- + + void zoomable_region:: + on_h_scroll ( + ) + { + gr_orig.x() = hsb.slider_pos(); + redraw_graph(); + + on_view_changed(); + } + +// ---------------------------------------------------------------------------------------- + + void zoomable_region:: + on_v_scroll ( + ) + { + gr_orig.y() = vsb.slider_pos(); + redraw_graph(); + + on_view_changed(); + } + +// ---------------------------------------------------------------------------------------- + + void zoomable_region:: + redraw_graph ( + ) + { + parent.invalidate_rectangle(display_rect_); + } + +// ---------------------------------------------------------------------------------------- + + void zoomable_region:: + adjust_origin ( + const point& gui_p, + const vector& graph_p + ) + { + const point rect_corner(display_rect_.left(), display_rect_.top()); + const dlib::vector v(gui_p - rect_corner); + gr_orig = graph_p - v/scale; + + + // make sure the origin isn't outside the point (0,0) + if (gr_orig.x() < 0) + gr_orig.x() = 0; + if (gr_orig.y() < 0) + gr_orig.y() = 0; + + // make sure the lower right corner of the display_rect_ doesn't map to a point beyond lr_point + point lr_rect_corner(display_rect_.right(), display_rect_.bottom()); + point p = graph_to_gui_space(lr_point); + vector lr_rect_corner_graph_space(gui_to_graph_space(lr_rect_corner)); + vector delta(lr_point - lr_rect_corner_graph_space); + if (lr_rect_corner.x() > p.x()) + { + gr_orig.x() += delta.x(); + } + + if (lr_rect_corner.y() > p.y()) + { + gr_orig.y() += delta.y(); + } + + + const vector ul_rect_corner_graph_space(gui_to_graph_space(rect_corner)); + lr_rect_corner_graph_space = gui_to_graph_space(lr_rect_corner); + // now adjust the scroll bars + + hsb.set_max_slider_pos((unsigned long)std::max(lr_point.x()-(lr_rect_corner_graph_space.x()-ul_rect_corner_graph_space.x()),0.0)); + vsb.set_max_slider_pos((unsigned long)std::max(lr_point.y()-(lr_rect_corner_graph_space.y()-ul_rect_corner_graph_space.y()),0.0)); + // adjust slider position now. + hsb.set_slider_pos(static_cast(ul_rect_corner_graph_space.x())); + vsb.set_slider_pos(static_cast(ul_rect_corner_graph_space.y())); + + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// class scrollable_region +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + scrollable_region:: + scrollable_region ( + drawable_window& w, + unsigned long events + ) : + drawable(w, MOUSE_WHEEL|events|MOUSE_CLICK|MOUSE_MOVE), + hsb(w,scroll_bar::HORIZONTAL), + vsb(w,scroll_bar::VERTICAL), + hscroll_bar_inc(1), + vscroll_bar_inc(1), + h_wheel_scroll_bar_inc(1), + v_wheel_scroll_bar_inc(1), + mouse_drag_enabled_(false), + user_is_dragging_mouse(false) + { + style.reset(new scrollable_region_style_default()); + + hsb.set_scroll_handler(*this,&scrollable_region::on_h_scroll); + vsb.set_scroll_handler(*this,&scrollable_region::on_v_scroll); + } + +// ---------------------------------------------------------------------------------------- + + scrollable_region:: + ~scrollable_region ( + ) + { + } + +// ---------------------------------------------------------------------------------------- + + void scrollable_region:: + show ( + ) + { + auto_mutex M(m); + drawable::show(); + if (need_h_scroll()) + hsb.show(); + if (need_v_scroll()) + vsb.show(); + } + +// ---------------------------------------------------------------------------------------- + + void scrollable_region:: + hide ( + ) + { + auto_mutex M(m); + drawable::hide(); + hsb.hide(); + vsb.hide(); + } + +// ---------------------------------------------------------------------------------------- + + void scrollable_region:: + enable ( + ) + { + auto_mutex M(m); + drawable::enable(); + hsb.enable(); + vsb.enable(); + } + +// ---------------------------------------------------------------------------------------- + + void scrollable_region:: + disable ( + ) + { + auto_mutex M(m); + drawable::disable(); + hsb.disable(); + vsb.disable(); + } + +// ---------------------------------------------------------------------------------------- + + void scrollable_region:: + set_z_order ( + long order + ) + { + auto_mutex M(m); + drawable::set_z_order(order); + hsb.set_z_order(order); + vsb.set_z_order(order); + } + +// ---------------------------------------------------------------------------------------- + + void scrollable_region:: + set_size ( + unsigned long width, + unsigned long height + ) + { + auto_mutex M(m); + rectangle old(rect); + rect = resize_rect(rect,width,height); + vsb.set_pos(rect.right()-style->get_border_size()-vsb.width()+1, rect.top()+style->get_border_size()); + hsb.set_pos(rect.left()+style->get_border_size(), rect.bottom()-style->get_border_size()-hsb.height()+1); + + // adjust the display_rect_ + if (need_h_scroll() && need_v_scroll()) + { + // both scroll bars aren't hidden + if (!hidden) + { + vsb.show(); + hsb.show(); + } + display_rect_ = rectangle( rect.left()+style->get_border_size(), + rect.top()+style->get_border_size(), + rect.right()-style->get_border_size()-vsb.width(), + rect.bottom()-style->get_border_size()-hsb.height()); + + // figure out how many scroll bar positions there should be + unsigned long hdelta = total_rect_.width()-display_rect_.width(); + unsigned long vdelta = total_rect_.height()-display_rect_.height(); + hdelta = (hdelta+hscroll_bar_inc-1)/hscroll_bar_inc; + vdelta = (vdelta+vscroll_bar_inc-1)/vscroll_bar_inc; + + hsb.set_max_slider_pos(hdelta); + vsb.set_max_slider_pos(vdelta); + + vsb.set_jump_size((display_rect_.height()+vscroll_bar_inc-1)/vscroll_bar_inc/2+1); + hsb.set_jump_size((display_rect_.width()+hscroll_bar_inc-1)/hscroll_bar_inc/2+1); + } + else if (need_h_scroll()) + { + // only hsb is hidden + if (!hidden) + { + hsb.show(); + vsb.hide(); + } + display_rect_ = rectangle( rect.left()+style->get_border_size(), + rect.top()+style->get_border_size(), + rect.right()-style->get_border_size(), + rect.bottom()-style->get_border_size()-hsb.height()); + + // figure out how many scroll bar positions there should be + unsigned long hdelta = total_rect_.width()-display_rect_.width(); + hdelta = (hdelta+hscroll_bar_inc-1)/hscroll_bar_inc; + + hsb.set_max_slider_pos(hdelta); + vsb.set_max_slider_pos(0); + + hsb.set_jump_size((display_rect_.width()+hscroll_bar_inc-1)/hscroll_bar_inc/2+1); + } + else if (need_v_scroll()) + { + // only vsb is hidden + if (!hidden) + { + hsb.hide(); + vsb.show(); + } + display_rect_ = rectangle( rect.left()+style->get_border_size(), + rect.top()+style->get_border_size(), + rect.right()-style->get_border_size()-vsb.width(), + rect.bottom()-style->get_border_size()); + + unsigned long vdelta = total_rect_.height()-display_rect_.height(); + vdelta = (vdelta+vscroll_bar_inc-1)/vscroll_bar_inc; + + hsb.set_max_slider_pos(0); + vsb.set_max_slider_pos(vdelta); + + vsb.set_jump_size((display_rect_.height()+vscroll_bar_inc-1)/vscroll_bar_inc/2+1); + } + else + { + // both are hidden + if (!hidden) + { + hsb.hide(); + vsb.hide(); + } + display_rect_ = rectangle( rect.left()+style->get_border_size(), + rect.top()+style->get_border_size(), + rect.right()-style->get_border_size(), + rect.bottom()-style->get_border_size()); + + hsb.set_max_slider_pos(0); + vsb.set_max_slider_pos(0); + } + + vsb.set_length(display_rect_.height()); + hsb.set_length(display_rect_.width()); + + // adjust the total_rect_ position by trigging the scroll events + on_h_scroll(); + on_v_scroll(); + + parent.invalidate_rectangle(rect+old); + } + +// ---------------------------------------------------------------------------------------- + + unsigned long scrollable_region:: + horizontal_mouse_wheel_scroll_increment ( + ) const + { + auto_mutex M(m); + return h_wheel_scroll_bar_inc; + } + +// ---------------------------------------------------------------------------------------- + + unsigned long scrollable_region:: + vertical_mouse_wheel_scroll_increment ( + ) const + { + auto_mutex M(m); + return v_wheel_scroll_bar_inc; + } + +// ---------------------------------------------------------------------------------------- + + void scrollable_region:: + set_horizontal_mouse_wheel_scroll_increment ( + unsigned long inc + ) + { + auto_mutex M(m); + h_wheel_scroll_bar_inc = inc; + } + +// ---------------------------------------------------------------------------------------- + + void scrollable_region:: + set_vertical_mouse_wheel_scroll_increment ( + unsigned long inc + ) + { + auto_mutex M(m); + v_wheel_scroll_bar_inc = inc; + } + +// ---------------------------------------------------------------------------------------- + + unsigned long scrollable_region:: + horizontal_scroll_increment ( + ) const + { + auto_mutex M(m); + return hscroll_bar_inc; + } + +// ---------------------------------------------------------------------------------------- + + unsigned long scrollable_region:: + vertical_scroll_increment ( + ) const + { + auto_mutex M(m); + return vscroll_bar_inc; + } + +// ---------------------------------------------------------------------------------------- + + void scrollable_region:: + set_horizontal_scroll_increment ( + unsigned long inc + ) + { + auto_mutex M(m); + hscroll_bar_inc = inc; + // call set_size to reset the scroll bars + set_size(rect.width(),rect.height()); + } + +// ---------------------------------------------------------------------------------------- + + void scrollable_region:: + set_vertical_scroll_increment ( + unsigned long inc + ) + { + auto_mutex M(m); + vscroll_bar_inc = inc; + // call set_size to reset the scroll bars + set_size(rect.width(),rect.height()); + } + +// ---------------------------------------------------------------------------------------- + + long scrollable_region:: + horizontal_scroll_pos ( + ) const + { + auto_mutex M(m); + return hsb.slider_pos(); + } + +// ---------------------------------------------------------------------------------------- + + long scrollable_region:: + vertical_scroll_pos ( + ) const + { + auto_mutex M(m); + return vsb.slider_pos(); + } + +// ---------------------------------------------------------------------------------------- + + void scrollable_region:: + set_horizontal_scroll_pos ( + long pos + ) + { + auto_mutex M(m); + + hsb.set_slider_pos(pos); + on_h_scroll(); + } + +// ---------------------------------------------------------------------------------------- + + void scrollable_region:: + set_vertical_scroll_pos ( + long pos + ) + { + auto_mutex M(m); + + vsb.set_slider_pos(pos); + on_v_scroll(); + } + +// ---------------------------------------------------------------------------------------- + + void scrollable_region:: + set_pos ( + long x, + long y + ) + { + auto_mutex M(m); + drawable::set_pos(x,y); + vsb.set_pos(rect.right()-style->get_border_size()-vsb.width()+1, rect.top()+style->get_border_size()); + hsb.set_pos(rect.left()+style->get_border_size(), rect.bottom()-style->get_border_size()-hsb.height()+1); + + const long delta_x = total_rect_.left() - display_rect_.left(); + const long delta_y = total_rect_.top() - display_rect_.top(); + + display_rect_ = move_rect(display_rect_, rect.left()+style->get_border_size(), rect.top()+style->get_border_size()); + + total_rect_ = move_rect(total_rect_, display_rect_.left()+delta_x, display_rect_.top()+delta_y); + } + +// ---------------------------------------------------------------------------------------- + + bool scrollable_region:: + mouse_drag_enabled ( + ) const + { + auto_mutex M(m); + return mouse_drag_enabled_; + } + +// ---------------------------------------------------------------------------------------- + + void scrollable_region:: + enable_mouse_drag ( + ) + { + auto_mutex M(m); + mouse_drag_enabled_ = true; + } + +// ---------------------------------------------------------------------------------------- + + void scrollable_region:: + disable_mouse_drag ( + ) + { + auto_mutex M(m); + mouse_drag_enabled_ = false; + } + +// ---------------------------------------------------------------------------------------- + + const rectangle& scrollable_region:: + display_rect ( + ) const + { + return display_rect_; + } + +// ---------------------------------------------------------------------------------------- + + void scrollable_region:: + set_total_rect_size ( + unsigned long width, + unsigned long height + ) + { + DLIB_ASSERT((width > 0 && height > 0) || (width == 0 && height == 0), + "\tvoid scrollable_region::set_total_rect_size(width,height)" + << "\n\twidth and height must be > 0 or both == 0" + << "\n\twidth: " << width + << "\n\theight: " << height + << "\n\tthis: " << this + ); + + total_rect_ = move_rect(rectangle(width,height), + display_rect_.left()-static_cast(hsb.slider_pos()), + display_rect_.top()-static_cast(vsb.slider_pos())); + + // call this just to reconfigure the scroll bars + set_size(rect.width(),rect.height()); + } + +// ---------------------------------------------------------------------------------------- + + const rectangle& scrollable_region:: + total_rect ( + ) const + { + return total_rect_; + } + +// ---------------------------------------------------------------------------------------- + + void scrollable_region:: + scroll_to_rect ( + const rectangle& r_ + ) + { + const rectangle r(total_rect_.intersect(r_)); + const rectangle old(total_rect_); + // adjust the horizontal scroll bar so that r fits as best as possible + if (r.left() < display_rect_.left()) + { + long distance = (r.left()-total_rect_.left())/hscroll_bar_inc; + hsb.set_slider_pos(distance); + } + else if (r.right() > display_rect_.right()) + { + long distance = (r.right()-total_rect_.left()-display_rect_.width()+hscroll_bar_inc)/hscroll_bar_inc; + hsb.set_slider_pos(distance); + } + + // adjust the vertical scroll bar so that r fits as best as possible + if (r.top() < display_rect_.top()) + { + long distance = (r.top()-total_rect_.top())/vscroll_bar_inc; + vsb.set_slider_pos(distance); + } + else if (r.bottom() > display_rect_.bottom()) + { + long distance = (r.bottom()-total_rect_.top()-display_rect_.height()+vscroll_bar_inc)/vscroll_bar_inc; + vsb.set_slider_pos(distance); + } + + + // adjust total_rect_ so that it matches where the scroll bars are now + total_rect_ = move_rect(total_rect_, + display_rect_.left()-hscroll_bar_inc*hsb.slider_pos(), + display_rect_.top()-vscroll_bar_inc*vsb.slider_pos()); + + // only redraw if we actually changed something + if (total_rect_ != old) + { + parent.invalidate_rectangle(display_rect_); + } + } + +// ---------------------------------------------------------------------------------------- + + void scrollable_region:: + on_wheel_down ( + unsigned long + ) + { + if (rect.contains(lastx,lasty) && enabled && !hidden) + { + if (need_v_scroll()) + { + long pos = vsb.slider_pos(); + vsb.set_slider_pos(pos+(long)v_wheel_scroll_bar_inc); + on_v_scroll(); + } + else if (need_h_scroll()) + { + long pos = hsb.slider_pos(); + hsb.set_slider_pos(pos+(long)h_wheel_scroll_bar_inc); + on_h_scroll(); + } + } + } + +// ---------------------------------------------------------------------------------------- + + void scrollable_region:: + on_mouse_move ( + unsigned long state, + long x, + long y + ) + { + if (enabled && !hidden && user_is_dragging_mouse && state==base_window::LEFT) + { + point current_delta = point(x,y) - point(total_rect().left(), total_rect().top()); + rectangle new_rect(translate_rect(display_rect(), drag_origin - current_delta)); + new_rect = centered_rect(new_rect, new_rect.width()-hscroll_bar_inc, new_rect.height()-vscroll_bar_inc); + scroll_to_rect(new_rect); + on_view_changed(); + } + else + { + user_is_dragging_mouse = false; + } + } + +// ---------------------------------------------------------------------------------------- + + void scrollable_region:: + on_mouse_down ( + unsigned long btn, + unsigned long , + long x, + long y, + bool + ) + { + if (mouse_drag_enabled_ && enabled && !hidden && display_rect().contains(x,y) && (btn==base_window::LEFT)) + { + drag_origin = point(x,y) - point(total_rect().left(), total_rect().top()); + user_is_dragging_mouse = true; + } + else + { + user_is_dragging_mouse = false; + } + } + +// ---------------------------------------------------------------------------------------- + + void scrollable_region:: + on_mouse_up ( + unsigned long , + unsigned long , + long , + long + ) + { + user_is_dragging_mouse = false; + } + +// ---------------------------------------------------------------------------------------- + + void scrollable_region:: + on_wheel_up ( + unsigned long + ) + { + if (rect.contains(lastx,lasty) && enabled && !hidden) + { + if (need_v_scroll()) + { + long pos = vsb.slider_pos(); + vsb.set_slider_pos(pos-(long)v_wheel_scroll_bar_inc); + on_v_scroll(); + } + else if (need_h_scroll()) + { + long pos = hsb.slider_pos(); + hsb.set_slider_pos(pos-(long)h_wheel_scroll_bar_inc); + on_h_scroll(); + } + } + } + +// ---------------------------------------------------------------------------------------- + + void scrollable_region:: + draw ( + const canvas& c + ) const + { + style->draw_scrollable_region_border(c, rect, enabled); + } + +// ---------------------------------------------------------------------------------------- + + bool scrollable_region:: + need_h_scroll ( + ) const + { + if (total_rect_.width() > rect.width()-style->get_border_size()*2) + { + return true; + } + else + { + // check if we would need a vertical scroll bar and if adding one would make us need + // a horizontal one + if (total_rect_.height() > rect.height()-style->get_border_size()*2 && + total_rect_.width() > rect.width()-style->get_border_size()*2-vsb.width()) + return true; + else + return false; + } + } + +// ---------------------------------------------------------------------------------------- + + bool scrollable_region:: + need_v_scroll ( + ) const + { + if (total_rect_.height() > rect.height()-style->get_border_size()*2) + { + return true; + } + else + { + // check if we would need a horizontal scroll bar and if adding one would make us need + // a vertical_scroll_pos one + if (total_rect_.width() > rect.width()-style->get_border_size()*2 && + total_rect_.height() > rect.height()-style->get_border_size()*2-hsb.height()) + return true; + else + return false; + } + } + +// ---------------------------------------------------------------------------------------- + + void scrollable_region:: + on_h_scroll ( + ) + { + total_rect_ = move_rect(total_rect_, display_rect_.left()-hscroll_bar_inc*hsb.slider_pos(), total_rect_.top()); + parent.invalidate_rectangle(display_rect_); + if (events_are_enabled()) + on_view_changed(); + } + +// ---------------------------------------------------------------------------------------- + + void scrollable_region:: + on_v_scroll ( + ) + { + total_rect_ = move_rect(total_rect_, total_rect_.left(), display_rect_.top()-vscroll_bar_inc*vsb.slider_pos()); + parent.invalidate_rectangle(display_rect_); + if (events_are_enabled()) + on_view_changed(); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// class popup_menu_region +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + popup_menu_region:: + popup_menu_region( + drawable_window& w + ) : + drawable(w,MOUSE_CLICK | KEYBOARD_EVENTS | FOCUS_EVENTS | WINDOW_MOVED), + popup_menu_shown(false) + { + + menu_.set_on_hide_handler(*this,&popup_menu_region::on_menu_becomes_hidden); + enable_events(); + } + +// ---------------------------------------------------------------------------------------- + + popup_menu_region:: + ~popup_menu_region( + ) + { + disable_events(); + } + +// ---------------------------------------------------------------------------------------- + + void popup_menu_region:: + set_size ( + unsigned long width, + unsigned long height + ) + { + auto_mutex M(m); + rect = resize_rect(rect,width,height); + } + +// ---------------------------------------------------------------------------------------- + + void popup_menu_region:: + set_rect ( + const rectangle& new_rect + ) + { + auto_mutex M(m); + rect = new_rect; + } + +// ---------------------------------------------------------------------------------------- + + popup_menu& popup_menu_region:: + menu ( + ) + { + return menu_; + } + +// ---------------------------------------------------------------------------------------- + + void popup_menu_region:: + hide ( + ) + { + auto_mutex M(m); + drawable::hide(); + menu_.hide(); + popup_menu_shown = false; + } + +// ---------------------------------------------------------------------------------------- + + void popup_menu_region:: + disable ( + ) + { + auto_mutex M(m); + drawable::disable(); + menu_.hide(); + popup_menu_shown = false; + } + +// ---------------------------------------------------------------------------------------- + + void popup_menu_region:: + on_keydown ( + unsigned long key, + bool is_printable, + unsigned long state + ) + { + if (enabled && !hidden && popup_menu_shown) + { + menu_.forwarded_on_keydown(key, is_printable, state); + } + else if (popup_menu_shown) + { + menu_.hide(); + popup_menu_shown = false; + } + + if (key == (unsigned long)base_window::KEY_ESC) + { + menu_.hide(); + popup_menu_shown = false; + } + } + +// ---------------------------------------------------------------------------------------- + + void popup_menu_region:: + on_menu_becomes_hidden ( + ) + { + popup_menu_shown = false; + } + +// ---------------------------------------------------------------------------------------- + + void popup_menu_region:: + on_focus_lost ( + ) + { + if (popup_menu_shown) + { + menu_.hide(); + popup_menu_shown = false; + } + } + +// ---------------------------------------------------------------------------------------- + + void popup_menu_region:: + on_focus_gained ( + ) + { + if (popup_menu_shown) + { + menu_.hide(); + popup_menu_shown = false; + } + } + +// ---------------------------------------------------------------------------------------- + + void popup_menu_region:: + on_window_moved( + ) + { + if (popup_menu_shown) + { + menu_.hide(); + popup_menu_shown = false; + } + } + +// ---------------------------------------------------------------------------------------- + + void popup_menu_region:: + on_mouse_down ( + unsigned long btn, + unsigned long , + long x, + long y, + bool + ) + { + if (enabled && !hidden && rect.contains(x,y) && btn == base_window::RIGHT) + { + long orig_x, orig_y; + parent.get_pos(orig_x, orig_y); + menu_.set_pos(orig_x+x, orig_y+y); + menu_.show(); + popup_menu_shown = true; + } + else if (popup_menu_shown) + { + menu_.hide(); + popup_menu_shown = false; + } + } + +// ---------------------------------------------------------------------------------------- + + void popup_menu_region:: + draw ( + const canvas& + ) const + { + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BASE_WIDGETs_CPP_ + diff --git a/ml/dlib/dlib/gui_widgets/base_widgets.h b/ml/dlib/dlib/gui_widgets/base_widgets.h new file mode 100644 index 000000000..7c6d25097 --- /dev/null +++ b/ml/dlib/dlib/gui_widgets/base_widgets.h @@ -0,0 +1,2678 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net), Keita Mochizuki +// License: Boost Software License See LICENSE.txt for the full license. + +#ifndef DLIB_BASE_WIDGETs_ +#define DLIB_BASE_WIDGETs_ + +#include +#include + +#include "base_widgets_abstract.h" +#include "drawable.h" +#include "../gui_core.h" +#include "../algs.h" +#include "../member_function_pointer.h" +#include "../timer.h" +#include "../map.h" +#include "../set.h" +#include "../array2d.h" +#include "../pixel.h" +#include "../image_transforms/assign_image.h" +#include "../array.h" +#include "style.h" +#include "../unicode.h" +#include "../any.h" + + +namespace dlib +{ + + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // class draggable +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class draggable : public drawable + { + /*! + INITIAL VALUE + - drag == false + + CONVENTION + - if (the user is holding the left button down over this object) then + - drag == true + - x == the x position of the mouse relative to the upper left corner + of this object. + - y == the y position of the mouse relative to the upper left corner + of this object. + - else + - drag == false + !*/ + + public: + + draggable( + drawable_window& w, + unsigned long events = 0 + ) : + drawable(w,events | MOUSE_MOVE | MOUSE_CLICK), + drag(false) + {} + + virtual ~draggable( + ) = 0; + + rectangle draggable_area ( + ) const { auto_mutex M(m); return area; } + + void set_draggable_area ( + const rectangle& area_ + ) { auto_mutex M(m); area = area_; } + + protected: + + bool is_being_dragged ( + ) const { return drag; } + + virtual void on_drag ( + ){} + + virtual void on_drag_stop ( + ){} + + void on_mouse_move ( + unsigned long state, + long x, + long y + ); + + void on_mouse_down ( + unsigned long btn, + unsigned long , + long x, + long y, + bool + ); + + void on_mouse_up ( + unsigned long btn, + unsigned long state, + long x, + long y + ); + + private: + + rectangle area; + bool drag; + long x, y; + + // restricted functions + draggable(draggable&); // copy constructor + draggable& operator=(draggable&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // class mouse_over_event +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class mouse_over_event : public drawable + { + /*! + INITIAL VALUE + - is_mouse_over_ == false + + CONVENTION + - is_mouse_over_ == is_mouse_over() + !*/ + + public: + + mouse_over_event( + drawable_window& w, + unsigned long events = 0 + ) : + drawable(w,events | MOUSE_MOVE), + is_mouse_over_(false) + {} + + + virtual ~mouse_over_event( + ) = 0; + + int next_free_user_event_number() const + { + return drawable::next_free_user_event_number()+1; + } + + protected: + + bool is_mouse_over ( + ) const; + + virtual void on_mouse_over ( + ){} + + virtual void on_mouse_not_over ( + ){} + + void on_mouse_leave ( + ); + + void on_mouse_move ( + unsigned long state, + long x, + long y + ); + + void on_user_event ( + int num + ); + + private: + mutable bool is_mouse_over_; + + // restricted functions + mouse_over_event(mouse_over_event&); // copy constructor + mouse_over_event& operator=(mouse_over_event&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // class button_action +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class button_action : public mouse_over_event + { + /*! + INITIAL VALUE + - is_depressed_ == false + - seen_click == false + + CONVENTION + - is_depressed_ == is_depressed() + - if (the user has clicked the button but hasn't yet released the + left mouse button) then + - seen_click == true + - else + - seen_click == false + !*/ + + public: + + button_action( + drawable_window& w, + unsigned long events = 0 + ) : + mouse_over_event(w,events | MOUSE_MOVE | MOUSE_CLICK), + is_depressed_(false), + seen_click(false) + {} + + + virtual ~button_action( + ) = 0; + + int next_free_user_event_number() const + { + return mouse_over_event::next_free_user_event_number()+1; + } + + protected: + + bool is_depressed ( + ) const; + + virtual void on_button_down ( + ){} + + virtual void on_button_up ( + bool + ){} + + void on_mouse_not_over ( + ); + + void on_mouse_down ( + unsigned long btn, + unsigned long , + long x, + long y, + bool + ); + + void on_mouse_move ( + unsigned long state, + long x, + long y + ); + + void on_mouse_up ( + unsigned long btn, + unsigned long, + long x, + long y + ); + + + private: + mutable bool is_depressed_; + bool seen_click; + + void on_user_event ( + int num + ); + + // restricted functions + button_action(button_action&); // copy constructor + button_action& operator=(button_action&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // class widget_group +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class widget_group : public drawable + { + /*! + INITIAL VALUE + widgets.size() == 0 + + CONVENTION + - widgets contains all the drawable objects and their relative positions + that are in *this. + - wg_widgets contains pointers to just the widgets that happen + to be widget_group objects. + !*/ + + struct relpos + { + unsigned long x; + unsigned long y; + }; + + public: + widget_group( + drawable_window& w + ) : drawable(w) { rect = rectangle(0,0,-1,-1); enable_events();} + + virtual ~widget_group( + ){ disable_events(); } + + void empty ( + ); + + void add ( + drawable& widget, + unsigned long x, + unsigned long y + ); + + void add ( + widget_group& widget, + unsigned long x, + unsigned long y + ); + + bool is_member ( + const drawable& widget + ) const; + + void remove ( + const drawable& widget + ); + + size_t size ( + ) const; + + void set_pos ( + long x, + long y + ); + + void set_z_order ( + long order + ); + + void show ( + ); + + void hide ( + ); + + void enable ( + ); + + void disable ( + ); + + void fit_to_contents ( + ); + + protected: + + // this object doesn't draw anything but also isn't abstract + void draw ( + const canvas& + ) const {} + + private: + + map::kernel_1a_c widgets; + set::kernel_1a_c wg_widgets; + + + // restricted functions + widget_group(widget_group&); // copy constructor + widget_group& operator=(widget_group&); // assignment operator + }; + + +// ---------------------------------------------------------------------------------------- + + class image_widget : public draggable + { + /*! + INITIAL VALUE + - img.size() == 0 + + CONVENTION + - img == the image this object displays + !*/ + + public: + + image_widget( + drawable_window& w + ): draggable(w) { enable_events(); } + + ~image_widget( + ) + { + disable_events(); + parent.invalidate_rectangle(rect); + } + + template < + typename image_type + > + void set_image ( + const image_type& new_img + ) + { + auto_mutex M(m); + assign_image_scaled(img,new_img); + rectangle old(rect); + rect.set_right(rect.left()+num_columns(img)-1); + rect.set_bottom(rect.top()+num_rows(img)-1); + parent.invalidate_rectangle(rect+old); + } + + private: + + void draw ( + const canvas& c + ) const + { + rectangle area = rect.intersect(c); + if (area.is_empty()) + return; + + draw_image(c, point(rect.left(),rect.top()), img); + } + + array2d img; + + // restricted functions + image_widget(image_widget&); // copy constructor + image_widget& operator=(image_widget&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // class tooltip +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class tooltip : public mouse_over_event + { + /*! + INITIAL VALUE + - stuff.get() == 0 + - events_are_enabled() == false + + CONVENTION + - if (events_are_enabled() == true) then + - stuff.get() != 0 + !*/ + + public: + + tooltip( + drawable_window& w + ) : + mouse_over_event(w,MOUSE_CLICK) + {} + + ~tooltip( + ){ disable_events();} + + void set_size ( + unsigned long width, + unsigned long height + ) + { + auto_mutex M(m); + rect = resize_rect(rect,width,height); + } + + + void set_text ( + const std::string& str + ) + { + set_text(convert_mbstring_to_wstring(str)); + } + + void set_text ( + const std::wstring& str + ) + { + set_text(convert_wstring_to_utf32(str)); + } + + void set_text ( + const ustring& str + ) + { + auto_mutex M(m); + if (!stuff) + { + stuff.reset(new data(*this)); + enable_events(); + } + + stuff->win.set_text(str); + } + + const std::string text ( + ) const + { + return convert_wstring_to_mbstring(wtext()); + } + + const std::wstring wtext ( + ) const + { + return convert_utf32_to_wstring(utext()); + } + + const dlib::ustring utext ( + ) const + { + auto_mutex M(m); + dlib::ustring temp; + if (stuff) + { + temp = stuff->win.text; + } + return temp.c_str(); + } + + void hide ( + ) + { + auto_mutex M(m); + mouse_over_event::hide(); + if (stuff) + { + stuff->tt_timer.stop(); + stuff->win.hide(); + } + } + + void disable ( + ) + { + auto_mutex M(m); + mouse_over_event::disable(); + if (stuff) + { + stuff->tt_timer.stop(); + stuff->win.hide(); + } + } + + protected: + + void on_mouse_over() + { + stuff->x = lastx; + stuff->y = lasty; + stuff->tt_timer.start(); + } + + void on_mouse_not_over () + { + stuff->tt_timer.stop(); + stuff->win.hide(); + } + + void on_mouse_down ( + unsigned long btn, + unsigned long state, + long x, + long y, + bool is_double_click + ) + { + mouse_over_event::on_mouse_down(btn,state,x,y,is_double_click); + stuff->tt_timer.stop(); + stuff->win.hide(); + } + + void draw ( + const canvas& + ) const{} + + private: + + class tooltip_window : public base_window + { + public: + tooltip_window (const std::shared_ptr& f) : base_window(false,true), pad(3), mfont(f) + { + } + + ustring text; + rectangle rect_all; + rectangle rect_text; + const unsigned long pad; + const std::shared_ptr mfont; + + void set_text ( + const std::string& str + ) + { + set_text(convert_mbstring_to_wstring(str)); + } + + void set_text ( + const std::wstring& str + ) + { + set_text(convert_wstring_to_utf32(str)); + } + + void set_text ( + const dlib::ustring& str + ) + { + text = str.c_str(); + + unsigned long width, height; + mfont->compute_size(text,width,height); + + set_size(width+pad*2, height+pad*2); + rect_all.set_left(0); + rect_all.set_top(0); + rect_all.set_right(width+pad*2-1); + rect_all.set_bottom(height+pad*2-1); + + rect_text = move_rect(rectangle(width,height),pad,pad); + } + + void paint(const canvas& c) + { + c.fill(255,255,150); + draw_rectangle(c, rect_all); + mfont->draw_string(c,rect_text,text); + } + }; + + void show_tooltip ( + ) + { + auto_mutex M(m); + long x, y; + // if the mouse has moved since we started the timer then + // keep waiting until the user stops moving it + if (lastx != stuff->x || lasty != stuff->y) + { + stuff->x = lastx; + stuff->y = lasty; + return; + } + + unsigned long display_width, display_height; + // stop the timer + stuff->tt_timer.stop(); + parent.get_pos(x,y); + x += lastx+15; + y += lasty+15; + + // make sure the tooltip isn't going to be off the screen + parent.get_display_size(display_width, display_height); + rectangle wrect(move_rect(stuff->win.rect_all,x,y)); + rectangle srect(display_width, display_height); + if (srect.contains(wrect) == false) + { + rectangle temp(srect.intersect(wrect)); + x -= wrect.width()-temp.width(); + y -= wrect.height()-temp.height(); + } + + stuff->win.set_pos(x,y); + stuff->win.show(); + } + + // put all this stuff in data so we can arrange to only + // construct it when someone is actually using the tooltip widget + // rather than just instantiating it. + struct data + { + data( + tooltip& self + ) : + x(-1), + y(-1), + win(self.mfont), + tt_timer(self,&tooltip::show_tooltip) + { + tt_timer.set_delay_time(400); + } + + long x, y; + tooltip_window win; + timer tt_timer; + + }; + friend struct data; + std::unique_ptr stuff; + + + + // restricted functions + tooltip(tooltip&); // copy constructor + tooltip& operator=(tooltip&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // class button +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class button : public button_action + { + public: + button( + drawable_window& w + ) : + button_action(w), + btn_tooltip(w) + { + style.reset(new button_style_default()); + enable_events(); + } + + ~button() { disable_events(); parent.invalidate_rectangle(style->get_invalidation_rect(rect)); } + + void set_size ( + unsigned long width, + unsigned long height + ); + + void set_name ( + const std::string& name_ + ); + + void set_name ( + const std::wstring& name_ + ); + + void set_name ( + const dlib::ustring& name_ + ); + + const std::string name ( + ) const; + + const std::wstring wname ( + ) const; + + const dlib::ustring uname ( + ) const; + + void set_tooltip_text ( + const std::string& text + ); + + void set_tooltip_text ( + const std::wstring& text + ); + + void set_tooltip_text ( + const dlib::ustring& text + ); + + void set_pos( + long x, + long y + ); + + const std::string tooltip_text ( + ) const; + + const std::wstring tooltip_wtext ( + ) const; + + const dlib::ustring tooltip_utext ( + ) const; + + void set_main_font ( + const std::shared_ptr& f + ); + + void show ( + ); + + void hide ( + ); + + void enable ( + ); + + void disable ( + ); + + template < + typename style_type + > + void set_style ( + const style_type& style_ + ) + { + auto_mutex M(m); + style.reset(new style_type(style_)); + rect = move_rect(style->get_min_size(name_,*mfont), rect.left(), rect.top()); + parent.invalidate_rectangle(style->get_invalidation_rect(rect)); + } + + template < + typename T + > + void set_click_handler ( + T& object, + void (T::*event_handler_)() + ) + { + auto_mutex M(m); + event_handler = make_mfp(object,event_handler_); + event_handler_self.clear(); + } + + void set_click_handler ( + const any_function& event_handler_ + ) + { + auto_mutex M(m); + event_handler = event_handler_; + event_handler_self.clear(); + } + + template < + typename T + > + void set_click_handler ( + T& object, + void (T::*event_handler_)(button&) + ) + { + auto_mutex M(m); + event_handler_self = make_mfp(object,event_handler_); + event_handler.clear(); + } + + void set_sourced_click_handler ( + const any_function& event_handler_ + ) + { + auto_mutex M(m); + event_handler_self = event_handler_; + event_handler.clear(); + } + + bool is_depressed ( + ) const + { + auto_mutex M(m); + return button_action::is_depressed(); + } + + template < + typename T + > + void set_button_down_handler ( + T& object, + void (T::*event_handler)() + ) + { + auto_mutex M(m); + button_down_handler = make_mfp(object,event_handler); + } + + void set_button_down_handler ( + const any_function& event_handler + ) + { + auto_mutex M(m); + button_down_handler = event_handler; + } + + template < + typename T + > + void set_button_up_handler ( + T& object, + void (T::*event_handler)(bool mouse_over) + ) + { + auto_mutex M(m); + button_up_handler = make_mfp(object,event_handler); + } + + void set_button_up_handler ( + const any_function& event_handler + ) + { + auto_mutex M(m); + button_up_handler = event_handler; + } + + template < + typename T + > + void set_button_down_handler ( + T& object, + void (T::*event_handler)(button&) + ) + { + auto_mutex M(m); + button_down_handler_self = make_mfp(object,event_handler); + } + + void set_sourced_button_down_handler ( + const any_function& event_handler + ) + { + auto_mutex M(m); + button_down_handler_self = event_handler; + } + + template < + typename T + > + void set_button_up_handler ( + T& object, + void (T::*event_handler)(bool mouse_over, button&) + ) + { + auto_mutex M(m); + button_up_handler_self = make_mfp(object,event_handler); + } + + void set_sourced_button_up_handler ( + const any_function& event_handler + ) + { + auto_mutex M(m); + button_up_handler_self = event_handler; + } + + private: + + // restricted functions + button(button&); // copy constructor + button& operator=(button&); // assignment operator + + dlib::ustring name_; + tooltip btn_tooltip; + + any_function event_handler; + any_function event_handler_self; + any_function button_down_handler; + any_function button_up_handler; + any_function button_down_handler_self; + any_function button_up_handler_self; + + std::unique_ptr style; + + protected: + + void draw ( + const canvas& c + ) const { style->draw_button(c,rect,enabled,*mfont,lastx,lasty,name_,is_depressed()); } + + void on_button_up ( + bool mouse_over + ); + + void on_button_down ( + ); + + void on_mouse_over ( + ){ if (style->redraw_on_mouse_over()) parent.invalidate_rectangle(style->get_invalidation_rect(rect)); } + + void on_mouse_not_over ( + ){ if (style->redraw_on_mouse_over()) parent.invalidate_rectangle(style->get_invalidation_rect(rect)); } + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // class scroll_bar +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class scroll_bar : public drawable + { + /*! + INITIAL VALUE + - ori == a value given by the constructor + - style == a scroll_bar_style_default object + - pos == 0 + - max_pos == 0 + - js == 10 + + CONVENTION + - ori == orientation() + - b1 == the button that is near the 0 end of the scroll bar + - b2 == the button that is near the max_pos() end of the scroll bar + + - max_pos == max_slider_pos() + - pos == slider_pos() + - js == jump_size() + !*/ + + public: + enum bar_orientation + { + HORIZONTAL, + VERTICAL + }; + + scroll_bar( + drawable_window& w, + bar_orientation orientation_ + ); + + virtual ~scroll_bar( + ); + + bar_orientation orientation ( + ) const; + + void set_length ( + unsigned long length + ); + + long max_slider_pos ( + ) const; + + void set_max_slider_pos ( + long mpos + ); + + void set_slider_pos ( + long pos + ); + + long slider_pos ( + ) const; + + template < + typename T + > + void set_scroll_handler ( + T& object, + void (T::*eh)() + ) { auto_mutex M(m); scroll_handler = make_mfp(object,eh); } + + void set_scroll_handler ( + const any_function& eh + ) { auto_mutex M(m); scroll_handler = eh; } + + void set_pos ( + long x, + long y + ); + + void enable ( + ) + { + auto_mutex M(m); + if (!hidden) + show_slider(); + if (max_pos != 0) + { + b1.enable(); + b2.enable(); + } + drawable::enable(); + } + + void disable ( + ) + { + auto_mutex M(m); + hide_slider(); + b1.disable(); + b2.disable(); + drawable::disable(); + } + + void hide ( + ) + { + auto_mutex M(m); + hide_slider(); + top_filler.hide(); + bottom_filler.hide(); + b1.hide(); + b2.hide(); + drawable::hide(); + } + + void show ( + ) + { + auto_mutex M(m); + b1.show(); + b2.show(); + drawable::show(); + top_filler.show(); + if (enabled) + show_slider(); + } + + void set_z_order ( + long order + ) + { + auto_mutex M(m); + slider.set_z_order(order); + top_filler.set_z_order(order); + bottom_filler.set_z_order(order); + b1.set_z_order(order); + b2.set_z_order(order); + drawable::set_z_order(order); + } + + void set_jump_size ( + long js + ); + + long jump_size ( + ) const; + + template < + typename style_type + > + void set_style ( + const style_type& style_ + ) + { + auto_mutex M(m); + style.reset(new style_type(style_)); + + if (ori == HORIZONTAL) + { + b1.set_style(style_.get_left_button_style()); + b2.set_style(style_.get_right_button_style()); + set_length(rect.width()); + } + else + { + b1.set_style(style_.get_up_button_style()); + b2.set_style(style_.get_down_button_style()); + set_length(rect.height()); + } + + } + + private: + + void hide_slider ( + ); + /*! + ensures + - hides the slider and makes any other changes needed so that the + scroll_bar still looks right. + !*/ + + void show_slider ( + ); + /*! + ensures + - shows the slider and makes any other changes needed so that the + scroll_bar still looks right. + !*/ + + + void on_slider_drag ( + ); + /*! + requires + - is called whenever the user drags the slider + !*/ + + void draw ( + const canvas& c + ) const; + + void b1_down ( + ); + + void b1_up ( + bool mouse_over + ); + + void b2_down ( + ); + + void b2_up ( + bool mouse_over + ); + + void top_filler_down ( + ); + + void top_filler_up ( + bool mouse_over + ); + + void bottom_filler_down ( + ); + + void bottom_filler_up ( + bool mouse_over + ); + + void on_user_event ( + int i + ); + + void delayed_set_slider_pos ( + unsigned long dpos + ); + + void b1_down_t ( + ); + + void b2_down_t ( + ); + + void top_filler_down_t ( + ); + + void bottom_filler_down_t ( + ); + + friend class filler; + class filler : public button_action + { + friend class scroll_bar; + public: + filler ( + drawable_window& w, + scroll_bar& object, + void (scroll_bar::*down)(), + void (scroll_bar::*up)(bool) + ): + button_action(w), + my_scroll_bar(object) + { + bup = make_mfp(object,up); + bdown = make_mfp(object,down); + + enable_events(); + } + + ~filler ( + ) + { + disable_events(); + } + + void set_size ( + unsigned long width, + unsigned long height + ) + { + rectangle old(rect); + const unsigned long x = rect.left(); + const unsigned long y = rect.top(); + rect.set_right(x+width-1); + rect.set_bottom(y+height-1); + + parent.invalidate_rectangle(rect+old); + } + + private: + + void draw ( + const canvas& c + ) const + { + my_scroll_bar.style->draw_scroll_bar_background(c,rect,enabled,lastx,lasty,is_depressed()); + } + + void on_button_down ( + ) { bdown(); } + + void on_button_up ( + bool mouse_over + ) { bup(mouse_over); } + + scroll_bar& my_scroll_bar; + any_function bdown; + any_function bup; + }; + + friend class slider_class; + class slider_class : public draggable + { + friend class scroll_bar; + public: + slider_class ( + drawable_window& w, + scroll_bar& object, + void (scroll_bar::*handler)() + ) : + draggable(w, MOUSE_MOVE), + mouse_in_widget(false), + my_scroll_bar(object) + { + callback = make_mfp(object,handler); + enable_events(); + } + + ~slider_class ( + ) + { + disable_events(); + } + + void set_size ( + unsigned long width, + unsigned long height + ) + { + rectangle old(rect); + const unsigned long x = rect.left(); + const unsigned long y = rect.top(); + rect.set_right(x+width-1); + rect.set_bottom(y+height-1); + + parent.invalidate_rectangle(rect+old); + } + + private: + virtual void on_mouse_move ( + unsigned long state, + long x, + long y + ) + { + draggable::on_mouse_move(state,x,y); + if (!hidden && my_scroll_bar.style->redraw_on_mouse_over_slider()) + { + if (rect.contains(x,y) && !mouse_in_widget) + { + mouse_in_widget = true; + parent.invalidate_rectangle(rect); + } + else if (rect.contains(x,y) == false && mouse_in_widget) + { + mouse_in_widget = false; + parent.invalidate_rectangle(rect); + } + } + } + + void on_mouse_leave ( + ) + { + if (mouse_in_widget && my_scroll_bar.style->redraw_on_mouse_over_slider()) + { + mouse_in_widget = false; + parent.invalidate_rectangle(rect); + } + } + + void on_drag_stop ( + ) + { + if (my_scroll_bar.style->redraw_on_mouse_over_slider()) + parent.invalidate_rectangle(rect); + } + + void on_drag ( + ) + { + callback(); + } + + void draw ( + const canvas& c + ) const + { + my_scroll_bar.style->draw_scroll_bar_slider(c,rect,enabled,lastx,lasty, is_being_dragged()); + } + + bool mouse_in_widget; + scroll_bar& my_scroll_bar; + any_function callback; + }; + + + void adjust_fillers ( + ); + /*! + ensures + - top_filler and bottom_filler appear in their correct positions + relative to the current positions of the slider and the b1 and + b2 buttons + !*/ + + unsigned long get_slider_size ( + ) const; + /*! + ensures + - returns the length in pixels the slider should have based on the current + state of this scroll bar + !*/ + + + button b1, b2; + slider_class slider; + bar_orientation ori; + filler top_filler, bottom_filler; + any_function scroll_handler; + + long pos; + long max_pos; + long js; + + timer b1_timer; + timer b2_timer; + timer top_filler_timer; + timer bottom_filler_timer; + long delayed_pos; + std::unique_ptr style; + + // restricted functions + scroll_bar(scroll_bar&); // copy constructor + scroll_bar& operator=(scroll_bar&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // class popup_menu +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class menu_item + { + public: + virtual ~menu_item() {} + + virtual rectangle get_left_size ( + ) const { return rectangle(); } + virtual rectangle get_middle_size ( + ) const = 0; + virtual rectangle get_right_size ( + ) const { return rectangle(); } + + virtual unichar get_hot_key ( + ) const { return 0; } + + virtual void draw_background ( + const canvas& , + const rectangle& , + const bool , + const bool + ) const {} + + virtual void draw_left ( + const canvas& , + const rectangle& , + const bool , + const bool + ) const {} + + virtual void draw_middle ( + const canvas& , + const rectangle& , + const bool , + const bool + ) const = 0; + + virtual void draw_right ( + const canvas& , + const rectangle& , + const bool , + const bool + ) const {} + + virtual void on_click ( + ) const {} + + virtual bool has_click_event ( + ) const { return false; } + + }; + +// ---------------------------------------------------------------------------------------- + + class menu_item_submenu : public menu_item + { + void initialize ( + unichar hk + ) + { + const dlib::ustring &str = text; + if (hk != 0) + { + std::string::size_type pos = str.find_first_of(hk); + if (pos != std::string::npos) + { + // now compute the location of the underline bar + rectangle r1 = f->compute_cursor_rect( rectangle(100000,100000), str, pos); + rectangle r2 = f->compute_cursor_rect( rectangle(100000,100000), str, pos+1); + + underline_p1.x() = r1.left()+1; + underline_p2.x() = r2.left()-1; + underline_p1.y() = r1.bottom()-f->height()+f->ascender()+2; + underline_p2.y() = r2.bottom()-f->height()+f->ascender()+2; + } + } + } + public: + menu_item_submenu ( + const std::string& str, + unichar hk = 0 + ) : + text(convert_wstring_to_utf32(convert_mbstring_to_wstring(str))), + f(default_font::get_font()), + hotkey(hk) + { + initialize(hk); + } + + menu_item_submenu ( + const std::wstring& str, + unichar hk = 0 + ) : + text(convert_wstring_to_utf32(str)), + f(default_font::get_font()), + hotkey(hk) + { + initialize(hk); + } + + menu_item_submenu ( + const dlib::ustring& str, + unichar hk = 0 + ) : + text(str), + f(default_font::get_font()), + hotkey(hk) + { + initialize(hk); + } + + virtual unichar get_hot_key ( + ) const { return hotkey; } + + virtual rectangle get_middle_size ( + ) const + { + unsigned long width, height; + f->compute_size(text,width,height); + return rectangle(width+30,height); + } + + virtual rectangle get_right_size ( + ) const + { + return rectangle(15, 5); + } + + virtual void draw_background ( + const canvas& c, + const rectangle& rect, + const bool enabled, + const bool is_selected + ) const + { + if (c.intersect(rect).is_empty()) + return; + + if (enabled && is_selected) + { + fill_rect_with_vertical_gradient(c, rect,rgb_alpha_pixel(0,200,0,100), rgb_alpha_pixel(0,0,0,100)); + draw_rectangle(c, rect,rgb_alpha_pixel(0,0,0,100)); + } + } + + virtual void draw_right ( + const canvas& c, + const rectangle& rect, + const bool enabled, + const bool + ) const + { + if (c.intersect(rect).is_empty()) + return; + + unsigned char color = 0; + + if (enabled == false) + color = 128; + + long x, y; + x = rect.right() - 7; + y = rect.top() + rect.height()/2; + + for ( unsigned long i = 0; i < 5; ++i) + draw_line (c, point(x - i, y + i), point(x - i, y - i), color); + } + + virtual void draw_middle ( + const canvas& c, + const rectangle& rect, + const bool enabled, + const bool + ) const + { + if (c.intersect(rect).is_empty()) + return; + + if (enabled) + { + f->draw_string(c,rect,text); + } + else + { + f->draw_string(c,rect,text,128); + } + + if (underline_p1 != underline_p2) + { + point base(rect.left(),rect.top()); + draw_line(c, base+underline_p1, base+underline_p2); + } + } + + private: + dlib::ustring text; + const std::shared_ptr f; + any_function action; + unichar hotkey; + point underline_p1; + point underline_p2; + }; + +// ---------------------------------------------------------------------------------------- + + class menu_item_text : public menu_item + { + void initialize ( + const any_function& event_handler_, + unichar hk + ) + { + dlib::ustring &str = text; + action = event_handler_; + + if (hk != 0) + { + std::string::size_type pos = str.find_first_of(hk); + if (pos != std::string::npos) + { + // now compute the location of the underline bar + rectangle r1 = f->compute_cursor_rect( rectangle(100000,100000), str, pos); + rectangle r2 = f->compute_cursor_rect( rectangle(100000,100000), str, pos+1); + + underline_p1.x() = r1.left()+1; + underline_p2.x() = r2.left()-1; + underline_p1.y() = r1.bottom()-f->height()+f->ascender()+2; + underline_p2.y() = r2.bottom()-f->height()+f->ascender()+2; + } + } + } + + public: + template + menu_item_text ( + const std::string& str, + T& object, + void (T::*event_handler_)(), + unichar hk = 0 + ) : + text(convert_wstring_to_utf32(convert_mbstring_to_wstring(str))), + f(default_font::get_font()), + hotkey(hk) + { + initialize(make_mfp(object, event_handler_), hk); + } + + menu_item_text ( + const std::string& str, + const any_function& event_handler_, + unichar hk = 0 + ) : + text(convert_wstring_to_utf32(convert_mbstring_to_wstring(str))), + f(default_font::get_font()), + hotkey(hk) + { + initialize(event_handler_, hk); + } + + template + menu_item_text ( + const std::wstring& str, + T& object, + void (T::*event_handler_)(), + unichar hk = 0 + ) : + text(convert_wstring_to_utf32(str)), + f(default_font::get_font()), + hotkey(hk) + { + initialize(make_mfp(object, event_handler_), hk); + } + + menu_item_text ( + const std::wstring& str, + const any_function& event_handler_, + unichar hk = 0 + ) : + text(convert_wstring_to_utf32(str)), + f(default_font::get_font()), + hotkey(hk) + { + initialize(event_handler_, hk); + } + + template + menu_item_text ( + const dlib::ustring& str, + T& object, + void (T::*event_handler_)(), + unichar hk = 0 + ) : + text(str), + f(default_font::get_font()), + hotkey(hk) + { + initialize(make_mfp(object, event_handler_), hk); + } + + menu_item_text ( + const dlib::ustring& str, + const any_function& event_handler_, + unichar hk = 0 + ) : + text(str), + f(default_font::get_font()), + hotkey(hk) + { + initialize(event_handler_, hk); + } + + virtual unichar get_hot_key ( + ) const { return hotkey; } + + virtual rectangle get_middle_size ( + ) const + { + unsigned long width, height; + f->compute_size(text,width,height); + return rectangle(width,height); + } + + virtual void draw_background ( + const canvas& c, + const rectangle& rect, + const bool enabled, + const bool is_selected + ) const + { + if (c.intersect(rect).is_empty()) + return; + + if (enabled && is_selected) + { + fill_rect_with_vertical_gradient(c, rect,rgb_alpha_pixel(0,200,0,100), rgb_alpha_pixel(0,0,0,100)); + draw_rectangle(c, rect,rgb_alpha_pixel(0,0,0,100)); + } + } + + virtual void draw_middle ( + const canvas& c, + const rectangle& rect, + const bool enabled, + const bool + ) const + { + if (c.intersect(rect).is_empty()) + return; + + unsigned char color = 0; + + if (enabled == false) + color = 128; + + f->draw_string(c,rect,text,color); + + if (underline_p1 != underline_p2) + { + point base(rect.left(),rect.top()); + draw_line(c, base+underline_p1, base+underline_p2, color); + } + } + + virtual void on_click ( + ) const + { + action(); + } + + virtual bool has_click_event ( + ) const { return true; } + + private: + dlib::ustring text; + const std::shared_ptr f; + any_function action; + unichar hotkey; + point underline_p1; + point underline_p2; + }; + +// ---------------------------------------------------------------------------------------- + + class menu_item_separator : public menu_item + { + public: + virtual rectangle get_middle_size ( + ) const + { + return rectangle(10,4); + } + + virtual void draw_background ( + const canvas& c, + const rectangle& rect, + const bool , + const bool + ) const + { + if (c.intersect(rect).is_empty()) + return; + + point p1(rect.left(),rect.top()+rect.height()/2-1); + point p2(rect.right(),rect.top()+rect.height()/2-1); + + point p3(rect.left(),rect.top()+rect.height()/2); + point p4(rect.right(),rect.top()+rect.height()/2); + draw_line(c, p1,p2,128); + draw_line(c, p3,p4,255); + } + + virtual void draw_middle ( + const canvas& , + const rectangle& , + const bool , + const bool + ) const + { + } + }; + +// ---------------------------------------------------------------------------------------- + + class popup_menu : public base_window + { + /*! + INITIAL VALUE + - pad == 2 + - item_pad == 3 + - cur_rect == rectangle(pad,pad,pad-1,pad-1) + - left_width == 0 + - middle_width == 0 + - selected_item == 0 + - submenu_open == false + - items.size() == 0 + - item_enabled.size() == 0 + - left_rects.size() == 0 + - middle_rects.size() == 0 + - right_rects.size() == 0 + - line_rects.size() == 0 + - submenus.size() == 0 + - hide_handlers.size() == 0 + + CONVENTION + - pad = 2 + - item_pad = 3 + - all of the following arrays have the same size: + - items.size() + - item_enabled.size() + - left_rects.size() + - middle_rects.size() + - right_rects.size() + - line_rects.size() + - submenus.size() + + - win_rect == a rectangle that is the exact size of this window and with + its upper left corner at (0,0) + - cur_rect == the rect inside which all the menu items are drawn + + - if (a menu_item is supposed to be selected) then + - selected_item == the index in menus of the menu_item + - else + - selected_item == submenus.size() + + - if (there is a selected submenu and it is currently open) then + - submenu_open == true + - else + - submenu_open == false + + - for all valid i: + - items[i] == a pointer to the ith menu_item + - item_enabled[i] == true if the ith menu_item is enabled, false otherwise + - left_rects[i] == the left rectangle for the ith menu item + - middle_rects[i] == the middle rectangle for the ith menu item + - right_rects[i] == the right rectangle for the ith menu item + - line_rects[i] == the rectangle for the entire line on which the ith menu + item appears. + - if (submenus[i] != 0) then + - the ith menu item has a submenu and it is pointed to by submenus[i] + + - hide_handlers == an array of all the on_hide events registered for + this popup_menu + !*/ + + public: + + popup_menu ( + ); + + template < + typename menu_item_type + > + unsigned long add_menu_item ( + const menu_item_type& new_item + ) + { + auto_mutex M(wm); + bool t = true; + std::unique_ptr item(new menu_item_type(new_item)); + items.push_back(item); + item_enabled.push_back(t); + + // figure out how big the window should be now and what not + rectangle left = new_item.get_left_size(); + rectangle middle = new_item.get_middle_size(); + rectangle right = new_item.get_right_size(); + + bool recalc_rect_positions = false; + const rectangle all = left+middle+right; + + + // make sure left_width contains the max of all the left rectangles + if (left.width() > left_width) + { + left_width = left.width(); + recalc_rect_positions = true; + } + // make sure middle_width contains the max of all the middle rectangles + if (middle.width() > middle_width) + { + middle_width = middle.width(); + recalc_rect_positions = true; + } + + // make the current rectangle wider if necessary + if (cur_rect.width() < left_width + middle_width + right.width() + 2*item_pad) + { + cur_rect = resize_rect_width(cur_rect, left_width + middle_width + right.width() + 2*item_pad); + recalc_rect_positions = true; + } + + const long y = cur_rect.bottom()+1 + item_pad; + const long x = cur_rect.left() + item_pad; + + // make the current rectangle taller to account for this new menu item + cur_rect.set_bottom(cur_rect.bottom()+all.height() + 2*item_pad); + + // adjust all the saved rectangles since the width of the window changed + // or left_width changed + if (recalc_rect_positions) + { + long y = cur_rect.top() + item_pad; + for (unsigned long i = 0; i < left_rects.size(); ++i) + { + middle_rects[i] = move_rect(middle_rects[i], x+left_width, y); + right_rects[i] = move_rect(right_rects[i], x+cur_rect.width()-right_rects[i].width()-item_pad, y); + line_rects[i] = resize_rect_width(line_rects[i], cur_rect.width()); + + y += line_rects[i].height(); + } + } + + // save the rectangles for later use. Also position them at the + // right spots + left = move_rect(left,x,y); + middle = move_rect(middle,x+left_width,y); + right = move_rect(right,x+cur_rect.width()-right.width()-item_pad,y); + rectangle line(move_rect(rectangle(cur_rect.width(),all.height()+2*item_pad), x-item_pad, y-item_pad)); + + // make sure the left, middle, and right rectangles are centered in the + // line. + if (left.height() < all.height()) + left = translate_rect(left,0, (all.height()-left.height())/2); + if (middle.height() < all.height()) + middle = translate_rect(middle,0, (all.height()-middle.height())/2); + if (right.height() < all.height()) + right = translate_rect(right,0, (all.height()-right.height())/2); + + left_rects.push_back(left); + middle_rects.push_back(middle); + right_rects.push_back(right); + line_rects.push_back(line); + + popup_menu* junk = 0; + submenus.push_back(junk); + + win_rect.set_right(cur_rect.right()+pad); + win_rect.set_bottom(cur_rect.bottom()+pad); + set_size(win_rect.width(),win_rect.height()); + + // make it so that nothing is selected + selected_item = submenus.size(); + + return items.size()-1; + } + + template < + typename menu_item_type + > + unsigned long add_submenu ( + const menu_item_type& new_item, + popup_menu& submenu + ) + { + auto_mutex M(wm); + + submenus[add_menu_item(new_item)] = &submenu; + + submenu.set_on_hide_handler(*this,&popup_menu::on_submenu_hide); + + return items.size()-1; + } + + void enable_menu_item ( + unsigned long idx + ); + + void disable_menu_item ( + unsigned long idx + ); + + size_t size ( + ) const; + + void clear ( + ); + + void show ( + ); + + void hide ( + ); + + template + void set_on_hide_handler ( + T& object, + void (T::*event_handler)() + ) + { + auto_mutex M(wm); + + member_function_pointer<> temp; + temp.set(object,event_handler); + + // if this handler isn't already registered then add it + bool found_handler = false; + for (unsigned long i = 0; i < hide_handlers.size(); ++i) + { + if (hide_handlers[i] == temp) + { + found_handler = true; + break; + } + } + + if (found_handler == false) + { + hide_handlers.push_back(temp); + } + } + + void select_first_item ( + ); + + bool forwarded_on_keydown ( + unsigned long key, + bool is_printable, + unsigned long state + ); + + private: + + void on_submenu_hide ( + ); + + void on_window_resized( + ); + + void on_mouse_up ( + unsigned long btn, + unsigned long, + long x, + long y + ); + + void on_mouse_move ( + unsigned long state, + long x, + long y + ); + + void close_submenu ( + ); + + bool display_selected_submenu ( + ); + /*! + ensures + - if (submenus[selected_item] isn't null) then + - displays the selected submenu + - returns true + - else + - returns false + !*/ + + void on_mouse_leave ( + ); + + void paint ( + const canvas& c + ); + + const long pad; + const long item_pad; + rectangle cur_rect; + rectangle win_rect; + unsigned long left_width; + unsigned long middle_width; + array > items; + array item_enabled; + array left_rects; + array middle_rects; + array right_rects; + array line_rects; + array submenus; + unsigned long selected_item; + bool submenu_open; + array > hide_handlers; + + // restricted functions + popup_menu(popup_menu&); // copy constructor + popup_menu& operator=(popup_menu&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- + + class zoomable_region : public drawable + { + /*! + INITIAL VALUE + - min_scale == 0.15 + - max_scale == 1.0 + - zoom_increment_ == 0.02 + - scale == 1.0 + - mouse_drag_screen == false + + + CONVENTION + - zoom_increment() == zoom_increment_ + - min_zoom_scale() == min_scale + - max_zoom_scale() == max_scale + - zoom_scale() == scale + - if (the user is currently dragging the graph around via the mouse) then + - mouse_drag_screen == true + - else + - mouse_drag_screen == false + + - max_graph_point() == lr_point + - display_rect() == display_rect_ + - gui_to_graph_space(point(display_rect.left(),display_rect.top())) == gr_orig + !*/ + + public: + + zoomable_region ( + drawable_window& w, + unsigned long events = 0 + ); + + virtual ~zoomable_region ( + )= 0; + + virtual void set_pos ( + long x, + long y + ); + + template < + typename style_type + > + void set_style ( + const style_type& style_ + ) + { + auto_mutex M(m); + style.reset(new style_type(style_)); + hsb.set_style(style_.get_horizontal_scroll_bar_style()); + vsb.set_style(style_.get_vertical_scroll_bar_style()); + + // do this just so that everything gets redrawn right + set_size(rect.width(), rect.height()); + } + + void set_zoom_increment ( + double zi + ); + + double zoom_increment ( + ) const; + + void set_max_zoom_scale ( + double ms + ); + + void set_min_zoom_scale ( + double ms + ); + + double min_zoom_scale ( + ) const; + + double max_zoom_scale ( + ) const; + + virtual void set_size ( + unsigned long width, + unsigned long height + ); + + void show ( + ); + + void hide ( + ); + + void enable ( + ); + + void disable ( + ); + + void set_z_order ( + long order + ); + + protected: + + virtual void on_view_changed () {} + + point graph_to_gui_space ( + const vector& p + ) const; + + vector gui_to_graph_space ( + const point& p + ) const; + + point max_graph_point ( + ) const; + + rectangle display_rect ( + ) const; + + double zoom_scale ( + ) const; + + void set_zoom_scale ( + double new_scale + ); + + void center_display_at_graph_point ( + const vector& p + ); + + // ----------- event handlers --------------- + + void on_wheel_down ( + unsigned long state + ); + + void on_wheel_up ( + unsigned long state + ); + + void on_mouse_move ( + unsigned long state, + long x, + long y + ); + + void on_mouse_up ( + unsigned long btn, + unsigned long state, + long x, + long y + ); + + void on_mouse_down ( + unsigned long btn, + unsigned long state, + long x, + long y, + bool is_double_click + ); + + void draw ( + const canvas& c + ) const; + + private: + + void on_h_scroll ( + ); + + void on_v_scroll ( + ); + + void redraw_graph ( + ); + + void adjust_origin ( + const point& gui_p, + const vector& graph_p + ); + /*! + ensures + - adjusts gr_orig so that we are as close to the following as possible: + - graph_to_gui_space(graph_p) == gui_p + - gui_to_graph_space(gui_p) == graph_p + !*/ + + + vector gr_orig; // point in graph space such that it's gui space point is the upper left of display_rect_ + vector lr_point; // point in graph space such that it is at the lower right corner of the screen at max zoom + + mutable std::ostringstream sout; + + double scale; // 0 < scale <= 1 + double min_scale; + double max_scale; + double zoom_increment_; + rectangle display_rect_; + + bool mouse_drag_screen; // true if the user is dragging the white background area + vector drag_screen_point; // the starting point the mouse was at in graph space for the background area drag + + scroll_bar vsb; + scroll_bar hsb; + + std::unique_ptr style; + + // restricted functions + zoomable_region(zoomable_region&); // copy constructor + zoomable_region& operator=(zoomable_region&); // assignment operator + + }; + +// ---------------------------------------------------------------------------------------- + + class scrollable_region : public drawable + { + /*! + INITIAL VALUE + - hscroll_bar_inc == 1 + - vscroll_bar_inc == 1 + - h_wheel_scroll_bar_inc == 1 + - v_wheel_scroll_bar_inc == 1 + - mouse_drag_enabled_ == false + - user_is_dragging_mouse == false + + CONVENTION + - mouse_drag_enabled() == mouse_drag_enabled_ + - horizontal_scroll_increment() == hscroll_bar_inc + - vertical_scroll_increment() == vscroll_bar_inc + - horizontal_mouse_wheel_scroll_increment() == h_wheel_scroll_bar_inc + - vertical_mouse_wheel_scroll_increment() == v_wheel_scroll_bar_inc + - vertical_scroll_pos() == vsb.slider_pos() + - horizontal_scroll_pos() == hsb.slider_pos() + - total_rect() == total_rect_ + - display_rect() == display_rect_ + + - if (the user is currently dragging the total_rect around with a mouse drag) then + - user_is_dragging_mouse == true + - drag_origin == the point the mouse was at, with respect to total_rect, + when the dragging started + - else + - user_is_dragging_mouse == false + !*/ + + public: + + scrollable_region ( + drawable_window& w, + unsigned long events = 0 + ); + + virtual ~scrollable_region ( + ) = 0; + + template < + typename style_type + > + void set_style ( + const style_type& style_ + ) + { + auto_mutex M(m); + style.reset(new style_type(style_)); + hsb.set_style(style_.get_horizontal_scroll_bar_style()); + vsb.set_style(style_.get_vertical_scroll_bar_style()); + + // do this just so that everything gets redrawn right + set_size(rect.width(), rect.height()); + } + + void show ( + ); + + void hide ( + ); + + void enable ( + ); + + void disable ( + ); + + void set_z_order ( + long order + ); + + virtual void set_size ( + unsigned long width, + unsigned long height + ); + + unsigned long horizontal_mouse_wheel_scroll_increment ( + ) const; + + unsigned long vertical_mouse_wheel_scroll_increment ( + ) const; + + void set_horizontal_mouse_wheel_scroll_increment ( + unsigned long inc + ); + + void set_vertical_mouse_wheel_scroll_increment ( + unsigned long inc + ); + + unsigned long horizontal_scroll_increment ( + ) const; + + unsigned long vertical_scroll_increment ( + ) const; + + void set_horizontal_scroll_increment ( + unsigned long inc + ); + + void set_vertical_scroll_increment ( + unsigned long inc + ); + + long horizontal_scroll_pos ( + ) const; + + long vertical_scroll_pos ( + ) const; + + void set_horizontal_scroll_pos ( + long pos + ); + + void set_vertical_scroll_pos ( + long pos + ); + + virtual void set_pos ( + long x, + long y + ); + + bool mouse_drag_enabled ( + ) const; + + void enable_mouse_drag ( + ); + + void disable_mouse_drag ( + ); + + protected: + + virtual void on_view_changed () {} + + const rectangle& display_rect ( + ) const; + + void set_total_rect_size ( + unsigned long width, + unsigned long height + ); + + const rectangle& total_rect ( + ) const; + + void scroll_to_rect ( + const rectangle& r_ + ); + + void on_wheel_down ( + unsigned long state + ); + + void on_mouse_move ( + unsigned long state, + long x, + long y + ); + + void on_mouse_down ( + unsigned long btn, + unsigned long state, + long x, + long y, + bool is_double_click + ); + + void on_mouse_up ( + unsigned long btn, + unsigned long state, + long x, + long y + ); + + void on_wheel_up ( + unsigned long state + ); + + void draw ( + const canvas& c + ) const; + + private: + + bool need_h_scroll ( + ) const; + + bool need_v_scroll ( + ) const; + + void on_h_scroll ( + ); + + void on_v_scroll ( + ); + + rectangle total_rect_; + rectangle display_rect_; + scroll_bar hsb; + scroll_bar vsb; + unsigned long hscroll_bar_inc; + unsigned long vscroll_bar_inc; + unsigned long h_wheel_scroll_bar_inc; + unsigned long v_wheel_scroll_bar_inc; + bool mouse_drag_enabled_; + bool user_is_dragging_mouse; + point drag_origin; + std::unique_ptr style; + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // class popup_menu_region +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class popup_menu_region : public drawable + { + /*! + CONVENTION + popup_menu_visible() == popup_menu_shown + !*/ + + public: + + popup_menu_region( + drawable_window& w + ); + + virtual ~popup_menu_region( + ); + + void set_size ( + unsigned long width, + unsigned long height + ); + + void set_rect ( + const rectangle& new_rect + ); + + popup_menu& menu ( + ); + + void hide ( + ); + + void disable ( + ); + + bool popup_menu_visible ( + ) const { auto_mutex M(m); return popup_menu_shown; } + + protected: + + void on_keydown ( + unsigned long key, + bool is_printable, + unsigned long state + ); + + void on_focus_lost ( + ); + + void on_focus_gained ( + ); + + void on_window_moved( + ); + + void on_mouse_down ( + unsigned long btn, + unsigned long state, + long x, + long y, + bool is_double_click + ); + + void on_menu_becomes_hidden ( + ); + + void draw ( + const canvas& + ) const; + + private: + + popup_menu menu_; + bool popup_menu_shown; + + // restricted functions + popup_menu_region(popup_menu_region&); // copy constructor + popup_menu_region& operator=(popup_menu_region&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- + +} + +#ifdef NO_MAKEFILE +#include "base_widgets.cpp" +#endif + +#endif // DLIB_BASE_WIDGETs_ + diff --git a/ml/dlib/dlib/gui_widgets/base_widgets_abstract.h b/ml/dlib/dlib/gui_widgets/base_widgets_abstract.h new file mode 100644 index 000000000..3dcee0d5a --- /dev/null +++ b/ml/dlib/dlib/gui_widgets/base_widgets_abstract.h @@ -0,0 +1,2290 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net), Keita Mochizuki +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_BASE_WIDGETs_ABSTRACT_ +#ifdef DLIB_BASE_WIDGETs_ABSTRACT_ + +#include "fonts_abstract.h" +#include "drawable_abstract.h" + +#include "../gui_core.h" +#include + +namespace dlib +{ + + /*! + GENERAL REMARKS + This file contains objects that are useful for creating complex drawable + widgets. + + THREAD SAFETY + All objects and functions defined in this file are thread safe. You may + call them from any thread without serializing access to them. + + EVENT HANDLERS + If you derive from any of the drawable objects and redefine any of the on_*() + event handlers then you should ensure that your version calls the same event + handler in the base object so that the base class part of your object will also + be able to process the event. + + Also note that all event handlers, including the user registered callback + functions, are executed in the event handling thread. Additionally, + the drawable::m mutex will always be locked while these event handlers + are running. Also, don't rely on get_thread_id() always returning the + same ID from inside event handlers. + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // class draggable +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class draggable : public drawable + { + /*! + INITIAL VALUE + draggable_area() == an initial value for its type + + WHAT THIS OBJECT REPRESENTS + This object represents a drawable object that is draggable by the mouse. + You use it by inheriting from it and defining the draw() method and any + of the on_*() event handlers you need. + + This object is draggable by the user when is_enabled() == true and + not draggable otherwise. + !*/ + + public: + + draggable( + drawable_window& w, + unsigned long events = 0 + ); + /*! + ensures + - #*this is properly initialized + - #*this has been added to window w + - #parent_window() == w + - This object will not receive any events or draw() requests until + enable_events() is called + - the events flags are passed on to the drawable object's + constructor. + throws + - std::bad_alloc + - dlib::thread_error + !*/ + + virtual ~draggable( + ) = 0; + /*! + ensures + - all resources associated with *this have been released + !*/ + + rectangle draggable_area ( + ) const; + /*! + ensures + - returns the area that this draggable can be dragged around in. + !*/ + + void set_draggable_area ( + const rectangle& area + ); + /*! + ensures + - #draggable_area() == area + !*/ + + protected: + + bool is_being_dragged ( + ) const; + /*! + requires + - mutex drawable::m is locked + ensures + - if (this widget is currently being dragged by the user) then + - returns true + - else + - returns false + !*/ + + // does nothing by default + virtual void on_drag ( + ){} + /*! + requires + - enable_events() has been called + - is_enabled() == true + - is_hidden() == false + - mutex drawable::m is locked + - is called when the user drags this object + - get_rect() == the rectangle that defines the new position + of this object. + - is_being_dragged() == true + ensures + - does not change the state of mutex drawable::m. + !*/ + + // does nothing by default + virtual void on_drag_stop ( + ){} + /*! + requires + - enable_events() has been called + - mutex drawable::m is locked + - is called when the user stops dragging this object + - is_being_dragged() == false + ensures + - does not change the state of mutex drawable::m. + !*/ + + private: + + // restricted functions + draggable(draggable&); // copy constructor + draggable& operator=(draggable&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // class mouse_over_event +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class mouse_over_event : public drawable + { + /*! + INITIAL VALUE + is_mouse_over() == false + + WHAT THIS OBJECT REPRESENTS + This object represents a drawable object with the addition of two events + that will alert you when the mouse enters or leaves your drawable object. + + You use it by inheriting from it and defining the draw() method and any + of the on_*() event handlers you need. + !*/ + + public: + + mouse_over_event( + drawable_window& w, + unsigned long events = 0 + ); + /*! + ensures + - #*this is properly initialized + - #*this has been added to window w + - #parent_window() == w + - #*this will not receive any events or draw() requests until + enable_events() is called + - the events flags are passed on to the drawable object's + constructor. + throws + - std::bad_alloc + - dlib::thread_error + !*/ + + virtual ~mouse_over_event( + ) = 0; + /*! + ensures + - all resources associated with *this have been released + !*/ + + protected: + + bool is_mouse_over ( + ) const; + /*! + requires + - mutex drawable::m is locked + ensures + - if (the mouse is currently over this widget) then + - returns true + - else + - returns false + !*/ + + // does nothing by default + virtual void on_mouse_over ( + ){} + /*! + requires + - enable_events() has been called + - mutex drawable::m is locked + - is_enabled() == true + - is_hidden() == false + - is called whenever this object transitions from the state where + is_mouse_over() == false to is_mouse_over() == true + ensures + - does not change the state of mutex drawable::m. + !*/ + + // does nothing by default + virtual void on_mouse_not_over ( + ){} + /*! + requires + - enable_events() has been called + - mutex drawable::m is locked + - is called whenever this object transitions from the state where + is_mouse_over() == true to is_mouse_over() == false + ensures + - does not change the state of mutex drawable::m. + !*/ + + private: + + // restricted functions + mouse_over_event(mouse_over_event&); // copy constructor + mouse_over_event& operator=(mouse_over_event&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // class button_action +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class button_action : public mouse_over_event + { + /*! + INITIAL VALUE + is_depressed() == false + + WHAT THIS OBJECT REPRESENTS + This object represents the clicking action of a push button. It provides + simple callbacks that can be used to make various kinds of button + widgets. + + You use it by inheriting from it and defining the draw() method and any + of the on_*() event handlers you need. + !*/ + + public: + + button_action( + drawable_window& w, + unsigned long events = 0 + ); + /*! + ensures + - #*this is properly initialized + - #*this has been added to window w + - #parent_window() == w + - #*this will not receive any events or draw() requests until + enable_events() is called + - the events flags are passed on to the drawable object's + constructor. + throws + - std::bad_alloc + - dlib::thread_error + !*/ + + virtual ~button_action( + ) = 0; + /*! + ensures + - all resources associated with *this have been released + !*/ + + protected: + + bool is_depressed ( + ) const; + /*! + requires + - mutex drawable::m is locked + ensures + - if (this button is currently in a depressed state) then + - the user has left clicked on this drawable and is still + holding the left mouse button down over it. + - returns true + - else + - returns false + !*/ + + // does nothing by default + virtual void on_button_down ( + ){} + /*! + requires + - enable_events() has been called + - mutex drawable::m is locked + - is_enabled() == true + - is_hidden() == false + - the area in parent_window() defined by get_rect() has been invalidated. + (This means you don't have to call invalidate_rectangle()) + - is called whenever this object transitions from the state where + is_depressed() == false to is_depressed() == true + ensures + - does not change the state of mutex drawable::m. + !*/ + + // does nothing by default + virtual void on_button_up ( + bool mouse_over + ){} + /*! + requires + - enable_events() has been called + - mutex drawable::m is locked + - the area in parent_window() defined by get_rect() has been invalidated. + (This means you don't have to call invalidate_rectangle()) + - is called whenever this object transitions from the state where + is_depressed() == true to is_depressed() == false + - if (the mouse was over this button when this event occurred) then + - mouse_over == true + - else + - mouse_over == false + ensures + - does not change the state of mutex drawable::m. + !*/ + + private: + + // restricted functions + button_action(button_action&); // copy constructor + button_action& operator=(button_action&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // class button +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class button : public button_action + { + /*! + INITIAL VALUE + name() == "" + tooltip_text() == "" (i.e. there is no tooltip by default) + + WHAT THIS OBJECT REPRESENTS + This object represents a simple button. + + When this object is disabled it means it will not respond to user clicks. + !*/ + + public: + + button( + drawable_window& w + ); + /*! + ensures + - #*this is properly initialized + - #*this has been added to window w + - #parent_window() == w + throws + - std::bad_alloc + - dlib::thread_error + !*/ + + virtual ~button( + ); + /*! + ensures + - all resources associated with *this have been released + !*/ + + void set_size ( + unsigned long width_, + unsigned long height_ + ); + /*! + ensures + - if (width and height are big enough to contain the name of this button) then + - #width() == width_ + - #height() == height_ + - #top() == top() + - #left() == left() + - i.e. The location of the upper left corner of this button stays the + same but its width and height are modified + !*/ + + void set_name (const std::wstring& name); + void set_name (const dlib::ustring& name); + void set_name ( + const std::string& name + ); + /*! + ensures + - #name() == name + - this button has been resized such that it is big enough to contain + the new name. + throws + - std::bad_alloc + !*/ + + const std::wstring wname () const; + const dlib::string uname () const; + const std::string name ( + ) const; + /*! + ensures + - returns the name of this button + throws + - std::bad_alloc + !*/ + + void set_tooltip_text (const std::wstring& text); + void set_tooltip_text (const dlib::ustring& text); + void set_tooltip_text ( + const std::string& text + ); + /*! + ensures + - #tooltip_text() == text + - enables the tooltip for this button + !*/ + + const dlib::ustring tooltip_utext () const; + const std::wstring tooltip_wtext () const; + const std::string tooltip_text ( + ) const; + /*! + ensures + - returns the text that is displayed in the tooltip for this button + !*/ + + bool is_depressed ( + ) const; + /*! + ensures + - if (this button is currently in a depressed state) then + - the user has left clicked on this widget and is still + holding the left mouse button down over it. + - returns true + - else + - returns false + !*/ + + template < + typename style_type + > + void set_style ( + const style_type& style + ); + /*! + requires + - style_type == a type that inherits from button_style + ensures + - this button object will draw itself using the given + button style + !*/ + + template < + typename T + > + void set_click_handler ( + T& object, + void (T::*event_handler)() + ); + /*! + requires + - event_handler is a valid pointer to a member function in T + ensures + - the event_handler function is called on object when the button is + clicked by the user. + - any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + throws + - std::bad_alloc + !*/ + + void set_click_handler ( + const any_function& event_handler + ); + /*! + ensures + - the event_handler function is called when the button is clicked by + the user. + - any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + throws + - std::bad_alloc + !*/ + + template < + typename T + > + void set_click_handler ( + T& object, + void (T::*event_handler)(button& self) + ); + /*! + requires + - event_handler is a valid pointer to a member function in T + ensures + - &self == this + - the event_handler function is called on object when the button is + clicked by the user. + - any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + throws + - std::bad_alloc + !*/ + + void set_sourced_click_handler ( + const any_function& event_handler + ); + /*! + ensures + - &self == this + - the event_handler function is called when the button is clicked by + the user. + - any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + throws + - std::bad_alloc + !*/ + + template < + typename T + > + void set_button_down_handler ( + T& object, + void (T::*event_handler)() + ); + /*! + requires + - event_handler is a valid pointer to a member function in T + ensures + - the event_handler function is called on object when the user causes + the button to go into its depressed state. + - any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + throws + - std::bad_alloc + !*/ + + void set_button_down_handler ( + const any_function& event_handler + ); + /*! + ensures + - the event_handler function is called when the user causes the button + to go into its depressed state. + - any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + throws + - std::bad_alloc + !*/ + + template < + typename T + > + void set_button_up_handler ( + T& object, + void (T::*event_handler)(bool mouse_over) + ); + /*! + requires + - event_handler is a valid pointer to a member function in T + ensures + - the event_handler function is called on object when the user causes + the button to go into its non-depressed state. + - if (the mouse is over this button when this event occurs) then + - mouse_over == true + - else + - mouse_over == false + - any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + throws + - std::bad_alloc + !*/ + + void set_button_up_handler ( + const any_function& event_handler + ); + /*! + ensures + - the event_handler function is called when the user causes the + button to go into its non-depressed state. + - if (the mouse is over this button when this event occurs) then + - mouse_over == true + - else + - mouse_over == false + - any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + throws + - std::bad_alloc + !*/ + + template < + typename T + > + void set_button_down_handler ( + T& object, + void (T::*event_handler)(button& self) + ); + /*! + requires + - event_handler is a valid pointer to a member function in T + ensures + - &self == this + - the event_handler function is called on object when the user causes + the button to go into its depressed state. + - any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + throws + - std::bad_alloc + !*/ + + void set_sourced_button_down_handler ( + const any_function& event_handler + ); + /*! + ensures + - &self == this + - the event_handler function is called when the user causes the button + to go into its depressed state. + - any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + throws + - std::bad_alloc + !*/ + + template < + typename T + > + void set_button_up_handler ( + T& object, + void (T::*event_handler)(bool mouse_over, button& self) + ); + /*! + requires + - event_handler is a valid pointer to a member function in T + ensures + - &self == this + - the event_handler function is called on object when the user causes + the button to go into its non-depressed state. + - if (the mouse is over this button when this event occurs) then + - mouse_over == true + - else + - mouse_over == false + - any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + throws + - std::bad_alloc + !*/ + + void set_sourced_button_up_handler ( + const any_function& event_handler + ); + /*! + ensures + - &self == this + - the event_handler function is called when the user causes the + button to go into its non-depressed state. + - if (the mouse is over this button when this event occurs) then + - mouse_over == true + - else + - mouse_over == false + - any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + throws + - std::bad_alloc + !*/ + + private: + + // restricted functions + button(button&); // copy constructor + button& operator=(button&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // class scroll_bar +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class scroll_bar : public drawable + { + /*! + INITIAL VALUE + orientation() == a value given to the constructor. + max_slider_pos() == 0 + slider_pos() == 0 + jump_size() == 10 + + WHAT THIS OBJECT REPRESENTS + This object represents a scroll bar. The slider_pos() of the scroll bar + ranges from 0 to max_slider_pos(). The 0 position of the scroll_bar is + in the top or left side of the scroll_bar depending on its orientation. + + When this object is disabled it means it will not respond to user clicks. + !*/ + + public: + enum bar_orientation + { + HORIZONTAL, + VERTICAL + }; + + scroll_bar( + drawable_window& w, + bar_orientation orientation + ); + /*! + ensures + - #*this is properly initialized + - #*this has been added to window w + - #orientation() == orientation + - #parent_window() == w + throws + - std::bad_alloc + - dlib::thread_error + !*/ + + virtual ~scroll_bar( + ); + /*! + ensures + - all resources associated with *this have been released + !*/ + + bar_orientation orientation ( + ) const; + /*! + ensures + - returns the orientation of this scroll_bar + !*/ + + template < + typename style_type + > + void set_style ( + const style_type& style + ); + /*! + requires + - style_type == a type that inherits from scroll_bar_style + ensures + - this scroll_bar object will draw itself using the given + scroll bar style + !*/ + + void set_length ( + unsigned long length, + ); + /*! + ensures + - if (orientation() == HORIZONTAL) then + - #width() == max(length,1) + - else + - #height() == max(length,1) + !*/ + + long max_slider_pos ( + ) const; + /*! + ensures + - returns the maximum value that slider_pos() can take. + !*/ + + void set_max_slider_pos ( + long mpos + ); + /*! + ensures + - if (mpos < 0) then + - #max_slider_pos() == 0 + - else + - #max_slider_pos() == mpos + - if (slider_pos() > #max_slider_pos()) then + - #slider_pos() == #max_slider_pos() + - else + - #slider_pos() == slider_pos() + !*/ + + void set_slider_pos ( + unsigned long pos + ); + /*! + ensures + - if (pos < 0) then + - #slider_pos() == 0 + - else if (pos > max_slider_pos()) then + - #slider_pos() == max_slider_pos() + - else + - #slider_pos() == pos + !*/ + + long slider_pos ( + ) const; + /*! + ensures + - returns the current position of the slider box within the scroll bar. + !*/ + + long jump_size ( + ) const; + /*! + ensures + - returns the number of positions that the slider bar will jump when the + user clicks on the empty gaps above or below the slider bar. + (note that the slider will jump less than the jump size if it hits the + end of the scroll bar) + !*/ + + void set_jump_size ( + long js + ); + /*! + ensures + - if (js < 1) then + - #jump_size() == 1 + - else + - #jump_size() == js + !*/ + + + template < + typename T + > + void set_scroll_handler ( + T& object, + void (T::*event_handler)() + ); + /*! + requires + - event_handler is a valid pointer to a member function in T + ensures + - The event_handler function is called whenever the user causes the slider box + to move. + - This event is NOT triggered by calling set_slider_pos() + - any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + throws + - std::bad_alloc + !*/ + + void set_scroll_handler ( + const any_function& event_handler + ); + /*! + ensures + - The event_handler function is called whenever the user causes the slider box + to move. + - This event is NOT triggered by calling set_slider_pos() + - any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + throws + - std::bad_alloc + !*/ + + private: + + // restricted functions + scroll_bar(scroll_bar&); // copy constructor + scroll_bar& operator=(scroll_bar&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // class widget_group +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class widget_group : public drawable + { + /*! + INITIAL VALUE + size() == 0 + get_rect().is_empty() == true + left() == 0 + top() == 0 + + WHAT THIS OBJECT REPRESENTS + This object represents a grouping of drawable widgets. It doesn't draw + anything itself, rather it lets you manipulate the position, enabled + status, and visibility of a set of widgets as a group. + !*/ + + public: + widget_group( + drawable_window& w + ); + /*! + ensures + - #*this is properly initialized + - #*this has been added to window w + - #parent_window() == w + throws + - std::bad_alloc + - dlib::thread_error + !*/ + + virtual ~widget_group( + ); + /*! + ensures + - all resources associated with *this have been released. + !*/ + + void empty ( + ); + /*! + ensures + - #size() == 0 + !*/ + + void fit_to_contents ( + ); + /*! + ensures + - does not change the position of this object. + (i.e. the upper left corner of get_rect() remains at the same position) + - if (size() == 0) then + - #get_rect().is_empty() == true + - else + - recursively calls fit_to_contents() on any widget_groups inside + this object. + - #get_rect() will be the smallest rectangle that contains all the + widgets in this group and the upper left corner of get_rect(). + !*/ + + size_t size ( + ) const; + /*! + ensures + - returns the number of widgets currently in *this. + !*/ + + void add ( + drawable& widget, + unsigned long x, + unsigned long y + ); + /*! + ensures + - #is_member(widget) == true + - if (is_member(widget) == false) then + - #size() == size() + 1 + - else + - #size() == size() + - The following conditions apply to this function as well as to all of the + following functions so long as is_member(widget) == true: + enable(), disable(), hide(), show(), set_z_order(), and set_pos(). + - #widget.left() == left()+x + - #widget.width() == widget.width() + - #widget.top() == top()+y + - #widget.height() == widget.height() + - #widget.is_hidden() == is_hidden() + - #widget.is_enabled() == is_enabled() + - #widget.z_order() == z_order() + throws + - std::bad_alloc + !*/ + + bool is_member ( + const drawable& widget + ) const; + /*! + ensures + - returns true if widget is currently in this object, returns false otherwise. + !*/ + + void remove ( + const drawable& widget + ); + /*! + ensures + - #is_member(widget) == false + - if (is_member(widget) == true) then + - #size() == size() - 1 + - else + - #size() == size() + !*/ + + protected: + + // this object doesn't draw anything but also isn't abstract + void draw ( + const canvas& c + ) const {} + + private: + + // restricted functions + widget_group(widget_group&); // copy constructor + widget_group& operator=(widget_group&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // class image_widget +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class image_widget : public draggable + { + /*! + INITIAL VALUE + draggable_area() == an initial value for its type. + This object isn't displaying anything. + + WHAT THIS OBJECT REPRESENTS + This object represents a draggable image. You give it an image to display + by calling set_image(). + + Also note that initially the draggable area is empty so it won't be + draggable unless you call set_draggable_area() to some non-empty region. + + The image is drawn such that: + - the pixel img[0][0] is the upper left corner of the image. + - the pixel img[img.nr()-1][0] is the lower left corner of the image. + - the pixel img[0][img.nc()-1] is the upper right corner of the image. + - the pixel img[img.nr()-1][img.nc()-1] is the lower right corner of the image. + + !*/ + + public: + + image_widget( + drawable_window& w + ); + /*! + ensures + - #*this is properly initialized + - #*this has been added to window w + - #parent_window() == w + throws + - std::bad_alloc + - dlib::thread_error + !*/ + + virtual ~image_widget( + ); + /*! + ensures + - all resources associated with *this have been released + !*/ + + template < + typename image_type + > + void set_image ( + const image_type& img + ); + /*! + requires + - image_type == an implementation of array2d/array2d_kernel_abstract.h + - pixel_traits must be defined + ensures + - #width() == img.nc() + - #height() == img.nr() + - #*this widget is now displaying the given image img. + !*/ + + private: + + // restricted functions + image_widget(image_widget&); // copy constructor + image_widget& operator=(image_widget&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // class tooltip +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class tooltip : public mouse_over_event + { + /*! + INITIAL VALUE + - text() == "" + - the tooltip is inactive until the text is changed to + a non-empty string. + + WHAT THIS OBJECT REPRESENTS + This object represents a region on a window where if the user + hovers the mouse over this region a tooltip with a message + appears. + !*/ + + public: + + tooltip( + drawable_window& w + ); + /*! + ensures + - #*this is properly initialized + - #*this has been added to window w + - #parent_window() == w + throws + - std::bad_alloc + - dlib::thread_error + !*/ + + virtual ~tooltip( + ); + /*! + ensures + - all resources associated with *this have been released + !*/ + + void set_size ( + unsigned long width_, + unsigned long height_ + ); + /*! + ensures + - #width() == width_ + - #height() == height_ + - #top() == top() + - #left() == left() + - i.e. The location of the upper left corner of this widget stays the + same but its width and height are modified + !*/ + + void set_text (const std::wstring& str); + void set_text (const dlib::ustring& str); + void set_text ( + const std::string& str + ); + /*! + ensures + - #text() == str + - activates the tooltip. i.e. after this function the tooltip + will display on the screen when the user hovers the mouse over it + !*/ + + const std::wstring wtext () const; + const dlib::ustring utext () const; + const std::string text ( + ) const; + /*! + ensures + - returns the text that is displayed inside this + tooltip + !*/ + + private: + + // restricted functions + tooltip(tooltip&); // copy constructor + tooltip& operator=(tooltip&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // popup menu stuff +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class menu_item + { + /*! + WHAT THIS OBJECT REPRESENTS + This is an abstract class that defines the interface a + menu item in a popup_menu must implement. + + Note that a menu_item is drawn as 3 separate pieces: + --------------------------------- + | left | middle | right | + --------------------------------- + + Also note that derived classes must be copyable via + their copy constructors. + !*/ + + public: + + virtual ~menu_item() {} + + virtual void on_click ( + ) const {} + /*! + requires + - the mutex drawable::m is locked + - if (has_click_event()) then + - this function is called when the user clicks on this menu_item + !*/ + + virtual bool has_click_event ( + ) const { return false; } + /*! + ensures + - if (this menu_item wants to receive on_click events) then + - returns true + - else + - returns false + !*/ + + virtual unichar get_hot_key ( + ) const { return 0; } + /*! + ensures + - if (this menu item has a keyboard hot key) then + - returns the unicode value of the key + - else + - returns 0 + !*/ + + virtual rectangle get_left_size ( + ) const { return rectangle(); } // return empty rect by default + /*! + ensures + - returns the dimensions of the left part of the menu_item + !*/ + + virtual rectangle get_middle_size ( + ) const = 0; + /*! + ensures + - returns the dimensions of the middle part of the menu_item + !*/ + + virtual rectangle get_right_size ( + ) const { return rectangle(); } // return empty rect by default + /*! + ensures + - returns the dimensions of the right part of the menu_item + !*/ + + virtual void draw_background ( + const canvas& c, + const rectangle& rect, + const bool enabled, + const bool is_selected + ) const {} + /*! + requires + - the mutex drawable::m is locked + requires + - c == the canvas to draw on + - rect == the rectangle in which we are to draw the background + - enabled == true if the menu_item is to be drawn enabled + - is_selected == true if the menu_item is to be drawn selected + ensures + - draws the background of the menu_item on the canvas c at the location + given by rect. + !*/ + + virtual void draw_left ( + const canvas& c, + const rectangle& rect, + const bool enabled, + const bool is_selected + ) const {} + /*! + requires + - the mutex drawable::m is locked + requires + - c == the canvas to draw on + - rect == the rectangle in which we are to draw the background + - enabled == true if the menu_item is to be drawn enabled + - is_selected == true if the menu_item is to be drawn selected + ensures + - draws the left part of the menu_item on the canvas c at the location + given by rect. + !*/ + + virtual void draw_middle ( + const canvas& c, + const rectangle& rect, + const bool enabled, + const bool is_selected + ) const = 0; + /*! + requires + - the mutex drawable::m is locked + requires + - c == the canvas to draw on + - rect == the rectangle in which we are to draw the background + - enabled == true if the menu_item is to be drawn enabled + - is_selected == true if the menu_item is to be drawn selected + ensures + - draws the middle part of the menu_item on the canvas c at the location + given by rect. + !*/ + + virtual void draw_right ( + const canvas& c, + const rectangle& rect, + const bool enabled, + const bool is_selected + ) const {} + /*! + requires + - the mutex drawable::m is locked + requires + - c == the canvas to draw on + - rect == the rectangle in which we are to draw the background + - enabled == true if the menu_item is to be drawn enabled + - is_selected == true if the menu_item is to be drawn selected + ensures + - draws the right part of the menu_item on the canvas c at the location + given by rect. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + class menu_item_text : public menu_item + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a simple text menu item + !*/ + + public: + + template < + typename T + > + menu_item_text ( + const std::string& str, + T& object, + void (T::*on_click_handler)(), + unichar hotkey = 0 + ); + /*! + ensures + - The text of this menu item will be str + - the on_click_handler function is called on object when this menu_item + clicked by the user. + - #get_hot_key() == hotkey + !*/ + + menu_item_text ( + const std::string& str, + const any_function& on_click_handler, + unichar hotkey = 0 + ); + /*! + ensures + - The text of this menu item will be str + - the on_click_handler function is called when this menu_item + clicked by the user. + - #get_hot_key() == hotkey + !*/ + + // overloads for wide character strings + template < + typename T + > + menu_item_text ( + const std::wstring& str, + T& object, + void (T::*on_click_handler)(), + unichar hotkey = 0 + ); + + menu_item_text ( + const std::wstring& str, + const any_function& on_click_handler, + unichar hotkey = 0 + ); + + template < + typename T + > + menu_item_text ( + const dlib::ustring& str, + T& object, + void (T::*on_click_handler)(), + unichar hotkey = 0 + ); + + template < + typename T + > + menu_item_text ( + const dlib::ustring& str, + const any_function& on_click_handler, + unichar hotkey = 0 + ); + }; + +// ---------------------------------------------------------------------------------------- + + class menu_item_submenu : public menu_item + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a simple text item intended to be used with + submenus inside a popup_menu. + !*/ + + public: + + menu_item_submenu ( + const std::string& str, + unichar hotkey = 0 + ); + /*! + ensures + - The text of this menu item will be str + - #get_hot_key() == hotkey + !*/ + + //overloads for wide character strings + menu_item_submenu ( + const std::wstring& str, + unichar hotkey = 0 + ); + + menu_item_submenu ( + const dlib::ustring& str, + unichar hotkey = 0 + ); + }; + +// ---------------------------------------------------------------------------------------- + + class menu_item_separator : public menu_item + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a horizontal separator in a popup menu + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + class popup_menu : public base_window + { + /*! + INITIAL VALUE + - size() == 0 + + WHAT THIS OBJECT REPRESENTS + This object represents a popup menu window capable of containing + menu_item objects. + !*/ + + public: + + popup_menu ( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc + - dlib::thread_error + - dlib::gui_error + !*/ + + void clear ( + ); + /*! + ensures + - #*this has its initial value + throws + - std::bad_alloc + if this exception is thrown then *this is unusable + until clear() is called and succeeds + !*/ + + template < + typename menu_item_type + > + unsigned long add_menu_item ( + const menu_item_type& new_item + ); + /*! + requires + - menu_item_type == a type that inherits from menu_item + ensures + - adds new_item onto the bottom of this popup_menu. + - returns size() + (This is also the index by which this item can be + referenced by the enable_menu_item() and disable_menu_item() + functions.) + !*/ + + template < + typename menu_item_type + > + unsigned long add_submenu ( + const menu_item_type& new_item, + popup_menu& submenu + ); + /*! + requires + - menu_item_type == a type that inherits from menu_item + ensures + - adds new_item onto the bottom of this popup_menu. + - when the user puts the mouse above this menu_item the given + submenu popup_menu will be displayed. + - returns size() + (This is also the index by which this item can be + referenced by the enable_menu_item() and disable_menu_item() + functions.) + !*/ + + void enable_menu_item ( + unsigned long idx + ); + /*! + requires + - idx < size() + ensures + - the menu_item in this with the index idx has been enabled + !*/ + + void disable_menu_item ( + unsigned long idx + ); + /*! + requires + - idx < size() + ensures + - the menu_item in this with the index idx has been disabled + !*/ + + size_t size ( + ) const; + /*! + ensures + - returns the number of menu_item objects in this popup_menu + !*/ + + template + void set_on_hide_handler ( + T& object, + void (T::*event_handler)() + ); + /*! + ensures + - the event_handler function is called on object when this popup_menu + hides itself due to an action by the user. + - Note that you can register multiple handlers for this event. + !*/ + + void select_first_item ( + ); + /*! + ensures + - causes this popup menu to highlight the first + menu item that it contains which has a click event + and is enabled. + !*/ + + bool forwarded_on_keydown ( + unsigned long key, + bool is_printable, + unsigned long state + ); + /*! + requires + - key, is_printable, and state are the variables from the + base_window::on_keydown() event + ensures + - forwards this keyboard event to this popup window so that it + may deal with keyboard events from other windows. + - if (this popup_menu uses the keyboard event) then + - returns true + - else + - returns false + !*/ + + private: + + // restricted functions + popup_menu(popup_menu&); // copy constructor + popup_menu& operator=(popup_menu&); // assignment operator + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // class popup_menu_region +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class popup_menu_region : public drawable + { + /*! + INITIAL VALUE + - popup_menu_visible() == false + + WHAT THIS OBJECT REPRESENTS + This object represents a region on a window where if the user + right clicks the mouse over this region a popup_menu pops up. + + Note that this widget doesn't actually draw anything, it just + provides a region the user can click on to get a popup menu. + !*/ + + public: + + popup_menu_region( + drawable_window& w + ); + /*! + ensures + - #*this is properly initialized + - #*this has been added to window w + - #parent_window() == w + throws + - std::bad_alloc + - dlib::thread_error + !*/ + + virtual ~popup_menu_region( + ); + /*! + ensures + - all resources associated with *this have been released + !*/ + + void set_size ( + unsigned long width_, + unsigned long height_ + ); + /*! + ensures + - #width() == width_ + - #height() == height_ + - #top() == top() + - #left() == left() + - i.e. The location of the upper left corner of this widget stays the + same but its width and height are modified + !*/ + + void set_rect ( + const rectangle& new_rect + ); + /*! + ensures + - #get_rect() == new_rect + !*/ + + bool popup_menu_visible ( + ) const; + /*! + ensures + - if (the popup menu is currently visible on the screen) then + - returns true + - else + - returns false + !*/ + + popup_menu& menu ( + ); + /*! + ensures + - returns a reference to the popup_menu for this object. It is + the menu that is displayed when the user right clicks on + this widget + !*/ + + private: + + // restricted functions + popup_menu_region(popup_menu_region&); // copy constructor + popup_menu_region& operator=(popup_menu_region&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // class zoomable_region +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class zoomable_region : public drawable + { + /* + INITIAL VALUE + - min_zoom_scale() == 0.15 + - max_zoom_scale() == 1.0 + - zoom_increment() == 0.90 + - zoom_scale() == 1.0 + + WHAT THIS OBJECT REPRESENTS + This object represents a 2D Cartesian graph that you can zoom into and + out of. It is a graphical widget that draws a rectangle with + a horizontal and vertical scroll bar that allow the user to scroll + around on a Cartesian graph that is much larger than the actual + area occupied by this object on the screen. It also allows + the user to zoom in and out. + + To use this object you inherit from it and make use of its public and + protected member functions. It provides functions for converting between + pixel locations and the points in our 2D Cartesian graph so that when the + user is scrolling/zooming the widget you can still determine where + things are to be placed on the screen and what screen pixels correspond + to in the Cartesian graph. + + Note that the Cartesian graph in this object is bounded by the point + (0,0), corresponding to the upper left corner when we are zoomed all + the way out, and max_graph_point() which corresponds to the lower right + corner when zoomed all the way out. The value of max_graph_point() is + determined automatically from the size of this object's on screen + rectangle and the value of min_zoom_scale() which determines how far + out you can zoom. + */ + + public: + + zoomable_region ( + drawable_window& w, + unsigned long events = 0 + ); + /*! + ensures + - #*this is properly initialized + - #*this has been added to window w + - #parent_window() == w + - This object will not receive any events or draw() requests until + enable_events() is called + - the events flags are passed on to the drawable object's + constructor. + throws + - std::bad_alloc + - dlib::thread_error + !*/ + + virtual ~zoomable_region ( + ) = 0; + /*! + ensures + - all resources associated with *this have been released + !*/ + + template < + typename style_type + > + void set_style ( + const style_type& style_ + ); + /*! + requires + - style_type == a type that inherits from scrollable_region_style + ensures + - this zoomable_region object will draw itself using the given + style + !*/ + + void set_zoom_increment ( + double zi + ); + /*! + requires + - 0 < zi < 1 + ensures + - #zoom_increment() == zi + !*/ + + double zoom_increment ( + ) const; + /*! + ensures + - When the user zooms in using the mouse wheel: + - #zoom_scale() == zoom_scale() / zoom_increment() + - When the user zooms out using the mouse wheel: + - #zoom_scale() == zoom_scale() * zoom_increment() + - So this function returns the number that determines how much the zoom + changes when the mouse wheel is moved. + !*/ + + void set_max_zoom_scale ( + double ms + ); + /*! + requires + - ms > 0 + ensures + - #max_zoom_scale() == ms + !*/ + + void set_min_zoom_scale ( + double ms + ); + /*! + requires + - ms > 0 + ensures + - #min_zoom_scale() == ms + !*/ + + double min_zoom_scale ( + ) const; + /*! + ensures + - returns the minimum allowed value of zoom_scale() + (i.e. this is the number that determines how far out the user is allowed to zoom) + !*/ + + double max_zoom_scale ( + ) const; + /*! + ensures + - returns the maximum allowed value of zoom_scale() + (i.e. this is the number that determines how far in the user is allowed to zoom) + !*/ + + virtual void set_size ( + unsigned long width, + unsigned long height + ); + /*! + ensures + - #width() == width_ + - #height() == height_ + - #top() == top() + - #left() == left() + - i.e. The location of the upper left corner of this button stays the + same but its width and height are modified + !*/ + + protected: + + rectangle display_rect ( + ) const; + /*! + requires + - mutex drawable::m is locked + ensures + - returns the rectangle on the screen that contains the Cartesian + graph in this widget. I.e. this is the area of this widget minus + the area taken up by the scroll bars and border decorations. + !*/ + + point graph_to_gui_space ( + const vector& graph_point + ) const; + /*! + requires + - mutex drawable::m is locked + ensures + - returns the location of the pixel on the screen that corresponds + to the given point in Cartesian graph space + !*/ + + vector gui_to_graph_space ( + const point& pixel_point + ) const; + /*! + requires + - mutex drawable::m is locked + ensures + - returns the point in Cartesian graph space that corresponds to the given + pixel location + !*/ + + vector max_graph_point ( + ) const; + /*! + requires + - mutex drawable::m is locked + ensures + - returns the pixel farthest from the graph point (0,0) that is still + in the graph. I.e. returns the point in graph space that corresponds + to the lower right corner of the display_rect() when we are zoomed + all the way out. + !*/ + + double zoom_scale ( + ) const; + /*! + requires + - mutex drawable::m is locked + ensures + - returns a double Z that represents the current zoom. + - Smaller values of Z represent the user zooming out. + - Bigger values of Z represent the user zooming in. + - The default unzoomed case is when Z == 1 + - objects should be drawn such that they are zoom_scale() + times their normal size + !*/ + + void set_zoom_scale ( + double new_scale + ); + /*! + requires + - mutex drawable::m is locked + ensures + - invalidates the display_rect() so that it will be redrawn + - if (min_zoom_scale() <= new_scale && new_scale <= max_zoom_scale()) then + - #zoom_scale() == new_scale + - else if (new_scale < min_zoom_scale()) then + - #zoom_scale() == min_zoom_scale() + - else if (new_scale > max_zoom_scale()) then + - #zoom_scale() == max_zoom_scale() + !*/ + + void center_display_at_graph_point ( + const vector& graph_point + ); + /*! + requires + - mutex drawable::m is locked + ensures + - causes the given graph point to be centered in the display + if possible + - invalidates the display_rect() so that it will be redrawn + !*/ + + virtual void on_view_changed ( + ) {} + /*! + requires + - events_are_enabled() == true + - mutex drawable::m is locked + ensures + - on_view_changed() is called whenever the user causes the view of the + zoomable_region to change. That is, this function is called when the + user scrolls or zooms around in the region. + !*/ + + // ---------------------------- event handlers ---------------------------- + // The following event handlers are used in this object. So if you + // use any of them in your derived object you should pass the events + // back to it so that they still operate unless you wish to hijack the + // event for your own reasons (e.g. to override the mouse drag this object + // performs) + + void on_wheel_down (unsigned long state); + void on_wheel_up (unsigned long state); + void on_mouse_move ( unsigned long state, long x, long y); + void on_mouse_up ( unsigned long btn, unsigned long state, long x, long y); + void on_mouse_down ( unsigned long btn, unsigned long state, long x, long y, bool is_double_click); + void draw ( const canvas& c) const; + + private: + + // restricted functions + zoomable_region(zoomable_region&); // copy constructor + zoomable_region& operator=(zoomable_region&); // assignment operator + + }; + +// ---------------------------------------------------------------------------------------- + + class scrollable_region : public drawable + { + /*! + INITIAL VALUE + - horizontal_scroll_pos() == 0 + - horizontal_scroll_increment() == 1 + - horizontal_mouse_wheel_scroll_increment() == 1 + - vertical_scroll_pos() == 0 + - vertical_scroll_increment() == 1 + - vertical_mouse_wheel_scroll_increment() == 1 + - total_rect().empty() == true + - mouse_drag_enabled() == false + + WHAT THIS OBJECT REPRESENTS + This object represents a 2D region of arbitrary size that is displayed + within a possibly smaller scrollable gui widget. That is, it is a + graphical widget that draws a rectangle with a horizontal and vertical + scroll bar that allows the user to scroll around on a region that is much + larger than the actual area occupied by this object on the screen. + + To use this object you inherit from it and make use of its public and + protected member functions. It provides a function, total_rect(), that + tells you where the 2D region is on the screen. You draw your stuff + inside total_rect() as you would normally except that you only modify + pixels that are also inside display_rect(). When the user moves the + scroll bars the position of total_rect() is updated accordingly, causing + the widget's content to scroll across the screen. + !*/ + + public: + scrollable_region ( + drawable_window& w, + unsigned long events = 0 + ); + /*! + ensures + - #*this is properly initialized + - #*this has been added to window w + - #parent_window() == w + - This object will not receive any events or draw() requests until + enable_events() is called + - the events flags are passed on to the drawable object's + constructor. + throws + - std::bad_alloc + - dlib::thread_error + !*/ + + virtual ~scrollable_region ( + ) = 0; + /*! + ensures + - all resources associated with *this have been released + !*/ + + template < + typename style_type + > + void set_style ( + const style_type& style_ + ); + /*! + requires + - style_type == a type that inherits from scrollable_region_style + ensures + - this scrollable_region object will draw itself using the given + style + !*/ + + virtual void set_size ( + unsigned long width, + unsigned long height + ); + /*! + ensures + - #width() == width_ + - #height() == height_ + - #top() == top() + - #left() == left() + - i.e. The location of the upper left corner of this widget stays the + same but its width and height are modified. + !*/ + + long horizontal_scroll_pos ( + ) const; + /*! + ensures + - returns the current position of the horizontal scroll bar. + 0 means it is at the far left while bigger values represent + scroll positions closer to the right. + !*/ + + long vertical_scroll_pos ( + ) const; + /*! + ensures + - returns the current position of the vertical scroll bar. + 0 means it is at the top and bigger values represent scroll positions + closer to the bottom. + !*/ + + void set_horizontal_scroll_pos ( + long pos + ); + /*! + ensures + - if (pos is a valid horizontal scroll position) then + - #horizontal_scroll_pos() == pos + - else + - #horizontal_scroll_pos() == the valid scroll position closest to pos + !*/ + + void set_vertical_scroll_pos ( + long pos + ); + /*! + ensures + - if (pos is a valid vertical scroll position) then + - #vertical_scroll_pos() == pos + - else + - #vertical_scroll_pos() == the valid scroll position closest to pos + !*/ + + unsigned long horizontal_mouse_wheel_scroll_increment ( + ) const; + /*! + ensures + - returns the number of positions the horizontal scroll bar + moves when the user scrolls the mouse wheel. + !*/ + + unsigned long vertical_mouse_wheel_scroll_increment ( + ) const; + /*! + ensures + - returns the number of positions the vertical scroll bar + moves when the user scrolls the mouse wheel. + !*/ + + void set_horizontal_mouse_wheel_scroll_increment ( + unsigned long inc + ); + /*! + ensures + - #horizontal_mouse_wheel_scroll_increment() == inc + !*/ + + void set_vertical_mouse_wheel_scroll_increment ( + unsigned long inc + ); + /*! + ensures + - #vertical_mouse_wheel_scroll_increment() == inc + !*/ + + + unsigned long horizontal_scroll_increment ( + ) const; + /*! + ensures + - returns the number of pixels that total_rect() is moved by when + the horizontal scroll bar moves by one position + !*/ + + unsigned long vertical_scroll_increment ( + ) const; + /*! + ensures + - returns the number of pixels that total_rect() is moved by when + the vertical scroll bar moves by one position + !*/ + + void set_horizontal_scroll_increment ( + unsigned long inc + ); + /*! + ensures + - #horizontal_scroll_increment() == inc + !*/ + + void set_vertical_scroll_increment ( + unsigned long inc + ); + /*! + ensures + - #vertical_scroll_increment() == inc + !*/ + + bool mouse_drag_enabled ( + ) const; + /*! + ensures + - if (the user can drag this contents of this widget around by + holding down the left mouse button and dragging) then + - returns true + - else + - returns false + !*/ + + void enable_mouse_drag ( + ); + /*! + ensures + - #mouse_drag_enabled() == true + !*/ + + void disable_mouse_drag ( + ); + /*! + ensures + - #mouse_drag_enabled() == false + !*/ + + protected: + + rectangle display_rect ( + ) const; + /*! + requires + - mutex drawable::m is locked + ensures + - returns the rectangle on the screen that contains the scrollable + area in this widget. I.e. this is the area of this widget minus + the area taken up by the scroll bars and border decorations. + !*/ + + void set_total_rect_size ( + unsigned long width, + unsigned long height + ); + /*! + requires + - mutex drawable::m is locked + - (width > 0 && height > 0) || (width == 0 && height == 0) + ensures + - #total_rect().width() == width + - #total_rect().height() == height + - The scroll bars as well as the position of #total_rect() + is updated so that the total rect is still in the correct + position with respect to the scroll bars. + !*/ + + const rectangle& total_rect ( + ) const; + /*! + requires + - mutex drawable::m is locked + ensures + - returns a rectangle that represents the entire scrollable + region inside this widget, even the parts that are outside + display_rect(). + !*/ + + void scroll_to_rect ( + const rectangle& r + ); + /*! + requires + - mutex drawable::m is locked + ensures + - Adjusts the scroll bars of this object so that the part of + the total_rect() rectangle that overlaps with r is displayed in + the display_rect() rectangle on the screen. + !*/ + + virtual void on_view_changed ( + ) {} + /*! + requires + - events_are_enabled() == true + - mutex drawable::m is locked + ensures + - on_view_changed() is called whenever the user causes the view of the + scrollable_region to change. That is, this function is called when the + user scrolls around in the region. + !*/ + + // ---------------------------- event handlers ---------------------------- + // The following event handlers are used in this object. So if you + // use any of them in your derived object you should pass the events + // back to it so that they still operate unless you wish to hijack the + // event for your own reasons (e.g. to override the mouse wheel action + // this object performs) + + void on_wheel_down (unsigned long state); + void on_wheel_up (unsigned long state); + void on_mouse_move (unsigned long state, long x, long y); + void on_mouse_down (unsigned long btn, unsigned long state, long x, long y, bool is_double_click); + void on_mouse_up (unsigned long btn, unsigned long state, long x, long y); + void draw (const canvas& c) const; + + private: + + // restricted functions + scrollable_region(scrollable_region&); // copy constructor + scrollable_region& operator=(scrollable_region&); // assignment operator + + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BASE_WIDGETs_ABSTRACT_ + + diff --git a/ml/dlib/dlib/gui_widgets/canvas_drawing.cpp b/ml/dlib/dlib/gui_widgets/canvas_drawing.cpp new file mode 100644 index 000000000..0fecd1cd3 --- /dev/null +++ b/ml/dlib/dlib/gui_widgets/canvas_drawing.cpp @@ -0,0 +1,101 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net), and Nils Labugt +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CANVAS_DRAWINg_CPP_ +#define DLIB_CANVAS_DRAWINg_CPP_ + +#include "canvas_drawing.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + void draw_sunken_rectangle ( + const canvas& c, + const rectangle& border, + unsigned char alpha + ) + { + rectangle area = border.intersect(c); + if (area.is_empty() == false) + { + const rgb_alpha_pixel dark_gray(64,64,64,alpha); + const rgb_alpha_pixel gray(128,128,128,alpha); + const rgb_alpha_pixel white(255,255,255,alpha); + const rgb_alpha_pixel background(212,208,200,alpha); + + draw_line(c,point(border.left(),border.top()),point(border.right()-1,border.top()),gray); + + draw_line(c,point(border.left(),border.bottom()),point(border.right(),border.bottom()),white); + draw_line(c,point(border.left()+1,border.bottom()-1),point(border.right()-1,border.bottom()-1),background); + + draw_line(c,point(border.left(),border.top()+1),point(border.left(),border.bottom()-1),gray); + + draw_line(c,point(border.right(),border.top()),point(border.right(),border.bottom()-1),white); + draw_line(c,point(border.right()-1,border.top()+1),point(border.right()-1,border.bottom()-2),background); + + draw_line(c,point(border.left()+1,border.top()+1),point(border.left()+1,border.bottom()-2),dark_gray); + draw_line(c,point(border.left()+1,border.top()+1),point(border.right()-2,border.top()+1),dark_gray); + } + } + +// ---------------------------------------------------------------------------------------- + + void draw_button_down ( + const canvas& c, + const rectangle& btn, + unsigned char alpha + ) + { + rectangle area = btn.intersect(c); + if (area.is_empty() == false) + { + const rgb_alpha_pixel dark_gray(64,64,64,alpha); + const rgb_alpha_pixel gray(128,128,128,alpha); + const rgb_alpha_pixel black(0,0,0,alpha); + + draw_line(c,point(btn.left(),btn.top()),point(btn.right(),btn.top()),black); + + draw_line(c,point(btn.left()+1,btn.bottom()),point(btn.right(),btn.bottom()),dark_gray); + draw_line(c,point(btn.left()+1,btn.top()+1),point(btn.right()-1,btn.top()+1),gray); + + draw_line(c,point(btn.left(),btn.top()+1),point(btn.left(),btn.bottom()),black); + + draw_line(c,point(btn.right(),btn.top()+1),point(btn.right(),btn.bottom()-1),dark_gray); + draw_line(c,point(btn.left()+1,btn.top()+1),point(btn.left()+1,btn.bottom()-1),gray); + } + } + +// ---------------------------------------------------------------------------------------- + + void draw_button_up ( + const canvas& c, + const rectangle& btn, + unsigned char alpha + ) + { + rectangle area = btn.intersect(c); + if (area.is_empty() == false) + { + const rgb_alpha_pixel dark_gray(64,64,64,alpha); + const rgb_alpha_pixel gray(128,128,128,alpha); + const rgb_alpha_pixel white(255,255,255,alpha); + + draw_line(c,point(btn.left(),btn.top()),point(btn.right()-1,btn.top()),white); + + draw_line(c,point(btn.left(),btn.bottom()),point(btn.right(),btn.bottom()),dark_gray); + draw_line(c,point(btn.left()+1,btn.bottom()-1),point(btn.right()-1,btn.bottom()-1),gray); + + draw_line(c,point(btn.left(),btn.top()+1),point(btn.left(),btn.bottom()-1),white); + + draw_line(c,point(btn.right(),btn.top()),point(btn.right(),btn.bottom()-1),dark_gray); + draw_line(c,point(btn.right()-1,btn.top()+1),point(btn.right()-1,btn.bottom()-2),gray); + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CANVAS_DRAWINg_CPP_ + diff --git a/ml/dlib/dlib/gui_widgets/canvas_drawing.h b/ml/dlib/dlib/gui_widgets/canvas_drawing.h new file mode 100644 index 000000000..61f688112 --- /dev/null +++ b/ml/dlib/dlib/gui_widgets/canvas_drawing.h @@ -0,0 +1,964 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net), and Nils Labugt +// License: Boost Software License See LICENSE.txt for the full license. + +#ifndef DLIB_GUI_CANVAS_DRAWINg_ +#define DLIB_GUI_CANVAS_DRAWINg_ + +#include "canvas_drawing_abstract.h" +#include "../gui_core.h" +#include "../algs.h" +#include "../array2d.h" +#include "../pixel.h" +#include "../image_transforms/assign_image.h" +#include "../geometry.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template + void draw_line ( + const canvas& c, + const point& p1, + const point& p2, + const pixel_type& pixel, + const rectangle& area = rectangle(std::numeric_limits::min(), std::numeric_limits::min(), + std::numeric_limits::max(), std::numeric_limits::max()) + ) + { + rectangle valid_area(c.intersect(area)); + long x1 = p1.x(); + long y1 = p1.y(); + long x2 = p2.x(); + long y2 = p2.y(); + if (x1 == x2) + { + // if the x coordinate is inside the canvas's area + if (x1 <= valid_area.right() && x1 >= valid_area.left()) + { + // make sure y1 comes before y2 + if (y1 > y2) + swap(y1,y2); + + y1 = std::max(y1,valid_area.top()); + y2 = std::min(y2,valid_area.bottom()); + // this is a vertical line + for (long y = y1; y <= y2; ++y) + { + assign_pixel(c[y-c.top()][x1-c.left()], pixel); + } + } + } + else if (y1 == y2) + { + // if the y coordinate is inside the canvas's area + if (y1 <= valid_area.bottom() && y1 >= valid_area.top()) + { + // make sure x1 comes before x2 + if (x1 > x2) + swap(x1,x2); + + x1 = std::max(x1,valid_area.left()); + x2 = std::min(x2,valid_area.right()); + // this is a horizontal line + for (long x = x1; x <= x2; ++x) + { + assign_pixel(c[y1-c.top()][x-c.left()], pixel); + } + } + } + else + { + rgb_alpha_pixel alpha_pixel; + assign_pixel(alpha_pixel, pixel); + const unsigned char max_alpha = alpha_pixel.alpha; + + const long rise = (((long)y2) - ((long)y1)); + const long run = (((long)x2) - ((long)x1)); + if (std::abs(rise) < std::abs(run)) + { + const double slope = ((double)rise)/run; + + double first, last; + + if (x1 > x2) + { + first = std::max(x2,valid_area.left()); + last = std::min(x1,valid_area.right()); + } + else + { + first = std::max(x1,valid_area.left()); + last = std::min(x2,valid_area.right()); + } + + + long y; + long x; + const double x1f = x1; + const double y1f = y1; + for (double i = first; i <= last; ++i) + { + const double dy = slope*(i-x1f) + y1f; + const double dx = i; + + y = static_cast(dy); + x = static_cast(dx); + + + if (y >= valid_area.top() && y <= valid_area.bottom()) + { + alpha_pixel.alpha = static_cast((1.0-(dy-y))*max_alpha); + assign_pixel(c[y-c.top()][x-c.left()], alpha_pixel); + } + if (y+1 >= valid_area.top() && y+1 <= valid_area.bottom()) + { + alpha_pixel.alpha = static_cast((dy-y)*max_alpha); + assign_pixel(c[y+1-c.top()][x-c.left()], alpha_pixel); + } + } + } + else + { + const double slope = ((double)run)/rise; + + double first, last; + + if (y1 > y2) + { + first = std::max(y2,valid_area.top()); + last = std::min(y1,valid_area.bottom()); + } + else + { + first = std::max(y1,valid_area.top()); + last = std::min(y2,valid_area.bottom()); + } + + long x; + long y; + const double x1f = x1; + const double y1f = y1; + for (double i = first; i <= last; ++i) + { + const double dx = slope*(i-y1f) + x1f; + const double dy = i; + + y = static_cast(dy); + x = static_cast(dx); + + if (x >= valid_area.left() && x <= valid_area.right()) + { + alpha_pixel.alpha = static_cast((1.0-(dx-x))*max_alpha); + assign_pixel(c[y-c.top()][x-c.left()], alpha_pixel); + } + if (x+1 >= valid_area.left() && x+1 <= valid_area.right()) + { + alpha_pixel.alpha = static_cast((dx-x)*max_alpha); + assign_pixel(c[y-c.top()][x+1-c.left()], alpha_pixel); + } + } + } + } + + } + inline void draw_line ( + const canvas& c, + const point& p1, + const point& p2 + ){ draw_line(c,p1,p2,0); } + +// ---------------------------------------------------------------------------------------- + + void draw_sunken_rectangle ( + const canvas& c, + const rectangle& border, + unsigned char alpha = 255 + ); + +// ---------------------------------------------------------------------------------------- + + template + inline void draw_pixel ( + const canvas& c, + const point& p, + const pixel_type& pixel + ) + { + if (c.contains(p)) + { + assign_pixel(c[p.y()-c.top()][p.x()-c.left()],pixel); + } + } + +// ---------------------------------------------------------------------------------------- + + template + void draw_checkered ( + const canvas& c, + const rectangle& a, + const pixel_type& pixel1, + const pixel_type& pixel2 + ) + { + rectangle area = a.intersect(c); + if (area.is_empty()) + return; + + for (long i = area.left(); i <= area.right(); ++i) + { + for (long j = area.top(); j <= area.bottom(); ++j) + { + canvas::pixel& p = c[j - c.top()][i - c.left()]; + if ((j&0x1) ^ (i&0x1)) + { + assign_pixel(p,pixel1); + } + else + { + assign_pixel(p,pixel2); + } + } + } + } + +// ---------------------------------------------------------------------------------------- + + void draw_button_down ( + const canvas& c, + const rectangle& btn, + unsigned char alpha = 255 + ); + +// ---------------------------------------------------------------------------------------- + + void draw_button_up ( + const canvas& c, + const rectangle& btn, + unsigned char alpha = 255 + ); + +// ---------------------------------------------------------------------------------------- + + template + void draw_circle ( + const canvas& c, + const point& center_point, + double radius, + const pixel_type& pixel, + const rectangle& area = rectangle(std::numeric_limits::min(), std::numeric_limits::min(), + std::numeric_limits::max(), std::numeric_limits::max()) + ) + { + using std::sqrt; + rectangle valid_area(c.intersect(area)); + const long x = center_point.x(); + const long y = center_point.y(); + if (radius > 1) + { + long first_x = static_cast(x - radius + 0.5); + long last_x = static_cast(x + radius + 0.5); + const double rs = radius*radius; + + // ensure that we only loop over the part of the x dimension that this + // canvas contains. + if (first_x < valid_area.left()) + first_x = valid_area.left(); + if (last_x > valid_area.right()) + last_x = valid_area.right(); + + long top, bottom; + + top = static_cast(sqrt(std::max(rs - (first_x-x-0.5)*(first_x-x-0.5),0.0))+0.5); + top += y; + long last = top; + + // draw the left half of the circle + long middle = std::min(x-1,last_x); + for (long i = first_x; i <= middle; ++i) + { + double a = i - x + 0.5; + // find the top of the arc + top = static_cast(sqrt(std::max(rs - a*a,0.0))+0.5); + top += y; + long temp = top; + + while(top >= last) + { + bottom = y - top + y; + if (top >= valid_area.top() && top <= valid_area.bottom() ) + { + assign_pixel(c[top-c.top()][i-c.left()],pixel); + } + + if (bottom >= valid_area.top() && bottom <= valid_area.bottom() ) + { + assign_pixel(c[bottom-c.top()][i-c.left()],pixel); + } + --top; + } + + last = temp; + } + + middle = std::max(x,first_x); + top = static_cast(sqrt(std::max(rs - (last_x-x+0.5)*(last_x-x+0.5),0.0))+0.5); + top += y; + last = top; + // draw the right half of the circle + for (long i = last_x; i >= middle; --i) + { + double a = i - x - 0.5; + // find the top of the arc + top = static_cast(sqrt(std::max(rs - a*a,0.0))+0.5); + top += y; + long temp = top; + + while(top >= last) + { + bottom = y - top + y; + if (top >= valid_area.top() && top <= valid_area.bottom() ) + { + assign_pixel(c[top-c.top()][i-c.left()],pixel); + } + + if (bottom >= valid_area.top() && bottom <= valid_area.bottom() ) + { + assign_pixel(c[bottom-c.top()][i-c.left()],pixel); + } + --top; + } + + last = temp; + } + } + else if (radius == 1 && + x >= valid_area.left() && x <= valid_area.right() && + y >= valid_area.top() && y <= valid_area.bottom() ) + { + assign_pixel(c[y-c.top()][x-c.left()], pixel); + } + } + inline void draw_circle ( + const canvas& c, + const point& center_point, + double radius + ){ draw_circle(c, center_point, radius, 0); } + +// ---------------------------------------------------------------------------------------- + + template + void draw_solid_circle ( + const canvas& c, + const point& center_point, + double radius, + const pixel_type& pixel, + const rectangle& area = rectangle(std::numeric_limits::min(), std::numeric_limits::min(), + std::numeric_limits::max(), std::numeric_limits::max()) + ) + { + using std::sqrt; + rectangle valid_area(c.intersect(area)); + const long x = center_point.x(); + const long y = center_point.y(); + if (radius > 1) + { + long first_x = static_cast(x - radius + 0.5); + long last_x = static_cast(x + radius + 0.5); + const double rs = radius*radius; + + // ensure that we only loop over the part of the x dimension that this + // canvas contains. + if (first_x < valid_area.left()) + first_x = valid_area.left(); + if (last_x > valid_area.right()) + last_x = valid_area.right(); + + long top, bottom; + + top = static_cast(sqrt(std::max(rs - (first_x-x-0.5)*(first_x-x-0.5),0.0))+0.5); + top += y; + long last = top; + + // draw the left half of the circle + long middle = std::min(x-1,last_x); + for (long i = first_x; i <= middle; ++i) + { + double a = i - x + 0.5; + // find the top of the arc + top = static_cast(sqrt(std::max(rs - a*a,0.0))+0.5); + top += y; + long temp = top; + + while(top >= last) + { + bottom = y - top + y; + draw_line(c, point(i,top),point(i,bottom),pixel,area); + --top; + } + + last = temp; + } + + middle = std::max(x,first_x); + top = static_cast(sqrt(std::max(rs - (last_x-x+0.5)*(last_x-x+0.5),0.0))+0.5); + top += y; + last = top; + // draw the right half of the circle + for (long i = last_x; i >= middle; --i) + { + double a = i - x - 0.5; + // find the top of the arc + top = static_cast(sqrt(std::max(rs - a*a,0.0))+0.5); + top += y; + long temp = top; + + while(top >= last) + { + bottom = y - top + y; + draw_line(c, point(i,top),point(i,bottom),pixel,area); + --top; + } + + last = temp; + } + } + else if (radius == 1 && + x >= valid_area.left() && x <= valid_area.right() && + y >= valid_area.top() && y <= valid_area.bottom() ) + { + assign_pixel(c[y-c.top()][x-c.left()], pixel); + } + } + inline void draw_solid_circle ( + const canvas& c, + const point& center_point, + double radius + ) { draw_solid_circle(c, center_point, radius, 0); } + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + + template + void get_convex_polygon_shape ( + const std::vector& points, + const long top, + const long bottom, + std::vector& left_boundary, + std::vector& right_boundary + ) + /*! + requires + - 0 <= top <= bottom + ensures + - interprets points as the coordinates defining a convex polygon. In + particular, we interpret points as a list of the vertices of the polygon + and assume they are ordered in clockwise order. + - #left_boundary.size() == bottom-top+1 + - #right_boundary.size() == bottom-top+1 + - for all top <= y <= bottom: + - #left_boundary[y-top] == the x coordinate for the left most side of + the polygon at coordinate y. + - #right_boundary[y-top] == the x coordinate for the right most side of + the polygon at coordinate y. + !*/ + { + using std::min; + using std::max; + + left_boundary.assign(bottom-top+1, std::numeric_limits::infinity()); + right_boundary.assign(bottom-top+1, -std::numeric_limits::infinity()); + + // trace out the points along the edge of the polynomial and record them + for (unsigned long i = 0; i < points.size(); ++i) + { + const point p1 = points[i]; + const point p2 = points[(i+1)%points.size()]; + + if (p1.y() == p2.y()) + { + if (top <= p1.y() && p1.y() <= bottom) + { + const long y = p1.y() - top; + const double xmin = min(p1.x(), p2.x()); + const double xmax = min(p1.x(), p2.x()); + left_boundary[y] = min(left_boundary[y], xmin); + right_boundary[y] = max(right_boundary[y], xmax); + } + } + else + { + // Here we trace out the line from p1 to p2 and record where it hits. + + // x = m*y + b + const double m = (p2.x() - p1.x())/(double)(p2.y()-p1.y()); + const double b = p1.x() - m*p1.y(); // because: x1 = m*y1 + b + + const long ymin = max(top,min(p1.y(), p2.y())); + const long ymax = min(bottom,max(p1.y(), p2.y())); + for (long y = ymin; y <= ymax; ++y) + { + const double x = m*y + b; + const unsigned long idx = y-top; + left_boundary[idx] = min(left_boundary[idx], x); + right_boundary[idx] = max(right_boundary[idx], x); + } + } + } + } + + // ------------------------------------------------------------------------------------ + + } + + template + void draw_solid_convex_polygon ( + const canvas& c, + const std::vector& polygon, + const pixel_type& pixel, + const rectangle& area = rectangle(std::numeric_limits::min(), std::numeric_limits::min(), + std::numeric_limits::max(), std::numeric_limits::max()) + ) + { + using std::max; + using std::min; + const rectangle valid_area(c.intersect(area)); + + rectangle bounding_box; + for (unsigned long i = 0; i < polygon.size(); ++i) + bounding_box += polygon[i]; + + // Don't do anything if the polygon is totally outside the area we can draw in + // right now. + if (bounding_box.intersect(valid_area).is_empty()) + return; + + rgb_alpha_pixel alpha_pixel; + assign_pixel(alpha_pixel, pixel); + const unsigned char max_alpha = alpha_pixel.alpha; + + // we will only want to loop over the part of left_boundary that is part of the + // valid_area. + long top = max(valid_area.top(),bounding_box.top()); + long bottom = min(valid_area.bottom(),bounding_box.bottom()); + + // Since we look at the adjacent rows of boundary information when doing the alpha + // blending, we want to make sure we always have some boundary information unless + // we are at the absolute edge of the polygon. + const long top_offset = (top == bounding_box.top()) ? 0 : 1; + const long bottom_offset = (bottom == bounding_box.bottom()) ? 0 : 1; + if (top != bounding_box.top()) + top -= 1; + if (bottom != bounding_box.bottom()) + bottom += 1; + + std::vector left_boundary; + std::vector right_boundary; + impl::get_convex_polygon_shape(polygon, top, bottom, left_boundary, right_boundary); + + + // draw the polygon row by row + for (unsigned long i = top_offset; i < left_boundary.size(); ++i) + { + long left_x = static_cast(std::ceil(left_boundary[i])); + long right_x = static_cast(std::floor(right_boundary[i])); + + left_x = max(left_x, valid_area.left()); + right_x = min(right_x, valid_area.right()); + + if (i < left_boundary.size()-bottom_offset) + { + // draw the main body of the polygon + for (long x = left_x; x <= right_x; ++x) + { + const long y = i+top; + assign_pixel(c[y-c.top()][x-c.left()], pixel); + } + } + + if (i == 0) + continue; + + // Now draw anti-aliased edges so they don't look all pixely. + + // Alpha blend the edges on the left side. + double delta = left_boundary[i-1] - left_boundary[i]; + if (std::abs(delta) <= 1) + { + if (std::floor(left_boundary[i]) != left_x) + { + const point p(static_cast(std::floor(left_boundary[i])), i+top); + rgb_alpha_pixel temp = alpha_pixel; + temp.alpha = max_alpha-static_cast((left_boundary[i]-p.x())*max_alpha); + if (valid_area.contains(p)) + assign_pixel(c[p.y()-c.top()][p.x()-c.left()],temp); + } + } + else if (delta < 0) // on the bottom side + { + for (long x = static_cast(std::ceil(left_boundary[i-1])); x < left_x; ++x) + { + const point p(x, i+top); + rgb_alpha_pixel temp = alpha_pixel; + temp.alpha = static_cast((x-left_boundary[i-1])/std::abs(delta)*max_alpha); + if (valid_area.contains(p)) + assign_pixel(c[p.y()-c.top()][p.x()-c.left()],temp); + } + } + else // on the top side + { + const long old_left_x = static_cast(std::ceil(left_boundary[i-1])); + for (long x = left_x; x < old_left_x; ++x) + { + const point p(x, i+top-1); + rgb_alpha_pixel temp = alpha_pixel; + temp.alpha = static_cast((x-left_boundary[i])/delta*max_alpha); + if (valid_area.contains(p)) + assign_pixel(c[p.y()-c.top()][p.x()-c.left()],temp); + } + } + + + // Alpha blend the edges on the right side + delta = right_boundary[i-1] - right_boundary[i]; + if (std::abs(delta) <= 1) + { + if (std::ceil(right_boundary[i]) != right_x) + { + const point p(static_cast(std::ceil(right_boundary[i])), i+top); + rgb_alpha_pixel temp = alpha_pixel; + temp.alpha = max_alpha-static_cast((p.x()-right_boundary[i])*max_alpha); + if (valid_area.contains(p)) + assign_pixel(c[p.y()-c.top()][p.x()-c.left()],temp); + } + } + else if (delta < 0) // on the top side + { + for (long x = static_cast(std::floor(right_boundary[i-1]))+1; x <= right_x; ++x) + { + const point p(x, i+top-1); + rgb_alpha_pixel temp = alpha_pixel; + temp.alpha = static_cast((right_boundary[i]-x)/std::abs(delta)*max_alpha); + if (valid_area.contains(p)) + assign_pixel(c[p.y()-c.top()][p.x()-c.left()],temp); + } + } + else // on the bottom side + { + const long old_right_x = static_cast(std::floor(right_boundary[i-1])); + for (long x = right_x+1; x <= old_right_x; ++x) + { + const point p(x, i+top); + rgb_alpha_pixel temp = alpha_pixel; + temp.alpha = static_cast((right_boundary[i-1]-x)/delta*max_alpha); + if (valid_area.contains(p)) + assign_pixel(c[p.y()-c.top()][p.x()-c.left()],temp); + } + } + } + } + inline void draw_solid_convex_polygon ( + const canvas& c, + const std::vector& polygon + ) { draw_solid_convex_polygon(c, polygon, 0); } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + void draw_image ( + const canvas& c, + const point& p, + const image_type& img, + const rectangle& area_ = rectangle(std::numeric_limits::min(), std::numeric_limits::min(), + std::numeric_limits::max(), std::numeric_limits::max()) + ) + { + const long x = p.x(); + const long y = p.y(); + rectangle rect(x,y,num_columns(img)+x-1,num_rows(img)+y-1); + rectangle area = c.intersect(rect).intersect(area_); + if (area.is_empty()) + return; + + for (long row = area.top(); row <= area.bottom(); ++row) + { + for (long col = area.left(); col <= area.right(); ++col) + { + assign_pixel(c[row-c.top()][col-c.left()], img[row-rect.top()][col-rect.left()]); + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + void draw_image ( + const canvas& c, + const rectangle& rect, + const image_type& img, + const rectangle& area_ = rectangle(std::numeric_limits::min(), std::numeric_limits::min(), + std::numeric_limits::max(), std::numeric_limits::max()) + ) + { + const rectangle area = c.intersect(rect).intersect(area_); + if (area.is_empty() || num_columns(img) * num_rows(img) == 0) + return; + + const matrix x = matrix_cast(round(linspace(0, num_columns(img)-1, rect.width()))); + const matrix y = matrix_cast(round(linspace(0, num_rows(img)-1, rect.height()))); + + for (long row = area.top(); row <= area.bottom(); ++row) + { + const long r = y(row-rect.top()); + long cc = area.left() - rect.left(); + for (long col = area.left(); col <= area.right(); ++col) + { + assign_pixel(c[row-c.top()][col-c.left()], img[r][x(cc++)]); + } + } + } + +// ---------------------------------------------------------------------------------------- + + template + void draw_rounded_rectangle ( + const canvas& c, + const rectangle& rect, + unsigned radius, + const pixel_type& color, + const rectangle& area_ = rectangle(std::numeric_limits::min(), std::numeric_limits::min(), + std::numeric_limits::max(), std::numeric_limits::max()) + ) + { + if ( rect.intersect ( c ).is_empty() ) + return; + + draw_line ( c, point(rect.left() + radius + 1, rect.bottom()), + point(rect.right() - radius - 1, rect.bottom()), color,area_ ); + + draw_line ( c, point(rect.left() + radius + 1, rect.top()), + point(rect.right() - radius - 1, rect.top()), color,area_ ); + + draw_line ( c, point(rect.left(), rect.top() + radius + 1), + point(rect.left(), rect.bottom() - radius - 1), color,area_ ); + + draw_line ( c, point(rect.right(), rect.top() + radius + 1), + point(rect.right(), rect.bottom() - radius - 1), color,area_ ); + + unsigned x = radius, y = 0, old_x = x; + + point p; + while ( x > y ) + { + p = point(rect.left() + radius - y, rect.top() + radius - x); + if (area_.contains(p)) draw_pixel (c, p , color ); + p = point(rect.right() - radius + y, rect.top() + radius - x); + if (area_.contains(p)) draw_pixel (c, p , color ); + p = point(rect.right() - radius + y, rect.bottom() - radius + x); + if (area_.contains(p)) draw_pixel (c, p , color ); + p = point(rect.left() + radius - y, rect.bottom() - radius + x); + if (area_.contains(p)) draw_pixel (c, p , color ); + p = point(rect.left() + radius - x, rect.top() + radius - y); + if (area_.contains(p)) draw_pixel (c, p , color ); + p = point(rect.right() - radius + x, rect.top() + radius - y); + if (area_.contains(p)) draw_pixel (c, p , color ); + p = point(rect.right() - radius + x, rect.bottom() - radius + y); + if (area_.contains(p)) draw_pixel (c, p , color ); + p = point(rect.left() + radius - x, rect.bottom() - radius + y); + if (area_.contains(p)) draw_pixel (c, p , color ); + y++; + old_x = x; + x = square_root ( ( radius * radius - y * y ) * 4 ) / 2; + } + + if ( x == y && old_x != x ) + { + p = point(rect.left() + radius - y, rect.top() + radius - x); + if (area_.contains(p)) draw_pixel (c, p , color ); + p = point(rect.right() - radius + y, rect.top() + radius - x); + if (area_.contains(p)) draw_pixel (c, p , color ); + p = point(rect.right() - radius + y, rect.bottom() - radius + x); + if (area_.contains(p)) draw_pixel (c, p , color ); + p = point(rect.left() + radius - y, rect.bottom() - radius + x); + if (area_.contains(p)) draw_pixel (c, p , color ); + } + } + +// ---------------------------------------------------------------------------------------- + + template + void fill_gradient_rounded ( + const canvas& c, + const rectangle& rect, + unsigned long radius, + const pixel_type& top_color, + const pixel_type& bottom_color, + const rectangle& area = rectangle(std::numeric_limits::min(), std::numeric_limits::min(), + std::numeric_limits::max(), std::numeric_limits::max()) + + ) + { + rectangle valid_area(c.intersect(area.intersect(rect))); + if ( valid_area.is_empty() ) + return; + + + unsigned long m_prev = 0, m = radius, c_div = valid_area.height() - 1; + + const long c_top = valid_area.top(); + const long c_bottom = valid_area.bottom(); + + for ( long y = c_top; y <= c_bottom;y++ ) + { + + unsigned long c_s = y - c_top; + + unsigned long c_t = c_bottom - y; + + + if ( c_div == 0 ) + { + // only a single round, just take the average color + c_div = 2; + c_s = c_t = 1; + } + + rgb_alpha_pixel color; + vector_to_pixel(color, + ((pixel_to_vector(top_color)*c_t + pixel_to_vector(bottom_color)*c_s)/c_div)); + + unsigned long s = y - rect.top(); + + unsigned long t = rect.bottom() - y; + + if ( s < radius ) + { + m = radius - square_root ( ( radius * radius - ( radius - s ) * ( radius - s ) ) * 4 ) / 2; + + if ( s == m && m + 1 < m_prev ) // these are hacks to remove distracting artefacts at small radii + m++; + } + else if ( t < radius ) + { + m = radius - square_root ( ( radius * radius - ( radius - t ) * ( radius - t ) ) * 4 ) / 2; + + if ( t == m && m == m_prev ) + m++; + } + else + { + m = 0; + } + + m_prev = m; + + draw_line ( c, point(rect.left() + m, y), + point(rect.right() - m, y), color, valid_area ); + } + } + +// ---------------------------------------------------------------------------------------- + + template + void draw_rectangle ( + const canvas& c, + rectangle rect, + const pixel_type& pixel, + const rectangle& area = rectangle(std::numeric_limits::min(), std::numeric_limits::min(), + std::numeric_limits::max(), std::numeric_limits::max()) + ) + { + // top line + draw_line(c, point(rect.left(),rect.top()), + point(rect.right(),rect.top()), + pixel, area); + + // bottom line + draw_line(c, point(rect.left(),rect.bottom()), + point(rect.right(),rect.bottom()), + pixel, area); + + // left line + draw_line(c, point(rect.left(),rect.top()), + point(rect.left(),rect.bottom()), + pixel, area); + + // right line + draw_line(c, point(rect.right(),rect.top()), + point(rect.right(),rect.bottom()), + pixel, area); + } + inline void draw_rectangle ( + const canvas& c, + rectangle rect + ){ draw_rectangle(c, rect, 0); } + +// ---------------------------------------------------------------------------------------- + + template + void fill_rect ( + const canvas& c, + const rectangle& rect, + const pixel_type& pixel + ) + { + rectangle area = rect.intersect(c); + for (long y = area.top(); y <= area.bottom(); ++y) + { + for (long x = area.left(); x <= area.right(); ++x) + { + assign_pixel(c[y-c.top()][x-c.left()], pixel); + } + } + } + +// ---------------------------------------------------------------------------------------- + + template + void fill_rect_with_vertical_gradient ( + const canvas& c, + const rectangle& rect, + const pixel_type& pixel_top, + const pixel_type& pixel_bottom, + const rectangle& area_ = rectangle(std::numeric_limits::min(), std::numeric_limits::min(), + std::numeric_limits::max(), std::numeric_limits::max()) + ) + { + rectangle area = rect.intersect(c).intersect(area_); + pixel_type pixel; + + const long s = rect.bottom()-rect.top(); + + for (long y = area.top(); y <= area.bottom(); ++y) + { + const long t = rect.bottom()-y; + const long b = y-rect.top(); + vector_to_pixel(pixel, + ((pixel_to_vector(pixel_top)*t + + pixel_to_vector(pixel_bottom)*b)/s)); + + for (long x = area.left(); x <= area.right(); ++x) + { + assign_pixel(c[y-c.top()][x-c.left()], pixel); + } + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#ifdef NO_MAKEFILE +#include "canvas_drawing.cpp" +#endif + +#endif // DLIB_GUI_CANVAS_DRAWINg_ + diff --git a/ml/dlib/dlib/gui_widgets/canvas_drawing_abstract.h b/ml/dlib/dlib/gui_widgets/canvas_drawing_abstract.h new file mode 100644 index 000000000..e4a298c76 --- /dev/null +++ b/ml/dlib/dlib/gui_widgets/canvas_drawing_abstract.h @@ -0,0 +1,364 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net), and Nils Labugt +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_GUI_CANVAS_DRAWINg_ABSTRACT_ +#ifdef DLIB_GUI_CANVAS_DRAWINg_ABSTRACT_ + +#include "../gui_core.h" +#include "../pixel.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename pixel_type + > + void draw_line ( + const canvas& c, + const point& p1, + const point& p2, + const pixel_type& pixel = rgb_pixel(0,0,0), + const rectangle& area = rectangle(-infinity,-infinity,infinity,infinity) + ); + /*! + requires + - pixel_traits is defined + ensures + - draws the part of the line from p1 to p1 that overlaps with + the canvas and area onto the canvas. + - Uses the given pixel color. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename pixel_type + > + void draw_rectangle ( + const canvas& c, + rectangle rect, + const pixel_type& pixel = rgb_pixel(0,0,0), + const rectangle& area = rectangle(-infinity,-infinity,infinity,infinity) + ); + /*! + requires + - pixel_traits is defined + ensures + - Draws the part of the rectangle that overlaps with + the canvas and area onto the canvas. + - Uses the given pixel color. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename pixel_type + > + void draw_circle ( + const canvas& c, + const point& center_point, + double radius, + const pixel_type& pixel = rgb_pixel(0,0,0), + const rectangle& area = rectangle(-infinity,-infinity,infinity,infinity) + ); + /*! + requires + - pixel_traits is defined + ensures + - draws the part of the circle centered at center_point with the given radius + that overlaps with the canvas and area onto the canvas. + - Uses the given pixel color. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename pixel_type + > + void draw_pixel ( + const canvas& c, + const point& p, + const pixel_type& pixel + ); + /*! + requires + - pixel_traits is defined + ensures + - if (c.contains(p)) then + - sets the pixel in c that represents the point p to the + given pixel color. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename pixel_type + > + void draw_solid_circle ( + const canvas& c, + const point& center_point, + double radius, + const pixel_type& pixel = rgb_pixel(0,0,0), + const rectangle& area = rectangle(-infinity,-infinity,infinity,infinity) + ); + /*! + requires + - pixel_traits is defined + ensures + - draws the part of the solid circle centered at center_point with the given + radius that overlaps with the canvas and area onto the canvas. + ("solid" means that the interior is also filled in with the given + pixel color) + - Uses the given pixel color. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename pixel_type + > + void draw_solid_convex_polygon ( + const canvas& c, + const std::vector& polygon, + const pixel_type& pixel = rgb_pixel(0,0,0), + const rectangle& area = rectangle(-infinity,-infinity,infinity,infinity) + ); + /*! + requires + - pixel_traits is defined + ensures + - Interprets the given std::vector polygon object as defining a convex polygon + shape. In particular, the polygon is given by taking the points and drawing + lines between them. That is, imagine drawing a line connecting polygon[i] + and polygon[(i+1)%polygon.size()], for all valid i, and then filling in the + interior of the polygon. That is what this function does. + - When drawing the polygon, only the part of the polygon which overlaps both + the given canvas and area rectangle is drawn. + - Uses the given pixel color to draw the polygon. + !*/ + +// ---------------------------------------------------------------------------------------- + + void draw_button_down ( + const canvas& c, + const rectangle& btn, + unsigned char alpha = 255 + ); + /*! + requires + - 0 <= alpha <= 255 + ensures + - draws the border of a button onto canvas c: + - the border will be that of a button that is depressed + - only the part of the border that overlaps with the canvas object + will be drawn. + - the border will be for the button whose area is defined by the + rectangle btn. + - performs alpha blending such that the button is drawn with full opacity + when alpha is 255 and fully transparent when alpha is 0. + !*/ + +// ---------------------------------------------------------------------------------------- + + void draw_sunken_rectangle ( + const canvas& c, + const rectangle& border, + unsigned char alpha = 255 + ); + /*! + requires + - 0 <= alpha <= 255 + ensures + - draws a sunken rectangle around the given border. + (This is the type of border used for text_fields and + check_boxes and the like). + - performs alpha blending such that the rectangle is drawn with full opacity + when alpha is 255 and fully transparent when alpha is 0. + !*/ + +// ---------------------------------------------------------------------------------------- + + void draw_button_up ( + const canvas& c, + const rectangle& btn, + unsigned char alpha = 255 + ); + /*! + requires + - 0 <= alpha <= 255 + ensures + - draws the border of a button onto canvas c: + - the border will be that of a button that is NOT depressed + - only the part of the border that overlaps with the canvas object + will be drawn. + - the border will be for the button whose area is defined by the + rectangle btn. + - performs alpha blending such that the button is drawn with full opacity + when alpha is 255 and fully transparent when alpha is 0. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename pixel_type + > + void draw_checkered ( + const canvas& c, + const rectangle& area, + const pixel_type& pixel1, + const pixel_type& pixel2 + ); + /*! + requires + - pixel_traits is defined + ensures + - fills the area on the given canvas defined by the rectangle area with a checkers + board pattern where every other pixel gets assigned either pixel1 or pixel2. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + void draw_image ( + const canvas& c + const point& p, + const image_type& image, + const rectangle& area = rectangle(-infinity,-infinity,infinity,infinity) + ); + /*! + requires + - image_type == an implementation of array2d/array2d_kernel_abstract.h + - pixel_traits is defined + ensures + - draws the given image object onto the canvas such that the upper left corner of the + image will appear at the point p in the canvas's window. (note that the + upper left corner of the image is assumed to be the pixel image[0][0] and the + lower right corner of the image is assumed to be image[image.nr()-1][image.nc()-1]) + - only draws the part of the image that overlaps with the area rectangle + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + void draw_image ( + const canvas& c, + const rectangle& rect, + const image_type& img, + const rectangle& area = rectangle(-infinity,-infinity,infinity,infinity) + ); + /*! + requires + - image_type == an implementation of array2d/array2d_kernel_abstract.h + - pixel_traits is defined + ensures + - draws the given image object onto the canvas such that the upper left corner + of the image will appear at the point rect.tl_corner() in the canvas's window + and the lower right corner of the image will appear at rect.br_corner() in + the canvas's window. (note that the upper left corner of the image is + assumed to be the pixel image[0][0] and the lower right corner of the image + is assumed to be image[image.nr()-1][image.nc()-1]) + - only draws the part of the image that overlaps with the area rectangle + - Uses nearest neighbor interpolation when the given rect isn't the same size + as the input image. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename pixel_type + > + void fill_rect ( + const canvas& c, + const rectangle& rect, + const pixel_type& pixel + ); + /*! + requires + - pixel_traits is defined + ensures + - fills the area defined by rect in the given canvas with the given pixel color. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename pixel_type + > + void fill_rect_with_vertical_gradient ( + const canvas& c, + const rectangle& rect, + const pixel_type& pixel_top, + const pixel_type& pixel_bottom, + const rectangle& area = rectangle(-infinity,-infinity,infinity,infinity) + ); + /*! + requires + - pixel_traits is defined + ensures + - fills the rectangle defined by rect in the given canvas with the given colors. + The top of the area will have the pixel_top color and will slowly fade + towards the pixel_bottom color towards the bottom of rect. + - only draws the part of the image that overlaps with the area rectangle + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename pixel_type + > + void fill_gradient_rounded ( + const canvas& c, + const rectangle& rect, + unsigned long radius, + const pixel_type& top_color, + const pixel_type& bottom_color, + const rectangle& area = rectangle(-infinity,-infinity,infinity,infinity) + ); + /*! + requires + - pixel_traits is defined + ensures + - Fills the region defined by rect in the given canvas with the given colors. + The top of the region will have the top_color color and will slowly fade + towards the bottom_color color towards the bottom of rect. + - The drawn rectangle will have rounded corners and with the amount of + - rounding given by the radius argument. + - only the part of this object that overlaps with area and the canvas + will be drawn on the canvas + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename pixel_type + > + void draw_rounded_rectangle ( + const canvas& c, + const rectangle& rect, + unsigned radius, + const pixel_type& color, + const rectangle& area = rectangle(-infinity,-infinity,infinity,infinity) + ); + /*! + requires + - pixel_traits is defined + ensures + - Draws the part of the rectangle that overlaps with + the canvas onto the canvas. + - The drawn rectangle will have rounded corners and with the amount of + rounding given by the radius argument. + - Uses the given pixel color. + - only draws the part of the image that overlaps with the area rectangle + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_GUI_CANVAS_DRAWINg_ABSTRACT_ + diff --git a/ml/dlib/dlib/gui_widgets/drawable.cpp b/ml/dlib/dlib/gui_widgets/drawable.cpp new file mode 100644 index 000000000..8cf114950 --- /dev/null +++ b/ml/dlib/dlib/gui_widgets/drawable.cpp @@ -0,0 +1,544 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net), Keita Mochizuki +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_DRAWABLe_CPP_ +#define DLIB_DRAWABLe_CPP_ + +#include "drawable.h" + +#include +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ----------- drawable_window object ------------------------------------------------------------ +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + rgb_pixel drawable_window:: + background_color ( + ) const + { + auto_mutex M(wm); + return bg_color; + } + +// ---------------------------------------------------------------------------------------- + + void drawable_window:: + set_background_color ( + unsigned long red_, + unsigned long green_, + unsigned long blue_ + ) + { + wm.lock(); + bg_color.red = red_; + bg_color.green = green_; + bg_color.blue = blue_; + wm.unlock(); + // now repaint the window + unsigned long width,height; + get_size(width,height); + rectangle rect(0,0,width-1,height-1); + invalidate_rectangle(rect); + } + +// ---------------------------------------------------------------------------------------- + + void drawable_window:: + paint ( + const canvas& c + ) + { + ++event_id; + c.fill(bg_color.red,bg_color.green,bg_color.blue); + + widgets.reset(); + while (widgets.move_next()) + { + widgets.element().value().reset(); + while (widgets.element().value().move_next()) + { + // only dispatch a draw() call if this widget isn't hidden + if (widgets.element().value().element()->hidden == false && + widgets.element().value().element()->event_id != event_id) + { + widgets.element().value().element()->event_id = event_id; + widgets.element().value().element()->draw(c); + } + } + } + } + +// ---------------------------------------------------------------------------------------- + + void drawable_window:: + on_user_event ( + void* p, + int i + ) + { + drawable* d = static_cast(p); + if (widget_set.is_member(d)) + { + d->on_user_event(i); + } + } + +// ---------------------------------------------------------------------------------------- + + void drawable_window:: + on_window_moved( + ) + { + ++event_id; + window_moved.reset(); + while (window_moved.move_next()) + { + if (window_moved.element()->event_id != event_id) + { + window_moved.element()->event_id = event_id; + window_moved.element()->on_window_moved(); + } + } + } + +// ---------------------------------------------------------------------------------------- + + void drawable_window:: + on_window_resized( + ) + { + ++event_id; + window_resized.reset(); + while (window_resized.move_next()) + { + if (window_resized.element()->event_id != event_id) + { + window_resized.element()->event_id = event_id; + window_resized.element()->on_window_resized(); + } + } + } + +// ---------------------------------------------------------------------------------------- + + void drawable_window:: + on_keydown ( + unsigned long key, + bool is_printable, + unsigned long state + ) + { + ++event_id; + keyboard.reset(); + while (keyboard.move_next()) + { + if (keyboard.element()->event_id != event_id) + { + keyboard.element()->event_id = event_id; + keyboard.element()->on_keydown(key,is_printable,state); + } + } + } + +// ---------------------------------------------------------------------------------------- + + void drawable_window:: + on_focus_gained ( + ) + { + ++event_id; + focus.reset(); + while (focus.move_next()) + { + if (focus.element()->event_id != event_id) + { + focus.element()->event_id = event_id; + focus.element()->on_focus_gained(); + } + } + } + +// ---------------------------------------------------------------------------------------- + + void drawable_window:: + on_focus_lost ( + ) + { + ++event_id; + focus.reset(); + while (focus.move_next()) + { + if (focus.element()->event_id != event_id) + { + focus.element()->event_id = event_id; + focus.element()->on_focus_lost(); + } + } + } + +// ---------------------------------------------------------------------------------------- + + void drawable_window:: + on_mouse_down ( + unsigned long btn, + unsigned long state, + long x, + long y, + bool is_double_click + ) + { + lastx = x; + lasty = y; + + ++event_id; + mouse_click.reset(); + while (mouse_click.move_next()) + { + if (mouse_click.element()->event_id != event_id) + { + mouse_click.element()->event_id = event_id; + mouse_click.element()->on_mouse_down(btn,state,x,y,is_double_click); + } + } + } + +// ---------------------------------------------------------------------------------------- + + void drawable_window:: + on_mouse_up ( + unsigned long btn, + unsigned long state, + long x, + long y + ) + { + lastx = x; + lasty = y; + + ++event_id; + mouse_click.reset(); + while (mouse_click.move_next()) + { + if (mouse_click.element()->event_id != event_id) + { + mouse_click.element()->event_id = event_id; + mouse_click.element()->on_mouse_up(btn,state,x,y); + } + } + } + +// ---------------------------------------------------------------------------------------- + + void drawable_window:: + on_mouse_move ( + unsigned long state, + long x, + long y + ) + { + lastx = x; + lasty = y; + + ++event_id; + mouse_move.reset(); + while (mouse_move.move_next()) + { + if (mouse_move.element()->event_id != event_id) + { + mouse_move.element()->event_id = event_id; + mouse_move.element()->on_mouse_move(state,x,y); + } + } + } + +// ---------------------------------------------------------------------------------------- + + void drawable_window:: + on_mouse_leave ( + ) + { + lastx = -1; + lasty = -1; + + ++event_id; + mouse_move.reset(); + while (mouse_move.move_next()) + { + if (mouse_move.element()->event_id != event_id) + { + mouse_move.element()->event_id = event_id; + mouse_move.element()->on_mouse_leave(); + } + } + } + +// ---------------------------------------------------------------------------------------- + + void drawable_window:: + on_mouse_enter ( + ) + { + ++event_id; + mouse_move.reset(); + while (mouse_move.move_next()) + { + if (mouse_move.element()->event_id != event_id) + { + mouse_move.element()->event_id = event_id; + mouse_move.element()->on_mouse_enter(); + } + } + } + +// ---------------------------------------------------------------------------------------- + + void drawable_window:: + on_wheel_up ( + unsigned long state + ) + { + ++event_id; + mouse_wheel.reset(); + while (mouse_wheel.move_next()) + { + if (mouse_wheel.element()->event_id != event_id) + { + mouse_wheel.element()->event_id = event_id; + mouse_wheel.element()->on_wheel_up(state); + } + } + } + +// ---------------------------------------------------------------------------------------- + + void drawable_window:: + on_wheel_down ( + unsigned long state + ) + { + ++event_id; + mouse_wheel.reset(); + while (mouse_wheel.move_next()) + { + if (mouse_wheel.element()->event_id != event_id) + { + mouse_wheel.element()->event_id = event_id; + mouse_wheel.element()->on_wheel_down(state); + } + } + } + +// ---------------------------------------------------------------------------------------- + + void drawable_window:: + on_string_put ( + const std::wstring &str + ) + { + ++event_id; + string_put.reset(); + while (string_put.move_next()) + { + if (string_put.element()->event_id != event_id) + { + string_put.element()->event_id = event_id; + string_put.element()->on_string_put(str); + } + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ----------- drawable object ---------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + void drawable:: + enable_events ( + ) + { + auto_mutex M(m); + if (enabled_events == false) + { + enabled_events = true; + drawable* temp = this; + long zo = z_order_value; + + drawable_window::set_of_drawables* sod = parent.widgets[zo]; + if (sod == 0) + { + // this drawable is the first widget at this z order so we need + // to make its containing set + drawable_window::set_of_drawables s; + s.add(temp); + parent.widgets.add(zo,s); + } + else + { + sod->add(temp); + } + + temp = this; + parent.widget_set.add(temp); + + if (events & MOUSE_MOVE) + { + temp = this; + parent.mouse_move.add(temp); + } + if (events & MOUSE_CLICK) + { + temp = this; + parent.mouse_click.add(temp); + } + if (events & MOUSE_WHEEL) + { + temp = this; + parent.mouse_wheel.add(temp); + } + if (events & WINDOW_RESIZED) + { + temp = this; + parent.window_resized.add(temp); + } + if (events & KEYBOARD_EVENTS) + { + temp = this; + parent.keyboard.add(temp); + } + if (events & FOCUS_EVENTS) + { + temp = this; + parent.focus.add(temp); + } + if (events & WINDOW_MOVED) + { + temp = this; + parent.window_moved.add(temp); + } + if (events & STRING_PUT) + { + temp = this; + parent.string_put.add(temp); + } + parent.invalidate_rectangle(rect); + } + } + +// ---------------------------------------------------------------------------------------- + + void drawable:: + set_z_order ( + long order + ) + { + auto_mutex M(m); + if (order != z_order_value && enabled_events) + { + // first remove this drawable from widgets + drawable_window::set_of_drawables* sod = parent.widgets[z_order_value]; + drawable* junk; + sod->remove(this,junk); + + // if there are no more drawables at this z order then destroy the + // set for this order + if (sod->size() == 0) + parent.widgets.destroy(z_order_value); + + // now add this drawable to its new z order + sod = parent.widgets[order]; + if (sod == 0) + { + // this drawable is the first widget at this z order so we need + // to make its containing set + drawable_window::set_of_drawables s, x; + s.add(junk); + long temp_order = order; + parent.widgets.add(temp_order,s); + } + else + { + sod->add(junk); + } + parent.invalidate_rectangle(rect); + + } + z_order_value = order; + } + +// ---------------------------------------------------------------------------------------- + + void drawable:: + disable_events ( + ) + { + auto_mutex M(m); + if (enabled_events) + { + enabled_events = false; + // first remove this drawable from widgets + drawable_window::set_of_drawables* sod = parent.widgets[z_order_value]; + drawable* junk; + sod->remove(this,junk); + + // if there are no more drawables at this z order then destroy the + // set for this order + if (sod->size() == 0) + parent.widgets.destroy(z_order_value); + + parent.widget_set.remove(this,junk); + + // now unregister this drawable from all the events it has registered for. + if (events & MOUSE_MOVE) + parent.mouse_move.remove(this,junk); + if (events & MOUSE_CLICK) + parent.mouse_click.remove(this,junk); + if (events & MOUSE_WHEEL) + parent.mouse_wheel.remove(this,junk); + if (events & WINDOW_RESIZED) + parent.window_resized.remove(this,junk); + if (events & KEYBOARD_EVENTS) + parent.keyboard.remove(this,junk); + if (events & FOCUS_EVENTS) + parent.focus.remove(this,junk); + if (events & WINDOW_MOVED) + parent.window_moved.remove(this,junk); + if (events & STRING_PUT) + parent.string_put.remove(this,junk); + } + } + +// ---------------------------------------------------------------------------------------- + + drawable:: + ~drawable ( + ) + { + try + { + DLIB_ASSERT(events_are_enabled() == false, + "\tdrawable::~drawable()" + << "\n\tYou must disable events for drawable objects in their destructor by calling disable_events()." + << "\n\tthis: " << this + ); + } + catch (std::exception& e) + { + std::cerr << e.what() << std::endl; + assert(false); + abort(); + } + disable_events(); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_DRAWABLe_CPP_ + diff --git a/ml/dlib/dlib/gui_widgets/drawable.h b/ml/dlib/dlib/gui_widgets/drawable.h new file mode 100644 index 000000000..a270b53c8 --- /dev/null +++ b/ml/dlib/dlib/gui_widgets/drawable.h @@ -0,0 +1,527 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net), Keita Mochizuki +// License: Boost Software License See LICENSE.txt for the full license. + +#ifndef DLIB_DRAWABLe_ +#define DLIB_DRAWABLe_ + +#include + +#include "drawable_abstract.h" +#include "../gui_core.h" +#include "../set.h" +#include "../binary_search_tree.h" +#include "../algs.h" +#include "../pixel.h" +#include "fonts.h" +#include "../matrix.h" +#include "canvas_drawing.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // class drawable_window +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class drawable; + class drawable_window : public base_window + { + /*! + INITIAL VALUE + - lastx == -1 + - lasty == -1 + - event_id == 1 + + CONVENTION + - bg_color == background_color() + + - widgets == this binary search tree contains every drawable that is in + this window. It is a mapping of each drawable's z-order to a pointer + to said drawable. + - widget_set == a set that contains all the widgets in this window and + want to receive events. + + - mouse_move == this is a set of drawables that are in this window and + want to receive the mouse movement events. + - mouse_wheel == this is a set of drawables that are in this window and + want to receive the mouse wheel events. + - mouse_click == this is a set of drawables that are in this window and + want to receive the mouse click events. + - window_resized == this is a set of drawables that are in this window and + want to receive the window_resized event. + - keyboard == this is a set of drawables that are in this window and + want to receive keyboard events. + - focus == this is a set of drawables that are in this window and + want to receive focus events. + - window_moved == this is a set of drawables that are in this window and + want to receive window move events. + + - lastx == the x coordinate that we last saw the mouse at or -1 if the + mouse is outside this window. + - lasty == the y coordinate that we last saw the mouse at or -1 if the + mouse is outside this window. + + - event_id == a number we use to tag events so we don't end up sending + an event to a drawable more than once. This could happen if one of the + event handlers does something to reset the enumerator while we are + dispatching events (e.g. creating a new widget). + !*/ + public: + + drawable_window( + bool resizable = true, + bool undecorated = false + ) : + base_window(resizable,undecorated), + bg_color(rgb_pixel(212,208,200)), + lastx(-1), + lasty(-1), + event_id(1) + {} + + void set_background_color ( + unsigned long red, + unsigned long green, + unsigned long blue + ); + + rgb_pixel background_color ( + ) const; + + virtual inline ~drawable_window()=0; + + private: + + void paint ( + const canvas& c + ); + + protected: + + void on_window_resized( + ); + + void on_window_moved( + ); + + void on_mouse_down ( + unsigned long btn, + unsigned long state, + long x, + long y, + bool is_double_click + ); + + void on_mouse_up ( + unsigned long btn, + unsigned long state, + long x, + long y + ); + + void on_mouse_move ( + unsigned long state, + long x, + long y + ); + + void on_mouse_leave ( + ); + + void on_mouse_enter ( + ); + + void on_wheel_up ( + unsigned long state + ); + + void on_wheel_down ( + unsigned long state + ); + + void on_focus_gained ( + ); + + void on_focus_lost ( + ); + + void on_keydown ( + unsigned long key, + bool is_printable, + unsigned long state + ); + + void on_string_put ( + const std::wstring &str + ); + + void on_user_event ( + void* p, + int i + ); + + private: + + friend class drawable; + + + rgb_pixel bg_color; + + typedef set::kernel_1a_c set_of_drawables; + + binary_search_tree::kernel_1a_c widgets; + + set_of_drawables widget_set; + set_of_drawables mouse_move; + set_of_drawables mouse_wheel; + set_of_drawables mouse_click; + set_of_drawables window_resized; + set_of_drawables keyboard; + set_of_drawables focus; + set_of_drawables window_moved; + set_of_drawables string_put; + + long lastx, lasty; + unsigned long event_id; + + + // restricted functions + drawable_window(drawable_window&); // copy constructor + drawable_window& operator=(drawable_window&); // assignment operator + + + }; + + drawable_window::~drawable_window(){ close_window();} + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // class drawable +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + enum + { + MOUSE_MOVE = 1, + MOUSE_CLICK = 2, + MOUSE_WHEEL = 4, + WINDOW_RESIZED = 8, + KEYBOARD_EVENTS = 16, + FOCUS_EVENTS = 32, + WINDOW_MOVED = 64, + STRING_PUT = 128 + }; + + class drawable + { + + /*! + INITIAL VALUE + - enabled_events == false + - event_id == 0 + + CONVENTION + - events == a bitset specifying what events this drawable is to receive. + + - z_order_value == z_order() + + - if (this drawable has been added to the parent window's sets and + binary search tree) then + - enabled_events == true + - else + - enabled_events == false + + - event_id == the id of the last event we got from our parent window + !*/ + + public: + + friend class drawable_window; + + drawable ( + drawable_window& w, + unsigned long events_ = 0 + ) : + m(w.wm), + parent(w), + hidden(false), + enabled(true), + lastx(w.lastx), + lasty(w.lasty), + mfont(default_font::get_font()), + z_order_value(0), + events(events_), + enabled_events(false), + event_id(0) + {} + + virtual ~drawable ( + ); + + long z_order ( + ) const + { + m.lock(); + long temp = z_order_value; + m.unlock(); + return temp; + } + + virtual void set_z_order ( + long order + ); + + const rectangle get_rect ( + ) const + { + auto_mutex M(m); + return rect; + } + + long bottom ( + ) const + { + auto_mutex M(m); + return rect.bottom(); + } + + long top ( + ) const + { + auto_mutex M(m); + return rect.top(); + } + + long left ( + ) const + { + auto_mutex M(m); + return rect.left(); + } + + long right ( + ) const + { + auto_mutex M(m); + return rect.right(); + } + + long width ( + ) const + { + auto_mutex M(m); + return rect.width(); + } + + long height ( + ) const + { + auto_mutex M(m); + return rect.height(); + } + + bool is_enabled ( + ) const + { + auto_mutex M(m); + return enabled; + } + + virtual void enable ( + ) + { + auto_mutex M(m); + enabled = true; + parent.invalidate_rectangle(rect); + } + + virtual void disable ( + ) + { + auto_mutex M(m); + enabled = false; + parent.invalidate_rectangle(rect); + } + + virtual void set_main_font ( + const std::shared_ptr& f + ) + { + auto_mutex M(m); + mfont = f; + parent.invalidate_rectangle(rect); + } + + const std::shared_ptr main_font ( + ) const + { + auto_mutex M(m); + return mfont; + } + + bool is_hidden ( + ) const + { + auto_mutex M(m); + return hidden; + } + + virtual void set_pos ( + long x, + long y + ) + { + m.lock(); + rectangle old(rect); + + const unsigned long width = rect.width(); + const unsigned long height = rect.height(); + rect.set_top(y); + rect.set_left(x); + rect.set_right(static_cast(x+width)-1); + rect.set_bottom(static_cast(y+height)-1); + + parent.invalidate_rectangle(rect+old); + m.unlock(); + } + + virtual void show ( + ) + { + m.lock(); + hidden = false; + parent.invalidate_rectangle(rect); + m.unlock(); + } + + virtual void hide ( + ) + { + m.lock(); + hidden = true; + parent.invalidate_rectangle(rect); + m.unlock(); + } + + base_window& parent_window ( + ) { return parent; } + + const base_window& parent_window ( + ) const { return parent; } + + virtual int next_free_user_event_number ( + )const { return 0; } + + protected: + rectangle rect; + const rmutex& m; + drawable_window& parent; + bool hidden; + bool enabled; + const long& lastx; + const long& lasty; + std::shared_ptr mfont; + + + void enable_events ( + ); + + bool events_are_enabled ( + ) const { auto_mutex M(m); return enabled_events; } + + void disable_events ( + ); + + private: + + long z_order_value; + const unsigned long events; + bool enabled_events; + unsigned long event_id; + + + // restricted functions + drawable(drawable&); // copy constructor + drawable& operator=(drawable&); // assignment operator + + + protected: + + virtual void draw ( + const canvas& c + ) const=0; + + virtual void on_user_event ( + int + ){} + + virtual void on_window_resized( + ){} + + virtual void on_window_moved( + ){} + + virtual void on_mouse_down ( + unsigned long , + unsigned long , + long , + long , + bool + ){} + + virtual void on_mouse_up ( + unsigned long , + unsigned long , + long , + long + ){} + + virtual void on_mouse_move ( + unsigned long , + long , + long + ){} + + virtual void on_mouse_leave ( + ){} + + virtual void on_mouse_enter ( + ){} + + virtual void on_wheel_up ( + unsigned long + ){} + + virtual void on_wheel_down ( + unsigned long + ){} + + virtual void on_focus_gained ( + ){} + + virtual void on_focus_lost ( + ){} + + virtual void on_keydown ( + unsigned long , + bool , + unsigned long + ){} + + virtual void on_string_put ( + const std::wstring& + ){} + }; + +// ---------------------------------------------------------------------------------------- + +} + +#ifdef NO_MAKEFILE +#include "drawable.cpp" +#endif + +#endif // DLIB_DRAWABLe_ + diff --git a/ml/dlib/dlib/gui_widgets/drawable_abstract.h b/ml/dlib/dlib/gui_widgets/drawable_abstract.h new file mode 100644 index 000000000..8f741d8bb --- /dev/null +++ b/ml/dlib/dlib/gui_widgets/drawable_abstract.h @@ -0,0 +1,717 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#undef DLIB_DRAWABLe_ABSTRACT_ +#ifdef DLIB_DRAWABLe_ABSTRACT_ + +#include "../gui_core.h" +#include "fonts_abstract.h" +#include "canvas_drawing_abstract.h" + +namespace dlib +{ + + /*! + GENERAL REMARKS + This file defines the drawable interface class and the drawable_window which + is just a window that is capable of displaying drawable objects (i.e. objects + that implement the drawable interface). + + The drawable interface is a simple framework for creating more complex + graphical widgets. It provides a default set of functionality and a + set of events which a gui widget may use. + + THREAD SAFETY + All objects and functions defined in this file are thread safe. You may + call them from any thread without serializing access to them. + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // class drawable_window +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class drawable_window : public base_window + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a window on the desktop that is capable of + containing drawable objects. + + INITIAL STATE + The initial state of the drawable_window is to be hidden. This means + you need to call show() to make it appear. + + EVENTS + The drawable_window object uses all the events provided by base_window + except for the on_window_close() event. This means that if you + define handlers for these events yourself you will have to call + the drawable_window's version of them so that the drawable_window + can continue to process and forward these events to its drawable + objects. + !*/ + public: + + drawable_window ( + bool resizable = true, + bool undecorated = false + ); + /*! + requires + - if (undecorated == true) then + - resizable == false + ensures + - #*this has been properly initialized + - #background_color() == rgb_pixel(212,208,200) + - if (resizable == true) then + - this window will be resizable by the user + - else + - this window will not be resizable by the user + - if (undecorated == true) then + - this window will not have any graphical elements outside + of its drawable area or appear in the system task bar. + (e.g. a popup menu) + throws + - std::bad_alloc + - dlib::thread_error + - dlib::gui_error + This exception is thrown if there is an error while + creating this window. + !*/ + + virtual ~drawable_window( + )=0; + /*! + ensures + - if (this window has not already been closed) then + - closes the window + - does NOT trigger the on_window_close() event + - all resources associated with *this have been released + !*/ + + void set_background_color ( + unsigned long red, + unsigned long green, + unsigned long blue + ); + /*! + ensures + - #background_color().red == red + - #background_color().green == green + - #background_color().blue == blue + !*/ + + rgb_pixel background_color ( + ) const; + /*! + ensures + - returns the background color this window paints its canvas + with before it passes it onto its drawable widgets + !*/ + + private: + // restricted functions + drawable_window(drawable_window&); // copy constructor + drawable_window& operator=(drawable_window&); // assignment operator + + friend class drawable; + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // class drawable +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + enum + { + MOUSE_MOVE = 1, + MOUSE_CLICK = 2, + MOUSE_WHEEL = 4, + WINDOW_RESIZED = 8, + KEYBOARD_EVENTS = 16, + FOCUS_EVENTS = 32, + WINDOW_MOVED = 64, + STRING_PUT = 128 + }; + + class drawable + { + /*! + INITIAL VALUE + top() == 0 + left() == 0 + right() == -1 + bottom() == -1 + get_rect().is_empty() == true + is_hidden() == false + is_enabled() == true + z_order() == 0 + main_font() == default_font::get_font() + + WHAT THIS OBJECT REPRESENTS + This is an interface that all drawable widgets implement. It + provides a standard method (draw()) to draw a widget onto a canvas + and many other convenient functions for drawable objects. + + EVENT FORWARDING + All the events that come to a drawable object are forwarded from its + parent window. Additionally, there is no filtering. This means that + if a drawable registers to receive a certain kind of event then whenever + its parent window receives that event the drawable object will get a + forwarded copy of it as well even if the event occurred outside the + drawable's rectangle. + + The only events that have anything in the way of filtering are the + draw() and on_user_event() events. draw() is only called on a drawable + object when that object is not hidden. on_user_event() is only called + for drawables that the on_user_event()'s first argument specifically + references. All other events are not filtered at all though. + + Z ORDER + Z order defines the order in which drawable objects are drawn. The + lower numbered drawables are drawn first and then the higher numbered + ones. So a drawable with a z order of 0 is drawn before one with a + z order of 1 and so on. + !*/ + + public: + + friend class drawable_window; + + drawable ( + drawable_window& w, + unsigned long events = 0 + ) : + m(w.wm), + parent(w), + hidden(false), + enabled(true) + {} + /*! + ensures + - #*this is properly initialized + - #parent_window() == w + - #*this will not receive any events or draw() requests until + enable_events() is called + - once events_are_enabled() == true this drawable will receive + the on_user_event() event. (i.e. you always get this event, you don't + have to enable it by setting something in the events bitset). + - if (events & MOUSE_MOVE) then + - once events_are_enabled() == true this drawable will receive + the following events related to mouse movement: on_mouse_move, + on_mouse_leave, and on_mouse_enter. + - if (events & MOUSE_CLICK) then + - once events_are_enabled() == true this drawable will receive + the following events related to mouse clicks: on_mouse_down and + on_mouse_up. + - if (events & MOUSE_WHEEL) then + - once events_are_enabled() == true this drawable will receive + the following events related to mouse wheel scrolling: + on_wheel_up and on_wheel_down. + - if (events & WINDOW_RESIZED) then + - once events_are_enabled() == true this drawable will receive + the following event related to its parent window resizing: + on_window_resized. + - if (events & KEYBOARD_EVENTS) then + - once events_are_enabled() == true this drawable will receive + the following keyboard event: on_keydown. + - if (events & FOCUS_EVENTS) then + - once events_are_enabled() == true this drawable will receive + the following focus events: on_focus_gained and on_focus_lost. + - if (events & WINDOW_MOVED) then + - once events_are_enabled() == true this drawable will receive + the following event related to its parent window moving: + on_window_moved. + - if (events & STRING_PUT) then + - once events_are_enabled() == true this drawable will receive + the following event related to wide character string input: + on_string_put. + throws + - std::bad_alloc + - dlib::thread_error + !*/ + + virtual ~drawable ( + ); + /*! + requires + - events_are_enabled() == false + ensures + - any resources associated with *this have been released + - *this has been removed from its containing window parent_window() and + its parent window will no longer try to dispatch events to it. + Note that this does not trigger a redraw of the parent window. If you + want to do that you must do it yourself. + !*/ + + long z_order ( + ) const; + /*! + ensures + - returns the z order for this drawable. + !*/ + + virtual void set_z_order ( + long order + ); + /*! + ensures + - #z_order() == order + - if (events_are_enabled() == true) then + - parent_window() is updated to reflect the new state of #*this + throws + - std::bad_alloc + !*/ + + const rectangle get_rect ( + ) const; + /*! + ensures + - returns the rectangle that defines the area and position of this + drawable inside its containing window parent_window(). + !*/ + + long bottom ( + ) const; + /*! + ensures + - returns get_rect().bottom() + !*/ + + long top ( + ) const; + /*! + ensures + - returns get_rect().top() + !*/ + + long left ( + ) const; + /*! + ensures + - returns get_rect().left() + !*/ + + long right ( + ) const; + /*! + ensures + - returns get_rect().right() + !*/ + + unsigned long width ( + ) const; + /*! + ensures + - returns get_rect().width() + !*/ + + unsigned long height ( + ) const; + /*! + ensures + - returns get_rect().height() + !*/ + + virtual void set_pos ( + long x, + long y + ); + /*! + ensures + - #top() == y + - #left() == x + - #width() == width() + - #height() == height() + - if (events_are_enabled() == true) then + - parent_window() is updated to reflect the new state of #*this + - i.e. This just sets the upper left corner of this drawable to the + location (x,y) + !*/ + + bool is_enabled ( + ) const; + /*! + ensures + - returns true if this object is enabled and false otherwise. + (it is up to derived classes to define exactly what it means to be + "enabled") + !*/ + + virtual void enable ( + ); + /*! + ensures + - #is_enabled() == true + - if (events_are_enabled() == true) then + - parent_window() is updated to reflect the new state of #*this + !*/ + + virtual void disable ( + ); + /*! + ensures + - #is_enabled() == false + - if (events_are_enabled() == true) then + - parent_window() is updated to reflect the new state of #*this + !*/ + + virtual void set_main_font ( + const shared_ptr_thread_safe& f + ); + /*! + ensures + - #main_font() == f + - if (events_are_enabled() == true) then + - parent_window() is updated to reflect the new state of #*this + !*/ + + const shared_ptr_thread_safe main_font ( + ) const; + /*! + ensures + - returns the current main font being used by this widget + !*/ + + bool is_hidden ( + ) const; + /*! + ensures + - returns true if this object is NOT currently displayed on parent_window() + and false otherwise. + !*/ + + virtual void show ( + ); + /*! + ensures + - #is_hidden() == false + - if (events_are_enabled() == true) then + - parent_window() is updated to reflect the new state of #*this + !*/ + + virtual void hide ( + ); + /*! + ensures + - #is_hidden() == true + - if (events_are_enabled() == true) then + - parent_window() is updated to reflect the new state of #*this + !*/ + + drawable_window& parent_window ( + ); + /*! + ensures + - returns a reference to the drawable_window that this drawable is + being drawn on and receiving events from. + !*/ + + const drawable_window& parent_window ( + ) const; + /*! + ensures + - returns a const reference to the drawable_window that this drawable + is being drawn on and receiving events from. + !*/ + + virtual int next_free_user_event_number ( + )const { return 0; } + /*! + ensures + - returns the smallest number, i, that is the next user event number you + can use in calls to parent.trigger_user_event((void*)this,i). + - This function exists because of the following scenario. Suppose + you make a class called derived1 that inherits from drawable and + in derived1 you use a user event to do something. Then suppose + you inherit from derived1 to make derived2. Now in derived2 you + may want to use a user event to do something as well. How are you + to know which user event numbers are in use already? This function + solves that problem. You would define derived1::next_free_user_event_number() + so that it returned a number bigger than any user event numbers used by + derived1 or its ancestors. Then derived2 could just call + derived1::next_free_user_event_number() to find out what numbers it could use. + !*/ + + protected: + /*!A drawable_protected_variables + + These protected members are provided because they are needed to + implement drawable widgets. + !*/ + + // This is the rectangle that is returned by get_rect() + rectangle rect; + + // This is the mutex used to serialize access to this class. + const rmutex& m; + + // This is the parent window of this drawable + drawable_window& parent; + + // This is the bool returned by is_hidden() + bool hidden; + + // This is the bool returned by is_enabled() + bool enabled; + + // This is the font pointer returned by main_font() + shared_ptr_thread_safe mfont; + + // This is the x coordinate that we last saw the mouse at or -1 if the mouse + // is outside the parent window. + const long& lastx; + + // This is the y coordinate that we last saw the mouse at or -1 if the mouse + // is outside the parent window. + const long& lasty; + + + void enable_events ( + ); + /*! + ensures + - #events_are_enabled() == true + !*/ + + void disable_events ( + ); + /*! + ensures + - #events_are_enabled() == false + !*/ + + bool events_are_enabled ( + ) const; + /*! + ensures + - returns true if this object is receiving events and draw() + requests from its parent window. + - returns false otherwise + !*/ + + // ---------------- EVENT HANDLERS ------------------ + + virtual void on_user_event ( + int i + ){} + /*! + requires + - events_are_enabled() == true + - mutex m is locked + - is called whenever the parent window receives an on_user_event(p,i) event + where p == this. (i.e. this is just a redirect of on_user_event for + cases where the first argument of on_user_event is equal to the + this pointer). + ensures + - does not change the state of mutex m. + !*/ + + virtual void on_window_resized( + ){} + /*! + requires + - events_are_enabled() == true + - mutex m is locked + - this is just the base_window::on_window_resized() event forwarded to + this object. See the gui_core specs for the details about this event. + ensures + - does not change the state of mutex m. + !*/ + + virtual void on_window_moved( + ){} + /*! + requires + - events_are_enabled() == true + - mutex m is locked + - this is just the base_window::on_window_moved() event forwarded to + this object. See the gui_core specs for the details about this event. + ensures + - does not change the state of mutex m. + !*/ + + virtual void on_mouse_down ( + unsigned long btn, + unsigned long state, + long x, + long y, + bool is_double_click + ){} + /*! + requires + - events_are_enabled() == true + - mutex m is locked + - this is just the base_window::on_mouse_down() event forwarded to + this object. See the gui_core specs for the details about this event. + ensures + - does not change the state of mutex m. + !*/ + + virtual void on_mouse_up ( + unsigned long btn, + unsigned long state, + long x, + long y + ){} + /*! + requires + - events_are_enabled() == true + - mutex m is locked + - this is just the base_window::on_mouse_up() event forwarded to + this object. See the gui_core specs for the details about this event. + ensures + - does not change the state of mutex m. + !*/ + + virtual void on_mouse_move ( + unsigned long state, + long x, + long y + ){} + /*! + requires + - events_are_enabled() == true + - mutex m is locked + - x == lastx + - y == lasty + - this is just the base_window::on_mouse_move() event forwarded to + this object. See the gui_core specs for the details about this event. + ensures + - does not change the state of mutex m. + !*/ + + virtual void on_mouse_leave ( + ){} + /*! + requires + - events_are_enabled() == true + - mutex m is locked + - this is just the base_window::on_mouse_leave() event forwarded to + this object. See the gui_core specs for the details about this event. + ensures + - does not change the state of mutex m. + !*/ + + virtual void on_mouse_enter ( + ){} + /*! + requires + - events_are_enabled() == true + - mutex m is locked + - this is just the base_window::on_mouse_enter() event forwarded to + this object. See the gui_core specs for the details about this event. + ensures + - does not change the state of mutex m. + !*/ + + virtual void on_wheel_up ( + unsigned long state + ){} + /*! + requires + - events_are_enabled() == true + - mutex m is locked + - this is just the base_window::on_wheel_up() event forwarded to + this object. See the gui_core specs for the details about this event. + ensures + - does not change the state of mutex m. + !*/ + + virtual void on_wheel_down ( + unsigned long state + ){} + /*! + requires + - events_are_enabled() == true + - mutex m is locked + - this is just the base_window::on_wheel_down() event forwarded to + this object. See the gui_core specs for the details about this event. + ensures + - does not change the state of mutex m. + !*/ + + virtual void on_focus_gained ( + ){} + /*! + requires + - events_are_enabled() == true + - mutex m is locked + - this is just the base_window::on_focus_gained() event forwarded to + this object. See the gui_core specs for the details about this event. + ensures + - does not change the state of mutex m. + !*/ + + virtual void on_focus_lost ( + ){} + /*! + requires + - events_are_enabled() == true + - mutex m is locked + - this is just the base_window::on_focus_lost() event forwarded to + this object. See the gui_core specs for the details about this event. + ensures + - does not change the state of mutex m. + !*/ + + virtual void on_keydown ( + unsigned long key, + bool is_printable, + unsigned long state + ){} + /*! + requires + - events_are_enabled() == true + - mutex m is locked + - this is just the base_window::on_keydown() event forwarded to + this object. See the gui_core specs for the details about this event. + ensures + - does not change the state of mutex m. + !*/ + + virtual void on_string_put ( + const std::wstring &str + ){} + /*! + requires + - events_are_enabled() == true + - mutex m is locked + - this is just the base_window::on_put_string() event forwarded to + this object. See the gui_core specs for the details about this event. + ensures + - does not change the state of mutex m. + !*/ + + virtual void draw ( + const canvas& c + ) const=0; + /*! + requires + - events_are_enabled() == true + - mutex m is locked + - is_hidden() == false + - is called by parent_window() when it needs to repaint itself. + - c == the canvas object for the area of parent_window() that needs + to be repainted. + ensures + - does not change the state of mutex m. + - draws the area of *this that intersects with the canvas onto + the canvas object c. + !*/ + + private: + + // restricted functions + drawable(drawable&); // copy constructor + drawable& operator=(drawable&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_DRAWABLe_ABSTRACT_ + diff --git a/ml/dlib/dlib/gui_widgets/fonts.cpp b/ml/dlib/dlib/gui_widgets/fonts.cpp new file mode 100644 index 000000000..dfbf9f720 --- /dev/null +++ b/ml/dlib/dlib/gui_widgets/fonts.cpp @@ -0,0 +1,673 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net), and Nils Labugt, Keita Mochizuki +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_FONTs_CPP_ +#define DLIB_FONTs_CPP_ + +#include "fonts.h" + +#include +#include +#include + +#include "../serialize.h" +#include "../base64.h" +#include "../compress_stream.h" +#include "../tokenizer.h" +#include "nativefont.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + const std::string get_decoded_string_with_default_font_data() + { + dlib::base64::kernel_1a base64_coder; + dlib::compress_stream::kernel_1ea compressor; + std::ostringstream sout; + std::istringstream sin; + + /* + SOURCE BDF FILE (helvR12.bdf) COMMENTS + COMMENT $XConsortium: helvR12.bdf,v 1.15 95/01/26 18:02:58 gildea Exp $ + COMMENT $Id: helvR12.bdf,v 1.26 2004-11-28 20:08:46+00 mgk25 Rel $ + COMMENT + COMMENT + + COMMENT Copyright 1984-1989, 1994 Adobe Systems Incorporated. + COMMENT Copyright 1988, 1994 Digital Equipment Corporation. + COMMENT + COMMENT Adobe is a trademark of Adobe Systems Incorporated which may be + COMMENT registered in certain jurisdictions. + COMMENT Permission to use these trademarks is hereby granted only in + COMMENT association with the images described in this file. + COMMENT + COMMENT Permission to use, copy, modify, distribute and sell this software + COMMENT and its documentation for any purpose and without fee is hereby + COMMENT granted, provided that the above copyright notices appear in all + COMMENT copies and that both those copyright notices and this permission + COMMENT notice appear in supporting documentation, and that the names of + COMMENT Adobe Systems and Digital Equipment Corporation not be used in + COMMENT advertising or publicity pertaining to distribution of the software + COMMENT without specific, written prior permission. Adobe Systems and + COMMENT Digital Equipment Corporation make no representations about the + COMMENT suitability of this software for any purpose. It is provided "as + COMMENT is" without express or implied warranty. + COMMENT - + */ + + // The base64 encoded data we want to decode and return. + sout << "AXF+zOQzCgGitrKiOCGEL4hlIv1ZenWJyjMQ4rJ6f/oPMeHqsZn+8XnpehwFQTz3dtUGlZRAUoOa"; + sout << "uVo8UiplcFxuK69A+94rpMCMAyEeeOwZ/tRzkX4eKuU3L4xtsJDknMiYUNKaMrYimb1QJ0E+SRqQ"; + sout << "wATrMTecYNZvJJm02WibiwE4cJ5scvkHNl4KJT5QfdwRdGopTyUVdZvRvtbTLLjsJP0fQEQLqemf"; + sout << "qPE4kDD79ehrBIwLO1Y6TzxtrrIoQR57zlwTUyLenqRtSN3VLtjWYd82cehRIlTLtuxBg2s+zZVq"; + sout << "jNlNnYTSM+Swy06qnQgg+Dt0lhtlB9shR1OAlcfCtTW6HKoBk/FGeDmjTGW4bNCGv7RjgM6TlLDg"; + sout << "ZYSSA6ZCCAKBgE++U32gLHCCiVkPTkkp9P6ioR+e3SSKRNm9p5MHf+ZQ3LJkW8KFJ/K9gKT1yvyv"; + sout << "F99pAvOOq16tHRFvzBs+xZj/mUpH0lGIS7kLWr9oP2KuccVrz25aJn3kDruwTYoD+CYlOqtPO0Mv"; + sout << "dEI0LUR0Ykp1M2rWo76fJ/fpzHjV7737hjkNPJ13nO72RMDr4R5V3uG7Dw7Ng+vGX3WgJZ4wh1JX"; + sout << "pl2VMqC5JXccctzvnQvnuvBvRm7THgwQUgMKKT3WK6afUUVlJy8DHKuU4k1ibfVMxAmrwKdTUX2w"; + sout << "cje3A05Qji3aop65qEdwgI5O17HIVoRQOG/na+XRMowOfUvI4H8Z4+JGACfRrQctgYDAM9eJzm8i"; + sout << "PibyutmJfZBGg0a3oC75S5R9lTxEjPocnEyJRYNnmVnVAmKKbTbTsznuaD+D1XhPdr2t3A4bRTsp"; + sout << "toKKtlFnd9YGwLWwONDwLnoQ/IXwyF7txrRHNSVToh772U0Aih/yn5vnmcMF750eiMzRAgXu5sbR"; + sout << "VXEOVCiLgVevN5umkvjZt1eGTSSzDMrIvnv4nyOfaFsD+I76wQfgLqd71rheozGtjNc0AOTx4Ggc"; + sout << "eUSFHTDAVfTExBzckurtyuIAqF986a0JLHCtsDpBa2wWNuiQYOH3/LX1zkdU2hdamhBW774bpEwr"; + sout << "dguMxxOeDGOBgIlM5gxXGYXSf5IN3fUAEPfOPRxB7T+tpjFnWd7cg+JMabci3zhJ9ANaYT7HGeTX"; + sout << "bulKnGHjYrR1BxdK3YeliogQRU4ytmxlyL5zlNFU/759mA8XSfIPMEZn9Vxkb00q1htF7REiDcr3"; + sout << "kW1rtPAc7VQNEhT54vK/YF6rMvjO7kBZ/vLYo7E8e8hDKEnY8ucrC3KGmeo31Gei74BBcEbvJBd3"; + sout << "/YAaIKgXWwU2wSUw9wLq2RwGwyguvKBx0J/gn27tjcVAHorRBwxzPpk8r+YPyN+SifSzEL7LEy1G"; + sout << "lPHxmXTrcqnH9qraeAqXJUJvU8SJJpf/tmsAE+XSKD/kpVBnT5qXsJ1SRFS7MtfPjE1j/NYbaQBI"; + sout << "bOrh81zaYCEJR0IKHWCIsu/MC3zKXfkxFgQ9XpYAuWjSSK64YpgkxSMe8VG8yYvigOw2ODg/z4FU"; + sout << "+HpnEKF/M/mKfLKK1i/8BV7xcYVHrhEww1QznoFklJs/pEg3Kd5PE1lRii6hvTn6McVAkw+YbH9q"; + sout << "/sg4gFIAvai64hMcZ1oIZYppj3ZN6KMdyhK5s4++ZS/YOV2nNhW73ovivyi2Tjg7lxjJJtsYrLKb"; + sout << "zIN1slOICKYwBq42TFBcFXaZ6rf0Czd09tL+q6A1Ztgr3BNuhCenjhWN5ji0LccGYZo6bLTggRG/"; + sout << "Uz6K3CBBU/byLs79c5qCohrr7rlpDSdbuR+aJgNiWoU6T0i2Tvua6h51LcWEHy5P2n146/Ae2di4"; + sout << "eh20WQvclrsgm1oFTGD0Oe85GKOTA7vvwKmLBc1wwA0foTuxzVgj0TMTFBiYLTLG4ujUyBYy1N6e"; + sout << "H8EKi8H+ZAlqezrjABO3BQr33ewdZL5IeJ4w7gdGUDA6+P+7cODcBW50X9++6YTnKctuEw6aXBpy"; + sout << "GgcMfPE61G8YKBbFGFic3TVvGCLvre1iURv+F+hU4/ee6ILuPnpYnSXX2iCIK/kmkBse8805d4Qe"; + sout << "DG/8rBW9ojvAgc0jX7CatPEMHGkcz+KIZoKMI7XXK4PJpGQUdq6EdIhJC4koXEynjwwXMeC+jJqH"; + sout << "agwrlDNssq/8AA=="; + + + + // Put the data into the istream sin + sin.str(sout.str()); + sout.str(""); + + // Decode the base64 text into its compressed binary form + base64_coder.decode(sin,sout); + sin.clear(); + sin.str(sout.str()); + sout.str(""); + + // Decompress the data into its original form + compressor.decompress(sin,sout); + + // Return the decoded and decompressed data + return sout.str(); + } + + + default_font:: + default_font ( + ) + { + using namespace std; + l = new letter[256]; + + try + { + istringstream sin(get_decoded_string_with_default_font_data()); + + for (int i = 0; i < 256; ++i) + { + deserialize(l[i],sin); + } + + } + catch (...) + { + delete [] l; + throw; + } + } + +// ---------------------------------------------------------------------------------------- + + void serialize ( + const letter& item, + std::ostream& out + ) + { + try + { + serialize(item.w,out); + serialize(item.count,out); + + for (unsigned long i = 0; i < item.count; ++i) + { + serialize(item.points[i].x,out); + serialize(item.points[i].y,out); + } + } + catch (serialization_error e) + { + throw serialization_error(e.info + "\n while serializing object of type letter"); + } + } + + void deserialize ( + letter& item, + std::istream& in + ) + { + try + { + if (item.points) + delete [] item.points; + + deserialize(item.w,in); + deserialize(item.count,in); + + if (item.count > 0) + item.points = new letter::point[item.count]; + else + item.points = 0; + + for (unsigned long i = 0; i < item.count; ++i) + { + deserialize(item.points[i].x,in); + deserialize(item.points[i].y,in); + } + } + catch (serialization_error e) + { + item.w = 0; + item.count = 0; + item.points = 0; + throw serialization_error(e.info + "\n while deserializing object of type letter"); + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + namespace bdf_font_helpers + { + class bdf_parser + { + public: + bdf_parser( std::istream& in ) : in_( in ) + { + std::string str_tmp; + int int_tmp; + + str_tmp = "STARTFONT"; int_tmp = STARTFONT; keyword_map.add( str_tmp, int_tmp ); + str_tmp = "FONTBOUNDINGBOX";int_tmp = FONTBOUNDINGBOX; keyword_map.add( str_tmp, int_tmp ); + str_tmp = "DWIDTH"; int_tmp = DWIDTH; keyword_map.add( str_tmp, int_tmp ); + str_tmp = "CHARS"; int_tmp = CHARS; keyword_map.add( str_tmp, int_tmp ); + str_tmp = "STARTCHAR"; int_tmp = STARTCHAR; keyword_map.add( str_tmp, int_tmp ); + str_tmp = "ENCODING"; int_tmp = ENCODING; keyword_map.add( str_tmp, int_tmp ); + str_tmp = "BBX"; int_tmp = BBX; keyword_map.add( str_tmp, int_tmp ); + str_tmp = "BITMAP"; int_tmp = BITMAP; keyword_map.add( str_tmp, int_tmp ); + str_tmp = "ENDCHAR"; int_tmp = ENDCHAR; keyword_map.add( str_tmp, int_tmp ); + str_tmp = "ENDFONT"; int_tmp = ENDFONT; keyword_map.add( str_tmp, int_tmp ); + str_tmp = "DEFAULT_CHAR"; int_tmp = DEFAULT_CHAR; keyword_map.add( str_tmp, int_tmp ); + + tokzr.set_identifier_token( tokzr.uppercase_letters(), tokzr.uppercase_letters() + "_" ); + tokzr.set_stream( in ); + + } + + enum bdf_enums + { + NO_KEYWORD = 0, + STARTFONT = 1, + FONTBOUNDINGBOX = 2, + DWIDTH = 4, + DEFAULT_CHAR = 8, + CHARS = 16, + STARTCHAR = 32, + ENCODING = 64, + BBX = 128, + BITMAP = 256, + ENDCHAR = 512, + ENDFONT = 1024 + + }; + struct header_info + { + int FBBx, FBBy, Xoff, Yoff; + int dwx0, dwy0; + bool has_global_dw; + long default_char; + }; + struct char_info + { + int dwx0, dwy0; + int BBw, BBh, BBxoff0x, BByoff0y; + array2d bitmap; + bool has_dw; + }; + bool parse_header( header_info& info ) + { + if ( required_keyword( STARTFONT ) == false ) + return false; // parse_error: required keyword missing + info.has_global_dw = false; + int find = FONTBOUNDINGBOX | DWIDTH | DEFAULT_CHAR; + int stop = CHARS | STARTCHAR | ENCODING | BBX | BITMAP | ENDCHAR | ENDFONT; + int res; + while ( 1 ) + { + res = find_keywords( find | stop ); + if ( res & FONTBOUNDINGBOX ) + { + in_ >> info.FBBx >> info.FBBy >> info.Xoff >> info.Yoff; + if ( in_.fail() ) + return false; // parse_error + find &= ~FONTBOUNDINGBOX; + continue; + } + if ( res & DWIDTH ) + { + in_ >> info.dwx0 >> info.dwy0; + if ( in_.fail() ) + return false; // parse_error + find &= ~DWIDTH; + info.has_global_dw = true; + continue; + } + if ( res & DEFAULT_CHAR ) + { + in_ >> info.default_char; + if ( in_.fail() ) + return false; // parse_error + find &= ~DEFAULT_CHAR; + continue; + } + if ( res & NO_KEYWORD ) + return false; // parse_error: unexpected EOF + break; + } + if ( res != CHARS || ( find & FONTBOUNDINGBOX ) ) + return false; // parse_error: required keyword missing or unexpeced keyword + return true; + } + int parse_glyph( char_info& info, unichar& enc ) + { + info.has_dw = false; + int e; + int res; + while ( 1 ) + { + res = find_keywords( ENCODING ); + if ( res != ENCODING ) + return 0; // no more glyphs + in_ >> e; + if ( in_.fail() ) + return -1; // parse_error + if ( e >= static_cast(enc) ) + break; + } + int find = BBX | DWIDTH; + int stop = STARTCHAR | ENCODING | BITMAP | ENDCHAR | ENDFONT; + while ( 1 ) + { + res = find_keywords( find | stop ); + if ( res & BBX ) + { + in_ >> info.BBw >> info.BBh >> info.BBxoff0x >> info.BByoff0y; + if ( in_.fail() ) + return -1; // parse_error + find &= ~BBX; + continue; + } + if ( res & DWIDTH ) + { + in_ >> info.dwx0 >> info.dwy0; + if ( in_.fail() ) + return -1; // parse_error + find &= ~DWIDTH; + info.has_dw = true; + continue; + } + if ( res & NO_KEYWORD ) + return -1; // parse_error: unexpected EOF + break; + } + if ( res != BITMAP || ( find != NO_KEYWORD ) ) + return -1; // parse_error: required keyword missing or unexpeced keyword + unsigned h = info.BBh; + unsigned w = ( info.BBw + 7 ) / 8 * 2; + info.bitmap.set_size( h, w ); + for ( unsigned r = 0;r < h;r++ ) + { + trim(); + std::string str = ""; + extract_hex(str); + if(str.size() < w) + return -1; // parse_error + for ( unsigned c = 0;c < w;c++ ) + info.bitmap[r][c] = str[c]; + } + if ( in_.fail() ) + return -1; // parse_error + if ( required_keyword( ENDCHAR ) == false ) + return -1; // parse_error: required keyword missing + enc = e; + return 1; + } + private: + map::kernel_1a_c keyword_map; + tokenizer::kernel_1a_c tokzr; + std::istream& in_; + void extract_hex(std::string& str) + { + int type; + std::string token; + while ( 1 ) + { + type = tokzr.peek_type(); + if ( type == tokenizer::kernel_1a_c::IDENTIFIER || type == tokenizer::kernel_1a_c::NUMBER ) + { + tokzr.get_token( type, token ); + str += token; + continue; + } + break; + } + } + void trim() + { + int type; + std::string token; + while ( 1 ) + { + type = tokzr.peek_type(); + if ( type == tokenizer::kernel_1a_c::WHITE_SPACE || type == tokenizer::kernel_1a_c::END_OF_LINE ) + { + tokzr.get_token( type, token ); + continue; + } + break; + } + } + bool required_keyword( int kw ) + { + int type; + std::string token; + while ( 1 ) + { + tokzr.get_token( type, token ); + if ( type == tokenizer::kernel_1a_c::WHITE_SPACE || type == tokenizer::kernel_1a_c::END_OF_LINE ) + continue; + if ( type != tokenizer::kernel_1a_c::IDENTIFIER || keyword_map.is_in_domain( token ) == false || ( keyword_map[token] & kw ) == 0 ) + return false; + break; + } + return true; + } + int find_keywords( int find ) + { + int type; + std::string token; + while ( 1 ) + { + tokzr.get_token( type, token ); + if ( type == tokenizer::kernel_1a_c::END_OF_FILE ) + return NO_KEYWORD; + if ( type == tokenizer::kernel_1a_c::IDENTIFIER && keyword_map.is_in_domain( token ) == true ) + { + int kw = keyword_map[token]; + if ( kw & find ) + return kw; + } + } + return true; + } + + }; + + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// bdf_font functions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + bdf_font::bdf_font( + long default_char_ + ) : + default_char(0), + is_initialized( false ), + right_overflow_( 0 ), + has_global_width( false ), + specified_default_char( default_char_ ) + { + // make sure gl contains at least one letter + gl.resize(1); + } + +// ---------------------------------------------------------------------------------------- + + void bdf_font::adjust_metrics( + ) + { + if ( is_initialized == false ) + return; + // set starting values for fbb + if ( gl[default_char].num_of_points() > 0 ) + { + letter& g = gl[default_char]; + fbb.set_top( g[0].y ); + fbb.set_bottom( g[0].y ); + fbb.set_left( g[0].x ); + fbb.set_right( g[0].x ); + } + else + { + // ok, the default char was a space + // let's just choose some safe arbitrary values then... + fbb.set_top( 10000 ); + fbb.set_bottom( -10000 ); + fbb.set_left( 10000 ); + fbb.set_right( -10000 ); + } + right_overflow_ = 0; + for ( unichar n = 0; n < gl.size(); n++ ) + { + letter& g = gl[n]; + unsigned short nr_pts = g.num_of_points(); + for ( unsigned short k = 0;k < nr_pts;k++ ) + { + fbb.set_top( std::min( fbb.top(), (long)g[k].y ) ); + fbb.set_left( std::min( fbb.left(), (long)g[k].x ) ); + fbb.set_bottom( std::max( fbb.bottom(), (long)g[k].y ) ); + fbb.set_right( std::max( fbb.right(), (long)g[k].x ) ); + right_overflow_ = std::max( right_overflow_, (unsigned long)(g[k].x - g.width()) ); // superfluous? + } + } + } + +// ---------------------------------------------------------------------------------------- + + long bdf_font:: + read_bdf_file( + std::istream& in, + unichar max_enc, + unichar min_enc + ) + { + using namespace bdf_font_helpers; + + bdf_parser parser( in ); + bdf_parser::header_info hinfo; + bdf_parser::char_info cinfo; + + gl.resize(max_enc+1); + hinfo.default_char = - 1; + if ( is_initialized == false || static_cast(in.tellg()) == std::ios::beg ) + { + if ( parser.parse_header( hinfo ) == false ) + return 0; // parse_error: invalid or missing header + } + else + { + // not start of file, so use values from previous read. + hinfo.has_global_dw = has_global_width; + hinfo.dwx0 = global_width; + } + int res; + unichar nr_letters_added = 0; + unsigned width; + for ( unichar n = min_enc; n <= max_enc; n++ ) + { + if ( in.eof() ) + break; + long pos = in.tellg(); + res = parser.parse_glyph( cinfo, n ); + if ( res < 0 ) + return 0; // parse_error + if ( res == 0 ) + continue; + if ( n > max_enc ) + { + in.seekg( pos ); + break; + } + + if ( cinfo.has_dw == false ) + { + if ( hinfo.has_global_dw == false ) + return 0; // neither width info for the glyph, nor for the font as a whole (monospace). + width = hinfo.dwx0; + } + else + width = cinfo.dwx0; + + + if ( bitmap_to_letter( cinfo.bitmap, n, width, cinfo.BBxoff0x, cinfo.BByoff0y ) == false ) + return 0; + nr_letters_added++; + + if ( is_initialized == false ) + { + // Bonding rectangle for the font. + fbb.set_top( -( hinfo.Yoff + hinfo.FBBy - 1 ) ); + fbb.set_bottom( -hinfo.Yoff ); + fbb.set_left( hinfo.Xoff ); + fbb.set_right( hinfo.Xoff + hinfo.FBBx - 1 ); + // We need to compute this after all the glyphs are loaded. + right_overflow_ = 0; + // set this to something valid now, just in case. + default_char = n; + // Save any global width in case we later read from the same file. + has_global_width = hinfo.has_global_dw; + if ( has_global_width ) + global_width = hinfo.dwx0; + // dont override value specified in the constructor with value specified in the file + if ( specified_default_char < 0 && hinfo.default_char >= 0 ) + specified_default_char = hinfo.default_char; + + is_initialized = true; + } + } + if ( is_initialized == false ) + return 0; // Not a single glyph was found within the specified range. + + if ( specified_default_char >= 0 ) + default_char = specified_default_char; + // no default char specified, try find something sane. + else + default_char = 0; + + return nr_letters_added; + } + +// ---------------------------------------------------------------------------------------- + + bool bdf_font:: + bitmap_to_letter( + array2d& bitmap, + unichar enc, + unsigned long width, + int x_offset, + int y_offset + ) + { + unsigned nr_points = 0; + bitmap.reset(); + while ( bitmap.move_next() ) + { + unsigned char ch = bitmap.element(); + if ( ch > '9' ) + ch -= 'A' - '9' - 1; + ch -= '0'; + if ( ch > 0xF ) + return false; // parse error: invalid hex digit + bitmap.element() = ch; + if ( ch & 8 ) + nr_points++; + if ( ch & 4 ) + nr_points++; + if ( ch & 2 ) + nr_points++; + if ( ch & 1 ) + nr_points++; + } + + letter( width, nr_points ).swap(gl[enc]); + + unsigned index = 0; + for ( int r = 0;r < bitmap.nr();r++ ) + { + for ( int c = 0;c < bitmap.nc();c++ ) + { + int x = x_offset + c * 4; + int y = -( y_offset + bitmap.nr() - r - 1 ); + char ch = bitmap[r][c]; + letter& glyph = gl[enc]; + if ( ch & 8 ) + { + glyph[index] = letter::point( x, y ); + right_overflow_ = std::max( right_overflow_, x - width ); + index++; + } + if ( ch & 4 ) + { + glyph[index] = letter::point( x + 1, y ); + right_overflow_ = std::max( right_overflow_, x + 1 - width ); + index++; + } + if ( ch & 2 ) + { + glyph[index] = letter::point( x + 2, y ); + right_overflow_ = std::max( right_overflow_, x + 2 - width ); + index++; + } + if ( ch & 1 ) + { + glyph[index] = letter::point( x + 3, y ); + right_overflow_ = std::max( right_overflow_, x + 3 - width ); + index++; + } + } + } + return true; + } + +// ---------------------------------------------------------------------------------------- + + const std::shared_ptr get_native_font ( + ) + { + return nativefont::native_font::get_font(); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_FONTs_CPP_ + diff --git a/ml/dlib/dlib/gui_widgets/fonts.h b/ml/dlib/dlib/gui_widgets/fonts.h new file mode 100644 index 000000000..5d3181aaa --- /dev/null +++ b/ml/dlib/dlib/gui_widgets/fonts.h @@ -0,0 +1,628 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net), and Nils Labugt, Keita Mochizuki +// License: Boost Software License See LICENSE.txt for the full license. + +#ifndef DLIB_FONTs_ +#define DLIB_FONTs_ + +#include +#include + +#include "fonts_abstract.h" +#include "../gui_core.h" +#include "../algs.h" +#include "../serialize.h" +#include "../unicode.h" +#include "../array.h" +#include "../array2d.h" +#include "../threads.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class letter + { + /*! + INITIAL VALUE + - defined by constructor + + CONVENTION + - if (points != 0) then + - points == an array of count point structs + - w == width() + - count == num_of_points() + !*/ + public: + struct point + { + point (){} + + point ( + signed char x_, + signed char y_ + ) : + x(x_), + y(y_) + {} + + signed char x; + signed char y; + }; + + letter ( + ) : + points(0), + w(0), + count(0) + {} + + letter ( + unsigned short width_, + unsigned short point_count + ) : + points(new point[point_count]), + w(width_), + count(point_count) + {} + + ~letter( + ) + { + if (points) + delete [] points; + } + + unsigned short width ( + ) const { return w; } + + unsigned short num_of_points ( + ) const { return count;} + + point& operator[] ( + unsigned short i + ) + { + DLIB_ASSERT (i < num_of_points(), + "\tvoid letter::operator[]()" + << "\n\ti: " << i + << "\n\tnum_of_points(): " << num_of_points() ); + return points[i]; + } + + const point& operator[] ( + unsigned short i + ) const + { + DLIB_ASSERT (i < num_of_points(), + "\tvoid letter::operator[]()" + << "\n\ti: " << i + << "\n\tnum_of_points(): " << num_of_points() ); + return points[i]; + } + + friend void serialize ( + const letter& item, + std::ostream& out + ); + + friend void deserialize ( + letter& item, + std::istream& in + ); + + void swap ( + letter& item + ) + { + exchange(points, item.points); + exchange(w, item.w); + exchange(count, item.count); + } + + private: + // restricted functions + letter(letter&); // copy constructor + letter& operator=(letter&); // assignment operator + + point* points; + unsigned short w; + unsigned short count; + }; + + inline void swap ( + letter& a, + letter& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- + + class font + { + public: + virtual ~font() {} + + virtual bool has_character ( + unichar ch + )const=0; + bool has_character(char ch) const { return this->has_character(zero_extend_cast(ch)); } + bool has_character(wchar_t ch) const { return this->has_character(zero_extend_cast(ch)); } + + const letter& operator[] (char ch) const { return (*this)[zero_extend_cast(ch)]; }; + const letter& operator[] (wchar_t ch)const { return (*this)[zero_extend_cast(ch)]; }; + + virtual const letter& operator[] ( + unichar ch + )const=0; + + virtual unsigned long height ( + ) const = 0; + + virtual unsigned long ascender ( + ) const = 0; + + virtual unsigned long left_overflow ( + ) const = 0; + + virtual unsigned long right_overflow ( + ) const = 0; + + // ------------------------------------------------------------------------------------ + + template + void compute_size ( + const std::basic_string& str, + unsigned long& width, + unsigned long& height, + typename std::basic_string::size_type first = 0, + typename std::basic_string::size_type last = (std::basic_string::npos) + ) const + { + typedef std::basic_string string; + DLIB_ASSERT ( (last == string::npos) || (first <= last && last < str.size()) , + "\tvoid font::compute_size()" + << "\n\tlast == string::npos: " << ((last == string::npos)?"true":"false") + << "\n\tfirst: " << (unsigned long)first + << "\n\tlast: " << (unsigned long)last + << "\n\tstr.size(): " << (unsigned long)str.size() ); + + unsigned long line_width = 0; + unsigned long newlines = 0; + width = 0; + height = 0; + + if (str.size()) + { + if (last == string::npos) + last = str.size()-1; + const font& f = *this; + + for (typename string::size_type i = first; i <= last; ++i) + { + // ignore '\r' characters + if (str[i] == '\r') + continue; + + if (str[i] == '\n') + { + ++newlines; + width = std::max(width,line_width); + line_width = 0; + } + else + { + if (is_combining_char(str[i]) == false) + line_width += f[str[i]].width(); + } + } + width = std::max(width,line_width); + + height = (newlines+1)*f.height(); + width += f.left_overflow() + f.right_overflow(); + } + } + + // ------------------------------------------------------------------------------------ + + template + void draw_string ( + const canvas& c, + const rectangle& rect, + const std::basic_string& str, + const pixel_type& color, + typename std::basic_string::size_type first = 0, + typename std::basic_string::size_type last = (std::basic_string::npos), + const rectangle area_ = rectangle(std::numeric_limits::min(), std::numeric_limits::min(), + std::numeric_limits::max(), std::numeric_limits::max()) + ) const + { + typedef std::basic_string string; + DLIB_ASSERT ( (last == string::npos) || (first <= last && last < str.size()) , + "\tvoid font::draw_string()" + << "\n\tlast == string::npos: " << ((last == string::npos)?"true":"false") + << "\n\tfirst: " << (unsigned long)first + << "\n\tlast: " << (unsigned long)last + << "\n\tstr.size(): " << (unsigned long)str.size() ); + + rectangle area = rect.intersect(c).intersect(area_); + if (area.is_empty() || str.size() == 0) + return; + + if (last == string::npos) + last = str.size()-1; + + const font& f = *this; + + long y_offset = rect.top() + f.ascender() - 1; + + long pos = rect.left()+f.left_overflow(); + for (typename string::size_type i = first; i <= last; ++i) + { + // ignore the '\r' character + if (str[i] == '\r') + continue; + + // A combining character should be applied to the previous character, and we + // therefore make one step back. If a combining comes right after a newline, + // then there must be some kind of error in the string, and we don't combine. + if(is_combining_char(str[i]) && + pos > rect.left() + static_cast(f.left_overflow())) + { + pos -= f[str[i]].width(); + } + + if (str[i] == '\n') + { + y_offset += f.height(); + pos = rect.left()+f.left_overflow(); + continue; + } + + // only look at letters in the intersection area + if (area.bottom() + static_cast(f.height()) < y_offset) + { + // the string is now below our rectangle so we are done + break; + } + else if (area.left() > pos - static_cast(f.left_overflow()) && + pos + static_cast(f[str[i]].width() + f.right_overflow()) < area.left() ) + { + pos += f[str[i]].width(); + continue; + } + else if (area.right() + static_cast(f.right_overflow()) < pos) + { + // keep looking because there might be a '\n' in the string that + // will wrap us around and put us back into our rectangle. + continue; + } + + // at this point in the loop we know that f[str[i]] overlaps + // horizontally with the intersection rectangle area. + + const letter& l = f[str[i]]; + for (unsigned short i = 0; i < l.num_of_points(); ++i) + { + const long x = l[i].x + pos; + const long y = l[i].y + y_offset; + // draw each pixel of the letter if it is inside the intersection + // rectangle + if (area.contains(x,y)) + { + assign_pixel(c[y-c.top()][x-c.left()], color); + } + } + + pos += l.width(); + } + } + template + void draw_string ( + const canvas& c, + const rectangle& rect, + const std::basic_string& str + ) const + { + draw_string(c,rect, str, 0, 0, (std::basic_string::npos), + rectangle(std::numeric_limits::min(), std::numeric_limits::min(), + std::numeric_limits::max(), std::numeric_limits::max())); + } + + // ------------------------------------------------------------------------------------ + + template + const rectangle compute_cursor_rect ( + const rectangle& rect, + const std::basic_string& str, + unsigned long index, + typename std::basic_string::size_type first = 0, + typename std::basic_string::size_type last = (std::basic_string::npos) + ) const + { + typedef std::basic_string string; + DLIB_ASSERT ( (last == string::npos) || (first <= last && last < str.size()) , + "\trectangle font::compute_cursor_rect()" + << "\n\tlast == string::npos: " << ((last == string::npos)?"true":"false") + << "\n\tfirst: " << (unsigned long)first + << "\n\tlast: " << (unsigned long)last + << "\n\tindex: " << index + << "\n\tstr.size(): " << (unsigned long)str.size() ); + + const font& f = *this; + + if (last == string::npos) + last = str.size()-1; + + long x = f.left_overflow(); + long y = 0; + int count = 0; + + if (str.size() != 0) + { + for (typename string::size_type i = first; i <= last && i < index; ++i) + { + ++count; + if (str[i] == '\n') + { + x = f.left_overflow(); + y += f.height(); + count = 0; + } + else if (is_combining_char(str[i]) == false && + str[i] != '\r') + { + x += f[str[i]].width(); + } + } + } + + x += rect.left(); + y += rect.top(); + + // if the cursor is at the start of a line then back it up one pixel + if (count == 0) + --x; + + return rectangle(x,y,x,y+f.height()-1); + } + + // ------------------------------------------------------------------------------------ + + template + unsigned long compute_cursor_pos ( + const rectangle& rect, + const std::basic_string& str, + long x, + long y, + typename std::basic_string::size_type first = 0, + typename std::basic_string::size_type last = (std::basic_string::npos) + ) const + { + typedef std::basic_string string; + DLIB_ASSERT ( (last == string::npos) || (first <= last && last < str.size()) , + "\tunsigned long font::compute_cursor_pos()" + << "\n\tlast == string::npos: " << ((last == string::npos)?"true":"false") + << "\n\tfirst: " << (unsigned long)first + << "\n\tlast: " << (unsigned long)last + << "\n\tx: " << x + << "\n\ty: " << y + << "\n\tstr.size(): " << (unsigned long)str.size() ); + const font& f = *this; + + + if (str.size() == 0) + return 0; + else if (first >= str.size()) + return static_cast(str.size()); + + y -= rect.top(); + x -= rect.left(); + if (y < 0) + y = 0; + if (x < 0) + x = 0; + + if (last == string::npos) + last = str.size()-1; + + + // first figure out what line we are on + typename string::size_type pos = first; + long line = 0; + while (static_cast(y) >= f.height()) + { + ++line; + y -= f.height(); + } + + // find the start of the given line + for (typename string::size_type i = first; i <= last && line != 0; ++i) + { + if (str[i] == '\n') + { + --line; + pos = i + 1; + } + } + + + // now str[pos] == the first character of the start of the line + // that contains the cursor. + const typename string::size_type start_of_line = pos; + + + long cur_x = f.left_overflow(); + // set the current cursor position to where the mouse clicked + while (pos <= last) + { + if (x <= cur_x || str[pos] == '\n') + break; + + if (is_combining_char(str[pos]) == false && + str[pos] != '\r') + { + cur_x += f[str[pos]].width(); + } + ++pos; + } + + if (x <= cur_x) + { + if (pos != start_of_line) + { + // we might actually be closer to the previous character + // so check for that and if so then jump us back one. + const long width = f[str[pos-1]].width(); + if (x < cur_x - width/2) + --pos; + } + } + return static_cast(pos); + } + + }; + +// ---------------------------------------------------------------------------------------- + + const std::shared_ptr get_native_font (); + +// ---------------------------------------------------------------------------------------- + + class default_font : public font + { + letter* l; + + + default_font( + ); + default_font(default_font&); // copy constructor + default_font& operator=(default_font&); // assignment operator + + + + public: + static const std::shared_ptr& get_font ( + ) + { + static mutex m; + static std::shared_ptr f; + auto_mutex M(m); + if (f.get() == 0) + f.reset(new default_font); + + return f; + } + + ~default_font( + ) + { + delete [] l; + } + + unsigned long height ( + ) const { return 16; } + + unsigned long ascender ( + ) const { return 12; } + + unsigned long left_overflow ( + ) const { return 1; } + + unsigned long right_overflow ( + ) const { return 2; } + + bool has_character ( + unichar ch + )const + { + if (ch < 256 && (l[ch].width() != 0 || l[ch].num_of_points() != 0)) + return true; + else + return false; + } + + const letter& operator[] ( + unichar ch + ) const + { + if(ch < 256) + return l[ch]; + return l[0]; // just return one of the empty characters in this case + } + }; + + +// ---------------------------------------------------------------------------------------- + + class bdf_font : public font + { + + public: + bdf_font( long default_char_ = -1 ); + + long read_bdf_file( std::istream& in, unichar max_enc, unichar min_enc = 0 ); + unsigned long height() const + { + return fbb.height(); + } + unsigned long ascender() const + { + return std::max( 0L, 1 - fbb.top() ); + } + unsigned long left_overflow() const + { + return std::max( 0L, -fbb.left() ); + } + unsigned long right_overflow() const + { + return right_overflow_; + } + const letter& operator[] ( unichar uch ) const + { + if ( !has_character(uch) ) + { + return gl[default_char]; + } + return gl[uch]; + } + + bool has_character ( + unichar ch + )const + { + if (ch < gl.size() && (gl[ch].width() != 0 || gl[ch].num_of_points() != 0)) + return true; + else + return false; + } + + void adjust_metrics(); + private: + + bool bitmap_to_letter( array2d& bitmap, unichar enc, unsigned long width, int x_offset, int y_offset ); + + array gl; + unichar default_char; // if (is_intialized == true), then this MUST be an actual glyph + bool is_initialized; + rectangle fbb; + unsigned long right_overflow_; + + unsigned global_width; + bool has_global_width; + long specified_default_char; + + bdf_font( bdf_font& ); // copy constructor + bdf_font& operator=( bdf_font& ); // assignment operator + + }; + +// ---------------------------------------------------------------------------------------- + +} + +#ifdef NO_MAKEFILE +#include "fonts.cpp" +#endif + +#endif // DLIB_FONTs_ + diff --git a/ml/dlib/dlib/gui_widgets/fonts_abstract.h b/ml/dlib/dlib/gui_widgets/fonts_abstract.h new file mode 100644 index 000000000..df194bccd --- /dev/null +++ b/ml/dlib/dlib/gui_widgets/fonts_abstract.h @@ -0,0 +1,492 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net), Nils Labugt, Keita Mochizuki +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_FONTs_ABSTRACT_ +#ifdef DLIB_FONTs_ABSTRACT_ + +#include "../gui_core.h" +#include +#include "../serialize.h" +#include "../unicode.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class letter + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a letter in a font. It tells you the nominal + width of the letter and which pixels form the letter. + + THREAD SAFETY + const versions of this object are thread safe but if you are going to + be modifying it then you must serialize access to it. + !*/ + public: + struct point + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents one of the pixels of a letter. + + The origin (i.e. (0,0)) of the coordinate plane is at the left + side of the letter's baseline. Also note that y is negative when + above the baseline and positive below (it is zero on the baseline + itself). + + The x value is positive going to the right and negative to the left. + The meaning of a negative x value is that any points with a negative + x value will overlap with the preceding letter. + !*/ + + point ( + ); + /*! + ensures + - This constructor does nothing. The value of x and y + are undefined after its execution. + !*/ + + point ( + signed char x_, + signed char y_ + ); + /*! + ensures + - #x == x_ + - #y == y_ + !*/ + + + signed char x; + signed char y; + }; + + // --------------------------------- + + letter ( + ); + /*! + ensures + - #width() == 0 + - #num_of_points() == 0 + !*/ + + letter ( + unsigned short width_, + unsigned short point_count + ); + /*! + ensures + - #width() == width_ + - #num_of_points() == point_count + !*/ + + ~letter( + ); + /*! + ensures + - any resources used by *this have been freed + !*/ + + const unsigned short width ( + ) const; + /*! + ensures + - returns the width reserved for this letter in pixels. This is the + number of pixels that are reserved for this letter between adjoining + letters. It isn't necessarily the width of the actual letter itself. + (for example, you can make a letter with a width less than how wide it + actually is so that it overlaps with its neighbor letters.) + !*/ + + const unsigned short num_of_points ( + ) const; + /*! + ensures + - returns the number of pixels that make up this letter. + !*/ + + point& operator[] ( + unsigned short i + ); + /*! + requires + - i < num_of_points() + ensures + - returns a non-const reference to the ith point in this letter. + !*/ + + const point& operator[] ( + unsigned short i + ) const; + /*! + requires + - i < num_of_points() + ensures + - returns a const reference to the ith point in this letter. + !*/ + + void swap ( + letter& item + ); + /*! + ensures + - swaps *this with item + !*/ + + private: + + // restricted functions + letter(letter&); // copy constructor + letter& operator=(letter&); // assignment operator + }; + + inline void swap ( + letter& a, + letter& b + ) { a.swap(b); } + /*! + provides a global swap + !*/ + +// ---------------------------------------------------------------------------------------- + + void serialize ( + const letter& item, + std::ostream& out + ); + /*! + provides serialization support for letter objects + !*/ + + void deserialize ( + letter& item, + std::istream& in + ); + /*! + provides deserialization support for letter objects + !*/ + +// ---------------------------------------------------------------------------------------- + + class font + { + /*! + WHAT THIS OBJECT REPRESENTS + This object defines an interface for a font type. It provides metrics + for the font and functions to help you draw strings on a canvas object. + + THREAD SAFETY + All the functions in this class are thread safe. + !*/ + + public: + + virtual bool has_character ( + unichar ch + )const=0; + /*! + ensures + - if (this font has a glyph for the given character) then + - returns true + - else + - returns false + !*/ + bool has_character(char ch) const { return this->has_character(zero_extend_cast(ch)); } + bool has_character(wchar_t ch) const { return this->has_character(zero_extend_cast(ch)); } + /* Cast char and wchar_t to unichar correctly when char or wchar_t is a signed type */ + + virtual const letter& operator[] ( + unichar ch + )const=0; + /*! + ensures + - if (has_character(ch) == true) then + - returns a letter object that tells you how to draw this character. + - else + - returns some default glyph for characters that aren't in this font. + !*/ + const letter& operator[] (char ch) const { return (*this)[zero_extend_cast(ch)]; }; + const letter& operator[] (wchar_t ch) const { return (*this)[zero_extend_cast(ch)]; }; + /* Cast char and wchar_t to unichar correctly when char or wchar_t is a signed type */ + + virtual const unsigned long height ( + ) const = 0; + /*! + ensures + - returns the height in pixels of the tallest letter in the font + !*/ + + virtual const unsigned long ascender ( + ) const = 0; + /*! + ensures + - returns the height() minus the number of pixels below the baseline used + by the letter that hangs the lowest below the baseline. + !*/ + + virtual const unsigned long left_overflow ( + ) const = 0; + /*! + ensures + - returns how far outside and to the left of its width a letter + from this font may set pixels. (i.e. how many extra pixels to its + left may a font use) + !*/ + + virtual const unsigned long right_overflow ( + ) const = 0; + /*! + ensures + - returns how far outside and to the right of its width a letter + from this font may set pixels. (i.e. how many extra pixels to its + right may a font use) + !*/ + + template + void compute_size ( + const std::basic_string& str, + unsigned long& width, + unsigned long& height, + typename std::basic_string::size_type first = 0, + typename std::basic_string::size_type last = std::basic_string::npos + ) const; + /*! + requires + - if (last != std::basic_string::npos) then + - first <= last + - last < str.size() + ensures + - all characters in str with an index < first are ignored by this + function. + - if (last != std::basic_string::npos) then + - all characters in str with an index > last are ignored by + this function. + - if (str.size() == 0) then + - #width == 0 + - #height == 0 + - else + - #width == sum of the widths of the characters in the widest + line in str + left_overflow() + right_overflow(). + - #height == (count(str.begin(),str.end(),'\n')+1)*height() + !*/ + + template + void draw_string ( + const canvas& c, + const rectangle& rect, + const std::basic_string& str, + const pixel_type& color = rgb_pixel(0,0,0), + typename std::basic_string::size_type first = 0, + typename std::basic_string::size_type last = std::basic_string::npos, + const rectangle area = rectangle(-infinity,-infinity,infinity,infinity) + ) const; + /*! + requires + - if (last != std::basic_string::npos) then + - first <= last + - last < str.size() + ensures + - all characters in str with an index < first are ignored by this + function. + - if (last != std::basic_string::npos) then + - all characters in str with an index > last are ignored by + this function. + - if (str.size() == 0) then + - does nothing + - else + - draws str on the given canvas at the position defined by rect. + Also uses the given pixel colors for the font color. + - If the string is too big to fit in rect then the right and + bottom sides of it will be clipped to make it fit. + - only the part of the string that is contained inside the area + rectangle will be drawn + !*/ + + template + const rectangle compute_cursor_rect ( + const rectangle& rect, + const std::basic_string& str, + unsigned long index, + typename std::basic_string::size_type first = 0, + typename std::basic_string::size_type last = std::basic_string::npos + ) const; + /*! + requires + - if (last != std::basic_string::npos) then + - first <= last + - last < str.size() + ensures + - the returned rectangle has a width of 1 and a + height of this->height(). + - computes the location of the cursor that would sit just before + the character str[index] if str were drawn on the screen by + draw_string(rect,str,...,first,last). The cursor location is + returned in the form of a rectangle. + - if (index < first) then + - the returned cursor will be just before the character str[first]. + - if (last != std::basic_string::npos && index > last) then + - the returned cursor will be just after the character str[last] + - if (str.size() == 0) then + - the returned cursor will be just at the start of the rectangle where + str would be drawn if it wasn't empty. + - if (index > str.size()-1) then + - the returned cursor will be just after the character str[str.size()-1] + !*/ + + template + const unsigned long compute_cursor_pos ( + const rectangle& rect, + const std::basic_string& str, + long x, + long y, + typename std::basic_string::size_type first = 0, + typename std::basic_string::size_type last = std::basic_string::npos + ) const; + /*! + requires + - if (last != std::basic_string::npos) then + - first <= last + - last < str.size() + ensures + - returns a number idx that has the following properties: + - if (first < str.size()) then + - first <= idx + - else + - idx == str.size() + - if (last != std::basic_string::npos) then + - idx <= last + 1 + - compute_cursor_rect(rect,str,idx,first,last) == the cursor + position that is closest to the pixel (x,y) + !*/ + + + private: + + // restricted functions + font(font&); // copy constructor + font& operator=(font&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- + + class default_font : public font + { + /*! + WHAT THIS OBJECT REPRESENTS + This is an implementation of the Helvetica 12 point font. + + THREAD SAFETY + It is safe to call get_font() and access the returned font from any + thread and no synchronization is needed as long as it is called + after the main() function has been entered. + !*/ + + public: + static const shared_ptr_thread_safe get_font( + ); + /*! + ensures + - returns an instance of this font. + throws + - std::bad_alloc + This exception is thrown if there is a problem gathering the needed + memory for the font object. + !*/ + + private: + + // restricted functions + default_font(); // normal constructor + default_font(default_font&); // copy constructor + default_font& operator=(default_font&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- + + class bdf_font : public font + { + + /*! + WHAT THIS OBJECT REPRESENTS + This is a font object that is capable of loading of loading BDF (Glyph + Bitmap Distribution Format) font files. + + THREAD SAFETY + If you only access this object via the functions in the parent class font + then this object is thread safe. But if you need to call any of the + functions introduced in this derived class then you need to serialize + access to this object while you call these functions. + !*/ + + public: + + bdf_font( + long default_char = -1 + ); + /*! + ensures + - for all x: + - #has_character(x) == false + (i.e. this font starts out empty. You have to call read_bdf_file() + to load it with data) + - if (default_char == -1) then + - the letter returned by (*this)[ch] for values of + ch where has_character(ch) == false will be the + default glyph defined in the bdf file. + - else + - the letter returned by (*this)[ch] for values of + ch where has_character(ch) == false will be the + letter (*this)[default_char]. + !*/ + + long read_bdf_file( + std::istream& in, + unichar max_enc, + unichar min_enc = 0 + ); + /*! + ensures + - attempts to read the font data from the given input stream into + *this. The input stream is expected to contain a valid BDF file. + - reads in characters with encodings in the range min_enc to max_enc + into this font. All characters in the font file outside this range + are ignored. + - returns the number of characters loaded into this font from the + given input stream. + !*/ + + void adjust_metrics(); + /*! + ensures + - Computes metrics based on actual glyphs loaded, instead of using + the values in the bdf file header. (May be useful if loading glyphs + from more than one file or a small part of a file.) + !*/ + + private: + + bdf_font( bdf_font& ); // copy constructor + bdf_font& operator=( bdf_font& ); // assignment operator + + }; + +// ---------------------------------------------------------------------------------------- + + const shared_ptr_thread_safe get_native_font( + ); + /*! + ensures + - returns a font object that uses the local font + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_FONTs_ABSTRACT_ + diff --git a/ml/dlib/dlib/gui_widgets/nativefont.h b/ml/dlib/dlib/gui_widgets/nativefont.h new file mode 100644 index 000000000..2de0edfaa --- /dev/null +++ b/ml/dlib/dlib/gui_widgets/nativefont.h @@ -0,0 +1,612 @@ +// Copyright (C) 2006 Keita Mochizuki +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_IGG_FONT_RENDERER_H_ +#define DLIB_IGG_FONT_RENDERER_H_ +#include "../platform.h" + + +#include "../gui_widgets.h" +#include "../unicode.h" +#include "../uintn.h" + +#include +#include + +#include +#include +#include +#include + +#if defined(WIN32) +#include +#include +#elif defined(POSIX) +#include +#include +#include +#include +#include +#include +#endif + +namespace nativefont +{ +// ---------------------------------------------------------------------------------------- + + namespace font_renderer + { + typedef dlib::uint8 byte; + + +#ifdef WIN32 + template struct input2native_trait{ + }; + template <> struct input2native_trait{ + typedef char type_t; + }; + template <> struct input2native_trait{ + typedef wchar_t type_t; + }; + template <> struct input2native_trait{ + typedef wchar_t type_t; + }; +#endif + // T : N : sizeof_source_type + template struct size2inner_trait{ + }; + template <> struct size2inner_trait<1>{ + typedef char type_t; + }; + template <> struct size2inner_trait<2>{ + typedef dlib::uint16 type_t; + }; + template <> struct size2inner_trait<4>{ + typedef dlib::unichar type_t; + }; + + +// ---------------------------------------------------------------------------------------- + + template struct create_helper{ }; + template <> struct create_helper<1>{ + typedef char type_t; + type_t *istr; + int len; + create_helper(char *str){ + len = (int)strlen(str); + istr = str; + } + ~create_helper(){} + }; + template <> struct create_helper<2>{ + typedef wchar_t type_t; + type_t *istr; + bool allocated; + int len; + create_helper(wchar_t *str){ + allocated = false; + len = (int)wcslen(str); + istr = str; + } + create_helper(dlib::unichar *str){ + allocated = true; + len = 0; + int unicount = 0; + dlib::unichar *p = str; + while(*p){ + if (*p > 0xffff){ + len += 2; + }else{ + len++; + } + unicount++; + p++; + } + istr = new wchar_t[len+1]; + for (int i = 0, wi = 0; i < unicount; ++i){ + dlib::unichar high, low; + if (str[i] > 0xffff){ + dlib::unichar_to_surrogate_pair(str[i], high, low); + istr[wi] = (wchar_t)high, istr[wi+1] = (wchar_t)low; + wi += 2; + }else{ + istr[wi] = (wchar_t)str[i]; + wi += 1; + } + } + istr[len] = L'\0'; + } + + ~create_helper(){ + if (allocated) delete[] istr; + } + }; + template <> struct create_helper<4>{ + typedef wchar_t type_t; + type_t *istr; + int len; + create_helper(dlib::unichar *str){ + len = (int)wcslen((wchar_t *)str); + istr = (type_t *)str; + } + ~create_helper(){} + }; + +// ---------------------------------------------------------------------------------------- + + class font_renderer{ + public: + + struct rgb_type{ + byte r, g, b; + rgb_type() : r(0), g(0), b(0){}; + rgb_type(byte r_, byte g_, byte b_) : r(r_), g(g_), b(b_){}; + }; + private: + + byte *image; + int width, height; + void destroy(){ + width = height = 0; + delete image; + image = 0; + } + struct vals_internal{ + int width, height; +#ifdef WIN32 + COLORREF rgb2RGB(rgb_type &rgb){ + return RGB(rgb.r, rgb.g, rgb.b); + } + HBITMAP hBmp, hBmpOld; + HDC hDCBmp; + BYTE *pixelint; + HFONT hFont, hFontOld; + HBRUSH hBrush; + int pix_width_prev, pix_height_prev; + bool first; + int ascender, descender; + int height_prev; + char attribute_prev; + + template void create(T *str, int height_want, bool italic, bool bold, bool fixed, rgb_type &background, rgb_type &foreground){ + struct inner{ + inline static BOOL GetTextExtentPoint32(HDC hDC, LPCSTR str, int len, LPSIZE lpsize){ + return ::GetTextExtentPoint32A(hDC, str, len, lpsize); + } + inline static BOOL GetTextExtentPoint32(HDC hDC, LPCWSTR str, int len, LPSIZE lpsize){ + return ::GetTextExtentPoint32W(hDC, str, len, lpsize); + } + inline static BOOL TextOut(HDC hDC, int nxstart, int nystart, LPCSTR str, int cbstr){ + return ::TextOutA(hDC, nxstart, nystart, str, cbstr); + } + inline static BOOL TextOut(HDC hDC, int nxstart, int nystart, LPCWSTR str, int cbstr){ + return ::TextOutW(hDC, nxstart, nystart, str, cbstr); + } + }; + + create_helper::type_t)> ch(str); + + if (hDCBmp == NULL){ + HWND hWnd = GetDesktopWindow(); + HDC hDC = GetDC(hWnd); + hDCBmp = CreateCompatibleDC(hDC); + ReleaseDC(hWnd, hDC); + } + SetTextColor(hDCBmp, rgb2RGB(foreground)); + SetBkColor(hDCBmp, rgb2RGB(background)); + + char attribute = (italic ? 1 : 0) | (bold ? 2 : 0) | (fixed ? 4 : 0); + if (!hFont || height_prev != height || attribute != attribute_prev){ + attribute_prev = attribute; + height_prev = height_want; + if (hFont){ + SelectObject(hDCBmp, hFontOld); + DeleteObject(hFont); + } + hFont = CreateFont(height_want, 0, 0, 0, bold ? FW_BOLD : FW_DONTCARE, italic ? TRUE : FALSE, + FALSE, FALSE, DEFAULT_CHARSET, OUT_DEFAULT_PRECIS, CLIP_DEFAULT_PRECIS, DEFAULT_QUALITY, + fixed ? (FIXED_PITCH | FF_DONTCARE) : (VARIABLE_PITCH | FF_DONTCARE), NULL); + hFontOld = (HFONT)SelectObject(hDCBmp, hFont); + } + + { + SIZE sz; + inner::GetTextExtentPoint32(hDCBmp, ch.istr, ch.len, &sz); + width = ((sz.cx + 3) / 4) * 4; + height = sz.cy; + } + + if (pix_width_prev < width || pix_height_prev < height){ + if (hBmp){ + SelectObject(hDCBmp, hBmpOld); + DeleteObject(hBmp); + } + pix_width_prev = width * 2; + pix_height_prev = height * 2; + BITMAPINFO bi; + ZeroMemory(&bi, sizeof(bi)); + bi.bmiHeader.biSize = sizeof(BITMAPINFOHEADER); + bi.bmiHeader.biBitCount = 24; + bi.bmiHeader.biPlanes = 1; + bi.bmiHeader.biWidth = pix_width_prev; + bi.bmiHeader.biHeight = -pix_height_prev; + hBmp = CreateDIBSection(NULL, &bi, DIB_RGB_COLORS, (void **)&pixelint, NULL, 0); + hBmpOld = (HBITMAP)SelectObject(hDCBmp, hBmp); + } + + { + HBRUSH hBrush = CreateSolidBrush(rgb2RGB(background)); + RECT rc; + rc.left = rc.top = 0; + rc.right = pix_width_prev; + rc.bottom = pix_height_prev; + FillRect(hDCBmp, &rc, hBrush); + } + + inner::TextOut(hDCBmp, 0, 0, ch.istr, ch.len); + TEXTMETRICW tm; + GetTextMetricsW(hDCBmp,&tm); + ascender = tm.tmAscent; + descender = tm.tmDescent; + } + + template vals_internal(T *str, int height_want, bool italic = false, + bool bold = false, bool fixed = false, rgb_type background = rgb_type(), rgb_type foreground = rgb_type()){ + first = true; + hFont = NULL; + hDCBmp = 0; + hBmpOld = 0; + hBmp = 0; + hDCBmp = 0; + pixelint = 0; + pix_width_prev = pix_height_prev = 0; + height_prev = -1; + attribute_prev = 0; + create(str, height_want, italic, bold, fixed, background, foreground); + first = false; + } + + inline int get_ascender(){ + return ascender; + } + + inline int get_descender(){ + return descender; + } + + inline void get_pixel(int x, int y, byte &r, byte &g, byte &b){ + byte *p = pixelint + (y * pix_width_prev + x) * 3; + r = *(p+2), g = *(p+1), b = *p; + } + + void destroy(){ + SelectObject(hDCBmp, hBmpOld); + DeleteObject(hBmp); + SelectObject(hDCBmp, hFontOld); + DeleteObject(hFont); + DeleteDC(hDCBmp); + hFont = NULL; + hDCBmp = 0; + hBmpOld = 0; + hBmp = 0; + hDCBmp = 0; + pixelint = 0; + } + ~vals_internal(){ + destroy(); + } +#elif defined(POSIX) + XImage *ximg; + Display *d; + GC gc; + XFontSet fs; + Pixmap pix; + Colormap cmap; + int ascender, descender; + int pix_width_prev, pix_height_prev; + char fontset_prev[256]; + unsigned long rgb2color(rgb_type col, Display *d, Colormap &cmap){ + XColor xcol; + xcol.red = col.r * 257; + xcol.green = col.g * 257; + xcol.blue = col.b * 257; + XAllocColor(d, cmap, &xcol); + return xcol.pixel; + } + template void create(T *str, int height_want, bool italic, bool bold, bool fixed, rgb_type background, rgb_type foreground){ + struct inner{ + inline static int XTextExtents (XFontSet fs, char *str, int len, XRectangle *ink, XRectangle *logical){ + return XmbTextExtents(fs, str, len, ink, logical); + } + inline static int XTextExtents (XFontSet fs, wchar_t *str, int len, XRectangle *ink, XRectangle *logical){ + return XwcTextExtents(fs, str, len, ink, logical); + } + inline static void XDrawString(Display *d, Window w, XFontSet fs, GC gc, int x, int y, char *str, int num_bytes){ + XmbDrawString(d, w, fs, gc, x, y, str, num_bytes); + } + inline static void XDrawString(Display *d, Window w, XFontSet fs, GC gc, int x, int y, wchar_t *str, int num_bytes){ + XwcDrawString(d, w, fs, gc, x, y, str, num_bytes); + } + }; + create_helper ch((typename size2inner_trait::type_t *)str); + setlocale(LC_CTYPE, ""); + if (d == NULL){ + d = XOpenDisplay(NULL); + if (d == 0) + { + d = XOpenDisplay(":0.0"); + if (d == 0) + { + throw dlib::gui_error("Unable to connect to the X display."); + } + } + + cmap = DefaultColormap(d, DefaultScreen(d)); + } + char fontset[256]; + { + char *p = fontset; + p += sprintf(fontset, "-*-*-%s-%c-normal--%d-*-*-*-%c", + bold ? "bold" : "medium", italic ? 'i' : 'r', height_want, fixed ? 'c' : 'p'); + if (fixed){ + sprintf(p, ",-*-*-%s-%c-normal--%d-*-*-*-m", + bold ? "bold" : "medium", italic ? 'i' : 'r', height_want); + } + } + bool equal_font; + if (strcmp(fontset, fontset_prev) == 0){ + equal_font = true; + }else{ + equal_font = false; + strcpy(fontset_prev, fontset); + } + + char **mlist; + int mcount; + char *def_str; + if (!equal_font){ + if (fs){ + XFreeFontSet(d, fs); + } + fs = XCreateFontSet(d, fontset, &mlist, &mcount, &def_str); + if (fs == NULL) + throw dlib::gui_error("gui_error: XCreateFontSet() failure"); + + XFontSetExtents *extent; + extent = XExtentsOfFontSet(fs); + ascender = -extent->max_logical_extent.y; + descender = extent->max_logical_extent.height - ascender; + XFreeStringList(mlist); + } + XRectangle ink, logical; + inner::XTextExtents (fs, ch.istr, ch.len, &ink, &logical); + width = logical.width; + height = height_want; + + if (pix == None || pix_width_prev < width || pix_height_prev < height){ + if (pix != None){ + XFreeGC(d, gc); + XFreePixmap(d, pix); + } + pix_width_prev = width * 2; + pix_height_prev = height * 2; + pix = XCreatePixmap(d, DefaultRootWindow(d), pix_width_prev, pix_height_prev, XDefaultDepth(d, DefaultScreen(d))); + gc = XCreateGC(d, pix, 0, NULL); + } + + unsigned long backcolor = rgb2color(background, d, cmap); + XSetForeground(d, gc, backcolor); + XSetBackground(d, gc, backcolor); + XFillRectangle(d, pix, gc, 0, 0, width, height); + XSetForeground(d, gc, rgb2color(foreground, d, cmap)); + inner::XDrawString(d, pix, fs, gc, 0, ascender, ch.istr, ch.len); + + if (ximg) XDestroyImage(ximg); + ximg = XGetImage(d, pix, 0, 0, width, height, AllPlanes, ZPixmap ); + } + + template vals_internal(T *str, int height_want, bool italic = false, + bool bold = false, bool fixed = false, rgb_type background = rgb_type(), rgb_type foreground = rgb_type()){ + fontset_prev[0] = '\0'; + ximg = NULL; + d = NULL; + pix = None; + fs = NULL; + ascender = descender = -1; + pix_width_prev = pix_height_prev = -1; + create(str, height_want, italic, bold, fixed, background, foreground); + } + + inline int get_ascender(){ + return ascender; + } + + inline int get_descender(){ + return descender; + } + + std::map col2rgb; + rgb_type color2rgb(unsigned long color, Display *d, Colormap &cmap){ + if (col2rgb.count(color)){ + return col2rgb[color]; + }else{ + XColor xcol; + xcol.pixel = color; + XQueryColor(d, cmap, &xcol); + rgb_type rgb_((byte)(xcol.red/257), (byte)(xcol.green/257), (byte)(xcol.blue/257)); + col2rgb[color] = rgb_; + return rgb_; + } + } + inline void get_pixel(int x, int y, byte &r, byte &g, byte &b){ + rgb_type c = color2rgb(XGetPixel(ximg,x,y), d, cmap); + r = c.r, g = c.g, b = c.b; + } + + ~vals_internal(){ + XDestroyImage(ximg); + + XFreeGC(d, gc); + XFreeFontSet(d, fs); + XFreePixmap(d, pix); + XCloseDisplay(d); + } +#endif + }; + + struct image_size_setter{ + void operator()(int&, int&){ + } + }; + + int ascender, descender; + vals_internal *vi; + public: + font_renderer() : image(0), width(0), height(0){ + ascender = descender = 0; + vi = NULL; + } + + template font_renderer(T *str, int height_want, bool italic = false, bool bold = false, bool fixed = false, rgb_type background = rgb_type(0,0,0), rgb_type foreground = rgb_type(255,255,255)){ + render(str, height_want, italic, bold, fixed, background, foreground); + } + + template void render(T *str, int height_want, + bool italic = false, bool bold = false, bool fixed = false, + rgb_type background = rgb_type(0,0,0), rgb_type foreground = rgb_type(255,255,255)){ + if (vi == NULL){ + vi = new vals_internal(str, height_want, italic, bold, fixed, background, foreground); + }else{ + vi->create(str, height_want, italic, bold, fixed, background, foreground); + } + width = vi->width, height = vi->height; + image = new byte[width * height * 3]; + ascender = vi->get_ascender(); + descender = vi->get_descender(); + + int h = height, w = width; + for (int j = 0, i3 = 0; j < h; ++j){ + for (int i = 0; i < w; ++i, i3 += 3){ + vi->get_pixel(i, j, image[i3], image[i3+1], image[i3+2]); + } + } + } + + ~font_renderer(){ + if (vi) delete vi; + destroy(); + } + int get_width(){ + return width; + } + int get_height(){ + return height; + } + inline int get_ascender(){ + return ascender; + } + inline int get_descender(){ + return descender; + } + + const byte *get_image(){ + return image; + } + }; + } + +// ---------------------------------------------------------------------------------------- + + class native_font : public dlib::font + { + unsigned long ascender_; + native_font(){ + setlocale(LC_CTYPE, ""); + ascender_ = 0; + get_letter((int)('x')); + } + typedef std::map letters_map_type; + letters_map_type letters; + font_renderer::font_renderer fl; + public: + + virtual ~native_font() + { + // delete all the letter objects we have in our letters map + letters_map_type::iterator i; + for (i = letters.begin(); i != letters.end(); ++i) + { + delete i->second; + } + } + + virtual bool has_character ( + dlib::unichar ch + )const{ + return (*this)[ch].width() > 0; + } + + static const std::shared_ptr& get_font ( + ) + { + static std::shared_ptr f(new native_font); + return f; + } + + virtual const dlib::letter& operator[] (dlib::unichar ch) const{ + return (const_cast(this))->get_letter(ch); + } + + dlib::letter& get_letter ( + dlib::unichar ch + ){ + if (letters.count(ch)){ + dlib::letter *l = letters.find(ch)->second; + return *l; + } + + dlib::unichar c[2]; + c[0] = ch; + c[1] = 0; + + fl.render(c, height(),false,false,true); + if (ascender_ == 0){ + ascender_ = fl.get_ascender(); + } + std::vector v; + const font_renderer::byte *bp = fl.get_image(); + for (int j = 0; j < fl.get_height(); ++j){ + for (int i = 0; i < fl.get_width(); ++i, bp += 3){ + if (*bp){ + v.push_back(dlib::letter::point(i,j-ascender()+1)); + } + } + } + dlib::letter *l = new dlib::letter(fl.get_width(), (unsigned long)v.size()); + + letters.insert(std::make_pair(ch,l)); + for (int i = 0; i < (int)v.size(); ++i){ + (*l)[i] = v.at(i); + } + return *l; + } + + virtual unsigned long height ( + ) const { return 12; } + + virtual unsigned long ascender ( + ) const { return ascender_; } + + virtual unsigned long left_overflow ( + ) const { return 1; } + + virtual unsigned long right_overflow ( + ) const { return 2; } + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_IGG_FONT_RENDERER_H_ + diff --git a/ml/dlib/dlib/gui_widgets/style.cpp b/ml/dlib/dlib/gui_widgets/style.cpp new file mode 100644 index 000000000..a3d22d10f --- /dev/null +++ b/ml/dlib/dlib/gui_widgets/style.cpp @@ -0,0 +1,998 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net), and Nils Labugt +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_WIDGETs_STYLE_CPP_ +#define DLIB_WIDGETs_STYLE_CPP_ + +#include "style.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // button style stuff +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + void button_style_default::draw_button ( + const canvas& c, + const rectangle& rect, + const bool enabled, + const font& mfont, + const long , + const long , + const ustring& name, + const bool is_depressed + ) const + { + rectangle area = rect.intersect(c); + if (area.is_empty()) + return; + + fill_rect(c,rect,rgb_pixel(212,208,200)); + + unsigned char red, green, blue; + if (enabled) + { + red = 0; + green = 0; + blue = 0; + } + else + { + red = 128; + green = 128; + blue = 128; + } + + // compute the name length if it hasn't already been computed + if (name_width == 0) + { + unsigned long height; + mfont.compute_size(name,name_width,height); + } + + // figure out where the name string should appear + rectangle name_rect; + const unsigned long width = name_width; + const unsigned long height = mfont.height(); + name_rect.set_left((rect.right() + rect.left() - width)/2); + name_rect.set_top((rect.bottom() + rect.top() - height)/2 + 1); + name_rect.set_right(name_rect.left()+width-1); + name_rect.set_bottom(name_rect.top()+height); + + + if (is_depressed) + { + name_rect.set_left(name_rect.left()+1); + name_rect.set_right(name_rect.right()+1); + name_rect.set_top(name_rect.top()+1); + name_rect.set_bottom(name_rect.bottom()+1); + + mfont.draw_string(c,name_rect,name,rgb_pixel(red,green,blue)); + + draw_button_down(c,rect); + } + else + { + mfont.draw_string(c,name_rect,name,rgb_pixel(red,green,blue)); + + // now draw the edge of the button + draw_button_up(c,rect); + } + } + +// ---------------------------------------------------------------------------------------- + + rectangle button_style_default:: + get_min_size ( + const ustring& name, + const font& mfont + ) const + { + + unsigned long width; + unsigned long height; + mfont.compute_size(name,width,height); + name_width = width; + + return rectangle(width+2*padding, height+2*padding); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + void button_style_toolbar1::draw_button ( + const canvas& c, + const rectangle& rect, + const bool enabled, + const font& mfont, + const long lastx, + const long lasty, + const ustring& name, + const bool is_depressed + ) const + { + rectangle area = rect.intersect(c); + if (area.is_empty()) + return; + + const long radius = 4; + + unsigned char red, green, blue; + if (enabled) + { + red = 0; + green = 0; + blue = 0; + + long d = 0; + if (rect.contains(lastx,lasty)) + d = -70; + + if (is_depressed) + d = 20; + + if (d != 0) + { + rectangle temp(rect); + temp.left()--; temp.top()--; temp.right()++; temp.bottom()++; + draw_rounded_rectangle(c, temp, radius, rgb_alpha_pixel(255,255,0,120)); + temp.left()--; temp.top()--; temp.right()++; temp.bottom()++; + draw_rounded_rectangle(c, temp, radius, rgb_alpha_pixel(255,255,0,40)); + } + + fill_gradient_rounded(c,rect,radius,rgb_alpha_pixel(255, 255, 255,120-d), + rgb_alpha_pixel(255, 255, 255,0)); + draw_rounded_rectangle(c,rect,radius, rgb_alpha_pixel(30,30,30,200)); + } + else + { + red = 128; + green = 128; + blue = 128; + draw_rounded_rectangle(c,rect,radius, rgb_alpha_pixel(red,green,blue,210)); + } + + + // compute the name length if it hasn't already been computed + if (name_width == 0) + { + unsigned long height; + mfont.compute_size(name,name_width,height); + } + + // figure out where the name string should appear + rectangle name_rect; + const unsigned long width = name_width; + const unsigned long height = mfont.height(); + name_rect.set_left((rect.right() + rect.left() - width)/2); + name_rect.set_top((rect.bottom() + rect.top() - height)/2 + 1); + name_rect.set_right(name_rect.left()+width-1); + name_rect.set_bottom(name_rect.top()+height); + + + if (is_depressed) + { + name_rect.set_left(name_rect.left()+1); + name_rect.set_right(name_rect.right()+1); + name_rect.set_top(name_rect.top()+1); + name_rect.set_bottom(name_rect.bottom()+1); + + mfont.draw_string(c,name_rect,name,rgb_pixel(red,green,blue)); + + } + else + { + mfont.draw_string(c,name_rect,name,rgb_pixel(red,green,blue)); + } + } + +// ---------------------------------------------------------------------------------------- + + rectangle button_style_toolbar1:: + get_min_size ( + const ustring& name, + const font& mfont + ) const + { + + unsigned long width; + unsigned long height; + mfont.compute_size(name,width,height); + name_width = width; + + return rectangle(width+2*padding, height+2*padding); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + void button_style_toolbar_icon1::draw_button ( + const canvas& c, + const rectangle& rect, + const bool enabled, + const font& , + const long lastx, + const long lasty, + const ustring& , + const bool is_depressed + ) const + { + rectangle area = rect.intersect(c); + if (area.is_empty()) + return; + + const long radius = padding; + + if (enabled) + { + if (rect.contains(lastx,lasty)) + { + if (is_depressed) + { + fill_gradient_rounded(c,rect,radius,rgb_alpha_pixel(100,100,200,150), + rgb_alpha_pixel(50,50,100,100)); + draw_rounded_rectangle(c,rect,radius, rgb_alpha_pixel(150,150,30,200)); + } + else + { + fill_gradient_rounded(c,rect,radius,rgb_alpha_pixel(150,150,250,130), + rgb_alpha_pixel(100,100,150,90)); + draw_rounded_rectangle(c,rect,radius, rgb_alpha_pixel(150,150,30,200)); + } + } + + if (is_depressed) + { + rectangle img_rect(translate_rect(centered_rect(rect,img_mouseover.nc(),img_mouseover.nr()),1,1)); + point p(img_rect.left(),img_rect.top()); + draw_image(c,p,img_mouseover); + } + else + { + rectangle img_rect(centered_rect(rect,img_normal.nc(),img_normal.nr())); + point p(img_rect.left(),img_rect.top()); + if (rect.contains(lastx,lasty)) + draw_image(c,p,img_mouseover); + else + draw_image(c,p,img_normal); + } + + } + else + { + rectangle img_rect(centered_rect(rect,img_normal.nc(),img_normal.nr())); + point p(img_rect.left(),img_rect.top()); + draw_image(c,p,img_disabled); + } + } + +// ---------------------------------------------------------------------------------------- + + rectangle button_style_toolbar_icon1:: + get_min_size ( + const ustring& , + const font& + ) const + { + return rectangle(img_normal.nc()+2*padding, img_normal.nr()+2*padding); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + void button_style_arrow:: + draw_button ( + const canvas& c, + const rectangle& rect, + const bool enabled, + const font& , + const long , + const long , + const ustring& , + const bool is_depressed + ) const + { + rectangle area = rect.intersect(c); + if (area.is_empty()) + return; + + fill_rect(c,rect,rgb_pixel(212,208,200)); + + const long height = rect.height(); + const long width = rect.width(); + + const long smallest = (width < height) ? width : height; + + const long rows = (smallest+3)/4; + const long start = rows + rows/2-1; + long dep; + + long tip_x = 0; + long tip_y = 0; + long wy = 0; + long hy = 0; + long wx = 0; + long hx = 0; + + if (is_depressed) + { + dep = 0; + + // draw the button's border + draw_button_down(c,rect); + } + else + { + dep = -1; + + // draw the button's border + draw_button_up(c,rect); + } + + + switch (dir) + { + case UP: + tip_x = width/2 + rect.left() + dep; + tip_y = (height - start)/2 + rect.top() + dep + 1; + wy = 0; + hy = 1; + wx = 1; + hx = 0; + break; + + case DOWN: + tip_x = width/2 + rect.left() + dep; + tip_y = rect.bottom() - (height - start)/2 + dep; + wy = 0; + hy = -1; + wx = 1; + hx = 0; + break; + + case LEFT: + tip_x = rect.left() + (width - start)/2 + dep + 1; + tip_y = height/2 + rect.top() + dep; + wy = 1; + hy = 0; + wx = 0; + hx = 1; + break; + + case RIGHT: + tip_x = rect.right() - (width - start)/2 + dep; + tip_y = height/2 + rect.top() + dep; + wy = 1; + hy = 0; + wx = 0; + hx = -1; + break; + } + + + rgb_pixel color; + if (enabled) + { + color.red = 0; + color.green = 0; + color.blue = 0; + } + else + { + color.red = 128; + color.green = 128; + color.blue = 128; + } + + + + for (long i = 0; i < rows; ++i) + { + draw_line(c,point(tip_x + wx*i + hx*i, tip_y + wy*i + hy*i), + point(tip_x + wx*i*-1 + hx*i, tip_y + wy*i*-1 + hy*i), + color); + } + + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // toggle button style stuff +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + void toggle_button_style_default::draw_toggle_button ( + const canvas& c, + const rectangle& rect, + const bool enabled, + const font& mfont, + const long , + const long , + const ustring& name, + const bool is_depressed, + const bool is_checked + ) const + { + rectangle area = rect.intersect(c); + if (area.is_empty()) + return; + + fill_rect(c,rect,rgb_pixel(212,208,200)); + + unsigned char red, green, blue; + if (enabled) + { + red = 0; + green = 0; + blue = 0; + } + else + { + red = 128; + green = 128; + blue = 128; + } + + // compute the name length if it hasn't already been computed + if (name_width == 0) + { + unsigned long height; + mfont.compute_size(name,name_width,height); + } + + // figure out where the name string should appear + rectangle name_rect; + const unsigned long width = name_width; + const unsigned long height = mfont.height(); + name_rect.set_left((rect.right() + rect.left() - width)/2); + name_rect.set_top((rect.bottom() + rect.top() - height)/2 + 1); + name_rect.set_right(name_rect.left()+width-1); + name_rect.set_bottom(name_rect.top()+height); + + long d = 0; + if (is_checked) + d = 1; + + if (is_depressed) + d = 2; + + name_rect.set_left(name_rect.left()+d); + name_rect.set_right(name_rect.right()+d); + name_rect.set_top(name_rect.top()+d); + name_rect.set_bottom(name_rect.bottom()+d); + + mfont.draw_string(c,name_rect,name,rgb_pixel(red,green,blue)); + + // now draw the edge of the button + if (is_checked || is_depressed) + draw_button_down(c,rect); + else + draw_button_up(c,rect); + } + +// ---------------------------------------------------------------------------------------- + + rectangle toggle_button_style_default:: + get_min_size ( + const ustring& name, + const font& mfont + ) const + { + + unsigned long width; + unsigned long height; + mfont.compute_size(name,width,height); + name_width = width; + + return rectangle(width+2*padding, height+2*padding); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + void toggle_button_style_check_box::draw_toggle_button ( + const canvas& c, + const rectangle& rect, + const bool enabled, + const font& mfont, + const long , + const long , + const ustring& name, + const bool is_depressed, + const bool is_checked + ) const + { + rectangle area = rect.intersect(c); + if (area.is_empty()) + return; + + + rgb_pixel color; + if (enabled) + { + color.red = 0; + color.green = 0; + color.blue = 0; + } + else + { + color.red = 128; + color.green = 128; + color.blue = 128; + } + + + // figure out where the name string should appear + rectangle name_rect, box_rect; + unsigned long padding = 0; + if (mfont.height() < 13) + padding = (rect.height() - mfont.height())/2; + + name_rect = rect; + name_rect.set_left(rect.left() + 17-1); + name_rect.set_top(rect.top() + padding); + name_rect.set_bottom(rect.bottom() - padding); + + box_rect = rect; + box_rect.set_right(rect.left() + 12); + box_rect.set_bottom(rect.top() + 12); + + mfont.draw_string(c,name_rect,name,color); + + if (enabled && is_depressed == false) + fill_rect(c, box_rect,rgb_pixel(255,255,255)); + else + fill_rect(c, box_rect,rgb_pixel(212,208,200)); + + draw_sunken_rectangle(c, box_rect); + + + if (is_checked) + { + const long x = box_rect.left(); + const long y = box_rect.top(); + draw_line(c,point(3+x,5+y),point(6+x,8+y),color); + draw_line(c,point(3+x,6+y),point(5+x,8+y),color); + draw_line(c,point(3+x,7+y),point(5+x,9+y),color); + draw_line(c,point(6+x,6+y),point(9+x,3+y),color); + draw_line(c,point(6+x,7+y),point(9+x,4+y),color); + draw_line(c,point(6+x,8+y),point(9+x,5+y),color); + } + } + +// ---------------------------------------------------------------------------------------- + + rectangle toggle_button_style_check_box:: + get_min_size ( + const ustring& name, + const font& mfont + ) const + { + unsigned long width; + unsigned long height; + mfont.compute_size(name,width,height); + + if (height < 13) + height = 13; + + return rectangle(width + 17 -1, height -1); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + void toggle_button_style_radio_button::draw_toggle_button ( + const canvas& c, + const rectangle& rect, + const bool enabled, + const font& mfont, + const long , + const long , + const ustring& name, + const bool is_depressed, + const bool is_checked + ) const + { + rectangle area = rect.intersect(c); + if (area.is_empty()) + return; + + + rgb_pixel color; + + // figure out where the name string should appear + rectangle name_rect, box_rect; + unsigned long padding = 0; + if (mfont.height() < 13) + padding = (rect.height() - mfont.height())/2; + + name_rect = rect; + name_rect.set_left(rect.left() + 17-1); + name_rect.set_top(rect.top() + padding); + name_rect.set_bottom(rect.bottom() - padding); + + box_rect = rect; + box_rect.set_right(rect.left() + 12); + box_rect.set_bottom(rect.top() + 12); + + + const long x = box_rect.left(); + const long y = box_rect.top(); + + if (enabled && is_depressed == false) + draw_solid_circle(c,point(rect.left()+5,rect.top()+5),4.5,rgb_pixel(255,255,255)); + else + draw_solid_circle(c,point(rect.left()+5,rect.top()+5),4.5,rgb_pixel(212,208,200)); + + + color = rgb_pixel(128,128,128); + draw_line(c,point(0+x,4+y),point(0+x,7+y),color); + draw_line(c,point(1+x,2+y),point(1+x,9+y),color); + draw_line(c,point(2+x,1+y),point(9+x,1+y),color); + draw_line(c,point(4+x,0+y),point(7+x,0+y),color); + + color = rgb_pixel(255,255,255); + draw_line(c,point(4+x,11+y),point(7+x,11+y),color); + draw_line(c,point(2+x,10+y),point(9+x,10+y),color); + draw_line(c,point(10+x,2+y),point(10+x,9+y),color); + draw_line(c,point(11+x,4+y),point(11+x,7+y),color); + + color = rgb_pixel(64,64,64); + draw_line(c,point(1+x,4+y),point(1+x,7+y),color); + draw_line(c,point(4+x,1+y),point(7+x,1+y),color); + draw_pixel(c,point(2+x,3+y),color); + draw_pixel(c,point(3+x,2+y),color); + draw_pixel(c,point(2+x,2+y),color); + draw_pixel(c,point(2+x,8+y),color); + draw_pixel(c,point(8+x,2+y),color); + draw_pixel(c,point(9+x,2+y),color); + + color = rgb_pixel(212,208,200); + draw_line(c,point(4+x,10+y),point(7+x,10+y),color); + draw_line(c,point(10+x,4+y),point(10+x,7+y),color); + draw_pixel(c,point(3+x,9+y),color); + draw_pixel(c,point(9+x,3+y),color); + + if (enabled) + { + color.red = 0; + color.green = 0; + color.blue = 0; + } + else + { + color.red = 128; + color.green = 128; + color.blue = 128; + } + + mfont.draw_string(c,name_rect,name,color); + + if (is_checked) + { + draw_line(c,point(5+x,4+y),point(6+x,4+y),color); + draw_line(c,point(4+x,5+y),point(7+x,5+y),color); + draw_line(c,point(4+x,6+y),point(7+x,6+y),color); + draw_line(c,point(5+x,7+y),point(6+x,7+y),color); + } + + } + +// ---------------------------------------------------------------------------------------- + + rectangle toggle_button_style_radio_button:: + get_min_size ( + const ustring& name, + const font& mfont + ) const + { + unsigned long width; + unsigned long height; + mfont.compute_size(name,width,height); + + if (height < 13) + height = 13; + + return rectangle(width + 17 -1, height -1); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // scroll bar style stuff +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + long scroll_bar_style_default:: + get_slider_length ( + long total_length, + long max_pos + ) const + { + // if the length is too small then we have to smash up the arrow buttons + // and hide the slider. + if (total_length <= get_width()*2) + { + return 0; + } + else + { + double range = total_length - get_button_length(total_length, max_pos)*2; + + double scale_factor = 30.0/(max_pos + 30.0); + + if (scale_factor < 0.1) + scale_factor = 0.1; + + + double fraction = range/(max_pos + range)*scale_factor; + double result = fraction * range; + long res = static_cast(result); + if (res < 8) + res = 8; + return res; + } + } + +// ---------------------------------------------------------------------------------------- + + long scroll_bar_style_default:: + get_button_length ( + long total_length, + long + ) const + { + // if the length is too small then we have to smash up the arrow buttons + // and hide the slider. + if (total_length <= get_width()*2) + { + return total_length/2; + } + else + { + return get_width(); + } + } + +// ---------------------------------------------------------------------------------------- + + void scroll_bar_style_default:: + draw_scroll_bar_background ( + const canvas& c, + const rectangle& rect, + const bool , + const long , + const long , + const bool is_depressed + ) const + { + if (is_depressed) + draw_checkered(c, rect,rgb_pixel(0,0,0),rgb_pixel(43,47,55)); + else + draw_checkered(c, rect,rgb_pixel(255,255,255),rgb_pixel(212,208,200)); + } + +// ---------------------------------------------------------------------------------------- + + void scroll_bar_style_default:: + draw_scroll_bar_slider ( + const canvas& c, + const rectangle& rect, + const bool , + const long , + const long , + const bool + ) const + { + fill_rect(c, rect, rgb_pixel(212,208,200)); + draw_button_up(c, rect); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // text_field styles +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + unsigned long text_field_style_default:: + get_padding ( + const font& mfont + ) const + { + return mfont.height()-mfont.ascender(); + } + +// ---------------------------------------------------------------------------------------- + + void text_field_style_default:: + draw_text_field ( + const canvas& c, + const rectangle& rect, + const rectangle& text_rect, + const bool enabled, + const font& mfont, + const ustring& text, + const unsigned long cursor_x, + const unsigned long text_pos, + const rgb_pixel& text_color, + const rgb_pixel& bg_color, + const bool has_focus, + const bool cursor_visible, + const long highlight_start, + const long highlight_end + ) const + { + rectangle area = rect.intersect(c); + + if (enabled) + { + // first fill our area with the bg_color + fill_rect(c, area,bg_color); + } + else + { + // first fill our area with gray + fill_rect(c, area,rgb_pixel(212,208,200)); + } + + + if (enabled) + mfont.draw_string(c,text_rect,text,text_color,text_pos); + else + mfont.draw_string(c,text_rect,text,rgb_pixel(128,128,128),text_pos); + + // now draw the edge of the text_field + draw_sunken_rectangle(c, rect); + + if (highlight_start <= highlight_end && enabled) + { + rectangle highlight_rect = text_rect; + unsigned long left_pad = 0, right_pad = mfont.left_overflow(); + + long i; + for (i = text_pos; i <= highlight_end; ++i) + { + if (i == highlight_start) + left_pad = right_pad; + + right_pad += mfont[text[i]].width(); + } + + highlight_rect.set_left(text_rect.left()+left_pad); + highlight_rect.set_right(text_rect.left()+right_pad); + + // highlight the highlight_rect area + highlight_rect = highlight_rect.intersect(c); + for (long row = highlight_rect.top(); row <= highlight_rect.bottom(); ++row) + { + for (long col = highlight_rect.left(); col <= highlight_rect.right(); ++col) + { + canvas::pixel& pixel = c[row-c.top()][col-c.left()]; + if (pixel.red == 255 && pixel.green == 255 && pixel.blue == 255) + { + // this is a background (and white) pixel so set it to a dark + // blueish color. + pixel.red = 10; + pixel.green = 36; + pixel.blue = 106; + } + else + { + // this should be a pixel that is part of a letter so set it to white + pixel.red = 255; + pixel.green = 255; + pixel.blue = 255; + } + } + } + } + + // now draw the cursor if we need to + if (cursor_visible && has_focus && enabled) + { + const unsigned long top = rect.top()+3; + const unsigned long bottom = rect.bottom()-3; + draw_line(c, point(rect.left()+cursor_x,top),point(rect.left()+cursor_x,bottom)); + } + + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // text_box styles +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + void text_box_style_default:: + draw_text_box ( + const canvas& c, + const rectangle& display_rect, + const rectangle& text_rect, + const bool enabled, + const font& mfont, + const ustring& text, + const rectangle& cursor_rect, + const rgb_pixel& text_color, + const rgb_pixel& bg_color, + const bool has_focus, + const bool cursor_visible, + const long highlight_start, + const long highlight_end + ) const + { + rectangle area = display_rect.intersect(c); + + if (enabled) + { + // first fill our area with the bg_color + fill_rect(c, area,bg_color); + } + else + { + // first fill our area with gray + fill_rect(c, area,rgb_pixel(212,208,200)); + } + + + if (enabled) + mfont.draw_string(c,text_rect,text,text_color, 0, ustring::npos, area); + else + mfont.draw_string(c,text_rect,text,rgb_pixel(128,128,128), 0, ustring::npos, area); + + + // now draw the highlight if there is any + if (highlight_start <= highlight_end && enabled) + { + const rectangle first_pos = mfont.compute_cursor_rect(text_rect, text, highlight_start); + const rectangle last_pos = mfont.compute_cursor_rect(text_rect, text, highlight_end+1); + + const rgb_alpha_pixel color(10, 30, 106, 90); + + // if the highlighted text is all on one line + if (first_pos.top() == last_pos.top()) + { + fill_rect(c, (first_pos + last_pos).intersect(display_rect), color); + } + else + { + const rectangle min_boundary(display_rect.left()+4, display_rect.top()+4, + display_rect.right()-4, display_rect.bottom()-4); + const rectangle boundary( display_rect.intersect(text_rect) + min_boundary); + + rectangle first_row, last_row, middle_rows; + first_row += first_pos; + first_row += point(boundary.right(), first_pos.top()); + last_row += last_pos; + last_row += point(boundary.left(), last_pos.bottom()); + + middle_rows.left() = boundary.left(); + middle_rows.right() = boundary.right(); + middle_rows.top() = first_row.bottom()+1; + middle_rows.bottom() = last_row.top()-1; + + fill_rect(c, first_row.intersect(display_rect), color); + fill_rect(c, middle_rows, color); + fill_rect(c, last_row.intersect(display_rect), color); + } + } + + // now draw the cursor if we need to + if (cursor_visible && has_focus && enabled) + { + draw_line(c, point(cursor_rect.left(), cursor_rect.top()),point(cursor_rect.left(), cursor_rect.bottom()), 0, area); + } + + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_WIDGETs_STYLE_CPP_ + diff --git a/ml/dlib/dlib/gui_widgets/style.h b/ml/dlib/dlib/gui_widgets/style.h new file mode 100644 index 000000000..f31caee30 --- /dev/null +++ b/ml/dlib/dlib/gui_widgets/style.h @@ -0,0 +1,825 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net), and Nils Labugt +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_WIDGETs_STYLE_ +#define DLIB_WIDGETs_STYLE_ + +#include "../algs.h" +#include "style_abstract.h" +#include "../gui_core.h" +#include "canvas_drawing.h" +#include +#include +#include "../unicode.h" +#include "../array2d.h" +#include "../pixel.h" +#include "fonts.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // button styles +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class button_style + { + public: + + button_style() + { + } + + virtual ~button_style() + {} + + virtual bool redraw_on_mouse_over ( + ) const { return false; } + + virtual rectangle get_invalidation_rect ( + const rectangle& rect + ) const { return rect; } + + virtual rectangle get_min_size ( + const ustring& name, + const font& mfont + ) const = 0; + + virtual void draw_button ( + const canvas& c, + const rectangle& rect, + const bool enabled, + const font& mfont, + const long lastx, + const long lasty, + const ustring& name, + const bool is_depressed + ) const = 0; + }; + +// ---------------------------------------------------------------------------------------- + + class button_style_default : public button_style + { + public: + button_style_default () : padding(4), name_width(0) {} + + virtual void draw_button ( + const canvas& c, + const rectangle& rect, + const bool enabled, + const font& mfont, + const long lastx, + const long lasty, + const ustring& name, + const bool is_depressed + ) const; + + virtual rectangle get_min_size ( + const ustring& name, + const font& mfont + ) const; + + private: + + // this is the minimum amount of padding that can separate the name from the + // edge of the button + const unsigned long padding; + // this is the width of the name string + mutable unsigned long name_width; + + }; + +// ---------------------------------------------------------------------------------------- + + class button_style_toolbar1 : public button_style + { + public: + button_style_toolbar1 () : padding(4), name_width(0) {} + + virtual void draw_button ( + const canvas& c, + const rectangle& rect, + const bool enabled, + const font& mfont, + const long lastx, + const long lasty, + const ustring& name, + const bool is_depressed + ) const; + + virtual rectangle get_invalidation_rect ( + const rectangle& rect + ) const + { + rectangle temp(rect); + temp.left() -= 2; + temp.top() -= 2; + temp.right() += 2; + temp.bottom() += 2; + return temp; + } + + virtual bool redraw_on_mouse_over ( + ) const { return true; } + + virtual rectangle get_min_size ( + const ustring& name, + const font& mfont + ) const; + + private: + + // this is the minimum amount of padding that can separate the name from the + // edge of the button + const unsigned long padding; + // this is the width of the name string + mutable unsigned long name_width; + + }; + +// ---------------------------------------------------------------------------------------- + + class button_style_toolbar_icon1 : public button_style + { + public: + template + button_style_toolbar_icon1 (const image_type& img_, unsigned long pad = 6) : padding(pad) + { + assign_image(img_mouseover,img_); + make_images(); + } + + button_style_toolbar_icon1( const button_style_toolbar_icon1& item): button_style(item), padding(item.padding) + { + assign_image(img_mouseover, item.img_mouseover); + assign_image(img_normal, item.img_normal); + assign_image(img_disabled, item.img_disabled); + } + + virtual void draw_button ( + const canvas& c, + const rectangle& rect, + const bool enabled, + const font& mfont, + const long lastx, + const long lasty, + const ustring& name, + const bool is_depressed + ) const; + + virtual bool redraw_on_mouse_over ( + ) const { return true; } + + virtual rectangle get_min_size ( + const ustring& name, + const font& mfont + ) const; + + private: + + void make_images ( + ) + { + // make the disabled image grayscale and make both non-mouseover images have weaker alpha channels + img_disabled.set_size(img_mouseover.nr(), img_mouseover.nc()); + img_normal.set_size(img_mouseover.nr(), img_mouseover.nc()); + + for (long r = 0; r < img_mouseover.nr(); ++r) + { + for (long c = 0; c < img_mouseover.nc(); ++c) + { + rgb_alpha_pixel p = img_mouseover[r][c]; + long avg = p.red; + avg += p.green; + avg += p.blue; + avg /= 3; + + if (p.alpha > 40) + p.alpha -= 40; + else + p.alpha = 0; + + img_normal[r][c] = p; + + if (p.alpha > 80) + p.alpha -= 80; + else + p.alpha = 0; + + p.red = avg; + p.green = avg; + p.blue = avg; + img_disabled[r][c] = p; + } + } + } + + array2d img_mouseover; + array2d img_normal; + array2d img_disabled; + + // this is the minimum amount of padding that can separate the name from the + // edge of the button + const unsigned long padding; + + }; + +// ---------------------------------------------------------------------------------------- + + class button_style_arrow : public button_style + { + + public: + + enum arrow_direction + { + UP, + DOWN, + LEFT, + RIGHT + }; + + button_style_arrow ( + arrow_direction dir_ + ) : dir(dir_) {} + + virtual void draw_button ( + const canvas& c, + const rectangle& rect, + const bool enabled, + const font& mfont, + const long lastx, + const long lasty, + const ustring& name, + const bool is_depressed + ) const; + + virtual rectangle get_min_size ( + const ustring& , + const font& + ) const { return rectangle(); } + + private: + arrow_direction dir; + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // toggle button styles +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class toggle_button_style + { + public: + + toggle_button_style() + { + } + + virtual ~toggle_button_style() + {} + + virtual bool redraw_on_mouse_over ( + ) const { return false; } + + virtual rectangle get_min_size ( + const ustring& name, + const font& mfont + ) const = 0; + + virtual void draw_toggle_button ( + const canvas& c, + const rectangle& rect, + const bool enabled, + const font& mfont, + const long lastx, + const long lasty, + const ustring& name, + const bool is_depressed, + const bool is_checked + ) const = 0; + }; + +// ---------------------------------------------------------------------------------------- + + class toggle_button_style_default : public toggle_button_style + { + public: + toggle_button_style_default () : padding(4), name_width(0) {} + + virtual void draw_toggle_button ( + const canvas& c, + const rectangle& rect, + const bool enabled, + const font& mfont, + const long lastx, + const long lasty, + const ustring& name, + const bool is_depressed, + const bool is_checked + ) const; + + virtual rectangle get_min_size ( + const ustring& name, + const font& mfont + ) const; + + private: + + // this is the minimum amount of padding that can separate the name from the + // edge of the button + const unsigned long padding; + // this is the width of the name string + mutable unsigned long name_width; + + }; + +// ---------------------------------------------------------------------------------------- + + class toggle_button_style_check_box : public toggle_button_style + { + public: + virtual void draw_toggle_button ( + const canvas& c, + const rectangle& rect, + const bool enabled, + const font& mfont, + const long lastx, + const long lasty, + const ustring& name, + const bool is_depressed, + const bool is_checked + ) const; + + virtual rectangle get_min_size ( + const ustring& name, + const font& mfont + ) const; + + }; + +// ---------------------------------------------------------------------------------------- + + class toggle_button_style_radio_button : public toggle_button_style + { + public: + virtual void draw_toggle_button ( + const canvas& c, + const rectangle& rect, + const bool enabled, + const font& mfont, + const long lastx, + const long lasty, + const ustring& name, + const bool is_depressed, + const bool is_checked + ) const; + + virtual rectangle get_min_size ( + const ustring& name, + const font& mfont + ) const; + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // scroll_bar styles +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class scroll_bar_style + { + public: + + virtual ~scroll_bar_style() {} + + virtual bool redraw_on_mouse_over_slider ( + ) const { return false; } + + virtual long get_width ( + ) const = 0; + + virtual long get_slider_length ( + long total_length, + long max_pos + ) const = 0; + + virtual long get_button_length ( + long total_length, + long max_pos + ) const = 0; + + virtual void draw_scroll_bar_background ( + const canvas& c, + const rectangle& rect, + const bool enabled, + const long lastx, + const long lasty, + const bool is_depressed + ) const = 0; + + virtual void draw_scroll_bar_slider ( + const canvas& c, + const rectangle& rect, + const bool enabled, + const long lastx, + const long lasty, + const bool is_being_dragged + ) const = 0; + + }; + +// ---------------------------------------------------------------------------------------- + + class scroll_bar_style_default : public scroll_bar_style + { + public: + button_style_arrow get_up_button_style ( + ) const { return button_style_arrow(button_style_arrow::UP); } + + button_style_arrow get_down_button_style ( + ) const { return button_style_arrow(button_style_arrow::DOWN); } + + button_style_arrow get_left_button_style ( + ) const { return button_style_arrow(button_style_arrow::LEFT); } + + button_style_arrow get_right_button_style ( + ) const { return button_style_arrow(button_style_arrow::RIGHT); } + + virtual long get_width ( + ) const { return 16; } + + virtual long get_slider_length ( + long total_length, + long max_pos + ) const; + + virtual long get_button_length ( + long total_length, + long max_pos + ) const; + + virtual void draw_scroll_bar_background ( + const canvas& c, + const rectangle& rect, + const bool enabled, + const long lastx, + const long lasty, + const bool is_depressed + ) const; + + virtual void draw_scroll_bar_slider ( + const canvas& c, + const rectangle& rect, + const bool enabled, + const long lastx, + const long lasty, + const bool is_being_dragged + ) const; + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // scrollable_region styles +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class scrollable_region_style + { + public: + + virtual ~scrollable_region_style() {} + + virtual long get_border_size ( + ) const = 0; + + virtual void draw_scrollable_region_border ( + const canvas& c, + const rectangle& rect, + const bool enabled + ) const = 0; + + }; + +// ---------------------------------------------------------------------------------------- + + class scrollable_region_style_default : public scrollable_region_style + { + public: + scroll_bar_style_default get_horizontal_scroll_bar_style ( + ) const { return scroll_bar_style_default(); } + + scroll_bar_style_default get_vertical_scroll_bar_style ( + ) const { return scroll_bar_style_default(); } + + virtual long get_border_size ( + ) const { return 2; } + + virtual void draw_scrollable_region_border ( + const canvas& c, + const rectangle& rect, + const bool + ) const { draw_sunken_rectangle(c,rect); } + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // list_box styles +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class list_box_style + { + public: + + virtual ~list_box_style() {} + + virtual void draw_list_box_background ( + const canvas& c, + const rectangle& display_rect, + const bool enabled + ) const = 0; + + virtual void draw_list_box_item ( + const canvas& c, + const rectangle& rect, + const rectangle& display_rect, + const bool enabled, + const font& mfont, + const std::string& text, + const bool is_selected + ) const = 0; + + virtual void draw_list_box_item ( + const canvas& c, + const rectangle& rect, + const rectangle& display_rect, + const bool enabled, + const font& mfont, + const std::wstring& text, + const bool is_selected + ) const = 0; + + virtual void draw_list_box_item ( + const canvas& c, + const rectangle& rect, + const rectangle& display_rect, + const bool enabled, + const font& mfont, + const ustring& text, + const bool is_selected + ) const = 0; + + }; + +// ---------------------------------------------------------------------------------------- + + class list_box_style_default : public list_box_style + { + public: + scrollable_region_style_default get_scrollable_region_style ( + ) const { return scrollable_region_style_default(); } + + virtual void draw_list_box_item ( + const canvas& c, + const rectangle& rect, + const rectangle& display_rect, + const bool enabled, + const font& mfont, + const std::string& text, + const bool is_selected + ) const { draw_list_box_item_template(c,rect,display_rect, enabled, mfont, text, is_selected); } + + virtual void draw_list_box_item ( + const canvas& c, + const rectangle& rect, + const rectangle& display_rect, + const bool enabled, + const font& mfont, + const std::wstring& text, + const bool is_selected + ) const { draw_list_box_item_template(c,rect,display_rect, enabled, mfont, text, is_selected); } + + virtual void draw_list_box_item ( + const canvas& c, + const rectangle& rect, + const rectangle& display_rect, + const bool enabled, + const font& mfont, + const ustring& text, + const bool is_selected + ) const { draw_list_box_item_template(c,rect,display_rect, enabled, mfont, text, is_selected); } + + template + void draw_list_box_item_template ( + const canvas& c, + const rectangle& rect, + const rectangle& display_rect, + const bool enabled, + const font& mfont, + const string_type& text, + const bool is_selected + ) const + { + if (is_selected) + { + if (enabled) + fill_rect_with_vertical_gradient(c,rect,rgb_pixel(110,160,255), rgb_pixel(100,130,250),display_rect); + else + fill_rect_with_vertical_gradient(c,rect,rgb_pixel(140,190,255), rgb_pixel(130,160,250),display_rect); + } + + if (enabled) + mfont.draw_string(c,rect,text,rgb_pixel(0,0,0),0,std::string::npos,display_rect); + else + mfont.draw_string(c,rect,text,rgb_pixel(128,128,128),0,std::string::npos,display_rect); + } + + virtual void draw_list_box_background ( + const canvas& c, + const rectangle& display_rect, + const bool enabled + ) const + { + if (enabled) + { + // first fill our area with white + fill_rect(c, display_rect,rgb_pixel(255,255,255)); + } + else + { + // first fill our area with gray + fill_rect(c, display_rect,rgb_pixel(212,208,200)); + } + } + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // text_box styles +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class text_box_style + { + public: + + text_box_style() + { + } + + virtual ~text_box_style() + {} + + virtual unsigned long get_padding ( + const font& mfont + ) const = 0; + + virtual void draw_text_box ( + const canvas& c, + const rectangle& display_rect, + const rectangle& text_rect, + const bool enabled, + const font& mfont, + const ustring& text, + const rectangle& cursor_rect, + const rgb_pixel& text_color, + const rgb_pixel& bg_color, + const bool has_focus, + const bool cursor_visible, + const long highlight_start, + const long highlight_end + ) const = 0; + }; + +// ---------------------------------------------------------------------------------------- + + class text_box_style_default : public text_box_style + { + public: + + text_box_style_default() + { + } + + scrollable_region_style_default get_scrollable_region_style ( + ) const { return scrollable_region_style_default(); } + + virtual ~text_box_style_default() + {} + + virtual unsigned long get_padding ( + const font& + ) const { return 1; } + + virtual void draw_text_box ( + const canvas& c, + const rectangle& display_rect, + const rectangle& text_rect, + const bool enabled, + const font& mfont, + const ustring& text, + const rectangle& cursor_rect, + const rgb_pixel& text_color, + const rgb_pixel& bg_color, + const bool has_focus, + const bool cursor_visible, + const long highlight_start, + const long highlight_end + ) const; + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // text_field styles +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class text_field_style + { + public: + + text_field_style() + { + } + + virtual ~text_field_style() + {} + + virtual unsigned long get_padding ( + const font& mfont + ) const = 0; + + virtual void draw_text_field ( + const canvas& c, + const rectangle& rect, + const rectangle& text_rect, + const bool enabled, + const font& mfont, + const ustring& text, + const unsigned long cursor_x, + const unsigned long text_pos, + const rgb_pixel& text_color, + const rgb_pixel& bg_color, + const bool has_focus, + const bool cursor_visible, + const long highlight_start, + const long highlight_end + ) const = 0; + }; + +// ---------------------------------------------------------------------------------------- + + class text_field_style_default : public text_field_style + { + public: + + text_field_style_default() + { + } + + virtual ~text_field_style_default() + {} + + virtual unsigned long get_padding ( + const font& mfont + ) const; + + virtual void draw_text_field ( + const canvas& c, + const rectangle& rect, + const rectangle& text_rect, + const bool enabled, + const font& mfont, + const ustring& text, + const unsigned long cursor_x, + const unsigned long text_pos, + const rgb_pixel& text_color, + const rgb_pixel& bg_color, + const bool has_focus, + const bool cursor_visible, + const long highlight_start, + const long highlight_end + ) const; + + }; + +// ---------------------------------------------------------------------------------------- + +} + +#ifdef NO_MAKEFILE +#include "style.cpp" +#endif + +#endif // DLIB_WIDGETs_STYLE_ + + diff --git a/ml/dlib/dlib/gui_widgets/style_abstract.h b/ml/dlib/dlib/gui_widgets/style_abstract.h new file mode 100644 index 000000000..e4d3245df --- /dev/null +++ b/ml/dlib/dlib/gui_widgets/style_abstract.h @@ -0,0 +1,777 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net), and Nils Labugt +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_WIDGETs_STYLE_ABSTRACT_ +#ifdef DLIB_WIDGETs_STYLE_ABSTRACT_ + +#include "../algs.h" +#include "../gui_core.h" +#include "widgets_abstract.h" +#include "../unicode/unicode_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // button styles +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class button_style + { + /*! + WHAT THIS OBJECT REPRESENTS + This is an abstract class that defines the interface a + button style object must implement. + + Note that derived classes must be copyable via + their copy constructors. + !*/ + + public: + + virtual ~button_style() {} + + virtual bool redraw_on_mouse_over ( + ) const { return false; } + /*! + ensures + - if (this style draws buttons differently when a mouse is over them) then + - returns true + - else + - returns false + !*/ + + virtual rectangle get_invalidation_rect ( + const rectangle& rect + ) const { return rect; } + /*! + requires + - the mutex drawable::m is locked + - rect == the get_rect() that defines where the button is + ensures + - returns a rectangle that should be invalidated whenever a button + needs to redraw itself. (e.g. If you wanted your button style to + draw outside the button then you could return a larger rectangle) + !*/ + + virtual rectangle get_min_size ( + const ustring& name, + const font& mfont + ) const = 0; + /*! + requires + - the mutex drawable::m is locked + ensures + - returns a rectangle that represents the minimum size of the button + given the name and font. + !*/ + + virtual void draw_button ( + const canvas& c, + const rectangle& rect, + const bool enabled, + const font& mfont, + const long lastx, + const long lasty, + const ustring& name, + const bool is_depressed + ) const = 0; + /*! + requires + - the mutex drawable::m is locked + - c == the canvas to draw on + - rect, enabled, mfont, lastx, and lasty are the variables + defined in the protected section of the drawable class. + - name == the name of the button to be drawn + - is_depressed == true if the button is to be drawn in a depressed state + ensures + - draws the button on the canvas c at the location given by rect. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + class button_style_default : public button_style + { + /*! + This is the default style for button objects. It will cause + a button to appear as the simple MS Windows 2000 button style. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + class button_style_toolbar1 : public button_style + { + /*! + This draws a simple toolbar style button that displays its name in the + middle of itself. When the mouse moves over it it will light up. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + class button_style_toolbar_icon1 : public button_style + { + /*! + This draws a simple toolbar style button that displays an image in the + middle of itself. When the mouse moves over it it will light up. + !*/ + template + button_style_toolbar_icon1 ( + const image_type& img, + unsigned long border_size = 6 + ); + /*! + requires + - image_type == an implementation of array2d/array2d_kernel_abstract.h + - pixel_traits is defined + ensures + - displays image img in the middle of the button + - the distance between the edge of the button and the image + will be border_size pixels + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + class button_style_arrow : public button_style + { + public: + /*! + This draws a simple button with an arrow in it + !*/ + + enum arrow_direction + { + UP, + DOWN, + LEFT, + RIGHT + }; + + button_style_arrow ( + arrow_direction dir + ); + /*! + ensures + - the arrow in the button will point in the given direction + !*/ + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // toggle button styles +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class toggle_button_style + { + /*! + WHAT THIS OBJECT REPRESENTS + This is an abstract class that defines the interface a + toggle button style object must implement. + + Note that derived classes must be copyable via + their copy constructors. + !*/ + + public: + + virtual ~toggle_button_style() {} + + virtual bool redraw_on_mouse_over ( + ) const { return false; } + /*! + ensures + - if (this style draws buttons differently when a mouse is over them) then + - returns true + - else + - returns false + !*/ + + virtual rectangle get_min_size ( + const ustring& name, + const font& mfont + ) const = 0; + /*! + requires + - the mutex drawable::m is locked + ensures + - returns a rectangle that represents the minimum size of the button + given the name and font. + !*/ + + virtual void draw_toggle_button ( + const canvas& c, + const rectangle& rect, + const bool enabled, + const font& mfont, + const long lastx, + const long lasty, + const ustring& name, + const bool is_depressed, + const bool is_checked + ) const = 0; + /*! + requires + - the mutex drawable::m is locked + - c == the canvas to draw on + - rect, enabled, mfont, lastx, and lasty are the variables + defined in the protected section of the drawable class. + - name == the name of the button to be drawn + - is_depressed == true if the button is to be drawn in a depressed state + - is_checked == true if the toggle_button is in the checked state + ensures + - draws the button on the canvas c at the location given by rect. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + class toggle_button_style_default : public toggle_button_style + { + /*! + This is the default style for toggle_button objects. It will cause + a button to appear as the simple MS Windows 2000 button style. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + class toggle_button_style_check_box : public toggle_button_style + { + /*! + This draws a simple check box style toggle button that displays its + name to the right of a check box. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + class toggle_button_style_radio_button : public toggle_button_style + { + /*! + This draws a simple radio button style toggle button that displays its + name to the right of a circular radio button. + !*/ + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // scroll_bar styles +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class scroll_bar_style + { + /*! + WHAT THIS OBJECT REPRESENTS + This is an abstract class that defines the interface a + scroll_bar style object must implement. + + Note that derived classes must be copyable via + their copy constructors. + + There are three parts of a scroll bar, the slider, the background, + and the two buttons on its ends. The "slider" is the thing that you + drag around on the scroll bar and the "background" is the part + in between the slider and the buttons on the ends. + !*/ + + public: + + virtual ~scroll_bar_style() {} + + virtual bool redraw_on_mouse_over_slider ( + ) const { return false; } + /*! + ensures + - if (this style draws a scroll_bar's slider differently when a mouse is over it + or it is being dragged) then + - returns true + - else + - returns false + !*/ + + virtual long get_width ( + ) const = 0; + /*! + requires + - the mutex drawable::m is locked + ensures + - returns the width in pixels of the scroll bar + !*/ + + virtual long get_slider_length ( + long total_length, + long max_pos + ) const = 0; + /*! + requires + - the mutex drawable::m is locked + - total_length == the total length in pixels of the scroll bar + - max_pos == the value of scroll_bar::max_slider_pos() for this + scroll bar + ensures + - returns the length in pixels of the scroll bar's slider + !*/ + + virtual long get_button_length ( + long total_length, + long max_pos + ) const = 0; + /*! + requires + - the mutex drawable::m is locked + - total_length == the total length in pixels of the scroll bar + - max_pos == the value of scroll_bar::max_slider_pos() for this + scroll bar + ensures + - returns the length in pixels of each of the scroll bar's + buttons + !*/ + + virtual void draw_scroll_bar_background ( + const canvas& c, + const rectangle& rect, + const bool enabled, + const long lastx, + const long lasty, + const bool is_depressed + ) const = 0; + /*! + requires + - the mutex drawable::m is locked + - c == the canvas to draw on + - rect, enabled, lastx, and lasty are the variables + defined in the protected section of the drawable class. + - is_depressed == true if the background area of the scroll_bar is to + be drawn in a depressed state (because the user is clicking on it) + ensures + - draws the background part of the scroll_bar on the canvas c at the + location given by rect. + !*/ + + virtual void draw_scroll_bar_slider ( + const canvas& c, + const rectangle& rect, + const bool enabled, + const long lastx, + const long lasty, + const bool is_being_dragged + ) const = 0; + /*! + requires + - the mutex drawable::m is locked + - c == the canvas to draw on + - rect, enabled, lastx, and lasty are the variables + defined in the protected section of the drawable class + - is_being_dragged == true if the user is dragging the slider + ensures + - draws the slider part of the scroll_bar on the canvas c at the + location given by rect. + !*/ + + button_style_type get_up_button_style ( + ) const; + /*! + ensures + - returns the type of button_style to use for a button on the + top side of a vertical scroll bar. + !*/ + + button_style_type get_down_button_style ( + ) const; + /*! + ensures + - returns the type of button_style to use for a button on the + bottom side of a vertical scroll bar. + !*/ + + button_style_type get_left_button_style ( + ) const; + /*! + ensures + - returns the type of button_style to use for a button on the + left side of a horizontal scroll bar. + !*/ + + button_style_type get_right_button_style ( + ) const; + /*! + ensures + - returns the type of button_style to use for a button on the + right side of a horizontal scroll bar. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + class scroll_bar_style_default : public scroll_bar_style + { + /*! + This is the default style for scroll_bar objects. It will cause + a scroll_bar to appear as the simple MS Windows 2000 scroll_bar style. + !*/ + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // scrollable_region (and zoomable_region) styles +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class scrollable_region_style + { + /*! + WHAT THIS OBJECT REPRESENTS + This is an abstract class that defines the interface a + scrollable_region and zoomable_region style object must implement. + + Note that derived classes must be copyable via + their copy constructors. + !*/ + public: + + virtual ~scrollable_region_style() {} + + virtual long get_border_size ( + ) const = 0; + /*! + requires + - the mutex drawable::m is locked + ensures + - returns the size of the border region in pixels + !*/ + + virtual void draw_scrollable_region_border ( + const canvas& c, + const rectangle& rect, + const bool enabled + ) const = 0; + /*! + requires + - the mutex drawable::m is locked + - c == the canvas to draw on + - rect and enabled are the variables defined in the protected section + of the drawable class. + ensures + - draws the border part of a scrollable_region on the canvas c at the + location given by rect. + !*/ + + scroll_bar_style_type get_horizontal_scroll_bar_style ( + ) const; + /*! + ensures + - returns the style of scroll_bar to use for the + horizontal scroll_bar in this widget. + !*/ + + scroll_bar_style_type get_vertical_scroll_bar_style ( + ) const; + /*! + ensures + - returns the style of scroll_bar to use for the + vertical scroll_bar in this widget. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + class scrollable_region_style_default : public scrollable_region_style + { + public: + /*! + This is the default style for scrollable_region and zoomable_region objects. + !*/ + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // text_box styles +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class text_box_style + { + /*! + WHAT THIS OBJECT REPRESENTS + This is an abstract class that defines the interface a + text_box style object must implement. + + Note that derived classes must be copyable via + their copy constructors. + !*/ + public: + + virtual ~text_field_style() {} + + scrollable_region_style_type get_scrollable_region_style ( + ) const; + /*! + ensures + - returns the style of scrollable_region to use for the + text_box. + !*/ + + virtual unsigned long get_padding ( + const font& mfont + ) const = 0; + /*! + requires + - the mutex drawable::m is locked + ensures + - returns the number of pixels that separate the text in the text_box + from the edge of the text_box widget itself. + !*/ + + virtual void draw_text_box ( + const canvas& c, + const rectangle& display_rect, + const rectangle& text_rect, + const bool enabled, + const font& mfont, + const ustring& text, + const rectangle& cursor_rect, + const rgb_pixel& text_color, + const rgb_pixel& bg_color, + const bool has_focus, + const bool cursor_visible, + const long highlight_start, + const long highlight_end + ) const = 0; + /*! + requires + - the mutex drawable::m is locked + - c == the canvas to draw on + - enabled and mfont are the variables defined in the protected section + - text_rect == the rectangle in which we should draw the given text + of the drawable class. + - display_rect == the rectangle returned by scrollable_region::display_rect() + - text == the current text in the text_box + - cursor_rect == A rectangle of width 1 that represents the current + position of the cursor on the screen. + - text_color == the color of the text to be drawn + - bg_color == the background color of the text field + - has_focus == true if this text field has keyboard input focus + - cursor_visible == true if the cursor should be drawn + - if (highlight_start <= highlight_end) then + - text[highlight_start] though text[highlight_end] should be + highlighted + ensures + - draws the text_box on the canvas c at the location given by text_rect. + (Note that the scroll bars and borders are drawn by the scrollable_region + and therefore the style returned by get_scrollable_region_style() + controls how those appear) + - doesn't draw anything outside display_rect + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + class text_box_style_default : public text_box_style + { + public: + /*! + This is the default style for text_box objects. + !*/ + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // list_box styles +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class list_box_style + { + /*! + WHAT THIS OBJECT REPRESENTS + This is an abstract class that defines the interface a + list_box style object must implement. + + Note that derived classes must be copyable via + their copy constructors. + !*/ + public: + + virtual ~list_box_style() {} + + virtual void draw_list_box_background ( + const canvas& c, + const rectangle& display_rect, + const bool enabled + ) const = 0; + /*! + requires + - the mutex drawable::m is locked + - c == the canvas to draw on + - display_rect == the display_rect for the list_box. This is the area + in which list box items are drawn (see display_rect in the scrollable_region + widget for more info) + - enabled == true if the list box is enabled + ensures + - draws the background of a list box on the canvas c at the location given + by display_rect. + !*/ + + scrollable_region_style_type get_scrollable_region_style ( + ) const; + /*! + ensures + - returns the style of scrollable_region to use for the + list_box. + !*/ + + virtual void draw_list_box_item ( + const canvas& c, + const rectangle& rect, + const rectangle& display_rect, + const bool enabled, + const font& mfont, + const std::string& text, + const bool is_selected + ) const = 0; + /*! + requires + - the mutex drawable::m is locked + - c == the canvas to draw on + - rect == the rectangle that defines where on the screen this list box item is. + - display_rect == the display_rect for the list_box. This is the area + in which list box items are drawn (see display_rect in the scrollable_region + widget for more info) + - mfont == the font to use to draw the list box item + - text == the text of the list box item to be drawn + - enabled == true if the list box is enabled + - is_selected == true if the item is to be drawn in a selected state + ensures + - draws the list box item on the canvas c at the location given by rect. + !*/ + + // wide character overloads + virtual void draw_list_box_item ( + const canvas& c, + const rectangle& rect, + const rectangle& display_rect, + const bool enabled, + const font& mfont, + const std::wstring& text, + const bool is_selected + ) const = 0; + + virtual void draw_list_box_item ( + const canvas& c, + const rectangle& rect, + const rectangle& display_rect, + const bool enabled, + const font& mfont, + const ustring& text, + const bool is_selected + ) const = 0; + + }; + +// ---------------------------------------------------------------------------------------- + + class list_box_style_default : public list_box_style + { + public: + /*! + This is the default style for list_box objects. + !*/ + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // text_field styles +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class text_field_style + { + /*! + WHAT THIS OBJECT REPRESENTS + This is an abstract class that defines the interface a + text_field style object must implement. + + Note that derived classes must be copyable via + their copy constructors. + !*/ + public: + + virtual ~text_field_style() {} + + virtual unsigned long get_padding ( + const font& mfont + ) const = 0; + /*! + requires + - the mutex drawable::m is locked + ensures + - returns the number of pixels that separate the text in the text_field + from the edge of the text_field widget itself. + !*/ + + virtual void draw_text_field ( + const canvas& c, + const rectangle& rect, + const rectangle& text_rect, + const bool enabled, + const font& mfont, + const ustring& text, + const unsigned long cursor_x, + const unsigned long text_pos, + const rgb_pixel& text_color, + const rgb_pixel& bg_color, + const bool has_focus, + const bool cursor_visible, + const long highlight_start, + const long highlight_end + ) const = 0; + /*! + requires + - the mutex drawable::m is locked + - c == the canvas to draw on + - rect, enabled, and mfont are the variables defined in the protected section + of the drawable class. + - text == the current text in the text_field + - text_rect == the rectangle in which we should draw the given text + - cursor_x == the x coordinate of the cursor relative to the left side + of rect. i.e. the number of pixels that separate the cursor from the + left side of the text_field. + - text_pos == the index of the first letter in text that appears in + this text field. + - text_color == the color of the text to be drawn + - bg_color == the background color of the text field + - has_focus == true if this text field has keyboard input focus + - cursor_visible == true if the cursor should be drawn + - if (highlight_start <= highlight_end) then + - text[highlight_start] though text[highlight_end] should be + highlighted + ensures + - draws the text_field on the canvas c at the location given by rect. + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + class text_field_style_default : public text_field_style + { + public: + /*! + This is the default style for text_field objects. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_WIDGETs_STYLE_ABSTRACT_ + + + diff --git a/ml/dlib/dlib/gui_widgets/widgets.cpp b/ml/dlib/dlib/gui_widgets/widgets.cpp new file mode 100644 index 000000000..c460d946d --- /dev/null +++ b/ml/dlib/dlib/gui_widgets/widgets.cpp @@ -0,0 +1,7341 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net), Keita Mochizuki +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_WIDGETs_CPP_ +#define DLIB_WIDGETs_CPP_ + +#include +#include + +#include "widgets.h" +#include "../string.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // toggle_button object methods +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + void toggle_button:: + set_size ( + unsigned long width, + unsigned long height + ) + { + auto_mutex M(m); + rectangle min_rect = style->get_min_size(name_,*mfont); + // only change the size if it isn't going to be too small to fit the name + if (height >= min_rect.height() && + width >= min_rect.width()) + { + rectangle old(rect); + rect = resize_rect(rect,width,height); + parent.invalidate_rectangle(rect+old); + btn_tooltip.set_size(width,height); + } + } + +// ---------------------------------------------------------------------------------------- + + void toggle_button:: + set_checked ( + ) + { + auto_mutex M(m); + checked = true; + parent.invalidate_rectangle(rect); + } + +// ---------------------------------------------------------------------------------------- + + void toggle_button:: + set_unchecked ( + ) + { + auto_mutex M(m); + checked = false; + parent.invalidate_rectangle(rect); + } + +// ---------------------------------------------------------------------------------------- + + bool toggle_button:: + is_checked ( + ) const + { + auto_mutex M(m); + return checked; + } + +// ---------------------------------------------------------------------------------------- + + void toggle_button:: + show ( + ) + { + button_action::show(); + btn_tooltip.show(); + } + +// ---------------------------------------------------------------------------------------- + + void toggle_button:: + hide ( + ) + { + button_action::hide(); + btn_tooltip.hide(); + } + +// ---------------------------------------------------------------------------------------- + + void toggle_button:: + enable ( + ) + { + button_action::enable(); + btn_tooltip.enable(); + } + +// ---------------------------------------------------------------------------------------- + + void toggle_button:: + disable ( + ) + { + button_action::disable(); + btn_tooltip.disable(); + } + +// ---------------------------------------------------------------------------------------- + + void toggle_button:: + set_tooltip_text ( + const std::string& text + ) + { + btn_tooltip.set_text(text); + } + + void toggle_button:: + set_tooltip_text ( + const std::wstring& text + ) + { + btn_tooltip.set_text(text); + } + + void toggle_button:: + set_tooltip_text ( + const dlib::ustring& text + ) + { + btn_tooltip.set_text(text); + } + +// ---------------------------------------------------------------------------------------- + + const std::string toggle_button:: + tooltip_text ( + ) const + { + return btn_tooltip.text(); + } + + const std::wstring toggle_button:: + tooltip_wtext ( + ) const + { + return btn_tooltip.wtext(); + } + + const dlib::ustring toggle_button:: + tooltip_utext ( + ) const + { + return btn_tooltip.utext(); + } + +// ---------------------------------------------------------------------------------------- + + void toggle_button:: + set_main_font ( + const std::shared_ptr& f + ) + { + auto_mutex M(m); + mfont = f; + set_name(name_); + } + +// ---------------------------------------------------------------------------------------- + + void toggle_button:: + set_pos ( + long x, + long y + ) + { + auto_mutex M(m); + button_action::set_pos(x,y); + btn_tooltip.set_pos(x,y); + } + +// ---------------------------------------------------------------------------------------- + + void toggle_button:: + set_name ( + const std::string& name + ) + { + set_name(convert_mbstring_to_wstring(name)); + } + + void toggle_button:: + set_name ( + const std::wstring& name + ) + { + set_name(convert_wstring_to_utf32(name)); + } + + void toggle_button:: + set_name ( + const dlib::ustring& name + ) + { + auto_mutex M(m); + name_ = name; + // do this to get rid of any reference counting that may be present in + // the std::string implementation. + name_[0] = name_[0]; + + rectangle old(rect); + rect = move_rect(style->get_min_size(name,*mfont),rect.left(),rect.top()); + btn_tooltip.set_size(rect.width(),rect.height()); + + parent.invalidate_rectangle(rect+old); + } + +// ---------------------------------------------------------------------------------------- + + const std::string toggle_button:: + name ( + ) const + { + return convert_wstring_to_mbstring(wname()); + } + + const std::wstring toggle_button:: + wname ( + ) const + { + return convert_utf32_to_wstring(uname()); + } + + const dlib::ustring toggle_button:: + uname ( + ) const + { + auto_mutex M(m); + dlib::ustring temp = name_; + // do this to get rid of any reference counting that may be present in + // the std::string implementation. + temp[0] = name_[0]; + return temp; + } + +// ---------------------------------------------------------------------------------------- + + void toggle_button:: + on_button_up ( + bool mouse_over + ) + { + if (mouse_over) + { + checked = !checked; + // this is a valid toggle_button click + if (event_handler.is_set()) + event_handler(); + else if (event_handler_self.is_set()) + event_handler_self(*this); + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // label object methods +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + void label:: + draw ( + const canvas& c + ) const + { + rectangle area = rect.intersect(c); + if (area.is_empty() || text_.size() == 0) + return; + + using namespace std; + unsigned char r = text_color_.red; + unsigned char g = text_color_.green; + unsigned char b = text_color_.blue; + if (!enabled) + { + r = 128; + g = 128; + b = 128; + } + + rectangle text_rect(rect); + + string::size_type first, last; + first = 0; + last = text_.find_first_of('\n'); + mfont->draw_string(c,text_rect,text_,rgb_pixel(r,g,b),first,last); + + while (last != string::npos) + { + first = last+1; + last = text_.find_first_of('\n',first); + text_rect.set_top(text_rect.top()+mfont->height()); + mfont->draw_string(c,text_rect,text_,rgb_pixel(r,g,b),first,last); + } + } + +// ---------------------------------------------------------------------------------------- + + void label:: + set_main_font ( + const std::shared_ptr& f + ) + { + auto_mutex M(m); + mfont = f; + set_text(text_); + } + +// ---------------------------------------------------------------------------------------- + + + void label:: + set_text ( + const std::string& text + ) + { + set_text(convert_mbstring_to_wstring(text)); + } + + void label:: + set_text ( + const std::wstring& text + ) + { + set_text(convert_wstring_to_utf32(text)); + } + + void label:: + set_text ( + const dlib::ustring& text + ) + { + using namespace std; + auto_mutex M(m); + text_ = text; + // do this to get rid of any reference counting that may be present in + // the std::string implementation. + text_[0] = text[0]; + + rectangle old(rect); + + unsigned long width; + unsigned long height; + mfont->compute_size(text,width,height); + + rect.set_right(rect.left() + width - 1); + rect.set_bottom(rect.top() + height - 1); + + parent.invalidate_rectangle(rect+old); + } + +// ---------------------------------------------------------------------------------------- + + const std::string label:: + text ( + ) const + { + return convert_wstring_to_mbstring(wtext()); + } + + const std::wstring label:: + wtext ( + ) const + { + return convert_utf32_to_wstring(utext()); + } + + const dlib::ustring label:: + utext ( + ) const + { + auto_mutex M(m); + dlib::ustring temp = text_; + // do this to get rid of any reference counting that may be present in + // the std::string implementation. + temp[0] = text_[0]; + return temp; + } + +// ---------------------------------------------------------------------------------------- + + void label:: + set_text_color ( + const rgb_pixel color + ) + { + m.lock(); + text_color_ = color; + parent.invalidate_rectangle(rect); + m.unlock(); + } + +// ---------------------------------------------------------------------------------------- + + const rgb_pixel label:: + text_color ( + ) const + { + auto_mutex M(m); + return text_color_; + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // text_field object methods +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + rectangle text_field:: + get_text_rect ( + ) const + { + // figure out where the text string should appear + unsigned long vertical_pad = (rect.height() - mfont->height())/2+1; + + rectangle text_rect; + text_rect.set_left(rect.left()+style->get_padding(*mfont)); + text_rect.set_top(rect.top()+vertical_pad); + text_rect.set_right(rect.right()-style->get_padding(*mfont)); + text_rect.set_bottom(text_rect.top()+mfont->height()-1); + return text_rect; + } + +// ---------------------------------------------------------------------------------------- + + void text_field:: + enable ( + ) + { + drawable::enable(); + right_click_menu.enable(); + } + +// ---------------------------------------------------------------------------------------- + + void text_field:: + give_input_focus ( + ) + { + auto_mutex M(m); + has_focus = true; + cursor_visible = true; + parent.invalidate_rectangle(rect); + t.start(); + } + +// ---------------------------------------------------------------------------------------- + + bool text_field:: + has_input_focus ( + ) const + { + auto_mutex M(m); + return has_focus; + } + +// ---------------------------------------------------------------------------------------- + + void text_field:: + select_all_text ( + ) + { + auto_mutex M(m); + on_select_all(); + } + +// ---------------------------------------------------------------------------------------- + + void text_field:: + on_cut ( + ) + { + on_copy(); + on_delete_selected(); + } + +// ---------------------------------------------------------------------------------------- + + void text_field:: + on_copy ( + ) + { + if (highlight_start <= highlight_end) + { + put_on_clipboard(text_.substr(highlight_start, highlight_end-highlight_start+1)); + } + } + +// ---------------------------------------------------------------------------------------- + + void text_field:: + on_paste ( + ) + { + ustring temp_str; + get_from_clipboard(temp_str); + + // If this is a multi line string then just take the first line. + ustring::size_type pos = temp_str.find_first_of('\n'); + if (pos != ustring::npos) + { + temp_str = temp_str.substr(0,pos); + } + + if (highlight_start <= highlight_end) + { + text_ = text_.substr(0,highlight_start) + temp_str + + text_.substr(highlight_end+1,text_.size()-highlight_end-1); + move_cursor(highlight_start+temp_str.size()); + highlight_start = 0; + highlight_end = -1; + parent.invalidate_rectangle(rect); + on_no_text_selected(); + + // send out the text modified event + if (text_modified_handler.is_set()) + text_modified_handler(); + } + else + { + text_ = text_.substr(0,cursor_pos) + temp_str + + text_.substr(cursor_pos,text_.size()-cursor_pos); + move_cursor(cursor_pos+temp_str.size()); + + // send out the text modified event + if (temp_str.size() != 0 && text_modified_handler.is_set()) + text_modified_handler(); + } + } + +// ---------------------------------------------------------------------------------------- + + void text_field:: + on_select_all ( + ) + { + move_cursor(static_cast(text_.size())); + highlight_start = 0; + highlight_end = static_cast(text_.size()-1); + if (highlight_start <= highlight_end) + on_text_is_selected(); + parent.invalidate_rectangle(rect); + } + +// ---------------------------------------------------------------------------------------- + + void text_field:: + on_delete_selected ( + ) + { + if (highlight_start <= highlight_end) + { + text_ = text_.erase(highlight_start,highlight_end-highlight_start+1); + move_cursor(highlight_start); + highlight_start = 0; + highlight_end = -1; + + on_no_text_selected(); + // send out the text modified event + if (text_modified_handler.is_set()) + text_modified_handler(); + + parent.invalidate_rectangle(rect); + } + } + +// ---------------------------------------------------------------------------------------- + + void text_field:: + on_text_is_selected ( + ) + { + right_click_menu.menu().enable_menu_item(0); + right_click_menu.menu().enable_menu_item(1); + right_click_menu.menu().enable_menu_item(3); + } + +// ---------------------------------------------------------------------------------------- + + void text_field:: + on_no_text_selected ( + ) + { + right_click_menu.menu().disable_menu_item(0); + right_click_menu.menu().disable_menu_item(1); + right_click_menu.menu().disable_menu_item(3); + } + +// ---------------------------------------------------------------------------------------- + + void text_field:: + show ( + ) + { + drawable::show(); + right_click_menu.show(); + } + +// ---------------------------------------------------------------------------------------- + + void text_field:: + disable ( + ) + { + auto_mutex M(m); + drawable::disable(); + t.stop(); + has_focus = false; + cursor_visible = false; + right_click_menu.disable(); + } + +// ---------------------------------------------------------------------------------------- + + void text_field:: + hide ( + ) + { + auto_mutex M(m); + drawable::hide(); + t.stop(); + has_focus = false; + cursor_visible = false; + } + +// ---------------------------------------------------------------------------------------- + + void text_field:: + set_main_font ( + const std::shared_ptr& f + ) + { + auto_mutex M(m); + mfont = f; + // adjust the height of this text field so that it is appropriate for the current + // font size + rect.set_bottom(rect.top() + mfont->height()+ (style->get_padding(*mfont))*2); + set_text(text_); + right_click_menu.set_rect(get_text_rect()); + } + +// ---------------------------------------------------------------------------------------- + + void text_field:: + draw ( + const canvas& c + ) const + { + rectangle area = rect.intersect(c); + if (area.is_empty()) + return; + + style->draw_text_field(c,rect,get_text_rect(), enabled, *mfont, text_, cursor_x, text_pos, + text_color_, bg_color_, has_focus, cursor_visible, highlight_start, + highlight_end); + } + +// ---------------------------------------------------------------------------------------- + + void text_field:: + set_text ( + const std::string& text + ) + { + set_text(convert_mbstring_to_wstring(text)); + } + + void text_field:: + set_text ( + const std::wstring& text + ) + { + set_text(convert_wstring_to_utf32(text)); + } + + void text_field:: + set_text ( + const dlib::ustring& text + ) + { + DLIB_ASSERT ( text.find_first_of('\n') == std::string::npos , + "\tvoid text_field::set_text()" + << "\n\ttext: " << narrow(text) ); + auto_mutex M(m); + // do this to get rid of any reference counting that may be present in + // the std::string implementation. + text_ = text.c_str(); + + move_cursor(0); + + highlight_start = 0; + highlight_end = -1; + + parent.invalidate_rectangle(rect); + } + +// ---------------------------------------------------------------------------------------- + + const std::string text_field:: + text ( + ) const + { + std::string temp = convert_wstring_to_mbstring(wtext()); + return temp; + } + + const std::wstring text_field:: + wtext ( + ) const + { + std::wstring temp = convert_utf32_to_wstring(utext()); + return temp; + } + + const dlib::ustring text_field:: + utext ( + ) const + { + auto_mutex M(m); + // do this to get rid of any reference counting that may be present in + // the dlib::ustring implementation. + dlib::ustring temp = text_.c_str(); + return temp; + } + +// ---------------------------------------------------------------------------------------- + + void text_field:: + set_width ( + unsigned long width + ) + { + auto_mutex M(m); + if (width < style->get_padding(*mfont)*2) + return; + + rectangle old(rect); + + rect.set_right(rect.left() + width - 1); + + right_click_menu.set_rect(get_text_rect()); + parent.invalidate_rectangle(rect+old); + } + +// ---------------------------------------------------------------------------------------- + + void text_field:: + set_pos ( + long x, + long y + ) + { + drawable::set_pos(x,y); + right_click_menu.set_rect(get_text_rect()); + } + +// ---------------------------------------------------------------------------------------- + + void text_field:: + set_background_color ( + const rgb_pixel color + ) + { + auto_mutex M(m); + bg_color_ = color; + parent.invalidate_rectangle(rect); + } + +// ---------------------------------------------------------------------------------------- + + const rgb_pixel text_field:: + background_color ( + ) const + { + auto_mutex M(m); + return bg_color_; + } + +// ---------------------------------------------------------------------------------------- + + void text_field:: + set_text_color ( + const rgb_pixel color + ) + { + auto_mutex M(m); + text_color_ = color; + parent.invalidate_rectangle(rect); + } + +// ---------------------------------------------------------------------------------------- + + const rgb_pixel text_field:: + text_color ( + ) const + { + auto_mutex M(m); + return text_color_; + } + +// ---------------------------------------------------------------------------------------- + + void text_field:: + on_mouse_move ( + unsigned long state, + long x, + long y + ) + { + if (!enabled || hidden || !has_focus) + { + return; + } + + if (state & base_window::LEFT) + { + if (highlight_start <= highlight_end) + { + if (highlight_start == cursor_pos) + shift_pos = highlight_end + 1; + else + shift_pos = highlight_start; + } + + unsigned long new_pos = mfont->compute_cursor_pos(get_text_rect(),text_,x,y,text_pos); + if (static_cast(new_pos) != cursor_pos) + { + move_cursor(new_pos); + parent.invalidate_rectangle(rect); + } + } + else if (shift_pos != -1) + { + shift_pos = -1; + } + + } + +// ---------------------------------------------------------------------------------------- + + void text_field:: + on_mouse_up ( + unsigned long btn, + unsigned long, + long , + long + ) + { + if (!enabled || hidden) + return; + + if (btn == base_window::LEFT) + shift_pos = -1; + } + +// ---------------------------------------------------------------------------------------- + + void text_field:: + on_mouse_down ( + unsigned long btn, + unsigned long state, + long x, + long y, + bool double_clicked + ) + { + using namespace std; + if (!enabled || hidden || btn != (unsigned long)base_window::LEFT) + return; + + if (rect.contains(x,y)) + { + has_focus = true; + cursor_visible = true; + parent.invalidate_rectangle(rect); + t.start(); + + if (double_clicked) + { + // highlight the double clicked word + string::size_type first, last; + const ustring ustr = convert_utf8_to_utf32(std::string(" \t\n")); + first = text_.substr(0,cursor_pos).find_last_of(ustr.c_str()); + last = text_.find_first_of(ustr.c_str(),cursor_pos); + long f = static_cast(first); + long l = static_cast(last); + if (first == string::npos) + f = -1; + if (last == string::npos) + l = static_cast(text_.size()); + + ++f; + --l; + + move_cursor(l+1); + highlight_start = f; + highlight_end = l; + on_text_is_selected(); + } + else + { + if (state & base_window::SHIFT) + { + if (highlight_start <= highlight_end) + { + if (highlight_start == cursor_pos) + shift_pos = highlight_end + 1; + else + shift_pos = highlight_start; + } + else + { + shift_pos = cursor_pos; + } + } + + bool at_end = false; + if (cursor_pos == 0 || cursor_pos == static_cast(text_.size())) + at_end = true; + const long old_pos = cursor_pos; + + unsigned long new_pos = mfont->compute_cursor_pos(get_text_rect(),text_,x,y,text_pos); + if (static_cast(new_pos) != cursor_pos) + { + move_cursor(new_pos); + parent.invalidate_rectangle(rect); + } + shift_pos = cursor_pos; + + if (at_end && cursor_pos == old_pos) + { + highlight_start = 0; + highlight_end = -1; + on_no_text_selected(); + parent.invalidate_rectangle(rect); + } + } + + } + else if (has_focus) + { + t.stop(); + has_focus = false; + cursor_visible = false; + shift_pos = -1; + highlight_start = 0; + highlight_end = -1; + on_no_text_selected(); + + if (focus_lost_handler.is_set()) + focus_lost_handler(); + parent.invalidate_rectangle(rect); + } + } + +// ---------------------------------------------------------------------------------------- + + void text_field:: + on_keydown ( + unsigned long key, + bool is_printable, + unsigned long state + ) + { + // If the right click menu is up then we don't want to do anything with + // the keyboard ourselves. Let the popup menu use the keyboard for now. + if (right_click_menu.popup_menu_visible()) + return; + + const ustring space_str = convert_utf8_to_utf32(std::string(" \t\n")); + const bool shift = (state&base_window::KBD_MOD_SHIFT) != 0; + const bool ctrl = (state&base_window::KBD_MOD_CONTROL) != 0; + if (has_focus && enabled && !hidden) + { + if (shift && is_printable == false) + { + if (shift_pos == -1) + { + if (highlight_start <= highlight_end) + { + if (highlight_start == cursor_pos) + shift_pos = highlight_end + 1; + else + shift_pos = highlight_start; + } + else + { + shift_pos = cursor_pos; + } + } + } + else + { + shift_pos = -1; + } + + if (key == base_window::KEY_LEFT || + key == base_window::KEY_UP) + { + if (cursor_pos != 0) + { + unsigned long new_pos; + if (ctrl) + { + // find the first non-whitespace to our left + std::string::size_type pos = text_.find_last_not_of(space_str.c_str(),cursor_pos); + if (pos != std::string::npos) + { + pos = text_.find_last_of(space_str.c_str(),pos); + if (pos != std::string::npos) + new_pos = static_cast(pos); + else + new_pos = 0; + } + else + { + new_pos = 0; + } + } + else + { + new_pos = cursor_pos-1; + } + + move_cursor(new_pos); + } + else if (shift_pos == -1) + { + highlight_start = 0; + highlight_end = -1; + on_no_text_selected(); + parent.invalidate_rectangle(rect); + } + + } + else if (key == base_window::KEY_RIGHT || + key == base_window::KEY_DOWN) + { + if (cursor_pos != static_cast(text_.size())) + { + unsigned long new_pos; + if (ctrl) + { + // find the first non-whitespace to our left + std::string::size_type pos = text_.find_first_not_of(space_str.c_str(),cursor_pos); + if (pos != std::string::npos) + { + pos = text_.find_first_of(space_str.c_str(),pos); + if (pos != std::string::npos) + new_pos = static_cast(pos+1); + else + new_pos = static_cast(text_.size()); + } + else + { + new_pos = static_cast(text_.size()); + } + } + else + { + new_pos = cursor_pos+1; + } + + move_cursor(new_pos); + } + else if (shift_pos == -1) + { + highlight_start = 0; + highlight_end = -1; + on_no_text_selected(); + parent.invalidate_rectangle(rect); + } + } + else if (is_printable) + { + if (ctrl) + { + if (key == 'a') + { + on_select_all(); + } + else if (key == 'c') + { + on_copy(); + } + else if (key == 'v') + { + on_paste(); + } + else if (key == 'x') + { + on_cut(); + } + } + else if (key != '\n') + { + if (highlight_start <= highlight_end) + { + text_ = text_.substr(0,highlight_start) + static_cast(key) + + text_.substr(highlight_end+1,text_.size()-highlight_end-1); + move_cursor(highlight_start+1); + highlight_start = 0; + highlight_end = -1; + on_no_text_selected(); + parent.invalidate_rectangle(rect); + } + else + { + text_ = text_.substr(0,cursor_pos) + static_cast(key) + + text_.substr(cursor_pos,text_.size()-cursor_pos); + move_cursor(cursor_pos+1); + } + unsigned long height; + mfont->compute_size(text_,text_width,height,text_pos); + + // send out the text modified event + if (text_modified_handler.is_set()) + text_modified_handler(); + } + else if (key == '\n') + { + if (enter_key_handler.is_set()) + enter_key_handler(); + } + } + else if (key == base_window::KEY_BACKSPACE) + { + // if something is highlighted then delete that + if (highlight_start <= highlight_end) + { + on_delete_selected(); + } + else if (cursor_pos != 0) + { + text_ = text_.erase(cursor_pos-1,1); + move_cursor(cursor_pos-1); + + // send out the text modified event + if (text_modified_handler.is_set()) + text_modified_handler(); + } + else + { + // do this just so it repaints itself right + move_cursor(cursor_pos); + } + unsigned long height; + mfont->compute_size(text_,text_width,height,text_pos); + parent.invalidate_rectangle(rect); + } + else if (key == base_window::KEY_DELETE) + { + // if something is highlighted then delete that + if (highlight_start <= highlight_end) + { + on_delete_selected(); + } + else if (cursor_pos != static_cast(text_.size())) + { + text_ = text_.erase(cursor_pos,1); + + // send out the text modified event + if (text_modified_handler.is_set()) + text_modified_handler(); + } + else + { + // do this just so it repaints itself right + move_cursor(cursor_pos); + } + parent.invalidate_rectangle(rect); + + unsigned long height; + mfont->compute_size(text_,text_width,height,text_pos); + } + else if (key == base_window::KEY_HOME) + { + move_cursor(0); + if (shift_pos == -1) + { + highlight_start = 0; + highlight_end = -1; + on_no_text_selected(); + parent.invalidate_rectangle(rect); + } + } + else if (key == base_window::KEY_END) + { + move_cursor(static_cast(text_.size())); + if (shift_pos == -1) + { + highlight_start = 0; + highlight_end = -1; + on_no_text_selected(); + parent.invalidate_rectangle(rect); + } + } + cursor_visible = true; + recent_movement = true; + + } + } + +// ---------------------------------------------------------------------------------------- + + void text_field:: + on_string_put( + const std::wstring &str + ) + { + if (has_focus && enabled && !hidden){ + ustring ustr = convert_wstring_to_utf32(str); + if (highlight_start <= highlight_end) + { + text_ = text_.substr(0,highlight_start) + ustr + + text_.substr(highlight_end+1,text_.size()-highlight_end-1); + move_cursor(highlight_start+ustr.size()); + highlight_start = 0; + highlight_end = -1; + on_no_text_selected(); + parent.invalidate_rectangle(rect); + } + else + { + text_ = text_.substr(0,cursor_pos) + ustr + + text_.substr(cursor_pos,text_.size()-cursor_pos); + move_cursor(cursor_pos+ustr.size()); + } + unsigned long height; + mfont->compute_size(text_,text_width,height,text_pos); + + // send out the text modified event + if (text_modified_handler.is_set()) + text_modified_handler(); + } + } + +// ---------------------------------------------------------------------------------------- + + void text_field:: + move_cursor ( + unsigned long pos + ) + { + using namespace std; + const long old_cursor_pos = cursor_pos; + + if (text_pos >= pos) + { + // the cursor should go all the way to the left side of the text + if (pos >= 6) + text_pos = pos-6; + else + text_pos = 0; + + cursor_pos = pos; + unsigned long height; + mfont->compute_size(text_,text_width,height,text_pos); + + unsigned long width; + unsigned long new_x = style->get_padding(*mfont); + if (static_cast(cursor_pos)-1 >= static_cast(text_pos)) + { + mfont->compute_size(text_,width,height,text_pos,cursor_pos-1); + if (cursor_pos != 0) + new_x += width - mfont->right_overflow(); + } + + cursor_x = new_x; + } + else + { + unsigned long height; + unsigned long width; + mfont->compute_size(text_,width,height,text_pos,pos-1); + + unsigned long new_x = style->get_padding(*mfont) + + width - mfont->right_overflow(); + + // move the text to the left if necessary + if (new_x + 4 > rect.width()) + { + while (new_x > rect.width() - rect.width()/5) + { + new_x -= (*mfont)[text_[text_pos]].width(); + ++text_pos; + } + } + + cursor_x = new_x; + cursor_pos = pos; + mfont->compute_size(text_,text_width,height,text_pos); + } + + parent.set_im_pos(rect.left()+cursor_x, rect.top()); + + if (old_cursor_pos != cursor_pos) + { + if (shift_pos != -1) + { + highlight_start = std::min(shift_pos,cursor_pos); + highlight_end = std::max(shift_pos,cursor_pos)-1; + } + else + { + highlight_start = 0; + highlight_end = -1; + } + + if (highlight_start > highlight_end) + on_no_text_selected(); + else + on_text_is_selected(); + + recent_movement = true; + cursor_visible = true; + parent.invalidate_rectangle(rect); + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// tabbed_display object methods +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + tabbed_display:: + tabbed_display( + drawable_window& w + ) : + drawable(w,MOUSE_CLICK), + selected_tab_(0), + left_pad(6), + right_pad(4), + top_pad(3), + bottom_pad(3) + { + rect = rectangle(0,0,40,mfont->height()+top_pad+bottom_pad); + enable_events(); + tabs.set_max_size(1); + tabs.set_size(1); + } + +// ---------------------------------------------------------------------------------------- + + tabbed_display:: + ~tabbed_display( + ) + { + disable_events(); + parent.invalidate_rectangle(rect); + } + +// ---------------------------------------------------------------------------------------- + + void tabbed_display:: + set_pos ( + long x, + long y + ) + { + auto_mutex M(m); + // we have to adjust the positions of all the tab rectangles + const long xdelta = rect.left() - x; + const long ydelta = rect.top() - y; + for (unsigned long i = 0; i < tabs.size(); ++i) + { + tabs[i].rect.set_left(tabs[i].rect.left()+xdelta); + tabs[i].rect.set_right(tabs[i].rect.right()+xdelta); + + tabs[i].rect.set_top(tabs[i].rect.top()+ydelta); + tabs[i].rect.set_bottom(tabs[i].rect.bottom()+ydelta); + + + // adjust the position of the group associated with this tab if it exists + if (tabs[i].group) + tabs[i].group->set_pos(x+3, y+mfont->height()+top_pad+bottom_pad+3); + } + drawable::set_pos(x,y); + recompute_tabs(); + } + +// ---------------------------------------------------------------------------------------- + + void tabbed_display:: + fit_to_contents ( + ) + { + auto_mutex M(m); + rectangle new_rect; + point p(rect.left(),rect.top()); + new_rect += p; + + for (unsigned long i = 0; i < tabs.size(); ++i) + { + if (tabs[i].group) + { + tabs[i].group->fit_to_contents(); + new_rect += tabs[i].group->get_rect(); + } + } + + // and give the new rect an additional 4 pixels on the bottom and right sides + // so that the contents to hit the edge of the tabbed display + new_rect = resize_rect(new_rect, new_rect.width()+4, new_rect.height()+4); + + parent.invalidate_rectangle(new_rect+rect); + rect = new_rect; + } + +// ---------------------------------------------------------------------------------------- + + void tabbed_display:: + set_size ( + unsigned long width, + unsigned long height + ) + { + auto_mutex M(m); + rectangle old(rect); + const long x = rect.left(); + const long y = rect.top(); + rect.set_right(x+width-1); + rect.set_bottom(y+height-1); + + recompute_tabs(); + + parent.invalidate_rectangle(rect+old); + } + +// ---------------------------------------------------------------------------------------- + + void tabbed_display:: + set_number_of_tabs ( + unsigned long num + ) + { + auto_mutex M(m); + + DLIB_ASSERT ( num > 0 , + "\tvoid tabbed_display::set_number_of_tabs()" + << "\n\tnum: " << num ); + + tabs.set_max_size(num); + tabs.set_size(num); + + selected_tab_ = 0; + + recompute_tabs(); + parent.invalidate_rectangle(rect); + } + +// ---------------------------------------------------------------------------------------- + + unsigned long tabbed_display:: + selected_tab ( + ) const + { + auto_mutex M(m); + return selected_tab_; + } + + unsigned long tabbed_display:: + number_of_tabs ( + ) const + { + auto_mutex M(m); + return tabs.size(); + } + +// ---------------------------------------------------------------------------------------- + + const std::string tabbed_display:: + tab_name ( + unsigned long idx + ) const + { + return convert_wstring_to_mbstring(tab_wname(idx)); + } + + const std::wstring tabbed_display:: + tab_wname ( + unsigned long idx + ) const + { + return convert_utf32_to_wstring(tab_uname(idx)); + } + + const dlib::ustring& tabbed_display:: + tab_uname ( + unsigned long idx + ) const + { + auto_mutex M(m); + + DLIB_ASSERT ( idx < number_of_tabs() , + "\tvoid tabbed_display::tab_name()" + << "\n\tidx: " << idx + << "\n\tnumber_of_tabs(): " << number_of_tabs() ); + + return tabs[idx].name; + } + +// ---------------------------------------------------------------------------------------- + + void tabbed_display:: + set_tab_name ( + unsigned long idx, + const std::string& new_name + ) + { + set_tab_name(idx, convert_mbstring_to_wstring(new_name)); + } + + void tabbed_display:: + set_tab_name ( + unsigned long idx, + const std::wstring& new_name + ) + { + set_tab_name(idx, convert_wstring_to_utf32(new_name)); + } + + void tabbed_display:: + set_tab_name ( + unsigned long idx, + const dlib::ustring& new_name + ) + { + auto_mutex M(m); + + + DLIB_ASSERT ( idx < number_of_tabs() , + "\tvoid tabbed_display::set_tab_name()" + << "\n\tidx: " << idx + << "\n\tnumber_of_tabs(): " << number_of_tabs() ); + + + tabs[idx].name = new_name; + // do this so that there isn't any reference counting going on + tabs[idx].name[0] = tabs[idx].name[0]; + unsigned long height; + mfont->compute_size(new_name,tabs[idx].width,height); + + + recompute_tabs(); + + parent.invalidate_rectangle(rect); + } + +// ---------------------------------------------------------------------------------------- + + void tabbed_display:: + on_mouse_down ( + unsigned long btn, + unsigned long, + long x, + long y, + bool + ) + { + if (rect.contains(x,y) && btn == base_window::LEFT && enabled && !hidden) + { + rectangle temp = rect; + const long offset = mfont->height() + bottom_pad + top_pad; + temp.set_bottom(rect.top()+offset); + if (temp.contains(x,y)) + { + // now we have to figure out which tab was clicked + for (unsigned long i = 0; i < tabs.size(); ++i) + { + if (selected_tab_ != i && tabs[i].rect.contains(x,y) && + tabs[selected_tab_].rect.contains(x,y) == false) + { + unsigned long old_idx = selected_tab_; + selected_tab_ = i; + recompute_tabs(); + parent.invalidate_rectangle(temp); + + // adjust the widget_group objects for these tabs if they exist + if (tabs[i].group) + tabs[i].group->show(); + if (tabs[old_idx].group) + tabs[old_idx].group->hide(); + + if (event_handler.is_set()) + event_handler(i,old_idx); + break; + } + } + } + } + } + +// ---------------------------------------------------------------------------------------- + + void tabbed_display:: + set_tab_group ( + unsigned long idx, + widget_group& group + ) + { + auto_mutex M(m); + + DLIB_ASSERT ( idx < number_of_tabs() , + "\tvoid tabbed_display::set_tab_group()" + << "\n\tidx: " << idx + << "\n\tnumber_of_tabs(): " << number_of_tabs() ); + + + tabs[idx].group = &group; + group.set_pos(rect.left()+3,rect.top()+mfont->height()+top_pad+bottom_pad+2); + if (idx == selected_tab_) + group.show(); + else + group.hide(); + } + +// ---------------------------------------------------------------------------------------- + + void tabbed_display:: + disable ( + ) + { + auto_mutex M(m); + if (tabs[selected_tab_].group) + tabs[selected_tab_].group->disable(); + drawable::disable(); + } + +// ---------------------------------------------------------------------------------------- + + void tabbed_display:: + enable ( + ) + { + auto_mutex M(m); + if (tabs[selected_tab_].group) + tabs[selected_tab_].group->enable(); + drawable::enable(); + } + +// ---------------------------------------------------------------------------------------- + + void tabbed_display:: + hide ( + ) + { + auto_mutex M(m); + if (tabs[selected_tab_].group) + tabs[selected_tab_].group->hide(); + drawable::hide(); + } + +// ---------------------------------------------------------------------------------------- + + void tabbed_display:: + show ( + ) + { + auto_mutex M(m); + if (tabs[selected_tab_].group) + tabs[selected_tab_].group->show(); + drawable::show(); + } + +// ---------------------------------------------------------------------------------------- + + void tabbed_display:: + draw ( + const canvas& c + ) const + { + rectangle area = rect.intersect(c); + if (area.is_empty()) + return; + + // draw the main border first + rectangle main_box(rect.left(),rect.top()+mfont->height()+top_pad+bottom_pad,rect.right(),rect.bottom()); + draw_button_up(c,main_box); + draw_pixel(c,point(main_box.right()-1,main_box.top()),rgb_pixel(128,128,128)); + + rgb_pixel color; + if (enabled) + { + color.red = 0; + color.green = 0; + color.blue = 0; + } + else + { + color.red = 128; + color.green = 128; + color.blue = 128; + } + + // draw the tabs + for (unsigned long i = 0; i < tabs.size(); ++i) + { + if (selected_tab_ != i) + draw_tab(tabs[i].rect,c); + + // draw the name string + rectangle temp = tabs[i].rect; + temp.set_top(temp.top()+top_pad); + temp.set_bottom(temp.bottom()+bottom_pad); + temp.set_left(temp.left()+left_pad); + temp.set_right(temp.right()+right_pad); + mfont->draw_string(c,temp,tabs[i].name,color); + } + draw_tab(tabs[selected_tab_].rect,c); + draw_line(c, + point(tabs[selected_tab_].rect.left()+1, + tabs[selected_tab_].rect.bottom()), + point(tabs[selected_tab_].rect.right()-2, + tabs[selected_tab_].rect.bottom()), + rgb_pixel(212,208,200)); + } + +// ---------------------------------------------------------------------------------------- + + void tabbed_display:: + draw_tab ( + const rectangle& tab, + const canvas& c + ) const + { + const rgb_pixel white(255,255,255); + const rgb_pixel background(212,208,200); + const rgb_pixel dark_gray(64,64,64); + const rgb_pixel gray(128,128,128); + draw_line(c,point(tab.left(),tab.top()+2),point(tab.left(),tab.bottom()),white); + draw_line(c,point(tab.left()+1,tab.top()+2),point(tab.left()+1,tab.bottom()),background); + draw_line(c,point(tab.right(),tab.top()+2),point(tab.right(),tab.bottom()),dark_gray); + draw_line(c,point(tab.right()-1,tab.top()+2),point(tab.right()-1,tab.bottom()),gray); + draw_line(c,point(tab.left()+2,tab.top()),point(tab.right()-2,tab.top()),white); + draw_pixel(c,point(tab.left()+1,tab.top()+1),white); + draw_pixel(c,point(tab.right()-1,tab.top()+1),dark_gray); + } + +// ---------------------------------------------------------------------------------------- + + void tabbed_display:: + set_main_font ( + const std::shared_ptr& f + ) + { + auto_mutex M(m); + mfont = f; + + for (unsigned long i = 0; i < tabs.size(); ++i) + { + unsigned long height; + mfont->compute_size(tabs[i].name,tabs[i].width,height); + } + + recompute_tabs(); + set_pos(rect.left(), rect.top()); + + parent.invalidate_rectangle(rect); + } + +// ---------------------------------------------------------------------------------------- + + void tabbed_display:: + recompute_tabs ( + ) + { + const long offset = mfont->height() + bottom_pad + top_pad; + + + // figure out the size and position of all the tabs + rectangle sel_tab_rect, other_tab; + sel_tab_rect.set_top(rect.top()); + sel_tab_rect.set_bottom(rect.top()+offset); + + other_tab.set_top(rect.top()+2); + other_tab.set_bottom(rect.top()+offset-1); + + long cur_x = rect.left(); + for (unsigned long i = 0; i < tabs.size(); ++i) + { + const unsigned long str_width = tabs[i].width; + if (selected_tab_ != i) + { + other_tab.set_left(cur_x); + cur_x += left_pad + str_width + right_pad; + other_tab.set_right(cur_x); + tabs[i].rect = other_tab; + ++cur_x; + + } + else + { + if (i != 0) + sel_tab_rect.set_left(cur_x-2); + else + sel_tab_rect.set_left(cur_x); + + cur_x += left_pad + str_width + right_pad; + + if (i != tabs.size()-1) + sel_tab_rect.set_right(cur_x+2); + else + sel_tab_rect.set_right(cur_x); + ++cur_x; + + tabs[i].rect = sel_tab_rect; + } + } + + // make sure this object is wide enough + const rectangle& last = tabs[tabs.size()-1].rect; + const rectangle& first = tabs[0].rect; + rect = last + rect + first; + + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// named_rectangle object methods +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + named_rectangle:: + named_rectangle( + drawable_window& w + ) : + drawable(w), + name_width(0), + name_height(0) + { + make_name_fit_in_rect(); + enable_events(); + } + +// ---------------------------------------------------------------------------------------- + + named_rectangle:: + ~named_rectangle( + ) + { + disable_events(); + parent.invalidate_rectangle(rect); + } + +// ---------------------------------------------------------------------------------------- + + void named_rectangle:: + set_size ( + unsigned long width, + unsigned long height + ) + { + auto_mutex M(m); + rectangle old(rect); + const long x = rect.left(); + const long y = rect.top(); + rect.set_right(x+width-1); + rect.set_bottom(y+height-1); + + make_name_fit_in_rect(); + parent.invalidate_rectangle(rect+old); + } + +// ---------------------------------------------------------------------------------------- + + void named_rectangle:: + wrap_around ( + const rectangle& r + ) + { + auto_mutex M(m); + rectangle old(rect); + const unsigned long pad = name_height/2; + + rect = rectangle(r.left()-pad, r.top()-name_height*4/3, r.right()+pad, r.bottom()+pad); + + make_name_fit_in_rect(); + parent.invalidate_rectangle(rect+old); + } + +// ---------------------------------------------------------------------------------------- + + void named_rectangle:: + set_main_font ( + const std::shared_ptr& f + ) + { + auto_mutex M(m); + mfont = f; + mfont->compute_size(name_,name_width,name_height); + make_name_fit_in_rect(); + parent.invalidate_rectangle(rect); + } + +// ---------------------------------------------------------------------------------------- + + void named_rectangle:: + make_name_fit_in_rect ( + ) + { + // make sure the named rectangle is big enough to contain the name + const unsigned long wtemp = mfont->height() + name_width; + const unsigned long htemp = mfont->height() + name_height; + if (rect.width() < wtemp) + rect.set_right(rect.left() + wtemp - 1 ); + if (rect.height() < htemp) + rect.set_bottom(rect.bottom() + htemp - 1 ); + } + +// ---------------------------------------------------------------------------------------- + + void named_rectangle:: + set_name ( + const std::string& name + ) + { + set_name(convert_mbstring_to_wstring(name)); + } + + void named_rectangle:: + set_name ( + const std::wstring& name + ) + { + set_name(convert_wstring_to_utf32(name)); + } + + void named_rectangle:: + set_name ( + const dlib::ustring& name + ) + { + auto_mutex M(m); + name_ = name.c_str(); + mfont->compute_size(name_,name_width,name_height); + + make_name_fit_in_rect(); + parent.invalidate_rectangle(rect); + } + +// ---------------------------------------------------------------------------------------- + + const std::string named_rectangle:: + name ( + ) const + { + return convert_wstring_to_mbstring(wname()); + } + + const std::wstring named_rectangle:: + wname ( + ) const + { + return convert_utf32_to_wstring(uname()); + } + + const dlib::ustring named_rectangle:: + uname ( + ) const + { + auto_mutex M(m); + return dlib::ustring(name_.c_str()); + } + +// ---------------------------------------------------------------------------------------- + + void named_rectangle:: + draw ( + const canvas& c + ) const + { + rectangle area = rect.intersect(c); + if (area.is_empty()) + return; + + const unsigned long gap = mfont->height()/2; + rectangle strrect = rect; + strrect.set_left(rect.left() + gap); + + const unsigned long rtop = rect.top() + name_height/2; + + const rgb_pixel white(255,255,255); + const rgb_pixel gray(128,128,128); + + mfont->draw_string(c,strrect,name_); + draw_line(c,point(rect.left(), rtop), + point(rect.left()+gap/2, rtop), gray); + draw_line(c,point(rect.left(), rtop), + point(rect.left(), rect.bottom()-1), gray); + draw_line(c,point(rect.left(), rect.bottom()-1), + point(rect.right()-1, rect.bottom()-1), gray); + draw_line(c,point(rect.right()-1, rtop), + point(rect.right()-1, rect.bottom()-2), gray); + draw_line(c,point(strrect.left() + name_width + 2, rtop), + point(rect.right()-1, rtop), gray); + + draw_line(c,point(strrect.left() + name_width + 2, rtop+1), + point( rect.right()-2, rtop+1), white); + draw_line(c,point(rect.right(), rtop), + point(rect.right(), rect.bottom()), white); + draw_line(c,point(rect.left(), rect.bottom()), + point(rect.right(), rect.bottom()), white); + draw_line(c,point(rect.left()+1, rtop+1), + point(rect.left()+1, rect.bottom()-2), white); + draw_line(c,point(rect.left()+1, rtop+1), + point(rect.left()+gap/2, rtop+1), white); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // class mouse_tracker +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + mouse_tracker:: + mouse_tracker( + drawable_window& w + ) : + draggable(w), + offset(18), + nr(w), + x_label(w), + y_label(w), + click_x(-1), + click_y(-1) + { + set_draggable_area(rectangle(0,0,500,500)); + + + x_label.set_text("x: "); + y_label.set_text("y: "); + nr.set_name("mouse position"); + + + x_label.set_pos(offset,offset); + y_label.set_pos(x_label.get_rect().left(), x_label.get_rect().bottom()+3); + + nr.wrap_around(x_label.get_rect() + y_label.get_rect()); + rect = nr.get_rect(); + + set_z_order(2000000000); + x_label.set_z_order(2000000001); + y_label.set_z_order(2000000001); + nr.set_z_order(2000000001); + + enable_events(); + } + +// ---------------------------------------------------------------------------------------- + + mouse_tracker:: + ~mouse_tracker( + ) + { + disable_events(); + parent.invalidate_rectangle(rect); + } + +// ---------------------------------------------------------------------------------------- + + void mouse_tracker:: + set_main_font ( + const std::shared_ptr& f + ) + { + auto_mutex M(m); + nr.set_main_font(f); + x_label.set_main_font(f); + y_label.set_main_font(f); + mfont = f; + nr.wrap_around(x_label.get_rect() + y_label.get_rect()); + rect = nr.get_rect(); + } + +// ---------------------------------------------------------------------------------------- + + void mouse_tracker:: + set_pos ( + long x, + long y + ) + { + draggable::set_pos(x,y); + nr.set_pos(x,y); + x_label.set_pos(rect.left()+offset,rect.top()+offset); + y_label.set_pos(x_label.get_rect().left(), x_label.get_rect().bottom()+3); + } + +// ---------------------------------------------------------------------------------------- + + void mouse_tracker:: + show ( + ) + { + draggable::show(); + nr.show(); + x_label.show(); + y_label.show(); + } + +// ---------------------------------------------------------------------------------------- + + void mouse_tracker:: + hide ( + ) + { + draggable::hide(); + nr.hide(); + x_label.hide(); + y_label.hide(); + } + +// ---------------------------------------------------------------------------------------- + + void mouse_tracker:: + enable ( + ) + { + draggable::enable(); + nr.enable(); + x_label.enable(); + y_label.enable(); + } + +// ---------------------------------------------------------------------------------------- + + void mouse_tracker:: + disable ( + ) + { + draggable::disable(); + nr.disable(); + x_label.disable(); + y_label.disable(); + } + +// ---------------------------------------------------------------------------------------- + + void mouse_tracker:: + on_mouse_down ( + unsigned long btn, + unsigned long state, + long x, + long y, + bool double_clicked + ) + { + draggable::on_mouse_down(btn,state,x,y,double_clicked); + if ((state & base_window::SHIFT) && (btn == base_window::LEFT) && enabled && !hidden) + { + parent.invalidate_rectangle(rectangle(x,y,x,y)); + parent.invalidate_rectangle(rectangle(click_x,click_y,click_x,click_y)); + click_x = x; + click_y = y; + + y_label.set_text("y: 0"); + x_label.set_text("x: 0"); + } + } + +// ---------------------------------------------------------------------------------------- + + void mouse_tracker:: + on_mouse_move ( + unsigned long state, + long x, + long y + ) + { + if (!hidden && enabled) + { + parent.invalidate_rectangle(rect); + draggable::on_mouse_move(state,x,y); + + long dx = 0; + long dy = 0; + if (click_x != -1) + dx = click_x; + if (click_y != -1) + dy = click_y; + + sout.str(""); + sout << "y: " << y - dy; + y_label.set_text(sout.str()); + + sout.str(""); + sout << "x: " << x - dx; + x_label.set_text(sout.str()); + } + } + +// ---------------------------------------------------------------------------------------- + + void mouse_tracker:: + on_drag ( + ) + { + nr.set_pos(rect.left(),rect.top()); + x_label.set_pos(rect.left()+offset,rect.top()+offset); + y_label.set_pos(x_label.get_rect().left(), x_label.get_rect().bottom()+3); + + long x = 0; + long y = 0; + if (click_x != -1) + x = click_x; + if (click_y != -1) + y = click_y; + + sout.str(""); + sout << "y: " << lasty - y; + y_label.set_text(sout.str()); + + sout.str(""); + sout << "x: " << lastx - x; + x_label.set_text(sout.str()); + } + +// ---------------------------------------------------------------------------------------- + + void mouse_tracker:: + draw ( + const canvas& c + ) const + { + fill_rect(c, rect,rgb_pixel(212,208,200)); + draw_pixel(c, point(click_x,click_y),rgb_pixel(255,0,0)); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // class list_box +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + namespace list_box_helper{ + template + list_box:: + list_box( + drawable_window& w + ) : + scrollable_region(w,MOUSE_WHEEL|MOUSE_CLICK), + ms_enabled(false), + last_selected(0) + { + set_vertical_scroll_increment(mfont->height()); + set_horizontal_scroll_increment(mfont->height()); + + style.reset(new list_box_style_default()); + enable_events(); + } + +// ---------------------------------------------------------------------------------------- + + template + list_box:: + ~list_box( + ) + { + disable_events(); + parent.invalidate_rectangle(rect); + } + +// ---------------------------------------------------------------------------------------- + + template + void list_box:: + set_main_font ( + const std::shared_ptr& f + ) + { + auto_mutex M(m); + mfont = f; + // recompute the sizes of all the items + for (unsigned long i = 0; i < items.size(); ++i) + { + mfont->compute_size(items[i].name,items[i].width, items[i].height); + } + set_vertical_scroll_increment(mfont->height()); + parent.invalidate_rectangle(rect); + } + +// ---------------------------------------------------------------------------------------- + + template + bool list_box:: + is_selected ( + unsigned long index + ) const + { + auto_mutex M(m); + DLIB_ASSERT ( index < size() , + "\tbool list_box::is_selected(index)" + << "\n\tindex: " << index + << "\n\tsize(): " << size() ); + + return items[index].is_selected; + } + +// ---------------------------------------------------------------------------------------- + + template + void list_box:: + select ( + unsigned long index + ) + { + auto_mutex M(m); + DLIB_ASSERT ( index < size() , + "\tvoid list_box::select(index)" + << "\n\tindex: " << index + << "\n\tsize(): " << size() ); + + last_selected = index; + items[index].is_selected = true; + parent.invalidate_rectangle(rect); + } + +// ---------------------------------------------------------------------------------------- + + template + void list_box:: + unselect ( + unsigned long index + ) + { + auto_mutex M(m); + DLIB_ASSERT ( index < size() , + "\tvoid list_box::unselect(index)" + << "\n\tindex: " << index + << "\n\tsize(): " << size() ); + items[index].is_selected = false; + parent.invalidate_rectangle(rect); + } + +// ---------------------------------------------------------------------------------------- + + template + const S& list_box::operator [] ( + unsigned long index + ) const + { + auto_mutex M(m); + DLIB_ASSERT ( index < size() , + "\tconst std::string& list_box::operator[](index)" + << "\n\tindex: " << index + << "\n\tsize(): " << size() ); + return items[index].name; + } + +// ---------------------------------------------------------------------------------------- + + template + bool list_box:: + multiple_select_enabled ( + ) const + { + auto_mutex M(m); + return ms_enabled; + } + +// ---------------------------------------------------------------------------------------- + + template + void list_box:: + enable_multiple_select ( + ) + { + auto_mutex M(m); + ms_enabled = true; + } + +// ---------------------------------------------------------------------------------------- + + template + void list_box:: + disable_multiple_select ( + ) + { + auto_mutex M(m); + ms_enabled = false; + } + +// ---------------------------------------------------------------------------------------- + + template + bool list_box:: + at_start ( + ) const + { + auto_mutex M(m); + return items.at_start(); + } + +// ---------------------------------------------------------------------------------------- + + template + void list_box:: + reset ( + ) const + { + auto_mutex M(m); + items.reset(); + } + +// ---------------------------------------------------------------------------------------- + + template + bool list_box:: + current_element_valid ( + ) const + { + auto_mutex M(m); + return items.current_element_valid(); + } + +// ---------------------------------------------------------------------------------------- + + template + const S &list_box:: + element ( + ) const + { + auto_mutex M(m); + DLIB_ASSERT ( current_element_valid() , + "\tconst std::string& list_box::element()" + ); + return items.element().name; + } + +// ---------------------------------------------------------------------------------------- + + template + const S &list_box:: + element ( + ) + { + auto_mutex M(m); + DLIB_ASSERT ( current_element_valid() , + "\tconst std::string& list_box::element()" + ); + return items.element().name; + } + +// ---------------------------------------------------------------------------------------- + + template + bool list_box:: + move_next ( + ) const + { + auto_mutex M(m); + return items.move_next(); + } + +// ---------------------------------------------------------------------------------------- + + template + size_t list_box:: + size ( + ) const + { + auto_mutex M(m); + return items.size(); + } + +// ---------------------------------------------------------------------------------------- + + template + void list_box:: + draw ( + const canvas& c + ) const + { + scrollable_region::draw(c); + + rectangle area = display_rect().intersect(c); + if (area.is_empty()) + return; + + style->draw_list_box_background(c, display_rect(), enabled); + + long y = total_rect().top(); + for (unsigned long i = 0; i < items.size(); ++i) + { + if (y+(long)items[i].height <= area.top()) + { + y += items[i].height; + continue; + } + + rectangle r(total_rect().left(), y, display_rect().right(), y+items[i].height-1); + + style->draw_list_box_item(c,r, display_rect(), enabled, *mfont, items[i].name, items[i].is_selected); + + + y += items[i].height; + + if (y > area.bottom()) + break; + } + } + +// ---------------------------------------------------------------------------------------- + + template + void list_box:: + on_mouse_down ( + unsigned long btn, + unsigned long state, + long x, + long y, + bool is_double_click + ) + { + if (display_rect().contains(x,y) && btn == base_window::LEFT && enabled && !hidden ) + { + if ( ms_enabled == false || + ((!(state&base_window::CONTROL)) && !(state&base_window::SHIFT))) + { + items.reset(); + while (items.move_next()) + { + items.element().is_selected = false; + } + } + + y -= total_rect().top(); + long h = 0; + for (unsigned long i = 0; i < items.size(); ++i) + { + h += items[i].height; + if (h >= y) + { + if (ms_enabled) + { + if (state&base_window::CONTROL) + { + items[i].is_selected = !items[i].is_selected; + if (items[i].is_selected) + last_selected = i; + } + else if (state&base_window::SHIFT) + { + // we want to select everything between (and including) the + // current thing clicked and last_selected. + const unsigned long first = std::min(i,last_selected); + const unsigned long last = std::max(i,last_selected); + for (unsigned long j = first; j <= last; ++j) + items[j].is_selected = true; + } + else + { + items[i].is_selected = true; + last_selected = i; + if (is_double_click && event_handler.is_set()) + event_handler(i); + else if (single_click_event_handler.is_set()) + single_click_event_handler(i); + } + } + else + { + items[i].is_selected = true; + last_selected = i; + if (is_double_click && event_handler.is_set()) + event_handler(i); + else if (single_click_event_handler.is_set()) + single_click_event_handler(i); + } + + break; + } + } + + parent.invalidate_rectangle(rect); + } + } + +// ---------------------------------------------------------------------------------------- + + template + unsigned long list_box:: + get_selected ( + ) const + { + auto_mutex M(m); + DLIB_ASSERT ( multiple_select_enabled() == false, + "\tunsigned long list_box::get_selected()" + ); + for (unsigned long i = 0; i < items.size(); ++i) + { + if (items[i].is_selected) + return i; + } + return items.size(); + } +// ---------------------------------------------------------------------------------------- + + // making instance of template + template class list_box; + template class list_box; + template class list_box; + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // function message_box() +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + namespace message_box_helper + { + void box_win:: + initialize ( + ) + { + msg.set_pos(20,20); + msg.set_text(message); + rectangle msg_rect = msg.get_rect(); + btn_ok.set_name("OK"); + btn_ok.set_size(60,btn_ok.height()); + if (msg_rect.width() >= 60) + btn_ok.set_pos(msg_rect.width()/2+msg_rect.left()-btn_ok.width()/2,msg_rect.bottom()+15); + else + btn_ok.set_pos(20,msg_rect.bottom()+15); + btn_ok.set_click_handler(*this,&box_win::on_click); + + rectangle size = btn_ok.get_rect() + msg_rect; + set_size(size.right()+20,size.bottom()+20); + + + show(); + set_title(title); + } + + // ------------------------------------------------------------------------------------ + + box_win:: + box_win ( + const std::string& title_, + const std::string& message_ + ) : + drawable_window(false), + title(convert_mbstring_to_wstring(title_)), + message(convert_mbstring_to_wstring(message_)), + msg(*this), + btn_ok(*this) + { + initialize(); + } + + // ------------------------------------------------------------------------------------ + + box_win:: + box_win ( + const std::wstring& title_, + const std::wstring& message_ + ) : + drawable_window(false), + title(title_), + message(message_), + msg(*this), + btn_ok(*this) + { + initialize(); + } + + // ------------------------------------------------------------------------------------ + + box_win:: + box_win ( + const dlib::ustring& title_, + const dlib::ustring& message_ + ) : + drawable_window(false), + title(convert_utf32_to_wstring(title_)), + message(convert_utf32_to_wstring(message_)), + msg(*this), + btn_ok(*this) + { + initialize(); + } + + // ------------------------------------------------------------------------------------ + + box_win:: + ~box_win ( + ) + { + close_window(); + } + + // ------------------------------------------------------------------------------------ + + void box_win:: + deleter_thread ( + void* param + ) + { + // The point of this extra event_handler stuff is to allow the user + // to end the program from within the callback. So we want to destroy the + // window *before* we call their callback. + box_win& w = *static_cast(param); + w.close_window(); + any_function event_handler(w.event_handler); + delete &w; + if (event_handler.is_set()) + event_handler(); + } + + // ------------------------------------------------------------------------------------ + + void box_win:: + on_click ( + ) + { + hide(); + create_new_thread(&deleter_thread,this); + } + + // ------------------------------------------------------------------------------------ + + base_window::on_close_return_code box_win:: + on_window_close ( + ) + { + // The point of this extra event_handler stuff is to allow the user + // to end the program within the callback. So we want to destroy the + // window *before* we call their callback. + any_function event_handler_copy(event_handler); + delete this; + if (event_handler_copy.is_set()) + event_handler_copy(); + return CLOSE_WINDOW; + } + + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + + void blocking_box_win:: + initialize ( + ) + { + msg.set_pos(20,20); + msg.set_text(message); + rectangle msg_rect = msg.get_rect(); + btn_ok.set_name("OK"); + btn_ok.set_size(60,btn_ok.height()); + if (msg_rect.width() >= 60) + btn_ok.set_pos(msg_rect.width()/2+msg_rect.left()-btn_ok.width()/2,msg_rect.bottom()+15); + else + btn_ok.set_pos(20,msg_rect.bottom()+15); + btn_ok.set_click_handler(*this,&blocking_box_win::on_click); + + rectangle size = btn_ok.get_rect() + msg_rect; + set_size(size.right()+20,size.bottom()+20); + + + set_title(title); + show(); + } + + // ------------------------------------------------------------------------------------ + + blocking_box_win:: + blocking_box_win ( + const std::string& title_, + const std::string& message_ + ) : + drawable_window(false), + title(convert_mbstring_to_wstring(title_)), + message(convert_mbstring_to_wstring(message_)), + msg(*this), + btn_ok(*this) + { + initialize(); + } + + // ------------------------------------------------------------------------------------ + + blocking_box_win:: + blocking_box_win ( + const std::wstring& title_, + const std::wstring& message_ + ) : + drawable_window(false), + title(title_), + message(message_), + msg(*this), + btn_ok(*this) + { + initialize(); + } + + // ------------------------------------------------------------------------------------ + + blocking_box_win:: + blocking_box_win ( + const dlib::ustring& title_, + const dlib::ustring& message_ + ) : + drawable_window(false), + title(convert_utf32_to_wstring(title_)), + message(convert_utf32_to_wstring(message_)), + msg(*this), + btn_ok(*this) + { + initialize(); + } + + // ------------------------------------------------------------------------------------ + + blocking_box_win:: + ~blocking_box_win ( + ) + { + close_window(); + } + + // ------------------------------------------------------------------------------------ + + void blocking_box_win:: + on_click ( + ) + { + close_window(); + } + + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // function open_file_box() +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + namespace open_file_box_helper + { + box_win:: + box_win ( + const std::string& title, + bool has_text_field + ) : + lbl_dirs(*this), + lbl_files(*this), + lbl_file_name(*this), + lb_dirs(*this), + lb_files(*this), + btn_ok(*this), + btn_cancel(*this), + btn_root(*this), + tf_file_name(*this) + { + if (has_text_field == false) + { + tf_file_name.hide(); + lbl_file_name.hide(); + } + else + { + lbl_file_name.set_text("File: "); + } + + cur_dir = -1; + set_size(500,300); + + lbl_dirs.set_text("Directories:"); + lbl_files.set_text("Files:"); + btn_ok.set_name("Ok"); + btn_cancel.set_name("Cancel"); + btn_root.set_name("/"); + + btn_root.set_click_handler(*this,&box_win::on_root_click); + btn_cancel.set_click_handler(*this,&box_win::on_cancel_click); + btn_ok.set_click_handler(*this,&box_win::on_open_click); + lb_dirs.set_double_click_handler(*this,&box_win::on_dirs_click); + lb_files.set_click_handler(*this,&box_win::on_files_click); + lb_files.set_double_click_handler(*this,&box_win::on_files_double_click); + + + btn_root.set_pos(5,5); + + set_sizes(); + set_title(title); + + on_root_click(); + + // make it so that the file box starts out in our current working + // directory + std::string full_name(get_current_dir()); + + while (full_name.size() > 0) + { + std::string::size_type pos = full_name.find_first_of("\\/"); + std::string left(full_name.substr(0,pos)); + if (pos != std::string::npos) + full_name = full_name.substr(pos+1); + else + full_name.clear(); + + if (left.size() > 0) + enter_folder(left); + } + + + show(); + } + + // ------------------------------------------------------------------------------------ + + box_win:: + ~box_win ( + ) + { + close_window(); + } + + // ------------------------------------------------------------------------------------ + + void box_win:: + set_sizes( + ) + { + unsigned long width, height; + get_size(width,height); + + + if (lbl_file_name.is_hidden()) + { + lbl_dirs.set_pos(0,btn_root.bottom()+5); + lb_dirs.set_pos(0,lbl_dirs.bottom()); + lb_dirs.set_size(width/2,height-lb_dirs.top()-btn_cancel.height()-10); + + lbl_files.set_pos(lb_dirs.right(),btn_root.bottom()+5); + lb_files.set_pos(lb_dirs.right(),lbl_files.bottom()); + lb_files.set_size(width-lb_files.left(),height-lb_files.top()-btn_cancel.height()-10); + + btn_ok.set_pos(width - btn_ok.width()-25,lb_files.bottom()+5); + btn_cancel.set_pos(btn_ok.left() - btn_cancel.width()-5,lb_files.bottom()+5); + } + else + { + + lbl_dirs.set_pos(0,btn_root.bottom()+5); + lb_dirs.set_pos(0,lbl_dirs.bottom()); + lb_dirs.set_size(width/2,height-lb_dirs.top()-btn_cancel.height()-10-tf_file_name.height()); + + lbl_files.set_pos(lb_dirs.right(),btn_root.bottom()+5); + lb_files.set_pos(lb_dirs.right(),lbl_files.bottom()); + lb_files.set_size(width-lb_files.left(),height-lb_files.top()-btn_cancel.height()-10-tf_file_name.height()); + + lbl_file_name.set_pos(lb_files.left(), lb_files.bottom()+8); + tf_file_name.set_pos(lbl_file_name.right(), lb_files.bottom()+5); + tf_file_name.set_width(width-tf_file_name.left()-5); + + btn_ok.set_pos(width - btn_ok.width()-25,tf_file_name.bottom()+5); + btn_cancel.set_pos(btn_ok.left() - btn_cancel.width()-5,tf_file_name.bottom()+5); + } + + } + + // ------------------------------------------------------------------------------------ + + void box_win:: + on_window_resized ( + ) + { + set_sizes(); + } + + // ------------------------------------------------------------------------------------ + + void box_win:: + deleter_thread ( + ) + { + close_window(); + delete this; + } + + // ------------------------------------------------------------------------------------ + + void box_win:: + enter_folder ( + const std::string& folder_name + ) + { + if (btn_root.is_checked()) + btn_root.set_unchecked(); + if (cur_dir != -1) + sob[cur_dir]->set_unchecked(); + + + const std::string old_path = path; + const long old_cur_dir = cur_dir; + + std::unique_ptr new_btn(new toggle_button(*this)); + new_btn->set_name(folder_name); + new_btn->set_click_handler(*this,&box_win::on_path_button_click); + + // remove any path buttons that won't be part of the path anymore + if (sob.size()) + { + while (sob.size() > (unsigned long)(cur_dir+1)) + { + std::unique_ptr junk; + sob.remove(cur_dir+1,junk); + } + } + + if (sob.size()) + new_btn->set_pos(sob[sob.size()-1]->right()+5,sob[sob.size()-1]->top()); + else + new_btn->set_pos(btn_root.right()+5,btn_root.top()); + + cur_dir = sob.size(); + sob.add(sob.size(),new_btn); + + path += folder_name + directory::get_separator(); + if (set_dir(prefix + path) == false) + { + sob.remove(sob.size()-1,new_btn); + path = old_path; + cur_dir = old_cur_dir; + } + else + { + + sob[cur_dir]->set_checked(); + } + } + + // ------------------------------------------------------------------------------------ + + void box_win:: + on_dirs_click ( + unsigned long idx + ) + { + enter_folder(lb_dirs[idx]); + } + + // ------------------------------------------------------------------------------------ + + void box_win:: + on_files_click ( + unsigned long idx + ) + { + if (tf_file_name.is_hidden() == false) + { + tf_file_name.set_text(lb_files[idx]); + } + } + + // ------------------------------------------------------------------------------------ + + void box_win:: + on_files_double_click ( + unsigned long + ) + { + on_open_click(); + } + + // ------------------------------------------------------------------------------------ + + void box_win:: + on_cancel_click ( + ) + { + hide(); + create_new_thread(*this); + } + + // ------------------------------------------------------------------------------------ + + void box_win:: + on_open_click ( + ) + { + if (lb_files.get_selected() != lb_files.size() || tf_file_name.text().size() > 0) + { + if (event_handler.is_set()) + { + if (tf_file_name.is_hidden()) + event_handler(prefix + path + lb_files[lb_files.get_selected()]); + else if (tf_file_name.text().size() > 0) + event_handler(prefix + path + tf_file_name.text()); + } + hide(); + create_new_thread(*this); + } + } + + // ------------------------------------------------------------------------------------ + + void box_win:: + on_path_button_click ( + toggle_button& btn + ) + { + if (btn_root.is_checked()) + btn_root.set_unchecked(); + if (cur_dir != -1) + sob[cur_dir]->set_unchecked(); + std::string new_path; + + for (unsigned long i = 0; i < sob.size(); ++i) + { + new_path += sob[i]->name() + directory::get_separator(); + if (sob[i].get() == &btn) + { + cur_dir = i; + sob[i]->set_checked(); + break; + } + } + if (path != new_path) + { + path = new_path; + set_dir(prefix+path); + } + } + + // ------------------------------------------------------------------------------------ + + struct case_insensitive_compare + { + bool operator() ( + const std::string& a, + const std::string& b + ) const + { + std::string::size_type i, size; + size = std::min(a.size(),b.size()); + for (i = 0; i < size; ++i) + { + if (std::tolower(a[i]) < std::tolower(b[i])) + return true; + else if (std::tolower(a[i]) > std::tolower(b[i])) + return false; + } + if (a.size() < b.size()) + return true; + else + return false; + } + }; + + // ------------------------------------------------------------------------------------ + + bool box_win:: + set_dir ( + const std::string& dir + ) + { + try + { + directory d(dir); + queue::kernel_1a_c qod; + queue::kernel_1a_c qof; + queue::sort_1a_c qos; + d.get_dirs(qod); + d.get_files(qof); + + qod.reset(); + while (qod.move_next()) + { + std::string temp = qod.element().name(); + qos.enqueue(temp); + } + qos.sort(case_insensitive_compare()); + lb_dirs.load(qos); + qos.clear(); + + qof.reset(); + while (qof.move_next()) + { + std::string temp = qof.element().name(); + qos.enqueue(temp); + } + qos.sort(case_insensitive_compare()); + lb_files.load(qos); + return true; + } + catch (directory::listing_error& ) + { + return false; + } + catch (directory::dir_not_found&) + { + return false; + } + } + + // ------------------------------------------------------------------------------------ + + void box_win:: + on_root_click ( + ) + { + btn_root.set_checked(); + if (cur_dir != -1) + sob[cur_dir]->set_unchecked(); + + queue::kernel_1a_c qod, qod2; + queue::kernel_1a_c qof; + queue::sort_1a_c qos; + get_filesystem_roots(qod); + path.clear(); + cur_dir = -1; + if (qod.size() == 1) + { + qod.current().get_files(qof); + qod.current().get_dirs(qod2); + prefix = qod.current().full_name(); + + qod2.reset(); + while (qod2.move_next()) + { + std::string temp = qod2.element().name(); + qos.enqueue(temp); + } + qos.sort(case_insensitive_compare()); + lb_dirs.load(qos); + qos.clear(); + + qof.reset(); + while (qof.move_next()) + { + std::string temp = qof.element().name(); + qos.enqueue(temp); + } + qos.sort(case_insensitive_compare()); + lb_files.load(qos); + } + else + { + prefix.clear(); + qod.reset(); + while (qod.move_next()) + { + std::string temp = qod.element().full_name(); + temp = temp.substr(0,temp.size()-1); + qos.enqueue(temp); + } + qos.sort(case_insensitive_compare()); + lb_dirs.load(qos); + qos.clear(); + lb_files.load(qos); + } + } + + // ------------------------------------------------------------------------------------ + + base_window::on_close_return_code box_win:: + on_window_close ( + ) + { + delete this; + return CLOSE_WINDOW; + } + + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // class menu_bar +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + menu_bar:: + menu_bar( + drawable_window& w + ) : + drawable(w, 0xFFFF), // listen for all events + open_menu(0) + { + adjust_position(); + enable_events(); + } + +// ---------------------------------------------------------------------------------------- + + menu_bar:: + ~menu_bar() + { + disable_events(); + parent.invalidate_rectangle(rect); + } + +// ---------------------------------------------------------------------------------------- + + void menu_bar:: + set_main_font ( + const std::shared_ptr& f + ) + { + auto_mutex M(m); + mfont = f; + adjust_position(); + compute_menu_geometry(); + parent.invalidate_rectangle(rect); + } + +// ---------------------------------------------------------------------------------------- + + void menu_bar:: + set_number_of_menus ( + unsigned long num + ) + { + auto_mutex M(m); + menus.set_max_size(num); + menus.set_size(num); + open_menu = menus.size(); + compute_menu_geometry(); + + for (unsigned long i = 0; i < menus.size(); ++i) + { + menus[i].menu.set_on_hide_handler(*this,&menu_bar::on_popup_hide); + } + + parent.invalidate_rectangle(rect); + } + +// ---------------------------------------------------------------------------------------- + + unsigned long menu_bar:: + number_of_menus ( + ) const + { + auto_mutex M(m); + return menus.size(); + } + +// ---------------------------------------------------------------------------------------- + + void menu_bar:: + set_menu_name ( + unsigned long idx, + const std::string name, + char underline_ch + ) + { + set_menu_name(idx, convert_mbstring_to_wstring(name), underline_ch); + } + +// ---------------------------------------------------------------------------------------- + + void menu_bar:: + set_menu_name ( + unsigned long idx, + const std::wstring name, + char underline_ch + ) + { + set_menu_name(idx, convert_wstring_to_utf32(name), underline_ch); + } + +// ---------------------------------------------------------------------------------------- + + void menu_bar:: + set_menu_name ( + unsigned long idx, + const dlib::ustring name, + char underline_ch + ) + { + DLIB_ASSERT ( idx < number_of_menus() , + "\tvoid menu_bar::set_menu_name()" + << "\n\tidx: " << idx + << "\n\tnumber_of_menus(): " << number_of_menus() + ); + auto_mutex M(m); + menus[idx].name = name.c_str(); + menus[idx].underline_pos = name.find_first_of(underline_ch); + compute_menu_geometry(); + parent.invalidate_rectangle(rect); + } + +// ---------------------------------------------------------------------------------------- + + const std::string menu_bar:: + menu_name ( + unsigned long idx + ) const + { + return convert_wstring_to_mbstring(menu_wname(idx)); + } + +// ---------------------------------------------------------------------------------------- + + const std::wstring menu_bar:: + menu_wname ( + unsigned long idx + ) const + { + return convert_utf32_to_wstring(menu_uname(idx)); + } + +// ---------------------------------------------------------------------------------------- + + const dlib::ustring menu_bar:: + menu_uname ( + unsigned long idx + ) const + { + DLIB_ASSERT ( idx < number_of_menus() , + "\tstd::string menu_bar::menu_name()" + << "\n\tidx: " << idx + << "\n\tnumber_of_menus(): " << number_of_menus() + ); + auto_mutex M(m); + return menus[idx].name.c_str(); + } + +// ---------------------------------------------------------------------------------------- + + popup_menu& menu_bar:: + menu ( + unsigned long idx + ) + { + DLIB_ASSERT ( idx < number_of_menus() , + "\tpopup_menu& menu_bar::menu()" + << "\n\tidx: " << idx + << "\n\tnumber_of_menus(): " << number_of_menus() + ); + auto_mutex M(m); + return menus[idx].menu; + } + +// ---------------------------------------------------------------------------------------- + + const popup_menu& menu_bar:: + menu ( + unsigned long idx + ) const + { + DLIB_ASSERT ( idx < number_of_menus() , + "\tconst popup_menu& menu_bar::menu()" + << "\n\tidx: " << idx + << "\n\tnumber_of_menus(): " << number_of_menus() + ); + auto_mutex M(m); + return menus[idx].menu; + } + +// ---------------------------------------------------------------------------------------- + + void menu_bar:: + on_window_resized ( + ) + { + adjust_position(); + hide_menu(); + } + +// ---------------------------------------------------------------------------------------- + + void menu_bar:: + draw ( + const canvas& c + ) const + { + rectangle area(rect.intersect(c)); + if (area.is_empty()) + return; + + const unsigned char opacity = 40; + fill_rect_with_vertical_gradient(c, rect,rgb_alpha_pixel(255,255,255,opacity), + rgb_alpha_pixel(0,0,0,opacity)); + + // first draw the border between the menu and the rest of the window + draw_line(c, point(rect.left(),rect.bottom()-1), + point(rect.right(),rect.bottom()-1), 100); + draw_line(c, point(rect.left(),rect.bottom()), + point(rect.right(),rect.bottom()), 255); + + // now draw all the menu buttons + for (unsigned long i = 0; i < menus.size(); ++i) + { + mfont->draw_string(c,menus[i].rect, menus[i].name ); + if (menus[i].underline_p1 != menus[i].underline_p2) + draw_line(c, menus[i].underline_p1, menus[i].underline_p2); + + if (open_menu == i) + { + fill_rect_with_vertical_gradient(c, menus[i].bgrect,rgb_alpha_pixel(255,255,0,40), rgb_alpha_pixel(0,0,0,40)); + } + } + } + +// ---------------------------------------------------------------------------------------- + + void menu_bar:: + on_window_moved ( + ) + { + hide_menu(); + } + +// ---------------------------------------------------------------------------------------- + + void menu_bar:: + on_focus_lost ( + ) + { + hide_menu(); + } + +// ---------------------------------------------------------------------------------------- + + void menu_bar:: + on_mouse_down ( + unsigned long btn, + unsigned long , + long x, + long y, + bool + ) + { + + if (rect.contains(x,y) == false || btn != (unsigned long)base_window::LEFT) + { + hide_menu(); + return; + } + + unsigned long old_menu = menus.size(); + + // if a menu is currently open then save its index + if (open_menu != menus.size()) + { + old_menu = open_menu; + hide_menu(); + } + + // figure out which menu should be open if any + for (unsigned long i = 0; i < menus.size(); ++i) + { + if (menus[i].bgrect.contains(x,y)) + { + if (old_menu != i) + show_menu(i); + + break; + } + } + + } + +// ---------------------------------------------------------------------------------------- + + void menu_bar:: + on_mouse_move ( + unsigned long , + long x, + long y + ) + { + // if the mouse is over the menu_bar and some menu is currently open + if (rect.contains(x,y) && open_menu != menus.size()) + { + // if the mouse is still in the same rectangle then don't do anything + if (menus[open_menu].bgrect.contains(x,y) == false) + { + // figure out which menu should be instead + for (unsigned long i = 0; i < menus.size(); ++i) + { + if (menus[i].bgrect.contains(x,y)) + { + show_menu(i); + break; + } + } + + } + } + } + +// ---------------------------------------------------------------------------------------- + + void menu_bar:: + on_keydown ( + unsigned long key, + bool is_printable, + unsigned long state + ) + { + if (state&base_window::KBD_MOD_ALT) + { + // check if the key matches any of our underlined keys + for (unsigned long i = 0; i < menus.size(); ++i) + { + // if we have found a matching key + if (is_printable && + menus[i].underline_pos != std::string::npos && + std::tolower(menus[i].name[menus[i].underline_pos]) == std::tolower(key)) + { + show_menu(i); + menus[open_menu].menu.select_first_item(); + return; + } + } + } + + if (open_menu != menus.size()) + { + unsigned long i = open_menu; + // if the submenu doesn't use this key for something then we will + if (menus[open_menu].menu.forwarded_on_keydown(key,is_printable,state) == false) + { + if (key == base_window::KEY_LEFT) + { + i = (i+menus.size()-1)%menus.size(); + show_menu(i); + menus[open_menu].menu.select_first_item(); + } + else if (key == base_window::KEY_RIGHT) + { + i = (i+1)%menus.size(); + show_menu(i); + menus[open_menu].menu.select_first_item(); + } + else if (key == base_window::KEY_ESC) + { + hide_menu(); + } + } + } + } + +// ---------------------------------------------------------------------------------------- + + void menu_bar:: + show_menu ( + unsigned long i + ) + { + rectangle temp; + + // menu already open so do nothing + if (i == open_menu) + return; + + // if a menu is currently open + if (open_menu != menus.size()) + { + menus[open_menu].menu.hide(); + temp = menus[open_menu].bgrect; + } + + // display the new menu + open_menu = i; + long wx, wy; + parent.get_pos(wx,wy); + wx += menus[i].bgrect.left(); + wy += menus[i].bgrect.bottom()+1; + menus[i].menu.set_pos(wx,wy); + menus[i].menu.show(); + parent.invalidate_rectangle(menus[i].bgrect+temp); + } + +// ---------------------------------------------------------------------------------------- + + void menu_bar:: + hide_menu ( + ) + { + // if a menu is currently open + if (open_menu != menus.size()) + { + menus[open_menu].menu.hide(); + parent.invalidate_rectangle(menus[open_menu].bgrect); + open_menu = menus.size(); + } + } + +// ---------------------------------------------------------------------------------------- + + void menu_bar:: + on_popup_hide ( + ) + { + // if a menu is currently open + if (open_menu != menus.size()) + { + parent.invalidate_rectangle(menus[open_menu].bgrect); + open_menu = menus.size(); + } + } + +// ---------------------------------------------------------------------------------------- + + void menu_bar:: + compute_menu_geometry ( + ) + { + long x = 7; + long bg_x = 0; + for (unsigned long i = 0; i < menus.size(); ++i) + { + // compute the locations of the text rectangles + menus[i].rect.set_top(5); + menus[i].rect.set_left(x); + menus[i].rect.set_bottom(rect.bottom()-2); + + unsigned long width, height; + mfont->compute_size(menus[i].name,width,height); + menus[i].rect = resize_rect_width(menus[i].rect, width); + x = menus[i].rect.right()+10; + + menus[i].bgrect.set_top(0); + menus[i].bgrect.set_left(bg_x); + menus[i].bgrect.set_bottom(rect.bottom()-2); + menus[i].bgrect.set_right(x-5); + bg_x = menus[i].bgrect.right()+1; + + if (menus[i].underline_pos != std::string::npos) + { + // now compute the location of the underline bar + rectangle r1 = mfont->compute_cursor_rect( + menus[i].rect, + menus[i].name, + menus[i].underline_pos); + + rectangle r2 = mfont->compute_cursor_rect( + menus[i].rect, + menus[i].name, + menus[i].underline_pos+1); + + menus[i].underline_p1.x() = r1.left()+1; + menus[i].underline_p2.x() = r2.left()-1; + menus[i].underline_p1.y() = r1.bottom()-mfont->height()+mfont->ascender()+2; + menus[i].underline_p2.y() = r2.bottom()-mfont->height()+mfont->ascender()+2; + } + else + { + // there is no underline in this case + menus[i].underline_p1 = menus[i].underline_p2; + } + + } + } + +// ---------------------------------------------------------------------------------------- + + void menu_bar:: + adjust_position ( + ) + { + unsigned long width, height; + rectangle old(rect); + parent.get_size(width,height); + rect.set_left(0); + rect.set_top(0); + rect = resize_rect(rect,width,mfont->height()+10); + parent.invalidate_rectangle(old+rect); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// class text_grid +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + text_grid:: + text_grid ( + drawable_window& w + ) : + scrollable_region(w, KEYBOARD_EVENTS | MOUSE_CLICK | FOCUS_EVENTS ), + has_focus(false), + cursor_timer(*this,&text_grid::timer_action), + border_color_(128,128,128) + { + + cursor_timer.set_delay_time(500); + set_vertical_scroll_increment(10); + set_horizontal_scroll_increment(10); + enable_events(); + } + +// ---------------------------------------------------------------------------------------- + + text_grid:: + ~text_grid ( + ) + { + // Disable all further events for this drawable object. We have to do this + // because we don't want draw() events coming to this object while or after + // it has been destructed. + disable_events(); + + // wait for the timer to stop doing its thing + cursor_timer.stop_and_wait(); + // Tell the parent window to redraw its area that previously contained this + // drawable object. + parent.invalidate_rectangle(rect); + } + +// ---------------------------------------------------------------------------------------- + + void text_grid:: + set_grid_size ( + unsigned long rows, + unsigned long cols + ) + { + auto_mutex M(m); + row_height.set_max_size(rows); + row_height.set_size(rows); + + col_width.set_max_size(cols); + col_width.set_size(cols); + + grid.set_size(rows,cols); + + for (unsigned long i = 0; i < row_height.size(); ++i) + row_height[i] = (mfont->height()*3)/2; + for (unsigned long i = 0; i < col_width.size(); ++i) + col_width[i] = mfont->height()*5; + + compute_total_rect(); + compute_bg_rects(); + } + +// ---------------------------------------------------------------------------------------- + + unsigned long text_grid:: + number_of_columns ( + ) const + { + auto_mutex M(m); + return grid.nc(); + } + +// ---------------------------------------------------------------------------------------- + + unsigned long text_grid:: + number_of_rows ( + ) const + { + auto_mutex M(m); + return grid.nr(); + } + +// ---------------------------------------------------------------------------------------- + + int text_grid:: + next_free_user_event_number ( + ) const + { + return scrollable_region::next_free_user_event_number()+1; + } + +// ---------------------------------------------------------------------------------------- + + rgb_pixel text_grid:: + border_color ( + ) const + { + auto_mutex M(m); + return border_color_; + } + +// ---------------------------------------------------------------------------------------- + + void text_grid:: + set_border_color ( + rgb_pixel color + ) + { + auto_mutex M(m); + border_color_ = color; + parent.invalidate_rectangle(rect); + } + +// ---------------------------------------------------------------------------------------- + + const std::string text_grid:: + text ( + unsigned long row, + unsigned long col + ) const + { + return convert_wstring_to_mbstring(wtext(row, col)); + } + +// ---------------------------------------------------------------------------------------- + + const std::wstring text_grid:: + wtext ( + unsigned long row, + unsigned long col + ) const + { + return convert_utf32_to_wstring(utext(row, col)); + } + +// ---------------------------------------------------------------------------------------- + + const dlib::ustring text_grid:: + utext ( + unsigned long row, + unsigned long col + ) const + { + auto_mutex M(m); + DLIB_ASSERT ( row < number_of_rows() && col < number_of_columns(), + "\tconst std::string text_grid::text(row,col)" + << "\n\trow: " << row + << "\n\tcol: " << col + << "\n\tnumber_of_rows(): " << number_of_rows() + << "\n\tnumber_of_columns(): " << number_of_columns() + << "\n\tthis: " << this + ); + return grid[row][col].text.c_str(); + } + +// ---------------------------------------------------------------------------------------- + + void text_grid:: + set_text ( + unsigned long row, + unsigned long col, + const std::string& str + ) + { + set_text(row, col, convert_mbstring_to_wstring(str)); + } + +// ---------------------------------------------------------------------------------------- + + void text_grid:: + set_text ( + unsigned long row, + unsigned long col, + const std::wstring& str + ) + { + set_text(row, col, convert_wstring_to_utf32(str)); + } + +// ---------------------------------------------------------------------------------------- + + void text_grid:: + set_text ( + unsigned long row, + unsigned long col, + const dlib::ustring& str + ) + { + auto_mutex M(m); + DLIB_ASSERT ( row < number_of_rows() && col < number_of_columns(), + "\tvoid text_grid::set_text(row,col)" + << "\n\trow: " << row + << "\n\tcol: " << col + << "\n\tnumber_of_rows(): " << number_of_rows() + << "\n\tnumber_of_columns(): " << number_of_columns() + << "\n\tthis: " << this + ); + grid[row][col].text = str.c_str(); + parent.invalidate_rectangle(get_text_rect(row,col)); + } + +// ---------------------------------------------------------------------------------------- + + const rgb_pixel text_grid:: + text_color ( + unsigned long row, + unsigned long col + ) const + { + auto_mutex M(m); + DLIB_ASSERT ( row < number_of_rows() && col < number_of_columns(), + "\tconst rgb_pixel text_grid::text_color(row,col)" + << "\n\trow: " << row + << "\n\tcol: " << col + << "\n\tnumber_of_rows(): " << number_of_rows() + << "\n\tnumber_of_columns(): " << number_of_columns() + << "\n\tthis: " << this + ); + return grid[row][col].text_color; + } + +// ---------------------------------------------------------------------------------------- + + void text_grid:: + set_text_color ( + unsigned long row, + unsigned long col, + const rgb_pixel color + ) + { + auto_mutex M(m); + DLIB_ASSERT ( row < number_of_rows() && col < number_of_columns(), + "\tvoid text_grid::set_text_color(row,col,color)" + << "\n\trow: " << row + << "\n\tcol: " << col + << "\n\tnumber_of_rows(): " << number_of_rows() + << "\n\tnumber_of_columns(): " << number_of_columns() + << "\n\tthis: " << this + ); + grid[row][col].text_color = color; + parent.invalidate_rectangle(get_text_rect(row,col)); + } + +// ---------------------------------------------------------------------------------------- + + const rgb_pixel text_grid:: + background_color ( + unsigned long row, + unsigned long col + ) const + { + auto_mutex M(m); + DLIB_ASSERT ( row < number_of_rows() && col < number_of_columns(), + "\tconst rgb_pixel text_grid::background_color(row,col,color)" + << "\n\trow: " << row + << "\n\tcol: " << col + << "\n\tnumber_of_rows(): " << number_of_rows() + << "\n\tnumber_of_columns(): " << number_of_columns() + << "\n\tthis: " << this + ); + return grid[row][col].bg_color; + } + +// ---------------------------------------------------------------------------------------- + + void text_grid:: + set_background_color ( + unsigned long row, + unsigned long col, + const rgb_pixel color + ) + { + auto_mutex M(m); + DLIB_ASSERT ( row < number_of_rows() && col < number_of_columns(), + "\tvoid text_grid::set_background_color(row,col,color)" + << "\n\trow: " << row + << "\n\tcol: " << col + << "\n\tnumber_of_rows(): " << number_of_rows() + << "\n\tnumber_of_columns(): " << number_of_columns() + << "\n\tthis: " << this + ); + grid[row][col].bg_color = color; + parent.invalidate_rectangle(get_bg_rect(row,col)); + } + +// ---------------------------------------------------------------------------------------- + + bool text_grid:: + is_editable ( + unsigned long row, + unsigned long col + ) const + { + auto_mutex M(m); + DLIB_ASSERT ( row < number_of_rows() && col < number_of_columns(), + "\tbool text_grid::is_editable(row,col)" + << "\n\trow: " << row + << "\n\tcol: " << col + << "\n\tnumber_of_rows(): " << number_of_rows() + << "\n\tnumber_of_columns(): " << number_of_columns() + << "\n\tthis: " << this + ); + return grid[row][col].is_editable; + } + +// ---------------------------------------------------------------------------------------- + + void text_grid:: + set_editable ( + unsigned long row, + unsigned long col, + bool editable + ) + { + auto_mutex M(m); + DLIB_ASSERT ( row < number_of_rows() && col < number_of_columns(), + "\tvoid text_grid::set_editable(row,col,editable)" + << "\n\trow: " << row + << "\n\tcol: " << col + << "\n\tnumber_of_rows(): " << number_of_rows() + << "\n\tnumber_of_columns(): " << number_of_columns() + << "\n\teditable: " << editable + << "\n\tthis: " << this + ); + grid[row][col].is_editable = editable; + if (has_focus && active_row == static_cast(row) && active_col == static_cast(col)) + { + drop_input_focus(); + } + } + +// ---------------------------------------------------------------------------------------- + + void text_grid:: + set_column_width ( + unsigned long col, + unsigned long width + ) + { + auto_mutex M(m); + DLIB_ASSERT ( col < number_of_columns(), + "\tvoid text_grid::set_column_width(col,width)" + << "\n\tcol: " << col + << "\n\tnumber_of_columns(): " << number_of_columns() + << "\n\twidth: " << width + << "\n\tthis: " << this + ); + col_width[col] = width; + compute_total_rect(); + compute_bg_rects(); + } + +// ---------------------------------------------------------------------------------------- + + void text_grid:: + set_row_height ( + unsigned long row, + unsigned long height + ) + { + auto_mutex M(m); + DLIB_ASSERT ( row < number_of_rows() , + "\tvoid text_grid::set_row_height(row,height)" + << "\n\trow: " << row + << "\n\tnumber_of_rows(): " << number_of_rows() + << "\n\theight: " << height + << "\n\tthis: " << this + ); + row_height[row] = height; + compute_total_rect(); + compute_bg_rects(); + } + +// ---------------------------------------------------------------------------------------- + + void text_grid:: + disable ( + ) + { + auto_mutex M(m); + scrollable_region::disable(); + drop_input_focus(); + } + +// ---------------------------------------------------------------------------------------- + + void text_grid:: + hide ( + ) + { + auto_mutex M(m); + scrollable_region::hide(); + drop_input_focus(); + } + +// ---------------------------------------------------------------------------------------- + + void text_grid:: + on_user_event ( + int num + ) + { + // ignore this user event if it isn't for us + if (num != scrollable_region::next_free_user_event_number()) + return; + + if (has_focus && !recent_cursor_move && enabled && !hidden) + { + show_cursor = !show_cursor; + parent.invalidate_rectangle(get_text_rect(active_row,active_col)); + } + recent_cursor_move = false; + } + +// ---------------------------------------------------------------------------------------- + + void text_grid:: + timer_action ( + ) + { + parent.trigger_user_event(this,scrollable_region::next_free_user_event_number()); + } + +// ---------------------------------------------------------------------------------------- + + void text_grid:: + compute_bg_rects ( + ) + { + // loop over each element in the grid and figure out what its rectangle should be + // with respect to the total_rect() + point p1, p2; + p1.y() = total_rect().top(); + for (long row = 0; row < grid.nr(); ++row) + { + p1.x() = total_rect().left(); + p2.y() = p1.y() + row_height[row]-1; + for (long col = 0; col < grid.nc(); ++col) + { + // if this is the last box in this row make it super wide so that it always + // goes to the end of the widget + if (col+1 == grid.nc()) + p2.x() = 1000000; + else + p2.x() = p1.x() + col_width[col]-1; + + // at this point p1 is the upper left corner of this box and p2 is the + // lower right corner of the box; + rectangle bg_rect(p1); + bg_rect += p2; + + grid[row][col].bg_rect = translate_rect(bg_rect, -total_rect().left(), -total_rect().top()); + + + p1.x() += 1 + col_width[col]; + } + p1.y() += 1 + row_height[row]; + } + } + +// ---------------------------------------------------------------------------------------- + + void text_grid:: + compute_total_rect ( + ) + { + if (grid.size() == 0) + { + set_total_rect_size(0,0); + } + else + { + unsigned long width = col_width.size()-1; + unsigned long height = row_height.size()-1; + + for (unsigned long i = 0; i < col_width.size(); ++i) + width += col_width[i]; + for (unsigned long i = 0; i < row_height.size(); ++i) + height += row_height[i]; + + set_total_rect_size(width,height); + } + } + +// ---------------------------------------------------------------------------------------- + + void text_grid:: + on_keydown ( + unsigned long key, + bool is_printable, + unsigned long state + ) + { + // ignore this event if we are disabled or hidden + if (!enabled || hidden) + return; + + if (has_focus) + { + if (is_printable) + { + // if the user hit the tab key then jump to the next box + if (key == '\t') + { + if (active_col+1 == grid.nc()) + { + if (active_row+1 == grid.nr()) + move_cursor(0,0,0); + else + move_cursor(active_row+1,0,0); + } + else + { + move_cursor(active_row,active_col+1,0); + } + } + if (key == '\n') + { + // ignore the enter key + } + else if (grid[active_row][active_col].is_editable) + { + // insert the key the user pressed into the string + grid[active_row][active_col].text.insert(cursor_pos,1,static_cast(key)); + move_cursor(active_row,active_col,cursor_pos+1); + + if (text_modified_handler.is_set()) + text_modified_handler(active_row,active_col); + } + } + else if ((state & base_window::KBD_MOD_CONTROL)) + { + if (key == base_window::KEY_LEFT) + move_cursor(active_row,active_col-1,0); + else if (key == base_window::KEY_RIGHT) + move_cursor(active_row,active_col+1,0); + else if (key == base_window::KEY_UP) + move_cursor(active_row-1,active_col,0); + else if (key == base_window::KEY_DOWN) + move_cursor(active_row+1,active_col,0); + else if (key == base_window::KEY_END) + move_cursor(active_row,active_col,grid[active_row][active_col].text.size()); + else if (key == base_window::KEY_HOME) + move_cursor(active_row,active_col,0); + } + else + { + if (key == base_window::KEY_LEFT) + move_cursor(active_row,active_col,cursor_pos-1); + else if (key == base_window::KEY_RIGHT) + move_cursor(active_row,active_col,cursor_pos+1); + else if (key == base_window::KEY_UP) + move_cursor(active_row-1,active_col,0); + else if (key == base_window::KEY_DOWN) + move_cursor(active_row+1,active_col,0); + else if (key == base_window::KEY_END) + move_cursor(active_row,active_col,grid[active_row][active_col].text.size()); + else if (key == base_window::KEY_HOME) + move_cursor(active_row,active_col,0); + else if (key == base_window::KEY_BACKSPACE) + { + if (cursor_pos > 0 && grid[active_row][active_col].is_editable) + { + grid[active_row][active_col].text.erase( + grid[active_row][active_col].text.begin()+cursor_pos-1, + grid[active_row][active_col].text.begin()+cursor_pos); + move_cursor(active_row,active_col,cursor_pos-1); + + if (text_modified_handler.is_set()) + text_modified_handler(active_row,active_col); + } + } + else if (key == base_window::KEY_DELETE) + { + if (cursor_pos < static_cast(grid[active_row][active_col].text.size()) && + grid[active_row][active_col].is_editable) + { + grid[active_row][active_col].text.erase( + grid[active_row][active_col].text.begin()+cursor_pos); + move_cursor(active_row,active_col,cursor_pos); + + if (text_modified_handler.is_set()) + text_modified_handler(active_row,active_col); + } + } + } + } // if (has_focus) + } + +// ---------------------------------------------------------------------------------------- + + void text_grid:: + on_mouse_down ( + unsigned long btn, + unsigned long state, + long x, + long y, + bool is_double_click + ) + { + scrollable_region::on_mouse_down(btn, state, x, y, is_double_click); + if (display_rect().contains(x,y) && enabled && !hidden) + { + // figure out which box this click landed in + rectangle hit; + + // find which column we hit + unsigned long col = 0; + long box_x = total_rect().left(); + for (unsigned long i = 0; i < col_width.size(); ++i) + { + if (box_x <= x && (x < box_x+static_cast(col_width[i]) || (i+1 == col_width.size()))) + { + col = i; + hit.set_left(box_x); + hit.set_right(box_x+col_width[i]-1); + break; + } + else + { + box_x += col_width[i]+1; + } + } + + // find which row we hit + unsigned long row = 0; + long box_y = total_rect().top(); + for (unsigned long i = 0; i < row_height.size(); ++i) + { + if (box_y <= y && y < box_y+static_cast(row_height[i])) + { + row = i; + hit.set_top(box_y); + hit.set_bottom(box_y+row_height[i]-1); + break; + } + else + { + box_y += row_height[i]+1; + } + } + + // if we hit a box + if (hit.is_empty() == false) + { + move_cursor(row, + col, + mfont->compute_cursor_pos(get_text_rect(row,col), grid[row][col].text, x, y, grid[row][col].first) + ); + } + else + { + drop_input_focus(); + } + } + else + { + drop_input_focus(); + } + } + +// ---------------------------------------------------------------------------------------- + + void text_grid:: + on_mouse_up ( + unsigned long btn, + unsigned long state, + long x, + long y + ) + { + scrollable_region::on_mouse_up(btn, state, x, y); + } + +// ---------------------------------------------------------------------------------------- + + void text_grid:: + on_focus_lost ( + ) + { + drop_input_focus(); + } + +// ---------------------------------------------------------------------------------------- + + void text_grid:: + draw ( + const canvas& c + ) const + { + scrollable_region::draw(c); + rectangle area = c.intersect(display_rect()); + if (area.is_empty() == true) + return; + + if (enabled) + fill_rect(c, area, 255); + + // don't do anything if the grid is empty + if (grid.size() == 0) + return; + + // draw all the vertical lines + point p1, p2; + p1.x() = p2.x() = total_rect().left(); + p1.y() = total_rect().top(); + p2.y() = total_rect().bottom(); + for (unsigned long i = 0; i < col_width.size()-1; ++i) + { + p1.x() += col_width[i]; + p2.x() += col_width[i]; + if (enabled) + draw_line(c,p1,p2,border_color_,area); + else + draw_line(c,p1,p2,128,area); + p1.x() += 1; + p2.x() += 1; + } + + // draw all the horizontal lines + p1.y() = p2.y() = total_rect().top(); + p1.x() = display_rect().left(); + p2.x() = display_rect().right(); + for (unsigned long i = 0; i < row_height.size(); ++i) + { + p1.y() += row_height[i]; + p2.y() += row_height[i]; + if (enabled) + draw_line(c,p1,p2,border_color_,area); + else + draw_line(c,p1,p2,128,area); + p1.y() += 1; + p2.y() += 1; + } + + // draw the backgrounds and text for each box + for (long row = 0; row < grid.nr(); ++row) + { + for (long col = 0; col < grid.nc(); ++col) + { + rectangle bg_rect(get_bg_rect(row,col)); + + rectangle text_rect(get_text_rect(row,col)); + + if (enabled) + { + fill_rect(c,bg_rect.intersect(area),grid[row][col].bg_color); + + mfont->draw_string(c, + text_rect, + grid[row][col].text, + grid[row][col].text_color, + grid[row][col].first, + std::string::npos, + area); + } + else + { + mfont->draw_string(c, + text_rect, + grid[row][col].text, + 128, + grid[row][col].first, + std::string::npos, + area); + } + + // if this box has input focus then draw it with a cursor + if (has_focus && active_col == col && active_row == row && show_cursor) + { + rectangle cursor_rect = mfont->compute_cursor_rect(text_rect, + grid[row][col].text, + cursor_pos, + grid[row][col].first); + draw_rectangle(c,cursor_rect,0,area); + } + + } + } + + + } + +// ---------------------------------------------------------------------------------------- + + rectangle text_grid:: + get_text_rect ( + unsigned long row, + unsigned long col + ) const + { + rectangle bg_rect(get_bg_rect(row,col)); + long padding = (bg_rect.height() - mfont->height())/2 + (bg_rect.height() - mfont->height())%2; + if (padding < 0) + padding = 0; + bg_rect.set_left(bg_rect.left()+padding); + bg_rect.set_top(bg_rect.top()+padding); + bg_rect.set_right(bg_rect.right()-padding); + bg_rect.set_bottom(bg_rect.bottom()-padding); + return bg_rect; + } + +// ---------------------------------------------------------------------------------------- + + rectangle text_grid:: + get_bg_rect ( + unsigned long row, + unsigned long col + ) const + { + return translate_rect(grid[row][col].bg_rect, total_rect().left(), total_rect().top()); + } + +// ---------------------------------------------------------------------------------------- + + void text_grid:: + drop_input_focus ( + ) + { + if (has_focus) + { + parent.invalidate_rectangle(get_text_rect(active_row,active_col)); + has_focus = false; + show_cursor = false; + cursor_timer.stop(); + } + } + +// ---------------------------------------------------------------------------------------- + + void text_grid:: + move_cursor ( + long row, + long col, + long new_cursor_pos + ) + { + // don't do anything if the grid is empty + if (grid.size() == 0) + { + return; + } + + if (row < 0) + row = 0; + if (row >= grid.nr()) + row = grid.nr()-1; + if (col < 0) + col = 0; + if (col >= grid.nc()) + col = grid.nc()-1; + + if (new_cursor_pos < 0) + { + if (col == 0) + { + new_cursor_pos = 0; + } + else + { + --col; + new_cursor_pos = grid[row][col].text.size(); + } + } + + if (new_cursor_pos > static_cast(grid[row][col].text.size())) + { + if (col+1 == grid.nc()) + { + new_cursor_pos = grid[row][col].text.size(); + } + else + { + ++col; + new_cursor_pos = 0; + } + } + + // if some other box had the input focus then redraw it + if (has_focus && (active_row != row || active_col != col )) + { + parent.invalidate_rectangle(get_text_rect(active_row,active_col)); + } + + if (has_focus == false) + { + cursor_timer.start(); + } + + has_focus = true; + recent_cursor_move = true; + show_cursor = true; + active_row = row; + active_col = col; + cursor_pos = new_cursor_pos; + + // adjust the first character to draw so that the string is displayed well + rectangle text_rect(get_text_rect(active_row,active_col)); + rectangle cursor_rect = mfont->compute_cursor_rect(text_rect, + grid[row][col].text, + cursor_pos, + grid[row][col].first); + + // if the cursor rect is too far to the left of the string + if (cursor_pos < static_cast(grid[row][col].first)) + { + if (cursor_pos > 5) + { + grid[row][col].first = cursor_pos - 5; + } + else + { + grid[row][col].first = 0; + } + } + // if the cursor rect is too far to the right of the string + else if (cursor_rect.left() > text_rect.right()) + { + long distance = (cursor_rect.left() - text_rect.right()) + text_rect.width()/3; + // find the letter that is distance pixels from the start of the string + long sum = 0; + for (unsigned long i = grid[row][col].first; i < grid[row][col].text.size(); ++i) + { + sum += (*mfont)[grid[row][col].text[i]].width(); + if (sum >= distance) + { + grid[row][col].first = i; + break; + } + } + } + + scroll_to_rect(get_bg_rect(row,col)); + + // redraw our box + parent.invalidate_rectangle(text_rect); + + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // text_field object methods +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + rectangle text_box:: + get_text_rect ( + ) const + { + const unsigned long padding = style->get_padding(*mfont); + + rectangle text_rect; + text_rect.set_left(total_rect().left()+padding); + text_rect.set_top(total_rect().top()+padding); + text_rect.set_right(total_rect().right()-padding); + text_rect.set_bottom(total_rect().bottom()-padding); + return text_rect; + } + +// ---------------------------------------------------------------------------------------- + + void text_box:: + enable ( + ) + { + scrollable_region::enable(); + right_click_menu.enable(); + } + +// ---------------------------------------------------------------------------------------- + + void text_box:: + on_cut ( + ) + { + on_copy(); + on_delete_selected(); + } + +// ---------------------------------------------------------------------------------------- + + void text_box:: + on_copy ( + ) + { + if (highlight_start <= highlight_end) + { + put_on_clipboard(text_.substr(highlight_start, highlight_end-highlight_start+1)); + } + } + +// ---------------------------------------------------------------------------------------- + + void text_box:: + on_paste ( + ) + { + ustring temp_str; + get_from_clipboard(temp_str); + + + if (highlight_start <= highlight_end) + { + text_ = text_.substr(0,highlight_start) + temp_str + + text_.substr(highlight_end+1,text_.size()-highlight_end-1); + move_cursor(highlight_start+temp_str.size()); + highlight_start = 0; + highlight_end = -1; + parent.invalidate_rectangle(rect); + on_no_text_selected(); + + // send out the text modified event + if (text_modified_handler.is_set()) + text_modified_handler(); + } + else + { + text_ = text_.substr(0,cursor_pos) + temp_str + + text_.substr(cursor_pos,text_.size()-cursor_pos); + move_cursor(cursor_pos+temp_str.size()); + + // send out the text modified event + if (temp_str.size() != 0 && text_modified_handler.is_set()) + text_modified_handler(); + } + + adjust_total_rect(); + } + +// ---------------------------------------------------------------------------------------- + + void text_box:: + on_select_all ( + ) + { + move_cursor(static_cast(text_.size())); + highlight_start = 0; + highlight_end = static_cast(text_.size()-1); + if (highlight_start <= highlight_end) + on_text_is_selected(); + parent.invalidate_rectangle(rect); + } + +// ---------------------------------------------------------------------------------------- + + void text_box:: + on_delete_selected ( + ) + { + if (highlight_start <= highlight_end) + { + text_ = text_.erase(highlight_start,highlight_end-highlight_start+1); + move_cursor(highlight_start); + highlight_start = 0; + highlight_end = -1; + + on_no_text_selected(); + // send out the text modified event + if (text_modified_handler.is_set()) + text_modified_handler(); + + adjust_total_rect(); + + parent.invalidate_rectangle(rect); + } + } + +// ---------------------------------------------------------------------------------------- + + void text_box:: + on_text_is_selected ( + ) + { + right_click_menu.menu().enable_menu_item(0); + right_click_menu.menu().enable_menu_item(1); + right_click_menu.menu().enable_menu_item(3); + } + +// ---------------------------------------------------------------------------------------- + + void text_box:: + on_no_text_selected ( + ) + { + right_click_menu.menu().disable_menu_item(0); + right_click_menu.menu().disable_menu_item(1); + right_click_menu.menu().disable_menu_item(3); + } + +// ---------------------------------------------------------------------------------------- + + void text_box:: + show ( + ) + { + scrollable_region::show(); + right_click_menu.show(); + } + +// ---------------------------------------------------------------------------------------- + + void text_box:: + disable ( + ) + { + auto_mutex M(m); + scrollable_region::disable(); + t.stop(); + has_focus = false; + cursor_visible = false; + right_click_menu.disable(); + } + +// ---------------------------------------------------------------------------------------- + + void text_box:: + hide ( + ) + { + auto_mutex M(m); + scrollable_region::hide(); + t.stop(); + has_focus = false; + cursor_visible = false; + } + +// ---------------------------------------------------------------------------------------- + + void text_box:: + adjust_total_rect ( + ) + { + const unsigned long padding = style->get_padding(*mfont); + unsigned long text_width; + unsigned long text_height; + + mfont->compute_size(text_, text_width, text_height); + + set_total_rect_size(text_width + padding*2, text_height + padding*2); + } + +// ---------------------------------------------------------------------------------------- + + void text_box:: + set_main_font ( + const std::shared_ptr& f + ) + { + auto_mutex M(m); + mfont = f; + adjust_total_rect(); + right_click_menu.set_rect(display_rect()); + } + +// ---------------------------------------------------------------------------------------- + + void text_box:: + draw ( + const canvas& c + ) const + { + scrollable_region::draw(c); + rectangle area = rect.intersect(c); + if (area.is_empty()) + return; + + const point origin(total_rect().left(), total_rect().top()); + + style->draw_text_box(c,display_rect(),get_text_rect(), enabled, *mfont, text_, + translate_rect(cursor_rect, origin), + text_color_, bg_color_, has_focus, cursor_visible, highlight_start, + highlight_end); + } + +// ---------------------------------------------------------------------------------------- + + void text_box:: + set_text ( + const std::string& text + ) + { + set_text(convert_mbstring_to_wstring(text)); + } + + void text_box:: + set_text ( + const std::wstring& text + ) + { + set_text(convert_wstring_to_utf32(text)); + } + + void text_box:: + set_text ( + const dlib::ustring& text + ) + { + auto_mutex M(m); + // do this to get rid of any reference counting that may be present in + // the std::string implementation. + text_ = text.c_str(); + + adjust_total_rect(); + move_cursor(0); + + highlight_start = 0; + highlight_end = -1; + } + +// ---------------------------------------------------------------------------------------- + + const std::string text_box:: + text ( + ) const + { + std::string temp = convert_wstring_to_mbstring(wtext()); + return temp; + } + + const std::wstring text_box:: + wtext ( + ) const + { + std::wstring temp = convert_utf32_to_wstring(utext()); + return temp; + } + + const dlib::ustring text_box:: + utext ( + ) const + { + auto_mutex M(m); + // do this to get rid of any reference counting that may be present in + // the dlib::ustring implementation. + dlib::ustring temp = text_.c_str(); + return temp; + } + +// ---------------------------------------------------------------------------------------- + + void text_box:: + set_size ( + unsigned long width, + unsigned long height + ) + { + auto_mutex M(m); + scrollable_region::set_size(width,height); + right_click_menu.set_rect(display_rect()); + } + +// ---------------------------------------------------------------------------------------- + + void text_box:: + set_pos ( + long x, + long y + ) + { + scrollable_region::set_pos(x,y); + right_click_menu.set_rect(get_text_rect()); + } + +// ---------------------------------------------------------------------------------------- + + void text_box:: + set_background_color ( + const rgb_pixel color + ) + { + auto_mutex M(m); + bg_color_ = color; + parent.invalidate_rectangle(rect); + } + +// ---------------------------------------------------------------------------------------- + + const rgb_pixel text_box:: + background_color ( + ) const + { + auto_mutex M(m); + return bg_color_; + } + +// ---------------------------------------------------------------------------------------- + + void text_box:: + set_text_color ( + const rgb_pixel color + ) + { + auto_mutex M(m); + text_color_ = color; + parent.invalidate_rectangle(rect); + } + +// ---------------------------------------------------------------------------------------- + + const rgb_pixel text_box:: + text_color ( + ) const + { + auto_mutex M(m); + return text_color_; + } + +// ---------------------------------------------------------------------------------------- + + void text_box:: + on_mouse_move ( + unsigned long state, + long x, + long y + ) + { + if (!enabled || hidden || !has_focus) + { + return; + } + + if (state & base_window::LEFT) + { + if (highlight_start <= highlight_end) + { + if (highlight_start == cursor_pos) + shift_pos = highlight_end + 1; + else + shift_pos = highlight_start; + } + + unsigned long new_pos = mfont->compute_cursor_pos(get_text_rect(),text_,x,y); + if (static_cast(new_pos) != cursor_pos) + { + move_cursor(new_pos); + parent.invalidate_rectangle(rect); + } + } + else if (shift_pos != -1) + { + shift_pos = -1; + } + + } + +// ---------------------------------------------------------------------------------------- + + void text_box:: + on_mouse_up ( + unsigned long btn, + unsigned long, + long , + long + ) + { + if (!enabled || hidden) + return; + + if (btn == base_window::LEFT) + shift_pos = -1; + } + +// ---------------------------------------------------------------------------------------- + + void text_box:: + on_mouse_down ( + unsigned long btn, + unsigned long state, + long x, + long y, + bool double_clicked + ) + { + using namespace std; + if (!enabled || hidden || btn != (unsigned long)base_window::LEFT) + return; + + if (display_rect().contains(x,y)) + { + has_focus = true; + cursor_visible = true; + parent.invalidate_rectangle(rect); + t.start(); + + + if (double_clicked) + { + // highlight the double clicked word + string::size_type first, last; + const ustring ustr = convert_utf8_to_utf32(std::string(" \t\n")); + first = text_.substr(0,cursor_pos).find_last_of(ustr.c_str()); + last = text_.find_first_of(ustr.c_str(),cursor_pos); + long f = static_cast(first); + long l = static_cast(last); + if (first == string::npos) + f = -1; + if (last == string::npos) + l = static_cast(text_.size()); + + ++f; + --l; + + move_cursor(l+1); + highlight_start = f; + highlight_end = l; + on_text_is_selected(); + } + else + { + if (state & base_window::SHIFT) + { + if (highlight_start <= highlight_end) + { + if (highlight_start == cursor_pos) + shift_pos = highlight_end + 1; + else + shift_pos = highlight_start; + } + else + { + shift_pos = cursor_pos; + } + } + + bool at_end = false; + if (cursor_pos == 0 || cursor_pos == static_cast(text_.size())) + at_end = true; + const long old_pos = cursor_pos; + + unsigned long new_pos = mfont->compute_cursor_pos(get_text_rect(),text_,x,y); + move_cursor(new_pos); + + shift_pos = cursor_pos; + + if (at_end && cursor_pos == old_pos) + { + highlight_start = 0; + highlight_end = -1; + on_no_text_selected(); + } + } + + } + else if (has_focus && rect.contains(x,y) == false) + { + t.stop(); + has_focus = false; + cursor_visible = false; + shift_pos = -1; + highlight_start = 0; + highlight_end = -1; + on_no_text_selected(); + + if (focus_lost_handler.is_set()) + focus_lost_handler(); + parent.invalidate_rectangle(rect); + } + else + { + has_focus = false; + } + } + +// ---------------------------------------------------------------------------------------- + + void text_box:: + on_keydown ( + unsigned long key, + bool is_printable, + unsigned long state + ) + { + // If the right click menu is up then we don't want to do anything with + // the keyboard ourselves. Let the popup menu use the keyboard for now. + if (right_click_menu.popup_menu_visible()) + return; + + if (has_focus && enabled && !hidden) + { + const ustring space_str = convert_utf8_to_utf32(std::string(" \t\n")); + const bool shift = (state&base_window::KBD_MOD_SHIFT) != 0; + const bool ctrl = (state&base_window::KBD_MOD_CONTROL) != 0; + + if (shift && is_printable == false) + { + if (shift_pos == -1) + { + if (highlight_start <= highlight_end) + { + if (highlight_start == cursor_pos) + shift_pos = highlight_end + 1; + else + shift_pos = highlight_start; + } + else + { + shift_pos = cursor_pos; + } + } + } + else + { + shift_pos = -1; + } + + if (key == base_window::KEY_LEFT) + { + if (cursor_pos != 0) + { + unsigned long new_pos; + if (ctrl) + { + // find the first non-whitespace to our left + std::string::size_type pos = text_.find_last_not_of(space_str.c_str(),cursor_pos); + if (pos != std::string::npos) + { + pos = text_.find_last_of(space_str.c_str(),pos); + if (pos != std::string::npos) + new_pos = static_cast(pos); + else + new_pos = 0; + } + else + { + new_pos = 0; + } + } + else + { + new_pos = cursor_pos-1; + } + + move_cursor(new_pos); + } + else if (shift_pos == -1) + { + highlight_start = 0; + highlight_end = -1; + on_no_text_selected(); + parent.invalidate_rectangle(rect); + } + + } + else if (key == base_window::KEY_RIGHT) + { + if (cursor_pos != static_cast(text_.size())) + { + unsigned long new_pos; + if (ctrl) + { + // find the first non-whitespace to our left + std::string::size_type pos = text_.find_first_not_of(space_str.c_str(),cursor_pos); + if (pos != std::string::npos) + { + pos = text_.find_first_of(space_str.c_str(),pos); + if (pos != std::string::npos) + new_pos = static_cast(pos+1); + else + new_pos = static_cast(text_.size()); + } + else + { + new_pos = static_cast(text_.size()); + } + } + else + { + new_pos = cursor_pos+1; + } + + move_cursor(new_pos); + } + else if (shift_pos == -1) + { + highlight_start = 0; + highlight_end = -1; + on_no_text_selected(); + parent.invalidate_rectangle(rect); + } + } + else if (key == base_window::KEY_UP) + { + if (ctrl) + { + move_cursor(0); + } + else + { + const point origin(total_rect().left(), total_rect().top()); + // move the cursor so the position that is just a few pixels above + // the current cursor_rect + move_cursor(mfont->compute_cursor_pos( + get_text_rect(), text_, cursor_rect.left()+origin.x(), + cursor_rect.top()+origin.y()-mfont->height()/2)); + + } + + if (shift_pos == -1) + { + highlight_start = 0; + highlight_end = -1; + on_no_text_selected(); + parent.invalidate_rectangle(rect); + } + } + else if (key == base_window::KEY_DOWN) + { + if (ctrl) + { + move_cursor(static_cast(text_.size())); + } + else + { + const point origin(total_rect().left(), total_rect().top()); + // move the cursor so the position that is just a few pixels above + // the current cursor_rect + move_cursor(mfont->compute_cursor_pos( + get_text_rect(), text_, cursor_rect.left()+origin.x(), + cursor_rect.bottom()+origin.y()+mfont->height()/2)); + } + + if (shift_pos == -1) + { + highlight_start = 0; + highlight_end = -1; + on_no_text_selected(); + parent.invalidate_rectangle(rect); + } + } + else if (is_printable) + { + if (ctrl) + { + if (key == 'a') + { + on_select_all(); + } + else if (key == 'c') + { + on_copy(); + } + else if (key == 'v') + { + on_paste(); + } + else if (key == 'x') + { + on_cut(); + } + } + else + { + if (highlight_start <= highlight_end) + { + text_ = text_.substr(0,highlight_start) + static_cast(key) + + text_.substr(highlight_end+1,text_.size()-highlight_end-1); + + adjust_total_rect(); + move_cursor(highlight_start+1); + highlight_start = 0; + highlight_end = -1; + on_no_text_selected(); + } + else + { + text_ = text_.substr(0,cursor_pos) + static_cast(key) + + text_.substr(cursor_pos,text_.size()-cursor_pos); + adjust_total_rect(); + move_cursor(cursor_pos+1); + } + + // send out the text modified event + if (text_modified_handler.is_set()) + text_modified_handler(); + + } + + if (key == '\n') + { + if (enter_key_handler.is_set()) + enter_key_handler(); + } + } + else if (key == base_window::KEY_BACKSPACE) + { + // if something is highlighted then delete that + if (highlight_start <= highlight_end) + { + on_delete_selected(); + } + else if (cursor_pos != 0) + { + text_ = text_.erase(cursor_pos-1,1); + adjust_total_rect(); + move_cursor(cursor_pos-1); + + // send out the text modified event + if (text_modified_handler.is_set()) + text_modified_handler(); + } + else + { + // do this just so it repaints itself right + move_cursor(cursor_pos); + } + + } + else if (key == base_window::KEY_DELETE) + { + // if something is highlighted then delete that + if (highlight_start <= highlight_end) + { + on_delete_selected(); + } + else if (cursor_pos != static_cast(text_.size())) + { + text_ = text_.erase(cursor_pos,1); + + adjust_total_rect(); + // send out the text modified event + if (text_modified_handler.is_set()) + text_modified_handler(); + } + else + { + // do this just so it repaints itself right + move_cursor(cursor_pos); + } + + } + else if (key == base_window::KEY_HOME) + { + if (ctrl) + { + move_cursor(0); + } + else if (cursor_pos != 0) + { + // find the start of the current line + ustring::size_type pos = text_.find_last_of('\n',cursor_pos-1); + if (pos == ustring::npos) + pos = 0; + else + pos += 1; + move_cursor(static_cast(pos)); + + } + + if (shift_pos == -1) + { + highlight_start = 0; + highlight_end = -1; + on_no_text_selected(); + parent.invalidate_rectangle(rect); + } + } + else if (key == base_window::KEY_END) + { + if (ctrl) + { + move_cursor(static_cast(text_.size())); + } + { + ustring::size_type pos = text_.find_first_of('\n',cursor_pos); + if (pos == ustring::npos) + pos = text_.size(); + + move_cursor(static_cast(pos)); + } + + if (shift_pos == -1) + { + highlight_start = 0; + highlight_end = -1; + on_no_text_selected(); + parent.invalidate_rectangle(rect); + } + } + else if (key == base_window::KEY_PAGE_DOWN || key == base_window::KEY_PAGE_UP) + { + long jump_size = display_rect().height() - + std::min(mfont->height()*3, display_rect().height()/5); + + // if we are supposed to page up then just jump in the other direction + if (key == base_window::KEY_PAGE_UP) + jump_size = -jump_size; + + scroll_to_rect(translate_rect(display_rect(), point(0, jump_size ))); + } + + cursor_visible = true; + recent_movement = true; + + } + } + +// ---------------------------------------------------------------------------------------- + + void text_box:: + on_string_put( + const std::wstring &str + ) + { + if (has_focus && enabled && !hidden) + { + ustring ustr = convert_wstring_to_utf32(str); + if (highlight_start <= highlight_end) + { + text_ = text_.substr(0,highlight_start) + ustr + + text_.substr(highlight_end+1,text_.size()-highlight_end-1); + + adjust_total_rect(); + move_cursor(highlight_start+ustr.size()); + highlight_start = 0; + highlight_end = -1; + on_no_text_selected(); + } + else + { + text_ = text_.substr(0,cursor_pos) + ustr + + text_.substr(cursor_pos,text_.size()-cursor_pos); + + adjust_total_rect(); + move_cursor(cursor_pos+ustr.size()); + } + + + // send out the text modified event + if (text_modified_handler.is_set()) + text_modified_handler(); + } + } + +// ---------------------------------------------------------------------------------------- + + void text_box:: + move_cursor ( + unsigned long pos + ) + { + using namespace std; + const long old_cursor_pos = cursor_pos; + + + + // figure out where the cursor is supposed to be + cursor_rect = mfont->compute_cursor_rect(get_text_rect(), text_, pos); + const point origin(total_rect().left(), total_rect().top()); + + + cursor_pos = pos; + + + const unsigned long padding = style->get_padding(*mfont); + + // now scroll us so that we can see the current cursor + scroll_to_rect(centered_rect(cursor_rect, cursor_rect.width() + padding + 6, cursor_rect.height() + 1)); + + // adjust the cursor_rect so that it is relative to the total_rect + cursor_rect = translate_rect(cursor_rect, -origin); + + parent.set_im_pos(cursor_rect.left(), cursor_rect.top()); + + if (old_cursor_pos != cursor_pos) + { + if (shift_pos != -1) + { + highlight_start = std::min(shift_pos,cursor_pos); + highlight_end = std::max(shift_pos,cursor_pos)-1; + } + + if (highlight_start > highlight_end) + on_no_text_selected(); + else + on_text_is_selected(); + + recent_movement = true; + cursor_visible = true; + parent.invalidate_rectangle(display_rect()); + } + + if (shift_pos == -1) + { + highlight_start = 0; + highlight_end = -1; + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// perspective_display member functions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + perspective_display:: + perspective_display( + drawable_window& w + ) : + drawable(w,MOUSE_MOVE|MOUSE_CLICK|MOUSE_WHEEL) + { + clear_overlay(); + enable_events(); + } + +// ---------------------------------------------------------------------------------------- + + perspective_display:: + ~perspective_display( + ) + { + disable_events(); + parent.invalidate_rectangle(rect); + } + +// ---------------------------------------------------------------------------------------- + + void perspective_display:: + set_size ( + unsigned long width, + unsigned long height + ) + { + auto_mutex lock(m); + rectangle old(rect); + rect = resize_rect(rect,width,height); + tform = camera_transform(tform.get_camera_pos(), + tform.get_camera_looking_at(), + tform.get_camera_up_direction(), + tform.get_camera_field_of_view(), + std::min(rect.width(),rect.height())); + parent.invalidate_rectangle(old+rect); + } + +// ---------------------------------------------------------------------------------------- + + void perspective_display:: + add_overlay ( + const std::vector& overlay + ) + { + auto_mutex M(m); + if (overlay.size() == 0) + return; + // push this new overlay into our overlay vector + overlay_lines.insert(overlay_lines.end(), overlay.begin(), overlay.end()); + + for (unsigned long i = 0; i < overlay.size(); ++i) + { + sum_pts += overlay[i].p1; + sum_pts += overlay[i].p2; + max_pts.x() = std::max(overlay[i].p1.x(), max_pts.x()); + max_pts.x() = std::max(overlay[i].p2.x(), max_pts.x()); + max_pts.y() = std::max(overlay[i].p1.y(), max_pts.y()); + max_pts.y() = std::max(overlay[i].p2.y(), max_pts.y()); + max_pts.z() = std::max(overlay[i].p1.z(), max_pts.z()); + max_pts.z() = std::max(overlay[i].p2.z(), max_pts.z()); + } + + tform = camera_transform(max_pts, + sum_pts/(overlay_lines.size()*2+overlay_dots.size()), + vector(0,0,1), + tform.get_camera_field_of_view(), + std::min(rect.width(),rect.height())); + + + // make the parent window redraw us now that we changed the overlay + parent.invalidate_rectangle(rect); + } + +// ---------------------------------------------------------------------------------------- + + void perspective_display:: + add_overlay ( + const std::vector& overlay + ) + { + auto_mutex M(m); + if (overlay.size() == 0) + return; + + for (unsigned long i = 0; i < overlay.size(); ++i) + { + overlay_dots.push_back(overlay[i]); + + sum_pts += overlay[i].p; + max_pts.x() = std::max(overlay[i].p.x(), max_pts.x()); + max_pts.y() = std::max(overlay[i].p.y(), max_pts.y()); + max_pts.z() = std::max(overlay[i].p.z(), max_pts.z()); + } + + tform = camera_transform(max_pts, + sum_pts/(overlay_lines.size()*2+overlay_dots.size()), + vector(0,0,1), + tform.get_camera_field_of_view(), + std::min(rect.width(),rect.height())); + + + // make the parent window redraw us now that we changed the overlay + parent.invalidate_rectangle(rect); + } + +// ---------------------------------------------------------------------------------------- + + void perspective_display:: + clear_overlay ( + ) + { + auto_mutex lock(m); + overlay_dots.clear(); + overlay_lines.clear(); + sum_pts = vector(); + max_pts = vector(-std::numeric_limits::infinity(), + -std::numeric_limits::infinity(), + -std::numeric_limits::infinity()); + + parent.invalidate_rectangle(rect); + } + +// ---------------------------------------------------------------------------------------- + + void perspective_display:: + set_dot_double_clicked_handler ( + const any_function&)>& event_handler_ + ) + { + auto_mutex M(m); + dot_clicked_event_handler = event_handler_; + } + +// ---------------------------------------------------------------------------------------- + + void perspective_display:: + draw ( + const canvas& c + ) const + { + if (depth.nr() < (long)c.height() || depth.nc() < (long)c.width()) + depth.set_size(c.height(), c.width()); + assign_all_pixels(depth, std::numeric_limits::infinity()); + + rectangle area = rect.intersect(c); + fill_rect(c, area, 0); + for (unsigned long i = 0; i < overlay_lines.size(); ++i) + { + draw_line(c, tform(overlay_lines[i].p1)+rect.tl_corner(), + tform(overlay_lines[i].p2)+rect.tl_corner(), + overlay_lines[i].color, + area); + } + for (unsigned long i = 0; i < overlay_dots.size(); ++i) + { + double scale, distance; + point p = tform(overlay_dots[i].p, scale, distance) + rect.tl_corner(); + if (area.contains(p) && depth[p.y()-c.top()][p.x()-c.left()] > distance) + { + depth[p.y()-c.top()][p.x()-c.left()] = distance; + assign_pixel(c[p.y()-c.top()][p.x()-c.left()], overlay_dots[i].color); + } + } + + } + +// ---------------------------------------------------------------------------------------- + + void perspective_display:: + on_wheel_up ( + unsigned long //state + ) + { + if (rect.contains(lastx,lasty) == false || hidden || !enabled) + return; + + const double alpha = 0.10; + const vector delta = alpha*(tform.get_camera_pos() - tform.get_camera_looking_at()); + tform = camera_transform( + tform.get_camera_pos() - delta, + tform.get_camera_looking_at(), + tform.get_camera_up_direction(), + tform.get_camera_field_of_view(), + std::min(rect.width(),rect.height())); + parent.invalidate_rectangle(rect); + } + +// ---------------------------------------------------------------------------------------- + + void perspective_display:: + on_wheel_down ( + unsigned long //state + ) + { + if (rect.contains(lastx,lasty) == false || hidden || !enabled) + return; + + const double alpha = 0.10; + const vector delta = alpha*(tform.get_camera_pos() - tform.get_camera_looking_at()); + tform = camera_transform( + tform.get_camera_pos() + delta, + tform.get_camera_looking_at(), + tform.get_camera_up_direction(), + tform.get_camera_field_of_view(), + std::min(rect.width(),rect.height())); + parent.invalidate_rectangle(rect); + } + +// ---------------------------------------------------------------------------------------- + + void perspective_display:: + on_mouse_down ( + unsigned long btn, + unsigned long, //state + long x, + long y, + bool is_double_click + ) + { + if (btn == base_window::LEFT || btn == base_window::RIGHT) + { + last = point(x,y); + } + if (is_double_click && btn == base_window::LEFT && enabled && !hidden && overlay_dots.size() != 0) + { + double best_dist = std::numeric_limits::infinity(); + unsigned long best_idx = 0; + const dpoint pp(x,y); + for (unsigned long i = 0; i < overlay_dots.size(); ++i) + { + dpoint p = tform(overlay_dots[i].p) + rect.tl_corner(); + double dist = length_squared(p-pp); + if (dist < best_dist) + { + best_dist = dist; + best_idx = i; + } + } + if (dot_clicked_event_handler.is_set()) + dot_clicked_event_handler(overlay_dots[best_idx].p); + } + } + +// ---------------------------------------------------------------------------------------- + + void perspective_display:: + on_mouse_move ( + unsigned long state, + long x, + long y + ) + { + if (!enabled || hidden) + return; + + if (state == base_window::LEFT) + { + const point cur(x, y); + dpoint delta = last-cur; + last = cur; + + const vector radius = tform.get_camera_pos()-tform.get_camera_looking_at(); + delta *= 2*pi*length(radius)/600.0; + vector tangent_x = tform.get_camera_up_direction().cross(radius).normalize(); + vector tangent_y = radius.cross(tangent_x).normalize(); + vector new_pos = tform.get_camera_pos() + tangent_x*delta.x() + tangent_y*-delta.y(); + + // now make it have the correct radius relative to the looking at point. + new_pos = (new_pos-tform.get_camera_looking_at()).normalize()*length(radius) + tform.get_camera_looking_at(); + + tform = camera_transform(new_pos, + tform.get_camera_looking_at(), + tangent_y, + tform.get_camera_field_of_view(), + std::min(rect.width(),rect.height())); + parent.invalidate_rectangle(rect); + } + else if (state == (base_window::LEFT|base_window::SHIFT) || + state == base_window::RIGHT) + { + const point cur(x, y); + dpoint delta = last-cur; + last = cur; + + const vector radius = tform.get_camera_pos()-tform.get_camera_looking_at(); + delta *= 2*pi*length(radius)/600.0; + vector tangent_x = tform.get_camera_up_direction().cross(radius).normalize(); + vector tangent_y = radius.cross(tangent_x).normalize(); + + vector offset = tangent_x*delta.x() + tangent_y*-delta.y(); + + + tform = camera_transform( + tform.get_camera_pos()+offset, + tform.get_camera_looking_at()+offset, + tform.get_camera_up_direction(), + tform.get_camera_field_of_view(), + std::min(rect.width(),rect.height())); + parent.invalidate_rectangle(rect); + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// image_display member functions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + namespace impl + { + class image_display_functor + { + const std::string str; + const member_function_pointer mfp; + public: + image_display_functor ( + const std::string& str_, + const member_function_pointer& mfp_ + ) : str(str_), + mfp(mfp_) + {} + + void operator() ( + ) const { mfp(str); } + }; + } + + image_display:: + image_display( + drawable_window& w + ): + scrollable_region(w,KEYBOARD_EVENTS), + zoom_in_scale(1), + zoom_out_scale(1), + drawing_rect(true), + rect_is_selected(false), + selected_rect(0), + default_rect_color(255,0,0,255), + parts_menu(w), + part_width(100), // "parts" circles are drawn 1.0/part_width size on the screen relative to the size of the bounding rectangle. + overlay_editing_enabled(true), + highlight_timer(*this, &image_display::timer_event_unhighlight_rect), + highlighted_rect(std::numeric_limits::max()), + holding_shift_key(false) + { + enable_mouse_drag(); + + highlight_timer.set_delay_time(250); + set_horizontal_scroll_increment(1); + set_vertical_scroll_increment(1); + set_horizontal_mouse_wheel_scroll_increment(30); + set_vertical_mouse_wheel_scroll_increment(30); + + parts_menu.disable(); + + + enable_events(); + } + +// ---------------------------------------------------------------------------------------- + + void image_display:: + on_part_add ( + const std::string& part_name + ) + { + if (!rect_is_selected) + return; + + const point loc = last_right_click_pos; + + // Transform loc from gui window space into the space used by the overlay + // rectangles (i.e. relative to the raw image) + const point origin(total_rect().tl_corner()); + point c1 = loc - origin; + if (zoom_in_scale != 1) + { + c1 = c1/(double)zoom_in_scale; + } + else if (zoom_out_scale != 1) + { + c1 = c1*(double)zoom_out_scale; + } + + overlay_rects[selected_rect].parts[part_name] = c1; + parent.invalidate_rectangle(rect); + + if (event_handler.is_set()) + event_handler(); + } + +// ---------------------------------------------------------------------------------------- + + image_display:: + ~image_display( + ) + { + highlight_timer.stop_and_wait(); + disable_events(); + parent.invalidate_rectangle(rect); + } + +// ---------------------------------------------------------------------------------------- + + rectangle image_display:: + get_image_display_rect ( + ) const + { + if (zoom_in_scale != 1) + { + return rectangle(0,0, img.nc()*zoom_in_scale-1, img.nr()*zoom_in_scale-1); + } + else if (zoom_out_scale != 1) + { + return rectangle(0,0, img.nc()/zoom_out_scale-1, img.nr()/zoom_out_scale-1); + } + else + { + return dlib::get_rect(img); + } + } + +// ---------------------------------------------------------------------------------------- + + void image_display:: + add_overlay ( + const overlay_rect& overlay + ) + { + auto_mutex M(m); + // push this new overlay into our overlay vector + overlay_rects.push_back(overlay); + + // make the parent window redraw us now that we changed the overlay + parent.invalidate_rectangle(rect); + } + +// ---------------------------------------------------------------------------------------- + + void image_display:: + add_overlay ( + const overlay_line& overlay + ) + { + auto_mutex M(m); + + // push this new overlay into our overlay vector + overlay_lines.push_back(overlay); + + // make the parent window redraw us now that we changed the overlay + parent.invalidate_rectangle(get_rect_on_screen(rectangle(overlay.p1, overlay.p2))); + } + +// ---------------------------------------------------------------------------------------- + + void image_display:: + add_overlay ( + const overlay_circle& overlay + ) + { + auto_mutex M(m); + + // push this new overlay into our overlay vector + overlay_circles.push_back(overlay); + + // make the parent window redraw us now that we changed the overlay + parent.invalidate_rectangle(rect); + } + +// ---------------------------------------------------------------------------------------- + + void image_display:: + add_overlay ( + const std::vector& overlay + ) + { + auto_mutex M(m); + + // push this new overlay into our overlay vector + overlay_rects.insert(overlay_rects.end(), overlay.begin(), overlay.end()); + + // make the parent window redraw us now that we changed the overlay + parent.invalidate_rectangle(rect); + } + +// ---------------------------------------------------------------------------------------- + + void image_display:: + add_overlay ( + const std::vector& overlay + ) + { + auto_mutex M(m); + + // push this new overlay into our overlay vector + overlay_lines.insert(overlay_lines.end(), overlay.begin(), overlay.end()); + + // make the parent window redraw us now that we changed the overlay + parent.invalidate_rectangle(rect); + } + +// ---------------------------------------------------------------------------------------- + + void image_display:: + add_overlay ( + const std::vector& overlay + ) + { + auto_mutex M(m); + + // push this new overlay into our overlay vector + overlay_circles.insert(overlay_circles.end(), overlay.begin(), overlay.end()); + + // make the parent window redraw us now that we changed the overlay + parent.invalidate_rectangle(rect); + } + +// ---------------------------------------------------------------------------------------- + + void image_display:: + clear_overlay ( + ) + { + auto_mutex M(m); + overlay_rects.clear(); + overlay_lines.clear(); + overlay_circles.clear(); + parent.invalidate_rectangle(rect); + } + +// ---------------------------------------------------------------------------------------- + + rectangle image_display:: + get_rect_on_screen ( + rectangle orect + ) const + { + const point origin(total_rect().tl_corner()); + orect.left() = orect.left()*zoom_in_scale/zoom_out_scale; + orect.top() = orect.top()*zoom_in_scale/zoom_out_scale; + if (zoom_in_scale != 1) + { + // make it so the box surrounds the pixels when we zoom in. + orect.right() = (orect.right()+1)*zoom_in_scale/zoom_out_scale; + orect.bottom() = (orect.bottom()+1)*zoom_in_scale/zoom_out_scale; + } + else + { + orect.right() = orect.right()*zoom_in_scale/zoom_out_scale; + orect.bottom() = orect.bottom()*zoom_in_scale/zoom_out_scale; + } + + return translate_rect(orect, origin); + } + +// ---------------------------------------------------------------------------------------- + + rectangle image_display:: + get_rect_on_screen ( + unsigned long idx + ) const + { + return get_rect_on_screen(overlay_rects[idx].rect); + } + +// ---------------------------------------------------------------------------------------- + + void image_display:: + draw ( + const canvas& c + ) const + { + scrollable_region::draw(c); + + rectangle area = display_rect().intersect(c); + if (area.is_empty()) + return; + + const point origin(total_rect().tl_corner()); + + // draw the image on the screen + const double scale = zoom_out_scale/(double)zoom_in_scale; + const rectangle img_area = total_rect().intersect(area); + for (long row = img_area.top(); row <= img_area.bottom(); ++row) + { + const long rc = row-c.top(); + const long rimg = (row-origin.y())*scale; + for (long col = img_area.left(); col <= img_area.right(); ++col) + { + assign_pixel(c[rc][col-c.left()], + img[rimg][(col-origin.x())*scale]); + } + } + + // draw the mouse cross-hairs + if (holding_shift_key && total_rect().contains(lastx,lasty) ) + { + draw_line(c, point(lastx,-10000), point(lastx,100000),rgb_pixel(255,255,0), area); + draw_line(c, point(-10000,lasty), point(100000,lasty),rgb_pixel(255,255,0), area); + } + + // now draw all the overlay rectangles + for (unsigned long i = 0; i < overlay_rects.size(); ++i) + { + const rectangle orect = get_rect_on_screen(i); + rgb_alpha_pixel color = overlay_rects[i].color; + // draw crossed out boxes slightly faded + if (overlay_rects[i].crossed_out) + color.alpha = 150; + + if (rect_is_selected && selected_rect == i) + { + draw_rectangle(c, orect, invert_pixel(color), area); + } + else if (highlighted_rect < overlay_rects.size() && highlighted_rect == i) + { + // Draw the rectangle wider and with a slightly different color that tapers + // out at the edges of the line. + hsi_pixel temp; + assign_pixel(temp, 0); + assign_pixel(temp, overlay_rects[i].color); + temp.s = 255; + temp.h = temp.h + 20; + if (temp.i < 245) + temp.i += 10; + rgb_pixel p; + assign_pixel(p, temp); + rgb_alpha_pixel po, po2; + assign_pixel(po, p); + po.alpha = 160; + po2 = po; + po2.alpha = 90; + draw_rectangle(c, grow_rect(orect,2), po2, area); + draw_rectangle(c, grow_rect(orect,1), po, area); + draw_rectangle(c, orect, p, area); + draw_rectangle(c, shrink_rect(orect,1), po, area); + draw_rectangle(c, shrink_rect(orect,2), po2, area); + } + else + { + draw_rectangle(c, orect, color, area); + } + + if (overlay_rects[i].label.size() != 0) + { + // make a rectangle that is at the spot we want to draw our string + rectangle r(orect.br_corner(), c.br_corner()); + mfont->draw_string(c, r, overlay_rects[i].label, color, 0, + std::string::npos, area); + } + + + // draw circles for each "part" in this overlay rectangle. + std::map::const_iterator itr; + for (itr = overlay_rects[i].parts.begin(); itr != overlay_rects[i].parts.end(); ++itr) + { + if (itr->second == OBJECT_PART_NOT_PRESENT) + continue; + + const long part_size = (long)std::max(1.0,std::round(std::sqrt(orect.area())/part_width)); + rectangle temp = centered_rect(get_rect_on_screen(centered_rect(itr->second,1,1)), part_size, part_size); + + if (rect_is_selected && selected_rect == i && + selected_part_name.size() != 0 && selected_part_name == itr->first) + { + draw_circle(c, center(temp), temp.width(), invert_pixel(color), area); + } + else + { + draw_circle(c, center(temp), temp.width(), color, area); + } + + // make a rectangle that is at the spot we want to draw our string + rectangle r((temp.br_corner() + temp.bl_corner())/2, + c.br_corner()); + mfont->draw_string(c, r, itr->first, color, 0, + std::string::npos, area); + } + + if (overlay_rects[i].crossed_out) + { + if (rect_is_selected && selected_rect == i) + { + draw_line(c, orect.tl_corner(), orect.br_corner(),invert_pixel(color), area); + draw_line(c, orect.bl_corner(), orect.tr_corner(),invert_pixel(color), area); + } + else + { + draw_line(c, orect.tl_corner(), orect.br_corner(),color, area); + draw_line(c, orect.bl_corner(), orect.tr_corner(),color, area); + } + } + } + + // now draw all the overlay lines + for (unsigned long i = 0; i < overlay_lines.size(); ++i) + { + draw_line(c, + zoom_in_scale*overlay_lines[i].p1/zoom_out_scale + origin, + zoom_in_scale*overlay_lines[i].p2/zoom_out_scale + origin, + overlay_lines[i].color, area); + } + + // now draw all the overlay circles + for (unsigned long i = 0; i < overlay_circles.size(); ++i) + { + const point center = zoom_in_scale*overlay_circles[i].center/zoom_out_scale + origin; + const int radius = zoom_in_scale*overlay_circles[i].radius/zoom_out_scale; + draw_circle(c, + center, + radius, + overlay_circles[i].color, area); + + if (overlay_circles[i].label.size() != 0) + { + const point temp = center + point(0,radius); + + // make a rectangle that is at the spot we want to draw our string + rectangle r(temp, c.br_corner()); + mfont->draw_string(c, r, overlay_circles[i].label, overlay_circles[i].color, 0, + std::string::npos, area); + } + } + + if (drawing_rect) + draw_rectangle(c, rect_to_draw, invert_pixel(default_rect_color), area); + } + +// ---------------------------------------------------------------------------------------- + + void image_display:: + on_keydown ( + unsigned long key, + bool is_printable, + unsigned long state + ) + { + scrollable_region::on_keydown(key,is_printable, state); + + if (!is_printable && key==base_window::KEY_SHIFT) + { + if (!holding_shift_key) + { + holding_shift_key = true; + parent.invalidate_rectangle(rect); + } + } + else if (holding_shift_key) + { + holding_shift_key = false; + parent.invalidate_rectangle(rect); + } + + if (!is_printable && !hidden && enabled && rect_is_selected && + (key == base_window::KEY_BACKSPACE || key == base_window::KEY_DELETE)) + { + moving_overlay = false; + rect_is_selected = false; + parts_menu.disable(); + if (selected_part_name.size() == 0) + overlay_rects.erase(overlay_rects.begin() + selected_rect); + else + overlay_rects[selected_rect].parts.erase(selected_part_name); + parent.invalidate_rectangle(rect); + + if (event_handler.is_set()) + event_handler(); + } + + if (!hidden && enabled && rect_is_selected && + ((is_printable && key == 'i') || (!is_printable && key==base_window::KEY_END))) + { + overlay_rects[selected_rect].crossed_out = !overlay_rects[selected_rect].crossed_out; + parent.invalidate_rectangle(rect); + + if (event_handler.is_set()) + event_handler(); + } + } + +// ---------------------------------------------------------------------------------------- + + void image_display:: + add_labelable_part_name ( + const std::string& name + ) + { + auto_mutex lock(m); + if (part_names.insert(name).second) + { + member_function_pointer mfp; + mfp.set(*this,&image_display::on_part_add); + parts_menu.menu().add_menu_item(menu_item_text("Add " + name,impl::image_display_functor(name,mfp))); + } + } + +// ---------------------------------------------------------------------------------------- + + void image_display:: + clear_labelable_part_names ( + ) + { + auto_mutex lock(m); + part_names.clear(); + parts_menu.menu().clear(); + } + +// ---------------------------------------------------------------------------------------- + + void image_display:: + on_mouse_down ( + unsigned long btn, + unsigned long state, + long x, + long y, + bool is_double_click + ) + { + scrollable_region::on_mouse_down(btn, state, x, y, is_double_click); + + if (state&base_window::SHIFT) + { + holding_shift_key = true; + } + else if (holding_shift_key) + { + holding_shift_key = false; + parent.invalidate_rectangle(rect); + } + + if (rect.contains(x,y) == false || hidden || !enabled) + return; + + if (image_clicked_handler.is_set()) + { + const point origin(total_rect().tl_corner()); + point p(x,y); + p -= origin; + if (zoom_in_scale != 1) + p = p/zoom_in_scale; + else if (zoom_out_scale != 1) + p = p*zoom_out_scale; + + if (dlib::get_rect(img).contains(p)) + image_clicked_handler(p, is_double_click, btn); + } + + if (!overlay_editing_enabled) + return; + + if (btn == base_window::RIGHT && (state&base_window::SHIFT)) + { + const bool rect_was_selected = rect_is_selected; + rect_is_selected = false; + parts_menu.disable(); + + long best_dist = std::numeric_limits::max(); + long best_idx = 0; + std::string best_part; + + // check if this click landed on any of the overlay rectangles + for (unsigned long i = 0; i < overlay_rects.size(); ++i) + { + const rectangle orect = get_rect_on_screen(i); + + const long dist = distance_to_rect_edge(orect, point(x,y)); + + if (dist < best_dist) + { + best_dist = dist; + best_idx = i; + best_part.clear(); + } + + std::map::const_iterator itr; + for (itr = overlay_rects[i].parts.begin(); itr != overlay_rects[i].parts.end(); ++itr) + { + if (itr->second == OBJECT_PART_NOT_PRESENT) + continue; + + const long part_size = (long)std::max(1.0,std::round(std::sqrt(orect.area())/part_width)); + rectangle temp = centered_rect(get_rect_on_screen(centered_rect(itr->second,1,1)), part_size, part_size); + point c = center(temp); + + // distance from edge of part circle + const long dist = static_cast(std::abs(length(c - point(x,y)) + 0.5 - temp.width())); + if (dist < best_dist) + { + best_idx = i; + best_dist = dist; + best_part = itr->first; + } + } + } + + + if (best_dist < 13) + { + moving_overlay = true; + moving_rect = best_idx; + moving_part_name = best_part; + // If we are moving one of the sides of the rectangle rather than one of + // the parts circles then we need to figure out which side of the rectangle + // we are moving. + if (best_part.size() == 0) + { + // which side is the click closest to? + const rectangle orect = get_rect_on_screen(best_idx); + const point p = nearest_point(orect,point(x,y)); + long dist_left = std::abs(p.x()-orect.left()); + long dist_top = std::abs(p.y()-orect.top()); + long dist_right = std::abs(p.x()-orect.right()); + long dist_bottom = std::abs(p.y()-orect.bottom()); + long min_val = std::min(std::min(dist_left,dist_right),std::min(dist_top,dist_bottom)); + if (dist_left == min_val) + moving_what = MOVING_RECT_LEFT; + else if (dist_top == min_val) + moving_what = MOVING_RECT_TOP; + else if (dist_right == min_val) + moving_what = MOVING_RECT_RIGHT; + else + moving_what = MOVING_RECT_BOTTOM; + } + else + { + moving_what = MOVING_PART; + } + // Do this to make the moving stuff snap to the mouse immediately. + on_mouse_move(state|btn,x,y); + } + + if (rect_was_selected) + parent.invalidate_rectangle(rect); + + return; + } + + if (btn == base_window::RIGHT && rect_is_selected) + { + last_right_click_pos = point(x,y); + parts_menu.set_rect(rect); + return; + } + + if (btn == base_window::LEFT && (state&base_window::CONTROL) && !drawing_rect) + { + long best_dist = std::numeric_limits::max(); + long best_idx = 0; + // check if this click landed on any of the overlay rectangles + for (unsigned long i = 0; i < overlay_rects.size(); ++i) + { + const rectangle orect = get_rect_on_screen(i); + const long dist = distance_to_rect_edge(orect, point(x,y)); + + if (dist < best_dist) + { + best_dist = dist; + best_idx = i; + } + } + if (best_dist < 13) + { + overlay_rects[best_idx].label = default_rect_label; + overlay_rects[best_idx].color = default_rect_color; + highlighted_rect = best_idx; + highlight_timer.stop(); + highlight_timer.start(); + if (event_handler.is_set()) + event_handler(); + parent.invalidate_rectangle(rect); + } + return; + } + + + if (!is_double_click && btn == base_window::LEFT && (state&base_window::SHIFT)) + { + drawing_rect = true; + rect_anchor = point(x,y); + + if (rect_is_selected) + { + rect_is_selected = false; + parts_menu.disable(); + parent.invalidate_rectangle(rect); + } + } + else if (drawing_rect) + { + if (rect_is_selected) + { + rect_is_selected = false; + parts_menu.disable(); + } + + drawing_rect = false; + parent.invalidate_rectangle(rect); + } + else if (is_double_click) + { + const bool rect_was_selected = rect_is_selected; + rect_is_selected = false; + parts_menu.disable(); + + long best_dist = std::numeric_limits::max(); + long best_idx = 0; + std::string best_part; + + // check if this click landed on any of the overlay rectangles + for (unsigned long i = 0; i < overlay_rects.size(); ++i) + { + const rectangle orect = get_rect_on_screen(i); + + const long dist = distance_to_rect_edge(orect, point(x,y)); + + if (dist < best_dist) + { + best_dist = dist; + best_idx = i; + best_part.clear(); + } + + std::map::const_iterator itr; + for (itr = overlay_rects[i].parts.begin(); itr != overlay_rects[i].parts.end(); ++itr) + { + if (itr->second == OBJECT_PART_NOT_PRESENT) + continue; + + const long part_size = (long)std::max(1.0,std::round(std::sqrt(orect.area())/part_width)); + rectangle temp = centered_rect(get_rect_on_screen(centered_rect(itr->second,1,1)), part_size, part_size); + point c = center(temp); + + // distance from edge of part circle + const long dist = static_cast(std::abs(length(c - point(x,y)) + 0.5 - temp.width())); + if (dist < best_dist) + { + best_idx = i; + best_dist = dist; + best_part = itr->first; + } + } + } + + + if (best_dist < 13) + { + rect_is_selected = true; + if (part_names.size() != 0) + parts_menu.enable(); + selected_rect = best_idx; + selected_part_name = best_part; + if (orect_selected_event_handler.is_set()) + orect_selected_event_handler(overlay_rects[best_idx]); + } + + if (rect_is_selected || rect_was_selected) + parent.invalidate_rectangle(rect); + } + else if (rect_is_selected) + { + rect_is_selected = false; + parts_menu.disable(); + parent.invalidate_rectangle(rect); + } + } + +// ---------------------------------------------------------------------------------------- + + std::vector image_display:: + get_overlay_rects ( + ) const + { + auto_mutex lock(m); + return overlay_rects; + } + +// ---------------------------------------------------------------------------------------- + + void image_display:: + set_default_overlay_rect_label ( + const std::string& label + ) + { + auto_mutex lock(m); + default_rect_label = label; + } + +// ---------------------------------------------------------------------------------------- + + std::string image_display:: + get_default_overlay_rect_label ( + ) const + { + auto_mutex lock(m); + return default_rect_label; + } + +// ---------------------------------------------------------------------------------------- + + void image_display:: + set_default_overlay_rect_color ( + const rgb_alpha_pixel& color + ) + { + auto_mutex lock(m); + default_rect_color = color; + } + +// ---------------------------------------------------------------------------------------- + + rgb_alpha_pixel image_display:: + get_default_overlay_rect_color ( + ) const + { + auto_mutex lock(m); + return default_rect_color; + } + +// ---------------------------------------------------------------------------------------- + + void image_display:: + on_mouse_up ( + unsigned long btn, + unsigned long state, + long x, + long y + ) + { + scrollable_region::on_mouse_up(btn,state,x,y); + + if (state&base_window::SHIFT) + { + holding_shift_key = true; + } + else if (holding_shift_key) + { + holding_shift_key = false; + parent.invalidate_rectangle(rect); + } + + if (drawing_rect && btn == base_window::LEFT && (state&base_window::SHIFT) && + !hidden && enabled) + { + const point origin(total_rect().tl_corner()); + point c1 = point(x,y) - origin; + point c2 = rect_anchor - origin; + + if (zoom_in_scale != 1) + { + c1 = c1/(double)zoom_in_scale; + c2 = c2/(double)zoom_in_scale; + } + else if (zoom_out_scale != 1) + { + c1 = c1*(double)zoom_out_scale; + c2 = c2*(double)zoom_out_scale; + } + + rectangle new_rect(c1,c2); + if (zoom_in_scale != 1) + { + // When we are zoomed in we adjust the rectangles a little so they + // are drown surrounding the pixels inside the rect. This adjustment + // is necessary to make this code consistent with this goal. + new_rect.right() -= 1; + new_rect.bottom() -= 1; + } + + + if (new_rect.width() > 0 && new_rect.height() > 0) + { + add_overlay(overlay_rect(new_rect, default_rect_color, default_rect_label)); + + if (event_handler.is_set()) + event_handler(); + } + } + + if (drawing_rect) + { + drawing_rect = false; + parent.invalidate_rectangle(rect); + } + if (moving_overlay) + { + moving_overlay = false; + } + } + +// ---------------------------------------------------------------------------------------- + + void image_display:: + on_mouse_move ( + unsigned long state, + long x, + long y + ) + { + scrollable_region::on_mouse_move(state,x,y); + + if (enabled && !hidden) + { + if (holding_shift_key) + parent.invalidate_rectangle(rect); + + if (state&base_window::SHIFT) + holding_shift_key = true; + else if (holding_shift_key) + holding_shift_key = false; + } + + if (drawing_rect) + { + if ((state&base_window::LEFT) && (state&base_window::SHIFT) && !hidden && enabled) + { + rectangle new_rect(point(x,y), rect_anchor); + parent.invalidate_rectangle(new_rect + rect_to_draw); + rect_to_draw = new_rect; + } + else + { + drawing_rect = false; + parent.invalidate_rectangle(rect); + } + moving_overlay = false; + } + else if (moving_overlay) + { + if ((state&base_window::RIGHT) && (state&base_window::SHIFT) && !hidden && enabled) + { + // map point(x,y) into the image coordinate space. + point p = point(x,y) - total_rect().tl_corner(); + if (zoom_in_scale != 1) + { + if (moving_what == MOVING_PART) + p = p/(double)zoom_in_scale-dpoint(0.5,0.5); + else + p = p/(double)zoom_in_scale; + } + else if (zoom_out_scale != 1) + { + p = p*(double)zoom_out_scale; + } + + + if (moving_what == MOVING_PART) + { + if (overlay_rects[moving_rect].parts[moving_part_name] != p) + { + overlay_rects[moving_rect].parts[moving_part_name] = p; + parent.invalidate_rectangle(rect); + if (event_handler.is_set()) + event_handler(); + } + } + else + { + rectangle original = overlay_rects[moving_rect].rect; + if (moving_what == MOVING_RECT_LEFT) + overlay_rects[moving_rect].rect.left() = std::min(p.x(), overlay_rects[moving_rect].rect.right()); + else if (moving_what == MOVING_RECT_RIGHT) + overlay_rects[moving_rect].rect.right() = std::max(p.x()-1, overlay_rects[moving_rect].rect.left()); + else if (moving_what == MOVING_RECT_TOP) + overlay_rects[moving_rect].rect.top() = std::min(p.y(), overlay_rects[moving_rect].rect.bottom()); + else + overlay_rects[moving_rect].rect.bottom() = std::max(p.y()-1, overlay_rects[moving_rect].rect.top()); + + if (original != overlay_rects[moving_rect].rect) + { + parent.invalidate_rectangle(rect); + if (event_handler.is_set()) + event_handler(); + } + } + } + else + { + moving_overlay = false; + } + } + } + +// ---------------------------------------------------------------------------------------- + + void image_display:: + on_wheel_up ( + unsigned long state + ) + { + // disable mouse wheel if the user is drawing a rectangle + if (drawing_rect) + return; + + // if CONTROL is not being held down + if ((state & base_window::CONTROL) == 0) + { + scrollable_region::on_wheel_up(state); + return; + } + + if (rect.contains(lastx,lasty) == false || hidden || !enabled) + return; + + + if (zoom_in_scale < 100 && zoom_out_scale == 1) + { + const point mouse_loc(lastx, lasty); + // the pixel in img that the mouse is over + const point pix_loc = (mouse_loc - total_rect().tl_corner())/zoom_in_scale; + + zoom_in_scale = zoom_in_scale*10/9 + 1; + + set_total_rect_size(img.nc()*zoom_in_scale, img.nr()*zoom_in_scale); + + // make is to the pixel under the mouse doesn't move while we zoom + const point delta = total_rect().tl_corner() - (mouse_loc - pix_loc*zoom_in_scale); + scroll_to_rect(translate_rect(display_rect(), delta)); + } + else if (zoom_out_scale != 1) + { + const point mouse_loc(lastx, lasty); + // the pixel in img that the mouse is over + const point pix_loc = (mouse_loc - total_rect().tl_corner())*zoom_out_scale; + + zoom_out_scale = zoom_out_scale*9/10; + if (zoom_out_scale == 0) + zoom_out_scale = 1; + + set_total_rect_size(img.nc()/zoom_out_scale, img.nr()/zoom_out_scale); + + // make is to the pixel under the mouse doesn't move while we zoom + const point delta = total_rect().tl_corner() - (mouse_loc - pix_loc/zoom_out_scale); + scroll_to_rect(translate_rect(display_rect(), delta)); + } + } + +// ---------------------------------------------------------------------------------------- + + void image_display:: + on_wheel_down ( + unsigned long state + ) + { + // disable mouse wheel if the user is drawing a rectangle + if (drawing_rect) + return; + + // if CONTROL is not being held down + if ((state & base_window::CONTROL) == 0) + { + scrollable_region::on_wheel_down(state); + return; + } + + if (rect.contains(lastx,lasty) == false || hidden || !enabled) + return; + + + if (zoom_in_scale != 1) + { + const point mouse_loc(lastx, lasty); + // the pixel in img that the mouse is over + const point pix_loc = (mouse_loc - total_rect().tl_corner())/zoom_in_scale; + + zoom_in_scale = zoom_in_scale*9/10; + if (zoom_in_scale == 0) + zoom_in_scale = 1; + + set_total_rect_size(img.nc()*zoom_in_scale, img.nr()*zoom_in_scale); + + // make is to the pixel under the mouse doesn't move while we zoom + const point delta = total_rect().tl_corner() - (mouse_loc - pix_loc*zoom_in_scale); + scroll_to_rect(translate_rect(display_rect(), delta)); + } + else if (std::max(img.nr(), img.nc())/zoom_out_scale > 10) + { + const point mouse_loc(lastx, lasty); + // the pixel in img that the mouse is over + const point pix_loc = (mouse_loc - total_rect().tl_corner())*zoom_out_scale; + + zoom_out_scale = zoom_out_scale*10/9 + 1; + + set_total_rect_size(img.nc()/zoom_out_scale, img.nr()/zoom_out_scale); + + // make is to the pixel under the mouse doesn't move while we zoom + const point delta = total_rect().tl_corner() - (mouse_loc - pix_loc/zoom_out_scale); + scroll_to_rect(translate_rect(display_rect(), delta)); + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// image_window member functions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + image_window:: + image_window( + ) : + gui_img(*this), + window_has_closed(false), + have_last_click(false), + mouse_btn(0), + clicked_signaler(this->wm), + tie_input_events(false) + { + + gui_img.set_image_clicked_handler(*this, &image_window::on_image_clicked); + gui_img.disable_overlay_editing(); + // show this window on the screen + show(); + } + +// ---------------------------------------------------------------------------------------- + + image_window:: + ~image_window( + ) + { + // You should always call close_window() in the destructor of window + // objects to ensure that no events will be sent to this window while + // it is being destructed. + close_window(); + } + +// ---------------------------------------------------------------------------------------- + + base_window::on_close_return_code image_window:: + on_window_close( + ) + { + window_has_closed = true; + clicked_signaler.broadcast(); + return base_window::CLOSE_WINDOW; + } + +// ---------------------------------------------------------------------------------------- + + bool image_window:: + get_next_keypress ( + unsigned long& key, + bool& is_printable, + unsigned long& state + ) + { + auto_mutex lock(wm); + while (have_last_keypress == false && !window_has_closed && + (have_last_click == false || !tie_input_events)) + { + clicked_signaler.wait(); + } + + if (window_has_closed) + return false; + + if (have_last_keypress) + { + // Mark that we are taking the key click so the next call to get_next_keypress() + // will have to wait for another click. + have_last_keypress = false; + key = next_key; + is_printable = next_is_printable; + state = next_state; + return true; + } + else + { + key = 0; + is_printable = true; + return false; + } + } + +// ---------------------------------------------------------------------------------------- + + void image_window:: + on_keydown ( + unsigned long key, + bool is_printable, + unsigned long state + ) + { + dlib::drawable_window::on_keydown(key,is_printable,state); + + have_last_keypress = true; + next_key = key; + next_is_printable = is_printable; + next_state = state; + clicked_signaler.signal(); + } + +// ---------------------------------------------------------------------------------------- + + void image_window:: + tie_events ( + ) + { + auto_mutex lock(wm); + tie_input_events = true; + } + +// ---------------------------------------------------------------------------------------- + + void image_window:: + untie_events ( + ) + { + auto_mutex lock(wm); + tie_input_events = false; + } + +// ---------------------------------------------------------------------------------------- + + bool image_window:: + events_tied ( + ) const + { + auto_mutex lock(wm); + return tie_input_events; + } + +// ---------------------------------------------------------------------------------------- + + bool image_window:: + get_next_double_click ( + point& p, + unsigned long& mouse_button + ) + { + p = point(-1,-1); + + auto_mutex lock(wm); + while (have_last_click == false && !window_has_closed && + (have_last_keypress==false || !tie_input_events)) + { + clicked_signaler.wait(); + } + + if (window_has_closed) + return false; + + if (have_last_click) + { + // Mark that we are taking the point click so the next call to + // get_next_double_click() will have to wait for another click. + have_last_click = false; + mouse_button = mouse_btn; + p = last_clicked_point; + return true; + } + else + { + return false; + } + } + +// ---------------------------------------------------------------------------------------- + + void image_window:: + on_image_clicked ( + const point& p, + bool is_double_click, + unsigned long btn + ) + { + if (is_double_click) + { + have_last_click = true; + last_clicked_point = p; + mouse_btn = btn; + clicked_signaler.signal(); + } + } + +// ---------------------------------------------------------------------------------------- + + void image_window:: + add_overlay ( + const overlay_rect& overlay + ) + { + gui_img.add_overlay(overlay); + } + +// ---------------------------------------------------------------------------------------- + + void image_window:: + add_overlay ( + const overlay_line& overlay + ) + { + gui_img.add_overlay(overlay); + } + +// ---------------------------------------------------------------------------------------- + + void image_window:: + add_overlay ( + const overlay_circle& overlay + ) + { + gui_img.add_overlay(overlay); + } + +// ---------------------------------------------------------------------------------------- + + void image_window:: + add_overlay ( + const std::vector& overlay + ) + { + gui_img.add_overlay(overlay); + } + +// ---------------------------------------------------------------------------------------- + + void image_window:: + add_overlay ( + const std::vector& overlay + ) + { + gui_img.add_overlay(overlay); + } + +// ---------------------------------------------------------------------------------------- + + void image_window:: + add_overlay ( + const std::vector& overlay + ) + { + gui_img.add_overlay(overlay); + } + +// ---------------------------------------------------------------------------------------- + + void image_window:: + clear_overlay ( + ) + { + gui_img.clear_overlay(); + } + +// ---------------------------------------------------------------------------------------- + + void image_window:: + on_window_resized( + ) + { + drawable_window::on_window_resized(); + unsigned long width, height; + get_size(width,height); + gui_img.set_size(width, height); + + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_WIDGETs_CPP_ + diff --git a/ml/dlib/dlib/gui_widgets/widgets.h b/ml/dlib/dlib/gui_widgets/widgets.h new file mode 100644 index 000000000..305d7dc00 --- /dev/null +++ b/ml/dlib/dlib/gui_widgets/widgets.h @@ -0,0 +1,4165 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net), Keita Mochizuki +// License: Boost Software License See LICENSE.txt for the full license. + +#ifndef DLIB_WIDGETs_ +#define DLIB_WIDGETs_ + +#include +#include +#include +#include +#include +#include + +#include "../algs.h" +#include "widgets_abstract.h" +#include "drawable.h" +#include "../gui_core.h" +#include "fonts.h" +#include "../timer.h" +#include "base_widgets.h" +#include "../member_function_pointer.h" +#include "../array.h" +#include "../array2d.h" +#include "../sequence.h" +#include "../dir_nav.h" +#include "../queue.h" +#include "style.h" +#include "../string.h" +#include "../misc_api.h" +#include "../any.h" +#include "../image_processing/full_object_detection.h" + +#ifdef _MSC_VER +// This #pragma directive is also located in the algs.h file but for whatever +// reason visual studio 9 just ignores it when it is only there. + +// this is to disable the "'this' : used in base member initializer list" +// warning you get from some of the GUI objects since all the objects +// require that their parent class be passed into their constructor. +// In this case though it is totally safe so it is ok to disable this warning. +#pragma warning(disable : 4355) +#endif + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // class label +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class label : public drawable + { + public: + label( + drawable_window& w + ) : + drawable(w), + text_color_(0,0,0) + { + enable_events(); + } + + ~label() + { disable_events(); parent.invalidate_rectangle(rect); } + + void set_text ( + const std::string& text + ); + + void set_text ( + const std::wstring& text + ); + + void set_text ( + const dlib::ustring& text + ); + + const std::string text ( + ) const; + + const std::wstring wtext ( + ) const; + + const dlib::ustring utext ( + ) const; + + void set_text_color ( + const rgb_pixel color + ); + + const rgb_pixel text_color ( + ) const; + + void set_main_font ( + const std::shared_ptr& f + ); + + private: + dlib::ustring text_; + rgb_pixel text_color_; + + + // restricted functions + label(label&); // copy constructor + label& operator=(label&); // assignment operator + + protected: + + void draw ( + const canvas& c + ) const; + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // class toggle_button +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class toggle_button : public button_action + { + /*! + INITIAL VALUE + - checked == false + + CONVENTION + - is_checked() == checked + !*/ + + public: + + toggle_button( + drawable_window& w + ) : + button_action(w), + btn_tooltip(w), + checked(false) + { + style.reset(new toggle_button_style_default()); + enable_events(); + } + + ~toggle_button() { disable_events(); parent.invalidate_rectangle(rect); } + + void set_name ( + const std::string& name + ); + + void set_name ( + const std::wstring& name + ); + + void set_name ( + const dlib::ustring& name + ); + + void set_size ( + unsigned long width_, + unsigned long height_ + ); + + void set_tooltip_text ( + const std::string& text + ); + + void set_tooltip_text ( + const std::wstring& text + ); + + void set_tooltip_text ( + const ustring& text + ); + + const std::string tooltip_text ( + ) const; + + const std::wstring tooltip_wtext ( + ) const; + + const dlib::ustring tooltip_utext ( + ) const; + + bool is_checked ( + ) const; + + const std::string name ( + ) const; + + const std::wstring wname ( + ) const; + + const dlib::ustring uname ( + ) const; + + void set_checked ( + ); + + void set_unchecked ( + ); + + void show ( + ); + + void hide ( + ); + + void enable ( + ); + + void disable ( + ); + + void set_main_font ( + const std::shared_ptr& f + ); + + void set_pos ( + long x, + long y + ); + + template < + typename style_type + > + void set_style ( + const style_type& style_ + ) + { + auto_mutex M(m); + style.reset(new style_type(style_)); + rect = move_rect(style->get_min_size(name_,*mfont), rect.left(), rect.top()); + parent.invalidate_rectangle(rect); + } + + template < + typename T + > + void set_click_handler ( + T& object, + void (T::*event_handler_)() + ) + { + auto_mutex M(m); + event_handler = make_mfp(object,event_handler_); + event_handler_self.clear(); + } + + void set_click_handler ( + const any_function& event_handler_ + ) + { + auto_mutex M(m); + event_handler = event_handler_; + event_handler_self.clear(); + } + + template < + typename T + > + void set_click_handler ( + T& object, + void (T::*event_handler_)(toggle_button&) + ) + { + auto_mutex M(m); + event_handler_self = make_mfp(object,event_handler_); + event_handler.clear(); + } + + void set_sourced_click_handler ( + const any_function& event_handler_ + ) + { + auto_mutex M(m); + event_handler_self = event_handler_; + event_handler.clear(); + } + + private: + + // restricted functions + toggle_button(toggle_button&); // copy constructor + toggle_button& operator=(toggle_button&); // assignment operator + + dlib::ustring name_; + tooltip btn_tooltip; + bool checked; + + any_function event_handler; + any_function event_handler_self; + + std::unique_ptr style; + + protected: + + void draw ( + const canvas& c + ) const { style->draw_toggle_button(c,rect,enabled,*mfont,lastx,lasty,name_,is_depressed(),checked); } + + void on_button_up ( + bool mouse_over + ); + + void on_mouse_over ( + ){ if (style->redraw_on_mouse_over()) parent.invalidate_rectangle(rect); } + + void on_mouse_not_over ( + ){ if (style->redraw_on_mouse_over()) parent.invalidate_rectangle(rect); } + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // class text_field +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class text_field : public drawable + { + /*! + INITIAL VALUE + text_color_ == rgb_pixel(0,0,0) + bg_color_ == rgb_pixel(255,255,255) + cursor_pos == 0 + text_width == 0 + text_ == "" + has_focus == false + cursor_visible == false + recent_movement == false + highlight_start == 0 + highlight_end == -1 + shift_pos == -1 + text_pos == 0 + + CONVENTION + - cursor_pos == the position of the cursor in the string text_. The + cursor appears before the letter text_[cursor_pos] + - cursor_x == the x coordinate of the cursor relative to the left side + of rect. i.e. the number of pixels that separate the cursor from the + left side of the text_field. + - has_focus == true if this text field has keyboard input focus + - cursor_visible == true if the cursor should be painted + - text_ == text() + - text_pos == the index of the first letter in text_ that appears in + this text field. + - text_width == the width of text_[text_pos] though text_[text.size()-1] + + - if (has_focus && the user has recently moved the cursor) then + - recent_movement == true + - else + - recent_movement == false + + - if (highlight_start <= highlight_end) then + - text[highlight_start] though text[highlight_end] should be + highlighted + + - if (shift_pos != -1) then + - has_focus == true + - the shift key is being held down or the left mouse button is + being held down. + - shift_pos == the position of the cursor when the shift or mouse key + was first pressed. + + - text_color() == text_color_ + - background_color() == bg_color_ + !*/ + + public: + text_field( + drawable_window& w + ) : + drawable(w,MOUSE_CLICK | KEYBOARD_EVENTS | MOUSE_MOVE | STRING_PUT), + text_color_(0,0,0), + bg_color_(255,255,255), + text_width(0), + text_pos(0), + recent_movement(false), + has_focus(false), + cursor_visible(false), + cursor_pos(0), + highlight_start(0), + highlight_end(-1), + shift_pos(-1), + t(*this,&text_field::timer_action), + right_click_menu(w) + { + style.reset(new text_field_style_default()); + rect.set_bottom(mfont->height()+ (style->get_padding(*mfont))*2); + rect.set_right((style->get_padding(*mfont))*2); + cursor_x = style->get_padding(*mfont); + + right_click_menu.menu().add_menu_item(menu_item_text("Cut",*this,&text_field::on_cut,'t')); + right_click_menu.menu().add_menu_item(menu_item_text("Copy",*this,&text_field::on_copy,'C')); + right_click_menu.menu().add_menu_item(menu_item_text("Paste",*this,&text_field::on_paste,'P')); + right_click_menu.menu().add_menu_item(menu_item_text("Delete",*this,&text_field::on_delete_selected,'D')); + right_click_menu.menu().add_menu_item(menu_item_separator()); + right_click_menu.menu().add_menu_item(menu_item_text("Select All",*this,&text_field::on_select_all,'A')); + + right_click_menu.set_rect(get_text_rect()); + enable_events(); + + t.set_delay_time(500); + } + + ~text_field ( + ) + { + disable_events(); + parent.invalidate_rectangle(rect); + t.stop_and_wait(); + } + + template < + typename style_type + > + void set_style ( + const style_type& style_ + ) + { + auto_mutex M(m); + style.reset(new style_type(style_)); + // call this just so that this widget redraws itself with the new style + set_main_font(mfont); + } + + void set_text ( + const std::string& text_ + ); + + void set_text ( + const std::wstring& text_ + ); + + void give_input_focus ( + ); + + bool has_input_focus ( + ) const; + + void select_all_text ( + ); + + void set_text ( + const dlib::ustring& text_ + ); + + const std::string text ( + ) const; + + const std::wstring wtext ( + ) const; + + const dlib::ustring utext ( + ) const; + + void set_text_color ( + const rgb_pixel color + ); + + const rgb_pixel text_color ( + ) const; + + void set_background_color ( + const rgb_pixel color + ); + + const rgb_pixel background_color ( + ) const; + + void set_width ( + unsigned long width + ); + + void set_pos ( + long x, + long y + ); + + void set_main_font ( + const std::shared_ptr& f + ); + + int next_free_user_event_number ( + ) const + { + return drawable::next_free_user_event_number()+1; + } + + void disable ( + ); + + void enable ( + ); + + void hide ( + ); + + void show ( + ); + + template < + typename T + > + void set_text_modified_handler ( + T& object, + void (T::*event_handler)() + ) + { + auto_mutex M(m); + text_modified_handler = make_mfp(object,event_handler); + } + + template < + typename T + > + void set_enter_key_handler ( + T& object, + void (T::*event_handler)() + ) + { + auto_mutex M(m); + enter_key_handler = make_mfp(object,event_handler); + } + + void set_text_modified_handler ( + const any_function& event_handler + ) + { + auto_mutex M(m); + text_modified_handler = event_handler; + } + + void set_enter_key_handler ( + const any_function& event_handler + ) + { + auto_mutex M(m); + enter_key_handler = event_handler; + } + + template < + typename T + > + void set_focus_lost_handler ( + T& object, + void (T::*event_handler)() + ) + { + auto_mutex M(m); + focus_lost_handler = make_mfp(object,event_handler); + } + + void set_focus_lost_handler ( + const any_function& event_handler + ) + { + auto_mutex M(m); + focus_lost_handler = event_handler; + } + + private: + + void on_cut ( + ); + + void on_copy ( + ); + + void on_paste ( + ); + + void on_select_all ( + ); + + void on_delete_selected ( + ); + + void on_text_is_selected ( + ); + + void on_no_text_selected ( + ); + + void on_user_event ( + int num + ) + { + // ignore this user event if it isn't for us + if (num != drawable::next_free_user_event_number()) + return; + + if (recent_movement == false) + { + cursor_visible = !cursor_visible; + parent.invalidate_rectangle(rect); + } + else + { + if (cursor_visible == false) + { + cursor_visible = true; + parent.invalidate_rectangle(rect); + } + recent_movement = false; + } + } + + void timer_action ( + ) { parent.trigger_user_event(this,drawable::next_free_user_event_number()); } + /*! + ensures + - flips the state of cursor_visible + !*/ + + void move_cursor ( + unsigned long pos + ); + /*! + requires + - pos <= text_.size() + ensures + - moves the cursor to the position given by pos and moves the text + in the text box if necessary + - if the position changes then the parent window will be updated + !*/ + + rectangle get_text_rect ( + ) const; + /*! + ensures + - returns the rectangle that should contain the text in this widget + !*/ + + dlib::ustring text_; + rgb_pixel text_color_; + rgb_pixel bg_color_; + + unsigned long text_width; + unsigned long text_pos; + + + bool recent_movement; + bool has_focus; + bool cursor_visible; + long cursor_pos; + unsigned long cursor_x; + + // this tells you what part of the text is highlighted + long highlight_start; + long highlight_end; + long shift_pos; + any_function text_modified_handler; + any_function enter_key_handler; + any_function focus_lost_handler; + + std::unique_ptr style; + + timer t; + + popup_menu_region right_click_menu; + + // restricted functions + text_field(text_field&); // copy constructor + text_field& operator=(text_field&); // assignment operator + + + protected: + + void draw ( + const canvas& c + ) const; + + + void on_mouse_down ( + unsigned long btn, + unsigned long state, + long x, + long y, + bool is_double_click + ); + + void on_mouse_up ( + unsigned long btn, + unsigned long state, + long x, + long y + ); + + void on_mouse_move ( + unsigned long state, + long x, + long y + ); + + void on_keydown ( + unsigned long key, + bool is_printable, + unsigned long state + ); + + void on_string_put ( + const std::wstring &str + ); + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // class text_box +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class text_box : public scrollable_region + { + /*! + INITIAL VALUE + text_color_ == rgb_pixel(0,0,0) + bg_color_ == rgb_pixel(255,255,255) + cursor_pos == 0 + text_ == "" + has_focus == false + cursor_visible == false + recent_movement == false + highlight_start == 0 + highlight_end == -1 + shift_pos == -1 + + CONVENTION + - cursor_pos == the position of the cursor in the string text_. The + cursor appears before the letter text_[cursor_pos] + - cursor_rect == The rectangle that should be drawn for the cursor. + The position is relative to total_rect(). + - has_focus == true if this text field has keyboard input focus + - cursor_visible == true if the cursor should be painted + - text_ == text() + + - if (has_focus && the user has recently moved the cursor) then + - recent_movement == true + - else + - recent_movement == false + + - if (highlight_start <= highlight_end) then + - text[highlight_start] though text[highlight_end] should be + highlighted + + - if (shift_pos != -1) then + - has_focus == true + - the shift key is being held down or the left mouse button is + being held down. + - shift_pos == the position of the cursor when the shift or mouse key + was first pressed. + + - text_color() == text_color_ + - background_color() == bg_color_ + !*/ + + public: + text_box( + drawable_window& w + ) : + scrollable_region(w,MOUSE_CLICK | KEYBOARD_EVENTS | MOUSE_MOVE | STRING_PUT), + text_color_(0,0,0), + bg_color_(255,255,255), + recent_movement(false), + has_focus(false), + cursor_visible(false), + cursor_pos(0), + highlight_start(0), + highlight_end(-1), + shift_pos(-1), + t(*this,&text_box::timer_action), + right_click_menu(w) + { + style.reset(new text_box_style_default()); + + const long padding = static_cast(style->get_padding(*mfont)); + cursor_rect = mfont->compute_cursor_rect(rectangle(padding,padding,1000000,1000000), text_, 0); + + adjust_total_rect(); + + set_vertical_mouse_wheel_scroll_increment(mfont->height()); + set_horizontal_mouse_wheel_scroll_increment(mfont->height()); + + right_click_menu.menu().add_menu_item(menu_item_text("Cut",*this,&text_box::on_cut,'t')); + right_click_menu.menu().add_menu_item(menu_item_text("Copy",*this,&text_box::on_copy,'C')); + right_click_menu.menu().add_menu_item(menu_item_text("Paste",*this,&text_box::on_paste,'P')); + right_click_menu.menu().add_menu_item(menu_item_text("Delete",*this,&text_box::on_delete_selected,'D')); + right_click_menu.menu().add_menu_item(menu_item_separator()); + right_click_menu.menu().add_menu_item(menu_item_text("Select All",*this,&text_box::on_select_all,'A')); + + right_click_menu.set_rect(get_text_rect()); + + set_size(100,100); + + enable_events(); + + t.set_delay_time(500); + } + + ~text_box ( + ) + { + disable_events(); + parent.invalidate_rectangle(rect); + t.stop_and_wait(); + } + + template < + typename style_type + > + void set_style ( + const style_type& style_ + ) + { + auto_mutex M(m); + style.reset(new style_type(style_)); + + scrollable_region::set_style(style_.get_scrollable_region_style()); + // call this just so that this widget redraws itself with the new style + set_main_font(mfont); + } + + void set_text ( + const std::string& text_ + ); + + void set_text ( + const std::wstring& text_ + ); + + void set_text ( + const dlib::ustring& text_ + ); + + const std::string text ( + ) const; + + const std::wstring wtext ( + ) const; + + const dlib::ustring utext ( + ) const; + + void set_text_color ( + const rgb_pixel color + ); + + const rgb_pixel text_color ( + ) const; + + void set_background_color ( + const rgb_pixel color + ); + + const rgb_pixel background_color ( + ) const; + + void set_size ( + unsigned long width, + unsigned long height + ); + + void set_pos ( + long x, + long y + ); + + void set_main_font ( + const std::shared_ptr& f + ); + + int next_free_user_event_number ( + ) const + { + return scrollable_region::next_free_user_event_number()+1; + } + + void disable ( + ); + + void enable ( + ); + + void hide ( + ); + + void show ( + ); + + template < + typename T + > + void set_text_modified_handler ( + T& object, + void (T::*event_handler)() + ) + { + auto_mutex M(m); + text_modified_handler = make_mfp(object,event_handler); + } + + void set_text_modified_handler ( + const any_function& event_handler + ) + { + auto_mutex M(m); + text_modified_handler = event_handler; + } + + template < + typename T + > + void set_enter_key_handler ( + T& object, + void (T::*event_handler)() + ) + { + auto_mutex M(m); + enter_key_handler = make_mfp(object,event_handler); + } + + void set_enter_key_handler ( + const any_function& event_handler + ) + { + auto_mutex M(m); + enter_key_handler = event_handler; + } + + template < + typename T + > + void set_focus_lost_handler ( + T& object, + void (T::*event_handler)() + ) + { + auto_mutex M(m); + focus_lost_handler = make_mfp(object,event_handler); + } + + void set_focus_lost_handler ( + const any_function& event_handler + ) + { + auto_mutex M(m); + focus_lost_handler = event_handler; + } + + private: + + void on_cut ( + ); + + void on_copy ( + ); + + void on_paste ( + ); + + void on_select_all ( + ); + + void on_delete_selected ( + ); + + void on_text_is_selected ( + ); + + void on_no_text_selected ( + ); + + void on_user_event ( + int num + ) + { + // ignore this user event if it isn't for us + if (num != scrollable_region::next_free_user_event_number()) + return; + + if (recent_movement == false) + { + cursor_visible = !cursor_visible; + parent.invalidate_rectangle(rect); + } + else + { + if (cursor_visible == false) + { + cursor_visible = true; + parent.invalidate_rectangle(rect); + } + recent_movement = false; + } + } + + // The reason for using user actions here rather than just having the timer just call + // what it needs directly is to avoid a potential deadlock during destruction of this widget. + void timer_action ( + ) { parent.trigger_user_event(this,scrollable_region::next_free_user_event_number()); } + /*! + ensures + - flips the state of cursor_visible + !*/ + + void move_cursor ( + unsigned long pos + ); + /*! + requires + - pos <= text_.size() + ensures + - moves the cursor to the position given by pos and moves the text + in the text box if necessary + - if the position changes then the parent window will be updated + !*/ + + rectangle get_text_rect ( + ) const; + /*! + ensures + - returns the rectangle that should contain the text in this widget + !*/ + + void adjust_total_rect ( + ); + /*! + ensures + - adjusts total_rect() so that it is big enough to contain the text + currently in this object. + !*/ + + dlib::ustring text_; + rgb_pixel text_color_; + rgb_pixel bg_color_; + + + + bool recent_movement; + bool has_focus; + bool cursor_visible; + long cursor_pos; + rectangle cursor_rect; + + // this tells you what part of the text is highlighted + long highlight_start; + long highlight_end; + long shift_pos; + any_function text_modified_handler; + any_function enter_key_handler; + any_function focus_lost_handler; + + std::unique_ptr style; + + timer t; + + popup_menu_region right_click_menu; + + // restricted functions + text_box(text_box&); // copy constructor + text_box& operator=(text_box&); // assignment operator + + + protected: + + void draw ( + const canvas& c + ) const; + + + void on_mouse_down ( + unsigned long btn, + unsigned long state, + long x, + long y, + bool is_double_click + ); + + void on_mouse_up ( + unsigned long btn, + unsigned long state, + long x, + long y + ); + + void on_mouse_move ( + unsigned long state, + long x, + long y + ); + + void on_keydown ( + unsigned long key, + bool is_printable, + unsigned long state + ); + + void on_string_put ( + const std::wstring &str + ); + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // class check_box +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class check_box : public toggle_button + { + public: + check_box( + drawable_window& w + ) : toggle_button(w) + { + set_style(toggle_button_style_check_box()); + } + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // class radio_button +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class radio_button : public toggle_button + { + public: + radio_button ( + drawable_window& w + ) : toggle_button(w) + { + set_style(toggle_button_style_radio_button()); + } + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // class tabbed_display +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class tabbed_display : public drawable + { + /*! + INITIAL VALUE + - tabs.size() == 0 + - selected_tab_ == 0 + + CONVENTION + - number_of_tabs() == tabs.size() + - tab_name(idx) == tabs[idx] + - if (tabs.size() > 0) then + - selected_tab_ == the index of the tab that is currently selected + + - for all valid i: + - tabs[i].width == mfont->compute_size(tabs[i].name) + - tabs[i].rect == the rectangle that defines where this tab is + - if (tabs[i].group != 0) then + - tabs[i].group == a pointer to the widget_group for this tab. + + - left_pad == the amount of padding in a tab to the left of the name string. + - right_pad == the amount of padding in a tab to the right of the name string. + - top_pad == the amount of padding in a tab to the top of the name string. + - bottom_pad == the amount of padding in a tab to the bottom of the name string. + + - if (event_handler.is_set()) then + - event_handler() is what is called to process click events + on this object. + !*/ + + public: + + tabbed_display( + drawable_window& w + ); + + virtual ~tabbed_display( + ); + + void set_size ( + unsigned long width, + unsigned long height + ); + + void set_number_of_tabs ( + unsigned long num + ); + + unsigned long selected_tab ( + ) const; + + unsigned long number_of_tabs ( + ) const; + + const std::string tab_name ( + unsigned long idx + ) const; + + const std::wstring tab_wname ( + unsigned long idx + ) const; + + const dlib::ustring& tab_uname ( + unsigned long idx + ) const; + + void set_tab_name ( + unsigned long idx, + const std::string& new_name + ); + + void set_tab_name ( + unsigned long idx, + const std::wstring& new_name + ); + + void set_tab_name ( + unsigned long idx, + const dlib::ustring& new_name + ); + + void set_pos ( + long x, + long y + ); + + template < + typename T + > + void set_click_handler ( + T& object, + void (T::*eh)(unsigned long new_idx,unsigned long old_idx) + ) + { + auto_mutex M(m); + event_handler = make_mfp(object,eh); + } + + void set_click_handler ( + const any_function& eh + ) + { + auto_mutex M(m); + event_handler = eh; + } + + void set_tab_group ( + unsigned long idx, + widget_group& group + ); + + void show ( + ); + + void hide ( + ); + + void enable ( + ); + + void disable ( + ); + + void set_main_font ( + const std::shared_ptr& f + ); + + void fit_to_contents ( + ); + + protected: + void on_mouse_down ( + unsigned long btn, + unsigned long state, + long x, + long y, + bool is_double_click + ); + + void draw ( + const canvas& c + ) const; + + private: + void recompute_tabs ( + ); + /*! + ensures + - recomputes the rectangles for all the tabs and makes this object + wider if needed + !*/ + + void draw_tab ( + const rectangle& tab, + const canvas& c + ) const; + /*! + ensures + - draws the outline of a tab as given by the rectangle onto c + !*/ + + struct tab_data + { + tab_data() : width(0), group(0) {} + + dlib::ustring name; + unsigned long width; + rectangle rect; + widget_group* group; + }; + + unsigned long selected_tab_; + + array tabs; + + const long left_pad; + const long right_pad; + const long top_pad; + const long bottom_pad; + + any_function event_handler; + + // restricted functions + tabbed_display(tabbed_display&); // copy constructor + tabbed_display& operator=(tabbed_display&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // class named_rectangle +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class named_rectangle : public drawable + { + /*! + INITIAL VALUE + name == "" + + CONVENTION + name_ == name() + !*/ + + public: + + named_rectangle( + drawable_window& w + ); + + virtual ~named_rectangle( + ); + + void set_size ( + unsigned long width, + unsigned long height + ); + + void set_name ( + const std::string& name + ); + + void set_name ( + const std::wstring& name + ); + + void set_name ( + const dlib::ustring& name + ); + + const std::string name ( + ) const; + + const std::wstring wname ( + ) const; + + const dlib::ustring uname ( + ) const; + + void wrap_around ( + const rectangle& rect + ); + + void set_main_font ( + const std::shared_ptr& f + ); + + protected: + + void draw ( + const canvas& c + ) const; + + private: + + void make_name_fit_in_rect ( + ); + + dlib::ustring name_; + unsigned long name_width; + unsigned long name_height; + + // restricted functions + named_rectangle(named_rectangle&); // copy constructor + named_rectangle& operator=(named_rectangle&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // class mouse_tracker +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class mouse_tracker : public draggable + { + + public: + + mouse_tracker( + drawable_window& w + ); + + ~mouse_tracker( + ); + + void show ( + ); + + void hide ( + ); + + void enable ( + ); + + void disable ( + ); + + void set_pos ( + long x, + long y + ); + + void set_main_font ( + const std::shared_ptr& f + ); + + protected: + + void on_mouse_move ( + unsigned long state, + long x, + long y + ); + + void on_drag ( + ); + + void draw ( + const canvas& c + ) const; + + void on_mouse_down ( + unsigned long btn, + unsigned long state, + long x, + long y, + bool is_double_click + ); + + + private: + + const long offset; + named_rectangle nr; + label x_label; + label y_label; + std::ostringstream sout; + + long click_x, click_y; + + // restricted functions + mouse_tracker(mouse_tracker&); // copy constructor + mouse_tracker& operator=(mouse_tracker&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // function message_box() +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + namespace message_box_helper + { + class box_win : public drawable_window + { + void initialize ( + ); + public: + box_win ( + const std::string& title_, + const std::string& message_ + ); + + box_win ( + const std::wstring& title_, + const std::wstring& message_ + ); + + box_win ( + const dlib::ustring& title_, + const dlib::ustring& message_ + ); + + ~box_win ( + ); + + void set_click_handler ( + const any_function& event_handler_ + ) + { + auto_mutex M(wm); + event_handler = event_handler_; + } + + private: + + static void deleter_thread ( + void* param + ); + + void on_click ( + ); + + on_close_return_code on_window_close ( + ); + + const std::wstring title; + const std::wstring message; + label msg; + button btn_ok; + + any_function event_handler; + }; + + class blocking_box_win : public drawable_window + { + void initialize ( + ); + + public: + blocking_box_win ( + const std::string& title_, + const std::string& message_ + ); + + blocking_box_win ( + const std::wstring& title_, + const std::wstring& message_ + ); + + blocking_box_win ( + const dlib::ustring& title_, + const dlib::ustring& message_ + ); + + ~blocking_box_win ( + ); + + private: + + void on_click ( + ); + + const std::wstring title; + const std::wstring message; + label msg; + button btn_ok; + }; + } + + template < + typename T + > + void message_box ( + const std::string& title, + const std::string& message, + T& object, + void (T::*event_handler)() + ) + { + using namespace message_box_helper; + box_win* win = new box_win(title,message); + win->set_click_handler(make_mfp(object,event_handler)); + } + + inline void message_box ( + const std::string& title, + const std::string& message, + const any_function& event_handler + ) + { + using namespace message_box_helper; + box_win* win = new box_win(title,message); + win->set_click_handler(event_handler); + } + + inline void message_box ( + const std::string& title, + const std::string& message + ) + { + using namespace message_box_helper; + new box_win(title,message); + } + + inline void message_box_blocking ( + const std::string& title, + const std::string& message + ) + { + using namespace message_box_helper; + blocking_box_win w(title,message); + w.wait_until_closed(); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // class list_box +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + namespace list_box_helper{ + template + class list_box : public scrollable_region, + public enumerable + { + /*! + INITIAL VALUE + - ms_enabled == false + - items.size() == 0 + - last_selected = 0 + + CONVENTION + - size() == items.size() + - (*this)[i] == items[i].name + - is_selected(i) == items[i].is_selected + - ms_enabled == multiple_select_enabled() + + - items[i].width == the width of items[i].name as given by font::compute_size() + - items[i].height == the height of items[i].name as given by font::compute_size() + + - last_selected == the last item the user selected + !*/ + + public: + + list_box( + drawable_window& w + ); + + ~list_box( + ); + + bool is_selected ( + unsigned long index + ) const; + + void select ( + unsigned long index + ); + + void unselect ( + unsigned long index + ); + + template < + typename style_type + > + void set_style ( + const style_type& style_ + ) + { + auto_mutex M(m); + style.reset(new style_type(style_)); + scrollable_region::set_style(style_.get_scrollable_region_style()); + parent.invalidate_rectangle(rect); + } + + template + void get_selected ( + T& list + ) const + { + auto_mutex M(m); + list.clear(); + for (unsigned long i = 0; i < items.size(); ++i) + { + if (items[i].is_selected) + { + unsigned long idx = i; + list.enqueue(idx); + } + } + } + + template + void load ( + const T& list + ) + { + auto_mutex M(m); + items.clear(); + unsigned long i = 0; + items.set_max_size(list.size()); + items.set_size(list.size()); + list.reset(); + unsigned long max_width = 0; + unsigned long total_height = 0; + while (list.move_next()) + { + items[i].is_selected = false; + items[i].name = list.element(); + mfont->compute_size(items[i].name,items[i].width, items[i].height); + + if (items[i].width > max_width) + max_width = items[i].width; + total_height += items[i].height; + + ++i; + } + set_total_rect_size(max_width, total_height); + + parent.invalidate_rectangle(rect); + last_selected = 0; + } + + const S& operator[] ( + unsigned long index + ) const; + + bool multiple_select_enabled ( + ) const; + + void enable_multiple_select ( + ); + + void disable_multiple_select ( + ); + + template < + typename T + > + void set_double_click_handler ( + T& object, + void (T::*eh)(unsigned long index) + ) { auto_mutex M(m); event_handler = make_mfp(object,eh); } + + void set_double_click_handler ( + const any_function& eh + ) { auto_mutex M(m); event_handler = eh; } + + template < + typename T + > + void set_click_handler ( + T& object, + void (T::*eh)(unsigned long index) + ) { auto_mutex M(m); single_click_event_handler = make_mfp(object,eh); } + + void set_click_handler ( + const any_function& eh + ) { auto_mutex M(m); single_click_event_handler = eh; } + + bool at_start ( + ) const; + + void reset ( + ) const; + + bool current_element_valid ( + ) const; + + const S& element ( + ) const; + + const S& element ( + ); + + bool move_next ( + ) const; + + size_t size ( + ) const; + + unsigned long get_selected ( + ) const; + + void set_main_font ( + const std::shared_ptr& f + ); + + private: + + void on_mouse_down ( + unsigned long btn, + unsigned long state, + long x, + long y, + bool is_double_click + ); + + void draw ( + const canvas& c + ) const; + + template + struct data + { + SS name; + bool is_selected; + unsigned long width; + unsigned long height; + }; + + bool ms_enabled; + array > items; + any_function event_handler; + any_function single_click_event_handler; + unsigned long last_selected; + + std::unique_ptr style; + + // restricted functions + list_box(list_box&); // copy constructor + list_box& operator=(list_box&); // assignment operator + }; + } + typedef list_box_helper::list_box list_box; + typedef list_box_helper::list_box wlist_box; + typedef list_box_helper::list_box ulist_box; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // function open_file_box() +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + namespace open_file_box_helper + { + class box_win : public drawable_window + { + public: + box_win ( + const std::string& title, + bool has_text_field = false + ); + + ~box_win ( + ); + + void set_click_handler ( + const any_function& event_handler_ + ) + { + auto_mutex M(wm); + event_handler = event_handler_; + } + + private: + + void set_sizes( + ); + + void on_window_resized ( + ); + + void deleter_thread ( + ); + + void enter_folder ( + const std::string& folder_name + ); + + void on_dirs_click ( + unsigned long idx + ); + + void on_files_click ( + unsigned long idx + ); + + void on_files_double_click ( + unsigned long + ); + + void on_cancel_click ( + ); + + void on_open_click ( + ); + + void on_path_button_click ( + toggle_button& btn + ); + + bool set_dir ( + const std::string& dir + ); + + void on_root_click ( + ); + + on_close_return_code on_window_close ( + ); + + label lbl_dirs; + label lbl_files; + label lbl_file_name; + list_box lb_dirs; + list_box lb_files; + button btn_ok; + button btn_cancel; + toggle_button btn_root; + text_field tf_file_name; + std::string path; + std::string prefix; + int cur_dir; + + any_function event_handler; + sequence >::kernel_2a_c sob; + }; + } + + template < + typename T + > + void open_file_box ( + T& object, + void (T::*event_handler)(const std::string&) + ) + { + using namespace open_file_box_helper; + box_win* win = new box_win("Open File",true); + win->set_click_handler(make_mfp(object,event_handler)); + } + + inline void open_file_box ( + const any_function& event_handler + ) + { + using namespace open_file_box_helper; + box_win* win = new box_win("Open File",true); + win->set_click_handler(event_handler); + } + + template < + typename T + > + void open_existing_file_box ( + T& object, + void (T::*event_handler)(const std::string&) + ) + { + using namespace open_file_box_helper; + box_win* win = new box_win("Open File"); + win->set_click_handler(make_mfp(object,event_handler)); + } + + inline void open_existing_file_box ( + const any_function& event_handler + ) + { + using namespace open_file_box_helper; + box_win* win = new box_win("Open File"); + win->set_click_handler(event_handler); + } + + template < + typename T + > + void save_file_box ( + T& object, + void (T::*event_handler)(const std::string&) + ) + { + using namespace open_file_box_helper; + box_win* win = new box_win("Save File",true); + win->set_click_handler(make_mfp(object,event_handler)); + } + + inline void save_file_box ( + const any_function& event_handler + ) + { + using namespace open_file_box_helper; + box_win* win = new box_win("Save File",true); + win->set_click_handler(event_handler); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // class menu_bar +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class menu_bar : public drawable + { + /*! + INITIAL VALUE + - menus.size() == 0 + - open_menu == 0 + + CONVENTION + - size() == menus.size() + - all menu data is stored in menus + - menus[x].name == the name of the xth menu + - if (menus[x].underline_pos != std::string::npos) then + - menus[x].underline_pos == the position of the character in the + menu name that should be underlined + - menus[x].underline_p1 != menus[x].underline_p2 + and these two points define the underline bar + - else + - menus[x].underline_p1 == menus[x].underline_p2 + - menus[x].menu == menu(x) + - menus[x].rect == the rectangle in which menus[x].name is drawn + - menus[x].bgrect == the rectangle for the xth menu button + + - if (there is an open menu on the screen) then + - open_menu == the index of the open menu from menus + - else + - open_menu == menus.size() + !*/ + + public: + menu_bar( + drawable_window& w + ); + + ~menu_bar(); + + // this function does nothing + void set_pos(long,long){} + + void set_main_font ( + const std::shared_ptr& f + ); + + void set_number_of_menus ( + unsigned long num + ); + + unsigned long number_of_menus ( + ) const; + + void set_menu_name ( + unsigned long idx, + const std::string name, + char underline_ch = '\0' + ); + + void set_menu_name ( + unsigned long idx, + const std::wstring name, + char underline_ch = '\0' + ); + + void set_menu_name ( + unsigned long idx, + const dlib::ustring name, + char underline_ch = '\0' + ); + + const std::string menu_name ( + unsigned long idx + ) const; + + const std::wstring menu_wname ( + unsigned long idx + ) const; + + const dlib::ustring menu_uname ( + unsigned long idx + ) const; + + popup_menu& menu ( + unsigned long idx + ); + + const popup_menu& menu ( + unsigned long idx + ) const; + + protected: + + void on_window_resized ( + ); + + void draw ( + const canvas& c + ) const; + + void on_window_moved ( + ); + + void on_focus_lost ( + ); + + void on_mouse_down ( + unsigned long btn, + unsigned long , + long x, + long y, + bool + ); + + void on_mouse_move ( + unsigned long , + long x, + long y + ); + + void on_keydown ( + unsigned long key, + bool is_printable, + unsigned long state + ); + + private: + + void show_menu ( + unsigned long i + ); + + void hide_menu ( + ); + + void on_popup_hide ( + ); + + void compute_menu_geometry ( + ); + + void adjust_position ( + ); + + struct menu_data + { + menu_data():underline_pos(dlib::ustring::npos){} + + dlib::ustring name; + dlib::ustring::size_type underline_pos; + popup_menu menu; + rectangle rect; + rectangle bgrect; + point underline_p1; + point underline_p2; + }; + + array menus; + unsigned long open_menu; + + // restricted functions + menu_bar(menu_bar&); // copy constructor + menu_bar& operator=(menu_bar&); // assignment operator + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // class directed_graph_drawer +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template + class directed_graph_drawer : public zoomable_region + { + /*! + INITIAL VALUE + - edge_selected == false + - mouse_drag == false + - selected_node == 0 + - graph_.number_of_nodes() == 0 + - external_graph.number_of_nodes() == 0 + - radius == 25 + - last_mouse_click_in_display == false + + CONVENTION + - radius == the radius of the nodes when they aren't zoomed + - external_graph and graph_ have the same graph structure + - external_graph == graph() + - external_graph.node(i) == graph_node(i) + + - if (one of the nodes is selected) then + - selected_node < graph_.number_of_nodes() + - graph_.node(selected_node) == the selected node + - else + - selected_node == graph_.number_of_nodes() + + - if (the user is dragging a node with the mouse) then + - mouse_drag == true + - drag_offset == the vector from the mouse position to the + center of the node + - else + - mouse_drag == false + + - if (the user has selected an edge) then + - edge_selected == true + - the parent node is graph_.node(selected_edge_parent) + - the child node is graph_.node(selected_edge_parent) + - else + - edge_selected == false + + - for all valid i: + - graph_.node(i).data.p == the center of the node in graph space + - graph_.node(i).data.name == node_label(i) + - graph_.node(i).data.color == node_color(i) + - graph_.node(i).data.str_rect == a rectangle sized to contain graph_.node(i).data.name + + - if (the last mouse click in our parent window as in our display_rect_ ) then + - last_mouse_click_in_display == true + - else + - last_mouse_click_in_display == false + !*/ + + public: + directed_graph_drawer ( + drawable_window& w + ) : + zoomable_region(w,MOUSE_CLICK | MOUSE_WHEEL | KEYBOARD_EVENTS), + radius(25), + edge_selected(false), + last_mouse_click_in_display(false) + { + mouse_drag = false; + selected_node = 0; + + // Whenever you make your own drawable (or inherit from draggable or button_action) + // you have to remember to call this function to enable the events. The idea + // here is that you can perform whatever setup you need to do to get your + // object into a valid state without needing to worry about event handlers + // triggering before you are ready. + enable_events(); + } + + ~directed_graph_drawer ( + ) + { + // Disable all further events for this drawable object. We have to do this + // because we don't want draw() events coming to this object while or after + // it has been destructed. + disable_events(); + + // Tell the parent window to redraw its area that previously contained this + // drawable object. + parent.invalidate_rectangle(rect); + } + + void clear_graph ( + ) + { + auto_mutex M(m); + graph_.clear(); + external_graph.clear(); + parent.invalidate_rectangle(display_rect()); + } + + const typename graph_type::node_type& graph_node ( + unsigned long i + ) const + { + DLIB_ASSERT ( i < number_of_nodes() , + "\tgraph_type::node_type& directed_graph_drawer::graph_node(i)" + << "\n\ti: " << i + << "\n\tnumber_of_nodes(): " << number_of_nodes() + ); + return external_graph.node(i); + } + + typename graph_type::node_type& graph_node ( + unsigned long i + ) + { + DLIB_ASSERT ( i < number_of_nodes() , + "\tgraph_type::node_type& directed_graph_drawer::graph_node(i)" + << "\n\ti: " << i + << "\n\tnumber_of_nodes(): " << number_of_nodes() + ); + return external_graph.node(i); + } + + const graph_type& graph ( + ) const + { + return external_graph; + } + + void save_graph ( + std::ostream& out + ) + { + auto_mutex M(m); + serialize(external_graph, out); + serialize(graph_, out); + parent.invalidate_rectangle(display_rect()); + } + + void load_graph ( + std::istream& in + ) + { + auto_mutex M(m); + deserialize(external_graph, in); + deserialize(graph_, in); + parent.invalidate_rectangle(display_rect()); + } + + unsigned long number_of_nodes ( + ) const + { + auto_mutex M(m); + return graph_.number_of_nodes(); + } + + void set_node_label ( + unsigned long i, + const std::string& label + ) + { + set_node_label(i, convert_mbstring_to_wstring(label)); + } + + void set_node_label ( + unsigned long i, + const std::wstring& label + ) + { + set_node_label(i, convert_wstring_to_utf32(label)); + } + + void set_node_label ( + unsigned long i, + const dlib::ustring& label + ) + { + auto_mutex M(m); + DLIB_ASSERT ( i < number_of_nodes() , + "\tvoid directed_graph_drawer::set_node_label(i,label)" + << "\n\ti: " << i + << "\n\tlabel: " << narrow(label) + << "\n\tnumber_of_nodes(): " << number_of_nodes() + ); + graph_.node(i).data.name = label.c_str(); + unsigned long width, height; + mfont->compute_size(label,width,height); + graph_.node(i).data.str_rect = rectangle(width,height); + parent.invalidate_rectangle(display_rect()); + } + + void set_node_color ( + unsigned long i, + rgb_pixel color + ) + { + auto_mutex M(m); + DLIB_ASSERT ( i < number_of_nodes() , + "\tvoid directed_graph_drawer::set_node_color(i,label)" + << "\n\ti: " << i + << "\n\tnumber_of_nodes(): " << number_of_nodes() + ); + graph_.node(i).data.color = color; + parent.invalidate_rectangle(display_rect()); + } + + rgb_pixel node_color ( + unsigned long i + ) const + { + auto_mutex M(m); + DLIB_ASSERT ( i < number_of_nodes() , + "\trgb_pixel directed_graph_drawer::node_color(i)" + << "\n\ti: " << i + << "\n\tnumber_of_nodes(): " << number_of_nodes() + ); + return graph_.node(i).data.color; + } + + const std::string node_label ( + unsigned long i + ) const + { + auto_mutex M(m); + DLIB_ASSERT ( i < number_of_nodes() , + "\tconst std::ustring directed_graph_drawer::node_label(i)" + << "\n\ti: " << i + << "\n\tnumber_of_nodes(): " << number_of_nodes() + ); + return narrow(graph_.node(i).data.name); + } + + const std::wstring node_wlabel ( + unsigned long i + ) const + { + return convert_utf32_to_wstring(node_ulabel(i)); + } + + const dlib::ustring node_ulabel ( + unsigned long i + ) const + { + auto_mutex M(m); + DLIB_ASSERT ( i < number_of_nodes() , + "\tconst std::ustring directed_graph_drawer::node_label(i)" + << "\n\ti: " << i + << "\n\tnumber_of_nodes(): " << number_of_nodes() + ); + return graph_.node(i).data.name.c_str(); + } + + template < + typename T + > + void set_node_selected_handler ( + T& object, + void (T::*event_handler_)(unsigned long) + ) + { + auto_mutex M(m); + node_selected_handler = make_mfp(object,event_handler_); + } + + void set_node_selected_handler ( + const any_function& event_handler_ + ) + { + auto_mutex M(m); + node_selected_handler = event_handler_; + } + + template < + typename T + > + void set_node_deselected_handler ( + T& object, + void (T::*event_handler_)(unsigned long) + ) + { + auto_mutex M(m); + node_deselected_handler = make_mfp(object,event_handler_); + } + + void set_node_deselected_handler ( + const any_function& event_handler_ + ) + { + auto_mutex M(m); + node_deselected_handler = event_handler_; + } + + template < + typename T + > + void set_node_deleted_handler ( + T& object, + void (T::*event_handler_)() + ) + { + auto_mutex M(m); + node_deleted_handler = make_mfp(object,event_handler_); + } + + void set_node_deleted_handler ( + const any_function& event_handler_ + ) + { + auto_mutex M(m); + node_deleted_handler = event_handler_; + } + + template < + typename T + > + void set_graph_modified_handler ( + T& object, + void (T::*event_handler_)() + ) + { + auto_mutex M(m); + graph_modified_handler = make_mfp(object,event_handler_); + } + + void set_graph_modified_handler ( + const any_function& event_handler_ + ) + { + auto_mutex M(m); + graph_modified_handler = event_handler_; + } + + protected: + + void on_keydown ( + unsigned long key, + bool , + unsigned long + ) + { + // ignore all keyboard input if the last thing the user clicked on + // wasn't the display area + if (last_mouse_click_in_display == false) + return; + + // if a node is selected + if (selected_node != graph_.number_of_nodes()) + { + // deselect the node if the user hits escape + if (key == base_window::KEY_ESC) + { + parent.invalidate_rectangle(display_rect()); + if (node_deselected_handler.is_set()) + node_deselected_handler(selected_node); + selected_node = graph_.number_of_nodes(); + } + + // delete the node if the user hits delete + if (key == base_window::KEY_DELETE || key == base_window::KEY_BACKSPACE) + { + parent.invalidate_rectangle(display_rect()); + graph_.remove_node(selected_node); + external_graph.remove_node(selected_node); + selected_node = graph_.number_of_nodes(); + mouse_drag = false; + if (graph_modified_handler.is_set()) + graph_modified_handler(); + if (node_deleted_handler.is_set()) + node_deleted_handler(); + } + } + + // if an edge is selected + if (edge_selected) + { + // deselect the node if the user hits escape + if (key == base_window::KEY_ESC) + { + parent.invalidate_rectangle(display_rect()); + edge_selected = false; + } + + // delete the node if the user hits delete + if (key == base_window::KEY_DELETE || key == base_window::KEY_BACKSPACE) + { + parent.invalidate_rectangle(display_rect()); + graph_.remove_edge(selected_edge_parent, selected_edge_child); + external_graph.remove_edge(selected_edge_parent, selected_edge_child); + edge_selected = false; + + if (graph_modified_handler.is_set()) + graph_modified_handler(); + } + } + } + + + void on_mouse_move ( + unsigned long state, + long x, + long y + ) + { + if (mouse_drag) + { + const point p(nearest_point(display_rect(),point(x,y))); + + point center = drag_offset + p; + graph_.node(selected_node).data.p = gui_to_graph_space(center); + parent.invalidate_rectangle(display_rect()); + } + else + { + zoomable_region::on_mouse_move(state,x,y); + } + + // check if the mouse isn't being dragged anymore + if ((state & base_window::LEFT) == 0) + { + mouse_drag = false; + } + } + + void on_mouse_up ( + unsigned long btn, + unsigned long state, + long x, + long y + ) + { + mouse_drag = false; + zoomable_region::on_mouse_up(btn,state,x,y); + } + + void on_mouse_down ( + unsigned long btn, + unsigned long state, + long x, + long y, + bool is_double_click + ) + { + bool redraw = false; + + if (display_rect().contains(x,y) && + (btn == base_window::RIGHT || btn == base_window::LEFT) && + (state & base_window::SHIFT) == 0 ) + { + // start out saying no edge is selected + if (edge_selected) + { + edge_selected = false; + redraw = true; + } + + bool click_hit_node = false; + dlib::vector p(gui_to_graph_space(point(x,y))); + // check if this click is on an existing node + for (unsigned long i = 0; i < graph_.number_of_nodes(); ++i) + { + dlib::vector n(graph_.node(i).data.p); + if ((p-n).length() < radius) + { + click_hit_node = true; + point center = graph_to_gui_space(graph_.node(i).data.p); + mouse_drag = true; + drag_offset = center - point(x,y); + + // only do something if the click isn't on the currently + // selected node + if (selected_node != i) + { + // send out the deselected event if appropriate + if (selected_node != graph_.number_of_nodes() && node_deselected_handler.is_set()) + node_deselected_handler(selected_node); + + selected_node = i; + redraw = true; + if (node_selected_handler.is_set()) + node_selected_handler(selected_node); + } + break; + } + } + + // if the click didn't hit any node then make sure nothing is selected + if (click_hit_node == false && selected_node != graph_.number_of_nodes()) + { + if (node_deselected_handler.is_set()) + node_deselected_handler(selected_node); + selected_node = graph_.number_of_nodes(); + redraw = true; + } + + + // check if this click is on an edge if we didn't click on a node + if (click_hit_node == false) + { + for (unsigned long n = 0; n < graph_.number_of_nodes() && edge_selected == false; ++n) + { + const dlib::vector parent_center(graph_to_gui_space(graph_.node(n).data.p)); + for (unsigned long e = 0; e < graph_.node(n).number_of_children() && edge_selected == false; ++e) + { + const dlib::vector child_center(graph_to_gui_space(graph_.node(n).child(e).data.p)); + + rectangle area; + area += parent_center; + area += child_center; + // if the point(x,y) is between the two nodes then lets consider it further + if (area.contains(point(x,y))) + { + p = point(x,y); + const dlib::vector z(0,0,1); + // find the distance from the line between the two nodes + const dlib::vector perpendicular(z.cross(parent_center-child_center).normalize()); + double distance = std::abs((child_center-p).dot(perpendicular)); + if (distance < 8) + { + edge_selected = true; + selected_edge_parent = n; + selected_edge_child = graph_.node(n).child(e).index(); + redraw = true; + } + } + } + } + } + + + // if the click didn't land on any node then add a new one if this was + // a right mouse button click + if (click_hit_node == false && btn == base_window::RIGHT) + { + const unsigned long n = graph_.add_node(); + external_graph.add_node(); + + graph_.node(n).data.p = gui_to_graph_space(point(x,y)); + + redraw = true; + selected_node = n; + mouse_drag = false; + if (graph_modified_handler.is_set()) + graph_modified_handler(); + + if (node_selected_handler.is_set()) + node_selected_handler(selected_node); + + } + else if (selected_node == graph_.number_of_nodes()) + { + // in this case the click landed in the white area between nodes + zoomable_region::on_mouse_down( btn, state, x, y, is_double_click); + } + } + + // If the user is shift clicking with the mouse then see if we + // should add a new edge. + if (display_rect().contains(x,y) && + btn == base_window::LEFT && + (state & base_window::SHIFT) && + selected_node != graph_.number_of_nodes() ) + { + dlib::vector p(gui_to_graph_space(point(x,y))); + // check if this click is on an existing node + for (unsigned long i = 0; i < graph_.number_of_nodes(); ++i) + { + dlib::vector n(graph_.node(i).data.p); + if ((p-n).length() < radius) + { + // add the edge if it doesn't already exist and isn't an edge back to + // the same node + if (graph_.has_edge(selected_node,i) == false && selected_node != i && + graph_.has_edge(i, selected_node) == false) + { + graph_.add_edge(selected_node,i); + external_graph.add_edge(selected_node,i); + redraw = true; + + if (graph_modified_handler.is_set()) + graph_modified_handler(); + } + break; + } + } + } + + + if (redraw) + parent.invalidate_rectangle(display_rect()); + + + if (display_rect().contains(x,y) == false) + last_mouse_click_in_display = false; + else + last_mouse_click_in_display = true; + } + + void draw ( + const canvas& c + ) const + { + zoomable_region::draw(c); + + rectangle area = c.intersect(display_rect()); + if (area.is_empty() == true) + return; + + + if (enabled) + fill_rect(c,display_rect(),255); + else + fill_rect(c,display_rect(),128); + + + const unsigned long rad = static_cast(radius*zoom_scale()); + point center; + + + // first draw all the edges + for (unsigned long i = 0; i < graph_.number_of_nodes(); ++i) + { + center = graph_to_gui_space(graph_.node(i).data.p); + const rectangle circle_area(centered_rect(center,2*(rad+8),2*(rad+8))); + + // draw lines to all this node's parents + const dlib::vector z(0,0,1); + for (unsigned long j = 0; j < graph_.node(i).number_of_parents(); ++j) + { + point p(graph_to_gui_space(graph_.node(i).parent(j).data.p)); + + rgb_pixel color(0,0,0); + // if this is the selected edge then draw it with red instead of black + if (edge_selected && selected_edge_child == i && selected_edge_parent == graph_.node(i).parent(j).index()) + { + color.red = 255; + // we need to be careful when drawing this line to not draw it over the node dots since it + // has a different color from them and would look weird + dlib::vector v(p-center); + v = v.normalize()*rad; + draw_line(c,center+v,p-v ,color, area); + } + else + { + draw_line(c,center,p ,color, area); + } + + + // draw the triangle pointing to this node + if (area.intersect(circle_area).is_empty() == false) + { + dlib::vector v(p-center); + v = v.normalize(); + + dlib::vector cross = z.cross(v).normalize(); + dlib::vector r(center + v*rad); + for (double i = 0; i < 8*zoom_scale(); i += 0.1) + draw_line(c,(r+v*i)+cross*i, (r+v*i)-cross*i,color,area); + } + } + } + + + // now draw all the node dots + for (unsigned long i = 0; i < graph_.number_of_nodes(); ++i) + { + center = graph_to_gui_space(graph_.node(i).data.p); + const rectangle circle_area(centered_rect(center,2*(rad+8),2*(rad+8))); + + // draw the actual dot for this node + if (area.intersect(circle_area).is_empty()==false) + { + rgb_alpha_pixel color; + assign_pixel(color, graph_.node(i).data.color); + // this node is in area so lets draw it and all of it's edges as well + draw_solid_circle(c,center,rad-3,color,area); + color.alpha = 240; + draw_circle(c,center,rad-3,color,area); + color.alpha = 200; + draw_circle(c,center,rad-2.5,color,area); + color.alpha = 160; + draw_circle(c,center,rad-2.0,color,area); + color.alpha = 120; + draw_circle(c,center,rad-1.5,color,area); + color.alpha = 80; + draw_circle(c,center,rad-1.0,color,area); + color.alpha = 40; + draw_circle(c,center,rad-0.5,color,area); + + } + + + if (i == selected_node) + draw_circle(c,center,rad+5,rgb_pixel(0,0,255),area); + } + + + // now draw all the strings last + for (unsigned long i = 0; i < graph_.number_of_nodes(); ++i) + { + center = graph_to_gui_space(graph_.node(i).data.p); + rectangle circle_area(centered_rect(center,2*rad+3,2*rad+3)); + if (area.intersect(circle_area).is_empty()==false) + { + rgb_pixel color = graph_.node(i).data.color; + // invert this color + color.red = 255-color.red; + color.green = 255-color.green; + color.blue = 255-color.blue; + sout << i; + unsigned long width, height; + mfont->compute_size(sout.str(),width,height); + rectangle str_rect(centered_rect(center, width,height)); + if (circle_area.contains(str_rect)) + { + mfont->draw_string(c,str_rect,sout.str(),color,0,std::string::npos,area); + + // draw the label for this node if it isn't empty + if(graph_.node(i).data.name.size() > 0) + { + rectangle str_rect(graph_.node(i).data.str_rect); + str_rect = centered_rect(center.x(), center.y()-rad-mfont->height(), str_rect.width(), str_rect.height()); + mfont->draw_string(c,str_rect,graph_.node(i).data.name,0,0,std::string::npos,area); + } + } + sout.str(""); + } + } + } + + private: + + struct data + { + data() : color(0,0,0) {} + vector p; + dlib::ustring name; + rectangle str_rect; + rgb_pixel color; + }; + + friend void serialize(const data& item, std::ostream& out) + { + serialize(item.p, out); + serialize(item.name, out); + serialize(item.str_rect, out); + serialize(item.color, out); + } + + friend void deserialize(data& item, std::istream& in) + { + deserialize(item.p, in); + deserialize(item.name, in); + deserialize(item.str_rect, in); + deserialize(item.color, in); + } + + mutable std::ostringstream sout; + + const double radius; + unsigned long selected_node; + bool mouse_drag; // true if the user is dragging a node + point drag_offset; + + bool edge_selected; + unsigned long selected_edge_parent; + unsigned long selected_edge_child; + + any_function node_selected_handler; + any_function node_deselected_handler; + any_function node_deleted_handler; + any_function graph_modified_handler; + + graph_type external_graph; + // rebind the graph_ type to make us a graph_ of data structs + typename graph_type::template rebind::other graph_; + + bool last_mouse_click_in_display; + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // class text_grid +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class text_grid : public scrollable_region + { + /*! + INITIAL VALUE + - has_focus == false + - vertical_scroll_increment() == 10 + - horizontal_scroll_increment() == 10 + - border_color_ == rgb_pixel(128,128,128) + + CONVENTION + - grid.nr() == row_height.size() + - grid.nc() == col_width.size() + - border_color() == border_color_ + - text(r,c) == grid[r][c].text + - text_color(r,c) == grid[r][c].text_color + - background_color(r,c) == grid[r][c].bg_color + + - if (the user has clicked on this widget and caused one of the + boxes to have input focus) then + - has_focus == true + - grid[active_row][active_col] == the active text box + - cursor_pos == the position of the cursor in the above box + - if (the cursor should be displayed) then + - show_cursor == true + - else + - show_cursor == false + - else + - has_focus == false + !*/ + + public: + text_grid ( + drawable_window& w + ); + + ~text_grid ( + ); + + void set_grid_size ( + unsigned long rows, + unsigned long cols + ); + + unsigned long number_of_columns ( + ) const; + + unsigned long number_of_rows ( + ) const; + + int next_free_user_event_number ( + ) const; + + rgb_pixel border_color ( + ) const; + + void set_border_color ( + rgb_pixel color + ); + + const std::string text ( + unsigned long row, + unsigned long col + ) const; + + const std::wstring wtext ( + unsigned long row, + unsigned long col + ) const; + + const dlib::ustring utext ( + unsigned long row, + unsigned long col + ) const; + + void set_text ( + unsigned long row, + unsigned long col, + const std::string& str + ); + + void set_text ( + unsigned long row, + unsigned long col, + const std::wstring& str + ); + + void set_text ( + unsigned long row, + unsigned long col, + const dlib::ustring& str + ); + + const rgb_pixel text_color ( + unsigned long row, + unsigned long col + ) const; + + void set_text_color ( + unsigned long row, + unsigned long col, + const rgb_pixel color + ); + + const rgb_pixel background_color ( + unsigned long row, + unsigned long col + ) const; + + void set_background_color ( + unsigned long row, + unsigned long col, + const rgb_pixel color + ); + + bool is_editable ( + unsigned long row, + unsigned long col + ) const; + + void set_editable ( + unsigned long row, + unsigned long col, + bool editable + ); + + void set_column_width ( + unsigned long col, + unsigned long width + ); + + void set_row_height ( + unsigned long row, + unsigned long height + ); + + void disable ( + ); + + void hide ( + ); + + template < + typename T + > + void set_text_modified_handler ( + T& object, + void (T::*eh)(unsigned long, unsigned long) + ) { text_modified_handler = make_mfp(object,eh); } + + void set_text_modified_handler ( + const any_function& eh + ) { text_modified_handler = eh; } + + private: + + void on_user_event ( + int num + ); + + void timer_action ( + ); + /*! + ensures + - flips the state of show_cursor + !*/ + + void compute_bg_rects ( + ); + + void compute_total_rect ( + ); + + void on_keydown ( + unsigned long key, + bool is_printable, + unsigned long state + ); + + void on_mouse_down ( + unsigned long btn, + unsigned long state, + long x, + long y, + bool is_double_click + ); + + void on_mouse_up ( + unsigned long btn, + unsigned long state, + long x, + long y + ); + + void on_focus_lost ( + ); + + void draw ( + const canvas& c + ) const; + + rectangle get_text_rect ( + unsigned long row, + unsigned long col + ) const; + + rectangle get_bg_rect ( + unsigned long row, + unsigned long col + ) const; + + struct data_type + { + data_type(): text_color(0,0,0), bg_color(255,255,255), + first(0), is_editable(true) + {} + + dlib::ustring text; + rgb_pixel text_color; + rgb_pixel bg_color; + rectangle bg_rect; + dlib::ustring::size_type first; + bool is_editable; + }; + + void drop_input_focus ( + ); + + void move_cursor ( + long row, + long col, + long new_cursor_pos + ); + + array2d grid; + array col_width; + array row_height; + bool has_focus; + long active_col; + long active_row; + long cursor_pos; + bool show_cursor; + bool recent_cursor_move; + timer cursor_timer; + rgb_pixel border_color_; + any_function text_modified_handler; + }; + +// ---------------------------------------------------------------------------------------- + + class image_display : public scrollable_region + { + /*! + INITIAL VALUE + - img.size() == 0 + - overlay_rects.size() == 0 + - overlay_lines.size() == 0 + - drawing_rect == false + - rect_is_selected == false + + CONVENTION + - img == the image this object displays + - overlay_rects == the overlay rectangles this object displays + - overlay_lines == the overlay lines this object displays + + - if (drawing_rect) then + - the user is drawing a rectangle on the screen and is + thus holding down CTRL and the left mouse button. + - rect_anchor == the point on the screen where the user + clicked to begin drawing the rectangle. + - rect_to_draw == the rectangle which should appear on the screen. + + - if (rect_is_selected) then + - selected_rect == the index in overlay_rects of the user selected + rectangle. + - last_right_click_pos == the last place we saw the user right click + the mouse. + - parts_menu.is_enabled() == true + - if (it is actually a part of this rect that is selected) then + - selected_part_name == the name of the part in overlay_rects[selected_rect].parts + that is selected. + - else + - selected_part_name.size() == 0 + - else + - parts_menu.is_enabled() == false + - selected_part_name.size() == 0 + + - if (moving_overlay) then + - moving_rect == the index in overlay_rects that the move applies to. + - if (moving_what == MOVING_PART) then + - moving_part_name == the name of the part in + overlay_rects[moving_rect] that is being moved around with the + mouse. + - else + - moving_what will tell us which side of the rectangle in + overlay_rects[moving_rect] is being moved by the mouse. + !*/ + + public: + + image_display( + drawable_window& w + ); + + ~image_display( + ); + + template < + typename image_type + > + void set_image ( + const image_type& new_img + ) + { + auto_mutex M(m); + + // if the new image has a different size when compared to the previous image + // then we should readjust the total rectangle size. + if (num_rows(new_img) != img.nr() || num_columns(new_img) != img.nc()) + { + if (zoom_in_scale != 1) + set_total_rect_size(num_columns(new_img)*zoom_in_scale, num_rows(new_img)*zoom_in_scale); + else + set_total_rect_size(num_columns(new_img)/zoom_out_scale, num_rows(new_img)/zoom_out_scale); + } + else + { + parent.invalidate_rectangle(rect); + } + + highlighted_rect = std::numeric_limits::max(); + rect_is_selected = false; + parts_menu.disable(); + assign_image_scaled(img,new_img); + } + + virtual void set_pos ( + long x, + long y + ) + { + auto_mutex lock(m); + scrollable_region::set_pos(x,y); + parts_menu.set_rect(rect); + } + + virtual void set_size ( + unsigned long width, + unsigned long height + ) + { + auto_mutex lock(m); + scrollable_region::set_size(width,height); + parts_menu.set_rect(rect); + } + + struct overlay_rect + { + overlay_rect() :crossed_out(false) { assign_pixel(color, 0);} + + template + overlay_rect(const rectangle& r, pixel_type p) + : rect(r),crossed_out(false) { assign_pixel(color, p); } + + template + overlay_rect(const rectangle& r, pixel_type p, const std::string& l) + : rect(r),label(l),crossed_out(false) { assign_pixel(color, p); } + + template + overlay_rect(const rectangle& r, pixel_type p, const std::string& l, const std::map& parts_) + : rect(r),label(l),parts(parts_),crossed_out(false) { assign_pixel(color, p); } + + rectangle rect; + rgb_alpha_pixel color; + std::string label; + std::map parts; + bool crossed_out; + }; + + struct overlay_line + { + overlay_line() { assign_pixel(color, 0);} + + template + overlay_line(const point& p1_, const point& p2_, pixel_type p) + : p1(p1_), p2(p2_) { assign_pixel(color, p); } + + point p1; + point p2; + rgb_alpha_pixel color; + }; + + struct overlay_circle + { + overlay_circle():radius(0) { assign_pixel(color, 0);} + + template + overlay_circle(const point& center_, const int radius_, pixel_type p) + : center(center_), radius(radius_) { assign_pixel(color, p); } + + template + overlay_circle(const point& center_, const int radius_, pixel_type p, const std::string& l) + : center(center_), radius(radius_), label(l) { assign_pixel(color, p); } + + point center; + int radius; + rgb_alpha_pixel color; + std::string label; + }; + + void add_overlay ( + const overlay_rect& overlay + ); + + void add_overlay ( + const overlay_line& overlay + ); + + void add_overlay ( + const overlay_circle& overlay + ); + + void add_overlay ( + const std::vector& overlay + ); + + void add_overlay ( + const std::vector& overlay + ); + + void add_overlay ( + const std::vector& overlay + ); + + void clear_overlay ( + ); + + rectangle get_image_display_rect ( + ) const; + + std::vector get_overlay_rects ( + ) const; + + void set_default_overlay_rect_label ( + const std::string& label + ); + + std::string get_default_overlay_rect_label ( + ) const; + + void set_default_overlay_rect_color ( + const rgb_alpha_pixel& color + ); + + rgb_alpha_pixel get_default_overlay_rect_color ( + ) const; + + template < + typename T + > + void set_overlay_rects_changed_handler ( + T& object, + void (T::*event_handler_)() + ) + { + auto_mutex M(m); + event_handler = make_mfp(object,event_handler_); + } + + void set_overlay_rects_changed_handler ( + const any_function& event_handler_ + ) + { + auto_mutex M(m); + event_handler = event_handler_; + } + + template < + typename T + > + void set_overlay_rect_selected_handler ( + T& object, + void (T::*event_handler_)(const overlay_rect& orect) + ) + { + auto_mutex M(m); + orect_selected_event_handler = make_mfp(object,event_handler_); + } + + void set_overlay_rect_selected_handler ( + const any_function& event_handler_ + ) + { + auto_mutex M(m); + orect_selected_event_handler = event_handler_; + } + + template < + typename T + > + void set_image_clicked_handler ( + T& object, + void (T::*event_handler_)(const point& p, bool is_double_click, unsigned long btn) + ) + { + auto_mutex M(m); + image_clicked_handler = make_mfp(object,event_handler_); + } + + void set_image_clicked_handler ( + const any_function& event_handler_ + ) + { + auto_mutex M(m); + image_clicked_handler = event_handler_; + } + + void add_labelable_part_name ( + const std::string& name + ); + + void clear_labelable_part_names ( + ); + + void enable_overlay_editing ( + ) { auto_mutex M(m); overlay_editing_enabled = true; } + + void disable_overlay_editing ( + ) + { + auto_mutex M(m); + overlay_editing_enabled = false; + rect_is_selected = false; + drawing_rect = false; + parent.invalidate_rectangle(rect); + } + + bool overlay_editing_is_enabled ( + ) const { auto_mutex M(m); return overlay_editing_enabled; } + + private: + + void draw ( + const canvas& c + ) const; + + void on_wheel_up ( + unsigned long state + ); + + void on_wheel_down ( + unsigned long state + ); + + void on_mouse_down ( + unsigned long btn, + unsigned long state, + long x, + long y, + bool is_double_click + ); + + void on_mouse_up ( + unsigned long btn, + unsigned long state, + long x, + long y + ); + + void on_mouse_move ( + unsigned long state, + long x, + long y + ); + + void on_keydown ( + unsigned long key, + bool is_printable, + unsigned long state + ); + + void on_part_add ( + const std::string& part_name + ); + + rectangle get_rect_on_screen ( + unsigned long idx + ) const; + + rectangle get_rect_on_screen ( + rectangle orect + ) const; + + rgb_alpha_pixel invert_pixel (const rgb_alpha_pixel& p) const + { return rgb_alpha_pixel(255-p.red, 255-p.green, 255-p.blue, p.alpha); } + + virtual int next_free_user_event_number ( + ) const { return scrollable_region::next_free_user_event_number()+1; } + // The reason for using user actions here rather than just having the timer just call + // what it needs directly is to avoid a potential deadlock during destruction of this widget. + void timer_event_unhighlight_rect() + { + highlight_timer.stop(); + parent.trigger_user_event(this,scrollable_region::next_free_user_event_number()); + } + void on_user_event (int num) + { + // ignore this user event if it isn't for us + if (num != scrollable_region::next_free_user_event_number()) + return; + if (highlighted_rect < overlay_rects.size()) + { + highlighted_rect = std::numeric_limits::max(); + parent.invalidate_rectangle(rect); + } + } + + + array2d img; + + + std::vector overlay_rects; + std::vector overlay_lines; + std::vector overlay_circles; + + long zoom_in_scale; + long zoom_out_scale; + bool drawing_rect; + point rect_anchor; + rectangle rect_to_draw; + bool rect_is_selected; + std::string selected_part_name; + unsigned long selected_rect; + rgb_alpha_pixel default_rect_color; + std::string default_rect_label; + any_function event_handler; + any_function orect_selected_event_handler; + any_function image_clicked_handler; + popup_menu_region parts_menu; + point last_right_click_pos; + const double part_width; + std::set part_names; + bool overlay_editing_enabled; + timer highlight_timer; + unsigned long highlighted_rect; + bool holding_shift_key; + + bool moving_overlay; + unsigned long moving_rect; + enum { + MOVING_RECT_LEFT, + MOVING_RECT_TOP, + MOVING_RECT_RIGHT, + MOVING_RECT_BOTTOM, + MOVING_PART + } moving_what; + std::string moving_part_name; + + // restricted functions + image_display(image_display&); // copy constructor + image_display& operator=(image_display&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- + + class perspective_display : public drawable, noncopyable + { + public: + + perspective_display( + drawable_window& w + ); + + ~perspective_display( + ); + + virtual void set_size ( + unsigned long width, + unsigned long height + ); + + struct overlay_line + { + overlay_line() { assign_pixel(color, 0);} + + overlay_line(const vector& p1_, const vector& p2_) + : p1(p1_), p2(p2_) { assign_pixel(color, 255); } + + template + overlay_line(const vector& p1_, const vector& p2_, pixel_type p) + : p1(p1_), p2(p2_) { assign_pixel(color, p); } + + vector p1; + vector p2; + rgb_pixel color; + }; + + struct overlay_dot + { + overlay_dot() { assign_pixel(color, 0);} + + overlay_dot(const vector& p_) + : p(p_) { assign_pixel(color, 255); } + + template + overlay_dot(const vector& p_, pixel_type color_) + : p(p_) { assign_pixel(color, color_); } + + vector p; + rgb_pixel color; + }; + + + void add_overlay ( + const std::vector& overlay + ); + + void add_overlay ( + const std::vector& overlay + ); + + void clear_overlay ( + ); + + template < + typename T + > + void set_dot_double_clicked_handler ( + T& object, + void (T::*event_handler_)(const vector&) + ) + { + auto_mutex M(m); + dot_clicked_event_handler = make_mfp(object,event_handler_); + } + + void set_dot_double_clicked_handler ( + const any_function&)>& event_handler_ + ); + + private: + + void draw ( + const canvas& c + ) const; + + void on_wheel_up ( + unsigned long state + ); + + void on_wheel_down ( + unsigned long state + ); + + void on_mouse_down ( + unsigned long btn, + unsigned long state, + long x, + long y, + bool is_double_click + ); + + void on_mouse_move ( + unsigned long state, + long x, + long y + ); + + static bool compare_second ( + const std::pair& a, + const std::pair& b + ) { return a.second < b.second; } + + + point last; + std::vector overlay_lines; + std::vector overlay_dots; + + camera_transform tform; + vector sum_pts; + vector max_pts; + any_function&)> dot_clicked_event_handler; + mutable array2d depth; + }; + +// ---------------------------------------------------------------------------------------- + + class perspective_window : public drawable_window, noncopyable + { + public: + + typedef perspective_display::overlay_line overlay_line; + typedef perspective_display::overlay_dot overlay_dot; + + perspective_window( + ) : disp(*this) + { + set_size(100,100); + on_window_resized(); + show(); + } + + perspective_window( + const std::vector >& point_cloud + ) : + disp(*this) + { + set_size(100,100); + on_window_resized(); + add_overlay(point_cloud); + show(); + } + + perspective_window( + const std::vector >& point_cloud, + const std::string& title + ) : + disp(*this) + { + set_size(100,100); + on_window_resized(); + add_overlay(point_cloud); + set_title(title); + show(); + } + + ~perspective_window( + ) + { + // You should always call close_window() in the destructor of window + // objects to ensure that no events will be sent to this window while + // it is being destructed. + close_window(); + } + + void add_overlay ( + const std::vector& overlay + ) + { + disp.add_overlay(overlay); + } + + void add_overlay ( + const std::vector& overlay + ) + { + disp.add_overlay(overlay); + } + + void clear_overlay ( + ) + { + disp.clear_overlay(); + } + + template + void add_overlay(const vector& p1, const vector& p2, pixel_type p) + { + add_overlay(std::vector(1,overlay_line(p1,p2,p))); + } + + void add_overlay(const std::vector >& d) + { + add_overlay(d, 255); + } + + template + void add_overlay(const std::vector >& d, pixel_type p) + { + std::vector temp; + temp.resize(d.size()); + for (unsigned long i = 0; i < temp.size(); ++i) + temp[i] = overlay_dot(d[i], p); + + add_overlay(temp); + } + + template < + typename T + > + void set_dot_double_clicked_handler ( + T& object, + void (T::*event_handler_)(const vector&) + ) + { + disp.set_dot_double_clicked_handler(object,event_handler_); + } + + void set_dot_double_clicked_handler ( + const any_function&)>& event_handler_ + ) + { + disp.set_dot_double_clicked_handler(event_handler_); + } + + private: + + void on_window_resized( + ) + { + drawable_window::on_window_resized(); + unsigned long width, height; + get_size(width,height); + disp.set_pos(0,0); + disp.set_size(width, height); + } + + perspective_display disp; + }; + +// ---------------------------------------------------------------------------------------- + + class image_window : public drawable_window + { + public: + + typedef image_display::overlay_rect overlay_rect; + typedef image_display::overlay_line overlay_line; + typedef image_display::overlay_circle overlay_circle; + + image_window( + ); + + template < typename image_type > + image_window( + const image_type& img + ) : + gui_img(*this), + window_has_closed(false), + have_last_click(false), + mouse_btn(0), + clicked_signaler(this->wm), + have_last_keypress(false), + tie_input_events(false) + { + gui_img.set_image_clicked_handler(*this, &image_window::on_image_clicked); + gui_img.disable_overlay_editing(); + set_image(img); + show(); + } + + template < typename image_type > + image_window( + const image_type& img, + const std::string& title + ) : + gui_img(*this), + window_has_closed(false), + have_last_click(false), + mouse_btn(0), + clicked_signaler(this->wm), + have_last_keypress(false), + tie_input_events(false) + { + gui_img.set_image_clicked_handler(*this, &image_window::on_image_clicked); + gui_img.disable_overlay_editing(); + set_image(img); + set_title(title); + show(); + } + + + ~image_window( + ); + + template < typename image_type > + void set_image ( + const image_type& img + ) + { + const unsigned long padding = scrollable_region_style_default().get_border_size(); + auto_mutex M(wm); + gui_img.set_image(img); + + // Only ever mess with the size of the window if the user is giving us an image + // that is a different size. Otherwise we assume that they will have already + // sized the window to whatever they feel is reasonable for an image of the + // current size. + if (previous_image_size != get_rect(img)) + { + const rectangle r = gui_img.get_image_display_rect(); + if (image_rect != r) + { + // set the size of this window to match the size of the input image + set_size(r.width()+padding*2,r.height()+padding*2); + + // call this to make sure everything else is setup properly + on_window_resized(); + + image_rect = r; + } + previous_image_size = get_rect(img); + } + } + + void add_overlay ( + const overlay_rect& overlay + ); + + template + void add_overlay(const rectangle& r, pixel_type p) + { add_overlay(image_display::overlay_rect(r,p)); } + + void add_overlay(const rectangle& r) + { add_overlay(image_display::overlay_rect(r,rgb_pixel(255,0,0))); } + + template + void add_overlay(const rectangle& r, pixel_type p, const std::string& l) + { add_overlay(image_display::overlay_rect(r,p,l)); } + + template + void add_overlay(const std::vector& r, pixel_type p) + { + std::vector temp; + temp.resize(r.size()); + for (unsigned long i = 0; i < temp.size(); ++i) + temp[i] = overlay_rect(r[i], p); + + add_overlay(temp); + } + + void add_overlay(const std::vector& r) + { add_overlay(r, rgb_pixel(255,0,0)); } + + void add_overlay( + const full_object_detection& object, + const std::vector& part_names + ) + { + + add_overlay(overlay_rect(object.get_rect(), rgb_pixel(255,0,0))); + + std::vector temp; + temp.reserve(object.num_parts()); + for (unsigned long i = 0; i < object.num_parts(); ++i) + { + if (object.part(i) != OBJECT_PART_NOT_PRESENT) + { + if (i < part_names.size()) + temp.push_back(overlay_circle(object.part(i), 7, rgb_pixel(0,255,0), part_names[i])); + else + temp.push_back(overlay_circle(object.part(i), 7, rgb_pixel(0,255,0))); + } + } + + add_overlay(temp); + } + + void add_overlay( + const full_object_detection& object + ) + { + std::vector part_names; + add_overlay(object, part_names); + } + + void add_overlay( + const std::vector& objects, + const std::vector& part_names + ) + { + std::vector rtemp; + rtemp.reserve(objects.size()); + for (unsigned long i = 0; i < objects.size(); ++i) + { + rtemp.push_back(overlay_rect(objects[i].get_rect(), rgb_pixel(255,0,0))); + } + + add_overlay(rtemp); + + std::vector temp; + + for (unsigned long i = 0; i < objects.size(); ++i) + { + for (unsigned long j = 0; j < objects[i].num_parts(); ++j) + { + if (objects[i].part(j) != OBJECT_PART_NOT_PRESENT) + { + if (j < part_names.size()) + temp.push_back(overlay_circle(objects[i].part(j), 7, rgb_pixel(0,255,0),part_names[j])); + else + temp.push_back(overlay_circle(objects[i].part(j), 7, rgb_pixel(0,255,0))); + } + } + } + + add_overlay(temp); + } + + void add_overlay( + const std::vector& objects + ) + { + std::vector part_names; + add_overlay(objects, part_names); + } + + void add_overlay ( + const overlay_line& overlay + ); + + void add_overlay ( + const overlay_circle& overlay + ); + + template + void add_overlay(const point& p1, const point& p2, pixel_type p) + { add_overlay(image_display::overlay_line(p1,p2,p)); } + + void add_overlay ( + const std::vector& overlay + ); + + void add_overlay ( + const std::vector& overlay + ); + + void add_overlay ( + const std::vector& overlay + ); + + void clear_overlay ( + ); + + bool get_next_double_click ( + point& p, + unsigned long& mouse_button + ); + + void tie_events ( + ); + + void untie_events ( + ); + + bool events_tied ( + ) const; + + bool get_next_double_click ( + point& p + ) + { + unsigned long mouse_button; + return get_next_double_click(p, mouse_button); + } + + bool get_next_keypress ( + unsigned long& key, + bool& is_printable, + unsigned long& state + ); + + bool get_next_keypress ( + unsigned long& key, + bool& is_printable + ) + { + unsigned long state; + return get_next_keypress(key,is_printable,state); + } + + private: + + virtual base_window::on_close_return_code on_window_close( + ); + + void on_window_resized( + ); + + void on_image_clicked ( + const point& p, + bool is_double_click, + unsigned long btn + ); + + virtual void on_keydown ( + unsigned long key, + bool is_printable, + unsigned long state + ); + + // restricted functions + image_window(image_window&); + image_window& operator= (image_window&); + + image_display gui_img; + rectangle image_rect; + rectangle previous_image_size; + bool window_has_closed; + bool have_last_click; + point last_clicked_point; + unsigned long mouse_btn; + rsignaler clicked_signaler; + + bool have_last_keypress; + unsigned long next_key; + bool next_is_printable; + unsigned long next_state; + bool tie_input_events; + }; + +// ---------------------------------------------------------------------------------------- + +} + +#ifdef NO_MAKEFILE +#include "widgets.cpp" +#endif + +#endif // DLIB_WIDGETs_ + diff --git a/ml/dlib/dlib/gui_widgets/widgets_abstract.h b/ml/dlib/dlib/gui_widgets/widgets_abstract.h new file mode 100644 index 000000000..2b4dc4486 --- /dev/null +++ b/ml/dlib/dlib/gui_widgets/widgets_abstract.h @@ -0,0 +1,3461 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net), Keita Mochizuki +// License: Boost Software License See LICENSE.txt for the full license. + +#undef DLIB_WIDGETs_ABSTRACT_ +#ifdef DLIB_WIDGETs_ABSTRACT_ + +#include "fonts_abstract.h" +#include "drawable_abstract.h" +#include "base_widgets_abstract.h" + +#include "../gui_core.h" +#include +#include +#include "../interfaces/enumerable.h" +#include "style_abstract.h" +#include "../image_processing/full_object_detection_abstract.h" + +namespace dlib +{ + + /*! + GENERAL REMARKS + This component is a collection of various windowing widgets such as buttons, + labels, text boxes, and so on. This component also includes the drawable + interface, drawable_window, and font handling objects. The file you are + currently viewing defines all the high level graphical widgets which are + provided by this component that can appear in a drawable_window. To view + the specifications for the other members of this component look at + fonts_abstract.h, base_widgets_abstract.h, and drawable_abstract.h + + THREAD SAFETY + All objects and functions defined in this file are thread safe. You may + call them from any thread without serializing access to them. + + EVENT HANDLERS + Note that all event handlers, including the user registered callback + functions, are executed in the event handling thread. Additionally, + the drawable::m mutex will always be locked while these event handlers + are running. Also, don't rely on get_thread_id() always returning the + same ID from inside event handlers. + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // function open_file_box(), open_existing_file_box(), and save_file_box() +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + void open_file_box ( + T& object, + void (T::*event_handler)(const std::string&) + ); + /*! + requires + - event_handler == a valid pointer to a member function of object T. + ensures + - Displays a window titled "Open File" that will allow the user to select a + file. + - The displayed window will start out showing the directory get_current_dir() + (i.e. it starts in the current working directory) + - The event_handler function is called on object if the user selects + a file. If the user closes the window without selecting a file + then nothing occurs. + !*/ + + void open_file_box ( + const any_function& event_handler + ); + /*! + ensures + - Displays a window titled "Open File" that will allow the user to select a + file. + - The displayed window will start out showing the directory get_current_dir() + (i.e. it starts in the current working directory) + - The event_handler function is called if the user selects + a file. If the user closes the window without selecting a file + then nothing occurs. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + void open_existing_file_box ( + T& object, + void (T::*event_handler)(const std::string&) + ); + /*! + requires + - event_handler == a valid pointer to a member function of object T. + ensures + - Displays a window titled "Open File" that will allow the user to select + a file. But only a file that already exists. + - The displayed window will start out showing the directory get_current_dir() + (i.e. it starts in the current working directory) + - The event_handler function is called on object if the user selects + a file. If the user closes the window without selecting a file + then nothing occurs. + !*/ + + void open_existing_file_box ( + const any_function& event_handler + ); + /*! + ensures + - Displays a window titled "Open File" that will allow the user to select + a file. But only a file that already exists. + - The displayed window will start out showing the directory get_current_dir() + (i.e. it starts in the current working directory) + - The event_handler function is called if the user selects + a file. If the user closes the window without selecting a file + then nothing occurs. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + void save_file_box ( + T& object, + void (T::*event_handler)(const std::string&) + ); + /*! + requires + - event_handler == a valid pointer to a member function of object T. + ensures + - Displays a window titled "Save File" that will allow the user to select + a file. + - The displayed window will start out showing the directory get_current_dir() + (i.e. it starts in the current working directory) + - The event_handler function is called on object if the user selects + a file. If the user closes the window without selecting a file + then nothing occurs. + !*/ + + void save_file_box ( + const any_function& event_handler + ); + /*! + ensures + - Displays a window titled "Save File" that will allow the user to select + a file. + - The displayed window will start out showing the directory get_current_dir() + (i.e. it starts in the current working directory) + - The event_handler function is called if the user selects + a file. If the user closes the window without selecting a file + then nothing occurs. + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // function message_box() +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + void message_box ( + const std::string& title, + const std::string& message + ); + /*! + ensures + - displays a message box with the given title and message. It will have a + single button and when the user clicks it the message box will go away. + - this function does not block but instead returns immediately. + !*/ + + void message_box_blocking ( + const std::string& title, + const std::string& message + ); + /*! + ensures + - displays a message box with the given title and message. It will have a + single button and when the user clicks it the message box will go away. + - this function blocks until the user clicks on the message box and + causes it to go away. + !*/ + + template < + typename T + > + void message_box ( + const std::string& title, + const std::string& message, + T& object, + void (T::*event_handler)() + ); + /*! + requires + - event_handler == a valid pointer to a member function of object T. + ensures + - Displays a message box with the given title and message. It will have a + single button and when the user clicks it the message box will go away. + - The event_handler function is called on object when the user clicks + ok or otherwise closes the message box window. + - this function does not block but instead returns immediately. + !*/ + + void message_box ( + const std::string& title, + const std::string& message, + const any_function& event_handler + ); + /*! + ensures + - Displays a message box with the given title and message. It will have a + single button and when the user clicks it the message box will go away. + - The event_handler function is called when the user clicks + ok or otherwise closes the message box window. + - this function does not block but instead returns immediately. + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // class label +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class label : public drawable + { + /*! + INITIAL VALUE + text() == "" + the text color will be black + + WHAT THIS OBJECT REPRESENTS + This object represents a simple text label. The size of the label + is automatically set to be just big enough to contain its text. + !*/ + + public: + + label( + drawable_window& w + ); + /*! + ensures + - #*this is properly initialized + - #*this has been added to window w + - #parent_window() == w + throws + - std::bad_alloc + - dlib::thread_error + !*/ + + virtual ~label( + ); + /*! + ensures + - all resources associated with *this have been released + !*/ + + void set_text (const std::wstring& text); + void set_text (const dlib::ustring& text); + void set_text ( + const std::string& text + ); + /*! + ensures + - #text() == text + throws + - std::bad_alloc + !*/ + + const std::wstring wtext () const; + const dlib::ustring utext () const; + const std::string text ( + ) const; + /*! + ensures + - returns the text of this label + throws + - std::bad_alloc + !*/ + + void set_text_color ( + const rgb_pixel color + ); + /*! + ensures + - #text_color() == color + !*/ + + const rgb_pixel text_color ( + ) const; + /*! + ensures + - returns the color used to draw the text in this widget + !*/ + + private: + + // restricted functions + label(label&); // copy constructor + label& operator=(label&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // class toggle_button +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class toggle_button : public button_action + { + /*! + INITIAL VALUE + name() == "" + is_checked() == false + + WHAT THIS OBJECT REPRESENTS + This object represents a simple two state toggle button. Is is either + in the checked or unchecked state and when a user clicks on it it toggles its + state. + + When this object is disabled it means it will not respond to user clicks. + !*/ + + public: + + toggle_button( + drawable_window& w + ); + /*! + ensures + - #*this is properly initialized + - #*this has been added to window w + - #parent_window() == w + throws + - std::bad_alloc + - dlib::thread_error + !*/ + + virtual ~toggle_button( + ); + /*! + ensures + - all resources associated with *this have been released + !*/ + + void set_name (const std::wstring& name); + void set_name (const dlib::ustring& name); + void set_name ( + const std::string& name + ); + /*! + ensures + - #name() == name + - this toggle_button has been resized such that it is big enough to contain + the new name. + throws + - std::bad_alloc + !*/ + + void set_size ( + unsigned long width_, + unsigned long height_ + ); + /*! + ensures + - if (width and height are big enough to contain the name of this button) then + - #width() == width_ + - #height() == height_ + - #top() == top() + - #left() == left() + - i.e. The location of the upper left corner of this button stays the + same but its width and height are modified + !*/ + + void set_tooltip_text (const std::wstring& text); + void set_tooltip_text (const dlib::ustring& text); + void set_tooltip_text ( + const std::string& text + ); + /*! + ensures + - #tooltip_text() == text + - enables the tooltip for this toggle_button + !*/ + + const dlib::ustring tooltip_utext () const; + const std::wstring tooltip_wtext () const; + const std::string tooltip_text ( + ) const; + /*! + ensures + - returns the text that is displayed in the tooltip for this toggle_button + !*/ + + template < + typename style_type + > + void set_style ( + const style_type& style + ); + /*! + requires + - style_type == a type that inherits from toggle_button_style + ensures + - this toggle_button object will draw itself using the given + button style + !*/ + + bool is_checked ( + ) const; + /*! + ensures + - if (this box is currently checked) then + - returns true + - else + - returns false + !*/ + + const std::wstring wname () const; + const dlib::ustring uname () const; + const std::string name ( + ) const; + /*! + ensures + - returns the name of this toggle_button. The name is a string + that appears to the right of the actual check box. + throws + - std::bad_alloc + !*/ + + void set_checked ( + ); + /*! + ensures + - #is_checked() == true + !*/ + + void set_unchecked ( + ); + /*! + ensures + - #is_checked() == false + !*/ + + template < + typename T + > + void set_click_handler ( + T& object, + void (T::*event_handler)() + ); + /*! + requires + - event_handler is a valid pointer to a member function in T + ensures + - the event_handler function is called on object when the toggle_button is + toggled by the user. + - this event is NOT triggered by calling set_checked() or set_unchecked(). + - any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + throws + - std::bad_alloc + !*/ + + void set_click_handler ( + const any_function& event_handler + ); + /*! + ensures + - the event_handler function is called when the toggle_button is + toggled by the user. + - this event is NOT triggered by calling set_checked() or set_unchecked(). + - any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + throws + - std::bad_alloc + !*/ + + template < + typename T + > + void set_click_handler ( + T& object, + void (T::*event_handler)(toggle_button& self) + ); + /*! + requires + - event_handler is a valid pointer to a member function in T. + ensures + - the event_handler function is called on object when the toggle_button is + toggled by the user. self will be a reference to the toggle_button object + that the user clicked. + - this event is NOT triggered by calling set_checked() or set_unchecked(). + - any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + throws + - std::bad_alloc + !*/ + + void set_sourced_click_handler ( + const any_function& event_handler + ); + /*! + ensures + - the event_handler function is called when the toggle_button is + toggled by the user. self will be a reference to the toggle_button object + that the user clicked. + - this event is NOT triggered by calling set_checked() or set_unchecked(). + - any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + throws + - std::bad_alloc + !*/ + + private: + + // restricted functions + toggle_button(toggle_button&); // copy constructor + toggle_button& operator=(toggle_button&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // class text_field +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class text_field : public drawable + { + /*! + INITIAL VALUE + - text() == "" + - width() == 10 + - height() == a height appropriate for the font used. The text color will + be black. + - has_input_focus() == false + + WHAT THIS OBJECT REPRESENTS + This object represents a simple one line text input field. + !*/ + + public: + + text_field( + drawable_window& w + ); + /*! + ensures + - #*this is properly initialized + - #*this has been added to window w + - #parent_window() == w + throws + - std::bad_alloc + - dlib::thread_error + !*/ + + virtual ~text_field( + ); + /*! + ensures + - all resources associated with *this have been released + !*/ + + template < + typename style_type + > + void set_style ( + const style_type& style + ); + /*! + requires + - style_type == a type that inherits from text_field_style + ensures + - this text_field object will draw itself using the given + text field style + !*/ + + void set_text (const std::wstring& text); + void set_text (const dlib::ustring& text); + void set_text ( + const std::string& text + ); + /*! + requires + - text.find_first_of('\n') == std::string::npos + (i.e. there aren't any new lines in text) + ensures + - #text() == text + throws + - std::bad_alloc + !*/ + + const std::wstring wtext () const; + const dlib::ustring utext () const; + const std::string text ( + ) const; + /*! + ensures + - returns the text of this text_field + throws + - std::bad_alloc + !*/ + + void set_width ( + unsigned long width_ + ); + /*! + ensures + - if (width >= 10) then + - #width() == width_ + - #height() == height() + - #top() == top() + - #left() == left() + - i.e. The width of this drawable is set to the given width but + nothing else changes. + !*/ + + void give_input_focus ( + ); + /*! + ensures + - #has_input_focus() == true + !*/ + + bool has_input_focus ( + ); + /*! + ensures + - Returns true if this txt field has input keyboard focus. If this + is the case then it means that when the user types on the keyboard + the output will appear inside the text field. + !*/ + + void select_all_text ( + ); + /*! + ensures + - causes all the text in the text field to become selected. + (note that it doesn't give input focus) + !*/ + + void set_text_color ( + const rgb_pixel color + ); + /*! + ensures + - #text_color() == color + !*/ + + const rgb_pixel text_color ( + ) const; + /*! + ensures + - returns the color used to draw the text in this widget + !*/ + + void set_background_color ( + const rgb_pixel color + ); + /*! + ensures + - #background_color() == color + !*/ + + const rgb_pixel background_color ( + ) const; + /*! + ensures + - returns the color used to fill in the background of this widget + !*/ + + template < + typename T + > + void set_text_modified_handler ( + T& object, + void (T::*event_handler)() + ); + /*! + requires + - event_handler is a valid pointer to a member function in T + ensures + - the event_handler function is called on object when the text + in this text_field is modified by the user. + - any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + throws + - std::bad_alloc + !*/ + + void set_text_modified_handler ( + const any_function& event_handler + ); + /*! + ensures + - the event_handler function is called when the text in this text_field + is modified by the user. + - any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + throws + - std::bad_alloc + !*/ + + template < + typename T + > + void set_enter_key_handler ( + T& object, + void (T::*event_handler)() + ); + /*! + requires + - event_handler is a valid pointer to a member function in T + ensures + - the event_handler function is called on object when this text field + has input focus and the user hits the enter key on their keyboard. + - any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + throws + - std::bad_alloc + !*/ + + void set_enter_key_handler ( + const any_function& event_handler + ); + /*! + ensures + - the event_handler function is called when this text field has input + focus and the user hits the enter key on their keyboard. + - any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + throws + - std::bad_alloc + !*/ + + template < + typename T + > + void set_focus_lost_handler ( + T& object, + void (T::*event_handler)() + ); + /*! + requires + - event_handler is a valid pointer to a member function in T + ensures + - the event_handler function is called on object when this object + loses input focus due to the user clicking outside the text field + - any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + throws + - std::bad_alloc + !*/ + + void set_focus_lost_handler ( + const any_function& event_handler + ); + /*! + ensures + - the event_handler function is called when this object loses input + focus due to the user clicking outside the text field + - any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + throws + - std::bad_alloc + !*/ + + private: + + // restricted functions + text_field(text_field&); // copy constructor + text_field& operator=(text_field&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // class text_box +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class text_box : public scrollable_region + { + /*! + INITIAL VALUE + - text() == "" + - The text color will be black. + - width() == 100 + - height() == 100 + + WHAT THIS OBJECT REPRESENTS + This object represents a simple multi-line text input box. + !*/ + + public: + + text_box( + drawable_window& w + ); + /*! + ensures + - #*this is properly initialized + - #*this has been added to window w + - #parent_window() == w + throws + - std::bad_alloc + - dlib::thread_error + !*/ + + virtual ~text_box( + ); + /*! + ensures + - all resources associated with *this have been released + !*/ + + template < + typename style_type + > + void set_style ( + const style_type& style + ); + /*! + requires + - style_type == a type that inherits from text_box_style + ensures + - this text_box object will draw itself using the given + text box style + !*/ + + void set_text (const std::wstring& text); + void set_text (const dlib::ustring& text); + void set_text ( + const std::string& text + ); + /*! + ensures + - #text() == text + throws + - std::bad_alloc + !*/ + + const std::wstring wtext () const; + const dlib::ustring utext () const; + const std::string text ( + ) const; + /*! + ensures + - returns the text of this text_box + throws + - std::bad_alloc + !*/ + + void set_size ( + unsigned long width, + unsigned long height + ); + /*! + ensures + - #width() == width_ + - #height() == height_ + - #top() == top() + - #left() == left() + - i.e. The location of the upper left corner of this widget stays the + same but its width and height are modified + !*/ + + void set_text_color ( + const rgb_pixel color + ); + /*! + ensures + - #text_color() == color + !*/ + + const rgb_pixel text_color ( + ) const; + /*! + ensures + - returns the color used to draw the text in this widget + !*/ + + void set_background_color ( + const rgb_pixel color + ); + /*! + ensures + - #background_color() == color + !*/ + + const rgb_pixel background_color ( + ) const; + /*! + ensures + - returns the color used to fill in the background of this widget + !*/ + + template < + typename T + > + void set_text_modified_handler ( + T& object, + void (T::*event_handler)() + ); + /*! + requires + - event_handler is a valid pointer to a member function in T + ensures + - the event_handler function is called on object when the text + in this text_box is modified by the user. + - any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + throws + - std::bad_alloc + !*/ + + void set_text_modified_handler ( + const any_function& event_handler + ); + /*! + ensures + - the event_handler function is called when the text in this text_box + is modified by the user. + - any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + throws + - std::bad_alloc + !*/ + + template < + typename T + > + void set_enter_key_handler ( + T& object, + void (T::*event_handler)() + ); + /*! + requires + - event_handler is a valid pointer to a member function in T + ensures + - the event_handler function is called on object when this text box + has input focus and the user hits the enter key on their keyboard. + - any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + throws + - std::bad_alloc + !*/ + + void set_enter_key_handler ( + const any_function& event_handler + ); + /*! + ensures + - the event_handler function is called when this text box has input + focus and the user hits the enter key on their keyboard. + - any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + throws + - std::bad_alloc + !*/ + + template < + typename T + > + void set_focus_lost_handler ( + T& object, + void (T::*event_handler)() + ); + /*! + requires + - event_handler is a valid pointer to a member function in T + ensures + - the event_handler function is called on object when this object + loses input focus due to the user clicking outside the text box + - any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + throws + - std::bad_alloc + !*/ + + void set_focus_lost_handler ( + const any_function& event_handler + ); + /*! + ensures + - the event_handler function is called when this object loses input + focus due to the user clicking outside the text box + - any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + throws + - std::bad_alloc + !*/ + + private: + + // restricted functions + text_box(text_box&); // copy constructor + text_box& operator=(text_box&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // class check_box +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class check_box : public toggle_button + { + /*! + This is just a toggle button with the style set to + toggle_button_style_check_box. + !*/ + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // class radio_button +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class radio_button : public toggle_button + { + /*! + This is just a toggle button with the style set to + toggle_button_style_radio_button. + !*/ + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // class tabbed_display +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class tabbed_display : public drawable + { + /*! + INITIAL VALUE + number_of_tabs() == 1 + selected_tab() == 0 + + WHAT THIS OBJECT REPRESENTS + This object represents a row of tabs that are user selectable. + + When this object is disabled it means it will not respond to user clicks. + !*/ + + public: + + tabbed_display( + drawable_window& w + ); + /*! + ensures + - #*this is properly initialized + - #*this has been added to window w + - #parent_window() == w + throws + - std::bad_alloc + - dlib::thread_error + !*/ + + virtual ~tabbed_display( + ); + /*! + ensures + - all resources associated with *this have been released + !*/ + + void set_size ( + unsigned long width_, + unsigned long height_ + ); + /*! + ensures + - if (width and height are big enough to contain the tabs) then + - #width() == width_ + - #height() == height_ + - #top() == top() + - #left() == left() + - i.e. The location of the upper left corner of this widget stays the + same but its width and height are modified + !*/ + + void set_number_of_tabs ( + unsigned long num + ); + /*! + requires + - num > 0 + ensures + - #number_of_tabs() == num + - no tabs have any widget_groups associated with them. + - for all valid idx: + - #tab_name(idx) == "" + throws + - std::bad_alloc + !*/ + + unsigned long selected_tab ( + ) const; + /*! + ensures + - returns the index of the currently selected tab + !*/ + + unsigned long number_of_tabs ( + ) const; + /*! + ensures + - returns the number of tabs in this tabbed_display + !*/ + + const std::wstring& tab_wname (unsigned long idx) const; + const dlib::ustring& tab_uname (unsigned long idx) const; + const std::string& tab_name ( + unsigned long idx + ) const; + /*! + requires + - idx < number_of_tabs() + ensures + - returns a const reference to the name of the tab given by idx + !*/ + + void set_tab_name (unsigned long idx, const std::wstring& new_name); + void set_tab_name (unsigned long idx, const dlib::ustring& new_name); + void set_tab_name ( + unsigned long idx, + const std::string& new_name + ); + /*! + requires + - idx < number_of_tabs() + ensures + - #tab_name(idx) == new_name + throws + - std::bad_alloc + !*/ + + void set_tab_group ( + unsigned long idx, + widget_group& group + ); + /*! + requires + - idx < number_of_tabs() + ensures + - if (is_hidden()) then + - group.is_hidden() == true + - else + - whenever the tab with index idx is selected group.is_hidden() == false + - whenever the tab with index idx is deselected group.is_hidden() == true + - whenever the position of *this changes the position of group will be + updated so that it is still inside the tabbed_display. The position of group + will also be updated after this call to set_tab_group(). + - any previous calls to set_tab_group() with this index are overridden by this + new call. (i.e. you can only have one widget_group associated with a single + tab at a time) + !*/ + + void fit_to_contents ( + ); + /*! + ensures + - Adjusts the size this tabbed_display so that it nicely contains + all of its widget_group objects. + - does not change the position of this object. + (i.e. the upper left corner of get_rect() remains at the same position) + !*/ + + template < + typename T + > + void set_click_handler ( + T& object, + void (T::*event_handler)(unsigned long new_idx, unsigned long old_idx) + ); + /*! + requires + - event_handler is a valid pointer to a member function in T + ensures + - The event_handler function is called on object when the user clicks + on a tab that isn't already selected. new_idx will give the index of + the newly selected tab and old_idx will give the index of the tab + that was previously selected. + - any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + throws + - std::bad_alloc + !*/ + + void set_click_handler ( + const any_function& eh + ); + /*! + ensures + - The event_handler function is called when the user clicks on a tab + that isn't already selected. new_idx will give the index of the + newly selected tab and old_idx will give the index of the tab that + was previously selected. + - any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + throws + - std::bad_alloc + !*/ + + private: + + // restricted functions + tabbed_display(tabbed_display&); // copy constructor + tabbed_display& operator=(tabbed_display&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // class named_rectangle +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class named_rectangle : public drawable + { + /*! + INITIAL VALUE + name() == "" + + WHAT THIS OBJECT REPRESENTS + This object represents a simple named rectangle. + !*/ + + public: + + named_rectangle( + drawable_window& w + ); + /*! + ensures + - #*this is properly initialized + - #*this has been added to window w + - #parent_window() == w + throws + - std::bad_alloc + - dlib::thread_error + !*/ + + virtual ~named_rectangle( + ); + /*! + ensures + - all resources associated with *this have been released + !*/ + + void set_size ( + unsigned long width_, + unsigned long height_ + ); + /*! + ensures + - #width() == width_ + - #height() == height_ + - #top() == top() + - #left() == left() + - i.e. The location of the upper left corner of this widget stays the + same but its width and height are modified + !*/ + + void wrap_around ( + const rectangle& rect + ); + /*! + ensures + - This object will be repositioned and sized so that it fits + around the given rectangle. + !*/ + + void set_name (const std::wstring& name); + void set_name (const dlib::ustring& name); + void set_name ( + const std::string& name + ); + /*! + ensures + - #name() == name + throws + - std::bad_alloc + !*/ + + const std::wstring wname () const; + const dlib::ustring uname () const; + const std::string name ( + ) const; + /*! + ensures + - returns the name of this named_rectangle + throws + - std::bad_alloc + !*/ + + private: + + // restricted functions + named_rectangle(named_rectangle&); // copy constructor + named_rectangle& operator=(named_rectangle&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // class mouse_tracker +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class mouse_tracker : public draggable + { + /*! + INITIAL VALUE + draggable_area() == rectangle(0,0,500,500) + + WHAT THIS OBJECT REPRESENTS + This object represents a simple draggable box that displays the + current location of the mouse. + + Also, if you hold shift and left click on the parent window then the + mouse_tracker will place a single red pixel where you clicked and will + display the mouse position relative to that point. + !*/ + + public: + + mouse_tracker( + drawable_window& w + ); + /*! + ensures + - #*this is properly initialized + - #*this has been added to window w + - #parent_window() == w + throws + - std::bad_alloc + - dlib::thread_error + !*/ + + virtual ~mouse_tracker( + ); + /*! + ensures + - all resources associated with *this have been released + !*/ + + private: + + // restricted functions + mouse_tracker(mouse_tracker&); // copy constructor + mouse_tracker& operator=(mouse_tracker&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // class list_box +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class list_box : public scrollable_region, + public enumerable + { + /*! + INITIAL VALUE + multiple_select_enabled() == false + size() == 0 + + ENUMERATION ORDER + The enumerator will iterate over the elements in the list_box from + the 0th element to the (size()-1)th element. i.e. (*this)[0] to + (*this)[size()-1]. + + WHAT THIS OBJECT REPRESENTS + This object represents a simple textual list box. It contains a + vertical list of strings which the user may select from. + !*/ + + public: + + list_box( + drawable_window& w + ); + /*! + ensures + - #*this is properly initialized + - #*this has been added to window w + - #parent_window() == w + throws + - std::bad_alloc + - dlib::thread_error + !*/ + + virtual ~list_box( + ); + /*! + ensures + - all resources associated with *this have been released + !*/ + + template < + typename style_type + > + void set_style ( + const style_type& style + ); + /*! + requires + - style_type == a type that inherits from list_box_style + ensures + - this list_box object will draw itself using the given style + !*/ + + void set_size ( + unsigned long width_, + unsigned long height_ + ); + /*! + ensures + - #width() == width_ + - #height() == height_ + - #top() == top() + - #left() == left() + - i.e. The location of the upper left corner of this widget stays the + same but its width and height are modified + !*/ + + bool is_selected ( + unsigned long index + ) const; + /*! + requires + - index < size() + ensures + - if (the item given by index is currently selected) then + - returns true + - else + - returns false + !*/ + + void select ( + unsigned long index + ); + /*! + requires + - index < size() + ensures + - #is_selected(index) == true + !*/ + + void unselect ( + unsigned long index + ); + /*! + requires + - index < size() + ensures + - #is_selected(index) == false + !*/ + + template + void get_selected ( + T& list + ) const; + /*! + requires + - T == an implementation of dlib/queue/queue_kernel_abstract.h + - T::type == unsigned long + ensures + - #list == a list of all the currently selected indices for this list_box. + !*/ + + unsigned long get_selected ( + ) const; + /*! + requires + - multiple_select_enabled() == false + ensures + - if (there is currently something selected) then + - returns the index of the selected item + - else + - returns size() + !*/ + + template + void load ( + const T& list + ); + /*! + requires + - T == compatible with dlib::enumerable + ensures + - #size() == list.size() + - Copies all the strings from list into *this in the order in which they are enumerated. + (i.e. The first one goes into (*this)[0], the second into (*this)[1], and so on...) + !*/ + + const std::string& operator[] ( + unsigned long index + ) const; + /*! + requires + - index < size() + ensures + - returns the name of the indexth item/row in this list box. + !*/ + + bool multiple_select_enabled ( + ) const; + /*! + ensures + - if (this object will allow the user to select more than one item at a time) then + - returns true + - else + - returns false + !*/ + + void enable_multiple_select ( + ); + /*! + ensures + - #multiple_select_enabled() == true + !*/ + + void disable_multiple_select ( + ); + /*! + ensures + - #multiple_select_enabled() == false + !*/ + + template < + typename T + > + void set_double_click_handler ( + T& object, + void (T::*event_handler)(unsigned long index) + ); + /*! + requires + - event_handler is a valid pointer to a member function in T. + ensures + - The event_handler function is called on object when the user double + clicks on one of the rows in this list box. index gives the row + number for the item the user clicked. + - any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + throws + - std::bad_alloc + !*/ + + void set_double_click_handler ( + const any_function& event_handler + ); + /*! + ensures + - The event_handler function is called when the user double clicks on + one of the rows in this list box. index gives the row number for + the item the user clicked. + - any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + throws + - std::bad_alloc + !*/ + + template < + typename T + > + void set_click_handler ( + T& object, + void (T::*event_handler)(unsigned long index) + ); + /*! + requires + - event_handler is a valid pointer to a member function in T. + ensures + - The event_handler function is called on object when the user + clicks on one of the rows in this list box. index gives the row + number for the item the user clicked. (Note that the second click + in a double click triggers the double click handler above instead + of this event) + - any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + throws + - std::bad_alloc + !*/ + + void set_click_handler ( + const any_function& event_handler + ); + /*! + ensures + - The event_handler function is called when the user clicks on one + of the rows in this list box. index gives the row number for the + item the user clicked. (Note that the second click in a double + click triggers the double click handler above instead of this event) + - any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + throws + - std::bad_alloc + !*/ + + private: + + // restricted functions + list_box(list_box&); // copy constructor + list_box& operator=(list_box&); // assignment operator + }; + + class wlist_box : public scrollable_region, + public enumerable; + /*! + same as list_box except for std::wstring instead of std::string + !*/ + + class ulist_box : public scrollable_region, + public enumerable; + /*! + same as list_box except for dlib::ustring instead of std::string + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // class menu_bar +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class menu_bar : public drawable + { + /*! + INITIAL VALUE + - number_of_menus() == 0 + + WHAT THIS OBJECT REPRESENTS + This object represents a menu bar that appears at the top of a + window. + !*/ + + public: + + menu_bar( + drawable_window& w + ); + /*! + ensures + - #*this is properly initialized + - #*this has been added to window w + - #parent_window() == w + throws + - std::bad_alloc + - dlib::thread_error + !*/ + + virtual ~menu_bar( + ); + /*! + ensures + - all resources associated with *this have been released + !*/ + + void set_number_of_menus ( + unsigned long num + ); + /*! + ensures + - #number_of_menus() == num + !*/ + + unsigned long number_of_menus ( + ) const; + /*! + ensures + - returns the number of menus in this menu_bar + !*/ + + void set_menu_name (unsigned long idx, const std::wstring name, char underline_ch = '\0'); + void set_menu_name (unsigned long idx, const dlib::ustring name, char underline_ch = '\0'); + void set_menu_name ( + unsigned long idx, + const std::string name, + char underline_ch = '\0' + ); + /*! + requires + - idx < number_of_menus() + ensures + - #menu_name(idx) == name + - if (underline_ch is present in name) then + - The menu with index idx will have the first underline_ch character + in its name underlined and users will be able to activate the menu + by hitting alt+underline_char + !*/ + + const std::wstring menu_wname (unsigned long idx) const; + const dlib::ustring menu_uname (unsigned long idx) const; + const std::string menu_name ( + unsigned long idx + ) const; + /*! + requires + - idx < number_of_menus() + ensures + - returns the name of the menu with index idx + !*/ + + popup_menu& menu ( + unsigned long idx + ); + /*! + requires + - idx < number_of_menus() + ensures + - returns a non-const reference to the popup_menu for the menu with + index idx. + !*/ + + const popup_menu& menu ( + unsigned long idx + ) const; + /*! + requires + - idx < number_of_menus() + ensures + - returns a const reference to the popup_menu for the menu with + index idx. + !*/ + + private: + + // restricted functions + menu_bar(menu_bar&); // copy constructor + menu_bar& operator=(menu_bar&); // assignment operator + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename graph_type + > + class directed_graph_drawer : public zoomable_region + { + /*! + REQUIREMENTS ON graph_type + - must be an implementation of directed_graph/directed_graph_kernel_abstract.h + + INITIAL VALUE + - get_graph().size() == 0 + + WHAT THIS OBJECT REPRESENTS + This object represents a graphical widget that allows the user to draw + a directed graph. + + The user can create nodes by right clicking on the draw area and add + edges by selecting a node (via left clicking on it) and then holding + shift and clicking on the node that is to be the child node of the + selected node. + !*/ + + public: + + directed_graph_drawer ( + drawable_window& w + ); + /*! + ensures + - #*this is properly initialized + - #*this has been added to window w + - #parent_window() == w + throws + - std::bad_alloc + - dlib::thread_error + !*/ + + virtual ~directed_graph_drawer ( + ); + /*! + ensures + - all resources associated with *this have been released + !*/ + + const graph_type& graph ( + ) const; + /*! + requires + - drawable::m is locked + ensures + - returns a const reference to the graph that this widget has been drawing + !*/ + + unsigned long number_of_nodes ( + ) const; + /*! + ensures + - returns graph().number_of_nodes() + !*/ + + void clear_graph ( + ); + /*! + ensures + - #number_of_nodes() == 0 + !*/ + + const typename graph_type::node_type& graph_node ( + unsigned long i + ) const; + /*! + requires + - drawable::m is locked + - i < number_of_nodes() + ensures + - returns a const reference to get_graph().node(i) + !*/ + + typename graph_type::node_type& graph_node ( + unsigned long i + ); + /*! + requires + - drawable::m is locked + - i < number_of_nodes() + ensures + - returns a non-const reference to get_graph().node(i) + !*/ + + void save_graph ( + std::ostream& out + ); + /*! + ensures + - saves the state of the graph to the output stream. Does so in a + way that not only preserves the state of the graph this->graph() + but also preserves the graphical layout of the graph in this + GUI widget. + - Also, the first part of the saved state is a serialized + version of this->graph(). Thus, you can deserialize just the + this->graph() object from the serialized data if you like. + !*/ + + void load_graph ( + std::istream& in + ); + /*! + ensures + - loads a saved graph from the given input stream. + !*/ + + void set_node_label (unsigned long i, const std::wstring& label); + void set_node_label (unsigned long i, const dlib::ustring& label); + void set_node_label ( + unsigned long i, + const std::string& label + ); + /*! + requires + - i < number_of_nodes() + ensures + - #node_label(i) == label + !*/ + + void set_node_color ( + unsigned long i, + rgb_pixel color + ); + /*! + requires + - i < number_of_nodes() + ensures + - #node_color(i) == color + !*/ + + rgb_pixel node_color ( + unsigned long i + ) const; + /*! + requires + - i < number_of_nodes() + ensures + - returns the color used to draw node graph_node(i) + !*/ + + const std::wstring node_wlabel (unsigned long i) const; + const dlib::ustring node_ulabel (unsigned long i) const; + const std::string node_label ( + unsigned long i + ) const; + /*! + requires + - i < number_of_nodes() + ensures + - returns the text label for node graph_node(i) + !*/ + + template < + typename T + > + void set_node_selected_handler ( + T& object, + void (T::*event_handler)(unsigned long node_index) + ); + /*! + requires + - event_handler is a valid pointer to a member function in T + ensures + - the event_handler function is called on object when the user selects + a node. + - node_index == the index of the node that was selected + - any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + throws + - std::bad_alloc + !*/ + + void set_node_selected_handler ( + const any_function& event_handler + ); + /*! + ensures + - the event_handler function is called when the user selects + a node. + - node_index == the index of the node that was selected + - any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + throws + - std::bad_alloc + !*/ + + template < + typename T + > + void set_node_deselected_handler ( + T& object, + void (T::*event_handler)(unsigned long node_index) + ); + /*! + requires + - event_handler is a valid pointer to a member function in T + ensures + - the event_handler function is called on object when the user + deselects a node. + - node_index == the index of the node that was deselected + - any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + throws + - std::bad_alloc + !*/ + + void set_node_deselected_handler ( + const any_function& event_handler + ); + /*! + ensures + - the event_handler function is called when the user deselects a node. + - node_index == the index of the node that was deselected + - any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + throws + - std::bad_alloc + !*/ + + template < + typename T + > + void set_node_deleted_handler ( + T& object, + void (T::*event_handler)() + ); + /*! + requires + - event_handler is a valid pointer to a member function in T + ensures + - the event_handler function is called on object when the user + deletes a node. + - any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + throws + - std::bad_alloc + !*/ + + void set_node_deleted_handler ( + const any_function& event_handler + ); + /*! + ensures + - the event_handler function is called when the user deletes a node. + - any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + throws + - std::bad_alloc + !*/ + + template < + typename T + > + void set_graph_modified_handler ( + T& object, + void (T::*event_handler)() + ); + /*! + requires + - event_handler is a valid pointer to a member function in T + ensures + - the event_handler function is called on object when the user + modifies the graph (i.e. adds or removes a node or edge) + - the event_handler function is not called when the user just + moves nodes around on the screen. + - This event is always dispatched before any more specific event + that results from the user modifying the graph. + - any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + throws + - std::bad_alloc + !*/ + + void set_graph_modified_handler ( + const any_function& event_handler + ); + /*! + ensures + - the event_handler function is called when the user modifies + the graph (i.e. adds or removes a node or edge) + - the event_handler function is not called when the user just + moves nodes around on the screen. + - This event is always dispatched before any more specific event + that results from the user modifying the graph. + - any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + throws + - std::bad_alloc + !*/ + + private: + + // restricted functions + directed_graph_drawer(directed_graph_drawer&); // copy constructor + directed_graph_drawer& operator=(directed_graph_drawer&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- + + class text_grid : public scrollable_region + { + /*! + INITIAL VALUE + - vertical_scroll_increment() == 10 + - horizontal_scroll_increment() == 10 + - border_color() == rgb_pixel(128,128,128) + - number_of_columns() == 0 + - number_of_rows() == 0 + + WHAT THIS OBJECT REPRESENTS + This object represents a simple grid of square text fields that + looks more or less like a spreadsheet grid. + !*/ + + public: + + text_grid ( + drawable_window& w + ); + /*! + ensures + - #*this is properly initialized + - #*this has been added to window w + - #parent_window() == w + throws + - std::bad_alloc + - dlib::thread_error + !*/ + + virtual ~text_grid ( + ); + /*! + ensures + - all resources associated with *this have been released + !*/ + + void set_grid_size ( + unsigned long rows, + unsigned long cols + ); + /*! + ensures + - #number_of_rows() == rows + - #number_of_columns() == cols + - for all valid r and c: + - #text(r,c) == "" + - #text_color(r,c) == rgb_pixel(0,0,0) + - #background_color(r,c) == rgb_pixel(255,255,255) + - #is_editable(r,c) == true + !*/ + + unsigned long number_of_columns ( + ) const; + /*! + ensures + - returns the number of columns contained in this grid + !*/ + + unsigned long number_of_rows ( + ) const; + /*! + ensures + - returns the number of rows contained in this grid + !*/ + + rgb_pixel border_color ( + ) const; + /*! + ensures + - returns the color of the lines drawn between the grid elements + !*/ + + void set_border_color ( + rgb_pixel color + ); + /*! + ensures + - #border_color() == color + !*/ + + const std::wstring wtext (unsigned long row, unsigned long col) const; + const dlib::ustring utext (unsigned long row, unsigned long col) const; + const std::string text ( + unsigned long row, + unsigned long col + ) const; + /*! + requires + - row < number_of_rows() + - col < number_of_columns() + ensures + - returns the text in the given grid location + !*/ + + void set_text (unsigned long row, unsigned long col, const std::wstring& str); + void set_text (unsigned long row, unsigned long col, const dlib::ustring& str); + void set_text ( + unsigned long row, + unsigned long col, + const std::string& str + ); + /*! + requires + - row < number_of_rows() + - col < number_of_columns() + ensures + - #text(row,col) == str + !*/ + + const rgb_pixel text_color ( + unsigned long row, + unsigned long col + ) const; + /*! + requires + - row < number_of_rows() + - col < number_of_columns() + ensures + - returns the color of the text in the given grid location + !*/ + + void set_text_color ( + unsigned long row, + unsigned long col, + const rgb_pixel color + ); + /*! + requires + - row < number_of_rows() + - col < number_of_columns() + ensures + - #text_color(row,col) == color + !*/ + + const rgb_pixel background_color ( + unsigned long row, + unsigned long col + ) const; + /*! + requires + - row < number_of_rows() + - col < number_of_columns() + ensures + - returns the background color of the given grid location + !*/ + + void set_background_color ( + unsigned long row, + unsigned long col, + const rgb_pixel color + ); + /*! + requires + - row < number_of_rows() + - col < number_of_columns() + ensures + - #background_color(row,col) == color + !*/ + + bool is_editable ( + unsigned long row, + unsigned long col + ) const; + /*! + requires + - row < number_of_rows() + - col < number_of_columns() + ensures + - if (the given grid location is editable by the user) then + - returns true + - else + - returns false + !*/ + + void set_editable ( + unsigned long row, + unsigned long col, + bool editable + ); + /*! + requires + - row < number_of_rows() + - col < number_of_columns() + ensures + - #is_editable(row,col) == editable + !*/ + + void set_column_width ( + unsigned long col, + unsigned long width + ); + /*! + requires + - col < number_of_columns() + ensures + - the given column will be displayed such that it is width pixels wide + !*/ + + void set_row_height ( + unsigned long row, + unsigned long height + ); + /*! + requires + - row < number_of_rows() + ensures + - the given row will be displayed such that it is height pixels wide + !*/ + + template < + typename T + > + void set_text_modified_handler ( + T& object, + void (T::*event_handler)(unsigned long row, unsigned long col) + ); + /*! + requires + - event_handler is a valid pointer to a member function in T + ensures + - the event_handler function is called on object when the user selects + a node. + - row == row will give the row of the grid item that was modified + - col == col will give the column of the grid item that was modified + - any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + throws + - std::bad_alloc + !*/ + + void set_text_modified_handler ( + const any_function& event_handler + ); + /*! + ensures + - the event_handler function is called when the user selects a node. + - row == row will give the row of the grid item that was modified + - col == col will give the column of the grid item that was modified + - any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + throws + - std::bad_alloc + !*/ + + private: + + // restricted functions + text_grid(text_grid&); // copy constructor + text_grid& operator=(text_grid&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- + + class image_display : public scrollable_region + { + /*! + INITIAL VALUE + - This object isn't displaying anything. + - get_overlay_rects().size() == 0 + - get_default_overlay_rect_label() == "" + - get_default_overlay_rect_color() == rgb_alpha_pixel(255,0,0,255) (i.e. RED) + - This object does not have any user labelable parts defined. + - overlay_editing_is_enabled() == true + + WHAT THIS OBJECT REPRESENTS + This object represents an image inside a scrollable region. + You give it an image to display by calling set_image(). + This widget also allows you to add rectangle and line overlays that + will be drawn on top of the image. + + If you hold the Ctrl key you can zoom in and out using the mouse wheel. + You can also add new overlay rectangles by holding shift, left clicking, + and dragging the mouse. Additionally, you can delete an overlay rectangle + by double clicking on it and hitting delete or backspace. Finally, you + can also add part labels (if they have been defined by calling add_labelable_part_name()) + by selecting an overlay rectangle with the mouse and then right clicking + on the part. If you want to move any rectangle or an object part then + shift+right click and drag it. + + Finally, if you hold Ctrl and left click an overlay rectangle it will + change its label to get_default_overlay_rect_label() and color to + get_default_overlay_rect_color(). + + The image is drawn such that: + - the pixel img[0][0] is the upper left corner of the image. + - the pixel img[img.nr()-1][0] is the lower left corner of the image. + - the pixel img[0][img.nc()-1] is the upper right corner of the image. + - the pixel img[img.nr()-1][img.nc()-1] is the lower right corner of the image. + !*/ + + public: + + image_display( + drawable_window& w + ); + /*! + ensures + - #*this is properly initialized + - #*this has been added to window w + - #parent_window() == w + !*/ + + ~image_display( + ); + /*! + ensures + - all resources associated with *this have been released + !*/ + + template < + typename image_type + > + void set_image ( + const image_type& new_img + ); + /*! + requires + - image_type == an implementation of array2d/array2d_kernel_abstract.h or + a dlib::matrix or something convertible to a matrix via mat() + - pixel_traits must be defined + ensures + - #*this widget is now displaying the given image new_img. + !*/ + + struct overlay_rect + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a rectangle that is drawn on top of the + image shown by this object. Each rectangle is represented by + a rectangle object as well as a color and text label. The label + is drawn below the lower right corner of the rectangle. + + Moreover, the rectangle can have sub-parts. Each part is listed + in the parts member variable. This variable maps the name of the + part to its position. + + Rectangles with crossed_out == true will be drawn with an X through + them. + !*/ + + rectangle rect; + rgb_alpha_pixel color; + std::string label; + std::map parts; + bool crossed_out; + + overlay_rect( + ); + /*! + ensures + - #color == rgb_alpha_pixel(0,0,0,0) + - #rect == rectangle() + - #label.size() == 0 + - #crossed_out == false + !*/ + + template + overlay_rect( + const rectangle& r, + pixel_type p + ); + /*! + ensures + - #rect == r + - performs assign_pixel(color, p) + - #label.size() == 0 + - #crossed_out == false + !*/ + + template + overlay_rect( + const rectangle& r, + pixel_type p, + const std::string& l + ); + /*! + ensures + - #rect == r + - performs assign_pixel(color, p) + - #label == l + - #crossed_out == false + !*/ + + template + overlay_rect( + const rectangle& r, + pixel_type p, + const std::string& l, + const std::map& parts_ + ); + /*! + ensures + - #rect == r + - performs assign_pixel(color, p) + - #label == l + - #parts == parts_ + - #crossed_out == false + !*/ + + }; + + struct overlay_line + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a line that is drawn on top of the + image shown by this object. Each line is represented by + its two end points (p1 and p2) as well as a color. + !*/ + + point p1; + point p2; + rgb_alpha_pixel color; + + overlay_line( + ); + /*! + ensures + - #color == rgb_alpha_pixel(0,0,0,0) + - #p1 == point() + - #p2 == point() + !*/ + + template + overlay_line( + const point& p1_, + const point& p2_, + pixel_type p + ); + /*! + ensures + - performs assign_pixel(color, p) + - #p1 == p1_ + - #p2 == p2_ + !*/ + + }; + + struct overlay_circle + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a circle that is drawn on top of the + image shown by this object. Each circle is represented by + its center, radius, and color. It can also have an optional + text label which will appear below the circle. + !*/ + + point center; + int radius; + rgb_alpha_pixel color; + std::string label; + + overlay_circle( + ); + /*! + ensures + - #center == point(0,0) + - #radius == 0 + - #color == rgb_alpha_pixel(0,0,0,0) + - #label.size() == 0 + !*/ + + template + overlay_circle( + const point& center_, + const int radius_, + pixel_type p + ); + /*! + ensures + - performs assign_pixel(color, p) + - #center == center_ + - #radius == radius_ + !*/ + + template + overlay_circle( + const point& center_, + const int radius_, + pixel_type p, + const std::string& label_ + ); + /*! + ensures + - performs assign_pixel(color, p) + - #center == center_ + - #radius == radius_ + - #label == label_ + !*/ + + }; + + void add_overlay ( + const overlay_rect& overlay + ); + /*! + ensures + - adds the given overlay rectangle into this object such + that it will be displayed. + !*/ + + void add_overlay ( + const overlay_line& overlay + ); + /*! + ensures + - adds the given overlay line into this object such + that it will be displayed. + !*/ + + void add_overlay ( + const overlay_circle& overlay + ); + /*! + ensures + - adds the given overlay circle into this object such + that it will be displayed. + !*/ + + void add_overlay ( + const std::vector& overlay + ); + /*! + ensures + - adds the given set of overlay rectangles into this object such + that they will be displayed. + !*/ + + void add_overlay ( + const std::vector& overlay + ); + /*! + ensures + - adds the given set of overlay lines into this object such + that they will be displayed. + !*/ + + void add_overlay ( + const std::vector& overlay + ); + /*! + ensures + - adds the given set of overlay circles into this object such + that they will be displayed. + !*/ + + void clear_overlay ( + ); + /*! + ensures + - removes all overlays from this object. + - #get_overlay_rects().size() == 0 + !*/ + + std::vector get_overlay_rects ( + ) const; + /*! + ensures + - returns a copy of all the overlay_rect objects currently displayed. + !*/ + + void set_default_overlay_rect_label ( + const std::string& label + ); + /*! + ensures + - #get_default_overlay_rect_label() == label + !*/ + + std::string get_default_overlay_rect_label ( + ) const; + /*! + ensures + - returns the label given to new overlay rectangles created by the user + (i.e. when the user holds shift and adds them with the mouse) + !*/ + + void set_default_overlay_rect_color ( + const rgb_alpha_pixel& color + ); + /*! + ensures + - #get_default_overlay_rect_color() == color + !*/ + + rgb_alpha_pixel get_default_overlay_rect_color ( + ) const; + /*! + ensures + - returns the color given to new overlay rectangles created by the user + (i.e. when the user holds shift and adds them with the mouse) + !*/ + + void add_labelable_part_name ( + const std::string& name + ); + /*! + ensures + - adds a user labelable part with the given name. If the name has + already been added then this function has no effect. + - These parts can be added by the user by selecting an overlay box + and then right clicking anywhere in it. A popup menu will appear + listing the parts. The user can then click a part name and it will + add it into the overlay_rect::parts variable and also show it on the + screen. + !*/ + + void clear_labelable_part_names ( + ); + /*! + ensures + - removes all use labelable parts. Calling this function undoes + all previous calls to add_labelable_part_name(). Therefore, the + user won't be able to label any parts after clear_labelable_part_names() + is called. + !*/ + + rectangle get_image_display_rect ( + ) const; + /*! + ensures + - returns a rectangle R that tells you how big the image inside the + display is when it appears on the screen. Note that it takes the + current zoom level into account. + - R.width() == the width of the displayed image + - R.height() == the height of the displayed image + - R.tl_corner() == (0,0) + !*/ + + void enable_overlay_editing ( + ); + /*! + ensures + - #overlay_editing_is_enabled() == true + !*/ + + void disable_overlay_editing ( + ); + /*! + ensures + - #overlay_editing_is_enabled() == false + !*/ + + bool overlay_editing_is_enabled ( + ) const; + /*! + ensures + - if this function returns true then it is possible for the user to add or + remove overlay objects (e.g. rectangles) using the mouse and keyboard. + If it returns false then the overlay is not user editable. + !*/ + + template < + typename T + > + void set_overlay_rects_changed_handler ( + T& object, + void (T::*event_handler)() + ); + /* + requires + - event_handler is a valid pointer to a member function in T + ensures + - the event_handler function is called on object when the user adds, + removes, or modifies an overlay rectangle. + - any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + throws + - std::bad_alloc + */ + + void set_overlay_rects_changed_handler ( + const any_function& event_handler + ); + /* + ensures + - the event_handler function is called when the user adds or removes + an overlay rectangle. + - any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + throws + - std::bad_alloc + */ + + template < + typename T + > + void set_overlay_rect_selected_handler ( + T& object, + void (T::*event_handler)(const overlay_rect& orect) + ); + /* + requires + - event_handler is a valid pointer to a member function in T + ensures + - The event_handler function is called on object when the user selects + an overlay rectangle by double clicking on it. The selected rectangle + will be passed to event_handler(). + - any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + throws + - std::bad_alloc + */ + + void set_overlay_rect_selected_handler ( + const any_function& event_handler + ); + /* + ensures + - The event_handler function is called when the user selects an overlay + rectangle by double clicking on it. The selected rectangle will be + passed to event_handler(). + - any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + throws + - std::bad_alloc + */ + + template < + typename T + > + void set_image_clicked_handler ( + T& object, + void (T::*event_handler)(const point& p, bool is_double_click, unsigned long btn) + ); + /* + requires + - event_handler is a valid pointer to a member function in T + ensures + - The event_handler function is called on object when the user left clicks + anywhere on the image. When they do so this callback is called with the + location of the image pixel which was clicked. The is_double_click bool + will also tell you if it was a double click or single click. + - btn == the button that was released. (either base_window::LEFT, base_window::MIDDLE, or base_window::RIGHT) + - any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + throws + - std::bad_alloc + */ + + void set_image_clicked_handler ( + const any_function& event_handler + ); + /* + ensures + - The event_handler function is called when the user left clicks anywhere + on the image. When they do so this callback is called with the location + of the image pixel which was clicked. The is_double_click bool will also + tell you if it was a double click or single click. + - btn == the button that was released. (either base_window::LEFT, base_window::MIDDLE, or base_window::RIGHT) + - Any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this event at a + time) + throws + - std::bad_alloc + */ + + private: + + // restricted functions + image_display(image_display&); // copy constructor + image_display& operator=(image_display&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- + + class image_window : public drawable_window + { + /*! + INITIAL VALUE + - initially, this object is visible on the screen + - events_tied() == false + + WHAT THIS OBJECT REPRESENTS + This is a simple window that is just a container for an image_display. + It exists to make it easy to throw image_displays onto the screen + without having to put together your own drawable_window objects. + !*/ + public: + + typedef image_display::overlay_rect overlay_rect; + typedef image_display::overlay_line overlay_line; + typedef image_display::overlay_circle overlay_circle; + + image_window( + ); + /*! + ensures + - this object is properly initialized + !*/ + + template + image_window( + const image_type& img + ); + /*! + requires + - image_type == an implementation of array2d/array2d_kernel_abstract.h or + a dlib::matrix or something convertible to a matrix via mat() + - pixel_traits must be defined + ensures + - this object is properly initialized + - #*this window is now displaying the given image img. + !*/ + + template < typename image_type> + image_window( + const image_type& img, + const std::string& title + ); + /*! + requires + - image_type == an implementation of array2d/array2d_kernel_abstract.h or + a dlib::matrix or something convertible to a matrix via mat() + - pixel_traits must be defined + ensures + - this object is properly initialized + - #*this window is now displaying the given image img. + - The title of the window will be set to the given title string. + !*/ + + ~image_window( + ); + /*! + ensures + - any resources associated with this object have been released + !*/ + + template + void set_image ( + const image_type& img + ); + /*! + requires + - image_type == an implementation of array2d/array2d_kernel_abstract.h + - pixel_traits must be defined + ensures + - #*this window is now displaying the given image img. + !*/ + + void add_overlay ( + const overlay_rect& overlay + ); + /*! + ensures + - adds the given overlay rectangle into this object such + that it will be displayed. + !*/ + + template + void add_overlay( + const rectangle& r, + pixel_type p = rgb_pixel(255,0,0) + ); + /*! + ensures + - performs: add_overlay(overlay_rect(r,p)); + !*/ + + template + void add_overlay( + const rectangle& r, + pixel_type p, + const std::string& l + ); + /*! + ensures + - performs: add_overlay(overlay_rect(r,p,l)); + !*/ + + template + void add_overlay( + const std::vector& r, + pixel_type p = rgb_pixel(255,0,0) + ); + /*! + ensures + - adds the given set of rectangles into this object such + that they will be displayed with the color specific by p. + !*/ + + void add_overlay( + const full_object_detection& object, + const std::vector& part_names + ); + /*! + ensures + - adds the given full_object_detection to the overlays + and shows it on the screen. This includes any of its + parts that are not set equal to OBJECT_PART_NOT_PRESENT. + - for all valid i < part_names.size(): + - the part object.part(i) will be labeled with the string + part_names[i]. + !*/ + + void add_overlay( + const full_object_detection& object + ); + /*! + ensures + - adds the given full_object_detection to the overlays + and shows it on the screen. This includes any of its + parts that are not set equal to OBJECT_PART_NOT_PRESENT. + !*/ + + void add_overlay( + const std::vector& objects, + const std::vector& part_names + ); + /*! + ensures + - calling this function is equivalent to calling the following + sequence of functions, for all valid i: + - add_overlay(objects[i], part_names); + !*/ + + void add_overlay( + const std::vector& objects + ); + /*! + ensures + - calling this function is equivalent to calling the following + sequence of functions, for all valid i: + - add_overlay(objects[i]); + !*/ + + void add_overlay ( + const overlay_line& overlay + ); + /*! + ensures + - adds the given overlay line into this object such + that it will be displayed. + !*/ + + void add_overlay ( + const overlay_circle& overlay + ); + /*! + ensures + - adds the given overlay circle into this object such + that it will be displayed. + !*/ + + template + void add_overlay( + const point& p1, + const point& p2, + pixel_type p + ); + /*! + ensures + - performs: add_overlay(overlay_line(p1,p2,p)); + !*/ + + void add_overlay ( + const std::vector& overlay + ); + /*! + ensures + - adds the given set of overlay rectangles into this object such + that they will be displayed. + !*/ + + void add_overlay ( + const std::vector& overlay + ); + /*! + ensures + - adds the given set of overlay lines into this object such + that they will be displayed. + !*/ + + void add_overlay ( + const std::vector& overlay + ); + /*! + ensures + - adds the given set of overlay circles into this object such + that they will be displayed. + !*/ + + void clear_overlay ( + ); + /*! + ensures + - removes all overlays from this object. + !*/ + + void tie_events ( + ); + /*! + ensures + - #events_tied() == true + !*/ + + void untie_events ( + ); + /*! + ensures + - #events_tied() == false + !*/ + + bool events_tied ( + ) const; + /*! + ensures + - returns true if and only if the get_next_double_click() and + get_next_keypress() events are tied together. If they are tied it means + that you can use a loop of the following form to listen for both events + simultaneously: + while (mywindow.get_next_double_click(p) || mywindow.get_next_keypress(key,printable)) + { + if (p.x() < 0) + // Do something with the keyboard event + else + // Do something with the mouse event + } + !*/ + + bool get_next_double_click ( + point& p + ); + /*! + ensures + - This function blocks until the user double clicks on the image or the + window is closed by the user. It will also unblock for a keyboard key + press if events_tied() == true. + - if (this function returns true) then + - This means the user double clicked the mouse. + - #p == the next image pixel the user clicked. + - else + - #p == point(-1,1) + !*/ + + bool get_next_double_click ( + point& p, + unsigned long& mouse_button + ); + /*! + ensures + - This function blocks until the user double clicks on the image or the + window is closed by the user. It will also unblock for a keyboard key + press if events_tied() == true. + - if (this function returns true) then + - This means the user double clicked the mouse. + - #p == the next image pixel the user clicked. + - #mouse_button == the mouse button which was used to double click. + This will be either dlib::base_window::LEFT, + dlib::base_window::MIDDLE, or dlib::base_window::RIGHT + - else + - #p == point(-1,1) + (Note that this point is outside any possible image) + !*/ + + bool get_next_keypress ( + unsigned long& key, + bool& is_printable, + unsigned long& state + ); + /*! + ensures + - This function blocks until the user presses a keyboard key or the + window is closed by the user. It will also unblock for a mouse double + click if events_tied() == true. + - if (this function returns true) then + - This means the user pressed a keyboard key. + - The keyboard button press is recorded into #key, #is_printable, and + #state. In particular, these variables are populated with the three + identically named arguments to the base_window::on_keydown(key,is_printable,state) + event. + !*/ + + bool get_next_keypress ( + unsigned long& key, + bool& is_printable + ); + /*! + ensures + - This function blocks until the user presses a keyboard key or the + window is closed by the user. It will also unblock for a mouse double + click if events_tied() == true. + - This function is the equivalent to calling get_next_keypress(key,is_printable,temp) + and then discarding temp. + !*/ + + private: + + // restricted functions + image_window(image_window&); + image_window& operator= (image_window&); + }; + +// ---------------------------------------------------------------------------------------- + + class perspective_display : public drawable, noncopyable + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a tool for displaying 3D point clouds on a screen. You can + navigate the display with the mouse. Left click and drag rotates the + camera around the displayed data. Scrolling the mouse wheel zooms and + shift+left click (or just right click) and drag pans the view around. + !*/ + + public: + + perspective_display( + drawable_window& w + ); + /*! + ensures + - #*this is properly initialized + - #*this has been added to window w + - #parent_window() == w + !*/ + + ~perspective_display( + ); + /*! + ensures + - all resources associated with *this have been released + !*/ + + void set_size ( + unsigned long width, + unsigned long height + ); + /*! + ensures + - #width() == width + - #height() == height + - #top() == top() + - #left() == left() + - i.e. The location of the upper left corner of this widget stays the + same but its width and height are modified. + !*/ + + struct overlay_line + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a line that is drawn on the screen. Each line + is represented by its two end points (p1 and p2) as well as a color. + !*/ + + overlay_line() { assign_pixel(color, 0);} + + overlay_line(const vector& p1_, const vector& p2_) + : p1(p1_), p2(p2_) { assign_pixel(color, 255); } + + template + overlay_line(const vector& p1_, const vector& p2_, pixel_type p) + : p1(p1_), p2(p2_) { assign_pixel(color, p); } + + vector p1; + vector p2; + rgb_pixel color; + }; + + struct overlay_dot + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a dot that is drawn on the screen. Each dot is + represented by one point and a color. + !*/ + + overlay_dot() { assign_pixel(color, 0);} + + overlay_dot(const vector& p_) + : p(p_) { assign_pixel(color, 255); } + + template + overlay_dot(const vector& p_, pixel_type color_) + : p(p_) { assign_pixel(color, color_); } + + vector p; // The location of the dot + rgb_pixel color; + }; + + void add_overlay ( + const std::vector& overlay + ); + /*! + ensures + - Adds the given overlay lines into this object such that it will be + displayed. + !*/ + + void add_overlay ( + const std::vector& overlay + ); + /*! + ensures + - Adds the given overlay dots into this object such that it will be + displayed. + !*/ + + void clear_overlay ( + ); + /*! + ensures + - Removes all overlays from this object. The display will be empty. + !*/ + + template + void set_dot_double_clicked_handler ( + T& object, + void (T::*event_handler)(const vector&) + ); + /* + requires + - event_handler is a valid pointer to a member function in T + ensures + - The event_handler function is called on object when the user double + clicks on one of the overlay dots. The selected dot will be passed to + event_handler(). + - Any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + */ + + void set_dot_double_clicked_handler ( + const any_function&)>& event_handler + ); + /* + ensures + - The event_handler function is called when the user double clicks on one + of the overlay dots. The selected dot will be passed to event_handler(). + - Any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + */ + }; + +// ---------------------------------------------------------------------------------------- + + class perspective_window : public drawable_window, noncopyable + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a simple window that is just a container for a perspective_display. + It exists to make it easy to throw perspective_displays onto the screen + without having to put together your own drawable_window objects. + !*/ + public: + + typedef perspective_display::overlay_line overlay_line; + typedef perspective_display::overlay_dot overlay_dot; + + perspective_window( + ); + /*! + ensures + - The window is displayed on the screen and is 100x100 pixels in size. + !*/ + + perspective_window( + const std::vector >& point_cloud + ); + /*! + ensures + - The window is displayed on the screen and is 100x100 pixels in size. + - This window will have point_cloud added to it via add_overlay() and the + points will all be white. + !*/ + + perspective_window( + const std::vector >& point_cloud, + const std::string& title + ); + /*! + ensures + - The window is displayed on the screen and is 100x100 pixels in size. + - This window will have point_cloud added to it via add_overlay() and the + points will all be white. + - The title of the window will be set to the given title string. + !*/ + + ~perspective_window( + ); + /*! + ensures + - any resources associated with this object have been released + !*/ + + void add_overlay ( + const std::vector& overlay + ); + /*! + ensures + - Adds the given overlay lines into this object such that it will be + displayed. + !*/ + + void add_overlay ( + const std::vector& overlay + ); + /*! + ensures + - Adds the given overlay dots into this object such that it will be + displayed. + !*/ + + void clear_overlay ( + ); + /*! + ensures + - Removes all overlays from this object. The display will be empty. + !*/ + + void add_overlay( + const std::vector >& d + ); + /*! + ensures + - Adds the given dots into this object such that it will be + displayed. They will be colored white. + !*/ + + template + void add_overlay( + const std::vector >& d, + pixel_type p + ); + /*! + ensures + - Adds the given dots into this object such that it will be + displayed. They will be colored by pixel color p. + !*/ + + template + void add_overlay( + const vector& p1, + const vector& p2, + pixel_type color + ); + /*! + ensures + - Adds an overlay line going from p1 to p2 with the given color. + !*/ + + template < typename T > + void set_dot_double_clicked_handler ( + T& object, + void (T::*event_handler)(const vector&) + ); + /* + requires + - event_handler is a valid pointer to a member function in T + ensures + - The event_handler function is called on object when the user double + clicks on one of the overlay dots. The selected dot will be passed to + event_handler(). + - Any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + */ + + void set_dot_double_clicked_handler ( + const any_function&)>& event_handler + ); + /* + ensures + - The event_handler function is called when the user double clicks on one + of the overlay dots. The selected dot will be passed to event_handler(). + - Any previous calls to this function are overridden by this new call. + (i.e. you can only have one event handler associated with this + event at a time) + */ + + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_WIDGETs_ABSTRACT_ + diff --git a/ml/dlib/dlib/hash.h b/ml/dlib/dlib/hash.h new file mode 100644 index 000000000..5a018b438 --- /dev/null +++ b/ml/dlib/dlib/hash.h @@ -0,0 +1,14 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_HASh_ +#define DLIB_HASh_ + + +#include "general_hash/hash.h" +#include "general_hash/random_hashing.h" +#include "general_hash/count_bits.h" + + +#endif // DLIB_HASh_ + + diff --git a/ml/dlib/dlib/hash_map.h b/ml/dlib/dlib/hash_map.h new file mode 100644 index 000000000..225ebd466 --- /dev/null +++ b/ml/dlib/dlib/hash_map.h @@ -0,0 +1,63 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_HASH_MAp_ +#define DLIB_HASH_MAp_ + +#include "hash_map/hash_map_kernel_1.h" +#include "hash_map/hash_map_kernel_c.h" + +#include "hash_table.h" +#include "algs.h" + +#include "algs.h" +#include + +namespace dlib +{ + + template < + typename domain, + typename range, + unsigned long expnum, + typename mem_manager = default_memory_manager, + typename compare = std::less + > + class hash_map + { + hash_map() {} + + typedef typename hash_table::kernel_1a + hash_table_1; + typedef typename hash_table::kernel_2a + hash_table_2; + typedef typename hash_table::kernel_2b + hash_table_3; + + public: + + //----------- kernels --------------- + + // kernel_1a + typedef hash_map_kernel_1 + kernel_1a; + typedef hash_map_kernel_c + kernel_1a_c; + + // kernel_1b + typedef hash_map_kernel_1 + kernel_1b; + typedef hash_map_kernel_c + kernel_1b_c; + + // kernel_1c + typedef hash_map_kernel_1 + kernel_1c; + typedef hash_map_kernel_c + kernel_1c_c; + + + }; +} + +#endif // DLIB_HASH_MAp_ + diff --git a/ml/dlib/dlib/hash_map/hash_map_kernel_1.h b/ml/dlib/dlib/hash_map/hash_map_kernel_1.h new file mode 100644 index 000000000..e8b1d6d71 --- /dev/null +++ b/ml/dlib/dlib/hash_map/hash_map_kernel_1.h @@ -0,0 +1,460 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_HASH_MAP_KERNEl_1_ +#define DLIB_HASH_MAP_KERNEl_1_ + +#include "hash_map_kernel_abstract.h" +#include "../algs.h" +#include "../general_hash/general_hash.h" +#include "../interfaces/enumerable.h" +#include "../interfaces/map_pair.h" +#include "../interfaces/remover.h" +#include "../assert.h" +#include "../serialize.h" + +namespace dlib +{ + + template < + typename domain, + typename range, + unsigned long expnum, + typename hash_table, + typename mem_manager = default_memory_manager + > + class hash_map_kernel_1 : public enumerable >, + public pair_remover + { + + /*! + REQUIREMENTS ON hash_table + hash_table is instantiated with domain and range and + T_is_POD must be set to false and + implements hash_table/hash_table_kernel_abstract.h + + INITIAL VALUE + table.size() == 0 + + CONVENTION + table.size() = size() == the number of elements in the map + the elements in this hash_map are stored in table + !*/ + + + public: + + typedef domain domain_type; + typedef range range_type; + typedef typename hash_table::compare_type compare_type; + typedef mem_manager mem_manager_type; + + hash_map_kernel_1( + ) : + table(expnum) + { + COMPILE_TIME_ASSERT(expnum < 32); + } + + virtual ~hash_map_kernel_1( + ) + {} + + inline void clear( + ); + + void add ( + domain& d, + range& r + ); + + inline bool is_in_domain ( + const domain& d + ) const; + + void remove ( + const domain& d, + domain& d_copy, + range& r + ); + + void destroy ( + const domain& d + ); + + range& operator[] ( + const domain& d + ); + + const range& operator[] ( + const domain& d + ) const; + + inline void swap ( + hash_map_kernel_1& item + ); + + // functions from the remover interface + inline void remove_any ( + domain& d, + range& r + ); + + // functions from the enumerable interface + inline size_t size ( + ) const; + + inline bool at_start ( + ) const; + + inline void reset ( + ) const; + + inline bool current_element_valid ( + ) const; + + inline const map_pair& element ( + ) const; + + inline map_pair& element ( + ); + + inline bool move_next ( + ) const; + + private: + + hash_table table; + + // restricted functions + hash_map_kernel_1(hash_map_kernel_1&); + hash_map_kernel_1& operator= ( hash_map_kernel_1&); + + }; + + template < + typename domain, + typename range, + unsigned long expnum, + typename hash_table, + typename mem_manager + > + inline void swap ( + hash_map_kernel_1& a, + hash_map_kernel_1& b + ) { a.swap(b); } + + template < + typename domain, + typename range, + unsigned long expnum, + typename hash_table, + typename mem_manager + > + void deserialize ( + hash_map_kernel_1& item, + std::istream& in + ) + { + try + { + item.clear(); + unsigned long size; + deserialize(size,in); + domain d; + range r; + for (unsigned long i = 0; i < size; ++i) + { + deserialize(d,in); + deserialize(r,in); + item.add(d,r); + } + } + catch (serialization_error e) + { + item.clear(); + throw serialization_error(e.info + "\n while deserializing object of type hash_map_kernel_1"); + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + unsigned long expnum, + typename hash_table, + typename mem_manager + > + void hash_map_kernel_1:: + clear ( + ) + { + table.clear(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + unsigned long expnum, + typename hash_table, + typename mem_manager + > + void hash_map_kernel_1:: + add ( + domain& d, + range& r + ) + { + table.add(d,r); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + unsigned long expnum, + typename hash_table, + typename mem_manager + > + bool hash_map_kernel_1:: + is_in_domain( + const domain& d + ) const + { + return (table[d] != 0); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + unsigned long expnum, + typename hash_table, + typename mem_manager + > + void hash_map_kernel_1:: + remove_any ( + domain& d, + range& r + ) + { + table.remove_any(d,r); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + unsigned long expnum, + typename hash_table, + typename mem_manager + > + void hash_map_kernel_1:: + remove( + const domain& d, + domain& d_copy, + range& r + ) + { + table.remove(d,d_copy,r); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + unsigned long expnum, + typename hash_table, + typename mem_manager + > + void hash_map_kernel_1:: + destroy( + const domain& d + ) + { + table.destroy(d); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + unsigned long expnum, + typename hash_table, + typename mem_manager + > + range& hash_map_kernel_1:: + operator[]( + const domain& d + ) + { + return *table[d]; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + unsigned long expnum, + typename hash_table, + typename mem_manager + > + const range& hash_map_kernel_1:: + operator[]( + const domain& d + ) const + { + return *table[d]; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + unsigned long expnum, + typename hash_table, + typename mem_manager + > + size_t hash_map_kernel_1:: + size ( + ) const + { + return table.size(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + unsigned long expnum, + typename hash_table, + typename mem_manager + > + void hash_map_kernel_1:: + swap ( + hash_map_kernel_1& item + ) + { + table.swap(item.table); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // enumerable function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + unsigned long expnum, + typename hash_table, + typename mem_manager + > + bool hash_map_kernel_1:: + at_start ( + ) const + { + return table.at_start(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + unsigned long expnum, + typename hash_table, + typename mem_manager + > + void hash_map_kernel_1:: + reset ( + ) const + { + table.reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + unsigned long expnum, + typename hash_table, + typename mem_manager + > + bool hash_map_kernel_1:: + current_element_valid ( + ) const + { + return table.current_element_valid(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + unsigned long expnum, + typename hash_table, + typename mem_manager + > + const map_pair& hash_map_kernel_1:: + element ( + ) const + { + return table.element(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + unsigned long expnum, + typename hash_table, + typename mem_manager + > + map_pair& hash_map_kernel_1:: + element ( + ) + { + return table.element(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + unsigned long expnum, + typename hash_table, + typename mem_manager + > + bool hash_map_kernel_1:: + move_next ( + ) const + { + return table.move_next(); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_HASH_MAP_KERNEl_1_ + diff --git a/ml/dlib/dlib/hash_map/hash_map_kernel_abstract.h b/ml/dlib/dlib/hash_map/hash_map_kernel_abstract.h new file mode 100644 index 000000000..cee404f54 --- /dev/null +++ b/ml/dlib/dlib/hash_map/hash_map_kernel_abstract.h @@ -0,0 +1,247 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_HASH_MAP_KERNEl_ABSTRACT_ +#ifdef DLIB_HASH_MAP_KERNEl_ABSTRACT_ + +#include "../general_hash/general_hash.h" +#include "../interfaces/enumerable.h" +#include "../interfaces/remover.h" +#include "../interfaces/map_pair.h" +#include "../serialize.h" +#include "../algs.h" +#include + +namespace dlib +{ + + template < + typename domain, + typename range, + unsigned long expnum, + typename mem_manager = default_memory_manager, + typename compare = std::less + > + class hash_map : public enumerable >, + public pair_remover + { + + /*! + REQUIREMENTS ON domain + domain must be comparable by compare where compare is a functor compatible with std::less and + domain must be hashable by general_hash + (general_hash is defined in dlib/general_hash) and + domain must be swappable by a global swap() and + domain must have a default constructor + + REQUIREMENTS ON range + range must be swappable by a global swap() and + range must have a default constructor + + REQUIREMENTS ON expnum + expnum < 32 + 2^expnum is the number of buckets to hash items of type T into. + Note that this is really just a suggestion to the hash table. + Implementations are free to manage the table size however is most + appropriate. + + REQUIREMENTS ON mem_manager + must be an implementation of memory_manager/memory_manager_kernel_abstract.h or + must be an implementation of memory_manager_global/memory_manager_global_kernel_abstract.h or + must be an implementation of memory_manager_stateless/memory_manager_stateless_kernel_abstract.h + mem_manager::type can be set to anything. + + POINTERS AND REFERENCES TO INTERNAL DATA + swap(), is_in_domain(), and operator[] functions do + not invalidate pointers or references to internal data. + All other functions have no such guarantees. + + INITIAL VALUE + size() == 0 + + ENUMERATION ORDER + No order is specified. Only that each element will be visited once + and only once. + + WHAT THIS OBJECT REPRESENTS + hash_map contains items of type domain and range + + This object is similar an array. It maps items of type domain on to + items of type range. + + Also note that unless specified otherwise, no member functions + of this object throw exceptions. + + definition of equivalent: + a is equivalent to b if + a < b == false and + b < a == false + !*/ + + public: + + typedef domain domain_type; + typedef range range_type; + typedef compare compare_type; + typedef mem_manager mem_manager_type; + + hash_map( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc or any exception thrown by domain's or range's + constructor. + !*/ + + virtual ~hash_map( + ); + /*! + ensures + - all memory associated with *this has been released + !*/ + + void clear( + ); + /*! + ensures + - #*this has its initial value + throws + - std::bad_alloc or any exception thrown by domain's or range's + constructor. + if this exception is thrown then *this is unusable + until clear() is called and succeeds + !*/ + + void add ( + domain& d, + range& r + ); + /*! + requires + - &d != &r (i.e. d and r cannot be the same variable) + - is_in_domain(d) == false + ensures + - #is_in_domain(d) == true + - #operator[](d) == r + - #d and #r have initial values for their types + - #size() == size() + 1 + - #at_start() == true + throws + - std::bad_alloc or any exception thrown by domain's or range's + constructor. + if add() throws then it has no effect + !*/ + + bool is_in_domain ( + const domain& d + ) const; + /*! + ensures + - returns whether or not an element equivalent to d is in the + domain of *this + !*/ + + void remove ( + const domain& d, + domain& d_copy, + range& r + ); + /*! + requires + - &d != &r (i.e. d and r cannot be the same variable) + - &d != &d_copy (i.e. d and d_copy cannot be the same variable) + - &r != &d_copy (i.e. r and d_copy cannot be the same variable) + - is_in_domain(d) == true + ensures + - #is_in_domain(d) == false + - #d_copy is equivalent to d + - the element in the range of *this associated with #d_copy has + been swapped into #r + - #size() == size() - 1 + - #at_start() == true + !*/ + + void destroy ( + const domain& d + ); + /*! + requires + - is_in_domain(d) == true + ensures + - #is_in_domain(d) == false + - #size() == size() - 1 + - #at_start() == true + !*/ + + range& operator[] ( + const domain& d + ); + /*! + requires + - is_in_domain(d) == true + ensures + - returns a non-const reference to the element in the range of *this + associated with the element equivalent to d + !*/ + + const range& operator[] ( + const domain& d + ) const; + /*! + requires + - is_in_domain(d) == true + ensures + - returns a const reference to the element in the range of *this + associated with the element equivalent to d + !*/ + + void swap ( + hash_map& item + ); + /*! + ensures + - swaps *this and item + !*/ + + + private: + + // restricted functions + hash_map(hash_map&); + hash_map& operator=(hash_map&); + }; + + template < + typename domain, + typename range, + unsigned long expnum, + typename mem_manager, + typename compare + > + inline void swap ( + hash_map& a, + hash_map& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + + template < + typename domain, + typename range, + unsigned long expnum, + typename mem_manager, + typename compare + > + void deserialize ( + hash_map& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ +} + +#endif // DLIB_HASH_MAP_KERNEl_ABSTRACT_ + diff --git a/ml/dlib/dlib/hash_map/hash_map_kernel_c.h b/ml/dlib/dlib/hash_map/hash_map_kernel_c.h new file mode 100644 index 000000000..4db937249 --- /dev/null +++ b/ml/dlib/dlib/hash_map/hash_map_kernel_c.h @@ -0,0 +1,276 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_HASH_MAP_KERNEl_C_ +#define DLIB_HASH_MAP_KERNEl_C_ + +#include "hash_map_kernel_abstract.h" +#include "../algs.h" +#include "../assert.h" + +namespace dlib +{ + + template < + typename hash_map_base + > + class hash_map_kernel_c : public hash_map_base + { + + typedef typename hash_map_base::domain_type domain; + typedef typename hash_map_base::range_type range; + + + public: + void add ( + domain& d, + range& r + ); + + void remove_any ( + domain& d, + range& r + ); + + void remove ( + const domain& d, + domain& d_copy, + range& r + ); + + void destroy ( + const domain& d + ); + + range& operator[] ( + const domain& d + ); + + const range& operator[] ( + const domain& d + ) const; + + const map_pair& element ( + ) const; + + map_pair& element ( + ); + }; + + template < + typename hash_map_base + > + inline void swap ( + hash_map_kernel_c& a, + hash_map_kernel_c& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename hash_map_base + > + void hash_map_kernel_c:: + add ( + domain& d, + range& r + ) + { + + // make sure requires clause is not broken + DLIB_CASSERT( (!this->is_in_domain(d)) && + (static_cast(&d) != static_cast(&r)), + "\tvoid hash_map::add" + << "\n\tdomain element being added must not already be in the hash_map" + << "\n\tand d and r must not be the same variable" + << "\n\tis_in_domain(d): " << (this->is_in_domain(d) ? "true" : "false") + << "\n\tthis: " << this + << "\n\t&d: " << static_cast(&d) + << "\n\t&r: " << static_cast(&r) + ); + + + // call the real function + hash_map_base::add(d,r); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename hash_map_base + > + void hash_map_kernel_c:: + remove_any ( + domain& d, + range& r + ) + { + + + // make sure requires clause is not broken + DLIB_CASSERT( (this->size() > 0) && + (static_cast(&d) != static_cast(&r)), + "\tvoid hash_map::remove_any" + << "\n\tsize() must be greater than zero if something is going to be removed" + << "\n\tand d and r must not be the same variable." + << "\n\tsize(): " << this->size() + << "\n\tthis: " << this + << "\n\t&d: " << static_cast(&d) + << "\n\t&r: " << static_cast(&r) + ); + + + // call the real function + hash_map_base::remove_any(d,r); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename hash_map_base + > + void hash_map_kernel_c:: + remove ( + const domain& d, + domain& d_copy, + range& r + ) + { + + + // make sure requires clause is not broken + DLIB_CASSERT( (this->is_in_domain(d)) && + (static_cast(&d) != static_cast(&r)) && + (static_cast(&r) != static_cast(&d_copy)) && + (static_cast(&d) != static_cast(&d_copy)), + "\tvoid hash_map::remove" + << "\n\tcan't remove something that isn't in the hash_map or if the paremeters" + << "\n\tare actually the same variable. Either way can't remove." + << "\n\tis_in_domain(d): " << (this->is_in_domain(d) ? "true" : "false") + << "\n\tthis: " << this + << "\n\t&d: " << static_cast(&d) + << "\n\t&r: " << static_cast(&r) + << "\n\t&d_copy: " << static_cast(&d_copy) + ); + + + // call the real function + hash_map_base::remove(d,d_copy,r); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename hash_map_base + > + void hash_map_kernel_c:: + destroy ( + const domain& d + ) + { + + + // make sure requires clause is not broken + DLIB_CASSERT( this->is_in_domain(d), + "\tvoid hash_map::destroy" + << "\n\tcan't remove something that isn't in the hash_map" + << "\n\tthis: " << this + << "\n\t&d: " << static_cast(&d) + ); + + + // call the real function + hash_map_base::destroy(d); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename hash_map_base + > + typename hash_map_base::range_type& hash_map_kernel_c:: + operator[] ( + const domain& d + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( this->is_in_domain(d), + "\trange& hash_map::operator[]" + << "\n\td must be in the domain of the hash_map" + << "\n\tthis: " << this + ); + + // call the real function + return hash_map_base::operator[](d); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename hash_map_base + > + const typename hash_map_base::range_type& hash_map_kernel_c:: + operator[] ( + const domain& d + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT( this->is_in_domain(d), + "\tconst range& hash_map::operator[]" + << "\n\td must be in the domain of the hash_map" + << "\n\tthis: " << this + ); + + // call the real function + return hash_map_base::operator[](d); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename hash_map_base + > + const map_pair& hash_map_kernel_c:: + element ( + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT(this->current_element_valid() == true, + "\tconst map_pair& hash_map::element" + << "\n\tyou can't access the current element if it doesn't exist" + << "\n\tthis: " << this + ); + + // call the real function + return hash_map_base::element(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename hash_map_base + > + map_pair& hash_map_kernel_c:: + element ( + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(this->current_element_valid() == true, + "\tmap_pair& hash_map::element" + << "\n\tyou can't access the current element if it doesn't exist" + << "\n\tthis: " << this + ); + + // call the real function + return hash_map_base::element(); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_HASH_MAP_KERNEl_C_ + diff --git a/ml/dlib/dlib/hash_set.h b/ml/dlib/dlib/hash_set.h new file mode 100644 index 000000000..90a51fd20 --- /dev/null +++ b/ml/dlib/dlib/hash_set.h @@ -0,0 +1,63 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_HASH_SEt_ +#define DLIB_HASH_SEt_ + +#include "hash_set/hash_set_kernel_1.h" +#include "hash_set/hash_set_kernel_c.h" + +#include "hash_table.h" +#include "algs.h" + + +#include "algs.h" +#include + + +namespace dlib +{ + + template < + typename T, + unsigned long expnum, + typename mem_manager = default_memory_manager, + typename compare = std::less + > + class hash_set + { + hash_set() {} + + typedef typename hash_table::kernel_1a ht1a; + typedef typename hash_table::kernel_1a ht2a; + typedef typename hash_table::kernel_1a ht2b; + + public: + + //----------- kernels --------------- + + // kernel_1a + typedef hash_set_kernel_1 + kernel_1a; + typedef hash_set_kernel_c + kernel_1a_c; + + // kernel_1b + typedef hash_set_kernel_1 + kernel_1b; + typedef hash_set_kernel_c + kernel_1b_c; + + // kernel_1c + typedef hash_set_kernel_1 + kernel_1c; + typedef hash_set_kernel_c + kernel_1c_c; + + + + + }; +} + +#endif // DLIB_HASH_SEt_ + diff --git a/ml/dlib/dlib/hash_set/hash_set_kernel_1.h b/ml/dlib/dlib/hash_set/hash_set_kernel_1.h new file mode 100644 index 000000000..c40770b91 --- /dev/null +++ b/ml/dlib/dlib/hash_set/hash_set_kernel_1.h @@ -0,0 +1,391 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_HASH_SET_KERNEl_1_ +#define DLIB_HASH_SET_KERNEl_1_ + +#include "hash_set_kernel_abstract.h" +#include "../algs.h" +#include "../general_hash/general_hash.h" +#include "../interfaces/enumerable.h" +#include "../interfaces/remover.h" +#include "../assert.h" +#include "../serialize.h" + +namespace dlib +{ + + template < + typename T, + unsigned long expnum, + typename hash_table, + typename mem_manager = default_memory_manager + > + class hash_set_kernel_1 : public enumerable, + public remover + { + + /*! + REQUIREMENTS ON hash_table + hash_table is instantiated with and + T_is_POD must be set to false and + is an implementation of hash_table/hash_table_kernel_abstract.h + + INITIAL VALUE + table.size() == 0 + + CONVENTION + table.size() = size() == the number of elements in the set and + the elements in this hash_set are stored in table + !*/ + + public: + + typedef T type; + typedef typename hash_table::compare_type compare_type; + typedef mem_manager mem_manager_type; + + hash_set_kernel_1( + ) : + table(expnum) + { + COMPILE_TIME_ASSERT(expnum < 32); + } + + virtual ~hash_set_kernel_1( + ) + {} + + inline void clear( + ); + + inline void add ( + T& item + ); + + inline bool is_member ( + const T& item + ) const; + + inline void remove ( + const T& item, + T& item_copy + ); + + inline void destroy ( + const T& item + ); + + inline void swap ( + hash_set_kernel_1& item + ); + + // functions from the remover interface + inline void remove_any ( + T& item + ); + + // functions from the enumerable interface + inline size_t size ( + ) const; + + inline bool at_start ( + ) const; + + inline void reset ( + ) const; + + inline bool current_element_valid ( + ) const; + + inline const T& element ( + ) const; + + inline const T& element ( + ); + + inline bool move_next ( + ) const; + + private: + + hash_table table; + char junk; + + // restricted functions + hash_set_kernel_1(hash_set_kernel_1&); + hash_set_kernel_1& operator= ( hash_set_kernel_1&); + + }; + + template < + typename T, + unsigned long expnum, + typename hash_table, + typename mem_manager + > + inline void swap ( + hash_set_kernel_1& a, + hash_set_kernel_1& b + ) { a.swap(b); } + + template < + typename T, + unsigned long expnum, + typename hash_table, + typename mem_manager + > + void deserialize ( + hash_set_kernel_1& item, + std::istream& in + ) + { + try + { + item.clear(); + unsigned long size; + deserialize(size,in); + T temp; + for (unsigned long i = 0; i < size; ++i) + { + deserialize(temp,in); + item.add(temp); + } + } + catch (serialization_error e) + { + item.clear(); + throw serialization_error(e.info + "\n while deserializing object of type hash_set_kernel_1"); + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T, + unsigned long expnum, + typename hash_table, + typename mem_manager + > + void hash_set_kernel_1:: + clear ( + ) + { + table.clear(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + unsigned long expnum, + typename hash_table, + typename mem_manager + > + void hash_set_kernel_1:: + add ( + T& item + ) + { + table.add(item,junk); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + unsigned long expnum, + typename hash_table, + typename mem_manager + > + bool hash_set_kernel_1:: + is_member( + const T& item + ) const + { + return (table[item] != 0); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + unsigned long expnum, + typename hash_table, + typename mem_manager + > + void hash_set_kernel_1:: + remove_any ( + T& item + ) + { + table.remove_any(item,junk); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + unsigned long expnum, + typename hash_table, + typename mem_manager + > + void hash_set_kernel_1:: + remove( + const T& item, + T& item_copy + ) + { + table.remove(item,item_copy,junk); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + unsigned long expnum, + typename hash_table, + typename mem_manager + > + void hash_set_kernel_1:: + destroy( + const T& item + ) + { + table.destroy(item); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + unsigned long expnum, + typename hash_table, + typename mem_manager + > + size_t hash_set_kernel_1:: + size ( + ) const + { + return table.size(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + unsigned long expnum, + typename hash_table, + typename mem_manager + > + void hash_set_kernel_1:: + swap ( + hash_set_kernel_1& item + ) + { + table.swap(item.table); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // enumerable function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T, + unsigned long expnum, + typename hash_table, + typename mem_manager + > + bool hash_set_kernel_1:: + at_start ( + ) const + { + return table.at_start(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + unsigned long expnum, + typename hash_table, + typename mem_manager + > + void hash_set_kernel_1:: + reset ( + ) const + { + table.reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + unsigned long expnum, + typename hash_table, + typename mem_manager + > + bool hash_set_kernel_1:: + current_element_valid ( + ) const + { + return table.current_element_valid(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + unsigned long expnum, + typename hash_table, + typename mem_manager + > + const T& hash_set_kernel_1:: + element ( + ) const + { + return table.element().key(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + unsigned long expnum, + typename hash_table, + typename mem_manager + > + const T& hash_set_kernel_1:: + element ( + ) + { + return table.element().key(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + unsigned long expnum, + typename hash_table, + typename mem_manager + > + bool hash_set_kernel_1:: + move_next ( + ) const + { + return table.move_next(); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_HASH_SET_KERNEl_1_ + diff --git a/ml/dlib/dlib/hash_set/hash_set_kernel_abstract.h b/ml/dlib/dlib/hash_set/hash_set_kernel_abstract.h new file mode 100644 index 000000000..35d1de74e --- /dev/null +++ b/ml/dlib/dlib/hash_set/hash_set_kernel_abstract.h @@ -0,0 +1,207 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_HASH_SET_KERNEl_ABSTRACT_ +#ifdef DLIB_HASH_SET_KERNEl_ABSTRACT_ + +#include "../general_hash/general_hash.h" +#include "../interfaces/enumerable.h" +#include "../interfaces/remover.h" +#include "../serialize.h" +#include "../algs.h" +#include + +namespace dlib +{ + + template < + typename T, + unsigned long expnum, + typename mem_manager = default_memory_manager, + typename compare = std::less + > + class hash_set : public enumerable, + public remover + { + + /*! + REQUIREMENTS ON T + domain must be comparable by compare where compare is a functor compatible with std::less and + T must be hashable by general_hash + (general_hash is defined in dlib/general_hash) and + T must be swappable by a global swap() and + T must have a default constructor + + REQUIREMENTS ON expnum + expnum < 32 + 2^expnum is the number of buckets to hash items of type T into. + Note that this is really just a suggestion to the hash table. + Implementations are free to manage the table size however is most + appropriate. + + REQUIREMENTS ON mem_manager + must be an implementation of memory_manager/memory_manager_kernel_abstract.h or + must be an implementation of memory_manager_global/memory_manager_global_kernel_abstract.h or + must be an implementation of memory_manager_stateless/memory_manager_stateless_kernel_abstract.h + mem_manager::type can be set to anything. + + POINTERS AND REFERENCES TO INTERNAL DATA + swap() and is_member() functions do not invalidate + pointers or references to internal data. + All other functions have no such guarantee. + + INITIAL VALUE + size() == 0 + + ENUMERATION ORDER + No order is specified. Only that each element will be visited once + and only once. + + WHAT THIS OBJECT REPRESENTS + hash_set contains items of type T + + This object represents an unaddressed collection + of items. Every element in a hash_set is unique. + + Also note that unless specified otherwise, no member functions + of this object throw exceptions. + + definition of equivalent: + a is equivalent to b if + a < b == false and + b < a == false + !*/ + + public: + + typedef T type; + typedef compare compare_type; + typedef mem_manager mem_manager_type; + + hash_set( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc or any exception thrown by T's constructor + !*/ + + virtual ~hash_set( + ); + /*! + ensures + - all memory associated with *this has been released + !*/ + + void clear( + ); + /*! + ensures + - #*this has its initial value + throws + - std::bad_alloc or any exception thrown by T's constructor + if this exception is thrown then *this is unusable + until clear() is called and succeeds + !*/ + + void add ( + T& item + ); + /*! + requires + - is_member(item) == false + ensures + - #is_member(item) == true + - #item has an initial value for its type + - #size() == size() + 1 + - #at_start() == true + throws + - std::bad_alloc or any exception thrown by T's constructor + if add() throws then it has no effect + !*/ + + bool is_member ( + const T& item + ) const; + /*! + ensures + - returns whether or not there is an element in *this equivalent + to item + !*/ + + void remove ( + const T& item, + T& item_copy + ); + /*! + requires + - is_member(item) == true + - &item != &item_copy (i.e. item and item_copy cannot be the + same variable) + ensures + - #is_member(item) == false + - the element in *this equivalent to item has been removed and + swapped into #item_copy + - #size() == size() - 1 + - #at_start() == true + !*/ + + void destroy ( + const T& item + ); + /*! + requires + - is_member(item) == true + ensures + - #is_member(item) == false + - #size() == size() - 1 + - #at_start() == true + !*/ + + void swap ( + hash_set& item + ); + /*! + ensures + - swaps *this and item + !*/ + + private: + + // restricted functions + hash_set(hash_set&); // copy constructor + hash_set& operator=(hash_set&); // assignment operator + + }; + + template < + typename T, + unsigned long expnum, + typename mem_manager, + typename compare + > + inline void swap ( + hash_set& a, + hash_set& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + + template < + typename T, + unsigned long expnum, + typename mem_manager, + typename compare + > + void deserialize ( + hash_set& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ +} + +#endif // DLIB_HASH_SET_KERNEl_ABSTRACT_ + diff --git a/ml/dlib/dlib/hash_set/hash_set_kernel_c.h b/ml/dlib/dlib/hash_set/hash_set_kernel_c.h new file mode 100644 index 000000000..bd0abd848 --- /dev/null +++ b/ml/dlib/dlib/hash_set/hash_set_kernel_c.h @@ -0,0 +1,190 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_HASH_SET_KERNEl_C_ +#define DLIB_HASH_SET_KERNEl_C_ + +#include "hash_set_kernel_abstract.h" +#include "../algs.h" +#include "../assert.h" + +namespace dlib +{ + + template < + typename hash_set_base + > + class hash_set_kernel_c : public hash_set_base + { + typedef typename hash_set_base::type T; + public: + + void add ( + T& item + ); + + void remove_any ( + T& item + ); + + void remove ( + const T& item, + T& item_copy + ); + + void destroy ( + const T& item + ); + + const T& element ( + ) const; + + const T& element ( + ); + + + }; + + + template < + typename hash_set_base + > + inline void swap ( + hash_set_kernel_c& a, + hash_set_kernel_c& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename hash_set_base + > + void hash_set_kernel_c:: + add( + T& item + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( !this->is_member(item), + "\tvoid hash_set::add" + << "\n\titem being added must not already be in the hash_set" + << "\n\tthis: " << this + ); + + // call the real function + hash_set_base::add(item); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename hash_set_base + > + void hash_set_kernel_c:: + remove ( + const T& item, + T& item_copy + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( this->is_member(item) && + (static_cast(&item) != static_cast(&item_copy)), + "\tvoid hash_set::remove" + << "\n\titem should be in the hash_set if it's going to be removed" + << "\n\tthis: " << this + << "\n\t&item: " << &item + << "\n\t&item_copy: " << &item_copy + ); + + // call the real function + hash_set_base::remove(item,item_copy); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename hash_set_base + > + void hash_set_kernel_c:: + destroy ( + const T& item + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( this->is_member(item), + "\tvoid hash_set::destroy" + << "\n\titem should be in the hash_set if it's going to be removed" + << "\n\tthis: " << this + << "\n\t&item: " << &item + ); + + // call the real function + hash_set_base::destroy(item); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename hash_set_base + > + void hash_set_kernel_c:: + remove_any ( + T& item + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( this->size() != 0, + "\tvoid hash_set::remove_any" + << "\n\tsize must be greater than zero if an item is to be removed" + << "\n\tthis: " << this + ); + + // call the real function + hash_set_base::remove_any(item); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename hash_set_base + > + const typename hash_set_base::type& hash_set_kernel_c:: + element ( + ) const + { + DLIB_CASSERT(this->current_element_valid() == true, + "\tconst T& hash_set::element()" + << "\n\tyou can't access the current element if it doesn't exist" + << "\n\tthis: " << this + ); + + return hash_set_base::element(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename hash_set_base + > + const typename hash_set_base::type& hash_set_kernel_c:: + element ( + ) + { + DLIB_CASSERT(this->current_element_valid() == true, + "\tT& hash_set::element()" + << "\n\tyou can't access the current element if it doesn't exist" + << "\n\tthis: " << this + ); + + return hash_set_base::element(); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_HASH_SET_KERNEl_C_ + diff --git a/ml/dlib/dlib/hash_table.h b/ml/dlib/dlib/hash_table.h new file mode 100644 index 000000000..b49feb916 --- /dev/null +++ b/ml/dlib/dlib/hash_table.h @@ -0,0 +1,60 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_HASH_TABLe_ +#define DLIB_HASH_TABLe_ + + +#include "hash_table/hash_table_kernel_1.h" +#include "hash_table/hash_table_kernel_2.h" +#include "hash_table/hash_table_kernel_c.h" +#include "algs.h" + +#include "binary_search_tree.h" +#include + + +namespace dlib +{ + + template < + typename domain, + typename range, + typename mem_manager = default_memory_manager, + typename compare = std::less + > + class hash_table + { + hash_table() {} + + typedef typename binary_search_tree::kernel_1a + bst_1; + typedef typename binary_search_tree::kernel_2a + bst_2; + + public: + + //----------- kernels --------------- + + // kernel_1a + typedef hash_table_kernel_1 + kernel_1a; + typedef hash_table_kernel_c + kernel_1a_c; + + + // kernel_2a + typedef hash_table_kernel_2 + kernel_2a; + typedef hash_table_kernel_c + kernel_2a_c; + + // kernel_2b + typedef hash_table_kernel_2 + kernel_2b; + typedef hash_table_kernel_c + kernel_2b_c; + }; +} + +#endif // DLIB_HASH_TABLe_ + diff --git a/ml/dlib/dlib/hash_table/hash_table_kernel_1.h b/ml/dlib/dlib/hash_table/hash_table_kernel_1.h new file mode 100644 index 000000000..f06847a72 --- /dev/null +++ b/ml/dlib/dlib/hash_table/hash_table_kernel_1.h @@ -0,0 +1,819 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_HASH_TABLE_KERNEl_1_ +#define DLIB_HASH_TABLE_KERNEl_1_ + +#include "hash_table_kernel_abstract.h" +#include "../general_hash/general_hash.h" +#include "../algs.h" +#include "../interfaces/map_pair.h" +#include "../interfaces/enumerable.h" +#include "../interfaces/remover.h" +#include "../assert.h" +#include "../serialize.h" +#include + + +namespace dlib +{ + + template < + typename domain, + typename range, + typename mem_manager = default_memory_manager, + typename compare = std::less + > + class hash_table_kernel_1 : public enumerable >, + public pair_remover + { + + /*! + INITIAL VALUE + hash_size == 0 + table == pointer to an array of num_of_buckets node pointers + num_of_buckets == the number of buckets in the hash table + current_element == 0 + at_start_ == true + mask == num_of_buckets-1 + + CONVENTION + current_element_valid() == (current_element != 0) + element() == current_element->d and current_element->r + at_start_ == at_start() + if (current_element != 0) then + table[current_bucket] == a pointer to the linked list that contains + the node pointed to by current_element + + mask == num_of_buckets-1 + + + + hash_size = size() == the number of elements in the hash_table and + table == pointer to an array of num_of_buckets node pointers and + num_of_buckets == the number of buckets in the hash table and + for all i: + table[i] == pointer to the first node in a linked list or + table[i] == 0 if this bucket is currently not in use + + + for all nodes: + d == the domain element stored in this node + r == the range element stored in this node which is associated with + d. + next == pointer to the next node in the linked list or + next == 0 if this is the last node in the linked list + + !*/ + + struct node + { + node* next; + domain d; + range r; + }; + + + class mpair : public map_pair + { + public: + const domain* d; + range* r; + + const domain& key( + ) const { return *d; } + + const range& value( + ) const { return *r; } + + range& value( + ) { return *r; } + }; + + + public: + + typedef domain domain_type; + typedef range range_type; + typedef compare compare_type; + typedef mem_manager mem_manager_type; + + explicit hash_table_kernel_1( + unsigned long expnum + ); + + virtual ~hash_table_kernel_1( + ); + + void clear( + ); + + unsigned long count ( + const domain& item + ) const; + + void add ( + domain& d, + range& r + ); + + void remove ( + const domain& d, + domain& d_copy, + range& r + ); + + void destroy ( + const domain& d + ); + + const range* operator[] ( + const domain& d + ) const; + + range* operator[] ( + const domain& d + ); + + void swap ( + hash_table_kernel_1& item + ); + + // functions from the remover interface + void remove_any ( + domain& d, + range& r + ); + + // functions from the enumerable interface + inline size_t size ( + ) const; + + bool at_start ( + ) const; + + inline void reset ( + ) const; + + bool current_element_valid ( + ) const; + + const map_pair& element ( + ) const; + + map_pair& element ( + ); + + bool move_next ( + ) const; + + private: + + // data members + typename mem_manager::template rebind::other pool; + typename mem_manager::template rebind::other ppool; + unsigned long hash_size; + node** table; + general_hash hash; + unsigned long num_of_buckets; + unsigned long mask; + + mutable mpair p; + + mutable unsigned long current_bucket; + mutable node* current_element; + mutable bool at_start_; + compare comp; + + // restricted functions + hash_table_kernel_1(hash_table_kernel_1&); + hash_table_kernel_1& operator=(hash_table_kernel_1&); + + }; + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + inline void swap ( + hash_table_kernel_1& a, + hash_table_kernel_1& b + ) { a.swap(b); } + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void deserialize ( + hash_table_kernel_1& item, + std::istream& in + ) + { + try + { + item.clear(); + unsigned long size; + deserialize(size,in); + domain d; + range r; + for (unsigned long i = 0; i < size; ++i) + { + deserialize(d,in); + deserialize(r,in); + item.add(d,r); + } + } + catch (serialization_error e) + { + item.clear(); + throw serialization_error(e.info + "\n while deserializing object of type hash_table_kernel_1"); + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + hash_table_kernel_1:: + hash_table_kernel_1( + unsigned long expnum + ) : + hash_size(0), + current_element(0), + at_start_(true) + { + + num_of_buckets = 1; + while (expnum != 0) + { + --expnum; + num_of_buckets <<= 1; + } + mask = num_of_buckets-1; + + table = ppool.allocate_array(num_of_buckets); + for (unsigned long i = 0; i < num_of_buckets; ++i) + { + table[i] = 0; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + hash_table_kernel_1:: + ~hash_table_kernel_1( + ) + { + for (unsigned long i = 0; i < num_of_buckets; ++i) + { + // delete this linked list + node* temp = table[i]; + while (temp) + { + node* t = temp; + temp = temp->next; + pool.deallocate(t); + } + table[i] = 0; + } + ppool.deallocate_array(table); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void hash_table_kernel_1:: + clear( + ) + { + if (hash_size > 0) + { + for (unsigned long i = 0; i < num_of_buckets; ++i) + { + // delete this linked list + node* temp = table[i]; + while (temp) + { + node* t = temp; + temp = temp->next; + pool.deallocate(t); + } + table[i] = 0; + } + hash_size = 0; + } + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + size_t hash_table_kernel_1:: + size( + ) const + { + return hash_size; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + unsigned long hash_table_kernel_1:: + count( + const domain& d + ) const + { + unsigned long items_found = 0; + node* temp = table[hash(d)&mask]; + + while (temp != 0) + { + // look for an element equivalent to d + if ( !(comp(temp->d , d) || comp(d , temp->d)) ) + { + ++items_found; + } + temp = temp->next; + } + + return items_found; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void hash_table_kernel_1:: + add( + domain& d, + range& r + ) + { + unsigned long hash_value = hash(d)&mask; + + // make a new node for this item + node& temp = *(pool.allocate()); + exchange(d,temp.d); + exchange(r,temp.r); + + // add this new node to the head of the linked list in bucket number hash_value + temp.next = table[hash_value]; + table[hash_value] = &temp; + + ++hash_size; + + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void hash_table_kernel_1:: + destroy( + const domain& d + ) + { + node* last; + const unsigned long hash_value = hash(d)&mask; + node* temp = table[hash_value]; + + // if there is more than one thing in this bucket + if (temp->next != 0) + { + // start looking with the second item in the list + last = temp; + temp = temp->next; + while (true) + { + // if we hit the end of the list without finding item then it must + // be the first element in the list so splice it out + if (temp == 0) + { + temp = table[hash_value]; + table[hash_value] = temp->next; + + break; + } + + // look for an element equivalent to item + if ( !(comp(temp->d , d) || comp(d , temp->d)) ) + { + // splice out the node we want to remove + last->next = temp->next; + break; + } + + last = temp; + temp = temp->next; + } + + } + // else there is only one node in this linked list + else + { + table[hash_value] = 0; + } + + pool.deallocate(temp); + + --hash_size; + + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void hash_table_kernel_1:: + remove( + const domain& d, + domain& d_copy, + range& r + ) + { + node* last; + const unsigned long hash_value = hash(d)&mask; + node* temp = table[hash_value]; + + // if there is more than one thing in this bucket + if (temp->next != 0) + { + // start looking with the second item in the list + last = temp; + temp = temp->next; + while (true) + { + // if we hit the end of the list without finding item then it must + // be the first element in the list so splice it out + if (temp == 0) + { + temp = table[hash_value]; + table[hash_value] = temp->next; + + break; + } + + // look for an element equivalent to item + if ( !(comp(temp->d , d) || comp(d , temp->d)) ) + { + // splice out the node we want to remove + last->next = temp->next; + break; + } + + last = temp; + temp = temp->next; + } + + } + // else there is only one node in this linked list + else + { + table[hash_value] = 0; + } + + + exchange(d_copy,temp->d); + exchange(r,temp->r); + pool.deallocate(temp); + + --hash_size; + + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void hash_table_kernel_1:: + remove_any( + domain& d, + range& r + ) + { + unsigned long i = 0; + + // while the ith bucket is empty keep looking + while (table[i] == 0) + { + ++i; + } + + // remove the first node in the linked list in the ith bucket + node& temp = *(table[i]); + + exchange(temp.d,d); + exchange(temp.r,r); + table[i] = temp.next; + + pool.deallocate(&temp); + + --hash_size; + + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + const range* hash_table_kernel_1:: + operator[]( + const domain& d + ) const + { + node* temp = table[hash(d)&mask]; + + while (temp != 0) + { + // look for an element equivalent to item + if ( !(comp(temp->d , d) || comp(d , temp->d)) ) + return &(temp->r); + + temp = temp->next; + } + + return 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + range* hash_table_kernel_1:: + operator[]( + const domain& d + ) + { + node* temp = table[hash(d)&mask]; + + while (temp != 0) + { + // look for an element equivalent to item + if ( !(comp(temp->d , d) || comp(d , temp->d)) ) + return &(temp->r); + + temp = temp->next; + } + + return 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void hash_table_kernel_1:: + swap( + hash_table_kernel_1& item + ) + { + exchange(mask,item.mask); + exchange(table,item.table); + exchange(hash_size,item.hash_size); + exchange(num_of_buckets,item.num_of_buckets); + exchange(current_bucket,item.current_bucket); + exchange(current_element,item.current_element); + exchange(at_start_,item.at_start_); + pool.swap(item.pool); + ppool.swap(item.ppool); + exchange(p,item.p); + exchange(comp,item.comp); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // enumerable function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + bool hash_table_kernel_1:: + at_start ( + ) const + { + return at_start_; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void hash_table_kernel_1:: + reset ( + ) const + { + at_start_ = true; + current_element = 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + bool hash_table_kernel_1:: + current_element_valid ( + ) const + { + return (current_element != 0); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + const map_pair& hash_table_kernel_1:: + element ( + ) const + { + p.d = &(current_element->d); + p.r = &(current_element->r); + return p; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + map_pair& hash_table_kernel_1:: + element ( + ) + { + p.d = &(current_element->d); + p.r = &(current_element->r); + return p; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + bool hash_table_kernel_1:: + move_next ( + ) const + { + if (at_start_) + { + at_start_ = false; + // if the queue is empty then there is nothing to do + if (hash_size == 0) + { + return false; + } + else + { + // find the first element in the hash table + for (current_bucket = 0; true ; ++current_bucket) + { + if (table[current_bucket] != 0) + { + current_element = table[current_bucket]; + break; + } + } + return true; + } + } + else + { + // if we have already enumerated every element + if (current_element == 0) + { + return false; + } + else + { + // find the next element if it exists + if (current_element->next != 0) + { + current_element = current_element->next; + return true; + } + else + { + // find next bucket with something in it + for (current_bucket+=1; current_bucket + +namespace dlib +{ + + template < + typename domain, + typename range, + typename bst_base, + typename mem_manager = default_memory_manager, + typename compare = std::less + > + class hash_table_kernel_2 : public enumerable >, + public pair_remover + { + + /*! + REQUIREMENTS ON bst_base + bst_base is instantiated with domain and range and + implements binray_search_tree/binary_search_tree_kernel_abstract.h + + INITIAL VALUE + hash_size == 0 + table == pointer to an array of num_of_buckets bst_base objects + num_of_buckets == the number of buckets in the hash table + current_bucket == 0 + at_start_ == true + + CONVENTION + current_element_valid() == (current_bucket != 0) + element() == current_bucket->element() + at_start_ == at_start() + + mask == num_of_buckets-1 + + for all integers i where &table[i] != current_bucket + table[i].at_start() == true + + + hash_size = size() == the number of elements in the hash_table and + table == pointer to an array of num_of_buckets bst_base objects + num_of_buckets == the number of buckets in the hash table and + the elements in this hash table are stored in the bst_base objects in the + array table + + !*/ + + + + public: + + typedef domain domain_type; + typedef range range_type; + typedef compare compare_type; + typedef mem_manager mem_manager_type; + + explicit hash_table_kernel_2( + unsigned long expnum + ); + + virtual ~hash_table_kernel_2( + ) + { pool.deallocate_array(table); } + + void clear( + ); + + unsigned long count ( + const domain& item + ) const; + + inline void add ( + domain& d, + range& r + ); + + void destroy ( + const domain& d + ); + + void remove ( + const domain& d, + domain& d_copy, + range& r + ); + + const range* operator[] ( + const domain& item + ) const; + + range* operator[] ( + const domain& item + ); + + inline void swap ( + hash_table_kernel_2& item + ); + + // functions from the remover interface + void remove_any ( + domain& d, + range& r + ); + + // functions from the enumerable interface + inline size_t size ( + ) const; + + inline bool at_start ( + ) const; + + inline void reset ( + ) const; + + bool current_element_valid ( + ) const; + + inline const map_pair& element ( + ) const; + + inline map_pair& element ( + ); + + bool move_next ( + ) const; + + private: + + // data members + typename mem_manager::template rebind::other pool; + unsigned long mask; + unsigned long hash_size; + unsigned long num_of_buckets; + bst_base* table; + general_hash hash; + mutable bst_base* current_bucket; + mutable bool at_start_; + compare comp; + + // restricted functions + hash_table_kernel_2(hash_table_kernel_2&); + hash_table_kernel_2& operator=(hash_table_kernel_2&); + + }; + + template < + typename domain, + typename range, + typename bst_base, + typename mem_manager, + typename compare + > + inline void swap ( + hash_table_kernel_2& a, + hash_table_kernel_2& b + ) { a.swap(b); } + + template < + typename domain, + typename range, + typename bst_base, + typename mem_manager, + typename compare + > + void deserialize ( + hash_table_kernel_2& item, + std::istream& in + ) + { + try + { + item.clear(); + unsigned long size; + deserialize(size,in); + domain d; + range r; + for (unsigned long i = 0; i < size; ++i) + { + deserialize(d,in); + deserialize(r,in); + item.add(d,r); + } + } + catch (serialization_error e) + { + item.clear(); + throw serialization_error(e.info + "\n while deserializing object of type hash_table_kernel_2"); + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename bst_base, + typename mem_manager, + typename compare + > + hash_table_kernel_2:: + hash_table_kernel_2( + unsigned long expnum + ) : + hash_size(0), + current_bucket(0), + at_start_(true) + { + + num_of_buckets = 1; + while (expnum != 0) + { + --expnum; + num_of_buckets <<= 1; + } + mask = num_of_buckets-1; + + table = pool.allocate_array(num_of_buckets); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename bst_base, + typename mem_manager, + typename compare + > + void hash_table_kernel_2:: + clear( + ) + { + if (hash_size != 0) + { + hash_size = 0; + for (unsigned long i = 0; i < num_of_buckets; ++i) + table[i].clear(); + } + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename bst_base, + typename mem_manager, + typename compare + > + size_t hash_table_kernel_2:: + size( + ) const + { + return hash_size; + } +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename bst_base, + typename mem_manager, + typename compare + > + unsigned long hash_table_kernel_2:: + count( + const domain& item + ) const + { + return table[hash(item)&mask].count(item); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename bst_base, + typename mem_manager, + typename compare + > + void hash_table_kernel_2:: + destroy( + const domain& item + ) + { + table[hash(item)&mask].destroy(item); + --hash_size; + + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename bst_base, + typename mem_manager, + typename compare + > + void hash_table_kernel_2:: + add( + domain& d, + range& r + ) + { + table[hash(d)&mask].add(d,r); + ++hash_size; + + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename bst_base, + typename mem_manager, + typename compare + > + void hash_table_kernel_2:: + remove( + const domain& d, + domain& d_copy, + range& r + ) + { + table[hash(d)&mask].remove(d,d_copy,r); + --hash_size; + + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename bst_base, + typename mem_manager, + typename compare + > + void hash_table_kernel_2:: + remove_any( + domain& d, + range& r + ) + { + unsigned long i = 0; + while (table[i].size() == 0) + { + ++i; + } + table[i].remove_any(d,r); + --hash_size; + + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename bst_base, + typename mem_manager, + typename compare + > + const range* hash_table_kernel_2:: + operator[]( + const domain& d + ) const + { + return table[hash(d)&mask][d]; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename bst_base, + typename mem_manager, + typename compare + > + range* hash_table_kernel_2:: + operator[]( + const domain& d + ) + { + return table[hash(d)&mask][d]; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename bst_base, + typename mem_manager, + typename compare + > + void hash_table_kernel_2:: + swap( + hash_table_kernel_2& item + ) + { + pool.swap(item.pool); + exchange(mask,item.mask); + exchange(hash_size,item.hash_size); + exchange(num_of_buckets,item.num_of_buckets); + exchange(table,item.table); + exchange(current_bucket,item.current_bucket); + exchange(at_start_,item.at_start_); + exchange(comp,item.comp); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // enumerable function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename bst_base, + typename mem_manager, + typename compare + > + bool hash_table_kernel_2:: + at_start ( + ) const + { + return at_start_; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename bst_base, + typename mem_manager, + typename compare + > + void hash_table_kernel_2:: + reset ( + ) const + { + at_start_ = true; + if (current_bucket != 0) + { + current_bucket->reset(); + current_bucket = 0; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename bst_base, + typename mem_manager, + typename compare + > + bool hash_table_kernel_2:: + current_element_valid ( + ) const + { + return (current_bucket != 0); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename bst_base, + typename mem_manager, + typename compare + > + const map_pair& hash_table_kernel_2:: + element ( + ) const + { + return current_bucket->element(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename bst_base, + typename mem_manager, + typename compare + > + map_pair& hash_table_kernel_2:: + element ( + ) + { + return current_bucket->element(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename bst_base, + typename mem_manager, + typename compare + > + bool hash_table_kernel_2:: + move_next ( + ) const + { + if (at_start_) + { + at_start_ = false; + // if the queue is empty then there is nothing to do + if (hash_size == 0) + { + return false; + } + else + { + // find the first element in the hash table + current_bucket = table; + while (current_bucket->size() == 0) + { + ++current_bucket; + } + + current_bucket->move_next(); + + return true; + } + } + else + { + // if we have already enumerated every element + if (current_bucket == 0) + { + return false; + } + else + { + if (current_bucket->move_next()) + { + // if there is another element in this current bucket then use that + return true; + } + else + { + // find the next bucket + bst_base* end = table + num_of_buckets; + current_bucket->reset(); + + while (true) + { + ++current_bucket; + // if we ran out of buckets and didn't find anything + if (current_bucket == end) + { + current_bucket = 0; + return false; + } + if (current_bucket->size() > 0) + { + current_bucket->move_next(); + return true; + } + } + } + } + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_HASH_TABLE_KERNEl_2_ + diff --git a/ml/dlib/dlib/hash_table/hash_table_kernel_abstract.h b/ml/dlib/dlib/hash_table/hash_table_kernel_abstract.h new file mode 100644 index 000000000..52480e91c --- /dev/null +++ b/ml/dlib/dlib/hash_table/hash_table_kernel_abstract.h @@ -0,0 +1,253 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_HASH_TABLE_KERNEl_ABSTRACT_ +#ifdef DLIB_HASH_TABLE_KERNEl_ABSTRACT_ + +#include "../interfaces/map_pair.h" +#include "../general_hash/general_hash.h" +#include "../interfaces/enumerable.h" +#include "../interfaces/remover.h" +#include "../serialize.h" +#include "../algs.h" +#include + +namespace dlib +{ + + template < + typename domain, + typename range, + typename mem_manager = default_memory_manager, + typename compare = std::less + > + class hash_table : public enumerable >, + public pair_remover + { + + /*! + REQUIREMENTS ON domain + domain must be comparable by compare where compare is a functor compatible with std::less and + domain must be hashable by general_hash + (general_hash is defined in dlib/general_hash) and + domain must be swappable by a global swap() and + domain must have a default constructor + + REQUIREMENTS ON range + range must be swappable by a global swap() and + range must have a default constructor + + REQUIREMENTS ON mem_manager + must be an implementation of memory_manager/memory_manager_kernel_abstract.h or + must be an implementation of memory_manager_global/memory_manager_global_kernel_abstract.h or + must be an implementation of memory_manager_stateless/memory_manager_stateless_kernel_abstract.h + mem_manager::type can be set to anything. + + POINTERS AND REFERENCES TO INTERNAL DATA + swap(), count(), and operator[] functions do + not invalidate pointers or references to internal data. + All other functions have no such guarantee. + + INITIAL VALUE + size() == 0 + + ENUMERATION ORDER + No order is specified. Only that each element will be visited once + and only once. + + WHAT THIS OBJECT REPRESENTS + hash_table contains items of type T + + This object represents a data dictionary that is built on top of some + kind of hash table. The number of buckets in the hash table is + defined by the constructor argument and is some power of 2. + + Also note that unless specified otherwise, no member functions + of this object throw exceptions. + + NOTE: + definition of equivalent: + a is equivalent to b if + a < b == false and + b < a == false + !*/ + + + public: + + typedef domain domain_type; + typedef range range_type; + typedef compare compare_type; + typedef mem_manager mem_manager_type; + + explicit hash_table( + unsigned long expnum + ); + /*! + requires + - expnum < 32 + ensures + - #*this is properly initialized + - #*this will use 2^expnum as a suggestion for the initial number + of buckets. + throws + - std::bad_alloc or any exception thrown by domain's or range's + constructor. + !*/ + + virtual ~hash_table( + ); + /*! + ensures + - all memory associated with *this has been released + !*/ + + void clear( + ); + /*! + ensures + - #*this has its initial value + throws + - std::bad_alloc or any exception thrown by domain's or range's + constructor. + if this exception is thrown then *this is unusable + until clear() is called and succeeds + !*/ + + unsigned long count ( + const domain& d + ) const; + /*! + ensures + - returns the number of elements in the domain of *this that are + equivalent to d + !*/ + + void add ( + domain& d, + range& r + ); + /*! + requires + - &d != &r (i.e. d and r cannot be the same variable) + ensures + - adds a mapping between d and r to *this + - if (count(d) == 0) then + - #*(*this)[d] == r + - else + - #(*this)[d] != 0 + - #d and #r have initial values for their types + - #count(d) == count(d) + 1 + - #at_start() == true + - #size() == size() + 1 + throws + - std::bad_alloc or any exception thrown by domain's or range's + constructor. + if add() throws then it has no effect + !*/ + + void remove ( + const domain& d, + domain& d_copy, + range& r + ); + /*! + requires + - (*this)[d] != 0 + - &d != &r (i.e. d and r cannot be the same variable) + - &d != &d_copy (i.e. d and d_copy cannot be the same variable) + - &r != &d_copy (i.e. r and d_copy cannot be the same variable) + ensures + - some element in the domain of *this that is equivalent to d has + been removed and swapped into #d_copy. Additionally, its + associated range element has been removed and swapped into #r. + - #count(d) = count(d) - 1 + - #size() == size() - 1 + - #at_start() == true + !*/ + + void destroy ( + const domain& d + ); + /*! + requires + - (*this)[d] != 0 + ensures + - an element in the domain of *this equivalent to d has been removed. + The element in the range of *this associated with d has also been + removed. + - #count(d) == count(d) - 1 + - #size() == size() - 1 + - #at_start() == true + !*/ + + const range* operator[] ( + const domain& d + ) const; + /*! + ensures + - if (there is an element in the domain equivalent to d) then + - returns a pointer to an element in the range of *this that + is associated with an element in the domain of *this + equivalent to d. + - else + - returns 0 + !*/ + + range* operator[] ( + const domain& d + ); + /*! + ensures + - if (there is an element in the domain equivalent to d) then + - returns a pointer to an element in the range of *this that + is associated with an element in the domain of *this + equivalent to d. + - else + - returns 0 + !*/ + + void swap ( + hash_table& item + ); + /*! + ensures + - swaps *this and item + !*/ + + private: + + // restricted functions + hash_table(hash_table&); + hash_table& operator=(hash_table&); + + }; + + template < + typename domain, + typename range, + typename mem_manager + > + inline void swap ( + hash_table& a, + hash_table& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + + template < + typename domain, + typename range, + typename mem_manager + > + void deserialize ( + hash_table& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ +} + +#endif // DLIB_HASH_TABLE_KERNEl_ABSTRACT_ + diff --git a/ml/dlib/dlib/hash_table/hash_table_kernel_c.h b/ml/dlib/dlib/hash_table/hash_table_kernel_c.h new file mode 100644 index 000000000..93ef16d0b --- /dev/null +++ b/ml/dlib/dlib/hash_table/hash_table_kernel_c.h @@ -0,0 +1,194 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_HASH_TABLE_KERNEl_C_ +#define DLIB_HASH_TABLE_KERNEl_C_ + +#include "hash_table_kernel_abstract.h" +#include "../algs.h" +#include "../interfaces/map_pair.h" +#include "../assert.h" + +namespace dlib +{ + + template < + typename ht_base + > + class hash_table_kernel_c : public ht_base + { + typedef typename ht_base::domain_type domain; + typedef typename ht_base::range_type range; + public: + + explicit hash_table_kernel_c ( + unsigned long expnum + ) : + ht_base(expnum) + { + DLIB_CASSERT(expnum < 32, + "\thash_table::hash_table(unsigned long)" + << "\n\tyou can't set expnum >= 32" + << "\n\tthis: " << this + << "\n\texpnum: " << expnum + ); + } + + void remove ( + const domain& d, + domain& d_copy, + range& r + ); + + void remove_any ( + domain& d, + range& r + ); + + void add ( + domain& d, + range& r + ); + + void destroy ( + const domain& d + ); + + const map_pair& element ( + ) const + { + DLIB_CASSERT(this->current_element_valid() == true, + "\tconst map_pair& hash_table::element() const" + << "\n\tyou can't access the current element if it doesn't exist" + << "\n\tthis: " << this + ); + + return ht_base::element(); + } + + map_pair& element ( + ) + { + DLIB_CASSERT(this->current_element_valid() == true, + "\tmap_pair& hash_table::element()" + << "\n\tyou can't access the current element if it doesn't exist" + << "\n\tthis: " << this + ); + + return ht_base::element(); + } + + + }; + + + template < + typename ht_base + > + inline void swap ( + hash_table_kernel_c& a, + hash_table_kernel_c& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename ht_base + > + void hash_table_kernel_c:: + remove ( + const domain& d, + domain& d_copy, + range& r + ) + { + DLIB_CASSERT(this->operator[](d) != 0 && + (static_cast(&d) != static_cast(&d_copy)) && + (static_cast(&d) != static_cast(&r)) && + (static_cast(&r) != static_cast(&d_copy)), + "\tvoid binary_search_tree::remove" + << "\n\tthe element must be in the table for it to be removed" + << "\n\tthis: " << this + << "\n\t&d: " << &d + << "\n\t&d_copy: " << &d_copy + << "\n\t&r: " << &r + ); + + ht_base::remove(d,d_copy,r); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename ht_base + > + void hash_table_kernel_c:: + add( + domain& d, + range& r + ) + { + DLIB_CASSERT( static_cast(&d) != static_cast(&r), + "\tvoid binary_search_tree::add" + << "\n\tyou can't call add() and give the same object to both arguments." + << "\n\tthis: " << this + << "\n\t&d: " << &d + << "\n\t&r: " << &r + << "\n\tsize(): " << this->size() + ); + + ht_base::add(d,r); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename ht_base + > + void hash_table_kernel_c:: + destroy( + const domain& d + ) + { + DLIB_CASSERT((*this)[d] != 0, + "\tvoid hash_table::destroy" + << "\n\tthe element must be in the table for it to be destroyed" + << "\n\tthis: " << this + << "\n\t&d: " << &d + ); + + ht_base::destroy(d); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename ht_base + > + void hash_table_kernel_c:: + remove_any( + domain& d, + range& r + ) + { + DLIB_CASSERT(this->size() != 0 && + (static_cast(&d) != static_cast(&r)), + "\tvoid hash_table::remove_any" + << "\n\ttable must not be empty if something is going to be removed" + << "\n\tthis: " << this + << "\n\t&d: " << &d + << "\n\t&r: " << &r + ); + + ht_base::remove_any(d,r); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_HASH_TABLE_KERNEl_C_ + diff --git a/ml/dlib/dlib/http_client/http_client.cpp b/ml/dlib/dlib/http_client/http_client.cpp new file mode 100644 index 000000000..75e838a73 --- /dev/null +++ b/ml/dlib/dlib/http_client/http_client.cpp @@ -0,0 +1,743 @@ + + +#include "../sockets.h" +#include "../string.h" +#include "../logger.h" +#include "../sockstreambuf.h" +#include "../timeout.h" +#include "http_client.h" +#include +#include +#include +#include +#include + +namespace dlib +{ + + typedef std::shared_ptr timeout_ptr; + + +#ifdef _MSC_VER +#define BR_CASECMP strnicmp +#else +#define BR_CASECMP strncasecmp +#endif +// Default timeout after 60 seconds +#define DEFAULT_TIMEOUT 60000 + +// ---------------------------------------------------------------------------------------- + + inline bool isXdigit( char c ) + { + return (c >= '0' && c <= '9') || + (c >= 'A' && c <= 'Z') || + (c >= 'a' && c <= 'z'); + } + +// ---------------------------------------------------------------------------------------- + + std::string http_client::urldecode( const std::string& s ) + { + std::stringstream ss; + + for ( char const * p_read = s.c_str(), * p_end = (s.c_str() + s.size()); p_read < p_end; p_read++ ) + { + if ( p_read[0] == '%' && p_read+1 != p_end && p_read+2 != p_end && isXdigit(p_read[1]) && isXdigit(p_read[2]) ) + { + ss << static_cast((( (p_read[1] & 0xf) + ((p_read[1] >= 'A') ? 9 : 0) ) << 4 ) | ( (p_read[2] & 0xf) + ((p_read[2] >= 'A') ? 9 : 0) )); + p_read += 2; + } + else if ( p_read[0] == '+' ) + { + // Undo the encoding that replaces spaces with plus signs. + ss << ' '; + } + else + { + ss << p_read[0]; + } + } + + return ss.str(); + } + +// ---------------------------------------------------------------------------------------- + +//! \return modified string ``s'' with spaces trimmed from left + inline std::string& triml(std::string& s) + { + int pos(0); + for ( ; s[pos] == ' ' || s[pos] == '\t' || s[pos] == '\r' || s[pos] == '\n' ; ++pos ); + s.erase(0, pos); + return s; + } + +// ---------------------------------------------------------------------------------------- + +//! \return modified string ``s'' with spaces trimmed from right + inline std::string& trimr(std::string& s) + { + int pos(s.size()); + for ( ; pos && (s[pos-1] == ' ' || s[pos-1] == '\t' || s[pos-1] == '\r' || s[pos-1] == '\n') ; --pos ); + s.erase(pos, s.size()-pos); + return s; + } + +// ---------------------------------------------------------------------------------------- + +//! \return modified string ``s'' with spaces trimmed from edges + inline std::string& trim(std::string& s) + { + return triml(trimr(s)); + } + +// ---------------------------------------------------------------------------------------- + + http_client:: + http_client( + ) : + http_return(0), + timeout(DEFAULT_TIMEOUT), + OnDownload(0) + { + } + +// ---------------------------------------------------------------------------------------- + + std::string http_client::get_header(const std::string& header_name) const + { + stringmap::const_iterator ci = headers.find(header_name); + return ci != headers.end() ? ci->second : std::string(); + } + +// ---------------------------------------------------------------------------------------- + + void http_client::set_header(const std::string& header_name, long header_value) + { + char buf[21] = { 0 }; +#ifdef __WXMSW__ + ::ltoa(header_value, buf, 10); +#else + sprintf(buf, "%ld", header_value); +#endif + set_header(header_name, buf); + } + +// ---------------------------------------------------------------------------------------- + + void http_client::set_header(const std::string& header_name, const std::string& header_value) + { + headers[header_name] = header_value; + } + +// ---------------------------------------------------------------------------------------- + + bool http_client::is_header_set(const std::string& header_name) const + { + stringmap::const_iterator ci = headers.find(header_name); + return ci != headers.end() && !ci->second.empty(); + } + +// ---------------------------------------------------------------------------------------- + + void http_client::remove_header(const std::string& header_name) + { + headers.erase(header_name); + } + +// ---------------------------------------------------------------------------------------- + + void http_client::set_cookie(const std::string& cookie_name, long cookie_value) + { + char buf[21] = { 0 }; +#ifdef __WXMSW__ + ::ltoa(cookie_value, buf, 10); +#else + sprintf(buf, "%ld", cookie_value); +#endif + set_cookie(cookie_name, buf); + } + +// ---------------------------------------------------------------------------------------- + + void http_client::set_cookie(const std::string& cookie_name, const std::string& cookie_value) + { + cookies[cookie_name] = cookie_value; + } + +// ---------------------------------------------------------------------------------------- + + void http_client::remove_cookie(const std::string& cookie_name) + { + cookies.erase(cookie_name); + } + +// ---------------------------------------------------------------------------------------- + +// POST + const std::string& http_client::post_url (const std::string& url, const string_to_stringmap& postvars, const string_to_stringmap& filenames) + { + std::string CT; + std::string postBody = build_post(CT, postvars, filenames); + set_header("Content-Type", CT); + set_header("Content-Length", static_cast(postBody.size())); + + grab_url(url, "POST", postBody); + + return returned_body; + } + +// ---------------------------------------------------------------------------------------- + + const std::string& http_client::post_url (const std::string& url, const std::string& postbuffer) + { + if ( !is_header_set("Content-Type") ) // Maybe they just forgot it? + set_header("Content-Type", "application/x-www-form-urlencoded"); + + set_header("Content-Length", static_cast(postbuffer.size())); + + grab_url(url, "POST", postbuffer); + + return returned_body; + } + +// ---------------------------------------------------------------------------------------- + + std::string http_client::get_random_string( size_t length ) const + { + static bool has_seeded(false); + static std::string allowed_chars("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"); + + if ( !has_seeded ) + { + has_seeded = true; + ::srand( static_cast(::time(NULL)) ); + } + + std::string retVal; retVal.reserve(length); + while ( retVal.size() < length ) + { + retVal += allowed_chars[(rand() % allowed_chars.size())]; + } + + return retVal; + } + +// ---------------------------------------------------------------------------------------- + +// static + std::string http_client::urlencode(const std::string& in, bool post_encode) + { + static std::string allowed_chars("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_"); + + std::stringstream ss; + ss << std::hex; + for (std::string::const_iterator ci = in.begin(); ci != in.end(); ++ci) + { + if ( allowed_chars.find(*ci) != std::string::npos ) + { + ss << *ci; + } + else if ( post_encode && *ci == ' ' ) + { + ss << '+'; + } + else + { + ss << '%' << std::setfill('0') << std::setw(2) << std::right << static_cast(*ci); + } + } + + return ss.str(); + } + +// ---------------------------------------------------------------------------------------- + + std::string http_client::get_basename( const std::string& filename ) const + { + std::string::size_type pos = filename.find_last_of("\\/"); + if ( pos == std::string::npos ) + return filename; + else + return filename.substr(pos+1); + } + +// ---------------------------------------------------------------------------------------- + + bool http_client::parse_url( + const std::string& url, + std::string& scheme, + std::string& user, + std::string& pass, + std::string& host, + short& port, + std::string& path + ) const + { + scheme.clear(); + user.clear(); + pass.clear(); + host.clear(); + path.clear(); + port = 0; + + // Find scheme + std::string::size_type pos_scheme = url.find("://"); + if ( pos_scheme == std::string::npos ) + { + pos_scheme = 0; + } + else + { + scheme = strtolower(url.substr(0, pos_scheme)); + pos_scheme += 3; + } + + std::string::size_type pos_path = url.find('/', pos_scheme); + if ( pos_path == std::string::npos ) + { + host = url.substr(pos_scheme); + } + else + { + host = url.substr(pos_scheme, pos_path - pos_scheme); + path = url.substr(pos_path); + } + + std::string::size_type pos_at = host.find('@'); + if ( pos_at != std::string::npos ) + { + std::string::size_type pos_dp = host.find(':'); + if ( pos_dp != std::string::npos && pos_dp < pos_at ) + { + user = host.substr(0, pos_dp); + pass = host.substr(pos_dp+1, pos_at-pos_dp-1); + } + else + { + user = host.substr(0, pos_at); + } + host = host.substr(pos_at+1); + } + + std::string::size_type pos_dp = host.find(':'); + if ( pos_dp != std::string::npos ) + { + port = dlib::string_cast(host.substr(pos_dp+1)); + host = host.substr(0, pos_dp); + } + + host = strtolower(host); + + if ( port == 0 ) + { + if ( scheme == "http" ) + port = 80; + else if ( scheme == "ftp" ) + port = 21; + else if ( scheme == "https" ) + port = 443; + } + + return !host.empty(); + } + +// ---------------------------------------------------------------------------------------- + + std::string http_client::strtolower(const std::string& in) const + { + std::string retVal = in; + + for (std::string::iterator ii = retVal.begin(); ii != retVal.end(); ++ii) + { + *ii = ::tolower(*ii); + } + + return retVal; + } + +// ---------------------------------------------------------------------------------------- + + std::string http_client::strtoupper(const std::string& in) const + { + std::string retVal = in; + + for (std::string::iterator ii = retVal.begin(); ii != retVal.end(); ++ii) + { + *ii = ::toupper(*ii); + } + + return retVal; + } + +// ---------------------------------------------------------------------------------------- + +// GET + const std::string& http_client::get_url(const std::string& url) + { + std::string CT = get_header("Content-Type"); + + // You do a GET with a POST header?? + if ( CT == "application/x-www-form-urlencoded" || CT == "multipart/form-data" ) + remove_header("Content-Type"); + + grab_url(url); + + return returned_body; + } + +// ---------------------------------------------------------------------------------------- + + std::string http_client::build_post(std::string& content_type, const string_to_stringmap& postvars, const string_to_stringmap& filenames_in) const + { + if ( postvars.empty() && filenames_in.empty() ) + return std::string(); + + string_to_stringmap filenames = filenames_in; + + // sanitize the files + if ( !filenames.empty() ) + { + string_to_stringmap::iterator var_names = filenames.begin(); + while (var_names != filenames.end()) + { + stringmap::iterator fnames = var_names->second.begin(); + + while( fnames != var_names->second.end() ) + { + FILE *fp = ::fopen(fnames->second.c_str(), "rb"); + if ( fp == NULL ) + { + stringmap::iterator old_one = fnames++; + var_names->second.erase(old_one); + } + else + { + fclose(fp); + ++fnames; + } + } + + if ( fnames->second.empty() ) + { + string_to_stringmap::iterator old_one = var_names++; + filenames.erase(old_one); + } + else + { + ++var_names; + } + } + } + + content_type = !filenames.empty() ? "multipart/form-data" : "application/x-www-form-urlencoded"; + std::stringstream postBody; + if ( !filenames.empty() ) + { + std::string mime_boundary = get_random_string(32); + + // First add the form vars + for (string_to_stringmap::const_iterator ci = postvars.begin(); ci != postvars.end(); ++ci) + { + for (stringmap::const_iterator si = ci->second.begin(); si != ci->second.end(); ++si) + { + postBody << "--" << mime_boundary << "\r\n" + "Content-Disposition: form-data; name=\"" << ci->first << "\"\r\n\r\n" + << si->second << "\r\n"; + } + } + + // Then add the files + for (string_to_stringmap::const_iterator ci = filenames.begin(); ci != filenames.end(); ++ci) + { + for (stringmap::const_iterator si = ci->second.begin(); si != ci->second.end(); ++si) + { + std::ifstream in(si->second.c_str()); + postBody << "--" << mime_boundary << "\r\n" + "Content-Disposition: form-data; name=\"" << ci->first << "\"; filename=\"" << get_basename(si->second) << "\"\r\n\r\n" + << in.rdbuf() << "\r\n"; + } + } + + postBody << "--" << mime_boundary << "--\r\n"; + } + else + { + // No files... + for (string_to_stringmap::const_iterator ci = postvars.begin(); ci != postvars.end(); ++ci) + { + for (stringmap::const_iterator si = ci->second.begin(); si != ci->second.end(); ++si) + { + postBody << urlencode(ci->first) << '=' << urlencode(si->second) << '&'; + } + } + + // read the last '&' + char c; + postBody.read(&c, 1); + } + + return postBody.str(); + } + +// ---------------------------------------------------------------------------------------- + + bool http_client::grab_url(const std::string& url, const std::string& method, const std::string& post_body) + { + error_field.clear(); + returned_headers.clear(); + http_return = 0; + returned_body.clear(); + + std::string to_use_method = strtoupper(method); + + std::string scheme, user, pass, host, path; + short port; + if ( !parse_url(url, scheme, user, pass, host, port, path) ) + { + error_field = "Couldn't parse the URL!"; + return false; + } + + // Build request + std::stringstream ret; + ret << to_use_method << ' ' << path << " HTTP/1.0\r\n" + << "Host: " << host; + if (port != 80 && port != 443) ret << ':' << port; + ret << "\r\n"; + + bool content_length_said = false; + + set_header("Connection", "Close"); + for (stringmap::iterator ci = headers.begin(); ci != headers.end(); ++ci) + { + std::string head = strtolower(ci->first); + + if ( head == "content-length" ) + { + content_length_said = true; + } + + ret << ci->first << ':' << ' ' << ci->second << "\r\n"; + } + + if ( !content_length_said && to_use_method != "GET" ) + ret << "Content-Length: " << static_cast(post_body.size()) << "\r\n"; + + std::stringstream cookie_ss; + for (stringmap::iterator ci = cookies.begin(); ci != cookies.end(); ++ci) + { + std::string var = ci->first ; trim(var); + std::string val = ci->second; trim(val); + + if ( val.empty() || var.empty() ) + continue; + + if ( !cookie_ss.str().empty() ) + cookie_ss << ';' << ' '; + + cookie_ss << urlencode(var) << '=' << urlencode(val); + } + + if ( !cookie_ss.str().empty() ) + ret << "Cookie: " << cookie_ss.str() << "\r\n"; + + ret << "\r\n"; + ret << post_body; + + std::string request_build = ret.str(); + + std::stringstream ss; + { + dlib::connection * conn(0); + try + { + if (timeout > 0) + conn = dlib::connect(host, port, timeout); + else + conn = dlib::connect(host, port); + } + catch (const dlib::socket_error& e) + { + error_field = e.what(); + return false; + } + + // Implement a timeout + timeout_ptr t; + if ( timeout > 0 ) + t.reset( new dlib::timeout(*conn, &dlib::connection::shutdown, timeout) ); + + // Write our request + conn->write(request_build.c_str(), static_cast(request_build.size())); + + t.reset(); + + // And read the response + char buf[512]; + long bytes_read(0), bytes_total(0); + bool read_headers(true); + + if ( timeout > 0 ) + t.reset( new dlib::timeout(*conn, &dlib::connection::shutdown, timeout) ); + + while ( (bytes_read = conn->read(buf, 512)) > 0 ) + { + ss.write(buf, bytes_read); + + // Incremental read headers + if ( read_headers ) + { + std::string body_with_headers = ss.str(); + std::string::size_type ctr(0); + + while ( true ) + { + std::string::size_type pos = body_with_headers.find("\r\n", ctr); + if ( pos == std::string::npos ) + { + // This is our last position of "\r\n" + ss.str(""); + ss.write( body_with_headers.substr(ctr).c_str(), body_with_headers.size() - ctr ); + break; + } + + std::string header = body_with_headers.substr(ctr, pos-ctr); + if ( header.empty() ) + { + // Ok, we're done reading the headers + read_headers = false; + // What follows now is the body + ss.str(""); + ss.write( body_with_headers.substr(pos + 2).c_str(), body_with_headers.size() - pos - 2 ); + break; + } + ctr = pos + 2; + + if ( returned_headers.empty() ) + { + if ( + header[0] == 'H' && + header[1] == 'T' && + header[2] == 'T' && + header[3] == 'P' && + header[4] == '/' && + (header[5] >= '0' && header[5] <= '9') && + header[6] == '.' && + (header[7] >= '0' && header[7] <= '9') && + header[8] == ' ' + ) + { + http_return = (header[9 ] - '0') * 100 + + (header[10] - '0') * 10 + + (header[11] - '0'); + continue; + } + } + + std::string::size_type pos_dp = header.find_first_of(':'); + std::string header_name, header_value; + if ( pos_dp == std::string::npos ) + { + // **TODO** what should I do here?? + header_name = header; + } + else + { + header_name = trim(header.substr(0, pos_dp)); + header_value = trim(header.substr(pos_dp+1)); + } + + returned_headers[ header_name ].push_back(header_value); + + if ( BR_CASECMP(header_name.c_str(), "Content-Length", 14) == 0 ) + { + bytes_total = atol( header_value.c_str() ); + } + else if ( BR_CASECMP(header_name.c_str(), "Set-Cookie", 10) == 0 ) + { + std::string::size_type cur_pos(0), pos_pk, pos_is; + std::string work, var, val; + for ( cur_pos = 0; cur_pos < header_value.size(); cur_pos++ ) + { + pos_pk = header_value.find(';', cur_pos); + work = trim( header_value.substr(cur_pos, pos_pk - cur_pos) ); + + pos_is = work.find('='); + if ( pos_is != std::string::npos ) + { // Hmmm? what in the else case? + var = trim( http_client::urldecode( work.substr(0, pos_is) ) ); + val = trim( http_client::urldecode( work.substr(pos_is + 1) ) ); + + if ( var != "expires" && var != "domain" && var != "path" ) + set_cookie( var, val ); + } + cur_pos = pos_pk == std::string::npos ? pos_pk - 1 : pos_pk; + } + } // Set-Cookie? + + } // while (true) + } // read_headers? + + // Call the OnDownload function if it's set + if ( OnDownload && !read_headers ) + { + if ( (*OnDownload)(static_cast(ss.tellp()), bytes_total, user_info) == false ) + { + t.reset(); + break; + } + } + + if ( bytes_total != 0 && static_cast(ss.tellp()) == bytes_total ) + { + t.reset(); + break; + } + + if ( timeout > 0 ) + t.reset( new dlib::timeout(*conn, &dlib::connection::shutdown, timeout) ); + } // while still data to read + + t.reset(); + + delete conn; + + + switch ( bytes_read ) + { + case dlib::TIMEOUT: error_field = "Timeout"; return false; break; + case dlib::WOULDBLOCK: error_field = "Would block"; return false; break; + case dlib::OTHER_ERROR: error_field = "Other error"; return false; break; + case dlib::SHUTDOWN: error_field = "Timeout"; return false; break; + case dlib::PORTINUSE: error_field = "Port in use"; return false; break; + } + } + + returned_body = ss.str(); + + return true; + } + +// ---------------------------------------------------------------------------------------- + + void http_client::clear() + { + headers.clear(); + cookies.clear(); + } + +// ---------------------------------------------------------------------------------------- + + void http_client::prepare_for_next_url( ) + { + remove_header("Content-Type"); + remove_header("Content-Length"); + } + +// ---------------------------------------------------------------------------------------- + +} + + diff --git a/ml/dlib/dlib/http_client/http_client.h b/ml/dlib/dlib/http_client/http_client.h new file mode 100644 index 000000000..fd2996ab1 --- /dev/null +++ b/ml/dlib/dlib/http_client/http_client.h @@ -0,0 +1,101 @@ +#ifndef DLIB_BROWSERhH +#define DLIB_BROWSERhH + + +#include +#include +#include +#include "http_client_abstract.h" + + +// Default timeout after 60 seconds +#define DEFAULT_TIMEOUT 60000 + +namespace dlib +{ + + // Function which is called when there is data available. + // Return false to stop the download process... + typedef bool (*fnOnDownload)(long already_downloaded, long total_to_download, void * userInfo); + + + class http_client + { + public: + http_client(); + + typedef std::map< std::string, std::string > stringmap; + typedef std::map< std::string, stringmap > string_to_stringmap; + typedef std::map< std::string, std::vector > string_to_stringvector; + + // Header functions + void set_header(const std::string& header_name, const std::string& header_value); + void set_header(const std::string& header_name, long header_value); + std::string get_header(const std::string& header_name) const; + void remove_header(const std::string& header_name); + bool is_header_set(const std::string& header_name) const; + + // This function will clear out all cookies & headers set until now + void clear(); + // This function will clear out the Content-Type header + void prepare_for_next_url(); + + void set_callback_function( fnOnDownload od, void * _user_info ) { OnDownload = od; user_info = _user_info; } + + void set_cookie(const std::string& cookie_name, const std::string& cookie_value); + void set_cookie(const std::string& cookie_name, long cookie_value); + void remove_cookie(const std::string& cookie_name); + + void set_user_agent(const std::string& new_agent) { set_header("User-Agent", new_agent); } + + + void set_timeout( unsigned int milliseconds = DEFAULT_TIMEOUT ) { timeout = milliseconds; } + + + string_to_stringvector get_returned_headers() const { return returned_headers; } + short get_http_return () const { return http_return; } + const std::string& get_body () const { return returned_body; } + + // POST + const std::string& post_url (const std::string& url, const string_to_stringmap& postvars, const string_to_stringmap& filenames = string_to_stringmap()); + const std::string& post_url (const std::string& url, const std::string& postbuffer); + // GET + const std::string& get_url (const std::string& url); + + bool has_error( ) const { return !error_field.empty(); } + const std::string& get_error( ) const { return error_field; } + + static std::string urlencode(const std::string& in, bool post_encode = false); + static std::string urldecode(const std::string& in); + private: + bool grab_url(const std::string& url, const std::string& method = "GET", const std::string& post_body = ""); + std::string build_post(std::string& content_type, const string_to_stringmap& postvars, const string_to_stringmap& filenames) const; + + std::string get_random_string( size_t length = 32 ) const; + std::string get_basename( const std::string& filename ) const; + std::string strtolower(const std::string& in) const; + std::string strtoupper(const std::string& in) const; + + bool parse_url(const std::string& url, std::string& scheme, std::string& user, std::string& pass, std::string& host, short& port, std::string& path) const; + + stringmap headers; + stringmap cookies; + + string_to_stringvector returned_headers; + short http_return; + std::string returned_body, error_field; + + unsigned int timeout; + + fnOnDownload OnDownload; + void * user_info; + }; + +} + +#ifdef NO_MAKEFILE +#include "http_client.cpp" +#endif + +#endif // DLIB_BROWSERhH + diff --git a/ml/dlib/dlib/http_client/http_client_abstract.h b/ml/dlib/dlib/http_client/http_client_abstract.h new file mode 100644 index 000000000..5e9d1d5e6 --- /dev/null +++ b/ml/dlib/dlib/http_client/http_client_abstract.h @@ -0,0 +1,218 @@ +#undef DLIB_BROWSER_ABSTRACh_ +#ifdef DLIB_BROWSER_ABSTRACh_ + + + +namespace dlib +{ + + // Function which is called when there is data available. + // Return false to stop the download process... + typedef bool (*fnOnDownload)(long already_downloaded, long total_to_download, void * userInfo); + + +// ---------------------------------------------------------------------------------------- +/* +TODO: +- Timed cookie support +- POSTing files: check it! +- Don't timeout when still downloading! +*/ +// ---------------------------------------------------------------------------------------- + + + class Browser + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a possibility for the end user to download webpages (HTTP/1.0) + from the internet like a normal webbrowser would do. + !*/ + + public: + + Browser( + ); + /*! + Constructor + !*/ + + void set_header( + const std::string& header_name, + const std::string& header_value + ); + /*! + Set a header to a certain value + Example: set_header("User-Agent", "Internet Explorer") + !*/ + + + void set_header( + const std::string& header_name, + long header_value + ); + /*! + Set a header to a certain number + Example: set_header("Content-Length", 1234) + !*/ + + std::string get_header( + const std::string& header_name + ) const; + /*! + Get the value of the header or an empty string when it's not set. + Example: get_header("Content-Length") would return "1234" + !*/ + + void remove_header( + const std::string& header_name + ); + /*! + Removes a certain header + !*/ + + bool is_header_set( + const std::string& header_name + ) const; + /*! + Returns when a header is set and is not empty + !*/ + + void set_user_agent( + const std::string& new_agent + ) { set_header("User-Agent", new_agent); } + /*! + Convenience function for setting a user agent + !*/ + + void clear( + ); + /*! + Clear out all cookies & headers set until now + !*/ + + void prepare_for_next_url( + ); + /*! + Clear out any header and/or cookie which would obstruct getting a next page. + At this moment this is cleared: + - the Content-Type header + !*/ + + void set_callback_function( + fnOnDownload od, + void * _user_info + ); + /*! + Set a callback function for one of the following events: + - OnDownload: this will tell you how much is downloaded and how much will need to be downloaded + !*/ + + void set_cookie( + const std::string& cookie_name, + const std::string& cookie_value + ); + /*! + Set a cookie + !*/ + + void set_cookie( + const std::string& cookie_name, + long cookie_value + ); + /*! + Set a cookie + !*/ + + void remove_cookie( + const std::string& cookie_name + ); + /*! + Remove a cookie if it's set + !*/ + + void set_timeout( + unsigned int milliseconds + ); + /*! + Set the maximum time how long a request can take. Setting this to 0 disables + this behavior. + !*/ + + string_to_stringvector get_returned_headers( + ) const; + /*! + Returns all the headers which are returned in the download of the webpage. + !*/ + + short get_http_return ( + ) const; + /*! + Retrieves the HTTP return code. + !*/ + + const std::string& get_body ( + ) const; + /*! + Retrieves the HTTP body. + !*/ + + const std::string& post_url ( + const std::string& url, + const string_to_stringmap& postvars, + const string_to_stringmap& filenames = string_to_stringmap() + ); + /*! + POST an url to the internet. + You can pass the post variables as well as a list of filenames + !*/ + + const std::string& post_url ( + const std::string& url, + const std::string& postbuffer + ); + /*! + POST an url to the internet. + In this function you have constructed the POST string yourselves + !*/ + + const std::string& get_url ( + const std::string& url + ); + /*! + GET an url from the internet. + !*/ + + bool has_error( + ) const; + /*! + Has there happened an error? + !*/ + + const std::string& get_error( + ) const; + /*! + Get the error explanation + !*/ + + static std::string urlencode( + const std::string& in, + bool post_encode = false + ); + /*! + Convenience function to URLencode a string + !*/ + + static std::string urldecode( + const std::string& in + ); + /*! + Convenience function to URLdecode a string + !*/ + + }; + +} + +#endif // DLIB_BROWSER_ABSTRACh_ + diff --git a/ml/dlib/dlib/image_io.h b/ml/dlib/dlib/image_io.h new file mode 100644 index 000000000..4fdc79881 --- /dev/null +++ b/ml/dlib/dlib/image_io.h @@ -0,0 +1,20 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#ifdef DLIB_ALL_SOURCE_END +#include "dlib_basic_cpp_build_tutorial.txt" +#endif + +#ifndef DLIB_IMAGe_IO_ +#define DLIB_IMAGe_IO_ + +#include "image_loader/image_loader.h" +#include "image_loader/png_loader.h" +#include "image_loader/jpeg_loader.h" +#include "image_loader/load_image.h" +#include "image_saver/image_saver.h" +#include "image_saver/save_png.h" +#include "image_saver/save_jpeg.h" + +#endif // DLIB_IMAGe_IO_ + diff --git a/ml/dlib/dlib/image_keypoint.h b/ml/dlib/dlib/image_keypoint.h new file mode 100644 index 000000000..335d620e8 --- /dev/null +++ b/ml/dlib/dlib/image_keypoint.h @@ -0,0 +1,16 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_IMAGE_KEYPOINt_H_ +#define DLIB_IMAGE_KEYPOINt_H_ + +#include "image_keypoint/surf.h" +#include "image_keypoint/hessian_pyramid.h" +#include "image_keypoint/hog.h" +#include "image_keypoint/poly_image.h" +#include "image_keypoint/fine_hog_image.h" +#include "image_keypoint/hashed_feature_image.h" +#include "image_keypoint/nearest_neighbor_feature_image.h" +#include "image_keypoint/binned_vector_feature_image.h" + +#endif // DLIB_IMAGE_KEYPOINt_H_ + diff --git a/ml/dlib/dlib/image_keypoint/binned_vector_feature_image.h b/ml/dlib/dlib/image_keypoint/binned_vector_feature_image.h new file mode 100644 index 000000000..019a12739 --- /dev/null +++ b/ml/dlib/dlib/image_keypoint/binned_vector_feature_image.h @@ -0,0 +1,433 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BINNED_VECTOR_IMAGE_FEATUrES_Hh_ +#define DLIB_BINNED_VECTOR_IMAGE_FEATUrES_Hh_ + +#include "../lsh/projection_hash.h" +#include "binned_vector_feature_image_abstract.h" +#include +#include "../algs.h" +#include "../matrix.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor, + typename hash_function_type_ = projection_hash + > + class binned_vector_feature_image : noncopyable + { + + public: + typedef feature_extractor feature_extractor_type; + typedef hash_function_type_ hash_function_type; + + typedef std::vector > descriptor_type; + + binned_vector_feature_image ( + ); + + void clear ( + ); + + void set_hash ( + const hash_function_type& hash_ + ); + + const hash_function_type& get_hash ( + ) const; + + void copy_configuration ( + const feature_extractor& item + ); + + void copy_configuration ( + const binned_vector_feature_image& item + ); + + template < + typename image_type + > + inline void load ( + const image_type& img + ); + + inline size_t size ( + ) const; + + inline long nr ( + ) const; + + inline long nc ( + ) const; + + inline long get_num_dimensions ( + ) const; + + inline const descriptor_type& operator() ( + long row, + long col + ) const; + + inline const rectangle get_block_rect ( + long row, + long col + ) const; + + inline const point image_to_feat_space ( + const point& p + ) const; + + inline const rectangle image_to_feat_space ( + const rectangle& rect + ) const; + + inline const point feat_to_image_space ( + const point& p + ) const; + + inline const rectangle feat_to_image_space ( + const rectangle& rect + ) const; + + template + friend void serialize ( + const binned_vector_feature_image& item, + std::ostream& out + ); + + template + friend void deserialize ( + binned_vector_feature_image& item, + std::istream& in + ); + + private: + + array2d feats; + feature_extractor fe; + hash_function_type phash; + }; + +// ---------------------------------------------------------------------------------------- + + template + void serialize ( + const binned_vector_feature_image& item, + std::ostream& out + ) + { + int version = 1; + serialize(version, out); + serialize(item.feats, out); + serialize(item.fe, out); + serialize(item.phash, out); + } + + template + void deserialize ( + binned_vector_feature_image& item, + std::istream& in + ) + { + int version = 0; + deserialize(version, in); + if (version != 1) + throw dlib::serialization_error("Unexpected version found while deserializing dlib::binned_vector_feature_image"); + deserialize(item.feats, in); + deserialize(item.fe, in); + deserialize(item.phash, in); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// binned_vector_feature_image member functions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor, + typename hash_function_type + > + binned_vector_feature_image:: + binned_vector_feature_image ( + ) + { + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor, + typename hash_function_type + > + void binned_vector_feature_image:: + clear ( + ) + { + fe.clear(); + phash = hash_function_type(); + feats.clear(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor, + typename hash_function_type + > + void binned_vector_feature_image:: + set_hash ( + const hash_function_type& hash_ + ) + { + phash = hash_; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor, + typename hash_function_type + > + const hash_function_type& binned_vector_feature_image:: + get_hash ( + ) const + { + return phash; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor, + typename hash_function_type + > + void binned_vector_feature_image:: + copy_configuration ( + const feature_extractor& item + ) + { + fe.copy_configuration(item); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor, + typename hash_function_type + > + void binned_vector_feature_image:: + copy_configuration ( + const binned_vector_feature_image& item + ) + { + fe.copy_configuration(item.fe); + phash = item.phash; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor, + typename hash_function_type + > + template < + typename image_type + > + void binned_vector_feature_image:: + load ( + const image_type& img + ) + { + fe.load(img); + + if (fe.size() != 0) + { + feats.set_size(fe.nr(), fe.nc()); + for (long r = 0; r < feats.nr(); ++r) + { + for (long c = 0; c < feats.nc(); ++c) + { + feats[r][c].clear(); + feats[r][c].reserve(fe.get_num_dimensions()+1); + const typename feature_extractor::descriptor_type& des = fe(r,c); + const unsigned long idx = phash(des); + const unsigned long offset = idx*(fe.get_num_dimensions()+1); + + for (long i = 0; i < des.size(); ++i) + { + feats[r][c].push_back(std::make_pair(offset + i, des(i))); + } + feats[r][c].push_back(std::make_pair(offset + des.size(), 1.0)); + } + } + } + else + { + feats.set_size(0,0); + } + + fe.unload(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor, + typename hash_function_type + > + size_t binned_vector_feature_image:: + size ( + ) const + { + return feats.size(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor, + typename hash_function_type + > + long binned_vector_feature_image:: + nr ( + ) const + { + return feats.nr(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor, + typename hash_function_type + > + long binned_vector_feature_image:: + nc ( + ) const + { + return feats.nc(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor, + typename hash_function_type + > + long binned_vector_feature_image:: + get_num_dimensions ( + ) const + { + return phash.num_hash_bins()*(fe.get_num_dimensions()+1); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor, + typename hash_function_type + > + const std::vector >& binned_vector_feature_image:: + operator() ( + long row, + long col + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(0 <= row && row < nr() && + 0 <= col && col < nc(), + "\t descriptor_type binned_vector_feature_image::operator(row,col)" + << "\n\t Invalid inputs were given to this function" + << "\n\t row: " << row + << "\n\t col: " << col + << "\n\t nr(): " << nr() + << "\n\t nc(): " << nc() + << "\n\t this: " << this + ); + + return feats[row][col]; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor, + typename hash_function_type + > + const rectangle binned_vector_feature_image:: + get_block_rect ( + long row, + long col + ) const + { + return fe.get_block_rect(row,col); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor, + typename hash_function_type + > + const point binned_vector_feature_image:: + image_to_feat_space ( + const point& p + ) const + { + return fe.image_to_feat_space(p); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor, + typename hash_function_type + > + const rectangle binned_vector_feature_image:: + image_to_feat_space ( + const rectangle& rect + ) const + { + return fe.image_to_feat_space(rect); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor, + typename hash_function_type + > + const point binned_vector_feature_image:: + feat_to_image_space ( + const point& p + ) const + { + return fe.feat_to_image_space(p); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor, + typename hash_function_type + > + const rectangle binned_vector_feature_image:: + feat_to_image_space ( + const rectangle& rect + ) const + { + return fe.feat_to_image_space(rect); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BINNED_VECTOR_IMAGE_FEATUrES_Hh_ + + diff --git a/ml/dlib/dlib/image_keypoint/binned_vector_feature_image_abstract.h b/ml/dlib/dlib/image_keypoint/binned_vector_feature_image_abstract.h new file mode 100644 index 000000000..6bd6cdbb8 --- /dev/null +++ b/ml/dlib/dlib/image_keypoint/binned_vector_feature_image_abstract.h @@ -0,0 +1,287 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_BINNED_VECTOR_FEATUrES_ABSTRACT_Hh_ +#ifdef DLIB_BINNED_VECTOR_FEATUrES_ABSTRACT_Hh_ + +#include "../lsh/projection_hash_abstract.h" +#include +#include "../matrix.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor, + typename hash_function_type_ = projection_hash + > + class binned_vector_feature_image : noncopyable + { + /*! + REQUIREMENTS ON feature_extractor + - must be an object with an interface compatible with dlib::hog_image + + REQUIREMENTS ON hash_function_type_ + - must be an object with an interface compatible with projection_hash + + INITIAL VALUE + - size() == 0 + + WHAT THIS OBJECT REPRESENTS + This object is a tool for performing image feature extraction. In + particular, it wraps another image feature extractor and converts the + wrapped image feature vectors into a high dimensional sparse vector. For + example, if the lower level feature extractor outputs the vector [3,4,5] + and this vector is hashed into the second bin of four bins then the output + sparse vector is: + [0,0,0,0, 3,4,5,1, 0,0,0,0, 0,0,0,0]. + That is, the output vector has a dimensionality that is equal to the number + of hash bins times the dimensionality of the lower level vector plus one. + The value in the extra dimension concatenated onto the end of the vector is + always a constant value of of 1 and serves as a bias value. This means + that, if there are N hash bins, these vectors are capable of representing N + different linear functions, each operating on the vectors that fall into + their corresponding hash bin. + + + THREAD SAFETY + Concurrent access to an instance of this object is not safe and should be + protected by a mutex lock except for the case where you are copying the + configuration (via copy_configuration()) of a binned_vector_feature_image + object to many other threads. In this case, it is safe to copy the + configuration of a shared object so long as no other operations are + performed on it. + + + NOTATION + let BASE_FE denote the base feature_extractor object contained inside the + binned_vector_feature_image. + !*/ + + public: + + typedef feature_extractor feature_extractor_type; + typedef hash_function_type_ hash_function_type; + typedef std::vector > descriptor_type; + + binned_vector_feature_image ( + ); + /*! + ensures + - this object is properly initialized + !*/ + + void clear ( + ); + /*! + ensures + - this object will have its initial value + !*/ + + void set_hash ( + const hash_function_type& hash + ); + /*! + ensures + - #get_hash() == hash + !*/ + + const hash_function_type& get_hash ( + ) const; + /*! + ensures + - returns the hash function used by this object to hash + base feature vectors into integers. + !*/ + + void copy_configuration ( + const feature_extractor& item + ); + /*! + ensures + - performs BASE_FE.copy_configuration(item) + !*/ + + void copy_configuration ( + const binned_vector_feature_image& item + ); + /*! + ensures + - copies all the state information of item into *this, except for state + information populated by load(). More precisely, given two binned_vector_feature_image + objects H1 and H2, the following sequence of instructions should always + result in both of them having the exact same state. + H2.copy_configuration(H1); + H1.load(img); + H2.load(img); + !*/ + + template < + typename image_type + > + void load ( + const image_type& img + ); + /*! + requires + - image_type == any type that can be supplied to feature_extractor::load() + ensures + - performs BASE_FE.load(img) + i.e. does feature extraction. The features can be accessed using + operator() as defined below. + !*/ + + size_t size ( + ) const; + /*! + ensures + - returns BASE_FE.size() + !*/ + + long nr ( + ) const; + /*! + ensures + - returns BASE_FE.nr() + !*/ + + long nc ( + ) const; + /*! + ensures + - returns BASE_FE.nc() + !*/ + + long get_num_dimensions ( + ) const; + /*! + ensures + - returns the dimensionality of the feature vectors returned by operator(). + In this case, this is the number of hash bins times the dimensionality of + the features produced by BASE_FE plus one. That is, this function + returns get_hash().num_hash_bins()*(BASE_FE.get_num_dimensions()+1) + !*/ + + const descriptor_type& operator() ( + long row, + long col + ) const; + /*! + requires + - 0 <= row < nr() + - 0 <= col < nc() + - It must be legal to evaluate expressions of the form: get_hash()(BASE_FE(row,col)) + (e.g. the hash function must be properly configured to process the feature + vectors produced by the base feature extractor) + ensures + - hashes BASE_FE(row,col) and returns the resulting sparse vector. In + particular, we return a vector that is a copy of BASE_FE(row,col) that + has been shifted into the part of the sparse vector indicated by the hash + function. It will also have a constant bias value of 1 appended to it. + - To be precise, this function returns a sparse vector V such that: + - V.size() == BASE_FE.get_num_dimensions()+1 + - let IDX = get_hash()(BASE_FE(row,col)) + - for i where 0 <= i < BASE_FE.get_num_dimensions(): + - V[i].first == IDX*(BASE_FE.get_num_dimensions()+1) + i + - V[i].second == BASE_FE(row,col)(i) + - V[BASE_FE.get_num_dimensions()].first == IDX*(BASE_FE.get_num_dimensions()+1) + BASE_FE.get_num_dimensions() + - V[BASE_FE.get_num_dimensions()].second == 1 + !*/ + + const rectangle get_block_rect ( + long row, + long col + ) const; + /*! + ensures + - returns BASE_FE.get_block_rect(row,col) + I.e. returns a rectangle that tells you what part of the original image is associated + with a particular feature vector. + !*/ + + const point image_to_feat_space ( + const point& p + ) const; + /*! + ensures + - returns BASE_FE.image_to_feat_space(p) + I.e. Each local feature is extracted from a certain point in the input image. + This function returns the identity of the local feature corresponding + to the image location p. Or in other words, let P == image_to_feat_space(p), + then (*this)(P.y(),P.x()) == the local feature closest to, or centered at, + the point p in the input image. Note that some image points might not have + corresponding feature locations. E.g. border points or points outside the + image. In these cases the returned point will be outside get_rect(*this). + !*/ + + const rectangle image_to_feat_space ( + const rectangle& rect + ) const; + /*! + ensures + - returns BASE_FE.image_to_feat_space(rect) + I.e. returns rectangle(image_to_feat_space(rect.tl_corner()), image_to_feat_space(rect.br_corner())); + (i.e. maps a rectangle from image space to feature space) + !*/ + + const point feat_to_image_space ( + const point& p + ) const; + /*! + ensures + - returns BASE_FE.feat_to_image_space(p) + I.e. returns the location in the input image space corresponding to the center + of the local feature at point p. In other words, this function computes + the inverse of image_to_feat_space(). Note that it may only do so approximately, + since more than one image location might correspond to the same local feature. + That is, image_to_feat_space() might not be invertible so this function gives + the closest possible result. + !*/ + + const rectangle feat_to_image_space ( + const rectangle& rect + ) const; + /*! + ensures + - returns BASE_FE.feat_to_image_space(rect) + I.e. return rectangle(feat_to_image_space(rect.tl_corner()), feat_to_image_space(rect.br_corner())); + (i.e. maps a rectangle from feature space to image space) + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename U + > + void serialize ( + const binned_vector_feature_image& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename U + > + void deserialize ( + binned_vector_feature_image& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BINNED_VECTOR_FEATUrES_ABSTRACT_Hh_ + diff --git a/ml/dlib/dlib/image_keypoint/build_separable_poly_filters.h b/ml/dlib/dlib/image_keypoint/build_separable_poly_filters.h new file mode 100644 index 000000000..aea59067d --- /dev/null +++ b/ml/dlib/dlib/image_keypoint/build_separable_poly_filters.h @@ -0,0 +1,186 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BUILD_SEPARABLE_PoLY_FILTERS_Hh_ +#define DLIB_BUILD_SEPARABLE_PoLY_FILTERS_Hh_ + +#include "../matrix.h" +#include "surf.h" +#include "../uintn.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + typedef std::pair, matrix > separable_filter_type; + typedef std::pair, matrix > separable_int32_filter_type; + +// ---------------------------------------------------------------------------------------- + + inline std::vector > build_separable_poly_filters ( + const long order, + const long window_size + ) + /*! + requires + - 1 <= order <= 6 + - window_size >= 3 && window_size is odd + ensures + - the "first" element is the row_filter, the second is the col_filter. + - Some filters are not totally separable and that's why they are grouped + into vectors of vectors. The groups are all the parts of a partially + separable filter. + !*/ + { + long num_filters = 6; + switch (order) + { + case 1: num_filters = 3; break; + case 2: num_filters = 6; break; + case 3: num_filters = 10; break; + case 4: num_filters = 15; break; + case 5: num_filters = 21; break; + case 6: num_filters = 28; break; + } + + matrix X(window_size*window_size,num_filters); + matrix G(window_size*window_size,1); + const double sigma = window_size/4.0; + + + long cnt = 0; + for (double x = -window_size/2; x <= window_size/2; ++x) + { + for (double y = -window_size/2; y <= window_size/2; ++y) + { + X(cnt, 0) = 1; + X(cnt, 1) = x; + X(cnt, 2) = y; + + if (X.nc() > 5) + { + X(cnt, 3) = x*x; + X(cnt, 4) = x*y; + X(cnt, 5) = y*y; + } + if (X.nc() > 9) + { + X(cnt, 6) = x*x*x; + X(cnt, 7) = y*x*x; + X(cnt, 8) = y*y*x; + X(cnt, 9) = y*y*y; + } + if (X.nc() > 14) + { + X(cnt, 10) = x*x*x*x; + X(cnt, 11) = y*x*x*x; + X(cnt, 12) = y*y*x*x; + X(cnt, 13) = y*y*y*x; + X(cnt, 14) = y*y*y*y; + } + if (X.nc() > 20) + { + X(cnt, 15) = x*x*x*x*x; + X(cnt, 16) = y*x*x*x*x; + X(cnt, 17) = y*y*x*x*x; + X(cnt, 18) = y*y*y*x*x; + X(cnt, 19) = y*y*y*y*x; + X(cnt, 20) = y*y*y*y*y; + } + if (X.nc() > 27) + { + X(cnt, 21) = x*x*x*x*x*x; + X(cnt, 22) = y*x*x*x*x*x; + X(cnt, 23) = y*y*x*x*x*x; + X(cnt, 24) = y*y*y*x*x*x; + X(cnt, 25) = y*y*y*y*x*x; + X(cnt, 26) = y*y*y*y*y*x; + X(cnt, 27) = y*y*y*y*y*y; + } + + G(cnt) = std::sqrt(gaussian(x,y,sigma)); + ++cnt; + } + } + + X = diagm(G)*X; + + const matrix S = inv(trans(X)*X)*trans(X)*diagm(G); + + matrix row_filter, col_filter; + + matrix u,v, temp; + matrix w; + + std::vector > results(num_filters); + + for (long r = 0; r < S.nr(); ++r) + { + temp = reshape(rowm(S,r), window_size, window_size); + svd3(temp,u,w,v); + const double thresh = max(w)*1e-8; + for (long i = 0; i < w.size(); ++i) + { + if (w(i) > thresh) + { + col_filter = std::sqrt(w(i))*colm(u,i); + row_filter = std::sqrt(w(i))*colm(v,i); + results[r].push_back(std::make_pair(row_filter, col_filter)); + } + } + } + + return results; + } + +// ---------------------------------------------------------------------------------------- + + inline std::vector > build_separable_int32_poly_filters ( + const long order, + const long window_size, + const double max_range = 300.0 + ) + /*! + requires + - 1 <= order <= 6 + - window_size >= 3 && window_size is odd + - max_range > 1 + ensures + - the "first" element is the row_filter, the second is the col_filter. + !*/ + { + const std::vector >& filters = build_separable_poly_filters(order, window_size); + std::vector > int_filters(filters.size()); + + for (unsigned long i = 0; i < filters.size(); ++i) + { + + double max_val = 0; + for (unsigned long j = 0; j < filters[i].size(); ++j) + { + const separable_filter_type& filt = filters[i][j]; + max_val = std::max(max_val, max(abs(filt.first))); + max_val = std::max(max_val, max(abs(filt.second))); + } + if (max_val == 0) + max_val = 1; + + int_filters[i].resize(filters[i].size()); + for (unsigned long j = 0; j < filters[i].size(); ++j) + { + const separable_filter_type& filt = filters[i][j]; + int_filters[i][j].first = matrix_cast(round(filt.first*max_range/max_val)); + int_filters[i][j].second = matrix_cast(round(filt.second*max_range/max_val)); + } + } + + return int_filters; + } + +} + +// ---------------------------------------------------------------------------------------- + +#endif // DLIB_BUILD_SEPARABLE_PoLY_FILTERS_Hh_ + diff --git a/ml/dlib/dlib/image_keypoint/draw_surf_points.h b/ml/dlib/dlib/image_keypoint/draw_surf_points.h new file mode 100644 index 000000000..b16c28f5d --- /dev/null +++ b/ml/dlib/dlib/image_keypoint/draw_surf_points.h @@ -0,0 +1,40 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_DRAW_SURf_POINTS_H_ +#define DLIB_DRAW_SURf_POINTS_H_ + +#include "surf.h" +#include "../gui_widgets.h" +#include "draw_surf_points_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + inline void draw_surf_points ( + image_window& win, + const std::vector& sp + ) + { + for (unsigned long i = 0; i < sp.size(); ++i) + { + const unsigned long radius = static_cast(sp[i].p.scale*3); + const point center(sp[i].p.center); + point direction = center + point(radius,0); + // SURF descriptors are rotated by sp[i].angle. So we want to include a visual + // indication of this rotation on our overlay. + direction = rotate_point(center, direction, sp[i].angle); + + win.add_overlay(image_display::overlay_circle(center, radius, rgb_pixel(0,255,0))); + // Draw a line showing the orientation of the SURF descriptor. + win.add_overlay(center, direction, rgb_pixel(255,0,0)); + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_DRAW_SURf_POINTS_H_ + + diff --git a/ml/dlib/dlib/image_keypoint/draw_surf_points_abstract.h b/ml/dlib/dlib/image_keypoint/draw_surf_points_abstract.h new file mode 100644 index 000000000..86a66ef49 --- /dev/null +++ b/ml/dlib/dlib/image_keypoint/draw_surf_points_abstract.h @@ -0,0 +1,30 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_DRAW_SURf_POINTS_ABSTRACT_H_ +#ifdef DLIB_DRAW_SURf_POINTS_ABSTRACT_H_ + +#include "surf_abstract.h" +#include "../gui_widgets.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + void draw_surf_points ( + image_window& win, + const std::vector& sp + ); + /*! + ensures + - draws all the SURF points in sp onto the given image_window. They + are drawn as overlay circles with extra lines to indicate the rotation + of the SURF descriptor. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_DRAW_SURf_POINTS_ABSTRACT_H_ + diff --git a/ml/dlib/dlib/image_keypoint/fine_hog_image.h b/ml/dlib/dlib/image_keypoint/fine_hog_image.h new file mode 100644 index 000000000..a421ffe7c --- /dev/null +++ b/ml/dlib/dlib/image_keypoint/fine_hog_image.h @@ -0,0 +1,378 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_FINE_HOG_IMaGE_Hh_ +#define DLIB_FINE_HOG_IMaGE_Hh_ + +#include "fine_hog_image_abstract.h" +#include "../array2d.h" +#include "../matrix.h" +#include "hog.h" + + +namespace dlib +{ + template < + unsigned long cell_size_, + unsigned long block_size_, + unsigned long pixel_stride_, + unsigned char num_orientation_bins_, + int gradient_type_ + > + class fine_hog_image : noncopyable + { + COMPILE_TIME_ASSERT(cell_size_ > 1); + COMPILE_TIME_ASSERT(block_size_ > 0); + COMPILE_TIME_ASSERT(pixel_stride_ > 0); + COMPILE_TIME_ASSERT(num_orientation_bins_ > 0); + + COMPILE_TIME_ASSERT( gradient_type_ == hog_signed_gradient || + gradient_type_ == hog_unsigned_gradient); + + + public: + + const static unsigned long cell_size = cell_size_; + const static unsigned long block_size = block_size_; + const static unsigned long pixel_stride = pixel_stride_; + const static unsigned long num_orientation_bins = num_orientation_bins_; + const static int gradient_type = gradient_type_; + + const static long min_size = cell_size*block_size+2; + + typedef matrix descriptor_type; + + fine_hog_image ( + ) : + num_block_rows(0), + num_block_cols(0) + {} + + void clear ( + ) + { + num_block_rows = 0; + num_block_cols = 0; + hist_counts.clear(); + } + + void copy_configuration ( + const fine_hog_image& + ){} + + template < + typename image_type + > + inline void load ( + const image_type& img + ) + { + COMPILE_TIME_ASSERT( pixel_traits::pixel_type>::has_alpha == false ); + load_impl(mat(img)); + } + + inline void unload( + ) { clear(); } + + inline size_t size ( + ) const { return static_cast(nr()*nc()); } + + inline long nr ( + ) const { return num_block_rows; } + + inline long nc ( + ) const { return num_block_cols; } + + long get_num_dimensions ( + ) const + { + return block_size*block_size*num_orientation_bins; + } + + inline const descriptor_type& operator() ( + long row, + long col + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT( 0 <= row && row < nr() && + 0 <= col && col < nc(), + "\t descriptor_type fine_hog_image::operator()()" + << "\n\t invalid row or col argument" + << "\n\t row: " << row + << "\n\t col: " << col + << "\n\t nr(): " << nr() + << "\n\t nc(): " << nc() + << "\n\t this: " << this + ); + + row *= pixel_stride; + col *= pixel_stride; + + des = 0; + unsigned long off = 0; + for (unsigned long r = 0; r < block_size; ++r) + { + for (unsigned long c = 0; c < block_size; ++c) + { + for (unsigned long rr = 0; rr < cell_size; ++rr) + { + for (unsigned long cc = 0; cc < cell_size; ++cc) + { + const histogram_count& hist = hist_counts[row + r*cell_size + rr][col + c*cell_size + cc]; + des(off + hist.quantized_angle_lower) += hist.lower_strength; + des(off + hist.quantized_angle_upper) += hist.upper_strength; + } + } + + off += num_orientation_bins; + } + } + + des /= length(des) + 1e-8; + + return des; + } + + const rectangle get_block_rect ( + long row, + long col + ) const + { + row *= pixel_stride; + col *= pixel_stride; + + // do this to account for the 1 pixel padding we use all around the image + ++row; + ++col; + + return rectangle(col, row, col+cell_size*block_size-1, row+cell_size*block_size-1); + } + + const point image_to_feat_space ( + const point& p + ) const + { + const long border_size = 1 + cell_size*block_size/2; + return (p-point(border_size,border_size))/(long)pixel_stride; + } + + const rectangle image_to_feat_space ( + const rectangle& rect + ) const + { + return rectangle(image_to_feat_space(rect.tl_corner()), image_to_feat_space(rect.br_corner())); + } + + const point feat_to_image_space ( + const point& p + ) const + { + const long border_size = 1 + cell_size*block_size/2; + return p*(long)pixel_stride + point(border_size,border_size); + } + + const rectangle feat_to_image_space ( + const rectangle& rect + ) const + { + return rectangle(feat_to_image_space(rect.tl_corner()), feat_to_image_space(rect.br_corner())); + } + + + + // these _PRIVATE_ functions are only here as a workaround for a bug in visual studio 2005. + void _PRIVATE_serialize (std::ostream& out) const + { + // serialize hist_counts + serialize(hist_counts.nc(),out); + serialize(hist_counts.nr(),out); + hist_counts.reset(); + while (hist_counts.move_next()) + hist_counts.element().serialize(out); + hist_counts.reset(); + + + serialize(num_block_rows, out); + serialize(num_block_cols, out); + } + + void _PRIVATE_deserialize (std::istream& in ) + { + // deserialize item.hist_counts + long nc, nr; + deserialize(nc,in); + deserialize(nr,in); + hist_counts.set_size(nr,nc); + while (hist_counts.move_next()) + hist_counts.element().deserialize(in); + hist_counts.reset(); + + + deserialize(num_block_rows, in); + deserialize(num_block_cols, in); + } + + private: + + template < + typename image_type + > + void load_impl ( + const image_type& img + ) + { + // Note that we keep a border of 1 pixel all around the image so that we don't have + // to worry about running outside the image when computing the horizontal and vertical + // gradients. + + + + // check if the window is just too small + if (img.nr() < min_size || img.nc() < min_size) + { + // If the image is smaller than our windows then there aren't any descriptors at all! + num_block_rows = 0; + num_block_cols = 0; + hist_counts.clear(); + return; + } + + hist_counts.set_size(img.nr()-2, img.nc()-2); + + + + + for (long r = 0; r < hist_counts.nr(); ++r) + { + for (long c = 0; c < hist_counts.nc(); ++c) + { + unsigned long left; + unsigned long right; + unsigned long top; + unsigned long bottom; + + assign_pixel(left, img(r+1,c)); + assign_pixel(right, img(r+1,c+2)); + assign_pixel(top, img(r ,c+1)); + assign_pixel(bottom, img(r+2,c+1)); + + double grad_x = (long)right-(long)left; + double grad_y = (long)top-(long)bottom; + + // obtain the angle of the gradient. Make sure it is scaled between 0 and 1. + double angle = std::max(0.0, std::atan2(grad_y, grad_x)/pi + 1)/2; + + + if (gradient_type == hog_unsigned_gradient) + { + angle *= 2; + if (angle >= 1) + angle -= 1; + } + + + // now scale angle to between 0 and num_orientation_bins + angle *= num_orientation_bins; + + + const double strength = std::sqrt(grad_y*grad_y + grad_x*grad_x); + + + unsigned char quantized_angle_lower = static_cast(std::floor(angle)); + unsigned char quantized_angle_upper = static_cast(std::ceil(angle)); + + quantized_angle_lower %= num_orientation_bins; + quantized_angle_upper %= num_orientation_bins; + + const double angle_split = (angle-std::floor(angle)); + const double upper_strength = angle_split*strength; + const double lower_strength = (1-angle_split)*strength; + + // Stick into gradient counts. Note that we linearly interpolate between neighboring + // histogram buckets. + hist_counts[r][c].quantized_angle_lower = quantized_angle_lower; + hist_counts[r][c].quantized_angle_upper = quantized_angle_upper; + hist_counts[r][c].lower_strength = lower_strength; + hist_counts[r][c].upper_strength = upper_strength; + + } + } + + + // Now figure out how many feature extraction blocks we should have. + num_block_rows = (hist_counts.nr() - block_size*cell_size + 1)/(long)pixel_stride; + num_block_cols = (hist_counts.nc() - block_size*cell_size + 1)/(long)pixel_stride; + + } + + struct histogram_count + { + unsigned char quantized_angle_lower; + unsigned char quantized_angle_upper; + float lower_strength; + float upper_strength; + + void serialize(std::ostream& out) const + { + dlib::serialize(quantized_angle_lower, out); + dlib::serialize(quantized_angle_upper, out); + dlib::serialize(lower_strength, out); + dlib::serialize(upper_strength, out); + } + void deserialize(std::istream& in) + { + dlib::deserialize(quantized_angle_lower, in); + dlib::deserialize(quantized_angle_upper, in); + dlib::deserialize(lower_strength, in); + dlib::deserialize(upper_strength, in); + } + }; + + array2d hist_counts; + + mutable descriptor_type des; + + long num_block_rows; + long num_block_cols; + + + }; + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long T1, + unsigned long T2, + unsigned long T3, + unsigned char T4, + int T5 + > + void serialize ( + const fine_hog_image& item, + std::ostream& out + ) + { + item._PRIVATE_serialize(out); + } + + template < + unsigned long T1, + unsigned long T2, + unsigned long T3, + unsigned char T4, + int T5 + > + void deserialize ( + fine_hog_image& item, + std::istream& in + ) + { + item._PRIVATE_deserialize(in); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_FINE_HOG_IMaGE_Hh_ + diff --git a/ml/dlib/dlib/image_keypoint/fine_hog_image_abstract.h b/ml/dlib/dlib/image_keypoint/fine_hog_image_abstract.h new file mode 100644 index 000000000..50be85afe --- /dev/null +++ b/ml/dlib/dlib/image_keypoint/fine_hog_image_abstract.h @@ -0,0 +1,276 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_FINE_HOG_IMaGE_ABSTRACT_Hh_ +#ifdef DLIB_FINE_HOG_IMaGE_ABSTRACT_Hh_ + +#include "../array2d.h" +#include "../matrix.h" +#include "hog_abstract.h" + + +namespace dlib +{ + template < + unsigned long cell_size_, + unsigned long block_size_, + unsigned long pixel_stride_, + unsigned char num_orientation_bins_, + int gradient_type_ + > + class fine_hog_image : noncopyable + { + /*! + REQUIREMENTS ON TEMPLATE PARAMETERS + - cell_size_ > 1 + - block_size_ > 0 + - pixel_stride_ > 0 + - num_orientation_bins_ > 0 + - gradient_type_ == hog_signed_gradient or hog_unsigned_gradient + + INITIAL VALUE + - size() == 0 + + WHAT THIS OBJECT REPRESENTS + This object is a version of the hog_image that allows you to extract HOG + features at a finer resolution. The hog_image can only extract HOG features + cell_size_ pixels apart. However, this object, the fine_hog_image can + extract HOG features from every pixel location. + + The template arguments to this class have the same meaning as they do for + the hog_image, except for pixel_stride_. This controls the stepping between + HOG extraction locations. A value of 1 indicates HOG features should be + extracted from every pixel location. A value of 2 indicates every other pixel + location, etc. + + Finally, note that the interpolation used by this object is equivalent + to using hog_angle_interpolation with hog_image. + + THREAD SAFETY + Concurrent access to an instance of this object is not safe and should be protected + by a mutex lock except for the case where you are copying the configuration + (via copy_configuration()) of a fine_hog_image object to many other threads. + In this case, it is safe to copy the configuration of a shared object so long + as no other operations are performed on it. + !*/ + + public: + + const static unsigned long cell_size = cell_size_; + const static unsigned long block_size = block_size_; + const static unsigned long pixel_stride = pixel_stride_; + const static unsigned long num_orientation_bins = num_orientation_bins_; + const static int gradient_type = gradient_type_; + + const static long min_size = cell_size*block_size+2; + + typedef matrix descriptor_type; + + fine_hog_image ( + ); + /*! + ensures + - this object is properly initialized + !*/ + + void clear ( + ); + /*! + ensures + - this object will have its initial value + !*/ + + void copy_configuration ( + const fine_hog_image& + ); + /*! + ensures + - copies all the state information of item into *this, except for state + information populated by load(). More precisely, given two fine_hog_image + objects H1 and H2, the following sequence of instructions should always + result in both of them having the exact same state. + H2.copy_configuration(H1); + H1.load(img); + H2.load(img); + !*/ + + template < + typename image_type + > + inline void load ( + const image_type& img + ); + /*! + requires + - image_type is a dlib::matrix or something convertible to a matrix + via mat() + - pixel_traits::pixel_type>::has_alpha == false + ensures + - if (img.nr() < min_size || img.nc() < min_size) then + - the image is too small so we don't compute anything on it + - #size() == 0 + - else + - generates a HOG image from the given image. + - #size() > 0 + !*/ + + inline void unload( + ); + /*! + ensures + - #nr() == 0 + - #nc() == 0 + - clears only the state information which is populated by load(). For + example, let H be a fine_hog_image object. Then consider the two + sequences of instructions: + Sequence 1: + H.load(img); + H.unload(); + H.load(img); + + Sequence 2: + H.load(img); + Both sequence 1 and sequence 2 should have the same effect on H. + !*/ + + inline size_t size ( + ) const; + /*! + ensures + - returns nr()*nc() + !*/ + + inline long nr ( + ) const; + /*! + ensures + - returns the number of rows in this HOG image + !*/ + + inline long nc ( + ) const; + /*! + ensures + - returns the number of columns in this HOG image + !*/ + + long get_num_dimensions ( + ) const; + /*! + ensures + - returns the number of dimensions in the feature vectors generated by + this object. + - In particular, returns the value block_size*block_size*num_orientation_bins + !*/ + + inline const descriptor_type& operator() ( + long row, + long col + ) const; + /*! + requires + - 0 <= row < nr() + - 0 <= col < nc() + ensures + - returns the descriptor for the HOG block at the given row and column. This descriptor + will include information from a window that is located at get_block_rect(row,col) in + the original image given to load(). + - The returned descriptor vector will have get_num_dimensions() elements. + !*/ + + const rectangle get_block_rect ( + long row, + long col + ) const; + /*! + ensures + - returns a rectangle that tells you what part of the original image is associated + with a particular HOG block. That is, what part of the input image is associated + with (*this)(row,col). + - The returned rectangle will be cell_size*block_size pixels wide and tall. + !*/ + + const point image_to_feat_space ( + const point& p + ) const; + /*! + ensures + - Each local feature is extracted from a certain point in the input image. + This function returns the identity of the local feature corresponding + to the image location p. Or in other words, let P == image_to_feat_space(p), + then (*this)(P.y(),P.x()) == the local feature closest to, or centered at, + the point p in the input image. Note that some image points might not have + corresponding feature locations. E.g. border points or points outside the + image. In these cases the returned point will be outside get_rect(*this). + !*/ + + const rectangle image_to_feat_space ( + const rectangle& rect + ) const; + /*! + ensures + - returns rectangle(image_to_feat_space(rect.tl_corner()), image_to_feat_space(rect.br_corner())); + (i.e. maps a rectangle from image space to feature space) + !*/ + + const point feat_to_image_space ( + const point& p + ) const; + /*! + ensures + - returns the location in the input image space corresponding to the center + of the local feature at point p. In other words, this function computes + the inverse of image_to_feat_space(). Note that it may only do so approximately, + since more than one image location might correspond to the same local feature. + That is, image_to_feat_space() might not be invertible so this function gives + the closest possible result. + !*/ + + const rectangle feat_to_image_space ( + const rectangle& rect + ) const; + /*! + ensures + - return rectangle(feat_to_image_space(rect.tl_corner()), feat_to_image_space(rect.br_corner())); + (i.e. maps a rectangle from feature space to image space) + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long T1, + unsigned long T2, + unsigned long T3, + unsigned char T4, + int T5 + > + void serialize ( + const fine_hog_image& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + + template < + unsigned long T1, + unsigned long T2, + unsigned long T3, + unsigned char T4, + int T5 + > + void deserialize ( + fine_hog_image& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_FINE_HOG_IMaGE_ABSTRACT_Hh_ + diff --git a/ml/dlib/dlib/image_keypoint/hashed_feature_image.h b/ml/dlib/dlib/image_keypoint/hashed_feature_image.h new file mode 100644 index 000000000..80f429330 --- /dev/null +++ b/ml/dlib/dlib/image_keypoint/hashed_feature_image.h @@ -0,0 +1,518 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_HASHED_IMAGE_FEATUrES_Hh_ +#define DLIB_HASHED_IMAGE_FEATUrES_Hh_ + +#include "../lsh/projection_hash.h" +#include "hashed_feature_image_abstract.h" +#include +#include "../algs.h" +#include "../matrix.h" +#include "../statistics.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor, + typename hash_function_type_ = projection_hash + > + class hashed_feature_image : noncopyable + { + + public: + typedef feature_extractor feature_extractor_type; + typedef hash_function_type_ hash_function_type; + + typedef std::vector > descriptor_type; + + hashed_feature_image ( + ); + + void clear ( + ); + + void set_hash ( + const hash_function_type& hash_ + ); + + const hash_function_type& get_hash ( + ) const; + + void copy_configuration ( + const feature_extractor& item + ); + + void copy_configuration ( + const hashed_feature_image& item + ); + + template < + typename image_type + > + inline void load ( + const image_type& img + ); + + inline size_t size ( + ) const; + + inline long nr ( + ) const; + + inline long nc ( + ) const; + + inline long get_num_dimensions ( + ) const; + + void use_relative_feature_weights ( + ); + + void use_uniform_feature_weights ( + ); + + bool uses_uniform_feature_weights ( + ) const; + + inline const descriptor_type& operator() ( + long row, + long col + ) const; + + inline const rectangle get_block_rect ( + long row, + long col + ) const; + + inline const point image_to_feat_space ( + const point& p + ) const; + + inline const rectangle image_to_feat_space ( + const rectangle& rect + ) const; + + inline const point feat_to_image_space ( + const point& p + ) const; + + inline const rectangle feat_to_image_space ( + const rectangle& rect + ) const; + + template + friend void serialize ( + const hashed_feature_image& item, + std::ostream& out + ); + + template + friend void deserialize ( + hashed_feature_image& item, + std::istream& in + ); + + private: + + array2d feats; + feature_extractor fe; + hash_function_type phash; + std::vector feat_counts; + bool uniform_feature_weights; + + + // This is a transient variable. It is just here so it doesn't have to be + // reallocated over and over inside operator() + mutable descriptor_type hash_feats; + + }; + +// ---------------------------------------------------------------------------------------- + + template + void serialize ( + const hashed_feature_image& item, + std::ostream& out + ) + { + int version = 1; + serialize(version, out); + serialize(item.feats, out); + serialize(item.fe, out); + serialize(item.phash, out); + serialize(item.feat_counts, out); + serialize(item.uniform_feature_weights, out); + } + + template + void deserialize ( + hashed_feature_image& item, + std::istream& in + ) + { + int version = 0; + deserialize(version, in); + if (version != 1) + throw serialization_error("Unexpected version found while deserializing a dlib::hashed_feature_image object."); + + deserialize(item.feats, in); + deserialize(item.fe, in); + deserialize(item.phash, in); + deserialize(item.feat_counts, in); + deserialize(item.uniform_feature_weights, in); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// hashed_feature_image member functions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor, + typename hash_function_type + > + hashed_feature_image:: + hashed_feature_image ( + ) + { + clear(); + hash_feats.resize(1); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor, + typename hash_function_type + > + void hashed_feature_image:: + clear ( + ) + { + fe.clear(); + phash = hash_function_type(); + feats.clear(); + feat_counts.clear(); + uniform_feature_weights = false; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor, + typename hash_function_type + > + void hashed_feature_image:: + set_hash ( + const hash_function_type& hash_ + ) + { + phash = hash_; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor, + typename hash_function_type + > + const hash_function_type& hashed_feature_image:: + get_hash ( + ) const + { + return phash; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor, + typename hash_function_type + > + void hashed_feature_image:: + copy_configuration ( + const feature_extractor& item + ) + { + fe.copy_configuration(item); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor, + typename hash_function_type + > + void hashed_feature_image:: + copy_configuration ( + const hashed_feature_image& item + ) + { + fe.copy_configuration(item.fe); + phash = item.phash; + uniform_feature_weights = item.uniform_feature_weights; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor, + typename hash_function_type + > + template < + typename image_type + > + void hashed_feature_image:: + load ( + const image_type& img + ) + { + fe.load(img); + + if (fe.size() != 0) + { + feats.set_size(fe.nr(), fe.nc()); + feat_counts.assign(phash.num_hash_bins(),1); + if (uniform_feature_weights) + { + for (long r = 0; r < feats.nr(); ++r) + { + for (long c = 0; c < feats.nc(); ++c) + { + feats[r][c] = phash(fe(r,c)); + } + } + } + else + { + for (long r = 0; r < feats.nr(); ++r) + { + for (long c = 0; c < feats.nc(); ++c) + { + feats[r][c] = phash(fe(r,c)); + feat_counts[feats[r][c]]++; + } + } + } + } + else + { + feats.set_size(0,0); + } + + if (!uniform_feature_weights) + { + // use the inverse frequency as the scale for each feature. We also scale + // these counts so that they are invariant to the size of the image (we scale + // them so they all look like they come from a 500x400 images). + const double scale = image_size(img)/(500.0*400.0); + for (unsigned long i = 0; i < feat_counts.size(); ++i) + { + feat_counts[i] = scale/feat_counts[i]; + } + } + + fe.unload(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor, + typename hash_function_type + > + size_t hashed_feature_image:: + size ( + ) const + { + return feats.size(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor, + typename hash_function_type + > + long hashed_feature_image:: + nr ( + ) const + { + return feats.nr(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor, + typename hash_function_type + > + long hashed_feature_image:: + nc ( + ) const + { + return feats.nc(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor, + typename hash_function_type + > + long hashed_feature_image:: + get_num_dimensions ( + ) const + { + return phash.num_hash_bins(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor, + typename hash_function_type + > + void hashed_feature_image:: + use_relative_feature_weights ( + ) + { + uniform_feature_weights = false; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor, + typename hash_function_type + > + void hashed_feature_image:: + use_uniform_feature_weights ( + ) + { + uniform_feature_weights = true; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor, + typename hash_function_type + > + bool hashed_feature_image:: + uses_uniform_feature_weights ( + ) const + { + return uniform_feature_weights; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor, + typename hash_function_type + > + const std::vector >& hashed_feature_image:: + operator() ( + long row, + long col + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(0 <= row && row < nr() && + 0 <= col && col < nc(), + "\t descriptor_type hashed_feature_image::operator(row,col)" + << "\n\t Invalid inputs were given to this function" + << "\n\t row: " << row + << "\n\t col: " << col + << "\n\t nr(): " << nr() + << "\n\t nc(): " << nc() + << "\n\t this: " << this + ); + + hash_feats[0] = std::make_pair(feats[row][col],feat_counts[feats[row][col]]); + return hash_feats; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor, + typename hash_function_type + > + const rectangle hashed_feature_image:: + get_block_rect ( + long row, + long col + ) const + { + return fe.get_block_rect(row,col); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor, + typename hash_function_type + > + const point hashed_feature_image:: + image_to_feat_space ( + const point& p + ) const + { + return fe.image_to_feat_space(p); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor, + typename hash_function_type + > + const rectangle hashed_feature_image:: + image_to_feat_space ( + const rectangle& rect + ) const + { + return fe.image_to_feat_space(rect); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor, + typename hash_function_type + > + const point hashed_feature_image:: + feat_to_image_space ( + const point& p + ) const + { + return fe.feat_to_image_space(p); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor, + typename hash_function_type + > + const rectangle hashed_feature_image:: + feat_to_image_space ( + const rectangle& rect + ) const + { + return fe.feat_to_image_space(rect); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_HASHED_IMAGE_FEATUrES_Hh_ + + diff --git a/ml/dlib/dlib/image_keypoint/hashed_feature_image_abstract.h b/ml/dlib/dlib/image_keypoint/hashed_feature_image_abstract.h new file mode 100644 index 000000000..90c1348c5 --- /dev/null +++ b/ml/dlib/dlib/image_keypoint/hashed_feature_image_abstract.h @@ -0,0 +1,303 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_HASHED_IMAGE_FEATUrES_ABSTRACT_Hh_ +#ifdef DLIB_HASHED_IMAGE_FEATUrES_ABSTRACT_Hh_ + +#include "../lsh/projection_hash_abstract.h" +#include +#include "../matrix.h" +#include "../statistics.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor, + typename hash_function_type_ = projection_hash + > + class hashed_feature_image : noncopyable + { + /*! + REQUIREMENTS ON feature_extractor + - must be an object with an interface compatible with dlib::hog_image + + REQUIREMENTS ON hash_function_type_ + - must be an object with an interface compatible with projection_hash + + INITIAL VALUE + - size() == 0 + - uses_uniform_feature_weights() == false + + WHAT THIS OBJECT REPRESENTS + This object is a tool for performing image feature extraction. In + particular, it wraps another image feature extractor and converts the + wrapped image feature vectors into sparse indicator vectors. It does this + by hashing each feature vector into the range [0, get_num_dimensions()-1] + and then returns a new vector which is zero everywhere except for the + position determined by the hash. + + + THREAD SAFETY + Concurrent access to an instance of this object is not safe and should be protected + by a mutex lock except for the case where you are copying the configuration + (via copy_configuration()) of a hashed_feature_image object to many other threads. + In this case, it is safe to copy the configuration of a shared object so long + as no other operations are performed on it. + + + NOTATION + let BASE_FE denote the base feature_extractor object contained inside + the hashed_feature_image. + !*/ + + public: + + typedef feature_extractor feature_extractor_type; + typedef hash_function_type_ hash_function_type; + typedef std::vector > descriptor_type; + + hashed_feature_image ( + ); + /*! + ensures + - this object is properly initialized + !*/ + + void clear ( + ); + /*! + ensures + - this object will have its initial value + !*/ + + void set_hash ( + const hash_function_type& hash + ); + /*! + ensures + - #get_hash() == hash + !*/ + + const hash_function_type& get_hash ( + ) const; + /*! + ensures + - returns the hash function used by this object to hash + base feature vectors into integers. + !*/ + + void copy_configuration ( + const feature_extractor& item + ); + /*! + ensures + - performs BASE_FE.copy_configuration(item) + !*/ + + void copy_configuration ( + const hashed_feature_image& item + ); + /*! + ensures + - copies all the state information of item into *this, except for state + information populated by load(). More precisely, given two hashed_feature_image + objects H1 and H2, the following sequence of instructions should always + result in both of them having the exact same state. + H2.copy_configuration(H1); + H1.load(img); + H2.load(img); + !*/ + + template < + typename image_type + > + void load ( + const image_type& img + ); + /*! + requires + - image_type == any type that can be supplied to feature_extractor::load() + ensures + - performs BASE_FE.load(img) + i.e. does feature extraction. The features can be accessed using + operator() as defined below. + !*/ + + size_t size ( + ) const; + /*! + ensures + - returns BASE_FE.size() + !*/ + + long nr ( + ) const; + /*! + ensures + - returns BASE_FE.nr() + !*/ + + long nc ( + ) const; + /*! + ensures + - returns BASE_FE.nc() + !*/ + + long get_num_dimensions ( + ) const; + /*! + ensures + - returns the dimensionality of the feature vectors returned by operator(). + In this case, this is the number of hash bins. That is, get_hash().num_hash_bins() + !*/ + + void use_relative_feature_weights ( + ); + /*! + ensures + - #uses_uniform_feature_weights() == false + !*/ + + void use_uniform_feature_weights ( + ); + /*! + ensures + - #uses_uniform_feature_weights() == true + !*/ + + bool uses_uniform_feature_weights ( + ) const; + /*! + ensures + - returns true if this object weights each feature with a value of 1 and + false if it uses a weighting of 1/N where N is the number of occurrences + of the feature in an image (note that we normalize N so that it is + invariant to the size of the image given to load()). + !*/ + + const descriptor_type& operator() ( + long row, + long col + ) const; + /*! + requires + - 0 <= row < nr() + - 0 <= col < nc() + - It must be legal to evaluate expressions of the form: get_hash()(BASE_FE(row,col)) + (e.g. the hash function must be properly configured to process the feature + vectors produced by the base feature extractor) + ensures + - hashes BASE_FE(row,col) and returns the resulting indicator vector. + - To be precise, this function returns a sparse vector V such that: + - V.size() == 1 + - V[0].first == get_hash()(BASE_FE(row,col)) + - if (uses_uniform_feature_weights()) then + - V[0].second == 1 + - else + - V[0].second == 1/N where N is the number of times a feature in + hash bin V[0].first was observed in the image given to load(). + Note that we scale all the counts so that they are invariant to + the size of the image. + !*/ + + const rectangle get_block_rect ( + long row, + long col + ) const; + /*! + ensures + - returns BASE_FE.get_block_rect(row,col) + I.e. returns a rectangle that tells you what part of the original image is associated + with a particular feature vector. + !*/ + + const point image_to_feat_space ( + const point& p + ) const; + /*! + ensures + - returns BASE_FE.image_to_feat_space(p) + I.e. Each local feature is extracted from a certain point in the input image. + This function returns the identity of the local feature corresponding + to the image location p. Or in other words, let P == image_to_feat_space(p), + then (*this)(P.y(),P.x()) == the local feature closest to, or centered at, + the point p in the input image. Note that some image points might not have + corresponding feature locations. E.g. border points or points outside the + image. In these cases the returned point will be outside get_rect(*this). + !*/ + + const rectangle image_to_feat_space ( + const rectangle& rect + ) const; + /*! + ensures + - returns BASE_FE.image_to_feat_space(rect) + I.e. returns rectangle(image_to_feat_space(rect.tl_corner()), image_to_feat_space(rect.br_corner())); + (i.e. maps a rectangle from image space to feature space) + !*/ + + const point feat_to_image_space ( + const point& p + ) const; + /*! + ensures + - returns BASE_FE.feat_to_image_space(p) + I.e. returns the location in the input image space corresponding to the center + of the local feature at point p. In other words, this function computes + the inverse of image_to_feat_space(). Note that it may only do so approximately, + since more than one image location might correspond to the same local feature. + That is, image_to_feat_space() might not be invertible so this function gives + the closest possible result. + !*/ + + const rectangle feat_to_image_space ( + const rectangle& rect + ) const; + /*! + ensures + - returns BASE_FE.feat_to_image_space(rect) + I.e. return rectangle(feat_to_image_space(rect.tl_corner()), feat_to_image_space(rect.br_corner())); + (i.e. maps a rectangle from feature space to image space) + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename U + > + void serialize ( + const hashed_feature_image& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename U + > + void deserialize ( + hashed_feature_image& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_HASHED_IMAGE_FEATUrES_ABSTRACT_Hh_ + + + diff --git a/ml/dlib/dlib/image_keypoint/hessian_pyramid.h b/ml/dlib/dlib/image_keypoint/hessian_pyramid.h new file mode 100644 index 000000000..2e672c0d0 --- /dev/null +++ b/ml/dlib/dlib/image_keypoint/hessian_pyramid.h @@ -0,0 +1,531 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_HESSIAN_PYRAMId_Hh_ +#define DLIB_HESSIAN_PYRAMId_Hh_ + +#include "hessian_pyramid_abstract.h" +#include "../algs.h" +#include "../image_transforms/integral_image.h" +#include "../array.h" +#include "../array2d.h" +#include "../noncopyable.h" +#include "../matrix.h" +#include "../stl_checked.h" +#include +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + struct interest_point + { + interest_point() : scale(0), score(0), laplacian(0) {} + + dlib::vector center; + double scale; + double score; + double laplacian; + + bool operator < (const interest_point& p) const { return score < p.score; } + }; + +// ---------------------------------------------------------------------------------------- + + inline void serialize( + const interest_point& item, + std::ostream& out + ) + { + try + { + serialize(item.center,out); + serialize(item.scale,out); + serialize(item.score,out); + serialize(item.laplacian,out); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing object of type interest_point"); + } + } + +// ---------------------------------------------------------------------------------------- + + inline void deserialize( + interest_point& item, + std::istream& in + ) + { + try + { + deserialize(item.center,in); + deserialize(item.scale,in); + deserialize(item.score,in); + deserialize(item.laplacian,in); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while deserializing object of type interest_point"); + } + } + +// ---------------------------------------------------------------------------------------- + + class hessian_pyramid : noncopyable + { + public: + hessian_pyramid() + { + num_octaves = 0; + num_intervals = 0; + initial_step_size = 0; + } + + template + void build_pyramid ( + const integral_image_type& img, + long num_octaves, + long num_intervals, + long initial_step_size + ) + { + DLIB_ASSERT(num_octaves > 0 && num_intervals > 0 && initial_step_size > 0, + "\tvoid build_pyramid()" + << "\n\tAll arguments to this function must be > 0" + << "\n\t this: " << this + << "\n\t num_octaves: " << num_octaves + << "\n\t num_intervals: " << num_intervals + << "\n\t initial_step_size: " << initial_step_size + ); + + this->num_octaves = num_octaves; + this->num_intervals = num_intervals; + this->initial_step_size = initial_step_size; + + // allocate space for the pyramid + pyramid.resize(num_octaves*num_intervals); + for (long o = 0; o < num_octaves; ++o) + { + const long step_size = get_step_size(o); + for (long i = 0; i < num_intervals; ++i) + { + pyramid[num_intervals*o + i].set_size(img.nr()/step_size, img.nc()/step_size); + } + } + + // now fill out the pyramid with data + for (long o = 0; o < num_octaves; ++o) + { + const long step_size = get_step_size(o); + + for (long i = 0; i < num_intervals; ++i) + { + const long border_size = get_border_size(i)*step_size; + const long lobe_size = static_cast(std::pow(2.0, o+1.0)+0.5)*(i+1) + 1; + const double area_inv = 1.0/std::pow(3.0*lobe_size, 2.0); + + const long lobe_offset = lobe_size/2+1; + const point tl(-lobe_offset,-lobe_offset); + const point tr(lobe_offset,-lobe_offset); + const point bl(-lobe_offset,lobe_offset); + const point br(lobe_offset,lobe_offset); + + for (long r = border_size; r < img.nr() - border_size; r += step_size) + { + for (long c = border_size; c < img.nc() - border_size; c += step_size) + { + const point p(c,r); + + double Dxx = img.get_sum_of_area(centered_rect(p, lobe_size*3, 2*lobe_size-1)) - + img.get_sum_of_area(centered_rect(p, lobe_size, 2*lobe_size-1))*3.0; + + double Dyy = img.get_sum_of_area(centered_rect(p, 2*lobe_size-1, lobe_size*3)) - + img.get_sum_of_area(centered_rect(p, 2*lobe_size-1, lobe_size))*3.0; + + double Dxy = img.get_sum_of_area(centered_rect(p+bl, lobe_size, lobe_size)) + + img.get_sum_of_area(centered_rect(p+tr, lobe_size, lobe_size)) - + img.get_sum_of_area(centered_rect(p+tl, lobe_size, lobe_size)) - + img.get_sum_of_area(centered_rect(p+br, lobe_size, lobe_size)); + + // now we normalize the filter responses + Dxx *= area_inv; + Dyy *= area_inv; + Dxy *= area_inv; + + + double sign_of_laplacian = +1; + if (Dxx + Dyy < 0) + sign_of_laplacian = -1; + + double determinant = Dxx*Dyy - 0.81*Dxy*Dxy; + + // If the determinant is negative then just blank it out by setting + // it to zero. + if (determinant < 0) + determinant = 0; + + // Save the determinant of the Hessian into our image pyramid. Also + // pack the laplacian sign into the value so we can get it out later. + pyramid[o*num_intervals + i][r/step_size][c/step_size] = sign_of_laplacian*determinant; + + } + } + + } + } + } + + long get_border_size ( + long interval + ) const + { + DLIB_ASSERT(0 <= interval && interval < intervals(), + "\tlong get_border_size(interval)" + << "\n\tInvalid interval value" + << "\n\t this: " << this + << "\n\t interval: " << interval + ); + + const double lobe_size = 2.0*(interval+1) + 1; + const double filter_size = 3*lobe_size; + + const long bs = static_cast(std::ceil(filter_size/2.0)); + return bs; + } + + long get_step_size ( + long octave + ) const + { + DLIB_ASSERT(0 <= octave && octave < octaves(), + "\tlong get_step_size(octave)" + << "\n\tInvalid octave value" + << "\n\t this: " << this + << "\n\t octave: " << octave + ); + + return initial_step_size*static_cast(std::pow(2.0, (double)octave)+0.5); + } + + long nr ( + long octave + ) const + { + DLIB_ASSERT(0 <= octave && octave < octaves(), + "\tlong nr(octave)" + << "\n\tInvalid octave value" + << "\n\t this: " << this + << "\n\t octave: " << octave + ); + + return pyramid[num_intervals*octave].nr(); + } + + long nc ( + long octave + ) const + { + DLIB_ASSERT(0 <= octave && octave < octaves(), + "\tlong nc(octave)" + << "\n\tInvalid octave value" + << "\n\t this: " << this + << "\n\t octave: " << octave + ); + + return pyramid[num_intervals*octave].nc(); + } + + double get_value ( + long octave, + long interval, + long r, + long c + ) const + { + DLIB_ASSERT(0 <= octave && octave < octaves() && + 0 <= interval && interval < intervals() && + get_border_size(interval) <= r && r < nr(octave)-get_border_size(interval) && + get_border_size(interval) <= c && c < nc(octave)-get_border_size(interval), + "\tdouble get_value(octave, interval, r, c)" + << "\n\tInvalid inputs to this function" + << "\n\t this: " << this + << "\n\t octave: " << octave + << "\n\t interval: " << interval + << "\n\t octaves: " << octaves() + << "\n\t intervals: " << intervals() + << "\n\t r: " << r + << "\n\t c: " << c + << "\n\t nr(octave): " << nr(octave) + << "\n\t nc(octave): " << nc(octave) + << "\n\t get_border_size(interval): " << get_border_size(interval) + ); + + return std::abs(pyramid[num_intervals*octave + interval][r][c]); + } + + double get_laplacian ( + long octave, + long interval, + long r, + long c + ) const + { + DLIB_ASSERT(0 <= octave && octave < octaves() && + 0 <= interval && interval < intervals() && + get_border_size(interval) <= r && r < nr(octave)-get_border_size(interval) && + get_border_size(interval) <= c && c < nc(octave)-get_border_size(interval), + "\tdouble get_laplacian(octave, interval, r, c)" + << "\n\tInvalid inputs to this function" + << "\n\t this: " << this + << "\n\t octave: " << octave + << "\n\t interval: " << interval + << "\n\t octaves: " << octaves() + << "\n\t intervals: " << intervals() + << "\n\t r: " << r + << "\n\t c: " << c + << "\n\t nr(octave): " << nr(octave) + << "\n\t nc(octave): " << nc(octave) + << "\n\t get_border_size(interval): " << get_border_size(interval) + ); + + // return the sign of the laplacian + if (pyramid[num_intervals*octave + interval][r][c] > 0) + return +1; + else + return -1; + } + + long octaves ( + ) const { return num_octaves; } + + long intervals ( + ) const { return num_intervals; } + + private: + + long num_octaves; + long num_intervals; + long initial_step_size; + + typedef array2d image_type; + typedef array pyramid_type; + + pyramid_type pyramid; + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + namespace hessian_pyramid_helpers + { + inline bool is_maximum_in_region( + const hessian_pyramid& pyr, + long o, + long i, + long r, + long c + ) + { + // First check if this point is near the edge of the octave + // If it is then we say it isn't a maximum as these points are + // not as reliable. + if (i <= 0 || i+1 >= pyr.intervals()) + { + return false; + } + + const double val = pyr.get_value(o,i,r,c); + + // now check if there are any bigger values around this guy + for (long ii = i-1; ii <= i+1; ++ii) + { + for (long rr = r-1; rr <= r+1; ++rr) + { + for (long cc = c-1; cc <= c+1; ++cc) + { + if (pyr.get_value(o,ii,rr,cc) > val) + return false; + } + } + } + + return true; + } + + // ------------------------------------------------------------------------------------ + + inline const matrix get_hessian_gradient ( + const hessian_pyramid& pyr, + long o, + long i, + long r, + long c + ) + { + matrix grad; + grad(0) = (pyr.get_value(o,i,r,c+1) - pyr.get_value(o,i,r,c-1))/2.0; + grad(1) = (pyr.get_value(o,i,r+1,c) - pyr.get_value(o,i,r-1,c))/2.0; + grad(2) = (pyr.get_value(o,i+1,r,c) - pyr.get_value(o,i-1,r,c))/2.0; + return grad; + } + + // ------------------------------------------------------------------------------------ + + inline const matrix get_hessian_hessian ( + const hessian_pyramid& pyr, + long o, + long i, + long r, + long c + ) + { + matrix hess; + const double val = pyr.get_value(o,i,r,c); + + double Dxx = (pyr.get_value(o,i,r,c+1) + pyr.get_value(o,i,r,c-1)) - 2*val; + double Dyy = (pyr.get_value(o,i,r+1,c) + pyr.get_value(o,i,r-1,c)) - 2*val; + double Dss = (pyr.get_value(o,i+1,r,c) + pyr.get_value(o,i-1,r,c)) - 2*val; + + double Dxy = (pyr.get_value(o,i,r+1,c+1) + pyr.get_value(o,i,r-1,c-1) - + pyr.get_value(o,i,r-1,c+1) - pyr.get_value(o,i,r+1,c-1)) / 4.0; + + double Dxs = (pyr.get_value(o,i+1,r,c+1) + pyr.get_value(o,i-1,r,c-1) - + pyr.get_value(o,i-1,r,c+1) - pyr.get_value(o,i+1,r,c-1)) / 4.0; + + double Dys = (pyr.get_value(o,i+1,r+1,c) + pyr.get_value(o,i-1,r-1,c) - + pyr.get_value(o,i-1,r+1,c) - pyr.get_value(o,i+1,r-1,c)) / 4.0; + + + hess = Dxx, Dxy, Dxs, + Dxy, Dyy, Dys, + Dxs, Dys, Dss; + + return hess; + } + + // ------------------------------------------------------------------------------------ + + inline const interest_point interpolate_point ( + const hessian_pyramid& pyr, + long o, + long i, + long r, + long c + ) + { + dlib::vector p(c,r); + + dlib::vector start_point(c,r,i); + dlib::vector interpolated_point = -inv(get_hessian_hessian(pyr,o,i,r,c))*get_hessian_gradient(pyr,o,i,r,c); + + //cout << "inter: " << trans(interpolated_point); + + interest_point temp; + if (max(abs(interpolated_point)) < 0.5) + { + p = (start_point+interpolated_point)*pyr.get_step_size(o); + const double lobe_size = std::pow(2.0, o+1.0)*(i+interpolated_point.z()+1) + 1; + const double filter_size = 3*lobe_size; + const double scale = 1.2/9.0 * filter_size; + + temp.center = p; + temp.scale = scale; + temp.score = pyr.get_value(o,i,r,c); + temp.laplacian = pyr.get_laplacian(o,i,r,c); + } + else + { + // this indicates to the caller that no interest point was found. + temp.score = -1; + } + + return temp; + } + + } + +// ---------------------------------------------------------------------------------------- + + template + void get_interest_points ( + const hessian_pyramid& pyr, + double threshold, + std::vector& result_points + ) + { + DLIB_ASSERT(threshold >= 0, + "\tvoid get_interest_points()" + << "\n\t Invalid arguments to this function" + << "\n\t threshold: " << threshold + ); + using namespace std; + using namespace hessian_pyramid_helpers; + + result_points.clear(); + + for (long o = 0; o < pyr.octaves(); ++o) + { + const long nr = pyr.nr(o); + const long nc = pyr.nc(o); + + // do non-maximum suppression on all the intervals in the current octave and + // accumulate the results in result_points + for (long i = 1; i < pyr.intervals()-1; i += 1) + { + const long border_size = pyr.get_border_size(i+1); + for (long r = border_size+1; r < nr - border_size-1; r += 1) + { + for (long c = border_size+1; c < nc - border_size-1; c += 1) + { + double max_val = pyr.get_value(o,i,r,c); + long max_i = i; + long max_r = r; + long max_c = c; + + + // If the max point we found is really a maximum in its own region and + // is big enough then add it to the results. + if (max_val >= threshold && is_maximum_in_region(pyr, o, max_i, max_r, max_c)) + { + //cout << max_val << endl; + interest_point sp = interpolate_point (pyr, o, max_i, max_r, max_c); + if (sp.score >= threshold) + { + result_points.push_back(sp); + } + } + + } + } + } + } + + } + +// ---------------------------------------------------------------------------------------- + + template + void get_interest_points ( + const hessian_pyramid& pyr, + double threshold, + std_vector_c& result_points + ) + /*! + This function is just an overload that automatically casts std_vector_c objects + into std::vector objects. (Usually this is automatic but the template argument + there messes up the conversion so we have to do it explicitly) + !*/ + { + std::vector& v = result_points; + get_interest_points(pyr, threshold, v); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_HESSIAN_PYRAMId_Hh_ + diff --git a/ml/dlib/dlib/image_keypoint/hessian_pyramid_abstract.h b/ml/dlib/dlib/image_keypoint/hessian_pyramid_abstract.h new file mode 100644 index 000000000..2db39c210 --- /dev/null +++ b/ml/dlib/dlib/image_keypoint/hessian_pyramid_abstract.h @@ -0,0 +1,244 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_HESSIAN_PYRAMId_ABSTRACT_Hh_ +#ifdef DLIB_HESSIAN_PYRAMId_ABSTRACT_Hh_ + +#include "../image_transforms/integral_image_abstract.h" +#include "../noncopyable.h" +#include + +namespace dlib +{ + + class hessian_pyramid : noncopyable + { + /*! + INITIAL VALUE + - octaves() == 0 + - intervals() == 0 + + WHAT THIS OBJECT REPRESENTS + This object represents an image pyramid where each level in the + pyramid holds determinants of Hessian matrices for the original + input image. This object can be used to find stable interest + points in an image. For further details consult the following + papers. + + This object is an implementation of the fast Hessian pyramid + as described in the paper: + SURF: Speeded Up Robust Features + By Herbert Bay, Tinne Tuytelaars, and Luc Van Gool + + This implementation was also influenced by the very well documented + OpenSURF library and its corresponding description of how the fast + Hessian algorithm functions: + Notes on the OpenSURF Library + Christopher Evans + !*/ + public: + + template + void build_pyramid ( + const integral_image_type& img, + long num_octaves, + long num_intervals, + long initial_step_size + ); + /*! + requires + - num_octaves > 0 + - num_intervals > 0 + - initial_step_size > 0 + - integral_image_type == an object such as dlib::integral_image or another + type that implements the interface defined in image_transforms/integral_image_abstract.h + ensures + - #get_step_size(0) == initial_step_size + - #octaves() == num_octaves + - #intervals() == num_intervals + - creates a Hessian pyramid from the given input image. + !*/ + + long octaves ( + ) const; + /*! + ensures + - returns the number of octaves in this pyramid + !*/ + + long intervals ( + ) const; + /*! + ensures + - returns the number of intervals in this pyramid + !*/ + + long get_border_size ( + long interval + ) const; + /*! + requires + - 0 <= interval < intervals() + ensures + - Each interval of the pyramid has a certain sized border region where we + can't compute the Hessian values since they are too close to the edge + of the input image. This function returns the size of that border. + !*/ + + long get_step_size ( + long octave + ) const; + /*! + requires + - 0 <= octave < octaves() + ensures + - Each octave has a step size value. This value determines how many + input image pixels separate each pixel in the given pyramid octave. + As the octave gets larger (i.e. as it goes to the top of the pyramid) the + step size gets bigger and thus the pyramid narrows. + !*/ + + long nr ( + long octave + ) const; + /*! + requires + - 0 <= octave < octaves() + ensures + - returns the number of rows there are per layer in the given + octave of pyramid + !*/ + + long nc ( + long octave + ) const; + /*! + requires + - 0 <= octave < octaves() + ensures + - returns the number of columns there are per layer in the given + octave of pyramid + !*/ + + double get_value ( + long octave, + long interval, + long r, + long c + ) const; + /*! + requires + - 0 <= octave < octaves() + - 0 <= interval < intervals() + - Let BS == get_border_size(interval): then + - BS <= r < nr(octave)-BS + - BS <= c < nc(octave)-BS + ensures + - returns the determinant of the Hessian from the given octave and interval + of the pyramid. The specific point sampled at this pyramid level is + the one that corresponds to the input image point (point(r,c)*get_step_size(octave)). + !*/ + + double get_laplacian ( + long octave, + long interval, + long r, + long c + ) const; + /*! + requires + - 0 <= octave < octaves() + - 0 <= interval < intervals() + - Let BS == get_border_size(interval): then + - BS <= r < nr(octave)-BS + - BS <= c < nc(octave)-BS + ensures + - returns the sign of the laplacian for the given octave and interval + of the pyramid. The specific point sampled at this pyramid level is + the one that corresponds to the input image point (point(r,c)*get_step_size(octave)). + - The laplacian is the trace of the Hessian at the given point. So this + function returns either +1 or -1 depending on this number's sign. This + value can be used to distinguish bright blobs on dark backgrounds from + the reverse. + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + struct interest_point + { + /*! + WHAT THIS OBJECT REPRESENTS + This object contains the interest points found using the + hessian_pyramid object. Its fields have the following + meanings: + - center == the x/y location of the center of the interest point + (in image space coordinates. y gives the row and x gives the + column in the image) + - scale == the scale at which the point was detected. This is a number + >= 1. If it is 1 then it means the interest point was detected at + the lowest scale in the image pyramid. Larger numbers indicate that + the interest point is from high up in the image pyramid. For + example, a scale of 4 would mean the interest point was located at a + point in the pyramid where the image had been shrunk by a factor of 4. + - score == the determinant of the Hessian for the interest point + - laplacian == the sign of the laplacian for the interest point + !*/ + + interest_point() : scale(0), score(0), laplacian(0) {} + + dlib::vector center; + double scale; + double score; + double laplacian; + + bool operator < (const interest_point& p) const { return score < p.score; } + /*! + This function is here so you can sort interest points according to + their scores + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + void serialize ( + const interest_point& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + + void deserialize ( + interest_point& item, + std::istream& in + ); + /*! + provides serialization support + !*/ + +// ---------------------------------------------------------------------------------------- + + template + void get_interest_points ( + const hessian_pyramid& pyr, + double threshold, + std::vector& result_points + ) + /*! + requires + - threshold >= 0 + ensures + - extracts interest points from the pyramid pyr and stores them into + result_points (note that result_points is cleared before these new interest + points are added to it). + - Only interest points with determinant values in the pyramid larger than + threshold are output. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_HESSIAN_PYRAMId_ABSTRACT_Hh_ + diff --git a/ml/dlib/dlib/image_keypoint/hog.h b/ml/dlib/dlib/image_keypoint/hog.h new file mode 100644 index 000000000..823c25d6d --- /dev/null +++ b/ml/dlib/dlib/image_keypoint/hog.h @@ -0,0 +1,514 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_HoG_Hh_ +#define DLIB_HoG_Hh_ + +#include "hog_abstract.h" +#include "../algs.h" +#include "../matrix.h" +#include "../array2d.h" +#include "../geometry.h" +#include + +namespace dlib +{ + enum + { + hog_no_interpolation, + hog_angle_interpolation, + hog_full_interpolation, + hog_signed_gradient, + hog_unsigned_gradient + }; + + template < + unsigned long cell_size_, + unsigned long block_size_, + unsigned long cell_stride_, + unsigned long num_orientation_bins_, + int gradient_type_, + int interpolation_type_ + > + class hog_image : noncopyable + { + COMPILE_TIME_ASSERT(cell_size_ > 1); + COMPILE_TIME_ASSERT(block_size_ > 0); + COMPILE_TIME_ASSERT(cell_stride_ > 0); + COMPILE_TIME_ASSERT(num_orientation_bins_ > 0); + + COMPILE_TIME_ASSERT( gradient_type_ == hog_signed_gradient || + gradient_type_ == hog_unsigned_gradient); + + COMPILE_TIME_ASSERT( interpolation_type_ == hog_no_interpolation || + interpolation_type_ == hog_angle_interpolation || + interpolation_type_ == hog_full_interpolation ); + + + public: + + const static unsigned long cell_size = cell_size_; + const static unsigned long block_size = block_size_; + const static unsigned long cell_stride = cell_stride_; + const static unsigned long num_orientation_bins = num_orientation_bins_; + const static int gradient_type = gradient_type_; + const static int interpolation_type = interpolation_type_; + + const static long min_size = cell_size*block_size+2; + + typedef matrix descriptor_type; + + hog_image ( + ) : + num_block_rows(0), + num_block_cols(0) + {} + + void clear ( + ) + { + num_block_rows = 0; + num_block_cols = 0; + hist_cells.clear(); + } + + void copy_configuration ( + const hog_image& + ){} + + template < + typename image_type + > + inline void load ( + const image_type& img + ) + { + COMPILE_TIME_ASSERT( pixel_traits::pixel_type>::has_alpha == false ); + load_impl(mat(img)); + } + + inline void unload( + ) { clear(); } + + inline size_t size ( + ) const { return static_cast(nr()*nc()); } + + inline long nr ( + ) const { return num_block_rows; } + + inline long nc ( + ) const { return num_block_cols; } + + long get_num_dimensions ( + ) const + { + return block_size*block_size*num_orientation_bins; + } + + inline const descriptor_type& operator() ( + long row, + long col + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT( 0 <= row && row < nr() && + 0 <= col && col < nc(), + "\t descriptor_type hog_image::operator()()" + << "\n\t invalid row or col argument" + << "\n\t row: " << row + << "\n\t col: " << col + << "\n\t nr(): " << nr() + << "\n\t nc(): " << nc() + << "\n\t this: " << this + ); + + row *= cell_stride; + col *= cell_stride; + ++row; + ++col; + + int feat = 0; + for (unsigned long r = 0; r < block_size; ++r) + { + for (unsigned long c = 0; c < block_size; ++c) + { + for (unsigned long i = 0; i < num_orientation_bins; ++i) + { + des(feat++) = hist_cells[row+r][col+c].values[i]; + } + } + } + + des /= length(des) + 1e-8; + + return des; + } + + const rectangle get_block_rect ( + long row, + long col + ) const + { + row *= cell_stride; + col *= cell_stride; + + row *= cell_size; + col *= cell_size; + + // do this to account for the 1 pixel padding we use all around the image + ++row; + ++col; + + return rectangle(col, row, col+cell_size*block_size-1, row+cell_size*block_size-1); + } + + const point image_to_feat_space ( + const point& p + ) const + { + + const long half_block = block_size/2; + if ((block_size%2) == 0) + { + return point(((p.x()-1)/(long)cell_size - half_block)/(long)cell_stride, + ((p.y()-1)/(long)cell_size - half_block)/(long)cell_stride); + } + else + { + return point(((p.x()-1-(long)cell_size/2)/(long)cell_size - half_block)/(long)cell_stride, + ((p.y()-1-(long)cell_size/2)/(long)cell_size - half_block)/(long)cell_stride); + } + } + + const rectangle image_to_feat_space ( + const rectangle& rect + ) const + { + return rectangle(image_to_feat_space(rect.tl_corner()), image_to_feat_space(rect.br_corner())); + } + + const point feat_to_image_space ( + const point& p + ) const + { + const long half_block = block_size/2; + if ((block_size%2) == 0) + { + return point((p.x()*cell_stride + half_block)*cell_size + 1, + (p.y()*cell_stride + half_block)*cell_size + 1); + } + else + { + return point((p.x()*cell_stride + half_block)*cell_size + 1 + cell_size/2, + (p.y()*cell_stride + half_block)*cell_size + 1 + cell_size/2); + } + } + + const rectangle feat_to_image_space ( + const rectangle& rect + ) const + { + return rectangle(feat_to_image_space(rect.tl_corner()), feat_to_image_space(rect.br_corner())); + } + + + + // these _PRIVATE_ functions are only here as a workaround for a bug in visual studio 2005. + void _PRIVATE_serialize (std::ostream& out) const + { + // serialize hist_cells + serialize(hist_cells.nc(),out); + serialize(hist_cells.nr(),out); + hist_cells.reset(); + while (hist_cells.move_next()) + serialize(hist_cells.element().values,out); + hist_cells.reset(); + + + serialize(num_block_rows, out); + serialize(num_block_cols, out); + } + + void _PRIVATE_deserialize (std::istream& in ) + { + // deserialize item.hist_cells + long nc, nr; + deserialize(nc,in); + deserialize(nr,in); + hist_cells.set_size(nr,nc); + while (hist_cells.move_next()) + deserialize(hist_cells.element().values,in); + hist_cells.reset(); + + + deserialize(num_block_rows, in); + deserialize(num_block_cols, in); + } + + private: + + template < + typename image_type + > + void load_impl ( + const image_type& img + ) + { + // Note that we keep a border of 1 pixel all around the image so that we don't have + // to worry about running outside the image when computing the horizontal and vertical + // gradients. + + // Note also that we have a border of unused cells around the hist_cells array so that we + // don't have to worry about edge effects when doing the interpolation in the main loop + // below. + + + // check if the window is just too small + if (img.nr() < min_size || img.nc() < min_size) + { + // If the image is smaller than our windows then there aren't any descriptors at all! + num_block_rows = 0; + num_block_cols = 0; + return; + } + + // Make sure we have the right number of cell histograms and that they are + // all set to zero. + hist_cells.set_size((img.nr()-2)/cell_size+2, (img.nc()-2)/cell_size+2); + for (long r = 0; r < hist_cells.nr(); ++r) + { + for (long c = 0; c < hist_cells.nc(); ++c) + { + hist_cells[r][c].zero(); + } + } + + + // loop over all the histogram cells and fill them out + for (long rh = 1; rh < hist_cells.nr()-1; ++rh) + { + for (long ch = 1; ch < hist_cells.nc()-1; ++ch) + { + // Fill out the current histogram cell. + // First, figure out the row and column offsets into the image for the current histogram cell. + const long roff = (rh-1)*cell_size + 1; + const long coff = (ch-1)*cell_size + 1; + + for (long r = 0; r < (long)cell_size; ++r) + { + for (long c = 0; c < (long)cell_size; ++c) + { + unsigned long left; + unsigned long right; + unsigned long top; + unsigned long bottom; + + assign_pixel(left, img(r+roff,c+coff-1)); + assign_pixel(right, img(r+roff,c+coff+1)); + assign_pixel(top, img(r+roff-1,c+coff)); + assign_pixel(bottom, img(r+roff+1,c+coff)); + + double grad_x = (long)right-(long)left; + double grad_y = (long)top-(long)bottom; + + // obtain the angle of the gradient. Make sure it is scaled between 0 and 1. + double angle = std::max(0.0, std::atan2(grad_y, grad_x)/pi + 1)/2; + + + if (gradient_type == hog_unsigned_gradient) + { + angle *= 2; + if (angle >= 1) + angle -= 1; + } + + + // now scale angle to between 0 and num_orientation_bins + angle *= num_orientation_bins; + + + const double strength = std::sqrt(grad_y*grad_y + grad_x*grad_x); + + + if (interpolation_type == hog_no_interpolation) + { + // no interpolation + hist_cells[rh][ch].values[round_to_int(angle)%num_orientation_bins] += strength; + } + else // if we should do some interpolation + { + unsigned long quantized_angle_lower = static_cast(std::floor(angle)); + unsigned long quantized_angle_upper = static_cast(std::ceil(angle)); + + quantized_angle_lower %= num_orientation_bins; + quantized_angle_upper %= num_orientation_bins; + + const double angle_split = (angle-std::floor(angle)); + const double upper_strength = angle_split*strength; + const double lower_strength = (1-angle_split)*strength; + + if (interpolation_type == hog_angle_interpolation) + { + // Stick into gradient histogram. Note that we linearly interpolate between neighboring + // histogram buckets. + hist_cells[rh][ch].values[quantized_angle_lower] += lower_strength; + hist_cells[rh][ch].values[quantized_angle_upper] += upper_strength; + } + else // here we do hog_full_interpolation + { + const double center_r = (cell_size-1)/2.0; + const double center_c = (cell_size-1)/2.0; + + const double lin_neighbor_r = std::abs(center_r - r)/cell_size; + const double lin_main_r = 1-lin_neighbor_r; + + const double lin_neighbor_c = std::abs(center_c - c)/cell_size; + const double lin_main_c = 1-lin_neighbor_c; + + // Which neighboring cells we interpolate into depends on which + // corner of our main cell we are nearest. + if (r < center_r) + { + if (c < center_c) + { + hist_cells[rh][ch].values[quantized_angle_upper] += upper_strength * lin_main_r*lin_main_c; + hist_cells[rh][ch].values[quantized_angle_lower] += lower_strength * lin_main_r*lin_main_c; + + hist_cells[rh-1][ch].values[quantized_angle_upper] += upper_strength * lin_neighbor_r*lin_main_c; + hist_cells[rh-1][ch].values[quantized_angle_lower] += lower_strength * lin_neighbor_r*lin_main_c; + + hist_cells[rh][ch-1].values[quantized_angle_upper] += upper_strength * lin_neighbor_c*lin_main_r; + hist_cells[rh][ch-1].values[quantized_angle_lower] += lower_strength * lin_neighbor_c*lin_main_r; + + hist_cells[rh-1][ch-1].values[quantized_angle_upper] += upper_strength * lin_neighbor_c*lin_neighbor_r; + hist_cells[rh-1][ch-1].values[quantized_angle_lower] += lower_strength * lin_neighbor_c*lin_neighbor_r; + } + else + { + hist_cells[rh][ch].values[quantized_angle_upper] += upper_strength * lin_main_r*lin_main_c; + hist_cells[rh][ch].values[quantized_angle_lower] += lower_strength * lin_main_r*lin_main_c; + + hist_cells[rh-1][ch].values[quantized_angle_upper] += upper_strength * lin_neighbor_r*lin_main_c; + hist_cells[rh-1][ch].values[quantized_angle_lower] += lower_strength * lin_neighbor_r*lin_main_c; + + hist_cells[rh][ch+1].values[quantized_angle_upper] += upper_strength * lin_neighbor_c*lin_main_r; + hist_cells[rh][ch+1].values[quantized_angle_lower] += lower_strength * lin_neighbor_c*lin_main_r; + + hist_cells[rh-1][ch+1].values[quantized_angle_upper] += upper_strength * lin_neighbor_c*lin_neighbor_r; + hist_cells[rh-1][ch+1].values[quantized_angle_lower] += lower_strength * lin_neighbor_c*lin_neighbor_r; + } + } + else + { + if (c < center_c) + { + hist_cells[rh][ch].values[quantized_angle_upper] += upper_strength * lin_main_r*lin_main_c; + hist_cells[rh][ch].values[quantized_angle_lower] += lower_strength * lin_main_r*lin_main_c; + + hist_cells[rh+1][ch].values[quantized_angle_upper] += upper_strength * lin_neighbor_r*lin_main_c; + hist_cells[rh+1][ch].values[quantized_angle_lower] += lower_strength * lin_neighbor_r*lin_main_c; + + hist_cells[rh][ch-1].values[quantized_angle_upper] += upper_strength * lin_neighbor_c*lin_main_r; + hist_cells[rh][ch-1].values[quantized_angle_lower] += lower_strength * lin_neighbor_c*lin_main_r; + + hist_cells[rh+1][ch-1].values[quantized_angle_upper] += upper_strength * lin_neighbor_c*lin_neighbor_r; + hist_cells[rh+1][ch-1].values[quantized_angle_lower] += lower_strength * lin_neighbor_c*lin_neighbor_r; + } + else + { + hist_cells[rh][ch].values[quantized_angle_upper] += upper_strength * lin_main_r*lin_main_c; + hist_cells[rh][ch].values[quantized_angle_lower] += lower_strength * lin_main_r*lin_main_c; + + hist_cells[rh+1][ch].values[quantized_angle_upper] += upper_strength * lin_neighbor_r*lin_main_c; + hist_cells[rh+1][ch].values[quantized_angle_lower] += lower_strength * lin_neighbor_r*lin_main_c; + + hist_cells[rh][ch+1].values[quantized_angle_upper] += upper_strength * lin_neighbor_c*lin_main_r; + hist_cells[rh][ch+1].values[quantized_angle_lower] += lower_strength * lin_neighbor_c*lin_main_r; + + hist_cells[rh+1][ch+1].values[quantized_angle_upper] += upper_strength * lin_neighbor_c*lin_neighbor_r; + hist_cells[rh+1][ch+1].values[quantized_angle_lower] += lower_strength * lin_neighbor_c*lin_neighbor_r; + } + } + } + } + + + } + } + } + } + + + // Now figure out how many blocks we should have. Note again that the hist_cells has a border of + // unused cells (thats where that -2 comes from). + num_block_rows = (hist_cells.nr()-2 - (block_size-1) + cell_stride - 1)/cell_stride; + num_block_cols = (hist_cells.nc()-2 - (block_size-1) + cell_stride - 1)/cell_stride; + + } + + unsigned long round_to_int( + double val + ) const + { + return static_cast(std::floor(val + 0.5)); + } + + struct histogram + { + void zero() + { + for (unsigned long i = 0; i < num_orientation_bins; ++i) + values[i] = 0; + } + double values[num_orientation_bins]; + }; + + array2d hist_cells; + + mutable descriptor_type des; + + long num_block_rows; + long num_block_cols; + + + }; + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long T1, + unsigned long T2, + unsigned long T3, + unsigned long T4, + int T5, + int T6 + > + void serialize ( + const hog_image& item, + std::ostream& out + ) + { + item._PRIVATE_serialize(out); + } + + template < + unsigned long T1, + unsigned long T2, + unsigned long T3, + unsigned long T4, + int T5, + int T6 + > + void deserialize ( + hog_image& item, + std::istream& in + ) + { + item._PRIVATE_deserialize(in); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_HoG_Hh_ + diff --git a/ml/dlib/dlib/image_keypoint/hog_abstract.h b/ml/dlib/dlib/image_keypoint/hog_abstract.h new file mode 100644 index 000000000..26c8cab64 --- /dev/null +++ b/ml/dlib/dlib/image_keypoint/hog_abstract.h @@ -0,0 +1,335 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_HoG_ABSTRACT_Hh_ +#ifdef DLIB_HoG_ABSTRACT_Hh_ + +#include "../algs.h" +#include "../matrix.h" +#include "../array2d.h" +#include "../geometry.h" +#include + +namespace dlib +{ + enum + { + hog_no_interpolation, + hog_angle_interpolation, + hog_full_interpolation, + hog_signed_gradient, + hog_unsigned_gradient + }; + + template < + unsigned long cell_size_, + unsigned long block_size_, + unsigned long cell_stride_, + unsigned long num_orientation_bins_, + int gradient_type_, + int interpolation_type_ + > + class hog_image : noncopyable + { + /*! + REQUIREMENTS ON TEMPLATE PARAMETERS + - cell_size_ > 1 + - block_size_ > 0 + - cell_stride_ > 0 + - num_orientation_bins_ > 0 + - gradient_type_ == hog_signed_gradient or hog_unsigned_gradient + - interpolation_type_ == hog_no_interpolation, hog_angle_interpolation, or + hog_full_interpolation + + INITIAL VALUE + - size() == 0 + + WHAT THIS OBJECT REPRESENTS + This object is a tool for performing the image feature extraction algorithm + described in the following paper: + Histograms of Oriented Gradients for Human Detection + by Navneet Dalal and Bill Triggs + + + To summarize the technique, this object tiles non-overlapping cells over an + image. Each of these cells is a box that is cell_size by cell_size pixels + in size. Each cell contains an array of size num_orientation_bins. The array + in a cell is used to store a histogram of all the edge orientations contained + within the cell's image region. + + Once the grid of cells and their histograms has been computed (via load()) + you can obtain descriptors for each "block" in the image. A block is just a + group of cells and blocks are allowed to overlap. Each block is square and + made up of block_size*block_size cells. So when you call operator()(r,c) + what you obtain is a vector that is just a bunch of cell histograms that + have been concatenated (and length normalized). + + The template arguments control the various parameters of this algorithm. + + The interpolation_type parameter controls the amount of interpolation + that happens during the creation of the edge orientation histograms. It + varies from no interpolation at all to full spatial and angle interpolation. + + Angle interpolation means that an edge doesn't just go into its nearest + histogram bin but instead gets interpolated into its two nearest neighbors. + Similarly, spatial interpolation means that an edge doesn't just go into + the cell it is in but it also contributes to nearby cells depending on how + close they are. + + The gradient_type parameter controls how edge orientations are measured. + Consider the following ASCII art: + signed gradients: unsigned gradients: + /\ | + || | + <--- ----> ------+------ + || | + \/ | + + An image is full of gradients caused by edges between objects. The direction + of a gradient is determined by which end of it has pixels of highest intensity. + So for example, suppose you had a picture containing black and white stripes. + Then the magnitude of the gradient at each point in the image tells you if you + are on the edge of a stripe and the gradient's orientation tells you which + direction you have to move get into the white stripe. + + Signed gradients preserve this direction information while unsigned gradients + do not. An unsigned gradient will only tell you the orientation of the stripe + but not which direction leads to the white stripe. + + Finally, the cell_stride parameter controls how much overlap you get between + blocks. The maximum amount of overlap is obtained when cell_stride == 1. + At the other extreme, you would have no overlap if cell_stride == block_size. + + + THREAD SAFETY + Concurrent access to an instance of this object is not safe and should be protected + by a mutex lock except for the case where you are copying the configuration + (via copy_configuration()) of a hog_image object to many other threads. + In this case, it is safe to copy the configuration of a shared object so long + as no other operations are performed on it. + !*/ + + public: + + const static unsigned long cell_size = cell_size_; + const static unsigned long block_size = block_size_; + const static unsigned long cell_stride = cell_stride_; + const static unsigned long num_orientation_bins = num_orientation_bins_; + const static int gradient_type = gradient_type_; + const static int interpolation_type = interpolation_type_; + + const static long min_size = cell_size*block_size+2; + + typedef matrix descriptor_type; + + hog_image ( + ); + /*! + ensures + - this object is properly initialized + !*/ + + void clear ( + ); + /*! + ensures + - this object will have its initial value + !*/ + + void copy_configuration ( + const hog_image& item + ); + /*! + ensures + - copies all the state information of item into *this, except for state + information populated by load(). More precisely, given two hog_image + objects H1 and H2, the following sequence of instructions should always + result in both of them having the exact same state. + H2.copy_configuration(H1); + H1.load(img); + H2.load(img); + !*/ + + template < + typename image_type + > + inline void load ( + const image_type& img + ); + /*! + requires + - image_type is a dlib::matrix or something convertible to a matrix + via mat(). + - pixel_traits::pixel_type>::has_alpha == false + ensures + - if (img.nr() < min_size || img.nc() < min_size) then + - the image is too small so we don't compute anything on it + - #size() == 0 + - else + - generates a HOG image from the given image. + - #size() > 0 + !*/ + + inline void unload ( + ); + /*! + ensures + - #nr() == 0 + - #nc() == 0 + - clears only the state information which is populated by load(). For + example, let H be a hog_image object. Then consider the two sequences + of instructions: + Sequence 1: + H.load(img); + H.unload(); + H.load(img); + + Sequence 2: + H.load(img); + Both sequence 1 and sequence 2 should have the same effect on H. + !*/ + + inline size_t size ( + ) const; + /*! + ensures + - returns nr()*nc() + !*/ + + inline long nr ( + ) const; + /*! + ensures + - returns the number of rows in this HOG image + !*/ + + inline long nc ( + ) const; + /*! + ensures + - returns the number of columns in this HOG image + !*/ + + long get_num_dimensions ( + ) const; + /*! + ensures + - returns the number of dimensions in the feature vectors generated by + this object. + - In particular, returns the value block_size*block_size*num_orientation_bins + !*/ + + inline const descriptor_type& operator() ( + long row, + long col + ) const; + /*! + requires + - 0 <= row < nr() + - 0 <= col < nc() + ensures + - returns the descriptor for the HOG block at the given row and column. This descriptor + will include information from a window that is located at get_block_rect(row,col) in + the original image given to load(). + - The returned descriptor vector will have get_num_dimensions() elements. + !*/ + + const rectangle get_block_rect ( + long row, + long col + ) const; + /*! + ensures + - returns a rectangle that tells you what part of the original image is associated + with a particular HOG block. That is, what part of the input image is associated + with (*this)(row,col). + - The returned rectangle will be cell_size*block_size pixels wide and tall. + !*/ + + const point image_to_feat_space ( + const point& p + ) const; + /*! + ensures + - Each local feature is extracted from a certain point in the input image. + This function returns the identity of the local feature corresponding + to the image location p. Or in other words, let P == image_to_feat_space(p), + then (*this)(P.y(),P.x()) == the local feature closest to, or centered at, + the point p in the input image. Note that some image points might not have + corresponding feature locations. E.g. border points or points outside the + image. In these cases the returned point will be outside get_rect(*this). + !*/ + + const rectangle image_to_feat_space ( + const rectangle& rect + ) const; + /*! + ensures + - returns rectangle(image_to_feat_space(rect.tl_corner()), image_to_feat_space(rect.br_corner())); + (i.e. maps a rectangle from image space to feature space) + !*/ + + const point feat_to_image_space ( + const point& p + ) const; + /*! + ensures + - returns the location in the input image space corresponding to the center + of the local feature at point p. In other words, this function computes + the inverse of image_to_feat_space(). Note that it may only do so approximately, + since more than one image location might correspond to the same local feature. + That is, image_to_feat_space() might not be invertible so this function gives + the closest possible result. + !*/ + + const rectangle feat_to_image_space ( + const rectangle& rect + ) const; + /*! + ensures + - return rectangle(feat_to_image_space(rect.tl_corner()), feat_to_image_space(rect.br_corner())); + (i.e. maps a rectangle from feature space to image space) + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long T1, + unsigned long T2, + unsigned long T3, + unsigned long T4, + int T5, + int T6 + > + void serialize ( + const hog_image& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + + template < + unsigned long T1, + unsigned long T2, + unsigned long T3, + unsigned long T4, + int T5, + int T6 + > + void deserialize ( + hog_image& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_HoG_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/image_keypoint/nearest_neighbor_feature_image.h b/ml/dlib/dlib/image_keypoint/nearest_neighbor_feature_image.h new file mode 100644 index 000000000..2ee45da2f --- /dev/null +++ b/ml/dlib/dlib/image_keypoint/nearest_neighbor_feature_image.h @@ -0,0 +1,408 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_NEAREST_NEIGHBOR_FeATURE_IMAGE_Hh_ +#define DLIB_NEAREST_NEIGHBOR_FeATURE_IMAGE_Hh_ + +#include "nearest_neighbor_feature_image_abstract.h" +#include +#include "../algs.h" +#include "../matrix.h" +#include "../statistics.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + class nearest_neighbor_feature_image : noncopyable + { + /*! + INITIAL VALUE + - nn_feats.size() == 1 + + CONVENTION + - nn_feats.size() == 1 + + !*/ + + public: + + typedef std::vector > descriptor_type; + + nearest_neighbor_feature_image ( + ); + + void clear ( + ); + + void copy_configuration ( + const feature_extractor& item + ); + + void copy_configuration ( + const nearest_neighbor_feature_image& item + ); + + template < + typename image_type + > + inline void load ( + const image_type& img + ); + + inline size_t size ( + ) const; + + inline long nr ( + ) const; + + inline long nc ( + ) const; + + inline long get_num_dimensions ( + ) const; + + template + void set_basis ( + const vector_type& new_basis + ); + + inline const descriptor_type& operator() ( + long row, + long col + ) const; + + inline const rectangle get_block_rect ( + long row, + long col + ) const; + + inline const point image_to_feat_space ( + const point& p + ) const; + + inline const rectangle image_to_feat_space ( + const rectangle& rect + ) const; + + inline const point feat_to_image_space ( + const point& p + ) const; + + inline const rectangle feat_to_image_space ( + const rectangle& rect + ) const; + + template + friend void serialize ( + const nearest_neighbor_feature_image& item, + std::ostream& out + ); + + template + friend void deserialize ( + nearest_neighbor_feature_image& item, + std::istream& in + ); + + private: + + array2d feats; + feature_extractor fe; + std::vector basis; + + // This is a transient variable. It is just here so it doesn't have to be + // reallocated over and over inside operator() + mutable descriptor_type nn_feats; + + }; + +// ---------------------------------------------------------------------------------------- + + template + void serialize ( + const nearest_neighbor_feature_image& item, + std::ostream& out + ) + { + serialize(item.feats, out); + serialize(item.fe, out); + serialize(item.basis, out); + } + + template + void deserialize ( + nearest_neighbor_feature_image& item, + std::istream& in + ) + { + deserialize(item.feats, in); + deserialize(item.fe, in); + deserialize(item.basis, in); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// nearest_neighbor_feature_image member functions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + nearest_neighbor_feature_image:: + nearest_neighbor_feature_image ( + ) + { + nn_feats.resize(1); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + void nearest_neighbor_feature_image:: + clear ( + ) + { + feats.clear(); + fe.clear(); + basis.clear(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + void nearest_neighbor_feature_image:: + copy_configuration ( + const feature_extractor& item + ) + { + fe.copy_configuration(item); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + void nearest_neighbor_feature_image:: + copy_configuration ( + const nearest_neighbor_feature_image& item + ) + { + fe.copy_configuration(item.fe); + basis = item.basis; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + template < + typename image_type + > + void nearest_neighbor_feature_image:: + load ( + const image_type& img + ) + { + fe.load(img); + + feats.set_size(fe.nr(), fe.nc()); + + // find the nearest neighbor for each feature vector and store the + // result in feats. + for (long r = 0; r < feats.nr(); ++r) + { + for (long c = 0; c < feats.nc(); ++c) + { + const typename feature_extractor::descriptor_type& local_feat = fe(r,c); + + double best_dist = std::numeric_limits::infinity(); + unsigned long best_idx = 0; + for (unsigned long i = 0; i < basis.size(); ++i) + { + double dist = length_squared(local_feat - basis[i]); + if (dist < best_dist) + { + best_dist = dist; + best_idx = i; + } + } + + feats[r][c] = best_idx; + } + } + + fe.unload(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + size_t nearest_neighbor_feature_image:: + size ( + ) const + { + return feats.size(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + long nearest_neighbor_feature_image:: + nr ( + ) const + { + return feats.nr(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + long nearest_neighbor_feature_image:: + nc ( + ) const + { + return feats.nc(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + long nearest_neighbor_feature_image:: + get_num_dimensions ( + ) const + { + return basis.size(); + } + +// ---------------------------------------------------------------------------------------- + + template + template + void nearest_neighbor_feature_image:: + set_basis ( + const vector_type& new_basis + ) + { + basis.assign(new_basis.begin(), new_basis.end()); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + const typename nearest_neighbor_feature_image::descriptor_type& + nearest_neighbor_feature_image:: + operator() ( + long row, + long col + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(0 <= row && row < nr() && + 0 <= col && col < nc(), + "\t descriptor_type nearest_neighbor_feature_image::operator(row,col)" + << "\n\t Invalid inputs were given to this function" + << "\n\t row: " << row + << "\n\t col: " << col + << "\n\t nr(): " << nr() + << "\n\t nc(): " << nc() + << "\n\t this: " << this + ); + + nn_feats[0] = std::make_pair(feats[row][col],1); + return nn_feats; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + const rectangle nearest_neighbor_feature_image:: + get_block_rect ( + long row, + long col + ) const + { + return fe.get_block_rect(row,col); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + const point nearest_neighbor_feature_image:: + image_to_feat_space ( + const point& p + ) const + { + return fe.image_to_feat_space(p); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + const rectangle nearest_neighbor_feature_image:: + image_to_feat_space ( + const rectangle& rect + ) const + { + return fe.image_to_feat_space(rect); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + const point nearest_neighbor_feature_image:: + feat_to_image_space ( + const point& p + ) const + { + return fe.feat_to_image_space(p); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + const rectangle nearest_neighbor_feature_image:: + feat_to_image_space ( + const rectangle& rect + ) const + { + return fe.feat_to_image_space(rect); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_NEAREST_NEIGHBOR_FeATURE_IMAGE_Hh_ + + diff --git a/ml/dlib/dlib/image_keypoint/nearest_neighbor_feature_image_abstract.h b/ml/dlib/dlib/image_keypoint/nearest_neighbor_feature_image_abstract.h new file mode 100644 index 000000000..59d7cfeb7 --- /dev/null +++ b/ml/dlib/dlib/image_keypoint/nearest_neighbor_feature_image_abstract.h @@ -0,0 +1,254 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_NEAREST_NEIGHBOR_FeATURE_IMAGE_ABSTRACT_Hh_ +#ifdef DLIB_NEAREST_NEIGHBOR_FeATURE_IMAGE_ABSTRACT_Hh_ + +#include +#include "../algs.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + class nearest_neighbor_feature_image : noncopyable + { + /*! + REQUIREMENTS ON feature_extractor + - must be an object with an interface compatible with dlib::hog_image + + INITIAL VALUE + - size() == 0 + - get_num_dimensions() == 0 + + WHAT THIS OBJECT REPRESENTS + This object is a tool for performing image feature extraction. In + particular, it wraps another image feature extractor and converts + the wrapped image feature vectors into sparse indicator vectors. It does + this by finding the nearest neighbor for each feature vector and returning an + indicator vector that is zero everywhere except for the position indicated by + the nearest neighbor. + + + THREAD SAFETY + Concurrent access to an instance of this object is not safe and should be protected + by a mutex lock except for the case where you are copying the configuration + (via copy_configuration()) of a nearest_neighbor_feature_image object to many other + threads. In this case, it is safe to copy the configuration of a shared object so + long as no other operations are performed on it. + + + NOTATION + let BASE_FE denote the base feature_extractor object contained inside + the nearest_neighbor_feature_image. + !*/ + + public: + + typedef std::vector > descriptor_type; + + nearest_neighbor_feature_image ( + ); + /*! + ensures + - this object is properly initialized + !*/ + + void clear ( + ); + /*! + ensures + - this object will have its initial value + !*/ + + void copy_configuration ( + const feature_extractor& item + ); + /*! + ensures + - performs BASE_FE.copy_configuration(item) + !*/ + + void copy_configuration ( + const nearest_neighbor_feature_image& item + ); + /*! + ensures + - copies all the state information of item into *this, except for state + information populated by load(). More precisely, given two + nearest_neighbor_feature_image objects H1 and H2, the following sequence + of instructions should always result in both of them having the exact + same state. + H2.copy_configuration(H1); + H1.load(img); + H2.load(img); + !*/ + + template < + typename image_type + > + inline void load ( + const image_type& img + ); + /*! + requires + - image_type == any type that can be supplied to feature_extractor::load() + ensures + - performs BASE_FE.load(img) + i.e. does feature extraction. The features can be accessed using + operator() as defined below. + !*/ + + inline size_t size ( + ) const; + /*! + ensures + - returns BASE_FE.size() + !*/ + + inline long nr ( + ) const; + /*! + ensures + - returns BASE_FE.nr() + !*/ + + inline long nc ( + ) const; + /*! + ensures + - returns BASE_FE.nc() + !*/ + + inline long get_num_dimensions ( + ) const; + /*! + ensures + - returns the dimensionality of the feature vectors returned by operator(). + In this case, this is the number of basis elements. That is, it is the number + of vectors given to the set_basis() member function. + !*/ + + template + void set_basis ( + const vector_type& new_basis + ); + /*! + ensures + - #get_num_dimensions() == new_basis.size() + - The operator() member function defined below will use new_basis to + determine nearest neighbors. + !*/ + + inline const descriptor_type& operator() ( + long row, + long col + ) const; + /*! + requires + - 0 <= row < nr() + - 0 <= col < nc() + - get_num_dimensions() > 0 + ensures + - determines which basis element is nearest to BASE_FE(row,col) and returns a sparse + indicator vector identifying the nearest neighbor. + - To be precise, this function returns a sparse vector V such that: + - V.size() == 1 + - V[0].first == The basis element index for the basis vector nearest to BASE_FE(row,col). + "nearness" is determined using Euclidean distance. + - V[0].second == 1 + !*/ + + inline const rectangle get_block_rect ( + long row, + long col + ) const; + /*! + ensures + - returns BASE_FE.get_block_rect(row,col) + I.e. returns a rectangle that tells you what part of the original image is associated + with a particular feature vector. + !*/ + + inline const point image_to_feat_space ( + const point& p + ) const; + /*! + ensures + - returns BASE_FE.image_to_feat_space(p) + I.e. Each local feature is extracted from a certain point in the input image. + This function returns the identity of the local feature corresponding + to the image location p. Or in other words, let P == image_to_feat_space(p), + then (*this)(P.y(),P.x()) == the local feature closest to, or centered at, + the point p in the input image. Note that some image points might not have + corresponding feature locations. E.g. border points or points outside the + image. In these cases the returned point will be outside get_rect(*this). + !*/ + + inline const rectangle image_to_feat_space ( + const rectangle& rect + ) const; + /*! + ensures + - returns BASE_FE.image_to_feat_space(rect) + I.e. returns rectangle(image_to_feat_space(rect.tl_corner()), image_to_feat_space(rect.br_corner())); + (i.e. maps a rectangle from image space to feature space) + !*/ + + inline const point feat_to_image_space ( + const point& p + ) const; + /*! + ensures + - returns BASE_FE.feat_to_image_space(p) + I.e. returns the location in the input image space corresponding to the center + of the local feature at point p. In other words, this function computes + the inverse of image_to_feat_space(). Note that it may only do so approximately, + since more than one image location might correspond to the same local feature. + That is, image_to_feat_space() might not be invertible so this function gives + the closest possible result. + !*/ + + inline const rectangle feat_to_image_space ( + const rectangle& rect + ) const; + /*! + ensures + - returns BASE_FE.feat_to_image_space(rect) + I.e. return rectangle(feat_to_image_space(rect.tl_corner()), feat_to_image_space(rect.br_corner())); + (i.e. maps a rectangle from feature space to image space) + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template + void serialize ( + const nearest_neighbor_feature_image& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + + template + void deserialize ( + nearest_neighbor_feature_image& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_NEAREST_NEIGHBOR_FeATURE_IMAGE_ABSTRACT_Hh_ + + + diff --git a/ml/dlib/dlib/image_keypoint/poly_image.h b/ml/dlib/dlib/image_keypoint/poly_image.h new file mode 100644 index 000000000..8abb912f0 --- /dev/null +++ b/ml/dlib/dlib/image_keypoint/poly_image.h @@ -0,0 +1,649 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_POLY_ImAGE_Hh_ +#define DLIB_POLY_ImAGE_Hh_ + +#include "poly_image_abstract.h" +#include "build_separable_poly_filters.h" +#include "../algs.h" +#include "../matrix.h" +#include "../array2d.h" +#include "../geometry.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + long Downsample + > + class poly_image : noncopyable + { + COMPILE_TIME_ASSERT(Downsample >= 1); + public: + const static long downsample = Downsample; + typedef matrix descriptor_type; + + poly_image( + long order_, + long window_size_, + bool normalization = true, + bool rotation_invariance_ = false + ) + { + setup(order_, window_size_); + set_uses_normalization(normalization); + set_is_rotationally_invariant(rotation_invariance_); + } + + poly_image ( + ) + { + clear(); + } + + void clear ( + ) + { + normalize = true; + rotation_invariance = false; + poly_coef.clear(); + order = 3; + window_size = 13; + border_size = (long)std::ceil(std::floor(window_size/2.0)/downsample); + num_rows = 0; + num_cols = 0; + filters = build_separable_poly_filters(order, window_size); + } + + long get_order ( + ) const + { + return order; + } + + long get_window_size ( + ) const + { + return window_size; + } + + void setup ( + long order_, + long window_size_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(1 <= order_ && order_ <= 6 && + window_size_ >= 3 && (window_size_%2) == 1, + "\t descriptor_type poly_image::setup()" + << "\n\t Invalid arguments were given to this function." + << "\n\t order_: " << order_ + << "\n\t window_size_: " << window_size_ + << "\n\t this: " << this + ); + + + poly_coef.clear(); + order = order_; + window_size = window_size_; + border_size = (long)std::ceil(std::floor(window_size/2.0)/downsample); + num_rows = 0; + num_cols = 0; + filters = build_separable_poly_filters(order, window_size); + } + + bool uses_normalization ( + ) const { return normalize; } + + void set_uses_normalization ( + bool normalization + ) + { + normalize = normalization; + } + + bool is_rotationally_invariant ( + ) const { return rotation_invariance; } + + void set_is_rotationally_invariant ( + bool rotation_invariance_ + ) + { + rotation_invariance = rotation_invariance_; + } + + void copy_configuration ( + const poly_image& item + ) + { + normalize = item.normalize; + rotation_invariance = item.rotation_invariance; + if (order != item.order || + window_size != item.window_size) + { + order = item.order; + window_size = item.window_size; + border_size = item.border_size; + filters = item.filters; + } + } + + template < + typename image_type + > + inline void load ( + const image_type& img + ) + { + COMPILE_TIME_ASSERT( pixel_traits::pixel_type>::has_alpha == false ); + + poly_coef.resize(get_num_dimensions()); + des.set_size(get_num_dimensions()); + + + if (normalize) + { + array2d coef0; + rectangle rect = filter_image(img, coef0, filters[0]); + num_rows = rect.height(); + num_cols = rect.width(); + + for (unsigned long i = 1; i < filters.size(); ++i) + { + filter_image(img, poly_coef[i-1], filters[i]); + + // intensity normalize everything + for (long r = 0; r < coef0.nr(); ++r) + { + for (long c = 0; c < coef0.nc(); ++c) + { + if (coef0[r][c] >= 1) + poly_coef[i-1][r][c] /= coef0[r][c]; + else + poly_coef[i-1][r][c] = 0; + } + } + } + + if (rotation_invariance) + rotate_polys(rect); + } + else + { + rectangle rect; + for (unsigned long i = 0; i < filters.size(); ++i) + { + rect = filter_image(img, poly_coef[i], filters[i]); + } + num_rows = rect.height(); + num_cols = rect.width(); + + if (rotation_invariance) + rotate_polys(rect); + } + } + + void unload() + { + poly_coef.clear(); + num_rows = 0; + num_cols = 0; + } + + inline size_t size ( + ) const { return static_cast(nr()*nc()); } + + inline long nr ( + ) const { return num_rows; } + + inline long nc ( + ) const { return num_cols; } + + long get_num_dimensions ( + ) const + { + if (normalize) + { + // -1 because we discard the constant term of the polynomial. + return filters.size()-1; + } + else + { + return filters.size(); + } + } + + inline const descriptor_type& operator() ( + long row, + long col + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT( 0 <= row && row < nr() && + 0 <= col && col < nc(), + "\t descriptor_type poly_image::operator()()" + << "\n\t invalid row or col argument" + << "\n\t row: " << row + << "\n\t col: " << col + << "\n\t nr(): " << nr() + << "\n\t nc(): " << nc() + << "\n\t this: " << this + ); + + // add because of the zero border around the poly_coef images + row += border_size; + col += border_size; + + for (long i = 0; i < des.size(); ++i) + des(i) = poly_coef[i][row][col]; + + return des; + } + + const rectangle get_block_rect ( + long row, + long col + ) const + { + return centered_rect(Downsample*point(col+border_size, row+border_size), + window_size, window_size); + } + + const point image_to_feat_space ( + const point& p + ) const + { + return p/Downsample - point(border_size, border_size); + } + + const rectangle image_to_feat_space ( + const rectangle& rect + ) const + { + return rectangle(image_to_feat_space(rect.tl_corner()), image_to_feat_space(rect.br_corner())); + } + + const point feat_to_image_space ( + const point& p + ) const + { + return (p + point(border_size, border_size))*Downsample; + } + + const rectangle feat_to_image_space ( + const rectangle& rect + ) const + { + return rectangle(feat_to_image_space(rect.tl_corner()), feat_to_image_space(rect.br_corner())); + } + + + + friend void serialize (const poly_image& item, std::ostream& out) + { + int version = 1; + serialize(version, out); + serialize(item.poly_coef, out); + serialize(item.order, out); + serialize(item.window_size, out); + serialize(item.border_size, out); + serialize(item.num_rows, out); + serialize(item.num_cols, out); + serialize(item.normalize, out); + serialize(item.rotation_invariance, out); + serialize(item.filters, out); + } + + friend void deserialize (poly_image& item, std::istream& in ) + { + int version = 0; + deserialize(version, in); + if (version != 1) + throw dlib::serialization_error("Unexpected version found while deserializing dlib::poly_image"); + + deserialize(item.poly_coef, in); + deserialize(item.order, in); + deserialize(item.window_size, in); + deserialize(item.border_size, in); + deserialize(item.num_rows, in); + deserialize(item.num_cols, in); + deserialize(item.normalize, in); + deserialize(item.rotation_invariance, in); + deserialize(item.filters, in); + } + + private: + + matrix rotate_order_1 ( + const matrix& w, + double cos_theta, + double sin_theta + ) const + { + const double w1 = w(0); + const double w2 = w(1); + matrix M; + M = w1, w2, + w2, -w1; + + matrix x; + x = cos_theta, + sin_theta; + + return matrix_cast(M*x); + } + + matrix rotate_order_2 ( + const matrix& w, + double cos_theta, + double sin_theta + ) const + { + const double w1 = w(0); + const double w2 = w(1); + const double w3 = w(2); + matrix M; + M = w1, w2, w3, + w2, (2*w3-2*w1), -w2, + w3, -w2, w1; + + matrix x; + x = std::pow(cos_theta,2.0), + cos_theta*sin_theta, + std::pow(sin_theta,2.0); + + return matrix_cast(M*x); + } + + matrix rotate_order_3 ( + const matrix& w, + double cos_theta, + double sin_theta + ) const + { + const double w1 = w(0); + const double w2 = w(1); + const double w3 = w(2); + const double w4 = w(3); + matrix M; + M = w1, w2, w3, w4, + w2, (2*w3-3*w1), (3*w4-2*w2), -w3, + w3, (3*w4-2*w2), (3*w1-2*w3), w2, + w4, -w3, w2, -w1; + + matrix x; + x = std::pow(cos_theta,3.0), + std::pow(cos_theta,2.0)*sin_theta, + cos_theta*std::pow(sin_theta,2.0), + std::pow(sin_theta,3.0); + + return matrix_cast(M*x); + } + + matrix rotate_order_4 ( + const matrix& w, + double cos_theta, + double sin_theta + ) const + { + const double w1 = w(0); + const double w2 = w(1); + const double w3 = w(2); + const double w4 = w(3); + const double w5 = w(4); + matrix M; + M = w1, w2, w3, w4, w5, + w2, (2*w3-4*w1), (3*w4-3*w2), (4*w5-2*w3), -w4, + w3, (3*w4-3*w2), (6*w1-4*w3+6*w5), (3*w2-3*w4), w3, + w4, (4*w5-2*w3), (3*w2-3*w4), (2*w3-4*w1), -w2, + w5, -w4, w3, -w2, w1; + + matrix x; + x = std::pow(cos_theta,4.0), + std::pow(cos_theta,3.0)*sin_theta, + std::pow(cos_theta,2.0)*std::pow(sin_theta,2.0), + cos_theta*std::pow(sin_theta,3.0), + std::pow(sin_theta,4.0); + + return matrix_cast(M*x); + } + + matrix rotate_order_5 ( + const matrix& w, + double cos_theta, + double sin_theta + ) const + { + const double w1 = w(0); + const double w2 = w(1); + const double w3 = w(2); + const double w4 = w(3); + const double w5 = w(4); + const double w6 = w(5); + matrix M; + M = w1, w2, w3, w4, w5, w6, + w2, (2*w3-5*w1), (3*w4-4*w2), (4*w5-3*w3), (5*w6-2*w4), -w5, + w3, (3*w4-4*w2), (10*w1-6*w3+6*w5), (6*w2-6*w4+10*w6), (3*w3-4*w5), w4, + w4, (4*w5-3*w3), (6*w2-6*w4+10*w6), (-10*w1+6*w3-6*w5), (3*w4-4*w2), -w3, + w5, (5*w6-2*w4), (3*w3-4*w5), (3*w4-4*w2), (5*w1-2*w3), w2, + w6, -w5, w4, -w3, w2, -w1; + + matrix x; + x = std::pow(cos_theta,5.0), + std::pow(cos_theta,4.0)*sin_theta, + std::pow(cos_theta,3.0)*std::pow(sin_theta,2.0), + std::pow(cos_theta,2.0)*std::pow(sin_theta,3.0), + cos_theta*std::pow(sin_theta,4.0), + std::pow(sin_theta,5.0); + + return matrix_cast(M*x); + } + + matrix rotate_order_6 ( + const matrix& w, + double cos_theta, + double sin_theta + ) const + { + const double w1 = w(0); + const double w2 = w(1); + const double w3 = w(2); + const double w4 = w(3); + const double w5 = w(4); + const double w6 = w(5); + const double w7 = w(6); + matrix M; + M = w1, w2, w3, w4, w5, w6, w7, + w2, (2*w3-6*w1), (3*w4-5*w2), (4*w5-4*w3), (5*w6-3*w4), (6*w7-2*w5), -w6, + w3, (3*w4-5*w2), (15*w1-8*w3+ 6*w5), ( 10*w2 -9*w4+10*w6), ( 6*w3-8*w5+15*w7), (3*w4-5*w6), w5, + w4, (4*w5-4*w3), (10*w2-9*w4+10*w6), (-20*w1+12*w3-12*w5+20*w7), (-10*w2+9*w4-10*w6), (4*w5-4*w3), -w4, + w5, (5*w6-3*w4), ( 6*w3-8*w5+15*w7), (-10*w2 +9*w4-10*w6), ( 15*w1-8*w3 +6*w5), (5*w2-3*w4), w3, + w6, (6*w7-2*w5), (3*w4-5*w6), (4*w5-4*w3), (5*w2-3*w4), (2*w3-6*w1), -w2, + w7, -w6, w5, -w4, w3, -w2, w1; + + matrix x; + x = std::pow(cos_theta,6.0), + std::pow(cos_theta,5.0)*sin_theta, + std::pow(cos_theta,4.0)*std::pow(sin_theta,2.0), + std::pow(cos_theta,3.0)*std::pow(sin_theta,3.0), + std::pow(cos_theta,2.0)*std::pow(sin_theta,4.0), + cos_theta*std::pow(sin_theta,5.0), + std::pow(sin_theta,6.0); + + return matrix_cast(M*x); + } + + void rotate_polys ( + const rectangle& rect + ) + /*! + ensures + - rotates all the polynomials in poly_coef so that they are + rotationally invariant + !*/ + { + // The idea here is to use a rotation matrix to rotate the + // coordinate system for the polynomial so that the x axis + // always lines up with the gradient vector (or direction of + // max curvature). This way we can make the representation + // rotation invariant. + + // Note that the rotation matrix is given by: + // [ cos_theta -sin_theta ] + // [ sin_theta cos_theta ] + + // need to offset poly_coef to get past the constant term if there isn't any normalization. + const int off = (normalize) ? 0 : 1; + + for (long r = rect.top(); r <= rect.bottom(); ++r) + { + for (long c = rect.left(); c <= rect.right(); ++c) + { + dlib::vector g(poly_coef[off+0][r][c], + poly_coef[off+1][r][c]); + + const double len = g.length(); + if (len != 0) + { + g /= len; + } + else + { + g.x() = 1; + g.y() = 0; + } + // since we normalized g we can find the sin/cos of its angle easily. + const double cos_theta = g.x(); + const double sin_theta = g.y(); + + if (order >= 1) + { + matrix w; + w = poly_coef[off+0][r][c], + poly_coef[off+1][r][c]; + w = rotate_order_1(w, cos_theta, sin_theta); + poly_coef[off+0][r][c] = w(0); + poly_coef[off+1][r][c] = w(1); + } + if (order >= 2) + { + matrix w; + w = poly_coef[off+2][r][c], + poly_coef[off+3][r][c], + poly_coef[off+4][r][c]; + w = rotate_order_2(w, cos_theta, sin_theta); + poly_coef[off+2][r][c] = w(0); + poly_coef[off+3][r][c] = w(1); + poly_coef[off+4][r][c] = w(2); + } + if (order >= 3) + { + matrix w; + w = poly_coef[off+5][r][c], + poly_coef[off+6][r][c], + poly_coef[off+7][r][c], + poly_coef[off+8][r][c]; + w = rotate_order_3(w, cos_theta, sin_theta); + poly_coef[off+5][r][c] = w(0); + poly_coef[off+6][r][c] = w(1); + poly_coef[off+7][r][c] = w(2); + poly_coef[off+8][r][c] = w(3); + } + if (order >= 4) + { + matrix w; + w = poly_coef[off+9][r][c], + poly_coef[off+10][r][c], + poly_coef[off+11][r][c], + poly_coef[off+12][r][c], + poly_coef[off+13][r][c]; + w = rotate_order_4(w, cos_theta, sin_theta); + poly_coef[off+9][r][c] = w(0); + poly_coef[off+10][r][c] = w(1); + poly_coef[off+11][r][c] = w(2); + poly_coef[off+12][r][c] = w(3); + poly_coef[off+13][r][c] = w(4); + } + if (order >= 5) + { + matrix w; + w = poly_coef[off+14][r][c], + poly_coef[off+15][r][c], + poly_coef[off+16][r][c], + poly_coef[off+17][r][c], + poly_coef[off+18][r][c], + poly_coef[off+19][r][c]; + w = rotate_order_5(w, cos_theta, sin_theta); + poly_coef[off+14][r][c] = w(0); + poly_coef[off+15][r][c] = w(1); + poly_coef[off+16][r][c] = w(2); + poly_coef[off+17][r][c] = w(3); + poly_coef[off+18][r][c] = w(4); + poly_coef[off+19][r][c] = w(5); + } + if (order >= 6) + { + matrix w; + w = poly_coef[off+20][r][c], + poly_coef[off+21][r][c], + poly_coef[off+22][r][c], + poly_coef[off+23][r][c], + poly_coef[off+24][r][c], + poly_coef[off+25][r][c], + poly_coef[off+26][r][c]; + w = rotate_order_6(w, cos_theta, sin_theta); + poly_coef[off+20][r][c] = w(0); + poly_coef[off+21][r][c] = w(1); + poly_coef[off+22][r][c] = w(2); + poly_coef[off+23][r][c] = w(3); + poly_coef[off+24][r][c] = w(4); + poly_coef[off+25][r][c] = w(5); + poly_coef[off+26][r][c] = w(6); + } + } + } + + } + + template + rectangle filter_image ( + const image_type& img, + array2d& out, + const std::vector& filter + ) const + { + rectangle rect = spatially_filter_image_separable_down(downsample, img, out, filter[0].first, filter[0].second); + for (unsigned long i = 1; i < filter.size(); ++i) + { + spatially_filter_image_separable_down(downsample, img, out, filter[i].first, filter[i].second, 1, false, true); + } + return rect; + } + + + + std::vector > filters; + + dlib::array > poly_coef; + long order; + long window_size; + long border_size; + long num_rows; + long num_cols; + + bool normalize; + bool rotation_invariance; + + mutable descriptor_type des; + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_POLY_ImAGE_Hh_ + + diff --git a/ml/dlib/dlib/image_keypoint/poly_image_abstract.h b/ml/dlib/dlib/image_keypoint/poly_image_abstract.h new file mode 100644 index 000000000..2f17bb31e --- /dev/null +++ b/ml/dlib/dlib/image_keypoint/poly_image_abstract.h @@ -0,0 +1,335 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_POLY_ImAGE_ABSTRACT_Hh_ +#ifdef DLIB_POLY_ImAGE_ABSTRACT_Hh_ + +#include "../algs.h" +#include "../matrix.h" +#include "../geometry/rectangle_abstract.h" +#include +#include "../image_processing/generic_image.h" + +namespace dlib +{ + template < + long Downsample + > + class poly_image : noncopyable + { + /*! + REQUIREMENTS ON TEMPLATE PARAMETERS + - Downsample >= 1 + + WHAT THIS OBJECT REPRESENTS + This object is a tool for extracting local feature descriptors from an image. + In particular, it fits polynomials to local pixel patches and allows you to + query the coefficients of these polynomials. Additionally, the coefficients + may be intensity normalized by dividing them by the constant term of the fitted + polynomial and then the constant term is discarded. + + Finally, the user can specify a downsampling rate. If the template argument + Downsample is set to 1 then feature extraction is performed at every pixel of + an input image (except for a small area around the image border). However, + if Downsample is set to 2 then feature extraction is only performed at every + other pixel location. More generally, if Downsample is set to N then feature + extraction is performed only every N pixels. + + THREAD SAFETY + Concurrent access to an instance of this object is not safe and should be protected + by a mutex lock except for the case where you are copying the configuration + (via copy_configuration()) of a poly_image object to many other threads. + In this case, it is safe to copy the configuration of a shared object so long + as no other operations are performed on it. + !*/ + + public: + + typedef matrix descriptor_type; + const static long downsample = Downsample; + + poly_image ( + ); + /*! + ensures + - #get_order() == 3 + - #get_window_size() == 13 + - #size() == 0 + - #uses_normalization() == true + - #is_rotationally_invariant() == false + !*/ + + poly_image( + long order, + long window_size, + bool normalization = true, + bool rotation_invariance = false + ); + /*! + requires + - 1 <= order <= 6 + - window_size >= 3 && window_size is odd + ensures + - #get_order() == order + - #get_window_size() == window_size + - #size() == 0 + - #uses_normalization() == normalization + - #is_rotationally_invariant() == rotation_invariance + !*/ + + void clear ( + ); + /*! + ensures + - this object will have its initial value + !*/ + + void setup ( + long order, + long window_size + ); + /*! + requires + - 1 <= order <= 6 + - window_size >= 3 && window_size is odd + ensures + - #get_order() == order + - #get_window_size() == window_size + !*/ + + long get_order ( + ) const; + /*! + ensures + - returns the order of the polynomial that will be fitted to + each local pixel patch during feature extraction. + !*/ + + long get_window_size ( + ) const; + /*! + ensures + - returns the size of the window used for local feature extraction. + This is the width and height of the window in pixels. + !*/ + + bool uses_normalization ( + ) const; + /*! + ensures + - returns true if the polynomial coefficients are intensity normalized + and false otherwise. + !*/ + + void set_uses_normalization ( + bool normalization + ); + /*! + ensures + - #uses_normalization() == normalization + !*/ + + bool is_rotationally_invariant ( + ); + /*! + ensures + - returns true if the feature extractor will adjust the output so that it + is rotationally invariant. This is done by rotating each patch such that + the gradient vector always points in the same direction. + !*/ + + void set_is_rotationally_invariant ( + bool rotation_invariance + ); + /*! + ensures + - #is_rotationally_invariant() == rotation_invariance + !*/ + + void copy_configuration ( + const poly_image& item + ); + /*! + ensures + - copies all the state information of item into *this, except for state + information populated by load(). More precisely, given two poly_image + objects H1 and H2, the following sequence of instructions should always + result in both of them having the exact same state. + H2.copy_configuration(H1); + H1.load(img); + H2.load(img); + !*/ + + template < + typename image_type + > + inline void load ( + const image_type& img + ); + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - pixel_traits::pixel_type>::has_alpha == false + ensures + - Performs the feature extraction described in the WHAT THIS OBJECT REPRESENTS + section above. This means after load() finishes you can call (*this)(row,col) + to obtain the polynomial coefficients for an order get_order() polynomial which + was fitted to the image patch get_block_rect(row,col). + - #size() > 0 + !*/ + + void unload( + ); + /*! + ensures + - #nr() == 0 + - #nc() == 0 + - clears only the state information which is populated by load(). For + example, let H be a poly_image object. Then consider the two sequences + of instructions: + Sequence 1: + H.load(img); + H.unload(); + H.load(img); + + Sequence 2: + H.load(img); + Both sequence 1 and sequence 2 should have the same effect on H. + !*/ + + inline size_t size ( + ) const; + /*! + ensures + - returns nr()*nc() + !*/ + + inline long nr ( + ) const; + /*! + ensures + - returns the number of rows in this polynomial feature image + !*/ + + inline long nc ( + ) const; + /*! + ensures + - returns the number of columns in this polynomial feature image + !*/ + + long get_num_dimensions ( + ) const; + /*! + ensures + - returns the number of dimensions in the feature vectors generated by + this object. + - In this case, this will be the number of coefficients in an order + get_order() polynomial, except for the constant term of the polynomial + if uses_normalization() == true. + !*/ + + inline const descriptor_type& operator() ( + long row, + long col + ) const; + /*! + requires + - 0 <= row < nr() + - 0 <= col < nc() + ensures + - returns the descriptor for the polynomial filtering block at the given row and column. + This vector will contain the polynomial coefficients for a polynomial fitted to the + image patch located at get_block_rect(row,col) in the original image given to load(). + - The returned descriptor vector will have get_num_dimensions() elements. + !*/ + + const rectangle get_block_rect ( + long row, + long col + ) const; + /*! + ensures + - returns a rectangle that tells you what part of the original image is associated + with a particular polynomial filter block. That is, what part of the input image + is associated with (*this)(row,col). + - The returned rectangle will be get_window_size() pixels wide and tall. + !*/ + + const point image_to_feat_space ( + const point& p + ) const; + /*! + ensures + - Each local feature is extracted from a certain point in the input image. + This function returns the identity of the local feature corresponding + to the image location p. Or in other words, let P == image_to_feat_space(p), + then (*this)(P.y(),P.x()) == the local feature closest to, or centered at, + the point p in the input image. Note that some image points might not have + corresponding feature locations. E.g. border points or points outside the + image. In these cases the returned point will be outside get_rect(*this). + !*/ + + const rectangle image_to_feat_space ( + const rectangle& rect + ) const; + /*! + ensures + - returns rectangle(image_to_feat_space(rect.tl_corner()), image_to_feat_space(rect.br_corner())); + (i.e. maps a rectangle from image space to feature space) + !*/ + + const point feat_to_image_space ( + const point& p + ) const; + /*! + ensures + - returns the location in the input image space corresponding to the center + of the local feature at point p. In other words, this function computes + the inverse of image_to_feat_space(). Note that it may only do so approximately, + since more than one image location might correspond to the same local feature. + That is, image_to_feat_space() might not be invertible so this function gives + the closest possible result. + !*/ + + const rectangle feat_to_image_space ( + const rectangle& rect + ) const; + /*! + ensures + - return rectangle(feat_to_image_space(rect.tl_corner()), feat_to_image_space(rect.br_corner())); + (i.e. maps a rectangle from feature space to image space) + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + template < + long downsample + > + void serialize ( + const poly_image& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + + template < + long downsample + > + void deserialize ( + poly_image& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_POLY_ImAGE_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/image_keypoint/surf.h b/ml/dlib/dlib/image_keypoint/surf.h new file mode 100644 index 000000000..d12b30840 --- /dev/null +++ b/ml/dlib/dlib/image_keypoint/surf.h @@ -0,0 +1,295 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SURf_H_ +#define DLIB_SURf_H_ + +#include "surf_abstract.h" +#include "hessian_pyramid.h" +#include "../matrix.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + struct surf_point + { + interest_point p; + matrix des; + double angle; + }; + +// ---------------------------------------------------------------------------------------- + + inline void serialize( + const surf_point& item, + std::ostream& out + ) + { + try + { + serialize(item.p,out); + serialize(item.des,out); + serialize(item.angle,out); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing object of type surf_point"); + } + } + +// ---------------------------------------------------------------------------------------- + + inline void deserialize( + surf_point& item, + std::istream& in + ) + { + try + { + deserialize(item.p,in); + deserialize(item.des,in); + deserialize(item.angle,in); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while deserializing object of type surf_point"); + } + } + +// ---------------------------------------------------------------------------------------- + + inline double gaussian (double x, double y, double sig) + { + DLIB_ASSERT(sig > 0, + "\tdouble gaussian()" + << "\n\t sig must be bigger than 0" + << "\n\t sig: " << sig + ); + const double sqrt_2_pi = 2.5066282746310002416123552393401041626930; + return 1.0/(sig*sqrt_2_pi) * std::exp( -(x*x + y*y)/(2*sig*sig)); + } + +// ---------------------------------------------------------------------------------------- + + template + double compute_dominant_angle ( + const integral_image_type& img, + const dlib::vector& center, + const double& scale + ) + { + DLIB_ASSERT(get_rect(img).contains(centered_rect(center, (unsigned long)(17*scale),(unsigned long)(17*scale))) == true && + scale > 0, + "\tdouble compute_dominant_angle(img, center, scale)" + << "\n\tAll arguments to this function must be > 0" + << "\n\t get_rect(img): " << get_rect(img) + << "\n\t center: " << center + << "\n\t scale: " << scale + ); + + + std::vector ang; + std::vector > samples; + + const long sc = static_cast(scale+0.5); + + // accumulate a bunch of angle and vector samples + dlib::vector vect; + for (long r = -6; r <= 6; ++r) + { + for (long c = -6; c <= 6; ++c) + { + if (r*r + c*c < 36) + { + // compute a Gaussian weighted gradient and the gradient's angle. + const double gauss = gaussian(c,r, 2.5); + vect.x() = gauss*haar_x(img, sc*point(c,r)+center, 4*sc); + vect.y() = gauss*haar_y(img, sc*point(c,r)+center, 4*sc); + samples.push_back(vect); + ang.push_back(atan2(vect.y(), vect.x())); + } + } + } + + + // now find the dominant direction + double max_length = 0; + double best_ang = 0; + // look at a bunch of pie shaped slices of a circle + const long slices = 45; + const double ang_step = (2*pi)/slices; + for (long ang_i = 0; ang_i < slices; ++ang_i) + { + // compute the bounding angles + double ang1 = ang_step*ang_i - pi; + double ang2 = ang1 + pi/3; + + + // compute sum of all vectors that are within the above two angles + vect.x() = 0; + vect.y() = 0; + for (unsigned long i = 0; i < ang.size(); ++i) + { + if (ang1 <= ang[i] && ang[i] <= ang2) + { + vect += samples[i]; + } + else if (ang2 > pi && (ang[i] >= ang1 || ang[i] <= (-2*pi+ang2))) + { + vect += samples[i]; + } + } + + + // record the angle of the best vectors + if (length_squared(vect) > max_length) + { + max_length = length_squared(vect); + best_ang = atan2(vect.y(), vect.x()); + } + } + + return best_ang; + } + +// ---------------------------------------------------------------------------------------- + + template + void compute_surf_descriptor ( + const integral_image_type& img, + const dlib::vector& center, + const double scale, + const double angle, + matrix& des + ) + { + DLIB_ASSERT(get_rect(img).contains(centered_rect(center, (unsigned long)(32*scale),(unsigned long)(32*scale))) == true && + scale > 0, + "\tvoid compute_surf_descriptor(img, center, scale, angle)" + << "\n\tAll arguments to this function must be > 0" + << "\n\t get_rect(img): " << get_rect(img) + << "\n\t center: " << center + << "\n\t scale: " << scale + ); + + point_rotator rot(angle); + point_rotator inv_rot(-angle); + + const long sc = static_cast(scale+0.5); + long count = 0; + + // loop over the 4x4 grid of histogram buckets + for (long r = -10; r < 10; r += 5) + { + for (long c = -10; c < 10; c += 5) + { + dlib::vector vect, abs_vect, temp; + + // now loop over 25 points in this bucket and sum their features. Note + // that we include 1 pixels worth of padding around the outside of each 5x5 + // cell. This is to help neighboring cells interpolate their counts into + // each other a little bit. + for (long y = r-1; y < r+5+1; ++y) + { + if (y < -10 || y >= 10) + continue; + for (long x = c-1; x < c+5+1; ++x) + { + if (x < -10 || x >= 10) + continue; + + // get the rotated point for this extraction point + point p(rot(point(x,y)*scale) + center); + + // Give points farther from the center of the bucket a lower weight. + const long center_r = r+2; + const long center_c = c+2; + const double weight = 1.0/(4+std::abs(center_r-y) + std::abs(center_c-x)); + + temp.x() = weight*haar_x(img, p, 2*sc); + temp.y() = weight*haar_y(img, p, 2*sc); + + // rotate this vector into alignment with the surf descriptor box + temp = inv_rot(temp); + + vect += temp; + abs_vect += abs(temp); + } + } + + des(count++) = vect.x(); + des(count++) = vect.y(); + des(count++) = abs_vect.x(); + des(count++) = abs_vect.y(); + } + } + + // Return the length normalized descriptor. Add a small number + // to guard against division by zero. + const double len = length(des) + 1e-7; + des = des/len; + } + +// ---------------------------------------------------------------------------------------- + + template + const std::vector get_surf_points ( + const image_type& img, + long max_points = 10000, + double detection_threshold = 30.0 + ) + { + DLIB_ASSERT(max_points > 0 && detection_threshold >= 0, + "\t std::vector get_surf_points()" + << "\n\t Invalid arguments were given to this function." + << "\n\t max_points: " << max_points + << "\n\t detection_threshold: " << detection_threshold + ); + + // Figure out the proper scalar type we should use to work with these pixels. + typedef typename pixel_traits::pixel_type>::basic_pixel_type bp_type; + typedef typename promote::type working_pixel_type; + + // make an integral image first + integral_image_generic int_img; + int_img.load(img); + + // now make a hessian pyramid + hessian_pyramid pyr; + pyr.build_pyramid(int_img, 4, 6, 2); + + // now get all the interest points from the hessian pyramid + std::vector points; + get_interest_points(pyr, detection_threshold, points); + std::vector spoints; + + // sort all the points by how strong their detect is + std::sort(points.rbegin(), points.rend()); + + // now extract SURF descriptors for the points + surf_point sp; + for (unsigned long i = 0; i < std::min((size_t)max_points,points.size()); ++i) + { + // ignore points that are close to the edge of the image + const double border = 32; + const unsigned long border_size = static_cast(border*points[i].scale); + if (get_rect(int_img).contains(centered_rect(points[i].center, border_size, border_size))) + { + sp.angle = compute_dominant_angle(int_img, points[i].center, points[i].scale); + compute_surf_descriptor(int_img, points[i].center, points[i].scale, sp.angle, sp.des); + sp.p = points[i]; + + spoints.push_back(sp); + } + } + + return spoints; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SURf_H_ + diff --git a/ml/dlib/dlib/image_keypoint/surf_abstract.h b/ml/dlib/dlib/image_keypoint/surf_abstract.h new file mode 100644 index 000000000..e539f3e24 --- /dev/null +++ b/ml/dlib/dlib/image_keypoint/surf_abstract.h @@ -0,0 +1,163 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SURf_ABSTRACT_H_ +#ifdef DLIB_SURf_ABSTRACT_H_ + +#include "hessian_pyramid_abstract.h" +#include "../geometry/vector_abstract.h" +#include "../matrix/matrix_abstract.h" +#include "../image_processing/generic_image.h" + +namespace dlib +{ + /* + The functions in this file implement the components of the SURF algorithm + for extracting scale invariant feature descriptors from images. + + For the full story on what this algorithm does and how it works + you should refer to the following papers. + + This is the original paper which introduced the algorithm: + SURF: Speeded Up Robust Features + By Herbert Bay, Tinne Tuytelaars, and Luc Van Gool + + This paper provides a nice detailed overview of how the algorithm works: + Notes on the OpenSURF Library by Christopher Evans + */ + +// ---------------------------------------------------------------------------------------- + + double gaussian ( + double x, + double y, + double sig + ); + /*! + requires + - sig > 0 + ensures + - computes and returns the value of a 2D Gaussian function with mean 0 + and standard deviation sig at the given (x,y) point. + !*/ + +// ---------------------------------------------------------------------------------------- + + template + double compute_dominant_angle ( + const integral_image_type& img, + const dlib::vector& center, + const double& scale + ); + /*! + requires + - integral_image_type == an object such as dlib::integral_image or another + type that implements the interface defined in image_transforms/integral_image_abstract.h + - scale > 0 + - get_rect(img).contains(centered_rect(center, 17*scale, 17*scale)) == true + (i.e. center can't be within 17*scale pixels of the edge of the image) + ensures + - computes and returns the dominant angle (i.e. the angle of the dominant gradient) + at the given center point and scale in img. + - The returned angle is in radians. Specifically, if the angle is described by + a vector vect then the angle is exactly the value of std::atan2(vect.y(), vect.x()) + !*/ + +// ---------------------------------------------------------------------------------------- + + template + void compute_surf_descriptor ( + const integral_image_type& img, + const dlib::vector& center, + const double scale, + const double angle, + matrix& des + ) + /*! + requires + - integral_image_type == an object such as dlib::integral_image or another + type that implements the interface defined in image_transforms/integral_image_abstract.h + - scale > 0 + - get_rect(img).contains(centered_rect(center, 32*scale, 32*scale)) == true + (i.e. center can't be within 32*scale pixels of the edge of the image) + ensures + - computes the 64 dimensional SURF descriptor vector of a box centered + at the given center point, tilted at an angle determined by the given + angle, and sized according to the given scale. + - #des == the computed SURF descriptor vector extracted from the img object. + - The angle is measured in radians and measures the degree of counter-clockwise + rotation around the center point. This is the same kind of rotation as is + performed by the dlib::rotate_point() function. + !*/ + +// ---------------------------------------------------------------------------------------- + + struct surf_point + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a detected SURF point. The meanings of + its fields are defined below in the get_surf_points() function. + !*/ + + interest_point p; + matrix des; + double angle; + }; + +// ---------------------------------------------------------------------------------------- + + void serialize ( + const surf_point& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + + void deserialize ( + surf_point& item, + std::istream& in + ); + /*! + provides serialization support + !*/ + +// ---------------------------------------------------------------------------------------- + + template + const std::vector get_surf_points ( + const image_type& img, + long max_points = 10000, + double detection_threshold = 30.0 + ); + /*! + requires + - max_points > 0 + - detection_threshold >= 0 + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - Let P denote the type of pixel in img, then we require: + - pixel_traits

    ::has_alpha == false + ensures + - This function runs the complete SURF algorithm on the given input image and + returns the points it found. + - returns a vector V such that: + - V.size() <= max_points + - for all valid i: + - V[i] == a SURF point found in the given input image img + - V[i].p == the interest_point extracted from the hessian pyramid for this + SURF point. + - V[i].des == the SURF descriptor for this point (calculated using + compute_surf_descriptor()) + - V[i].angle == the angle of the SURF box at this point (calculated using + compute_dominant_angle()) + - V[i].p.score >= detection_threshold + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SURf_ABSTRACT_H_ + + diff --git a/ml/dlib/dlib/image_loader/image_loader.h b/ml/dlib/dlib/image_loader/image_loader.h new file mode 100644 index 000000000..4fa29dab2 --- /dev/null +++ b/ml/dlib/dlib/image_loader/image_loader.h @@ -0,0 +1,863 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_IMAGE_LOADEr_ +#define DLIB_IMAGE_LOADEr_ + +#include "image_loader_abstract.h" +#include +#include +#include "../algs.h" +#include "../pixel.h" +#include "../image_saver/dng_shared.h" +#include "../entropy_decoder_model.h" +#include "../entropy_decoder.h" +#include "../uintn.h" +#include "../image_transforms/assign_image.h" +#include +#include "../vectorstream.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class image_load_error : public dlib::error { + public: image_load_error(const std::string& str) : error(EIMAGE_LOAD,str){} + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + void load_bmp ( + image_type& image_, + std::istream& in_ + ) + { + image_view image(image_); + try + { + unsigned long bytes_read_so_far = 0; + unsigned long bfSize; + unsigned long bfOffBits; + unsigned long bfReserved; + unsigned long biSize; + unsigned long biWidth; + unsigned long biHeight; + unsigned short biBitCount; + unsigned long biCompression; + /* + unsigned long biSizeImage; + unsigned long biClrUsed; + unsigned long biClrImportant; + */ + unsigned long a, b, c, d, i; + + using namespace std; + + streambuf& in = *in_.rdbuf(); + // streamsize num; + unsigned char buf[100]; + + + // first make sure the BMP starts with BM + if (in.sgetn(reinterpret_cast(buf),2) != 2) + throw image_load_error("bmp load error 1: header error"); + bytes_read_so_far += 2; + + if (buf[0] != 'B' || buf[1] != 'M') + throw image_load_error("bmp load error 2: header error"); + + // now read the BITMAPFILEHEADER + if (in.sgetn(reinterpret_cast(buf),12) != 12) + throw image_load_error("bmp load error 3: header error"); + + bytes_read_so_far += 12; + + i = 0; + a = buf[i]; b = buf[i+1]; c = buf[i+2]; d = buf[i+3]; + bfSize = a | (b<<8) | (c<<16) | (d<<24); + + i = 4; + a = buf[i]; b = buf[i+1]; c = buf[i+2]; d = buf[i+3]; + bfReserved = a | (b<<8) | (c<<16) | (d<<24); + + i = 8; + a = buf[i]; b = buf[i+1]; c = buf[i+2]; d = buf[i+3]; + bfOffBits = a | (b<<8) | (c<<16) | (d<<24); + + // if this value isn't zero then there is something wrong + // with this bitmap. + if (bfReserved != 0) + throw image_load_error("bmp load error 4: reserved area not zero"); + + + // load the BITMAPINFOHEADER + if (in.sgetn(reinterpret_cast(buf),40) != 40) + throw image_load_error("bmp load error 5: file too short"); + bytes_read_so_far += 40; + + + i = 0; + a = buf[i]; b = buf[i+1]; c = buf[i+2]; d = buf[i+3]; + biSize = a | (b<<8) | (c<<16) | (d<<24); + + i += 4; + a = buf[i]; b = buf[i+1]; c = buf[i+2]; d = buf[i+3]; + biWidth = a | (b<<8) | (c<<16) | (d<<24); + + i += 4; + a = buf[i]; b = buf[i+1]; c = buf[i+2]; d = buf[i+3]; + biHeight = a | (b<<8) | (c<<16) | (d<<24); + + i += 4+2; + a = buf[i]; b = buf[i+1]; + biBitCount = static_cast(a | (b<<8)); + + i += 2; + a = buf[i]; b = buf[i+1]; c = buf[i+2]; d = buf[i+3]; + biCompression = a | (b<<8) | (c<<16) | (d<<24); + + /* + i += 4; + a = buf[i]; b = buf[i+1]; c = buf[i+2]; d = buf[i+3]; + biSizeImage = a | (b<<8) | (c<<16) | (d<<24); + + i += 4+4+4; + a = buf[i]; b = buf[i+1]; c = buf[i+2]; d = buf[i+3]; + biClrUsed = a | (b<<8) | (c<<16) | (d<<24); + + i += 4; + a = buf[i]; b = buf[i+1]; c = buf[i+2]; d = buf[i+3]; + biClrImportant = a | (b<<8) | (c<<16) | (d<<24); + */ + + + if (biSize != 40) + throw image_load_error("bmp load error 6: header too small"); + + // read and discard any extra bytes that are part of the header + if (biSize > 40) + { + if (in.sgetn(reinterpret_cast(buf),biSize-40) != static_cast(biSize - 40)) + { + throw image_load_error("bmp load error 7: header too small"); + } + bytes_read_so_far += biSize-40; + } + + image.set_size(biHeight, biWidth); + + switch (biBitCount) + { + case 1: + { + // figure out how the pixels are packed + long padding; + if (bfSize - bfOffBits == biWidth*biHeight/8) + padding = 0; + else + padding = 4 - ((biWidth+7)/8)%4; + + const unsigned int palette_size = 2; + unsigned char red[palette_size]; + unsigned char green[palette_size]; + unsigned char blue[palette_size]; + + for (unsigned int i = 0; i < palette_size; ++i) + { + if (in.sgetn(reinterpret_cast(buf),4) != 4) + { + throw image_load_error("bmp load error 20: color palette missing"); + } + bytes_read_so_far += 4; + blue[i] = buf[0]; + green[i] = buf[1]; + red[i] = buf[2]; + } + + + // seek to the start of the pixel data + while (bytes_read_so_far != bfOffBits) + { + const long to_read = (long)std::min(bfOffBits - bytes_read_so_far, (unsigned long)sizeof(buf)); + if (in.sgetn(reinterpret_cast(buf), to_read) != to_read) + { + throw image_load_error("bmp load error: missing data"); + } + bytes_read_so_far += to_read; + } + + // load the image data + for (long row = biHeight-1; row >= 0; --row) + { + for (unsigned long col = 0; col < biWidth; col+=8) + { + if (in.sgetn(reinterpret_cast(buf),1) != 1) + { + throw image_load_error("bmp load error 21.6: file too short"); + } + + unsigned char pixels[8]; + + pixels[0] = (buf[0]>>7); + pixels[1] = ((buf[0]>>6)&0x01); + pixels[2] = ((buf[0]>>5)&0x01); + pixels[3] = ((buf[0]>>4)&0x01); + pixels[4] = ((buf[0]>>3)&0x01); + pixels[5] = ((buf[0]>>2)&0x01); + pixels[6] = ((buf[0]>>1)&0x01); + pixels[7] = ((buf[0])&0x01); + + for (int i = 0; i < 8 && col+i < biWidth; ++i) + { + rgb_pixel p; + p.red = red[pixels[i]]; + p.green = green[pixels[i]]; + p.blue = blue[pixels[i]]; + assign_pixel(image[row][col+i],p); + } + } + if (in.sgetn(reinterpret_cast(buf),padding) != padding) + throw image_load_error("bmp load error 9: file too short"); + } + + + + } break; + case 4: + { + // figure out how the pixels are packed + long padding; + if (bfSize - bfOffBits == biWidth*biHeight/2) + padding = 0; + else + padding = 4 - ((biWidth+1)/2)%4; + + const unsigned int palette_size = 16; + unsigned char red[palette_size]; + unsigned char green[palette_size]; + unsigned char blue[palette_size]; + + for (unsigned int i = 0; i < palette_size; ++i) + { + if (in.sgetn(reinterpret_cast(buf),4) != 4) + { + throw image_load_error("bmp load error 20: color palette missing"); + } + bytes_read_so_far += 4; + blue[i] = buf[0]; + green[i] = buf[1]; + red[i] = buf[2]; + } + + + // seek to the start of the pixel data + while (bytes_read_so_far != bfOffBits) + { + const long to_read = (long)std::min(bfOffBits - bytes_read_so_far, (unsigned long)sizeof(buf)); + if (in.sgetn(reinterpret_cast(buf), to_read) != to_read) + { + throw image_load_error("bmp load error: missing data"); + } + bytes_read_so_far += to_read; + } + + // load the image data + for (long row = biHeight-1; row >= 0; --row) + { + for (unsigned long col = 0; col < biWidth; col+=2) + { + if (in.sgetn(reinterpret_cast(buf),1) != 1) + { + throw image_load_error("bmp load error 21.7: file too short"); + } + + const unsigned char pixel1 = (buf[0]>>4); + const unsigned char pixel2 = (buf[0]&0x0F); + + rgb_pixel p; + p.red = red[pixel1]; + p.green = green[pixel1]; + p.blue = blue[pixel1]; + assign_pixel(image[row][col], p); + + if (col+1 < biWidth) + { + p.red = red[pixel2]; + p.green = green[pixel2]; + p.blue = blue[pixel2]; + assign_pixel(image[row][col+1], p); + } + } + if (in.sgetn(reinterpret_cast(buf),padding) != padding) + throw image_load_error("bmp load error 9: file too short"); + } + + + + } break; + case 8: + { + // figure out how the pixels are packed + long padding; + if (bfSize - bfOffBits == biWidth*biHeight) + padding = 0; + else + padding = 4 - biWidth%4; + + // check for this case. It shouldn't happen but some BMP writers screw up the files + // so we have to do this. + if (biHeight*(biWidth+padding) > bfSize - bfOffBits) + padding = 0; + + const unsigned int palette_size = 256; + unsigned char red[palette_size]; + unsigned char green[palette_size]; + unsigned char blue[palette_size]; + + for (unsigned int i = 0; i < palette_size; ++i) + { + if (in.sgetn(reinterpret_cast(buf),4) != 4) + { + throw image_load_error("bmp load error 20: color palette missing"); + } + bytes_read_so_far += 4; + blue[i] = buf[0]; + green[i] = buf[1]; + red[i] = buf[2]; + } + + + // seek to the start of the pixel data + while (bytes_read_so_far != bfOffBits) + { + const long to_read = (long)std::min(bfOffBits - bytes_read_so_far, (unsigned long)sizeof(buf)); + if (in.sgetn(reinterpret_cast(buf), to_read) != to_read) + { + throw image_load_error("bmp load error: missing data"); + } + bytes_read_so_far += to_read; + } + + // Next we load the image data. + + // if there is no RLE compression + if (biCompression == 0) + { + for (long row = biHeight-1; row >= 0; --row) + { + for (unsigned long col = 0; col < biWidth; ++col) + { + if (in.sgetn(reinterpret_cast(buf),1) != 1) + { + throw image_load_error("bmp load error 21.8: file too short"); + } + + rgb_pixel p; + p.red = red[buf[0]]; + p.green = green[buf[0]]; + p.blue = blue[buf[0]]; + assign_pixel(image[row][col],p); + } + if (in.sgetn(reinterpret_cast(buf),padding) != padding) + throw image_load_error("bmp load error 9: file too short"); + } + } + else + { + // Here we deal with the psychotic RLE used by BMP files. + + // First zero the image since the RLE sometimes jumps over + // pixels and assumes the image has been zero initialized. + assign_all_pixels(image, 0); + + long row = biHeight-1; + long col = 0; + while (true) + { + if (in.sgetn(reinterpret_cast(buf),2) != 2) + { + throw image_load_error("bmp load error 21.9: file too short"); + } + + const unsigned char count = buf[0]; + const unsigned char command = buf[1]; + + if (count == 0 && command == 0) + { + // This is an escape code that means go to the next row + // of the image + --row; + col = 0; + continue; + } + else if (count == 0 && command == 1) + { + // This is the end of the image. So quit this loop. + break; + } + else if (count == 0 && command == 2) + { + // This is the escape code for the command to jump to + // a new part of the image relative to where we are now. + if (in.sgetn(reinterpret_cast(buf),2) != 2) + { + throw image_load_error("bmp load error 21.1: file too short"); + } + col += buf[0]; + row -= buf[1]; + continue; + } + else if (count == 0) + { + // This is the escape code for a run of uncompressed bytes + + if (row < 0 || col + command > image.nc()) + { + // If this is just some padding bytes at the end then ignore them + if (row >= 0 && col + count <= image.nc() + padding) + continue; + + throw image_load_error("bmp load error 21.2: file data corrupt"); + } + + // put the bytes into the image + for (unsigned int i = 0; i < command; ++i) + { + if (in.sgetn(reinterpret_cast(buf),1) != 1) + { + throw image_load_error("bmp load error 21.3: file too short"); + } + rgb_pixel p; + p.red = red[buf[0]]; + p.green = green[buf[0]]; + p.blue = blue[buf[0]]; + assign_pixel(image[row][col],p); + + ++col; + } + + // if we read an uneven number of bytes then we need to read and + // discard the next byte. + if ((command&1) != 1) + { + if (in.sgetn(reinterpret_cast(buf),1) != 1) + { + throw image_load_error("bmp load error 21.4: file too short"); + } + } + + continue; + } + + rgb_pixel p; + + if (row < 0 || col + count > image.nc()) + { + // If this is just some padding bytes at the end then ignore them + if (row >= 0 && col + count <= image.nc() + padding) + continue; + + throw image_load_error("bmp load error 21.5: file data corrupt"); + } + + // put the bytes into the image + for (unsigned int i = 0; i < count; ++i) + { + p.red = red[command]; + p.green = green[command]; + p.blue = blue[command]; + assign_pixel(image[row][col],p); + + ++col; + } + } + } + + + + } + break; + case 16: + throw image_load_error ("16 bit BMP images not supported"); + case 24: + { + // figure out how the pixels are packed + long padding; + if (bfSize - bfOffBits == biWidth*biHeight*3) + padding = 0; + else + padding = 4 - (biWidth*3)%4; + + // check for this case. It shouldn't happen but some BMP writers screw up the files + // so we have to do this. + if (biHeight*(biWidth*3+padding) > bfSize - bfOffBits) + padding = 0; + + // seek to the start of the pixel data + while (bytes_read_so_far != bfOffBits) + { + const long to_read = (long)std::min(bfOffBits - bytes_read_so_far, (unsigned long)sizeof(buf)); + if (in.sgetn(reinterpret_cast(buf), to_read) != to_read) + { + throw image_load_error("bmp load error: missing data"); + } + bytes_read_so_far += to_read; + } + + // load the image data + for (long row = biHeight-1; row >= 0; --row) + { + for (unsigned long col = 0; col < biWidth; ++col) + { + if (in.sgetn(reinterpret_cast(buf),3) != 3) + { + throw image_load_error("bmp load error 8: file too short"); + } + + rgb_pixel p; + p.red = buf[2]; + p.green = buf[1]; + p.blue = buf[0]; + assign_pixel(image[row][col], p); + + } + if (in.sgetn(reinterpret_cast(buf),padding) != padding) + throw image_load_error("bmp load error 9: file too short"); + } + + break; + } + case 32: + throw image_load_error ("32 bit BMP images not supported"); + default: + throw image_load_error("bmp load error 10: unknown color depth"); + + } + } + catch (...) + { + image.clear(); + throw; + } + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + void load_dng ( + image_type& image_, + std::istream& in + ) + { + image_view image(image_); + using namespace dng_helpers_namespace; + try + { + if (in.get() != 'D' || in.get() != 'N' || in.get() != 'G') + throw image_load_error("the stream does not contain a dng image file"); + + unsigned long version; + deserialize(version,in); + if (version != 1) + throw image_load_error("You need the new version of the dlib library to read this dng file"); + + unsigned long type; + deserialize(type,in); + + long width, height; + deserialize(width,in); + deserialize(height,in); + + if (width > 0 && height > 0) + image.set_size(height,width); + else + image.clear(); + + if (type != grayscale_float) + { + typedef entropy_decoder::kernel_2a decoder_type; + decoder_type decoder; + decoder.set_stream(in); + + entropy_decoder_model<256,decoder_type>::kernel_5a edm(decoder); + unsigned long symbol; + rgb_pixel p_rgb; + rgb_alpha_pixel p_rgba; + hsi_pixel p_hsi; + switch (type) + { + case rgb_alpha_paeth: + + for (long r = 0; r < image.nr(); ++r) + { + for (long c = 0; c < image.nc(); ++c) + { + p_rgba = predictor_rgb_alpha_paeth(image,r,c); + edm.decode(symbol); + p_rgba.red += static_cast(symbol); + + edm.decode(symbol); + p_rgba.green += static_cast(symbol); + + edm.decode(symbol); + p_rgba.blue += static_cast(symbol); + + edm.decode(symbol); + p_rgba.alpha += static_cast(symbol); + + assign_pixel(image[r][c],p_rgba); + } + } + break; + + case rgb_alpha: + + for (long r = 0; r < image.nr(); ++r) + { + for (long c = 0; c < image.nc(); ++c) + { + p_rgba = predictor_rgb_alpha(image,r,c); + edm.decode(symbol); + p_rgba.red += static_cast(symbol); + + edm.decode(symbol); + p_rgba.green += static_cast(symbol); + + edm.decode(symbol); + p_rgba.blue += static_cast(symbol); + + edm.decode(symbol); + p_rgba.alpha += static_cast(symbol); + + assign_pixel(image[r][c],p_rgba); + } + } + break; + + case rgb_paeth: + + for (long r = 0; r < image.nr(); ++r) + { + for (long c = 0; c < image.nc(); ++c) + { + p_rgb = predictor_rgb_paeth(image,r,c); + edm.decode(symbol); + p_rgb.red += static_cast(symbol); + + edm.decode(symbol); + p_rgb.green += static_cast(symbol); + + edm.decode(symbol); + p_rgb.blue += static_cast(symbol); + + assign_pixel(image[r][c],p_rgb); + } + } + break; + + case rgb: + + for (long r = 0; r < image.nr(); ++r) + { + for (long c = 0; c < image.nc(); ++c) + { + p_rgb = predictor_rgb(image,r,c); + edm.decode(symbol); + p_rgb.red += static_cast(symbol); + + edm.decode(symbol); + p_rgb.green += static_cast(symbol); + + edm.decode(symbol); + p_rgb.blue += static_cast(symbol); + + assign_pixel(image[r][c],p_rgb); + } + } + break; + + case hsi: + + for (long r = 0; r < image.nr(); ++r) + { + for (long c = 0; c < image.nc(); ++c) + { + p_hsi = predictor_hsi(image,r,c); + edm.decode(symbol); + p_hsi.h += static_cast(symbol); + + edm.decode(symbol); + p_hsi.s += static_cast(symbol); + + edm.decode(symbol); + p_hsi.i += static_cast(symbol); + + assign_pixel(image[r][c],p_hsi); + } + } + break; + + case grayscale: + { + unsigned char p; + for (long r = 0; r < image.nr(); ++r) + { + for (long c = 0; c < image.nc(); ++c) + { + edm.decode(symbol); + p = static_cast(symbol); + p += predictor_grayscale(image,r,c); + assign_pixel(image[r][c],p); + } + } + } + break; + + case grayscale_16bit: + { + uint16 p; + for (long r = 0; r < image.nr(); ++r) + { + for (long c = 0; c < image.nc(); ++c) + { + edm.decode(symbol); + p = static_cast(symbol); + p <<= 8; + edm.decode(symbol); + p |= static_cast(symbol); + + p += predictor_grayscale_16(image,r,c); + assign_pixel(image[r][c],p); + } + } + } + break; + + default: + throw image_load_error("corruption detected in the dng file"); + } // switch (type) + + edm.decode(symbol); + if (symbol != dng_magic_byte) + throw image_load_error("corruption detected in the dng file"); + edm.decode(symbol); + if (symbol != dng_magic_byte) + throw image_load_error("corruption detected in the dng file"); + edm.decode(symbol); + if (symbol != dng_magic_byte) + throw image_load_error("corruption detected in the dng file"); + edm.decode(symbol); + if (symbol != dng_magic_byte) + throw image_load_error("corruption detected in the dng file"); + } + else // if this is a grayscale_float type image + { + std::vector man(image.size()); + std::vector expbuf; + // get the mantissa data + for (unsigned long j = 0; j < man.size(); ++j) + deserialize(man[j], in); + // get the compressed exponent data + deserialize(expbuf, in); + typedef entropy_decoder::kernel_2a decoder_type; + typedef entropy_decoder_model<256,decoder_type>::kernel_4a edm_exp_type; + vectorstream inexp(expbuf); + decoder_type decoder; + decoder.set_stream(inexp); + + edm_exp_type edm_exp(decoder); + float_details prev; + unsigned long i = 0; + // fill out the image + for (long r = 0; r < image.nr(); ++r) + { + for (long c = 0; c < image.nc(); ++c) + { + unsigned long exp1, exp2; + edm_exp.decode(exp1); + edm_exp.decode(exp2); + + float_details cur(man[i++],(exp2<<8) | exp1); + cur.exponent += prev.exponent; + cur.mantissa += prev.mantissa; + prev = cur; + + // Only use long double precision if the target image contains long + // doubles because it's slower to use those. + if (!is_same_type::pixel_type,long double>::value) + { + double temp = cur; + assign_pixel(image[r][c],temp); + } + else + { + long double temp = cur; + assign_pixel(image[r][c],temp); + } + } + } + unsigned long symbol; + edm_exp.decode(symbol); + if (symbol != dng_magic_byte) + throw image_load_error("corruption detected in the dng file"); + edm_exp.decode(symbol); + if (symbol != dng_magic_byte) + throw image_load_error("corruption detected in the dng file"); + edm_exp.decode(symbol); + if (symbol != dng_magic_byte) + throw image_load_error("corruption detected in the dng file"); + edm_exp.decode(symbol); + if (symbol != dng_magic_byte) + throw image_load_error("corruption detected in the dng file"); + } + } + catch (...) + { + image.clear(); + throw; + } + + } + +// ---------------------------------------------------------------------------------------- + + template + void load_bmp ( + image_type& image, + const std::string& file_name + ) + { + std::ifstream fin(file_name.c_str(), std::ios::binary); + if (!fin) + throw image_load_error("Unable to open " + file_name + " for reading."); + load_bmp(image, fin); + } + +// ---------------------------------------------------------------------------------------- + + template + void load_dng ( + image_type& image, + const std::string& file_name + ) + { + std::ifstream fin(file_name.c_str(), std::ios::binary); + if (!fin) + throw image_load_error("Unable to open " + file_name + " for reading."); + load_dng(image, fin); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_IMAGE_LOADEr_ + + + diff --git a/ml/dlib/dlib/image_loader/image_loader_abstract.h b/ml/dlib/dlib/image_loader/image_loader_abstract.h new file mode 100644 index 000000000..cd66b3699 --- /dev/null +++ b/ml/dlib/dlib/image_loader/image_loader_abstract.h @@ -0,0 +1,136 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_IMAGE_LOADEr_ABSTRACT_ +#ifdef DLIB_IMAGE_LOADEr_ABSTRACT_ + +#include +#include "../algs.h" +#include "../pixel.h" +#include "../image_processing/generic_image.h" + +namespace dlib +{ + class image_load_error : public dlib::error + { + /*! + WHAT THIS OBJECT REPRESENTS + This is an exception used to indicate a failure to load an image. + Its type member variable will be set to EIMAGE_LOAD. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + void load_bmp ( + image_type& image, + std::istream& in + ); + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + ensures + - #image == the image of the MS Windows BMP file that was available + in the input stream in. + - #image[0][0] will be the upper left corner of the image + - #image[image.nr()-1][image.nc()-1] will be the lower right + corner of the image + - Performs any color space conversion necessary to convert the + BMP image data into the pixel type used by the given image + object. + throws + - image_load_error + This exception is thrown if there is an error that prevents us + from loading the image. If this exception is thrown then + #image will have an initial value for its type. + - std::bad_alloc + If this exception is thrown then #image will have an initial + value for its type. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + void load_bmp ( + image_type& image, + const std::string& file_name + ); + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + ensures + - opens the file indicated by file_name with an input file stream named fin + and performs: + load_bmp(image,fin); + !*/ + +// ---------------------------------------------------------------------------------------- + + /*! + dlib dng file format: + This is a file format I created for this library. It is a lossless + compressed image format that is similar to the PNG format but uses + the dlib PPM compression algorithms instead of the DEFLATE algorithm. + !*/ + + template < + typename image_type + > + void load_dng ( + image_type& image, + std::istream& in + ); + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + ensures + - #image == the image of the dlib dng file that was available + in the input stream in. + - #image[0][0] will be the upper left corner of the image + - #image[image.nr()-1][image.nc()-1] will be the lower right + corner of the image + - Performs any color space conversion necessary to convert the + dng image data into the pixel type used by the given image + object. + throws + - image_load_error + This exception is thrown if there is an error that prevents us + from loading the image. If this exception is thrown then + #image will have an initial value for its type. + - std::bad_alloc + If this exception is thrown then #image will have an initial + value for its type. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + void load_dng ( + image_type& image, + const std::string& file_name + ); + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + ensures + - opens the file indicated by file_name with an input file stream named fin + and performs: + load_dng(image,fin); + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_IMAGE_LOADEr_ABSTRACT_ + diff --git a/ml/dlib/dlib/image_loader/jpeg_loader.cpp b/ml/dlib/dlib/image_loader/jpeg_loader.cpp new file mode 100644 index 000000000..710d7586f --- /dev/null +++ b/ml/dlib/dlib/image_loader/jpeg_loader.cpp @@ -0,0 +1,173 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net), Nils Labugt +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_JPEG_LOADER_CPp_ +#define DLIB_JPEG_LOADER_CPp_ + +// only do anything with this file if DLIB_JPEG_SUPPORT is defined +#ifdef DLIB_JPEG_SUPPORT + +#include "../array2d.h" +#include "../pixel.h" +#include "../dir_nav.h" +#include "jpeg_loader.h" +#include +#ifdef DLIB_JPEG_STATIC +# include "../external/libjpeg/jpeglib.h" +#else +# include +#endif +#include +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + jpeg_loader:: + jpeg_loader( const char* filename ) : height_( 0 ), width_( 0 ), output_components_(0) + { + read_image( filename ); + } + +// ---------------------------------------------------------------------------------------- + + jpeg_loader:: + jpeg_loader( const std::string& filename ) : height_( 0 ), width_( 0 ), output_components_(0) + { + read_image( filename.c_str() ); + } + +// ---------------------------------------------------------------------------------------- + + jpeg_loader:: + jpeg_loader( const dlib::file& f ) : height_( 0 ), width_( 0 ), output_components_(0) + { + read_image( f.full_name().c_str() ); + } + +// ---------------------------------------------------------------------------------------- + + bool jpeg_loader::is_gray() const + { + return (output_components_ == 1); + } + +// ---------------------------------------------------------------------------------------- + + bool jpeg_loader::is_rgb() const + { + return (output_components_ == 3); + } + +// ---------------------------------------------------------------------------------------- + + bool jpeg_loader::is_rgba() const + { + return (output_components_ == 4); + } + +// ---------------------------------------------------------------------------------------- + + struct jpeg_loader_error_mgr + { + jpeg_error_mgr pub; /* "public" fields */ + jmp_buf setjmp_buffer; /* for return to caller */ + }; + + void jpeg_loader_error_exit (j_common_ptr cinfo) + { + /* cinfo->err really points to a jpeg_loader_error_mgr struct, so coerce pointer */ + jpeg_loader_error_mgr* myerr = (jpeg_loader_error_mgr*) cinfo->err; + + /* Return control to the setjmp point */ + longjmp(myerr->setjmp_buffer, 1); + } + +// ---------------------------------------------------------------------------------------- + + void jpeg_loader::read_image( const char* filename ) + { + if ( filename == NULL ) + { + throw image_load_error("jpeg_loader: invalid filename, it is NULL"); + } + FILE *fp = fopen( filename, "rb" ); + if ( !fp ) + { + throw image_load_error(std::string("jpeg_loader: unable to open file ") + filename); + } + + jpeg_decompress_struct cinfo; + jpeg_loader_error_mgr jerr; + + cinfo.err = jpeg_std_error(&jerr.pub); + + jerr.pub.error_exit = jpeg_loader_error_exit; + + /* Establish the setjmp return context for my_error_exit to use. */ + if (setjmp(jerr.setjmp_buffer)) + { + /* If we get here, the JPEG code has signaled an error. + * We need to clean up the JPEG object, close the input file, and return. + */ + jpeg_destroy_decompress(&cinfo); + fclose(fp); + throw image_load_error(std::string("jpeg_loader: error while reading ") + filename); + } + + + jpeg_create_decompress(&cinfo); + + jpeg_stdio_src(&cinfo, fp); + + jpeg_read_header(&cinfo, TRUE); + + jpeg_start_decompress(&cinfo); + + height_ = cinfo.output_height; + width_ = cinfo.output_width; + output_components_ = cinfo.output_components; + + if (output_components_ != 1 && + output_components_ != 3 && + output_components_ != 4) + { + fclose( fp ); + jpeg_destroy_decompress(&cinfo); + std::ostringstream sout; + sout << "jpeg_loader: Unsupported number of colors (" << output_components_ << ") in file " << filename; + throw image_load_error(sout.str()); + } + + std::vector rows; + rows.resize(height_); + + // size the image buffer + data.resize(height_*width_*output_components_); + + // setup pointers to each row + for (unsigned long i = 0; i < rows.size(); ++i) + rows[i] = &data[i*width_*output_components_]; + + // read the data into the buffer + while (cinfo.output_scanline < cinfo.output_height) + { + jpeg_read_scanlines(&cinfo, &rows[cinfo.output_scanline], 100); + } + + jpeg_finish_decompress(&cinfo); + jpeg_destroy_decompress(&cinfo); + + fclose( fp ); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_JPEG_SUPPORT + +#endif // DLIB_JPEG_LOADER_CPp_ + + diff --git a/ml/dlib/dlib/image_loader/jpeg_loader.h b/ml/dlib/dlib/image_loader/jpeg_loader.h new file mode 100644 index 000000000..097a461f8 --- /dev/null +++ b/ml/dlib/dlib/image_loader/jpeg_loader.h @@ -0,0 +1,109 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net), Nils Labugt +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_JPEG_IMPORT +#define DLIB_JPEG_IMPORT + +#include + +#include "jpeg_loader_abstract.h" +#include "image_loader.h" +#include "../pixel.h" +#include "../dir_nav.h" +#include "../test_for_odr_violations.h" + +namespace dlib +{ + + class jpeg_loader : noncopyable + { + public: + + jpeg_loader( const char* filename ); + jpeg_loader( const std::string& filename ); + jpeg_loader( const dlib::file& f ); + + bool is_gray() const; + bool is_rgb() const; + bool is_rgba() const; + + template + void get_image( T& t_) const + { +#ifndef DLIB_JPEG_SUPPORT + /* !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + You are getting this error because you are trying to use the jpeg_loader + object but you haven't defined DLIB_JPEG_SUPPORT. You must do so to use + this object. You must also make sure you set your build environment + to link against the libjpeg library. + !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!*/ + COMPILE_TIME_ASSERT(sizeof(T) == 0); +#endif + image_view t(t_); + t.set_size( height_, width_ ); + for ( unsigned n = 0; n < height_;n++ ) + { + const unsigned char* v = get_row( n ); + for ( unsigned m = 0; m < width_;m++ ) + { + if ( is_gray() ) + { + unsigned char p = v[m]; + assign_pixel( t[n][m], p ); + } + else if ( is_rgba() ) { + rgb_alpha_pixel p; + p.red = v[m*4]; + p.green = v[m*4+1]; + p.blue = v[m*4+2]; + p.alpha = v[m*4+3]; + assign_pixel( t[n][m], p ); + } + else // if ( is_rgb() ) + { + rgb_pixel p; + p.red = v[m*3]; + p.green = v[m*3+1]; + p.blue = v[m*3+2]; + assign_pixel( t[n][m], p ); + } + } + } + } + + private: + const unsigned char* get_row( unsigned long i ) const + { + return &data[i*width_*output_components_]; + } + + void read_image( const char* filename ); + unsigned long height_; + unsigned long width_; + unsigned long output_components_; + std::vector data; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + void load_jpeg ( + image_type& image, + const std::string& file_name + ) + { + jpeg_loader(file_name).get_image(image); + } + +// ---------------------------------------------------------------------------------------- + +} + +#ifdef NO_MAKEFILE +#include "jpeg_loader.cpp" +#endif + +#endif // DLIB_JPEG_IMPORT + + diff --git a/ml/dlib/dlib/image_loader/jpeg_loader_abstract.h b/ml/dlib/dlib/image_loader/jpeg_loader_abstract.h new file mode 100644 index 000000000..48b5bb031 --- /dev/null +++ b/ml/dlib/dlib/image_loader/jpeg_loader_abstract.h @@ -0,0 +1,133 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net), Nils Labugt +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_JPEG_IMPORT_ABSTRACT +#ifdef DLIB_JPEG_IMPORT_ABSTRACT + +#include "image_loader_abstract.h" +#include "../algs.h" +#include "../pixel.h" +#include "../dir_nav.h" +#include "../image_processing/generic_image.h" + +namespace dlib +{ + + class jpeg_loader : noncopyable + { + /*! + INITIAL VALUE + Defined by the constructors + + WHAT THIS OBJECT REPRESENTS + This object represents a class capable of loading JPEG image files. + Once an instance of it is created to contain a JPEG file from + disk you can obtain the image stored in it via get_image(). + !*/ + + public: + + jpeg_loader( + const char* filename + ); + /*! + ensures + - loads the JPEG file with the given file name into this object + throws + - std::bad_alloc + - image_load_error + This exception is thrown if there is some error that prevents + us from loading the given JPEG file. + !*/ + + jpeg_loader( + const std::string& filename + ); + /*! + ensures + - loads the JPEG file with the given file name into this object + throws + - std::bad_alloc + - image_load_error + This exception is thrown if there is some error that prevents + us from loading the given JPEG file. + !*/ + + jpeg_loader( + const dlib::file& f + ); + /*! + ensures + - loads the JPEG file with the given file name into this object + throws + - std::bad_alloc + - image_load_error + This exception is thrown if there is some error that prevents + us from loading the given JPEG file. + !*/ + + ~jpeg_loader( + ); + /*! + ensures + - all resources associated with *this has been released + !*/ + + bool is_gray( + ) const; + /*! + ensures + - if (this object contains a grayscale image) then + - returns true + - else + - returns false + !*/ + + bool is_rgb( + ) const; + /*! + ensures + - if (this object contains a 3 channel RGB image) then + - returns true + - else + - returns false + !*/ + + template< + typename image_type + > + void get_image( + image_type& img + ) const; + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + ensures + - loads the JPEG image stored in this object into img + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + void load_jpeg ( + image_type& image, + const std::string& file_name + ); + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + ensures + - performs: jpeg_loader(file_name).get_image(image); + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_JPEG_IMPORT_ABSTRACT + diff --git a/ml/dlib/dlib/image_loader/load_image.h b/ml/dlib/dlib/image_loader/load_image.h new file mode 100644 index 000000000..64ccea9f2 --- /dev/null +++ b/ml/dlib/dlib/image_loader/load_image.h @@ -0,0 +1,226 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net), Nils Labugt, Changjiang Yang (yangcha@leidos.com) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_LOAd_IMAGE_Hh_ +#define DLIB_LOAd_IMAGE_Hh_ + +#include "load_image_abstract.h" +#include "../string.h" +#include "png_loader.h" +#include "jpeg_loader.h" +#include "image_loader.h" +#include +#include +#ifdef DLIB_GIF_SUPPORT +#include +#endif + +namespace dlib +{ + namespace image_file_type + { + enum type + { + BMP, + JPG, + PNG, + DNG, + GIF, + UNKNOWN + }; + + inline type read_type(const std::string& file_name) + { + std::ifstream file(file_name.c_str(), std::ios::in|std::ios::binary); + if (!file) + throw image_load_error("Unable to open file: " + file_name); + + char buffer[9]; + file.read((char*)buffer, 8); + buffer[8] = 0; + + // Determine the true image type using link: + // http://en.wikipedia.org/wiki/List_of_file_signatures + + if (strcmp(buffer, "\x89\x50\x4E\x47\x0D\x0A\x1A\x0A") == 0) + return PNG; + else if(buffer[0]=='\xff' && buffer[1]=='\xd8' && buffer[2]=='\xff') + return JPG; + else if(buffer[0]=='B' && buffer[1]=='M') + return BMP; + else if(buffer[0]=='D' && buffer[1]=='N' && buffer[2] == 'G') + return DNG; + else if(buffer[0]=='G' && buffer[1]=='I' && buffer[2] == 'F') + return GIF; + + return UNKNOWN; + } + }; + +// ---------------------------------------------------------------------------------------- + +// handle the differences in API between libgif v5 and older. +#if defined(GIFLIB_MAJOR) && GIFLIB_MAJOR >= 5 +#define DLIB_GIFLIB_HANDLE_DIFF_VERSIONS ,0 +#else +#define DLIB_GIFLIB_HANDLE_DIFF_VERSIONS +#endif + + template + void load_image ( + image_type& image, + const std::string& file_name + ) + { + const image_file_type::type im_type = image_file_type::read_type(file_name); + switch (im_type) + { + case image_file_type::BMP: load_bmp(image, file_name); return; + case image_file_type::DNG: load_dng(image, file_name); return; +#ifdef DLIB_PNG_SUPPORT + case image_file_type::PNG: load_png(image, file_name); return; +#endif +#ifdef DLIB_JPEG_SUPPORT + case image_file_type::JPG: load_jpeg(image, file_name); return; +#endif +#ifdef DLIB_GIF_SUPPORT + case image_file_type::GIF: + { + image_view img(image); + GifFileType* gif = DGifOpenFileName(file_name.c_str() DLIB_GIFLIB_HANDLE_DIFF_VERSIONS); + try + { + if (gif == 0) throw image_load_error("Couldn't open file " + file_name); + if (DGifSlurp(gif) != GIF_OK) + throw image_load_error("Error reading from " + file_name); + + if (gif->ImageCount != 1) throw image_load_error("Dlib only supports reading GIF files containing one image."); + if (gif->SavedImages == 0) throw image_load_error("Unsupported GIF format 1."); + + ColorMapObject* cmo=gif->SColorMap?gif->SColorMap:gif->SavedImages->ImageDesc.ColorMap; + + if (cmo==0) throw image_load_error("Unsupported GIF format 2."); + if (cmo->Colors == 0) throw image_load_error("Unsupported GIF format 3."); + if (gif->SavedImages->ImageDesc.Width != gif->SWidth) throw image_load_error("Unsupported GIF format 4."); + if (gif->SavedImages->ImageDesc.Height != gif->SHeight) throw image_load_error("Unsupported GIF format 5."); + if (gif->SavedImages->RasterBits == 0) throw image_load_error("Unsupported GIF format 6."); + if (gif->Image.Top != 0) throw image_load_error("Unsupported GIF format 7."); + if (gif->Image.Left != 0) throw image_load_error("Unsupported GIF format 8."); + + img.set_size(gif->SHeight, gif->SWidth); + unsigned char* raster = gif->SavedImages->RasterBits; + GifColorType* colormap = cmo->Colors; + if (gif->Image.Interlace) + { + const long interlaced_offset[] = { 0, 4, 2, 1 }; + const long interlaced_jumps[] = { 8, 8, 4, 2 }; + for (int i = 0; i < 4; ++i) + { + for (long r = interlaced_offset[i]; r < img.nr(); r += interlaced_jumps[i]) + { + for (long c = 0; c < img.nc(); ++c) + { + if (*raster >= cmo->ColorCount) + throw image_load_error("Invalid GIF color value"); + rgb_pixel p; + p.red = colormap[*raster].Red; + p.green = colormap[*raster].Green; + p.blue = colormap[*raster].Blue; + assign_pixel(img[r][c], p); + ++raster; + } + } + } + } + else + { + for (long r = 0; r < img.nr(); ++r) + { + for (long c = 0; c < img.nc(); ++c) + { + if (*raster >= cmo->ColorCount) + throw image_load_error("Invalid GIF color value"); + rgb_pixel p; + p.red = colormap[*raster].Red; + p.green = colormap[*raster].Green; + p.blue = colormap[*raster].Blue; + assign_pixel(img[r][c], p); + ++raster; + } + } + } + DGifCloseFile(gif DLIB_GIFLIB_HANDLE_DIFF_VERSIONS); + } + catch(...) + { + if (gif) + DGifCloseFile(gif DLIB_GIFLIB_HANDLE_DIFF_VERSIONS); + throw; + } + return; + } +#endif + default: ; + } + + if (im_type == image_file_type::JPG) + { + std::ostringstream sout; + sout << "Unable to load image in file " + file_name + ".\n" + + "You must #define DLIB_JPEG_SUPPORT and link to libjpeg to read JPEG files.\n" + + "Do this by following the instructions at http://dlib.net/compile.html.\n\n"; +#ifdef _MSC_VER + sout << "Note that you must cause DLIB_JPEG_SUPPORT to be defined for your entire project.\n"; + sout << "So don't #define it in one file. Instead, add it to the C/C++->Preprocessor->Preprocessor Definitions\n"; + sout << "field in Visual Studio's Property Pages window so it takes effect for your entire application."; +#else + sout << "Note that you must cause DLIB_JPEG_SUPPORT to be defined for your entire project.\n"; + sout << "So don't #define it in one file. Instead, use a compiler switch like -DDLIB_JPEG_SUPPORT\n"; + sout << "so it takes effect for your entire application."; +#endif + throw image_load_error(sout.str()); + } + else if (im_type == image_file_type::PNG) + { + std::ostringstream sout; + sout << "Unable to load image in file " + file_name + ".\n" + + "You must #define DLIB_PNG_SUPPORT and link to libpng to read PNG files.\n" + + "Do this by following the instructions at http://dlib.net/compile.html.\n\n"; +#ifdef _MSC_VER + sout << "Note that you must cause DLIB_PNG_SUPPORT to be defined for your entire project.\n"; + sout << "So don't #define it in one file. Instead, add it to the C/C++->Preprocessor->Preprocessor Definitions\n"; + sout << "field in Visual Studio's Property Pages window so it takes effect for your entire application.\n"; +#else + sout << "Note that you must cause DLIB_PNG_SUPPORT to be defined for your entire project.\n"; + sout << "So don't #define it in one file. Instead, use a compiler switch like -DDLIB_PNG_SUPPORT\n"; + sout << "so it takes effect for your entire application."; +#endif + throw image_load_error(sout.str()); + } + else if (im_type == image_file_type::GIF) + { + std::ostringstream sout; + sout << "Unable to load image in file " + file_name + ".\n" + + "You must #define DLIB_GIF_SUPPORT and link to libgif to read GIF files.\n\n"; +#ifdef _MSC_VER + sout << "Note that you must cause DLIB_GIF_SUPPORT to be defined for your entire project.\n"; + sout << "So don't #define it in one file. Instead, add it to the C/C++->Preprocessor->Preprocessor Definitions\n"; + sout << "field in Visual Studio's Property Pages window so it takes effect for your entire application.\n"; +#else + sout << "Note that you must cause DLIB_GIF_SUPPORT to be defined for your entire project.\n"; + sout << "So don't #define it in one file. Instead, use a compiler switch like -DDLIB_GIF_SUPPORT\n"; + sout << "so it takes effect for your entire application."; +#endif + throw image_load_error(sout.str()); + } + else + { + throw image_load_error("Unknown image file format: Unable to load image in file " + file_name); + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_LOAd_IMAGE_Hh_ + diff --git a/ml/dlib/dlib/image_loader/load_image_abstract.h b/ml/dlib/dlib/image_loader/load_image_abstract.h new file mode 100644 index 000000000..f357bb278 --- /dev/null +++ b/ml/dlib/dlib/image_loader/load_image_abstract.h @@ -0,0 +1,37 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net), Nils Labugt +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_LOAd_IMAGE_ABSTRACT_ +#ifdef DLIB_LOAd_IMAGE_ABSTRACT_ + +#include "../image_processing/generic_image.h" + +namespace dlib +{ + template + void load_image ( + image_type& image, + const std::string& file_name + ); + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + ensures + - This function loads an image from disk, in the indicated file file_name, and + writes it to the indicated image object. + - It is capable of reading the PNG, JPEG, BMP, GIF, and DNG image formats. It + is always capable of reading BMP and DNG images. However, for PNG, JPEG, and + GIF you must #define DLIB_PNG_SUPPORT, DLIB_JPEG_SUPPORT, and + DLIB_GIF_SUPPORT respectively and link your program to libpng, libjpeg, and + libgif respectively. + throws + - image_load_error + This exception is thrown if there is some error that prevents + us from loading the given image file. + !*/ + +} + +#endif // DLIB_LOAd_IMAGE_ABSTRACT_ + + diff --git a/ml/dlib/dlib/image_loader/png_loader.cpp b/ml/dlib/dlib/image_loader/png_loader.cpp new file mode 100644 index 000000000..3346ddb6a --- /dev/null +++ b/ml/dlib/dlib/image_loader/png_loader.cpp @@ -0,0 +1,222 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net), Nils Labugt +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_PNG_LOADER_CPp_ +#define DLIB_PNG_LOADER_CPp_ + +// only do anything with this file if DLIB_PNG_SUPPORT is defined +#ifdef DLIB_PNG_SUPPORT + +#include "../array2d.h" +#include "../pixel.h" +#include "../dir_nav.h" +#include "png_loader.h" +#include +#include "../string.h" +#include "../byte_orderer.h" +#include +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + struct LibpngData + { + png_bytep* row_pointers_; + png_structp png_ptr_; + png_infop info_ptr_; + png_infop end_info_; + }; + +// ---------------------------------------------------------------------------------------- + + png_loader:: + png_loader( const char* filename ) : height_( 0 ), width_( 0 ) + { + read_image( filename ); + } + +// ---------------------------------------------------------------------------------------- + + png_loader:: + png_loader( const std::string& filename ) : height_( 0 ), width_( 0 ) + { + read_image( filename.c_str() ); + } + +// ---------------------------------------------------------------------------------------- + + png_loader:: + png_loader( const dlib::file& f ) : height_( 0 ), width_( 0 ) + { + read_image( f.full_name().c_str() ); + } + +// ---------------------------------------------------------------------------------------- + + const unsigned char* png_loader::get_row( unsigned i ) const + { + return ld_->row_pointers_[i]; + } + +// ---------------------------------------------------------------------------------------- + + png_loader::~png_loader() + { + if ( ld_ && ld_->row_pointers_ != NULL ) + png_destroy_read_struct( &( ld_->png_ptr_ ), &( ld_->info_ptr_ ), &( ld_->end_info_ ) ); + } + +// ---------------------------------------------------------------------------------------- + + bool png_loader::is_gray() const + { + return ( color_type_ == PNG_COLOR_TYPE_GRAY ); + } + +// ---------------------------------------------------------------------------------------- + + bool png_loader::is_graya() const + { + return ( color_type_ == PNG_COLOR_TYPE_GRAY_ALPHA ); + } + +// ---------------------------------------------------------------------------------------- + + bool png_loader::is_rgb() const + { + return ( color_type_ == PNG_COLOR_TYPE_RGB ); + } + +// ---------------------------------------------------------------------------------------- + + bool png_loader::is_rgba() const + { + return ( color_type_ == PNG_COLOR_TYPE_RGB_ALPHA ); + } + +// ---------------------------------------------------------------------------------------- + + // Don't do anything when libpng calls us to tell us about an error. Just return to + // our own code and throw an exception (at the long jump target). + void png_loader_user_error_fn_silent(png_structp png_struct, png_const_charp ) + { + longjmp(png_jmpbuf(png_struct),1); + } + void png_loader_user_warning_fn_silent(png_structp , png_const_charp ) + { + } + + void png_loader::read_image( const char* filename ) + { + ld_.reset(new LibpngData); + if ( filename == NULL ) + { + throw image_load_error("png_loader: invalid filename, it is NULL"); + } + FILE *fp = fopen( filename, "rb" ); + if ( !fp ) + { + throw image_load_error(std::string("png_loader: unable to open file ") + filename); + } + png_byte sig[8]; + if (fread( sig, 1, 8, fp ) != 8) + { + fclose( fp ); + throw image_load_error(std::string("png_loader: error reading file ") + filename); + } + if ( png_sig_cmp( sig, 0, 8 ) != 0 ) + { + fclose( fp ); + throw image_load_error(std::string("png_loader: format error in file ") + filename); + } + ld_->png_ptr_ = png_create_read_struct( PNG_LIBPNG_VER_STRING, NULL, &png_loader_user_error_fn_silent, &png_loader_user_warning_fn_silent ); + if ( ld_->png_ptr_ == NULL ) + { + fclose( fp ); + std::ostringstream sout; + sout << "Error, unable to allocate png structure while opening file " << filename << std::endl; + const char* runtime_version = png_get_header_ver(NULL); + if (runtime_version && std::strcmp(PNG_LIBPNG_VER_STRING, runtime_version) != 0) + { + sout << "This is happening because you compiled against one version of libpng, but then linked to another." << std::endl; + sout << "Compiled against libpng version: " << PNG_LIBPNG_VER_STRING << std::endl; + sout << "Linking to this version of libpng: " << runtime_version << std::endl; + } + throw image_load_error(sout.str()); + } + ld_->info_ptr_ = png_create_info_struct( ld_->png_ptr_ ); + if ( ld_->info_ptr_ == NULL ) + { + fclose( fp ); + png_destroy_read_struct( &( ld_->png_ptr_ ), ( png_infopp )NULL, ( png_infopp )NULL ); + throw image_load_error(std::string("png_loader: parse error in file ") + filename); + } + ld_->end_info_ = png_create_info_struct( ld_->png_ptr_ ); + if ( ld_->end_info_ == NULL ) + { + fclose( fp ); + png_destroy_read_struct( &( ld_->png_ptr_ ), &( ld_->info_ptr_ ), ( png_infopp )NULL ); + throw image_load_error(std::string("png_loader: parse error in file ") + filename); + } + + if (setjmp(png_jmpbuf(ld_->png_ptr_))) + { + // If we get here, we had a problem writing the file + fclose(fp); + png_destroy_read_struct( &( ld_->png_ptr_ ), &( ld_->info_ptr_ ), &( ld_->end_info_ ) ); + throw image_load_error(std::string("png_loader: parse error in file ") + filename); + } + + png_set_palette_to_rgb(ld_->png_ptr_); + + png_init_io( ld_->png_ptr_, fp ); + png_set_sig_bytes( ld_->png_ptr_, 8 ); + // flags force one byte per channel output + byte_orderer bo; + int png_transforms = PNG_TRANSFORM_PACKING; + if (bo.host_is_little_endian()) + png_transforms |= PNG_TRANSFORM_SWAP_ENDIAN; + png_read_png( ld_->png_ptr_, ld_->info_ptr_, png_transforms, NULL ); + height_ = png_get_image_height( ld_->png_ptr_, ld_->info_ptr_ ); + width_ = png_get_image_width( ld_->png_ptr_, ld_->info_ptr_ ); + bit_depth_ = png_get_bit_depth( ld_->png_ptr_, ld_->info_ptr_ ); + color_type_ = png_get_color_type( ld_->png_ptr_, ld_-> info_ptr_ ); + + + if (color_type_ != PNG_COLOR_TYPE_GRAY && + color_type_ != PNG_COLOR_TYPE_RGB && + color_type_ != PNG_COLOR_TYPE_RGB_ALPHA && + color_type_ != PNG_COLOR_TYPE_GRAY_ALPHA) + { + fclose( fp ); + png_destroy_read_struct( &( ld_->png_ptr_ ), &( ld_->info_ptr_ ), &( ld_->end_info_ ) ); + throw image_load_error(std::string("png_loader: unsupported color type in file ") + filename); + } + + if (bit_depth_ != 8 && bit_depth_ != 16) + { + fclose( fp ); + png_destroy_read_struct( &( ld_->png_ptr_ ), &( ld_->info_ptr_ ), &( ld_->end_info_ ) ); + throw image_load_error("png_loader: unsupported bit depth of " + cast_to_string(bit_depth_) + " in file " + std::string(filename)); + } + + ld_->row_pointers_ = png_get_rows( ld_->png_ptr_, ld_->info_ptr_ ); + + fclose( fp ); + if ( ld_->row_pointers_ == NULL ) + { + png_destroy_read_struct( &( ld_->png_ptr_ ), &( ld_->info_ptr_ ), &( ld_->end_info_ ) ); + throw image_load_error(std::string("png_loader: parse error in file ") + filename); + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_PNG_SUPPORT + +#endif // DLIB_PNG_LOADER_CPp_ + diff --git a/ml/dlib/dlib/image_loader/png_loader.h b/ml/dlib/dlib/image_loader/png_loader.h new file mode 100644 index 000000000..291d3fddd --- /dev/null +++ b/ml/dlib/dlib/image_loader/png_loader.h @@ -0,0 +1,223 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net), Nils Labugt +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_PNG_IMPORT +#define DLIB_PNG_IMPORT + +#include + +#include "png_loader_abstract.h" +#include "image_loader.h" +#include "../pixel.h" +#include "../dir_nav.h" +#include "../test_for_odr_violations.h" + +namespace dlib +{ + + struct LibpngData; + class png_loader : noncopyable + { + public: + + png_loader( const char* filename ); + png_loader( const std::string& filename ); + png_loader( const dlib::file& f ); + ~png_loader(); + + bool is_gray() const; + bool is_graya() const; + bool is_rgb() const; + bool is_rgba() const; + + unsigned int bit_depth () const { return bit_depth_; } + + template + void get_image( T& t_) const + { +#ifndef DLIB_PNG_SUPPORT + /* !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + You are getting this error because you are trying to use the png_loader + object but you haven't defined DLIB_PNG_SUPPORT. You must do so to use + this object. You must also make sure you set your build environment + to link against the libpng library. + !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!*/ + COMPILE_TIME_ASSERT(sizeof(T) == 0); +#endif + + typedef typename image_traits::pixel_type pixel_type; + image_view t(t_); + t.set_size( height_, width_ ); + + + if (is_gray() && bit_depth_ == 8) + { + for ( unsigned n = 0; n < height_;n++ ) + { + const unsigned char* v = get_row( n ); + for ( unsigned m = 0; m < width_;m++ ) + { + unsigned char p = v[m]; + assign_pixel( t[n][m], p ); + } + } + } + else if (is_gray() && bit_depth_ == 16) + { + for ( unsigned n = 0; n < height_;n++ ) + { + const uint16* v = (uint16*)get_row( n ); + for ( unsigned m = 0; m < width_;m++ ) + { + dlib::uint16 p = v[m]; + assign_pixel( t[n][m], p ); + } + } + } + else if (is_graya() && bit_depth_ == 8) + { + for ( unsigned n = 0; n < height_;n++ ) + { + const unsigned char* v = get_row( n ); + for ( unsigned m = 0; m < width_; m++ ) + { + unsigned char p = v[m*2]; + if (!pixel_traits::has_alpha) + { + assign_pixel( t[n][m], p ); + } + else + { + unsigned char pa = v[m*2+1]; + rgb_alpha_pixel pix; + assign_pixel(pix, p); + assign_pixel(pix.alpha, pa); + assign_pixel(t[n][m], pix); + } + } + } + } + else if (is_graya() && bit_depth_ == 16) + { + for ( unsigned n = 0; n < height_;n++ ) + { + const uint16* v = (uint16*)get_row( n ); + for ( unsigned m = 0; m < width_; m++ ) + { + dlib::uint16 p = v[m*2]; + if (!pixel_traits::has_alpha) + { + assign_pixel( t[n][m], p ); + } + else + { + dlib::uint16 pa = v[m*2+1]; + rgb_alpha_pixel pix; + assign_pixel(pix, p); + assign_pixel(pix.alpha, pa); + assign_pixel(t[n][m], pix); + } + } + } + } + else if (is_rgb() && bit_depth_ == 8) + { + for ( unsigned n = 0; n < height_;n++ ) + { + const unsigned char* v = get_row( n ); + for ( unsigned m = 0; m < width_;m++ ) + { + rgb_pixel p; + p.red = v[m*3]; + p.green = v[m*3+1]; + p.blue = v[m*3+2]; + assign_pixel( t[n][m], p ); + } + } + } + else if (is_rgb() && bit_depth_ == 16) + { + for ( unsigned n = 0; n < height_;n++ ) + { + const uint16* v = (uint16*)get_row( n ); + for ( unsigned m = 0; m < width_;m++ ) + { + rgb_pixel p; + p.red = static_cast(v[m*3]); + p.green = static_cast(v[m*3+1]); + p.blue = static_cast(v[m*3+2]); + assign_pixel( t[n][m], p ); + } + } + } + else if (is_rgba() && bit_depth_ == 8) + { + if (!pixel_traits::has_alpha) + assign_all_pixels(t,0); + + for ( unsigned n = 0; n < height_;n++ ) + { + const unsigned char* v = get_row( n ); + for ( unsigned m = 0; m < width_;m++ ) + { + rgb_alpha_pixel p; + p.red = v[m*4]; + p.green = v[m*4+1]; + p.blue = v[m*4+2]; + p.alpha = v[m*4+3]; + assign_pixel( t[n][m], p ); + } + } + } + else if (is_rgba() && bit_depth_ == 16) + { + if (!pixel_traits::has_alpha) + assign_all_pixels(t,0); + + for ( unsigned n = 0; n < height_;n++ ) + { + const uint16* v = (uint16*)get_row( n ); + for ( unsigned m = 0; m < width_;m++ ) + { + rgb_alpha_pixel p; + p.red = static_cast(v[m*4]); + p.green = static_cast(v[m*4+1]); + p.blue = static_cast(v[m*4+2]); + p.alpha = static_cast(v[m*4+3]); + assign_pixel( t[n][m], p ); + } + } + } + } + + private: + const unsigned char* get_row( unsigned i ) const; + void read_image( const char* filename ); + unsigned height_, width_; + unsigned bit_depth_; + int color_type_; + std::unique_ptr ld_; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + void load_png ( + image_type& image, + const std::string& file_name + ) + { + png_loader(file_name).get_image(image); + } + +// ---------------------------------------------------------------------------------------- + +} + +#ifdef NO_MAKEFILE +#include "png_loader.cpp" +#endif + +#endif // DLIB_PNG_IMPORT + diff --git a/ml/dlib/dlib/image_loader/png_loader_abstract.h b/ml/dlib/dlib/image_loader/png_loader_abstract.h new file mode 100644 index 000000000..d81e7f83a --- /dev/null +++ b/ml/dlib/dlib/image_loader/png_loader_abstract.h @@ -0,0 +1,162 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net), Nils Labugt +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_PNG_IMPORT_ABSTRACT +#ifdef DLIB_PNG_IMPORT_ABSTRACT + +#include "image_loader_abstract.h" +#include "../algs.h" +#include "../pixel.h" +#include "../dir_nav.h" +#include "../image_processing/generic_image.h" + +namespace dlib +{ + + class png_loader : noncopyable + { + /*! + INITIAL VALUE + Defined by the constructors + + WHAT THIS OBJECT REPRESENTS + This object represents a class capable of loading PNG image files. + Once an instance of it is created to contain a PNG file from + disk you can obtain the image stored in it via get_image(). + !*/ + + public: + + png_loader( + const char* filename + ); + /*! + ensures + - loads the PNG file with the given file name into this object + throws + - std::bad_alloc + - image_load_error + This exception is thrown if there is some error that prevents + us from loading the given PNG file. + !*/ + + png_loader( + const std::string& filename + ); + /*! + ensures + - loads the PNG file with the given file name into this object + throws + - std::bad_alloc + - image_load_error + This exception is thrown if there is some error that prevents + us from loading the given PNG file. + !*/ + + png_loader( + const dlib::file& f + ); + /*! + ensures + - loads the PNG file with the given file name into this object + throws + - std::bad_alloc + - image_load_error + This exception is thrown if there is some error that prevents + us from loading the given PNG file. + !*/ + + ~png_loader( + ); + /*! + ensures + - all resources associated with *this has been released + !*/ + + bool is_gray( + ) const; + /*! + ensures + - if (this object contains a grayscale image without an alpha channel) then + - returns true + - else + - returns false + !*/ + + bool is_graya( + ) const; + /*! + ensures + - if (this object contains a grayscale image with an alpha channel) then + - returns true + - else + - returns false + !*/ + + bool is_rgb( + ) const; + /*! + ensures + - if (this object contains a 3 channel RGB image) then + - returns true + - else + - returns false + !*/ + + bool is_rgba( + ) const; + /*! + ensures + - if (this object contains a 4 channel RGB alpha image) then + - returns true + - else + - returns false + !*/ + + unsigned int bit_depth ( + ) const; + /*! + ensures + - returns the number of bits per channel in the image contained by this + object. The possible values are 8 or 16. + !*/ + + template< + typename image_type + > + void get_image( + image_type& img + ) const; + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + ensures + - loads the PNG image stored in this object into img + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + void load_png ( + image_type& image, + const std::string& file_name + ); + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + ensures + - performs: png_loader(file_name).get_image(image); + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_PNG_IMPORT_ABSTRACT + + diff --git a/ml/dlib/dlib/image_processing.h b/ml/dlib/dlib/image_processing.h new file mode 100644 index 000000000..a53f4a9d1 --- /dev/null +++ b/ml/dlib/dlib/image_processing.h @@ -0,0 +1,28 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#ifdef DLIB_ALL_SOURCE_END +#include "dlib_basic_cpp_build_tutorial.txt" +#endif + +#ifndef DLIB_IMAGE_PROCESSInG_H_h_ +#define DLIB_IMAGE_PROCESSInG_H_h_ + +#include "image_processing/scan_image.h" +#include "image_processing/scan_image_pyramid.h" +#include "image_processing/detection_template_tools.h" +#include "image_processing/object_detector.h" +#include "image_processing/box_overlap_testing.h" +#include "image_processing/scan_image_pyramid_tools.h" +#include "image_processing/setup_hashed_features.h" +#include "image_processing/scan_image_boxes.h" +#include "image_processing/scan_image_custom.h" +#include "image_processing/remove_unobtainable_rectangles.h" +#include "image_processing/scan_fhog_pyramid.h" +#include "image_processing/shape_predictor.h" +#include "image_processing/shape_predictor_trainer.h" +#include "image_processing/correlation_tracker.h" + +#endif // DLIB_IMAGE_PROCESSInG_H_h_ + + diff --git a/ml/dlib/dlib/image_processing/box_overlap_testing.h b/ml/dlib/dlib/image_processing/box_overlap_testing.h new file mode 100644 index 000000000..32409d13e --- /dev/null +++ b/ml/dlib/dlib/image_processing/box_overlap_testing.h @@ -0,0 +1,215 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BOX_OVERlAP_TESTING_Hh_ +#define DLIB_BOX_OVERlAP_TESTING_Hh_ + +#include "box_overlap_testing_abstract.h" +#include "../geometry.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + inline double box_intersection_over_union ( + const drectangle& a, + const drectangle& b + ) + { + const double inner = a.intersect(b).area(); + if (inner == 0) + return 0; + const double outer = (a+b).area(); + return inner/outer; + } + +// ---------------------------------------------------------------------------------------- + + inline double box_intersection_over_union ( + const rectangle& a, + const rectangle& b + ) + { + return box_intersection_over_union(drectangle(a),drectangle(b)); + } + +// ---------------------------------------------------------------------------------------- + + inline double box_percent_covered ( + const drectangle& a, + const drectangle& b + ) + { + const double inner = a.intersect(b).area(); + if (inner == 0) + return 0; + return std::max(inner/a.area(), inner/b.area()); + } + +// ---------------------------------------------------------------------------------------- + + inline double box_percent_covered ( + const rectangle& a, + const rectangle& b + ) + { + return box_percent_covered(drectangle(a), drectangle(b)); + } + +// ---------------------------------------------------------------------------------------- + + class test_box_overlap + { + public: + test_box_overlap ( + ) : iou_thresh(0.5), percent_covered_thresh(1.0) + {} + + explicit test_box_overlap ( + double iou_thresh_, + double percent_covered_thresh_ = 1.0 + ) : iou_thresh(iou_thresh_), percent_covered_thresh(percent_covered_thresh_) + { + // make sure requires clause is not broken + DLIB_ASSERT(0 <= iou_thresh && iou_thresh <= 1 && + 0 <= percent_covered_thresh && percent_covered_thresh <= 1, + "\t test_box_overlap::test_box_overlap(iou_thresh, percent_covered_thresh)" + << "\n\t Invalid inputs were given to this function " + << "\n\t iou_thresh: " << iou_thresh + << "\n\t percent_covered_thresh: " << percent_covered_thresh + << "\n\t this: " << this + ); + + } + + bool operator() ( + const dlib::rectangle& a, + const dlib::rectangle& b + ) const + { + const double inner = a.intersect(b).area(); + if (inner == 0) + return false; + + const double outer = (a+b).area(); + if (inner/outer > iou_thresh || + inner/a.area() > percent_covered_thresh || + inner/b.area() > percent_covered_thresh) + return true; + else + return false; + } + + double get_percent_covered_thresh ( + ) const + { + return percent_covered_thresh; + } + + double get_iou_thresh ( + ) const + { + return iou_thresh; + } + + private: + double iou_thresh; + double percent_covered_thresh; + }; + +// ---------------------------------------------------------------------------------------- + + inline void serialize ( + const test_box_overlap& item, + std::ostream& out + ) + { + serialize(item.get_iou_thresh(), out); + serialize(item.get_percent_covered_thresh(), out); + } + + inline void deserialize ( + test_box_overlap& item, + std::istream& in + ) + { + double percent_covered_thresh, iou_thresh; + deserialize(iou_thresh, in); + deserialize(percent_covered_thresh, in); + item = test_box_overlap(iou_thresh, percent_covered_thresh); + } + +// ---------------------------------------------------------------------------------------- + + inline test_box_overlap find_tight_overlap_tester ( + const std::vector >& rects + ) + { + double max_pcov = 0; + double max_iou_score = 0; + for (unsigned long i = 0; i < rects.size(); ++i) + { + for (unsigned long j = 0; j < rects[i].size(); ++j) + { + for (unsigned long k = j+1; k < rects[i].size(); ++k) + { + const rectangle a = rects[i][j]; + const rectangle b = rects[i][k]; + const double iou_score = (a.intersect(b)).area()/(double)(a+b).area(); + const double pcov_a = (a.intersect(b)).area()/(double)(a).area(); + const double pcov_b = (a.intersect(b)).area()/(double)(b).area(); + + if (iou_score > max_iou_score) + max_iou_score = iou_score; + + if (pcov_a > max_pcov) + max_pcov = pcov_a; + if (pcov_b > max_pcov) + max_pcov = pcov_b; + } + } + } + + // Relax these thresholds very slightly. We do this because on some systems the + // boxes that generated the max values erroneously trigger a box overlap iou even + // though their percent covered and iou values are *equal* to the thresholds but + // not greater. That is, sometimes when double values get moved around they change + // their values slightly, so this avoids the problems that can create. + max_iou_score = std::min(1.0000001*max_iou_score, 1.0); + max_pcov = std::min(1.0000001*max_pcov, 1.0); + return test_box_overlap(max_iou_score, max_pcov); + } + +// ---------------------------------------------------------------------------------------- + + inline bool overlaps_any_box ( + const test_box_overlap& tester, + const std::vector& rects, + const rectangle& rect + ) + { + for (unsigned long i = 0; i < rects.size(); ++i) + { + if (tester(rects[i],rect)) + return true; + } + return false; + } + +// ---------------------------------------------------------------------------------------- + + inline bool overlaps_any_box ( + const std::vector& rects, + const rectangle& rect + ) + { + return overlaps_any_box(test_box_overlap(),rects,rect); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BOX_OVERlAP_TESTING_Hh_ + diff --git a/ml/dlib/dlib/image_processing/box_overlap_testing_abstract.h b/ml/dlib/dlib/image_processing/box_overlap_testing_abstract.h new file mode 100644 index 000000000..1bb4a28ae --- /dev/null +++ b/ml/dlib/dlib/image_processing/box_overlap_testing_abstract.h @@ -0,0 +1,201 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_BOX_OVERlAP_TESTING_ABSTRACT_Hh_ +#ifdef DLIB_BOX_OVERlAP_TESTING_ABSTRACT_Hh_ + +#include "../geometry.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + inline double box_intersection_over_union ( + const drectangle& a, + const drectangle& b + ); + /*! + ensures + - returns area of the intersection of a and b divided by (a+b).area(). If both + boxes are empty then returns 0. + !*/ + +// ---------------------------------------------------------------------------------------- + + inline double box_intersection_over_union ( + const rectangle& a, + const rectangle& b + ); + /*! + ensures + - returns area of the intersection of a and b divided by (a+b).area(). If both + boxes are empty then returns 0. + !*/ + +// ---------------------------------------------------------------------------------------- + + inline double box_percent_covered ( + const drectangle& a, + const drectangle& b + ); + /*! + ensures + - let OVERLAP = a.intersect(b).area() + - This function returns max(OVERLAP/a.area(), OVERLAP/b.area()) + e.g. If one box entirely contains another then this function returns 1, if + they don't overlap at all it returns 0. + !*/ + +// ---------------------------------------------------------------------------------------- + + inline double box_percent_covered ( + const rectangle& a, + const rectangle& b + ); + /*! + ensures + - let OVERLAP = a.intersect(b).area() + - This function returns max(OVERLAP/a.area(), OVERLAP/b.area()) + e.g. If one box entirely contains another then this function returns 1, if + they don't overlap at all it returns 0. + !*/ + +// ---------------------------------------------------------------------------------------- + + class test_box_overlap + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a simple function object for determining if two rectangles + overlap. + + THREAD SAFETY + Concurrent access to an instance of this object is safe provided that + only const member functions are invoked. Otherwise, access must be + protected by a mutex lock. + !*/ + + public: + test_box_overlap ( + ); + /*! + ensures + - #get_iou_thresh() == 0.5 + - #get_percent_covered_thresh() == 1.0 + !*/ + + explicit test_box_overlap ( + double iou_thresh, + double percent_covered_thresh = 1.0 + ); + /*! + requires + - 0 <= iou_thresh <= 1 + - 0 <= percent_covered_thresh <= 1 + ensures + - #get_iou_thresh() == iou_thresh + - #get_percent_covered_thresh() == percent_covered_thresh + !*/ + + bool operator() ( + const dlib::rectangle& a, + const dlib::rectangle& b + ) const; + /*! + ensures + - returns true if a and b overlap "enough". This is defined precisely below. + - if (a.intersect(b).area()/(a+b).area() > get_iou_thresh() || + a.intersect(b).area()/a.area() > get_percent_covered_thresh() || + a.intersect(b).area()/b.area() > get_percent_covered_thresh() ) then + - returns true + - else + - returns false + !*/ + + double get_iou_thresh ( + ) const; + /*! + ensures + - returns the threshold used to determine if two rectangle's intersection + over union value is big enough to be considered a match. Note that the + iou score varies from 0 to 1 and only becomes 1 when two rectangles are + identical. + !*/ + + double get_percent_covered_thresh ( + ) const; + /*! + ensures + - returns the threshold used to determine if two rectangles overlap. This + value is the percent of a rectangle's area covered by another rectangle. + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + void serialize ( + const test_box_overlap& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + + void deserialize ( + test_box_overlap& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +// ---------------------------------------------------------------------------------------- + + test_box_overlap find_tight_overlap_tester ( + const std::vector >& rects + ); + /*! + ensures + - This function finds the most restrictive test_box_overlap object possible + that is consistent with the given set of sets of rectangles. + - To be precise, this function finds and returns a test_box_overlap object + TBO such that: + - TBO.get_iou_thresh() and TBO.get_percent_covered_thresh() are as small + as possible such that the following conditions are satisfied. + - for all valid i: + - for all distinct rectangles A and B in rects[i]: + - TBO(A,B) == false + !*/ + +// ---------------------------------------------------------------------------------------- + + bool overlaps_any_box ( + const test_box_overlap& tester, + const std::vector& rects, + const rectangle& rect + ); + /*! + ensures + - returns true if rect overlaps any box in rects and false otherwise. Overlap + is determined based on the given tester object. + !*/ + +// ---------------------------------------------------------------------------------------- + + bool overlaps_any_box ( + const std::vector& rects, + const rectangle& rect + ); + /*! + ensures + - returns overlaps_any_box(test_box_overlap(), rects, rect) + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BOX_OVERlAP_TESTING_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/image_processing/correlation_tracker.h b/ml/dlib/dlib/image_processing/correlation_tracker.h new file mode 100644 index 000000000..f005ddc7b --- /dev/null +++ b/ml/dlib/dlib/image_processing/correlation_tracker.h @@ -0,0 +1,404 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CORRELATION_TrACKER_H_ +#define DLIB_CORRELATION_TrACKER_H_ + +#include "correlation_tracker_abstract.h" +#include "../geometry.h" +#include "../matrix.h" +#include "../array2d.h" +#include "../image_transforms/assign_image.h" +#include "../image_transforms/interpolation.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class correlation_tracker + { + public: + + explicit correlation_tracker (unsigned long filter_size = 6, + unsigned long num_scale_levels = 5, + unsigned long scale_window_size = 23, + double regularizer_space = 0.001, + double nu_space = 0.025, + double regularizer_scale = 0.001, + double nu_scale = 0.025, + double scale_pyramid_alpha = 1.020 + ) + : filter_size(1 << filter_size), num_scale_levels(1 << num_scale_levels), + scale_window_size(scale_window_size), + regularizer_space(regularizer_space), nu_space(nu_space), + regularizer_scale(regularizer_scale), nu_scale(nu_scale), + scale_pyramid_alpha(scale_pyramid_alpha) + { + // Create the cosine mask used for space filtering. + mask = make_cosine_mask(); + + // Create the cosine mask used for the scale filtering. + scale_cos_mask.resize(get_num_scale_levels()); + const long max_level = get_num_scale_levels()/2; + for (unsigned long k = 0; k < get_num_scale_levels(); ++k) + { + double dist = std::abs((double)k-max_level)/max_level*pi/2; + dist = std::min(dist, pi/2); + scale_cos_mask[k] = std::cos(dist); + } + } + + template + void start_track ( + const image_type& img, + const drectangle& p + ) + { + DLIB_CASSERT(p.is_empty() == false, + "\t void correlation_tracker::start_track()" + << "\n\t You can't give an empty rectangle." + ); + + B.set_size(0,0); + + point_transform_affine tform = inv(make_chip(img, p, F)); + for (unsigned long i = 0; i < F.size(); ++i) + fft_inplace(F[i]); + make_target_location_image(tform(center(p)), G); + A.resize(F.size()); + for (unsigned long i = 0; i < F.size(); ++i) + { + A[i] = pointwise_multiply(G, F[i]); + B += squared(real(F[i]))+squared(imag(F[i])); + } + + position = p; + + // now do the scale space stuff + make_scale_space(img, Fs); + for (unsigned long i = 0; i < Fs.size(); ++i) + fft_inplace(Fs[i]); + make_scale_target_location_image(get_num_scale_levels()/2, Gs); + Bs.set_size(0); + As.resize(Fs.size()); + for (unsigned long i = 0; i < Fs.size(); ++i) + { + As[i] = pointwise_multiply(Gs, Fs[i]); + Bs += squared(real(Fs[i]))+squared(imag(Fs[i])); + } + } + + + unsigned long get_filter_size ( + ) const { return filter_size; } + + unsigned long get_num_scale_levels( + ) const { return num_scale_levels; } + + unsigned long get_scale_window_size ( + ) const { return scale_window_size; } + + double get_regularizer_space ( + ) const { return regularizer_space; } + inline double get_nu_space ( + ) const { return nu_space;} + + double get_regularizer_scale ( + ) const { return regularizer_scale; } + double get_nu_scale ( + ) const { return nu_scale;} + + drectangle get_position ( + ) const + { + return position; + } + + double get_scale_pyramid_alpha ( + ) const { return scale_pyramid_alpha; } + + + template + double update_noscale( + const image_type& img, + const drectangle& guess + ) + { + DLIB_CASSERT(get_position().is_empty() == false, + "\t double correlation_tracker::update()" + << "\n\t You must call start_track() first before calling update()." + ); + + + const point_transform_affine tform = make_chip(img, guess, F); + for (unsigned long i = 0; i < F.size(); ++i) + fft_inplace(F[i]); + + // use the current filter to predict the object's location + G = 0; + for (unsigned long i = 0; i < F.size(); ++i) + G += pointwise_multiply(F[i],conj(A[i])); + G = pointwise_multiply(G, reciprocal(B+get_regularizer_space())); + ifft_inplace(G); + const dlib::vector pp = max_point_interpolated(real(G)); + + + // Compute the peak to side lobe ratio. + const point p = pp; + running_stats rs; + const rectangle peak = centered_rect(p, 8,8); + for (long r = 0; r < G.nr(); ++r) + { + for (long c = 0; c < G.nc(); ++c) + { + if (!peak.contains(point(c,r))) + rs.add(G(r,c).real()); + } + } + const double psr = (G(p.y(),p.x()).real()-rs.mean())/rs.stddev(); + + // update the position of the object + position = translate_rect(guess, tform(pp)-center(guess)); + + // now update the position filters + make_target_location_image(pp, G); + B *= (1-get_nu_space()); + for (unsigned long i = 0; i < F.size(); ++i) + { + A[i] = get_nu_space()*pointwise_multiply(G, F[i]) + (1-get_nu_space())*A[i]; + B += get_nu_space()*(squared(real(F[i]))+squared(imag(F[i]))); + } + + return psr; + } + + template + double update ( + const image_type& img, + const drectangle& guess + ) + { + double psr = update_noscale(img, guess); + + // Now predict the scale change + make_scale_space(img, Fs); + for (unsigned long i = 0; i < Fs.size(); ++i) + fft_inplace(Fs[i]); + Gs = 0; + for (unsigned long i = 0; i < Fs.size(); ++i) + Gs += pointwise_multiply(Fs[i],conj(As[i])); + Gs = pointwise_multiply(Gs, reciprocal(Bs+get_regularizer_scale())); + ifft_inplace(Gs); + const double pos = max_point_interpolated(real(Gs)).y(); + + // update the rectangle's scale + position *= std::pow(get_scale_pyramid_alpha(), pos-(double)get_num_scale_levels()/2); + + + + // Now update the scale filters + make_scale_target_location_image(pos, Gs); + Bs *= (1-get_nu_scale()); + for (unsigned long i = 0; i < Fs.size(); ++i) + { + As[i] = get_nu_scale()*pointwise_multiply(Gs, Fs[i]) + (1-get_nu_scale())*As[i]; + Bs += get_nu_scale()*(squared(real(Fs[i]))+squared(imag(Fs[i]))); + } + + + return psr; + } + + template + double update_noscale ( + const image_type& img + ) + { + return update_noscale(img, get_position()); + } + + template + double update( + const image_type& img + ) + { + return update(img, get_position()); + } + + private: + + template + void make_scale_space( + const image_type& img, + std::vector,0,1> >& Fs + ) const + { + typedef typename image_traits::pixel_type pixel_type; + + // Make an image pyramid and put it into the chips array. + const long chip_size = get_scale_window_size(); + drectangle ppp = position*std::pow(get_scale_pyramid_alpha(), -(double)get_num_scale_levels()/2); + dlib::array > chips; + std::vector > from_points, to_points; + from_points.push_back(point(0,0)); + from_points.push_back(point(chip_size-1,0)); + from_points.push_back(point(chip_size-1,chip_size-1)); + for (unsigned long i = 0; i < get_num_scale_levels(); ++i) + { + array2d chip(chip_size,chip_size); + + // pull box into chip + to_points.clear(); + to_points.push_back(ppp.tl_corner()); + to_points.push_back(ppp.tr_corner()); + to_points.push_back(ppp.br_corner()); + transform_image(img,chip,interpolate_bilinear(),find_affine_transform(from_points, to_points)); + + chips.push_back(chip); + ppp *= get_scale_pyramid_alpha(); + } + + + // extract HOG for each chip + dlib::array > > hogs(chips.size()); + for (unsigned long i = 0; i < chips.size(); ++i) + { + extract_fhog_features(chips[i], hogs[i], 4); + hogs[i].resize(32); + assign_image(hogs[i][31], chips[i]); + assign_image(hogs[i][31], mat(hogs[i][31])/255.0); + } + + // Now copy the hog features into the Fs outputs and also apply the cosine + // windowing. + Fs.resize(hogs[0].size()*hogs[0][0].size()); + unsigned long i = 0; + for (long r = 0; r < hogs[0][0].nr(); ++r) + { + for (long c = 0; c < hogs[0][0].nc(); ++c) + { + for (unsigned long j = 0; j < hogs[0].size(); ++j) + { + Fs[i].set_size(hogs.size()); + for (unsigned long k = 0; k < hogs.size(); ++k) + { + Fs[i](k) = hogs[k][j][r][c]*scale_cos_mask[k]; + } + ++i; + } + } + } + } + + template + point_transform_affine make_chip ( + const image_type& img, + drectangle p, + std::vector > >& chip + ) const + { + typedef typename image_traits::pixel_type pixel_type; + array2d temp; + const double padding = 1.4; + const chip_details details(p*padding, chip_dims(get_filter_size(), get_filter_size())); + extract_image_chip(img, details, temp); + + + chip.resize(32); + dlib::array > hog; + extract_fhog_features(temp, hog, 1, 3,3 ); + for (unsigned long i = 0; i < hog.size(); ++i) + assign_image(chip[i], pointwise_multiply(matrix_cast(mat(hog[i])), mask)); + + assign_image(chip[31], temp); + assign_image(chip[31], pointwise_multiply(mat(chip[31]), mask)/255.0); + + return inv(get_mapping_to_chip(details)); + } + + void make_target_location_image ( + const dlib::vector& p, + matrix >& g + ) const + { + g.set_size(get_filter_size(), get_filter_size()); + g = 0; + rectangle area = centered_rect(p, 21,21).intersect(get_rect(g)); + for (long r = area.top(); r <= area.bottom(); ++r) + { + for (long c = area.left(); c <= area.right(); ++c) + { + double dist = length(point(c,r)-p); + g(r,c) = std::exp(-dist/3.0); + } + } + fft_inplace(g); + g = conj(g); + } + + + void make_scale_target_location_image ( + const double scale, + matrix,0,1>& g + ) const + { + g.set_size(get_num_scale_levels()); + for (long i = 0; i < g.size(); ++i) + { + double dist = std::pow((i-scale),2.0); + g(i) = std::exp(-dist/1.000); + } + fft_inplace(g); + g = conj(g); + } + + matrix make_cosine_mask ( + ) const + { + const long size = get_filter_size(); + matrix temp(size,size); + point cent = center(get_rect(temp)); + for (long r = 0; r < temp.nr(); ++r) + { + for (long c = 0; c < temp.nc(); ++c) + { + point delta = point(c,r)-cent; + double dist = length(delta)/(size/2.0)*(pi/2); + dist = std::min(dist*1.0, pi/2); + + temp(r,c) = std::cos(dist); + } + } + return temp; + } + + + std::vector > > A, F; + matrix B; + + std::vector,0,1> > As, Fs; + matrix Bs; + drectangle position; + + matrix mask; + std::vector scale_cos_mask; + + // G and Gs do not logically contribute to the state of this object. They are + // here just so we can void reallocating them over and over. + matrix > G; + matrix,0,1> Gs; + + unsigned long filter_size; + unsigned long num_scale_levels; + unsigned long scale_window_size; + double regularizer_space; + double nu_space; + double regularizer_scale; + double nu_scale; + double scale_pyramid_alpha; + }; +} + +#endif // DLIB_CORRELATION_TrACKER_H_ + diff --git a/ml/dlib/dlib/image_processing/correlation_tracker_abstract.h b/ml/dlib/dlib/image_processing/correlation_tracker_abstract.h new file mode 100644 index 000000000..5514f5e76 --- /dev/null +++ b/ml/dlib/dlib/image_processing/correlation_tracker_abstract.h @@ -0,0 +1,162 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_CORRELATION_TrACKER_ABSTRACT_H_ +#ifdef DLIB_CORRELATION_TrACKER_ABSTRACT_H_ + +#include "../geometry/drectangle_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class correlation_tracker + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a tool for tracking moving objects in a video stream. You give it + the bounding box of an object in the first frame and it attempts to track the + object in the box from frame to frame. + + This tool is an implementation of the method described in the following paper: + Danelljan, Martin, et al. "Accurate scale estimation for robust visual + tracking." Proceedings of the British Machine Vision Conference BMVC. 2014. + !*/ + + public: + + explicit correlation_tracker (unsigned long filter_size = 6, + unsigned long num_scale_levels = 5, + unsigned long scale_window_size = 23, + double regularizer_space = 0.001, + double nu_space = 0.025, + double regularizer_scale = 0.001, + double nu_scale = 0.025, + double scale_pyramid_alpha = 1.020 + ); + /*! + requires + - p.is_empty() == false + ensures + - Initializes correlation_tracker. Higher value of filter_size and + num_scale_levels increases tracking precision but requires more CPU + for processing. Recommended values for filter_size = 5-7, + default = 6, for num_scale_levels = 4-6, default = 5 + - #get_position().is_empty() == true + !*/ + + template < + typename image_type + > + void start_track ( + const image_type& img, + const drectangle& p + ); + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - p.is_empty() == false + ensures + - This object will start tracking the thing inside the bounding box in the + given image. That is, if you call update() with subsequent video frames + then it will try to keep track of the position of the object inside p. + - #get_position() == p + !*/ + + drectangle get_position ( + ) const; + /*! + ensures + - returns the predicted position of the object under track. + !*/ + + template < + typename image_type + > + double update_noscale ( + const image_type& img, + const drectangle& guess + ); + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - get_position().is_empty() == false + (i.e. you must have started tracking by calling start_track()) + ensures + - When searching for the object in img, we search in the area around the + provided guess. This function only tracks object position without trying + to track the scale + - #get_position() == the new predicted location of the object in img. This + location will be a copy of guess that has been translated and NOT scaled + appropriately based on the content of img so that it, hopefully, bounds + the object in img. + - Returns the peak to side-lobe ratio. This is a number that measures how + confident the tracker is that the object is inside #get_position(). + Larger values indicate higher confidence. + !*/ + + template < + typename image_type + > + double update ( + const image_type& img, + const drectangle& guess + ); + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - get_position().is_empty() == false + (i.e. you must have started tracking by calling start_track()) + ensures + - When searching for the object in img, we search in the area around the + provided guess. + - #get_position() == the new predicted location of the object in img. This + location will be a copy of guess that has been translated and scaled + appropriately based on the content of img so that it, hopefully, bounds + the object in img. + - Returns the peak to side-lobe ratio. This is a number that measures how + confident the tracker is that the object is inside #get_position(). + Larger values indicate higher confidence. + !*/ + + template < + typename image_type + > + double update_noscale ( + const image_type& img + ); + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - get_position().is_empty() == false + (i.e. you must have started tracking by calling start_track()) + ensures + - performs: return update_noscale(img, get_position()) + !*/ + template < + typename image_type + > + double update ( + const image_type& img + ); + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - get_position().is_empty() == false + (i.e. you must have started tracking by calling start_track()) + ensures + - performs: return update(img, get_position()) + !*/ + + }; +} + +#endif // DLIB_CORRELATION_TrACKER_ABSTRACT_H_ + + + diff --git a/ml/dlib/dlib/image_processing/detection_template_tools.h b/ml/dlib/dlib/image_processing/detection_template_tools.h new file mode 100644 index 000000000..b22c109fe --- /dev/null +++ b/ml/dlib/dlib/image_processing/detection_template_tools.h @@ -0,0 +1,113 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_DETECTION_TEMPlATE_TOOLS_Hh_ +#define DLIB_DETECTION_TEMPlATE_TOOLS_Hh_ + +#include "detection_template_tools_abstract.h" +#include "../geometry.h" +#include "../matrix.h" +#include +#include +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + inline rectangle compute_box_dimensions ( + const double width_to_height_ratio, + const double area + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(width_to_height_ratio > 0 && area > 0, + "\t rectangle compute_box_dimensions()" + << "\n\t Invalid arguments were given to this function. " + << "\n\t width_to_height_ratio: " << width_to_height_ratio + << "\n\t area: " << area + ); + + /* + width*height == area + width/height == width_to_height_ratio + */ + using namespace std; + + const int height = (int)std::floor(std::sqrt(area/width_to_height_ratio) + 0.5); + const int width = (int)std::floor(area/height + 0.5); + + return centered_rect(0,0,width,height); + } + +// ---------------------------------------------------------------------------------------- + + inline std::vector create_single_box_detection_template ( + const rectangle& object_box + ) + { + std::vector temp; + temp.push_back(object_box); + return temp; + } + +// ---------------------------------------------------------------------------------------- + + inline std::vector create_overlapped_2x2_detection_template ( + const rectangle& object_box + ) + { + std::vector result; + + const point c = center(object_box); + + result.push_back(rectangle() + c + object_box.tl_corner() + object_box.tr_corner()); + result.push_back(rectangle() + c + object_box.bl_corner() + object_box.br_corner()); + result.push_back(rectangle() + c + object_box.tl_corner() + object_box.bl_corner()); + result.push_back(rectangle() + c + object_box.tr_corner() + object_box.br_corner()); + + return result; + } + +// ---------------------------------------------------------------------------------------- + + inline std::vector create_grid_detection_template ( + const rectangle& object_box, + unsigned int cells_x, + unsigned int cells_y + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(cells_x > 0 && cells_y > 0, + "\t std::vector create_grid_detection_template()" + << "\n\t The number of cells along a dimension can't be zero. " + << "\n\t cells_x: " << cells_x + << "\n\t cells_y: " << cells_y + ); + + std::vector result; + + const matrix x = linspace(object_box.left(), object_box.right(), cells_x+1); + const matrix y = linspace(object_box.top(), object_box.bottom(), cells_y+1); + + for (long j = 0; j+1 < y.size(); ++j) + { + for (long i = 0; i+1 < x.size(); ++i) + { + const dlib::vector tl(x(i),y(j)); + const dlib::vector br(x(i+1),y(j+1)); + result.push_back(rectangle(tl,br)); + } + } + + return result; + } + +// ---------------------------------------------------------------------------------------- + +} + + +#endif // DLIB_DETECTION_TEMPlATE_TOOLS_Hh_ + + diff --git a/ml/dlib/dlib/image_processing/detection_template_tools_abstract.h b/ml/dlib/dlib/image_processing/detection_template_tools_abstract.h new file mode 100644 index 000000000..30b0ad5b9 --- /dev/null +++ b/ml/dlib/dlib/image_processing/detection_template_tools_abstract.h @@ -0,0 +1,95 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_DETECTION_TEMPlATE_TOOLS_ABSTRACT_Hh_ +#ifdef DLIB_DETECTION_TEMPlATE_TOOLS_ABSTRACT_Hh_ + +#include "../geometry.h" +#include +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + rectangle compute_box_dimensions ( + const double width_to_height_ratio, + const double area + ); + /*! + requires + - area > 0 + - width_to_height_ratio > 0 + ensures + - returns a rectangle with the given area and width_to_height_ratio. + - In particular, returns a rectangle R such that: + - R.area() == area (to within integer precision) + - R.width()/R.height() == width_to_height_ratio (to within integer precision) + - center(R) == point(0,0) + !*/ + +// ---------------------------------------------------------------------------------------- + + std::vector create_single_box_detection_template ( + const rectangle& object_box + ); + /*! + ensures + - returns a vector that contains only object_box. + - In particular, returns a vector V such that: + - V.size() == 1 + - V[0] == object_box + !*/ + +// ---------------------------------------------------------------------------------------- + + std::vector create_overlapped_2x2_detection_template ( + const rectangle& object_box + ); + /*! + ensures + - Divides object_box up into four overlapping regions, the + top half, bottom half, left half, and right half. These + four rectangles are returned inside a std::vector. + - In particular, returns a vector V such that: + - V.size() == 4 + - V[0] == top half of object_box + - V[1] == bottom half of object_box + - V[2] == left half of object_box + - V[3] == right half of object_box + - for all valid i: object_box.contains(V[i]) == true + !*/ + +// ---------------------------------------------------------------------------------------- + + std::vector create_grid_detection_template ( + const rectangle& object_box, + unsigned int cells_x, + unsigned int cells_y + ); + /*! + requires + - cells_x > 0 + - cells_y > 0 + ensures + - Divides object_box up into a grid and returns a vector + containing all the rectangles corresponding to elements + of the grid. Moreover, the grid will be cells_x elements + wide and cells_y elements tall. + - In particular, returns a vector V such that: + - V.size() == cells_x*cells_y + - for all valid i: + - object_box.contains(V[i]) == true + - V[i] == The rectangle corresponding to the ith grid + element. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + + +#endif // DLIB_DETECTION_TEMPlATE_TOOLS_ABSTRACT_Hh_ + + + diff --git a/ml/dlib/dlib/image_processing/frontal_face_detector.h b/ml/dlib/dlib/image_processing/frontal_face_detector.h new file mode 100644 index 000000000..3f4b59769 --- /dev/null +++ b/ml/dlib/dlib/image_processing/frontal_face_detector.h @@ -0,0 +1,2373 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_FRONTAL_FACE_DETECTOr_Hh_ +#define DLIB_FRONTAL_FACE_DETECTOr_Hh_ + +#include "frontal_face_detector_abstract.h" +#include "../image_processing/object_detector.h" +#include "../image_processing/scan_fhog_pyramid.h" +#include +#include "../compress_stream.h" +#include "../base64.h" + +namespace dlib +{ + typedef object_detector > > frontal_face_detector; + inline const std::string get_serialized_frontal_faces(); + + inline frontal_face_detector get_frontal_face_detector() + { + std::istringstream sin(get_serialized_frontal_faces()); + frontal_face_detector detector; + deserialize(detector, sin); + return detector; + } + +// ---------------------------------------------------------------------------------------- + + /* + It is built out of 5 HOG filters. A front looking, left looking, right looking, + front looking but rotated left, and finally a front looking but rotated right one. + + Moreover, here is the training log and parameters used to generate the filters: + The front detector: + trained on mirrored set of labeled_faces_in_the_wild/frontal_faces.xml + upsampled each image by 2:1 + used pyramid_down<6> + loss per missed target: 1 + epsilon: 0.05 + padding: 0 + detection window size: 80 80 + C: 700 + nuclear norm regularizer: 9 + cell_size: 8 + num filters: 78 + num images: 4748 + Train detector (precision,recall,AP): 0.999793 0.895517 0.895368 + singular value threshold: 0.15 + + The left detector: + trained on labeled_faces_in_the_wild/left_faces.xml + upsampled each image by 2:1 + used pyramid_down<6> + loss per missed target: 2 + epsilon: 0.05 + padding: 0 + detection window size: 80 80 + C: 250 + nuclear norm regularizer: 8 + cell_size: 8 + num filters: 63 + num images: 493 + Train detector (precision,recall,AP): 0.991803 0.86019 0.859486 + singular value threshold: 0.15 + + The right detector: + trained left-right flip of labeled_faces_in_the_wild/left_faces.xml + upsampled each image by 2:1 + used pyramid_down<6> + loss per missed target: 2 + epsilon: 0.05 + padding: 0 + detection window size: 80 80 + C: 250 + nuclear norm regularizer: 8 + cell_size: 8 + num filters: 66 + num images: 493 + Train detector (precision,recall,AP): 0.991781 0.85782 0.857341 + singular value threshold: 0.19 + + The front-rotate-left detector: + trained on mirrored set of labeled_faces_in_the_wild/frontal_faces.xml + upsampled each image by 2:1 + used pyramid_down<6> + rotated left 27 degrees + loss per missed target: 1 + epsilon: 0.05 + padding: 0 + detection window size: 80 80 + C: 700 + nuclear norm regularizer: 9 + cell_size: 8 + num images: 4748 + singular value threshold: 0.12 + + The front-rotate-right detector: + trained on mirrored set of labeled_faces_in_the_wild/frontal_faces.xml + upsampled each image by 2:1 + used pyramid_down<6> + rotated right 27 degrees + loss per missed target: 1 + epsilon: 0.05 + padding: 0 + detection window size: 80 80 + C: 700 + nuclear norm regularizer: 9 + cell_size: 8 + num filters: 89 + num images: 4748 + Train detector (precision,recall,AP): 1 0.897369 0.897369 + singular value threshold: 0.15 + */ + inline const std::string get_serialized_frontal_faces() + { + dlib::base64 base64_coder; + dlib::compress_stream::kernel_1ea compressor; + std::ostringstream sout; + std::istringstream sin; + + // The base64 encoded data from the file 'object_detector.dat' we want to decode and return. + sout << "AW2B5ZIvv09mlKLVYjKqbJC05yeR2KsCpPGEGOgn2QlwM92S4UT4HgQkV0V9WqYRf6xETTSVKz7Z"; + sout << "YcJ84Jc4C3+VdPgZDhV+LDt6qAt3OI4nA9zN4Y9cCIb6ivlETkN/JMmapbOAUW2mrSzDif5zjAaq"; + sout << "+NFvw/5V0Jciopw9tR6nYtV41unWGvyyfsO9CcqvDy81QIydToHh0a7UaL0jCtA2DYzkViDufxyv"; + sout << "Kpsn4xMyiU0haM1ge3UktIO48io/gSzjEKu0YYAffbD2YO1IE34tUH15Z3Z9NjkBFxTytDgrMxk8"; + sout << "i9MYq+Nl9nS421aogmec3ugExJYjLZMHs4KAk71jvG8vtJyJEA3qyLY6lvONt98gzQwGQ9+2B6de"; + sout << "ocb/DDJUza6mvudHQNJBYraR4gCWcIn9gFu2rJiRHf4IiqP4GEB3B1zKiHfJRo9jZbhxQUitAxAx"; + sout << "U/E2SuuHGZDilqK9AJ4K41RAudraxF9li/Bs4f+CK3G8Z/c97P7WLVekJL2ws+MsCdL9ObHE5ePD"; + sout << "uLLQWBy5NUbgPVM6HEnhnOiZk3rA4DYNqbABy3uemablAln9BLGkk4wrm2UcicacnzY8Aq054Ttb"; + sout << "3CCTcG4SOSPfePl/7T1M6Uy1hOesp5MpXfUR8gBKr4466dbdXCDHSahI05gra6NzxkOpOo2mOqBg"; + sout << "LYNGZUkHK4tdRyyD12N1MH+nJiMJbgk+qj54t5i3AuEr/71HTRXoTT8AEYbvc9y4f2WAlliQYXPn"; + sout << "O2Uaza3lKYrH7mFjKMNhLfvrezy9fe+1asbSlRKelnU3eY4lhD6fTVJjXqZypBfMnfmGQQJ0Q7g5"; + sout << "1Z/9GzpRyZnPSzQljtJgzVp8Gk0z3fuKiXPO9g+s4XL2cEuxBOFij0KGTy4eNitM0gcPc6xzp3tz"; + sout << "6Wv0W2h7w4h+V8Bzvyn8ag1sbEO0G1Lf2BrDVM9+pNxFoWFxYHqdoOmJPVvb8PRQqoC5bkqhplFr"; + sout << "TR5l3XsQedgwsnkadxNZQ3MbRJyo0JU0kvV1cfphLcn24MIIKqAnw3daXqbJaba+oCUep5GTuzI7"; + sout << "nad7ykHNN0iFkgYXMmXJl+F5TsS8y+izuHlXAX6wX1qRVzWJwCpM5oVVG/5eYTzg0J9C1bCcNyHL"; + sout << "2w5TJFYrD8bq3O+Y3fiO5LJ8F5/vsu2EBUMi1+eP1WfsTwd6N9jFtF5gA5sHX3zI925aDqVx9byr"; + sout << "j4X5yr68p5P6f8wSLL8jzW8i4a0yP3zXlqN6QQDY1ssfNsMf43tOTtmbBlmxviL2egs4gvadD7Gd"; + sout << "fRNowL71P3mkqRmnrnihlI01NbDl+Trzsh3EOn43PRC9nl8yo+fYVH8GqS8JGy1xOw4G479vOifI"; + sout << "9GC4BGnSDJdKgSnBwI1AJQ2TT8EZ//56lkRlgusg25TwC7uQ1zreeL6baYdgfXSggx3ULdNDGl5o"; + sout << "ftRK9LDaop6XvB6I0ITsLYvAoGP/5sHfttDj6HlQW/LlzkSPmzY/FtV6h6bE+k1gG7BANrQjwOW5"; + sout << "sfHNYadD1v4zIFdt2su3docGbGP/iDMvM+BmYIBP86zX5eIlTYwDmxXht95T6GCCjS/XuMMy12hd"; + sout << "Fdb6lm1O42ieM4KQ/2EOFy3Ij+YOIapzYA6p6Jz9dtINpCojgUHyo6xc4HTNnEKRy+YN+awhb1l2"; + sout << "FJdy2/QI3xGVNNTnWcQrsvjGZb/Z3VaZUltrIbnCeEZOeOCM0TxkBEhqFfI3qwMx8PUj+imUlTDM"; + sout << "7N+p5sxmKLliHHovOO32ajBTKUSI9IMQzf3QY6dZDts4JkMYQ1xc6lpm679s1KMVVrWuOqiAU5Vs"; + sout << "qehfnl+oMRngi0G0BnMne45CjU5RECvhg+Vkkxx0kAp38+9pY3XiO/DuyIxpOSPip2o0+9rZLF1Z"; + sout << "cAUGnG85CFEXl96wpxvqVlIULUV2+pNJxdU+q1MkCsxDeXrvfjhEAJpPE38dUb3t4blsNUZ3wJ2w"; + sout << "s6cXe0nEPWNkZlmEsXcFpw5zHe0Gd7YpXigz7Z+IVhvplpv686TJiLTpVPW2T1uJvSmMuG/FqvT5"; + sout << "JIIMg2of1ydicw5EbWrqhIUzllX3l0u00gFziPmKAioiqCxjWojd9l3Q0Q6IsaZAH+WzV2xFabbY"; + sout << "4b8SwoFvhe4qnUQLFdOSTbzeDIKP9B8bSiQwbjUBg3jYEWUrMz+eR9lpGu8603vChIEXaTxyMrO5"; + sout << "SCeaVOgPE77potDoSUV1hsoW7ZqGCFH+AGyVTohitS0iqZbIxC7+7rnVP8XfXw5YpSajF94z2TSd"; + sout << "jW0KpmuCZ88DTCPFamf5zh917qp/PzQOGTdalr+Ov+ogvrJraDnoE+ONWrdHqBm7Adgn8/wy5vzX"; + sout << "fNu1AT14eYrEmWmXvt6JDAbBYqP8Aw8b1QRZff11MblUh0IpztedWhifGy/RFJUN0/e66Mh0cKeF"; + sout << "plmK6NqchTzOQMKJVq9jxdyurcjcA0uu4dVJ1XXkAtxBim2J2m0zcwX/+HcRe9VbeNehmDbUC49o"; + sout << "ktNvrwbbB1IUV/c0MNCruV359DVINXskQTK12g2X5qprOLW+YPO6CnTFpJRsiFBoLllF1sUTjROH"; + sout << "SrHHRYp3W5t5gqfT4afBxmtTmpJEG0oG4eNfMhxEhQ7HjoVhahOM6px9Be9S+4ca/w+zII7NnUkY"; + sout << "Iaas+FW7vhOIDOiV82SpJqBjdY9eIP//XGR1DFQKI5cLKmT2/DF8tB9XcqTgmVWNMVt9Xw21CaeR"; + sout << "eYeoWvLHlm8o7ahtJCSQ0iHypTZMA16wdJ5IJD5WoYd50rUn58RBa9sTXT/t/KhxJfG5OWXl55eq"; + sout << "abYojSlluFyvFSk7Z/wu/EqFUEBD8r4OIrlJCMZl6kKy4EncmjUrb3mG6aDKxsaRBRBkRRya9t77"; + sout << "epMG60v3MRCcY9E+n9sXAOUpf+ErN7iD6FY4XFpq5R0Z+6MiLRE0af/JQ9R42quTl8CLH7609DDd"; + sout << "s8+8bKA2zjvSJhWbwGURRCW8SK9tNKuemwkt3Eutm+xMJemP2JIVFVXYxjCvDmxIIODneu1vmcSy"; + sout << "XadKkyjtYDwacddFAqGh0kLqHX9i/WoedVKC1Vuup+AYPkyZ1lPraGVqjq0nsiwp/vxm9c/+4/wS"; + sout << "hW99Q+zoAZ0IWWeYAqcXGdZqvd58gx0/fmU/Pq4FqtCdJ2qnoUDMvjZeyWE7lA/Xf7AdLcz4XHNz"; + sout << "VAidxMj7/K8p3KdK+XqED94Ey1WzpUQ2mH+10Zq/6jebtoYJlht9meMsjvjWxg4nwFIZY1QAMZPV"; + sout << "phcmEwrLA+Z/Xjo+FEq4hKD8pIriQi4xT4uAoPzOFGp/ziwBbAb/EfYsspnVxpnERKblbDsV9bFK"; + sout << "df5VSgeqg7p2auZBk/WkX/wOeXkulbiJA5lXTsgInJGoREJ+uaudFadnLD1pmjqq+VFJW6XOT++I"; + sout << "arRHT1sYJY5mhqFztTeUGH5VXZNtRGl1nWpewvmgyK6T5XLUcuZqsyZVtzkkQ0eSR4h+nBuRAQZK"; + sout << "EmqzcPrRKObVC7Xv+kMcnM+M2+zCuZoUSO/zt7OOXNt/B51oQ2DRqthPxgzUrWOvoOgZleeayImR"; + sout << "rqG2QnkA8+Kb/mhxJ+SAqOjsIJATfLc13SzVIKVumz+uX5jUiZXWfWF/e1cdS1w3Nf+dNinnGQ2v"; + sout << "Vf2SxiRlTfDTZTZXcLlT8VrCOP4UYvg0QrzBqU2myM41lZDUS2X/WOzvrNrRpEoCS6/OcbvMj5gf"; + sout << "dzoZ1oqvaL3dosQ9/QAwI7wPC2/QQTRyDIbl4EuhcX/ebyueLqlxKRLPrmq/mE2YU6aU7cf+t8PM"; + sout << "LX7J0eyNGl00TEWN0R7ui2xlfdnLfmILs+lNNthYvmUbtbncoqx0sCWgjk1Iqagp9uFlWA+6vMa6"; + sout << "nx7Qg0Jz+Qn2u4iZpyGDZKUWmHgDYhRcKfsnjnzbztyNms4tEnmwtIeLwDqFPlC4BKefz9gja+tt"; + sout << "0Om7TAtwcmWHve1ENSOQSKTLNvVwhjRLgmfK9SFUUjKeVSMp1g27Rf2WCoDyBXhRauHpqCdj7GH4"; + sout << "AmlI3EUHSjS+/1ZoT/2DnHwuN+GKVh0d8k/7sGrln94r3JuxEPvfyQFvPSRlFkWYyKPdxb+H37L7"; + sout << "DCYzgf43vxZ82JGYVXB7QBpO0VeEXiQcIgXV7uOsGKiXfFUwueL4kdNknk2hAfrFFQcpoBiQSz89"; + sout << "sNRT1tH5Ipbuf19R66cogiaAesXnm42jLjQpnkMwd3S4F6P8zDL7m13u6ahcrWvZiyDuBc/+Th2+"; + sout << "Swex+BSv2TkUorR6nwVoozQEA16/MlZB1acIFSrD0kSr22Vubdbo05svEAZ7DKIdQjDu2wTTOOvL"; + sout << "IVJOYPjdxXZytBv1jRIhUyyYBvRtaFsl57ZWAmvbFEXZXLihnrBskTqxrxNPqhg+bLxicTDlFyHI"; + sout << "UIipzL4AvofdYWolB8RvFyom18/szC67Flr1OW4axZ+k5S7249Y64eqtU1hk98joIhOdWaWBHxkL"; + sout << "nP+ooeHeEvB5hNRIA23Yxoh4zzsWUB1KKvg1XRzjt2CBQ2FPaCfHsOKf52aaj0W2FByC81rpryrm"; + sout << "Ye51T5zP5/N5j7yA+a907774PwIS3eYYyJRUCSh0ywfQ8rgkbjBdf2rKa0alzokz7Kmo2Iswnid0"; + sout << "WzpliQr9KaPwAk7hkLjprMjzdJIug3KOKVAgygXP7rkgETIfTfZRG49EdJOjlW8mlmHZsO+arTFW"; + sout << "vj0FgJCAQrrX0X9BOQ0MPu0friAGK0TNGsFs17lcHjaRNHXz3v6dY/MSR2TY82iyEkRofvNY/Xjr"; + sout << "FRB2KM6Aq2pPpIjY4EuSQS5sU9ur5oxrKo68jNzoB9iRvmKhQq5HRSYKL5ACBF85HM2oWtyVl23y"; + sout << "TqTW+jwNHfF+sc0FPS6xfwr3yvHi/OVlW046gnNLKOxO3RjGntaJeX8EVXGGpXDHLyle0UaZE0iG"; + sout << "1xeLZjyKq5wRJ/Q1MPry/JJbfCXSIHeO2Uznqn7O5rcs5v3Z1PdlF7BfPUhnP7+Wcqryfi4xJ8rS"; + sout << "BzyJkibOCzegvXnKTTw7q5/lrgh7LxfrY/4G6/Js8ibrUU9NGqBkOHUmxa9P7UPK43pz/bS7SWtl"; + sout << "yA/3hBa0bv6hN0OeXVaBxtr8sMfS7FcvR3wtvmtKn4BlIYer3LMSvigPCK3K5seTPH3cx0J2uGzf"; + sout << "SlPZus5idN8MnFCEiBUbs4W1M/BSw9EYA9rJyhDTyOYqKr6s1kagBUoVCXVlEPVgrJoppc4vLghu"; + sout << "NMgUpcakhT8SAulssCjPb1UWPF92XPpn8/byK8dJoSFe1lfFb5Yog5YZMjgoKKbokk0n3eMlrbm2"; + sout << "AGwIh0acdOXRR+lpeJQ240N/Waw3e+FhAI+AYfOkIXodtQcod08+F8uHCAAcd9dZvYyxZxNKjbCc"; + sout << "aYTFUYPN53OZEwEyCIFWwPf0QhdhlpyAGCj9gqVU4N9b5FJYX2ZqVAl5JF4nl9yDWrJ3zmhwL4r1"; + sout << "P8Pdv02ysNeZu76Y60+ffPXCqmjHjllu082gde9BXIEWdS1sd5qaH0qb8KRpV8WAYaM7/ccGTHQ+"; + sout << "H+0C5o2904WS3MG8rR6LI6EqO2fcBnJzZ5BJX2bHv4kNHhQiW2tZjBlwKjuMH8Ddayd1BVqzjeuH"; + sout << "5dfcL8xV4su36eRT/Vmanq/NZ80+KXsXZO1k88RIfQwwZdt5XribJfUSwzKGsKQrhu+8iUCjGP8l"; + sout << "ScrIRdj3gjy2brM+zBr9z9pvFZR5NLjYN1Ko2BptMbEDxdjnYkYWix8BF1P+/PtSEJeGATIyl2al"; + sout << "rAlEHX3ysdDjUic86ZNUx9c6N59ZcQkIr7IwFl6kc5sbuthroXmAnbW0A2UIO/LN6KFbbE53Up4Q"; + sout << "KwoMeMHxlgEwundK+LV5WZ136K5JoA6SpvxzuKhCckg0Ev4+KtyA+1wlna6AHOQaj24BzblSd9k4"; + sout << "2lWsVOwAOtGxFIRIxpou7S4yqPrvS93KPtVkrHDBqIveGcwoGfyw2ZSX+5o5SIZ5PUG3mFM/sNWw"; + sout << "twketaHdV/ndITa3aJyGpqChs3hcwMOgODnpC+vjtY1D8zdp3pn4MBgb33jxc5kOCpDktiKGyaQN"; + sout << "sQ7oaOy4aKmr4TFfWbrH3qeR9gz0utGL/iVHcgSlfl8rw4BFncc8HIB0SGJJhYE+lfEYpsP8H+1p"; + sout << "pfG0yzIA752vcaOIWIGt+C/EvuXl5PP8qyE0aBe637yQd1aMyRhf46rsAIlhzwZ28wPYZ9KCaC41"; + sout << "ap2+7/EJMw3HramAAo5OVqA6M5cV2V+MlGifSoVgTN+5TaY9EnqexQy2Gqw+9484Tv7QNVaEtwtY"; + sout << "/O+aQ8nzc6H+clWCWJkWDvoqIrIqP4jFUaJ/FnqlPEb2GkPoNluJV92HqQj6fD2Iz0TQKCVQkVWq"; + sout << "D/QuFVq8c+EC8Bz2j1cI0D30iwROmneb6XHTYVwn4yHkZ6LAoOz28fjT6dwJFdYo+Ci4Hhyl62tW"; + sout << "P0al3X19i/IjH4Xi+ZH+lISFmA0oEJo4AG/oAklXtGtRtIwfKGIuIzqqEztmX9tY+INu7PtgH/FP"; + sout << "z7d2f3CBZTZY4qTPMEPQ/th8jnjHrROZIM7Cej4v+zYms7NPlJ7x/k+eX5ISG7xEbWr8j+kr+R70"; + sout << "bjGaz/rED73YxTMBmhQSKMDUjNaW/qclrQuvaNwUgM/VCtnY7NANztFMXhCa2hGjZaG/bp8Yc9IN"; + sout << "T20nhrbTX+KPkcEmQjsHwyK8hT9XN6J+TD4iwdnb4A/KQI9JwaqpYPp0S1d99j0iqXlirvdPcotu"; + sout << "AsUmNf1YOlK1I5KxaFA42emXXmg1vr7USuKiX62IslSjknRY0+bPxOcn09P0VK+HTTdLIZ+8p9+k"; + sout << "fKsgY8ajl6qZ/LP5qbZ2KgHJHJNArwRSxNn4CR5ish97R+1A3DglEaWJ4vVuu4oaFIHc9eSgRMdJ"; + sout << "IPJ52p+8SKIpjM3Tnig/Gw35R+sPcuZlpauFplYb3vcIoY7vN/f6+RKxPtWnuOfBh1iPJxJJfz+H"; + sout << "MDVZihR471I+DLXgGrZ0fgMQZqVFelhF5eszKMOxB81TbRuPqUmneRijtWvR8QAySqzV3o+OoM/n"; + sout << "fpoLxmcVQm3LanGF1VfCbk7X9dhocgWTpk9XjDzVjIjPceJ2IPuFjHcrNtu6L/fe/6sMqkKWNRQH"; + sout << "8GGbrJU0/kqXIeYch5gXjiKFTIU/QIRt0e4YKlNV0Rpqkh8vY/X3OL8xbNCBd0bM4nCXMp5Ytwyw"; + sout << "DEyjzBl4SvxgGSqG6ehAzY1LrZ3bHU0Bn/Q7vD6RIEr/WcxUvdr8oy1JuIey4PllgfcCaDdW1+wG"; + sout << "YCz/81Acw7xOiG//gLZ+tApj+tGpMP3Z/vnC7bZmXAmXWCfZeWDwIcxX/V5Sco8G21PpzMYPM7k0"; + sout << "1MEkR1PgNhpKv5he6criGZ6D/xVAfVJbxc7blkovBkLh109MFBCAGiA8zk3MAShzI7cynZVbyWGw"; + sout << "+x5Bvl8/6xSUyG0MLFANqDilWFIEBpT8h6G6StX3WXoEfqJrO2pYMjQdqOg08AvXKWJg/xj4U2Mw"; + sout << "m8+nK+zX8aXHC333WcQ+1eG918/0TEDoQAXep1atGq3wir0iBvurJHbOXffjGQalMd3AeFCLWaFn"; + sout << "7tYSTWYcPnWwWuA47FxTSOPezm1PrihBIC7CyVjGHAGvtBdh2EjCVptJHYgft9Ivp0YpPFaGtT1c"; + sout << "IsaiWl+dF+Yg0K4FVIpNqRq/g1EEpGni0mrTmlTKeeSiKzAXdGnjOZ/9woea73BFZkY0kAqMlrn4"; + sout << "6AOIuXh9Af1UJe0OxzhcYKFjFuzj7Imjv0SgNaah5XePYFfLyqNUCctmTlFna9nZWZ4/Q/N0tqN1"; + sout << "QJwMtZOdFdKoSwFcDrSuMBc2kKNCEgnXAB9azTyR6Frs6RDNbCOdmMKEIF0Ra6v4fqO/rzc+m8nM"; + sout << "2GAyE9yBNQq1THcQSqlataFHDe9KkmlQ41F9hKifZEPJ2eMe4WbpMdXmjT0nNmxif9OiPMKR28EQ"; + sout << "pcqtuJxTE2oQArxmoOD6uUSUpm+Xc190raj1/JA7kfFQPkONEkNn9fYRh7J9VvPk58RIkyDL3RfG"; + sout << "SjlzXsvz0d2uU14U8ppyPSOUEgcvxUu9Zk/TcwZkWvQeJTPd/i7jbUyAHTPXy0secfKXWSoF4T1S"; + sout << "AuuRuErtEIXmJm6bd3v6ozR+Vc494q5Nu80EGIEy+09XWaDi8E0ChYGPUn15jWmkw9aZ2SUGju+0"; + sout << "OS7eaGTSBcbS7l3IaP+053oZvh6NN+iYo5Lb0rs+bog58fqpXLFLeaJFnHUmZipr4oX2EfpI3FuE"; + sout << "1I7xjgdZiWMq57u9UId2PuNahTVN62Du790tZhGfoAACZxKx9xxi8nRwxz1Rh8uFosXdHJridfzF"; + sout << "gzZhDTmxJjYCq7tg6769BcDtHxT2G+JOh2hMFV+aieGZEkBEfj6EWhuot2jR+VVjpLUhys154Fj1"; + sout << "NLN0d0zMnJDThQlNGigIaHgVdQ+l/lNtN9ovAuVJRib/fYnSDRBpOQpOU5NuwyHjeHnYg20iKuBT"; + sout << "ZWphFPD7M+zYlrVVH9Lg0AsaY5Yt0U7g+TXLuT/bi2tUz3rrrk/5bY7iLkGbEmOFZmxzXXqWdm8h"; + sout << "ENOKVj/yrgSa/l93WqyOESSZk7hLMvP+OVkSj8qAKKBQ1+XyqVLODZZae0volbIcZe3HAAIjdYTm"; + sout << "+JQfIGAWgkqcHwgv3WJiGPhq1WOVi4FSq2Dgxi5/J6cRg1Smsr9aCx0uNC2x362lI8Jd9yKn8m+m"; + sout << "te+3Zx6sx0NCnYKxaWcH3V7BfF0hp0WQJ3vQbPG20PD/ACHvMEgmo2dDFit6m4yfWAQxHzQZE/3N"; + sout << "E5TLT4EMnZxi00F6sV0G25nElE2t9CrGkLNxTUbK2sGKx+ybsveIWpoNtQty7hY8NF2KIICOd8QJ"; + sout << "FsAKxGHbydI+9NV/8KyW2UID4JpoNJOQkh4B8pp/1bkBRPsikKLyowC6RWuWmVBm/DCPSIwkiV2A"; + sout << "jNHVRmSoDO+U3eTMxbamjBV/H+xWgrBzBu+4aaFGH0MbKNtXG5COeCVMCtA5v9pmR65GLD/DYWcM"; + sout << "JltMV+H82nUN8qDVTMpCSzrlkiv4Gmvh6b9HkxZC03g+IrBKAkXWkhIl3iYkIjLYNudFSUddDb6g"; + sout << "/wHCk0lGJlbYim9VV0uRYJITZenRrzsMcb6g6Cm22cB9awV0qpixCGVW+jms3MfgcstzqdN36KPw"; + sout << "C0IDdKjN2Bu9aNqHqWafK8Vl+oTYVU6foPJSOmHD3MhFHhuZk0oPtptRs/0aSZKH3FI+jz6KyTM0"; + sout << "E9UDooIsxYAo7og8Ka1QCVel8cH4mmTBWTGLNNxuVwvHYgQc+j+QgKJ8DX3XzEHJQVL2fxCmm3i/"; + sout << "tjJTltGK8o3S66OO3dN5g1KaYyxCDkmKjsGpyqAKGhdrzQzwLru4oof/b4cM6E/3aqGWH9pI27G2"; + sout << "8jNYhu6r5LhYMpczurY9gssS94+RdUn4UFMt2zZSlpFsCY9E0NNaGwQ5sX0pcyk9r/FKWAWxT+e4"; + sout << "b/buzfSIVsHWrzytkKOYCHMylaPd+juDOWX/Y1x5IBmR/VnpsIWbuYFjRlK7bvNoVcwitIZyI1Ku"; + sout << "vmkH9u5YzndkbH9fj8FroFgMdZumIeRSFz2448yoIh/1+2wyEUXUvof32q5kktEutky9XtKCTIen"; + sout << "LlWO9/7k0Kcz2Cp1S8bugmULKSLHEWMTtScZhEOl/o3jyMjLpbHhSfY5IHwZXVp6MO/bxpk4F4ur"; + sout << "C2eAlsHUW0484VZFIm/GtgNRKq5H4MTRSmlzHxh0o5KnK87ZZNGKv2sGFoxhOKT8g9s8uz2ZfkI8"; + sout << "HS0VWQ5y0Y9dY00ShJj9FShAuForC1EW8TBcgW2wjk7uN4CjXgupadGHC4hMFxVjJJ4tPj53PX+w"; + sout << "KKTety47QKF0aeNAXeiNkzo0e/H8XYKYvyRKPpUhWbj5rzdkSev920dKjpq731kGRLUP8kljqmx4"; + sout << "j/1ukvHqJrarb/U0LWECXe1mHUjehedJNCuDsXmlX3OIT4557z3W9vMbzKyu+0R+LN2YtYUFTdXA"; + sout << "z442W3oqy7cJRIMfioDLTO4ry26sNo4uyq83j3iFx2iY4Wc41ZUGg9cwh5TVKg8XEh5US5xlsqVO"; + sout << "kDR3XfYXA3GwuKklaNN6vImd+oP4g2ZYSl51f7tj5hd9xpTSRwIy3RJJ5VoTz+36jpT4Y3fnlppg"; + sout << "GqBmWhJrY5UemTIoZbJ5X12NjQjW2HiKsuiCpLS9Wm0IXYWcRSYfiWYBLP+QZFyRA2VqVpwmY2X1"; + sout << "EafYVxAjG2au8TnbfK+PLccuRg+kYNExJfD/hLUMyVg4wkLxP95L8CB85+g/1VomueeKJFnlrnkO"; + sout << "ezCBls31aI/r4fMbdISFALkRwPav4rVwi8M67zuhxx/K97+5I4ONkaSU8/DI4SpqjfEIzl/y07Rg"; + sout << "VUou00laGIhidjtfwENl18fyXGmjmLI/Mn+/H8gU1mW4Z0stSN/NkPYZjTx1AnvjG/LgaY1750yS"; + sout << "4dk+ygLr07oWPhGB3BhIElS7VDxZnnPo2MFIPXTqWHqZ1/lNq8DE2EqgHgpFQGmp2MZVi060DA0Y"; + sout << "En5g8zk1NXq0irzIv/hXYLbDEnL4ieulF+BlWN1oeERYelY8VkqgMtqGwBlwiO/qN488MVobHHAk"; + sout << "VARDBpkSyX2bsF0KS4BCwybuQtNPVCaozYKWd8Q0RSNvsK72afBC+snd/y2KrFhcE4mE9ZhAwV7R"; + sout << "LRR4IBmgNkDPDi7YXFEVZ4No5G5dJYL3yfsZy4b0kBEplbOoIjYxwz2dXYtX5Wc3hcKzRblKZG2i"; + sout << "GkmTHabzN4BTwbGBxmCTbbyecAIO6MFJGlnxW6tQfdiQbcBbt1utUTpjVhZPVkGolN4VgU/qFPCj"; + sout << "UyO9bO+RUapMvtwhI9+1KPcGiTbQsAX/V9+dSCjQIgD5sLRjfVQcKmK6/R0VSppo3ab0+XHDv55p"; + sout << "FOPkhAKiKvI4Wl1JcKcsx8mwxCoTSchCxp5JhNn+WYBoINpTlmdRKI2hfXvfY+YXUzbATuTLKIZX"; + sout << "IsHeRrzLNmntT4lzgHtEArSwEcYDRXLKBd+L13FZBV8iMX3ON8vUBMLU8QKoSDXatEI//h8RcI2R"; + sout << "pOba7GU2f5TWFy5lB74tBKpcllmmid9w6jE2T3yhxU0E5GFWxWv64oSJDCfyD5GRfY7L2dOVBVwA"; + sout << "H1DuC3NeBQfgaY+DPYFyC2gR6vEihtW5biK4HZoQkEHaBD8nREBdMlh8DcGuXwsTwEH1co2xFaNz"; + sout << "53QpwalF61MYqPbQuFXZBvFEruliv3cYHUgIqtFo902pwFOK447zzj81l+5XzdVZHsA6dCGAjSqW"; + sout << "J//PGJo3M48ERSqeURrEwNN6lD6nOqs9XAkQyGp0xLcv1/EVyzMoYTWazSaTkHbocIh58BOJVDya"; + sout << "rjRytgcV9cAKzYvY9O3NBPvWMBSybUG0weBGTpWXNlydqxlAc7PBND1DfOL4XA6aDHpra2rRpJ/t"; + sout << "xQJvaFWVNRYBOpR34GsrLpczGcf/z5hhR1gpE5y9//b26xf7V66n3kn0w2qGADZz4eL+Y7Wl1rIJ"; + sout << "QXs4U95d6lfp26TVY7MsmQRf1GaO4keltA6LW8XkS9zXro/Ydl49AWToXe7suuJk6OGzaUqJImLB"; + sout << "fI1w0xXDoVdNfY1SgepZyQxrW7PqtQUlLTHccsTDUJqVdu9ZUMnCVlo+6fQNz5lS7wvRbv5iqgkz"; + sout << "DMyynFxFQvzk2L3sZUt1+xTw9r2d7urJ9VmGpj0arjR2+qb+2mfFqH0HaldqN+DGEiibZ7w9PmCT"; + sout << "MNDZvjC0zm2N87yPuRBSbwn4JoAD979lNhFSpExOt7v2zucluinLIqwESRQjWnyun+xTZbu1MAka"; + sout << "JAut97DUpQb5ALQ7TLqKOfk4vSSP07cVRJPSH6K3XnR+ZFX9W+7kb1mYRhJ60r3uKUYAoYJIdGqL"; + sout << "jgbNfvqdTZqUOVq/Sfc2/B2T3kY0W6facFDev+/YnpwWe95pYSfUbewbM35nEZGJ0HVSRHnBTWIO"; + sout << "n7C6Xeg9e29pfohDW3jy7vPL9HU7+GdvhZYMUfNeQTe0zYKuY0+/UtMIuMFzDJ1J9tBy/cLPuI4K"; + sout << "oyPNxmTBGCcf33xcff6ZvAePZPBFgjmbV8izFio89if2qmyhGPDi6LH2NxYGpjC0f+aPj3j3H7Ua"; + sout << "yX5PEPGDl+3l5jZjuY+sLwwqgrUV0skzdcjAyEbLPClOkj2BG5dELl4VcD8ESsOwyk7Yyb2mt44n"; + sout << "GKgKNGm2+EwSNyvECcoEksksg7gaE6ZNXazytt+kRITYczq/v57+U7/tSjyTRL5qPLxWX5OwESUw"; + sout << "y2zx0ulSrfH44+Xxr7ZnI82X5IgDehZJQvNPBmtPTB6JvDuUhJMd+hQF1lboLwEHAfZKpcN4v7FB"; + sout << "GEZi7Sp/iWCZQwtALzUDY4YKGUuS7uOyjpHcQp+hxIlbhXY9byIyhvvVy361/nbVwOnEHo5BaKYE"; + sout << "csaN1xi8WvBN108lpddsUUDRgBW7oKXoiDI06pfubDTDZHSJDABQlnor5sTsIQBMs35yYGuq0lMN"; + sout << "lDJ5h6Nb8r6h2HhenA8tSBmMXoq3j4IAq0jUDpeR9TXX4pBbGfN1HgWpbIrAKmSh9L6Pxa/tB97o"; + sout << "D5seIFPmORWWemSfAMoAs28YqCise3933/HPnCk83PcWH/4S7+KITJx0tIgF8ssoS36XP4J9L/1Y"; + sout << "ym65j7ffiEEDH2rDgip/UQ7utwOHAIW6rOjkComtHS0F+m+yOcdoIwecU9J4rPGLgpq7KW7oSdD8"; + sout << "5/ckw1VAwgUvvo4YT9forakPHIPB9BgroHxUjvbvABXJEkYGw7xUwn51NSj+7LieWTsE+IVs2kWx"; + sout << "TAU1B6hsUPr3f4Z2g74JI1AVGc9KSuJhTtognYLM0amQd7HkR9Y4gmTRYYrSbE1yCWj/gYd5Sn+W"; + sout << "/NUvdGmfqjcmItvBAkDr7lf79aevcKySCPfP5ZzDBfM6aJw/T6EC3KwBpY4obv/Zgx9dLKZhA9Uh"; + sout << "jCKQEEpTnfOOGI7D92zvtySthJNjrGzN8ZVdJHyzXMSYWgHEElfM3bB3LdAe54vVG/XyHag1EMMH"; + sout << "DnH9JOUvMeXHOLRDnkI0RNlGg21wNjl3HTxSiXkIwpANPsBpcoow32KWYqrygnB9iF30IdwfVTbz"; + sout << "OhTNM/4qyrwjdzxSTX4IeMQrviMB+gi12mTcB4G1ggqXuz1q6uFqfxrlmMx+gDAuoEbR0vFF/bXg"; + sout << "M+8PXQ/oKyGYtptl5gM50TsI2CxNaBAU7SUTG2zH9pDkoko6VO3mXfRblwFH4vjJ3XETsr2uAJlS"; + sout << "7wOiJOWfMj9dKFMH6efJkuZPegH2WtkRjomYXO3l/UVkWwW2KuLJgJhAgqJcI3ODJd/kVYoR+THn"; + sout << "IiPnJvoBXfLTKJ6r2lbjeImNg/CwzRVhDVVPJi401mloyrMU6JQ4DtwoqeiS6qAcLDlJMcu2A0bG"; + sout << "+F1isgRc72oPo86rpVxR7oJDEVRSsQqkOhv8O/lVaziMsLCBuqXUGfuohPNE/+mCdSrZZ5LzkKSe"; + sout << "iTlYATHr66c+jnkOETWaEPjUpwQ6ABfit3mbttnnONmDSnenmUnHUf20QyUonJGpFMsWB+DSPFs6"; + sout << "7zI9eOAhqKQh///VoPYY37AdsdacucmhBJY6lmIHHiDyT804IvuNqWLSt5/Cko8t8thgShjeSM8J"; + sout << "9med4U5W4XJ1fEoialm7jil6e/fr23OJJf3VJp8JEaibvAk+rbAbc5VIzwaUG1duo1O4783HLJu5"; + sout << "4we8QCONtekxRwXi5R2gUi//qD1kJKbKxnkYOXaKNWkUbEXPOSy+evfT6jfbYcdk7VvmfA0qioZc"; + sout << "B9nVWevTpoC/1sE/aSX7dqAWjOd5WH+KsReDpJAtB+uMhu4iyzaHV/gPuCAUdm16nVmHgcBP65Ix"; + sout << "Rz93awPYE4aI2yoWBvJnIN/GgUJPBW5rHFcTncTV1LSUMePStawPL7BZaY/V/HRBUOnQ3V+p3xwE"; + sout << "QqFY44ilI49X0t0OR04upM1hjEnx1lVyd/2bSw03lCDr+y6oNwi/hrk389KdFRtlxT0/rItCg+gX"; + sout << "JVV3Q5LyY4WEE5on7coii0m/ZyXMxNT/RitPnLWc2aPtEKhbVOWpVuQGw00eYcKTs+AN1SuqcsrA"; + sout << "7mVcPVIYhQ9/rDYzAzdG3HTcuiFrDWkGIOQp++BZEYitA7zEexC0xZPQZsqcKoH8RieRidrtPNXS"; + sout << "ihFNGNNuxYQWVchClJMvEBrl1ankneT/fJOLTob7xAG+o/n1zdtSeTUtXPN4O4ym3GiubaONLzLL"; + sout << "Z1TzkBa+H/t9NkI+Vp3kVRswJr3cEu6K607OPm6yGAxw/BwpQRBli6uf6SMAdNcAPMwZr9xV7Est"; + sout << "pLz8ibkfNdfMj6fMY9WKJ5CJhajqg0WPFTlnSnaRs5ERtBLK9r8Ip/XS2VUT9/rqeFivpq8OsInl"; + sout << "yV5iKygaW5OyZOtBbI4SrhN30LZZaoP4D4fjXqc5/EzHyzGCYCfgjKrytefR2F6CqUUdBOn0nVtH"; + sout << "Z4xjlb6IBw80vupy3KEjpjsl8eAiYM9JsV9aw4Fd2hjCdeg6yPCNN56pm59Yamga80+31oINYjri"; + sout << "3OcSjN6tVNwdf1Jr0s6Y1+0VgrXT+AbHuMkdPQbhTgQr/AAqHzUr+5FhrIZ7xM2vF3PrqUDakSkQ"; + sout << "P8xIrxYawDr6fXVDWeVPOlVhUSihPBMHjc187YnXDd8Hun9Lww0wUuzOPc9P3Wb8wBTFY2HiXNL1"; + sout << "ciWhFec1G2O1lNNBgYSeclowdwMNrC5z9lk0jhLKLrX2Ji+B5ypECjWGE7ZMNSuETIucCTh4wl/Z"; + sout << "fLIB5I8Lx2D0asU17GjJQk1UdQa9uWdNgpG07osHpTH5FoWxZcQSBl7cvfkqXltox1ItArv9yuKo"; + sout << "3gDp6AgZTFOqYhSdagGzYHdzB6KkEpIUJJlvMZsRzlSNIUtHJ4muh5SbP/X0AAGWnNjNZj95Yf4L"; + sout << "IS5+ZQRnfrzIl7Nvb4KkxbQicPMrtXCcZkWJ0zN+xlNOX4Ph69XZEpmkzj5OBi7H59Kcw6ZB8yEc"; + sout << "3SIw3oNS+6XAIMU1TvhPexpfDTyQNBbIgyycOPYaeA7eSgg6yz/4z1RfNMVZEj8PgPri6IzZc7h7"; + sout << "AzIGqSzGJWAiWCtBFSmDQ3KbDXDMAaG6e8g+zzdm5dnujiAJ+s3PneWlapo5dIvjh4MaL3w6iy6w"; + sout << "T62tjz17F9eEnJD8IM36+Wn13OSPk0iFfPKZfBDZPhEAGRYG7tzc/HKJ/d4m0hEg2GTY3M6pEjZj"; + sout << "nQIcccSE/e76TSkeBNrZGp7lplsixpLjBdRFSFQ59D4juFAU/8tf+MmgtxWd2VVPU/mtkYU9QXzq"; + sout << "JeDq/+MOHtQoMdFuxJvlEj0EE6Aa4E4Ya31LBEoSbp7ln5dDcP/R1LvqaHZr7+XU73GMzpgMec3D"; + sout << "7UY6Rnip+AXpYOaWcfz6XX6y6lLA3kdIsNptHnc+f85kigFZ1RsCXZxugGLjxcFWXVieSKv6PwVG"; + sout << "oQdmyR7KlT+tdjXvfGTP9AwdU3S4QGHi77l1FSZaebpelVkrMgWhcug3s1Ed0/1c55yvaZi/ymXU"; + sout << "OYEtOmPmMAB5wcOagBZsTBP4/6w8Zrfy+27SKh0W1vD/rHQMP083Xsv7HCqWppSVZMWOGJqyUkUV"; + sout << "rBnbjjEmLTyHXr0e5DP7TRCpx4MWMIFUI+fkRdrWSXxNJqPlST0J4BzbQl3XSppj3iURoccQBDpN"; + sout << "VoLZA+61XgLAwY4a+1HGcvVRLzlEhHCmWEaLIeWMfIpb6L7U/LQyfG4nDiMqt5HtmMbLHhSf1Iqy"; + sout << "swpB3hHI+UjL+bFOD5XkCRclzVPucjimVLtsH1QXWKYIcrC23dh5/tfoR4SUxzYn30LEbmcMNctf"; + sout << "ETO3ebcGC9+hvFGH1CwVowGxDhQfxf8tjORR4Vv+L0Xfp7yr0Li/Q8wPnrIQQkbasy2O7PwodnVx"; + sout << "W3iohD3htvNaL90vc5WOf7o0YUKcAWfM5ryT2OJoFFYCSJaI4qLPb/b9OLg2WUL0jV7lYLrK1mBE"; + sout << "JJhgJRQr1OiI4TnIIWrUQsdvkrRVYiUWzVS6RsE9dLpw0ThoXtu3Bi1gWJInLknUCJ/yUPkNqwQS"; + sout << "hKult4TehcOZBfHVc+BOtTdcLNEzQVWy+HPWssvhSNIYtoWp839hdGzzLoFxsIGik40aHU78d+cq"; + sout << "ksDCHIFvnbgBvPkpLxrmvXrKofATp+ywoYFeV0g808/Pl4kX+27zXT6ggZWAO/I9anenVXNcgtvx"; + sout << "ICdGbYeHkfXRr1/IVvTgtN2kaS4sSRbf9LPij72aJoftCIa5EknOgNSOQuqpEDidYdaXZXl+6tg9"; + sout << "lz1qKU3i0ivIjIcaGmxLEi1pBM1LwgEYWcovusXNcv5+pm6SXUgVzQkHu0Iz5MJEdrgsSsc4NN+2"; + sout << "swZHdmviqcDDIk7fuOSwmj4IdjAWUq5lmgYZbpZLZ0pPsmTjqX5uaBFXqmlpKVj/vEIKiFOCGtZu"; + sout << "uEek9ZEpH8aTYqjf+tGKsNNANsDNOFVwLsdD5edStQS0c3U9f2Q1KGKXw16BM7pArxVx6KxFjI4D"; + sout << "LQxcYx18Gm6V4sCn2J0ahj7IO389LWJQJBcfJNyNSFhfaRbVha6itGi8UaBr7Q5LvkCvV01WUcJu"; + sout << "AsuyKjRBScPvjzypYoCxSZp3ln/sXB58RGCVZ6c7UXeZnGs2ABzXEIRYIJyrsNNVky0aGKSHFRem"; + sout << "r2gsZ/RYPQBVw+xt8kGwAkM2km4waF7nHbkN5SYq3VIedvw1gU3UIkbpno8zJeJcwrnoVT9n686i"; + sout << "aE/9ltlEfn/OW7XUGFK4jXB9GBJ455E/9iUejULkvx7iqfRsDnhbI7UsVDn7Q4snN82f65MGUtU0"; + sout << "w9UaxqWKQUZqvP6rX/4u+2IfBKWAksqUG/Rl3O4krkxQRuuOS2KA+u512w/JhgdR/9O0BNG1YuBd"; + sout << "C14QpgMmqPdGEfNrXUZN7uSWJSdBiwqwh+yPFVqoclcjencYDg3ZzKNrfUMun9eKKRBJ2UqPNmJ0"; + sout << "zM06doKb85m39v5GBACWExd6vWsrP8JxcHeWtDkH3Bt6qhZ4YjU38qiTD1avmK12ti8n3lpzOpNB"; + sout << "ObN1g9F9JLpGQs4RrVt5xH7Xv2LsAC3UxcKZMW8nr+QVc9BykqYIU96dVj3kffJlvfM1fTyAtTN0"; + sout << "4016YMvWy9OzdbSeaW9c3ua91Eq0w9Ve7rR4D7rUm0DGwaPPNaCAQP41DDP23U2RkaV2yhcS2ntN"; + sout << "95eArvlyyr0JKWkMochvrYl2iHN/4cv15vog9n/9pUP15ttJdZdqiE+qwBaGA+B78y7kCf4X8dAE"; + sout << "ab3I0Gbc7FLCNEsJGcQlFTSmB+2ccRQtRh74pirUXd5BPNNQUEkZWXyD4tVcDCkWRqHbkxlneyqS"; + sout << "ziQQlbRyiM3MmvSmsUOWYYlq6iKu3rzomWTRukwdFP/LbClcTaMW611t5rXM2Jrl2JXWer6HsRk4"; + sout << "Z65qxcLwDmjokz84nvJ1zpNeDcCq19jnxNqEOaDIpDVQxtM6RY5L96Sn395ecZHJFpL/E1TEHSbk"; + sout << "V9ZnTcZwiga4d4FaVDa/L26ckv1+93o7KzMuxgopuJqf9GJ2c9inY+Y0m+71i6MNharLdfuGkLwr"; + sout << "/iEeyeu8K8QcwNyDDd8QfStXRGgIGhNR573Q6ARI6a9Uft8y0hUYrJSTcaryLZRnV0xYAGAq11p0"; + sout << "YOLr8U2iOncUfz0+Cfc3cu7nytOEpr+jDM70ojkQxjU7DmdCCYqdgik4v093Rv9hTdPEBFtzNNqh"; + sout << "hWlaCaul6E6Pe+PdjMStduOGI8+eNFpOJ7/K7IlXuQLMLcggaXELqeqzUTaXGnMigQeiNsXUhXJr"; + sout << "7g9PJYDPXXLNYIR1TuLPXCc6L5KuF+fjWQ9CwUxT6F0xCBMVMUVHQSkNoqngCeaHsQbIjpkFBbxH"; + sout << "BXjl24RA5TKttRF9mUgNUQK4VV9LU93FJFqPAegUWA8A5AbaRkJbylwWT26qCwRcNcMm+wVFjxdI"; + sout << "BYvJqx7TCrvo5ytBIlgRVx1HLjUcyITmeQ9CCl1j/Tfb4RwslDCgCeW0LHocZde6lCdknTwlOre8"; + sout << "FHdSxxvQAImwiZKBSxPYqsRLXEGMtcFxkpAjbUZjITfHP99qD+h18ywpMp1xQz1FAa6QBaDjFQUy"; + sout << "q6DWqkCYI2cOpovwq+eU9y0HKT/CxiAclgINEMJ26zRPgJDBK2vYmft5gfB0CHj8zUzBuYCK8n+u"; + sout << "6RiFrok3YKqfszQbsWj/M1nBKReS75d282S02qdrnm+OwlbRFmkUX5VNUsI48fffdLUzQAVOD42L"; + sout << "O04+nlesMESB1w0GezPWmsG+eNiUghaLxLnptRrcATVWseeWuqcdcDG9ct3BkyPEjaNSZGbvqN0q"; + sout << "8S0H1IjsnmKrctlK+1ELXGBaND79Uq4HC8NtNseb4gEQFBTj2rI85gXxRTqPwtYvB26mpBPrWmgz"; + sout << "JfLIOawOFX/GEe3W3NelU+CoBpxvGv2wmgqW2quks6TmilBZaQX1ewR61jVhKaI5oG96e4uUksjW"; + sout << "I/cXdAP2GFl20jLWia2m76GGbTimCyffGDV9v3uzu+lpoZ87JgudGUKn/jdJK1uqTad/YQj5t2N2"; + sout << "Y0PnVN59l6Do19E0YwY5vx5vDXs6q0SiWoy/zuY2hqcZ7paY1iJhIaanMAJjFK+3FelY/IH0Xo33"; + sout << "Uhv+4k5pPRuL3AU0nA+rg4JCT9YS3OwOpcwv4kNyQL1xRg2DnlcryXMHHCZPYeVApEruSVmO8nTE"; + sout << "g9cbFMz5nmDHKWdF+KH1Zs9jMTI7tfaOGS5qwX+gBwcMGUPeHf8OMiJ6y0zXo26vHzLJ61wPWLrP"; + sout << "juyapqVcV/YgQl9Ok54s8YGCv9ZB4rxnxjiABMls7ZZquK5kNJ2ShQWW5F9ibtec8EcjMuK4i20B"; + sout << "tTgq8ymele0eOIWVE5AZZ3yD64qqbY7kuCLJG1LxvQRe2zr3FYVYqWhfyRvywRIRFcequso1Bc9b"; + sout << "HKoPDpwJDaTrQn7uhEsb02WZKg7Q4XKZVRXxDGHHcBqwa4fnNM3IPScvskFpmyrZAvR1QrjNLTii"; + sout << "NdjPVn0Klyr0sQlppC50bu1eCy/Wt14QKDUA8OFYOmA5mEjVGtFktwxy9wLssgpAD8LoyPQSxuzD"; + sout << "cgiB2HqNtOwNGPvlE4/ZDfT/N/j8s3lk0q0cZmrAUCBXsBDiAHNbkm3WeDEBDY9+Un6fFF1U5chM"; + sout << "tKKpyvjeZd4bjQXK6zzZNavXSJzOvqVb3OKlH6OTvP11rgv3pMdHYo9T28C0onwMHN53QGPsWbzO"; + sout << "57SmomSDGYs+ERJRpGEVUXgj3D+Q4O4v/fR+XMAtiSOVmz1c3c2y7Ys9Pq9pFX3UF4q6DLIWfBmE"; + sout << "6omLA6O/y/Y6p++EZleosnni/RH3hMH8TvZRYFW4EojNCm7Ss+eyuktlVXQhcPUOuxQ/lK1SUx4k"; + sout << "vCBOk1YCMcl67xok/WdgM/lWJvovLTLqylpQlhsHIM3I8ccuOcJ++lhPABcaXmInXnPMEV9K20kk"; + sout << "d08Q72uJPoU4rjT6SMbBbq8L9UA9Ba7U2cOdK5dr3rUZwQyFZBswroQ7A9cZnuCD4ugJ4l8AnUQo"; + sout << "A7ghXFgGkmzItOKtHoFYz9HmxZa+3qX23EQk3jVdml/8fFh6VjpTK73RPKwrbZyclJs0pyN09eDb"; + sout << "RjZ153ucSgBH0jflZbSIoQhPdmmi+xQBqWV+YjqVYzyYJMJypf5ZrLCb682KW0KF2dpleDJUoX+Q"; + sout << "4pQNhHSZWXtdJTUcRdmM8Wl3AjY0QNsbjyqe8rYj/o7FJaX/Y8b85y1fGLF3qqPJDxZQgR+jKfTg"; + sout << "0vlbJdh3EbV5L//jZX7EcOTDU+dvkXOSyx8zQeS+5xwVXWPmTmaTNIriV/6EvNJBPQ0vmYCjsUoU"; + sout << "6hD4EOcBuuOXADFmEcgRZl9z46qgDwqRasacBLwpaICbLCnpc8Q7QrBhpbmeHsWmqYtK6SzfiQ4j"; + sout << "e19bNsz4SP4zzFzdEpLl/J7PeNWURM4SUBwZcNOZbDbD8as2KcD78JSCs0sG9zWL12JPjZ9lJ7AY"; + sout << "lN9vqJl+2N6H1VGiJR3eO8Zrb/lYX2LkSz87AzRggZxDXdv1DjnPzGj740McdjWa96DZeexzAKtj"; + sout << "DVHwn2PFsYDKzvmwr71zUNYLxwcK2U3ayJf8nuuP5nPkRDwl3b8ttN9QHNo2JWabnVJKWid0Dmiu"; + sout << "zRus2qOuzcXLkGbgE5DdmONYcg7qznU3ostY+QyHO4/UZbDpqOPG+uXuk3SVhp6yBEmO2yE3T/We"; + sout << "WCWw3dfW4DlOTxb4m+nf29ST7WIBoNR8omWSyxyZZodAXRy6NfOnpkYrgFAorXCprqzNiRLSbi8e"; + sout << "hZrbJNNoqEUTrw1R/hXMJHftJH8GotFVuFuXTBV0wcm9eM0UeN1cWvuT/0dc7ORsLdbhNW7X9Uke"; + sout << "tYEHwobKiM0mfa1dCdTWFvee2XkrmYsNjHfMNoQRUN/w+1VHqnFy1Q/qc1MoU/C6L1rPUjth3gNC"; + sout << "oNr7jNNWH/tXMAEGqsHPP3+Hw9pqk4XE/B3QSbeQYrZqcZojBhWcwLJQIbSyJmya3w+QYKqie3k4"; + sout << "/LNyngQb4on/sr1vLNCNc4c8yJAxV/nMoNx5cVDBVg/HNAh+qwhQ9tVhi7xRUpIrAPICcaWaX+RH"; + sout << "Wl9jEeq6PdI1bUaTdBvF7DgSvAriB22oHxG6Jy8X5WIycn4FHMF0/ZlCfwccg8HcjvZzDlvFHfbs"; + sout << "lepXFAJ21XIOWHwDzG19VnizLorKXM/FmOnFwClhG/+yTREVbWCjaTkDsOlWL6JGOehVHckhxWRM"; + sout << "03eFtiNGh2k/oqsqHkDxYtt0rMmGly3vLlLN82Eiijq2iNo19EF+euIVAl2h1iEmOGXQf6lkXCCZ"; + sout << "yscrDfP3XEbe0grXP0+/ETyFrAAl3/zoENKR5MUYxVziTQ2cy733o/8aX8J/X/hFTm3ZFVKKmZRp"; + sout << "vS7HzVzjj8i9zZNshaRzWt1jYnQxKtJ9w9BEn0VLGcb2spLOKyctfmsughG6DDA7wjjkmNDNPFMH"; + sout << "mefa7RvWjXog4gPP6SiKITag5LwuBDotZn966sOaBOWK85QelE/NBsk4hCsf5LNPkMDFJYwB2ZUK"; + sout << "g/+WbwWHIGFI2O/sNsM8W4xDyia/k6ZFpaki6uOZ48uR8uLOu9mMR7NSJ19gRoe/aHeOAi40FC8v"; + sout << "K6Bs4rbxF82hbJmSfp64b0d8pNPevQ4X0UEbQI8d9o8RjmDqgVwV49InO/hobuZNWyY5sIz+8b+0"; + sout << "swgWi4uhT8bsQvMswskKsWmV9bxrrQ4EJFOGtzCIeC1X0Kzm25gSf7biU85dr8/3dXDQcEqdL3x/"; + sout << "BVrjeWsXd3ko8dQTJzD2kNLqi2yNmymIarjj5qzTxlnAZYbJFtFxtPPO4bfPDeRRQ+D5PhiZVZn6"; + sout << "a9WUtuDFmwKmoZKZIaKQfQtwkMQq6F03sU/EsO5UuglsfN9gZmLVZrNR89YPC10gM3cSXmABMYcx"; + sout << "OHnaivD81i4KmkX23r4rltxlqsgzdUKiGvEpPhfRwD1bKlKb+dTFgA6x5cYaOQ+2/KqeGn0JvRHB"; + sout << "HWmQQG0aJvlLelva7sG2mahqaTpsRGunwr6EkeTDwSzY711r2cNcLBRq5VGIg9ODw/Pn2eMN1Sza"; + sout << "0xBt8eEzdGYywXR7zpcalcJfOOKUdpm5D/Lr0Q2y9qGhqr0yambYW4ltxSreDuBWLTp9lxjbPmeg"; + sout << "tpdUAqTqOEIfskeG6FVSzfTSzUN+q9BjZw4RE7aWCtqKm1M7hF7o9FRLMhqws47tamk3AZZTIC3x"; + sout << "sGaCMG1/h/gu5bH/VQUZpzIj7KWgf7PnJEL/WhAsjjgx7XRcEW4OE02pwmhe5C0WQehHYOdTSTQ3"; + sout << "Y6djpbeCYLGJNSW42W9N5Qvrp3GTYyqDlRcFgsZFDK0DmBonUAsSh6Ytr8pxPSNWAejairTkKoi+"; + sout << "gon4K6TAdRrK8VQSSf+eWE9oTlcteftAn4iWQzY99aisuJP0MzOJr6gZgp6s2GlsaiAA0KObTlwm"; + sout << "/SptTPhSn9K+d0Os8QMMXYHlhF6waJ4i2NrCiMulOp0vHYPOKmfyCI0+hQt9R3ArNvY1pqhwBHYo"; + sout << "+PMNqnEJOH28aU5s9HyHPOkQOOSvMTYnUo0kOns0sa2dSxuaZOq9kb7aqV4qG6+ZXzFL9FbAhth+"; + sout << "4FWqkdDxPUbYUh9pKNQucyvCOtJdlrQbrJgMOTHShadJl9g+eF5boVBDAMZI5WP08py+U2sw19IE"; + sout << "/uju0H1VxjRPA4xips+lxnZdrgWQu0zG2nHeLObhdS/gbO3R25LWZxUILdNWpVbxuQrE5dRWlIaj"; + sout << "aB5qbQO7zrwCwjB/ZjDc2dNEH/4lvPv8vnEehJBa/sIoieseVzcQgFRLK0n20sdaXil9vJ3B5qyA"; + sout << "UUJBjOT9I8dtvWAP41Y53UOjkkuqWOoAbfd0sp1wxpmjOgzm3BhPoGmLhJRlJ58LkVYTF2Ix/Hwt"; + sout << "ZvLSvOwNE8HdPFOU8BALTzAPtVpfQCKKDD0iOdAzDJC8NX77Ar/8UGHkgcJP+9LiycmuPnxOzhGa"; + sout << "tsZH//NolNfZ3+4HA2wir+hHAjw0ia6d7OSIJyLV03kjAsL6I/CajovAsT8uAkegs8ydmvLryqGB"; + sout << "+98uO61GMTuZNyrv6TIG0oRDpNiKywoYmXoqdum/Dj2UtsE0A+XAb9KK8Nl6RIM27j6vkFWjgC5p"; + sout << "7Y9u+TIHFd18fxQ20A8RuTKooQaC12h11+nFUR2kAUqkJq0e9LWklG9jiq2kIsfHr6jORWi5b0a5"; + sout << "Tdu2IDMtMCS/58L3+RggG0BAs+hc+uX0txQJUYEUCmJDlzReIAwTWj/5j6YRFpq534fp38muNl0k"; + sout << "OV5OE0UU/K+TLtfXsUQ1QCbKOvEW7RcDICz+5sC2/FIGC0B+5KY8GUeQILlr33+LIxJTKAQwdp3y"; + sout << "G5BZk4EXDjoLiis8OAO6TzgfktYXL/KwAicyHdRzSUAuzQ921QA8TeKMAPd+4UrzhUp3S/cgGxNF"; + sout << "sF9DVFCaYR7t8Wg/ecyZ/Kn71D8ClkIigOWuWyN2XIuHgLVb81DF3gDdcoiMQ9qroTmrCbJCukft"; + sout << "4zE53jcVbaVnhr7hK56cEbe3F1eFW6pEUFNLOxFrFOwyiwngw18cy/PUThn6kpenQoYS5f4FXPul"; + sout << "qlFnVMTXc/j5mWhRSNXddfti53WsTB9eh6UI/MSWNE9z8kjdMUWn4z2perojF4VUavHqea9U+6ms"; + sout << "samqyF9mL9CxEhDNL0+Uof3C+yPo4KSqMC3or2rqSe8IyrAbemjcfKGqMjsxBpSGH0YcRSNoHGw9"; + sout << "wC6wJ37LuOncQ3VLJM9f8zgTpgln3y5NnKz9nMrgXE017K+nW+I/U8o0XdPAJMgXlY66pfkJsShw"; + sout << "JmidbowoBQGAXWO35cYQ92Avtitrs8kvq0GIuUGdIsej9PogzuW1ZLTSNJkee/S9t39cXyQ9YpzC"; + sout << "KSj+H4uhCrGCQJ6Wts6l7kEhHG24Eu/mwSEw0fTTNqL5maeimxJsOdmR/IEVC/pps38ZEUuTJXuq"; + sout << "J2uTsoZmiVs3ImbIkjHnN7NF0I4ryYf11Qd4ULqa3eIjraoiDn1KOZIJK41CkKOYAYIts6aFS2iw"; + sout << "+34SlRxtIsIrn29zkOG3L0yc0o+fRT5gXHXPKv3eVmNca7kPGaqSrBzv+c1Fz6ShHK6tBjiPVtwp"; + sout << "S7k2/05j6t0tMny1O3QzQZDTdaAi0Xx5yg+1oyK3M8XrgLsEU5aN7XNRorjMSHv4XT+xhe7F2B7V"; + sout << "//JUutejh1StxytV/AswArYDSz63Q/7Vs1Eel0/Fkc/Rz9gYPKDqYo0UAFi/baWoUrSYZjYlIvBb"; + sout << "fQn7nAgDTWd8Ry8dx/j8kUSR/CBWx+xz+tNPRl0TjIOWZJtT/MOqLZyODDBisgCpV2en/f9VUr0o"; + sout << "V68sh5TYTczlcSA/lngZLj3uNH6n5Vh7u1kjINhOiYzYgaUzbruQJycDF096vEGTZ8JWnz8PNzT+"; + sout << "gkDfWt8JtyCsQg3UHfO58YSspPbapErvM6zc+nOzOde7dLkjj10pxRtQMGc+fES5ER9QLpNF8wIR"; + sout << "zGF4ENLZ63jytxO/Y3klAguefc14abKtLaOi/SfWseX0HoNpW+cA80GVq+oYU7krKhh4n7mg7NHX"; + sout << "PBsUV4gmzhj0DkVAgmaGd9qT9KlZPWHU4vf1pGE5P8mIoDIwWmpFsN2sz/mEcugGWAT2JhT6CP/J"; + sout << "sPOCAB66Ln2SoLzSJvBQwmtzdBZm0a0NF6TOloCdC18cJZfqYFuEJm2gUozzFMPjLKonOR3Zhr2w"; + sout << "CpBUdXL90wQKDIoTIjnOIorXzjFZv+O9O7y4OBDpjDsFrcdQG3AZPtHcQU49W42aH+6XcOUqlxRc"; + sout << "DYsJdg75GZliymNIWgfGGuQxJ5kysL8iPjmCxVGNnXM5FvYW/JYzcXy4A2DXqO7eyVBw6KXKHxq0"; + sout << "SEF7IIjjj0TgUp7tLKY7/L0aIVuvlNblobJNZDAHYheG6OG4dAAnwaggRO/dneeiQ6hDtFbenFmQ"; + sout << "SWdiYK09ZQIpXbdlXUlPiJIG+mtHudHcbaaZwVlYS5iG3NIOx/IYdKzYnStqv+5NpkBEPySa0sYy"; + sout << "npsOyO7Hz/09SMXbSOo+RBkLYrYXiLtH9qiSj+Lv2N2ueJw1fN2PeCd3I05jcLJwqCEsZ0ch7FfZ"; + sout << "1MaXVqdKYBD3BxBs+SY3qtHgffqUrGXDIbFARc807a65++J1NLuoD5ZXI4empkMcD1O2j46tQqdt"; + sout << "MtApg+jZVUSdVCJl9YArIMwOuUdgH0oPcOItEBlmYRmDDcWX0E/fuPfa6p7u4X6zwQNY5PdI8nG9"; + sout << "+Qp+PhYXKt4fB1xEivvnHXy1x3xMmDwDH7qyZEmvpiNPxZSO5ZGUbDdzUH5AwTCxVRA2/Yd9B0GZ"; + sout << "+WkLsX0Ds1K55I2cTSN5Tcb2IgT2VQnFb7ceok9bY3ABeIipxABdKzhLiDeYDfeI3is/pftpDxGC"; + sout << "vkyX1ZUnB7cpC9lwlwwbRdXhQ8WidAEFZNmt77xdTVeyU4cVby6WCBelfRiKd0+wXtQsSp8s7iRY"; + sout << "rNHSq8FQwmt5LPdhIsvLlD6254b5AV95lYKnriWhCJ3frsNaqWc001OhiyRT2FTtYIb45pxEtFw+"; + sout << "fY3kLR/xs0V7KWpXJCyt7uI51vRNCIlI2PbfEe6epAMg7oLrbXOtjJ+z6f6C5DPlVUf5pxNNg7PT"; + sout << "EEGuVTVEkWOD/hKQB6iJfhLqIbgfOnglwI+pR8JzGJSJg+yuK82g9rRvzhPvIQm8Y/xmsMoQTvg4"; + sout << "zBpA2kgegUxYS+0VXG6ZEbO9CjSdmtUvcIpa1yw6emAmjR1TgiDqtTtQVWksXhRNiQ2XBal/YlBE"; + sout << "+45tqfGRjSiInj056XBdqsQqcQVuXcxXMT90/uslRAEV7oTmVX/XwSahsuP1VIevdVI+rC4liNj8"; + sout << "yWmDEdVfQOM0u0MgevyXmDj8BMQHznlHulIVNQRhEUmX13ibjVIDLBHDyozBbvPneuY7hBejxnV+"; + sout << "0kaG1VGjxSaTpH+jOhXwK2UvlXPTN+FM4UPLS93zEzEiJ4EWDUR1wL9VUXEfx7IveXQSYixy/rMy"; + sout << "zpSX+UpoqFyiiKaG8DGxg2hu4CPjNXfKVI7/IZAFIT6KHcxPrrO/3BZ2HTqwpPUpD7SCiVQkqzwd"; + sout << "gA5zjgo5/vrSFtah1TLBrd3PiN2RBsGn4EzHBdR2IIrDQcE3uB3nEQRyWPhRcO2pkVry8RtcGh4b"; + sout << "hkH0XKI26GAIC3cEuRzdwWU70JlrbuN4beXlL3z3xqxz6J88vpg599N6Offn7V0Ki/z0w+kIgHJN"; + sout << "H2CSc9uBNeKfQMGNdQzutSckMpg/9neWQVhgnegiV9hJ+L2lQODdnUCJp6nBYPHqkgtbOAEFRL+K"; + sout << "tkKQcyc7YkJs20WUsItmeW36T26+0T3uyN72dbmGqBaZIonm32MKPgh+sVTbUspUiLdPGSFz6zuF"; + sout << "89/dwHkVzv/dndOHh5yb+itB1LaCjsDRmdHw9uBYto6eJDdPa+XYOBe2w2qLyOHph/3NaHrQI8ia"; + sout << "gfvPuVAZ+V4XXQgXAaWjPF0XtyBmRuEm+C0J9nWznqrKjja+HXr+hnmlEuJrHGFIJeTcpNrrbcRi"; + sout << "RBdh1oiBTCG1tipAQeDf4fqRy/zrvugJg5Rx6/4fA39XHXW+IjqLMuOpA1qCpOWf/4llleXazoM1"; + sout << "J392TqpJVxxg43Jfxehz+XXl//IZPMCoiNeuztiSaWuGsw8ZjBAc8mMnzehBVPv28axiOixuiofH"; + sout << "WF61WGuAUXWOZ5Jqsg+7dMuHhTZznTPwGVj6hzw0cyL4KEoj02zic8JaSnnprKzhM442UFIADacs"; + sout << "m6PchnrjUqfG3D9DwbGIwwmVFXePtscyV6vb19r5nITmXaQHm2zJdkpVmwe8rB0Vusg+d3Crpxe/"; + sout << "kOWVPy11Gv/rkYx/1BsKd7SQJfTdsZFABUTjLIaJb9pmry7FnI+pNxWNfIyEu4wHTyNSs8ZWPOlM"; + sout << "0oXY2CqzRqdnzt0rOdEp/ncMhjN0p8q93nMhfwSlH3Fvse9KZ5Vrk4Np8HDTOEjjgCQO3vpDveL4"; + sout << "S53oq7IfTsp6lKSyzR3h9iYqa+fnr2G5wDoIlLBPIRXOx9+l9OH64rrNITKFntHEkFFJi8rnkr54"; + sout << "0892mQk/ia9ELfn7V7pxKnEgXnSqDwldLYAWSD6hhGmvUYEvEKbS+4/D/zsfnauYNFkc+cdtPuFx"; + sout << "uo6oFl+6Th5VYBwrmk7xq7j0vu9xxY0NfYXkRcd2gp7tuxAtMn0gJVbfu+IfKW7cneAT1fszraXX"; + sout << "+qVo8SX9+d7phb7a4srtpA6513DcZSBSObMJ9RT+HZeYDn+3l986vZGAOJoQMGIohkU79sFTdqfR"; + sout << "i0lfN9X/JohokHdD9eroMq/F4vMWEDj4ax+OR1wvJ1T9syDfS7YT9IXXt1eXYL78aHwoCJ352IQf"; + sout << "Rv3phfVYK45OdBn0BqAQiUhWdQxPcwuvt9O4cs1JdChTjSSGz7bKtOPTBxst3pj0butIF1ORrs/h"; + sout << "n9iLG4GVA///xeK2m1OFKDuk9AQK8hX+Nm6lEu5zKZO6NBJjSzUX928LJ+Wx93h6F/I2sy6hmBNS"; + sout << "0bhas2TNmGX2Eiu/EslEzPudjbp8uBKLCyHvM5P0kiOdzOqfDD8L0qyuOLNumauOEVB4EuSoROZg"; + sout << "uGPoG2xyiVGdkKIXIfvb9YjLpU3mebY73zeInPcaQCVY0PcbvEYIs/wMXch9CU8FP4wyJRYsp5rP"; + sout << "XELbGL9uDF4FnhmFDjVOTX5j/zV/jVNNWOv98Ue2JkVZ1o1gMweglPMk5Gnxn+twxl/OY785uZkB"; + sout << "CJdsVSNbXanSF4dFQWhkEsmjXyA0PVsw7uD0tfIaHB/BekjR2hWcwbUmsAUzIm4EVvepSk3Ec3rc"; + sout << "v0ZNE6FO3KD0kNueAt8RvhpxbO2AOL5pm3vCt3q7lbz86dalSf4mFuv23Di0Dyl8iZqNAHzsrToW"; + sout << "50Wua4dgrUEhqSgn8ohZq5Y5tI7m3uiP2LTLAfVDag0TdA7IAREMvTT9F3N2f/2Ff3j3b9HW3MEI"; + sout << "2PU2HZOvUzJyyOfuc0TnK2DXpyXwGWfgoeAQ1X8c4CZhEfecXs/FXM+BfUyJEBAxnAkHxSV2oKW3"; + sout << "PPismlBkfTS/XPxnChKJu57uy88MTaL9cEfVexPwEiCWC8VOBAd2nQUYouKadJLVx2UliM/dKQu2"; + sout << "oYzUBJKlphHkLMGgUUuhXmGFa0KFkpisUVwFtyX1ey5myWvdyjm3xGkJJW7+6KzKIuFZ8oF6QEBS"; + sout << "thWf7/v+I+TMG2FI3cl96g0zhRnguGVWeQ1NCzPRsDmK0fDZcgnCGyQGtQkdVgRBysC5hLqgnq0W"; + sout << "+2S+oRxhL7AY+DaVGOHfYnI+7jyf+NVcHhBfqw9qmVH/rZk6cuG23Lj3nS4IzdR96a70EWufMtHo"; + sout << "Mu1zlfyTEcZ0HgZ8vZdaBzCDRT7nkqSou4/uomjQzFUBRkZYD+etcrNb70zi1Pa+AR9kYXVJIO04"; + sout << "TTRLgikAT4Ja36SlxjBbWDS0swHyWy3ecHhdKwVvZcklA8yxzLvdBbiANv51v/o4O8vP2AAqAQV8"; + sout << "AcOjIeo46sslQzYC0aYgaigzGmsx6nYGgQ2zFH8bpWR0KNnP394GdKHx0AuIjmC1/FKtN0KVVQeE"; + sout << "xcU6gbb8Lbe9Mw/WZYDiGtSVXrNBlD59+0dNG/SAyDZj1LqU1Fp0thu+CzID2LVAJtA7z4Wlx8f3"; + sout << "t3iMWYIq9+3fDuCta6hORK3L/FUG89sB3Rqbi6OaYW7nu60+tm+qRP7ECwoNcyVFtGLOOMDonoyM"; + sout << "uHGfkqQcEeQaAfPeWl7sEbnzmBbdqZ+xUj63WTpGQ42Ceqaa+LeoUJT3zdt8tDK4Hc9j39rZh0zX"; + sout << "HtwdVVVlozsSQO6HvvElDrBwO8EY/YPczI4dUGCc28KcHNgEQQ+M7UenP4mbqoTIbo4Q0i7FCvEu"; + sout << "9/TrBv0JQ88ae2xds6Lk5xsQomtH4ITr4VYliRfDO9bD60zxQBhhiHo5joRQ7sW/t3ms1tDzsco0"; + sout << "YG4xVvyLzUHdDh4FBNVXJ5ZHnHnTBnppmNZ3M3ucVQpk3mEXauKEdPt7AWwdX6rCqNR81/ZasFJh"; + sout << "qTU34x7ZH4Lg3Ut+95CHI1qubhh5W9feqPOLbvUaJzuqxlcsTZBhY43N1PSCerMFFez/b1/Rtw0m"; + sout << "OAmNLDlwUfsqtrs4q3OnKlfXyHlwuIDe3fClgcE4r+9i6QSjVzPDaPfRwE+eTFRKPm9uhlLQYAnP"; + sout << "YvVqOtnUCbLQod0SIakYVb24cwnvuHfGwx1tikIizumhn07N/8LXxtVGth+prxi0Lu8b3hMI52e0"; + sout << "LXdzbJfXUa44mXoyEL0ARMyhaombmUt6BX/HaBLx19n6j+munyoheRlpqIdIMWllIAaEMbxCoyXO"; + sout << "ExRZf8NPFANOB0hvtXhN5U17hleB7U8ri/6Uf/j30M+jglbucV015gMXP3h+E8qg8aOy/RoxhkWi"; + sout << "lJBmD8OIGytd0GcoSmpzYzpoOQ+DVNMvVlBJxzG2sQy87PtAiJW0u22n5LGUVEQLxAJrRDdEiKcO"; + sout << "zp/wdKOlFD102SjWbLp1mQIrqkjVgGRs+M8kIzTSyUw3HQPOvucyewaf30EkILRHwNOpUq18WxRt"; + sout << "dE3BXUtm6v0Xc1rn71+Kbka3pXmU3qoLLQaCLLFQXcDBnKGt0olk84XYOb0/cTlnxBJiips2Q0z1"; + sout << "CD8+ny5LOekJtzcze0PcYimx7Rl1EilhFrq9fF2N1cdyF2skG9UfoPVor/6U4dDGzN1qTzN/uflE"; + sout << "harEkP1c3/Dk4BqYx39Vym1nvTSs9dYVYli+6K9EPaxjUu6ng+3ad48uZWE0xA5i8DfEYU+VBYSg"; + sout << "ssaHO3OrblCccri8HgN1U17xnDR5C/zBxpcgoQ5BZfs6GYsjsDmIHz4Kq7dCvZ5xXStqZ9FEPnNf"; + sout << "oeAMmWp6H15Ci6pqj12u9lZRG4rj1kQvaBvnpKgWOiscli6Q6mCTEsgtIYvmQhds27mbO6Y+uwDd"; + sout << "ZP01VsjBxmPZaLxpUabmGIZiwAsAZKLN9qQgRt7EdMIDsyLehadWOzcjiYTHrWTrQQ8R7zN3e/eT"; + sout << "NYJ1wsWhxMslZ+Q96vKEmejIyWAvcVmxkCuiifp6VfXCF8BEYRsf0lY+Y1VFe7yW1BaOfy+8Q4z6"; + sout << "jtBh12vD54vp7dOo9xGlFYZD/w3KQKdKyVnnqm/tbT/pCrc/X6vbrgsf7sIbV7nMUJ12FRKYYX8J"; + sout << "p0C9T9nkvrPZkxNhUKIMkra9+NT4tMsBPw4MyQQsMmyUjLjZMTSRiIyupPjo8U01tXEXElUdaNxj"; + sout << "feSkHmcQS1PBXV8oQfQnwEIXLGUX84fe6q6i46YrhrEhjf0crtCdpYrDOuhJPAKnu46i0sHS+xRG"; + sout << "B6/685zGr8K+1ET/1RqwR7zo8t0vJf4zqhAY8OpWgvY5ASN2f0gK2dzBMfuqrai05tWaJaUhYRqL"; + sout << "7FufcaIyxbmDc1laARn/mvN6pYA8C0VxexvsQqF0OlOhpLOsY31GPy722kMJeKxFLuFYYp9DeHQS"; + sout << "Br1KtYMU2S/RnQ/sIbzZ15aSbYstM0K8J+9hxHCi4aq5K7KGzzNCSsuMTAz/LEDEV5UKCFd7se7E"; + sout << "rs4gwaKuHqH5b9n2cAb1wMpM+VMttHtz5oQplOyrZaz1r5sUV2KGkuU4oWHiHKSgu/tvrBvU9LqT"; + sout << "1TB0RV405rPe3HjtIAzSPD8/VUHZy8pNBjGi/vqZR/ybRNP8GQFWZZ+vt6l7YMX83DqWbTbfk4oR"; + sout << "C06g1YnZcYt9UsV0GldCGJQZ5oJ/UHLM4bfRzSaKiRAWTwOuDv5hHL+AKJcYOvDub5pCp+EsZh7h"; + sout << "kMS9v85LQNmI5vUHbhrS8XE/2Zuki86Oms0AtR3aVWWGWM28oQopU35Ayu0rQE4oexb5NDjsKcP5"; + sout << "A38n8gS7O4cbHdQ3XAIQaPoUx6TVK4xL5Pc5bUJ5aRsyivvPGd6meP4AGW8YBKObMMTL6TSRBD3R"; + sout << "e0B/yqLXd30yRNiazNfw/aajOIAeHfbNuFvNLeqe7IKdpRrDyLg29Fh5n3Pg4Y5+vdaJigi2Iy39"; + sout << "FKvhs6zp5SE0tJubMJuFDUfgYqmW44XVNv8z+ZUIJCYQScuSkCfMmexS3LczxvOTqTm4yShuNr5A"; + sout << "+o4wXi9n6jEa0pcjTVlU45mbOXiX5ELwfbSpYDorlP2BJTZqMAvnfMdxdX7cs123ysQUHxMqKK/Y"; + sout << "fSUaJORhU5zWNpmiSFeTCsH0uVuWMj11H2ttk1+gotSz7I/RCvotGNzuS7SqqOZzb3h6LpjHtplh"; + sout << "04W+r6/FtyAq5z9gOphIZQUhoyO2tvbY7aCciCBBX1lugJWKgCQ7Ui60G80225lAU4b9QKOGzwN6"; + sout << "hqKtxP6c9Dh+rxiERh4NoE6rctk+CfBoCAxDOEFVM/bcWlEbed+KvM49mqrDF9epNJP81oIZZb4i"; + sout << "afVphDQaM8LuV9LNhPSMGLOoPeYNPAxyCHnlNyAsOrZK7otQ4lFc4dLM66NFJBbSA5iBTYaW6MqM"; + sout << "qGCTCWxu5MzZbgNgwMv7QrTgPy4LBH1qwST759B3Um80nBimBa+8HbddXPbsqC/KzjW2yjodKmqx"; + sout << "2t6vnINn3v5Q9dFFKv4oaetAj2dqSfbfCAvcMvS/Tl0Ont6PZt8iFpH2NmwcEbrwSY/2cKKqZIh0"; + sout << "sVj8a5fxwXHAkMcov5J692QQf7P+OUzmkI/P/+TlBoWYW5Zor+79wh05NnrDHqdVMY24aSRKrdIE"; + sout << "qOeME0ZEnPkMyOHJRUzHoPPw5+r7Oowvbske31mERKc3Es0cCjhJ3bab+OSaoXt2xfzgLd1e2eXm"; + sout << "DtOlfyyLT4vK55WTN+IP8xpw+e+D5N4bdGDn4UC13jvhrviBYxasYYU/EtFWpFGonD17FPZ22PAX"; + sout << "pGNpKbM9Eet+Cabzi3hJjAXwKYCmTXHk8+aHceC6DtWExR06Dj1QKoxSJg8vcbuik11BP16e8fEB"; + sout << "UiPiV0hzsLwtHdbg7uZQgvd/e8qoyivF/NqZ9e6A0s4r0NoRO+6s8OPAL1QTpavvLtDhF37g+7HH"; + sout << "vjGkDbTbzU2TOdzl0OvxmSezNfASeloEOSXqIkjMStawtwC+DgLRXzMhACjjVuc+MY0D+/zGt7lp"; + sout << "hQlKeLuLWLDEUYNn3Dvb7G8/6beGyyjKHlHSOr9WK90JsS6OQT+xkEZnOibrIzgZBwmws4SPrhuf"; + sout << "wwwQCiHH0xjOXF7iT/vi5qEKbJePKxh4m6kbcxxaPmP56M7TfmPSYJ62RTfZbnq/FIPurztCgSHG"; + sout << "jmTJvBi4m1qzKgiwHVBoWjHK3wnjQxTuerxN6CMuTA8mv+ayvrabdbRwpVRJgjSXbb/Vzgj2e0PX"; + sout << "IeHUocfZtrz9J0D9f5D5nhIlqv/aSodiEn91GhsasFWwDWwSdvKwAnTtCaC8EF/+rLx/qzCneheg"; + sout << "4gGRJ/VgaQ4ib9Dx0+0SnOEWoE0GJAceSDJImKJ7yi3Oo7Dxy3J/tsJFST+JKNuakqFkNCnTetmm"; + sout << "FkVNEXbo8tKzbeuXlextGxSr3ZCyOYfxXsfNQUJawLDxpL6fVpaDbf6Ot7ip2nejQHL5wt1AqQIA"; + sout << "GNwV3Gszf6q/bx3vXKCTrkL+AFc1VmJR4ZShMwVlVz5pdzMWJWmal5ivqLFkRQ4q4HXA1bPVhS8O"; + sout << "tfnYwzNTBMRaNMpsKGTL7MOOHdmXUHDPKhoMki5bX7Oy66C0x35qEGyTozFQgygLOMFe6eOleXF0"; + sout << "D40Yu6VhauEQko7l/VN0rICXaTJSaemFIgmPLT7lMfrd3Ta6XNKYxXB449SOC63A4doQBIFJnekm"; + sout << "VO2t09QLePM0/ztdbA3toWuK8cTSNZgelZRbSlpgE0+d5rqxjlxmKom9I/8kj5bjS8FrQAbbsnR3"; + sout << "pLEArb7kuWg3W3Gc0pdlONrgWVvoYpTtwQsrY8QfZmIVmQ+Dkcz4q+LqfYRU+TEIM4Y06Gm67JzA"; + sout << "ncMkUYcVm5YS2Adqe/VFsctKTliHEPhPRFpo9Sfn7csGTexvc+IYB1mFIWy2DjRSsfIorQJ25wAy"; + sout << "FX+E8PlVJPf7S9IKNM7NmNX4k0aqJYJt9VVuDidECJrYiTW9vsJ/SDM39qR2bQ01w42aG7/ksF6B"; + sout << "tWPS6j9Gr6KDOtdAK9UOof2Oq6uI1ENoajWJuDMxe9fT3J9/1ruJMMDE4dn8eSXCi09ip7qbnrxv"; + sout << "5splSpCC7MvfBr1bpzbQ6gsZF7tPXz9R6FfM1idioIEJITqqaaYdABv+qwad9Pwv84maHHslw8O0"; + sout << "/SPCYyCpfOAbbYuVpw+cEryeS7RD7LNom1M2Qo6A5Orh6Qmd+xTExUOJCxkPNvyxUDezWtVCPxHk"; + sout << "lu2vn9lqjj3UaUjALvD9xB0WsGhdJbOUtlQw3e/gEloGvQ8rLamwCBAelQqTTwM10iHkvx/AzA4p"; + sout << "TZSpfZOyXY+3HEjROeej+gvP6VXL1lWcYSrEexizEzy2eN/H7kVTEgnHOzYwp2A8x6vlKHWL1gnU"; + sout << "lGV1R9H2bzmhj/iGFom/TMYWmiv6mtUe3PTRTIo3W2yjlLKnBiBFhWMh8TjfcLz63dadzCLEV3hL"; + sout << "D9DPn8fnJYh2HZyd/N2/lVDtEDnxqyd8C5lZnxY28bew7dB/n5euYTo1H6zWvpIHlbbbmXrHd8VQ"; + sout << "iaJew6T+G+bBxuUpUQlu1NscX6mcHZBoT+/g8Ng1Q8bmjiT7/c+Pl7wpRlF/0jwhWxPIL/5mlmPf"; + sout << "+pHBfYj5BWXTDrZNchJKF5Pvwdnl0GbmJsPPiQQ+Wvt+y3lt0wv1gwefzFdGhi77EJyRY841JGSh"; + sout << "LU/TVYCUFn1MdrtVLhGdomxo7Bl6Zj4bZaYlKXf3Qru1bRjD0Ug2tIePfjX6TPymfgxBkgGQ0zVx"; + sout << "dJuLeeRNmhIFVH8T5jJTMcmd/lS9u0hzXrEjvpSFKPmH+Fmv6LJF/D0bl/V0rYEI2gKmpAWJLd1M"; + sout << "iI6/XFF/bt6tQTRu+g4WgPWk8HG3VffsHUSAI7jDK8hjrwrJm5AHIa1a/PRxiSCH5Hi2EHCPRfkB"; + sout << "grfcTEwLwHJA+j0zK0tV668Bal1NBw5bJGMzQSONvQqB/vnVqcRDxhCyztjW4oAnEeeBKPlAmAsy"; + sout << "MlrkLbkVwjiaYbpzkYCsO/KlIAgoahAJEOLghLQGBuhP0voDuN99bKp5JZFaGlaaycEuYFHyk9f7"; + sout << "zo6Xfa+hhmjfgWJ8KCFKKvr50ndwdqRTafQVxjZRe7QkwYB5ZIpwYehKsXlr5SfTLp+ANmZqP832"; + sout << "yXRF39OZLGs7xtsMmPdI6WX/xsOKiCc0oLHdzH7y7tYzLVQVSFdRZtN7HA2r2Aun8gTbLExrbfRW"; + sout << "+16+go4YQEubnMwW3FeRQqbpM+0GCsiuTlBPTgnvNEWi1n3JHGHeJLs/0HRNo6ba+jVE3vIZO6B8"; + sout << "NN1s+3CSDGA0uqpGg/51P21F/Q7wHDgQDxRuTP1KQGNm4u8A7Kbdw5rGZUb0ZREWwGPfUpMU0qXi"; + sout << "TgY2XA4rESFQRvn3Uo8jPP7UOZgKM5bsdToC7Rp3wuKXJwBlmdTZgdly6CmELOILvHVNtGHg4yxO"; + sout << "XA9Co0RF7RsHOg1b9XzfeqibiK27ZiQnZ6sIvl+EdmsnxOY+Qa8oUZdrQ0JnFUiZjC0+P0760Si4"; + sout << "i36AOc4h+GzZ6WbG0yPZlFAFEcPiBbyOCp6VoY76H32BskEuSBJx0U5+iFKvkvK/6lyvr9WDTSX5"; + sout << "n74qoW6/CduCm96RcKbi2ywZ0Sk8vi6V+tWYUnrDaepg/iY84UA0qpAw/QsNZN99BqOlNEWftRVP"; + sout << "SZhPmpxzE2xg/SLUwR5y7NRijA8QalWWQf4G9TotViLV1IkftS1d9TvTMb38vGW4o55pWw5SZlzW"; + sout << "kFf9fbqlQ5ktTZVol/v/n2pwBjDeKu9wbNFlI7bhl2xhg/3nN2e3mW0HikB6K81JywLpxL+sa8+l"; + sout << "SW6bdGwbgj3BTAt9evqejNkQHVb8ALTkmZJUcPP3CmW0DWrKioqolNXrfVRjl9DLcxgQfhG7osjc"; + sout << "zE7/8gU0JYIk9Tr/XDHdk9zzOsMZAHu6nUbKpIL/5mytR193QuNE3x/eCQpsG1dMC/alF8EGn77V"; + sout << "QK8dhMFI3BQCXlo5TmLXd4RWOUmsNGt061t6ZQ/jvP7oPRskGSB6YUtr9Z6CN92wbnBUIf5VFJPD"; + sout << "qf0mPU0vKaMsbwQmD8+rNG/bYItW8lRAZLa4gJR5r6lULMKvcMiHeP9742lYL97T5w5VTo1c0Ec/"; + sout << "YTUMJKSvhfGPq4IfdtYmMOX9pmqd0VKlwnQ5w4pKiGRY3lj8zpXAHjraMA+46wTo9ZJ7Foq1XR1z"; + sout << "eCY3xmnahD5uHBR3Va66rw3W0WGRgbuCk63WTKDULa0dWuvNw0RWy+MkZzbn76QLX9dYGrWYUZWM"; + sout << "9mnQiC9yDObAz9vRcITf5PhEKduGFQgAl9CU0k/ZgmSmAYZ/wFJTb3U/roNqnptdm/KGuegjdQqT"; + sout << "5utMQZwwcgV8DCuE3/Y9DZcnEMoeSjZF5ugB4Hw30EJt5xAvCUpVRA1tygRwwWD5XiKt9srDjBrM"; + sout << "nmkQCbblTb/HTPqOvdkBO+4yHSGtpXfA3G0DdSFtFPO7OTwPLrUxf6rM6CQTgpzP6/eoupLM/I7z"; + sout << "tlfmNGcq+JYCy8csjKaNlPTJextxMVqNScbpeYwG/9SpjmhqObv+oYs/uga5UBWSyS5ls+eF/Tph"; + sout << "QSD9FkMCTY15Twa1LqEJy2TMZaDkE/UVJJB0Xe9eRQRXEgkM/7Olj2qsyN6VaLpzNBhcXT+Q7FDE"; + sout << "NyZ3C6b4MeoE7zRM0/KEXc25ma/uiHu142x4Ar2icGm7NYTp/gxVRUrSpqOOK2EYoHzCTFilsryB"; + sout << "JhzQqEIsDpNqmqyjy0BR5+rrfBJFxXAaEDryGMntT9XoLOF41ZMUAsypcv6yujSkbCemiaCTJ3gc"; + sout << "9p0IkAplKKgb0m1mJ3SrW2WR56ZGhZuVIQ0p24yHi1ECqcEK4GJq1jMpHpqIGfxVLUD8yZ6wFLDl"; + sout << "NP890I+aZC1kFSBSuj0Jee1VbQzTykN9oj6+8Bo8v9Yv/qt67WH1h/oIP5bxmCrjTzbncAT0zs5w"; + sout << "qzi+jN+AcQS4qh+IiuCke78wgaLtnMJIb1yA0GkDM6t9crv7vKAJwGkEFZSSg+p9TucFv2RqCCjQ"; + sout << "D5v0X3g0OpBql7+5IqXmuTiRCJE6BIBeeXVOtD0R6JORJ0e5Hy7SvT1LrIxISqZwat7//kb168WS"; + sout << "cdpRlCa5c0+ds81SVFkjTKViZAnFkOHPGVgzxrsdlfTt1T/2fx9MZO8HT3p5V9eM9XTASzXM/jMN"; + sout << "5X63GR7qz/5hpM12SMJvgcznXESiudVq7d7HYd0NeId04fzMlFaGtgGLDRr+15G0eWPCk7mcF0np"; + sout << "7XLz9ytRsWC/WXaVx2/ucmztaGzksA02If8fHmKuIHbfJ4YkjjKjVRrthMCOLsGYzfb9Hat6WbXv"; + sout << "jkdPQZd/EYGM/ZUd0wfuWo8W7yLidl+zMgMcJ+LeJjvRdBkHU3o70srnlKBaneUv7Ly/r2xPcsj6"; + sout << "wLMH9iuRZ7jZJ45EM3htNEHJzrT/9PxaGKiK3caGIilUnFPpyb+HgousVcBbn4XSdkIjy32UPdmQ"; + sout << "aG5w8GLvMEJ+SDAseg5/B2poqEiGstsA7EbJyJn3ptHk6I0lUFnpCgHI1H/vp7f5vhSIVkDyP4SD"; + sout << "GPPTMLfV3mmMzS/IeIdOLdshYQCrL9I/C1sgaASFLuh0fZoWnNmSxfS1dVA7fU597PnuEzGn4TlM"; + sout << "CVb8UdlIBfEmrFXegOnBEKu8RKspYa2TdawhfTZs1lyPJOxOewWH054k09q5m+p/QLHfmK4kr6wF"; + sout << "eShsc8VFczeBC2MtoBVF+a2gGsz+F2Q8WhtG35kJuTqZK+lqB8clFAVSljny5eVo3WJ83QHoZMO0"; + sout << "L8XspRqErV0+T/n7Uf9p72CQbw059rmyc2X1I1Mka7+cdqewt3RUZbnB7YD17+lg+do05keNDDea"; + sout << "kkjLqmK+oKKOJUyDULGM2syzbhMrJn3WY1v8RuEVsLIUrCxgvMzD/Hwcl+lC3idwjonUrQuWWDyK"; + sout << "dTjkimDohgLx3uoRs2ke/SCb7ERVtB0DMpLF3zytvY3D4ZGIdet4BGQ9H2nXNWMK9ahTfo4GFMxB"; + sout << "EBdLiQLAqBJteGuoJ1CWdYGj6BDbYoy9sXSdnY4Hw0a8ydeBmFqoRm5wZy6ozEpxX0IeEc3W0GfV"; + sout << "+4EojUBei+nfqmpeaDTWFgITbhplcgSU4snYS6dHxO4U66xa6N58YokpK0oSHkL42y9+Ap9qxiTk"; + sout << "eDRtTd1wgP4m4EZS6ld1rTROZMhwgwZoyKizgzW3cc6o3rS7wNBvrdRMtOPVo0YJD0UZJtwCyDzT"; + sout << "/LMXSVs5EQByRsZkKxNFMEky5F+D1IKlVny8ylC3oOI0VNJVxr+hIcuYjOf3adwKeslTQ2AO//HN"; + sout << "4WEyXck1inZFxnDybu7NGa/WhNdet+PzgzEI43go+Nn/9AYLjR3m7nAcBOGgtMqmrwOTPJl0Dhnh"; + sout << "QWv3JWK+OXqz3RFJ2q0lZP7A7EhU+3G5OHdHl6yFOh+Mb9TNh1ofYXPXL/mnnUhxCvUNYEXVdHGD"; + sout << "RALydLzf6oY5id6VdYgCOcQY6c32WcJUEk2zCeMFNxyZ1ntzmn2KlawbF6MPWHqemMEvJM7YCFXl"; + sout << "ObWzVlcxcnecWLltZB5xeRqDBcWetCMJvG5crQJk9V2jV3iXi56gFzOta7HzrEkKsjuYtys2c2aN"; + sout << "EMcOMedGcMH4q3xkzca56eoR5b0mXeZRnFRnK82LczMxJjitaFLOMTA84YAgWWpWt7ENpJ1qA7bu"; + sout << "QFdAdnQC+DKTJuLykS9eGpMZIuOarnlRH00AbK0j98XyjkaOhmL4ygSKNJQghgmFg2vyOfji/XNR"; + sout << "F2M6c1HdfQK3Im0qLTXaMZlEjkWPIZTcDtVt+CdUtiEZaqPQwW1H37Eqk9gZAs1TnLe5AOHAMWh3"; + sout << "mS458zVeG1d59r1BBc7pxrju9CWB1PjF19pxCUc71R8sh2D98ss2W02R8dqCAT2wdUuEC1WQwpL3"; + sout << "GGGEYk8jB/4ccjN/+EdYqAulxV/Bp+q1jspEm10/EU3Cpu6RkNFZj310v0t643E1MqzLjpg7ZSko"; + sout << "FkzcMRM2E7Z7WEzOQwbr4KVgXcXZ+76Xc1ahlvMqwkUQV+/PnYj68Ogv95pj6biD9vweY46XV2dF"; + sout << "uxp1FG2j2e/nzzFJgFfIcsB1411szIaEqmHTUPBo26qg/EyWg0775HTFSZwlw8iDOFG13OBlz7kQ"; + sout << "kTvrxy0iwvM6cTm7IqJUZi8hhC/eAX+Sdx8S6pKhOcy1q0fadSKF5cTynAQtrSYwDK0/xYRjVl28"; + sout << "4oTks7lHxGE3Jdu0+j2ufxNuem5fcEgWFSKZnZGM4XLvUnpVbdjPMEjf+KJ0w8LqDbTbmZYXW69q"; + sout << "Apl9sQ7AdAUcS0ZhGtACegAyg1QcJUjFvDKDhXwx5d1+0ZKChwYt8FFzqht6nWY9WjDT1i61XmbS"; + sout << "m2Emc68sgA9VMv1AL2bOhXq44//vlGYj0bX0gj7ZDEdUhCVNi5aXobYv5fn4nGfTNKp8njHyqGrZ"; + sout << "CT1xO+YrGMY/4qAGiBsOMGtbRrIcFHeeCMKOO/5so80js3nh2fw0L/XNVlY9RJXJA8WkEsalm5O8"; + sout << "MUDzmTukLJQu6YIjFhCHjT0FA9R2agaQnGFZXiiZ+0GnsRoUSrGC5mdkZ9H6jg7Odm+aB5fD7I5M"; + sout << "pHzEi++Jt8xHcbX0jyaQd4Y4pdku3xx1MrdgZht2WmCpUNxg66syuutjS6iWKQlvWSdDtlyf5zVU"; + sout << "rRn2ESMesAxpmuFd4sG0j8Qm/8jzAF0Fwolv5GmJFlwKx5X+RW5Mb3HZdJYg3xou0Uwb3FBgV6aK"; + sout << "mLA7GQUK92IMgRKKNXcDAPJdaF6wT4lin82Ss1QpWBBnWwLNTslobUbn1TJsg+MgurruWUB9byjb"; + sout << "56mNxeoTkMh9h9fEuDZYL7svjxh8U13AgnJCgchaDj7iK0TOd6b+sbz2DadE2SoyFYo4XDTbXcpE"; + sout << "LS8TbCb1rH7G2WAS2a4JKRyahhELZhOB/X+jPRtre+MNSMvoYOFYtHYXjIzoEpOgEGdSPOsLcA8s"; + sout << "JHCoPcU7bFQcnIlB/HZVqbDZwVzAtUxKoven135MSyoRVu8j8nTIcM0RBT1VjQa0L7eHcgQVHPvR"; + sout << "uhtFmhtUJJT6lDJmoolIJt+nxeT45+ndSTz16YrT4pODPMuawO3XIULSvkPDWNhQybkF+2jmUqlX"; + sout << "brWKyviTX/IZsHS/0YMkgXZRX5McLzUH3Vs1PcjNBoj2q6tdCq2HCcIdBKH0Mokm9DLqEi0oY4mq"; + sout << "kEsB6HFUnYZm4xezpAoY/wHkXRq9Y/6lOa4DtgBz8ng6/WwFwKwCIkT3/aTHzKDQSewYiU5jmQGe"; + sout << "R4+wtbbRqGm4yfRc7z0xEsMESF2FJnODPSSSBAFT60rKGdGK4ai0IaYbWoWaHeDxhHVeOUYbwUlr"; + sout << "dBHstgs0/k6t3QKbHDo0KwMv6LN0sYYDK0fD5cLA+92Pf8mOlOAS2YNlR0sG2GrSa8M2gZhpgTcf"; + sout << "dz4j66DDlLmQaC9UozZNf2PxDFLUN8NEHu2XnfzYEyuyt0BQcTQyqs8DvrayzIwV/qiD4o/xRI0T"; + sout << "Ma5RvRpagaW5EUApHgu09rkvFXk6ijL+2L+7j+lDQTgkvn+wHYwbTADm/b9tZs2egY+6pSDDnK97"; + sout << "invYOsWUNUDtgvRxgFTDDrjUNX4xNifDUnGCJrODnhg41dOG4M6ST3gBgxtKkrmuHkcgafuohkCm"; + sout << "w8bPHam0dnd55H+lj0J74uyhw/LreRAthXdcR3RbizO0rWIX5EDP5Y+JYxQfejANE9APE2zFHg6f"; + sout << "6y7DE+GVQHO9iN/N85TOUhN5ya9MZ9fk8M2JR/440Q7gqoqNJ7w1PpzfHCyQBSszPQTPBr5WeirK"; + sout << "3D9axZGpND29nMOQ8iZy/7b+5IinED5SDZq8MTekt6kgW/ScvByBVOeBbxh38DRAcClRnn7CdFT3"; + sout << "Ey+kaJAYlahuErS/gjU3VkqPKGds02dkPwumMWd1G6s37r4hQdPoO0F73+5dxWwqF1dX75nQ5tyz"; + sout << "J6f2pbTFnJcp4+CGMYs/SEoPzODOqMbMt5hDqQHZuEekNwMNI/X80PhnIOiMa7oK7h+8cOb9v4Fc"; + sout << "07AR19kD1T6r53s77u7o8wYLp7/uf7NLT7tA+bc6Qeaiss+JLJznbb0LHa9uSUZ7dK0hvHFAYBjh"; + sout << "fGqkQ2JI+adCJ/eUQLs9zvo/0GvEV4P3dZlRcuXw+y0uj/HhbMnTYgnD8AUbRYLCLS5Lx00m+Vrg"; + sout << "lVyzOLy3d84zSmZMR/xoX1Q27biAyEnvZK5vkX9CmMY4lhP3XnEBfcK/UVxxZDQf5HfZIamjAbOP"; + sout << "ZRyNX/+VYpWja1qGHZmW8Tc3aJfH9GsIKdls9/FKfuydYwfwr2kPuG+PTwXxiHKWWdzUXluLrmB5"; + sout << "T9UCUmf1mkBM+Jv5r5R7qUPldDHzJP3njQfT4SUrNNXL/ooR2NTX/jYMqkTlwSCM5w7jYa19IDe5"; + sout << "lYzspMfRfAnjG/DPX7m8H3fYqHXo6JKfWGL9NsuFzFPzbeVqVac5AQ7XcoEDXj8b5OmGu915yRqO"; + sout << "zs+Oag12V6PRX1w6EKKyhQT9BPvW9E8I9kSKGZrO6rSq9pBaZk5DWES/UX0U7yCSHaBAlYZP2fmt"; + sout << "tvEMLhzKwu5vjk35yr3Nm2ICYBow40OjKI6ft7EpneSgGbkBuHH7X2RxfCVcg0+7THSBMEMgnBzC"; + sout << "+8JERFD94nZDWrX7P/QTaejq4L8yWMb7uIChSHDnIIUleV3h5kGuM0MCRChSUPvXUESKuhkezPTF"; + sout << "KRt5t3r1CbOHFwnuIjB5JAMUhCUF0IEJtk1b0gnCY109aCl/sefPp3yznMH8tDAWTL6G/aJbRLxv"; + sout << "DpBb20+ZyK78kfTrFaJ6xGgAyi/fkcggLa8ANOwU+ZbVKvVEQOzU9d/8ygn1SjNBASKaHfVja6Ym"; + sout << "TffBgjcGFtnz2NkjUBccd83WqOyZOxsFqOugfmoavukdhDAaVtcB/KijSoBYQVhNITRC7/RLRJSE"; + sout << "/obYwojRxTfq//cIE33pXW5RCDP0odyGIr4YVCHioG+li9NK2/xrduroOiLdaKly5eAaFZgiFof2"; + sout << "jVYA3mXRxwH83YBdKUSwHDhItv7R2lAT0nsfXopxcJBQVNwPvCabAmsCf71FTl+Wn/LfALLRgtvl"; + sout << "b9JBLI/gPf7tKXx/MhCyz0Lq33VoyOGm246DvalV2QChpolNafxuxK5AnvLxYox3U+53A88SddqM"; + sout << "P+k7oftTdX+u7RhDmZtjHLYn4ikjLwK+afDTW4hCBGFR/bmDsZvAyp6OOVnnvsj9qSOWpeI0VGQe"; + sout << "tCaSgoPEwP8GtY4AX+Dhy1uIYPmoXZUPyNQ+ng7MlCk5S6HAzLlsd5giVhyzN4b4N7f6vseMubhI"; + sout << "/lG0Hfcg+rMX2VvTF54FMPgPjkf6CWr7QorPDB/Puw8c3dt69dr09Q/r16t6aJmtyVa0EzAxiqsR"; + sout << "n1POTejlH9K3Ul98w2mthFuaFVwk0ILJjCZZVgY1GxlvYRzPs1/qN+VQ6GdLYFaaBWN3mI20aPND"; + sout << "wcKX7LEodSgu63Da0BBrKBkoihTYg+IbO9m473LWf9U4dZDQmSjJeISGLa7HDoh/JAnfp0Ej3EVG"; + sout << "K9oe/ewohkPFTgXIV/qg4LDRny+zIJ+jsEOHeryCzBfml0mR+WWwwYYcPur4cCABhkZfKGVH5L9o"; + sout << "S2pTxevQeY8hDj9dUblfU7W9RBl9t0j6pYm3wP9wwP1yh5AKezedXy5VavBPk82SzWBqtsyTpODJ"; + sout << "AWQzgF5G6yLvPeUChHRDGedP6KzrPxkSxpyKFZq6dTNl8g+yLg/FHwOZH/R8GesbKKCxjjeZqD+4"; + sout << "UjB68IayH6ttcLEjncK/rLOWn8lWrnTsQdOf7GznjWDHILJ9B5hxEsKMKqOfYLKosPMpIJFk8UE6"; + sout << "CYffnCZFI5/M4stj/2A4zYt2kcqIMN5rIBJ0hJtBlfZRTMweBmJjhAKbM3b/ER6jQOlbPBzNNEmN"; + sout << "DvcgFIL1c9/mqfI041LVldLTpHeaNDRVVXzho5+iVi/R5UfYYTvvDU48E6cKh/ErrdGHXURNYRXZ"; + sout << "ghNAKIOA3YeXZk0SpcdzxRhQaFwLcim6U1j3kVLEYnk17GavItFxjv+22zr382ivLNfjVdx9vxBj"; + sout << "bRqj1fy/goe9ilHC5sxCphuoe9vfAe4Le0b2hqn0Y3LyO/lvzJEVbP2hkFbU06YXUOktXJVCH5bk"; + sout << "JO928XrFA98vOciWwqtRXhO26A5yELcjimSmqZ9zwt+TKz5cG53+CXufCssHPHDSaZCbF2gxST9/"; + sout << "TooDJUbagLQ7VfNZMBIwDVoLpl1dG1Z7GgobQ70lkUAA6Rlo7JfhNM5d8N8UgM2dPPukCAQ5yuMl"; + sout << "0KTHpW9a+qp3B3VE3+8mPoX16p+zpRlw5Y+jA5ctu84Pez0FqZcTmM4lFcevPe6C9LCxJGzKJJer"; + sout << "yGyryVbEzPmVk6nsuWnHFodPve6NeeYnlTkifu3FBfP10DzXYSUO6nPwSZiGiG47AYnyQXmIcj0N"; + sout << "iSwKDzHcPf5IVgT4qZswNoy+t8QBATl4run7klhzNdE1QSAp8eNZnjtdmZdGY4UhB3/mto70MIPV"; + sout << "OBRoyQCHkzWx8tPcn7NMUW9lDJdFFgNiHngYBe47fz+Tkuv1iBxRppU32ByeXHOSKMgaTyb5f1XS"; + sout << "ern+rWMqDivMybInX6vMV8BfQ8eOcnl9MS2OblGPj7XDsbLmr/eOe4cFoKdBKg791eavES07Kr6/"; + sout << "LXw/mELHEEXBqkuafGTUlUfatUz5OmHHM6ssKoQM1uUrGqzgcdiAe8vNxuMsapJYJjIb80meaYia"; + sout << "EyAvl8dEYiGSx5etaf8hwCnXAOzEfHNi+qjIC6Dqr5s87SQpYPDexkuUv1hEGSDnEbWuDo7BSwu5"; + sout << "hZCA4T6yoSSki1hhXMMsl0tR45bm6UlB74VX1IsL7VD6HiPY3BcveCMbk9aNCLt61KYwY+4+xEWk"; + sout << "smbgq1BOC31buf7xQxX9UtsyBHQig3dCScPoFBm0nKaEHK1fuf8zN2WIi7+p3pEQqRTphOxHtsVk"; + sout << "QyC7PmZJS4merPZpki9Mh/jo5MM76vCXgESUr91CY5neMliK6mlVDQjWBPenKw6JwM2UeDjjQHTa"; + sout << "5w1vSg47kgvLT4KUfqQjS/JDx+u/RIrKGZbsj7WnZlJXsBvoFYn5N2UGBj/SOLVmCXYhBfUHZxqH"; + sout << "GvrtVWHffKoSHcVlm0G0u0EKYYvrYCydASHOegavtcc66k4YIqge/pv8j5lTBiswafWuBz5OwbDe"; + sout << "zxzAcApflzsfgNLol/DNUutk3QpBTk2qW2XXt4r6tkfCG89NsMcr1L3E0+NqWc1IbIIxJuSBhq9F"; + sout << "wNvPOBUYmMD08jguGFHHFLFTVAhcxP90IP4NN6/ImZWqUmZ7DsQz27ritt7RAgjR0aXUJ2AipCss"; + sout << "LWr6mIptUwcOwPHwqwrb+s5xuLIGsT+hXzwpaNrOq7AjKx4GbRcL2UhNwDJ3jAQe8K6uJxffzABu"; + sout << "x1FLhLe7N63S+rKYHMLSSQMKCdf8uxVRz2NX44QRIdxXpV5/mHXgDd8qysZkv26IGp7DG6Pggag0"; + sout << "DtuCocsV49swsov9/gw2yR4r8ziyO0oUMOygSA0Uirda3B1MOiYAvCULRc/HyzndANKNjef+uagl"; + sout << "U9r+o2c2pqApAZY9JeKNj0vbW5mmIZPU9B+/LNGnLAmJvpbrcRTpkHI/0VUV541n26da7jXoJ9y9"; + sout << "ejOkZUY9xMjkLl4aqgwi4cfg3TJNC+0GXjM7aRlJPk2uPNujEiLSkrPk5GX1+jlZIdBmWSfIkNeo"; + sout << "XVQf2dVu1YlWKqmRHsgnJecopec6lhGHVIeamgsyyO9GEmQBG0cZ0lkkq3Fc3Mhakcbnw8InWFYe"; + sout << "skGjFM2kSHQ8ImGUuGx9UW/cKMZbq++oZJQ8y/rmTDfzASJWXezdCOJsOAEAqstGrVZXO0Aoq2Qr"; + sout << "rwO6oJYlxdCpc4MLBzriX0qlA/kj+3qc/lAb7svIugXusyTAshRsSeKgBb4qNUdJ+31j5UoZbwT1"; + sout << "LQt6WHxVFoqYzdjnYsdbnSbV6llixlEfdmdg+gyAQuYciR5z4B4D/dydNjmbFU2BT9hbeMqkam/3"; + sout << "vZSjziXOOsqMnGj5ZA9yIPwKcU9IqvlhxSsLPyFN6eTq7EUV6/njn97pMQhew7hrhXWpoq66A6lG"; + sout << "5kE6M2y39eZuQnh1GmWhF2Pr+0Y0hk2WHx+h/6sRNVTH2qOm0wFh3g2ZqCPgQSFSdam5f+Jo5yYi"; + sout << "IFxYEsif6HxxKGg100Yt59WpeKCVjYO26uc+bclSKZ9VNFsowfyCI3ZtFOp5aqZne3aB7tNkqUh0"; + sout << "387ffCRsqMoQTLLZcycLiIpcCgmOCb8ASCMi5WgPrtZxKA9oB44xUgosSQlRgchUdu5qtROn1Q7f"; + sout << "PHqeNv9EyJWiwHUyLp26f8pus1iHAYIJWvoFj09Q0vXiyvPM4iSd4tFX7kIXxOSNy3ch1FFHWc6j"; + sout << "FDiXSs9Lf1iWpvF0O9iqXZcMdez9dx3lnNuXBttsbNFg6vBABk8K3DFtwDVWg6a4CCFlVmucQD4F"; + sout << "Nf2eNdlVUwwpJn//rtDlvPnZXqpqx8JCEyBcYt0/oYT5vJdrhOKRNiteeUuWf51JxdJh4XKnLE+e"; + sout << "8JAmuTL3AJ/gp6Bxpi+4P3jAR9mZfA/sQnTqasw32kz427/v+hGWMq5j+Ak0wGrKdTlDeA7KLBGl"; + sout << "vbZUrSR60wvpL7SJSYkM5XraJzR5ICO+IxE+hfsDR2uTSrzfcsb0M6wnVatnXjOFNLL5AY6jpHc5"; + sout << "sHLhUHSziNpnsMsWJ4h22wbZmrbO0zeD6wTHvUtrTwOTiPfGkgvz3DLcNBzcn4oqfaEk5JnXoE06"; + sout << "eH4+bqe1/kZ7tlZLzjd+7lFAO8dlAA1KoIGOvAL1L5aqnPbCuJVaak81GSiwHlaztibjOB9Qolm7"; + sout << "y/ln+qnZ5nbMYuWnn3lN4m3Gnk692nuMfiqh4RQuwwxzR1FZtiCzGBPkVD50VMzUsxs5d170C4MH"; + sout << "ER71dvl0U9lLf8WKsHyiU79tJv4wTOpzSMszbf+WCpWZXCJrfyJ6ylRO3dvk1QfjPmFkUwiuQy67"; + sout << "q+lFfRo9N50ZvVB2bYNsne/tlNHPk9/WdeP1rtTyMBMdl+RfJUf27wduC7RQzMtdUuhFLYxdDGxm"; + sout << "1ZZBolFeteOW4IboI2KanbMlv+GOAwxkAWC9ymADZ9YBecvz1unjZkLSGzge1MY7ORdnLBA8fyUv"; + sout << "JsigTjg7KuX2kpkuKxQHYs2l78Tlkdwo3x2VDKEfNegyiZMGMr8CSksLfxnI9dHwbUmyBQWMZBOR"; + sout << "gWU+s45M6S9+HP3R+lwsT0Pn12HpljhnPde8ZpemkMsTbJW0rCQnYsrdryfL/X1GUxUxhNm5WJk4"; + sout << "blqylbW9X7IbbUCB3+8A0phl522/wLWusDD9ndRsiiXj2EDO2Ah5MXodILg7FBo8PpHDsgKBLZeK"; + sout << "WlYVQlWNtg6m/jC5004csHtx2ra6MjSxAhxyt2Y5qf90KBDt/61iJfKC0N9hWm5KKMsGntwWZBVe"; + sout << "h8eJjOzM8vP4azMfymC9U87PqPuI2/J+IB+JNVnjLsx2z2NumqgAKZc4mVUs0PiNK2sSVRlMU4uF"; + sout << "DdbEn+oK+iQ7ILBnkRhDu/C4h1eSyAQwF8XENwyYaclxhoVGRTCRkqrOaLh3SSb+YFSFvbDz49jo"; + sout << "ZcfxsSvLNgIeW0MnmYUH9ydCunC+y9aiO2sZTfKrl95BeX8aEjtxCjpqVj4daFz0diNEQSiWIWlU"; + sout << "/EqYRDfOIZoA5QkPBp2rioNMsElAHaosb90vh+LiDXpOj5cz0S9cKvQw7hI5Ma0gVmhQbAqFrnOD"; + sout << "Me/94YZgsabKZ5IxOJrajCoF1WO6GgdJCkdSUrfRZTkxe0y4qBPAbVg46BGOTjnOkSxUUzXAjXDE"; + sout << "wPf86SdrIWq0NDJ3N+D2lu0MloB/tabrVHmBXQwNDqk24uMVQPBQRqYUGCTzkjr6awEz6oICB2iC"; + sout << "14KqcsiwuSxIsY2jyL0VG32aHA/XNQTZetsvMq5fee0jzSZWAqA0EFzw6UFW/kUmVEsBdX5kpYpS"; + sout << "XuNVhtxVGey4zANa9P7nc40VmxDM++pP/45XNbpyIUOPxfcBM5YoHjnpAwZlTTUsBi4Kdx2D9oRx"; + sout << "PGlK+TGQF9T9ZvI0mxDTXiotRScPozQ8jom/oINYIVzdH7EWLEG+nmKib+3icO9A09pOYWw8TSZj"; + sout << "HNAxPczCwc98w5FdlLe4zDwHnILl4NwfnAz8x0QU+VyW6kugntyk+G1NvFHIGvOnFnZA/Ku1nqM6"; + sout << "3jErxhCl0Ii8h4dD7+HVUu7FWUMOM91LQHsgmphdHA+NQXm+/J/zj2WxamzK8pfpnq9SCYpNfjGe"; + sout << "NOJGOuytsDuG/Ct5fik3DquGqccuhve3sw9H5RDhaFqjWOEPT2rhlu66UfRsyy6Y2FwRWcB5YuLZ"; + sout << "+96O5+VUhbSynvOuygLQfmCZb4RR7h1x4RLgimV+8DdDaH8abjuhKv9bFlSFUNpireUlxkYk1NoK"; + sout << "q+3XpmpJyDI4N/qB35pf7RUdp0mlTU8/cF7g4OJ05ryskis2/tURAdoKcePaWSzn7xoA1tCGYC1y"; + sout << "+Z9m/7erwcD8n+1ZHu3499wbywGzscGo0q7Drc4SxiwUUDSH2FWpaoUmKEUjsYLMwx57QPZVIiH8"; + sout << "WjLd9hcjaht7OqgIQI7ihdSY5TTvihLif9siWmrJk2kpl7Qpx9B2VSVfFBrie6JU+6Ecn4XIMY4n"; + sout << "GYVF/bao4WKkobcbSYCmUfTr1gx7hh8VkIpmuPfiyBVHXI1dQtY/BDxkLIBiUIk509Kb/U0xA45N"; + sout << "+NPG833fYqJ6p9ygAJCQJfmymr3TXL/TxfXtSQrO42/upQob6+KEjSBpewaULLu3+6diY1/P9n79"; + sout << "gLnJOnGncD988L0KCthNKXM9Rm9hzGjnObcxg2IHckRu8F5Bn298963sUe8x5G5BKCLWTTPFbBxO"; + sout << "Mr/zSVZjlKaafRQYBWKE3BPQRbLKBYN+XaNuEQ/cfxnsN+gUimbvpZUueMU0Jeo1ReHqZpZnjvAE"; + sout << "Upf5Lrmjpa0Iyy3irUQ8HEVEisLMdpg1Qo5RwuWanqTVmGxW0sCjfrCOhV3TwIQqxeWPfc8o4cRO"; + sout << "uN7IkWSyck2HR/4ncq+tfaFXehMA9rxigFkeXNh3divWzV3W+U/SeAAQcrgUgecpcxqmGkQjqV3k"; + sout << "Lttwa2MF5SCMfM8LwEG3Dl5JPvPfTY15dIUGNjEIFJ25UuJ1vK9K/ns4ifG5QxNKGFXLvgdv2WXy"; + sout << "0968qdLYeIMXJAv3ieTFm2ivd6OVs7c+jp7k2Ondx5e/WeC8ciCOt7JKyAbJcCZYbasFlaOOeNOO"; + sout << "tm7/SmLUp1hZALfjTeoDc26EtdeNf2x0HzFsGivu+/qd+eYiQXLGqStyYPVc2IZ08AYLFClR2r+L"; + sout << "to5+RouZ3CCXiBqoaoZVsGiqyM9Dgt1E3PlaXtgLFzhpXNw/l4FXFKPiqS9DZh0X1aNQE1KAqLgM"; + sout << "F/VYjhOpbzycTrRhiJ0vhzaQUiCCe/8QUlSqCVVNZ9w6Dvc7uuNNp+exxN+IjPcWoSYDe2QW4MHi"; + sout << "3Sk99Bhu+5H4tVqc3ehExt01lIjQI/M1A1CpuwGv9Sz3/tspG0hYpOT+Mo7sHZ7ojYfkfLHOfp9J"; + sout << "gtUT2uqV0LtDgctrD5vx6NEgfMzP2dNnj7wopH9P5YsTv2XPW+bKxUp95K5KMz0Ea2xLmrZhZqG9"; + sout << "HXhubQS90G74ibC9p5nvDts1DoIC5swJpo6ZzVude5M6GmzjVUHTaS7LSjIbxQBcjyFrx+panbfB"; + sout << "TLLL2VaFjfxw82RGjjQtaeCeZt0BsnkfL1A4Vb5ItP+qhayO+9lqnSXdECHa1qFTXKDVPvD+SQC1"; + sout << "cXQnxXvCxuLmHf/ao8i5QFXd7tHr96ioCtng0PWHxbXyr3NQhriX2zbwD2kkrrL4udvZo2TRIe3D"; + sout << "/L46KWG8S+zpo/iXc68pt4sjfuc6DNIXDtx8xkFAre/q9dRKmuV1BrtlYiIBQlhDByyQuqD4Cdwj"; + sout << "1QMd3JkiF/WXCc3vsuLH5v54EMHM02MkS0TjzLjOinkMT9s8u+ssSyp0NUE83XRdUoND1Kw8DbI7"; + sout << "e8izxcz3VIpR0PE6MYUKS6SR1Gq+haGlgnZYxBCbrhSkTRn0VQ02rnBF9An6rmXRHbqs1rJZ4yXX"; + sout << "9/jj2dPwYrFW5NMXL35Za8fTfG6pXaDjPnqJokXt2rKX8X2/DVWhYQwBLpZytXTc9RpDis4VA6ao"; + sout << "0ZG/u1qgQeqsmhkwTayh6S+VFpwrQA5upwIL9KI6JdbCCsQhazF7qsEDoc1GeAntjyqntc+mml41"; + sout << "I1ZCYtP0sSOWfWiQ5yZavEsrp94P01thtx7s2UsrMtrYQc5hxs/J8VWtIIpg9AIqDW7g50qsBv4w"; + sout << "id8XQPCcfyXQszUJi+w4Kubu+MPwAMiAVkDmlJGAC9t5oXIIwIbvNZHrq3C1ixd84+VqhvZZgw+/"; + sout << "izIXi9IrhvkcOJmuwxQ9FCyd726/ks2e4H7TafuTXyE1RQ+6Ju3odB4meFU8SIylbCw2jwoI3kTX"; + sout << "XXqHbAcdgMWoxoDXPZ0XzVqC3utfaOARHji4+RXEseqA3tJhS5CPHj+D+6yXJivrVC6KWSwPqLyW"; + sout << "6JI3rVeT9B94EYtGH6qWvcA2GaSJjxiQUZQyqahHk/FxSkmnMTPwtvV6AcmYZjQESUWYI9UQyHvy"; + sout << "5e/ULXf2bVF3dmTWoKOsDchTNOQ/Vtswf5avr9tuz841NKT+1fojuuAhhYP7uvRruvMJqebeiMIL"; + sout << "QBzv27ji9+tQ6D42+Z5U40UW5XtytFKcGR1NWtaxNkWCPMGRLcC4mFLuQMYq8cK3ZZDhpzyVeyIi"; + sout << "nLJRFigzQfHlBeJmXm8UccbJYo6oruyFUEagjILnb+fb4y+uE7vN6WoqPcg4vvr4RzSkuHueZxHL"; + sout << "5ks+O6XyEywn5hU3kqOmI21d+C/HBRjJHbiL39BcHKdzhnHJ2huEtwqSV6SsDiqhxSUL7zYEDECj"; + sout << "zs+9/3LHJ/9TbZZ104stYMWXIWuLc2NSS0bF5NHarE4XedBAcXaJzKzlj2o6NYb4Ifxcs0shAxl4"; + sout << "6ptwCMOtvcTz1t5QB1vG5CHrc7emzriP7+H1tukmCRA2i7cZPumoiw81AM7SPv/A/9yr5EFRdBmv"; + sout << "FCOwJNDKfD9Mfb6U1Lk+ikWFR1wm3MSqq4ZSy1Q1s4F0npfVh1J5nWZTxd0Z/TqDhSAUU47cC9do"; + sout << "qRMBg/bBPxfrWGeDEOLActWvxMr2g6O4tdWLxTtWcGsGiatS9TahYclwL40gnzAC+4g1yOgc4/u4"; + sout << "Ye0T1KyUWUnAfpWmFibwphrKOrVj4ZBVzahAWMG4jWzDcxTaKoZvAtQZi2FOn4V/+bKMURZChDl8"; + sout << "DZpqrBSfSyY7DnPiGRgxOmTKK6/kb36nlas7zwcQvSgQTw3vPJSU+LJg4gKmDycFpl5deILsDRBp"; + sout << "78NSCvjJqK7UVNfdAG/2C9nQHNk9nHEf3jo4Er7ogKGkWOHshaIvzL98Nz8BHbr115gtKo06mD+C"; + sout << "0EJGtywTGXAzkO8J3T/gEuW35KH+kOIBMXzQcJNY1lNud+fALQxNv2Rw5u17KBXqLtABWVaOABgz"; + sout << "IVCkpM90yG5CwxkzyKO+CIxb8JyNs6Shk+CPNoiU8mbqMFigPgXw+yWEEdOMvTdzJuj5dKX28e+I"; + sout << "jwOz6wcXPYod1yzF0kPUlol0XaDR1m3abfEiTwquKsTY7upWPjU+Q4yUnrP0/+W4Oe6TgH7JvEvy"; + sout << "fhqLicE67NkotZ1/+AMZVDYWZOScrGmi8c9Yj5dg3lWi5OtfLYg74TSQ8AOiH6jL+wLc9J+QwOaL"; + sout << "gibI6x/dWW6eJGHxInVjQrt0xmP6I90cDPOnx8JC9tMAEv1MrLMt9VKpjdMut/Gok2doLdEHyTKL"; + sout << "wBNiLGq0viiPXgm9/sbi3KGrWhOdypAbECm0emMVQK/WdkSFZu6PSArdOTvHthXOvyTb58fm7wS6"; + sout << "oWvK8J4f/T9t9aNb4lBAYFtufB5S6ndUr0hEU0nirxgfsXl9ZsRpsQidq58dK6l6WI5CmUpYboL2"; + sout << "6T66Tndz5jnvtPAybpIus/Kbmh3tPsaVtI7e6IRl5mzbEngtlOqfbwlyJyUiC5UuGJAUCtqoKadJ"; + sout << "QKGtTuLv2VqavDmNIuZ1mRX7g+2Kbf6y+7L42iqivvX0Lmo8YSJFTuNkWTNlu7s+yYH91uC8BdT6"; + sout << "3DqKRgiRWVfgKuuJvDp72gyJknKesDa36+iCp2qaZZ+yU0XhK/6U2phZ6XeNJlQY5H8JD8pe+RJk"; + sout << "qgi9XqP8qvLIIv0RkfnDVJtl4V8pi6I6UiaD3dNUeHdN1YHCu+Ub225CFQIDsko0Ly87Iv4oULoI"; + sout << "1gmb6dLkogavqkQfz69GYinC/z5/En7ekQ2TnOgFiPV3SaF70p7RmyMnn1/ON9r8vd4NfFPXXCCL"; + sout << "NBSihPK2Tt3T5i+TRXipLYq4+kghSyTRhT2YRIbpvT3yYs1EGbEekCQQmIltN78NpyWqBwBks9Gt"; + sout << "xpIS1Mj7Ff20DpfjgaSok+XhDSDA8pjNzuM4Hv55IJbF5se3Nj27glEhioHCvKzx84FUuInAjQtT"; + sout << "0+5hsgLkTppvJTdSUg1hzhU7wa6NixwuZgfUenls7CwH9K5WYB7KN6ls8gzbK0V3QMoyZBRIoSuO"; + sout << "7yyzgY5erfXbwfcYQ2LliJQQyYyLiEcm9HCxJGMuBL/41LQxxm5KBXMZhJx4c/S/54GJmTMDfU+3"; + sout << "xceNujbOdtlz9ahcxUBPGOU5iTYBGVUPFscRy0nxuHhOu2VI9rYBhoIUZs01Sh3BDgksu4uUnGof"; + sout << "crDdCqjpR+1UVnpoEmxbW/Z1Ux03tiG90lCn58/g872xPsO/O/DDX4wfNU1yeibFvWrW57k0FND/"; + sout << "lNNeLErDW/nBpQSbaWal3O8lfmjD2SBwk3xtJ2qXx538rFZRM8DQJUOfUjvhycRGITPP4GZBuiE2"; + sout << "cVpeof6Vqbe6J4I3E9Bcrnm3moAyKGiPDJ6UqthQuILGNTH2GOOXAl2I1X0qcJ4G2H5xgLNliett"; + sout << "T+uB6q/kG7Ab+4QNVm0j0zdPdVZdUUtGceMJViJ+ZoRVjunK3QuzwgBHYr89OqNxvULc7tImHva3"; + sout << "xan2frwQ2ZzJy8hRCIdaNSNgdZp4WONm+5Bw8/MU0l3PjjXjtlzjUE+rHSwdv5abDy81XQwkDP0y"; + sout << "CS2C7MMahBsySi7v+PrLljFQbyG+deR6ch15UfFjtyd6Mctr8FcW2umvPYSnl1g2Dl0jE50ZLOHj"; + sout << "2gRAl1EvFXIuPjxcA7gsrlATb2SgCUtrelJSDOJpf8a3WMKrLRYp2MONLXacFRpkI0iU4fNhbzgB"; + sout << "o/CdP3O/AMGGnaZN5kIkmOwm+naBEF3mv1xFE3ZZ5VldoFrWrxZ+9s4lOAScPYzVpABjbJxzOzSo"; + sout << "fz7+pEOY7/OY4oNgqWA4MEtO2pMbjUpdTbSiVuXD87KHJYB48ffxkA5aPKFWpdRFCJHM+J2Zjkaq"; + sout << "ZRRUeBZVzo2rRvFBWrpbvAkkaXEDk+k+N3mOn1LF78Qbm5FWAn8aw8QL7oQnmrkEqUj9EAiYTx+q"; + sout << "v3oRQ/x10A2LctzTyy0mCuM4bR+WQW2h8dGoMUhjJpGPQzeR6owMDi1n09T84aVUDkxfR2PpIHxs"; + sout << "HBag5ryae61mzdzqNmyFzySbIbD83r7ML6FqMCW4xpWF6ru2DvNykb9GYL/kquN1rjNOe9XYEQM8"; + sout << "pyYvcipUqJdEeYj9sjgJBXkuDr8fCBpmvJuklMOeuAwAMXwxL1b53VM0XNgVokPmFTSF5zqesA/p"; + sout << "ah/7f2URRjrP9RR344YUdzGoccdhpNnFQ+g4KRdmd7FuPTJYePyDncJFnVZNT1TP026QrRPzg7H8"; + sout << "sFMoZQuzZxjLy8gSU3JYcBCnfqBdcvrtF5Wb+pKkEZxROE4qURll4Y7zYVxPv03Z2PvA6eovQyap"; + sout << "Pr2jTJPFgyEiAtcdsluxtc7NvY+B0Oxgk+oUyaXAhkihKKSGT2pPRUvd0LqDYU29/7v148UhItLJ"; + sout << "minqiK/TinqUPHOe8oyljHRIO6awrl1oc7HbFs/8ZrK8hd07LG27lMDcrBogEyUZ+mcBRykohy5w"; + sout << "vEPLCUA/hLP4NHaUts+q33Wtrq5n0m+4k79SRjwuvmcSXjf+Uf09CQD32oicQr0UEfo8uhrUPwta"; + sout << "WukfxXE/VF/jfynSXwEomRZopcNoc20gWzRlk+M7HpjY1X3TCi0T6f8qAJuTjSSbUZTFC926jhsk"; + sout << "a7Z67jTJ0woPzPFjK41KLWQcoArKTuqOytbd5fKao7qpxiFnUPd30dq6Xhpnt8oW2QESiOPq/4mT"; + sout << "Bdtd+wl8c3bzgxRpOGBKakq0iI5Vc30wc3vTUw8Mpae/beCHolCN52hJF1LAKRJT4xYWMHLpKwNl"; + sout << "ZqpKMybMAGM7O5X3xqcE4nH+bPkpkMILmNN2x4JfWYfy+667pQdIed0jEaDRHueSfyuXCp+Yu7sV"; + sout << "q7CTe3bwrUSklM7zpU6NxuX++oMNuruKzZ7GwsP/NkepKp8MGUyoT8EOz7yIq7k9z61OLxz0Ylxr"; + sout << "0irk3K8oCPttZz8y3u4ZmkAGcYOs41JR3Lu7D+/V0hjdGtNdb3ALJON1AB7r+OQ2LW3SEP87POnT"; + sout << "R3wayKImuVIpIjZi+yBSMqomZMtH3nExOSKTHzLzkSyefGR3WTSIXI5GqtWNetBTk9y2eeRzbSu5"; + sout << "+7c0XsuaWRQd+2qfns49HNwQocfbM5kNMKa8npRCg8YUBMimQiihwEgM6NJQWJ3ig6vwXncIOAYI"; + sout << "kwcYWkLBme6vSiaeyLRG8y6A78zeoc93lGy87yjRWYkr9CeARb8B3tHX4E1APF0Dcs3/6MTZhnmv"; + sout << "eMdEj9TW/MFATPy0BELuV7nLBpSrH2Intvo1jzPU7dGsw/36MbHr2LFO9a2PrZkMBCixQ5kApER9"; + sout << "HQCKGangWrdU/jO76/8j+mcrxFNUEc7y0OUUFK8ZXwHMXQp7MOew6OVUmRuFSl/jv22Zdz16njaZ"; + sout << "Vd7fVrgOdTG+EXRseg0LvvzZ40qWpRahwg7zqP/sAWp1ZhH1yJVoalrCFuTLgabjWGo5Amp4vyk4"; + sout << "zzseJ+O10oZZrWsFOg1lbPjs4j+U1IhQ48qNS/zk4Y2kaiM7/0q7PkaUxdwdcdHwbOmp7zszAyfu"; + sout << "rwXlt1PBiRyEsBKKKsBLwUzP/KJEoRvFvgc2VgzDxxe7pOsX3E6+ey7gKt08W7RNTTGx70YF2uVG"; + sout << "4cEQHQNqLygscv4DllKnH7Jz+2cHnSES9us1PaWqRwpQ07XG/sSI6t9o8P0xMBVkR6XB9JkdKP1o"; + sout << "Bj1MgxVN8pMbv3ANGQzTbmCQ5IczVK1Io/OQbSEVlZb8ZOatdQ9nq2YZvedMelM3C+t00b5AaQa6"; + sout << "DSslaJ3FWzJLguvwYQgYUKvJ+jOC8dMr74YNAxt06b5RaijSib2TNDjnZ3tIySixM/dYOnygj2QA"; + sout << "gubdj4tuvNpWh04VszgcZd7TbLntvd2IA8b+FDDzMkyT067BC51elULM0CWKoZgaaMglt3LqR43C"; + sout << "H4TQMRoitKUyfjAPE07UmUi6lEGYcrCyawT+Si9YQy+T5vGMpryrpy9gaHAmFQTGCJgit/0g/c5k"; + sout << "mbyZ4pilxX69VO9VHrc23YN0CMQ0W3pP/DpPNYQTcW3jeG2T3d8b48B+iY9boq4ONEGcx6y4wWok"; + sout << "SuOSzK6rTR834/xqlNGMLNZWSOFjLdq7G4oQLmVZVRIkUXrQImDl5W+1EjdglFE0jAnnuxjT6+DU"; + sout << "B9VIMbIzVlVId6xVrp6Kg+Fb5gcxGsKcpH83c18DbuuIoLsywmIEiWQfrSV7bPVuPKJVof/t+Mar"; + sout << "Q32TiRHMO2ZsUXC8w46fkI7EsH7mWMWOMfhkvnY3AVoAj2i6+orDHNwnTmWQiGj7FR9Aw5CQcCbh"; + sout << "W57bPF88lkQMxdEHBd0Hk1c8Pm+p3x71xzdIofXjTnCd71R8FIIvoduLv6XQttPFatTHY0CeRmjI"; + sout << "49yWBFJ+w+BSEP6cnTQeSo7uAOyit2S3EdLudn7XEBr9eN4wIVAF5lhti7HqJqV/pfqpkgaAQeEv"; + sout << "mbTG0L3+6Mr4RPAKaa1H1aQXcYT32jdC7w4BKaxvACcyjNuukUwhi8GxxXxIktPrOJnwNCs719MY"; + sout << "TbFkKOxloVafybwRZx5DdrEKMq+jKyCffePTaM77bvL1hNyK/78Q9TdEO4hx2q0m853zlnibtbwM"; + sout << "YcO9adPVjghDSU5amF4Ul5+rMxroEoaGKLW8+gL9DbmcW6Uj/vg7TRQohnGbzmKwPWT0hu22ID7w"; + sout << "+ILYIhGhG337Pr0Osu6m80nETubjTijMQHAb1/Pwv2eNOOdu8X2EM2awT+qtxqHCq6il4YFSbUdi"; + sout << "+i8vY32x0cXZ1EdKfgciuKtK180q1S57ZNZTCwe26z5eLXomvjtxLxBZl5z2TJyPEnJprnZWNOo7"; + sout << "5nuXsYypcrV3W2KlRmybUjRfc+/6nMKKBtU8U2phCUg1uVbzAXgyKJuGRmwxUOXXWeiS6f+X02yn"; + sout << "ZpJda7aMY2VBMEuxMRw7P+S18C/drO223DaUROCv25mTOZGKRI8iXQYEa9UxLvXgRfUB+eEGOuAS"; + sout << "7N4NVz2kfzGWlQdKqRoh7fWDkM6R8J3BgWhXvl6dacJcZyeRQGjpR0ZEKDFTtLK8Dr9NFYXk6Pnq"; + sout << "N9kMrQpwK3bDmdv++XISq0WHnf6SBJcNhSx74TUlUDBxsYAcoFefI1FNbFFobGmT7lEX+L7uQNmS"; + sout << "/TMxpzGwSeRavNmjkanqTTZFpWNvRXrfsod+TmJPe53ApSSZpiuOAedN0/3AGsVuqjNtcxf4DUCZ"; + sout << "13v1FAYWluQfSXtsM0weQ0q6ZvhM2SyX3VjNleA0/X/PPmOLLK423iG/d4m5+D95t3NRAbU+TEUp"; + sout << "Nff+BWEiVSlWCTwbP7mTaJPnAUWND7rjif1IIrQLL9IdyauKUOP04yMgCk/LQs0HsDkLxSexCMHZ"; + sout << "rK0S8Bh+/2V/dNF54MVJEqln+Urz2REIGmzqSjly38FZqDwG8Rn3dPYImGASFgGgrCAbYza3ZyEb"; + sout << "yaPzMiO8wN+1F+DbMDfvMtlfkY3fO8LvJ5csTKK94rSw2erFACsU0cXqThvGma5gOcpL1H9CwUsa"; + sout << "zKLfBclmCCGrP4rX1/SN7OBpNjbrCHZnkEebf2TrzxT2moacNtJt1cwVUVGhiJ/DYdZawCRUE8IS"; + sout << "IKDq/tSGREdGSMLEwJdXyRk+4LyVF1V+YF3Jp55yFh5skh3Be2lP5YY4+YiaZDb4IjJh7qxjjO0R"; + sout << "RmvES9HZTw0GhCa3ep3x8dvbHTp1NB4//onnMx9JheeyZRZOdEK/Suw5BlXNgCvc8UJJR4ZLtyuU"; + sout << "7HxX7orjWgkHNdix0VnJKSFYfn2uJ0LTPrz3sARjeVnQ3W0On0DIXqNe4thTZDwUFWUKklmQGV2d"; + sout << "bgVLPmuKff8DHP0xgGyiy1kHollAqwEUzD6bW3Tkrtu2f+LzTN2yAQxKkVsiW+sqEKqUtzNoXVQD"; + sout << "36tnDjAFotkirlS5Bzk5pXgDpbAxFnVsTDq/TcT0XbIR7QvBd0bXq4am1uoZjaLGA6bF0l3Ej6Ub"; + sout << "pB0ZMs2nLo/8EHNIgx4cBgXL2FN0DWyb2CzchL+jsPU4s7xBYuVvGA1bLaM/nlRV7DohA29pE5jw"; + sout << "MXQPYIDzraGVsOdJ0ghhlqqnwPgO0oCaGDxIlBehjzkJcgXPhWarYGGBdk19SUjt9y+aqDQmxShW"; + sout << "PrJagqBImi4S4N99hOHnIxGa46BFlk72ddnCvaJ4n5LahIZXXAhqOrxRiCKGKBX20/vzOh2ICqUn"; + sout << "MgfLtH5T87Mho0k7hh5gVFpYktOqOowwqPJp7Q7Qlwty07GrnkriILYvkp1bftSHHyepPxAAgaJo"; + sout << "jeyuGTtsa+K0jHc8WaDSZncIqmuKs8oVUZ/1YD2EkVlVXIh5NjyDc6RFfYhP5/w043WKp7xqeRwq"; + sout << "zaMTKjrUr02omLUduz14J+ka9uUErd5O2xPaV2O0LpLuZaMjoZ83ZirW6U6ipV6d7c0B3kC4CcVk"; + sout << "kUdzqBoH3VVbTav1vugceepwBjBiXsRqReTJBzNRJK4PPu/ApBKQKp6jyZoWCfhJiuJNB1sEjCbc"; + sout << "o6XEZWbBEjqRaCujvVzo/dLLGHUUOXyGA8sKJq/iAM2dqvDoaduLBnCsp2oBIFw1q0s2psDRLdvY"; + sout << "KiTAKEAypmqRB37gV45Mrmn0zYeDR2/EsWDoo1HFu7Kh3FX/y2fL0155dtiBJxkHDJIXj7POIea9"; + sout << "Bl/qheUN9FwqXsVheoacBp+Uh9NulkH01NsKCyX3nV9ec1sdzwh8hj3zWZIpD+ATVH3S9Hh0pt2+"; + sout << "FHzfu7z+YtwgpWfSqim/C0/yfDF1QoRws09DqLTke3TmnvxoRfkdOwsc5y1s4YR72xFXVrmIRpkW"; + sout << "pqHx43/MxHaQ8WrEABai43+FqGJmR70fvCNecLDPm0ZlGySkOCrfgMH9lK018QD+DRu5D3n6alPI"; + sout << "3bEqruPXlP/N8eNnJFEeDbxQ0aQFJFxSOmTkYWr3LZBjtjX1Cc4qujxcCaUH2hnKvdyxejlt8fEa"; + sout << "h/I2aSaD3ejPS2xz8y5D7Ve0mxx6znMEO8E61RbXztaIRzze57GvfwpTa+NbAjtX43ltW48u0yeu"; + sout << "5UM25BNbXDdmfUzzVFpiaflImZj4Ga/k8TjlTVURwRXX1GzB7lfdEukFRdQnY6VQcsjzo3O0gKDA"; + sout << "+ZUtku8C2JpdWVYYsTVoyR+5rNlgHdaGqbqzADSrSqNBBs0OLaEH+pqoVvN7gNNk/hCHN7tvmi8n"; + sout << "LZQHQ+0Ukp9p3bPTUapH+h+NJ/SXUHdjHrJpaFJR+ANmN8a79zacX7v09ig91M4yRTSxzMzsbQef"; + sout << "xnNKm4+2JxpFLH2fmM4FM+Rx30KcqLmUdBGOaJYNHQqBVJYYMNVfdRCo/KwBX7r+V9SRlLQ5ijED"; + sout << "zKbJlHi7LyjYOq0iv0R4MIiW67OR+5O0aHeoXQ5RJ6Rs1Xytaiipn7aXphPVPo+fZ2xH6nLTQfFq"; + sout << "fEutvtOrYIR8G0dGUSYsbZEhdPC/gpSdHsO0bOJlxgXhNGNeOFtYohwIKP9Z7Giv+lM15GPlxnm6"; + sout << "usrSaqlZWs75RuyOn7etmGBSg86+J5wE3ihgVXcsprkgs/y5mnP9iPifs1UBfDO+2L6UnmceygSm"; + sout << "sIOIuMzGUkrM6iOLYHtp15txD6bUYZsIEeCCnjaIJujQc0lvbU9bh5AVae/HG+SfuWqE5nzp2TD0"; + sout << "jug6A2KoPK/ufGzjCJJ6VIe9v2bG3WXApQGWyj8lWCO+d6YtKFA4agCPTv4FSnS+Ip3Bws18ZKxj"; + sout << "M+AFJKpr7lmSJURUl38hITa//w1TKSA3vSoMaEFGEH8OFkOCDFOHoGK/z/ievfyBkC0/WgvwcSbj"; + sout << "uvaOd+M4umm/BCNqcPjZ86HSej9mSb24Kl3Rn2pxywK74b8P3z41Z2yVbRIjK1Dwvo6t0zPp8vtD"; + sout << "QHW/9v2uZ0VhBdbRAL2+qFD6qKg/sacjR0kE/o/zRR5vsW4K0efFZ8uerG5hLs4rYhZPtiDOHr/M"; + sout << "xOGIfBpL30xzMdMs2zTnbXF7wECjcklZcwlefbO5ubaT0doKkVSZ6YkCQqVXL9LNVNHqULonUgBs"; + sout << "Xrbe3KKGFrToxCljsw4+3+RUVnNN48CklcKfEfTHEXoUNZ/p5msufJct4JL4K+35Ecfauq0GOsLX"; + sout << "sTkVaLfkelv8Uiju3VEggHb+lLdrY7oKT7aC+kyQu8SjtJQY7b2aKQNELwrX9u95mPNZkySc5vaA"; + sout << "xwL/6ar2Y5/UA/Lh+5d4hGSkQ0PbSacBStFaLzJ8AVdOcXihah1a1ddYNtfH6kNsck9XJqz9Y4G0"; + sout << "YdFr6DGHNLBlu2ksVL0ta/tfPZy9BVFo6Vz3tbpnBwkPJ1vo9EZrJxo5GJ309oeLz6mYG+0ctXJr"; + sout << "PC90nYz7s4RxXkWZBzq3RbNiBGNgP1iOU3x9INOf1CbKBhNSp3b5QJDfM+hVzL0Xd78pBpO/XWQq"; + sout << "hdl2dL5XZtd3XYslKjHDaZQyZgnaoWD9hhRlk9lhweUSB5bFjHOH/FZjlAH8et4Eri7N9Y/OFHFP"; + sout << "+a+FoLLT0YiBGFsJiqe9udoECrxUR/A1Yi3Wen2xOQv+eMMQ6IcvA5xYONHWrS8WgMT/I7oELwnx"; + sout << "GK3kwjKy8lQJ9HOmVa0yuQsV5bbrtjQcP72TmuWKOcyNWf9vHvLbWSZUDEXchz4HfADCJEehJzh2"; + sout << "vIYiVvotktpiC/MzHmHoTw0Sf/hOT8AALmyPqVFGrkZHhLU4GZTqcJEzBNUvkvuwo3qMHzOc3zfk"; + sout << "VtoQHlg0pOcKczFo/44eMRlpf0Cn4KTcU6vEmUncnvx1nNMHEYKQNGp/IXHpXp+PN7s9zFDFbQHY"; + sout << "iXURtZZ/cF6jW5TshrT/CgpldM4oi/gn3yTm8AJoqATWeF83ft0AphzHDfz+B0qf2wF9Te8adqcb"; + sout << "ekFiLrNf0VfA/wMCE3drkqLhZKO6L99EK9SZqsly4Q1rwLr+sIgOrI5v5g82DOFpY9zL8Sxvg7Gp"; + sout << "03KDCEsW/2eFfxNtGGZ2P6JJOWLIe/6Zsx+1F5lcbYxPhSirbDYp+1bMZvbm4uLWHttHS6e5NQH9"; + sout << "pW9+qFTlk68lJWtpI5Nzy2CutML4mI1uVDgIctZLSHZOTlWW3LICiFxYVLjvE7JExUZMI5J+3VZx"; + sout << "OTTeFGiuzoEveGJ34tOMrHxoj3dIMkRN9kh2dN10crD8MKoo1sb2EqZvhKu3yb+XxKKl/RkHTuBv"; + sout << "JZjCnoTnGa9JltcMDgzO6demBWSN9950HXwS0cZRcDIgG+eDQnfSfHGGJ0UuL9ydA0KYWYAPxVMt"; + sout << "tqR+S8lKMzk18k7FuGBsFjloOSgsychm0GSTKu+AVUJJohpnT5KhE9Z7Y1iSc1pBOxyf4iS+FF5v"; + sout << "AzB7BEivwSPqcjPplzSQ1uGNZsa/mCfGy77XBBm5wYUpKDshd4l8mOHPFofPW9WrpI4X2e3LuiMJ"; + sout << "8BijYVutWZv/8BrhhFuUnngfLsaJBhO+KL1m/JsQEvWIJ53jPt5qgEOn/4Tei953J2hanoVzKelM"; + sout << "l2OhgHFUxvIYGvKMwLRMPQ4kUH3C1vvBJivvyIZ/JWcHRq/0CCjvg1aPUXSXhBrgkqFwQduoNW9G"; + sout << "D3BowPV9v7Mk+IPddFxHdQeyHdaquVKxWoWHzlDnv0gbzlv/EdZVUd7k8rbqYgUdPl9os0dbh8gq"; + sout << "nRpOW6MTBdFojgzrdOHPhdHM+Rs/TIntPPd2YmryOjdHhAyrVgFTS/rx+yDRrP+ChjTPAF+k8lG+"; + sout << "qvN+Z8gDFq3B2vfIgAUs0rbSMRfOahM4+EH0pX7fe9ylY9yPHtWKs9i3mU+uVTgB5ZtRKuSKkpnO"; + sout << "MgWOp94H7Ui/tI7b7hAC673dprCx1PZgdIncutPRIlxzv8exEVhjhH5AsW4gRNxLX9pukC5fVdXF"; + sout << "nWif4xcZbEcsW+mm6IV8NE7iokkhEC0bahxpHeS0H0RB28nZY2idhHP6hLDMkSg9b8mlJgqcGBtj"; + sout << "Q2Vd396H38VydtreELEV2FgC6Ski/PFP37rc3/NDStzcY7d8twoEiWa1msraDn7zlvORuk+Cu/Ec"; + sout << "VaHwuWQdyqf7dhOhrH+/wK0RTPMKHDyLxkwJmO31Ka0LEOvqMcsUzEaxAMUNz406l0r6pLW2MbP4"; + sout << "AdGu/pH2HwxZ7b1ps+8JiOu+U3SDrmhMZH7kS8LUhCNpH21Ty6QzgJcaw2jC9HYUoNeobJSQaOFu"; + sout << "kiOHd69hw7VbD752KMiVur1AYmRegfhkEm9850zS/AjT+SntphKTtFNsc5cR074M9iucA3cH6Dv2"; + sout << "AUSKvl15U83cbs/B8NA/4iSp1OTEb58O97rM0iLsWyBLSq2Lx5ZblFZeLi/rPkoVGMo/o9UV8pNa"; + sout << "Cy/ushHj3CMku0REQZQfcxl3OgQ6U6avNywmwdncjoDCwX2j7MwoZ9SrAKvxSedrSkdRhMvVkuJO"; + sout << "l9c6Mp86g8KafJQIp2ZS7CMC16uoiaUZCqdcdIEsKvzMreCoM+XGxXE589e5BsfV2gTuxpvnWVCs"; + sout << "8BaKUgLnLKCHzyioUVC7bzz8Ov1/Wqa1LHzsrobPw7DV+dIDHqAS0d7TRra2q6rXAqYA/0P99NTp"; + sout << "uhJkixu4ACmGbuBO2ob+t8cdczxIBSRPHHZNCXQG1t6A1IYpozEWiFjATpoD+gEL7Mv8AZW7AF3Y"; + sout << "wFQEt2JapH373jpvJ/lAIt33yTMgyTQHKQ6bsTOWpK1EK8NvnOSOeiFfmtseS85l33K8vl8pyXBL"; + sout << "Fx3dqms9fdbwhUmUWGcpzynCrKyV6DVGFznQtpr4QCl67Oz7RaaYXGe+lOmyrBgrnjKcQk4ue99C"; + sout << "ucZa94y9SaRBTFNbXtpRm9jjdzoDmtd0JPCGzSACGikNArk4gC2AVKMJpFdJrS3e11Ew+pprc2uR"; + sout << "fq/nf3C9WQ9lQ+KbpwEokeXHWNmWBwDZwFCf2w41zxYqjTu2XjGt27ea7nleb1IA1PtQtcSz/qoC"; + sout << "6A7G4EFk2WE6VNL7/Wv81DjrdTdYy8DFQBzab3Qc7T0+T+aNZdNfoZSj5YK35BB0mmPbxw2U4sfx"; + sout << "1nahYLG8OEQgoOVG/3vG+EmA+iCAcCBHPObthul5KdW5JXwPw3AMgL0UL/3GIpXRrzE4JTPrFfzX"; + sout << "zFmbikWuDkDw9sP8UNfHAOGWE+OorN/Y43rcSfVA6xNgMhqr0pBbvVTIXbeIy/nR34IJGR5xN6YS"; + sout << "G1SriivqytwRmJLC+Z4OoSVtqNPzPvUBzs2wHPF00gGAE1FpUa/CfxOfXhbx8VoZWMcuysBgpRCg"; + sout << "Xne4u7qvre3CRkuXLazM6qKDt47OA1is3gG6ZA/XCvcI8EA61YccN6zTpFNT2cBGQNQ11tK5jqEV"; + sout << "JeONoq3L4SCF6qCiCKMg0nsc6sk/xJ5Tam9Rl292lzxTV1Xa3be+lfSjUdiWnmbU4ml2Y7FEVbWv"; + sout << "4A3WHted4mKb0HiCz8wgbHrjTnNQyIntFcXuNAUUzri4nGsV128EuJrPmbYXqWuKIe1wPT2Lbco+"; + sout << "7qH3xg/zxl+t/bqlyqghqoYQMPw1FJXz7Z1TMI9TNbxUtCd+J+mX3xkrYI5AtCbTZh3lMvbwWJZa"; + sout << "4AOY4WRavT6XHsBs7hKl0iHOmTyMe1bkM6sg2BChZ2B4T3BoeFISya1Qn2U3R0Y61plw8OFrTsAT"; + sout << "FKgD4pSc2VaV5bOPM/vqg4B1In01SOw4itnoIC9N4+IvHvzQFYzeqiueAcHYGss+uaTMqqc9c+gz"; + sout << "cgiJCv480X1iiauIKK1Szm/06GiT5WWnkyqWEqxmNyMPrj125XXg75AZeRDazTWXjObzZkJ7UA6i"; + sout << "YQRFCFH0enKb75qIjHQ7635m/Xegoboat9VCJDGGVNVG58hQAVMAWHmnIb6oK9t1aaqXmAYXXgFo"; + sout << "3uchUubvOjQXN+txW64ZRXzM5lAw4TklqmhDjxruXRdddoBONkvSPSh/XiDTkdmO6QKTCh/t6NLo"; + sout << "Jb9Z7nfmQUd4adkS/+aUfb9LNOPd0GHxq2O/y9XXHvRTBASUV6xlJzpnOBAtTJSDrbLj9NIhiWpF"; + sout << "nsQ+Nq9aEthi7RHvc6rujDGF1gqfJjsdPUTxveNk22Xc0fe4Qm+07HEg0HvhfFWJdFJJHirD5j82"; + sout << "bZ+QZTppInMAQRh+dxZhlCzPKdqS7WPS8dpXp45xBpYH4GsG7SaRgbCBm8R96j5iSW3oYd+ql99U"; + sout << "OC5fYGDbbGT8zl+mtYi23CAV6AY/izASsZHq7KzSV5taKzEZyJXf9D+wRaopIZ5JPPtmi/fr2HAC"; + sout << "XyKoAPNAywLXUTtDezuKTJiyy/DBlHKxloa3uVTeG9YBJfiloyym8p9GsCOcz1p0s4K1Oyph1NCo"; + sout << "CyklDAHvlMNjTA4hkWIHTyIStVmgvVkNN93qpPYYNnV+fIETW0tvJDnyJ2OnHVP4h4VBWbntk+uE"; + sout << "JDG+DLleC5n87Kk+e/lurPTik1Yqpicwe2ct0AgA9XJCrDO6CYGT7QzFhEujgOYqzUmFe+KGn9u8"; + sout << "4QMY5HbYXY1P9V3pR/QyoyxtuZNU/caTzRM58Hxe62q7BHmsL53db7cDUE5dYVR3Lx2uK/pm3QWH"; + sout << "IIoyUhcLmk04dLwoi7YAhnao9IJQ12KlFISFoa/3WvhWbb5M/GttUDXGX1A9JuGrhjzXMIfGyS/u"; + sout << "wugpWw/5SnWaCis9sNGmmMNC2rniqemMrjAUv5mGPYA57KyKzvVGGKK7VM8qyrtwSKOQ7fWRevDE"; + sout << "NVAwhR5NnzdQmL1x2UrFcEkARtk/8RU/tfy3nXqg6Bi8g+myCn4hH4BOtK5E9GItI4R68YjNDVBR"; + sout << "MiW9KS5gXcigTV1fXrChyjn4Enh+GrdLBpWDJQr7MKSVSmqGXntmhyTAD8Hgi111whRmKXgwi7OC"; + sout << "k1Jn/NAURbtbi18YCjCc/+NfStTpuMv01C7+iZSQbO2rzP2aqaYSSvIx2e3/e1fWfZfTJBMQzcfs"; + sout << "zITZqhDilOIDxPPYv9xAkBxnwGXIkPAi5133yWC5gcNh9hzEq146cKvQ8BFi1RHGL/yOiDjADrns"; + sout << "Ky2PGu1fBfFNfw7tqF03f6zNB5IVWVjmnKtXhM2bIGqI/EM6Qx79bcYKYXcFg8yYbTc3z9KAcgp8"; + sout << "es3R4Rwfg9irMdI0DK1CyYnpoa+zWhUXUzpTbSlFx0zZ0Vfx4xnPh7ZJqJ7/U/QNAU3lUP8LXVAP"; + sout << "+jMIEj/H2mx80xaTNCDp/Sl+QcU1jZYx1Ycl1B4gH2/wsYYr59q/mDD0gE70hQpsE9ODlZzx9hoE"; + sout << "HFRmDUTI0um1u5fLykqAJFVPcfk920vT5TqEZSMX0WvMVxqvN92IUxy2JNpXXMlBiekCZhEKKQqg"; + sout << "D3MqZCUlGFkFpO/zc35vlTBUCuTvbcnak5HVcREFV0yNbTyukhEJzcdZArcCZs0xIbM873u1y/mZ"; + sout << "+dUcW4zZJ4jNSa/Vz3W6sj41hSQKrQfPyeeKpFd5iKxJPXDxV2iuv+ZBgJi584KVpYSVOUA8Bzi0"; + sout << "0kQh/NO5BdU1qaekY5lfrpA+O/tTzkTL1ZslaaBimJapn5KZcTtsZR0YmrE0E5okd4DxUccbElFb"; + sout << "81SuaH/YCvib5kUbkXA0ooelAU5UmmMCZ1ArVVgwlLti65IqtSWD3LB8rIOUT+QKutHVEOpNMOrJ"; + sout << "A8DnPVgF2kmeOngOlOwb8Hgpg9I9YCuu2cojbmROM8BZh2V6E9oiBTeKPgqgZ2tkgqGSf27SEpPo"; + sout << "xHPyKMhQewkcq3WQ132hIb0s96PdEfniPUSig2MLn+5qBZDFLHHuB6K+pl/6OF/h7ZaAaCa4NCA9"; + sout << "32ecTvY2MzUHmBM6CI/QKLRtBjmE09TB8s/5/X++k3cde6xPE77c0/G6qLrOUKFITVb+iLVADv3O"; + sout << "zhj95j7URImm87lmo9ZBDTm6bbKoQtNAzEpBAiGs/VDuFfA5ZoQcwTwEnhmOs/4bT4591uwosx/T"; + sout << "G4whvzgfcbPt9yHc2nnSSrIdYTSAwFZ1Ksh36RgvDXxaOHozLvlduvPaJkYSk9HijbmULXMUcyvD"; + sout << "6IrzOnn/8NY6m++hE3KxgFM+hlrlAp3qQBDEGj5FAYFu7AcCT2r7YB5kiuwjM/7J8KyJYWg7Z6Nh"; + sout << "zJbYleW8QA54s40d7D6AmOBX4+5Mx4e0roOM7WEFoqvXqK0ebAOrgwZ7MmqtGwe+5/bqkfNqg06m"; + sout << "o0UCVDLmF/SzqAsxgIJt1P4/l8HzWtrFINUcBwcDY4ZhV1fbWICOmej8aBMJ4j/iUHLkxbnzQFBr"; + sout << "xuNgC5/5I+iLjiYTd5+bdsITz2aRpXgUHx94KdOtcdEJ7uLjUeo7jeYvQVt50x/V2thmd8zHebNA"; + sout << "6Yj14+3dHQH8c29EghfihdNobQHSyw9eVo8RyVI40oUxykuk65ILT1mLbnASJcKIInvh57YMVTRv"; + sout << "rpPmXm/b5ZUfTMbzpTwO5DThvD+TTGvLY+eC0awfepSJj9CjiBTqNBJDbd5yKVjH+oM0wTufEUgP"; + sout << "Fr0fondoPNTB61iwLzHrGVbzU9Rl8rsuiQOr6Yz/Dj9OjgzyuvDb3zMilAEcLydUFmkZlq3uVCl+"; + sout << "LTYgz2tAQIl0+6SWZCel1bGyRKY50U2bJO0dbHawgGrUgZ/OO2k2HJnC+0+WKZeSj4s3v1OfCZbq"; + sout << "F1nQDtk64Q2X857gH4kTx3MQ8gZCok7M+sd3p3eA43nUuis/SwComUQODyRA+TmuoE9AdASWTZQw"; + sout << "Vz/To25UixKQtov8has8TRMEPOWqtZGHg/1tGGyZKkTbHQlulKwyb9uusuzC+FAQUEtpTcXkQzDE"; + sout << "RNHKa/I1/zTSLVnaADsLgrQA13Igh5w86Bpn1bFgE8iEW5zQoNs6KDYwsRQ7mMiJI75ylVIvfbN1"; + sout << "Cp7tpGCKCL/8xtWWII5ygWogcGhi5abUDxG0WQOaYk3+kPU6w8OPpbaT7h9TDO94XhcwUb5q11O7"; + sout << "LAKIZY5+/7DhdVfMnplHpLNftDihrGwf8r46FUrZ8MRHAREnHSsch5dYY3boGxsFoNLc5r+fbJVD"; + sout << "TxslZ5TEbJCFrBtFqbT3xuLReyo4FsOPZd0N3ZtFXt/ChTutUudGtCLfD3ucoJohZuz0Y5Mq5sHQ"; + sout << "Kb1vAx9sRwpFnpO5o+RKYaGh9QkdrPXqUlcyXkSrXeNVfGGyET2vZB4VtQEAlG0nF2dkiF2uxKbc"; + sout << "6swYBEYWLaftOzle9xXK6NrwQCrH/guAo10Ct+7uLmhK6P5MZW4kxyLi5nhOPkB/jiASFRPN1NOw"; + sout << "KdNeaakEO7PEKPdcIhWjvw3sePbhOyHnmdJGHRB5WUvjOqD553+DFu4pgpUPHx2PV2IfEF2Wtfvk"; + sout << "1J4pPO9DvXw5kBSFmoR/VdU1bGkVfkDMafaHt2diRtBZwbJb6XgkzJsuSiEtzaZ3BLaieOYJKCFf"; + sout << "LK0/Jj4SEEDYj9NIHq5MsWH9PyVYvV3+IpJa/AB/WHXF6XAMOsLR7ulEmViNjv3UYOiWYEyT2v2o"; + sout << "hYYg+6/lbwOg3frwE4tHu8IMfIZZzoFHbfT6MnP9ajoZWuSbbIj5KwvJ27wLUsQPevjbTypUXIEv"; + sout << "W5WLBcBCegLZyIdElFV+1kJmSF5dZhY9RJeDOs2/CpgRctMX8hUGtMjLIAMjm2yzS6FgTtuuPJdr"; + sout << "ErWoYhOWgGSZEitvtLZn6tGl+JZghT2F7dfmMaJH8y86txR7YCaPQgThH1zCh+yE5hvZPXW7+OCj"; + sout << "D/MUIXp648v2+9ItiAlAZ+ho+JzhIA1jgC3h5fWWmEc7RnWvtyiatDFAGo9Z4UCz1qtRqoKdlhr9"; + sout << "loA5N8Q/E1vPO/G3X++gNdEU6/ykRYyAODNhxYTnkiT47qRUzqSgOwsi6ePyQVuLDbSCwcpZK0ps"; + sout << "qR803wEQBh1/nA3ZUeO9g7ZGaybxNoTj6Qs0u5vRNudI6OkjlUMMTdXVWf5ZyV338dihwHho3RQT"; + sout << "Y6nYvFbud7IP9HMA5ga98s8z1ImzGbKGTMqy15rL5/PVkyddf7MArGuMSKWR4F0njVU40ILJh4Ns"; + sout << "xkuzzPGZHb68JdMit3OazsaLvkZ6kBlq3KwWH6BISfiFv8eLYRxCKSfGWoqcw1qlUaFM1MQTJ0Zf"; + sout << "4IK+R9+6/GGQJmSEDCyK1t4jbWDr6pjoXOIsihE/hsJj+Uy0mcQ/pRQpxGc9DvL2/7rA09Lr+Hk1"; + sout << "xmm09tLOMEtd6mNexizTXHotFeN0+M/fMRwjIUY3lZSOHFVhb3AcNRcVE8IhsOnrUu9SuZCP5cLI"; + sout << "xpc66aGsMww2jY/TEMBmWn5IUpX3vNe26iJ8kXZYYqPKz67I1BmH1wOXBaHZVznvMHXBMF7WhwhN"; + sout << "gq7Pf38r16eou+L80O0aAaHE840lzt+jynl/vNb5tZ6F75G6jIPW+ARUy002bD+n/k1lnmwZ1WgP"; + sout << "WnxwFloLan9gTO8dZG6uPyDHL8HaXU7cpty67AFbDMnYnajqpixUyDYCZeeYDvcCv+lqDm3YgKPx"; + sout << "XQPEQQP+qpwvPcwLIebl3F15EhdJuy1gdwt3Dyca13/+aQOq4Rd4AYFl9U8ehlJsrUOus7Q6xohm"; + sout << "J6wNJJzJpnRwNH+qekXnqKqRdTUwcQ8JwcN9u4sUE4pmBWS4/hJSYHpFVaUbMPmun0Dou54UPzod"; + sout << "HFGXXZd76cZ+1MqaoyofxHU95d3rh/47EjPtQInAKAXmQlud1wYLjOhdDwGyvhPN3XfohEPrH+s0"; + sout << "n/tMDShXPyFIaS+WLQZbaid5LuarlwA1rQc5cJ16xtSrW9sRjpEM22MxGmCyyuR5mMhbYLlULR1R"; + sout << "b9UJ6/BbJ9iqO2LOh4geqx7Pl+kvdiOwAfXLP3mFZ8PHTHk9mJMa1TlA/fLSo8O0YqHMTnFe7OtW"; + sout << "qMNzs0Pz4ucJ5uGzptFUXyCqq4vmSS3Cx4eolZv36mgGa7Ll2TJDCJckb6m3lsroRCTs6zF4B2yN"; + sout << "VfTekTPXttFqhi4GI85Y+ZBvMoOIAsvysWCsV31AzKg/qlb2+dOumyMScT0GX5+l3OMO0iqG1QXG"; + sout << "3fJZOpDQHq07WuCisGH2DOO1sEE43+d8fZHb1l7PQ2w25w6hq0IA8u73oDWRyTr5IqY+oDM5d2K4"; + sout << "Ysuzkl+odreYxlUvxCjwgG1UkXw31DQ50xIwpChB8rEd/5gW10lnXK8XTbVgL3fPuy0cd74toIhY"; + sout << "JFg1jyA1ZST25w82oJBYq8f0NQtHc7bK01cS4OM0hmOGd7PDKiRGiIrsLcjgfw/J/jx3mBuEaR5J"; + sout << "X0NRxcrJZN0RdwhmwfRIybRGdSX46JVUCrf15yD6vFi08cwkMxqpdn6fDfh0BTXzhmdJU0+219/a"; + sout << "JFJBFD8iktFAl629X8denKKlc2sMyBT0IODc2OYlS1+xn1t13uEeLar2l/EgG15/eeBsELvM+M61"; + sout << "7uSI7t1szDmeZmWrO5yVcOBq1HFoEmHIlDJNtYQ7sOqRiaJehiRNhnj/9umPDOWCmBdnIXYAxWMB"; + sout << "TJXpWzOOeRvpFNXUHN8MDttwWjx7Osmal/5ZqIbX66+jdQrLih+L4//J8oLyCBjawxIj+4BVLSpI"; + sout << "/jyMUEbzBmRM3HiOoqqgiTFlLeRAM666O6oGhNmrFC/Zf32boZ6YmU0OicO1P1zsaRLJkvMbCYaU"; + sout << "Dmi4reAoKBG1PBqaWEhDE2VmjiPPCs2WqZfll85wyGQDgKKdyboLm26UWwW44oM698E6stYbg8cv"; + sout << "P8eyOR8awT6WwPgOS3fJZFuKTgN5LXJxk2BVCzx1F3gnmMQ5FHhIVm1srDDFIIwSkuP9mgjN9lEI"; + sout << "SGCLsFYdEhVedNZuOWMYCvV4RChI5XCPh0ZRilVbCP4kkj0Xaqq3SbNWD90lwwJvFdIkZOEPeYc3"; + sout << "xfLARYjHqHbNkYZLpnTw6BPH8L4BCDl4v7ndyl5Ef5zvShk7L2VIyKUK+3lboiVjwDoI+k6A7iNC"; + sout << "UW5f3fTA9D/s1hxS/LxaFU+9Z/bOMjiRaPLQ/ioOkwyrUJpsy8/K+T78Vv5XXqtOct0aXrtUydgP"; + sout << "uWZKhAFxJe5467wlC1aja4hHwyEeQmn0P8ilBq2G2hDYGEnKEF3XN7bwyFS9RGFdU06jkKlop8cm"; + sout << "s9bVHFI4NeItHNO48LVlRFQaVdq1p5vD5elrH1h+J0s0BzDauJaraINeTlF/PrpGzdxiA/vq7sB2"; + sout << "DHPBfrxe/yTTu3YTnSzMYR5mvPdjKh9mH6PmzTYyu/lUN780SUyW7mmEqoNT3LcgPg6QGUfg7hqn"; + sout << "HP7ml/WtBRhZh1eLok4o0lVOhX5S5vMkXFa7K2pObSkGyyItCW+Kk9SkH2BboW7aYomw/3tAd2ce"; + sout << "FUm8H6Q9NP8KDpaebg/67KZijw5kCSH9dBi/XzfZY8vzWGPW6Uj6zpbcDfRQDb1jwrPKZRhQHHiT"; + sout << "dU6ZscvJMqSdHXsCNByBtKrUX37K8C59AaEmCdshNlaIOQLvceIIpJ7XTfmtnjSSvK0zJzBWt+xf"; + sout << "RdpAeZwDNK66MCl0DIJwG0aRK0nwuePf0f2k51xlsrnl5nY5rM+hfblfsm2BKH6BH/SYlUSLxQlC"; + sout << "MhdXA3MGhxN9yuHCm0m8WVc4/4g+SCKnp70ZghQORzaijlU/0KS6P3e1po+y554gEqicyXXndnwA"; + sout << "qmRp58+1pxbwORmCQIdgriAZxPQl77hXxCMES9WQkBVb0RKNnbWCWvVdkwxxYHBGhO+m4kZVjFwW"; + sout << "0jxlEcKNa7WdQStRXCap2VDFA0c5kUS7WN6R9HbM/u8xMNLzzE6jrH8l2h8QrLkjHE2R5y/NIVAX"; + sout << "G2+cQp2X9XxHVWQmfflpXPKcDGGYR70ALFrgeRkQQAxfR4WBDLLQitSat8H/q+yyl6hdGj5nmYXW"; + sout << "BRy7URAh0ibk4K/Bs4d2mN8ivkxahOB1wPZS2FFHP/I9C7nikHmQwlg4VlPJTwLp25JDsOypt2sH"; + sout << "J8CK02T35fcv89Dg6GKoIp9RFkJ2RcSUXyWP3JWFxH/u1rdTmOlVP7mAW4MgT8AQdXWXkfaJxq4S"; + sout << "n0ZRo9RT68RvAcrdv2z6Hwh4klmlRw9qFtokZMGLk8BOtPGSOU/Av3k0RjEsqHLiQaRo0JpnBCw7"; + sout << "JcdRwIuqvfWRJMpGF6DhHkIQCFZcTSftnNDKUBLVBn41O/VnUCakShsRIy0a3hjQpg0M/7hFbZI1"; + sout << "wqiEu/aR4jZE3wixAJhXFC86bb59UwlH6TZ1Eszn7eH0inEqCSlMMRKTo5Bz5kybjP8As1pkPS4Z"; + sout << "/A8L6SnY/HIm6jy1MiZJmzdimBLwv14JNZ2oUM/TWMzgnFayE0C1dwiX+75KOrJqLl5hnM/5r+Rz"; + sout << "dB3emOb6i0OQI444QHxHLK6ZkLlJr8kY5opZlAfjyzuWzJUG9p1CXDfL4t99GOoq8TebYYdjmbyV"; + sout << "BLn51c8VBQ63UM/dzKKgL6n5aUWAenJ9sNV3/b+fmN4cPWAv6TZQTDt6XZxPz+GT74rFhMv5YcwV"; + sout << "zkuiYPllmx7WE0pegs+TV7BpKef2DBNlo6D/7ELNsr7L3+rz6w3l+m95i8jtIpeoPRzNxMWfvFZw"; + sout << "OP46naSo+j+8Euy3CnMnx5HQVwBpEeGbqjwyhsKmRyev5wzgYGe4mxNTgrTKsVh4exFOYkFJovWm"; + sout << "l9VKTRY4/Zd0H4dodHVquSWT/i6ogNNwB1O4T0I8hs8T8i4XybD4hM//dNvbPxq82HobB1R77rqz"; + sout << "7wGi/X9MmX3ct7FIpU/sYUT2DeHZ8QFMfiBhSLqoHcJj2RgNr4cX9ZbFXk2nzw4Dz/ggTDOx8XRO"; + sout << "qLovxOk8Sw7KXBDVmW1/kBhLVN/yffDJjzwwk/JqiUI3kPaLOiiJnnLRH+fS52mC/+M/xg7JHde/"; + sout << "oGmlCCGQb/1q2j7he+nBCFZM8F6BVHnlA7DoGCzSetTLkOLo4OZssfDw3OSU/2vxTxkZ9qsarMCF"; + sout << "uIl0MrtO59MbXtiMtGS5gcZL7Tu0SMDJUQYFtoZ6+QnDgHD3P3BsqyXfjRR+DTGvcVSyh878HpJV"; + sout << "3R786KM3ZwW8iVxFabz1OOFx0Ox1fCspZa3L1ckftSTD6ypZAN4f5ox3Ad2iY3l2J7iPeE6/a9sk"; + sout << "MnHQmuMy7pMum4VVepnd3lfqJljkCG5ptnznCTmp2wPgkqmvzXl1UU7kyYrdMt1jTzg6fiKsbROp"; + sout << "wGSxXLbRGT2sFjU0KExb/rjG21a0NVVp2PnYMQuyF6glGmCeSqohLsV3nv7eCbUd42IacVWAR0qN"; + sout << "c+5Dj8hdAmRMU5DJrbaLpuHeg0gouaBYrmCYGrcMqcI77VLyeQGF5hzq6t6IhiPk9uLrW2tpF5Df"; + sout << "gHr/I94WMSy4Wf0RnACjVxSeRewTXJDuA5pKq0EqNs2RzmiB7LoktPBzqBgMjCRyKlCAzwZ63XOt"; + sout << "vOvKGWCbjXcVEOZGAY3s4lQRiNIJ1Wat62u0MSj9BERVU8kd/NWueMn+tbdGTSR6Hgyk9bvHF+Q8"; + sout << "CdFQntiNHHU2Qx5ZbzPPoNNjqwy+FGicBBG8/xV2bPd/JWm3uo/8BFPqzORzGAnG/BHSARPbEvkJ"; + sout << "ix1Nw+gepwJ/M8YNp2aE+hYTBUBJueaSPyPR5FS6ocR5oLd5j6biPlTv9gHGg03Wc4W3BqrC4/He"; + sout << "R5sSsRThAGj/wbDu6SMIoT+2gUNTbUZ7Lx+8rVcWGY4OLaySvre1BzyhGPpNDuIhmbW7n2jfEO9C"; + sout << "MYbtPL7skgqV4L9JlLup+vbCNC5ccOpxm0Y+XSA/Gidd+aZu9D4R9Ddz66ns1Ida6bWLFM7T8BpE"; + sout << "dYgcQBJ63ixc2nYknR7w5Ez6JPnmfUSDjXTrY+tnXlbyWIrF1fI1Og7V5i0d3G2jcuitqs0v6SwR"; + sout << "rgEDVnVz0O+cd56TOssSv1d2KPIY6tv22oH9P1sNzjBYN9np1jRH7e1GpnoEvY7MA2gNyBIWdrKA"; + sout << "uUs0c5gYr88/DbJvZtSnmPUMRr2Lp/RtOHnabXkbAevAovdIvmh/BfnXnSjrOkH+BH3Ici47ohO2"; + sout << "Da/7OlZj4wAESZO3WO0SKN4fUwC6M6LUM2L7kg1tvqLA5KtYokFOc2dMSRs8qz6qpneGUv/f8NMu"; + sout << "PphJtFjOixUrhCOZVv+0dN94L3dNPRyjXlDp4EejgUxWdeYVPxiyXor+LSIP/s4yPNmn6wIYWF4l"; + sout << "hIFjh3drkdxwsaS6QWeYL3knA6I2rR0NERayOsVKxcVftkonwoZJHpdRkhDhBj0Sa8TicGcSgQt5"; + sout << "kqnjLvOqNA3A4XkXj445r/HiNWVbd/v3DRYn0lh09HU7ihGo/qrC1DDjIQjRm5MO64mM9xZY58yU"; + sout << "2fHWbglRconR2QeXmQ1D71aEvNZwBs1G6s1VqbCNgpeA+7kvjlzwFzyvfIqMuLTlvUo+4uCHf03O"; + sout << "im1AxMxq1zvmRPM844aIenx1bhx4vKmeEwTeHKYvjUkTCMFqOXxq0UeK5XSKJkXrK0F98MORKeEI"; + sout << "lhpQXL7ky8QZChU2k8DDBaocWzF9oRtrXDY6vfuvcituo7gbfVIDzQrp4ran2eY/pvucYyn/Maai"; + sout << "W5IQzzordTdlVKX99bpoGJrmFzYrrpUa7xxb5Hs+YvwicR8IE/2ePoZ0t04dBaIwPP3ZuPO0G3EA"; + sout << "2XGdG7f9AARr1098ES4qU/JRanR/5XZY16deNuVPRKMdUEBMKijfHoQCCorvA2HoQz/jD1nhg2ZR"; + sout << "cB19h0shhjnz/OVgb60bN4izBugSDZfyiuZvvdpnCbc7bReVBCM8456EG4mHdI1NSdGJfATpvYbT"; + sout << "lKBwrESSbmkYIVCcBxn845i5lrPHMdlt0s0qwtcqBTvjBmhRpezoF/U1BnPdETs303TMgApql/xS"; + sout << "MWS/Jme3ioI6E35BzkTXABoxIxyjdFkFTququq686oqyGCSpCxlJA2OyXwUc7XkBI4Ki35ewV3Yh"; + sout << "DbONeeMe9NAred9EhUBg+uLO3W8oI3094ZbMa7N9b0tVpewsJuqP7nfW4hZtLSpFEtjlQVGsqfQv"; + sout << "TqljECWVoZ4cbclYUXTMqtd/5okENdDEk4BJnmd5b091PoGNeSqYPf94LVrlyVSM6G9DmQm1UTFr"; + sout << "bUipf0LcSMZ+DKYhwIJLpP/GppEMORYKRhDtgPy66GNfhiNKzLP+fyPpSf6LbtCUbE3Ni+PJ8zux"; + sout << "k810i0by9lVlFUW9tNnUcSE7IPrcFs8fozpLwNSI/gTKFy0mNEnHJ4hlSgaVB157PLKmHSv6WJIe"; + sout << "6LE89xQbc/Rl2TRjh2SxZXB62FoRLXApUFcSdCl4ZRdfA5p5zXcQVWxl3RUdJbpUhsqNA7vmHT9q"; + sout << "qinoGtKuPZvANZr9Yn5hB8zzHLiRvk9NHhFdR/Ibw2gGXq6XMuJmrShSFRMZrTiIVtQeBWZNzQBL"; + sout << "jViE75H59XWkpeUM8hAUUawEH+CFcYynOsbTECPSbgIbt1w6DRB096vYoMnxnttX3J7SIf4JtpEZ"; + sout << "iSI/Xzdcu9L3QX7bQaHzaf/FIqWvdl3quN2CrSnBd/UbGCdgylgCQyBtQjcr2lP9O3QKav9laEz2"; + sout << "YJbjjXldgdumq2W/0rTopOIX3V3xCHB6xkZq1SS773LCHWAaoPRMwgfrJplzRSNIaD9fRsVQuS79"; + sout << "TUwphVg4iGrl/0u8YLEr9TmSTojfTS3M23HNJU/K8l1zOB1gfCYwmj6jGxTCuvDHvgpJ/UnViT2j"; + sout << "fnQjGXg/WUwEoaD+66ZePipeu+PlIHDSX3Jj2+jHAmMMsFX7PZlP3xhWUyG6U0hJ1guDe4HOcMWQ"; + sout << "97gJDdYzwSE+S7vB1LQWKNdEnZPmZ+rQ79gpjNYVfzE2djCO9Sn0rQD5W6O7Wd/QCV+qu619i1wN"; + sout << "IziA40bBSmDLiEtZv9Aviu7q7Au24X7ZcXxRqruz9zKPSFd6dhpWAK9lD8PCNwUrDXAvlHT8pS+U"; + sout << "8K2ISz2knDjOiHIb1P4jF374K89RNSNvefvkGBNrJ8H7JiseiuhumYfAiSVqBQwGUeh8lnPq4aeQ"; + sout << "V0gFbCj8kNEs8/cDLBnnwCGZbqRfSGgSXTKWu3/oSmn/bMralWdW35AgMM2tJ3Fo5fIkfzsq4ykT"; + sout << "ywlPUtOOHIS8ImXHCQgVt1+/OPbTpgCly1i3ssNeIkjPY5DyGk08WWYiih7IAKxjFeHSavuYKy1P"; + sout << "1QMI0dJ3y2Wp717Zg46c+iKxlg+oUR7It9UT+ePKJCm/UAayU7LyOWThMfP73hM9k5G/h+8oENw2"; + sout << "D/Kj5IkwB3iK6iC/6QrZaIHTWVRNYioSzsEndB0X9NhBMlcTmyvPJfctFydhK5JjQwjjx96k4FvS"; + sout << "P3C7zh1ww1K/x1J5vK/2NViZxgSFmIz7L3NbwUb87eXPictC8NVReqskmSw8hxNmgTHG/WEU6NZi"; + sout << "XRI+HIV5erFIeNRIixP7gp1z/o/5bL1QH5ALgBk4TRc1B9qG+fcWfJg7o0NckPbkkueM3u1J79mu"; + sout << "pcs18OQ2A8zehkFZeaTCC6Q6BRijJFBpPNnDIxJ7RDazpi3i366Hds6mD/r8IDRT8TbWqedF8bNv"; + sout << "vGRPs1aEWykMumZdtkRZXfByK14T0gF0CSjUeIRnPuhXbfMBznBUGue7eCj3jHnLynEKdtuA7CyU"; + sout << "FlXmJWrixsWkohwNSY+2jCNgmIWsUsvKS+dlNktC7OdSo+QEIDnCfPio5AwCTWBBvUaxwR90GeWW"; + sout << "yk3Rob5grT0wgxQGoCoew1J9DJT6iLEuy/jJ4FjFSB581f5186b4lOYaMoSCpOp+uWEz+I0c2G0o"; + sout << "DZKFloHMTpZyEt8LggzchgxypR0SGSFxm1X0/lhdPwa+8lKRdcXmhe2ai6H0R3lhR9ZTBaCWhMMo"; + sout << "GJ6SvT1ttMF3rg9qfFc1RKjK4+TsRxwYOP1XDz1rjEp7KWHLxZ/JKzp4IXomPm2f6exs+miIPME/"; + sout << "yo+Dtmev/tYJLd9poc+LAVkFIDqQ5ID08LgCGwAwLz7J+nLxSww2OjOxhbVfRIm0Bj0ERLmvy0jv"; + sout << "9Oz6hHcYVHPqSGN2aMsPtpWUWOgND2ofYfAJqqUMbwalMYUyaimuIU9ODyoT0xIMH3VeN/nrIHDB"; + sout << "u1J1iY/Lshk4okLUB8OnbtJY8t3KATibzVtoLaoOcmcQzOJahRLw861zYSlCs7SQpAqVEoRjYtZC"; + sout << "5JPGKLY9zqoogM41u3uy16DfVEvDYr7XZm9PRYY7H33QZiOYhPRgzZvs9nJZAJ7QGRDYiNQG5V6a"; + sout << "Zc/qLzF9UaRbUOxVKJ6FkdhuIuK8d38cN1ky2uUM2m2cFAyJmzKosMn1+ZKo85F097lSyuDqgXvF"; + sout << "hD9fu2ze0hTo2zqFmHHoDvA/3dN3Ilt5omMSSf/HQdvOeXj/2vtmJhpnERGJo0NCiOZfKI5QG4qP"; + sout << "313wIpwQPZxAvJzh9YEHdb3fv6GpwlEygu93G8Uu4WZkT7P+XXqhH5PL3cZ2MYQeYGR70nvUzJWM"; + sout << "lz7g/DvPE2aQHxGSlH+W6JEabNL1YUvAim6/uLWXvjbU+QdKDUQsml0P/auI6oHYMAgXnifheH85"; + sout << "xyT4TLqA33Bc0SRpwPV6lIigxQg8OEX1czyeYInQ7Y+YhaPF9Dh+YR2Kf6xLS/qnhOJrAr102gf1"; + sout << "apYuPnSogxRcir2pFtltl7k7xjaAcfcOUfM10VkYOoq92X9XTfbEBbeEAhFtw5AgWJw/iryapVb0"; + sout << "nNPvXwZ6g8e0LM8cky8/CR+68D7KCkxTQ+C3cq6Fco2niuPO1/IbF4nnpLB7vDS7qjrgqU4/t0bq"; + sout << "qMPMFgX4MhBikxwzWh6xnQgylA47Rbzy0EKMD1j4NLSkySg7zpm29sTTvEh68OnSBaaiNm7yoIpv"; + sout << "41yj7kewoFyIY6wqbc6bxfAc+cKLdxYIxwwXRknJQuZW2mx4wpsjXYqEwZ6OCDldMgYkQWuLd1r0"; + sout << "LNY80XRGlKZw4bV3eBiEAo8E7yU6VBqi/C+fTWCdaD3kjC1mu7V9XX2tMg4G2kaSqc6JMjOZ9VF1"; + sout << "06FJujVmJxzZ8t4Z8yI3JGUpmWgYnvY3baanLpy54Fd7M2JgoXJrElrfsEMFsci44CmK54NpntkB"; + sout << "wFoMnSxymuW02SsBU3DEBbZJmXspP2yqCUVHyqpp2Wez21u9T2LV2yNLpKHm8uNRf++WJ49zGBDY"; + sout << "vJ2xirDEOgRWCQ/z4J7fg6AwMrG0PhqX5Rw1NStM8rPEWqKd9+6KKTL3lMj0rvvVjUCylKvev2Vw"; + sout << "aWeVNhBWgelJ9v/Kdp8lQRsSyqJ7JP9CQ9iRicBTKpKio0NTaYn8LEo0yH5B8HVycOVJmPjjYCg6"; + sout << "eiZ6Vxr0laxchTlNPwRAU9hsZky0eO68C/cLrfgqB2IxJeZH0z1XWmsA5dGcYAWMDamqBNPTm1jW"; + sout << "USE/0iR1KPxvIemOxt2V565of+jareTk/Q77E+jkopN7K2AmKOGouPeRUGEmfrtZQY2sarl0tF2r"; + sout << "f3TvsoqK47umT9s3R208qwO7nC1TQjaEdP2KIfcPxbyuZ38d+GxQcyuID5+qS3IQR4X1k+lyUxY9"; + sout << "wtfNYYbmsXnqAXnM/mKlfGlYalVxalvxo9eJdugBIVbaQ3JmEovseS8IV0dZz9ItdrW3HtHquix1"; + sout << "xKf4VBhFqpCdVXBY0/mMlcWws6nItmS20yCf8i6jJ1HcaHGuPQvhyiv+xL8kELcu2avfpf93i4m5"; + sout << "KaJb4xkjSXNAcBerhBlZ+DO58TBlEZSI8UTw1+yuzymxOyNbq2o5rwqAqMDxmfJlJvbum5OnTWCc"; + sout << "2AgSoTUVOUmg1br3xDxH5xX81Ty+fMaizY/8P+dFg6MGikZZuUky2R/nJRilgi4thJV2q9iCm3sX"; + sout << "KVW0vt0AHRnUb+QkaLl4ZJwGUcx37S2YcsFCAk/GG9F3g1zo3zUNgbl7WelqICfme9vwcl+LDnvK"; + sout << "q8jgEUNwGe4ZcgDuPAxiOmIo0W0iRDbt97R56EnRs9n0E4W+/DWBvHY09nBYtVrQWpodvZ3ZkEp7"; + sout << "9zg0d0rwaLtb21iZPH1swSnf3hTYC2sxq8ppPJKENrTOYj6qiaye/kqg/u+rZmhd/LwNckkjFy4H"; + sout << "9OA+Te/sjSoRCoSSGXtGSexAhh9WfFMnivoJ+9kAIGRUCJPQCx+NPQ5s1+d5sMZHyGs6CsR2Hpjf"; + sout << "ZvXMpbZVEV/4Fgf7VMtSexPTV27hFxD0wHrhU4lsQp0T3QcI20VYx9kTz025uiMWUD29Fm3b7peD"; + sout << "K2bXo7y3G3k4ykOPrF4kR0VSkjB1NbhXftQImcpZiSDce0MZOZZvQ3lH0pZ7LaU5d++Dyp8xsxLt"; + sout << "4toEQgoyLaS22tEe1jSQsT7v4AWwjMBAQngr8MsqmcWyu1UYshp3kmosoxHZ5N0w+RN6P8dnRWuR"; + sout << "N3gIp6dBvllTNVG1Nf1zJ1r0w8KITbzsS4bBLijKlUilL4b/P8N81fEUDaXCrcdcL9b0fzv64vg4"; + sout << "9ZXsKIfw3rjw+/ZiDt6a4WtwCZvCyI7cRRB2YGxaeynx/Yyqq9xwnWThHA0a/SAwkokg/zbxfi3u"; + sout << "kzPOBau+RpDKYDL6DBkYaxMeTnbF9dRZnckTEhu8kcSh5uHhVNSeyYSngJ0YXZV+tEzNi8yI0wA+"; + sout << "6M9OX5gcHv1pbem3Fyx/3Rqu/LbjD19TtdwuyhXnnGcCXPGzqFrpEdN6504IdM7WDRrkdjpyTlOR"; + sout << "5hzk1GymO/IPAk8uRPEHeU+873xDLoxeuNseeNDRuNwFa5FcIKmIAEGeSydVPv2uPH6l0VQezasr"; + sout << "XMXlWOCOFo9eGGJW0p4tz4qX35tClsU+0ml1uaRR/fFeJyJ7+WrqyqDdHpTqnGWp9OdyyCiowxFD"; + sout << "T7MJ/Py4+ytZ++sArWxfC5KATdztRSzuquLbzKfLtF5x3vihfHzonsdHzZNUTi6py1JV0uQO6by/"; + sout << "1HjNlxiNqfOEaQxfLe/KwZvvvSVasZQz+64BSVLPlmwFnha6iUXYl0m8pNQ/E3QlysaeDvf6j6nC"; + sout << "By2Ft84JgN/2Gy/XJuQRobhPiDh2Fro4ebBwOjIIfgQx9Q9FNzKRDHAS0/HXB+P41/CzXWr2eTu2"; + sout << "PXZ0O8TshRiD5gH7Su23qpwB7qlNrPbULE06oc8ftAJtuKAGDPFXbom3zFVjsj/yeTmS2HvU/b80"; + sout << "8FCgSSxx8cvIqfPHLLOSPBUxJAVyCFPvDIUzHE3sBPyw9OY40bgSyaYoC+hoKuf0lHgt20r1Tz5A"; + sout << "Em0sAZfnbWRESfaM8lafrA+OjaSDQQW7wAI4+XO+HbWYxz6Zj3hu4qfmID2Iz2kuo57Ci1RLOUzl"; + sout << "sWcTkVqUrdnoGzjtese/OnbbsCUQIfCM5ygX1KKgmlb331Er2So/dqrYpCFAqBo3SmctlCtmERuC"; + sout << "hkB59TSoGqV6ePxidHGWKBsEHudcQyRv+pwNhnRkW+l1XmpzMmdQqeBjqcw0RAVJgDmHS1SMrpPE"; + sout << "GZmNu+LmuN8xdGV1tbylOTC4TVND5NOWaPDE+vx0YaRi2orc7KgrLc9OhtAQdTC7hlk/8jlmYxFp"; + sout << "xl0Fe0oW1fU0avYFk0p+dUXfYCA6HXTetcQJdo7Ai4JveZIUdZ5xW2snKTZXd/IAE1UG1o5Uz0LA"; + sout << "/y3Q2+1hBR+4cVnmqVftEf4ZQQoJlY2CN5Wltg8kWf3nsgSmu1JNtRXEV0DKODWTAGSizMJJleG7"; + sout << "ilV6GJU335MapNMUrlHYVNicy1NY8CX/PhZtbGzqNM3gfwS0KhWmKN6HS2BYwfPllMYJk4koLHWS"; + sout << "R9dpGgz8n+PDrKOcLG4U8xYocSvLUVJ6YmdbYWtWunGa0I0egn3YmUc6RTflJ6G2rRpjY/B//CY/"; + sout << "tAKUuY+4e+4/VTH/sm1z3yHSycLOorAeMjKnPpDfN5d4hyrqeHgG/PKdZ/QlMZYw0M2Kb9fpvefl"; + sout << "cqeu2IZaH1E7G/YbqnL4wRlPq6HSQcT7C4r7vJxwYGGwZ2wClOXdvvbdjAOxZITv116e6jaMQ3vN"; + sout << "NNd09s9zZlXyXDyBfdLgFebbKLjrxC+2eB16aURJkorv9jQydbdMMsZ7DnW75DppDplDyPvMC3TT"; + sout << "0o4PDSd+YBeWo6CAZcWNR+lB/ZB36tA/RIyCopqeUOL/P8cfeJSX+3k1yghKm4HeY6VaWJSALiiP"; + sout << "/0nEgGc4dbVmspZGzaORpGOTqqGnHm7TiZmOJW0PFrCVY0N0VqHcSkTOZn6PtuQKFXtFHYEtpepB"; + sout << "0d332dQEz4AxgF0qrhrviqEZpXTxG1bWE4uC3H0tDWIJhyJe16xF96vE3QN6z4P2+esYwjsfRdQp"; + sout << "3CYcxcjagfWBngiEfRsseTLhVf1KsFIXgafpK36gWlZan3O/7uyqhOYyg3W7aJ20bvZWuZZ94Gls"; + sout << "7VLJyTKS4LNKHcb9Y2PtX7nzosb7tL2hcYgGChva73FTDSuf56HcC5+KM+t/MXf+BPBr4FUsgN3M"; + sout << "1QrQiRM4dZte2DfAvSaZmWQuqtsH2uOLimHU2AAsfZav6k+7Qb4mQs6zdRPTMPcd8cUv8MWf3aOW"; + sout << "skdmBW3ausbs+iQFW1MUqr/4lwe4DUekn8nfniBCI129E9p0JLXLw9NrT+669orbCX3UOGc5UHTg"; + sout << "0SvFTMRqw7Y9mw4NMLOMCn8/mINMrmmaTvtEwLhHSOTwYjRCJ1A353ENYX3T551pA3QS1CrWfIrI"; + sout << "apV92iQiSH5A/rzOjhq3BDCU8HIXihgVCsR/w5Kcy7nCkI4pvbsIjPNivZsDhx8H+IKNTUNpSRt/"; + sout << "CPzh885kcdofM5k/nqVvvJhQzhpk18+IRuxTp5Dqa62MYZwgEgURS1lVVzWQX86wJl6/4cbIVEcb"; + sout << "1dowX6FrycYfy81NkQdF06egVOZVYCaZdW7f1W4Xn4l7PWb6pb1HRb3/0Wu8YuYjm8pSY/wA70cJ"; + sout << "G141hphqUxQeblVInhyk40zcaroWoWWGTkiHUnAp2cQgjZJRHbrfoJgF5UVmdDJrwdwiMREmR7zZ"; + sout << "uSS4xTZ85fXMH1NGcOAu3tFt97bY2uh/Nzk+XucsND69TUp4s7RBmiahbRw8XQa5izZMRPgYkVss"; + sout << "A4ggkTd9wiuotgjMtjd+NsitkGhuoQiQHLW0R0PVPEfG5mtdlp6/PlqCsTum8bRuxivsPkCcjEB5"; + sout << "9gdgEeSfSOSRTe0hnm/gDQXs4BnFeyG/YzHm1pHHIP2VHkr3GlWVu0w7tLPXFLXF8kRESZgo1Np5"; + sout << "WFYh5OmLS+J3xhJMfzNjUJ0bzfN5BnlQMDT8vttfAovg4IO7zFTexfLBxgKLWJZ2aQnoPeoBIJFk"; + sout << "uPBxTjUY29133ygl2fAylnqP9zIipFbgkLEw0NrfKrfKIPK2Kf9VI1+adXQlsHo0vsI7l/uWbW5n"; + sout << "XMP598ROkdbP5aSkpXPdld++HGzYtfD76teXubxsFcCpmm3OSbcLo2vAOMDQdpfFTOLM0ig1VhC6"; + sout << "tCC32GUIWix0VLb2dZCdidXJXu40BNL+ASvXUAipVBD+qBL0uSHZXjns8B6XFD88zrW+reKH9ZO/"; + sout << "gRZI+FH7eqxkZlu5ChFHJeZVSK1MwHk1ikQxAe8jVXTzzVAMcK3EVGvH/l0RKX/SJvU58t0SuHOr"; + sout << "HVule85Z1pwxt5DWS5IYXibDgxOWfrcKp1iTXhhKsmOmqLA4M8iGbtJ59RjGp3ndoc8k5lgx0YXO"; + sout << "CZKSNjMhF/swTPJF3fvVwzBU9vvWe2HCUTejmLi+bivqW3n+qtHiI1y9VWPY9B90N9OalQHJD2OB"; + sout << "myHYFjs7AVdwFYoQZIBFtaKpozISrGZBa47RfYaS3zH2aNzYnT6zDsgKJYdA8skR/tNgMQPdgEtD"; + sout << "d/NZPBXg0YnjIJNX9z3kcvXuwNnpv/ZafofCuHJQDRlWe2hvfHXZJf22IWrkeFjTGAzVl7bGry9d"; + sout << "zaaFJKt73ogJ4IM36xHJ4pgGuP+daAnif+2sdHoEavTV6zvkV8T2UXgdP5LmPVnoz0zue+DBTt9q"; + sout << "Lco7T8RhIgRDQL3K/PB23bc/gYdGKh2ISHY1klB2pUWu30sBa+g1ap7cTFIC5eDQ0L19jaJNlLWf"; + sout << "MQR0Tshn9dRsznAfmjIm4RvhMvf97Wer3t9CjFOOaQfSXxdPQ0KpnxeyfyuBpRYScUX2cCEkIXjA"; + sout << "IiOHn7aTJ22O3S7vZ2a4jVygPosfhvd/nNfWhAukM9UfVIFi/kG8CXvuPjKhn77k1ddfIbsp7Q2L"; + sout << "Na7xoC6HiPus1PI5kd/p1NkclRSCiMBCl2qQfHqVIioK25jAjHl8SxsiPpkGDgiBJTO+HR39syAt"; + sout << "iyjvIv5KH/yAzl+w13S68pDbAGZ7U8hDI/NKlhgVycVmkfKZzGB/9wzJVemHhvosb5tY/P9ZHqgB"; + sout << "HuIUgdxusvm5IyzBBl0DbaMxwdHevHQgJ7NmxtvGDDofkHQ8JQXkr+N7+Oswa8KAEdJc2t9fST82"; + sout << "X2NPa6s0wEQ9PvZhUQd45UcFtXCBbkCoisIEA4X5Af1xG7gqwo/PqEW3g0hYwmbWDgNTnEHaHSsI"; + sout << "4k18Jcem/+3cNZx6zFKEZ6u6/JlvGJsNjvoZQlRluwF64lvxtmvxwD7ak+UmFc9Ll+cYhiHhdy7h"; + sout << "e2RAF5HLTVf4BZkUs+NI4h4c5z+pU4dxM4JOVNQ5OoBisyMeFbFDlBd2x6wRMtiREIF4h/yDm5/g"; + sout << "QIQN6sXCt4zrBLoPvFBvC7UvVPaLEpW4Ay7yeYoFhTUcZ7G11HdyJHIrWhXj8mazBuT33wiXPXD0"; + sout << "fgDO2G5kVrYv837OEagI68UWoYwHFVQyKy7rXAqnJOH6rUOiSeXwrAs0VXCOo/iWz0+H7Vgkjvo3"; + sout << "a+RGXFU0PakfqhrS4elT76gTx27rIG2VuKjWnNWi91g2MnD9dwS4xGNu0nB/HiOTDKKfzQ3uoE/D"; + sout << "oP/Fn1wFPtBDC6935HgjJA30RBpTT4lZtqaVhy4F/pFfqFBL/RUU0jmlG8WvHSF28NkXMJFxqJqr"; + sout << "FdZj/ImUTWIBtAOXIMgu8OQ9LXpd6kyh7dpy2lLG5ik8X18yJnJS6WZnRAx6nX3Egg7zi+jq6P+B"; + sout << "AGQecbLJ+ngHGqQd6fP8hOU7ujU4yyAEKX3CWnvjK3DKN+fH9enZ3VHAgBS6uy4BicQeqNJx0D2k"; + sout << "ab3+B1WPW98ieaG+G6TjyFGNFwqA7ComKR7X745Px3MW0qoQXJVK6FkV1x5zOjrI09hc3G/NqzC8"; + sout << "5hR5EA4ebicmVwYgc2rftfhZ9FY8vTslWklZHaZosv3/QsoFuoDNMA82bg0/tDUo7G47EAETe7lj"; + sout << "7MRS+dqjw9rMDlry6erLVu9vOvz5gVxhNM2+Q2VKx2f2ugkn8/VGq3ufEEoEkKtnLI0aJuli0ZFM"; + sout << "yEDEHqombS/JdKacUNhM/zevqMA+F63pABb/wTCYhNIZWZM5SRCoYiduOSQYs7eR4ex+Y9a2T2cc"; + sout << "UyRYpcFRySR4XB1Nf5GVdBmuuVqvg7wuRwCKXDkxcesKSrg7HLlUT9xaIReIHoFsZj5w3RcZvSCT"; + sout << "q+d61nOKgP5ym2c5BwxytP42yQ3SisRyykroBgipHm7QmrJnrsw0OZvAkFh/bo/mhstLDPKkoA89"; + sout << "Dg/kr5VTRG7Id11Qb6CKvSb9hq0mtnGAGnYa7XTmmojYhVXeFF1aj7E94bUVjtl1rJHlpNSBRl8U"; + sout << "LffnXik6Ju9PK+gw/WXTgcyOM0Gmw91IqJsBiNVmLMB8VPpunL+J27aoPVdSxIpR1H5PjWhWyqNz"; + sout << "jDTlxcIKpM0PBs7PyNJ3a0QSQYQrGDYI9VkTQJpxwYhPOAWdIjTKDBfx/wl5UyxbcAzD2e4mXkGq"; + sout << "535EjhmfhG0SRvXP4heNN4Vzl7KMDjO9VNMB4/L7500528rL2b1wC81igCwNWpUL7Pn8/XTK8jRz"; + sout << "KCcb92liptbbjGVb1V9hh6o2aR5l6ZQtS2MHlkZG5rucr+NT/SFYoVKB93lLnq2CbGPlzZZdPHLv"; + sout << "+RG2+80boVipNp8XD53hLWtOvWMQGWYwVWSqQ6jrO24yXdRjZCn+F7mn9mP/8U2dEi546WLWvaOi"; + sout << "RzQJdz+5IcL8RBPWCcUqR2VSyPIT5Fzw4D2+W6G3cvc7mWgiEjT9Vzjpeo5D5rWM3pAPl/kI+gMg"; + sout << "42h/H45O4HY7CTb1Isa5GyX4L9vPfLXyfqU0N/8v8/UvQyB8TyKaDMedoiA7+hdWtXRMJ3cFxepR"; + sout << "WqHTCK3o0Lhh29qPjMh7hKtC4xNVL5/Cz9KmIu8ZYAC4aWXX0g7wzBb9o/Xb6mJ+ZEDmtI25DiFy"; + sout << "McNqoVFxoxYhTrUYNOaAqJYVTOauV8N3Mv9WiR+s00mjLAIEOrcExjt6WrCUVAI6XCqsWPibNFTM"; + sout << "tSBPtg/w6jDsX9H13vPzwv2Wt1gGzejC2nuqaEVXlr1UKExiXjQGAJyiw13PjRHSOD5zUWOhkPBr"; + sout << "h7lzK/HcCTA6S8PzM3yIPWYUtlzfFT1a1gIoBvO6dx3EbGGRYcePQ+xCln/07SFsUXfbKu43/RIV"; + sout << "LUXeP19XnLIFDk9NRlK577rEfl3rZN2TkeFXMESPA2E3WtUXLw/rQxAi+E9n4pHJ1dfLJDrXcuH4"; + sout << "iY3QbOnJj/ay/dY0eDCjoR8XBlbvqDCvuKDZd/SPNLseIwXXTSFxnRBONagPDgQTwM6s0bqn99DJ"; + sout << "CEVWRIrkCH9o44FsWgF/gZEcEwFbZifnkjRSbiYPRn+euaJrvU8t+x86gbdywBJg/NdAPdn7gmyR"; + sout << "V0mxYX035zo6xzjvYlOVP8cDewPK76P5sFFM/00tE134eK0jCmw2uMJ+AX2mEm3RhtmUdShzk8Y1"; + sout << "lVbWyEn8VGdiZmsaB9SbeWAX4vj763YnH8pogS5AfRtR7+Z/B8dB7wBowpfs0Kvtm6LzPAqpyUYI"; + sout << "OqvwEK1/f+fwq70lm3K0jFwda0ycYlBWhmGhpuFWYh60ictb+b19JLmdT59NrIme9aHqB8+SrEnN"; + sout << "C5eh/Ee6eAU7fcsai415NdkTphEqSj/0IJNNqpoAm6WYXyNcFeXDma7VCa/6mXZFTHi18+qWqdYa"; + sout << "FjFmYunmwGz5eEXqWwZrN0RQlZLzGTjAkNuVnrDhUqfPtTQ3Jg9Q/g514VBc/e1f4tcmY4lBDeGZ"; + sout << "yAfWnvwG76t3EfpMmnfgTyHXf+cf2g6OUsha7eTpLNh6DUT6TBPtMElNMOWMDlCRrNUAvUGaiDCk"; + sout << "dkZ0LcNkPZmrS9YTvzVVXU/DfKQk2Lc4kJw08UoAFRTXRPA3reX47HbrXeo7bU4PAhA6Flkzays6"; + sout << "NWx7RVOmdSvo0Jc0Ec2rzSfShigqkmjXo/CSxvmuemnMZaBJb6o+R3iVOdMpCsw5+PwCtlG9ghH+"; + sout << "egfrWDE2MA4hENDmgbONs1XrTLrEWl26avRL15OcLNfjb4Q85Db+r+8Q8ApGlnw3Sk76VrAoqY5W"; + sout << "RIsD0jH/CD5VBtJnFNU+qFtYVNp3kBNKcNTP9RDpi9BxY0qZuEcRGtfGJMIAtqrjo9STIe7RxdfE"; + sout << "wy30LUHKaHZSURLZgUYeY5h+OTOYDpTMjwm+xT9r2GyOOQSmNL+2QVXyrBq7BxtbTwDUOd+a/9SU"; + sout << "rCqZfwqmA6OrPbLHkRkQHYJmtsR9k/+C3R4bSl28PQDEf4jfiA2ztVZ3Ot3FLHooVKF5Vk8kIVsp"; + sout << "ZnbVJA52R4CLnmUyLsqqwDaUJYVfiOQ3zuqR8lDzhNDKYuTzkS2g+X9b91wF58wA+Mg/PyhUpWFs"; + sout << "iYsELBbYlTW6eZYdDsIlI2x7khrgJuralZPQpwt2+3cg4NThKyKTpCIDQee6stuiUXqh1NKE4F/k"; + sout << "U2M6WifViFrXuKqIChGe9DSraR3UfTyjDm4l3uGX6oE2MWXRyigPCfkHCqunq34357V0Bds/Q341"; + sout << "hA6uDgedbPjXqPj5MF7/bZUHvqnnwD22WkGXfzgCpGAwN+WOWFxUZL2adBjJoEC7hYwx+umcXQZa"; + sout << "H+p04+MqMvyzyfnI8d20mZQaZef/nY6vdNy5rXna6ZyFJnKR757m/OUaL9wrvO5Wlk0nzhQYnexX"; + sout << "KyOvQYBbd4Hu5tWQ518f2RtMZzm7ReM7Xr78leb7GVB0claKWD+7ptwWwFVn7eiphI0rnxmSY7Fm"; + sout << "BABTROBL//vZp3c2VqmpdZbhgeSyWEvyZ8KUheldonXlpbm0OahvTVNZHafDX2aUph5zYHSzT8YE"; + sout << "cvm85ivZA1XdQfLkQot0Y3H6rQKiW89VE3qNoU2IReWEYZYbTk2FHIgIEicG9LECWuoPaHMLf+tr"; + sout << "8dwW4q37Bz4IpwhsPrlBmXDdlPJnWDvq/C5AWpDkgr5JEg2E1CL2QA0YsUqLMt3ogOyqhr8jvAFI"; + sout << "mWgYRweFTHjCMSbsRHRKr/WRRljRHqIs4sV1X7wzsXXR9V0GwYnkwHWu7SekTSivJHoedUspgymz"; + sout << "Xk/OkmO3hJlvow4lyNXI/QsF4ObXSF5xvUuUtNHy4EF54IRqc1chr8E073MVyHdYbK/LYDoHJ1Kn"; + sout << "wx+A4+BWLmeVS5weTfUHronUZlRIv1YgrSzw8M3Uxf3tA9+Vi9Pzu29X2IeNt7DbiZzGzMZWkgeD"; + sout << "ye0enuQfIDcZ/AV+5B2mG6uxjLnMFp3S5HxSQEfbCwyvpxHEwImq2NR+9/J4UcxXNsqe2k92Wz8H"; + sout << "jPlROh4/lFbmzxE2kT49ro3ewacP+bvpRpWyXpuVQo/oMQ0NuljpBtIoig4csJHhp/cREvXVsp2O"; + sout << "q3nF/kxgs6BrfptKS6RE5R+gz1s9WNJ0yJk2QkurEWQReloMW/w0ktcfENcFevr/HkOtTr28gP1i"; + sout << "vMleP2QxZrdNad2iBLdpJRs9SH663kVa67RTfBczDeXTqQiFtCpI16Dr7GZWnBoYtT8d2AXtCsNQ"; + sout << "wUfRuiVc23aXnOkgD5PtRmcp7WvBDczaGtBBmTbKSy4DhlWol+B8kUcajMRGnrKDC6j1a5mYtWkK"; + sout << "XzGrqT7Bsg0Tffj61OjYU6nCDLPm2Mm2XLHmhf8Ud3WJWANjsiWQKFTNH7eZUhzL3XhK/Iet+dPk"; + sout << "Y1h1o4tGWeZPgFm9lLIhXuT/SkLR+XAGNNgrbAMGSlyTP6d2EPn5WvAL8Txt+1EaMITteGO0uKwM"; + sout << "X+vpdPJC8cocJroQOethNqV+vK7brbtz7NZsdgFov5VYoLNWpwQLigIcaDxORQBC5yMh1ygeDQjD"; + sout << "exiUJVF6Y0blMVyLqBz5OwIO1ftDAxeirjE5CIwSAs/E/CslH8dTJAL6hmTJ5Ncw70pf0NvzXkJz"; + sout << "GUFSaWzRNyrpqFh7vtXVSqy/YAPOmvbU6TxUR8I/1O5afmTKkETJQP90BiuzjYs1N84IYoVt17YD"; + sout << "xPWkQWH4LxwDdA51WSGCIwE/svX3BMYH/0TB8EMdL3FwVpNMgjCVW4srZP0KShXr6oJ18/RVSe8e"; + sout << "Ls5HmsOL5vbvGarpKLvt0IXmveJx1Lr7h+9FAOG/n0HHD9S3Ug1JrgweszosdeIho7KDV0tGCEZR"; + sout << "lLtI0p0KkoNEc6Hcm7VZyMOx6d4aN12NuVL3l6DFtd4SCFwBrIgqyDIf6ZAb4xeH2RckVs/uTBQo"; + sout << "tNCZQKMaxcgc3Eee/IOjDVzFhjW4lo33FHhTrzYzvCISWREM4eDr4/ce+UZ6E00IsClZI2h2HkiO"; + sout << "3k6ydBh51dHWHf5CZPp6gU218nIU7usc+XttaZpsodclSmh+LpXGroacfKyUGTB+I7lf+yzzBrmk"; + sout << "Eo+WFBgemz3Rbcyem+zZaRS8YXKi4NlUX5lS3heqA2ogFwFDbQJB42FBCEMi3WVr33GELllqWSqw"; + sout << "FZPBsn6RrIcIwI0/ehYLjO01/iOmI2Z+YAg/DRuxYDI/eV49z/xv0fCU2Xs3vYv8q7ZH3fjwyusW"; + sout << "R7lKhO5nksDaJ1OS3bf0u237OCQYoc7dB/BkC9BHX2eo2Ou4x3pKBn5HQnldPNqA8Om/tdCgQeVx"; + sout << "izQq9KJefumUdCmn4ChiVdPs7SK1DfhBnyL2xVHck9bC4fb3MuWZYL2XgwMwQ4OOOtUY+/sX8z/U"; + sout << "O1EM/l936v45oXW1LJJYQtIhBf/rlqxqc/3l2YtHR7Sa40RknaaukggYYQJB2PJqXNBs3kI3FwIO"; + sout << "z1mMafTy12d1ROBlqBee7AF5QJ25fCM90GgxqRwuF8yrGyXwv9pbnqmLBC6tz5DjQcYWpwm6qric"; + sout << "pQu/xFI6RKtMM+tq+ETci06pbHQIcQaTvvW6iX5tEJ1iMFJaCueVEx7Q0MOgNhux+338+BEe7Gwq"; + sout << "5g8xr553nO6KTDIrcTAgkRnQIL29WN3AOoFl5JWtP4uo1/cKqKpN3JcAIW9yQTjKtVOfSvZgSRxg"; + sout << "Eir9UhU9SgMrM0/cdNLZed3+P8U6AYn5LxZDNWIlFzmxCBRh0SuJJ+yXtPkqApgvzso7d60FhUPa"; + sout << "H0LYx+zhjx7jPhABHXhu3lY5uzsMcRnndIEOx5+N27vFUwxJC+KmgOi31g6XcPIBqwaWV5YxdVf3"; + sout << "7cYOF9FYgRaCwiE0OKGADJJA6mbTYv7s4tzcPtHEyg3XK4UuT3lJjaxjsl/VLXbWW3JKKXwOuYGQ"; + sout << "w7ExwYxyfm0tbXgwlK5nctw7Xg9U5CJA09jh2RXOQku+PmwbtseSP1HbWFC0pOskZatQneNQ6jd7"; + sout << "kcYZ49MxSI5q5+rZuEOke5NytoWvPIeCFsKMUC9HqsRPivSrczNf51yX0fIuSTXts/ch5fakqj8f"; + sout << "cA3Ln8e3wASkQo6qtG3PP2tAyEOP3qj39DAZT9/a7e5m49vjCx+ui8Kq7tj5ehGuKhrteERJ02q2"; + sout << "ObOq3gYffJ4LDddCB1HQlrK1vV0ZSDctR4QZbEhD5PC0ToKsarT+FKHCDN2GngDIh7zSmEkq1YHk"; + sout << "VXzD3V406Uij54UzOXv+6riqiv139xC8jo5TntvqR7S7/CNQu7D2ix+l+WMlUbFWy3PT6pL+Al2g"; + sout << "ydAx24ZbJ3PQLivg44d64nA87Ii+eDsrJNIFt7jitB7w8tDfpJRsCyGWDOkA81Y6Us8udhe5KDUR"; + sout << "7jekzGrKuX5VfUGAQi5+Vmk9W4d6C8U9Ma/D4wHIWkU+M0ffMVEgJjETYNMDyefi1gcCwh+OMCgC"; + sout << "wbPlRBhdqaUMFlMBv7ysGeYJUQ6VoDw/TFcNtDMZV7G5mMwtsMhfJQV32hxLDhscF8nsafoeWIhU"; + sout << "9aoihwKib0H9EvT2gg3KyvAIwNM/EjpninJXaaMNGbEFJ/7gEAZ2W/jAJ/eGmkLjYvv7Hh2oAkYn"; + sout << "5UOk7TabvEFSsoCdeTjR4YjQS18zZcz3Ta7ffCgp2htZaAyJwVeHbBBL6KoJAlsgKi1yVtP8HXY/"; + sout << "lpulZcwuaKXjQJVcHS/RcgZsaam9WxNbYh+X/CBouHSoZ2/LbtimWXb0WbQE3xk2Ntq3j7FJArKl"; + sout << "gLGaykVfPzTfXFPKF6BduSPmxLCmgLgeZU2rIYBC0UXMVtEtizbU/4MF/Qo+0jjbO2BZGb/6tyea"; + sout << "PRE/tXWyjGHkcx2T0ZUYTTpPtWI4MQN2LxopbcmDkZszXvxH9MlWC1jkhv2xBnddamb3jc1d4qmP"; + sout << "YetgmnmQzoozyXiu95T4s6V+H6gNZVhle4+QmwxOe+zNzZaLhn9GBiwfLDGfBqc2usa2FvFWpqke"; + sout << "PGE7LZ5HN+A4XFJ7NrEva9PtoPUOCM3FU9gyGdzKUyyW2ilrPX6rAil7+ZF6KDAVYOmnQQCtp5Ls"; + sout << "P8t9MlqGMM5uwELErplaidbu2fNVTXvySFbZkga6AYWWV+xQfekzRYoN3LSs+Y0OuUU86RBqyOe3"; + sout << "+5F9fvPztVD8Zc/c1jfY7/3lgJMtgEO/+TDa+gDQAZoalgf/SGpvv66MuiABp8LALajU0a2k1DYT"; + sout << "otIdBHlVLj/JKD2tfHDi0FevTo3l2N2do9NEx2tyMGoLQwuJplvP40vHr3aII9Acuim4GEqh8cXT"; + sout << "XWT5KzFF5L8f+QK27m4lgotZVF1ujXVuQRT+CHFwv5LTiwHaZ/KFgVVY3gIkjCRoP0XRARMk5oIh"; + sout << "F7G9FnoQM7oV8f2b4mTKVvqMMv3uWzd4zr0Tx7GTSgL5YfcJox9my7ibHgpBpBNtej8uX6MtzHXT"; + sout << "SbBoG3lDrdZ4fhMx8d9+/oACjVThf3PPa0exxvme5s6T3GaTnZt2BgzLrSknMaRZ1Rx2/kT/1ecn"; + sout << "GALfb5irpfKmQPLdGJv8EIAB/Br50DqHGC7DfCtZ1BshNTwgY77mI9cus0SGrgmPRXYuuEUVhFHf"; + sout << "OfrUwDNX/L8+ztdvlgJ4NnE5tUnegEz0ApOyZWywtxDodl3AZoZfBV5ODhoHNsCyZCqD/HYZXHb+"; + sout << "u3ZJ2PB7wcBbxSCf3mLSo8Kknq1d87o9YAQdyyNR0s2XvQO9IakDz3HfBFugDO0LQn0ewgxAFhMs"; + sout << "O6hA4M0Q9NS3X+Dr/Roy7wn51BkJhrLvh9TSO0+VCa1EBxo3mI5vM0m+Aji9HwySVBcd9L5FcNIm"; + sout << "PTCkaKyAd3DsqBIoHFhdLhauODisNzUFC32TZiz6VxSwxFnNvyrRQUDidLEwWWcp8Kmm//Q4nq1f"; + sout << "F2Qq30IqkLC9nObGbVSjZb+n+EwqMlE8mJxqbgkh5TQ94V+zQY1ykiRa5a98cFA5MLazMniE0JXv"; + sout << "YqUtmkvvbyFjbAZ+/RPUqG3eQxyvWxBSQSC/zpbG8rJ58ZUzuPkmmXe2z3PIzFpb2BMVi7U9t8Kh"; + sout << "+ywyagKP5gLgUTqOS6Sz6p30dvtDsayhngpRtVfmOoJ+Px1iw8f+3o+JEdJGnXEcecaGIzABCOST"; + sout << "mEC/TQQ5ecV4yEwN/cN695Nu76+cSiHR+7UyWNSVV1A9WJhsaW6NJl/YY7dZDlMYWCtlB8kUCfWr"; + sout << "4bhrudo8xW5XeM1EtmmpONd8HeMN0479Y0KqMHYCnxYz+UnC7XZBJkweyFPa6TroZoqjaIZMEyRm"; + sout << "YizuyvyaX1eYGM3Q58/qti2hIgrRovGl8VYnJWJbaU41OB2CY+R/TsuQo4EcqfvAUCgNxZwwmvkG"; + sout << "brIAJ5+fOVfI4Fa6swJ2PulbJBXXS+fCrjCEvBAqg2j4GxupCDk92IF6cUH0dgmx1TBFgRTNDfzf"; + sout << "ENcymno2Q8HKz6uvMenEfNDAsmV+giNq1jIY6gm09kMwtDk2hs1c0hqFuno7pOrkS1RFbd8onyGC"; + sout << "qoN7GtGpMk6q82mLqrEc4chaIOF5UpPSH3sRa8tbiP44q8ItnbXnMlzox3ZhZqJN6QgrNWnQa3wD"; + sout << "4qp2gUxGkFE4i242i/9xzvfYu8h461f7e3Es8rW0sB5PZ+FQ8SImlpE5vh+Q1qhXoNof+I0fGGb7"; + sout << "fbTZmz4ZZrhCS6KgP6HBmbohfPXCKubfc8Q+F9PCaAN31Fp9ycTV8FAfXySbQzdShgnnhnddJl/T"; + sout << "Vhx2PFPythxz7GEv62hwjLuZdxZq3nCykVHZ/yTE5yoLBaYUvsVVZn9rv1YD5+oSYeJZTbjR6Hif"; + sout << "t7geARjzf6oe85YZdQvODK9SMHYo/QwdcPwwnSf4Dtmvvg5XnSG41s16X0n7o7g5ZgFr+HTLwMzR"; + sout << "k1Zk3offnOMpkcpEXvM8OhFvFiXoN91CoCvin7f2gX7IWrNv0mSj92fbKzQ9Qa7N5cH1pldFuDcA"; + sout << "kQ9wHFmz+rrnzgEl03Br7NqKsB6koxE5YEzHkNLHmg+pTOyu+yDREbopMHBi+jTUReVS+5fKFMLY"; + sout << "LEB6T96s1i+ygPVSt/4H/DQsU+0caW2dCfrm0onr04auJcU+oBYjbR3OlM+6SD9UaUjt2PtNEQ8j"; + sout << "fZocgKhSxt855hYm0y5qvfBLoDaADhziF6Exuh7YM/G6ywKMi04Ab7qecuut0c2bWHXYrQPhmmCE"; + sout << "NXCiCX64ZIGfsFR1pf8eCChyiA/GW+rQvxv23bl1RxbDFk2ZO9o5581NjEPjQKJ8AyE57W0bZjol"; + sout << "8s4S1ZNwMF46UgZEJwCPLuqJEJvZyLBwahHwC1B7dgoDdZ4hjaJdH+jYEZusfsvV8ZC7hS5e8V6U"; + sout << "rkOARJZwmYES8JJlVtfbfpCqBXZ0gTPuBKt2yYIlGTqGwGu+biiNMVt3spu2Ov6MlVoxY37UZax8"; + sout << "B0Bl1cbrqIBC3LN3uWsmnR99chFqEmhi1cRQR7+3DtD9gpkt1Uo0STq0YCgbVmr4azLoJ3kcEdRa"; + sout << "cfZ4LQ+7zlFTAqOhh/kBPsL51ZLQ83ipFUgz+c1f71aM0eCFharOyhIPy6YAAPHpm0FdMv9Q53V2"; + sout << "p7flk/h/s2eqe6M2jbtud3bxxUdercWJvaioQZWMkNSXcP4jKaRN88np7Xc44pjUnPgO+AYc3Kl0"; + sout << "2DXf19IXDjAegihkFzqvAKV4zttcsT5gi8wN3M+zdoG1TcnG5gJfXumw+Bu5giUudQ9P0mUZPc4t"; + sout << "vMWo3r7ajnqeiUaN3PZaItRimu7vzun5+qOEIT5nxHNDuwKVN8ZTigvOrFRUxYkyzoUzUsOPUdla"; + sout << "nE3MUbT1bj2cs6ih5jqtbqgutSNQSGupC8Xve/W8tiQcrfRPBMdtvJIp6Iz+Z7g/SXpunCoFaBSa"; + sout << "mAuCHeTbUrCQ5xbQ5k8G5qIYRglHRADsWO/oM+vWtc1jEEBZ9UwLdFZxIdj1ytNnTwIZWstliRQw"; + sout << "hwNPyxfJQwng2WNJFG9LPvzxxnpCwYG4Fk75wPd05j14gLMcERPyWpwYurfqeQG673BTCuBeF/DV"; + sout << "zZHI0nIV9CLhSSbuo46Vkb1+FaZwwpsD5sBg4+EwKcglHUzrqiskq83hIryVcvSKRNdb2R8VDLCa"; + sout << "pPblCQx811bg68OJ9UHqbw8FPcSu8M5zbahisVzxszdjmyZ7fqWR1twCBsJu+kLjNEsgC1cmC1ca"; + sout << "OLyaDYs0sad92UjkYqqBQijASoQgafhSPuJ7sEoXvLXrPP9GWb1DV6S6mD1ZBRicwlo7pWb6LREQ"; + sout << "HFXvVkX4/seCn3O76cNuqenVr3EtIOhKC+lKDCBbnKZ2ggyoIqgO5BIteeFcs5rz4yuQRM+lB7gT"; + sout << "Nvdwa7iYLDOsYOrtBpNMD5KO95LQw6m/Okd9RKQ9yKTcV3jQFGfAlYcvXt3g7tfz7Zyj8ED/thPO"; + sout << "pcrsgQE1E99T07ozCCDNRaPc+DhgHyyIEHNNxmgGUWziRBKqhsnei3nPCg123K00ifq005MgO5wF"; + sout << "oHUXDMjbWBF+ELX8MWGaMh+a5OWq50xs6wPw1WhPug4vnEJa1j1rJxeGRSa3SsELXmeY2QoAX/9b"; + sout << "S8HeDJq2O9OdVy8bMKdfxxvxL+Q/5m35EGdExoccnU+LcpGc1fIKqLfJn5oKPS6BlNKZnMpkyV+8"; + sout << "z6VQA7zKvYP/Coj8NSmZkBJm4SoUkTBt0hZ9KEXEAAzR7Xf7BzR710lXcFdRCYWxz4KEjsu8z1e6"; + sout << "NcjhBpcgvFTnowMThMLtaWotUJ4KnGU2ys5wMNh9E94+pZNkEQ8xShAEt9le6/gSLLuXpssWJIZE"; + sout << "JR6PlTgQtFoSwYfx4+/iqbMVtbH7f6tBpLbnrIzQXRUY07Uw13p7kblM28k91GxRDqbu0FR3cWnI"; + sout << "nTZBplbiYcOM/SlFJFfThLbgAtTO0RWgJkSr0n0n4duBka7ZzfUiMXjyAHvBAQTGkHIb+YidnC0w"; + sout << "eY86YsFjo3jTqJpLYuciQLJ/ZUI1g9v681JJVICEw4CxgguvGJJNOgQDh2CEZI3AGLlbm81ftzRs"; + sout << "MpKPE/RoI5kzoo43tcr3WCUImlVdrNak6K4gaNXH4J8aui6MN3kmZ0LWS0l/7Of/bVs56dEOozSm"; + sout << "PcGfVLmvLExSbXU90kmknOftcUcTG6J9r1S6PR6qaF2w/G1t3wx4mdberrM9hOIMoXQhXEFShNHn"; + sout << "uZ/DLE6t4Dl1pvD707gtoY6UpHpv9oV5xVVPVnq7smExwf3gQsuiHnDDGrPlz+DmLnkZ1XYCaKSf"; + sout << "EQleJCCuZApUfEl9F0EjFjSDBgPoPpwQ4mhDXFjhYjYhdWmOfANCDsfXm2PUoBm4JJNI5tV6fa1+"; + sout << "S+yaFaIw3TiqWvaOGIA0GyicfP4SaHodizLUVo+yfYJhycZrWqclh3OkuT2SVndIDLs+8Xc2oSBY"; + sout << "uJr8P3Rm+Du02X7X3D94XvNI+tl22SU9CMMfG5E2kReYpkvtVsejWaeg7QKil8gyRLtkH1kIb4Bl"; + sout << "Nu1oeg25ufTr+pyX5nqCNqeStAsIMH6ynVRecIVr0R+AF0+winGiFVVoeedIQzWc2pe1nD1F8LwD"; + sout << "H1NTNdcm8ukf6AA3O4pl1uvLPIZIwC3Qk+AW3u/Xwhjq6AxEpkCGoLrUztB4Xj+uzoBpQ2ka3Lu8"; + sout << "K5XooiNcanVGjWC+0/IBmB9gioabhNXit99vFEdTQCsZzlDY8D9le1IZ5lXIGgSZR5LMPOAzSaTA"; + sout << "b7vXFps/tixvj3n9wxtOTAMREIVaGkJgunwSNST9lQ3tcveAJWnowLztcn861ystgFGuDgW65xgF"; + sout << "Fo1EHK/Sq38jEw9frNatPCJK3eIJRih53VdvB9A+viAX9IoSK+KH70Jo25Bcrv1c+6Gr2H/Rhv4o"; + sout << "quxm7d9XpPAA5dGZb20fFloxerWOvWO+QHghXZjyD81o/1hhR6ZkOiK0trPkvDLcLOfPl++YIjPP"; + sout << "vI1hZ7uTmn6qlN20FCN6P5H2YVJBnhSZrKx8gFkULKgXdT5SpVSR194vK5cwoOUY0V5b4EdSAkWa"; + sout << "3cQm6ibiokDAktBSDS8vpitwP0sLI5K/7QfLuXjihWXONCSf+RobFWqWPXzYW51q43+MP7xb0gbM"; + sout << "39I93rn8s5XNNwVdO/IXzd/HmMHiElOr6xEHMzCUErUuQQQ59NbJ/N/iVuv2c4JZ7ppFWcz1ZQRz"; + sout << "dgtOzq49kBMIxksFoFuEWuRV8LEqBbNBQkW2zFNNWpV59mqwpikzNsraFEvUPBUJKz5JNSuXVLF5"; + sout << "vq+hYqVmhd3UC5RU50BWbQgMgRm1Bw98ScOPxfMMqBFPiMntTDcd2UuFptT066sbXg8slg/itSSO"; + sout << "/CQmixHP0Km3Lz7K5OqVdIr8GYAeP1M3pjxZCEsaLGF8YGrPnIjMyjyFxFhyteYsgU6330hc0TGN"; + sout << "LKJw1F8lbEyGt5lujwnumUFn4Wdg+gtd91ewchV+uCkT8Atrrd5gvCc2AEWxRceQHrUTuW+u6UCf"; + sout << "LZvF/BC/8PJszVQsoy6dHruw4E7UwlX+TCUvP3G/wMeDf9SBsgnn/KsRr0aQmzuIJpWNNULXGyXZ"; + sout << "0OHSxs3fJYq/X1FB2EgjJMXrAczU6aKHAj1AbCUbBIaJcB2umY4YAQqt4x2q/iOJDbJ+k3K28+Ry"; + sout << "vW5w7JdiinwwjNpQkF5OO+4bQ+rqPzJare1PT1eBn5OxRpVlZYpn6AZIt6/La8Ir2iWtAdUu9Krm"; + sout << "9F4qrhoykAlcXnI2GZi+U3I5ftBgWfDVodX8hxzyS5Soo2LFCwmG4aYayxu+RigV4YRyuXdHe3vD"; + sout << "L6Vs2qmyKQ4JE+IbtJQq9jkC2PrFMwfP3nOstccqgnlGMrKMmN/UX2R0emMXsImuMGTn+rKctWJn"; + sout << "sXIysVIEupHUUD4qC1RuZGvdGk/LsV7mQO4SkPJWD1JjUADB+8Cw2nkMVCxn8gFTf1JU+ovxgHl6"; + sout << "nMrlYnN46eRJSFIFvZP0S0gZf18aJaxkgZt/YPxm7m8ZT6NLqLE54+st+e+sWHlP+9VeqA5u4aLh"; + sout << "XEJGaJozsDZxkNfKN9lGMkacM7LZzR3lhjX1o7cjZ9D5l4m40M7DS8sbMg0/5wXn3WxlPimZFLvC"; + sout << "IBLsEAoUM+oWB1xRq58SAPIAgic3jETi8vYhDR3vlN8ubWsXyTqOJKIXJ8O9Tz+aKkgLq1P/1y21"; + sout << "Py1VM0VW3XRnnzs1mN1NLO0A9Ot7NMluCUEJekVYi3jrhhLRG/tGYkLUiYcUq3ozoUZjBI3nX+g3"; + sout << "4d9uDIXJT7GH+ovsh1q5Rw7H9bijVF/qT2TNJH0DoI25kuplXnihc1+PAMp9XseXwA3lPZhxMVLU"; + sout << "ftPlCMbdP/LUPrfoq2MaAQ9oKokaK99YSEumcKzgcNyx6wUrB7ILTO3zqFDPsxTO6pxaPgOXC+yM"; + sout << "mLm+LXoEFMsByb+F8xXsrtCFDwNBR8XU92UgjMQLbvymvFE/fEeYMwOd5VOPcPeOwPs8Jwn6K1TK"; + sout << "Gl/QNu8NQOgVK9rIISGbRdpmcLRW7JY4sp84q8CHBaZ+uVRBYzdFDw8VokNyWestrK06wQG7UOMa"; + sout << "AZGqcttVsN0aePOk13IfTG1llEL2fczoNioOpCaVqzzUSU4wA0QBOa33h1VHOY2Py3x7PWI2QIJt"; + sout << "v1IJSEmnpKgVzYjSnpt7kUlB2dYv7w3qwDeiE47NuK4ywMZAUuONlL/ZMpR3+I/EIu3VJRo0RknW"; + sout << "jKkYUjOdo0FrGhQWjuISkZB8prJxAFrTDxI4aiHSnfMB0q+L2LJU8LMLLI3Ok2NIxWmdIJaFKckH"; + sout << "f3JUFpn784RDpdHlIUU5EOzfqal9flGrMlz03gMrvUiq+9ZtJR44EXoUwLLdAWN+FO8a9Rrj3PUw"; + sout << "E9q2KZ9M6BqLc/YdZFHlpZYOiIlE5x34OqkVqcQOmlWmPVP4bSwbXwmk6oWjruRxmaA2YQ/sLIHm"; + sout << "XBZZQdMe+i8TqEjte84I5EebJRi0oiAzTFUlf89NWonCzSnzdFQPOiB6Ik6dX8o4607SkyijR81F"; + sout << "5ZiusAkWNx6zksMfeZzCxnS93TqiOSmFVrbxMx6BkArmc1IWPCFu0OAQYC2af0KciOEndo00R0xf"; + sout << "wQAtJyKyfM68kaHebkMhRWDwk9HEwznfn52y6evH3n7NnDttgw8PnphD4/l19wWJdZ2l4tcmY4eX"; + sout << "ly2zjWCcaIsx4QMz1gEGc7lvMq0J6RbzsoH4oa6Zdu6kaCN0Fy246vkDDBaW4n/fZt+RvT9YbiqO"; + sout << "fzMs1dcXd9mCOz2yB0/i+XzzG7awbdiR7bafVk+7/8H4y3qVyfGrhilzYSP7qEjF0ossijkvU1bo"; + sout << "OblpwSw4RLL1npCLQDtXh5ugQGcsEesLEBq2uFpxPTSLKWgqTr3P/XL+EwNJgkWDIXdMgG0Ft52X"; + sout << "bNQpevk0mqynVmh3HiVnAZ/mUaRXZG40kqcfQzxtH1gvf41S71sDXT+EOQDdwxygNV9BhOfyA4eO"; + sout << "gNeFlhuCPE3cAzKDUrE3ITYl1++4JIpP3BAcsT03GlBlSL4N1nXi8YuTCliEmN3TRF6FLJ5DkSN6"; + sout << "tCk/xGsBkaZlQXjj2P4FV/BdS//Uqh6EBRnRMoRH7LgLl0QJxJrj9RxDdRwW91fNOAJjg53IIBX/"; + sout << "XCsZr09+jke+Ci0r1NWHRhmF6qNY27CVKs3Wub6dYS/GifyrN++qpK4aTOLkkXQv5v6Zhs48qFAZ"; + sout << "4XORp6I1s7vCQ/yU6R4UeBNoRAT4naUFGc3gvmszzFb0Cp+li0RM9wgW25miE5Y1PMpne0nBvV8a"; + sout << "k1RYRJA1ZyOEQbmFxLMqanwITgwxZT7tK+nSITSmHAo7/m9IaopgYOumkPGdy/KsrjkR2gTq0XxL"; + sout << "OSjhTS/sOrZhSV0HJ4pBgjZMIAlEwv6k9XsefXyREo+rudBS9CiXW6TzIViAWCMc/hQo5zo5tUeA"; + sout << "vmTrWxkRgUmdsj6FGS21aQFPADj3oPrAXsTiH1ZcwpyzJrT42+4P/SLFw/h0NBYRozI01AVVaxQ8"; + sout << "sbJ8VUFliiZnB++Mh46aAF1PzomK1NiHC7n+llPAdEqpqiWqVH5YK7CpKdmCPnTUuOb4a7A2Ylow"; + sout << "rodSKCOX4VHwxzARShJxnQz+FIfVIzpiCFs9ZJKM6x89ybXTKVRrmwTbTAXRbihUKa9d6t1vmoDP"; + sout << "07z8WyFDoeWoSmK4XZfegZeZMIXF3hrDPCYcZEluxujpPG3ZPPw+/pGjBOx3/QFJTQdeHBz/PQ/G"; + sout << "wM7DkD3NM3fVApST0LbPQFKxlEXx9tCL3GBul7YnnUmF+N3OIs8vDoJQqNqVi5ZQnGt2uFLOYRo/"; + sout << "67XsBQv/ZwYNC88VlJykWEbbVDN3rY6i1JLr/9WRygS6aqF2JJBTgLwXpjfvaq7Ygttm3WiT/dxj"; + sout << "QtLg9L/y3Q+LoLOKZ5syIdk1+oQhYnSqUgTNEuw4dcTXj3v2lHm5oxzst2dtierbluemr//Bnclq"; + sout << "8FdWTT2JC9N/YI+CZi6K1lz1Bq2yblxL+BC9UE3rtsq81Rm8kJxSIVQ9HsS2ligaxnEfCIPNvMhV"; + sout << "orLVKxoWuOfhJt++IepP8kap2JZtLdmdZSr9/6suSkInlANS9SkwaNWB+EiTUm5cfu7s90V+m4J2"; + sout << "SEiAMESJKLywmeQi/JcV+8lEiThyYTxhm5huK2dUGPWnzKK0HsdOuqRs0RnHn8SQIgLndjyTzcV9"; + sout << "2mpVtjcwCRKVBFiPdyeqyVI1j7srk8WrRc0GozuUkSBMVk5KHQa3/qZqiV1w9KN7eMRaYiT46bP4"; + sout << "2dpdvqNrms7/qmyv0q+Bxp7JYv8bYqlQPaXgHxS78I7HqFsLvP4kZVWxmhCXIoIObX304w3KRWH3"; + sout << "WlQEj/llet0DPOhF6EPfcuMeSvH1UhucSS2pnYwo/P0BvMYvj6gcbjTowWMk3vmjiZV6a1xq0Nm3"; + sout << "sgrrS/RqYJeGPM6THw80LulT4e0aflcyFN3QQ+C01vGHjpaw0OwIW/iWAMQolVwcZznJPq+6Aya/"; + sout << "aYbiphirZC6fZtUMIXiKRrUxQDyfgI0y8XzJgf28AxgdyCP9Dashm9uelYMGU0Jglf/iwa2K1pkL"; + sout << "FTS7zfGJmz66x2+7agbLuCHd7AQhZb/rfOwauykZXpCE9ooLMLgW+BlKpqgVynLrt+b0I3KqQZy9"; + sout << "KkFxFwC9cumQJPtKGKUigDLV6+orLNwgtVmNal9RFHppuj95RWG3XgwNLLGoE+K5TiLhLY1na4AF"; + sout << "1NXsHL3pFjM2VHEU6cOqVLqhEWwGhrmBoxIGrbmT//lUiGwSLcmDHCvv3Qig43HbRmuQWW90qFzc"; + sout << "Sp7atxofyP8SqmW2aRw2SRH/nvCTTpKw4s9bhDEWCukZKo+bpfp1ti0P4DzSiUxTy9uSjWs3+7Cf"; + sout << "nWsTsoUxU+YdC6rDJYq2TAcn3Zl6YBU9XZl3YLckRqpUGUQ0IkF0Fc0bb5kQCJH8h9qc0Olnqa6s"; + sout << "sjHyjTwkzZWFmuvJPfxEZzsZZFu9qWdk/Pzx/46e3J7/VaQhtyYidoCdV+wtN9IOuxnbyIV0A+vx"; + sout << "ZAjS2f7y5aPxzc2XAmlsmNLBqrRfMm9dvxN+gYVD3UcWgzTVxfIXRyZqmdAvuHoTq7BV8FYPMIqb"; + sout << "SKpCcdEztosRjzsCxioB0JPCGxMhHgPU4FCNBy24fa5RIz3rZkR7xFdTNw4LTkJ15BRKMsBxHaiS"; + sout << "E30FSLCvFvBVlUITI7aOrvc8+2IgX5cXLxHQeqaSTyKm76ioORshYyRkp8K1Ao0hZXCW8rks9fxm"; + sout << "FeRen1ureYGmkoKJUWq92iAPwNsFvvxTm4IprFoUDn4s56Ung85NTvcEJEyc4yVyevqb7hh4axAu"; + sout << "h9N2uFPhx9YPLbOHA9m84PWgt/IPOfdwhlylIL50MZ9qitPPG02an2gjybEvicqckLYZt4cY1Utu"; + sout << "pslZJgCrFZ7ITUBixaxhkuXW1wxwkF8RGosHdfB5zycnXkvXhdV41+QDvhZ1tR2E10jwqjVR/7W+"; + sout << "suB35kUFqrlvWs/bBTQHwwPdNmGiNFi+uq6hIBQCm/J0MbCaYcPPfhtO3LIsd2KeiTz5oV8AqBVu"; + sout << "NPQav6McdWJjg6dBQw4G+raxkUf9Qf9TPQspS2Ll/Kmt8WwbqiRHutPOaxaBBeE8Igx6BT5kIASc"; + sout << "OsAnqIzudE0cclrkxN4aeTglqmwwq/eeZUSOzkv6Ge0xtpYDgqfGbe53fxtYvItYYExMgy5HdojY"; + sout << "1zDaNXsfHTebUUjPnsXitzIQQpRobTCe5ttWhfeeXOvE+MPLRt19KxSrFKHaRlQtjYsmxdg78zc3"; + sout << "8ceNlnoQbh/puSKv+QWnnELfDnmgVZnU7lIKs2xM2B+pNy4Q2YX1wmIqlC53rsTzU/Ik7Db7Z59s"; + sout << "i5QWZ0R1ll6rdkCsUCTkK8aqsF+mSvWNbMd269Nch0fSx1D1q2NN/4lHDDoOX1JjLPMxVNbdxLMt"; + sout << "gP8gpiQrFJWZRmTgYBW+biZPtiKgOztsDrZ+eolNWpsGnLTK9Xe11lZrS2fPton4L4jnr+R5AnHk"; + sout << "F6tBVhSBRHtZdF/xNcCsSOKvTOZJCtftMO6XaKMfsI/yexeLz41D5qfMVSBHY4bIgL3Jgh952tlH"; + sout << "2blbe7lcZxZN2XkUXXaVaX60XDTLgZUbGzqJo0+x/TJIAuCAlGF1XeF52jTY7X1veE63l423NPsf"; + sout << "kCmiRcepJT1ihXQDd/VU37tJfeGZrN8FjJzx53PVp70r7eLJMedMbBUJBYmOofO2F+OU4Rfyab3Q"; + sout << "Wy1ngiJXrfjF1DVqmKNfY79SQWWjlzST97ldjTICsyQdHWn8j+Bof6sL5k/3d/cT/sohhSSl0Bux"; + sout << "GssAFa27VYsJYMjnriFZD13ZYmjuesSXyQ+EkUsHn56MB2xEwLRVEhq8ygDP2eO82PnsIXh+uVFe"; + sout << "0ZQ+eGSiGZtqeL1aoePOA51s2eE4DxGBHWdSNr/7Nc43oUU0P7sV2lut9VmhgDleV+Du8IaZUMIb"; + sout << "2MS/TkvMW5S/ebAoXaL68zmsU32tk3C0FG5oxbdZD1tcZ4yHNJ38HPSx39uOfRG+FrmDqomfs3k9"; + sout << "pu6KaH/Be6Urzd7mr4EvxaXHz58hsJGVH31tOdo1jhg7i+6lCJAuJX01FWfnClLz0mSJnIrcPge4"; + sout << "zQrhghg31+fLZw4FtNmiJJcU/hJPGhUZ5LDJWFHUi04u3+dfP1pYR4e+CEOyvlK2YXx8CHcvjWJ2"; + sout << "kCW8TMN9I5arBUTXdFfOg/v/EH4cw54bmTgzjpn10HlKqfJn/mJPPbBGCCS+pdGmNIzItVME9IES"; + sout << "Gco5a+DT/a/TThfkRVKhJShe7ccCC6uQPdU10P+DV2KPxOfKvar2ww3I+VZ3FKxhC0AcMZBbN0V/"; + sout << "0tTOJVH4GxjH84c9Jr6niQJ/2w4G6tuTcJ2RLb+HnzfingM7J9VOr9C3IfG0jndHCww96Exjjk0j"; + sout << "bE+Aq4H39i5pH3jEvdMspm1B3VY7EN2s+fC5jpmZpA5kRb6Vj+CYCbyEYNicyAJqCHdJLrtHDiGN"; + sout << "fIhpfrkpc4+g5xHHel+nnJKgduqzs5eecM68vOtBTdmLa/nxUz/VBywqWNmmN0v1/g0Jlw3U+IGb"; + sout << "bZOvv2Kg0FO5rJmnRIUsokPkcVxnfz9iyf+l1d2QIBmeR+YeNtLS3XUfIjHvHwSrZ3lJdfXNnnwr"; + sout << "zZcKmHHNAoRLIhM/0ZYQXMbOvfEqUBZ0GCzMdwiEuNwVPtbDkgmcWjYR6ew4YjRZR9alIas8o+qF"; + sout << "GkL+PT4LEHfASjQPnCZIhUISIpwzuf7Ii1I2Fk/v8jBbL47msvPsQcVH8JiFiWBGCBfSB8IfSSw8"; + sout << "sgc4bo2hwHMIES5W+rXXoQMSkGFP6tUfd4GvS+8t60UyDUePQbBGj1lt3ETsinD0btygvMOqDhD6"; + sout << "H+E5gRWBXOV6QrXtQSPG4ZSPlDvATmDdgLCBtQBBMTLMg36Kvia0W+Yyu0bFtFEJre9N/bmyyvRR"; + sout << "P4tHn8uOTi3ry5PT2PQiJ9c9utByvU3ydcBhoIlcdQz0f1Rf8y0eyKj6qt1/OhnzVXuLrkqXLylC"; + sout << "sthAWqyIavUltofLnsAP1uMHZ0pHeX6R1lwpIROxtRH6p/0+//OK4lIZx+7D8aMPIaefQAB6Cf3C"; + sout << "jKnvyCpR3QEh9JLv22OghKwswHT50P/+z2XbyoQsRySXNTUbTsGJuYBazgG10fe6YHQOTrUDUp24"; + sout << "6PS/wEco60z1OF49dVlYak9zGA6kTRndkjWDu49NFhJRY8sBd7TVxeiU34NVyjIDLrLlx7DcuZ6P"; + sout << "1/XcA8czU9m2n6VDgkkAB6eaeL0XhB6XZfOrg9lY6R+xjVYo50Fg9A8AGen5T6+m0PeUr20e3WV1"; + sout << "H4KANYIvY6+zN+Fe6K37VaO+CbDZfvFMzifvzTQWQ3kDDTWX/BLfFgLGZ6QBWrTwF7MNiQ6fCpG9"; + sout << "BAYy9V+QWy4Iz9lqnp19J4Q+cqlIUDBp8b8vBNyrOjjLjAC7ezFUujvB/7RrKtnbaYmS8vYF1f9+"; + sout << "mGqA78owo1zfad8DsvGEnr5J5mC8d10rXhVXXb/udkiU5iEhYPnSxRy62tgLbKHJOvW+R28r7lpc"; + sout << "yDK1++NYKRpDIHYMuaZ13oDNdIoQD+d42Su0NsP0wECAoMmQKkWUaJyYUyFqzmQ2FJmvl6hDSbrs"; + sout << "lDipexBL2U3slCAJbX/PyvE0KrBPBe+vT/w1Z5s2GBvSoGwmFlN/oa5I4TfcTA5W4ie1rBHUKqPr"; + sout << "az/4un38eXTaF0Gfiw86pwlgJXWr+D9qvXQApo6KmaJhjKo+4/MaPw+iEQrU3IM44eaqq9exmiNF"; + sout << "/SgBv76gC8hXFrucFFp5znYrl6ISQUedvk81jnI7ce0Up1jYsh8fpfp+0V54IUNKqxT9YlgfIkbG"; + sout << "THuQQ3p2F3gLaflJWfZEo/lPtKc6RqVadizBP/oxl7q+zTriiyK1JWiJojr97xQcNpj5j1QF6uqK"; + sout << "+7NQgfVevfP9FGVKOlIgvsqc9fnqWd7pzWKEt+kfjbpMGpr0XFERPld0+EqLts1kAj/ejnMymrXY"; + sout << "LF9QJotJBy/QvCuoaxdptGBtrC+qJIuoYKTNTwW4A+kaf3RrNr4ohuLBcH8RRpHFU+3fEDW/kOx0"; + sout << "fdRlEUJoxdfK4Za4moLfsmI8pfl5rBvlh9oumT0sgvPg1P3zp70UXJI0tSHisCb/uK+Hggdk5AJC"; + sout << "gC3NM9uf3AubM+CW8rJXM6qz+xwOeTuMawJTGTrSg6bB0jREYOXfvwZ7oNgiSdqe4+s7tIhd21Hd"; + sout << "A4cJw2aRWk0ft7jp9O/6A8II+hJoxReOnEFn54/kdM5LFasSjd/wwxsdtGevYixIItCTJ36K48dU"; + sout << "nndygDo7qFWaryImEjgErq5NT9VUfg/VdIhCfAro6xE/1C5mMZ083LOjvTdxdRO2qtgKYCvL3HWO"; + sout << "Q0k2/WqcaAKyh19cHsnoVKd1gXCummMSRahTGpmK82fAE2vP7MThuNel1CGti+nKVvzu7Yf3i9eJ"; + sout << "rUapLrviTwh0cQzVAWHXxgE1PB5O1Hj1/hxwRVStuKNsPjZKFmNu+j02TT98hpYCFQ3F2hE/jsIT"; + sout << "VN7hr7k16mXUqWbCzNWO/xo86NyCS/8HYFXMomE9KRLlP1cw1KrmBltivEnesHGCi9sAyHnE9bcX"; + sout << "mWXL/tg8wP96Eh10KeJKzx80xfSPz929wa2Z+fAfjHqVgroE6AJQNKengtaLjTo9vBo2E4mFbs+O"; + sout << "AUy5hDVB7Mm/koguhT3+BxefAXoUbaUkNlIewCbDUlLIT2o7ZLYhoR1WTBv5K0eQIt1wTFytuCVq"; + sout << "Ldoni/6ad1IzqamnGBam4hFQlbQgpezUea29Idj4dhUqTmksk0fV6sXYaT7P9ELB+W/OnF6IY/d3"; + sout << "6TUUvnSPfyfOHHxts3HnXQjDqrRhPeJown/p6tzdtytVCLeUtXu2o1SgfGJrs9uAtT5+3mt/NjFb"; + sout << "Z+VX86HZQ0A5XgYWiBDvd/Lu+ptgu5kMX/QAzBZp4Ubxo/sF57ZuA3TJMx3CjUFgNBwHgdWsiMHf"; + sout << "X+sBqB17McUR4Ar0OVoRewmYlRL/J9vN3ZJ/P24Ctp/7Y2Ozk6vJZBIKnTvcapys6mah6rzYurq1"; + sout << "vQ2xWL0+hq2nsbkZ6bMl0ummFaCrPrMmBLGAQTa0Qa6QqO8sCtepBHRzJdCH8InI/jRBJKrGM2jt"; + sout << "s/bM1qvX3U8jLY+vGwt8URt+H9VrF13rKxlgTJmu15oC0duyGyc0ejeDsopX4NkOQ6xScBsPAMV3"; + sout << "ObV53ImrQQxdypGM8UWByxISLsLxMDDB49DhgBvOqfJCJ62M/m3zZNV7PzUkUcI3iB8QDpovj4Iz"; + sout << "CmTXnRgtsARcRc977luV5QiiofiyUOVQFbBY7obGei9EnfRspuXIwDsShcRIEHEz49K5SdToZ3Ky"; + sout << "XLZUEdbMsZDXWjmE1A+hN8G1oAh+fkFoXav5S1xvmYqr28vAff3UhxXH1ZKVGM0ePGoaE+AgPIuC"; + sout << "NawQwcyWFoPqLjaWg9bgf1K+gbhxH4ot5ehVVMb5YiAbMrIp03ONvVFLGWzI3tzFMMTOemMgM7UI"; + sout << "1T5rYrWbm0q0IBIhEpRCpR1Smtgf3wK4pUX5y+1+xGk3bIl+Sk2Niyl0DmrmqZlenbbts1n/hDyX"; + sout << "ZS+148JBO7FJvG6L/oD7LaIzjBAPMWB2TLARdbc/ShtzxMY8l58fxDoOs8gZXcpLPHpdgEALIJWi"; + sout << "GHRB0U1osylCUfS64GUzB/mVA4677H/n31R1WaghcDrZRgZs/aeyy8DIO8fprsM8MwamVSDoXU4y"; + sout << "YNyDupAJA6GKeYC98bCgZ64TdmXX3pDp++8TOyDJ0VDNrp6LfdqOaMNZBoG4G8KduFcUssKyGs8D"; + sout << "mGZ49omOt6rvrlZpGgBxi6afE/7tg4ac2hHd368gOLuPyLY8UKBEqhRGG4POlbg+v4AEegzLQdmx"; + sout << "GrtxnwcojpDFU3k+HPD7wQwv6dSpzCGJy3d0682y/x3muJLZ/bQaZUV5yj2SCfgSIaIZcmVW6YFE"; + sout << "dB6GXs8kxAZgfLLPT49+5jIL2XXkcmK/LK4dY3ah3zGKd/Gl+NUxbZmgvOzG2LFmPvp1Mr73WZ4L"; + sout << "q/K4XId/B35/LLnT1bboTQ0L4BaJR0r8uURHEUO2fo9HtByhvRmaDXy/+EAxT01G0sjy4Nhwd0oj"; + sout << "AZrryg3KSdCTwUoi/CMme7l7udmRxUMSg+L6+pTXgONqyBNDwkksQ5aw5YP0q02WH8pd4kVHHLkc"; + sout << "oVv3diXYApXi3VBhtVmpyjAm9xy2SPL/iaoitL7/0d9nIlPymzGh3Ko+ghijUD0Ft1CJHr2pn4QI"; + sout << "zWL0rF8DTAP2TcD1dl/I1VNW6e+GLtWgI5+/iDElGyrELfamWyu63q14Yp/Uk/PrMTsxRxZBzZ6j"; + sout << "whO3UnC70N0EWPc5M1q2HdWuUadl9tbXFj4eX5F2OABfnMgQHDo9F+pux+TxpCZcNFWHsd3frjF4"; + sout << "HEg6qkqxyaqgZLUM7e90wLxYi2/XV59BWFuZC6kVAkBCNWMZm+jisGYdf+kLWVrbZ091mYTllIpt"; + sout << "qP2NaF948T+/rfiT1TbhmAGwj3NI4dd3azqqJuobPSjj9pdK9JEKhyC2QlkZ4HgRqyKTH0xPA1Bo"; + sout << "mBCwzYhubIB2Ro/oxS152TwESUPDvZXTsDbmJXGsdCrrmfQy4NCXNQEDQrjrRY40qXKi9Cxr7wjn"; + sout << "W6NrSshxpsV1NwlSGT1Omn7RUMTpe1JaKLINxTUJCCPOuNCAgbkmPfB2L/vzOq/PJ3/EGbolcvCc"; + sout << "b9zvJpUeKfNOK9oqBZ2dZQqgGAD70uTPitKu0/5pA8I+14sLkIfVpVAiJI/54Jl7cz9lhMQ/X6oh"; + sout << "ILDtsAHQjcDXP9BCfByca38PJy1k56vOjfg6Tuc+0hPcDkhWobXqi5xhJaS8WftAOhqwiYZKsimZ"; + sout << "yFuYCz0EeT8EGFa4APSgjPSmsV7jCpOWoY0RtUYvhNMLSmFuwOVOiUHLrlsb8gFXQ9nmw/K8hxHc"; + sout << "9kP96Rx3f+y79BSrhUcQyUrCSV50Dkvq07wfcAEvz2dzSKvz6zTDIKUFxfz6ejsAvrcx/7UgP0i+"; + sout << "1rtTnLXH/Qy6rlJAkxXknkQQ6YG8egWa0mX9Wh7RxpTIBHPaUU8gnIxo5/RZEPOqrF9DtYoJMOkB"; + sout << "T6davyvC84bSPrk5QW0DMqzZtHDDuzqiRYZt9PfEvURu7iI0IJxcVWYQcXJL/ZCp2GVS8689pmYU"; + sout << "zIse1tFwwd64kJJfyBqN/vW9A4aI/PutNu14PwgPVr+NeH8wQk1L6MhOvcnJwYDWm504takuedAe"; + sout << "hB3lePRnMEBawImStuD73hcZV1KzTKFvu+ebNR8414Wj6gVJYfPvHFv333u0ROqJ4/Mhb8occomU"; + sout << "iedo9jM8ZJCMNbPF2Vxxgg0ahEwTKyq49Qht9fhwUTm6pMQKUXOWU/rD3x2BwrKQ7NZoOUqTlfGI"; + sout << "/czGxGeP9PqmLcleyMhHuC4GdlvI+ulMTlngO760teQeuF7EYqbwwuwqtSV3E/pbXm15OFGSr+Jo"; + sout << "FrOdb5/WWAQ7OZPwdgYiWxbLELUCjNu22PoVAN+PlAYShXG/qR2lO6I8Mh6TXZ2Oo1tP68lGp83c"; + sout << "zpgoFpNuaQzLCWIjC4ka5v5k+Y06crZnzCzxaPsanjgOtAlQ+BOykH38ErkDlM4M2SGs8IAZXW0R"; + sout << "zHgZygNho43FDRQHPiYxFtGV2ktVmtGNTG9YtWQsBcnd/T8xtUP8DI3tR3nCg7Q9esbsZOgbzCIL"; + sout << "bbtn92I2iiTkjuPSuuYoWPzVS+hi4dB/BQWeHtnwdNAmqU8IAJki4jNkjh/6NfTgVepkT2Nc58pT"; + sout << "V2DzfqFrhOW+whXsJIFkMSW0dXp5TCuJBbXTIiChMjXI3c+Dzes1/6CF7l2lFA9Ol+AiPwjPvhhQ"; + sout << "/9WCwWzfsmp3+w+9nLLt5CtWRePIB4LC4nugM985fxf10qYES/K0vxKR1W/Ox90s3D4aG57SHpsX"; + sout << "8frJ3HU/ouV73ZOtCg/lOfdocNCQa8KsoKhy0T7tDnWoKJvk0tT0RhFWltVQHV+sVteGXOImaxcZ"; + sout << "MXDU6mnRzp5PEkQnNnBDeynPvs0vWfVSYpzQnO1tyIM2YYwfyf4vdn82ikk/Bq7onvdohQH02/jU"; + sout << "VQjCfYj7X93of1bK2ja4vdEtcFpH34YTnSyZ7Rc+Np8BjO915a5XsUagea3ZHQbC9bWM75QKp9f3"; + sout << "HuDpEn4tNg6RTcwWf8s2G2vZCfIDkH0geal//39BsVC9LqIyEF6urWV10NerZXv8fumNmX/5JraN"; + sout << "aZ4GGdnhHPnryfHkHG62ESpmUCBO+85oDXrZcI1riSFKCYP1eBn/KCWkjMf0oV/pYljo1KrZgQay"; + sout << "6vXXMdX3Ur6UUBO8iFht7ETaE5BeQ+3CZQEe3tdz1Z87V8rI1Qe0Mkq40CwniIxvk86JigLkt1Yo"; + sout << "iQBWhcgJXCdVPSB1XpOdtZjgMPQNJqJtcbu3SDHsWQ3TWA6BDtVbHH7xY7CPVN2si2K8B8/Czxmq"; + sout << "xptTLBIMfU0xaBZTXLSmRKNTEj9D0H+PwHvzPper0tQU5NEs0tZ5h6FZER1mgFzEmbZ6PUVgc79f"; + sout << "ZnMaLmKpirOVc7HwFfLuwxAlvZRWtykLc8AvkZAp+bYGHERo3uHgdIKkeW82HQUKsUMN1VglX/YJ"; + sout << "gT3Htmy3dEuYGNh+OKulcvAoRvLmTMYgACr/0ZNWoYjqAeilI7ZoqBWHYswLrvefHlIz6WAAHwKC"; + sout << "xkbk1H0uWe9IRcf0DgEvWkXSrWRFHSS7uTePFxtLM7OQW/GmGfS3wt4/YLomS1enK1L4tH8fEPoL"; + sout << "ZeZgjNYiMg4gQLm0+FyseqNjiUnQPkoFQg1jztPOZSxxZToCSdeUIawgr/3KKQ5+d2w3LEOQNAy+"; + sout << "Vf+o5CjsfXpYJ8HVxpLsEC3xwVEzhpeY/04yUW2ygAfZqbsLARAAWtW2GcPN77MnGAmm46PeCSe/"; + sout << "2QS7gM0ygtJewBCaxSfij+Xg4HQhXiKDGuwDYqUDmLInHskRXyVuuojyCbmuNmtyj+Uh+HsdFeOU"; + sout << "fqTBHfEQzzAFVyORf14oEhGBwDuvMzLjciKSfLu6KcMCFo0TGgW3Niu7kDdbxOxg6Kwa4wiaWWk5"; + sout << "mJpG6uhNFaDqpfVLY5D/o0Hn/QfMZwR/MzmiFjstPVEd9ZzuRvKKSDOM9Ult2lTU6z6Ci2b4unu/"; + sout << "qfLIE8CjUAZrZnEFbC/KWx9I9zuUuoXS8elD9aG4if80uCyJ/tT4MXjx8iRsbvZ9B2MoFdO9gAUn"; + sout << "qafo/IYR+5BOderTedEwe6FanPX4opAcpmpXM/cbD51QA9LCLyNtSoSPlwji/ZtYA97UUV6k5c47"; + sout << "TNFB+QtUaW7/DpPZbxsU92jE19ztQ8VV5Ylm8v9RdvZDQ6T37KVEYu/z2bHX5L0r/bkEtQQ3gf/e"; + sout << "TRfpy9j5f1lVMpN/UccOtsZfQpW1WvDcbAbOjcuPuuwNOudfzztYg8vjJNn0pMATK1+p4adLcxJ6"; + sout << "uCh+zVgxAS9W+z/t5X3YTMnFDULRXqs8lsN6H/1t+gVBp1uG04BBDO8YMXZsUB8dLyb5BHbkjPiU"; + sout << "c/zt703lv2U6So90+H+6hlnWDBvfLHI1W1C/RkelOCD8WVZ0rk/sgI9UYf+jk/SkIkZgbtVTB5S5"; + sout << "gd5IXkexjN/kP372AvWgX1t+uKVmKRCrYlG8TZzDrjRD4BjS1M72OGwm3MDT6GPRcaGc+jtSGAxd"; + sout << "eW+6BPofMujBNOBsrxKejRcLsddfOrGTBqP0Wc8B3cNOjEg/ITJ3LOu3FFA5TyLh0BvOEkuiD7Yf"; + sout << "SpB8L9jDYG9nCcgFC19CA0Ji/Oc+aYFDMiDM4be+gG9QUdAxk7Yg13HdJirKRiOvJkf++ayrJBbZ"; + sout << "xs8NEcd6lN3XoogW4gChg9mwlUNehzzjIl56wG01T+29FhD8Z8ebywHq0xEbP5aVGrEzwj4HF7kz"; + sout << "erUrxlRZbmRgleNhldZDjqUpMcHyJzF+OxZ7W2AfpjJf5FaF3E4Rb9IG6vBtPhZbk3X41VpKx5Ih"; + sout << "DedeX1BVHzRYAPt6JchdpOa+F0Y4Evke5XR97NecmGoYXrSkLrArSXDu2qqeQVx6XJ2Qginw4711"; + sout << "fJwfVb8QMMRouBWEYz7k2LflgRG3c3XL1KP8npyz3yNw+6b3/K5JSq4hRFzXfsIWoLArSWgfboMa"; + sout << "w/wIVdPgwPnUh3bfYKEXsQSenrazuqhPosJRMsX8qdZ1HB9PBwMKMHqDessyUX8u4kFMZwo1egfw"; + sout << "7x8xkykzBeY8Iy0y6uOVpowXPgEPg9PZnquqW/r7qQAiqgoW/4oMU8DZtcfT9fTEN8BI0YPPNPFY"; + sout << "1N02ocPS4kd5V8/bcTWkYmtCWIqlDtyPCEcEoBQbhneOdw6EmOmTqJVArAwydyN8TLsJq5A4WOf5"; + sout << "wWO407yhoizBhRjGlqsp+LW9DlG4LHKLWyKCXb+iozkdD8Bh9RX3iYz4RWNSm552G1u8YQtZkeN6"; + sout << "ISOxeylEI9OFkxvueA6u7juBnwpmt3sjcvCOLbbwwpc0JyWfAhiRdBfC7aXLjHm0NcwiULATnoJ8"; + sout << "AMfNt/qzosP8LYQCvRYwJYShbskncZCd5amEmN4eStNaebKyFX++T+XC9EiG9FysF2VZdRf3OhzG"; + sout << "dDBtLbkdALyBU3A6GhCueeZ4c5vEsHC8RMH5iJtYHXwLc0OWzUPC/DVlhVi82avRChrbJQ55qHza"; + sout << "DBnr3I00eUsfGxijrqaQ4bfLLv6S5e4jP6C6waRbf4RY/Q6kTb/9oECBuXgayar3WIvLp9txdLEi"; + sout << "dc0sokRgOMXjrheq/gZ+WHBm88SLRpecDwc4D0kGo7UQC/PIajFQ6b+rTJBKwYxeaaVpA2gaakqR"; + sout << "BeQurXrdpAW7hWwBHjzWtxG/qKwO/w1+x/zAICr9N2/1Bex8yhnnIEcltD2qN5ykzqBzrKrbAzaT"; + sout << "+6+ZV4ZTvaeeR9ELOBDYzkkUF6cgTWAZnwT0dKySp2C6FSPkpNseYrrgwg4Ddp49pN7ehwCE6V99"; + sout << "ebgMYWGWqJkBJ7ID8NSMDFjOTvuWzz08QpaygfGSGgipHcuEKdHEe+waRMOMZSPVD7uDdV67805a"; + sout << "67OTB9j4JiA9zhyXqx3STc/ICLFkSjZbl0VC5XAyVSrC0ssNFCVLp7OQ1REJYnj2bCPMpZCX7VHF"; + sout << "RY8HRrA7V+X0HGBVPAcCrBn47lFvImWGBWGAWFO+TltuyKwvuiBOU1VfG1mxFCJJj/+dv7jIdDp5"; + sout << "+VEE6pCWOJxzveqnJ31PvwrlGjxY8WondS64rNe2qnQ+grQCp69T4vHe7zuYfta+QyjtQxHiXSBV"; + sout << "Zyu+R1M648V5jmGpQUT1bek5FLy5LoDDjkHITygysq4NzrytxR7Ncmgns3tdB9LV5zcHBG2I2gzc"; + sout << "lHRBaE17Z3Lo6zdwTlGpfumcYid4Rlj7OgZIZTz3ogfaWU3OblHZ7Kqs3GGvctnHnaEtK9oqQQFP"; + sout << "zEvAG1nL5qjUKB70e9YgDiUYKAC7xyWVoB0BCIjIXdJh8JlwXMRQDQwEwAeV8FAYKQffZB/2JEeR"; + sout << "NnlaRxoQVuiNom9Dh7AokAEHuitMI+KD0n/bijH01pr9cxGoz9uimYQwc13mw+ZS0GsWeeEqYSWV"; + sout << "kUsJVrisNxDeJMKW0hQuZmNTkjve2aX70+V9ouwgMlYQusDxVGugu0HNUQBgWbnE7CW9fh8GQV6d"; + sout << "42/8i28kLjJ5wiiReXFbXDQ1XBz2lwyXXm4ITq5e3tqn/BDX5S6yQg5yKGl+2pISvDvZipWVNj0i"; + sout << "vNeVWDITOZw8OAceKLZJ0WDSaV5W+F2Io89LI4mdzDdvOk/Ct9OpxFAacJ+H3Kx2S6b70hcdrx2u"; + sout << "0Z+LB5tDRmxItNYMfz7br1fZKyAO8hUxeWN3EF6v5Kh8t0R7ZZaCEJgt19e/3sFnR8ssvG3yrdXG"; + sout << "RT6qvp2JAUyP53dm4Z8+6e/k7YA7WR1sv7d7omc56Q3LXhpvmJJPbAJ/qBXHog0pO1ypN8szYVUo"; + sout << "MbTlTlz0TECPbqdThQFSps4oXTeCcR+5fj7SR2K3FVjbxHfWhrUfw35Nh7nPfzyDHwfmX+CMdwoj"; + sout << "oosswTADZSwvEh46PBffUydxEG7X70qyn8VhwtZGHTFNTjjncB9AXeEpYpqxo9ZzOXcCM4JwI+jC"; + sout << "eDsx+HOZdie9keAE27XUcvKhNYv+G5Hq3rAVNe2GRgDdt47Ysz9vECFOBxO4RVAl0jr3TmfKoqmR"; + sout << "3WBN07KGMf2q1osu+RZ0I2wAQBwQOFG4c7IjxqEh1tlBsFUQBE207K8zhrYyURXulKfdHL+iWMFs"; + sout << "LxDsVCo3kG1L1WRBXynobD9DGvbx3EJf8eIgrKuHSEF46SHUWs6HjACerAjryYxvgHUoUez5cD88"; + sout << "TcQwFBS662AAs2ZAp/uA2Y5Pc+mnluC+5HYwxYI1hh9TVEQPXRPLccrfkitZYn/kV1nf7AzK8lEm"; + sout << "lavQkdl8e0rKXTunMcuayFgtPCyqiUPn16MkaQBfaNM9B+gw/fFeRctuRq46xbF0fTzBS8iQBkoB"; + sout << "jTV7UmXA0Ysi+iVFAod6e4o9eyW/WEhGUCrAWefyUqbytPJDH7/1VnTdlN0bpeXXSyqWgTpsWYay"; + sout << "GhdYTsPDdfTSzkEU/0wdKm/w8Mjcf3+ZcVNJhPtJpEXzoFsSL0VVGym0kDsJooOizfSqJPyczVD/"; + sout << "b3vN7al2kuDcNZjVyWiZCfPyED1UQTSQFVZBygAfmjg/spVyg3zbiW6BjcDuxOxOHmBGXuE2XOjM"; + sout << "Qa4Qcw7ioTk8dw1VobafnWiGN2S01drTqLFry6ZOG2es7IejQh6HsGig3Iis3xAxWWGS/Wazhw8v"; + sout << "M3OQzSmbW7j0L0KcLOMtSxiYo5LHXxhti/pAaI0VZSLkQYgyMD/mBdbs0B8i3diXoAM/BzqfoU7L"; + sout << "7SNslWLwXXscPrJsCwFqe2UiahyQyIvf+qXqBhKuFnIHVt0cF5cFAeEh1kYzAd0JDbi/LB2hYsVX"; + sout << "rosavz3xN715cRmjpjuQvlIDUGfmmJPPvUIbbXiT6q+TrDdQaPw56G7vz5llNd17QS7UUZymWjY8"; + sout << "p/8bq6MieVzWeOmiBPbmWRtmZ4GPTc0bmau8FH77JxwgrkdH/7jDL6uhcd3grTrsm8X8ZX9FGCZy"; + sout << "QZQ52pU3tWbfy7KsWnBZjkRwzZVRY3EVjg9V1PMQYKAqps7mU0LrDBPwrgHW0c9EvuYa9SIShUgX"; + sout << "JtiWswoJVbbJXltdhclBmvuDsyK/MfLJrSE7qnPOD009hILCojdSCmqdaFgpmw88EiYAVDqwcFzw"; + sout << "EiDdvqg91OFplvWoGEQx5wsjOKaQst1Z1LSc/C1ZZH+mmBMgHfYmr3ZiX/eYUjGzYpdKGa00v8gk"; + sout << "dDEkUCLQ1I0ayiaz9cMPWkZ3avnYb7TdUKa3iwERVVFEJWfhpWe/4KP5qBL3Ih8qLicgtdkt6FOM"; + sout << "ymOpNQmfe2X+P3U6042OQPIU1vmig0euGeKQmsE4K3NoSOEy2QheNv/2GSSNRzc925zhjAk1zxXQ"; + sout << "04KKCDA4HJKFK0tly5P6KjSVTiQXZ6e+IWH9CQfAJWumCVtHNCulD2epjCPlTEjHH332gl5c7h7u"; + sout << "eJRVbsJKZm9z1TgFN5w1iDWI1cBLZipzM0BznWEvEoYNixyZot+Yfm6IC0SZi8DA7c4Ngn9V0ydg"; + sout << "c8JNp+A7T7LqB/LWgma/E4+wv/cdoXiKi3mwg71ZQRM1d5oWknFfm8hq0KNOl2RsSOYO92PIF8Lp"; + sout << "kG2KPK2nBU7VkbNtaJwjmJpadvkAEPPpCbYDdRjrvKaSw9yehsWKfLVuCLQHlO7uJmGe2LHWKojj"; + sout << "BtYYau7+Vfyjj8K4eNqABvny8hj8zV0HyToirIxVpBmH/6/8p2imqYx8hc7yBE84LSNCEac3UJYf"; + sout << "8sxZluHs/ZD3t6s9jVisgtx67iv8u3A4toAOlYjBqj/RXVQRAb30MzAFy1YCdxedXIvKBa/jXMwk"; + sout << "0DsQ+YcRsnKmnGQ2yOxNVfiJZIg+KqN2Ork0y2QNEy7aTv25OEUM8U2wKaRkiyNYh9VmvZpYk5xM"; + sout << "OPOsReX2FOJOFTCUGZ+w4Lp13OI+QLc9TBllim4bhtbBscCMN78WOpnUccS17gXHnqtpK1Nx7R3N"; + sout << "dwpVTYBCxVVUSCHAREPvmPlHj7c7nnIfO5LRxuGxrg0TfCBUwCNiJvuUpiuqNyJha+eJCPS+SlXy"; + sout << "BN9m4jvEM4KGbkGSwelXon6LWOKq3cya+8xOqQ6KliFrdQaNk0Tdg1Mi2oRjBlqjyb55oXYNeRgT"; + sout << "wQ8wAPqY5TM/z/Zk9ELGeQz7JjS/ru5Q1LGNUU3tL0AjqUdU6adovph7CC0/g8lX/NqYg1oRT575"; + sout << "7m4ihHqVTIN3pMkB8TJBZNAa3i1rlom4qcOFGdjTi4WTRd9r3ll0y4PB7V81SddpOk8RPTOIDKR8"; + sout << "pgIa6xFOQ/GTIaB8i7uLax+IMs4dIq+gzpxE9g35BP7J5cbIYJJeMFIm92l7wJB93ANTT84yQABm"; + sout << "oBnlfu/fjcbGvfHw41CTRpMWraO5d9nABxuIX4wFKEzxse0iAVsJA21RbpdBtIHLBcbYKCh1bUW/"; + sout << "hRSk1yislMs9Mtjx/VSUn50NAompaF3NLpG10/hGd7pp1co4wZm7+fiW6sIPBCqLtkg6A1+qGUxG"; + sout << "nUU6NqUvmyWo5B4Re3mDggASRSIFdD8ckv7LlAbyDgVyIxsBrf80ISQN7jRDDd7MjkyAJ5CRlrUu"; + sout << "vFCwCo11+7E/Yfk2dadQJTjGvuyyDXW72l4ze1PuBrjnql0vBe7CCPZKgKJWWDzbGIEKGbBqDfBV"; + sout << "EVA0NrVpWYBu3uhj8sHi+cAMwiXy2T1ar59Cz8bvuIA+9egegV19YAay92S86BsJbhVfb/THW/G9"; + sout << "sbTgF3MWTUgvDFmaJzYn7mkR/2xsBn6UQsLYPZPxOQ4VaEUzAKFSP0zwGZfMYE/REnDIbdOR1ai+"; + sout << "zrA5gWXmbLlr09hkM7tIn2Tw/+X7zEzZ66oDyjDs6g4a+CBt8OtOLi7Ga/dO1DP2y4n1YPAsGADW"; + sout << "Jf24rVZGv8IWyGmlYJZWX2tBKV++HDUTxnmbsDSI7Jy3r1UXNWskyDC9zYWetsp8UxAVRYBDhwfv"; + sout << "fqlchxnmTpm6ozULqYI6ExncbbbK1IBSbY2A+VBUrEqdH7Rxg45S0HKRbvB2jAkksOyhbXUwngM8"; + sout << "6uG1TUcRdnxLvZniKptRcYJthyx0HDHCzLoTT1AlyTB9kvrktDArOUa1zmF2hCZAYP1swqb+zBjP"; + sout << "CHd4aExlUf/UokEBduXCWomJLGXsJpVG1EVz6kTZb2Twdt5x5vZJK3O/LCBKgZNMuS7/xLFZpvsB"; + sout << "cUA0MoO577pa7SECGSrUrSblkliCMKmIECRrQoEJaNKCs1r991ptT9aWnLSbl6HTTWXKcceN0vXQ"; + sout << "tKwTpiM5nhBmea6eOQn5JP5oeytyDul1Jd0SHR4WILgQvYQSrqdyQOf5zDBQHe2/EsinsfonJy50"; + sout << "4M4l2ArPeWEjJKtQBpgQaLF7CRGTjnv0TPJ1Xjo2Cy2OFvWpkA1ZjK3Q/I5PJsMrlkYkTSgQql/k"; + sout << "HwmLiFL0KGdVfo6F1HT5J/5ZxCDOWhO+DmvLXtvTQSEE2asJVZWDo1rpKOq6H3AhObvUgp8AtPec"; + sout << "+KXWiyzTH9YYq5RFU4JE8a+lt8+bGtYHxGOE6AuVgQsvUvAJ2kgqgi+kLYF19FZjeYbvuxzI3t7j"; + sout << "+qCO3v59HfjZWfQ7FOJ/ciLVAAEM65r1sX1JpZRodNnZUvbxT7QoVOm0LwjW3qw2VxZCbVU/jPkc"; + sout << "XQ+8lxiZpHO9mSxBXvQkwRW4K9MZ9dBl0oYdvq8IK33SkyTcaHBqlp2oszKYwDFyLo2C9uiE2BSE"; + sout << "gj9UpPG4M7XOUjUMJ3gs7aitPsTdI5oHG1xklLOTvyWYIJ6vk2/uBrMImGSExPSMMGw0NEbHFEX0"; + sout << "Fh4Td6+cO6QgCQeobzV6Hxsi+HTk5mtAbMGXiIKi7bAiSC3tkoiCeDfBxzqROYgSVQU9fW8PeyW9"; + sout << "xEmwZSABCNRog1xAwHaOuosBqYDMidyh7F5ID9XeYHH0qaMxLmzYkZ0SAl1qECFzJ7wUSiwRzvzC"; + sout << "VfatrCs84cDUKQIpr82LPouxQBeVFQAlj0Plcuo6CxrJgVR8q6ciC8eCnZ5WaF27pKHrfPc7ZjdL"; + sout << "h4AwTLGOpUg0N0/aUJxvFGOVbm/g8pH7duFC3+ycmYrOXZPbhz2qthsxf894+lkU6TfmN3OOoTwG"; + sout << "EXrda/bR2NQNvqJWBRTv//VdXIr1RmlNMDcz2d9lL3rwF9Gq5EupgYIfe1FhiiqDjyrTzh7fw2Qw"; + sout << "P5T1SfKa5Ww9mGxa1psmYlJ/IzOQlSlSzSHxMVvD7c6UDpWwKWFlzzRPjl3WPS9Dmzk3tzUo2ZYx"; + sout << "RdP7RBZg7Cveb3lkIn+gU5XII4cD79YzKPesFWpWR1ejkLVGW1YQB7jtRaJmoz8T81UTuqhxhBt4"; + sout << "NswM1L0VtG8asAidWd9nT3Y/WWI2ydJ9YgiepKeMdtpt8aYqoutZUFzPKGhdZsZIYikb9r1Svrd3"; + sout << "MAyceFX5RIKLVcyJJC9n1QlJVSXz2e47rrj99prVvMN4ROB4hpORryHzGuPMMhpn21ZywDk1aH9D"; + sout << "0WGZq8Nm4YZEjneUtkkW0nXcF0nxJ0wTFBmOrXEgjpk6jzfQ+9+b803WlwtqRUJ6S6FYirRyTv6X"; + sout << "r8xCqyXgNioZycZSgjxzLcoFqq/65u935jK7BEMtBGWa4zQcwq38IcmZu9xfCrQLCfHpIfydjxBs"; + sout << "8PrNV84ccG5Yfs4zN6oaLOahJBukoFHUZLLM+67oALQzsStEXS+HQhWMJCRj5M/8/SwoxZM0ACfQ"; + sout << "XY5G7OnF21/5NwpsUoYboT1wNFT2r8TKrFOx9bEIbKV8xTUgdrnZwbKGPemsypnnsFHoA3BwIWKx"; + sout << "+w0vNWVqK2vsbL/pTiZiov1lxvfKFV25Q9uylplUaSnzYuaDQFLOPBFr8nhcmnfZA8r4ljcjMNw6"; + sout << "0AUHO03MyUDfu0BUVYQEyGXpEvwnI5JmVOI/y3/TqnVnReemG4D+diuwm2Q5UgyY8KyykxJam4Zv"; + sout << "W+Tn3DHT0emJNWv3imi5itW0rUIaTpt/9a0LvqIWi0Q1OOvrxGgUcrsRzfUVi+ru+tX9YZrBOYXe"; + sout << "Ut+KwUCvITG88r9m7o8aovNyH9V7bC4axTrIxmtpRiRP5e0Z/IiH6b560Z9ixQ0MRv5SUn/lhX/l"; + sout << "AP2rTsqxLmt2mWz9b3GqtEO0iW8XisORLiHzupEX9jRbcqShrBoG9bu+4DUO9hGdemq5lX1782dh"; + sout << "FeZOjhIE0/8JttTmz7EZwzp7SeHup1yqlForPR0hhkFzfOOQk5CwfYR1tD+0ImdmbL/QyE/TRsnP"; + sout << "NoDB0NVyXcmlJQzRH+s/dK9kJjHFn4FAUJ7Eaw5xi8/O5VokSz3gfwwPHbKFexrY18fDiyeZ404M"; + sout << "cAGYSGyYzxkpa1HwuZE7UzUNlyzpuja8RR/8XaUv1XRK1lIkmDFA9cKT81vtXJBY8/04Jh2OsPBT"; + sout << "mFRnFF6+vItwlsmOaF/WUstZg2XxP+MePh2jGkKj3b4c4Vw9YwzD0/xNlkCk22jnN60wtM5WHiLT"; + sout << "Opu4ZnP6/woOP+7uOOHdjEJvo2d/v3TOmNmfx38TBCBdAbRpMzy4XDZdsWpP7l0DN9/OnbErxD30"; + sout << "2sECWhkxKygkNxeq+CoeiS6Zl3mzEKkQ9q/XHccCioWgiODFt6+Wg9MtPHGaR1cnaDDBBp4YQWJG"; + sout << "c0D1Urc9H4Tpn0lmJkx0p09nBoMHGIBUKmHuhMOnuEuZp9wDYSQ6UnjtLy/+QJcCB8QDvTCg9mhu"; + sout << "W2HNZPqa9DFD2tmV5e3+pTojHI4O0tOwA5B2OYmqFzdvMWLuFn0uMHOu5285KE0ZDykOOv4Nupq9"; + sout << "z/rjTtrxInMTU+a7hynQ7Ra75F5nCnEefKXL2lgD90IkuNDpmxHBX5OgOYb9RUwqcVGAFSZq1jYm"; + sout << "rg7sxuDJu/honmXsDohqfo3/vBVca9U28wnCJ2IsuJXRQwzBLOJAz44ijj6+Bx4R/7U3IbN6nSNW"; + sout << "5fXiqV1JD/r3AMAf6THZKnYy88mv6kMzt0ZXPd4PlyHCYhLTIzw3AZP36xPjEWBRP8FGaJw5Pm02"; + sout << "Kh91uM2QdF2jowNV/Ago7rHS09R3B/aWjmGdVTyhCjdogPtrhVHdaq33OftcCuCF7q/wZ8kL56HD"; + sout << "2T2BxxmsZY0XqSlU3757R/VoeGdk7XjAGbpOeyxn+yrP8/stbyIwxq8PSv6YhRN2c3+H9AYka8Lr"; + sout << "Vrb16X3aEEP49jxNwEOL4hM2rcgS7f3uayVLwsuviWxiUcQdWcmjTVgTos/11pdXpSaPqSAvpAMs"; + sout << "m79R8fdq1V5+kmZZ8UZKUdQwjJxeg83dCTDpYRTXAdSdKDLXeD4Kkvd2ebKOGb+1j345x4HUnuef"; + sout << "ijTom/Xz3fEw1jLgqZd/nVeUCBYcXpozL50LwzYDZzjxERSVCMCgmt6VJBgpvwFBXyAYJql5972d"; + sout << "vIP6cFU9K1AGAUipBxgaVR/5NBFTIjgo2CEOJulAQkwWXE1TOUXHF6hD+Tp9Zf3nZRmTBU2OXA/U"; + sout << "8M5Q6nyDRaV/qRdB3GMrYPHZIXcy+uh9xI6OYDWSfv6uZxGfKDzN2JjMNLuK5CEIu4hiLj3JHC5o"; + sout << "vHJ77Vimbm7wa/cTYx4ztSvSSFwSXy82nSd3o6d3Z7vL4FlzDjxkW+6AfP+8SQLvC8yiKiGr6U9/"; + sout << "cqFdi3xOg3Ska+3F1sn/OIrDZVLkP5eyqo2aDC++WnVv4Aig14Xf7lgD/nOV16X4sn6BmpIysM/i"; + sout << "Cjqh7AuHgDkB9kgr7jQXFeGAhdGCNby93jgIRBdWTY61JRX6Isek98cdoL1XqCheIxKckNMmljnr"; + sout << "lHL4Fi8okkonI43IKV+8NJ5eH/JmUEtrcOuaB9M4rEI1NsEQK72dIWyuuIYyocaeA87R2dt2biia"; + sout << "SsZLV5jE2O9YE/AWY5fDGAH9PA40TD8bC6bxgaik/QFegoo4tRw2yR+GwtUh9Utsu9LUofRnq6KZ"; + sout << "hq+3Bqzu3U9grOYNU9MUxNo5jZS52U28A8NSuHBDeS2870sdUvEM2RftmNusMnCDR1bw/tru/i33"; + sout << "iYIUoaX6EPN2UOJMm1TGqIOMs4QWkDTXYAsIdpBgq2nMfXVDb5y/nqpmQH2eJCgz/7Ly0VtZISf9"; + sout << "mkDrrok3cIEMKMgvM5X/dA4F6Bv929nkXqqSi6AtUER/jJJARkpAJHJTj0IxUgY2unVx4KP4OuL5"; + sout << "b1HeIkWdlln4V64OPJnBiKl17p4pX5BGtEXngsMR3kdDBJwliLv6ciLvPKybgggvLSQbQ/vqnZdh"; + sout << "1x1GlnvrpKfoLFH1S6f4HiDicapOP3wKOp8ECyNZp2IAVbRmmPhjKjPry1EdgySgpXFZUHVEHdR3"; + sout << "R7cGnH1AGhZinYuoN74cFsNooe+LxT74F4KWSW08+S659M7862DADzaxlHAbd9BL79Tu8RGS5CUB"; + sout << "i2ATQ4rUgRvzC9OGcwCpOAfbp33+jqLmAWRllB5f5uBJuO2nG7OiKHV9jeunOyHxNnp4UTpblCBW"; + sout << "yKx4EVA7p5T7qH8krIftUpBbUx38XhCtPMcsIZIMp2w+nmrlZeHyPGVGFrA+Z1+VjmRtiwGCghXG"; + sout << "JjZ9j+HPIniiHK78MNAPoRlt4G5mtBopMopDn/1cdVroLuJkxX1CpCOu3MmkQQghl/km3jphiehr"; + sout << "Bj/E9IZI3/LesK3Esv1J5VJCFl+Eyn8i90S2ij6sBBPF+eqY7TAMYAr7MK9mFtfM9orEpIJTvG+U"; + sout << "Uun4ibgHRw0t6UQ4nnvZKHL4qmZXZB4clGgvF2Sl1hj3FpUhQLr4YWJbMeBVddvi8Ifsl/KpN/J8"; + sout << "k0iL9nhnyDfLBNy9iT6Fnyrt6wKEJKAmvMTUS/U5YZ3CFAgWEmqI04ESYkZ1F/pVjwgfDtwjbkPV"; + sout << "Vr+9zaxAa2gT5yZZ8m6CqMJJ+Hr1857AgVw+msikx4c3G0zFZpUEv4o3PiprOr0t0bGSzr4rWnLb"; + sout << "rPNfMq0mQI2bWIMeZDtiAeMLOwrGdlnBUekp8NydJeL1XNAMMJjHKvNEnopNqkZ9Gt1Kmvl3Eqcx"; + sout << "TuwbsX2ew5yNJOUD7T9q4lYRdKSNs88fdTFzE6heGTV9UGs6z0PMP2HoLXSShBhvPAIX3t92KK+9"; + sout << "8986KZvRZ/QVN93Aqt1mczK4SEgo2X6Tm5814X6WcCWiZpH8CcwQjW23psQyguC4PyLhGJo92XJY"; + sout << "GRVVisGEARz1qjm/GY/jW9kE6alXImoLE+F0opY6QTQXAo+KKleXAWmMEUs972E/aFQtmR2V2mxo"; + sout << "Ga6EtSNXobw7jP+kMkYoSiav9pYhII3t/v/badrNBhKksNt56Zx6BYFOnfhPrwGkR+YoRz2+x+gR"; + sout << "GMHLwRFe8lMPR584WNMrC8r0r3i+qrp2TeG5vFM+VKwWLSBUoG+bgZShS7ITSJiSjkuDMpa7jy3w"; + sout << "1s1AIgtNyvxrKPTKr6C2FjVuhDkpSInVX+Cdr7utecsExirDn7uyb46oI46oNbsD/6HW3sIkvrVH"; + sout << "xWY1bwvcKlT1+aFWKYtdZEOk8JE60pwiat+MpxgOk9TC5EUdJvcxOP8M8zQsIuQhWlkgNhsm3G7X"; + sout << "1xWL302yhLrl66DlkKHjCH8+ee7/RCRm3l4nS62XeVQ9ZlzEIUbPP1tnnzg2aFi/VxMxk+EfF23P"; + sout << "auIEXL5jALZVnr1D93LlSiqGwKoAq7s1bblpIGyxcXOsuaZ8Ls/J7FsxcLD8BQiK9v3wotEzaOOD"; + sout << "jSfsH5DDUQ/nIdlBwZk1J8aOTSIo7DZ1wZRBSz66ptR9sSqcwrXvFndHulLtva04DfNDPzPNaCoH"; + sout << "cAjxsOZRKl+ZPi9X/qsVzCr6DijJYeWT+UqbhTOGX/Vrl3M3Mt7q4Pcs0vcdXH+SUV8YlF/zmaD/"; + sout << "xxZcFz5DPz8NRAculqklupW4Wc1D421TC78Kka2NTJKLzivgisWROR9c6g4ml5cqbrizhraRbvXP"; + sout << "ilWQ0PSOq+uyJqeo+1JOSH3KWoyGWfOfGxlRffmvsOOuY2oblonsSCz6sYKySsnKeTezRcSP9PFg"; + sout << "6EGCjFg1WodhnA/KWixJtQ8QsfhaRDoIYklDBWlPBEKxgvHi3yVF3vTyakoY/VR82B8UWfzxW+gL"; + sout << "ldM+gn3H/DTpzR8/h7xBK2yoOw2Gpv2TcEB5oMtPsXsN58QdoqAsgBcHI/GNdfqaQDK0lt+npV2M"; + sout << "8NAnFcK3r/wdAJ6OwJYtXPEIbhHxRkEY7GLUT1IwLeV2PIV8gy3Eh03ULZDcmii5xO/MsPgrUSKM"; + sout << "Z9s+JmBmJhQISUjumZ2SBkAU5V3SG3T8c8EaeN/9yNYMFmSXX/jAjlAViNK8bCfnPUkrmnTAMaMv"; + sout << "blW6tw1yhXV1WSAe075/Zl9slrZhcCJHYX+mZ1yKDpt7x/QxayGmwypWDo9ukTnE66UKiZ0udTiE"; + sout << "awN5uFe+qvj+zZLrnM4Bf4ZjMjKzjd67W8UaxMJCVtx2SHeNC7Ffjo9xN4bicDUnpqTt5phMV+hX"; + sout << "lBRiuUYA067qUb58MbkIE02F36n/ioHNYj15Eteh9wv1IMIpMLz/1UGm6+97m8xIAKH5B691KOsD"; + sout << "k1ovwfjzGXmFX12KsdSplPSmH9UCo6CsuJmFkhAkYd5vCw99+JVu7RNqhJGF4h6Lg22AqisjfHde"; + sout << "5CFUks0DocGrgDw3GcmVMFAV1Ix29iekfAgEJcOphJBGJ3FvwRmGkflrIJW5x2sPMaRHeU8Fkzso"; + sout << "o/IF8537YleWU45S1m+GXNdomkGIFVXZCKwiQ8bRr1MmzMcs+0EkU0r6ei7yxi+6LlKfp2NQOv/Q"; + sout << "owQ74v4YHvyNJcI+YLLMHaZySYlxIjL33H6b8vjrzsDntYXii8lVRRXhykAiLldwH+89cFk2tbeO"; + sout << "am9kZeEIlCX889URxsw58TRlwd0Lfag6IhNO/hHpi652lgEYcZu5NPaDQPRRC6vkl/+ocoMmTDhX"; + sout << "uGhphdczLsCBpIejuQVqI1dhrnko128Jzl6yIe6mpuhzRT1ReFL80F6VRjegnDa67Srztleu6qzi"; + sout << "OUgU3EDRfE8IJ4i97SsnTFJR7Eyou2/GP1U+ai1U7JTknHbg7fNeD5wjTIyQVE31oOR2WTNELWa0"; + sout << "/GWO3P6Em3J7vZJifeKTGYWf1124uh+oBYgv+j3vUuX01i1P9kmeUP0nmVgzWs6AGIvbIp6vF+gV"; + sout << "Uc+8TvNYH+0f19q/lPOBnBT46URHpJ9EZxodLin8K5aTCB3sLjpXiqQDPD9fbMzL2r0s1Z0YPWgH"; + sout << "Mllgxh7l6bIhAmymXIUX3+Iw9rnW/6ODO1huBbvgZOAUrLMnTA9j54trOL/n2DxwAU7JJ/fCeH5u"; + sout << "PKH6nM9l/PvCTslH7ahvT/uqGD3/++3FeHok90oU0QxyDVbsMDX+ksj5Hn15ZQZQtYqToONe/OI2"; + sout << "F+xJHU1CWejRZDJgDICcBYQSI9e6crc1vdf0y7aUeNRthj2xLC0zbtfy7PaIEy/R2E93I3jOoAq2"; + sout << "Mz+0icoUb3Hckf2St+HqfYPL59yCrEyqoFONN5XMArC/MTYpMn/XQfCQL2X9xV9T7WaScs9TAqQ2"; + sout << "/qZrrcfRxUaPCVIBtG5V+of/vfYoZzwQrLTioTohrX6SC0WvxS05UlbYo9sDwrchmIxPmrf5vI10"; + sout << "T4Sc1b7sRfiijqZA+YTnYyBPG5c2qqsmSXLmtctpzUn94XPhS76qLj6uQEcoaYTOwFK2dVTvFpfu"; + sout << "zlRNDzCCfzXc64lA9oKwEhn2tTH4ddClyYobly+2HH3xz3NQMrLxeYXblIKBFJzn6IlY26a+77qM"; + sout << "srMzEwRmnzUojuRBO6n7BxGFfP1JuP4FhIpKLTgA1Ql5mf7Z9LrsM2lORzMuO/og7LvjrO5jVs/w"; + sout << "+ZKmQzSzNBjEgq09UDk3gMG5OV38SgLXQCkUnCUThGU6En5OTQkSAF0IDQ9Bmmyadqp/+V6WiViv"; + sout << "Py/WYbFDSAb95ME/J9YGfTg9VWncLDEU4otj6Milbr8OLjlR5YbbGkieo2I4jYfuKHR+/C1uXVWk"; + sout << "HEOZBv2lYRKteXx18m8DLOZ1PO0PC20KRwcep1QSmw8SH0NJ2Gj5L2WvkpfCSwKYJOJECrAc2PcI"; + sout << "Bzf363+QgmckMylMuHJ0Kf6AAP2xFtKYw3qM7tTUjt14QkC1Bi0l7ViNGt1vvyTq8erWFJYveVwm"; + sout << "lJQMmyvrHq5g+hbj9ftT8cPxbKJ/EPYGKrtrM8AGxTxrtA0FHqmzjPTm9rHXhGYfm3ALWw+frbVc"; + sout << "d06D+iiyRXpTTDZoEcbMaZYCSyHI1QQmTR9Wfr64r6XSk+/jDLNlL/Syp/Hzs5dK3P2YR/UjTB3I"; + sout << "8U70K/yeV0v9rDUI41COSK/akYmpybimqmexTop7zqNUIFOqFHbH7TCqRxfq76DFebeHbE1Nq7l9"; + sout << "p8HRuOuavv9wKGBJx5BoYl0h+ujVfqWIJHas9QuCbJK8Z2eNht/j/FFSvdlPfWIrRu20Anay6EHN"; + sout << "bEDl6/voathXXz8Rtc+/xkb4A2D0+WwXHrYh0onectahnSjI6sK35wtf8UF653KENWSOVFiFAwjr"; + sout << "XHV4YZb+VWi5jQ7470jOdd+6eg9v47AymA+6Z6ZeTWazbxSgfV8hCsAepHyZ+U2Z0B1Zgeuc5rKp"; + sout << "MhBQZ8xf5x+yM1G9Jtu4Z5kgQ8NdliztMLkquT+mbY9U4eri/WsAQa0BvC6hisueLg7IGqjQGgIJ"; + sout << "29Y6iqD9fJZuGAgbEbW7qu7J3DyYzGaKyAx4+GOyctio9ChW1v4uq7gkc8XNpmYiGl/90mo20hzR"; + sout << "vEXjcg3Hh/aHJlVXH0hJ6W4kOuBFkZxptoDFi4vKdSCVWAgnQgWhmhkeI492Y9bUaDTxejYhTfDX"; + sout << "N08UiRXnAN96KK9iL1fJZxB5MUX7cJ/GvW7RjTVP1MNzy/FdlAN+k6Sb8vyst3Sz3eGHOgO1hpTq"; + sout << "Wz9Moc4BVabx5NLlb74ScQHCHX3dnKXle6sPCzfbPjM2Z8rbm8NvB/VH1FZ/EzriephAUYdra8OV"; + sout << "+f9V3o9zEYMGujKxqZysPfheqaA3gKKDVgKGhfUn51dEq5YC61aWqxjO1suUrW6rQpBHN/V/VadA"; + sout << "hpuSHr2UHCeL8yUzC2nnXxpHCifGCrx1/R+A7EXBjTlrxd0n83pEV5paIK1AB8Vsv7RNrT0bEPxx"; + sout << "uNxgTVdFUVHVLZf2zuhlqQCSr6zKzQgdlsW2Rbkzqm/RpdwRmONCWkS5IRTwdqLM3rfStz3zt3U2"; + sout << "PklWaPeSMTUzREfa9xfaxA99Yn8QcQZU0GK+zEe0d/iyeDr+7XPYH6ZOgQbFp3BgHWk9tAgeVx3d"; + sout << "hOwbut1u5WRa4azjGkYUUlvXLJpehNwKEiaej7u9jKrdQMTRi7gtSYliGdNr0nnw4ADV60LXD0Fi"; + sout << "uYvPQwJnbICCYdy/Yx0X/HBu0dUS9D15JuHuT9z1D/dqirMKNiWhXlEzC99gX4/mGGM4Q6l53SQw"; + sout << "0205XkJ21NasDxznXp8UE2e4GsI+N40b6mLvLZaauw9dB2IeNhwZvn7mkGDraXjDnpnnmX0iV0Y3"; + sout << "OH6OCq5ZBlk9/IIVJpS6Xtec6MiuVBYrhKj2MVn0tI7rRalXXqk+7Plmg/9S/Vh6U5RM3pCjTTf4"; + sout << "ZEAxFH7LhT/JRpkk7fPxcncaHQVrnrbM8xEYc9/2nptKYriJdXHXecsiH8xYjjzw8ACs/FeFyPsK"; + sout << "84MjAvnsSWnPVpJe9n51BIQ68dUG87+igoKtNjWomVVSAMlBKtbiNrzFN+4AvESKCufoTC4gpGPD"; + sout << "2sFkTO5yg8+v3rwFK4aAjBLA6eKzSGkcJqcThom+NfRjrvCtPysEisAdXHc40muhThMiIiOFXwbH"; + sout << "T4722yaSHEMGNwDqY9dlixhGz8L5G/B3hbrlqCPe2+GHcmmMI58oFMgZhEh4u/1t5nkCDizZ1COg"; + sout << "+CDWA4n4lcOXfVMwp/8u7GXZMMLuGSJUy38DgjLzy6f8bji0+EEZ8rbFsvMjssgcx25YJ9Jy23eq"; + sout << "rcwGu8iMKUQ6t8J2sZFPXuFrn9cyFP/vP6qFNWZ1CtiN2fRVbqwtBj6DXmSae7NDMmL8XfGlTD5i"; + sout << "aiIWK5SPqYzS0K6oPFagYWYjIKqkNRyIAseq0fqTLKnyKmv8oj///TlXGawbehr9clexoKbAM5Mv"; + sout << "1RKu3/pJiPVwOPXcwCIlQ+BemtjmsnBzJySAkA1I9mpgkPm4Pg4qblWhgq/aVulwN5uSkJP4ZU6n"; + sout << "reWj3khTO6Svsua3toMPYm6FOoioKOOrMMXQCoIjsybBmSaEHeAFxZmHODtbf/WrGL4YbhUMal5o"; + sout << "90Ay+c0apxZIV+8d7Bm5L0dRXmQfSUCnP1CKyUEGPaMAPK+DYRZNU97lzpowFRdjTlrJX5aN3jrk"; + sout << "WN8/xaU0PTVo1LtIesYvyG8n8gBSmj90hw5QLFt/94H1NDvPYMvw5A6oGoQUYdCXTXzrRJ2OEow1"; + sout << "0/qEzLCWExSS9q8nbsf9Ne5ZC24XA9KHmiJot5v7MZ4KBORjPy/Ub223U3xuJqVGHlSjYuHqmOFC"; + sout << "bhqGY4yNo/7PU1/cvg6LbtQSriMj8+85GQSheMBI73EnwIHgPOIV7EGEr/GHrfyM1jI763x7Pr0H"; + sout << "Bcsg5o34glMy6YWSMLd16/Djq277MW4XiMoGPK7Rptwh9ahtgjS+SAl92VTVyx+kr7gZXfecCivl"; + sout << "gXUmA8sQugTC3Zf7QqR5eJvYA/i+PMRWkLw4YCogDOYXK7tKY86djtBHftkYob6bMj8a7/1/XGzm"; + sout << "Xz1OGinw9yYz7N19wgkVZEubmK08ZnKINN7A9IU98iAauGaLizHRlQ1y84Sz13coFKe9lRtgnNHr"; + sout << "zsG5yqiTSIlH1jLyxVh1O94Qj6RlNT65LPJa714IW1ot0XsS8QgX7Jazdb8mqyiBky3/w+VoLWGx"; + sout << "kplBp4VZN5gRvTIHRBq4LaoGgzc1C5us3loLsxNUDgKBCBjUkywWAD+l7sCB+2btJo2HAcmgbNBQ"; + sout << "zITjO+dRirzIS6GerX9+zbzv4lC2d5K8ZmDUMHV2x2NKaXIU3mP6II/iBaP7W2VRx5qfAAChmMAP"; + sout << "B6zDwzc7LNopXp/GgOKhYzjctxbbpOcMjP1gnQ/2F4EAyXOU+iMZeXqlmbJcXk9S+bCK/Q8i96kW"; + sout << "k8lbmB/ecQxbYJKynBJoijaSzeRYL5wC/IZyqScchpePu8J84p7DEu4ECqd9+vhjp28/u6aCgnKe"; + sout << "T0mE9aIOVS9DLgIURa5qmWUFvolbrZ698AwXlEV8vkMgSDZyxgRSBdyZcmQawO26rEVQWeAr7dw2"; + sout << "zQIGGtM4wACd+Vv3e1FgjvkytZOwyo0NLlURyqdykYhZx+youz/Kmmri50XG/hpD1+5gXtVFWUhO"; + sout << "V5G9Aw6GG3+trxtMJnjCdkaA7DTjTIU8nUjNnx3TsCqXK0OqDMvem3e2TTeXorRPe2zYbhYB5pVY"; + sout << "lSJ5yYpUrIzJCb3xPGNxQP1qVv/tCEq7IIKGqPl2RS0NihXE1uz8xPMu3cxu1juHsw4Mx3qi2C3z"; + sout << "KH4LvPoEbP7fO9uELQD6ipi+AmmB7bKabD2AwsHFurh8oFtaCt+KuVy+676ggvwGEFm0j28Cj/ff"; + sout << "SQzCut2j0HQjychGdngbYIF3HLJQNr8rbh7k1FPpJhdpJc5lH4QXMddVxuydPbwQuB5zGOWZdc6y"; + sout << "0O4zjCgP0ibSPyEdhkolB5T7Sm6ftL3dctp49kJKtIyGd5scyxMreHGJxeaGADdgqken2Ahx9izm"; + sout << "VoVdjhj2tvLlJ5FbAlHqir5ZcZIu+prlMkAnW20mXjeWdh5lSCUFGcOTTyxirXkUfyDRFVwfMnOi"; + sout << "Y8j4v6X/LyNIKgBPY5qFKfbwAjYWyAPYYQVBurLEu5gsYGSYhmtm6BGZ39gnxbBfhB/JkCyE68Pp"; + sout << "wwAA3SA3Uxha5yMTYK75ClGNsWyI1AwnkfYuYgEyCbBv/psf+I/jSpT3yvcEDKiI7FwToriSgA61"; + sout << "3ksBBfw9E7Y66wmwTHq9O4alqsltIAVLkj/Lp+DxSc6tZn5lI66aOj7paImjnZac/GoJAxVUfcKL"; + sout << "q8pBZEwlJoAZb6T3wCuEdgZlYp3MZ5uInqAZSOsqQQ1/S1vCj5pwdzHrHF6Q0FJemS1AO2/GCk8p"; + sout << "dEaniaohFtWT+AMdEQi5TLVmtpltYP8J1Sm+U0/TazrQStLUXm7Wpc2DkMdU3cZw+ncK5znAzzn4"; + sout << "dQGBHVR7TMfIpx8anoWGhShI6oHh4zPXYtRtiV0Y514GWxAMB7oeLOjnJTWeiwO8VS9IPCjcKEbO"; + sout << "zTpPAwMiBrZEnglnKmIGxQXRmGMpag1JZfPRj/XyucX+LreWYF+PabGLPrjSnG5e2D4aDlNx9wGH"; + sout << "raYpvkG/LJJp81U+s3VsS2vUA6uWgz2XiHonASuttcV+a5vo5tcKUDXDaFckDzjpIJe2rn+Z1Xxo"; + sout << "X8OomKyPCCGmMEmE3SxPkGCkTIDrbb5Pc9j0NDVXbRyjVZUf2hrZ7pR+H8+jzUVgt4IfnB8Q7/Yq"; + sout << "uKW+YvA6w3hYv94BqXPHeO6FMLhH14iBLVLp1Yp2aOp1ehJUQ21OgrhNGjMmF04P/2EF8DXy2V9s"; + sout << "h+ohPrdH13fyVrQWRUkG+7Fn+GAYcpD1K5jvRavlli9pHys/axwdu089ivNpA5D3D+t+SHVGEw/L"; + sout << "ZcWpKRQrUZnOLTNjJEpEIAkryuxqTcGDzZlB6ngOVxK7oUiwzocC+zjKtmafaH1BqxhX3RxkXy/8"; + sout << "sXA9pDQ2HIn82Y7vv1+9L+88KXcTLTIUSI6iwrL43h6IhM018jTNfCgdVx2Uav7Vax1kmu9O8mhT"; + sout << "NOxDk0uYoMJfLPhr/nSWlKjgojoEj4IBmALtiMtqe8L4zNpHTRxgvLW8wMQAyj/woN1mLxLb7Lxr"; + sout << "UIY+ECqQomA4TLPWXuTsLMJXQyFgd4nQgqVTL3PUdAtwes3xrHTODTYv9f5CLEiTBV3BP5V6lo5O"; + sout << "Wj50oMM6yJWV4jZypwEp8osWHksLQjPEubu5CiBucOv3abc3Zvteygzp8J0la8wjOT3cd4+LHFwg"; + sout << "hQjO7JuNXwM1xhgdostGxjYExHC6KB9HDG3OUQ0wtxsrh5QITF3hmlvkE//xrWsyDzTcbWCgG5Wo"; + sout << "NRy0RpsXsMImYHlzO3hV5wQHz1Xjswc0ATMArZ+YJvXieYgEyhkIIzYlFlc1/GgnScyJs5Al2mvs"; + sout << "5GFjwMfC7wnMNRmV97SuDzQjKOiwMFochKCR8pSbMzJ4+dW0W9LMvasWGVAnwvQdfIm+E4+O0W2N"; + sout << "ArEB7utReECczQcBVWjdKqK6/N6QbXobZfvLto8AqTTw5CjOgTpoEybY4og+zNemINlmYgzuU7Jk"; + sout << "IwiyzZLohI6UaTrPWd9Ck2lLn9kJf07TnWVw5+2gIqhr5M0d4Yz1Xlgy7EfAWMhPEu6MuD0WKKB9"; + sout << "1MSe1zl2pbq14zDGyICSCXCLTTnLjc5Yj+E92s76igXq71sobtFNGzjQfZtK9VLTrAaeKf4Ql4jO"; + sout << "KWQJTFYbvd1ScHzbPyF3S/T0CHPDZssQQTd/VBOuspVr48Vl2/MRdLpdzjd2b9gwCh2EwGeJIy3o"; + sout << "R+AdmgHfA6PwhJKBkIliNMMR1DEhwN6QDY+b4GKnRHe3+k6tChrTNUqZpZ7Eyhret/Hz//LhmPn5"; + sout << "EJ7PSeJR7K8PKE5WoCzr0h/o7SeiHi9qi+2STfd6J940/2D4HBkYvJOEH4Tu+1XBPi589dh6vnhP"; + sout << "aQkicmf2fG4KuHO9f9Jlt21I4ksrEaibH+BKFFw2Rq2cHME8Bg2cG6HO+TyRwIROhy9yxtWQEYqY"; + sout << "g+zU+DM/WCXBCNDCPqy2q8nY19izJrU3R/9ZCUF6Ji4GjEzjjb+mErVhPcWpaLk1FWAPxtC/A+By"; + sout << "rEiAVG7+asI05YWpap2g7apUyGoyIdnHQ2g8QucKspJQhW5BiF4pi7Kh/zzPQPXkuxQ3biy1pOQ6"; + sout << "/6EEXB6W8dLKHTHn3cXCphbbQzqnti+yQdGuvwDiC/GWZbB7ePUsUSJ3+b7gcR54EjMeIX6rBUxF"; + sout << "2xXAs8El87oP8PCp023NStXDjpiRFCyQ9mYweS0bRh7MewCbvgDIGBuJfXL46eVRfj1sdwlnpS1w"; + sout << "aN7ePzPyTR66zIu6kwZzfv5zw10c2LC9cj/WcujPSKir/nCMFaAUroinmmNiH8C4M/CL4mU7SBkZ"; + sout << "/hidsIh9YDq7cQRP+vBwLgoVewGxnS2Q8yfZ0K9lAEsEo1Lpxg/Bhr6Ei21nnfaWcMQEVj3Gwpdb"; + sout << "fqKrvOzQrUbQBADmmksC7FSgEBUTYze282+qXvS5jnhFg23iV6bafe7LOnZmqLivJGA6S0Y2D0nN"; + sout << "HNP+0t/IutdM96z446RKbwvUrfNiPZF1CJWu+T7Gifwib9XOmMCvamYUqI+dplpbgCxmwbSnMwJM"; + sout << "yK2++dILo2W4OfcPyM9eSBgY/fWKdcv330fbSE74lTUiInCJr0hECkFD6gThcDVRY1MJY1hX2Dak"; + sout << "E+r/OKBGTM5i0AlXeZoi5+RyIiA6wRTKEzjznz3eHhAe8gRKj57DUbhXd0L7BK1ZPs/F4y3/GHtZ"; + sout << "rfh8fPWKi3TX77vcIOANXRmGtpDm3CDxlhcWq/bWxEL2u6wk4Cqqn7xxYLNmvnZHB3GF8aFfPRdW"; + sout << "iqdGfvdkoEWNNAd+OJd5dmu+kdqFXnoGtlItp5myxabhN/I5sdFxalICuAm+kAl0ocG98Q7v9YVI"; + sout << "PaajznXTdN79Fn3cFSHlgrJ7yDb5T5mzpk6ud45ux/SxLw4ERY0jykYeAYV7NQp87gSxDfp/xyJu"; + sout << "LKgD5aTz1RUrkC6VnCDdij8ZnqWx6FlSUdXt7+mDe7GAKFBHsKBUY0B7/nqnP+wS9GU7SqPuHa54"; + sout << "fxgnq2fQi/VaMA4ZTugMOvCKP6PgHWuzhFOorZKnkN/Au+/iOH4z+1NvTzBiBjCH1ZPzXEU8pXMC"; + sout << "d+XzQyzL1fd1xS5+FTOq/FV9fm5ihRm3f4qVxdJFmSnNKBdUt6+BnrDoReWLDOpNiOF9oOfADJMM"; + sout << "c55TRsb7ikJpOUhXHQsldLRMCBQOB1YRiDjLgRXnPM2Kh1/YUFV7dOS/tLUj6Th4Tmtq833I1AWD"; + sout << "CLADK2WJXlXL5s9HzGBwR4eVwzuvQsTfIFZmbnsjb+pIDR1Ek7JH3eyAEwy3WWATtUfSC0HXjEgo"; + sout << "lKpcSGD6BZsJ7IkSo7Yf/pVcQbpqOV0W0QyGcnu9IUu8q1O8iF7Nu5LNEFWBm+hpN/sUgzE+XIBn"; + sout << "3sInc0pG0ehnOT3HznAl24ljHBudSqVqoT/O/K2aswIhxOl6RHw20PypAGvTEDQThB3prX1RuTeg"; + sout << "kmQnQir2ffAp7pxEtxvQJ1kjA7Xp88ZdE+xOFX/8yvCnCWh7jzidGLOZrh6T/OlPIhd0Ipj5GiM/"; + sout << "CZqxIuPwEyKbZJ612Ny57dLmPK2J/J5AuWaX/wmMlItloLwNSvTwQ2tSiLUraT98oIVqhu/GSRs5"; + sout << "pRbWOeTspkUZVC1306DklVBAKqWSDalxyilnxToN3l1+DG+UgJtSevkNqlDpNfH9LqrcRrfmokTW"; + sout << "cymYpib926UvmWf9lsMzJfBqFg3QU5Sa/eF+BwdMgB8Cno/p7FsZ3hKytbOM5fnPzm2MYR7Kw5Mp"; + sout << "xPcptC3Al6ammAMjlp51IMU/1wyfDEZslTCZrBCGSeYBke5D9Qz31gl1Oyp+RV2lxhNngKZcBKez"; + sout << "HjNzzJHm/PvqfxyAILoth+5Cb5LxcBp49XjcuPhZ3sFIfNUs8IhjnT4RUZxJHs0PMFHwx9NwgryB"; + sout << "4UqLFn0+Xves6a9xW6mIzgSJZXxgKcWmA4MZFR1welVzD7l0Y6Rl+Wjw+qrklobDSNgntToguXwR"; + sout << "mrSM3jnAYy00HnoxJVcDFZ47Fa1aYpZvAlZHbgqJ42LrR1KvI4Qfe+cJVyV2iq3FNB+4aO6DUVkn"; + sout << "2dBbf2jmR8fNIjz4XkSCXPpwx9OZuYB2C6JjUkQ9lGfDGPfFcDNA0mjkX/FmKf67UaGkJsG+DV2m"; + sout << "sxojbh+zm48COhWz0xrBGSdj0viVa10iqcKs4+izaPgDSHMFdez7nT1aIBqf/ys792hN9h8kybET"; + sout << "GZw6HGYgD0u5+hbXeXVC11f+8aUi2vjGW9gc+R0cAnAjcLNBH6IPIhAzC3E3IJ3R37wqQ5huxyrm"; + sout << "qZ5AWYXo3PRswYjxwlXvUSwEbc+GxcFR+jirERKaqgwoVOw/2F2WHFQ5yQ5R0nI6uUhNMbTjlF9x"; + sout << "2MYa3H19/dgtY4UC496uQ0uVJ06wlVRCL1SCRpSn8y9IIvAEpkhX1b2o8NAMxR1duq4U0fUhCDo7"; + sout << "GtzXFiUKyLgKT6tEZdww8+FztcaEB5GmoIxy25ReE9Y2Yz1w/IjIvyOQvqDxqNMC64ETTqYYlF7o"; + sout << "oOGEBqO3JnTVdOtwbv18JsefEaqWs6hyUboe3zxvZqnrmSrF2Ezydw6jFhjHyTdI2rRNK7mTCg2i"; + sout << "/1fxkFC4Rlw3U4NctvjVXqhVG7/NafRynbtBxSVq17MH1Sz3IRDsKS1/HtORRuAYX5/KYcgDjNyL"; + sout << "ew/mVxTedPDQ+NDsj/5k4gjpHB28TQrYD4R8DheQ9liqV7R+shaUz98cdqoamEcpi+Vvo9eftqpo"; + sout << "/tfL8h88W/BFEBPqG5u0CrSasq88l2EWUflOfdHaVjPiUpA03DhJ6d7WEXIgKqR4KQ2QUFqhft1V"; + sout << "x6XQ5c6HvkYEA3JXB9WY7pqwhh807ucN0Kp30K87OElOIDXe7FWotG1Nu0IjnhIK93NW9W4g9CRw"; + sout << "jDvID7dpdENEY4wzLG3Q+PX6srxE/bjKkWXRrIVe4egRq1yc1BdTPd8SuxHKgVvQYGNlDHZww/Cm"; + sout << "sdtNn9gO2AN+zouZkXsG5JdkJVoGZq7SoLF1JNiBi0asFYTD3y6q6wvfonilbZSnc6V70VVmjY53"; + sout << "NUCkcTs9H8acP44DwzIAMSm+IWWynZnEmul7B9+ViqgqsHW63Q7AdBCd1O9oqYaLKa16iFiHXAGS"; + sout << "AAwYVBuirfpq/yGh2obX9w5LuIfc1Ohmm22xT1tL5Tbx1y6rOFL/LirpgtqOKACxNbxiCzEvUEpB"; + sout << "XHMNDcX/fRaVFYss25guCVEj9/YsVcDz+lf0+qb88TM8LGqYCb0A2lVHbL1O3jdvBtfw3MXHX+3f"; + sout << "UqH4kjiyAkgn/TLXzlbEjkGu9EdImol1ikOMjffpoXri7J3Rakpq19cJcXZEfN8Nl/J0IxjghnQc"; + sout << "/Ge2ls1WqEwMUrKyQfRwoJL8Pe/8JErvyUuIIIjG1PHDPAUDmEHVXWU7wDms6GPAxh5hWM9Gwghx"; + sout << "xDxn6q999jMjBmut7Xvl2yozhYpzs0KCNYzpFwSOyIjfwU/S1SvpXbV/fOlzGLqM8uDMveoDeLTG"; + sout << "5LmSt1eI2ZMnKoDquGtw9Up/Wrj191RIHkurW0RubmX95xu9KJaLOSQ07cDP/FDiME9LyrHlT1e7"; + sout << "4DkaBHyVw9fNhbeEWjF9dCOT/IQz3dQhiNizidYZgJX3d/coX3xjbBKhI18DAtQPLdx3EP6IFcPe"; + sout << "AiU6yzb6lcthTbZ+DAntjRMbEh4sCd/OtjC1HHKm0foDGYcpsN6R3pxA2fJ13yU0ZPLYSJT+3hCE"; + sout << "vyjvV1go70WzRw5TJ02F1ROusAkZnL+LVJj5cOIKQ+MNifUgd+jtwdiUhpFs4HWEnDxq2tWLNf5d"; + sout << "xQ3igiJjm1MTfsrex6ehzwmyueSZKxTBgVRQdLTcSSSiivdx0zO+/TqN/0s5WkgCufzMm6uMf/1o"; + sout << "eWPkERpbGs8c740XtOwHi/3JxDC135mV1Gfau8qPUA4VRqfTKMKwcCk2j7SYzx0RaQvaNDa94aNd"; + sout << "a8gVcQhLQcStjdUOR8smPQtqtEoIsHqLLXDZF/3gWIEjQfYjbchQcscXEi5iM4jMQCP1/Gf/oR2s"; + sout << "BY/eEjhzUNnDxyikEt+idxkKLnuRYJoRVTzfCk3Gdlc556Vz14By+uKeAmmgVVH2+U9WtTMOO7J6"; + sout << "vlpJ3lMEQ9ANonhWe788Jz9pzjzIHhx2UfmZM20I/t4U90+iTJ29/AoDc7g73bHfzugK+2WpYocd"; + sout << "ecImcUaQLdMqL8gyK882q967InpmujoXoDo4hrqpaLsowHhWiAbcPTIH6hrvkNYJ++PYo3QGCYwf"; + sout << "z7z1vc6pUEuKEx02TvYO3PX6e/i7RVBQ5HBjVsuNiAw9ai0fsc9sKjwOml7//NR/o6hdFvNnn7wE"; + sout << "Qk7pP/kasS/nVGYvsxAcqQBBXRDfDEB1mpkZX2bcNzZ/PrP5t9aS9t7LQYfqA7lWVgY+brfPTdmE"; + sout << "haXezSHgRgy8iQjBAVdBl+Y972kmR4zZuW/SiK3mkOe92GD0dVXz6l4/HL5VFflONDAHCpM8s0zR"; + sout << "HIlzwxAHl2wwiKqW2LtUIfnHA8Tb36Uix1nyKcGiYjHtv3+mc13GzYU4KZ4sj4suUs6Yr/XEtAXm"; + sout << "WYO9T8LJ3zzbxMNJZgsVHWx1UWKixJwGCv1HrEAzQNmm5aDBGYrutpNtcC8B2DKIBex+IGlpTwrJ"; + sout << "Avo0T5uJCfw3hCNwVpoSvpaOpP/hFVJYI5RLtDQbgshmFI4iFz4pIN55qCA98tYAg7a1m2g4CH0l"; + sout << "1R1ErZK5A08UpTFnxKWS4pVB1XAy7hqtEd8YTADZvNpFTkocwBFsYSsDU92hTS1GzEQ3+NPIj+pg"; + sout << "i9G8bDLaAY/bojHUktrnaYL47+BmrKQ7N9YWtFsjFrOaS39NIw6WUf4nPd5uxcoqlaM59SrOyEBT"; + sout << "r4emeEzPDzF7GJd69qFOek/DrjX87M0rDcyCOKABNJfIBaeEJKKnyniWOx+ZN4fb9meuM5bKTqr5"; + sout << "p6JLFbf7WFzqgs3h7tOcMhSvpVv4ve5Ap4yx4w8ftImJrW8f2j/OUiX6jtA9shoYI+4DOqYL1mRS"; + sout << "xYK5WKMnkftUB6lK1ldv0ogi8Z2Niil3ZBpWqhY4+tuEvTLAT8nQtX+p9/gtHxNCyGUV9nQk/T/N"; + sout << "lN/kBTdZs+KzJ0K+AIi7K1t4dVKggZ4vmqYVPj2WTKHboSzNg+oxAL8VmZAxnY2bS5AFR4wRSLgF"; + sout << "8yQj9xfyDKTPRK4D7b+Axh0s/ra4XzE0S+M2wk0cmRbb0I8/pG7yXSFzYPM0bVhRuiTVT+mrxcpA"; + sout << "qSJRKzFQ6h+KQLSYUVAspQeGnIQX70fF1/BcYviKGzwiN+T6zLH+thLJnZX/lgHwZZOtHYnCaBBc"; + sout << "t1WFi3G/Ex2+UHo+nPBCFPq9TNEWcAArNxPwPb9e2H+ZH7dNgLZjZtGQoEiEI2uwI5l7/X1++/3N"; + sout << "EWAKPziqUxX1sZeW/T9cLMiDsVajYh6wqn11+RM+lCUKvYZym15JcEKAxxsFxo3S3gbILFXSy6IL"; + sout << "AkGGUsKtITt1LB3PO3sIare5OP3YAk+PzEVZcwp80V9dHq72cSskofr8hvUkzsBZ/iH9dC1dupTF"; + sout << "UbI4YOrA0nQRnvx9AQVrPldyzgxv9y7mVakcgLLpd1YOUDS7We9Wh7exElCVuB73RtXfcd7hEjRb"; + sout << "fu0xg2KT3wOTQy7WDAnidLZLtWeKRqDLxmO/W7370N1TgXS2sKLG6lGk6UY5mEOmoLsBtu6G2bRj"; + sout << "QIZnC0q9x7sUVMyg0TOZ3698MgxWXpvNqocP7+TYdJeptTkg0RZdgON1WLlCs/yBSm9DYpN6u1ak"; + sout << "pr7GDx/Yira5gaDGvnLtoL/6L4/0bUH3UJmeNACcLec9V+qxUJnYjQPYPjGjALRfRNPV8YRsjWBn"; + sout << "eEP8OJG4i1dQAif8ehmV1Q++qHcmf5je2IQwcsrgj3+UIw6urMoRWVwDISw5Xy5Y1HwitWZ6VW2Y"; + sout << "eeeLRYBIWAsoJAw7rSyLqKjaihO8rlGw3IeVrAHt3k12wX2t2f7eLF/MKv/8JjK7d/rr7EbNKrGA"; + sout << "iyp1zQFtenTYrjr1/Wh12/lmTnipKS8ZDCJuMM0i8NTXB1fuisqNI35HCuXXoJQmc0Yfu0jA99jL"; + sout << "T87FUwkmKQA5+Q0HV2C70CfE01ALVU9ZbdP80mvHlMhiLURPjRmYjXl++FKpHU4wvRmflXG2K7KD"; + sout << "aKccgkYqc+Pw5T2AGX5LiJscIykcNaMLVtH/QYPasG8jfgVH3dENDXagJjuYDwkROZ7Dm36vvMZG"; + sout << "QkRDHXApzY0T9GizL6V9WWflM0lopSWELw4HNPExg+ZMSE1eAebW58HHTDZEnu/Yrr/itfWUid8t"; + sout << "od265bhRmkCVeEwIL6Kt7DBdDVDu268mwE/OlCRmE0Z1XHDIR8ggVbDjoQHV8Fzr9YCM/36qexHL"; + sout << "v5VxMx6PzZGBMIGuPHIiqnqaDPHhVpMGbD6Xowd50iE82GANJrgulQOlIg66T5znMtXg6g6hOqh4"; + sout << "WL4qrPm7T/2mH+T5PVIR/D+VJSoEtnRLdnO6zZnIZp7tj3jXjhNSo+YGzZmdp/H1iuSqC6GuWKue"; + sout << "agUBFs+p5Zz/yMrYI31yx1URhmsLJrIQU8llICpkRo3uqhPzpXcZs/MZIS5pWniqIfGnqUtzeusS"; + sout << "bouAVAOgf2rModIu5NWIwdnf+FEsDsw4sYDEI0jkJD8StwBgTg3uiq1jV3FP8nHWgu/QoWWXGtpx"; + sout << "u7OR7HYi+Aq4lK6sVu2qld+deujeO27r4atK+wQfmbuPyjJugTMJEdLfT2SkkW9mN2Qx4jPmABq+"; + sout << "5tCUgibAdWz5SqJBDr/jxW8WZ5YCfAhLifSbTdcM/IUF5Nq+NCT9vrcWaUak2K2uhfvoJjN7DWVW"; + sout << "6KZMzeN5XIRKMZc2ItCW7sUmAy120f3LWzGZtlX49LIY5nsawUDOrPIENNab3CWEPAuKAciOsFU7"; + sout << "CV6BvOxRekHvH/6clgwrfSCGgJsru77XV43H1/wZciyR/4QDMXPaewDGxdH1FNddnkZVAuYWrRzq"; + sout << "PF763f2SQmwM5GERk2bxwr4eeeBWBRqMCK4ZNWpdYBseupT694Rrk/bnUKjuo3kwYrs2LzRXDJVz"; + sout << "Fvj0UjorGcr63LSXNyLhyBGo/Uhn7dgtVUaYzsY5rChqvLthaBhab+268cfGyYhacN/E3vezmySL"; + sout << "J1Bj4twQ+kcbWWQE/Jx5zGkgorM3Yu1vbjilzv/UbmNb+Cul6MgnKDRFS+cGQ5UMw2DD9OyTKJZ4"; + sout << "5k+BcnLih6kE/s4UmFkeC4QRqLMBBhTlScS3wXS2lFsvMd/BXNXWtyFtk+dUP0AH5tgrRsaY2/h8"; + sout << "xTgN0sMK70yvSzojbnExkON5mDszehIEHnfBlL14eqImQRiUKlGQlP55T2AKSYBbpKFJtzy+/aXp"; + sout << "SxmkYMQcvFD+AH0oaqJFaB9JBLT7JW5D4ZwTHpMX7AS6W0LxgriTgPZnSqVgosl2DZLKBUt1zXjw"; + sout << "dbSmrROYcUAXJxMOAj6jNlv7nMkn2xV28QS7lFE+aq5yBIX+UBTF1uqk1XI6OddMKrtzh0SH9Dey"; + sout << "q2NWoOiH5YIxHHbWx2c4fiFrSCo9v2JQlk/rv0GKz5jX4GEI9+UPWzr4MA6GzJrYS9cydN/NXwY8"; + sout << "lvHx0e+zI0aiQkbZb3hNnWsjQ5q2gk2pj4yT4zCFhdL9NBTAwWrWqjSi/Zf8BU5R23WHHMMPBfEj"; + sout << "1gX3IdAZXw6tNUseFnBithp1Zo52/Rv+OUyNIC5o9whYx7T+e1Yyr2ytlKeiAHOWmiP8h+Mqnb0l"; + sout << "BXYuP1fiYcGClI5Krh/05hb4CCW0n385qHMIGDiFTw1emVktV9OTiM0Z5M8PUHDi7IGUyrJcwXyr"; + sout << "Bjh5HwdrsCA7NdLqE+8QPYnAPj+rMJWw+xrrlcPSpRBTTPTPiVnn/PXL6khN9lBiwz6Kn+X61bhO"; + sout << "8yN+OP1IWu/DZUzauOoVqhEPEARq316Uj15VXoPpy2oa9/L0Pqzna3d+8/VxnSlWWo7i7bWEXbc5"; + sout << "N9OKxooZQ6XAFihvJMIsVvkYp0oAQ4AIoOLUSm0w0ejFIWsQ/goNmLj+CQneb3+YUpdYtpLODhln"; + sout << "/4HgBsPjGnEug/fzuWbEqWXb/e+VL8RUX/r9nXe2rGeQtdQkdqsNj8avFShnnth+UNk/2TRAq9Cd"; + sout << "eDYJMWOQgP9NSecba1B8FGmBdvC6gRsp06uKfGesrktOPlzgKmZPoqsjHhrLLV58KwTavZSoIriF"; + sout << "Mzj6ZAalNMDEqUGuhtXZrhTVoRiYsy+3jjJ8IubSS/RxiQfPS3+3nudS8jzA5u4uHJcy42t04z/l"; + sout << "NkN5/kjem3F1fAu79/16XeDD/wk9A7FyjqsnyvNXsiX1yKcEC1Bw/yxCh33mgcAlh5ilP+gWc7sg"; + sout << "QQXi/dnqQebktoqyUYklPN/cYQpCQhhY+YRp1wwbLJYWYrBXiN3TbACI1koRROZw8pT8YKG4utqc"; + sout << "P+IjOrAVYg94bzGeuRVI/T+mpmGHs+lsgUY/WLx52PqnmLgBORLHZzM1sFp50ALnbYah3gAR19rz"; + sout << "twGYVQXIIhK8rewz/RwJJYwSvilMiJNx5UqJc8VFaEnwKcr2QTqlJ9FezDpuE5kvqlYl4tywsnwU"; + sout << "ta1MB0FXAC86i+4pXWfaRt5Y1g48PGEI3TdjkpiRJwEQ6YGSoDngFX5K5l+rZePE6FHNLBoRcIp1"; + sout << "5aNK7cS3MPyb1MN8zhjX400wFp208a28cBpCH0KmHvt8+ip26ejS1C1/aLvdz9j9sQtixw7JU4Jb"; + sout << "ytPOVlxKjAzFFFdaQnVEAkFCKetYCkMzrARSzcLQJcBhhChb8NQDZoQWxe2286/y9p1ImU+jHlzC"; + sout << "iTREz2ADRBGQO3JGz2h5GYjU6YjZn6ojyUyMECTZASV8RbvNy0/8yYPNONYWm8v4/LLOUM0fq8iv"; + sout << "xwag9U6L0Y+TWtW1YD13ThbTjTrg1lTHKzyp6KeGWNWGH141FgXz9+r04Qkvulcn+rGimg4SuvNJ"; + sout << "XLotrbWuhULd/esyJIAqfTtHzHZMSBoQ0WauyWV+so3c2rGDjo197k87SJkfGp+l78kqj03jh6ev"; + sout << "WyjhQZizdnuNdBpm26GnuDYmZdEaOAADCtkR9dvMWdeNsDCjyRcCbFEPiUX2xduSbw6aulm6mW4l"; + sout << "CckSOv3l0/pfw8bTKEszFqeXiPEI2SwIbSx9R1nWd62VVDrYmAkAXzdJ33QOlJm6YJT97+QuU+61"; + sout << "RVXpRanpmC4Eh3jEZKV2BWIuGPXzrDGgAHoHNhRG6vJ0Kp3RqaxFv3cPrLZdxDnQhZQH01O/dZ8S"; + sout << "swcpb3/D5GdWvUS6MsIPhO+nhoabkyxwvbBNzyaLVDoXZK6taVT2GWFyYqguu3ubVBGd0H0Mksh9"; + sout << "luAh4ogG6drswOFxJhR4S8VIMkF3x8PDUj5w8WTC2nlhxIYVh1fxvdj8fEfIH7WzBElS8vpZ1tXC"; + sout << "1JSgHngiEYafDFX4OJeHxKFu994g18YvLkJDdQKzTDvOiD9TaKnnaSGGXZuasKTtD9z87GX4O84P"; + sout << "rye5DvbYUP3bZx8GOmhf5YsaLI6b5m+iX0Mm8ulB0hOla8uE6EfLBovorLCKke5iPFU/gynxeCkR"; + sout << "8RL9zw58ATEm3YuzE6nDNaIekqV7Q0MqCpnpAdBPxjtEGo9yRRPeu8SSSGL7IGKUUcr/6Xp8ooyK"; + sout << "iK4f+mB0hgO0pQqffQHkEXxGg7Zd/eTGAO/n3Acq6bjGr03T6LEO0KFJl1/m96ZEBQK/iiV4m5+V"; + sout << "bZvr6xOrErUE+ih4g6vk30cSQqFvyH574K93bx/uyydqzyPEQTZ0oJmT/KoUrYti7CxlGwrqPDTZ"; + sout << "AuEOwQGqqxoW097Ql1bpxekOevJPSL21snAT+Lf7JZ/79YwZu2WZnWoDQRNNm4nwn5M6IPmErESJ"; + sout << "szsF5VmUs71cj2/gnK7NDU4c1kLjXSsVVnN3/VL1mOnYV25Nh/ktPRfTLn1TDNG+rkdJDjGcjjSd"; + sout << "JY9Ro1Rab18UKm9nE2tWjgbNUQbzsisy9P4F7cBLBd5K6y47wRYZ7MamnAUZMUEDmIiO1SR1uLey"; + sout << "gugfd3Sb93PYjImZCjulVD/w64IThgqDIgVd1BOdAN23laLSQpfuR0xW5k6GqGtiP98VwULbV9rJ"; + sout << "thalp5t6mwsOdo46101N512/p70XcZuX/VyFxw3bGO597RbO2nj+gM4UanMOfTEQch/kaJsk/WTu"; + sout << "m4QMSxd9SAn9/aRXKR55Im98Fx/m9q2GB0vn8RZqfqHAqjmO6F3usDuUFxk0kn5Hq+8jXBephw4/"; + sout << "YCOP8eBKnY9V2k2bnQO/9BWhCyuwun8huNz6keaC+qa9PFFwXmAf54BIYdgr+NspegaRej7bOISM"; + sout << "WpHZ77YuKL04uozXUz6B+0Jy/zDzI9EgLTCk9L1CUpGPSoNGpDfb9VqkOdTvjQphSOPB5PxD7NNC"; + sout << "tV4tl1nhB46JCCEalhhsip0/ZEtx0Hd+UNxYRA+qjaGiEXDPPDyjbYYHWNRhYKo1SDJLVZG8hoVh"; + sout << "stl5hebAI8t8rCBUF3x1HrMRdrRX/GEzLfx8xZgk/YAegRw68TfN9V1bY63aRS712i/twY6KDxD0"; + sout << "V0DeubXqkK+N1ERAn8ygvQ7EJNvGbp0OTu765R4zv8y9u+GqdmpA5E1Ti6+2l/L6008zbYQVWG6z"; + sout << "ejkRAp4UjMSDTf2zTNjP76zzQQOQwsmwodggF1paT6UlXpb5HHp9osSrdwGXb06FiXf2uVo1rtSh"; + sout << "Stzfee1i+BebVOqKM38qGjBhXwGiiiZh/mwuG4Fa4wOKlrqaTVEHgGm5N0uxr9by1E78BXj7Nha7"; + sout << "XuDtqivJf0brsiAUNNtfUjBT5WJrj+iYT16Zs7ds6Fm60mdF7vSBrxXPNw7ZuueSA5xuN2bkeP/m"; + sout << "CQ/3N3lGgTQlMNkMUzgY37ln3xYCXGcMpQ67ACMf4gf3l6FhI0S+JMwzxD4XMZBTxgt/2YX410Ia"; + sout << "ZijtokULoOBQRqLjxREcBBhm+tpSSU5cRb071EVhz4K+aSN7PfcFOTMC1mVGy4cibxeRIq5Un1o0"; + sout << "NYSkHLiVCA4OCLkemWvzxswd9Fduc9mK11xyqBPu688yHH9y2e4Fo5YjiKiZynd/kydWuNsYM9i0"; + sout << "U9ukHpMFmrDmIPZJMGDgZm0I2QPvoieRcpitQwNYh5x3cvsEs3eAAc4mfxQZUR+v0gClhfu7wSn2"; + sout << "XBPPNVyzE16RK8q96Z+w1paaFKHR2t4TRGhmeCcxR3vS+jvqpGaYFJ+YJe8hvJYK0IMZTBSPg6Qx"; + sout << "Cx/cwlr7gStMvrFBKjLGDMAsrw0OyDbIe6/E3Erq43GHsAplYSxYAvOFZUwX0b6VkfAB8CsBnRFb"; + sout << "jnZmZP+cZPIySLLGJJm/oIs/Hy1wYuyn2XpI/uCuBDZifE6fnNzdpzQ0BpH3qjaIF4OV4gghhruD"; + sout << "toj7yNaS1FZl8IhQ59Z7P1eSvQNulTHYsfBw6E8RE780AO+aRHAxVtuS08RWGrde4wV3ma1NzMaT"; + sout << "6fHsCr9IYY1qACpBtyGz389+cnXpeN2VZczNw7JBx3VthteVNdpX7cpftA0e/mazMuKLX0vcBfus"; + sout << "/p3Vj1SxPevb5pslgkfBTVcFxmEM9Doy5dSKQ32SVFgYwTppEuyviNKuUzQWwO/XhsL50gcuUhyU"; + sout << "ELpBLIkpABrbGRa9K9ye5qLWYYzzBzlPoyhEwt8cWi4hVnafg+RHEJ7lXXQ5GlEMlzxPJw18LwZs"; + sout << "jjbC/PE4abuILrfHlOygD1CqQeYzzBN+gwwaTbmquBtZyIMvGr5FrRqz6sRP/W2OvOBN5iVSUa50"; + sout << "VjKdnNUwSkeLCU14Lp1zN0WQQVVNeR3o1JfJsGvQNjD74MnSlHTFSUYloSFgHbM2ANeKgiYN3uB+"; + sout << "xLb+hRJPxIwiTshZyclnLmv5mh63OWU+7afdA07BMdmnK77uFB1rdXFJIWUBiUG8u2yqUk3PKVg6"; + sout << "9rJYvVu3eA7mjBqRvL//koo7VJt6g1Hj+8lnqu71OwsodU/t+W4tmcIKP9S0eafHEQC6hYCu8b5v"; + sout << "QgVuWf1LnTXwTGbm1ohar6OTxyZV07nJGC0xscQ2yzhxnssaCoAw5xFUzFrgo8a0J1Nj9aOaV40C"; + sout << "3y5lP6NgiO+3r2J8wUcKWEVFLHDAYMXiMfe3XMp9/ERswAGSkzfXSxZN8T4RPorh+pjPH7sZJExm"; + sout << "o4ZOdqaWLdLS03+0kRuPXm6FvZCaMY7Hz0e2Zv9X22KYmXluGqcgErFTjYWhQFhcpt8oZA5T3mU5"; + sout << "NDE3qI+bgtkN9vvNbhIrrOH7mKm0af/zrHA8nnU33sc6Gnv0wfCvolpiyTqla+HU7rwapX0+Ge6j"; + sout << "SZZht4l8OPe3/mSmwGjX9U8xWxiIhhr9VbVxZZ9UTnAhfcceB3TrDA7F2SxGmfBSRvQ50uRp9Yi4"; + sout << "14Sasx4kgF6e2wJh7TBRfDMQM6dew3rmIn1p7l8f43AcBe7T3rnWB6UJ5eh14+xb6HwyT260Njvy"; + sout << "vF8Z4nTq/1MCLv3c4lQoO8yuhFA+J3rWlYTQmc2ilfWbYpdiAOxT7k44QYoHTZDS58anIywl9J8Z"; + sout << "uBYa+giIPDINPDLGNobNEPl3kerfygAFfVkxd33eO9LwDAMGAZmbywt/OwRmraCFhT8Qz9jygFse"; + sout << "mT1tZ7+hF2rfTvyF4OiTz2nutfzVvYQ45DLjmbHFQsHMVGzU6xdLQlZYWq+h5tfVng4zhwIge8tB"; + sout << "mA9ndgrTZx5tspm45fVKTjE3MYL5qNsmzghyfasarz1MOD/2wml6uzWZjtljHIzap5gtIchaIF3f"; + sout << "JMOvFu8Y4R0Fu2UNkbOfy/YLRnKlUHEChwTeEavSkaHBkprvRZJuMsIWcvEZHDIWdRte0qfNm9V9"; + sout << "zWmj4B3RchqvXh5XsoDkMqergDisI4gCiHFIxzzfJBsX8TQDUQIvAGk0pc3ib1/6xwCvLs+NXC+8"; + sout << "0sI9e/aruBe9uEXGya5n6Vqy8z9soOoBCcFqc4B+1ZUOyN2zpIxVFd9/20NKmp1W8RjRforBvF7F"; + sout << "2E4IkWh/3RcL8tNyPpWkAA7Snlm8oO8L5cc+lMx0ph7DzYX5NcIc6V7otI+Bd7w8xLSVxzfX3Awu"; + sout << "S6bjrU2af39KbPFq2LpGOakR+1gAh15sd6WYscozdkSZu7kGpuEppnuMxYFmbMrCxtw2lDt4IwMY"; + sout << "ttBHEP2XwF5z6DQpRqsWBMcuBbztZ9BNa3y2aR8pASldncrnRjYQ1Iq6KGrHeau/uq80ui3VPbOD"; + sout << "IrV/JeNFGet12rX0zneHBGYUxG+Pbso07l1fJxd6OgaIrwVksovd1psBmD6qYsbc3V3dJLacNOdt"; + sout << "fvFnHWz6EDWivzMDWvumzAO1JPVFB1TebaaWaINSxaxS3h3IKUVPTu+2Ytw4Plwt6fc/LakedkFH"; + sout << "fxE4uN3LR8Y1bKQN+/VvTHXGUtf+UyZx/VWHaxuiqoxXUazpT+l2muie+/2l6v8JUCBTUYwa+2W7"; + sout << "WWvrBxyliD2V+KaN4dybWey+3RIp68UgV0gP8h579kezx0+Q2Ku60vjCIKJ4vJB8Gt06a15lHOmx"; + sout << "WcJq2DmQuyWTXsqaJPuljuHUhqYWMPnqY9gi53ddw9SwcpgE8cCC0cG65FKn23C1Ihh5NnuBohYp"; + sout << "ly/QVOFi/rGW9kKrF9FW7C93p2DxgBhygKC5E582EIjRXm15bQ5bUSwxfYfQ8VDCPlL1cR4yhpqf"; + sout << "CpAd5/ls0MJ3SFgUS4e96011A5intr3a3UMPqdMKWrwCW6cP+7TJKC22+LwvAvyS2dI+NdyLt0qY"; + sout << "o2t9/ml4u0bbBjwl24o/CIvm7T+045fz+3h4yrKYYPSbHPgGiN0+oTcGhtjBgtil6WCBMuzgEn4Y"; + sout << "zVaTD4gQV2l6Cnpi6fi3ZIOwuEi8aqfRdNrPc8+Rtik2wWdmbYZ66d52HweeALKwfzIm1DESb/Dd"; + sout << "qc+pz+yVNmYp5fffsjf70WrDyGWMDqSVvj3tXMkxKiKRetutfs0Tf3vGDdUEdTmtw4rEca/1K8p1"; + sout << "eNpotuLSio7xp4/gz4cmq3IUEXzk7b++TAs6GAq5ddjbcTJoyRSnOhu2oetJGzWJSglrWwGpAAoC"; + sout << "l5jHFfMXIyi4lggUHjIcpEosLLxntEai5cP9bcGzTw3bnnQMo2mI8bCKSrw3uzKapI6MdDJ7fXZD"; + sout << "aR+zuSlVjO98VU+BugUHWJcrHneh7tRZ5ijux8t2d6y/iGeJm0gs2C6ukA9PTwSV7UxnmVOCsGqp"; + sout << "kh37fINlLQWIjS+fVQZwYcm5XyzPzsyHa/x1DZ9CJgSaLrOCm8PDPk2BHFSdoYMrUGn9CSL6zRNN"; + sout << "TL8KprKSZLrhrk4wtgRrTAkO9twMovO8Pjq/xgO6djfJygflU/Yn3VPYjg31O0SQLXGQLTL2SMcf"; + sout << "NlIqmLEov/m6A5cs5pJF7be1yX0XneDEa7arl7v4tcvMw9Ot57OytQUTbFu2iUb6zviTVWMEMJST"; + sout << "eLj7c/+RNy+lhBjlxG756YIjtWULrrdPBkie+AYHMSJ7PnJ/QAl3D126W7D9ubCjagGVZ2cLX/ft"; + sout << "/pLcIuPRcRUeL/jTGshsfijbq8us89JPjLS08HMiaaPhTtSOKbyqHwVEvfu++thQZsu/3wM8f2a8"; + sout << "CnKdaRAkLethTUyv2xNyQmx36Qg3LVGf5W/tzC9U2AQ1XRE5l/IoL1AJ08JdQDqHSAEAiw0LWMIW"; + sout << "a2alrr54YOp0qiPcRTLKo0gZLr6Sy4jKmQqEFrPCxqPSGMe2pc/pzVuxlYhW9fDMjE1vvfzmWnhU"; + sout << "SSTXpxONjtMDtFWol0pYyjOLyJcATXCrOK3PXPX50e+S2Xjl1iJezlgtdqWVR6mK8yJ63GyAtjzT"; + sout << "oBIBQjydPb9tpuVY3Qy129f/Pnw/MT/6wggYCZDuhhNmOqmSRbm3runLfwSX3G8XbN8LEw94wy+n"; + sout << "yMs17FWzTwWk2fMEIV88TuqsxDPW0Ko+a+OuiOQ+7TGVdOLxtqmUDW/3n5QA6Ryc2kbaFhZAhONU"; + sout << "zJXyJzll6UPZfUsmAiX3u/qyTWlbB3EKngamgmD/1WNGX6ysD3RYPu2olopo9wOqihwweOtxfQYr"; + sout << "QdXhUKreENZAZ/lOjn0tSfAFsB21/MuDNpuz37dI0ACfvKyPdGnFpNocIXHWmSWZbOB6OypSbzMP"; + sout << "5r1fNPRk9qMY01lKWRz+lkFGzl1vbfKfSO17q073cCPouHosT+RpE5HBBf6E0yG84m5o3u1bLUBb"; + sout << "8V4YyMREOIa/2wkUH99l3YksFc3eeUMzHSIoHrvCQq1CTyaed3tNbZw0rHPLzwSSgQMZAhsi5mxp"; + sout << "ZnMABMsXSHdvgBRkVn6PH5rjrKAcSnikddL8F9/FXvJR21nKfo78g6RULWdlcc8tPQR0E+PyOoLW"; + sout << "VR3yw7L8Ddl1zHu7hoerz+cM4gW+67kjehlzV+BAA70h43CfxsmvZm0Kaw9GOkxjDs8F5b5XRtdN"; + sout << "gxh9Y5o/vlxI8f1w388jL7od20NPAYSqvONaJjF4nthDN+CqlNWjy3K44Xeghk1o4smFlbJM1ZdM"; + sout << "taj+ZlPozWzYEugvDe7/9woqdFtYqCWDlIAcMGqw6QGFGE9EVwP25LeGWnM4bcSWkdP/Fc+Bpssq"; + sout << "QaRpCt/+Igw9C26fM3zSM3y+RfDs6dseYokawF7xp7oo++opFuMO18Flo9dJw9HeXMZ6R8lXrNrD"; + sout << "l1RWaWakmYp2KgPJYT9USAV6S0JI0m9MxL0R0t7XT23cWgtQsG1m/gLb7HT+CFNm4SYnWMGpza7N"; + sout << "4w6hXTItZLuRrC3pH6AF/m8MuwdlKJEkTn0Mmu0e+MVe/DJmnA2w46oxML03wNy4qZMj0LywOnIu"; + sout << "Ip8p1oLIviccDGpnXqZORZGVMqzuidIfNBSgalxT1A+ySEAVroXtGjtW/L7sedcTElNyK+aspa3l"; + sout << "O10XSF3gjgcZnN7W1j3wiqePmtzL/RlZj81ruAoYJRWaNUDT28p67/k75Uhl4Sh7qWwc/mJnG9qm"; + sout << "rvZSX5GiW11EJs35JJ08G3UHAUa98VEIOesPMaL1Wr0CVHFVC9eW2iHwEq6vk3JrGfsKARrZgg9B"; + sout << "YDhkyO+3c+qNgTk1WMC3MXaSZg+mwqwx0TGeJRg6/+rcKr0Y+y0ZEDLz3kMD/aihZ0ictCZlcYAz"; + sout << "8pqdirBd1Dbi8AaMCJTWHGWaOIxmrdU2XyVWVk5Vpc0rmzyc3ya8WfXjkdDzXKUS1IIFi34b3/7G"; + sout << "qYpJNqtVAPT1TXnnM+2koRgGaCARu/RoKzl169YpDYPHVJuRJDjy6lfrKNEqPzJY9phbh6WhmZ4X"; + sout << "CZMPrgnrs+0Bs///61zNUJA3fArNyessRIkZqSxnKjVOHz6/LUB7/UzJXAd/XfMYi4yNw72BSAjJ"; + sout << "Z2AdnjGIO9p0E2KIOVx39bETt8l8JKOyT4JbimwFT0IN1AcvwH4wgwtKYAJ29L0kSNoBmMwaKCAj"; + sout << "1I1b8usj6EM/OUjn+e9xmLMkeAUo8HM9BvLP33TuKJwLINhNcggbZVeaFSwBsvj6hEQ8dN06eLJ/"; + sout << "yIGaRsni2n6oXeRUo8/TVWmNIbloEE7mQUuttR1pVBdkbJsHGWm/S5hwEmsKKOpvQo44Rao7fsDK"; + sout << "cCtKATplclwuCDCPIogRMkpHRSSrHUKE4tzdA7rWWj/jaYc1Sw6hG7OaDFSWuP09Qq7klmC3F+0U"; + sout << "CIAtQbKJevYhCSdWHfo4THHFR5CeuA+shfbdWyPNmP1HRA3pvo52jSsJi8HtPBp6XmGmM8dN8Vaz"; + sout << "pG5d28HH76fj2Ny9McHrxGFaV/WhhoST0GDqN4PnqV1X6d05aBKfwRbqdWo2XvmrC2YNI7Ou6+3M"; + sout << "Ovp0NH7MHXpLY0B+GMf+x2X8hoXkffnGCa5CQ1SwTbZRkeGjejwGhpOybsOl7AaSA/fqymEW9JMF"; + sout << "1E9XEtZtzyATkj0cBUlofwC6N3N//altqBAbmDKrHThV3SDsvn+tpPiZG4+pQQOWqQomy3QAWwVV"; + sout << "nZTorQIvoPMUdJWPwsVLaRldOrUEjmK0OF36+QysxIJLgj9x5uQDymixyS0Pvu6ybdB5Z+ggH+hs"; + sout << "A4wt9pW2wbaOGU6jjYN42C/ehSOClqGP9JE1qejqcLRAXm+JDDFZGmCBfZRw7W6R39cNfrbbACVF"; + sout << "de9yicnqvNbpaeAWAmfjrwg0QvVVmmlvsgDinO2OdZ+TLx0KKj+4+Zrt0fFfXQgUa2oIFGoeCuQv"; + sout << "eyjv6ZUnkYHtJLeMDfuhaxaV6dIG7US6tdGcxKxw0/Z3Cl1AqyDSOSeZIfABudI/tNrQF3bk/x3g"; + sout << "Kb/i4VNKLwEq6O1T8HrZSJzTWus7XcQncGAOJWTsuC0/0Cv/Eg7pY87c6VtF2/kltcQeH0qWNix8"; + sout << "qHpkEC5Nmjm2erJhQ2gsdVQFrVjxv+irSaO7GLhuVwFgj0pMzKDQTFvd49DRMyrwb9cml+u9MZgs"; + sout << "Uqwmc0zXzCCT51uF3LMqVbeDEgt/Q59+03U1N2QAJXtTScjJO97PYZSJvMEPONjSI5AeRlRqpD23"; + sout << "4O5xrQGkzb31WrZP54ZyZ7Gt4T4bxdwVIZSrgSrGJt8RjREWl6LuDpuRsPgnt85EKNRjf+icgibb"; + sout << "sGxnWWdqFdOLHdPKDyeCTiJb/5dc9YvjVxWINlgOMxdYpUU1OZ1hX3OmDDepo9HP5OWA4GnGXR8Q"; + sout << "cNS85y2Q+DS7RF+qxSbIO7hOcNL2PyFxScBXYyrkLUWWxzfc0BnyCaCg4lrMr6uH8vnM+7u16IsY"; + sout << "0f/y4nxYNholE4Q9JlbhXVTBL/ZZaC4qU+/mXrEBkUVY/xj8NqYSicBJpv2c0IYwtf6hwOFdMm8H"; + sout << "y+l26vbWkXEGS9XV/dKKKLbW8EqhKxtlRGB+Tle1HnOp7HhcINVwH97RL9vS8EurSNlSWJCHc/1P"; + sout << "Ui+ZsmkNIQtV6tATb57g/AaResCtNWa7znBxzJUW4stpxdxjD7fLZxyP1dSAlXDluOaJ3vaUpdnt"; + sout << "JqkG/gHcvI/0DK84ddN4VLKIZ0XTrCxvTQJLJcUhA0ENgzdIPQj96CN1XvJ/DdFxx4YKBnb5qyA3"; + sout << "51z6P84v5wm+ciOVpTFERffXwtPPfro5z7/ZWazbUjB4AcOFxj5Sk4EweqTrVSD2Ujv3quBT8SUj"; + sout << "A8Mc7LBMtANusOZeMURs5vqoI5O5hwicxsnqHJeGe7FMRRVH9lrM/+3W5dfm+jJ/ntoZJArHCrWQ"; + sout << "wJqFWJksHTraWTJ8crxkGSgZzFhSHv5pkz/FfVPRW91XprikZvTlVhgyTFbB5gVGeEu1X+M1m+29"; + sout << "8Gqo/AYNTsXS6KfOWnfc+jfQRkkjiLsBS2vsnV255hPBz7Fvr0mqaN6AawiESUSXcFBnfzWQbOdN"; + sout << "wzA468JTsVoOz1MUXZ71hOcXpGMVdKN5ABH+8imVfDYGVUUTXv1e1Z41jrzXVphOQcQA+bJYc3mP"; + sout << "K0HVQQfyuwo73PTJ1pINwbi4OrAJvTP1vEuo9xMtyc2d9zIN5lVg0gTlEE7uwGWVctS67onu5vd9"; + sout << "cQ4/+1axVD2yy/Z+s9JGRZ78DA8rIUqVhOTWNlqmqquxf/mh+nY9Q3PdHfcpNoqc4g465Xe3CvO1"; + sout << "cBGJT0vnZw1vV7rtQ2FqAB1LPAp0f8nIT3VExu/mABWauAImnDE7ZiEnOuiuq2vdS8WspLMjktCg"; + sout << "vlvwF+7AGxfwDCImKSzh4Ph5AA1VuYW3sVQ5I3XuONfappCoJjZwiV8jyKC91rHY1/jP4eMb5vwy"; + sout << "8T5ciRyzv977mW5MK5sx3ujnlVFWr8y8VDhp4qmhnvvswLIKBRbrmW7jkG8Jg1YB7mjY0rP0sVJl"; + sout << "zICdb33DVPWxgRHDjyO189HiKWkzbH7JTQfKOBQwufTNQ4v94HlyUYNTzHg9S6nqI7LiBq8zO784"; + sout << "hGYDRvfB37TziuwR5oHeQ0IiV3x79V0+qLUHMY7ZFpQB+LnGcxG41kvposK0UBUCPDLKsLNQDSn3"; + sout << "ltl6K1uTz/4gjdLjNVlkpO05xJXjK4JEoG+siPuABINp/7iLAgYGHq2G9p8KmKiu2CLtN/g5Aa37"; + sout << "pewz9BkOuvCHm+J9/V8+rMHTCxD3QT962/XOO+iguzglwO+YlN3BpWTnr0VhpQN4ZZZcZcucEze4"; + sout << "PgoLS6CcGSP6xrMdnasvqM1kFCvgOn2ROITQEJzaakPDaB7fzVIw6WyusM85BSNnyc8XT2RkDJaf"; + sout << "VTGEMba9qIAi36u+pC/PH6XTEW+x6nehFJu8q8e7B+VrnN9lkpvVHVDxjci3CHvG1QLPGl1GaGYJ"; + sout << "oLeINwB1VTKtRtEyKENgc0MS2FMk2G81MoKGKVKId2oddu9ACvHqzQMbOqnzp0l7RJewHez830/f"; + sout << "8oKpsUvl1N/b9wfTXJwN77MF/n2zAg3VSv4HInhkdGA9HRSg1JylghqBAOrqpRsh6r/TG6ByJ1uQ"; + sout << "rv0aYvVdVaHQ0kHH+ma5Mlc73U0x+ncDjlicbrIca7x6y/b90emfyRcN0uEKr0fZWtcbK0+1XJXX"; + sout << "WWMGGDUO5b5cJOpDzW7GJe6i1m6yIPPlfp4AWLa4M6F/ys1ONLrbXqjYCFJY4y62OeRK32Ff2zN2"; + sout << "/Lw3hMt6iPTm7oiJOJjDcamybQlWOoYIv8tkEQYN2nCX7k7DD6z183yX7KtJF8Sj95nOeYpzAnLG"; + sout << "YOfY00SA1cM4aLRUPG+FYrzn8UFF8Bl9UFrLXoJ/IxXOWKqrbjhsKxMvC4xriww6KZqHx0L5FWGO"; + sout << "lcFkhdpx2fFeswyfjQwtyqsGsiY7MXGsWA7fA1TOv+FZhJbgCixrx1rxwJaNel8uMo6pGHsZt1N5"; + sout << "Cr4CkiA5ZZ0ITa63oOGM0tD30yaSLkSDZahPT7xo5s33HNAgXNxZmYnv8M6p421yuNPraabZyJ5P"; + sout << "3FtNwzPpw7Q4BPSEBoKc/UbY+mdzm8ES0nKPI6Oeg1tzc9NgP5JaNVgWymA408iLfTjxaxS7FqOw"; + sout << "BCYBbz0EdJigWmF73di27QmfsygTWbHrfdqxCafzsOPuk7l+W26tfJTVxP7S5snKfhbqVNrV/4M3"; + sout << "uhpZthruy0mSHtBx5CuXAtGNQoesk06k22scTmX89BgNsYJAshikS3IHg+LMMfaxBe8KQaaySUX9"; + sout << "j/ULsdT1YZKTeV3NVn/Z01BxF/ZPvV33Dbj4bjEcf2N9byWwicz2CFpkFyQlXaSPPSzMYYtSPCZH"; + sout << "QqXL7ZKh1Xkgh68h5WyuoKkzDXGuZcl0bGVWXSOfO6gK1Mz22to6vJ3lfssNBUV1mxA7Yos1bUb+"; + sout << "tGBxyzoBmphWy8dFCdE5vN/WS+w7VV6O/JQ6j8qbpBn596kNhL4iaK9Lzh0yXRR85nqJbi8bkIZZ"; + sout << "FWskXD/A31jO2IUZKCdB/iqtIuEFZQnq4LLgCTlwF3GbxPeZTlJ1yKcy/2e7nIHZ7xHeG5tzyqjP"; + sout << "oi2AAttJbrL9n4Ebk20aqzYxdYUkR8DcflVhCjPDfeczGwzwJxwznHHhEV/OuxZpWYRvTQDOuUzJ"; + sout << "Qpu+291rSMqntqfrUo7uxCGJLmyRwWqfesAluH9y3GQco+V9yYzqf7ekdNpuygq+htLNc2lG6deM"; + sout << "02oxJHjYyjDlG48zlMEFy9OhLLLPvFVmQWVL/D4FwAwKSQfYWmkUbgqgB0QwHBFSA+xqZBWZbzoC"; + sout << "3xLt42o0hhdlZkXEQtGhtaqMfU4rBbKWxKx1gKhHYk/Gzba+PSi3J0e1A0fC4chtUfWc9UTZT0HO"; + sout << "9QLCXzuC90KrYTYOeaPSoxfOm2aboXBWQfyxZGX9fKRdLgfJ9+8EHw7rD2TUTJueSdfvVahIRi+t"; + sout << "eiEAG3CWoToa3S7i3tpR+YczkNVTl9DCuwD/QMAuPo2mul4bdVQt2NVnoz7PKMWim/JYSFPlJBXP"; + sout << "RaKg8manuhMoDfHTVx3FcToGnIDBd2in/rfbEjE46d+98FRs8p/zhmyQ6/+xw34DKX/NrnU42qli"; + sout << "1+erGDbJFhLmIoU3OhYTiABG5JVefsXQFgIL4kSF+Scratkn+s1aW28PLoooaClnGF14imxCA1Dc"; + sout << "Y/O0HqBFR5lAn+57fsKP195CJBvmyPR+UE8pBE818w09TvBmN6onwBCFtsrEnzTmstINfz0TX5lP"; + sout << "7eOD6+XiZM3y+5YBEK140cgAUae91wD0Hjz8chAlXmgE8xZfalAjHnS1c7ROuPxV9zhj5AjnXV/Q"; + sout << "7Bllgaqoyl3q8Sc6eXgEpJMz8fyZ+LfMkKkMBycdjKG0AZAkF5902y9b1VPVIhgIHueHDt5irhQE"; + sout << "2slLN6pSkRg/3FxPb97BOt0/zt96EvehPVGoKs7OlOd5C6tQLWio1Eh9IL67ldEsXOiwaSMw8QLs"; + sout << "wAw3Vx1VfWnDCReIT7/yxZGHmEGhjdtqs+gGPtZ5jkIp0l4QPQwLhVL1tC4RSWlt8BR2Nq0Ivon0"; + sout << "E4t8jhH5zWeLYZxLIVkyW0SKj5xi7s4Xo2XOZL95m43Kv2zrxzNngy1oPOXTAEUaEwuRKmo90B5k"; + sout << "xdxVS3x9BarxlfUKwNBvwQ8OVmJu6VPWQ5hx++9ZLDfwK63V0gH0tIU69q3diMOjXv+fmvQe28cx"; + sout << "iUh9o3Em/w+7P3DTZu35eSPUD6yzP67VH3pDvahh4uLzes3bdTYmpo39kD+gFc7ADodPWp2E4787"; + sout << "JZaS2A+efTMt/OPen2TV+wMSSK80nbAta3VE1KSQogWp2bFPi0NXqC6GZ+bHV1DVnw5YGs28V5Zd"; + sout << "RuuIzTLauK958DBeAc57QbcAun5sEsPtg+CICI1pcOGhv8hdLzT+YKcuUhji95EAYzCj41nxMM1b"; + sout << "wP1rH0TG7jaHvET8IY4dQb5VvzNXVne0FExHIyIFGIUWoyQgjROX+cHDLe39tvjZo2xs27/CAhZH"; + sout << "dsrgamGD6wj8ZMH+8gHzu6b1yy2crBAw1Ot2JdNhW0kp4DNdFgjqMcZVCfDM8qRZ71DBO82U9QU9"; + sout << "hXh0DCIAyryS/6W9mY2I4IxPrbSL4RkmKzhepcIoui74zu3p+lKY9LgUxYbGNOK46S2olWwJYf7Y"; + sout << "qkOPWcN6yen8rO/9Wkvmxl06qSDDhcYdsc5S409X/VeWz67rJ2vNDiyTKjUJaDWcSU/znMnhtosm"; + sout << "QT5LyX0ZVd++zQvXfjSjBLkJrki5vvKMZSUKhiYL3mIapZpcnsOQ1WRwnmSEave/MidONmZEP4rH"; + sout << "OgS00fnCK7eAWXsZb/rk7tq0cjAFsNeJsHe53E2hJK4TyaYcUi1IM7ZNSQ1yqMYbcFn5bEHBpuxS"; + sout << "+TUy1QppziHZ9xoLn8nOxRstSGxrhBG3WjfMhOwR6JNOe4Y1RDfxVHYeo2BAmtKUutztKsf04UNy"; + sout << "XTP5HSaKx1tZVMDn2fkIj6W+s07imc8sg68Sd1lwJx09bG9ggOw6D0eouy/IeJi9YQAyZT2BqBQ7"; + sout << "2D5CrrD6hqhVF4IVhBPCkJ8+9hbFs5EA8XhItGVBPyXg4/alIJq2ZeBsjhK+tuKMunwOl13sxzn7"; + sout << "iSO4+wO3geQ3+SzKShzywLk+uGm2BdS6teWxatVdm5hKNDirghiEzqdHdwGSoZUlLMYp3KTdzpMg"; + sout << "eAAPELcLl78jObmQkun0ECGaV0paIRhytIRpV3yJGmrRIxWNI5WrqLtocN7hnm4fTrzdlSGnDojf"; + sout << "CBmISIwp60UVjJijLDTVYBIP3VcwaoT9V2F2HkDZ7ifQUTvHmxRDFHuquws0bNRjLSBSPa4jh6th"; + sout << "6h62qlgbHQ/BBLIUXi/I12CwYwlg3ute/qQJIpTCsRxf4rgtEjvV2dG9ltlo8PPMhOd9fkqQcvua"; + sout << "BJaZYLYw9rxq4U4TY/IxFQe8MGJ6It9BOFHD83v3awKPKP84inEWXAkFyI0YYmzFjW/gofECo69R"; + sout << "Wk4/2wv+K5MFEEIzBUGsyrjz0TmiKvzEJ08KZOuWtFXeF2LUGvYeL3TO8FyLRpPnCAoErjZzVAhI"; + sout << "9Sd3ukmv2oZ5oYk52HrUQZ9WywqVgFL8r1KKEouCd3iNhmb+2qUSh8VcvAncWTQbn8yjpGe+kwUf"; + sout << "HssRskQ0nKLIvCz6JwWEjoaglBC3PiJ0EOwDLNP1fKXqPtcPglnu6aGaTrb/lOTCX70yvk2XpdPE"; + sout << "HzsxZHKsV2LxqDI/IT8yIxvM67v8b9CSGOv64a8MSBZ5zYOQR0hunA/iYWfjzuvX5dgW90GnM5ZR"; + sout << "VWoq9zxcBEr+Nk0HlWHpPm3o08OPGmfx/ndMbLLtuH85q4/7Kelm3NSXLXfOxjkRuSZkcAeiRk8e"; + sout << "64X1nXADtf6z1Uaj8H0JfREe91GPSUfyZjwI0pQLDnnsgGoG0DGn01qshIL3ANXsLTHDsdgjaFGP"; + sout << "zZ2f/7V/NKLrJ6HrVpWoQiwhNGY5OYv/KXa2q2XMGybUaoONCdnjGKzPYyXA6C0utQNy7JTKcbgV"; + sout << "/71QilRGXbV8OPBFLLN1VNHV2rQJDSH6UL4RYDdhRSsItpsl2tMJGGS2RLH/FwEv0jDo9e/218Fd"; + sout << "0P3DIy3wAevcgLQu6Ex2SFlTpZcvO+yidCgCv56rJHlVxgN7eSbcKS/l7CiAcafTsjsNm5VtgTx0"; + sout << "CM3ehbIeuPrV5GqbZPgbF1akOkP9vzeNsGGW5kep471Q+xn5o78MZBz7YsMT3IAHd6RViNn/Rrbj"; + sout << "gLWyleDnUQ13gY1xNXl1SD8xaQVvWvvpmlZKyB7uIPEL2VrDZSxmFYyhdhK64yb4Nvvawge8jIWM"; + sout << "a43Sgazodlobk0CkFktZz10gZEUobBos3idIWk0hQP+nGOLhiGO//fwpiplh+g6NBMLj6vI3GoHZ"; + sout << "MKoqqYh4LPDaEtnvILpMCP5NJoBbxNLrUALf17ObB5Kgz9YQsxRIhtrXoyK0XdU98ws4zxDGn5Od"; + sout << "Mj4Zpw/QpI2cCvc6fiue8+l/6BsFse0JL8pvOZDYhy73kaVarhJhQgoZabwnyATtN5VltM8niedj"; + sout << "reSbVYi75bWEdYZrO1PESHmhr7lx+s4gBJQd7kz8wbRADk4dJBf8QDccgMO2yxf6biY6ezjB9GtZ"; + sout << "DnGrIuG867zq9fMbooU1PUrXEMYkeeccI3d3ClrF11AiOxFAh/tSSj05pwl7UWVKDXJrP/T9sl37"; + sout << "qW5Uy21+/Mi+kNMCaunB+fgJHsjvoQpwSmEb9k7b35H04v/qe1deEKPua/xXlawyTt4vv9OFwLFV"; + sout << "ovcegUF44BJCRF7zlZnLMzyveCE2Uizux/z5zXOW21bFEaODnvc9N8oTl7XkQAyDT67JrBC7WbH3"; + sout << "g8lzwSjSyC4paGxi87EifET4QFZ/wdhP/mP8lZXzkq8xieb2ad9dSCHS1T36XB40s2QmiSu/dXNN"; + sout << "yaGFiWtFCKYhie16fkeBuCuzcW98RXOUQYBVCbSaDVM7nNTXU5+2B5NfM6ts6+7fO+ovsJh74pGx"; + sout << "HrPXVdWPNb1Y0v7YhX7fLu77zpFbr5Mo1tdL3H1P8k/y+ep5/vsBE+git78nNPeoZajrbKuf3QtR"; + sout << "uKx/sDQrnt3opQr2k0BEt2HZOPeDw0HuPJnpUNg0u31w+xWZN5YsTO6HBpfg9x5hsYX9YybRFoWJ"; + sout << "bpfcNHrNzlc940Pbn7cy/XKiIsegvY8JKScIsmP0pnv/3rr3M+FDocBoe9QDnj7a044KpD2ghnwl"; + sout << "4uJjqFHv7RoGkY4365yAFnD/vm7tEvO09ibX74jzreb5lcSu/f0hy3IA2IbFEYiY1MBL8MPDBh6h"; + sout << "dKL6Jf+UHWrz7MM2bSfDzf467cr65Pt+N3bmX/BR5yd5sPyaiqQhEewGImg/bOiaqpRnIKIFNrF1"; + sout << "dHziLHvI5n/RvsyKPc1ScVFFHdGs9VMXR3ODzG4Q4wNyGYm8Bt9ciTw+Cd1VuAODX8fjGFhCRq5m"; + sout << "wlIDo6rwxR8ME68IXNMqztq1JPQU1VqEXWGAOBIewAYPLOepF4PR28SUgEGqn32oSG6hutM8UAK/"; + sout << "4LEsqVBnc1XrCczlpF2vad4Oux/Qd27JfspaVThUqSc2x0kY1Ca8Keggn3j0s5GCSXxzsQh0CuYq"; + sout << "jSJjKBVxbOP7XU9EHt6w4/9SOnEmd2QVJ26x5o/znpYpW1zTwrIG1ZhvU5AhBtzFtlLEOs1WzBH9"; + sout << "c2jFEPc221taxwBmLaOqa22VULgiW0zJ0nFkXZ4vfoJghqWOK9N9ux14hpKyWUtPPwPHT+oqctBf"; + sout << "RzNzQ/ODTSIkbFdJzskMR10TAYT3zsY27dWLgdAnmwXE1xr4SCbilLOEZcsJJDEX+XkRUsMbxKxp"; + sout << "Cr/IrdTuAzx5MogSWKsUj0W3cT+EG8Nv+iHw/q5PIag2y98b99gk73KWnswo+QXx6nTCwuSuON2s"; + sout << "KlVN2mvKk/t6dNJCbLNQoEnIxoHumQimOrsmgtD5rTXK2fR8rZlnBzUG5Bp/45YBEK2a4JMWivLu"; + sout << "nY+zNxxg0wslFV8OJaH3yfssBMr7lzi3jX0D1dCD86mSslQZ9hAx/pkgNPogz8KDbmoQjozc05Dr"; + sout << "nUKPCCTYq7OUxpGXu9JAFsQsbqnWIkqyxAg3+s4M2cc9+P+dFpb99cQAHDIn0LSBR57f6o8lojFg"; + sout << "nPtdvUF4gNJd1EVqy2ERrWmwRQ8gnzLHVI6wJYlJVUa7p6gk2yZBz4yRak1JJUgxp3l1goGG/0Ci"; + sout << "PsE78UDJQW+3FmtOwLuphs51+Clq65O/n5RIeLQpQ89mh1govqglh7o7nHyyulhQTrEpBEvy2GWh"; + sout << "LhFNp5FWc1ZV/yl4SwQEIbSpGk/MTmlPM2vZcOwTGNIxHia1PISF2LNDMDMmqQ0N8y/NlaKha2T/"; + sout << "ZfRhOQdR9nOgI0svhZQtke/ExHkpYNHsF4bbGisnGyzmtrKgr47s3l527s8lP8Pxe2g5cgX1focZ"; + sout << "aeI+CIeKPFynBjP5GmbCoBpOBpFfOgPrMU5cPMvBktXq27wF0eKbZ1pi3WvJo5Yt5SNlRbs8cX6V"; + sout << "dM0WDcqw39ff+2ztBdDR+aeFRmHelayiPJ6Rmd3dWgH5Zk/vIGEaSsc7/KfYuJQXkca+gR1SVE67"; + sout << "wez78yJbRDcoXiWhFOC4ilWR1DDKgTy7ej+qsbEtgEQcAqLukxeAnICNaIYzrQsPacqSvE6MLBiN"; + sout << "DcoQ5TnAjtFPGGkkhYH8TwD4kTMQgSDXt5yqADxbAMLu+XsmgOJzyAL597+yMJm0w/COAOvVYPRw"; + sout << "kgsDb3IXuTpW+bL/xGC9RvOefz1UoxpwP7znW1IYeltHBqT9JLiGLMQSRTl2x+OudBurZ66LiqQf"; + sout << "TnX+HXrN7MnTdshMCQ339YwnY6b1baYKnxVwNdIUcLnpyVqMCKZNu92bPJsdj4af5nZuIP8HMwoR"; + sout << "g29BOeXRb8UNtWd/EMGiuHiNwMMhpDBN/L4H1f6GS6ZeNobN9k613HrCeJmlXQDBWCHE66vVsAhX"; + sout << "X/Qwi0w5PCLwwpoIsgMDnZDLa9ZCyfBez3VP48ur6Qs0llb1j5y2gHz24yGH/fVlWPdnc0pbKvFR"; + sout << "KKW9jVKicLU9NGaOq5vLvBmFMnx2YLwhGMIlOgnnDLUkB7wCY3ADrQau1lPLfQ4QHr919kbUpCmq"; + sout << "3syJxq1j5248HvzVQRCQV9H35RIQGwd0fjwtK8Y60fxa6W0tD7B/hf1b9lrodaHQL4gYt9/lrCQU"; + sout << "LezUF2Y3HN9LxAWQNfoCbYaz9HebHTfCM+ynCUXAXabGngBLCat1GxbYaWlbotAXawMKxRplVYnB"; + sout << "mORrYoF3NaBlDPznNhgiU0Uz6t7Nh03r3wR8NMho/VNPv4XggWVXyyoNvqY1NDZVtHEPVaSlVq12"; + sout << "yUOiQ2qEpATWET8P+19RL1KAtCoRzPcuY2Ikir9MR8sTbog1FWyTqZN42P6Bx+darDn7eQTRmaHv"; + sout << "DrC97SY8szr6hZEA2ei3vM0JpNWKYTOHo7pTlaDWt8gEDt/oQFPUSk7WJmJJ3DeXCGOdnRABlCXI"; + sout << "NXDZx79pTuebQNZWIysOeisk2h2TXVvghE7J7QyqIYh80zqZn6nuAy05mtUT5igW4twr21r/X+W+"; + sout << "Eu+dnXFU/Cw04SFh4MVLMaZ8kCerQ21GQn0NSgFCPd/bfeJyonCdXaO4rVATPw8tgbsI93po//q3"; + sout << "Ps0KQWV//ZPdqERqZJYkZKZm7CNJRC9mcLsPAfnE8PRF191kr/A7BjLJu4UgxcGZlyQbi/uOJFd2"; + sout << "mErOsNiJIC7MuRTh3SF5baHXsjiVg+l9m9lBGXmaSHAHBsKpGqpFHF2UuJnkWihfK1QoNhEkUhBV"; + sout << "QOI63RPKJlgU7h7hi8fEzxzDbXbpWhsMwcKQ+j3UBSBzUGbRYznWar5vHxf8GQ1LklLRr+8jBJwZ"; + sout << "fIPGjtKXXyMNJACNv0iulszYy3XUM56cIdLXqjVgEu41sMUfJGvbi3aP0tX6wl2GO5gptrX40j6d"; + sout << "Qys2Evu2rH26dj/oCWFOqK0EKcH/Nfj/+EY+gm1GI6rSTgKq8RtD5J9fEe1nro4Sz+B9J230FWd9"; + sout << "G89GSwEnpPOE6hRzubd/NyFZqe9sxVIFdNqRlOsPbXC/gJw7Gx9B83I5j1dpYAsqWyBKhalrGRz4"; + sout << "UMLThFH6rJkGO0YWO2FF3Y59cWk3rS7taKomqL37407f3VWOtRH3YNhvxXI/6m0LV6czsmFHrSp1"; + sout << "lkpYyq/UlH4kL0c/m/P2w8JrXwD+1SwgMtMjuwIQsETng1TeEgdyRlkRnmKXVFDfSLQSppfq3U4Y"; + sout << "zrlnUiKPk511ZcXkg7TOkHxIMUpkhlSihUaBPwT02QpbcJR73i50BzQNDVQbQfMrXny8LEv2D331"; + sout << "mtjKrWCWThul7ReXcJV7p3LpWuJWD8h/6dwx8WuBg4/zKPXZst6WHD/xoEjYG8QNoDDOE+SylSks"; + sout << "AwuI4F5meYJNpFAoCj7qQXJKKyzyeTz2Y0nISbP34Ue1H01zu4SBN7OMvhP/Kl7RSFpc6JkV2brz"; + sout << "zksgJZ4vpuqdij2S4lQfJbtdrk8smc1VWsVtkRKZzRfbPmECuz/0aNHsWe7sFyqISk2Pe5kwuj3O"; + sout << "3mvWDzXBK05J63FYRnpemGdnvIkKLddEmxNecMuotq2s7fIWy2xeu1Ge6/8ILJmtwm9Bq3afgFfw"; + sout << "nRBQ7dj40h9kCgGUoAx0a2oEwTdvdgHp3eEjlyGMqdodEC5F7Cn+uSh89UsHUYbxAI/qTpYekaVh"; + sout << "5MRydNwdIg+3XCHQi+QGEr41DECSXUAoL+mCG6W87WE70G1BsvA3ofdGtKUxTXskdWVtyNip/qXC"; + sout << "f7Fez1i7Cd5U8rzkdCZD56P4KKx7H9N10aU5jIyfba6IKOG6WEqgf92SISTXCVuQSHqVBNHoGiG3"; + sout << "r1xxH8KEzApRHJF+whOvt+pP8eDVfVIdLrzDH7YP5nzQvnD1eVv6iZw0izQu7ArxMoqgSXOZlKBW"; + sout << "L3nkAohU98uR9oYPGck8HNiyKzuiI+DCeu7BOg7unI0WJsn0qE/F6EFEFYoqusBXRgDVwmySLyAo"; + sout << "+0eB/sUawsnUh/w0lRHBso36uynLjUSMBCwppQUU84pJ8fqr5aUfdGfWzogc/q0pFxppiexM6xHd"; + sout << "P8/lVcbgKEclaNJRkE401wDUWHDmtVAHiaZA/532FWyrucLOp6IRgM+eZAtao2/hQUtwOOcaAQNa"; + sout << "DcCN488KJmPQntnz9zn7Ep2ZZvsrx6fg9ubEEjHHIiRYSUDClLUqqFAJZaGrGsXmrSDxF0AEfQgJ"; + sout << "aZF7CJ0WBl4qyUBcIIeWr/+u7Xm9WDnD3Aowl3/RVxH4L9ffGDEMkCeY/rmNp2SDuWfti7vSrxge"; + sout << "NSQEaqiPPMHrgQ0vuGU7ipACRYtW8MgxSwxMc8AaAbhPQu8L/ZEC+/pfxDBy9mtMKnF8Jj2scJtB"; + sout << "GKhZ4s1s7F1fqA/m3Ts4rfA0HLY7bAWeXa2sNAdvkjoQHMLAiP/5BB9B4vNqVGLSrislMJ0sYfcr"; + sout << "t/O0wfihW++wYxK5HJWn/g7U7+lu4OZHco87HB62wHcvy24LqwNKY1w033SQJzkLb6Nkqj4HU9cw"; + sout << "E82PHpVUAXqR6FRqOSGX9bMMqWL/e66vl0G3hbhIwQs5toUgg3XVv13F1o2sTzjiQ+22wf3PEuD+"; + sout << "aYIYKQiE4R8kl3/96og7QmVfWoPxZiVEGBzM9bTDHx88W1D/CkwPkH389oV1Xky/L9UHd8adnI6l"; + sout << "kfuHNP/geDSReFNJ3vtqUo6RnWHAFaTj+x5GDJDtol2hC0KaDLJCETOQRXUOOW371wB3I6slEKjG"; + sout << "gANfqlMyzsO/7bxUw8X71uDns683wpW+Kwr73bJwe0OeP4dgwgTdTLQApMZncaphcpuZCEn9iSpw"; + sout << "GX3EO37EZR2RvLBMvlHhkZSAWr8TqOALJeFEuYiduuUmyM2xh7mLw73AqGA0Wm6rnc7oZwXhwaBa"; + sout << "VhwUZyv5rGMNQl/p5bMO49+tQiaIP5UXxheMCUN8eD6m7qDyHsljVQ+Madsj5iaJw5rkNcO4oquG"; + sout << "mkrQOI62kZGM7reBAZjYYjf33ttAZspzXJVFkNt8geb8e9/AFBdfB2nkJebf8aAAie13umwDR09c"; + sout << "44fUO0jtn43FxMBMXRt8SN3zVf4ycLW0WBYliqRrfhbOlN4wfzHNnII0JBO4wuDMC0zXnurwcLA3"; + sout << "SWng9i7SI/XH9YZLdEn1cEtkTNulFFsZJGbLci2EvHEnG6D/m041JrCBRP7oCGonxdi2zaKe02B2"; + sout << "LHaMeX7Of28XebQgQ3ZQrpSxW+h5FqIVNNo9EYxnqDDKyB7TeQnuISmSOwrcARqs7EgNAQ3Wbepu"; + sout << "Fciu2AFnzjsW37Di4Z+L8tVWDN4nZstII68QCzlx78pYFiWD7lBQvuMYXGPujFgNGXIW+uk8NM/H"; + sout << "ZWKV4UopvDx32VBudzMGJn3XPfrFqeMDukDYZXe+8G/jo8poz/kPfQmBqMSoHZyoKZ6yPNEeo1pD"; + sout << "JyR5/neJsSu8rCPX61ZMuInitxMKRTS8OD7RhcO5o0kNsFwXxe2AhJm2HbOCqSnGktciP5M3gvpL"; + sout << "d25K7cTMNvZIpmwH8eQmOoicl8IRKpj6qzxZt4m8yZU5yFiYzkMfEPjUwCmJs7tCZelrNnaLoFpD"; + sout << "LxraAAsxLtsTcIL9ajBequXTFW3SFL7jsG8WiYRSYmu60gWPgUcmpaK+dk6HFy8SFCsaxlTmCn4G"; + sout << "cbIFfJU1O45770/tHsDxNOBu/vsdp5yZw97L/xl6jSM9gTMdFbTdYw4Va3nr3L99zk9WnSYVZtwl"; + sout << "C5LtQR0xUAldgCgSUfM06Y85VismlOis/GxVfYqYSwN17u3Rv5xj8JDpTxV8N2CwRXxrq013NXoj"; + sout << "eDztUxX0NWUEDZuVvZUIBv9IO+9UaBHLZ834PTzKolkznv0dfedf47H18EWGxhXcJs/bqIP7qlm2"; + sout << "B3bpaGxX3BbRVzuGrWRccOO2K6K1wxDPBSh03z/qH2ddZsTPArbhqlvRiD7kmMDGJzX/Y2i0SuCw"; + sout << "nwhTENF7U/QfBFD8hWqi6/5rHArt/YSfPhaJyBOoRhTWgGc5RPlvMKviM7fE41Yhyd0vD4BqylQh"; + sout << "iOKpy350GOo7g72UBTjlUHq0uWTX9rMpghHkZrMt2bvNzUn9SCx31vC/H7dVjcuEjCeUH4xkVpWu"; + sout << "g820oyLQSkreplnhgE8IyG9OjJ07t1A2YcKHq0sr3t+TMLp23O6D2b/ZAYcnBFVIu1T2LIuX7NAp"; + sout << "tOmFBbO315b3PV8bQdywYBSlIcILFc9AGlttfB2ksqGF9Juhx7UOPczWeBgaTOhBFFcBoSELVoqJ"; + sout << "YP1KMe2EreYgzjVdUEAlbUh5R7mOs/pPF6XjoP++qnu9YE82r567jLBo3fmMGQJqURjgAIxr12l1"; + sout << "+O9rTfxhjumcmJo6iOeK6/wCPuiD2QosXIlYHi/nFlS3qJ8d5z+tU3R9mi6v6I70GRSB4TkFuHmy"; + sout << "5AX3LA05BcrjPH5LO/ulytgvkKsNIgOA4lRgZ9Z+whgI/MxvgGEzmTEyM2nxUFfzxlM4Abi210+2"; + sout << "WFu1QMDXI6MQ54KGwi17BQOckYEFa+0PSrxwf0ntQcgcGnpPNCSlSATUQZKCN6E67UTP1KRbak1K"; + sout << "nMvD3D9t061X0fAWW1txiUca/iO3KLYtMqAKgW6DZXuUzJdgtk2aqUYTKoI7sVkyVHo03aBBr3fT"; + sout << "optSJzBz8UIRxzRauuWo6o1mCyja2eM8fn7TMvSsZx4QkIF9XZA6UNdeUpbSzwa4XkdTqdzmoa7J"; + sout << "V3z+mkCQ/AaaEctE6Oi63jOlzfPmAOobY9Vp9i+ojaeo77SXu/7cwoWgeASY6zOHdN66e/zedP/L"; + sout << "SNYH+B7QfAvUgBLK7QuP26mtws0qeeGWPYCB3x5+pbeS/PHqIF578D3lasqUUXwa3Nz+OzyBGXBl"; + sout << "AB587DXXvbqK5gsoPE+mh4CJ0qYcUGWOhDejTpuEG50OQiGORtaswXqr86d+3vWdsmoJq8YFEWTR"; + sout << "Pd+TwEF9FO3qrN+R73YWEZ5nihD5eG2Q9ABWg+C9Us9KqtlPYQOeUjwrEBhAJZ1j3UT+OKHXuR6e"; + sout << "6dFh0fmlpXewt/gZ2z2WMZT3okC5JyKIMROWN7m+W+mBMmYg19gvib1KdBeBLO4sxdB/mElqdYYt"; + sout << "k+mwhcuKUWYmaSmnzXdU8/Lxyvfr+hpPIPJHYLXftm1DpNHQjLY2bX+UROQ67weTACWOtfz0VJGv"; + sout << "auzt3Kd5UG7m9nLXgk8ZibGOie2MWMY8C2reyrfU5sqOnLO9VNwazWvKb34cFEPKrGU1QqVELim9"; + sout << "kq/0gFdbsyiotlds/Va5AkeVAVC81yYhYRnQm2W23NY62dMMgorLZoxuJV3RPVUoAa5RKRPq346V"; + sout << "HnC7z8pqOsVfoslTCivd2CEzNQK/cYT/LJ7n2Kw3bhT5niBsVNyAjZaK4oI5fqzMoqC75ipQz1IP"; + sout << "JwLqRuBTBjipG1Boum2TahnMriKcAhV1vAiHcB1GkYf9Bwnne8KiWpVdN54drWw7otSWxpsD7oyi"; + sout << "nPfN7o9CrAcERG94IxkZJD3uY748ly7UrScLvIS0M/zcC3f0FuBHb3k1LKOa53B4Q/wo9FfffpYA"; + sout << "RoixLpFQRoFkC4l2wfMenZEgaYkYyfWoWEfGD9ucOiu8h5IrffBDzs8Ck8/y32oHvL/FlouHrmpz"; + sout << "5SKBIez1dkvdFTI/378ArjAwoJ0i2QsGyS/uO7FqgTAgd/9iIqeGonMnJFRPmH/VdT5mZwwj1BhY"; + sout << "2/L0tu2kF3uwGX4c7fSDN8AkPK+DexcvcRBwG6cdmFs1871Ltv0A+i2UTgdk6/Q7MGnhRxnDUp56"; + sout << "7cTxaOSfGTgKcJJ/US0PZjyWLN6tinlrYxkDq7uODDifVo8kEIOzYn/Zo7YbrC2FEXSqw7khbirb"; + sout << "5ErpT55vaggbD/HtahzrQeYIc6lF5oVtW5TuAAdBBwnPwtYrkAVYQ2KZeGh68JKCEbKfNoAeKr3Y"; + sout << "hhU51cHOBC8p6hm5T5p4JZppu15YmKxbLZIscvWCfriYVwPycUde10Nu+AixLGcaJQCr5LdWm06+"; + sout << "ZJPovwV5yErl40B+WEa4YTBm9HJqYXOI3JRAEkwvioR9TotiOOqCaoo8YtIHRU4y3YOUcAdqvtIW"; + sout << "Sj9eLPFLgTQ33kOtsZKDDzBJ/eVl6eSv559u23AuaeuI0ynarOjvQPw/543W/lCKibBGhuaKbYJy"; + sout << "igAy49KqZ2IBHGJEOxd/7e4pnBgW2Yg7T7oeruGrIlJk3BFKQxEOO9Usc7XYSiXMVXwzGDA4huZt"; + sout << "2zq/k193xAGK86lvndpmzpJ/eyVqxqGSN7yyr9o9Mq8mfVJ6CYGXgb+CybtXX1r5FfNTT7ikJreQ"; + sout << "koaSxOGLvHfys0h09POLbTZ4jXBVrcuLuYoN2Dg6vpnMQSEoZzY85NOGxyCPUHXf9KFRfUe+r5zp"; + sout << "O4NnC66UUzrBan3eTYs5X7DEksg+bPx4KGMQ/byJm4Jr8Jzeug6S1qV1diRpkgR60KCS8Q2zCmqa"; + sout << "Y42riGdaBXcaZ3MM1onSKNWiIWa62QGu/GuwY64JUXkSTZViCgkPbieNk4J9WSjwKWPHwPq97qBy"; + sout << "UcJb7MNuA5cxGXOUpUgo1I6ewBWZJvTIU4o/b+hBbBntfgacsZDooFmU2xUYycL0/ornfsEuJczW"; + sout << "1dI+thnc10hcf/WWhUpOuMLQK3f34AUL7tUWtsuYFaRN+1dabnldREIXMFnZs6G+gmbaX0c4TDLh"; + sout << "i31s8ch2irdU8aSTT/1J7EkMvyDZt7NLHeTRgRRrVtGg8nYbBrB/MxwqGTjtD8X7+MYzloLLiUrc"; + sout << "7ST+dShb/DY4JS5RA8I0RBJbqOcMYzdv61UaxfYOnvZHtbCGlj5MmIMaPCDe4ZTSPsk1+NhUZSKr"; + sout << "NcbSth2EUMRZsJqDLN5Xh8LFoerrFLNfbb6r/Zo1I4T7NcVaSKim7u03ti6xCQj4Ds4yqDLrql0P"; + sout << "8IP3w1WmY07b+YRN9eqS7gFN4ReqCPVlk1qnZOzzPjH8pUPqgb+0cjuvQyJslbT98FrPljGtXbJ5"; + sout << "si6Cp+zT+aIEy9HTtQ8Wq3ojrVLtjMl/aRNPb94A0ncTZYuoZCnOC75i2lLru7HOwHCdA7Acyt4M"; + sout << "a7xgqqPiV13DaPpOQYp5rwKHNelo56FyhrmhBC2qQznvwqLJPRcnMEvfsg3Prs26ncYOOwXJ8DK0"; + sout << "GzXkJVznLs71ggUJY00mOvTS5+5sPWwxyjKhdckQfSil9orOJJWeDExnCVZYuIWBij5VXQiACLyN"; + sout << "WBKo52tef5NJwqNbjInxKmFChkX8N6en0uHyEAPxNLrYrn5I8+HibsvccOOIwb9yeT3TV/DB4SiF"; + sout << "rKq/nLdDdQjbReZfZUHhjKz7i/QB7eHinoBVgLcmAkN5lDJVec1Oq9GNwqKTpM9hzHl9KZueqC0h"; + sout << "EJILZgoJDs6Vy9NgjK+FYD6O/2Uv9CGfnCL47ziWfyWgLX2UPFX8DTrnWG71HfNjdlWO0oAmwJxh"; + sout << "0zoURT4B03KBdm54pk6TNeoThxauGpNkK+vU6qcILlI64ZFxaO42b5rjQXtnmYEGL9d5MroM39vC"; + sout << "hykPHjc2j2qFIRUQ2Eg4mtSR++AAriXGpDu3WuIygTiKMNy4Jbt8AnInqnHe5IeWgzE2eAomR8Yp"; + sout << "ZkI4LCBOIPRQbFZ+DnX9Ntxqx/6rJmXE9njSqgbepHd4e38SMP9iMBR6u5fX9joI6rcRikf9W5nu"; + sout << "G/TiDJNhEwvSOtgdEu3pv/urCx7CFwbkhSVz3r/HbQwup/g5FDuTxteiAXYn9TKCWXomd4oKtip5"; + sout << "aZw1aIIYHKR4zNPsNPw6uJUpwKU6aaJjAjXJtaB3Aq/fcvuHdCYL0AkL5AAItGnCq3TOli3VAHg0"; + sout << "zLm6OpVaCUravHHyYSuwXHB58w/zjz88c3tlOLRJO2CoFIUzvesT7FkKN8+y9itHmIdvgZCCWbm8"; + sout << "uoYwp4U7dy6bBVvcxjca4A9f5BcsnZ0u3h8JRxdhMpwk3H1kt6PwEFTrk2Gb4IemrBh4GugYC1NU"; + sout << "ZGgtiPD9Lk9pXqPb9LHcY4nWCEioiMUBd9IJlcWa5NOnZe+xeEHgVpy0vYrtqVHsoRAiXzRTxgaa"; + sout << "2RlSne0P1g0NgU4YCnRf12ZIS+LvPMbYe3mXfDpn099AvuQiUngIFGKm//F59pxxMjZg/ADElluv"; + sout << "0eOBWtYiaYVDL7D0hiffIcBcKQotLJ5p4q7t5v+7oowQ7cMyM4hvL5IjS3BqCpKNwWg3PA9jnhun"; + sout << "sAVHJa9lLMk3cdXix+BKNhfq3IfqpzXAMdlk7GeKbEl3F1/+q4sxVq7sd9X9XLNR2lhJVy93soM8"; + sout << "HfWHR7qaHnA/7XlDMDORDxuT7H3IKxDa+TkORiI9ldqwCEzC0ZGDBU3rBBxNK1jzXH7vzQ1vX7x7"; + sout << "eqYl74NRYH8j1jDxWeQx6bQWSy0V6N+c8uC9N9HPVuH/18KArH+sEDxGRW+Hxi1Wkls6+QyBiweA"; + sout << "tlHueW6h53s3Hc79AEK7gM+40nwvjCOG9vojSoI7bc2ndc+gdqRxYdUz8mu79bx8CI9Y/WlGO/Vk"; + sout << "opsB2ETODt3OdS5xMTUZfSO732Hj3nQUTmVVheFbxDpVPUXNx6AAp63iRR1oYv8b4w6rXrjpj2ru"; + sout << "8U+EXxv2W29LDJFC/JRBwj8JyTR9lDhARPqehW6EZmUYERihI4NZbM9mhgszbC+XDiM5rL80FVz+"; + sout << "4eIA3mJrPNgmevZC2Y8osV71pnmt54foH2yERar9tqSS+q0q7PWX3+5jIrM/IDT/TBrHQ+Boo0y3"; + sout << "bz8JHPnwCKXv34Kjg4NxbDjFT0G8hL1+LsNle+d+PyYPAKSLOjtzDicknXz3M6DgjwMrzD3zLf0O"; + sout << "jUlwIR13B4f1mwH/cOatF5LfXQRCmUxJt/16/VtPsB9FuC0sziuda+4/CGAIjJ5GB0gqzwN7DSOt"; + sout << "7cNdIO7w0CTNOOfVq9C1zPi4fndOnX9F3M+sPSzv7YjeUbkchXzJhip96etsRY855grSR//reW8P"; + sout << "XNodO/FvibrfkHKpG7fZOU+CRLqnikcNgIGBDYumpgdWA1ynvAOFPgAk5tqtLZKmrIe3ceptlG6X"; + sout << "E3UDGPMNVgExa9GFR1e5bIq0GZCNeIyGMq8cGLny0uRC/u9nPu0NoTMxFb9rr2iEp1DJB/zyiped"; + sout << "EoTpBPDptrXVQ9DOLvpZnplsmuuCMELZ7eP10CaDzNXodNNSFr6bk5oDTHJG3foIB/VHzdH7WDDG"; + sout << "DKMcvwwXAB+gAe0Uz3C+W3DlvI2xAHxpjx4hPvATyJwVgYLus9pHzKzZCJ/gLYMgcgyArdTPI7JT"; + sout << "ExgJcLp8QTosu4zuYCJ6GgPmtYlVtUdLWiPz6lHCW77bkkFmz+c6u9WF3yzzY1nyhecgnt/7jbyA"; + sout << "iSeDme7x/u9K3n3CTHc7+ZXHLT6HSdxKlwZ6811c5Z12X8aP6vvL/ysGTVwa8lRngJOOiGQuX6GT"; + sout << "PMohKZIhYGudZundQNXPOCavN+dpdT5m3AVMN9IrU/ht8o9eH+Hmhs0gw/iEwos7GgEJDl7IUcAg"; + sout << "TZb8blBMrDfdTyAKjYcuj6iA8aDYes9sB/W8R6AtNRG4S9vQcJuCD9JBKjxwd/5jG4hrpibV5bMD"; + sout << "F7dW4/5kQsc3AEdYIeGITaDjXorQ8UZ4zF9Zy81kTw6RAoTKdMYtDgnEffkDSRWGUqySn1CkTs/p"; + sout << "cdY3SRk1HMd98JwlVLpED+mWbcLwRuROBimfERC048b/vfUXzGSCgdYZ9AGQhHPzu8y/nr7Lp74h"; + sout << "I0ghlojt4zEkgBgK1JRMXD/dxwqIu+PIOo6s+8vwOMCLytvqCcTR8hqs63JgVGJ0c4/O8R1nbwFg"; + sout << "iQiCrQ2EjBbGaBMq8vhYNA1r9QzZwHejoHm2Z4a+JXoef2al/GqOWnuOUHBotsp70MwvCS8Q/EtU"; + sout << "TFGRttRKcrm7OhNmeiFLAr4P6C6O7gJxMBu78s0bIzpSIDTZFCZRYU7NWdSysmRq1ngerD9inrKg"; + sout << "zt/qQLnry4PmY8NjPwt7QYcQziZLN2ruCtbvk2DrPVSx24HDHGvB+0g8jLNLQDfS552/Xs2G4zc1"; + sout << "wwPpayeQy7HFGT7Gtf3oCI4jb35g8MTCXoOMgM1Dk+EYJIyDT2PBae1AcDcDVj2JKOI1UbvanFI2"; + sout << "sFpsLkIgj8SVUI5dX2PEEl7AjmoFJqyeVpqmBgNEfNwPEqfNNLUp6mic3WusrWO/jQwFtUIdEWAc"; + sout << "tjyIqyXeF1e1zO5po9/kJOrRJFAFxMFI8PVZNSQJ6KIQDhjPJHi/z5cCaf+lxqSo6OUt0D7z9o5A"; + sout << "hfA5GwHX6oy8tGqommBrb6nF1AWVInk19mynB6E17TQRo3T5W2OFG+BaMRO1xlg1znAurKht21fr"; + sout << "6VZscHHtctXifsg4dequ2SkchmjJ5yWpxqRqckES9B3nVTSiqn71/T2O/poDieJhqADWNw3gj96c"; + sout << "GI4vlUR4Zo1drdvtgwHFFR+RKndNuzv2BIYvB4nm5SBqx0nuKcvbbxkEEhIRVOb/GiMJ8ZwjeVxy"; + sout << "b07lDWMdBxLyxytbDoNZbefVc8N0YFrsFvrc16W8y4XKS2NG7Kfa4u977Rr6Yd6+u2+DIYwfsrzM"; + sout << "DWs67DC1+vHawIhIQ5NyB/R6hx/Tym42s7IK7l9w80SaBvpN2A6vaDlNlOD7ANeZsmaAG8L6zFsb"; + sout << "FLBPRNFxm+HxHYjpG3+dBySqcIyK0tSFfykLVffPMFCtgzTwDM3+5VJEghCyVlGEhDl5tv97jwVI"; + sout << "aTYM84hH/JqKcDUQwmilCXqbb8bh/kTEWKzsuuyaLxr3D6+5MhVMakjG3RkIKfjddRukZC341efa"; + sout << "fFJVWanSIf3lSrhFjsiH5k63EjbZmAmuMpfiLZ/Eh+zjBbh0dE1spYJ62S+PwkIXMTjCxmoJW46O"; + sout << "5Bdl71SqZW7lnQUubNcgRlxAZGKY3aQlIV2qHptzubV04SMqdcyVYPrPmvZSM5U40eDoJB9UYB/D"; + sout << "JaxlWFMYrW1ntJTv+LrtYepKdD3uC9SLKi5VYrWjx9WzRMxxkDcNt85Ak2cfpA3MQXI+uT7lQVoG"; + sout << "vDM472l4Ruls317G1X95HA3lcJNseSQnFsAeE9GhE6XXmxDwIDqUQ6Mk+F71rvDufNIj+GaUTISI"; + sout << "+lwC4tqse9dg+7wcqSTqe1cx0K7tZo3ZwR+Y3JEe+6HyURImluo6BRV0n6URS2MLpgN4geV0m+6z"; + sout << "ZIML0oUAKyJVdyj4XWEq5Q/pPDvz7jDfKyqcHTCwBCuBQEnhsxi7Pr3yEac/dwYDKRSkX+7Q9mdo"; + sout << "lAjgzkEnxlR2JUJxtrsEuhTiWql6HjmrkT6ohOeR/k24Y5cEiKx3QM7gBcB57IoxxMVaDJut+Q3x"; + sout << "o4HQ7J/IRtkKCPp0LLSK4Ue+nSgJY5LWtD2a4KzmVnahmXeYXCuAsKm3+7dGFISapet16Grci2+j"; + sout << "ETdZGgs0DEdZWncLxrCnhnvXHrfsSsVyqocOtTs1KiwUgGgcEhlSBZyH6bpF7IYhQhPzML/SIwy9"; + sout << "ByDnFZIvESXwQObXa1iYfR37SyV0Yw/75bcWAnx/+jbodjLZF+RMskqs/i+PnEasmlC5wrjYo55o"; + sout << "E0tqAakxpTNP8+TrRgNMLFP4tP25lfDLUka+pUXoeLC7gWQ0MTA8LDx5GyMlblaIFGHx/eMZv05r"; + sout << "he2nLJgKs5ZtTMh+8pMNkEBWXQDKyJn0oVLcRQL80bPfpUfsnTW+jrEkmH8azM7CN028QyPEhBPx"; + sout << "lvDsEjVbElUqBH+eo3o0CYLpnABUQ28KvsJQhHAsk2WVScBv/hm/T4c2zABqaDOl2LkRCK9ZeF2W"; + sout << "kqLNjVH/iZBVQcjo49++lINm4Nupi0GPnWpsuMnNoOdZzK7Oygqbxs9Zz5VqdjRCpxQGQgOhAD2x"; + sout << "aifYz6RSk7i0RB33QkjE3u+MMEQLL51zl2EOREXed6AJT0U9D0+hSLH39wqMbdEV9bIFcJ7XprJP"; + sout << "ueO/zZOmCHIBfz+47Kec5BAWkufEOfKxOxwV/Wvh9+VA2RcbA0exvVj3Fy/hPmRaw6E+SBZTJoi0"; + sout << "aQltfhFmfLQtT4SRl2FjDwlpNJHjmbmVOALldNBRhxSz4oWfurPqAB1A4MazDorJI74vqElH1SMQ"; + sout << "7RZI0hDhS100VqCdgnYhtMa/We1kNuLcLi7cSq28JPUE0tkIE8SVLldq2Ekn90XBP3mY8LJe0WMR"; + sout << "J/LzfvopGKYrqday/eYuiXl+lihXp0dpph9ZUpRzmAYWtns/djnzFSGS3bF7VOOaFuGxwYHi/1p1"; + sout << "SaiBJV2a4U1vr0mh2OU3wMuqd/t5nfl+Sv75zG3YOAvcoiUH76xxHvRnn+D97TegC72k4mUgBCPs"; + sout << "NZJUhjyKF5S3rerk5uhxFmlAmy+XMo9VJJaYs6OxiIeSQa3iAnbnWDzY3m1chupbDYK2fs8tBg3c"; + sout << "ta7LM5SgWZLMN5Lj1RtEbFBpAylXWrb3JXGuaK/C2biFNTSFU10Xoit4NcqZuYTXu6OhNahZqayp"; + sout << "IYXTMXUy1KPyHHMFWZ9gXXB/igOg80NZ2O+EYHHfRiqSOFJ7xhnY+H4X0kh9v6N7lxlv1WOudRnJ"; + sout << "lWg4FktfX+JVDflDvJ0IqKuM77O5fgtBpRL3tpS2dn4U227ZiXRybiv9zCnbpK7cfq7/t69AVgUG"; + sout << "qjyTGZ+YG+PDjVBWC0KGsJWIBIHwCoTgIZO5kgU/831jcG+LyuHWf1Q9bIEkOXy6UzexKYCJgIaf"; + sout << "EH8IsTAVe+uUvVSBpr/Eis7LJkyx6KOAbDR3eXjXjhXJTSTK2efctB2Bb3GraweN/LUZ7suO16Ng"; + sout << "odHLPnS7jsgag4B7sLNCE3NeGfd4VIE6bYg/9+Ci8k26WMJzOHIO3H+Kc7cTPCdohAQS93epRcLz"; + sout << "bz957op4v9NBlgMcIKW3Fp1opHX/+vApvBecnNDF2oJ0bUq4pFuM8Ag+2PdSDlJfcdnq6S0lNA8w"; + sout << "H/65WlKeVTHRt0nm/p3jJ8JNUlIZa8iZyGa/vpCaLqymASIcIfYEdfi4ugCOvbKgBOT//4cl8kji"; + sout << "3BpcImX1IRx1UzgcIcK4lqO8WoSpGqhClHHi+/ezdj06JLKP27zKeHZwhTMAyYifDnCOJG2XtuY6"; + sout << "gg4LTyI7BcZ5jAogPeibS+GdwlEfN+hOk1idevzrgEdzJTIR3i6p7ugPS4unLC2+waQfdJMs4Vjd"; + sout << "aCmDdalKnQlpHL0Smr7Ezu0P1jPQcyO2hGupDOHl/TslPpoPcA1N6OSTraByTWL4yo9aKG1QDIOX"; + sout << "NLGvos5UGUwmPxJMTeBGkxT7hyKmNzMmhfKB/yfqDaBaV6fosCm7S1FVsdnU4PvcFxa61Jrtok6n"; + sout << "/lwuY3BM5VCG9/kZQtVUtImeRB4hYy8fETZI/WH+WSGhZZaum+Z7r4Ngg8/DC21U33wyR/WElnVE"; + sout << "+Ny+yHoK9xkiLpLlx4ns79XnKV+S511udI6XniRyg36kCNdRWTfN81HfbS6h/KXzl2ByPVo0Q9Op"; + sout << "lTF2Ox2WJ0x0RMrE1SyS1KMyL1ff+ToYLKarrgalvFUvSMI7cxPlc6phczmtYojlm9R7JdJDpuGB"; + sout << "k9qA/UEnH0v+wl1+5rt+RZ4dHtD4FOXvnqW3GdN7XvISIFJLhq+UMhBPiI4M39ZRxGpuROiF3Dzd"; + sout << "gD19YX9c930yj3JEhkbYzjzwa9CWLcnXbSYe9sayBJmCjpUxYmfrWo+mH66POo9ZxJeeEOPiVS77"; + sout << "HXv52am6J9m3gm2Buruibflo5CGo6Ngw9slSm2rcP+pGEAyKjQoLrU6fb0TxqUnRB8q4PVvbBRFj"; + sout << "c1y5MU03KShFhFfcM7iND5az9RcBJwyNZtCZr7OxlUDVDJP5IJubY/5dahCfiX3EfMXT/dR6R8Yq"; + sout << "jnUNqjVGnIVm456KF1xlEZBDwCL5LNUhGCV+kp90STd77lypcQPP1q8iFDo1HhHnJDj8is+5eMY6"; + sout << "5bfMP1EP13tCqPnHgD4hK7Ega8bkvoEKZInxohWoaS6F6o3lslr267EKKvjHDmnA9uGWzxss7nuj"; + sout << "3PlaWuYpU9NfJp6Eq0xYNYcf4qLTN3Nhd/GRpptxzdMCuu8x2CYKBkHKMy5iJELoJ8ntEdLtVdd+"; + sout << "46L46j8oPIUyswAbmodfp82VIodnJ8xvJUyJWRCXpCa2esfid02xvel9h8w1V8Uc7BKu9HMsFCDn"; + sout << "+FbebdVIahMe67eUI3i9+mSZ3BNq0qbkn/dDabGwixw9lsCU5uyHZrzafG6lIA/MiepNzIRLO2Wp"; + sout << "iwhMc4f4l8ZRb89HhL4bjl0Wo44VEBP+Sqf+jWMPS8qo5st+MZbtu89RObDH4fOr+1v5dvGkl5/d"; + sout << "YROeKTCWyAWWSSoinCHoR/IcBi8J31l34Pk3nNrZ7BAyMn28ZY7mue6AF6hpieJ+o/0RRv5m6VM1"; + sout << "1OpAFyajq7E9/BjrAfRSpCtoicoIx+APfjrLbCkEehMCBZoOM9USTMi92Id+Fm2iI+/AMStEU6tJ"; + sout << "zumgP60vNm3VwZw6ep1m1DSb+2lduAxtuDtaspztXYx3QhksLmbzeKD91/v2Mu/JveEfy35tf1kY"; + sout << "fcFtGY/puCXnAIxEeKY26aiZroOWt6HLIafNvD4wg//A/9iTszmQauiTWQbBLxcK/KreAE2TOpJ4"; + sout << "klwh8NfDhULKVKoJ/AgNS4I/Y/LFZSX1VerAImW/WQmpkqvjnOUrpLNkuV2NjA2tKP1G45dInmG/"; + sout << "nAnaPoGs/bAdQLIJayWx42xsKM1pil6q86WvoKcEH1WrRRXsRJgw0dU+Z1QIFlC+xKqP7xS0LZEs"; + sout << "fjyeED07SGKBhUj1kbSp9E3PMr32yDHj23zmnFyBcwihktaRV8fEquoE0CacrSoxfdxzdWIDy9Ya"; + sout << "6tVBbCcxWoz4CSGzBKSrEXVjedHR6H3etlrGwzPWYwbiFQQu3oSNml6ORLh2fu87qIjz8VrO2804"; + sout << "dAqaUcJBtHOOB5KnNMSyZWA1+62WX3P/yJN8Be3AcHY7xf9bP1HUxyq5cWcAbtx9yJV1dhXyvPz6"; + sout << "D4k1sntDiWgivLkY7l3OBVRs/QKGOOTj1uLei/v6X1vR54lH5wsRwq6niPYsyEcCnWk2zzTKOcK/"; + sout << "OBhLQHtj5Wxf+NQ/S1zZnIBS0cww5BqZD0aATVsHfifZLEw2eRM+vap33o0ozqXIdbCzzsotLwlx"; + sout << "CfFT2a/MUDatTSTuGjOPpjbmVtK7qWkWmHAF0j/GQFs73K/8PGmAWomFb2WWB8frlHkkJoEPyHVR"; + sout << "N6hlsAEd8ylb6SwVawZglxdpqfhQpn71NaRU/ZJnk7fnSVmKwcGVeh1sppPYKEu/Q7i4YrAGtLDp"; + sout << "Osipwqh5gL/skuxQwsBI/VNqoxIELFbJKhGrVqr2kHWB3yE9cB5puaP/MPPvIwbPSNRZVVbvGV+m"; + sout << "Mmi4LDSKPlxDmjl3wWnzdldnNUb281O5dA7lMuPHgibIFjl8EeVHBFLocpnFbSLRiCZx+xh8sNRX"; + sout << "ti+5Zl/pGvlloP2k7yVNRFpB3Av30jzQePY/DOy5DoF2k4eY0qvSR7HtV2IG+9wpBAUCbIaD/2t3"; + sout << "FBcQT2cWEx0TPZbBqPLAxruNGcv1z23BdWUeQpeEV6ij78xVILQegtNiJlCgMTCmIRr1meCfuVJ5"; + sout << "vbvNLxRiL4JSZbjKgfonK98os72YhGB0Y2oL7ZpcthGFKCGAjo76XrMB/hoeXlx6UKYlLknZImWO"; + sout << "NQj5SAhZElv+LKdSZ1oc9U+0Qm8K4G9+8OnQ4wS61R0MpZKXnfFAMp99XHqN0tD6W5Br22/IOyj9"; + sout << "x2QH50KAfIYZFrJaLt0MMylpVhHPQ1RBembk0wDcwDt2phnU5nm3m22r81uUHp5Sc0Or/vrzpMEc"; + sout << "9vqhEFhFG74o+EGTigUiyzCGUHJu9ITqb55xDmPm829KOqkOEV3L4XX/jLwRFVkbAHaijjmg17xN"; + sout << "yJZkFlYG/RntbFd+p+10JuhI9ahBvvCjFnkC+8/F64Umx4jlWJRPIADPr3yH6tt+nVkC2Cbx+U4y"; + sout << "bgINetMWF1b3syNyJh4qJy/KXX5UckzU9A1Xkz2MvXWiB8Jp0f+Q/bZ0inTmecpyBGWCsD0xRYPv"; + sout << "Ktm2zqD2HL3Qtfpyi0khXGmUDKwvuFnjX5V57ustxKCq2h/ylqTDCMpqS/QQ122Tr2DuHHVjGwGG"; + sout << "PajaD64n0qEj+zaqyUgpDZatEVMgFxpLpbJ+sTyBaYKYPYqGxoo+XycDYIWoc7K+7d2NT76qQcI/"; + sout << "LHo0kGXRTNiciphMn4EX3PkH7zUqR/yCWBRsP4wtPax6aM2m6IpY3bDbOfwMzXDTZDLAfi3is8Lb"; + sout << "H5AfK2DDG45w4olFJuLKO27zafuz9GIAgoTsUwpJIJeR4xdtlP9EtMjG6SZdp22+MBSh6w4CXe82"; + sout << "iIquldB5HzaHvBr3UP5LGQR3W/Wky4udFj2+XCpKsDUll+4MZCeGz1VM5ZYjDKfRDzSaLjJspoVH"; + sout << "TT05yuagEN3+w/ROxo419ZEnvR+X4QLyp0JofhFpSsr+J9fQwaz3IGpPMs7wA+Q9CfEw3FkrFfeH"; + sout << "022qfnq3naUsXyh8xAeI4r+MP6WGlscgFkWPZhM/vDq2IcIbCwZ10Ivx8r9bwi6bAfPE5cru0oRf"; + sout << "YYs7kU0l9z3yzeSvjVjkdZbmgRNPpED67KlUQyRibkWe34KGglSuSClFu6PWrfFxb/9ch0KTFVQS"; + sout << "fo3NDryMrF6oeLHe7osiyJ1ismBimKN3idmZ7yUknYbpR7/Z2K0LbAIttidcQ8LpgUzy/6jROtHE"; + sout << "MRLxKl4ba6tx6BP8WpXWVzx6xYox42qMX24SkhhNVR/9ETntqyFEeDXXlS88Nm14AZARVAp8wJ3n"; + sout << "WaA9OydaCeVzFVZ614BRC41g8VoHz9iZxHF/7fL5gz+YTJH7ccf18hxv4U/y0ROiZAQflTW6+lYI"; + sout << "kIO8t9A9eQ7Djh7cOHCyA7E391UriFUZH8kBzKjnMeF27uYx5U9acC5YUVtDk7O7Xq3wqeGaxS1b"; + sout << "oTVKitvXCbdmvIiCK8N8RZeVlteuYhJ20u7LTHsO87c1Tv83pu9OiYDMwEpAFLwfHwi5YfZecAIA"; + sout << "bkkveDusgnh80pxc/24zp7W42Hm6hfq852ZLLWD2bHm3YZdey2WIbGsjvQS+io4xoEKGLHUOi7Vs"; + sout << "/Te80Yujo61Y3BjY8w58xNplBAXvK4Fn2EenYcDz839o62+xeUuGN4wSUpJF94aU9//hEF93J9Zf"; + sout << "CxFPlooKe1W7nU3sl2ziEh0PeM4RImmQ9zVNkDEx8ahdmvWWSWlmlv1K754XAdwp68cdfKUIY8kk"; + sout << "M666TRZrSCDUvDufPL4y1LTpQUmIhtNLkrYZg0WEXx5QRF00x90VrsIFcXA/3Sjc2MPi0iMck6vM"; + sout << "kh4kYbIqlvgQSwTqc6XmXEM/W3HRuVCH8kQKIXjN5my0C3RbbYIr4vAPaQL8PCWaIjbLSHDcutk1"; + sout << "LBjyPnJQzn2/vvm+3oKN2JXPxeKvTbaoYu7EvWtXxkFJlrKY+o+FQfxmtN8qxniGWVLuLX6k8keB"; + sout << "wCLZV6wF1XG5FsqBoVh8qvAV/Tykk0VO5FJnoJ0sVwOveUgbCFKsReM24lZlgzHsHFaTN4qSWHyi"; + sout << "BGg3nr0dETiOyEfLP/ckF4GhGMMQZmT9jMFyUEgDfszoDgF+jxA1buvCKTRm4iVgFT16pi4wTG8I"; + sout << "fCOM+SU2UltyA9EHdOW3rm8P8w4z0D0/IzH9Vd/O8Y2kB4TUOW4yHdB95MBCJbvEGYXR/A1eTsMW"; + sout << "J96ZVGCmlcOzWqPsX61LShr57TLbE7+hJzxUMEW+ri+goqnSehcpGurw0+d4K3+5jFHxUpNaT5n+"; + sout << "t0p0S42lOvrgigSZrQiJozBCCpm4sCCM91bzGge0VpyuIGTVkzbzXfjzF5J/wliLSVyzqSNq+HqL"; + sout << "06BqMB0leDtLT/msAUV6SRKr7/n6qMgkrIfrAri4d13kZTpexQTRTFXxcoIgwUelzsxhdwZQ4yMX"; + sout << "dAeL/oKOQdjCgnsVFD5BD7clhTEOA8gVVxASHinziac7ZkdYaHwESDZBxJhv4j1nkKNLWZLifph7"; + sout << "NdJBK51yqlBvkZKROBbTiQi7UCe/YGXusDPCBsu55iieo/o65HkF7olpANKxvHxBrQuCTxYnL1tl"; + sout << "pMjxmZc1BA2T7vyd4AzXRJ6tMUdqUqnjPVFY7OIKMgUn7qRieMqt95vzJqh8+jdApnY+xwnIKosv"; + sout << "ox55mijPLs9oUBzAJPpD3nDLR9pnTVIkY2RmVRQFUN/kuJHYbNtc0PRIAv6iDiZhe+jeCkTx/dXC"; + sout << "sSVwD5hp/v0TvaPa0XSPr1BbqlvK6KjtdsVJsUOjHFskNm/8qlIGKp9F5QCtLOBhp1eoy2AZlNlN"; + sout << "+eYQRzwMSsJNxq44rixF97d7qeiOkC/Uu3wNk7aL11AR5iS7gau10LHLs3YhMbUcb+4kf2j9NpWG"; + sout << "wqMklOYYJag/XNyoQs8g44qAha1rVyeq4eXodi0JegvjkXWEB4Mq8jBuHXbYjYiRiHoL68/9mry5"; + sout << "nlN2Duwp7g5yl982CZLZc0k7uSjKaDkWyynH60MwLnmVj2sA"; + + // Put the data into the istream sin + sin.str(sout.str()); + sout.str(""); + + // Decode the base64 text into its compressed binary form + base64_coder.decode(sin,sout); + sin.clear(); + sin.str(sout.str()); + sout.str(""); + + // Decompress the data into its original form + compressor.decompress(sin,sout); + + // Return the decoded and decompressed data + return sout.str(); + } + +} + +#endif // DLIB_FRONTAL_FACE_DETECTOr_Hh_ + diff --git a/ml/dlib/dlib/image_processing/frontal_face_detector_abstract.h b/ml/dlib/dlib/image_processing/frontal_face_detector_abstract.h new file mode 100644 index 000000000..20815cd0e --- /dev/null +++ b/ml/dlib/dlib/image_processing/frontal_face_detector_abstract.h @@ -0,0 +1,25 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_FRONTAL_FACE_DETECTOr_ABSTRACT_Hh_ +#ifdef DLIB_FRONTAL_FACE_DETECTOr_ABSTRACT_Hh_ + +#include "object_detector_abstract.h" +#include "scan_fhog_pyramid_abstract.h" +#include "../image_transforms/image_pyramid_abstract.h" + +namespace dlib +{ + typedef object_detector > > frontal_face_detector; + + frontal_face_detector get_frontal_face_detector( + ); + /*! + ensures + - returns an object_detector that is configured to find human faces that are + looking more or less towards the camera. + !*/ + +} + +#endif // DLIB_FRONTAL_FACE_DETECTOr_ABSTRACT_Hh_ + diff --git a/ml/dlib/dlib/image_processing/full_object_detection.h b/ml/dlib/dlib/image_processing/full_object_detection.h new file mode 100644 index 000000000..1dfc99b2d --- /dev/null +++ b/ml/dlib/dlib/image_processing/full_object_detection.h @@ -0,0 +1,191 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_FULL_OBJECT_DeTECTION_Hh_ +#define DLIB_FULL_OBJECT_DeTECTION_Hh_ + +#include "../geometry.h" +#include "full_object_detection_abstract.h" +#include +#include "../serialize.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + const static point OBJECT_PART_NOT_PRESENT(0x7FFFFFFF, + 0x7FFFFFFF); + +// ---------------------------------------------------------------------------------------- + + class full_object_detection + { + public: + full_object_detection( + const rectangle& rect_, + const std::vector& parts_ + ) : rect(rect_), parts(parts_) {} + + full_object_detection(){} + + explicit full_object_detection( + const rectangle& rect_ + ) : rect(rect_) {} + + const rectangle& get_rect() const { return rect; } + rectangle& get_rect() { return rect; } + unsigned long num_parts() const { return parts.size(); } + + const point& part( + unsigned long idx + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(idx < num_parts(), + "\t point full_object_detection::part()" + << "\n\t Invalid inputs were given to this function " + << "\n\t idx: " << idx + << "\n\t num_parts(): " << num_parts() + << "\n\t this: " << this + ); + return parts[idx]; + } + + point& part( + unsigned long idx + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(idx < num_parts(), + "\t point full_object_detection::part()" + << "\n\t Invalid inputs were given to this function " + << "\n\t idx: " << idx + << "\n\t num_parts(): " << num_parts() + << "\n\t this: " << this + ); + return parts[idx]; + } + + friend void serialize ( + const full_object_detection& item, + std::ostream& out + ) + { + int version = 1; + serialize(version, out); + serialize(item.rect, out); + serialize(item.parts, out); + } + + friend void deserialize ( + full_object_detection& item, + std::istream& in + ) + { + int version = 0; + deserialize(version, in); + if (version != 1) + throw serialization_error("Unexpected version encountered while deserializing dlib::full_object_detection."); + + deserialize(item.rect, in); + deserialize(item.parts, in); + } + + bool operator==( + const full_object_detection& rhs + ) const + { + if (rect != rhs.rect) + return false; + if (parts.size() != rhs.parts.size()) + return false; + for (size_t i = 0; i < parts.size(); ++i) + { + if (parts[i] != rhs.parts[i]) + return false; + } + return true; + } + + private: + rectangle rect; + std::vector parts; + }; + +// ---------------------------------------------------------------------------------------- + + inline bool all_parts_in_rect ( + const full_object_detection& obj + ) + { + for (unsigned long i = 0; i < obj.num_parts(); ++i) + { + if (obj.get_rect().contains(obj.part(i)) == false && + obj.part(i) != OBJECT_PART_NOT_PRESENT) + return false; + } + return true; + } + +// ---------------------------------------------------------------------------------------- + + struct mmod_rect + { + mmod_rect() = default; + mmod_rect(const rectangle& r) : rect(r) {} + mmod_rect(const rectangle& r, double score) : rect(r),detection_confidence(score) {} + mmod_rect(const rectangle& r, double score, const std::string& label) : rect(r),detection_confidence(score), label(label) {} + + rectangle rect; + double detection_confidence = 0; + bool ignore = false; + std::string label; + + operator rectangle() const { return rect; } + bool operator == (const mmod_rect& rhs) const + { + return rect == rhs.rect + && detection_confidence == rhs.detection_confidence + && ignore == rhs.ignore + && label == rhs.label; + } + }; + + inline mmod_rect ignored_mmod_rect(const rectangle& r) + { + mmod_rect temp(r); + temp.ignore = true; + return temp; + } + + inline void serialize(const mmod_rect& item, std::ostream& out) + { + int version = 2; + serialize(version, out); + serialize(item.rect, out); + serialize(item.detection_confidence, out); + serialize(item.ignore, out); + serialize(item.label, out); + } + + inline void deserialize(mmod_rect& item, std::istream& in) + { + int version = 0; + deserialize(version, in); + if (version != 1 && version != 2) + throw serialization_error("Unexpected version found while deserializing dlib::mmod_rect"); + deserialize(item.rect, in); + deserialize(item.detection_confidence, in); + deserialize(item.ignore, in); + if (version == 2) + deserialize(item.label, in); + else + item.label = ""; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_FULL_OBJECT_DeTECTION_H_ + diff --git a/ml/dlib/dlib/image_processing/full_object_detection_abstract.h b/ml/dlib/dlib/image_processing/full_object_detection_abstract.h new file mode 100644 index 000000000..099ee01b0 --- /dev/null +++ b/ml/dlib/dlib/image_processing/full_object_detection_abstract.h @@ -0,0 +1,203 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_FULL_OBJECT_DeTECTION_ABSTRACT_Hh_ +#ifdef DLIB_FULL_OBJECT_DeTECTION_ABSTRACT_Hh_ + +#include +#include "../geometry.h" +#include "../serialize.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + const static point OBJECT_PART_NOT_PRESENT(0x7FFFFFFF, + 0x7FFFFFFF); + +// ---------------------------------------------------------------------------------------- + + class full_object_detection + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents the location of an object in an image along with the + positions of each of its constituent parts. + !*/ + + public: + + full_object_detection( + const rectangle& rect, + const std::vector& parts + ); + /*! + ensures + - #get_rect() == rect + - #num_parts() == parts.size() + - for all valid i: + - part(i) == parts[i] + !*/ + + full_object_detection( + ); + /*! + ensures + - #get_rect().is_empty() == true + - #num_parts() == 0 + !*/ + + explicit full_object_detection( + const rectangle& rect + ); + /*! + ensures + - #get_rect() == rect + - #num_parts() == 0 + !*/ + + const rectangle& get_rect( + ) const; + /*! + ensures + - returns the rectangle that indicates where this object is. In general, + this should be the bounding box for the object. + !*/ + + rectangle& get_rect( + ); + /*! + ensures + - returns the rectangle that indicates where this object is. In general, + this should be the bounding box for the object. + !*/ + + unsigned long num_parts( + ) const; + /*! + ensures + - returns the number of parts in this object. + !*/ + + const point& part( + unsigned long idx + ) const; + /*! + requires + - idx < num_parts() + ensures + - returns the location of the center of the idx-th part of this object. + Note that it is valid for a part to be "not present". This is indicated + when the return value of part() is equal to OBJECT_PART_NOT_PRESENT. + This is useful for modeling object parts that are not always observed. + !*/ + + point& part( + unsigned long idx + ); + /*! + requires + - idx < num_parts() + ensures + - returns the location of the center of the idx-th part of this object. + Note that it is valid for a part to be "not present". This is indicated + when the return value of part() is equal to OBJECT_PART_NOT_PRESENT. + This is useful for modeling object parts that are not always observed. + !*/ + + bool operator==( + const full_object_detection& rhs + ) const; + /*! + ensures + - returns true if and only if *this and rhs have identical state. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + void serialize ( + const full_object_detection& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + + void deserialize ( + full_object_detection& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +// ---------------------------------------------------------------------------------------- + + bool all_parts_in_rect ( + const full_object_detection& obj + ); + /*! + ensures + - returns true if all the parts in obj are contained within obj.get_rect(). + That is, returns true if and only if, for all valid i, the following is + always true: + obj.get_rect().contains(obj.part(i)) == true || obj.part(i) == OBJECT_PART_NOT_PRESENT + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + struct mmod_rect + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a simple struct that is used to give training data and receive detections + from the Max-Margin Object Detection loss layer loss_mmod_ object. + !*/ + + mmod_rect() = default; + mmod_rect(const rectangle& r) : rect(r) {} + mmod_rect(const rectangle& r, double score) : rect(r),detection_confidence(score) {} + mmod_rect(const rectangle& r, double score, const std::string& label) : rect(r),detection_confidence(score),label(label) {} + + rectangle rect; + double detection_confidence = 0; + bool ignore = false; + std::string label; + + operator rectangle() const { return rect; } + + bool operator == (const mmod_rect& rhs) const; + /*! + ensures + - returns true if and only if all the elements of this object compare equal + to the corresponding elements of rhs. + !*/ + }; + + mmod_rect ignored_mmod_rect( + const rectangle& r + ); + /*! + ensures + - returns a mmod_rect R such that: + - R.rect == r + - R.ignore == true + - R.detection_confidence == 0 + - R.label == "" + !*/ + + void serialize(const mmod_rect& item, std::ostream& out); + void deserialize(mmod_rect& item, std::istream& in); + /*! + provides serialization support + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_FULL_OBJECT_DeTECTION_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/image_processing/generic_image.h b/ml/dlib/dlib/image_processing/generic_image.h new file mode 100644 index 000000000..362277368 --- /dev/null +++ b/ml/dlib/dlib/image_processing/generic_image.h @@ -0,0 +1,431 @@ +// Copyright (C) 2014 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_GeNERIC_IMAGE_Hh_ +#define DLIB_GeNERIC_IMAGE_Hh_ + +#include "../assert.h" + +namespace dlib +{ + + /*! + In dlib, an "image" is any object that implements the generic image interface. In + particular, this simply means that an image type (let's refer to it as image_type + from here on) has the following seven global functions defined for it: + - long num_rows (const image_type& img) + - long num_columns (const image_type& img) + - void set_image_size( image_type& img, long rows, long cols) + - void* image_data ( image_type& img) + - const void* image_data (const image_type& img) + - long width_step (const image_type& img) + - void swap ( image_type& a, image_type& b) + And also provides a specialization of the image_traits template that looks like: + namespace dlib + { + template <> + struct image_traits + { + typedef the_type_of_pixel_used_in_image_type pixel_type; + }; + } + + Additionally, an image object must be default constructable. This means that + expressions of the form: + image_type img; + Must be legal. + + Finally, the type of pixel in image_type must have a pixel_traits specialization. + That is, pixel_traits::pixel_type> must be one of + the specializations of pixel_traits. + + + To be very precise, the seven functions defined above are defined thusly: + + long num_rows( + const image_type& img + ); + /!* + ensures + - returns the number of rows in the given image + *!/ + + long num_columns( + const image_type& img + ); + /!* + ensures + - returns the number of columns in the given image + *!/ + + void set_image_size( + image_type& img, + long rows, + long cols + ); + /!* + requires + - rows >= 0 && cols >= 0 + ensures + - num_rows(#img) == rows + - num_columns(#img) == cols + *!/ + + void* image_data( + image_type& img + ); + /!* + ensures + - returns a non-const pointer to the pixel at row and column position 0,0 + in the given image. Or if the image has zero rows or columns in it + then this function returns NULL. + - The image lays pixels down in row major order. However, there might + be padding at the end of each row. The amount of padding is given by + width_step(img). + *!/ + + const void* image_data( + const image_type& img + ); + /!* + ensures + - returns a const pointer to the pixel at row and column position 0,0 in + the given image. Or if the image has zero rows or columns in it then + this function returns NULL. + - The image lays pixels down in row major order. However, there might + be padding at the end of each row. The amount of padding is given by + width_step(img). + *!/ + + long width_step( + const image_type& img + ); + /!* + ensures + - returns the size of one row of the image, in bytes. More precisely, + return a number N such that: (char*)image_data(img) + N*R == a + pointer to the first pixel in the R-th row of the image. This means + that the image must lay its pixels down in row major order. + *!/ + + void swap( + image_type& a, + image_type& b + ); + /!* + ensures + - swaps the state of a and b + *!/ + !*/ + +// ---------------------------------------------------------------------------------------- + + template + struct image_traits; + /*! + WHAT THIS OBJECT REPRESENTS + This is a traits class for generic image objects. You can use it to find out + the pixel type contained within an image via an expression of the form: + image_traits::pixel_type + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// UTILITIES TO MAKE ACCESSING IMAGE PIXELS SIMPLER +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + class image_view + { + /*! + REQUIREMENTS ON image_type + image_type must be an image object as defined at the top of this file. + + WHAT THIS OBJECT REPRESENTS + This object takes an image object and wraps it with an interface that makes + it look like a dlib::array2d. That is, it makes it look similar to a + regular 2-dimensional C style array, making code which operates on the + pixels simple to read. + + Note that an image_view instance is valid until the image given to its + constructor is modified through an interface other than the image_view + instance. This is because, for example, someone might cause the underlying + image object to reallocate its memory, thus invalidating the pointer to its + pixel data stored in the image_view. + + As an side, the reason why this object stores a pointer to the image + object's data and uses that pointer instead of calling image_data() each + time a pixel is accessed is to allow for image objects to implement + complex, and possibly slow, image_data() functions. For example, an image + object might perform some kind of synchronization between a GPU and the + host memory during a call to image_data(). Therefore, we call image_data() + only in image_view's constructor to avoid the performance penalty of + calling it for each pixel access. + !*/ + + public: + typedef typename image_traits::pixel_type pixel_type; + + image_view( + image_type& img + ) : + _data((char*)image_data(img)), + _width_step(width_step(img)), + _nr(num_rows(img)), + _nc(num_columns(img)), + _img(&img) + {} + + long nr() const { return _nr; } + /*! + ensures + - returns the number of rows in this image. + !*/ + + long nc() const { return _nc; } + /*! + ensures + - returns the number of columns in this image. + !*/ + + unsigned long size() const { return static_cast(nr()*nc()); } + /*! + ensures + - returns the number of pixels in this image. + !*/ + +#ifndef ENABLE_ASSERTS + pixel_type* operator[] (long row) { return (pixel_type*)(_data+_width_step*row); } + /*! + requires + - 0 <= row < nr() + ensures + - returns a pointer to the first pixel in the row-th row. Therefore, the + pixel at row and column position r,c can be accessed via (*this)[r][c]. + !*/ + + const pixel_type* operator[] (long row) const { return (const pixel_type*)(_data+_width_step*row); } + /*! + requires + - 0 <= row < nr() + ensures + - returns a const pointer to the first pixel in the row-th row. Therefore, + the pixel at row and column position r,c can be accessed via + (*this)[r][c]. + !*/ +#else + // If asserts are enabled then we need to return a proxy class so we can make sure + // the column accesses don't go out of bounds. + struct pix_row + { + pix_row(pixel_type* data_, long nc_) : data(data_),_nc(nc_) {} + const pixel_type& operator[] (long col) const + { + DLIB_ASSERT(0 <= col && col < _nc, + "\t The given column index is out of range." + << "\n\t col: " << col + << "\n\t _nc: " << _nc); + return data[col]; + } + pixel_type& operator[] (long col) + { + DLIB_ASSERT(0 <= col && col < _nc, + "\t The given column index is out of range." + << "\n\t col: " << col + << "\n\t _nc: " << _nc); + return data[col]; + } + private: + pixel_type* const data; + const long _nc; + }; + pix_row operator[] (long row) + { + DLIB_ASSERT(0 <= row && row < _nr, + "\t The given row index is out of range." + << "\n\t row: " << row + << "\n\t _nr: " << _nr); + return pix_row((pixel_type*)(_data+_width_step*row), _nc); + } + const pix_row operator[] (long row) const + { + DLIB_ASSERT(0 <= row && row < _nr, + "\t The given row index is out of range." + << "\n\t row: " << row + << "\n\t _nr: " << _nr); + return pix_row((pixel_type*)(_data+_width_step*row), _nc); + } +#endif + + void set_size(long rows, long cols) + /*! + requires + - rows >= 0 && cols >= 0 + ensures + - Tells the underlying image to resize itself to have the given number of + rows and columns. + - #nr() == rows + - #nc() == cols + !*/ + { + DLIB_ASSERT((cols >= 0 && rows >= 0), + "\t image_view::set_size(long rows, long cols)" + << "\n\t The images can't have negative rows or columns." + << "\n\t cols: " << cols + << "\n\t rows: " << rows + ); + set_image_size(*_img, rows, cols); *this = *_img; + } + + void clear() { set_size(0,0); } + /*! + ensures + - sets the image to have 0 pixels in it. + !*/ + + private: + + char* _data; + long _width_step; + long _nr; + long _nc; + image_type* _img; + }; + +// ---------------------------------------------------------------------------------------- + + template + class const_image_view + { + /*! + REQUIREMENTS ON image_type + image_type must be an image object as defined at the top of this file. + + WHAT THIS OBJECT REPRESENTS + This object is just like the image_view except that it provides a "const" + view into an image. That is, it has the same interface as image_view + except that you can't modify the image through a const_image_view. + !*/ + + public: + typedef typename image_traits::pixel_type pixel_type; + + const_image_view( + const image_type& img + ) : + _data((char*)image_data(img)), + _width_step(width_step(img)), + _nr(num_rows(img)), + _nc(num_columns(img)) + {} + + long nr() const { return _nr; } + long nc() const { return _nc; } + unsigned long size() const { return static_cast(nr()*nc()); } +#ifndef ENABLE_ASSERTS + const pixel_type* operator[] (long row) const { return (const pixel_type*)(_data+_width_step*row); } +#else + // If asserts are enabled then we need to return a proxy class so we can make sure + // the column accesses don't go out of bounds. + struct pix_row + { + pix_row(pixel_type* data_, long nc_) : data(data_),_nc(nc_) {} + const pixel_type& operator[] (long col) const + { + DLIB_ASSERT(0 <= col && col < _nc, + "\t The given column index is out of range." + << "\n\t col: " << col + << "\n\t _nc: " << _nc); + return data[col]; + } + private: + pixel_type* const data; + const long _nc; + }; + const pix_row operator[] (long row) const + { + DLIB_ASSERT(0 <= row && row < _nr, + "\t The given row index is out of range." + << "\n\t row: " << row + << "\n\t _nr: " << _nr); + return pix_row((pixel_type*)(_data+_width_step*row), _nc); + } +#endif + + private: + const char* _data; + long _width_step; + long _nr; + long _nc; + }; + +// ---------------------------------------------------------------------------------------- + + template + image_view make_image_view ( image_type& img) + { return image_view(img); } + /*! + requires + - image_type == an image object that implements the interface defined at the + top of this file. + ensures + - constructs an image_view from an image object + !*/ + + template + const_image_view make_image_view (const image_type& img) + { return const_image_view(img); } + /*! + requires + - image_type == an image object that implements the interface defined at the + top of this file. + ensures + - constructs a const_image_view from an image object + !*/ + +// ---------------------------------------------------------------------------------------- + + template + inline unsigned long image_size( + const image_type& img + ) { return num_columns(img)*num_rows(img); } + /*! + requires + - image_type == an image object that implements the interface defined at the + top of this file. + ensures + - returns the number of pixels in the given image. + !*/ + +// ---------------------------------------------------------------------------------------- + + template + inline long num_rows( + const image_type& img + ) { return img.nr(); } + /*! + ensures + - By default, try to use the member function .nr() to determine the number + of rows in an image. However, as stated at the top of this file, image + objects should provide their own overload of num_rows() if needed. + !*/ + + template + inline long num_columns( + const image_type& img + ) { return img.nc(); } + /*! + ensures + - By default, try to use the member function .nc() to determine the number + of columns in an image. However, as stated at the top of this file, image + objects should provide their own overload of num_rows() if needed. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_GeNERIC_IMAGE_Hh_ + diff --git a/ml/dlib/dlib/image_processing/object_detector.h b/ml/dlib/dlib/image_processing/object_detector.h new file mode 100644 index 000000000..9f78abd19 --- /dev/null +++ b/ml/dlib/dlib/image_processing/object_detector.h @@ -0,0 +1,626 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_OBJECT_DeTECTOR_Hh_ +#define DLIB_OBJECT_DeTECTOR_Hh_ + +#include "object_detector_abstract.h" +#include "../geometry.h" +#include +#include "box_overlap_testing.h" +#include "full_object_detection.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + struct rect_detection + { + double detection_confidence; + unsigned long weight_index; + rectangle rect; + + bool operator<(const rect_detection& item) const { return detection_confidence < item.detection_confidence; } + }; + + struct full_detection + { + double detection_confidence; + unsigned long weight_index; + full_object_detection rect; + + bool operator<(const full_detection& item) const { return detection_confidence < item.detection_confidence; } + }; + +// ---------------------------------------------------------------------------------------- + + template + struct processed_weight_vector + { + processed_weight_vector(){} + + typedef typename image_scanner_type::feature_vector_type feature_vector_type; + + void init ( + const image_scanner_type& + ) + /*! + requires + - w has already been assigned its value. Note that the point of this + function is to allow an image scanner to overload the + processed_weight_vector template and provide some different kind of + object as the output of get_detect_argument(). For example, the + scan_fhog_pyramid object uses an overload that causes + get_detect_argument() to return the special fhog_filterbank object + instead of a feature_vector_type. This avoids needing to construct the + fhog_filterbank during each call to detect and therefore speeds up + detection. + !*/ + {} + + // return the first argument to image_scanner_type::detect() + const feature_vector_type& get_detect_argument() const { return w; } + + feature_vector_type w; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename image_scanner_type_ + > + class object_detector + { + public: + typedef image_scanner_type_ image_scanner_type; + typedef typename image_scanner_type::feature_vector_type feature_vector_type; + + object_detector ( + ); + + object_detector ( + const object_detector& item + ); + + object_detector ( + const image_scanner_type& scanner_, + const test_box_overlap& overlap_tester_, + const feature_vector_type& w_ + ); + + object_detector ( + const image_scanner_type& scanner_, + const test_box_overlap& overlap_tester_, + const std::vector& w_ + ); + + explicit object_detector ( + const std::vector& detectors + ); + + unsigned long num_detectors ( + ) const { return w.size(); } + + const feature_vector_type& get_w ( + unsigned long idx = 0 + ) const { return w[idx].w; } + + const processed_weight_vector& get_processed_w ( + unsigned long idx = 0 + ) const { return w[idx]; } + + const test_box_overlap& get_overlap_tester ( + ) const; + + const image_scanner_type& get_scanner ( + ) const; + + object_detector& operator= ( + const object_detector& item + ); + + template < + typename image_type + > + std::vector operator() ( + const image_type& img, + double adjust_threshold = 0 + ); + + template < + typename image_type + > + void operator() ( + const image_type& img, + std::vector >& final_dets, + double adjust_threshold = 0 + ); + + template < + typename image_type + > + void operator() ( + const image_type& img, + std::vector >& final_dets, + double adjust_threshold = 0 + ); + + template < + typename image_type + > + void operator() ( + const image_type& img, + std::vector& final_dets, + double adjust_threshold = 0 + ); + + // These typedefs are here for backwards compatibility with previous versions of + // dlib. + typedef ::dlib::rect_detection rect_detection; + typedef ::dlib::full_detection full_detection; + + template < + typename image_type + > + void operator() ( + const image_type& img, + std::vector& final_dets, + double adjust_threshold = 0 + ); + + template < + typename image_type + > + void operator() ( + const image_type& img, + std::vector& final_dets, + double adjust_threshold = 0 + ); + + template + friend void serialize ( + const object_detector& item, + std::ostream& out + ); + + template + friend void deserialize ( + object_detector& item, + std::istream& in + ); + + private: + + bool overlaps_any_box ( + const std::vector& rects, + const dlib::rectangle& rect + ) const + { + for (unsigned long i = 0; i < rects.size(); ++i) + { + if (boxes_overlap(rects[i].rect, rect)) + return true; + } + return false; + } + + test_box_overlap boxes_overlap; + std::vector > w; + image_scanner_type scanner; + }; + +// ---------------------------------------------------------------------------------------- + + template + void serialize ( + const object_detector& item, + std::ostream& out + ) + { + int version = 2; + serialize(version, out); + + T scanner; + scanner.copy_configuration(item.scanner); + serialize(scanner, out); + serialize(item.boxes_overlap, out); + // serialize all the weight vectors + serialize(item.w.size(), out); + for (unsigned long i = 0; i < item.w.size(); ++i) + serialize(item.w[i].w, out); + } + +// ---------------------------------------------------------------------------------------- + + template + void deserialize ( + object_detector& item, + std::istream& in + ) + { + int version = 0; + deserialize(version, in); + if (version == 1) + { + deserialize(item.scanner, in); + item.w.resize(1); + deserialize(item.w[0].w, in); + item.w[0].init(item.scanner); + deserialize(item.boxes_overlap, in); + } + else if (version == 2) + { + deserialize(item.scanner, in); + deserialize(item.boxes_overlap, in); + unsigned long num_detectors = 0; + deserialize(num_detectors, in); + item.w.resize(num_detectors); + for (unsigned long i = 0; i < item.w.size(); ++i) + { + deserialize(item.w[i].w, in); + item.w[i].init(item.scanner); + } + } + else + { + throw serialization_error("Unexpected version encountered while deserializing a dlib::object_detector object."); + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// object_detector member functions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename image_scanner_type + > + object_detector:: + object_detector ( + ) + { + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_scanner_type + > + object_detector:: + object_detector ( + const object_detector& item + ) + { + boxes_overlap = item.boxes_overlap; + w = item.w; + scanner.copy_configuration(item.scanner); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_scanner_type + > + object_detector:: + object_detector ( + const image_scanner_type& scanner_, + const test_box_overlap& overlap_tester, + const feature_vector_type& w_ + ) : + boxes_overlap(overlap_tester) + { + // make sure requires clause is not broken + DLIB_ASSERT(scanner_.get_num_detection_templates() > 0 && + w_.size() == scanner_.get_num_dimensions() + 1, + "\t object_detector::object_detector(scanner_,overlap_tester,w_)" + << "\n\t Invalid inputs were given to this function " + << "\n\t scanner_.get_num_detection_templates(): " << scanner_.get_num_detection_templates() + << "\n\t w_.size(): " << w_.size() + << "\n\t scanner_.get_num_dimensions(): " << scanner_.get_num_dimensions() + << "\n\t this: " << this + ); + + scanner.copy_configuration(scanner_); + w.resize(1); + w[0].w = w_; + w[0].init(scanner); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_scanner_type + > + object_detector:: + object_detector ( + const image_scanner_type& scanner_, + const test_box_overlap& overlap_tester, + const std::vector& w_ + ) : + boxes_overlap(overlap_tester) + { + // make sure requires clause is not broken + DLIB_CASSERT(scanner_.get_num_detection_templates() > 0 && w_.size() > 0, + "\t object_detector::object_detector(scanner_,overlap_tester,w_)" + << "\n\t Invalid inputs were given to this function " + << "\n\t scanner_.get_num_detection_templates(): " << scanner_.get_num_detection_templates() + << "\n\t w_.size(): " << w_.size() + << "\n\t this: " << this + ); + + for (unsigned long i = 0; i < w_.size(); ++i) + { + DLIB_CASSERT(w_[i].size() == scanner_.get_num_dimensions() + 1, + "\t object_detector::object_detector(scanner_,overlap_tester,w_)" + << "\n\t Invalid inputs were given to this function " + << "\n\t scanner_.get_num_detection_templates(): " << scanner_.get_num_detection_templates() + << "\n\t w_["< + object_detector:: + object_detector ( + const std::vector& detectors + ) + { + DLIB_CASSERT(detectors.size() != 0, + "\t object_detector::object_detector(detectors)" + << "\n\t Invalid inputs were given to this function " + << "\n\t this: " << this + ); + std::vector weights; + weights.reserve(detectors.size()); + for (unsigned long i = 0; i < detectors.size(); ++i) + { + for (unsigned long j = 0; j < detectors[i].num_detectors(); ++j) + weights.push_back(detectors[i].get_w(j)); + } + + *this = object_detector(detectors[0].get_scanner(), detectors[0].get_overlap_tester(), weights); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_scanner_type + > + object_detector& object_detector:: + operator= ( + const object_detector& item + ) + { + if (this == &item) + return *this; + + boxes_overlap = item.boxes_overlap; + w = item.w; + scanner.copy_configuration(item.scanner); + return *this; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_scanner_type + > + template < + typename image_type + > + void object_detector:: + operator() ( + const image_type& img, + std::vector& final_dets, + double adjust_threshold + ) + { + scanner.load(img); + std::vector > dets; + std::vector dets_accum; + for (unsigned long i = 0; i < w.size(); ++i) + { + const double thresh = w[i].w(scanner.get_num_dimensions()); + scanner.detect(w[i].get_detect_argument(), dets, thresh + adjust_threshold); + for (unsigned long j = 0; j < dets.size(); ++j) + { + rect_detection temp; + temp.detection_confidence = dets[j].first-thresh; + temp.weight_index = i; + temp.rect = dets[j].second; + dets_accum.push_back(temp); + } + } + + // Do non-max suppression + final_dets.clear(); + if (w.size() > 1) + std::sort(dets_accum.rbegin(), dets_accum.rend()); + for (unsigned long i = 0; i < dets_accum.size(); ++i) + { + if (overlaps_any_box(final_dets, dets_accum[i].rect)) + continue; + + final_dets.push_back(dets_accum[i]); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_scanner_type + > + template < + typename image_type + > + void object_detector:: + operator() ( + const image_type& img, + std::vector& final_dets, + double adjust_threshold + ) + { + std::vector dets; + (*this)(img,dets,adjust_threshold); + + final_dets.resize(dets.size()); + + // convert all the rectangle detections into full_object_detections. + for (unsigned long i = 0; i < dets.size(); ++i) + { + final_dets[i].detection_confidence = dets[i].detection_confidence; + final_dets[i].weight_index = dets[i].weight_index; + final_dets[i].rect = scanner.get_full_object_detection(dets[i].rect, w[dets[i].weight_index].w); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_scanner_type + > + template < + typename image_type + > + std::vector object_detector:: + operator() ( + const image_type& img, + double adjust_threshold + ) + { + std::vector dets; + (*this)(img,dets,adjust_threshold); + + std::vector final_dets(dets.size()); + for (unsigned long i = 0; i < dets.size(); ++i) + final_dets[i] = dets[i].rect; + + return final_dets; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_scanner_type + > + template < + typename image_type + > + void object_detector:: + operator() ( + const image_type& img, + std::vector >& final_dets, + double adjust_threshold + ) + { + std::vector dets; + (*this)(img,dets,adjust_threshold); + + final_dets.resize(dets.size()); + for (unsigned long i = 0; i < dets.size(); ++i) + final_dets[i] = std::make_pair(dets[i].detection_confidence,dets[i].rect); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_scanner_type + > + template < + typename image_type + > + void object_detector:: + operator() ( + const image_type& img, + std::vector >& final_dets, + double adjust_threshold + ) + { + std::vector dets; + (*this)(img,dets,adjust_threshold); + + final_dets.clear(); + final_dets.reserve(dets.size()); + + // convert all the rectangle detections into full_object_detections. + for (unsigned long i = 0; i < dets.size(); ++i) + { + final_dets.push_back(std::make_pair(dets[i].detection_confidence, + scanner.get_full_object_detection(dets[i].rect, w[dets[i].weight_index].w))); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_scanner_type + > + template < + typename image_type + > + void object_detector:: + operator() ( + const image_type& img, + std::vector& final_dets, + double adjust_threshold + ) + { + std::vector dets; + (*this)(img,dets,adjust_threshold); + + final_dets.clear(); + final_dets.reserve(dets.size()); + + // convert all the rectangle detections into full_object_detections. + for (unsigned long i = 0; i < dets.size(); ++i) + { + final_dets.push_back(scanner.get_full_object_detection(dets[i].rect, w[dets[i].weight_index].w)); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_scanner_type + > + const test_box_overlap& object_detector:: + get_overlap_tester ( + ) const + { + return boxes_overlap; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_scanner_type + > + const image_scanner_type& object_detector:: + get_scanner ( + ) const + { + return scanner; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_OBJECT_DeTECTOR_Hh_ + + diff --git a/ml/dlib/dlib/image_processing/object_detector_abstract.h b/ml/dlib/dlib/image_processing/object_detector_abstract.h new file mode 100644 index 000000000..9578d8b03 --- /dev/null +++ b/ml/dlib/dlib/image_processing/object_detector_abstract.h @@ -0,0 +1,404 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_OBJECT_DeTECTOR_ABSTRACT_Hh_ +#ifdef DLIB_OBJECT_DeTECTOR_ABSTRACT_Hh_ + +#include "../geometry.h" +#include +#include "box_overlap_testing_abstract.h" +#include "full_object_detection_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + struct rect_detection + { + double detection_confidence; + unsigned long weight_index; + rectangle rect; + }; + + struct full_detection + { + double detection_confidence; + unsigned long weight_index; + full_object_detection rect; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename image_scanner_type_ + > + class object_detector + { + /*! + REQUIREMENTS ON image_scanner_type_ + image_scanner_type_ must be an implementation of + dlib/image_processing/scan_image_pyramid_abstract.h or + dlib/image_processing/scan_fhog_pyramid.h or + dlib/image_processing/scan_image_custom.h or + dlib/image_processing/scan_image_boxes_abstract.h + + WHAT THIS OBJECT REPRESENTS + This object is a tool for detecting the positions of objects in an image. + In particular, it is a simple container to aggregate an instance of an image + scanner (i.e. scan_image_pyramid, scan_fhog_pyramid, scan_image_custom, or + scan_image_boxes), the weight vector needed by one of these image scanners, + and finally an instance of test_box_overlap. The test_box_overlap object + is used to perform non-max suppression on the output of the image scanner + object. + + Note further that this object can contain multiple weight vectors. In this + case, it will run the image scanner multiple times, once with each of the + weight vectors. Then it will aggregate the results from all runs, perform + non-max suppression and then return the results. Therefore, the object_detector + can also be used as a container for a set of object detectors that all use + the same image scanner but different weight vectors. This is useful since + the object detection procedure has two parts. A loading step where the + image is loaded into the scanner, then a detect step which uses the weight + vector to locate objects in the image. Since the loading step is independent + of the weight vector it is most efficient to run multiple detectors by + performing one load into a scanner followed by multiple detect steps. This + avoids unnecessarily loading the same image into the scanner multiple times. + !*/ + public: + typedef image_scanner_type_ image_scanner_type; + typedef typename image_scanner_type::feature_vector_type feature_vector_type; + + object_detector ( + ); + /*! + ensures + - This detector won't generate any detections when + presented with an image. + - #num_detectors() == 0 + !*/ + + object_detector ( + const object_detector& item + ); + /*! + ensures + - #*this is a copy of item + - #get_scanner() == item.get_scanner() + (note that only the "configuration" of item.get_scanner() is copied. + I.e. the copy is done using copy_configuration()) + !*/ + + object_detector ( + const image_scanner_type& scanner, + const test_box_overlap& overlap_tester, + const feature_vector_type& w + ); + /*! + requires + - w.size() == scanner.get_num_dimensions() + 1 + - scanner.get_num_detection_templates() > 0 + ensures + - When the operator() member function is called it will + invoke scanner.detect(w,dets,w(w.size()-1)), suppress + overlapping detections, and then report the results. + - when #*this is used to detect objects, the set of + output detections will never contain any overlaps + with respect to overlap_tester. That is, for all + pairs of returned detections A and B, we will always + have: overlap_tester(A,B) == false + - #get_w() == w + - #get_overlap_tester() == overlap_tester + - #get_scanner() == scanner + (note that only the "configuration" of scanner is copied. + I.e. the copy is done using copy_configuration()) + - #num_detectors() == 1 + !*/ + + object_detector ( + const image_scanner_type& scanner, + const test_box_overlap& overlap_tester, + const std::vector& w + ); + /*! + requires + - for all valid i: + - w[i].size() == scanner.get_num_dimensions() + 1 + - scanner.get_num_detection_templates() > 0 + - w.size() > 0 + ensures + - When the operator() member function is called it will invoke + get_scanner().detect(w[i],dets,w[i](w[i].size()-1)) for all valid i. Then it + will take all the detections output by the calls to detect() and suppress + overlapping detections, and finally report the results. + - when #*this is used to detect objects, the set of output detections will + never contain any overlaps with respect to overlap_tester. That is, for + all pairs of returned detections A and B, we will always have: + overlap_tester(A,B) == false + - for all valid i: + - #get_w(i) == w[i] + - #num_detectors() == w.size() + - #get_overlap_tester() == overlap_tester + - #get_scanner() == scanner + (note that only the "configuration" of scanner is copied. + I.e. the copy is done using copy_configuration()) + !*/ + + explicit object_detector ( + const std::vector& detectors + ); + /*! + requires + - detectors.size() != 0 + - All the detectors must use compatibly configured scanners. That is, it + must make sense for the weight vector from one detector to be used with + the scanner from any other. + - for all valid i: + - detectors[i].get_scanner().get_num_dimensions() == detectors[0].get_scanner().get_num_dimensions() + (i.e. all the detectors use scanners that use the same kind of feature vectors.) + ensures + - Very much like the above constructor, this constructor takes all the + given detectors and packs them into #*this. That is, invoking operator() + on #*this will run all the detectors, perform non-max suppression, and + then report the results. + - When #*this is used to detect objects, the set of output detections will + never contain any overlaps with respect to overlap_tester. That is, for + all pairs of returned detections A and B, we will always have: + overlap_tester(A,B) == false + - #num_detectors() == The sum of detectors[i].num_detectors() for all valid i. + - #get_overlap_tester() == detectors[0].get_overlap_tester() + - #get_scanner() == detectors[0].get_scanner() + (note that only the "configuration" of scanner is copied. I.e. the copy + is done using copy_configuration()) + !*/ + + unsigned long num_detectors ( + ) const; + /*! + ensures + - returns the number of weight vectors in this object. Since each weight + vector logically represents an object detector, this returns the number + of object detectors contained in this object. + !*/ + + const feature_vector_type& get_w ( + unsigned long idx = 0 + ) const; + /*! + requires + - idx < num_detectors() + ensures + - returns the idx-th weight vector loaded into this object. All the weight vectors + have the same dimension and logically each represents a different detector. + !*/ + + const test_box_overlap& get_overlap_tester ( + ) const; + /*! + ensures + - returns the overlap tester used by this object + !*/ + + const image_scanner_type& get_scanner ( + ) const; + /*! + ensures + - returns the image scanner used by this object. + !*/ + + object_detector& operator= ( + const object_detector& item + ); + /*! + ensures + - #*this is a copy of item + - #get_scanner() == item.get_scanner() + (note that only the "configuration" of item.get_scanner() is + copied. I.e. the copy is done using copy_configuration()) + - returns #*this + !*/ + + template < + typename image_type + > + void operator() ( + const image_type& img, + std::vector& dets, + double adjust_threshold = 0 + ); + /*! + requires + - img == an object which can be accepted by image_scanner_type::load() + ensures + - Performs object detection on the given image and stores the detected + objects into #dets. In particular, we will have that: + - #dets is sorted such that the highest confidence detections come + first. E.g. element 0 is the best detection, element 1 the next + best, and so on. + - #dets.size() == the number of detected objects. + - #dets[i].detection_confidence == The strength of the i-th detection. + Larger values indicate that the detector is more confident that + #dets[i] is a correct detection rather than being a false alarm. + Moreover, the detection_confidence is equal to the detection value + output by the scanner minus the threshold value stored at the end of + the weight vector in get_w(#dets[i].weight_index). + - #dets[i].weight_index == the index for the weight vector that + generated this detection. + - #dets[i].rect == the bounding box for the i-th detection. + - #get_scanner() will have been loaded with img. Therefore, you can call + #get_scanner().get_feature_vector() to obtain the feature vectors or + #get_scanner().get_full_object_detection() to get the + full_object_detections for the resulting object detection boxes. + - The detection threshold is adjusted by having adjust_threshold added to + it. Therefore, an adjust_threshold value > 0 makes detecting objects + harder while a negative value makes it easier. Moreover, the following + will be true for all valid i: + - #dets[i].detection_confidence >= adjust_threshold + This means that, for example, you can obtain the maximum possible number + of detections by setting adjust_threshold equal to negative infinity. + !*/ + + template < + typename image_type + > + void operator() ( + const image_type& img, + std::vector& dets, + double adjust_threshold = 0 + ); + /*! + requires + - img == an object which can be accepted by image_scanner_type::load() + ensures + - This function is identical to the above operator() routine, except that + it outputs full_object_detections instead of rectangles. This means that + the output includes part locations. In particular, calling this function + is the same as calling the above operator() routine and then using + get_scanner().get_full_object_detection() to resolve all the rectangles + into full_object_detections. Therefore, this version of operator() is + simply a convenience function for performing this set of operations. + !*/ + + template < + typename image_type + > + std::vector operator() ( + const image_type& img, + const adjust_threshold = 0 + ); + /*! + requires + - img == an object which can be accepted by image_scanner_type::load() + ensures + - This function is identical to the above operator() routine, except that + it returns a std::vector which contains just the bounding + boxes of all the detections. + !*/ + + template < + typename image_type + > + void operator() ( + const image_type& img, + std::vector >& dets, + double adjust_threshold = 0 + ); + /*! + requires + - img == an object which can be accepted by image_scanner_type::load() + ensures + - performs object detection on the given image and stores the + detected objects into #dets. In particular, we will have that: + - #dets is sorted such that the highest confidence detections + come first. E.g. element 0 is the best detection, element 1 + the next best, and so on. + - #dets.size() == the number of detected objects. + - #dets[i].first gives the "detection confidence", of the i-th + detection. This is the detection value output by the scanner minus + the threshold value stored at the end of the weight vector in get_w(). + - #dets[i].second == the bounding box for the i-th detection. + - #get_scanner() will have been loaded with img. Therefore, you can call + #get_scanner().get_feature_vector() to obtain the feature vectors or + #get_scanner().get_full_object_detection() to get the + full_object_detections for the resulting object detection boxes. + - The detection threshold is adjusted by having adjust_threshold added to + it. Therefore, an adjust_threshold value > 0 makes detecting objects + harder while a negative value makes it easier. Moreover, the following + will be true for all valid i: + - #dets[i].first >= adjust_threshold + This means that, for example, you can obtain the maximum possible number + of detections by setting adjust_threshold equal to negative infinity. + !*/ + + template < + typename image_type + > + void operator() ( + const image_type& img, + std::vector >& dets, + double adjust_threshold = 0 + ); + /*! + requires + - img == an object which can be accepted by image_scanner_type::load() + ensures + - This function is identical to the above operator() routine, except that + it outputs full_object_detections instead of rectangles. This means that + the output includes part locations. In particular, calling this function + is the same as calling the above operator() routine and then using + get_scanner().get_full_object_detection() to resolve all the rectangles + into full_object_detections. Therefore, this version of operator() is + simply a convenience function for performing this set of operations. + !*/ + + template < + typename image_type + > + void operator() ( + const image_type& img, + std::vector& dets, + double adjust_threshold = 0 + ); + /*! + requires + - img == an object which can be accepted by image_scanner_type::load() + ensures + - This function is identical to the above operator() routine, except that + it doesn't include a double valued score. That is, it just outputs the + full_object_detections. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + template + void serialize ( + const object_detector& item, + std::ostream& out + ); + /*! + provides serialization support. Note that this function only saves the + configuration part of item.get_scanner(). That is, we use the scanner's + copy_configuration() function to get a copy of the scanner that doesn't contain any + loaded image data and we then save just the configuration part of the scanner. + This means that any serialized object_detectors won't remember any images they have + processed but will otherwise contain all their state and be able to detect objects + in new images. + !*/ + +// ---------------------------------------------------------------------------------------- + + template + void deserialize ( + object_detector& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_OBJECT_DeTECTOR_ABSTRACT_Hh_ + diff --git a/ml/dlib/dlib/image_processing/remove_unobtainable_rectangles.h b/ml/dlib/dlib/image_processing/remove_unobtainable_rectangles.h new file mode 100644 index 000000000..95ab4f353 --- /dev/null +++ b/ml/dlib/dlib/image_processing/remove_unobtainable_rectangles.h @@ -0,0 +1,317 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_REMOVE_UnOBTAINABLE_RECTANGLES_Hh_ +#define DLIB_REMOVE_UnOBTAINABLE_RECTANGLES_Hh_ + +#include "remove_unobtainable_rectangles_abstract.h" +#include "scan_image_pyramid.h" +#include "scan_image_boxes.h" +#include "scan_image_custom.h" +#include "scan_fhog_pyramid.h" +#include "../svm/structural_object_detection_trainer.h" +#include "../geometry.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + inline bool matches_rect ( + const std::vector& rects, + const rectangle& rect, + const double eps + ) + { + for (unsigned long i = 0; i < rects.size(); ++i) + { + const double score = (rect.intersect(rects[i])).area()/(double)(rect+rects[i]).area(); + if (score > eps) + return true; + } + + return false; + } + + inline rectangle get_best_matching_rect ( + const std::vector& rects, + const rectangle& rect + ) + { + double best_score = -1; + rectangle best_rect; + for (unsigned long i = 0; i < rects.size(); ++i) + { + const double score = (rect.intersect(rects[i])).area()/(double)(rect+rects[i]).area(); + if (score > best_score) + { + best_score = score; + best_rect = rects[i]; + } + } + return best_rect; + } + + // ------------------------------------------------------------------------------------ + + template < + typename image_array_type, + typename image_scanner_type + > + std::vector > pyramid_remove_unobtainable_rectangles ( + const structural_object_detection_trainer& trainer, + const image_array_type& images, + std::vector >& object_locations + ) + { + using namespace dlib::impl; + // make sure requires clause is not broken + DLIB_ASSERT(images.size() == object_locations.size(), + "\t std::vector> remove_unobtainable_rectangles()" + << "\n\t Invalid inputs were given to this function." + ); + + + std::vector > rejects(images.size()); + + // If the trainer is setup to automatically fit the overlap tester to the data then + // we should use the loosest possible overlap tester here. Otherwise we should use + // the tester the trainer will use. + test_box_overlap boxes_overlap(0.9999999,1); + if (!trainer.auto_set_overlap_tester()) + boxes_overlap = trainer.get_overlap_tester(); + + for (unsigned long k = 0; k < images.size(); ++k) + { + std::vector objs = object_locations[k]; + + // First remove things that don't have any matches with the candidate object + // locations. + std::vector good_rects; + for (unsigned long j = 0; j < objs.size(); ++j) + { + const rectangle rect = trainer.get_scanner().get_best_matching_rect(objs[j]); + const double score = (objs[j].intersect(rect)).area()/(double)(objs[j] + rect).area(); + if (score > trainer.get_match_eps()) + good_rects.push_back(objs[j]); + else + rejects[k].push_back(objs[j]); + } + object_locations[k] = good_rects; + + + // Remap these rectangles to the ones that can come out of the scanner. That + // way when we compare them to each other in the following loop we will know if + // any distinct truth rectangles get mapped to overlapping boxes. + objs.resize(good_rects.size()); + for (unsigned long i = 0; i < good_rects.size(); ++i) + objs[i] = trainer.get_scanner().get_best_matching_rect(good_rects[i]); + + good_rects.clear(); + // now check for truth rects that are too close together. + for (unsigned long i = 0; i < objs.size(); ++i) + { + // check if objs[i] hits another box + bool hit_box = false; + for (unsigned long j = i+1; j < objs.size(); ++j) + { + if (boxes_overlap(objs[i], objs[j])) + { + hit_box = true; + break; + } + } + if (hit_box) + rejects[k].push_back(object_locations[k][i]); + else + good_rects.push_back(object_locations[k][i]); + } + object_locations[k] = good_rects; + } + + return rejects; + } + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_array_type, + typename Pyramid_type, + typename Feature_extractor_type + > + std::vector > remove_unobtainable_rectangles ( + const structural_object_detection_trainer >& trainer, + const image_array_type& images, + std::vector >& object_locations + ) + { + return impl::pyramid_remove_unobtainable_rectangles(trainer, images, object_locations); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_array_type, + typename Pyramid_type, + typename Feature_extractor_type + > + std::vector > remove_unobtainable_rectangles ( + const structural_object_detection_trainer >& trainer, + const image_array_type& images, + std::vector >& object_locations + ) + { + return impl::pyramid_remove_unobtainable_rectangles(trainer, images, object_locations); + } + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + template < + typename image_array_type, + typename scanner_type, + typename get_boxes_functor + > + std::vector > remove_unobtainable_rectangles ( + get_boxes_functor& bg, + const structural_object_detection_trainer& trainer, + const image_array_type& images, + std::vector >& object_locations + ) + { + using namespace dlib::impl; + // make sure requires clause is not broken + DLIB_ASSERT(images.size() == object_locations.size(), + "\t std::vector> remove_unobtainable_rectangles()" + << "\n\t Invalid inputs were given to this function." + ); + + std::vector rects; + + std::vector > rejects(images.size()); + + // If the trainer is setup to automatically fit the overlap tester to the data then + // we should use the loosest possible overlap tester here. Otherwise we should use + // the tester the trainer will use. + test_box_overlap boxes_overlap(0.9999999,1); + if (!trainer.auto_set_overlap_tester()) + boxes_overlap = trainer.get_overlap_tester(); + + for (unsigned long k = 0; k < images.size(); ++k) + { + std::vector objs = object_locations[k]; + // Don't even bother computing the candidate rectangles if there aren't any + // object locations for this image since there isn't anything to do anyway. + if (objs.size() == 0) + continue; + + bg(images[k], rects); + + + // First remove things that don't have any matches with the candidate object + // locations. + std::vector good_rects; + for (unsigned long j = 0; j < objs.size(); ++j) + { + if (matches_rect(rects, objs[j], trainer.get_match_eps())) + good_rects.push_back(objs[j]); + else + rejects[k].push_back(objs[j]); + } + object_locations[k] = good_rects; + + + // Remap these rectangles to the ones that can come out of the scanner. That + // way when we compare them to each other in the following loop we will know if + // any distinct truth rectangles get mapped to overlapping boxes. + objs.resize(good_rects.size()); + for (unsigned long i = 0; i < good_rects.size(); ++i) + objs[i] = get_best_matching_rect(rects, good_rects[i]); + + good_rects.clear(); + // now check for truth rects that are too close together. + for (unsigned long i = 0; i < objs.size(); ++i) + { + // check if objs[i] hits another box + bool hit_box = false; + for (unsigned long j = i+1; j < objs.size(); ++j) + { + if (boxes_overlap(objs[i], objs[j])) + { + hit_box = true; + break; + } + } + if (hit_box) + rejects[k].push_back(object_locations[k][i]); + else + good_rects.push_back(object_locations[k][i]); + } + object_locations[k] = good_rects; + } + + return rejects; + } + + // ---------------------------------------------------------------------------------------- + + template + struct load_to_functor + { + load_to_functor(T& obj_) : obj(obj_) {} + T& obj; + + template + void operator()(const U& u, V& v) + { + obj.load(u,v); + } + }; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_array_type, + typename feature_extractor, + typename box_generator + > + std::vector > remove_unobtainable_rectangles ( + const structural_object_detection_trainer >& trainer, + const image_array_type& images, + std::vector >& object_locations + ) + { + box_generator bg = trainer.get_scanner().get_box_generator(); + return impl::remove_unobtainable_rectangles(bg, trainer, images, object_locations); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_array_type, + typename feature_extractor + > + std::vector > remove_unobtainable_rectangles ( + const structural_object_detection_trainer >& trainer, + const image_array_type& images, + std::vector >& object_locations + ) + { + feature_extractor fe; + fe.copy_configuration(trainer.get_scanner().get_feature_extractor()); + impl::load_to_functor bg(fe); + return impl::remove_unobtainable_rectangles(bg, trainer, images, object_locations); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_REMOVE_UnOBTAINABLE_RECTANGLES_Hh_ + diff --git a/ml/dlib/dlib/image_processing/remove_unobtainable_rectangles_abstract.h b/ml/dlib/dlib/image_processing/remove_unobtainable_rectangles_abstract.h new file mode 100644 index 000000000..328326f1c --- /dev/null +++ b/ml/dlib/dlib/image_processing/remove_unobtainable_rectangles_abstract.h @@ -0,0 +1,56 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_REMOVE_UnOBTAINABLE_RECTANGLES_ABSTRACT_Hh_ +#ifdef DLIB_REMOVE_UnOBTAINABLE_RECTANGLES_ABSTRACT_Hh_ + +#include "scan_image_pyramid_abstract.h" +#include "scan_image_boxes_abstract.h" +#include "scan_image_custom_abstract.h" +#include "scan_fhog_pyramid_abstract.h" +#include "../svm/structural_object_detection_trainer_abstract.h" +#include "../geometry.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_scanner_type, + typename image_array_type + > + std::vector > remove_unobtainable_rectangles ( + const structural_object_detection_trainer& trainer, + const image_array_type& images, + std::vector >& object_locations + ); + /*! + requires + - image_scanner_type must be either scan_image_boxes, scan_image_pyramid, + scan_image_custom, or scan_fhog_pyramid. + - images.size() == object_locations.size() + ensures + - Recall that the image scanner objects can't produce all possible rectangles + as object detections since they only consider a limited subset of all possible + object positions. Moreover, the structural_object_detection_trainer requires + its input training data to not contain any object positions which are unobtainable + by its scanner object. Therefore, remove_unobtainable_rectangles() is a tool + to filter out these unobtainable rectangles from the training data before giving + it to a structural_object_detection_trainer. + - This function interprets object_locations[i] as the set of object positions for + image[i], for all valid i. + - In particular, this function removes unobtainable rectangles from object_locations + and also returns a vector V such that: + - V.size() == object_locations.size() + - for all valid i: + - V[i] == the set of rectangles removed from object_locations[i] + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_REMOVE_UnOBTAINABLE_RECTANGLES_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/image_processing/render_face_detections.h b/ml/dlib/dlib/image_processing/render_face_detections.h new file mode 100644 index 000000000..96ff8971f --- /dev/null +++ b/ml/dlib/dlib/image_processing/render_face_detections.h @@ -0,0 +1,99 @@ +// Copyright (C) 2014 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_RENDER_FACE_DeTECTIONS_H_ +#define DLIB_RENDER_FACE_DeTECTIONS_H_ + +#include "full_object_detection.h" +#include "../gui_widgets.h" +#include "render_face_detections_abstract.h" +#include + +namespace dlib +{ + inline std::vector render_face_detections ( + const std::vector& dets, + const rgb_pixel color = rgb_pixel(0,255,0) + ) + { + std::vector lines; + for (unsigned long i = 0; i < dets.size(); ++i) + { + DLIB_CASSERT(dets[i].num_parts() == 68 || dets[i].num_parts() == 5, + "\t std::vector render_face_detections()" + << "\n\t You have to give either a 5 point or 68 point face landmarking output to this function. " + << "\n\t dets["< render_face_detections ( + const full_object_detection& det, + const rgb_pixel color = rgb_pixel(0,255,0) + ) + { + std::vector dets; + dets.push_back(det); + return render_face_detections(dets, color); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_RENDER_FACE_DeTECTIONS_H_ + diff --git a/ml/dlib/dlib/image_processing/render_face_detections_abstract.h b/ml/dlib/dlib/image_processing/render_face_detections_abstract.h new file mode 100644 index 000000000..f609c8e8c --- /dev/null +++ b/ml/dlib/dlib/image_processing/render_face_detections_abstract.h @@ -0,0 +1,59 @@ +// Copyright (C) 2014 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_RENDER_FACE_DeTECTIONS_ABSTRACT_H_ +#ifdef DLIB_RENDER_FACE_DeTECTIONS_ABSTRACT_H_ + +#include "full_object_detection_abstract.h" +#include "../gui_widgets.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + inline std::vector render_face_detections ( + const std::vector& dets, + const rgb_pixel color = rgb_pixel(0,255,0) + ); + /*! + requires + - for all valid i: + - dets[i].num_parts() == 68 || dets[i].num_parts() == 5 + ensures + - Interprets the given objects as face detections with parts annotated using + either the iBUG face landmark scheme or a 5 point face annotation. We then + return a set of overlay lines that will draw the objects onto the screen in a + way that properly draws the outline of the face features defined by the part + locations. + - returns a vector with dets.size() elements, each containing the lines + necessary to render a face detection from dets. + - The 5 point face annotation scheme is assumed to be: + - det part 0 == left eye corner, outside part of eye. + - det part 1 == left eye corner, inside part of eye. + - det part 2 == right eye corner, outside part of eye. + - det part 3 == right eye corner, inside part of eye. + - det part 4 == immediately under the nose, right at the top of the philtrum. + !*/ + +// ---------------------------------------------------------------------------------------- + + inline std::vector render_face_detections ( + const full_object_detection& det, + const rgb_pixel color = rgb_pixel(0,255,0) + ); + /*! + requires + - det.num_parts() == 68 || det.num_parts() == 5 + ensures + - This function is identical to the above render_face_detections() routine + except that it takes just a single full_object_detection instead of a + std::vector of them. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_RENDER_FACE_DeTECTIONS_ABSTRACT_H_ + + diff --git a/ml/dlib/dlib/image_processing/scan_fhog_pyramid.h b/ml/dlib/dlib/image_processing/scan_fhog_pyramid.h new file mode 100644 index 000000000..5ae0310af --- /dev/null +++ b/ml/dlib/dlib/image_processing/scan_fhog_pyramid.h @@ -0,0 +1,1348 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SCAN_fHOG_PYRAMID_Hh_ +#define DLIB_SCAN_fHOG_PYRAMID_Hh_ + +#include "scan_fhog_pyramid_abstract.h" +#include "../matrix.h" +#include "../image_transforms.h" +#include "../array.h" +#include "../array2d.h" +#include "object_detector.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class default_fhog_feature_extractor + { + public: + inline rectangle image_to_feats ( + const rectangle& rect, + int cell_size, + int filter_rows_padding, + int filter_cols_padding + ) const + { + return image_to_fhog(rect, cell_size, filter_rows_padding, filter_cols_padding); + } + + inline rectangle feats_to_image ( + const rectangle& rect, + int cell_size, + int filter_rows_padding, + int filter_cols_padding + ) const + { + return fhog_to_image(rect, cell_size, filter_rows_padding, filter_cols_padding); + } + + template < + typename image_type + > + void operator()( + const image_type& img, + dlib::array >& hog, + int cell_size, + int filter_rows_padding, + int filter_cols_padding + ) const + { + extract_fhog_features(img,hog,cell_size,filter_rows_padding,filter_cols_padding); + } + + inline unsigned long get_num_planes ( + ) const + { + return 31; + } + }; + + inline void serialize (const default_fhog_feature_extractor&, std::ostream&) {} + inline void deserialize (default_fhog_feature_extractor&, std::istream&) {} + +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename Feature_extractor_type = default_fhog_feature_extractor + > + class scan_fhog_pyramid : noncopyable + { + + public: + + typedef matrix feature_vector_type; + + typedef Pyramid_type pyramid_type; + typedef Feature_extractor_type feature_extractor_type; + + scan_fhog_pyramid ( + ); + + explicit scan_fhog_pyramid ( + const feature_extractor_type& fe_ + ); + + template < + typename image_type + > + void load ( + const image_type& img + ); + + inline bool is_loaded_with_image ( + ) const; + + inline void copy_configuration ( + const scan_fhog_pyramid& item + ); + + void set_detection_window_size ( + unsigned long width, + unsigned long height + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(width > 0 && height > 0, + "\t void scan_fhog_pyramid::set_detection_window_size()" + << "\n\t Invalid inputs were given to this function " + << "\n\t width: " << width + << "\n\t height: " << height + << "\n\t this: " << this + ); + + window_width = width; + window_height = height; + feats.clear(); + } + + inline unsigned long get_detection_window_width ( + ) const { return window_width; } + inline unsigned long get_detection_window_height ( + ) const { return window_height; } + + inline unsigned long get_num_detection_templates ( + ) const; + + inline unsigned long get_num_movable_components_per_detection_template ( + ) const; + + void set_padding ( + unsigned long new_padding + ) + { + padding = new_padding; + feats.clear(); + } + + unsigned long get_padding ( + ) const { return padding; } + + void set_cell_size ( + unsigned long new_cell_size + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(new_cell_size > 0 , + "\t void scan_fhog_pyramid::set_cell_size()" + << "\n\t You can't have zero sized fHOG cells. " + << "\n\t this: " << this + ); + + cell_size = new_cell_size; + feats.clear(); + } + + unsigned long get_cell_size ( + ) const { return cell_size; } + + inline long get_num_dimensions ( + ) const; + + unsigned long get_max_pyramid_levels ( + ) const; + + const feature_extractor_type& get_feature_extractor( + ) const { return fe; } + + void set_max_pyramid_levels ( + unsigned long max_levels + ); + + void set_min_pyramid_layer_size ( + unsigned long width, + unsigned long height + ); + + inline unsigned long get_min_pyramid_layer_width ( + ) const; + + inline unsigned long get_min_pyramid_layer_height ( + ) const; + + void detect ( + const feature_vector_type& w, + std::vector >& dets, + const double thresh + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(is_loaded_with_image() && + w.size() >= get_num_dimensions(), + "\t void scan_fhog_pyramid::detect()" + << "\n\t Invalid inputs were given to this function " + << "\n\t is_loaded_with_image(): " << is_loaded_with_image() + << "\n\t w.size(): " << w.size() + << "\n\t get_num_dimensions(): " << get_num_dimensions() + << "\n\t this: " << this + ); + + fhog_filterbank temp = build_fhog_filterbank(w); + detect(temp, dets, thresh); + } + + class fhog_filterbank + { + friend class scan_fhog_pyramid; + public: + inline long get_num_dimensions() const + { + unsigned long dims = 0; + for (unsigned long i = 0; i < filters.size(); ++i) + { + dims += filters[i].size(); + } + return dims; + } + + const std::vector >& get_filters() const { return filters;} + + unsigned long num_separable_filters() const + { + unsigned long num = 0; + for (unsigned long i = 0; i < row_filters.size(); ++i) + { + num += row_filters[i].size(); + } + return num; + } + + std::vector > filters; + std::vector > > row_filters, col_filters; + }; + + fhog_filterbank build_fhog_filterbank ( + const feature_vector_type& weights + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(weights.size() >= get_num_dimensions(), + "\t fhog_filterbank scan_fhog_pyramid::build_fhog_filterbank()" + << "\n\t The number of weights isn't enough to fill out the filterbank. " + << "\n\t weights.size(): " << weights.size() + << "\n\t get_num_dimensions(): " << get_num_dimensions() + << "\n\t this: " << this + ); + + fhog_filterbank temp; + temp.filters.resize(fe.get_num_planes()); + temp.row_filters.resize(fe.get_num_planes()); + temp.col_filters.resize(fe.get_num_planes()); + + // load filters from w + unsigned long width, height; + compute_fhog_window_size(width, height); + const long size = width*height; + for (unsigned long i = 0; i < temp.filters.size(); ++i) + { + matrix u,v,w,f; + f = reshape(rowm(weights, range(i*size, (i+1)*size-1)), height, width); + temp.filters[i] = matrix_cast(f); + + svd3(f, u,w,v); + + matrix w2 = w; + rsort_columns(u,w); + rsort_columns(v,w2); + + double thresh = std::max(1e-4, max(w)*0.001); + w = round_zeros(w, thresh); + + + for (long j = 0; j < w.size(); ++j) + { + if (w(j) != 0) + { + temp.col_filters[i].push_back(matrix_cast(colm(u,j)*std::sqrt(w(j)))); + temp.row_filters[i].push_back(matrix_cast(colm(v,j)*std::sqrt(w(j)))); + } + } + } + + return temp; + } + + void detect ( + const fhog_filterbank& w, + std::vector >& dets, + const double thresh + ) const; + + + void get_feature_vector ( + const full_object_detection& obj, + feature_vector_type& psi + ) const; + + full_object_detection get_full_object_detection ( + const rectangle& rect, + const feature_vector_type& w + ) const; + + const rectangle get_best_matching_rect ( + const rectangle& rect + ) const; + + double get_nuclear_norm_regularization_strength ( + ) const { return nuclear_norm_regularization_strength; } + + void set_nuclear_norm_regularization_strength ( + double strength + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(strength >= 0 , + "\t void scan_fhog_pyramid::set_nuclear_norm_regularization_strength()" + << "\n\t You can't have a negative regularization strength." + << "\n\t strength: " << strength + << "\n\t this: " << this + ); + + nuclear_norm_regularization_strength = strength; + } + + unsigned long get_fhog_window_width ( + ) const + { + unsigned long width, height; + compute_fhog_window_size(width, height); + return width; + } + + unsigned long get_fhog_window_height ( + ) const + { + unsigned long width, height; + compute_fhog_window_size(width, height); + return height; + } + + template + friend void serialize ( + const scan_fhog_pyramid& item, + std::ostream& out + ); + + template + friend void deserialize ( + scan_fhog_pyramid& item, + std::istream& in + ); + + private: + inline void compute_fhog_window_size( + unsigned long& width, + unsigned long& height + ) const + { + const rectangle rect = centered_rect(point(0,0),window_width,window_height); + const rectangle temp = grow_rect(fe.image_to_feats(rect, cell_size, 1, 1), padding); + width = temp.width(); + height = temp.height(); + } + + void get_mapped_rect_and_metadata ( + const unsigned long number_pyramid_levels, + const rectangle& rect, + rectangle& mapped_rect, + rectangle& fhog_rect, + unsigned long& best_level + ) const; + + double get_match_score ( + rectangle r1, + rectangle r2 + ) const + { + // make the rectangles overlap as much as possible before computing the match score. + r1 = move_rect(r1, r2.tl_corner()); + return (r1.intersect(r2).area())/(double)(r1 + r2).area(); + } + + typedef array > fhog_image; + + feature_extractor_type fe; + array feats; + int cell_size; + unsigned long padding; + unsigned long window_width; + unsigned long window_height; + unsigned long max_pyramid_levels; + unsigned long min_pyramid_layer_width; + unsigned long min_pyramid_layer_height; + double nuclear_norm_regularization_strength; + + void init() + { + cell_size = 8; + padding = 1; + window_width = 64; + window_height = 64; + max_pyramid_levels = 1000; + min_pyramid_layer_width = 64; + min_pyramid_layer_height = 64; + nuclear_norm_regularization_strength = 0; + } + + }; + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + template + rectangle apply_filters_to_fhog ( + const fhog_filterbank& w, + const array >& feats, + array2d& saliency_image + ) + { + const unsigned long num_separable_filters = w.num_separable_filters(); + rectangle area; + // use the separable filters if they would be faster than running the regular filters. + if (num_separable_filters > w.filters.size()*std::min(w.filters[0].nr(),w.filters[0].nc())/3.0) + { + area = spatially_filter_image(feats[0], saliency_image, w.filters[0]); + for (unsigned long i = 1; i < w.filters.size(); ++i) + { + // now we filter but the output adds to saliency_image rather than + // overwriting it. + spatially_filter_image(feats[i], saliency_image, w.filters[i], 1, false, true); + } + } + else + { + saliency_image.clear(); + array2d scratch; + + // find the first filter to apply + unsigned long i = 0; + while (i < w.row_filters.size() && w.row_filters[i].size() == 0) + ++i; + + for (; i < w.row_filters.size(); ++i) + { + for (unsigned long j = 0; j < w.row_filters[i].size(); ++j) + { + if (saliency_image.size() == 0) + area = float_spatially_filter_image_separable(feats[i], saliency_image, w.row_filters[i][j], w.col_filters[i][j],scratch,false); + else + area = float_spatially_filter_image_separable(feats[i], saliency_image, w.row_filters[i][j], w.col_filters[i][j],scratch,true); + } + } + if (saliency_image.size() == 0) + { + saliency_image.set_size(feats[0].nr(), feats[0].nc()); + assign_all_pixels(saliency_image, 0); + } + } + return area; + } + } + +// ---------------------------------------------------------------------------------------- + + template + void serialize ( + const scan_fhog_pyramid& item, + std::ostream& out + ) + { + int version = 1; + serialize(version, out); + serialize(item.fe, out); + serialize(item.feats, out); + serialize(item.cell_size, out); + serialize(item.padding, out); + serialize(item.window_width, out); + serialize(item.window_height, out); + serialize(item.max_pyramid_levels, out); + serialize(item.min_pyramid_layer_width, out); + serialize(item.min_pyramid_layer_height, out); + serialize(item.nuclear_norm_regularization_strength, out); + serialize(item.get_num_dimensions(), out); + } + +// ---------------------------------------------------------------------------------------- + + template + void deserialize ( + scan_fhog_pyramid& item, + std::istream& in + ) + { + int version = 0; + deserialize(version, in); + if (version != 1) + throw serialization_error("Unsupported version found when deserializing a scan_fhog_pyramid object."); + + deserialize(item.fe, in); + deserialize(item.feats, in); + deserialize(item.cell_size, in); + deserialize(item.padding, in); + deserialize(item.window_width, in); + deserialize(item.window_height, in); + deserialize(item.max_pyramid_levels, in); + deserialize(item.min_pyramid_layer_width, in); + deserialize(item.min_pyramid_layer_height, in); + deserialize(item.nuclear_norm_regularization_strength, in); + + // When developing some feature extractor, it's easy to accidentally change its + // number of dimensions and then try to deserialize data from an older version of + // your extractor into the current code. This check is here to catch that kind of + // user error. + long dims; + deserialize(dims, in); + if (item.get_num_dimensions() != dims) + throw serialization_error("Number of dimensions in serialized scan_fhog_pyramid doesn't match the expected number."); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// scan_fhog_pyramid member functions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename feature_extractor_type + > + scan_fhog_pyramid:: + scan_fhog_pyramid ( + ) + { + init(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename feature_extractor_type + > + scan_fhog_pyramid:: + scan_fhog_pyramid ( + const feature_extractor_type& fe_ + ) + { + init(); + fe = fe_; + } + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + template < + typename pyramid_type, + typename image_type, + typename feature_extractor_type + > + void create_fhog_pyramid ( + const image_type& img, + const feature_extractor_type& fe, + array > >& feats, + int cell_size, + int filter_rows_padding, + int filter_cols_padding, + unsigned long min_pyramid_layer_width, + unsigned long min_pyramid_layer_height, + unsigned long max_pyramid_levels + ) + { + unsigned long levels = 0; + rectangle rect = get_rect(img); + + // figure out how many pyramid levels we should be using based on the image size + pyramid_type pyr; + do + { + rect = pyr.rect_down(rect); + ++levels; + } while (rect.width() >= min_pyramid_layer_width && rect.height() >= min_pyramid_layer_height && + levels < max_pyramid_levels); + + if (feats.max_size() < levels) + feats.set_max_size(levels); + feats.set_size(levels); + + + + // build our feature pyramid + fe(img, feats[0], cell_size,filter_rows_padding,filter_cols_padding); + DLIB_ASSERT(feats[0].size() == fe.get_num_planes(), + "Invalid feature extractor used with dlib::scan_fhog_pyramid. The output does not have the \n" + "indicated number of planes."); + + if (feats.size() > 1) + { + typedef typename image_traits::pixel_type pixel_type; + array2d temp1, temp2; + pyr(img, temp1); + fe(temp1, feats[1], cell_size,filter_rows_padding,filter_cols_padding); + swap(temp1,temp2); + + for (unsigned long i = 2; i < feats.size(); ++i) + { + pyr(temp2, temp1); + fe(temp1, feats[i], cell_size,filter_rows_padding,filter_cols_padding); + swap(temp1,temp2); + } + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename feature_extractor_type + > + template < + typename image_type + > + void scan_fhog_pyramid:: + load ( + const image_type& img + ) + { + unsigned long width, height; + compute_fhog_window_size(width,height); + impl::create_fhog_pyramid(img, fe, feats, cell_size, height, + width, min_pyramid_layer_width, min_pyramid_layer_height, + max_pyramid_levels); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename feature_extractor_type + > + bool scan_fhog_pyramid:: + is_loaded_with_image ( + ) const + { + return feats.size() != 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename feature_extractor_type + > + void scan_fhog_pyramid:: + copy_configuration ( + const scan_fhog_pyramid& item + ) + { + cell_size = item.cell_size; + padding = item.padding; + window_width = item.window_width; + window_height = item.window_height; + max_pyramid_levels = item.max_pyramid_levels; + min_pyramid_layer_width = item.min_pyramid_layer_width; + min_pyramid_layer_height = item.min_pyramid_layer_height; + nuclear_norm_regularization_strength = item.nuclear_norm_regularization_strength; + fe = item.fe; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename feature_extractor_type + > + unsigned long scan_fhog_pyramid:: + get_num_detection_templates ( + ) const + { + return 1; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename feature_extractor_type + > + unsigned long scan_fhog_pyramid:: + get_num_movable_components_per_detection_template ( + ) const + { + return 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename feature_extractor_type + > + long scan_fhog_pyramid:: + get_num_dimensions ( + ) const + { + unsigned long width, height; + compute_fhog_window_size(width,height); + return width*height*fe.get_num_planes(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename feature_extractor_type + > + unsigned long scan_fhog_pyramid:: + get_max_pyramid_levels ( + ) const + { + return max_pyramid_levels; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename feature_extractor_type + > + void scan_fhog_pyramid:: + set_max_pyramid_levels ( + unsigned long max_levels + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(max_levels > 0 , + "\t void scan_fhog_pyramid::set_max_pyramid_levels()" + << "\n\t You can't have zero levels. " + << "\n\t max_levels: " << max_levels + << "\n\t this: " << this + ); + + max_pyramid_levels = max_levels; + } + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + inline bool compare_pair_rect ( + const std::pair& a, + const std::pair& b + ) + { + return a.first < b.first; + } + + template < + typename pyramid_type, + typename feature_extractor_type, + typename fhog_filterbank + > + void detect_from_fhog_pyramid ( + const array > >& feats, + const feature_extractor_type& fe, + const fhog_filterbank& w, + const double thresh, + const unsigned long det_box_height, + const unsigned long det_box_width, + const int cell_size, + const int filter_rows_padding, + const int filter_cols_padding, + std::vector >& dets + ) + { + dets.clear(); + + array2d saliency_image; + pyramid_type pyr; + + // for all pyramid levels + for (unsigned long l = 0; l < feats.size(); ++l) + { + const rectangle area = apply_filters_to_fhog(w, feats[l], saliency_image); + + // now search the saliency image for any detections + for (long r = area.top(); r <= area.bottom(); ++r) + { + for (long c = area.left(); c <= area.right(); ++c) + { + // if we found a detection + if (saliency_image[r][c] >= thresh) + { + rectangle rect = fe.feats_to_image(centered_rect(point(c,r),det_box_width,det_box_height), + cell_size, filter_rows_padding, filter_cols_padding); + rect = pyr.rect_up(rect, l); + dets.push_back(std::make_pair(saliency_image[r][c], rect)); + } + } + } + } + + std::sort(dets.rbegin(), dets.rend(), compare_pair_rect); + } + + inline bool overlaps_any_box ( + const test_box_overlap& tester, + const std::vector& rects, + const rect_detection& rect + ) + { + for (unsigned long i = 0; i < rects.size(); ++i) + { + // Only compare detections from the same detector. That is, we don't want + // the output of one detector to stop on the output of another detector. + if (rects[i].weight_index == rect.weight_index && tester(rects[i].rect, rect.rect)) + return true; + } + return false; + } + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename feature_extractor_type + > + void scan_fhog_pyramid:: + detect ( + const fhog_filterbank& w, + std::vector >& dets, + const double thresh + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(is_loaded_with_image() && + w.get_num_dimensions() == get_num_dimensions(), + "\t void scan_fhog_pyramid::detect()" + << "\n\t Invalid inputs were given to this function " + << "\n\t is_loaded_with_image(): " << is_loaded_with_image() + << "\n\t w.get_num_dimensions(): " << w.get_num_dimensions() + << "\n\t get_num_dimensions(): " << get_num_dimensions() + << "\n\t this: " << this + ); + + unsigned long width, height; + compute_fhog_window_size(width,height); + + impl::detect_from_fhog_pyramid(feats, fe, w, thresh, + height-2*padding, width-2*padding, cell_size, height, width, dets); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename feature_extractor_type + > + const rectangle scan_fhog_pyramid:: + get_best_matching_rect ( + const rectangle& rect + ) const + { + rectangle mapped_rect, fhog_rect; + unsigned long best_level; + get_mapped_rect_and_metadata(max_pyramid_levels, rect, mapped_rect, fhog_rect, best_level); + return mapped_rect; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename feature_extractor_type + > + void scan_fhog_pyramid:: + get_mapped_rect_and_metadata ( + const unsigned long number_pyramid_levels, + const rectangle& rect, + rectangle& mapped_rect, + rectangle& fhog_rect, + unsigned long& best_level + ) const + { + pyramid_type pyr; + best_level = 0; + double best_match_score = -1; + + + unsigned long width, height; + compute_fhog_window_size(width,height); + + // Figure out the pyramid level which best matches rect against our detection + // window. + for (unsigned long l = 0; l < number_pyramid_levels; ++l) + { + const rectangle rect_fhog_space = fe.image_to_feats(pyr.rect_down(rect,l), cell_size, height,width); + + const rectangle win_image_space = pyr.rect_up(fe.feats_to_image(centered_rect(center(rect_fhog_space),width-2*padding,height-2*padding), cell_size, height,width), l); + + const double match_score = get_match_score(win_image_space, rect); + if (match_score > best_match_score) + { + best_match_score = match_score; + best_level = l; + fhog_rect = centered_rect(center(rect_fhog_space), width, height); + } + + if (rect_fhog_space.area() <= 1) + break; + } + mapped_rect = pyr.rect_up(fe.feats_to_image(shrink_rect(fhog_rect,padding), cell_size,height,width),best_level); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename feature_extractor_type + > + full_object_detection scan_fhog_pyramid:: + get_full_object_detection ( + const rectangle& rect, + const feature_vector_type& + ) const + { + return full_object_detection(rect); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename feature_extractor_type + > + void scan_fhog_pyramid:: + get_feature_vector ( + const full_object_detection& obj, + feature_vector_type& psi + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(is_loaded_with_image() && + psi.size() >= get_num_dimensions() && + obj.num_parts() == 0, + "\t void scan_fhog_pyramid::get_feature_vector()" + << "\n\t Invalid inputs were given to this function " + << "\n\t is_loaded_with_image(): " << is_loaded_with_image() + << "\n\t psi.size(): " << psi.size() + << "\n\t get_num_dimensions(): " << get_num_dimensions() + << "\n\t obj.num_parts(): " << obj.num_parts() + << "\n\t this: " << this + ); + + + + rectangle mapped_rect; + unsigned long best_level; + rectangle fhog_rect; + get_mapped_rect_and_metadata(feats.size(), obj.get_rect(), mapped_rect, fhog_rect, best_level); + + + long i = 0; + for (unsigned long ii = 0; ii < feats[best_level].size(); ++ii) + { + const rectangle rect = get_rect(feats[best_level][0]); + for (long r = fhog_rect.top(); r <= fhog_rect.bottom(); ++r) + { + for (long c = fhog_rect.left(); c <= fhog_rect.right(); ++c) + { + if (rect.contains(c,r)) + psi(i) += feats[best_level][ii][r][c]; + ++i; + } + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename feature_extractor_type + > + void scan_fhog_pyramid:: + set_min_pyramid_layer_size ( + unsigned long width, + unsigned long height + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(width > 0 && height > 0 , + "\t void scan_fhog_pyramid::set_min_pyramid_layer_size()" + << "\n\t These sizes can't be zero. " + << "\n\t width: " << width + << "\n\t height: " << height + << "\n\t this: " << this + ); + + min_pyramid_layer_width = width; + min_pyramid_layer_height = height; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename feature_extractor_type + > + unsigned long scan_fhog_pyramid:: + get_min_pyramid_layer_width ( + ) const + { + return min_pyramid_layer_width; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename feature_extractor_type + > + unsigned long scan_fhog_pyramid:: + get_min_pyramid_layer_height ( + ) const + { + return min_pyramid_layer_height; + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename feature_extractor_type + > + matrix draw_fhog ( + const object_detector >& detector, + const unsigned long weight_index = 0, + const long cell_draw_size = 15 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(weight_index < detector.num_detectors(), + "\t matrix draw_fhog()" + << "\n\t Invalid arguments were given to this function. " + << "\n\t weight_index: " << weight_index + << "\n\t detector.num_detectors(): " << detector.num_detectors() + ); + DLIB_ASSERT(cell_draw_size > 0 && detector.get_w(weight_index).size() >= detector.get_scanner().get_num_dimensions(), + "\t matrix draw_fhog()" + << "\n\t Invalid arguments were given to this function. " + << "\n\t cell_draw_size: " << cell_draw_size + << "\n\t weight_index: " << weight_index + << "\n\t detector.get_w(weight_index).size(): " << detector.get_w(weight_index).size() + << "\n\t detector.get_scanner().get_num_dimensions(): " << detector.get_scanner().get_num_dimensions() + ); + + typename scan_fhog_pyramid::fhog_filterbank fb = detector.get_scanner().build_fhog_filterbank(detector.get_w(weight_index)); + return draw_fhog(fb.get_filters(),cell_draw_size); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename feature_extractor_type + > + unsigned long num_separable_filters ( + const object_detector >& detector, + const unsigned long weight_index = 0 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(weight_index < detector.num_detectors(), + "\t unsigned long num_separable_filters()" + << "\n\t Invalid arguments were given to this function. " + << "\n\t weight_index: " << weight_index + << "\n\t detector.num_detectors(): " << detector.num_detectors() + ); + DLIB_ASSERT(detector.get_w(weight_index).size() >= detector.get_scanner().get_num_dimensions() , + "\t unsigned long num_separable_filters()" + << "\n\t Invalid arguments were given to this function. " + << "\n\t detector.get_w(weight_index).size(): " << detector.get_w(weight_index).size() + << "\n\t detector.get_scanner().get_num_dimensions(): " << detector.get_scanner().get_num_dimensions() + ); + + typename scan_fhog_pyramid::fhog_filterbank fb = detector.get_scanner().build_fhog_filterbank(detector.get_w(weight_index)); + return fb.num_separable_filters(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename feature_extractor_type + > + object_detector > threshold_filter_singular_values ( + const object_detector >& detector, + double thresh, + const unsigned long weight_index = 0 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(thresh >= 0 , + "\t object_detector threshold_filter_singular_values()" + << "\n\t Invalid inputs were given to this function." + << "\n\t thresh: " << thresh + ); + + DLIB_ASSERT(weight_index < detector.num_detectors(), + "\t object_detector threshold_filter_singular_values()" + << "\n\t Invalid arguments were given to this function. " + << "\n\t weight_index: " << weight_index + << "\n\t detector.num_detectors(): " << detector.num_detectors() + ); + DLIB_ASSERT(detector.get_w(weight_index).size() >= detector.get_scanner().get_num_dimensions() , + "\t object_detector threshold_filter_singular_values()" + << "\n\t Invalid arguments were given to this function. " + << "\n\t detector.get_w(weight_index).size(): " << detector.get_w(weight_index).size() + << "\n\t detector.get_scanner().get_num_dimensions(): " << detector.get_scanner().get_num_dimensions() + ); + + + const unsigned long width = detector.get_scanner().get_fhog_window_width(); + const unsigned long height = detector.get_scanner().get_fhog_window_height(); + const long num_planes = detector.get_scanner().get_feature_extractor().get_num_planes(); + const long size = width*height; + + std::vector > detector_weights; + for (unsigned long j = 0; j < detector.num_detectors(); ++j) + { + matrix weights = detector.get_w(j); + + if (j == weight_index) + { + matrix u,v,w,f; + for (long i = 0; i < num_planes; ++i) + { + f = reshape(rowm(weights, range(i*size, (i+1)*size-1)), height, width); + + svd3(f, u,w,v); + const double scaled_thresh = std::max(1e-3, max(w)*thresh); + w = round_zeros(w, scaled_thresh); + f = u*diagm(w)*trans(v); + + set_rowm(weights,range(i*size, (i+1)*size-1)) = reshape_to_column_vector(f); + } + } + detector_weights.push_back(weights); + } + + return object_detector >(detector.get_scanner(), + detector.get_overlap_tester(), + detector_weights); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename feature_extractor_type, + typename svm_struct_prob_type + > + void configure_nuclear_norm_regularizer ( + const scan_fhog_pyramid& scanner, + svm_struct_prob_type& prob + ) + { + const double strength = scanner.get_nuclear_norm_regularization_strength(); + const long num_planes = scanner.get_feature_extractor().get_num_planes(); + if (strength != 0) + { + const unsigned long width = scanner.get_fhog_window_width(); + const unsigned long height = scanner.get_fhog_window_height(); + for (long i = 0; i < num_planes; ++i) + { + prob.add_nuclear_norm_regularizer(i*width*height, height, width, strength); + } + prob.set_cache_based_epsilon(0.001); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename feature_extractor_type + > + struct processed_weight_vector > + { + processed_weight_vector(){} + + typedef matrix feature_vector_type; + typedef typename scan_fhog_pyramid::fhog_filterbank fhog_filterbank; + + void init ( + const scan_fhog_pyramid& scanner + ) + { + fb = scanner.build_fhog_filterbank(w); + } + + const fhog_filterbank& get_detect_argument() const { return fb; } + + feature_vector_type w; + fhog_filterbank fb; + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename pyramid_type, + typename image_type + > + void evaluate_detectors ( + const std::vector > >& detectors, + const image_type& img, + std::vector& dets, + const double adjust_threshold = 0 + ) + { + typedef scan_fhog_pyramid scanner_type; + + dets.clear(); + if (detectors.size() == 0) + return; + + const unsigned long cell_size = detectors[0].get_scanner().get_cell_size(); + + // Find the maximum sized filters and also most extreme pyramiding settings used. + unsigned long max_filter_width = 0; + unsigned long max_filter_height = 0; + unsigned long min_pyramid_layer_width = std::numeric_limits::max(); + unsigned long min_pyramid_layer_height = std::numeric_limits::max(); + unsigned long max_pyramid_levels = 0; + bool all_cell_sizes_the_same = true; + for (unsigned long i = 0; i < detectors.size(); ++i) + { + const scanner_type& scanner = detectors[i].get_scanner(); + max_filter_width = std::max(max_filter_width, scanner.get_fhog_window_width()); + max_filter_height = std::max(max_filter_height, scanner.get_fhog_window_height()); + max_pyramid_levels = std::max(max_pyramid_levels, scanner.get_max_pyramid_levels()); + min_pyramid_layer_width = std::min(min_pyramid_layer_width, scanner.get_min_pyramid_layer_width()); + min_pyramid_layer_height = std::min(min_pyramid_layer_height, scanner.get_min_pyramid_layer_height()); + if (cell_size != scanner.get_cell_size()) + all_cell_sizes_the_same = false; + } + + std::vector dets_accum; + // Do to the HOG feature extraction to make the fhog pyramid. Again, note that we + // are making a pyramid that will work with any of the detectors. But only if all + // the cell sizes are the same. If they aren't then we have to calculate the + // pyramid for each detector individually. + array > > feats; + if (all_cell_sizes_the_same) + { + impl::create_fhog_pyramid(img, + detectors[0].get_scanner().get_feature_extractor(), feats, cell_size, + max_filter_height, max_filter_width, min_pyramid_layer_width, + min_pyramid_layer_height, max_pyramid_levels); + } + + std::vector > temp_dets; + for (unsigned long i = 0; i < detectors.size(); ++i) + { + const scanner_type& scanner = detectors[i].get_scanner(); + if (!all_cell_sizes_the_same) + { + impl::create_fhog_pyramid(img, + scanner.get_feature_extractor(), feats, scanner.get_cell_size(), + max_filter_height, max_filter_width, min_pyramid_layer_width, + min_pyramid_layer_height, max_pyramid_levels); + } + + const unsigned long det_box_width = scanner.get_fhog_window_width() - 2*scanner.get_padding(); + const unsigned long det_box_height = scanner.get_fhog_window_height() - 2*scanner.get_padding(); + // A single detector object might itself have multiple weight vectors in it. So + // we need to evaluate all of them. + for (unsigned d = 0; d < detectors[i].num_detectors(); ++d) + { + const double thresh = detectors[i].get_processed_w(d).w(scanner.get_num_dimensions()); + + impl::detect_from_fhog_pyramid(feats, scanner.get_feature_extractor(), + detectors[i].get_processed_w(d).get_detect_argument(), thresh+adjust_threshold, + det_box_height, det_box_width, cell_size, max_filter_height, + max_filter_width, temp_dets); + + for (unsigned long j = 0; j < temp_dets.size(); ++j) + { + rect_detection temp; + temp.detection_confidence = temp_dets[j].first-thresh; + temp.weight_index = i; + temp.rect = temp_dets[j].second; + dets_accum.push_back(temp); + } + } + } + + + // Do non-max suppression + if (detectors.size() > 1) + std::sort(dets_accum.rbegin(), dets_accum.rend()); + for (unsigned long i = 0; i < dets_accum.size(); ++i) + { + const test_box_overlap tester = detectors[dets_accum[i].weight_index].get_overlap_tester(); + if (impl::overlaps_any_box(tester, dets, dets_accum[i])) + continue; + + dets.push_back(dets_accum[i]); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename image_type + > + std::vector evaluate_detectors ( + const std::vector > >& detectors, + const image_type& img, + const double adjust_threshold = 0 + ) + { + std::vector out_dets; + std::vector dets; + evaluate_detectors(detectors, img, dets, adjust_threshold); + out_dets.reserve(dets.size()); + for (unsigned long i = 0; i < dets.size(); ++i) + out_dets.push_back(dets[i].rect); + return out_dets; + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SCAN_fHOG_PYRAMID_Hh_ + diff --git a/ml/dlib/dlib/image_processing/scan_fhog_pyramid_abstract.h b/ml/dlib/dlib/image_processing/scan_fhog_pyramid_abstract.h new file mode 100644 index 000000000..d12a2b2b8 --- /dev/null +++ b/ml/dlib/dlib/image_processing/scan_fhog_pyramid_abstract.h @@ -0,0 +1,784 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SCAN_fHOG_PYRAMID_ABSTRACT_Hh_ +#ifdef DLIB_SCAN_fHOG_PYRAMID_ABSTRACT_Hh_ + +#include +#include "../image_transforms/fhog_abstract.h" +#include "object_detector_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename feature_extractor_type + > + matrix draw_fhog ( + const object_detector >& detector, + const unsigned long weight_index = 0, + const long cell_draw_size = 15 + ); + /*! + requires + - cell_draw_size > 0 + - weight_index < detector.num_detectors() + - detector.get_w(weight_index).size() >= detector.get_scanner().get_num_dimensions() + (i.e. the detector must have been populated with a HOG filter) + ensures + - Converts the HOG filters in the given detector (specifically, the filters in + detector.get_w(weight_index)) into an image suitable for display on the + screen. In particular, we draw all the HOG cells into a grayscale image in a + way that shows the magnitude and orientation of the gradient energy in each + cell. The resulting image is then returned. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename feature_extractor_type + > + unsigned long num_separable_filters ( + const object_detector >& detector, + const unsigned long weight_index = 0 + ); + /*! + requires + - weight_index < detector.num_detectors() + - detector.get_w(weight_index).size() >= detector.get_scanner().get_num_dimensions() + (i.e. the detector must have been populated with a HOG filter) + ensures + - Returns the number of separable filters necessary to represent the HOG + filters in the given detector's weight_index'th filter. This is the filter + defined by detector.get_w(weight_index). + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename feature_extractor_type + > + object_detector > threshold_filter_singular_values ( + const object_detector >& detector, + double thresh, + const unsigned long weight_index = 0 + ); + /*! + requires + - thresh >= 0 + - weight_index < detector.num_detectors() + - detector.get_w(weight_index).size() >= detector.get_scanner().get_num_dimensions() + (i.e. the detector must have been populated with a HOG filter) + ensures + - Removes all components of the filters in the given detector that have + singular values that are smaller than the given threshold. Therefore, this + function allows you to control how many separable filters are in a detector. + In particular, as thresh gets larger the quantity + num_separable_filters(threshold_filter_singular_values(detector,thresh,weight_index),weight_index) + will generally get smaller and therefore give a faster running detector. + However, note that at some point a large enough thresh will drop too much + information from the filters and their accuracy will suffer. + - returns the updated detector + !*/ + +// ---------------------------------------------------------------------------------------- + + class default_fhog_feature_extractor + { + /*! + WHAT THIS OBJECT REPRESENTS + The scan_fhog_pyramid object defined below is primarily meant to be used + with the feature extraction technique implemented by extract_fhog_features(). + This technique can generally be understood as taking an input image and + outputting a multi-planed output image of floating point numbers that + somehow describe the image contents. Since there are many ways to define + how this feature mapping is performed, the scan_fhog_pyramid allows you to + replace the extract_fhog_features() method with a customized method of your + choosing. To do this you implement a class with the same interface as + default_fhog_feature_extractor. + + Therefore, the point of default_fhog_feature_extractor is two fold. First, + it provides the default FHOG feature extraction method used by scan_fhog_pyramid. + Second, it serves to document the interface you need to implement to define + your own custom HOG style feature extraction. + !*/ + + public: + + rectangle image_to_feats ( + const rectangle& rect, + int cell_size, + int filter_rows_padding, + int filter_cols_padding + ) const { return image_to_fhog(rect, cell_size, filter_rows_padding, filter_cols_padding); } + /*! + requires + - cell_size > 0 + - filter_rows_padding > 0 + - filter_cols_padding > 0 + ensures + - Maps a rectangle from the coordinates in an input image to the corresponding + area in the output feature image. + !*/ + + rectangle feats_to_image ( + const rectangle& rect, + int cell_size, + int filter_rows_padding, + int filter_cols_padding + ) const { return fhog_to_image(rect, cell_size, filter_rows_padding, filter_cols_padding); } + /*! + requires + - cell_size > 0 + - filter_rows_padding > 0 + - filter_cols_padding > 0 + ensures + - Maps a rectangle from the coordinates of the hog feature image back to + the input image. + - Mapping from feature space to image space is an invertible + transformation. That is, for any rectangle R we have: + R == image_to_feats(feats_to_image(R,cell_size,filter_rows_padding,filter_cols_padding), + cell_size,filter_rows_padding,filter_cols_padding). + !*/ + + template < + typename image_type + > + void operator()( + const image_type& img, + dlib::array >& hog, + int cell_size, + int filter_rows_padding, + int filter_cols_padding + ) const { extract_fhog_features(img,hog,cell_size,filter_rows_padding,filter_cols_padding); } + /*! + requires + - image_type == is an implementation of array2d/array2d_kernel_abstract.h + - img contains some kind of pixel type. + (i.e. pixel_traits is defined) + ensures + - Extracts FHOG features by calling extract_fhog_features(). The results are + stored into #hog. Note that if you are implementing your own feature extractor you can + pretty much do whatever you want in terms of feature extraction so long as the following + conditions are met: + - #hog.size() == get_num_planes() + - Each image plane in #hog has the same dimensions. + - for all valid i, r, and c: + - #hog[i][r][c] == a feature value describing the image content centered at the + following pixel location in img: + feats_to_image(point(c,r),cell_size,filter_rows_padding,filter_cols_padding) + !*/ + + inline unsigned long get_num_planes ( + ) const { return 31; } + /*! + ensures + - returns the number of planes in the hog image output by the operator() + method. + !*/ + }; + + inline void serialize (const default_fhog_feature_extractor&, std::ostream&) {} + inline void deserialize (default_fhog_feature_extractor&, std::istream&) {} + /*! + Provides serialization support. Note that there is no state in the default hog + feature extractor so these functions do nothing. But if you define a custom + feature extractor then make sure you remember to serialize any state in your + feature extractor. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename Feature_extractor_type = default_fhog_feature_extractor + > + class scan_fhog_pyramid : noncopyable + { + /*! + REQUIREMENTS ON Pyramid_type + - Must be one of the pyramid_down objects defined in + dlib/image_transforms/image_pyramid_abstract.h or an object with a + compatible interface + + REQUIREMENTS ON Feature_extractor_type + - Must be a type with an interface compatible with the + default_fhog_feature_extractor. + + INITIAL VALUE + - get_padding() == 1 + - get_cell_size() == 8 + - get_detection_window_width() == 64 + - get_detection_window_height() == 64 + - get_max_pyramid_levels() == 1000 + - get_min_pyramid_layer_width() == 64 + - get_min_pyramid_layer_height() == 64 + - get_nuclear_norm_regularization_strength() == 0 + + WHAT THIS OBJECT REPRESENTS + This object is a tool for running a fixed sized sliding window classifier + over an image pyramid. In particular, it slides a linear classifier over + a HOG pyramid as discussed in the paper: + Histograms of Oriented Gradients for Human Detection by Navneet Dalal + and Bill Triggs, CVPR 2005 + However, we augment the method slightly to use the version of HOG features + from: + Object Detection with Discriminatively Trained Part Based Models by + P. Felzenszwalb, R. Girshick, D. McAllester, D. Ramanan + IEEE Transactions on Pattern Analysis and Machine Intelligence, Vol. 32, No. 9, Sep. 2010 + Since these HOG features have been shown to give superior performance. + + THREAD SAFETY + Concurrent access to an instance of this object is not safe and should be + protected by a mutex lock except for the case where you are copying the + configuration (via copy_configuration()) of a scan_fhog_pyramid object to + many other threads. In this case, it is safe to copy the configuration of + a shared object so long as no other operations are performed on it. + !*/ + + public: + typedef matrix feature_vector_type; + typedef Pyramid_type pyramid_type; + typedef Feature_extractor_type feature_extractor_type; + + scan_fhog_pyramid ( + ); + /*! + ensures + - this object is properly initialized + !*/ + + explicit scan_fhog_pyramid ( + const feature_extractor_type& fe + ); + /*! + ensures + - this object is properly initialized + - #get_feature_extractor() == fe + !*/ + + template < + typename image_type + > + void load ( + const image_type& img + ); + /*! + requires + - image_type == is an implementation of array2d/array2d_kernel_abstract.h + - img contains some kind of pixel type. + (i.e. pixel_traits is defined) + ensures + - #is_loaded_with_image() == true + - This object is ready to run a classifier over img to detect object + locations. Call detect() to do this. + !*/ + + const feature_extractor_type& get_feature_extractor( + ) const; + /*! + ensures + - returns a const reference to the feature extractor used by this object. + !*/ + + bool is_loaded_with_image ( + ) const; + /*! + ensures + - returns true if this object has been loaded with an image to process and + false otherwise. + !*/ + + void copy_configuration ( + const scan_fhog_pyramid& item + ); + /*! + ensures + - Copies all the state information of item into *this, except for state + information populated by load(). More precisely, given two scan_fhog_pyramid + objects S1 and S2, the following sequence of instructions should always + result in both of them having the exact same state: + S2.copy_configuration(S1); + S1.load(img); + S2.load(img); + !*/ + + void set_detection_window_size ( + unsigned long window_width, + unsigned long window_height + ); + /*! + requires + - window_width > 0 + - window_height > 0 + ensures + - When detect() is called, this object scans a window that is of the given + width and height (in pixels) over each layer in an image pyramid. This + means that the rectangle detections which come out of detect() will have + a width to height ratio approximately equal to window_width/window_height + and will be approximately window_width*window_height pixels in area or + larger. Therefore, the smallest object that can be detected is roughly + window_width by window_height pixels in size. + - #get_detection_window_width() == window_width + - #get_detection_window_height() == window_height + - Since we use a HOG feature representation, the detection procedure works + as follows: + Step 1. Make an image pyramid. + Step 2. Convert each layer of the image pyramid into a multi-planed HOG "image". + (the number of bands is given by get_feature_extractor().get_num_planes()) + Step 3. Scan a linear classifier over each HOG image in the pyramid. + Moreover, the HOG features quantize the input image into a grid of cells, + each cell being get_cell_size() by get_cell_size() pixels in size. So + when we scan the object detector over the pyramid we are scanning an + appropriately sized window over these smaller quantized HOG features. In + particular, the size of the window we scan over the HOG feature pyramid + is #get_fhog_window_width() by #get_fhog_window_height() HOG cells in + size. + - #is_loaded_with_image() == false + !*/ + + unsigned long get_detection_window_width ( + ) const; + /*! + ensures + - returns the width, in pixels, of the detection window that is scanned + over the image when detect() is called. + !*/ + + inline unsigned long get_detection_window_height ( + ) const; + /*! + ensures + - returns the height, in pixels, of the detection window that is scanned + over the image when detect() is called. + !*/ + + unsigned long get_fhog_window_width ( + ) const; + /*! + ensures + - Returns the width of the HOG scanning window in terms of HOG cell blocks. + Note that this is a function of get_detection_window_width(), get_cell_size(), + and get_padding() and is therefore not something you set directly. + - #get_fhog_window_width() is approximately equal to the number of HOG cells + that fit into get_detection_window_width() pixels plus 2*get_padding() + since we include additional padding around each window to add context. + !*/ + + unsigned long get_fhog_window_height ( + ) const; + /*! + ensures + - Returns the height of the HOG scanning window in terms of HOG cell blocks. + Note that this is a function of get_detection_window_height(), get_cell_size(), + and get_padding() and is therefore not something you set directly. + - #get_fhog_window_height() is approximately equal to the number of HOG cells + that fit into get_detection_window_height() pixels plus 2*get_padding() + since we include additional padding around each window to add context. + !*/ + + void set_padding ( + unsigned long new_padding + ); + /*! + ensures + - #get_padding() == new_padding + - #is_loaded_with_image() == false + !*/ + + unsigned long get_padding ( + ) const; + /*! + ensures + - The HOG windows scanned over the HOG pyramid can include additional HOG + cells outside the detection window. This can help add context and + improve detection accuracy. This function returns the number of extra + HOG cells added onto the border of the HOG windows which are scanned by + detect(). + !*/ + + unsigned long get_cell_size ( + ) const; + /*! + ensures + - Returns the size of the HOG cells. Each HOG cell is square and contains + get_cell_size()*get_cell_size() pixels. + !*/ + + void set_cell_size ( + unsigned long new_cell_size + ); + /*! + requires + - new_cell_size > 0 + ensures + - #get_cell_size() == new_cell_size + - #is_loaded_with_image() == false + !*/ + + inline long get_num_dimensions ( + ) const; + /*! + ensures + - returns get_fhog_window_width()*get_fhog_window_height()*get_feature_extractor().get_num_planes() + (i.e. The number of features is equal to the size of the HOG window times + the number of planes output by the feature extractor. ) + !*/ + + inline unsigned long get_num_detection_templates ( + ) const { return 1; } + /*! + ensures + - returns 1. Note that this function is here only for compatibility with + the scan_image_pyramid object. Notionally, its return value indicates + that a scan_fhog_pyramid object is always ready to detect objects once + an image has been loaded. + !*/ + + inline unsigned long get_num_movable_components_per_detection_template ( + ) const { return 0; } + /*! + ensures + - returns 0. Note that this function is here only for compatibility with + the scan_image_pyramid object. Its return value means that this object + does not support using movable part models. + !*/ + + unsigned long get_max_pyramid_levels ( + ) const; + /*! + ensures + - returns the maximum number of image pyramid levels this object will use. + Note that #get_max_pyramid_levels() == 1 indicates that no image pyramid + will be used at all. That is, only the original image will be processed + and no lower scale versions will be created. + !*/ + + void set_max_pyramid_levels ( + unsigned long max_levels + ); + /*! + requires + - max_levels > 0 + ensures + - #get_max_pyramid_levels() == max_levels + !*/ + + void set_min_pyramid_layer_size ( + unsigned long width, + unsigned long height + ); + /*! + requires + - width > 0 + - height > 0 + ensures + - #get_min_pyramid_layer_width() == width + - #get_min_pyramid_layer_height() == height + !*/ + + inline unsigned long get_min_pyramid_layer_width ( + ) const; + /*! + ensures + - returns the smallest allowable width of an image in the image pyramid. + All pyramids will always include the original input image, however, no + pyramid levels will be created which have a width smaller than the + value returned by this function. + !*/ + + inline unsigned long get_min_pyramid_layer_height ( + ) const; + /*! + ensures + - returns the smallest allowable height of an image in the image pyramid. + All pyramids will always include the original input image, however, no + pyramid levels will be created which have a height smaller than the + value returned by this function. + !*/ + + fhog_filterbank build_fhog_filterbank ( + const feature_vector_type& weights + ) const; + /*! + requires + - weights.size() >= get_num_dimensions() + ensures + - Creates and then returns a fhog_filterbank object FB such that: + - FB.get_num_dimensions() == get_num_dimensions() + - FB.get_filters() == the values in weights unpacked into get_feature_extractor().get_num_planes() filters. + - FB.num_separable_filters() == the number of separable filters necessary to + represent all the filters in FB.get_filters(). + !*/ + + class fhog_filterbank + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a HOG filter bank. That is, the classifier that is + slid over a HOG pyramid is a set of get_feature_extractor().get_num_planes() + linear filters, each get_fhog_window_width() rows by get_fhog_window_height() + columns in size. This object contains that set of filters. + !*/ + + public: + long get_num_dimensions( + ) const; + /*! + ensures + - Returns the total number of values in the filters. + !*/ + + const std::vector >& get_filters( + ) const; + /*! + ensures + - returns the set of HOG filters in this object. + !*/ + + unsigned long num_separable_filters( + ) const; + /*! + ensures + - returns the number of separable filters necessary to represent all + the filters in get_filters(). + !*/ + }; + + void detect ( + const fhog_filterbank& w, + std::vector >& dets, + const double thresh + ) const; + /*! + requires + - w.get_num_dimensions() == get_num_dimensions() + - is_loaded_with_image() == true + ensures + - Scans the HOG filter defined by w over the HOG pyramid that was populated + by the last call to load() and stores all object detections into #dets. + - for all valid i: + - #dets[i].second == The object box which produced this detection. This rectangle gives + the location of the detection. Note that the rectangle will have been converted back into + the original image input space. That is, if this detection was made at a low level in the + image pyramid then the object box will have been automatically mapped up the pyramid layers + to the original image space. Or in other words, if you plot #dets[i].second on top of the + image given to load() it will show up in the right place. + - #dets[i].first == The score for this detection. This value is equal to dot(w, feature vector + for this sliding window location). + - #dets[i].first >= thresh + - #dets will be sorted in descending order. (i.e. #dets[i].first >= #dets[j].first for all i, and j>i) + - Elements of w beyond index get_num_dimensions()-1 are ignored. I.e. only the first + get_num_dimensions() are used. + - Note that no form of non-max suppression is performed. If a window has a score >= thresh + then it is reported in #dets. + !*/ + + void detect ( + const feature_vector_type& w, + std::vector >& dets, + const double thresh + ) const; + /*! + requires + - w.size() >= get_num_dimensions() + - is_loaded_with_image() == true + ensures + - performs: detect(build_fhog_filterbank(w), dets, thresh) + !*/ + + void get_feature_vector ( + const full_object_detection& obj, + feature_vector_type& psi + ) const; + /*! + requires + - obj.num_parts() == 0 + - is_loaded_with_image() == true + - psi.size() >= get_num_dimensions() + (i.e. psi must have preallocated its memory before this function is called) + ensures + - This function allows you to determine the feature vector used for an + object detection output from detect(). Note that this vector is + added to psi. Note also that you can use get_full_object_detection() to + convert a rectangle from detect() into the needed full_object_detection. + - The dimensionality of the vector added to psi is get_num_dimensions(). This + means that elements of psi after psi(get_num_dimensions()-1) are not modified. + - Since scan_fhog_pyramid only searches a limited set of object locations, + not all possible rectangles can be output by detect(). So in the case + where obj.get_rect() could not arise from a call to detect(), this + function will map obj.get_rect() to the nearest possible rectangle and + then add the feature vector for the mapped rectangle into #psi. + - get_best_matching_rect(obj.get_rect()) == the rectangle obj.get_rect() + gets mapped to for feature extraction. + !*/ + + full_object_detection get_full_object_detection ( + const rectangle& rect, + const feature_vector_type& w + ) const; + /*! + ensures + - returns full_object_detection(rect) + (This function is here only for compatibility with the scan_image_pyramid + object) + !*/ + + const rectangle get_best_matching_rect ( + const rectangle& rect + ) const; + /*! + ensures + - Since scan_fhog_pyramid only searches a limited set of object locations, + not all possible rectangles can be represented. Therefore, this function + allows you to supply a rectangle and obtain the nearest possible + candidate object location rectangle. + !*/ + + double get_nuclear_norm_regularization_strength ( + ) const; + /*! + ensures + - If the number of separable filters in a fhog_filterbank is small then the + filter bank can be scanned over an image much faster than a normal set of + filters. Therefore, this object provides the option to encourage + machine learning methods that learn a HOG filter bank (i.e. + structural_object_detection_trainer) to select filter banks that have + this beneficial property. In particular, the value returned by + get_nuclear_norm_regularization_strength() is a multiplier on a nuclear + norm regularizer which will encourage the selection of filters that use a + small number of separable components. Larger values encourage tend to + give a smaller number of separable filters. + - if (get_nuclear_norm_regularization_strength() == 0) then + - This feature is disabled + - else + - A nuclear norm regularizer will be added when + structural_object_detection_trainer is used to learn a HOG filter + bank. Note that this can make the training process take + significantly longer (but can result in faster object detectors). + !*/ + + void set_nuclear_norm_regularization_strength ( + double strength + ); + /*! + requires + - strength >= 0 + ensures + - #get_nuclear_norm_regularization_strength() == strength + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template + void serialize ( + const scan_fhog_pyramid& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + +// ---------------------------------------------------------------------------------------- + + template + void deserialize ( + scan_fhog_pyramid& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename pyramid_type, + typename image_type + > + void evaluate_detectors ( + const std::vector>>& detectors, + const image_type& img, + std::vector& dets, + const double adjust_threshold = 0 + ); + /*! + requires + - image_type == is an implementation of array2d/array2d_kernel_abstract.h + - img contains some kind of pixel type. + (i.e. pixel_traits is defined) + ensures + - This function runs each of the provided object_detector objects over img and + stores the resulting detections into #dets. Importantly, this function is + faster than running each detector individually because it computes the HOG + features only once and then reuses them for each detector. However, it is + important to note that this speedup is only possible if all the detectors use + the same cell_size parameter that determines how HOG features are computed. + If different cell_size values are used then this function will not be any + faster than running the detectors individually. + - This function applies non-max suppression individually to the output of each + detector. Therefore, the output is the same as if you ran each detector + individually and then concatenated the results. + - To be precise, this function performs object detection on the given image and + stores the detected objects into #dets. In particular, we will have that: + - #dets is sorted such that the highest confidence detections come first. + E.g. element 0 is the best detection, element 1 the next best, and so on. + - #dets.size() == the number of detected objects. + - #dets[i].detection_confidence == The strength of the i-th detection. + Larger values indicate that the detector is more confident that #dets[i] + is a correct detection rather than being a false alarm. Moreover, the + detection_confidence is equal to the detection value output by the + scanner minus the threshold value stored at the end of the weight vector. + - #dets[i].rect == the bounding box for the i-th detection. + - The detection #dets[i].rect was produced by detectors[#dets[i].weight_index]. + - The detection threshold is adjusted by having adjust_threshold added to it. + Therefore, an adjust_threshold value > 0 makes detecting objects harder while + a negative value makes it easier. Moreover, the following will be true for + all valid i: + - #dets[i].detection_confidence >= adjust_threshold + This means that, for example, you can obtain the maximum possible number of + detections by setting adjust_threshold equal to negative infinity. + - This function is threadsafe in the sense that multiple threads can call + evaluate_detectors() with the same instances of detectors and img without + requiring a mutex lock. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename pyramid_type, + typename image_type + > + std::vector evaluate_detectors ( + const std::vector>>& detectors, + const image_type& img, + const double adjust_threshold = 0 + ); + /*! + requires + - image_type == is an implementation of array2d/array2d_kernel_abstract.h + - img contains some kind of pixel type. + (i.e. pixel_traits is defined) + ensures + - This function just calls the above evaluate_detectors() routine and copies + the output dets into a vector object and returns it. Therefore, + this function is provided for convenience. + - This function is threadsafe in the sense that multiple threads can call + evaluate_detectors() with the same instances of detectors and img without + requiring a mutex lock. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SCAN_fHOG_PYRAMID_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/image_processing/scan_image.h b/ml/dlib/dlib/image_processing/scan_image.h new file mode 100644 index 000000000..1a9c46eda --- /dev/null +++ b/ml/dlib/dlib/image_processing/scan_image.h @@ -0,0 +1,368 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SCAN_iMAGE_Hh_ +#define DLIB_SCAN_iMAGE_Hh_ + +#include +#include +#include "scan_image_abstract.h" +#include "../matrix.h" +#include "../algs.h" +#include "../rand.h" +#include "../array2d.h" +#include "../image_transforms/spatial_filtering.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + + inline rectangle bounding_box_of_rects ( + const std::vector >& rects, + const point& position + ) + /*! + ensures + - returns the smallest rectangle that contains all the + rectangles in rects. That is, returns the rectangle that + contains translate_rect(rects[i].second,position) for all valid i. + !*/ + { + rectangle rect; + + for (unsigned long i = 0; i < rects.size(); ++i) + { + rect += translate_rect(rects[i].second,position); + } + + return rect; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_array_type + > + bool all_images_same_size ( + const image_array_type& images + ) + { + if (images.size() == 0) + return true; + + for (unsigned long i = 0; i < images.size(); ++i) + { + if (num_rows(images[0]) != num_rows(images[i]) || + num_columns(images[0]) != num_columns(images[i])) + return false; + } + + return true; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_array_type + > + double sum_of_rects_in_images ( + const image_array_type& images, + const std::vector >& rects, + const point& position + ) + { + DLIB_ASSERT(all_images_same_size(images), + "\t double sum_of_rects_in_images()" + << "\n\t Invalid arguments given to this function." + << "\n\t all_images_same_size(images): " << all_images_same_size(images) + ); +#ifdef ENABLE_ASSERTS + for (unsigned long i = 0; i < rects.size(); ++i) + { + DLIB_ASSERT(rects[i].first < images.size(), + "\t double sum_of_rects_in_images()" + << "\n\t rects["<::pixel_type pixel_type; + typedef typename promote::type ptype; + + ptype temp = 0; + + for (unsigned long i = 0; i < rects.size(); ++i) + { + const typename image_array_type::type& img = images[rects[i].first]; + const rectangle rect = get_rect(img).intersect(translate_rect(rects[i].second,position)); + temp += sum(matrix_cast(subm(mat(img), rect))); + } + + return static_cast(temp); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_array_type + > + double sum_of_rects_in_images_movable_parts ( + const image_array_type& images, + const rectangle& window, + const std::vector >& fixed_rects, + const std::vector >& movable_rects, + const point& position + ) + { + DLIB_ASSERT(all_images_same_size(images) && center(window) == point(0,0), + "\t double sum_of_rects_in_images_movable_parts()" + << "\n\t Invalid arguments given to this function." + << "\n\t all_images_same_size(images): " << all_images_same_size(images) + << "\n\t center(window): " << center(window) + ); +#ifdef ENABLE_ASSERTS + for (unsigned long i = 0; i < fixed_rects.size(); ++i) + { + DLIB_ASSERT(fixed_rects[i].first < images.size(), + "\t double sum_of_rects_in_images_movable_parts()" + << "\n\t fixed_rects["<::pixel_type pixel_type; + typedef typename promote::type ptype; + + ptype temp = 0; + + // compute TOTAL_FIXED part + for (unsigned long i = 0; i < fixed_rects.size(); ++i) + { + const typename image_array_type::type& img = images[fixed_rects[i].first]; + const rectangle rect = get_rect(img).intersect(translate_rect(fixed_rects[i].second,position)); + temp += sum(matrix_cast(subm(mat(img), rect))); + } + + if (images.size() > 0) + { + // compute TOTAL_MOVABLE part + array2d tempimg(images[0].nr(), images[0].nc()); + for (unsigned long i = 0; i < movable_rects.size(); ++i) + { + const typename image_array_type::type& img = images[movable_rects[i].first]; + + sum_filter_assign(img, tempimg, movable_rects[i].second); + + const rectangle rect = get_rect(tempimg).intersect(translate_rect(window,position)); + if (rect.is_empty() == false) + temp += std::max(0,max(matrix_cast(subm(mat(tempimg), rect)))); + } + } + + return static_cast(temp); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + void find_points_above_thresh ( + std::vector >& dets, + const image_type& img_, + const double thresh, + const unsigned long max_dets + ) + { + const_image_view img(img_); + typedef typename image_traits::pixel_type ptype; + + dets.clear(); + if (max_dets == 0) + return; + + unsigned long count = 0; + dlib::rand rnd; + for (long r = 0; r < img.nr(); ++r) + { + for (long c = 0; c < img.nc(); ++c) + { + const ptype val = img[r][c]; + if (val >= thresh) + { + ++count; + + if (dets.size() < max_dets) + { + dets.push_back(std::make_pair(val, point(c,r))); + } + else + { + // The idea here is to cause us to randomly sample possible detection + // locations throughout the image rather than just stopping the detection + // procedure once we hit the max_dets limit. So this method will result + // in a random subsample of all the detections >= thresh being in dets + // at the end of scan_image(). + const unsigned long random_index = rnd.get_random_32bit_number()%count; + if (random_index < dets.size()) + { + dets[random_index] = std::make_pair(val, point(c,r)); + } + } + } + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_array_type + > + void scan_image ( + std::vector >& dets, + const image_array_type& images, + const std::vector >& rects, + const double thresh, + const unsigned long max_dets + ) + { + DLIB_ASSERT(images.size() > 0 && rects.size() > 0 && all_images_same_size(images), + "\t void scan_image()" + << "\n\t Invalid arguments given to this function." + << "\n\t images.size(): " << images.size() + << "\n\t rects.size(): " << rects.size() + << "\n\t all_images_same_size(images): " << all_images_same_size(images) + ); +#ifdef ENABLE_ASSERTS + for (unsigned long i = 0; i < rects.size(); ++i) + { + DLIB_ASSERT(rects[i].first < images.size(), + "\t void scan_image()" + << "\n\t rects["<::pixel_type pixel_type; + typedef typename promote::type ptype; + + array2d accum(images[0].nr(), images[0].nc()); + assign_all_pixels(accum, 0); + + for (unsigned long i = 0; i < rects.size(); ++i) + sum_filter(images[rects[i].first], accum, rects[i].second); + + find_points_above_thresh(dets, accum, thresh, max_dets); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_array_type + > + void scan_image_movable_parts ( + std::vector >& dets, + const image_array_type& images, + const rectangle& window, + const std::vector >& fixed_rects, + const std::vector >& movable_rects, + const double thresh, + const unsigned long max_dets + ) + { + DLIB_ASSERT(images.size() > 0 && all_images_same_size(images) && + center(window) == point(0,0) && window.area() > 0, + "\t void scan_image_movable_parts()" + << "\n\t Invalid arguments given to this function." + << "\n\t all_images_same_size(images): " << all_images_same_size(images) + << "\n\t center(window): " << center(window) + << "\n\t window.area(): " << window.area() + << "\n\t images.size(): " << images.size() + ); +#ifdef ENABLE_ASSERTS + for (unsigned long i = 0; i < fixed_rects.size(); ++i) + { + DLIB_ASSERT(fixed_rects[i].first < images.size(), + "\t void scan_image_movable_parts()" + << "\n\t Invalid arguments given to this function." + << "\n\t fixed_rects["< 0, + "\t void scan_image_movable_parts()" + << "\n\t Invalid arguments given to this function." + << "\n\t movable_rects["<::pixel_type pixel_type; + typedef typename promote::type ptype; + + array2d accum(images[0].nr(), images[0].nc()); + assign_all_pixels(accum, 0); + + for (unsigned long i = 0; i < fixed_rects.size(); ++i) + sum_filter(images[fixed_rects[i].first], accum, fixed_rects[i].second); + + array2d temp(accum.nr(), accum.nc()); + for (unsigned long i = 0; i < movable_rects.size(); ++i) + { + const rectangle rect = movable_rects[i].second; + sum_filter_assign(images[movable_rects[i].first], temp, rect); + max_filter(temp, accum, window.width(), window.height(), 0); + } + + find_points_above_thresh(dets, accum, thresh, max_dets); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SCAN_iMAGE_Hh_ + + diff --git a/ml/dlib/dlib/image_processing/scan_image_abstract.h b/ml/dlib/dlib/image_processing/scan_image_abstract.h new file mode 100644 index 000000000..fe2fc51ac --- /dev/null +++ b/ml/dlib/dlib/image_processing/scan_image_abstract.h @@ -0,0 +1,227 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SCAN_iMAGE_ABSTRACT_Hh_ +#ifdef DLIB_SCAN_iMAGE_ABSTRACT_Hh_ + +#include +#include +#include "../algs.h" +#include "../image_processing/generic_image.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_array_type + > + bool all_images_same_size ( + const image_array_type& images + ); + /*! + requires + - image_array_type == an implementation of array/array_kernel_abstract.h + - image_array_type::type == an image object that implements the interface + defined in dlib/image_processing/generic_image.h + ensures + - if (all elements of images have the same dimensions (i.e. + for all i and j: get_rect(images[i]) == get_rect(images[j]))) then + - returns true + - else + - returns false + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_array_type + > + double sum_of_rects_in_images ( + const image_array_type& images, + const std::vector >& rects, + const point& position + ); + /*! + requires + - image_array_type == an implementation of array/array_kernel_abstract.h + - image_array_type::type == an image object that implements the interface + defined in dlib/image_processing/generic_image.h. Moreover, these objects must + contain a scalar pixel type (e.g. int rather than rgb_pixel) + - all_images_same_size(images) == true + - for all valid i: rects[i].first < images.size() + (i.e. all the rectangles must reference valid elements of images) + ensures + - returns the sum of the pixels inside the given rectangles. To be precise, + let RECT_SUM[i] = sum of pixels inside the rectangle translate_rect(rects[i].second, position) + from the image images[rects[i].first]. Then this function returns the + sum of RECT_SUM[i] for all the valid values of i. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_array_type + > + double sum_of_rects_in_images_movable_parts ( + const image_array_type& images, + const rectangle& window, + const std::vector >& fixed_rects, + const std::vector >& movable_rects, + const point& position + ); + /*! + requires + - image_array_type == an implementation of array/array_kernel_abstract.h + - image_array_type::type == an image object that implements the interface + defined in dlib/image_processing/generic_image.h. Moreover, these objects must + contain a scalar pixel type (e.g. int rather than rgb_pixel) + - all_images_same_size(images) == true + - center(window) == point(0,0) + - for all valid i: + - fixed_rects[i].first < images.size() + (i.e. all the rectangles must reference valid elements of images) + - for all valid i: + - movable_rects[i].first < images.size() + (i.e. all the rectangles must reference valid elements of images) + - center(movable_rects[i].second) == point(0,0) + ensures + - returns the sum of the pixels inside fixed_rects as well as the sum of the pixels + inside movable_rects when these latter rectangles are placed at their highest + scoring locations inside the given window. To be precise: + - let RECT_SUM(r,x) = sum of pixels inside the rectangle translate_rect(r.second, x) + from the image images[r.first]. + - let WIN_MAX(i) = The maximum value of RECT_SUM(movable_rects[i],X) when maximizing + over all the X such that translate_rect(window,position).contains(X) == true. + + - let TOTAL_FIXED == sum over all elements R in fixed_rects of: RECT_SUM(R,position) + - let TOTAL_MOVABLE == sum over all valid i of: max(WIN_MAX(i), 0) + + Then this function returns TOTAL_FIXED + TOTAL_MOVABLE. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + void find_points_above_thresh ( + std::vector >& dets, + const image_type& img, + const double thresh, + const unsigned long max_dets + ); + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h. Moreover, these it must contain a + scalar pixel type (e.g. int rather than rgb_pixel) + ensures + - #dets == a list of points from img which had pixel values >= thresh. + - Specifically, we have: + - #dets.size() <= max_dets + (note that dets is cleared before new detections are added by find_points_above_thresh()) + - for all valid i: + - #dets[i].first == img[#dets[i].second.y()][#dets[i].second.x()] + (i.e. the first field contains the value of the pixel at this detection location) + - #dets[i].first >= thresh + - if (there are more than max_dets locations that pass the above threshold test) then + - #dets == a random subsample of all the locations which passed the threshold + test. + - else + - #dets == all the points which passed the threshold test. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_array_type + > + void scan_image ( + std::vector >& dets, + const image_array_type& images, + const std::vector >& rects, + const double thresh, + const unsigned long max_dets + ); + /*! + requires + - image_array_type == an implementation of array/array_kernel_abstract.h + - image_array_type::type == an image object that implements the interface + defined in dlib/image_processing/generic_image.h. Moreover, these objects must + contain a scalar pixel type (e.g. int rather than rgb_pixel) + - images.size() > 0 + - rects.size() > 0 + - all_images_same_size(images) == true + - for all valid i: rects[i].first < images.size() + (i.e. all the rectangles must reference valid elements of images) + ensures + - slides the set of rectangles over the image space and reports the locations + which give a sum bigger than thresh. + - Specifically, we have: + - #dets.size() <= max_dets + (note that dets is cleared before new detections are added by scan_image()) + - for all valid i: + - #dets[i].first == sum_of_rects_in_images(images,rects,#dets[i].second) >= thresh + - if (there are more than max_dets locations that pass the threshold test) then + - #dets == a random subsample of all the locations which passed the threshold + test. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_array_type + > + void scan_image_movable_parts ( + std::vector >& dets, + const image_array_type& images, + const rectangle& window, + const std::vector >& fixed_rects, + const std::vector >& movable_rects, + const double thresh, + const unsigned long max_dets + ); + /*! + requires + - image_array_type == an implementation of array/array_kernel_abstract.h + - image_array_type::type == an image object that implements the interface + defined in dlib/image_processing/generic_image.h. Moreover, these objects must + contain a scalar pixel type (e.g. int rather than rgb_pixel) + - images.size() > 0 + - all_images_same_size(images) == true + - center(window) == point(0,0) + - window.area() > 0 + - for all valid i: + - fixed_rects[i].first < images.size() + (i.e. all the rectangles must reference valid elements of images) + - for all valid i: + - movable_rects[i].first < images.size() + (i.e. all the rectangles must reference valid elements of images) + - center(movable_rects[i].second) == point(0,0) + - movable_rects[i].second.area() > 0 + ensures + - Scans the given window over the images and reports the locations with a score bigger + than thresh. + - Specifically, we have: + - #dets.size() <= max_dets + (note that dets is cleared before new detections are added by scan_image_movable_parts()) + - for all valid i: + - #dets[i].first == sum_of_rects_in_images_movable_parts(images, + window, + fixed_rects, + movable_rects, + #dets[i].second) >= thresh + - if (there are more than max_dets locations that pass the above threshold test) then + - #dets == a random subsample of all the locations which passed the threshold + test. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SCAN_iMAGE_ABSTRACT_Hh_ + + + diff --git a/ml/dlib/dlib/image_processing/scan_image_boxes.h b/ml/dlib/dlib/image_processing/scan_image_boxes.h new file mode 100644 index 000000000..f4549565c --- /dev/null +++ b/ml/dlib/dlib/image_processing/scan_image_boxes.h @@ -0,0 +1,630 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SCAN_IMAGE_bOXES_Hh_ +#define DLIB_SCAN_IMAGE_bOXES_Hh_ + +#include "scan_image_boxes_abstract.h" +#include "../matrix.h" +#include "../geometry.h" +#include "../array2d.h" +#include +#include "../image_processing/full_object_detection.h" +#include "../image_transforms.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class default_box_generator + { + public: + template + void operator() ( + const image_type& img, + std::vector& rects + ) const + { + rects.clear(); + find_candidate_object_locations(img, rects); + } + }; + + inline void serialize(const default_box_generator&, std::ostream& ) {} + inline void deserialize(default_box_generator&, std::istream& ) {} + +// ---------------------------------------------------------------------------------------- + + template < + typename Feature_extractor_type, + typename Box_generator = default_box_generator + > + class scan_image_boxes : noncopyable + { + + public: + + typedef matrix feature_vector_type; + + typedef Feature_extractor_type feature_extractor_type; + typedef Box_generator box_generator; + + scan_image_boxes ( + ); + + template < + typename image_type + > + void load ( + const image_type& img + ); + + inline bool is_loaded_with_image ( + ) const; + + inline void copy_configuration( + const feature_extractor_type& fe + ); + + inline void copy_configuration( + const box_generator& bg + ); + + const box_generator& get_box_generator ( + ) const { return detect_boxes; } + + const Feature_extractor_type& get_feature_extractor ( + ) const { return feats; } + + inline void copy_configuration ( + const scan_image_boxes& item + ); + + inline long get_num_dimensions ( + ) const; + + unsigned long get_num_spatial_pyramid_levels ( + ) const; + + void set_num_spatial_pyramid_levels ( + unsigned long levels + ); + + void detect ( + const feature_vector_type& w, + std::vector >& dets, + const double thresh + ) const; + + void get_feature_vector ( + const full_object_detection& obj, + feature_vector_type& psi + ) const; + + full_object_detection get_full_object_detection ( + const rectangle& rect, + const feature_vector_type& w + ) const; + + const rectangle get_best_matching_rect ( + const rectangle& rect + ) const; + /*! + requires + - is_loaded_with_image() == true + !*/ + + inline unsigned long get_num_detection_templates ( + ) const { return 1; } + + inline unsigned long get_num_movable_components_per_detection_template ( + ) const { return 0; } + + template + friend void serialize ( + const scan_image_boxes& item, + std::ostream& out + ); + + template + friend void deserialize ( + scan_image_boxes& item, + std::istream& in + ); + + private: + static bool compare_pair_rect ( + const std::pair& a, + const std::pair& b + ) + { + return a.first < b.first; + } + + void test_coordinate_transforms() + { + for (long x = -10; x <= 10; x += 10) + { + for (long y = -10; y <= 10; y += 10) + { + const rectangle rect = centered_rect(x,y,5,6); + rectangle a; + + a = feats.image_to_feat_space(rect); + if (a.width() > 10000000 || a.height() > 10000000 ) + { + DLIB_CASSERT(false, "The image_to_feat_space() routine is outputting rectangles of an implausibly " + << "\nlarge size. This means there is probably a bug in your feature extractor."); + } + a = feats.feat_to_image_space(rect); + if (a.width() > 10000000 || a.height() > 10000000 ) + { + DLIB_CASSERT(false, "The feat_to_image_space() routine is outputting rectangles of an implausibly " + << "\nlarge size. This means there is probably a bug in your feature extractor."); + } + } + } + + } + + static void add_grid_rects ( + std::vector& rects, + const rectangle& object_box, + unsigned int cells_x, + unsigned int cells_y + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(cells_x > 0 && cells_y > 0, + "\t void add_grid_rects()" + << "\n\t The number of cells along a dimension can't be zero. " + << "\n\t cells_x: " << cells_x + << "\n\t cells_y: " << cells_y + ); + + const matrix_range_exp& x = linspace(object_box.left(), object_box.right(), cells_x+1); + const matrix_range_exp& y = linspace(object_box.top(), object_box.bottom(), cells_y+1); + + for (long j = 0; j+1 < y.size(); ++j) + { + for (long i = 0; i+1 < x.size(); ++i) + { + const dlib::vector tl(x(i),y(j)); + const dlib::vector br(x(i+1),y(j+1)); + rects.push_back(rectangle(tl,br)); + } + } + } + + void get_feature_extraction_regions ( + const rectangle& rect, + std::vector& regions + ) const + /*! + ensures + - #regions.size() is always the same number no matter what the input is. The + regions also have a consistent ordering. + - all the output rectangles are contained within rect. + !*/ + { + regions.clear(); + + for (unsigned int l = 1; l <= num_spatial_pyramid_levels; ++l) + { + const int cells = (int)std::pow(2.0, l-1.0); + add_grid_rects(regions, rect, cells, cells); + } + } + + unsigned int get_num_components_per_detection_template( + ) const + { + return (unsigned int)(std::pow(4.0,(double)num_spatial_pyramid_levels)-1)/3; + } + + feature_extractor_type feats; + std::vector search_rects; + bool loaded_with_image; + unsigned int num_spatial_pyramid_levels; + box_generator detect_boxes; + + const long box_sizedims; + const long box_maxsize; + }; + +// ---------------------------------------------------------------------------------------- + + template + void serialize ( + const scan_image_boxes& item, + std::ostream& out + ) + { + int version = 1; + serialize(version, out); + serialize(item.feats, out); + serialize(item.search_rects, out); + serialize(item.loaded_with_image, out); + serialize(item.num_spatial_pyramid_levels, out); + serialize(item.detect_boxes, out); + serialize(item.get_num_dimensions(), out); + } + +// ---------------------------------------------------------------------------------------- + + template + void deserialize ( + scan_image_boxes& item, + std::istream& in + ) + { + int version = 0; + deserialize(version, in); + if (version != 1) + throw serialization_error("Unsupported version found when deserializing a scan_image_boxes object."); + + deserialize(item.feats, in); + deserialize(item.search_rects, in); + deserialize(item.loaded_with_image, in); + deserialize(item.num_spatial_pyramid_levels, in); + deserialize(item.detect_boxes, in); + + // When developing some feature extractor, it's easy to accidentally change its + // number of dimensions and then try to deserialize data from an older version of + // your extractor into the current code. This check is here to catch that kind of + // user error. + long dims; + deserialize(dims, in); + if (item.get_num_dimensions() != dims) + throw serialization_error("Number of dimensions in serialized scan_image_boxes doesn't match the expected number."); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// scan_image_boxes member functions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename Feature_extractor_type, + typename Box_generator + > + scan_image_boxes:: + scan_image_boxes ( + ) : + loaded_with_image(false), + num_spatial_pyramid_levels(3), + box_sizedims(20), + box_maxsize(1200) + { + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Feature_extractor_type, + typename Box_generator + > + template < + typename image_type + > + void scan_image_boxes:: + load ( + const image_type& img + ) + { + feats.load(img); + detect_boxes(img, search_rects); + loaded_with_image = true; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Feature_extractor_type, + typename Box_generator + > + bool scan_image_boxes:: + is_loaded_with_image ( + ) const + { + return loaded_with_image; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Feature_extractor_type, + typename Box_generator + > + void scan_image_boxes:: + copy_configuration( + const feature_extractor_type& fe + ) + { + test_coordinate_transforms(); + feats.copy_configuration(fe); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Feature_extractor_type, + typename Box_generator + > + void scan_image_boxes:: + copy_configuration( + const box_generator& bg + ) + { + detect_boxes = bg; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Feature_extractor_type, + typename Box_generator + > + void scan_image_boxes:: + copy_configuration ( + const scan_image_boxes& item + ) + { + feats.copy_configuration(item.feats); + detect_boxes = item.detect_boxes; + num_spatial_pyramid_levels = item.num_spatial_pyramid_levels; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Feature_extractor_type, + typename Box_generator + > + unsigned long scan_image_boxes:: + get_num_spatial_pyramid_levels ( + ) const + { + return num_spatial_pyramid_levels; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Feature_extractor_type, + typename Box_generator + > + void scan_image_boxes:: + set_num_spatial_pyramid_levels ( + unsigned long levels + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(levels > 0, + "\t void scan_image_boxes::set_num_spatial_pyramid_levels()" + << "\n\t Invalid inputs were given to this function " + << "\n\t levels: " << levels + << "\n\t this: " << this + ); + + + num_spatial_pyramid_levels = levels; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Feature_extractor_type, + typename Box_generator + > + long scan_image_boxes:: + get_num_dimensions ( + ) const + { + return feats.get_num_dimensions()*get_num_components_per_detection_template() + box_sizedims*2; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Feature_extractor_type, + typename Box_generator + > + void scan_image_boxes:: + detect ( + const feature_vector_type& w, + std::vector >& dets, + const double thresh + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(is_loaded_with_image() && + w.size() >= get_num_dimensions(), + "\t void scan_image_boxes::detect()" + << "\n\t Invalid inputs were given to this function " + << "\n\t is_loaded_with_image(): " << is_loaded_with_image() + << "\n\t w.size(): " << w.size() + << "\n\t get_num_dimensions(): " << get_num_dimensions() + << "\n\t this: " << this + ); + + dets.clear(); + + array > saliency_images(get_num_components_per_detection_template()); + + array2d temp_img(feats.nr(), feats.nc()); + + // build saliency images + for (unsigned long i = 0; i < saliency_images.size(); ++i) + { + const unsigned long offset = 2*box_sizedims + feats.get_num_dimensions()*i; + + // make the basic saliency image for the i-th feature extraction region + for (long r = 0; r < feats.nr(); ++r) + { + for (long c = 0; c < feats.nc(); ++c) + { + const typename feature_extractor_type::descriptor_type& descriptor = feats(r,c); + + double sum = 0; + for (unsigned long k = 0; k < descriptor.size(); ++k) + { + sum += w(descriptor[k].first + offset)*descriptor[k].second; + } + temp_img[r][c] = sum; + } + } + + // now convert base saliency image into final integral image + saliency_images[i].load(temp_img); + } + + + // now search the saliency images + std::vector regions; + const rectangle bounds = get_rect(feats); + for (unsigned long i = 0; i < search_rects.size(); ++i) + { + const rectangle rect = feats.image_to_feat_space(search_rects[i]).intersect(bounds); + if (rect.is_empty()) + continue; + get_feature_extraction_regions(rect, regions); + double score = 0; + for (unsigned long k = 0; k < regions.size(); ++k) + { + score += saliency_images[k].get_sum_of_area(regions[k]); + } + const double width = search_rects[i].width(); + const double height = search_rects[i].height(); + + score += dot(linpiece(width, linspace(0, box_maxsize, box_sizedims+1)), rowm(w, range(0,box_sizedims-1))); + score += dot(linpiece(height, linspace(0, box_maxsize, box_sizedims+1)), rowm(w, range(box_sizedims,2*box_sizedims-1))); + + if (score >= thresh) + { + dets.push_back(std::make_pair(score, search_rects[i])); + } + } + + std::sort(dets.rbegin(), dets.rend(), compare_pair_rect); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Feature_extractor_type, + typename Box_generator + > + const rectangle scan_image_boxes:: + get_best_matching_rect ( + const rectangle& rect + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(is_loaded_with_image(), + "\t const rectangle scan_image_boxes::get_best_matching_rect()" + << "\n\t Invalid inputs were given to this function " + << "\n\t is_loaded_with_image(): " << is_loaded_with_image() + << "\n\t this: " << this + ); + + + double best_score = -1; + rectangle best_rect; + for (unsigned long i = 0; i < search_rects.size(); ++i) + { + const double score = (rect.intersect(search_rects[i])).area()/(double)(rect+search_rects[i]).area(); + if (score > best_score) + { + best_score = score; + best_rect = search_rects[i]; + } + } + return best_rect; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Feature_extractor_type, + typename Box_generator + > + full_object_detection scan_image_boxes:: + get_full_object_detection ( + const rectangle& rect, + const feature_vector_type& /*w*/ + ) const + { + return full_object_detection(rect); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Feature_extractor_type, + typename Box_generator + > + void scan_image_boxes:: + get_feature_vector ( + const full_object_detection& obj, + feature_vector_type& psi + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(is_loaded_with_image() && + psi.size() >= get_num_dimensions() && + obj.num_parts() == 0, + "\t void scan_image_boxes::get_feature_vector()" + << "\n\t Invalid inputs were given to this function " + << "\n\t is_loaded_with_image(): " << is_loaded_with_image() + << "\n\t psi.size(): " << psi.size() + << "\n\t get_num_dimensions(): " << get_num_dimensions() + << "\n\t obj.num_parts(): " << obj.num_parts() + << "\n\t this: " << this + ); + + + + const rectangle best_rect = get_best_matching_rect(obj.get_rect()); + const rectangle mapped_rect = feats.image_to_feat_space(best_rect).intersect(get_rect(feats)); + if (mapped_rect.is_empty()) + return; + + std::vector regions; + get_feature_extraction_regions(mapped_rect, regions); + + // pull features out of all the boxes in regions. + for (unsigned long j = 0; j < regions.size(); ++j) + { + const rectangle rect = regions[j]; + + const unsigned long template_region_id = j; + const unsigned long offset = box_sizedims*2 + feats.get_num_dimensions()*template_region_id; + for (long r = rect.top(); r <= rect.bottom(); ++r) + { + for (long c = rect.left(); c <= rect.right(); ++c) + { + const typename feature_extractor_type::descriptor_type& descriptor = feats(r,c); + for (unsigned long k = 0; k < descriptor.size(); ++k) + { + psi(descriptor[k].first + offset) += descriptor[k].second; + } + } + } + } + + const double width = best_rect.width(); + const double height = best_rect.height(); + set_rowm(psi, range(0,box_sizedims-1)) += linpiece(width, linspace(0, box_maxsize, box_sizedims+1)); + set_rowm(psi, range(box_sizedims,box_sizedims*2-1)) += linpiece(height, linspace(0, box_maxsize, box_sizedims+1)); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SCAN_IMAGE_bOXES_Hh_ + + + diff --git a/ml/dlib/dlib/image_processing/scan_image_boxes_abstract.h b/ml/dlib/dlib/image_processing/scan_image_boxes_abstract.h new file mode 100644 index 000000000..e2f16aa76 --- /dev/null +++ b/ml/dlib/dlib/image_processing/scan_image_boxes_abstract.h @@ -0,0 +1,394 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SCAN_IMAGE_bOXES_ABSTRACT_Hh_ +#ifdef DLIB_SCAN_IMAGE_bOXES_ABSTRACT_Hh_ + +#include "../matrix.h" +#include "../geometry.h" +#include "../image_processing.h" +#include "../array2d.h" +#include "full_object_detection_abstract.h" +#include "../image_transforms/segment_image_abstract.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class default_box_generator + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a function object that takes in an image and outputs a set of + candidate object locations. It is also the default box generator used by + the scan_image_boxes object defined below. + !*/ + + public: + + template + void operator() ( + const image_type& img, + std::vector& rects + ) const + /*! + ensures + - #rects == the set of candidate object locations which should be searched + inside img. That is, these are the rectangles which might contain + objects of interest within the given image. + !*/ + { + rects.clear(); + find_candidate_object_locations(img, rects); + } + }; + + inline void serialize (const default_box_generator&, std::ostream& ) {} + inline void deserialize( default_box_generator&, std::istream& ) {} + /*! + ensures + - provides serialization support. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename Feature_extractor_type, + typename Box_generator = default_box_generator + > + class scan_image_boxes : noncopyable + { + /*! + REQUIREMENTS ON Feature_extractor_type + - must be an object with an interface compatible with the hashed_feature_image + object defined in dlib/image_keypoint/hashed_feature_image_abstract.h or + with the nearest_neighbor_feature_image object defined in + dlib/image_keypoint/nearest_neighbor_feature_image_abstract.h + + REQUIREMENTS ON Box_generator + - must be an object with an interface compatible with the + default_box_generator object defined at the top of this file. + + INITIAL VALUE + - get_num_spatial_pyramid_levels() == 3 + - is_loaded_with_image() == false + + WHAT THIS OBJECT REPRESENTS + This object is a tool for running a classifier over an image with the goal + of localizing each object present. The localization is in the form of the + bounding box around each object of interest. + + Unlike the scan_image_pyramid object which scans a fixed sized window over + an image pyramid, the scan_image_boxes tool allows you to define your own + list of "candidate object locations" which should be evaluated. This is + simply a list of rectangle objects which might contain objects of interest. + The scan_image_boxes object will then evaluate the classifier at each of + these locations and return the subset of rectangles which appear to have + objects in them. The candidate object location generation is provided by + the Box_generator that is passed in as a template argument. + + This object can also be understood as a general tool for implementing the + spatial pyramid models described in the paper: + Beyond Bags of Features: Spatial Pyramid Matching for Recognizing + Natural Scene Categories by Svetlana Lazebnik, Cordelia Schmid, + and Jean Ponce + + + The classifiers used by this object have three parts: + 1. The underlying feature extraction provided by Feature_extractor_type + objects, which associate a vector with each location in an image. + + 2. A rule for extracting a feature vector from a candidate object + location. In this object we use the spatial pyramid matching method. + This means we cut an object's detection window into a set of "feature + extraction regions" and extract a bag-of-words vector from each + before finally concatenating them to form the final feature vector + representing the entire object window. The set of feature extraction + regions can be configured by the user by calling + set_num_spatial_pyramid_levels(). To be a little more precise, the + feature vector for a candidate object window is defined as follows: + - Let N denote the number of feature extraction zones. + - Let M denote the dimensionality of the vectors output by + Feature_extractor_type objects. + - Let F(i) == the M dimensional vector which is the sum of all + vectors given by our Feature_extractor_type object inside the + i-th feature extraction zone. So this is notionally a + bag-of-words vector from the i-th zone. + - Then the feature vector for an object window is an M*N + dimensional vector [F(1) F(2) F(3) ... F(N)] (i.e. it is a + concatenation of the N vectors). This feature vector can be + thought of as a collection of N bags-of-words, each bag coming + from a spatial location determined by one of the feature + extraction zones. + + 3. A weight vector and a threshold value. The dot product between the + weight vector and the feature vector for a candidate object location + gives the score of the location. If this score is greater than the + threshold value then the candidate object location is output as a + detection. + + THREAD SAFETY + Concurrent access to an instance of this object is not safe and should be + protected by a mutex lock except for the case where you are copying the + configuration (via copy_configuration()) of a scan_image_boxes object to + many other threads. In this case, it is safe to copy the configuration of + a shared object so long as no other operations are performed on it. + !*/ + + public: + + typedef matrix feature_vector_type; + + typedef Feature_extractor_type feature_extractor_type; + typedef Box_generator box_generator; + + scan_image_boxes ( + ); + /*! + ensures + - this object is properly initialized + !*/ + + template < + typename image_type + > + void load ( + const image_type& img + ); + /*! + requires + - image_type must be a type with the following properties: + - image_type objects can be loaded into Feature_extractor_type + objects via Feature_extractor_type::load(). + - image_type objects can be passed to the first argument of + Box_generator::operator() + ensures + - #is_loaded_with_image() == true + - This object is ready to run a classifier over img to detect object + locations. Call detect() to do this. + !*/ + + bool is_loaded_with_image ( + ) const; + /*! + ensures + - returns true if this object has been loaded with an image to process and + false otherwise. + !*/ + + const feature_extractor_type& get_feature_extractor ( + ) const; + /*! + ensures + - returns a const reference to the feature_extractor_type object used + internally for local feature extraction. + !*/ + + void copy_configuration( + const feature_extractor_type& fe + ); + /*! + ensures + - This function performs the equivalent of + get_feature_extractor().copy_configuration(fe) (i.e. this function allows + you to configure the parameters of the underlying feature extractor used + by a scan_image_boxes object) + !*/ + + void copy_configuration( + const box_generator& bg + ); + /*! + ensures + - #get_box_generator() == bg + (i.e. this function allows you to configure the parameters of the + underlying box generator used by a scan_image_boxes object) + !*/ + + const box_generator& get_box_generator ( + ) const; + /*! + ensures + - returns the box_generator used by this object to generate candidate + object locations. + !*/ + + void copy_configuration ( + const scan_image_boxes& item + ); + /*! + ensures + - Copies all the state information of item into *this, except for state + information populated by load(). More precisely, given two scan_image_boxes + objects S1 and S2, the following sequence of instructions should always + result in both of them having the exact same state: + S2.copy_configuration(S1); + S1.load(img); + S2.load(img); + !*/ + + long get_num_dimensions ( + ) const; + /*! + ensures + - returns the number of dimensions in the feature vector for a candidate + object location. This value is the dimensionality of the underlying + feature vectors produced by Feature_extractor_type times the number of + feature extraction regions used. Note that the number of feature + extraction regions used is a function of + get_num_spatial_pyramid_levels(). + !*/ + + unsigned long get_num_spatial_pyramid_levels ( + ) const; + /*! + ensures + - returns the number of layers in the spatial pyramid. For example, if + this function returns 1 then it means we use a simple bag-of-words + representation over the whole object window. If it returns 2 then it + means the feature representation is the concatenation of 5 bag-of-words + vectors, one from the entire object window and 4 others from 4 different + parts of the object window. If it returns 3 then there are 1+4+16 + bag-of-words vectors concatenated together in the feature representation, + and so on. + !*/ + + void set_num_spatial_pyramid_levels ( + unsigned long levels + ); + /*! + requires + - levels > 0 + ensures + - #get_num_spatial_pyramid_levels() == levels + !*/ + + void detect ( + const feature_vector_type& w, + std::vector >& dets, + const double thresh + ) const; + /*! + requires + - w.size() >= get_num_dimensions() + - is_loaded_with_image() == true + ensures + - Scans over all the candidate object locations as discussed in the WHAT + THIS OBJECT REPRESENTS section and stores all detections into #dets. + - for all valid i: + - #dets[i].second == The candidate object location which produced this + detection. This rectangle gives the location of the detection. + - #dets[i].first == The score for this detection. This value is equal + to dot(w, feature vector for this candidate object location). + - #dets[i].first >= thresh + - #dets will be sorted in descending order. + (i.e. #dets[i].first >= #dets[j].first for all i, and j>i) + - Elements of w beyond index get_num_dimensions()-1 are ignored. I.e. only + the first get_num_dimensions() are used. + - Note that no form of non-max suppression is performed. If a locations + has a score >= thresh then it is reported in #dets. + !*/ + + void get_feature_vector ( + const full_object_detection& obj, + feature_vector_type& psi + ) const; + /*! + requires + - obj.num_parts() == 0 + - is_loaded_with_image() == true + - psi.size() >= get_num_dimensions() + (i.e. psi must have preallocated its memory before this function is called) + ensures + - This function allows you to determine the feature vector used for a + candidate object location output from detect(). Note that this vector is + added to psi. Note also that you must use get_full_object_detection() to + convert a rectangle from detect() into the needed full_object_detection. + - The dimensionality of the vector added to psi is get_num_dimensions(). This + means that elements of psi after psi(get_num_dimensions()-1) are not modified. + - Since scan_image_boxes only searches a limited set of object locations, + not all possible rectangles can be output by detect(). So in the case + where obj.get_rect() could not arise from a call to detect(), this + function will map obj.get_rect() to the nearest possible rectangle and + then add the feature vector for the mapped rectangle into #psi. + - get_best_matching_rect(obj.get_rect()) == the rectangle obj.get_rect() + gets mapped to for feature extraction. + !*/ + + full_object_detection get_full_object_detection ( + const rectangle& rect, + const feature_vector_type& w + ) const; + /*! + ensures + - returns full_object_detection(rect) + (This function is here only for compatibility with the scan_image_pyramid + object) + !*/ + + const rectangle get_best_matching_rect ( + const rectangle& rect + ) const; + /*! + requires + - is_loaded_with_image() == true + ensures + - Since scan_image_boxes only searches a limited set of object locations, + not all possible rectangles can be represented. Therefore, this function + allows you to supply a rectangle and obtain the nearest possible + candidate object location rectangle. + !*/ + + unsigned long get_num_detection_templates ( + ) const { return 1; } + /*! + ensures + - returns 1. Note that this function is here only for compatibility with + the scan_image_pyramid object. Notionally, its return value indicates + that a scan_image_boxes object is always ready to detect objects once + an image has been loaded. + !*/ + + unsigned long get_num_movable_components_per_detection_template ( + ) const { return 0; } + /*! + ensures + - returns 0. Note that this function is here only for compatibility with + the scan_image_pyramid object. Its return value means that this object + does not support using movable part models. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename Feature_extractor_type, + typename Box_generator + > + void serialize ( + const scan_image_boxes& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + + template < + typename Feature_extractor_type, + typename Box_generator + > + void deserialize ( + scan_image_boxes& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SCAN_IMAGE_bOXES_ABSTRACT_Hh_ + diff --git a/ml/dlib/dlib/image_processing/scan_image_custom.h b/ml/dlib/dlib/image_processing/scan_image_custom.h new file mode 100644 index 000000000..29b969fca --- /dev/null +++ b/ml/dlib/dlib/image_processing/scan_image_custom.h @@ -0,0 +1,401 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SCAN_IMAGE_CuSTOM_Hh_ +#define DLIB_SCAN_IMAGE_CuSTOM_Hh_ + +#include "scan_image_custom_abstract.h" +#include "../matrix.h" +#include "../geometry.h" +#include +#include "../image_processing/full_object_detection.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename Feature_extractor_type + > + class scan_image_custom : noncopyable + { + + public: + + typedef matrix feature_vector_type; + typedef Feature_extractor_type feature_extractor_type; + + scan_image_custom ( + ); + + template < + typename image_type + > + void load ( + const image_type& img + ); + + inline bool is_loaded_with_image ( + ) const; + + inline void copy_configuration( + const feature_extractor_type& fe + ); + + const Feature_extractor_type& get_feature_extractor ( + ) const { return feats; } + + inline void copy_configuration ( + const scan_image_custom& item + ); + + inline long get_num_dimensions ( + ) const; + + void detect ( + const feature_vector_type& w, + std::vector >& dets, + const double thresh + ) const; + + void get_feature_vector ( + const full_object_detection& obj, + feature_vector_type& psi + ) const; + + full_object_detection get_full_object_detection ( + const rectangle& rect, + const feature_vector_type& w + ) const; + + const rectangle get_best_matching_rect ( + const rectangle& rect + ) const; + + inline unsigned long get_num_detection_templates ( + ) const { return 1; } + + inline unsigned long get_num_movable_components_per_detection_template ( + ) const { return 0; } + + template + friend void serialize ( + const scan_image_custom& item, + std::ostream& out + ); + + template + friend void deserialize ( + scan_image_custom& item, + std::istream& in + ); + + private: + static bool compare_pair_rect ( + const std::pair& a, + const std::pair& b + ) + { + return a.first < b.first; + } + + + DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST( + has_compute_object_score, + double, + compute_object_score, + ( const matrix& w, const rectangle& obj) const + ); + + template + typename enable_if >::type compute_all_rect_scores ( + const fe_type& feats, + const feature_vector_type& w, + std::vector >& dets, + const double thresh + ) const + { + for (unsigned long i = 0; i < search_rects.size(); ++i) + { + const double score = feats.compute_object_score(w, search_rects[i]); + if (score >= thresh) + { + dets.push_back(std::make_pair(score, search_rects[i])); + } + } + } + + template + typename disable_if >::type compute_all_rect_scores ( + const fe_type& feats, + const feature_vector_type& w, + std::vector >& dets, + const double thresh + ) const + { + matrix psi(w.size()); + psi = 0; + double prev_dot = 0; + for (unsigned long i = 0; i < search_rects.size(); ++i) + { + // Reset these back to zero every so often to avoid the accumulation of + // rounding error. Note that the only reason we do this loop in this + // complex way is to avoid needing to zero the psi vector every iteration. + if ((i%500) == 499) + { + psi = 0; + prev_dot = 0; + } + + feats.get_feature_vector(search_rects[i], psi); + const double cur_dot = dot(psi, w); + const double score = cur_dot - prev_dot; + if (score >= thresh) + { + dets.push_back(std::make_pair(score, search_rects[i])); + } + prev_dot = cur_dot; + } + } + + + feature_extractor_type feats; + std::vector search_rects; + bool loaded_with_image; + }; + +// ---------------------------------------------------------------------------------------- + + template + void serialize ( + const scan_image_custom& item, + std::ostream& out + ) + { + int version = 1; + serialize(version, out); + serialize(item.feats, out); + serialize(item.search_rects, out); + serialize(item.loaded_with_image, out); + serialize(item.get_num_dimensions(), out); + } + +// ---------------------------------------------------------------------------------------- + + template + void deserialize ( + scan_image_custom& item, + std::istream& in + ) + { + int version = 0; + deserialize(version, in); + if (version != 1) + throw serialization_error("Unsupported version found when deserializing a scan_image_custom object."); + + deserialize(item.feats, in); + deserialize(item.search_rects, in); + deserialize(item.loaded_with_image, in); + + // When developing some feature extractor, it's easy to accidentally change its + // number of dimensions and then try to deserialize data from an older version of + // your extractor into the current code. This check is here to catch that kind of + // user error. + long dims; + deserialize(dims, in); + if (item.get_num_dimensions() != dims) + throw serialization_error("Number of dimensions in serialized scan_image_custom doesn't match the expected number."); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// scan_image_custom member functions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename Feature_extractor_type + > + scan_image_custom:: + scan_image_custom ( + ) : + loaded_with_image(false) + { + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Feature_extractor_type + > + template < + typename image_type + > + void scan_image_custom:: + load ( + const image_type& img + ) + { + feats.load(img, search_rects); + loaded_with_image = true; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Feature_extractor_type + > + bool scan_image_custom:: + is_loaded_with_image ( + ) const + { + return loaded_with_image; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Feature_extractor_type + > + void scan_image_custom:: + copy_configuration( + const feature_extractor_type& fe + ) + { + feats.copy_configuration(fe); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Feature_extractor_type + > + void scan_image_custom:: + copy_configuration ( + const scan_image_custom& item + ) + { + feats.copy_configuration(item.feats); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Feature_extractor_type + > + long scan_image_custom:: + get_num_dimensions ( + ) const + { + return feats.get_num_dimensions(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Feature_extractor_type + > + void scan_image_custom:: + detect ( + const feature_vector_type& w, + std::vector >& dets, + const double thresh + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(is_loaded_with_image() && + w.size() >= get_num_dimensions(), + "\t void scan_image_custom::detect()" + << "\n\t Invalid inputs were given to this function " + << "\n\t is_loaded_with_image(): " << is_loaded_with_image() + << "\n\t w.size(): " << w.size() + << "\n\t get_num_dimensions(): " << get_num_dimensions() + << "\n\t this: " << this + ); + + dets.clear(); + compute_all_rect_scores(feats, w,dets,thresh); + std::sort(dets.rbegin(), dets.rend(), compare_pair_rect); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Feature_extractor_type + > + const rectangle scan_image_custom:: + get_best_matching_rect ( + const rectangle& rect + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(is_loaded_with_image(), + "\t const rectangle scan_image_custom::get_best_matching_rect()" + << "\n\t Invalid inputs were given to this function " + << "\n\t is_loaded_with_image(): " << is_loaded_with_image() + << "\n\t this: " << this + ); + + + double best_score = -1; + rectangle best_rect; + for (unsigned long i = 0; i < search_rects.size(); ++i) + { + const double score = (rect.intersect(search_rects[i])).area()/(double)(rect+search_rects[i]).area(); + if (score > best_score) + { + best_score = score; + best_rect = search_rects[i]; + } + } + return best_rect; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Feature_extractor_type + > + full_object_detection scan_image_custom:: + get_full_object_detection ( + const rectangle& rect, + const feature_vector_type& /*w*/ + ) const + { + return full_object_detection(rect); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Feature_extractor_type + > + void scan_image_custom:: + get_feature_vector ( + const full_object_detection& obj, + feature_vector_type& psi + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(is_loaded_with_image() && + psi.size() >= get_num_dimensions() && + obj.num_parts() == 0, + "\t void scan_image_custom::get_feature_vector()" + << "\n\t Invalid inputs were given to this function " + << "\n\t is_loaded_with_image(): " << is_loaded_with_image() + << "\n\t psi.size(): " << psi.size() + << "\n\t get_num_dimensions(): " << get_num_dimensions() + << "\n\t obj.num_parts(): " << obj.num_parts() + << "\n\t this: " << this + ); + + + feats.get_feature_vector(get_best_matching_rect(obj.get_rect()), psi); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SCAN_IMAGE_CuSTOM_Hh_ + diff --git a/ml/dlib/dlib/image_processing/scan_image_custom_abstract.h b/ml/dlib/dlib/image_processing/scan_image_custom_abstract.h new file mode 100644 index 000000000..ca3ba402a --- /dev/null +++ b/ml/dlib/dlib/image_processing/scan_image_custom_abstract.h @@ -0,0 +1,390 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SCAN_IMAGE_CuSTOM_ABSTRACT_Hh_ +#ifdef DLIB_SCAN_IMAGE_CuSTOM_ABSTRACT_Hh_ + +#include +#include "../matrix.h" +#include "../geometry.h" +#include "../image_processing/full_object_detection_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class example_feature_extractor + { + /*! + WHAT THIS OBJECT REPRESENTS + This object defines the interface a feature extractor must implement if it + is to be used with the scan_image_custom object defined at the bottom of + this file. + + In this case, the purpose of a feature extractor is to associated a + complete feature vector with each rectangle in an image. In particular, + each rectangle is scored by taking the dot product between this feature + vector and a weight vector. If this score is greater than a threshold then + the rectangle is output as a detection. + !*/ + + public: + + template < + typename image_type + > + void load ( + const image_type& image, + std::vector& candidate_objects + ); + /*! + ensures + - Loads the given image into this feature extractor. This means that + subsequent calls to get_feature_vector() will return the feature vector + corresponding to locations in the image given to load(). + - #candidate_objects == a set of bounding boxes in the given image that + might contain objects of interest. These are the locations that will be + checked for the presents of objects when this feature extractor is used + with the scan_image_custom object. + + !*/ + + void copy_configuration ( + const feature_extractor& item + ); + /*! + ensures + - Copies all the state information of item into *this, except for state + information populated by load(). More precisely, given two + feature extractor objects S1 and S2, the following sequence of + instructions should always result in both of them having the exact same + state: + S2.copy_configuration(S1); + S1.load(img, temp); + S2.load(img, temp); + !*/ + + unsigned long get_num_dimensions ( + ) const; + /*! + ensures + - returns the dimensionality of the feature vectors output by this object. + !*/ + + void get_feature_vector ( + const rectangle& obj, + matrix& psi + ) const; + /*! + requires + - psi.size() >= get_num_dimensions() + (i.e. psi must have preallocated its memory before this function is called) + ensures + - This function computes the feature vector associated with the given rectangle + in obj. This rectangle is interpreted as a bounding box within the last image + given to this->load() and a feature vector describing that bounding box is + output into psi. + - The feature vector is added into psi. That is, it does not overwrite the + previous contents of psi, but instead, it adds the vector to psi. + - The dimensionality of the vector added to psi is get_num_dimensions(). This + means that elements of psi after psi(get_num_dimensions()-1) are not modified. + - #psi.size() == psi.size() + (i.e. this function does not change the size of the psi vector) + !*/ + + double compute_object_score ( + const matrix& w, + const rectangle& obj + ) const; + /*! + requires + - w.size() >= get_num_dimensions() + ensures + - This function returns the dot product between the feature vector for + object box obj and the given w vector. That is, this function computes + the same number as the following code snippet: + matrix psi(w.size()); + psi = 0; + get_feature_vector(obj, psi); + return dot(psi, w); + The point of the compute_object_score() routine is to compute this dot + product in a much more efficient way than directly calling + get_feature_vector() and dot(). Therefore, compute_object_score() is an + optional function. If you can't think of a faster way to compute these + scores then do not implement compute_object_score() and the + scan_image_custom object will simply compute these scores for you. + However, it is often the case that there is something clever you can do + to make this computation faster. If that is the case, then you can + provide an implementation of this function with your feature extractor + and then scan_image_custom will use it instead of using the default + calculation method shown in the above code snippet. + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + void serialize( + const feature_extractor& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + + void deserialize( + feature_extractor& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename Feature_extractor_type + > + class scan_image_custom : noncopyable + { + /*! + REQUIREMENTS ON Feature_extractor_type + - must be an object with an interface compatible with the + example_feature_extractor defined at the top of this file. + + INITIAL VALUE + - is_loaded_with_image() == false + + WHAT THIS OBJECT REPRESENTS + This object is a tool for running a classifier over an image with the goal + of localizing each object present. The localization is in the form of the + bounding box around each object of interest. + + Unlike the scan_image_pyramid and scan_image_boxes objects, this image + scanner delegates all the work of constructing the object feature vector to + its Feature_extractor_type template argument. That is, scan_image_custom + simply asks the supplied feature extractor what boxes in the image we + should investigate and then asks the feature extractor for the complete + feature vector for each box. That is, scan_image_custom does not apply any + kind of pyramiding or other higher level processing to the features coming + out of the feature extractor. That means that when you use + scan_image_custom it is completely up to you to define the feature vector + used with each image box. + + THREAD SAFETY + Concurrent access to an instance of this object is not safe and should be + protected by a mutex lock except for the case where you are copying the + configuration (via copy_configuration()) of a scan_image_custom object to + many other threads. In this case, it is safe to copy the configuration of + a shared object so long as no other operations are performed on it. + !*/ + + public: + + typedef matrix feature_vector_type; + typedef Feature_extractor_type feature_extractor_type; + + scan_image_custom ( + ); + /*! + ensures + - this object is properly initialized + !*/ + + template < + typename image_type + > + void load ( + const image_type& img + ); + /*! + requires + - image_type must be a type with the following properties: + - image_type objects can be loaded into Feature_extractor_type + objects via Feature_extractor_type::load(). + ensures + - #is_loaded_with_image() == true + - Calls get_feature_extractor().load() on the given image. That is, we + will have loaded the image into the feature extractor in this + scan_image_custom object. We will also have stored the candidate + object locations generated by the feature extractor and will scan + over them when this->detect() is called. + - This object is ready to run a classifier over img to detect object + locations. Call detect() to do this. + !*/ + + bool is_loaded_with_image ( + ) const; + /*! + ensures + - returns true if this object has been loaded with an image to process and + false otherwise. + !*/ + + const feature_extractor_type& get_feature_extractor ( + ) const; + /*! + ensures + - returns a const reference to the feature_extractor_type object used + internally for local feature extraction. + !*/ + + void copy_configuration( + const feature_extractor_type& fe + ); + /*! + ensures + - This function performs the equivalent of + get_feature_extractor().copy_configuration(fe) (i.e. this function allows + you to configure the parameters of the underlying feature extractor used + by a scan_image_custom object) + !*/ + + void copy_configuration ( + const scan_image_custom& item + ); + /*! + ensures + - Copies all the state information of item into *this, except for state + information populated by load(). More precisely, given two + scan_image_custom objects S1 and S2, the following sequence of + instructions should always result in both of them having the exact same + state: + S2.copy_configuration(S1); + S1.load(img); + S2.load(img); + !*/ + + long get_num_dimensions ( + ) const; + /*! + ensures + - returns the number of dimensions in the feature vector for a candidate + object location. That is, this function returns get_feature_extractor().get_num_dimensions(). + !*/ + + void detect ( + const feature_vector_type& w, + std::vector >& dets, + const double thresh + ) const; + /*! + requires + - w.size() >= get_num_dimensions() + - is_loaded_with_image() == true + ensures + - Scans over all the candidate object locations produced by the feature + extractor during image loading and stores all detections into #dets. + - for all valid i: + - #dets[i].second == The candidate object location which produced this + detection. This rectangle gives the location of the detection. + - #dets[i].first == The score for this detection. This value is equal + to dot(w, feature vector for this candidate object location). + - #dets[i].first >= thresh + - #dets will be sorted in descending order. + (i.e. #dets[i].first >= #dets[j].first for all i, and j>i) + - Elements of w beyond index get_num_dimensions()-1 are ignored. I.e. only + the first get_num_dimensions() are used. + - Note that no form of non-max suppression is performed. If a locations + has a score >= thresh then it is reported in #dets. + !*/ + + void get_feature_vector ( + const full_object_detection& obj, + feature_vector_type& psi + ) const; + /*! + requires + - obj.num_parts() == 0 + - is_loaded_with_image() == true + - psi.size() >= get_num_dimensions() + (i.e. psi must have preallocated its memory before this function is called) + ensures + - This function allows you to determine the feature vector used for a + candidate object location output from detect(). Note that this vector is + added to psi. Note also that you must use get_full_object_detection() to + convert a rectangle from detect() into the needed full_object_detection. + - The dimensionality of the vector added to psi is get_num_dimensions(). This + means that elements of psi after psi(get_num_dimensions()-1) are not modified. + - Since scan_image_custom only searches a limited set of object locations, + not all possible rectangles can be output by detect(). So in the case + where obj.get_rect() could not arise from a call to detect(), this + function will map obj.get_rect() to the nearest possible rectangle and + then add the feature vector for the mapped rectangle into #psi. + - get_best_matching_rect(obj.get_rect()) == the rectangle obj.get_rect() + gets mapped to for feature extraction. + !*/ + + full_object_detection get_full_object_detection ( + const rectangle& rect, + const feature_vector_type& w + ) const; + /*! + ensures + - returns full_object_detection(rect) + (This function is here only for compatibility with the scan_image_pyramid + object) + !*/ + + const rectangle get_best_matching_rect ( + const rectangle& rect + ) const; + /*! + requires + - is_loaded_with_image() == true + ensures + - Since scan_image_custom only searches a limited set of object locations, + not all possible rectangles can be represented. Therefore, this function + allows you to supply a rectangle and obtain the nearest possible + candidate object location rectangle. + !*/ + + unsigned long get_num_detection_templates ( + ) const { return 1; } + /*! + ensures + - returns 1. Note that this function is here only for compatibility with + the scan_image_pyramid object. Notionally, its return value indicates + that a scan_image_custom object is always ready to detect objects once an + image has been loaded. + !*/ + + unsigned long get_num_movable_components_per_detection_template ( + ) const { return 0; } + /*! + ensures + - returns 0. Note that this function is here only for compatibility with + the scan_image_pyramid object. Its return value means that this object + does not support using movable part models. + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template + void serialize ( + const scan_image_custom& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + + template + void deserialize ( + scan_image_custom& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SCAN_IMAGE_CuSTOM_ABSTRACT_Hh_ + diff --git a/ml/dlib/dlib/image_processing/scan_image_pyramid.h b/ml/dlib/dlib/image_processing/scan_image_pyramid.h new file mode 100644 index 000000000..455f1a649 --- /dev/null +++ b/ml/dlib/dlib/image_processing/scan_image_pyramid.h @@ -0,0 +1,1101 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SCAN_IMaGE_PYRAMID_Hh_ +#define DLIB_SCAN_IMaGE_PYRAMID_Hh_ + +#include "scan_image_pyramid_abstract.h" +#include "../matrix.h" +#include "../geometry.h" +#include "scan_image.h" +#include "../array2d.h" +#include +#include "full_object_detection.h" +#include "../image_processing/generic_image.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename Feature_extractor_type + > + class scan_image_pyramid : noncopyable + { + + public: + + typedef matrix feature_vector_type; + + typedef Pyramid_type pyramid_type; + typedef Feature_extractor_type feature_extractor_type; + + scan_image_pyramid ( + ); + + template < + typename image_type + > + void load ( + const image_type& img + ); + + inline bool is_loaded_with_image ( + ) const; + + inline void copy_configuration( + const feature_extractor_type& fe + ); + + inline void copy_configuration ( + const scan_image_pyramid& item + ); + + const Feature_extractor_type& get_feature_extractor ( + ) const { return feats_config; } + + void add_detection_template ( + const rectangle& object_box, + const std::vector& stationary_feature_extraction_regions, + const std::vector& movable_feature_extraction_regions + ); + + void add_detection_template ( + const rectangle& object_box, + const std::vector& stationary_feature_extraction_regions + ); + + inline unsigned long get_num_detection_templates ( + ) const; + + inline unsigned long get_num_movable_components_per_detection_template ( + ) const; + + inline unsigned long get_num_stationary_components_per_detection_template ( + ) const; + + inline unsigned long get_num_components_per_detection_template ( + ) const; + + inline long get_num_dimensions ( + ) const; + + unsigned long get_max_pyramid_levels ( + ) const; + + void set_max_pyramid_levels ( + unsigned long max_levels + ); + + inline unsigned long get_max_detections_per_template ( + ) const; + + void set_min_pyramid_layer_size ( + unsigned long width, + unsigned long height + ); + + inline unsigned long get_min_pyramid_layer_width ( + ) const; + + inline unsigned long get_min_pyramid_layer_height ( + ) const; + + void set_max_detections_per_template ( + unsigned long max_dets + ); + + void detect ( + const feature_vector_type& w, + std::vector >& dets, + const double thresh + ) const; + + void get_feature_vector ( + const full_object_detection& obj, + feature_vector_type& psi + ) const; + + full_object_detection get_full_object_detection ( + const rectangle& rect, + const feature_vector_type& w + ) const; + + const rectangle get_best_matching_rect ( + const rectangle& rect + ) const; + + template + friend void serialize ( + const scan_image_pyramid& item, + std::ostream& out + ); + + template + friend void deserialize ( + scan_image_pyramid& item, + std::istream& in + ); + + private: + static bool compare_pair_rect ( + const std::pair& a, + const std::pair& b + ) + { + return a.first < b.first; + } + + struct detection_template + { + rectangle object_box; // always centered at (0,0) + std::vector rects; // template with respect to (0,0) + std::vector movable_rects; + }; + + friend void serialize(const detection_template& item, std::ostream& out) + { + int version = 1; + serialize(version, out); + serialize(item.object_box, out); + serialize(item.rects, out); + serialize(item.movable_rects, out); + } + friend void deserialize(detection_template& item, std::istream& in) + { + int version = 0; + deserialize(version, in); + if (version != 1) + throw serialization_error("Unexpected version found while deserializing a dlib::scan_image_pyramid::detection_template object."); + + deserialize(item.object_box, in); + deserialize(item.rects, in); + deserialize(item.movable_rects, in); + } + + void get_mapped_rect_and_metadata ( + const unsigned long number_pyramid_levels, + rectangle rect, + rectangle& mapped_rect, + detection_template& best_template, + rectangle& object_box, + unsigned long& best_level, + unsigned long& detection_template_idx + ) const; + + double get_match_score ( + rectangle r1, + rectangle r2 + ) const + { + // make the rectangles overlap as much as possible before computing the match score. + r1 = move_rect(r1, r2.tl_corner()); + return (r1.intersect(r2).area())/(double)(r1 + r2).area(); + } + + void test_coordinate_transforms() + { + for (long x = -10; x <= 10; x += 10) + { + for (long y = -10; y <= 10; y += 10) + { + const rectangle rect = centered_rect(x,y,5,6); + rectangle a; + + a = feats_config.image_to_feat_space(rect); + if (a.width() > 10000000 || a.height() > 10000000 ) + { + DLIB_CASSERT(false, "The image_to_feat_space() routine is outputting rectangles of an implausibly " + << "\nlarge size. This means there is probably a bug in your feature extractor."); + } + a = feats_config.feat_to_image_space(rect); + if (a.width() > 10000000 || a.height() > 10000000 ) + { + DLIB_CASSERT(false, "The feat_to_image_space() routine is outputting rectangles of an implausibly " + << "\nlarge size. This means there is probably a bug in your feature extractor."); + } + } + } + + } + + feature_extractor_type feats_config; // just here to hold configuration. use it to populate the feats elements. + array feats; + std::vector det_templates; + unsigned long max_dets_per_template; + unsigned long max_pyramid_levels; + unsigned long min_pyramid_layer_width; + unsigned long min_pyramid_layer_height; + + }; + +// ---------------------------------------------------------------------------------------- + + template + void serialize ( + const scan_image_pyramid& item, + std::ostream& out + ) + { + int version = 3; + serialize(version, out); + serialize(item.feats_config, out); + serialize(item.feats, out); + serialize(item.det_templates, out); + serialize(item.max_dets_per_template, out); + serialize(item.max_pyramid_levels, out); + serialize(item.min_pyramid_layer_width, out); + serialize(item.min_pyramid_layer_height, out); + serialize(item.get_num_dimensions(), out); + } + +// ---------------------------------------------------------------------------------------- + + template + void deserialize ( + scan_image_pyramid& item, + std::istream& in + ) + { + int version = 0; + deserialize(version, in); + if (version != 3) + throw serialization_error("Unsupported version found when deserializing a scan_image_pyramid object."); + + deserialize(item.feats_config, in); + deserialize(item.feats, in); + deserialize(item.det_templates, in); + deserialize(item.max_dets_per_template, in); + deserialize(item.max_pyramid_levels, in); + deserialize(item.min_pyramid_layer_width, in); + deserialize(item.min_pyramid_layer_height, in); + + // When developing some feature extractor, it's easy to accidentally change its + // number of dimensions and then try to deserialize data from an older version of + // your extractor into the current code. This check is here to catch that kind of + // user error. + long dims; + deserialize(dims, in); + if (item.get_num_dimensions() != dims) + throw serialization_error("Number of dimensions in serialized scan_image_pyramid doesn't match the expected number."); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// scan_image_pyramid member functions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename Feature_extractor_type + > + scan_image_pyramid:: + scan_image_pyramid ( + ) : + max_dets_per_template(10000), + max_pyramid_levels(1000), + min_pyramid_layer_width(20), + min_pyramid_layer_height(20) + { + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename Feature_extractor_type + > + template < + typename image_type + > + void scan_image_pyramid:: + load ( + const image_type& img + ) + { + unsigned long levels = 0; + rectangle rect = get_rect(img); + + // figure out how many pyramid levels we should be using based on the image size + pyramid_type pyr; + do + { + rect = pyr.rect_down(rect); + ++levels; + } while (rect.width() >= min_pyramid_layer_width && rect.height() >= min_pyramid_layer_height && + levels < max_pyramid_levels); + + if (feats.max_size() < levels) + feats.set_max_size(levels); + feats.set_size(levels); + + for (unsigned long i = 0; i < feats.size(); ++i) + feats[i].copy_configuration(feats_config); + + // build our feature pyramid + feats[0].load(img); + if (feats.size() > 1) + { + image_type temp1, temp2; + pyr(img, temp1); + feats[1].load(temp1); + swap(temp1,temp2); + + for (unsigned long i = 2; i < feats.size(); ++i) + { + pyr(temp2, temp1); + feats[i].load(temp1); + swap(temp1,temp2); + } + } + + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename Feature_extractor_type + > + unsigned long scan_image_pyramid:: + get_max_detections_per_template ( + ) const + { + return max_dets_per_template; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename Feature_extractor_type + > + void scan_image_pyramid:: + set_max_detections_per_template ( + unsigned long max_dets + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(max_dets > 0 , + "\t void scan_image_pyramid::set_max_detections_per_template()" + << "\n\t The max number of possible detections can't be zero. " + << "\n\t max_dets: " << max_dets + << "\n\t this: " << this + ); + + max_dets_per_template = max_dets; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename Feature_extractor_type + > + bool scan_image_pyramid:: + is_loaded_with_image ( + ) const + { + return feats.size() != 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename Feature_extractor_type + > + void scan_image_pyramid:: + copy_configuration( + const feature_extractor_type& fe + ) + { + test_coordinate_transforms(); + feats_config.copy_configuration(fe); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename Feature_extractor_type + > + void scan_image_pyramid:: + copy_configuration ( + const scan_image_pyramid& item + ) + { + feats_config.copy_configuration(item.feats_config); + det_templates = item.det_templates; + max_dets_per_template = item.max_dets_per_template; + max_pyramid_levels = item.max_pyramid_levels; + min_pyramid_layer_width = item.min_pyramid_layer_width; + min_pyramid_layer_height = item.min_pyramid_layer_height; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename Feature_extractor_type + > + void scan_image_pyramid:: + add_detection_template ( + const rectangle& object_box, + const std::vector& stationary_feature_extraction_regions, + const std::vector& movable_feature_extraction_regions + ) + { +#ifdef ENABLE_ASSERTS + // make sure requires clause is not broken + DLIB_ASSERT((get_num_detection_templates() == 0 || + (get_num_stationary_components_per_detection_template() == stationary_feature_extraction_regions.size() && + get_num_movable_components_per_detection_template() == movable_feature_extraction_regions.size())) && + center(object_box) == point(0,0), + "\t void scan_image_pyramid::add_detection_template()" + << "\n\t The number of rects in this new detection template doesn't match " + << "\n\t the number in previous detection templates." + << "\n\t get_num_stationary_components_per_detection_template(): " << get_num_stationary_components_per_detection_template() + << "\n\t stationary_feature_extraction_regions.size(): " << stationary_feature_extraction_regions.size() + << "\n\t get_num_movable_components_per_detection_template(): " << get_num_movable_components_per_detection_template() + << "\n\t movable_feature_extraction_regions.size(): " << movable_feature_extraction_regions.size() + << "\n\t this: " << this + ); + + for (unsigned long i = 0; i < movable_feature_extraction_regions.size(); ++i) + { + DLIB_ASSERT(center(movable_feature_extraction_regions[i]) == point(0,0), + "Invalid inputs were given to this function." + << "\n\t center(movable_feature_extraction_regions["< + void scan_image_pyramid:: + add_detection_template ( + const rectangle& object_box, + const std::vector& stationary_feature_extraction_regions + ) + { + // an empty set of movable feature regions + const std::vector movable_feature_extraction_regions; + add_detection_template(object_box, stationary_feature_extraction_regions, + movable_feature_extraction_regions); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename Feature_extractor_type + > + unsigned long scan_image_pyramid:: + get_num_detection_templates ( + ) const + { + return det_templates.size(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename Feature_extractor_type + > + unsigned long scan_image_pyramid:: + get_num_stationary_components_per_detection_template ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(get_num_detection_templates() > 0 , + "\t unsigned long scan_image_pyramid::get_num_stationary_components_per_detection_template()" + << "\n\t You need to give some detection templates before calling this function. " + << "\n\t get_num_detection_templates(): " << get_num_detection_templates() + << "\n\t this: " << this + ); + + return det_templates[0].rects.size(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename Feature_extractor_type + > + unsigned long scan_image_pyramid:: + get_num_movable_components_per_detection_template ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(get_num_detection_templates() > 0 , + "\t unsigned long scan_image_pyramid::get_num_movable_components_per_detection_template()" + << "\n\t You need to give some detection templates before calling this function. " + << "\n\t get_num_detection_templates(): " << get_num_detection_templates() + << "\n\t this: " << this + ); + + return det_templates[0].movable_rects.size(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename Feature_extractor_type + > + unsigned long scan_image_pyramid:: + get_num_components_per_detection_template ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(get_num_detection_templates() > 0 , + "\t unsigned long scan_image_pyramid::get_num_components_per_detection_template()" + << "\n\t You need to give some detection templates before calling this function. " + << "\n\t get_num_detection_templates(): " << get_num_detection_templates() + << "\n\t this: " << this + ); + + return get_num_movable_components_per_detection_template() + + get_num_stationary_components_per_detection_template(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename Feature_extractor_type + > + long scan_image_pyramid:: + get_num_dimensions ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(get_num_detection_templates() > 0 , + "\t long scan_image_pyramid::get_num_dimensions()" + << "\n\t You need to give some detection templates before calling this function. " + << "\n\t get_num_detection_templates(): " << get_num_detection_templates() + << "\n\t this: " << this + ); + + return feats_config.get_num_dimensions()*get_num_components_per_detection_template() + get_num_detection_templates(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename Feature_extractor_type + > + unsigned long scan_image_pyramid:: + get_max_pyramid_levels ( + ) const + { + return max_pyramid_levels; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename Feature_extractor_type + > + void scan_image_pyramid:: + set_max_pyramid_levels ( + unsigned long max_levels + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(max_levels > 0 , + "\t void scan_image_pyramid::set_max_pyramid_levels()" + << "\n\t You can't have zero levels. " + << "\n\t max_levels: " << max_levels + << "\n\t this: " << this + ); + + max_pyramid_levels = max_levels; + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename Feature_extractor_type + > + void scan_image_pyramid:: + detect ( + const feature_vector_type& w, + std::vector >& dets, + const double thresh + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(get_num_detection_templates() > 0 && + is_loaded_with_image() && + w.size() >= get_num_dimensions(), + "\t void scan_image_pyramid::detect()" + << "\n\t Invalid inputs were given to this function " + << "\n\t get_num_detection_templates(): " << get_num_detection_templates() + << "\n\t is_loaded_with_image(): " << is_loaded_with_image() + << "\n\t w.size(): " << w.size() + << "\n\t get_num_dimensions(): " << get_num_dimensions() + << "\n\t this: " << this + ); + + dets.clear(); + + array > saliency_images; + saliency_images.set_max_size(get_num_components_per_detection_template()); + saliency_images.set_size(get_num_components_per_detection_template()); + std::vector > stationary_region_rects(get_num_stationary_components_per_detection_template()); + std::vector > movable_region_rects(get_num_movable_components_per_detection_template()); + pyramid_type pyr; + std::vector > point_dets; + + // for all pyramid levels + for (unsigned long l = 0; l < feats.size(); ++l) + { + for (unsigned long i = 0; i < saliency_images.size(); ++i) + { + saliency_images[i].set_size(feats[l].nr(), feats[l].nc()); + const unsigned long offset = get_num_detection_templates() + feats_config.get_num_dimensions()*i; + + // build saliency images for pyramid level l + for (long r = 0; r < feats[l].nr(); ++r) + { + for (long c = 0; c < feats[l].nc(); ++c) + { + const typename feature_extractor_type::descriptor_type& descriptor = feats[l](r,c); + + double sum = 0; + for (unsigned long k = 0; k < descriptor.size(); ++k) + { + sum += w(descriptor[k].first + offset)*descriptor[k].second; + } + saliency_images[i][r][c] = sum; + } + } + } + + // now search the saliency images + for (unsigned long i = 0; i < det_templates.size(); ++i) + { + const point offset = -feats[l].image_to_feat_space(point(0,0)); + for (unsigned long j = 0; j < stationary_region_rects.size(); ++j) + { + stationary_region_rects[j] = std::make_pair(j, translate_rect(feats[l].image_to_feat_space(det_templates[i].rects[j]),offset)); + } + for (unsigned long j = 0; j < movable_region_rects.size(); ++j) + { + // Scale the size of the movable rectangle but make sure its center + // stays at point(0,0). + const rectangle temp = feats[l].image_to_feat_space(det_templates[i].movable_rects[j]); + movable_region_rects[j] = std::make_pair(j+stationary_region_rects.size(), + centered_rect(point(0,0),temp.width(), temp.height())); + } + + // Scale the object box into the feature extraction image, but keeping it + // centered at point(0,0). + rectangle scaled_object_box = feats[l].image_to_feat_space(det_templates[i].object_box); + scaled_object_box = centered_rect(point(0,0),scaled_object_box.width(), scaled_object_box.height()); + + // Each detection template gets its own special threshold in addition to + // the global detection threshold. This allows us to model the fact that + // some detection templates might be more prone to false alarming or since + // their size is different naturally require a larger or smaller threshold + // (since they integrate over a larger or smaller region of the image). + const double template_specific_thresh = w(i); + + scan_image_movable_parts(point_dets, saliency_images, scaled_object_box, + stationary_region_rects, movable_region_rects, + thresh+template_specific_thresh, max_dets_per_template); + + // convert all the point detections into rectangles at the original image scale and coordinate system + for (unsigned long j = 0; j < point_dets.size(); ++j) + { + const double score = point_dets[j].first-template_specific_thresh; + point p = point_dets[j].second; + p = feats[l].feat_to_image_space(p); + rectangle rect = translate_rect(det_templates[i].object_box, p); + rect = pyr.rect_up(rect, l); + + dets.push_back(std::make_pair(score, rect)); + } + } + } + + std::sort(dets.rbegin(), dets.rend(), compare_pair_rect); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename Feature_extractor_type + > + const rectangle scan_image_pyramid:: + get_best_matching_rect ( + const rectangle& rect + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(get_num_detection_templates() > 0 , + "\t const rectangle scan_image_pyramid::get_best_matching_rect()" + << "\n\t Invalid inputs were given to this function " + << "\n\t get_num_detection_templates(): " << get_num_detection_templates() + << "\n\t this: " << this + ); + + rectangle mapped_rect, object_box; + detection_template best_template; + unsigned long best_level, junk; + get_mapped_rect_and_metadata(max_pyramid_levels, rect, mapped_rect, best_template, object_box, best_level, junk); + return mapped_rect; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename Feature_extractor_type + > + void scan_image_pyramid:: + get_mapped_rect_and_metadata ( + const unsigned long number_pyramid_levels, + rectangle rect, + rectangle& mapped_rect, + detection_template& best_template, + rectangle& object_box, + unsigned long& best_level, + unsigned long& detection_template_idx + ) const + { + pyramid_type pyr; + // Figure out the pyramid level which best matches rect against one of our + // detection template object boxes. + best_level = 0; + double best_match_score = -1; + + + // Find the best matching detection template for rect + for (unsigned long l = 0; l < number_pyramid_levels; ++l) + { + const rectangle temp = pyr.rect_down(rect,l); + if (temp.area() <= 1) + break; + + // At this pyramid level, what matches best? + for (unsigned long t = 0; t < det_templates.size(); ++t) + { + const double match_score = get_match_score(det_templates[t].object_box, temp); + if (match_score > best_match_score) + { + best_match_score = match_score; + best_level = l; + best_template = det_templates[t]; + detection_template_idx = t; + } + } + } + + + // Now we translate best_template into the right spot (it should be centered at the location + // determined by rect) and convert it into the feature image coordinate system. + rect = pyr.rect_down(rect,best_level); + const point offset = -feats_config.image_to_feat_space(point(0,0)); + const point origin = feats_config.image_to_feat_space(center(rect)) + offset; + for (unsigned long k = 0; k < best_template.rects.size(); ++k) + { + rectangle temp = best_template.rects[k]; + temp = feats_config.image_to_feat_space(temp); + temp = translate_rect(temp, origin); + best_template.rects[k] = temp; + } + for (unsigned long k = 0; k < best_template.movable_rects.size(); ++k) + { + rectangle temp = best_template.movable_rects[k]; + temp = feats_config.image_to_feat_space(temp); + temp = centered_rect(point(0,0), temp.width(), temp.height()); + best_template.movable_rects[k] = temp; + } + + const rectangle scaled_object_box = feats_config.image_to_feat_space(best_template.object_box); + object_box = centered_rect(origin-offset, scaled_object_box.width(), scaled_object_box.height()); + + // The input rectangle was mapped to one of the detection templates. Reverse the process + // to figure out what the mapped rectangle is in the original input space. + mapped_rect = translate_rect(best_template.object_box, feats_config.feat_to_image_space(origin-offset)); + mapped_rect = pyr.rect_up(mapped_rect, best_level); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename Feature_extractor_type + > + full_object_detection scan_image_pyramid:: + get_full_object_detection ( + const rectangle& rect, + const feature_vector_type& w + ) const + { + // fill in movable part positions. + + rectangle mapped_rect; + detection_template best_template; + unsigned long best_level, junk; + rectangle object_box; + get_mapped_rect_and_metadata(feats.size(), rect, mapped_rect, best_template, object_box, best_level, junk); + + Pyramid_type pyr; + + array2d saliency_image, sum_img; + + double total_temp_score = 0; + // convert into feature space. + object_box = object_box.intersect(get_rect(feats[best_level])); + + std::vector movable_parts; + movable_parts.reserve(get_num_movable_components_per_detection_template()); + for (unsigned long i = 0; i < get_num_movable_components_per_detection_template(); ++i) + { + // make the saliency_image for the ith movable part. + + const rectangle part_rect = best_template.movable_rects[i]; + const rectangle area = grow_rect(object_box, + part_rect.width()/2, + part_rect.height()/2).intersect(get_rect(feats[best_level])); + + saliency_image.set_size(area.height(), area.width()); + const unsigned long offset = get_num_detection_templates() + feats_config.get_num_dimensions()*(i+get_num_stationary_components_per_detection_template()); + + // build saliency image for pyramid level best_level + for (long r = area.top(); r <= area.bottom(); ++r) + { + for (long c = area.left(); c <= area.right(); ++c) + { + const typename feature_extractor_type::descriptor_type& descriptor = feats[best_level](r,c); + + double sum = 0; + for (unsigned long k = 0; k < descriptor.size(); ++k) + { + sum += w(descriptor[k].first + offset)*descriptor[k].second; + } + saliency_image[r-area.top()][c-area.left()] = sum; + } + } + + sum_img.set_size(saliency_image.nr(), saliency_image.nc()); + sum_filter_assign(saliency_image, sum_img, part_rect); + // Figure out where the maximizer is in sum_img. Note that we + // only look in the part of sum_img that corresponds to a location inside + // object_box. + rectangle valid_area = get_rect(sum_img); + valid_area.left() += object_box.left() - area.left(); + valid_area.top() += object_box.top() - area.top(); + valid_area.right() += object_box.right() - area.right(); + valid_area.bottom() += object_box.bottom() - area.bottom(); + double max_val = 0; + point max_loc; + for (long r = valid_area.top(); r <= valid_area.bottom(); ++r) + { + for (long c = valid_area.left(); c <= valid_area.right(); ++c) + { + if (sum_img[r][c] > max_val) + { + //if (object_box.contains(point(c,r) + area.tl_corner())) + { + max_loc = point(c,r); + max_val = sum_img[r][c]; + } + } + } + } + + if (max_val <= 0) + { + max_loc = OBJECT_PART_NOT_PRESENT; + } + else + { + total_temp_score += max_val; + // convert max_loc back into feature image space from our cropped image. + max_loc += area.tl_corner(); + + // now convert from feature space to image space. + max_loc = feats[best_level].feat_to_image_space(max_loc); + max_loc = pyr.point_up(max_loc, best_level); + max_loc = nearest_point(rect, max_loc); + } + + movable_parts.push_back(max_loc); + } + + return full_object_detection(rect, movable_parts); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename Feature_extractor_type + > + void scan_image_pyramid:: + get_feature_vector ( + const full_object_detection& obj, + feature_vector_type& psi + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(get_num_detection_templates() > 0 && + is_loaded_with_image() && + psi.size() >= get_num_dimensions() && + obj.num_parts() == get_num_movable_components_per_detection_template(), + "\t void scan_image_pyramid::get_feature_vector()" + << "\n\t Invalid inputs were given to this function " + << "\n\t get_num_detection_templates(): " << get_num_detection_templates() + << "\n\t is_loaded_with_image(): " << is_loaded_with_image() + << "\n\t psi.size(): " << psi.size() + << "\n\t get_num_dimensions(): " << get_num_dimensions() + << "\n\t get_num_movable_components_per_detection_template(): " << get_num_movable_components_per_detection_template() + << "\n\t obj.num_parts(): " << obj.num_parts() + << "\n\t this: " << this + ); + DLIB_ASSERT(all_parts_in_rect(obj), + "\t void scan_image_pyramid::get_feature_vector()" + << "\n\t Invalid inputs were given to this function " + << "\n\t obj.get_rect(): " << obj.get_rect() + << "\n\t this: " << this + ); + + + + rectangle mapped_rect; + detection_template best_template; + unsigned long best_level, detection_template_idx; + rectangle object_box; + get_mapped_rect_and_metadata(feats.size(), obj.get_rect(), mapped_rect, best_template, object_box, best_level, detection_template_idx); + + psi(detection_template_idx) -= 1; + + Pyramid_type pyr; + + // put the movable rects at the places indicated by obj. + std::vector rects = best_template.rects; + for (unsigned long i = 0; i < obj.num_parts(); ++i) + { + if (obj.part(i) != OBJECT_PART_NOT_PRESENT) + { + // map from the original image to scaled feature space. + point loc = feats[best_level].image_to_feat_space(pyr.point_down(obj.part(i), best_level)); + // Make sure the movable part always stays within the object_box. + // Otherwise it would be at a place that the detect() function can never + // look. + loc = nearest_point(object_box, loc); + rects.push_back(translate_rect(best_template.movable_rects[i], loc)); + } + else + { + // add an empty rectangle since this part wasn't observed. + rects.push_back(rectangle()); + } + } + + // pull features out of all the boxes in rects. + for (unsigned long j = 0; j < rects.size(); ++j) + { + const rectangle rect = rects[j].intersect(get_rect(feats[best_level])); + const unsigned long template_region_id = j; + const unsigned long offset = get_num_detection_templates() + feats_config.get_num_dimensions()*template_region_id; + for (long r = rect.top(); r <= rect.bottom(); ++r) + { + for (long c = rect.left(); c <= rect.right(); ++c) + { + const typename feature_extractor_type::descriptor_type& descriptor = feats[best_level](r,c); + for (unsigned long k = 0; k < descriptor.size(); ++k) + { + psi(descriptor[k].first + offset) += descriptor[k].second; + } + } + } + } + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename Feature_extractor_type + > + void scan_image_pyramid:: + set_min_pyramid_layer_size ( + unsigned long width, + unsigned long height + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(width > 0 && height > 0 , + "\t void scan_image_pyramid::set_min_pyramid_layer_size()" + << "\n\t These sizes can't be zero. " + << "\n\t width: " << width + << "\n\t height: " << height + << "\n\t this: " << this + ); + + min_pyramid_layer_width = width; + min_pyramid_layer_height = height; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename Feature_extractor_type + > + unsigned long scan_image_pyramid:: + get_min_pyramid_layer_width ( + ) const + { + return min_pyramid_layer_width; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename Feature_extractor_type + > + unsigned long scan_image_pyramid:: + get_min_pyramid_layer_height ( + ) const + { + return min_pyramid_layer_height; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SCAN_IMaGE_PYRAMID_Hh_ + + diff --git a/ml/dlib/dlib/image_processing/scan_image_pyramid_abstract.h b/ml/dlib/dlib/image_processing/scan_image_pyramid_abstract.h new file mode 100644 index 000000000..e985a3f32 --- /dev/null +++ b/ml/dlib/dlib/image_processing/scan_image_pyramid_abstract.h @@ -0,0 +1,495 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SCAN_IMaGE_PYRAMID_ABSTRACT_Hh_ +#ifdef DLIB_SCAN_IMaGE_PYRAMID_ABSTRACT_Hh_ + +#include "../matrix.h" +#include "../geometry.h" +#include "../image_processing.h" +#include "../array2d.h" +#include +#include "full_object_detection_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename Feature_extractor_type + > + class scan_image_pyramid : noncopyable + { + /*! + REQUIREMENTS ON Pyramid_type + - must be one of the pyramid_down objects defined in + dlib/image_transforms/image_pyramid_abstract.h or an object with + a compatible interface + + REQUIREMENTS ON Feature_extractor_type + - must be an object with an interface compatible with the hashed_feature_image + object defined in dlib/image_keypoint/hashed_feature_image_abstract.h or + with the nearest_neighbor_feature_image object defined in + dlib/image_keypoint/nearest_neighbor_feature_image_abstract.h + + INITIAL VALUE + - get_num_detection_templates() == 0 + - is_loaded_with_image() == false + - get_max_detections_per_template() == 10000 + - get_max_pyramid_levels() == 1000 + - get_min_pyramid_layer_width() == 20 + - get_min_pyramid_layer_height() == 20 + + WHAT THIS OBJECT REPRESENTS + This object is a tool for running a sliding window classifier over + an image pyramid. This object can also be understood as a general + tool for implementing the spatial pyramid models described in the paper: + Beyond Bags of Features: Spatial Pyramid Matching for Recognizing + Natural Scene Categories by Svetlana Lazebnik, Cordelia Schmid, + and Jean Ponce + It also includes the ability to represent movable part models. + + + + + The sliding window classifiers used by this object have three parts: + 1. The underlying feature extraction provided by Feature_extractor_type + objects, which associate a vector with each location in an image. + + 2. A detection template. This is a rectangle which defines the shape of a + sliding window (i.e. the object_box), as well as a set of rectangular feature + extraction regions inside it. This set of regions defines the spatial + structure of the overall feature extraction within a sliding window. In + particular, each location of a sliding window has a feature vector + associated with it. This feature vector is defined as follows: + - Let N denote the number of feature extraction zones. + - Let M denote the dimensionality of the vectors output by Feature_extractor_type + objects. + - Let F(i) == the M dimensional vector which is the sum of all vectors + given by our Feature_extractor_type object inside the i-th feature extraction + zone. + - Then the feature vector for a sliding window is an M*N dimensional vector + [F(1) F(2) F(3) ... F(N)] (i.e. it is a concatenation of the N vectors). + This feature vector can be thought of as a collection of N "bags of features", + each bag coming from a spatial location determined by one of the rectangular + feature extraction zones. + + 3. A weight vector and a threshold value. The dot product between the weight + vector and the feature vector for a sliding window location gives the score + of the window. If this score is greater than the threshold value then the + window location is output as a detection. + + Finally, the sliding window classifiers described above are applied to every level of + an image pyramid. Moreover, some of the feature extraction zones are allowed to move + freely within the object box. This means that when we are sliding the classifier over + an image, some feature extraction zones are stationary (i.e. always in the same place + relative to the object box) while others are allowed to move anywhere within the object + box. In particular, the movable regions are placed at the locations that maximize the + score of the classifier. Note further that each of the movable feature extraction + zones must pass a threshold test for it to be included. That is, if the score that a + movable zone would contribute to the overall score for a sliding window location is not + positive then that zone is not included in the feature vector (i.e. its part of the + feature vector is set to zero. This way the length of the feature vector stays + constant). This movable region construction allows us to represent objects with parts + that move around relative to the object box. For example, a human has hands but they + aren't always in the same place relative to a person's bounding box. + + THREAD SAFETY + Concurrent access to an instance of this object is not safe and should be protected + by a mutex lock except for the case where you are copying the configuration + (via copy_configuration()) of a scan_image_pyramid object to many other threads. + In this case, it is safe to copy the configuration of a shared object so long + as no other operations are performed on it. + !*/ + public: + + typedef matrix feature_vector_type; + + typedef Pyramid_type pyramid_type; + typedef Feature_extractor_type feature_extractor_type; + + scan_image_pyramid ( + ); + /*! + ensures + - this object is properly initialized + !*/ + + template < + typename image_type + > + void load ( + const image_type& img + ); + /*! + requires + - image_type must be a type with the following properties: + - image_type is default constructable. + - image_type is swappable by the global swap() function. + - image_type logically represents some kind of image and therefore its + number of rows and columns can be queried via num_rows(img) and + num_columns(img) respectively. + - image_type objects can be loaded into Feature_extractor_type + objects via Feature_extractor_type::load(). + - image_type objects can be used with Pyramid_type. That is, + if pyr is an object of type Pyramid_type while img1 and img2 + are objects of image_type, then pyr(img1,img2) should be + a valid expression which downsamples img1 into img2. + ensures + - #is_loaded_with_image() == true + - This object is ready to run sliding window classifiers over img. Call + detect() to do this. + !*/ + + bool is_loaded_with_image ( + ) const; + /*! + ensures + - returns true if this object has been loaded with an image to process + and false otherwise. + !*/ + + const feature_extractor_type& get_feature_extractor ( + ) const; + /*! + ensures + - returns a const reference to the feature_extractor_type object used + internally for local feature extraction. + !*/ + + void copy_configuration( + const feature_extractor_type& fe + ); + /*! + ensures + - This function performs the equivalent of + get_feature_extractor().copy_configuration(fe) (i.e. this function allows + you to configure the parameters of the underlying feature extractor used + by a scan_image_pyramid object) + !*/ + + void copy_configuration ( + const scan_image_pyramid& item + ); + /*! + ensures + - copies all the state information of item into *this, except for state + information populated by load(). More precisely, given two scan_image_pyramid + objects S1 and S2, the following sequence of instructions should always + result in both of them having the exact same state. + S2.copy_configuration(S1); + S1.load(img); + S2.load(img); + !*/ + + void add_detection_template ( + const rectangle& object_box, + const std::vector& stationary_feature_extraction_regions, + const std::vector& movable_feature_extraction_regions + ); + /*! + requires + - center(object_box) == point(0,0) + - for all valid i: + - center(movable_feature_extraction_regions[i]) == point(0,0) + - if (get_num_detection_templates() > 0) then + - get_num_stationary_components_per_detection_template() == stationary_feature_extraction_regions.size() + - get_num_movable_components_per_detection_template() == movable_feature_extraction_regions.size() + (i.e. if you already have detection templates in this object, then + any new detection template must declare a consistent number of + feature extraction regions) + ensures + - Adds another detection template to this object. In particular, object_box + defines the size and shape of a sliding window while stationary_feature_extraction_regions + and movable_feature_extraction_regions defines the locations for feature extraction as + discussed in the WHAT THIS OBJECT REPRESENTS section above. Note also that the locations of + the stationary feature extraction regions are relative to the object_box. + - #get_num_detection_templates() == get_num_detection_templates() + 1 + - The order of rectangles in stationary_feature_extraction_regions and + movable_feature_extraction_regions matters. Recall that each rectangle + gets its own set of features. So given two different templates, their + i-th rectangles will both share the same part of the weight vector (i.e. the w + supplied to detect()). So there should be some reasonable correspondence + between the rectangle ordering in different detection templates. For, + example, different detection templates should place corresponding feature + extraction regions in roughly the same part of the object_box. + - #get_num_stationary_components_per_detection_template() = stationary_feature_extraction_regions.size() + - #get_num_movable_components_per_detection_template() = movable_feature_extraction_regions.size() + !*/ + + void add_detection_template ( + const rectangle& object_box, + const std::vector& stationary_feature_extraction_regions + ); + /*! + ensures + - calls add_detection_template(object_box, stationary_feature_extraction_regions, empty_list) + where empty_list is a vector of size 0. I.e. this function is just a convenience + routine for adding detection templates with no movable regions. + !*/ + + unsigned long get_num_detection_templates ( + ) const; + /*! + ensures + - returns the number of detection templates in this object + !*/ + + unsigned long get_num_stationary_components_per_detection_template ( + ) const; + /*! + requires + - get_num_detection_templates() > 0 + ensures + - A detection template is a rectangle which defines the shape of a sliding + window (the object_box), as well as a set of rectangles which define + feature extraction zones. This function returns the number of stationary + feature extraction zones in the detection templates used by this object. + !*/ + + unsigned long get_num_movable_components_per_detection_template ( + ) const; + /*! + requires + - get_num_detection_templates() > 0 + ensures + - A detection template is a rectangle which defines the shape of a sliding + window (the object_box), as well as a set of rectangles which define + feature extraction zones. This function returns the number of movable + feature extraction zones in the detection templates used by this object. + !*/ + + unsigned long get_num_components_per_detection_template ( + ) const; + /*! + requires + - get_num_detection_templates() > 0 + ensures + - returns the total number of feature extraction zones in the detection + templates used by this object. That is, returns the following: + - get_num_movable_components_per_detection_template() + + get_num_stationary_components_per_detection_template() + !*/ + + long get_num_dimensions ( + ) const; + /*! + requires + - get_num_detection_templates() > 0 + ensures + - returns the number of dimensions in the feature vector for a sliding window + location. This value is the dimensionality of the underlying feature vectors + produced by Feature_extractor_type times (get_num_stationary_components_per_detection_template() + + get_num_movable_components_per_detection_template()). + !*/ + + unsigned long get_max_pyramid_levels ( + ) const; + /*! + ensures + - returns the maximum number of image pyramid levels this object will use. + Note that #get_max_pyramid_levels() == 1 indicates that no image pyramid + will be used at all. That is, only the original image will be processed + and no lower scale versions will be created. + !*/ + + void set_max_pyramid_levels ( + unsigned long max_levels + ); + /*! + requires + - max_levels > 0 + ensures + - #get_max_pyramid_levels() == max_levels + !*/ + + void set_min_pyramid_layer_size ( + unsigned long width, + unsigned long height + ); + /*! + requires + - width > 0 + - height > 0 + ensures + - #get_min_pyramid_layer_width() == width + - #get_min_pyramid_layer_height() == height + !*/ + + inline unsigned long get_min_pyramid_layer_width ( + ) const; + /*! + ensures + - returns the smallest allowable width of an image in the image pyramid. + All pyramids will always include the original input image, however, no + pyramid levels will be created which have a width smaller than the + value returned by this function. + !*/ + + inline unsigned long get_min_pyramid_layer_height ( + ) const; + /*! + ensures + - returns the smallest allowable height of an image in the image pyramid. + All pyramids will always include the original input image, however, no + pyramid levels will be created which have a height smaller than the + value returned by this function. + !*/ + + unsigned long get_max_detections_per_template ( + ) const; + /*! + ensures + - For each image pyramid layer and detection template, this object scans a sliding + window classifier over an image and produces a number of detections. This + function returns a number which defines a hard upper limit on the number of + detections allowed by a single scan. This means that the total number of + possible detections produced by detect() is get_max_detections_per_template()* + get_num_detection_templates()*(number of image pyramid layers). Additionally, + if the maximum number of detections is reached during a scan then this object + will return a random subsample of all detections which are above the detection + threshold. + !*/ + + void set_max_detections_per_template ( + unsigned long max_dets + ); + /*! + requires + - max_dets > 0 + ensures + - #get_max_detections_per_template() == max_dets + !*/ + + void detect ( + const feature_vector_type& w, + std::vector >& dets, + const double thresh + ) const; + /*! + requires + - w.size() >= get_num_dimensions() + - is_loaded_with_image() == true + - get_num_detection_templates() > 0 + ensures + - Scans all the detection templates over all pyramid layers as discussed in the + WHAT THIS OBJECT REPRESENTS section and stores all detections into #dets. + - for all valid i: + - #dets[i].second == The object box which produced this detection. This rectangle gives + the location of the detection. Note that the rectangle will have been converted back into + the original image input space. That is, if this detection was made at a low level in the + image pyramid then the object box will have been automatically mapped up the pyramid layers + to the original image space. Or in other words, if you plot #dets[i].second on top of the + image given to load() it will show up in the right place. + - #dets[i].first == The score for this detection. This value is equal to dot(w, feature vector + for this sliding window location). + - #dets[i].first >= thresh + - #dets will be sorted in descending order. (i.e. #dets[i].first >= #dets[j].first for all i, and j>i) + - Elements of w beyond index get_num_dimensions()-1 are ignored. I.e. only the first + get_num_dimensions() are used. + - Note that no form of non-max suppression is performed. If a window has a score >= thresh + then it is reported in #dets (assuming the limit imposed by get_max_detections_per_template() hasn't + been reached). + !*/ + + const rectangle get_best_matching_rect ( + const rectangle& rect + ) const; + /*! + requires + - get_num_detection_templates() > 0 + ensures + - Since scan_image_pyramid is a sliding window classifier system, not all possible rectangles + can be represented. Therefore, this function allows you to supply a rectangle and obtain the + nearest possible sliding window rectangle. + !*/ + + void get_feature_vector ( + const full_object_detection& obj, + feature_vector_type& psi + ) const; + /*! + requires + - all_parts_in_rect(obj) == true + - obj.num_parts() == get_num_movable_components_per_detection_template() + - is_loaded_with_image() == true + - get_num_detection_templates() > 0 + - psi.size() >= get_num_dimensions() + (i.e. psi must have preallocated its memory before this function is called) + ensures + - This function allows you to determine the feature vector used for a + sliding window location. Note that this vector is added to psi. Note + also that you must use get_full_object_detection() to convert a rect from + detect() into the needed full_object_detection. + - The dimensionality of the vector added to psi is get_num_dimensions(). This + means that elements of psi after psi(get_num_dimensions()-1) are not modified. + - Since scan_image_pyramid is a sliding window classifier system, not all + possible rectangles can be output by detect(). So in the case where + obj.get_rect() could not arise from a call to detect(), this function + will map obj.get_rect() to the nearest possible object box and then add + the feature vector for the mapped rectangle into #psi. + - get_best_matching_rect(obj.get_rect()) == the rectangle obj.get_rect() + gets mapped to for feature extraction. + !*/ + + full_object_detection get_full_object_detection ( + const rectangle& rect, + const feature_vector_type& w + ) const; + /*! + requires + - w.size() >= get_num_dimensions() + - is_loaded_with_image() == true + - get_num_detection_templates() > 0 + ensures + - This function allows you to determine the full_object_detection + corresponding to a sliding window location. Note that the detect() + routine doesn't return the locations of the movable parts in a detected + object. Therefore, if you are using any movable parts in your model you + must use get_full_object_detection() to find out where the movable parts + were detected. To do this, you supply the w and detected rectangle. + Then the corresponding fully populated full_object_detection will be + returned. + - returns a full_object_detection, OBJ, such that: + - OBJ.get_rect() == rect + - OBJ.num_parts() == get_num_movable_components_per_detection_template() + - OBJ.part(i) == the location of the i-th movable part inside this detection, + or OBJECT_PART_NOT_PRESENT if the part was not found. + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename Pyramid_type, + typename Feature_extractor_type + > + void serialize ( + const scan_image_pyramid& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + + template < + typename Pyramid_type, + typename Feature_extractor_type + > + void deserialize ( + scan_image_pyramid& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SCAN_IMaGE_PYRAMID_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/image_processing/scan_image_pyramid_tools.h b/ml/dlib/dlib/image_processing/scan_image_pyramid_tools.h new file mode 100644 index 000000000..874b995b4 --- /dev/null +++ b/ml/dlib/dlib/image_processing/scan_image_pyramid_tools.h @@ -0,0 +1,180 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SCAN_IMaGE_PYRAMID_TOOLS_Hh_ +#define DLIB_SCAN_IMaGE_PYRAMID_TOOLS_Hh_ + +#include "scan_image_pyramid_tools_abstract.h" +#include "../statistics.h" +#include +#include "../geometry.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + inline bool compare_first ( + const std::pair& a, + const std::pair& b + ) + { + return a.first < b.first; + } + } + + + template + std::vector determine_object_boxes ( + const image_scanner_type& scanner, + const std::vector& rects, + double min_match_score + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(0 < min_match_score && min_match_score <= 1, + "\t std::vector determine_object_boxes()" + << "\n\t Invalid inputs were given to this function. " + << "\n\t min_match_score: " << min_match_score + ); + + typename image_scanner_type::pyramid_type pyr; + + typedef std::list > list_type; + + unsigned long max_area = 0; + + // Copy rects into sorted_rects and sort them in order of increasing area. But + // only include the rectangles that aren't already obtainable by the scanner. + list_type sorted_rects; + for (unsigned long i = 0; i < rects.size(); ++i) + { + if (scanner.get_num_detection_templates() > 0) + { + rectangle temp = scanner.get_best_matching_rect(rects[i]); + const double match_score = (rects[i].intersect(temp).area())/(double)(rects[i] + temp).area(); + // skip this rectangle if it's already matched well enough. + if (match_score > min_match_score) + continue; + } + max_area = std::max(rects[i].area(), max_area); + sorted_rects.push_back(std::make_pair(rects[i].area(), rects[i])); + } + sorted_rects.sort(dlib::impl::compare_first); + + // Make sure this area value is comfortably larger than all the + // rectangles' areas. + max_area = 3*max_area + 100; + + std::vector object_boxes; + + while (sorted_rects.size() != 0) + { + rectangle cur = sorted_rects.front().second; + sorted_rects.pop_front(); + object_boxes.push_back(centered_rect(point(0,0), cur.width(), cur.height())); + + // Scale cur up the image pyramid and remove any rectangles which match. + // But also stop when cur gets large enough to not match anything. + for (unsigned long itr = 0; + itr < scanner.get_max_pyramid_levels() && cur.area() < max_area; + ++itr) + { + list_type::iterator i = sorted_rects.begin(); + while (i != sorted_rects.end()) + { + const rectangle temp = move_rect(i->second, cur.tl_corner()); + const double match_score = (cur.intersect(temp).area())/(double)(cur + temp).area(); + if (match_score > min_match_score) + { + i = sorted_rects.erase(i); + } + else + { + ++i; + } + } + + cur = pyr.rect_up(cur); + } + + } + + return object_boxes; + } + +// ---------------------------------------------------------------------------------------- + + template + std::vector determine_object_boxes ( + const image_scanner_type& scanner, + const std::vector >& rects, + double min_match_score + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(0 < min_match_score && min_match_score <= 1, + "\t std::vector determine_object_boxes()" + << "\n\t Invalid inputs were given to this function. " + << "\n\t min_match_score: " << min_match_score + ); + + std::vector temp; + for (unsigned long i = 0; i < rects.size(); ++i) + { + for (unsigned long j = 0; j < rects[i].size(); ++j) + { + temp.push_back(rects[i][j]); + } + } + + return determine_object_boxes(scanner, temp, min_match_score); + } + +// ---------------------------------------------------------------------------------------- + + template + void setup_grid_detection_templates ( + image_scanner_type& scanner, + const std::vector >& rects, + unsigned int cells_x, + unsigned int cells_y, + double min_match_score = 0.75 + ) + { + const std::vector& object_boxes = determine_object_boxes(scanner, rects, min_match_score); + for (unsigned long i = 0; i < object_boxes.size(); ++i) + { + scanner.add_detection_template(object_boxes[i], create_grid_detection_template(object_boxes[i], cells_x, cells_y)); + } + } + +// ---------------------------------------------------------------------------------------- + + template + void setup_grid_detection_templates_verbose ( + image_scanner_type& scanner, + const std::vector >& rects, + unsigned int cells_x, + unsigned int cells_y, + double min_match_score = 0.75 + ) + { + const std::vector& object_boxes = determine_object_boxes(scanner, rects, min_match_score); + std::cout << "number of detection templates: "<< object_boxes.size() << std::endl; + for (unsigned long i = 0; i < object_boxes.size(); ++i) + { + std::cout << " object box " << i << ": width: " << object_boxes[i].width() + << " height: "<< object_boxes[i].height() << std::endl; + scanner.add_detection_template(object_boxes[i], create_grid_detection_template(object_boxes[i], cells_x, cells_y)); + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SCAN_IMaGE_PYRAMID_TOOLS_Hh_ + diff --git a/ml/dlib/dlib/image_processing/scan_image_pyramid_tools_abstract.h b/ml/dlib/dlib/image_processing/scan_image_pyramid_tools_abstract.h new file mode 100644 index 000000000..83a572df7 --- /dev/null +++ b/ml/dlib/dlib/image_processing/scan_image_pyramid_tools_abstract.h @@ -0,0 +1,118 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SCAN_IMaGE_PYRAMID_TOOLS_ABSTRACT_Hh_ +#ifdef DLIB_SCAN_IMaGE_PYRAMID_TOOLS_ABSTRACT_Hh_ + +#include "scan_image_pyramid_abstract.h" +#include +#include "../geometry.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_scanner_type + > + std::vector determine_object_boxes ( + const image_scanner_type& scanner, + const std::vector& rects, + double min_match_score + ); + /*! + requires + - 0 < min_match_score <= 1 + - image_scanner_type == an implementation of the scan_image_pyramid + object defined in dlib/image_processing/scan_image_pyramid_tools_abstract.h + ensures + - returns a set of object boxes which, when used as detection templates with + the given scanner, can attain at least min_match_score alignment with every + element of rects. Note that the alignment between two rectangles A and B is + defined as: + (A.intersect(B).area())/(double)(A+B).area() + - Only elements of rects which are not already well matched by the scanner are + considered. That is, if the scanner already has some detection templates in + it then the contents of rects will be checked against those detection + templates and elements with a match better than min_match_score are ignore. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_scanner_type + > + std::vector determine_object_boxes ( + const image_scanner_type& scanner, + const std::vector >& rects, + double min_match_score + ); + /*! + requires + - 0 < min_match_score <= 1 + - image_scanner_type == an implementation of the scan_image_pyramid + object defined in dlib/image_processing/scan_image_pyramid_tools_abstract.h + ensures + - copies all rectangles in rects into a std::vector object, call it + R. Then this function returns determine_object_boxes(scanner,R,min_match_score). + That is, it just called the version of determine_object_boxes() defined above + and returns the results. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_scanner_type + > + void setup_grid_detection_templates ( + image_scanner_type& scanner, + const std::vector >& rects, + unsigned int cells_x, + unsigned int cells_y, + double min_match_score = 0.75 + ); + /*! + requires + - cells_x > 0 + - cells_y > 0 + - 0 < min_match_score <= 1 + - image_scanner_type == an implementation of the scan_image_pyramid + object defined in dlib/image_processing/scan_image_pyramid_tools_abstract.h + ensures + - uses determine_object_boxes(scanner,rects,min_match_score) to obtain a set of + object boxes and then adds them to the given scanner object as detection templates. + Also uses create_grid_detection_template(object_box, cells_x, cells_y) to create + each feature extraction region. Therefore, the detection templates will extract + features from a regular grid inside each object box. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_scanner_type + > + void setup_grid_detection_templates_verbose ( + image_scanner_type& scanner, + const std::vector >& rects, + unsigned int cells_x, + unsigned int cells_y, + double min_match_score = 0.75 + ); + /*! + requires + - cells_x > 0 + - cells_y > 0 + - 0 < min_match_score <= 1 + - image_scanner_type == an implementation of the scan_image_pyramid + object defined in dlib/image_processing/scan_image_pyramid_tools_abstract.h + ensures + - this function is identical to setup_grid_detection_templates() except + that it also outputs the selected detection templates to standard out. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SCAN_IMaGE_PYRAMID_TOOLS_ABSTRACT_Hh_ + diff --git a/ml/dlib/dlib/image_processing/setup_hashed_features.h b/ml/dlib/dlib/image_processing/setup_hashed_features.h new file mode 100644 index 000000000..5b82cecb4 --- /dev/null +++ b/ml/dlib/dlib/image_processing/setup_hashed_features.h @@ -0,0 +1,219 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SETUP_HAShED_FEATURES_Hh_ +#define DLIB_SETUP_HAShED_FEATURES_Hh_ + +#include "setup_hashed_features_abstract.h" +#include "scan_image_pyramid.h" +#include "scan_image_boxes.h" +#include "../lsh.h" +#include "../statistics.h" +#include "../image_keypoint.h" +#include "../geometry.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class image_hash_construction_failure : public error + { + public: + image_hash_construction_failure( + const std::string& a + ): error(a) {} + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename image_scanner + > + void use_uniform_feature_weights ( + image_scanner& scanner + ) + { + typename image_scanner::feature_extractor_type fe; + fe.copy_configuration(scanner.get_feature_extractor()); + fe.use_uniform_feature_weights(); + scanner.copy_configuration(fe); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_scanner + > + void use_relative_feature_weights ( + image_scanner& scanner + ) + { + typename image_scanner::feature_extractor_type fe; + fe.copy_configuration(scanner.get_feature_extractor()); + fe.use_relative_feature_weights(); + scanner.copy_configuration(fe); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// stuff for scan_image_pyramid +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename image_array, + typename pyramid, + typename feature_extractor, + template class feature_image + > + void setup_hashed_features ( + scan_image_pyramid >& scanner, + const image_array& images, + const feature_extractor& fe, + int bits, + unsigned long num_samples = 200000 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(0 < bits && bits <= 32 && + num_samples > 1 && + images.size() > 0, + "\t void setup_hashed_features()" + << "\n\t Invalid inputs were given to this function. " + << "\n\t bits: " << bits + << "\n\t num_samples: " << num_samples + << "\n\t images.size(): " << images.size() + ); + + pyramid pyr; + + const random_subset_selector& samps = + randomly_sample_image_features(images, pyr, fe, num_samples); + + if (samps.size() <= 1) + throw dlib::image_hash_construction_failure("Images too small, not able to gather enough samples to make hash"); + + projection_hash phash = create_random_projection_hash(samps, bits); + + feature_image hfe; + hfe.copy_configuration(scanner.get_feature_extractor()); + hfe.set_hash(phash); + hfe.copy_configuration(fe); + scanner.copy_configuration(hfe); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_array, + typename pyramid, + typename feature_extractor, + template class feature_image + > + void setup_hashed_features ( + scan_image_pyramid >& scanner, + const image_array& images, + int bits, + unsigned long num_samples = 200000 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(0 < bits && bits <= 32 && + num_samples > 1 && + images.size() > 0, + "\t void setup_hashed_features()" + << "\n\t Invalid inputs were given to this function. " + << "\n\t bits: " << bits + << "\n\t num_samples: " << num_samples + << "\n\t images.size(): " << images.size() + ); + + feature_extractor fe; + setup_hashed_features(scanner, images, fe, bits, num_samples); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// stuff for scan_image_boxes +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename image_array, + typename feature_extractor, + template class feature_image, + typename box_generator + > + void setup_hashed_features ( + scan_image_boxes,box_generator >& scanner, + const image_array& images, + const feature_extractor& fe, + int bits, + unsigned long num_samples = 200000 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(0 < bits && bits <= 32 && + num_samples > 1 && + images.size() > 0, + "\t void setup_hashed_features()" + << "\n\t Invalid inputs were given to this function. " + << "\n\t bits: " << bits + << "\n\t num_samples: " << num_samples + << "\n\t images.size(): " << images.size() + ); + + pyramid_disable pyr; + + const random_subset_selector& samps = + randomly_sample_image_features(images, pyr, fe, num_samples); + + if (samps.size() <= 1) + throw dlib::image_hash_construction_failure("Images too small, not able to gather enough samples to make hash"); + + projection_hash phash = create_random_projection_hash(samps, bits); + + feature_image hfe; + hfe.copy_configuration(scanner.get_feature_extractor()); + hfe.set_hash(phash); + hfe.copy_configuration(fe); + scanner.copy_configuration(hfe); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_array, + typename feature_extractor, + template class feature_image, + typename box_generator + > + void setup_hashed_features ( + scan_image_boxes,box_generator>& scanner, + const image_array& images, + int bits, + unsigned long num_samples = 200000 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(0 < bits && bits <= 32 && + num_samples > 1 && + images.size() > 0, + "\t void setup_hashed_features()" + << "\n\t Invalid inputs were given to this function. " + << "\n\t bits: " << bits + << "\n\t num_samples: " << num_samples + << "\n\t images.size(): " << images.size() + ); + + feature_extractor fe; + setup_hashed_features(scanner, images, fe, bits, num_samples); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SETUP_HAShED_FEATURES_Hh_ + + diff --git a/ml/dlib/dlib/image_processing/setup_hashed_features_abstract.h b/ml/dlib/dlib/image_processing/setup_hashed_features_abstract.h new file mode 100644 index 000000000..886411cd4 --- /dev/null +++ b/ml/dlib/dlib/image_processing/setup_hashed_features_abstract.h @@ -0,0 +1,210 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SETUP_HAShED_FEATURES_ABSTRACT_Hh_ +#ifdef DLIB_SETUP_HAShED_FEATURES_ABSTRACT_Hh_ + +#include "scan_image_pyramid_abstract.h" +#include "scan_image_boxes_abstract.h" +#include "../lsh/projection_hash_abstract.h" +#include "../image_keypoint/hashed_feature_image_abstract.h" +#include "../image_keypoint/binned_vector_feature_image_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class image_hash_construction_failure : public error + { + /*! + WHAT THIS OBJECT REPRESENTS + This is the exception object used by the routines in this file. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename image_scanner + > + void use_uniform_feature_weights ( + image_scanner& scanner + ); + /*! + requires + - image_scanner should be either scan_image_pyramid or scan_image_boxes and + should use the hashed_feature_image as its local feature extractor. + ensures + - #scanner.get_feature_extractor().uses_uniform_feature_weights() == true + (i.e. Make the scanner's feature extractor use the uniform feature weighting + scheme) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_scanner + > + void use_relative_feature_weights ( + image_scanner& scanner + ); + /*! + requires + - image_scanner should be either scan_image_pyramid or scan_image_boxes and + should use the hashed_feature_image as its local feature extractor. + ensures + - #scanner.get_feature_extractor().uses_uniform_feature_weights() == false + (i.e. Make the scanner's feature extractor use the relative feature weighting + scheme) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_array, + typename pyramid, + typename feature_extractor + template class feature_image + > + void setup_hashed_features ( + scan_image_pyramid >& scanner, + const image_array& images, + const feature_extractor& fe, + int bits, + unsigned long num_samples = 200000 + ); + /*! + requires + - 0 < bits <= 32 + - num_samples > 1 + - images.size() > 0 + - it must be valid to pass images[0] into scanner.load(). + (also, image_array must be an implementation of dlib/array/array_kernel_abstract.h) + - feature_image == must be either hashed_feature_image, binned_vector_feature_image, + or a type with a compatible interface. + ensures + - Creates a projection_hash suitable for hashing the feature vectors produced by + fe and then configures scanner to use this hash function. + - The hash function will map vectors into integers in the range [0, pow(2,bits)) + - The hash function will be setup so that it hashes a random sample of num_samples + vectors from fe such that each bin ends up with roughly the same number of + elements in it. + throws + - image_hash_construction_failure + This exception is thrown if there is a problem creating the projection_hash. + This should only happen the images are so small they contain less than 2 + feature vectors. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_array, + typename pyramid, + typename feature_extractor + template class feature_image + > + void setup_hashed_features ( + scan_image_pyramid >& scanner, + const image_array& images, + int bits, + unsigned long num_samples = 200000 + ); + /*! + requires + - 0 < bits <= 32 + - num_samples > 1 + - images.size() > 0 + - it must be valid to pass images[0] into scanner.load(). + (also, image_array must be an implementation of dlib/array/array_kernel_abstract.h) + - feature_image == must be either hashed_feature_image, binned_vector_feature_image, + or a type with a compatible interface. + ensures + - performs: setup_hashed_features(scanner, images, feature_extractor(), bits, num_samples) + throws + - image_hash_construction_failure + This exception is thrown if there is a problem creating the projection_hash. + This should only happen the images are so small they contain less than 2 + feature vectors. + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename image_array, + typename feature_extractor, + template class feature_image + typename box_generator + > + void setup_hashed_features ( + scan_image_boxes,box_generator>& scanner, + const image_array& images, + const feature_extractor& fe, + int bits, + unsigned long num_samples = 200000 + ); + /*! + requires + - 0 < bits <= 32 + - num_samples > 1 + - images.size() > 0 + - it must be valid to pass images[0] into scanner.load(). + (also, image_array must be an implementation of dlib/array/array_kernel_abstract.h) + - feature_image == must be either hashed_feature_image, binned_vector_feature_image, + or a type with a compatible interface. + ensures + - Creates a projection_hash suitable for hashing the feature vectors produced by + fe and then configures scanner to use this hash function. + - The hash function will map vectors into integers in the range [0, pow(2,bits)) + - The hash function will be setup so that it hashes a random sample of num_samples + vectors from fe such that each bin ends up with roughly the same number of + elements in it. + throws + - image_hash_construction_failure + This exception is thrown if there is a problem creating the projection_hash. + This should only happen the images are so small they contain less than 2 + feature vectors. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_array, + typename feature_extractor, + template class feature_image + typename box_generator + > + void setup_hashed_features ( + scan_image_boxes,box_generator>& scanner, + const image_array& images, + int bits, + unsigned long num_samples = 200000 + ); + /*! + requires + - 0 < bits <= 32 + - num_samples > 1 + - images.size() > 0 + - it must be valid to pass images[0] into scanner.load(). + (also, image_array must be an implementation of dlib/array/array_kernel_abstract.h) + - feature_image == must be either hashed_feature_image, binned_vector_feature_image, + or a type with a compatible interface. + ensures + - performs: setup_hashed_features(scanner, images, feature_extractor(), bits, num_samples) + throws + - image_hash_construction_failure + This exception is thrown if there is a problem creating the projection_hash. + This should only happen the images are so small they contain less than 2 + feature vectors. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SETUP_HAShED_FEATURES_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/image_processing/shape_predictor.h b/ml/dlib/dlib/image_processing/shape_predictor.h new file mode 100644 index 000000000..05e9a60fd --- /dev/null +++ b/ml/dlib/dlib/image_processing/shape_predictor.h @@ -0,0 +1,524 @@ +// Copyright (C) 2014 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SHAPE_PREDICToR_H_ +#define DLIB_SHAPE_PREDICToR_H_ + +#include "shape_predictor_abstract.h" +#include "full_object_detection.h" +#include "../algs.h" +#include "../matrix.h" +#include "../geometry.h" +#include "../pixel.h" +#include "../statistics.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + struct split_feature + { + unsigned long idx1; + unsigned long idx2; + float thresh; + + friend inline void serialize (const split_feature& item, std::ostream& out) + { + dlib::serialize(item.idx1, out); + dlib::serialize(item.idx2, out); + dlib::serialize(item.thresh, out); + } + friend inline void deserialize (split_feature& item, std::istream& in) + { + dlib::deserialize(item.idx1, in); + dlib::deserialize(item.idx2, in); + dlib::deserialize(item.thresh, in); + } + }; + + + // a tree is just a std::vector. We use this function to navigate the + // tree nodes + inline unsigned long left_child (unsigned long idx) { return 2*idx + 1; } + /*! + ensures + - returns the index of the left child of the binary tree node idx + !*/ + inline unsigned long right_child (unsigned long idx) { return 2*idx + 2; } + /*! + ensures + - returns the index of the left child of the binary tree node idx + !*/ + + struct regression_tree + { + std::vector splits; + std::vector > leaf_values; + + unsigned long num_leaves() const { return leaf_values.size(); } + + inline const matrix& operator()( + const std::vector& feature_pixel_values, + unsigned long& i + ) const + /*! + requires + - All the index values in splits are less than feature_pixel_values.size() + - leaf_values.size() is a power of 2. + (i.e. we require a tree with all the levels fully filled out. + - leaf_values.size() == splits.size()+1 + (i.e. there needs to be the right number of leaves given the number of splits in the tree) + ensures + - runs through the tree and returns the vector at the leaf we end up in. + - #i == the selected leaf node index. + !*/ + { + i = 0; + while (i < splits.size()) + { + if ((float)feature_pixel_values[splits[i].idx1] - (float)feature_pixel_values[splits[i].idx2] > splits[i].thresh) + i = left_child(i); + else + i = right_child(i); + } + i = i - splits.size(); + return leaf_values[i]; + } + + friend void serialize (const regression_tree& item, std::ostream& out) + { + dlib::serialize(item.splits, out); + dlib::serialize(item.leaf_values, out); + } + friend void deserialize (regression_tree& item, std::istream& in) + { + dlib::deserialize(item.splits, in); + dlib::deserialize(item.leaf_values, in); + } + }; + + // ------------------------------------------------------------------------------------ + + inline vector location ( + const matrix& shape, + unsigned long idx + ) + /*! + requires + - idx < shape.size()/2 + - shape.size()%2 == 0 + ensures + - returns the idx-th point from the shape vector. + !*/ + { + return vector(shape(idx*2), shape(idx*2+1)); + } + + // ------------------------------------------------------------------------------------ + + inline unsigned long nearest_shape_point ( + const matrix& shape, + const dlib::vector& pt + ) + { + // find the nearest part of the shape to this pixel + float best_dist = std::numeric_limits::infinity(); + const unsigned long num_shape_parts = shape.size()/2; + unsigned long best_idx = 0; + for (unsigned long j = 0; j < num_shape_parts; ++j) + { + const float dist = length_squared(location(shape,j)-pt); + if (dist < best_dist) + { + best_dist = dist; + best_idx = j; + } + } + return best_idx; + } + + // ------------------------------------------------------------------------------------ + + inline void create_shape_relative_encoding ( + const matrix& shape, + const std::vector >& pixel_coordinates, + std::vector& anchor_idx, + std::vector >& deltas + ) + /*! + requires + - shape.size()%2 == 0 + - shape.size() > 0 + ensures + - #anchor_idx.size() == pixel_coordinates.size() + - #deltas.size() == pixel_coordinates.size() + - for all valid i: + - pixel_coordinates[i] == location(shape,#anchor_idx[i]) + #deltas[i] + !*/ + { + anchor_idx.resize(pixel_coordinates.size()); + deltas.resize(pixel_coordinates.size()); + + + for (unsigned long i = 0; i < pixel_coordinates.size(); ++i) + { + anchor_idx[i] = nearest_shape_point(shape, pixel_coordinates[i]); + deltas[i] = pixel_coordinates[i] - location(shape,anchor_idx[i]); + } + } + + // ------------------------------------------------------------------------------------ + + inline point_transform_affine find_tform_between_shapes ( + const matrix& from_shape, + const matrix& to_shape + ) + { + DLIB_ASSERT(from_shape.size() == to_shape.size() && (from_shape.size()%2) == 0 && from_shape.size() > 0,""); + std::vector > from_points, to_points; + const unsigned long num = from_shape.size()/2; + from_points.reserve(num); + to_points.reserve(num); + if (num == 1) + { + // Just use an identity transform if there is only one landmark. + return point_transform_affine(); + } + + for (unsigned long i = 0; i < num; ++i) + { + from_points.push_back(location(from_shape,i)); + to_points.push_back(location(to_shape,i)); + } + return find_similarity_transform(from_points, to_points); + } + + // ------------------------------------------------------------------------------------ + + inline point_transform_affine normalizing_tform ( + const rectangle& rect + ) + /*! + ensures + - returns a transform that maps rect.tl_corner() to (0,0) and rect.br_corner() + to (1,1). + !*/ + { + std::vector > from_points, to_points; + from_points.push_back(rect.tl_corner()); to_points.push_back(point(0,0)); + from_points.push_back(rect.tr_corner()); to_points.push_back(point(1,0)); + from_points.push_back(rect.br_corner()); to_points.push_back(point(1,1)); + return find_affine_transform(from_points, to_points); + } + + // ------------------------------------------------------------------------------------ + + inline point_transform_affine unnormalizing_tform ( + const rectangle& rect + ) + /*! + ensures + - returns a transform that maps (0,0) to rect.tl_corner() and (1,1) to + rect.br_corner(). + !*/ + { + std::vector > from_points, to_points; + to_points.push_back(rect.tl_corner()); from_points.push_back(point(0,0)); + to_points.push_back(rect.tr_corner()); from_points.push_back(point(1,0)); + to_points.push_back(rect.br_corner()); from_points.push_back(point(1,1)); + return find_affine_transform(from_points, to_points); + } + + // ------------------------------------------------------------------------------------ + + template + void extract_feature_pixel_values ( + const image_type& img_, + const rectangle& rect, + const matrix& current_shape, + const matrix& reference_shape, + const std::vector& reference_pixel_anchor_idx, + const std::vector >& reference_pixel_deltas, + std::vector& feature_pixel_values + ) + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - reference_pixel_anchor_idx.size() == reference_pixel_deltas.size() + - current_shape.size() == reference_shape.size() + - reference_shape.size()%2 == 0 + - max(mat(reference_pixel_anchor_idx)) < reference_shape.size()/2 + ensures + - #feature_pixel_values.size() == reference_pixel_deltas.size() + - for all valid i: + - #feature_pixel_values[i] == the value of the pixel in img_ that + corresponds to the pixel identified by reference_pixel_anchor_idx[i] + and reference_pixel_deltas[i] when the pixel is located relative to + current_shape rather than reference_shape. + !*/ + { + const matrix tform = matrix_cast(find_tform_between_shapes(reference_shape, current_shape).get_m()); + const point_transform_affine tform_to_img = unnormalizing_tform(rect); + + const rectangle area = get_rect(img_); + + const_image_view img(img_); + feature_pixel_values.resize(reference_pixel_deltas.size()); + for (unsigned long i = 0; i < feature_pixel_values.size(); ++i) + { + // Compute the point in the current shape corresponding to the i-th pixel and + // then map it from the normalized shape space into pixel space. + point p = tform_to_img(tform*reference_pixel_deltas[i] + location(current_shape, reference_pixel_anchor_idx[i])); + if (area.contains(p)) + feature_pixel_values[i] = get_pixel_intensity(img[p.y()][p.x()]); + else + feature_pixel_values[i] = 0; + } + } + + } // end namespace impl + +// ---------------------------------------------------------------------------------------- + + class shape_predictor + { + public: + + + shape_predictor ( + ) + {} + + shape_predictor ( + const matrix& initial_shape_, + const std::vector >& forests_, + const std::vector > >& pixel_coordinates + ) : initial_shape(initial_shape_), forests(forests_) + /*! + requires + - initial_shape.size()%2 == 0 + - forests.size() == pixel_coordinates.size() == the number of cascades + - for all valid i: + - all the index values in forests[i] are less than pixel_coordinates[i].size() + - for all valid i and j: + - forests[i][j].leaf_values.size() is a power of 2. + (i.e. we require a tree with all the levels fully filled out. + - forests[i][j].leaf_values.size() == forests[i][j].splits.size()+1 + (i.e. there need to be the right number of leaves given the number of splits in the tree) + !*/ + { + anchor_idx.resize(pixel_coordinates.size()); + deltas.resize(pixel_coordinates.size()); + // Each cascade uses a different set of pixels for its features. We compute + // their representations relative to the initial shape now and save it. + for (unsigned long i = 0; i < pixel_coordinates.size(); ++i) + impl::create_shape_relative_encoding(initial_shape, pixel_coordinates[i], anchor_idx[i], deltas[i]); + } + + unsigned long num_parts ( + ) const + { + return initial_shape.size()/2; + } + + unsigned long num_features ( + ) const + { + unsigned long num = 0; + for (unsigned long iter = 0; iter < forests.size(); ++iter) + for (unsigned long i = 0; i < forests[iter].size(); ++i) + num += forests[iter][i].num_leaves(); + return num; + } + + template + full_object_detection operator()( + const image_type& img, + const rectangle& rect + ) const + { + using namespace impl; + matrix current_shape = initial_shape; + std::vector feature_pixel_values; + for (unsigned long iter = 0; iter < forests.size(); ++iter) + { + extract_feature_pixel_values(img, rect, current_shape, initial_shape, + anchor_idx[iter], deltas[iter], feature_pixel_values); + unsigned long leaf_idx; + // evaluate all the trees at this level of the cascade. + for (unsigned long i = 0; i < forests[iter].size(); ++i) + current_shape += forests[iter][i](feature_pixel_values, leaf_idx); + } + + // convert the current_shape into a full_object_detection + const point_transform_affine tform_to_img = unnormalizing_tform(rect); + std::vector parts(current_shape.size()/2); + for (unsigned long i = 0; i < parts.size(); ++i) + parts[i] = tform_to_img(location(current_shape, i)); + return full_object_detection(rect, parts); + } + + template + full_object_detection operator()( + const image_type& img, + const rectangle& rect, + std::vector >& feats + ) const + { + feats.clear(); + using namespace impl; + matrix current_shape = initial_shape; + std::vector feature_pixel_values; + unsigned long feat_offset = 0; + for (unsigned long iter = 0; iter < forests.size(); ++iter) + { + extract_feature_pixel_values(img, rect, current_shape, initial_shape, + anchor_idx[iter], deltas[iter], feature_pixel_values); + // evaluate all the trees at this level of the cascade. + for (unsigned long i = 0; i < forests[iter].size(); ++i) + { + unsigned long leaf_idx; + current_shape += forests[iter][i](feature_pixel_values, leaf_idx); + + feats.push_back(std::make_pair(feat_offset+leaf_idx, 1)); + feat_offset += forests[iter][i].num_leaves(); + } + } + + // convert the current_shape into a full_object_detection + const point_transform_affine tform_to_img = unnormalizing_tform(rect); + std::vector parts(current_shape.size()/2); + for (unsigned long i = 0; i < parts.size(); ++i) + parts[i] = tform_to_img(location(current_shape, i)); + return full_object_detection(rect, parts); + } + + friend void serialize (const shape_predictor& item, std::ostream& out); + + friend void deserialize (shape_predictor& item, std::istream& in); + + private: + matrix initial_shape; + std::vector > forests; + std::vector > anchor_idx; + std::vector > > deltas; + }; + + inline void serialize (const shape_predictor& item, std::ostream& out) + { + int version = 1; + dlib::serialize(version, out); + dlib::serialize(item.initial_shape, out); + dlib::serialize(item.forests, out); + dlib::serialize(item.anchor_idx, out); + dlib::serialize(item.deltas, out); + } + + inline void deserialize (shape_predictor& item, std::istream& in) + { + int version = 0; + dlib::deserialize(version, in); + if (version != 1) + throw serialization_error("Unexpected version found while deserializing dlib::shape_predictor."); + dlib::deserialize(item.initial_shape, in); + dlib::deserialize(item.forests, in); + dlib::deserialize(item.anchor_idx, in); + dlib::deserialize(item.deltas, in); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename image_array + > + double test_shape_predictor ( + const shape_predictor& sp, + const image_array& images, + const std::vector >& objects, + const std::vector >& scales + ) + { + // make sure requires clause is not broken +#ifdef ENABLE_ASSERTS + DLIB_CASSERT( images.size() == objects.size() , + "\t double test_shape_predictor()" + << "\n\t Invalid inputs were given to this function. " + << "\n\t images.size(): " << images.size() + << "\n\t objects.size(): " << objects.size() + ); + for (unsigned long i = 0; i < objects.size(); ++i) + { + for (unsigned long j = 0; j < objects[i].size(); ++j) + { + DLIB_CASSERT(objects[i][j].num_parts() == sp.num_parts(), + "\t double test_shape_predictor()" + << "\n\t Invalid inputs were given to this function. " + << "\n\t objects["< rs; + for (unsigned long i = 0; i < objects.size(); ++i) + { + for (unsigned long j = 0; j < objects[i].size(); ++j) + { + // Just use a scale of 1 (i.e. no scale at all) if the caller didn't supply + // any scales. + const double scale = scales.size()==0 ? 1 : scales[i][j]; + + full_object_detection det = sp(images[i], objects[i][j].get_rect()); + + for (unsigned long k = 0; k < det.num_parts(); ++k) + { + if (objects[i][j].part(k) != OBJECT_PART_NOT_PRESENT) + { + double score = length(det.part(k) - objects[i][j].part(k))/scale; + rs.add(score); + } + } + } + } + return rs.mean(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_array + > + double test_shape_predictor ( + const shape_predictor& sp, + const image_array& images, + const std::vector >& objects + ) + { + std::vector > no_scales; + return test_shape_predictor(sp, images, objects, no_scales); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SHAPE_PREDICToR_H_ + diff --git a/ml/dlib/dlib/image_processing/shape_predictor_abstract.h b/ml/dlib/dlib/image_processing/shape_predictor_abstract.h new file mode 100644 index 000000000..718b4952e --- /dev/null +++ b/ml/dlib/dlib/image_processing/shape_predictor_abstract.h @@ -0,0 +1,195 @@ +// Copyright (C) 2014 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SHAPE_PREDICToR_ABSTRACT_H_ +#ifdef DLIB_SHAPE_PREDICToR_ABSTRACT_H_ + +#include "full_object_detection_abstract.h" +#include "../matrix.h" +#include "../geometry.h" +#include "../pixel.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class shape_predictor + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a tool that takes in an image region containing some object + and outputs a set of point locations that define the pose of the object. + The classic example of this is human face pose prediction, where you take + an image of a human face as input and are expected to identify the + locations of important facial landmarks such as the corners of the mouth + and eyes, tip of the nose, and so forth. + + To create useful instantiations of this object you need to use the + shape_predictor_trainer object defined in the + shape_predictor_trainer_abstract.h file to train a shape_predictor using a + set of training images, each annotated with shapes you want to predict. + + THREAD SAFETY + No synchronization is required when using this object. In particular, a + single instance of this object can be used from multiple threads at the + same time. + !*/ + + public: + + shape_predictor ( + ); + /*! + ensures + - #num_parts() == 0 + - #num_features() == 0 + !*/ + + unsigned long num_parts ( + ) const; + /*! + ensures + - returns the number of parts in the shapes predicted by this object. + !*/ + + unsigned long num_features ( + ) const; + /*! + ensures + - Returns the dimensionality of the feature vector output by operator(). + This number is the total number of trees in this object times the number + of leaves on each tree. + !*/ + + template + full_object_detection operator()( + const image_type& img, + const rectangle& rect, + std::vector >& feats + ) const; + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - T is some unsigned integral type (e.g. unsigned int). + - U is any scalar type capable of storing the value 1 (e.g. float). + ensures + - Runs the shape prediction algorithm on the part of the image contained in + the given bounding rectangle. So it will try and fit the shape model to + the contents of the given rectangle in the image. For example, if there + is a human face inside the rectangle and you use a face landmarking shape + model then this function will return the locations of the face landmarks + as the parts. So the return value is a full_object_detection DET such + that: + - DET.get_rect() == rect + - DET.num_parts() == num_parts() + - for all valid i: + - DET.part(i) == the location in img for the i-th part of the shape + predicted by this object. + - #feats == a sparse vector that records which leaf each tree used to make + the shape prediction. Moreover, it is an indicator vector, Therefore, + for all valid i: + - #feats[i].second == 1 + Further, #feats is a vector from the space of num_features() dimensional + vectors. The output shape positions can be represented as the dot + product between #feats and a weight vector. Therefore, #feats encodes + all the information from img that was used to predict the returned shape + object. + !*/ + + template + full_object_detection operator()( + const image_type& img, + const rectangle& rect + ) const; + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + ensures + - Calling this function is equivalent to calling (*this)(img, rect, ignored) + where the 3d argument is discarded. + !*/ + + }; + + void serialize (const shape_predictor& item, std::ostream& out); + void deserialize (shape_predictor& item, std::istream& in); + /*! + provides serialization support + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename image_array + > + double test_shape_predictor ( + const shape_predictor& sp, + const image_array& images, + const std::vector >& objects, + const std::vector >& scales + ); + /*! + requires + - image_array is a dlib::array of image objects where each image object + implements the interface defined in dlib/image_processing/generic_image.h + - images.size() == objects.size() + - for all valid i and j: + - objects[i][j].num_parts() == sp.num_parts() + - if (scales.size() != 0) then + - There must be a scale value for each full_object_detection in objects. + That is, it must be the case that: + - scales.size() == objects.size() + - for all valid i: + - scales[i].size() == objects[i].size() + ensures + - Tests the given shape_predictor by running it on each of the given objects and + checking how well it recovers the part positions. In particular, for all + valid i and j we perform: + sp(images[i], objects[i][j].get_rect()) + and compare the result with the truth part positions in objects[i][j]. We + then return the average distance (measured in pixels) between a predicted + part location and its true position. + - Note that any parts in objects that are set to OBJECT_PART_NOT_PRESENT are + simply ignored. + - if (scales.size() != 0) then + - Each time we compute the distance between a predicted part location and + its true location in objects[i][j] we divide the distance by + scales[i][j]. Therefore, if you want the reported error to be the + average pixel distance then give an empty scales vector, but if you want + the returned value to be something else like the average distance + normalized by some feature of each object (e.g. the interocular distance) + then you can supply those normalizing values via scales. + !*/ + + template < + typename image_array + > + double test_shape_predictor ( + const shape_predictor& sp, + const image_array& images, + const std::vector >& objects + ); + /*! + requires + - image_array is a dlib::array of image objects where each image object + implements the interface defined in dlib/image_processing/generic_image.h + - images.size() == objects.size() + - for all valid i and j: + - objects[i][j].num_parts() == sp.num_parts() + ensures + - returns test_shape_predictor(sp, images, objects, no_scales) where no_scales + is an empty vector. So this is just a convenience function for calling the + above test_shape_predictor() routine without a scales argument. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SHAPE_PREDICToR_ABSTRACT_H_ + diff --git a/ml/dlib/dlib/image_processing/shape_predictor_trainer.h b/ml/dlib/dlib/image_processing/shape_predictor_trainer.h new file mode 100644 index 000000000..3090998f9 --- /dev/null +++ b/ml/dlib/dlib/image_processing/shape_predictor_trainer.h @@ -0,0 +1,852 @@ +// Copyright (C) 2014 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SHAPE_PREDICToR_TRAINER_H_ +#define DLIB_SHAPE_PREDICToR_TRAINER_H_ + +#include "shape_predictor_trainer_abstract.h" +#include "shape_predictor.h" +#include "../console_progress_indicator.h" +#include "../threads.h" +#include "../data_io/image_dataset_metadata.h" +#include "box_overlap_testing.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class shape_predictor_trainer + { + /*! + This thing really only works with unsigned char or rgb_pixel images (since we assume the threshold + should be in the range [-128,128]). + !*/ + public: + + enum padding_mode_t + { + bounding_box_relative, + landmark_relative + }; + + shape_predictor_trainer ( + ) + { + _cascade_depth = 10; + _tree_depth = 4; + _num_trees_per_cascade_level = 500; + _nu = 0.1; + _oversampling_amount = 20; + _feature_pool_size = 400; + _lambda = 0.1; + _num_test_splits = 20; + _feature_pool_region_padding = 0; + _verbose = false; + _num_threads = 0; + _padding_mode = landmark_relative; + } + + unsigned long get_cascade_depth ( + ) const { return _cascade_depth; } + + void set_cascade_depth ( + unsigned long depth + ) + { + DLIB_CASSERT(depth > 0, + "\t void shape_predictor_trainer::set_cascade_depth()" + << "\n\t Invalid inputs were given to this function. " + << "\n\t depth: " << depth + ); + + _cascade_depth = depth; + } + + unsigned long get_tree_depth ( + ) const { return _tree_depth; } + + void set_tree_depth ( + unsigned long depth + ) + { + DLIB_CASSERT(depth > 0, + "\t void shape_predictor_trainer::set_tree_depth()" + << "\n\t Invalid inputs were given to this function. " + << "\n\t depth: " << depth + ); + + _tree_depth = depth; + } + + unsigned long get_num_trees_per_cascade_level ( + ) const { return _num_trees_per_cascade_level; } + + void set_num_trees_per_cascade_level ( + unsigned long num + ) + { + DLIB_CASSERT( num > 0, + "\t void shape_predictor_trainer::set_num_trees_per_cascade_level()" + << "\n\t Invalid inputs were given to this function. " + << "\n\t num: " << num + ); + _num_trees_per_cascade_level = num; + } + + double get_nu ( + ) const { return _nu; } + void set_nu ( + double nu + ) + { + DLIB_CASSERT(0 < nu && nu <= 1, + "\t void shape_predictor_trainer::set_nu()" + << "\n\t Invalid inputs were given to this function. " + << "\n\t nu: " << nu + ); + + _nu = nu; + } + + std::string get_random_seed ( + ) const { return rnd.get_seed(); } + void set_random_seed ( + const std::string& seed + ) { rnd.set_seed(seed); } + + unsigned long get_oversampling_amount ( + ) const { return _oversampling_amount; } + void set_oversampling_amount ( + unsigned long amount + ) + { + DLIB_CASSERT(amount > 0, + "\t void shape_predictor_trainer::set_oversampling_amount()" + << "\n\t Invalid inputs were given to this function. " + << "\n\t amount: " << amount + ); + + _oversampling_amount = amount; + } + + unsigned long get_feature_pool_size ( + ) const { return _feature_pool_size; } + void set_feature_pool_size ( + unsigned long size + ) + { + DLIB_CASSERT(size > 1, + "\t void shape_predictor_trainer::set_feature_pool_size()" + << "\n\t Invalid inputs were given to this function. " + << "\n\t size: " << size + ); + + _feature_pool_size = size; + } + + double get_lambda ( + ) const { return _lambda; } + void set_lambda ( + double lambda + ) + { + DLIB_CASSERT(lambda > 0, + "\t void shape_predictor_trainer::set_lambda()" + << "\n\t Invalid inputs were given to this function. " + << "\n\t lambda: " << lambda + ); + + _lambda = lambda; + } + + unsigned long get_num_test_splits ( + ) const { return _num_test_splits; } + void set_num_test_splits ( + unsigned long num + ) + { + DLIB_CASSERT(num > 0, + "\t void shape_predictor_trainer::set_num_test_splits()" + << "\n\t Invalid inputs were given to this function. " + << "\n\t num: " << num + ); + + _num_test_splits = num; + } + + void set_padding_mode ( + padding_mode_t mode + ) + { + _padding_mode = mode; + } + + padding_mode_t get_padding_mode ( + ) const { return _padding_mode; } + + double get_feature_pool_region_padding ( + ) const { return _feature_pool_region_padding; } + void set_feature_pool_region_padding ( + double padding + ) + { + DLIB_CASSERT(padding > -0.5, + "\t void shape_predictor_trainer::set_feature_pool_region_padding()" + << "\n\t Invalid inputs were given to this function. " + << "\n\t padding: " << padding + ); + + _feature_pool_region_padding = padding; + } + + void be_verbose ( + ) + { + _verbose = true; + } + + void be_quiet ( + ) + { + _verbose = false; + } + + unsigned long get_num_threads ( + ) const { return _num_threads; } + void set_num_threads ( + unsigned long num + ) + { + _num_threads = num; + } + + template + shape_predictor train ( + const image_array& images, + const std::vector >& objects + ) const + { + using namespace impl; + DLIB_CASSERT(images.size() == objects.size() && images.size() > 0, + "\t shape_predictor shape_predictor_trainer::train()" + << "\n\t Invalid inputs were given to this function. " + << "\n\t images.size(): " << images.size() + << "\n\t objects.size(): " << objects.size() + ); + // make sure the objects agree on the number of parts and that there is at + // least one full_object_detection. + unsigned long num_parts = 0; + std::vector part_present; + for (unsigned long i = 0; i < objects.size(); ++i) + { + for (unsigned long j = 0; j < objects[i].size(); ++j) + { + if (num_parts == 0) + { + num_parts = objects[i][j].num_parts(); + DLIB_CASSERT(objects[i][j].num_parts() != 0, + "\t shape_predictor shape_predictor_trainer::train()" + << "\n\t You can't give objects that don't have any parts to the trainer." + ); + part_present.resize(num_parts); + } + else + { + DLIB_CASSERT(objects[i][j].num_parts() == num_parts, + "\t shape_predictor shape_predictor_trainer::train()" + << "\n\t All the objects must agree on the number of parts. " + << "\n\t objects["< 1 ? _num_threads : 0); + + // determining the type of features used for this type of images + typedef typename std::remove_const::type>::type image_type; + typedef typename image_traits::pixel_type pixel_type; + typedef typename pixel_traits::basic_pixel_type feature_type; + + rnd.set_seed(get_random_seed()); + + std::vector> samples; + const matrix initial_shape = populate_training_sample_shapes(objects, samples); + const std::vector > > pixel_coordinates = randomly_sample_pixel_coordinates(initial_shape); + + unsigned long trees_fit_so_far = 0; + console_progress_indicator pbar(get_cascade_depth()*get_num_trees_per_cascade_level()); + if (_verbose) + std::cout << "Fitting trees..." << std::endl; + + std::vector > forests(get_cascade_depth()); + // Now start doing the actual training by filling in the forests + for (unsigned long cascade = 0; cascade < get_cascade_depth(); ++cascade) + { + // Each cascade uses a different set of pixels for its features. We compute + // their representations relative to the initial shape first. + std::vector anchor_idx; + std::vector > deltas; + create_shape_relative_encoding(initial_shape, pixel_coordinates[cascade], anchor_idx, deltas); + + // First compute the feature_pixel_values for each training sample at this + // level of the cascade. + parallel_for(tp, 0, samples.size(), [&](unsigned long i) + { + impl::extract_feature_pixel_values(images[samples[i].image_idx], samples[i].rect, + samples[i].current_shape, initial_shape, anchor_idx, + deltas, samples[i].feature_pixel_values); + }, 1); + + // Now start building the trees at this cascade level. + for (unsigned long i = 0; i < get_num_trees_per_cascade_level(); ++i) + { + forests[cascade].push_back(make_regression_tree(tp, samples, pixel_coordinates[cascade])); + + if (_verbose) + { + ++trees_fit_so_far; + pbar.print_status(trees_fit_so_far); + } + } + } + + if (_verbose) + std::cout << "Training complete " << std::endl; + + return shape_predictor(initial_shape, forests, pixel_coordinates); + } + + private: + + static void object_to_shape ( + const full_object_detection& obj, + matrix& shape, + matrix& present // a mask telling which elements of #shape are present. + ) + { + shape.set_size(obj.num_parts()*2); + present.set_size(obj.num_parts()*2); + const point_transform_affine tform_from_img = impl::normalizing_tform(obj.get_rect()); + for (unsigned long i = 0; i < obj.num_parts(); ++i) + { + if (obj.part(i) != OBJECT_PART_NOT_PRESENT) + { + vector p = tform_from_img(obj.part(i)); + shape(2*i) = p.x(); + shape(2*i+1) = p.y(); + present(2*i) = 1; + present(2*i+1) = 1; + + if (length(p) > 100) + { + std::cout << "Warning, one of your objects has parts that are way outside its bounding box! This is probably an error in your annotation." << std::endl; + } + } + else + { + shape(2*i) = 0; + shape(2*i+1) = 0; + present(2*i) = 0; + present(2*i+1) = 0; + } + } + } + + template + struct training_sample + { + /*! + + CONVENTION + - feature_pixel_values.size() == get_feature_pool_size() + - feature_pixel_values[j] == the value of the j-th feature pool + pixel when you look it up relative to the shape in current_shape. + + - target_shape == The truth shape. Stays constant during the whole + training process (except for the parts that are not present, those are + always equal to the current_shape values). + - present == 0/1 mask saying which parts of target_shape are present. + - rect == the position of the object in the image_idx-th image. All shape + coordinates are coded relative to this rectangle. + - diff_shape == temporary value for holding difference between current + shape and target shape + !*/ + + unsigned long image_idx; + rectangle rect; + matrix target_shape; + matrix present; + + matrix current_shape; + matrix diff_shape; + std::vector feature_pixel_values; + + void swap(training_sample& item) + { + std::swap(image_idx, item.image_idx); + std::swap(rect, item.rect); + target_shape.swap(item.target_shape); + present.swap(item.present); + current_shape.swap(item.current_shape); + diff_shape.swap(item.diff_shape); + feature_pixel_values.swap(item.feature_pixel_values); + } + }; + + template + impl::regression_tree make_regression_tree ( + thread_pool& tp, + std::vector>& samples, + const std::vector >& pixel_coordinates + ) const + { + using namespace impl; + std::deque > parts; + parts.push_back(std::make_pair(0, (unsigned long)samples.size())); + + impl::regression_tree tree; + + // walk the tree in breadth first order + const unsigned long num_split_nodes = static_cast(std::pow(2.0, (double)get_tree_depth())-1); + std::vector > sums(num_split_nodes*2+1); + if (tp.num_threads_in_pool() > 1) + { + // Here we need to calculate shape differences and store sum of differences into sums[0] + // to make it. I am splitting samples into blocks, each block will be processed by + // separate thread, and the sum of differences of each block is stored into separate + // place in block_sums + + const unsigned long num_workers = std::max(1UL, tp.num_threads_in_pool()); + const unsigned long num = samples.size(); + const unsigned long block_size = std::max(1UL, (num + num_workers - 1) / num_workers); + std::vector > block_sums(num_workers); + + parallel_for(tp, 0, num_workers, [&](unsigned long block) + { + const unsigned long block_begin = block * block_size; + const unsigned long block_end = std::min(num, block_begin + block_size); + for (unsigned long i = block_begin; i < block_end; ++i) + { + samples[i].diff_shape = samples[i].target_shape - samples[i].current_shape; + block_sums[block] += samples[i].diff_shape; + } + }, 1); + + // now calculate the total result from separate blocks + for (unsigned long i = 0; i < block_sums.size(); ++i) + sums[0] += block_sums[i]; + } + else + { + // synchronous implementation + for (unsigned long i = 0; i < samples.size(); ++i) + { + samples[i].diff_shape = samples[i].target_shape - samples[i].current_shape; + sums[0] += samples[i].diff_shape; + } + } + + for (unsigned long i = 0; i < num_split_nodes; ++i) + { + std::pair range = parts.front(); + parts.pop_front(); + + const impl::split_feature split = generate_split(tp, samples, range.first, + range.second, pixel_coordinates, sums[i], sums[left_child(i)], + sums[right_child(i)]); + tree.splits.push_back(split); + const unsigned long mid = partition_samples(split, samples, range.first, range.second); + + parts.push_back(std::make_pair(range.first, mid)); + parts.push_back(std::make_pair(mid, range.second)); + } + + // Now all the parts contain the ranges for the leaves so we can use them to + // compute the average leaf values. + matrix present_counts(samples[0].target_shape.size()); + tree.leaf_values.resize(parts.size()); + for (unsigned long i = 0; i < parts.size(); ++i) + { + // Get the present counts for each dimension so we can divide each + // dimension by the number of observations we have on it to find the mean + // displacement in each leaf. + present_counts = 0; + for (unsigned long j = parts[i].first; j < parts[i].second; ++j) + present_counts += samples[j].present; + present_counts = dlib::reciprocal(present_counts); + + if (parts[i].second != parts[i].first) + tree.leaf_values[i] = pointwise_multiply(present_counts,sums[num_split_nodes+i]*get_nu()); + else + tree.leaf_values[i] = zeros_matrix(samples[0].target_shape); + + // now adjust the current shape based on these predictions + parallel_for(tp, parts[i].first, parts[i].second, [&](unsigned long j) + { + samples[j].current_shape += tree.leaf_values[i]; + // For parts that aren't present in the training data, we just make + // sure that the target shape always matches and therefore gives zero + // error. So this makes the algorithm simply ignore non-present + // landmarks. + for (long k = 0; k < samples[j].present.size(); ++k) + { + // if this part is not present + if (samples[j].present(k) == 0) + samples[j].target_shape(k) = samples[j].current_shape(k); + } + }, 1); + } + + return tree; + } + + impl::split_feature randomly_generate_split_feature ( + const std::vector >& pixel_coordinates + ) const + { + const double lambda = get_lambda(); + impl::split_feature feat; + const size_t max_iters = get_feature_pool_size()*get_feature_pool_size(); + for (size_t i = 0; i < max_iters; ++i) + { + feat.idx1 = rnd.get_integer(get_feature_pool_size()); + feat.idx2 = rnd.get_integer(get_feature_pool_size()); + while (feat.idx1 == feat.idx2) + feat.idx2 = rnd.get_integer(get_feature_pool_size()); + const double dist = length(pixel_coordinates[feat.idx1]-pixel_coordinates[feat.idx2]); + const double accept_prob = std::exp(-dist/lambda); + if (accept_prob > rnd.get_random_double()) + break; + } + + feat.thresh = (rnd.get_random_double()*256 - 128)/2.0; + + return feat; + } + + template + impl::split_feature generate_split ( + thread_pool& tp, + const std::vector>& samples, + unsigned long begin, + unsigned long end, + const std::vector >& pixel_coordinates, + const matrix& sum, + matrix& left_sum, + matrix& right_sum + ) const + { + // generate a bunch of random splits and test them and return the best one. + + const unsigned long num_test_splits = get_num_test_splits(); + + // sample the random features we test in this function + std::vector feats; + feats.reserve(num_test_splits); + for (unsigned long i = 0; i < num_test_splits; ++i) + feats.push_back(randomly_generate_split_feature(pixel_coordinates)); + + std::vector > left_sums(num_test_splits); + std::vector left_cnt(num_test_splits); + + const unsigned long num_workers = std::max(1UL, tp.num_threads_in_pool()); + const unsigned long block_size = std::max(1UL, (num_test_splits + num_workers - 1) / num_workers); + + // now compute the sums of vectors that go left for each feature + parallel_for(tp, 0, num_workers, [&](unsigned long block) + { + const unsigned long block_begin = block * block_size; + const unsigned long block_end = std::min(block_begin + block_size, num_test_splits); + + for (unsigned long j = begin; j < end; ++j) + { + for (unsigned long i = block_begin; i < block_end; ++i) + { + if ((float)samples[j].feature_pixel_values[feats[i].idx1] - (float)samples[j].feature_pixel_values[feats[i].idx2] > feats[i].thresh) + { + left_sums[i] += samples[j].diff_shape; + ++left_cnt[i]; + } + } + } + + }, 1); + + // now figure out which feature is the best + double best_score = -1; + unsigned long best_feat = 0; + matrix temp; + for (unsigned long i = 0; i < num_test_splits; ++i) + { + // check how well the feature splits the space. + double score = 0; + unsigned long right_cnt = end-begin-left_cnt[i]; + if (left_cnt[i] != 0 && right_cnt != 0) + { + temp = sum - left_sums[i]; + score = dot(left_sums[i],left_sums[i])/left_cnt[i] + dot(temp,temp)/right_cnt; + if (score > best_score) + { + best_score = score; + best_feat = i; + } + } + } + + left_sums[best_feat].swap(left_sum); + if (left_sum.size() != 0) + { + right_sum = sum - left_sum; + } + else + { + right_sum = sum; + left_sum = zeros_matrix(sum); + } + return feats[best_feat]; + } + + template + unsigned long partition_samples ( + const impl::split_feature& split, + std::vector>& samples, + unsigned long begin, + unsigned long end + ) const + { + // splits samples based on split (sorta like in quick sort) and returns the mid + // point. make sure you return the mid in a way compatible with how we walk + // through the tree. + + unsigned long i = begin; + for (unsigned long j = begin; j < end; ++j) + { + if ((float)samples[j].feature_pixel_values[split.idx1] - (float)samples[j].feature_pixel_values[split.idx2] > split.thresh) + { + samples[i].swap(samples[j]); + ++i; + } + } + return i; + } + + + + template + matrix populate_training_sample_shapes( + const std::vector >& objects, + std::vector>& samples + ) const + { + samples.clear(); + matrix mean_shape; + matrix count; + // first fill out the target shapes + for (unsigned long i = 0; i < objects.size(); ++i) + { + for (unsigned long j = 0; j < objects[i].size(); ++j) + { + training_sample sample; + sample.image_idx = i; + sample.rect = objects[i][j].get_rect(); + object_to_shape(objects[i][j], sample.target_shape, sample.present); + for (unsigned long itr = 0; itr < get_oversampling_amount(); ++itr) + samples.push_back(sample); + mean_shape += sample.target_shape; + count += sample.present; + } + } + + mean_shape = pointwise_multiply(mean_shape,reciprocal(count)); + + // now go pick random initial shapes + for (unsigned long i = 0; i < samples.size(); ++i) + { + if ((i%get_oversampling_amount()) == 0) + { + // The mean shape is what we really use as an initial shape so always + // include it in the training set as an example starting shape. + samples[i].current_shape = mean_shape; + } + else + { + samples[i].current_shape.set_size(0); + + matrix hits(mean_shape.size()); + hits = 0; + + int iter = 0; + // Pick a few samples at random and randomly average them together to + // make the initial shape. Note that we make sure we get at least one + // observation (i.e. non-OBJECT_PART_NOT_PRESENT) on each part + // location. + while(min(hits) == 0 || iter < 2) + { + ++iter; + const unsigned long rand_idx = rnd.get_random_32bit_number()%samples.size(); + const double alpha = rnd.get_random_double()+0.1; + samples[i].current_shape += alpha*samples[rand_idx].target_shape; + hits += alpha*samples[rand_idx].present; + } + samples[i].current_shape = pointwise_multiply(samples[i].current_shape, reciprocal(hits)); + } + + } + for (unsigned long i = 0; i < samples.size(); ++i) + { + for (long k = 0; k < samples[i].present.size(); ++k) + { + // if this part is not present + if (samples[i].present(k) == 0) + samples[i].target_shape(k) = samples[i].current_shape(k); + } + } + + + return mean_shape; + } + + + void randomly_sample_pixel_coordinates ( + std::vector >& pixel_coordinates, + const double min_x, + const double min_y, + const double max_x, + const double max_y + ) const + /*! + ensures + - #pixel_coordinates.size() == get_feature_pool_size() + - for all valid i: + - pixel_coordinates[i] == a point in the box defined by the min/max x/y arguments. + !*/ + { + pixel_coordinates.resize(get_feature_pool_size()); + for (unsigned long i = 0; i < get_feature_pool_size(); ++i) + { + pixel_coordinates[i].x() = rnd.get_random_double()*(max_x-min_x) + min_x; + pixel_coordinates[i].y() = rnd.get_random_double()*(max_y-min_y) + min_y; + } + } + + std::vector > > randomly_sample_pixel_coordinates ( + const matrix& initial_shape + ) const + { + const double padding = get_feature_pool_region_padding(); + // Figure out the bounds on the object shapes. We will sample uniformly + // from this box. + matrix temp = reshape(initial_shape, initial_shape.size()/2, 2); + double min_x = min(colm(temp,0)); + double min_y = min(colm(temp,1)); + double max_x = max(colm(temp,0)); + double max_y = max(colm(temp,1)); + + if (get_padding_mode() == bounding_box_relative) + { + min_x = std::min(0.0, min_x); + min_y = std::min(0.0, min_y); + max_x = std::max(1.0, max_x); + max_y = std::max(1.0, max_y); + } + + min_x -= padding; + min_y -= padding; + max_x += padding; + max_y += padding; + + std::vector > > pixel_coordinates; + pixel_coordinates.resize(get_cascade_depth()); + for (unsigned long i = 0; i < get_cascade_depth(); ++i) + randomly_sample_pixel_coordinates(pixel_coordinates[i], min_x, min_y, max_x, max_y); + return pixel_coordinates; + } + + + + mutable dlib::rand rnd; + + unsigned long _cascade_depth; + unsigned long _tree_depth; + unsigned long _num_trees_per_cascade_level; + double _nu; + unsigned long _oversampling_amount; + unsigned long _feature_pool_size; + double _lambda; + unsigned long _num_test_splits; + double _feature_pool_region_padding; + bool _verbose; + unsigned long _num_threads; + padding_mode_t _padding_mode; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename some_type_of_rectangle + > + image_dataset_metadata::dataset make_bounding_box_regression_training_data ( + const image_dataset_metadata::dataset& truth, + const std::vector>& detections + ) + { + DLIB_CASSERT(truth.images.size() == detections.size(), + "truth.images.size(): "<< truth.images.size() << + "\tdetections.size(): "<< detections.size() + ); + image_dataset_metadata::dataset result = truth; + + for (size_t i = 0; i < truth.images.size(); ++i) + { + result.images[i].boxes.clear(); + for (auto truth_box : truth.images[i].boxes) + { + if (truth_box.ignore) + continue; + + // Find the detection that best matches the current truth_box. + auto det = max_scoring_element(detections[i], [&truth_box](const rectangle& r) { return box_intersection_over_union(r, truth_box.rect); }); + if (det.second > 0.5) + { + // Remove any existing parts and replace them with the truth_box corners. + truth_box.parts.clear(); + auto b = truth_box.rect; + truth_box.parts["left"] = (b.tl_corner()+b.bl_corner())/2; + truth_box.parts["right"] = (b.tr_corner()+b.br_corner())/2; + truth_box.parts["top"] = (b.tl_corner()+b.tr_corner())/2; + truth_box.parts["bottom"] = (b.bl_corner()+b.br_corner())/2; + truth_box.parts["middle"] = center(b); + + // Now replace the bounding truth_box with the detector's bounding truth_box. + truth_box.rect = det.first; + + result.images[i].boxes.push_back(truth_box); + } + } + } + return result; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SHAPE_PREDICToR_TRAINER_H_ + diff --git a/ml/dlib/dlib/image_processing/shape_predictor_trainer_abstract.h b/ml/dlib/dlib/image_processing/shape_predictor_trainer_abstract.h new file mode 100644 index 000000000..278b97842 --- /dev/null +++ b/ml/dlib/dlib/image_processing/shape_predictor_trainer_abstract.h @@ -0,0 +1,418 @@ +// Copyright (C) 2014 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SHAPE_PREDICToR_TRAINER_ABSTRACT_H_ +#ifdef DLIB_SHAPE_PREDICToR_TRAINER_ABSTRACT_H_ + +#include "shape_predictor_abstract.h" +#include "../data_io/image_dataset_metadata.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class shape_predictor_trainer + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a tool for training shape_predictors based on annotated training + images. Its implementation uses the algorithm described in: + One Millisecond Face Alignment with an Ensemble of Regression Trees + by Vahid Kazemi and Josephine Sullivan, CVPR 2014 + + !*/ + + public: + + shape_predictor_trainer ( + ); + /*! + ensures + - #get_cascade_depth() == 10 + - #get_tree_depth() == 4 + - #get_num_trees_per_cascade_level() == 500 + - #get_nu() == 0.1 + - #get_oversampling_amount() == 20 + - #get_feature_pool_size() == 400 + - #get_lambda() == 0.1 + - #get_num_test_splits() == 20 + - #get_feature_pool_region_padding() == 0 + - #get_random_seed() == "" + - #get_num_threads() == 0 + - #get_padding_mode() == landmark_relative + - This object will not be verbose + !*/ + + unsigned long get_cascade_depth ( + ) const; + /*! + ensures + - returns the number of cascades created when you train a model. This + means that the total number of trees in the learned model is equal to + get_cascade_depth()*get_num_trees_per_cascade_level(). + !*/ + + void set_cascade_depth ( + unsigned long depth + ); + /*! + requires + - depth > 0 + ensures + - #get_cascade_depth() == depth + !*/ + + unsigned long get_tree_depth ( + ) const; + /*! + ensures + - returns the depth of the trees used in the cascade. In particular, there + are pow(2,get_tree_depth()) leaves in each tree. + !*/ + + void set_tree_depth ( + unsigned long depth + ); + /*! + requires + - depth > 0 + ensures + - #get_tree_depth() == depth + !*/ + + unsigned long get_num_trees_per_cascade_level ( + ) const; + /*! + ensures + - returns the number of trees created for each cascade. This means that + the total number of trees in the learned model is equal to + get_cascade_depth()*get_num_trees_per_cascade_level(). + !*/ + + void set_num_trees_per_cascade_level ( + unsigned long num + ); + /*! + requires + - num > 0 + ensures + - #get_num_trees_per_cascade_level() == num + !*/ + + double get_nu ( + ) const; + /*! + ensures + - returns the regularization parameter. Larger values of this parameter + will cause the algorithm to fit the training data better but may also + cause overfitting. + !*/ + + void set_nu ( + double nu + ); + /*! + requires + - 0 < nu <= 1 + ensures + - #get_nu() == nu + !*/ + + std::string get_random_seed ( + ) const; + /*! + ensures + - returns the random seed used by the internal random number generator. + Since this algorithm is a random forest style algorithm it relies on a + random number generator for generating the trees. So each setting of the + random seed will produce slightly different outputs. + !*/ + + void set_random_seed ( + const std::string& seed + ); + /*! + ensures + - #get_random_seed() == seed + !*/ + + unsigned long get_oversampling_amount ( + ) const; + /*! + ensures + - You give annotated images to this object as training examples. You + can effectively increase the amount of training data by adding in each + training example multiple times but with a randomly selected deformation + applied to it. That is what this parameter controls. That is, if you + supply N training samples to train() then the algorithm runs internally + with N*get_oversampling_amount() training samples. So the bigger this + parameter the better (excepting that larger values make training take + longer). In terms of the Kazemi paper, this parameter is the number of + randomly selected initial starting points sampled for each training + example. + !*/ + + void set_oversampling_amount ( + unsigned long amount + ); + /*! + requires + - amount > 0 + ensures + - #get_oversampling_amount() == amount + !*/ + + unsigned long get_feature_pool_size ( + ) const; + /*! + ensures + - At each level of the cascade we randomly sample get_feature_pool_size() + pixels from the image. These pixels are used to generate features for + the random trees. So in general larger settings of this parameter give + better accuracy but make the algorithm run slower. + !*/ + + void set_feature_pool_size ( + unsigned long size + ); + /*! + requires + - size > 1 + ensures + - #get_feature_pool_size() == size + !*/ + + enum padding_mode_t + { + bounding_box_relative, + landmark_relative + }; + + padding_mode_t get_padding_mode ( + ) const; + /*! + ensures + - returns the current padding mode. See get_feature_pool_region_padding() + for a discussion of the modes. + !*/ + + void set_padding_mode ( + padding_mode_t mode + ); + /*! + ensures + - #get_padding_mode() == mode + !*/ + + double get_feature_pool_region_padding ( + ) const; + /*! + ensures + - This algorithm works by comparing the relative intensity of pairs of + pixels in the input image. To decide which pixels to look at, the + training algorithm randomly selects pixels from a box roughly centered + around the object of interest. We call this box the feature pool region + box. + + Each object of interest is defined by a full_object_detection, which + contains a bounding box and a list of landmarks. If + get_padding_mode()==landmark_relative then the feature pool region box is + the tightest box that contains the landmarks inside the + full_object_detection. In this mode the full_object_detection's bounding + box is ignored. Otherwise, if the padding mode is bounding_box_relative + then the feature pool region box is the tightest box that contains BOTH + the landmarks and the full_object_detection's bounding box. + + Additionally, you can adjust the size of the feature pool padding region + by setting get_feature_pool_region_padding() to some value. If + get_feature_pool_region_padding()==0 then the feature pool region box is + unmodified and defined exactly as stated above. However, you can expand + the size of the box by setting the padding > 0 or shrink it by setting it + to something < 0. + + To explain this precisely, for a padding of 0 we say that the pixels are + sampled from a box of size 1x1. The padding value is added to each side + of the box. So a padding of 0.5 would cause the algorithm to sample + pixels from a box that was 2x2, effectively multiplying the area pixels + are sampled from by 4. Similarly, setting the padding to -0.2 would + cause it to sample from a box 0.6x0.6 in size. + !*/ + + void set_feature_pool_region_padding ( + double padding + ); + /*! + requires + - padding > -0.5 + ensures + - #get_feature_pool_region_padding() == padding + !*/ + + double get_lambda ( + ) const; + /*! + ensures + - To decide how to split nodes in the regression trees the algorithm looks + at pairs of pixels in the image. These pixel pairs are sampled randomly + but with a preference for selecting pixels that are near each other. + get_lambda() controls this "nearness" preference. In particular, smaller + values of get_lambda() will make the algorithm prefer to select pixels + close together and larger values of get_lambda() will make it care less + about picking nearby pixel pairs. + + Note that this is the inverse of how it is defined in the Kazemi paper. + For this object, you should think of lambda as "the fraction of the + bounding box will we traverse to find a neighboring pixel". Nominally, + this is normalized between 0 and 1. So reasonable settings of lambda are + values in the range 0 < lambda < 1. + !*/ + + void set_lambda ( + double lambda + ); + /*! + requires + - lambda > 0 + ensures + - #get_lambda() == lambda + !*/ + + unsigned long get_num_test_splits ( + ) const; + /*! + ensures + - When generating the random trees we randomly sample get_num_test_splits() + possible split features at each node and pick the one that gives the best + split. Larger values of this parameter will usually give more accurate + outputs but take longer to train. + !*/ + + void set_num_test_splits ( + unsigned long num + ); + /*! + requires + - num > 0 + ensures + - #get_num_test_splits() == num + !*/ + + unsigned long get_num_threads ( + ) const; + /*! + ensures + - When running training process, it is possible to make some parts of it parallel + using CPU threads with #parallel_for() extension and creating #thread_pool internally + When get_num_threads() == 0, trainer will not create threads and all processing will + be done in the calling thread + !*/ + + void set_num_threads ( + unsigned long num + ); + /*! + requires + - num >= 0 + ensures + - #get_num_threads() == num + !*/ + + void be_verbose ( + ); + /*! + ensures + - This object will print status messages to standard out so that a + user can observe the progress of the algorithm. + !*/ + + void be_quiet ( + ); + /*! + ensures + - This object will not print anything to standard out + !*/ + + template + shape_predictor train ( + const image_array& images, + const std::vector >& objects + ) const; + /*! + requires + - image_array is a dlib::array of image objects where each image object + implements the interface defined in dlib/image_processing/generic_image.h + - images.size() == objects.size() + - images.size() > 0 + - for some i: objects[i].size() != 0 + (i.e. there has to be at least one full_object_detection in the training set) + - for all valid p, there must exist i and j such that: + objects[i][j].part(p) != OBJECT_PART_NOT_PRESENT. + (i.e. You can't define a part that is always set to OBJECT_PART_NOT_PRESENT.) + - for all valid i,j,k,l: + - objects[i][j].num_parts() == objects[k][l].num_parts() + (i.e. all objects must agree on the number of parts) + - objects[i][j].num_parts() > 0 + ensures + - This object will try to learn to predict the locations of an object's parts + based on the object bounding box (i.e. full_object_detection::get_rect()) + and the image pixels in that box. That is, we will try to learn a + shape_predictor, SP, such that: + SP(images[i], objects[i][j].get_rect()) == objects[i][j] + This learned SP object is then returned. + - Not all parts are required to be observed for all objects. So if you + have training instances with missing parts then set the part positions + equal to OBJECT_PART_NOT_PRESENT and this algorithm will basically ignore + those missing parts. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename some_type_of_rectangle + > + image_dataset_metadata::dataset make_bounding_box_regression_training_data ( + const image_dataset_metadata::dataset& truth, + const std::vector>& detections + ); + /*! + requires + - truth.images.size() == detections.size() + - some_type_of_rectangle == rectangle, drectangle, mmod_rect, or any other type + that is convertible to a rectangle. + ensures + - Suppose you have an object detector that can roughly locate objects in an + image. This means your detector draws boxes around objects, but these are + *rough* boxes in the sense that they aren't positioned super accurately. For + instance, HOG based detectors usually have a stride of 8 pixels. So the + positional accuracy is going to be, at best, +/-8 pixels. + + If you want to get better positional accuracy one easy thing to do is train a + shape_predictor to give you the location of the object's box. The + make_bounding_box_regression_training_data() routine helps you do this by + creating an appropriate training dataset. It does this by taking the dataset + you used to train your detector (given by the truth object), and combining + that with the output of your detector on each image in the training dataset + (given by the detections object). In particular, it will create a new + annotated dataset where each object box is one of the rectangles from + detections and that object has 5 part annotations. These annotations + identify the sides and middle of the truth rectangle corresponding to the + detection rectangle. You can then take the returned dataset and train a + shape_predictor on it. The resulting shape_predictor can then be used to do + bounding box regression. + + As an aside, the reason we create 5 part annotations in this way is because + it gives the best shape_predictor when trained. If instead you used the 4 + corners it wouldn't work as well, due to tedious vagaries of the shape_predictor + training process. + + - We assume that detections[i] contains object detections corresponding to + the image truth.images[i]. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SHAPE_PREDICToR_TRAINER_ABSTRACT_H_ + diff --git a/ml/dlib/dlib/image_saver/dng_shared.h b/ml/dlib/dlib/image_saver/dng_shared.h new file mode 100644 index 000000000..d098851b3 --- /dev/null +++ b/ml/dlib/dlib/image_saver/dng_shared.h @@ -0,0 +1,288 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_DNG_SHAREd_ +#define DLIB_DNG_SHAREd_ + +#include "../pixel.h" +#include +#include "../uintn.h" + +namespace dlib +{ + + namespace dng_helpers_namespace + { + enum + { + grayscale = 1, + rgb, + hsi, + rgb_paeth, + rgb_alpha, + rgb_alpha_paeth, + grayscale_16bit, + grayscale_float + }; + + const unsigned long dng_magic_byte = 100; + + template + rgb_pixel predictor_rgb_paeth (const T& img, long row, long col) + /* + This is similar to the Paeth filter from the PNG image format. + */ + { + // a = left, b = above, c = upper left + rgb_pixel a(0,0,0), b(0,0,0), c(0,0,0); + + + const long c1 = col-1; + const long r1 = row-1; + + if (c1 >= 0) + assign_pixel(a, img[row][c1]); + else + assign_pixel(a,(unsigned char)0); + + if (c1 >= 0 && r1 >= 0) + assign_pixel(c, img[r1][c1]); + else + assign_pixel(c,(unsigned char)0); + + if (r1 >= 0) + assign_pixel(b, img[r1][col]); + else + assign_pixel(b,(unsigned char)0); + + + rgb_pixel p; + p.red = a.red + b.red - c.red; + p.green = a.green + b.green - c.green; + p.blue = a.blue + b.blue - c.blue; + + short pa = std::abs((short)p.red - (short)a.red) + + std::abs((short)p.green - (short)a.green) + + std::abs((short)p.blue - (short)a.blue); + short pb = std::abs((short)p.red - (short)b.red) + + std::abs((short)p.green - (short)b.green) + + std::abs((short)p.blue - (short)b.blue); + short pc = std::abs((short)p.red - (short)c.red) + + std::abs((short)p.green - (short)c.green) + + std::abs((short)p.blue - (short)c.blue); + + if (pa <= pb && pa <= pc) + return a; + else if (pb <= pc) + return b; + else + return c; + } + + + template + rgb_pixel predictor_rgb (const T& img, long row, long col) + { + // a = left, b = above, c = upper left + rgb_pixel a(0,0,0), b(0,0,0), c(0,0,0); + + + const long c1 = col-1; + const long r1 = row-1; + + if (c1 >= 0) + assign_pixel(a, img[row][c1]); + else + assign_pixel(a,(unsigned char)0); + + if (c1 >= 0 && r1 >= 0) + assign_pixel(c, img[r1][c1]); + else + assign_pixel(c,(unsigned char)0); + + if (r1 >= 0) + assign_pixel(b, img[r1][col]); + else + assign_pixel(b,(unsigned char)0); + + + rgb_pixel p; + p.red = a.red + b.red - c.red; + p.green = a.green + b.green - c.green; + p.blue = a.blue + b.blue - c.blue; + return p; + } + + template + rgb_alpha_pixel predictor_rgb_alpha_paeth (const T& img, long row, long col) + /* + This is similar to the Paeth filter from the PNG image format. + */ + { + // a = left, b = above, c = upper left + rgb_alpha_pixel a, b, c; + + + const long c1 = col-1; + const long r1 = row-1; + + if (c1 >= 0) + assign_pixel(a, img[row][c1]); + else + assign_pixel(a,(unsigned char)0); + + if (c1 >= 0 && r1 >= 0) + assign_pixel(c, img[r1][c1]); + else + assign_pixel(c,(unsigned char)0); + + if (r1 >= 0) + assign_pixel(b, img[r1][col]); + else + assign_pixel(b,(unsigned char)0); + + + rgb_alpha_pixel p; + p.red = a.red + b.red - c.red; + p.green = a.green + b.green - c.green; + p.blue = a.blue + b.blue - c.blue; + + short pa = std::abs((short)p.red - (short)a.red) + + std::abs((short)p.green - (short)a.green) + + std::abs((short)p.blue - (short)a.blue); + short pb = std::abs((short)p.red - (short)b.red) + + std::abs((short)p.green - (short)b.green) + + std::abs((short)p.blue - (short)b.blue); + short pc = std::abs((short)p.red - (short)c.red) + + std::abs((short)p.green - (short)c.green) + + std::abs((short)p.blue - (short)c.blue); + + if (pa <= pb && pa <= pc) + return a; + else if (pb <= pc) + return b; + else + return c; + } + + + template + rgb_alpha_pixel predictor_rgb_alpha (const T& img, long row, long col) + { + // a = left, b = above, c = upper left + rgb_alpha_pixel a, b, c; + + + const long c1 = col-1; + const long r1 = row-1; + + if (c1 >= 0) + assign_pixel(a, img[row][c1]); + else + assign_pixel(a,(unsigned char)0); + + if (c1 >= 0 && r1 >= 0) + assign_pixel(c, img[r1][c1]); + else + assign_pixel(c,(unsigned char)0); + + if (r1 >= 0) + assign_pixel(b, img[r1][col]); + else + assign_pixel(b,(unsigned char)0); + + + rgb_alpha_pixel p; + p.red = a.red + b.red - c.red; + p.green = a.green + b.green - c.green; + p.blue = a.blue + b.blue - c.blue; + p.alpha = a.alpha + b.alpha - c.alpha; + return p; + } + + + template + hsi_pixel predictor_hsi (const T& img, long row, long col) + { + // a = left, b = above, c = upper left + hsi_pixel a(0,0,0), b(0,0,0), c(0,0,0); + + + const long c1 = col-1; + const long r1 = row-1; + + if (c1 >= 0) + assign_pixel(a, img[row][c1]); + else + assign_pixel(a,(unsigned char)0); + + if (c1 >= 0 && r1 >= 0) + assign_pixel(c, img[r1][c1]); + else + assign_pixel(c,(unsigned char)0); + + if (r1 >= 0) + assign_pixel(b, img[r1][col]); + else + assign_pixel(b,(unsigned char)0); + + + hsi_pixel p; + p.h = a.h + b.h - c.h; + p.s = a.s + b.s - c.s; + p.i = a.i + b.i - c.i; + return p; + } + + template + unsigned char predictor_grayscale (const T& img, long row, long col) + { + // a = left, b = above, c = upper left + unsigned char a = 0, b = 0, c = 0; + + + const long c1 = col-1; + const long r1 = row-1; + + if (c1 >= 0) + assign_pixel(a, img[row][c1]); + + if (c1 >= 0 && r1 >= 0) + assign_pixel(c, img[r1][c1]); + + if (r1 >= 0) + assign_pixel(b, img[r1][col]); + + + unsigned char p = a + b - c; + return p; + } + + template + uint16 predictor_grayscale_16 (const T& img, long row, long col) + { + // a = left, b = above, c = upper left + uint16 a = 0, b = 0, c = 0; + + + const long c1 = col-1; + const long r1 = row-1; + + if (c1 >= 0) + assign_pixel(a, img[row][c1]); + + if (c1 >= 0 && r1 >= 0) + assign_pixel(c, img[r1][c1]); + + if (r1 >= 0) + assign_pixel(b, img[r1][col]); + + + uint16 p = a + b - c; + return p; + } + + } +} + +#endif // DLIB_DNG_SHAREd_ + diff --git a/ml/dlib/dlib/image_saver/image_saver.h b/ml/dlib/dlib/image_saver/image_saver.h new file mode 100644 index 000000000..43a2717af --- /dev/null +++ b/ml/dlib/dlib/image_saver/image_saver.h @@ -0,0 +1,688 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_IMAGE_SAVEr_ +#define DLIB_IMAGE_SAVEr_ + +#include "image_saver_abstract.h" +#include +#include +#include +#include "../algs.h" +#include "../pixel.h" +#include "../byte_orderer.h" +#include "../entropy_encoder.h" +#include "../entropy_encoder_model.h" +#include "dng_shared.h" +#include "../uintn.h" +#include "../dir_nav.h" +#include "../float_details.h" +#include "../vectorstream.h" +#include "../matrix/matrix_exp.h" +#include "../image_transforms/assign_image.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class image_save_error : public dlib::error { + public: image_save_error(const std::string& str) : error(EIMAGE_SAVE,str){} + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type, + bool grayscale = pixel_traits::pixel_type>::grayscale + > + struct save_bmp_helper; + + + template + struct save_bmp_helper + { + static void save_bmp ( + const image_type& image_, + std::ostream& out + ) + { + const_image_view image(image_); + // we are going to write out a 24bit color image. + byte_orderer::kernel_1a bo; + + out.write("BM",2); + + if (!out) + throw image_save_error("error writing image to output stream"); + + + unsigned long pad = 4 - (image.nc()*3)%4; + if (pad == 4) + pad = 0; + + unsigned long bfSize = 14 + 40 + (image.nc()*3 + pad)*image.nr(); + unsigned long bfReserved = 0; + unsigned long bfOffBits = 14 + 40; + unsigned long biSize = 40; + unsigned long biWidth = image.nc(); + unsigned long biHeight = image.nr(); + unsigned short biPlanes = 1; + unsigned short biBitCount = 24; + unsigned long biCompression = 0; + unsigned long biSizeImage = 0; + unsigned long biXPelsPerMeter = 0; + unsigned long biYPelsPerMeter = 0; + unsigned long biClrUsed = 0; + unsigned long biClrImportant = 0; + + bo.host_to_little(bfSize); + bo.host_to_little(bfOffBits); + bo.host_to_little(biSize); + bo.host_to_little(biWidth); + bo.host_to_little(biHeight); + bo.host_to_little(biPlanes); + bo.host_to_little(biBitCount); + + out.write((char*)&bfSize,4); + out.write((char*)&bfReserved,4); + out.write((char*)&bfOffBits,4); + out.write((char*)&biSize,4); + out.write((char*)&biWidth,4); + out.write((char*)&biHeight,4); + out.write((char*)&biPlanes,2); + out.write((char*)&biBitCount,2); + out.write((char*)&biCompression,4); + out.write((char*)&biSizeImage,4); + out.write((char*)&biXPelsPerMeter,4); + out.write((char*)&biYPelsPerMeter,4); + out.write((char*)&biClrUsed,4); + out.write((char*)&biClrImportant,4); + + + if (!out) + throw image_save_error("error writing image to output stream"); + + // now we write out the pixel data + for (long row = image.nr()-1; row >= 0; --row) + { + for (long col = 0; col < image.nc(); ++col) + { + rgb_pixel p; + p.red = 0; + p.green = 0; + p.blue = 0; + assign_pixel(p,image[row][col]); + out.write((char*)&p.blue,1); + out.write((char*)&p.green,1); + out.write((char*)&p.red,1); + } + + // write out some zeros so that this line is a multiple of 4 bytes + for (unsigned long i = 0; i < pad; ++i) + { + unsigned char p = 0; + out.write((char*)&p,1); + } + } + + if (!out) + throw image_save_error("error writing image to output stream"); + } + }; + + template + struct save_bmp_helper + { + static void save_bmp ( + const image_type& image_, + std::ostream& out + ) + { + const_image_view image(image_); + // we are going to write out an 8bit color image. + byte_orderer::kernel_1a bo; + + out.write("BM",2); + + if (!out) + throw image_save_error("error writing image to output stream"); + + unsigned long pad = 4 - image.nc()%4; + if (pad == 4) + pad = 0; + + unsigned long bfSize = 14 + 40 + (image.nc() + pad)*image.nr() + 256*4; + unsigned long bfReserved = 0; + unsigned long bfOffBits = 14 + 40 + 256*4; + unsigned long biSize = 40; + unsigned long biWidth = image.nc(); + unsigned long biHeight = image.nr(); + unsigned short biPlanes = 1; + unsigned short biBitCount = 8; + unsigned long biCompression = 0; + unsigned long biSizeImage = 0; + unsigned long biXPelsPerMeter = 0; + unsigned long biYPelsPerMeter = 0; + unsigned long biClrUsed = 0; + unsigned long biClrImportant = 0; + + bo.host_to_little(bfSize); + bo.host_to_little(bfOffBits); + bo.host_to_little(biSize); + bo.host_to_little(biWidth); + bo.host_to_little(biHeight); + bo.host_to_little(biPlanes); + bo.host_to_little(biBitCount); + + out.write((char*)&bfSize,4); + out.write((char*)&bfReserved,4); + out.write((char*)&bfOffBits,4); + out.write((char*)&biSize,4); + out.write((char*)&biWidth,4); + out.write((char*)&biHeight,4); + out.write((char*)&biPlanes,2); + out.write((char*)&biBitCount,2); + out.write((char*)&biCompression,4); + out.write((char*)&biSizeImage,4); + out.write((char*)&biXPelsPerMeter,4); + out.write((char*)&biYPelsPerMeter,4); + out.write((char*)&biClrUsed,4); + out.write((char*)&biClrImportant,4); + + + // write out the color palette + for (unsigned int i = 0; i <= 255; ++i) + { + unsigned char ch = static_cast(i); + out.write((char*)&ch,1); + out.write((char*)&ch,1); + out.write((char*)&ch,1); + ch = 0; + out.write((char*)&ch,1); + } + + if (!out) + throw image_save_error("error writing image to output stream"); + + // now we write out the pixel data + for (long row = image.nr()-1; row >= 0; --row) + { + for (long col = 0; col < image.nc(); ++col) + { + unsigned char p = 0; + assign_pixel(p,image[row][col]); + out.write((char*)&p,1); + } + + // write out some zeros so that this line is a multiple of 4 bytes + for (unsigned long i = 0; i < pad; ++i) + { + unsigned char p = 0; + out.write((char*)&p,1); + } + } + + if (!out) + throw image_save_error("error writing image to output stream"); + + } + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + inline typename disable_if >::type save_bmp ( + const image_type& image, + std::ostream& out + ) + { + save_bmp_helper::save_bmp(image,out); + } + + template < + typename EXP + > + inline void save_bmp ( + const matrix_exp& image, + std::ostream& out + ) + { + array2d temp; + assign_image(temp, image); + save_bmp_helper >::save_bmp(temp,out); + } + +// ---------------------------------------------------------------------------------------- + + namespace dng_helpers_namespace + { + template < + typename image_type, + typename enabled = void + > + struct save_dng_helper; + + typedef entropy_encoder::kernel_2a encoder_type; + typedef entropy_encoder_model<256,encoder_type>::kernel_5a eem_type; + + typedef entropy_encoder_model<256,encoder_type>::kernel_4a eem_exp_type; + + template + struct save_dng_helper::pixel_type> >::type > + { + static void save_dng ( + const image_type& image_, + std::ostream& out + ) + { + const_image_view image(image_); + out.write("DNG",3); + unsigned long version = 1; + serialize(version,out); + unsigned long type = grayscale_float; + serialize(type,out); + serialize(image.nc(),out); + serialize(image.nr(),out); + + + // Write the compressed exponent data into expbuf. We will append it + // to the stream at the end of the loops. + std::vector expbuf; + expbuf.reserve(image.size()*2); + vectorstream outexp(expbuf); + encoder_type encoder; + encoder.set_stream(outexp); + + eem_exp_type eem_exp(encoder); + float_details prev; + for (long r = 0; r < image.nr(); ++r) + { + for (long c = 0; c < image.nc(); ++c) + { + float_details cur = image[r][c]; + int16 exp = cur.exponent-prev.exponent; + int64 man = cur.mantissa-prev.mantissa; + prev = cur; + + unsigned char ebyte1 = exp&0xFF; + unsigned char ebyte2 = exp>>8; + eem_exp.encode(ebyte1); + eem_exp.encode(ebyte2); + + serialize(man, out); + } + } + // write out the magic byte to mark the end of the compressed data. + eem_exp.encode(dng_magic_byte); + eem_exp.encode(dng_magic_byte); + eem_exp.encode(dng_magic_byte); + eem_exp.encode(dng_magic_byte); + + encoder.clear(); + serialize(expbuf, out); + } + }; + + + template + struct is_non_float_non8bit_grayscale + { + typedef typename image_traits::pixel_type pixel_type; + const static bool value = pixel_traits::grayscale && + sizeof(pixel_type) != 1 && + !is_float_type::value; + }; + + template + struct save_dng_helper >::type> + { + static void save_dng ( + const image_type& image_, + std::ostream& out + ) + { + const_image_view image(image_); + out.write("DNG",3); + unsigned long version = 1; + serialize(version,out); + unsigned long type = grayscale_16bit; + serialize(type,out); + serialize(image.nc(),out); + serialize(image.nr(),out); + + encoder_type encoder; + encoder.set_stream(out); + + eem_type eem(encoder); + for (long r = 0; r < image.nr(); ++r) + { + for (long c = 0; c < image.nc(); ++c) + { + uint16 cur; + assign_pixel(cur, image[r][c]); + cur -= predictor_grayscale_16(image,r,c); + unsigned char byte1 = cur&0xFF; + unsigned char byte2 = cur>>8; + eem.encode(byte2); + eem.encode(byte1); + } + } + // write out the magic byte to mark the end of the data + eem.encode(dng_magic_byte); + eem.encode(dng_magic_byte); + eem.encode(dng_magic_byte); + eem.encode(dng_magic_byte); + } + }; + + template + struct is_8bit_grayscale + { + typedef typename image_traits::pixel_type pixel_type; + const static bool value = pixel_traits::grayscale && sizeof(pixel_type) == 1; + }; + + template + struct save_dng_helper >::type> + { + static void save_dng ( + const image_type& image_, + std::ostream& out + ) + { + const_image_view image(image_); + out.write("DNG",3); + unsigned long version = 1; + serialize(version,out); + unsigned long type = grayscale; + serialize(type,out); + serialize(image.nc(),out); + serialize(image.nr(),out); + + encoder_type encoder; + encoder.set_stream(out); + + eem_type eem(encoder); + for (long r = 0; r < image.nr(); ++r) + { + for (long c = 0; c < image.nc(); ++c) + { + unsigned char cur; + assign_pixel(cur, image[r][c]); + cur -= predictor_grayscale(image,r,c); + eem.encode(cur); + } + } + // write out the magic byte to mark the end of the data + eem.encode(dng_magic_byte); + eem.encode(dng_magic_byte); + eem.encode(dng_magic_byte); + eem.encode(dng_magic_byte); + } + }; + + template + struct is_rgb_image + { + typedef typename image_traits::pixel_type pixel_type; + const static bool value = pixel_traits::rgb; + }; + + template + struct save_dng_helper >::type> + { + static void save_dng ( + const image_type& image_, + std::ostream& out + ) + { + const_image_view image(image_); + out.write("DNG",3); + unsigned long version = 1; + serialize(version,out); + + unsigned long type = rgb; + // if this is a small image then we will use a different predictor + if (image.size() < 4000) + type = rgb_paeth; + + serialize(type,out); + serialize(image.nc(),out); + serialize(image.nr(),out); + + encoder_type encoder; + encoder.set_stream(out); + + rgb_pixel pre, cur; + eem_type eem(encoder); + + if (type == rgb) + { + for (long r = 0; r < image.nr(); ++r) + { + for (long c = 0; c < image.nc(); ++c) + { + pre = predictor_rgb(image,r,c); + assign_pixel(cur, image[r][c]); + + eem.encode((unsigned char)(cur.red - pre.red)); + eem.encode((unsigned char)(cur.green - pre.green)); + eem.encode((unsigned char)(cur.blue - pre.blue)); + } + } + } + else + { + for (long r = 0; r < image.nr(); ++r) + { + for (long c = 0; c < image.nc(); ++c) + { + pre = predictor_rgb_paeth(image,r,c); + assign_pixel(cur, image[r][c]); + + eem.encode((unsigned char)(cur.red - pre.red)); + eem.encode((unsigned char)(cur.green - pre.green)); + eem.encode((unsigned char)(cur.blue - pre.blue)); + } + } + } + // write out the magic byte to mark the end of the data + eem.encode(dng_magic_byte); + eem.encode(dng_magic_byte); + eem.encode(dng_magic_byte); + eem.encode(dng_magic_byte); + } + }; + + template + struct is_rgb_alpha_image + { + typedef typename image_traits::pixel_type pixel_type; + const static bool value = pixel_traits::rgb_alpha; + }; + + template + struct save_dng_helper >::type> + { + static void save_dng ( + const image_type& image_, + std::ostream& out + ) + { + const_image_view image(image_); + out.write("DNG",3); + unsigned long version = 1; + serialize(version,out); + + unsigned long type = rgb_alpha; + // if this is a small image then we will use a different predictor + if (image.size() < 4000) + type = rgb_alpha_paeth; + + serialize(type,out); + serialize(image.nc(),out); + serialize(image.nr(),out); + + encoder_type encoder; + encoder.set_stream(out); + + rgb_alpha_pixel pre, cur; + eem_type eem(encoder); + + if (type == rgb_alpha) + { + for (long r = 0; r < image.nr(); ++r) + { + for (long c = 0; c < image.nc(); ++c) + { + pre = predictor_rgb_alpha(image,r,c); + assign_pixel(cur, image[r][c]); + + eem.encode((unsigned char)(cur.red - pre.red)); + eem.encode((unsigned char)(cur.green - pre.green)); + eem.encode((unsigned char)(cur.blue - pre.blue)); + eem.encode((unsigned char)(cur.alpha - pre.alpha)); + } + } + } + else + { + for (long r = 0; r < image.nr(); ++r) + { + for (long c = 0; c < image.nc(); ++c) + { + pre = predictor_rgb_alpha_paeth(image,r,c); + assign_pixel(cur, image[r][c]); + + eem.encode((unsigned char)(cur.red - pre.red)); + eem.encode((unsigned char)(cur.green - pre.green)); + eem.encode((unsigned char)(cur.blue - pre.blue)); + eem.encode((unsigned char)(cur.alpha - pre.alpha)); + } + } + } + // write out the magic byte to mark the end of the data + eem.encode(dng_magic_byte); + eem.encode(dng_magic_byte); + eem.encode(dng_magic_byte); + eem.encode(dng_magic_byte); + } + }; + + template + struct is_hsi_image + { + typedef typename image_traits::pixel_type pixel_type; + const static bool value = pixel_traits::hsi; + }; + + template + struct save_dng_helper >::type> + { + static void save_dng ( + const image_type& image_, + std::ostream& out + ) + { + const_image_view image(image_); + out.write("DNG",3); + unsigned long version = 1; + serialize(version,out); + unsigned long type = hsi; + serialize(type,out); + serialize(image.nc(),out); + serialize(image.nr(),out); + + encoder_type encoder; + encoder.set_stream(out); + + hsi_pixel pre, cur; + eem_type eem(encoder); + for (long r = 0; r < image.nr(); ++r) + { + for (long c = 0; c < image.nc(); ++c) + { + pre = predictor_hsi(image,r,c); + assign_pixel(cur, image[r][c]); + + eem.encode((unsigned char)(cur.h - pre.h)); + eem.encode((unsigned char)(cur.s - pre.s)); + eem.encode((unsigned char)(cur.i - pre.i)); + } + } + // write out the magic byte to mark the end of the data + eem.encode(dng_magic_byte); + eem.encode(dng_magic_byte); + eem.encode(dng_magic_byte); + eem.encode(dng_magic_byte); + } + }; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + inline typename disable_if >::type save_dng ( + const image_type& image, + std::ostream& out + ) + { + using namespace dng_helpers_namespace; + save_dng_helper::save_dng(image,out); + } + + template < + typename EXP + > + inline void save_dng ( + const matrix_exp& image, + std::ostream& out + ) + { + array2d temp; + assign_image(temp, image); + using namespace dng_helpers_namespace; + save_dng_helper >::save_dng(temp,out); + } + +// ---------------------------------------------------------------------------------------- + + template + void save_dng ( + const image_type& image, + const std::string& file_name + ) + { + std::ofstream fout(file_name.c_str(), std::ios::binary); + if (!fout) + throw image_save_error("Unable to open " + file_name + " for writing."); + save_dng(image, fout); + } + +// ---------------------------------------------------------------------------------------- + + template + void save_bmp ( + const image_type& image, + const std::string& file_name + ) + { + std::ofstream fout(file_name.c_str(), std::ios::binary); + if (!fout) + throw image_save_error("Unable to open " + file_name + " for writing."); + save_bmp(image, fout); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_IMAGE_SAVEr_ + + + + diff --git a/ml/dlib/dlib/image_saver/image_saver_abstract.h b/ml/dlib/dlib/image_saver/image_saver_abstract.h new file mode 100644 index 000000000..82f91ed45 --- /dev/null +++ b/ml/dlib/dlib/image_saver/image_saver_abstract.h @@ -0,0 +1,129 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_IMAGE_SAVEr_ABSTRACT_ +#ifdef DLIB_IMAGE_SAVEr_ABSTRACT_ + +#include +#include "../algs.h" +#include "../pixel.h" +#include "../image_processing/generic_image.h" + +namespace dlib +{ + class image_save_error : public dlib::error + { + /*! + WHAT THIS OBJECT REPRESENTS + This is an exception used to indicate a failure to save an image. + Its type member variable will be set to EIMAGE_SAVE. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + void save_bmp ( + const image_type& image, + std::ostream& out + ); + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h or any kind of matrix expression. + ensures + - writes the image to the out stream in the Microsoft Windows BMP format. + - image[0][0] will be in the upper left corner of the image. + - image[image.nr()-1][image.nc()-1] will be in the lower right + corner of the image. + - This routine can save images containing any type of pixel. However, it + will convert all color pixels into rgb_pixel and grayscale pixels into + uint8 type before saving to disk. + throws + - image_save_error + This exception is thrown if there is an error that prevents us + from saving the image. + - std::bad_alloc + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + void save_bmp ( + const image_type& image, + const std::string& file_name + ); + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h or any kind of matrix expression. + ensures + - opens the file indicated by file_name with an output file stream named fout + and performs: + save_bmp(image,fout); + !*/ + +// ---------------------------------------------------------------------------------------- + + /*! + dlib dng file format: + This is a file format I created for this library. It is a lossless + compressed image format that is similar to the PNG format but uses + the dlib PPM compression algorithms instead of the DEFLATE algorithm. + !*/ + + template < + typename image_type + > + void save_dng ( + const image_type& image, + std::ostream& out + ); + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h or any kind of matrix expression. + ensures + - writes the image to the out stream in the dlib dng format. + - image[0][0] will be in the upper left corner of the image. + - image[image.nr()-1][image.nc()-1] will be in the lower right + corner of the image. + - This routine can save images containing any type of pixel. However, the DNG + format can natively store only the following pixel types: rgb_pixel, + hsi_pixel, rgb_alpha_pixel, uint8, uint16, float, and double. + All other pixel types will be converted into one of these types as + appropriate before being saved to disk. + throws + - image_save_error + This exception is thrown if there is an error that prevents us + from saving the image. + - std::bad_alloc + !*/ + +// ---------------------------------------------------------------------------------------- + + template + void save_dng ( + const image_type& image, + const std::string& file_name + ); + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h or any kind of matrix expression. + ensures + - opens the file indicated by file_name with an output file stream named fout + and performs: + save_dng(image,fout); + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_IMAGE_SAVEr_ABSTRACT_ + + diff --git a/ml/dlib/dlib/image_saver/save_jpeg.cpp b/ml/dlib/dlib/image_saver/save_jpeg.cpp new file mode 100644 index 000000000..ef637fa7a --- /dev/null +++ b/ml/dlib/dlib/image_saver/save_jpeg.cpp @@ -0,0 +1,175 @@ +// Copyright (C) 2014 Davis E. King (davis@dlib.net), Nils Labugt +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_JPEG_SAVER_CPp_ +#define DLIB_JPEG_SAVER_CPp_ + +// only do anything with this file if DLIB_JPEG_SUPPORT is defined +#ifdef DLIB_JPEG_SUPPORT + +#include "../array2d.h" +#include "../pixel.h" +#include "save_jpeg.h" +#include +#include +#include +#include "image_saver.h" + +#ifdef DLIB_JPEG_STATIC +# include "../external/libjpeg/jpeglib.h" +#else +# include +#endif + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + struct jpeg_saver_error_mgr + { + jpeg_error_mgr pub; /* "public" fields */ + jmp_buf setjmp_buffer; /* for return to caller */ + }; + + void jpeg_saver_error_exit (j_common_ptr cinfo) + { + /* cinfo->err really points to a jpeg_saver_error_mgr struct, so coerce pointer */ + jpeg_saver_error_mgr* myerr = (jpeg_saver_error_mgr*) cinfo->err; + + /* Return control to the setjmp point */ + longjmp(myerr->setjmp_buffer, 1); + } + +// ---------------------------------------------------------------------------------------- + + void save_jpeg ( + const array2d& img, + const std::string& filename, + int quality + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(img.size() != 0, + "\t save_jpeg()" + << "\n\t You can't save an empty image as a JPEG." + ); + DLIB_CASSERT(0 <= quality && quality <= 100, + "\t save_jpeg()" + << "\n\t Invalid quality value." + << "\n\t quality: " << quality + ); + + FILE* outfile = fopen(filename.c_str(), "wb"); + if (!outfile) + throw image_save_error("Can't open file " + filename + " for writing."); + + jpeg_compress_struct cinfo; + + jpeg_saver_error_mgr jerr; + cinfo.err = jpeg_std_error(&jerr.pub); + jerr.pub.error_exit = jpeg_saver_error_exit; + /* Establish the setjmp return context for my_error_exit to use. */ + if (setjmp(jerr.setjmp_buffer)) + { + /* If we get here, the JPEG code has signaled an error. + * We need to clean up the JPEG object, close the input file, and return. + */ + jpeg_destroy_compress(&cinfo); + fclose(outfile); + throw image_save_error("save_jpeg: error while writing " + filename); + } + + jpeg_create_compress(&cinfo); + jpeg_stdio_dest(&cinfo, outfile); + + cinfo.image_width = img.nc(); + cinfo.image_height = img.nr(); + cinfo.input_components = 3; + cinfo.in_color_space = JCS_RGB; + jpeg_set_defaults(&cinfo); + jpeg_set_quality (&cinfo, quality, TRUE); + jpeg_start_compress(&cinfo, TRUE); + + // now write out the rows one at a time + while (cinfo.next_scanline < cinfo.image_height) { + JSAMPROW row_pointer = (JSAMPROW) &img[cinfo.next_scanline][0]; + jpeg_write_scanlines(&cinfo, &row_pointer, 1); + } + + jpeg_finish_compress(&cinfo); + jpeg_destroy_compress(&cinfo); + fclose( outfile ); + } + +// ---------------------------------------------------------------------------------------- + + void save_jpeg ( + const array2d& img, + const std::string& filename, + int quality + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(img.size() != 0, + "\t save_jpeg()" + << "\n\t You can't save an empty image as a JPEG." + ); + DLIB_CASSERT(0 <= quality && quality <= 100, + "\t save_jpeg()" + << "\n\t Invalid quality value." + << "\n\t quality: " << quality + ); + + + FILE* outfile = fopen(filename.c_str(), "wb"); + if (!outfile) + throw image_save_error("Can't open file " + filename + " for writing."); + + jpeg_compress_struct cinfo; + + jpeg_saver_error_mgr jerr; + cinfo.err = jpeg_std_error(&jerr.pub); + jerr.pub.error_exit = jpeg_saver_error_exit; + /* Establish the setjmp return context for my_error_exit to use. */ + if (setjmp(jerr.setjmp_buffer)) + { + /* If we get here, the JPEG code has signaled an error. + * We need to clean up the JPEG object, close the input file, and return. + */ + jpeg_destroy_compress(&cinfo); + fclose(outfile); + throw image_save_error("save_jpeg: error while writing " + filename); + } + + jpeg_create_compress(&cinfo); + jpeg_stdio_dest(&cinfo, outfile); + + cinfo.image_width = img.nc(); + cinfo.image_height = img.nr(); + cinfo.input_components = 1; + cinfo.in_color_space = JCS_GRAYSCALE; + jpeg_set_defaults(&cinfo); + jpeg_set_quality (&cinfo, quality, TRUE); + jpeg_start_compress(&cinfo, TRUE); + + // now write out the rows one at a time + while (cinfo.next_scanline < cinfo.image_height) { + JSAMPROW row_pointer = (JSAMPROW) &img[cinfo.next_scanline][0]; + jpeg_write_scanlines(&cinfo, &row_pointer, 1); + } + + jpeg_finish_compress(&cinfo); + jpeg_destroy_compress(&cinfo); + fclose( outfile ); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_JPEG_SUPPORT + +#endif // DLIB_JPEG_SAVER_CPp_ + + + diff --git a/ml/dlib/dlib/image_saver/save_jpeg.h b/ml/dlib/dlib/image_saver/save_jpeg.h new file mode 100644 index 000000000..fb1808c44 --- /dev/null +++ b/ml/dlib/dlib/image_saver/save_jpeg.h @@ -0,0 +1,82 @@ +// Copyright (C) 2014 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SAVE_JPEG_Hh_ +#define DLIB_SAVE_JPEG_Hh_ + +#include "save_jpeg_abstract.h" + +#include "../enable_if.h" +#include "../matrix.h" +#include "../array2d.h" +#include "../pixel.h" +#include "../image_processing/generic_image.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + void save_jpeg ( + const array2d& img, + const std::string& filename, + int quality = 75 + ); + +// ---------------------------------------------------------------------------------------- + + void save_jpeg ( + const array2d& img, + const std::string& filename, + int quality = 75 + ); + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + typename disable_if >::type save_jpeg( + const image_type& img, + const std::string& filename, + int quality = 75 + ) + { + // Convert any kind of grayscale image to an unsigned char image + if (pixel_traits::pixel_type>::grayscale) + { + array2d temp; + assign_image(temp, img); + save_jpeg(temp, filename, quality); + } + else + { + // This is some other kind of color image so just save it as an RGB image. + array2d temp; + assign_image(temp, img); + save_jpeg(temp, filename, quality); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP + > + void save_jpeg( + const matrix_exp& img, + const std::string& file_name, + int quality = 75 + ) + { + array2d temp; + assign_image(temp, img); + save_jpeg(temp, file_name, quality); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SAVE_JPEG_Hh_ + diff --git a/ml/dlib/dlib/image_saver/save_jpeg_abstract.h b/ml/dlib/dlib/image_saver/save_jpeg_abstract.h new file mode 100644 index 000000000..f441339b9 --- /dev/null +++ b/ml/dlib/dlib/image_saver/save_jpeg_abstract.h @@ -0,0 +1,52 @@ +// Copyright (C) 2014 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SAVE_JPEG_ABSTRACT_Hh_ +#ifdef DLIB_SAVE_JPEG_ABSTRACT_Hh_ + +#include "../image_processing/generic_image.h" +#include "../pixel.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + void save_jpeg ( + const image_type& img, + const std::string& filename, + int quality = 75 + ); + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h or a matrix expression + - image.size() != 0 + - 0 <= quality <= 100 + ensures + - writes the image to the file indicated by file_name in the JPEG format. + - image[0][0] will be in the upper left corner of the image. + - image[image.nr()-1][image.nc()-1] will be in the lower right corner of the + image. + - This routine can save images containing any type of pixel. However, + save_jpeg() can only natively store rgb_pixel and uint8 pixel types. All + other pixel types will be converted into one of these types as appropriate + before being saved to disk. + - The quality value determines how lossy the compression is. Larger quality + values result in larger output images but the images will look better. + throws + - image_save_error + This exception is thrown if there is an error that prevents us from saving + the image. + - std::bad_alloc + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SAVE_JPEG_ABSTRACT_Hh_ + diff --git a/ml/dlib/dlib/image_saver/save_png.cpp b/ml/dlib/dlib/image_saver/save_png.cpp new file mode 100644 index 000000000..1c96b929c --- /dev/null +++ b/ml/dlib/dlib/image_saver/save_png.cpp @@ -0,0 +1,124 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SAVE_PnG_CPPh_ +#define DLIB_SAVE_PnG_CPPh_ + +// only do anything with this file if DLIB_PNG_SUPPORT is defined +#ifdef DLIB_PNG_SUPPORT + +#include "save_png.h" +#include +#include +#include "../byte_orderer.h" + +namespace dlib +{ + // Don't do anything when libpng calls us to tell us about an error. Just return to + // our own code and throw an exception (at the long jump target). + void png_reader_user_error_fn_silent(png_structp png_struct, png_const_charp ) + { + longjmp(png_jmpbuf(png_struct),1); + } + void png_reader_user_warning_fn_silent(png_structp , png_const_charp ) + { + } + + namespace impl + { + void impl_save_png ( + const std::string& file_name, + std::vector& row_pointers, + const long width, + const png_type type, + const int bit_depth + ) + { + + FILE *fp; + png_structp png_ptr; + png_infop info_ptr; + + /* Open the file */ + fp = fopen(file_name.c_str(), "wb"); + if (fp == NULL) + throw image_save_error("Unable to open " + file_name + " for writing."); + + /* Create and initialize the png_struct with the desired error handler + * functions. If you want to use the default stderr and longjump method, + * you can supply NULL for the last three parameters. We also check that + * the library version is compatible with the one used at compile time, + * in case we are using dynamically linked libraries. REQUIRED. + */ + png_ptr = png_create_write_struct(PNG_LIBPNG_VER_STRING, NULL, &png_reader_user_error_fn_silent, &png_reader_user_warning_fn_silent); + + if (png_ptr == NULL) + { + fclose(fp); + throw image_save_error("Error while writing PNG file " + file_name); + } + + /* Allocate/initialize the image information data. REQUIRED */ + info_ptr = png_create_info_struct(png_ptr); + if (info_ptr == NULL) + { + fclose(fp); + png_destroy_write_struct(&png_ptr, NULL); + throw image_save_error("Error while writing PNG file " + file_name); + } + + /* Set error handling. REQUIRED if you aren't supplying your own + * error handling functions in the png_create_write_struct() call. + */ + if (setjmp(png_jmpbuf(png_ptr))) + { + /* If we get here, we had a problem writing the file */ + fclose(fp); + png_destroy_write_struct(&png_ptr, &info_ptr); + throw image_save_error("Error while writing PNG file " + file_name); + } + + int color_type = 0; + switch(type) + { + case png_type_rgb: color_type = PNG_COLOR_TYPE_RGB; break; + case png_type_rgb_alpha: color_type = PNG_COLOR_TYPE_RGB_ALPHA; break; + case png_type_gray: color_type = PNG_COLOR_TYPE_GRAY; break; + default: + { + fclose(fp); + png_destroy_write_struct(&png_ptr, &info_ptr); + throw image_save_error("Invalid color type"); + } + } + + + /* Set up the output control if you are using standard C streams */ + png_init_io(png_ptr, fp); + + + int png_transforms = PNG_TRANSFORM_IDENTITY; + byte_orderer bo; + if (bo.host_is_little_endian()) + png_transforms |= PNG_TRANSFORM_SWAP_ENDIAN; + + const long height = row_pointers.size(); + + + png_set_IHDR(png_ptr, info_ptr, width, height, bit_depth, color_type, PNG_INTERLACE_NONE, PNG_COMPRESSION_TYPE_DEFAULT, PNG_FILTER_TYPE_DEFAULT); + png_set_rows(png_ptr, info_ptr, &row_pointers[0]); + png_write_png(png_ptr, info_ptr, png_transforms, NULL); + + /* Clean up after the write, and free any memory allocated */ + png_destroy_write_struct(&png_ptr, &info_ptr); + + /* Close the file */ + fclose(fp); + } + } +} + +#endif // DLIB_PNG_SUPPORT + +#endif // DLIB_SAVE_PnG_CPPh_ + + diff --git a/ml/dlib/dlib/image_saver/save_png.h b/ml/dlib/dlib/image_saver/save_png.h new file mode 100644 index 000000000..cddf03ff6 --- /dev/null +++ b/ml/dlib/dlib/image_saver/save_png.h @@ -0,0 +1,162 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SAVE_PnG_Hh_ +#define DLIB_SAVE_PnG_Hh_ + +#include "save_png_abstract.h" +#include "image_saver.h" +#include "../array2d.h" +#include +#include +#include "../pixel.h" +#include "../matrix/matrix_exp.h" +#include "../image_transforms/assign_image.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + enum png_type + { + png_type_rgb, + png_type_rgb_alpha, + png_type_gray, + }; + + void impl_save_png ( + const std::string& file_name, + std::vector& row_pointers, + const long width, + const png_type type, + const int bit_depth + ); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + typename disable_if >::type save_png( + const image_type& img_, + const std::string& file_name + ) + { + const_image_view img(img_); + + // make sure requires clause is not broken + DLIB_CASSERT(img.size() != 0, + "\t save_png()" + << "\n\t You can't save an empty image as a PNG" + ); + + +#ifndef DLIB_PNG_SUPPORT + /* !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + You are getting this error because you are trying to use save_png() + but you haven't defined DLIB_PNG_SUPPORT. You must do so to use + this function. You must also make sure you set your build environment + to link against the libpng library. + !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!*/ + COMPILE_TIME_ASSERT(sizeof(image_type) == 0); +#else + std::vector row_pointers(img.nr()); + typedef typename image_traits::pixel_type pixel_type; + + if (is_same_type::value) + { + for (unsigned long i = 0; i < row_pointers.size(); ++i) + row_pointers[i] = (unsigned char*)(&img[i][0]); + + impl::impl_save_png(file_name, row_pointers, img.nc(), impl::png_type_rgb, 8); + } + else if (is_same_type::value) + { + for (unsigned long i = 0; i < row_pointers.size(); ++i) + row_pointers[i] = (unsigned char*)(&img[i][0]); + + impl::impl_save_png(file_name, row_pointers, img.nc(), impl::png_type_rgb_alpha, 8); + } + else if (pixel_traits::lab || pixel_traits::hsi || pixel_traits::rgb) + { + // convert from Lab or HSI to RGB (Or potentially RGB pixels that aren't laid out as R G B) + array2d temp_img; + assign_image(temp_img, img_); + for (unsigned long i = 0; i < row_pointers.size(); ++i) + row_pointers[i] = (unsigned char*)(&temp_img[i][0]); + + impl::impl_save_png(file_name, row_pointers, img.nc(), impl::png_type_rgb, 8); + } + else if (pixel_traits::rgb_alpha) + { + // convert from RGBA pixels that aren't laid out as R G B A + array2d temp_img; + assign_image(temp_img, img_); + for (unsigned long i = 0; i < row_pointers.size(); ++i) + row_pointers[i] = (unsigned char*)(&temp_img[i][0]); + + impl::impl_save_png(file_name, row_pointers, img.nc(), impl::png_type_rgb_alpha, 8); + } + else // this is supposed to be grayscale + { + DLIB_CASSERT(pixel_traits::grayscale, "impossible condition detected"); + + if (pixel_traits::is_unsigned && sizeof(pixel_type) == 1) + { + for (unsigned long i = 0; i < row_pointers.size(); ++i) + row_pointers[i] = (unsigned char*)(&img[i][0]); + + impl::impl_save_png(file_name, row_pointers, img.nc(), impl::png_type_gray, 8); + } + else if (pixel_traits::is_unsigned && sizeof(pixel_type) == 2) + { + for (unsigned long i = 0; i < row_pointers.size(); ++i) + row_pointers[i] = (unsigned char*)(&img[i][0]); + + impl::impl_save_png(file_name, row_pointers, img.nc(), impl::png_type_gray, 16); + } + else + { + // convert from whatever this is to 16bit grayscale + array2d temp_img; + assign_image(temp_img, img_); + for (unsigned long i = 0; i < row_pointers.size(); ++i) + row_pointers[i] = (unsigned char*)(&temp_img[i][0]); + + impl::impl_save_png(file_name, row_pointers, img.nc(), impl::png_type_gray, 16); + } + } + + +#endif + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP + > + void save_png( + const matrix_exp& img, + const std::string& file_name + ) + { + array2d temp; + assign_image(temp, img); + save_png(temp, file_name); + } + +// ---------------------------------------------------------------------------------------- + +} + +#ifdef NO_MAKEFILE +#include "save_png.cpp" +#endif + +#endif // DLIB_SAVE_PnG_Hh_ + diff --git a/ml/dlib/dlib/image_saver/save_png_abstract.h b/ml/dlib/dlib/image_saver/save_png_abstract.h new file mode 100644 index 000000000..ae495d1f2 --- /dev/null +++ b/ml/dlib/dlib/image_saver/save_png_abstract.h @@ -0,0 +1,50 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SAVE_PnG_ABSTRACT_ +#ifdef DLIB_SAVE_PnG_ABSTRACT_ + +#include "../pixel.h" +#include "../image_processing/generic_image.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + void save_png ( + const image_type& image, + const std::string& file_name + ); + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h or a matrix expression + - image.size() != 0 + ensures + - writes the image to the file indicated by file_name in the PNG (Portable Network Graphics) + format. + - image[0][0] will be in the upper left corner of the image. + - image[image.nr()-1][image.nc()-1] will be in the lower right + corner of the image. + - This routine can save images containing any type of pixel. However, save_png() can + only natively store the following pixel types: rgb_pixel, rgb_alpha_pixel, uint8, + and uint16. All other pixel types will be converted into one of these types as + appropriate before being saved to disk. + throws + - image_save_error + This exception is thrown if there is an error that prevents us from saving + the image. + - std::bad_alloc + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SAVE_PnG_ABSTRACT_ + + + diff --git a/ml/dlib/dlib/image_transforms.h b/ml/dlib/dlib/image_transforms.h new file mode 100644 index 000000000..89b4e0db6 --- /dev/null +++ b/ml/dlib/dlib/image_transforms.h @@ -0,0 +1,31 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#ifdef DLIB_ALL_SOURCE_END +#include "dlib_basic_cpp_build_tutorial.txt" +#endif + +#ifndef DLIB_IMAGE_TRANSFORMs_ +#define DLIB_IMAGE_TRANSFORMs_ + +#include "image_transforms/assign_image.h" +#include "image_transforms/equalize_histogram.h" +#include "image_transforms/morphological_operations.h" +#include "image_transforms/spatial_filtering.h" +#include "image_transforms/thresholding.h" +#include "image_transforms/edge_detector.h" +#include "image_transforms/draw.h" +#include "image_transforms/integral_image.h" +#include "image_transforms/image_pyramid.h" +#include "image_transforms/hough_transform.h" +#include "image_transforms/label_connected_blobs.h" +#include "image_transforms/colormaps.h" +#include "image_transforms/segment_image.h" +#include "image_transforms/interpolation.h" +#include "image_transforms/fhog.h" +#include "image_transforms/lbp.h" +#include "image_transforms/random_color_transform.h" +#include "image_transforms/random_cropper.h" + +#endif // DLIB_IMAGE_TRANSFORMs_ + diff --git a/ml/dlib/dlib/image_transforms/assign_image.h b/ml/dlib/dlib/image_transforms/assign_image.h new file mode 100644 index 000000000..c69878efa --- /dev/null +++ b/ml/dlib/dlib/image_transforms/assign_image.h @@ -0,0 +1,385 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ASSIGN_IMAGe_ +#define DLIB_ASSIGN_IMAGe_ + +#include "../pixel.h" +#include "assign_image_abstract.h" +#include "../statistics.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename dest_image_type, + typename src_image_type + > + void impl_assign_image ( + image_view& dest, + const src_image_type& src + ) + { + dest.set_size(src.nr(),src.nc()); + for (long r = 0; r < src.nr(); ++r) + { + for (long c = 0; c < src.nc(); ++c) + { + assign_pixel(dest[r][c], src(r,c)); + } + } + } + + template < + typename dest_image_type, + typename src_image_type + > + void impl_assign_image ( + dest_image_type& dest_, + const src_image_type& src + ) + { + image_view dest(dest_); + impl_assign_image(dest, src); + } + + template < + typename dest_image_type, + typename src_image_type + > + void assign_image ( + dest_image_type& dest, + const src_image_type& src + ) + { + // check for the case where dest is the same object as src + if (is_same_object(dest,src)) + return; + + impl_assign_image(dest, mat(src)); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename dest_image_type, + typename src_image_type + > + void impl_assign_image_scaled ( + image_view& dest, + const src_image_type& src, + const double thresh + ) + { + DLIB_ASSERT( thresh > 0, + "\tvoid assign_image_scaled()" + << "\n\t You have given an threshold value" + << "\n\t thresh: " << thresh + ); + + + typedef typename image_traits::pixel_type dest_pixel; + + // If the destination has a dynamic range big enough to contain the source image data then just do a + // regular assign_image() + if (pixel_traits::max() >= pixel_traits::max() && + pixel_traits::min() <= pixel_traits::min() ) + { + impl_assign_image(dest, src); + return; + } + + dest.set_size(src.nr(),src.nc()); + + if (src.size() == 0) + return; + + if (src.size() == 1) + { + impl_assign_image(dest, src); + return; + } + + // gather image statistics + running_stats rs; + for (long r = 0; r < src.nr(); ++r) + { + for (long c = 0; c < src.nc(); ++c) + { + rs.add(get_pixel_intensity(src(r,c))); + } + } + typedef typename pixel_traits::basic_pixel_type spix_type; + + if (std::numeric_limits::is_integer) + { + // If the destination has a dynamic range big enough to contain the source image data then just do a + // regular assign_image() + if (pixel_traits::max() >= rs.max() && + pixel_traits::min() <= rs.min() ) + { + impl_assign_image(dest, src); + return; + } + } + + // Figure out the range of pixel values based on image statistics. There might be some huge + // outliers so don't just pick the min and max values. + const double upper = std::min(rs.mean() + thresh*rs.stddev(), rs.max()); + const double lower = std::max(rs.mean() - thresh*rs.stddev(), rs.min()); + + + const double dest_min = pixel_traits::min(); + const double dest_max = pixel_traits::max(); + + const double scale = (upper!=lower)? ((dest_max - dest_min) / (upper - lower)) : 0; + + for (long r = 0; r < src.nr(); ++r) + { + for (long c = 0; c < src.nc(); ++c) + { + const double val = get_pixel_intensity(src(r,c)) - lower; + + assign_pixel(dest[r][c], scale*val + dest_min); + } + } + } + + template < + typename dest_image_type, + typename src_image_type + > + void impl_assign_image_scaled ( + dest_image_type& dest_, + const src_image_type& src, + const double thresh + ) + { + image_view dest(dest_); + impl_assign_image_scaled(dest, src, thresh); + } + + template < + typename dest_image_type, + typename src_image_type + > + void assign_image_scaled ( + dest_image_type& dest, + const src_image_type& src, + const double thresh = 4 + ) + { + // check for the case where dest is the same object as src + if (is_same_object(dest,src)) + return; + + impl_assign_image_scaled(dest, mat(src),thresh); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename dest_image_type, + typename src_pixel_type + > + void assign_all_pixels ( + image_view& dest_img, + const src_pixel_type& src_pixel + ) + { + for (long r = 0; r < dest_img.nr(); ++r) + { + for (long c = 0; c < dest_img.nc(); ++c) + { + assign_pixel(dest_img[r][c], src_pixel); + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename dest_image_type, + typename src_pixel_type + > + void assign_all_pixels ( + dest_image_type& dest_img_, + const src_pixel_type& src_pixel + ) + { + image_view dest_img(dest_img_); + assign_all_pixels(dest_img, src_pixel); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + void assign_border_pixels ( + image_view& img, + long x_border_size, + long y_border_size, + const typename image_traits::pixel_type& p + ) + { + DLIB_ASSERT( x_border_size >= 0 && y_border_size >= 0, + "\tvoid assign_border_pixels(img, p, border_size)" + << "\n\tYou have given an invalid border_size" + << "\n\tx_border_size: " << x_border_size + << "\n\ty_border_size: " << y_border_size + ); + + y_border_size = std::min(y_border_size, img.nr()/2+1); + x_border_size = std::min(x_border_size, img.nc()/2+1); + + // assign the top border + for (long r = 0; r < y_border_size; ++r) + { + for (long c = 0; c < img.nc(); ++c) + { + img[r][c] = p; + } + } + + // assign the bottom border + for (long r = img.nr()-y_border_size; r < img.nr(); ++r) + { + for (long c = 0; c < img.nc(); ++c) + { + img[r][c] = p; + } + } + + // now assign the two sides + for (long r = y_border_size; r < img.nr()-y_border_size; ++r) + { + // left border + for (long c = 0; c < x_border_size; ++c) + img[r][c] = p; + + // right border + for (long c = img.nc()-x_border_size; c < img.nc(); ++c) + img[r][c] = p; + } + } + + template < + typename image_type + > + void assign_border_pixels ( + image_type& img_, + long x_border_size, + long y_border_size, + const typename image_traits::pixel_type& p + ) + { + image_view img(img_); + assign_border_pixels(img, x_border_size, y_border_size, p); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + void zero_border_pixels ( + image_type& img, + long x_border_size, + long y_border_size + ) + { + DLIB_ASSERT( x_border_size >= 0 && y_border_size >= 0, + "\tvoid zero_border_pixels(img, p, border_size)" + << "\n\tYou have given an invalid border_size" + << "\n\tx_border_size: " << x_border_size + << "\n\ty_border_size: " << y_border_size + ); + + typename image_traits::pixel_type zero_pixel; + assign_pixel_intensity(zero_pixel, 0); + assign_border_pixels(img, x_border_size, y_border_size, zero_pixel); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + void zero_border_pixels ( + image_view& img, + long x_border_size, + long y_border_size + ) + { + DLIB_ASSERT( x_border_size >= 0 && y_border_size >= 0, + "\tvoid zero_border_pixels(img, p, border_size)" + << "\n\tYou have given an invalid border_size" + << "\n\tx_border_size: " << x_border_size + << "\n\ty_border_size: " << y_border_size + ); + + typename image_traits::pixel_type zero_pixel; + assign_pixel_intensity(zero_pixel, 0); + assign_border_pixels(img, x_border_size, y_border_size, zero_pixel); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + void zero_border_pixels ( + image_view& img, + rectangle inside + ) + { + inside = inside.intersect(get_rect(img)); + if (inside.is_empty()) + { + assign_all_pixels(img, 0); + return; + } + + for (long r = 0; r < inside.top(); ++r) + { + for (long c = 0; c < img.nc(); ++c) + assign_pixel(img[r][c], 0); + } + for (long r = inside.top(); r <= inside.bottom(); ++r) + { + for (long c = 0; c < inside.left(); ++c) + assign_pixel(img[r][c], 0); + for (long c = inside.right()+1; c < img.nc(); ++c) + assign_pixel(img[r][c], 0); + } + for (long r = inside.bottom()+1; r < img.nr(); ++r) + { + for (long c = 0; c < img.nc(); ++c) + assign_pixel(img[r][c], 0); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + void zero_border_pixels ( + image_type& img_, + const rectangle& inside + ) + { + image_view img(img_); + zero_border_pixels(img, inside); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_ASSIGN_IMAGe_ + + + diff --git a/ml/dlib/dlib/image_transforms/assign_image_abstract.h b/ml/dlib/dlib/image_transforms/assign_image_abstract.h new file mode 100644 index 000000000..5ba262ba5 --- /dev/null +++ b/ml/dlib/dlib/image_transforms/assign_image_abstract.h @@ -0,0 +1,196 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_ASSIGN_IMAGe_ABSTRACT +#ifdef DLIB_ASSIGN_IMAGe_ABSTRACT + +#include "../pixel.h" +#include "../image_processing/generic_image.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename dest_image_type, + typename src_image_type + > + void assign_image ( + dest_image_type& dest_img, + const src_image_type& src_img + ); + /*! + requires + - src_image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h or any object convertible to a matrix + via mat(). + - dest_image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h or an image_view. + ensures + - #dest_img.nc() == src_img.nc() + - #dest_img.nr() == src_img.nr() + - for all valid r and c: + - performs assign_pixel(#dest_img[r][c],src_img[r][c]) + (i.e. copies the src image to dest image) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename dest_image_type, + typename src_image_type + > + void assign_image_scaled ( + dest_image_type& dest_img, + const src_image_type& src_img, + const double thresh = 4 + ); + /*! + requires + - src_image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h or any object convertible to a matrix + via mat(). + - dest_image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h or an image_view. + - thresh > 0 + ensures + - #dest_img.nc() == src_img.nc() + - #dest_img.nr() == src_img.nr() + - if (dest_img's pixels have a wide enough dynamic range to contain all the + pixels in src_img. (Note that dynamic range is determined by the min() and + max() pixel_traits properties)) then + - performs: assign_image(dest_img, src_img) + (i.e. in this case, no scaling is performed. Just a normal color space + conversion and copy ) + - else + - #dest_img will be converted to a grayscale image + - scales the contents of src_img into the dynamic range of dest_img and then + assigns the result into dest_img. The thresh parameter is used to filter + source pixel values which are outliers. These outliers will saturate + at the edge of the destination image's dynamic range. + - Specifically, for all valid r and c: + - scales get_pixel_intensity(src_img[r][c]) into the dynamic range + of the dest_img. This is done by computing the mean and standard + deviation of src_img. Call the mean M and the standard deviation + D. Then the scaling from src_img to dest_img is performed using + the following mapping: + let SRC_UPPER = min(M + thresh*D, max(mat(src_img))) + let SRC_LOWER = max(M - thresh*D, min(mat(src_img))) + let DEST_UPPER = pixel_traits::pixel_type>::max() + let DEST_LOWER = pixel_traits::pixel_type>::min() + + MAPPING: [SRC_LOWER, SRC_UPPER] -> [DEST_LOWER, DEST_UPPER] + + Where this mapping is a linear mapping of values from the left range + into the right range of values. Source pixel values outside the left + range are modified to be at the appropriate end of the range. + + The scaled pixel is then stored in dest_img[r][c]. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename dest_image_type, + typename src_pixel_type + > + void assign_all_pixels ( + dest_image_type& dest_img, + const src_pixel_type& src_pixel + ); + /*! + requires + - dest_image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h or an image_view. + - pixel_traits is defined + ensures + - #dest_img.nc() == dest_img.nc() + - #dest_img.nr() == dest_img.nr() + (i.e. the size of dest_img isn't changed by this function) + - for all valid r and c: + - performs assign_pixel(#dest_img[r][c],src_pixel) + (i.e. assigns the src pixel to every pixel in the dest image) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + void assign_border_pixels ( + image_type& img, + long x_border_size, + long y_border_size, + const typename image_traits::pixel_type& p + ); + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h or an image_view + - x_border_size >= 0 + - y_border_size >= 0 + ensures + - #img.nc() == img.nc() + - #img.nr() == img.nr() + (i.e. the size of img isn't changed by this function) + - for all valid r such that r+y_border_size or r-y_border_size gives an invalid row + - for all valid c such that c+x_border_size or c-x_border_size gives an invalid column + - performs assign_pixel(#img[r][c],p) + (i.e. assigns the given pixel to every pixel in the border of img) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + void zero_border_pixels ( + image_type& img, + long x_border_size, + long y_border_size + ); + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h or an image_view + - x_border_size >= 0 + - y_border_size >= 0 + ensures + - #img.nc() == img.nc() + - #img.nr() == img.nr() + (i.e. the size of img isn't changed by this function) + - for all valid r such that r+y_border_size or r-y_border_size gives an invalid row + - for all valid c such that c+x_border_size or c-x_border_size gives an invalid column + - performs assign_pixel(#img[r][c], 0 ) + (i.e. assigns 0 to every pixel in the border of img) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + void zero_border_pixels ( + image_type& img, + rectangle inside + ); + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h or an image_view + ensures + - #img.nc() == img.nc() + - #img.nr() == img.nr() + (i.e. the size of img isn't changed by this function) + - All the pixels in img that are not contained inside the inside rectangle + given to this function are set to 0. That is, anything not "inside" is on + the border and set to 0. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_ASSIGN_IMAGe_ABSTRACT + + diff --git a/ml/dlib/dlib/image_transforms/colormaps.h b/ml/dlib/dlib/image_transforms/colormaps.h new file mode 100644 index 000000000..813d1ff75 --- /dev/null +++ b/ml/dlib/dlib/image_transforms/colormaps.h @@ -0,0 +1,269 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_RANDOMLY_COlOR_IMAGE_Hh_ +#define DLIB_RANDOMLY_COlOR_IMAGE_Hh_ + +#include "colormaps_abstract.h" +#include "../hash.h" +#include "../pixel.h" +#include "../matrix.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template + struct op_randomly_color_image : does_not_alias + { + op_randomly_color_image( const T& img_) : img(img_){} + + const T& img; + + const static long cost = 7; + const static long NR = 0; + const static long NC = 0; + typedef rgb_pixel type; + typedef const rgb_pixel const_ret_type; + typedef default_memory_manager mem_manager_type; + typedef row_major_layout layout_type; + + const_ret_type apply (long r, long c ) const + { + const unsigned long gray = get_pixel_intensity(mat(img)(r,c)); + if (gray != 0) + { + const uint32 h = murmur_hash3_2(gray,0); + rgb_pixel pix; + pix.red = static_cast(h)%200 + 55; + pix.green = static_cast(h>>8)%200 + 55; + pix.blue = static_cast(h>>16)%200 + 55; + return pix; + } + else + { + // keep black pixels black + return rgb_pixel(0,0,0); + } + } + + long nr () const { return num_rows(img); } + long nc () const { return num_columns(img); } + }; + + template < + typename image_type + > + const matrix_op > + randomly_color_image ( + const image_type& img + ) + { + typedef op_randomly_color_image op; + return matrix_op(op(img)); + } + +// ---------------------------------------------------------------------------------------- + + inline rgb_pixel colormap_heat ( + double value, + double min_val, + double max_val + ) + { + // scale the gray value into the range [0, 1] + const double gray = put_in_range(0, 1, (value - min_val)/(max_val-min_val)); + rgb_pixel pix(0,0,0); + + pix.red = static_cast(std::min(gray/0.4,1.0)*255 + 0.5); + + if (gray > 0.4) + { + pix.green = static_cast(std::min((gray-0.4)/0.4,1.0)*255 + 0.5); + } + if (gray > 0.8) + { + pix.blue = static_cast(std::min((gray-0.8)/0.2,1.0)*255 + 0.5); + } + + return pix; + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_heatmap : does_not_alias + { + op_heatmap( + const T& img_, + const double max_val_, + const double min_val_ + ) : img(img_), max_val(max_val_), min_val(min_val_){} + + const T& img; + + const double max_val; + const double min_val; + + const static long cost = 7; + const static long NR = 0; + const static long NC = 0; + typedef rgb_pixel type; + typedef const rgb_pixel const_ret_type; + typedef default_memory_manager mem_manager_type; + typedef row_major_layout layout_type; + + const_ret_type apply (long r, long c ) const + { + return colormap_heat(get_pixel_intensity(mat(img)(r,c)), min_val, max_val); + } + + long nr () const { return num_rows(img); } + long nc () const { return num_columns(img); } + }; + + template < + typename image_type + > + const matrix_op > + heatmap ( + const image_type& img, + double max_val, + double min_val = 0 + ) + { + typedef op_heatmap op; + return matrix_op(op(img,max_val,min_val)); + } + + template < + typename image_type + > + const matrix_op > + heatmap ( + const image_type& img + ) + { + typedef op_heatmap op; + if (num_columns(img) * num_rows(img) != 0) + return matrix_op(op(img,max(mat(img)),min(mat(img)))); + else + return matrix_op(op(img,0,0)); + } + +// ---------------------------------------------------------------------------------------- + + inline rgb_pixel colormap_jet ( + double value, + double min_val, + double max_val + ) + { + // scale the gray value into the range [0, 8] + const double gray = 8*put_in_range(0, 1, (value - min_val)/(max_val-min_val)); + rgb_pixel pix; + // s is the slope of color change + const double s = 1.0/2.0; + + if (gray <= 1) + { + pix.red = 0; + pix.green = 0; + pix.blue = static_cast((gray+1)*s*255 + 0.5); + } + else if (gray <= 3) + { + pix.red = 0; + pix.green = static_cast((gray-1)*s*255 + 0.5); + pix.blue = 255; + } + else if (gray <= 5) + { + pix.red = static_cast((gray-3)*s*255 + 0.5); + pix.green = 255; + pix.blue = static_cast((5-gray)*s*255 + 0.5); + } + else if (gray <= 7) + { + pix.red = 255; + pix.green = static_cast((7-gray)*s*255 + 0.5); + pix.blue = 0; + } + else + { + pix.red = static_cast((9-gray)*s*255 + 0.5); + pix.green = 0; + pix.blue = 0; + } + + return pix; + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_jet : does_not_alias + { + op_jet( + const T& img_, + const double max_val_, + const double min_val_ + ) : img(img_), max_val(max_val_), min_val(min_val_){} + + const T& img; + + const double max_val; + const double min_val; + + const static long cost = 7; + const static long NR = 0; + const static long NC = 0; + typedef rgb_pixel type; + typedef const rgb_pixel const_ret_type; + typedef default_memory_manager mem_manager_type; + typedef row_major_layout layout_type; + + const_ret_type apply (long r, long c ) const + { + return colormap_jet(get_pixel_intensity(mat(img)(r,c)), min_val, max_val); + } + + long nr () const { return num_rows(img); } + long nc () const { return num_columns(img); } + }; + + template < + typename image_type + > + const matrix_op > + jet ( + const image_type& img, + double max_val, + double min_val = 0 + ) + { + typedef op_jet op; + return matrix_op(op(img,max_val,min_val)); + } + + template < + typename image_type + > + const matrix_op > + jet ( + const image_type& img + ) + { + typedef op_jet op; + if (num_columns(img) * num_rows(img) != 0) + return matrix_op(op(img,max(mat(img)),min(mat(img)))); + else + return matrix_op(op(img,0,0)); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_RANDOMLY_COlOR_IMAGE_Hh_ + diff --git a/ml/dlib/dlib/image_transforms/colormaps_abstract.h b/ml/dlib/dlib/image_transforms/colormaps_abstract.h new file mode 100644 index 000000000..41a7784ba --- /dev/null +++ b/ml/dlib/dlib/image_transforms/colormaps_abstract.h @@ -0,0 +1,152 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_RANDOMLY_COlOR_IMAGE_ABSTRACT_Hh_ +#ifdef DLIB_RANDOMLY_COlOR_IMAGE_ABSTRACT_Hh_ + +#include "../hash.h" +#include "../pixel.h" +#include "../matrix.h" +#include "../image_processing/generic_image.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + const matrix_exp randomly_color_image ( + const image_type& img + ); + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h, or something convertible to a matrix + via mat(). + ensures + - randomly generates a mapping from gray level pixel values + to the RGB pixel space and then uses this mapping to create + a colored version of img. Returns a matrix which represents + this colored version of img. + - black pixels in img will remain black in the output image. + - The returned matrix will have the same dimensions as img. + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + rgb_pixel colormap_heat ( + double value, + double min_val, + double max_val + ); + /*! + requires + - min_val <= max_val + ensures + - Maps value to a color. In particular, we use a heatmap color scheme where + values <= min_val are black and larger values become more red, then yellow, + and then white as they approach max_val. + !*/ + + template < + typename image_type + > + const matrix_exp heatmap ( + const image_type& img, + double max_val, + double min_val = 0 + ); + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h, or something convertible to a matrix + via mat(). + ensures + - Interprets img as a grayscale image and returns a new matrix which represents + a colored version of img. In particular, the colormap is defined by + out_color = colormap_heat(grayscale_pixel_value, min_val, max_val). + - The returned matrix will have the same dimensions as img. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + const matrix_exp heatmap ( + const image_type& img + ); + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h, or something convertible to a matrix + via mat(). + ensures + - returns heatmap(img, max(mat(img)), min(mat(img))) + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + rgb_pixel colormap_jet ( + double value, + double min_val, + double max_val + ); + /*! + requires + - min_val <= max_val + ensures + - Maps value to a color. In particular, we use a jet color scheme where + values <= min_val are dark blue and larger values become light blue, then + yellow, and then finally red as they approach max_val. + !*/ + + template < + typename image_type + > + const matrix_exp jet ( + const image_type& img, + double max_val, + double min_val = 0 + ); + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h, or something convertible to a matrix + via mat(). + ensures + - Interprets img as a grayscale image and returns a new matrix which represents + a colored version of img. In particular, the colormap is defined by + out_color = colormap_jet(grayscale_pixel_value, min_val, max_val). + - The returned matrix will have the same dimensions as img. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + const matrix_exp jet ( + const image_type& img + ); + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h, or something convertible to a matrix + via mat(). + ensures + - returns jet(img, max(mat(img)), min(mat(img))) + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_RANDOMLY_COlOR_IMAGE_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/image_transforms/draw.h b/ml/dlib/dlib/image_transforms/draw.h new file mode 100644 index 000000000..66737b215 --- /dev/null +++ b/ml/dlib/dlib/image_transforms/draw.h @@ -0,0 +1,396 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_DRAW_IMAGe_ +#define DLIB_DRAW_IMAGe_ + +#include "draw_abstract.h" +#include "../algs.h" +#include "../pixel.h" +#include "../matrix.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type, + typename pixel_type + > + void draw_line ( + long x1, + long y1, + long x2, + long y2, + image_type& c_, + const pixel_type& val + ) + { + image_view c(c_); + if (x1 == x2) + { + // make sure y1 comes before y2 + if (y1 > y2) + swap(y1,y2); + + if (x1 < 0 || x1 >= c.nc()) + return; + + + // this is a vertical line + for (long y = y1; y <= y2; ++y) + { + if (y < 0 || y >= c.nr()) + continue; + + assign_pixel(c[y][x1], val); + } + } + else if (y1 == y2) + { + + // make sure x1 comes before x2 + if (x1 > x2) + swap(x1,x2); + + if (y1 < 0 || y1 >= c.nr()) + return; + + // this is a horizontal line + for (long x = x1; x <= x2; ++x) + { + if (x < 0 || x >= c.nc()) + continue; + + assign_pixel(c[y1][x] , val); + } + } + else + { + // This part is a little more complicated because we are going to perform alpha + // blending so the diagonal lines look nice. + const rectangle valid_area = get_rect(c); + rgb_alpha_pixel alpha_pixel; + assign_pixel(alpha_pixel, val); + const unsigned char max_alpha = alpha_pixel.alpha; + + const long rise = (((long)y2) - ((long)y1)); + const long run = (((long)x2) - ((long)x1)); + if (std::abs(rise) < std::abs(run)) + { + const double slope = ((double)rise)/run; + + + double first, last; + + + if (x1 > x2) + { + first = std::max(x2,valid_area.left()); + last = std::min(x1,valid_area.right()); + } + else + { + first = std::max(x1,valid_area.left()); + last = std::min(x2,valid_area.right()); + } + + long y; + long x; + const double x1f = x1; + const double y1f = y1; + for (double i = first; i <= last; ++i) + { + const double dy = slope*(i-x1f) + y1f; + const double dx = i; + + y = static_cast(dy); + x = static_cast(dx); + + + if (y >= valid_area.top() && y <= valid_area.bottom()) + { + alpha_pixel.alpha = static_cast((1.0-(dy-y))*max_alpha); + assign_pixel(c[y][x], alpha_pixel); + } + if (y+1 >= valid_area.top() && y+1 <= valid_area.bottom()) + { + alpha_pixel.alpha = static_cast((dy-y)*max_alpha); + assign_pixel(c[y+1][x], alpha_pixel); + } + } + } + else + { + const double slope = ((double)run)/rise; + + + double first, last; + + + if (y1 > y2) + { + first = std::max(y2,valid_area.top()); + last = std::min(y1,valid_area.bottom()); + } + else + { + first = std::max(y1,valid_area.top()); + last = std::min(y2,valid_area.bottom()); + } + + long x; + long y; + const double x1f = x1; + const double y1f = y1; + for (double i = first; i <= last; ++i) + { + const double dx = slope*(i-y1f) + x1f; + const double dy = i; + + y = static_cast(dy); + x = static_cast(dx); + + if (x >= valid_area.left() && x <= valid_area.right()) + { + alpha_pixel.alpha = static_cast((1.0-(dx-x))*max_alpha); + assign_pixel(c[y][x], alpha_pixel); + } + if (x+1 >= valid_area.left() && x+1 <= valid_area.right()) + { + alpha_pixel.alpha = static_cast((dx-x)*max_alpha); + assign_pixel(c[y][x+1], alpha_pixel); + } + } + } + } + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type, + typename pixel_type + > + void draw_line ( + image_type& c, + const point& p1, + const point& p2, + const pixel_type& val + ) + { + draw_line(p1.x(),p1.y(),p2.x(),p2.y(),c,val); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type, + typename pixel_type + > + void draw_rectangle ( + image_type& c, + const rectangle& rect, + const pixel_type& val + ) + { + draw_line(c, rect.tl_corner(), rect.tr_corner(), val); + draw_line(c, rect.bl_corner(), rect.br_corner(), val); + draw_line(c, rect.tl_corner(), rect.bl_corner(), val); + draw_line(c, rect.tr_corner(), rect.br_corner(), val); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type, + typename pixel_type + > + void draw_rectangle ( + image_type& c, + const rectangle& rect, + const pixel_type& val, + unsigned int thickness + ) + { + for (unsigned int i = 0; i < thickness; ++i) + { + if ((i%2)==0) + draw_rectangle(c,shrink_rect(rect,(i+1)/2),val); + else + draw_rectangle(c,grow_rect(rect,(i+1)/2),val); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type, + typename pixel_type + > + void fill_rect ( + image_type& img_, + const rectangle& rect, + const pixel_type& pixel + ) + { + image_view img(img_); + rectangle area = rect.intersect(get_rect(img)); + + for (long r = area.top(); r <= area.bottom(); ++r) + { + for (long c = area.left(); c <= area.right(); ++c) + { + assign_pixel(img[r][c], pixel); + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_array_type + > + matrix::pixel_type> tile_images ( + const image_array_type& images + ) + { + typedef typename image_traits::pixel_type T; + + if (images.size() == 0) + return matrix(); + + const unsigned long size_nc = square_root(images.size()); + const unsigned long size_nr = (size_nc*(size_nc-1)>=images.size())? size_nc-1 : size_nc; + // Figure out the size we have to use for each chip in the big main image. We will + // use the largest dimensions seen across all the chips. + long nr = 0; + long nc = 0; + for (unsigned long i = 0; i < images.size(); ++i) + { + nr = std::max(num_rows(images[i]), nr); + nc = std::max(num_columns(images[i]), nc); + } + + matrix temp(size_nr*nr, size_nc*nc); + T background_color; + assign_pixel(background_color, 0); + temp = background_color; + unsigned long idx = 0; + for (unsigned long r = 0; r < size_nr; ++r) + { + for (unsigned long c = 0; c < size_nc; ++c) + { + if (idx < images.size()) + { + set_subm(temp, r*nr, c*nc, nr, nc) = mat(images[idx]); + } + ++idx; + } + } + return temp; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type, + typename pixel_type + > + void draw_solid_circle ( + image_type& img_, + const dpoint& center_point, + double radius, + const pixel_type& pixel + ) + { + image_view img(img_); + using std::sqrt; + const rectangle valid_area(get_rect(img)); + const double x = center_point.x(); + const double y = center_point.y(); + const point cp(center_point); + if (radius > 1) + { + long first_x = static_cast(x - radius + 0.5); + long last_x = static_cast(x + radius + 0.5); + const double rs = radius*radius; + + // ensure that we only loop over the part of the x dimension that this + // image contains. + if (first_x < valid_area.left()) + first_x = valid_area.left(); + if (last_x > valid_area.right()) + last_x = valid_area.right(); + + long top, bottom; + + top = static_cast(sqrt(std::max(rs - (first_x-x-0.5)*(first_x-x-0.5),0.0))+0.5); + top += y; + long last = top; + + // draw the left half of the circle + long middle = std::min(cp.x()-1,last_x); + for (long i = first_x; i <= middle; ++i) + { + double a = i - x + 0.5; + // find the top of the arc + top = static_cast(sqrt(std::max(rs - a*a,0.0))+0.5); + top += y; + long temp = top; + + while(top >= last) + { + bottom = y - top + y; + draw_line(img_, point(i,top),point(i,bottom),pixel); + --top; + } + + last = temp; + } + + middle = std::max(cp.x(),first_x); + top = static_cast(sqrt(std::max(rs - (last_x-x+0.5)*(last_x-x+0.5),0.0))+0.5); + top += y; + last = top; + // draw the right half of the circle + for (long i = last_x; i >= middle; --i) + { + double a = i - x - 0.5; + // find the top of the arc + top = static_cast(sqrt(std::max(rs - a*a,0.0))+0.5); + top += y; + long temp = top; + + while(top >= last) + { + bottom = y - top + y; + draw_line(img_, point(i,top),point(i,bottom),pixel); + --top; + } + + last = temp; + } + } + else if (valid_area.contains(cp)) + { + // For circles smaller than a pixel we will just alpha blend them in proportion + // to how small they are. + rgb_alpha_pixel temp; + assign_pixel(temp, pixel); + temp.alpha = static_cast(255*radius + 0.5); + assign_pixel(img[cp.y()][cp.x()], temp); + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_DRAW_IMAGe_ + + + + diff --git a/ml/dlib/dlib/image_transforms/draw_abstract.h b/ml/dlib/dlib/image_transforms/draw_abstract.h new file mode 100644 index 000000000..6631f8d8f --- /dev/null +++ b/ml/dlib/dlib/image_transforms/draw_abstract.h @@ -0,0 +1,150 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_DRAW_IMAGe_ABSTRACT +#ifdef DLIB_DRAW_IMAGe_ABSTRACT + +#include "../matrix.h" +#include "../image_processing/generic_image.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type, + typename pixel_type + > + void draw_line ( + image_type& img, + const point& p1, + const point& p2, + const pixel_type& val + ); + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + ensures + - #img.nr() == img.nr() && #img.nc() == img.nc() + (i.e. the dimensions of the input image are not changed) + - for all valid r and c that are on the line between point p1 and p2: + - performs assign_pixel(img[r][c], val) + (i.e. it draws the line from p1 to p2 onto the image) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type, + typename pixel_type + > + void draw_line ( + long x1, + long y1, + long x2, + long y2, + image_type& img, + const pixel_type& val + ); + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + ensures + - performs draw_line(img, point(x1,y1), point(x2,y2), val) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type, + typename pixel_type + > + void draw_rectangle ( + image_type& img, + const rectangle& rect, + const pixel_type& val, + unsigned int thickness = 1 + ); + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - pixel_traits is defined + ensures + - Draws the given rectangle onto the image img. It does this by calling + draw_line() four times to draw the four sides of the rectangle. + - The rectangle is drawn with the color given by val. + - The drawn rectangle will have edges that are thickness pixels wide. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type, + typename pixel_type + > + void draw_solid_circle ( + image_type& img, + const dpoint& center_point, + double radius, + const pixel_type& pixel + ); + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - pixel_traits is defined + ensures + - Draws a fully filled in circle onto image that is centered at center_point + and has the given radius. The circle will be filled by assigning the given + pixel value to each element of the circle. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type, + typename pixel_type + > + void fill_rect ( + image_type& img, + const rectangle& rect, + const pixel_type& pixel + ); + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - pixel_traits is defined + ensures + - fills the area defined by rect in the given image with the given pixel value. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_array_type + > + matrix::pixel_type> tile_images ( + const image_array_type& images + ); + /*! + requires + - image_array_type is a dlib::array of image objects where each image object + implements the interface defined in dlib/image_processing/generic_image.h + ensures + - This function takes the given images and tiles them into a single large + square image and returns this new big tiled image. Therefore, it is a useful + method to visualize many small images at once. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_DRAW_IMAGe_ABSTRACT + + + diff --git a/ml/dlib/dlib/image_transforms/edge_detector.h b/ml/dlib/dlib/image_transforms/edge_detector.h new file mode 100644 index 000000000..2fa898fed --- /dev/null +++ b/ml/dlib/dlib/image_transforms/edge_detector.h @@ -0,0 +1,302 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_EDGE_DETECTOr_ +#define DLIB_EDGE_DETECTOr_ + +#include "edge_detector_abstract.h" +#include "../pixel.h" +#include "../array2d.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + inline char edge_orientation ( + const T& x_, + const T& y_ + ) + { + + // if this is a perfectly horizontal gradient then return right away + if (x_ == 0) + { + return '|'; + } + else if (y_ == 0) // if this is a perfectly vertical gradient then return right away + { + return '-'; + } + + // Promote x so that when we multiply by 128 later we know overflow won't happen. + typedef typename promote::type type; + type x = x_; + type y = y_; + + if (x < 0) + { + x = -x; + if (y < 0) + { + y = -y; + x *= 128; + const type temp = x/y; + if (temp > 309) + return '-'; + else if (temp > 53) + return '/'; + else + return '|'; + } + else + { + x *= 128; + const type temp = x/y; + if (temp > 309) + return '-'; + else if (temp > 53) + return '\\'; + else + return '|'; + } + } + else + { + if (y < 0) + { + y = -y; + x *= 128; + + const type temp = x/y; + if (temp > 309) + return '-'; + else if (temp > 53) + return '\\'; + else + return '|'; + } + else + { + x *= 128; + + const type temp = x/y; + if (temp > 309) + return '-'; + else if (temp > 53) + return '/'; + else + return '|'; + } + } + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename in_image_type, + typename out_image_type + > + void sobel_edge_detector ( + const in_image_type& in_img_, + out_image_type& horz_, + out_image_type& vert_ + ) + { + typedef typename image_traits::pixel_type pixel_type; + COMPILE_TIME_ASSERT(pixel_traits::is_unsigned == false); + DLIB_ASSERT( !is_same_object(in_img_,horz_) && !is_same_object(in_img_,vert_) && + !is_same_object(horz_,vert_), + "\tvoid sobel_edge_detector(in_img_, horz_, vert_)" + << "\n\t You can't give the same image as more than one argument" + << "\n\t is_same_object(in_img_,horz_): " << is_same_object(in_img_,horz_) + << "\n\t is_same_object(in_img_,vert_): " << is_same_object(in_img_,vert_) + << "\n\t is_same_object(horz_,vert_): " << is_same_object(horz_,vert_) + ); + + + const int vert_filter[3][3] = {{-1,-2,-1}, + {0,0,0}, + {1,2,1}}; + const int horz_filter[3][3] = { {-1,0,1}, + {-2,0,2}, + {-1,0,1}}; + + const long M = 3; + const long N = 3; + + + const_image_view in_img(in_img_); + image_view horz(horz_); + image_view vert(vert_); + + horz.set_size(in_img.nr(),in_img.nc()); + vert.set_size(in_img.nr(),in_img.nc()); + + assign_border_pixels(horz,1,1,0); + assign_border_pixels(vert,1,1,0); + + // figure out the range that we should apply the filter to + const long first_row = M/2; + const long first_col = N/2; + const long last_row = in_img.nr() - M/2; + const long last_col = in_img.nc() - N/2; + + + // apply the filter to the image + for (long r = first_row; r < last_row; ++r) + { + for (long c = first_col; c < last_col; ++c) + { + typedef typename pixel_traits::pixel_type>::basic_pixel_type bp_type; + + typename promote::type p, horz_temp, vert_temp; + horz_temp = 0; + vert_temp = 0; + for (long m = 0; m < M; ++m) + { + for (long n = 0; n < N; ++n) + { + // pull out the current pixel and put it into p + p = get_pixel_intensity(in_img[r-M/2+m][c-N/2+n]); + + horz_temp += p*horz_filter[m][n]; + vert_temp += p*vert_filter[m][n]; + } + } + + assign_pixel(horz[r][c] , horz_temp); + assign_pixel(vert[r][c] , vert_temp); + + } + } + } + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + template + typename promote::type square (const T& a) + { + return static_cast(a)*static_cast(a); + } + } + + template < + typename in_image_type, + typename out_image_type + > + void suppress_non_maximum_edges ( + const in_image_type& horz_, + const in_image_type& vert_, + out_image_type& out_img_ + ) + { + const_image_view horz(horz_); + const_image_view vert(vert_); + image_view out_img(out_img_); + + COMPILE_TIME_ASSERT(is_signed_type::pixel_type>::value); + DLIB_ASSERT( horz.nr() == vert.nr() && horz.nc() == vert.nc(), + "\tvoid suppress_non_maximum_edges(horz, vert, out_img)" + << "\n\tYou have to give horz and vert gradient images that are the same size" + << "\n\thorz.nr(): " << horz.nr() + << "\n\thorz.nc(): " << horz.nc() + << "\n\tvert.nr(): " << vert.nr() + << "\n\tvert.nc(): " << vert.nc() + ); + DLIB_ASSERT( !is_same_object(out_img_,horz_) && !is_same_object(out_img_,vert_), + "\tvoid suppress_non_maximum_edges(horz_, vert_, out_img_)" + << "\n\t out_img can't be the same as one of the input images." + << "\n\t is_same_object(out_img_,horz_): " << is_same_object(out_img_,horz_) + << "\n\t is_same_object(out_img_,vert_): " << is_same_object(out_img_,vert_) + ); + + using std::min; + using std::abs; + + + // if there isn't any input image then don't do anything + if (horz.size() == 0) + { + out_img.clear(); + return; + } + + out_img.set_size(horz.nr(),horz.nc()); + + zero_border_pixels(out_img,1,1); + + // now do non maximum suppression while we copy the + const long M = 3; + const long N = 3; + + // figure out the range that we should apply the filter to + const long first_row = M/2; + const long first_col = N/2; + const long last_row = horz.nr() - M/2; + const long last_col = horz.nc() - N/2; + + + // apply the filter to the image + for (long r = first_row; r < last_row; ++r) + { + for (long c = first_col; c < last_col; ++c) + { + typedef typename promote::pixel_type>::type T; + const T y = horz[r][c]; + const T x = vert[r][c]; + + using impl::square; + + const T val = square(horz[r][c]) + square(vert[r][c]); + + const char ori = edge_orientation(x,y); + const unsigned char zero = 0; + switch (ori) + { + case '-': + if (square(horz[r-1][c])+square(vert[r-1][c]) > val || square(horz[r+1][c]) + square(vert[r+1][c]) > val) + assign_pixel(out_img[r][c] , zero); + else + assign_pixel(out_img[r][c] , std::sqrt((double)val)); + break; + + case '|': + if (square(horz[r][c-1]) + square(vert[r][c-1]) > val || square(horz[r][c+1]) + square(vert[r][c+1]) > val) + assign_pixel(out_img[r][c] , zero); + else + assign_pixel(out_img[r][c] , std::sqrt((double)val)); + break; + + case '/': + if (square(horz[r-1][c-1]) + square(vert[r-1][c-1]) > val || square(horz[r+1][c+1]) + square(vert[r+1][c+1]) > val) + assign_pixel(out_img[r][c] , zero); + else + assign_pixel(out_img[r][c] , std::sqrt((double)val)); + break; + + case '\\': + if (square(horz[r+1][c-1]) + square(vert[r+1][c-1]) > val || square(horz[r-1][c+1]) + square(vert[r-1][c+1]) > val) + assign_pixel(out_img[r][c] , zero); + else + assign_pixel(out_img[r][c] , std::sqrt((double)val)); + break; + + } + } + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_EDGE_DETECTOr_ + + + diff --git a/ml/dlib/dlib/image_transforms/edge_detector_abstract.h b/ml/dlib/dlib/image_transforms/edge_detector_abstract.h new file mode 100644 index 000000000..42c991665 --- /dev/null +++ b/ml/dlib/dlib/image_transforms/edge_detector_abstract.h @@ -0,0 +1,112 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_EDGE_DETECTOr_ABSTRACT_ +#ifdef DLIB_EDGE_DETECTOr_ABSTRACT_ + +#include "../pixel.h" +#include "../image_processing/generic_image.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + inline char edge_orientation ( + const T& x, + const T& y + ); + /*! + ensures + - returns the orientation of the line drawn from the origin to the point (x,y). + The orientation is represented pictorially using the four ascii + characters /,|,\, and -. + - if (the line is horizontal) then + returns '-' + - if (the line is vertical) then + returns '|' + - if (the line is diagonal with a positive slope) then + returns '/' + - if (the line is diagonal with a negative slope) then + returns '\\' + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename in_image_type, + typename out_image_type + > + void sobel_edge_detector ( + const in_image_type& in_img, + out_image_type& horz, + out_image_type& vert + ); + /*! + requires + - in_image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - out_image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - out_image_type must use signed grayscale pixels + - is_same_object(in_img,horz) == false + - is_same_object(in_img,vert) == false + - is_same_object(horz,vert) == false + ensures + - Applies the sobel edge detector to the given input image and stores the resulting + edge detections in the horz and vert images + - #horz.nr() == in_img.nr() + - #horz.nc() == in_img.nc() + - #vert.nr() == in_img.nr() + - #vert.nc() == in_img.nc() + - for all valid r and c: + - #horz[r][c] == the magnitude of the horizontal gradient at the point in_img[r][c] + - #vert[r][c] == the magnitude of the vertical gradient at the point in_img[r][c] + - edge_orientation(#vert[r][c], #horz[r][c]) == the edge direction at this point in + the image + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename in_image_type, + typename out_image_type + > + void suppress_non_maximum_edges ( + const in_image_type& horz, + const in_image_type& vert, + out_image_type& out_img + ); + /*! + requires + - in_image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - out_image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - horz.nr() == vert.nr() + - horz.nc() == vert.nc() + - is_same_object(out_img, horz) == false + - is_same_object(out_img, vert) == false + - image_traits::pixel_type == A signed scalar type (e.g. int, double, etc.) + ensures + - #out_img.nr() = horz.nr() + - #out_img.nc() = horz.nc() + - let edge_strength(r,c) == sqrt(pow(horz[r][c],2) + pow(vert[r][c],2)) + (i.e. The Euclidean norm of the gradient) + - for all valid r and c: + - if (edge_strength(r,c) is at a maximum with respect to its 2 neighboring + pixels along the line given by edge_orientation(vert[r][c],horz[r][c])) then + - performs assign_pixel(#out_img[r][c], edge_strength(r,c)) + - else + - performs assign_pixel(#out_img[r][c], 0) + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_EDGE_DETECTOr_ABSTRACT_ + + diff --git a/ml/dlib/dlib/image_transforms/equalize_histogram.h b/ml/dlib/dlib/image_transforms/equalize_histogram.h new file mode 100644 index 000000000..dd048759a --- /dev/null +++ b/ml/dlib/dlib/image_transforms/equalize_histogram.h @@ -0,0 +1,143 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_EQUALIZE_HISTOGRAm_ +#define DLIB_EQUALIZE_HISTOGRAm_ + +#include "../pixel.h" +#include "equalize_histogram_abstract.h" +#include +#include "../enable_if.h" +#include "../matrix.h" + +namespace dlib +{ + +// --------------------------------------------------------------------------------------- + + template < + typename in_image_type, + long R, + long C, + typename MM + > + void get_histogram ( + const in_image_type& in_img_, + matrix& hist + ) + { + typedef typename image_traits::pixel_type pixel_type; + COMPILE_TIME_ASSERT( pixel_traits::is_unsigned == true ); + + typedef typename pixel_traits::basic_pixel_type in_image_basic_pixel_type; + COMPILE_TIME_ASSERT( sizeof(in_image_basic_pixel_type) <= 2); + + // make sure hist is the right size + if (R == 1) + hist.set_size(1,pixel_traits::max()+1); + else + hist.set_size(pixel_traits::max()+1,1); + + + set_all_elements(hist,0); + + const_image_view in_img(in_img_); + // compute the histogram + for (long r = 0; r < in_img.nr(); ++r) + { + for (long c = 0; c < in_img.nc(); ++c) + { + unsigned long p = get_pixel_intensity(in_img[r][c]); + ++hist(p); + } + } + } + +// --------------------------------------------------------------------------------------- + + template < + typename in_image_type, + typename out_image_type + > + void equalize_histogram ( + const in_image_type& in_img_, + out_image_type& out_img_ + ) + { + const_image_view in_img(in_img_); + image_view out_img(out_img_); + + typedef typename image_traits::pixel_type in_pixel_type; + typedef typename image_traits::pixel_type out_pixel_type; + + COMPILE_TIME_ASSERT( pixel_traits::has_alpha == false ); + COMPILE_TIME_ASSERT( pixel_traits::has_alpha == false ); + + COMPILE_TIME_ASSERT( pixel_traits::is_unsigned == true ); + COMPILE_TIME_ASSERT( pixel_traits::is_unsigned == true ); + + typedef typename pixel_traits::basic_pixel_type in_image_basic_pixel_type; + COMPILE_TIME_ASSERT( sizeof(in_image_basic_pixel_type) <= 2); + + + // if there isn't any input image then don't do anything + if (in_img.size() == 0) + { + out_img.clear(); + return; + } + + out_img.set_size(in_img.nr(),in_img.nc()); + + unsigned long p; + + matrix histogram; + get_histogram(in_img_, histogram); + in_img = in_img_; + + double scale = pixel_traits::max(); + if (in_img.size() > histogram(0)) + scale /= in_img.size()-histogram(0); + else + scale = 0; + + // make the black pixels remain black in the output image + histogram(0) = 0; + + // compute the transform function + for (long i = 1; i < histogram.size(); ++i) + histogram(i) += histogram(i-1); + // scale so that it is in the range [0,pixel_traits::max()] + for (long i = 0; i < histogram.size(); ++i) + histogram(i) = static_cast(histogram(i)*scale); + + // now do the transform + for (long row = 0; row < in_img.nr(); ++row) + { + for (long col = 0; col < in_img.nc(); ++col) + { + p = histogram(get_pixel_intensity(in_img[row][col])); + assign_pixel(out_img[row][col], in_img[row][col]); + assign_pixel_intensity(out_img[row][col],p); + } + } + + } + + template < + typename image_type + > + void equalize_histogram ( + image_type& img + ) + { + equalize_histogram(img,img); + } + +// --------------------------------------------------------------------------------------- + +} + +#endif // DLIB_EQUALIZE_HISTOGRAm_ + + + diff --git a/ml/dlib/dlib/image_transforms/equalize_histogram_abstract.h b/ml/dlib/dlib/image_transforms/equalize_histogram_abstract.h new file mode 100644 index 000000000..2592aef1a --- /dev/null +++ b/ml/dlib/dlib/image_transforms/equalize_histogram_abstract.h @@ -0,0 +1,91 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_EQUALIZE_HISTOGRAm_ABSTRACT_ +#ifdef DLIB_EQUALIZE_HISTOGRAm_ABSTRACT_ + +#include "../pixel.h" +#include "../matrix.h" +#include "../image_processing/generic_image.h" + +namespace dlib +{ + +// --------------------------------------------------------------------------------------- + + template < + typename in_image_type, + typename out_image_type + > + void equalize_histogram ( + const in_image_type& in_img, + out_image_type& out_img + ); + /*! + requires + - in_image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - out_image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - Let pixel_type be the type of pixel in either input or output images, then we + must have: + - pixel_traits::has_alpha == false + - pixel_traits::is_unsigned == true + - For the input image pixel type, we have the additional requirement that: + - pixel_traits::max() <= 65535 + ensures + - #out_img == the histogram equalized version of in_img + - #out_img.nc() == in_img.nc() + - #out_img.nr() == in_img.nr() + !*/ + + template < + typename image_type + > + void equalize_histogram ( + image_type& img + ); + /*! + requires + - it is valid to call equalize_histogram(img,img) + ensures + - calls equalize_histogram(img,img); + !*/ + +// --------------------------------------------------------------------------------------- + + template < + typename in_image_type, + long R, + long C, + typename MM + > + void get_histogram ( + const in_image_type& in_img, + matrix& hist + ); + /*! + requires + - in_image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - Let pixel_type denote the type of pixel in in_img, then we must have: + - pixel_traits::is_unsigned == true + - pixel_traits::max() <= 65535 + - hist must be capable of representing a column vector of length + pixel_traits::max(). I.e. if R and C are nonzero + then they must be values that don't conflict with the previous sentence. + ensures + - #hist.size() == pixel_traits::max() + - #hist.nc() == 1 || #hist.nr() == 1 (i.e. hist is either a row or column vector) + - #hist == the histogram for in_img. I.e. it is the case that for all + valid i: + - hist(i) == the number of times a pixel with intensity i appears + in in_img + !*/ + +// --------------------------------------------------------------------------------------- + +} + +#endif // DLIB_EQUALIZE_HISTOGRAm_ABSTRACT_ + + diff --git a/ml/dlib/dlib/image_transforms/fhog.h b/ml/dlib/dlib/image_transforms/fhog.h new file mode 100644 index 000000000..d99973adf --- /dev/null +++ b/ml/dlib/dlib/image_transforms/fhog.h @@ -0,0 +1,1404 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_fHOG_Hh_ +#define DLIB_fHOG_Hh_ + +#include "fhog_abstract.h" +#include "../matrix.h" +#include "../array2d.h" +#include "../array.h" +#include "../geometry.h" +#include "assign_image.h" +#include "draw.h" +#include "interpolation.h" +#include "../simd.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace impl_fhog + { + template + inline typename dlib::enable_if_c::rgb>::type get_gradient ( + const int r, + const int c, + const image_type& img, + matrix& grad, + T& len + ) + { + matrix grad2, grad3; + // get the red gradient + grad(0) = (int)img[r][c+1].red-(int)img[r][c-1].red; + grad(1) = (int)img[r+1][c].red-(int)img[r-1][c].red; + len = length_squared(grad); + + // get the green gradient + grad2(0) = (int)img[r][c+1].green-(int)img[r][c-1].green; + grad2(1) = (int)img[r+1][c].green-(int)img[r-1][c].green; + T v2 = length_squared(grad2); + + // get the blue gradient + grad3(0) = (int)img[r][c+1].blue-(int)img[r][c-1].blue; + grad3(1) = (int)img[r+1][c].blue-(int)img[r-1][c].blue; + T v3 = length_squared(grad3); + + // pick color with strongest gradient + if (v2 > len) + { + len = v2; + grad = grad2; + } + if (v3 > len) + { + len = v3; + grad = grad3; + } + } + + template + inline typename dlib::enable_if_c::rgb>::type get_gradient ( + const int r, + const int c, + const image_type& img, + simd4f& grad_x, + simd4f& grad_y, + simd4f& len + ) + { + simd4i rleft((int)img[r][c-1].red, + (int)img[r][c].red, + (int)img[r][c+1].red, + (int)img[r][c+2].red); + simd4i rright((int)img[r][c+1].red, + (int)img[r][c+2].red, + (int)img[r][c+3].red, + (int)img[r][c+4].red); + simd4i rtop((int)img[r-1][c].red, + (int)img[r-1][c+1].red, + (int)img[r-1][c+2].red, + (int)img[r-1][c+3].red); + simd4i rbottom((int)img[r+1][c].red, + (int)img[r+1][c+1].red, + (int)img[r+1][c+2].red, + (int)img[r+1][c+3].red); + + simd4i gleft((int)img[r][c-1].green, + (int)img[r][c].green, + (int)img[r][c+1].green, + (int)img[r][c+2].green); + simd4i gright((int)img[r][c+1].green, + (int)img[r][c+2].green, + (int)img[r][c+3].green, + (int)img[r][c+4].green); + simd4i gtop((int)img[r-1][c].green, + (int)img[r-1][c+1].green, + (int)img[r-1][c+2].green, + (int)img[r-1][c+3].green); + simd4i gbottom((int)img[r+1][c].green, + (int)img[r+1][c+1].green, + (int)img[r+1][c+2].green, + (int)img[r+1][c+3].green); + + simd4i bleft((int)img[r][c-1].blue, + (int)img[r][c].blue, + (int)img[r][c+1].blue, + (int)img[r][c+2].blue); + simd4i bright((int)img[r][c+1].blue, + (int)img[r][c+2].blue, + (int)img[r][c+3].blue, + (int)img[r][c+4].blue); + simd4i btop((int)img[r-1][c].blue, + (int)img[r-1][c+1].blue, + (int)img[r-1][c+2].blue, + (int)img[r-1][c+3].blue); + simd4i bbottom((int)img[r+1][c].blue, + (int)img[r+1][c+1].blue, + (int)img[r+1][c+2].blue, + (int)img[r+1][c+3].blue); + + simd4i grad_x_red = rright-rleft; + simd4i grad_y_red = rbottom-rtop; + simd4i grad_x_green = gright-gleft; + simd4i grad_y_green = gbottom-gtop; + simd4i grad_x_blue = bright-bleft; + simd4i grad_y_blue = bbottom-btop; + + simd4i rlen = grad_x_red*grad_x_red + grad_y_red*grad_y_red; + simd4i glen = grad_x_green*grad_x_green + grad_y_green*grad_y_green; + simd4i blen = grad_x_blue*grad_x_blue + grad_y_blue*grad_y_blue; + + simd4i cmp = rlen>glen; + simd4i tgrad_x = select(cmp,grad_x_red,grad_x_green); + simd4i tgrad_y = select(cmp,grad_y_red,grad_y_green); + simd4i tlen = select(cmp,rlen,glen); + + cmp = tlen>blen; + grad_x = select(cmp,tgrad_x,grad_x_blue); + grad_y = select(cmp,tgrad_y,grad_y_blue); + len = select(cmp,tlen,blen); + } + + // ------------------------------------------------------------------------------------ + + template + inline typename dlib::enable_if_c::rgb>::type get_gradient( + const int r, + const int c, + const image_type& img, + simd8f& grad_x, + simd8f& grad_y, + simd8f& len + ) + { + simd8i rleft((int)img[r][c - 1].red, + (int)img[r][c].red, + (int)img[r][c + 1].red, + (int)img[r][c + 2].red, + (int)img[r][c + 3].red, + (int)img[r][c + 4].red, + (int)img[r][c + 5].red, + (int)img[r][c + 6].red); + simd8i rright((int)img[r][c + 1].red, + (int)img[r][c + 2].red, + (int)img[r][c + 3].red, + (int)img[r][c + 4].red, + (int)img[r][c + 5].red, + (int)img[r][c + 6].red, + (int)img[r][c + 7].red, + (int)img[r][c + 8].red); + simd8i rtop((int)img[r - 1][c].red, + (int)img[r - 1][c + 1].red, + (int)img[r - 1][c + 2].red, + (int)img[r - 1][c + 3].red, + (int)img[r - 1][c + 4].red, + (int)img[r - 1][c + 5].red, + (int)img[r - 1][c + 6].red, + (int)img[r - 1][c + 7].red); + simd8i rbottom((int)img[r + 1][c].red, + (int)img[r + 1][c + 1].red, + (int)img[r + 1][c + 2].red, + (int)img[r + 1][c + 3].red, + (int)img[r + 1][c + 4].red, + (int)img[r + 1][c + 5].red, + (int)img[r + 1][c + 6].red, + (int)img[r + 1][c + 7].red); + + simd8i gleft((int)img[r][c - 1].green, + (int)img[r][c].green, + (int)img[r][c + 1].green, + (int)img[r][c + 2].green, + (int)img[r][c + 3].green, + (int)img[r][c + 4].green, + (int)img[r][c + 5].green, + (int)img[r][c + 6].green); + simd8i gright((int)img[r][c + 1].green, + (int)img[r][c + 2].green, + (int)img[r][c + 3].green, + (int)img[r][c + 4].green, + (int)img[r][c + 5].green, + (int)img[r][c + 6].green, + (int)img[r][c + 7].green, + (int)img[r][c + 8].green); + simd8i gtop((int)img[r - 1][c].green, + (int)img[r - 1][c + 1].green, + (int)img[r - 1][c + 2].green, + (int)img[r - 1][c + 3].green, + (int)img[r - 1][c + 4].green, + (int)img[r - 1][c + 5].green, + (int)img[r - 1][c + 6].green, + (int)img[r - 1][c + 7].green); + simd8i gbottom((int)img[r + 1][c].green, + (int)img[r + 1][c + 1].green, + (int)img[r + 1][c + 2].green, + (int)img[r + 1][c + 3].green, + (int)img[r + 1][c + 4].green, + (int)img[r + 1][c + 5].green, + (int)img[r + 1][c + 6].green, + (int)img[r + 1][c + 7].green); + + simd8i bleft((int)img[r][c - 1].blue, + (int)img[r][c].blue, + (int)img[r][c + 1].blue, + (int)img[r][c + 2].blue, + (int)img[r][c + 3].blue, + (int)img[r][c + 4].blue, + (int)img[r][c + 5].blue, + (int)img[r][c + 6].blue); + simd8i bright((int)img[r][c + 1].blue, + (int)img[r][c + 2].blue, + (int)img[r][c + 3].blue, + (int)img[r][c + 4].blue, + (int)img[r][c + 5].blue, + (int)img[r][c + 6].blue, + (int)img[r][c + 7].blue, + (int)img[r][c + 8].blue); + simd8i btop((int)img[r - 1][c].blue, + (int)img[r - 1][c + 1].blue, + (int)img[r - 1][c + 2].blue, + (int)img[r - 1][c + 3].blue, + (int)img[r - 1][c + 4].blue, + (int)img[r - 1][c + 5].blue, + (int)img[r - 1][c + 6].blue, + (int)img[r - 1][c + 7].blue); + simd8i bbottom((int)img[r + 1][c].blue, + (int)img[r + 1][c + 1].blue, + (int)img[r + 1][c + 2].blue, + (int)img[r + 1][c + 3].blue, + (int)img[r + 1][c + 4].blue, + (int)img[r + 1][c + 5].blue, + (int)img[r + 1][c + 6].blue, + (int)img[r + 1][c + 7].blue); + + simd8i grad_x_red = rright - rleft; + simd8i grad_y_red = rbottom - rtop; + simd8i grad_x_green = gright - gleft; + simd8i grad_y_green = gbottom - gtop; + simd8i grad_x_blue = bright - bleft; + simd8i grad_y_blue = bbottom - btop; + + simd8i rlen = grad_x_red*grad_x_red + grad_y_red*grad_y_red; + simd8i glen = grad_x_green*grad_x_green + grad_y_green*grad_y_green; + simd8i blen = grad_x_blue*grad_x_blue + grad_y_blue*grad_y_blue; + + simd8i cmp = rlen > glen; + simd8i tgrad_x = select(cmp, grad_x_red, grad_x_green); + simd8i tgrad_y = select(cmp, grad_y_red, grad_y_green); + simd8i tlen = select(cmp, rlen, glen); + + cmp = tlen > blen; + grad_x = select(cmp, tgrad_x, grad_x_blue); + grad_y = select(cmp, tgrad_y, grad_y_blue); + len = select(cmp, tlen, blen); + } + + // ------------------------------------------------------------------------------------ + + template + inline typename dlib::disable_if_c::rgb>::type get_gradient ( + const int r, + const int c, + const image_type& img, + matrix& grad, + T& len + ) + { + grad(0) = (int)get_pixel_intensity(img[r][c+1])-(int)get_pixel_intensity(img[r][c-1]); + grad(1) = (int)get_pixel_intensity(img[r+1][c])-(int)get_pixel_intensity(img[r-1][c]); + len = length_squared(grad); + } + + template + inline typename dlib::disable_if_c::rgb>::type get_gradient ( + int r, + int c, + const image_type& img, + simd4f& grad_x, + simd4f& grad_y, + simd4f& len + ) + { + simd4i left((int)get_pixel_intensity(img[r][c-1]), + (int)get_pixel_intensity(img[r][c]), + (int)get_pixel_intensity(img[r][c+1]), + (int)get_pixel_intensity(img[r][c+2])); + simd4i right((int)get_pixel_intensity(img[r][c+1]), + (int)get_pixel_intensity(img[r][c+2]), + (int)get_pixel_intensity(img[r][c+3]), + (int)get_pixel_intensity(img[r][c+4])); + + simd4i top((int)get_pixel_intensity(img[r-1][c]), + (int)get_pixel_intensity(img[r-1][c+1]), + (int)get_pixel_intensity(img[r-1][c+2]), + (int)get_pixel_intensity(img[r-1][c+3])); + simd4i bottom((int)get_pixel_intensity(img[r+1][c]), + (int)get_pixel_intensity(img[r+1][c+1]), + (int)get_pixel_intensity(img[r+1][c+2]), + (int)get_pixel_intensity(img[r+1][c+3])); + + grad_x = right-left; + grad_y = bottom-top; + + len = (grad_x*grad_x + grad_y*grad_y); + } + + // ------------------------------------------------------------------------------------ + + template + inline typename dlib::disable_if_c::rgb>::type get_gradient( + int r, + int c, + const image_type& img, + simd8f& grad_x, + simd8f& grad_y, + simd8f& len + ) + { + simd8i left((int)get_pixel_intensity(img[r][c - 1]), + (int)get_pixel_intensity(img[r][c]), + (int)get_pixel_intensity(img[r][c + 1]), + (int)get_pixel_intensity(img[r][c + 2]), + (int)get_pixel_intensity(img[r][c + 3]), + (int)get_pixel_intensity(img[r][c + 4]), + (int)get_pixel_intensity(img[r][c + 5]), + (int)get_pixel_intensity(img[r][c + 6])); + simd8i right((int)get_pixel_intensity(img[r][c + 1]), + (int)get_pixel_intensity(img[r][c + 2]), + (int)get_pixel_intensity(img[r][c + 3]), + (int)get_pixel_intensity(img[r][c + 4]), + (int)get_pixel_intensity(img[r][c + 5]), + (int)get_pixel_intensity(img[r][c + 6]), + (int)get_pixel_intensity(img[r][c + 7]), + (int)get_pixel_intensity(img[r][c + 8])); + + simd8i top((int)get_pixel_intensity(img[r - 1][c]), + (int)get_pixel_intensity(img[r - 1][c + 1]), + (int)get_pixel_intensity(img[r - 1][c + 2]), + (int)get_pixel_intensity(img[r - 1][c + 3]), + (int)get_pixel_intensity(img[r - 1][c + 4]), + (int)get_pixel_intensity(img[r - 1][c + 5]), + (int)get_pixel_intensity(img[r - 1][c + 6]), + (int)get_pixel_intensity(img[r - 1][c + 7])); + simd8i bottom((int)get_pixel_intensity(img[r + 1][c]), + (int)get_pixel_intensity(img[r + 1][c + 1]), + (int)get_pixel_intensity(img[r + 1][c + 2]), + (int)get_pixel_intensity(img[r + 1][c + 3]), + (int)get_pixel_intensity(img[r + 1][c + 4]), + (int)get_pixel_intensity(img[r + 1][c + 5]), + (int)get_pixel_intensity(img[r + 1][c + 6]), + (int)get_pixel_intensity(img[r + 1][c + 7])); + + grad_x = right - left; + grad_y = bottom - top; + + len = (grad_x*grad_x + grad_y*grad_y); + } + + // ------------------------------------------------------------------------------------ + + template + inline void set_hog ( + dlib::array,mm2>& hog, + int o, + int x, + int y, + const float& value + ) + { + hog[o][y][x] = value; + } + + template + void init_hog ( + dlib::array,mm2>& hog, + int hog_nr, + int hog_nc, + int filter_rows_padding, + int filter_cols_padding + ) + { + const int num_hog_bands = 27+4; + hog.resize(num_hog_bands); + for (int i = 0; i < num_hog_bands; ++i) + { + hog[i].set_size(hog_nr+filter_rows_padding-1, hog_nc+filter_cols_padding-1); + rectangle rect = get_rect(hog[i]); + rect.top() += (filter_rows_padding-1)/2; + rect.left() += (filter_cols_padding-1)/2; + rect.right() -= filter_cols_padding/2; + rect.bottom() -= filter_rows_padding/2; + zero_border_pixels(hog[i],rect); + } + } + + template + void init_hog_zero_everything ( + dlib::array,mm2>& hog, + int hog_nr, + int hog_nc, + int filter_rows_padding, + int filter_cols_padding + ) + { + const int num_hog_bands = 27+4; + hog.resize(num_hog_bands); + for (int i = 0; i < num_hog_bands; ++i) + { + hog[i].set_size(hog_nr+filter_rows_padding-1, hog_nc+filter_cols_padding-1); + assign_all_pixels(hog[i], 0); + } + } + + // ------------------------------------------------------------------------------------ + + template + inline void set_hog ( + array2d,mm>& hog, + int o, + int x, + int y, + const float& value + ) + { + hog[y][x](o) = value; + } + + template + void init_hog ( + array2d,mm>& hog, + int hog_nr, + int hog_nc, + int filter_rows_padding, + int filter_cols_padding + ) + { + hog.set_size(hog_nr+filter_rows_padding-1, hog_nc+filter_cols_padding-1); + + // now zero out the border region + rectangle rect = get_rect(hog); + rect.top() += (filter_rows_padding-1)/2; + rect.left() += (filter_cols_padding-1)/2; + rect.right() -= filter_cols_padding/2; + rect.bottom() -= filter_rows_padding/2; + border_enumerator be(get_rect(hog),rect); + while (be.move_next()) + { + const point p = be.element(); + set_all_elements(hog[p.y()][p.x()], 0); + } + } + + template + void init_hog_zero_everything ( + array2d,mm>& hog, + int hog_nr, + int hog_nc, + int filter_rows_padding, + int filter_cols_padding + ) + { + hog.set_size(hog_nr+filter_rows_padding-1, hog_nc+filter_cols_padding-1); + + for (long r = 0; r < hog.nr(); ++r) + { + for (long c = 0; c < hog.nc(); ++c) + { + set_all_elements(hog[r][c], 0); + } + } + } + + // ------------------------------------------------------------------------------------ + + template < + typename image_type, + typename out_type + > + void impl_extract_fhog_features_cell_size_1( + const image_type& img_, + out_type& hog, + int filter_rows_padding, + int filter_cols_padding + ) + { + const_image_view img(img_); + // make sure requires clause is not broken + DLIB_ASSERT( filter_rows_padding > 0 && + filter_cols_padding > 0 , + "\t void extract_fhog_features()" + << "\n\t Invalid inputs were given to this function. " + << "\n\t filter_rows_padding: " << filter_rows_padding + << "\n\t filter_cols_padding: " << filter_cols_padding + ); + + /* + This function is an optimized version of impl_extract_fhog_features() for + the case where cell_size == 1. + */ + + + // unit vectors used to compute gradient orientation + matrix directions[9]; + directions[0] = 1.0000, 0.0000; + directions[1] = 0.9397, 0.3420; + directions[2] = 0.7660, 0.6428; + directions[3] = 0.500, 0.8660; + directions[4] = 0.1736, 0.9848; + directions[5] = -0.1736, 0.9848; + directions[6] = -0.5000, 0.8660; + directions[7] = -0.7660, 0.6428; + directions[8] = -0.9397, 0.3420; + + + + if (img.nr() <= 2 || img.nc() <= 2) + { + hog.clear(); + return; + } + + array2d angle(img.nr(), img.nc()); + + array2d norm(img.nr(), img.nc()); + zero_border_pixels(norm,1,1); + + // memory for HOG features + const long hog_nr = img.nr()-2; + const long hog_nc = img.nc()-2; + + const int padding_rows_offset = (filter_rows_padding-1)/2; + const int padding_cols_offset = (filter_cols_padding-1)/2; + init_hog_zero_everything(hog, hog_nr, hog_nc, filter_rows_padding, filter_cols_padding); + + + const int visible_nr = img.nr()-1; + const int visible_nc = img.nc()-1; + + // First populate the gradient histograms + for (int y = 1; y < visible_nr; y++) + { + int x; + for (x = 1; x < visible_nc - 7; x += 8) + { + // v will be the length of the gradient vectors. + simd8f grad_x, grad_y, v; + get_gradient(y, x, img, grad_x, grad_y, v); + + float _vv[8]; + v.store(_vv); + + // Now snap the gradient to one of 18 orientations + simd8f best_dot = 0; + simd8f best_o = 0; + for (int o = 0; o < 9; o++) + { + simd8f dot = grad_x*directions[o](0) + grad_y*directions[o](1); + simd8f_bool cmp = dot>best_dot; + best_dot = select(cmp, dot, best_dot); + dot *= -1; + best_o = select(cmp, o, best_o); + + cmp = dot > best_dot; + best_dot = select(cmp, dot, best_dot); + best_o = select(cmp, o + 9, best_o); + } + + int32 _best_o[8]; simd8i(best_o).store(_best_o); + + norm[y][x + 0] = _vv[0]; + norm[y][x + 1] = _vv[1]; + norm[y][x + 2] = _vv[2]; + norm[y][x + 3] = _vv[3]; + norm[y][x + 4] = _vv[4]; + norm[y][x + 5] = _vv[5]; + norm[y][x + 6] = _vv[6]; + norm[y][x + 7] = _vv[7]; + + angle[y][x + 0] = _best_o[0]; + angle[y][x + 1] = _best_o[1]; + angle[y][x + 2] = _best_o[2]; + angle[y][x + 3] = _best_o[3]; + angle[y][x + 4] = _best_o[4]; + angle[y][x + 5] = _best_o[5]; + angle[y][x + 6] = _best_o[6]; + angle[y][x + 7] = _best_o[7]; + } + // Now process the right columns that don't fit into simd registers. + for (; x < visible_nc; x++) + { + matrix grad; + float v; + get_gradient(y,x,img,grad,v); + + // snap to one of 18 orientations + float best_dot = 0; + int best_o = 0; + for (int o = 0; o < 9; o++) + { + const float dot = dlib::dot(directions[o], grad); + if (dot > best_dot) + { + best_dot = dot; + best_o = o; + } + else if (-dot > best_dot) + { + best_dot = -dot; + best_o = o+9; + } + } + + norm[y][x] = v; + angle[y][x] = best_o; + } + } + + const float eps = 0.0001; + // compute features + for (int y = 0; y < hog_nr; y++) + { + const int yy = y+padding_rows_offset; + for (int x = 0; x < hog_nc; x++) + { + const simd4f z1(norm[y+1][x+1], + norm[y][x+1], + norm[y+1][x], + norm[y][x]); + + const simd4f z2(norm[y+1][x+2], + norm[y][x+2], + norm[y+1][x+1], + norm[y][x+1]); + + const simd4f z3(norm[y+2][x+1], + norm[y+1][x+1], + norm[y+2][x], + norm[y+1][x]); + + const simd4f z4(norm[y+2][x+2], + norm[y+1][x+2], + norm[y+2][x+1], + norm[y+1][x+1]); + + const simd4f temp0 = std::sqrt(norm[y+1][x+1]); + const simd4f nn = 0.2*sqrt(z1+z2+z3+z4+eps); + const simd4f n = 0.1/nn; + + simd4f t = 0; + + const int xx = x+padding_cols_offset; + + simd4f h0 = min(temp0,nn)*n; + const float vv = sum(h0); + set_hog(hog,angle[y+1][x+1],xx,yy, vv); + t += h0; + + t *= 2*0.2357; + + // contrast-insensitive features + set_hog(hog,angle[y+1][x+1]%9+18,xx,yy, vv); + + + float temp[4]; + t.store(temp); + + // texture features + set_hog(hog,27,xx,yy, temp[0]); + set_hog(hog,28,xx,yy, temp[1]); + set_hog(hog,29,xx,yy, temp[2]); + set_hog(hog,30,xx,yy, temp[3]); + } + } + } + + // ------------------------------------------------------------------------------------ + + template < + typename image_type, + typename out_type + > + void impl_extract_fhog_features( + const image_type& img_, + out_type& hog, + int cell_size, + int filter_rows_padding, + int filter_cols_padding + ) + { + const_image_view img(img_); + // make sure requires clause is not broken + DLIB_ASSERT( cell_size > 0 && + filter_rows_padding > 0 && + filter_cols_padding > 0 , + "\t void extract_fhog_features()" + << "\n\t Invalid inputs were given to this function. " + << "\n\t cell_size: " << cell_size + << "\n\t filter_rows_padding: " << filter_rows_padding + << "\n\t filter_cols_padding: " << filter_cols_padding + ); + + /* + This function implements the HOG feature extraction method described in + the paper: + P. Felzenszwalb, R. Girshick, D. McAllester, D. Ramanan + Object Detection with Discriminatively Trained Part Based Models + IEEE Transactions on Pattern Analysis and Machine Intelligence, Vol. 32, No. 9, Sep. 2010 + + Moreover, this function is derived from the HOG feature extraction code + from the features.cc file in the voc-releaseX code (see + http://people.cs.uchicago.edu/~rbg/latent/) which is has the following + license (note that the code has been modified to work with grayscale and + color as well as planar and interlaced input and output formats): + + Copyright (C) 2011, 2012 Ross Girshick, Pedro Felzenszwalb + Copyright (C) 2008, 2009, 2010 Pedro Felzenszwalb, Ross Girshick + Copyright (C) 2007 Pedro Felzenszwalb, Deva Ramanan + + Permission is hereby granted, free of charge, to any person obtaining + a copy of this software and associated documentation files (the + "Software"), to deal in the Software without restriction, including + without limitation the rights to use, copy, modify, merge, publish, + distribute, sublicense, and/or sell copies of the Software, and to + permit persons to whom the Software is furnished to do so, subject to + the following conditions: + + The above copyright notice and this permission notice shall be + included in all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE + LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION + OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION + WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + + if (cell_size == 1) + { + impl_extract_fhog_features_cell_size_1(img_,hog,filter_rows_padding,filter_cols_padding); + return; + } + + // unit vectors used to compute gradient orientation + matrix directions[9]; + directions[0] = 1.0000, 0.0000; + directions[1] = 0.9397, 0.3420; + directions[2] = 0.7660, 0.6428; + directions[3] = 0.500, 0.8660; + directions[4] = 0.1736, 0.9848; + directions[5] = -0.1736, 0.9848; + directions[6] = -0.5000, 0.8660; + directions[7] = -0.7660, 0.6428; + directions[8] = -0.9397, 0.3420; + + + + // First we allocate memory for caching orientation histograms & their norms. + const int cells_nr = (int)((float)img.nr()/(float)cell_size + 0.5); + const int cells_nc = (int)((float)img.nc()/(float)cell_size + 0.5); + + if (cells_nr == 0 || cells_nc == 0) + { + hog.clear(); + return; + } + + // We give hist extra padding around the edges (1 cell all the way around the + // edge) so we can avoid needing to do boundary checks when indexing into it + // later on. So some statements assign to the boundary but those values are + // never used. + array2d > hist(cells_nr+2, cells_nc+2); + for (long r = 0; r < hist.nr(); ++r) + { + for (long c = 0; c < hist.nc(); ++c) + { + hist[r][c] = 0; + } + } + + array2d norm(cells_nr, cells_nc); + assign_all_pixels(norm, 0); + + // memory for HOG features + const int hog_nr = std::max(cells_nr-2, 0); + const int hog_nc = std::max(cells_nc-2, 0); + if (hog_nr == 0 || hog_nc == 0) + { + hog.clear(); + return; + } + const int padding_rows_offset = (filter_rows_padding-1)/2; + const int padding_cols_offset = (filter_cols_padding-1)/2; + init_hog(hog, hog_nr, hog_nc, filter_rows_padding, filter_cols_padding); + + const int visible_nr = std::min((long)cells_nr*cell_size,img.nr())-1; + const int visible_nc = std::min((long)cells_nc*cell_size,img.nc())-1; + + // First populate the gradient histograms + for (int y = 1; y < visible_nr; y++) + { + const float yp = ((float)y+0.5)/(float)cell_size - 0.5; + const int iyp = (int)std::floor(yp); + const float vy0 = yp - iyp; + const float vy1 = 1.0 - vy0; + int x; + for (x = 1; x < visible_nc - 7; x += 8) + { + simd8f xx(x, x + 1, x + 2, x + 3, x + 4, x + 5, x + 6, x + 7); + // v will be the length of the gradient vectors. + simd8f grad_x, grad_y, v; + get_gradient(y, x, img, grad_x, grad_y, v); + + // We will use bilinear interpolation to add into the histogram bins. + // So first we precompute the values needed to determine how much each + // pixel votes into each bin. + simd8f xp = (xx + 0.5) / (float)cell_size + 0.5; + simd8i ixp = simd8i(xp); + simd8f vx0 = xp - ixp; + simd8f vx1 = 1.0f - vx0; + + v = sqrt(v); + + // Now snap the gradient to one of 18 orientations + simd8f best_dot = 0; + simd8f best_o = 0; + for (int o = 0; o < 9; o++) + { + simd8f dot = grad_x*directions[o](0) + grad_y*directions[o](1); + simd8f_bool cmp = dot>best_dot; + best_dot = select(cmp, dot, best_dot); + dot *= -1; + best_o = select(cmp, o, best_o); + + cmp = dot > best_dot; + best_dot = select(cmp, dot, best_dot); + best_o = select(cmp, o + 9, best_o); + } + + + // Add the gradient magnitude, v, to 4 histograms around pixel using + // bilinear interpolation. + vx1 *= v; + vx0 *= v; + // The amounts for each bin + simd8f v11 = vy1*vx1; + simd8f v01 = vy0*vx1; + simd8f v10 = vy1*vx0; + simd8f v00 = vy0*vx0; + + int32 _best_o[8]; simd8i(best_o).store(_best_o); + int32 _ixp[8]; ixp.store(_ixp); + float _v11[8]; v11.store(_v11); + float _v01[8]; v01.store(_v01); + float _v10[8]; v10.store(_v10); + float _v00[8]; v00.store(_v00); + + hist[iyp + 1][_ixp[0]](_best_o[0]) += _v11[0]; + hist[iyp + 1 + 1][_ixp[0]](_best_o[0]) += _v01[0]; + hist[iyp + 1][_ixp[0] + 1](_best_o[0]) += _v10[0]; + hist[iyp + 1 + 1][_ixp[0] + 1](_best_o[0]) += _v00[0]; + + hist[iyp + 1][_ixp[1]](_best_o[1]) += _v11[1]; + hist[iyp + 1 + 1][_ixp[1]](_best_o[1]) += _v01[1]; + hist[iyp + 1][_ixp[1] + 1](_best_o[1]) += _v10[1]; + hist[iyp + 1 + 1][_ixp[1] + 1](_best_o[1]) += _v00[1]; + + hist[iyp + 1][_ixp[2]](_best_o[2]) += _v11[2]; + hist[iyp + 1 + 1][_ixp[2]](_best_o[2]) += _v01[2]; + hist[iyp + 1][_ixp[2] + 1](_best_o[2]) += _v10[2]; + hist[iyp + 1 + 1][_ixp[2] + 1](_best_o[2]) += _v00[2]; + + hist[iyp + 1][_ixp[3]](_best_o[3]) += _v11[3]; + hist[iyp + 1 + 1][_ixp[3]](_best_o[3]) += _v01[3]; + hist[iyp + 1][_ixp[3] + 1](_best_o[3]) += _v10[3]; + hist[iyp + 1 + 1][_ixp[3] + 1](_best_o[3]) += _v00[3]; + + hist[iyp + 1][_ixp[4]](_best_o[4]) += _v11[4]; + hist[iyp + 1 + 1][_ixp[4]](_best_o[4]) += _v01[4]; + hist[iyp + 1][_ixp[4] + 1](_best_o[4]) += _v10[4]; + hist[iyp + 1 + 1][_ixp[4] + 1](_best_o[4]) += _v00[4]; + + hist[iyp + 1][_ixp[5]](_best_o[5]) += _v11[5]; + hist[iyp + 1 + 1][_ixp[5]](_best_o[5]) += _v01[5]; + hist[iyp + 1][_ixp[5] + 1](_best_o[5]) += _v10[5]; + hist[iyp + 1 + 1][_ixp[5] + 1](_best_o[5]) += _v00[5]; + + hist[iyp + 1][_ixp[6]](_best_o[6]) += _v11[6]; + hist[iyp + 1 + 1][_ixp[6]](_best_o[6]) += _v01[6]; + hist[iyp + 1][_ixp[6] + 1](_best_o[6]) += _v10[6]; + hist[iyp + 1 + 1][_ixp[6] + 1](_best_o[6]) += _v00[6]; + + hist[iyp + 1][_ixp[7]](_best_o[7]) += _v11[7]; + hist[iyp + 1 + 1][_ixp[7]](_best_o[7]) += _v01[7]; + hist[iyp + 1][_ixp[7] + 1](_best_o[7]) += _v10[7]; + hist[iyp + 1 + 1][_ixp[7] + 1](_best_o[7]) += _v00[7]; + } + // Now process the right columns that don't fit into simd registers. + for (; x < visible_nc; x++) + { + matrix grad; + float v; + get_gradient(y,x,img,grad,v); + + // snap to one of 18 orientations + float best_dot = 0; + int best_o = 0; + for (int o = 0; o < 9; o++) + { + const float dot = dlib::dot(directions[o], grad); + if (dot > best_dot) + { + best_dot = dot; + best_o = o; + } + else if (-dot > best_dot) + { + best_dot = -dot; + best_o = o+9; + } + } + + v = std::sqrt(v); + // add to 4 histograms around pixel using bilinear interpolation + const float xp = ((double)x + 0.5) / (double)cell_size - 0.5; + const int ixp = (int)std::floor(xp); + const float vx0 = xp - ixp; + const float vx1 = 1.0 - vx0; + + hist[iyp+1][ixp+1](best_o) += vy1*vx1*v; + hist[iyp+1+1][ixp+1](best_o) += vy0*vx1*v; + hist[iyp+1][ixp+1+1](best_o) += vy1*vx0*v; + hist[iyp+1+1][ixp+1+1](best_o) += vy0*vx0*v; + } + } + + // compute energy in each block by summing over orientations + for (int r = 0; r < cells_nr; ++r) + { + for (int c = 0; c < cells_nc; ++c) + { + for (int o = 0; o < 9; o++) + { + norm[r][c] += (hist[r+1][c+1](o) + hist[r+1][c+1](o+9)) * (hist[r+1][c+1](o) + hist[r+1][c+1](o+9)); + } + } + } + + const float eps = 0.0001; + // compute features + for (int y = 0; y < hog_nr; y++) + { + const int yy = y+padding_rows_offset; + for (int x = 0; x < hog_nc; x++) + { + const simd4f z1(norm[y+1][x+1], + norm[y][x+1], + norm[y+1][x], + norm[y][x]); + + const simd4f z2(norm[y+1][x+2], + norm[y][x+2], + norm[y+1][x+1], + norm[y][x+1]); + + const simd4f z3(norm[y+2][x+1], + norm[y+1][x+1], + norm[y+2][x], + norm[y+1][x]); + + const simd4f z4(norm[y+2][x+2], + norm[y+1][x+2], + norm[y+2][x+1], + norm[y+1][x+1]); + + const simd4f nn = 0.2*sqrt(z1+z2+z3+z4+eps); + const simd4f n = 0.1/nn; + + simd4f t = 0; + + const int xx = x+padding_cols_offset; + + // contrast-sensitive features + for (int o = 0; o < 18; o+=3) + { + simd4f temp0(hist[y+1+1][x+1+1](o)); + simd4f temp1(hist[y+1+1][x+1+1](o+1)); + simd4f temp2(hist[y+1+1][x+1+1](o+2)); + simd4f h0 = min(temp0,nn)*n; + simd4f h1 = min(temp1,nn)*n; + simd4f h2 = min(temp2,nn)*n; + set_hog(hog,o,xx,yy, sum(h0)); + set_hog(hog,o+1,xx,yy, sum(h1)); + set_hog(hog,o+2,xx,yy, sum(h2)); + t += h0+h1+h2; + } + + t *= 2*0.2357; + + // contrast-insensitive features + for (int o = 0; o < 9; o+=3) + { + simd4f temp0 = hist[y+1+1][x+1+1](o) + hist[y+1+1][x+1+1](o+9); + simd4f temp1 = hist[y+1+1][x+1+1](o+1) + hist[y+1+1][x+1+1](o+9+1); + simd4f temp2 = hist[y+1+1][x+1+1](o+2) + hist[y+1+1][x+1+1](o+9+2); + simd4f h0 = min(temp0,nn)*n; + simd4f h1 = min(temp1,nn)*n; + simd4f h2 = min(temp2,nn)*n; + set_hog(hog,o+18,xx,yy, sum(h0)); + set_hog(hog,o+18+1,xx,yy, sum(h1)); + set_hog(hog,o+18+2,xx,yy, sum(h2)); + } + + + float temp[4]; + t.store(temp); + + // texture features + set_hog(hog,27,xx,yy, temp[0]); + set_hog(hog,28,xx,yy, temp[1]); + set_hog(hog,29,xx,yy, temp[2]); + set_hog(hog,30,xx,yy, temp[3]); + } + } + } + + // ------------------------------------------------------------------------------------ + + inline void create_fhog_bar_images ( + dlib::array >& mbars, + const long w + ) + { + const long bdims = 9; + // Make the oriented lines we use to draw on each HOG cell. + mbars.resize(bdims); + dlib::array > bars(bdims); + array2d temp(w,w); + for (unsigned long i = 0; i < bars.size(); ++i) + { + assign_all_pixels(temp, 0); + draw_line(temp, point(w/2,0), point(w/2,w-1), 255); + rotate_image(temp, bars[i], i*-pi/bars.size()); + + mbars[i] = subm(matrix_cast(mat(bars[i])), centered_rect(get_rect(bars[i]),w,w) ); + } + } + + } // end namespace impl_fhog + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename image_type, + typename T, + typename mm1, + typename mm2 + > + void extract_fhog_features( + const image_type& img, + dlib::array,mm2>& hog, + int cell_size = 8, + int filter_rows_padding = 1, + int filter_cols_padding = 1 + ) + { + impl_fhog::impl_extract_fhog_features(img, hog, cell_size, filter_rows_padding, filter_cols_padding); + // If the image is too small then the above function outputs an empty feature map. + // But to make things very uniform in usage we require the output to still have the + // 31 planes (but they are just empty). + if (hog.size() == 0) + hog.resize(31); + } + + template < + typename image_type, + typename T, + typename mm + > + void extract_fhog_features( + const image_type& img, + array2d,mm>& hog, + int cell_size = 8, + int filter_rows_padding = 1, + int filter_cols_padding = 1 + ) + { + impl_fhog::impl_extract_fhog_features(img, hog, cell_size, filter_rows_padding, filter_cols_padding); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type, + typename T + > + void extract_fhog_features( + const image_type& img, + matrix& feats, + int cell_size = 8, + int filter_rows_padding = 1, + int filter_cols_padding = 1 + ) + { + dlib::array > hog; + extract_fhog_features(img, hog, cell_size, filter_rows_padding, filter_cols_padding); + feats.set_size(hog.size()*hog[0].size()); + for (unsigned long i = 0; i < hog.size(); ++i) + { + const long size = hog[i].size(); + set_rowm(feats, range(i*size, (i+1)*size-1)) = reshape_to_column_vector(mat(hog[i])); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + matrix extract_fhog_features( + const image_type& img, + int cell_size = 8, + int filter_rows_padding = 1, + int filter_cols_padding = 1 + ) + { + matrix feats; + extract_fhog_features(img, feats, cell_size, filter_rows_padding, filter_cols_padding); + return feats; + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + inline point image_to_fhog ( + point p, + int cell_size = 8, + int filter_rows_padding = 1, + int filter_cols_padding = 1 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( cell_size > 0 && + filter_rows_padding > 0 && + filter_cols_padding > 0 , + "\t point image_to_fhog()" + << "\n\t Invalid inputs were given to this function. " + << "\n\t cell_size: " << cell_size + << "\n\t filter_rows_padding: " << filter_rows_padding + << "\n\t filter_cols_padding: " << filter_cols_padding + ); + + // There is a one pixel border around the image. + p -= point(1,1); + // There is also a 1 "cell" border around the HOG image formation. + return p/cell_size - point(1,1) + point((filter_cols_padding-1)/2,(filter_rows_padding-1)/2); + } + +// ---------------------------------------------------------------------------------------- + + inline rectangle image_to_fhog ( + const rectangle& rect, + int cell_size = 8, + int filter_rows_padding = 1, + int filter_cols_padding = 1 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( cell_size > 0 && + filter_rows_padding > 0 && + filter_cols_padding > 0 , + "\t rectangle image_to_fhog()" + << "\n\t Invalid inputs were given to this function. " + << "\n\t cell_size: " << cell_size + << "\n\t filter_rows_padding: " << filter_rows_padding + << "\n\t filter_cols_padding: " << filter_cols_padding + ); + + return rectangle(image_to_fhog(rect.tl_corner(),cell_size,filter_rows_padding,filter_cols_padding), + image_to_fhog(rect.br_corner(),cell_size,filter_rows_padding,filter_cols_padding)); + } + +// ---------------------------------------------------------------------------------------- + + inline point fhog_to_image ( + point p, + int cell_size = 8, + int filter_rows_padding = 1, + int filter_cols_padding = 1 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( cell_size > 0 && + filter_rows_padding > 0 && + filter_cols_padding > 0 , + "\t point fhog_to_image()" + << "\n\t Invalid inputs were given to this function. " + << "\n\t cell_size: " << cell_size + << "\n\t filter_rows_padding: " << filter_rows_padding + << "\n\t filter_cols_padding: " << filter_cols_padding + ); + + // Convert to image space and then set to the center of the cell. + point offset; + + p = (p+point(1,1)-point((filter_cols_padding-1)/2,(filter_rows_padding-1)/2))*cell_size + point(1,1); + if (p.x() >= 0 && p.y() >= 0) offset = point(cell_size/2,cell_size/2); + if (p.x() < 0 && p.y() >= 0) offset = point(-cell_size/2,cell_size/2); + if (p.x() >= 0 && p.y() < 0) offset = point(cell_size/2,-cell_size/2); + if (p.x() < 0 && p.y() < 0) offset = point(-cell_size/2,-cell_size/2); + return p + offset; + } + +// ---------------------------------------------------------------------------------------- + + inline rectangle fhog_to_image ( + const rectangle& rect, + int cell_size = 8, + int filter_rows_padding = 1, + int filter_cols_padding = 1 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( cell_size > 0 && + filter_rows_padding > 0 && + filter_cols_padding > 0 , + "\t rectangle fhog_to_image()" + << "\n\t Invalid inputs were given to this function. " + << "\n\t cell_size: " << cell_size + << "\n\t filter_rows_padding: " << filter_rows_padding + << "\n\t filter_cols_padding: " << filter_cols_padding + ); + + return rectangle(fhog_to_image(rect.tl_corner(),cell_size,filter_rows_padding,filter_cols_padding), + fhog_to_image(rect.br_corner(),cell_size,filter_rows_padding,filter_cols_padding)); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mm1, + typename mm2 + > + matrix draw_fhog( + const dlib::array,mm2>& hog, + const long cell_draw_size = 15, + const float min_response_threshold = 0.0 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( cell_draw_size > 0 && hog.size()==31, + "\t matrix draw_fhog()" + << "\n\t Invalid inputs were given to this function. " + << "\n\t cell_draw_size: " << cell_draw_size + << "\n\t hog.size(): " << hog.size() + ); + + dlib::array > mbars; + impl_fhog::create_fhog_bar_images(mbars,cell_draw_size); + + // now draw the bars onto the HOG cells + matrix himg(hog[0].nr()*cell_draw_size, hog[0].nc()*cell_draw_size); + himg = 0; + for (unsigned long d = 0; d < mbars.size(); ++d) + { + for (long r = 0; r < himg.nr(); r+=cell_draw_size) + { + for (long c = 0; c < himg.nc(); c+=cell_draw_size) + { + const float val = hog[d][r/cell_draw_size][c/cell_draw_size] + + hog[d+mbars.size()][r/cell_draw_size][c/cell_draw_size] + + hog[d+mbars.size()*2][r/cell_draw_size][c/cell_draw_size]; + if (val > min_response_threshold) + { + set_subm(himg, r, c, cell_draw_size, cell_draw_size) += val*mbars[d%mbars.size()]; + } + } + } + } + + const float thresh = mean(himg) + 4 * stddev(himg); + if (thresh != 0) + return matrix_cast(upperbound(round(himg*255/thresh),255)); + else + return matrix_cast(himg); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + matrix draw_fhog ( + const std::vector >& hog, + const long cell_draw_size = 15, + const float min_response_threshold = 0.0 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( cell_draw_size > 0 && hog.size()==31, + "\t matrix draw_fhog()" + << "\n\t Invalid inputs were given to this function. " + << "\n\t cell_draw_size: " << cell_draw_size + << "\n\t hog.size(): " << hog.size() + ); + + // Just convert the input into the right object and then call the above draw_fhog() + // function on it. + dlib::array > temp(hog.size()); + for (unsigned long i = 0; i < temp.size(); ++i) + { + temp[i].set_size(hog[i].nr(), hog[i].nc()); + for (long r = 0; r < hog[i].nr(); ++r) + { + for (long c = 0; c < hog[i].nc(); ++c) + { + temp[i][r][c] = hog[i](r,c); + } + } + } + return draw_fhog(temp,cell_draw_size, min_response_threshold); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mm + > + matrix draw_fhog( + const array2d,mm>& hog, + const long cell_draw_size = 15, + const float min_response_threshold = 0.0 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( cell_draw_size > 0, + "\t matrix draw_fhog()" + << "\n\t Invalid inputs were given to this function. " + << "\n\t cell_draw_size: " << cell_draw_size + ); + + dlib::array > mbars; + impl_fhog::create_fhog_bar_images(mbars,cell_draw_size); + + // now draw the bars onto the HOG cells + matrix himg(hog.nr()*cell_draw_size, hog.nc()*cell_draw_size); + himg = 0; + for (unsigned long d = 0; d < mbars.size(); ++d) + { + for (long r = 0; r < himg.nr(); r+=cell_draw_size) + { + for (long c = 0; c < himg.nc(); c+=cell_draw_size) + { + const float val = hog[r/cell_draw_size][c/cell_draw_size](d) + + hog[r/cell_draw_size][c/cell_draw_size](d+mbars.size()) + + hog[r/cell_draw_size][c/cell_draw_size](d+mbars.size()*2); + if (val > min_response_threshold) + { + set_subm(himg, r, c, cell_draw_size, cell_draw_size) += val*mbars[d%mbars.size()]; + } + } + } + } + + const float thresh = mean(himg) + 4 * stddev(himg); + if (thresh != 0) + return matrix_cast(upperbound(round(himg*255/thresh),255)); + else + return matrix_cast(himg); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_fHOG_Hh_ + diff --git a/ml/dlib/dlib/image_transforms/fhog_abstract.h b/ml/dlib/dlib/image_transforms/fhog_abstract.h new file mode 100644 index 000000000..f66c5d55a --- /dev/null +++ b/ml/dlib/dlib/image_transforms/fhog_abstract.h @@ -0,0 +1,346 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_fHOG_ABSTRACT_Hh_ +#ifdef DLIB_fHOG_ABSTRACT_Hh_ + +#include "../matrix/matrix_abstract.h" +#include "../array2d/array2d_kernel_abstract.h" +#include "../array/array_kernel_abstract.h" +#include "../image_processing/generic_image.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type, + typename T, + typename mm + > + void extract_fhog_features( + const image_type& img, + array2d,mm>& hog, + int cell_size = 8, + int filter_rows_padding = 1, + int filter_cols_padding = 1 + ); + /*! + requires + - cell_size > 0 + - filter_rows_padding > 0 + - filter_cols_padding > 0 + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - T should be float or double + ensures + - This function implements the HOG feature extraction method described in + the paper: + Object Detection with Discriminatively Trained Part Based Models by + P. Felzenszwalb, R. Girshick, D. McAllester, D. Ramanan + IEEE Transactions on Pattern Analysis and Machine Intelligence, Vol. 32, No. 9, Sep. 2010 + This means that it takes an input image img and outputs Felzenszwalb's + 31 dimensional version of HOG features, which are stored into #hog. + - The input image is broken into cells that are cell_size by cell_size pixels + and within each cell we compute a 31 dimensional FHOG vector. This vector + describes the gradient structure within the cell. + - A common task is to convolve each channel of the hog image with a linear + filter. This is made more convenient if the contents of #hog includes extra + rows and columns of zero padding along the borders. This extra padding + allows for more efficient convolution code since the code does not need to + perform expensive boundary checking. Therefore, you can set + filter_rows_padding and filter_cols_padding to indicate the size of the + filter you wish to use and this function will ensure #hog has the appropriate + extra zero padding along the borders. In particular, it will include the + following extra padding: + - (filter_rows_padding-1)/2 extra rows of zeros on the top of #hog. + - (filter_cols_padding-1)/2 extra columns of zeros on the left of #hog. + - filter_rows_padding/2 extra rows of zeros on the bottom of #hog. + - filter_cols_padding/2 extra columns of zeros on the right of #hog. + Therefore, the extra padding is done such that functions like + spatially_filter_image() apply their filters to the entire content containing + area of a hog image (note that you should use the following planar version of + extract_fhog_features() instead of the interlaced version if you want to use + spatially_filter_image() on a hog image). + - #hog.nr() == max(round(img.nr()/(double)cell_size)-2,0) + filter_rows_padding-1. + - #hog.nc() == max(round(img.nc()/(double)cell_size)-2,0) + filter_cols_padding-1. + (i.e. Each output dimension is roughly 1/cell_size the original size but + there is a one cell_size border all around the image that is lost and then we + add on any additional padding that is requested.) + - for all valid r and c: + - #hog[r][c] == the FHOG vector describing the cell centered at the pixel location + fhog_to_image(point(c,r),cell_size,filter_rows_padding,filter_cols_padding) in img. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type, + typename T, + typename mm1, + typename mm2 + > + void extract_fhog_features( + const image_type& img, + dlib::array,mm2>& hog, + int cell_size = 8, + int filter_rows_padding = 1, + int filter_cols_padding = 1 + ); + /*! + requires + - cell_size > 0 + - filter_rows_padding > 0 + - filter_cols_padding > 0 + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - T should be float or double + ensures + - This function is identical to the above extract_fhog_features() routine + except that it outputs the results in a planar format rather than the + interlaced format used above. That is, each element of the hog vector is + placed into one of 31 images inside #hog. To be precise, if vhog is the + output of the above interlaced version of extract_fhog_features() then we + will have, for all valid r and c: + - #hog[i][r][c] == vhog[r][c](i) + (where 0 <= i < 31) + - #hog.size() == 31 + - for all valid i: + - #hog[i].nr() == hog[0].nr() + - #hog[i].nc() == hog[0].nc() + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + matrix extract_fhog_features( + const image_type& img, + int cell_size = 8, + int filter_rows_padding = 1, + int filter_cols_padding = 1 + ); + /*! + requires + - cell_size > 0 + - filter_rows_padding > 0 + - filter_cols_padding > 0 + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + ensures + - This function calls the above extract_fhog_features() routine and simply + packages the entire output into a dlib::matrix. The matrix is constructed + using the planar version of extract_fhog_features() and then each output + plane is converted into a column vector and subsequently all 31 column + vectors are concatenated together and returned. + - Each plane is converted into a column vector using reshape_to_column_vector(), + and is therefore represented in row major order inside the returned vector. + - If H is the array> object output by the planar + extract_fhog_features() then the returned vector is composed by concatenating + H[0], then H[1], then H[2], and so on in ascending index order. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type, + typename T + > + void extract_fhog_features( + const image_type& img, + matrix& feats, + int cell_size = 8, + int filter_rows_padding = 1, + int filter_cols_padding = 1 + ); + /*! + requires + - cell_size > 0 + - filter_rows_padding > 0 + - filter_cols_padding > 0 + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - T is float, double, or long double + ensures + - This function is identical to the above version of extract_fhog_features() + that returns a matrix except that it returns the matrix here + through a reference argument instead of returning it by value. + !*/ + +// ---------------------------------------------------------------------------------------- + + inline point image_to_fhog ( + point p, + int cell_size = 8, + int filter_rows_padding = 1, + int filter_cols_padding = 1 + ); + /*! + requires + - cell_size > 0 + - filter_rows_padding > 0 + - filter_cols_padding > 0 + ensures + - When using extract_fhog_features(), each FHOG cell is extracted from a + certain region in the input image. image_to_fhog() returns the identity of + the FHOG cell containing the image pixel at location p. Or in other words, + let P == image_to_fhog(p) and hog be a FHOG feature map output by + extract_fhog_features(), then hog[P.y()][P.x()] == the FHOG vector/cell + containing the point p in the input image. Note that some image points + might not have corresponding feature locations. E.g. border points or points + outside the image. In these cases the returned point will be outside the + input image. + - Note that you should use the same values of cell_size, filter_rows_padding, + and filter_cols_padding that you used with extract_fhog_features(). + !*/ + +// ---------------------------------------------------------------------------------------- + + inline rectangle image_to_fhog ( + const rectangle& rect, + int cell_size = 8, + int filter_rows_padding = 1, + int filter_cols_padding = 1 + ); + /*! + requires + - cell_size > 0 + - filter_rows_padding > 0 + - filter_cols_padding > 0 + ensures + - maps a rectangle from image space to fhog space. In particular this function returns: + rectangle(image_to_fhog(rect.tl_corner(),cell_size,filter_rows_padding,filter_cols_padding), + image_to_fhog(rect.br_corner(),cell_size,filter_rows_padding,filter_cols_padding)) + !*/ + +// ---------------------------------------------------------------------------------------- + + inline point fhog_to_image ( + point p, + int cell_size = 8, + int filter_rows_padding = 1, + int filter_cols_padding = 1 + ); + /*! + requires + - cell_size > 0 + - filter_rows_padding > 0 + - filter_cols_padding > 0 + ensures + - Maps a pixel in a FHOG image (produced by extract_fhog_features()) back to the + corresponding original input pixel. Note that since FHOG images are + spatially downsampled by aggregation into cells the mapping is not totally + invertible. Therefore, the returned location will be the center of the cell + in the original image that contained the FHOG vector at position p. Moreover, + cell_size, filter_rows_padding, and filter_cols_padding should be set to the + values used by the call to extract_fhog_features(). + - Mapping from fhog space to image space is an invertible transformation. That + is, for any point P we have P == image_to_fhog(fhog_to_image(P,cell_size,filter_rows_padding,filter_cols_padding), + cell_size,filter_rows_padding,filter_cols_padding). + !*/ + +// ---------------------------------------------------------------------------------------- + + inline rectangle fhog_to_image ( + const rectangle& rect, + int cell_size = 8, + int filter_rows_padding = 1, + int filter_cols_padding = 1 + ); + /*! + requires + - cell_size > 0 + - filter_rows_padding > 0 + - filter_cols_padding > 0 + ensures + - maps a rectangle from fhog space to image space. In particular this function returns: + rectangle(fhog_to_image(rect.tl_corner(),cell_size,filter_rows_padding,filter_cols_padding), + fhog_to_image(rect.br_corner(),cell_size,filter_rows_padding,filter_cols_padding)) + - Mapping from fhog space to image space is an invertible transformation. That + is, for any rectangle R we have R == image_to_fhog(fhog_to_image(R,cell_size,filter_rows_padding,filter_cols_padding), + cell_size,filter_rows_padding,filter_cols_padding). + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mm1, + typename mm2 + > + matrix draw_fhog( + const dlib::array,mm2>& hog, + const long cell_draw_size = 15, + const float min_response_threshold = 0.0 + ); + /*! + requires + - cell_draw_size > 0 + - hog.size() == 31 + ensures + - Interprets hog as a FHOG feature map output by extract_fhog_features() and + converts it into an image suitable for display on the screen. In particular, + we draw all the hog cells into a grayscale image in a way that shows the + magnitude and orientation of the gradient energy in each cell. The result is + then returned. + - The size of the cells in the output image will be rendered as cell_draw_size + pixels wide and tall. + - HOG cells with a response value less than min_response_threshold are not + drawn. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + matrix draw_fhog ( + const std::vector >& hog, + const long cell_draw_size = 15, + const float min_response_threshold = 0.0 + ); + /*! + requires + - cell_draw_size > 0 + - hog.size() == 31 + ensures + - This function just converts the given hog object into an array> + and passes it to the above draw_fhog() routine and returns the results. + - HOG cells with a response value less than min_response_threshold are not + drawn. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mm + > + matrix draw_fhog( + const array2d,mm>& hog, + const long cell_draw_size = 15, + const float min_response_threshold = 0.0 + ); + /*! + requires + - cell_draw_size > 0 + ensures + - Interprets hog as a FHOG feature map output by extract_fhog_features() and + converts it into an image suitable for display on the screen. In particular, + we draw all the hog cells into a grayscale image in a way that shows the + magnitude and orientation of the gradient energy in each cell. The result is + then returned. + - The size of the cells in the output image will be rendered as cell_draw_size + pixels wide and tall. + - HOG cells with a response value less than min_response_threshold are not + drawn. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_fHOG_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/image_transforms/hough_transform.h b/ml/dlib/dlib/image_transforms/hough_transform.h new file mode 100644 index 000000000..477b4dc2b --- /dev/null +++ b/ml/dlib/dlib/image_transforms/hough_transform.h @@ -0,0 +1,358 @@ +// Copyright (C) 2014 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_HOUGH_tRANSFORM_Hh_ +#define DLIB_HOUGH_tRANSFORM_Hh_ + +#include "hough_transform_abstract.h" +#include "../image_processing/generic_image.h" +#include "../geometry.h" +#include "../algs.h" +#include "assign_image.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class hough_transform + { + + public: + explicit hough_transform ( + unsigned long size_ + ) : _size(size_) + { + DLIB_CASSERT(size_ > 0, + "\t hough_transform::hough_transform(size_)" + << "\n\t Invalid arguments given to this function." + ); + + even_size = _size - (_size%2); + + const point cent = center(rectangle(0,0,size_-1,size_-1)); + xcos_theta.set_size(size_, size_); + ysin_theta.set_size(size_, size_); + + std::vector cos_theta(size_), sin_theta(size_); + const double scale = 1<<16; + for (unsigned long t = 0; t < size_; ++t) + { + double theta = t*pi/even_size; + + cos_theta[t] = scale*std::cos(theta)/sqrt_2; + sin_theta[t] = scale*std::sin(theta)/sqrt_2; + } + const double offset = scale*even_size/4.0 + 0.5; + + for (unsigned long c = 0; c < size_; ++c) + { + const long x = c - cent.x(); + for (unsigned long t = 0; t < size_; ++t) + xcos_theta(c,t) = static_cast(x*cos_theta[t] + offset); + } + for (unsigned long r = 0; r < size_; ++r) + { + const long y = r - cent.y(); + for (unsigned long t = 0; t < size_; ++t) + ysin_theta(r,t) = static_cast(y*sin_theta[t] + offset); + } + } + + unsigned long size( + ) const { return _size; } + + long nr( + ) const { return _size; } + + long nc( + ) const { return _size; } + + std::pair get_line ( + const point& p + ) const + { + DLIB_ASSERT(rectangle(0,0,size()-1,size()-1).contains(p) == true, + "\t pair hough_transform::get_line(point)" + << "\n\t Invalid arguments given to this function." + << "\n\t p: " << p + << "\n\t size(): " << size() + ); + + // First we compute the radius measured in pixels from the center and the theta + // angle in radians. + typedef dlib::vector vect; + const rectangle box(0,0,size()-1,size()-1); + const vect cent = center(box); + double theta = p.x()-cent.x(); + double radius = p.y()-cent.y(); + theta = theta*pi/even_size; + radius = radius*sqrt_2 + 0.5; + + // now make a line segment on the line. + vect v1 = cent + vect(size()+1000,0) + vect(0,radius); + vect v2 = cent - vect(size()+1000,0) + vect(0,radius); + point p1 = rotate_point(cent, v1, theta); + point p2 = rotate_point(cent, v2, theta); + + clip_line_to_rectangle(box, p1, p2); + + return std::make_pair(p1,p2); + } + + template < + typename image_type + > + point get_best_hough_point ( + const point& p, + const image_type& himg_ + ) + { + const const_image_view himg(himg_); + + DLIB_ASSERT(himg.nr() == size() && himg.nc() == size() && + rectangle(0,0,size()-1,size()-1).contains(p) == true, + "\t point hough_transform::get_best_hough_point()" + << "\n\t Invalid arguments given to this function." + << "\n\t himg.nr(): " << himg.nr() + << "\n\t himg.nc(): " << himg.nc() + << "\n\t size(): " << size() + << "\n\t p: " << p + ); + + + typedef typename image_traits::pixel_type pixel_type; + COMPILE_TIME_ASSERT(pixel_traits::grayscale == true); + pixel_type best_val = std::numeric_limits::min(); + point best_point; + + + const long max_n8 = (himg.nc()/8)*8; + const long max_n4 = (himg.nc()/4)*4; + const long r = p.y(); + const long c = p.x(); + + const int32* ysin = &ysin_theta(r,0); + const int32* xcos = &xcos_theta(c,0); + long t = 0; + while(t < max_n8) + { + long rr0 = (*xcos++ + *ysin++)>>16; + long rr1 = (*xcos++ + *ysin++)>>16; + long rr2 = (*xcos++ + *ysin++)>>16; + long rr3 = (*xcos++ + *ysin++)>>16; + long rr4 = (*xcos++ + *ysin++)>>16; + long rr5 = (*xcos++ + *ysin++)>>16; + long rr6 = (*xcos++ + *ysin++)>>16; + long rr7 = (*xcos++ + *ysin++)>>16; + + if (himg[rr0][t++] > best_val) + { + best_val = himg[rr0][t-1]; + best_point.x() = t-1; + best_point.y() = rr0; + } + if (himg[rr1][t++] > best_val) + { + best_val = himg[rr1][t-1]; + best_point.x() = t-1; + best_point.y() = rr1; + } + if (himg[rr2][t++] > best_val) + { + best_val = himg[rr2][t-1]; + best_point.x() = t-1; + best_point.y() = rr2; + } + if (himg[rr3][t++] > best_val) + { + best_val = himg[rr3][t-1]; + best_point.x() = t-1; + best_point.y() = rr3; + } + if (himg[rr4][t++] > best_val) + { + best_val = himg[rr4][t-1]; + best_point.x() = t-1; + best_point.y() = rr4; + } + if (himg[rr5][t++] > best_val) + { + best_val = himg[rr5][t-1]; + best_point.x() = t-1; + best_point.y() = rr5; + } + if (himg[rr6][t++] > best_val) + { + best_val = himg[rr6][t-1]; + best_point.x() = t-1; + best_point.y() = rr6; + } + if (himg[rr7][t++] > best_val) + { + best_val = himg[rr7][t-1]; + best_point.x() = t-1; + best_point.y() = rr7; + } + } + while(t < max_n4) + { + long rr0 = (*xcos++ + *ysin++)>>16; + long rr1 = (*xcos++ + *ysin++)>>16; + long rr2 = (*xcos++ + *ysin++)>>16; + long rr3 = (*xcos++ + *ysin++)>>16; + if (himg[rr0][t++] > best_val) + { + best_val = himg[rr0][t-1]; + best_point.x() = t-1; + best_point.y() = rr0; + } + if (himg[rr1][t++] > best_val) + { + best_val = himg[rr1][t-1]; + best_point.x() = t-1; + best_point.y() = rr1; + } + if (himg[rr2][t++] > best_val) + { + best_val = himg[rr2][t-1]; + best_point.x() = t-1; + best_point.y() = rr2; + } + if (himg[rr3][t++] > best_val) + { + best_val = himg[rr3][t-1]; + best_point.x() = t-1; + best_point.y() = rr3; + } + } + while(t < himg.nc()) + { + long rr0 = (*xcos++ + *ysin++)>>16; + if (himg[rr0][t++] > best_val) + { + best_val = himg[rr0][t-1]; + best_point.x() = t-1; + best_point.y() = rr0; + } + } + + return best_point; + } + + template < + typename in_image_type, + typename out_image_type + > + void operator() ( + const in_image_type& img_, + const rectangle& box, + out_image_type& himg_ + ) const + { + typedef typename image_traits::pixel_type in_pixel_type; + typedef typename image_traits::pixel_type out_pixel_type; + + DLIB_CASSERT(box.width() == size() && box.height() == size(), + "\t hough_transform::hough_transform(size_)" + << "\n\t Invalid arguments given to this function." + << "\n\t box.width(): " << box.width() + << "\n\t box.height(): " << box.height() + << "\n\t size(): " << size() + ); + + COMPILE_TIME_ASSERT(pixel_traits::grayscale == true); + COMPILE_TIME_ASSERT(pixel_traits::grayscale == true); + + const_image_view img(img_); + image_view himg(himg_); + + himg.set_size(size(), size()); + assign_all_pixels(himg, 0); + + const rectangle area = box.intersect(get_rect(img)); + + const long max_n8 = (himg.nc()/8)*8; + const long max_n4 = (himg.nc()/4)*4; + for (long r = area.top(); r <= area.bottom(); ++r) + { + const int32* ysin_base = &ysin_theta(r-box.top(),0); + for (long c = area.left(); c <= area.right(); ++c) + { + const out_pixel_type val = static_cast(img[r][c]); + if (val != 0) + { + /* + // The code in this comment is equivalent to the more complex but + // faster code below. We keep this simple version of the Hough + // transform implementation here just to document what it's doing + // more clearly. + const point cent = center(box); + const long x = c - cent.x(); + const long y = r - cent.y(); + for (long t = 0; t < himg.nc(); ++t) + { + double theta = t*pi/even_size; + double radius = (x*std::cos(theta) + y*std::sin(theta))/sqrt_2 + even_size/2 + 0.5; + long rr = static_cast(radius); + himg[rr][t] += val; + } + continue; + */ + + // Run the speed optimized version of the code in the above + // comment. + const int32* ysin = ysin_base; + const int32* xcos = &xcos_theta(c-box.left(),0); + long t = 0; + while(t < max_n8) + { + long rr0 = (*xcos++ + *ysin++)>>16; + long rr1 = (*xcos++ + *ysin++)>>16; + long rr2 = (*xcos++ + *ysin++)>>16; + long rr3 = (*xcos++ + *ysin++)>>16; + long rr4 = (*xcos++ + *ysin++)>>16; + long rr5 = (*xcos++ + *ysin++)>>16; + long rr6 = (*xcos++ + *ysin++)>>16; + long rr7 = (*xcos++ + *ysin++)>>16; + + himg[rr0][t++] += val; + himg[rr1][t++] += val; + himg[rr2][t++] += val; + himg[rr3][t++] += val; + himg[rr4][t++] += val; + himg[rr5][t++] += val; + himg[rr6][t++] += val; + himg[rr7][t++] += val; + } + while(t < max_n4) + { + long rr0 = (*xcos++ + *ysin++)>>16; + long rr1 = (*xcos++ + *ysin++)>>16; + long rr2 = (*xcos++ + *ysin++)>>16; + long rr3 = (*xcos++ + *ysin++)>>16; + himg[rr0][t++] += val; + himg[rr1][t++] += val; + himg[rr2][t++] += val; + himg[rr3][t++] += val; + } + while(t < himg.nc()) + { + long rr0 = (*xcos++ + *ysin++)>>16; + himg[rr0][t++] += val; + } + } + } + } + } + + private: + + unsigned long _size; + unsigned long even_size; // equal to _size if _size is even, otherwise equal to _size-1. + matrix xcos_theta, ysin_theta; + }; +} + +#endif // DLIB_HOUGH_tRANSFORM_Hh_ + diff --git a/ml/dlib/dlib/image_transforms/hough_transform_abstract.h b/ml/dlib/dlib/image_transforms/hough_transform_abstract.h new file mode 100644 index 000000000..f0ff2b550 --- /dev/null +++ b/ml/dlib/dlib/image_transforms/hough_transform_abstract.h @@ -0,0 +1,145 @@ +// Copyright (C) 2014 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_HOUGH_tRANSFORM_ABSTRACT_Hh_ +#ifdef DLIB_HOUGH_tRANSFORM_ABSTRACT_Hh_ + +#include "../geometry.h" +#include "../image_processing/generic_image.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class hough_transform + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a tool for computing the line finding version of the Hough + transform given some kind of edge detection image as input. It also allows + the edge pixels to be weighted such that higher weighted edge pixels + contribute correspondingly more to the output of the Hough transform, + allowing stronger edges to create correspondingly stronger line detections + in the final Hough transform. + + THREAD SAFETY + It is safe for multiple threads to make concurrent accesses to this object + without synchronization. + !*/ + + public: + + explicit hough_transform ( + unsigned long size_ + ); + /*! + requires + - size_ > 0 + ensures + - This object will compute Hough transforms that are size_ by size_ pixels. + This is in terms of both the Hough accumulator array size as well as the + input image size. + - #size() == size_ + !*/ + + unsigned long size( + ) const; + /*! + ensures + - returns the size of the Hough transforms generated by this object. In + particular, this object creates Hough transform images that are size() by + size() pixels in size. + !*/ + + long nr( + ) const; + /*! + ensures + - returns size() + !*/ + + long nc( + ) const; + /*! + ensures + - returns size() + !*/ + + std::pair get_line ( + const point& p + ) const; + /*! + requires + - rectangle(0,0,size()-1,size()-1).contains(p) == true + (i.e. p must be a point inside the Hough accumulator array) + ensures + - returns the line segment in the original image space corresponding + to Hough transform point p. + - The returned points are inside rectangle(0,0,size()-1,size()-1). + !*/ + + template < + typename image_type + > + point get_best_hough_point ( + const point& p, + const image_type& himg + ); + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h and it must contain grayscale pixels. + - himg.nr() == size() + - himg.nc() == size() + - rectangle(0,0,size()-1,size()-1).contains(p) == true + ensures + - This function interprets himg as a Hough image and p as a point in the + original image space. Given this, it finds the maximum scoring line that + passes though p. That is, it checks all the Hough accumulator bins in + himg corresponding to lines though p and returns the location with the + largest score. + - returns a point X such that get_rect(himg).contains(X) == true + !*/ + + template < + typename in_image_type, + typename out_image_type + > + void operator() ( + const in_image_type& img, + const rectangle& box, + out_image_type& himg + ) const; + /*! + requires + - in_image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h and it must contain grayscale pixels. + - out_image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h and it must contain grayscale pixels. + - box.width() == size() + - box.height() == size() + ensures + - Computes the Hough transform of the part of img contained within box. + In particular, we do a grayscale version of the Hough transform where any + non-zero pixel in img is treated as a potential component of a line and + accumulated into the Hough accumulator #himg. However, rather than + adding 1 to each relevant accumulator bin we add the value of the pixel + in img to each Hough accumulator bin. This means that, if all the + pixels in img are 0 or 1 then this routine performs a normal Hough + transform. However, if some pixels have larger values then they will be + weighted correspondingly more in the resulting Hough transform. + - #himg.nr() == size() + - #himg.nc() == size() + - #himg is the Hough transform of the part of img contained in box. Each + point in #himg corresponds to a line in the input box. In particular, + the line for #himg[y][x] is given by get_line(point(x,y)). Also, when + viewing the #himg image, the x-axis gives the angle of the line and the + y-axis the distance of the line from the center of the box. + !*/ + + }; +} + +#endif // DLIB_HOUGH_tRANSFORM_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/image_transforms/image_pyramid.h b/ml/dlib/dlib/image_transforms/image_pyramid.h new file mode 100644 index 000000000..3efed30d8 --- /dev/null +++ b/ml/dlib/dlib/image_transforms/image_pyramid.h @@ -0,0 +1,1238 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_IMAGE_PYRaMID_Hh_ +#define DLIB_IMAGE_PYRaMID_Hh_ + +#include "image_pyramid_abstract.h" +#include "../pixel.h" +#include "../array2d.h" +#include "../geometry.h" +#include "spatial_filtering.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class pyramid_disable : noncopyable + { + public: + + template + vector point_down ( + const vector& + ) const + { + return vector(0,0); + } + + template + vector point_up ( + const vector& + ) const + { + return vector(0,0); + } + + // ----------------------------- + + template + vector point_down ( + const vector& p, + unsigned int levels + ) const + { + if (levels == 0) + return p; + else + return vector(0,0); + } + + template + vector point_up ( + const vector& p, + unsigned int levels + ) const + { + if (levels == 0) + return p; + else + return vector(0,0); + } + + // ----------------------------- + + drectangle rect_up ( + const drectangle& rect + ) const + { + return drectangle(point_up(rect.tl_corner()), point_up(rect.br_corner())); + } + + drectangle rect_up ( + const drectangle& rect, + unsigned int levels + ) const + { + return drectangle(point_up(rect.tl_corner(),levels), point_up(rect.br_corner(),levels)); + } + + // ----------------------------- + + drectangle rect_down ( + const drectangle& rect + ) const + { + return drectangle(point_down(rect.tl_corner()), point_down(rect.br_corner())); + } + + drectangle rect_down ( + const drectangle& rect, + unsigned int levels + ) const + { + return drectangle(point_down(rect.tl_corner(),levels), point_down(rect.br_corner(),levels)); + } + + // ----------------------------- + + public: + + template < + typename in_image_type, + typename out_image_type + > + void operator() ( + // we do this #ifdef stuff to avoid compiler warnings about unused variables. +#ifdef ENABLE_ASSERTS + const in_image_type& original, +#else + const in_image_type& , +#endif + out_image_type& down + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(is_same_object(original, down) == false, + "\t void pyramid_disable::operator()" + << "\n\t is_same_object(original, down): " << is_same_object(original, down) + << "\n\t this: " << this + ); + + typedef typename image_traits::pixel_type in_pixel_type; + typedef typename image_traits::pixel_type out_pixel_type; + COMPILE_TIME_ASSERT( pixel_traits::has_alpha == false ); + COMPILE_TIME_ASSERT( pixel_traits::has_alpha == false ); + + set_image_size(down, 0, 0); + } + + template < + typename image_type + > + void operator() ( + image_type& img + ) const + { + typedef typename image_traits::pixel_type pixel_type; + COMPILE_TIME_ASSERT( pixel_traits::has_alpha == false ); + set_image_size(img, 0, 0); + } + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + namespace impl + { + + class pyramid_down_2_1 : noncopyable + { + public: + + template + vector point_down ( + const vector& p + ) const + { + return p/2.0 - vector(1.25,0.75); + } + + template + vector point_up ( + const vector& p + ) const + { + return (p + vector(1.25,0.75))*2; + } + + // ----------------------------- + + template + vector point_down ( + const vector& p, + unsigned int levels + ) const + { + vector temp = p; + for (unsigned int i = 0; i < levels; ++i) + temp = point_down(temp); + return temp; + } + + template + vector point_up ( + const vector& p, + unsigned int levels + ) const + { + vector temp = p; + for (unsigned int i = 0; i < levels; ++i) + temp = point_up(temp); + return temp; + } + + // ----------------------------- + + drectangle rect_up ( + const drectangle& rect + ) const + { + return drectangle(point_up(rect.tl_corner()), point_up(rect.br_corner())); + } + + drectangle rect_up ( + const drectangle& rect, + unsigned int levels + ) const + { + return drectangle(point_up(rect.tl_corner(),levels), point_up(rect.br_corner(),levels)); + } + + // ----------------------------- + + drectangle rect_down ( + const drectangle& rect + ) const + { + return drectangle(point_down(rect.tl_corner()), point_down(rect.br_corner())); + } + + drectangle rect_down ( + const drectangle& rect, + unsigned int levels + ) const + { + return drectangle(point_down(rect.tl_corner(),levels), point_down(rect.br_corner(),levels)); + } + + // ----------------------------- + + private: + template + struct both_images_rgb + { + typedef typename image_traits::pixel_type T_pix; + typedef typename image_traits::pixel_type U_pix; + const static bool value = pixel_traits::rgb && pixel_traits::rgb; + }; + public: + + template < + typename in_image_type, + typename out_image_type + > + typename disable_if >::type operator() ( + const in_image_type& original_, + out_image_type& down_ + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT( is_same_object(original_, down_) == false, + "\t void pyramid_down_2_1::operator()" + << "\n\t is_same_object(original_, down_): " << is_same_object(original_, down_) + << "\n\t this: " << this + ); + + typedef typename image_traits::pixel_type in_pixel_type; + typedef typename image_traits::pixel_type out_pixel_type; + COMPILE_TIME_ASSERT( pixel_traits::has_alpha == false ); + COMPILE_TIME_ASSERT( pixel_traits::has_alpha == false ); + + const_image_view original(original_); + image_view down(down_); + + if (original.nr() <= 8 || original.nc() <= 8) + { + down.clear(); + return; + } + + typedef typename pixel_traits::basic_pixel_type bp_type; + typedef typename promote::type ptype; + array2d temp_img; + temp_img.set_size(original.nr(), (original.nc()-3)/2); + down.set_size((original.nr()-3)/2, (original.nc()-3)/2); + + + // This function applies a 5x5 Gaussian filter to the image. It + // does this by separating the filter into its horizontal and vertical + // components and then downsamples the image by dropping every other + // row and column. Note that we can do these things all together in + // one step. + + // apply row filter + for (long r = 0; r < temp_img.nr(); ++r) + { + long oc = 0; + for (long c = 0; c < temp_img.nc(); ++c) + { + ptype pix1; + ptype pix2; + ptype pix3; + ptype pix4; + ptype pix5; + + assign_pixel(pix1, original[r][oc]); + assign_pixel(pix2, original[r][oc+1]); + assign_pixel(pix3, original[r][oc+2]); + assign_pixel(pix4, original[r][oc+3]); + assign_pixel(pix5, original[r][oc+4]); + + pix2 *= 4; + pix3 *= 6; + pix4 *= 4; + + assign_pixel(temp_img[r][c], pix1 + pix2 + pix3 + pix4 + pix5); + oc += 2; + } + } + + + // apply column filter + long dr = 0; + for (long r = 2; r < temp_img.nr()-2; r += 2) + { + for (long c = 0; c < temp_img.nc(); ++c) + { + ptype temp = temp_img[r-2][c] + + temp_img[r-1][c]*4 + + temp_img[r ][c]*6 + + temp_img[r+1][c]*4 + + temp_img[r+2][c]; + + assign_pixel(down[dr][c],temp/256); + } + ++dr; + } + + } + + private: + struct rgbptype + { + uint16 red; + uint16 green; + uint16 blue; + }; + public: + // ------------------------------------------ + // OVERLOAD FOR RGB TO RGB IMAGES + // ------------------------------------------ + template < + typename in_image_type, + typename out_image_type + > + typename enable_if >::type operator() ( + const in_image_type& original_, + out_image_type& down_ + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT( is_same_object(original_, down_) == false, + "\t void pyramid_down_2_1::operator()" + << "\n\t is_same_object(original_, down_): " << is_same_object(original_, down_) + << "\n\t this: " << this + ); + + typedef typename image_traits::pixel_type in_pixel_type; + typedef typename image_traits::pixel_type out_pixel_type; + COMPILE_TIME_ASSERT( pixel_traits::has_alpha == false ); + COMPILE_TIME_ASSERT( pixel_traits::has_alpha == false ); + + const_image_view original(original_); + image_view down(down_); + + if (original.nr() <= 8 || original.nc() <= 8) + { + down.clear(); + return; + } + + array2d temp_img; + temp_img.set_size(original.nr(), (original.nc()-3)/2); + down.set_size((original.nr()-3)/2, (original.nc()-3)/2); + + + // This function applies a 5x5 Gaussian filter to the image. It + // does this by separating the filter into its horizontal and vertical + // components and then downsamples the image by dropping every other + // row and column. Note that we can do these things all together in + // one step. + + // apply row filter + for (long r = 0; r < temp_img.nr(); ++r) + { + long oc = 0; + for (long c = 0; c < temp_img.nc(); ++c) + { + rgbptype pix1; + rgbptype pix2; + rgbptype pix3; + rgbptype pix4; + rgbptype pix5; + + pix1.red = original[r][oc].red; + pix2.red = original[r][oc+1].red; + pix3.red = original[r][oc+2].red; + pix4.red = original[r][oc+3].red; + pix5.red = original[r][oc+4].red; + pix1.green = original[r][oc].green; + pix2.green = original[r][oc+1].green; + pix3.green = original[r][oc+2].green; + pix4.green = original[r][oc+3].green; + pix5.green = original[r][oc+4].green; + pix1.blue = original[r][oc].blue; + pix2.blue = original[r][oc+1].blue; + pix3.blue = original[r][oc+2].blue; + pix4.blue = original[r][oc+3].blue; + pix5.blue = original[r][oc+4].blue; + + pix2.red *= 4; + pix3.red *= 6; + pix4.red *= 4; + + pix2.green *= 4; + pix3.green *= 6; + pix4.green *= 4; + + pix2.blue *= 4; + pix3.blue *= 6; + pix4.blue *= 4; + + rgbptype temp; + temp.red = pix1.red + pix2.red + pix3.red + pix4.red + pix5.red; + temp.green = pix1.green + pix2.green + pix3.green + pix4.green + pix5.green; + temp.blue = pix1.blue + pix2.blue + pix3.blue + pix4.blue + pix5.blue; + + temp_img[r][c] = temp; + + oc += 2; + } + } + + + // apply column filter + long dr = 0; + for (long r = 2; r < temp_img.nr()-2; r += 2) + { + for (long c = 0; c < temp_img.nc(); ++c) + { + rgbptype temp; + temp.red = temp_img[r-2][c].red + + temp_img[r-1][c].red*4 + + temp_img[r ][c].red*6 + + temp_img[r+1][c].red*4 + + temp_img[r+2][c].red; + temp.green = temp_img[r-2][c].green + + temp_img[r-1][c].green*4 + + temp_img[r ][c].green*6 + + temp_img[r+1][c].green*4 + + temp_img[r+2][c].green; + temp.blue = temp_img[r-2][c].blue + + temp_img[r-1][c].blue*4 + + temp_img[r ][c].blue*6 + + temp_img[r+1][c].blue*4 + + temp_img[r+2][c].blue; + + down[dr][c].red = temp.red/256; + down[dr][c].green = temp.green/256; + down[dr][c].blue = temp.blue/256; + } + ++dr; + } + + } + + template < + typename image_type + > + void operator() ( + image_type& img + ) const + { + image_type temp; + (*this)(img, temp); + swap(temp, img); + } + + private: + + + }; + + // ---------------------------------------------------------------------------------------- + // ---------------------------------------------------------------------------------------- + // ---------------------------------------------------------------------------------------- + + class pyramid_down_3_2 : noncopyable + { + public: + + template + vector point_down ( + const vector& p + ) const + { + const double ratio = 2.0/3.0; + return p*ratio - vector(1,1); + } + + template + vector point_up ( + const vector& p + ) const + { + const double ratio = 3.0/2.0; + return p*ratio + vector(ratio,ratio); + } + + // ----------------------------- + + template + vector point_down ( + const vector& p, + unsigned int levels + ) const + { + vector temp = p; + for (unsigned int i = 0; i < levels; ++i) + temp = point_down(temp); + return temp; + } + + template + vector point_up ( + const vector& p, + unsigned int levels + ) const + { + vector temp = p; + for (unsigned int i = 0; i < levels; ++i) + temp = point_up(temp); + return temp; + } + + // ----------------------------- + + drectangle rect_up ( + const drectangle& rect + ) const + { + return drectangle(point_up(rect.tl_corner()), point_up(rect.br_corner())); + } + + drectangle rect_up ( + const drectangle& rect, + unsigned int levels + ) const + { + return drectangle(point_up(rect.tl_corner(),levels), point_up(rect.br_corner(),levels)); + } + + // ----------------------------- + + drectangle rect_down ( + const drectangle& rect + ) const + { + return drectangle(point_down(rect.tl_corner()), point_down(rect.br_corner())); + } + + drectangle rect_down ( + const drectangle& rect, + unsigned int levels + ) const + { + return drectangle(point_down(rect.tl_corner(),levels), point_down(rect.br_corner(),levels)); + } + + // ----------------------------- + + private: + template + struct both_images_rgb + { + typedef typename image_traits::pixel_type T_pix; + typedef typename image_traits::pixel_type U_pix; + const static bool value = pixel_traits::rgb && pixel_traits::rgb; + }; + public: + + template < + typename in_image_type, + typename out_image_type + > + typename disable_if >::type operator() ( + const in_image_type& original_, + out_image_type& down_ + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(is_same_object(original_, down_) == false, + "\t void pyramid_down_3_2::operator()" + << "\n\t is_same_object(original_, down_): " << is_same_object(original_, down_) + << "\n\t this: " << this + ); + + typedef typename image_traits::pixel_type in_pixel_type; + typedef typename image_traits::pixel_type out_pixel_type; + COMPILE_TIME_ASSERT( pixel_traits::has_alpha == false ); + COMPILE_TIME_ASSERT( pixel_traits::has_alpha == false ); + + const_image_view original(original_); + image_view down(down_); + + if (original.nr() <= 8 || original.nc() <= 8) + { + down.clear(); + return; + } + + const long size_in = 3; + const long size_out = 2; + + typedef typename pixel_traits::basic_pixel_type bp_type; + typedef typename promote::type ptype; + const long full_nr = size_out*((original.nr()-2)/size_in); + const long part_nr = (size_out*(original.nr()-2))/size_in; + const long full_nc = size_out*((original.nc()-2)/size_in); + const long part_nc = (size_out*(original.nc()-2))/size_in; + down.set_size(part_nr, part_nc); + + + long rr = 1; + long r; + for (r = 0; r < full_nr; r+=size_out) + { + long cc = 1; + long c; + for (c = 0; c < full_nc; c+=size_out) + { + ptype block[size_in][size_in]; + separable_3x3_filter_block_grayscale(block, original_, rr, cc, 2, 12, 2); + + // bi-linearly interpolate block + assign_pixel(down[r][c] , (block[0][0]*9 + block[1][0]*3 + block[0][1]*3 + block[1][1])/(16*256)); + assign_pixel(down[r][c+1] , (block[0][2]*9 + block[1][2]*3 + block[0][1]*3 + block[1][1])/(16*256)); + assign_pixel(down[r+1][c] , (block[2][0]*9 + block[1][0]*3 + block[2][1]*3 + block[1][1])/(16*256)); + assign_pixel(down[r+1][c+1] , (block[2][2]*9 + block[1][2]*3 + block[2][1]*3 + block[1][1])/(16*256)); + + cc += size_in; + } + if (part_nc - full_nc == 1) + { + ptype block[size_in][2]; + separable_3x3_filter_block_grayscale(block, original_, rr, cc, 2, 12, 2); + + // bi-linearly interpolate partial block + assign_pixel(down[r][c] , (block[0][0]*9 + block[1][0]*3 + block[0][1]*3 + block[1][1])/(16*256)); + assign_pixel(down[r+1][c] , (block[2][0]*9 + block[1][0]*3 + block[2][1]*3 + block[1][1])/(16*256)); + } + rr += size_in; + } + if (part_nr - full_nr == 1) + { + long cc = 1; + long c; + for (c = 0; c < full_nc; c+=size_out) + { + ptype block[2][size_in]; + separable_3x3_filter_block_grayscale(block, original_, rr, cc, 2, 12, 2); + + // bi-linearly interpolate partial block + assign_pixel(down[r][c] , (block[0][0]*9 + block[1][0]*3 + block[0][1]*3 + block[1][1])/(16*256)); + assign_pixel(down[r][c+1] , (block[0][2]*9 + block[1][2]*3 + block[0][1]*3 + block[1][1])/(16*256)); + + cc += size_in; + } + if (part_nc - full_nc == 1) + { + ptype block[2][2]; + separable_3x3_filter_block_grayscale(block, original_, rr, cc, 2, 12, 2); + + // bi-linearly interpolate partial block + assign_pixel(down[r][c] , (block[0][0]*9 + block[1][0]*3 + block[0][1]*3 + block[1][1])/(16*256)); + } + } + + } + + private: + struct rgbptype + { + uint32 red; + uint32 green; + uint32 blue; + }; + + public: + // ------------------------------------------ + // OVERLOAD FOR RGB TO RGB IMAGES + // ------------------------------------------ + template < + typename in_image_type, + typename out_image_type + > + typename enable_if >::type operator() ( + const in_image_type& original_, + out_image_type& down_ + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT( is_same_object(original_, down_) == false, + "\t void pyramid_down_3_2::operator()" + << "\n\t is_same_object(original_, down_): " << is_same_object(original_, down_) + << "\n\t this: " << this + ); + + typedef typename image_traits::pixel_type in_pixel_type; + typedef typename image_traits::pixel_type out_pixel_type; + COMPILE_TIME_ASSERT( pixel_traits::has_alpha == false ); + COMPILE_TIME_ASSERT( pixel_traits::has_alpha == false ); + + const_image_view original(original_); + image_view down(down_); + + if (original.nr() <= 8 || original.nc() <= 8) + { + down.clear(); + return; + } + + const long size_in = 3; + const long size_out = 2; + + const long full_nr = size_out*((original.nr()-2)/size_in); + const long part_nr = (size_out*(original.nr()-2))/size_in; + const long full_nc = size_out*((original.nc()-2)/size_in); + const long part_nc = (size_out*(original.nc()-2))/size_in; + down.set_size(part_nr, part_nc); + + + long rr = 1; + long r; + for (r = 0; r < full_nr; r+=size_out) + { + long cc = 1; + long c; + for (c = 0; c < full_nc; c+=size_out) + { + rgbptype block[size_in][size_in]; + separable_3x3_filter_block_rgb(block, original_, rr, cc, 2, 12, 2); + + // bi-linearly interpolate block + down[r][c].red = (block[0][0].red*9 + block[1][0].red*3 + block[0][1].red*3 + block[1][1].red)/(16*256); + down[r][c].green = (block[0][0].green*9 + block[1][0].green*3 + block[0][1].green*3 + block[1][1].green)/(16*256); + down[r][c].blue = (block[0][0].blue*9 + block[1][0].blue*3 + block[0][1].blue*3 + block[1][1].blue)/(16*256); + + down[r][c+1].red = (block[0][2].red*9 + block[1][2].red*3 + block[0][1].red*3 + block[1][1].red)/(16*256); + down[r][c+1].green = (block[0][2].green*9 + block[1][2].green*3 + block[0][1].green*3 + block[1][1].green)/(16*256); + down[r][c+1].blue = (block[0][2].blue*9 + block[1][2].blue*3 + block[0][1].blue*3 + block[1][1].blue)/(16*256); + + down[r+1][c].red = (block[2][0].red*9 + block[1][0].red*3 + block[2][1].red*3 + block[1][1].red)/(16*256); + down[r+1][c].green = (block[2][0].green*9 + block[1][0].green*3 + block[2][1].green*3 + block[1][1].green)/(16*256); + down[r+1][c].blue = (block[2][0].blue*9 + block[1][0].blue*3 + block[2][1].blue*3 + block[1][1].blue)/(16*256); + + down[r+1][c+1].red = (block[2][2].red*9 + block[1][2].red*3 + block[2][1].red*3 + block[1][1].red)/(16*256); + down[r+1][c+1].green = (block[2][2].green*9 + block[1][2].green*3 + block[2][1].green*3 + block[1][1].green)/(16*256); + down[r+1][c+1].blue = (block[2][2].blue*9 + block[1][2].blue*3 + block[2][1].blue*3 + block[1][1].blue)/(16*256); + + cc += size_in; + } + if (part_nc - full_nc == 1) + { + rgbptype block[size_in][2]; + separable_3x3_filter_block_rgb(block, original_, rr, cc, 2, 12, 2); + + // bi-linearly interpolate partial block + down[r][c].red = (block[0][0].red*9 + block[1][0].red*3 + block[0][1].red*3 + block[1][1].red)/(16*256); + down[r][c].green = (block[0][0].green*9 + block[1][0].green*3 + block[0][1].green*3 + block[1][1].green)/(16*256); + down[r][c].blue = (block[0][0].blue*9 + block[1][0].blue*3 + block[0][1].blue*3 + block[1][1].blue)/(16*256); + + down[r+1][c].red = (block[2][0].red*9 + block[1][0].red*3 + block[2][1].red*3 + block[1][1].red)/(16*256); + down[r+1][c].green = (block[2][0].green*9 + block[1][0].green*3 + block[2][1].green*3 + block[1][1].green)/(16*256); + down[r+1][c].blue = (block[2][0].blue*9 + block[1][0].blue*3 + block[2][1].blue*3 + block[1][1].blue)/(16*256); + } + rr += size_in; + } + if (part_nr - full_nr == 1) + { + long cc = 1; + long c; + for (c = 0; c < full_nc; c+=size_out) + { + rgbptype block[2][size_in]; + separable_3x3_filter_block_rgb(block, original_, rr, cc, 2, 12, 2); + + // bi-linearly interpolate partial block + down[r][c].red = (block[0][0].red*9 + block[1][0].red*3 + block[0][1].red*3 + block[1][1].red)/(16*256); + down[r][c].green = (block[0][0].green*9 + block[1][0].green*3 + block[0][1].green*3 + block[1][1].green)/(16*256); + down[r][c].blue = (block[0][0].blue*9 + block[1][0].blue*3 + block[0][1].blue*3 + block[1][1].blue)/(16*256); + + down[r][c+1].red = (block[0][2].red*9 + block[1][2].red*3 + block[0][1].red*3 + block[1][1].red)/(16*256); + down[r][c+1].green = (block[0][2].green*9 + block[1][2].green*3 + block[0][1].green*3 + block[1][1].green)/(16*256); + down[r][c+1].blue = (block[0][2].blue*9 + block[1][2].blue*3 + block[0][1].blue*3 + block[1][1].blue)/(16*256); + + cc += size_in; + } + if (part_nc - full_nc == 1) + { + rgbptype block[2][2]; + separable_3x3_filter_block_rgb(block, original_, rr, cc, 2, 12, 2); + + // bi-linearly interpolate partial block + down[r][c].red = (block[0][0].red*9 + block[1][0].red*3 + block[0][1].red*3 + block[1][1].red)/(16*256); + down[r][c].green = (block[0][0].green*9 + block[1][0].green*3 + block[0][1].green*3 + block[1][1].green)/(16*256); + down[r][c].blue = (block[0][0].blue*9 + block[1][0].blue*3 + block[0][1].blue*3 + block[1][1].blue)/(16*256); + } + } + } + + template < + typename image_type + > + void operator() ( + image_type& img + ) const + { + image_type temp; + (*this)(img, temp); + swap(temp, img); + } + private: + + + }; + + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + unsigned int N + > + class pyramid_down : noncopyable + { + public: + + COMPILE_TIME_ASSERT(N > 0); + + template + vector point_down ( + const vector& p + ) const + { + const double ratio = (N-1.0)/N; + return (p - 0.3)*ratio; + } + + template + vector point_up ( + const vector& p + ) const + { + const double ratio = N/(N-1.0); + return p*ratio + 0.3; + } + + // ----------------------------- + + template + vector point_down ( + const vector& p, + unsigned int levels + ) const + { + vector temp = p; + for (unsigned int i = 0; i < levels; ++i) + temp = point_down(temp); + return temp; + } + + template + vector point_up ( + const vector& p, + unsigned int levels + ) const + { + vector temp = p; + for (unsigned int i = 0; i < levels; ++i) + temp = point_up(temp); + return temp; + } + + // ----------------------------- + + drectangle rect_up ( + const drectangle& rect + ) const + { + return drectangle(point_up(rect.tl_corner()), point_up(rect.br_corner())); + } + + drectangle rect_up ( + const drectangle& rect, + unsigned int levels + ) const + { + return drectangle(point_up(rect.tl_corner(),levels), point_up(rect.br_corner(),levels)); + } + + // ----------------------------- + + drectangle rect_down ( + const drectangle& rect + ) const + { + return drectangle(point_down(rect.tl_corner()), point_down(rect.br_corner())); + } + + drectangle rect_down ( + const drectangle& rect, + unsigned int levels + ) const + { + return drectangle(point_down(rect.tl_corner(),levels), point_down(rect.br_corner(),levels)); + } + + template < + typename in_image_type, + typename out_image_type + > + void operator() ( + const in_image_type& original, + out_image_type& down + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(is_same_object(original, down) == false, + "\t void pyramid_down::operator()" + << "\n\t is_same_object(original, down): " << is_same_object(original, down) + << "\n\t this: " << this + ); + + typedef typename image_traits::pixel_type in_pixel_type; + typedef typename image_traits::pixel_type out_pixel_type; + COMPILE_TIME_ASSERT( pixel_traits::has_alpha == false ); + COMPILE_TIME_ASSERT( pixel_traits::has_alpha == false ); + + + set_image_size(down, ((N-1)*num_rows(original))/N+0.5, ((N-1)*num_columns(original))/N+0.5); + resize_image(original, down); + } + + template < + typename image_type + > + void operator() ( + image_type& img + ) const + { + image_type temp; + (*this)(img, temp); + swap(temp, img); + } + }; + + template <> + class pyramid_down<1> : public pyramid_disable {}; + + template <> + class pyramid_down<2> : public dlib::impl::pyramid_down_2_1 {}; + + template <> + class pyramid_down<3> : public dlib::impl::pyramid_down_3_2 {}; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template + double pyramid_rate(const pyramid_down&) + { + return (N-1.0)/N; + } + +// ---------------------------------------------------------------------------------------- + + template + void find_pyramid_down_output_image_size( + const pyramid_down& pyr, + long& nr, + long& nc + ) + { + const double rate = pyramid_rate(pyr); + nr = std::floor(rate*nr); + nc = std::floor(rate*nc); + } + + inline void find_pyramid_down_output_image_size( + const pyramid_down<3>& /*pyr*/, + long& nr, + long& nc + ) + { + nr = 2*(nr-2)/3; + nc = 2*(nc-2)/3; + } + + inline void find_pyramid_down_output_image_size( + const pyramid_down<2>& /*pyr*/, + long& nr, + long& nc + ) + { + nr = (nr-3)/2; + nc = (nc-3)/2; + } + + inline void find_pyramid_down_output_image_size( + const pyramid_down<1>& /*pyr*/, + long& nr, + long& nc + ) + { + nr = 0; + nc = 0; + } + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + template + void compute_tiled_image_pyramid_details ( + const pyramid_type& pyr, + long nr, + long nc, + const unsigned long padding, + const unsigned long outer_padding, + std::vector& rects, + long& pyramid_image_nr, + long& pyramid_image_nc + ) + { + rects.clear(); + if (nr*nc == 0) + { + pyramid_image_nr = 0; + pyramid_image_nc = 0; + return; + } + + const long min_height = 5; + rects.reserve(100); + rects.push_back(rectangle(nc,nr)); + // build the whole pyramid + while(true) + { + find_pyramid_down_output_image_size(pyr, nr, nc); + if (nr*nc == 0 || nr < min_height) + break; + rects.push_back(rectangle(nc,nr)); + } + + // figure out output image size + long total_height = 0; + for (auto&& i : rects) + total_height += i.height()+padding; + total_height -= padding*2; // don't add unnecessary padding to the very right side. + long height = 0; + long prev_width = 0; + for (auto&& i : rects) + { + // Figure out how far we go on the first column. We go until the next image can + // fit next to the previous one, which means we can double back for the second + // column of images. + if (i.width() <= rects[0].width()-prev_width-(long)padding && + (height-rects[0].height())*2 >= (total_height-rects[0].height())) + { + break; + } + height += i.height() + padding; + prev_width = i.width(); + } + height -= padding; // don't add unnecessary padding to the very right side. + + const long width = rects[0].width(); + pyramid_image_nr = height+outer_padding*2; + pyramid_image_nc = width+outer_padding*2; + + + long y = outer_padding; + size_t i = 0; + while(y < height+(long)outer_padding && i < rects.size()) + { + rects[i] = translate_rect(rects[i],point(outer_padding,y)); + DLIB_ASSERT(rectangle(pyramid_image_nc,pyramid_image_nr).contains(rects[i])); + y += rects[i].height()+padding; + ++i; + } + y -= padding; + while (i < rects.size()) + { + point p1(outer_padding+width-1,y-1); + point p2 = p1 - rects[i].br_corner(); + rectangle rect(p1,p2); + DLIB_ASSERT(rectangle(pyramid_image_nc,pyramid_image_nr).contains(rect)); + // don't keep going on the last row if it would intersect the original image. + if (!rects[0].intersect(rect).is_empty()) + break; + + rects[i] = rect; + y -= rects[i].height()+padding; + ++i; + } + + // Delete any extraneous rectangles if we broke out of the above loop early due to + // intersection with the original image. + rects.resize(i); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename pyramid_type, + typename image_type1, + typename image_type2 + > + void create_tiled_pyramid ( + const image_type1& img, + image_type2& out_img, + std::vector& rects, + const unsigned long padding = 10, + const unsigned long outer_padding = 0 + ) + { + DLIB_ASSERT(!is_same_object(img, out_img)); + + long out_nr, out_nc; + pyramid_type pyr; + impl::compute_tiled_image_pyramid_details(pyr, img.nr(), img.nc(), padding, outer_padding, rects, out_nr, out_nc); + + set_image_size(out_img, out_nr, out_nc); + assign_all_pixels(out_img, 0); + + if (rects.size() == 0) + return; + + // now build the image pyramid into out_img + auto si = sub_image(out_img, rects[0]); + assign_image(si, img); + for (size_t i = 1; i < rects.size(); ++i) + { + auto s1 = sub_image(out_img, rects[i-1]); + auto s2 = sub_image(out_img, rects[i]); + pyr(s1,s2); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename pyramid_type + > + dpoint image_to_tiled_pyramid ( + const std::vector& rects, + double scale, + dpoint p + ) + { + DLIB_CASSERT(rects.size() > 0); + DLIB_CASSERT(0 < scale && scale <= 1); + pyramid_type pyr; + // This scale factor maps this many levels down the pyramid + long pyramid_down_iter = static_cast(std::log(scale)/std::log(pyramid_rate(pyr))+0.5); + pyramid_down_iter = put_in_range(0, (long)rects.size()-1, pyramid_down_iter); + + return rects[pyramid_down_iter].tl_corner() + pyr.point_down(p, pyramid_down_iter); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename pyramid_type + > + drectangle image_to_tiled_pyramid ( + const std::vector& rects, + double scale, + drectangle r + ) + { + DLIB_ASSERT(rects.size() > 0); + DLIB_ASSERT(0 < scale && scale <= 1); + return drectangle(image_to_tiled_pyramid(rects, scale, r.tl_corner()), + image_to_tiled_pyramid(rects, scale, r.br_corner())); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename pyramid_type + > + dpoint tiled_pyramid_to_image ( + const std::vector& rects, + dpoint p + ) + { + DLIB_CASSERT(rects.size() > 0); + + size_t pyramid_down_iter = nearest_rect(rects, p); + + p -= rects[pyramid_down_iter].tl_corner(); + pyramid_type pyr; + return pyr.point_up(p, pyramid_down_iter); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename pyramid_type + > + drectangle tiled_pyramid_to_image ( + const std::vector& rects, + drectangle r + ) + { + DLIB_CASSERT(rects.size() > 0); + + size_t pyramid_down_iter = nearest_rect(rects, dcenter(r)); + + dpoint origin = rects[pyramid_down_iter].tl_corner(); + r = drectangle(r.tl_corner()-origin, r.br_corner()-origin); + pyramid_type pyr; + return pyr.rect_up(r, pyramid_down_iter); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_IMAGE_PYRaMID_Hh_ + diff --git a/ml/dlib/dlib/image_transforms/image_pyramid_abstract.h b/ml/dlib/dlib/image_transforms/image_pyramid_abstract.h new file mode 100644 index 000000000..a61b275fd --- /dev/null +++ b/ml/dlib/dlib/image_transforms/image_pyramid_abstract.h @@ -0,0 +1,384 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_IMAGE_PYRaMID_ABSTRACT_Hh_ +#ifdef DLIB_IMAGE_PYRaMID_ABSTRACT_Hh_ + +#include "../pixel.h" +#include "../array2d.h" +#include "../geometry.h" +#include "../image_processing/generic_image.h" + +namespace dlib +{ + + template < + unsigned int N + > + class pyramid_down : noncopyable + { + /*! + REQUIREMENTS ON N + N > 0 + + WHAT THIS OBJECT REPRESENTS + This is a simple functor to help create image pyramids. In particular, it + downsamples images at a ratio of N to N-1. + + Note that setting N to 1 means that this object functions like + pyramid_disable (defined at the bottom of this file). + + WARNING, when mapping rectangles from one layer of a pyramid + to another you might end up with rectangles which extend slightly + outside your images. This is because points on the border of an + image at a higher pyramid layer might correspond to points outside + images at lower layers. So just keep this in mind. Note also + that it's easy to deal with. Just say something like this: + rect = rect.intersect(get_rect(my_image)); // keep rect inside my_image + !*/ + public: + + template < + typename in_image_type, + typename out_image_type + > + void operator() ( + const in_image_type& original, + out_image_type& down + ) const; + /*! + requires + - is_same_object(original, down) == false + - in_image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - out_image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - for both pixel types P in the input and output images, we require: + - pixel_traits

    ::has_alpha == false + ensures + - #down will contain an image that is roughly (N-1)/N times the size of the + original image. + - If both input and output images contain RGB pixels then the downsampled image will + be in color. Otherwise, the downsampling will be performed in a grayscale mode. + - The location of a point P in original image will show up at point point_down(P) + in the #down image. + - Note that some points on the border of the original image might correspond to + points outside the #down image. + !*/ + + template < + typename image_type + > + void operator() ( + image_type& img + ) const; + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - pixel_traits::pixel_type>::has_alpha == false + ensures + - This function downsamples the given image and stores the results in #img. + In particular, it is equivalent to performing: + (*this)(img, temp); + swap(img, temp); + !*/ + + // ------------------------------- + + template + vector point_down ( + const vector& p + ) const; + /*! + ensures + - interprets p as a point in a parent image and returns the + point in a downsampled image which corresponds to p. + - This function is the inverse of point_up(). I.e. for a point P: + point_down(point_up(P)) == P + !*/ + + template + vector point_up ( + const vector& p + ) const; + /*! + ensures + - interprets p as a point in a downsampled image and returns the + point in a parent image which corresponds to p. + - This function is the inverse of point_down(). I.e. for a point P: + point_up(point_down(P)) == P + !*/ + + drectangle rect_down ( + const drectangle& rect + ) const; + /*! + ensures + - returns drectangle(point_down(rect.tl_corner()), point_down(rect.br_corner())); + (i.e. maps rect into a downsampled) + !*/ + + drectangle rect_up ( + const drectangle& rect + ) const; + /*! + ensures + - returns drectangle(point_up(rect.tl_corner()), point_up(rect.br_corner())); + (i.e. maps rect into a parent image) + !*/ + + // ------------------------------- + + template + vector point_down ( + const vector& p, + unsigned int levels + ) const; + /*! + ensures + - applies point_down() to p levels times and returns the result. + (i.e. point_down(p,2) == point_down(point_down(p)), + point_down(p,1) == point_down(p), + point_down(p,0) == p, etc. ) + !*/ + + template + vector point_up ( + const vector& p, + unsigned int levels + ) const; + /*! + ensures + - applies point_up() to p levels times and returns the result. + (i.e. point_up(p,2) == point_up(point_up(p)), + point_up(p,1) == point_up(p), + point_up(p,0) == p, etc. ) + !*/ + + drectangle rect_down ( + const drectangle& rect, + unsigned int levels + ) const; + /*! + ensures + - returns drectangle(point_down(rect.tl_corner(),levels), point_down(rect.br_corner(),levels)); + (i.e. Basically applies rect_down() to rect levels times and returns the result.) + !*/ + + drectangle rect_up ( + const drectangle& rect, + unsigned int levels + ) const; + /*! + ensures + - returns drectangle(point_up(rect.tl_corner(),levels), point_up(rect.br_corner(),levels)); + (i.e. Basically applies rect_up() to rect levels times and returns the result.) + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + class pyramid_disable : noncopyable + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a function object with an interface identical to pyramid_down (defined + at the top of this file) except that it downsamples images at a ratio of infinity + to 1. That means it always outputs images of size 0 regardless of the size + of the inputs. + + This is useful because it can be supplied to routines which take a pyramid_down + function object and it will essentially disable pyramid processing. This way, + a pyramid oriented function can be turned into a regular routine which processes + just the original undownsampled image. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + template < + unsigned int N + > + double pyramid_rate( + const pyramid_down& pyr + ); + /*! + ensures + - returns (N-1.0)/N + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + unsigned int N + > + void find_pyramid_down_output_image_size( + const pyramid_down& pyr, + long& nr, + long& nc + ); + /*! + requires + - nr >= 0 + - nc >= 0 + ensures + - If pyr() were called on an image with nr by nc rows and columns, what would + be the size of the output image? This function finds the size of the output + image and stores it back into #nr and #nc. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename pyramid_type, + typename image_type1, + typename image_type2 + > + void create_tiled_pyramid ( + const image_type1& img, + image_type2& out_img, + std::vector& rects, + const unsigned long padding = 10, + const unsigned long outer_padding = 0 + ); + /*! + requires + - pyramid_type == one of the dlib::pyramid_down template instances defined above. + - is_same_object(img, out_img) == false + - image_type1 == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - image_type2 == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - for both pixel types P in the input and output images, we require: + - pixel_traits

    ::has_alpha == false + ensures + - Creates an image pyramid from the input image img. The pyramid is made using + pyramid_type. The highest resolution image is img and then all further + pyramid levels are generated from pyramid_type's downsampling. The entire + resulting pyramid is packed into a single image and stored in out_img. + - When packing pyramid levels into out_img, there will be padding pixels of + space between each sub-image. There will also be outer_padding pixels of + padding around the edge of the image. All padding pixels have a value of 0. + - The resulting pyramid will be composed of #rects.size() images packed into + out_img. Moreover, #rects[i] is the location inside out_img of the i-th + pyramid level. + - #rects.size() > 0 + - #rects[0] == get_rect(img). I.e. the first rectangle is the highest + resolution pyramid layer. Subsequent elements of #rects correspond to + smaller and smaller pyramid layers inside out_img. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename pyramid_type + > + dpoint image_to_tiled_pyramid ( + const std::vector& rects, + double scale, + dpoint p + ); + /*! + requires + - pyramid_type == one of the dlib::pyramid_down template instances defined above. + - 0 < scale <= 1 + - rects.size() > 0 + ensures + - The function create_tiled_pyramid() converts an image, img, to a "tiled + pyramid" called out_img. It also outputs a vector of rectangles, rect, that + show where each pyramid layer appears in out_img. Therefore, + image_to_tiled_pyramid() allows you to map from coordinates in img (i.e. p) + to coordinates in the tiled pyramid out_img, when given the rects metadata. + + So given a point p in img, you can ask, what coordinate in out_img + corresponds to img[p.y()][p.x()] when things are scale times smaller? This + new coordinate is a location in out_img and is what is returned by this + function. + - A scale of 1 means we don't move anywhere in the pyramid scale space relative + to the input image while smaller values of scale mean we move down the + pyramid. + - Assumes pyramid_type is the pyramid class used to produce the tiled image. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename pyramid_type + > + drectangle image_to_tiled_pyramid ( + const std::vector& rects, + double scale, + drectangle r + ); + /*! + requires + - pyramid_type == one of the dlib::pyramid_down template instances defined above. + - 0 < scale <= 1 + - rects.size() > 0 + ensures + - This function maps from input image space to tiled pyramid coordinate space + just as the above image_to_tiled_pyramid() does, except it operates on + rectangle objects instead of points. + - Assumes pyramid_type is the pyramid class used to produce the tiled image. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename pyramid_type + > + dpoint tiled_pyramid_to_image ( + const std::vector& rects, + dpoint p + ); + /*! + requires + - pyramid_type == one of the dlib::pyramid_down template instances defined above. + - rects.size() > 0 + ensures + - This function maps from a coordinate in a tiled pyramid to the corresponding + input image coordinate. Therefore, it is essentially the inverse of + image_to_tiled_pyramid(). + - It should be noted that this function isn't always an inverse of + image_to_tiled_pyramid(). This is because you can ask + image_to_tiled_pyramid() for the coordinates of points outside the input + image and they will be mapped to somewhere that doesn't have an inverse. But + for points actually inside the image this function performs an approximate + inverse mapping. + - Assumes pyramid_type is the pyramid class used to produce the tiled image. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename pyramid_type + > + drectangle tiled_pyramid_to_image ( + const std::vector& rects, + drectangle r + ); + /*! + requires + - pyramid_type == one of the dlib::pyramid_down template instances defined above. + - rects.size() > 0 + ensures + - This function maps from a coordinate in a tiled pyramid to the corresponding + input image coordinate. Therefore, it is essentially the inverse of + image_to_tiled_pyramid(). + - It should be noted that this function isn't always an inverse of + image_to_tiled_pyramid(). This is because you can ask + image_to_tiled_pyramid() for the coordinates of points outside the input + image and they will be mapped to somewhere that doesn't have an inverse. But + for points actually inside the image this function performs an approximate + inverse mapping. + - Assumes pyramid_type is the pyramid class used to produce the tiled image. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_IMAGE_PYRaMID_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/image_transforms/integral_image.h b/ml/dlib/dlib/image_transforms/integral_image.h new file mode 100644 index 000000000..2ae47d921 --- /dev/null +++ b/ml/dlib/dlib/image_transforms/integral_image.h @@ -0,0 +1,190 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_INTEGRAL_IMAGE +#define DLIB_INTEGRAL_IMAGE + +#include "integral_image_abstract.h" + +#include "../algs.h" +#include "../assert.h" +#include "../geometry.h" +#include "../array2d.h" +#include "../matrix.h" +#include "../pixel.h" +#include "../noncopyable.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + class integral_image_generic : noncopyable + { + public: + typedef T value_type; + + long nr() const { return int_img.nr(); } + long nc() const { return int_img.nc(); } + + template + void load ( + const image_type& img_ + ) + { + const_image_view img(img_); + T pixel; + int_img.set_size(img.nr(), img.nc()); + + // compute the first row of the integral image + T temp = 0; + for (long c = 0; c < img.nc(); ++c) + { + assign_pixel(pixel, img[0][c]); + temp += pixel; + int_img[0][c] = temp; + } + + // now compute the rest of the integral image + for (long r = 1; r < img.nr(); ++r) + { + temp = 0; + for (long c = 0; c < img.nc(); ++c) + { + assign_pixel(pixel, img[r][c]); + temp += pixel; + int_img[r][c] = temp + int_img[r-1][c]; + } + } + + } + + value_type get_sum_of_area ( + const rectangle& rect + ) const + { + DLIB_ASSERT(get_rect(*this).contains(rect) == true && rect.is_empty() == false, + "\tvalue_type get_sum_of_area(rect)" + << "\n\tYou have given a rectangle that goes outside the image" + << "\n\tthis: " << this + << "\n\trect.is_empty(): " << rect.is_empty() + << "\n\trect: " << rect + << "\n\tget_rect(*this): " << get_rect(*this) + ); + + T top_left = 0, top_right = 0, bottom_left = 0, bottom_right = 0; + + bottom_right = int_img[rect.bottom()][rect.right()]; + if (rect.left()-1 >= 0 && rect.top()-1 >= 0) + { + top_left = int_img[rect.top()-1][rect.left()-1]; + bottom_left = int_img[rect.bottom()][rect.left()-1]; + top_right = int_img[rect.top()-1][rect.right()]; + } + else if (rect.left()-1 >= 0) + { + bottom_left = int_img[rect.bottom()][rect.left()-1]; + } + else if (rect.top()-1 >= 0) + { + top_right = int_img[rect.top()-1][rect.right()]; + } + + return bottom_right - bottom_left - top_right + top_left; + } + + void swap(integral_image_generic& item) + { + int_img.swap(item.int_img); + } + + private: + + array2d int_img; + }; + + + template < + typename T + > + void swap ( + integral_image_generic& a, + integral_image_generic& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- + + typedef integral_image_generic integral_image; + +// ---------------------------------------------------------------------------------------- + + template + typename integral_image_type::value_type haar_x ( + const integral_image_type& img, + const point& p, + long width + ) + { + DLIB_ASSERT(get_rect(img).contains(centered_rect(p,width,width)) == true, + "\tlong haar_x(img,p,width)" + << "\n\tYou have given a point and with that goes outside the image" + << "\n\tget_rect(img): " << get_rect(img) + << "\n\tp: " << p + << "\n\twidth: " << width + ); + + rectangle left_rect; + left_rect.set_left ( p.x() - width / 2 ); + left_rect.set_top ( p.y() - width / 2 ); + left_rect.set_right ( p.x()-1 ); + left_rect.set_bottom ( left_rect.top() + width - 1 ); + + rectangle right_rect; + right_rect.set_left ( p.x() ); + right_rect.set_top ( left_rect.top() ); + right_rect.set_right ( left_rect.left() + width -1 ); + right_rect.set_bottom ( left_rect.bottom() ); + + return img.get_sum_of_area(right_rect) - img.get_sum_of_area(left_rect); + } + + // ---------------------------------------------------------------------------- + + template + typename integral_image_type::value_type haar_y ( + const integral_image_type& img, + const point& p, + long width + ) + { + DLIB_ASSERT(get_rect(img).contains(centered_rect(p,width,width)) == true, + "\tlong haar_y(img,p,width)" + << "\n\tYou have given a point and with that goes outside the image" + << "\n\tget_rect(img): " << get_rect(img) + << "\n\tp: " << p + << "\n\twidth: " << width + ); + + rectangle top_rect; + top_rect.set_left ( p.x() - width / 2 ); + top_rect.set_top ( p.y() - width / 2 ); + top_rect.set_right ( top_rect.left() + width - 1 ); + top_rect.set_bottom ( p.y()-1 ); + + rectangle bottom_rect; + bottom_rect.set_left ( top_rect.left() ); + bottom_rect.set_top ( p.y() ); + bottom_rect.set_right ( top_rect.right() ); + bottom_rect.set_bottom ( top_rect.top() + width - 1 ); + + return img.get_sum_of_area(bottom_rect) - img.get_sum_of_area(top_rect); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_INTEGRAL_IMAGE + diff --git a/ml/dlib/dlib/image_transforms/integral_image_abstract.h b/ml/dlib/dlib/image_transforms/integral_image_abstract.h new file mode 100644 index 000000000..583fa0375 --- /dev/null +++ b/ml/dlib/dlib/image_transforms/integral_image_abstract.h @@ -0,0 +1,169 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_INTEGRAL_IMAGe_ABSTRACT_ +#ifdef DLIB_INTEGRAL_IMAGe_ABSTRACT_ + +#include "../geometry/rectangle_abstract.h" +#include "../array2d/array2d_kernel_abstract.h" +#include "../pixel.h" +#include "../noncopyable.h" +#include "../image_processing/generic_image.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + class integral_image_generic : noncopyable + { + /*! + REQUIREMENTS ON T + T should be a built in scalar type. Moreover, it should + be capable of storing sums of whatever kind of pixel + you will be dealing with. + + INITIAL VALUE + - nr() == 0 + - nc() == 0 + + WHAT THIS OBJECT REPRESENTS + This object is an alternate way of representing image data + that allows for very fast computations of sums of pixels in + rectangular regions. To use this object you load it with a + normal image and then you can use the get_sum_of_area() + function to compute sums of pixels in a given area in + constant time. + !*/ + public: + typedef T value_type; + + const long nr( + ) const; + /*! + ensures + - returns the number of rows in this integral image object + !*/ + + const long nc( + ) const; + /*! + ensures + - returns the number of columns in this integral image object + !*/ + + template + void load ( + const image_type& img + ); + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - Let P denote the type of pixel in img, then we require: + - pixel_traits

    ::has_alpha == false + ensures + - #nr() == img.nr() + - #nc() == img.nc() + - #*this will now contain an "integral image" representation of the + given input image. + !*/ + + value_type get_sum_of_area ( + const rectangle& rect + ) const; + /*! + requires + - rect.is_empty() == false + - get_rect(*this).contains(rect) == true + (i.e. rect must not be outside the integral image) + ensures + - Let O denote the image this integral image was generated from. + Then this function returns sum(subm(mat(O),rect)). + That is, this function returns the sum of the pixels in O that + are contained within the given rectangle. + !*/ + + void swap( + integral_image_generic& item + ); + /*! + ensures + - swaps *this and item + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template < typename T > + void swap ( + integral_image_generic& a, + integral_image_generic& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + +// ---------------------------------------------------------------------------------------- + + typedef integral_image_generic integral_image; + +// ---------------------------------------------------------------------------------------- + + template + typename integral_image_type::value_type haar_x ( + const integral_image_type& img, + const point& p, + long width + ) + /*! + requires + - get_rect(img).contains(centered_rect(p,width,width)) == true + - integral_image_type == a type that implements the integral_image_generic + interface defined above + ensures + - returns the response of a Haar wavelet centered at the point p + with the given width. The wavelet is oriented along the X axis + and has the following shape: + ----++++ + ----++++ + ----++++ + ----++++ + That is, the wavelet is square and computes the sum of pixels on the + right minus the sum of pixels on the left. + !*/ + +// ---------------------------------------------------------------------------------------- + + template + typename integral_image_type::value_type haar_y ( + const integral_image_type& img, + const point& p, + long width + ) + /*! + requires + - get_rect(img).contains(centered_rect(p,width,width)) == true + - integral_image_type == a type that implements the integral_image_generic + interface defined above + ensures + - returns the response of a Haar wavelet centered at the point p + with the given width in the given image. The wavelet is oriented + along the Y axis and has the following shape: + -------- + -------- + ++++++++ + ++++++++ + That is, the wavelet is square and computes the sum of pixels on the + bottom minus the sum of pixels on the top. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_INTEGRAL_IMAGe_ABSTRACT_ + diff --git a/ml/dlib/dlib/image_transforms/interpolation.h b/ml/dlib/dlib/image_transforms/interpolation.h new file mode 100644 index 000000000..11c561e2d --- /dev/null +++ b/ml/dlib/dlib/image_transforms/interpolation.h @@ -0,0 +1,2193 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_INTERPOlATIONh_ +#define DLIB_INTERPOlATIONh_ + +#include "interpolation_abstract.h" +#include "../pixel.h" +#include "../matrix.h" +#include "assign_image.h" +#include "image_pyramid.h" +#include "../simd.h" +#include "../image_processing/full_object_detection.h" +#include +#include "../rand.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template + struct sub_image_proxy + { + sub_image_proxy() = default; + + sub_image_proxy ( + T& img, + rectangle rect + ) + { + rect = rect.intersect(get_rect(img)); + typedef typename image_traits::pixel_type pixel_type; + + _nr = rect.height(); + _nc = rect.width(); + _width_step = width_step(img); + _data = (char*)image_data(img) + sizeof(pixel_type)*rect.left() + rect.top()*_width_step; + } + + void* _data = 0; + long _width_step = 0; + long _nr = 0; + long _nc = 0; + }; + + template + struct const_sub_image_proxy + { + const_sub_image_proxy() = default; + + const_sub_image_proxy ( + const T& img, + rectangle rect + ) + { + rect = rect.intersect(get_rect(img)); + typedef typename image_traits::pixel_type pixel_type; + + _nr = rect.height(); + _nc = rect.width(); + _width_step = width_step(img); + _data = (const char*)image_data(img) + sizeof(pixel_type)*rect.left() + rect.top()*_width_step; + } + + const void* _data = 0; + long _width_step = 0; + long _nr = 0; + long _nc = 0; + }; + + template + struct image_traits > + { + typedef typename image_traits::pixel_type pixel_type; + }; + template + struct image_traits > + { + typedef typename image_traits::pixel_type pixel_type; + }; + template + struct image_traits > + { + typedef typename image_traits::pixel_type pixel_type; + }; + template + struct image_traits > + { + typedef typename image_traits::pixel_type pixel_type; + }; + + template + inline long num_rows( const sub_image_proxy& img) { return img._nr; } + template + inline long num_columns( const sub_image_proxy& img) { return img._nc; } + + template + inline long num_rows( const const_sub_image_proxy& img) { return img._nr; } + template + inline long num_columns( const const_sub_image_proxy& img) { return img._nc; } + + template + inline void* image_data( sub_image_proxy& img) + { + return img._data; + } + template + inline const void* image_data( const sub_image_proxy& img) + { + return img._data; + } + + template + inline const void* image_data( const const_sub_image_proxy& img) + { + return img._data; + } + + template + inline long width_step( + const sub_image_proxy& img + ) { return img._width_step; } + + template + inline long width_step( + const const_sub_image_proxy& img + ) { return img._width_step; } + + template + void set_image_size(sub_image_proxy& img, long rows, long cols) + { + DLIB_CASSERT(img._nr == rows && img._nc == cols, "A sub_image can't be resized." + << "\n\t img._nr: "<< img._nr + << "\n\t img._nc: "<< img._nc + << "\n\t rows: "<< rows + << "\n\t cols: "<< cols + ); + } + + template < + typename image_type + > + sub_image_proxy sub_image ( + image_type& img, + const rectangle& rect + ) + { + return sub_image_proxy(img,rect); + } + + template < + typename image_type + > + const const_sub_image_proxy sub_image ( + const image_type& img, + const rectangle& rect + ) + { + return const_sub_image_proxy(img,rect); + } + + template + inline sub_image_proxy> sub_image ( + T* img, + long nr, + long nc, + long row_stride + ) + { + sub_image_proxy> tmp; + tmp._data = img; + tmp._nr = nr; + tmp._nc = nc; + tmp._width_step = row_stride*sizeof(T); + return tmp; + } + + template + inline const const_sub_image_proxy> sub_image ( + const T* img, + long nr, + long nc, + long row_stride + ) + { + const_sub_image_proxy> tmp; + tmp._data = img; + tmp._nr = nr; + tmp._nc = nc; + tmp._width_step = row_stride*sizeof(T); + return tmp; + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class interpolate_nearest_neighbor + { + public: + + template + bool operator() ( + const image_view_type& img, + const dlib::point& p, + pixel_type& result + ) const + { + COMPILE_TIME_ASSERT(pixel_traits::has_alpha == false); + + if (get_rect(img).contains(p)) + { + assign_pixel(result, img[p.y()][p.x()]); + return true; + } + else + { + return false; + } + } + + }; + +// ---------------------------------------------------------------------------------------- + + class interpolate_bilinear + { + template + struct is_rgb_image + { + const static bool value = pixel_traits::rgb; + }; + + public: + + template + typename disable_if,bool>::type operator() ( + const image_view_type& img, + const dlib::vector& p, + pixel_type& result + ) const + { + COMPILE_TIME_ASSERT(pixel_traits::has_alpha == false); + + const long left = static_cast(std::floor(p.x())); + const long top = static_cast(std::floor(p.y())); + const long right = left+1; + const long bottom = top+1; + + + // if the interpolation goes outside img + if (!(left >= 0 && top >= 0 && right < img.nc() && bottom < img.nr())) + return false; + + const double lr_frac = p.x() - left; + const double tb_frac = p.y() - top; + + double tl = 0, tr = 0, bl = 0, br = 0; + + assign_pixel(tl, img[top][left]); + assign_pixel(tr, img[top][right]); + assign_pixel(bl, img[bottom][left]); + assign_pixel(br, img[bottom][right]); + + double temp = (1-tb_frac)*((1-lr_frac)*tl + lr_frac*tr) + + tb_frac*((1-lr_frac)*bl + lr_frac*br); + + assign_pixel(result, temp); + return true; + } + + template + typename enable_if,bool>::type operator() ( + const image_view_type& img, + const dlib::vector& p, + pixel_type& result + ) const + { + COMPILE_TIME_ASSERT(pixel_traits::has_alpha == false); + + const long left = static_cast(std::floor(p.x())); + const long top = static_cast(std::floor(p.y())); + const long right = left+1; + const long bottom = top+1; + + + // if the interpolation goes outside img + if (!(left >= 0 && top >= 0 && right < img.nc() && bottom < img.nr())) + return false; + + const double lr_frac = p.x() - left; + const double tb_frac = p.y() - top; + + double tl, tr, bl, br; + + tl = img[top][left].red; + tr = img[top][right].red; + bl = img[bottom][left].red; + br = img[bottom][right].red; + const double red = (1-tb_frac)*((1-lr_frac)*tl + lr_frac*tr) + + tb_frac*((1-lr_frac)*bl + lr_frac*br); + + tl = img[top][left].green; + tr = img[top][right].green; + bl = img[bottom][left].green; + br = img[bottom][right].green; + const double green = (1-tb_frac)*((1-lr_frac)*tl + lr_frac*tr) + + tb_frac*((1-lr_frac)*bl + lr_frac*br); + + tl = img[top][left].blue; + tr = img[top][right].blue; + bl = img[bottom][left].blue; + br = img[bottom][right].blue; + const double blue = (1-tb_frac)*((1-lr_frac)*tl + lr_frac*tr) + + tb_frac*((1-lr_frac)*bl + lr_frac*br); + + rgb_pixel temp; + assign_pixel(temp.red, red); + assign_pixel(temp.green, green); + assign_pixel(temp.blue, blue); + assign_pixel(result, temp); + return true; + } + }; + +// ---------------------------------------------------------------------------------------- + + class interpolate_quadratic + { + template + struct is_rgb_image + { + const static bool value = pixel_traits::rgb; + }; + + public: + + template + typename disable_if,bool>::type operator() ( + const image_view_type& img, + const dlib::vector& p, + pixel_type& result + ) const + { + COMPILE_TIME_ASSERT(pixel_traits::has_alpha == false); + + const point pp(p); + + // if the interpolation goes outside img + if (!get_rect(img).contains(grow_rect(pp,1))) + return false; + + const long r = pp.y(); + const long c = pp.x(); + + const double temp = interpolate(p-pp, + img[r-1][c-1], + img[r-1][c ], + img[r-1][c+1], + img[r ][c-1], + img[r ][c ], + img[r ][c+1], + img[r+1][c-1], + img[r+1][c ], + img[r+1][c+1]); + + assign_pixel(result, temp); + return true; + } + + template + typename enable_if,bool>::type operator() ( + const image_view_type& img, + const dlib::vector& p, + pixel_type& result + ) const + { + COMPILE_TIME_ASSERT(pixel_traits::has_alpha == false); + + const point pp(p); + + // if the interpolation goes outside img + if (!get_rect(img).contains(grow_rect(pp,1))) + return false; + + const long r = pp.y(); + const long c = pp.x(); + + const double red = interpolate(p-pp, + img[r-1][c-1].red, + img[r-1][c ].red, + img[r-1][c+1].red, + img[r ][c-1].red, + img[r ][c ].red, + img[r ][c+1].red, + img[r+1][c-1].red, + img[r+1][c ].red, + img[r+1][c+1].red); + const double green = interpolate(p-pp, + img[r-1][c-1].green, + img[r-1][c ].green, + img[r-1][c+1].green, + img[r ][c-1].green, + img[r ][c ].green, + img[r ][c+1].green, + img[r+1][c-1].green, + img[r+1][c ].green, + img[r+1][c+1].green); + const double blue = interpolate(p-pp, + img[r-1][c-1].blue, + img[r-1][c ].blue, + img[r-1][c+1].blue, + img[r ][c-1].blue, + img[r ][c ].blue, + img[r ][c+1].blue, + img[r+1][c-1].blue, + img[r+1][c ].blue, + img[r+1][c+1].blue); + + + rgb_pixel temp; + assign_pixel(temp.red, red); + assign_pixel(temp.green, green); + assign_pixel(temp.blue, blue); + assign_pixel(result, temp); + + return true; + } + + private: + + /* tl tm tr + ml mm mr + bl bm br + */ + // The above is the pixel layout in our little 3x3 neighborhood. interpolate() will + // fit a quadratic to these 9 pixels and then use that quadratic to find the interpolated + // value at point p. + inline double interpolate( + const dlib::vector& p, + double tl, double tm, double tr, + double ml, double mm, double mr, + double bl, double bm, double br + ) const + { + matrix w; + // x + w(0) = (tr + mr + br - tl - ml - bl)*0.16666666666; + // y + w(1) = (bl + bm + br - tl - tm - tr)*0.16666666666; + // x^2 + w(2) = (tl + tr + ml + mr + bl + br)*0.16666666666 - (tm + mm + bm)*0.333333333; + // x*y + w(3) = (tl - tr - bl + br)*0.25; + // y^2 + w(4) = (tl + tm + tr + bl + bm + br)*0.16666666666 - (ml + mm + mr)*0.333333333; + // 1 (constant term) + w(5) = (tm + ml + mr + bm)*0.222222222 - (tl + tr + bl + br)*0.11111111 + (mm)*0.55555556; + + const double x = p.x(); + const double y = p.y(); + + matrix z; + z = x, y, x*x, x*y, y*y, 1.0; + + return dot(w,z); + } + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class black_background + { + public: + template + void operator() ( pixel_type& p) const { assign_pixel(p, 0); } + }; + + class white_background + { + public: + template + void operator() ( pixel_type& p) const { assign_pixel(p, 255); } + }; + + class no_background + { + public: + template + void operator() ( pixel_type& ) const { } + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type1, + typename image_type2, + typename interpolation_type, + typename point_mapping_type, + typename background_type + > + void transform_image ( + const image_type1& in_img, + image_type2& out_img, + const interpolation_type& interp, + const point_mapping_type& map_point, + const background_type& set_background, + const rectangle& area + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( get_rect(out_img).contains(area) == true && + is_same_object(in_img, out_img) == false , + "\t void transform_image()" + << "\n\t Invalid inputs were given to this function." + << "\n\t get_rect(out_img).contains(area): " << get_rect(out_img).contains(area) + << "\n\t get_rect(out_img): " << get_rect(out_img) + << "\n\t area: " << area + << "\n\t is_same_object(in_img, out_img): " << is_same_object(in_img, out_img) + ); + + const_image_view imgv(in_img); + image_view out_imgv(out_img); + + for (long r = area.top(); r <= area.bottom(); ++r) + { + for (long c = area.left(); c <= area.right(); ++c) + { + if (!interp(imgv, map_point(dlib::vector(c,r)), out_imgv[r][c])) + set_background(out_imgv[r][c]); + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type1, + typename image_type2, + typename interpolation_type, + typename point_mapping_type, + typename background_type + > + void transform_image ( + const image_type1& in_img, + image_type2& out_img, + const interpolation_type& interp, + const point_mapping_type& map_point, + const background_type& set_background + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( is_same_object(in_img, out_img) == false , + "\t void transform_image()" + << "\n\t Invalid inputs were given to this function." + << "\n\t is_same_object(in_img, out_img): " << is_same_object(in_img, out_img) + ); + + transform_image(in_img, out_img, interp, map_point, set_background, get_rect(out_img)); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type1, + typename image_type2, + typename interpolation_type, + typename point_mapping_type + > + void transform_image ( + const image_type1& in_img, + image_type2& out_img, + const interpolation_type& interp, + const point_mapping_type& map_point + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( is_same_object(in_img, out_img) == false , + "\t void transform_image()" + << "\n\t Invalid inputs were given to this function." + << "\n\t is_same_object(in_img, out_img): " << is_same_object(in_img, out_img) + ); + + + transform_image(in_img, out_img, interp, map_point, black_background(), get_rect(out_img)); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type1, + typename image_type2, + typename interpolation_type + > + point_transform_affine rotate_image ( + const image_type1& in_img, + image_type2& out_img, + double angle, + const interpolation_type& interp + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( is_same_object(in_img, out_img) == false , + "\t point_transform_affine rotate_image()" + << "\n\t Invalid inputs were given to this function." + << "\n\t is_same_object(in_img, out_img): " << is_same_object(in_img, out_img) + ); + + const rectangle rimg = get_rect(in_img); + + + // figure out bounding box for rotated rectangle + rectangle rect; + rect += rotate_point(center(rimg), rimg.tl_corner(), -angle); + rect += rotate_point(center(rimg), rimg.tr_corner(), -angle); + rect += rotate_point(center(rimg), rimg.bl_corner(), -angle); + rect += rotate_point(center(rimg), rimg.br_corner(), -angle); + set_image_size(out_img, rect.height(), rect.width()); + + const matrix R = rotation_matrix(angle); + + point_transform_affine trans = point_transform_affine(R, -R*dcenter(get_rect(out_img)) + dcenter(rimg)); + transform_image(in_img, out_img, interp, trans); + return inv(trans); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type1, + typename image_type2 + > + point_transform_affine rotate_image ( + const image_type1& in_img, + image_type2& out_img, + double angle + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( is_same_object(in_img, out_img) == false , + "\t point_transform_affine rotate_image()" + << "\n\t Invalid inputs were given to this function." + << "\n\t is_same_object(in_img, out_img): " << is_same_object(in_img, out_img) + ); + + return rotate_image(in_img, out_img, angle, interpolate_quadratic()); + } + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + class helper_resize_image + { + public: + helper_resize_image( + double x_scale_, + double y_scale_ + ): + x_scale(x_scale_), + y_scale(y_scale_) + {} + + dlib::vector operator() ( + const dlib::vector& p + ) const + { + return dlib::vector(p.x()*x_scale, p.y()*y_scale); + } + + private: + const double x_scale; + const double y_scale; + }; + } + + template < + typename image_type1, + typename image_type2, + typename interpolation_type + > + void resize_image ( + const image_type1& in_img, + image_type2& out_img, + const interpolation_type& interp + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( is_same_object(in_img, out_img) == false , + "\t void resize_image()" + << "\n\t Invalid inputs were given to this function." + << "\n\t is_same_object(in_img, out_img): " << is_same_object(in_img, out_img) + ); + + const double x_scale = (num_columns(in_img)-1)/(double)std::max((num_columns(out_img)-1),1); + const double y_scale = (num_rows(in_img)-1)/(double)std::max((num_rows(out_img)-1),1); + transform_image(in_img, out_img, interp, + dlib::impl::helper_resize_image(x_scale,y_scale)); + } + +// ---------------------------------------------------------------------------------------- + + template + struct is_rgb_image { const static bool value = pixel_traits::pixel_type>::rgb; }; + template + struct is_grayscale_image { const static bool value = pixel_traits::pixel_type>::grayscale; }; + + // This is an optimized version of resize_image for the case where bilinear + // interpolation is used. + template < + typename image_type1, + typename image_type2 + > + typename disable_if_c<(is_rgb_image::value&&is_rgb_image::value) || + (is_grayscale_image::value&&is_grayscale_image::value)>::type + resize_image ( + const image_type1& in_img_, + image_type2& out_img_, + interpolate_bilinear + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( is_same_object(in_img_, out_img_) == false , + "\t void resize_image()" + << "\n\t Invalid inputs were given to this function." + << "\n\t is_same_object(in_img_, out_img_): " << is_same_object(in_img_, out_img_) + ); + + const_image_view in_img(in_img_); + image_view out_img(out_img_); + + if (out_img.size() == 0 || in_img.size() == 0) + return; + + + typedef typename image_traits::pixel_type T; + typedef typename image_traits::pixel_type U; + const double x_scale = (in_img.nc()-1)/(double)std::max((out_img.nc()-1),1); + const double y_scale = (in_img.nr()-1)/(double)std::max((out_img.nr()-1),1); + double y = -y_scale; + for (long r = 0; r < out_img.nr(); ++r) + { + y += y_scale; + const long top = static_cast(std::floor(y)); + const long bottom = std::min(top+1, in_img.nr()-1); + const double tb_frac = y - top; + double x = -x_scale; + if (pixel_traits::grayscale) + { + for (long c = 0; c < out_img.nc(); ++c) + { + x += x_scale; + const long left = static_cast(std::floor(x)); + const long right = std::min(left+1, in_img.nc()-1); + const double lr_frac = x - left; + + double tl = 0, tr = 0, bl = 0, br = 0; + + assign_pixel(tl, in_img[top][left]); + assign_pixel(tr, in_img[top][right]); + assign_pixel(bl, in_img[bottom][left]); + assign_pixel(br, in_img[bottom][right]); + + double temp = (1-tb_frac)*((1-lr_frac)*tl + lr_frac*tr) + + tb_frac*((1-lr_frac)*bl + lr_frac*br); + + assign_pixel(out_img[r][c], temp); + } + } + else + { + for (long c = 0; c < out_img.nc(); ++c) + { + x += x_scale; + const long left = static_cast(std::floor(x)); + const long right = std::min(left+1, in_img.nc()-1); + const double lr_frac = x - left; + + const T tl = in_img[top][left]; + const T tr = in_img[top][right]; + const T bl = in_img[bottom][left]; + const T br = in_img[bottom][right]; + + T temp; + assign_pixel(temp, 0); + vector_to_pixel(temp, + (1-tb_frac)*((1-lr_frac)*pixel_to_vector(tl) + lr_frac*pixel_to_vector(tr)) + + tb_frac*((1-lr_frac)*pixel_to_vector(bl) + lr_frac*pixel_to_vector(br))); + assign_pixel(out_img[r][c], temp); + } + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type1, + typename image_type2 + > + struct images_have_same_pixel_types + { + typedef typename image_traits::pixel_type ptype1; + typedef typename image_traits::pixel_type ptype2; + const static bool value = is_same_type::value; + }; + + template < + typename image_type, + typename image_type2 + > + typename enable_if_c::value && is_grayscale_image::value && images_have_same_pixel_types::value>::type + resize_image ( + const image_type& in_img_, + image_type2& out_img_, + interpolate_bilinear + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( is_same_object(in_img_, out_img_) == false , + "\t void resize_image()" + << "\n\t Invalid inputs were given to this function." + << "\n\t is_same_object(in_img_, out_img_): " << is_same_object(in_img_, out_img_) + ); + + const_image_view in_img(in_img_); + image_view out_img(out_img_); + + if (out_img.size() == 0 || in_img.size() == 0) + return; + + typedef typename image_traits::pixel_type T; + const double x_scale = (in_img.nc()-1)/(double)std::max((out_img.nc()-1),1); + const double y_scale = (in_img.nr()-1)/(double)std::max((out_img.nr()-1),1); + double y = -y_scale; + for (long r = 0; r < out_img.nr(); ++r) + { + y += y_scale; + const long top = static_cast(std::floor(y)); + const long bottom = std::min(top+1, in_img.nr()-1); + const double tb_frac = y - top; + double x = -4*x_scale; + + const simd4f _tb_frac = tb_frac; + const simd4f _inv_tb_frac = 1-tb_frac; + const simd4f _x_scale = 4*x_scale; + simd4f _x(x, x+x_scale, x+2*x_scale, x+3*x_scale); + long c = 0; + for (;; c+=4) + { + _x += _x_scale; + simd4i left = simd4i(_x); + + simd4f _lr_frac = _x-left; + simd4f _inv_lr_frac = 1-_lr_frac; + simd4i right = left+1; + + simd4f tlf = _inv_tb_frac*_inv_lr_frac; + simd4f trf = _inv_tb_frac*_lr_frac; + simd4f blf = _tb_frac*_inv_lr_frac; + simd4f brf = _tb_frac*_lr_frac; + + int32 fleft[4]; + int32 fright[4]; + left.store(fleft); + right.store(fright); + + if (fright[3] >= in_img.nc()) + break; + simd4f tl(in_img[top][fleft[0]], in_img[top][fleft[1]], in_img[top][fleft[2]], in_img[top][fleft[3]]); + simd4f tr(in_img[top][fright[0]], in_img[top][fright[1]], in_img[top][fright[2]], in_img[top][fright[3]]); + simd4f bl(in_img[bottom][fleft[0]], in_img[bottom][fleft[1]], in_img[bottom][fleft[2]], in_img[bottom][fleft[3]]); + simd4f br(in_img[bottom][fright[0]], in_img[bottom][fright[1]], in_img[bottom][fright[2]], in_img[bottom][fright[3]]); + + simd4f out = simd4f(tlf*tl + trf*tr + blf*bl + brf*br); + float fout[4]; + out.store(fout); + + out_img[r][c] = static_cast(fout[0]); + out_img[r][c+1] = static_cast(fout[1]); + out_img[r][c+2] = static_cast(fout[2]); + out_img[r][c+3] = static_cast(fout[3]); + } + x = -x_scale + c*x_scale; + for (; c < out_img.nc(); ++c) + { + x += x_scale; + const long left = static_cast(std::floor(x)); + const long right = std::min(left+1, in_img.nc()-1); + const float lr_frac = x - left; + + float tl = 0, tr = 0, bl = 0, br = 0; + + assign_pixel(tl, in_img[top][left]); + assign_pixel(tr, in_img[top][right]); + assign_pixel(bl, in_img[bottom][left]); + assign_pixel(br, in_img[bottom][right]); + + float temp = (1-tb_frac)*((1-lr_frac)*tl + lr_frac*tr) + + tb_frac*((1-lr_frac)*bl + lr_frac*br); + + assign_pixel(out_img[r][c], temp); + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + typename enable_if >::type resize_image ( + const image_type& in_img_, + image_type& out_img_, + interpolate_bilinear + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( is_same_object(in_img_, out_img_) == false , + "\t void resize_image()" + << "\n\t Invalid inputs were given to this function." + << "\n\t is_same_object(in_img_, out_img_): " << is_same_object(in_img_, out_img_) + ); + + const_image_view in_img(in_img_); + image_view out_img(out_img_); + + if (out_img.size() == 0 || in_img.size() == 0) + return; + + + typedef typename image_traits::pixel_type T; + const double x_scale = (in_img.nc()-1)/(double)std::max((out_img.nc()-1),1); + const double y_scale = (in_img.nr()-1)/(double)std::max((out_img.nr()-1),1); + double y = -y_scale; + for (long r = 0; r < out_img.nr(); ++r) + { + y += y_scale; + const long top = static_cast(std::floor(y)); + const long bottom = std::min(top+1, in_img.nr()-1); + const double tb_frac = y - top; + double x = -4*x_scale; + + const simd4f _tb_frac = tb_frac; + const simd4f _inv_tb_frac = 1-tb_frac; + const simd4f _x_scale = 4*x_scale; + simd4f _x(x, x+x_scale, x+2*x_scale, x+3*x_scale); + long c = 0; + for (;; c+=4) + { + _x += _x_scale; + simd4i left = simd4i(_x); + simd4f lr_frac = _x-left; + simd4f _inv_lr_frac = 1-lr_frac; + simd4i right = left+1; + + simd4f tlf = _inv_tb_frac*_inv_lr_frac; + simd4f trf = _inv_tb_frac*lr_frac; + simd4f blf = _tb_frac*_inv_lr_frac; + simd4f brf = _tb_frac*lr_frac; + + int32 fleft[4]; + int32 fright[4]; + left.store(fleft); + right.store(fright); + + if (fright[3] >= in_img.nc()) + break; + simd4f tl(in_img[top][fleft[0]].red, in_img[top][fleft[1]].red, in_img[top][fleft[2]].red, in_img[top][fleft[3]].red); + simd4f tr(in_img[top][fright[0]].red, in_img[top][fright[1]].red, in_img[top][fright[2]].red, in_img[top][fright[3]].red); + simd4f bl(in_img[bottom][fleft[0]].red, in_img[bottom][fleft[1]].red, in_img[bottom][fleft[2]].red, in_img[bottom][fleft[3]].red); + simd4f br(in_img[bottom][fright[0]].red, in_img[bottom][fright[1]].red, in_img[bottom][fright[2]].red, in_img[bottom][fright[3]].red); + + simd4i out = simd4i(tlf*tl + trf*tr + blf*bl + brf*br); + int32 fout[4]; + out.store(fout); + + out_img[r][c].red = static_cast(fout[0]); + out_img[r][c+1].red = static_cast(fout[1]); + out_img[r][c+2].red = static_cast(fout[2]); + out_img[r][c+3].red = static_cast(fout[3]); + + + tl = simd4f(in_img[top][fleft[0]].green, in_img[top][fleft[1]].green, in_img[top][fleft[2]].green, in_img[top][fleft[3]].green); + tr = simd4f(in_img[top][fright[0]].green, in_img[top][fright[1]].green, in_img[top][fright[2]].green, in_img[top][fright[3]].green); + bl = simd4f(in_img[bottom][fleft[0]].green, in_img[bottom][fleft[1]].green, in_img[bottom][fleft[2]].green, in_img[bottom][fleft[3]].green); + br = simd4f(in_img[bottom][fright[0]].green, in_img[bottom][fright[1]].green, in_img[bottom][fright[2]].green, in_img[bottom][fright[3]].green); + out = simd4i(tlf*tl + trf*tr + blf*bl + brf*br); + out.store(fout); + out_img[r][c].green = static_cast(fout[0]); + out_img[r][c+1].green = static_cast(fout[1]); + out_img[r][c+2].green = static_cast(fout[2]); + out_img[r][c+3].green = static_cast(fout[3]); + + + tl = simd4f(in_img[top][fleft[0]].blue, in_img[top][fleft[1]].blue, in_img[top][fleft[2]].blue, in_img[top][fleft[3]].blue); + tr = simd4f(in_img[top][fright[0]].blue, in_img[top][fright[1]].blue, in_img[top][fright[2]].blue, in_img[top][fright[3]].blue); + bl = simd4f(in_img[bottom][fleft[0]].blue, in_img[bottom][fleft[1]].blue, in_img[bottom][fleft[2]].blue, in_img[bottom][fleft[3]].blue); + br = simd4f(in_img[bottom][fright[0]].blue, in_img[bottom][fright[1]].blue, in_img[bottom][fright[2]].blue, in_img[bottom][fright[3]].blue); + out = simd4i(tlf*tl + trf*tr + blf*bl + brf*br); + out.store(fout); + out_img[r][c].blue = static_cast(fout[0]); + out_img[r][c+1].blue = static_cast(fout[1]); + out_img[r][c+2].blue = static_cast(fout[2]); + out_img[r][c+3].blue = static_cast(fout[3]); + } + x = -x_scale + c*x_scale; + for (; c < out_img.nc(); ++c) + { + x += x_scale; + const long left = static_cast(std::floor(x)); + const long right = std::min(left+1, in_img.nc()-1); + const double lr_frac = x - left; + + const T tl = in_img[top][left]; + const T tr = in_img[top][right]; + const T bl = in_img[bottom][left]; + const T br = in_img[bottom][right]; + + T temp; + assign_pixel(temp, 0); + vector_to_pixel(temp, + (1-tb_frac)*((1-lr_frac)*pixel_to_vector(tl) + lr_frac*pixel_to_vector(tr)) + + tb_frac*((1-lr_frac)*pixel_to_vector(bl) + lr_frac*pixel_to_vector(br))); + assign_pixel(out_img[r][c], temp); + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type1, + typename image_type2 + > + void resize_image ( + const image_type1& in_img, + image_type2& out_img + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( is_same_object(in_img, out_img) == false , + "\t void resize_image()" + << "\n\t Invalid inputs were given to this function." + << "\n\t is_same_object(in_img, out_img): " << is_same_object(in_img, out_img) + ); + + resize_image(in_img, out_img, interpolate_bilinear()); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + void resize_image ( + double size_scale, + image_type& img + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( size_scale > 0 , + "\t void resize_image()" + << "\n\t Invalid inputs were given to this function." + << "\n\t size_scale: " << size_scale + ); + + image_type temp; + set_image_size(temp, std::round(size_scale*num_rows(img)), std::round(size_scale*num_columns(img))); + resize_image(img, temp); + swap(img, temp); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type1, + typename image_type2 + > + point_transform_affine flip_image_left_right ( + const image_type1& in_img, + image_type2& out_img + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( is_same_object(in_img, out_img) == false , + "\t void flip_image_left_right()" + << "\n\t Invalid inputs were given to this function." + << "\n\t is_same_object(in_img, out_img): " << is_same_object(in_img, out_img) + ); + + assign_image(out_img, fliplr(mat(in_img))); + std::vector > from, to; + rectangle r = get_rect(in_img); + from.push_back(r.tl_corner()); to.push_back(r.tr_corner()); + from.push_back(r.bl_corner()); to.push_back(r.br_corner()); + from.push_back(r.tr_corner()); to.push_back(r.tl_corner()); + from.push_back(r.br_corner()); to.push_back(r.bl_corner()); + return find_affine_transform(from,to); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + point_transform_affine flip_image_left_right ( + image_type& img + ) + { + image_type temp; + auto tform = flip_image_left_right(img, temp); + swap(temp,img); + return tform; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type1, + typename image_type2 + > + void flip_image_up_down ( + const image_type1& in_img, + image_type2& out_img + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( is_same_object(in_img, out_img) == false , + "\t void flip_image_up_down()" + << "\n\t Invalid inputs were given to this function." + << "\n\t is_same_object(in_img, out_img): " << is_same_object(in_img, out_img) + ); + + assign_image(out_img, flipud(mat(in_img))); + } + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + inline rectangle flip_rect_left_right ( + const rectangle& rect, + const rectangle& window + ) + { + rectangle temp; + temp.top() = rect.top(); + temp.bottom() = rect.bottom(); + + const long left_dist = rect.left()-window.left(); + + temp.right() = window.right()-left_dist; + temp.left() = temp.right()-rect.width()+1; + return temp; + } + + inline rectangle tform_object ( + const point_transform_affine& tran, + const rectangle& rect + ) + { + return centered_rect(tran(center(rect)), rect.width(), rect.height()); + } + + inline mmod_rect tform_object ( + const point_transform_affine& tran, + mmod_rect rect + ) + { + rect.rect = tform_object(tran, rect.rect); + return rect; + } + + inline full_object_detection tform_object( + const point_transform_affine& tran, + const full_object_detection& obj + ) + { + std::vector parts; + parts.reserve(obj.num_parts()); + for (unsigned long i = 0; i < obj.num_parts(); ++i) + { + if (obj.part(i) != OBJECT_PART_NOT_PRESENT) + parts.push_back(tran(obj.part(i))); + else + parts.push_back(OBJECT_PART_NOT_PRESENT); + } + return full_object_detection(tform_object(tran,obj.get_rect()), parts); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_array_type, + typename T + > + void add_image_left_right_flips ( + image_array_type& images, + std::vector >& objects + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( images.size() == objects.size(), + "\t void add_image_left_right_flips()" + << "\n\t Invalid inputs were given to this function." + << "\n\t images.size(): " << images.size() + << "\n\t objects.size(): " << objects.size() + ); + + typename image_array_type::value_type temp; + std::vector rects; + + const unsigned long num = images.size(); + for (unsigned long j = 0; j < num; ++j) + { + const point_transform_affine tran = flip_image_left_right(images[j], temp); + + rects.clear(); + for (unsigned long i = 0; i < objects[j].size(); ++i) + rects.push_back(impl::tform_object(tran, objects[j][i])); + + images.push_back(std::move(temp)); + objects.push_back(rects); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_array_type, + typename T, + typename U + > + void add_image_left_right_flips ( + image_array_type& images, + std::vector >& objects, + std::vector >& objects2 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( images.size() == objects.size() && + images.size() == objects2.size(), + "\t void add_image_left_right_flips()" + << "\n\t Invalid inputs were given to this function." + << "\n\t images.size(): " << images.size() + << "\n\t objects.size(): " << objects.size() + << "\n\t objects2.size(): " << objects2.size() + ); + + typename image_array_type::value_type temp; + std::vector rects; + std::vector rects2; + + const unsigned long num = images.size(); + for (unsigned long j = 0; j < num; ++j) + { + const point_transform_affine tran = flip_image_left_right(images[j], temp); + images.push_back(std::move(temp)); + + rects.clear(); + for (unsigned long i = 0; i < objects[j].size(); ++i) + rects.push_back(impl::tform_object(tran, objects[j][i])); + objects.push_back(rects); + + rects2.clear(); + for (unsigned long i = 0; i < objects2[j].size(); ++i) + rects2.push_back(impl::tform_object(tran, objects2[j][i])); + objects2.push_back(rects2); + } + } + +// ---------------------------------------------------------------------------------------- + + template + void flip_image_dataset_left_right ( + image_array_type& images, + std::vector >& objects + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( images.size() == objects.size(), + "\t void flip_image_dataset_left_right()" + << "\n\t Invalid inputs were given to this function." + << "\n\t images.size(): " << images.size() + << "\n\t objects.size(): " << objects.size() + ); + + typename image_array_type::value_type temp; + for (unsigned long i = 0; i < images.size(); ++i) + { + flip_image_left_right(images[i], temp); + swap(temp,images[i]); + for (unsigned long j = 0; j < objects[i].size(); ++j) + { + objects[i][j] = impl::flip_rect_left_right(objects[i][j], get_rect(images[i])); + } + } + } + +// ---------------------------------------------------------------------------------------- + + template + void flip_image_dataset_left_right ( + image_array_type& images, + std::vector >& objects, + std::vector >& objects2 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( images.size() == objects.size() && + images.size() == objects2.size(), + "\t void flip_image_dataset_left_right()" + << "\n\t Invalid inputs were given to this function." + << "\n\t images.size(): " << images.size() + << "\n\t objects.size(): " << objects.size() + << "\n\t objects2.size(): " << objects2.size() + ); + + typename image_array_type::value_type temp; + for (unsigned long i = 0; i < images.size(); ++i) + { + flip_image_left_right(images[i], temp); + swap(temp, images[i]); + for (unsigned long j = 0; j < objects[i].size(); ++j) + { + objects[i][j] = impl::flip_rect_left_right(objects[i][j], get_rect(images[i])); + } + for (unsigned long j = 0; j < objects2[i].size(); ++j) + { + objects2[i][j] = impl::flip_rect_left_right(objects2[i][j], get_rect(images[i])); + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename pyramid_type, + typename image_array_type + > + void upsample_image_dataset ( + image_array_type& images, + std::vector >& objects, + unsigned long max_image_size = std::numeric_limits::max() + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( images.size() == objects.size(), + "\t void upsample_image_dataset()" + << "\n\t Invalid inputs were given to this function." + << "\n\t images.size(): " << images.size() + << "\n\t objects.size(): " << objects.size() + ); + + typename image_array_type::value_type temp; + pyramid_type pyr; + for (unsigned long i = 0; i < images.size(); ++i) + { + const unsigned long img_size = num_rows(images[i])*num_columns(images[i]); + if (img_size <= max_image_size) + { + pyramid_up(images[i], temp, pyr); + swap(temp, images[i]); + for (unsigned long j = 0; j < objects[i].size(); ++j) + { + objects[i][j] = pyr.rect_up(objects[i][j]); + } + } + } + } + + template < + typename pyramid_type, + typename image_array_type + > + void upsample_image_dataset ( + image_array_type& images, + std::vector>& objects, + unsigned long max_image_size = std::numeric_limits::max() + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( images.size() == objects.size(), + "\t void upsample_image_dataset()" + << "\n\t Invalid inputs were given to this function." + << "\n\t images.size(): " << images.size() + << "\n\t objects.size(): " << objects.size() + ); + + typename image_array_type::value_type temp; + pyramid_type pyr; + for (unsigned long i = 0; i < images.size(); ++i) + { + const unsigned long img_size = num_rows(images[i])*num_columns(images[i]); + if (img_size <= max_image_size) + { + pyramid_up(images[i], temp, pyr); + swap(temp, images[i]); + for (unsigned long j = 0; j < objects[i].size(); ++j) + { + objects[i][j].rect = pyr.rect_up(objects[i][j].rect); + } + } + } + } + + template < + typename pyramid_type, + typename image_array_type + > + void upsample_image_dataset ( + image_array_type& images, + std::vector >& objects, + std::vector >& objects2, + unsigned long max_image_size = std::numeric_limits::max() + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( images.size() == objects.size() && + images.size() == objects2.size(), + "\t void upsample_image_dataset()" + << "\n\t Invalid inputs were given to this function." + << "\n\t images.size(): " << images.size() + << "\n\t objects.size(): " << objects.size() + << "\n\t objects2.size(): " << objects2.size() + ); + + typename image_array_type::value_type temp; + pyramid_type pyr; + for (unsigned long i = 0; i < images.size(); ++i) + { + const unsigned long img_size = num_rows(images[i])*num_columns(images[i]); + if (img_size <= max_image_size) + { + pyramid_up(images[i], temp, pyr); + swap(temp, images[i]); + for (unsigned long j = 0; j < objects[i].size(); ++j) + { + objects[i][j] = pyr.rect_up(objects[i][j]); + } + for (unsigned long j = 0; j < objects2[i].size(); ++j) + { + objects2[i][j] = pyr.rect_up(objects2[i][j]); + } + } + } + } + +// ---------------------------------------------------------------------------------------- + + template + void rotate_image_dataset ( + double angle, + image_array_type& images, + std::vector >& objects + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( images.size() == objects.size(), + "\t void rotate_image_dataset()" + << "\n\t Invalid inputs were given to this function." + << "\n\t images.size(): " << images.size() + << "\n\t objects.size(): " << objects.size() + ); + + typename image_array_type::value_type temp; + for (unsigned long i = 0; i < images.size(); ++i) + { + const point_transform_affine tran = rotate_image(images[i], temp, angle); + swap(temp, images[i]); + for (unsigned long j = 0; j < objects[i].size(); ++j) + { + const rectangle rect = objects[i][j]; + objects[i][j] = centered_rect(tran(center(rect)), rect.width(), rect.height()); + } + } + } + + template + void rotate_image_dataset ( + double angle, + image_array_type& images, + std::vector >& objects, + std::vector >& objects2 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( images.size() == objects.size() && + images.size() == objects2.size(), + "\t void rotate_image_dataset()" + << "\n\t Invalid inputs were given to this function." + << "\n\t images.size(): " << images.size() + << "\n\t objects.size(): " << objects.size() + << "\n\t objects2.size(): " << objects2.size() + ); + + typename image_array_type::value_type temp; + for (unsigned long i = 0; i < images.size(); ++i) + { + const point_transform_affine tran = rotate_image(images[i], temp, angle); + swap(temp, images[i]); + for (unsigned long j = 0; j < objects[i].size(); ++j) + { + const rectangle rect = objects[i][j]; + objects[i][j] = centered_rect(tran(center(rect)), rect.width(), rect.height()); + } + for (unsigned long j = 0; j < objects2[i].size(); ++j) + { + const rectangle rect = objects2[i][j]; + objects2[i][j] = centered_rect(tran(center(rect)), rect.width(), rect.height()); + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_array_type, + typename EXP, + typename T, + typename U + > + void add_image_rotations ( + const matrix_exp& angles, + image_array_type& images, + std::vector >& objects, + std::vector >& objects2 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( is_vector(angles) && angles.size() > 0 && + images.size() == objects.size() && + images.size() == objects2.size(), + "\t void add_image_rotations()" + << "\n\t Invalid inputs were given to this function." + << "\n\t is_vector(angles): " << is_vector(angles) + << "\n\t angles.size(): " << angles.size() + << "\n\t images.size(): " << images.size() + << "\n\t objects.size(): " << objects.size() + << "\n\t objects2.size(): " << objects2.size() + ); + + image_array_type new_images; + std::vector > new_objects; + std::vector > new_objects2; + + using namespace impl; + + std::vector objtemp; + std::vector objtemp2; + typename image_array_type::value_type temp; + for (long i = 0; i < angles.size(); ++i) + { + for (unsigned long j = 0; j < images.size(); ++j) + { + const point_transform_affine tran = rotate_image(images[j], temp, angles(i)); + new_images.push_back(std::move(temp)); + + objtemp.clear(); + for (unsigned long k = 0; k < objects[j].size(); ++k) + objtemp.push_back(tform_object(tran, objects[j][k])); + new_objects.push_back(objtemp); + + objtemp2.clear(); + for (unsigned long k = 0; k < objects2[j].size(); ++k) + objtemp2.push_back(tform_object(tran, objects2[j][k])); + new_objects2.push_back(objtemp2); + } + } + + new_images.swap(images); + new_objects.swap(objects); + new_objects2.swap(objects2); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_array_type, + typename EXP, + typename T + > + void add_image_rotations ( + const matrix_exp& angles, + image_array_type& images, + std::vector >& objects + ) + { + std::vector > objects2(objects.size()); + add_image_rotations(angles, images, objects, objects2); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename image_type1, + typename image_type2, + typename pyramid_type, + typename interpolation_type + > + void pyramid_up ( + const image_type1& in_img, + image_type2& out_img, + const pyramid_type& pyr, + const interpolation_type& interp + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( is_same_object(in_img, out_img) == false , + "\t void pyramid_up()" + << "\n\t Invalid inputs were given to this function." + << "\n\t is_same_object(in_img, out_img): " << is_same_object(in_img, out_img) + ); + + if (image_size(in_img) == 0) + { + set_image_size(out_img, 0, 0); + return; + } + + rectangle rect = get_rect(in_img); + rectangle uprect = pyr.rect_up(rect); + if (uprect.is_empty()) + { + set_image_size(out_img, 0, 0); + return; + } + set_image_size(out_img, uprect.bottom()+1, uprect.right()+1); + + resize_image(in_img, out_img, interp); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type1, + typename image_type2, + typename pyramid_type + > + void pyramid_up ( + const image_type1& in_img, + image_type2& out_img, + const pyramid_type& pyr + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( is_same_object(in_img, out_img) == false , + "\t void pyramid_up()" + << "\n\t Invalid inputs were given to this function." + << "\n\t is_same_object(in_img, out_img): " << is_same_object(in_img, out_img) + ); + + pyramid_up(in_img, out_img, pyr, interpolate_bilinear()); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type, + typename pyramid_type + > + void pyramid_up ( + image_type& img, + const pyramid_type& pyr + ) + { + image_type temp; + pyramid_up(img, temp, pyr); + swap(temp, img); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + void pyramid_up ( + image_type& img + ) + { + pyramid_down<2> pyr; + pyramid_up(img, pyr); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + struct chip_dims + { + chip_dims ( + unsigned long rows_, + unsigned long cols_ + ) : rows(rows_), cols(cols_) { } + + unsigned long rows; + unsigned long cols; + }; + + struct chip_details + { + chip_details() : angle(0), rows(0), cols(0) {} + chip_details(const rectangle& rect_) : rect(rect_),angle(0), rows(rect_.height()), cols(rect_.width()) {} + chip_details(const drectangle& rect_) : rect(rect_),angle(0), + rows((unsigned long)(rect_.height()+0.5)), cols((unsigned long)(rect_.width()+0.5)) {} + chip_details(const drectangle& rect_, unsigned long size) : rect(rect_),angle(0) + { compute_dims_from_size(size); } + chip_details(const drectangle& rect_, unsigned long size, double angle_) : rect(rect_),angle(angle_) + { compute_dims_from_size(size); } + + chip_details(const drectangle& rect_, const chip_dims& dims) : + rect(rect_),angle(0),rows(dims.rows), cols(dims.cols) {} + chip_details(const drectangle& rect_, const chip_dims& dims, double angle_) : + rect(rect_),angle(angle_),rows(dims.rows), cols(dims.cols) {} + + template + chip_details( + const std::vector >& chip_points, + const std::vector >& img_points, + const chip_dims& dims + ) : + rows(dims.rows), cols(dims.cols) + { + DLIB_CASSERT( chip_points.size() == img_points.size() && chip_points.size() >= 2, + "\t chip_details::chip_details(chip_points,img_points,dims)" + << "\n\t Invalid inputs were given to this function." + << "\n\t chip_points.size(): " << chip_points.size() + << "\n\t img_points.size(): " << img_points.size() + ); + + const point_transform_affine tform = find_similarity_transform(chip_points,img_points); + dlib::vector p(1,0); + p = tform.get_m()*p; + + // There are only 3 things happening in a similarity transform. There is a + // rescaling, a rotation, and a translation. So here we pick out the scale and + // rotation parameters. + angle = std::atan2(p.y(),p.x()); + // Note that the translation and scale part are represented by the extraction + // rectangle. So here we build the appropriate rectangle. + const double scale = length(p); + rect = centered_drect(tform(point(dims.cols,dims.rows)/2.0), + dims.cols*scale, + dims.rows*scale); + } + + + drectangle rect; + double angle; + unsigned long rows; + unsigned long cols; + + inline unsigned long size() const + { + return rows*cols; + } + + private: + void compute_dims_from_size ( + unsigned long size + ) + { + const double relative_size = std::sqrt(size/(double)rect.area()); + rows = static_cast(rect.height()*relative_size + 0.5); + cols = static_cast(size/(double)rows + 0.5); + rows = std::max(1ul,rows); + cols = std::max(1ul,cols); + } + }; + +// ---------------------------------------------------------------------------------------- + + inline point_transform_affine get_mapping_to_chip ( + const chip_details& details + ) + { + std::vector > from, to; + point p1(0,0); + point p2(details.cols-1,0); + point p3(details.cols-1, details.rows-1); + to.push_back(p1); + from.push_back(rotate_point(center(details.rect),details.rect.tl_corner(),details.angle)); + to.push_back(p2); + from.push_back(rotate_point(center(details.rect),details.rect.tr_corner(),details.angle)); + to.push_back(p3); + from.push_back(rotate_point(center(details.rect),details.rect.br_corner(),details.angle)); + return find_affine_transform(from, to); + } + +// ---------------------------------------------------------------------------------------- + + inline full_object_detection map_det_to_chip( + const full_object_detection& det, + const chip_details& details + ) + { + point_transform_affine tform = get_mapping_to_chip(details); + full_object_detection res(det); + // map the parts + for (unsigned long l = 0; l < det.num_parts(); ++l) + { + if (det.part(l) != OBJECT_PART_NOT_PRESENT) + res.part(l) = tform(det.part(l)); + else + res.part(l) = OBJECT_PART_NOT_PRESENT; + } + // map the main rectangle + rectangle rect; + rect += tform(det.get_rect().tl_corner()); + rect += tform(det.get_rect().tr_corner()); + rect += tform(det.get_rect().bl_corner()); + rect += tform(det.get_rect().br_corner()); + res.get_rect() = rect; + return res; + } + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + template < + typename image_type1, + typename image_type2 + > + void basic_extract_image_chip ( + const image_type1& img, + const rectangle& location, + image_type2& chip + ) + /*! + ensures + - This function doesn't do any scaling or rotating. It just pulls out the + chip in the given rectangle. This also means the output image has the + same dimensions as the location rectangle. + !*/ + { + const_image_view vimg(img); + image_view vchip(chip); + + vchip.set_size(location.height(), location.width()); + + // location might go outside img so clip it + rectangle area = location.intersect(get_rect(img)); + + // find the part of the chip that corresponds to area in img. + rectangle chip_area = translate_rect(area, -location.tl_corner()); + + zero_border_pixels(chip, chip_area); + // now pull out the contents of area/chip_area. + for (long r = chip_area.top(), rr = area.top(); r <= chip_area.bottom(); ++r,++rr) + { + for (long c = chip_area.left(), cc = area.left(); c <= chip_area.right(); ++c,++cc) + { + assign_pixel(vchip[r][c], vimg[rr][cc]); + } + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type1, + typename image_type2, + typename interpolation_type + > + void extract_image_chips ( + const image_type1& img, + const std::vector& chip_locations, + dlib::array& chips, + const interpolation_type& interp + ) + { + // make sure requires clause is not broken +#ifdef ENABLE_ASSERTS + for (unsigned long i = 0; i < chip_locations.size(); ++i) + { + DLIB_CASSERT(chip_locations[i].size() != 0 && + chip_locations[i].rect.is_empty() == false, + "\t void extract_image_chips()" + << "\n\t Invalid inputs were given to this function." + << "\n\t chip_locations["< pyr; + long max_depth = 0; + // If the chip is supposed to be much smaller than the source subwindow then you + // can't just extract it using bilinear interpolation since at a high enough + // downsampling amount it would effectively turn into nearest neighbor + // interpolation. So we use an image pyramid to make sure the interpolation is + // fast but also high quality. The first thing we do is figure out how deep the + // image pyramid needs to be. + rectangle bounding_box; + for (unsigned long i = 0; i < chip_locations.size(); ++i) + { + long depth = 0; + double grow = 2; + drectangle rect = pyr.rect_down(chip_locations[i].rect); + while (rect.area() > chip_locations[i].size()) + { + rect = pyr.rect_down(rect); + ++depth; + // We drop the image size by a factor of 2 each iteration and then assume a + // border of 2 pixels is needed to avoid any border effects of the crop. + grow = grow*2 + 2; + } + drectangle rot_rect; + const vector cent = center(chip_locations[i].rect); + rot_rect += rotate_point(cent,chip_locations[i].rect.tl_corner(),chip_locations[i].angle); + rot_rect += rotate_point(cent,chip_locations[i].rect.tr_corner(),chip_locations[i].angle); + rot_rect += rotate_point(cent,chip_locations[i].rect.bl_corner(),chip_locations[i].angle); + rot_rect += rotate_point(cent,chip_locations[i].rect.br_corner(),chip_locations[i].angle); + bounding_box += grow_rect(rot_rect, grow).intersect(get_rect(img)); + max_depth = std::max(depth,max_depth); + } + //std::cout << "max_depth: " << max_depth << std::endl; + //std::cout << "crop amount: " << bounding_box.area()/(double)get_rect(img).area() << std::endl; + + // now make an image pyramid + dlib::array::pixel_type> > levels(max_depth); + if (levels.size() != 0) + pyr(sub_image(img,bounding_box),levels[0]); + for (unsigned long i = 1; i < levels.size(); ++i) + pyr(levels[i-1],levels[i]); + + std::vector > from, to; + + // now pull out the chips + chips.resize(chip_locations.size()); + for (unsigned long i = 0; i < chips.size(); ++i) + { + // If the chip doesn't have any rotation or scaling then use the basic version + // of chip extraction that just does a fast copy. + if (chip_locations[i].angle == 0 && + chip_locations[i].rows == chip_locations[i].rect.height() && + chip_locations[i].cols == chip_locations[i].rect.width()) + { + impl::basic_extract_image_chip(img, chip_locations[i].rect, chips[i]); + } + else + { + set_image_size(chips[i], chip_locations[i].rows, chip_locations[i].cols); + + // figure out which level in the pyramid to use to extract the chip + int level = -1; + drectangle rect = translate_rect(chip_locations[i].rect, -bounding_box.tl_corner()); + while (pyr.rect_down(rect).area() > chip_locations[i].size()) + { + ++level; + rect = pyr.rect_down(rect); + } + + // find the appropriate transformation that maps from the chip to the input + // image + from.clear(); + to.clear(); + from.push_back(get_rect(chips[i]).tl_corner()); to.push_back(rotate_point(center(rect),rect.tl_corner(),chip_locations[i].angle)); + from.push_back(get_rect(chips[i]).tr_corner()); to.push_back(rotate_point(center(rect),rect.tr_corner(),chip_locations[i].angle)); + from.push_back(get_rect(chips[i]).bl_corner()); to.push_back(rotate_point(center(rect),rect.bl_corner(),chip_locations[i].angle)); + point_transform_affine trns = find_affine_transform(from,to); + + // now extract the actual chip + if (level == -1) + transform_image(sub_image(img,bounding_box),chips[i],interp,trns); + else + transform_image(levels[level],chips[i],interp,trns); + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type1, + typename image_type2 + > + void extract_image_chips( + const image_type1& img, + const std::vector& chip_locations, + dlib::array& chips + ) + { + extract_image_chips(img, chip_locations, chips, interpolate_bilinear()); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type1, + typename image_type2, + typename interpolation_type + > + void extract_image_chip ( + const image_type1& img, + const chip_details& location, + image_type2& chip, + const interpolation_type& interp + ) + { + // If the chip doesn't have any rotation or scaling then use the basic version of + // chip extraction that just does a fast copy. + if (location.angle == 0 && + location.rows == location.rect.height() && + location.cols == location.rect.width()) + { + impl::basic_extract_image_chip(img, location.rect, chip); + } + else + { + std::vector chip_locations(1,location); + dlib::array chips; + extract_image_chips(img, chip_locations, chips, interp); + swap(chips[0], chip); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type1, + typename image_type2 + > + void extract_image_chip ( + const image_type1& img, + const chip_details& location, + image_type2& chip + ) + { + extract_image_chip(img, location, chip, interpolate_bilinear()); + } + +// ---------------------------------------------------------------------------------------- + + inline chip_details get_face_chip_details ( + const full_object_detection& det, + const unsigned long size = 200, + const double padding = 0.2 + ) + { + DLIB_CASSERT(det.num_parts() == 68 || det.num_parts() == 5, + "\t chip_details get_face_chip_details()" + << "\n\t You have to give either a 5 point or 68 point face landmarking output to this function. " + << "\n\t det.num_parts(): " << det.num_parts() + ); + DLIB_CASSERT(padding >= 0 && size > 0, + "\t chip_details get_face_chip_details()" + << "\n\t Invalid inputs were given to this function." + << "\n\t padding: " << padding + << "\n\t size: " << size + ); + + + std::vector from_points, to_points; + if (det.num_parts() == 5) + { + dpoint p0(0.8595674595992, 0.2134981538014); + dpoint p1(0.6460604764104, 0.2289674387677); + dpoint p2(0.1205750620789, 0.2137274526848); + dpoint p3(0.3340850613712, 0.2290642403242); + dpoint p4(0.4901123135679, 0.6277975316475); + + + p0 = (padding+p0)/(2*padding+1); + p1 = (padding+p1)/(2*padding+1); + p2 = (padding+p2)/(2*padding+1); + p3 = (padding+p3)/(2*padding+1); + p4 = (padding+p4)/(2*padding+1); + + from_points.push_back(p0*size); + to_points.push_back(det.part(0)); + + from_points.push_back(p1*size); + to_points.push_back(det.part(1)); + + from_points.push_back(p2*size); + to_points.push_back(det.part(2)); + + from_points.push_back(p3*size); + to_points.push_back(det.part(3)); + + from_points.push_back(p4*size); + to_points.push_back(det.part(4)); + } + else + { + // Average positions of face points 17-67 + const double mean_face_shape_x[] = { + 0.000213256, 0.0752622, 0.18113, 0.29077, 0.393397, 0.586856, 0.689483, 0.799124, + 0.904991, 0.98004, 0.490127, 0.490127, 0.490127, 0.490127, 0.36688, 0.426036, + 0.490127, 0.554217, 0.613373, 0.121737, 0.187122, 0.265825, 0.334606, 0.260918, + 0.182743, 0.645647, 0.714428, 0.793132, 0.858516, 0.79751, 0.719335, 0.254149, + 0.340985, 0.428858, 0.490127, 0.551395, 0.639268, 0.726104, 0.642159, 0.556721, + 0.490127, 0.423532, 0.338094, 0.290379, 0.428096, 0.490127, 0.552157, 0.689874, + 0.553364, 0.490127, 0.42689 + }; + const double mean_face_shape_y[] = { + 0.106454, 0.038915, 0.0187482, 0.0344891, 0.0773906, 0.0773906, 0.0344891, + 0.0187482, 0.038915, 0.106454, 0.203352, 0.307009, 0.409805, 0.515625, 0.587326, + 0.609345, 0.628106, 0.609345, 0.587326, 0.216423, 0.178758, 0.179852, 0.231733, + 0.245099, 0.244077, 0.231733, 0.179852, 0.178758, 0.216423, 0.244077, 0.245099, + 0.780233, 0.745405, 0.727388, 0.742578, 0.727388, 0.745405, 0.780233, 0.864805, + 0.902192, 0.909281, 0.902192, 0.864805, 0.784792, 0.778746, 0.785343, 0.778746, + 0.784792, 0.824182, 0.831803, 0.824182 + }; + + COMPILE_TIME_ASSERT(sizeof(mean_face_shape_x)/sizeof(double) == 68-17); + + for (unsigned long i = 17; i < det.num_parts(); ++i) + { + // Ignore the lower lip + if ((55 <= i && i <= 59) || (65 <= i && i <= 67)) + continue; + // Ignore the eyebrows + if (17 <= i && i <= 26) + continue; + + dpoint p; + p.x() = (padding+mean_face_shape_x[i-17])/(2*padding+1); + p.y() = (padding+mean_face_shape_y[i-17])/(2*padding+1); + from_points.push_back(p*size); + to_points.push_back(det.part(i)); + } + } + + return chip_details(from_points, to_points, chip_dims(size,size)); + } + +// ---------------------------------------------------------------------------------------- + + inline std::vector get_face_chip_details ( + const std::vector& dets, + const unsigned long size = 200, + const double padding = 0.2 + ) + { + std::vector res; + res.reserve(dets.size()); + for (unsigned long i = 0; i < dets.size(); ++i) + res.push_back(get_face_chip_details(dets[i], size, padding)); + return res; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + image_type jitter_image( + const image_type& img, + dlib::rand& rnd + ) + { + DLIB_CASSERT(num_rows(img)*num_columns(img) != 0); + DLIB_CASSERT(num_rows(img)==num_columns(img)); + + const double max_rotation_degrees = 3; + const double min_object_height = 0.97; + const double max_object_height = 0.99999; + const double translate_amount = 0.02; + + + const auto rect = shrink_rect(get_rect(img),3); + + // perturb the location of the crop by a small fraction of the object's size. + const point rand_translate = dpoint(rnd.get_double_in_range(-translate_amount,translate_amount)*rect.width(), + rnd.get_double_in_range(-translate_amount,translate_amount)*rect.height()); + + // perturb the scale of the crop by a fraction of the object's size + const double rand_scale_perturb = rnd.get_double_in_range(min_object_height, max_object_height); + + const long box_size = rect.height()/rand_scale_perturb; + const auto crop_rect = centered_rect(center(rect)+rand_translate, box_size, box_size); + const double angle = rnd.get_double_in_range(-max_rotation_degrees, max_rotation_degrees)*pi/180; + image_type crop; + extract_image_chip(img, chip_details(crop_rect, chip_dims(img.nr(),img.nc()), angle), crop); + if (rnd.get_random_double() > 0.5) + flip_image_left_right(crop); + + return crop; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_INTERPOlATIONh_ + diff --git a/ml/dlib/dlib/image_transforms/interpolation_abstract.h b/ml/dlib/dlib/image_transforms/interpolation_abstract.h new file mode 100644 index 000000000..f2da2fb02 --- /dev/null +++ b/ml/dlib/dlib/image_transforms/interpolation_abstract.h @@ -0,0 +1,1480 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_INTERPOlATION_ABSTRACT_ +#ifdef DLIB_INTERPOlATION_ABSTRACT_ + +#include "../pixel.h" +#include "../image_processing/full_object_detection_abstract.h" +#include "../image_processing/generic_image.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class interpolate_nearest_neighbor + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a tool for performing nearest neighbor interpolation + on an image. + !*/ + + public: + + template < + typename image_view_type, + typename pixel_type + > + bool operator() ( + const image_view_type& img, + const dlib::point& p, + pixel_type& result + ) const; + /*! + requires + - image_view_type == an image_view or const_image_view object. + - pixel_traits::has_alpha == false + - pixel_traits is defined + ensures + - if (p is located inside img) then + - #result == img[p.y()][p.x()] + (This assignment is done using assign_pixel(#result, img[p.y()][p.x()]), + therefore any necessary color space conversion will be performed) + - returns true + - else + - returns false + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + class interpolate_bilinear + { + + /*! + WHAT THIS OBJECT REPRESENTS + This object is a tool for performing bilinear interpolation + on an image. This is performed by looking at the 4 pixels + nearest to a point and deriving an interpolated value from them. + !*/ + + public: + + template < + typename T, + typename image_view_type, + typename pixel_type + > + bool operator() ( + const image_view_type& img, + const dlib::vector& p, + pixel_type& result + ) const; + /*! + requires + - image_view_type == an image_view or const_image_view object + - pixel_traits::has_alpha == false + - pixel_traits is defined + ensures + - if (there is an interpolatable image location at point p in img) then + - #result == the interpolated pixel value from img at point p. + - assign_pixel() will be used to write to #result, therefore any + necessary color space conversion will be performed. + - returns true + - if img contains RGB pixels then the interpolation will be in color. + Otherwise, the interpolation will be performed in a grayscale mode. + - else + - returns false + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + class interpolate_quadratic + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a tool for performing quadratic interpolation + on an image. This is performed by looking at the 9 pixels + nearest to a point and deriving an interpolated value from them. + !*/ + + public: + + template < + typename T, + typename image_view_type, + typename pixel_type + > + bool operator() ( + const image_view_type& img, + const dlib::vector& p, + pixel_type& result + ) const; + /*! + requires + - image_view_type == an image_view or const_image_view object. + - pixel_traits::has_alpha == false + - pixel_traits is defined + ensures + - if (there is an interpolatable image location at point p in img) then + - #result == the interpolated pixel value from img at point p + - assign_pixel() will be used to write to #result, therefore any + necessary color space conversion will be performed. + - returns true + - if img contains RGB pixels then the interpolation will be in color. + Otherwise, the interpolation will be performed in a grayscale mode. + - else + - returns false + !*/ + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class black_background + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a function object which simply sets a pixel + to have a black value. + !*/ + + public: + template + void operator() ( pixel_type& p) const { assign_pixel(p, 0); } + }; + +// ---------------------------------------------------------------------------------------- + + class white_background + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a function object which simply sets a pixel + to have a white value. + !*/ + + public: + template + void operator() ( pixel_type& p) const { assign_pixel(p, 255); } + }; + +// ---------------------------------------------------------------------------------------- + + class no_background + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a function object which does nothing. It is useful + when used with the transform_image() routine defined below + if no modification of uninterpolated output pixels is desired. + !*/ + public: + template + void operator() ( pixel_type& ) const { } + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename image_type1, + typename image_type2, + typename interpolation_type, + typename point_mapping_type, + typename background_type + > + void transform_image ( + const image_type1& in_img, + image_type2& out_img, + const interpolation_type& interp, + const point_mapping_type& map_point, + const background_type& set_background, + const rectangle& area + ); + /*! + requires + - image_type1 == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - image_type2 == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - interpolation_type == interpolate_nearest_neighbor, interpolate_bilinear, + interpolate_quadratic, or a type with a compatible interface. + - map_point should be a function which takes dlib::vector objects and + returns dlib::vector objects. An example is point_transform_affine. + - set_background should be a function which can take a single argument of + type image_traits::pixel_type. Examples are black_background, + white_background, and no_background. + - get_rect(out_img).contains(area) == true + - is_same_object(in_img, out_img) == false + ensures + - The map_point function defines a mapping from pixels in out_img to pixels + in in_img. transform_image() uses this mapping, along with the supplied + interpolation routine interp, to fill the region of out_img defined by + area with an interpolated copy of in_img. + - This function does not change the size of out_img. + - Only pixels inside the region defined by area in out_img are modified. + - For all locations r and c such that area.contains(c,r) but have no corresponding + locations in in_img: + - set_background(out_img[r][c]) is invoked + (i.e. some parts of out_img might correspond to areas outside in_img and + therefore can't supply interpolated values. In these cases, these + pixels can be assigned a value by the supplied set_background() routine) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type1, + typename image_type2, + typename interpolation_type, + typename point_mapping_type, + typename background_type + > + void transform_image ( + const image_type1& in_img, + image_type2& out_img, + const interpolation_type& interp, + const point_mapping_type& map_point, + const background_type& set_background + ); + /*! + requires + - image_type1 == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - image_type2 == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - interpolation_type == interpolate_nearest_neighbor, interpolate_bilinear, + interpolate_quadratic, or a type with a compatible interface. + - map_point should be a function which takes dlib::vector objects and + returns dlib::vector objects. An example is point_transform_affine. + - set_background should be a function which can take a single argument of + type image_traits::pixel_type. Examples are black_background, white_background, + and no_background. + - is_same_object(in_img, out_img) == false + ensures + - performs: + transform_image(in_img, out_img, interp, map_point, set_background, get_rect(out_img)); + (i.e. runs transform_image() on the entire out_img) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type1, + typename image_type2, + typename interpolation_type, + typename point_mapping_type + > + void transform_image ( + const image_type1& in_img, + image_type2& out_img, + const interpolation_type& interp, + const point_mapping_type& map_point + ); + /*! + requires + - image_type1 == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - image_type2 == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - interpolation_type == interpolate_nearest_neighbor, interpolate_bilinear, + interpolate_quadratic, or a type with a compatible interface. + - map_point should be a function which takes dlib::vector objects and + returns dlib::vector objects. An example is point_transform_affine. + - is_same_object(in_img, out_img) == false + ensures + - performs: + transform_image(in_img, out_img, interp, map_point, black_background(), get_rect(out_img)); + (i.e. runs transform_image() on the entire out_img and sets non-interpolated + pixels to black) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type1, + typename image_type2, + typename interpolation_type + > + point_transform_affine rotate_image ( + const image_type1& in_img, + image_type2& out_img, + double angle, + const interpolation_type& interp + ); + /*! + requires + - image_type1 == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - image_type2 == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - interpolation_type == interpolate_nearest_neighbor, interpolate_bilinear, + interpolate_quadratic, or a type with a compatible interface. + - is_same_object(in_img, out_img) == false + ensures + - #out_img == a copy of in_img which has been rotated angle radians counter clockwise. + The rotation is performed with respect to the center of the image. + - Parts of #out_img which have no corresponding locations in in_img are set to black. + - uses the supplied interpolation routine interp to perform the necessary + pixel interpolation. + - returns a transformation object that maps points in in_img into their corresponding + location in #out_img. + !*/ + +// ---------------------------------------------------------------------------------------- + + + template < + typename image_type1, + typename image_type2 + > + point_transform_affine rotate_image ( + const image_type1& in_img, + image_type2& out_img, + double angle + ); + /*! + requires + - image_type1 == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - image_type2 == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - pixel_traits::pixel_type>::has_alpha == false + - is_same_object(in_img, out_img) == false + ensures + - #out_img == a copy of in_img which has been rotated angle radians counter clockwise. + The rotation is performed with respect to the center of the image. + - Parts of #out_img which have no corresponding locations in in_img are set to black. + - uses the interpolate_quadratic object to perform the necessary pixel interpolation. + - returns a transformation object that maps points in in_img into their corresponding + location in #out_img. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type1, + typename image_type2, + typename interpolation_type + > + void resize_image ( + const image_type1& in_img, + image_type2& out_img, + const interpolation_type& interp + ); + /*! + requires + - image_type1 == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - image_type2 == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - interpolation_type == interpolate_nearest_neighbor, interpolate_bilinear, + interpolate_quadratic, or a type with a compatible interface. + - is_same_object(in_img, out_img) == false + ensures + - #out_img == A copy of in_img which has been stretched so that it + fits exactly into out_img. + - The size of out_img is not modified. I.e. + - #out_img.nr() == out_img.nr() + - #out_img.nc() == out_img.nc() + - uses the supplied interpolation routine interp to perform the necessary + pixel interpolation. + !*/ + +// ---------------------------------------------------------------------------------------- + + + template < + typename image_type1, + typename image_type2 + > + void resize_image ( + const image_type1& in_img, + image_type2& out_img + ); + /*! + requires + - image_type1 == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - image_type2 == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - pixel_traits::pixel_type>::has_alpha == false + - is_same_object(in_img, out_img) == false + ensures + - #out_img == A copy of in_img which has been stretched so that it + fits exactly into out_img. + - The size of out_img is not modified. I.e. + - #out_img.nr() == out_img.nr() + - #out_img.nc() == out_img.nc() + - Uses the bilinear interpolation to perform the necessary pixel interpolation. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + void resize_image ( + double size_scale, + image_type& img + ); + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - pixel_traits::pixel_type>::has_alpha == false + ensures + - Resizes img so that each of it's dimensions are size_scale times larger than img. + In particular, we will have: + - #img.nr() == std::round(size_scale*img.nr()) + - #img.nc() == std::round(size_scale*img.nc()) + - #img == a bilinearly interpolated copy of the input image. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type1, + typename image_type2 + > + point_transform_affine flip_image_left_right ( + const image_type1& in_img, + image_type2& out_img + ); + /*! + requires + - image_type1 == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - image_type2 == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - is_same_object(in_img, out_img) == false + ensures + - #out_img.nr() == in_img.nr() + - #out_img.nc() == in_img.nc() + - #out_img == a copy of in_img which has been flipped from left to right. + (i.e. it is flipped as if viewed though a mirror) + - returns a transformation object that maps points in in_img into their + corresponding location in #out_img. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + point_transform_affine flip_image_left_right ( + image_type& img + ); + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + ensures + - This function is identical to the above version of flip_image_left_right() + except that it operates in-place. + - #img.nr() == img.nr() + - #img.nc() == img.nc() + - #img == a copy of img which has been flipped from left to right. + (i.e. it is flipped as if viewed though a mirror) + - returns a transformation object that maps points in img into their + corresponding location in #img. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_array_type, + typename T + > + void add_image_left_right_flips ( + image_array_type& images, + std::vector >& objects + ); + /*! + requires + - image_array_type == a dlib::array or std::vector of image objects that each + implement the interface defined in dlib/image_processing/generic_image.h + - T == rectangle, full_object_detection, or mmod_rect + - images.size() == objects.size() + ensures + - This function computes all the left/right flips of the contents of images and + then appends them onto the end of the images array. It also finds the + left/right flips of the rectangles in objects and similarly appends them into + objects. That is, we assume objects[i] is the set of bounding boxes in + images[i] and we flip the bounding boxes so that they still bound the same + objects in the new flipped images. + - #images.size() == images.size()*2 + - #objects.size() == objects.size()*2 + - All the original elements of images and objects are left unmodified. That + is, this function only appends new elements to each of these containers. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_array_type, + typename T, + typename U + > + void add_image_left_right_flips ( + image_array_type& images, + std::vector >& objects, + std::vector >& objects2 + ); + /*! + requires + - image_array_type == a dlib::array or std::vector of image objects that each + implement the interface defined in dlib/image_processing/generic_image.h + - images.size() == objects.size() + - images.size() == objects2.size() + - T == rectangle, full_object_detection, or mmod_rect + - U == rectangle, full_object_detection, or mmod_rect + ensures + - This function computes all the left/right flips of the contents of images and + then appends them onto the end of the images array. It also finds the + left/right flips of the rectangles in objects and objects2 and similarly + appends them into objects and objects2 respectively. That is, we assume + objects[i] is the set of bounding boxes in images[i] and we flip the bounding + boxes so that they still bound the same objects in the new flipped images. + We similarly flip the boxes in objects2. + - #images.size() == images.size()*2 + - #objects.size() == objects.size()*2 + - #objects2.size() == objects2.size()*2 + - All the original elements of images, objects, and objects2 are left unmodified. + That is, this function only appends new elements to each of these containers. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_array_type, + typename EXP, + typename T, + typename U + > + void add_image_rotations ( + const matrix_exp& angles, + image_array_type& images, + std::vector >& objects, + std::vector >& objects2 + ); + /*! + requires + - image_array_type == a dlib::array or std::vector of image objects that each + implement the interface defined in dlib/image_processing/generic_image.h + - is_vector(angles) == true + - angles.size() > 0 + - images.size() == objects.size() + - images.size() == objects2.size() + - T == rectangle, full_object_detection, or mmod_rect + - U == rectangle, full_object_detection, or mmod_rect + ensures + - This function computes angles.size() different rotations of all the given + images and then replaces the contents of images with those rotations of the + input dataset. We will also adjust the rectangles inside objects and + objects2 so that they still bound the same objects in the new rotated images. + That is, we assume objects[i] and objects2[i] are bounding boxes for things + in images[i]. So we will adjust the positions of the boxes in objects and + objects2 accordingly. + - The elements of angles are interpreted as angles in radians and we will + rotate the images around their center using the values in angles. Moreover, + the rotation is done counter clockwise. + - #images.size() == images.size()*angles.size() + - #objects.size() == objects.size()*angles.size() + - #objects2.size() == objects2.size()*angles.size() + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_array_type, + typename EXP, + typename T + > + void add_image_rotations ( + const matrix_exp& angles, + image_array_type& images, + std::vector >& objects + ); + /*! + requires + - image_array_type == a dlib::array or std::vector of image objects that each + implement the interface defined in dlib/image_processing/generic_image.h + - is_vector(angles) == true + - angles.size() > 0 + - images.size() == objects.size() + - T == rectangle, full_object_detection, or mmod_rect + ensures + - This function is identical to the add_image_rotations() define above except + that it doesn't have objects2 as an argument. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_array_type + > + void flip_image_dataset_left_right ( + image_array_type& images, + std::vector >& objects + ); + /*! + requires + - image_array_type == a dlib::array or std::vector of image objects that each + implement the interface defined in dlib/image_processing/generic_image.h + - images.size() == objects.size() + ensures + - This function replaces each image in images with the left/right flipped + version of the image. Therefore, #images[i] will contain the left/right + flipped version of images[i]. It also flips all the rectangles in objects so + that they still bound the same visual objects in each image. + - #images.size() == image.size() + - #objects.size() == objects.size() + - for all valid i: + #objects[i].size() == objects[i].size() + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_array_type + > + void flip_image_dataset_left_right ( + image_array_type& images, + std::vector >& objects, + std::vector >& objects2 + ); + /*! + requires + - image_array_type == a dlib::array or std::vector of image objects that each + implement the interface defined in dlib/image_processing/generic_image.h + - images.size() == objects.size() + - images.size() == objects2.size() + ensures + - This function replaces each image in images with the left/right flipped + version of the image. Therefore, #images[i] will contain the left/right + flipped version of images[i]. It also flips all the rectangles in objects + and objects2 so that they still bound the same visual objects in each image. + - #images.size() == image.size() + - #objects.size() == objects.size() + - #objects2.size() == objects2.size() + - for all valid i: + #objects[i].size() == objects[i].size() + - for all valid i: + #objects2[i].size() == objects2[i].size() + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename pyramid_type, + typename image_array_type + > + void upsample_image_dataset ( + image_array_type& images, + std::vector >& objects, + unsigned long max_image_size = std::numeric_limits::max() + ); + /*! + requires + - image_array_type == a dlib::array or std::vector of image objects that each + implement the interface defined in dlib/image_processing/generic_image.h + - images.size() == objects.size() + ensures + - This function replaces each image in images with an upsampled version of that + image. Each image is upsampled using pyramid_up() and the given + pyramid_type. Therefore, #images[i] will contain the larger upsampled + version of images[i]. It also adjusts all the rectangles in objects so that + they still bound the same visual objects in each image. + - Input images already containing more than max_image_size pixels are not upsampled. + - #images.size() == image.size() + - #objects.size() == objects.size() + - for all valid i: + #objects[i].size() == objects[i].size() + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename pyramid_type, + typename image_array_type + > + void upsample_image_dataset ( + image_array_type& images, + std::vector>& objects, + unsigned long max_image_size = std::numeric_limits::max() + ); + /*! + requires + - image_array_type == a dlib::array or std::vector of image objects that each + implement the interface defined in dlib/image_processing/generic_image.h + - images.size() == objects.size() + ensures + - This function replaces each image in images with an upsampled version of that + image. Each image is upsampled using pyramid_up() and the given + pyramid_type. Therefore, #images[i] will contain the larger upsampled + version of images[i]. It also adjusts all the rectangles in objects so that + they still bound the same visual objects in each image. + - Input images already containing more than max_image_size pixels are not upsampled. + - #images.size() == image.size() + - #objects.size() == objects.size() + - for all valid i: + #objects[i].size() == objects[i].size() + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename pyramid_type, + typename image_array_type, + > + void upsample_image_dataset ( + image_array_type& images, + std::vector >& objects, + std::vector >& objects2, + unsigned long max_image_size = std::numeric_limits::max() + ); + /*! + requires + - image_array_type == a dlib::array or std::vector of image objects that each + implement the interface defined in dlib/image_processing/generic_image.h + - images.size() == objects.size() + - images.size() == objects2.size() + ensures + - This function replaces each image in images with an upsampled version of that + image. Each image is upsampled using pyramid_up() and the given + pyramid_type. Therefore, #images[i] will contain the larger upsampled + version of images[i]. It also adjusts all the rectangles in objects and + objects2 so that they still bound the same visual objects in each image. + - Input images already containing more than max_image_size pixels are not upsampled. + - #images.size() == image.size() + - #objects.size() == objects.size() + - #objects2.size() == objects2.size() + - for all valid i: + #objects[i].size() == objects[i].size() + - for all valid i: + #objects2[i].size() == objects2[i].size() + !*/ + +// ---------------------------------------------------------------------------------------- + + template + void rotate_image_dataset ( + double angle, + image_array_type& images, + std::vector >& objects + ); + /*! + requires + - image_array_type == a dlib::array or std::vector of image objects that each + implement the interface defined in dlib/image_processing/generic_image.h + - images.size() == objects.size() + ensures + - This function replaces each image in images with a rotated version of that + image. In particular, each image is rotated using + rotate_image(original,rotated,angle). Therefore, the images are rotated + angle radians counter clockwise around their centers. That is, #images[i] + will contain the rotated version of images[i]. It also adjusts all + the rectangles in objects so that they still bound the same visual objects in + each image. + - All the rectangles will still have the same sizes and aspect ratios after + rotation. They will simply have had their positions adjusted so they still + fall on the same objects. + - #images.size() == image.size() + - #objects.size() == objects.size() + - for all valid i: + #objects[i].size() == objects[i].size() + !*/ + +// ---------------------------------------------------------------------------------------- + + template + void rotate_image_dataset ( + double angle, + image_array_type& images, + std::vector >& objects, + std::vector >& objects2 + ); + /*! + requires + - image_array_type == a dlib::array or std::vector of image objects that each + implement the interface defined in dlib/image_processing/generic_image.h + - images.size() == objects.size() + - images.size() == objects2.size() + ensures + - This function replaces each image in images with a rotated version of that + image. In particular, each image is rotated using + rotate_image(original,rotated,angle). Therefore, the images are rotated + angle radians counter clockwise around their centers. That is, #images[i] + will contain the rotated version of images[i]. It also adjusts all + the rectangles in objects and objects2 so that they still bound the same + visual objects in each image. + - All the rectangles will still have the same sizes and aspect ratios after + rotation. They will simply have had their positions adjusted so they still + fall on the same objects. + - #images.size() == image.size() + - #objects.size() == objects.size() + - #objects2.size() == objects2.size() + - for all valid i: + #objects[i].size() == objects[i].size() + - for all valid i: + #objects2[i].size() == objects2[i].size() + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type1, + typename image_type2 + > + void flip_image_up_down ( + const image_type1& in_img, + image_type2& out_img + ); + /*! + requires + - image_type1 == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - image_type2 == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - is_same_object(in_img, out_img) == false + ensures + - #out_img.nr() == in_img.nr() + - #out_img.nc() == in_img.nc() + - #out_img == a copy of in_img which has been flipped upside down. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type1, + typename image_type2, + typename pyramid_type, + typename interpolation_type + > + void pyramid_up ( + const image_type1& in_img, + image_type2& out_img, + const pyramid_type& pyr, + const interpolation_type& interp + ); + /*! + requires + - image_type1 == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - image_type2 == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - pyramid_type == a type compatible with the image pyramid objects defined + in dlib/image_transforms/image_pyramid_abstract.h + - interpolation_type == interpolate_nearest_neighbor, interpolate_bilinear, + interpolate_quadratic, or a type with a compatible interface. + - is_same_object(in_img, out_img) == false + ensures + - This function inverts the downsampling transformation performed by pyr(). + In particular, it attempts to make an image, out_img, which would result + in in_img when downsampled with pyr(). + - #out_img == An upsampled copy of in_img. In particular, downsampling + #out_img 1 time with pyr() should result in a final image which looks like + in_img. + - Uses the supplied interpolation routine interp to perform the necessary + pixel interpolation. + - Note that downsampling an image with pyr() and then upsampling it with + pyramid_up() will not necessarily result in a final image which is + the same size as the original. This is because the exact size of the + original image cannot be determined based on the downsampled image. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type1, + typename image_type2, + typename pyramid_type + > + void pyramid_up ( + const image_type1& in_img, + image_type2& out_img, + const pyramid_type& pyr + ); + /*! + requires + - image_type1 == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - image_type2 == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - pyramid_type == a type compatible with the image pyramid objects defined + in dlib/image_transforms/image_pyramid_abstract.h + - is_same_object(in_img, out_img) == false + ensures + - performs: pyramid_up(in_img, out_img, pyr, interpolate_bilinear()); + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type, + typename pyramid_type + > + void pyramid_up ( + image_type& img, + const pyramid_type& pyr + ); + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - pyramid_type == a type compatible with the image pyramid objects defined + in dlib/image_transforms/image_pyramid_abstract.h + ensures + - Performs an in-place version of pyramid_up() on the given image. In + particular, this function is equivalent to: + pyramid_up(img, temp, pyr); + temp.swap(img); + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + void pyramid_up ( + image_type& img + ); + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + ensures + - performs: pyramid_up(img, pyramid_down<2>()); + (i.e. it upsamples the given image and doubles it in size.) + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + struct chip_dims + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a simple tool for passing in a pair of row and column values to the + chip_details constructor. + !*/ + + chip_dims ( + unsigned long rows_, + unsigned long cols_ + ) : rows(rows_), cols(cols_) { } + + unsigned long rows; + unsigned long cols; + }; + +// ---------------------------------------------------------------------------------------- + + struct chip_details + { + /*! + WHAT THIS OBJECT REPRESENTS + This object describes where an image chip is to be extracted from within + another image. In particular, it specifies that the image chip is + contained within the rectangle this->rect and that prior to extraction the + image should be rotated counter-clockwise by this->angle radians. Finally, + the extracted chip should have this->rows rows and this->cols columns in it + regardless of the shape of this->rect. This means that the extracted chip + will be stretched to fit via bilinear interpolation when necessary. + !*/ + + chip_details( + ); + /*! + ensures + - #rect.is_empty() == true + - #size() == 0 + - #angle == 0 + - #rows == 0 + - #cols == 0 + !*/ + + chip_details( + const drectangle& rect_ + ); + /*! + ensures + - #rect == rect_ + - #size() == rect_.area() + - #angle == 0 + - #rows == rect_.height() + - #cols == rect_.width() + !*/ + + chip_details( + const rectangle& rect_ + ); + /*! + ensures + - #rect == rect_ + - #size() == rect_.area() + - #angle == 0 + - #rows == rect_.height() + - #cols == rect_.width() + !*/ + + chip_details( + const drectangle& rect_, + unsigned long size_ + ); + /*! + ensures + - #rect == rect_ + - #size() == size_ + - #angle == 0 + - #rows and #cols is set such that the total size of the chip is as close + to size_ as possible but still matches the aspect ratio of rect_. + - As long as size_ and the aspect ratio of of rect_ stays constant then + #rows and #cols will always have the same values. This means that, for + example, if you want all your chips to have the same dimensions then + ensure that size_ is always the same and also that rect_ always has the + same aspect ratio. Otherwise the calculated values of #rows and #cols + may be different for different chips. Alternatively, you can use the + chip_details constructor below that lets you specify the exact values for + rows and cols. + !*/ + + chip_details( + const drectangle& rect_, + unsigned long size_, + double angle_ + ); + /*! + ensures + - #rect == rect_ + - #size() == size_ + - #angle == angle_ + - #rows and #cols is set such that the total size of the chip is as close + to size_ as possible but still matches the aspect ratio of rect_. + - As long as size_ and the aspect ratio of of rect_ stays constant then + #rows and #cols will always have the same values. This means that, for + example, if you want all your chips to have the same dimensions then + ensure that size_ is always the same and also that rect_ always has the + same aspect ratio. Otherwise the calculated values of #rows and #cols + may be different for different chips. Alternatively, you can use the + chip_details constructor below that lets you specify the exact values for + rows and cols. + !*/ + + chip_details( + const drectangle& rect_, + const chip_dims& dims + ); + /*! + ensures + - #rect == rect_ + - #size() == dims.rows*dims.cols + - #angle == 0 + - #rows == dims.rows + - #cols == dims.cols + !*/ + + chip_details( + const drectangle& rect_, + const chip_dims& dims, + double angle_ + ); + /*! + ensures + - #rect == rect_ + - #size() == dims.rows*dims.cols + - #angle == angle_ + - #rows == dims.rows + - #cols == dims.cols + !*/ + + template + chip_details( + const std::vector >& chip_points, + const std::vector >& img_points, + const chip_dims& dims + ); + /*! + requires + - chip_points.size() == img_points.size() + - chip_points.size() >= 2 + ensures + - The chip will be extracted such that the pixel locations chip_points[i] + in the chip are mapped to img_points[i] in the original image by a + similarity transform. That is, if you know the pixelwize mapping you + want between the chip and the original image then you use this function + of chip_details constructor to define the mapping. + - #rows == dims.rows + - #cols == dims.cols + - #size() == dims.rows*dims.cols + - #rect and #angle are computed based on the given size of the output chip + (specified by dims) and the similarity transform between the chip and + image (specified by chip_points and img_points). + !*/ + + inline unsigned long size() const { return rows*cols; } + /*! + ensures + - returns the number of pixels in this chip. This is just rows*cols. + !*/ + + drectangle rect; + double angle; + unsigned long rows; + unsigned long cols; + }; + +// ---------------------------------------------------------------------------------------- + + point_transform_affine get_mapping_to_chip ( + const chip_details& details + ); + /*! + ensures + - returns a transformation that maps from the pixels in the original image + to the pixels in the cropped image defined by the given details object. + !*/ + +// ---------------------------------------------------------------------------------------- + + full_object_detection map_det_to_chip ( + const full_object_detection& det, + const chip_details& details + ); + /*! + ensures + - Maps the given detection into the pixel space of the image chip defined by + the given details object. That is, this function returns an object D such + that: + - D.get_rect() == a box that bounds the same thing in the image chip as + det.get_rect() bounds in the original image the chip is extracted from. + - for all valid i: + - D.part(i) == the location in the image chip corresponding to + det.part(i) in the original image. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type1, + typename image_type2, + typename interpolation_type + > + void extract_image_chips ( + const image_type1& img, + const std::vector& chip_locations, + dlib::array& chips, + const interpolation_type& interp + ); + /*! + requires + - image_type1 == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - image_type2 == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - pixel_traits::pixel_type>::has_alpha == false + - for all valid i: + - chip_locations[i].rect.is_empty() == false + - chip_locations[i].size() != 0 + - interpolation_type == interpolate_nearest_neighbor, interpolate_bilinear, + interpolate_quadratic, or a type with a compatible interface. + ensures + - This function extracts "chips" from an image. That is, it takes a list of + rectangular sub-windows (i.e. chips) within an image and extracts those + sub-windows, storing each into its own image. It also scales and rotates the + image chips according to the instructions inside each chip_details object. + It uses the interpolation method supplied as a parameter. + - #chips == the extracted image chips + - #chips.size() == chip_locations.size() + - for all valid i: + - #chips[i] == The image chip extracted from the position + chip_locations[i].rect in img. + - #chips[i].nr() == chip_locations[i].rows + - #chips[i].nc() == chip_locations[i].cols + - The image will have been rotated counter-clockwise by + chip_locations[i].angle radians, around the center of + chip_locations[i].rect, before the chip was extracted. + - Any pixels in an image chip that go outside img are set to 0 (i.e. black). + !*/ + + template < + typename image_type1, + typename image_type2 + > + void extract_image_chips ( + const image_type1& img, + const std::vector& chip_locations, + dlib::array& chips + ); + /*! + ensures + - This function is a simple convenience / compatibility wrapper that calls the + above-defined extract_image_chips() function using bilinear interpolation. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type1, + typename image_type2, + typename interpolation_type + > + void extract_image_chip ( + const image_type1& img, + const chip_details& chip_location, + image_type2& chip, + const interpolation_type& interp + ); + /*! + ensures + - This function simply calls extract_image_chips() with a single chip location + and stores the single output chip into #chip. It uses the provided + interpolation method. + !*/ + + template < + typename image_type1, + typename image_type2 + > + void extract_image_chip ( + const image_type1& img, + const chip_details& chip_location, + image_type2& chip + ); + /*! + ensures + - This function is a simple convenience / compatibility wrapper that calls the + above-defined extract_image_chip() function using bilinear interpolation. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + struct sub_image_proxy + { + /*! + REQUIREMENTS ON image_type + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + + WHAT THIS OBJECT REPRESENTS + This is a lightweight image object for referencing a subwindow of an image. + It implements the generic image interface and can therefore be used with + any function that expects a generic image, excepting that you cannot change + the size of a sub_image_proxy. + + Note that it only stores a pointer to the image data given to its + constructor and therefore does not perform a copy. Moreover, this means + that an instance of this object becomes invalid after the underlying image + data it references is destroyed. + !*/ + sub_image_proxy ( + T& img, + const rectangle& rect + ); + /*! + ensures + - This object is an image that represents the part of img contained within + rect. If rect is larger than img then rect is cropped so that it does + not go outside img. + !*/ + }; + + template < + typename image_type + > + sub_image_proxy sub_image ( + image_type& img, + const rectangle& rect + ); + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + ensures + - returns sub_image_proxy(img,rect) + !*/ + + template + sub_image_proxy sub_image ( + T* img, + long nr, + long nc, + long row_stride + ); + /*! + requires + - img == a pointer to at least nr*row_stride T objects + - nr >= 0 + - nc >= 0 + - row_stride >= 0 + ensures + - This function returns an image that is just a thin wrapper around the given + pointer. It will have the dimensions defined by the supplied longs. To be + precise, this function returns an image object IMG such that: + - image_data(IMG) == img + - num_rows(IMG) == nr + - num_columns(IMG) == nc + - width_step(IMG) == row_stride*sizeof(T) + - IMG contains pixels of type T. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + struct const_sub_image_proxy + { + /*! + REQUIREMENTS ON image_type + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + + WHAT THIS OBJECT REPRESENTS + This object is just like sub_image_proxy except that it does not allow the + pixel data to be modified. + !*/ + const_sub_image_proxy ( + const T& img, + const rectangle& rect + ); + /*! + ensures + - This object is an image that represents the part of img contained within + rect. If rect is larger than img then rect is cropped so that it does + not go outside img. + !*/ + }; + + template < + typename image_type + > + const const_sub_image_proxy sub_image ( + const image_type& img, + const rectangle& rect + ); + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + ensures + - returns const_sub_image_proxy(img,rect) + !*/ + + template + const const_sub_image_proxy sub_image ( + const T* img, + long nr, + long nc, + long row_stride + ); + /*! + requires + - img == a pointer to at least nr*row_stride T objects + - nr >= 0 + - nc >= 0 + - row_stride >= 0 + ensures + - This function returns an image that is just a thin wrapper around the given + pointer. It will have the dimensions defined by the supplied longs. To be + precise, this function returns an image object IMG such that: + - image_data(IMG) == img + - num_rows(IMG) == nr + - num_columns(IMG) == nc + - width_step(IMG) == row_stride*sizeof(T) + - IMG contains pixels of type T. + !*/ + +// ---------------------------------------------------------------------------------------- + + chip_details get_face_chip_details ( + const full_object_detection& det, + const unsigned long size = 200, + const double padding = 0.2 + ); + /*! + requires + - det.num_parts() == 68 || det.num_parts() == 5 + - size > 0 + - padding >= 0 + ensures + - This function assumes det contains a human face detection with face parts + annotated using the annotation scheme from the iBUG 300-W face landmark + dataset or a 5 point face annotation. Given these assumptions, it creates a + chip_details object that will extract a copy of the face that has been + rotated upright, centered, and scaled to a standard size when given to + extract_image_chip(). + - This function is specifically calibrated to work with one of these models: + - http://dlib.net/files/shape_predictor_5_face_landmarks.dat.bz2 + - http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2 + - The extracted chips will have size rows and columns in them. + - if padding == 0 then the chip will be closely cropped around the face. + Setting larger padding values will result a looser cropping. In particular, + a padding of 0.5 would double the width of the cropped area, a value of 1 + would triple it, and so forth. + - The 5 point face annotation scheme is assumed to be: + - det part 0 == left eye corner, outside part of eye. + - det part 1 == left eye corner, inside part of eye. + - det part 2 == right eye corner, outside part of eye. + - det part 3 == right eye corner, inside part of eye. + - det part 4 == immediately under the nose, right at the top of the philtrum. + !*/ + +// ---------------------------------------------------------------------------------------- + + std::vector get_face_chip_details ( + const std::vector& dets, + const unsigned long size = 200, + const double padding = 0.2 + ); + /*! + requires + - for all valid i: + - det[i].num_parts() == 68 + - size > 0 + - padding >= 0 + ensures + - This function is identical to the version of get_face_chip_details() defined + above except that it creates and returns an array of chip_details objects, + one for each input full_object_detection. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + image_type jitter_image( + const image_type& img, + dlib::rand& rnd + ); + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - pixel_traits::pixel_type>::has_alpha == false + - img.size() > 0 + - img.nr() == img.nc() + ensures + - Randomly jitters the image a little bit and returns this new jittered image. + To be specific, the returned image has the same size as img and will look + generally similar. The difference is that the returned image will have been + slightly rotated, zoomed, and translated. There is also a 50% chance it will + be mirrored left to right. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_INTERPOlATION_ABSTRACT_ + diff --git a/ml/dlib/dlib/image_transforms/label_connected_blobs.h b/ml/dlib/dlib/image_transforms/label_connected_blobs.h new file mode 100644 index 000000000..c25346c76 --- /dev/null +++ b/ml/dlib/dlib/image_transforms/label_connected_blobs.h @@ -0,0 +1,188 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_LABEL_CONNeCTED_BLOBS_H_ +#define DLIB_LABEL_CONNeCTED_BLOBS_H_ + +#include "label_connected_blobs_abstract.h" +#include "../geometry.h" +#include +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + struct neighbors_8 + { + void operator() ( + const point& p, + std::vector& neighbors + ) const + { + neighbors.push_back(point(p.x()+1,p.y()+1)); + neighbors.push_back(point(p.x()+1,p.y() )); + neighbors.push_back(point(p.x()+1,p.y()-1)); + + neighbors.push_back(point(p.x(),p.y()+1)); + neighbors.push_back(point(p.x(),p.y()-1)); + + neighbors.push_back(point(p.x()-1,p.y()+1)); + neighbors.push_back(point(p.x()-1,p.y() )); + neighbors.push_back(point(p.x()-1,p.y()-1)); + } + }; + + struct neighbors_4 + { + void operator() ( + const point& p, + std::vector& neighbors + ) const + { + neighbors.push_back(point(p.x()+1,p.y())); + neighbors.push_back(point(p.x()-1,p.y())); + neighbors.push_back(point(p.x(),p.y()+1)); + neighbors.push_back(point(p.x(),p.y()-1)); + } + }; + +// ---------------------------------------------------------------------------------------- + + struct connected_if_both_not_zero + { + template + bool operator() ( + const image_type& img, + const point& a, + const point& b + ) const + { + return (img[a.y()][a.x()] != 0 && img[b.y()][b.x()] != 0); + } + }; + + struct connected_if_equal + { + template + bool operator() ( + const image_type& img, + const point& a, + const point& b + ) const + { + return (img[a.y()][a.x()] == img[b.y()][b.x()]); + } + }; + +// ---------------------------------------------------------------------------------------- + + struct zero_pixels_are_background + { + template + bool operator() ( + const image_type& img, + const point& p + ) const + { + return img[p.y()][p.x()] == 0; + } + + }; + + struct nothing_is_background + { + template + bool operator() ( + const image_type&, + const point& + ) const + { + return false; + } + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type, + typename label_image_type, + typename background_functor_type, + typename neighbors_functor_type, + typename connected_functor_type + > + unsigned long label_connected_blobs ( + const image_type& img_, + const background_functor_type& is_background, + const neighbors_functor_type& get_neighbors, + const connected_functor_type& is_connected, + label_image_type& label_img_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(is_same_object(img_, label_img_) == false, + "\t unsigned long label_connected_blobs()" + << "\n\t The input image and output label image can't be the same object." + ); + + const_image_view img(img_); + image_view label_img(label_img_); + + std::stack neighbors; + label_img.set_size(img.nr(), img.nc()); + assign_all_pixels(label_img, 0); + unsigned long next = 1; + + if (img.size() == 0) + return 0; + + const rectangle area = get_rect(img); + + std::vector window; + + for (long r = 0; r < img.nr(); ++r) + { + for (long c = 0; c < img.nc(); ++c) + { + // skip already labeled pixels or background pixels + if (label_img[r][c] != 0 || is_background(img,point(c,r))) + continue; + + label_img[r][c] = next; + + // label all the neighbors of this point + neighbors.push(point(c,r)); + while (neighbors.size() > 0) + { + const point p = neighbors.top(); + neighbors.pop(); + + window.clear(); + get_neighbors(p, window); + + for (unsigned long i = 0; i < window.size(); ++i) + { + if (area.contains(window[i]) && // point in image. + !is_background(img,window[i]) && // isn't background. + label_img[window[i].y()][window[i].x()] == 0 && // haven't already labeled it. + is_connected(img, p, window[i])) // it's connected. + { + label_img[window[i].y()][window[i].x()] = next; + neighbors.push(window[i]); + } + } + } + + ++next; + } + } + + return next; + } +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_LABEL_CONNeCTED_BLOBS_H_ + diff --git a/ml/dlib/dlib/image_transforms/label_connected_blobs_abstract.h b/ml/dlib/dlib/image_transforms/label_connected_blobs_abstract.h new file mode 100644 index 000000000..5dc984000 --- /dev/null +++ b/ml/dlib/dlib/image_transforms/label_connected_blobs_abstract.h @@ -0,0 +1,199 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_LABEL_CONNeCTED_BLOBS_ABSTRACT_H_ +#ifdef DLIB_LABEL_CONNeCTED_BLOBS_ABSTRACT_H_ + +#include "../geometry.h" +#include +#include "../image_processing/generic_image.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + struct neighbors_8 + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a pixel neighborhood generating functor for + use with the label_connected_blobs() routine defined below. + !*/ + + void operator() ( + const point& p, + std::vector& neighbors + ) const; + /*! + ensures + - adds the 8 neighboring pixels surrounding p into neighbors + !*/ + }; + + struct neighbors_4 + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a pixel neighborhood generating functor for + use with the label_connected_blobs() routine defined below. + !*/ + + void operator() ( + const point& p, + std::vector& neighbors + ) const; + /*! + ensures + - adds the 4 neighboring pixels of p into neighbors. These + are the ones immediately to the left, top, right, and bottom. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + struct connected_if_both_not_zero + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a pixel connection testing functor for use + with the label_connected_blobs() routine defined below. + !*/ + + template + bool operator() ( + const image_view_type& img, + const point& a, + const point& b + ) const + { + return (img[a.y()][a.x()] != 0 && img[b.y()][b.x()] != 0); + } + }; + + struct connected_if_equal + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a pixel connection testing functor for use + with the label_connected_blobs() routine defined below. + !*/ + + template + bool operator() ( + const image_view_type& img, + const point& a, + const point& b + ) const + { + return (img[a.y()][a.x()] == img[b.y()][b.x()]); + } + }; + +// ---------------------------------------------------------------------------------------- + + struct zero_pixels_are_background + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a background testing functor for use + with the label_connected_blobs() routine defined below. + !*/ + + template + bool operator() ( + const image_view_type& img, + const point& p + ) const + { + return img[p.y()][p.x()] == 0; + } + + }; + + struct nothing_is_background + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a background testing functor for use + with the label_connected_blobs() routine defined below. + !*/ + + template + bool operator() ( + const image_view_type&, + const point& + ) const + { + return false; + } + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type, + typename label_image_type, + typename background_functor_type, + typename neighbors_functor_type, + typename connected_functor_type + > + unsigned long label_connected_blobs ( + const image_type& img, + const background_functor_type& is_background, + const neighbors_functor_type& get_neighbors, + const connected_functor_type& is_connected, + label_image_type& label_img + ); + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - label_image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h and it must contain integer pixels. + - is_background(img, point(c,r)) is a legal expression that evaluates to a bool. + - is_connected(img, point(c,r), point(c2,r2)) is a legal expression that + evaluates to a bool. + - get_neighbors(point(c,r), neighbors) is a legal expression where neighbors + is of type std::vector. + - is_same_object(img, label_img) == false + ensures + - This function labels each of the connected blobs in img with a unique integer + label. + - An image can be thought of as a graph where pixels A and B are connected if + and only if the following two statements are satisfied: + - is_connected(img,A,B) == true + - get_neighbors(A, neighbors) results in neighbors containing B or + get_neighbors(B, neighbors) results in neighbors containing A. + Then this function can be understood as labeling all the connected components + of this pixel graph such that all pixels in a component get the same label while + pixels in different components get different labels. Note that there is a + special "background" component determined by is_background(). Any pixels which + are "background" always get a blob id of 0 regardless of any other considerations. + - #label_img.nr() == img.nr() + - #label_img.nc() == img.nc() + - for all valid r and c: + - #label_img[r][c] == the blob label number for pixel img[r][c]. + - #label_img[r][c] >= 0 + - if (is_background(img, point(c,r))) then + - #label_img[r][c] == 0 + - else + - #label_img[r][c] != 0 + - if (img.size() != 0) then + - returns max(mat(#label_img))+1 + (i.e. returns a number one greater than the maximum blob id number, + this is the number of blobs found.) + - else + - returns 0 + - blob labels are contiguous, therefore, the number returned by this function is + the number of blobs in the image (including the background blob). + - It is guaranteed that is_connected() and is_background() will never be + called with points outside the image. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_LABEL_CONNeCTED_BLOBS_ABSTRACT_H_ + diff --git a/ml/dlib/dlib/image_transforms/lbp.h b/ml/dlib/dlib/image_transforms/lbp.h new file mode 100644 index 000000000..b6bbac9cf --- /dev/null +++ b/ml/dlib/dlib/image_transforms/lbp.h @@ -0,0 +1,307 @@ +// Copyright (C) 2014 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_LBP_Hh_ +#define DLIB_LBP_Hh_ + +#include "lbp_abstract.h" +#include "../image_processing/generic_image.h" +#include "assign_image.h" +#include "../pixel.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type, + typename image_type2 + > + void make_uniform_lbp_image ( + const image_type& img_, + image_type2& lbp_ + ) + { + const static unsigned char uniform_lbps[] = { + 0, 1, 2, 3, 4, 58, 5, 6, 7, 58, 58, 58, 8, 58, 9, 10, 11, 58, 58, 58, 58, 58, + 58, 58, 12, 58, 58, 58, 13, 58, 14, 15, 16, 58, 58, 58, 58, 58, 58, 58, 58, 58, + 58, 58, 58, 58, 58, 58, 17, 58, 58, 58, 58, 58, 58, 58, 18, 58, 58, 58, 19, 58, + 20, 21, 22, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, + 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 23, 58, 58, 58, 58, 58, + 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 24, 58, 58, 58, 58, 58, 58, 58, 25, 58, + 58, 58, 26, 58, 27, 28, 29, 30, 58, 31, 58, 58, 58, 32, 58, 58, 58, 58, 58, 58, + 58, 33, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 34, 58, 58, + 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, + 58, 58, 58, 58, 58, 58, 58, 58, 58, 35, 36, 37, 58, 38, 58, 58, 58, 39, 58, 58, + 58, 58, 58, 58, 58, 40, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, + 58, 41, 42, 43, 58, 44, 58, 58, 58, 45, 58, 58, 58, 58, 58, 58, 58, 46, 47, 48, + 58, 49, 58, 58, 58, 50, 51, 52, 58, 53, 54, 55, 56, 57 + }; + + COMPILE_TIME_ASSERT(sizeof(uniform_lbps) == 256); + + const_image_view img(img_); + image_view lbp(lbp_); + + lbp.set_size(img.nr(), img.nc()); + + // set all the border pixels to the "non-uniform LBP value". + assign_border_pixels(lbp, 1, 1, 58); + + typedef typename image_traits::pixel_type pixel_type; + typedef typename pixel_traits::basic_pixel_type basic_pixel_type; + + for (long r = 1; r+1 < img.nr(); ++r) + { + for (long c = 1; c+1 < img.nc(); ++c) + { + const basic_pixel_type pix = get_pixel_intensity(img[r][c]); + unsigned char b1 = 0; + unsigned char b2 = 0; + unsigned char b3 = 0; + unsigned char b4 = 0; + unsigned char b5 = 0; + unsigned char b6 = 0; + unsigned char b7 = 0; + unsigned char b8 = 0; + + unsigned char x = 0; + if (get_pixel_intensity(img[r-1][c-1]) > pix) b1 = 0x80; + if (get_pixel_intensity(img[r-1][c ]) > pix) b2 = 0x40; + if (get_pixel_intensity(img[r-1][c+1]) > pix) b3 = 0x20; + x |= b1; + if (get_pixel_intensity(img[r ][c-1]) > pix) b4 = 0x10; + x |= b2; + if (get_pixel_intensity(img[r ][c+1]) > pix) b5 = 0x08; + x |= b3; + if (get_pixel_intensity(img[r+1][c-1]) > pix) b6 = 0x04; + x |= b4; + if (get_pixel_intensity(img[r+1][c ]) > pix) b7 = 0x02; + x |= b5; + if (get_pixel_intensity(img[r+1][c+1]) > pix) b8 = 0x01; + + x |= b6; + x |= b7; + x |= b8; + + lbp[r][c] = uniform_lbps[x]; + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type, + typename T + > + void extract_histogram_descriptors ( + const image_type& img_, + const point& loc, + std::vector& histograms, + const unsigned int cell_size = 10, + const unsigned int block_size = 4, + const unsigned int max_val = 58 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(cell_size >= 1 && block_size >= 1 && max_val < 256 && + (unsigned int)max(mat(img_)) <= max_val, + "\t void extract_histogram_descriptors()" + << "\n\t Invalid inputs were given to this function." + << "\n\t cell_size: " << cell_size + << "\n\t block_size: " << block_size + << "\n\t max_val: " << max_val + << "\n\t max(mat(img_)): " << max(mat(img_)) + ); + + typedef typename image_traits::pixel_type pixel_type; + COMPILE_TIME_ASSERT((is_same_type::value)); + + const_image_view img(img_); + + const rectangle area = get_rect(img); + const rectangle window = centered_rect(loc, block_size*cell_size, block_size*cell_size); + unsigned int cell_top = window.top(); + for (unsigned int br = 0; br < block_size; ++br) + { + unsigned int cell_left = window.left(); + for (unsigned int bc = 0; bc < block_size; ++bc) + { + // figure out the cell boundaries + rectangle cell(cell_left, cell_top, cell_left+cell_size-1, cell_top+cell_size-1); + cell = cell.intersect(area); + + // make the actual histogram for this cell + unsigned int hist[256] = {0}; + for (long r = cell.top(); r <= cell.bottom(); ++r) + { + for (long c = cell.left(); c <= cell.right(); ++c) + { + hist[img[r][c]]++; + } + } + + // copy histogram into the output. + histograms.insert(histograms.end(), hist, hist + max_val+1); + + cell_left += cell_size; + } + cell_top += cell_size; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type, + typename T + > + void extract_uniform_lbp_descriptors ( + const image_type& img, + std::vector& feats, + const unsigned int cell_size = 10 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(cell_size >= 1, + "\t void extract_uniform_lbp_descriptors()" + << "\n\t Invalid inputs were given to this function." + << "\n\t cell_size: " << cell_size + ); + + feats.clear(); + array2d lbp; + make_uniform_lbp_image(img, lbp); + for (long r = 0; r < lbp.nr(); r+=cell_size) + { + for (long c = 0; c < lbp.nc(); c+=cell_size) + { + const rectangle cell = rectangle(c,r,c+cell_size-1,r+cell_size-1).intersect(get_rect(lbp)); + // make the actual histogram for this cell + unsigned int hist[59] = {0}; + for (long r = cell.top(); r <= cell.bottom(); ++r) + { + for (long c = cell.left(); c <= cell.right(); ++c) + { + hist[lbp[r][c]]++; + } + } + + // copy histogram into the output. + feats.insert(feats.end(), hist, hist + 59); + } + } + + for (unsigned long i = 0; i < feats.size(); ++i) + feats[i] = std::sqrt(feats[i]); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type, + typename T + > + void extract_highdim_face_lbp_descriptors ( + const image_type& img, + const full_object_detection& det, + std::vector& feats + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(det.num_parts() == 68, + "\t void extract_highdim_face_lbp_descriptors()" + << "\n\t Invalid inputs were given to this function." + << "\n\t det.num_parts(): " << det.num_parts() + ); + + const unsigned long num_scales = 5; + feats.clear(); + dlib::vector l, r; + double cnt = 0; + // Find the center of the left eye by averaging the points around + // the eye. + for (unsigned long i = 36; i <= 41; ++i) + { + l += det.part(i); + ++cnt; + } + l /= cnt; + + // Find the center of the right eye by averaging the points around + // the eye. + cnt = 0; + for (unsigned long i = 42; i <= 47; ++i) + { + r += det.part(i); + ++cnt; + } + r /= cnt; + + // We only do feature extraction from these face parts. These are things like the + // corners of the eyes and mouth and stuff like that. + std::vector parts; + parts.reserve(30); + parts.push_back(l); + parts.push_back(r); + parts.push_back(det.part(17)); + parts.push_back(det.part(21)); + parts.push_back(det.part(22)); + parts.push_back(det.part(26)); + parts.push_back(det.part(36)); + parts.push_back(det.part(39)); + parts.push_back(det.part(42)); + parts.push_back(det.part(45)); + parts.push_back(det.part(27)); + parts.push_back(det.part(28)); + parts.push_back(det.part(29)); + parts.push_back(det.part(30)); + parts.push_back(det.part(31)); + parts.push_back(det.part(35)); + parts.push_back(det.part(33)); + parts.push_back(det.part(48)); + parts.push_back(det.part(54)); + parts.push_back(det.part(51)); + parts.push_back(det.part(57)); + + array2d lbp; + make_uniform_lbp_image(img, lbp); + for (unsigned long i = 0; i < parts.size(); ++i) + extract_histogram_descriptors(lbp, parts[i], feats); + + if (num_scales > 1) + { + pyramid_down<4> pyr; + image_type img_temp; + pyr(img, img_temp); + unsigned long num_pyr_calls = 1; + + // now pull the features out at coarser scales + for (unsigned long iter = 1; iter < num_scales; ++iter) + { + // now do the feature extraction + make_uniform_lbp_image(img_temp, lbp); + for (unsigned long i = 0; i < parts.size(); ++i) + extract_histogram_descriptors(lbp, pyr.point_down(parts[i],num_pyr_calls), feats); + + if (iter+1 < num_scales) + { + pyr(img_temp); + ++num_pyr_calls; + } + } + } + + for (unsigned long i = 0; i < feats.size(); ++i) + feats[i] = std::sqrt(feats[i]); + + DLIB_ASSERT(feats.size() == 99120, feats.size()); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_LBP_Hh_ + diff --git a/ml/dlib/dlib/image_transforms/lbp_abstract.h b/ml/dlib/dlib/image_transforms/lbp_abstract.h new file mode 100644 index 000000000..1a20082a2 --- /dev/null +++ b/ml/dlib/dlib/image_transforms/lbp_abstract.h @@ -0,0 +1,139 @@ +// Copyright (C) 2014 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_LBP_ABSTRACT_Hh_ +#ifdef DLIB_LBP_ABSTRACT_Hh_ + +#include "../image_processing/generic_image.h" +#include "../pixel.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type, + typename image_type2 + > + void make_uniform_lbp_image ( + const image_type& img, + image_type2& lbp + ); + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - image_type2 == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - image_type2 should contain a grayscale pixel type such as unsigned char. + ensures + - #lbp.nr() == img.nr() + - #lbp.nc() == img.nc() + - This function extracts the uniform local-binary-pattern feature at every pixel + and stores it into #lbp. In particular, we have the following for all valid + r and c: + - #lbp[r][c] == the uniform LBP for the 3x3 pixel window centered on img[r][c]. + In particular, this is a value in the range 0 to 58 inclusive. + - We use the idea of uniform LBPs from the paper: + Face Description with Local Binary Patterns: Application to Face Recognition + by Ahonen, Hadid, and Pietikainen. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type, + typename T + > + void extract_histogram_descriptors ( + const image_type& img, + const point& loc, + std::vector& histograms, + const unsigned int cell_size = 10, + const unsigned int block_size = 4, + const unsigned int max_val = 58 + ); + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - image_type contains unsigned char valued pixels. + - T is some scalar type like int or double + - All pixel values in img are <= max_val + - cell_size >= 1 + - block_size >= 1 + - max_val < 256 + ensures + - This function extracts histograms of pixel values from block_size*block_size + windows in the area in img immediately around img[loc.y()][loc.x()]. The + histograms are appended onto the end of #histograms. Each window is + cell_size pixels wide and tall. Moreover, the windows do not overlap. + - #histograms.size() == histograms.size() + block_size*block_size*(max_val+1) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type, + typename T + > + void extract_uniform_lbp_descriptors ( + const image_type& img, + std::vector& feats, + const unsigned int cell_size = 10 + ); + /*! + requires + - cell_size >= 1 + - T is some scalar type like int or double + ensures + - Extracts histograms of uniform local-binary-patterns from img. The + histograms are from densely tiled windows that are cell_size pixels wide and + tall. The windows do not overlap and cover all of img. + - #feats.size() == 59*(number of windows that fit into img) + (i.e. #feats contains the LBP histograms) + - We will have taken the square root of all the histogram elements. That is, + #feats[i] is the square root of the number of LBPs that appeared in its + corresponding window. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type, + typename T + > + void extract_highdim_face_lbp_descriptors ( + const image_type& img, + const full_object_detection& det, + std::vector& feats + ); + /*! + requires + - T is some scalar type like int or double + - det.num_parts() == 68 + ensures + - This function extracts the high-dimensional LBP feature described in the + paper: + Blessing of Dimensionality: High-dimensional Feature and Its Efficient + Compression for Face Verification by Dong Chen, Xudong Cao, Fang Wen, and + Jian Sun + - #feats == the high-dimensional LBP descriptor. It is the concatenation of + many LBP histograms, each extracted from different scales and from different + windows around different face landmarks. We also take the square root of + each histogram element before storing it into #feats. + - #feats.size() == 99120 + - This function assumes img has already been aligned and normalized to a + standard size. + - This function assumes det contains a human face detection with face parts + annotated using the annotation scheme from the iBUG 300-W face landmark + dataset. This means that det.part(i) gives the locations of different face + landmarks according to the iBUG 300-W annotation scheme. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_LBP_ABSTRACT_Hh_ + diff --git a/ml/dlib/dlib/image_transforms/morphological_operations.h b/ml/dlib/dlib/image_transforms/morphological_operations.h new file mode 100644 index 000000000..a659e4bdc --- /dev/null +++ b/ml/dlib/dlib/image_transforms/morphological_operations.h @@ -0,0 +1,846 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MORPHOLOGICAL_OPERATIONs_ +#define DLIB_MORPHOLOGICAL_OPERATIONs_ + +#include "../pixel.h" +#include "thresholding.h" +#include "morphological_operations_abstract.h" +#include "assign_image.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace morphological_operations_helpers + { + template + bool is_binary_image ( + const image_type& img_ + ) + /*! + ensures + - returns true if img_ contains only on_pixel and off_pixel values. + - returns false otherwise + !*/ + { + const_image_view img(img_); + for (long r = 0; r < img.nr(); ++r) + { + for (long c = 0; c < img.nc(); ++c) + { + if (img[r][c] != on_pixel && img[r][c] != off_pixel) + { + return false; + } + } + } + return true; + } + + template < + long M, + long N + > + bool is_binary_image ( + const unsigned char (&structuring_element)[M][N] + ) + /*! + ensures + - returns true if structuring_element contains only on_pixel and off_pixel values. + - returns false otherwise + !*/ + { + for (long m = 0; m < M; ++m) + { + for (long n = 0; n < N; ++n) + { + if (structuring_element[m][n] != on_pixel && + structuring_element[m][n] != off_pixel) + { + return false; + } + } + } + return true; + } + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename in_image_type, + typename out_image_type, + long M, + long N + > + void binary_dilation ( + const in_image_type& in_img_, + out_image_type& out_img_, + const unsigned char (&structuring_element)[M][N] + ) + { + typedef typename image_traits::pixel_type in_pixel_type; + typedef typename image_traits::pixel_type out_pixel_type; + COMPILE_TIME_ASSERT( pixel_traits::has_alpha == false ); + COMPILE_TIME_ASSERT( pixel_traits::has_alpha == false ); + + using namespace morphological_operations_helpers; + COMPILE_TIME_ASSERT(M%2 == 1); + COMPILE_TIME_ASSERT(N%2 == 1); + DLIB_ASSERT(is_same_object(in_img_,out_img_) == false, + "\tvoid binary_dilation()" + << "\n\tYou must give two different image objects" + ); + COMPILE_TIME_ASSERT(pixel_traits::grayscale); + DLIB_ASSERT(is_binary_image(in_img_) , + "\tvoid binary_dilation()" + << "\n\tin_img must be a binary image" + ); + DLIB_ASSERT(is_binary_image(structuring_element) , + "\tvoid binary_dilation()" + << "\n\tthe structuring_element must be a binary image" + ); + + + const_image_view in_img(in_img_); + image_view out_img(out_img_); + + // if there isn't any input image then don't do anything + if (in_img.size() == 0) + { + out_img.clear(); + return; + } + + out_img.set_size(in_img.nr(),in_img.nc()); + + // apply the filter to the image + for (long r = 0; r < in_img.nr(); ++r) + { + for (long c = 0; c < in_img.nc(); ++c) + { + unsigned char out_pixel = off_pixel; + for (long m = 0; m < M && out_pixel == off_pixel; ++m) + { + for (long n = 0; n < N && out_pixel == off_pixel; ++n) + { + if (structuring_element[m][n] == on_pixel) + { + // if this pixel is inside the image then get it from the image + // but if it isn't just pretend it was an off_pixel value + if (r+m >= M/2 && c+n >= N/2 && + r+m-M/2 < in_img.nr() && c+n-N/2 < in_img.nc()) + { + out_pixel = in_img[r+m-M/2][c+n-N/2]; + } + } + } + } + assign_pixel(out_img[r][c], out_pixel); + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename in_image_type, + typename out_image_type, + long M, + long N + > + void binary_erosion ( + const in_image_type& in_img_, + out_image_type& out_img_, + const unsigned char (&structuring_element)[M][N] + ) + { + typedef typename image_traits::pixel_type in_pixel_type; + typedef typename image_traits::pixel_type out_pixel_type; + COMPILE_TIME_ASSERT( pixel_traits::has_alpha == false ); + COMPILE_TIME_ASSERT( pixel_traits::has_alpha == false ); + + using namespace morphological_operations_helpers; + COMPILE_TIME_ASSERT(M%2 == 1); + COMPILE_TIME_ASSERT(N%2 == 1); + DLIB_ASSERT(is_same_object(in_img_,out_img_) == false, + "\tvoid binary_erosion()" + << "\n\tYou must give two different image objects" + ); + COMPILE_TIME_ASSERT(pixel_traits::grayscale); + DLIB_ASSERT(is_binary_image(in_img_) , + "\tvoid binary_erosion()" + << "\n\tin_img must be a binary image" + ); + DLIB_ASSERT(is_binary_image(structuring_element) , + "\tvoid binary_erosion()" + << "\n\tthe structuring_element must be a binary image" + ); + + const_image_view in_img(in_img_); + image_view out_img(out_img_); + + + // if there isn't any input image then don't do anything + if (in_img.size() == 0) + { + out_img.clear(); + return; + } + + out_img.set_size(in_img.nr(),in_img.nc()); + + // apply the filter to the image + for (long r = 0; r < in_img.nr(); ++r) + { + for (long c = 0; c < in_img.nc(); ++c) + { + unsigned char out_pixel = on_pixel; + for (long m = 0; m < M && out_pixel == on_pixel; ++m) + { + for (long n = 0; n < N && out_pixel == on_pixel; ++n) + { + if (structuring_element[m][n] == on_pixel) + { + // if this pixel is inside the image then get it from the image + // but if it isn't just pretend it was an off_pixel value + if (r+m >= M/2 && c+n >= N/2 && + r+m-M/2 < in_img.nr() && c+n-N/2 < in_img.nc()) + { + out_pixel = in_img[r+m-M/2][c+n-N/2]; + } + else + { + out_pixel = off_pixel; + } + } + } + } + assign_pixel(out_img[r][c], out_pixel); + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename in_image_type, + typename out_image_type, + long M, + long N + > + void binary_open ( + const in_image_type& in_img, + out_image_type& out_img, + const unsigned char (&structuring_element)[M][N], + const unsigned long iter = 1 + ) + { + typedef typename image_traits::pixel_type in_pixel_type; + typedef typename image_traits::pixel_type out_pixel_type; + COMPILE_TIME_ASSERT( pixel_traits::has_alpha == false ); + COMPILE_TIME_ASSERT( pixel_traits::has_alpha == false ); + + using namespace morphological_operations_helpers; + COMPILE_TIME_ASSERT(M%2 == 1); + COMPILE_TIME_ASSERT(N%2 == 1); + DLIB_ASSERT(is_same_object(in_img,out_img) == false, + "\tvoid binary_open()" + << "\n\tYou must give two different image objects" + ); + COMPILE_TIME_ASSERT(pixel_traits::grayscale); + DLIB_ASSERT(is_binary_image(in_img) , + "\tvoid binary_open()" + << "\n\tin_img must be a binary image" + ); + DLIB_ASSERT(is_binary_image(structuring_element) , + "\tvoid binary_open()" + << "\n\tthe structuring_element must be a binary image" + ); + + + // if there isn't any input image then don't do anything + if (num_rows(in_img)*num_columns(in_img) == 0) + { + set_image_size(out_img, 0,0); + return; + } + + set_image_size(out_img, num_rows(in_img), num_columns(in_img)); + + if (iter == 0) + { + // just copy the image over + assign_image(out_img, in_img); + } + else if (iter == 1) + { + in_image_type temp; + binary_erosion(in_img,temp,structuring_element); + binary_dilation(temp,out_img,structuring_element); + } + else + { + in_image_type temp1, temp2; + binary_erosion(in_img,temp1,structuring_element); + + // do the extra erosions + for (unsigned long i = 1; i < iter; ++i) + { + swap(temp1, temp2); + binary_erosion(temp2,temp1,structuring_element); + } + + // do the extra dilations + for (unsigned long i = 1; i < iter; ++i) + { + swap(temp1, temp2); + binary_dilation(temp2,temp1,structuring_element); + } + + binary_dilation(temp1,out_img,structuring_element); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename in_image_type, + typename out_image_type, + long M, + long N + > + void binary_close ( + const in_image_type& in_img, + out_image_type& out_img, + const unsigned char (&structuring_element)[M][N], + const unsigned long iter = 1 + ) + { + typedef typename image_traits::pixel_type in_pixel_type; + typedef typename image_traits::pixel_type out_pixel_type; + COMPILE_TIME_ASSERT( pixel_traits::has_alpha == false ); + COMPILE_TIME_ASSERT( pixel_traits::has_alpha == false ); + + + using namespace morphological_operations_helpers; + COMPILE_TIME_ASSERT(M%2 == 1); + COMPILE_TIME_ASSERT(N%2 == 1); + DLIB_ASSERT(is_same_object(in_img,out_img) == false, + "\tvoid binary_close()" + << "\n\tYou must give two different image objects" + ); + COMPILE_TIME_ASSERT(pixel_traits::grayscale); + DLIB_ASSERT(is_binary_image(in_img) , + "\tvoid binary_close()" + << "\n\tin_img must be a binary image" + ); + DLIB_ASSERT(is_binary_image(structuring_element) , + "\tvoid binary_close()" + << "\n\tthe structuring_element must be a binary image" + ); + + + // if there isn't any input image then don't do anything + if (num_rows(in_img)*num_columns(in_img) == 0) + { + set_image_size(out_img, 0,0); + return; + } + + set_image_size(out_img, num_rows(in_img), num_columns(in_img)); + + if (iter == 0) + { + // just copy the image over + assign_image(out_img, in_img); + } + else if (iter == 1) + { + in_image_type temp; + binary_dilation(in_img,temp,structuring_element); + binary_erosion(temp,out_img,structuring_element); + } + else + { + in_image_type temp1, temp2; + binary_dilation(in_img,temp1,structuring_element); + + // do the extra dilations + for (unsigned long i = 1; i < iter; ++i) + { + swap(temp1, temp2); + binary_dilation(temp2,temp1,structuring_element); + } + + // do the extra erosions + for (unsigned long i = 1; i < iter; ++i) + { + swap(temp1, temp2); + binary_erosion(temp2,temp1,structuring_element); + } + + binary_erosion(temp1,out_img,structuring_element); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename in_image_type1, + typename in_image_type2, + typename out_image_type + > + void binary_intersection ( + const in_image_type1& in_img1_, + const in_image_type2& in_img2_, + out_image_type& out_img_ + ) + { + typedef typename image_traits::pixel_type in_pixel_type1; + typedef typename image_traits::pixel_type in_pixel_type2; + typedef typename image_traits::pixel_type out_pixel_type; + COMPILE_TIME_ASSERT( pixel_traits::has_alpha == false ); + COMPILE_TIME_ASSERT( pixel_traits::has_alpha == false ); + COMPILE_TIME_ASSERT( pixel_traits::has_alpha == false ); + + using namespace morphological_operations_helpers; + COMPILE_TIME_ASSERT(pixel_traits::grayscale); + COMPILE_TIME_ASSERT(pixel_traits::grayscale); + DLIB_ASSERT(is_binary_image(in_img1_) , + "\tvoid binary_intersection()" + << "\n\tin_img1 must be a binary image" + ); + DLIB_ASSERT(is_binary_image(in_img2_) , + "\tvoid binary_intersection()" + << "\n\tin_img2 must be a binary image" + ); + + const_image_view in_img1(in_img1_); + const_image_view in_img2(in_img2_); + image_view out_img(out_img_); + + DLIB_ASSERT(in_img1.nc() == in_img2.nc(), + "\tvoid binary_intersection()" + << "\n\tin_img1 and in_img2 must have the same ncs." + << "\n\tin_img1.nc(): " << in_img1.nc() + << "\n\tin_img2.nc(): " << in_img2.nc() + ); + DLIB_ASSERT(in_img1.nr() == in_img2.nr(), + "\tvoid binary_intersection()" + << "\n\tin_img1 and in_img2 must have the same nrs." + << "\n\tin_img1.nr(): " << in_img1.nr() + << "\n\tin_img2.nr(): " << in_img2.nr() + ); + + + + // if there isn't any input image then don't do anything + if (in_img1.size() == 0) + { + out_img.clear(); + return; + } + + out_img.set_size(in_img1.nr(),in_img1.nc()); + + for (long r = 0; r < in_img1.nr(); ++r) + { + for (long c = 0; c < in_img1.nc(); ++c) + { + if (in_img1[r][c] == on_pixel && in_img2[r][c] == on_pixel) + assign_pixel(out_img[r][c], on_pixel); + else + assign_pixel(out_img[r][c], off_pixel); + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename in_image_type1, + typename in_image_type2, + typename out_image_type + > + void binary_union ( + const in_image_type1& in_img1_, + const in_image_type2& in_img2_, + out_image_type& out_img_ + ) + { + typedef typename image_traits::pixel_type in_pixel_type1; + typedef typename image_traits::pixel_type in_pixel_type2; + typedef typename image_traits::pixel_type out_pixel_type; + COMPILE_TIME_ASSERT( pixel_traits::has_alpha == false ); + COMPILE_TIME_ASSERT( pixel_traits::has_alpha == false ); + COMPILE_TIME_ASSERT( pixel_traits::has_alpha == false ); + + + using namespace morphological_operations_helpers; + COMPILE_TIME_ASSERT(pixel_traits::grayscale); + COMPILE_TIME_ASSERT(pixel_traits::grayscale); + DLIB_ASSERT(is_binary_image(in_img1_) , + "\tvoid binary_intersection()" + << "\n\tin_img1 must be a binary image" + ); + DLIB_ASSERT(is_binary_image(in_img2_) , + "\tvoid binary_intersection()" + << "\n\tin_img2 must be a binary image" + ); + + const_image_view in_img1(in_img1_); + const_image_view in_img2(in_img2_); + image_view out_img(out_img_); + + DLIB_ASSERT(in_img1.nc() == in_img2.nc(), + "\tvoid binary_intersection()" + << "\n\tin_img1 and in_img2 must have the same ncs." + << "\n\tin_img1.nc(): " << in_img1.nc() + << "\n\tin_img2.nc(): " << in_img2.nc() + ); + DLIB_ASSERT(in_img1.nr() == in_img2.nr(), + "\tvoid binary_intersection()" + << "\n\tin_img1 and in_img2 must have the same nrs." + << "\n\tin_img1.nr(): " << in_img1.nr() + << "\n\tin_img2.nr(): " << in_img2.nr() + ); + + + + // if there isn't any input image then don't do anything + if (in_img1.size() == 0) + { + out_img.clear(); + return; + } + + out_img.set_size(in_img1.nr(),in_img1.nc()); + + for (long r = 0; r < in_img1.nr(); ++r) + { + for (long c = 0; c < in_img1.nc(); ++c) + { + if (in_img1[r][c] == on_pixel || in_img2[r][c] == on_pixel) + assign_pixel(out_img[r][c], on_pixel); + else + assign_pixel(out_img[r][c], off_pixel); + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename in_image_type1, + typename in_image_type2, + typename out_image_type + > + void binary_difference ( + const in_image_type1& in_img1_, + const in_image_type2& in_img2_, + out_image_type& out_img_ + ) + { + typedef typename image_traits::pixel_type in_pixel_type1; + typedef typename image_traits::pixel_type in_pixel_type2; + typedef typename image_traits::pixel_type out_pixel_type; + COMPILE_TIME_ASSERT( pixel_traits::has_alpha == false ); + COMPILE_TIME_ASSERT( pixel_traits::has_alpha == false ); + COMPILE_TIME_ASSERT( pixel_traits::has_alpha == false ); + + using namespace morphological_operations_helpers; + COMPILE_TIME_ASSERT(pixel_traits::grayscale); + COMPILE_TIME_ASSERT(pixel_traits::grayscale); + DLIB_ASSERT(is_binary_image(in_img1_) , + "\tvoid binary_difference()" + << "\n\tin_img1 must be a binary image" + ); + DLIB_ASSERT(is_binary_image(in_img2_) , + "\tvoid binary_difference()" + << "\n\tin_img2 must be a binary image" + ); + + const_image_view in_img1(in_img1_); + const_image_view in_img2(in_img2_); + image_view out_img(out_img_); + + DLIB_ASSERT(in_img1.nc() == in_img2.nc(), + "\tvoid binary_difference()" + << "\n\tin_img1 and in_img2 must have the same ncs." + << "\n\tin_img1.nc(): " << in_img1.nc() + << "\n\tin_img2.nc(): " << in_img2.nc() + ); + DLIB_ASSERT(in_img1.nr() == in_img2.nr(), + "\tvoid binary_difference()" + << "\n\tin_img1 and in_img2 must have the same nrs." + << "\n\tin_img1.nr(): " << in_img1.nr() + << "\n\tin_img2.nr(): " << in_img2.nr() + ); + + + + // if there isn't any input image then don't do anything + if (in_img1.size() == 0) + { + out_img.clear(); + return; + } + + out_img.set_size(in_img1.nr(),in_img1.nc()); + + for (long r = 0; r < in_img1.nr(); ++r) + { + for (long c = 0; c < in_img1.nc(); ++c) + { + if (in_img1[r][c] == on_pixel && in_img2[r][c] == off_pixel) + assign_pixel(out_img[r][c], on_pixel); + else + assign_pixel(out_img[r][c], off_pixel); + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename in_image_type, + typename out_image_type + > + void binary_complement ( + const in_image_type& in_img_, + out_image_type& out_img_ + ) + { + typedef typename image_traits::pixel_type in_pixel_type; + typedef typename image_traits::pixel_type out_pixel_type; + COMPILE_TIME_ASSERT( pixel_traits::has_alpha == false ); + COMPILE_TIME_ASSERT( pixel_traits::has_alpha == false ); + + + using namespace morphological_operations_helpers; + COMPILE_TIME_ASSERT(pixel_traits::grayscale); + DLIB_ASSERT(is_binary_image(in_img_) , + "\tvoid binary_complement()" + << "\n\tin_img must be a binary image" + ); + + const_image_view in_img(in_img_); + image_view out_img(out_img_); + + // if there isn't any input image then don't do anything + if (in_img.size() == 0) + { + out_img.clear(); + return; + } + + out_img.set_size(in_img.nr(),in_img.nc()); + + for (long r = 0; r < in_img.nr(); ++r) + { + for (long c = 0; c < in_img.nc(); ++c) + { + if (in_img[r][c] == on_pixel) + assign_pixel(out_img[r][c], off_pixel); + else + assign_pixel(out_img[r][c], on_pixel); + } + } + } + + template < + typename image_type + > + void binary_complement ( + image_type& img + ) + { + binary_complement(img,img); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + namespace impl + { + template + inline bool should_remove_pixel ( + const image_type& img, + long r, + long c, + int iter + ) + { + unsigned int p2 = img[r-1][c]; + unsigned int p3 = img[r-1][c+1]; + unsigned int p4 = img[r][c+1]; + unsigned int p5 = img[r+1][c+1]; + unsigned int p6 = img[r+1][c]; + unsigned int p7 = img[r+1][c-1]; + unsigned int p8 = img[r][c-1]; + unsigned int p9 = img[r-1][c-1]; + + int A = (p2 == 0 && p3 == 255) + (p3 == 0 && p4 == 255) + + (p4 == 0 && p5 == 255) + (p5 == 0 && p6 == 255) + + (p6 == 0 && p7 == 255) + (p7 == 0 && p8 == 255) + + (p8 == 0 && p9 == 255) + (p9 == 0 && p2 == 255); + int B = p2 + p3 + p4 + p5 + p6 + p7 + p8 + p9; + int m1 = iter == 0 ? (p2 * p4 * p6) : (p2 * p4 * p8); + int m2 = iter == 0 ? (p4 * p6 * p8) : (p2 * p6 * p8); + // Decide if we should remove the pixel img[r][c]. + return (A == 1 && (B >= 2*255 && B <= 6*255) && m1 == 0 && m2 == 0); + } + + template + inline void add_to_remove ( + std::vector& to_remove, + array2d& marker, + const image_type& img, + long r, + long c, + int iter + ) + { + if (marker[r][c]&&should_remove_pixel(img,r,c,iter)) + { + to_remove.push_back(point(c,r)); + marker[r][c] = 0; + } + } + + template + inline bool is_bw_border_pixel( + const image_type& img, + long r, + long c + ) + { + unsigned int p2 = img[r-1][c]; + unsigned int p3 = img[r-1][c+1]; + unsigned int p4 = img[r][c+1]; + unsigned int p5 = img[r+1][c+1]; + unsigned int p6 = img[r+1][c]; + unsigned int p7 = img[r+1][c-1]; + unsigned int p8 = img[r][c-1]; + unsigned int p9 = img[r-1][c-1]; + + int B = p2 + p3 + p4 + p5 + p6 + p7 + p8 + p9; + // If you are on but at least one of your neighbors isn't. + return B<8*255 && img[r][c]; + + } + + inline void add_if( + std::vector& to_check2, + const array2d& marker, + long c, + long r + ) + { + if (marker[r][c]) + to_check2.push_back(point(c,r)); + } + + } // end namespace impl + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + void skeleton( + image_type& img_ + ) + { + /* + The implementation of this function is based on the paper + "A fast parallel algorithm for thinning digital patterns” by T.Y. Zhang and C.Y. Suen. + and also the excellent discussion of it at: + http://opencv-code.com/quick-tips/implementation-of-thinning-algorithm-in-opencv/ + */ + + typedef typename image_traits::pixel_type pixel_type; + + // This function only works on grayscale images + COMPILE_TIME_ASSERT(pixel_traits::grayscale); + + using namespace impl; + // Note that it's important to zero the border for 2 reasons. First, it allows + // thinning to being at the border of the image. But more importantly, it causes + // the mask to have a border of 0 pixels as well which we use later to avoid + // indexing outside the image inside add_to_remove(). + zero_border_pixels(img_,1,1); + image_view img(img_); + + // We use the marker to keep track of pixels we have committed to removing but + // haven't yet removed from img. + array2d marker(img.nr(), img.nc()); + assign_image(marker, img); + + + // Begin by making a list of the pixels on the borders of binary blobs. + std::vector to_remove, to_check, to_check2; + for (int r = 1; r < img.nr()-1; r++) + { + for (int c = 1; c < img.nc()-1; c++) + { + if (is_bw_border_pixel(img, r, c)) + { + to_check.push_back(point(c,r)); + } + } + } + + // Now start iteratively looking at the border pixels and removing them. + while(to_check.size() != 0) + { + for (int iter = 0; iter <= 1; ++iter) + { + // Check which pixels we should remove + to_remove.clear(); + for (unsigned long i = 0; i < to_check.size(); ++i) + { + long r = to_check[i].y(); + long c = to_check[i].x(); + add_to_remove(to_remove, marker, img, r, c, iter); + } + for (unsigned long i = 0; i < to_check2.size(); ++i) + { + long r = to_check2[i].y(); + long c = to_check2[i].x(); + add_to_remove(to_remove, marker, img, r, c, iter); + } + // Now remove those pixels. Also add their neighbors into the "to check" + // pixel list for the next iteration. + for (unsigned long i = 0; i < to_remove.size(); ++i) + { + long r = to_remove[i].y(); + long c = to_remove[i].x(); + // remove the pixel + img[r][c] = 0; + add_if(to_check2, marker, c-1, r-1); + add_if(to_check2, marker, c, r-1); + add_if(to_check2, marker, c+1, r-1); + add_if(to_check2, marker, c-1, r); + add_if(to_check2, marker, c+1, r); + add_if(to_check2, marker, c-1, r+1); + add_if(to_check2, marker, c, r+1); + add_if(to_check2, marker, c+1, r+1); + } + } + to_check.clear(); + to_check.swap(to_check2); + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MORPHOLOGICAL_OPERATIONs_ + diff --git a/ml/dlib/dlib/image_transforms/morphological_operations_abstract.h b/ml/dlib/dlib/image_transforms/morphological_operations_abstract.h new file mode 100644 index 000000000..c69bdd1ca --- /dev/null +++ b/ml/dlib/dlib/image_transforms/morphological_operations_abstract.h @@ -0,0 +1,316 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_MORPHOLOGICAL_OPERATIONs_ABSTRACT_ +#ifdef DLIB_MORPHOLOGICAL_OPERATIONs_ABSTRACT_ + +#include "../pixel.h" +#include "thresholding_abstract.h" +#include "../image_processing/generic_image.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename in_image_type, + typename out_image_type, + long M, + long N + > + void binary_dilation ( + const in_image_type& in_img, + out_image_type& out_img, + const unsigned char (&structuring_element)[M][N] + ); + /*! + requires + - in_image_type and out_image_type are image objects that implement the + interface defined in dlib/image_processing/generic_image.h + - in_img must contain a grayscale pixel type. + - both in_img and out_img must contain pixels with no alpha channel. + (i.e. pixel_traits::has_alpha==false for their pixels) + - is_same_object(in_img,out_img) == false + - M % 2 == 1 (i.e. M must be odd) + - N % 2 == 1 (i.e. N must be odd) + - all pixels in in_img are set to either on_pixel or off_pixel + (i.e. it must be a binary image) + - all pixels in structuring_element are set to either on_pixel or off_pixel + (i.e. it must be a binary image) + ensures + - Does a binary dilation of in_img using the given structuring element and + stores the result in out_img. + - #out_img.nc() == in_img.nc() + - #out_img.nr() == in_img.nr() + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename in_image_type, + typename out_image_type, + long M, + long N + > + void binary_erosion ( + const in_image_type& in_img, + out_image_type& out_img, + const unsigned char (&structuring_element)[M][N] + ); + /*! + requires + - in_image_type and out_image_type are image objects that implement the + interface defined in dlib/image_processing/generic_image.h + - in_img must contain a grayscale pixel type. + - both in_img and out_img must contain pixels with no alpha channel. + (i.e. pixel_traits::has_alpha==false for their pixels) + - is_same_object(in_img,out_img) == false + - M % 2 == 1 (i.e. M must be odd) + - N % 2 == 1 (i.e. N must be odd) + - all pixels in in_img are set to either on_pixel or off_pixel + (i.e. it must be a binary image) + - all pixels in structuring_element are set to either on_pixel or off_pixel + (i.e. it must be a binary image) + ensures + - Does a binary erosion of in_img using the given structuring element and + stores the result in out_img. + - #out_img.nc() == in_img.nc() + - #out_img.nr() == in_img.nr() + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename in_image_type, + typename out_image_type, + long M, + long N + > + void binary_open ( + const in_image_type& in_img, + out_image_type& out_img, + const unsigned char (&structuring_element)[M][N], + const unsigned long iter = 1 + ); + /*! + requires + - in_image_type and out_image_type are image objects that implement the + interface defined in dlib/image_processing/generic_image.h + - in_img must contain a grayscale pixel type. + - both in_img and out_img must contain pixels with no alpha channel. + (i.e. pixel_traits::has_alpha==false for their pixels) + - is_same_object(in_img,out_img) == false + - M % 2 == 1 (i.e. M must be odd) + - N % 2 == 1 (i.e. N must be odd) + - all pixels in in_img are set to either on_pixel or off_pixel + (i.e. it must be a binary image) + - all pixels in structuring_element are set to either on_pixel or off_pixel + (i.e. it must be a binary image) + ensures + - Does a binary open of in_img using the given structuring element and + stores the result in out_img. Specifically, iter iterations of binary + erosion are applied and then iter iterations of binary dilation. + - #out_img.nc() == in_img.nc() + - #out_img.nr() == in_img.nr() + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename in_image_type, + typename out_image_type, + long M, + long N + > + void binary_close ( + const in_image_type& in_img, + out_image_type& out_img, + const unsigned char (&structuring_element)[M][N], + const unsigned long iter = 1 + ); + /*! + requires + - in_image_type and out_image_type are image objects that implement the + interface defined in dlib/image_processing/generic_image.h + - in_img must contain a grayscale pixel type. + - both in_img and out_img must contain pixels with no alpha channel. + (i.e. pixel_traits::has_alpha==false for their pixels) + - is_same_object(in_img,out_img) == false + - M % 2 == 1 (i.e. M must be odd) + - N % 2 == 1 (i.e. N must be odd) + - all pixels in in_img are set to either on_pixel or off_pixel + (i.e. it must be a binary image) + - all pixels in structuring_element are set to either on_pixel or off_pixel + (i.e. it must be a binary image) + ensures + - Does a binary close of in_img using the given structuring element and + stores the result in out_img. Specifically, iter iterations of binary + dilation are applied and then iter iterations of binary erosion. + - #out_img.nc() == in_img.nc() + - #out_img.nr() == in_img.nr() + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename in_image_type1, + typename in_image_type2, + typename out_image_type + > + void binary_intersection ( + const in_image_type1& in_img1, + const in_image_type2& in_img2, + out_image_type& out_img + ); + /*! + requires + - in_image_type1, in_image_type2, and out_image_type are image objects that + implement the interface defined in dlib/image_processing/generic_image.h + - in_img1 and in_img2 must contain grayscale pixel types. + - in_img1, in_img2, and out_img must contain pixels with no alpha channel. + (i.e. pixel_traits::has_alpha==false for their pixels) + - all pixels in in_img1 and in_img2 are set to either on_pixel or off_pixel + (i.e. they must be binary images) + - in_img1.nc() == in_img2.nc() + - in_img1.nr() == in_img2.nr() + ensures + - #out_img == the binary intersection of in_img1 and in_img2. (i.e. All + the pixels that are set to on_pixel in both in_img1 and in_img2 will be set + to on_pixel in #out_img. All other pixels will be set to off_pixel) + - #out_img.nc() == in_img.nc() + - #out_img.nr() == in_img.nr() + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename in_image_type1, + typename in_image_type2, + typename out_image_type + > + void binary_union ( + const in_image_type1& in_img1, + const in_image_type2& in_img2, + out_image_type& out_img + ); + /*! + requires + - in_image_type1, in_image_type2, and out_image_type are image objects that + implement the interface defined in dlib/image_processing/generic_image.h + - in_img1 and in_img2 must contain grayscale pixel types. + - in_img1, in_img2, and out_img must contain pixels with no alpha channel. + (i.e. pixel_traits::has_alpha==false for their pixels) + - all pixels in in_img1 and in_img2 are set to either on_pixel or off_pixel + (i.e. they must be binary images) + - in_img1.nc() == in_img2.nc() + - in_img1.nr() == in_img2.nr() + ensures + - #out_img == the binary union of in_img1 and in_img2. (i.e. All + the pixels that are set to on_pixel in in_img1 and/or in_img2 will be set + to on_pixel in #out_img. All other pixels will be set to off_pixel) + - #out_img.nc() == in_img.nc() + - #out_img.nr() == in_img.nr() + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename in_image_type1, + typename in_image_type2, + typename out_image_type + > + void binary_difference ( + const in_image_type1& in_img1, + const in_image_type2& in_img2, + out_image_type& out_img + ); + /*! + requires + - in_image_type1, in_image_type2, and out_image_type are image objects that + implement the interface defined in dlib/image_processing/generic_image.h + - in_img1 and in_img2 must contain grayscale pixel types. + - in_img1, in_img2, and out_img must contain pixels with no alpha channel. + (i.e. pixel_traits::has_alpha==false for their pixels) + - all pixels in in_img1 and in_img2 are set to either on_pixel or off_pixel + (i.e. they must be binary images) + - in_img1.nc() == in_img2.nc() + - in_img1.nr() == in_img2.nr() + ensures + - #out_img == the binary difference of in_img1 and in_img2. (i.e. #out_img + will be a copy of in_img1 except that any pixels in in_img2 that are set to + on_pixel will be set to off_pixel) + - #out_img.nc() == in_img.nc() + - #out_img.nr() == in_img.nr() + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename in_image_type, + typename out_image_type + > + void binary_complement ( + const in_image_type& in_img, + out_image_type& out_img + ); + /*! + requires + - in_image_type and out_image_type are image objects that implement the + interface defined in dlib/image_processing/generic_image.h + - in_img must contain a grayscale pixel type. + - both in_img and out_img must contain pixels with no alpha channel. + (i.e. pixel_traits::has_alpha==false for their pixels) + - all pixels in in_img are set to either on_pixel or off_pixel + (i.e. it must be a binary image) + ensures + - #out_img == the binary complement of in_img. (i.e. For each pixel in + in_img, if it is on_pixel then it will be set to off_pixel in #out_img and + if it was off_pixel in in_img then it will be on_pixel in #out_img) + - #out_img.nc() == in_img.nc() + - #out_img.nr() == in_img.nr() + !*/ + + template < + typename image_type + > + void binary_complement ( + image_type& img + ); + /*! + requires + - it must be valid to call binary_complement(img,img); + ensures + - calls binary_complement(img,img); + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + void skeleton( + image_type& img + ); + /*! + requires + - image_type is an object that implement the interface defined in + dlib/image_processing/generic_image.h + - img must contain a grayscale pixel type. + - all pixels in img are set to either on_pixel or off_pixel. + (i.e. it must be a binary image) + ensures + - This function computes the skeletonization of img and stores the result in + #img. That is, given a binary image, we progressively thin the binary blobs + (composed of on_pixel values) until only a single pixel wide skeleton of the + original blobs remains. + - #img.nc() == img.nc() + - #img.nr() == img.nr() + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MORPHOLOGICAL_OPERATIONs_ABSTRACT_ + + diff --git a/ml/dlib/dlib/image_transforms/random_color_transform.h b/ml/dlib/dlib/image_transforms/random_color_transform.h new file mode 100644 index 000000000..7433da1f7 --- /dev/null +++ b/ml/dlib/dlib/image_transforms/random_color_transform.h @@ -0,0 +1,157 @@ +// Copyright (C) 2016 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_RANDOM_cOLOR_TRANSFORM_Hh_ +#define DLIB_RANDOM_cOLOR_TRANSFORM_Hh_ + +#include "random_color_transform_abstract.h" +#include "../image_processing/generic_image.h" +#include "../pixel.h" +#include "../rand.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class random_color_transform + { + public: + + random_color_transform ( + dlib::rand& rnd, + const double gamma_magnitude = 0.5, + const double color_magnitude = 0.2 + ) + { + // pick a random gamma correction factor. + double gamma = std::max(0.0, 1 + gamma_magnitude*(rnd.get_random_double()-0.5)); + + // pick a random color balancing scheme. + double red_scale = 1-rnd.get_random_double()*color_magnitude; + double green_scale = 1-rnd.get_random_double()*color_magnitude; + double blue_scale = 1-rnd.get_random_double()*color_magnitude; + const double m = 255*std::max(std::max(red_scale,green_scale),blue_scale); + red_scale /= m; + green_scale /= m; + blue_scale /= m; + + // Now compute a lookup table for all the color channels. The table tells us + // what the transform does. + table.resize(256*3); + unsigned long i = 0; + for (int k = 0; k < 256; ++k) + { + double v = 255*std::pow(k*red_scale, gamma); + table[i++] = (unsigned char)(v + 0.5); + } + for (int k = 0; k < 256; ++k) + { + double v = 255*std::pow(k*green_scale, gamma); + table[i++] = (unsigned char)(v + 0.5); + } + for (int k = 0; k < 256; ++k) + { + double v = 255*std::pow(k*blue_scale, gamma); + table[i++] = (unsigned char)(v + 0.5); + } + } + + rgb_pixel operator()(rgb_pixel p) const + { + p.red = table[(unsigned int)p.red]; + p.green = table[(unsigned int)p.green+256]; + p.blue = table[(unsigned int)p.blue+512]; + return p; + } + + private: + std::vector table; + }; + +// ---------------------------------------------------------------------------------------- + + template + void disturb_colors ( + image_type& img_, + dlib::rand& rnd, + const double gamma_magnitude = 0.5, + const double color_magnitude = 0.2 + ) + { + image_view img(img_); + random_color_transform tform(rnd, gamma_magnitude, color_magnitude); + for (long r = 0; r < img.nr(); ++r) + { + for (long c = 0; c < img.nc(); ++c) + { + rgb_pixel temp; + assign_pixel(temp, img[r][c]); + temp = tform(temp); + assign_pixel(img[r][c], temp); + } + } + } + +// ---------------------------------------------------------------------------------------- + + template + void apply_random_color_offset ( + image_type& img_, + dlib::rand& rnd + ) + { + // Make a random color offset. This tform matrix came from looking at the + // covariance matrix of RGB values in a bunch of images. In particular, if you + // multiply Gaussian random vectors by tform it will result in vectors with the + // same covariance matrix as the original RGB data. Also, this color transform is + // what is suggested by the paper: + // Krizhevsky, Alex, Ilya Sutskever, and Geoffrey E. Hinton. "Imagenet + // classification with deep convolutional neural networks." Advances in neural + // information processing systems. 2012. + // Except that we used the square root of the eigenvalues (which I'm pretty sure is + // what the authors intended). + matrix tform; + tform = -66.379, 25.094, 6.79698, + -68.0492, -0.302309, -13.9539, + -68.4907, -24.0199, 7.27653; + matrix v; + v = rnd.get_random_gaussian(),rnd.get_random_gaussian(),rnd.get_random_gaussian(); + v = round(tform*0.1*v); + const int roffset = v(0); + const int goffset = v(1); + const int boffset = v(2); + + // Make up lookup tables that apply the color mapping so we don't have to put a + // bunch of complicated conditional branches in the loop below. + unsigned char rtable[256]; + unsigned char gtable[256]; + unsigned char btable[256]; + for (int i = 0; i < 256; ++i) + { + rtable[i] = put_in_range(0, 255, i+roffset); + gtable[i] = put_in_range(0, 255, i+goffset); + btable[i] = put_in_range(0, 255, i+boffset); + } + + // now transform the image. + image_view img(img_); + for (long r = 0; r < img.nr(); ++r) + { + for (long c = 0; c < img.nc(); ++c) + { + rgb_pixel temp; + assign_pixel(temp, img[r][c]); + temp.red = rtable[temp.red]; + temp.green = gtable[temp.green]; + temp.blue = btable[temp.blue]; + assign_pixel(img[r][c], temp); + } + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_RANDOM_cOLOR_TRANSFORM_Hh_ + diff --git a/ml/dlib/dlib/image_transforms/random_color_transform_abstract.h b/ml/dlib/dlib/image_transforms/random_color_transform_abstract.h new file mode 100644 index 000000000..5826e16a6 --- /dev/null +++ b/ml/dlib/dlib/image_transforms/random_color_transform_abstract.h @@ -0,0 +1,94 @@ +// Copyright (C) 2016 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_RANDOM_cOLOR_TRANSFORM_ABSTRACT_Hh_ +#ifdef DLIB_RANDOM_cOLOR_TRANSFORM_ABSTRACT_Hh_ + +#include "../image_processing/generic_image.h" +#include "../pixel.h" +#include "../rand.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class random_color_transform + { + /*! + WHAT THIS OBJECT REPRESENTS + This object generates a random color balancing and gamma correction + transform. It then allows you to apply that specific transform to as many + rgb_pixel objects as you like. + !*/ + + public: + + random_color_transform ( + dlib::rand& rnd, + const double gamma_magnitude = 0.5, + const double color_magnitude = 0.2 + ); + /*! + requires + - 0 <= gamma_magnitude + - 0 <= color_magnitude <= 1 + ensures + - This constructor generates a random color transform which can be applied + by calling this object's operator() method. + - The color transform is a gamma correction and color rebalancing. If + gamma_magnitude == 0 and color_magnitude == 0 then the transform doesn't + change any colors at all. However, the larger these parameters the more + noticeable the resulting transform. + !*/ + + rgb_pixel operator()( + rgb_pixel p + ) const; + /*! + ensures + - returns the color transformed version of p. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + template + void disturb_colors ( + image_type& img, + dlib::rand& rnd, + const double gamma_magnitude = 0.5, + const double color_magnitude = 0.2 + ); + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + ensures + - Applies a random color transform to the given image. This is done by + creating a random_color_transform with the given parameters and then + transforming each pixel in the image with the resulting transform. + !*/ + +// ---------------------------------------------------------------------------------------- + + template + void apply_random_color_offset ( + image_type& img, + dlib::rand& rnd + ); + /*! + ensures + - Picks a random color offset vector and adds it to the given image. The offset + vector is selected using the method described in the paper: + Krizhevsky, Alex, Ilya Sutskever, and Geoffrey E. Hinton. "Imagenet + classification with deep convolutional neural networks." Advances in neural + information processing systems. 2012. + In particular, we sample an RGB value from the typical distribution of RGB + values, assuming it has a Gaussian distribution, and then divide it by 10. + This sampled RGB vector is added to each pixel of img. + !*/ + +// ---------------------------------------------------------------------------------------- + +#endif // DLIB_RANDOM_cOLOR_TRANSFORM_ABSTRACT_Hh_ + diff --git a/ml/dlib/dlib/image_transforms/random_cropper.h b/ml/dlib/dlib/image_transforms/random_cropper.h new file mode 100644 index 000000000..2c754b608 --- /dev/null +++ b/ml/dlib/dlib/image_transforms/random_cropper.h @@ -0,0 +1,361 @@ +// Copyright (C) 2016 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_RaNDOM_CROPPER_H_ +#define DLIB_RaNDOM_CROPPER_H_ + +#include "random_cropper_abstract.h" +#include "../threads.h" +#include +#include +#include "interpolation.h" +#include "../image_processing/full_object_detection.h" +#include "../rand.h" + +namespace dlib +{ + class random_cropper + { + chip_dims dims = chip_dims(300,300); + bool randomly_flip = true; + double max_rotation_degrees = 30; + long min_object_length_long_dim = 75; // cropped object will be at least this many pixels along its longest edge. + long min_object_length_short_dim = 30; // cropped object will be at least this many pixels along its shortest edge. + double max_object_size = 0.7; // cropped object will be at most this fraction of the size of the image. + double background_crops_fraction = 0.5; + double translate_amount = 0.10; + + std::mutex rnd_mutex; + dlib::rand rnd; + public: + + void set_seed ( + time_t seed + ) { rnd = dlib::rand(seed); } + + double get_translate_amount ( + ) const { return translate_amount; } + + void set_translate_amount ( + double value + ) + { + DLIB_CASSERT(0 <= value); + translate_amount = value; + } + + double get_background_crops_fraction ( + ) const { return background_crops_fraction; } + + void set_background_crops_fraction ( + double value + ) + { + DLIB_CASSERT(0 <= value && value <= 1); + background_crops_fraction = value; + } + + const chip_dims& get_chip_dims( + ) const { return dims; } + + void set_chip_dims ( + const chip_dims& dims_ + ) { dims = dims_; } + + void set_chip_dims ( + unsigned long rows, + unsigned long cols + ) { set_chip_dims(chip_dims(rows,cols)); } + + bool get_randomly_flip ( + ) const { return randomly_flip; } + + void set_randomly_flip ( + bool value + ) { randomly_flip = value; } + + double get_max_rotation_degrees ( + ) const { return max_rotation_degrees; } + void set_max_rotation_degrees ( + double value + ) { max_rotation_degrees = std::abs(value); } + + long get_min_object_length_long_dim ( + ) const { return min_object_length_long_dim; } + long get_min_object_length_short_dim ( + ) const { return min_object_length_short_dim; } + + void set_min_object_size ( + long long_dim, + long short_dim + ) + { + DLIB_CASSERT(0 < short_dim && short_dim <= long_dim); + min_object_length_long_dim = long_dim; + min_object_length_short_dim = short_dim; + } + + double get_max_object_size ( + ) const { return max_object_size; } + void set_max_object_size ( + double value + ) + { + DLIB_CASSERT(0 < value); + max_object_size = value; + } + + template < + typename array_type + > + void operator() ( + size_t num_crops, + const array_type& images, + const std::vector>& rects, + array_type& crops, + std::vector>& crop_rects + ) + { + DLIB_CASSERT(images.size() == rects.size()); + crops.clear(); + crop_rects.clear(); + append(num_crops, images, rects, crops, crop_rects); + } + + template < + typename array_type + > + void append ( + size_t num_crops, + const array_type& images, + const std::vector>& rects, + array_type& crops, + std::vector>& crop_rects + ) + { + DLIB_CASSERT(images.size() == rects.size()); + DLIB_CASSERT(crops.size() == crop_rects.size()); + auto original_size = crops.size(); + crops.resize(crops.size()+num_crops); + crop_rects.resize(crop_rects.size()+num_crops); + parallel_for(original_size, original_size+num_crops, [&](long i) { + (*this)(images, rects, crops[i], crop_rects[i]); + }); + } + + + template < + typename array_type, + typename image_type + > + void operator() ( + const array_type& images, + const std::vector>& rects, + image_type& crop, + std::vector& crop_rects + ) + { + DLIB_CASSERT(images.size() == rects.size()); + size_t idx; + { std::lock_guard lock(rnd_mutex); + idx = rnd.get_integer(images.size()); + } + (*this)(images[idx], rects[idx], crop, crop_rects); + } + + template < + typename image_type1 + > + image_type1 operator() ( + const image_type1& img + ) + { + image_type1 crop; + std::vector junk1, junk2; + (*this)(img, junk1, crop, junk2); + return crop; + } + + template < + typename image_type1, + typename image_type2 + > + void operator() ( + const image_type1& img, + const std::vector& rects, + image_type2& crop, + std::vector& crop_rects + ) + { + DLIB_CASSERT(num_rows(img)*num_columns(img) != 0); + chip_details crop_plan; + bool should_flip_crop; + make_crop_plan(img, rects, crop_plan, should_flip_crop); + + extract_image_chip(img, crop_plan, crop); + const rectangle_transform tform = get_mapping_to_chip(crop_plan); + + // copy rects into crop_rects and set ones that are outside the crop to ignore or + // drop entirely as appropriate. + crop_rects.clear(); + for (auto rect : rects) + { + // map to crop + rect.rect = tform(rect.rect); + + // if the rect is at least partly in the crop + if (get_rect(crop).intersect(rect.rect).area() != 0) + { + // set to ignore if not totally in the crop or if too small. + if (!get_rect(crop).contains(rect.rect) || + ((long)rect.rect.height() < min_object_length_long_dim && (long)rect.rect.width() < min_object_length_long_dim) || + ((long)rect.rect.height() < min_object_length_short_dim || (long)rect.rect.width() < min_object_length_short_dim)) + { + rect.ignore = true; + } + + crop_rects.push_back(rect); + } + } + + // Also randomly flip the image + if (should_flip_crop) + { + image_type2 temp; + flip_image_left_right(crop, temp); + swap(crop,temp); + for (auto&& rect : crop_rects) + rect.rect = impl::flip_rect_left_right(rect.rect, get_rect(crop)); + } + } + + private: + + template + void make_crop_plan ( + const image_type1& img, + const std::vector& rects, + chip_details& crop_plan, + bool& should_flip_crop + ) + { + std::lock_guard lock(rnd_mutex); + rectangle crop_rect; + if (has_non_ignored_box(rects) && rnd.get_random_double() >= background_crops_fraction) + { + auto rect = rects[randomly_pick_rect(rects)].rect; + + // perturb the location of the crop by a small fraction of the object's size. + const point rand_translate = dpoint(rnd.get_double_in_range(-translate_amount,translate_amount)*std::max(rect.height(),rect.width()), + rnd.get_double_in_range(-translate_amount,translate_amount)*std::max(rect.height(),rect.width())); + + // We are going to grow rect into the cropping rect. First, we grow it a + // little so that it has the desired minimum border around it. + drectangle drect = centered_drect(center(rect)+rand_translate, rect.width()/max_object_size, rect.height()/max_object_size); + + // Now make rect have the same aspect ratio as dims so that there won't be + // any funny stretching when we crop it. We do this by growing it along + // whichever dimension is too short. + const double target_aspect = dims.cols/(double)dims.rows; + if (drect.width()/drect.height() < target_aspect) + drect = centered_drect(drect, target_aspect*drect.height(), drect.height()); + else + drect = centered_drect(drect, drect.width(), drect.width()/target_aspect); + + // Now perturb the scale of the crop. We do this by shrinking it, but not + // so much that it gets smaller than the min object sizes require. + double current_width = dims.cols*rect.width()/drect.width(); + double current_height = dims.rows*rect.height()/drect.height(); + + // never make any dimension smaller than the short dim. + double min_scale1 = std::max(min_object_length_short_dim/current_width, min_object_length_short_dim/current_height); + // at least one dimension needs to be longer than the long dim. + double min_scale2 = std::min(min_object_length_long_dim/current_width, min_object_length_long_dim/current_height); + double min_scale = std::max(min_scale1, min_scale2); + + const double rand_scale_perturb = 1.0/rnd.get_double_in_range(min_scale, 1); + crop_rect = centered_drect(drect, drect.width()*rand_scale_perturb, drect.height()*rand_scale_perturb); + + } + else + { + crop_rect = make_random_cropping_rect(img); + } + should_flip_crop = randomly_flip && rnd.get_random_double() > 0.5; + const double angle = rnd.get_double_in_range(-max_rotation_degrees, max_rotation_degrees)*pi/180; + crop_plan = chip_details(crop_rect, dims, angle); + } + + bool has_non_ignored_box ( + const std::vector& rects + ) const + { + for (auto&& b : rects) + { + if (!b.ignore) + return true; + } + return false; + } + + size_t randomly_pick_rect ( + const std::vector& rects + ) + { + DLIB_CASSERT(has_non_ignored_box(rects)); + size_t idx = rnd.get_integer(rects.size()); + while(rects[idx].ignore) + idx = rnd.get_integer(rects.size()); + return idx; + } + + template + rectangle make_random_cropping_rect( + const image_type& img_ + ) + { + const_image_view img(img_); + // Figure out what rectangle we want to crop from the image. We are going to + // crop out an image of size this->dims, so we pick a random scale factor that + // lets this random box be either as big as it can be while still fitting in + // the image or as small as a 3x zoomed in box randomly somewhere in the image. + double mins = 1.0/3.0, maxs = std::min(img.nr()/(double)dims.rows, img.nc()/(double)dims.cols); + mins = std::min(mins, maxs); + auto scale = rnd.get_double_in_range(mins, maxs); + rectangle rect(scale*dims.cols, scale*dims.rows); + // randomly shift the box around + point offset(rnd.get_integer(1+img.nc()-rect.width()), + rnd.get_integer(1+img.nr()-rect.height())); + return move_rect(rect, offset); + } + + + + }; + +// ---------------------------------------------------------------------------------------- + + inline std::ostream& operator<< ( + std::ostream& out, + const random_cropper& item + ) + { + using std::endl; + out << "random_cropper details: " << endl; + out << " chip_dims.rows: " << item.get_chip_dims().rows << endl; + out << " chip_dims.cols: " << item.get_chip_dims().cols << endl; + out << " randomly_flip: " << std::boolalpha << item.get_randomly_flip() << endl; + out << " max_rotation_degrees: " << item.get_max_rotation_degrees() << endl; + out << " min_object_length_long_dim: " << item.get_min_object_length_long_dim() << endl; + out << " min_object_length_short_dim: " << item.get_min_object_length_short_dim() << endl; + out << " max_object_size: " << item.get_max_object_size() << endl; + out << " background_crops_fraction: " << item.get_background_crops_fraction() << endl; + out << " translate_amount: " << item.get_translate_amount() << endl; + return out; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_RaNDOM_CROPPER_H_ + diff --git a/ml/dlib/dlib/image_transforms/random_cropper_abstract.h b/ml/dlib/dlib/image_transforms/random_cropper_abstract.h new file mode 100644 index 000000000..7603a1c47 --- /dev/null +++ b/ml/dlib/dlib/image_transforms/random_cropper_abstract.h @@ -0,0 +1,346 @@ +// Copyright (C) 2016 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_RaNDOM_CROPPER_ABSTRACT_H_ +#ifdef DLIB_RaNDOM_CROPPER_ABSTRACT_H_ + +#include "../threads.h" +#include +#include +#include "interpolation.h" +#include "../image_processing/full_object_detection.h" +#include "../rand.h" + +namespace dlib +{ + class random_cropper + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a tool for extracting random crops of objects from a set of + images. The crops are randomly jittered in scale, translation, and + rotation but more or less centered on objects specified by mmod_rect + objects. + + THREAD SAFETY + It is safe for multiple threads to make concurrent calls to this object's + operator() methods. + !*/ + + public: + + random_cropper ( + ); + /*! + ensures + - #get_chip_dims() == chip_dims(300,300) + - #get_randomly_flip() == true + - #get_max_rotation_degrees() == 30 + - #get_min_object_length_long_dim() == 70 + - #get_min_object_length_short_dim() == 30 + - #get_max_object_size() == 0.7 + - #get_background_crops_fraction() == 0.5 + - #get_translate_amount() == 0.1 + !*/ + + void set_seed ( + time_t seed + ); + /*! + ensures + - Seeds the internal random number generator with the given seed. + !*/ + + double get_translate_amount ( + ) const; + /*! + ensures + - When a box is cropped out, it will be randomly translated prior to + cropping by #get_translate_amount()*(the box's height) up or down and + #get_translate_amount()*(the box's width) left or right. + !*/ + + void set_translate_amount ( + double value + ); + /*! + requires + - value >= 0 + ensures + - #get_translate_amount() == value + !*/ + + double get_background_crops_fraction ( + ) const; + /*! + ensures + - When making random crops, get_background_crops_fraction() fraction of + them will be from random background rather than being centered on some + object in the dataset. + !*/ + + void set_background_crops_fraction ( + double value + ); + /*! + requires + - 0 <= value <= 1 + ensures + - #get_background_crops_fraction() == value + !*/ + + const chip_dims& get_chip_dims( + ) const; + /*! + ensures + - returns the dimensions of image chips produced by this object. + !*/ + + void set_chip_dims ( + const chip_dims& dims + ); + /*! + ensures + - #get_chip_dims() == dims + !*/ + + void set_chip_dims ( + unsigned long rows, + unsigned long cols + ); + /*! + ensures + - #get_chip_dims() == chip_dims(rows,cols) + !*/ + + bool get_randomly_flip ( + ) const; + /*! + ensures + - if this object will randomly mirror chips left to right. + !*/ + + void set_randomly_flip ( + bool value + ); + /*! + ensures + - #get_randomly_flip() == value + !*/ + + double get_max_rotation_degrees ( + ) const; + /*! + ensures + - When extracting an image chip, this object will pick a random rotation + in the range [-get_max_rotation_degrees(), get_max_rotation_degrees()] + and rotate the chip by that amount. + !*/ + + void set_max_rotation_degrees ( + double value + ); + /*! + ensures + - #get_max_rotation_degrees() == std::abs(value) + !*/ + + long get_min_object_length_long_dim ( + ) const; + /*! + ensures + - When a chip is extracted around an object, the chip will be sized so that + the longest edge of the object (i.e. either its height or width, + whichever is longer) is at least #get_min_object_length_long_dim() pixels + in length. When we say "object" here we are referring specifically to + the rectangle in the mmod_rect output by the cropper. + !*/ + + long get_min_object_length_short_dim ( + ) const; + /*! + ensures + - When a chip is extracted around an object, the chip will be sized so that + the shortest edge of the object (i.e. either its height or width, + whichever is shorter) is at least #get_min_object_length_short_dim() + pixels in length. When we say "object" here we are referring + specifically to the rectangle in the mmod_rect output by the cropper. + !*/ + + void set_min_object_size ( + long long_dim, + long short_dim + ); + /*! + requires + - 0 < short_dim <= long_dim + ensures + - #get_min_object_length_short_dim() == short_dim + - #get_min_object_length_long_dim() == long_dim + !*/ + + double get_max_object_size ( + ) const; + /*! + ensures + - When a chip is extracted around an object, the chip will be sized so that + both the object's height and width are at most get_max_object_size() * + the chip's height and width, respectively. E.g. if the chip is 640x480 + pixels in size then the object will be at most 480*get_max_object_size() + pixels tall and 640*get_max_object_size() pixels wide. + !*/ + + void set_max_object_size ( + double value + ); + /*! + requires + - 0 < value + ensures + - #get_max_object_size() == value + !*/ + + template < + typename array_type + > + void append ( + size_t num_crops, + const array_type& images, + const std::vector>& rects, + array_type& crops, + std::vector>& crop_rects + ); + /*! + requires + - images.size() == rects.size() + - crops.size() == crop_rects.size() + - for all valid i: + - images[i].size() != 0 + - array_type is a type with an interface compatible with dlib::array or + std::vector and it must in turn contain image objects that implement the + interface defined in dlib/image_processing/generic_image.h + ensures + - Randomly extracts num_crops chips from images and appends them to the end + of crops. We also copy the object metadata for each extracted crop and + store it into #crop_rects. In particular, calling this function is the + same as making multiple calls to the version of operator() below that + outputs a single crop, except that append() will use multiple CPU cores + to do the processing and is therefore faster. + - #crops.size() == crops.size()+num_crops + - #crop_rects.size() == crop_rects.size()+num_crops + !*/ + + template < + typename array_type + > + void operator() ( + size_t num_crops, + const array_type& images, + const std::vector>& rects, + array_type& crops, + std::vector>& crop_rects + ); + /*! + requires + - images.size() == rects.size() + - for all valid i: + - images[i].size() != 0 + - array_type is a type with an interface compatible with dlib::array or + std::vector and it must in turn contain image objects that implement the + interface defined in dlib/image_processing/generic_image.h + ensures + - Randomly extracts num_crops chips from images. We also copy the object + metadata for each extracted crop and store it into #crop_rects. In + particular, calling this function is the same as invoking the version of + operator() below multiple times, except that this version of operator() + will use multiple CPU cores to do the processing and is therefore faster. + - #crops.size() == num_crops + - #crop_rects.size() == num_crops + !*/ + + template < + typename array_type, + typename image_type + > + void operator() ( + const array_type& images, + const std::vector>& rects, + image_type& crop, + std::vector& crop_rects + ); + /*! + requires + - images.size() == rects.size() + - for all valid i: + - images[i].size() != 0 + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - array_type is a type with an interface compatible with dlib::array or + std::vector and it must in turn contain image objects that implement the + interface defined in dlib/image_processing/generic_image.h + ensures + - Selects a random image and creates a random crop from it. Specifically, + we pick a random index IDX < images.size() and then execute + (*this)(images[IDX],rects[IDX],crop,crop_rects) + !*/ + + template < + typename image_type1, + typename image_type2 + > + void operator() ( + const image_type1& img, + const std::vector& rects, + image_type2& crop, + std::vector& crop_rects + ); + /*! + requires + - img.size() != 0 + - image_type1 == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - image_type2 == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + ensures + - Extracts a random crop from img and copies over the mmod_rect objects in + rects to #crop_rects if they are contained inside the crop. Moreover, + rectangles are marked as ignore if they aren't completely contained + inside the crop. + - #crop_rects.size() <= rects.size() + !*/ + + template < + typename image_type1 + > + image_type1 operator() ( + const image_type1& img + ); + /*! + requires + - img.size() != 0 + - image_type1 == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + ensures + - This function simply calls (*this)(img, junk1, crop, junk2) and returns + crop. Therefore it is simply a convenience function for extracting a + random background patch. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + std::ostream& operator<< ( + std::ostream& out, + const random_cropper& item + ); + /*! + ensures + - Prints the state of all the parameters of item to out. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_RaNDOM_CROPPER_ABSTRACT_H_ + + diff --git a/ml/dlib/dlib/image_transforms/segment_image.h b/ml/dlib/dlib/image_transforms/segment_image.h new file mode 100644 index 000000000..3b57e4801 --- /dev/null +++ b/ml/dlib/dlib/image_transforms/segment_image.h @@ -0,0 +1,730 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SEGMENT_ImAGE_Hh_ +#define DLIB_SEGMENT_ImAGE_Hh_ + +#include "segment_image_abstract.h" +#include "../algs.h" +#include +#include "../geometry.h" +#include "../disjoint_subsets.h" +#include "../set.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + template + inline T edge_diff_uint( + const T& a, + const T& b + ) + { + if (a > b) + return a - b; + else + return b - a; + } + + // ---------------------------------------- + + template + struct edge_diff_funct + { + typedef double diff_type; + + template + double operator()( + const pixel_type& a, + const pixel_type& b + ) const + { + return length(pixel_to_vector(a) - pixel_to_vector(b)); + } + }; + + template <> + struct edge_diff_funct + { + typedef uint8 diff_type; + uint8 operator()( const uint8& a, const uint8& b) const { return edge_diff_uint(a,b); } + }; + + template <> + struct edge_diff_funct + { + typedef uint16 diff_type; + uint16 operator()( const uint16& a, const uint16& b) const { return edge_diff_uint(a,b); } + }; + + template <> + struct edge_diff_funct + { + typedef uint32 diff_type; + uint32 operator()( const uint32& a, const uint32& b) const { return edge_diff_uint(a,b); } + }; + + template <> + struct edge_diff_funct + { + typedef double diff_type; + double operator()( const double& a, const double& b) const { return std::abs(a-b); } + }; + + template + struct edge_diff_funct >::type> + { + typedef double diff_type; + double operator()( + const T& a, + const T& b + ) const + { + return length(a-b); + } + }; + + // ------------------------------------------------------------------------------------ + + template + struct graph_image_segmentation_data_T + { + graph_image_segmentation_data_T() : component_size(1), internal_diff(0) {} + unsigned long component_size; + T internal_diff; + }; + + // ------------------------------------------------------------------------------------ + + template + struct segment_image_edge_data_T + { + segment_image_edge_data_T (){} + + segment_image_edge_data_T ( + const rectangle& rect, + const point& p1, + const point& p2, + const T& diff_ + ) : + idx1(p1.y()*rect.width() + p1.x()), + idx2(p2.y()*rect.width() + p2.x()), + diff(diff_) + {} + + bool operator<(const segment_image_edge_data_T& item) const + { return diff < item.diff; } + + unsigned long idx1; + unsigned long idx2; + T diff; + }; + + // ------------------------------------------------------------------------------------ + + template + struct uint8_or_uint16_pixels + { + typedef typename image_view_type::pixel_type pixel_type; + const static bool value = is_same_type::value || + is_same_type::value; + }; + + // This is an overload of get_pixel_edges() that is optimized to segment images + // with 8bit or 16bit pixels very quickly. We do this by using a radix sort + // instead of quicksort. + template + typename enable_if >::type + get_pixel_edges ( + const in_image_type& in_img, + std::vector >& sorted_edges + ) + { + typedef typename in_image_type::pixel_type ptype; + typedef T diff_type; + std::vector counts(std::numeric_limits::max()+1, 0); + + edge_diff_funct edge_diff; + + border_enumerator be(get_rect(in_img), 1); + // we are going to do a radix sort on the edge weights. So the first step + // is to accumulate them into count. + const rectangle area = get_rect(in_img); + while (be.move_next()) + { + const long r = be.element().y(); + const long c = be.element().x(); + const ptype pix = in_img[r][c]; + if (area.contains(c-1,r)) counts[edge_diff(pix, in_img[r ][c-1])] += 1; + if (area.contains(c+1,r)) counts[edge_diff(pix, in_img[r ][c+1])] += 1; + if (area.contains(c ,r-1)) counts[edge_diff(pix, in_img[r-1][c ])] += 1; + if (area.contains(c ,r+1)) counts[edge_diff(pix, in_img[r+1][c ])] += 1; + } + for (long r = 1; r+1 < in_img.nr(); ++r) + { + for (long c = 1; c+1 < in_img.nc(); ++c) + { + const ptype pix = in_img[r][c]; + counts[edge_diff(pix, in_img[r-1][c+1])] += 1; + counts[edge_diff(pix, in_img[r ][c+1])] += 1; + counts[edge_diff(pix, in_img[r+1][c ])] += 1; + counts[edge_diff(pix, in_img[r+1][c+1])] += 1; + } + } + + const unsigned long num_edges = shrink_rect(area,1).area()*4 + in_img.nr()*2*3 - 4 + (in_img.nc()-2)*2*3; + typedef segment_image_edge_data_T segment_image_edge_data; + sorted_edges.resize(num_edges); + + // integrate counts. The idea is to have sorted_edges[counts[i]] be the location that edges + // with an edge_diff of i go. So counts[0] == 0, counts[1] == number of 0 edge diff edges, etc. + unsigned long prev = counts[0]; + for (unsigned long i = 1; i < counts.size(); ++i) + { + const unsigned long temp = counts[i]; + counts[i] += counts[i-1]; + counts[i-1] -= prev; + prev = temp; + } + counts[counts.size()-1] -= prev; + + + // now build a sorted list of all the edges + be.reset(); + while(be.move_next()) + { + const point p = be.element(); + const long r = p.y(); + const long c = p.x(); + const ptype pix = in_img[r][c]; + if (area.contains(c-1,r)) + { + const diff_type diff = edge_diff(pix, in_img[r ][c-1]); + sorted_edges[counts[diff]++] = segment_image_edge_data(area,p,point(c-1,r),diff); + } + + if (area.contains(c+1,r)) + { + const diff_type diff = edge_diff(pix, in_img[r ][c+1]); + sorted_edges[counts[diff]++] = segment_image_edge_data(area,p,point(c+1,r),diff); + } + + if (area.contains(c ,r-1)) + { + const diff_type diff = edge_diff(pix, in_img[r-1][c ]); + sorted_edges[counts[diff]++] = segment_image_edge_data(area,p,point(c ,r-1),diff); + } + + if (area.contains(c ,r+1)) + { + const diff_type diff = edge_diff(pix, in_img[r+1][c ]); + sorted_edges[counts[diff]++] = segment_image_edge_data(area,p,point(c ,r+1),diff); + } + } + // same thing as the above loop but now we do it on the interior of the image and therefore + // don't have to include the boundary checking if statements used above. + for (long r = 1; r+1 < in_img.nr(); ++r) + { + for (long c = 1; c+1 < in_img.nc(); ++c) + { + const point p(c,r); + const ptype pix = in_img[r][c]; + diff_type diff; + + diff = edge_diff(pix, in_img[r ][c+1]); + sorted_edges[counts[diff]++] = segment_image_edge_data(area,p,point(c+1,r),diff); + diff = edge_diff(pix, in_img[r-1][c+1]); + sorted_edges[counts[diff]++] = segment_image_edge_data(area,p,point(c+1,r-1),diff); + diff = edge_diff(pix, in_img[r+1][c+1]); + sorted_edges[counts[diff]++] = segment_image_edge_data(area,p,point(c+1,r+1),diff); + diff = edge_diff(pix, in_img[r+1][c ]); + sorted_edges[counts[diff]++] = segment_image_edge_data(area,p,point(c ,r+1),diff); + } + } + } + + // ---------------------------------------------------------------------------------------- + + // This is the general purpose version of get_pixel_edges(). It handles all pixel types. + template + typename disable_if >::type + get_pixel_edges ( + const in_image_type& in_img, + std::vector >& sorted_edges + ) + { + const rectangle area = get_rect(in_img); + sorted_edges.reserve(area.area()*4); + + typedef typename in_image_type::pixel_type ptype; + edge_diff_funct edge_diff; + typedef T diff_type; + typedef segment_image_edge_data_T segment_image_edge_data; + + border_enumerator be(get_rect(in_img), 1); + + // now build a sorted list of all the edges + be.reset(); + while(be.move_next()) + { + const point p = be.element(); + const long r = p.y(); + const long c = p.x(); + const ptype& pix = in_img[r][c]; + if (area.contains(c-1,r)) + { + const diff_type diff = edge_diff(pix, in_img[r ][c-1]); + sorted_edges.push_back(segment_image_edge_data(area,p,point(c-1,r),diff)); + } + + if (area.contains(c+1,r)) + { + const diff_type diff = edge_diff(pix, in_img[r ][c+1]); + sorted_edges.push_back(segment_image_edge_data(area,p,point(c+1,r),diff)); + } + + if (area.contains(c ,r-1)) + { + const diff_type diff = edge_diff(pix, in_img[r-1][c ]); + sorted_edges.push_back( segment_image_edge_data(area,p,point(c ,r-1),diff)); + } + if (area.contains(c ,r+1)) + { + const diff_type diff = edge_diff(pix, in_img[r+1][c ]); + sorted_edges.push_back( segment_image_edge_data(area,p,point(c ,r+1),diff)); + } + } + // same thing as the above loop but now we do it on the interior of the image and therefore + // don't have to include the boundary checking if statements used above. + for (long r = 1; r+1 < in_img.nr(); ++r) + { + for (long c = 1; c+1 < in_img.nc(); ++c) + { + const point p(c,r); + const ptype& pix = in_img[r][c]; + diff_type diff; + + diff = edge_diff(pix, in_img[r ][c+1]); + sorted_edges.push_back( segment_image_edge_data(area,p,point(c+1,r),diff)); + diff = edge_diff(pix, in_img[r+1][c+1]); + sorted_edges.push_back( segment_image_edge_data(area,p,point(c+1,r+1),diff)); + diff = edge_diff(pix, in_img[r+1][c ]); + sorted_edges.push_back( segment_image_edge_data(area,p,point(c ,r+1),diff)); + diff = edge_diff(pix, in_img[r-1][c+1]); + sorted_edges.push_back( segment_image_edge_data(area,p,point(c+1,r-1),diff)); + } + } + + std::sort(sorted_edges.begin(), sorted_edges.end()); + + } + + // ------------------------------------------------------------------------------------ + + } // end of namespace impl + +// ---------------------------------------------------------------------------------------- + + template < + typename in_image_type, + typename out_image_type + > + void segment_image ( + const in_image_type& in_img_, + out_image_type& out_img_, + const double k = 200, + const unsigned long min_size = 10 + ) + { + using namespace dlib::impl; + typedef typename image_traits::pixel_type ptype; + typedef typename edge_diff_funct::diff_type diff_type; + + // make sure requires clause is not broken + DLIB_ASSERT(is_same_object(in_img_, out_img_) == false, + "\t void segment_image()" + << "\n\t The input images can't be the same object." + ); + + COMPILE_TIME_ASSERT(is_unsigned_type::pixel_type>::value); + + const_image_view in_img(in_img_); + image_view out_img(out_img_); + + out_img.set_size(in_img.nr(), in_img.nc()); + // don't bother doing anything if the image is too small + if (in_img.nr() < 2 || in_img.nc() < 2) + { + assign_all_pixels(out_img,0); + return; + } + + disjoint_subsets sets; + sets.set_size(in_img.size()); + + std::vector > sorted_edges; + get_pixel_edges(in_img, sorted_edges); + + std::vector > data(in_img.size()); + + // now start connecting blobs together to make a minimum spanning tree. + for (unsigned long i = 0; i < sorted_edges.size(); ++i) + { + const unsigned long idx1 = sorted_edges[i].idx1; + const unsigned long idx2 = sorted_edges[i].idx2; + + unsigned long set1 = sets.find_set(idx1); + unsigned long set2 = sets.find_set(idx2); + if (set1 != set2) + { + const diff_type diff = sorted_edges[i].diff; + const diff_type tau1 = static_cast(k/data[set1].component_size); + const diff_type tau2 = static_cast(k/data[set2].component_size); + + const diff_type mint = std::min(data[set1].internal_diff + tau1, + data[set2].internal_diff + tau2); + if (diff <= mint) + { + const unsigned long new_set = sets.merge_sets(set1, set2); + data[new_set].component_size = data[set1].component_size + data[set2].component_size; + data[new_set].internal_diff = diff; + } + } + } + + // now merge any really small blobs + if (min_size != 0) + { + for (unsigned long i = 0; i < sorted_edges.size(); ++i) + { + const unsigned long idx1 = sorted_edges[i].idx1; + const unsigned long idx2 = sorted_edges[i].idx2; + + unsigned long set1 = sets.find_set(idx1); + unsigned long set2 = sets.find_set(idx2); + if (set1 != set2 && (data[set1].component_size < min_size || data[set2].component_size < min_size)) + { + const unsigned long new_set = sets.merge_sets(set1, set2); + data[new_set].component_size = data[set1].component_size + data[set2].component_size; + //data[new_set].internal_diff = sorted_edges[i].diff; + } + } + } + + unsigned long idx = 0; + for (long r = 0; r < out_img.nr(); ++r) + { + for (long c = 0; c < out_img.nc(); ++c) + { + out_img[r][c] = sets.find_set(idx++); + } + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// Candidate object location generation code. +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + namespace impl + { + struct edge_data + { + double edge_diff; + unsigned long set1; + unsigned long set2; + bool operator<(const edge_data& item) const + { + return edge_diff < item.edge_diff; + } + }; + + template < + typename in_image_type, + typename diff_type + > + void find_basic_candidate_object_locations ( + const in_image_type& in_img, + const std::vector >& sorted_edges, + std::vector& out_rects, + std::vector& edges, + const double k, + const unsigned long min_size + ) + { + using namespace dlib::impl; + + std::vector > rejected_edges; + rejected_edges.reserve(sorted_edges.size()); + + out_rects.clear(); + edges.clear(); + + // don't bother doing anything if the image is too small + if (in_img.nr() < 2 || in_img.nc() < 2) + { + return; + } + + disjoint_subsets sets; + sets.set_size(in_img.size()); + + + std::vector > data(in_img.size()); + + + + std::pair last_blob_edge(std::numeric_limits::max(), + std::numeric_limits::max());; + // now start connecting blobs together to make a minimum spanning tree. + for (unsigned long i = 0; i < sorted_edges.size(); ++i) + { + const unsigned long idx1 = sorted_edges[i].idx1; + const unsigned long idx2 = sorted_edges[i].idx2; + + unsigned long set1 = sets.find_set(idx1); + unsigned long set2 = sets.find_set(idx2); + if (set1 != set2) + { + const diff_type diff = sorted_edges[i].diff; + const diff_type tau1 = static_cast(k/data[set1].component_size); + const diff_type tau2 = static_cast(k/data[set2].component_size); + + const diff_type mint = std::min(data[set1].internal_diff + tau1, + data[set2].internal_diff + tau2); + if (diff <= mint) + { + const unsigned long new_set = sets.merge_sets(set1, set2); + data[new_set].component_size = data[set1].component_size + data[set2].component_size; + data[new_set].internal_diff = diff; + } + else + { + // Don't bother keeping multiple edges from the same pair of blobs, we + // only need one for what we will do later. + if (std::make_pair(set1,set2) != last_blob_edge) + { + segment_image_edge_data_T temp = sorted_edges[i]; + temp.idx1 = set1; + temp.idx2 = set2; + rejected_edges.push_back(temp); + last_blob_edge = std::make_pair(set1,set2); + } + } + } + } + + + // merge small blobs + for (unsigned long i = 0; i < rejected_edges.size(); ++i) + { + const unsigned long idx1 = rejected_edges[i].idx1; + const unsigned long idx2 = rejected_edges[i].idx2; + + unsigned long set1 = sets.find_set(idx1); + unsigned long set2 = sets.find_set(idx2); + rejected_edges[i].idx1 = set1; + rejected_edges[i].idx2 = set2; + if (set1 != set2 && (data[set1].component_size < min_size || data[set2].component_size < min_size)) + { + const unsigned long new_set = sets.merge_sets(set1, set2); + data[new_set].component_size = data[set1].component_size + data[set2].component_size; + data[new_set].internal_diff = rejected_edges[i].diff; + } + } + + // find bounding boxes of each blob + std::map boxes; + std::map box_id_map; + unsigned long idx = 0; + for (long r = 0; r < in_img.nr(); ++r) + { + for (long c = 0; c < in_img.nc(); ++c) + { + const unsigned long id = sets.find_set(idx++); + // Accumulate the current point into its box and if it is the first point + // in the box then also record the id number for this box. + if ((boxes[id] += point(c,r)).area() == 1) + box_id_map[id] = boxes.size()-1; + } + } + + // copy boxes into out_rects + out_rects.resize(boxes.size()); + for (std::map::iterator i = boxes.begin(); i != boxes.end(); ++i) + { + out_rects[box_id_map[i->first]] = i->second; + } + + // Now find the edges between the boxes + typedef dlib::memory_manager::kernel_2c mm_type; + dlib::set, mm_type>::kernel_1a neighbors_final; + for (unsigned long i = 0; i < rejected_edges.size(); ++i) + { + const unsigned long idx1 = rejected_edges[i].idx1; + const unsigned long idx2 = rejected_edges[i].idx2; + + unsigned long set1 = sets.find_set(idx1); + unsigned long set2 = sets.find_set(idx2); + if (set1 != set2) + { + std::pair p = std::make_pair(set1,set2); + if (!neighbors_final.is_member(p)) + { + neighbors_final.add(p); + + edge_data temp; + const diff_type mint = std::min(data[set1].internal_diff , + data[set2].internal_diff ); + temp.edge_diff = rejected_edges[i].diff - mint; + temp.set1 = box_id_map[set1]; + temp.set2 = box_id_map[set2]; + edges.push_back(temp); + } + } + } + + std::sort(edges.begin(), edges.end()); + } + } // end namespace impl + +// ---------------------------------------------------------------------------------------- + + template + void remove_duplicates ( + std::vector& rects + ) + { + std::sort(rects.begin(), rects.end(), std::less()); + unsigned long num_unique = 1; + for (unsigned long i = 1; i < rects.size(); ++i) + { + if (rects[i] != rects[i-1]) + { + rects[num_unique++] = rects[i]; + } + } + if (rects.size() != 0) + rects.resize(num_unique); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename in_image_type, + typename EXP + > + void find_candidate_object_locations ( + const in_image_type& in_img_, + std::vector& rects, + const matrix_exp& kvals, + const unsigned long min_size = 20, + const unsigned long max_merging_iterations = 50 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(is_vector(kvals) && kvals.size() > 0, + "\t void find_candidate_object_locations()" + << "\n\t Invalid inputs were given to this function." + << "\n\t is_vector(kvals): " << is_vector(kvals) + << "\n\t kvals.size(): " << kvals.size() + ); + + typedef dlib::memory_manager::kernel_2c mm_type; + typedef dlib::set::kernel_1a set_of_rects; + + using namespace dlib::impl; + typedef typename image_traits::pixel_type ptype; + typedef typename edge_diff_funct::diff_type diff_type; + + const_image_view in_img(in_img_); + + // don't bother doing anything if the image is too small + if (in_img.nr() < 2 || in_img.nc() < 2) + { + return; + } + + std::vector edges; + std::vector working_rects; + std::vector > sorted_edges; + get_pixel_edges(in_img, sorted_edges); + + disjoint_subsets sets; + + for (long j = 0; j < kvals.size(); ++j) + { + const double k = kvals(j); + + find_basic_candidate_object_locations(in_img, sorted_edges, working_rects, edges, k, min_size); + rects.insert(rects.end(), working_rects.begin(), working_rects.end()); + + + // Now iteratively merge all the rectangles we have and record the results. + // Note that, unlike what is described in the paper + // Segmentation as Selective Search for Object Recognition" by Koen E. A. van de Sande, et al. + // we don't use any kind of histogram/SIFT like thing to order the edges + // between the blobs. Here we simply order by the pixel difference value. + // Additionally, note that we keep progressively merging boxes in the outer + // loop rather than performing just a single iteration as indicated in the + // paper. + set_of_rects detected_rects; + bool did_merge = true; + for (unsigned long iter = 0; did_merge && iter < max_merging_iterations; ++iter) + { + did_merge = false; + sets.clear(); + sets.set_size(working_rects.size()); + + // recursively merge neighboring blobs until we have merged everything + for (unsigned long i = 0; i < edges.size(); ++i) + { + edge_data temp = edges[i]; + + temp.set1 = sets.find_set(temp.set1); + temp.set2 = sets.find_set(temp.set2); + if (temp.set1 != temp.set2) + { + rectangle merged_rect = working_rects[temp.set1] + working_rects[temp.set2]; + // Skip merging this pair of blobs if it was merged in a previous + // iteration. Doing this lets us consider other possible blob + // merges. + if (!detected_rects.is_member(merged_rect)) + { + const unsigned long new_set = sets.merge_sets(temp.set1, temp.set2); + rects.push_back(merged_rect); + working_rects[new_set] = merged_rect; + did_merge = true; + detected_rects.add(merged_rect); + } + } + } + } + } + + remove_duplicates(rects); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename in_image_type + > + void find_candidate_object_locations ( + const in_image_type& in_img, + std::vector& rects + ) + { + find_candidate_object_locations(in_img, rects, linspace(50, 200, 3)); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SEGMENT_ImAGE_Hh_ + diff --git a/ml/dlib/dlib/image_transforms/segment_image_abstract.h b/ml/dlib/dlib/image_transforms/segment_image_abstract.h new file mode 100644 index 000000000..af1af46a1 --- /dev/null +++ b/ml/dlib/dlib/image_transforms/segment_image_abstract.h @@ -0,0 +1,126 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SEGMENT_ImAGE_ABSTRACT_Hh_ +#ifdef DLIB_SEGMENT_ImAGE_ABSTRACT_Hh_ + +#include +#include "../matrix.h" +#include "../image_processing/generic_image.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename in_image_type, + typename out_image_type + > + void segment_image ( + const in_image_type& in_img, + out_image_type& out_img, + const double k = 200, + const unsigned long min_size = 10 + ); + /*! + requires + - in_image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - out_image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - in_image_type can contain any pixel type with a pixel_traits specialization + or a dlib matrix object representing a row or column vector. + - out_image_type must contain an unsigned integer pixel type. + - is_same_object(in_img, out_img) == false + ensures + - Attempts to segment in_img into regions which have some visual consistency to + them. In particular, this function implements the algorithm described in the + paper: Efficient Graph-Based Image Segmentation by Felzenszwalb and Huttenlocher. + - #out_img.nr() == in_img.nr() + - #out_img.nc() == in_img.nc() + - for all valid r and c: + - #out_img[r][c] == an integer value indicating the identity of the segment + containing the pixel in_img[r][c]. + - The k parameter is a measure used to influence how large the segment regions + will be. Larger k generally results in larger segments being produced. For + a deeper discussion of the k parameter you should consult the above + referenced paper. + - min_size is a lower bound on the size of the output segments. That is, it is + guaranteed that all output segments will have at least min_size pixels in + them (unless the whole image contains fewer than min_size pixels, in this + case the entire image will be put into a single segment). + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename in_image_type, + typename EXP + > + void find_candidate_object_locations ( + const in_image_type& in_img, + std::vector& rects, + const matrix_exp& kvals = linspace(50, 200, 3), + const unsigned long min_size = 20, + const unsigned long max_merging_iterations = 50 + ); + /*! + requires + - in_image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - is_vector(kvals) == true + - kvals.size() > 0 + ensures + - This function takes an input image and generates a set of candidate + rectangles which are expected to bound any objects in the image. It does + this by running a version of the segment_image() routine on the image and + then reports rectangles containing each of the segments as well as rectangles + containing unions of adjacent segments. The basic idea is described in the + paper: + Segmentation as Selective Search for Object Recognition by Koen E. A. van de Sande, et al. + Note that this function deviates from what is described in the paper slightly. + See the code for details. + - The basic segmentation is performed kvals.size() times, each time with the k + parameter (see segment_image() and the Felzenszwalb paper for details on k) + set to a different value from kvals. + - When doing the basic segmentations prior to any box merging, we discard all + rectangles that have an area < min_size. Therefore, all outputs and + subsequent merged rectangles are built out of rectangles that contain at + least min_size pixels. Note that setting min_size to a smaller value than + you might otherwise be interested in using can be useful since it allows a + larger number of possible merged boxes to be created. + - There are max_merging_iterations rounds of neighboring blob merging. + Therefore, this parameter has some effect on the number of output rectangles + you get, with larger values of the parameter giving more output rectangles. + - This function appends the output rectangles into #rects. This means that any + rectangles in rects before this function was called will still be in there + after it terminates. Note further that #rects will not contain any duplicate + rectangles. That is, for all valid i and j where i != j it will be true + that: + - #rects[i] != rects[j] + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename alloc + > + void remove_duplicates ( + std::vector& rects + ); + /*! + ensures + - This function finds any duplicate rectangles in rects and removes the extra + instances. This way, the result is that rects contains only unique rectangle + instances. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SEGMENT_ImAGE_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/image_transforms/spatial_filtering.h b/ml/dlib/dlib/image_transforms/spatial_filtering.h new file mode 100644 index 000000000..91dcae321 --- /dev/null +++ b/ml/dlib/dlib/image_transforms/spatial_filtering.h @@ -0,0 +1,1580 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SPATIAL_FILTERINg_H_ +#define DLIB_SPATIAL_FILTERINg_H_ + +#include "../pixel.h" +#include "spatial_filtering_abstract.h" +#include "../algs.h" +#include "../assert.h" +#include "../array2d.h" +#include "../matrix.h" +#include "../geometry/border_enumerator.h" +#include "../simd.h" +#include +#include "assign_image.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + template < + typename in_image_type, + typename out_image_type, + typename EXP, + typename T + > + rectangle grayscale_spatially_filter_image ( + const in_image_type& in_img_, + out_image_type& out_img_, + const matrix_exp& filter_, + T scale, + bool use_abs, + bool add_to + ) + { + const_temp_matrix filter(filter_); + COMPILE_TIME_ASSERT( pixel_traits::pixel_type>::has_alpha == false ); + COMPILE_TIME_ASSERT( pixel_traits::pixel_type>::has_alpha == false ); + + DLIB_ASSERT(scale != 0 && filter.size() != 0, + "\trectangle spatially_filter_image()" + << "\n\t You can't give a scale of zero or an empty filter." + << "\n\t scale: "<< scale + << "\n\t filter.nr(): "<< filter.nr() + << "\n\t filter.nc(): "<< filter.nc() + ); + DLIB_ASSERT(is_same_object(in_img_, out_img_) == false, + "\trectangle spatially_filter_image()" + << "\n\tYou must give two different image objects" + ); + + + const_image_view in_img(in_img_); + image_view out_img(out_img_); + + // if there isn't any input image then don't do anything + if (in_img.size() == 0) + { + out_img.clear(); + return rectangle(); + } + + out_img.set_size(in_img.nr(),in_img.nc()); + + + // figure out the range that we should apply the filter to + const long first_row = filter.nr()/2; + const long first_col = filter.nc()/2; + const long last_row = in_img.nr() - ((filter.nr()-1)/2); + const long last_col = in_img.nc() - ((filter.nc()-1)/2); + + const rectangle non_border = rectangle(first_col, first_row, last_col-1, last_row-1); + if (!add_to) + zero_border_pixels(out_img_, non_border); + + // apply the filter to the image + for (long r = first_row; r < last_row; ++r) + { + for (long c = first_col; c < last_col; ++c) + { + typedef typename EXP::type ptype; + ptype p; + ptype temp = 0; + for (long m = 0; m < filter.nr(); ++m) + { + for (long n = 0; n < filter.nc(); ++n) + { + // pull out the current pixel and put it into p + p = get_pixel_intensity(in_img[r-first_row+m][c-first_col+n]); + temp += p*filter(m,n); + } + } + + temp /= scale; + + if (use_abs && temp < 0) + { + temp = -temp; + } + + // save this pixel to the output image + if (add_to == false) + { + assign_pixel(out_img[r][c], temp); + } + else + { + assign_pixel(out_img[r][c], temp + out_img[r][c]); + } + } + } + + return non_border; + } + + // ------------------------------------------------------------------------------------ + + template < + typename in_image_type, + typename out_image_type, + typename EXP + > + rectangle float_spatially_filter_image ( + const in_image_type& in_img_, + out_image_type& out_img_, + const matrix_exp& filter_, + bool add_to + ) + { + + const_temp_matrix filter(filter_); + DLIB_ASSERT(filter.size() != 0, + "\trectangle spatially_filter_image()" + << "\n\t You can't give an empty filter." + << "\n\t filter.nr(): "<< filter.nr() + << "\n\t filter.nc(): "<< filter.nc() + ); + DLIB_ASSERT(is_same_object(in_img_, out_img_) == false, + "\trectangle spatially_filter_image()" + << "\n\tYou must give two different image objects" + ); + + + const_image_view in_img(in_img_); + image_view out_img(out_img_); + + // if there isn't any input image then don't do anything + if (in_img.size() == 0) + { + out_img.clear(); + return rectangle(); + } + + out_img.set_size(in_img.nr(),in_img.nc()); + + + // figure out the range that we should apply the filter to + const long first_row = filter.nr()/2; + const long first_col = filter.nc()/2; + const long last_row = in_img.nr() - ((filter.nr()-1)/2); + const long last_col = in_img.nc() - ((filter.nc()-1)/2); + + const rectangle non_border = rectangle(first_col, first_row, last_col-1, last_row-1); + if (!add_to) + zero_border_pixels(out_img_, non_border); + + // apply the filter to the image + for (long r = first_row; r < last_row; ++r) + { + long c = first_col; + for (; c < last_col-7; c+=8) + { + simd8f p,p2,p3; + simd8f temp = 0, temp2=0, temp3=0; + for (long m = 0; m < filter.nr(); ++m) + { + long n = 0; + for (; n < filter.nc()-2; n+=3) + { + // pull out the current pixel and put it into p + p.load(&in_img[r-first_row+m][c-first_col+n]); + p2.load(&in_img[r-first_row+m][c-first_col+n+1]); + p3.load(&in_img[r-first_row+m][c-first_col+n+2]); + temp += p*filter(m,n); + temp2 += p2*filter(m,n+1); + temp3 += p3*filter(m,n+2); + } + for (; n < filter.nc(); ++n) + { + // pull out the current pixel and put it into p + p.load(&in_img[r-first_row+m][c-first_col+n]); + temp += p*filter(m,n); + } + } + temp += temp2+temp3; + + // save this pixel to the output image + if (add_to == false) + { + temp.store(&out_img[r][c]); + } + else + { + p.load(&out_img[r][c]); + temp += p; + temp.store(&out_img[r][c]); + } + } + for (; c < last_col; ++c) + { + float p; + float temp = 0; + for (long m = 0; m < filter.nr(); ++m) + { + for (long n = 0; n < filter.nc(); ++n) + { + // pull out the current pixel and put it into p + p = in_img[r-first_row+m][c-first_col+n]; + temp += p*filter(m,n); + } + } + + // save this pixel to the output image + if (add_to == false) + { + out_img[r][c] = temp; + } + else + { + out_img[r][c] += temp; + } + } + } + + return non_border; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename in_image_type, + typename out_image_type, + typename EXP + > + struct is_float_filtering2 + { + const static bool value = is_same_type::pixel_type,float>::value && + is_same_type::pixel_type,float>::value && + is_same_type::value; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename in_image_type, + typename out_image_type, + typename EXP, + typename T + > + typename enable_if_c::pixel_type>::grayscale && + is_float_filtering2::value,rectangle>::type + spatially_filter_image ( + const in_image_type& in_img, + out_image_type& out_img, + const matrix_exp& filter, + T scale, + bool use_abs = false, + bool add_to = false + ) + { + if (use_abs == false) + { + if (scale == 1) + return impl::float_spatially_filter_image(in_img, out_img, filter, add_to); + else + return impl::float_spatially_filter_image(in_img, out_img, filter/scale, add_to); + } + else + { + return impl::grayscale_spatially_filter_image(in_img, out_img, filter, scale, true, add_to); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename in_image_type, + typename out_image_type, + typename EXP, + typename T + > + typename enable_if_c::pixel_type>::grayscale && + !is_float_filtering2::value,rectangle>::type + spatially_filter_image ( + const in_image_type& in_img, + out_image_type& out_img, + const matrix_exp& filter, + T scale, + bool use_abs = false, + bool add_to = false + ) + { + return impl::grayscale_spatially_filter_image(in_img,out_img,filter,scale,use_abs,add_to); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename in_image_type, + typename out_image_type, + typename EXP, + typename T + > + typename disable_if_c::pixel_type>::grayscale,rectangle>::type + spatially_filter_image ( + const in_image_type& in_img_, + out_image_type& out_img_, + const matrix_exp& filter_, + T scale + ) + { + const_temp_matrix filter(filter_); + COMPILE_TIME_ASSERT( pixel_traits::pixel_type>::has_alpha == false ); + COMPILE_TIME_ASSERT( pixel_traits::pixel_type>::has_alpha == false ); + + DLIB_ASSERT(scale != 0 && filter.size() != 0, + "\trectangle spatially_filter_image()" + << "\n\t You can't give a scale of zero or an empty filter." + << "\n\t scale: "<< scale + << "\n\t filter.nr(): "<< filter.nr() + << "\n\t filter.nc(): "<< filter.nc() + ); + DLIB_ASSERT(is_same_object(in_img_, out_img_) == false, + "\trectangle spatially_filter_image()" + << "\n\tYou must give two different image objects" + ); + + + const_image_view in_img(in_img_); + image_view out_img(out_img_); + + // if there isn't any input image then don't do anything + if (in_img.size() == 0) + { + out_img.clear(); + return rectangle(); + } + + out_img.set_size(in_img.nr(),in_img.nc()); + + + // figure out the range that we should apply the filter to + const long first_row = filter.nr()/2; + const long first_col = filter.nc()/2; + const long last_row = in_img.nr() - ((filter.nr()-1)/2); + const long last_col = in_img.nc() - ((filter.nc()-1)/2); + + const rectangle non_border = rectangle(first_col, first_row, last_col-1, last_row-1); + zero_border_pixels(out_img, non_border); + + // apply the filter to the image + for (long r = first_row; r < last_row; ++r) + { + for (long c = first_col; c < last_col; ++c) + { + typedef typename image_traits::pixel_type pixel_type; + typedef matrix::num,1> ptype; + ptype p; + ptype temp; + temp = 0; + for (long m = 0; m < filter.nr(); ++m) + { + for (long n = 0; n < filter.nc(); ++n) + { + // pull out the current pixel and put it into p + p = pixel_to_vector(in_img[r-first_row+m][c-first_col+n]); + temp += p*filter(m,n); + } + } + + temp /= scale; + + pixel_type pp; + vector_to_pixel(pp, temp); + assign_pixel(out_img[r][c], pp); + } + } + + return non_border; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename in_image_type, + typename out_image_type, + typename EXP + > + rectangle spatially_filter_image ( + const in_image_type& in_img, + out_image_type& out_img, + const matrix_exp& filter + ) + { + return spatially_filter_image(in_img,out_img,filter,1); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + namespace impl + { + template < + typename in_image_type, + typename out_image_type, + typename EXP1, + typename EXP2, + typename T + > + rectangle grayscale_spatially_filter_image_separable ( + const in_image_type& in_img_, + out_image_type& out_img_, + const matrix_exp& _row_filter, + const matrix_exp& _col_filter, + T scale, + bool use_abs, + bool add_to + ) + { + const_temp_matrix row_filter(_row_filter); + const_temp_matrix col_filter(_col_filter); + COMPILE_TIME_ASSERT( pixel_traits::pixel_type>::has_alpha == false ); + COMPILE_TIME_ASSERT( pixel_traits::pixel_type>::has_alpha == false ); + + DLIB_ASSERT(scale != 0 && row_filter.size() != 0 && col_filter.size() != 0 && + is_vector(row_filter) && + is_vector(col_filter), + "\trectangle spatially_filter_image_separable()" + << "\n\t Invalid inputs were given to this function." + << "\n\t scale: "<< scale + << "\n\t row_filter.size(): "<< row_filter.size() + << "\n\t col_filter.size(): "<< col_filter.size() + << "\n\t is_vector(row_filter): "<< is_vector(row_filter) + << "\n\t is_vector(col_filter): "<< is_vector(col_filter) + ); + DLIB_ASSERT(is_same_object(in_img_, out_img_) == false, + "\trectangle spatially_filter_image_separable()" + << "\n\tYou must give two different image objects" + ); + + + const_image_view in_img(in_img_); + image_view out_img(out_img_); + + // if there isn't any input image then don't do anything + if (in_img.size() == 0) + { + out_img.clear(); + return rectangle(); + } + + out_img.set_size(in_img.nr(),in_img.nc()); + + + // figure out the range that we should apply the filter to + const long first_row = col_filter.size()/2; + const long first_col = row_filter.size()/2; + const long last_row = in_img.nr() - ((col_filter.size()-1)/2); + const long last_col = in_img.nc() - ((row_filter.size()-1)/2); + + const rectangle non_border = rectangle(first_col, first_row, last_col-1, last_row-1); + if (!add_to) + zero_border_pixels(out_img, non_border); + + typedef typename EXP1::type ptype; + + array2d temp_img; + temp_img.set_size(in_img.nr(), in_img.nc()); + + // apply the row filter + for (long r = 0; r < in_img.nr(); ++r) + { + for (long c = first_col; c < last_col; ++c) + { + ptype p; + ptype temp = 0; + for (long n = 0; n < row_filter.size(); ++n) + { + // pull out the current pixel and put it into p + p = get_pixel_intensity(in_img[r][c-first_col+n]); + temp += p*row_filter(n); + } + temp_img[r][c] = temp; + } + } + + // apply the column filter + for (long r = first_row; r < last_row; ++r) + { + for (long c = first_col; c < last_col; ++c) + { + ptype temp = 0; + for (long m = 0; m < col_filter.size(); ++m) + { + temp += temp_img[r-first_row+m][c]*col_filter(m); + } + + temp /= scale; + + if (use_abs && temp < 0) + { + temp = -temp; + } + + // save this pixel to the output image + if (add_to == false) + { + assign_pixel(out_img[r][c], temp); + } + else + { + assign_pixel(out_img[r][c], temp + out_img[r][c]); + } + } + } + return non_border; + } + + } // namespace impl + +// ---------------------------------------------------------------------------------------- + + template < + typename in_image_type, + typename out_image_type, + typename EXP1, + typename EXP2 + > + struct is_float_filtering + { + const static bool value = is_same_type::pixel_type,float>::value && + is_same_type::pixel_type,float>::value && + is_same_type::value && + is_same_type::value; + }; + +// ---------------------------------------------------------------------------------------- + + // This overload is optimized to use SIMD instructions when filtering float images with + // float filters. + template < + typename in_image_type, + typename out_image_type, + typename EXP1, + typename EXP2 + > + rectangle float_spatially_filter_image_separable ( + const in_image_type& in_img_, + out_image_type& out_img_, + const matrix_exp& _row_filter, + const matrix_exp& _col_filter, + out_image_type& scratch_, + bool add_to = false + ) + { + // You can only use this function with images and filters containing float + // variables. + COMPILE_TIME_ASSERT((is_float_filtering::value == true)); + + + const_temp_matrix row_filter(_row_filter); + const_temp_matrix col_filter(_col_filter); + DLIB_ASSERT(row_filter.size() != 0 && col_filter.size() != 0 && + is_vector(row_filter) && + is_vector(col_filter), + "\trectangle float_spatially_filter_image_separable()" + << "\n\t Invalid inputs were given to this function." + << "\n\t row_filter.size(): "<< row_filter.size() + << "\n\t col_filter.size(): "<< col_filter.size() + << "\n\t is_vector(row_filter): "<< is_vector(row_filter) + << "\n\t is_vector(col_filter): "<< is_vector(col_filter) + ); + DLIB_ASSERT(is_same_object(in_img_, out_img_) == false, + "\trectangle float_spatially_filter_image_separable()" + << "\n\tYou must give two different image objects" + ); + + + const_image_view in_img(in_img_); + image_view out_img(out_img_); + + // if there isn't any input image then don't do anything + if (in_img.size() == 0) + { + out_img.clear(); + return rectangle(); + } + + out_img.set_size(in_img.nr(),in_img.nc()); + + // figure out the range that we should apply the filter to + const long first_row = col_filter.size()/2; + const long first_col = row_filter.size()/2; + const long last_row = in_img.nr() - ((col_filter.size()-1)/2); + const long last_col = in_img.nc() - ((row_filter.size()-1)/2); + + const rectangle non_border = rectangle(first_col, first_row, last_col-1, last_row-1); + if (!add_to) + zero_border_pixels(out_img, non_border); + + image_view scratch(scratch_); + scratch.set_size(in_img.nr(), in_img.nc()); + + // apply the row filter + for (long r = 0; r < in_img.nr(); ++r) + { + long c = first_col; + for (; c < last_col-7; c+=8) + { + simd8f p,p2,p3, temp = 0, temp2=0, temp3=0; + long n = 0; + for (; n < row_filter.size()-2; n+=3) + { + // pull out the current pixel and put it into p + p.load(&in_img[r][c-first_col+n]); + p2.load(&in_img[r][c-first_col+n+1]); + p3.load(&in_img[r][c-first_col+n+2]); + temp += p*row_filter(n); + temp2 += p2*row_filter(n+1); + temp3 += p3*row_filter(n+2); + } + for (; n < row_filter.size(); ++n) + { + // pull out the current pixel and put it into p + p.load(&in_img[r][c-first_col+n]); + temp += p*row_filter(n); + } + temp += temp2 + temp3; + temp.store(&scratch[r][c]); + } + for (; c < last_col; ++c) + { + float p; + float temp = 0; + for (long n = 0; n < row_filter.size(); ++n) + { + // pull out the current pixel and put it into p + p = in_img[r][c-first_col+n]; + temp += p*row_filter(n); + } + scratch[r][c] = temp; + } + } + + // apply the column filter + for (long r = first_row; r < last_row; ++r) + { + long c = first_col; + for (; c < last_col-7; c+=8) + { + simd8f p, p2, p3, temp = 0, temp2 = 0, temp3 = 0; + long m = 0; + for (; m < col_filter.size()-2; m+=3) + { + p.load(&scratch[r-first_row+m][c]); + p2.load(&scratch[r-first_row+m+1][c]); + p3.load(&scratch[r-first_row+m+2][c]); + temp += p*col_filter(m); + temp2 += p2*col_filter(m+1); + temp3 += p3*col_filter(m+2); + } + for (; m < col_filter.size(); ++m) + { + p.load(&scratch[r-first_row+m][c]); + temp += p*col_filter(m); + } + temp += temp2+temp3; + + // save this pixel to the output image + if (add_to == false) + { + temp.store(&out_img[r][c]); + } + else + { + p.load(&out_img[r][c]); + temp += p; + temp.store(&out_img[r][c]); + } + } + for (; c < last_col; ++c) + { + float temp = 0; + for (long m = 0; m < col_filter.size(); ++m) + { + temp += scratch[r-first_row+m][c]*col_filter(m); + } + + // save this pixel to the output image + if (add_to == false) + { + out_img[r][c] = temp; + } + else + { + out_img[r][c] += temp; + } + } + } + return non_border; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename in_image_type, + typename out_image_type, + typename EXP1, + typename EXP2, + typename T + > + typename enable_if_c::pixel_type>::grayscale && + is_float_filtering::value,rectangle>::type + spatially_filter_image_separable ( + const in_image_type& in_img, + out_image_type& out_img, + const matrix_exp& row_filter, + const matrix_exp& col_filter, + T scale, + bool use_abs = false, + bool add_to = false + ) + { + if (use_abs == false) + { + out_image_type scratch; + if (scale == 1) + return float_spatially_filter_image_separable(in_img, out_img, row_filter, col_filter, scratch, add_to); + else + return float_spatially_filter_image_separable(in_img, out_img, row_filter/scale, col_filter, scratch, add_to); + } + else + { + return impl::grayscale_spatially_filter_image_separable(in_img, out_img, row_filter, col_filter, scale, true, add_to); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename in_image_type, + typename out_image_type, + typename EXP1, + typename EXP2, + typename T + > + typename enable_if_c::pixel_type>::grayscale && + !is_float_filtering::value,rectangle>::type + spatially_filter_image_separable ( + const in_image_type& in_img, + out_image_type& out_img, + const matrix_exp& row_filter, + const matrix_exp& col_filter, + T scale, + bool use_abs = false, + bool add_to = false + ) + { + return impl::grayscale_spatially_filter_image_separable(in_img,out_img, row_filter, col_filter, scale, use_abs, add_to); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename in_image_type, + typename out_image_type, + typename EXP1, + typename EXP2, + typename T + > + typename disable_if_c::pixel_type>::grayscale,rectangle>::type + spatially_filter_image_separable ( + const in_image_type& in_img_, + out_image_type& out_img_, + const matrix_exp& _row_filter, + const matrix_exp& _col_filter, + T scale + ) + { + const_temp_matrix row_filter(_row_filter); + const_temp_matrix col_filter(_col_filter); + COMPILE_TIME_ASSERT( pixel_traits::pixel_type>::has_alpha == false ); + COMPILE_TIME_ASSERT( pixel_traits::pixel_type>::has_alpha == false ); + + DLIB_ASSERT(scale != 0 && row_filter.size() != 0 && col_filter.size() != 0 && + is_vector(row_filter) && + is_vector(col_filter), + "\trectangle spatially_filter_image_separable()" + << "\n\t Invalid inputs were given to this function." + << "\n\t scale: "<< scale + << "\n\t row_filter.size(): "<< row_filter.size() + << "\n\t col_filter.size(): "<< col_filter.size() + << "\n\t is_vector(row_filter): "<< is_vector(row_filter) + << "\n\t is_vector(col_filter): "<< is_vector(col_filter) + ); + DLIB_ASSERT(is_same_object(in_img_, out_img_) == false, + "\trectangle spatially_filter_image_separable()" + << "\n\tYou must give two different image objects" + ); + + + const_image_view in_img(in_img_); + image_view out_img(out_img_); + + // if there isn't any input image then don't do anything + if (in_img.size() == 0) + { + out_img.clear(); + return rectangle(); + } + + out_img.set_size(in_img.nr(),in_img.nc()); + + + // figure out the range that we should apply the filter to + const long first_row = col_filter.size()/2; + const long first_col = row_filter.size()/2; + const long last_row = in_img.nr() - ((col_filter.size()-1)/2); + const long last_col = in_img.nc() - ((row_filter.size()-1)/2); + + const rectangle non_border = rectangle(first_col, first_row, last_col-1, last_row-1); + zero_border_pixels(out_img, non_border); + + typedef typename image_traits::pixel_type pixel_type; + typedef matrix::num,1> ptype; + + array2d temp_img; + temp_img.set_size(in_img.nr(), in_img.nc()); + + // apply the row filter + for (long r = 0; r < in_img.nr(); ++r) + { + for (long c = first_col; c < last_col; ++c) + { + ptype p; + ptype temp; + temp = 0; + for (long n = 0; n < row_filter.size(); ++n) + { + // pull out the current pixel and put it into p + p = pixel_to_vector(in_img[r][c-first_col+n]); + temp += p*row_filter(n); + } + temp_img[r][c] = temp; + } + } + + // apply the column filter + for (long r = first_row; r < last_row; ++r) + { + for (long c = first_col; c < last_col; ++c) + { + ptype temp; + temp = 0; + for (long m = 0; m < col_filter.size(); ++m) + { + temp += temp_img[r-first_row+m][c]*col_filter(m); + } + + temp /= scale; + + + // save this pixel to the output image + pixel_type p; + vector_to_pixel(p, temp); + assign_pixel(out_img[r][c], p); + } + } + return non_border; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename in_image_type, + typename out_image_type, + typename EXP1, + typename EXP2 + > + rectangle spatially_filter_image_separable ( + const in_image_type& in_img, + out_image_type& out_img, + const matrix_exp& row_filter, + const matrix_exp& col_filter + ) + { + return spatially_filter_image_separable(in_img,out_img,row_filter,col_filter,1); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename in_image_type, + typename out_image_type, + typename EXP1, + typename EXP2, + typename T + > + rectangle spatially_filter_image_separable_down ( + const unsigned long downsample, + const in_image_type& in_img_, + out_image_type& out_img_, + const matrix_exp& row_filter, + const matrix_exp& col_filter, + T scale, + bool use_abs = false, + bool add_to = false + ) + { + COMPILE_TIME_ASSERT( pixel_traits::pixel_type>::has_alpha == false ); + COMPILE_TIME_ASSERT( pixel_traits::pixel_type>::has_alpha == false ); + COMPILE_TIME_ASSERT( pixel_traits::pixel_type>::grayscale == true ); + + DLIB_ASSERT(downsample > 0 && + scale != 0 && + row_filter.size()%2 == 1 && + col_filter.size()%2 == 1 && + is_vector(row_filter) && + is_vector(col_filter), + "\trectangle spatially_filter_image_separable_down()" + << "\n\t Invalid inputs were given to this function." + << "\n\t downsample: "<< downsample + << "\n\t scale: "<< scale + << "\n\t row_filter.size(): "<< row_filter.size() + << "\n\t col_filter.size(): "<< col_filter.size() + << "\n\t is_vector(row_filter): "<< is_vector(row_filter) + << "\n\t is_vector(col_filter): "<< is_vector(col_filter) + ); + DLIB_ASSERT(is_same_object(in_img_, out_img_) == false, + "\trectangle spatially_filter_image_separable_down()" + << "\n\tYou must give two different image objects" + ); + + + const_image_view in_img(in_img_); + image_view out_img(out_img_); + + // if there isn't any input image then don't do anything + if (in_img.size() == 0) + { + out_img.clear(); + return rectangle(); + } + + out_img.set_size((long)(std::ceil((double)in_img.nr()/downsample)), + (long)(std::ceil((double)in_img.nc()/downsample))); + + const double col_border = std::floor(col_filter.size()/2.0); + const double row_border = std::floor(row_filter.size()/2.0); + + // figure out the range that we should apply the filter to + const long first_row = (long)std::ceil(col_border/downsample); + const long first_col = (long)std::ceil(row_border/downsample); + const long last_row = (long)std::ceil((in_img.nr() - col_border)/downsample) - 1; + const long last_col = (long)std::ceil((in_img.nc() - row_border)/downsample) - 1; + + // zero border pixels + const rectangle non_border = rectangle(first_col, first_row, last_col, last_row); + zero_border_pixels(out_img,non_border); + + typedef typename EXP1::type ptype; + + array2d temp_img; + temp_img.set_size(in_img.nr(), out_img.nc()); + + // apply the row filter + for (long r = 0; r < temp_img.nr(); ++r) + { + for (long c = non_border.left(); c <= non_border.right(); ++c) + { + ptype p; + ptype temp = 0; + for (long n = 0; n < row_filter.size(); ++n) + { + // pull out the current pixel and put it into p + p = get_pixel_intensity(in_img[r][c*downsample-row_filter.size()/2+n]); + temp += p*row_filter(n); + } + temp_img[r][c] = temp; + } + } + + // apply the column filter + for (long r = non_border.top(); r <= non_border.bottom(); ++r) + { + for (long c = non_border.left(); c <= non_border.right(); ++c) + { + ptype temp = 0; + for (long m = 0; m < col_filter.size(); ++m) + { + temp += temp_img[r*downsample-col_filter.size()/2+m][c]*col_filter(m); + } + + temp /= scale; + + if (use_abs && temp < 0) + { + temp = -temp; + } + + // save this pixel to the output image + if (add_to == false) + { + assign_pixel(out_img[r][c], temp); + } + else + { + assign_pixel(out_img[r][c], temp + out_img[r][c]); + } + } + } + + return non_border; + } + + template < + typename in_image_type, + typename out_image_type, + typename EXP1, + typename EXP2 + > + rectangle spatially_filter_image_separable_down ( + const unsigned long downsample, + const in_image_type& in_img, + out_image_type& out_img, + const matrix_exp& row_filter, + const matrix_exp& col_filter + ) + { + return spatially_filter_image_separable_down(downsample,in_img,out_img,row_filter,col_filter,1); + } + +// ---------------------------------------------------------------------------------------- + + template < + long NR, + long NC, + typename T, + typename U, + typename in_image_type + > + inline void separable_3x3_filter_block_grayscale ( + T (&block)[NR][NC], + const in_image_type& img_, + const long& r, + const long& c, + const U& fe1, // separable filter end + const U& fm, // separable filter middle + const U& fe2 // separable filter end 2 + ) + { + const_image_view img(img_); + // make sure requires clause is not broken + DLIB_ASSERT(shrink_rect(get_rect(img),1).contains(c,r) && + shrink_rect(get_rect(img),1).contains(c+NC-1,r+NR-1), + "\t void separable_3x3_filter_block_grayscale()" + << "\n\t The sub-window doesn't fit inside the given image." + << "\n\t get_rect(img): " << get_rect(img) + << "\n\t (c,r): " << point(c,r) + << "\n\t (c+NC-1,r+NR-1): " << point(c+NC-1,r+NR-1) + ); + + + T row_filt[NR+2][NC]; + for (long rr = 0; rr < NR+2; ++rr) + { + for (long cc = 0; cc < NC; ++cc) + { + row_filt[rr][cc] = get_pixel_intensity(img[r+rr-1][c+cc-1])*fe1 + + get_pixel_intensity(img[r+rr-1][c+cc])*fm + + get_pixel_intensity(img[r+rr-1][c+cc+1])*fe2; + } + } + + for (long rr = 0; rr < NR; ++rr) + { + for (long cc = 0; cc < NC; ++cc) + { + block[rr][cc] = (row_filt[rr][cc]*fe1 + + row_filt[rr+1][cc]*fm + + row_filt[rr+2][cc]*fe2); + } + } + + } + +// ---------------------------------------------------------------------------------------- + + template < + long NR, + long NC, + typename T, + typename U, + typename in_image_type + > + inline void separable_3x3_filter_block_rgb ( + T (&block)[NR][NC], + const in_image_type& img_, + const long& r, + const long& c, + const U& fe1, // separable filter end + const U& fm, // separable filter middle + const U& fe2 // separable filter end 2 + ) + { + const_image_view img(img_); + // make sure requires clause is not broken + DLIB_ASSERT(shrink_rect(get_rect(img),1).contains(c,r) && + shrink_rect(get_rect(img),1).contains(c+NC-1,r+NR-1), + "\t void separable_3x3_filter_block_rgb()" + << "\n\t The sub-window doesn't fit inside the given image." + << "\n\t get_rect(img): " << get_rect(img) + << "\n\t (c,r): " << point(c,r) + << "\n\t (c+NC-1,r+NR-1): " << point(c+NC-1,r+NR-1) + ); + + T row_filt[NR+2][NC]; + for (long rr = 0; rr < NR+2; ++rr) + { + for (long cc = 0; cc < NC; ++cc) + { + row_filt[rr][cc].red = img[r+rr-1][c+cc-1].red*fe1 + img[r+rr-1][c+cc].red*fm + img[r+rr-1][c+cc+1].red*fe2; + row_filt[rr][cc].green = img[r+rr-1][c+cc-1].green*fe1 + img[r+rr-1][c+cc].green*fm + img[r+rr-1][c+cc+1].green*fe2; + row_filt[rr][cc].blue = img[r+rr-1][c+cc-1].blue*fe1 + img[r+rr-1][c+cc].blue*fm + img[r+rr-1][c+cc+1].blue*fe2; + } + } + + for (long rr = 0; rr < NR; ++rr) + { + for (long cc = 0; cc < NC; ++cc) + { + block[rr][cc].red = row_filt[rr][cc].red*fe1 + row_filt[rr+1][cc].red*fm + row_filt[rr+2][cc].red*fe2; + block[rr][cc].green = row_filt[rr][cc].green*fe1 + row_filt[rr+1][cc].green*fm + row_filt[rr+2][cc].green*fe2; + block[rr][cc].blue = row_filt[rr][cc].blue*fe1 + row_filt[rr+1][cc].blue*fm + row_filt[rr+2][cc].blue*fe2; + } + } + + } + +// ---------------------------------------------------------------------------------------- + + inline double gaussian ( + double x, + double sigma + ) + { + DLIB_ASSERT(sigma > 0, + "\tdouble gaussian(x)" + << "\n\t sigma must be bigger than 0" + << "\n\t sigma: " << sigma + ); + const double sqrt_2_pi = 2.5066282746310002416123552393401041626930; + return 1.0/(sigma*sqrt_2_pi) * std::exp( -(x*x)/(2*sigma*sigma)); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + matrix create_gaussian_filter ( + double sigma, + int max_size + ) + { + DLIB_ASSERT(sigma > 0 && max_size > 0 && (max_size%2)==1, + "\t matrix create_gaussian_filter()" + << "\n\t Invalid inputs were given to this function." + << "\n\t sigma: " << sigma + << "\n\t max_size: " << max_size + ); + + // Adjust the size so that the ratio of the gaussian values isn't huge. + // This only matters when T is an integer type. However, we do it for + // all types so that the behavior of this function is always relatively + // the same. + while (gaussian(0,sigma)/gaussian(max_size/2,sigma) > 50) + --max_size; + + + matrix f(max_size); + for (long i = 0; i < f.size(); ++i) + { + f(i) = gaussian(i-max_size/2, sigma); + } + + if (is_float_type::value == false) + { + f /= f(0); + return matrix_cast(round(f)); + } + else + { + return matrix_cast(f); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename in_image_type, + typename out_image_type + > + rectangle gaussian_blur ( + const in_image_type& in_img, + out_image_type& out_img, + double sigma = 1, + int max_size = 1001 + ) + { + DLIB_ASSERT(sigma > 0 && max_size > 0 && (max_size%2)==1 && + is_same_object(in_img, out_img) == false, + "\t void gaussian_blur()" + << "\n\t Invalid inputs were given to this function." + << "\n\t sigma: " << sigma + << "\n\t max_size: " << max_size + << "\n\t is_same_object(in_img,out_img): " << is_same_object(in_img,out_img) + ); + + if (sigma < 18) + { + typedef typename pixel_traits::pixel_type>::basic_pixel_type type; + typedef typename promote::type ptype; + const matrix& filt = create_gaussian_filter(sigma, max_size); + ptype scale = sum(filt); + scale = scale*scale; + return spatially_filter_image_separable(in_img, out_img, filt, filt, scale); + } + else + { + // For large sigma we need to use a type with a lot of precision to avoid + // numerical problems. So we use double here. + typedef double ptype; + const matrix& filt = create_gaussian_filter(sigma, max_size); + ptype scale = sum(filt); + scale = scale*scale; + return spatially_filter_image_separable(in_img, out_img, filt, filt, scale); + } + + } + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + template < + bool add_to, + typename image_type1, + typename image_type2 + > + void sum_filter ( + const image_type1& img_, + image_type2& out_, + const rectangle& rect + ) + { + const_image_view img(img_); + image_view out(out_); + DLIB_ASSERT(img.nr() == out.nr() && + img.nc() == out.nc() && + is_same_object(img_,out_) == false, + "\t void sum_filter()" + << "\n\t Invalid arguments given to this function." + << "\n\t img.nr(): " << img.nr() + << "\n\t img.nc(): " << img.nc() + << "\n\t out.nr(): " << out.nr() + << "\n\t out.nc(): " << out.nc() + << "\n\t is_same_object(img_,out_): " << is_same_object(img_,out_) + ); + + typedef typename image_traits::pixel_type pixel_type; + typedef typename promote::type ptype; + + std::vector column_sum; + column_sum.resize(img.nc() + rect.width(),0); + + const long top = -1 + rect.top(); + const long bottom = -1 + rect.bottom(); + long left = rect.left()-1; + + // initialize column_sum at row -1 + for (unsigned long j = 0; j < column_sum.size(); ++j) + { + rectangle strip(left,top,left,bottom); + strip = strip.intersect(get_rect(img)); + if (!strip.is_empty()) + { + column_sum[j] = sum(matrix_cast(subm(mat(img),strip))); + } + + ++left; + } + + + const rectangle area = get_rect(img); + + // Save width to avoid computing it over and over. + const long width = rect.width(); + + + // Now do the bulk of the filtering work. + for (long r = 0; r < img.nr(); ++r) + { + // set to sum at point(-1,r). i.e. should be equal to sum(mat(img), translate_rect(rect, point(-1,r))) + // We compute it's value in the next loop. + ptype cur_sum = 0; + + // Update the first part of column_sum since we only work on the c+width part of column_sum + // in the main loop. + const long top = r + rect.top() - 1; + const long bottom = r + rect.bottom(); + for (long k = 0; k < width; ++k) + { + const long right = k-width + rect.right(); + + const ptype br_corner = area.contains(right,bottom) ? img[bottom][right] : 0; + const ptype tr_corner = area.contains(right,top) ? img[top][right] : 0; + // update the sum in this column now that we are on the next row + column_sum[k] = column_sum[k] + br_corner - tr_corner; + cur_sum += column_sum[k]; + } + + for (long c = 0; c < img.nc(); ++c) + { + const long top = r + rect.top() - 1; + const long bottom = r + rect.bottom(); + const long right = c + rect.right(); + + const ptype br_corner = area.contains(right,bottom) ? img[bottom][right] : 0; + const ptype tr_corner = area.contains(right,top) ? img[top][right] : 0; + + // update the sum in this column now that we are on the next row + column_sum[c+width] = column_sum[c+width] + br_corner - tr_corner; + + // add in the new right side of the rect and subtract the old right side. + cur_sum = cur_sum + column_sum[c+width] - column_sum[c]; + + if (add_to) + out[r][c] += static_cast::pixel_type>(cur_sum); + else + out[r][c] = static_cast::pixel_type>(cur_sum); + } + } + } + } + + template < + typename image_type1, + typename image_type2 + > + void sum_filter ( + const image_type1& img, + image_type2& out, + const rectangle& rect + ) + { + impl::sum_filter(img,out,rect); + } + + template < + typename image_type1, + typename image_type2 + > + void sum_filter_assign ( + const image_type1& img, + image_type2& out, + const rectangle& rect + ) + { + set_image_size(out, num_rows(img), num_columns(img)); + impl::sum_filter(img,out,rect); + } + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + template + class fast_deque + { + /* + This is a fast and minimal implementation of std::deque for + use with the max_filter. + + This object assumes that no more than max_size elements + will ever be pushed into it at a time. + */ + public: + + explicit fast_deque(unsigned long max_size) + { + // find a power of two that upper bounds max_size + mask = 2; + while (mask < max_size) + mask *= 2; + + clear(); + + data.resize(mask); + --mask; // make into bit mask + } + + void clear() + { + first = 1; + last = 0; + size = 0; + } + + bool empty() const + { + return size == 0; + } + + void pop_back() + { + last = (last-1)&mask; + --size; + } + + void push_back(const T& item) + { + last = (last+1)&mask; + ++size; + data[last] = item; + } + + void pop_front() + { + first = (first+1)&mask; + --size; + } + + const T& front() const + { + return data[first]; + } + + const T& back() const + { + return data[last]; + } + + private: + + std::vector data; + unsigned long mask; + unsigned long first; + unsigned long last; + unsigned long size; + }; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type1, + typename image_type2 + > + void max_filter ( + image_type1& img_, + image_type2& out_, + const long width, + const long height, + const typename image_traits::pixel_type& thresh + ) + { + image_view img(img_); + image_view out(out_); + DLIB_ASSERT( width > 0 && + height > 0 && + out.nr() == img.nr() && + out.nc() == img.nc() && + is_same_object(img_,out_) == false, + "\t void max_filter()" + << "\n\t Invalid arguments given to this function." + << "\n\t img.nr(): " << img.nr() + << "\n\t img.nc(): " << img.nc() + << "\n\t out.nr(): " << out.nr() + << "\n\t out.nc(): " << out.nc() + << "\n\t width: " << width + << "\n\t height: " << height + << "\n\t is_same_object(img_,out_): " << is_same_object(img_,out_) + ); + + typedef typename image_traits::pixel_type pixel_type; + + + dlib::impl::fast_deque > Q(std::max(width,height)); + + const long last_col = std::max(img.nc(), ((width-1)/2)); + const long last_row = std::max(img.nr(), ((height-1)/2)); + + // run max filter along rows of img + for (long r = 0; r < img.nr(); ++r) + { + Q.clear(); + for (long c = 0; c < (width-1)/2 && c < img.nc(); ++c) + { + while (!Q.empty() && img[r][c] >= Q.back().second) + Q.pop_back(); + Q.push_back(std::make_pair(c,img[r][c])); + } + + for (long c = (width-1)/2; c < img.nc(); ++c) + { + while (!Q.empty() && img[r][c] >= Q.back().second) + Q.pop_back(); + while (!Q.empty() && Q.front().first <= c-width) + Q.pop_front(); + Q.push_back(std::make_pair(c,img[r][c])); + + img[r][c-((width-1)/2)] = Q.front().second; + } + + for (long c = last_col; c < img.nc() + ((width-1)/2); ++c) + { + while (!Q.empty() && Q.front().first <= c-width) + Q.pop_front(); + + img[r][c-((width-1)/2)] = Q.front().second; + } + } + + // run max filter along columns of img. Store result in out. + for (long cc = 0; cc < img.nc(); ++cc) + { + Q.clear(); + for (long rr = 0; rr < (height-1)/2 && rr < img.nr(); ++rr) + { + while (!Q.empty() && img[rr][cc] >= Q.back().second) + Q.pop_back(); + Q.push_back(std::make_pair(rr,img[rr][cc])); + } + + for (long rr = (height-1)/2; rr < img.nr(); ++rr) + { + while (!Q.empty() && img[rr][cc] >= Q.back().second) + Q.pop_back(); + while (!Q.empty() && Q.front().first <= rr-height) + Q.pop_front(); + Q.push_back(std::make_pair(rr,img[rr][cc])); + + out[rr-((height-1)/2)][cc] += std::max(Q.front().second, thresh); + } + + for (long rr = last_row; rr < img.nr() + ((height-1)/2); ++rr) + { + while (!Q.empty() && Q.front().first <= rr-height) + Q.pop_front(); + + out[rr-((height-1)/2)][cc] += std::max(Q.front().second, thresh); + } + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SPATIAL_FILTERINg_H_ + + diff --git a/ml/dlib/dlib/image_transforms/spatial_filtering_abstract.h b/ml/dlib/dlib/image_transforms/spatial_filtering_abstract.h new file mode 100644 index 000000000..5e200aa9a --- /dev/null +++ b/ml/dlib/dlib/image_transforms/spatial_filtering_abstract.h @@ -0,0 +1,487 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SPATIAL_FILTERINg_ABSTRACT_ +#ifdef DLIB_SPATIAL_FILTERINg_ABSTRACT_ + +#include "../pixel.h" +#include "../matrix.h" +#include "../image_processing/generic_image.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename in_image_type, + typename out_image_type, + typename EXP, + typename T + > + rectangle spatially_filter_image ( + const in_image_type& in_img, + out_image_type& out_img, + const matrix_exp& filter, + T scale = 1, + bool use_abs = false, + bool add_to = false + ); + /*! + requires + - in_image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - out_image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - in_img and out_img do not contain pixels with an alpha channel. That is, + pixel_traits::has_alpha is false for the pixels in these objects. + - is_same_object(in_img, out_img) == false + - T must be some scalar type + - filter.size() != 0 + - scale != 0 + - if (in_img doesn't contain grayscale pixels) then + - use_abs == false && add_to == false + (i.e. You can only use the use_abs and add_to options with grayscale images) + ensures + - Applies the given spatial filter to in_img and stores the result in out_img (i.e. + cross-correlates in_img with filter). Also divides each resulting pixel by scale. + - The intermediate filter computations will be carried out using variables of type EXP::type. + This is whatever scalar type is used inside the filter matrix. + - Pixel values are stored into out_img using the assign_pixel() function and therefore + any applicable color space conversion or value saturation is performed. Note that if + add_to is true then the filtered output value will be added to out_img rather than + overwriting the original value. + - if (in_img doesn't contain grayscale pixels) then + - The filter is applied to each color channel independently. + - if (use_abs == true) then + - pixel values after filtering that are < 0 are converted to their absolute values. + - The filter is applied such that it's centered over the pixel it writes its + output into. For centering purposes, we consider the center element of the + filter to be filter(filter.nr()/2,filter.nc()/2). This means that the filter + that writes its output to a pixel at location point(c,r) and is W by H (width + by height) pixels in size operates on exactly the pixels in the rectangle + centered_rect(point(c,r),W,H) within in_img. + - Pixels close enough to the edge of in_img to not have the filter still fit + inside the image are always set to zero. + - #out_img.nc() == in_img.nc() + - #out_img.nr() == in_img.nr() + - returns a rectangle which indicates what pixels in #out_img are considered + non-border pixels and therefore contain output from the filter. + - if (use_abs == false && all images and filers contain float types) then + - This function will use SIMD instructions and is particularly fast. So if + you can use this form of the function it can give a decent speed boost. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename in_image_type, + typename out_image_type, + typename EXP1, + typename EXP2, + typename T + > + rectangle spatially_filter_image_separable ( + const in_image_type& in_img, + out_image_type& out_img, + const matrix_exp& row_filter, + const matrix_exp& col_filter, + T scale = 1, + bool use_abs = false, + bool add_to = false + ); + /*! + requires + - in_image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - out_image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - in_img and out_img do not contain pixels with an alpha channel. That is, + pixel_traits::has_alpha is false for the pixels in these objects. + - is_same_object(in_img, out_img) == false + - T must be some scalar type + - scale != 0 + - row_filter.size() != 0 + - col_filter.size() != 0 + - is_vector(row_filter) == true + - is_vector(col_filter) == true + - if (in_img doesn't contain grayscale pixels) then + - use_abs == false && add_to == false + (i.e. You can only use the use_abs and add_to options with grayscale images) + ensures + - Applies the given separable spatial filter to in_img and stores the result in out_img. + Also divides each resulting pixel by scale. Calling this function has the same + effect as calling the regular spatially_filter_image() routine with a filter, + FILT, defined as follows: + - FILT(r,c) == col_filter(r)*row_filter(c) + - The intermediate filter computations will be carried out using variables of type EXP1::type. + This is whatever scalar type is used inside the row_filter matrix. + - Pixel values are stored into out_img using the assign_pixel() function and therefore + any applicable color space conversion or value saturation is performed. Note that if + add_to is true then the filtered output value will be added to out_img rather than + overwriting the original value. + - if (in_img doesn't contain grayscale pixels) then + - The filter is applied to each color channel independently. + - if (use_abs == true) then + - pixel values after filtering that are < 0 are converted to their absolute values + - The filter is applied such that it's centered over the pixel it writes its + output into. For centering purposes, we consider the center element of the + filter to be FILT(col_filter.size()/2,row_filter.size()/2). This means that + the filter that writes its output to a pixel at location point(c,r) and is W + by H (width by height) pixels in size operates on exactly the pixels in the + rectangle centered_rect(point(c,r),W,H) within in_img. + - Pixels close enough to the edge of in_img to not have the filter still fit + inside the image are always set to zero. + - #out_img.nc() == in_img.nc() + - #out_img.nr() == in_img.nr() + - returns a rectangle which indicates what pixels in #out_img are considered + non-border pixels and therefore contain output from the filter. + - if (use_abs == false && all images and filers contain float types) then + - This function will use SIMD instructions and is particularly fast. So if + you can use this form of the function it can give a decent speed boost. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename in_image_type, + typename out_image_type, + typename EXP1, + typename EXP2 + > + rectangle float_spatially_filter_image_separable ( + const in_image_type& in_img, + out_image_type& out_img, + const matrix_exp& row_filter, + const matrix_exp& col_filter, + out_image_type& scratch, + bool add_to = false + ); + /*! + requires + - in_image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - out_image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - in_img, out_img, row_filter, and col_filter must all contain float type elements. + - is_same_object(in_img, out_img) == false + - row_filter.size() != 0 + - col_filter.size() != 0 + - is_vector(row_filter) == true + - is_vector(col_filter) == true + ensures + - This function is identical to the above spatially_filter_image_separable() + function except that it can only be invoked on float images with float + filters. In fact, spatially_filter_image_separable() invokes + float_spatially_filter_image_separable() in those cases. So why is + float_spatially_filter_image_separable() in the public API? The reason is + because the separable filtering routines internally allocate an image each + time they are called. If you want to avoid this memory allocation then you + can call float_spatially_filter_image_separable() and provide the scratch + image as input. This allows you to reuse the same scratch image for many + calls to float_spatially_filter_image_separable() and thereby avoid having it + allocated and freed for each call. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename in_image_type, + typename out_image_type, + typename EXP1, + typename EXP2, + typename T + > + rectangle spatially_filter_image_separable_down ( + const unsigned long downsample, + const in_image_type& in_img, + out_image_type& out_img, + const matrix_exp& row_filter, + const matrix_exp& col_filter, + T scale = 1, + bool use_abs = false, + bool add_to = false + ); + /*! + requires + - in_image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - out_image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - in_img and out_img do not contain pixels with an alpha channel. That is, + pixel_traits::has_alpha is false for the pixels in these objects. + - out_img contains grayscale pixels. + - is_same_object(in_img, out_img) == false + - T must be some scalar type + - scale != 0 + - is_vector(row_filter) == true + - is_vector(col_filter) == true + - row_filter.size() % 2 == 1 (i.e. must be odd) + - col_filter.size() % 2 == 1 (i.e. must be odd) + - downsample > 0 + ensures + - This function is equivalent to calling + spatially_filter_image_separable(in_img,out_img,row_filter,col_filter,scale,use_abs,add_to) + and then downsampling the output image by a factor of downsample. Therefore, + we will have that: + - #out_img.nr() == ceil((double)in_img.nr()/downsample) + - #out_img.nc() == ceil((double)in_img.nc()/downsample) + - #out_img[r][c] == filtered pixel corresponding to in_img[r*downsample][c*downsample] + - returns a rectangle which indicates what pixels in #out_img are considered + non-border pixels and therefore contain output from the filter. + - Note that the first row and column of non-zero padded data are the following + - first_row == ceil(floor(col_filter.size()/2.0)/downsample) + - first_col == ceil(floor(row_filter.size()/2.0)/downsample) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + long NR, + long NC, + typename T, + typename U, + typename in_image_type + > + inline void separable_3x3_filter_block_grayscale ( + T (&block)[NR][NC], + const in_image_type& img, + const long& r, + const long& c, + const U& fe1, + const U& fm, + const U& fe2 + ); + /*! + requires + - in_image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - T and U should be scalar types + - shrink_rect(get_rect(img),1).contains(c,r) + - shrink_rect(get_rect(img),1).contains(c+NC-1,r+NR-1) + ensures + - Filters the image in the sub-window of img defined by a rectangle + with its upper left corner at (c,r) and lower right at (c+NC-1,r+NR-1). + - The output of the filter is stored in #block. Note that img will be + interpreted as a grayscale image. + - The filter used is defined by the separable filter [fe1 fm fe2]. So the + spatial filter is thus: + fe1*fe1 fe1*fm fe2*fe1 + fe1*fm fm*fm fe2*fm + fe1*fe2 fe2*fm fe2*fe2 + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + long NR, + long NC, + typename T, + typename U, + typename in_image_type + > + inline void separable_3x3_filter_block_rgb ( + T (&block)[NR][NC], + const in_image_type& img, + const long& r, + const long& c, + const U& fe1, + const U& fm, + const U& fe2 + ); + /*! + requires + - in_image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - img must contain RGB pixels, that is pixel_traits::rgb == true for the pixels + in img. + - T should be a struct with .red .green and .blue members. + - U should be a scalar type + - shrink_rect(get_rect(img),1).contains(c,r) + - shrink_rect(get_rect(img),1).contains(c+NC-1,r+NR-1) + ensures + - Filters the image in the sub-window of img defined by a rectangle + with its upper left corner at (c,r) and lower right at (c+NC-1,r+NR-1). + - The output of the filter is stored in #block. Note that the filter is applied + to each color component independently. + - The filter used is defined by the separable filter [fe1 fm fe2]. So the + spatial filter is thus: + fe1*fe1 fe1*fm fe2*fe1 + fe1*fm fm*fm fe2*fm + fe1*fe2 fe2*fm fe2*fe2 + !*/ + +// ---------------------------------------------------------------------------------------- + + inline double gaussian ( + double x, + double sigma + ); + /*! + requires + - sigma > 0 + ensures + - computes and returns the value of a 1D Gaussian function with mean 0 + and standard deviation sigma at the given x value. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + matrix create_gaussian_filter ( + double sigma, + int size + ); + /*! + requires + - sigma > 0 + - size > 0 + - size is an odd number + ensures + - returns a separable Gaussian filter F such that: + - is_vector(F) == true + - F.size() == size + - F is suitable for use with the spatially_filter_image_separable() routine + and its use with this function corresponds to running a Gaussian filter + of sigma width over an image. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename in_image_type, + typename out_image_type + > + rectangle gaussian_blur ( + const in_image_type& in_img, + out_image_type& out_img, + double sigma = 1, + int max_size = 1001 + ); + /*! + requires + - in_image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - out_image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h + - in_img and out_img do not contain pixels with an alpha channel. That is, + pixel_traits::has_alpha is false for the pixels in these objects. + - is_same_object(in_img, out_img) == false + - sigma > 0 + - max_size > 0 + - max_size is an odd number + ensures + - Filters in_img with a Gaussian filter of sigma width. The actual spatial filter will + be applied to pixel blocks that are at most max_size wide and max_size tall (note that + this function will automatically select a smaller block size as appropriate). The + results are stored into #out_img. + - Pixel values are stored into out_img using the assign_pixel() function and therefore + any applicable color space conversion or value saturation is performed. + - if (in_img doesn't contain grayscale pixels) then + - The filter is applied to each color channel independently. + - Pixels close enough to the edge of in_img to not have the filter still fit + inside the image are set to zero. + - #out_img.nc() == in_img.nc() + - #out_img.nr() == in_img.nr() + - returns a rectangle which indicates what pixels in #out_img are considered + non-border pixels and therefore contain output from the filter. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type1, + typename image_type2 + > + void sum_filter ( + const image_type1& img, + image_type2& out, + const rectangle& rect + ); + /*! + requires + - out.nr() == img.nr() + - out.nc() == img.nc() + - image_type1 == an image object that implements the interface defined in + dlib/image_processing/generic_image.h and it must contain grayscale pixels. + - image_type2 == an image object that implements the interface defined in + dlib/image_processing/generic_image.h and it must contain grayscale pixels. + - is_same_object(img,out) == false + ensures + - for all valid r and c: + - let SUM(r,c) == sum of pixels from img which are inside the rectangle + translate_rect(rect, point(c,r)). + - #out[r][c] == out[r][c] + SUM(r,c) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type1, + typename image_type2 + > + void sum_filter_assign ( + const image_type1& img, + image_type2& out, + const rectangle& rect + ); + /*! + requires + - image_type1 == an image object that implements the interface defined in + dlib/image_processing/generic_image.h and it must contain grayscale pixels. + - image_type2 == an image object that implements the interface defined in + dlib/image_processing/generic_image.h and it must contain grayscale pixels. + - is_same_object(img,out) == false + ensures + - #out.nr() == img.nr() + - #out.nc() == img.nc() + - for all valid r and c: + - let SUM(r,c) == sum of pixels from img which are inside the rectangle + translate_rect(rect, point(c,r)). + - #out[r][c] == SUM(r,c) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type1, + typename image_type2 + > + void max_filter ( + image_type1& img, + image_type2& out, + const long width, + const long height, + const typename image_traits::pixel_type& thresh + ); + /*! + requires + - out.nr() == img.nr() + - out.nc() == img.nc() + - image_type1 == an image object that implements the interface defined in + dlib/image_processing/generic_image.h and it must contain grayscale pixels. + - image_type2 == an image object that implements the interface defined in + dlib/image_processing/generic_image.h and it must contain grayscale pixels. + - is_same_object(img,out) == false + - width > 0 && height > 0 + ensures + - for all valid r and c: + - let MAX(r,c) == maximum of pixels from img which are inside the rectangle + centered_rect(point(c,r), width, height) + - if (MAX(r,c) >= thresh) + - #out[r][c] == out[r][c] + MAX(r,c) + - else + - #out[r][c] == out[r][c] + thresh + - Does not change the size of img. + - Uses img as scratch space. Therefore, the pixel values in img will have + been modified by this function. That is, max_filter() destroys the contents + of img. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SPATIAL_FILTERINg_ABSTRACT_ + diff --git a/ml/dlib/dlib/image_transforms/thresholding.h b/ml/dlib/dlib/image_transforms/thresholding.h new file mode 100644 index 000000000..e4fb02c4a --- /dev/null +++ b/ml/dlib/dlib/image_transforms/thresholding.h @@ -0,0 +1,340 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_THRESHOLDINg_ +#define DLIB_THRESHOLDINg_ + +#include "../pixel.h" +#include "thresholding_abstract.h" +#include "equalize_histogram.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + const unsigned char on_pixel = 255; + const unsigned char off_pixel = 0; + +// ---------------------------------------------------------------------------------------- + + template < + typename in_image_type, + typename out_image_type + > + void threshold_image ( + const in_image_type& in_img_, + out_image_type& out_img_, + typename pixel_traits::pixel_type>::basic_pixel_type thresh + ) + { + COMPILE_TIME_ASSERT( pixel_traits::pixel_type>::has_alpha == false ); + COMPILE_TIME_ASSERT( pixel_traits::pixel_type>::has_alpha == false ); + + COMPILE_TIME_ASSERT(pixel_traits::pixel_type>::grayscale); + + const_image_view in_img(in_img_); + image_view out_img(out_img_); + + // if there isn't any input image then don't do anything + if (in_img.size() == 0) + { + out_img.clear(); + return; + } + + out_img.set_size(in_img.nr(),in_img.nc()); + + for (long r = 0; r < in_img.nr(); ++r) + { + for (long c = 0; c < in_img.nc(); ++c) + { + if (get_pixel_intensity(in_img[r][c]) >= thresh) + assign_pixel(out_img[r][c], on_pixel); + else + assign_pixel(out_img[r][c], off_pixel); + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + void threshold_image ( + image_type& img, + typename pixel_traits::pixel_type>::basic_pixel_type thresh + ) + { + threshold_image(img,img,thresh); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename in_image_type, + typename out_image_type + > + void auto_threshold_image ( + const in_image_type& in_img_, + out_image_type& out_img_ + ) + { + COMPILE_TIME_ASSERT( pixel_traits::pixel_type>::has_alpha == false ); + COMPILE_TIME_ASSERT( pixel_traits::pixel_type>::has_alpha == false ); + COMPILE_TIME_ASSERT( pixel_traits::pixel_type>::is_unsigned == true ); + COMPILE_TIME_ASSERT( pixel_traits::pixel_type>::is_unsigned == true ); + + COMPILE_TIME_ASSERT(pixel_traits::pixel_type>::grayscale); + + image_view out_img(out_img_); + + // if there isn't any input image then don't do anything + if (image_size(in_img_) == 0) + { + out_img.clear(); + return; + } + + unsigned long thresh; + // find the threshold we should use + matrix hist; + get_histogram(in_img_,hist); + + const_image_view in_img(in_img_); + + // Start our two means (a and b) out at the ends of the histogram + long a = 0; + long b = hist.size()-1; + bool moved_a = true; + bool moved_b = true; + while (moved_a || moved_b) + { + moved_a = false; + moved_b = false; + + // catch the degenerate case where the histogram is empty + if (a >= b) + break; + + if (hist(a) == 0) + { + ++a; + moved_a = true; + } + + if (hist(b) == 0) + { + --b; + moved_b = true; + } + } + + // now do k-means clustering with k = 2 on the histogram. + moved_a = true; + moved_b = true; + while (moved_a || moved_b) + { + moved_a = false; + moved_b = false; + + int64 a_hits = 0; + int64 b_hits = 0; + int64 a_mass = 0; + int64 b_mass = 0; + + for (long i = 0; i < hist.size(); ++i) + { + // if i is closer to a + if (std::abs(i-a) < std::abs(i-b)) + { + a_mass += hist(i)*i; + a_hits += hist(i); + } + else // if i is closer to b + { + b_mass += hist(i)*i; + b_hits += hist(i); + } + } + + long new_a = (a_mass + a_hits/2)/a_hits; + long new_b = (b_mass + b_hits/2)/b_hits; + + if (new_a != a) + { + moved_a = true; + a = new_a; + } + + if (new_b != b) + { + moved_b = true; + b = new_b; + } + } + + // put the threshold between the two means we found + thresh = (a + b)/2; + + // now actually apply the threshold + threshold_image(in_img_,out_img_,thresh); + } + + template < + typename image_type + > + void auto_threshold_image ( + image_type& img + ) + { + auto_threshold_image(img,img); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename in_image_type, + typename out_image_type + > + void hysteresis_threshold ( + const in_image_type& in_img_, + out_image_type& out_img_, + typename pixel_traits::pixel_type>::basic_pixel_type lower_thresh, + typename pixel_traits::pixel_type>::basic_pixel_type upper_thresh + ) + { + COMPILE_TIME_ASSERT( pixel_traits::pixel_type>::has_alpha == false ); + COMPILE_TIME_ASSERT( pixel_traits::pixel_type>::has_alpha == false ); + + COMPILE_TIME_ASSERT(pixel_traits::pixel_type>::grayscale); + + DLIB_ASSERT( lower_thresh <= upper_thresh && is_same_object(in_img_, out_img_) == false, + "\tvoid hysteresis_threshold(in_img_, out_img_, lower_thresh, upper_thresh)" + << "\n\tYou can't use an upper_thresh that is less than your lower_thresh" + << "\n\tlower_thresh: " << lower_thresh + << "\n\tupper_thresh: " << upper_thresh + << "\n\tis_same_object(in_img_,out_img_): " << is_same_object(in_img_,out_img_) + ); + + const_image_view in_img(in_img_); + image_view out_img(out_img_); + + // if there isn't any input image then don't do anything + if (in_img.size() == 0) + { + out_img.clear(); + return; + } + + out_img.set_size(in_img.nr(),in_img.nc()); + assign_all_pixels(out_img, off_pixel); + + const long size = 1000; + long rstack[size]; + long cstack[size]; + + // now do the thresholding + for (long r = 0; r < in_img.nr(); ++r) + { + for (long c = 0; c < in_img.nc(); ++c) + { + typename pixel_traits::pixel_type>::basic_pixel_type p; + assign_pixel(p,in_img[r][c]); + if (p >= upper_thresh) + { + // now do line following for pixels >= lower_thresh. + // set the stack position to 0. + long pos = 1; + rstack[0] = r; + cstack[0] = c; + + while (pos > 0) + { + --pos; + const long r = rstack[pos]; + const long c = cstack[pos]; + + // This is the base case of our recursion. We want to stop if we hit a + // pixel we have already visited. + if (out_img[r][c] == on_pixel) + continue; + + out_img[r][c] = on_pixel; + + // put the neighbors of this pixel on the stack if they are bright enough + if (r-1 >= 0) + { + if (pos < size && get_pixel_intensity(in_img[r-1][c]) >= lower_thresh) + { + rstack[pos] = r-1; + cstack[pos] = c; + ++pos; + } + if (pos < size && c-1 >= 0 && get_pixel_intensity(in_img[r-1][c-1]) >= lower_thresh) + { + rstack[pos] = r-1; + cstack[pos] = c-1; + ++pos; + } + if (pos < size && c+1 < in_img.nc() && get_pixel_intensity(in_img[r-1][c+1]) >= lower_thresh) + { + rstack[pos] = r-1; + cstack[pos] = c+1; + ++pos; + } + } + + if (pos < size && c-1 >= 0 && get_pixel_intensity(in_img[r][c-1]) >= lower_thresh) + { + rstack[pos] = r; + cstack[pos] = c-1; + ++pos; + } + if (pos < size && c+1 < in_img.nc() && get_pixel_intensity(in_img[r][c+1]) >= lower_thresh) + { + rstack[pos] = r; + cstack[pos] = c+1; + ++pos; + } + + if (r+1 < in_img.nr()) + { + if (pos < size && get_pixel_intensity(in_img[r+1][c]) >= lower_thresh) + { + rstack[pos] = r+1; + cstack[pos] = c; + ++pos; + } + if (pos < size && c-1 >= 0 && get_pixel_intensity(in_img[r+1][c-1]) >= lower_thresh) + { + rstack[pos] = r+1; + cstack[pos] = c-1; + ++pos; + } + if (pos < size && c+1 < in_img.nc() && get_pixel_intensity(in_img[r+1][c+1]) >= lower_thresh) + { + rstack[pos] = r+1; + cstack[pos] = c+1; + ++pos; + } + } + + } // end while (pos >= 0) + + } + else + { + out_img[r][c] = off_pixel; + } + + } + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_THRESHOLDINg_ + diff --git a/ml/dlib/dlib/image_transforms/thresholding_abstract.h b/ml/dlib/dlib/image_transforms/thresholding_abstract.h new file mode 100644 index 000000000..e7c1e8826 --- /dev/null +++ b/ml/dlib/dlib/image_transforms/thresholding_abstract.h @@ -0,0 +1,139 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_THRESHOLDINg_ABSTRACT_ +#ifdef DLIB_THRESHOLDINg_ABSTRACT_ + +#include "../pixel.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + const unsigned char on_pixel = 255; + const unsigned char off_pixel = 0; + +// ---------------------------------------------------------------------------------------- + + template < + typename in_image_type, + typename out_image_type + > + void threshold_image ( + const in_image_type& in_img, + out_image_type& out_img, + typename pixel_traits::pixel_type>::basic_pixel_type thresh + ); + /*! + requires + - in_image_type == is an implementation of array2d/array2d_kernel_abstract.h + - out_image_type == is an implementation of array2d/array2d_kernel_abstract.h + - pixel_traits::pixel_type>::grayscale == true + - pixel_traits::pixel_type>::has_alpha == false + - pixel_traits::pixel_type>::has_alpha == false + ensures + - #out_img == the thresholded version of in_img (in_img is converted to a grayscale + intensity image if it is color). Pixels in in_img with grayscale values >= thresh + have an output value of on_pixel and all others have a value of off_pixel. + - #out_img.nc() == in_img.nc() + - #out_img.nr() == in_img.nr() + !*/ + + template < + typename image_type + > + void threshold_image ( + image_type& img, + typename pixel_traits::pixel_type>::basic_pixel_type thresh + ); + /*! + requires + - it is valid to call threshold_image(img,img,thresh); + ensures + - calls threshold_image(img,img,thresh); + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename in_image_type, + typename out_image_type + > + void auto_threshold_image ( + const in_image_type& in_img, + out_image_type& out_img + ); + /*! + requires + - in_image_type == is an implementation of array2d/array2d_kernel_abstract.h + - out_image_type == is an implementation of array2d/array2d_kernel_abstract.h + - pixel_traits::pixel_type>::max() <= 65535 + - pixel_traits::pixel_type>::has_alpha == false + - pixel_traits::pixel_type>::is_unsigned == true + - pixel_traits::pixel_type>::grayscale == true + - pixel_traits::pixel_type>::has_alpha == false + - pixel_traits::pixel_type>::is_unsigned == true + ensures + - #out_img == the thresholded version of in_img (in_img is converted to a grayscale + intensity image if it is color). Pixels in in_img with grayscale values >= thresh + have an output value of on_pixel and all others have a value of off_pixel. + - The thresh value used is determined by performing a k-means clustering + on the input image histogram with a k of 2. The point between the two + means found is used as the thresh value. + - #out_img.nc() == in_img.nc() + - #out_img.nr() == in_img.nr() + !*/ + + template < + typename image_type + > + void auto_threshold_image ( + image_type& img + ); + /*! + requires + - it is valid to call auto_threshold_image(img,img); + ensures + - calls auto_threshold_image(img,img); + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename in_image_type, + typename out_image_type + > + void hysteresis_threshold ( + const in_image_type& in_img, + out_image_type& out_img, + typename pixel_traits::pixel_type>::basic_pixel_type lower_thresh, + typename pixel_traits::pixel_type>::basic_pixel_type upper_thresh + ); + /*! + requires + - in_image_type == is an implementation of array2d/array2d_kernel_abstract.h + - out_image_type == is an implementation of array2d/array2d_kernel_abstract.h + - pixel_traits::pixel_type>::grayscale == true + - pixel_traits::pixel_type>::has_alpha == false + - pixel_traits::pixel_type>::has_alpha == false + - lower_thresh <= upper_thresh + - is_same_object(in_img, out_img) == false + ensures + - #out_img == the hysteresis thresholded version of in_img (in_img is converted to a + grayscale intensity image if it is color). Pixels in in_img with grayscale + values >= upper_thresh have an output value of on_pixel and all others have a + value of off_pixel unless they are >= lower_thresh and are connected to a pixel + with a value >= upper_thresh, in which case they have a value of on_pixel. Here + pixels are connected if there is a path between them composed of pixels that + would receive an output of on_pixel. + - #out_img.nc() == in_img.nc() + - #out_img.nr() == in_img.nr() + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_THRESHOLDINg_ABSTRACT_ + + diff --git a/ml/dlib/dlib/interfaces/cmd_line_parser_option.h b/ml/dlib/dlib/interfaces/cmd_line_parser_option.h new file mode 100644 index 000000000..797dcd2e6 --- /dev/null +++ b/ml/dlib/dlib/interfaces/cmd_line_parser_option.h @@ -0,0 +1,107 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CMD_LINE_PARSER_OPTIOn_ +#define DLIB_CMD_LINE_PARSER_OPTIOn_ + +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename charT + > + class cmd_line_parser_option + { + /*! + POINTERS AND REFERENCES TO INTERNAL DATA + None of the functions in cmd_line_parser_option will invalidate + pointers or references to internal data when called. + + WHAT THIS OBJECT REPRESENTS + This object represents a command line option. + !*/ + + public: + + typedef charT char_type; + typedef std::basic_string string_type; + + virtual ~cmd_line_parser_option ( + ) = 0; + + virtual const string_type& name ( + ) const = 0; + /*! + ensures + - returns the name of this option + !*/ + + virtual const string_type& group_name ( + ) const = 0; + /*! + ensures + - returns the name of the group this option is in. If no group was set for + this option then this function returns "". + !*/ + + virtual const string_type& description ( + ) const = 0; + /*! + ensures + - returns the description for this option + !*/ + + virtual unsigned long number_of_arguments( + ) const = 0; + /*! + ensures + - returns the number of arguments for this option + !*/ + + virtual unsigned long count( + ) const = 0; + /*! + ensures + - returns the number of times this option appears on the command line. + !*/ + + virtual const string_type& argument ( + unsigned long arg = 0, + unsigned long N = 0 + ) const = 0; + /*! + requires + - arg < number_of_arguments() + - N < count() + ensures + - returns the arg-th argument to the Nth occurrence of this + option on the command line. + !*/ + + inline operator bool ( + ) const { return count() > 0; } + /*! + ensures + - returns true if this option appears on the command line at all + !*/ + + protected: + + // restricted functions + cmd_line_parser_option& operator=(const cmd_line_parser_option&){return *this;} + + }; + + // destructor does nothing + template < typename charT > + cmd_line_parser_option::~cmd_line_parser_option() {} + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CMD_LINE_PARSER_OPTIOn_ + diff --git a/ml/dlib/dlib/interfaces/enumerable.h b/ml/dlib/dlib/interfaces/enumerable.h new file mode 100644 index 000000000..e8f5ae78c --- /dev/null +++ b/ml/dlib/dlib/interfaces/enumerable.h @@ -0,0 +1,130 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ENUMERABLe_INTERFACE_ +#define DLIB_ENUMERABLe_INTERFACE_ + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + class enumerable + { + /*! + POINTERS AND REFERENCES TO INTERNAL DATA + - if (at_start()) then + - all pointers and references to data returned via element() are + invalid. + - calling move_next() or reset() invalidates pointers and references to + data returned via element() and only data returned via element(). + - calling at_start(), current_element_valid(), size(), or element() + does NOT invalidate pointers or references to any internal data. + + INITIAL VALUE + current_element_valid() == false + at_start() == true + + WHAT THIS OBJECT REPRESENTS + This object represent an interface for iterating through the + elements in a container. It starts out one before the first element + in the container. + + + EXAMPLE: The following loops though all elements in the container + and prints them to cout. + + container.reset(); + while(container.move_next()) { + cout << container.element(); + } + !*/ + + public: + typedef T type; + + inline virtual ~enumerable( + ) = 0; + + virtual bool at_start ( + ) const = 0; + /*! + ensures + - returns true if *this represents one position before the first element + in the container (this would also make the current element invalid) + else returns false + !*/ + + virtual void reset ( + ) const = 0; + /*! + ensures + - #current_element_valid() == false + - #at_start() == true + !*/ + + virtual bool current_element_valid ( + ) const = 0; + /*! + ensures + - returns true if we are currently at a valid element else + returns false + !*/ + + virtual const T& element ( + ) const = 0; + /*! + requires + - current_element_valid() == true + ensures + - returns a const reference to the current element + !*/ + + virtual T& element ( + ) = 0; + /*! + requires + - current_element_valid() == true + ensures + - returns a non-const reference to the current element + !*/ + + virtual bool move_next ( + ) const = 0; + /*! + ensures + - moves to the next element. i.e. #element() will now + return the next element in the container + - the return value will be equal to #current_element_valid() + - #at_start() == false + + - returns true if there is another element + - returns false if there are no more elements in the container + !*/ + + virtual size_t size ( + ) const = 0; + /*! + ensures + - returns the number of elements in *this + !*/ + + protected: + + // restricted functions + enumerable& operator=(const enumerable&) {return *this;} // no assignment operator + + }; + + // destructor does nothing + template + enumerable::~enumerable() {} + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_ENUMERABLe_INTERFACE_ + diff --git a/ml/dlib/dlib/interfaces/map_pair.h b/ml/dlib/dlib/interfaces/map_pair.h new file mode 100644 index 000000000..64310152e --- /dev/null +++ b/ml/dlib/dlib/interfaces/map_pair.h @@ -0,0 +1,74 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MAP_PAIr_INTERFACE_ +#define DLIB_MAP_PAIr_INTERFACE_ + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename T1, + typename T2 + > + class map_pair + { + /*! + POINTERS AND REFERENCES TO INTERNAL DATA + None of the functions in map_pair will invalidate + pointers or references to internal data when called. + + WHAT THIS OBJECT REPRESENTS + this object is used to return the key/value pair used in the + map and hash_map containers when using the enumerable interface. + + note that the enumerable interface is defined in + interfaces/enumerable.h + !*/ + + public: + typedef T1 key_type; + typedef T2 value_type; + + virtual ~map_pair( + )=0; + + virtual const T1& key( + ) const =0; + /*! + ensures + - returns a const reference to the key + !*/ + + virtual const T2& value( + ) const =0; + /*! + ensures + - returns a const reference to the value associated with key + !*/ + + virtual T2& value( + )=0; + /*! + ensures + - returns a non-const reference to the value associated with key + !*/ + + protected: + + // restricted functions + map_pair& operator=(const map_pair&) {return *this;} // no assignment operator + + }; + + // destructor does nothing + template + map_pair::~map_pair () {} + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MAP_PAIr_INTERFACE_ + diff --git a/ml/dlib/dlib/interfaces/remover.h b/ml/dlib/dlib/interfaces/remover.h new file mode 100644 index 000000000..f2098cba6 --- /dev/null +++ b/ml/dlib/dlib/interfaces/remover.h @@ -0,0 +1,220 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_REMOVER_KERNEl_INTERFACE_ +#define DLIB_REMOVER_KERNEl_INTERFACE_ + +#include + + +namespace dlib +{ + + template < + typename T + > + class remover + { + + /*! + REQUIREMENTS ON T + T is swappable by a global swap() and + T must have a default constructor + + POINTERS AND REFERENCES TO INTERNAL DATA + The size() function does not invalidate pointers or + references to internal data. All other functions have no such + guarantee. + + WHAT THIS OBJECT REPRESENTS + This object represents some generalized interface for removing + single items from container classes. + !*/ + + + public: + typedef T type; + + virtual ~remover( + ); + /*! + ensures + - all resources associated with *this have been released. + !*/ + + virtual void remove_any ( + T& item + ) = 0; + /*! + requires + - size() != 0 + ensures + - #size() == size() - 1 + - removes an element from *this and swaps it into item. + - if (*this implements the enumerable interface) then + - #at_start() == true + !*/ + + virtual size_t size ( + ) const = 0; + /*! + ensures + - returns the number of elements in *this + !*/ + + protected: + + // restricted functions + remover& operator=(const remover&) {return *this;} // assignment operator + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename compare + > + class asc_remover : public remover + { + /*! + REQUIREMENTS ON T + T is swappable by a global swap() and + T must have a default constructor and + T must be comparable by compare where compare is a functor compatible with std::less + + WHAT THIS OBJECT REPRESENTS + This object represents the same thing as remover except + that remove_any() will remove elements in ascending order + according to the compare functor. + !*/ + public: + typedef compare compare_type; + + protected: + // restricted functions + asc_remover& operator=(const asc_remover&) {return *this;} // assignment operator + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range + > + class pair_remover + { + + /*! + REQUIREMENTS ON domain + domain is swappable by a global swap() and + domain must have a default constructor + + REQUIREMENTS ON range + range is swappable by a global swap() and + range must have a default constructor + + POINTERS AND REFERENCES TO INTERNAL DATA + The size() function does not invalidate pointers or + references to internal data. All other functions have no such + guarantee. + + WHAT THIS OBJECT REPRESENTS + This object represents some generalized interface for removing + pairs from container classes which enforce some kind of pairing on + the elements that they contain. + !*/ + + public: + typedef domain domain_type; + typedef range range_type; + + virtual ~pair_remover( + ); + /*! + ensures + - all resources associated with *this have been released. + !*/ + + virtual void remove_any ( + domain& d, + range& r + ) = 0; + /*! + requires + - &d != &r (i.e. d and r cannot be the same variable) + - size() != 0 + ensures + - #size() == size() - 1 + - removes an element from the domain of *this and swaps it + into d. + - removes the element in *this's range that is associated + with #d and swaps it into r. + - if (*this implements the enumerable interface) then + - #at_start() == true + !*/ + + virtual size_t size ( + ) const = 0; + /*! + ensures + - returns the number of elements in *this + !*/ + + + protected: + + // restricted functions + pair_remover& operator=(const pair_remover&) {return *this;} // assignment operator + + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename compare + > + class asc_pair_remover : public pair_remover + { + /*! + REQUIREMENTS ON domain + domain is swappable by a global swap() and + domain must have a default constructor and + domain must be comparable by compare where compare is a functor compatible with std::less + + REQUIREMENTS ON range + range is swappable by a global swap() and + range must have a default constructor + + WHAT THIS OBJECT REPRESENTS + This object represents the same thing as pair_remover except + that remove_any() will remove domain elements in ascending + order according to the compare functor. + !*/ + public: + typedef compare compare_type; + + protected: + // restricted functions + asc_pair_remover& operator=(const asc_pair_remover&) {return *this;} // assignment operator + }; + +// ---------------------------------------------------------------------------------------- + + // destructor does nothing + template + remover::~remover() {} + + // destructor does nothing + template + pair_remover::~pair_remover() {} + + +// ---------------------------------------------------------------------------------------- + + +} + +#endif // DLIB_REMOVER_KERNEl_INTERFACE_ + diff --git a/ml/dlib/dlib/iomanip b/ml/dlib/dlib/iomanip new file mode 100644 index 000000000..eb0e59e41 --- /dev/null +++ b/ml/dlib/dlib/iomanip @@ -0,0 +1 @@ +#include "dlib_include_path_tutorial.txt" diff --git a/ml/dlib/dlib/iosfwd b/ml/dlib/dlib/iosfwd new file mode 100644 index 000000000..eb0e59e41 --- /dev/null +++ b/ml/dlib/dlib/iosfwd @@ -0,0 +1 @@ +#include "dlib_include_path_tutorial.txt" diff --git a/ml/dlib/dlib/iosockstream.h b/ml/dlib/dlib/iosockstream.h new file mode 100644 index 000000000..da1b9b505 --- /dev/null +++ b/ml/dlib/dlib/iosockstream.h @@ -0,0 +1,11 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_IOSOCkSTREAM_H_h_ +#define DLIB_IOSOCkSTREAM_H_h_ + +#include "iosockstream/iosockstream.h" + + +#endif // DLIB_IOSOCkSTREAM_H_h_ + + diff --git a/ml/dlib/dlib/iosockstream/iosockstream.h b/ml/dlib/dlib/iosockstream/iosockstream.h new file mode 100644 index 000000000..e49d2e37f --- /dev/null +++ b/ml/dlib/dlib/iosockstream/iosockstream.h @@ -0,0 +1,171 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_IOSOCKSTrEAM_Hh_ +#define DLIB_IOSOCKSTrEAM_Hh_ + +#include "iosockstream_abstract.h" + +#include +#include + +#include "../sockstreambuf.h" +#include "../timeout.h" + +#ifdef _MSC_VER +// Disable the warning about inheriting from std::iostream 'via dominance' since this warning is a warning about +// visual studio conforming to the standard and is ignorable. +// See http://connect.microsoft.com/VisualStudio/feedback/details/733720/inheriting-from-std-fstream-produces-c4250-warning +// for further details if interested. +#pragma warning(disable : 4250) +#endif // _MSC_VER + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class iosockstream : public std::iostream + { + public: + + iosockstream( + ) : + std::iostream(0) + { + } + + iosockstream( + const network_address& addr + ) : + std::iostream(0) + { + open(addr); + } + + iosockstream( + const network_address& addr, + unsigned long timeout + ) : + std::iostream(0) + { + open(addr, timeout); + } + + ~iosockstream() + { + close(); + } + + void open ( + const network_address& addr + ) + { + auto_mutex lock(class_mutex); + close(); + con.reset(connect(addr)); + buf.reset(new sockstreambuf(con.get())); + // Note that we use the sockstreambuf's ability to autoflush instead of + // telling the iostream::tie() function to tie the stream to itself even though + // that should work fine. The reason we do it this way is because there is a + // bug in visual studio 2012 that causes a program to crash when a stream is + // tied to itself and then used. See + // http://connect.microsoft.com/VisualStudio/feedback/details/772293/tying-a-c-iostream-object-to-itself-causes-a-stack-overflow-in-visual-studio-2012 + // for further details. + buf->flush_output_on_read(); + rdbuf(buf.get()); + clear(); + } + + void open ( + const network_address& addr, + unsigned long timeout + ) + { + auto_mutex lock(class_mutex); + close(timeout); + con.reset(connect(addr.host_address, addr.port, timeout)); + buf.reset(new sockstreambuf(con.get())); + buf->flush_output_on_read(); + rdbuf(buf.get()); + clear(); + } + + void close( + unsigned long timeout = 10000 + ) + { + auto_mutex lock(class_mutex); + rdbuf(0); + try + { + if (buf) + { + dlib::timeout t(*con,&connection::shutdown,timeout); + + // This will flush the sockstreambuf and also destroy it. + buf.reset(); + + if(con->shutdown_outgoing()) + { + // there was an error so just close it now and return + con->shutdown(); + } + else + { + char junk[100]; + // wait for the other end to close their side + while (con->read(junk,sizeof(junk)) > 0); + } + } + } + catch (...) + { + con.reset(); + throw; + } + con.reset(); + } + + void terminate_connection_after_timeout ( + unsigned long timeout + ) + { + auto_mutex lock(class_mutex); + if (con) + { + con_timeout.reset(new dlib::timeout(*this,&iosockstream::terminate_connection,timeout,con)); + } + } + + void shutdown ( + ) + { + auto_mutex lock(class_mutex); + if (con) + con->shutdown(); + } + + private: + + void terminate_connection( + std::shared_ptr thecon + ) + { + thecon->shutdown(); + } + + std::unique_ptr con_timeout; + rmutex class_mutex; + std::shared_ptr con; + std::unique_ptr buf; + + }; + +// ---------------------------------------------------------------------------------------- + +} + + +#endif // DLIB_IOSOCKSTrEAM_Hh_ + + diff --git a/ml/dlib/dlib/iosockstream/iosockstream_abstract.h b/ml/dlib/dlib/iosockstream/iosockstream_abstract.h new file mode 100644 index 000000000..2328f426e --- /dev/null +++ b/ml/dlib/dlib/iosockstream/iosockstream_abstract.h @@ -0,0 +1,171 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_IOSOCKSTrEAM_ABSTRACT_Hh_ +#ifdef DLIB_IOSOCKSTrEAM_ABSTRACT_Hh_ + +#include "../sockstreambuf/sockstreambuf_abstract.h" + +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class iosockstream : public std::iostream + { + /*! + WHAT THIS OBJECT REPRESENTS + This is an iostream object that reads/writes from a TCP network connection. + + Note that any attempt to read from this stream will automatically flush the + stream's output buffers. + + THREAD SAFETY + It is not safe for multiple threads to make concurrent accesses to the same + instance of this object (except for calls to shutdown() which are always + threadsafe). Therefore, you should mutex lock an instance of this object + if you need to touch it from multiple threads. + !*/ + + public: + + iosockstream( + ); + /*! + ensures + - #good() == false + !*/ + + iosockstream( + const network_address& addr + ); + /*! + ensures + - Attempts to connect to the given network address. + - Calling this constructor is equivalent to calling the default constructor + and then invoking open(addr). + - #good() == true + throws + - dlib::socket_error + This exception is thrown if there is some problem that prevents us from + creating the connection. + !*/ + + iosockstream( + const network_address& addr, + unsigned long timeout + ); + /*! + ensures + - Attempts to connect to the given network address. + - Calling this constructor is equivalent to calling the default constructor + and then invoking open(addr, timeout). + - #good() == true + throws + - dlib::socket_error + This exception is thrown if there is some problem that prevents us from + creating the connection or if timeout milliseconds elapses before the + connect is successful. + !*/ + + ~iosockstream( + ); + /*! + ensures + - Invokes close() before destructing the stream. Therefore, any open + connection will be gracefully closed using the default timeout time. + This also means any data in the stream will be flushed to the connection. + !*/ + + void open ( + const network_address& addr + ); + /*! + ensures + - This object will attempt to create a TCP connection with the remote host + indicated by addr. + - Any previous connection in this iosockstream is closed by calling close() + before we make any new connection. + - #good() == true + (i.e. the error flags are reset by calling open()) + throws + - dlib::socket_error + This exception is thrown if there is some problem that prevents us from + creating the connection. + !*/ + + void open ( + const network_address& addr, + unsigned long timeout + ); + /*! + ensures + - This object will attempt to create a TCP connection with the remote host + indicated by addr. + - Any previous connection in this iosockstream is closed by calling close() + before we make any new connection. + - #good() == true + (i.e. the error flags are reset by calling open()) + throws + - dlib::socket_error + This exception is thrown if there is some problem that prevents us from + creating the connection or if timeout milliseconds elapses before the + connect is successful. + !*/ + + void close( + unsigned long timeout = 10000 + ); + /*! + ensures + - #good() == false + - if (there is an active TCP connection) then + - Flushes any data buffered in the output part of the stream + to the connection. + - Performs a proper graceful close (i.e. like dlib::close_gracefully()). + - Will only wait timeout milliseconds for the buffer flush and graceful + close to finish before the connection is terminated forcefully. + Therefore, close() will only block for at most timeout milliseconds. + !*/ + + void terminate_connection_after_timeout ( + unsigned long timeout + ); + /*! + ensures + - if (there is an active TCP connection) then + - Any operations on this TCP connection will return error or + end-of-file once timeout milliseconds have elapsed from this call to + terminate_connection_after_timeout(). This is true unless another + call to terminate_connection_after_timeout() is made which gives a + new time. In this case, the previous call is forgotten and the + timeout is reset. + - This timeout only applies to the current TCP connection. That is, if + the iosockstream is closed and a new connection is established, any + previous timeouts setup by terminate_connection_after_timeout() do + not apply. + - else + - This function has no effect on this object. + !*/ + + void shutdown ( + ); + /*! + ensures + - Immediately closes the TCP connection and causes all I/O operations on + this object to return an error. + - It is safe to call this function from any thread, therefore, you can use + it to signal when you want a connection to terminate from another thread. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + +} + + +#endif // DLIB_IOSOCKSTrEAM_ABSTRACT_Hh_ + + + diff --git a/ml/dlib/dlib/iostream b/ml/dlib/dlib/iostream new file mode 100644 index 000000000..eb0e59e41 --- /dev/null +++ b/ml/dlib/dlib/iostream @@ -0,0 +1 @@ +#include "dlib_include_path_tutorial.txt" diff --git a/ml/dlib/dlib/is_kind.h b/ml/dlib/dlib/is_kind.h new file mode 100644 index 000000000..e8dcb6320 --- /dev/null +++ b/ml/dlib/dlib/is_kind.h @@ -0,0 +1,162 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_IS_KINd_H_ +#define DLIB_IS_KINd_H_ + +#include + +namespace dlib +{ + /*! + This file contains a set of templates that enable you to determine if + a given type implements an abstract interface defined in one of the + dlib *_abstract.h files. + !*/ + +// ---------------------------------------------------------------------------------------- + + struct default_is_kind_value { static const bool value = false; }; + +// ---------------------------------------------------------------------------------------- + + template + struct is_graph : public default_is_kind_value + { + /*! + - if (T is an implementation of graph/graph_kernel_abstract.h) then + - is_graph::value == true + - else + - is_graph::value == false + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + template + struct is_directed_graph : public default_is_kind_value + { + /*! + - if (T is an implementation of directed_graph/directed_graph_kernel_abstract.h) then + - is_directed_graph::value == true + - else + - is_directed_graph::value == false + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + template + struct is_matrix : public default_is_kind_value + { + /*! + - if (T is some kind of matrix expression from the matrix/matrix_exp_abstract.h component) then + - is_matrix::value == true + - else + - is_matrix::value == false + !*/ + + // Don't set the helper to anything. Just let it be void. + ASSERT_ARE_SAME_TYPE(helper,void); + }; + +// ---------------------------------------------------------------------------------------- + + template + struct is_array2d : public default_is_kind_value + { + /*! + - if (T is an implementation of array2d/array2d_kernel_abstract.h) then + - is_array2d::value == true + - else + - is_array2d::value == false + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + template + struct is_array : public default_is_kind_value + { + /*! + - if (T is an implementation of array/array_kernel_abstract.h) then + - is_array::value == true + - else + - is_array::value == false + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + template + struct is_std_vector : public default_is_kind_value + { + /*! + - if (T is an implementation of the standard C++ std::vector object) then + - is_std_vector::value == true + - else + - is_std_vector::value == false + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + template + struct is_pair : public default_is_kind_value + { + /*! + - if (T is a std::pair object) then + - is_std_vector::value == true + - else + - is_std_vector::value == false + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + template + struct is_rand : public default_is_kind_value + { + /*! + - if (T is an implementation of rand/rand_kernel_abstract.h) then + - is_rand::value == true + - else + - is_rand::value == false + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + template + struct is_config_reader : public default_is_kind_value + { + /*! + - if (T is an implementation of config_reader/config_reader_kernel_abstract.h) then + - is_config_reader::value == true + - else + - is_config_reader::value == false + !*/ + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// Implementation details +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template + struct is_std_vector > { const static bool value = true; }; + template struct is_std_vector { const static bool value = is_std_vector::value; }; + template struct is_std_vector{ const static bool value = is_std_vector::value; }; + template struct is_std_vector { const static bool value = is_std_vector::value; }; + +// ---------------------------------------------------------------------------------------- + + template + struct is_pair > { const static bool value = true; }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_IS_KINd_H_ + diff --git a/ml/dlib/dlib/istream b/ml/dlib/dlib/istream new file mode 100644 index 000000000..eb0e59e41 --- /dev/null +++ b/ml/dlib/dlib/istream @@ -0,0 +1 @@ +#include "dlib_include_path_tutorial.txt" diff --git a/ml/dlib/dlib/java/CMakeLists.txt b/ml/dlib/dlib/java/CMakeLists.txt new file mode 100644 index 000000000..4d66a513c --- /dev/null +++ b/ml/dlib/dlib/java/CMakeLists.txt @@ -0,0 +1,32 @@ + +cmake_minimum_required (VERSION 2.8.12) +project (myproject) +set(java_package_name net.dlib) +set(source_files + ) + +include(../cmake_utils/release_build_by_default) + +include_directories( + . + ) + +# Additional dependencies +#add_subdirectory(../../dlib dlib_build) +#set(additional_link_libraries dlib::dlib) + +# Tell swig to put the output files (the shared library and .jar) into the local folder. +set(install_target_output_folder .) + +# Alternatively, instead of using install_target_output_folder, you can tell +# cmake to output the shared library, java source files, and the jar to +# separate output folders. These commands would put them into folders thelib, +# thesrc, and thejar, respectively. +#set(install_shared_library_output_folder thelib) +#set(install_java_source_output_folder thesrc) +#set(install_jar_output_folder thejar) + + +include(cmake_swig_jni) + + diff --git a/ml/dlib/dlib/java/cmake_swig_jni b/ml/dlib/dlib/java/cmake_swig_jni new file mode 100644 index 000000000..d74dd60ec --- /dev/null +++ b/ml/dlib/dlib/java/cmake_swig_jni @@ -0,0 +1,265 @@ +# This file is used to create SWIG based JNI interfaces to C++ code. You use +# it by defining some CMake variables and then include(cmake_swig_jni). You +# would make a CMakeLists.txt file that looks like the following: +# +# cmake_minimum_required (VERSION 2.8.12) +# project (example) +# set(java_package_name "org.mycompany") +# set(source_files +# your_cpp_source.cpp +# more_cpp_source.cpp +# ) +# +# ### We might need to link our code to some other C++ library like dlib. You +# ### can do that by setting additional_link_libraries. Here is an example of +# ### linking to dlib: +# include(../../dlib/dlib/cmake) +# set(additional_link_libraries dlib::dlib) +# +# ### Tell swig to put the output files into the parent folder of your CMakeLists.txt +# ### file when you run make install. +# set(install_target_output_folder ..) +# include(cmake_swig_jni) +# +# ### Alternatively, instead of using install_target_output_folder, you can tell +# ### cmake to output the shared library, java source files, and the jar to +# ### separate output folders. These commands would put them into folders +# ### thelib, thesrc, and thejar, respectively. +# # set(install_shared_library_output_folder thelib) +# # set(install_java_source_output_folder thesrc) +# # set(install_jar_output_folder thejar) + + + + + + +################################################################################ +################################################################################ +# IMPLEMENTATION DETAILS +################################################################################ +################################################################################ + +cmake_minimum_required (VERSION 2.8.12) + +include(${CMAKE_CURRENT_LIST_DIR}/../cmake_utils/use_cpp_11.cmake) + +# This block of code tries to figure out what the JAVA_HOME environment +# variable should be by looking at the folder that contains the java +# executable. +if (NOT DEFINED ENV{JAVA_HOME}) + message(STATUS "JAVA_HOME environment variable not set, trying to guess it...") + find_program(JAVA_EXECUTABLE java) + # Resolve symbolic links, hopefully this will give us a path in the proper + # java home directory. + get_filename_component(JAVA_EXECUTABLE ${JAVA_EXECUTABLE} REALPATH) + # Pick out the parent directories + get_filename_component(JAVA_PATH1 ${JAVA_EXECUTABLE} PATH) + get_filename_component(JAVA_PATH2 ${JAVA_PATH1} PATH) + get_filename_component(JAVA_PATH3 ${JAVA_PATH2} PATH) + # and search them for include/jni.h. If we find that then we probably have + # a good java home candidate. + find_path(AUTO_JAVA_HOME include/jni.h + PATHS + ${JAVA_PATH1} + ${JAVA_PATH2} + ${JAVA_PATH3} + "C:/Program Files/Java/jdk*" + "C:/Program Files (x86)/Java/jdk*" + ) + + if (AUTO_JAVA_HOME) + set(ENV{JAVA_HOME} ${AUTO_JAVA_HOME}) + message(STATUS "Using JAVA_HOME OF " ${AUTO_JAVA_HOME}) + else() + message(FATAL_ERROR "Couldn't find a folder for JAVA_HOME. You must set the JAVA_HOME environment variable before running CMake.") + endif() +endif() + +set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/lib") +set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE "${CMAKE_CURRENT_BINARY_DIR}/lib") + +find_package(SWIG REQUIRED) +find_package(Java REQUIRED) +find_package(JNI REQUIRED) +include(UseSWIG) + +macro (add_global_switch def_name ) + if (NOT CMAKE_CXX_FLAGS MATCHES "${def_name}") + set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${def_name}" + CACHE STRING "Flags used by the compiler during all C++ builds." + FORCE) + endif () + if (NOT CMAKE_C_FLAGS MATCHES "${def_name}") + set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${def_name}" + CACHE STRING "Flags used by the compiler during all C builds." + FORCE) + endif () +endmacro() + +# SWIG doesn't work if optimizations are enabled and strict aliasing is not +# turned off. This is a little wonky but it's how SWIG is. +if (CMAKE_COMPILER_IS_GNUCXX) + add_definitions(-fno-strict-aliasing) +endif() +if (UNIX) + # we need to make sure all the code is compiled with -fPIC. In particular, + # it's important that all the code for the whole project is, not just the + # stuff immediately compiled by us in this cmake file. So we add -fPIC to + # the top level cmake flags variables. + add_global_switch(-fPIC) +endif() + +set(dlib_root_path ${CMAKE_CURRENT_LIST_DIR}/../../) + +string(REGEX REPLACE "\\." "/" package_path ${java_package_name}) +string(REGEX REPLACE "\\..*" "" package_root_name ${java_package_name}) + +include_directories(${dlib_root_path}) + +set(CMAKE_SWIG_FLAGS -package ${java_package_name} -I${dlib_root_path}) +set(CMAKE_SWIG_OUTDIR ${CMAKE_CURRENT_BINARY_DIR}/lib/java_src/${package_path}) + +set(output_library_name ${PROJECT_NAME}) + +# Create the swig.i interface file that swig will run on. We do it here in +# the cmake script because this lets us automatically include the correct +# output library name into the call to System.loadLibrary(). +FILE(WRITE ${CMAKE_CURRENT_BINARY_DIR}/swig.i + " + // Put the global functions in our api into a java class called global. + %module global + + %{ + #include + #include + static JavaVM *cached_jvm = 0; + + JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *jvm, void *reserved) { + cached_jvm = jvm; + return JNI_VERSION_1_6; + } + + static JNIEnv * JNI_GetEnv() { + JNIEnv *env; + jint rc = cached_jvm->GetEnv((void **)&env, JNI_VERSION_1_6); + if (rc == JNI_EDETACHED) + throw std::runtime_error(\"current thread not attached\"); + if (rc == JNI_EVERSION) + throw std::runtime_error(\"jni version not supported\"); + return env; + } + + #include \"swig_api.h\" + %} + + // Convert all C++ exceptions into java.lang.Exception + %exception { + try { + $action + } catch(std::exception& e) { + jclass clazz = jenv->FindClass(\"java/lang/Exception\"); + jenv->ThrowNew(clazz, e.what()); + return $null; + } + } + + %pragma(java) jniclasscode=%{ + static { System.loadLibrary(\"${output_library_name}\"); } + %} + + %include \"swig_api.h\" + " +) + +# There is a bug in CMake's Swig scripts that causes the build to fail if the +# binary folder doesn't contain a folder with the same name as the binary dir. +# So we make a subfolder of the same name to avoid that bug. +get_filename_component(binary_dir_name "${CMAKE_CURRENT_BINARY_DIR}" NAME) +FILE(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/${binary_dir_name}") + +set_source_files_properties(${CMAKE_CURRENT_BINARY_DIR}/swig.i PROPERTIES CPLUSPLUS ON) +swig_add_module(${output_library_name} java ${CMAKE_CURRENT_BINARY_DIR}/swig.i ${source_files}) +enable_cpp11_for_target(${output_library_name}) + +include_directories(${JNI_INCLUDE_DIRS}) +swig_link_libraries(${output_library_name} ${additional_link_libraries}) + +# Things to delete when "make clean" is run. +set(clean_files + ${CMAKE_CURRENT_BINARY_DIR}/intermediate_files_compiled + ${CMAKE_CURRENT_BINARY_DIR}/lib/java_src + ) +set_directory_properties(PROPERTIES ADDITIONAL_MAKE_CLEAN_FILES "${clean_files}") + +# Compile the java files into a jar file and stick it in the lib folder. Also, one problem +# with this cmake setup is that it doesn't know that modifications to swig_api.h mean that +# swig.i is invalidated and thus swig needs to be rerun. So here we also touch swig.i +# every time we build to make it always out of date and force swig to run on each build, +# thus avoiding the stale swig outputs problem that would otherwise irritate people who +# modify something and attempt to rebuild. +add_custom_command(TARGET ${output_library_name} + POST_BUILD + COMMAND cmake -E echo "compiling Java files..." + COMMAND cmake -E make_directory "${CMAKE_CURRENT_BINARY_DIR}/intermediate_files_compiled" + COMMAND ${Java_JAVAC_EXECUTABLE} ${CMAKE_SWIG_OUTDIR}/*.java -d "${CMAKE_CURRENT_BINARY_DIR}/intermediate_files_compiled" + COMMAND cmake -E echo "Making jar file..." + COMMAND ${Java_JAR_EXECUTABLE} cvf "${CMAKE_CURRENT_BINARY_DIR}/lib/${PROJECT_NAME}.jar" -C "${CMAKE_CURRENT_BINARY_DIR}/intermediate_files_compiled" ${package_root_name} + COMMAND cmake -E touch swig.i + ) + + +# Determine the path to our CMakeLists.txt file. +# There is either a bug (or break in compatability maybe) between versions +# of cmake that cause the or expression in this regular expression to be +# necessary. +string(REGEX REPLACE "(cmake_swig_jni|CMakeLists.txt)$" "" base_path ${CMAKE_PARENT_LIST_FILE}) + +#if the including cmake script set the install_target_output_folder variable +#then make it so we install the compiled library and jar into that folder +if (install_target_output_folder) + # The directory we will write the output files to. + set(install_dir "${base_path}${install_target_output_folder}") + set(CMAKE_INSTALL_PREFIX "${install_dir}") + set(CMAKE_INSTALL_SYSTEM_RUNTIME_DESTINATION "${install_dir}") + install(TARGETS ${output_library_name} + DESTINATION "${install_dir}" + ) + install(FILES ${CMAKE_CURRENT_BINARY_DIR}/lib/${PROJECT_NAME}.jar + DESTINATION "${install_dir}" + ) +endif() + +if (install_shared_library_output_folder) + set(install_dir "${base_path}${install_shared_library_output_folder}") + install(TARGETS ${output_library_name} + DESTINATION "${install_dir}" + ) +endif() + +if (install_java_source_output_folder) + set(install_dir "${base_path}${install_java_source_output_folder}") + install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib/java_src/${package_root_name} + DESTINATION "${install_dir}" + ) +endif() + +if (install_jar_output_folder) + set(install_dir "${base_path}${install_jar_output_folder}") + install(FILES ${CMAKE_CURRENT_BINARY_DIR}/lib/${PROJECT_NAME}.jar + DESTINATION "${install_dir}" + ) +endif() + + +# Copy any system libraries to the output folder. This really only matters on +# windows where it's good to have the visual studio runtime show up in the lib +# folder so that you don't forget to include it in your binary distribution. +INCLUDE(InstallRequiredSystemLibraries) +foreach (file_i ${CMAKE_INSTALL_SYSTEM_RUNTIME_LIBS}) + add_custom_command(TARGET ${output_library_name} + POST_BUILD + COMMAND cmake -E copy ${file_i} "${CMAKE_CURRENT_BINARY_DIR}/lib/" + ) +endforeach() + diff --git a/ml/dlib/dlib/java/java_array.h b/ml/dlib/dlib/java/java_array.h new file mode 100644 index 000000000..6c4d5f03d --- /dev/null +++ b/ml/dlib/dlib/java/java_array.h @@ -0,0 +1,605 @@ +// Copyright (C) 2017 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SWIG_JAVA_ARRAY_H_ +#define DLIB_SWIG_JAVA_ARRAY_H_ + + +/* + + This file defines three special classes: array, array_view, and array_view_crit. An + array is a simple opaque handle to a java array, like a double[] array. The array_view + and array_view_crit objects allow you to access the contents of an array. The + interfaces of these objects is shown below, but for an example use, suppose you had an + array of int in java and you wanted to pass it to C++. You could create a C++ function + like this: + + void my_function(const array_view& array); + + and then within java you could call it with code like this: + + int[] array = new int[100]; + my_function(array); + + and it will work just like you would expect. The array_view will usually result in + the JVM doing a copy in the background. However, you can also declare your function + like this: + + void my_function(const array_view_crit& array); + + and still call it the same way in java, however, using array_view_crit will usually + not result in any copying, and is therefore very fast. array_view_crit uses the JNI + routine GetPrimitiveArrayCritical() to get a lock on the java memory underlying the + array. So it will probably prevent the garbage collector from running while your + function is executing. The JNI documentation is somewhat vague on the limitations of + GetPrimitiveArrayCritical(), saying only that you shouldn't hold the lock on the array + for "an extended period" or call back into the JVM. Deciding whether or not this + matters in your application is left as an exercise for the reader. + + + There are two ways you can declare your methods if they take an array_view or + array_view_crit. Taking a const reference or a non-const reference. E.g. + void my_function(const array_view& array); + void my_function(array_view& array); + You can't declare them to be by value. The non-const version allows you to modify the + contents of the array and the modifications will be visible to java, as you would + expect. You can also make functions that take array objects directly, but that's only + useful if you want to store the array handle somewhere, like in a member of a long + lived class. You can also write functions that return arrays back to java. E.g. + array make_an_array(size_t s) + { + array arr(s); + array_view aview(arr); + // Use aview to put data into the array and generally do something useful. + ... + return arr; + } + This would create an array and return it as a java int[] array. + + + You can also of course use functions taking many arguments, as is normally the case + with SWIG. Finally, these classes work with the following primitive types: + - int16_t + - int32_t + - int64_t + - char (corresponding to java byte) + - float + - double + + + + +namespace java +{ + template + class array + { + /!* + WHAT THIS OBJECT REPRESENTS + This is a handle to a java array. I.e. a reference to an array instance in + java like a double[] or int[]. It doesn't do anything other than tell you + the size of the array and allow you to hold a reference to it. + + To access the array contents, you need to create an array_view or + array_view_crit from the array. + *!/ + public: + array(); + /!* + ensures + - #size() == 0 + - this array is a null reference, i.e. it doesn't reference any array. + *!/ + + explicit array(size_t new_size); + /!* + ensures + - #size() == new_size + - Allocates a new java array. + - This array is a reference to the newly allocated java array object. + *!/ + + size_t size() const; + /!* + ensures + - returns the number of elements in this java array. + *!/ + + void swap(array& item); + /!* + ensures + - swaps the state of *this and item. + *!/ + + array(const array& item); + array& operator= (const array& item) + array(array&& item); + array& operator= (array&& item); + /!* + ensures + - The array is copyable, assignable, and movable. All copies will + reference the same underlying array. So the copies are shallow, as is + normally the case with java reference semantics. + *!/ + }; + + + + template + class array_view + { + /!* + WHAT THIS OBJECT REPRESENTS + This is a view into a java array object. It allows you to access the + values stored in an array and modify them if you want to. + + You should only create array_view objects locally in a function since an + array_view is only valid as long as the array it references exists. So + don't store array_view objects in the member area of a class or globally. + *!/ + + public: + array_view(); + /!* + ensures + - #size() == 0 + - #data() == nullptr + *!/ + + array_view(const array& arr, bool might_be_modified=true); + /!* + ensures + - #size() == arr.size() + - #data() == a pointer to the beginning of the array data referenced by arr. + - When you get a view on a java array, sometimes the JVM will actually + give you a pointer to a copy of the original array. You therefore have + to tell the JVM if you modified the array when you are done using it. If + you say you modified it then the JVM will perform another copy from your + memory buffer back into the JVM. The state of might_be_modified controls + if we do this. So if you are going to modify the array via this + array_view you should set might_be_modified==true. + *!/ + + size_t size() const; + /!* + ensures + - returns the number of elements in this java array. + *!/ + + T* data(); + const T* data() const; + /!* + ensures + - returns a pointer to the beginning of the array. Or nullptr if this is a + handle to null, rather than an actual array instance. + *!/ + + T* begin(); + T* end(); + const T* begin() const; + const T* end() const; + /!* + ensures + - returns iterators to the start and one-past-the-end of the array, as is + the convention for iterator ranges in C++. + *!/ + + T& operator[](size_t i); + const T& operator[](size_t i) const; + /!* + ensures + - returns data()[i] + *!/ + + private: + // this object is non-copyable. + array_view(const array_view&); + array_view& operator=(const array_view&); + }; + + + template + class array_view_crit + { + /!* + WHAT THIS OBJECT REPRESENTS + This is just like an array_view and has an identical interface. The only + difference is that we use the JNI call GetPrimitiveArrayCritical() to get a + critical lock on the array's memory. Therefore, using array_view_crit is + usually faster than array_view since it avoids any unnecessary copying back + and forth between the JVM. + + However, this critical lock can block the JVM's garbage collector from + running. So don't create long lived array_view_crit objects. + *!/ + }; + +} +*/ + + + + + + + + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// IMPLEMENTATION DETAILS +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + + + + + + + +namespace java +{ + +template +class array_view_base +{ +public: + array_view_base() = default; + + size_t size() const { return sz; } + T* data() { return pdata; } + const T* data() const { return pdata; } + + T* begin() { return pdata; } + T* end() { return pdata+sz; } + const T* begin() const { return pdata; } + const T* end() const { return pdata+sz; } + + T& operator[](size_t i) { return pdata[i]; } + const T& operator[](size_t i) const { return pdata[i]; } + +protected: + T* pdata = nullptr; + size_t sz = 0; + +private: + // this object is non-copyable + array_view_base(const array_view_base&); + array_view_base& operator=(const array_view_base&); + +}; + + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + +template +struct find_java_array_type; + +template <> struct find_java_array_type { typedef jshortArray type; }; +template <> struct find_java_array_type { typedef jintArray type; }; +template <> struct find_java_array_type { typedef jlongArray type; }; +template <> struct find_java_array_type { typedef jbyteArray type; }; +template <> struct find_java_array_type { typedef jfloatArray type; }; +template <> struct find_java_array_type { typedef jdoubleArray type; }; + +jshortArray create_java_array(int16_t, size_t size) { return JNI_GetEnv()->NewShortArray(size); } +jintArray create_java_array(int32_t, size_t size) { return JNI_GetEnv()->NewIntArray(size); } +jlongArray create_java_array(int64_t, size_t size) { return JNI_GetEnv()->NewLongArray(size); } +jbyteArray create_java_array(char, size_t size) { return JNI_GetEnv()->NewByteArray(size); } +jfloatArray create_java_array(float, size_t size) { return JNI_GetEnv()->NewFloatArray(size); } +jdoubleArray create_java_array(double , size_t size) { return JNI_GetEnv()->NewDoubleArray(size); } + +template +class array +{ +public: + + typedef typename find_java_array_type::type java_type; + + array() {} + + explicit array(size_t size) + { + ref = create_java_array(T(),size); + is_global_ref = false; + } + + array(java_type ref_) + { + if (ref_) + { + ref = (java_type)JNI_GetEnv()->NewGlobalRef(ref_); + is_global_ref = true; + } + } + +#ifndef SWIG + array(array&& item) + { + ref = item.ref; + is_global_ref = item.is_global_ref; + item.ref = NULL; + item.is_global_ref = false; + } + array& operator= (array&& item) + { + array(std::move(item)).swap(*this); + return *this; + } +#endif + + ~array() + { + if (ref) + { + // Don't delete the reference if it's a local reference, since the only reason + // we will normally be using array object's that contain local references + // is because we plan on returning the newly constructed array back to the JVM, + // which automatically frees local references using the normal JVM garbage + // collection scheme. + if (is_global_ref) + JNI_GetEnv()->DeleteGlobalRef(ref); + + ref = NULL; + is_global_ref = false; + } + } + + size_t size() const + { + if (ref) + return JNI_GetEnv()->GetArrayLength(ref); + else + return 0; + } + + array(const array& item) + { + array(item.ref).swap(*this); + } + + array& operator= (const array& item) + { + array(item).swap(*this); + return *this; + } + + operator java_type() const { return ref;} + + void swap(array& item) + { + std::swap(ref, item.ref); + std::swap(is_global_ref, item.is_global_ref); + } + +private: + java_type ref = NULL; + bool is_global_ref = false; +}; + +#ifdef SWIG +// Tell SWIG to not use it's SwigValueWrapper stuff on array objects since they aren't +// needed and it causes superfluous construction and destruction of array objects. +%feature("novaluewrapper") array; +%template() array; +%feature("novaluewrapper") array; +%template() array; +%feature("novaluewrapper") array; +%template() array; +%feature("novaluewrapper") array; +%template() array; +%feature("novaluewrapper") array; +%template() array; +%feature("novaluewrapper") array; +%template() array; +#endif + +#ifdef SWIG +%define tostring(token) + #token +%enddef + +%define define_javaObjectRef_converion(type, java_type) + // Define array conversions for non-const arrays + %typemap(jtype) (array) "java_type[]" + %typemap(jstype) (array) "java_type[]" + %typemap(jni) (array) tostring(j##java_type##Array) + %typemap(javain) (array) "$javainput" + %typemap(in) (array) { $1 = java::array($input); } + %typemap(javaout) (array) {return $jnicall; } + %typemap(out) (array) {jresult = result;} + + %typemap(jtype) (array&) "java_type[]" + %typemap(jstype) (array&) "java_type[]" + %typemap(jni) (array&) tostring(j##java_type##Array) + %typemap(javain) (array&) "$javainput" + %typemap(arginit) (array&) { $1 = &temp$argnum; } + %typemap(in) (array&) (java::array temp) { *($1) = java::array($input); } + + %typemap(jtype) (const array&) "java_type[]" + %typemap(jstype) (const array&) "java_type[]" + %typemap(jni) (const array&) tostring(j##java_type##Array) + %typemap(javain) (const array&) "$javainput" + %typemap(arginit) (const array&) { $1 = &temp$argnum; } + %typemap(in) (const array&) (java::array temp) { *($1) = java::array($input); } +%enddef +define_javaObjectRef_converion(int16_t,short) +define_javaObjectRef_converion(int32_t,int) +define_javaObjectRef_converion(int64_t,long) +define_javaObjectRef_converion(char,byte) +define_javaObjectRef_converion(float,float) +define_javaObjectRef_converion(double,double) + +#endif +// ---------------------------------------------------------------------------------------- + +template class array_view; + +#define JAVA_ARRAY_CLASS_SPEC(ctype, type, Type) \ +template <> class array_view : public array_view_base \ +{ \ +public: \ + ~array_view() { clear(); } \ + array_view() {} \ + array_view(const array& arr, bool might_be_modified_=true){reset(JNI_GetEnv(),arr,might_be_modified_);} \ + void reset(JNIEnv* jenv_, j##type##Array arr, bool might_be_modified_) { \ + clear(); \ + jenv = jenv_; \ + oldArr = arr; \ + if (arr) { \ + pdata = (ctype*)jenv->Get##Type##ArrayElements(arr, 0); \ + sz = jenv->GetArrayLength(arr); \ + } \ + might_be_modified = might_be_modified_; \ + } \ +private: \ + void clear() { \ + if (pdata) { \ + jenv->Release##Type##ArrayElements(oldArr, (j##type*)pdata, might_be_modified?0:JNI_ABORT); \ + pdata = nullptr; \ + sz = 0; \ + } \ + } \ + JNIEnv* jenv = nullptr; \ + j##type##Array oldArr; \ + bool might_be_modified; \ +}; + +JAVA_ARRAY_CLASS_SPEC(int16_t,short, Short) +JAVA_ARRAY_CLASS_SPEC(int32_t,int, Int) +JAVA_ARRAY_CLASS_SPEC(int64_t,long, Long) +JAVA_ARRAY_CLASS_SPEC(char,byte, Byte) +JAVA_ARRAY_CLASS_SPEC(float,float, Float) +JAVA_ARRAY_CLASS_SPEC(double,double, Double) + + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + +template +class array_view_crit_base +{ +public: + array_view_crit_base() = default; + + size_t size() const { return sz; } + T* data() { return pdata; } + const T* data() const { return pdata; } + + T* begin() { return pdata; } + T* end() { return pdata+sz; } + const T* begin() const { return pdata; } + const T* end() const { return pdata+sz; } + T& operator[](size_t i) { return pdata[i]; } + const T& operator[](size_t i) const { return pdata[i]; } + + ~array_view_crit_base() { clear(); } + + void reset(JNIEnv* jenv_, JARR arr, bool might_be_modified_) + { + clear(); + jenv = jenv_; + oldArr = arr; + if (arr) + { + pdata = (T*)jenv->GetPrimitiveArrayCritical(arr, 0); + sz = jenv->GetArrayLength(arr); + } + might_be_modified = might_be_modified_; + } + +private: + + void clear() + { + if (pdata) { + jenv->ReleasePrimitiveArrayCritical(oldArr, pdata, might_be_modified?0:JNI_ABORT); + pdata = nullptr; + sz = 0; + } + } + + // this object is non-copyable + array_view_crit_base(const array_view_crit_base&); + array_view_crit_base& operator=(const array_view_crit_base&); + + T* pdata = nullptr; + size_t sz = 0; + JNIEnv* jenv = nullptr; + JARR oldArr; + bool might_be_modified; +}; + +template class array_view_crit; + +template <> class array_view_crit : public array_view_crit_base { public: array_view_crit(){} array_view_crit(const array& arr, bool might_be_modified_=true){reset(JNI_GetEnv(),arr,might_be_modified_);} }; +template <> class array_view_crit : public array_view_crit_base { public: array_view_crit(){} array_view_crit(const array& arr, bool might_be_modified_=true){reset(JNI_GetEnv(),arr,might_be_modified_);} }; +template <> class array_view_crit : public array_view_crit_base { public: array_view_crit(){} array_view_crit(const array& arr, bool might_be_modified_=true){reset(JNI_GetEnv(),arr,might_be_modified_);} }; +template <> class array_view_crit : public array_view_crit_base { public: array_view_crit(){} array_view_crit(const array& arr, bool might_be_modified_=true){reset(JNI_GetEnv(),arr,might_be_modified_);} }; +template <> class array_view_crit : public array_view_crit_base { public: array_view_crit(){} array_view_crit(const array& arr, bool might_be_modified_=true){reset(JNI_GetEnv(),arr,might_be_modified_);} }; +template <> class array_view_crit : public array_view_crit_base { public: array_view_crit(){} array_view_crit(const array& arr, bool might_be_modified_=true){reset(JNI_GetEnv(),arr,might_be_modified_);} }; + +// ---------------------------------------------------------------------------------------- + +// Define SWIG typemaps so SWIG will know what to do with the array_view and array_view_crit +// objects. +#ifdef SWIG +%define define_array_converion(type, java_type) + // Define array conversions for non-const arrays + %typemap(jtype) (array_view&) "java_type[]" + %typemap(jstype) (array_view&) "java_type[]" + %typemap(jni) (array_view&) tostring(j##java_type##Array) + %typemap(javain) (array_view&) "$javainput" + %typemap(arginit) (array_view&) { $1 = &temp$argnum; } + %typemap(in) (array_view&) (java::array_view temp) { $1->reset(jenv, $input, true); } + + %typemap(jtype) (const array_view&) "java_type[]" + %typemap(jstype) (const array_view&) "java_type[]" + %typemap(jni) (const array_view&) tostring(j##java_type##Array) + %typemap(javain) (const array_view&) "$javainput" + %typemap(arginit) (const array_view&) { $1 = &temp$argnum; } + %typemap(in) (const array_view&) (java::array_view temp) { $1->reset(jenv, $input, false); } +%enddef +define_array_converion(int16_t,short) +define_array_converion(int32_t,int) +define_array_converion(int64_t,long) +define_array_converion(char,byte) +define_array_converion(float,float) +define_array_converion(double,double) + + + +%define define_array_crit_converion(type, java_type) + // Define array conversions for non-const arrays + %typemap(jtype) (array_view_crit&) "java_type[]" + %typemap(jstype) (array_view_crit&) "java_type[]" + %typemap(jni) (array_view_crit&) tostring(j##java_type##Array) + %typemap(javain) (array_view_crit&) "$javainput" + %typemap(arginit) (array_view_crit&) { $1 = &temp$argnum; } + %typemap(in) (array_view_crit&) (java::array_view_crit temp) { $1->reset(jenv, $input, true); } + + %typemap(jtype) (const array_view_crit&) "java_type[]" + %typemap(jstype) (const array_view_crit&) "java_type[]" + %typemap(jni) (const array_view_crit&) tostring(j##java_type##Array) + %typemap(javain) (const array_view_crit&) "$javainput" + %typemap(arginit) (const array_view_crit&) { $1 = &temp$argnum; } + %typemap(in) (const array_view_crit&) (java::array_view_crit temp) { $1->reset(jenv, $input, false); } +%enddef +define_array_crit_converion(int16_t,short) +define_array_crit_converion(int32_t,int) +define_array_crit_converion(int64_t,long) +define_array_crit_converion(char,byte) +define_array_crit_converion(float,float) +define_array_crit_converion(double,double) + +#endif // SWIG + +} + +#endif // DLIB_SWIG_JAVA_ARRAY_H_ + diff --git a/ml/dlib/dlib/java/run_test.sh b/ml/dlib/dlib/java/run_test.sh new file mode 100755 index 000000000..192ea6c8b --- /dev/null +++ b/ml/dlib/dlib/java/run_test.sh @@ -0,0 +1,17 @@ + +# build the jar and shared library of C++ code needed by the JVM +mkdir build +cd build +cmake .. +cmake --build . --config Release --target install +cd .. + + +# setup paths so the JVM can find our jar and shared library. +export LD_LIBRARY_PATH=. +export DYLD_LIBRARY_PATH=. +export CLASSPATH=myproject.jar:. + +# Now compile and run our java test that calls our C++ code. +javac swig_test.java +java swig_test diff --git a/ml/dlib/dlib/java/swig_api.h b/ml/dlib/dlib/java/swig_api.h new file mode 100644 index 000000000..e807c6c8f --- /dev/null +++ b/ml/dlib/dlib/java/swig_api.h @@ -0,0 +1,126 @@ +#ifndef EXAMPLE_SWIG_ApI_H_ +#define EXAMPLE_SWIG_ApI_H_ + +// This file is essentially a small unit test for the swig cmake scripts and the java array +// classes. All it does it define a few simple functions for writing to and summing +// arrays. The swig_test.java file then calls these C++ functions and checks if they work +// correctly. + + + +// Let's use java_array.h, a tool for efficiently binding java native arrays to C++ +// function arguments. You do this by putting this pair of include statements in your +// swig_api.h file. Then after that you can use the java::array, java::array_view, and +// java::array_view_crit classes. +#include +#ifdef SWIG +%include +#endif + + +using namespace java; + + +// SWIG can't expose templated functions to java. We declare these here as helper +// functions to make the non-templated routines swig will expose easier to write. You can +// see these java exposed methods below (i.e. sum(), sum_crit(), assign(), and +// assign_crit()). +template +T tsum(const array_view_crit& arr) +{ + T s = 0; + for (auto& v : arr) + s += v; + return s; +} +template +T tsum(const array_view& arr) +{ + T s = 0; + for (auto& v : arr) + s += v; + return s; +} +template +void tassign(T& arr) +{ + for (size_t i = 0; i < arr.size(); ++i) + arr[i] = i; +} + +// ---------------------------------------------------------------------------------------- + +// Now write some functions SWIG will expose to java. SWIG will automatically expose +// pretty much any non-template C++ code to java. So just by defining these functions here +// we expose them to java. +// +// All global C++ functions will appear in java as static member functions of class called +// "global", which is where these sum and assign routines will appear. You can see +// examples of java code that calls them in swig_test.java. + +inline int sum_crit(const array_view_crit& arr) { return tsum(arr); } +inline int sum(const array_view& arr) { return tsum(arr); } +inline void assign_crit(array_view_crit& arr) { tassign(arr); } +inline void assign(array_view& arr) { tassign(arr); } + + +inline int sum_crit(const array_view_crit& arr) { return tsum(arr); } +inline int sum(const array_view& arr) { return tsum(arr); } +inline void assign_crit(array_view_crit& arr) { tassign(arr); } +inline void assign(array_view& arr) { tassign(arr); } + + +inline int sum_crit(const array_view_crit& arr) { return tsum(arr); } +inline int sum(const array_view& arr) { return tsum(arr); } +inline void assign_crit(array_view_crit& arr) { tassign(arr); } +inline void assign(array_view& arr) { tassign(arr); } + + +inline int sum_crit(const array_view_crit& arr) { return tsum(arr); } +inline int sum(const array_view& arr) { return tsum(arr); } +inline void assign_crit(array_view_crit& arr) { tassign(arr); } +inline void assign(array_view& arr) { tassign(arr); } + + + +inline double sum_crit(const array_view_crit& arr) { return tsum(arr); } +inline double sum(const array_view& arr) { return tsum(arr); } +inline void assign_crit(array_view_crit& arr) { tassign(arr); } +inline void assign(array_view& arr) { tassign(arr); } + + +inline float sum_crit(array arr) +{ + array_view_crit a(arr); + return tsum(a); +} +inline float sum(const array& arr) +{ + array_view a(arr); + return tsum(a); +} +inline void assign_crit(array_view_crit& arr) { tassign(arr); } +inline void assign(array& arr) +{ + array_view a(arr); + tassign(a); +} + +array make_an_array(size_t s) +{ + array arr(s); + array_view_crit a(arr); + + for (size_t i = 0; i < a.size(); ++i) + a[i] = i; + + return arr; +} + + +// ---------------------------------------------------------------------------------------- + + +#endif // EXAMPLE_SWIG_ApI_H_ + + diff --git a/ml/dlib/dlib/java/swig_test.java b/ml/dlib/dlib/java/swig_test.java new file mode 100644 index 000000000..e75edb913 --- /dev/null +++ b/ml/dlib/dlib/java/swig_test.java @@ -0,0 +1,254 @@ + +/* + + This file tests all the ways of using jvector and jvector_crit. + +*/ + + +import net.dlib.*; + +public class swig_test +{ + public static int sum(long[] arr) + { + int s = 0; + for (int i = 0; i < arr.length; ++i) + s += arr[i]; + return s; + } + public static void zero(long[] arr) + { + for (int i = 0; i < arr.length; ++i) + arr[i] = 0; + } + + public static int sum(byte[] arr) + { + int s = 0; + for (int i = 0; i < arr.length; ++i) + s += arr[i]; + return s; + } + public static void zero(byte[] arr) + { + for (int i = 0; i < arr.length; ++i) + arr[i] = 0; + } + public static int sum(short[] arr) + { + int s = 0; + for (int i = 0; i < arr.length; ++i) + s += arr[i]; + return s; + } + public static void zero(short[] arr) + { + for (int i = 0; i < arr.length; ++i) + arr[i] = 0; + } + + public static int sum(int[] arr) + { + int s = 0; + for (int i = 0; i < arr.length; ++i) + s += arr[i]; + return s; + } + public static void zero(int[] arr) + { + for (int i = 0; i < arr.length; ++i) + arr[i] = 0; + } + + public static void assertIs28(int val) + { + if (val != 28) + { + throw new RuntimeException("Test failed " + val); + } + } + + public static void assertIsEqual(int val1, int val2) + { + if (val1 != val2) + { + throw new RuntimeException("Test failed " + val1 + " should be equal to " + val2); + } + } + + public static double sum(double[] arr) + { + double s = 0; + for (int i = 0; i < arr.length; ++i) + s += arr[i]; + return s; + } + public static void zero(double[] arr) + { + for (int i = 0; i < arr.length; ++i) + arr[i] = 0; + } + + public static void assertIs28(double val) + { + if (val != 28) + { + throw new RuntimeException("Test failed " + val); + } + } + + public static float sum(float[] arr) + { + float s = 0; + for (int i = 0; i < arr.length; ++i) + s += arr[i]; + return s; + } + public static void zero(float[] arr) + { + for (int i = 0; i < arr.length; ++i) + arr[i] = 0; + } + + public static void assertIs28(float val) + { + if (val != 28) + { + throw new RuntimeException("Test failed " + val); + } + } + + public static void main(String[] args) + { + { + float[] arr = new float[8]; + + for (int round = 0; round < 100; ++round) + { + zero(arr); global.assign(arr); + assertIs28(sum(arr)); + zero(arr); global.assign_crit(arr); + assertIs28(sum(arr)); + } + for (int round = 0; round < 100; ++round) + { + zero(arr); global.assign(arr); + assertIs28(sum(arr)); + assertIs28(global.sum(arr)); + zero(arr); global.assign_crit(arr); + assertIs28(sum(arr)); + assertIs28(global.sum_crit(arr)); + } + } + { + double[] arr = new double[8]; + + for (int round = 0; round < 100; ++round) + { + zero(arr); global.assign(arr); + assertIs28(sum(arr)); + zero(arr); global.assign_crit(arr); + assertIs28(sum(arr)); + } + for (int round = 0; round < 100; ++round) + { + zero(arr); global.assign(arr); + assertIs28(sum(arr)); + assertIs28(global.sum(arr)); + zero(arr); global.assign_crit(arr); + assertIs28(sum(arr)); + assertIs28(global.sum_crit(arr)); + } + } + { + byte[] arr = new byte[8]; + + for (int round = 0; round < 100; ++round) + { + zero(arr); global.assign(arr); + assertIs28(sum(arr)); + zero(arr); global.assign_crit(arr); + assertIs28(sum(arr)); + } + for (int round = 0; round < 100; ++round) + { + zero(arr); global.assign(arr); + assertIs28(sum(arr)); + assertIs28(global.sum(arr)); + zero(arr); global.assign_crit(arr); + assertIs28(sum(arr)); + assertIs28(global.sum_crit(arr)); + } + } + { + long[] arr = new long[8]; + + for (int round = 0; round < 100; ++round) + { + zero(arr); global.assign(arr); + assertIs28(sum(arr)); + zero(arr); global.assign_crit(arr); + assertIs28(sum(arr)); + } + for (int round = 0; round < 100; ++round) + { + zero(arr); global.assign(arr); + assertIs28(sum(arr)); + assertIs28(global.sum(arr)); + zero(arr); global.assign_crit(arr); + assertIs28(sum(arr)); + assertIs28(global.sum_crit(arr)); + } + } + { + short[] arr = new short[8]; + + for (int round = 0; round < 100; ++round) + { + zero(arr); global.assign(arr); + assertIs28(sum(arr)); + zero(arr); global.assign_crit(arr); + assertIs28(sum(arr)); + } + for (int round = 0; round < 100; ++round) + { + zero(arr); global.assign(arr); + assertIs28(sum(arr)); + assertIs28(global.sum(arr)); + zero(arr); global.assign_crit(arr); + assertIs28(sum(arr)); + assertIs28(global.sum_crit(arr)); + } + } + { + int[] arr = new int[8]; + + for (int round = 0; round < 100; ++round) + { + zero(arr); global.assign(arr); + assertIs28(sum(arr)); + zero(arr); global.assign_crit(arr); + assertIs28(sum(arr)); + } + for (int round = 0; round < 100; ++round) + { + zero(arr); global.assign(arr); + assertIs28(sum(arr)); + assertIs28(global.sum(arr)); + zero(arr); global.assign_crit(arr); + assertIs28(sum(arr)); + assertIs28(global.sum_crit(arr)); + } + } + { + int[] a = global.make_an_array(4); + for (int i = 0; i < a.length; ++i) + { + assertIsEqual(a[i], i); + } + } + + System.out.println("\n\n ALL TESTS COMPLETED SUCCESSFULLY\n"); + } +} diff --git a/ml/dlib/dlib/linker.h b/ml/dlib/dlib/linker.h new file mode 100644 index 000000000..e4f9b5daf --- /dev/null +++ b/ml/dlib/dlib/linker.h @@ -0,0 +1,9 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_LINKEr_ +#define DLIB_LINKEr_ + +#include "linker/linker_kernel_1.h" + +#endif // DLIB_LINKEr_ + diff --git a/ml/dlib/dlib/linker/linker_kernel_1.cpp b/ml/dlib/dlib/linker/linker_kernel_1.cpp new file mode 100644 index 000000000..e76009b37 --- /dev/null +++ b/ml/dlib/dlib/linker/linker_kernel_1.cpp @@ -0,0 +1,357 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_LINKER_KERNEL_1_CPp_ +#define DLIB_LINKER_KERNEL_1_CPp_ +#include "linker_kernel_1.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + linker:: + linker ( + ) : + running(false), + running_signaler(running_mutex), + A(0), + B(0), + service_connection_running_signaler(service_connection_running_mutex) + { + } + +// ---------------------------------------------------------------------------------------- + + linker:: + linker ( + connection& a, + connection& b + ) : + running(false), + running_signaler(running_mutex), + A(0), + B(0), + service_connection_running_signaler(service_connection_running_mutex) + { + link(a,b); + } + +// ---------------------------------------------------------------------------------------- + + linker:: + ~linker ( + ) + { + clear(); + } + +// ---------------------------------------------------------------------------------------- + + void linker:: + clear ( + ) + { + + // shutdown the connections + cons_mutex.lock(); + if (A != 0 ) + { + A->shutdown(); + A = 0; + } + if (B != 0) + { + B->shutdown(); + B = 0; + } + cons_mutex.unlock(); + + + // wait for the other threads to signal that they have ended + running_mutex.lock(); + while (running == true) + { + running_signaler.wait(); + } + running_mutex.unlock(); + + } + +// ---------------------------------------------------------------------------------------- + + bool linker:: + is_running ( + ) const + { + running_mutex.lock(); + bool temp = running; + running_mutex.unlock(); + return temp; + } + +// ---------------------------------------------------------------------------------------- + + void linker:: + link ( + connection& a, + connection& b + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( + this->is_running() == false , + "\tvoid linker::link" + << "\n\tis_running() == " << this->is_running() + << "\n\tthis: " << this + ); + + running_mutex.lock(); + running = true; + running_mutex.unlock(); + + cons_mutex.lock(); + A = &a; + B = &b; + cons_mutex.unlock(); + + + + service_connection_running_mutex.lock(); + service_connection_running = true; + service_connection_running_mutex.unlock(); + + service_connection_error_mutex.lock(); + service_connection_error = false; + service_connection_error_mutex.unlock(); + + // if we fail to make the thread + if (!create_new_thread(service_connection,this)) + { + a.shutdown(); + b.shutdown(); + + service_connection_running_mutex.lock(); + service_connection_running = false; + service_connection_running_mutex.unlock(); + + cons_mutex.lock(); + A = 0; + B = 0; + cons_mutex.unlock(); + + running_mutex.lock(); + running = false; + running_mutex.unlock(); + + + + throw dlib::thread_error ( + ECREATE_THREAD, + "failed to make new thread in linker::link()" + ); + } + + + + // forward data from a to b + char buf[200]; + int status; + bool error = false; // becomes true if one of the connections returns an error + while (true) + { + status = a.read(buf,sizeof(buf)); + // if there was an error reading from the socket + if (status == OTHER_ERROR) + { + error = true; + break; + } + else if (status == SHUTDOWN) + { + b.shutdown(); + } + + if (status <= 0) + { + // if a has closed normally + if (status == 0) + b.shutdown_outgoing(); + break; + } + + status = b.write(buf,status); + // if there was an error writing to the socket then break + if (status == OTHER_ERROR) + { + error = true; + break; + } + + if (status <= 0) + break; + } + + + // if there was an error then shutdown both connections + if (error) + { + a.shutdown(); + b.shutdown(); + } + + + + + // wait for the other thread to end + service_connection_running_mutex.lock(); + while(service_connection_running) + { + service_connection_running_signaler.wait(); + } + service_connection_running_mutex.unlock(); + + + // make sure connections are shutdown + a.shutdown(); + b.shutdown(); + + + // both threads have ended so the connections are no longer needed + cons_mutex.lock(); + A = 0; + B = 0; + cons_mutex.unlock(); + + + // if service_connection terminated due to an error then set error to true + service_connection_error_mutex.lock(); + if (service_connection_error) + error = true; + service_connection_error_mutex.unlock(); + + + // if we are ending because of an error + if (error) + { + + // signal that the link() function is ending + running_mutex.lock(); + running = false; + running_signaler.broadcast(); + running_mutex.unlock(); + + // throw the exception for this error + throw dlib::socket_error ( + ECONNECTION, + "a connection returned an error in linker::link()" + ); + + } + + // signal that the link() function is ending + running_mutex.lock(); + running = false; + running_signaler.broadcast(); + running_mutex.unlock(); + } + +// ---------------------------------------------------------------------------------------- + + void linker:: + service_connection ( + void* param + ) + { + linker& p = *static_cast(param); + + p.cons_mutex.lock(); + // if the connections are gone for whatever reason then return + if (p.A == 0 || p.B == 0) + { + // signal that this function is ending + p.service_connection_running_mutex.lock(); + p.service_connection_running = false; + p.service_connection_running_signaler.broadcast(); + p.service_connection_running_mutex.unlock(); + return; + } + connection& a = *p.A; + connection& b = *p.B; + p.cons_mutex.unlock(); + + + + // forward data from b to a + char buf[200]; + int status; + bool error = false; + while (true) + { + status = b.read(buf,sizeof(buf)); + // if there was an error reading from the socket + if (status == OTHER_ERROR) + { + error = true; + break; + } + else if (status == SHUTDOWN) + { + a.shutdown(); + } + + + if (status <= 0) + { + // if b has closed normally + if (status == 0) + a.shutdown_outgoing(); + break; + } + + + status = a.write(buf,status); + // if there was an error writing to the socket then break + if (status == OTHER_ERROR) + { + error = true; + break; + } + + if (status <= 0) + break; + } + + + // if there was an error then shutdown both connections + if (error) + { + a.shutdown(); + b.shutdown(); + } + + + // if there was an error then signal that + if (error) + { + p.service_connection_error_mutex.lock(); + p.service_connection_error = true; + p.service_connection_error_mutex.unlock(); + } + + // signal that this function is ending + p.service_connection_running_mutex.lock(); + p.service_connection_running = false; + p.service_connection_running_signaler.broadcast(); + p.service_connection_running_mutex.unlock(); + + } + +// ---------------------------------------------------------------------------------------- + +} +#endif // DLIB_LINKER_KERNEL_1_CPp_ + diff --git a/ml/dlib/dlib/linker/linker_kernel_1.h b/ml/dlib/dlib/linker/linker_kernel_1.h new file mode 100644 index 000000000..b101026b2 --- /dev/null +++ b/ml/dlib/dlib/linker/linker_kernel_1.h @@ -0,0 +1,141 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_LINKER_KERNEl_1_ +#define DLIB_LINKER_KERNEl_1_ + +#include "linker_kernel_abstract.h" +#include "../threads.h" +#include "../sockets.h" +#include "../algs.h" + + +namespace dlib +{ + + class linker + { + + /*! + INITIAL VALUE + running == false + A == 0 + B == 0 + running_mutex == a mutex + running_signaler == a signaler associated with running_mutex + cons_mutex == a mutex + service_connection_running == false + service_connection_running_mutex == a mutex + service_connection_running_signaler == a signaler associated with + service_connection_running_mutex + + service_connection_error == false + service_connection_error_mutex == a mutex + + + + CONVENTION + running == is_running() + running_mutex == a mutex for running + running_signaler == a signaler for signaling when + running becomes false and is associated with + running_mutex + cons_mutex == a mutex for A and B + + service_connection_running == true when service_connection() is + running or is about to run else + false + service_connection_running_mutex == a mutex for service_connection_running + service_connection_running_signaler == a signaler associated with + service_connection_running_mutex + + if (running) then + A == address of a from link() + B == address of b from link() + else + A == 0 + B == 0 + + service_connection_error == service_connection uses this bool + to indicate if it terminated due to + an error or not + service_connection_error_mutex == a mutex for service_connection_error + + + !*/ + + public: + + // These two typedefs are here for backwards compatibility with previous + // versions of dlib. + typedef linker kernel_1a; + typedef linker kernel_1a_c; + + linker( + ); + + linker ( + connection& a, + connection& b + ); + + virtual ~linker( + ); + + void clear( + ); + + bool is_running( + ) const; + + void link ( + connection& a, + connection& b + ); + + + private: + + static void service_connection ( + void* param + ); + /*! + requires + param == pointer to a linker object + ensures + waits for data from b and forwards it to a and + if (b closes normally or is shutdown()) service_connection ends and + if (b closes normally) then a.shutdown_outgoing() is called and + if (a or b returns an error) then a and b are shutdown() + !*/ + + + // data members + bool running; + mutex running_mutex; + signaler running_signaler; + connection* A; + connection* B; + mutex cons_mutex; + + bool service_connection_running; + mutex service_connection_running_mutex; + signaler service_connection_running_signaler; + + bool service_connection_error; + mutex service_connection_error_mutex; + + // restricted functions + linker(linker&); // copy constructor + linker& operator=(linker&); // assignment operator + }; + + + +} + +#ifdef NO_MAKEFILE +#include "linker_kernel_1.cpp" +#endif + +#endif // DLIB_LINKER_KERNEl_1_ + diff --git a/ml/dlib/dlib/linker/linker_kernel_abstract.h b/ml/dlib/dlib/linker/linker_kernel_abstract.h new file mode 100644 index 000000000..cef0901e4 --- /dev/null +++ b/ml/dlib/dlib/linker/linker_kernel_abstract.h @@ -0,0 +1,141 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_LINKER_KERNEl_ABSTRACT_ +#ifdef DLIB_LINKER_KERNEl_ABSTRACT_ + +#include "../threads/threads_kernel_abstract.h" +#include "../sockets/sockets_kernel_abstract.h" + +namespace dlib +{ + + class linker + { + + /*! + INITIAL VALUE + is_running() == false + + + WHAT THIS OBJECT REPRESENTS + This object represents something that takes two connections and lets + them talk to each other. i.e. any incoming data from one connection is + passed unaltered to the other and vice versa. + + note that linker objects are not swappable. + + Also note that when one connection is closed shutdown_outgoing() + is called on the other to signal that no more data will be sent + in that direction on the connection. + (i.e. the FIN packet is effectively also forwarded by the linker object) + + THREAD SAFETY + all member functions are thread-safe. + + !*/ + + public: + + linker( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc + - dlib::thread_error + !*/ + + linker ( + connection& a, + connection& b + ); + /*! + ensures + - #*this is properly initialized + - immediately invokes link(a,b); + (i.e. using this constructor is the same as creating a linker with + the default constructor and then immediately invoking link() on it) + throws + - std::bad_alloc + - dlib::thread_error + !*/ + + virtual ~linker( + ); + /*! + ensures + - all resources associated with *this have been released + !*/ + + void clear( + ); + /*! + ensures + - #*this has its initial value + - if (is_running()) then + - the two connections being linked will be shutdown() + throws + - std::bad_alloc + if this exception is thrown then the linker object is unusable + until clear() is called and succeeds and + if is_running() then the connections will STILL be shutdown() + even though an exception is being thrown + !*/ + + bool is_running( + ) const; + /*! + ensures + - returns true if link() is running else + - returns false if link() is not running or has released all its + resources and is about to terminate + throws + - std::bad_alloc + !*/ + + + void link ( + connection& a, + connection& b + ); + /*! + requires + - is_running() == false + ensures + - all incoming data from connection a will be forwarded to b + - all incoming data from connection b will be forwarded to a + - #a and #b will have been shutdown() + - link() will block until both of the connections have ended + or an error occurs + throws + - std::bad_alloc + link() may throw this exception and if it does then the object + will be unusable until clear() is called and succeeds and + connections a and b will be shutdown() + - dlib::socket_error + link() will throw a this exception if one of the connections + returns an error value (being shutdown is not an error). + If this happens then the linker object will be cleared and + have its initial value. note that if this happens then the + connections being linked will be shutdown() + - dlib::thread_error + link() will throw a this exception if there is a problem + creating new threads. Or it may throw this exception if there + is a problem creating threading objects. If this happens + then the linker object will be cleared and have its initial value. + note that if this happens then the connections being linked will + be shutdown(). + !*/ + + private: + + // restricted functions + linker(linker&); // copy constructor + linker& operator=(linker&); // assignment operator + }; + +} + +#endif // DLIB_LINKER_KERNEl_ABSTRACT_ + diff --git a/ml/dlib/dlib/locale b/ml/dlib/dlib/locale new file mode 100644 index 000000000..eb0e59e41 --- /dev/null +++ b/ml/dlib/dlib/locale @@ -0,0 +1 @@ +#include "dlib_include_path_tutorial.txt" diff --git a/ml/dlib/dlib/logger.h b/ml/dlib/dlib/logger.h new file mode 100644 index 000000000..e49fad8fd --- /dev/null +++ b/ml/dlib/dlib/logger.h @@ -0,0 +1,11 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_LOGGEr_ +#define DLIB_LOGGEr_ + +#include "logger/logger_kernel_1.h" +#include "logger/extra_logger_headers.h" +#include "logger/logger_config_file.h" + +#endif // DLIB_LOGGEr_ + diff --git a/ml/dlib/dlib/logger/extra_logger_headers.cpp b/ml/dlib/dlib/logger/extra_logger_headers.cpp new file mode 100644 index 000000000..becc1ab2b --- /dev/null +++ b/ml/dlib/dlib/logger/extra_logger_headers.cpp @@ -0,0 +1,40 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_EXTRA_LOGGER_HEADERs_CPP_ +#define DLIB_EXTRA_LOGGER_HEADERs_CPP_ + +#include "extra_logger_headers.h" +#include +#include + +// ---------------------------------------------------------------------------------------- + +namespace dlib +{ + + void print_datetime_logger_header ( + std::ostream& out, + const std::string& logger_name, + const log_level& l, + const uint64 thread_id + ) + { + using namespace std; + char* buf; + + time_t t = time(0); + buf = ctime(&t); + // remove the trailing '\n' + size_t size = strlen(buf); + buf[size-1] = '\0'; + + out << l.name << " (" << buf << ") [" << thread_id << "] " << logger_name << ": "; + } + +} + +// ---------------------------------------------------------------------------------------- + +#endif // DLIB_EXTRA_LOGGER_HEADERs_CPP_ + + diff --git a/ml/dlib/dlib/logger/extra_logger_headers.h b/ml/dlib/dlib/logger/extra_logger_headers.h new file mode 100644 index 000000000..6eb24d84c --- /dev/null +++ b/ml/dlib/dlib/logger/extra_logger_headers.h @@ -0,0 +1,41 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_EXTRA_LOGGER_HEADERs_ +#define DLIB_EXTRA_LOGGER_HEADERs_ + +#include "logger_kernel_abstract.h" +#include "logger_kernel_1.h" +#include +#include +#include "../uintn.h" + +// ---------------------------------------------------------------------------------------- + +namespace dlib +{ + + void print_datetime_logger_header ( + std::ostream& out, + const std::string& logger_name, + const log_level& l, + const uint64 thread_id + ); + /*! + requires + - is not called more than once at a time (i.e. is not called from multiple + threads at the same time). + ensures + - let DATE be the current date and time (e.g. Thu Aug 31 16:41:52 2006). + - prints a string to out in the form: "l.name (DATE) [thread_id] logger_name:" + !*/ + +} + +// ---------------------------------------------------------------------------------------- + +#ifdef NO_MAKEFILE +#include "extra_logger_headers.cpp" +#endif + +#endif // DLIB_EXTRA_LOGGER_HEADERs_ + diff --git a/ml/dlib/dlib/logger/logger_config_file.cpp b/ml/dlib/dlib/logger/logger_config_file.cpp new file mode 100644 index 000000000..108f66c8c --- /dev/null +++ b/ml/dlib/dlib/logger/logger_config_file.cpp @@ -0,0 +1,214 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_LOGGER_CONFIg_FILE_CPP +#define DLIB_LOGGER_CONFIg_FILE_CPP + +#include "logger_config_file.h" +#include +#include "../config_reader.h" +#include +#include +#include "../error.h" +#include "../map.h" +#include "../string.h" + +// ---------------------------------------------------------------------------------------- + +namespace dlib +{ + + namespace logger_config_file_helpers + { + +// ---------------------------------------------------------------------------------------- + + std::ostream& get_file_stream ( + const std::string& file_name + ) + { + using namespace std; + static dlib::mutex m; + auto_mutex M(m); + static dlib::map::kernel_1a_c file_map; + + if (file_map.is_in_domain(file_name) == false) + { + // We won't ever delete this output stream. It should be around for the + // entire life of the program so just let the OS take care of it. + ostream* fout = new ofstream(file_name.c_str()); + if (!(*fout)) + { + delete fout; + throw error("logger_config: unable to open output file " + file_name); + } + + // add this file to our file map + string temp(file_name); + file_map.add(temp,fout); + } + + return *file_map[file_name]; + } + +// ---------------------------------------------------------------------------------------- + + log_level string_to_log_level ( + const std::string& level + ) + { + using namespace std; + if (level == "LALL" || level == "ALL" || level == "all") + return LALL; + else if (level == "LNONE" || level == "NONE" || level == "none") + return LNONE; + else if (level == "LTRACE" || level == "TRACE" || level == "trace") + return LTRACE; + else if (level == "LDEBUG" || level == "DEBUG" || level == "debug") + return LDEBUG; + else if (level == "LINFO" || level == "INFO" || level == "info") + return LINFO; + else if (level == "LWARN" || level == "WARN" || level == "warn") + return LWARN; + else if (level == "LERROR" || level == "ERROR" || level == "error") + return LERROR; + else if (level == "LFATAL" || level == "FATAL" || level == "fatal") + return LFATAL; + else + { + const int priority = string_cast(level); + return log_level(priority,"CONFIG_FILE_DEFINED"); + } + } + +// ---------------------------------------------------------------------------------------- + + void configure_sub_blocks ( + const config_reader& cr, + const std::string& name + ) + { + using namespace std; + + logger dlog(name.c_str()); + + if (cr.is_key_defined("logging_level")) + { + dlog.set_level(string_to_log_level(cr["logging_level"])); + } + + if (cr.is_key_defined("output")) + { + string output = cr["output"]; + if (output == "cout") + dlog.set_output_stream(cout); + else if (output == "cerr") + dlog.set_output_stream(cerr); + else if (output == "clog") + dlog.set_output_stream(clog); + else + { + istringstream sin(output); + string one, two, three; + sin >> one; + sin >> two; + sin >> three; + if (one == "file" && three.size() == 0) + dlog.set_output_stream(get_file_stream(two)); + else + throw error("logger_config: invalid argument to output option: " + output); + } + + } // if (cr.is_key_defined("output")) + + // now configure all the sub-blocks + std_vector_c blocks; + cr.get_blocks(blocks); + for (unsigned long i = 0; i < blocks.size(); ++i) + { + configure_sub_blocks(cr.block(blocks[i]), name + "." + blocks[i]); + } + + } + +// ---------------------------------------------------------------------------------------- + + } // namespace + +// ---------------------------------------------------------------------------------------- + + void configure_loggers_from_file ( + const std::string& file_name + ) + { + std::ifstream fin(file_name.c_str()); + + if (!fin) + throw logger_config_file_error("logger_config: unable to open config file " + file_name); + + config_reader temp(fin); + configure_loggers_from_file(temp); + } + +// ---------------------------------------------------------------------------------------- + + void configure_loggers_from_file ( + const config_reader& main_cr + ) + { + using namespace logger_config_file_helpers; + using namespace std; + + if (main_cr.is_block_defined("logger_config")) + { + const config_reader& cr = main_cr.block("logger_config"); + + if (cr.is_key_defined("logging_level")) + { + set_all_logging_levels(string_to_log_level(cr["logging_level"])); + } + + if (cr.is_key_defined("output")) + { + string output = cr["output"]; + if (output == "cout") + set_all_logging_output_streams(cout); + else if (output == "cerr") + set_all_logging_output_streams(cerr); + else if (output == "clog") + set_all_logging_output_streams(clog); + else + { + istringstream sin(output); + string one, two, three; + sin >> one; + sin >> two; + sin >> three; + if (one == "file" && three.size() == 0) + set_all_logging_output_streams(get_file_stream(two)); + else + throw logger_config_file_error("logger_config: invalid argument to output option: " + output); + } + + } // if (cr.is_key_defined("output")) + + // now configure all the sub-blocks + std_vector_c blocks; + cr.get_blocks(blocks); + for (unsigned long i = 0; i < blocks.size(); ++i) + { + configure_sub_blocks(cr.block(blocks[i]), blocks[i]); + } + + } + } + +// ---------------------------------------------------------------------------------------- + +} + +// ---------------------------------------------------------------------------------------- + +#endif // DLIB_LOGGER_CONFIg_FILE_CPP + + + diff --git a/ml/dlib/dlib/logger/logger_config_file.h b/ml/dlib/dlib/logger/logger_config_file.h new file mode 100644 index 000000000..b0a030f80 --- /dev/null +++ b/ml/dlib/dlib/logger/logger_config_file.h @@ -0,0 +1,135 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_LOGGER_CONFIg_FILE_ +#define DLIB_LOGGER_CONFIg_FILE_ + +#include "logger_kernel_abstract.h" +#include "logger_kernel_1.h" +#include +#include "../config_reader.h" + +// ---------------------------------------------------------------------------------------- + +namespace dlib +{ + class logger_config_file_error : public error + { + /*! + WHAT THIS OBJECT REPRESENTS + This is the exception class used by the configure_loggers_from_file() + function defined below. + !*/ + public: + logger_config_file_error(const std::string& s):error(s){} + }; + + void configure_loggers_from_file ( + const std::string& file_name + ); + /*! + ensures + - configures the loggers with the contents of the file_name file + throws + - dlib::logger_config_file_error + this exception is thrown if there is a problem reading the config file + !*/ + + void configure_loggers_from_file ( + const config_reader& cr + ); + /*! + ensures + - configures the loggers with the contents of cr. This function is just like + the above version that reads from a file except that it reads from an in-memory + config_reader instead. + throws + - dlib::logger_config_file_error + this exception is thrown if there is a problem reading the config file + !*/ + +// ---------------------------------------------------------------------------------------- + + /*! + # ----------------------------------------------- + # ------------- EXAMPLE CONFIG FILE ------------- + # ----------------------------------------------- + + # The overall format of the config file is the same as the one defined by + # the config_reader component of this library. + + # This line is a comment line + + # The config file always has a block named logger_config. This is where all the + # config data for the loggers reside. + logger_config + { + # This sets all loggers to the level LINFO since it is just inside the + # logger_config block + logging_level = info + + # Alternatively we could specify a user defined logging level by + # supplying a priority number. The following line would specify + # that only logging levels at or above 100 are printed. (note that + # you would have to comment out the logging_level statement above + # to avoid a conflict). + # logging_level = 100 + + parent_logger + { + # This sets all loggers named "parent_logger" or children of + # loggers with that name to not log at all (i.e. to logging level + # LNONE). + logging_level = none + } + + + parent_logger2 + { + # set loggers named "parent_logger2" and its children loggers + # to write their output to a file named out.txt + output = file out.txt + + child_logger + { + # Set loggers named "parent_logger2.child_logger" and children of loggers + # with this name to logging level LALL + logging_level = all + + # Note that this logger will also log to out.txt because that is what + # its parent does and we haven't overridden it here with something else. + # if we wanted this logger to write to cout instead we could uncomment + # the following line: + # output = cout + } + } + } + + # So in summary, all logger config stuff goes inside a block named logger_config. Then + # inside that block all blocks must be the names of loggers. There are only two keys, + # logging_level and output. + # + # The valid values of logging_level are: + # "LALL", "LNONE", "LTRACE", "LDEBUG", "LINFO", "LWARN", "LERROR", "LFATAL", + # "ALL", "NONE", "TRACE", "DEBUG", "INFO", "WARN", "ERROR", "FATAL", + # "all", "none", "trace", "debug", "info", "warn", "error", "fatal", or + # any integral value + # + # The valid values of output are: + # "cout", "cerr", "clog", or a string of the form "file some_file_name" + # which causes the output to be logged to the specified file. + # + !*/ + + +} + +// ---------------------------------------------------------------------------------------- + +#ifdef NO_MAKEFILE +#include "logger_config_file.cpp" +#endif + +#endif // DLIB_LOGGER_CONFIg_FILE_ + + + diff --git a/ml/dlib/dlib/logger/logger_kernel_1.cpp b/ml/dlib/dlib/logger/logger_kernel_1.cpp new file mode 100644 index 000000000..093cd29a8 --- /dev/null +++ b/ml/dlib/dlib/logger/logger_kernel_1.cpp @@ -0,0 +1,498 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_LOGGER_KERNEL_1_CPp_ +#define DLIB_LOGGER_KERNEL_1_CPp_ + +#include "logger_kernel_1.h" +#include +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + void set_all_logging_output_streams ( + std::ostream& out_ + ) + { + logger::global_data& gd = logger::get_global_data(); + auto_mutex M(gd.m); + gd.loggers.reset(); + while (gd.loggers.move_next()) + { + gd.loggers.element()->out.rdbuf(out_.rdbuf()); + gd.loggers.element()->hook.clear(); + } + + gd.set_output_stream("",out_); + + // set the default hook to be an empty member function pointer + logger::hook_mfp hook; + gd.set_output_hook("",hook); + } + + void set_all_logging_levels ( + const log_level& new_level + ) + { + logger::global_data& gd = logger::get_global_data(); + auto_mutex M(gd.m); + gd.loggers.reset(); + while (gd.loggers.move_next()) + { + gd.loggers.element()->cur_level = new_level; + } + + gd.set_level("",new_level); + } + + void set_all_logging_headers ( + const print_header_type& new_header + ) + { + logger::global_data& gd = logger::get_global_data(); + auto_mutex M(gd.m); + gd.loggers.reset(); + while (gd.loggers.move_next()) + { + gd.loggers.element()->print_header = new_header; + } + + gd.set_logger_header("",new_header); + } + +// ---------------------------------------------------------------------------------------- + + namespace logger_helper_stuff + { + class helper + { + public: + helper() + { + std::ostringstream sout; + print_default_logger_header(sout,"some_name",LDEBUG,0); + } + }; + // do this to make sure all the static members of print_default_logger_header get + // initialized when the program turns on. + static helper a; + // make a logger to make extra sure the static global_data object gets + // initialized before any threads start up. Also do this so that there is always + // at least one logger so that the global data won't be deleted until the + // program is terminating. + static logger log("dlib"); + } + +// ---------------------------------------------------------------------------------------- + + void print_default_logger_header ( + std::ostream& out, + const std::string& logger_name, + const log_level& l, + const uint64 thread_id + ) + { + using namespace std; + static timestamper ts; + static const uint64 first_time = ts.get_timestamp(); + + const uint64 cur_time = (ts.get_timestamp() - first_time)/1000; + streamsize old_width = out.width(); out.width(5); + out << cur_time << " " << l.name; + out.width(old_width); + + out << " [" << thread_id << "] " << logger_name << ": "; + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// global_data stuff +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + logger::global_data:: + ~global_data ( + ) + { + unregister_thread_end_handler(*this,&global_data::thread_end_handler); + } + +// ---------------------------------------------------------------------------------------- + + logger::global_data:: + global_data( + ) : + next_thread_name(1) + { + // make sure the main program thread always has id 0. Since there is + // a global logger object declared in this file we should expect that + // the global_data object will be initialized in the main program thread + // so if we call get_thread_id() now we should get the main thread id. + thread_id_type main_id = get_thread_id(); + uint64 id_zero = 0; + thread_names.add(main_id,id_zero); + + // set up the defaults + auto_flush_table.val = true; + streambuf_table.val = std::cout.rdbuf(); + header_table.val = print_default_logger_header; + + // also allocate an initial buffer for hook based logging + hookbuf.buffer.reserve(1000); + } + + logger::global_data::level_container:: + level_container ( + ) : val(300,"ERROR") {} + +// ---------------------------------------------------------------------------------------- + + template + const T& search_tables ( + const T& c, + const std::string& name + ) + { + if (c.table.size() == 0 || name.size() == 0) + return c; + + const std::string::size_type pos = name.find_first_of("."); + const std::string first = name.substr(0,pos); + std::string last; + if (pos != std::string::npos) + last = name.substr(pos+1); + + if (c.table.is_in_domain(first)) + { + return search_tables(*c.table[first], last); + } + else + { + return c; + } + } + +// ---------------------------------------------------------------------------------------- + + template + void assign_tables ( + T& c, + const std::string& name, + const U& val + ) + { + if (name.size() == 0) + { + c.val = val; + c.table.clear(); + return; + } + + const std::string::size_type pos = name.find_first_of("."); + std::string first = name.substr(0,pos); + std::string last; + if (pos != std::string::npos) + last = name.substr(pos+1); + + if (c.table.is_in_domain(first)) + { + assign_tables(*c.table[first], last, val); + } + else + { + std::unique_ptr temp (new T); + temp->val = c.val; + assign_tables(*temp, last, val); + c.table.add(first,temp); + } + } + +// ---------------------------------------------------------------------------------------- + + const log_level logger::global_data:: + level ( + const std::string& name + ) const + { + auto_mutex M(m); + return search_tables(level_table, name).val; + } + +// ---------------------------------------------------------------------------------------- + + void logger::global_data:: + set_level ( + const std::string& name, + const log_level& new_level + ) + { + auto_mutex M(m); + assign_tables(level_table, name, new_level); + } + +// ---------------------------------------------------------------------------------------- + + bool logger::global_data:: + auto_flush ( + const std::string& name + ) const + { + auto_mutex M(m); + return search_tables(auto_flush_table, name).val; + } + +// ---------------------------------------------------------------------------------------- + + void logger::global_data:: + set_auto_flush ( + const std::string& name, + bool enabled + ) + { + auto_mutex M(m); + assign_tables(auto_flush_table, name, enabled); + } + +// ---------------------------------------------------------------------------------------- + + std::streambuf* logger::global_data:: + output_streambuf ( + const std::string& name + ) + { + auto_mutex M(m); + return search_tables(streambuf_table, name).val; + } + +// ---------------------------------------------------------------------------------------- + + void logger::global_data:: + set_output_stream ( + const std::string& name, + std::ostream& out_ + ) + { + auto_mutex M(m); + assign_tables( streambuf_table, name, out_.rdbuf()); + } + +// ---------------------------------------------------------------------------------------- + + void logger::global_data:: + set_output_stream ( + const std::string& name, + std::streambuf& buf + ) + { + auto_mutex M(m); + assign_tables( streambuf_table, name, &buf); + } + +// ---------------------------------------------------------------------------------------- + + logger::hook_mfp logger::global_data:: + output_hook ( + const std::string& name + ) + { + auto_mutex M(m); + return search_tables(hook_table, name).val; + } + +// ---------------------------------------------------------------------------------------- + + void logger::global_data:: + set_output_hook ( + const std::string& name, + const hook_mfp& hook + ) + { + auto_mutex M(m); + assign_tables( hook_table, name, hook); + } + +// ---------------------------------------------------------------------------------------- + + print_header_type logger::global_data:: + logger_header ( + const std::string& name + ) + { + auto_mutex M(m); + return search_tables(header_table, name).val; + } + +// ---------------------------------------------------------------------------------------- + + void logger::global_data:: + set_logger_header ( + const std::string& name, + print_header_type ph + ) + { + auto_mutex M(m); + assign_tables(header_table, name, ph); + } + +// ---------------------------------------------------------------------------------------- + + logger::global_data& logger::get_global_data() + { + // Allocate the global_data on the heap rather than on the stack because + // we want to guard against the case where this static object would be destroyed + // during program termination BEFORE all logger objects are destroyed. + static global_data* gd = new global_data; + return *gd; + } + +// ---------------------------------------------------------------------------------------- + + void logger::global_data:: + thread_end_handler ( + ) + { + auto_mutex M(m); + thread_id_type id = get_thread_id(); + thread_id_type junkd; + uint64 junkr; + thread_names.remove(id,junkd,junkr); + } + +// ---------------------------------------------------------------------------------------- + + uint64 logger::global_data:: + get_thread_name ( + ) + { + thread_id_type id = get_thread_id(); + uint64 thread_name; + if (thread_names.is_in_domain(id)) + { + thread_name = thread_names[id]; + } + else + { + if (is_dlib_thread(id)) + register_thread_end_handler(*this,&global_data::thread_end_handler); + thread_name = next_thread_name; + thread_names.add(id,thread_name); + thread_name = next_thread_name; + ++next_thread_name; + } + return thread_name; + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// logger_stream stuff +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + void logger::logger_stream:: + print_header_and_stuff ( + ) + { + if (!been_used) + { + log.gd.m.lock(); + + // Check if the output hook is setup. If it isn't then we print the logger + // header like normal. Otherwise we need to remember to clear out the output + // stringstream we always write to. + if (log.hook.is_set() == false) + { + log.logger_header()(log.out,log.name(),l,log.gd.get_thread_name()); + } + else + { + // Make sure the hook buffer doesn't have any old data in it before we start + // logging a new message into it. + log.gd.hookbuf.buffer.resize(0); + } + been_used = true; + } + } + +// ---------------------------------------------------------------------------------------- + + void logger::logger_stream:: + print_end_of_line ( + ) + { + auto_unlock M(log.gd.m); + + if (log.hook.is_set() == false) + { + if (log.auto_flush_enabled) + log.out << std::endl; + else + log.out << "\n"; + } + else + { + // Make sure the buffer is a proper C-string + log.gd.hookbuf.buffer.push_back('\0'); + // call the output hook with all the info regarding this log message. + log.hook(log.name(), l, log.gd.get_thread_name(), &log.gd.hookbuf.buffer[0]); + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// logger stuff +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + logger:: + logger ( + const std::string& name_ + ) : + gd(get_global_data()), + logger_name(name_), + out(gd.output_streambuf(logger_name)), + cur_level(gd.level(logger_name)) + { + DLIB_ASSERT(name_[0] != '\0', + "\tlogger::logger()" + << "\n\tYou can't make a logger with an empty name" + << "\n\tthis: " << this + ); + + auto_mutex M(gd.m); + logger* temp = this; + gd.loggers.add(temp); + + // load the appropriate settings + print_header = gd.logger_header(logger_name); + auto_flush_enabled = gd.auto_flush(logger_name); + hook = gd.output_hook(logger_name); + } + +// ---------------------------------------------------------------------------------------- + + logger:: + ~logger ( + ) + { + gd.m.lock(); + gd.loggers.destroy(this); + // if this was the last logger then delete the global data + if (gd.loggers.size() == 0) + { + gd.m.unlock(); + delete &gd; + } + else + { + gd.m.unlock(); + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_LOGGER_KERNEL_1_CPp_ + diff --git a/ml/dlib/dlib/logger/logger_kernel_1.h b/ml/dlib/dlib/logger/logger_kernel_1.h new file mode 100644 index 000000000..528bd6f67 --- /dev/null +++ b/ml/dlib/dlib/logger/logger_kernel_1.h @@ -0,0 +1,687 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_LOGGER_KERNEl_1_ +#define DLIB_LOGGER_KERNEl_1_ + +#include +#include +#include +#include +#include + +#include "../threads.h" +#include "../misc_api.h" +#include "../set.h" +#include "logger_kernel_abstract.h" +#include "../algs.h" +#include "../assert.h" +#include "../uintn.h" +#include "../map.h" +#include "../member_function_pointer.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class log_level + { + public: + log_level( + int priority_, + const char* name_ + ) : + priority(priority_) + { + strncpy(name,name_,19); + name[19] = '\0'; + } + + bool operator< (const log_level& rhs) const { return priority < rhs.priority; } + bool operator<=(const log_level& rhs) const { return priority <= rhs.priority; } + bool operator> (const log_level& rhs) const { return priority > rhs.priority; } + bool operator>=(const log_level& rhs) const { return priority >= rhs.priority; } + + int priority; + char name[20]; + }; + + inline std::ostream& operator<< (std::ostream& out, const log_level& item) + { + out << item.name; + return out; + } + + const log_level LALL (std::numeric_limits::min(),"ALL"); + const log_level LNONE (std::numeric_limits::max(),"NONE"); + const log_level LTRACE(-100,"TRACE"); + const log_level LDEBUG(0 ,"DEBUG"); + const log_level LINFO (100,"INFO "); + const log_level LWARN (200,"WARN "); + const log_level LERROR(300,"ERROR"); + const log_level LFATAL(400,"FATAL"); + +// ---------------------------------------------------------------------------------------- + + void set_all_logging_output_streams ( + std::ostream& out + ); + + void set_all_logging_levels ( + const log_level& new_level + ); + + typedef void (*print_header_type)( + std::ostream& out, + const std::string& logger_name, + const log_level& l, + const uint64 thread_id + ); + + void set_all_logging_headers ( + const print_header_type& new_header + ); + +// ---------------------------------------------------------------------------------------- + + void print_default_logger_header ( + std::ostream& out, + const std::string& logger_name, + const log_level& l, + const uint64 thread_id + ); + + template < + typename T + > + void set_all_logging_output_hooks ( + T& object, + void (T::*hook_)(const std::string& logger_name, + const log_level& l, + const uint64 thread_id, + const char* message_to_log) + ); + + template < + typename T + > + void set_all_logging_output_hooks ( + T& object + ) + { + set_all_logging_output_hooks(object, &T::log); + } + +// ---------------------------------------------------------------------------------------- + + class logger + { + /*! + INITIAL VALUE + - print_header == print_default_logger_header + - out.rdbuf() == std::cout.rdbuf() + - cur_level == LERROR + - auto_flush_enabled == true + - hook.is_set() == false + + CONVENTION + - print_header == logger_header() + - if (hook.is_set() == false) then + - out.rdbuf() == output_streambuf() + - else + - out.rdbuf() == &gd.hookbuf + - output_streambuf() == 0 + + - cur_level == level() + - logger_name == name() + - auto_flush_enabled == auto_flush() + + - logger::gd::loggers == a set containing all currently existing loggers. + - logger::gd::m == the mutex used to lock everything in the logger + - logger::gd::thread_names == a map of thread ids to thread names. + - logger::gd::next_thread_name == the next thread name that will be given out + to a thread when we find that it isn't already in thread_names. + !*/ + + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + + class logger_stream + { + /*! + INITIAL VALUE + - been_used == false + + CONVENTION + - enabled == is_enabled() + - if (been_used) then + - logger::gd::m is locked + - someone has used the << operator to write something to the + output stream. + !*/ + public: + logger_stream ( + const log_level& l_, + logger& log_ + ) : + l(l_), + log(log_), + been_used(false), + enabled (l.priority >= log.cur_level.priority) + {} + + inline ~logger_stream( + ) + { + if (!been_used) + { + return; + } + else + { + print_end_of_line(); + } + } + + bool is_enabled ( + ) const { return enabled; } + + template + inline logger_stream& operator << ( + const T& item + ) + { + if (!enabled) + { + return *this; + } + else + { + print_header_and_stuff(); + log.out << item; + return *this; + } + } + + private: + + void print_header_and_stuff ( + ); + /*! + ensures + - if (!been_used) then + - prints the logger header + - locks log.gd.m + - #been_used == true + !*/ + + void print_end_of_line ( + ); + /*! + ensures + - prints a newline to log.out + - unlocks log.gd.m + !*/ + + const log_level& l; + logger& log; + bool been_used; + const bool enabled; + }; // end of class logger_stream + + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + + friend class logger_stream; + public: + + typedef member_function_pointer hook_mfp; + + logger ( + const std::string& name_ + ); + + virtual ~logger ( + ); + + const std::string& name ( + ) const { return logger_name; } + + logger_stream operator << ( + const log_level& l + ) const { return logger_stream(l,const_cast(*this)); } + + bool is_child_of ( + const logger& log + ) const + { + return (name().find(log.name() + ".") == 0) || (log.name() == name()); + } + + const log_level level ( + ) const + { + auto_mutex M(gd.m); + return log_level(cur_level); + }; + + void set_level ( + const log_level& new_level + ) + { + auto_mutex M(gd.m); + gd.loggers.reset(); + while (gd.loggers.move_next()) + { + if (gd.loggers.element()->is_child_of(*this)) + gd.loggers.element()->cur_level = new_level; + } + + gd.set_level(logger_name, new_level); + } + + bool auto_flush ( + ) const + { + auto_mutex M(gd.m); + return auto_flush_enabled; + }; + + void set_auto_flush ( + bool enabled + ) + { + auto_mutex M(gd.m); + gd.loggers.reset(); + while (gd.loggers.move_next()) + { + if (gd.loggers.element()->is_child_of(*this)) + gd.loggers.element()->auto_flush_enabled = enabled; + } + + gd.set_auto_flush(logger_name, enabled); + } + + std::streambuf* output_streambuf ( + ) + { + auto_mutex M(gd.m); + + // if there is an output hook set then we are supposed to return 0. + if (hook) + return 0; + else + return out.rdbuf(); + } + + template < + typename T + > + void set_output_hook ( + T& object, + void (T::*hook_)(const std::string& logger_name, + const log_level& l, + const uint64 thread_id, + const char* message_to_log) + ) + { + auto_mutex M(gd.m); + hook.set(object, hook_); + + gd.loggers.reset(); + while (gd.loggers.move_next()) + { + if (gd.loggers.element()->is_child_of(*this)) + { + gd.loggers.element()->out.rdbuf(&gd.hookbuf); + gd.loggers.element()->hook = hook; + } + } + + gd.set_output_hook(logger_name, hook); + gd.set_output_stream(logger_name, gd.hookbuf); + } + + void set_output_stream ( + std::ostream& out_ + ) + { + auto_mutex M(gd.m); + gd.loggers.reset(); + while (gd.loggers.move_next()) + { + if (gd.loggers.element()->is_child_of(*this)) + { + gd.loggers.element()->out.rdbuf(out_.rdbuf()); + gd.loggers.element()->hook.clear(); + } + } + + gd.set_output_stream(logger_name, out_); + + hook.clear(); + gd.set_output_hook(logger_name, hook); + } + + print_header_type logger_header ( + ) const { return print_header; } + + void set_logger_header ( + print_header_type ph + ) + { + auto_mutex M(gd.m); + gd.loggers.reset(); + while (gd.loggers.move_next()) + { + if (gd.loggers.element()->is_child_of(*this)) + gd.loggers.element()->print_header = ph; + } + + gd.set_logger_header(logger_name, ph); + } + + private: + + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + + struct global_data + { + rmutex m; + set::kernel_1b loggers; + map::kernel_1b thread_names; + uint64 next_thread_name; + + // Make a very simple streambuf that writes characters into a std::vector. We can + // use this as the output target for hooks. The reason we don't just use a std::ostringstream + // instead is that this way we can be guaranteed that logging doesn't perform memory allocations. + // This is because a std::vector never frees memory. I.e. its capacity() doesn't go down when + // you resize it back to 0. It just stays the same. + class hook_streambuf : public std::streambuf + { + public: + std::vector buffer; + int_type overflow ( int_type c) + { + if (c != EOF) buffer.push_back(static_cast(c)); + return c; + } + + std::streamsize xsputn ( const char* s, std::streamsize num) + { + buffer.insert(buffer.end(), s, s+num); + return num; + } + }; + + hook_streambuf hookbuf; + + global_data ( + ); + + ~global_data( + ); + + uint64 get_thread_name ( + ); + /*! + requires + - m is locked + ensures + - returns a unique id for the calling thread. also makes the number + small and nice unlike what you get from get_thread_id() + !*/ + + void thread_end_handler ( + ); + /*! + ensures + - removes the terminated thread from thread_names + !*/ + + struct level_container + { + level_container (); + + log_level val; + map >::kernel_1b_c table; + } level_table; + + const log_level level ( + const std::string& name + ) const; + /*! + ensures + - returns the level loggers with the given name are supposed + to have + !*/ + + void set_level ( + const std::string& name, + const log_level& new_level + ); + /*! + ensures + - for all children C of name: + - #level(C) == new_level + - if name == "" then + - for all loggers L: + - #level(L) == new_level + !*/ + + struct auto_flush_container + { + bool val; + map >::kernel_1b_c table; + } auto_flush_table; + + bool auto_flush ( + const std::string& name + ) const; + /*! + ensures + - returns the auto_flush value loggers with the given name are supposed + to have + !*/ + + void set_auto_flush ( + const std::string& name, + bool enabled + ); + /*! + ensures + - for all children C of name: + - #auto_flush_enabled(C) == enabled + - if name == "" then + - for all loggers L: + - #auto_flush_enabled(L) == enabled + !*/ + + struct output_streambuf_container + { + std::streambuf* val; + map >::kernel_1b_c table; + } streambuf_table; + + std::streambuf* output_streambuf ( + const std::string& name + ); + /*! + ensures + - returns the streambuf loggers with the given name are supposed + to have + !*/ + + void set_output_stream ( + const std::string& name, + std::ostream& out_ + ); + /*! + ensures + - for all children C of name: + - #output_streambuf(C) == out_.rdbuf() + - if name == "" then + - for all loggers L: + - #output_streambuf(L) == out_.rdbuf() + !*/ + + void set_output_stream ( + const std::string& name, + std::streambuf& buf + ); + /*! + ensures + - for all children C of name: + - #output_streambuf(C) == &buf + - if name == "" then + - for all loggers L: + - #output_streambuf(L) == &buf + !*/ + + struct output_hook_container + { + hook_mfp val; + map >::kernel_1b_c table; + } hook_table; + + hook_mfp output_hook ( + const std::string& name + ); + /*! + ensures + - returns the hook loggers with the given name are supposed + to have + !*/ + + void set_output_hook ( + const std::string& name, + const hook_mfp& hook + ); + /*! + ensures + - for all children C of name: + - #output_hook(C) == hook + - if name == "" then + - for all loggers L: + - #output_hook(L) == hook + !*/ + + struct logger_header_container + { + print_header_type val; + map >::kernel_1b_c table; + } header_table; + + print_header_type logger_header ( + const std::string& name + ); + /*! + ensures + - returns the header function loggers with the given name are supposed + to have + !*/ + + void set_logger_header ( + const std::string& name, + print_header_type ph + ); + /*! + ensures + - for all children C of name: + - #logger_header(C) == ph + - if name == "" then + - for all loggers L: + - #logger_header(L) == ph + !*/ + + }; // end of struct global_data + + static global_data& get_global_data(); + + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + + friend void set_all_logging_levels ( + const log_level& new_level + ); + + friend void set_all_logging_headers ( + const print_header_type& new_header + ); + + friend void set_all_logging_output_streams ( + std::ostream& out + ); + + template < + typename T + > + friend void set_all_logging_output_hooks ( + T& object, + void (T::*hook_)(const std::string& logger_name, + const log_level& l, + const uint64 thread_id, + const char* message_to_log) + ) + { + logger::hook_mfp hook; + + // There is a bug in one of the versions (but not all apparently) of + // Visual studio 2005 that causes it to error out if isn't in the + // following line of code. However, there is also a bug in gcc-3.3 + // that causes it to error out if is present. So this works around + // this problem. +#if defined(_MSC_VER) && _MSC_VER == 1400 + hook.set(object, hook_); +#else + hook.set(object, hook_); +#endif + + logger::global_data& gd = logger::get_global_data(); + auto_mutex M(gd.m); + gd.loggers.reset(); + while (gd.loggers.move_next()) + { + gd.loggers.element()->out.rdbuf(&gd.hookbuf); + gd.loggers.element()->hook = hook; + } + + gd.set_output_stream("",gd.hookbuf); + gd.set_output_hook("",hook); + } + + // ------------------------------------------------------------------------------------ + + global_data& gd; + + const std::string logger_name; + + print_header_type print_header; + bool auto_flush_enabled; + std::ostream out; + log_level cur_level; + + hook_mfp hook; + + + // restricted functions + logger(const logger&); // copy constructor + logger& operator=(const logger&); // assignment operator + + }; + +// ---------------------------------------------------------------------------------------- + + + + +} + +#ifdef NO_MAKEFILE +#include "logger_kernel_1.cpp" +#endif + +#endif // DLIB_LOGGER_KERNEl_1_ + diff --git a/ml/dlib/dlib/logger/logger_kernel_abstract.h b/ml/dlib/dlib/logger/logger_kernel_abstract.h new file mode 100644 index 000000000..b6a4367a2 --- /dev/null +++ b/ml/dlib/dlib/logger/logger_kernel_abstract.h @@ -0,0 +1,429 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_LOGGER_KERNEl_ABSTRACT_ +#ifdef DLIB_LOGGER_KERNEl_ABSTRACT_ + +#include "../threads.h" +#include +#include +#include +#include "../uintn.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class log_level + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a simple named level to log at. It contains a numeric + priority and a name to use in the logging messages. + !*/ + public: + log_level( + int priority_, + const char* name_ + ); + /*! + ensures + - #priority = priority_ + - the first 19 characters of name_ are copied into name and name + is null terminated. + !*/ + + bool operator< (const log_level& rhs) const { return priority < rhs.priority; } + bool operator<=(const log_level& rhs) const { return priority <= rhs.priority; } + bool operator> (const log_level& rhs) const { return priority > rhs.priority; } + bool operator>=(const log_level& rhs) const { return priority >= rhs.priority; } + + int priority; + char name[20]; + }; + + inline std::ostream& operator<< (std::ostream& out, const log_level& item); + /*! + ensures + - performs out << item.name + - returns out + !*/ + +// ---------------------------------------------------------------------------------------- + + const log_level LALL (std::numeric_limits::min(),"ALL"); + const log_level LNONE (std::numeric_limits::max(),"NONE"); + const log_level LTRACE(-100,"TRACE"); + const log_level LDEBUG(0 ,"DEBUG"); + const log_level LINFO (100 ,"INFO "); + const log_level LWARN (200 ,"WARN "); + const log_level LERROR(300 ,"ERROR"); + const log_level LFATAL(400 ,"FATAL"); + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + void set_all_logging_output_streams ( + std::ostream& out + ); + /*! + ensures + - for all loggers L (even loggers not yet constructed): + - #L.output_streambuf() == out.rdbuf() + - Removes any previous output hook from L. So now the logger + L will write all its messages to the given output stream. + throws + - std::bad_alloc + !*/ + +// ---------------------------------------------------------------------------------------- + + typedef void (*print_header_type)( + std::ostream& out, + const std::string& logger_name, + const log_level& l, + const uint64 thread_id + ); + + void set_all_logging_headers ( + const print_header_type& new_header + ); + /*! + ensures + - for all loggers L (even loggers not yet constructed): + - #L.logger_header() == new_header + throws + - std::bad_alloc + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + void set_all_logging_output_hooks ( + T& object, + void (T::*hook)(const std::string& logger_name, + const log_level& l, + const uint64 thread_id, + const char* message_to_log) + ); + /*! + ensures + - for all loggers L (even loggers not yet constructed): + - #L.output_streambuf() == 0 + - performs the equivalent to calling L.set_output_hook(object, hook); + (i.e. sets all loggers so that they will use the given hook function) + throws + - std::bad_alloc + !*/ + + template < + typename T + > + void set_all_logging_output_hooks ( + T& object + ); + /*! + ensures + - calls set_all_logging_output_hooks(object, &T::log); + !*/ + +// ---------------------------------------------------------------------------------------- + + void set_all_logging_levels ( + const log_level& new_level + ); + /*! + ensures + - for all loggers L (even loggers not yet constructed): + - #L.level() == new_level + throws + - std::bad_alloc + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + void print_default_logger_header ( + std::ostream& out, + const std::string& logger_name, + const log_level& l, + const uint64 thread_id + ); + /*! + requires + - is not called more than once at a time (i.e. is not called from multiple + threads at the same time). + ensures + - let MS be the number of milliseconds since program start. + - prints a string to out in the form: "MS l.name [thread_id] logger_name:" + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class logger + { + /*! + INITIAL VALUE + - name() == a user supplied value given to the constructor + - The values of level(), output_streambuf(), logger_header(), and + auto_flush() are inherited from the parent of this logger. + + WHAT THIS OBJECT REPRESENTS + This object represents a logging output stream in the style of the log4j + logger available for Java. + + Additionally, the logger doesn't perform any memory allocations during + each logging action. It just writes directly into the user supplied output + stream. Alternatively, if you use a logging output hook no memory allocations + are performed either. Logging just goes straight into a memory buffer + which gets passed to the user supplied logging hook. + + DEFAULTS + If the user hasn't specified values for the four inherited values level(), + output_streambuf(), logger_header(), or auto_flush() then the default + values will be used. The defaults are as follows: + - level() == LERROR + - output_streambuf() == std::cout.rdbuf() (i.e. the default is to log + to standard output). + - logger_header() == print_default_logger_header + - auto_flush() == true + + THREAD SAFETY + All methods of this class are thread safe. Note that it is safe to + chain calls to operator << such as: + log << LINFO << "message " << variable << " more message"; + The logger ensures that the entire statement executes atomically so the + message won't be broken up by other loggers in other threads. + !*/ + + class logger_stream + { + public: + + bool is_enabled ( + ) const; + /*! + ensures + - returns true if this logger stream will print out items + given to it by the << operator. returns false otherwise. + !*/ + + template + logger_stream& operator << ( + const T& item + ); + /*! + ensures + - if (is_enabled()) then + - writes item to this output stream + - returns *this + !*/ + }; + + public: + + logger ( + const std::string& name_ + ); + /*! + requires + - name_ != "" + ensures + - #*this is properly initialized + - #name() == name_ + throws + - std::bad_alloc + - dlib::thread_error + !*/ + + virtual ~logger ( + ); + /*! + ensures + - any resources associated with *this have been released + !*/ + + const std::string& name ( + ) const; + /*! + ensures + - returns the name of this logger + !*/ + + logger_stream operator << ( + const log_level& l + ) const; + /*! + ensures + - if (l.priority >= level().priority) then + - returns a logger_stream with is_enabled() == true. I.e. this + returned stream will write its output to the I/O destination + used by this logger object. + - else + - returns a logger stream with is_enabled() == false + throws + - std::bad_alloc + !*/ + + bool is_child_of ( + const logger& log + ) const; + /*! + ensures + - if ( (name().find(log.name() + ".") == 0) || (log.name() == name()) ) then + - returns true + (i.e. if log.name() + "." is a prefix of name() or if both *this and log + have the same name then return true) + - else + - returns false + !*/ + + const log_level level ( + ) const; + /*! + ensures + - returns the current log level of this logger. + !*/ + + void set_level ( + const log_level& new_level + ); + /*! + ensures + - for all loggers L such that L.is_child_of(*this) == true: + - #L.level() == new_level + throws + - std::bad_alloc + !*/ + + bool auto_flush ( + ); + /*! + ensures + - returns true if the output stream is flushed after every logged message. + returns false otherwise. (Note that flushing only does anything if + the logger is set to use an output stream rather than a hook) + !*/ + + void set_auto_flush ( + bool enabled + ); + /*! + ensures + - for all loggers L such that L.is_child_of(*this) == true: + - #L.auto_flush() == enabled + throws + - std::bad_alloc + !*/ + + + template < + typename T + > + void set_output_hook ( + T& object, + void (T::*hook)(const std::string& logger_name, + const log_level& l, + const uint64 thread_id, + const char* message_to_log) + ); + /*! + requires + - hook is a valid pointer to a member function in T + ensures + - for all loggers L such that L.is_child_of(*this) == true: + - #L.output_streambuf() == 0 + - #L will not send its log messages to an ostream object anymore. Instead + it will call the given hook member function (i.e. (object.*hook)(name,l,id,msg) ) + for each message that needs to be logged. + - The arguments to the hook function have the following meanings: + - logger_name == The name of the logger that is printing the log message. + - l == The level of the logger that is printing the log message. + - thread_id == A number that uniquely identifies the thread trying to log + the message. Note that this number is unique among all threads, past and + present. Also note that this id is not the same one returned by + get_thread_id(). + - message_to_log == the actual text of the message the user is giving to + the logger object to log. + - All hook functions will also only be called one at a time. This means + that hook functions don't need to be thread safe. + !*/ + + std::streambuf* output_streambuf ( + ); + /*! + ensures + - if (an output hook isn't set) then + - returns the output stream buffer that this logger writes all + messages to. + - else + - returns 0 + !*/ + + void set_output_stream ( + std::ostream& out + ); + /*! + ensures + - for all loggers L such that L.is_child_of(*this) == true: + - #L.output_streambuf() == out.rdbuf() + - Removes any previous output hook from L. So now the logger + L will write all its messages to the given output stream. + throws + - std::bad_alloc + !*/ + + print_header_type logger_header ( + ) const; + /*! + ensures + - returns the function that is called to print the header information + onto each logged message. The arguments to the function have the following + meanings: + - out == The output stream this function writes the header to. + - logger_name == The name of the logger that is printing the log message. + - l == The level of the logger that is printing the log message. + - thread_id == A number that uniquely identifies the thread trying to log + the message. Note that this number is unique among all threads, past and + present. Also note that this id is not the same one returned by + get_thread_id(). + - This logger_header function will also only be called once at a time. This means + the logger_header function doesn't need to be thread safe. + - the logger_header function is only used when output_streambuf() != 0 + !*/ + + void set_logger_header ( + print_header_type print_header + ); + /*! + ensures + - for all loggers L such that L.is_child_of(*this) == true: + - #L.logger_header() == print_header + throws + - std::bad_alloc + !*/ + + private: + + // restricted functions + logger(const logger&); // copy constructor + logger& operator=(const logger&); // assignment operator + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_LOGGER_KERNEl_ABSTRACT_ + diff --git a/ml/dlib/dlib/lsh.h b/ml/dlib/dlib/lsh.h new file mode 100644 index 000000000..28f4b9bc4 --- /dev/null +++ b/ml/dlib/dlib/lsh.h @@ -0,0 +1,14 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_LSh_ +#define DLIB_LSh_ + + +#include "lsh/projection_hash.h" +#include "lsh/create_random_projection_hash.h" +#include "lsh/hashes.h" + + +#endif // DLIB_LSh_ + + diff --git a/ml/dlib/dlib/lsh/create_random_projection_hash.h b/ml/dlib/dlib/lsh/create_random_projection_hash.h new file mode 100644 index 000000000..b3aecd9ec --- /dev/null +++ b/ml/dlib/dlib/lsh/create_random_projection_hash.h @@ -0,0 +1,232 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CREATE_RANDOM_PROJECTION_HAsH_Hh_ +#define DLIB_CREATE_RANDOM_PROJECTION_HAsH_Hh_ + +#include "create_random_projection_hash_abstract.h" +#include "projection_hash.h" +#include "../matrix.h" +#include "../rand.h" +#include "../statistics.h" +#include "../svm.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename vector_type + > + projection_hash create_random_projection_hash ( + const vector_type& v, + const int bits, + dlib::rand& rnd + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(0 < bits && bits <= 32 && + v.size() > 1, + "\t projection_hash create_random_projection_hash()" + << "\n\t Invalid arguments were given to this function." + << "\n\t bits: " << bits + << "\n\t v.size(): " << v.size() + ); + +#ifdef ENABLE_ASSERTS + for (unsigned long i = 0; i < v.size(); ++i) + { + DLIB_ASSERT(v[0].size() == v[i].size() && v[i].size() > 0 && is_col_vector(v[i]), + "\t projection_hash create_random_projection_hash()" + << "\n\t Invalid arguments were given to this function." + << "\n\t m(0).size(): " << v[0].size() + << "\n\t m("< > rc; + for (unsigned long i = 0; i < v.size(); ++i) + rc.add(matrix_cast(v[i])); + + // compute a whitening matrix + matrix whiten = trans(chol(pinv(rc.covariance()))); + + + // hashes + std::vector h(v.size(),0); + + std::vector vals(v.size(),0); + + // number of hits for each hash value + std::vector counts; + + std::vector temp; + + // build a random projection matrix + matrix proj(bits, v[0].size()); + for (long r = 0; r < proj.nr(); ++r) + for (long c = 0; c < proj.nc(); ++c) + proj(r,c) = rnd.get_random_gaussian(); + + // merge whitening matrix with projection matrix + proj = proj*whiten; + + matrix offset(bits); + + + // figure out what the offset values should be + for (int itr = 0; itr < offset.size(); ++itr) + { + counts.assign(static_cast(std::pow(2.0,bits)), 0); + // count the popularity of each hash value + for (unsigned long i = 0; i < h.size(); ++i) + { + h[i] <<= 1; + counts[h[i]] += 1; + } + + const unsigned long max_h = index_of_max(mat(counts)); + + temp.clear(); + for (unsigned long i = 0; i < v.size(); ++i) + { + vals[i] = dot(rowm(proj,itr), matrix_cast(v[i])); + if (h[i] == max_h) + temp.push_back(vals[i]); + } + + // split down the middle + std::sort(temp.begin(), temp.end()); + const double split = temp[temp.size()/2]; + offset(itr) = -split; + + for (unsigned long i = 0; i < vals.size(); ++i) + { + if (vals[i] - split > 0) + h[i] |= 1; + } + } + + + return projection_hash(proj, offset); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename vector_type + > + projection_hash create_random_projection_hash ( + const vector_type& v, + const int bits + ) + { + dlib::rand rnd; + return create_random_projection_hash(v,bits,rnd); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename vector_type + > + projection_hash create_max_margin_projection_hash ( + const vector_type& v, + const int bits, + const double C, + dlib::rand& rnd + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(0 < bits && bits <= 32 && + v.size() > 1, + "\t projection_hash create_max_margin_projection_hash()" + << "\n\t Invalid arguments were given to this function." + << "\n\t bits: " << bits + << "\n\t v.size(): " << v.size() + ); + +#ifdef ENABLE_ASSERTS + for (unsigned long i = 0; i < v.size(); ++i) + { + DLIB_ASSERT(v[0].size() == v[i].size() && v[i].size() > 0 && is_col_vector(v[i]), + "\t projection_hash create_max_margin_projection_hash()" + << "\n\t Invalid arguments were given to this function." + << "\n\t m(0).size(): " << v[0].size() + << "\n\t m("< > rc; + for (unsigned long i = 0; i < v.size(); ++i) + rc.add(matrix_cast(v[i])); + + // compute a whitening matrix + matrix whiten = trans(chol(pinv(rc.covariance()))); + const matrix meanval = whiten*rc.mean(); + + + + typedef matrix sample_type; + random_subset_selector training_samples; + random_subset_selector training_labels; + // We set this up to use enough samples to cover the vector space used by elements + // of v. + training_samples.set_max_size(v[0].size()*10); + training_labels.set_max_size(v[0].size()*10); + + matrix proj(bits, v[0].size()); + matrix offset(bits); + + // learn the random planes and put them into proj and offset. + for (int itr = 0; itr < offset.size(); ++itr) + { + training_samples.make_empty(); + training_labels.make_empty(); + // pick random training data and give each sample a random label. + for (unsigned long i = 0; i < v.size(); ++i) + { + training_samples.add(whiten*v[i]-meanval); + if (rnd.get_random_double() > 0.5) + training_labels.add(+1); + else + training_labels.add(-1); + } + + svm_c_linear_dcd_trainer > trainer; + trainer.set_c(C); + decision_function > df = trainer.train(training_samples, training_labels); + offset(itr) = -df.b; + set_rowm(proj,itr) = trans(df.basis_vectors(0)); + } + + + return projection_hash(proj*whiten, offset-proj*meanval); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename vector_type + > + projection_hash create_max_margin_projection_hash ( + const vector_type& v, + const int bits, + const double C = 10 + ) + { + dlib::rand rnd; + return create_max_margin_projection_hash(v,bits,C,rnd); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CREATE_RANDOM_PROJECTION_HAsH_Hh_ + diff --git a/ml/dlib/dlib/lsh/create_random_projection_hash_abstract.h b/ml/dlib/dlib/lsh/create_random_projection_hash_abstract.h new file mode 100644 index 000000000..cff55b9a5 --- /dev/null +++ b/ml/dlib/dlib/lsh/create_random_projection_hash_abstract.h @@ -0,0 +1,148 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_CREATE_RANDOM_PROJECTION_HAsH_ABSTRACT_Hh_ +#ifdef DLIB_CREATE_RANDOM_PROJECTION_HAsH_ABSTRACT_Hh_ + +#include "projection_hash_abstract.h" +#include "../rand.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename vector_type + > + projection_hash create_random_projection_hash ( + const vector_type& v, + const int bits, + dlib::rand& rnd + ); + /*! + requires + - 0 < bits <= 32 + - v.size() > 1 + - vector_type == a std::vector or compatible type containing dlib::matrix + objects, each representing a column vector of the same size. + - for all valid i, j: + - is_col_vector(v[i]) == true + - v[i].size() > 0 + - v[i].size() == v[j].size() + - i.e. v contains only column vectors and all the column vectors + have the same non-zero length + - rand_type == a type that implements the dlib/rand/rand_kernel_abstract.h interface + ensures + - returns a hash function H such that: + - H.num_hash_bins() == pow(2,bits) + - H will be setup so that it hashes the contents of v such that each bin + ends up with roughly the same number of elements in it. This is + accomplished by picking random hyperplanes passing though the data. In + particular, each plane normal vector is filled with Gaussian random + numbers and we also perform basic centering to ensure the plane passes + though the data. + - This function uses the supplied random number generator, rnd, to drive part + of it's processing. Therefore, giving different random number generators + will produce different outputs. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename vector_type + > + projection_hash create_random_projection_hash ( + const vector_type& v, + const int bits + ); + /*! + requires + - 0 < bits <= 32 + - v.size() > 1 + - vector_type == a std::vector or compatible type containing dlib::matrix + objects, each representing a column vector of the same size. + - for all valid i, j: + - is_col_vector(v[i]) == true + - v[i].size() > 0 + - v[i].size() == v[j].size() + - i.e. v contains only column vectors and all the column vectors + have the same non-zero length + ensures + - returns create_random_projection_hash(v,bits,dlib::rand()) + (i.e. calls the above function with a default initialized random number generator) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename vector_type + > + projection_hash create_max_margin_projection_hash ( + const vector_type& v, + const int bits, + const double C, + dlib::rand& rnd + ); + /*! + requires + - 0 < bits <= 32 + - v.size() > 1 + - vector_type == a std::vector or compatible type containing dlib::matrix + objects, each representing a column vector of the same size. + - for all valid i, j: + - is_col_vector(v[i]) == true + - v[i].size() > 0 + - v[i].size() == v[j].size() + - i.e. v contains only column vectors and all the column vectors + have the same non-zero length + - rand_type == a type that implements the dlib/rand/rand_kernel_abstract.h interface + ensures + - returns a hash function H such that: + - H.num_hash_bins() == pow(2,bits) + - H will be setup so that it hashes the contents of v such that + each bin ends up with roughly the same number of elements + in it. This is accomplished using a variation on the random hyperplane + generation technique from the paper: + Random Maximum Margin Hashing by Alexis Joly and Olivier Buisson + In particular, we use the svm_c_linear_dcd_trainer to generate planes. + We train it on randomly selected and randomly labeled points from v. + The C SVM parameter is set to the given C argument. + - This function uses the supplied random number generator, rnd, to drive part + of it's processing. Therefore, giving different random number generators + will produce different outputs. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename vector_type + > + projection_hash create_max_margin_projection_hash ( + const vector_type& v, + const int bits, + const double C = 10 + ); + /*! + requires + - 0 < bits <= 32 + - v.size() > 1 + - vector_type == a std::vector or compatible type containing dlib::matrix + objects, each representing a column vector of the same size. + - for all valid i, j: + - is_col_vector(v[i]) == true + - v[i].size() > 0 + - v[i].size() == v[j].size() + - i.e. v contains only column vectors and all the column vectors + have the same non-zero length + ensures + - returns create_max_margin_projection_hash(v,bits,C,dlib::rand()) + (i.e. calls the above function with a default initialized random number generator) + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CREATE_RANDOM_PROJECTION_HAsH_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/lsh/hashes.h b/ml/dlib/dlib/lsh/hashes.h new file mode 100644 index 000000000..35053ce4e --- /dev/null +++ b/ml/dlib/dlib/lsh/hashes.h @@ -0,0 +1,219 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_LSH_HAShES_Hh_ +#define DLIB_LSH_HAShES_Hh_ + +#include "hashes_abstract.h" +#include "../hash.h" +#include "../matrix.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class hash_similar_angles_64 + { + public: + hash_similar_angles_64 ( + ) : seed(0) {} + + hash_similar_angles_64 ( + const uint64 seed_ + ) : seed(seed_) {} + + uint64 get_seed ( + ) const { return seed; } + + + typedef uint64 result_type; + + template < + typename sparse_vector_type + > + typename disable_if,uint64>::type operator() ( + const sparse_vector_type& v + ) const + { + typedef typename sparse_vector_type::value_type::second_type scalar_type; + + uint64 temp = 0; + for (int i = 0; i < 64; ++i) + { + // compute the dot product between v and a Gaussian random vector. + scalar_type val = 0; + for (typename sparse_vector_type::const_iterator j = v.begin(); j != v.end(); ++j) + val += j->second*gaussian_random_hash(j->first, i, seed); + + if (val > 0) + temp |= 1; + temp <<= 1; + } + return temp; + } + + template + uint64 operator() ( + const matrix_exp& v + ) const + { + typedef typename EXP::type T; + uint64 temp = 0; + for (unsigned long i = 0; i < 64; ++i) + { + if (dot(matrix_cast(gaussian_randm(v.size(),1,i+seed*64)), v) > 0) + temp |= 1; + temp <<= 1; + } + return temp; + } + + unsigned int distance ( + const result_type& a, + const result_type& b + ) const + { + return hamming_distance(a,b); + } + + private: + const uint64 seed; + }; + +// ---------------------------------------------------------------------------------------- + + class hash_similar_angles_128 + { + public: + hash_similar_angles_128 ( + ) : seed(0),hasher1(0), hasher2(1) {} + + hash_similar_angles_128 ( + const uint64 seed_ + ) : seed(seed_),hasher1(2*seed),hasher2(2*seed+1) {} + + uint64 get_seed ( + ) const { return seed; } + + typedef std::pair result_type; + + template < + typename vector_type + > + result_type operator() ( + const vector_type& v + ) const + { + return std::make_pair(hasher1(v), hasher2(v)); + } + + unsigned int distance ( + const result_type& a, + const result_type& b + ) const + { + return hamming_distance(a.first,b.first) + + hamming_distance(a.second,b.second); + } + + private: + const uint64 seed; + hash_similar_angles_64 hasher1; + hash_similar_angles_64 hasher2; + + }; + +// ---------------------------------------------------------------------------------------- + + class hash_similar_angles_256 + { + public: + hash_similar_angles_256 ( + ) : seed(0), hasher1(0), hasher2(1) {} + + hash_similar_angles_256 ( + const uint64 seed_ + ) : seed(seed_),hasher1(2*seed),hasher2(2*seed+1) {} + + uint64 get_seed ( + ) const { return seed; } + + typedef std::pair hash128_type; + typedef std::pair result_type; + + template < + typename vector_type + > + result_type operator() ( + const vector_type& v + ) const + { + return std::make_pair(hasher1(v), hasher2(v)); + } + + unsigned int distance ( + const result_type& a, + const result_type& b + ) const + { + return hasher1.distance(a.first,b.first) + + hasher1.distance(a.second,b.second); + } + + private: + const uint64 seed; + hash_similar_angles_128 hasher1; + hash_similar_angles_128 hasher2; + + }; + +// ---------------------------------------------------------------------------------------- + + class hash_similar_angles_512 + { + public: + hash_similar_angles_512 ( + ) : seed(0), hasher1(0), hasher2(1) {} + + hash_similar_angles_512 ( + const uint64 seed_ + ) : seed(seed_),hasher1(2*seed),hasher2(2*seed+1) {} + + uint64 get_seed ( + ) const { return seed; } + + + typedef hash_similar_angles_256::result_type hash256_type; + typedef std::pair result_type; + + template < + typename vector_type + > + result_type operator() ( + const vector_type& v + ) const + { + return std::make_pair(hasher1(v), hasher2(v)); + } + + unsigned int distance ( + const result_type& a, + const result_type& b + ) const + { + return hasher1.distance(a.first,b.first) + + hasher1.distance(a.second,b.second); + } + + private: + const uint64 seed; + hash_similar_angles_256 hasher1; + hash_similar_angles_256 hasher2; + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_LSH_HAShES_Hh_ + diff --git a/ml/dlib/dlib/lsh/hashes_abstract.h b/ml/dlib/dlib/lsh/hashes_abstract.h new file mode 100644 index 000000000..27f8ddb69 --- /dev/null +++ b/ml/dlib/dlib/lsh/hashes_abstract.h @@ -0,0 +1,286 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_LSH_HAShES_ABSTRACT_Hh_ +#ifdef DLIB_LSH_HAShES_ABSTRACT_Hh_ + +#include "../matrix.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class hash_similar_angles_64 + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a tool for computing locality sensitive hashes that give + vectors with small angles between each other similar hash values. In + particular, this object creates 64 random planes which pass though the + origin and uses them to create a 64bit hash. To compute the hash for a new + vector, this object checks which side of each plane the vector falls on and + records this information into a 64bit integer. + !*/ + + public: + + hash_similar_angles_64 ( + ); + /*! + ensures + - #get_seed() == 0 + !*/ + + hash_similar_angles_64 ( + const uint64 seed + ); + /*! + ensures + - #get_seed() == seed + !*/ + + uint64 get_seed ( + ) const; + /*! + ensures + - returns the random seed used to generate the random planes used for + hashing. + !*/ + + typedef uint64 result_type; + + template + result_type perator() ( + const vector_type& v + ) const; + /*! + requires + - v is an unsorted sparse vector or a dlib matrix representing either a + column or row vector. + ensures + - returns a 64 bit hash of the input vector v. The bits in the hash record + which side of each random plane v falls on. + + !*/ + + unsigned int distance ( + const result_type& a, + const result_type& b + ) const; + /*! + ensures + - returns the Hamming distance between the two hashes given to this + function. That is, we return the number of bits in a and b which differ. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + struct hash_similar_angles_128 + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a tool for computing locality sensitive hashes that give + vectors with small angles between each other similar hash values. In + particular, this object creates 128 random planes which pass though the + origin and uses them to create a 128bit hash. To compute the hash for a new + vector, this object checks which side of each plane the vector falls on and + records this information into a 128bit integer. + !*/ + + public: + + hash_similar_angles_128 ( + ); + /*! + ensures + - #get_seed() == 0 + !*/ + + hash_similar_angles_128 ( + const uint64 seed + ); + /*! + ensures + - #get_seed() == seed + !*/ + + uint64 get_seed ( + ) const; + /*! + ensures + - returns the random seed used to generate the random planes used for + hashing. + !*/ + + typedef std::pair result_type; + + template + result_type perator() ( + const vector_type& v + ) const; + /*! + requires + - v is an unsorted sparse vector or a dlib matrix representing either a + column or row vector. + ensures + - returns a 128 bit hash of the input vector v. The bits in the hash record + which side of each random plane v falls on. + + !*/ + + unsigned int distance ( + const result_type& a, + const result_type& b + ) const; + /*! + ensures + - returns the Hamming distance between the two hashes given to this + function. That is, we return the number of bits in a and b which differ. + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + struct hash_similar_angles_256 + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a tool for computing locality sensitive hashes that give + vectors with small angles between each other similar hash values. In + particular, this object creates 256 random planes which pass though the + origin and uses them to create a 256bit hash. To compute the hash for a new + vector, this object checks which side of each plane the vector falls on and + records this information into a 256bit integer. + !*/ + + public: + + hash_similar_angles_256 ( + ); + /*! + ensures + - #get_seed() == 0 + !*/ + + hash_similar_angles_256 ( + const uint64 seed + ); + /*! + ensures + - #get_seed() == seed + !*/ + + uint64 get_seed ( + ) const; + /*! + ensures + - returns the random seed used to generate the random planes used for + hashing. + !*/ + + typedef std::pair hash128_type; + typedef std::pair result_type; + + template + result_type perator() ( + const vector_type& v + ) const; + /*! + requires + - v is an unsorted sparse vector or a dlib matrix representing either a + column or row vector. + ensures + - returns a 256 bit hash of the input vector v. The bits in the hash record + which side of each random plane v falls on. + + !*/ + + unsigned int distance ( + const result_type& a, + const result_type& b + ) const; + /*! + ensures + - returns the Hamming distance between the two hashes given to this + function. That is, we return the number of bits in a and b which differ. + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + struct hash_similar_angles_512 + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a tool for computing locality sensitive hashes that give + vectors with small angles between each other similar hash values. In + particular, this object creates 512 random planes which pass though the + origin and uses them to create a 512bit hash. To compute the hash for a new + vector, this object checks which side of each plane the vector falls on and + records this information into a 512bit integer. + !*/ + + public: + + hash_similar_angles_512 ( + ); + /*! + ensures + - #get_seed() == 0 + !*/ + + hash_similar_angles_512 ( + const uint64 seed + ); + /*! + ensures + - #get_seed() == seed + !*/ + + uint64 get_seed ( + ) const; + /*! + ensures + - returns the random seed used to generate the random planes used for + hashing. + !*/ + + typedef hash_similar_angles_256::result_type hash256_type; + typedef std::pair result_type; + + template + result_type perator() ( + const vector_type& v + ) const; + /*! + requires + - v is an unsorted sparse vector or a dlib matrix representing either a + column or row vector. + ensures + - returns a 512 bit hash of the input vector v. The bits in the hash record + which side of each random plane v falls on. + + !*/ + + unsigned int distance ( + const result_type& a, + const result_type& b + ) const; + /*! + ensures + - returns the Hamming distance between the two hashes given to this + function. That is, we return the number of bits in a and b which differ. + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_LSH_HAShES_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/lsh/projection_hash.h b/ml/dlib/dlib/lsh/projection_hash.h new file mode 100644 index 000000000..16de0ba11 --- /dev/null +++ b/ml/dlib/dlib/lsh/projection_hash.h @@ -0,0 +1,118 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_PROJECTION_HASh_Hh_ +#define DLIB_PROJECTION_HASh_Hh_ + +#include "projection_hash_abstract.h" +#include "../matrix.h" +#include "../rand.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class projection_hash + { + public: + + projection_hash() {} + + template + projection_hash( + const matrix_exp& proj_, + const matrix_exp& offset_ + ) : proj(proj_), offset(offset_) + { + // make sure requires clause is not broken + DLIB_ASSERT(proj.nr() == offset.nr(), + "\t projection_hash::projection_hash()" + << "\n\t Invalid arguments were given to this function." + << "\n\t proj.nr(): " << proj.nr() + << "\n\t offset.nr(): " << offset.nr() + ); + + } + + const matrix& get_projection_matrix ( + ) const { return proj; } + + const matrix& get_offset_matrix ( + ) const { return offset; } + + unsigned long num_hash_bins ( + ) const + { + return static_cast(std::pow(2.0, (double)offset.size())); + } + + template + unsigned long operator() ( + const matrix_exp& v + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(is_col_vector(v) && + v.size() == get_projection_matrix().nc() && + v.size() > 0, + "\t unsigned long projection_hash::operator()(v)" + << "\n\t Invalid arguments were given to this function." + << "\n\t is_col_vector(v): " << is_col_vector(v) + << "\n\t get_projection_matrix().nc(): " << get_projection_matrix().nc() + << "\n\t v.size(): " << v.size() + ); + + return do_hash(proj*matrix_cast(v) + offset); + } + + private: + + template + unsigned long do_hash ( + const matrix_exp& v + ) const + { + unsigned long h = 0; + for (long i = 0; i < v.size(); ++i) + { + h <<= 1; + if (v(i) > 0) + h |= 1; + } + return h; + } + + matrix proj; + matrix offset; + }; + +// ---------------------------------------------------------------------------------------- + + inline void serialize ( + const projection_hash& item, + std::ostream& out + ) + { + serialize(item.get_projection_matrix(), out); + serialize(item.get_offset_matrix(), out); + } + + inline void deserialize ( + projection_hash& item, + std::istream& in + ) + { + matrix proj; + matrix offset; + deserialize(proj, in); + deserialize(offset, in); + item = projection_hash(proj, offset); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_PROJECTION_HASh_Hh_ + diff --git a/ml/dlib/dlib/lsh/projection_hash_abstract.h b/ml/dlib/dlib/lsh/projection_hash_abstract.h new file mode 100644 index 000000000..abe78d10c --- /dev/null +++ b/ml/dlib/dlib/lsh/projection_hash_abstract.h @@ -0,0 +1,119 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_PROJECTION_HASh_ABSTRACT_Hh_ +#ifdef DLIB_PROJECTION_HASh_ABSTRACT_Hh_ + +#include "../matrix.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class projection_hash + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a tool for hashing elements of a vector space into the integers. + It is intended to represent locality sensitive hashing functions such as + the popular random projection hashing method. + + In particular, it represents hash functions of the form: + hash bit 0 = sign(rowm(P*v + O,0)) + hash bit 1 = sign(rowm(P*v + O,1)) + hash bit 2 = sign(rowm(P*v + O,2)) + ... + Where v is the vector to be hashed. The parameters of the projection + hash are the P and O matrices. + + THREAD SAFETY + The const members of this object can be called concurrently from multiple + threads, however, any operation that modifies the state of an instance of + this object must serialize access to that instance. + !*/ + public: + + projection_hash( + ); + /*! + ensures + - #get_projection_matrix().size() == 0 + - #get_offset_matrix().size() == 0 + !*/ + + template + projection_hash( + const matrix_exp& proj, + const matrix_exp& offset + ); + /*! + requires + - proj.nr() == offset.nr() + ensures + - #get_projection_matrix() == proj + - #get_offset_matrix() == offset + !*/ + + const matrix& get_projection_matrix ( + ) const; + /*! + ensures + - returns the P matrix discussed above in the WHAT THIS OBJECT REPRESENTS + section. + !*/ + + const matrix& get_offset_matrix ( + ) const; + /*! + ensures + - returns the O matrix discussed above in the WHAT THIS OBJECT REPRESENTS + section. + !*/ + + unsigned long num_hash_bins ( + ) const; + /*! + ensures + - returns the number of possible outputs from this hashing function. + - Specifically, returns: std::pow(2, get_offset_matrix().size()) + !*/ + + template + unsigned long operator() ( + const matrix_exp& v + ) const; + /*! + requires + - is_col_vector(v) == true + - v.size() == get_projection_matrix().nc() + - v.size() > 0 + ensures + - hashes v into the range [0, num_hash_bins()) using the method + discussed in the WHAT THIS OBJECT REPRESENTS section. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + void serialize ( + const projection_hash& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + + void deserialize ( + projection_hash& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_PROJECTION_HASh_ABSTRACT_Hh_ + diff --git a/ml/dlib/dlib/lz77_buffer.h b/ml/dlib/dlib/lz77_buffer.h new file mode 100644 index 000000000..b7364ad9a --- /dev/null +++ b/ml/dlib/dlib/lz77_buffer.h @@ -0,0 +1,47 @@ +// Copyright (C) 2004 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_LZ77_BUFFEr_ +#define DLIB_LZ77_BUFFEr_ + + +#include "lz77_buffer/lz77_buffer_kernel_1.h" +#include "lz77_buffer/lz77_buffer_kernel_2.h" +#include "lz77_buffer/lz77_buffer_kernel_c.h" + +#include "sliding_buffer.h" + + +namespace dlib +{ + + + class lz77_buffer + { + + lz77_buffer() {} + + typedef sliding_buffer::kernel_1a sb1; + + public: + + //----------- kernels --------------- + + // kernel_1a + typedef lz77_buffer_kernel_1 + kernel_1a; + typedef lz77_buffer_kernel_c + kernel_1a_c; + + + // kernel_2a + typedef lz77_buffer_kernel_2 + kernel_2a; + typedef lz77_buffer_kernel_c + kernel_2a_c; + + + }; +} + +#endif // DLIB_LZ77_BUFFEr_ + diff --git a/ml/dlib/dlib/lz77_buffer/lz77_buffer_kernel_1.h b/ml/dlib/dlib/lz77_buffer/lz77_buffer_kernel_1.h new file mode 100644 index 000000000..b93f0628d --- /dev/null +++ b/ml/dlib/dlib/lz77_buffer/lz77_buffer_kernel_1.h @@ -0,0 +1,263 @@ +// Copyright (C) 2004 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_LZ77_BUFFER_KERNEl_1_ +#define DLIB_LZ77_BUFFER_KERNEl_1_ + +#include "lz77_buffer_kernel_abstract.h" +#include "../algs.h" + + + +namespace dlib +{ + + template < + typename sliding_buffer + > + class lz77_buffer_kernel_1 + { + /*! + REQUIREMENTS ON sliding_buffer + sliding_buffer must be an implementation of sliding_buffer/sliding_buffer_kernel_abstract.h + + INITIAL VALUE + history_limit == defined by constructor arguments + lookahead_limit == defined by constructor arguments + history_size == 0 + lookahead_size == 0 + buffer.size() == history_limit + lookahead_limit + + + CONVENTION + history_limit == get_history_buffer_limit() + lookahead_limit == get_lookahead_buffer_limit() + history_size == get_history_buffer_size() + lookahead_limit == get_lookahead_buffer_size() + + buffer.size() == history_limit + lookahead_limit + + lookahead_buffer(i) == buffer[lookahead_limit-1-i] + history_buffer(i) == buffer[lookahead_limit+i] + !*/ + + public: + + lz77_buffer_kernel_1 ( + unsigned long total_limit_, + unsigned long lookahead_limit_ + ); + + virtual ~lz77_buffer_kernel_1 ( + ) {} + + void clear( + ); + + void add ( + unsigned char symbol + ); + + void find_match ( + unsigned long& index, + unsigned long& length, + unsigned long min_match_length + ); + + inline unsigned long get_history_buffer_limit ( + ) const { return history_limit; } + + inline unsigned long get_lookahead_buffer_limit ( + ) const { return lookahead_limit; } + + inline unsigned long get_history_buffer_size ( + ) const { return history_size; } + + inline unsigned long get_lookahead_buffer_size ( + ) const { return lookahead_size; } + + inline unsigned char lookahead_buffer ( + unsigned long index + ) const { return buffer[lookahead_limit-1-index]; } + + inline unsigned char history_buffer ( + unsigned long index + ) const { return buffer[lookahead_limit+index]; } + + + inline void shift_buffers ( + unsigned long N + ) { shift_buffer(N); } + + private: + + + inline void shift_buffer ( + unsigned long N + ) + /*! + requires + - N <= lookahead_size + ensuers + - #lookahead_size == lookahead_size - N + - if (history_size+N < history_limit) then + - #history_size == history_size+N + - else + - #history_size == history_limit + - for all i where 0 <= i < N: + #history_buffer(N-1-i) == lookahead_buffer(i) + - for all i where 0 <= i < #history_size-N: + #history_buffer(N+i) == history_buffer(i) + - for all i where 0 <= i < #lookahead_size + #lookahead_buffer(i) == lookahead_buffer(N+i) + !*/ + { + unsigned long temp = history_size+N; + buffer.rotate_left(N); + lookahead_size -= N; + if (temp < history_limit) + history_size = temp; + else + history_size = history_limit; + } + + + // member data + sliding_buffer buffer; + unsigned long lookahead_limit; + unsigned long history_limit; + + + unsigned long lookahead_size; + unsigned long history_size; + + + // restricted functions + lz77_buffer_kernel_1(lz77_buffer_kernel_1&); // copy constructor + lz77_buffer_kernel_1& operator=(lz77_buffer_kernel_1&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename sliding_buffer + > + lz77_buffer_kernel_1:: + lz77_buffer_kernel_1 ( + unsigned long total_limit_, + unsigned long lookahead_limit_ + ) : + lookahead_size(0), + history_size(0) + { + buffer.set_size(total_limit_); + lookahead_limit = lookahead_limit_; + history_limit = buffer.size() - lookahead_limit_; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename sliding_buffer + > + void lz77_buffer_kernel_1:: + clear( + ) + { + lookahead_size = 0; + history_size = 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename sliding_buffer + > + void lz77_buffer_kernel_1:: + add ( + unsigned char symbol + ) + { + if (lookahead_size == lookahead_limit) + { + shift_buffer(1); + } + buffer[lookahead_limit-1-lookahead_size] = symbol; + ++lookahead_size; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename sliding_buffer + > + void lz77_buffer_kernel_1:: + find_match ( + unsigned long& index, + unsigned long& length, + unsigned long min_match_length + ) + { + unsigned long hpos = history_size; // current position in the history buffer + unsigned long lpos = 0; // current position in the lookahead buffer + + unsigned long match_length = 0; // the length of the longest match we find + unsigned long match_index = 0; // the index of the longest match we find + + // try to find a match + while (hpos != 0) + { + --hpos; + // if we are finding a match + if (history_buffer(hpos) == lookahead_buffer(lpos)) + { + ++lpos; + // if we have found a match that is as long as the lookahead buffer + // then we are done + if (lpos == lookahead_size) + break; + } + // else if we found the end of a match + else if (lpos > 0) + { + // if this match is longer than the last match we saw + if (lpos > match_length) + { + match_length = lpos; + match_index = hpos + lpos; + } + lpos = 0; + } + } // while (hpos != 0) + + // if we found a match at the end of the loop that is greater than + // the match in match_index + if (lpos > match_length) + { + match_length = lpos; + match_index = hpos + lpos - 1; + } + + + // if we found a match that was long enough then report it + if (match_length >= min_match_length) + { + shift_buffer(match_length); + index = match_index; + length = match_length; + } + else + { + length = 0; + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_LZ77_BUFFER_KERNEl_1_ + diff --git a/ml/dlib/dlib/lz77_buffer/lz77_buffer_kernel_2.h b/ml/dlib/dlib/lz77_buffer/lz77_buffer_kernel_2.h new file mode 100644 index 000000000..f8d332784 --- /dev/null +++ b/ml/dlib/dlib/lz77_buffer/lz77_buffer_kernel_2.h @@ -0,0 +1,504 @@ +// Copyright (C) 2004 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_LZ77_BUFFER_KERNEl_2_ +#define DLIB_LZ77_BUFFER_KERNEl_2_ + +#include "lz77_buffer_kernel_abstract.h" +#include "../algs.h" + + + +namespace dlib +{ + + template < + typename sliding_buffer + > + class lz77_buffer_kernel_2 + { + /*! + REQUIREMENTS ON sliding_buffer + sliding_buffer must be an implementation of sliding_buffer/sliding_buffer_kernel_abstract.h + and must be instantiated to contain unsigned char data + + INITIAL VALUE + history_limit == defined by constructor arguments + lookahead_limit == defined by constructor arguments + history_size == 0 + lookahead_size == 0 + buffer.size() == history_limit + lookahead_limit + buffer[i] == 0 for all valid i + + nodes == an array of history_limit-3 nodes + id_table == an array of buffer.size() pointers + hash_table == an array of buffer.size() pointers and all are set to 0 + mask == buffer.size() - 1 + next_free_node == 0 + + + CONVENTION + history_limit == get_history_buffer_limit() + lookahead_limit == get_lookahead_buffer_limit() + history_size == get_history_buffer_size() + lookahead_limit == get_lookahead_buffer_size() + + buffer.size() == history_limit + lookahead_limit + + lookahead_buffer(i) == buffer[lookahead_limit-1-i] + history_buffer(i) == buffer[lookahead_limit+i] + + + hash_table[hash(a,b,c,d)] points to the head of a linked list. + Each node in this linked list tells the location in the buffer + of a string that begins with abcd or a string who's first four + letters have the same hash. The linked list is terminated by a + node with a null next pointer. + + hash_table[i] == 0 if there is no linked list for this element of the hash + table. + + each node in the hash table is allocated from the array nodes. + When adding a node to hash_table: + if (if all nodes aren't already in the hash_table) then + { + the next node to use is nodes[next_free_node]. + } + else + { + recycle nodes from the hash_table itself. This works because + when we add new nodes we also have to remove nodes. + } + + if (there is a node defined with an id of i) then + { + if (id_table[i] != 0) then + id_table[i]->next->id == i + else + hash_table[some_hash]->id == i + } + !*/ + + public: + + lz77_buffer_kernel_2 ( + unsigned long total_limit_, + unsigned long lookahead_limit_ + ); + + virtual ~lz77_buffer_kernel_2 ( + ); + + void clear( + ); + + void add ( + unsigned char symbol + ); + + void find_match ( + unsigned long& index, + unsigned long& length, + unsigned long min_match_length + ); + + inline unsigned long get_history_buffer_limit ( + ) const { return history_limit; } + + inline unsigned long get_lookahead_buffer_limit ( + ) const { return lookahead_limit; } + + inline unsigned long get_history_buffer_size ( + ) const { return history_size; } + + inline unsigned long get_lookahead_buffer_size ( + ) const { return lookahead_size; } + + inline unsigned char lookahead_buffer ( + unsigned long index + ) const { return buffer[lookahead_limit-1-index]; } + + inline unsigned char history_buffer ( + unsigned long index + ) const { return buffer[lookahead_limit+index]; } + + + inline void shift_buffers ( + unsigned long N + ) { shift_buffer(N); } + + private: + + inline unsigned long hash ( + unsigned char a, + unsigned char b, + unsigned char c, + unsigned char d + ) const + /*! + ensures + - returns a hash of the 4 arguments and the hash is in the range + !*/ + { + unsigned long B = b << 3; + unsigned long C = c << 6; + unsigned long D = d << 9; + + unsigned long temp = a + B; + temp += C; + temp += D; + + return (temp&mask); /**/ + } + + void shift_buffer ( + unsigned long N + ); + /*! + requires + - N <= lookahead_size + ensuers + - #lookahead_size == lookahead_size - N + - if (history_size+N < history_limit) then + - #history_size == history_size+N + - else + - #history_size == history_limit + - for all i where 0 <= i < N: + #history_buffer(N-1-i) == lookahead_buffer(i) + - for all i where 0 <= i < #history_size-N: + #history_buffer(N+i) == history_buffer(i) + - for all i where 0 <= i < #lookahead_size + #lookahead_buffer(i) == lookahead_buffer(N+i) + !*/ + + + + // member data + sliding_buffer buffer; + unsigned long lookahead_limit; + unsigned long history_limit; + + struct node + { + unsigned long id; + node* next; + }; + + node** hash_table; + node* nodes; + node** id_table; + unsigned long next_free_node; + unsigned long mask; + + unsigned long lookahead_size; + unsigned long history_size; + + + // restricted functions + lz77_buffer_kernel_2(lz77_buffer_kernel_2&); // copy constructor + lz77_buffer_kernel_2& operator=(lz77_buffer_kernel_2&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename sliding_buffer + > + lz77_buffer_kernel_2:: + lz77_buffer_kernel_2 ( + unsigned long total_limit_, + unsigned long lookahead_limit_ + ) : + lookahead_size(0), + history_size(0) + { + buffer.set_size(total_limit_); + lookahead_limit = lookahead_limit_; + history_limit = buffer.size() - lookahead_limit_; + + nodes = new node[history_limit-3]; + + try { id_table = new node*[buffer.size()]; } + catch (...) { delete [] nodes; throw; } + + try { hash_table = new node*[buffer.size()]; } + catch (...) { delete [] id_table; delete [] nodes; throw; } + + mask = buffer.size()-1; + next_free_node = 0; + + + node** start = hash_table; + node** end = hash_table + buffer.size(); + while (start != end) + { + *start = 0; + ++start; + } + + for (unsigned long i = 0; i < buffer.size(); ++i) + buffer[i] = 0; + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename sliding_buffer + > + lz77_buffer_kernel_2:: + ~lz77_buffer_kernel_2 ( + ) + { + delete [] nodes; + delete [] hash_table; + delete [] id_table; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename sliding_buffer + > + void lz77_buffer_kernel_2:: + clear( + ) + { + lookahead_size = 0; + history_size = 0; + next_free_node = 0; + + node** start = hash_table; + node** end = hash_table + buffer.size(); + while (start != end) + { + *start = 0; + ++start; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename sliding_buffer + > + void lz77_buffer_kernel_2:: + shift_buffer ( + unsigned long N + ) + { + unsigned long old_history_size = history_size; + unsigned long temp = history_size+N; + unsigned long new_nodes; // the number of nodes to pull from the nodes array + unsigned long recycled_nodes; // the number of nodes to pull from hash_table + lookahead_size -= N; + if (temp <= history_limit) + { + if (history_size <= 3) + { + if ((3-history_size) >= N) + new_nodes = 0; + else + new_nodes = N - (3-history_size); + } + else + { + new_nodes = N; + } + + recycled_nodes = 0; + history_size = temp; + } + else + { + if (history_size != history_limit) + { + new_nodes = history_limit - history_size; + recycled_nodes = temp - history_limit; + history_size = history_limit; + } + else + { + new_nodes = 0; + recycled_nodes = N; + } + } + + unsigned long i = lookahead_limit + 2; + + // if there are any "new" nodes to add to the hash table + if (new_nodes != 0) + { + unsigned long stop = i - new_nodes; + for (; i > stop; --i) + { + nodes[next_free_node].next = 0; + nodes[next_free_node].id = buffer.get_element_id(i); + id_table[nodes[next_free_node].id] = 0; + + unsigned long new_hash = hash(buffer[i],buffer[i-1],buffer[i-2],buffer[i-3]); + + if (hash_table[new_hash] != 0) + id_table[hash_table[new_hash]->id] = &nodes[next_free_node]; + nodes[next_free_node].next = hash_table[new_hash]; + hash_table[new_hash] = &nodes[next_free_node]; + + ++next_free_node; + } + } // if (new_nodes != 0) + + + + unsigned long stop = i - recycled_nodes; + unsigned long old = old_history_size-1+lookahead_limit; + for (; i > stop; --i) + { + // find the next node to recycle in hash_table + node* recycled_node; + + + unsigned long old_id = buffer.get_element_id(old); + + // find the node with id old_id + if (id_table[old_id] == 0) + { + unsigned long old_hash = hash(buffer[old],buffer[old-1],buffer[old-2],buffer[old-3]); + recycled_node = hash_table[old_hash]; + + // fill the gap left by removing this node + hash_table[old_hash] = recycled_node->next; + } + else + { + recycled_node = id_table[old_id]->next; + + // fill the gap left by removing this node + id_table[old_id]->next = recycled_node->next; + } + + --old; + + + + + + + recycled_node->next = 0; + recycled_node->id = buffer.get_element_id(i); + id_table[recycled_node->id] = 0; + + unsigned long new_hash = hash(buffer[i],buffer[i-1],buffer[i-2],buffer[i-3]); + + if (hash_table[new_hash] != 0) + id_table[hash_table[new_hash]->id] = recycled_node; + + recycled_node->next = hash_table[new_hash]; + hash_table[new_hash] = recycled_node; + + } // for (; i > stop; --i) + + + + + buffer.rotate_left(N); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename sliding_buffer + > + void lz77_buffer_kernel_2:: + add ( + unsigned char symbol + ) + { + if (lookahead_size == lookahead_limit) + { + shift_buffer(1); + } + buffer[lookahead_limit-1-lookahead_size] = symbol; + ++lookahead_size; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename sliding_buffer + > + void lz77_buffer_kernel_2:: + find_match ( + unsigned long& index, + unsigned long& length, + unsigned long min_match_length + ) + { + unsigned long match_length = 0; // the length of the longest match we find + unsigned long match_index = 0; // the index of the longest match we find + + + const unsigned long hash_value = hash(lookahead_buffer(0), + lookahead_buffer(1), + lookahead_buffer(2), + lookahead_buffer(3) + ); + + + + node* temp = hash_table[hash_value]; + while (temp != 0) + { + // current position in the history buffer + unsigned long hpos = buffer.get_element_index(temp->id)-lookahead_limit; + // current position in the lookahead buffer + unsigned long lpos = 0; + + // find length of this match + while (history_buffer(hpos) == lookahead_buffer(lpos)) + { + ++lpos; + if (hpos == 0) + break; + --hpos; + if (lpos == lookahead_size) + break; + } + + if (lpos > match_length) + { + match_length = lpos; + match_index = buffer.get_element_index(temp->id)-lookahead_limit; + // if this is the longest possible match then stop looking + if (lpos == lookahead_limit) + break; + } + + + temp = temp->next; + } // while (temp != 0) + + + + + // if we found a match that was long enough then report it + if (match_length >= min_match_length) + { + shift_buffer(match_length); + index = match_index; + length = match_length; + } + else + { + length = 0; + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_LZ77_BUFFER_KERNEl_2_ + diff --git a/ml/dlib/dlib/lz77_buffer/lz77_buffer_kernel_abstract.h b/ml/dlib/dlib/lz77_buffer/lz77_buffer_kernel_abstract.h new file mode 100644 index 000000000..942b4e3c2 --- /dev/null +++ b/ml/dlib/dlib/lz77_buffer/lz77_buffer_kernel_abstract.h @@ -0,0 +1,210 @@ +// Copyright (C) 2004 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_LZ77_BUFFER_KERNEl_ABSTRACT_ +#ifdef DLIB_LZ77_BUFFER_KERNEl_ABSTRACT_ + +#include "../algs.h" + +namespace dlib +{ + + class lz77_buffer + { + /*! + INITIAL VALUE + get_history_buffer_limit() == defined by constructor arguments + get_lookahead_buffer_limit() == defined by constructor arguments + get_history_buffer_size() == 0 + get_lookahead_buffer_size() == 0 + + + WHAT THIS OBJECT REPRESENTS + This object represents a pair of buffers (history and lookahead buffers) + used during lz77 style compression. + + It's main function is to search the history buffer for long strings which + match the contents (or a part of the contents) of the lookahead buffer. + + + HISTORY AND LOOKAHEAD BUFFERS + The buffers have the following structure: + | history buffer | lookahead buffer | <-- contents of buffers + | ...9876543210 | 0123456789... | <-- index numbers + + So this means that history_buffer(0) == 'r', history_buffer(1) == 'e' + and so on. And lookahead_buffer(0) == 'l', lookahead_buffer(1) == 'o' + and so on. + + + What shift_buffers() does in english: + This function just means that the buffers have their contents shifted + left by N elements and that elements shifted out of the lookahead buffer + go into the history buffer. An example will make it clearer. + + Suppose that we have the following buffers before we apply shift_buffers() + history_buffer() == "hey" and + lookahead_buffer() == "lookahead buffer" + And in the same format as the above diagram it would be + | hey | lookahead buffer | <-- contents of buffers + | 210 | 0123456789... | <-- index numbers + + Applying shift_buffers(4) will give + lookahead_buffer() == "ahead buffer" + history_buffer() == "heylook" or "eylook" or "ylook" or "look" + + You might be wondering why the history_buffer can resize itself in + such a nondeterministic way. It is just to allow a lot of freedom in the + implementations of this object. + !*/ + + public: + + lz77_buffer ( + unsigned long total_limit, + unsigned long lookahead_limit + ); + /*! + requires + - 6 < total_limit < 32 + - 15 < lookahead_limit <= 2^(total_limit-2) + ensures + - #*this is properly initialized + - #get_history_buffer_limit() == 2^total_limit - lookahead_limit + - #get_lookahead_buffer_limit() == lookahead_limit + throws + - std::bad_alloc + !*/ + + virtual ~lz77_buffer ( + ); + /*! + ensures + - any resources associated with *this have been released + !*/ + + void clear( + ); + /*! + ensures + - #*this has its initial value + throws + - std::bad_alloc + if this exception is thrown then #*this is unusable + until clear() is called and succeeds + !*/ + + void shift_buffers ( + unsigned long N + ); + /*! + requires + - N <= get_lookahead_buffer_size() + ensures + - #get_lookahead_buffer_size() == get_lookahead_buffer_size() - N + - #get_history_buffer_size() >= N + - #get_history_buffer_size() <= get_history_buffer_size()+N + - #get_history_buffer_size() <= get_history_buffer_limit() + - for all i where 0 <= i < N: + #history_buffer(N-1-i) == lookahead_buffer(i) + - for all i where 0 <= i < #get_history_buffer_size()-N: + #history_buffer(N+i) == history_buffer(i) + - for all i where 0 <= i < #get_lookahead_buffer_size() + #lookahead_buffer(i) == lookahead_buffer(N+i) + !*/ + + void add ( + unsigned char symbol + ); + /*! + ensures + - if (get_lookahead_buffer_size() == get_lookahead_buffer_limit()) then + - performs shift_buffers(1) + - #lookahead_buffer(get_lookahead_buffer_limit()-1) == symbol + - #get_lookahead_buffer_size() == get_lookahead_buffer_size() + - else + - #lookahead_buffer(get_lookahead_buffer_size()) == symbol + - #get_lookahead_buffer_size() == get_lookahead_buffer_size() + 1 + throws + - std::bad_alloc + if this exception is thrown then #*this is unusable + until clear() is called and succeeds + !*/ + + void find_match ( + unsigned long& index, + unsigned long& length, + unsigned long min_match_length + ); + /*! + ensures + - if (#length != 0) then + - #length >= min_match_length + - for all i where 0 <= i < #length: + history_buffer(#index-i) == lookahead_buffer(i) + - performs shift_buffers(#length) + throws + - std::bad_alloc + if this exception is thrown then #*this is unusable + until clear() is called and succeeds + !*/ + + unsigned long get_history_buffer_limit ( + ) const; + /*! + ensures + - returns the max number of symbols that can fit in the history buffer + !*/ + + unsigned long get_lookahead_buffer_limit ( + ) const; + /*! + ensures + - returns the max number of symbols that can fit in the lookahead buffer + !*/ + + unsigned long get_history_buffer_size ( + ) const; + /*! + ensures + - returns the number of symbols currently in the history buffer + !*/ + + unsigned long get_lookahead_buffer_size ( + ) const; + /*! + ensures + - returns the number of symbols currently in the lookahead buffer + !*/ + + unsigned char lookahead_buffer ( + unsigned long index + ) const; + /*! + requires + - index < get_lookahead_buffer_size() + ensures + - returns the symbol in the lookahead buffer at location index + !*/ + + unsigned char history_buffer ( + unsigned long index + ) const; + /*! + requires + - index < get_history_buffer_size() + ensures + - returns the symbol in the history buffer at location index + !*/ + + + private: + + // restricted functions + lz77_buffer(lz77_buffer&); // copy constructor + lz77_buffer& operator=(lz77_buffer&); // assignment operator + + }; +} + +#endif // DLIB_LZ77_BUFFER_KERNEl_ABSTRACT_ + diff --git a/ml/dlib/dlib/lz77_buffer/lz77_buffer_kernel_c.h b/ml/dlib/dlib/lz77_buffer/lz77_buffer_kernel_c.h new file mode 100644 index 000000000..704763ad1 --- /dev/null +++ b/ml/dlib/dlib/lz77_buffer/lz77_buffer_kernel_c.h @@ -0,0 +1,169 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_LZ77_BUFFER_KERNEl_C_ +#define DLIB_LZ77_BUFFER_KERNEl_C_ + +#include "lz77_buffer_kernel_abstract.h" +#include "../algs.h" +#include "../assert.h" +#include + +namespace dlib +{ + + template < + typename lz77_base + > + class lz77_buffer_kernel_c : public lz77_base + { + + public: + lz77_buffer_kernel_c ( + unsigned long total_limit, + unsigned long lookahead_limit + ); + + unsigned char lookahead_buffer ( + unsigned long index + ) const; + + unsigned char history_buffer ( + unsigned long index + ) const; + + void shift_buffers ( + unsigned long N + ); + + + + unsigned long make_safe ( + unsigned long total_limit, + unsigned long lookahead_limit + ) + /*! + ensures + - if ( 6 < total_limit < 32 && + 15 < lookahead_limit <= 2^(total_limit-2) + ) then + - returns total_limit + - else + - throws due to failed CASSERT + !*/ + { + unsigned long exp_size = (total_limit!=0)?total_limit-2:0; + unsigned long two_pow_total_limit_minus_2 = 1; + while (exp_size != 0) + { + --exp_size; + two_pow_total_limit_minus_2 <<= 1; + } + + // make sure requires clause is not broken + DLIB_CASSERT( 6 < total_limit && total_limit < 32 && + 15 < lookahead_limit && lookahead_limit <= two_pow_total_limit_minus_2, + "\tlz77_buffer::lz77_buffer(unsigned long,unsigned long)" + << "\n\ttotal_limit must be in the range 7 to 31 and \n\tlookahead_limit in the range 15 to 2^(total_limit-2)" + << "\n\tthis: " << this + << "\n\ttotal_limit: " << total_limit + << "\n\tlookahead_limit: " << lookahead_limit + ); + + return total_limit; + } + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename lz77_base + > + void lz77_buffer_kernel_c:: + shift_buffers ( + unsigned long N + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( N <= this->get_lookahead_buffer_size(), + "\tvoid lz77_buffer::shift_buffers(unsigned long)" + << "\n\tN must be <= the number of chars in the lookahead buffer" + << "\n\tthis: " << this + << "\n\tget_lookahead_buffer_size(): " << this->get_lookahead_buffer_size() + << "\n\tN: " << N + ); + + // call the real function + lz77_base::shift_buffers(N); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename lz77_base + > + unsigned char lz77_buffer_kernel_c:: + history_buffer ( + unsigned long index + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT( index < this->get_history_buffer_size(), + "\tunsigned char lz77_buffer::history_buffer(unsigned long) const" + << "\n\tindex must be in the range 0 to get_history_buffer_size()-1" + << "\n\tthis: " << this + << "\n\tget_history_buffer_size(): " << this->get_history_buffer_size() + << "\n\tindex: " << index + ); + + // call the real function + return lz77_base::history_buffer(index); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename lz77_base + > + unsigned char lz77_buffer_kernel_c:: + lookahead_buffer ( + unsigned long index + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT( index < this->get_lookahead_buffer_size(), + "\tunsigned char lz77_buffer::lookahead_buffer(unsigned long) const" + << "\n\tindex must be in the range 0 to get_lookahead_buffer_size()-1" + << "\n\tthis: " << this + << "\n\tget_lookahead_buffer_size(): " << this->get_lookahead_buffer_size() + << "\n\tindex: " << index + ); + + // call the real function + return lz77_base::lookahead_buffer(index); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename lz77_base + > + lz77_buffer_kernel_c:: + lz77_buffer_kernel_c ( + unsigned long total_limit, + unsigned long lookahead_limit + ) : + lz77_base(make_safe(total_limit,lookahead_limit),lookahead_limit) + { + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_LZ77_BUFFER_KERNEl_C_ + diff --git a/ml/dlib/dlib/lzp_buffer.h b/ml/dlib/dlib/lzp_buffer.h new file mode 100644 index 000000000..090a53de3 --- /dev/null +++ b/ml/dlib/dlib/lzp_buffer.h @@ -0,0 +1,46 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_LZP_BUFFEr_ +#define DLIB_LZP_BUFFEr_ + + +#include "lzp_buffer/lzp_buffer_kernel_1.h" +#include "lzp_buffer/lzp_buffer_kernel_2.h" +#include "lzp_buffer/lzp_buffer_kernel_c.h" + +#include "sliding_buffer.h" + + +namespace dlib +{ + + + class lzp_buffer + { + + lzp_buffer() {} + + typedef sliding_buffer::kernel_1a sb1; + + public: + + //----------- kernels --------------- + + // kernel_1a + typedef lzp_buffer_kernel_1 + kernel_1a; + typedef lzp_buffer_kernel_c + kernel_1a_c; + + // kernel_2a + typedef lzp_buffer_kernel_2 + kernel_2a; + typedef lzp_buffer_kernel_c + kernel_2a_c; + + + }; +} + +#endif // DLIB_LZP_BUFFEr_ + diff --git a/ml/dlib/dlib/lzp_buffer/lzp_buffer_kernel_1.h b/ml/dlib/dlib/lzp_buffer/lzp_buffer_kernel_1.h new file mode 100644 index 000000000..f24d74eee --- /dev/null +++ b/ml/dlib/dlib/lzp_buffer/lzp_buffer_kernel_1.h @@ -0,0 +1,236 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_LZP_BUFFER_KERNEl_1_ +#define DLIB_LZP_BUFFER_KERNEl_1_ + +#include "../algs.h" +#include "lzp_buffer_kernel_abstract.h" + +namespace dlib +{ + + template < + typename sbuf + > + class lzp_buffer_kernel_1 + { + /*! + REQUIREMENTS ON sbuf + sbuf is an implementation of sliding_buffer/sliding_buffer_kernel_abstract.h + T == unsigned char + + INITIAL VALUE + - buffer.size() == the size as defined by the constructor + - table_size == the number of elements in the table array + - for all i: buffer[i] == 0 + - for all i: table[i] == buffer.size() + + CONVENTION + - table_size == the number of elements in the table array + - size() == buffer.size() + - operator[](i) == buffer[i] + + - if (table[hash()] != buffer.size()) then + - buffer.get_element_index(table[hash()]) == the index we will + predict for the current context + - else + - there is no prediction for the current context + + - last_element == buffer.size()-1 + + + This is LZP with just an order-3 model without context confirmation. + + !*/ + + public: + + explicit lzp_buffer_kernel_1 ( + unsigned long buffer_size + ); + + virtual ~lzp_buffer_kernel_1 ( + ); + + void clear( + ); + + inline void add ( + unsigned char symbol + ); + + inline unsigned long predict_match ( + unsigned long& index + ); + + inline size_t size ( + ) const; + + inline unsigned char operator[] ( + unsigned long index + ) const; + + private: + + inline unsigned long hash ( + ) const + /*! + ensures + - returns a hash computed from the current context. This hash + is always in the range for table. + !*/ + { + unsigned long temp = buffer[0]; + temp <<= 16; + unsigned long temp2 = buffer[1]; + temp2 <<= 8; + unsigned long temp3 = buffer[2]; + temp = temp|temp2|temp3; + + temp = ((temp>>11)^temp)&0xFFFF; + + return temp; + } + + sbuf buffer; + const unsigned long table_size; + unsigned long* const table; + unsigned long last_element; + + // restricted functions + lzp_buffer_kernel_1(const lzp_buffer_kernel_1&); // copy constructor + lzp_buffer_kernel_1& operator=(const lzp_buffer_kernel_1&); // assignment operator + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename sbuf + > + lzp_buffer_kernel_1:: + lzp_buffer_kernel_1 ( + unsigned long buffer_size + ) : + table_size(65536), + table(new unsigned long[table_size]) + { + buffer.set_size(buffer_size); + + for (unsigned long i = 0; i < buffer.size(); ++i) + buffer[i] = 0; + + for (unsigned long i = 0; i < table_size; ++i) + table[i] = buffer.size(); + + last_element = buffer.size()-1; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename sbuf + > + lzp_buffer_kernel_1:: + ~lzp_buffer_kernel_1 ( + ) + { + delete [] table; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename sbuf + > + void lzp_buffer_kernel_1:: + clear( + ) + { + for (unsigned long i = 0; i < buffer.size(); ++i) + buffer[i] = 0; + + for (unsigned long i = 0; i < table_size; ++i) + table[i] = buffer.size(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename sbuf + > + void lzp_buffer_kernel_1:: + add ( + unsigned char symbol + ) + { + buffer.rotate_left(1); + buffer[0] = symbol; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename sbuf + > + unsigned long lzp_buffer_kernel_1:: + predict_match ( + unsigned long& index + ) + { + const unsigned long i = hash(); + + if (table[i] != buffer.size()) + { + index = buffer.get_element_index(table[i]); + + if (index > 20) + { + // update the prediction for this context + table[i] = buffer.get_element_id(last_element); + } + return 3; + } + else + { + // update the prediction for this context + table[i] = buffer.get_element_id(last_element); + return 0; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename sbuf + > + size_t lzp_buffer_kernel_1:: + size ( + ) const + { + return buffer.size(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename sbuf + > + unsigned char lzp_buffer_kernel_1:: + operator[] ( + unsigned long index + ) const + { + return buffer[index]; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_LZP_BUFFER_KERNEl_1_ + diff --git a/ml/dlib/dlib/lzp_buffer/lzp_buffer_kernel_2.h b/ml/dlib/dlib/lzp_buffer/lzp_buffer_kernel_2.h new file mode 100644 index 000000000..47c0443f1 --- /dev/null +++ b/ml/dlib/dlib/lzp_buffer/lzp_buffer_kernel_2.h @@ -0,0 +1,319 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_LZP_BUFFER_KERNEl_2_ +#define DLIB_LZP_BUFFER_KERNEl_2_ + +#include "../algs.h" +#include "lzp_buffer_kernel_abstract.h" +#include + +namespace dlib +{ + + template < + typename sbuf + > + class lzp_buffer_kernel_2 + { + /*! + REQUIREMENTS ON sbuf + sbuf is an implementation of sliding_buffer/sliding_buffer_kernel_abstract.h + T == unsigned char + + INITIAL VALUE + - buffer.size() == the size as defined by the constructor + - table_size == the number of elements in the table3 and table4 arrays + - for all i: buffer[i] == 0 + - for all i: table3[i] == buffer.size() + - for all i: table4[i] == buffer.size() + + CONVENTION + - table_size == the number of elements in the table3 and table4 arrays + - size() == buffer.size() + - operator[](i) == buffer[i] + + + + - last_element == buffer.size()-1 + + + This is LZP with an order-5-4-3 model with context confirmation. + To save memory the order5 and order3 predictions exist in the same + table, that is, table3. + + !*/ + + public: + + explicit lzp_buffer_kernel_2 ( + unsigned long buffer_size + ); + + virtual ~lzp_buffer_kernel_2 ( + ); + + void clear( + ); + + inline void add ( + unsigned char symbol + ); + + inline unsigned long predict_match ( + unsigned long& index + ); + + inline size_t size ( + ) const; + + inline unsigned char operator[] ( + unsigned long index + ) const; + + private: + + inline bool verify ( + unsigned long index + ) const + /*! + ensures + - returns true if buffer[index]'s context matches the current context + !*/ + { + if (index+3 < buffer.size()) + { + if (buffer[0] != buffer[index+1]) + return false; + if (buffer[1] != buffer[index+2]) + return false; + if (buffer[2] != buffer[index+3]) + return false; + return true; + } + else + { + // just call this a match + return true; + } + } + + + sbuf buffer; + unsigned long* table3; + unsigned long* table4; + unsigned long last_element; + const unsigned long table_size; + + // restricted functions + lzp_buffer_kernel_2(const lzp_buffer_kernel_2&); // copy constructor + lzp_buffer_kernel_2& operator=(const lzp_buffer_kernel_2&); // assignment operator + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename sbuf + > + lzp_buffer_kernel_2:: + lzp_buffer_kernel_2 ( + unsigned long buffer_size + ) : + table3(0), + table4(0), + table_size(65536) + { + buffer.set_size(buffer_size); + + table3 = new (std::nothrow) unsigned long[table_size]; + table4 = new (std::nothrow) unsigned long[table_size]; + + if (!table3 || !table4) + { + if (!table3) + delete [] table3; + if (!table4) + delete [] table4; + + throw std::bad_alloc(); + } + + + + for (unsigned long i = 0; i < buffer.size(); ++i) + buffer[i] = 0; + + for (unsigned long i = 0; i < table_size; ++i) + { + table3[i] = buffer.size(); + table4[i] = buffer.size(); + } + + last_element = buffer.size()-1; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename sbuf + > + lzp_buffer_kernel_2:: + ~lzp_buffer_kernel_2 ( + ) + { + delete [] table3; + delete [] table4; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename sbuf + > + void lzp_buffer_kernel_2:: + clear( + ) + { + for (unsigned long i = 0; i < buffer.size(); ++i) + buffer[i] = 0; + + for (unsigned long i = 0; i < table_size; ++i) + { + table3[i] = buffer.size(); + table4[i] = buffer.size(); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename sbuf + > + void lzp_buffer_kernel_2:: + add ( + unsigned char symbol + ) + { + buffer.rotate_left(1); + buffer[0] = symbol; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename sbuf + > + unsigned long lzp_buffer_kernel_2:: + predict_match ( + unsigned long& index + ) + { + unsigned long temp1 = buffer[0]; + unsigned long temp2 = buffer[1]; + temp2 <<= 8; + unsigned long temp3 = buffer[2]; + temp3 <<= 16; + unsigned long temp4 = buffer[3]; + temp4 <<= 24; + unsigned long temp5 = buffer[4]; + temp5 <<= 12; + + unsigned long context1 = temp1|temp2|temp3; + unsigned long context2 = context1|temp4; + + + const unsigned long i5 = ((temp5|(context2>>20))^context2)&0xFFFF; + const unsigned long i4 = ((context2>>15)^context2)&0xFFFF; + const unsigned long i3 = ((context1>>11)^context1)&0xFFFF; + + + + // check the 5-order context's prediction + if (table3[i5] != buffer.size() && + verify(buffer.get_element_index(table3[i5])) ) + { + index = buffer.get_element_index(table3[i5]); + if (index > 20) + { + // update the prediction for this context + table3[i3] = buffer.get_element_id(last_element); + table4[i4] = table3[i3]; + table3[i5] = table3[i3]; + } + return 5; + } + // check the 4-order context's prediction + else if (table4[i4] != buffer.size() && + verify(buffer.get_element_index(table4[i4])) ) + { + index = buffer.get_element_index(table4[i4]); + if (index > 20) + { + // update the prediction for this context + table3[i3] = buffer.get_element_id(last_element); + table4[i4] = table3[i3]; + table3[i5] = table3[i3]; + } + return 4; + } + // check the 3-order context's prediction + else if (table3[i3] != buffer.size() && + verify(buffer.get_element_index(table3[i3]))) + { + index = buffer.get_element_index(table3[i3]); + + if (index > 20) + { + // update the prediction for this context + table3[i3] = buffer.get_element_id(last_element); + table4[i4] = table3[i3]; + table3[i5] = table3[i3]; + } + return 3; + } + else + { + // update the prediction for this context + table3[i3] = buffer.get_element_id(last_element); + table4[i4] = table3[i3]; + table3[i5] = table3[i3]; + + return 0; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename sbuf + > + size_t lzp_buffer_kernel_2:: + size ( + ) const + { + return buffer.size(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename sbuf + > + unsigned char lzp_buffer_kernel_2:: + operator[] ( + unsigned long index + ) const + { + return buffer[index]; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_LZP_BUFFER_KERNEl_2_ + diff --git a/ml/dlib/dlib/lzp_buffer/lzp_buffer_kernel_abstract.h b/ml/dlib/dlib/lzp_buffer/lzp_buffer_kernel_abstract.h new file mode 100644 index 000000000..df8b8c80f --- /dev/null +++ b/ml/dlib/dlib/lzp_buffer/lzp_buffer_kernel_abstract.h @@ -0,0 +1,130 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_LZP_BUFFER_KERNEl_ABSTRACT_ +#ifdef DLIB_LZP_BUFFER_KERNEl_ABSTRACT_ + +#include "../algs.h" + +namespace dlib +{ + + class lzp_buffer + { + /*! + INITIAL VALUE + size() == some value defined by the constructor argument + Initially this object is at some predefined empty or ground state. + for all i: (*this)[i] == 0 + + WHAT THIS OBJECT REPRESENTS + This object represents some varation on the LZP algorithm + described by Charles Bloom in his paper "LZP: a new data + compression algorithm" + + The LZP algorithm is a lot like lz77 except there is no need to pass + the location of matches in the history buffer to the decoder because + LZP uses the data it has already seen to predict the location of the + next match. + + NOTE + The add() and predict_match() functions must be called in the same + order by the coder and decoder. If they aren't the state of the + lzp_buffer objects in the coder and decoder may differ and the decoder + won't be able to correctly decode the data stream. + !*/ + + public: + + explicit lzp_buffer ( + unsigned long buffer_size + ); + /*! + requires + - 10 < buffer_size < 32 + ensures + - #*this is properly initialized + - #size() == 2^buffer_size + throws + - std::bad_alloc + !*/ + + virtual ~lzp_buffer ( + ); + /*! + ensures + - any resources associated with *this have been released + !*/ + + void clear( + ); + /*! + ensures + - #*this has its initial value + throws + - std::bad_alloc + if this exception is thrown then #*this is unusable + until clear() is called and succeeds + !*/ + + void add ( + unsigned char symbol + ); + /*! + ensures + - shifts everything in the history buffer left 1. + (i.e. #(*this)[i+1] == (*this)[i]) + - #(*this)[0] == symbol + throws + - std::bad_alloc + if this exception is thrown then #*this is unusable + until clear() is called and succeeds + !*/ + + unsigned long predict_match ( + unsigned long& index + ); + /*! + ensures + - updates the prediction for the current context. + (the current context is the last few symbols seen. i.e. (*this)[0], + (*this)[1], etc.) + - if (*this can generate a prediction) then + - #index == the predicted location of a match in the history buffer. + (i.e. (*this)[#index] is the first symbol of the predicted match) + - returns the order this prediction came from + - else + - returns 0 + throws + - std::bad_alloc + if this exception is thrown then #*this is unusable + until clear() is called and succeeds + !*/ + + size_t size ( + ) const; + /*! + ensures + - returns the size of the history buffer + !*/ + + unsigned char operator[] ( + unsigned long index + ) const; + /*! + requires + - index < size() + ensures + - returns the symbol at the given index in the history buffer + !*/ + + private: + + // restricted functions + lzp_buffer(const lzp_buffer&); // copy constructor + lzp_buffer& operator=(const lzp_buffer&); // assignment operator + + }; +} + +#endif // DLIB_LZP_BUFFER_KERNEl_ABSTRACT_ + diff --git a/ml/dlib/dlib/lzp_buffer/lzp_buffer_kernel_c.h b/ml/dlib/dlib/lzp_buffer/lzp_buffer_kernel_c.h new file mode 100644 index 000000000..2b2de2f1d --- /dev/null +++ b/ml/dlib/dlib/lzp_buffer/lzp_buffer_kernel_c.h @@ -0,0 +1,101 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_LZP_BUFFER_KERNEl_C_ +#define DLIB_LZP_BUFFER_KERNEl_C_ + +#include "lzp_buffer_kernel_abstract.h" +#include "../algs.h" +#include "../assert.h" +#include + +namespace dlib +{ + + template < + typename lzp_base + > + class lzp_buffer_kernel_c : public lzp_base + { + + public: + lzp_buffer_kernel_c ( + unsigned long buffer_size + ); + + + unsigned char operator[] ( + unsigned long index + ) const; + + + unsigned long make_safe ( + unsigned long buffer_size + ) + /*! + ensures + - if ( 10 < buffer_size < 32) then + - returns buffer_size + - else + - throws due to failed CASSERT + !*/ + { + + // make sure requires clause is not broken + DLIB_CASSERT( 10 < buffer_size && buffer_size < 32, + "\tlzp_buffer::lzp_buffer(unsigned long)" + << "\n\tbuffer_size must be in the range 11 to 31." + << "\n\tthis: " << this + << "\n\tbuffer_size: " << buffer_size + ); + + return buffer_size; + } + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename lzp_base + > + unsigned char lzp_buffer_kernel_c:: + operator[] ( + unsigned long index + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT( index < this->size(), + "\tunsigned char lzp_buffer::operator[](unsigned long) const" + << "\n\tindex must be in the range 0 to size()()-1" + << "\n\tthis: " << this + << "\n\tsize(): " << this->size() + << "\n\tindex: " << index + ); + + // call the real function + return lzp_base::operator[](index); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename lzp_base + > + lzp_buffer_kernel_c:: + lzp_buffer_kernel_c ( + unsigned long buffer_size + ) : + lzp_base(make_safe(buffer_size)) + { + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_LZP_BUFFER_KERNEl_C_ + diff --git a/ml/dlib/dlib/manifold_regularization.h b/ml/dlib/dlib/manifold_regularization.h new file mode 100644 index 000000000..a7222fd4a --- /dev/null +++ b/ml/dlib/dlib/manifold_regularization.h @@ -0,0 +1,13 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MANIFOLD_REGULARIzATION_HEADER +#define DLIB_MANIFOLD_REGULARIzATION_HEADER + +#include "graph_utils/edge_list_graphs.h" +#include "manifold_regularization/linear_manifold_regularizer.h" +#include "graph_utils/function_objects.h" + +#endif // DLIB_MANIFOLD_REGULARIzATION_HEADER + + + diff --git a/ml/dlib/dlib/manifold_regularization/linear_manifold_regularizer.h b/ml/dlib/dlib/manifold_regularization/linear_manifold_regularizer.h new file mode 100644 index 000000000..95b8b1128 --- /dev/null +++ b/ml/dlib/dlib/manifold_regularization/linear_manifold_regularizer.h @@ -0,0 +1,328 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_LINEAR_MANIFOLD_ReGULARIZER_Hh_ +#define DLIB_LINEAR_MANIFOLD_ReGULARIZER_Hh_ + +#include "linear_manifold_regularizer_abstract.h" +#include +#include +#include "../serialize.h" +#include "../matrix.h" + +namespace dlib +{ + namespace impl + { + class undirected_adjacency_list + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is simply a tool for turning a vector of sample_pair objects + into an adjacency list with floating point weights on each edge. + !*/ + public: + + undirected_adjacency_list ( + ) + { + _size = 0; + sum_edge_weights = 0; + } + + struct neighbor + { + neighbor(unsigned long idx, double w):index(idx), weight(w) {} + neighbor():index(0), weight(0) {} + + unsigned long index; + double weight; + }; + + typedef std::vector::const_iterator const_iterator; + + size_t size ( + ) const + /*! + ensures + - returns the number of vertices in this graph + !*/ + { + return _size; + } + + const_iterator begin( + unsigned long idx + ) const + /*! + requires + - idx < size() + ensures + - returns an iterator that points to the first neighbor of + the idx'th vertex. + !*/ + { + return blocks[idx]; + } + + const_iterator end( + unsigned long idx + ) const + /*! + requires + - idx < size() + ensures + - returns an iterator that points one past the last neighbor + of the idx'th vertex. + !*/ + { + return blocks[idx+1]; + } + + + template + void build ( + const vector_type& edges, + const weight_function_type& weight_funct + ) + /*! + requires + - vector_type == a type with an interface compatible with std::vector and + it must in turn contain objects with an interface compatible with dlib::sample_pair + - edges.size() > 0 + - contains_duplicate_pairs(edges) == false + - weight_funct(edges[i]) must be a valid expression that evaluates to a + floating point number >= 0 + ensures + - #size() == one greater than the max index in edges. + - builds the adjacency list so that it contains all the given edges. + - The weight in each neighbor is set to the output of the weight_funct() + for the associated edge. + !*/ + { + + + // Figure out how many neighbors each sample ultimately has. We do this so + // we will know how much space to allocate in the data vector. + std::vector num_neighbors; + num_neighbors.reserve(edges.size()); + + for (unsigned long i = 0; i < edges.size(); ++i) + { + // make sure num_neighbors is always big enough + const unsigned long min_size = std::max(edges[i].index1(), edges[i].index2())+1; + if (num_neighbors.size() < min_size) + num_neighbors.resize(min_size, 0); + + num_neighbors[edges[i].index1()] += 1; + num_neighbors[edges[i].index2()] += 1; + } + + _size = num_neighbors.size(); + + // Now setup the iterators in blocks. Also setup a version of blocks that holds + // non-const iterators so we can use it below when we populate data. + std::vector::iterator> mutable_blocks; + data.resize(edges.size()*2); // each edge will show up twice + blocks.resize(_size + 1); + blocks[0] = data.begin(); + mutable_blocks.resize(_size + 1); + mutable_blocks[0] = data.begin(); + for (unsigned long i = 0; i < num_neighbors.size(); ++i) + { + blocks[i+1] = blocks[i] + num_neighbors[i]; + mutable_blocks[i+1] = mutable_blocks[i] + num_neighbors[i]; + } + + sum_edge_weights = 0; + // finally, put the edges into data + for (unsigned long i = 0; i < edges.size(); ++i) + { + const double weight = weight_funct(edges[i]); + sum_edge_weights += weight; + + // make sure requires clause is not broken + DLIB_ASSERT(weight >= 0, + "\t void linear_manifold_regularizer::build()" + << "\n\t You supplied a weight_funct() that generated a negative weight." + << "\n\t weight: " << weight + ); + + *mutable_blocks[edges[i].index1()]++ = neighbor(edges[i].index2(), weight); + *mutable_blocks[edges[i].index2()]++ = neighbor(edges[i].index1(), weight); + } + + } + + double sum_of_edge_weights ( + ) const + { + return sum_edge_weights; + } + + private: + + /*! + INITIAL VALUE + - _size == 0 + - data.size() == 0 + - blocks.size() == 0 + - sum_edge_weights == 0 + + CONVENTION + - size() == _size + - blocks.size() == _size + 1 + - sum_of_edge_weights() == sum_edge_weights + - blocks == a vector of iterators that point into data. + For all valid i: + - The iterator range [blocks[i], blocks[i+1]) contains all the edges + for the i'th node in the graph + !*/ + + std::vector data; + std::vector blocks; + unsigned long _size; + + double sum_edge_weights; + }; + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename matrix_type + > + class linear_manifold_regularizer + { + + public: + typedef typename matrix_type::mem_manager_type mem_manager_type; + typedef typename matrix_type::type scalar_type; + typedef typename matrix_type::layout_type layout_type; + typedef matrix general_matrix; + + + template < + typename vector_type1, + typename vector_type2, + typename weight_function_type + > + void build ( + const vector_type1& samples, + const vector_type2& edges, + const weight_function_type& weight_funct + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(edges.size() > 0 && + contains_duplicate_pairs(edges) == false && + max_index_plus_one(edges) <= samples.size(), + "\t void linear_manifold_regularizer::build()" + << "\n\t Invalid inputs were given to this function." + << "\n\t edges.size(): " << edges.size() + << "\n\t samples.size(): " << samples.size() + << "\n\t contains_duplicate_pairs(edges): " << contains_duplicate_pairs(edges) + << "\n\t max_index_plus_one(edges): " << max_index_plus_one(edges) + ); + + + impl::undirected_adjacency_list graph; + graph.build(edges, weight_funct); + + sum_edge_weights = graph.sum_of_edge_weights(); + + make_mr_matrix(samples, graph); + } + + long dimensionality ( + ) const { return reg_mat.nr(); } + + general_matrix get_transformation_matrix ( + scalar_type intrinsic_regularization_strength + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(intrinsic_regularization_strength >= 0, + "\t matrix linear_manifold_regularizer::get_transformation_matrix()" + << "\n\t This value must not be negative" + << "\n\t intrinsic_regularization_strength: " << intrinsic_regularization_strength + ); + + if (dimensionality() == 0) + return general_matrix(); + + + // This isn't how it's defined in the referenced paper but normalizing these kinds of + // sums is typical of most machine learning algorithms. Moreover, doing this makes + // the argument to this function more invariant to the size of the edge set. So it + // should make it easier for the user. + intrinsic_regularization_strength /= sum_edge_weights; + + return inv_lower_triangular(chol(identity_matrix(reg_mat.nr()) + intrinsic_regularization_strength*reg_mat)); + } + + private: + + template + void make_mr_matrix ( + const vector_type& samples, + const impl::undirected_adjacency_list& graph + ) + /*! + requires + - samples.size() == graph.size() + ensures + - computes trans(X)*lap(graph)*X where X is the data matrix + (i.e. the matrix that contains all the samples in its rows) + and lap(graph) is the laplacian matrix of the graph. The + resulting matrix is stored in reg_mat. + !*/ + { + const unsigned long dims = samples[0].size(); + reg_mat.set_size(dims,dims); + reg_mat = 0; + + + typename impl::undirected_adjacency_list::const_iterator beg, end; + + // loop over the columns of the X matrix + for (unsigned long d = 0; d < dims; ++d) + { + // loop down the row of X + for (unsigned long i = 0; i < graph.size(); ++i) + { + beg = graph.begin(i); + end = graph.end(i); + + // if this node in the graph has any neighbors + if (beg != end) + { + double weight_sum = 0; + double val = 0; + for (; beg != end; ++beg) + { + val -= beg->weight * samples[beg->index](d); + weight_sum += beg->weight; + } + val += weight_sum * samples[i](d); + + for (unsigned long j = 0; j < dims; ++j) + { + reg_mat(d,j) += val*samples[i](j); + } + } + } + } + + } + + general_matrix reg_mat; + double sum_edge_weights; + }; + +} + +// ---------------------------------------------------------------------------------------- + +#endif // DLIB_LINEAR_MANIFOLD_ReGULARIZER_Hh_ + diff --git a/ml/dlib/dlib/manifold_regularization/linear_manifold_regularizer_abstract.h b/ml/dlib/dlib/manifold_regularization/linear_manifold_regularizer_abstract.h new file mode 100644 index 000000000..9a9b579c9 --- /dev/null +++ b/ml/dlib/dlib/manifold_regularization/linear_manifold_regularizer_abstract.h @@ -0,0 +1,137 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_LINEAR_MANIFOLD_ReGULARIZER_ABSTRACT_Hh_ +#ifdef DLIB_LINEAR_MANIFOLD_ReGULARIZER_ABSTRACT_Hh_ + +#include +#include +#include "../serialize.h" +#include "../matrix.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename matrix_type + > + class linear_manifold_regularizer + { + /*! + REQUIREMENTS ON matrix_type + Must be some type of dlib::matrix. + + INITIAL VALUE + - dimensionality() == 0 + + WHAT THIS OBJECT REPRESENTS + Many learning algorithms attempt to minimize a function that, at a high + level, looks like this: + f(w) == complexity + training_set_error + + The idea is to find the set of parameters, w, that gives low error on + your training data but also is not "complex" according to some particular + measure of complexity. This strategy of penalizing complexity is + usually called regularization. + + In the above setting, all the training data consists of labeled samples. + However, it would be nice to be able to benefit from unlabeled data. + The idea of manifold regularization is to extract useful information from + unlabeled data by first defining which data samples are "close" to each other + (perhaps by using their 3 nearest neighbors) and then adding a term to + the above function that penalizes any decision rule which produces + different outputs on data samples which we have designated as being close. + + It turns out that it is possible to transform these manifold regularized + learning problems into the normal form shown above by applying a certain kind + of preprocessing to all our data samples. Once this is done we can use a + normal learning algorithm, such as the svm_c_linear_trainer, on just the + labeled data samples and obtain the same output as the manifold regularized + learner would have produced. + + The linear_manifold_regularizer is a tool for creating this preprocessing + transformation. In particular, the transformation is linear. That is, it + is just a matrix you multiply with all your samples. For a more detailed + discussion of this topic you should consult the following paper. In + particular, see section 4.2. This object computes the inverse T matrix + described in that section. + + Linear Manifold Regularization for Large Scale Semi-supervised Learning + by Vikas Sindhwani, Partha Niyogi, and Mikhail Belkin + !*/ + + public: + typedef typename matrix_type::mem_manager_type mem_manager_type; + typedef typename matrix_type::type scalar_type; + typedef typename matrix_type::layout_type layout_type; + typedef matrix general_matrix; + + template < + typename vector_type1, + typename vector_type2, + typename weight_function_type + > + void build ( + const vector_type1& samples, + const vector_type2& edges, + const weight_function_type& weight_funct + ); + /*! + requires + - vector_type1 == a type with an interface compatible with std::vector and it must + in turn contain dlib::matrix objects. + - vector_type2 == a type with an interface compatible with std::vector and + it must in turn contain objects with an interface compatible with dlib::sample_pair + - edges.size() > 0 + - contains_duplicate_pairs(edges) == false + - max_index_plus_one(edges) <= samples.size() + - weight_funct(edges[i]) must be a valid expression that evaluates to a + floating point number >= 0 + ensures + - #dimensionality() == samples[0].size() + - This function sets up the transformation matrix describe above. The manifold + regularization is done assuming that the samples are meant to be "close" + according to the graph defined by the given edges. I.e: + - for all valid i: samples[edges[i].index1()] is close to samples[edges[i].index2()]. + How much we care about these two samples having similar outputs according + to the learned rule is given by weight_funct(edges[i]). Bigger weights mean + we care more. + !*/ + + long dimensionality ( + ) const; + /*! + ensures + - returns the number of rows and columns in the transformation matrix + produced by this object. + !*/ + + general_matrix get_transformation_matrix ( + scalar_type intrinsic_regularization_strength + ) const; + /*! + requires + - intrinsic_regularization_strength >= 0 + ensures + - returns a matrix that represents the preprocessing transformation described above. + - You must choose how important the manifold regularizer is relative to the basic + "don't be complex" regularizer described above. The intrinsic_regularization_strength + is the parameter that controls this trade-off. A large value of + intrinsic_regularization_strength means that more emphasis should be placed on + finding decision rules which produce the same output on similar samples. On + the other hand, a small value would mean that we don't care much about the + manifold regularizer. For example, using 0 will cause this function to return the + identity matrix. + - The returned matrix will have dimensionality() rows and columns. + !*/ + + }; + +} + +// ---------------------------------------------------------------------------------------- + +#endif // DLIB_LINEAR_MANIFOLD_ReGULARIZER_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/map.h b/ml/dlib/dlib/map.h new file mode 100644 index 000000000..12036380e --- /dev/null +++ b/ml/dlib/dlib/map.h @@ -0,0 +1,59 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MAp_ +#define DLIB_MAp_ + +#include "map/map_kernel_1.h" +#include "map/map_kernel_c.h" + +#include "binary_search_tree.h" + + +#include "algs.h" +#include + + +namespace dlib +{ + + template < + typename domain, + typename range, + typename mem_manager = default_memory_manager, + typename compare = std::less + > + class map + { + map() {} + + + // a typedef for the binary search tree used by kernel_2 + typedef typename binary_search_tree::kernel_1a + binary_search_tree_1; + + // a typedef for the binary search tree used by kernel_2 + typedef typename binary_search_tree::kernel_2a + binary_search_tree_2; + + public: + + //----------- kernels --------------- + + // kernel_1a + typedef map_kernel_1 + kernel_1a; + typedef map_kernel_c + kernel_1a_c; + + // kernel_1b + typedef map_kernel_1 + kernel_1b; + typedef map_kernel_c + kernel_1b_c; + + + }; +} + +#endif // DLIB_MAp_ + diff --git a/ml/dlib/dlib/map/map_kernel_1.h b/ml/dlib/dlib/map/map_kernel_1.h new file mode 100644 index 000000000..1c79d179f --- /dev/null +++ b/ml/dlib/dlib/map/map_kernel_1.h @@ -0,0 +1,436 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MAP_KERNEl_1_ +#define DLIB_MAP_KERNEl_1_ + +#include "map_kernel_abstract.h" +#include "../algs.h" +#include "../interfaces/enumerable.h" +#include "../interfaces/map_pair.h" +#include "../interfaces/remover.h" +#include "../serialize.h" + +namespace dlib +{ + + template < + typename domain, + typename range, + typename bst_base, + typename mem_manager = default_memory_manager + > + class map_kernel_1 : public enumerable >, + public asc_pair_remover + { + + /*! + REQUIREMENTS ON BST_BASE + bst_base is instantiated with domain and range and + implements binary_search_tree/binary_search_tree_kernel_abstract.h + + INITIAL VALUE + bst has its initial value + + CONVENTION + bst.size() == the number of elements in the map and + the elements in map are stored in bst_base + !*/ + + public: + + typedef domain domain_type; + typedef range range_type; + typedef typename bst_base::compare_type compare_type; + typedef mem_manager mem_manager_type; + + map_kernel_1( + ) + {} + + virtual ~map_kernel_1( + ) + {} + + inline void clear( + ); + + inline void add ( + domain& d, + range& r + ); + + inline bool is_in_domain ( + const domain& d + ) const; + + inline void remove_any ( + domain& d, + range& r + ); + + inline void remove ( + const domain& d, + domain& d_copy, + range& r + ); + + inline void destroy ( + const domain& d + ); + + inline range& operator[] ( + const domain& d + ); + + inline const range& operator[] ( + const domain& d + ) const; + + inline void swap ( + map_kernel_1& item + ); + + // functions from the enumerable interface + inline size_t size ( + ) const; + + inline bool at_start ( + ) const; + + inline void reset ( + ) const; + + inline bool current_element_valid ( + ) const; + + inline const map_pair& element ( + ) const; + + inline map_pair& element ( + ); + + inline bool move_next ( + ) const; + + + private: + + bst_base bst; + + // restricted functions + map_kernel_1(map_kernel_1&); + map_kernel_1& operator= ( map_kernel_1&); + }; + + + template < + typename domain, + typename range, + typename bst_base, + typename mem_manager + > + inline void swap ( + map_kernel_1& a, + map_kernel_1& b + ) { a.swap(b); } + + + template < + typename domain, + typename range, + typename bst_base, + typename mem_manager + > + void deserialize ( + map_kernel_1& item, + std::istream& in + ) + { + try + { + item.clear(); + unsigned long size; + deserialize(size,in); + domain d; + range r; + for (unsigned long i = 0; i < size; ++i) + { + deserialize(d,in); + deserialize(r,in); + item.add(d,r); + } + } + catch (serialization_error e) + { + item.clear(); + throw serialization_error(e.info + "\n while deserializing object of type map_kernel_1"); + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename bst_base, + typename mem_manager + > + void map_kernel_1:: + clear ( + ) + { + bst.clear(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename bst_base, + typename mem_manager + > + void map_kernel_1:: + add( + domain& d, + range& r + ) + { + // try to add pair to bst_base + bst.add(d,r); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename bst_base, + typename mem_manager + > + bool map_kernel_1:: + is_in_domain( + const domain& d + ) const + { + return (bst[d] != 0); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename bst_base, + typename mem_manager + > + void map_kernel_1:: + remove_any( + domain& d, + range& r + ) + { + bst.remove_any(d,r); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename bst_base, + typename mem_manager + > + void map_kernel_1:: + remove ( + const domain& d, + domain& d_copy, + range& r + ) + { + bst.remove(d,d_copy,r); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename bst_base, + typename mem_manager + > + void map_kernel_1:: + destroy ( + const domain& d + ) + { + bst.destroy(d); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename bst_base, + typename mem_manager + > + range& map_kernel_1:: + operator[]( + const domain& d + ) + { + return *bst[d]; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename bst_base, + typename mem_manager + > + const range& map_kernel_1:: + operator[]( + const domain& d + ) const + { + return *bst[d]; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename bst_base, + typename mem_manager + > + size_t map_kernel_1:: + size ( + ) const + { + return bst.size(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename bst_base, + typename mem_manager + > + void map_kernel_1:: + swap ( + map_kernel_1& item + ) + { + bst.swap(item.bst); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // enumerable function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename bst_base, + typename mem_manager + > + bool map_kernel_1:: + at_start ( + ) const + { + return bst.at_start(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename bst_base, + typename mem_manager + > + void map_kernel_1:: + reset ( + ) const + { + bst.reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename bst_base, + typename mem_manager + > + bool map_kernel_1:: + current_element_valid ( + ) const + { + return bst.current_element_valid(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename bst_base, + typename mem_manager + > + const map_pair& map_kernel_1:: + element ( + ) const + { + return bst.element(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename bst_base, + typename mem_manager + > + map_pair& map_kernel_1:: + element ( + ) + { + return bst.element(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename bst_base, + typename mem_manager + > + bool map_kernel_1:: + move_next ( + ) const + { + return bst.move_next(); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MAP_KERNEl_1_ + diff --git a/ml/dlib/dlib/map/map_kernel_abstract.h b/ml/dlib/dlib/map/map_kernel_abstract.h new file mode 100644 index 000000000..1e07e5e56 --- /dev/null +++ b/ml/dlib/dlib/map/map_kernel_abstract.h @@ -0,0 +1,235 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_MAP_KERNEl_ABSTRACT_ +#ifdef DLIB_MAP_KERNEl_ABSTRACT_ + +#include "../interfaces/map_pair.h" +#include "../interfaces/enumerable.h" +#include "../interfaces/remover.h" +#include "../serialize.h" +#include "../algs.h" +#include + +namespace dlib +{ + + template < + typename domain, + typename range, + typename mem_manager = default_memory_manager, + typename compare = std::less + > + class map : public enumerable >, + public asc_pair_remover + { + + /*! + REQUIREMENTS ON domain + domain must be comparable by compare where compare is a functor compatible with std::less and + domain is swappable by a global swap() and + domain must have a default constructor + + REQUIREMENTS ON range + range is swappable by a global swap() and + range must have a default constructor + + REQUIREMENTS ON mem_manager + must be an implementation of memory_manager/memory_manager_kernel_abstract.h or + must be an implementation of memory_manager_global/memory_manager_global_kernel_abstract.h or + must be an implementation of memory_manager_stateless/memory_manager_stateless_kernel_abstract.h + mem_manager::type can be set to anything. + + POINTERS AND REFERENCES TO INTERNAL DATA + swap(), is_in_domain(), and operator[] functions do not invalidate + pointers or references to internal data. + All other functions have no such guarantee. + + INITIAL VALUE + size() == 0 + + ENUMERATION ORDER + The enumerator will iterate over the domain (and each associated + range element) elements in ascending order according to the compare functor. + (i.e. the elements are enumerated in sorted order) + + WHAT THIS OBJECT REPRESENTS + map contains items of type domain and range + + This object is similar an array. It maps items of type domain on to + items of type range. + + Also note that unless specified otherwise, no member functions + of this object throw exceptions. + + definition of equivalent: + a is equivalent to b if + a < b == false and + b < a == false + !*/ + + public: + + typedef domain domain_type; + typedef range range_type; + typedef compare compare_type; + typedef mem_manager mem_manager_type; + + map( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc or any exception thrown by domain's or range's + constructor. + !*/ + + virtual ~map( + ); + /*! + ensures + - all memory associated with *this has been released + !*/ + + void clear( + ); + /*! + ensures + - #*this has its initial value + throws + - std::bad_alloc or any exception thrown by domain's or range's + constructor. + if this exception is thrown then *this is unusable + until clear() is called and succeeds + !*/ + + void add ( + domain& d, + range& r + ); + /*! + requires + - &d != &r (i.e. d and r cannot be the same variable) + - is_in_domain(d) == false + ensures + - #is_in_domain(d) == true + - #operator[](d) == r + - #d and #r have initial values for their types + - #size() == size() + 1 + - #at_start() == true + throws + - std::bad_alloc or any exception thrown by domain's or range's + constructor. + if add() throws then it has no effect + !*/ + + bool is_in_domain ( + const domain& d + ) const; + /*! + ensures + - returns whether or not an element equivalent to d is in the + domain of *this + !*/ + + void remove ( + const domain& d, + domain& d_copy, + range& r + ); + /*! + requires + - &d != &r (i.e. d and r cannot be the same variable) + - &d != &d_copy (i.e. d and d_copy cannot be the same variable) + - &r != &d_copy (i.e. r and d_copy cannot be the same variable) + - is_in_domain(d) == true + ensures + - #is_in_domain(d) == false + - #d_copy is equivalent to d + - the element in the range of *this associated with #d_copy has been + swapped into #r + - #size() == size() - 1 + - #at_start() == true + !*/ + + void destroy ( + const domain& d + ); + /*! + requires + - is_in_domain(d) == true + ensures + - #is_in_domain(d) == false + - #size() == size() - 1 + - #at_start() == true + !*/ + + range& operator[] ( + const domain& d + ); + /*! + requires + - is_in_domain(d) == true + ensures + - returns a non-const reference to the element in the range of *this + associated with the element equivalent to d + !*/ + + const range& operator[] ( + const domain& d + ) const; + /*! + requires + - is_in_domain(d) == true + ensures + - returns a const reference to the element in the range of *this + associated with the element equivalent to d + !*/ + + void swap ( + map& item + ); + /*! + ensures + - swaps *this and item + !*/ + + + private: + + // restricted functions + map(map&); // copy constructor + map& operator=(map&); // assignment operator + }; + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + inline void swap ( + map& a, + map& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void deserialize ( + map& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ +} + +#endif // DLIB_MAP_KERNEl_ABSTRACT_ + diff --git a/ml/dlib/dlib/map/map_kernel_c.h b/ml/dlib/dlib/map/map_kernel_c.h new file mode 100644 index 000000000..dfdbd4632 --- /dev/null +++ b/ml/dlib/dlib/map/map_kernel_c.h @@ -0,0 +1,248 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MAP_KERNEl_C_ +#define DLIB_MAP_KERNEl_C_ + +#include "map_kernel_abstract.h" +#include "../algs.h" +#include "../assert.h" +#include "../interfaces/map_pair.h" + +namespace dlib +{ + + template < + typename map_base + > + class map_kernel_c : public map_base + { + + typedef typename map_base::domain_type domain; + typedef typename map_base::range_type range; + + public: + void add ( + domain& d, + range& r + ); + + void remove_any ( + domain& d, + range& r + ); + + void remove ( + const domain& d, + domain& d_copy, + range& r + ); + + void destroy ( + const domain& d + ); + + range& operator[] ( + const domain& d + ); + + const range& operator[] ( + const domain& d + ) const; + + const map_pair& element ( + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT(this->current_element_valid() == true, + "\tconst map_pair& map::element" + << "\n\tyou can't access the current element if it doesn't exist" + << "\n\tthis: " << this + ); + + // call the real function + return map_base::element(); + } + + map_pair& element ( + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(this->current_element_valid() == true, + "\tmap_pair& map::element" + << "\n\tyou can't access the current element if it doesn't exist" + << "\n\tthis: " << this + ); + + // call the real function + return map_base::element(); + } + + }; + + template < + typename map_base + > + inline void swap ( + map_kernel_c& a, + map_kernel_c& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename map_base + > + void map_kernel_c:: + add ( + domain& d, + range& r + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( (!this->is_in_domain(d)) && + (static_cast(&d) != static_cast(&r)), + "\tvoid map::add" + << "\n\tdomain element being added must not already be in the map" + << "\n\tand d and r must not be the same variable" + << "\n\tis_in_domain(d): " << (this->is_in_domain(d) ? "true" : "false") + << "\n\tthis: " << this + << "\n\t&d: " << static_cast(&d) + << "\n\t&r: " << static_cast(&r) + ); + + // call the real function + map_base::add(d,r); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_base + > + void map_kernel_c:: + remove_any ( + domain& d, + range& r + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( (this->size() > 0) && + (static_cast(&d) != static_cast(&r)), + "\tvoid map::remove_any" + << "\n\tsize() must be greater than zero if something is going to be removed" + << "\n\tand d and r must not be the same variable." + << "\n\tsize(): " << this->size() + << "\n\tthis: " << this + << "\n\t&d: " << static_cast(&d) + << "\n\t&r: " << static_cast(&r) + ); + + // call the real function + map_base::remove_any(d,r); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_base + > + void map_kernel_c:: + remove ( + const domain& d, + domain& d_copy, + range& r + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( (this->is_in_domain(d)) && + (static_cast(&d) != static_cast(&r)) && + (static_cast(&r) != static_cast(&d_copy)) && + (static_cast(&d) != static_cast(&d_copy)), + "\tvoid map::remove" + << "\n\tcan't remove something that isn't in the map or if the paremeters actually" + << "\n\tare the same variable. Either way can't remove." + << "\n\tis_in_domain(d): " << (this->is_in_domain(d) ? "true" : "false") + << "\n\tthis: " << this + << "\n\t&d: " << static_cast(&d) + << "\n\t&r: " << static_cast(&r) + << "\n\t&d_copy: " << static_cast(&d_copy) + ); + + // call the real function + map_base::remove(d,d_copy,r); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_base + > + void map_kernel_c:: + destroy ( + const domain& d + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(this->is_in_domain(d), + "\tvoid map::destroy" + << "\n\tcan't remove something that isn't in the map" + << "\n\tthis: " << this + << "\n\t&d: " << static_cast(&d) + ); + + // call the real function + map_base::destroy(d); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_base + > + typename map_base::range_type& map_kernel_c:: + operator[] ( + const domain& d + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( this->is_in_domain(d), + "\trange& map::operator[]" + << "\n\td must be in the domain of the map" + << "\n\tthis: " << this + ); + + // call the real function + return map_base::operator[](d); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_base + > + const typename map_base::range_type& map_kernel_c:: + operator[] ( + const domain& d + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT( this->is_in_domain(d), + "\tconst range& map::operator[]" + << "\n\td must be in the domain of the map" + << "\n\tthis: " << this + ); + + // call the real function + return map_base::operator[](d); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MAP_KERNEl_C_ + diff --git a/ml/dlib/dlib/matlab/CMakeLists.txt b/ml/dlib/dlib/matlab/CMakeLists.txt new file mode 100644 index 000000000..b9a0beab9 --- /dev/null +++ b/ml/dlib/dlib/matlab/CMakeLists.txt @@ -0,0 +1,22 @@ + +cmake_minimum_required(VERSION 2.8.12) + +PROJECT(mex_functions) + +include(cmake_mex_wrapper) + +add_subdirectory(.. dlib_build) + + +# You can tell cmake where to put the mex files when you run 'make install' by +# setting this variable. The path is relative to this CMakeLists.txt file. +set(install_target_output_folder .) + +# Compile the example_mex_function.cpp file and link it to dlib. Note +# that you can give a list of things to link to here. E.g. +# add_mex_function(some_other_mex_function pthread dlib fftw) +add_mex_function(example_mex_function dlib::dlib) +add_mex_function(example_mex_callback dlib::dlib) +add_mex_function(example_mex_struct dlib::dlib) +add_mex_function(example_mex_class dlib::dlib) + diff --git a/ml/dlib/dlib/matlab/README.txt b/ml/dlib/dlib/matlab/README.txt new file mode 100644 index 000000000..d571e2330 --- /dev/null +++ b/ml/dlib/dlib/matlab/README.txt @@ -0,0 +1,20 @@ +This folder contains a set of tools which make it easy to create MATLAB mex +functions. To understand how they work, you should read the +example_mex_function.cpp, example_mex_struct.cpp, and example_mex_callback.cpp examples. + +To compile them, you can use CMake. In particular, from this folder execute +these commands: + + mkdir build + cd build + cmake .. + cmake --build . --config release --target install + +That should build the mex files on any platform. + +Note that on windows you will probably need to tell CMake to use a 64bit +version of visual studio. You can do this by using a command like: + cmake -G "Visual Studio 10 Win64" .. +instead of + cmake .. + diff --git a/ml/dlib/dlib/matlab/call_matlab.h b/ml/dlib/dlib/matlab/call_matlab.h new file mode 100644 index 000000000..cc06a6812 --- /dev/null +++ b/ml/dlib/dlib/matlab/call_matlab.h @@ -0,0 +1,852 @@ +// Copyright (C) 2012 Massachusetts Institute of Technology, Lincoln Laboratory +// License: Boost Software License See LICENSE.txt for the full license. +// Authors: Davis E. King (davis@dlib.net) +#ifndef MIT_LL_CALL_MATLAB_H__ +#define MIT_LL_CALL_MATLAB_H__ + +#include +#include +#include +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + +struct invalid_args_exception : error +{ + /*! + WHAT THIS OBJECT REPRESENTS + This is the exception thrown when the mex wrapper tries to convert a matlab + object into a C++ object but for whatever reason can't (usually because the + types don't match). + !*/ + invalid_args_exception(const std::string& msg_): error(msg_) {} + invalid_args_exception(const std::ostringstream& msg_): error(msg_.str()) {} +}; + +// ---------------------------------------------------------------------------------------- + +void check_for_matlab_ctrl_c(); +/*! + ensures + - If the user of MATLAB has pressed ctrl+c then this function will throw an + exception. +!*/ + +// ---------------------------------------------------------------------------------------- + +class matlab_object +{ + /*! + WHAT THIS OBJECT REPRESENTS + This object is a simple wrapper around matlab's generic mxArray, which is the + thing that is matlab's "anything object". So a matlab_object can be used as an + argument to a mex_function() that can bind to any matlab object at all. It can + also bind to "nothing" and so is inherently also an optional argument when + present in a mex_funciton(). + !*/ +public: + matlab_object() : handle(0),should_free(false),arg_idx(0) {} + matlab_object(const matlab_object&) = delete; + + ~matlab_object(); + + + // Check if a matlab object is bound to this object. + bool is_empty() const { return handle==0; } + operator bool() const { return handle!=0; } + + // Convert from MATLAB to C++, throw invalid_args_exception if not possible. + template operator T() const; + template void get(T& item) const; + + // Convert from a C++ object to MATLAB + template matlab_object& operator= (const T& new_val); + + + template bool try_get(T& item) const + { + try { get(item); return true; } + catch(invalid_args_exception&) { return false; } + } + + const void* get_handle() const { return handle; } + /*! + ensures + - returns a pointer to the mxArray object. Might be NULL. + !*/ + + + matlab_object& operator=(const matlab_object&) = delete; + + // Users shouldn't call these functions + const void* release_object_to_matlab() { const void* temp=handle; handle = 0; return temp; } + void set_object_handle(int arg_idx_, const void* sh) { DLIB_CASSERT(!handle); handle = sh; arg_idx=arg_idx_; } +private: + + const void* handle; + bool should_free; + int arg_idx; +}; + +// ---------------------------------------------------------------------------------------- + +class matlab_struct +{ + /*! + WHAT THIS OBJECT REPRESENTS + This object lets you interface with MATLAB structs from C++. For example, + given a MATLAB struct named mystruct, you could access it's fields like this: + MATLAB way: mystruct.field + C++ way: mystruct["field"] + MATLAB way: mystruct.field.subfield + C++ way: mystruct["field"]["subfield"] + + To get the values as C++ types you do something like this: + int val = mystruct["field"]; + or + int val; + mystruct["field"].get(val); + + See also example_mex_struct.cpp for an example that uses this part of the API. + !*/ + + class sub; +public: + matlab_struct() : struct_handle(0),should_free(false),arg_idx(0) {} + matlab_struct(const matlab_struct&) = delete; + ~matlab_struct(); + + const sub operator[] (const std::string& name) const; + sub operator[] (const std::string& name); + bool has_field(const std::string& name) const; + + const void* release_struct_to_matlab() { const void* temp=struct_handle; struct_handle = 0; return temp; } + void set_struct_handle(int arg_idx_, const void* sh) { DLIB_CASSERT(!struct_handle); struct_handle = sh; arg_idx=arg_idx_; } +private: + + class sub + { + public: + sub() : struct_handle(0), field_idx(-1) {} + + template operator T() const; + template void get(T& item) const; + template sub& operator= (const T& new_val); + const sub operator[] (const std::string& name) const; + sub operator[] (const std::string& name); + bool has_field(const std::string& name) const; + private: + friend class matlab_struct; + const void* struct_handle; + int field_idx; + sub& operator=(const sub&); + }; + const void* struct_handle; + bool should_free; + int arg_idx; + matlab_struct& operator=(const matlab_struct&); +}; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + +template +struct output_decorator +{ + output_decorator(T& item_):item(item_){} + T& item; +}; + +template +output_decorator returns(T& item) { return output_decorator(item); } +/*! + ensures + - decorates item as an output type. This stuff is used by the call_matlab() + functions to tell if an argument is an input to the function or is supposed + to be bound to one of the return arguments. +!*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + +struct function_handle +{ + /*! + WHAT THIS OBJECT REPRESENTS + This type is used to represent function handles passed from MATLAB into a + mex function. You can call the function referenced by the handle by + saying: + call_matlab(my_handle); + !*/ + + // These two lines are just implementation details, ignore them. + function_handle():h(0){} + void* const h; +}; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + +void call_matlab ( + const std::string& function_name +); +/*! + ensures + - Calls MATLAB's function of the given name +!*/ + +// ---------------------------------------------------------------------------------------- + +void call_matlab ( + const function_handle& funct +); +/*! + ensures + - Calls MATLAB's function represented by the handle funct +!*/ + +// ---------------------------------------------------------------------------------------- + +template < + typename T1 + > +void call_matlab ( + const std::string& function_name, + const T1& A1 +); +/*! + ensures + - calls MATLAB's function of the given name. + - if (A1 is not decorated as an output by returns()) then + - A1 is passed as an argument into the MATLAB function + - else + - A1 is treated as the first return value from the MATLAB function. +!*/ + +template < + typename T1 + > +void call_matlab ( + const function_handle& funct, + const T1& A1 +) { call_matlab("feval", funct, A1); } +/*! + ensures + - Calls MATLAB's function represented by the handle funct + - if (A1 is not decorated as an output by returns()) then + - A1 is passed as an argument into the MATLAB function + - else + - A1 is treated as the first return value from the MATLAB function. +!*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +/* + The rest of this file is just overloads of call_matlab() for up to 10 arguments (or + just 9 arguments if function_handle is used). They all do the same thing as the above + version of call_matlab(). Generally, any argument not decorated by returns() is an + input to the MATLAB function. On the other hand, all arguments decorated by returns() + are treated as outputs. +*/ +// ---------------------------------------------------------------------------------------- + +template < + typename T1, + typename T2 + > +void call_matlab ( + const std::string& function_name, + const T1& A1, + const T2& A2 +); + +// ---------------------------------------------------------------------------------------- + +template < + typename T1, + typename T2, + typename T3 + > +void call_matlab ( + const std::string& function_name, + const T1& A1, + const T2& A2, + const T3& A3 +); + +// ---------------------------------------------------------------------------------------- + +template < + typename T1, + typename T2, + typename T3, + typename T4 + > +void call_matlab ( + const std::string& function_name, + const T1& A1, + const T2& A2, + const T3& A3, + const T4& A4 +); + +// ---------------------------------------------------------------------------------------- + +template < + typename T1, + typename T2, + typename T3, + typename T4, + typename T5 + > +void call_matlab ( + const std::string& function_name, + const T1& A1, + const T2& A2, + const T3& A3, + const T4& A4, + const T5& A5 +); + +// ---------------------------------------------------------------------------------------- + +template < + typename T1, + typename T2, + typename T3, + typename T4, + typename T5, + typename T6 + > +void call_matlab ( + const std::string& function_name, + const T1& A1, + const T2& A2, + const T3& A3, + const T4& A4, + const T5& A5, + const T6& A6 +); + +// ---------------------------------------------------------------------------------------- + +template < + typename T1, + typename T2, + typename T3, + typename T4, + typename T5, + typename T6, + typename T7 + > +void call_matlab ( + const std::string& function_name, + const T1& A1, + const T2& A2, + const T3& A3, + const T4& A4, + const T5& A5, + const T6& A6, + const T7& A7 +); + +// ---------------------------------------------------------------------------------------- + +template < + typename T1, + typename T2, + typename T3, + typename T4, + typename T5, + typename T6, + typename T7, + typename T8 + > +void call_matlab ( + const std::string& function_name, + const T1& A1, + const T2& A2, + const T3& A3, + const T4& A4, + const T5& A5, + const T6& A6, + const T7& A7, + const T8& A8 +); + +// ---------------------------------------------------------------------------------------- + +template < + typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, typename + T7, typename T8, typename T9 + > +void call_matlab ( + const std::string& function_name, + const T1& A1, const T2& A2, const T3& A3, const T4& A4, const T5& A5, const T6& A6, + const T7& A7, const T8& A8, const T9& A9 +); + +// ---------------------------------------------------------------------------------------- + +template < + typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, typename + T7, typename T8, typename T9, typename T10 + > +void call_matlab ( + const std::string& function_name, + const T1& A1, const T2& A2, const T3& A3, const T4& A4, const T5& A5, const T6& A6, + const T7& A7, const T8& A8, const T9& A9, const T10& A10 +); + +// ---------------------------------------------------------------------------------------- + +template < + typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, typename + T7, typename T8, typename T9, typename T10, typename T11 + > +void call_matlab ( + const std::string& function_name, + const T1& A1, const T2& A2, const T3& A3, const T4& A4, const T5& A5, const T6& A6, + const T7& A7, const T8& A8, const T9& A9, const T10& A10, const T11& A11 +); + +// ---------------------------------------------------------------------------------------- + +template < + typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, typename + T7, typename T8, typename T9, typename T10, typename T11, typename T12 + > +void call_matlab ( + const std::string& function_name, + const T1& A1, const T2& A2, const T3& A3, const T4& A4, const T5& A5, const T6& A6, + const T7& A7, const T8& A8, const T9& A9, const T10& A10, const T11& A11, const T12& A12 +); + +// ---------------------------------------------------------------------------------------- + +template < + typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, typename + T7, typename T8, typename T9, typename T10, typename T11, typename T12, typename T13 + > +void call_matlab ( + const std::string& function_name, + const T1& A1, const T2& A2, const T3& A3, const T4& A4, const T5& A5, const T6& A6, + const T7& A7, const T8& A8, const T9& A9, const T10& A10, const T11& A11, const T12& + A12, const T13& A13 +); + +// ---------------------------------------------------------------------------------------- + +template < + typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, typename + T7, typename T8, typename T9, typename T10, typename T11, typename T12, typename T13, + typename T14 + > +void call_matlab ( + const std::string& function_name, + const T1& A1, const T2& A2, const T3& A3, const T4& A4, const T5& A5, const T6& A6, + const T7& A7, const T8& A8, const T9& A9, const T10& A10, const T11& A11, const T12& + A12, const T13& A13, const T14& A14 +); + +// ---------------------------------------------------------------------------------------- + +template < + typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, typename + T7, typename T8, typename T9, typename T10, typename T11, typename T12, typename T13, + typename T14, typename T15 + > +void call_matlab ( + const std::string& function_name, + const T1& A1, const T2& A2, const T3& A3, const T4& A4, const T5& A5, const T6& A6, + const T7& A7, const T8& A8, const T9& A9, const T10& A10, const T11& A11, const T12& + A12, const T13& A13, const T14& A14, const T15& A15 +); + +// ---------------------------------------------------------------------------------------- + +template < + typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, typename + T7, typename T8, typename T9, typename T10, typename T11, typename T12, typename T13, + typename T14, typename T15, typename T16 + > +void call_matlab ( + const std::string& function_name, + const T1& A1, const T2& A2, const T3& A3, const T4& A4, const T5& A5, const T6& A6, + const T7& A7, const T8& A8, const T9& A9, const T10& A10, const T11& A11, const T12& + A12, const T13& A13, const T14& A14, const T15& A15, const T16& A16 +); + +// ---------------------------------------------------------------------------------------- + +template < + typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, typename + T7, typename T8, typename T9, typename T10, typename T11, typename T12, typename T13, + typename T14, typename T15, typename T16, typename T17 + > +void call_matlab ( + const std::string& function_name, + const T1& A1, const T2& A2, const T3& A3, const T4& A4, const T5& A5, const T6& A6, + const T7& A7, const T8& A8, const T9& A9, const T10& A10, const T11& A11, const T12& + A12, const T13& A13, const T14& A14, const T15& A15, const T16& A16, const T17& A17 +); + +// ---------------------------------------------------------------------------------------- + +template < + typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, typename + T7, typename T8, typename T9, typename T10, typename T11, typename T12, typename T13, + typename T14, typename T15, typename T16, typename T17, typename T18 + > +void call_matlab ( + const std::string& function_name, + const T1& A1, const T2& A2, const T3& A3, const T4& A4, const T5& A5, const T6& A6, + const T7& A7, const T8& A8, const T9& A9, const T10& A10, const T11& A11, const T12& + A12, const T13& A13, const T14& A14, const T15& A15, const T16& A16, const T17& A17, + const T18& A18 +); + +// ---------------------------------------------------------------------------------------- + +template < + typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, typename + T7, typename T8, typename T9, typename T10, typename T11, typename T12, typename T13, + typename T14, typename T15, typename T16, typename T17, typename T18, typename T19 + > +void call_matlab ( + const std::string& function_name, + const T1& A1, const T2& A2, const T3& A3, const T4& A4, const T5& A5, const T6& A6, + const T7& A7, const T8& A8, const T9& A9, const T10& A10, const T11& A11, const T12& + A12, const T13& A13, const T14& A14, const T15& A15, const T16& A16, const T17& A17, + const T18& A18, const T19& A19 +); + +// ---------------------------------------------------------------------------------------- + +template < + typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, typename + T7, typename T8, typename T9, typename T10, typename T11, typename T12, typename T13, + typename T14, typename T15, typename T16, typename T17, typename T18, typename T19, + typename T20 + > +void call_matlab ( + const std::string& function_name, + const T1& A1, const T2& A2, const T3& A3, const T4& A4, const T5& A5, const T6& A6, + const T7& A7, const T8& A8, const T9& A9, const T10& A10, const T11& A11, const T12& + A12, const T13& A13, const T14& A14, const T15& A15, const T16& A16, const T17& A17, + const T18& A18, const T19& A19, const T20& A20 +); + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + +template < + typename T1, + typename T2 + > +void call_matlab ( + const function_handle& funct, + const T1& A1, + const T2& A2 +) +{ + call_matlab("feval", funct, A1, A2); +} + +template < + typename T1, + typename T2, + typename T3 + > +void call_matlab ( + const function_handle& funct, + const T1& A1, + const T2& A2, + const T3& A3 +) +{ + call_matlab("feval", funct, A1, A2, A3); +} + +template < + typename T1, + typename T2, + typename T3, + typename T4 + > +void call_matlab ( + const function_handle& funct, + const T1& A1, + const T2& A2, + const T3& A3, + const T4& A4 +) +{ + call_matlab("feval", funct, A1, A2, A3, A4); +} + +template < + typename T1, + typename T2, + typename T3, + typename T4, + typename T5 + > +void call_matlab ( + const function_handle& funct, + const T1& A1, + const T2& A2, + const T3& A3, + const T4& A4, + const T5& A5 +) +{ + call_matlab("feval", funct, A1, A2, A3, A4, A5); +} + +template < + typename T1, + typename T2, + typename T3, + typename T4, + typename T5, + typename T6 + > +void call_matlab ( + const function_handle& funct, + const T1& A1, + const T2& A2, + const T3& A3, + const T4& A4, + const T5& A5, + const T6& A6 +) +{ + call_matlab("feval", funct, A1, A2, A3, A4, A5, A6); +} + +template < + typename T1, + typename T2, + typename T3, + typename T4, + typename T5, + typename T6, + typename T7 + > +void call_matlab ( + const function_handle& funct, + const T1& A1, + const T2& A2, + const T3& A3, + const T4& A4, + const T5& A5, + const T6& A6, + const T7& A7 +) +{ + call_matlab("feval", funct, A1, A2, A3, A4, A5, A6, A7); +} + +template < + typename T1, + typename T2, + typename T3, + typename T4, + typename T5, + typename T6, + typename T7, + typename T8 + > +void call_matlab ( + const function_handle& funct, + const T1& A1, + const T2& A2, + const T3& A3, + const T4& A4, + const T5& A5, + const T6& A6, + const T7& A7, + const T8& A8 +) +{ + call_matlab("feval", funct, A1, A2, A3, A4, A5, A6, A7, A8); +} + +template < + typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, typename + T7, typename T8, typename T9 + > +void call_matlab ( + const function_handle& funct, + const T1& A1, const T2& A2, const T3& A3, const T4& A4, const T5& A5, const T6& A6, + const T7& A7, const T8& A8, const T9& A9 +) +{ + call_matlab("feval", funct, A1, A2, A3, A4, A5, A6, A7, A8, A9); +} + +template < + typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, typename + T7, typename T8, typename T9, typename T10 + > +void call_matlab ( + const function_handle& funct, + const T1& A1, const T2& A2, const T3& A3, const T4& A4, const T5& A5, const T6& A6, + const T7& A7, const T8& A8, const T9& A9, const T10& A10 +) +{ + call_matlab("feval", funct, A1, A2, A3, A4, A5, A6, A7, A8, A9, A10); +} + +template < + typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, typename + T7, typename T8, typename T9, typename T10, typename T11 + > +void call_matlab ( + const function_handle& funct, + const T1& A1, const T2& A2, const T3& A3, const T4& A4, const T5& A5, const T6& A6, + const T7& A7, const T8& A8, const T9& A9, const T10& A10, const T11& A11 +) +{ + call_matlab("feval", funct, A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11); +} + +template < + typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, typename + T7, typename T8, typename T9, typename T10, typename T11, typename T12 + > +void call_matlab ( + const function_handle& funct, + const T1& A1, const T2& A2, const T3& A3, const T4& A4, const T5& A5, const T6& A6, + const T7& A7, const T8& A8, const T9& A9, const T10& A10, const T11& A11, const T12& + A12 +) +{ + call_matlab("feval", funct, A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12); +} + +template < + typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, typename + T7, typename T8, typename T9, typename T10, typename T11, typename T12, typename T13 + > +void call_matlab ( + const function_handle& funct, + const T1& A1, const T2& A2, const T3& A3, const T4& A4, const T5& A5, const T6& A6, + const T7& A7, const T8& A8, const T9& A9, const T10& A10, const T11& A11, const T12& + A12, const T13& A13 +) +{ + call_matlab("feval", funct, A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13); +} + +template < + typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, typename + T7, typename T8, typename T9, typename T10, typename T11, typename T12, typename T13, + typename T14 + > +void call_matlab ( + const function_handle& funct, + const T1& A1, const T2& A2, const T3& A3, const T4& A4, const T5& A5, const T6& A6, + const T7& A7, const T8& A8, const T9& A9, const T10& A10, const T11& A11, const T12& + A12, const T13& A13, const T14& A14 +) +{ + call_matlab("feval", funct, A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14); +} + +template < + typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, typename + T7, typename T8, typename T9, typename T10, typename T11, typename T12, typename T13, + typename T14, typename T15 + > +void call_matlab ( + const function_handle& funct, + const T1& A1, const T2& A2, const T3& A3, const T4& A4, const T5& A5, const T6& A6, + const T7& A7, const T8& A8, const T9& A9, const T10& A10, const T11& A11, const T12& + A12, const T13& A13, const T14& A14, const T15& A15 +) +{ + call_matlab("feval", funct, A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15); +} + +template < + typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, typename + T7, typename T8, typename T9, typename T10, typename T11, typename T12, typename T13, + typename T14, typename T15, typename T16 + > +void call_matlab ( + const function_handle& funct, + const T1& A1, const T2& A2, const T3& A3, const T4& A4, const T5& A5, const T6& A6, + const T7& A7, const T8& A8, const T9& A9, const T10& A10, const T11& A11, const T12& + A12, const T13& A13, const T14& A14, const T15& A15, const T16& A16 +) +{ + call_matlab("feval", funct, A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16); +} + +template < + typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, typename + T7, typename T8, typename T9, typename T10, typename T11, typename T12, typename T13, + typename T14, typename T15, typename T16, typename T17 + > +void call_matlab ( + const function_handle& funct, + const T1& A1, const T2& A2, const T3& A3, const T4& A4, const T5& A5, const T6& A6, + const T7& A7, const T8& A8, const T9& A9, const T10& A10, const T11& A11, const T12& + A12, const T13& A13, const T14& A14, const T15& A15, const T16& A16, const T17& A17 +) +{ + call_matlab("feval", funct, A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17); +} + +template < + typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, typename + T7, typename T8, typename T9, typename T10, typename T11, typename T12, typename T13, + typename T14, typename T15, typename T16, typename T17, typename T18 + > +void call_matlab ( + const function_handle& funct, + const T1& A1, const T2& A2, const T3& A3, const T4& A4, const T5& A5, const T6& A6, + const T7& A7, const T8& A8, const T9& A9, const T10& A10, const T11& A11, const T12& + A12, const T13& A13, const T14& A14, const T15& A15, const T16& A16, const T17& A17, + const T18& A18 +) +{ + call_matlab("feval", funct, A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18); +} + +template < + typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, typename + T7, typename T8, typename T9, typename T10, typename T11, typename T12, typename T13, + typename T14, typename T15, typename T16, typename T17, typename T18, typename T19 + > +void call_matlab ( + const function_handle& funct, + const T1& A1, const T2& A2, const T3& A3, const T4& A4, const T5& A5, const T6& A6, + const T7& A7, const T8& A8, const T9& A9, const T10& A10, const T11& A11, const T12& + A12, const T13& A13, const T14& A14, const T15& A15, const T16& A16, const T17& A17, + const T18& A18, const T19& A19 +) +{ + call_matlab("feval", funct, A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19); +} + +// ---------------------------------------------------------------------------------------- + +// We define this function here so that, if you write some code that has check_for_matlab_ctrl_c() +// sprinkled throughout it you can still compile that code outside the mex wrapper +// environment and these calls will simply be no-ops. +#ifndef MATLAB_MEX_FILE +inline void check_for_matlab_ctrl_c() {} +#endif + +} + +#endif // MIT_LL_CALL_MATLAB_H__ + diff --git a/ml/dlib/dlib/matlab/cmake_mex_wrapper b/ml/dlib/dlib/matlab/cmake_mex_wrapper new file mode 100644 index 000000000..67729ff70 --- /dev/null +++ b/ml/dlib/dlib/matlab/cmake_mex_wrapper @@ -0,0 +1,103 @@ +# This file figures out where MATLAB is and then defines a macro, add_mex_function(name) +# which when called instructs CMake to build a mex file from a file called name.cpp. Note +# that additional library dependencies can be added like this: add_mex_function(name lib1 dlib libetc). +# That is, just add more libraries after the name and they will be build into the mex file. + +cmake_minimum_required(VERSION 2.8.12) + +set(BUILDING_MATLAB_MEX_FILE true) +set(CMAKE_POSITION_INDEPENDENT_CODE True) + +# Trying to use cuda with matlab hasn't worked well, so just disable it. +SET(DLIB_USE_CUDA OFF CACHE BOOL "" FORCE) + +# Find MATLAB's include directory and needed libraries +find_program(MATLAB_EXECUTABLE matlab PATHS + "C:/Program Files/MATLAB/*/bin" + "C:/Program Files (x86)/MATLAB/*/bin" + ) +# Resolve symbolic links to try and get the real path to the MATLAB executable +get_filename_component(MATLAB_EXECUTABLE ${MATLAB_EXECUTABLE} REALPATH) +# Now get MATLAB root directory +get_filename_component(MATLAB_HOME ${MATLAB_EXECUTABLE} PATH) +get_filename_component(MATLAB_HOME ${MATLAB_HOME} PATH) +set(MATLAB_LIB_FOLDERS + "${MATLAB_HOME}/extern/lib/win64/microsoft" + "${MATLAB_HOME}/bin/glnxa64" + ) +# If there is a MATLAB_HOME environment variable then look there as well. +if (DEFINED ENV{MATLAB_HOME}) + set(MATLAB_LIB_FOLDERS + "$ENV{MATLAB_HOME}/extern/lib/win64/microsoft" + "$ENV{MATLAB_HOME}/bin/glnxa64" + ${MATLAB_LIB_FOLDERS} + ) +endif() +# Find the MATLAB libraries that need to get linked into the mex file +if (WIN32) + find_library(MATLAB_MEX_LIBRARY libmex PATHS ${MATLAB_LIB_FOLDERS} ) + find_library(MATLAB_MX_LIBRARY libmx PATHS ${MATLAB_LIB_FOLDERS} ) + find_library(MATLAB_ENG_LIBRARY libeng PATHS ${MATLAB_LIB_FOLDERS} ) +else() + find_library(MATLAB_MEX_LIBRARY mex PATHS ${MATLAB_LIB_FOLDERS} ) + find_library(MATLAB_MX_LIBRARY mx PATHS ${MATLAB_LIB_FOLDERS} ) + find_library(MATLAB_ENG_LIBRARY eng PATHS ${MATLAB_LIB_FOLDERS} ) +endif() +set(MATLAB_LIBRARIES ${MATLAB_MEX_LIBRARY} ${MATLAB_MX_LIBRARY} ${MATLAB_ENG_LIBRARY}) +# Figure out the path to MATLAB's mex.h so we can add it to the include search path. +find_path(mex_header_path mex.h + PATHS "$ENV{MATLAB_HOME}/extern/include" + "${MATLAB_HOME}/extern/include" + ) +INCLUDE_DIRECTORIES(${mex_header_path}) + +# Determine the path to cmake_mex_wrapper file so we can add it to the include search path.. +string(REGEX REPLACE "cmake_mex_wrapper$" "" dlib_matlab_binding_path ${CMAKE_CURRENT_LIST_FILE}) +INCLUDE_DIRECTORIES("${dlib_matlab_binding_path}") +# Also add dlib to the include search path +INCLUDE_DIRECTORIES(${dlib_matlab_binding_path}/../..) + +add_definitions(-DMATLAB_MEX_FILE) + +# Determine the path to our CMakeLists.txt file. This is the file that +# includeded the one you are reading right now. So here we make it so that +# when you run the install target it will copy the compiled mex files into the +# same folder as the parent CMakeLists.txt file. +string(REGEX REPLACE "CMakeLists.txt$" "" install_dir ${CMAKE_PARENT_LIST_FILE}) +set(CMAKE_INSTALL_PREFIX "${install_dir}") +set(CMAKE_INSTALL_SYSTEM_RUNTIME_DESTINATION "${install_dir}") +INCLUDE(InstallRequiredSystemLibraries) + + +MACRO(add_mex_function name ) + ADD_LIBRARY(${name} MODULE ${name}.cpp ) + target_compile_definitions(${name} PRIVATE -DMEX_FILENAME=${name}) + if (UNIX) + # Doing this prevents our mex function from exporting any symbols + # other than mexFunction(). This sometimes doesn't matter but sometimes + # avoids causing errors or otherwise bad behavior in MATLAB. + if (DEFINED ENV{MATLAB_HOME}) + set_target_properties(${name} PROPERTIES LINK_FLAGS "-Wl,--version-script,$ENV{MATLAB_HOME}/extern/lib/glnxa64/mexFunction.map") + else() + set_target_properties(${name} PROPERTIES LINK_FLAGS "-Wl,--version-script,${MATLAB_HOME}/extern/lib/glnxa64/mexFunction.map") + endif() + endif() + + # Change the output file extension to a mex extension. + if (WIN32) + set_target_properties(${name} PROPERTIES SUFFIX ".mexw64") + elseif(APPLE) + set_target_properties(${name} PROPERTIES SUFFIX ".mexmaci64") + else() + set_target_properties(${name} PROPERTIES SUFFIX ".mexa64") + endif() + set_target_properties(${name} PROPERTIES PREFIX "") + TARGET_LINK_LIBRARIES(${name} ${MATLAB_LIBRARIES} ${ARGN}) + if (install_target_output_folder) + install(TARGETS ${name} DESTINATION "${install_target_output_folder}") + else() + install(TARGETS ${name} DESTINATION "${install_dir}") + endif() +ENDMACRO() + + diff --git a/ml/dlib/dlib/matlab/example.m b/ml/dlib/dlib/matlab/example.m new file mode 100644 index 000000000..8ed47346b --- /dev/null +++ b/ml/dlib/dlib/matlab/example.m @@ -0,0 +1,16 @@ +% This example calls the three mex functions defined in this folder. As you +% can see, you call them just like you would normal MATLAB functions. + +x = magic(3) +y = 2*magic(3) + +[out1, out2] = example_mex_function(x,y, 12345) + +z = example_mex_callback(x, @(a)a+a) + + +input = {} +input.val = 2 +input.stuff = 'some string' +output = example_mex_struct(input) + diff --git a/ml/dlib/dlib/matlab/example_mex_callback.cpp b/ml/dlib/dlib/matlab/example_mex_callback.cpp new file mode 100644 index 000000000..a5a25dda1 --- /dev/null +++ b/ml/dlib/dlib/matlab/example_mex_callback.cpp @@ -0,0 +1,52 @@ +// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt + +#include "call_matlab.h" +#include "dlib/matrix.h" + +using namespace dlib; +using namespace std; + +/* + This mex function takes a MATLAB function handle, calls it, and + returns the results. + + For example, you can call this function in MATLAB like so: + A = magic(3) + y = example_mex_callback(A, @(x)x+x) + + This will result in y containing the value 2*A. +*/ + +void mex_function ( + const matrix& A, + const function_handle& f, + matrix& result +) +{ + // The f argument to this function is a function handle passed from MATLAB. To + // call it we use the following syntax: + call_matlab(f, A, returns(result)); + // This is equivalent to result = f(A). Therefore, the returns(variable) syntax + // is used to indicate which variables are outputs of the function. + + + + + // Another thing we can do is call MATLAB functions based on their string name + // rather than a function_handle. Here is an example of calling eigs(). + matrix m(2,2); + m = 1,2, + 3,4; + matrix v,d; + + // This is equivalent to [v,d] = eigs(m); + call_matlab("eigs", m, returns(v), returns(d)); + cout << "eigenvectors: \n" << v << endl; + cout << "eigenvalues: \n" << d << endl; +} + + + +// #including this brings in all the mex boiler plate needed by MATLAB. +#include "mex_wrapper.cpp" + diff --git a/ml/dlib/dlib/matlab/example_mex_class.cpp b/ml/dlib/dlib/matlab/example_mex_class.cpp new file mode 100644 index 000000000..b4242721b --- /dev/null +++ b/ml/dlib/dlib/matlab/example_mex_class.cpp @@ -0,0 +1,72 @@ +// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt +/* + This mex file will create a MATLAB function called example_mex_class. If you call it + with no arguments it will output the MATLAB .m code to create a MATLAB wrapper class. + Paste that code into a .m file. Then you will be able to work with this C++ class + directly in MATLAB. +*/ + +#include +#include + + +using namespace std; +using namespace dlib; + +class example_class +{ +public: + + // The class must have a default constructor. It's also the only kind of constructor + // you can call from MATLAB. + example_class() + { + xx.set_size(3,2); + xx = 1; + } + + // The rest of the member functions that you want to bind have to return void and + // generally have the same syntax limitations as regular mex funcitons. + void do_stuff(const matrix_colmajor& x) + { + cout << "in do_stuff" << endl; + cout << x << endl; + xx = x; + } + + void do_other_stuff(int x) + { + cout << "in do_other_stuff" << endl; + cout << "x: " << x << endl; + } + + void print_state() + { + cout << xx << endl; + } + + // saveobj() and load_obj() are special functions. If you provide these then you will + // be able to save() and load() your objects using MATLAB's built in object + // serialization. + void saveobj(matrix_colmajor& state) + { + // save this object's state to state. + state = xx; + } + void load_obj(const matrix_colmajor& state) + { + xx = state; + } + +private: + matrix_colmajor xx; +}; + +// Just tell the mex wrapper the name of your class and list the methods you want to bind. +#define MEX_CLASS_NAME example_class +#define MEX_CLASS_METHODS do_stuff, do_other_stuff, print_state, saveobj, load_obj + + +#include "mex_wrapper.cpp" + + diff --git a/ml/dlib/dlib/matlab/example_mex_function.cpp b/ml/dlib/dlib/matlab/example_mex_function.cpp new file mode 100644 index 000000000..49d2e35fa --- /dev/null +++ b/ml/dlib/dlib/matlab/example_mex_function.cpp @@ -0,0 +1,84 @@ +// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt + +#include "dlib/matrix.h" +using namespace dlib; +using namespace std; + + +/*! + This file defines a function callable from MATLAB once you mex it. + + It computes the same thing as the following MATLAB function: + + function [A, B] = example_mex_function(x, y, some_number) + A = x+y; + B = sum(sum(x+y)); + disp(['some_number: ' num2str(some_number)]) + end + + + VALID INPUT AND OUTPUT ARGUMENTS + The mex wrapper can handle the following kinds of input and output arguments: + - Types corresponding to a MATLAB matrix + - a dlib::matrix containing any kind of scalar value. + - a dlib::array2d containing any kind of scalar value. + - a dlib::vector containing any kind of scalar value. + - a dlib::point + - matrix_colmajor or fmatrix_colmajor + These are just typedefs for matrix containing double or float and using a + column major memory layout. However, they have the special distinction + of being fast to use in mex files since they sit directly on top of + MATLAB's built in matrices. That is, while other types of arguments copy + a MATLAB object into themselves, the matrix_colmajor and fmatrix_colmajor + do no such copy and are effectively zero overhead methods for working on + MATLAB's matrices. + + - RGB color images + - dlib::array2d can be used to represent + MATLAB uint8 MxNx3 images. + + - Types corresponding to a MATLAB scalar + - any kind of scalar value, e.g. double, int, etc. + + - Types corresponding to a MATLAB string + - std::string + + - Types corresponding to a MATLAB cell array + - a std::vector or dlib::array containing any of the above + types of objects or std::vector or dlib::array objects. + + - matlab_struct and matlab_object. These are special types defined in the + call_matlab.h file and correspond to matlab structs and arbitrary matlab + objects respectively. +!*/ + + +// You can also define default values for your input arguments. So +// here we say that if the user in MATLAB doesn't provide the "some_number" +// then it will get a value of 3.141. +#define ARG_5_DEFAULT 3.141 + +// Make a function named mex_function() and put your code inside it. +// Note that the return type should be void. Use non-const reference +// arguments to return outputs. Finally, mex_function() must have no +// more than 20 arguments. +void mex_function ( + const matrix_colmajor& x, + const matrix_colmajor& y, + matrix_colmajor& out1, + double& out2, + double some_number +) +{ + out1 = x + y; + out2 = sum(x+y); + + // we can also use cout to print things as usual: + cout << "some_number: "<< some_number << endl; +} + + + +// #including this brings in all the mex boiler plate needed by MATLAB. +#include "mex_wrapper.cpp" + diff --git a/ml/dlib/dlib/matlab/example_mex_struct.cpp b/ml/dlib/dlib/matlab/example_mex_struct.cpp new file mode 100644 index 000000000..a948fbff9 --- /dev/null +++ b/ml/dlib/dlib/matlab/example_mex_struct.cpp @@ -0,0 +1,55 @@ +// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt + +#include "call_matlab.h" +#include "dlib/matrix.h" +using namespace dlib; +using namespace std; + + +/* + This mex function takes a MATLAB struct, prints a few of its fields, + and then returns a new struct. + + For example, you can call this function in MATLAB like so: + input = {} + input.val = 2 + input.stuff = 'some string' + output = example_mex_struct(input) + + output.number + output.number2 + output.sub.stuff + output.sub.some_matrix +*/ + + +void mex_function ( + const matlab_struct& input, + matlab_struct& output +) +{ + int val = input["val"]; + string stuff = input["stuff"]; + + if (input.has_field("val2")) + { + string val2 = input["val2"]; + cout << "The optional val2 field was set to: " << val2 << endl; + } + + cout << "val: "<< val << endl; + cout << "stuff: " << stuff << endl; + + output["number"] = 999; + + output["number2"] = 1000; + output["sub"]["stuff"] = "some other string"; + matrix m = randm(2,2); + output["sub"]["some_matrix"] = m; +} + + + +// #including this brings in all the mex boiler plate needed by MATLAB. +#include "mex_wrapper.cpp" + diff --git a/ml/dlib/dlib/matlab/mex_wrapper.cpp b/ml/dlib/dlib/matlab/mex_wrapper.cpp new file mode 100644 index 000000000..30c7e12ac --- /dev/null +++ b/ml/dlib/dlib/matlab/mex_wrapper.cpp @@ -0,0 +1,5144 @@ +// Copyright (C) 2012 Massachusetts Institute of Technology, Lincoln Laboratory +// License: Boost Software License See LICENSE.txt for the full license. +// Authors: Davis E. King (davis@dlib.net) +/* + READ THIS FIRST + ###### + ###### + ###### + ###### + ###### + ###### + ###### + ###### + ###### + ###### + ###### + ###### + ###### + \############/ + \##########/ + \########/ + \######/ + \####/ + \##/ + \/ + + !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + + See example_mex_function.cpp for a discussion of how to use the mex wrapper. + + !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + + /\ + /##\ + /####\ + /######\ + /########\ + /##########\ + /############\ + ###### + ###### + ###### + ###### + ###### + ###### + ###### + ###### + ###### + ###### + ###### + ###### + ###### + READ THIS FIRST +*/ + +// Copyright (C) 2012 Massachusetts Institute of Technology, Lincoln Laboratory +// License: Boost Software License See LICENSE.txt for the full license. +// Authors: Davis E. King (davis@dlib.net) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// BEGIN IMPLEMENTATION DETAILS +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + +#include "../matrix.h" +#include "../array2d.h" +#include "../array.h" +#include "../image_transforms.h" +#include "../is_kind.h" +#include "../string.h" +#include "../any.h" // for sig_traits +#include "../hash.h" +#include +#include + +#if defined(_MSC_VER) +#define DLL_EXPORT_SYM __declspec(dllexport) +#endif +#include "mex.h" +#include +#include "call_matlab.h" + +// ---------------------------------------------------------------------------------------- + +#ifdef ARG_1_DEFAULT +#define ELSE_ASSIGN_ARG_1 else A1 = ARG_1_DEFAULT; +#else +#define ELSE_ASSIGN_ARG_1 +#endif + +#ifdef ARG_2_DEFAULT +#define ELSE_ASSIGN_ARG_2 else A2 = ARG_2_DEFAULT; +#else +#define ELSE_ASSIGN_ARG_2 +#endif + +#ifdef ARG_3_DEFAULT +#define ELSE_ASSIGN_ARG_3 else A3 = ARG_3_DEFAULT; +#else +#define ELSE_ASSIGN_ARG_3 +#endif + +#ifdef ARG_4_DEFAULT +#define ELSE_ASSIGN_ARG_4 else A4 = ARG_4_DEFAULT; +#else +#define ELSE_ASSIGN_ARG_4 +#endif + +#ifdef ARG_5_DEFAULT +#define ELSE_ASSIGN_ARG_5 else A5 = ARG_5_DEFAULT; +#else +#define ELSE_ASSIGN_ARG_5 +#endif + +#ifdef ARG_6_DEFAULT +#define ELSE_ASSIGN_ARG_6 else A6 = ARG_6_DEFAULT; +#else +#define ELSE_ASSIGN_ARG_6 +#endif + +#ifdef ARG_7_DEFAULT +#define ELSE_ASSIGN_ARG_7 else A7 = ARG_7_DEFAULT; +#else +#define ELSE_ASSIGN_ARG_7 +#endif + +#ifdef ARG_8_DEFAULT +#define ELSE_ASSIGN_ARG_8 else A8 = ARG_8_DEFAULT; +#else +#define ELSE_ASSIGN_ARG_8 +#endif + +#ifdef ARG_9_DEFAULT +#define ELSE_ASSIGN_ARG_9 else A9 = ARG_9_DEFAULT; +#else +#define ELSE_ASSIGN_ARG_9 +#endif + +#ifdef ARG_10_DEFAULT +#define ELSE_ASSIGN_ARG_10 else A10 = ARG_10_DEFAULT; +#else +#define ELSE_ASSIGN_ARG_10 +#endif + +#ifdef ARG_11_DEFAULT +#define ELSE_ASSIGN_ARG_11 else A11 = ARG_11_DEFAULT; +#else +#define ELSE_ASSIGN_ARG_11 +#endif + +#ifdef ARG_12_DEFAULT +#define ELSE_ASSIGN_ARG_12 else A12 = ARG_12_DEFAULT; +#else +#define ELSE_ASSIGN_ARG_12 +#endif + +#ifdef ARG_13_DEFAULT +#define ELSE_ASSIGN_ARG_13 else A13 = ARG_13_DEFAULT; +#else +#define ELSE_ASSIGN_ARG_13 +#endif + +#ifdef ARG_14_DEFAULT +#define ELSE_ASSIGN_ARG_14 else A14 = ARG_14_DEFAULT; +#else +#define ELSE_ASSIGN_ARG_14 +#endif + +#ifdef ARG_15_DEFAULT +#define ELSE_ASSIGN_ARG_15 else A15 = ARG_15_DEFAULT; +#else +#define ELSE_ASSIGN_ARG_15 +#endif + +#ifdef ARG_16_DEFAULT +#define ELSE_ASSIGN_ARG_16 else A16 = ARG_16_DEFAULT; +#else +#define ELSE_ASSIGN_ARG_16 +#endif + +#ifdef ARG_17_DEFAULT +#define ELSE_ASSIGN_ARG_17 else A17 = ARG_17_DEFAULT; +#else +#define ELSE_ASSIGN_ARG_17 +#endif + +#ifdef ARG_18_DEFAULT +#define ELSE_ASSIGN_ARG_18 else A18 = ARG_18_DEFAULT; +#else +#define ELSE_ASSIGN_ARG_18 +#endif + +#ifdef ARG_19_DEFAULT +#define ELSE_ASSIGN_ARG_19 else A19 = ARG_19_DEFAULT; +#else +#define ELSE_ASSIGN_ARG_19 +#endif + +#ifdef ARG_20_DEFAULT +#define ELSE_ASSIGN_ARG_20 else A20 = ARG_20_DEFAULT; +#else +#define ELSE_ASSIGN_ARG_20 +#endif + +// ---------------------------------------------------------------------------------------- + +namespace mex_binding +{ + using namespace dlib; + + template + struct is_input_type + { + const static unsigned long value = (!is_same_type::value && (!is_reference_type::value || is_const_type::value )) ? 1 : 0; + }; + template + struct is_output_type + { + const static unsigned long value = (!is_same_type::value && is_reference_type::value && !is_const_type::value) ? 1 : 0; + }; + + + template + struct funct_traits + { + const static unsigned long num_inputs = is_input_type::arg1_type>::value + + is_input_type::arg2_type>::value + + is_input_type::arg3_type>::value + + is_input_type::arg4_type>::value + + is_input_type::arg5_type>::value + + is_input_type::arg6_type>::value + + is_input_type::arg7_type>::value + + is_input_type::arg8_type>::value + + is_input_type::arg9_type>::value + + is_input_type::arg10_type>::value + + is_input_type::arg11_type>::value + + is_input_type::arg12_type>::value + + is_input_type::arg13_type>::value + + is_input_type::arg14_type>::value + + is_input_type::arg15_type>::value + + is_input_type::arg16_type>::value + + is_input_type::arg17_type>::value + + is_input_type::arg18_type>::value + + is_input_type::arg19_type>::value + + is_input_type::arg20_type>::value; + + const static unsigned long num_outputs= is_output_type::arg1_type>::value + + is_output_type::arg2_type>::value + + is_output_type::arg3_type>::value + + is_output_type::arg4_type>::value + + is_output_type::arg5_type>::value + + is_output_type::arg6_type>::value + + is_output_type::arg7_type>::value + + is_output_type::arg8_type>::value + + is_output_type::arg9_type>::value + + is_output_type::arg10_type>::value + + is_output_type::arg11_type>::value + + is_output_type::arg12_type>::value + + is_output_type::arg13_type>::value + + is_output_type::arg14_type>::value + + is_output_type::arg15_type>::value + + is_output_type::arg16_type>::value + + is_output_type::arg17_type>::value + + is_output_type::arg18_type>::value + + is_output_type::arg19_type>::value + + is_output_type::arg20_type>::value; + }; + +// ---------------------------------------------------------------------------------------- + + template + struct is_array_type + { + // true if T is std::vector or dlib::array + const static bool value = is_std_vector::value || dlib::is_array::value; + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename enabled = void + > + struct inner_type + { + typedef T type; + }; + + template < typename T> + struct inner_type::value || is_array2d::value || dlib::is_array::value >::type> + { + typedef typename T::type type; + }; + + template < typename T> + struct inner_type >::type> + { + typedef typename T::value_type type; + }; + + +// ------------------------------------------------------- + + struct user_hit_ctrl_c {}; + +// ------------------------------------------------------- + + template + void validate_and_populate_arg ( + long arg_idx, + const mxArray *prhs, + T& arg + ); + +// ------------------------------------------------------- + + template + struct is_column_major_matrix : public default_is_kind_value {}; + + template < + typename T, + long num_rows, + long num_cols, + typename mem_manager + > + struct is_column_major_matrix > + { static const bool value = true; }; + +// ------------------------------------------------------- + + string escape_percent(const string& str) + { + string temp; + for(auto c : str) + { + if (c != '%') + { + temp += c; + } + else + { + temp += c; + temp += c; + } + } + return temp; + } + + string escape_percent(const std::ostringstream& sout) + { + return escape_percent(sout.str()); + } + +// ------------------------------------------------------- + + template < + typename matrix_type + > + typename dlib::enable_if_c::value || is_array2d::value >::type + clear_mat ( + matrix_type& m + ) + { + m.set_size(0,0); + } + + template < + typename matrix_type + > + typename dlib::disable_if_c::value || is_array2d::value >::type + clear_mat ( + matrix_type& + ) + { + } + +// ------------------------------------------------------- + + template < + typename matrix_type, + typename EXP + > + typename dlib::enable_if_c::value && is_same_type::type,typename EXP::type>::value >::type + assign_mat ( + const long arg_idx, + matrix_type& m, + const matrix_exp& src + ) + { + if (matrix_type::NR != 0 && matrix_type::NR != src.nc()) + { + std::ostringstream sout; + sout << "Argument " << arg_idx+1 << " expects a matrix with " << matrix_type::NR << " rows but got one with " << src.nc(); + throw invalid_args_exception(sout); + } + if (matrix_type::NC != 0 && matrix_type::NC != src.nr()) + { + std::ostringstream sout; + sout << "Argument " << arg_idx+1 << " expects a matrix with " << matrix_type::NC << " columns but got one with " << src.nr(); + throw invalid_args_exception(sout); + } + + + m = trans(src); + } + + template < + typename matrix_type, + typename EXP + > + typename dlib::enable_if_c::value && is_same_type::type,typename EXP::type>::value >::type + assign_mat ( + const long arg_idx, + matrix_type& m, + const matrix_exp& src + ) + { + assign_image(m , trans(src)); + } + + template < + typename matrix_type, + typename EXP + > + typename disable_if_c<(is_array2d::value || is_matrix::value) && + is_same_type::type,typename EXP::type>::value >::type + assign_mat ( + const long arg_idx, + matrix_type& , + const matrix_exp& + ) + { + std::ostringstream sout; + sout << "mex_function has some bug in it related to processing input argument " << arg_idx+1; + throw invalid_args_exception(sout); + } + + +// ------------------------------------------------------- + + template < + typename T, + typename U + > + typename dlib::enable_if_c::value || is_same_type::value >::type + assign_scalar ( + const long arg_idx, + T& dest, + const U& src + ) + { + if (is_signed_type::value && src < 0 && is_unsigned_type::value) + { + std::ostringstream sout; + sout << "Error, input argument " << arg_idx+1 << " must be a non-negative number."; + throw invalid_args_exception(sout); + } + else + { + dest = src; + } + } + + template < + typename T, + typename U + > + typename dlib::disable_if_c::value || is_same_type::value >::type + assign_scalar ( + const long arg_idx, + T& , + const U& + ) + { + std::ostringstream sout; + sout << "mex_function has some bug in it related to processing input argument " << arg_idx+1; + throw invalid_args_exception(sout); + } + + +// ------------------------------------------------------- + + void assign_function_handle ( + const long arg_idx, + function_handle& dest, + const mxArray* src + ) + { + const_cast(dest.h) = (void*)src; + } + + template < + typename T + > + void assign_function_handle ( + const long arg_idx, + T& , + const mxArray* + ) + { + std::ostringstream sout; + sout << "mex_function has some bug in it related to processing input argument " << arg_idx+1; + throw invalid_args_exception(sout); + } + + +// ------------------------------------------------------- + + template < + typename T + > + typename dlib::enable_if >::type + assign_std_vector ( + const long arg_idx, + T& dest, + const mxArray* src + ) + { + const long nr = mxGetM(src); + const long nc = mxGetN(src); + + typedef typename inner_type::type type; + + if (!mxIsCell(src)) + { + std::ostringstream sout; + sout << "Input argument " << arg_idx+1 << " must be a cell array"; + throw invalid_args_exception(sout); + } + if (nr != 1 && nc != 1) + { + std::ostringstream sout; + sout << "Input argument " << arg_idx+1 << " must be a cell array with exactly 1 row or 1 column (i.e. a row or column vector)"; + throw invalid_args_exception(sout); + } + + const long size = nr*nc; + dest.resize(size); + + for (unsigned long i = 0; i < dest.size(); ++i) + { + try + { + validate_and_populate_arg(i, mxGetCell(src, i), dest[i]); + } + catch (invalid_args_exception& e) + { + std::ostringstream sout; + sout << "Error in argument " << arg_idx+1 << ": element " << i+1 << " of cell array not the expected type.\n"; + sout << "\t" << e.what(); + throw invalid_args_exception(sout); + } + } + + } + + template < + typename T + > + typename disable_if >::type + assign_std_vector ( + const long arg_idx, + T& , + const mxArray* + ) + { + std::ostringstream sout; + sout << "mex_function has some bug in it related to processing input argument " << arg_idx+1; + throw invalid_args_exception(sout); + } + +// ------------------------------------------------------- + + template + void assign_image ( + const long arg_idx, + T&, + const dlib::uint8* data, + long nr, + long nc + ) + { + std::ostringstream sout; + sout << "mex_function has some bug in it related to processing input argument " << arg_idx+1; + throw invalid_args_exception(sout); + } + + template + void assign_image( + const long , + array2d& img, + const dlib::uint8* data, + long nr, + long nc + ) + { + img.set_size(nr, nc); + for (long c = 0; c < img.nc(); ++c) + for (long r = 0; r < img.nr(); ++r) + img[r][c].red = *data++; + for (long c = 0; c < img.nc(); ++c) + for (long r = 0; r < img.nr(); ++r) + img[r][c].green = *data++; + for (long c = 0; c < img.nc(); ++c) + for (long r = 0; r < img.nr(); ++r) + img[r][c].blue = *data++; + } + +// ------------------------------------------------------- + + template + void call_private_set_mxArray(T&, mxArray*) {} + void call_private_set_mxArray(matrix_colmajor& item, mxArray* m) { item._private_set_mxArray(m); } + void call_private_set_mxArray(fmatrix_colmajor& item, mxArray* m) { item._private_set_mxArray(m); } + +// ------------------------------------------------------- + + template + void validate_and_populate_arg ( + long arg_idx, + const mxArray *prhs, + T& arg + ) + { + using namespace mex_binding; + if (is_built_in_scalar_type::value || is_same_type::value) + { + if( !(mxIsDouble(prhs) || mxIsSingle(prhs) || mxIsLogical(prhs) ) || + mxIsComplex(prhs) || + mxGetNumberOfElements(prhs)!=1 ) + { + std::ostringstream sout; + sout << "Input argument " << arg_idx+1 << " must be a scalar"; + throw invalid_args_exception(sout); + } + + assign_scalar(arg_idx, arg , mxGetScalar(prhs)); + } + else if (is_matrix::value || is_array2d::value) + { + if (prhs == NULL) + { + clear_mat(arg); + return; + } + + typedef typename inner_type::type type; + + const int num_dims = mxGetNumberOfDimensions(prhs); + const long nr = mxGetM(prhs); + const long nc = mxGetN(prhs); + + if (is_same_type::value) + { + if (!(num_dims == 3 && mxGetDimensions(prhs)[2] == 3 && mxIsUint8(prhs))) + { + std::ostringstream sout; + sout << "Input argument " << arg_idx+1 << " must be a 3-D NxMx3 image matrix of uint8"; + throw invalid_args_exception(sout); + } + + const long rows = mxGetDimensions(prhs)[0]; + const long cols = mxGetDimensions(prhs)[1]; + assign_image(arg_idx, arg , (const dlib::uint8*)mxGetData(prhs), rows, cols); + return; + } + + if (num_dims != 2) + { + std::ostringstream sout; + sout << "Input argument " << arg_idx+1 << " must be a 2-D matrix (got a " << num_dims << "-D matrix)"; + throw invalid_args_exception(sout); + } + + + if (is_same_type::value) + { + if (!mxIsDouble(prhs) || mxIsComplex(prhs)) + { + std::ostringstream sout; + sout << "Input argument " << arg_idx+1 << " must be a matrix of doubles"; + throw invalid_args_exception(sout); + } + if (is_column_major_matrix::value) + call_private_set_mxArray(arg, (mxArray*)prhs); + else + assign_mat(arg_idx, arg , pointer_to_matrix(mxGetPr(prhs), nc, nr)); + } + else if (is_same_type::value) + { + if (!mxIsSingle(prhs) || mxIsComplex(prhs)) + { + std::ostringstream sout; + sout << "Input argument " << arg_idx+1 << " must be a matrix of single/float"; + throw invalid_args_exception(sout); + } + + if (is_column_major_matrix::value) + call_private_set_mxArray(arg,(mxArray*)prhs); + else + assign_mat(arg_idx, arg , pointer_to_matrix((const float*)mxGetData(prhs), nc, nr)); + } + else if (is_same_type::value) + { + if (!mxIsLogical(prhs)) + { + std::ostringstream sout; + sout << "Input argument " << arg_idx+1 << " must be a matrix of logical elements."; + throw invalid_args_exception(sout); + } + DLIB_CASSERT(sizeof(mxLogical) == sizeof(bool),"logical matrices are not supported by the mex wrapper when mxLogical isn't a bool."); + + assign_mat(arg_idx, arg , pointer_to_matrix((const bool*)mxGetData(prhs), nc, nr)); + } + else if (is_same_type::value) + { + if (!mxIsUint8(prhs) || mxIsComplex(prhs)) + { + std::ostringstream sout; + sout << "Input argument " << arg_idx+1 << " must be a matrix of uint8"; + throw invalid_args_exception(sout); + } + + assign_mat(arg_idx, arg , pointer_to_matrix((const dlib::uint8*)mxGetData(prhs), nc, nr)); + } + else if (is_same_type::value) + { + if (!mxIsInt8(prhs) || mxIsComplex(prhs)) + { + std::ostringstream sout; + sout << "Input argument " << arg_idx+1 << " must be a matrix of int8"; + throw invalid_args_exception(sout); + } + + assign_mat(arg_idx, arg , pointer_to_matrix((const dlib::int8*)mxGetData(prhs), nc, nr)); + } + else if (is_same_type::value || + (is_same_type::value && sizeof(short) == sizeof(dlib::int16))) + { + if (!mxIsInt16(prhs) || mxIsComplex(prhs)) + { + std::ostringstream sout; + sout << "Input argument " << arg_idx+1 << " must be a matrix of int16"; + throw invalid_args_exception(sout); + } + + assign_mat(arg_idx, arg , pointer_to_matrix((const type*)mxGetData(prhs), nc, nr)); + } + else if (is_same_type::value || + (is_same_type::value && sizeof(unsigned short) == sizeof(dlib::uint16))) + { + if (!mxIsUint16(prhs) || mxIsComplex(prhs)) + { + std::ostringstream sout; + sout << "Input argument " << arg_idx+1 << " must be a matrix of uint16"; + throw invalid_args_exception(sout); + } + + assign_mat(arg_idx, arg , pointer_to_matrix((const type*)mxGetData(prhs), nc, nr)); + } + else if (is_same_type::value || + (is_same_type::value && sizeof(int) == sizeof(dlib::int32)) || + (is_same_type::value && sizeof(long) == sizeof(dlib::int32))) + { + if (!mxIsInt32(prhs) || mxIsComplex(prhs)) + { + std::ostringstream sout; + sout << "Input argument " << arg_idx+1 << " must be a matrix of int32"; + throw invalid_args_exception(sout); + } + + assign_mat(arg_idx, arg , pointer_to_matrix((const type*)mxGetData(prhs), nc, nr)); + } + else if (is_same_type::value || + (is_same_type::value && sizeof(unsigned int) == sizeof(dlib::uint32)) || + (is_same_type::value && sizeof(unsigned long) == sizeof(dlib::uint32))) + { + if (!mxIsUint32(prhs) || mxIsComplex(prhs)) + { + std::ostringstream sout; + sout << "Input argument " << arg_idx+1 << " must be a matrix of uint32"; + throw invalid_args_exception(sout); + } + + assign_mat(arg_idx, arg , pointer_to_matrix((const type*)mxGetData(prhs), nc, nr)); + } + else if (is_same_type::value || + (is_same_type::value && sizeof(unsigned int) == sizeof(dlib::uint64)) || + (is_same_type::value && sizeof(unsigned long) == sizeof(dlib::uint64))) + { + if (!mxIsUint64(prhs) || mxIsComplex(prhs)) + { + std::ostringstream sout; + sout << "Input argument " << arg_idx+1 << " must be a matrix of uint64"; + throw invalid_args_exception(sout); + } + + assign_mat(arg_idx, arg , pointer_to_matrix((const type*)mxGetData(prhs), nc, nr)); + } + else if (is_same_type::value || + (is_same_type::value && sizeof(int) == sizeof(dlib::int64)) || + (is_same_type::value && sizeof(long) == sizeof(dlib::int64))) + { + if (!mxIsInt64(prhs) || mxIsComplex(prhs)) + { + std::ostringstream sout; + sout << "Input argument " << arg_idx+1 << " must be a matrix of int64"; + throw invalid_args_exception(sout); + } + + assign_mat(arg_idx, arg , pointer_to_matrix((const type*)mxGetData(prhs), nc, nr)); + } + else + { + throw invalid_args_exception("mex_function uses unsupported matrix type"); + } + } + else if (is_array_type::value) + { + assign_std_vector(arg_idx, arg, prhs); + + } + else if (is_same_type::value) + { + if (!mxIsClass(prhs, "function_handle")) + { + std::ostringstream sout; + sout << "Input argument " << arg_idx+1 << " must be a function handle."; + throw invalid_args_exception(sout); + } + assign_function_handle(arg_idx, arg, prhs); + } + else + { + throw invalid_args_exception("mex_function uses unsupported input argument type"); + } + } + + void validate_and_populate_arg( + long arg_idx, + const mxArray *prhs, + matlab_struct& arg + ) + { + if (!mxIsStruct(prhs)) + { + std::ostringstream sout; + sout << "Input argument " << arg_idx+1 << " must be a struct"; + throw invalid_args_exception(sout); + } + + arg.set_struct_handle(arg_idx, prhs); + } + + + void validate_and_populate_arg( + long arg_idx, + const mxArray *prhs, + matlab_object& arg + ) + { + arg.set_object_handle(arg_idx, prhs); + } + + + void validate_and_populate_arg( + long arg_idx, + const mxArray *prhs, + std::string& arg + ) + { + if (!mxIsChar(prhs)) + { + std::ostringstream sout; + sout << "Input argument " << arg_idx+1 << " must be a char string"; + throw invalid_args_exception(sout); + } + + const long nr = mxGetM(prhs); + const long nc = mxGetN(prhs); + const long size = nr*nc; + arg.resize(size+1); + if (mxGetString(prhs, &arg[0], arg.size())) + { + std::ostringstream sout; + sout << "Input argument " << arg_idx+1 << " encountered an error while calling mxGetString()"; + throw invalid_args_exception(sout); + } + arg.resize(size); + } + +// ---------------------------------------------------------------------------------------- + + template + typename dlib::enable_if >::type assign_image_to_matlab ( + dlib::uint8* mat, + const matrix_exp& item + ) + { + for (long c = 0; c < item.nc(); ++c) + for (long r = 0; r < item.nr(); ++r) + *mat++ = item(r,c).red; + for (long c = 0; c < item.nc(); ++c) + for (long r = 0; r < item.nr(); ++r) + *mat++ = item(r,c).green; + for (long c = 0; c < item.nc(); ++c) + for (long r = 0; r < item.nr(); ++r) + *mat++ = item(r,c).blue; + } + + template + typename disable_if >::type assign_image_to_matlab ( + T* mat, + const matrix_exp& + ) + { + mexErrMsgIdAndTxt("mex_function:validate_and_populate_arg", + "mex_function uses unsupported output image argument type"); + } + + template + typename dlib::enable_if >::type assign_to_matlab( + mxArray*& plhs, + const T& item + ) + { + typedef typename T::type type; + + type* mat = 0; + + if (is_same_type::value) + { + plhs = mxCreateDoubleMatrix(item.nr(), + item.nc(), + mxREAL); + + mat = (type*)mxGetPr(plhs); + } + else if (is_same_type::value ) + { + plhs = mxCreateNumericMatrix(item.nr(), + item.nc(), + mxSINGLE_CLASS, + mxREAL); + + mat = (type*)mxGetData(plhs); + } + else if (is_same_type::value ) + { + plhs = mxCreateLogicalMatrix(item.nr(), + item.nc()); + + mat = (type*)mxGetData(plhs); + } + else if (is_same_type::value ) + { + plhs = mxCreateNumericMatrix(item.nr(), + item.nc(), + mxUINT8_CLASS, + mxREAL); + + mat = (type*)mxGetData(plhs); + } + else if (is_same_type::value ) + { + plhs = mxCreateNumericMatrix(item.nr(), + item.nc(), + mxINT8_CLASS, + mxREAL); + + mat = (type*)mxGetData(plhs); + } + else if (is_same_type::value || + (is_same_type::value && sizeof(short) == sizeof(dlib::int16))) + { + plhs = mxCreateNumericMatrix(item.nr(), + item.nc(), + mxINT16_CLASS, + mxREAL); + + mat = (type*)mxGetData(plhs); + } + else if (is_same_type::value || + (is_same_type::value && sizeof(unsigned short) == sizeof(dlib::uint16))) + { + plhs = mxCreateNumericMatrix(item.nr(), + item.nc(), + mxUINT16_CLASS, + mxREAL); + + mat = (type*)mxGetData(plhs); + } + else if (is_same_type::value || + (is_same_type::value && sizeof(long) == sizeof(dlib::int32))) + { + plhs = mxCreateNumericMatrix(item.nr(), + item.nc(), + mxINT32_CLASS, + mxREAL); + + mat = (type*)mxGetData(plhs); + } + else if (is_same_type::value || + (is_same_type::value && sizeof(unsigned long) == sizeof(dlib::uint32))) + { + plhs = mxCreateNumericMatrix(item.nr(), + item.nc(), + mxUINT32_CLASS, + mxREAL); + + mat = (type*)mxGetData(plhs); + } + else if (is_same_type::value || + (is_same_type::value && sizeof(unsigned long) == sizeof(dlib::uint64))) + { + plhs = mxCreateNumericMatrix(item.nr(), + item.nc(), + mxUINT64_CLASS, + mxREAL); + + mat = (type*)mxGetData(plhs); + } + else if (is_same_type::value || + (is_same_type::value && sizeof(long) == sizeof(dlib::int64))) + { + plhs = mxCreateNumericMatrix(item.nr(), + item.nc(), + mxINT64_CLASS, + mxREAL); + + mat = (type*)mxGetData(plhs); + } + else if (is_same_type::value) + { + mwSize dims[3] = {(mwSize)item.nr(), (mwSize)item.nc(), 3}; + plhs = mxCreateNumericArray(3, dims, mxUINT8_CLASS, mxREAL); + + assign_image_to_matlab((dlib::uint8*)mxGetData(plhs), item); + return; + } + else + { + mexErrMsgIdAndTxt("mex_function:validate_and_populate_arg", + "mex_function uses unsupported output argument type"); + } + + + const_temp_matrix m(item); + + for (long c = 0; c < m.nc(); ++c) + { + for ( long r = 0; r < m.nr(); ++r) + { + *mat++ = m(r,c); + } + } + } + + void assign_to_matlab( + mxArray*& plhs, + matrix_colmajor& item + ) + { + if(item._private_is_owned_by_matlab()) + { + // Don't need to do a copy if it's this kind of matrix since we can just + // pull the underlying mxArray out directly and thus avoid a copy. + plhs = item._private_release_mxArray(); + // If there isn't anything there because the matrix is empty then set it to an + // empty matrix. + if (!plhs) + plhs = mxCreateDoubleMatrix(item.nr(), + item.nc(), + mxREAL); + } + else + { + plhs = mxCreateDoubleMatrix(item.nr(), + item.nc(), + mxREAL); + if (item.size() != 0) + memcpy(mxGetPr(plhs), &item(0,0), item.size()*sizeof(double)); + } + } + + void assign_to_matlab( + mxArray*& plhs, + fmatrix_colmajor& item + ) + { + if(item._private_is_owned_by_matlab()) + { + // Don't need to do a copy if it's this kind of matrix since we can just + // pull the underlying mxArray out directly and thus avoid a copy. + plhs = item._private_release_mxArray(); + // If there isn't anything there because the matrix is empty then set it to an + // empty matrix. + if (!plhs) + plhs = mxCreateNumericMatrix(item.nr(), + item.nc(), + mxSINGLE_CLASS, + mxREAL); + } + else + { + plhs = mxCreateNumericMatrix(item.nr(), + item.nc(), + mxSINGLE_CLASS, + mxREAL); + if (item.size() != 0) + memcpy(mxGetPr(plhs), &item(0,0), item.size()*sizeof(float)); + } + } + + void assign_to_matlab( + mxArray*& plhs, + matlab_struct& item + ) + { + plhs = (mxArray*)item.release_struct_to_matlab(); + } + + void assign_to_matlab( + mxArray*& plhs, + matlab_object& item + ) + { + plhs = (mxArray*)item.release_object_to_matlab(); + } + + void assign_to_matlab( + mxArray*& plhs, + const std::string& item + ) + { + plhs = mxCreateString(item.c_str()); + } + + template + void assign_to_matlab( + mxArray*& plhs, + const array2d& item + ) + { + assign_to_matlab(plhs,array_to_matrix(item)); + } + + template + typename dlib::disable_if_c::value || is_array_type::value || + is_same_type::value>::type assign_to_matlab( + mxArray*& plhs, + const T& item + ) + { + plhs = mxCreateDoubleScalar(item); + } + + + void assign_to_matlab ( + mxArray*& plhs, + const char* str + ) + { + assign_to_matlab(plhs, std::string(str)); + } + + void assign_to_matlab( + mxArray*& plhs, + const function_handle& h + ) + { + } + + template + typename dlib::enable_if >::type assign_to_matlab( + mxArray*& plhs, + const T& item + ) + { + mwSize dims[1] = {item.size()}; + plhs = mxCreateCellArray(1,dims); + for (unsigned long i = 0; i < item.size(); ++i) + { + mxArray* next = 0; + assign_to_matlab(next, item[i]); + mxSetCell(plhs, i, next); + } + } + +// ---------------------------------------------------------------------------------------- + + template + void mark_owned_by_matlab (const T&){} + + void mark_owned_by_matlab(matrix_colmajor& item) { item._private_mark_owned_by_matlab(); } + void mark_owned_by_matlab(fmatrix_colmajor& item) { item._private_mark_owned_by_matlab(); } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long num_args + > + struct call_mex_function_helper; + + template <> + struct call_mex_function_helper<0> + { + template + void callit( + const funct& f, + int , mxArray **, + int , const mxArray ** + ) const + { + f(); + } + }; + + template <> + struct call_mex_function_helper<1> + { + template + void callit( + const funct& f, + int nlhs, mxArray *plhs[], + int nrhs, const mxArray *prhs[] + ) const + { + typedef typename sig_traits::arg1_type arg1_type; + + typename basic_type::type A1; + + mark_owned_by_matlab(A1); + + int i = 0; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1; + + f(A1); + + i = 0; + if (is_output_type::value) {assign_to_matlab(plhs[i],A1); ++i;} + } + }; + + template <> + struct call_mex_function_helper<2> + { + template + void callit( + const funct& f, + int nlhs, mxArray *plhs[], + int nrhs, const mxArray *prhs[] + ) const + { + typedef typename sig_traits::arg1_type arg1_type; + typedef typename sig_traits::arg2_type arg2_type; + + typename basic_type::type A1; + typename basic_type::type A2; + + mark_owned_by_matlab(A1); + mark_owned_by_matlab(A2); + + int i = 0; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2; + + f(A1,A2); + + i = 0; + if (is_output_type::value) {assign_to_matlab(plhs[i],A1); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A2); ++i;} + } + }; + + template <> + struct call_mex_function_helper<3> + { + template + void callit( + const funct& f, + int nlhs, mxArray *plhs[], + int nrhs, const mxArray *prhs[] + ) const + { + typedef typename sig_traits::arg1_type arg1_type; + typedef typename sig_traits::arg2_type arg2_type; + typedef typename sig_traits::arg3_type arg3_type; + + typename basic_type::type A1; + typename basic_type::type A2; + typename basic_type::type A3; + + mark_owned_by_matlab(A1); + mark_owned_by_matlab(A2); + mark_owned_by_matlab(A3); + + int i = 0; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A3); ++i;} ELSE_ASSIGN_ARG_3; + + f(A1,A2,A3); + + i = 0; + if (is_output_type::value) {assign_to_matlab(plhs[i],A1); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A2); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A3); ++i;} + } + }; + + template <> + struct call_mex_function_helper<4> + { + template + void callit( + const funct& f, + int nlhs, mxArray *plhs[], + int nrhs, const mxArray *prhs[] + ) const + { + typedef typename sig_traits::arg1_type arg1_type; + typedef typename sig_traits::arg2_type arg2_type; + typedef typename sig_traits::arg3_type arg3_type; + typedef typename sig_traits::arg4_type arg4_type; + + typename basic_type::type A1; + typename basic_type::type A2; + typename basic_type::type A3; + typename basic_type::type A4; + + mark_owned_by_matlab(A1); + mark_owned_by_matlab(A2); + mark_owned_by_matlab(A3); + mark_owned_by_matlab(A4); + + int i = 0; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A3); ++i;} ELSE_ASSIGN_ARG_3; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A4); ++i;} ELSE_ASSIGN_ARG_4; + + f(A1,A2,A3,A4); + + + i = 0; + if (is_output_type::value) {assign_to_matlab(plhs[i],A1); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A2); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A3); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A4); ++i;} + } + }; + + template <> + struct call_mex_function_helper<5> + { + template + void callit( + const funct& f, + int nlhs, mxArray *plhs[], + int nrhs, const mxArray *prhs[] + ) const + { + typedef typename sig_traits::arg1_type arg1_type; + typedef typename sig_traits::arg2_type arg2_type; + typedef typename sig_traits::arg3_type arg3_type; + typedef typename sig_traits::arg4_type arg4_type; + typedef typename sig_traits::arg5_type arg5_type; + + typename basic_type::type A1; + typename basic_type::type A2; + typename basic_type::type A3; + typename basic_type::type A4; + typename basic_type::type A5; + + mark_owned_by_matlab(A1); + mark_owned_by_matlab(A2); + mark_owned_by_matlab(A3); + mark_owned_by_matlab(A4); + mark_owned_by_matlab(A5); + + int i = 0; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A3); ++i;} ELSE_ASSIGN_ARG_3; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A4); ++i;} ELSE_ASSIGN_ARG_4; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A5); ++i;} ELSE_ASSIGN_ARG_5; + + f(A1,A2,A3,A4,A5); + + + i = 0; + if (is_output_type::value) {assign_to_matlab(plhs[i],A1); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A2); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A3); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A4); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A5); ++i;} + } + }; + + + template <> + struct call_mex_function_helper<6> + { + template + void callit( + const funct& f, + int nlhs, mxArray *plhs[], + int nrhs, const mxArray *prhs[] + ) const + { + typedef typename sig_traits::arg1_type arg1_type; + typedef typename sig_traits::arg2_type arg2_type; + typedef typename sig_traits::arg3_type arg3_type; + typedef typename sig_traits::arg4_type arg4_type; + typedef typename sig_traits::arg5_type arg5_type; + typedef typename sig_traits::arg6_type arg6_type; + + typename basic_type::type A1; + typename basic_type::type A2; + typename basic_type::type A3; + typename basic_type::type A4; + typename basic_type::type A5; + typename basic_type::type A6; + + mark_owned_by_matlab(A1); + mark_owned_by_matlab(A2); + mark_owned_by_matlab(A3); + mark_owned_by_matlab(A4); + mark_owned_by_matlab(A5); + mark_owned_by_matlab(A6); + + int i = 0; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A3); ++i;} ELSE_ASSIGN_ARG_3; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A4); ++i;} ELSE_ASSIGN_ARG_4; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A5); ++i;} ELSE_ASSIGN_ARG_5; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A6); ++i;} ELSE_ASSIGN_ARG_6; + + f(A1,A2,A3,A4,A5,A6); + + + i = 0; + if (is_output_type::value) {assign_to_matlab(plhs[i],A1); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A2); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A3); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A4); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A5); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A6); ++i;} + } + }; + + + template <> + struct call_mex_function_helper<7> + { + template + void callit( + const funct& f, + int nlhs, mxArray *plhs[], + int nrhs, const mxArray *prhs[] + ) const + { + typedef typename sig_traits::arg1_type arg1_type; + typedef typename sig_traits::arg2_type arg2_type; + typedef typename sig_traits::arg3_type arg3_type; + typedef typename sig_traits::arg4_type arg4_type; + typedef typename sig_traits::arg5_type arg5_type; + typedef typename sig_traits::arg6_type arg6_type; + typedef typename sig_traits::arg7_type arg7_type; + + typename basic_type::type A1; + typename basic_type::type A2; + typename basic_type::type A3; + typename basic_type::type A4; + typename basic_type::type A5; + typename basic_type::type A6; + typename basic_type::type A7; + + mark_owned_by_matlab(A1); + mark_owned_by_matlab(A2); + mark_owned_by_matlab(A3); + mark_owned_by_matlab(A4); + mark_owned_by_matlab(A5); + mark_owned_by_matlab(A6); + mark_owned_by_matlab(A7); + + int i = 0; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A3); ++i;} ELSE_ASSIGN_ARG_3; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A4); ++i;} ELSE_ASSIGN_ARG_4; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A5); ++i;} ELSE_ASSIGN_ARG_5; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A6); ++i;} ELSE_ASSIGN_ARG_6; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A7); ++i;} ELSE_ASSIGN_ARG_7; + + f(A1,A2,A3,A4,A5,A6,A7); + + + i = 0; + if (is_output_type::value) {assign_to_matlab(plhs[i],A1); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A2); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A3); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A4); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A5); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A6); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A7); ++i;} + } + }; + + + template <> + struct call_mex_function_helper<8> + { + template + void callit( + const funct& f, + int nlhs, mxArray *plhs[], + int nrhs, const mxArray *prhs[] + ) const + { + typedef typename sig_traits::arg1_type arg1_type; + typedef typename sig_traits::arg2_type arg2_type; + typedef typename sig_traits::arg3_type arg3_type; + typedef typename sig_traits::arg4_type arg4_type; + typedef typename sig_traits::arg5_type arg5_type; + typedef typename sig_traits::arg6_type arg6_type; + typedef typename sig_traits::arg7_type arg7_type; + typedef typename sig_traits::arg8_type arg8_type; + + typename basic_type::type A1; + typename basic_type::type A2; + typename basic_type::type A3; + typename basic_type::type A4; + typename basic_type::type A5; + typename basic_type::type A6; + typename basic_type::type A7; + typename basic_type::type A8; + + mark_owned_by_matlab(A1); + mark_owned_by_matlab(A2); + mark_owned_by_matlab(A3); + mark_owned_by_matlab(A4); + mark_owned_by_matlab(A5); + mark_owned_by_matlab(A6); + mark_owned_by_matlab(A7); + mark_owned_by_matlab(A8); + + int i = 0; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A3); ++i;} ELSE_ASSIGN_ARG_3; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A4); ++i;} ELSE_ASSIGN_ARG_4; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A5); ++i;} ELSE_ASSIGN_ARG_5; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A6); ++i;} ELSE_ASSIGN_ARG_6; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A7); ++i;} ELSE_ASSIGN_ARG_7; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A8); ++i;} ELSE_ASSIGN_ARG_8; + + f(A1,A2,A3,A4,A5,A6,A7,A8); + + + i = 0; + if (is_output_type::value) {assign_to_matlab(plhs[i],A1); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A2); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A3); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A4); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A5); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A6); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A7); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A8); ++i;} + } + }; + + + template <> + struct call_mex_function_helper<9> + { + template + void callit( + const funct& f, + int nlhs, mxArray *plhs[], + int nrhs, const mxArray *prhs[] + ) const + { + typedef typename sig_traits::arg1_type arg1_type; + typedef typename sig_traits::arg2_type arg2_type; + typedef typename sig_traits::arg3_type arg3_type; + typedef typename sig_traits::arg4_type arg4_type; + typedef typename sig_traits::arg5_type arg5_type; + typedef typename sig_traits::arg6_type arg6_type; + typedef typename sig_traits::arg7_type arg7_type; + typedef typename sig_traits::arg8_type arg8_type; + typedef typename sig_traits::arg9_type arg9_type; + + typename basic_type::type A1; + typename basic_type::type A2; + typename basic_type::type A3; + typename basic_type::type A4; + typename basic_type::type A5; + typename basic_type::type A6; + typename basic_type::type A7; + typename basic_type::type A8; + typename basic_type::type A9; + + mark_owned_by_matlab(A1); + mark_owned_by_matlab(A2); + mark_owned_by_matlab(A3); + mark_owned_by_matlab(A4); + mark_owned_by_matlab(A5); + mark_owned_by_matlab(A6); + mark_owned_by_matlab(A7); + mark_owned_by_matlab(A8); + mark_owned_by_matlab(A9); + + int i = 0; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A3); ++i;} ELSE_ASSIGN_ARG_3; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A4); ++i;} ELSE_ASSIGN_ARG_4; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A5); ++i;} ELSE_ASSIGN_ARG_5; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A6); ++i;} ELSE_ASSIGN_ARG_6; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A7); ++i;} ELSE_ASSIGN_ARG_7; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A8); ++i;} ELSE_ASSIGN_ARG_8; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A9); ++i;} ELSE_ASSIGN_ARG_9; + + f(A1,A2,A3,A4,A5,A6,A7,A8,A9); + + + i = 0; + if (is_output_type::value) {assign_to_matlab(plhs[i],A1); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A2); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A3); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A4); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A5); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A6); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A7); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A8); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A9); ++i;} + } + }; + + + + template <> + struct call_mex_function_helper<10> + { + template + void callit( + const funct& f, + int nlhs, mxArray *plhs[], + int nrhs, const mxArray *prhs[] + ) const + { + typedef typename sig_traits::arg1_type arg1_type; + typedef typename sig_traits::arg2_type arg2_type; + typedef typename sig_traits::arg3_type arg3_type; + typedef typename sig_traits::arg4_type arg4_type; + typedef typename sig_traits::arg5_type arg5_type; + typedef typename sig_traits::arg6_type arg6_type; + typedef typename sig_traits::arg7_type arg7_type; + typedef typename sig_traits::arg8_type arg8_type; + typedef typename sig_traits::arg9_type arg9_type; + typedef typename sig_traits::arg10_type arg10_type; + + typename basic_type::type A1; + typename basic_type::type A2; + typename basic_type::type A3; + typename basic_type::type A4; + typename basic_type::type A5; + typename basic_type::type A6; + typename basic_type::type A7; + typename basic_type::type A8; + typename basic_type::type A9; + typename basic_type::type A10; + + mark_owned_by_matlab(A1); + mark_owned_by_matlab(A2); + mark_owned_by_matlab(A3); + mark_owned_by_matlab(A4); + mark_owned_by_matlab(A5); + mark_owned_by_matlab(A6); + mark_owned_by_matlab(A7); + mark_owned_by_matlab(A8); + mark_owned_by_matlab(A9); + mark_owned_by_matlab(A10); + + int i = 0; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A3); ++i;} ELSE_ASSIGN_ARG_3; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A4); ++i;} ELSE_ASSIGN_ARG_4; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A5); ++i;} ELSE_ASSIGN_ARG_5; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A6); ++i;} ELSE_ASSIGN_ARG_6; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A7); ++i;} ELSE_ASSIGN_ARG_7; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A8); ++i;} ELSE_ASSIGN_ARG_8; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A9); ++i;} ELSE_ASSIGN_ARG_9; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A10); ++i;} ELSE_ASSIGN_ARG_10; + + f(A1,A2,A3,A4,A5,A6,A7,A8,A9,A10); + + + i = 0; + if (is_output_type::value) {assign_to_matlab(plhs[i],A1); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A2); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A3); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A4); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A5); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A6); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A7); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A8); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A9); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A10); ++i;} + } + }; + + template <> + struct call_mex_function_helper<11> + { + template + void callit( + const funct& f, + int nlhs, mxArray *plhs[], + int nrhs, const mxArray *prhs[] + ) const + { + typedef typename sig_traits::arg1_type arg1_type; + typedef typename sig_traits::arg2_type arg2_type; + typedef typename sig_traits::arg3_type arg3_type; + typedef typename sig_traits::arg4_type arg4_type; + typedef typename sig_traits::arg5_type arg5_type; + typedef typename sig_traits::arg6_type arg6_type; + typedef typename sig_traits::arg7_type arg7_type; + typedef typename sig_traits::arg8_type arg8_type; + typedef typename sig_traits::arg9_type arg9_type; + typedef typename sig_traits::arg10_type arg10_type; + typedef typename sig_traits::arg11_type arg11_type; + + typename basic_type::type A1; + typename basic_type::type A2; + typename basic_type::type A3; + typename basic_type::type A4; + typename basic_type::type A5; + typename basic_type::type A6; + typename basic_type::type A7; + typename basic_type::type A8; + typename basic_type::type A9; + typename basic_type::type A10; + typename basic_type::type A11; + + mark_owned_by_matlab(A1); + mark_owned_by_matlab(A2); + mark_owned_by_matlab(A3); + mark_owned_by_matlab(A4); + mark_owned_by_matlab(A5); + mark_owned_by_matlab(A6); + mark_owned_by_matlab(A7); + mark_owned_by_matlab(A8); + mark_owned_by_matlab(A9); + mark_owned_by_matlab(A10); + mark_owned_by_matlab(A11); + + int i = 0; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A3); ++i;} ELSE_ASSIGN_ARG_3; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A4); ++i;} ELSE_ASSIGN_ARG_4; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A5); ++i;} ELSE_ASSIGN_ARG_5; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A6); ++i;} ELSE_ASSIGN_ARG_6; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A7); ++i;} ELSE_ASSIGN_ARG_7; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A8); ++i;} ELSE_ASSIGN_ARG_8; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A9); ++i;} ELSE_ASSIGN_ARG_9; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A10); ++i;} ELSE_ASSIGN_ARG_10; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A11); ++i;} ELSE_ASSIGN_ARG_11; + + f(A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,A11); + + + i = 0; + if (is_output_type::value) {assign_to_matlab(plhs[i],A1); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A2); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A3); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A4); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A5); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A6); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A7); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A8); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A9); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A10); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A11); ++i;} + } + }; + + template <> + struct call_mex_function_helper<12> + { + template + void callit( + const funct& f, + int nlhs, mxArray *plhs[], + int nrhs, const mxArray *prhs[] + ) const + { + typedef typename sig_traits::arg1_type arg1_type; + typedef typename sig_traits::arg2_type arg2_type; + typedef typename sig_traits::arg3_type arg3_type; + typedef typename sig_traits::arg4_type arg4_type; + typedef typename sig_traits::arg5_type arg5_type; + typedef typename sig_traits::arg6_type arg6_type; + typedef typename sig_traits::arg7_type arg7_type; + typedef typename sig_traits::arg8_type arg8_type; + typedef typename sig_traits::arg9_type arg9_type; + typedef typename sig_traits::arg10_type arg10_type; + typedef typename sig_traits::arg11_type arg11_type; + typedef typename sig_traits::arg12_type arg12_type; + + typename basic_type::type A1; + typename basic_type::type A2; + typename basic_type::type A3; + typename basic_type::type A4; + typename basic_type::type A5; + typename basic_type::type A6; + typename basic_type::type A7; + typename basic_type::type A8; + typename basic_type::type A9; + typename basic_type::type A10; + typename basic_type::type A11; + typename basic_type::type A12; + + mark_owned_by_matlab(A1); + mark_owned_by_matlab(A2); + mark_owned_by_matlab(A3); + mark_owned_by_matlab(A4); + mark_owned_by_matlab(A5); + mark_owned_by_matlab(A6); + mark_owned_by_matlab(A7); + mark_owned_by_matlab(A8); + mark_owned_by_matlab(A9); + mark_owned_by_matlab(A10); + mark_owned_by_matlab(A11); + mark_owned_by_matlab(A12); + + int i = 0; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A3); ++i;} ELSE_ASSIGN_ARG_3; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A4); ++i;} ELSE_ASSIGN_ARG_4; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A5); ++i;} ELSE_ASSIGN_ARG_5; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A6); ++i;} ELSE_ASSIGN_ARG_6; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A7); ++i;} ELSE_ASSIGN_ARG_7; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A8); ++i;} ELSE_ASSIGN_ARG_8; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A9); ++i;} ELSE_ASSIGN_ARG_9; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A10); ++i;} ELSE_ASSIGN_ARG_10; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A11); ++i;} ELSE_ASSIGN_ARG_11; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A12); ++i;} ELSE_ASSIGN_ARG_12; + + f(A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,A11,A12); + + + i = 0; + if (is_output_type::value) {assign_to_matlab(plhs[i],A1); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A2); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A3); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A4); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A5); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A6); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A7); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A8); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A9); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A10); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A11); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A12); ++i;} + } + }; + + template <> + struct call_mex_function_helper<13> + { + template + void callit( + const funct& f, + int nlhs, mxArray *plhs[], + int nrhs, const mxArray *prhs[] + ) const + { + typedef typename sig_traits::arg1_type arg1_type; + typedef typename sig_traits::arg2_type arg2_type; + typedef typename sig_traits::arg3_type arg3_type; + typedef typename sig_traits::arg4_type arg4_type; + typedef typename sig_traits::arg5_type arg5_type; + typedef typename sig_traits::arg6_type arg6_type; + typedef typename sig_traits::arg7_type arg7_type; + typedef typename sig_traits::arg8_type arg8_type; + typedef typename sig_traits::arg9_type arg9_type; + typedef typename sig_traits::arg10_type arg10_type; + typedef typename sig_traits::arg11_type arg11_type; + typedef typename sig_traits::arg12_type arg12_type; + typedef typename sig_traits::arg13_type arg13_type; + + typename basic_type::type A1; + typename basic_type::type A2; + typename basic_type::type A3; + typename basic_type::type A4; + typename basic_type::type A5; + typename basic_type::type A6; + typename basic_type::type A7; + typename basic_type::type A8; + typename basic_type::type A9; + typename basic_type::type A10; + typename basic_type::type A11; + typename basic_type::type A12; + typename basic_type::type A13; + + mark_owned_by_matlab(A1); + mark_owned_by_matlab(A2); + mark_owned_by_matlab(A3); + mark_owned_by_matlab(A4); + mark_owned_by_matlab(A5); + mark_owned_by_matlab(A6); + mark_owned_by_matlab(A7); + mark_owned_by_matlab(A8); + mark_owned_by_matlab(A9); + mark_owned_by_matlab(A10); + mark_owned_by_matlab(A11); + mark_owned_by_matlab(A12); + mark_owned_by_matlab(A13); + + int i = 0; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A3); ++i;} ELSE_ASSIGN_ARG_3; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A4); ++i;} ELSE_ASSIGN_ARG_4; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A5); ++i;} ELSE_ASSIGN_ARG_5; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A6); ++i;} ELSE_ASSIGN_ARG_6; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A7); ++i;} ELSE_ASSIGN_ARG_7; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A8); ++i;} ELSE_ASSIGN_ARG_8; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A9); ++i;} ELSE_ASSIGN_ARG_9; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A10); ++i;} ELSE_ASSIGN_ARG_10; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A11); ++i;} ELSE_ASSIGN_ARG_11; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A12); ++i;} ELSE_ASSIGN_ARG_12; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A13); ++i;} ELSE_ASSIGN_ARG_13; + + f(A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,A11,A12,A13); + + + i = 0; + if (is_output_type::value) {assign_to_matlab(plhs[i],A1); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A2); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A3); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A4); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A5); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A6); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A7); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A8); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A9); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A10); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A11); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A12); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A13); ++i;} + } + }; + + template <> + struct call_mex_function_helper<14> + { + template + void callit( + const funct& f, + int nlhs, mxArray *plhs[], + int nrhs, const mxArray *prhs[] + ) const + { + typedef typename sig_traits::arg1_type arg1_type; + typedef typename sig_traits::arg2_type arg2_type; + typedef typename sig_traits::arg3_type arg3_type; + typedef typename sig_traits::arg4_type arg4_type; + typedef typename sig_traits::arg5_type arg5_type; + typedef typename sig_traits::arg6_type arg6_type; + typedef typename sig_traits::arg7_type arg7_type; + typedef typename sig_traits::arg8_type arg8_type; + typedef typename sig_traits::arg9_type arg9_type; + typedef typename sig_traits::arg10_type arg10_type; + typedef typename sig_traits::arg11_type arg11_type; + typedef typename sig_traits::arg12_type arg12_type; + typedef typename sig_traits::arg13_type arg13_type; + typedef typename sig_traits::arg14_type arg14_type; + + typename basic_type::type A1; + typename basic_type::type A2; + typename basic_type::type A3; + typename basic_type::type A4; + typename basic_type::type A5; + typename basic_type::type A6; + typename basic_type::type A7; + typename basic_type::type A8; + typename basic_type::type A9; + typename basic_type::type A10; + typename basic_type::type A11; + typename basic_type::type A12; + typename basic_type::type A13; + typename basic_type::type A14; + + mark_owned_by_matlab(A1); + mark_owned_by_matlab(A2); + mark_owned_by_matlab(A3); + mark_owned_by_matlab(A4); + mark_owned_by_matlab(A5); + mark_owned_by_matlab(A6); + mark_owned_by_matlab(A7); + mark_owned_by_matlab(A8); + mark_owned_by_matlab(A9); + mark_owned_by_matlab(A10); + mark_owned_by_matlab(A11); + mark_owned_by_matlab(A12); + mark_owned_by_matlab(A13); + mark_owned_by_matlab(A14); + + int i = 0; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A3); ++i;} ELSE_ASSIGN_ARG_3; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A4); ++i;} ELSE_ASSIGN_ARG_4; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A5); ++i;} ELSE_ASSIGN_ARG_5; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A6); ++i;} ELSE_ASSIGN_ARG_6; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A7); ++i;} ELSE_ASSIGN_ARG_7; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A8); ++i;} ELSE_ASSIGN_ARG_8; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A9); ++i;} ELSE_ASSIGN_ARG_9; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A10); ++i;} ELSE_ASSIGN_ARG_10; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A11); ++i;} ELSE_ASSIGN_ARG_11; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A12); ++i;} ELSE_ASSIGN_ARG_12; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A13); ++i;} ELSE_ASSIGN_ARG_13; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A14); ++i;} ELSE_ASSIGN_ARG_14; + + f(A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,A11,A12,A13,A14); + + + i = 0; + if (is_output_type::value) {assign_to_matlab(plhs[i],A1); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A2); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A3); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A4); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A5); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A6); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A7); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A8); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A9); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A10); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A11); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A12); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A13); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A14); ++i;} + } + }; + + template <> + struct call_mex_function_helper<15> + { + template + void callit( + const funct& f, + int nlhs, mxArray *plhs[], + int nrhs, const mxArray *prhs[] + ) const + { + typedef typename sig_traits::arg1_type arg1_type; + typedef typename sig_traits::arg2_type arg2_type; + typedef typename sig_traits::arg3_type arg3_type; + typedef typename sig_traits::arg4_type arg4_type; + typedef typename sig_traits::arg5_type arg5_type; + typedef typename sig_traits::arg6_type arg6_type; + typedef typename sig_traits::arg7_type arg7_type; + typedef typename sig_traits::arg8_type arg8_type; + typedef typename sig_traits::arg9_type arg9_type; + typedef typename sig_traits::arg10_type arg10_type; + typedef typename sig_traits::arg11_type arg11_type; + typedef typename sig_traits::arg12_type arg12_type; + typedef typename sig_traits::arg13_type arg13_type; + typedef typename sig_traits::arg14_type arg14_type; + typedef typename sig_traits::arg15_type arg15_type; + + typename basic_type::type A1; + typename basic_type::type A2; + typename basic_type::type A3; + typename basic_type::type A4; + typename basic_type::type A5; + typename basic_type::type A6; + typename basic_type::type A7; + typename basic_type::type A8; + typename basic_type::type A9; + typename basic_type::type A10; + typename basic_type::type A11; + typename basic_type::type A12; + typename basic_type::type A13; + typename basic_type::type A14; + typename basic_type::type A15; + + mark_owned_by_matlab(A1); + mark_owned_by_matlab(A2); + mark_owned_by_matlab(A3); + mark_owned_by_matlab(A4); + mark_owned_by_matlab(A5); + mark_owned_by_matlab(A6); + mark_owned_by_matlab(A7); + mark_owned_by_matlab(A8); + mark_owned_by_matlab(A9); + mark_owned_by_matlab(A10); + mark_owned_by_matlab(A11); + mark_owned_by_matlab(A12); + mark_owned_by_matlab(A13); + mark_owned_by_matlab(A14); + mark_owned_by_matlab(A15); + + int i = 0; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A3); ++i;} ELSE_ASSIGN_ARG_3; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A4); ++i;} ELSE_ASSIGN_ARG_4; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A5); ++i;} ELSE_ASSIGN_ARG_5; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A6); ++i;} ELSE_ASSIGN_ARG_6; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A7); ++i;} ELSE_ASSIGN_ARG_7; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A8); ++i;} ELSE_ASSIGN_ARG_8; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A9); ++i;} ELSE_ASSIGN_ARG_9; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A10); ++i;} ELSE_ASSIGN_ARG_10; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A11); ++i;} ELSE_ASSIGN_ARG_11; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A12); ++i;} ELSE_ASSIGN_ARG_12; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A13); ++i;} ELSE_ASSIGN_ARG_13; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A14); ++i;} ELSE_ASSIGN_ARG_14; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A15); ++i;} ELSE_ASSIGN_ARG_15; + + f(A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,A11,A12,A13,A14,A15); + + + i = 0; + if (is_output_type::value) {assign_to_matlab(plhs[i],A1); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A2); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A3); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A4); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A5); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A6); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A7); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A8); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A9); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A10); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A11); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A12); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A13); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A14); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A15); ++i;} + } + }; + + template <> + struct call_mex_function_helper<16> + { + template + void callit( + const funct& f, + int nlhs, mxArray *plhs[], + int nrhs, const mxArray *prhs[] + ) const + { + typedef typename sig_traits::arg1_type arg1_type; + typedef typename sig_traits::arg2_type arg2_type; + typedef typename sig_traits::arg3_type arg3_type; + typedef typename sig_traits::arg4_type arg4_type; + typedef typename sig_traits::arg5_type arg5_type; + typedef typename sig_traits::arg6_type arg6_type; + typedef typename sig_traits::arg7_type arg7_type; + typedef typename sig_traits::arg8_type arg8_type; + typedef typename sig_traits::arg9_type arg9_type; + typedef typename sig_traits::arg10_type arg10_type; + typedef typename sig_traits::arg11_type arg11_type; + typedef typename sig_traits::arg12_type arg12_type; + typedef typename sig_traits::arg13_type arg13_type; + typedef typename sig_traits::arg14_type arg14_type; + typedef typename sig_traits::arg15_type arg15_type; + typedef typename sig_traits::arg16_type arg16_type; + + typename basic_type::type A1; + typename basic_type::type A2; + typename basic_type::type A3; + typename basic_type::type A4; + typename basic_type::type A5; + typename basic_type::type A6; + typename basic_type::type A7; + typename basic_type::type A8; + typename basic_type::type A9; + typename basic_type::type A10; + typename basic_type::type A11; + typename basic_type::type A12; + typename basic_type::type A13; + typename basic_type::type A14; + typename basic_type::type A15; + typename basic_type::type A16; + + mark_owned_by_matlab(A1); + mark_owned_by_matlab(A2); + mark_owned_by_matlab(A3); + mark_owned_by_matlab(A4); + mark_owned_by_matlab(A5); + mark_owned_by_matlab(A6); + mark_owned_by_matlab(A7); + mark_owned_by_matlab(A8); + mark_owned_by_matlab(A9); + mark_owned_by_matlab(A10); + mark_owned_by_matlab(A11); + mark_owned_by_matlab(A12); + mark_owned_by_matlab(A13); + mark_owned_by_matlab(A14); + mark_owned_by_matlab(A15); + mark_owned_by_matlab(A16); + + int i = 0; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A3); ++i;} ELSE_ASSIGN_ARG_3; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A4); ++i;} ELSE_ASSIGN_ARG_4; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A5); ++i;} ELSE_ASSIGN_ARG_5; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A6); ++i;} ELSE_ASSIGN_ARG_6; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A7); ++i;} ELSE_ASSIGN_ARG_7; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A8); ++i;} ELSE_ASSIGN_ARG_8; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A9); ++i;} ELSE_ASSIGN_ARG_9; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A10); ++i;} ELSE_ASSIGN_ARG_10; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A11); ++i;} ELSE_ASSIGN_ARG_11; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A12); ++i;} ELSE_ASSIGN_ARG_12; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A13); ++i;} ELSE_ASSIGN_ARG_13; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A14); ++i;} ELSE_ASSIGN_ARG_14; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A15); ++i;} ELSE_ASSIGN_ARG_15; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A16); ++i;} ELSE_ASSIGN_ARG_16; + + f(A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,A11,A12,A13,A14,A15,A16); + + + i = 0; + if (is_output_type::value) {assign_to_matlab(plhs[i],A1); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A2); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A3); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A4); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A5); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A6); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A7); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A8); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A9); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A10); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A11); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A12); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A13); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A14); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A15); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A16); ++i;} + } + }; + + template <> + struct call_mex_function_helper<17> + { + template + void callit( + const funct& f, + int nlhs, mxArray *plhs[], + int nrhs, const mxArray *prhs[] + ) const + { + typedef typename sig_traits::arg1_type arg1_type; + typedef typename sig_traits::arg2_type arg2_type; + typedef typename sig_traits::arg3_type arg3_type; + typedef typename sig_traits::arg4_type arg4_type; + typedef typename sig_traits::arg5_type arg5_type; + typedef typename sig_traits::arg6_type arg6_type; + typedef typename sig_traits::arg7_type arg7_type; + typedef typename sig_traits::arg8_type arg8_type; + typedef typename sig_traits::arg9_type arg9_type; + typedef typename sig_traits::arg10_type arg10_type; + typedef typename sig_traits::arg11_type arg11_type; + typedef typename sig_traits::arg12_type arg12_type; + typedef typename sig_traits::arg13_type arg13_type; + typedef typename sig_traits::arg14_type arg14_type; + typedef typename sig_traits::arg15_type arg15_type; + typedef typename sig_traits::arg16_type arg16_type; + typedef typename sig_traits::arg17_type arg17_type; + + typename basic_type::type A1; + typename basic_type::type A2; + typename basic_type::type A3; + typename basic_type::type A4; + typename basic_type::type A5; + typename basic_type::type A6; + typename basic_type::type A7; + typename basic_type::type A8; + typename basic_type::type A9; + typename basic_type::type A10; + typename basic_type::type A11; + typename basic_type::type A12; + typename basic_type::type A13; + typename basic_type::type A14; + typename basic_type::type A15; + typename basic_type::type A16; + typename basic_type::type A17; + + mark_owned_by_matlab(A1); + mark_owned_by_matlab(A2); + mark_owned_by_matlab(A3); + mark_owned_by_matlab(A4); + mark_owned_by_matlab(A5); + mark_owned_by_matlab(A6); + mark_owned_by_matlab(A7); + mark_owned_by_matlab(A8); + mark_owned_by_matlab(A9); + mark_owned_by_matlab(A10); + mark_owned_by_matlab(A11); + mark_owned_by_matlab(A12); + mark_owned_by_matlab(A13); + mark_owned_by_matlab(A14); + mark_owned_by_matlab(A15); + mark_owned_by_matlab(A16); + mark_owned_by_matlab(A17); + + int i = 0; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A3); ++i;} ELSE_ASSIGN_ARG_3; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A4); ++i;} ELSE_ASSIGN_ARG_4; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A5); ++i;} ELSE_ASSIGN_ARG_5; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A6); ++i;} ELSE_ASSIGN_ARG_6; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A7); ++i;} ELSE_ASSIGN_ARG_7; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A8); ++i;} ELSE_ASSIGN_ARG_8; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A9); ++i;} ELSE_ASSIGN_ARG_9; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A10); ++i;} ELSE_ASSIGN_ARG_10; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A11); ++i;} ELSE_ASSIGN_ARG_11; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A12); ++i;} ELSE_ASSIGN_ARG_12; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A13); ++i;} ELSE_ASSIGN_ARG_13; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A14); ++i;} ELSE_ASSIGN_ARG_14; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A15); ++i;} ELSE_ASSIGN_ARG_15; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A16); ++i;} ELSE_ASSIGN_ARG_16; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A17); ++i;} ELSE_ASSIGN_ARG_17; + + f(A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,A11,A12,A13,A14,A15,A16,A17); + + + i = 0; + if (is_output_type::value) {assign_to_matlab(plhs[i],A1); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A2); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A3); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A4); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A5); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A6); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A7); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A8); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A9); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A10); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A11); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A12); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A13); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A14); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A15); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A16); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A17); ++i;} + } + }; + + template <> + struct call_mex_function_helper<18> + { + template + void callit( + const funct& f, + int nlhs, mxArray *plhs[], + int nrhs, const mxArray *prhs[] + ) const + { + typedef typename sig_traits::arg1_type arg1_type; + typedef typename sig_traits::arg2_type arg2_type; + typedef typename sig_traits::arg3_type arg3_type; + typedef typename sig_traits::arg4_type arg4_type; + typedef typename sig_traits::arg5_type arg5_type; + typedef typename sig_traits::arg6_type arg6_type; + typedef typename sig_traits::arg7_type arg7_type; + typedef typename sig_traits::arg8_type arg8_type; + typedef typename sig_traits::arg9_type arg9_type; + typedef typename sig_traits::arg10_type arg10_type; + typedef typename sig_traits::arg11_type arg11_type; + typedef typename sig_traits::arg12_type arg12_type; + typedef typename sig_traits::arg13_type arg13_type; + typedef typename sig_traits::arg14_type arg14_type; + typedef typename sig_traits::arg15_type arg15_type; + typedef typename sig_traits::arg16_type arg16_type; + typedef typename sig_traits::arg17_type arg17_type; + typedef typename sig_traits::arg18_type arg18_type; + + typename basic_type::type A1; + typename basic_type::type A2; + typename basic_type::type A3; + typename basic_type::type A4; + typename basic_type::type A5; + typename basic_type::type A6; + typename basic_type::type A7; + typename basic_type::type A8; + typename basic_type::type A9; + typename basic_type::type A10; + typename basic_type::type A11; + typename basic_type::type A12; + typename basic_type::type A13; + typename basic_type::type A14; + typename basic_type::type A15; + typename basic_type::type A16; + typename basic_type::type A17; + typename basic_type::type A18; + + mark_owned_by_matlab(A1); + mark_owned_by_matlab(A2); + mark_owned_by_matlab(A3); + mark_owned_by_matlab(A4); + mark_owned_by_matlab(A5); + mark_owned_by_matlab(A6); + mark_owned_by_matlab(A7); + mark_owned_by_matlab(A8); + mark_owned_by_matlab(A9); + mark_owned_by_matlab(A10); + mark_owned_by_matlab(A11); + mark_owned_by_matlab(A12); + mark_owned_by_matlab(A13); + mark_owned_by_matlab(A14); + mark_owned_by_matlab(A15); + mark_owned_by_matlab(A16); + mark_owned_by_matlab(A17); + mark_owned_by_matlab(A18); + + int i = 0; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A3); ++i;} ELSE_ASSIGN_ARG_3; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A4); ++i;} ELSE_ASSIGN_ARG_4; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A5); ++i;} ELSE_ASSIGN_ARG_5; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A6); ++i;} ELSE_ASSIGN_ARG_6; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A7); ++i;} ELSE_ASSIGN_ARG_7; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A8); ++i;} ELSE_ASSIGN_ARG_8; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A9); ++i;} ELSE_ASSIGN_ARG_9; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A10); ++i;} ELSE_ASSIGN_ARG_10; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A11); ++i;} ELSE_ASSIGN_ARG_11; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A12); ++i;} ELSE_ASSIGN_ARG_12; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A13); ++i;} ELSE_ASSIGN_ARG_13; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A14); ++i;} ELSE_ASSIGN_ARG_14; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A15); ++i;} ELSE_ASSIGN_ARG_15; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A16); ++i;} ELSE_ASSIGN_ARG_16; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A17); ++i;} ELSE_ASSIGN_ARG_17; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A18); ++i;} ELSE_ASSIGN_ARG_18; + + f(A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,A11,A12,A13,A14,A15,A16,A17,A18); + + + i = 0; + if (is_output_type::value) {assign_to_matlab(plhs[i],A1); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A2); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A3); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A4); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A5); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A6); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A7); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A8); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A9); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A10); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A11); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A12); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A13); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A14); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A15); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A16); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A17); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A18); ++i;} + } + }; + + template <> + struct call_mex_function_helper<19> + { + template + void callit( + const funct& f, + int nlhs, mxArray *plhs[], + int nrhs, const mxArray *prhs[] + ) const + { + typedef typename sig_traits::arg1_type arg1_type; + typedef typename sig_traits::arg2_type arg2_type; + typedef typename sig_traits::arg3_type arg3_type; + typedef typename sig_traits::arg4_type arg4_type; + typedef typename sig_traits::arg5_type arg5_type; + typedef typename sig_traits::arg6_type arg6_type; + typedef typename sig_traits::arg7_type arg7_type; + typedef typename sig_traits::arg8_type arg8_type; + typedef typename sig_traits::arg9_type arg9_type; + typedef typename sig_traits::arg10_type arg10_type; + typedef typename sig_traits::arg11_type arg11_type; + typedef typename sig_traits::arg12_type arg12_type; + typedef typename sig_traits::arg13_type arg13_type; + typedef typename sig_traits::arg14_type arg14_type; + typedef typename sig_traits::arg15_type arg15_type; + typedef typename sig_traits::arg16_type arg16_type; + typedef typename sig_traits::arg17_type arg17_type; + typedef typename sig_traits::arg18_type arg18_type; + typedef typename sig_traits::arg19_type arg19_type; + + typename basic_type::type A1; + typename basic_type::type A2; + typename basic_type::type A3; + typename basic_type::type A4; + typename basic_type::type A5; + typename basic_type::type A6; + typename basic_type::type A7; + typename basic_type::type A8; + typename basic_type::type A9; + typename basic_type::type A10; + typename basic_type::type A11; + typename basic_type::type A12; + typename basic_type::type A13; + typename basic_type::type A14; + typename basic_type::type A15; + typename basic_type::type A16; + typename basic_type::type A17; + typename basic_type::type A18; + typename basic_type::type A19; + + mark_owned_by_matlab(A1); + mark_owned_by_matlab(A2); + mark_owned_by_matlab(A3); + mark_owned_by_matlab(A4); + mark_owned_by_matlab(A5); + mark_owned_by_matlab(A6); + mark_owned_by_matlab(A7); + mark_owned_by_matlab(A8); + mark_owned_by_matlab(A9); + mark_owned_by_matlab(A10); + mark_owned_by_matlab(A11); + mark_owned_by_matlab(A12); + mark_owned_by_matlab(A13); + mark_owned_by_matlab(A14); + mark_owned_by_matlab(A15); + mark_owned_by_matlab(A16); + mark_owned_by_matlab(A17); + mark_owned_by_matlab(A18); + mark_owned_by_matlab(A19); + + int i = 0; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A3); ++i;} ELSE_ASSIGN_ARG_3; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A4); ++i;} ELSE_ASSIGN_ARG_4; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A5); ++i;} ELSE_ASSIGN_ARG_5; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A6); ++i;} ELSE_ASSIGN_ARG_6; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A7); ++i;} ELSE_ASSIGN_ARG_7; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A8); ++i;} ELSE_ASSIGN_ARG_8; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A9); ++i;} ELSE_ASSIGN_ARG_9; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A10); ++i;} ELSE_ASSIGN_ARG_10; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A11); ++i;} ELSE_ASSIGN_ARG_11; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A12); ++i;} ELSE_ASSIGN_ARG_12; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A13); ++i;} ELSE_ASSIGN_ARG_13; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A14); ++i;} ELSE_ASSIGN_ARG_14; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A15); ++i;} ELSE_ASSIGN_ARG_15; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A16); ++i;} ELSE_ASSIGN_ARG_16; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A17); ++i;} ELSE_ASSIGN_ARG_17; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A18); ++i;} ELSE_ASSIGN_ARG_18; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A19); ++i;} ELSE_ASSIGN_ARG_19; + + f(A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,A11,A12,A13,A14,A15,A16,A17,A18,A19); + + + i = 0; + if (is_output_type::value) {assign_to_matlab(plhs[i],A1); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A2); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A3); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A4); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A5); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A6); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A7); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A8); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A9); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A10); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A11); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A12); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A13); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A14); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A15); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A16); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A17); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A18); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A19); ++i;} + } + }; + + template <> + struct call_mex_function_helper<20> + { + template + void callit( + const funct& f, + int nlhs, mxArray *plhs[], + int nrhs, const mxArray *prhs[] + ) const + { + typedef typename sig_traits::arg1_type arg1_type; + typedef typename sig_traits::arg2_type arg2_type; + typedef typename sig_traits::arg3_type arg3_type; + typedef typename sig_traits::arg4_type arg4_type; + typedef typename sig_traits::arg5_type arg5_type; + typedef typename sig_traits::arg6_type arg6_type; + typedef typename sig_traits::arg7_type arg7_type; + typedef typename sig_traits::arg8_type arg8_type; + typedef typename sig_traits::arg9_type arg9_type; + typedef typename sig_traits::arg10_type arg10_type; + typedef typename sig_traits::arg11_type arg11_type; + typedef typename sig_traits::arg12_type arg12_type; + typedef typename sig_traits::arg13_type arg13_type; + typedef typename sig_traits::arg14_type arg14_type; + typedef typename sig_traits::arg15_type arg15_type; + typedef typename sig_traits::arg16_type arg16_type; + typedef typename sig_traits::arg17_type arg17_type; + typedef typename sig_traits::arg18_type arg18_type; + typedef typename sig_traits::arg19_type arg19_type; + typedef typename sig_traits::arg20_type arg20_type; + + typename basic_type::type A1; + typename basic_type::type A2; + typename basic_type::type A3; + typename basic_type::type A4; + typename basic_type::type A5; + typename basic_type::type A6; + typename basic_type::type A7; + typename basic_type::type A8; + typename basic_type::type A9; + typename basic_type::type A10; + typename basic_type::type A11; + typename basic_type::type A12; + typename basic_type::type A13; + typename basic_type::type A14; + typename basic_type::type A15; + typename basic_type::type A16; + typename basic_type::type A17; + typename basic_type::type A18; + typename basic_type::type A19; + typename basic_type::type A20; + + mark_owned_by_matlab(A1); + mark_owned_by_matlab(A2); + mark_owned_by_matlab(A3); + mark_owned_by_matlab(A4); + mark_owned_by_matlab(A5); + mark_owned_by_matlab(A6); + mark_owned_by_matlab(A7); + mark_owned_by_matlab(A8); + mark_owned_by_matlab(A9); + mark_owned_by_matlab(A10); + mark_owned_by_matlab(A11); + mark_owned_by_matlab(A12); + mark_owned_by_matlab(A13); + mark_owned_by_matlab(A14); + mark_owned_by_matlab(A15); + mark_owned_by_matlab(A16); + mark_owned_by_matlab(A17); + mark_owned_by_matlab(A18); + mark_owned_by_matlab(A19); + mark_owned_by_matlab(A20); + + int i = 0; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A3); ++i;} ELSE_ASSIGN_ARG_3; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A4); ++i;} ELSE_ASSIGN_ARG_4; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A5); ++i;} ELSE_ASSIGN_ARG_5; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A6); ++i;} ELSE_ASSIGN_ARG_6; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A7); ++i;} ELSE_ASSIGN_ARG_7; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A8); ++i;} ELSE_ASSIGN_ARG_8; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A9); ++i;} ELSE_ASSIGN_ARG_9; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A10); ++i;} ELSE_ASSIGN_ARG_10; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A11); ++i;} ELSE_ASSIGN_ARG_11; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A12); ++i;} ELSE_ASSIGN_ARG_12; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A13); ++i;} ELSE_ASSIGN_ARG_13; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A14); ++i;} ELSE_ASSIGN_ARG_14; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A15); ++i;} ELSE_ASSIGN_ARG_15; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A16); ++i;} ELSE_ASSIGN_ARG_16; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A17); ++i;} ELSE_ASSIGN_ARG_17; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A18); ++i;} ELSE_ASSIGN_ARG_18; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A19); ++i;} ELSE_ASSIGN_ARG_19; + if (i < nrhs && is_input_type::value) {validate_and_populate_arg(i,prhs[i],A20); ++i;} ELSE_ASSIGN_ARG_20; + + f(A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,A11,A12,A13,A14,A15,A16,A17,A18,A19,A20); + + + i = 0; + if (is_output_type::value) {assign_to_matlab(plhs[i],A1); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A2); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A3); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A4); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A5); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A6); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A7); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A8); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A9); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A10); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A11); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A12); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A13); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A14); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A15); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A16); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A17); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A18); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A19); ++i;} + if (is_output_type::value) {assign_to_matlab(plhs[i],A20); ++i;} + } + }; + +// ---------------------------------------------------------------------------------------- + + template struct is_matlab_object { const static bool value = false; }; + template <> struct is_matlab_object { const static bool value = true; }; + template <> struct is_matlab_object { const static bool value = true; }; + template <> struct is_matlab_object { const static bool value = true; }; + template <> struct is_matlab_object { const static bool value = true; }; + +// ---------------------------------------------------------------------------------------- + + template < + typename funct + > + void call_mex_function ( + const funct& f, + int nlhs, mxArray *plhs[], + int nrhs, const mxArray *prhs[] + ) + { + const long expected_nrhs = funct_traits::num_inputs; + const long expected_nlhs = funct_traits::num_outputs; + const long expected_args = expected_nrhs + expected_nlhs; + + long defaulted_args = 0; + + #ifdef ARG_1_DEFAULT + ++defaulted_args; + // You can only set an argument's default value if it is an input argument. + COMPILE_TIME_ASSERT(is_input_type::arg1_type>::value); + #ifndef ARG_2_DEFAULT + // You can't define a default for argument 1 if you don't define one for argument 2 also. + COMPILE_TIME_ASSERT(expected_args < 2); + #endif + COMPILE_TIME_ASSERT(1 <= expected_args); // You can't define a default for an argument that doesn't exist. + #endif + #ifdef ARG_2_DEFAULT + ++defaulted_args; + // You can only set an argument's default value if it is an input argument. + COMPILE_TIME_ASSERT(is_input_type::arg2_type>::value); + #ifndef ARG_3_DEFAULT + // You can't define a default for argument 2 if you don't define one for argument 3 also. + COMPILE_TIME_ASSERT(expected_args < 3); + #endif + COMPILE_TIME_ASSERT(2 <= expected_args); // You can't define a default for an argument that doesn't exist. + #endif + #ifdef ARG_3_DEFAULT + ++defaulted_args; + // You can only set an argument's default value if it is an input argument. + COMPILE_TIME_ASSERT(is_input_type::arg3_type>::value); + #ifndef ARG_4_DEFAULT + // You can't define a default for argument 3 if you don't define one for argument 4 also. + COMPILE_TIME_ASSERT(expected_args < 4); + #endif + COMPILE_TIME_ASSERT(3 <= expected_args); // You can't define a default for an argument that doesn't exist. + #endif + #ifdef ARG_4_DEFAULT + ++defaulted_args; + // You can only set an argument's default value if it is an input argument. + COMPILE_TIME_ASSERT(is_input_type::arg4_type>::value); + #ifndef ARG_5_DEFAULT + // You can't define a default for argument 4 if you don't define one for argument 5 also. + COMPILE_TIME_ASSERT(expected_args < 5); + #endif + COMPILE_TIME_ASSERT(4 <= expected_args); // You can't define a default for an argument that doesn't exist. + #endif + #ifdef ARG_5_DEFAULT + ++defaulted_args; + // You can only set an argument's default value if it is an input argument. + COMPILE_TIME_ASSERT(is_input_type::arg5_type>::value); + #ifndef ARG_6_DEFAULT + // You can't define a default for argument 5 if you don't define one for argument 6 also. + COMPILE_TIME_ASSERT(expected_args < 6); + #endif + COMPILE_TIME_ASSERT(5 <= expected_args); // You can't define a default for an argument that doesn't exist. + #endif + #ifdef ARG_6_DEFAULT + ++defaulted_args; + // You can only set an argument's default value if it is an input argument. + COMPILE_TIME_ASSERT(is_input_type::arg6_type>::value); + #ifndef ARG_7_DEFAULT + // You can't define a default for argument 6 if you don't define one for argument 7 also. + COMPILE_TIME_ASSERT(expected_args < 7); + #endif + COMPILE_TIME_ASSERT(6 <= expected_args); // You can't define a default for an argument that doesn't exist. + #endif + #ifdef ARG_7_DEFAULT + ++defaulted_args; + // You can only set an argument's default value if it is an input argument. + COMPILE_TIME_ASSERT(is_input_type::arg7_type>::value); + #ifndef ARG_8_DEFAULT + // You can't define a default for argument 7 if you don't define one for argument 8 also. + COMPILE_TIME_ASSERT(expected_args < 8); + #endif + COMPILE_TIME_ASSERT(7 <= expected_args); // You can't define a default for an argument that doesn't exist. + #endif + #ifdef ARG_8_DEFAULT + ++defaulted_args; + // You can only set an argument's default value if it is an input argument. + COMPILE_TIME_ASSERT(is_input_type::arg8_type>::value); + #ifndef ARG_9_DEFAULT + // You can't define a default for argument 8 if you don't define one for argument 9 also. + COMPILE_TIME_ASSERT(expected_args < 9); + #endif + COMPILE_TIME_ASSERT(8 <= expected_args); // You can't define a default for an argument that doesn't exist. + #endif + #ifdef ARG_9_DEFAULT + ++defaulted_args; + // You can only set an argument's default value if it is an input argument. + COMPILE_TIME_ASSERT(is_input_type::arg9_type>::value); + #ifndef ARG_10_DEFAULT + // You can't define a default for argument 9 if you don't define one for argument 10 also. + COMPILE_TIME_ASSERT(expected_args < 10); + #endif + COMPILE_TIME_ASSERT(9 <= expected_args); // You can't define a default for an argument that doesn't exist. + #endif + #ifdef ARG_10_DEFAULT + ++defaulted_args; + // You can only set an argument's default value if it is an input argument. + COMPILE_TIME_ASSERT(is_input_type::arg10_type>::value); + #ifndef ARG_11_DEFAULT + // You can't define a default for argument 10 if you don't define one for argument 11 also. + COMPILE_TIME_ASSERT(expected_args < 11); + #endif + COMPILE_TIME_ASSERT(10 <= expected_args); // You can't define a default for an argument that doesn't exist. + #endif + #ifdef ARG_11_DEFAULT + ++defaulted_args; + // You can only set an argument's default value if it is an input argument. + COMPILE_TIME_ASSERT(is_input_type::arg11_type>::value); + #ifndef ARG_12_DEFAULT + // You can't define a default for argument 11 if you don't define one for argument 12 also. + COMPILE_TIME_ASSERT(expected_args < 12); + #endif + COMPILE_TIME_ASSERT(11 <= expected_args); // You can't define a default for an argument that doesn't exist. + #endif + #ifdef ARG_12_DEFAULT + ++defaulted_args; + // You can only set an argument's default value if it is an input argument. + COMPILE_TIME_ASSERT(is_input_type::arg12_type>::value); + #ifndef ARG_13_DEFAULT + // You can't define a default for argument 12 if you don't define one for argument 13 also. + COMPILE_TIME_ASSERT(expected_args < 13); + #endif + COMPILE_TIME_ASSERT(12 <= expected_args); // You can't define a default for an argument that doesn't exist. + #endif + #ifdef ARG_13_DEFAULT + ++defaulted_args; + // You can only set an argument's default value if it is an input argument. + COMPILE_TIME_ASSERT(is_input_type::arg13_type>::value); + #ifndef ARG_14_DEFAULT + // You can't define a default for argument 13 if you don't define one for argument 14 also. + COMPILE_TIME_ASSERT(expected_args < 14); + #endif + COMPILE_TIME_ASSERT(13 <= expected_args); // You can't define a default for an argument that doesn't exist. + #endif + #ifdef ARG_14_DEFAULT + ++defaulted_args; + // You can only set an argument's default value if it is an input argument. + COMPILE_TIME_ASSERT(is_input_type::arg14_type>::value); + #ifndef ARG_15_DEFAULT + // You can't define a default for argument 14 if you don't define one for argument 15 also. + COMPILE_TIME_ASSERT(expected_args < 15); + #endif + COMPILE_TIME_ASSERT(14 <= expected_args); // You can't define a default for an argument that doesn't exist. + #endif + #ifdef ARG_15_DEFAULT + ++defaulted_args; + // You can only set an argument's default value if it is an input argument. + COMPILE_TIME_ASSERT(is_input_type::arg15_type>::value); + #ifndef ARG_16_DEFAULT + // You can't define a default for argument 15 if you don't define one for argument 16 also. + COMPILE_TIME_ASSERT(expected_args < 16); + #endif + COMPILE_TIME_ASSERT(15 <= expected_args); // You can't define a default for an argument that doesn't exist. + #endif + #ifdef ARG_16_DEFAULT + ++defaulted_args; + // You can only set an argument's default value if it is an input argument. + COMPILE_TIME_ASSERT(is_input_type::arg16_type>::value); + #ifndef ARG_17_DEFAULT + // You can't define a default for argument 16 if you don't define one for argument 17 also. + COMPILE_TIME_ASSERT(expected_args < 17); + #endif + COMPILE_TIME_ASSERT(16 <= expected_args); // You can't define a default for an argument that doesn't exist. + #endif + #ifdef ARG_17_DEFAULT + ++defaulted_args; + // You can only set an argument's default value if it is an input argument. + COMPILE_TIME_ASSERT(is_input_type::arg17_type>::value); + #ifndef ARG_18_DEFAULT + // You can't define a default for argument 17 if you don't define one for argument 18 also. + COMPILE_TIME_ASSERT(expected_args < 18); + #endif + COMPILE_TIME_ASSERT(17 <= expected_args); // You can't define a default for an argument that doesn't exist. + #endif + #ifdef ARG_18_DEFAULT + ++defaulted_args; + // You can only set an argument's default value if it is an input argument. + COMPILE_TIME_ASSERT(is_input_type::arg18_type>::value); + #ifndef ARG_19_DEFAULT + // You can't define a default for argument 18 if you don't define one for argument 19 also. + COMPILE_TIME_ASSERT(expected_args < 19); + #endif + COMPILE_TIME_ASSERT(18 <= expected_args); // You can't define a default for an argument that doesn't exist. + #endif + #ifdef ARG_19_DEFAULT + ++defaulted_args; + // You can only set an argument's default value if it is an input argument. + COMPILE_TIME_ASSERT(is_input_type::arg19_type>::value); + #ifndef ARG_20_DEFAULT + // You can't define a default for argument 19 if you don't define one for argument 20 also. + COMPILE_TIME_ASSERT(expected_args < 20); + #endif + COMPILE_TIME_ASSERT(19 <= expected_args); // You can't define a default for an argument that doesn't exist. + #endif + #ifdef ARG_20_DEFAULT + ++defaulted_args; + // You can only set an argument's default value if it is an input argument. + COMPILE_TIME_ASSERT(is_input_type::arg20_type>::value); + COMPILE_TIME_ASSERT(20 <= expected_args); // You can't define a default for an argument that doesn't exist. + #endif + + + // Arguments with type matlab_object are optional in both input and output. + int num_optional_inputs = 0; + int num_optional_outputs = 0; + if (is_matlab_object::arg20_type>::value) if (is_input_type::arg20_type>::value) ++num_optional_inputs; else ++num_optional_outputs; + if (is_matlab_object::arg19_type>::value) if (is_input_type::arg19_type>::value) ++num_optional_inputs; else ++num_optional_outputs; + if (is_matlab_object::arg18_type>::value) if (is_input_type::arg18_type>::value) ++num_optional_inputs; else ++num_optional_outputs; + if (is_matlab_object::arg17_type>::value) if (is_input_type::arg17_type>::value) ++num_optional_inputs; else ++num_optional_outputs; + if (is_matlab_object::arg16_type>::value) if (is_input_type::arg16_type>::value) ++num_optional_inputs; else ++num_optional_outputs; + if (is_matlab_object::arg15_type>::value) if (is_input_type::arg15_type>::value) ++num_optional_inputs; else ++num_optional_outputs; + if (is_matlab_object::arg14_type>::value) if (is_input_type::arg14_type>::value) ++num_optional_inputs; else ++num_optional_outputs; + if (is_matlab_object::arg13_type>::value) if (is_input_type::arg13_type>::value) ++num_optional_inputs; else ++num_optional_outputs; + if (is_matlab_object::arg12_type>::value) if (is_input_type::arg12_type>::value) ++num_optional_inputs; else ++num_optional_outputs; + if (is_matlab_object::arg11_type>::value) if (is_input_type::arg11_type>::value) ++num_optional_inputs; else ++num_optional_outputs; + if (is_matlab_object::arg10_type>::value) if (is_input_type::arg10_type>::value) ++num_optional_inputs; else ++num_optional_outputs; + if (is_matlab_object::arg9_type>::value) if (is_input_type::arg9_type>::value) ++num_optional_inputs; else ++num_optional_outputs; + if (is_matlab_object::arg8_type>::value) if (is_input_type::arg8_type>::value) ++num_optional_inputs; else ++num_optional_outputs; + if (is_matlab_object::arg7_type>::value) if (is_input_type::arg7_type>::value) ++num_optional_inputs; else ++num_optional_outputs; + if (is_matlab_object::arg6_type>::value) if (is_input_type::arg6_type>::value) ++num_optional_inputs; else ++num_optional_outputs; + if (is_matlab_object::arg5_type>::value) if (is_input_type::arg5_type>::value) ++num_optional_inputs; else ++num_optional_outputs; + if (is_matlab_object::arg4_type>::value) if (is_input_type::arg4_type>::value) ++num_optional_inputs; else ++num_optional_outputs; + if (is_matlab_object::arg3_type>::value) if (is_input_type::arg3_type>::value) ++num_optional_inputs; else ++num_optional_outputs; + if (is_matlab_object::arg2_type>::value) if (is_input_type::arg2_type>::value) ++num_optional_inputs; else ++num_optional_outputs; + if (is_matlab_object::arg1_type>::value) if (is_input_type::arg1_type>::value) ++num_optional_inputs; else ++num_optional_outputs; + + + /* check for proper number of arguments */ + if(nrhs > expected_nrhs || nrhs < expected_nrhs - defaulted_args - num_optional_inputs) + { + std::ostringstream sout; + sout << "Expected between " << expected_nrhs-defaulted_args - num_optional_inputs + << " and " << expected_nrhs << " input arguments, got " << nrhs << "."; + + mexErrMsgIdAndTxt("mex_function:nrhs", + escape_percent(sout).c_str()); + } + + if (nlhs > expected_nlhs) + { + std::ostringstream sout; + sout << "Expected at most " << expected_nlhs << " output arguments, got " << nlhs << "."; + + mexErrMsgIdAndTxt("mex_function:nlhs", + escape_percent(sout).c_str()); + } + + call_mex_function_helper::num_args> helper; + helper.callit(f, nlhs, plhs, nrhs, prhs); + + } + +// ---------------------------------------------------------------------------------------- + + class mex_streambuf : public std::streambuf + { + + public: + mex_streambuf ( + ) + { + buf.resize(1000); + setp(&buf[0], &buf[0] + buf.size()-2); + + // make cout send data to mex_streambuf + oldbuf = std::cout.rdbuf(this); + } + + ~mex_streambuf() + { + // put cout back to the way we found it before running our mex function. + std::cout.rdbuf(oldbuf); + } + + + protected: + + + int sync ( + ) + { + int num = static_cast(pptr()-pbase()); + if (num != 0) + { + check_for_matlab_ctrl_c(); + + buf[num] = 0; // null terminate the string + mexPrintf("%s",&buf[0]); + mexEvalString("drawnow"); // flush print to screen + pbump(-num); + } + return 0; + } + + int_type overflow ( + int_type c + ) + { + if (c != EOF) + { + *pptr() = c; + pbump(1); + } + sync(); + return c; + } + + private: + std::vector buf; + std::streambuf* oldbuf; + + }; + + class mex_warn_streambuf : public std::streambuf + { + + public: + mex_warn_streambuf ( + ) + { + buf.resize(1000); + setp(&buf[0], &buf[0] + buf.size()-2); + + // make cout send data to mex_warn_streambuf + oldbuf = std::cerr.rdbuf(this); + } + + ~mex_warn_streambuf() + { + // put cerr back to the way we found it before running our mex function. + std::cerr.rdbuf(oldbuf); + } + + protected: + + + int sync ( + ) + { + int num = static_cast(pptr()-pbase()); + if (num != 0) + { + check_for_matlab_ctrl_c(); + + buf[num] = 0; // null terminate the string + mexWarnMsgTxt(&buf[0]); + mexEvalString("drawnow"); // flush print to screen + pbump(-num); + } + return 0; + } + + int_type overflow ( + int_type c + ) + { + if (c != EOF) + { + *pptr() = c; + pbump(1); + } + sync(); + return c; + } + + private: + std::vector buf; + std::streambuf* oldbuf; + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template + void setup_input_args ( + mxArray*& array, + const T& item, + int& nrhs + ) + { + assign_to_matlab(array, item); + ++nrhs; + } + + void setup_input_args ( + mxArray*& array, + const function_handle& item, + int& nrhs + ) + { + array = static_cast(item.h); + ++nrhs; + } + + template + void setup_input_args ( + mxArray*& array, + const output_decorator& item, + int& nrhs + ) + { + } + + template + void setup_output_args ( + const std::string& function_name, + mxArray* array, + const T& item, + int& nrhs + ) + { + } + + template + void setup_output_args ( + const std::string& function_name, + mxArray* array, + const output_decorator& item, + int& i + ) + { + try + { + validate_and_populate_arg(i,array,const_cast(item.item)); + ++i; + } + catch (invalid_args_exception& e) + { + throw dlib::error("Error occurred calling MATLAB function '" + function_name + "' from mex file. \n" + "The MATLAB function didn't return what we expected it to. \nIn particular, return" + string(e.what())); + } + } + + void call_matlab_for_real ( + int nlhs, + mxArray* plhs[], + int nrhs, + mxArray* prhs[], + const std::string& function_name + ) + { + int status = mexCallMATLAB(nlhs, plhs, nrhs, prhs, function_name.c_str()); + if (status) + { + throw dlib::error("Error, an exception was thrown when we tried to call the MATLAB function '" + function_name + "'."); + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + +} + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + +namespace dlib +{ + void call_matlab ( + const std::string& function_name + ) + { + using namespace mex_binding; + + call_matlab_for_real(0,NULL,0,NULL, function_name); + } + + template + void free_callback_resources ( + int nlhs, + mxArray* plhs[], + int nrhs, + mxArray* prhs[] + ) + { + // free resources + for (int i = 0; i < nlhs; ++i) + mxDestroyArray(plhs[i]); + + for (int i = 0; i < nrhs; ++i) + { + // don't call mxDestroyArray() on function handles (which should only ever be in prhs[0]) + if (i == 0 && dlib::is_same_type::value) + continue; + mxDestroyArray(prhs[i]); + } + } + + template < + typename T1 + > + void call_matlab ( + const std::string& function_name, + const T1& A1 + ) + { + using namespace mex_binding; + const int num_args = 1; + mxArray* plhs[num_args] = {0}; + mxArray* prhs[num_args] = {0}; + + int nrhs = 0; + setup_input_args(prhs[nrhs], A1, nrhs); + + const int nlhs = num_args - nrhs; + call_matlab_for_real(nlhs,plhs,nrhs,prhs, function_name); + + int i = 0; + setup_output_args(function_name, plhs[i], A1, i); + + free_callback_resources(nlhs,plhs,nrhs,prhs); + } + + template < + typename T1, + typename T2 + > + void call_matlab ( + const std::string& function_name, + const T1& A1, + const T2& A2 + ) + { + using namespace mex_binding; + const int num_args = 2; + mxArray* plhs[num_args] = {0}; + mxArray* prhs[num_args] = {0}; + + int nrhs = 0; + setup_input_args(prhs[nrhs], A1, nrhs); + setup_input_args(prhs[nrhs], A2, nrhs); + + const int nlhs = num_args - nrhs; + call_matlab_for_real(nlhs,plhs,nrhs,prhs, function_name); + + int i = 0; + setup_output_args(function_name, plhs[i], A1, i); + setup_output_args(function_name, plhs[i], A2, i); + + free_callback_resources(nlhs,plhs,nrhs,prhs); + } + + template < + typename T1, + typename T2, + typename T3 + > + void call_matlab ( + const std::string& function_name, + const T1& A1, + const T2& A2, + const T3& A3 + ) + { + using namespace mex_binding; + const int num_args = 3; + mxArray* plhs[num_args] = {0}; + mxArray* prhs[num_args] = {0}; + + int nrhs = 0; + setup_input_args(prhs[nrhs], A1, nrhs); + setup_input_args(prhs[nrhs], A2, nrhs); + setup_input_args(prhs[nrhs], A3, nrhs); + + const int nlhs = num_args - nrhs; + call_matlab_for_real(nlhs,plhs,nrhs,prhs, function_name); + + int i = 0; + setup_output_args(function_name, plhs[i], A1, i); + setup_output_args(function_name, plhs[i], A2, i); + setup_output_args(function_name, plhs[i], A3, i); + + free_callback_resources(nlhs,plhs,nrhs,prhs); + } + + + template < + typename T1, + typename T2, + typename T3, + typename T4 + > + void call_matlab ( + const std::string& function_name, + const T1& A1, + const T2& A2, + const T3& A3, + const T4& A4 + ) + { + using namespace mex_binding; + const int num_args = 4; + mxArray* plhs[num_args] = {0}; + mxArray* prhs[num_args] = {0}; + + int nrhs = 0; + setup_input_args(prhs[nrhs], A1, nrhs); + setup_input_args(prhs[nrhs], A2, nrhs); + setup_input_args(prhs[nrhs], A3, nrhs); + setup_input_args(prhs[nrhs], A4, nrhs); + + const int nlhs = num_args - nrhs; + call_matlab_for_real(nlhs,plhs,nrhs,prhs, function_name); + + int i = 0; + setup_output_args(function_name, plhs[i], A1, i); + setup_output_args(function_name, plhs[i], A2, i); + setup_output_args(function_name, plhs[i], A3, i); + setup_output_args(function_name, plhs[i], A4, i); + + free_callback_resources(nlhs,plhs,nrhs,prhs); + } + + + template < + typename T1, + typename T2, + typename T3, + typename T4, + typename T5 + > + void call_matlab ( + const std::string& function_name, + const T1& A1, + const T2& A2, + const T3& A3, + const T4& A4, + const T5& A5 + ) + { + using namespace mex_binding; + const int num_args = 5; + mxArray* plhs[num_args] = {0}; + mxArray* prhs[num_args] = {0}; + + int nrhs = 0; + setup_input_args(prhs[nrhs], A1, nrhs); + setup_input_args(prhs[nrhs], A2, nrhs); + setup_input_args(prhs[nrhs], A3, nrhs); + setup_input_args(prhs[nrhs], A4, nrhs); + setup_input_args(prhs[nrhs], A5, nrhs); + + const int nlhs = num_args - nrhs; + call_matlab_for_real(nlhs,plhs,nrhs,prhs, function_name); + + int i = 0; + setup_output_args(function_name, plhs[i], A1, i); + setup_output_args(function_name, plhs[i], A2, i); + setup_output_args(function_name, plhs[i], A3, i); + setup_output_args(function_name, plhs[i], A4, i); + setup_output_args(function_name, plhs[i], A5, i); + + free_callback_resources(nlhs,plhs,nrhs,prhs); + } + + + template < + typename T1, + typename T2, + typename T3, + typename T4, + typename T5, + typename T6 + > + void call_matlab ( + const std::string& function_name, + const T1& A1, + const T2& A2, + const T3& A3, + const T4& A4, + const T5& A5, + const T6& A6 + ) + { + using namespace mex_binding; + const int num_args = 6; + mxArray* plhs[num_args] = {0}; + mxArray* prhs[num_args] = {0}; + + int nrhs = 0; + setup_input_args(prhs[nrhs], A1, nrhs); + setup_input_args(prhs[nrhs], A2, nrhs); + setup_input_args(prhs[nrhs], A3, nrhs); + setup_input_args(prhs[nrhs], A4, nrhs); + setup_input_args(prhs[nrhs], A5, nrhs); + setup_input_args(prhs[nrhs], A6, nrhs); + + const int nlhs = num_args - nrhs; + call_matlab_for_real(nlhs,plhs,nrhs,prhs, function_name); + + int i = 0; + setup_output_args(function_name, plhs[i], A1, i); + setup_output_args(function_name, plhs[i], A2, i); + setup_output_args(function_name, plhs[i], A3, i); + setup_output_args(function_name, plhs[i], A4, i); + setup_output_args(function_name, plhs[i], A5, i); + setup_output_args(function_name, plhs[i], A6, i); + + free_callback_resources(nlhs,plhs,nrhs,prhs); + } + + + template < + typename T1, + typename T2, + typename T3, + typename T4, + typename T5, + typename T6, + typename T7 + > + void call_matlab ( + const std::string& function_name, + const T1& A1, + const T2& A2, + const T3& A3, + const T4& A4, + const T5& A5, + const T6& A6, + const T7& A7 + ) + { + using namespace mex_binding; + const int num_args = 7; + mxArray* plhs[num_args] = {0}; + mxArray* prhs[num_args] = {0}; + + int nrhs = 0; + setup_input_args(prhs[nrhs], A1, nrhs); + setup_input_args(prhs[nrhs], A2, nrhs); + setup_input_args(prhs[nrhs], A3, nrhs); + setup_input_args(prhs[nrhs], A4, nrhs); + setup_input_args(prhs[nrhs], A5, nrhs); + setup_input_args(prhs[nrhs], A6, nrhs); + setup_input_args(prhs[nrhs], A7, nrhs); + + const int nlhs = num_args - nrhs; + call_matlab_for_real(nlhs,plhs,nrhs,prhs, function_name); + + int i = 0; + setup_output_args(function_name, plhs[i], A1, i); + setup_output_args(function_name, plhs[i], A2, i); + setup_output_args(function_name, plhs[i], A3, i); + setup_output_args(function_name, plhs[i], A4, i); + setup_output_args(function_name, plhs[i], A5, i); + setup_output_args(function_name, plhs[i], A6, i); + setup_output_args(function_name, plhs[i], A7, i); + + free_callback_resources(nlhs,plhs,nrhs,prhs); + } + + + template < + typename T1, + typename T2, + typename T3, + typename T4, + typename T5, + typename T6, + typename T7, + typename T8 + > + void call_matlab ( + const std::string& function_name, + const T1& A1, + const T2& A2, + const T3& A3, + const T4& A4, + const T5& A5, + const T6& A6, + const T7& A7, + const T8& A8 + ) + { + using namespace mex_binding; + const int num_args = 8; + mxArray* plhs[num_args] = {0}; + mxArray* prhs[num_args] = {0}; + + int nrhs = 0; + setup_input_args(prhs[nrhs], A1, nrhs); + setup_input_args(prhs[nrhs], A2, nrhs); + setup_input_args(prhs[nrhs], A3, nrhs); + setup_input_args(prhs[nrhs], A4, nrhs); + setup_input_args(prhs[nrhs], A5, nrhs); + setup_input_args(prhs[nrhs], A6, nrhs); + setup_input_args(prhs[nrhs], A7, nrhs); + setup_input_args(prhs[nrhs], A8, nrhs); + + const int nlhs = num_args - nrhs; + call_matlab_for_real(nlhs,plhs,nrhs,prhs, function_name); + + int i = 0; + setup_output_args(function_name, plhs[i], A1, i); + setup_output_args(function_name, plhs[i], A2, i); + setup_output_args(function_name, plhs[i], A3, i); + setup_output_args(function_name, plhs[i], A4, i); + setup_output_args(function_name, plhs[i], A5, i); + setup_output_args(function_name, plhs[i], A6, i); + setup_output_args(function_name, plhs[i], A7, i); + setup_output_args(function_name, plhs[i], A8, i); + + free_callback_resources(nlhs,plhs,nrhs,prhs); + } + + + + template < + typename T1, + typename T2, + typename T3, + typename T4, + typename T5, + typename T6, + typename T7, + typename T8, + typename T9 + > + void call_matlab ( + const std::string& function_name, + const T1& A1, + const T2& A2, + const T3& A3, + const T4& A4, + const T5& A5, + const T6& A6, + const T7& A7, + const T8& A8, + const T9& A9 + ) + { + using namespace mex_binding; + const int num_args = 9; + mxArray* plhs[num_args] = {0}; + mxArray* prhs[num_args] = {0}; + + int nrhs = 0; + setup_input_args(prhs[nrhs], A1, nrhs); + setup_input_args(prhs[nrhs], A2, nrhs); + setup_input_args(prhs[nrhs], A3, nrhs); + setup_input_args(prhs[nrhs], A4, nrhs); + setup_input_args(prhs[nrhs], A5, nrhs); + setup_input_args(prhs[nrhs], A6, nrhs); + setup_input_args(prhs[nrhs], A7, nrhs); + setup_input_args(prhs[nrhs], A8, nrhs); + setup_input_args(prhs[nrhs], A9, nrhs); + + const int nlhs = num_args - nrhs; + call_matlab_for_real(nlhs,plhs,nrhs,prhs, function_name); + + int i = 0; + setup_output_args(function_name, plhs[i], A1, i); + setup_output_args(function_name, plhs[i], A2, i); + setup_output_args(function_name, plhs[i], A3, i); + setup_output_args(function_name, plhs[i], A4, i); + setup_output_args(function_name, plhs[i], A5, i); + setup_output_args(function_name, plhs[i], A6, i); + setup_output_args(function_name, plhs[i], A7, i); + setup_output_args(function_name, plhs[i], A8, i); + setup_output_args(function_name, plhs[i], A9, i); + + free_callback_resources(nlhs,plhs,nrhs,prhs); + } + + + template < + typename T1, + typename T2, + typename T3, + typename T4, + typename T5, + typename T6, + typename T7, + typename T8, + typename T9, + typename T10 + > + void call_matlab ( + const std::string& function_name, + const T1& A1, + const T2& A2, + const T3& A3, + const T4& A4, + const T5& A5, + const T6& A6, + const T7& A7, + const T8& A8, + const T9& A9, + const T10& A10 + ) + { + using namespace mex_binding; + const int num_args = 10; + mxArray* plhs[num_args] = {0}; + mxArray* prhs[num_args] = {0}; + + int nrhs = 0; + setup_input_args(prhs[nrhs], A1, nrhs); + setup_input_args(prhs[nrhs], A2, nrhs); + setup_input_args(prhs[nrhs], A3, nrhs); + setup_input_args(prhs[nrhs], A4, nrhs); + setup_input_args(prhs[nrhs], A5, nrhs); + setup_input_args(prhs[nrhs], A6, nrhs); + setup_input_args(prhs[nrhs], A7, nrhs); + setup_input_args(prhs[nrhs], A8, nrhs); + setup_input_args(prhs[nrhs], A9, nrhs); + setup_input_args(prhs[nrhs], A10, nrhs); + + const int nlhs = num_args - nrhs; + call_matlab_for_real(nlhs,plhs,nrhs,prhs, function_name); + + int i = 0; + setup_output_args(function_name, plhs[i], A1, i); + setup_output_args(function_name, plhs[i], A2, i); + setup_output_args(function_name, plhs[i], A3, i); + setup_output_args(function_name, plhs[i], A4, i); + setup_output_args(function_name, plhs[i], A5, i); + setup_output_args(function_name, plhs[i], A6, i); + setup_output_args(function_name, plhs[i], A7, i); + setup_output_args(function_name, plhs[i], A8, i); + setup_output_args(function_name, plhs[i], A9, i); + setup_output_args(function_name, plhs[i], A10, i); + + free_callback_resources(nlhs,plhs,nrhs,prhs); + } + + template < + typename T1, + typename T2, + typename T3, + typename T4, + typename T5, + typename T6, + typename T7, + typename T8, + typename T9, + typename T10, + typename T11 + > + void call_matlab ( + const std::string& function_name, + const T1& A1, + const T2& A2, + const T3& A3, + const T4& A4, + const T5& A5, + const T6& A6, + const T7& A7, + const T8& A8, + const T9& A9, + const T10& A10, + const T11& A11 + ) + { + using namespace mex_binding; + const int num_args = 11; + mxArray* plhs[num_args] = {0}; + mxArray* prhs[num_args] = {0}; + + int nrhs = 0; + setup_input_args(prhs[nrhs], A1, nrhs); + setup_input_args(prhs[nrhs], A2, nrhs); + setup_input_args(prhs[nrhs], A3, nrhs); + setup_input_args(prhs[nrhs], A4, nrhs); + setup_input_args(prhs[nrhs], A5, nrhs); + setup_input_args(prhs[nrhs], A6, nrhs); + setup_input_args(prhs[nrhs], A7, nrhs); + setup_input_args(prhs[nrhs], A8, nrhs); + setup_input_args(prhs[nrhs], A9, nrhs); + setup_input_args(prhs[nrhs], A10, nrhs); + setup_input_args(prhs[nrhs], A11, nrhs); + + const int nlhs = num_args - nrhs; + call_matlab_for_real(nlhs,plhs,nrhs,prhs, function_name); + + int i = 0; + setup_output_args(function_name, plhs[i], A1, i); + setup_output_args(function_name, plhs[i], A2, i); + setup_output_args(function_name, plhs[i], A3, i); + setup_output_args(function_name, plhs[i], A4, i); + setup_output_args(function_name, plhs[i], A5, i); + setup_output_args(function_name, plhs[i], A6, i); + setup_output_args(function_name, plhs[i], A7, i); + setup_output_args(function_name, plhs[i], A8, i); + setup_output_args(function_name, plhs[i], A9, i); + setup_output_args(function_name, plhs[i], A10, i); + setup_output_args(function_name, plhs[i], A11, i); + + free_callback_resources(nlhs,plhs,nrhs,prhs); + } + + template < + typename T1, + typename T2, + typename T3, + typename T4, + typename T5, + typename T6, + typename T7, + typename T8, + typename T9, + typename T10, + typename T11, + typename T12 + > + void call_matlab ( + const std::string& function_name, + const T1& A1, + const T2& A2, + const T3& A3, + const T4& A4, + const T5& A5, + const T6& A6, + const T7& A7, + const T8& A8, + const T9& A9, + const T10& A10, + const T11& A11, + const T12& A12 + ) + { + using namespace mex_binding; + const int num_args = 12; + mxArray* plhs[num_args] = {0}; + mxArray* prhs[num_args] = {0}; + + int nrhs = 0; + setup_input_args(prhs[nrhs], A1, nrhs); + setup_input_args(prhs[nrhs], A2, nrhs); + setup_input_args(prhs[nrhs], A3, nrhs); + setup_input_args(prhs[nrhs], A4, nrhs); + setup_input_args(prhs[nrhs], A5, nrhs); + setup_input_args(prhs[nrhs], A6, nrhs); + setup_input_args(prhs[nrhs], A7, nrhs); + setup_input_args(prhs[nrhs], A8, nrhs); + setup_input_args(prhs[nrhs], A9, nrhs); + setup_input_args(prhs[nrhs], A10, nrhs); + setup_input_args(prhs[nrhs], A11, nrhs); + setup_input_args(prhs[nrhs], A12, nrhs); + + const int nlhs = num_args - nrhs; + call_matlab_for_real(nlhs,plhs,nrhs,prhs, function_name); + + int i = 0; + setup_output_args(function_name, plhs[i], A1, i); + setup_output_args(function_name, plhs[i], A2, i); + setup_output_args(function_name, plhs[i], A3, i); + setup_output_args(function_name, plhs[i], A4, i); + setup_output_args(function_name, plhs[i], A5, i); + setup_output_args(function_name, plhs[i], A6, i); + setup_output_args(function_name, plhs[i], A7, i); + setup_output_args(function_name, plhs[i], A8, i); + setup_output_args(function_name, plhs[i], A9, i); + setup_output_args(function_name, plhs[i], A10, i); + setup_output_args(function_name, plhs[i], A11, i); + setup_output_args(function_name, plhs[i], A12, i); + + free_callback_resources(nlhs,plhs,nrhs,prhs); + } + + template < + typename T1, + typename T2, + typename T3, + typename T4, + typename T5, + typename T6, + typename T7, + typename T8, + typename T9, + typename T10, + typename T11, + typename T12, + typename T13 + > + void call_matlab ( + const std::string& function_name, + const T1& A1, + const T2& A2, + const T3& A3, + const T4& A4, + const T5& A5, + const T6& A6, + const T7& A7, + const T8& A8, + const T9& A9, + const T10& A10, + const T11& A11, + const T12& A12, + const T13& A13 + ) + { + using namespace mex_binding; + const int num_args = 13; + mxArray* plhs[num_args] = {0}; + mxArray* prhs[num_args] = {0}; + + int nrhs = 0; + setup_input_args(prhs[nrhs], A1, nrhs); + setup_input_args(prhs[nrhs], A2, nrhs); + setup_input_args(prhs[nrhs], A3, nrhs); + setup_input_args(prhs[nrhs], A4, nrhs); + setup_input_args(prhs[nrhs], A5, nrhs); + setup_input_args(prhs[nrhs], A6, nrhs); + setup_input_args(prhs[nrhs], A7, nrhs); + setup_input_args(prhs[nrhs], A8, nrhs); + setup_input_args(prhs[nrhs], A9, nrhs); + setup_input_args(prhs[nrhs], A10, nrhs); + setup_input_args(prhs[nrhs], A11, nrhs); + setup_input_args(prhs[nrhs], A12, nrhs); + setup_input_args(prhs[nrhs], A13, nrhs); + + const int nlhs = num_args - nrhs; + call_matlab_for_real(nlhs,plhs,nrhs,prhs, function_name); + + int i = 0; + setup_output_args(function_name, plhs[i], A1, i); + setup_output_args(function_name, plhs[i], A2, i); + setup_output_args(function_name, plhs[i], A3, i); + setup_output_args(function_name, plhs[i], A4, i); + setup_output_args(function_name, plhs[i], A5, i); + setup_output_args(function_name, plhs[i], A6, i); + setup_output_args(function_name, plhs[i], A7, i); + setup_output_args(function_name, plhs[i], A8, i); + setup_output_args(function_name, plhs[i], A9, i); + setup_output_args(function_name, plhs[i], A10, i); + setup_output_args(function_name, plhs[i], A11, i); + setup_output_args(function_name, plhs[i], A12, i); + setup_output_args(function_name, plhs[i], A13, i); + + free_callback_resources(nlhs,plhs,nrhs,prhs); + } + + template < + typename T1, + typename T2, + typename T3, + typename T4, + typename T5, + typename T6, + typename T7, + typename T8, + typename T9, + typename T10, + typename T11, + typename T12, + typename T13, + typename T14 + > + void call_matlab ( + const std::string& function_name, + const T1& A1, + const T2& A2, + const T3& A3, + const T4& A4, + const T5& A5, + const T6& A6, + const T7& A7, + const T8& A8, + const T9& A9, + const T10& A10, + const T11& A11, + const T12& A12, + const T13& A13, + const T14& A14 + ) + { + using namespace mex_binding; + const int num_args = 14; + mxArray* plhs[num_args] = {0}; + mxArray* prhs[num_args] = {0}; + + int nrhs = 0; + setup_input_args(prhs[nrhs], A1, nrhs); + setup_input_args(prhs[nrhs], A2, nrhs); + setup_input_args(prhs[nrhs], A3, nrhs); + setup_input_args(prhs[nrhs], A4, nrhs); + setup_input_args(prhs[nrhs], A5, nrhs); + setup_input_args(prhs[nrhs], A6, nrhs); + setup_input_args(prhs[nrhs], A7, nrhs); + setup_input_args(prhs[nrhs], A8, nrhs); + setup_input_args(prhs[nrhs], A9, nrhs); + setup_input_args(prhs[nrhs], A10, nrhs); + setup_input_args(prhs[nrhs], A11, nrhs); + setup_input_args(prhs[nrhs], A12, nrhs); + setup_input_args(prhs[nrhs], A13, nrhs); + setup_input_args(prhs[nrhs], A14, nrhs); + + const int nlhs = num_args - nrhs; + call_matlab_for_real(nlhs,plhs,nrhs,prhs, function_name); + + int i = 0; + setup_output_args(function_name, plhs[i], A1, i); + setup_output_args(function_name, plhs[i], A2, i); + setup_output_args(function_name, plhs[i], A3, i); + setup_output_args(function_name, plhs[i], A4, i); + setup_output_args(function_name, plhs[i], A5, i); + setup_output_args(function_name, plhs[i], A6, i); + setup_output_args(function_name, plhs[i], A7, i); + setup_output_args(function_name, plhs[i], A8, i); + setup_output_args(function_name, plhs[i], A9, i); + setup_output_args(function_name, plhs[i], A10, i); + setup_output_args(function_name, plhs[i], A11, i); + setup_output_args(function_name, plhs[i], A12, i); + setup_output_args(function_name, plhs[i], A13, i); + setup_output_args(function_name, plhs[i], A14, i); + + free_callback_resources(nlhs,plhs,nrhs,prhs); + } + + template < + typename T1, + typename T2, + typename T3, + typename T4, + typename T5, + typename T6, + typename T7, + typename T8, + typename T9, + typename T10, + typename T11, + typename T12, + typename T13, + typename T14, + typename T15 + > + void call_matlab ( + const std::string& function_name, + const T1& A1, + const T2& A2, + const T3& A3, + const T4& A4, + const T5& A5, + const T6& A6, + const T7& A7, + const T8& A8, + const T9& A9, + const T10& A10, + const T11& A11, + const T12& A12, + const T13& A13, + const T14& A14, + const T15& A15 + ) + { + using namespace mex_binding; + const int num_args = 15; + mxArray* plhs[num_args] = {0}; + mxArray* prhs[num_args] = {0}; + + int nrhs = 0; + setup_input_args(prhs[nrhs], A1, nrhs); + setup_input_args(prhs[nrhs], A2, nrhs); + setup_input_args(prhs[nrhs], A3, nrhs); + setup_input_args(prhs[nrhs], A4, nrhs); + setup_input_args(prhs[nrhs], A5, nrhs); + setup_input_args(prhs[nrhs], A6, nrhs); + setup_input_args(prhs[nrhs], A7, nrhs); + setup_input_args(prhs[nrhs], A8, nrhs); + setup_input_args(prhs[nrhs], A9, nrhs); + setup_input_args(prhs[nrhs], A10, nrhs); + setup_input_args(prhs[nrhs], A11, nrhs); + setup_input_args(prhs[nrhs], A12, nrhs); + setup_input_args(prhs[nrhs], A13, nrhs); + setup_input_args(prhs[nrhs], A14, nrhs); + setup_input_args(prhs[nrhs], A15, nrhs); + + const int nlhs = num_args - nrhs; + call_matlab_for_real(nlhs,plhs,nrhs,prhs, function_name); + + int i = 0; + setup_output_args(function_name, plhs[i], A1, i); + setup_output_args(function_name, plhs[i], A2, i); + setup_output_args(function_name, plhs[i], A3, i); + setup_output_args(function_name, plhs[i], A4, i); + setup_output_args(function_name, plhs[i], A5, i); + setup_output_args(function_name, plhs[i], A6, i); + setup_output_args(function_name, plhs[i], A7, i); + setup_output_args(function_name, plhs[i], A8, i); + setup_output_args(function_name, plhs[i], A9, i); + setup_output_args(function_name, plhs[i], A10, i); + setup_output_args(function_name, plhs[i], A11, i); + setup_output_args(function_name, plhs[i], A12, i); + setup_output_args(function_name, plhs[i], A13, i); + setup_output_args(function_name, plhs[i], A14, i); + setup_output_args(function_name, plhs[i], A15, i); + + free_callback_resources(nlhs,plhs,nrhs,prhs); + } + + template < + typename T1, + typename T2, + typename T3, + typename T4, + typename T5, + typename T6, + typename T7, + typename T8, + typename T9, + typename T10, + typename T11, + typename T12, + typename T13, + typename T14, + typename T15, + typename T16 + > + void call_matlab ( + const std::string& function_name, + const T1& A1, + const T2& A2, + const T3& A3, + const T4& A4, + const T5& A5, + const T6& A6, + const T7& A7, + const T8& A8, + const T9& A9, + const T10& A10, + const T11& A11, + const T12& A12, + const T13& A13, + const T14& A14, + const T15& A15, + const T16& A16 + ) + { + using namespace mex_binding; + const int num_args = 16; + mxArray* plhs[num_args] = {0}; + mxArray* prhs[num_args] = {0}; + + int nrhs = 0; + setup_input_args(prhs[nrhs], A1, nrhs); + setup_input_args(prhs[nrhs], A2, nrhs); + setup_input_args(prhs[nrhs], A3, nrhs); + setup_input_args(prhs[nrhs], A4, nrhs); + setup_input_args(prhs[nrhs], A5, nrhs); + setup_input_args(prhs[nrhs], A6, nrhs); + setup_input_args(prhs[nrhs], A7, nrhs); + setup_input_args(prhs[nrhs], A8, nrhs); + setup_input_args(prhs[nrhs], A9, nrhs); + setup_input_args(prhs[nrhs], A10, nrhs); + setup_input_args(prhs[nrhs], A11, nrhs); + setup_input_args(prhs[nrhs], A12, nrhs); + setup_input_args(prhs[nrhs], A13, nrhs); + setup_input_args(prhs[nrhs], A14, nrhs); + setup_input_args(prhs[nrhs], A15, nrhs); + setup_input_args(prhs[nrhs], A16, nrhs); + + const int nlhs = num_args - nrhs; + call_matlab_for_real(nlhs,plhs,nrhs,prhs, function_name); + + int i = 0; + setup_output_args(function_name, plhs[i], A1, i); + setup_output_args(function_name, plhs[i], A2, i); + setup_output_args(function_name, plhs[i], A3, i); + setup_output_args(function_name, plhs[i], A4, i); + setup_output_args(function_name, plhs[i], A5, i); + setup_output_args(function_name, plhs[i], A6, i); + setup_output_args(function_name, plhs[i], A7, i); + setup_output_args(function_name, plhs[i], A8, i); + setup_output_args(function_name, plhs[i], A9, i); + setup_output_args(function_name, plhs[i], A10, i); + setup_output_args(function_name, plhs[i], A11, i); + setup_output_args(function_name, plhs[i], A12, i); + setup_output_args(function_name, plhs[i], A13, i); + setup_output_args(function_name, plhs[i], A14, i); + setup_output_args(function_name, plhs[i], A15, i); + setup_output_args(function_name, plhs[i], A16, i); + + free_callback_resources(nlhs,plhs,nrhs,prhs); + } + + template < + typename T1, + typename T2, + typename T3, + typename T4, + typename T5, + typename T6, + typename T7, + typename T8, + typename T9, + typename T10, + typename T11, + typename T12, + typename T13, + typename T14, + typename T15, + typename T16, + typename T17 + > + void call_matlab ( + const std::string& function_name, + const T1& A1, + const T2& A2, + const T3& A3, + const T4& A4, + const T5& A5, + const T6& A6, + const T7& A7, + const T8& A8, + const T9& A9, + const T10& A10, + const T11& A11, + const T12& A12, + const T13& A13, + const T14& A14, + const T15& A15, + const T16& A16, + const T17& A17 + ) + { + using namespace mex_binding; + const int num_args = 17; + mxArray* plhs[num_args] = {0}; + mxArray* prhs[num_args] = {0}; + + int nrhs = 0; + setup_input_args(prhs[nrhs], A1, nrhs); + setup_input_args(prhs[nrhs], A2, nrhs); + setup_input_args(prhs[nrhs], A3, nrhs); + setup_input_args(prhs[nrhs], A4, nrhs); + setup_input_args(prhs[nrhs], A5, nrhs); + setup_input_args(prhs[nrhs], A6, nrhs); + setup_input_args(prhs[nrhs], A7, nrhs); + setup_input_args(prhs[nrhs], A8, nrhs); + setup_input_args(prhs[nrhs], A9, nrhs); + setup_input_args(prhs[nrhs], A10, nrhs); + setup_input_args(prhs[nrhs], A11, nrhs); + setup_input_args(prhs[nrhs], A12, nrhs); + setup_input_args(prhs[nrhs], A13, nrhs); + setup_input_args(prhs[nrhs], A14, nrhs); + setup_input_args(prhs[nrhs], A15, nrhs); + setup_input_args(prhs[nrhs], A16, nrhs); + setup_input_args(prhs[nrhs], A17, nrhs); + + const int nlhs = num_args - nrhs; + call_matlab_for_real(nlhs,plhs,nrhs,prhs, function_name); + + int i = 0; + setup_output_args(function_name, plhs[i], A1, i); + setup_output_args(function_name, plhs[i], A2, i); + setup_output_args(function_name, plhs[i], A3, i); + setup_output_args(function_name, plhs[i], A4, i); + setup_output_args(function_name, plhs[i], A5, i); + setup_output_args(function_name, plhs[i], A6, i); + setup_output_args(function_name, plhs[i], A7, i); + setup_output_args(function_name, plhs[i], A8, i); + setup_output_args(function_name, plhs[i], A9, i); + setup_output_args(function_name, plhs[i], A10, i); + setup_output_args(function_name, plhs[i], A11, i); + setup_output_args(function_name, plhs[i], A12, i); + setup_output_args(function_name, plhs[i], A13, i); + setup_output_args(function_name, plhs[i], A14, i); + setup_output_args(function_name, plhs[i], A15, i); + setup_output_args(function_name, plhs[i], A16, i); + setup_output_args(function_name, plhs[i], A17, i); + + free_callback_resources(nlhs,plhs,nrhs,prhs); + } + + template < + typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, + typename T7, typename T8, typename T9, typename T10, typename T11, typename T12, + typename T13, typename T14, typename T15, typename T16, typename T17, typename T18 + > + void call_matlab ( + const std::string& function_name, + const T1& A1, const T2& A2, const T3& A3, const T4& A4, const T5& A5, const T6& A6, + const T7& A7, const T8& A8, const T9& A9, const T10& A10, const T11& A11, const + T12& A12, const T13& A13, const T14& A14, const T15& A15, const T16& A16, const + T17& A17, const T18& A18 + ) + { + using namespace mex_binding; + const int num_args = 18; + mxArray* plhs[num_args] = {0}; + mxArray* prhs[num_args] = {0}; + + int nrhs = 0; + setup_input_args(prhs[nrhs], A1, nrhs); + setup_input_args(prhs[nrhs], A2, nrhs); + setup_input_args(prhs[nrhs], A3, nrhs); + setup_input_args(prhs[nrhs], A4, nrhs); + setup_input_args(prhs[nrhs], A5, nrhs); + setup_input_args(prhs[nrhs], A6, nrhs); + setup_input_args(prhs[nrhs], A7, nrhs); + setup_input_args(prhs[nrhs], A8, nrhs); + setup_input_args(prhs[nrhs], A9, nrhs); + setup_input_args(prhs[nrhs], A10, nrhs); + setup_input_args(prhs[nrhs], A11, nrhs); + setup_input_args(prhs[nrhs], A12, nrhs); + setup_input_args(prhs[nrhs], A13, nrhs); + setup_input_args(prhs[nrhs], A14, nrhs); + setup_input_args(prhs[nrhs], A15, nrhs); + setup_input_args(prhs[nrhs], A16, nrhs); + setup_input_args(prhs[nrhs], A17, nrhs); + setup_input_args(prhs[nrhs], A18, nrhs); + + const int nlhs = num_args - nrhs; + call_matlab_for_real(nlhs,plhs,nrhs,prhs, function_name); + + int i = 0; + setup_output_args(function_name, plhs[i], A1, i); + setup_output_args(function_name, plhs[i], A2, i); + setup_output_args(function_name, plhs[i], A3, i); + setup_output_args(function_name, plhs[i], A4, i); + setup_output_args(function_name, plhs[i], A5, i); + setup_output_args(function_name, plhs[i], A6, i); + setup_output_args(function_name, plhs[i], A7, i); + setup_output_args(function_name, plhs[i], A8, i); + setup_output_args(function_name, plhs[i], A9, i); + setup_output_args(function_name, plhs[i], A10, i); + setup_output_args(function_name, plhs[i], A11, i); + setup_output_args(function_name, plhs[i], A12, i); + setup_output_args(function_name, plhs[i], A13, i); + setup_output_args(function_name, plhs[i], A14, i); + setup_output_args(function_name, plhs[i], A15, i); + setup_output_args(function_name, plhs[i], A16, i); + setup_output_args(function_name, plhs[i], A17, i); + setup_output_args(function_name, plhs[i], A18, i); + + free_callback_resources(nlhs,plhs,nrhs,prhs); + } + + template < + typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, + typename T7, typename T8, typename T9, typename T10, typename T11, typename T12, + typename T13, typename T14, typename T15, typename T16, typename T17, typename T18, + typename T19 + > + void call_matlab ( + const std::string& function_name, + const T1& A1, const T2& A2, const T3& A3, const T4& A4, const T5& A5, const T6& A6, + const T7& A7, const T8& A8, const T9& A9, const T10& A10, const T11& A11, const + T12& A12, const T13& A13, const T14& A14, const T15& A15, const T16& A16, const + T17& A17, const T18& A18, const T19& A19 + ) + { + using namespace mex_binding; + const int num_args = 19; + mxArray* plhs[num_args] = {0}; + mxArray* prhs[num_args] = {0}; + + int nrhs = 0; + setup_input_args(prhs[nrhs], A1, nrhs); + setup_input_args(prhs[nrhs], A2, nrhs); + setup_input_args(prhs[nrhs], A3, nrhs); + setup_input_args(prhs[nrhs], A4, nrhs); + setup_input_args(prhs[nrhs], A5, nrhs); + setup_input_args(prhs[nrhs], A6, nrhs); + setup_input_args(prhs[nrhs], A7, nrhs); + setup_input_args(prhs[nrhs], A8, nrhs); + setup_input_args(prhs[nrhs], A9, nrhs); + setup_input_args(prhs[nrhs], A10, nrhs); + setup_input_args(prhs[nrhs], A11, nrhs); + setup_input_args(prhs[nrhs], A12, nrhs); + setup_input_args(prhs[nrhs], A13, nrhs); + setup_input_args(prhs[nrhs], A14, nrhs); + setup_input_args(prhs[nrhs], A15, nrhs); + setup_input_args(prhs[nrhs], A16, nrhs); + setup_input_args(prhs[nrhs], A17, nrhs); + setup_input_args(prhs[nrhs], A18, nrhs); + setup_input_args(prhs[nrhs], A19, nrhs); + + const int nlhs = num_args - nrhs; + call_matlab_for_real(nlhs,plhs,nrhs,prhs, function_name); + + int i = 0; + setup_output_args(function_name, plhs[i], A1, i); + setup_output_args(function_name, plhs[i], A2, i); + setup_output_args(function_name, plhs[i], A3, i); + setup_output_args(function_name, plhs[i], A4, i); + setup_output_args(function_name, plhs[i], A5, i); + setup_output_args(function_name, plhs[i], A6, i); + setup_output_args(function_name, plhs[i], A7, i); + setup_output_args(function_name, plhs[i], A8, i); + setup_output_args(function_name, plhs[i], A9, i); + setup_output_args(function_name, plhs[i], A10, i); + setup_output_args(function_name, plhs[i], A11, i); + setup_output_args(function_name, plhs[i], A12, i); + setup_output_args(function_name, plhs[i], A13, i); + setup_output_args(function_name, plhs[i], A14, i); + setup_output_args(function_name, plhs[i], A15, i); + setup_output_args(function_name, plhs[i], A16, i); + setup_output_args(function_name, plhs[i], A17, i); + setup_output_args(function_name, plhs[i], A18, i); + setup_output_args(function_name, plhs[i], A19, i); + + free_callback_resources(nlhs,plhs,nrhs,prhs); + } + + template < + typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, + typename T7, typename T8, typename T9, typename T10, typename T11, typename T12, + typename T13, typename T14, typename T15, typename T16, typename T17, typename T18, + typename T19, typename T20 + > + void call_matlab ( + const std::string& function_name, + const T1& A1, const T2& A2, const T3& A3, const T4& A4, const T5& A5, const T6& A6, + const T7& A7, const T8& A8, const T9& A9, const T10& A10, const T11& A11, const + T12& A12, const T13& A13, const T14& A14, const T15& A15, const T16& A16, const + T17& A17, const T18& A18, const T19& A19, const T20& A20 + ) + { + using namespace mex_binding; + const int num_args = 20; + mxArray* plhs[num_args] = {0}; + mxArray* prhs[num_args] = {0}; + + int nrhs = 0; + setup_input_args(prhs[nrhs], A1, nrhs); + setup_input_args(prhs[nrhs], A2, nrhs); + setup_input_args(prhs[nrhs], A3, nrhs); + setup_input_args(prhs[nrhs], A4, nrhs); + setup_input_args(prhs[nrhs], A5, nrhs); + setup_input_args(prhs[nrhs], A6, nrhs); + setup_input_args(prhs[nrhs], A7, nrhs); + setup_input_args(prhs[nrhs], A8, nrhs); + setup_input_args(prhs[nrhs], A9, nrhs); + setup_input_args(prhs[nrhs], A10, nrhs); + setup_input_args(prhs[nrhs], A11, nrhs); + setup_input_args(prhs[nrhs], A12, nrhs); + setup_input_args(prhs[nrhs], A13, nrhs); + setup_input_args(prhs[nrhs], A14, nrhs); + setup_input_args(prhs[nrhs], A15, nrhs); + setup_input_args(prhs[nrhs], A16, nrhs); + setup_input_args(prhs[nrhs], A17, nrhs); + setup_input_args(prhs[nrhs], A18, nrhs); + setup_input_args(prhs[nrhs], A19, nrhs); + setup_input_args(prhs[nrhs], A20, nrhs); + + const int nlhs = num_args - nrhs; + call_matlab_for_real(nlhs,plhs,nrhs,prhs, function_name); + + int i = 0; + setup_output_args(function_name, plhs[i], A1, i); + setup_output_args(function_name, plhs[i], A2, i); + setup_output_args(function_name, plhs[i], A3, i); + setup_output_args(function_name, plhs[i], A4, i); + setup_output_args(function_name, plhs[i], A5, i); + setup_output_args(function_name, plhs[i], A6, i); + setup_output_args(function_name, plhs[i], A7, i); + setup_output_args(function_name, plhs[i], A8, i); + setup_output_args(function_name, plhs[i], A9, i); + setup_output_args(function_name, plhs[i], A10, i); + setup_output_args(function_name, plhs[i], A11, i); + setup_output_args(function_name, plhs[i], A12, i); + setup_output_args(function_name, plhs[i], A13, i); + setup_output_args(function_name, plhs[i], A14, i); + setup_output_args(function_name, plhs[i], A15, i); + setup_output_args(function_name, plhs[i], A16, i); + setup_output_args(function_name, plhs[i], A17, i); + setup_output_args(function_name, plhs[i], A18, i); + setup_output_args(function_name, plhs[i], A19, i); + setup_output_args(function_name, plhs[i], A20, i); + + free_callback_resources(nlhs,plhs,nrhs,prhs); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + matlab_object::~matlab_object( + ) + { + if (handle && should_free) + { + mxDestroyArray((mxArray*)handle); + handle = 0; + } + } + + template + matlab_object:: + operator T( + ) const + { + T item; + get(item); + return item; + } + + template + void matlab_object:: + get( + T& item + ) const + { + if (handle == 0) + throw dlib::invalid_args_exception("An attempt was made to access an empty matlab_object."); + + mex_binding::validate_and_populate_arg(arg_idx,(mxArray*)handle,item); + } + + template + matlab_object& matlab_object:: + operator= ( + const T& new_val + ) + { + mxArray* item; + mex_binding::assign_to_matlab(item, new_val); + handle = item; + should_free = true; + return *this; + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template + matlab_struct::sub::operator T() const + { + T item; + get(item); + return item; + } + + template + void matlab_struct::sub::get(T& item) const + { + if (struct_handle == 0) + throw dlib::error("Attempt to access data in an empty struct."); + + mxArray* temp = mxGetFieldByNumber((const mxArray*)struct_handle, 0, field_idx); + if (temp == 0) + throw dlib::error("Attempt to access data in an empty struct."); + + try + { + mex_binding::validate_and_populate_arg(0,temp,item); + } + catch(mex_binding::invalid_args_exception& e) + { + std::ostringstream sout; + sout << "Struct field '" << mxGetFieldNameByNumber((const mxArray*)struct_handle, field_idx) << "' can't be interpreted as the requested type." + << endl << e.what(); + throw dlib::error(sout.str()); + } + } + + const matlab_struct::sub matlab_struct:: + operator[] (const std::string& name) const + { + if (struct_handle == 0) + throw dlib::error("Struct does not have a field named '" + name + "'."); + + matlab_struct::sub temp; + temp.struct_handle = struct_handle; + temp.field_idx = mxGetFieldNumber((const mxArray*)struct_handle, name.c_str()); + if (temp.field_idx == -1 ) + throw dlib::error("Struct does not have a field named '" + name + "'."); + return temp; + } + + matlab_struct::sub matlab_struct:: + operator[] (const std::string& name) + { + if (struct_handle == 0) + { + // We make a struct from scratch and mark that we will free it unless it gets + // written back to matlab by assign_to_matlab(). + mwSize dims[1] = {1}; + const char* name_str = name.c_str(); + struct_handle = mxCreateStructArray(1, dims, 1, &name_str); + should_free = true; + if (struct_handle == 0) + throw dlib::error("Error creating struct from within mex function."); + } + + + matlab_struct::sub temp; + temp.struct_handle = struct_handle; + if ((temp.field_idx=mxGetFieldNumber((mxArray*)struct_handle, name.c_str())) == -1) + { + if ((temp.field_idx=mxAddField((mxArray*)struct_handle, name.c_str())) == -1) + { + throw dlib::error("Unable to add field '"+name + "' to struct."); + } + } + return temp; + } + + const matlab_struct::sub matlab_struct::sub:: + operator[] (const std::string& name) const + { + if (struct_handle == 0) + throw dlib::error("Struct does not have a field named '" + name + "'."); + + matlab_struct::sub temp; + temp.struct_handle = mxGetFieldByNumber((const mxArray*)struct_handle, 0, field_idx); + if (temp.struct_handle == 0) + throw dlib::error("Failure to get struct field while calling mxGetFieldByNumber()"); + + if (!mxIsStruct((const mxArray*)temp.struct_handle)) + throw dlib::error("Struct sub-field element '"+name+"' is not another struct."); + + temp.field_idx = mxGetFieldNumber((const mxArray*)temp.struct_handle, name.c_str()); + if (temp.field_idx == -1 ) + throw dlib::error("Struct does not have a field named '" + name + "'."); + return temp; + } + + matlab_struct::sub matlab_struct::sub:: + operator[] (const std::string& name) + { + if (struct_handle == 0) + throw dlib::error("Struct does not have a field named '" + name + "'."); + + matlab_struct::sub temp; + temp.struct_handle = mxGetFieldByNumber((const mxArray*)struct_handle, 0, field_idx); + // We are replacing this field with a struct if it exists and isn't already a struct + if (temp.struct_handle != 0 && !mxIsStruct((const mxArray*)temp.struct_handle)) + { + mxDestroyArray((mxArray*)temp.struct_handle); + temp.struct_handle = 0; + } + if (temp.struct_handle == 0) + { + mwSize dims[1] = {1}; + temp.struct_handle = mxCreateStructArray(1, dims, 0, 0); + if (temp.struct_handle == 0) + throw dlib::error("Failure to create new sub-struct field"); + mxSetFieldByNumber((mxArray*)struct_handle, 0, field_idx, (mxArray*)temp.struct_handle); + } + + + if ((temp.field_idx=mxGetFieldNumber((mxArray*)temp.struct_handle, name.c_str())) == -1) + { + if ((temp.field_idx=mxAddField((mxArray*)temp.struct_handle, name.c_str())) == -1) + { + throw dlib::error("Unable to add field '"+name + "' to struct."); + } + } + return temp; + } + + bool matlab_struct::has_field ( + const std::string& name + ) const + { + if (struct_handle == 0) + return false; + return mxGetFieldNumber((const mxArray*)struct_handle, name.c_str()) != -1; + } + + bool matlab_struct::sub::has_field ( + const std::string& name + ) const + { + if (struct_handle == 0) + return false; + mxArray* temp = mxGetFieldByNumber((const mxArray*)struct_handle, 0, field_idx); + if (temp == 0 || !mxIsStruct(temp)) + return false; + return mxGetFieldNumber(temp, name.c_str()) != -1; + } + + template + matlab_struct::sub& matlab_struct::sub::operator= ( + const T& new_val + ) + { + // Delete anything in the field before we overwrite it + mxArray* item = mxGetFieldByNumber((mxArray*)struct_handle, 0, field_idx); + if (item != 0) + { + mxDestroyArray((mxArray*)item); + item = 0; + } + + // Now set the field + mex_binding::assign_to_matlab(item, new_val); + mxSetFieldByNumber((mxArray*)struct_handle, 0, field_idx, item); + + return *this; + } + + matlab_struct:: + ~matlab_struct ( + ) + { + if (struct_handle && should_free) + { + mxDestroyArray((mxArray*)struct_handle); + struct_handle = 0; + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + void call_matlab ( + const function_handle& funct + ) + { + call_matlab("feval", funct); + } + + extern "C" bool utIsInterruptPending(); + void check_for_matlab_ctrl_c( + ) + { + if (utIsInterruptPending()) + throw mex_binding::user_hit_ctrl_c(); + } +} + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + +#ifdef MEX_CLASS_NAME +template +class mex_class_wrapper +{ +public: + mex_class_wrapper(T& obj_, mfp_type mfp_) : obj(obj_), mfp(mfp_) {} + + template + void operator()(Args&&... args) const + { + (obj.*mfp)(std::forward(args)...); + } + + mfp_type mfp; + T& obj; +}; + +template +mex_class_wrapper wrap_mex_class(T& obj, mfp_type mfp) { return mex_class_wrapper(obj, mfp); } + +namespace dlib +{ + template + struct sig_traits> + : public sig_traits + {}; + + template ::value> + struct tuple_element_default_void + { + typedef void type; + }; + + template + struct tuple_element_default_void + { + typedef typename std::tuple_element::type type; + }; + + template + struct sig_traits + { + enum { num_args = sizeof...(Args) }; + + typedef return_type result_type; + + template + struct arg + { + typedef typename tuple_element_default_void>::type type; + }; + + // These are here because that's how things are defined in sig_traits (since it is + // older than C++11, along with most of the other code in this file) + typedef typename arg<1>::type arg1_type; + typedef typename arg<2>::type arg2_type; + typedef typename arg<3>::type arg3_type; + typedef typename arg<4>::type arg4_type; + typedef typename arg<5>::type arg5_type; + typedef typename arg<6>::type arg6_type; + typedef typename arg<7>::type arg7_type; + typedef typename arg<8>::type arg8_type; + typedef typename arg<9>::type arg9_type; + typedef typename arg<10>::type arg10_type; + typedef typename arg<11>::type arg11_type; + typedef typename arg<12>::type arg12_type; + typedef typename arg<13>::type arg13_type; + typedef typename arg<14>::type arg14_type; + typedef typename arg<15>::type arg15_type; + typedef typename arg<16>::type arg16_type; + typedef typename arg<17>::type arg17_type; + typedef typename arg<18>::type arg18_type; + typedef typename arg<19>::type arg19_type; + typedef typename arg<20>::type arg20_type; + }; + + template + struct sig_traits + { + enum { num_args = sizeof...(Args) }; + + typedef return_type result_type; + + template + struct arg + { + typedef typename tuple_element_default_void>::type type; + }; + + // These are here because that's how things are defined in sig_traits (since it is + // older than C++11, along with most of the other code in this file) + typedef typename arg<1>::type arg1_type; + typedef typename arg<2>::type arg2_type; + typedef typename arg<3>::type arg3_type; + typedef typename arg<4>::type arg4_type; + typedef typename arg<5>::type arg5_type; + typedef typename arg<6>::type arg6_type; + typedef typename arg<7>::type arg7_type; + typedef typename arg<8>::type arg8_type; + typedef typename arg<9>::type arg9_type; + typedef typename arg<10>::type arg10_type; + typedef typename arg<11>::type arg11_type; + typedef typename arg<12>::type arg12_type; + typedef typename arg<13>::type arg13_type; + typedef typename arg<14>::type arg14_type; + typedef typename arg<15>::type arg15_type; + typedef typename arg<16>::type arg16_type; + typedef typename arg<17>::type arg17_type; + typedef typename arg<18>::type arg18_type; + typedef typename arg<19>::type arg19_type; + typedef typename arg<20>::type arg20_type; + }; +} + +// ---------------------------------------------------------------------------------------- + + +template +struct visit_impl +{ + template + static void visit(T& tup, size_t idx, F fun) + { + if (idx == I - 1) fun(std::get(tup)); + else visit_impl::visit(tup, idx, fun); + } +}; + +template <> +struct visit_impl<0> +{ + template + static void visit(T& tup, size_t idx, F fun) { DLIB_CASSERT(false,"this should never happen"); } +}; + +template +void visit_at(std::tuple const& tup, size_t idx, F fun) +{ + visit_impl::visit(tup, idx, fun); +} + +template +void visit_at(std::tuple& tup, size_t idx, F fun) +{ + visit_impl::visit(tup, idx, fun); +} + +class mex_class_dispatch +{ +public: + mex_class_dispatch( + MEX_CLASS_NAME* ptr_, + int nlhs_, + mxArray** plhs_, + int nrhs_, + const mxArray** prhs_ + ) : + ptr(ptr_), + nlhs(nlhs_), + plhs(plhs_), + nrhs(nrhs_), + prhs(prhs_) + {} + + template + void operator() (const funct& mfp) + { + mex_binding::call_mex_function(wrap_mex_class(*ptr,mfp), nlhs, plhs, nrhs, prhs); + } + +private: + MEX_CLASS_NAME* ptr; + int nlhs; + mxArray** plhs; + int nrhs; + const mxArray** prhs; +}; + +class class_factory_type : dlib::noncopyable +{ + /*! + WHAT THIS OBJECT REPRESENTS + This is a container class for all the MEX_CLASS_NAME objects we create. It allows + us to track what we have created and make sure the MATLAB user doesn't do any + double frees or use any stale pointers. + + It also helps us deal with the problem that would otherwise arise when a mex file + is unloaded from MATLAB when there are still active pointers to MEX_CLASS_NAME objects + in MATLAB, since we will be able to detect stale pointers. + !*/ +public: + + class_factory_type() + { + seed = (uint64)time(0); + } + + ~class_factory_type() + { + for (auto i : object_table) + delete i.second; + } + + template + uint64 create(T&& ...args) + { + MEX_CLASS_NAME* item = new MEX_CLASS_NAME(std::forward(args)...); + uint64 id = (uint64)item; + // Now generate a unique id that incorporates our seed value. The point of doing + // this is to avoid any chance that a mex file will get unloaded and then reloaded + // and start constructing objects with the same addresses, while old stale objects + // at those addresses are still stored in matlab, which would then call into the + // mex file and make things go crazy. So here we try to generate ID numbers that + // are globally unique. + uint64 i = 0; + id = murmur_hash3_128bit_3(id, seed, ++i).first; + // very unlikely but make sure there aren't any hash collisions. + while(object_table.count(id) != 0) + id = murmur_hash3_128bit_3(id, seed, ++i).first; + + object_table[id] = item; + return id; + } + + void free(uint64 item) + { + if (object_table.count(item) == 0) + { + throw dlib::error("An attempt to deallocate a mex class object with an invalid pointer was detected."); + } + + delete object_table[item]; + object_table.erase(item); + } + + MEX_CLASS_NAME* access(uint64 item) // convert numeric ID to pointer to object that can be used. + { + if (object_table.count(item) == 0) + { + throw dlib::error("An attempt to access a mex class object with an invalid pointer was detected."); + } + + return object_table[item]; + } + +private: + + std::map object_table; + uint64 seed; +} class_factory; + +// ---------------------------------------------------------------------------------------- + +// Make a FOREACH macro +#define FE_1(WHAT, X) WHAT(X) +#define FE_2(WHAT, X, ...) WHAT(X),FE_1(WHAT, __VA_ARGS__) +#define FE_3(WHAT, X, ...) WHAT(X),FE_2(WHAT, __VA_ARGS__) +#define FE_4(WHAT, X, ...) WHAT(X),FE_3(WHAT, __VA_ARGS__) +#define FE_5(WHAT, X, ...) WHAT(X),FE_4(WHAT, __VA_ARGS__) +#define FE_6(WHAT, X, ...) WHAT(X),FE_5(WHAT, __VA_ARGS__) +#define FE_7(WHAT, X, ...) WHAT(X),FE_6(WHAT, __VA_ARGS__) +#define FE_8(WHAT, X, ...) WHAT(X),FE_7(WHAT, __VA_ARGS__) +#define FE_9(WHAT, X, ...) WHAT(X),FE_8(WHAT, __VA_ARGS__) +#define FE_10(WHAT, X, ...) WHAT(X),FE_9(WHAT, __VA_ARGS__) +#define FE_11(WHAT, X, ...) WHAT(X),FE_10(WHAT, __VA_ARGS__) +#define FE_12(WHAT, X, ...) WHAT(X),FE_11(WHAT, __VA_ARGS__) +#define FE_13(WHAT, X, ...) WHAT(X),FE_12(WHAT, __VA_ARGS__) +#define FE_14(WHAT, X, ...) WHAT(X),FE_13(WHAT, __VA_ARGS__) +#define FE_15(WHAT, X, ...) WHAT(X),FE_14(WHAT, __VA_ARGS__) +#define FE_16(WHAT, X, ...) WHAT(X),FE_15(WHAT, __VA_ARGS__) +#define FE_17(WHAT, X, ...) WHAT(X),FE_16(WHAT, __VA_ARGS__) +#define FE_18(WHAT, X, ...) WHAT(X),FE_17(WHAT, __VA_ARGS__) +#define FE_19(WHAT, X, ...) WHAT(X),FE_18(WHAT, __VA_ARGS__) +#define FE_20(WHAT, X, ...) WHAT(X),FE_19(WHAT, __VA_ARGS__) +//... repeat as needed +#define GET_MACRO(_1,_2,_3,_4,_5,_6,_7,_8,_9,_10,_11,_12,_13,_14,_15,_16,_17,_18,_19,_20,NAME,...) NAME +#define FOR_EACH(action,...) GET_MACRO(__VA_ARGS__,FE_20,FE_19,FE_18,FE_17,FE_16,FE_15,FE_14,FE_13,FE_12,FE_11,FE_10,FE_9,FE_8,FE_7,FE_6,FE_5,FE_4,FE_3,FE_2,FE_1)(action,__VA_ARGS__) +#define MEX_CLASS_ANNOTATE(x) &MEX_CLASS_NAME::x + +// Now make a tuple containing all the member function pointers to our MEX_CLASS_NAME +auto mex_class_methods = std::make_tuple(FOR_EACH(MEX_CLASS_ANNOTATE, MEX_CLASS_METHODS)); + + +#endif // MEX_CLASS_NAME + +// ---------------------------------------------------------------------------------------- + +bool is_string(const mxArray* arr, const char* str) +{ + if (mxIsChar(arr)) + { + char ch[20]; + DLIB_CASSERT(mxGetString(arr, ch, sizeof(ch))==0, "Unable to retrieve string"); + ch[sizeof(ch)-1] = 0;// ensure NULL termination regardless of what MATLAB does. + return strcmp(str,ch)==0; + } + return false; +} + +// ---------------------------------------------------------------------------------------- + +/* The gateway function called by MATLAB*/ +void mexFunction( int nlhs, mxArray *plhs[], + int nrhs, const mxArray *prhs[]) +{ + // Only remap cout and cerr if we aren't using octave since octave already does this. +#if !defined(OCTAVE_IMPORT) && !defined(OCTAVE_API) + // make it so cout prints to mexPrintf() + mex_binding::mex_streambuf sb; + // make it so cerr prints to mexWarnMsgTxt() + mex_binding::mex_warn_streambuf wsb; +#endif + + try + { +#ifdef MEX_CLASS_NAME + if (nrhs == 0) + { + #define DEF2STR(x) DEF2STR2((x)) + #define DEF2STR2(x) #x + + string classname = trim(string(DEF2STR(MEX_CLASS_NAME)), " \t()"); + std::vector methods = split(trim(string(DEF2STR(MEX_CLASS_METHODS)), " \t()"), " \t,"); + + string mex_filename = trim(string(DEF2STR(MEX_FILENAME))," \t()"); + bool has_load_obj = false; + size_t load_obj_idx = 0; + + cout << "classdef " << classname << " < handle\n" + << " properties (Access = private)\n" + << " cpp_ptr\n" + << " end\n" + << "\n" + << " methods\n" + << " function this = "<::value; + DLIB_CASSERT(1 <= funct_idx && funct_idx <= num_registered_functions, "Invalid function index provided."); + + MEX_CLASS_NAME* ptr = class_factory.access(ptr_int); + + // we used the first two arguments to decide what function to call. So adjust nrhs + // and prhs so the member function never sees them. + mex_class_dispatch dispatch(ptr, nlhs, plhs, nrhs-2, prhs+2); + // now invoke the member function, subtract 1 to convert to 0 indexing. + visit_at(mex_class_methods, funct_idx-1, dispatch); + } + } +#else + mex_binding::call_mex_function(mex_function, nlhs, plhs, nrhs, prhs); +#endif + } + catch (mex_binding::invalid_args_exception& e) + { + mexErrMsgIdAndTxt("mex_function:validate_and_populate_arg", + mex_binding::escape_percent(e.what()).c_str()); + } + catch (mex_binding::user_hit_ctrl_c& ) + { + // do nothing, just return to matlab + } + catch (std::exception& e) + { + mexErrMsgIdAndTxt("mex_function:error", + mex_binding::escape_percent(e.what()).c_str()); + } + + cout << flush; + cerr << flush; +} + +// ---------------------------------------------------------------------------------------- + diff --git a/ml/dlib/dlib/matlab/subprocess_stream.cpp b/ml/dlib/dlib/matlab/subprocess_stream.cpp new file mode 100644 index 000000000..4d4d53af0 --- /dev/null +++ b/ml/dlib/dlib/matlab/subprocess_stream.cpp @@ -0,0 +1,537 @@ +// Copyright (C) 2016 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "subprocess_stream.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include "call_matlab.h" + +using namespace std; + +// ---------------------------------------------------------------------------------------- + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + void make_fd_non_blocking(int fd) + { + int flags = fcntl(fd, F_GETFL, 0); + fcntl(fd, F_SETFL, flags | O_NONBLOCK); + } + +// ---------------------------------------------------------------------------------------- + + // Block until fd is ready to read, while also echoing whatever is in fd_printf to + // cout. + int read_echoing_select(int fd, int fd_printf) + { + // run until fd has data ready + while(fd_printf >= 0) + { + fd_set rfds; + int retval; + + while(true) + { + FD_ZERO(&rfds); + FD_SET(fd, &rfds); + FD_SET(fd_printf, &rfds); + + // select times out every second just so we can check for matlab ctrl+c. + struct timeval tv; + tv.tv_sec = 1; + tv.tv_usec = 0; + + try{check_for_matlab_ctrl_c();} catch(...) { return 1; } + retval = select(std::max(fd,fd_printf)+1, &rfds, NULL, NULL, &tv); + try{check_for_matlab_ctrl_c();} catch(...) { return 1; } + if (retval == 0) // keep going if it was just a timeout. + continue; + else if (retval == -1 && errno == EINTR) + continue; + + break; + } + + if (retval == -1) + { + return 1; + } + else + { + if (FD_ISSET(fd,&rfds)) + { + return 0; + } + else + { + char buf[1024]; + int num = read(fd_printf,buf, sizeof(buf)-1); + if (num == -1) + return 1; + if (num > 0) + { + buf[num] = 0; + cout << buf << flush; + } + } + } + } + return 0; + } + + int write_echoing_select(int fd, int fd_printf) + { + // run until fd has data ready + while(fd_printf >= 0) + { + fd_set rfds, wfds; + int retval; + while(true) + { + FD_ZERO(&rfds); + FD_ZERO(&wfds); + FD_SET(fd, &wfds); + FD_SET(fd_printf, &rfds); + + // select times out every second just so we can check for matlab ctrl+c. + struct timeval tv; + tv.tv_sec = 1; + tv.tv_usec = 0; + + try{check_for_matlab_ctrl_c();} catch(...) { return 1; } + retval = select(std::max(fd,fd_printf)+1, &rfds, &wfds, NULL, &tv); + try{check_for_matlab_ctrl_c();} catch(...) { return 1; } + if (retval == 0) // keep going if it was just a timeout. + continue; + else if (retval == -1 && errno == EINTR) + continue; + + break; + } + + if (retval == -1) + { + return 1; + } + else + { + if (FD_ISSET(fd,&wfds)) + { + return 0; + } + else + { + char buf[1024]; + int num = read(fd_printf,buf, sizeof(buf)-1); + if (num == -1) + return 1; + if (num > 0) + { + buf[num] = 0; + cout << buf << flush; + } + } + } + } + return 0; + } + +// ---------------------------------------------------------------------------------------- + + class filestreambuf : public std::streambuf + { + /*! + INITIAL VALUE + - fd == the file descriptor we read from. + - in_buffer == an array of in_buffer_size bytes + - out_buffer == an array of out_buffer_size bytes + + CONVENTION + - in_buffer == the input buffer used by this streambuf + - out_buffer == the output buffer used by this streambuf + - max_putback == the maximum number of chars to have in the put back buffer. + !*/ + + public: + + filestreambuf ( + int fd_, + int fd_printf_ + ) : + fd(fd_), + fd_printf(fd_printf_), + out_buffer(0), + in_buffer(0) + { + init(); + } + + virtual ~filestreambuf ( + ) + { + sync(); + delete [] out_buffer; + delete [] in_buffer; + } + + int sync ( + ) + { + if (flush_out_buffer() == EOF) + { + // an error occurred + return -1; + } + return 0; + } + protected: + + void init ( + ) + { + try + { + out_buffer = new char[out_buffer_size]; + in_buffer = new char[in_buffer_size]; + } + catch (...) + { + if (out_buffer) delete [] out_buffer; + throw; + } + setp(out_buffer, out_buffer + (out_buffer_size-1)); + setg(in_buffer+max_putback, + in_buffer+max_putback, + in_buffer+max_putback); + } + + int flush_out_buffer ( + ) + { + int num = static_cast(pptr()-pbase()); + const int num_written = num; + char* buf = out_buffer; + while(num != 0) + { + if(write_echoing_select(fd, fd_printf)) + return EOF; + int status = write(fd,buf,num); + if (status < 0) + { + // the write was not successful so return EOF + return EOF; + } + num -= status; + buf += status; + } + pbump(-num_written); + return num_written; + } + + // output functions + int_type overflow ( + int_type c + ) + { + if (c != EOF) + { + *pptr() = c; + pbump(1); + } + if (flush_out_buffer() == EOF) + { + // an error occurred + return EOF; + } + return c; + } + + + std::streamsize xsputn ( + const char* s, + std::streamsize num + ) + { + // Add a sanity check here + DLIB_ASSERT(num >= 0, + "\tstd::streamsize filestreambuf::xsputn" + << "\n\tThe number of bytes to write can't be negative" + << "\n\tnum: " << num + << "\n\tthis: " << this + ); + + std::streamsize space_left = static_cast(epptr()-pptr()); + if (num <= space_left) + { + std::memcpy(pptr(),s,static_cast(num)); + pbump(static_cast(num)); + return num; + } + else + { + std::memcpy(pptr(),s,static_cast(space_left)); + s += space_left; + pbump(space_left); + std::streamsize num_left = num - space_left; + + if (flush_out_buffer() == EOF) + { + // the write was not successful so return that 0 bytes were written + return 0; + } + + if (num_left < out_buffer_size) + { + std::memcpy(pptr(),s,static_cast(num_left)); + pbump(num_left); + return num; + } + else + { + while(num_left != 0) + { + if(write_echoing_select(fd, fd_printf)) + return EOF; + int status = write(fd,s,num_left); + if (status < 0) + { + // the write was not successful so return that 0 bytes were written + return 0; + } + num_left -= status; + s += status; + } + return num; + } + } + } + + // input functions + int_type underflow( + ) + { + if (gptr() < egptr()) + { + return static_cast(*gptr()); + } + + int num_put_back = static_cast(gptr() - eback()); + if (num_put_back > max_putback) + { + num_put_back = max_putback; + } + + // copy the putback characters into the putback end of the in_buffer + std::memmove(in_buffer+(max_putback-num_put_back), gptr()-num_put_back, num_put_back); + + + if (read_echoing_select(fd, fd_printf)) + return EOF; + int num = read(fd,in_buffer+max_putback, in_buffer_size-max_putback); + if (num <= 0) + { + // an error occurred or the connection is over which is EOF + return EOF; + } + + // reset in_buffer pointers + setg (in_buffer+(max_putback-num_put_back), + in_buffer+max_putback, + in_buffer+max_putback+num); + + return static_cast(*gptr()); + } + + std::streamsize xsgetn ( + char_type* s, + std::streamsize n + ) + { + std::streamsize temp = n; + while (n > 0) + { + int num = static_cast(egptr() - gptr()); + if (num >= n) + { + // copy data from our buffer + std::memcpy(s, gptr(), static_cast(n)); + gbump(static_cast(n)); + return temp; + } + + // read more data into our buffer + if (num == 0) + { + if (underflow() == EOF) + break; + continue; + } + + // copy all the data from our buffer + std::memcpy(s, gptr(), num); + n -= num; + gbump(num); + s += num; + } + return temp-n; + } + + private: + + // member data + int fd; + int fd_printf; + static const std::streamsize max_putback = 4; + static const std::streamsize out_buffer_size = 10000; + static const std::streamsize in_buffer_size = 10000; + char* out_buffer; + char* in_buffer; + + }; + + namespace impl + { + int get_data_fd() + { + char* env_fd = getenv("DLIB_SUBPROCESS_DATA_FD"); + DLIB_CASSERT(env_fd != 0,""); + return atoi(env_fd); + } + + std::iostream& get_data_iostream() + { + static filestreambuf dbuff(get_data_fd(), -1); + static iostream out(&dbuff); + return out; + } + } + +// ---------------------------------------------------------------------------------------- + + subprocess_stream:: + subprocess_stream(const char* program_name) : stderr(NULL), iosub(NULL) + { + if (access(program_name, F_OK)) + throw dlib::error("Error: '" + std::string(program_name) + "' file does not exist."); + if (access(program_name, X_OK)) + throw dlib::error("Error: '" + std::string(program_name) + "' file is not executable."); + + child_pid = fork(); + if (child_pid == -1) + throw dlib::error("Failed to start child process"); + + if (child_pid == 0) + { + // In child process + dup2(stdout_pipe.child_fd(), STDOUT_FILENO); + dup2(stderr_pipe.child_fd(), STDERR_FILENO); + stdout_pipe.close(); + stderr_pipe.close(); + + char* argv[] = {(char*)program_name, nullptr}; + char* cudadevs = getenv("CUDA_VISIBLE_DEVICES"); + if (cudadevs) + { + std::ostringstream sout; + sout << "DLIB_SUBPROCESS_DATA_FD="<sync(); ::close(data_pipe.parent_fd()); } + +// ---------------------------------------------------------------------------------------- + +} + + diff --git a/ml/dlib/dlib/matlab/subprocess_stream.h b/ml/dlib/dlib/matlab/subprocess_stream.h new file mode 100644 index 000000000..b00904c12 --- /dev/null +++ b/ml/dlib/dlib/matlab/subprocess_stream.h @@ -0,0 +1,223 @@ +// Copyright (C) 2016 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SUBPROCeSS_STREAM_H_ +#define DLIB_SUBPROCeSS_STREAM_H_ + +#include +#include +#include +#include +#include +#include +#include + + +namespace dlib +{ + +// -------------------------------------------------------------------------------------- + + // Call dlib's serialize and deserialize by default. The point of this version of + // serialize is to do something fast that normally we wouldn't do, like directly copy + // memory. This is safe since this is an interprocess communication happening the same + // machine. + template void interprocess_serialize ( const T& item, std::ostream& out) { serialize(item, out); } + template void interprocess_deserialize (T& item, std::istream& in) { deserialize(item, in); } + + // But have overloads for direct memory copies for some types since this is faster than + // their default serialization. + template + void interprocess_serialize(const dlib::matrix& item, std::ostream& out) + { + dlib::serialize(item.nr(), out); + dlib::serialize(item.nc(), out); + if (item.size() != 0) + out.write((const char*)&item(0,0), sizeof(T)*item.size()); + if (!out) + throw dlib::serialization_error("Error writing matrix to interprocess iostream."); + } + + template + void interprocess_deserialize(dlib::matrix& item, std::istream& in) + { + long nr, nc; + dlib::deserialize(nr, in); + dlib::deserialize(nc, in); + item.set_size(nr,nc); + if (item.size() != 0) + in.read((char*)&item(0,0), sizeof(T)*item.size()); + if (!in) + throw dlib::serialization_error("Error reading matrix from interprocess iostream."); + } + +// ---------------------------------------------------------------------------------------- + + namespace impl{ std::iostream& get_data_iostream(); } + + inline void send_to_parent_process() {impl::get_data_iostream().flush();} + template + void send_to_parent_process(U&& arg1, T&& ...args) + /*! + ensures + - sends all the arguments to send_to_parent_process() to the parent process by + serializing them with interprocess_serialize(). + !*/ + { + interprocess_serialize(arg1, impl::get_data_iostream()); + send_to_parent_process(std::forward(args)...); + if (!impl::get_data_iostream()) + throw dlib::error("Error sending object to parent process."); + } + + inline void receive_from_parent_process() {} + template + void receive_from_parent_process(U&& arg1, T&& ...args) + /*! + ensures + - receives all the arguments to receive_from_parent_process() from the parent + process by deserializing them from interprocess_serialize(). + !*/ + { + interprocess_deserialize(arg1, impl::get_data_iostream()); + receive_from_parent_process(std::forward(args)...); + if (!impl::get_data_iostream()) + throw dlib::error("Error receiving object from parent process."); + } + + +// ---------------------------------------------------------------------------------------- + + class filestreambuf; + + class subprocess_stream : noncopyable + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a tool for spawning a subprocess and communicating with it. Here + is an example: + + subprocess_stream s("/usr/bin/some_program"); + s.send(obj1, obj2, obj3); + s.receive(obj4, obj5); + s.wait(); // wait for sub process to terminate + + Then in the sub process you would have: + + receive_from_parent_process(obj1, obj2, obj3); + // do stuff + cout << "echo this text to parent cout" << endl; + send_to_parent_process(obj4, obj5); + + + Additionally, if the sub process writes to its standard out then that will + be echoed to std::cout in the parent process. Writing to std::cerr or + returning a non-zero value from main will also be noted by the parent + process and an appropriate exception will be thrown. + !*/ + + public: + + explicit subprocess_stream( + const char* program_name + ); + /*! + ensures + - spawns a sub process by executing the file with the given program_name. + !*/ + + ~subprocess_stream( + ); + /*! + ensures + - calls wait(). Note that the destructor never throws even though wait() can. + If an exception is thrown by wait() it is just logged to std::cerr. + !*/ + + void wait( + ); + /*! + ensures + - closes the input stream to the child process and then waits for the child + to terminate. + - If the child returns an error (by returning != 0 from its main) or + outputs to its standard error then wait() throws a dlib::error() with the + standard error output in it. + !*/ + + int get_child_pid() const { return child_pid; } + /*! + ensures + - returns the PID of the child process + !*/ + + template + void send(U&& arg1, T&& ...args) + /*! + ensures + - sends all the arguments to send() to the subprocess by serializing them + with interprocess_serialize(). + !*/ + { + interprocess_serialize(arg1, iosub); + send(std::forward(args)...); + if (!iosub) + { + std::ostringstream sout; + sout << stderr.rdbuf(); + throw dlib::error("Error sending object to child process.\n" + sout.str()); + } + } + void send() {iosub.flush();} + + template + void receive(U&& arg1, T&& ...args) + /*! + ensures + - receives all the arguments to receive() to the subprocess by deserializing + them with interprocess_deserialize(). + !*/ + { + interprocess_deserialize(arg1, iosub); + receive(std::forward(args)...); + if (!iosub) + { + std::ostringstream sout; + sout << stderr.rdbuf(); + throw dlib::error("Error receiving object from child process.\n" + sout.str() ); + } + } + void receive() {} + + + private: + + void send_eof(); + + class cpipe : noncopyable + { + private: + int fd[2]; + public: + cpipe() { if (socketpair(AF_LOCAL, SOCK_STREAM, 0, fd)) throw dlib::error("Failed to create pipe"); } + ~cpipe() { close(); } + int parent_fd() const { return fd[0]; } + int child_fd() const { return fd[1]; } + void close() { ::close(fd[0]); ::close(fd[1]); } + }; + + cpipe data_pipe; + cpipe stdout_pipe; + cpipe stderr_pipe; + bool wait_called = false; + std::unique_ptr inout_buf; + std::unique_ptr err_buf; + int child_pid = -1; + std::istream stderr; + std::iostream iosub; + }; +} + +// ---------------------------------------------------------------------------------------- + +#endif // DLIB_SUBPROCeSS_STREAM_H_ + diff --git a/ml/dlib/dlib/matrix.h b/ml/dlib/dlib/matrix.h new file mode 100644 index 000000000..d2ae69afb --- /dev/null +++ b/ml/dlib/dlib/matrix.h @@ -0,0 +1,24 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MATRIx_HEADER +#define DLIB_MATRIx_HEADER + +#include "matrix/matrix.h" +#include "matrix/matrix_utilities.h" +#include "matrix/matrix_subexp.h" +#include "matrix/matrix_math_functions.h" +#include "matrix/matrix_assign.h" +#include "matrix/matrix_la.h" +#include "matrix/symmetric_matrix_cache.h" +#include "matrix/matrix_conv.h" +#include "matrix/matrix_read_from_istream.h" +#include "matrix/matrix_fft.h" +#include "matrix/matrix_generic_image.h" + +#ifdef DLIB_USE_BLAS +#include "matrix/matrix_blas_bindings.h" +#endif + +#endif // DLIB_MATRIx_HEADER + + diff --git a/ml/dlib/dlib/matrix/cblas_constants.h b/ml/dlib/dlib/matrix/cblas_constants.h new file mode 100644 index 000000000..6ff89f141 --- /dev/null +++ b/ml/dlib/dlib/matrix/cblas_constants.h @@ -0,0 +1,22 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CBLAS_CONSTAnTS_Hh_ +#define DLIB_CBLAS_CONSTAnTS_Hh_ + +#ifndef CBLAS_H +namespace dlib +{ + namespace blas_bindings + { + enum CBLAS_ORDER {CblasRowMajor=101, CblasColMajor=102}; + enum CBLAS_TRANSPOSE {CblasNoTrans=111, CblasTrans=112, CblasConjTrans=113}; + enum CBLAS_UPLO {CblasUpper=121, CblasLower=122}; + enum CBLAS_DIAG {CblasNonUnit=131, CblasUnit=132}; + enum CBLAS_SIDE {CblasLeft=141, CblasRight=142}; + + } +} +#endif // if not CBLAS_H + +#endif // DLIB_CBLAS_CONSTAnTS_Hh_ + diff --git a/ml/dlib/dlib/matrix/lapack/fortran_id.h b/ml/dlib/dlib/matrix/lapack/fortran_id.h new file mode 100644 index 000000000..8027ea34f --- /dev/null +++ b/ml/dlib/dlib/matrix/lapack/fortran_id.h @@ -0,0 +1,62 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BOOST_NUMERIC_BINDINGS_TRAITS_FORTRAN_H +#define DLIB_BOOST_NUMERIC_BINDINGS_TRAITS_FORTRAN_H + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// FORTRAN BINDING STUFF FROM BOOST +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + +// Permission to copy, use, modify, sell and +// distribute this software is granted provided this copyright notice appears +// in all copies. This software is provided "as is" without express or implied +// warranty, and with no claim as to its suitability for any purpose. +// Copyright (C) 2002, 2003 Si-Lab b.v.b.a., Toon Knapen and Kresimir Fresl + + +// First we need to know what the conventions for linking +// C with Fortran is on this platform/toolset +#if defined(LAPACK_FORCE_UNDERSCORE) +#define DLIB_BIND_FORTRAN_LOWERCASE_UNDERSCORE +#elif defined(LAPACK_FORCE_NOUNDERSCORE) +#define DLIB_BIND_FORTRAN_LOWERCASE +#elif defined(__GNUC__) || defined(__ICC) || defined(__sgi) || defined(__COMO__) || defined(__KCC) +#define DLIB_BIND_FORTRAN_LOWERCASE_UNDERSCORE +#elif defined(__IBMCPP__) || defined(_MSC_VER) || defined(__BORLANDC__) +#define DLIB_BIND_FORTRAN_LOWERCASE +#else +#error do not know how to link with fortran for the given platform +#endif + +// Next we define macros to convert our symbols to +// the current convention +#if defined(DLIB_BIND_FORTRAN_LOWERCASE_UNDERSCORE) +#define DLIB_FORTRAN_ID( id ) id##_ +#elif defined(DLIB_BIND_FORTRAN_LOWERCASE) +#define DLIB_FORTRAN_ID( id ) id +#else +#error do not know how to bind to fortran calling convention +#endif + + + +namespace dlib +{ + namespace lapack + { + // stuff from f2c used to define what exactly is an integer in fortran +#if (defined(__alpha__) || defined(__sparc64__) || defined(__x86_64__) || defined(__ia64__)) && !defined(MATLAB_MEX_FILE) + typedef int integer; + typedef unsigned int uinteger; +#else + typedef long int integer; + typedef unsigned long int uinteger; +#endif + + } +} + +#endif // DLIB_BOOST_NUMERIC_BINDINGS_TRAITS_FORTRAN_H + diff --git a/ml/dlib/dlib/matrix/lapack/gees.h b/ml/dlib/dlib/matrix/lapack/gees.h new file mode 100644 index 000000000..a8ee63ff1 --- /dev/null +++ b/ml/dlib/dlib/matrix/lapack/gees.h @@ -0,0 +1,264 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_LAPACk_ES_Hh_ +#define DLIB_LAPACk_ES_Hh_ + +#include "fortran_id.h" +#include "../matrix.h" + +namespace dlib +{ + namespace lapack + { + namespace binding + { +#if defined(__alpha__) || defined(__sparc64__) || defined(__x86_64__) || defined(__ia64__) + typedef int logical; +#else + typedef long int logical; +#endif + typedef logical (*L_fp)(...); + + extern "C" + { + void DLIB_FORTRAN_ID(dgees) (char *jobvs, char *sort, L_fp select, integer *n, + double *a, integer *lda, integer *sdim, double *wr, + double *wi, double *vs, integer *ldvs, double *work, + integer *lwork, logical *bwork, integer *info); + + void DLIB_FORTRAN_ID(sgees) (char *jobvs, char *sort, L_fp select, integer *n, + float *a, integer *lda, integer *sdim, float *wr, + float *wi, float *vs, integer *ldvs, float *work, + integer *lwork, logical *bwork, integer *info); + + } + + inline int gees (char jobvs, integer n, + double *a, integer lda, double *wr, + double *wi, double *vs, integer ldvs, double *work, + integer lwork) + { + // No sorting allowed + integer info = 0; + char sort = 'N'; + L_fp fnil = 0; + logical bwork = 0; + integer sdim = 0; + DLIB_FORTRAN_ID(dgees)(&jobvs, &sort, fnil, &n, + a, &lda, &sdim, wr, + wi, vs, &ldvs, work, + &lwork, &bwork, &info); + return info; + } + + + inline int gees (char jobvs, integer n, + float *a, integer lda, float *wr, + float *wi, float *vs, integer ldvs, float *work, + integer lwork) + { + // No sorting allowed + integer info = 0; + char sort = 'N'; + L_fp fnil = 0; + logical bwork = 0; + integer sdim = 0; + DLIB_FORTRAN_ID(sgees)(&jobvs, &sort, fnil, &n, + a, &lda, &sdim, wr, + wi, vs, &ldvs, work, + &lwork, &bwork, &info); + return info; + } + + + } + + // ------------------------------------------------------------------------------------ + +/* -- LAPACK driver routine (version 3.1) -- */ +/* Univ. of Tennessee, Univ. of California Berkeley and NAG Ltd.. */ +/* November 2006 */ + +/* .. Scalar Arguments .. */ +/* .. */ +/* .. Array Arguments .. */ +/* .. */ +/* .. Function Arguments .. */ +/* .. */ + +/* Purpose */ +/* ======= */ + +/* DGEES computes for an N-by-N real nonsymmetric matrix A, the */ +/* eigenvalues, the real Schur form T, and, optionally, the matrix of */ +/* Schur vectors Z. This gives the Schur factorization A = Z*T*(Z**T). */ + +/* Optionally, it also orders the eigenvalues on the diagonal of the */ +/* real Schur form so that selected eigenvalues are at the top left. */ +/* The leading columns of Z then form an orthonormal basis for the */ +/* invariant subspace corresponding to the selected eigenvalues. */ + +/* A matrix is in real Schur form if it is upper quasi-triangular with */ +/* 1-by-1 and 2-by-2 blocks. 2-by-2 blocks will be standardized in the */ +/* form */ +/* [ a b ] */ +/* [ c a ] */ + +/* where b*c < 0. The eigenvalues of such a block are a +- sqrt(bc). */ + +/* Arguments */ +/* ========= */ + +/* JOBVS (input) CHARACTER*1 */ +/* = 'N': Schur vectors are not computed; */ +/* = 'V': Schur vectors are computed. */ + +/* SORT (input) CHARACTER*1 */ +/* Specifies whether or not to order the eigenvalues on the */ +/* diagonal of the Schur form. */ +/* = 'N': Eigenvalues are not ordered; */ +/* = 'S': Eigenvalues are ordered (see SELECT). */ + +/* SELECT (external procedure) LOGICAL FUNCTION of two DOUBLE PRECISION arguments */ +/* SELECT must be declared EXTERNAL in the calling subroutine. */ +/* If SORT = 'S', SELECT is used to select eigenvalues to sort */ +/* to the top left of the Schur form. */ +/* If SORT = 'N', SELECT is not referenced. */ +/* An eigenvalue WR(j)+sqrt(-1)*WI(j) is selected if */ +/* SELECT(WR(j),WI(j)) is true; i.e., if either one of a complex */ +/* conjugate pair of eigenvalues is selected, then both complex */ +/* eigenvalues are selected. */ +/* Note that a selected complex eigenvalue may no longer */ +/* satisfy SELECT(WR(j),WI(j)) = .TRUE. after ordering, since */ +/* ordering may change the value of complex eigenvalues */ +/* (especially if the eigenvalue is ill-conditioned); in this */ +/* case INFO is set to N+2 (see INFO below). */ + +/* N (input) INTEGER */ +/* The order of the matrix A. N >= 0. */ + +/* A (input/output) DOUBLE PRECISION array, dimension (LDA,N) */ +/* On entry, the N-by-N matrix A. */ +/* On exit, A has been overwritten by its real Schur form T. */ + +/* LDA (input) INTEGER */ +/* The leading dimension of the array A. LDA >= max(1,N). */ + +/* SDIM (output) INTEGER */ +/* If SORT = 'N', SDIM = 0. */ +/* If SORT = 'S', SDIM = number of eigenvalues (after sorting) */ +/* for which SELECT is true. (Complex conjugate */ +/* pairs for which SELECT is true for either */ +/* eigenvalue count as 2.) */ + +/* WR (output) DOUBLE PRECISION array, dimension (N) */ +/* WI (output) DOUBLE PRECISION array, dimension (N) */ +/* WR and WI contain the real and imaginary parts, */ +/* respectively, of the computed eigenvalues in the same order */ +/* that they appear on the diagonal of the output Schur form T. */ +/* Complex conjugate pairs of eigenvalues will appear */ +/* consecutively with the eigenvalue having the positive */ +/* imaginary part first. */ + +/* VS (output) DOUBLE PRECISION array, dimension (LDVS,N) */ +/* If JOBVS = 'V', VS contains the orthogonal matrix Z of Schur */ +/* vectors. */ +/* If JOBVS = 'N', VS is not referenced. */ + +/* LDVS (input) INTEGER */ +/* The leading dimension of the array VS. LDVS >= 1; if */ +/* JOBVS = 'V', LDVS >= N. */ + +/* WORK (workspace/output) DOUBLE PRECISION array, dimension (MAX(1,LWORK)) */ +/* On exit, if INFO = 0, WORK(1) contains the optimal LWORK. */ + +/* LWORK (input) INTEGER */ +/* The dimension of the array WORK. LWORK >= max(1,3*N). */ +/* For good performance, LWORK must generally be larger. */ + +/* If LWORK = -1, then a workspace query is assumed; the routine */ +/* only calculates the optimal size of the WORK array, returns */ +/* this value as the first entry of the WORK array, and no error */ +/* message related to LWORK is issued by XERBLA. */ + +/* BWORK (workspace) LOGICAL array, dimension (N) */ +/* Not referenced if SORT = 'N'. */ + +/* INFO (output) INTEGER */ +/* = 0: successful exit */ +/* < 0: if INFO = -i, the i-th argument had an illegal value. */ +/* > 0: if INFO = i, and i is */ +/* <= N: the QR algorithm failed to compute all the */ +/* eigenvalues; elements 1:ILO-1 and i+1:N of WR and WI */ +/* contain those eigenvalues which have converged; if */ +/* JOBVS = 'V', VS contains the matrix which reduces A */ +/* to its partially converged Schur form. */ +/* = N+1: the eigenvalues could not be reordered because some */ +/* eigenvalues were too close to separate (the problem */ +/* is very ill-conditioned); */ +/* = N+2: after reordering, roundoff changed values of some */ +/* complex eigenvalues so that leading eigenvalues in */ +/* the Schur form no longer satisfy SELECT=.TRUE. This */ +/* could also be caused by underflow due to scaling. */ + + // ------------------------------------------------------------------------------------ + + template < + typename T, + long NR1, long NR2, long NR3, long NR4, + long NC1, long NC2, long NC3, long NC4, + typename MM, + typename layout + > + int gees ( + const char jobz, + matrix& a, + matrix& wr, + matrix& wi, + matrix& vs + ) + { + matrix work; + + const long n = a.nr(); + + wr.set_size(n,1); + wi.set_size(n,1); + + if (jobz == 'V') + vs.set_size(n,n); + else + vs.set_size(NR4?NR4:1, NC4?NC4:1); + + // figure out how big the workspace needs to be. + T work_size = 1; + int info = binding::gees(jobz, n, + &a(0,0), a.nr(), &wr(0,0), + &wi(0,0), &vs(0,0), vs.nr(), &work_size, + -1); + + if (info != 0) + return info; + + if (work.size() < work_size) + work.set_size(static_cast(work_size), 1); + + // compute the actual decomposition + info = binding::gees(jobz, n, + &a(0,0), a.nr(), &wr(0,0), + &wi(0,0), &vs(0,0), vs.nr(), &work(0,0), + work.size()); + + return info; + } + + // ------------------------------------------------------------------------------------ + + } + +} + +// ---------------------------------------------------------------------------------------- + +#endif // DLIB_LAPACk_ES_Hh_ + diff --git a/ml/dlib/dlib/matrix/lapack/geev.h b/ml/dlib/dlib/matrix/lapack/geev.h new file mode 100644 index 000000000..d8fdc4af5 --- /dev/null +++ b/ml/dlib/dlib/matrix/lapack/geev.h @@ -0,0 +1,234 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_LAPACk_GEEV_Hh_ +#define DLIB_LAPACk_GEEV_Hh_ + +#include "fortran_id.h" +#include "../matrix.h" + +namespace dlib +{ + namespace lapack + { + namespace binding + { + extern "C" + { + void DLIB_FORTRAN_ID(dgeev) (char *jobvl, char *jobvr, integer *n, double * a, + integer *lda, double *wr, double *wi, double *vl, + integer *ldvl, double *vr, integer *ldvr, double *work, + integer *lwork, integer *info); + + void DLIB_FORTRAN_ID(sgeev) (char *jobvl, char *jobvr, integer *n, float * a, + integer *lda, float *wr, float *wi, float *vl, + integer *ldvl, float *vr, integer *ldvr, float *work, + integer *lwork, integer *info); + + } + + inline int geev (char jobvl, char jobvr, integer n, double *a, + integer lda, double *wr, double *wi, double *vl, + integer ldvl, double *vr, integer ldvr, double *work, + integer lwork) + { + integer info = 0; + DLIB_FORTRAN_ID(dgeev)(&jobvl, &jobvr, &n, a, + &lda, wr, wi, vl, + &ldvl, vr, &ldvr, work, + &lwork, &info); + return info; + } + + inline int geev (char jobvl, char jobvr, integer n, float *a, + integer lda, float *wr, float *wi, float *vl, + integer ldvl, float *vr, integer ldvr, float *work, + integer lwork) + { + integer info = 0; + DLIB_FORTRAN_ID(sgeev)(&jobvl, &jobvr, &n, a, + &lda, wr, wi, vl, + &ldvl, vr, &ldvr, work, + &lwork, &info); + return info; + } + + + } + + // ------------------------------------------------------------------------------------ + +/* -- LAPACK driver routine (version 3.1) -- */ +/* Univ. of Tennessee, Univ. of California Berkeley and NAG Ltd.. */ +/* November 2006 */ + +/* .. Scalar Arguments .. */ +/* .. */ +/* .. Array Arguments .. */ +/* .. */ + +/* Purpose */ +/* ======= */ + +/* DGEEV computes for an N-by-N real nonsymmetric matrix A, the */ +/* eigenvalues and, optionally, the left and/or right eigenvectors. */ + +/* The right eigenvector v(j) of A satisfies */ +/* A * v(j) = lambda(j) * v(j) */ +/* where lambda(j) is its eigenvalue. */ +/* The left eigenvector u(j) of A satisfies */ +/* u(j)**H * A = lambda(j) * u(j)**H */ +/* where u(j)**H denotes the conjugate transpose of u(j). */ + +/* The computed eigenvectors are normalized to have Euclidean norm */ +/* equal to 1 and largest component real. */ + +/* Arguments */ +/* ========= */ + +/* JOBVL (input) CHARACTER*1 */ +/* = 'N': left eigenvectors of A are not computed; */ +/* = 'V': left eigenvectors of A are computed. */ + +/* JOBVR (input) CHARACTER*1 */ +/* = 'N': right eigenvectors of A are not computed; */ +/* = 'V': right eigenvectors of A are computed. */ + +/* N (input) INTEGER */ +/* The order of the matrix A. N >= 0. */ + +/* A (input/output) DOUBLE PRECISION array, dimension (LDA,N) */ +/* On entry, the N-by-N matrix A. */ +/* On exit, A has been overwritten. */ + +/* LDA (input) INTEGER */ +/* The leading dimension of the array A. LDA >= max(1,N). */ + +/* WR (output) DOUBLE PRECISION array, dimension (N) */ +/* WI (output) DOUBLE PRECISION array, dimension (N) */ +/* WR and WI contain the real and imaginary parts, */ +/* respectively, of the computed eigenvalues. Complex */ +/* conjugate pairs of eigenvalues appear consecutively */ +/* with the eigenvalue having the positive imaginary part */ +/* first. */ + +/* VL (output) DOUBLE PRECISION array, dimension (LDVL,N) */ +/* If JOBVL = 'V', the left eigenvectors u(j) are stored one */ +/* after another in the columns of VL, in the same order */ +/* as their eigenvalues. */ +/* If JOBVL = 'N', VL is not referenced. */ +/* If the j-th eigenvalue is real, then u(j) = VL(:,j), */ +/* the j-th column of VL. */ +/* If the j-th and (j+1)-st eigenvalues form a complex */ +/* conjugate pair, then u(j) = VL(:,j) + i*VL(:,j+1) and */ +/* u(j+1) = VL(:,j) - i*VL(:,j+1). */ + +/* LDVL (input) INTEGER */ +/* The leading dimension of the array VL. LDVL >= 1; if */ +/* JOBVL = 'V', LDVL >= N. */ + +/* VR (output) DOUBLE PRECISION array, dimension (LDVR,N) */ +/* If JOBVR = 'V', the right eigenvectors v(j) are stored one */ +/* after another in the columns of VR, in the same order */ +/* as their eigenvalues. */ +/* If JOBVR = 'N', VR is not referenced. */ +/* If the j-th eigenvalue is real, then v(j) = VR(:,j), */ +/* the j-th column of VR. */ +/* If the j-th and (j+1)-st eigenvalues form a complex */ +/* conjugate pair, then v(j) = VR(:,j) + i*VR(:,j+1) and */ +/* v(j+1) = VR(:,j) - i*VR(:,j+1). */ + +/* LDVR (input) INTEGER */ +/* The leading dimension of the array VR. LDVR >= 1; if */ +/* JOBVR = 'V', LDVR >= N. */ + +/* WORK (workspace/output) DOUBLE PRECISION array, dimension (MAX(1,LWORK)) */ +/* On exit, if INFO = 0, WORK(1) returns the optimal LWORK. */ + +/* LWORK (input) INTEGER */ +/* The dimension of the array WORK. LWORK >= max(1,3*N), and */ +/* if JOBVL = 'V' or JOBVR = 'V', LWORK >= 4*N. For good */ +/* performance, LWORK must generally be larger. */ + +/* If LWORK = -1, then a workspace query is assumed; the routine */ +/* only calculates the optimal size of the WORK array, returns */ +/* this value as the first entry of the WORK array, and no error */ +/* message related to LWORK is issued by XERBLA. */ + +/* INFO (output) INTEGER */ +/* = 0: successful exit */ +/* < 0: if INFO = -i, the i-th argument had an illegal value. */ +/* > 0: if INFO = i, the QR algorithm failed to compute all the */ +/* eigenvalues, and no eigenvectors have been computed; */ +/* elements i+1:N of WR and WI contain eigenvalues which */ +/* have converged. */ + + // ------------------------------------------------------------------------------------ + + template < + typename T, + long NR1, long NR2, long NR3, long NR4, long NR5, + long NC1, long NC2, long NC3, long NC4, long NC5, + typename MM, + typename layout + > + int geev ( + const char jobvl, + const char jobvr, + matrix& a, + matrix& wr, + matrix& wi, + matrix& vl, + matrix& vr + ) + { + matrix work; + + const long n = a.nr(); + + wr.set_size(n,1); + wi.set_size(n,1); + + if (jobvl == 'V') + vl.set_size(n,n); + else + vl.set_size(NR4?NR4:1, NC4?NC4:1); + + if (jobvr == 'V') + vr.set_size(n,n); + else + vr.set_size(NR5?NR5:1, NC5?NC5:1); + + + // figure out how big the workspace needs to be. + T work_size = 1; + int info = binding::geev(jobvl, jobvr, n, &a(0,0), + a.nr(), &wr(0,0), &wi(0,0), &vl(0,0), + vl.nr(), &vr(0,0), vr.nr(), &work_size, + -1); + + if (info != 0) + return info; + + if (work.size() < work_size) + work.set_size(static_cast(work_size), 1); + + // compute the actual decomposition + info = binding::geev(jobvl, jobvr, n, &a(0,0), + a.nr(), &wr(0,0), &wi(0,0), &vl(0,0), + vl.nr(), &vr(0,0), vr.nr(), &work(0,0), + work.size()); + + return info; + } + + // ------------------------------------------------------------------------------------ + + } + +} + +// ---------------------------------------------------------------------------------------- + +#endif // DLIB_LAPACk_GEEV_Hh_ + + diff --git a/ml/dlib/dlib/matrix/lapack/geqrf.h b/ml/dlib/dlib/matrix/lapack/geqrf.h new file mode 100644 index 000000000..c1f8fc050 --- /dev/null +++ b/ml/dlib/dlib/matrix/lapack/geqrf.h @@ -0,0 +1,168 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_LAPACk_GEQRF_Hh_ +#define DLIB_LAPACk_GEQRF_Hh_ + +#include "fortran_id.h" +#include "../matrix.h" + +namespace dlib +{ + namespace lapack + { + namespace binding + { + extern "C" + { + void DLIB_FORTRAN_ID(dgeqrf) (integer *m, integer *n, double *a, integer * + lda, double *tau, double *work, integer *lwork, + integer *info); + + void DLIB_FORTRAN_ID(sgeqrf) (integer *m, integer *n, float *a, integer * + lda, float *tau, float *work, integer *lwork, + integer *info); + } + + inline int geqrf (integer m, integer n, double *a, integer lda, + double *tau, double *work, integer lwork) + { + integer info = 0; + DLIB_FORTRAN_ID(dgeqrf)(&m, &n, a, &lda, + tau, work, &lwork, &info); + return info; + } + + inline int geqrf (integer m, integer n, float *a, integer lda, + float *tau, float *work, integer lwork) + { + integer info = 0; + DLIB_FORTRAN_ID(sgeqrf)(&m, &n, a, &lda, + tau, work, &lwork, &info); + return info; + } + + + } + + // ------------------------------------------------------------------------------------ + +/* -- LAPACK routine (version 3.1) -- */ +/* Univ. of Tennessee, Univ. of California Berkeley and NAG Ltd.. */ +/* November 2006 */ + +/* .. Scalar Arguments .. */ +/* .. */ +/* .. Array Arguments .. */ +/* .. */ + +/* Purpose */ +/* ======= */ + +/* DGEQRF computes a QR factorization of a real M-by-N matrix A: */ +/* A = Q * R. */ + +/* Arguments */ +/* ========= */ + +/* M (input) INTEGER */ +/* The number of rows of the matrix A. M >= 0. */ + +/* N (input) INTEGER */ +/* The number of columns of the matrix A. N >= 0. */ + +/* A (input/output) DOUBLE PRECISION array, dimension (LDA,N) */ +/* On entry, the M-by-N matrix A. */ +/* On exit, the elements on and above the diagonal of the array */ +/* contain the min(M,N)-by-N upper trapezoidal matrix R (R is */ +/* upper triangular if m >= n); the elements below the diagonal, */ +/* with the array TAU, represent the orthogonal matrix Q as a */ +/* product of min(m,n) elementary reflectors (see Further */ +/* Details). */ + +/* LDA (input) INTEGER */ +/* The leading dimension of the array A. LDA >= max(1,M). */ + +/* TAU (output) DOUBLE PRECISION array, dimension (min(M,N)) */ +/* The scalar factors of the elementary reflectors (see Further */ +/* Details). */ + +/* WORK (workspace/output) DOUBLE PRECISION array, dimension (MAX(1,LWORK)) */ +/* On exit, if INFO = 0, WORK(1) returns the optimal LWORK. */ + +/* LWORK (input) INTEGER */ +/* The dimension of the array WORK. LWORK >= max(1,N). */ +/* For optimum performance LWORK >= N*NB, where NB is */ +/* the optimal blocksize. */ + +/* If LWORK = -1, then a workspace query is assumed; the routine */ +/* only calculates the optimal size of the WORK array, returns */ +/* this value as the first entry of the WORK array, and no error */ +/* message related to LWORK is issued by XERBLA. */ + +/* INFO (output) INTEGER */ +/* = 0: successful exit */ +/* < 0: if INFO = -i, the i-th argument had an illegal value */ + +/* Further Details */ +/* =============== */ + +/* The matrix Q is represented as a product of elementary reflectors */ + +/* Q = H(1) H(2) . . . H(k), where k = min(m,n). */ + +/* Each H(i) has the form */ + +/* H(i) = I - tau * v * v' */ + +/* where tau is a real scalar, and v is a real vector with */ +/* v(1:i-1) = 0 and v(i) = 1; v(i+1:m) is stored on exit in A(i+1:m,i), */ +/* and tau in TAU(i). */ + + + // ------------------------------------------------------------------------------------ + + template < + typename T, + long NR1, long NR2, + long NC1, long NC2, + typename MM + > + int geqrf ( + matrix& a, + matrix& tau + ) + { + matrix work; + + tau.set_size(std::min(a.nr(), a.nc()), 1); + + // figure out how big the workspace needs to be. + T work_size = 1; + int info = binding::geqrf(a.nr(), a.nc(), &a(0,0), a.nr(), + &tau(0,0), &work_size, -1); + + if (info != 0) + return info; + + if (work.size() < work_size) + work.set_size(static_cast(work_size), 1); + + // compute the actual decomposition + info = binding::geqrf(a.nr(), a.nc(), &a(0,0), a.nr(), + &tau(0,0), &work(0,0), work.size()); + + return info; + } + + // ------------------------------------------------------------------------------------ + + } + +} + +// ---------------------------------------------------------------------------------------- + +#endif // DLIB_LAPACk_GEQRF_Hh_ + + + diff --git a/ml/dlib/dlib/matrix/lapack/gesdd.h b/ml/dlib/dlib/matrix/lapack/gesdd.h new file mode 100644 index 000000000..e6b4d26e1 --- /dev/null +++ b/ml/dlib/dlib/matrix/lapack/gesdd.h @@ -0,0 +1,364 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_LAPACk_SDD_Hh_ +#define DLIB_LAPACk_SDD_Hh_ + +#include "fortran_id.h" +#include "../matrix.h" + +namespace dlib +{ + namespace lapack + { + namespace binding + { + extern "C" + { + void DLIB_FORTRAN_ID(dgesdd) (char const* jobz, + const integer* m, const integer* n, double* a, const integer* lda, + double* s, double* u, const integer* ldu, + double* vt, const integer* ldvt, + double* work, const integer* lwork, integer* iwork, integer* info); + + void DLIB_FORTRAN_ID(sgesdd) (char const* jobz, + const integer* m, const integer* n, float* a, const integer* lda, + float* s, float* u, const integer* ldu, + float* vt, const integer* ldvt, + float* work, const integer* lwork, integer* iwork, integer* info); + + } + + inline integer gesdd (const char jobz, + const integer m, const integer n, double* a, const integer lda, + double* s, double* u, const integer ldu, + double* vt, const integer ldvt, + double* work, const integer lwork, integer* iwork) + { + integer info = 0; + DLIB_FORTRAN_ID(dgesdd)(&jobz, &m, &n, a, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, iwork, &info); + return info; + } + + inline integer gesdd (const char jobz, + const integer m, const integer n, float* a, const integer lda, + float* s, float* u, const integer ldu, + float* vt, const integer ldvt, + float* work, const integer lwork, integer* iwork) + { + integer info = 0; + DLIB_FORTRAN_ID(sgesdd)(&jobz, &m, &n, a, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, iwork, &info); + return info; + } + } + + // ------------------------------------------------------------------------------------ + +/* -- LAPACK driver routine (version 3.1) -- */ +/* Univ. of Tennessee, Univ. of California Berkeley and NAG Ltd.. */ +/* November 2006 */ + +/* .. Scalar Arguments .. */ +/* .. */ +/* .. Array Arguments .. */ +/* .. */ + +/* Purpose */ +/* ======= */ + +/* DGESDD computes the singular value decomposition (SVD) of a real */ +/* M-by-N matrix A, optionally computing the left and right singular */ +/* vectors. If singular vectors are desired, it uses a */ +/* divide-and-conquer algorithm. */ + +/* The SVD is written */ + +/* A = U * SIGMA * transpose(V) */ + +/* where SIGMA is an M-by-N matrix which is zero except for its */ +/* min(m,n) diagonal elements, U is an M-by-M orthogonal matrix, and */ +/* V is an N-by-N orthogonal matrix. The diagonal elements of SIGMA */ +/* are the singular values of A; they are real and non-negative, and */ +/* are returned in descending order. The first min(m,n) columns of */ +/* U and V are the left and right singular vectors of A. */ + +/* Note that the routine returns VT = V**T, not V. */ + +/* The divide and conquer algorithm makes very mild assumptions about */ +/* floating point arithmetic. It will work on machines with a guard */ +/* digit in add/subtract, or on those binary machines without guard */ +/* digits which subtract like the Cray X-MP, Cray Y-MP, Cray C-90, or */ +/* Cray-2. It could conceivably fail on hexadecimal or decimal machines */ +/* without guard digits, but we know of none. */ + +/* Arguments */ +/* ========= */ + +/* JOBZ (input) CHARACTER*1 */ +/* Specifies options for computing all or part of the matrix U: */ +/* = 'A': all M columns of U and all N rows of V**T are */ +/* returned in the arrays U and VT; */ +/* = 'S': the first min(M,N) columns of U and the first */ +/* min(M,N) rows of V**T are returned in the arrays U */ +/* and VT; */ +/* = 'O': If M >= N, the first N columns of U are overwritten */ +/* on the array A and all rows of V**T are returned in */ +/* the array VT; */ +/* otherwise, all columns of U are returned in the */ +/* array U and the first M rows of V**T are overwritten */ +/* in the array A; */ +/* = 'N': no columns of U or rows of V**T are computed. */ + +/* M (input) INTEGER */ +/* The number of rows of the input matrix A. M >= 0. */ + +/* N (input) INTEGER */ +/* The number of columns of the input matrix A. N >= 0. */ + +/* A (input/output) DOUBLE PRECISION array, dimension (LDA,N) */ +/* On entry, the M-by-N matrix A. */ +/* On exit, */ +/* if JOBZ = 'O', A is overwritten with the first N columns */ +/* of U (the left singular vectors, stored */ +/* columnwise) if M >= N; */ +/* A is overwritten with the first M rows */ +/* of V**T (the right singular vectors, stored */ +/* rowwise) otherwise. */ +/* if JOBZ .ne. 'O', the contents of A are destroyed. */ + +/* LDA (input) INTEGER */ +/* The leading dimension of the array A. LDA >= max(1,M). */ + +/* S (output) DOUBLE PRECISION array, dimension (min(M,N)) */ +/* The singular values of A, sorted so that S(i) >= S(i+1). */ + +/* U (output) DOUBLE PRECISION array, dimension (LDU,UCOL) */ +/* UCOL = M if JOBZ = 'A' or JOBZ = 'O' and M < N; */ +/* UCOL = min(M,N) if JOBZ = 'S'. */ +/* If JOBZ = 'A' or JOBZ = 'O' and M < N, U contains the M-by-M */ +/* orthogonal matrix U; */ +/* if JOBZ = 'S', U contains the first min(M,N) columns of U */ +/* (the left singular vectors, stored columnwise); */ +/* if JOBZ = 'O' and M >= N, or JOBZ = 'N', U is not referenced. */ + +/* LDU (input) INTEGER */ +/* The leading dimension of the array U. LDU >= 1; if */ +/* JOBZ = 'S' or 'A' or JOBZ = 'O' and M < N, LDU >= M. */ + +/* VT (output) DOUBLE PRECISION array, dimension (LDVT,N) */ +/* If JOBZ = 'A' or JOBZ = 'O' and M >= N, VT contains the */ +/* N-by-N orthogonal matrix V**T; */ +/* if JOBZ = 'S', VT contains the first min(M,N) rows of */ +/* V**T (the right singular vectors, stored rowwise); */ +/* if JOBZ = 'O' and M < N, or JOBZ = 'N', VT is not referenced. */ + +/* LDVT (input) INTEGER */ +/* The leading dimension of the array VT. LDVT >= 1; if */ +/* JOBZ = 'A' or JOBZ = 'O' and M >= N, LDVT >= N; */ +/* if JOBZ = 'S', LDVT >= min(M,N). */ + +/* WORK (workspace/output) DOUBLE PRECISION array, dimension (MAX(1,LWORK)) */ +/* On exit, if INFO = 0, WORK(1) returns the optimal LWORK; */ + +/* LWORK (input) INTEGER */ +/* The dimension of the array WORK. LWORK >= 1. */ +/* If JOBZ = 'N', */ +/* LWORK >= 3*min(M,N) + max(max(M,N),7*min(M,N)). */ +/* If JOBZ = 'O', */ +/* LWORK >= 3*min(M,N)*min(M,N) + */ +/* max(max(M,N),5*min(M,N)*min(M,N)+4*min(M,N)). */ +/* If JOBZ = 'S' or 'A' */ +/* LWORK >= 3*min(M,N)*min(M,N) + */ +/* max(max(M,N),4*min(M,N)*min(M,N)+4*min(M,N)). */ +/* For good performance, LWORK should generally be larger. */ +/* If LWORK = -1 but other input arguments are legal, WORK(1) */ +/* returns the optimal LWORK. */ + +/* IWORK (workspace) INTEGER array, dimension (8*min(M,N)) */ + +/* INFO (output) INTEGER */ +/* = 0: successful exit. */ +/* < 0: if INFO = -i, the i-th argument had an illegal value. */ +/* > 0: DBDSDC did not converge, updating process failed. */ + +/* Further Details */ +/* =============== */ + +/* Based on contributions by */ +/* Ming Gu and Huan Ren, Computer Science Division, University of */ +/* California at Berkeley, USA */ + + // ------------------------------------------------------------------------------------ + + template < + typename T, + long NR1, long NR2, long NR3, long NR4, + long NC1, long NC2, long NC3, long NC4, + typename MM + > + int gesdd ( + const char jobz, + matrix& a, + matrix& s, + matrix& u, + matrix& vt + ) + { + matrix work; + matrix iwork; + + const long m = a.nr(); + const long n = a.nc(); + s.set_size(std::min(m,n), 1); + + // make sure the iwork memory block is big enough + if (iwork.size() < 8*std::min(m,n)) + iwork.set_size(8*std::min(m,n), 1); + + if (jobz == 'A') + { + u.set_size(m,m); + vt.set_size(n,n); + } + else if (jobz == 'S') + { + u.set_size(m, std::min(m,n)); + vt.set_size(std::min(m,n), n); + } + else if (jobz == 'O') + { + DLIB_CASSERT(false, "jobz == 'O' not supported"); + } + else + { + u.set_size(NR3?NR3:1, NC3?NC3:1); + vt.set_size(NR4?NR4:1, NC4?NC4:1); + } + + // figure out how big the workspace needs to be. + T work_size = 1; + int info = binding::gesdd(jobz, a.nr(), a.nc(), &a(0,0), a.nr(), + &s(0,0), &u(0,0), u.nr(), &vt(0,0), vt.nr(), + &work_size, -1, &iwork(0,0)); + + if (info != 0) + return info; + + // There is a bug in an older version of LAPACK in Debian etch + // that causes the gesdd to return the wrong value for work_size + // when jobz == 'N'. So verify the value of work_size. + if (jobz == 'N') + { + using std::min; + using std::max; + const T min_work_size = 3*min(m,n) + max(max(m,n),7*min(m,n)); + if (work_size < min_work_size) + work_size = min_work_size; + } + + if (work.size() < work_size) + work.set_size(static_cast(work_size), 1); + + // compute the actual SVD + info = binding::gesdd(jobz, a.nr(), a.nc(), &a(0,0), a.nr(), + &s(0,0), &u(0,0), u.nr(), &vt(0,0), vt.nr(), + &work(0,0), work.size(), &iwork(0,0)); + + return info; + } + + // ------------------------------------------------------------------------------------ + + template < + typename T, + long NR1, long NR2, long NR3, long NR4, + long NC1, long NC2, long NC3, long NC4, + typename MM + > + int gesdd ( + const char jobz, + matrix& a, + matrix& s, + matrix& u_, + matrix& vt_ + ) + { + matrix work; + matrix iwork; + + // Row major order matrices are transposed from LAPACK's point of view. + matrix& u = vt_; + matrix& vt = u_; + + + const long m = a.nc(); + const long n = a.nr(); + s.set_size(std::min(m,n), 1); + + // make sure the iwork memory block is big enough + if (iwork.size() < 8*std::min(m,n)) + iwork.set_size(8*std::min(m,n), 1); + + if (jobz == 'A') + { + u.set_size(m,m); + vt.set_size(n,n); + } + else if (jobz == 'S') + { + u.set_size(std::min(m,n), m); + vt.set_size(n, std::min(m,n)); + } + else if (jobz == 'O') + { + DLIB_CASSERT(false, "jobz == 'O' not supported"); + } + else + { + u.set_size(NR4?NR4:1, NC4?NC4:1); + vt.set_size(NR3?NR3:1, NC3?NC3:1); + } + + // figure out how big the workspace needs to be. + T work_size = 1; + int info = binding::gesdd(jobz, m, n, &a(0,0), a.nc(), + &s(0,0), &u(0,0), u.nc(), &vt(0,0), vt.nc(), + &work_size, -1, &iwork(0,0)); + + if (info != 0) + return info; + + // There is a bug in an older version of LAPACK in Debian etch + // that causes the gesdd to return the wrong value for work_size + // when jobz == 'N'. So verify the value of work_size. + if (jobz == 'N') + { + using std::min; + using std::max; + const T min_work_size = 3*min(m,n) + max(max(m,n),7*min(m,n)); + if (work_size < min_work_size) + work_size = min_work_size; + } + + + if (work.size() < work_size) + work.set_size(static_cast(work_size), 1); + + // compute the actual SVD + info = binding::gesdd(jobz, m, n, &a(0,0), a.nc(), + &s(0,0), &u(0,0), u.nc(), &vt(0,0), vt.nc(), + &work(0,0), work.size(), &iwork(0,0)); + + return info; + } + + // ------------------------------------------------------------------------------------ + + } + +} + +// ---------------------------------------------------------------------------------------- + +#endif // DLIB_LAPACk_SDD_Hh_ + + diff --git a/ml/dlib/dlib/matrix/lapack/gesvd.h b/ml/dlib/dlib/matrix/lapack/gesvd.h new file mode 100644 index 000000000..e00654db6 --- /dev/null +++ b/ml/dlib/dlib/matrix/lapack/gesvd.h @@ -0,0 +1,323 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_LAPACk_SVD_Hh_ +#define DLIB_LAPACk_SVD_Hh_ + +#include "fortran_id.h" +#include "../matrix.h" + +namespace dlib +{ + namespace lapack + { + namespace binding + { + extern "C" + { + void DLIB_FORTRAN_ID(dgesvd) (const char* jobu, const char* jobvt, + const integer* m, const integer* n, double* a, const integer* lda, + double* s, double* u, const integer* ldu, + double* vt, const integer* ldvt, + double* work, const integer* lwork, integer* info); + + void DLIB_FORTRAN_ID(sgesvd) (const char* jobu, const char* jobvt, + const integer* m, const integer* n, float* a, const integer* lda, + float* s, float* u, const integer* ldu, + float* vt, const integer* ldvt, + float* work, const integer* lwork, integer* info); + + } + + inline integer gesvd (const char jobu, const char jobvt, + const integer m, const integer n, double* a, const integer lda, + double* s, double* u, const integer ldu, + double* vt, const integer ldvt, + double* work, const integer lwork) + { + integer info = 0; + DLIB_FORTRAN_ID(dgesvd)(&jobu, &jobvt, &m, &n, a, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, &info); + return info; + } + + inline integer gesvd (const char jobu, const char jobvt, + const integer m, const integer n, float* a, const integer lda, + float* s, float* u, const integer ldu, + float* vt, const integer ldvt, + float* work, const integer lwork) + { + integer info = 0; + DLIB_FORTRAN_ID(sgesvd)(&jobu, &jobvt, &m, &n, a, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, &info); + return info; + } + + } + + // ------------------------------------------------------------------------------------ + +/* -- LAPACK driver routine (version 3.1) -- */ +/* Univ. of Tennessee, Univ. of California Berkeley and NAG Ltd.. */ +/* November 2006 */ + +/* .. Scalar Arguments .. */ +/* .. */ +/* .. Array Arguments .. */ +/* .. */ + +/* Purpose */ +/* ======= */ + +/* DGESVD computes the singular value decomposition (SVD) of a real */ +/* M-by-N matrix A, optionally computing the left and/or right singular */ +/* vectors. The SVD is written */ + +/* A = U * SIGMA * transpose(V) */ + +/* where SIGMA is an M-by-N matrix which is zero except for its */ +/* min(m,n) diagonal elements, U is an M-by-M orthogonal matrix, and */ +/* V is an N-by-N orthogonal matrix. The diagonal elements of SIGMA */ +/* are the singular values of A; they are real and non-negative, and */ +/* are returned in descending order. The first min(m,n) columns of */ +/* U and V are the left and right singular vectors of A. */ + +/* Note that the routine returns V**T, not V. */ + +/* Arguments */ +/* ========= */ + +/* JOBU (input) CHARACTER*1 */ +/* Specifies options for computing all or part of the matrix U: */ +/* = 'A': all M columns of U are returned in array U: */ +/* = 'S': the first min(m,n) columns of U (the left singular */ +/* vectors) are returned in the array U; */ +/* = 'O': the first min(m,n) columns of U (the left singular */ +/* vectors) are overwritten on the array A; */ +/* = 'N': no columns of U (no left singular vectors) are */ +/* computed. */ + +/* JOBVT (input) CHARACTER*1 */ +/* Specifies options for computing all or part of the matrix */ +/* V**T: */ +/* = 'A': all N rows of V**T are returned in the array VT; */ +/* = 'S': the first min(m,n) rows of V**T (the right singular */ +/* vectors) are returned in the array VT; */ +/* = 'O': the first min(m,n) rows of V**T (the right singular */ +/* vectors) are overwritten on the array A; */ +/* = 'N': no rows of V**T (no right singular vectors) are */ +/* computed. */ + +/* JOBVT and JOBU cannot both be 'O'. */ + +/* M (input) INTEGER */ +/* The number of rows of the input matrix A. M >= 0. */ + +/* N (input) INTEGER */ +/* The number of columns of the input matrix A. N >= 0. */ + +/* A (input/output) DOUBLE PRECISION array, dimension (LDA,N) */ +/* On entry, the M-by-N matrix A. */ +/* On exit, */ +/* if JOBU = 'O', A is overwritten with the first min(m,n) */ +/* columns of U (the left singular vectors, */ +/* stored columnwise); */ +/* if JOBVT = 'O', A is overwritten with the first min(m,n) */ +/* rows of V**T (the right singular vectors, */ +/* stored rowwise); */ +/* if JOBU .ne. 'O' and JOBVT .ne. 'O', the contents of A */ +/* are destroyed. */ + +/* LDA (input) INTEGER */ +/* The leading dimension of the array A. LDA >= max(1,M). */ + +/* S (output) DOUBLE PRECISION array, dimension (min(M,N)) */ +/* The singular values of A, sorted so that S(i) >= S(i+1). */ + +/* U (output) DOUBLE PRECISION array, dimension (LDU,UCOL) */ +/* (LDU,M) if JOBU = 'A' or (LDU,min(M,N)) if JOBU = 'S'. */ +/* If JOBU = 'A', U contains the M-by-M orthogonal matrix U; */ +/* if JOBU = 'S', U contains the first min(m,n) columns of U */ +/* (the left singular vectors, stored columnwise); */ +/* if JOBU = 'N' or 'O', U is not referenced. */ + +/* LDU (input) INTEGER */ +/* The leading dimension of the array U. LDU >= 1; if */ +/* JOBU = 'S' or 'A', LDU >= M. */ + +/* VT (output) DOUBLE PRECISION array, dimension (LDVT,N) */ +/* If JOBVT = 'A', VT contains the N-by-N orthogonal matrix */ +/* V**T; */ +/* if JOBVT = 'S', VT contains the first min(m,n) rows of */ +/* V**T (the right singular vectors, stored rowwise); */ +/* if JOBVT = 'N' or 'O', VT is not referenced. */ + +/* LDVT (input) INTEGER */ +/* The leading dimension of the array VT. LDVT >= 1; if */ +/* JOBVT = 'A', LDVT >= N; if JOBVT = 'S', LDVT >= min(M,N). */ + +/* WORK (workspace/output) DOUBLE PRECISION array, dimension (MAX(1,LWORK)) */ +/* On exit, if INFO = 0, WORK(1) returns the optimal LWORK; */ +/* if INFO > 0, WORK(2:MIN(M,N)) contains the unconverged */ +/* superdiagonal elements of an upper bidiagonal matrix B */ +/* whose diagonal is in S (not necessarily sorted). B */ +/* satisfies A = U * B * VT, so it has the same singular values */ +/* as A, and singular vectors related by U and VT. */ + +/* LWORK (input) INTEGER */ +/* The dimension of the array WORK. */ +/* LWORK >= MAX(1,3*MIN(M,N)+MAX(M,N),5*MIN(M,N)). */ +/* For good performance, LWORK should generally be larger. */ + +/* If LWORK = -1, then a workspace query is assumed; the routine */ +/* only calculates the optimal size of the WORK array, returns */ +/* this value as the first entry of the WORK array, and no error */ +/* message related to LWORK is issued by XERBLA. */ + +/* INFO (output) INTEGER */ +/* = 0: successful exit. */ +/* < 0: if INFO = -i, the i-th argument had an illegal value. */ +/* > 0: if DBDSQR did not converge, INFO specifies how many */ +/* superdiagonals of an intermediate bidiagonal form B */ +/* did not converge to zero. See the description of WORK */ +/* above for details. */ + + // ------------------------------------------------------------------------------------ + + template < + typename T, + long NR1, long NR2, long NR3, long NR4, + long NC1, long NC2, long NC3, long NC4, + typename MM + > + int gesvd ( + const char jobu, + const char jobvt, + matrix& a, + matrix& s, + matrix& u, + matrix& vt + ) + { + matrix work; + + const long m = a.nr(); + const long n = a.nc(); + s.set_size(std::min(m,n), 1); + + if (jobu == 'A') + u.set_size(m,m); + else if (jobu == 'S') + u.set_size(m, std::min(m,n)); + else + u.set_size(NR3?NR3:1, NC3?NC3:1); + + if (jobvt == 'A') + vt.set_size(n,n); + else if (jobvt == 'S') + vt.set_size(std::min(m,n), n); + else + vt.set_size(NR4?NR4:1, NC4?NC4:1); + + + if (jobu == 'O' || jobvt == 'O') + { + DLIB_CASSERT(false, "job == 'O' not supported"); + } + + + // figure out how big the workspace needs to be. + T work_size = 1; + int info = binding::gesvd(jobu, jobvt, a.nr(), a.nc(), &a(0,0), a.nr(), + &s(0,0), &u(0,0), u.nr(), &vt(0,0), vt.nr(), + &work_size, -1); + + if (info != 0) + return info; + + if (work.size() < work_size) + work.set_size(static_cast(work_size), 1); + + // compute the actual SVD + info = binding::gesvd(jobu, jobvt, a.nr(), a.nc(), &a(0,0), a.nr(), + &s(0,0), &u(0,0), u.nr(), &vt(0,0), vt.nr(), + &work(0,0), work.size()); + + return info; + } + + // ------------------------------------------------------------------------------------ + + template < + typename T, + long NR1, long NR2, long NR3, long NR4, + long NC1, long NC2, long NC3, long NC4, + typename MM + > + int gesvd ( + char jobu, + char jobvt, + matrix& a, + matrix& s, + matrix& u_, + matrix& vt_ + ) + { + matrix work; + + // Row major order matrices are transposed from LAPACK's point of view. + matrix& u = vt_; + matrix& vt = u_; + std::swap(jobu, jobvt); + + const long m = a.nc(); + const long n = a.nr(); + s.set_size(std::min(m,n), 1); + + if (jobu == 'A') + u.set_size(m,m); + else if (jobu == 'S') + u.set_size(std::min(m,n), m); + else + u.set_size(NR4?NR4:1, NC4?NC4:1); + + if (jobvt == 'A') + vt.set_size(n,n); + else if (jobvt == 'S') + vt.set_size(n, std::min(m,n)); + else + vt.set_size(NR3?NR3:1, NC3?NC3:1); + + if (jobu == 'O' || jobvt == 'O') + { + DLIB_CASSERT(false, "job == 'O' not supported"); + } + + + // figure out how big the workspace needs to be. + T work_size = 1; + int info = binding::gesvd(jobu, jobvt, m, n, &a(0,0), a.nc(), + &s(0,0), &u(0,0), u.nc(), &vt(0,0), vt.nc(), + &work_size, -1); + + if (info != 0) + return info; + + if (work.size() < work_size) + work.set_size(static_cast(work_size), 1); + + // compute the actual SVD + info = binding::gesvd(jobu, jobvt, m, n, &a(0,0), a.nc(), + &s(0,0), &u(0,0), u.nc(), &vt(0,0), vt.nc(), + &work(0,0), work.size()); + + return info; + } + + // ------------------------------------------------------------------------------------ + + } + +} + +// ---------------------------------------------------------------------------------------- + +#endif // DLIB_LAPACk_SVD_Hh_ + diff --git a/ml/dlib/dlib/matrix/lapack/getrf.h b/ml/dlib/dlib/matrix/lapack/getrf.h new file mode 100644 index 000000000..a1f0b139d --- /dev/null +++ b/ml/dlib/dlib/matrix/lapack/getrf.h @@ -0,0 +1,132 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_LAPACk_GETRF_Hh_ +#define DLIB_LAPACk_GETRF_Hh_ + +#include "fortran_id.h" +#include "../matrix.h" + +namespace dlib +{ + namespace lapack + { + namespace binding + { + extern "C" + { + void DLIB_FORTRAN_ID(dgetrf) (integer* m, integer *n, double *a, + integer* lda, integer *ipiv, integer *info); + + void DLIB_FORTRAN_ID(sgetrf) (integer* m, integer *n, float *a, + integer* lda, integer *ipiv, integer *info); + + } + + inline int getrf (integer m, integer n, double *a, + integer lda, integer *ipiv) + { + integer info = 0; + DLIB_FORTRAN_ID(dgetrf)(&m, &n, a, &lda, ipiv, &info); + return info; + } + + inline int getrf (integer m, integer n, float *a, + integer lda, integer *ipiv) + { + integer info = 0; + DLIB_FORTRAN_ID(sgetrf)(&m, &n, a, &lda, ipiv, &info); + return info; + } + + + } + + // ------------------------------------------------------------------------------------ + + +/* -- LAPACK routine (version 3.1) -- */ +/* Univ. of Tennessee, Univ. of California Berkeley and NAG Ltd.. */ +/* November 2006 */ + +/* .. Scalar Arguments .. */ +/* .. */ +/* .. Array Arguments .. */ +/* .. */ + +/* Purpose */ +/* ======= */ + +/* DGETRF computes an LU factorization of a general M-by-N matrix A */ +/* using partial pivoting with row interchanges. */ + +/* The factorization has the form */ +/* A = P * L * U */ +/* where P is a permutation matrix, L is lower triangular with unit */ +/* diagonal elements (lower trapezoidal if m > n), and U is upper */ +/* triangular (upper trapezoidal if m < n). */ + +/* This is the right-looking Level 3 BLAS version of the algorithm. */ + +/* Arguments */ +/* ========= */ + +/* M (input) INTEGER */ +/* The number of rows of the matrix A. M >= 0. */ + +/* N (input) INTEGER */ +/* The number of columns of the matrix A. N >= 0. */ + +/* A (input/output) DOUBLE PRECISION array, dimension (LDA,N) */ +/* On entry, the M-by-N matrix to be factored. */ +/* On exit, the factors L and U from the factorization */ +/* A = P*L*U; the unit diagonal elements of L are not stored. */ + +/* LDA (input) INTEGER */ +/* The leading dimension of the array A. LDA >= max(1,M). */ + +/* IPIV (output) INTEGER array, dimension (min(M,N)) */ +/* The pivot indices; for 1 <= i <= min(M,N), row i of the */ +/* matrix was interchanged with row IPIV(i). */ + +/* INFO (output) INTEGER */ +/* = 0: successful exit */ +/* < 0: if INFO = -i, the i-th argument had an illegal value */ +/* > 0: if INFO = i, U(i,i) is exactly zero. The factorization */ +/* has been completed, but the factor U is exactly */ +/* singular, and division by zero will occur if it is used */ +/* to solve a system of equations. */ + + + // ------------------------------------------------------------------------------------ + + template < + typename T, + long NR1, long NR2, + long NC1, long NC2, + typename MM, + typename layout + > + int getrf ( + matrix& a, + matrix& ipiv + ) + { + const long m = a.nr(); + const long n = a.nc(); + + ipiv.set_size(std::min(m,n), 1); + + // compute the actual decomposition + return binding::getrf(m, n, &a(0,0), a.nr(), &ipiv(0,0)); + } + + // ------------------------------------------------------------------------------------ + + } + +} + +// ---------------------------------------------------------------------------------------- + +#endif // DLIB_LAPACk_GETRF_Hh_ + diff --git a/ml/dlib/dlib/matrix/lapack/ormqr.h b/ml/dlib/dlib/matrix/lapack/ormqr.h new file mode 100644 index 000000000..ab66ff4d2 --- /dev/null +++ b/ml/dlib/dlib/matrix/lapack/ormqr.h @@ -0,0 +1,224 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_LAPACk_ORMQR_Hh_ +#define DLIB_LAPACk_ORMQR_Hh_ + +#include "fortran_id.h" +#include "../matrix.h" + +namespace dlib +{ + namespace lapack + { + namespace binding + { + extern "C" + { + void DLIB_FORTRAN_ID(dormqr) (char *side, char *trans, integer *m, integer *n, + integer *k, const double *a, integer *lda, const double *tau, + double * c_, integer *ldc, double *work, integer *lwork, + integer *info); + + void DLIB_FORTRAN_ID(sormqr) (char *side, char *trans, integer *m, integer *n, + integer *k, const float *a, integer *lda, const float *tau, + float * c_, integer *ldc, float *work, integer *lwork, + integer *info); + + } + + inline int ormqr (char side, char trans, integer m, integer n, + integer k, const double *a, integer lda, const double *tau, + double *c_, integer ldc, double *work, integer lwork) + { + integer info = 0; + DLIB_FORTRAN_ID(dormqr)(&side, &trans, &m, &n, + &k, a, &lda, tau, + c_, &ldc, work, &lwork, &info); + return info; + } + + inline int ormqr (char side, char trans, integer m, integer n, + integer k, const float *a, integer lda, const float *tau, + float *c_, integer ldc, float *work, integer lwork) + { + integer info = 0; + DLIB_FORTRAN_ID(sormqr)(&side, &trans, &m, &n, + &k, a, &lda, tau, + c_, &ldc, work, &lwork, &info); + return info; + } + + + + } + + // ------------------------------------------------------------------------------------ + +/* -- LAPACK routine (version 3.1) -- */ +/* Univ. of Tennessee, Univ. of California Berkeley and NAG Ltd.. */ +/* November 2006 */ + +/* .. Scalar Arguments .. */ +/* .. */ +/* .. Array Arguments .. */ +/* .. */ + +/* Purpose */ +/* ======= */ + +/* DORMQR overwrites the general real M-by-N matrix C with */ + +/* SIDE = 'L' SIDE = 'R' */ +/* TRANS = 'N': Q * C C * Q */ +/* TRANS = 'T': Q**T * C C * Q**T */ + +/* where Q is a real orthogonal matrix defined as the product of k */ +/* elementary reflectors */ + +/* Q = H(1) H(2) . . . H(k) */ + +/* as returned by DGEQRF. Q is of order M if SIDE = 'L' and of order N */ +/* if SIDE = 'R'. */ + +/* Arguments */ +/* ========= */ + +/* SIDE (input) CHARACTER*1 */ +/* = 'L': apply Q or Q**T from the Left; */ +/* = 'R': apply Q or Q**T from the Right. */ + +/* TRANS (input) CHARACTER*1 */ +/* = 'N': No transpose, apply Q; */ +/* = 'T': Transpose, apply Q**T. */ + +/* M (input) INTEGER */ +/* The number of rows of the matrix C. M >= 0. */ + +/* N (input) INTEGER */ +/* The number of columns of the matrix C. N >= 0. */ + +/* K (input) INTEGER */ +/* The number of elementary reflectors whose product defines */ +/* the matrix Q. */ +/* If SIDE = 'L', M >= K >= 0; */ +/* if SIDE = 'R', N >= K >= 0. */ + +/* A (input) DOUBLE PRECISION array, dimension (LDA,K) */ +/* The i-th column must contain the vector which defines the */ +/* elementary reflector H(i), for i = 1,2,...,k, as returned by */ +/* DGEQRF in the first k columns of its array argument A. */ +/* A is modified by the routine but restored on exit. */ + +/* LDA (input) INTEGER */ +/* The leading dimension of the array A. */ +/* If SIDE = 'L', LDA >= max(1,M); */ +/* if SIDE = 'R', LDA >= max(1,N). */ + +/* TAU (input) DOUBLE PRECISION array, dimension (K) */ +/* TAU(i) must contain the scalar factor of the elementary */ +/* reflector H(i), as returned by DGEQRF. */ + +/* C (input/output) DOUBLE PRECISION array, dimension (LDC,N) */ +/* On entry, the M-by-N matrix C. */ +/* On exit, C is overwritten by Q*C or Q**T*C or C*Q**T or C*Q. */ + +/* LDC (input) INTEGER */ +/* The leading dimension of the array C. LDC >= max(1,M). */ + +/* WORK (workspace/output) DOUBLE PRECISION array, dimension (MAX(1,LWORK)) */ +/* On exit, if INFO = 0, WORK(1) returns the optimal LWORK. */ + +/* LWORK (input) INTEGER */ +/* The dimension of the array WORK. */ +/* If SIDE = 'L', LWORK >= max(1,N); */ +/* if SIDE = 'R', LWORK >= max(1,M). */ +/* For optimum performance LWORK >= N*NB if SIDE = 'L', and */ +/* LWORK >= M*NB if SIDE = 'R', where NB is the optimal */ +/* blocksize. */ + +/* If LWORK = -1, then a workspace query is assumed; the routine */ +/* only calculates the optimal size of the WORK array, returns */ +/* this value as the first entry of the WORK array, and no error */ +/* message related to LWORK is issued by XERBLA. */ + +/* INFO (output) INTEGER */ +/* = 0: successful exit */ +/* < 0: if INFO = -i, the i-th argument had an illegal value */ + + // ------------------------------------------------------------------------------------ + + template < + typename T, + long NR1, long NR2, long NR3, + long NC1, long NC2, long NC3, + typename MM, + typename C_LAYOUT + > + int ormqr ( + char side, + char trans, + const matrix& a, + const matrix& tau, + matrix& c + ) + { + long m = c.nr(); + long n = c.nc(); + const long k = a.nc(); + long ldc; + if (is_same_type::value) + { + ldc = c.nr(); + } + else + { + // Since lapack expects c to be in column major layout we have to + // do something to make this work. Since a row major layout matrix + // will look just like a transposed C we can just swap a few things around. + + ldc = c.nc(); + swap(m,n); + + if (side == 'L') + side = 'R'; + else + side = 'L'; + + if (trans == 'T') + trans = 'N'; + else + trans = 'T'; + } + + matrix work; + + // figure out how big the workspace needs to be. + T work_size = 1; + int info = binding::ormqr(side, trans, m, n, + k, &a(0,0), a.nr(), &tau(0,0), + &c(0,0), ldc, &work_size, -1); + + if (info != 0) + return info; + + if (work.size() < work_size) + work.set_size(static_cast(work_size), 1); + + // compute the actual result + info = binding::ormqr(side, trans, m, n, + k, &a(0,0), a.nr(), &tau(0,0), + &c(0,0), ldc, &work(0,0), work.size()); + + return info; + } + + // ------------------------------------------------------------------------------------ + + } + +} + +// ---------------------------------------------------------------------------------------- + +#endif // DLIB_LAPACk_ORMQR_Hh_ + diff --git a/ml/dlib/dlib/matrix/lapack/pbtrf.h b/ml/dlib/dlib/matrix/lapack/pbtrf.h new file mode 100644 index 000000000..23bcc127b --- /dev/null +++ b/ml/dlib/dlib/matrix/lapack/pbtrf.h @@ -0,0 +1,178 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_LAPACk_BDC_Hh_ +#define DLIB_LAPACk_BDC_Hh_ + +#include "fortran_id.h" +#include "../matrix.h" +namespace dlib +{ + namespace lapack + { + namespace binding + { + extern "C" + { + void DLIB_FORTRAN_ID(dpbtrf) (const char* uplo, const integer* n, const integer* kd, + double* ab, const integer* ldab, integer* info); + + void DLIB_FORTRAN_ID(spbtrf) (const char* uplo, const integer* n, const integer* kd, + float* ab, const integer* ldab, integer* info); + + } + + inline integer pbtrf (const char uplo, const integer n, const integer kd, + double* ab, const integer ldab) + { + integer info = 0; + DLIB_FORTRAN_ID(dpbtrf)(&uplo, &n, &kd, ab, &ldab, &info); + return info; + } + + inline integer pbtrf (const char uplo, const integer n, const integer kd, + float* ab, const integer ldab) + { + integer info = 0; + DLIB_FORTRAN_ID(spbtrf)(&uplo, &n, &kd, ab, &ldab, &info); + return info; + } + } + + // ------------------------------------------------------------------------------------ +/* DPBTRF(l) LAPACK routine (version 1.1) DPBTRF(l) + +NAME + DPBTRF - compute the Cholesky factorization of a real symmetric positive + definite band matrix A + +SYNOPSIS + + SUBROUTINE DPBTRF( UPLO, N, KD, AB, LDAB, INFO ) + + CHARACTER UPLO + + INTEGER INFO, KD, LDAB, N + + DOUBLE PRECISION AB( LDAB, * ) + +PURPOSE + DPBTRF computes the Cholesky factorization of a real symmetric positive + definite band matrix A. + + The factorization has the form + A = U**T * U, if UPLO = 'U', or + A = L * L**T, if UPLO = 'L', + where U is an upper triangular matrix and L is lower triangular. + +ARGUMENTS + + UPLO (input) CHARACTER*1 + = 'U': Upper triangle of A is stored; + = 'L': Lower triangle of A is stored. + + N (input) INTEGER + The order of the matrix A. N >= 0. + + KD (input) INTEGER + The number of superdiagonals of the matrix A if UPLO = 'U', or the + number of subdiagonals if UPLO = 'L'. KD >= 0. + + AB (input/output) DOUBLE PRECISION array, dimension (LDAB,N) + On entry, the upper or lower triangle of the symmetric band matrix + A, stored in the first KD+1 rows of the array. The j-th column of + A is stored in the j-th column of the array AB as follows: if UPLO + = 'U', AB(kd+1+i-j,j) = A(i,j) for max(1,j-kd)<=i<=j; if UPLO = + 'L', AB(1+i-j,j) = A(i,j) for j<=i<=min(n,j+kd). + + On exit, if INFO = 0, the triangular factor U or L from the Chole- + sky factorization A = U**T*U or A = L*L**T of the band matrix A, in + the same storage format as A. + + LDAB (input) INTEGER + The leading dimension of the array AB. LDAB >= KD+1. + + INFO (output) INTEGER + = 0: successful exit + < 0: if INFO = -i, the i-th argument had an illegal value + > 0: if INFO = i, the leading minor of order i is not positive + definite, and the factorization could not be completed. + +FURTHER DETAILS + The band storage scheme is illustrated by the following example, when N = + 6, KD = 2, and UPLO = 'U': + + On entry: On exit: + + * * a13 a24 a35 a46 * * u13 u24 u35 u46 + * a12 a23 a34 a45 a56 * u12 u23 u34 u45 u56 + a11 a22 a33 a44 a55 a66 u11 u22 u33 u44 u55 u66 + + Similarly, if UPLO = 'L' the format of A is as follows: + + On entry: On exit: + + a11 a22 a33 a44 a55 a66 l11 l22 l33 l44 l55 l66 + a21 a32 a43 a54 a65 * l21 l32 l43 l54 l65 * + a31 a42 a53 a64 * * l31 l42 l53 l64 * * + + Array elements marked * are not used by the routine. + + Contributed by + Peter Mayes and Giuseppe Radicati, IBM ECSEC, Rome, March 23, 1989 */ + + // ------------------------------------------------------------------------------------ + + template < + typename T, + long NR1, long NC1, + typename MM + > + int pbtrf ( + char uplo, matrix& ab + ) + { + const long ldab = ab.nr(); + const long n = ab.nc(); + const long kd = ldab - 1; // assume fully packed + + int info = binding::pbtrf(uplo, n, kd, &ab(0,0), ldab); + + return info; + } + + // ------------------------------------------------------------------------------------ + + + template < + typename T, + long NR1, long NC1, + typename MM + > + int pbtrf ( + char uplo, matrix& ab + ) + { + const long ldab = ab.nr(); + const long n = ab.nc(); + const long kd = ldab - 1; // assume fully packed + + matrix abt = trans(ab); + + int info = binding::pbtrf(uplo, n, kd, &abt(0,0), ldab); + + ab = trans(abt); + + return info; + } + + // ------------------------------------------------------------------------------------ + + } + +} + +// ---------------------------------------------------------------------------------------- + +#endif // DLIB_LAPACk_BDC_Hh_ + + diff --git a/ml/dlib/dlib/matrix/lapack/potrf.h b/ml/dlib/dlib/matrix/lapack/potrf.h new file mode 100644 index 000000000..b9d6a7cc8 --- /dev/null +++ b/ml/dlib/dlib/matrix/lapack/potrf.h @@ -0,0 +1,174 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_LAPACk_POTRF_Hh_ +#define DLIB_LAPACk_POTRF_Hh_ + +#include "fortran_id.h" +#include "../matrix.h" + +namespace dlib +{ + namespace lapack + { + namespace binding + { + extern "C" + { + void DLIB_FORTRAN_ID(dpotrf) (char *uplo, integer *n, double *a, + integer* lda, integer *info); + + void DLIB_FORTRAN_ID(spotrf) (char *uplo, integer *n, float *a, + integer* lda, integer *info); + + } + + inline int potrf (char uplo, integer n, double *a, integer lda) + { + integer info = 0; + DLIB_FORTRAN_ID(dpotrf)(&uplo, &n, a, &lda, &info); + return info; + } + + inline int potrf (char uplo, integer n, float *a, integer lda) + { + integer info = 0; + DLIB_FORTRAN_ID(spotrf)(&uplo, &n, a, &lda, &info); + return info; + } + + } + + // ------------------------------------------------------------------------------------ + +/* -- LAPACK routine (version 3.1) -- */ +/* Univ. of Tennessee, Univ. of California Berkeley and NAG Ltd.. */ +/* November 2006 */ + +/* .. Scalar Arguments .. */ +/* .. */ +/* .. Array Arguments .. */ +/* .. */ + +/* Purpose */ +/* ======= */ + +/* DPOTRF computes the Cholesky factorization of a real symmetric */ +/* positive definite matrix A. */ + +/* The factorization has the form */ +/* A = U**T * U, if UPLO = 'U', or */ +/* A = L * L**T, if UPLO = 'L', */ +/* where U is an upper triangular matrix and L is lower triangular. */ + +/* This is the block version of the algorithm, calling Level 3 BLAS. */ + +/* Arguments */ +/* ========= */ + +/* UPLO (input) CHARACTER*1 */ +/* = 'U': Upper triangle of A is stored; */ +/* = 'L': Lower triangle of A is stored. */ + +/* N (input) INTEGER */ +/* The order of the matrix A. N >= 0. */ + +/* A (input/output) DOUBLE PRECISION array, dimension (LDA,N) */ +/* On entry, the symmetric matrix A. If UPLO = 'U', the leading */ +/* N-by-N upper triangular part of A contains the upper */ +/* triangular part of the matrix A, and the strictly lower */ +/* triangular part of A is not referenced. If UPLO = 'L', the */ +/* leading N-by-N lower triangular part of A contains the lower */ +/* triangular part of the matrix A, and the strictly upper */ +/* triangular part of A is not referenced. */ + +/* On exit, if INFO = 0, the factor U or L from the Cholesky */ +/* factorization A = U**T*U or A = L*L**T. */ + +/* LDA (input) INTEGER */ +/* The leading dimension of the array A. LDA >= max(1,N). */ + +/* INFO (output) INTEGER */ +/* = 0: successful exit */ +/* < 0: if INFO = -i, the i-th argument had an illegal value */ +/* > 0: if INFO = i, the leading minor of order i is not */ +/* positive definite, and the factorization could not be */ +/* completed. */ + + + // ------------------------------------------------------------------------------------ + + template < + typename T, + long NR1, + long NC1, + typename MM + > + int potrf ( + char uplo, + matrix& a + ) + { + // compute the actual decomposition + int info = binding::potrf(uplo, a.nr(), &a(0,0), a.nr()); + + // If it fails part way though the factorization then make sure + // the end of the matrix gets properly initialized with zeros. + if (info > 0) + { + if (uplo == 'L') + set_colm(a, range(info-1, a.nc()-1)) = 0; + else + set_rowm(a, range(info-1, a.nr()-1)) = 0; + } + + return info; + } + + // ------------------------------------------------------------------------------------ + + template < + typename T, + long NR1, + long NC1, + typename MM + > + int potrf ( + char uplo, + matrix& a + ) + { + // since we are working on a row major order matrix we need to ask + // LAPACK for the transpose of whatever the user asked for. + + if (uplo == 'L') + uplo = 'U'; + else + uplo = 'L'; + + // compute the actual decomposition + int info = binding::potrf(uplo, a.nr(), &a(0,0), a.nr()); + + // If it fails part way though the factorization then make sure + // the end of the matrix gets properly initialized with zeros. + if (info > 0) + { + if (uplo == 'U') + set_colm(a, range(info-1, a.nc()-1)) = 0; + else + set_rowm(a, range(info-1, a.nr()-1)) = 0; + } + + return info; + } + + // ------------------------------------------------------------------------------------ + + } + +} + +// ---------------------------------------------------------------------------------------- + +#endif // DLIB_LAPACk_POTRF_Hh_ + + diff --git a/ml/dlib/dlib/matrix/lapack/syev.h b/ml/dlib/dlib/matrix/lapack/syev.h new file mode 100644 index 000000000..0c9fd251a --- /dev/null +++ b/ml/dlib/dlib/matrix/lapack/syev.h @@ -0,0 +1,218 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_LAPACk_EV_Hh_ +#define DLIB_LAPACk_EV_Hh_ + +#include "fortran_id.h" +#include "../matrix.h" + +namespace dlib +{ + namespace lapack + { + namespace binding + { + extern "C" + { + void DLIB_FORTRAN_ID(dsyev) (char *jobz, char *uplo, integer *n, double *a, + integer *lda, double *w, double *work, integer *lwork, + integer *info); + + void DLIB_FORTRAN_ID(ssyev) (char *jobz, char *uplo, integer *n, float *a, + integer *lda, float *w, float *work, integer *lwork, + integer *info); + + } + + inline int syev (char jobz, char uplo, integer n, double *a, + integer lda, double *w, double *work, integer lwork) + { + integer info = 0; + DLIB_FORTRAN_ID(dsyev)(&jobz, &uplo, &n, a, + &lda, w, work, &lwork, &info); + return info; + } + + inline int syev (char jobz, char uplo, integer n, float *a, + integer lda, float *w, float *work, integer lwork) + { + integer info = 0; + DLIB_FORTRAN_ID(ssyev)(&jobz, &uplo, &n, a, + &lda, w, work, &lwork, &info); + return info; + } + + + } + + // ------------------------------------------------------------------------------------ + +/* -- LAPACK driver routine (version 3.1) -- */ +/* Univ. of Tennessee, Univ. of California Berkeley and NAG Ltd.. */ +/* November 2006 */ + +/* .. Scalar Arguments .. */ +/* .. */ +/* .. Array Arguments .. */ +/* .. */ + +/* Purpose */ +/* ======= */ + +/* DSYEV computes all eigenvalues and, optionally, eigenvectors of a */ +/* real symmetric matrix A. */ + +/* Arguments */ +/* ========= */ + +/* JOBZ (input) CHARACTER*1 */ +/* = 'N': Compute eigenvalues only; */ +/* = 'V': Compute eigenvalues and eigenvectors. */ + +/* UPLO (input) CHARACTER*1 */ +/* = 'U': Upper triangle of A is stored; */ +/* = 'L': Lower triangle of A is stored. */ + +/* N (input) INTEGER */ +/* The order of the matrix A. N >= 0. */ + +/* A (input/output) DOUBLE PRECISION array, dimension (LDA, N) */ +/* On entry, the symmetric matrix A. If UPLO = 'U', the */ +/* leading N-by-N upper triangular part of A contains the */ +/* upper triangular part of the matrix A. If UPLO = 'L', */ +/* the leading N-by-N lower triangular part of A contains */ +/* the lower triangular part of the matrix A. */ +/* On exit, if JOBZ = 'V', then if INFO = 0, A contains the */ +/* orthonormal eigenvectors of the matrix A. */ +/* If JOBZ = 'N', then on exit the lower triangle (if UPLO='L') */ +/* or the upper triangle (if UPLO='U') of A, including the */ +/* diagonal, is destroyed. */ + +/* LDA (input) INTEGER */ +/* The leading dimension of the array A. LDA >= max(1,N). */ + +/* W (output) DOUBLE PRECISION array, dimension (N) */ +/* If INFO = 0, the eigenvalues in ascending order. */ + +/* WORK (workspace/output) DOUBLE PRECISION array, dimension (MAX(1,LWORK)) */ +/* On exit, if INFO = 0, WORK(1) returns the optimal LWORK. */ + +/* LWORK (input) INTEGER */ +/* The length of the array WORK. LWORK >= max(1,3*N-1). */ +/* For optimal efficiency, LWORK >= (NB+2)*N, */ +/* where NB is the blocksize for DSYTRD returned by ILAENV. */ + +/* If LWORK = -1, then a workspace query is assumed; the routine */ +/* only calculates the optimal size of the WORK array, returns */ +/* this value as the first entry of the WORK array, and no error */ +/* message related to LWORK is issued by XERBLA. */ + +/* INFO (output) INTEGER */ +/* = 0: successful exit */ +/* < 0: if INFO = -i, the i-th argument had an illegal value */ +/* > 0: if INFO = i, the algorithm failed to converge; i */ +/* off-diagonal elements of an intermediate tridiagonal */ +/* form did not converge to zero. */ + + + // ------------------------------------------------------------------------------------ + + template < + typename T, + long NR1, long NR2, + long NC1, long NC2, + typename MM + > + int syev ( + const char jobz, + const char uplo, + matrix& a, + matrix& w + ) + { + matrix work; + + const long n = a.nr(); + + w.set_size(n,1); + + + // figure out how big the workspace needs to be. + T work_size = 1; + int info = binding::syev(jobz, uplo, n, &a(0,0), + a.nr(), &w(0,0), &work_size, -1); + + if (info != 0) + return info; + + if (work.size() < work_size) + work.set_size(static_cast(work_size), 1); + + // compute the actual decomposition + info = binding::syev(jobz, uplo, n, &a(0,0), + a.nr(), &w(0,0), &work(0,0), work.size()); + + return info; + } + + // ------------------------------------------------------------------------------------ + + template < + typename T, + long NR1, long NR2, + long NC1, long NC2, + typename MM + > + int syev ( + char jobz, + char uplo, + matrix& a, + matrix& w + ) + { + matrix work; + + if (uplo == 'L') + uplo = 'U'; + else + uplo = 'L'; + + const long n = a.nr(); + + w.set_size(n,1); + + + // figure out how big the workspace needs to be. + T work_size = 1; + int info = binding::syev(jobz, uplo, n, &a(0,0), + a.nc(), &w(0,0), &work_size, -1); + + if (info != 0) + return info; + + if (work.size() < work_size) + work.set_size(static_cast(work_size), 1); + + // compute the actual decomposition + info = binding::syev(jobz, uplo, n, &a(0,0), + a.nc(), &w(0,0), &work(0,0), work.size()); + + + a = trans(a); + + return info; + } + + // ------------------------------------------------------------------------------------ + + } + +} + +// ---------------------------------------------------------------------------------------- + +#endif // DLIB_LAPACk_EV_Hh_ + + + + diff --git a/ml/dlib/dlib/matrix/lapack/syevr.h b/ml/dlib/dlib/matrix/lapack/syevr.h new file mode 100644 index 000000000..65190b3d8 --- /dev/null +++ b/ml/dlib/dlib/matrix/lapack/syevr.h @@ -0,0 +1,445 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_LAPACk_EVR_Hh_ +#define DLIB_LAPACk_EVR_Hh_ + +#include "fortran_id.h" +#include "../matrix.h" + +namespace dlib +{ + namespace lapack + { + namespace binding + { + extern "C" + { + void DLIB_FORTRAN_ID(dsyevr) (char *jobz, char *range, char *uplo, integer *n, + double *a, integer *lda, double *vl, double *vu, integer * il, + integer *iu, double *abstol, integer *m, double *w, + double *z_, integer *ldz, integer *isuppz, double *work, + integer *lwork, integer *iwork, integer *liwork, integer *info); + + void DLIB_FORTRAN_ID(ssyevr) (char *jobz, char *range, char *uplo, integer *n, + float *a, integer *lda, float *vl, float *vu, integer * il, + integer *iu, float *abstol, integer *m, float *w, + float *z_, integer *ldz, integer *isuppz, float *work, + integer *lwork, integer *iwork, integer *liwork, integer *info); + } + + inline int syevr (char jobz, char range, char uplo, integer n, + double* a, integer lda, double vl, double vu, integer il, + integer iu, double abstol, integer *m, double *w, + double *z, integer ldz, integer *isuppz, double *work, + integer lwork, integer *iwork, integer liwork) + { + integer info = 0; + DLIB_FORTRAN_ID(dsyevr)(&jobz, &range, &uplo, &n, + a, &lda, &vl, &vu, &il, + &iu, &abstol, m, w, + z, &ldz, isuppz, work, + &lwork, iwork, &liwork, &info); + return info; + } + + inline int syevr (char jobz, char range, char uplo, integer n, + float* a, integer lda, float vl, float vu, integer il, + integer iu, float abstol, integer *m, float *w, + float *z, integer ldz, integer *isuppz, float *work, + integer lwork, integer *iwork, integer liwork) + { + integer info = 0; + DLIB_FORTRAN_ID(ssyevr)(&jobz, &range, &uplo, &n, + a, &lda, &vl, &vu, &il, + &iu, &abstol, m, w, + z, &ldz, isuppz, work, + &lwork, iwork, &liwork, &info); + return info; + } + + } + + // ------------------------------------------------------------------------------------ + + /* + +* -- LAPACK driver routine (version 3.1) -- +* Univ. of Tennessee, Univ. of California Berkeley and NAG Ltd.. +* November 2006 +* +* .. Scalar Arguments .. + CHARACTER JOBZ, RANGE, UPLO + INTEGER IL, INFO, IU, LDA, LDZ, LIWORK, LWORK, M, N + DOUBLE PRECISION ABSTOL, VL, VU +* .. +* .. Array Arguments .. + INTEGER ISUPPZ( * ), IWORK( * ) + DOUBLE PRECISION A( LDA, * ), W( * ), WORK( * ), Z( LDZ, * ) +* .. +* +* Purpose +* ======= +* +* DSYEVR computes selected eigenvalues and, optionally, eigenvectors +* of a real symmetric matrix A. Eigenvalues and eigenvectors can be +* selected by specifying either a range of values or a range of +* indices for the desired eigenvalues. +* +* DSYEVR first reduces the matrix A to tridiagonal form T with a call +* to DSYTRD. Then, whenever possible, DSYEVR calls DSTEMR to compute +* the eigenspectrum using Relatively Robust Representations. DSTEMR +* computes eigenvalues by the dqds algorithm, while orthogonal +* eigenvectors are computed from various "good" L D L^T representations +* (also known as Relatively Robust Representations). Gram-Schmidt +* orthogonalization is avoided as far as possible. More specifically, +* the various steps of the algorithm are as follows. +* +* For each unreduced block (submatrix) of T, +* (a) Compute T - sigma I = L D L^T, so that L and D +* define all the wanted eigenvalues to high relative accuracy. +* This means that small relative changes in the entries of D and L +* cause only small relative changes in the eigenvalues and +* eigenvectors. The standard (unfactored) representation of the +* tridiagonal matrix T does not have this property in general. +* (b) Compute the eigenvalues to suitable accuracy. +* If the eigenvectors are desired, the algorithm attains full +* accuracy of the computed eigenvalues only right before +* the corresponding vectors have to be computed, see steps c) and d). +* (c) For each cluster of close eigenvalues, select a new +* shift close to the cluster, find a new factorization, and refine +* the shifted eigenvalues to suitable accuracy. +* (d) For each eigenvalue with a large enough relative separation compute +* the corresponding eigenvector by forming a rank revealing twisted +* factorization. Go back to (c) for any clusters that remain. +* +* The desired accuracy of the output can be specified by the input +* parameter ABSTOL. +* +* For more details, see DSTEMR's documentation and: +* - Inderjit S. Dhillon and Beresford N. Parlett: "Multiple representations +* to compute orthogonal eigenvectors of symmetric tridiagonal matrices," +* Linear Algebra and its Applications, 387(1), pp. 1-28, August 2004. +* - Inderjit Dhillon and Beresford Parlett: "Orthogonal Eigenvectors and +* Relative Gaps," SIAM Journal on Matrix Analysis and Applications, Vol. 25, +* 2004. Also LAPACK Working Note 154. +* - Inderjit Dhillon: "A new O(n^2) algorithm for the symmetric +* tridiagonal eigenvalue/eigenvector problem", +* Computer Science Division Technical Report No. UCB/CSD-97-971, +* UC Berkeley, May 1997. +* +* +* Note 1 : DSYEVR calls DSTEMR when the full spectrum is requested +* on machines which conform to the ieee-754 floating point standard. +* DSYEVR calls DSTEBZ and SSTEIN on non-ieee machines and +* when partial spectrum requests are made. +* +* Normal execution of DSTEMR may create NaNs and infinities and +* hence may abort due to a floating point exception in environments +* which do not handle NaNs and infinities in the ieee standard default +* manner. +* +* Arguments +* ========= +* +* JOBZ (input) CHARACTER*1 +* = 'N': Compute eigenvalues only; +* = 'V': Compute eigenvalues and eigenvectors. +* +* RANGE (input) CHARACTER*1 +* = 'A': all eigenvalues will be found. +* = 'V': all eigenvalues in the half-open interval (VL,VU] +* will be found. +* = 'I': the IL-th through IU-th eigenvalues will be found. +********** For RANGE = 'V' or 'I' and IU - IL < N - 1, DSTEBZ and +********** DSTEIN are called +* +* UPLO (input) CHARACTER*1 +* = 'U': Upper triangle of A is stored; +* = 'L': Lower triangle of A is stored. +* +* N (input) INTEGER +* The order of the matrix A. N >= 0. +* +* A (input/output) DOUBLE PRECISION array, dimension (LDA, N) +* On entry, the symmetric matrix A. If UPLO = 'U', the +* leading N-by-N upper triangular part of A contains the +* upper triangular part of the matrix A. If UPLO = 'L', +* the leading N-by-N lower triangular part of A contains +* the lower triangular part of the matrix A. +* On exit, the lower triangle (if UPLO='L') or the upper +* triangle (if UPLO='U') of A, including the diagonal, is +* destroyed. +* +* LDA (input) INTEGER +* The leading dimension of the array A. LDA >= max(1,N). +* +* VL (input) DOUBLE PRECISION +* VU (input) DOUBLE PRECISION +* If RANGE='V', the lower and upper bounds of the interval to +* be searched for eigenvalues. VL < VU. +* Not referenced if RANGE = 'A' or 'I'. +* +* IL (input) INTEGER +* IU (input) INTEGER +* If RANGE='I', the indices (in ascending order) of the +* smallest and largest eigenvalues to be returned. +* 1 <= IL <= IU <= N, if N > 0; IL = 1 and IU = 0 if N = 0. +* Not referenced if RANGE = 'A' or 'V'. +* +* ABSTOL (input) DOUBLE PRECISION +* The absolute error tolerance for the eigenvalues. +* An approximate eigenvalue is accepted as converged +* when it is determined to lie in an interval [a,b] +* of width less than or equal to +* +* ABSTOL + EPS * max( |a|,|b| ) , +* +* where EPS is the machine precision. If ABSTOL is less than +* or equal to zero, then EPS*|T| will be used in its place, +* where |T| is the 1-norm of the tridiagonal matrix obtained +* by reducing A to tridiagonal form. +* +* See "Computing Small Singular Values of Bidiagonal Matrices +* with Guaranteed High Relative Accuracy," by Demmel and +* Kahan, LAPACK Working Note #3. +* +* If high relative accuracy is important, set ABSTOL to +* DLAMCH( 'Safe minimum' ). Doing so will guarantee that +* eigenvalues are computed to high relative accuracy when +* possible in future releases. The current code does not +* make any guarantees about high relative accuracy, but +* future releases will. See J. Barlow and J. Demmel, +* "Computing Accurate Eigensystems of Scaled Diagonally +* Dominant Matrices", LAPACK Working Note #7, for a discussion +* of which matrices define their eigenvalues to high relative +* accuracy. +* +* M (output) INTEGER +* The total number of eigenvalues found. 0 <= M <= N. +* If RANGE = 'A', M = N, and if RANGE = 'I', M = IU-IL+1. +* +* W (output) DOUBLE PRECISION array, dimension (N) +* The first M elements contain the selected eigenvalues in +* ascending order. +* +* Z (output) DOUBLE PRECISION array, dimension (LDZ, max(1,M)) +* If JOBZ = 'V', then if INFO = 0, the first M columns of Z +* contain the orthonormal eigenvectors of the matrix A +* corresponding to the selected eigenvalues, with the i-th +* column of Z holding the eigenvector associated with W(i). +* If JOBZ = 'N', then Z is not referenced. +* Note: the user must ensure that at least max(1,M) columns are +* supplied in the array Z; if RANGE = 'V', the exact value of M +* is not known in advance and an upper bound must be used. +* Supplying N columns is always safe. +* +* LDZ (input) INTEGER +* The leading dimension of the array Z. LDZ >= 1, and if +* JOBZ = 'V', LDZ >= max(1,N). +* +* ISUPPZ (output) INTEGER array, dimension ( 2*max(1,M) ) +* The support of the eigenvectors in Z, i.e., the indices +* indicating the nonzero elements in Z. The i-th eigenvector +* is nonzero only in elements ISUPPZ( 2*i-1 ) through +* ISUPPZ( 2*i ). +********** Implemented only for RANGE = 'A' or 'I' and IU - IL = N - 1 +* +* WORK (workspace/output) DOUBLE PRECISION array, dimension (MAX(1,LWORK)) +* On exit, if INFO = 0, WORK(1) returns the optimal LWORK. +* +* LWORK (input) INTEGER +* The dimension of the array WORK. LWORK >= max(1,26*N). +* For optimal efficiency, LWORK >= (NB+6)*N, +* where NB is the max of the blocksize for DSYTRD and DORMTR +* returned by ILAENV. +* +* If LWORK = -1, then a workspace query is assumed; the routine +* only calculates the optimal size of the WORK array, returns +* this value as the first entry of the WORK array, and no error +* message related to LWORK is issued by XERBLA. +* +* IWORK (workspace/output) INTEGER array, dimension (MAX(1,LIWORK)) +* On exit, if INFO = 0, IWORK(1) returns the optimal LWORK. +* +* LIWORK (input) INTEGER +* The dimension of the array IWORK. LIWORK >= max(1,10*N). +* +* If LIWORK = -1, then a workspace query is assumed; the +* routine only calculates the optimal size of the IWORK array, +* returns this value as the first entry of the IWORK array, and +* no error message related to LIWORK is issued by XERBLA. +* +* INFO (output) INTEGER +* = 0: successful exit +* < 0: if INFO = -i, the i-th argument had an illegal value +* > 0: Internal error +* +* Further Details +* =============== +* +* Based on contributions by +* Inderjit Dhillon, IBM Almaden, USA +* Osni Marques, LBNL/NERSC, USA +* Ken Stanley, Computer Science Division, University of +* California at Berkeley, USA +* Jason Riedy, Computer Science Division, University of +* California at Berkeley, USA +* +* ===================================================================== + + */ + + // ------------------------------------------------------------------------------------ + + template < + typename T, + long NR1, long NR2, long NR3, long NR4, + long NC1, long NC2, long NC3, long NC4, + typename MM + > + int syevr ( + const char jobz, + const char range, + const char uplo, + matrix& a, + const double vl, + const double vu, + const integer il, + const integer iu, + const double abstol, + integer& num_eigenvalues_found, + matrix& w, + matrix& z, + matrix& isuppz + ) + { + matrix work; + matrix iwork; + + const long n = a.nr(); + + w.set_size(n,1); + + isuppz.set_size(2*n, 1); + + if (jobz == 'V') + { + z.set_size(n,n); + } + else + { + z.set_size(NR3?NR3:1, NC3?NC3:1); + } + + // figure out how big the workspace needs to be. + T work_size = 1; + integer iwork_size = 1; + int info = binding::syevr(jobz, range, uplo, n, &a(0,0), + a.nr(), vl, vu, il, iu, abstol, &num_eigenvalues_found, + &w(0,0), &z(0,0), z.nr(), &isuppz(0,0), &work_size, -1, + &iwork_size, -1); + + if (info != 0) + return info; + + if (work.size() < work_size) + work.set_size(static_cast(work_size), 1); + if (iwork.size() < iwork_size) + iwork.set_size(iwork_size, 1); + + // compute the actual decomposition + info = binding::syevr(jobz, range, uplo, n, &a(0,0), + a.nr(), vl, vu, il, iu, abstol, &num_eigenvalues_found, + &w(0,0), &z(0,0), z.nr(), &isuppz(0,0), &work(0,0), work.size(), + &iwork(0,0), iwork.size()); + + + return info; + } + + // ------------------------------------------------------------------------------------ + + template < + typename T, + long NR1, long NR2, long NR3, long NR4, + long NC1, long NC2, long NC3, long NC4, + typename MM + > + int syevr ( + const char jobz, + const char range, + char uplo, + matrix& a, + const double vl, + const double vu, + const integer il, + const integer iu, + const double abstol, + integer& num_eigenvalues_found, + matrix& w, + matrix& z, + matrix& isuppz + ) + { + matrix work; + matrix iwork; + + if (uplo == 'L') + uplo = 'U'; + else + uplo = 'L'; + + const long n = a.nr(); + + w.set_size(n,1); + + isuppz.set_size(2*n, 1); + + if (jobz == 'V') + { + z.set_size(n,n); + } + else + { + z.set_size(NR3?NR3:1, NC3?NC3:1); + } + + // figure out how big the workspace needs to be. + T work_size = 1; + integer iwork_size = 1; + int info = binding::syevr(jobz, range, uplo, n, &a(0,0), + a.nc(), vl, vu, il, iu, abstol, &num_eigenvalues_found, + &w(0,0), &z(0,0), z.nc(), &isuppz(0,0), &work_size, -1, + &iwork_size, -1); + + if (info != 0) + return info; + + if (work.size() < work_size) + work.set_size(static_cast(work_size), 1); + if (iwork.size() < iwork_size) + iwork.set_size(iwork_size, 1); + + // compute the actual decomposition + info = binding::syevr(jobz, range, uplo, n, &a(0,0), + a.nc(), vl, vu, il, iu, abstol, &num_eigenvalues_found, + &w(0,0), &z(0,0), z.nc(), &isuppz(0,0), &work(0,0), work.size(), + &iwork(0,0), iwork.size()); + + z = trans(z); + + return info; + } + + // ------------------------------------------------------------------------------------ + + } + +} + +// ---------------------------------------------------------------------------------------- + +#endif // DLIB_LAPACk_EVR_Hh_ + + + diff --git a/ml/dlib/dlib/matrix/matrix.h b/ml/dlib/dlib/matrix/matrix.h new file mode 100644 index 000000000..b16635879 --- /dev/null +++ b/ml/dlib/dlib/matrix/matrix.h @@ -0,0 +1,2162 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MATRIx_ +#define DLIB_MATRIx_ + +#include "matrix_exp.h" +#include "matrix_abstract.h" +#include "../algs.h" +#include "../serialize.h" +#include "../enable_if.h" +#include +#include +#include "../memory_manager.h" +#include "../is_kind.h" +#include "matrix_data_layout.h" +#include "matrix_assign_fwd.h" +#include "matrix_op.h" +#include +#ifdef DLIB_HAS_INITIALIZER_LISTS +#include +#endif + +#ifdef MATLAB_MEX_FILE +#include +#endif + +#ifdef _MSC_VER +// Disable the following warnings for Visual Studio + +// This warning is: +// "warning C4355: 'this' : used in base member initializer list" +// Which we get from this code but it is not an error so I'm turning this +// warning off and then turning it back on at the end of the file. +#pragma warning(disable : 4355) + +// "warning C4723: potential divide by 0" - This warning is triggered in +// matrix(const std::initializer_list& l) where the compiler can see that +// matrix<> was templated in a way making NR ending up 0, but division by 0 at runtime +// is not possible because the division operation is inside "if (NR!=0)" block. +#pragma warning(disable : 4723) + +#endif + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + // This template will perform the needed loop for element multiplication using whichever + // dimension is provided as a compile time constant (if one is at all). + template < + typename LHS, + typename RHS, + long lhs_nc = LHS::NC, + long rhs_nr = RHS::NR + > + struct matrix_multiply_helper + { + typedef typename LHS::type type; + template + inline const static type eval ( + const RHS_& rhs, + const LHS_& lhs, + const long r, + const long c + ) + { + type temp = lhs(r,0)*rhs(0,c); + for (long i = 1; i < rhs.nr(); ++i) + { + temp += lhs(r,i)*rhs(i,c); + } + return temp; + } + }; + + template < + typename LHS, + typename RHS, + long lhs_nc + > + struct matrix_multiply_helper + { + typedef typename LHS::type type; + template + inline const static type eval ( + const RHS_& rhs, + const LHS_& lhs, + const long r, + const long c + ) + { + type temp = lhs(r,0)*rhs(0,c); + for (long i = 1; i < lhs.nc(); ++i) + { + temp += lhs(r,i)*rhs(i,c); + } + return temp; + } + }; + + template + class matrix_multiply_exp; + + template + struct matrix_traits > + { + typedef typename LHS::type type; + typedef typename LHS::type const_ret_type; + typedef typename LHS::mem_manager_type mem_manager_type; + typedef typename LHS::layout_type layout_type; + const static long NR = LHS::NR; + const static long NC = RHS::NC; + +#ifdef DLIB_USE_BLAS + // if there are BLAS functions to be called then we want to make sure we + // always evaluate any complex expressions so that the BLAS bindings can happen. + const static bool lhs_is_costly = (LHS::cost > 2)&&(RHS::NC != 1 || LHS::cost >= 10000); + const static bool rhs_is_costly = (RHS::cost > 2)&&(LHS::NR != 1 || RHS::cost >= 10000); +#else + const static bool lhs_is_costly = (LHS::cost > 4)&&(RHS::NC != 1); + const static bool rhs_is_costly = (RHS::cost > 4)&&(LHS::NR != 1); +#endif + + // Note that if we decide that one of the matrices is too costly we will evaluate it + // into a temporary. Doing this resets its cost back to 1. + const static long lhs_cost = ((lhs_is_costly==true)? 1 : (LHS::cost)); + const static long rhs_cost = ((rhs_is_costly==true)? 1 : (RHS::cost)); + + // The cost of evaluating an element of a matrix multiply is the cost of evaluating elements from + // RHS and LHS times the number of rows/columns in the RHS/LHS matrix. If we don't know the matrix + // dimensions then just assume it is really large. + const static long cost = ((tmax::value!=0)? ((lhs_cost+rhs_cost)*tmax::value):(10000)); + }; + + template struct conditional_matrix_temp { typedef typename T::matrix_type type; }; + template struct conditional_matrix_temp { typedef T& type; }; + + template < + typename LHS, + typename RHS + > + class matrix_multiply_exp : public matrix_exp > + { + /*! + REQUIREMENTS ON LHS AND RHS + - must be matrix_exp objects. + !*/ + public: + + typedef typename matrix_traits::type type; + typedef typename matrix_traits::const_ret_type const_ret_type; + typedef typename matrix_traits::mem_manager_type mem_manager_type; + const static long NR = matrix_traits::NR; + const static long NC = matrix_traits::NC; + const static long cost = matrix_traits::cost; + typedef typename matrix_traits::layout_type layout_type; + + + const static bool lhs_is_costly = matrix_traits::lhs_is_costly; + const static bool rhs_is_costly = matrix_traits::rhs_is_costly; + const static bool either_is_costly = lhs_is_costly || rhs_is_costly; + const static bool both_are_costly = lhs_is_costly && rhs_is_costly; + + typedef typename conditional_matrix_temp::type LHS_ref_type; + typedef typename conditional_matrix_temp::type RHS_ref_type; + + // This constructor exists simply for the purpose of causing a compile time error if + // someone tries to create an instance of this object with the wrong kind of objects. + template + matrix_multiply_exp (T1,T2); + + inline matrix_multiply_exp ( + const LHS& lhs_, + const RHS& rhs_ + ) : + lhs(lhs_), + rhs(rhs_) + { + // You are trying to multiply two incompatible matrices together. The number of columns + // in the matrix on the left must match the number of rows in the matrix on the right. + COMPILE_TIME_ASSERT(LHS::NC == RHS::NR || LHS::NC*RHS::NR == 0); + DLIB_ASSERT(lhs.nc() == rhs.nr() && lhs.size() > 0 && rhs.size() > 0, + "\tconst matrix_exp operator*(const matrix_exp& lhs, const matrix_exp& rhs)" + << "\n\tYou are trying to multiply two incompatible matrices together" + << "\n\tlhs.nr(): " << lhs.nr() + << "\n\tlhs.nc(): " << lhs.nc() + << "\n\trhs.nr(): " << rhs.nr() + << "\n\trhs.nc(): " << rhs.nc() + << "\n\t&lhs: " << &lhs + << "\n\t&rhs: " << &rhs + ); + + // You can't multiply matrices together if they don't both contain the same type of elements. + COMPILE_TIME_ASSERT((is_same_type::value == true)); + } + + inline const type operator() ( + const long r, + const long c + ) const + { + return matrix_multiply_helper::eval(rhs,lhs,r,c); + } + + inline const type operator() ( long i ) const + { return matrix_exp::operator()(i); } + + long nr ( + ) const { return lhs.nr(); } + + long nc ( + ) const { return rhs.nc(); } + + template + bool aliases ( + const matrix_exp& item + ) const { return lhs.aliases(item) || rhs.aliases(item); } + + template + bool destructively_aliases ( + const matrix_exp& item + ) const { return aliases(item); } + + LHS_ref_type lhs; + RHS_ref_type rhs; + }; + + template < typename EXP1, typename EXP2 > + inline const matrix_multiply_exp operator* ( + const matrix_exp& m1, + const matrix_exp& m2 + ) + { + return matrix_multiply_exp(m1.ref(), m2.ref()); + } + + template + class matrix_mul_scal_exp; + + // ------------------------- + + // Now we declare some overloads that cause any scalar multiplications to percolate + // up and outside of any matrix multiplies. Note that we are using the non-reference containing + // mode of the matrix_mul_scal_exp object since we are passing in locally constructed matrix_multiply_exp + // objects. So the matrix_mul_scal_exp object will contain copies of matrix_multiply_exp objects + // rather than references to them. This could result in extra matrix copies if the matrix_multiply_exp + // decided it should evaluate any of its arguments. So we also try to not apply this percolating operation + // if the matrix_multiply_exp would contain a fully evaluated copy of the original matrix_mul_scal_exp + // expression. + // + // Also, the reason we want to apply this transformation in the first place is because it (1) makes + // the expressions going into matrix multiply expressions simpler and (2) it makes it a lot more + // straightforward to bind BLAS calls to matrix expressions involving scalar multiplies. + template < typename EXP1, typename EXP2 > + inline const typename disable_if_c< matrix_multiply_exp, matrix_mul_scal_exp >::both_are_costly , + matrix_mul_scal_exp,false> >::type operator* ( + const matrix_mul_scal_exp& m1, + const matrix_mul_scal_exp& m2 + ) + { + typedef matrix_multiply_exp exp1; + typedef matrix_mul_scal_exp exp2; + return exp2(exp1(m1.m, m2.m), m1.s*m2.s); + } + + template < typename EXP1, typename EXP2 > + inline const typename disable_if_c< matrix_multiply_exp, EXP2 >::lhs_is_costly , + matrix_mul_scal_exp,false> >::type operator* ( + const matrix_mul_scal_exp& m1, + const matrix_exp& m2 + ) + { + typedef matrix_multiply_exp exp1; + typedef matrix_mul_scal_exp exp2; + return exp2(exp1(m1.m, m2.ref()), m1.s); + } + + template < typename EXP1, typename EXP2 > + inline const typename disable_if_c< matrix_multiply_exp >::rhs_is_costly , + matrix_mul_scal_exp,false> >::type operator* ( + const matrix_exp& m1, + const matrix_mul_scal_exp& m2 + ) + { + typedef matrix_multiply_exp exp1; + typedef matrix_mul_scal_exp exp2; + return exp2(exp1(m1.ref(), m2.m), m2.s); + } + +// ---------------------------------------------------------------------------------------- + + template + class matrix_add_exp; + + template + struct matrix_traits > + { + typedef typename LHS::type type; + typedef typename LHS::type const_ret_type; + typedef typename LHS::mem_manager_type mem_manager_type; + typedef typename LHS::layout_type layout_type; + const static long NR = (RHS::NR > LHS::NR) ? RHS::NR : LHS::NR; + const static long NC = (RHS::NC > LHS::NC) ? RHS::NC : LHS::NC; + const static long cost = LHS::cost+RHS::cost+1; + }; + + template < + typename LHS, + typename RHS + > + class matrix_add_exp : public matrix_exp > + { + /*! + REQUIREMENTS ON LHS AND RHS + - must be matrix_exp objects. + !*/ + public: + typedef typename matrix_traits::type type; + typedef typename matrix_traits::const_ret_type const_ret_type; + typedef typename matrix_traits::mem_manager_type mem_manager_type; + const static long NR = matrix_traits::NR; + const static long NC = matrix_traits::NC; + const static long cost = matrix_traits::cost; + typedef typename matrix_traits::layout_type layout_type; + + // This constructor exists simply for the purpose of causing a compile time error if + // someone tries to create an instance of this object with the wrong kind of objects. + template + matrix_add_exp (T1,T2); + + matrix_add_exp ( + const LHS& lhs_, + const RHS& rhs_ + ) : + lhs(lhs_), + rhs(rhs_) + { + // You can only add matrices together if they both have the same number of rows and columns. + COMPILE_TIME_ASSERT(LHS::NR == RHS::NR || LHS::NR == 0 || RHS::NR == 0); + COMPILE_TIME_ASSERT(LHS::NC == RHS::NC || LHS::NC == 0 || RHS::NC == 0); + DLIB_ASSERT(lhs.nc() == rhs.nc() && + lhs.nr() == rhs.nr(), + "\tconst matrix_exp operator+(const matrix_exp& lhs, const matrix_exp& rhs)" + << "\n\tYou are trying to add two incompatible matrices together" + << "\n\tlhs.nr(): " << lhs.nr() + << "\n\tlhs.nc(): " << lhs.nc() + << "\n\trhs.nr(): " << rhs.nr() + << "\n\trhs.nc(): " << rhs.nc() + << "\n\t&lhs: " << &lhs + << "\n\t&rhs: " << &rhs + ); + + // You can only add matrices together if they both contain the same types of elements. + COMPILE_TIME_ASSERT((is_same_type::value == true)); + } + + const type operator() ( + long r, + long c + ) const { return lhs(r,c) + rhs(r,c); } + + inline const type operator() ( long i ) const + { return matrix_exp::operator()(i); } + + template + bool aliases ( + const matrix_exp& item + ) const { return lhs.aliases(item) || rhs.aliases(item); } + + template + bool destructively_aliases ( + const matrix_exp& item + ) const { return lhs.destructively_aliases(item) || rhs.destructively_aliases(item); } + + long nr ( + ) const { return lhs.nr(); } + + long nc ( + ) const { return lhs.nc(); } + + const LHS& lhs; + const RHS& rhs; + }; + + template < + typename EXP1, + typename EXP2 + > + inline const matrix_add_exp operator+ ( + const matrix_exp& m1, + const matrix_exp& m2 + ) + { + return matrix_add_exp(m1.ref(),m2.ref()); + } + +// ---------------------------------------------------------------------------------------- + + template + class matrix_subtract_exp; + + template + struct matrix_traits > + { + typedef typename LHS::type type; + typedef typename LHS::type const_ret_type; + typedef typename LHS::mem_manager_type mem_manager_type; + typedef typename LHS::layout_type layout_type; + const static long NR = (RHS::NR > LHS::NR) ? RHS::NR : LHS::NR; + const static long NC = (RHS::NC > LHS::NC) ? RHS::NC : LHS::NC; + const static long cost = LHS::cost+RHS::cost+1; + }; + + template < + typename LHS, + typename RHS + > + class matrix_subtract_exp : public matrix_exp > + { + /*! + REQUIREMENTS ON LHS AND RHS + - must be matrix_exp objects. + !*/ + public: + typedef typename matrix_traits::type type; + typedef typename matrix_traits::const_ret_type const_ret_type; + typedef typename matrix_traits::mem_manager_type mem_manager_type; + const static long NR = matrix_traits::NR; + const static long NC = matrix_traits::NC; + const static long cost = matrix_traits::cost; + typedef typename matrix_traits::layout_type layout_type; + + + // This constructor exists simply for the purpose of causing a compile time error if + // someone tries to create an instance of this object with the wrong kind of objects. + template + matrix_subtract_exp (T1,T2); + + matrix_subtract_exp ( + const LHS& lhs_, + const RHS& rhs_ + ) : + lhs(lhs_), + rhs(rhs_) + { + // You can only subtract one matrix from another if they both have the same number of rows and columns. + COMPILE_TIME_ASSERT(LHS::NR == RHS::NR || LHS::NR == 0 || RHS::NR == 0); + COMPILE_TIME_ASSERT(LHS::NC == RHS::NC || LHS::NC == 0 || RHS::NC == 0); + DLIB_ASSERT(lhs.nc() == rhs.nc() && + lhs.nr() == rhs.nr(), + "\tconst matrix_exp operator-(const matrix_exp& lhs, const matrix_exp& rhs)" + << "\n\tYou are trying to subtract two incompatible matrices" + << "\n\tlhs.nr(): " << lhs.nr() + << "\n\tlhs.nc(): " << lhs.nc() + << "\n\trhs.nr(): " << rhs.nr() + << "\n\trhs.nc(): " << rhs.nc() + << "\n\t&lhs: " << &lhs + << "\n\t&rhs: " << &rhs + ); + + // You can only subtract one matrix from another if they both contain elements of the same type. + COMPILE_TIME_ASSERT((is_same_type::value == true)); + } + + const type operator() ( + long r, + long c + ) const { return lhs(r,c) - rhs(r,c); } + + inline const type operator() ( long i ) const + { return matrix_exp::operator()(i); } + + template + bool aliases ( + const matrix_exp& item + ) const { return lhs.aliases(item) || rhs.aliases(item); } + + template + bool destructively_aliases ( + const matrix_exp& item + ) const { return lhs.destructively_aliases(item) || rhs.destructively_aliases(item); } + + long nr ( + ) const { return lhs.nr(); } + + long nc ( + ) const { return lhs.nc(); } + + const LHS& lhs; + const RHS& rhs; + }; + + template < + typename EXP1, + typename EXP2 + > + inline const matrix_subtract_exp operator- ( + const matrix_exp& m1, + const matrix_exp& m2 + ) + { + return matrix_subtract_exp(m1.ref(),m2.ref()); + } + +// ---------------------------------------------------------------------------------------- + + template + class matrix_div_scal_exp; + + template + struct matrix_traits > + { + typedef typename M::type type; + typedef typename M::type const_ret_type; + typedef typename M::mem_manager_type mem_manager_type; + typedef typename M::layout_type layout_type; + const static long NR = M::NR; + const static long NC = M::NC; + const static long cost = M::cost+1; + }; + + template < + typename M + > + class matrix_div_scal_exp : public matrix_exp > + { + /*! + REQUIREMENTS ON M + - must be a matrix_exp object. + !*/ + public: + typedef typename matrix_traits::type type; + typedef typename matrix_traits::const_ret_type const_ret_type; + typedef typename matrix_traits::mem_manager_type mem_manager_type; + const static long NR = matrix_traits::NR; + const static long NC = matrix_traits::NC; + const static long cost = matrix_traits::cost; + typedef typename matrix_traits::layout_type layout_type; + + + // This constructor exists simply for the purpose of causing a compile time error if + // someone tries to create an instance of this object with the wrong kind of objects. + template + matrix_div_scal_exp (T1, const type&); + + matrix_div_scal_exp ( + const M& m_, + const type& s_ + ) : + m(m_), + s(s_) + {} + + const type operator() ( + long r, + long c + ) const { return m(r,c)/s; } + + inline const type operator() ( long i ) const + { return matrix_exp::operator()(i); } + + template + bool aliases ( + const matrix_exp& item + ) const { return m.aliases(item); } + + template + bool destructively_aliases ( + const matrix_exp& item + ) const { return m.destructively_aliases(item); } + + long nr ( + ) const { return m.nr(); } + + long nc ( + ) const { return m.nc(); } + + const M& m; + const type s; + }; + + template < + typename EXP, + typename S + > + inline const typename enable_if_c::is_integer, matrix_div_scal_exp >::type operator/ ( + const matrix_exp& m, + const S& s + ) + { + return matrix_div_scal_exp(m.ref(),static_cast(s)); + } + +// ---------------------------------------------------------------------------------------- + + template + struct matrix_traits > + { + typedef typename M::type type; + typedef typename M::type const_ret_type; + typedef typename M::mem_manager_type mem_manager_type; + typedef typename M::layout_type layout_type; + const static long NR = M::NR; + const static long NC = M::NC; + const static long cost = M::cost+1; + }; + + template struct conditional_reference { typedef T type; }; + template struct conditional_reference { typedef T& type; }; + + + template < + typename M, + bool use_reference + > + class matrix_mul_scal_exp : public matrix_exp > + { + /*! + REQUIREMENTS ON M + - must be a matrix_exp object. + + !*/ + public: + typedef typename matrix_traits::type type; + typedef typename matrix_traits::const_ret_type const_ret_type; + typedef typename matrix_traits::mem_manager_type mem_manager_type; + const static long NR = matrix_traits::NR; + const static long NC = matrix_traits::NC; + const static long cost = matrix_traits::cost; + typedef typename matrix_traits::layout_type layout_type; + + // You aren't allowed to multiply a matrix of matrices by a scalar. + COMPILE_TIME_ASSERT(is_matrix::value == false); + + // This constructor exists simply for the purpose of causing a compile time error if + // someone tries to create an instance of this object with the wrong kind of objects. + template + matrix_mul_scal_exp (T1, const type&); + + matrix_mul_scal_exp ( + const M& m_, + const type& s_ + ) : + m(m_), + s(s_) + {} + + const type operator() ( + long r, + long c + ) const { return m(r,c)*s; } + + inline const type operator() ( long i ) const + { return matrix_exp::operator()(i); } + + template + bool aliases ( + const matrix_exp& item + ) const { return m.aliases(item); } + + template + bool destructively_aliases ( + const matrix_exp& item + ) const { return m.destructively_aliases(item); } + + long nr ( + ) const { return m.nr(); } + + long nc ( + ) const { return m.nc(); } + + typedef typename conditional_reference::type M_ref_type; + + M_ref_type m; + const type s; + }; + + template < + typename EXP, + typename S + > + inline typename disable_if, const matrix_mul_scal_exp >::type operator* ( + const matrix_exp& m, + const S& s + ) + { + typedef typename EXP::type type; + return matrix_mul_scal_exp(m.ref(),static_cast(s)); + } + + template < + typename EXP, + typename S, + bool B + > + inline typename disable_if, const matrix_mul_scal_exp >::type operator* ( + const matrix_mul_scal_exp& m, + const S& s + ) + { + typedef typename EXP::type type; + return matrix_mul_scal_exp(m.m,static_cast(s)*m.s); + } + + template < + typename EXP, + typename S + > + inline typename disable_if, const matrix_mul_scal_exp >::type operator* ( + const S& s, + const matrix_exp& m + ) + { + typedef typename EXP::type type; + return matrix_mul_scal_exp(m.ref(),static_cast(s)); + } + + template < + typename EXP, + typename S, + bool B + > + inline typename disable_if, const matrix_mul_scal_exp >::type operator* ( + const S& s, + const matrix_mul_scal_exp& m + ) + { + typedef typename EXP::type type; + return matrix_mul_scal_exp(m.m,static_cast(s)*m.s); + } + + template < + typename EXP , + typename S + > + inline const typename disable_if_c::is_integer, matrix_mul_scal_exp >::type operator/ ( + const matrix_exp& m, + const S& s + ) + { + typedef typename EXP::type type; + const type one = 1; + return matrix_mul_scal_exp(m.ref(),one/static_cast(s)); + } + + template < + typename EXP, + bool B, + typename S + > + inline const typename disable_if_c::is_integer, matrix_mul_scal_exp >::type operator/ ( + const matrix_mul_scal_exp& m, + const S& s + ) + { + typedef typename EXP::type type; + return matrix_mul_scal_exp(m.m,m.s/static_cast(s)); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_s_div_m : basic_op_m + { + typedef typename M::type type; + + op_s_div_m( const M& m_, const type& s_) : basic_op_m(m_), s(s_){} + + const type s; + + const static long cost = M::cost+1; + typedef const typename M::type const_ret_type; + const_ret_type apply (long r, long c) const + { + return s/this->m(r,c); + } + }; + + template < + typename EXP, + typename S + > + const typename disable_if, matrix_op > >::type operator/ ( + const S& val, + const matrix_exp& m + ) + { + typedef typename EXP::type type; + + typedef op_s_div_m op; + return matrix_op(op(m.ref(), static_cast(val))); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP + > + inline const matrix_mul_scal_exp operator- ( + const matrix_exp& m + ) + { + return matrix_mul_scal_exp(m.ref(),-1); + } + + template < + typename EXP, + bool B + > + inline const matrix_mul_scal_exp operator- ( + const matrix_mul_scal_exp& m + ) + { + return matrix_mul_scal_exp(m.m,-1*m.s); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_add_scalar : basic_op_m + { + typedef typename M::type type; + + op_add_scalar( const M& m_, const type& s_) : basic_op_m(m_), s(s_){} + + const type s; + + const static long cost = M::cost+1; + typedef const typename M::type const_ret_type; + const_ret_type apply (long r, long c) const + { + return this->m(r,c) + s; + } + }; + + template < + typename EXP, + typename T + > + const typename disable_if, matrix_op > >::type operator+ ( + const matrix_exp& m, + const T& val + ) + { + typedef typename EXP::type type; + + typedef op_add_scalar op; + return matrix_op(op(m.ref(), static_cast(val))); + } + + template < + typename EXP, + typename T + > + const typename disable_if, matrix_op > >::type operator+ ( + const T& val, + const matrix_exp& m + ) + { + typedef typename EXP::type type; + + typedef op_add_scalar op; + return matrix_op(op(m.ref(), static_cast(val))); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_subl_scalar : basic_op_m + { + typedef typename M::type type; + + op_subl_scalar( const M& m_, const type& s_) : basic_op_m(m_), s(s_){} + + const type s; + + const static long cost = M::cost+1; + typedef const typename M::type const_ret_type; + const_ret_type apply (long r, long c) const + { + return s - this->m(r,c) ; + } + }; + + template < + typename EXP, + typename T + > + const typename disable_if, matrix_op > >::type operator- ( + const T& val, + const matrix_exp& m + ) + { + typedef typename EXP::type type; + + typedef op_subl_scalar op; + return matrix_op(op(m.ref(), static_cast(val))); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_subr_scalar : basic_op_m + { + typedef typename M::type type; + + op_subr_scalar( const M& m_, const type& s_) : basic_op_m(m_), s(s_){} + + const type s; + + const static long cost = M::cost+1; + typedef const typename M::type const_ret_type; + const_ret_type apply (long r, long c) const + { + return this->m(r,c) - s; + } + }; + + template < + typename EXP, + typename T + > + const typename disable_if, matrix_op > >::type operator- ( + const matrix_exp& m, + const T& val + ) + { + typedef typename EXP::type type; + + typedef op_subr_scalar op; + return matrix_op(op(m.ref(), static_cast(val))); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP1, + typename EXP2 + > + bool operator== ( + const matrix_exp& m1, + const matrix_exp& m2 + ) + { + if (m1.nr() == m2.nr() && m1.nc() == m2.nc()) + { + for (long r = 0; r < m1.nr(); ++r) + { + for (long c = 0; c < m1.nc(); ++c) + { + if (m1(r,c) != m2(r,c)) + return false; + } + } + return true; + } + return false; + } + + template < + typename EXP1, + typename EXP2 + > + inline bool operator!= ( + const matrix_exp& m1, + const matrix_exp& m2 + ) { return !(m1 == m2); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template + struct op_pointer_to_mat; + template + struct op_pointer_to_col_vect; + + template < + typename T, + long num_rows, + long num_cols, + typename mem_manager, + typename layout + > + struct matrix_traits > + { + typedef T type; + typedef const T& const_ret_type; + typedef mem_manager mem_manager_type; + typedef layout layout_type; + const static long NR = num_rows; + const static long NC = num_cols; + const static long cost = 1; + + }; + + template < + typename T, + long num_rows, + long num_cols, + typename mem_manager, + typename layout + > + class matrix : public matrix_exp > + { + + COMPILE_TIME_ASSERT(num_rows >= 0 && num_cols >= 0); + + public: + typedef typename matrix_traits::type type; + typedef typename matrix_traits::const_ret_type const_ret_type; + typedef typename matrix_traits::mem_manager_type mem_manager_type; + typedef typename matrix_traits::layout_type layout_type; + const static long NR = matrix_traits::NR; + const static long NC = matrix_traits::NC; + const static long cost = matrix_traits::cost; + typedef T* iterator; + typedef const T* const_iterator; + + matrix () + { + } + + explicit matrix ( + long length + ) + { + // This object you are trying to call matrix(length) on is not a column or + // row vector. + COMPILE_TIME_ASSERT(NR == 1 || NC == 1); + DLIB_ASSERT( length >= 0, + "\tmatrix::matrix(length)" + << "\n\tlength must be at least 0" + << "\n\tlength: " << length + << "\n\tNR: " << NR + << "\n\tNC: " << NC + << "\n\tthis: " << this + ); + + if (NR == 1) + { + DLIB_ASSERT(NC == 0 || NC == length, + "\tmatrix::matrix(length)" + << "\n\tSince this is a statically sized matrix length must equal NC" + << "\n\tlength: " << length + << "\n\tNR: " << NR + << "\n\tNC: " << NC + << "\n\tthis: " << this + ); + + data.set_size(1,length); + } + else + { + DLIB_ASSERT(NR == 0 || NR == length, + "\tvoid matrix::set_size(length)" + << "\n\tSince this is a statically sized matrix length must equal NR" + << "\n\tlength: " << length + << "\n\tNR: " << NR + << "\n\tNC: " << NC + << "\n\tthis: " << this + ); + + data.set_size(length,1); + } + } + + matrix ( + long rows, + long cols + ) + { + DLIB_ASSERT( (NR == 0 || NR == rows) && ( NC == 0 || NC == cols) && + rows >= 0 && cols >= 0, + "\tvoid matrix::matrix(rows, cols)" + << "\n\tYou have supplied conflicting matrix dimensions" + << "\n\trows: " << rows + << "\n\tcols: " << cols + << "\n\tNR: " << NR + << "\n\tNC: " << NC + ); + data.set_size(rows,cols); + } + + template + matrix ( + const matrix_exp& m + ) + { + // You get an error on this line if the matrix m contains a type that isn't + // the same as the type contained in the target matrix. + COMPILE_TIME_ASSERT((is_same_type::value == true) || + (is_matrix::value == true)); + + // The matrix you are trying to assign m to is a statically sized matrix and + // m's dimensions don't match that of *this. + COMPILE_TIME_ASSERT(EXP::NR == NR || NR == 0 || EXP::NR == 0); + COMPILE_TIME_ASSERT(EXP::NC == NC || NC == 0 || EXP::NC == 0); + DLIB_ASSERT((NR == 0 || NR == m.nr()) && (NC == 0 || NC == m.nc()), + "\tmatrix& matrix::matrix(const matrix_exp& m)" + << "\n\tYou are trying to assign a dynamically sized matrix to a statically sized matrix with the wrong size" + << "\n\tNR: " << NR + << "\n\tNC: " << NC + << "\n\tm.nr(): " << m.nr() + << "\n\tm.nc(): " << m.nc() + << "\n\tthis: " << this + ); + + data.set_size(m.nr(),m.nc()); + + matrix_assign(*this, m); + } + + matrix ( + const matrix& m + ) : matrix_exp(*this) + { + data.set_size(m.nr(),m.nc()); + matrix_assign(*this, m); + } + +#ifdef DLIB_HAS_INITIALIZER_LISTS + matrix(const std::initializer_list& l) + { + if (NR*NC != 0) + { + DLIB_ASSERT(l.size() == NR*NC, + "\t matrix::matrix(const std::initializer_list& l)" + << "\n\t You are trying to initialize a statically sized matrix with a list that doesn't have a matching size." + << "\n\t l.size(): "<< l.size() + << "\n\t NR*NC: "<< NR*NC); + + data.set_size(NR, NC); + } + else if (NR!=0) + { + DLIB_ASSERT(l.size()%NR == 0, + "\t matrix::matrix(const std::initializer_list& l)" + << "\n\t You are trying to initialize a statically sized matrix with a list that doesn't have a compatible size." + << "\n\t l.size(): "<< l.size() + << "\n\t NR: "<< NR); + + if (l.size() != 0) + data.set_size(NR, l.size()/NR); + } + else if (NC!=0) + { + DLIB_ASSERT(l.size()%NC == 0, + "\t matrix::matrix(const std::initializer_list& l)" + << "\n\t You are trying to initialize a statically sized matrix with a list that doesn't have a compatible size." + << "\n\t l.size(): "<< l.size() + << "\n\t NC: "<< NC); + + if (l.size() != 0) + data.set_size(l.size()/NC, NC); + } + else if (l.size() != 0) + { + data.set_size(l.size(),1); + } + + if (l.size() != 0) + { + T* d = &data(0,0); + for (auto&& v : l) + *d++ = v; + } + + } + + matrix& operator=(const std::initializer_list& l) + { + matrix temp(l); + temp.swap(*this); + return *this; + } +#endif // DLIB_HAS_INITIALIZER_LISTS + +#ifdef DLIB_HAS_RVALUE_REFERENCES + matrix(matrix&& item) + { + #ifdef MATLAB_MEX_FILE + // You can't move memory around when compiled in a matlab mex file and the + // different locations have different ownership settings. + if (data._private_is_owned_by_matlab() == item.data._private_is_owned_by_matlab()) + { + swap(item); + } + else + { + data.set_size(item.nr(),item.nc()); + matrix_assign(*this, item); + } + #else + swap(item); + #endif + } + + matrix& operator= ( + matrix&& rhs + ) + { + #ifdef MATLAB_MEX_FILE + // You can't move memory around when compiled in a matlab mex file and the + // different locations have different ownership settings. + if (data._private_is_owned_by_matlab() == rhs.data._private_is_owned_by_matlab()) + { + swap(rhs); + } + else + { + data.set_size(rhs.nr(),rhs.nc()); + matrix_assign(*this, rhs); + } + #else + swap(rhs); + #endif + return *this; + } +#endif // DLIB_HAS_RVALUE_REFERENCES + + template + explicit matrix ( + U (&array)[len] + ) + { + COMPILE_TIME_ASSERT(NR*NC == len && len > 0); + size_t idx = 0; + for (long r = 0; r < NR; ++r) + { + for (long c = 0; c < NC; ++c) + { + data(r,c) = static_cast(array[idx]); + ++idx; + } + } + } + + T& operator() ( + long r, + long c + ) + { + DLIB_ASSERT(r < nr() && c < nc() && + r >= 0 && c >= 0, + "\tT& matrix::operator(r,c)" + << "\n\tYou must give a valid row and column" + << "\n\tr: " << r + << "\n\tc: " << c + << "\n\tnr(): " << nr() + << "\n\tnc(): " << nc() + << "\n\tthis: " << this + ); + return data(r,c); + } + + const T& operator() ( + long r, + long c + ) const + { + DLIB_ASSERT(r < nr() && c < nc() && + r >= 0 && c >= 0, + "\tconst T& matrix::operator(r,c)" + << "\n\tYou must give a valid row and column" + << "\n\tr: " << r + << "\n\tc: " << c + << "\n\tnr(): " << nr() + << "\n\tnc(): " << nc() + << "\n\tthis: " << this + ); + return data(r,c); + } + + T& operator() ( + long i + ) + { + // You can only use this operator on column vectors. + COMPILE_TIME_ASSERT(NC == 1 || NC == 0 || NR == 1 || NR == 0); + DLIB_ASSERT(nc() == 1 || nr() == 1, + "\tconst type matrix::operator(i)" + << "\n\tYou can only use this operator on column or row vectors" + << "\n\ti: " << i + << "\n\tnr(): " << nr() + << "\n\tnc(): " << nc() + << "\n\tthis: " << this + ); + DLIB_ASSERT( 0 <= i && i < size(), + "\tconst type matrix::operator(i)" + << "\n\tYou must give a valid row/column number" + << "\n\ti: " << i + << "\n\tsize(): " << size() + << "\n\tthis: " << this + ); + return data(i); + } + + const T& operator() ( + long i + ) const + { + // You can only use this operator on column vectors. + COMPILE_TIME_ASSERT(NC == 1 || NC == 0 || NR == 1 || NR == 0); + DLIB_ASSERT(nc() == 1 || nr() == 1, + "\tconst type matrix::operator(i)" + << "\n\tYou can only use this operator on column or row vectors" + << "\n\ti: " << i + << "\n\tnr(): " << nr() + << "\n\tnc(): " << nc() + << "\n\tthis: " << this + ); + DLIB_ASSERT( 0 <= i && i < size(), + "\tconst type matrix::operator(i)" + << "\n\tYou must give a valid row/column number" + << "\n\ti: " << i + << "\n\tsize(): " << size() + << "\n\tthis: " << this + ); + return data(i); + } + + inline operator const type ( + ) const + { + COMPILE_TIME_ASSERT(NC == 1 || NC == 0); + COMPILE_TIME_ASSERT(NR == 1 || NR == 0); + DLIB_ASSERT( nr() == 1 && nc() == 1 , + "\tmatrix::operator const type" + << "\n\tYou can only attempt to implicit convert a matrix to a scalar if" + << "\n\tthe matrix is a 1x1 matrix" + << "\n\tnr(): " << nr() + << "\n\tnc(): " << nc() + << "\n\tthis: " << this + ); + return data(0); + } + +#ifdef MATLAB_MEX_FILE + void _private_set_mxArray( + mxArray* mem + ) + { + data._private_set_mxArray(mem); + } + + mxArray* _private_release_mxArray( + ) + { + return data._private_release_mxArray(); + } + + void _private_mark_owned_by_matlab() + { + data._private_mark_owned_by_matlab(); + } + + bool _private_is_owned_by_matlab() + { + return data._private_is_owned_by_matlab(); + } +#endif + + void set_size ( + long rows, + long cols + ) + { + DLIB_ASSERT( (NR == 0 || NR == rows) && ( NC == 0 || NC == cols) && + rows >= 0 && cols >= 0, + "\tvoid matrix::set_size(rows, cols)" + << "\n\tYou have supplied conflicting matrix dimensions" + << "\n\trows: " << rows + << "\n\tcols: " << cols + << "\n\tNR: " << NR + << "\n\tNC: " << NC + << "\n\tthis: " << this + ); + if (nr() != rows || nc() != cols) + data.set_size(rows,cols); + } + + void set_size ( + long length + ) + { + // This object you are trying to call set_size(length) on is not a column or + // row vector. + COMPILE_TIME_ASSERT(NR == 1 || NC == 1); + DLIB_ASSERT( length >= 0, + "\tvoid matrix::set_size(length)" + << "\n\tlength must be at least 0" + << "\n\tlength: " << length + << "\n\tNR: " << NR + << "\n\tNC: " << NC + << "\n\tthis: " << this + ); + + if (NR == 1) + { + DLIB_ASSERT(NC == 0 || NC == length, + "\tvoid matrix::set_size(length)" + << "\n\tSince this is a statically sized matrix length must equal NC" + << "\n\tlength: " << length + << "\n\tNR: " << NR + << "\n\tNC: " << NC + << "\n\tthis: " << this + ); + + if (nc() != length) + data.set_size(1,length); + } + else + { + DLIB_ASSERT(NR == 0 || NR == length, + "\tvoid matrix::set_size(length)" + << "\n\tSince this is a statically sized matrix length must equal NR" + << "\n\tlength: " << length + << "\n\tNR: " << NR + << "\n\tNC: " << NC + << "\n\tthis: " << this + ); + + if (nr() != length) + data.set_size(length,1); + } + } + + long nr ( + ) const { return data.nr(); } + + long nc ( + ) const { return data.nc(); } + + long size ( + ) const { return data.nr()*data.nc(); } + + template + matrix& operator= ( + U (&array)[len] + ) + { + COMPILE_TIME_ASSERT(NR*NC == len && len > 0); + size_t idx = 0; + for (long r = 0; r < NR; ++r) + { + for (long c = 0; c < NC; ++c) + { + data(r,c) = static_cast(array[idx]); + ++idx; + } + } + return *this; + } + + template + matrix& operator= ( + const matrix_exp& m + ) + { + // You get an error on this line if the matrix you are trying to + // assign m to is a statically sized matrix and m's dimensions don't + // match that of *this. + COMPILE_TIME_ASSERT(EXP::NR == NR || NR == 0 || EXP::NR == 0); + COMPILE_TIME_ASSERT(EXP::NC == NC || NC == 0 || EXP::NC == 0); + DLIB_ASSERT((NR == 0 || nr() == m.nr()) && + (NC == 0 || nc() == m.nc()), + "\tmatrix& matrix::operator=(const matrix_exp& m)" + << "\n\tYou are trying to assign a dynamically sized matrix to a statically sized matrix with the wrong size" + << "\n\tnr(): " << nr() + << "\n\tnc(): " << nc() + << "\n\tm.nr(): " << m.nr() + << "\n\tm.nc(): " << m.nc() + << "\n\tthis: " << this + ); + + // You get an error on this line if the matrix m contains a type that isn't + // the same as the type contained in the target matrix. + COMPILE_TIME_ASSERT((is_same_type::value == true) || + (is_matrix::value == true)); + if (m.destructively_aliases(*this) == false) + { + // This if statement is seemingly unnecessary since set_size() contains this + // exact same if statement. However, structuring the code this way causes + // gcc to handle the way it inlines this function in a much more favorable way. + if (data.nr() == m.nr() && data.nc() == m.nc()) + { + matrix_assign(*this, m); + } + else + { + set_size(m.nr(),m.nc()); + matrix_assign(*this, m); + } + } + else + { + // we have to use a temporary matrix object here because + // *this is aliased inside the matrix_exp m somewhere. + matrix temp; + temp.set_size(m.nr(),m.nc()); + matrix_assign(temp, m); + temp.swap(*this); + } + return *this; + } + + template + matrix& operator += ( + const matrix_exp& m + ) + { + // The matrix you are trying to assign m to is a statically sized matrix and + // m's dimensions don't match that of *this. + COMPILE_TIME_ASSERT(EXP::NR == NR || NR == 0 || EXP::NR == 0); + COMPILE_TIME_ASSERT(EXP::NC == NC || NC == 0 || EXP::NC == 0); + COMPILE_TIME_ASSERT((is_same_type::value == true)); + if (nr() == m.nr() && nc() == m.nc()) + { + if (m.destructively_aliases(*this) == false) + { + matrix_assign(*this, *this + m); + } + else + { + // we have to use a temporary matrix object here because + // this->data is aliased inside the matrix_exp m somewhere. + matrix temp; + temp.set_size(m.nr(),m.nc()); + matrix_assign(temp, *this + m); + temp.swap(*this); + } + } + else + { + DLIB_ASSERT(size() == 0, + "\t const matrix::operator+=(m)" + << "\n\t You are trying to add two matrices that have incompatible dimensions."); + *this = m; + } + return *this; + } + + + template + matrix& operator -= ( + const matrix_exp& m + ) + { + // The matrix you are trying to assign m to is a statically sized matrix and + // m's dimensions don't match that of *this. + COMPILE_TIME_ASSERT(EXP::NR == NR || NR == 0 || EXP::NR == 0); + COMPILE_TIME_ASSERT(EXP::NC == NC || NC == 0 || EXP::NC == 0); + COMPILE_TIME_ASSERT((is_same_type::value == true)); + if (nr() == m.nr() && nc() == m.nc()) + { + if (m.destructively_aliases(*this) == false) + { + matrix_assign(*this, *this - m); + } + else + { + // we have to use a temporary matrix object here because + // this->data is aliased inside the matrix_exp m somewhere. + matrix temp; + temp.set_size(m.nr(),m.nc()); + matrix_assign(temp, *this - m); + temp.swap(*this); + } + } + else + { + DLIB_ASSERT(size() == 0, + "\t const matrix::operator-=(m)" + << "\n\t You are trying to subtract two matrices that have incompatible dimensions."); + *this = -m; + } + return *this; + } + + template + matrix& operator *= ( + const matrix_exp& m + ) + { + *this = *this * m; + return *this; + } + + matrix& operator += ( + const matrix& m + ) + { + const long size = m.nr()*m.nc(); + if (nr() == m.nr() && nc() == m.nc()) + { + for (long i = 0; i < size; ++i) + data(i) += m.data(i); + } + else + { + DLIB_ASSERT(this->size() == 0, + "\t const matrix::operator+=(m)" + << "\n\t You are trying to add two matrices that have incompatible dimensions."); + + set_size(m.nr(), m.nc()); + for (long i = 0; i < size; ++i) + data(i) = m.data(i); + } + return *this; + } + + matrix& operator -= ( + const matrix& m + ) + { + const long size = m.nr()*m.nc(); + if (nr() == m.nr() && nc() == m.nc()) + { + for (long i = 0; i < size; ++i) + data(i) -= m.data(i); + } + else + { + DLIB_ASSERT(this->size() == 0, + "\t const matrix::operator-=(m)" + << "\n\t You are trying to subtract two matrices that have incompatible dimensions."); + set_size(m.nr(), m.nc()); + for (long i = 0; i < size; ++i) + data(i) = -m.data(i); + } + return *this; + } + + matrix& operator += ( + const T val + ) + { + const long size = nr()*nc(); + for (long i = 0; i < size; ++i) + data(i) += val; + + return *this; + } + + matrix& operator -= ( + const T val + ) + { + const long size = nr()*nc(); + for (long i = 0; i < size; ++i) + data(i) -= val; + + return *this; + } + + matrix& operator *= ( + const T a + ) + { + *this = *this * a; + return *this; + } + + matrix& operator /= ( + const T a + ) + { + *this = *this / a; + return *this; + } + + matrix& operator= ( + const matrix& m + ) + { + if (this != &m) + { + set_size(m.nr(),m.nc()); + const long size = m.nr()*m.nc(); + for (long i = 0; i < size; ++i) + data(i) = m.data(i); + } + return *this; + } + + void swap ( + matrix& item + ) + { + data.swap(item.data); + } + + template + bool aliases ( + const matrix_exp& + ) const { return false; } + + bool aliases ( + const matrix_exp >& item + ) const { return (this == &item); } + + template + bool destructively_aliases ( + const matrix_exp& + ) const { return false; } + + // These two aliases() routines are defined in matrix_mat.h + bool aliases ( + const matrix_exp > >& item + ) const; + bool aliases ( + const matrix_exp > >& item + ) const; + + iterator begin() + { + if (size() != 0) + return &data(0,0); + else + return 0; + } + + iterator end() + { + if (size() != 0) + return &data(0,0)+size(); + else + return 0; + } + + const_iterator begin() const + { + if (size() != 0) + return &data(0,0); + else + return 0; + } + + const_iterator end() const + { + if (size() != 0) + return &data(0,0)+size(); + else + return 0; + } + + private: + struct literal_assign_helper + { + /* + This struct is a helper struct returned by the operator<<() function below. It is + used primarily to enable us to put DLIB_CASSERT statements on the usage of the + operator<< form of matrix assignment. + */ + + literal_assign_helper(const literal_assign_helper& item) : m(item.m), r(item.r), c(item.c), has_been_used(false) {} + explicit literal_assign_helper(matrix* m_): m(m_), r(0), c(0),has_been_used(false) {next();} + ~literal_assign_helper() noexcept(false) + { + DLIB_CASSERT(!has_been_used || r == m->nr(), + "You have used the matrix comma based assignment incorrectly by failing to\n" + "supply a full set of values for every element of a matrix object.\n"); + } + + const literal_assign_helper& operator, ( + const T& val + ) const + { + DLIB_CASSERT(r < m->nr() && c < m->nc(), + "You have used the matrix comma based assignment incorrectly by attempting to\n" << + "supply more values than there are elements in the matrix object being assigned to.\n\n" << + "Did you forget to call set_size()?" + << "\n\t r: " << r + << "\n\t c: " << c + << "\n\t m->nr(): " << m->nr() + << "\n\t m->nc(): " << m->nc()); + (*m)(r,c) = val; + next(); + has_been_used = true; + return *this; + } + + private: + + friend class matrix; + + void next ( + ) const + { + ++c; + if (c == m->nc()) + { + c = 0; + ++r; + } + } + + matrix* m; + mutable long r; + mutable long c; + mutable bool has_been_used; + }; + + public: + + matrix& operator = ( + const literal_assign_helper& val + ) + { + *this = *val.m; + return *this; + } + + const literal_assign_helper operator = ( + const T& val + ) + { + // assign the given value to every spot in this matrix + const long size = nr()*nc(); + for (long i = 0; i < size; ++i) + data(i) = val; + + // Now return the literal_assign_helper so that the user + // can use the overloaded comma notation to initialize + // the matrix if they want to. + return literal_assign_helper(this); + } + + private: + + + typename layout::template layout data; + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T, + long NR, + long NC, + typename mm, + typename l + > + void swap( + matrix& a, + matrix& b + ) { a.swap(b); } + + template < + typename T, + long NR, + long NC, + typename mm, + typename l + > + void serialize ( + const matrix& item, + std::ostream& out + ) + { + try + { + // The reason the serialization is a little funny is because we are trying to + // maintain backwards compatibility with an older serialization format used by + // dlib while also encoding things in a way that lets the array2d and matrix + // objects have compatible serialization formats. + serialize(-item.nr(),out); + serialize(-item.nc(),out); + for (long r = 0; r < item.nr(); ++r) + { + for (long c = 0; c < item.nc(); ++c) + { + serialize(item(r,c),out); + } + } + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing dlib::matrix"); + } + } + + template < + typename T, + long NR, + long NC, + typename mm, + typename l + > + void deserialize ( + matrix& item, + std::istream& in + ) + { + try + { + long nr, nc; + deserialize(nr,in); + deserialize(nc,in); + + // this is the newer serialization format + if (nr < 0 || nc < 0) + { + nr *= -1; + nc *= -1; + } + + if (NR != 0 && nr != NR) + throw serialization_error("Error while deserializing a dlib::matrix. Invalid rows"); + if (NC != 0 && nc != NC) + throw serialization_error("Error while deserializing a dlib::matrix. Invalid columns"); + + item.set_size(nr,nc); + for (long r = 0; r < nr; ++r) + { + for (long c = 0; c < nc; ++c) + { + deserialize(item(r,c),in); + } + } + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while deserializing a dlib::matrix"); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + long NR, + long NC, + typename mm, + typename l + > + void serialize ( + const ramdump_t>& item_, + std::ostream& out + ) + { + auto& item = item_.item; + serialize(item.nr(), out); + serialize(item.nc(), out); + if (item.size() != 0) + out.write((char*)&item(0,0), sizeof(item(0,0))*item.size()); + } + + template < + typename T, + long NR, + long NC, + typename mm, + typename l + > + void deserialize ( + ramdump_t>&& item_, + std::istream& in + ) + { + auto& item = item_.item; + long nr, nc; + deserialize(nr, in); + deserialize(nc, in); + item.set_size(nr,nc); + if (item.size() != 0) + in.read((char*)&item(0,0), sizeof(item(0,0))*item.size()); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP + > + std::ostream& operator<< ( + std::ostream& out, + const matrix_exp& m + ) + { + using namespace std; + const streamsize old = out.width(); + + // first figure out how wide we should make each field + string::size_type w = 0; + ostringstream sout; + for (long r = 0; r < m.nr(); ++r) + { + for (long c = 0; c < m.nc(); ++c) + { + sout << m(r,c); + w = std::max(sout.str().size(),w); + sout.str(""); + } + } + + // now actually print it + for (long r = 0; r < m.nr(); ++r) + { + for (long c = 0; c < m.nc(); ++c) + { + out.width(static_cast(w)); + out << m(r,c) << " "; + } + out << "\n"; + } + out.width(old); + return out; + } + + /* + template < + typename T, + long NR, + long NC, + typename MM, + typename L + > + std::istream& operator>> ( + std::istream& in, + matrix& m + ); + + This function is defined inside the matrix_read_from_istream.h file. + */ + +// ---------------------------------------------------------------------------------------- + + class print_matrix_as_csv_helper + { + /*! + This object is used to define an io manipulator for matrix expressions. + In particular, this code allows you to write statements like: + cout << csv << yourmatrix; + and have it print the matrix with commas separating each element. + !*/ + public: + print_matrix_as_csv_helper (std::ostream& out_) : out(out_) {} + + template + std::ostream& operator<< ( + const matrix_exp& m + ) + { + for (long r = 0; r < m.nr(); ++r) + { + for (long c = 0; c < m.nc(); ++c) + { + if (c+1 == m.nc()) + out << m(r,c) << "\n"; + else + out << m(r,c) << ", "; + } + } + return out; + } + + private: + std::ostream& out; + }; + + class print_matrix_as_csv {}; + const print_matrix_as_csv csv = print_matrix_as_csv(); + inline print_matrix_as_csv_helper operator<< ( + std::ostream& out, + const print_matrix_as_csv& + ) + { + return print_matrix_as_csv_helper(out); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template + class const_temp_matrix; + + template < + typename EXP + > + struct matrix_traits > + { + typedef typename EXP::type type; + typedef typename EXP::const_ret_type const_ret_type; + typedef typename EXP::mem_manager_type mem_manager_type; + typedef typename EXP::layout_type layout_type; + const static long NR = EXP::NR; + const static long NC = EXP::NC; + const static long cost = 1; + }; + + template + class const_temp_matrix : public matrix_exp >, noncopyable + { + public: + typedef typename matrix_traits::type type; + typedef typename matrix_traits::const_ret_type const_ret_type; + typedef typename matrix_traits::mem_manager_type mem_manager_type; + typedef typename matrix_traits::layout_type layout_type; + const static long NR = matrix_traits::NR; + const static long NC = matrix_traits::NC; + const static long cost = matrix_traits::cost; + + const_temp_matrix ( + const matrix_exp& item + ) : + ref_(item.ref()) + {} + const_temp_matrix ( + const EXP& item + ) : + ref_(item) + {} + + const_ret_type operator() ( + long r, + long c + ) const { return ref_(r,c); } + + const_ret_type operator() ( long i ) const + { return ref_(i); } + + template + bool aliases ( + const matrix_exp& item + ) const { return ref_.aliases(item); } + + template + bool destructively_aliases ( + const matrix_exp& item + ) const { return ref_.destructively_aliases(item); } + + long nr ( + ) const { return ref_.nr(); } + + long nc ( + ) const { return ref_.nc(); } + + private: + + typename conditional_matrix_temp::type ref_; + }; + +// ---------------------------------------------------------------------------------------- + + typedef matrix matrix_colmajor; + typedef matrix fmatrix_colmajor; + +} + +#ifdef _MSC_VER +// put warnings back to their default settings +#pragma warning(default : 4355) +#pragma warning(default : 4723) +#endif + +#endif // DLIB_MATRIx_ + diff --git a/ml/dlib/dlib/matrix/matrix_abstract.h b/ml/dlib/dlib/matrix/matrix_abstract.h new file mode 100644 index 000000000..0d05ce981 --- /dev/null +++ b/ml/dlib/dlib/matrix/matrix_abstract.h @@ -0,0 +1,857 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_MATRIx_ABSTRACT_ +#ifdef DLIB_MATRIx_ABSTRACT_ + +#include "matrix_exp_abstract.h" +#include "../serialize.h" +#include "../algs.h" +#include "matrix_data_layout_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + /* + Note that these operator prototypes are not correct C++ (the real versions, which + you can see in the implementation are really complex and so probably would + distract/confuse people if shown here). Think of this as just a list of the + operators available to you and what they do. + */ + + const matrix_exp operator* ( + const matrix_exp& m1, + const matrix_exp& m2 + ); + /*! + requires + - m1.nc() == m2.nr() + - m1.size() > 0 && m2.size() > 0 + (you can't multiply any sort of empty matrices together) + - m1 and m2 both contain elements of the same type + ensures + - returns the result of doing the matrix multiplication m1*m2. The resulting + matrix will have m1.nr() rows and m2.nc() columns. + !*/ + + const matrix_exp operator+ ( + const matrix_exp& m1, + const matrix_exp& m2 + ); + /*! + requires + - m1.nr() == m2.nr() + - m1.nc() == m2.nc() + - m1 and m2 both contain elements of the same type + ensures + - returns a matrix R such that for all valid r and c: + R(r,c) == m1(r,c) + m2(r,c) + (i.e. returns the result of doing a pairwise addition of the matrices m1 and m2.) + The resulting matrix will have the same dimensions as the originals. + !*/ + + const matrix_exp operator- ( + const matrix_exp& m1, + const matrix_exp& m2 + ); + /*! + requires + - m1.nr() == m2.nr() + - m1.nc() == m2.nc() + - m1 and m2 both contain elements of the same type + ensures + - returns a matrix R such that for all valid r and c: + R(r,c) == m1(r,c) - m2(r,c) + (i.e. returns the result of doing a pairwise subtraction of the matrices m1 and m2.) + The resulting matrix will have the same dimensions as the originals. + !*/ + + template + const matrix_exp operator* ( + const matrix_exp& m, + const T& value + ); + /*! + ensures + - returns the result of multiplying all the elements of matrix m by the given + scalar value. The resulting matrix will have the same dimensions as m. + !*/ + + template + const matrix_exp operator* ( + const T& value, + const matrix_exp& m + ); + /*! + ensures + - returns the result of multiplying all the elements of matrix m by the given + scalar value. The resulting matrix will have the same dimensions as m. + !*/ + + const matrix_exp operator- ( + const matrix_exp& m + ); + /*! + ensures + - returns -1*m + !*/ + + template + const matrix_exp operator/ ( + const matrix_exp& m, + const T& value + ); + /*! + ensures + - returns the result of dividing all the elements of matrix m by the given + scalar value. The resulting matrix will have the same dimensions as m. + !*/ + + template + const matrix_exp operator/ ( + const T& value, + const matrix_exp& m + ); + /*! + ensures + - returns the result of dividing the given scalar value by all the elements + of matrix m. The resulting matrix will have the same dimensions as m. + !*/ + + template + const matrix_exp operator+ ( + const matrix_exp& m, + const T& value + ); + /*! + ensures + - returns the result of adding value to all the elements of matrix m. + The resulting matrix will have the same dimensions as m. + !*/ + + template + const matrix_exp operator+ ( + const T& value, + const matrix_exp& m + ); + /*! + ensures + - returns the result of adding value to all the elements of matrix m. + The resulting matrix will have the same dimensions as m. + !*/ + + template + const matrix_exp operator- ( + const matrix_exp& m, + const T& value + ); + /*! + ensures + - returns the result of subtracting value from all the elements of matrix m. + The resulting matrix will have the same dimensions as m. + !*/ + + template + const matrix_exp operator- ( + const T& value, + const matrix_exp& m + ); + /*! + ensures + - Returns a matrix M such that: + - M has the same dimensions as m + - M contains the same type of element as m + - for all valid r and c: + - M(r,c) == value - m(r,c) + !*/ + + bool operator== ( + const matrix_exp& m1, + const matrix_exp& m2 + ); + /*! + ensures + - if (m1.nr() == m2.nr() && m1.nc() == m2.nc() && + for all valid r and c: m1(r,c) == m2(r,c) ) then + - returns true + - else + - returns false + !*/ + + bool operator!= ( + const matrix_exp& m1, + const matrix_exp& m2 + ); + /*! + ensures + - returns !(m1 == m2) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + long num_rows = 0, + long num_cols = 0, + typename mem_manager = default_memory_manager, + typename layout = row_major_layout + > + class matrix : public matrix_exp > + { + /*! + REQUIREMENTS ON num_rows and num_cols + both must be bigger than or equal to 0 + + REQUIREMENTS ON mem_manager + must be an implementation of memory_manager/memory_manager_kernel_abstract.h or + must be an implementation of memory_manager_global/memory_manager_global_kernel_abstract.h or + must be an implementation of memory_manager_stateless/memory_manager_stateless_kernel_abstract.h + mem_manager::type can be set to anything. + + REQUIREMENTS ON layout + Must be either row_major_layout or column_major_layout + + INITIAL VALUE + - if (num_rows > 0) then + - nr() == num_rows + - else + - nr() == 0 + + - if (num_cols > 0) then + - nc() == num_cols + - else + - nc() == 0 + + WHAT THIS OBJECT REPRESENTS + This object represents a matrix of nr() rows and nc() columns. This object + is also a matrix_exp. Thus it can be used in all of the above + global operators. + + The number of rows and columns of this object are determined by the template + arguments num_rows and num_cols. If num_rows or num_cols are 0 then + the matrix starts out empty (i.e. nr() == 0 and nc() == 0) and you may change + its size via the set_size() member function. + + Setting num_rows or num_cols to something other than 0 causes that dimension + to have a fixed size. Setting a fixed size at compile time is useful because + any errors related to operating on matrices with incompatible dimensions will + be detected at compile time. It also allows the compiler to perform loop + unrolling which can result in substantially faster code. + + Also note that the elements of this matrix are laid out in memory by the layout + object supplied as a template argument to this class. The row_major_layout + sets elements down contiguously in memory and in row major order. Additionally, + all memory allocations are performed using the memory manager object supplied as + a template argument to this class. + !*/ + + public: + typedef T type; + typedef mem_manager mem_manager_type; + typedef layout layout_type; + const static long NR = num_rows; + const static long NC = num_cols; + const static long cost = 1; + typedef T* iterator; + typedef const T* const_iterator; + + matrix ( + ); + /*! + ensures + - #*this is properly initialized + - #aliases(*this) == true + - #ref().aliases(*this) == true + !*/ + + explicit matrix ( + long length + ); + /*! + requires + - NR == 1 || NC == 1 (i.e. this must be a column or row vector) + - length >= 0 + - if (NR == 1 && NC > 0) then + - length == NC + - if (NC == 1 && NR > 0) then + - length == NR + ensures + - #*this is properly initialized + - #aliases(*this) == true + - #ref().aliases(*this) == true + - if (NR == 1) then + - #nr() == 1 + - #nc() == length + - else + - #nr() == length + - #nc() == 1 + !*/ + + matrix ( + long rows, + long cols + ); + /*! + requires + - rows == NR || NR == 0 + - cols == NC || NC == 0 + - rows >= 0 && cols >= 0 + ensures + - #*this is properly initialized + - #aliases(*this) == true + - #ref().aliases(*this) == true + - #nr() == rows + - #nc() == cols + !*/ + + template + matrix ( + const matrix_exp& m + ); + /*! + requires + - matrix_exp::type == T + (i.e. m contains the same type as *this does) + - if (NR != 0) then NR == m.nr() + - if (NC != 0) then NC == m.nc() + ensures + - #*this == m + - #aliases(*this) == true + - #ref().aliases(*this) == true + !*/ + + template + explicit matrix ( + U (&array)[len] + ); + /*! + requires + - NR != 0 && NC != 0 (i.e. you can only use this constructor on statically sized matrices) + - len == nr()*nc() (i.e. the array you give here must be the right size) + ensures + - for all valid r and c: + #(*this)(r,c) == array[r*nc() + c] + (i.e. initializes this matrix with the contents of the given array) + - #aliases(*this) == true + - #ref().aliases(*this) == true + !*/ + + matrix( + const std::initializer_list& l + ); + /*! + requires + - This matrix is capable of having a size() == l.size(). Therefore, if + NR*NC != 0 then l.size() must equal NR*NC. Alternatively, if NR or NC is + != 0 then l.size() must be a multiple of the non-zero NR or NC. + ensures + - #size() == l.size() + - The contents of l are enumerated and read into the matrix in row major order. + - if (NR != 0) then + - #nr() == NR + - #nc() == l.size()/NR + - if (NC != 0) then + - #nr() == l.size()/NC + - #nc() == NC + - if (NR*NC==0) then + - #nr() == l.size() + - #nc() == 1 + - #aliases(*this) == true + - #ref().aliases(*this) == true + !*/ + + T& operator() ( + long r, + long c + ); + /*! + requires + - 0 <= r < nr() + - 0 <= c < nc() + ensures + - returns a reference to the value at the given row and column in + this matrix. + !*/ + + const T& operator() ( + long r, + long c + ) const; + /*! + requires + - 0 <= r < nr() + - 0 <= c < nc() + ensures + - returns a const reference to the value at the given row and column in + this matrix. + !*/ + + T& operator() ( + long i + ); + /*! + requires + - nc() == 1 || nr() == 1 (i.e. this must be a column or row vector) + - 0 <= i < size() + ensures + - if (nc() == 1) then + - returns a reference to (*this)(i,0) + - else + - returns a reference to (*this)(0,i) + !*/ + + const T& operator() ( + long i + ) const; + /*! + requires + - nc() == 1 || nr() == 1 (i.e. this must be a column or row vector) + - 0 <= i < size() + ensures + - if (nc() == 1) then + - returns a reference to (*this)(i,0) + - else + - returns a reference to (*this)(0,i) + !*/ + + operator const type ( + ) const; + /*! + requires + - nr() == 1 + - nc() == 1 + ensures + - returns (*this)(0,0) + !*/ + + long nr( + ) const; + /*! + ensures + - returns the number of rows in this matrix + !*/ + + long nc( + ) const; + /*! + ensures + - returns the number of columns in this matrix + !*/ + + long size ( + ) const; + /*! + ensures + - returns nr()*nc() + !*/ + + void set_size ( + long rows, + long cols + ); + /*! + requires + - rows == NR || NR == 0 + - cols == NC || NC == 0 + - rows >= 0 && cols >= 0 + ensures + - #nr() == rows + - #nc() == cols + !*/ + + void set_size ( + long length + ); + /*! + requires + - NR == 1 || NC == 1 (i.e. this must be a column or row vector) + - length >= 0 + - if (NR == 1 && NC > 0) then + - length == NC + - if (NC == 1 && NR > 0) then + - length == NR + ensures + - if (NR == 1) then + - #nr() == 1 + - #nc() == length + - else + - #nr() == length + - #nc() == 1 + !*/ + + template + matrix& operator= ( + U (&array)[len] + ); + /*! + requires + - len == nr()*nc() (i.e. the array you give here must be the right size) + ensures + - for all valid r and c: + #(*this)(r,c) == array[r*nc() + c] + (i.e. loads this matrix with the contents of the given array) + - returns *this + !*/ + + matrix& operator=( + const std::initializer_list& l + ); + /*! + requires + - This matrix is capable of having a size() == l.size(). Therefore, if + NR*NC != 0 then l.size() must equal NR*NC. Alternatively, if NR or NC is + != 0 then l.size() must be a multiple of the non-zero NR or NC. + ensures + - Assigns the contents of l to *this by performing: matrix(l).swap(*this) + - returns *this + !*/ + + template + matrix& operator= ( + const matrix_exp& m + ); + /*! + requires + - matrix_exp::type == T + (i.e. m contains the same type as *this does) + - if (NR != 0) then NR == m.nr() + - if (NC != 0) then NC == m.nc() + ensures + - copies the given matrix expression m to *this + - returns *this + !*/ + + template + matrix& operator += ( + const matrix_exp& m + ); + /*! + requires + - matrix_exp::type == T + - One of the following is true: + - nr() == m.nr() && nc() == m.nc() + - size() == 0 + (i.e. this matrix must have matching dimensions or it must be empty) + ensures + - if (nr() == m.nr() && nc() == m.nc()) then + - #(*this) == *this + m + - else + - #(*this) == m + (i.e. if the dimensions don't match then this function performs a + normal assignment) + - returns *this + !*/ + + template + matrix& operator -= ( + const matrix_exp& m + ); + /*! + requires + - matrix_exp::type == T + - One of the following is true: + - nr() == m.nr() && nc() == m.nc() + - size() == 0 + (i.e. this matrix must have matching dimensions or it must be empty) + ensures + - if (nr() == m.nr() && nc() == m.nc()) then + - #(*this) == *this - m + - else + - #(*this) == -m + - returns *this + !*/ + + template + matrix& operator *= ( + const matrix_exp& m + ); + /*! + requires + - matrix_exp::type == T + (i.e. m must contain the same type of element as *this) + - nc() == m.nr() + - size() > 0 && m.size() > 0 + (you can't multiply any sort of empty matrices together) + ensures + - #(*this) == *this * m + - returns *this + !*/ + + matrix& operator *= ( + const T& a + ); + /*! + ensures + - #(*this) == *this * a + - returns *this + !*/ + + matrix& operator /= ( + const T& a + ); + /*! + ensures + - #(*this) == *this / a + - returns *this + !*/ + + matrix& operator += ( + const T& a + ); + /*! + ensures + - #(*this) == *this + a + - returns *this + !*/ + + matrix& operator -= ( + const T& a + ); + /*! + ensures + - #(*this) == *this - a + - returns *this + !*/ + + const literal_assign_helper operator = ( + const T& val + ); + /*! + This function is somewhat different than all the others defined in this file. + The purpose of this function is to enable you to easily initialize a matrix object. + For example: + matrix m(2,3); + m = 1,2,3, + 4,5,6; + + The above code creates a matrix m with 2 rows and 3 columns and sets it so that + it contains the matrix | 1 2 3 | + | 4 5 6 | + + You can also use this function to assign to all elements of a matrix. So + saying m = 3; would assign all elements of m equal to 3. + + Note that to use this method of assignment it is required that you supply + exactly m.size() or 1 values so that the matrix is fully initialized. Supplying + fewer or more than that is an error that will cause a dlib::fatal_error to be + thrown. + + Note also that using an expression of the form m = scalar; when m.size() == 0 + is legal but has no effect on m. + !*/ + + void swap ( + matrix& item + ); + /*! + ensures + - swaps *this and item + !*/ + + iterator begin( + ); + /*! + ensures + - returns a random access iterator pointing to the first element in this + matrix. + - The iterator will iterate over the elements of the matrix in row major + order if layout is row_major_layout or in column major order if layout is + column_major_layout. + !*/ + + iterator end( + ); + /*! + ensures + - returns a random access iterator pointing to one past the end of the last + element in this matrix. + !*/ + + const_iterator begin( + ) const; + /*! + ensures + - returns a random access iterator pointing to the first element in this + matrix. + - The iterator will iterate over the elements of the matrix in row major + order if layout is row_major_layout or in column major order if layout is + column_major_layout. + !*/ + + const_iterator end( + ) const; + /*! + ensures + - returns a random access iterator pointing to one past the end of the last + element in this matrix. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + /*!A matrix_colmajor + This is just a typedef of the matrix object that uses column major layout. + !*/ + typedef matrix matrix_colmajor; + + /*!A fmatrix_colmajor + This is just a typedef of the matrix object that uses column major layout. + !*/ + typedef matrix fmatrix_colmajor; + +// ---------------------------------------------------------------------------------------- +template < + typename T, + long NR, + long NC, + typename mm, + typename l + > + void swap( + matrix& a, + matrix& b + ) { a.swap(b); } + /*! + Provides a global swap function + !*/ + + template < + typename T, + long NR, + long NC, + typename mm, + typename l + > + void serialize ( + const matrix& item, + std::ostream& out + ); + /*! + Provides serialization support. Note that the serialization formats used by the + dlib::matrix and dlib::array2d objects are compatible. That means you can load the + serialized data from one into another and it will work properly. + !*/ + + template < + typename T, + long NR, + long NC, + typename mm, + typename l + > + void deserialize ( + matrix& item, + std::istream& in + ); + /*! + Provides deserialization support + !*/ + + template < + typename EXP + > + std::ostream& operator<< ( + std::ostream& out, + const matrix_exp& m + ); + /*! + ensures + - writes m to the given out stream in a form suitable for human consumption. + - returns out + !*/ + + template < + typename T, + long NR, + long NC, + typename MM, + typename L + > + std::istream& operator>> ( + std::istream& in, + matrix& m + ); + /*! + ensures + - Tries to read a matrix from the given input stream and store it into #m. + - The format expected is the text format output by the above operator<<(). + That is, the format should be a grid of text such as: + 2 3 4 + 5 2 6 + - The separation between numbers can be any number of whitespace characters or + commas. + - The matrix data is assumed to end upon the first blank line or end-of-file, + whichever comes first. This means you can create an input stream with + multiple matrices in it by separating them with empty lines. + - returns in. + - If there was a formatting error or something which prevents the input data + from being parsed into a matrix then #in.fail() == true. + !*/ + + /*!A csv + This object is used to define an io manipulator for matrix expressions. In + particular, you can write statements like: + cout << csv << yourmatrix; + and have it print the matrix with commas separating each element. + !*/ + some_undefined_iomnaip_type csv; + +// ---------------------------------------------------------------------------------------- + + template + class const_temp_matrix : public matrix_exp >, noncopyable + { + /*! + REQUIREMENTS ON EXP + - must be an object that inherits publicly from matrix_exp. + + WHAT THIS OBJECT REPRESENTS + This object represents a copy of a matrix expression. The twist + is that it only actually makes a copy of its input matrix expression + if that matrix expression is costly to evaluate. If it has + low cost then this object just stores a reference. + + This class is useful in cases where you write a function that + takes a matrix_exp object as input and you want to do some + intensive computation that looks at each element of that matrix_exp + many times. If the input matrix_exp has a high cost then you want + to store it into a temporary matrix. But if it has low cost then + it is faster if you just use a reference to it. The const_temp_matrix + makes doing this easy. + !*/ + public: + + const_temp_matrix ( + const matrix_exp& item + ); + /*! + ensures + - #*this == item + - if (EXP::cost <= 1) then + - this const_temp_matrix stores a reference to the item matrix + - else + - this const_temp_matrix creates a temporary matrix and copies + item into it + !*/ + + const_temp_matrix ( + const EXP& item + ); + /*! + ensures + - #*this == item + - if (EXP::cost <= 1) then + - this const_temp_matrix stores a reference to the item matrix + - else + - this const_temp_matrix creates a temporary matrix and copies + item into it + !*/ + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MATRIx_ABSTRACT_ + diff --git a/ml/dlib/dlib/matrix/matrix_assign.h b/ml/dlib/dlib/matrix/matrix_assign.h new file mode 100644 index 000000000..da53050b1 --- /dev/null +++ b/ml/dlib/dlib/matrix/matrix_assign.h @@ -0,0 +1,978 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MATRIx_ASSIGn_ +#define DLIB_MATRIx_ASSIGn_ + +#include "matrix.h" +#include "matrix_utilities.h" +#include "matrix_subexp.h" +#include "../enable_if.h" +#include "matrix_assign_fwd.h" +#include "matrix_default_mul.h" +#include "matrix_conj_trans.h" +#include "matrix_mat.h" + +namespace dlib +{ + /* + This file contains some templates that are used inside the matrix_blas_bindings.h + file to bind various matrix expressions to optimized code for carrying them out. + */ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + namespace blas_bindings + { + + // ------------------------------------------------------------------------------------ + + template + void zero_matrix ( + T& m + ) + { + for (long r = 0; r < m.nr(); ++r) + { + for (long c = 0; c < m.nc(); ++c) + { + m(r,c) = 0; + } + } + } + + // ------------------------------------------------------------------------------------ + + // This template struct is used to tell us if a matrix expression contains a matrix multiply. + template + struct has_matrix_multiply + { + const static bool value = false; + }; + + template + struct has_matrix_multiply > + { const static bool value = true; }; + + template + struct has_matrix_multiply > + { const static bool value = has_matrix_multiply::value || has_matrix_multiply::value; }; + + template + struct has_matrix_multiply > + { const static bool value = has_matrix_multiply::value || has_matrix_multiply::value; }; + + template + struct has_matrix_multiply > + { const static bool value = true; }; + + template + struct has_matrix_multiply > + { const static bool value = has_matrix_multiply::value; }; + + template + struct has_matrix_multiply > + { const static bool value = has_matrix_multiply::value; }; + + template + struct has_matrix_multiply > + { const static bool value = has_matrix_multiply::value; }; + + template + struct has_matrix_multiply > + { const static bool value = has_matrix_multiply::value; }; + + template + struct has_matrix_multiply > + { const static bool value = has_matrix_multiply::value; }; + + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + + const int unknown_matrix = 0; + const int general_matrix = 1; + const int row_matrix = 2; + const int column_matrix = 3; + + // ------------------------------------------------------------------------------------ + + template + struct matrix_type_id + { + const static int value = unknown_matrix; + }; + + template + struct matrix_type_id > + { + const static int value = general_matrix; + }; + + template + struct matrix_type_id > + { + const static int value = column_matrix; + }; + + template + struct matrix_type_id > + { + const static int value = column_matrix; + }; + + template + struct matrix_type_id > + { + const static int value = row_matrix; + }; + + // ------------------------------------------------------------------------------------ + + template + struct matrix_type_id > > > + { + const static int value = column_matrix; + }; + + template + struct matrix_type_id > > > + { + const static int value = row_matrix; + }; + + template + struct matrix_type_id > > > + { + const static int value = column_matrix; + }; + + template + struct matrix_type_id > > > + { + const static int value = row_matrix; + }; + + template + struct matrix_type_id > > > + { + const static int value = general_matrix; + }; + + template < typename T, typename MM > + struct matrix_type_id > > > + { const static int value = general_matrix; }; + + template < typename T, typename MM > + struct matrix_type_id > > > + { const static int value = column_matrix; }; + + template < typename value_type, typename alloc > + struct matrix_type_id > > > + { const static int value = column_matrix; }; + + template < typename value_type, typename alloc > + struct matrix_type_id > > > + { const static int value = column_matrix; }; + + template < typename T > + struct matrix_type_id > > + { const static int value = column_matrix; }; + template < typename T > + struct matrix_type_id > > + { const static int value = general_matrix; }; + + // ------------------------------------------------------------------------------------ + + template + struct same_matrix + { + const static int T_id = matrix_type_id::value; + const static int U_id = matrix_type_id::value; + // The check for unknown_matrix is here so that we can be sure that matrix types + // other than the ones specifically enumerated above never get pushed into + // any of the BLAS bindings. So saying they are never the same as anything + // else prevents them from matching any of the BLAS bindings. + const static bool value = (T_id == U_id) && (T_id != unknown_matrix); + }; + + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + + // This template struct is used to tell us if two matrix expressions both contain the same + // sequence of operators, expressions. It also only has a value of true if the T expression + // contains only matrices with the given layout. + template + struct same_exp + { + const static bool value = (is_same_type::value || + same_matrix::value) && + is_same_type::value; + + }; + + // Used only below. They help strip off the const and & qualifiers that can show up + // in the LHS_ref_type and RHS_ref_type typedefs. + template struct noref{ typedef T type;}; + template struct noref{ typedef T type;}; + template struct noref{ typedef T type;}; + template struct noref{ typedef T type;}; + + template + struct same_exp, matrix_multiply_exp,layout > + { + // The reason this case is more complex than the others is because the matrix_multiply_exp + // will use a temporary matrix instead of Tlhs or Trhs in the event that one of these + // types corresponds to an expensive expression. So we have to use the type that really + // gets used. The following typedefs are here to pick out that true type. + typedef typename matrix_multiply_exp::LHS_ref_type T_LHS_ref_type; + typedef typename matrix_multiply_exp::RHS_ref_type T_RHS_ref_type; + typedef typename noref::type T_lhs_type; + typedef typename noref::type T_rhs_type; + + typedef typename matrix_multiply_exp::LHS_ref_type U_LHS_ref_type; + typedef typename matrix_multiply_exp::RHS_ref_type U_RHS_ref_type; + typedef typename noref::type U_lhs_type; + typedef typename noref::type U_rhs_type; + + const static bool value = same_exp::value && + same_exp::value; + }; + + template + struct same_exp, matrix_add_exp, layout > + { const static bool value = same_exp::value && same_exp::value; }; + + template + struct same_exp, matrix_subtract_exp, layout > + { const static bool value = same_exp::value && same_exp::value; }; + + template + struct same_exp, matrix_mul_scal_exp, layout > + { const static bool value = same_exp::value; }; + + template + struct same_exp, matrix_div_scal_exp, layout > + { const static bool value = same_exp::value; }; + + template + struct same_exp >, matrix_op >, layout > + { const static bool value = same_exp::value; }; + + template + struct same_exp >, matrix_op >, layout > + { const static bool value = same_exp::value; }; + + template + struct same_exp >, matrix_op >, layout > + { const static bool value = same_exp::value; }; + + // ------------------------------------------------------------------------------------ + + struct yes_type + { + char ch; + }; + struct no_type + { + yes_type a, b; + }; + + // This is a helper that is used below to apply the same_exp template to matrix expressions. + template + typename enable_if,yes_type>::type test(U); + template + typename disable_if,no_type>::type test(U); + + // ------------------------------------------------------------------------------------ + + template < + typename dest_exp, + typename src_exp, + typename enabled = void + > + struct matrix_assign_blas_helper + { + // We are in the default version of the blas helper so this + // means there wasn't any more specific overload. So just + // let the default matrix assignment happen. + template + static void assign ( + dest_exp& dest, + const EXP& src, + typename src_exp::type alpha, + bool add_to, + bool transpose + ) + { + if (transpose == false) + matrix_assign_default(dest,src,alpha,add_to); + else + matrix_assign_default(dest,trans(src),alpha,add_to); + } + + // If we know this is a matrix multiply then apply the + // default dlib matrix multiply to speed things up a bit more + // than the above default function would. + template + static void assign ( + dest_exp& dest, + const matrix_multiply_exp& src, + typename src_exp::type alpha, + bool add_to, + bool transpose + ) + { + // At some point I need to improve the default (i.e. non BLAS) matrix + // multiplication algorithm... + + if (alpha == static_cast(1)) + { + if (add_to == false) + { + zero_matrix(dest); + } + + if (transpose == false) + default_matrix_multiply(dest, src.lhs, src.rhs); + else + default_matrix_multiply(dest, trans(src.rhs), trans(src.lhs)); + } + else + { + if (add_to) + { + typename dest_exp::matrix_type temp(dest.nr(),dest.nc()); + zero_matrix(temp); + + if (transpose == false) + default_matrix_multiply(temp, src.lhs, src.rhs); + else + default_matrix_multiply(temp, trans(src.rhs), trans(src.lhs)); + + matrix_assign_default(dest,temp, alpha,true); + } + else + { + zero_matrix(dest); + + if (transpose == false) + default_matrix_multiply(dest, src.lhs, src.rhs); + else + default_matrix_multiply(dest, trans(src.rhs), trans(src.lhs)); + + matrix_assign_default(dest,dest, alpha, false); + } + } + } + }; + + // This is a macro to help us add overloads for the matrix_assign_blas_helper template. + // Using this macro it is easy to add overloads for arbitrary matrix expressions. +#define DLIB_ADD_BLAS_BINDING(src_expression) \ + template struct BOOST_JOIN(blas,__LINE__) \ + { const static bool value = sizeof(yes_type) == sizeof(test(src_expression)); }; \ + \ + template < typename dest_exp, typename src_exp > \ + struct matrix_assign_blas_helper >::type > { \ + static void assign ( \ + dest_exp& dest, \ + const src_exp& src, \ + typename src_exp::type alpha, \ + bool add_to, \ + bool DLIB_NO_WARN_UNUSED transpose \ + ) { \ + DLIB_NO_WARN_UNUSED typedef typename dest_exp::type T; + +#define DLIB_END_BLAS_BINDING }}; + + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + + // ------------------- Forward Declarations ------------------- + + template < + typename dest_exp, + typename src_exp + > + void matrix_assign_blas_proxy ( + dest_exp& dest, + const src_exp& src, + typename src_exp::type alpha, + bool add_to, + bool transpose + ); + /*! + requires + - src.aliases(dest) == false + - dest.nr() == src.nr() + - dest.nc() == src.nc() + !*/ + + template < + typename dest_exp, + typename src_exp, typename src_exp2 + > + void matrix_assign_blas_proxy ( + dest_exp& dest, + const matrix_add_exp& src, + typename src_exp::type alpha, + bool add_to, + bool transpose + ); + /*! + requires + - src.aliases(dest) == false + - dest.nr() == src.nr() + - dest.nc() == src.nc() + !*/ + + template < + typename dest_exp, + typename src_exp, bool Sb + > + void matrix_assign_blas_proxy ( + dest_exp& dest, + const matrix_mul_scal_exp& src, + typename src_exp::type alpha, + bool add_to, + bool transpose + ); + /*! + requires + - src.aliases(dest) == false + - dest.nr() == src.nr() + - dest.nc() == src.nc() + !*/ + + template < + typename dest_exp, + typename src_exp + > + void matrix_assign_blas_proxy ( + dest_exp& dest, + const matrix_op >& src, + typename src_exp::type alpha, + bool add_to, + bool transpose + ); + /*! + requires + - src.aliases(dest) == false + - dest.nr() == src.nr() + - dest.nc() == src.nc() + !*/ + + template < + typename dest_exp, + typename src_exp, typename src_exp2 + > + void matrix_assign_blas_proxy ( + dest_exp& dest, + const matrix_subtract_exp& src, + typename src_exp::type alpha, + bool add_to, + bool transpose + ); + /*! + requires + - src.aliases(dest) == false + - dest.nr() == src.nr() + - dest.nc() == src.nc() + !*/ + + // ------------------------------------------------------------------------------------ + + template < + typename T, long NR, long NC, typename MM, typename L, + typename src_exp + > + void matrix_assign_blas ( + matrix& dest, + const src_exp& src + ); + + template < + typename T, long NR, long NC, typename MM, typename L, + typename src_exp + > + void matrix_assign_blas ( + matrix& dest, + const matrix_add_exp ,src_exp>& src + ); + /*! + This function catches the expressions of the form: + M = M + exp; + and converts them into the appropriate matrix_assign_blas() call. + This is an important case to catch because it is the expression used + to represent the += matrix operator. + !*/ + + template < + typename T, long NR, long NC, typename MM, typename L, + typename src_exp + > + void matrix_assign_blas ( + matrix& dest, + const matrix_add_exp >& src + ); + /*! + This function catches the expressions of the form: + M = exp + M; + and converts them into the appropriate matrix_assign_blas() call. + This is an important case to catch because it is the expression used + to represent the += matrix operator. + !*/ + + template < + typename T, long NR, long NC, typename MM, typename L, + typename src_exp + > + void matrix_assign_blas ( + matrix& dest, + const matrix_subtract_exp ,src_exp>& src + ); + /*! + This function catches the expressions of the form: + M = M - exp; + and converts them into the appropriate matrix_assign_blas() call. + This is an important case to catch because it is the expression used + to represent the -= matrix operator. + !*/ + + + // End of forward declarations for overloaded matrix_assign_blas functions + + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + + template < + typename dest_exp, + typename src_exp + > + void matrix_assign_blas_proxy ( + dest_exp& dest, + const src_exp& src, + typename src_exp::type alpha, + bool add_to, + bool transpose + ) + { + matrix_assign_blas_helper::assign(dest,src,alpha,add_to, transpose); + } + + // ------------------------------------------------------------------------------------ + + template < + typename dest_exp, + typename src_exp, typename src_exp2 + > + void matrix_assign_blas_proxy ( + dest_exp& dest, + const matrix_add_exp& src, + typename src_exp::type alpha, + bool add_to, + bool transpose + ) + { + if (has_matrix_multiply::value || has_matrix_multiply::value) + { + matrix_assign_blas_proxy(dest, src.lhs, alpha, add_to, transpose); + matrix_assign_blas_proxy(dest, src.rhs, alpha, true, transpose); + } + else + { + if (transpose == false) + matrix_assign_default(dest, src, alpha, add_to); + else + matrix_assign_default(dest, trans(src), alpha, add_to); + } + } + + // ------------------------------------------------------------------------------------ + + template < + typename dest_exp, + typename src_exp, bool Sb + > + void matrix_assign_blas_proxy ( + dest_exp& dest, + const matrix_mul_scal_exp& src, + typename src_exp::type alpha, + bool add_to, + bool transpose + ) + { + matrix_assign_blas_proxy(dest, src.m, alpha*src.s, add_to, transpose); + } + + // ------------------------------------------------------------------------------------ + + template < + typename dest_exp, + typename src_exp + > + void matrix_assign_blas_proxy ( + dest_exp& dest, + const matrix_op >& src, + typename src_exp::type alpha, + bool add_to, + bool transpose + ) + { + matrix_assign_blas_proxy(dest, src.op.m, alpha, add_to, !transpose); + } + + // ------------------------------------------------------------------------------------ + + template < + typename dest_exp, + typename src_exp, typename src_exp2 + > + void matrix_assign_blas_proxy ( + dest_exp& dest, + const matrix_subtract_exp& src, + typename src_exp::type alpha, + bool add_to, + bool transpose + ) + { + + if (has_matrix_multiply::value || has_matrix_multiply::value) + { + matrix_assign_blas_proxy(dest, src.lhs, alpha, add_to, transpose); + matrix_assign_blas_proxy(dest, src.rhs, -alpha, true, transpose); + } + else + { + if (transpose == false) + matrix_assign_default(dest, src, alpha, add_to); + else + matrix_assign_default(dest, trans(src), alpha, add_to); + } + } + + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + + // Once we get into this function it means that we are dealing with a matrix of float, + // double, complex, or complex and the src_exp contains at least one + // matrix multiply. + + template < + typename T, long NR, long NC, typename MM, typename L, + long NR2, long NC2, bool Sb + > + void matrix_assign_blas ( + matrix& dest, + const matrix_mul_scal_exp,Sb>& src + ) + { + // It's ok that we don't check for aliasing in this case because there isn't + // any complex unrolling of successive + or - operators in this expression. + matrix_assign_blas_proxy(dest,src.m,src.s,false, false); + } + + // ------------------------------------------------------------------------------------ + + template < + typename T, long NR, long NC, typename MM, typename L, + typename src_exp + > + void matrix_assign_blas ( + matrix& dest, + const src_exp& src + ) + { + if (src.aliases(dest)) + { + matrix temp(dest.nr(),dest.nc()); + matrix_assign_blas_proxy(temp,src,1,false, false); + temp.swap(dest); + } + else + { + matrix_assign_blas_proxy(dest,src,1,false, false); + } + } + + // ------------------------------------------------------------------------------------ + + template < + typename T, long NR, long NC, typename MM, typename L, + typename src_exp + > + void matrix_assign_blas ( + assignable_sub_matrix& dest, + const src_exp& src + ) + { + if (src.aliases(dest.m)) + { + matrix temp(dest.nr(),dest.nc()); + matrix_assign_blas_proxy(temp,src,1,false, false); + matrix_assign_default(dest,temp); + } + else + { + matrix_assign_blas_proxy(dest,src,1,false, false); + } + } + + // ------------------------------------------------------------------------------------ + + template < + typename T, + typename src_exp + > + void matrix_assign_blas ( + assignable_ptr_matrix& dest, + const src_exp& src + ) + { + if (src.aliases(mat(dest.ptr,dest.height,dest.width))) + { + matrix temp(dest.nr(),dest.nc()); + matrix_assign_blas_proxy(temp,src,1,false, false); + matrix_assign_default(dest,temp); + } + else + { + matrix_assign_blas_proxy(dest,src,1,false, false); + } + } + + // ------------------------------------------------------------------------------------ + + template < + typename T, long NR, long NC, typename MM, typename L, + typename src_exp + > + void matrix_assign_blas ( + assignable_row_matrix& dest, + const src_exp& src + ) + { + if (src.aliases(dest.m)) + { + matrix temp(dest.nr(),dest.nc()); + matrix_assign_blas_proxy(temp,src,1,false, false); + matrix_assign_default(dest,temp); + } + else + { + matrix_assign_blas_proxy(dest,src,1,false, false); + } + } + + // ------------------------------------------------------------------------------------ + + template < + typename T, long NR, long NC, typename MM, typename L, + typename src_exp + > + void matrix_assign_blas ( + assignable_col_matrix& dest, + const src_exp& src + ) + { + if (src.aliases(dest.m)) + { + matrix temp(dest.nr(),dest.nc()); + matrix_assign_blas_proxy(temp,src,1,false, false); + matrix_assign_default(dest,temp); + } + else + { + matrix_assign_blas_proxy(dest,src,1,false, false); + } + } + + // ------------------------------------------------------------------------------------ + + template < + typename T, long NR, long NC, typename MM, typename L, + typename src_exp + > + void matrix_assign_blas ( + matrix& dest, + const matrix_add_exp ,src_exp>& src + ) + { + if (src.rhs.aliases(dest) == false) + { + if (&src.lhs != &dest) + { + dest = src.lhs; + } + + matrix_assign_blas_proxy(dest, src.rhs, 1, true, false); + } + else + { + matrix temp(src.lhs); + matrix_assign_blas_proxy(temp, src.rhs, 1, true, false); + temp.swap(dest); + } + } + + // ------------------------------------------------------------------------------------ + + template < + typename T, long NR, long NC, typename MM, typename L, + typename src_exp + > + void matrix_assign_blas ( + matrix& dest, + const matrix_add_exp >& src + ) + { + // Just switch around the left and right hand sides of the incoming + // add expression and pass it back into matrix_assign_blas() so that + // the above function will be called. + typedef matrix_add_exp ,src_exp> swapped_add_exp; + matrix_assign_blas(dest, swapped_add_exp(src.rhs, src.lhs)); + } + + // ------------------------------------------------------------------------------------ + + template < + typename T, long NR, long NC, typename MM, typename L, + typename src_exp + > + void matrix_assign_blas ( + matrix& dest, + const matrix_subtract_exp ,src_exp>& src + ) + { + if (src.rhs.aliases(dest) == false) + { + if (&src.lhs != &dest) + { + dest = src.lhs; + } + + matrix_assign_blas_proxy(dest, src.rhs, -1, true, false); + } + else + { + matrix temp(src.lhs); + matrix_assign_blas_proxy(temp, src.rhs, -1, true, false); + temp.swap(dest); + } + } + + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + + } // end of namespace blas_bindings + + // ------------------------------------------------------------------------------------ + + template < + typename T, long NR, long NC, typename MM, typename L, + typename src_exp + > + inline typename enable_if_c<(is_same_type::value || + is_same_type::value || + is_same_type >::value || + is_same_type >::value) && + blas_bindings::has_matrix_multiply::value + >::type matrix_assign_big ( + matrix& dest, + const src_exp& src + ) + { + blas_bindings::matrix_assign_blas(dest,src); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, long NR, long NC, typename MM, typename L, + typename src_exp + > + inline typename enable_if_c<(is_same_type::value || + is_same_type::value || + is_same_type >::value || + is_same_type >::value) && + blas_bindings::has_matrix_multiply::value + >::type matrix_assign_big ( + assignable_sub_matrix& dest, + const src_exp& src + ) + { + blas_bindings::matrix_assign_blas(dest,src); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename src_exp + > + inline typename enable_if_c<(is_same_type::value || + is_same_type::value || + is_same_type >::value || + is_same_type >::value) && + blas_bindings::has_matrix_multiply::value + >::type matrix_assign_big ( + assignable_ptr_matrix& dest, + const src_exp& src + ) + { + blas_bindings::matrix_assign_blas(dest,src); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, long NR, long NC, typename MM, typename L, + typename src_exp + > + inline typename enable_if_c<(is_same_type::value || + is_same_type::value || + is_same_type >::value || + is_same_type >::value) && + blas_bindings::has_matrix_multiply::value + >::type matrix_assign_big ( + assignable_row_matrix& dest, + const src_exp& src + ) + { + blas_bindings::matrix_assign_blas(dest,src); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, long NR, long NC, typename MM, typename L, + typename src_exp + > + inline typename enable_if_c<(is_same_type::value || + is_same_type::value || + is_same_type >::value || + is_same_type >::value) && + blas_bindings::has_matrix_multiply::value + >::type matrix_assign_big ( + assignable_col_matrix& dest, + const src_exp& src + ) + { + blas_bindings::matrix_assign_blas(dest,src); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MATRIx_ASSIGn_ + diff --git a/ml/dlib/dlib/matrix/matrix_assign_fwd.h b/ml/dlib/dlib/matrix/matrix_assign_fwd.h new file mode 100644 index 000000000..7d29baf0a --- /dev/null +++ b/ml/dlib/dlib/matrix/matrix_assign_fwd.h @@ -0,0 +1,413 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MATRIx_ASSIGn_FWD_ +#define DLIB_MATRIx_ASSIGn_FWD_ + +// GCC 4.8 gives false alarms about some variables being uninitialized. Disable these +// false warnings. +#if defined(__GNUC__) && ((__GNUC__ >= 4 && __GNUC_MINOR__ >= 8) || (__GNUC__ > 4)) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" +#endif + +#include "../enable_if.h" +#include "matrix_data_layout.h" +#include "../algs.h" + +namespace dlib +{ + + /* + The point of the matrix_assign() functions is to contain all the various + optimizations that help the matrix assign a matrix_exp to an actual matrix + object quickly. + */ + +// ---------------------------------------------------------------------------------------- + + namespace ma + { + // This template here controls how big a compile time sized matrix needs + // to be for it to get passed into the optimized versions of the + // matrix_assign() function. So small matrices are evaluated with a simple + // loop like the ones in this file and bigger matrices may get sent to BLAS + // routines or some other kind of optimized thing. + template < typename EXP, typename enable = void > + struct is_small_matrix { static const bool value = false; }; + template < typename EXP > + struct is_small_matrix=1 && EXP::NC>=1 && + EXP::NR<=17 && EXP::NC<=17 && (EXP::cost <= 70)>::type> { static const bool value = true; }; + + // I wouldn't use this mul object to do the multiply but visual studio 7.1 wouldn't + // compile otherwise. + template + struct mul { const static long value = a*b; }; + + template < typename EXP, typename enable = void > + struct is_very_small_matrix { static const bool value = false; }; + template < typename EXP > + struct is_very_small_matrix=1 && EXP::NC>=1 && + (mul::value <= 16) && (EXP::cost <= 70)>::type> { static const bool value = true; }; + + + template < typename EXP, typename enable = void > + struct has_column_major_layout { static const bool value = false; }; + template < typename EXP > + struct has_column_major_layout >::type > + { static const bool value = true; }; + + + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP + > + class matrix_exp; + +// ---------------------------------------------------------------------------------------- + + template + inline typename disable_if >::type + matrix_assign_default ( + EXP1& dest, + const EXP2& src + ) + /*! + requires + - src.destructively_aliases(dest) == false + - dest.nr() == src.nr() + - dest.nc() == src.nc() + ensures + - #dest == src + !*/ + { + for (long r = 0; r < src.nr(); ++r) + { + for (long c = 0; c < src.nc(); ++c) + { + dest(r,c) = src(r,c); + } + } + } + +// ---------------------------------------------------------------------------------------- + + template + inline typename enable_if >::type + matrix_assign_default ( + EXP1& dest, + const EXP2& src + ) + /*! + requires + - src.destructively_aliases(dest) == false + - dest.nr() == src.nr() + - dest.nc() == src.nc() + ensures + - #dest == src + !*/ + { + for (long c = 0; c < src.nc(); ++c) + { + for (long r = 0; r < src.nr(); ++r) + { + dest(r,c) = src(r,c); + } + } + } + +// ---------------------------------------------------------------------------------------- + + template + inline typename disable_if >::type + matrix_assign_default ( + EXP1& dest, + const EXP2& src, + typename EXP2::type alpha, + bool add_to + ) + /*! + requires + - src.destructively_aliases(dest) == false + - dest.nr() == src.nr() + - dest.nc() == src.nc() + ensures + - if (add_to == false) then + - #dest == alpha*src + - else + - #dest == dest + alpha*src + !*/ + { + if (add_to) + { + if (alpha == static_cast(1)) + { + for (long r = 0; r < src.nr(); ++r) + { + for (long c = 0; c < src.nc(); ++c) + { + dest(r,c) += src(r,c); + } + } + } + else if (alpha == static_cast(-1)) + { + for (long r = 0; r < src.nr(); ++r) + { + for (long c = 0; c < src.nc(); ++c) + { + dest(r,c) -= src(r,c); + } + } + } + else + { + for (long r = 0; r < src.nr(); ++r) + { + for (long c = 0; c < src.nc(); ++c) + { + dest(r,c) += alpha*src(r,c); + } + } + } + } + else + { + if (alpha == static_cast(1)) + { + for (long r = 0; r < src.nr(); ++r) + { + for (long c = 0; c < src.nc(); ++c) + { + dest(r,c) = src(r,c); + } + } + } + else + { + for (long r = 0; r < src.nr(); ++r) + { + for (long c = 0; c < src.nc(); ++c) + { + dest(r,c) = alpha*src(r,c); + } + } + } + } + } + +// ---------------------------------------------------------------------------------------- + + template + inline typename enable_if >::type + matrix_assign_default ( + EXP1& dest, + const EXP2& src, + typename EXP2::type alpha, + bool add_to + ) + /*! + requires + - src.destructively_aliases(dest) == false + - dest.nr() == src.nr() + - dest.nc() == src.nc() + ensures + - if (add_to == false) then + - #dest == alpha*src + - else + - #dest == dest + alpha*src + !*/ + { + if (add_to) + { + if (alpha == static_cast(1)) + { + for (long c = 0; c < src.nc(); ++c) + { + for (long r = 0; r < src.nr(); ++r) + { + dest(r,c) += src(r,c); + } + } + } + else if (alpha == static_cast(-1)) + { + for (long c = 0; c < src.nc(); ++c) + { + for (long r = 0; r < src.nr(); ++r) + { + dest(r,c) -= src(r,c); + } + } + } + else + { + for (long c = 0; c < src.nc(); ++c) + { + for (long r = 0; r < src.nr(); ++r) + { + dest(r,c) += alpha*src(r,c); + } + } + } + } + else + { + if (alpha == static_cast(1)) + { + for (long c = 0; c < src.nc(); ++c) + { + for (long r = 0; r < src.nr(); ++r) + { + dest(r,c) = src(r,c); + } + } + } + else + { + for (long c = 0; c < src.nc(); ++c) + { + for (long r = 0; r < src.nr(); ++r) + { + dest(r,c) = alpha*src(r,c); + } + } + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename matrix_dest_type, + typename src_exp + > + void matrix_assign_big ( + matrix_dest_type& dest, + const matrix_exp& src + ) + { + matrix_assign_default(dest,src); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename matrix_dest_type, + typename src_exp + > + inline typename disable_if >::type matrix_assign ( + matrix_dest_type& dest, + const matrix_exp& src + ) + /*! + requires + - src.destructively_aliases(dest) == false + - dest.nr() == src.nr() + - dest.nc() == src.nc() + ensures + - #dest == src + !*/ + { + // Call src.ref() here so that the derived type of the matrix_exp shows + // up so we can overload matrix_assign_big() based on various matrix expression + // types. + matrix_assign_big(dest,src.ref()); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + +// this code is here to perform an unrolled version of the matrix_assign() function + template < typename DEST, typename SRC, long NR, long NC, + long R = 0, long C = 0, bool base_case = (R==NR) > + struct matrix_unroll_helper + { + inline static void go ( DEST& dest, const SRC& src) + { + dest(R,C) = src(R,C); + matrix_unroll_helper::go(dest,src); + } + }; + + template < typename DEST, typename SRC, long NR, long NC, long R, long C > + struct matrix_unroll_helper + { inline static void go ( DEST& , const SRC& ) {} }; + + template + inline void matrix_assign_unrolled ( + DEST& dest, + const SRC& src + ) + /*! + requires + - src.destructively_aliases(dest) == false + - dest.nr() == src.nr() + - dest.nc() == src.nc() + ensures + - #dest == src + !*/ + { + COMPILE_TIME_ASSERT(SRC::NR*SRC::NC != 0); + matrix_unroll_helper::go(dest,src); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename matrix_dest_type, + typename src_exp + > + inline typename enable_if_c::value && ma::is_very_small_matrix::value==false >::type matrix_assign ( + matrix_dest_type& dest, + const matrix_exp& src + ) + /*! + requires + - src.destructively_aliases(dest) == false + - dest.nr() == src.nr() + - dest.nc() == src.nc() + ensures + - #dest == src + !*/ + { + matrix_assign_default(dest,src.ref()); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename matrix_dest_type, + typename src_exp + > + inline typename enable_if_c::value && ma::is_very_small_matrix::value==true >::type matrix_assign ( + matrix_dest_type& dest, + const matrix_exp& src + ) + /*! + requires + - src.destructively_aliases(dest) == false + - dest.nr() == src.nr() + - dest.nc() == src.nc() + ensures + - #dest == src + !*/ + { + matrix_assign_unrolled(dest,src.ref()); + } + +// ---------------------------------------------------------------------------------------- + +} + +#if defined(__GNUC__) && ((__GNUC__ >= 4 && __GNUC_MINOR__ >= 8) || (__GNUC__ > 4)) +#pragma GCC diagnostic pop +#endif + +#endif // DLIB_MATRIx_ASSIGn_FWD_ + + diff --git a/ml/dlib/dlib/matrix/matrix_blas_bindings.h b/ml/dlib/dlib/matrix/matrix_blas_bindings.h new file mode 100644 index 000000000..b65e29cdd --- /dev/null +++ b/ml/dlib/dlib/matrix/matrix_blas_bindings.h @@ -0,0 +1,1637 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MATRIx_BLAS_BINDINGS_ +#define DLIB_MATRIx_BLAS_BINDINGS_ + +#ifndef DLIB_USE_BLAS +#error "DLIB_USE_BLAS should be defined if you want to use the BLAS bindings" +#endif + +#include "matrix_assign.h" +#include "matrix_conj_trans.h" +#include "cblas_constants.h" + +//#include +//using namespace std; + +namespace dlib +{ + + + namespace blas_bindings + { + +#ifdef DLIB_TEST_BLAS_BINDINGS + int& counter_gemm(); + int& counter_gemv(); + int& counter_ger(); + int& counter_dot(); + int& counter_axpy(); + int& counter_scal(); + + #define DLIB_TEST_BLAS_BINDING_GEMM ++counter_gemm(); + #define DLIB_TEST_BLAS_BINDING_GEMV ++counter_gemv(); + #define DLIB_TEST_BLAS_BINDING_GER ++counter_ger(); + #define DLIB_TEST_BLAS_BINDING_DOT ++counter_dot(); + #define DLIB_TEST_BLAS_BINDING_AXPY ++counter_axpy(); + #define DLIB_TEST_BLAS_BINDING_SCAL ++counter_scal(); +#else + #define DLIB_TEST_BLAS_BINDING_GEMM + #define DLIB_TEST_BLAS_BINDING_GEMV + #define DLIB_TEST_BLAS_BINDING_GER + #define DLIB_TEST_BLAS_BINDING_DOT + #define DLIB_TEST_BLAS_BINDING_AXPY + #define DLIB_TEST_BLAS_BINDING_SCAL +#endif + +#ifndef CBLAS_H + extern "C" + { + // Here we declare the prototypes for the CBLAS calls used by the BLAS bindings below + + void cblas_saxpy(const int N, const float alpha, const float *X, + const int incX, float *Y, const int incY); + void cblas_daxpy(const int N, const double alpha, const double *X, + const int incX, double *Y, const int incY); + void cblas_caxpy(const int N, const void *alpha, const void *X, + const int incX, void *Y, const int incY); + void cblas_zaxpy(const int N, const void *alpha, const void *X, + const int incX, void *Y, const int incY); + + void cblas_sscal(const int N, const float alpha, float *X, const int incX); + void cblas_dscal(const int N, const double alpha, double *X, const int incX); + void cblas_cscal(const int N, const void *alpha, void *X, const int incX); + void cblas_zscal(const int N, const void *alpha, void *X, const int incX); + + void cblas_sgemm(const CBLAS_ORDER Order, const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, const int M, const int N, + const int K, const float alpha, const float *A, + const int lda, const float *B, const int ldb, + const float beta, float *C, const int ldc); + void cblas_dgemm(const CBLAS_ORDER Order, const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, const int M, const int N, + const int K, const double alpha, const double *A, + const int lda, const double *B, const int ldb, + const double beta, double *C, const int ldc); + void cblas_cgemm(const CBLAS_ORDER Order, const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, const int M, const int N, + const int K, const void *alpha, const void *A, + const int lda, const void *B, const int ldb, + const void *beta, void *C, const int ldc); + void cblas_zgemm(const CBLAS_ORDER Order, const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, const int M, const int N, + const int K, const void *alpha, const void *A, + const int lda, const void *B, const int ldb, + const void *beta, void *C, const int ldc); + void cblas_sgemv(const CBLAS_ORDER order, + const CBLAS_TRANSPOSE TransA, const int M, const int N, + const float alpha, const float *A, const int lda, + const float *X, const int incX, const float beta, + float *Y, const int incY); + void cblas_dgemv(const CBLAS_ORDER order, + const CBLAS_TRANSPOSE TransA, const int M, const int N, + const double alpha, const double *A, const int lda, + const double *X, const int incX, const double beta, + double *Y, const int incY); + void cblas_cgemv(const CBLAS_ORDER order, + const CBLAS_TRANSPOSE TransA, const int M, const int N, + const void *alpha, const void *A, const int lda, + const void *X, const int incX, const void *beta, + void *Y, const int incY); + void cblas_zgemv(const CBLAS_ORDER order, + const CBLAS_TRANSPOSE TransA, const int M, const int N, + const void *alpha, const void *A, const int lda, + const void *X, const int incX, const void *beta, + void *Y, const int incY); + void cblas_sger(const CBLAS_ORDER order, const int M, const int N, + const float alpha, const float *X, const int incX, + const float *Y, const int incY, float *A, const int lda); + void cblas_dger(const CBLAS_ORDER order, const int M, const int N, + const double alpha, const double *X, const int incX, + const double *Y, const int incY, double *A, const int lda); + void cblas_cgerc(const CBLAS_ORDER order, const int M, const int N, + const void *alpha, const void *X, const int incX, + const void *Y, const int incY, void *A, const int lda); + void cblas_zgerc(const CBLAS_ORDER order, const int M, const int N, + const void *alpha, const void *X, const int incX, + const void *Y, const int incY, void *A, const int lda); + float cblas_sdot(const int N, const float *X, const int incX, + const float *Y, const int incY); + double cblas_ddot(const int N, const double *X, const int incX, + const double *Y, const int incY); + void cblas_cdotu_sub(const int N, const void *X, const int incX, + const void *Y, const int incY, void *dotu); + void cblas_zdotu_sub(const int N, const void *X, const int incX, + const void *Y, const int incY, void *dotu); + void cblas_cdotc_sub(const int N, const void *X, const int incX, + const void *Y, const int incY, void *dotc); + void cblas_zdotc_sub(const int N, const void *X, const int incX, + const void *Y, const int incY, void *dotc); + void cblas_cgeru(const CBLAS_ORDER order, const int M, const int N, + const void *alpha, const void *X, const int incX, + const void *Y, const int incY, void *A, const int lda); + void cblas_zgeru(const CBLAS_ORDER order, const int M, const int N, + const void *alpha, const void *X, const int incX, + const void *Y, const int incY, void *A, const int lda); + } +#endif // if not CBLAS_H + + // ---------------------------------------------------------------------------------------- + // ---------------------------------------------------------------------------------------- + + inline void cblas_axpy(const int N, const float alpha, const float *X, + const int incX, float *Y, const int incY) + { + DLIB_TEST_BLAS_BINDING_AXPY; + cblas_saxpy(N, alpha, X, incX, Y, incY); + } + + inline void cblas_axpy(const int N, const double alpha, const double *X, + const int incX, double *Y, const int incY) + { + DLIB_TEST_BLAS_BINDING_AXPY; + cblas_daxpy(N, alpha, X, incX, Y, incY); + } + + inline void cblas_axpy(const int N, const std::complex& alpha, const std::complex *X, + const int incX, std::complex *Y, const int incY) + { + DLIB_TEST_BLAS_BINDING_AXPY; + cblas_caxpy(N, &alpha, X, incX, Y, incY); + } + + inline void cblas_axpy(const int N, const std::complex& alpha, const std::complex *X, + const int incX, std::complex *Y, const int incY) + { + DLIB_TEST_BLAS_BINDING_AXPY; + cblas_zaxpy(N, &alpha, X, incX, Y, incY); + } + + // ---------------------------------------------------------------------------------------- + + inline void cblas_scal(const int N, const float alpha, float *X) + { + DLIB_TEST_BLAS_BINDING_SCAL; + cblas_sscal(N, alpha, X, 1); + } + + inline void cblas_scal(const int N, const double alpha, double *X) + { + DLIB_TEST_BLAS_BINDING_SCAL; + cblas_dscal(N, alpha, X, 1); + } + + inline void cblas_scal(const int N, const std::complex& alpha, std::complex *X) + { + DLIB_TEST_BLAS_BINDING_SCAL; + cblas_cscal(N, &alpha, X, 1); + } + + inline void cblas_scal(const int N, const std::complex& alpha, std::complex *X) + { + DLIB_TEST_BLAS_BINDING_SCAL; + cblas_zscal(N, &alpha, X, 1); + } + + // ---------------------------------------------------------------------------------------- + + inline void cblas_gemm( const CBLAS_ORDER Order, const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, const int M, const int N, + const int K, const float alpha, const float *A, + const int lda, const float *B, const int ldb, + const float beta, float *C, const int ldc) + { + DLIB_TEST_BLAS_BINDING_GEMM; + cblas_sgemm( Order, TransA, TransB, M, N, + K, alpha, A, lda, B, ldb, beta, C, ldc); + } + + inline void cblas_gemm(const CBLAS_ORDER Order, const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, const int M, const int N, + const int K, const double alpha, const double *A, + const int lda, const double *B, const int ldb, + const double beta, double *C, const int ldc) + { + DLIB_TEST_BLAS_BINDING_GEMM; + cblas_dgemm( Order, TransA, TransB, M, N, + K, alpha, A, lda, B, ldb, beta, C, ldc); + } + + inline void cblas_gemm(const CBLAS_ORDER Order, const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, const int M, const int N, + const int K, const std::complex& alpha, const std::complex *A, + const int lda, const std::complex *B, const int ldb, + const std::complex& beta, std::complex *C, const int ldc) + { + DLIB_TEST_BLAS_BINDING_GEMM; + cblas_cgemm( Order, TransA, TransB, M, N, + K, &alpha, A, lda, B, ldb, &beta, C, ldc); + } + + inline void cblas_gemm(const CBLAS_ORDER Order, const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, const int M, const int N, + const int K, const std::complex& alpha, const std::complex *A, + const int lda, const std::complex *B, const int ldb, + const std::complex& beta, std::complex *C, const int ldc) + { + DLIB_TEST_BLAS_BINDING_GEMM; + cblas_zgemm( Order, TransA, TransB, M, N, + K, &alpha, A, lda, B, ldb, &beta, C, ldc); + } + + // ---------------------------------------------------------------------------------------- + + inline void cblas_gemv(const CBLAS_ORDER order, + const CBLAS_TRANSPOSE TransA, const int M, const int N, + const float alpha, const float *A, const int lda, + const float *X, const int incX, const float beta, + float *Y, const int incY) + { + DLIB_TEST_BLAS_BINDING_GEMV; + cblas_sgemv(order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY); + } + + inline void cblas_gemv(const CBLAS_ORDER order, + const CBLAS_TRANSPOSE TransA, const int M, const int N, + const double alpha, const double *A, const int lda, + const double *X, const int incX, const double beta, + double *Y, const int incY) + { + DLIB_TEST_BLAS_BINDING_GEMV; + cblas_dgemv(order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY); + } + + inline void cblas_gemv(const CBLAS_ORDER order, + const CBLAS_TRANSPOSE TransA, const int M, const int N, + const std::complex& alpha, const std::complex *A, const int lda, + const std::complex *X, const int incX, const std::complex& beta, + std::complex *Y, const int incY) + { + DLIB_TEST_BLAS_BINDING_GEMV; + cblas_cgemv(order, TransA, M, N, &alpha, A, lda, X, incX, &beta, Y, incY); + } + + inline void cblas_gemv(const CBLAS_ORDER order, + const CBLAS_TRANSPOSE TransA, const int M, const int N, + const std::complex& alpha, const std::complex *A, const int lda, + const std::complex *X, const int incX, const std::complex& beta, + std::complex *Y, const int incY) + { + DLIB_TEST_BLAS_BINDING_GEMV; + cblas_zgemv(order, TransA, M, N, &alpha, A, lda, X, incX, &beta, Y, incY); + } + + // ---------------------------------------------------------------------------------------- + + inline void cblas_ger(const CBLAS_ORDER order, const int M, const int N, + const std::complex& alpha, const std::complex *X, const int incX, + const std::complex *Y, const int incY, std::complex *A, const int lda) + { + DLIB_TEST_BLAS_BINDING_GER; + cblas_cgeru (order, M, N, &alpha, X, incX, Y, incY, A, lda); + } + + inline void cblas_ger(const CBLAS_ORDER order, const int M, const int N, + const std::complex& alpha, const std::complex *X, const int incX, + const std::complex *Y, const int incY, std::complex *A, const int lda) + { + DLIB_TEST_BLAS_BINDING_GER; + cblas_zgeru (order, M, N, &alpha, X, incX, Y, incY, A, lda); + } + + inline void cblas_ger(const CBLAS_ORDER order, const int M, const int N, + const float alpha, const float *X, const int incX, + const float *Y, const int incY, float *A, const int lda) + { + DLIB_TEST_BLAS_BINDING_GER; + cblas_sger (order, M, N, alpha, X, incX, Y, incY, A, lda); + } + + inline void cblas_ger(const CBLAS_ORDER order, const int M, const int N, + const double alpha, const double *X, const int incX, + const double *Y, const int incY, double *A, const int lda) + { + DLIB_TEST_BLAS_BINDING_GER; + cblas_dger (order, M, N, alpha, X, incX, Y, incY, A, lda); + } + + // ---------------------------------------------------------------------------------------- + + inline void cblas_gerc(const CBLAS_ORDER order, const int M, const int N, + const std::complex& alpha, const std::complex *X, const int incX, + const std::complex *Y, const int incY, std::complex *A, const int lda) + { + DLIB_TEST_BLAS_BINDING_GER; + cblas_cgerc (order, M, N, &alpha, X, incX, Y, incY, A, lda); + } + + inline void cblas_gerc(const CBLAS_ORDER order, const int M, const int N, + const std::complex& alpha, const std::complex *X, const int incX, + const std::complex *Y, const int incY, std::complex *A, const int lda) + { + DLIB_TEST_BLAS_BINDING_GER; + cblas_zgerc (order, M, N, &alpha, X, incX, Y, incY, A, lda); + } + + // ---------------------------------------------------------------------------------------- + + inline float cblas_dot(const int N, const float *X, const int incX, + const float *Y, const int incY) + { + DLIB_TEST_BLAS_BINDING_DOT; + return cblas_sdot(N, X, incX, Y, incY); + } + + inline double cblas_dot(const int N, const double *X, const int incX, + const double *Y, const int incY) + { + DLIB_TEST_BLAS_BINDING_DOT; + return cblas_ddot(N, X, incX, Y, incY); + } + + inline std::complex cblas_dot(const int N, const std::complex *X, const int incX, + const std::complex *Y, const int incY) + { + DLIB_TEST_BLAS_BINDING_DOT; + std::complex result; + cblas_cdotu_sub(N, X, incX, Y, incY, &result); + return result; + } + + inline std::complex cblas_dot(const int N, const std::complex *X, const int incX, + const std::complex *Y, const int incY) + { + DLIB_TEST_BLAS_BINDING_DOT; + std::complex result; + cblas_zdotu_sub(N, X, incX, Y, incY, &result); + return result; + } + + // ---------------------------------------------------------------------------------------- + + inline std::complex cblas_dotc(const int N, const std::complex *X, const int incX, + const std::complex *Y, const int incY) + { + DLIB_TEST_BLAS_BINDING_DOT; + std::complex result; + cblas_cdotc_sub(N, X, incX, Y, incY, &result); + return result; + } + + inline std::complex cblas_dotc(const int N, const std::complex *X, const int incX, + const std::complex *Y, const int incY) + { + DLIB_TEST_BLAS_BINDING_DOT; + std::complex result; + cblas_zdotc_sub(N, X, incX, Y, incY, &result); + return result; + } + + // ---------------------------------------------------------------------------------------- + // ---------------------------------------------------------------------------------------- + // Helpers for determining the data pointer, LDA, and incX arguments to BLAS functions. + + template + int get_ld (const matrix& m) { return m.nc(); } + + template + int get_ld (const matrix& m) { return m.nr(); } + + + template + int get_ld (const matrix_op > >& m) { return m.op.m.nc(); } + + template + int get_ld (const matrix_op > >& m) { return m.op.m.nr(); } + + template + int get_ld (const assignable_sub_matrix& m) { return m.m.nc(); } + + template + int get_ld (const assignable_sub_matrix& m) { return m.m.nr(); } + + template + int get_ld (const assignable_col_matrix& m) { return m.m.nc(); } + + template + int get_ld (const assignable_col_matrix& m) { return m.m.nr(); } + + template + int get_ld (const assignable_row_matrix& m) { return m.m.nc(); } + + template + int get_ld (const assignable_row_matrix& m) { return m.m.nr(); } + + template + int get_ld (const assignable_ptr_matrix& m) { return m.nc(); } + + template + int get_ld (const matrix_op > >& m) { return m.nc(); } + template + int get_ld (const matrix_op > >& m) { return m.nc(); } + template < typename value_type, typename alloc > + int get_ld (const matrix_op > >& m) { return m.nc(); } + template < typename value_type, typename alloc > + int get_ld (const matrix_op > >& m) { return m.nc(); } + template + int get_ld (const matrix_op >& m) { return m.nc(); } + template + int get_ld (const matrix_op >& m) { return m.op.stride; } + + // -------- + + // get_inc() returns the offset from one element to another. If an object has a + // non-uniform offset between elements then returns 0 (e.g. a subm() view could + // have a non-uniform offset between elements). + + template + int get_inc (const matrix_op > >& ) { return 1; } + template + int get_inc (const matrix_op > >& ) { return 1; } + template < typename value_type, typename alloc > + int get_inc (const matrix_op > >& ) { return 1; } + template < typename value_type, typename alloc > + int get_inc (const matrix_op > >& ) { return 1; } + template + int get_inc (const matrix_op >& ) { return 1; } + template + int get_inc (const matrix_op >& m) { return m.op.stride==m.op.cols ? 1 : 0; } + + template + int get_inc (const matrix& ) { return 1; } + + template + int get_inc (const matrix_op > >& m) + { + // if the sub-view doesn't cover all the columns then it can't have a uniform + // layout. + if (m.nc() < m.op.m.nc()) + return 0; + else + return 1; + } + + template + int get_inc (const matrix_op > >& m) + { + if (m.nr() < m.op.m.nr()) + return 0; + else + return 1; + } + + template + int get_inc (const assignable_sub_matrix& m) + { + if (m.nc() < m.m.nc()) + return 0; + else + return 1; + } + template + int get_inc (const assignable_sub_matrix& m) + { + if (m.nr() < m.m.nr()) + return 0; + else + return 1; + } + + template + int get_inc (const assignable_ptr_matrix& ) { return 1; } + + template + int get_inc(const matrix_op > >& m) + { + return m.op.m.nc(); + } + + template + int get_inc(const matrix_op > >& ) + { + return 1; + } + + template + int get_inc(const matrix_op > >& m) + { + return m.op.m.nc(); + } + + template + int get_inc(const matrix_op > >& ) + { + return 1; + } + + + + template + int get_inc(const matrix_op > >& ) + { + return 1; + } + + template + int get_inc(const matrix_op > >& m) + { + return m.op.m.nr(); + } + + template + int get_inc(const matrix_op > >& ) + { + return 1; + } + + template + int get_inc(const matrix_op > >& m) + { + return m.op.m.nr(); + } + + + + template + int get_inc(const assignable_row_matrix& ) + { + return 1; + } + + template + int get_inc(const assignable_row_matrix& m) + { + return m.m.nr(); + } + + template + int get_inc(const assignable_col_matrix& m) + { + return m.m.nc(); + } + + template + int get_inc(const assignable_col_matrix& ) + { + return 1; + } + + // -------- + + template + const T* get_ptr (const matrix& m) { return &m(0,0); } + + template + T* get_ptr (matrix& m) { return &m(0,0); } + + template + const T* get_ptr (const matrix_op > >& m) { return &m.op.m(m.op.r_,m.op.c_); } + + template + const T* get_ptr (const matrix_op > >& m) { return &m.op.m(0,m.op.col); } + + template + const T* get_ptr (const matrix_op > >& m) { return &m.op.m(m.op.row,0); } + + template + const T* get_ptr (const matrix_op > >& m) { return &m.op.m(0,m.op.col); } + + template + const T* get_ptr (const matrix_op > >& m) { return &m.op.m(m.op.row,0); } + + + template + T* get_ptr (assignable_col_matrix& m) { return &m(0,0); } + + template + T* get_ptr (assignable_row_matrix& m) { return &m(0,0); } + + template + T* get_ptr (assignable_sub_matrix& m) { return &m(0,0); } + + template + T* get_ptr (assignable_ptr_matrix& m) { return m.ptr; } + + template + const T* get_ptr (const matrix_op > >& m) { return &m.op.array[0][0]; } + template + const T* get_ptr (const matrix_op > >& m) { return &m.op.vect[0]; } + template < typename T, typename alloc > + const T* get_ptr (const matrix_op > >& m) { return &m.op.vect[0]; } + template < typename T, typename alloc > + const T* get_ptr (const matrix_op > >& m) { return &m.op.vect[0]; } + template + const T* get_ptr (const matrix_op >& m) { return m.op.ptr; } + template + const T* get_ptr (const matrix_op >& m) { return m.op.ptr; } + + // ---------------------------------------------------------------------------------------- + // ---------------------------------------------------------------------------------------- + + // Here we declare some matrix objects for use in the DLIB_ADD_BLAS_BINDING macro. These + // extern declarations don't actually correspond to any real matrix objects. They are + // simply here so we can build matrix expressions with the DLIB_ADD_BLAS_BINDING marco. + + + // Note that the fact that these are double matrices isn't important, it is just a placeholder in this case. + extern matrix m; // general matrix + extern matrix rv; // general row vector + extern matrix cv; // general column vector + extern const double s; + + // ---------------------------------------------------------------------------------------- + // ---------------------------------------------------------------------------------------- + // AXPY/SCAL overloads + // ---------------------------------------------------------------------------------------- + // ---------------------------------------------------------------------------------------- + + DLIB_ADD_BLAS_BINDING(m) + { + + const int N = static_cast(src.size()); + if (transpose == false && N != 0) + { + if (add_to) + { + if (get_inc(src) && get_inc(dest)) + cblas_axpy(N, alpha, get_ptr(src), get_inc(src), get_ptr(dest), get_inc(dest)); + else + matrix_assign_default(dest, src, alpha, add_to); + } + else + { + if (get_ptr(src) == get_ptr(dest)) + cblas_scal(N, alpha, get_ptr(dest)); + else + matrix_assign_default(dest, src, alpha, add_to); + } + } + else + { + matrix_assign_default(dest, trans(src), alpha, add_to); + } + + } DLIB_END_BLAS_BINDING + + DLIB_ADD_BLAS_BINDING(rv) + { + + const int N = static_cast(src.size()); + if (transpose == false && N != 0) + { + if (add_to) + { + if (get_inc(src) && get_inc(dest)) + cblas_axpy(N, alpha, get_ptr(src), get_inc(src), get_ptr(dest), get_inc(dest)); + else + matrix_assign_default(dest, src, alpha, add_to); + } + else + { + if (get_ptr(src) == get_ptr(dest)) + cblas_scal(N, alpha, get_ptr(dest)); + else + matrix_assign_default(dest, src, alpha, add_to); + } + } + else + { + matrix_assign_default(dest, trans(src), alpha, add_to); + } + + } DLIB_END_BLAS_BINDING + + DLIB_ADD_BLAS_BINDING(cv) + { + + const int N = static_cast(src.size()); + if (transpose == false && N != 0) + { + if (add_to) + { + if (get_inc(src) && get_inc(dest)) + cblas_axpy(N, alpha, get_ptr(src), get_inc(src), get_ptr(dest), get_inc(dest)); + else + matrix_assign_default(dest, src, alpha, add_to); + } + else + { + if (get_ptr(src) == get_ptr(dest)) + cblas_scal(N, alpha, get_ptr(dest)); + else + matrix_assign_default(dest, src, alpha, add_to); + } + } + else + { + matrix_assign_default(dest, trans(src), alpha, add_to); + } + + } DLIB_END_BLAS_BINDING + + // ---------------------------------------------------------------------------------------- + // ---------------------------------------------------------------------------------------- + // GEMM overloads + // ---------------------------------------------------------------------------------------- + // ---------------------------------------------------------------------------------------- + + DLIB_ADD_BLAS_BINDING(m*m) + { + //cout << "BLAS GEMM: m*m" << endl; + const bool is_row_major_order = is_same_type::value; + const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; + const int M = static_cast(src.nr()); + const int N = static_cast(src.nc()); + const int K = static_cast(src.lhs.nc()); + const T* A = get_ptr(src.lhs); + const int lda = get_ld(src.lhs); + const T* B = get_ptr(src.rhs); + const int ldb = get_ld(src.rhs); + + const T beta = static_cast(add_to?1:0); + T* C = get_ptr(dest); + const int ldc = get_ld(dest); + + if (transpose == false) + cblas_gemm(Order, CblasNoTrans, CblasNoTrans, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); + else + cblas_gemm(Order, CblasTrans, CblasTrans, N, M, K, alpha, B, ldb, A, lda, beta, C, ldc); + + } DLIB_END_BLAS_BINDING + + // -------------------------------------- + + DLIB_ADD_BLAS_BINDING(trans(m)*m) + { + //cout << "BLAS GEMM: trans(m)*m" << endl; + const bool is_row_major_order = is_same_type::value; + const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; + const CBLAS_TRANSPOSE TransA = CblasTrans; + const CBLAS_TRANSPOSE TransB = CblasNoTrans; + const int M = static_cast(src.nr()); + const int N = static_cast(src.nc()); + const int K = static_cast(src.lhs.nc()); + const T* A = get_ptr(src.lhs.op.m); + const int lda = get_ld(src.lhs.op.m); + const T* B = get_ptr(src.rhs); + const int ldb = get_ld(src.rhs); + + const T beta = static_cast(add_to?1:0); + T* C = get_ptr(dest); + const int ldc = get_ld(dest); + + if (transpose == false) + cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); + else + cblas_gemm(Order, TransA, TransB, N, M, K, alpha, B, ldb, A, lda, beta, C, ldc); + + } DLIB_END_BLAS_BINDING + + // -------------------------------------- + + DLIB_ADD_BLAS_BINDING(m*trans(m)) + { + //cout << "BLAS GEMM: m*trans(m)" << endl; + const bool is_row_major_order = is_same_type::value; + const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; + const CBLAS_TRANSPOSE TransA = CblasNoTrans; + const CBLAS_TRANSPOSE TransB = CblasTrans; + const int M = static_cast(src.nr()); + const int N = static_cast(src.nc()); + const int K = static_cast(src.lhs.nc()); + const T* A = get_ptr(src.lhs); + const int lda = get_ld(src.lhs); + const T* B = get_ptr(src.rhs.op.m); + const int ldb = get_ld(src.rhs.op.m); + + const T beta = static_cast(add_to?1:0); + T* C = get_ptr(dest); + const int ldc = get_ld(dest); + + if (transpose == false) + cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); + else + cblas_gemm(Order, TransA, TransB, N, M, K, alpha, B, ldb, A, lda, beta, C, ldc); + + } DLIB_END_BLAS_BINDING + + // -------------------------------------- + + DLIB_ADD_BLAS_BINDING(trans(m)*trans(m)) + { + //cout << "BLAS GEMM: trans(m)*trans(m)" << endl; + const bool is_row_major_order = is_same_type::value; + const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; + const int M = static_cast(src.nr()); + const int N = static_cast(src.nc()); + const int K = static_cast(src.lhs.nc()); + const T* A = get_ptr(src.lhs.op.m); + const int lda = get_ld(src.lhs.op.m); + const T* B = get_ptr(src.rhs.op.m); + const int ldb = get_ld(src.rhs.op.m); + + const T beta = static_cast(add_to?1:0); + T* C = get_ptr(dest); + const int ldc = get_ld(dest); + + if (transpose == false) + cblas_gemm(Order, CblasTrans, CblasTrans, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); + else + cblas_gemm(Order, CblasNoTrans, CblasNoTrans, N, M, K, alpha, B, ldb, A, lda, beta, C, ldc); + } DLIB_END_BLAS_BINDING + + // -------------------------------------- + // -------------------------------------- + // -------------------------------------- + + DLIB_ADD_BLAS_BINDING(trans(conj(m))*m) + { + //cout << "BLAS GEMM: trans(conj(m))*m" << endl; + const bool is_row_major_order = is_same_type::value; + const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; + const CBLAS_TRANSPOSE TransA = CblasConjTrans; + const CBLAS_TRANSPOSE TransB = CblasNoTrans; + const int M = static_cast(src.nr()); + const int N = static_cast(src.nc()); + const int K = static_cast(src.lhs.nc()); + const T* A = get_ptr(src.lhs.op.m); + const int lda = get_ld(src.lhs.op.m); + const T* B = get_ptr(src.rhs); + const int ldb = get_ld(src.rhs); + + const T beta = static_cast(add_to?1:0); + T* C = get_ptr(dest); + const int ldc = get_ld(dest); + + if (transpose == false) + cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); + else + matrix_assign_default(dest, trans(src), alpha, add_to); + + } DLIB_END_BLAS_BINDING + + // -------------------------------------- + + DLIB_ADD_BLAS_BINDING(trans(conj(m))*trans(m)) + { + //cout << "BLAS GEMM: trans(conj(m))*trans(m)" << endl; + const bool is_row_major_order = is_same_type::value; + const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; + const CBLAS_TRANSPOSE TransA = CblasConjTrans; + const CBLAS_TRANSPOSE TransB = CblasTrans; + const int M = static_cast(src.nr()); + const int N = static_cast(src.nc()); + const int K = static_cast(src.lhs.nc()); + const T* A = get_ptr(src.lhs.op.m); + const int lda = get_ld(src.lhs.op.m); + const T* B = get_ptr(src.rhs.op.m); + const int ldb = get_ld(src.rhs.op.m); + + const T beta = static_cast(add_to?1:0); + T* C = get_ptr(dest); + const int ldc = get_ld(dest); + + if (transpose == false) + cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); + else + matrix_assign_default(dest, trans(src), alpha, add_to); + + } DLIB_END_BLAS_BINDING + + // -------------------------------------- + + DLIB_ADD_BLAS_BINDING(m*trans(conj(m))) + { + //cout << "BLAS GEMM: m*trans(conj(m))" << endl; + const bool is_row_major_order = is_same_type::value; + const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; + const CBLAS_TRANSPOSE TransA = CblasNoTrans; + const CBLAS_TRANSPOSE TransB = CblasConjTrans; + const int M = static_cast(src.nr()); + const int N = static_cast(src.nc()); + const int K = static_cast(src.lhs.nc()); + const T* A = get_ptr(src.lhs); + const int lda = get_ld(src.lhs); + const T* B = get_ptr(src.rhs.op.m); + const int ldb = get_ld(src.rhs.op.m); + + const T beta = static_cast(add_to?1:0); + T* C = get_ptr(dest); + const int ldc = get_ld(dest); + + if (transpose == false) + cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); + else + matrix_assign_default(dest, trans(src), alpha, add_to); + } DLIB_END_BLAS_BINDING + + // -------------------------------------- + + DLIB_ADD_BLAS_BINDING(trans(m)*trans(conj(m))) + { + const bool is_row_major_order = is_same_type::value; + const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; + const CBLAS_TRANSPOSE TransA = CblasTrans; + const CBLAS_TRANSPOSE TransB = CblasConjTrans; + const int M = static_cast(src.nr()); + const int N = static_cast(src.nc()); + const int K = static_cast(src.lhs.nc()); + const T* A = get_ptr(src.lhs.op.m); + const int lda = get_ld(src.lhs.op.m); + const T* B = get_ptr(src.rhs.op.m); + const int ldb = get_ld(src.rhs.op.m); + + const T beta = static_cast(add_to?1:0); + T* C = get_ptr(dest); + const int ldc = get_ld(dest); + + if (transpose == false) + cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); + else + matrix_assign_default(dest, trans(src), alpha, add_to); + } DLIB_END_BLAS_BINDING + + // -------------------------------------- + + DLIB_ADD_BLAS_BINDING(trans(conj(m))*trans(conj(m))) + { + //cout << "BLAS GEMM: trans(conj(m))*trans(conj(m))" << endl; + const bool is_row_major_order = is_same_type::value; + const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; + const CBLAS_TRANSPOSE TransA = CblasConjTrans; + const CBLAS_TRANSPOSE TransB = CblasConjTrans; + const int M = static_cast(src.nr()); + const int N = static_cast(src.nc()); + const int K = static_cast(src.lhs.nc()); + const T* A = get_ptr(src.lhs.op.m); + const int lda = get_ld(src.lhs.op.m); + const T* B = get_ptr(src.rhs.op.m); + const int ldb = get_ld(src.rhs.op.m); + + const T beta = static_cast(add_to?1:0); + T* C = get_ptr(dest); + const int ldc = get_ld(dest); + + if (transpose == false) + cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); + else + matrix_assign_default(dest, trans(src), alpha, add_to); + } DLIB_END_BLAS_BINDING + + // ---------------------------------------------------------------------------------------- + // ---------------------------------------------------------------------------------------- + // GEMV overloads + // ---------------------------------------------------------------------------------------- + // ---------------------------------------------------------------------------------------- + + DLIB_ADD_BLAS_BINDING(m*cv) + { + //cout << "BLAS GEMV: m*cv" << endl; + const bool is_row_major_order = is_same_type::value; + const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; + const CBLAS_TRANSPOSE TransA = CblasNoTrans; + const int M = static_cast(src.lhs.nr()); + const int N = static_cast(src.lhs.nc()); + const T* A = get_ptr(src.lhs); + const int lda = get_ld(src.lhs); + const T* X = get_ptr(src.rhs); + const int incX = get_inc(src.rhs); + + const T beta = static_cast(add_to?1:0); + T* Y = get_ptr(dest); + const int incY = get_inc(dest); + + cblas_gemv(Order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY); + } DLIB_END_BLAS_BINDING + + // -------------------------------------- + + DLIB_ADD_BLAS_BINDING(rv*m) + { + // Note that rv*m is the same as trans(m)*trans(rv) + + //cout << "BLAS GEMV: rv*m" << endl; + const bool is_row_major_order = is_same_type::value; + const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; + const CBLAS_TRANSPOSE TransA = CblasTrans; + const int M = static_cast(src.rhs.nr()); + const int N = static_cast(src.rhs.nc()); + const T* A = get_ptr(src.rhs); + const int lda = get_ld(src.rhs); + const T* X = get_ptr(src.lhs); + const int incX = get_inc(src.lhs); + + const T beta = static_cast(add_to?1:0); + T* Y = get_ptr(dest); + const int incY = get_inc(dest); + + cblas_gemv(Order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY); + } DLIB_END_BLAS_BINDING + + // -------------------------------------- + + DLIB_ADD_BLAS_BINDING(trans(cv)*m) + { + // Note that trans(cv)*m is the same as trans(m)*cv + + //cout << "BLAS GEMV: trans(cv)*m" << endl; + const bool is_row_major_order = is_same_type::value; + const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; + const CBLAS_TRANSPOSE TransA = CblasTrans; + const int M = static_cast(src.rhs.nr()); + const int N = static_cast(src.rhs.nc()); + const T* A = get_ptr(src.rhs); + const int lda = get_ld(src.rhs); + const T* X = get_ptr(src.lhs.op.m); + const int incX = get_inc(src.lhs.op.m); + + const T beta = static_cast(add_to?1:0); + T* Y = get_ptr(dest); + const int incY = get_inc(dest); + + cblas_gemv(Order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY); + } DLIB_END_BLAS_BINDING + + // -------------------------------------- + + DLIB_ADD_BLAS_BINDING(m*trans(rv)) + { + //cout << "BLAS GEMV: m*trans(rv)" << endl; + const bool is_row_major_order = is_same_type::value; + const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; + const CBLAS_TRANSPOSE TransA = CblasNoTrans; + const int M = static_cast(src.lhs.nr()); + const int N = static_cast(src.lhs.nc()); + const T* A = get_ptr(src.lhs); + const int lda = get_ld(src.lhs); + const T* X = get_ptr(src.rhs.op.m); + const int incX = get_inc(src.rhs.op.m); + + const T beta = static_cast(add_to?1:0); + T* Y = get_ptr(dest); + const int incY = get_inc(dest); + + cblas_gemv(Order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY); + } DLIB_END_BLAS_BINDING + + // -------------------------------------- + // -------------------------------------- + // -------------------------------------- + + DLIB_ADD_BLAS_BINDING(trans(m)*cv) + { + //cout << "BLAS GEMV: trans(m)*cv" << endl; + const bool is_row_major_order = is_same_type::value; + const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; + const CBLAS_TRANSPOSE TransA = CblasTrans; + const int M = static_cast(src.lhs.op.m.nr()); + const int N = static_cast(src.lhs.op.m.nc()); + const T* A = get_ptr(src.lhs.op.m); + const int lda = get_ld(src.lhs.op.m); + const T* X = get_ptr(src.rhs); + const int incX = get_inc(src.rhs); + + const T beta = static_cast(add_to?1:0); + T* Y = get_ptr(dest); + const int incY = get_inc(dest); + + cblas_gemv(Order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY); + } DLIB_END_BLAS_BINDING + + // -------------------------------------- + + DLIB_ADD_BLAS_BINDING(rv*trans(m)) + { + // Note that rv*trans(m) is the same as m*trans(rv) + + //cout << "BLAS GEMV: rv*trans(m)" << endl; + const bool is_row_major_order = is_same_type::value; + const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; + const CBLAS_TRANSPOSE TransA = CblasNoTrans; + const int M = static_cast(src.rhs.op.m.nr()); + const int N = static_cast(src.rhs.op.m.nc()); + const T* A = get_ptr(src.rhs.op.m); + const int lda = get_ld(src.rhs.op.m); + const T* X = get_ptr(src.lhs); + const int incX = get_inc(src.lhs); + + const T beta = static_cast(add_to?1:0); + T* Y = get_ptr(dest); + const int incY = get_inc(dest); + + cblas_gemv(Order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY); + } DLIB_END_BLAS_BINDING + + // -------------------------------------- + + DLIB_ADD_BLAS_BINDING(trans(cv)*trans(m)) + { + // Note that trans(cv)*trans(m) is the same as m*cv + + //cout << "BLAS GEMV: trans(cv)*trans(m)" << endl; + const bool is_row_major_order = is_same_type::value; + const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; + const CBLAS_TRANSPOSE TransA = CblasNoTrans; + const int M = static_cast(src.rhs.op.m.nr()); + const int N = static_cast(src.rhs.op.m.nc()); + const T* A = get_ptr(src.rhs.op.m); + const int lda = get_ld(src.rhs.op.m); + const T* X = get_ptr(src.lhs.op.m); + const int incX = get_inc(src.lhs.op.m); + + const T beta = static_cast(add_to?1:0); + T* Y = get_ptr(dest); + const int incY = get_inc(dest); + + cblas_gemv(Order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY); + } DLIB_END_BLAS_BINDING + + // -------------------------------------- + + DLIB_ADD_BLAS_BINDING(trans(m)*trans(rv)) + { + //cout << "BLAS GEMV: trans(m)*trans(rv)" << endl; + const bool is_row_major_order = is_same_type::value; + const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; + const CBLAS_TRANSPOSE TransA = CblasTrans; + const int M = static_cast(src.lhs.op.m.nr()); + const int N = static_cast(src.lhs.op.m.nc()); + const T* A = get_ptr(src.lhs.op.m); + const int lda = get_ld(src.lhs.op.m); + const T* X = get_ptr(src.rhs.op.m); + const int incX = get_inc(src.rhs.op.m); + + const T beta = static_cast(add_to?1:0); + T* Y = get_ptr(dest); + const int incY = get_inc(dest); + + cblas_gemv(Order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY); + } DLIB_END_BLAS_BINDING + + // -------------------------------------- + // -------------------------------------- + // -------------------------------------- + + DLIB_ADD_BLAS_BINDING(trans(cv)*conj(m)) + { + // Note that trans(cv)*conj(m) == conj(trans(m))*cv + //cout << "BLAS GEMV: trans(cv)*conj(m)" << endl; + const bool is_row_major_order = is_same_type::value; + const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; + const CBLAS_TRANSPOSE TransA = CblasConjTrans; + const int M = static_cast(src.rhs.op.m.nr()); + const int N = static_cast(src.rhs.op.m.nc()); + const T* A = get_ptr(src.rhs.op.m); + const int lda = get_ld(src.rhs.op.m); + const T* X = get_ptr(src.lhs.op.m); + const int incX = get_inc(src.lhs.op.m); + + const T beta = static_cast(add_to?1:0); + T* Y = get_ptr(dest); + const int incY = get_inc(dest); + + cblas_gemv(Order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY); + } DLIB_END_BLAS_BINDING + + // -------------------------------------- + + DLIB_ADD_BLAS_BINDING(rv*conj(m)) + { + // Note that rv*conj(m) == conj(trans(m))*cv + //cout << "BLAS GEMV: rv*conj(m)" << endl; + const bool is_row_major_order = is_same_type::value; + const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; + const CBLAS_TRANSPOSE TransA = CblasConjTrans; + const int M = static_cast(src.rhs.op.m.nr()); + const int N = static_cast(src.rhs.op.m.nc()); + const T* A = get_ptr(src.rhs.op.m); + const int lda = get_ld(src.rhs.op.m); + const T* X = get_ptr(src.lhs); + const int incX = get_inc(src.lhs); + + const T beta = static_cast(add_to?1:0); + T* Y = get_ptr(dest); + const int incY = get_inc(dest); + + cblas_gemv(Order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY); + } DLIB_END_BLAS_BINDING + + // -------------------------------------- + + DLIB_ADD_BLAS_BINDING(trans(conj(m))*cv) + { + //cout << "BLAS GEMV: trans(conj(m))*cv" << endl; + const bool is_row_major_order = is_same_type::value; + const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; + const CBLAS_TRANSPOSE TransA = CblasConjTrans; + const int M = static_cast(src.lhs.op.m.nr()); + const int N = static_cast(src.lhs.op.m.nc()); + const T* A = get_ptr(src.lhs.op.m); + const int lda = get_ld(src.lhs.op.m); + const T* X = get_ptr(src.rhs); + const int incX = get_inc(src.rhs); + + const T beta = static_cast(add_to?1:0); + T* Y = get_ptr(dest); + const int incY = get_inc(dest); + + cblas_gemv(Order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY); + } DLIB_END_BLAS_BINDING + + // -------------------------------------- + + DLIB_ADD_BLAS_BINDING(trans(conj(m))*trans(rv)) + { + //cout << "BLAS GEMV: trans(conj(m))*trans(rv)" << endl; + const bool is_row_major_order = is_same_type::value; + const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; + const CBLAS_TRANSPOSE TransA = CblasConjTrans; + const int M = static_cast(src.lhs.op.m.nr()); + const int N = static_cast(src.lhs.op.m.nc()); + const T* A = get_ptr(src.lhs.op.m); + const int lda = get_ld(src.lhs.op.m); + const T* X = get_ptr(src.rhs.op.m); + const int incX = get_inc(src.rhs.op.m); + + const T beta = static_cast(add_to?1:0); + T* Y = get_ptr(dest); + const int incY = get_inc(dest); + + cblas_gemv(Order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY); + } DLIB_END_BLAS_BINDING + + // ---------------------------------------------------------------------------------------- + // ---------------------------------------------------------------------------------------- + // GER overloads + // ---------------------------------------------------------------------------------------- + // ---------------------------------------------------------------------------------------- + + DLIB_ADD_BLAS_BINDING(cv*rv) + { + //cout << "BLAS GER: cv*rv" << endl; + const bool is_row_major_order = is_same_type::value; + const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; + const int M = static_cast(dest.nr()); + const int N = static_cast(dest.nc()); + const T* X = get_ptr(src.lhs); + const int incX = get_inc(src.lhs); + const T* Y = get_ptr(src.rhs); + const int incY = get_inc(src.rhs); + + if (add_to == false) + zero_matrix(dest); + + T* A = get_ptr(dest); + const int lda = get_ld(dest); + + if (transpose == false) + cblas_ger(Order, M, N, alpha, X, incX, Y, incY, A, lda); + else + cblas_ger(Order, M, N, alpha, Y, incY, X, incX, A, lda); + + } DLIB_END_BLAS_BINDING + + // -------------------------------------- + + DLIB_ADD_BLAS_BINDING(trans(rv)*rv) + { + //cout << "BLAS GER: trans(rv)*rv" << endl; + const bool is_row_major_order = is_same_type::value; + const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; + const int M = static_cast(dest.nr()); + const int N = static_cast(dest.nc()); + const T* X = get_ptr(src.lhs.op.m); + const int incX = get_inc(src.lhs.op.m); + const T* Y = get_ptr(src.rhs); + const int incY = get_inc(src.rhs); + + if (add_to == false) + zero_matrix(dest); + + T* A = get_ptr(dest); + const int lda = get_ld(dest); + + if (transpose == false) + cblas_ger(Order, M, N, alpha, X, incX, Y, incY, A, lda); + else + cblas_ger(Order, M, N, alpha, Y, incY, X, incX, A, lda); + } DLIB_END_BLAS_BINDING + + // -------------------------------------- + + DLIB_ADD_BLAS_BINDING(cv*trans(cv)) + { + //cout << "BLAS GER: cv*trans(cv)" << endl; + const bool is_row_major_order = is_same_type::value; + const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; + const int M = static_cast(dest.nr()); + const int N = static_cast(dest.nc()); + const T* X = get_ptr(src.lhs); + const int incX = get_inc(src.lhs); + const T* Y = get_ptr(src.rhs.op.m); + const int incY = get_inc(src.rhs.op.m); + + if (add_to == false) + zero_matrix(dest); + + T* A = get_ptr(dest); + const int lda = get_ld(dest); + + if (transpose == false) + cblas_ger(Order, M, N, alpha, X, incX, Y, incY, A, lda); + else + cblas_ger(Order, M, N, alpha, Y, incY, X, incX, A, lda); + } DLIB_END_BLAS_BINDING + + // -------------------------------------- + + DLIB_ADD_BLAS_BINDING(trans(rv)*trans(cv)) + { + //cout << "BLAS GER: trans(rv)*trans(cv)" << endl; + const bool is_row_major_order = is_same_type::value; + const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; + const int M = static_cast(dest.nr()); + const int N = static_cast(dest.nc()); + const T* X = get_ptr(src.lhs.op.m); + const int incX = get_inc(src.lhs.op.m); + const T* Y = get_ptr(src.rhs.op.m); + const int incY = get_inc(src.rhs.op.m); + + if (add_to == false) + zero_matrix(dest); + + T* A = get_ptr(dest); + const int lda = get_ld(dest); + + if (transpose == false) + cblas_ger(Order, M, N, alpha, X, incX, Y, incY, A, lda); + else + cblas_ger(Order, M, N, alpha, Y, incY, X, incX, A, lda); + } DLIB_END_BLAS_BINDING + + // ---------------------------------------------------------------------------------------- + // ---------------------------------------------------------------------------------------- + // GERC overloads + // ---------------------------------------------------------------------------------------- + // ---------------------------------------------------------------------------------------- + + /* + DLIB_ADD_BLAS_BINDING(cv*conj(rv)) + { + //cout << "BLAS GERC: cv*conj(rv)" << endl; + const bool is_row_major_order = is_same_type::value; + const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; + const int M = static_cast(dest.nr()); + const int N = static_cast(dest.nc()); + const T* X = get_ptr(src.lhs); + const int incX = get_inc(src.lhs); + const T* Y = get_ptr(src.rhs.op.m); + const int incY = get_inc(src.rhs.op.m); + + if (add_to == false) + zero_matrix(dest); + + T* A = get_ptr(dest); + const int lda = get_ld(dest); + + if (transpose == false) + cblas_gerc(Order, M, N, alpha, X, incX, Y, incY, A, lda); + else + cblas_gerc(Order, M, N, alpha, Y, incY, X, incX, A, lda); + } DLIB_END_BLAS_BINDING + */ + + // -------------------------------------- + + DLIB_ADD_BLAS_BINDING(cv*conj(trans(cv))) + { + //cout << "BLAS GERC: cv*conj(trans(cv))" << endl; + const bool is_row_major_order = is_same_type::value; + const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; + const int M = static_cast(dest.nr()); + const int N = static_cast(dest.nc()); + const T* X = get_ptr(src.lhs); + const int incX = get_inc(src.lhs); + const T* Y = get_ptr(src.rhs.op.m); + const int incY = get_inc(src.rhs.op.m); + + + if (transpose == false) + { + T* A = get_ptr(dest); + const int lda = get_ld(dest); + + if (add_to == false) + zero_matrix(dest); + + cblas_gerc(Order, M, N, alpha, X, incX, Y, incY, A, lda); + } + else + { + matrix_assign_default(dest,trans(src),alpha,add_to); + //cblas_gerc(Order, M, N, alpha, Y, incY, X, incX, A, lda); + } + } DLIB_END_BLAS_BINDING + + // -------------------------------------- + + DLIB_ADD_BLAS_BINDING(trans(rv)*conj(trans(cv))) + { + //cout << "BLAS GERC: trans(rv)*conj(trans(cv))" << endl; + const bool is_row_major_order = is_same_type::value; + const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; + const int M = static_cast(dest.nr()); + const int N = static_cast(dest.nc()); + const T* X = get_ptr(src.lhs.op.m); + const int incX = get_inc(src.lhs.op.m); + const T* Y = get_ptr(src.rhs.op.m); + const int incY = get_inc(src.rhs.op.m); + + + if (transpose == false) + { + if (add_to == false) + zero_matrix(dest); + + T* A = get_ptr(dest); + const int lda = get_ld(dest); + + cblas_gerc(Order, M, N, alpha, X, incX, Y, incY, A, lda); + } + else + { + matrix_assign_default(dest,trans(src),alpha,add_to); + //cblas_gerc(Order, M, N, alpha, Y, incY, X, incX, A, lda); + } + } DLIB_END_BLAS_BINDING + + // -------------------------------------- + + /* + DLIB_ADD_BLAS_BINDING(trans(rv)*conj(rv)) + { + //cout << "BLAS GERC: trans(rv)*conj(rv)" << endl; + const bool is_row_major_order = is_same_type::value; + const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; + const int M = static_cast(dest.nr()); + const int N = static_cast(dest.nc()); + const T* X = get_ptr(src.lhs.op.m); + const int incX = get_inc(src.lhs.op.m); + const T* Y = get_ptr(src.rhs.op.m); + const int incY = get_inc(src.rhs.op.m); + + if (add_to == false) + zero_matrix(dest); + + T* A = get_ptr(dest); + const int lda = get_ld(dest); + + if (transpose == false) + cblas_gerc(Order, M, N, alpha, X, incX, Y, incY, A, lda); + else + cblas_gerc(Order, M, N, alpha, Y, incY, X, incX, A, lda); + } DLIB_END_BLAS_BINDING + */ + + // ---------------------------------------------------------------------------------------- + // ---------------------------------------------------------------------------------------- + // DOT overloads + // ---------------------------------------------------------------------------------------- + // ---------------------------------------------------------------------------------------- + + DLIB_ADD_BLAS_BINDING(rv*cv) + { + //cout << "BLAS DOT: rv*cv" << endl; + const int N = static_cast(src.lhs.size()); + const T* X = get_ptr(src.lhs); + const int incX = get_inc(src.lhs); + const T* Y = get_ptr(src.rhs); + const int incY = get_inc(src.rhs); + + if (add_to == false) + dest(0) = alpha*cblas_dot(N, X, incX, Y, incY); + else + dest(0) += alpha*cblas_dot(N, X, incX, Y, incY); + + } DLIB_END_BLAS_BINDING + + // -------------------------------------- + + DLIB_ADD_BLAS_BINDING(trans(cv)*cv) + { + //cout << "BLAS DOT: trans(cv)*cv" << endl; + const int N = static_cast(src.lhs.size()); + const T* X = get_ptr(src.lhs.op.m); + const int incX = get_inc(src.lhs.op.m); + const T* Y = get_ptr(src.rhs); + const int incY = get_inc(src.rhs); + + if (add_to == false) + dest(0) = alpha*cblas_dot(N, X, incX, Y, incY); + else + dest(0) += alpha*cblas_dot(N, X, incX, Y, incY); + + } DLIB_END_BLAS_BINDING + + // -------------------------------------- + + DLIB_ADD_BLAS_BINDING(rv*trans(rv)) + { + //cout << "BLAS DOT: rv*trans(rv)" << endl; + const int N = static_cast(src.lhs.size()); + const T* X = get_ptr(src.lhs); + const int incX = get_inc(src.lhs); + const T* Y = get_ptr(src.rhs.op.m); + const int incY = get_inc(src.rhs.op.m); + + if (add_to == false) + dest(0) = alpha*cblas_dot(N, X, incX, Y, incY); + else + dest(0) += alpha*cblas_dot(N, X, incX, Y, incY); + + } DLIB_END_BLAS_BINDING + + // -------------------------------------- + + DLIB_ADD_BLAS_BINDING(trans(cv)*trans(rv)) + { + //cout << "BLAS DOT: trans(cv)*trans(rv)" << endl; + const int N = static_cast(src.lhs.op.m.size()); + const T* X = get_ptr(src.lhs.op.m); + const int incX = get_inc(src.lhs.op.m); + const T* Y = get_ptr(src.rhs.op.m); + const int incY = get_inc(src.rhs.op.m); + + if (add_to == false) + dest(0) = alpha*cblas_dot(N, X, incX, Y, incY); + else + dest(0) += alpha*cblas_dot(N, X, incX, Y, incY); + + } DLIB_END_BLAS_BINDING + + // ---------------------------------------------------------------------------------------- + // ---------------------------------------------------------------------------------------- + // DOTC overloads + // ---------------------------------------------------------------------------------------- + // ---------------------------------------------------------------------------------------- + + DLIB_ADD_BLAS_BINDING(conj(rv)*cv) + { + //cout << "BLAS DOTC: conj(rv)*cv" << endl; + const int N = static_cast(src.lhs.op.m.size()); + const T* X = get_ptr(src.lhs.op.m); + const int incX = get_inc(src.lhs.op.m); + const T* Y = get_ptr(src.rhs); + const int incY = get_inc(src.rhs); + + if (add_to == false) + dest(0) = alpha*cblas_dotc(N, X, incX, Y, incY); + else + dest(0) += alpha*cblas_dotc(N, X, incX, Y, incY); + + } DLIB_END_BLAS_BINDING + + // -------------------------------------- + + DLIB_ADD_BLAS_BINDING(conj(trans(cv))*cv) + { + //cout << "BLAS DOTC: conj(trans(cv))*cv" << endl; + const int N = static_cast(src.lhs.op.m.size()); + const T* X = get_ptr(src.lhs.op.m); + const int incX = get_inc(src.lhs.op.m); + const T* Y = get_ptr(src.rhs); + const int incY = get_inc(src.rhs); + + if (add_to == false) + dest(0) = alpha*cblas_dotc(N, X, incX, Y, incY); + else + dest(0) += alpha*cblas_dotc(N, X, incX, Y, incY); + + } DLIB_END_BLAS_BINDING + + // -------------------------------------- + + DLIB_ADD_BLAS_BINDING(trans(conj(cv))*trans(rv)) + { + //cout << "BLAS DOTC: trans(conj(cv))*trans(rv)" << endl; + const int N = static_cast(src.lhs.op.m.size()); + const T* X = get_ptr(src.lhs.op.m); + const int incX = get_inc(src.lhs.op.m); + const T* Y = get_ptr(src.rhs.op.m); + const int incY = get_inc(src.rhs.op.m); + + if (add_to == false) + dest(0) = alpha*cblas_dotc(N, X, incX, Y, incY); + else + dest(0) += alpha*cblas_dotc(N, X, incX, Y, incY); + + } DLIB_END_BLAS_BINDING + + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MATRIx_BLAS_BINDINGS_ + diff --git a/ml/dlib/dlib/matrix/matrix_cholesky.h b/ml/dlib/dlib/matrix/matrix_cholesky.h new file mode 100644 index 000000000..fc1140692 --- /dev/null +++ b/ml/dlib/dlib/matrix/matrix_cholesky.h @@ -0,0 +1,231 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +// This code was adapted from code from the JAMA part of NIST's TNT library. +// See: http://math.nist.gov/tnt/ +#ifndef DLIB_MATRIX_CHOLESKY_DECOMPOSITION_H +#define DLIB_MATRIX_CHOLESKY_DECOMPOSITION_H + +#include "matrix.h" +#include "matrix_utilities.h" +#include "matrix_subexp.h" +#include + +#ifdef DLIB_USE_LAPACK +#include "lapack/potrf.h" +#endif + +#include "matrix_trsm.h" + +namespace dlib +{ + + template < + typename matrix_exp_type + > + class cholesky_decomposition + { + + public: + + const static long NR = matrix_exp_type::NR; + const static long NC = matrix_exp_type::NC; + typedef typename matrix_exp_type::type type; + typedef typename matrix_exp_type::mem_manager_type mem_manager_type; + typedef typename matrix_exp_type::layout_type layout_type; + + typedef matrix matrix_type; + typedef matrix column_vector_type; + + // You have supplied an invalid type of matrix_exp_type. You have + // to use this object with matrices that contain float or double type data. + COMPILE_TIME_ASSERT((is_same_type::value || + is_same_type::value )); + + + + template + cholesky_decomposition( + const matrix_exp& A + ); + + bool is_spd( + ) const; + + const matrix_type& get_l( + ) const; + + template + const typename EXP::matrix_type solve ( + const matrix_exp& B + ) const; + + private: + + matrix_type L_; // lower triangular factor + bool isspd; // true if matrix to be factored was SPD + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// Member functions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template + bool cholesky_decomposition:: + is_spd( + ) const + { + return isspd; + } + +// ---------------------------------------------------------------------------------------- + + template + const typename cholesky_decomposition::matrix_type& cholesky_decomposition:: + get_l( + ) const + { + return L_; + } + +// ---------------------------------------------------------------------------------------- + + template + template + cholesky_decomposition:: + cholesky_decomposition( + const matrix_exp& A_ + ) + { + using std::sqrt; + COMPILE_TIME_ASSERT((is_same_type::value)); + + // make sure requires clause is not broken + DLIB_ASSERT(A_.nr() == A_.nc() && A_.size() > 0, + "\tcholesky_decomposition::cholesky_decomposition(A_)" + << "\n\tYou can only use this on square matrices" + << "\n\tA_.nr(): " << A_.nr() + << "\n\tA_.nc(): " << A_.nc() + << "\n\tA_.size(): " << A_.size() + << "\n\tthis: " << this + ); + +#ifdef DLIB_USE_LAPACK + L_ = A_; + const type eps = max(abs(diag(L_)))*std::sqrt(std::numeric_limits::epsilon())/100; + + // check if the matrix is actually symmetric + bool is_symmetric = true; + for (long r = 0; r < L_.nr() && is_symmetric; ++r) + { + for (long c = r+1; c < L_.nc() && is_symmetric; ++c) + { + // this is approximately doing: is_symmetric = is_symmetric && ( L_(k,j) == L_(j,k)) + is_symmetric = is_symmetric && (std::abs(L_(r,c) - L_(c,r)) < eps ); + } + } + + // now compute the actual cholesky decomposition + int info = lapack::potrf('L', L_); + + // check if it's really SPD + if (info == 0 && is_symmetric && min(abs(diag(L_))) > eps*100) + isspd = true; + else + isspd = false; + + L_ = lowerm(L_); +#else + const_temp_matrix A(A_); + + + isspd = true; + + const long n = A.nc(); + L_.set_size(n,n); + + const type eps = max(abs(diag(A)))*std::sqrt(std::numeric_limits::epsilon())/100; + + // Main loop. + for (long j = 0; j < n; j++) + { + type d(0.0); + for (long k = 0; k < j; k++) + { + type s(0.0); + for (long i = 0; i < k; i++) + { + s += L_(k,i)*L_(j,i); + } + + // if L_(k,k) != 0 + if (std::abs(L_(k,k)) > eps) + { + s = (A(j,k) - s)/L_(k,k); + } + else + { + s = (A(j,k) - s); + isspd = false; + } + + L_(j,k) = s; + + d = d + s*s; + + // this is approximately doing: isspd = isspd && ( A(k,j) == A(j,k)) + isspd = isspd && (std::abs(A(k,j) - A(j,k)) < eps ); + } + d = A(j,j) - d; + isspd = isspd && (d > eps); + L_(j,j) = sqrt(d > 0.0 ? d : 0.0); + for (long k = j+1; k < n; k++) + { + L_(j,k) = 0.0; + } + } +#endif + } + +// ---------------------------------------------------------------------------------------- + + template + template + const typename EXP::matrix_type cholesky_decomposition:: + solve( + const matrix_exp& B + ) const + { + COMPILE_TIME_ASSERT((is_same_type::value)); + + // make sure requires clause is not broken + DLIB_ASSERT(L_.nr() == B.nr(), + "\tconst matrix cholesky_decomposition::solve(B)" + << "\n\tInvalid arguments were given to this function." + << "\n\tL_.nr(): " << L_.nr() + << "\n\tB.nr(): " << B.nr() + << "\n\tthis: " << this + ); + + matrix X(B); + + using namespace blas_bindings; + // Solve L*y = b; + triangular_solver(CblasLeft, CblasLower, CblasNoTrans, CblasNonUnit, L_, X); + // Solve L'*X = Y; + triangular_solver(CblasLeft, CblasLower, CblasTrans, CblasNonUnit, L_, X); + return X; + } + +// ---------------------------------------------------------------------------------------- + + + +} + +#endif // DLIB_MATRIX_CHOLESKY_DECOMPOSITION_H + + + + diff --git a/ml/dlib/dlib/matrix/matrix_conj_trans.h b/ml/dlib/dlib/matrix/matrix_conj_trans.h new file mode 100644 index 000000000..3c319ccaf --- /dev/null +++ b/ml/dlib/dlib/matrix/matrix_conj_trans.h @@ -0,0 +1,71 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MATRIx_CONJ_TRANS_FUNCTIONS +#define DLIB_MATRIx_CONJ_TRANS_FUNCTIONS + +#include "matrix_utilities.h" +#include "matrix_math_functions.h" +#include "matrix.h" +#include "../algs.h" +#include +#include +#include + + +namespace dlib +{ + /*! + The point of the two functions defined in this file is to make statements + of the form conj(trans(m)) and trans(conj(m)) look the same so that it is + easier to map them to BLAS functions later on. + !*/ + +// ---------------------------------------------------------------------------------------- + + template + struct op_conj_trans + { + op_conj_trans( const M& m_) : m(m_){} + const M& m; + + const static long cost = M::cost+1; + const static long NR = M::NC; + const static long NC = M::NR; + typedef typename M::type type; + typedef typename M::type const_ret_type; + typedef typename M::mem_manager_type mem_manager_type; + typedef typename M::layout_type layout_type; + const_ret_type apply (long r, long c) const { return std::conj(m(c,r)); } + + long nr () const { return m.nc(); } + long nc () const { return m.nr(); } + + template bool aliases ( const matrix_exp& item) const { return m.aliases(item); } + template bool destructively_aliases ( const matrix_exp& item) const { return m.aliases(item); } + }; + + template + const matrix_op > trans ( + const matrix_op >& m + ) + { + typedef op_conj_trans op; + return matrix_op(op(m.op.m)); + } + + template + const matrix_op > conj ( + const matrix_op >& m + ) + { + typedef op_conj_trans op; + return matrix_op(op(m.op.m)); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MATRIx_CONJ_TRANS_FUNCTIONS + + diff --git a/ml/dlib/dlib/matrix/matrix_conv.h b/ml/dlib/dlib/matrix/matrix_conv.h new file mode 100644 index 000000000..b90c388bc --- /dev/null +++ b/ml/dlib/dlib/matrix/matrix_conv.h @@ -0,0 +1,358 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MATRIx_CONV_Hh_ +#define DLIB_MATRIx_CONV_Hh_ + +#include "matrix_conv_abstract.h" +#include "matrix.h" +#include "matrix_fft.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + template + const T& conj(const T& item) { return item; } + template + std::complex conj(const std::complex& item) { return std::conj(item); } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template + struct op_conv + { + op_conv( const M1& m1_, const M2& m2_) : + m1(m1_), + m2(m2_), + nr_(m1.nr()+m2.nr()-1), + nc_(m1.nc()+m2.nc()-1) + { + if (nr_ < 0 || m1.size() == 0 || m2.size() == 0) + nr_ = 0; + if (nc_ < 0 || m1.size() == 0 || m2.size() == 0) + nc_ = 0; + } + + const M1& m1; + const M2& m2; + long nr_; + long nc_; + + const static long cost = (M1::cost+M2::cost)*10; + const static long NR = (M1::NR*M2::NR==0) ? (0) : (M1::NR+M2::NR-1); + const static long NC = (M1::NC*M2::NC==0) ? (0) : (M1::NC+M2::NC-1); + typedef typename M1::type type; + typedef type const_ret_type; + typedef typename M1::mem_manager_type mem_manager_type; + typedef typename M1::layout_type layout_type; + + const_ret_type apply (long r, long c) const + { + type temp = 0; + + const long min_rr = std::max(r-m2.nr()+1, 0); + const long max_rr = std::min(m1.nr()-1, r); + + const long min_cc = std::max(c-m2.nc()+1, 0); + const long max_cc = std::min(m1.nc()-1, c); + + for (long rr = min_rr; rr <= max_rr; ++rr) + { + for (long cc = min_cc; cc <= max_cc; ++cc) + { + if (flip_m2) + temp += m1(rr,cc)*dlib::impl::conj(m2(m2.nr()-r+rr-1, m2.nc()-c+cc-1)); + else + temp += m1(rr,cc)*m2(r-rr,c-cc); + } + } + + return temp; + } + + long nr () const { return nr_; } + long nc () const { return nc_; } + + template bool aliases ( const matrix_exp& item) const { return m1.aliases(item) || m2.aliases(item); } + template bool destructively_aliases ( const matrix_exp& item) const { return m1.aliases(item) || m2.aliases(item); } + + }; + + template < + typename M1, + typename M2 + > + const matrix_op > conv ( + const matrix_exp& m1, + const matrix_exp& m2 + ) + { + COMPILE_TIME_ASSERT((is_same_type::value == true)); + + typedef op_conv op; + return matrix_op(op(m1.ref(),m2.ref())); + } + + template < + typename M1, + typename M2 + > + const matrix_op > xcorr ( + const matrix_exp& m1, + const matrix_exp& m2 + ) + { + COMPILE_TIME_ASSERT((is_same_type::value == true)); + + typedef op_conv op; + return matrix_op(op(m1.ref(),m2.ref())); + } + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + inline size_t bounding_power_of_two ( + size_t n + ) + { + size_t s = 1; + for (unsigned int i = 0; i < sizeof(s)*8 && s < n; ++i) + s <<= 1; + return s; + } + } + + template < + typename EXP1, + typename EXP2 + > + typename EXP1::matrix_type xcorr_fft( + const matrix_exp& u, + const matrix_exp& v + ) + { + COMPILE_TIME_ASSERT((is_same_type::value == true)); + using T = typename EXP1::type; + COMPILE_TIME_ASSERT((is_same_type::value || is_same_type::value || is_same_type::value )); + + const long pad_nr = impl::bounding_power_of_two(u.nr() + v.nr() - 1); + const long pad_nc = impl::bounding_power_of_two(u.nc() + v.nc() - 1); + + matrix> U(pad_nr, pad_nc), V(pad_nr,pad_nc); + + U = 0; + V = 0; + set_subm(U,U.nr()-u.nr(),U.nc()-u.nc(),u.nr(),u.nc()) = u; + set_subm(V,get_rect(v)) = v; + + fft_inplace(U); + fft_inplace(V); + + return subm(real(ifft(pointwise_multiply(U, conj(V)))), + U.nr()-u.nr()-v.nr()+1, + U.nc()-u.nc()-v.nc()+1, + u.nr()+v.nr()-1, + u.nc()+v.nc()-1 + ); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template + struct op_conv_same + { + op_conv_same( const M1& m1_, const M2& m2_) : m1(m1_),m2(m2_),nr_(m1.nr()),nc_(m1.nc()) + { + if (m1.size() == 0 || m2.size() == 0) + nr_ = 0; + if (m1.size() == 0 || m2.size() == 0) + nc_ = 0; + } + + const M1& m1; + const M2& m2; + long nr_; + long nc_; + + const static long cost = (M1::cost+M2::cost)*10; + const static long NR = M1::NR; + const static long NC = M1::NC; + typedef typename M1::type type; + typedef type const_ret_type; + typedef typename M1::mem_manager_type mem_manager_type; + typedef typename M1::layout_type layout_type; + + const_ret_type apply (long r, long c) const + { + r += m2.nr()/2; + c += m2.nc()/2; + + type temp = 0; + + const long min_rr = std::max(r-m2.nr()+1, 0); + const long max_rr = std::min(m1.nr()-1, r); + + const long min_cc = std::max(c-m2.nc()+1, 0); + const long max_cc = std::min(m1.nc()-1, c); + + for (long rr = min_rr; rr <= max_rr; ++rr) + { + for (long cc = min_cc; cc <= max_cc; ++cc) + { + if (flip_m2) + temp += m1(rr,cc)*dlib::impl::conj(m2(m2.nr()-r+rr-1, m2.nc()-c+cc-1)); + else + temp += m1(rr,cc)*m2(r-rr,c-cc); + } + } + + return temp; + } + + long nr () const { return nr_; } + long nc () const { return nc_; } + + template bool aliases ( const matrix_exp& item) const { return m1.aliases(item) || m2.aliases(item); } + template bool destructively_aliases ( const matrix_exp& item) const { return m1.aliases(item) || m2.aliases(item); } + + }; + + template < + typename M1, + typename M2 + > + const matrix_op > conv_same ( + const matrix_exp& m1, + const matrix_exp& m2 + ) + { + COMPILE_TIME_ASSERT((is_same_type::value == true)); + + typedef op_conv_same op; + return matrix_op(op(m1.ref(),m2.ref())); + } + + template < + typename M1, + typename M2 + > + const matrix_op > xcorr_same ( + const matrix_exp& m1, + const matrix_exp& m2 + ) + { + COMPILE_TIME_ASSERT((is_same_type::value == true)); + + typedef op_conv_same op; + return matrix_op(op(m1.ref(),m2.ref())); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template + struct op_conv_valid + { + op_conv_valid( const M1& m1_, const M2& m2_) : + m1(m1_),m2(m2_), + nr_(m1.nr()-m2.nr()+1), + nc_(m1.nc()-m2.nc()+1) + { + if (nr_ < 0 || nc_ <= 0 || m1.size() == 0 || m2.size() == 0) + nr_ = 0; + if (nc_ < 0 || nr_ <= 0 || m1.size() == 0 || m2.size() == 0) + nc_ = 0; + } + + const M1& m1; + const M2& m2; + long nr_; + long nc_; + + const static long cost = (M1::cost+M2::cost)*10; + const static long NR = (M1::NR*M2::NR==0) ? (0) : (M1::NR-M2::NR+1); + const static long NC = (M1::NC*M2::NC==0) ? (0) : (M1::NC-M2::NC+1); + typedef typename M1::type type; + typedef type const_ret_type; + typedef typename M1::mem_manager_type mem_manager_type; + typedef typename M1::layout_type layout_type; + + const_ret_type apply (long r, long c) const + { + r += m2.nr()-1; + c += m2.nc()-1; + + type temp = 0; + + const long min_rr = std::max(r-m2.nr()+1, 0); + const long max_rr = std::min(m1.nr()-1, r); + + const long min_cc = std::max(c-m2.nc()+1, 0); + const long max_cc = std::min(m1.nc()-1, c); + + for (long rr = min_rr; rr <= max_rr; ++rr) + { + for (long cc = min_cc; cc <= max_cc; ++cc) + { + if (flip_m2) + temp += m1(rr,cc)*dlib::impl::conj(m2(m2.nr()-r+rr-1, m2.nc()-c+cc-1)); + else + temp += m1(rr,cc)*m2(r-rr,c-cc); + } + } + + return temp; + } + + long nr () const { return nr_; } + long nc () const { return nc_; } + + template bool aliases ( const matrix_exp& item) const { return m1.aliases(item) || m2.aliases(item); } + template bool destructively_aliases ( const matrix_exp& item) const { return m1.aliases(item) || m2.aliases(item); } + + }; + + template < + typename M1, + typename M2 + > + const matrix_op > conv_valid ( + const matrix_exp& m1, + const matrix_exp& m2 + ) + { + COMPILE_TIME_ASSERT((is_same_type::value == true)); + + typedef op_conv_valid op; + return matrix_op(op(m1.ref(),m2.ref())); + } + + template < + typename M1, + typename M2 + > + const matrix_op > xcorr_valid ( + const matrix_exp& m1, + const matrix_exp& m2 + ) + { + COMPILE_TIME_ASSERT((is_same_type::value == true)); + + typedef op_conv_valid op; + return matrix_op(op(m1.ref(),m2.ref())); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MATRIx_CONV_Hh_ + diff --git a/ml/dlib/dlib/matrix/matrix_conv_abstract.h b/ml/dlib/dlib/matrix/matrix_conv_abstract.h new file mode 100644 index 000000000..b342f2668 --- /dev/null +++ b/ml/dlib/dlib/matrix/matrix_conv_abstract.h @@ -0,0 +1,158 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_MATRIx_CONV_ABSTRACT_Hh_ +#ifdef DLIB_MATRIx_CONV_ABSTRACT_Hh_ + +#include "matrix_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp conv ( + const matrix_exp& m1, + const matrix_exp& m2 + ); + /*! + requires + - m1 and m2 both contain elements of the same type + ensures + - returns a matrix R such that: + - R is the convolution of m1 with m2. In particular, this function is + equivalent to performing the following in matlab: R = conv2(m1,m2). + - R::type == the same type that was in m1 and m2. + - R.nr() == m1.nr()+m2.nr()-1 + - R.nc() == m1.nc()+m2.nc()-1 + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp xcorr ( + const matrix_exp& m1, + const matrix_exp& m2 + ); + /*! + requires + - m1 and m2 both contain elements of the same type + ensures + - returns a matrix R such that: + - R is the cross-correlation of m1 with m2. In particular, this + function returns conv(m1,flip(m2)) if the matrices contain real + elements and conv(m1,flip(conj(m2))) if they are complex. + - R::type == the same type that was in m1 and m2. + - R.nr() == m1.nr()+m2.nr()-1 + - R.nc() == m1.nc()+m2.nc()-1 + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp xcorr_fft ( + const matrix_exp& m1, + const matrix_exp& m2 + ); + /*! + requires + - m1 and m2 both contain elements of the same type + - m1 and m2 contain real or complex values and must be double, float, or long + double valued. (e.g. not integers) + ensures + - This function is identical to xcorr() except that it uses a fast Fourier + transform to do the convolution and is therefore much faster when both m1 and + m2 are large. + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp conv_same ( + const matrix_exp& m1, + const matrix_exp& m2 + ); + /*! + requires + - m1 and m2 both contain elements of the same type + ensures + - returns a matrix R such that: + - R is the convolution of m1 with m2. In particular, this function is + equivalent to performing the following in matlab: R = conv2(m1,m2,'same'). + In particular, this means the result will have the same dimensions as m1 and will + contain the central part of the full convolution. Therefore, conv_same(m1,m2) is + equivalent to subm(conv(m1,m2), m2.nr()/2, m2.nc()/2, m1.nr(), m1.nc()). + - R::type == the same type that was in m1 and m2. + - R.nr() == m1.nr() + - R.nc() == m1.nc() + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp xcorr_same ( + const matrix_exp& m1, + const matrix_exp& m2 + ); + /*! + requires + - m1 and m2 both contain elements of the same type + ensures + - returns a matrix R such that: + - R is the cross-correlation of m1 with m2. In particular, this + function returns conv_same(m1,flip(m2)) if the matrices contain real + elements and conv_same(m1,flip(conj(m2))) if they are complex. + - R::type == the same type that was in m1 and m2. + - R.nr() == m1.nr() + - R.nc() == m1.nc() + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp conv_valid ( + const matrix_exp& m1, + const matrix_exp& m2 + ); + /*! + requires + - m1 and m2 both contain elements of the same type + ensures + - returns a matrix R such that: + - R is the convolution of m1 with m2. In particular, this function is + equivalent to performing the following in matlab: R = conv2(m1,m2,'valid'). + In particular, this means only elements of the convolution which don't require + zero padding are included in the result. + - R::type == the same type that was in m1 and m2. + - if (m1 has larger dimensions than m2) then + - R.nr() == m1.nr()-m2.nr()+1 + - R.nc() == m1.nc()-m2.nc()+1 + - else + - R.nr() == 0 + - R.nc() == 0 + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp xcorr_valid ( + const matrix_exp& m1, + const matrix_exp& m2 + ); + /*! + requires + - m1 and m2 both contain elements of the same type + ensures + - returns a matrix R such that: + - R is the cross-correlation of m1 with m2. In particular, this + function returns conv_valid(m1,flip(m2)) if the matrices contain real + elements and conv_valid(m1,flip(conj(m2))) if they are complex. + - R::type == the same type that was in m1 and m2. + - if (m1 has larger dimensions than m2) then + - R.nr() == m1.nr()-m2.nr()+1 + - R.nc() == m1.nc()-m2.nc()+1 + - else + - R.nr() == 0 + - R.nc() == 0 + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MATRIx_CONV_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/matrix/matrix_data_layout.h b/ml/dlib/dlib/matrix/matrix_data_layout.h new file mode 100644 index 000000000..22891c228 --- /dev/null +++ b/ml/dlib/dlib/matrix/matrix_data_layout.h @@ -0,0 +1,1271 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MATRIx_DATA_LAYOUT_ +#define DLIB_MATRIx_DATA_LAYOUT_ + +#include "../algs.h" +#include "matrix_fwd.h" +#include "matrix_data_layout_abstract.h" +#ifdef MATLAB_MEX_FILE +#include +#endif + +// GCC 4.8 gives false alarms about some matrix operations going out of bounds. Disable +// these false warnings. +#if defined(__GNUC__) && ((__GNUC__ >= 4 && __GNUC_MINOR__ >= 8) || (__GNUC__ > 4)) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Warray-bounds" +#endif + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + /*! + A matrix layout object is any object that contains a templated class called "layout" + with an interface identical to one below: + (Note that all the template arguments are just the template arguments from the dlib::matrix + object and the member functions are defined identically to the ones with the same + signatures inside the matrix object.) + + struct matrix_layout + { + template < + typename T, + long num_rows, + long num_cols, + typename mem_manager + > + class layout + { + public: + + T& operator() ( + long r, + long c + ); + + const T& operator() ( + long r, + long c + ); + + T& operator() ( + long i + ); + + const T& operator() ( + long i + ) const; + + void swap( + layout& item + ); + + long nr ( + ) const; + + long nc ( + ) const; + + void set_size ( + long nr_, + long nc_ + ); + }; + }; + !*/ + +// ---------------------------------------------------------------------------------------- + + struct row_major_layout + { + // if a matrix is bigger than this many bytes then don't put it on the stack + const static size_t max_stack_based_size = 256; + + // this is a hack to avoid a compile time error in visual studio 8. I would just + // use sizeof(T) and be done with it but that won't compile. The idea here + // is to avoid using the stack allocation of the layout object if it + // is going to contain another matrix and also avoid asking for the sizeof() + // the contained matrix. + template + struct get_sizeof_helper + { + const static std::size_t val = sizeof(T); + }; + + template + struct get_sizeof_helper > + { + const static std::size_t val = 1000000; + }; + + template < + typename T, + long num_rows, + long num_cols, + typename mem_manager, + int val = static_switch < + // when the sizes are all non zero and small + (num_rows*num_cols*get_sizeof_helper::val <= max_stack_based_size) && (num_rows != 0 && num_cols != 0), + // when the sizes are all non zero and big + (num_rows*num_cols*get_sizeof_helper::val > max_stack_based_size) && (num_rows != 0 && num_cols != 0), + num_rows == 0 && num_cols != 0, + num_rows != 0 && num_cols == 0, + num_rows == 0 && num_cols == 0 + >::value + > + class layout ; + /*! + WHAT THIS OBJECT REPRESENTS + This object represents the actual allocation of space for a matrix. + Small matrices allocate all their data on the stack and bigger ones + use a memory_manager to get their memory. + !*/ + + // ------------------------------------------------------------------------------------ + + template < + typename T, + long num_rows, + long num_cols, + typename mem_manager + > + class layout : noncopyable // when the sizes are all non zero and small + { + public: + const static long NR = num_rows; + const static long NC = num_cols; + + layout() {} + + T& operator() ( + long r, + long c + ) { return *(data+r*num_cols + c); } + + const T& operator() ( + long r, + long c + ) const { return *(data+r*num_cols + c); } + + T& operator() ( + long i + ) { return data[i]; } + + const T& operator() ( + long i + ) const { return data[i]; } + + void swap( + layout& item + ) + { + for (long r = 0; r < num_rows; ++r) + { + for (long c = 0; c < num_cols; ++c) + { + exchange((*this)(r,c),item(r,c)); + } + } + } + + long nr ( + ) const { return num_rows; } + + long nc ( + ) const { return num_cols; } + + void set_size ( + long , + long + ) + { + } + +#ifdef MATLAB_MEX_FILE + void _private_set_mxArray ( mxArray* ) { DLIB_CASSERT(false, "This function should never be called."); } + mxArray* _private_release_mxArray(){DLIB_CASSERT(false, "This function should never be called."); } + void _private_mark_owned_by_matlab() {DLIB_CASSERT(false, "This function should never be called."); } + bool _private_is_owned_by_matlab() const { return false; } +#endif + + private: + T data[num_rows*num_cols]; + }; + + // ------------------------------------------------------------------------------------ + + template < + typename T, + long num_rows, + long num_cols, + typename mem_manager + > + class layout : noncopyable // when the sizes are all non zero and big + { + public: + const static long NR = num_rows; + const static long NC = num_cols; + + layout ( + ) { data = pool.allocate_array(num_rows*num_cols); } + + ~layout () + { pool.deallocate_array(data); } + + T& operator() ( + long r, + long c + ) { return data[r*num_cols + c]; } + + const T& operator() ( + long r, + long c + ) const { return data[r*num_cols + c]; } + + T& operator() ( + long i + ) { return data[i]; } + + const T& operator() ( + long i + ) const { return data[i]; } + + void swap( + layout& item + ) + { + std::swap(item.data,data); + pool.swap(item.pool); + } + + long nr ( + ) const { return num_rows; } + + long nc ( + ) const { return num_cols; } + + void set_size ( + long , + long + ) + { + } + +#ifdef MATLAB_MEX_FILE + void _private_set_mxArray ( mxArray* ) { DLIB_CASSERT(false, "This function should never be called."); } + mxArray* _private_release_mxArray(){DLIB_CASSERT(false, "This function should never be called."); } + void _private_mark_owned_by_matlab() {DLIB_CASSERT(false, "This function should never be called."); } + bool _private_is_owned_by_matlab() const { return false; } +#endif + + private: + + T* data; + typename mem_manager::template rebind::other pool; + }; + + // ------------------------------------------------------------------------------------ + + template < + typename T, + long num_rows, + long num_cols, + typename mem_manager + > + class layout : noncopyable // when num_rows == 0 && num_cols != 0, + { + public: + const static long NR = num_rows; + const static long NC = num_cols; + + layout ( + ):data(0), nr_(0) { } + + ~layout () + { + if (data) + pool.deallocate_array(data); + } + + T& operator() ( + long r, + long c + ) { return data[r*num_cols + c]; } + + const T& operator() ( + long r, + long c + ) const { return data[r*num_cols + c]; } + + T& operator() ( + long i + ) { return data[i]; } + + const T& operator() ( + long i + ) const { return data[i]; } + + void swap( + layout& item + ) + { + std::swap(item.data,data); + std::swap(item.nr_,nr_); + pool.swap(item.pool); + } + + long nr ( + ) const { return nr_; } + + long nc ( + ) const { return num_cols; } + + void set_size ( + long nr, + long nc + ) + { + if (data) + { + pool.deallocate_array(data); + } + data = pool.allocate_array(nr*nc); + nr_ = nr; + } + +#ifdef MATLAB_MEX_FILE + void _private_set_mxArray ( mxArray* ) { DLIB_CASSERT(false, "This function should never be called."); } + mxArray* _private_release_mxArray(){DLIB_CASSERT(false, "This function should never be called."); } + void _private_mark_owned_by_matlab() {DLIB_CASSERT(false, "This function should never be called."); } + bool _private_is_owned_by_matlab() const { return false; } +#endif + + private: + + T* data; + long nr_; + typename mem_manager::template rebind::other pool; + }; + + // ------------------------------------------------------------------------------------ + + template < + typename T, + long num_rows, + long num_cols, + typename mem_manager + > + class layout : noncopyable // when num_rows != 0 && num_cols == 0 + { + public: + const static long NR = num_rows; + const static long NC = num_cols; + + layout ( + ):data(0), nc_(0) { } + + ~layout () + { + if (data) + { + pool.deallocate_array(data); + } + } + + T& operator() ( + long r, + long c + ) { return data[r*nc_ + c]; } + + const T& operator() ( + long r, + long c + ) const { return data[r*nc_ + c]; } + + T& operator() ( + long i + ) { return data[i]; } + + const T& operator() ( + long i + ) const { return data[i]; } + + void swap( + layout& item + ) + { + std::swap(item.data,data); + std::swap(item.nc_,nc_); + pool.swap(item.pool); + } + + long nr ( + ) const { return num_rows; } + + long nc ( + ) const { return nc_; } + + void set_size ( + long nr, + long nc + ) + { + if (data) + { + pool.deallocate_array(data); + } + data = pool.allocate_array(nr*nc); + nc_ = nc; + } + +#ifdef MATLAB_MEX_FILE + void _private_set_mxArray ( mxArray* ) { DLIB_CASSERT(false, "This function should never be called."); } + mxArray* _private_release_mxArray(){DLIB_CASSERT(false, "This function should never be called."); } + void _private_mark_owned_by_matlab() {DLIB_CASSERT(false, "This function should never be called."); } + bool _private_is_owned_by_matlab() const { return false; } +#endif + + private: + + T* data; + long nc_; + typename mem_manager::template rebind::other pool; + }; + + // ------------------------------------------------------------------------------------ + + template < + typename T, + long num_rows, + long num_cols, + typename mem_manager + > + class layout : noncopyable // when num_rows == 0 && num_cols == 0 + { + public: + const static long NR = num_rows; + const static long NC = num_cols; + + layout ( + ):data(0), nr_(0), nc_(0) { } + + ~layout () + { + if (data) + { + pool.deallocate_array(data); + } + } + + T& operator() ( + long r, + long c + ) { return data[r*nc_ + c]; } + + const T& operator() ( + long r, + long c + ) const { return data[r*nc_ + c]; } + + T& operator() ( + long i + ) { return data[i]; } + + const T& operator() ( + long i + ) const { return data[i]; } + + void swap( + layout& item + ) + { + std::swap(item.data,data); + std::swap(item.nc_,nc_); + std::swap(item.nr_,nr_); + pool.swap(item.pool); + } + + long nr ( + ) const { return nr_; } + + long nc ( + ) const { return nc_; } + + void set_size ( + long nr, + long nc + ) + { + if (data) + { + pool.deallocate_array(data); + } + data = pool.allocate_array(nr*nc); + nr_ = nr; + nc_ = nc; + } + +#ifdef MATLAB_MEX_FILE + void _private_set_mxArray ( mxArray* ) { DLIB_CASSERT(false, "This function should never be called."); } + mxArray* _private_release_mxArray(){DLIB_CASSERT(false, "This function should never be called."); } + void _private_mark_owned_by_matlab() {DLIB_CASSERT(false, "This function should never be called."); } + bool _private_is_owned_by_matlab() const { return false; } +#endif + private: + T* data; + long nr_; + long nc_; + typename mem_manager::template rebind::other pool; + }; + + }; + +// ---------------------------------------------------------------------------------------- + + struct column_major_layout + { + // if a matrix is bigger than this many bytes then don't put it on the stack + const static size_t max_stack_based_size = 256; + + + // this is a hack to avoid a compile time error in visual studio 8. I would just + // use sizeof(T) and be done with it but that won't compile. The idea here + // is to avoid using the stack allocation of the layout object if it + // is going to contain another matrix and also avoid asking for the sizeof() + // the contained matrix. + template + struct get_sizeof_helper + { + const static std::size_t val = sizeof(T); + }; + + template + struct get_sizeof_helper > + { + const static std::size_t val = 1000000; + }; + + template < + typename T, + long num_rows, + long num_cols, + typename mem_manager, + int val = static_switch < + // when the sizes are all non zero and small + (num_rows*num_cols*get_sizeof_helper::val <= max_stack_based_size) && (num_rows != 0 && num_cols != 0), + // when the sizes are all non zero and big + (num_rows*num_cols*get_sizeof_helper::val > max_stack_based_size) && (num_rows != 0 && num_cols != 0), + num_rows == 0 && num_cols != 0, + num_rows != 0 && num_cols == 0, + num_rows == 0 && num_cols == 0 + >::value + > + class layout ; + /*! + WHAT THIS OBJECT REPRESENTS + This object represents the actual allocation of space for a matrix. + Small matrices allocate all their data on the stack and bigger ones + use a memory_manager to get their memory. + !*/ + + // ------------------------------------------------------------------------------------ + + template < + typename T, + long num_rows, + long num_cols, + typename mem_manager + > + class layout : noncopyable // when the sizes are all non zero and small + { + public: + const static long NR = num_rows; + const static long NC = num_cols; + + layout() {} + + T& operator() ( + long r, + long c + ) { return *(data+c*num_rows + r); } + + const T& operator() ( + long r, + long c + ) const { return *(data+c*num_rows + r); } + + T& operator() ( + long i + ) { return data[i]; } + + const T& operator() ( + long i + ) const { return data[i]; } + + void swap( + layout& item + ) + { + for (long r = 0; r < num_rows; ++r) + { + for (long c = 0; c < num_cols; ++c) + { + exchange((*this)(r,c),item(r,c)); + } + } + } + + long nr ( + ) const { return num_rows; } + + long nc ( + ) const { return num_cols; } + + void set_size ( + long, + long + ) + { + } + +#ifdef MATLAB_MEX_FILE + void _private_set_mxArray ( mxArray* ) { DLIB_CASSERT(false, "This function should never be called."); } + mxArray* _private_release_mxArray(){DLIB_CASSERT(false, "This function should never be called."); } + void _private_mark_owned_by_matlab() {DLIB_CASSERT(false, "This function should never be called."); } + bool _private_is_owned_by_matlab() const { return false; } +#endif + + private: + T data[num_cols*num_rows]; + }; + + // ------------------------------------------------------------------------------------ + + template < + typename T, + long num_rows, + long num_cols, + typename mem_manager + > + class layout : noncopyable // when the sizes are all non zero and big + { + public: + const static long NR = num_rows; + const static long NC = num_cols; + + layout ( + ) { data = pool.allocate_array(num_rows*num_cols); } + + ~layout () + { pool.deallocate_array(data); } + + T& operator() ( + long r, + long c + ) { return data[c*num_rows + r]; } + + const T& operator() ( + long r, + long c + ) const { return data[c*num_rows + r]; } + + T& operator() ( + long i + ) { return data[i]; } + + const T& operator() ( + long i + ) const { return data[i]; } + + void swap( + layout& item + ) + { + std::swap(item.data,data); + pool.swap(item.pool); + } + + long nr ( + ) const { return num_rows; } + + long nc ( + ) const { return num_cols; } + + void set_size ( + long , + long + ) + { + } + +#ifdef MATLAB_MEX_FILE + void _private_set_mxArray ( mxArray* ) { DLIB_CASSERT(false, "This function should never be called."); } + mxArray* _private_release_mxArray(){DLIB_CASSERT(false, "This function should never be called."); } + void _private_mark_owned_by_matlab() {DLIB_CASSERT(false, "This function should never be called."); } + bool _private_is_owned_by_matlab() const { return false; } +#endif + + private: + + T* data; + typename mem_manager::template rebind::other pool; + }; + + // ------------------------------------------------------------------------------------ + + template < + typename T, + long num_rows, + long num_cols, + typename mem_manager + > + class layout : noncopyable // when num_rows == 0 && num_cols != 0, + { + public: + const static long NR = num_rows; + const static long NC = num_cols; + + layout ( + ):data(0), nr_(0) { } + + ~layout () + { + if (data) + pool.deallocate_array(data); + } + + T& operator() ( + long r, + long c + ) { return data[c*nr_ + r]; } + + const T& operator() ( + long r, + long c + ) const { return data[c*nr_ + r]; } + + T& operator() ( + long i + ) { return data[i]; } + + const T& operator() ( + long i + ) const { return data[i]; } + + void swap( + layout& item + ) + { + std::swap(item.data,data); + std::swap(item.nr_,nr_); + pool.swap(item.pool); + } + + long nr ( + ) const { return nr_; } + + long nc ( + ) const { return num_cols; } + + void set_size ( + long nr, + long nc + ) + { + if (data) + { + pool.deallocate_array(data); + } + data = pool.allocate_array(nr*nc); + nr_ = nr; + } + +#ifdef MATLAB_MEX_FILE + void _private_set_mxArray ( mxArray* ) { DLIB_CASSERT(false, "This function should never be called."); } + mxArray* _private_release_mxArray(){DLIB_CASSERT(false, "This function should never be called."); } + void _private_mark_owned_by_matlab() {DLIB_CASSERT(false, "This function should never be called."); } + bool _private_is_owned_by_matlab() const { return false; } +#endif + + private: + + T* data; + long nr_; + typename mem_manager::template rebind::other pool; + }; + + // ------------------------------------------------------------------------------------ + + template < + typename T, + long num_rows, + long num_cols, + typename mem_manager + > + class layout : noncopyable // when num_rows != 0 && num_cols == 0 + { + public: + const static long NR = num_rows; + const static long NC = num_cols; + + layout ( + ):data(0), nc_(0) { } + + ~layout () + { + if (data) + { + pool.deallocate_array(data); + } + } + + T& operator() ( + long r, + long c + ) { return data[c*num_rows + r]; } + + const T& operator() ( + long r, + long c + ) const { return data[c*num_rows + r]; } + + T& operator() ( + long i + ) { return data[i]; } + + const T& operator() ( + long i + ) const { return data[i]; } + + void swap( + layout& item + ) + { + std::swap(item.data,data); + std::swap(item.nc_,nc_); + pool.swap(item.pool); + } + + long nr ( + ) const { return num_rows; } + + long nc ( + ) const { return nc_; } + + void set_size ( + long nr, + long nc + ) + { + if (data) + { + pool.deallocate_array(data); + } + data = pool.allocate_array(nr*nc); + nc_ = nc; + } + +#ifdef MATLAB_MEX_FILE + void _private_set_mxArray ( mxArray* ) { DLIB_CASSERT(false, "This function should never be called."); } + mxArray* _private_release_mxArray(){DLIB_CASSERT(false, "This function should never be called."); } + void _private_mark_owned_by_matlab() {DLIB_CASSERT(false, "This function should never be called."); } + bool _private_is_owned_by_matlab() const { return false; } +#endif + + private: + + T* data; + long nc_; + typename mem_manager::template rebind::other pool; + }; + + // ------------------------------------------------------------------------------------ + + template < + typename T, + long num_rows, + long num_cols, + typename mem_manager + > + class layout : noncopyable // when num_rows == 0 && num_cols == 0 + { + public: + const static long NR = num_rows; + const static long NC = num_cols; + + layout ( + ):data(0), nr_(0), nc_(0) { } + + ~layout () + { + if (data) + { + pool.deallocate_array(data); + } + } + + T& operator() ( + long r, + long c + ) { return data[c*nr_ + r]; } + + const T& operator() ( + long r, + long c + ) const { return data[c*nr_ + r]; } + + T& operator() ( + long i + ) { return data[i]; } + + const T& operator() ( + long i + ) const { return data[i]; } + + void swap( + layout& item + ) + { + std::swap(item.data,data); + std::swap(item.nc_,nc_); + std::swap(item.nr_,nr_); + pool.swap(item.pool); + } + +#ifdef MATLAB_MEX_FILE + void _private_set_mxArray ( mxArray* ) { DLIB_CASSERT(false, "This function should never be called."); } + mxArray* _private_release_mxArray(){DLIB_CASSERT(false, "This function should never be called."); } + void _private_mark_owned_by_matlab() {DLIB_CASSERT(false, "This function should never be called."); } + bool _private_is_owned_by_matlab() const { return false; } +#endif + + long nr ( + ) const { return nr_; } + + long nc ( + ) const { return nc_; } + + void set_size ( + long nr, + long nc + ) + { + if (data) + { + pool.deallocate_array(data); + } + data = pool.allocate_array(nr*nc); + nr_ = nr; + nc_ = nc; + } + + private: + T* data; + long nr_; + long nc_; + typename mem_manager::template rebind::other pool; + }; + +#ifdef MATLAB_MEX_FILE + template < + long num_rows, + long num_cols + > + class layout : noncopyable // when num_rows == 0 && num_cols == 0 + { + public: + const static long NR = num_rows; + const static long NC = num_cols; + + layout ( + ): data(0), nr_(0), nc_(0), owned_by_matlab(false),set_by_private_set_mxArray(false),mem(0) { } + + ~layout () + { + if (owned_by_matlab) + { + if (!set_by_private_set_mxArray && mem) + { + mxDestroyArray(mem); + mem = 0; + data = 0; + } + } + else if (data) + { + delete [] data; + data = 0; + } + } + + double& operator() ( + long r, + long c + ) { return data[c*nr_ + r]; } + + const double& operator() ( + long r, + long c + ) const { return data[c*nr_ + r]; } + + double& operator() ( + long i + ) { return data[i]; } + + const double& operator() ( + long i + ) const { return data[i]; } + + void _private_set_mxArray ( + mxArray* mem_ + ) + { + DLIB_CASSERT(mem == 0 && data == 0,"You can't call this function on an already allocated matrix."); + // We don't own the pointer, so make note of that so we won't try to free + // it. + set_by_private_set_mxArray = true; + owned_by_matlab = true; + mem = mem_; + data = mxGetPr(mem); + nr_ = mxGetM(mem); + nc_ = mxGetN(mem); + } + + mxArray* _private_release_mxArray() + { + DLIB_CASSERT(owned_by_matlab,""); + mxArray* temp = mem; + mem = 0; + set_by_private_set_mxArray = false; + data = 0; + nr_ = 0; + nc_ = 0; + return temp; + } + + void _private_mark_owned_by_matlab() + { + DLIB_CASSERT(mem == 0 && data == 0,"You can't say a matrix should be owned by matlab after it's been allocated."); + owned_by_matlab = true; + } + bool _private_is_owned_by_matlab() const + { + return owned_by_matlab; + } + + void swap( + layout& item + ) + { + std::swap(item.owned_by_matlab,owned_by_matlab); + std::swap(item.set_by_private_set_mxArray,set_by_private_set_mxArray); + std::swap(item.mem,mem); + std::swap(item.data,data); + std::swap(item.nc_,nc_); + std::swap(item.nr_,nr_); + } + + long nr ( + ) const { return nr_; } + + long nc ( + ) const { return nc_; } + + void set_size ( + long nr, + long nc + ) + { + if (owned_by_matlab) + { + if (!set_by_private_set_mxArray && mem) + { + mxDestroyArray(mem); + mem = 0; + data = 0; + } + set_by_private_set_mxArray = false; + + mem = mxCreateDoubleMatrix(nr, nc, mxREAL); + if (mem == 0) + throw std::bad_alloc(); + data = mxGetPr(mem); + } + else + { + if (data) + delete [] data; + data = new double[nr*nc]; + } + nr_ = nr; + nc_ = nc; + } + + private: + double* data; + long nr_; + long nc_; + bool owned_by_matlab; + bool set_by_private_set_mxArray; + mxArray* mem; + }; + + template < + long num_rows, + long num_cols + > + class layout : noncopyable // when num_rows == 0 && num_cols == 0 + { + public: + const static long NR = num_rows; + const static long NC = num_cols; + + layout ( + ): data(0), nr_(0), nc_(0), owned_by_matlab(false),set_by_private_set_mxArray(false),mem(0) { } + + ~layout () + { + if (owned_by_matlab) + { + if (!set_by_private_set_mxArray && mem) + { + mxDestroyArray(mem); + mem = 0; + data = 0; + } + } + else if (data) + { + delete [] data; + data = 0; + } + } + + float& operator() ( + long r, + long c + ) { return data[c*nr_ + r]; } + + const float& operator() ( + long r, + long c + ) const { return data[c*nr_ + r]; } + + float& operator() ( + long i + ) { return data[i]; } + + const float& operator() ( + long i + ) const { return data[i]; } + + void _private_set_mxArray ( + mxArray* mem_ + ) + { + DLIB_CASSERT(mem == 0 && data == 0,"You can't call this function on an already allocated matrix."); + // We don't own the pointer, so make note of that so we won't try to free + // it. + set_by_private_set_mxArray = true; + owned_by_matlab = true; + mem = mem_; + data = (float*)mxGetData(mem); + nr_ = mxGetM(mem); + nc_ = mxGetN(mem); + } + + mxArray* _private_release_mxArray() + { + DLIB_CASSERT(owned_by_matlab,""); + mxArray* temp = mem; + mem = 0; + set_by_private_set_mxArray = false; + data = 0; + nr_ = 0; + nc_ = 0; + return temp; + } + + void _private_mark_owned_by_matlab() + { + DLIB_CASSERT(mem == 0 && data == 0,"You can't say a matrix should be owned by matlab after it's been allocated."); + owned_by_matlab = true; + } + bool _private_is_owned_by_matlab() const + { + return owned_by_matlab; + } + + void swap( + layout& item + ) + { + std::swap(item.owned_by_matlab,owned_by_matlab); + std::swap(item.set_by_private_set_mxArray,set_by_private_set_mxArray); + std::swap(item.mem,mem); + std::swap(item.data,data); + std::swap(item.nc_,nc_); + std::swap(item.nr_,nr_); + } + + long nr ( + ) const { return nr_; } + + long nc ( + ) const { return nc_; } + + void set_size ( + long nr, + long nc + ) + { + if (owned_by_matlab) + { + if (!set_by_private_set_mxArray && mem) + { + mxDestroyArray(mem); + mem = 0; + data = 0; + } + set_by_private_set_mxArray = false; + + mem = mxCreateNumericMatrix(nr, nc, mxSINGLE_CLASS, mxREAL); + if (mem == 0) + throw std::bad_alloc(); + data = (float*)mxGetData(mem); + } + else + { + if (data) + delete [] data; + data = new float[nr*nc]; + } + nr_ = nr; + nc_ = nc; + } + + private: + float* data; + long nr_; + long nc_; + bool owned_by_matlab; + bool set_by_private_set_mxArray; + mxArray* mem; + }; +#endif + + }; + +// ---------------------------------------------------------------------------------------- + +} + +#if defined(__GNUC__) && ((__GNUC__ >= 4 && __GNUC_MINOR__ >= 8) || (__GNUC__ > 4)) +#pragma GCC diagnostic pop +#endif + +#endif // DLIB_MATRIx_DATA_LAYOUT_ + diff --git a/ml/dlib/dlib/matrix/matrix_data_layout_abstract.h b/ml/dlib/dlib/matrix/matrix_data_layout_abstract.h new file mode 100644 index 000000000..c3fa02be2 --- /dev/null +++ b/ml/dlib/dlib/matrix/matrix_data_layout_abstract.h @@ -0,0 +1,40 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_MATRIx_DATA_LAYOUT_ABSTRACT_ +#ifdef DLIB_MATRIx_DATA_LAYOUT_ABSTRACT_ + +#include "../algs.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + struct row_major_layout + { + /*! + This is the default matrix layout. Any matrix object that uses this + layout will be laid out in memory in row major order. Additionally, + all elements are contiguous (e.g. there isn't any padding at the ends of + rows or anything like that) + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + struct column_major_layout + { + /*! + Any matrix object that uses this layout will be laid out in memory in + column major order. Additionally, all elements are contiguous (e.g. + there isn't any padding at the ends of rows or anything like that) + !*/ + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MATRIx_DATA_LAYOUT_ABSTRACT_ + + diff --git a/ml/dlib/dlib/matrix/matrix_default_mul.h b/ml/dlib/dlib/matrix/matrix_default_mul.h new file mode 100644 index 000000000..493c641a8 --- /dev/null +++ b/ml/dlib/dlib/matrix/matrix_default_mul.h @@ -0,0 +1,134 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MATRIx_DEFAULT_MULTIPLY_ +#define DLIB_MATRIx_DEFAULT_MULTIPLY_ + +#include "../geometry/rectangle.h" +#include "matrix.h" +#include "matrix_utilities.h" +#include "../enable_if.h" + +namespace dlib +{ + +// ------------------------------------------------------------------------------------ + + namespace ma + { + template < typename EXP, typename enable = void > + struct matrix_is_vector { static const bool value = false; }; + template < typename EXP > + struct matrix_is_vector::type > { static const bool value = true; }; + } + +// ------------------------------------------------------------------------------------ + + /*! This file defines the default_matrix_multiply() function. It is a function + that conforms to the following definition: + + template < + typename matrix_dest_type, + typename EXP1, + typename EXP2 + > + void default_matrix_multiply ( + matrix_dest_type& dest, + const EXP1& lhs, + const EXP2& rhs + ); + requires + - (lhs*rhs).destructively_aliases(dest) == false + - dest.nr() == (lhs*rhs).nr() + - dest.nc() == (lhs*rhs).nc() + ensures + - #dest == dest + lhs*rhs + !*/ + +// ------------------------------------------------------------------------------------ + + template < + typename matrix_dest_type, + typename EXP1, + typename EXP2 + > + typename enable_if_c::value == true || ma::matrix_is_vector::value == true>::type + default_matrix_multiply ( + matrix_dest_type& dest, + const EXP1& lhs, + const EXP2& rhs + ) + { + matrix_assign_default(dest, lhs*rhs, 1, true); + } + +// ------------------------------------------------------------------------------------ + + template < + typename matrix_dest_type, + typename EXP1, + typename EXP2 + > + typename enable_if_c::value == false && ma::matrix_is_vector::value == false>::type + default_matrix_multiply ( + matrix_dest_type& dest, + const EXP1& lhs, + const EXP2& rhs + ) + { + const long bs = 90; + + // if the matrices are small enough then just use the simple multiply algorithm + if (lhs.nc() <= 2 || rhs.nc() <= 2 || lhs.nr() <= 2 || rhs.nr() <= 2 || (lhs.size() <= bs*10 && rhs.size() <= bs*10) ) + { + matrix_assign_default(dest, lhs*rhs, 1, true); + } + else + { + // if the lhs and rhs matrices are big enough we should use a cache friendly + // algorithm that computes the matrix multiply in blocks. + + + // Loop over all the blocks in the lhs matrix + for (long r = 0; r < lhs.nr(); r+=bs) + { + for (long c = 0; c < lhs.nc(); c+=bs) + { + // make a rect for the block from lhs + rectangle lhs_block(c, r, std::min(c+bs-1,lhs.nc()-1), std::min(r+bs-1,lhs.nr()-1)); + + // now loop over all the rhs blocks we have to multiply with the current lhs block + for (long i = 0; i < rhs.nc(); i += bs) + { + // make a rect for the block from rhs + rectangle rhs_block(i, c, std::min(i+bs-1,rhs.nc()-1), std::min(c+bs-1,rhs.nr()-1)); + + // make a target rect in res + rectangle res_block(rhs_block.left(),lhs_block.top(), rhs_block.right(), lhs_block.bottom()); + + // This loop is optimized assuming that the data is laid out in + // row major order in memory. + for (long r = lhs_block.top(); r <= lhs_block.bottom(); ++r) + { + for (long c = lhs_block.left(); c<= lhs_block.right(); ++c) + { + const typename EXP2::type temp = lhs(r,c); + for (long i = rhs_block.left(); i <= rhs_block.right(); ++i) + { + dest(r,i) += rhs(c,i)*temp; + } + } + } + } + } + } + } + + + } + +// ------------------------------------------------------------------------------------ + +} + +#endif // DLIB_MATRIx_DEFAULT_MULTIPLY_ + diff --git a/ml/dlib/dlib/matrix/matrix_eigenvalue.h b/ml/dlib/dlib/matrix/matrix_eigenvalue.h new file mode 100644 index 000000000..3dc47e105 --- /dev/null +++ b/ml/dlib/dlib/matrix/matrix_eigenvalue.h @@ -0,0 +1,1379 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +// This code was adapted from code from the JAMA part of NIST's TNT library. +// See: http://math.nist.gov/tnt/ +#ifndef DLIB_MATRIX_EIGENVALUE_DECOMPOSITION_H +#define DLIB_MATRIX_EIGENVALUE_DECOMPOSITION_H + +#include "matrix.h" +#include "matrix_utilities.h" +#include "matrix_subexp.h" +#include +#include +#include + +#ifdef DLIB_USE_LAPACK +#include "lapack/geev.h" +#include "lapack/syev.h" +#include "lapack/syevr.h" +#endif + +#define DLIB_LAPACK_EIGENVALUE_DECOMP_SIZE_THRESH 4 + +namespace dlib +{ + + template < + typename matrix_exp_type + > + class eigenvalue_decomposition + { + + public: + + const static long NR = matrix_exp_type::NR; + const static long NC = matrix_exp_type::NC; + typedef typename matrix_exp_type::type type; + typedef typename matrix_exp_type::mem_manager_type mem_manager_type; + typedef typename matrix_exp_type::layout_type layout_type; + + typedef typename matrix_exp_type::matrix_type matrix_type; + typedef matrix column_vector_type; + + typedef matrix,0,0,mem_manager_type,layout_type> complex_matrix_type; + typedef matrix,NR,1,mem_manager_type,layout_type> complex_column_vector_type; + + + // You have supplied an invalid type of matrix_exp_type. You have + // to use this object with matrices that contain float or double type data. + COMPILE_TIME_ASSERT((is_same_type::value || + is_same_type::value )); + + + template + eigenvalue_decomposition( + const matrix_exp& A + ); + + template + eigenvalue_decomposition( + const matrix_op >& A + ); + + long dim ( + ) const; + + const complex_column_vector_type get_eigenvalues ( + ) const; + + const column_vector_type& get_real_eigenvalues ( + ) const; + + const column_vector_type& get_imag_eigenvalues ( + ) const; + + const complex_matrix_type get_v ( + ) const; + + const complex_matrix_type get_d ( + ) const; + + const matrix_type& get_pseudo_v ( + ) const; + + const matrix_type get_pseudo_d ( + ) const; + + private: + + /** Row and column dimension (square matrix). */ + long n; + + bool issymmetric; + + /** Arrays for internal storage of eigenvalues. */ + + column_vector_type d; /* real part */ + column_vector_type e; /* img part */ + + /** Array for internal storage of eigenvectors. */ + matrix_type V; + + /** Array for internal storage of nonsymmetric Hessenberg form. + @serial internal storage of nonsymmetric Hessenberg form. + */ + matrix_type H; + + + /** Working storage for nonsymmetric algorithm. + @serial working storage for nonsymmetric algorithm. + */ + column_vector_type ort; + + // Symmetric Householder reduction to tridiagonal form. + void tred2(); + + + // Symmetric tridiagonal QL algorithm. + void tql2 (); + + + // Nonsymmetric reduction to Hessenberg form. + void orthes (); + + + // Complex scalar division. + type cdivr, cdivi; + void cdiv_(type xr, type xi, type yr, type yi); + + + // Nonsymmetric reduction from Hessenberg to real Schur form. + void hqr2 (); + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// Public member functions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template + template + eigenvalue_decomposition:: + eigenvalue_decomposition( + const matrix_exp& A_ + ) + { + COMPILE_TIME_ASSERT((is_same_type::value)); + + + const_temp_matrix A(A_); + + // make sure requires clause is not broken + DLIB_ASSERT(A.nr() == A.nc() && A.size() > 0, + "\teigenvalue_decomposition::eigenvalue_decomposition(A)" + << "\n\tYou can only use this on square matrices" + << "\n\tA.nr(): " << A.nr() + << "\n\tA.nc(): " << A.nc() + << "\n\tA.size(): " << A.size() + << "\n\tthis: " << this + ); + + + n = A.nc(); + V.set_size(n,n); + d.set_size(n); + e.set_size(n); + + + issymmetric = true; + for (long j = 0; (j < n) && issymmetric; j++) + { + for (long i = 0; (i < n) && issymmetric; i++) + { + issymmetric = (A(i,j) == A(j,i)); + } + } + + if (issymmetric) + { + V = A; + +#ifdef DLIB_USE_LAPACK + if (A.nr() > DLIB_LAPACK_EIGENVALUE_DECOMP_SIZE_THRESH) + { + e = 0; + + // We could compute the result using syev() + //lapack::syev('V', 'L', V, d); + + // Instead, we use syevr because its faster and maybe more stable. + matrix_type tempA(A); + matrix isupz; + + lapack::integer temp; + lapack::syevr('V','A','L',tempA,0,0,0,0,-1,temp,d,V,isupz); + return; + } +#endif + // Tridiagonalize. + tred2(); + + // Diagonalize. + tql2(); + + } + else + { + +#ifdef DLIB_USE_LAPACK + if (A.nr() > DLIB_LAPACK_EIGENVALUE_DECOMP_SIZE_THRESH) + { + matrix temp, vl, vr; + temp = A; + lapack::geev('N', 'V', temp, d, e, vl, vr); + V = vr; + return; + } +#endif + H = A; + + ort.set_size(n); + + // Reduce to Hessenberg form. + orthes(); + + // Reduce Hessenberg to real Schur form. + hqr2(); + } + } + +// ---------------------------------------------------------------------------------------- + + template + template + eigenvalue_decomposition:: + eigenvalue_decomposition( + const matrix_op >& A + ) + { + COMPILE_TIME_ASSERT((is_same_type::value)); + + + // make sure requires clause is not broken + DLIB_ASSERT(A.nr() == A.nc() && A.size() > 0, + "\teigenvalue_decomposition::eigenvalue_decomposition(A)" + << "\n\tYou can only use this on square matrices" + << "\n\tA.nr(): " << A.nr() + << "\n\tA.nc(): " << A.nc() + << "\n\tA.size(): " << A.size() + << "\n\tthis: " << this + ); + + + n = A.nc(); + V.set_size(n,n); + d.set_size(n); + e.set_size(n); + + + V = A; + +#ifdef DLIB_USE_LAPACK + if (A.nr() > DLIB_LAPACK_EIGENVALUE_DECOMP_SIZE_THRESH) + { + e = 0; + + // We could compute the result using syev() + //lapack::syev('V', 'L', V, d); + + // Instead, we use syevr because its faster and maybe more stable. + matrix_type tempA(A); + matrix isupz; + + lapack::integer temp; + lapack::syevr('V','A','L',tempA,0,0,0,0,-1,temp,d,V,isupz); + return; + } +#endif + // Tridiagonalize. + tred2(); + + // Diagonalize. + tql2(); + + } + +// ---------------------------------------------------------------------------------------- + + template + const typename eigenvalue_decomposition::matrix_type& eigenvalue_decomposition:: + get_pseudo_v ( + ) const + { + return V; + } + +// ---------------------------------------------------------------------------------------- + + template + long eigenvalue_decomposition:: + dim ( + ) const + { + return V.nr(); + } + +// ---------------------------------------------------------------------------------------- + + template + const typename eigenvalue_decomposition::complex_column_vector_type eigenvalue_decomposition:: + get_eigenvalues ( + ) const + { + return complex_matrix(get_real_eigenvalues(), get_imag_eigenvalues()); + } + +// ---------------------------------------------------------------------------------------- + + template + const typename eigenvalue_decomposition::column_vector_type& eigenvalue_decomposition:: + get_real_eigenvalues ( + ) const + { + return d; + } + +// ---------------------------------------------------------------------------------------- + + template + const typename eigenvalue_decomposition::column_vector_type& eigenvalue_decomposition:: + get_imag_eigenvalues ( + ) const + { + return e; + } + +// ---------------------------------------------------------------------------------------- + + template + const typename eigenvalue_decomposition::complex_matrix_type eigenvalue_decomposition:: + get_d ( + ) const + { + return diagm(complex_matrix(get_real_eigenvalues(), get_imag_eigenvalues())); + } + +// ---------------------------------------------------------------------------------------- + + template + const typename eigenvalue_decomposition::complex_matrix_type eigenvalue_decomposition:: + get_v ( + ) const + { + complex_matrix_type CV(n,n); + + for (long i = 0; i < n; i++) + { + if (e(i) > 0) + { + set_colm(CV,i) = complex_matrix(colm(V,i), colm(V,i+1)); + } + else if (e(i) < 0) + { + set_colm(CV,i) = complex_matrix(colm(V,i), colm(V,i-1)); + } + else + { + set_colm(CV,i) = complex_matrix(colm(V,i), uniform_matrix(n,1,0)); + } + } + + return CV; + } + +// ---------------------------------------------------------------------------------------- + + template + const typename eigenvalue_decomposition::matrix_type eigenvalue_decomposition:: + get_pseudo_d ( + ) const + { + matrix_type D(n,n); + + for (long i = 0; i < n; i++) + { + for (long j = 0; j < n; j++) + { + D(i,j) = 0.0; + } + D(i,i) = d(i); + if (e(i) > 0) + { + D(i,i+1) = e(i); + } + else if (e(i) < 0) + { + D(i,i-1) = e(i); + } + } + + return D; + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// Private member functions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + +// Symmetric Householder reduction to tridiagonal form. + template + void eigenvalue_decomposition:: + tred2() + { + using std::abs; + using std::sqrt; + + // This is derived from the Algol procedures tred2 by + // Bowdler, Martin, Reinsch, and Wilkinson, Handbook for + // Auto. Comp., Vol.ii-Linear Algebra, and the corresponding + // Fortran subroutine in EISPACK. + + for (long j = 0; j < n; j++) + { + d(j) = V(n-1,j); + } + + // Householder reduction to tridiagonal form. + + for (long i = n-1; i > 0; i--) + { + + // Scale to avoid under/overflow. + + type scale = 0.0; + type h = 0.0; + for (long k = 0; k < i; k++) + { + scale = scale + abs(d(k)); + } + if (scale == 0.0) + { + e(i) = d(i-1); + for (long j = 0; j < i; j++) + { + d(j) = V(i-1,j); + V(i,j) = 0.0; + V(j,i) = 0.0; + } + } + else + { + + // Generate Householder vector. + + for (long k = 0; k < i; k++) + { + d(k) /= scale; + h += d(k) * d(k); + } + type f = d(i-1); + type g = sqrt(h); + if (f > 0) + { + g = -g; + } + e(i) = scale * g; + h = h - f * g; + d(i-1) = f - g; + for (long j = 0; j < i; j++) + { + e(j) = 0.0; + } + + // Apply similarity transformation to remaining columns. + + for (long j = 0; j < i; j++) + { + f = d(j); + V(j,i) = f; + g = e(j) + V(j,j) * f; + for (long k = j+1; k <= i-1; k++) + { + g += V(k,j) * d(k); + e(k) += V(k,j) * f; + } + e(j) = g; + } + f = 0.0; + for (long j = 0; j < i; j++) + { + e(j) /= h; + f += e(j) * d(j); + } + type hh = f / (h + h); + for (long j = 0; j < i; j++) + { + e(j) -= hh * d(j); + } + for (long j = 0; j < i; j++) + { + f = d(j); + g = e(j); + for (long k = j; k <= i-1; k++) + { + V(k,j) -= (f * e(k) + g * d(k)); + } + d(j) = V(i-1,j); + V(i,j) = 0.0; + } + } + d(i) = h; + } + + // Accumulate transformations. + + for (long i = 0; i < n-1; i++) + { + V(n-1,i) = V(i,i); + V(i,i) = 1.0; + type h = d(i+1); + if (h != 0.0) + { + for (long k = 0; k <= i; k++) + { + d(k) = V(k,i+1) / h; + } + for (long j = 0; j <= i; j++) + { + type g = 0.0; + for (long k = 0; k <= i; k++) + { + g += V(k,i+1) * V(k,j); + } + for (long k = 0; k <= i; k++) + { + V(k,j) -= g * d(k); + } + } + } + for (long k = 0; k <= i; k++) + { + V(k,i+1) = 0.0; + } + } + for (long j = 0; j < n; j++) + { + d(j) = V(n-1,j); + V(n-1,j) = 0.0; + } + V(n-1,n-1) = 1.0; + e(0) = 0.0; + } + +// ---------------------------------------------------------------------------------------- + + template + void eigenvalue_decomposition:: + tql2 () + { + using std::pow; + using std::min; + using std::max; + using std::abs; + + // This is derived from the Algol procedures tql2, by + // Bowdler, Martin, Reinsch, and Wilkinson, Handbook for + // Auto. Comp., Vol.ii-Linear Algebra, and the corresponding + // Fortran subroutine in EISPACK. + + for (long i = 1; i < n; i++) + { + e(i-1) = e(i); + } + e(n-1) = 0.0; + + type f = 0.0; + type tst1 = 0.0; + const type eps = std::numeric_limits::epsilon(); + for (long l = 0; l < n; l++) + { + + // Find small subdiagonal element + + tst1 = max(tst1,abs(d(l)) + abs(e(l))); + long m = l; + + // Original while-loop from Java code + while (m < n) + { + if (abs(e(m)) <= eps*tst1) + { + break; + } + m++; + } + if (m == n) + --m; + + + // If m == l, d(l) is an eigenvalue, + // otherwise, iterate. + + if (m > l) + { + long iter = 0; + do + { + iter = iter + 1; // (Could check iteration count here.) + + // Compute implicit shift + + type g = d(l); + type p = (d(l+1) - g) / (2.0 * e(l)); + type r = hypot(p,(type)1.0); + if (p < 0) + { + r = -r; + } + d(l) = e(l) / (p + r); + d(l+1) = e(l) * (p + r); + type dl1 = d(l+1); + type h = g - d(l); + for (long i = l+2; i < n; i++) + { + d(i) -= h; + } + f = f + h; + + // Implicit QL transformation. + + p = d(m); + type c = 1.0; + type c2 = c; + type c3 = c; + type el1 = e(l+1); + type s = 0.0; + type s2 = 0.0; + for (long i = m-1; i >= l; i--) + { + c3 = c2; + c2 = c; + s2 = s; + g = c * e(i); + h = c * p; + r = hypot(p,e(i)); + e(i+1) = s * r; + s = e(i) / r; + c = p / r; + p = c * d(i) - s * g; + d(i+1) = h + s * (c * g + s * d(i)); + + // Accumulate transformation. + + for (long k = 0; k < n; k++) + { + h = V(k,i+1); + V(k,i+1) = s * V(k,i) + c * h; + V(k,i) = c * V(k,i) - s * h; + } + } + p = -s * s2 * c3 * el1 * e(l) / dl1; + e(l) = s * p; + d(l) = c * p; + + // Check for convergence. + + } while (abs(e(l)) > eps*tst1); + } + d(l) = d(l) + f; + e(l) = 0.0; + } + + /* + The code to sort the eigenvalues and eigenvectors + has been removed from here since, in the non-symmetric case, + we can't sort the eigenvalues in a meaningful way. If we left this + code in here then the user might supply what they thought was a symmetric + matrix but was actually slightly non-symmetric due to rounding error + and then they would end up in the non-symmetric eigenvalue solver + where the eigenvalues don't end up getting sorted. So to avoid + any possible user confusion I'm just removing this. + */ + } + +// ---------------------------------------------------------------------------------------- + + template + void eigenvalue_decomposition:: + orthes () + { + using std::abs; + using std::sqrt; + + // This is derived from the Algol procedures orthes and ortran, + // by Martin and Wilkinson, Handbook for Auto. Comp., + // Vol.ii-Linear Algebra, and the corresponding + // Fortran subroutines in EISPACK. + + long low = 0; + long high = n-1; + + for (long m = low+1; m <= high-1; m++) + { + + // Scale column. + + type scale = 0.0; + for (long i = m; i <= high; i++) + { + scale = scale + abs(H(i,m-1)); + } + if (scale != 0.0) + { + + // Compute Householder transformation. + + type h = 0.0; + for (long i = high; i >= m; i--) + { + ort(i) = H(i,m-1)/scale; + h += ort(i) * ort(i); + } + type g = sqrt(h); + if (ort(m) > 0) + { + g = -g; + } + h = h - ort(m) * g; + ort(m) = ort(m) - g; + + // Apply Householder similarity transformation + // H = (I-u*u'/h)*H*(I-u*u')/h) + + for (long j = m; j < n; j++) + { + type f = 0.0; + for (long i = high; i >= m; i--) + { + f += ort(i)*H(i,j); + } + f = f/h; + for (long i = m; i <= high; i++) + { + H(i,j) -= f*ort(i); + } + } + + for (long i = 0; i <= high; i++) + { + type f = 0.0; + for (long j = high; j >= m; j--) + { + f += ort(j)*H(i,j); + } + f = f/h; + for (long j = m; j <= high; j++) + { + H(i,j) -= f*ort(j); + } + } + ort(m) = scale*ort(m); + H(m,m-1) = scale*g; + } + } + + // Accumulate transformations (Algol's ortran). + + for (long i = 0; i < n; i++) + { + for (long j = 0; j < n; j++) + { + V(i,j) = (i == j ? 1.0 : 0.0); + } + } + + for (long m = high-1; m >= low+1; m--) + { + if (H(m,m-1) != 0.0) + { + for (long i = m+1; i <= high; i++) + { + ort(i) = H(i,m-1); + } + for (long j = m; j <= high; j++) + { + type g = 0.0; + for (long i = m; i <= high; i++) + { + g += ort(i) * V(i,j); + } + // Double division avoids possible underflow + g = (g / ort(m)) / H(m,m-1); + for (long i = m; i <= high; i++) + { + V(i,j) += g * ort(i); + } + } + } + } + } + +// ---------------------------------------------------------------------------------------- + + template + void eigenvalue_decomposition:: + cdiv_(type xr, type xi, type yr, type yi) + { + using std::abs; + type r,d; + if (abs(yr) > abs(yi)) + { + r = yi/yr; + d = yr + r*yi; + cdivr = (xr + r*xi)/d; + cdivi = (xi - r*xr)/d; + } + else + { + r = yr/yi; + d = yi + r*yr; + cdivr = (r*xr + xi)/d; + cdivi = (r*xi - xr)/d; + } + } + +// ---------------------------------------------------------------------------------------- + + template + void eigenvalue_decomposition:: + hqr2 () + { + using std::pow; + using std::min; + using std::max; + using std::abs; + using std::sqrt; + + // This is derived from the Algol procedure hqr2, + // by Martin and Wilkinson, Handbook for Auto. Comp., + // Vol.ii-Linear Algebra, and the corresponding + // Fortran subroutine in EISPACK. + + // Initialize + + long nn = this->n; + long n = nn-1; + long low = 0; + long high = nn-1; + const type eps = std::numeric_limits::epsilon(); + type exshift = 0.0; + type p=0,q=0,r=0,s=0,z=0,t,w,x,y; + + // Store roots isolated by balanc and compute matrix norm + + type norm = 0.0; + for (long i = 0; i < nn; i++) + { + if ((i < low) || (i > high)) + { + d(i) = H(i,i); + e(i) = 0.0; + } + for (long j = max(i-1,0L); j < nn; j++) + { + norm = norm + abs(H(i,j)); + } + } + + // Outer loop over eigenvalue index + + long iter = 0; + while (n >= low) + { + + // Look for single small sub-diagonal element + + long l = n; + while (l > low) + { + s = abs(H(l-1,l-1)) + abs(H(l,l)); + if (s == 0.0) + { + s = norm; + } + if (abs(H(l,l-1)) < eps * s) + { + break; + } + l--; + } + + // Check for convergence + // One root found + + if (l == n) + { + H(n,n) = H(n,n) + exshift; + d(n) = H(n,n); + e(n) = 0.0; + n--; + iter = 0; + + // Two roots found + + } + else if (l == n-1) + { + w = H(n,n-1) * H(n-1,n); + p = (H(n-1,n-1) - H(n,n)) / 2.0; + q = p * p + w; + z = sqrt(abs(q)); + H(n,n) = H(n,n) + exshift; + H(n-1,n-1) = H(n-1,n-1) + exshift; + x = H(n,n); + + // type pair + + if (q >= 0) + { + if (p >= 0) + { + z = p + z; + } + else + { + z = p - z; + } + d(n-1) = x + z; + d(n) = d(n-1); + if (z != 0.0) + { + d(n) = x - w / z; + } + e(n-1) = 0.0; + e(n) = 0.0; + x = H(n,n-1); + s = abs(x) + abs(z); + p = x / s; + q = z / s; + r = sqrt(p * p+q * q); + p = p / r; + q = q / r; + + // Row modification + + for (long j = n-1; j < nn; j++) + { + z = H(n-1,j); + H(n-1,j) = q * z + p * H(n,j); + H(n,j) = q * H(n,j) - p * z; + } + + // Column modification + + for (long i = 0; i <= n; i++) + { + z = H(i,n-1); + H(i,n-1) = q * z + p * H(i,n); + H(i,n) = q * H(i,n) - p * z; + } + + // Accumulate transformations + + for (long i = low; i <= high; i++) + { + z = V(i,n-1); + V(i,n-1) = q * z + p * V(i,n); + V(i,n) = q * V(i,n) - p * z; + } + + // Complex pair + + } + else + { + d(n-1) = x + p; + d(n) = x + p; + e(n-1) = z; + e(n) = -z; + } + n = n - 2; + iter = 0; + + // No convergence yet + + } + else + { + + // Form shift + + x = H(n,n); + y = 0.0; + w = 0.0; + if (l < n) + { + y = H(n-1,n-1); + w = H(n,n-1) * H(n-1,n); + } + + // Wilkinson's original ad hoc shift + + if (iter == 10) + { + exshift += x; + for (long i = low; i <= n; i++) + { + H(i,i) -= x; + } + s = abs(H(n,n-1)) + abs(H(n-1,n-2)); + x = y = 0.75 * s; + w = -0.4375 * s * s; + } + + // MATLAB's new ad hoc shift + + if (iter == 30) + { + s = (y - x) / 2.0; + s = s * s + w; + if (s > 0) + { + s = sqrt(s); + if (y < x) + { + s = -s; + } + s = x - w / ((y - x) / 2.0 + s); + for (long i = low; i <= n; i++) + { + H(i,i) -= s; + } + exshift += s; + x = y = w = 0.964; + } + } + + iter = iter + 1; // (Could check iteration count here.) + + // Look for two consecutive small sub-diagonal elements + + long m = n-2; + while (m >= l) + { + z = H(m,m); + r = x - z; + s = y - z; + p = (r * s - w) / H(m+1,m) + H(m,m+1); + q = H(m+1,m+1) - z - r - s; + r = H(m+2,m+1); + s = abs(p) + abs(q) + abs(r); + p = p / s; + q = q / s; + r = r / s; + if (m == l) + { + break; + } + if (abs(H(m,m-1)) * (abs(q) + abs(r)) < + eps * (abs(p) * (abs(H(m-1,m-1)) + abs(z) + + abs(H(m+1,m+1))))) + { + break; + } + m--; + } + + for (long i = m+2; i <= n; i++) + { + H(i,i-2) = 0.0; + if (i > m+2) + { + H(i,i-3) = 0.0; + } + } + + // Double QR step involving rows l:n and columns m:n + + for (long k = m; k <= n-1; k++) + { + long notlast = (k != n-1); + if (k != m) + { + p = H(k,k-1); + q = H(k+1,k-1); + r = (notlast ? H(k+2,k-1) : 0.0); + x = abs(p) + abs(q) + abs(r); + if (x != 0.0) + { + p = p / x; + q = q / x; + r = r / x; + } + } + if (x == 0.0) + { + break; + } + s = sqrt(p * p + q * q + r * r); + if (p < 0) + { + s = -s; + } + if (s != 0) + { + if (k != m) + { + H(k,k-1) = -s * x; + } + else if (l != m) + { + H(k,k-1) = -H(k,k-1); + } + p = p + s; + x = p / s; + y = q / s; + z = r / s; + q = q / p; + r = r / p; + + // Row modification + + for (long j = k; j < nn; j++) + { + p = H(k,j) + q * H(k+1,j); + if (notlast) + { + p = p + r * H(k+2,j); + H(k+2,j) = H(k+2,j) - p * z; + } + H(k,j) = H(k,j) - p * x; + H(k+1,j) = H(k+1,j) - p * y; + } + + // Column modification + + for (long i = 0; i <= min(n,k+3); i++) + { + p = x * H(i,k) + y * H(i,k+1); + if (notlast) + { + p = p + z * H(i,k+2); + H(i,k+2) = H(i,k+2) - p * r; + } + H(i,k) = H(i,k) - p; + H(i,k+1) = H(i,k+1) - p * q; + } + + // Accumulate transformations + + for (long i = low; i <= high; i++) + { + p = x * V(i,k) + y * V(i,k+1); + if (notlast) + { + p = p + z * V(i,k+2); + V(i,k+2) = V(i,k+2) - p * r; + } + V(i,k) = V(i,k) - p; + V(i,k+1) = V(i,k+1) - p * q; + } + } // (s != 0) + } // k loop + } // check convergence + } // while (n >= low) + + // Backsubstitute to find vectors of upper triangular form + + if (norm == 0.0) + { + return; + } + + for (n = nn-1; n >= 0; n--) + { + p = d(n); + q = e(n); + + // Real vector + + if (q == 0) + { + long l = n; + H(n,n) = 1.0; + for (long i = n-1; i >= 0; i--) + { + w = H(i,i) - p; + r = 0.0; + for (long j = l; j <= n; j++) + { + r = r + H(i,j) * H(j,n); + } + if (e(i) < 0.0) + { + z = w; + s = r; + } + else + { + l = i; + if (e(i) == 0.0) + { + if (w != 0.0) + { + H(i,n) = -r / w; + } + else + { + H(i,n) = -r / (eps * norm); + } + + // Solve real equations + + } + else + { + x = H(i,i+1); + y = H(i+1,i); + q = (d(i) - p) * (d(i) - p) + e(i) * e(i); + t = (x * s - z * r) / q; + H(i,n) = t; + if (abs(x) > abs(z)) + { + H(i+1,n) = (-r - w * t) / x; + } + else + { + H(i+1,n) = (-s - y * t) / z; + } + } + + // Overflow control + + t = abs(H(i,n)); + if ((eps * t) * t > 1) + { + for (long j = i; j <= n; j++) + { + H(j,n) = H(j,n) / t; + } + } + } + } + + // Complex vector + + } + else if (q < 0) + { + long l = n-1; + + // Last vector component imaginary so matrix is triangular + + if (abs(H(n,n-1)) > abs(H(n-1,n))) + { + H(n-1,n-1) = q / H(n,n-1); + H(n-1,n) = -(H(n,n) - p) / H(n,n-1); + } + else + { + cdiv_(0.0,-H(n-1,n),H(n-1,n-1)-p,q); + H(n-1,n-1) = cdivr; + H(n-1,n) = cdivi; + } + H(n,n-1) = 0.0; + H(n,n) = 1.0; + for (long i = n-2; i >= 0; i--) + { + type ra,sa,vr,vi; + ra = 0.0; + sa = 0.0; + for (long j = l; j <= n; j++) + { + ra = ra + H(i,j) * H(j,n-1); + sa = sa + H(i,j) * H(j,n); + } + w = H(i,i) - p; + + if (e(i) < 0.0) + { + z = w; + r = ra; + s = sa; + } + else + { + l = i; + if (e(i) == 0) + { + cdiv_(-ra,-sa,w,q); + H(i,n-1) = cdivr; + H(i,n) = cdivi; + } + else + { + + // Solve complex equations + + x = H(i,i+1); + y = H(i+1,i); + vr = (d(i) - p) * (d(i) - p) + e(i) * e(i) - q * q; + vi = (d(i) - p) * 2.0 * q; + if ((vr == 0.0) && (vi == 0.0)) + { + vr = eps * norm * (abs(w) + abs(q) + + abs(x) + abs(y) + abs(z)); + } + cdiv_(x*r-z*ra+q*sa,x*s-z*sa-q*ra,vr,vi); + H(i,n-1) = cdivr; + H(i,n) = cdivi; + if (abs(x) > (abs(z) + abs(q))) + { + H(i+1,n-1) = (-ra - w * H(i,n-1) + q * H(i,n)) / x; + H(i+1,n) = (-sa - w * H(i,n) - q * H(i,n-1)) / x; + } + else + { + cdiv_(-r-y*H(i,n-1),-s-y*H(i,n),z,q); + H(i+1,n-1) = cdivr; + H(i+1,n) = cdivi; + } + } + + // Overflow control + + t = max(abs(H(i,n-1)),abs(H(i,n))); + if ((eps * t) * t > 1) + { + for (long j = i; j <= n; j++) + { + H(j,n-1) = H(j,n-1) / t; + H(j,n) = H(j,n) / t; + } + } + } + } + } + } + + // Vectors of isolated roots + + for (long i = 0; i < nn; i++) + { + if (i < low || i > high) + { + for (long j = i; j < nn; j++) + { + V(i,j) = H(i,j); + } + } + } + + // Back transformation to get eigenvectors of original matrix + + for (long j = nn-1; j >= low; j--) + { + for (long i = low; i <= high; i++) + { + z = 0.0; + for (long k = low; k <= min(j,high); k++) + { + z = z + V(i,k) * H(k,j); + } + V(i,j) = z; + } + } + } + +// ---------------------------------------------------------------------------------------- + + +} + +#endif // DLIB_MATRIX_EIGENVALUE_DECOMPOSITION_H + + + + diff --git a/ml/dlib/dlib/matrix/matrix_exp.h b/ml/dlib/dlib/matrix/matrix_exp.h new file mode 100644 index 000000000..c0afb54c0 --- /dev/null +++ b/ml/dlib/dlib/matrix/matrix_exp.h @@ -0,0 +1,271 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MATRIx_EXP_h_ +#define DLIB_MATRIx_EXP_h_ + +#include "../algs.h" +#include "../is_kind.h" +#include "matrix_fwd.h" +#include "matrix_exp_abstract.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + // We want to return the compile time constant if our NR and NC dimensions + // aren't zero but if they are then we want to call ref_.nx() and return + // the correct values. + template < typename exp_type, long NR > + struct get_nr_helper + { + static inline long get(const exp_type&) { return NR; } + }; + + template < typename exp_type > + struct get_nr_helper + { + static inline long get(const exp_type& m) { return m.nr(); } + }; + + template < typename exp_type, long NC > + struct get_nc_helper + { + static inline long get(const exp_type&) { return NC; } + }; + + template < typename exp_type > + struct get_nc_helper + { + static inline long get(const exp_type& m) { return m.nc(); } + }; + + template + struct matrix_traits + { + typedef typename EXP::type type; + typedef typename EXP::const_ret_type const_ret_type; + typedef typename EXP::mem_manager_type mem_manager_type; + typedef typename EXP::layout_type layout_type; + const static long NR = EXP::NR; + const static long NC = EXP::NC; + const static long cost = EXP::cost; + }; + +// ---------------------------------------------------------------------------------------- + + template class matrix_exp; + template + class matrix_exp_iterator : public std::iterator::type> + { + friend class matrix_exp; + matrix_exp_iterator(const EXP& m_, long r_, long c_) + { + r = r_; + c = c_; + nc = m_.nc(); + m = &m_; + } + + public: + + matrix_exp_iterator() : r(0), c(0), nc(0), m(0) {} + + typedef typename matrix_traits::type type; + typedef type value_type; + typedef typename matrix_traits::const_ret_type const_ret_type; + + + bool operator == ( const matrix_exp_iterator& itr) const + { return r == itr.r && c == itr.c; } + + bool operator != ( const matrix_exp_iterator& itr) const + { return !(*this == itr); } + + matrix_exp_iterator& operator++() + { + ++c; + if (c==nc) + { + c = 0; + ++r; + } + return *this; + } + + matrix_exp_iterator operator++(int) + { + matrix_exp_iterator temp(*this); + ++(*this); + return temp; + } + + const_ret_type operator* () const { return (*m)(r,c); } + + private: + long r, c; + long nc; + const EXP* m; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP + > + class matrix_exp + { + /*! + REQUIREMENTS ON EXP + EXP should be something convertible to a matrix_exp. That is, + it should inherit from matrix_exp + !*/ + + public: + typedef typename matrix_traits::type type; + typedef type value_type; + typedef typename matrix_traits::const_ret_type const_ret_type; + typedef typename matrix_traits::mem_manager_type mem_manager_type; + typedef typename matrix_traits::layout_type layout_type; + const static long NR = matrix_traits::NR; + const static long NC = matrix_traits::NC; + const static long cost = matrix_traits::cost; + + typedef matrix matrix_type; + typedef EXP exp_type; + typedef matrix_exp_iterator iterator; + typedef matrix_exp_iterator const_iterator; + + inline const_ret_type operator() ( + long r, + long c + ) const + { + DLIB_ASSERT(r < nr() && c < nc() && r >= 0 && c >= 0, + "\tconst type matrix_exp::operator(r,c)" + << "\n\tYou must give a valid row and column" + << "\n\tr: " << r + << "\n\tc: " << c + << "\n\tnr(): " << nr() + << "\n\tnc(): " << nc() + << "\n\tthis: " << this + ); + return ref()(r,c); + } + + const_ret_type operator() ( + long i + ) const + { + COMPILE_TIME_ASSERT(NC == 1 || NC == 0 || NR == 1 || NR == 0); + DLIB_ASSERT(nc() == 1 || nr() == 1, + "\tconst type matrix_exp::operator(i)" + << "\n\tYou can only use this operator on column or row vectors" + << "\n\ti: " << i + << "\n\tnr(): " << nr() + << "\n\tnc(): " << nc() + << "\n\tthis: " << this + ); + DLIB_ASSERT( ((nc() == 1 && i < nr()) || (nr() == 1 && i < nc())) && i >= 0, + "\tconst type matrix_exp::operator(i)" + << "\n\tYou must give a valid row/column number" + << "\n\ti: " << i + << "\n\tnr(): " << nr() + << "\n\tnc(): " << nc() + << "\n\tthis: " << this + ); + if (nc() == 1) + return ref()(i,0); + else + return ref()(0,i); + } + + long size ( + ) const { return nr()*nc(); } + + long nr ( + ) const { return get_nr_helper::get(ref()); } + + long nc ( + ) const { return get_nc_helper::get(ref()); } + + template + bool aliases ( + const matrix_exp& item + ) const { return ref().aliases(item); } + + template + bool destructively_aliases ( + const matrix_exp& item + ) const { return ref().destructively_aliases(item); } + + inline const exp_type& ref ( + ) const { return *static_cast(this); } + + inline operator const type ( + ) const + { + COMPILE_TIME_ASSERT(NC == 1 || NC == 0); + COMPILE_TIME_ASSERT(NR == 1 || NR == 0); + DLIB_ASSERT(nr() == 1 && nc() == 1, + "\tmatrix_exp::operator const type() const" + << "\n\tYou can only use this operator on a 1x1 matrix" + << "\n\tnr(): " << nr() + << "\n\tnc(): " << nc() + << "\n\tthis: " << this + ); + + // Put the expression contained in this matrix_exp into + // a temporary 1x1 matrix so that the expression will encounter + // all the overloads of matrix_assign() and have the chance to + // go through any applicable optimizations. + matrix temp(ref()); + return temp(0); + } + + const_iterator begin() const { return matrix_exp_iterator(ref(),0,0); } + const_iterator end() const { return matrix_exp_iterator(ref(),nr(),0); } + + protected: + matrix_exp() {} + matrix_exp(const matrix_exp& ) {} + + private: + + matrix_exp& operator= (const matrix_exp&); + }; + +// ---------------------------------------------------------------------------------------- + + // something is a matrix if it is convertible to a matrix_exp object + template + struct is_matrix& > >::type > + { static const bool value = true; }; + /* + is_matrix::value == 1 if T is a matrix type else 0 + */ + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP + > + class matrix_diag_exp : public matrix_exp + { + /*! + This is a matrix expression type used to represent diagonal matrices. + That is, square matrices with all off diagonal elements equal to 0. + !*/ + + protected: + matrix_diag_exp() {} + matrix_diag_exp(const matrix_diag_exp& item ):matrix_exp(item) {} + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MATRIx_EXP_h_ + diff --git a/ml/dlib/dlib/matrix/matrix_exp_abstract.h b/ml/dlib/dlib/matrix/matrix_exp_abstract.h new file mode 100644 index 000000000..14ad143c2 --- /dev/null +++ b/ml/dlib/dlib/matrix/matrix_exp_abstract.h @@ -0,0 +1,210 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_MATRIx_EXP_ABSTRACT_ +#ifdef DLIB_MATRIx_EXP_ABSTRACT_ + +#include "matrix_fwd.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP + > + class matrix_exp + { + /*! + REQUIREMENTS ON EXP + - must be an object that inherits publicly from matrix_exp (this class). + + WHAT THIS OBJECT REPRESENTS + This object represents an expression that evaluates to a matrix + of nr() rows and nc() columns. + + The reason for having an object that represents an expression is that it + allows us to use the "expression templates" technique to eliminate the + temporary matrix objects that would normally be returned from expressions + such as M = A+B+C+D; Normally each invocation of the + operator would + construct and return a temporary matrix object but using this technique we + can avoid creating all of these temporary objects and receive a large + speed boost. + + Note that every time you invoke operator() on this object it recomputes + its result which may not be what you want to do. For example, if you + are going to be accessing the same element over and over it might + be faster to assign the matrix_exp to a temporary matrix and then + use that temporary. + + + const_ret_type typedef (defined below) + The purpose of the const_ret_type typedef is to allow matrix expressions + to return their elements by reference when appropriate. So const_ret_type + should be one of the following types: + - const type + - const type& + !*/ + + public: + typedef typename EXP::type type; + typedef type value_type; // Redefined for compatibility with the STL + typedef typename EXP::const_ret_type const_ret_type; + typedef typename EXP::mem_manager_type mem_manager_type; + typedef typename EXP::layout_type layout_type; + const static long cost = EXP::cost; + const static long NR = EXP::NR; + const static long NC = EXP::NC; + typedef matrix matrix_type; + typedef EXP exp_type; + typedef matrix_exp_iterator iterator; + typedef matrix_exp_iterator const_iterator; + + const_ret_type operator() ( + long r, + long c + ) const; + /*! + requires + - 0 <= r < nr() + - 0 <= c < nc() + ensures + - returns ref()(r,c) + (i.e. returns the value at the given row and column that would be in + the matrix represented by this matrix expression) + !*/ + + const_ret_type operator() ( + long i + ) const; + /*! + requires + - nc() == 1 || nr() == 1 (i.e. this must be a column or row vector) + - if (nc() == 1) then + - 0 <= i < nr() + - else + - 0 <= i < nc() + ensures + - if (nc() == 1) then + - returns (*this)(i,0) + - else + - returns (*this)(0,i) + !*/ + + operator const type ( + ) const; + /*! + requires + - nr() == 1 + - nc() == 1 + ensures + - returns (*this)(0,0) + !*/ + + long nr ( + ) const; + /*! + ensures + - returns the number of rows in this matrix expression. + !*/ + + long nc ( + ) const; + /*! + ensures + - returns the number of columns in this matrix expression. + !*/ + + long size ( + ) const; + /*! + ensures + - returns nr()*nc() + !*/ + + template + bool aliases ( + const matrix_exp& item + ) const; + /*! + ensures + - if (A change to the state of item could cause a change to the state of *this + matrix_exp object. ) then + - returns true + - This happens when this matrix_exp contains item in some way. + - else + - returns false + !*/ + + template + bool destructively_aliases ( + const matrix_exp& item + ) const; + /*! + ensures + - if (aliases(item)) then + - if (nr() != item.nr() || nc() != item.nc() + - returns true + (i.e. if this expression has different dimensions than item then + we have destructive aliasing) + + - returns true if the following assignment would evaluate incorrectly: + for (long r = 0; r < nr(); ++r) + for (long c = 0; c < nc(); ++c) + item(r,c) = (*this)(r,c) + - That is, if this matrix expression aliases item in such a way that a modification + to element item(r,c) causes a change in the value of something other than + (*this)(r,c) then this function returns true. + + - returns false if none of the above conditions say we should return true + - else + - returns false + !*/ + + inline const exp_type& ref ( + ) const; + /*! + ensures + - returns a reference to the expression contained in *this. + (i.e. returns *static_cast(this) ) + !*/ + + const_iterator begin( + ) const; + /*! + ensures + - returns a forward access iterator pointing to the first element in this + matrix expression. + - Since matrix_exp objects represent immutable views of a matrix, the + returned iterator does not allow the user to modify the matrix + expression's elements. + - The iterator will iterate over the elements of the matrix in row major + order. + !*/ + + const_iterator end( + ) const; + /*! + ensures + - returns a forward access iterator pointing to one past the end of the + last element in this matrix expression. + !*/ + + protected: + + // Only derived classes of matrix_exp may call the matrix_exp constructors. + matrix_exp(const matrix_exp&); + matrix_exp(); + + private: + // no one may ever use the assignment operator on a matrix_exp + matrix_exp& operator= (const matrix_exp&); + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MATRIx_EXP_ABSTRACT_ + + diff --git a/ml/dlib/dlib/matrix/matrix_expressions.h b/ml/dlib/dlib/matrix/matrix_expressions.h new file mode 100644 index 000000000..9f057d076 --- /dev/null +++ b/ml/dlib/dlib/matrix/matrix_expressions.h @@ -0,0 +1,280 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MATRIx_EXPRESSIONS_H_ +#define DLIB_MATRIx_EXPRESSIONS_H_ + +#include "matrix_fwd.h" + +#ifdef _MSC_VER +// This #pragma directive is also located in the algs.h file but for whatever +// reason visual studio 9 just ignores it when it is only there. + +// this is to disable the "'this' : used in base member initializer list" +// warning you get from some of the GUI objects since all the objects +// require that their parent class be passed into their constructor. +// In this case though it is totally safe so it is ok to disable this warning. +#pragma warning(disable : 4355) +#endif + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// Helper templates for making operators used by expression objects +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template + class matrix_range_exp; + + template + struct matrix_traits > + { + typedef T type; + typedef const T const_ret_type; + typedef default_memory_manager mem_manager_type; + typedef row_major_layout layout_type; + const static long NR = 1; + const static long NC = 0; + const static long cost = 1; + }; + + template + class matrix_range_exp : public matrix_exp > + { + public: + typedef typename matrix_traits::type type; + typedef typename matrix_traits::const_ret_type const_ret_type; + typedef typename matrix_traits::mem_manager_type mem_manager_type; + const static long NR = matrix_traits::NR; + const static long NC = matrix_traits::NC; + const static long cost = matrix_traits::cost; + typedef typename matrix_traits::layout_type layout_type; + + + matrix_range_exp ( + T start_, + T end_ + ) + { + start = start_; + if (start_ <= end_) + inc = 1; + else + inc = -1; + nc_ = std::abs(end_ - start_) + 1; + } + matrix_range_exp ( + T start_, + T inc_, + T end_ + ) + { + start = start_; + nc_ = std::abs(end_ - start_)/inc_ + 1; + if (start_ <= end_) + inc = inc_; + else + inc = -inc_; + } + + matrix_range_exp ( + T start_, + T end_, + long num, + bool + ) + { + start = start_; + nc_ = num; + if (num > 1) + { + inc = (end_-start_)/(num-1); + } + else + { + inc = 0; + start = end_; + } + + } + + const_ret_type operator() ( + long, + long c + ) const { return start + c*inc; } + + const_ret_type operator() ( + long c + ) const { return start + c*inc; } + + template + bool aliases ( + const matrix_exp& + ) const { return false; } + + template + bool destructively_aliases ( + const matrix_exp& + ) const { return false; } + + long nr ( + ) const { return NR; } + + long nc ( + ) const { return nc_; } + + long nc_; + T start; + T inc; + }; + +// ---------------------------------------------------------------------------------------- + + template + class matrix_log_range_exp; + + template + struct matrix_traits > + { + typedef T type; + typedef const T const_ret_type; + typedef default_memory_manager mem_manager_type; + typedef row_major_layout layout_type; + const static long NR = 1; + const static long NC = 0; + const static long cost = 1; + }; + + template + class matrix_log_range_exp : public matrix_exp > + { + public: + typedef typename matrix_traits::type type; + typedef typename matrix_traits::const_ret_type const_ret_type; + typedef typename matrix_traits::mem_manager_type mem_manager_type; + const static long NR = matrix_traits::NR; + const static long NC = matrix_traits::NC; + const static long cost = matrix_traits::cost; + typedef typename matrix_traits::layout_type layout_type; + + + matrix_log_range_exp ( + T start_, + T end_, + long num + ) + { + start = start_; + nc_ = num; + if (num > 1) + { + inc = (end_-start_)/(num-1); + } + else + { + inc = 0; + start = end_; + } + + } + + const_ret_type operator() ( + long, + long c + ) const { return std::pow((T)10,start + c*inc); } + + const_ret_type operator() ( + long c + ) const { return std::pow((T)10,start + c*inc); } + + template + bool aliases ( + const matrix_exp& + ) const { return false; } + + template + bool destructively_aliases ( + const matrix_exp& + ) const { return false; } + + long nr ( + ) const { return NR; } + + long nc ( + ) const { return nc_; } + + long nc_; + T start; + T inc; + }; + +// ---------------------------------------------------------------------------------------- + + template + class matrix_range_static_exp; + + template + struct matrix_traits > + { + typedef long type; + typedef const long const_ret_type; + typedef default_memory_manager mem_manager_type; + const static long NR = 1; + const static long NC = tabs<(end - start)>::value/inc_ + 1; + const static long cost = 1; + typedef row_major_layout layout_type; + }; + + template + class matrix_range_static_exp : public matrix_exp > + { + public: + typedef typename matrix_traits::type type; + typedef typename matrix_traits::const_ret_type const_ret_type; + typedef typename matrix_traits::mem_manager_type mem_manager_type; + const static long NR = matrix_traits::NR; + const static long NC = matrix_traits::NC; + const static long cost = matrix_traits::cost; + typedef typename matrix_traits::layout_type layout_type; + + const static long inc = (start <= end_)?inc_:-inc_; + + + matrix_range_static_exp ( + ) {} + + const_ret_type operator() ( + long , + long c + ) const { return start + c*inc; } + + const_ret_type operator() ( + long c + ) const { return start + c*inc; } + + template + bool aliases ( + const matrix_exp& + ) const { return false; } + + template + bool destructively_aliases ( + const matrix_exp& + ) const { return false; } + + long nr ( + ) const { return NR; } + + long nc ( + ) const { return NC; } + + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MATRIx_EXPRESSIONS_H_ + diff --git a/ml/dlib/dlib/matrix/matrix_fft.h b/ml/dlib/dlib/matrix/matrix_fft.h new file mode 100644 index 000000000..fbca6d344 --- /dev/null +++ b/ml/dlib/dlib/matrix/matrix_fft.h @@ -0,0 +1,846 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_FFt_Hh_ +#define DLIB_FFt_Hh_ + +#include "matrix_fft_abstract.h" +#include "matrix_utilities.h" +#include "../hash.h" +#include "../algs.h" + +#ifdef DLIB_USE_MKL_FFT +#include +#endif + +// No using FFTW until it becomes thread safe! +#if 0 +#ifdef DLIB_USE_FFTW +#include +#endif // DLIB_USE_FFTW +#endif + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + inline bool is_power_of_two ( + const unsigned long& value + ) + { + if (value == 0) + return true; + else + return count_bits(value) == 1; + } + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + + // ------------------------------------------------------------------------------------ + + /* + The next few functions related to doing FFTs are derived from Stefan + Gustavson's (stegu@itn.liu.se) public domain 2D Fourier transformation code. + The code has a long history, originally a FORTRAN implementation published in: + Programming for Digital Signal Processing, IEEE Press 1979, Section 1, by G. D. + Bergland and M. T. Dolan. In 2003 it was cleaned up and turned into modern C + by Steven Gustavson. Davis King then rewrote it in modern C++ in 2014 and also + changed the transform so that the outputs are identical to those given from FFTW. + */ + + // ------------------------------------------------------------------------------------ + + /* Get binary log of integer argument - exact if n is a power of 2 */ + inline long fastlog2(long n) + { + long log = -1; + while(n) { + log++; + n >>= 1; + } + return log ; + } + + // ------------------------------------------------------------------------------------ + + /* Radix-2 iteration subroutine */ + template + void R2TX(int nthpo, std::complex *c0, std::complex *c1) + { + for(int k=0; k temp = c0[k] + c1[k]; + c1[k] = c0[k] - c1[k]; + c0[k] = temp; + } + } + + // ------------------------------------------------------------------------------------ + + /* Radix-4 iteration subroutine */ + template + void R4TX(int nthpo, std::complex *c0, std::complex *c1, + std::complex *c2, std::complex *c3) + { + for(int k=0;k t1, t2, t3, t4; + t1 = c0[k] + c2[k]; + t2 = c0[k] - c2[k]; + t3 = c1[k] + c3[k]; + t4 = c1[k] - c3[k]; + + c0[k] = t1 + t3; + c1[k] = t1 - t3; + c2[k] = std::complex(t2.real()-t4.imag(), t2.imag()+t4.real()); + c3[k] = std::complex(t2.real()+t4.imag(), t2.imag()-t4.real()); + } + } + + // ------------------------------------------------------------------------------------ + + template + class twiddles + { + /*! + The point of this object is to cache the twiddle values so we don't + recompute them over and over inside R8TX(). + !*/ + public: + + twiddles() + { + data.resize(64); + } + + const std::complex* get_twiddles ( + int p + ) + /*! + requires + - 0 <= p <= 64 + ensures + - returns a pointer to the twiddle factors needed by R8TX if nxtlt == 2^p + !*/ + { + // Compute the twiddle factors for this p value if we haven't done so + // already. + if (data[p].size() == 0) + { + const int nxtlt = 0x1 << p; + data[p].reserve(nxtlt*7); + const T twopi = 6.2831853071795865; /* 2.0 * pi */ + const T scale = twopi/(nxtlt*8.0); + std::complex cs[7]; + for (int j = 0; j < nxtlt; ++j) + { + const T arg = j*scale; + cs[0] = std::complex(std::cos(arg),std::sin(arg)); + cs[1] = cs[0]*cs[0]; + cs[2] = cs[1]*cs[0]; + cs[3] = cs[1]*cs[1]; + cs[4] = cs[2]*cs[1]; + cs[5] = cs[2]*cs[2]; + cs[6] = cs[3]*cs[2]; + data[p].insert(data[p].end(), cs, cs+7); + } + } + + return &data[p][0]; + } + + private: + std::vector > > data; + }; + + // ---------------------------------------------------------------------------------------- + + /* Radix-8 iteration subroutine */ + template + void R8TX(int nxtlt, int nthpo, int length, const std::complex* cs, + std::complex *cc0, std::complex *cc1, std::complex *cc2, std::complex *cc3, + std::complex *cc4, std::complex *cc5, std::complex *cc6, std::complex *cc7) + { + const T irt2 = 0.707106781186548; /* 1.0/sqrt(2.0) */ + + for(int j=0; j a0, a1, a2, a3, a4, a5, a6, a7; + std::complex b0, b1, b2, b3, b4, b5, b6, b7; + a0 = cc0[k] + cc4[k]; + a1 = cc1[k] + cc5[k]; + a2 = cc2[k] + cc6[k]; + a3 = cc3[k] + cc7[k]; + a4 = cc0[k] - cc4[k]; + a5 = cc1[k] - cc5[k]; + a6 = cc2[k] - cc6[k]; + a7 = cc3[k] - cc7[k]; + + b0 = a0 + a2; + b1 = a1 + a3; + b2 = a0 - a2; + b3 = a1 - a3; + + b4 = std::complex(a4.real()-a6.imag(), a4.imag()+a6.real()); + b5 = std::complex(a5.real()-a7.imag(), a5.imag()+a7.real()); + b6 = std::complex(a4.real()+a6.imag(), a4.imag()-a6.real()); + b7 = std::complex(a5.real()+a7.imag(), a5.imag()-a7.real()); + + const std::complex tmp0(-b3.imag(), b3.real()); + const std::complex tmp1(irt2*(b5.real()-b5.imag()), irt2*(b5.real()+b5.imag())); + const std::complex tmp2(-irt2*(b7.real()+b7.imag()), irt2*(b7.real()-b7.imag())); + + cc0[k] = b0 + b1; + cc1[k] = b0 - b1; + cc2[k] = b2 + tmp0; + cc3[k] = b2 - tmp0; + cc4[k] = b4 + tmp1; + cc5[k] = b4 - tmp1; + cc6[k] = b6 + tmp2; + cc7[k] = b6 - tmp2; + if(j>0) + { + cc1[k] *= cs[3]; + cc2[k] *= cs[1]; + cc3[k] *= cs[5]; + cc4[k] *= cs[0]; + cc5[k] *= cs[4]; + cc6[k] *= cs[2]; + cc7[k] *= cs[6]; + } + } + + cs += 7; + } + } + + // ------------------------------------------------------------------------------------ + + template + void fft1d_inplace(matrix,NR,NC,MM,layout>& data, bool do_backward_fft, twiddles& cs) + /*! + requires + - is_vector(data) == true + - is_power_of_two(data.size()) == true + ensures + - This routine replaces the input std::complex vector by its finite + discrete complex fourier transform if do_backward_fft==true. It replaces + the input std::complex vector by its finite discrete complex + inverse fourier transform if do_backward_fft==false. + + The implementation is a radix-2 FFT, but with faster shortcuts for + radix-4 and radix-8. It performs as many radix-8 iterations as possible, + and then finishes with a radix-2 or -4 iteration if needed. + !*/ + { + COMPILE_TIME_ASSERT((is_same_type::value || is_same_type::value || is_same_type::value )); + + if (data.size() == 0) + return; + + std::complex* const b = &data(0); + int L[16],L1,L2,L3,L4,L5,L6,L7,L8,L9,L10,L11,L12,L13,L14,L15; + int j1,j2,j3,j4,j5,j6,j7,j8,j9,j10,j11,j12,j13,j14; + int j, ij, ji; + int n2pow, n8pow, nthpo, ipass, nxtlt, length; + + n2pow = fastlog2(data.size()); + nthpo = data.size(); + + n8pow = n2pow/3; + + if(n8pow) + { + /* Radix 8 iterations */ + for(ipass=1;ipass<=n8pow;ipass++) + { + const int p = n2pow - 3*ipass; + nxtlt = 0x1 << p; + length = 8*nxtlt; + R8TX(nxtlt, nthpo, length, cs.get_twiddles(p), + b, b+nxtlt, b+2*nxtlt, b+3*nxtlt, + b+4*nxtlt, b+5*nxtlt, b+6*nxtlt, b+7*nxtlt); + } + } + + if(n2pow%3 == 1) + { + /* A final radix 2 iteration is needed */ + R2TX(nthpo, b, b+1); + } + + if(n2pow%3 == 2) + { + /* A final radix 4 iteration is needed */ + R4TX(nthpo, b, b+1, b+2, b+3); + } + + for(j=1;j<=15;j++) + { + L[j] = 1; + if(j-n2pow <= 0) L[j] = 0x1 << (n2pow + 1 - j); + } + + L15=L[1];L14=L[2];L13=L[3];L12=L[4];L11=L[5];L10=L[6];L9=L[7]; + L8=L[8];L7=L[9];L6=L[10];L5=L[11];L4=L[12];L3=L[13];L2=L[14];L1=L[15]; + + ij = 0; + + for(j1=0;j1 + void fft2d_inplace( + matrix,NR,NC,MM,L>& data, + bool do_backward_fft + ) + { + if (data.size() == 0) + return; + + matrix > buff; + twiddles cs; + + // Compute transform row by row + for(long r=0; r >(rowm(data,r)); + fft1d_inplace(buff, do_backward_fft, cs); + set_rowm(data,r) = matrix_cast >(buff); + } + + // Compute transform column by column + for(long c=0; c >(colm(data,c)); + fft1d_inplace(buff, do_backward_fft, cs); + set_colm(data,c) = matrix_cast >(buff); + } + } + + // ---------------------------------------------------------------------------------------- + + template < + typename EXP, + typename T + > + void fft2d( + const matrix_exp& data, + matrix >& data_out, + bool do_backward_fft + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(is_power_of_two(data.nr()) && is_power_of_two(data.nc()), + "\t matrix fft(data)" + << "\n\t The number of rows and columns must be powers of two." + << "\n\t data.nr(): "<< data.nr() + << "\n\t data.nc(): "<< data.nc() + << "\n\t is_power_of_two(data.nr()): " << is_power_of_two(data.nr()) + << "\n\t is_power_of_two(data.nc()): " << is_power_of_two(data.nc()) + ); + + if (data.size() == 0) + return; + + matrix > buff; + data_out.set_size(data.nr(), data.nc()); + twiddles cs; + + // Compute transform row by row + for(long r=0; r >(rowm(data,r)); + fft1d_inplace(buff, do_backward_fft, cs); + set_rowm(data_out,r) = matrix_cast >(buff); + } + + // Compute transform column by column + for(long c=0; c >(colm(data_out,c)); + fft1d_inplace(buff, do_backward_fft, cs); + set_colm(data_out,c) = matrix_cast >(buff); + } + } + + // ------------------------------------------------------------------------------------ + + } // end namespace impl + +// ---------------------------------------------------------------------------------------- + + template + matrix fft (const matrix_exp& data) + { + // You have to give a complex matrix + COMPILE_TIME_ASSERT(is_complex::value); + // make sure requires clause is not broken + DLIB_CASSERT(is_power_of_two(data.nr()) && is_power_of_two(data.nc()), + "\t matrix fft(data)" + << "\n\t The number of rows and columns must be powers of two." + << "\n\t data.nr(): "<< data.nr() + << "\n\t data.nc(): "<< data.nc() + << "\n\t is_power_of_two(data.nr()): " << is_power_of_two(data.nr()) + << "\n\t is_power_of_two(data.nc()): " << is_power_of_two(data.nc()) + ); + + if (data.nr() == 1 || data.nc() == 1) + { + matrix temp(data); + impl::twiddles cs; + impl::fft1d_inplace(temp, false, cs); + return temp; + } + else + { + matrix temp; + impl::fft2d(data, temp, false); + return temp; + } + } + + template + matrix ifft (const matrix_exp& data) + { + // You have to give a complex matrix + COMPILE_TIME_ASSERT(is_complex::value); + // make sure requires clause is not broken + DLIB_CASSERT(is_power_of_two(data.nr()) && is_power_of_two(data.nc()), + "\t matrix ifft(data)" + << "\n\t The number of rows and columns must be powers of two." + << "\n\t data.nr(): "<< data.nr() + << "\n\t data.nc(): "<< data.nc() + << "\n\t is_power_of_two(data.nr()): " << is_power_of_two(data.nr()) + << "\n\t is_power_of_two(data.nc()): " << is_power_of_two(data.nc()) + ); + + matrix temp; + if (data.size() == 0) + return temp; + + if (data.nr() == 1 || data.nc() == 1) + { + temp = data; + impl::twiddles cs; + impl::fft1d_inplace(temp, true, cs); + } + else + { + impl::fft2d(data, temp, true); + } + temp /= data.size(); + return temp; + } + +// ---------------------------------------------------------------------------------------- + + template < typename T, long NR, long NC, typename MM, typename L > + typename enable_if_c::type fft_inplace (matrix,NR,NC,MM,L>& data) + // Note that we don't divide the outputs by data.size() so this isn't quite the inverse. + { + // make sure requires clause is not broken + DLIB_CASSERT(is_power_of_two(data.nr()) && is_power_of_two(data.nc()), + "\t void fft_inplace(data)" + << "\n\t The number of rows and columns must be powers of two." + << "\n\t data.nr(): "<< data.nr() + << "\n\t data.nc(): "<< data.nc() + << "\n\t is_power_of_two(data.nr()): " << is_power_of_two(data.nr()) + << "\n\t is_power_of_two(data.nc()): " << is_power_of_two(data.nc()) + ); + + impl::twiddles cs; + impl::fft1d_inplace(data, false, cs); + } + + template < typename T, long NR, long NC, typename MM, typename L > + typename disable_if_c::type fft_inplace (matrix,NR,NC,MM,L>& data) + // Note that we don't divide the outputs by data.size() so this isn't quite the inverse. + { + // make sure requires clause is not broken + DLIB_CASSERT(is_power_of_two(data.nr()) && is_power_of_two(data.nc()), + "\t void fft_inplace(data)" + << "\n\t The number of rows and columns must be powers of two." + << "\n\t data.nr(): "<< data.nr() + << "\n\t data.nc(): "<< data.nc() + << "\n\t is_power_of_two(data.nr()): " << is_power_of_two(data.nr()) + << "\n\t is_power_of_two(data.nc()): " << is_power_of_two(data.nc()) + ); + + impl::fft2d_inplace(data, false); + } + +// ---------------------------------------------------------------------------------------- + + template < typename T, long NR, long NC, typename MM, typename L > + typename enable_if_c::type ifft_inplace (matrix,NR,NC,MM,L>& data) + { + // make sure requires clause is not broken + DLIB_CASSERT(is_power_of_two(data.nr()) && is_power_of_two(data.nc()), + "\t void ifft_inplace(data)" + << "\n\t The number of rows and columns must be powers of two." + << "\n\t data.nr(): "<< data.nr() + << "\n\t data.nc(): "<< data.nc() + << "\n\t is_power_of_two(data.nr()): " << is_power_of_two(data.nr()) + << "\n\t is_power_of_two(data.nc()): " << is_power_of_two(data.nc()) + ); + + impl::twiddles cs; + impl::fft1d_inplace(data, true, cs); + } + + template < typename T, long NR, long NC, typename MM, typename L > + typename disable_if_c::type ifft_inplace (matrix,NR,NC,MM,L>& data) + { + // make sure requires clause is not broken + DLIB_CASSERT(is_power_of_two(data.nr()) && is_power_of_two(data.nc()), + "\t void ifft_inplace(data)" + << "\n\t The number of rows and columns must be powers of two." + << "\n\t data.nr(): "<< data.nr() + << "\n\t data.nc(): "<< data.nc() + << "\n\t is_power_of_two(data.nr()): " << is_power_of_two(data.nr()) + << "\n\t is_power_of_two(data.nc()): " << is_power_of_two(data.nc()) + ); + + impl::fft2d_inplace(data, true); + } + +// ---------------------------------------------------------------------------------------- + + /* + I'm disabling any use of the FFTW bindings because FFTW is, as of this writing, not + threadsafe as a library. This means that if multiple threads were to make + concurrent calls to these fft routines then the program could crash. If at some + point FFTW is fixed I'll turn these bindings back on. + + See https://github.com/FFTW/fftw3/issues/16 + */ +#if 0 +#ifdef DLIB_USE_FFTW + + template + matrix,NR,NC,MM,L> call_fftw_fft( + const matrix,NR,NC,MM,L>& data + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(is_power_of_two(data.nr()) && is_power_of_two(data.nc()), + "\t matrix fft(data)" + << "\n\t The number of rows and columns must be powers of two." + << "\n\t data.nr(): "<< data.nr() + << "\n\t data.nc(): "<< data.nc() + << "\n\t is_power_of_two(data.nr()): " << is_power_of_two(data.nr()) + << "\n\t is_power_of_two(data.nc()): " << is_power_of_two(data.nc()) + ); + + if (data.size() == 0) + return data; + + matrix,NR,NC,MM,L> m2(data.nr(),data.nc()); + fftw_complex *in, *out; + fftw_plan p; + in = (fftw_complex*)&data(0,0); + out = (fftw_complex*)&m2(0,0); + if (data.nr() == 1 || data.nc() == 1) + p = fftw_plan_dft_1d(data.size(), in, out, FFTW_FORWARD, FFTW_ESTIMATE); + else + p = fftw_plan_dft_2d(data.nr(), data.nc(), in, out, FFTW_FORWARD, FFTW_ESTIMATE); + fftw_execute(p); + fftw_destroy_plan(p); + return m2; + } + + template + matrix,NR,NC,MM,L> call_fftw_ifft( + const matrix,NR,NC,MM,L>& data + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(is_power_of_two(data.nr()) && is_power_of_two(data.nc()), + "\t matrix ifft(data)" + << "\n\t The number of rows and columns must be powers of two." + << "\n\t data.nr(): "<< data.nr() + << "\n\t data.nc(): "<< data.nc() + << "\n\t is_power_of_two(data.nr()): " << is_power_of_two(data.nr()) + << "\n\t is_power_of_two(data.nc()): " << is_power_of_two(data.nc()) + ); + + if (data.size() == 0) + return data; + + matrix,NR,NC,MM,L> m2(data.nr(),data.nc()); + fftw_complex *in, *out; + fftw_plan p; + in = (fftw_complex*)&data(0,0); + out = (fftw_complex*)&m2(0,0); + if (data.nr() == 1 || data.nc() == 1) + p = fftw_plan_dft_1d(data.size(), in, out, FFTW_BACKWARD, FFTW_ESTIMATE); + else + p = fftw_plan_dft_2d(data.nr(), data.nc(), in, out, FFTW_BACKWARD, FFTW_ESTIMATE); + fftw_execute(p); + fftw_destroy_plan(p); + return m2; + } + +// ---------------------------------------------------------------------------------------- + +// call FFTW for these cases: + inline matrix,0,1> fft (const matrix,0,1>& data) {return call_fftw_fft(data);} + inline matrix,0,1> ifft(const matrix,0,1>& data) {return call_fftw_ifft(data)/data.size();} + inline matrix,1,0> fft (const matrix,1,0>& data) {return call_fftw_fft(data);} + inline matrix,1,0> ifft(const matrix,1,0>& data) {return call_fftw_ifft(data)/data.size();} + inline matrix > fft (const matrix >& data) {return call_fftw_fft(data);} + inline matrix > ifft(const matrix >& data) {return call_fftw_ifft(data)/data.size();} + + inline void fft_inplace (matrix,0,1>& data) {data = call_fftw_fft(data);} + inline void ifft_inplace(matrix,0,1>& data) {data = call_fftw_ifft(data);} + inline void fft_inplace (matrix,1,0>& data) {data = call_fftw_fft(data);} + inline void ifft_inplace(matrix,1,0>& data) {data = call_fftw_ifft(data);} + inline void fft_inplace (matrix >& data) {data = call_fftw_fft(data);} + inline void ifft_inplace(matrix >& data) {data = call_fftw_ifft(data);} + +#endif // DLIB_USE_FFTW +#endif // end of #if 0 + +// ---------------------------------------------------------------------------------------- + +#ifdef DLIB_USE_MKL_FFT + +#define DLIB_DFTI_CHECK_STATUS(s) \ + if((s) != 0 && !DftiErrorClass((s), DFTI_NO_ERROR)) \ + { \ + throw dlib::error(DftiErrorMessage((s))); \ + } + + template < long NR, long NC, typename MM, typename L > + matrix,NR,NC,MM,L> call_mkl_fft( + const matrix,NR,NC,MM,L>& data, + bool do_backward_fft) + { + // make sure requires clause is not broken + DLIB_CASSERT(is_power_of_two(data.nr()) && is_power_of_two(data.nc()), + "\t matrix fft(data)" + << "\n\t The number of rows and columns must be powers of two." + << "\n\t data.nr(): "<< data.nr() + << "\n\t data.nc(): "<< data.nc() + << "\n\t is_power_of_two(data.nr()): " << is_power_of_two(data.nr()) + << "\n\t is_power_of_two(data.nc()): " << is_power_of_two(data.nc()) + ); + + if (data.size() == 0) + return data; + + DFTI_DESCRIPTOR_HANDLE h; + MKL_LONG status; + + if (data.nr() == 1 || data.nc() == 1) + { + status = DftiCreateDescriptor(&h, DFTI_DOUBLE, DFTI_COMPLEX, 1, data.size()); + DLIB_DFTI_CHECK_STATUS(status); + } + else + { + MKL_LONG size[2]; + size[0] = data.nr(); + size[1] = data.nc(); + + status = DftiCreateDescriptor(&h, DFTI_DOUBLE, DFTI_COMPLEX, 2, size); + DLIB_DFTI_CHECK_STATUS(status); + + MKL_LONG strides[3]; + strides[0] = 0; + strides[1] = size[1]; + strides[2] = 1; + + status = DftiSetValue(h, DFTI_INPUT_STRIDES, strides); + DLIB_DFTI_CHECK_STATUS(status); + status = DftiSetValue(h, DFTI_OUTPUT_STRIDES, strides); + DLIB_DFTI_CHECK_STATUS(status); + } + + status = DftiSetValue(h, DFTI_PLACEMENT, DFTI_NOT_INPLACE); + DLIB_DFTI_CHECK_STATUS(status); + + // Unless we use sequential mode, the fft results are not correct. + status = DftiSetValue(h, DFTI_THREAD_LIMIT, 1); + DLIB_DFTI_CHECK_STATUS(status); + + status = DftiCommitDescriptor(h); + DLIB_DFTI_CHECK_STATUS(status); + + matrix,NR,NC,MM,L> out(data.nr(), data.nc()); + + if (do_backward_fft) + status = DftiComputeBackward(h, (void *)(&data(0, 0)), &out(0,0)); + else + status = DftiComputeForward(h, (void *)(&data(0, 0)), &out(0,0)); + DLIB_DFTI_CHECK_STATUS(status); + + status = DftiFreeDescriptor(&h); + DLIB_DFTI_CHECK_STATUS(status); + + return out; + } + + template < long NR, long NC, typename MM, typename L > + void call_mkl_fft_inplace( + matrix,NR,NC,MM,L>& data, + bool do_backward_fft + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(is_power_of_two(data.nr()) && is_power_of_two(data.nc()), + "\t void ifft_inplace(data)" + << "\n\t The number of rows and columns must be powers of two." + << "\n\t data.nr(): "<< data.nr() + << "\n\t data.nc(): "<< data.nc() + << "\n\t is_power_of_two(data.nr()): " << is_power_of_two(data.nr()) + << "\n\t is_power_of_two(data.nc()): " << is_power_of_two(data.nc()) + ); + + if (data.size() == 0) + return; + + DFTI_DESCRIPTOR_HANDLE h; + MKL_LONG status; + + if (data.nr() == 1 || data.nc() == 1) + { + status = DftiCreateDescriptor(&h, DFTI_DOUBLE, DFTI_COMPLEX, 1, data.size()); + DLIB_DFTI_CHECK_STATUS(status); + } + else + { + MKL_LONG size[2]; + size[0] = data.nr(); + size[1] = data.nc(); + + status = DftiCreateDescriptor(&h, DFTI_DOUBLE, DFTI_COMPLEX, 2, size); + DLIB_DFTI_CHECK_STATUS(status); + + MKL_LONG strides[3]; + strides[0] = 0; + strides[1] = size[1]; + strides[2] = 1; + + status = DftiSetValue(h, DFTI_INPUT_STRIDES, strides); + DLIB_DFTI_CHECK_STATUS(status); + } + + // Unless we use sequential mode, the fft results are not correct. + status = DftiSetValue(h, DFTI_THREAD_LIMIT, 1); + DLIB_DFTI_CHECK_STATUS(status); + + status = DftiCommitDescriptor(h); + DLIB_DFTI_CHECK_STATUS(status); + + if (do_backward_fft) + status = DftiComputeBackward(h, &data(0, 0)); + else + status = DftiComputeForward(h, &data(0, 0)); + DLIB_DFTI_CHECK_STATUS(status); + + status = DftiFreeDescriptor(&h); + DLIB_DFTI_CHECK_STATUS(status); + + return; + } + +// ---------------------------------------------------------------------------------------- + + // Call the MKL DFTI implementation in these cases + + inline matrix,0,1> fft (const matrix,0,1>& data) + { + return call_mkl_fft(data, false); + } + inline matrix,0,1> ifft(const matrix,0,1>& data) + { + return call_mkl_fft(data, true) / data.size(); + } + inline matrix,1,0> fft (const matrix,1,0>& data) + { + return call_mkl_fft(data, false); + } + inline matrix,1,0> ifft(const matrix,1,0>& data) + { + return call_mkl_fft(data, true) / data.size(); + } + inline matrix > fft (const matrix >& data) + { + return call_mkl_fft(data, false); + } + inline matrix > ifft(const matrix >& data) + { + return call_mkl_fft(data, true) / data.size(); + } + + inline void fft_inplace (matrix,0,1>& data) + { + call_mkl_fft_inplace(data, false); + } + inline void ifft_inplace(matrix,0,1>& data) + { + call_mkl_fft_inplace(data, true); + } + inline void fft_inplace (matrix,1,0>& data) + { + call_mkl_fft_inplace(data, false); + } + inline void ifft_inplace(matrix,1,0>& data) + { + call_mkl_fft_inplace(data, true); + } + + inline void fft_inplace (matrix >& data) + { + call_mkl_fft_inplace(data, false); + } + inline void ifft_inplace(matrix >& data) + { + call_mkl_fft_inplace(data, true); + } + +#endif // DLIB_USE_MKL_FFT + +// ---------------------------------------------------------------------------------------- +} + +#endif // DLIB_FFt_Hh_ + diff --git a/ml/dlib/dlib/matrix/matrix_fft_abstract.h b/ml/dlib/dlib/matrix/matrix_fft_abstract.h new file mode 100644 index 000000000..25cdfcaee --- /dev/null +++ b/ml/dlib/dlib/matrix/matrix_fft_abstract.h @@ -0,0 +1,118 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_FFt_ABSTRACT_Hh_ +#ifdef DLIB_FFt_ABSTRACT_Hh_ + +#include "matrix_abstract.h" +#include "../algs.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + bool is_power_of_two ( + const unsigned long& value + ); + /*! + ensures + - returns true if value contains a power of two and false otherwise. As a + special case, we also consider 0 to be a power of two. + !*/ + +// ---------------------------------------------------------------------------------------- + + template + typename EXP::matrix_type fft ( + const matrix_exp& data + ); + /*! + requires + - data contains elements of type std::complex<> that itself contains double, float, or long double. + - is_power_of_two(data.nr()) == true + - is_power_of_two(data.nc()) == true + ensures + - Computes the 1 or 2 dimensional discrete Fourier transform of the given data + matrix and returns it. In particular, we return a matrix D such that: + - D.nr() == data.nr() + - D.nc() == data.nc() + - D(0,0) == the DC term of the Fourier transform. + - starting with D(0,0), D contains progressively higher frequency components + of the input data. + - ifft(D) == D + !*/ + +// ---------------------------------------------------------------------------------------- + + template + typename EXP::matrix_type ifft ( + const matrix_exp& data + ); + /*! + requires + - data contains elements of type std::complex<> that itself contains double, float, or long double. + - is_power_of_two(data.nr()) == true + - is_power_of_two(data.nc()) == true + ensures + - Computes the 1 or 2 dimensional inverse discrete Fourier transform of the + given data vector and returns it. In particular, we return a matrix D such + that: + - D.nr() == data.nr() + - D.nc() == data.nc() + - fft(D) == data + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + long NR, + long NC, + typename MM, + typename L + > + void fft_inplace ( + matrix,NR,NC,MM,L>& data + ); + /*! + requires + - data contains elements of type std::complex<> that itself contains double, float, or long double. + - is_power_of_two(data.nr()) == true + - is_power_of_two(data.nc()) == true + ensures + - This function is identical to fft() except that it does the FFT in-place. + That is, after this function executes we will have: + - #data == fft(data) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + long NR, + long NC, + typename MM, + typename L + > + void ifft_inplace ( + matrix,NR,NC,MM,L>& data + ); + /*! + requires + - data contains elements of type std::complex<> that itself contains double, float, or long double. + - is_power_of_two(data.nr()) == true + - is_power_of_two(data.nc()) == true + ensures + - This function is identical to ifft() except that it does the inverse FFT + in-place. That is, after this function executes we will have: + - #data == ifft(data)*data.size() + - Note that the output needs to be divided by data.size() to complete the + inverse transformation. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_FFt_ABSTRACT_Hh_ + diff --git a/ml/dlib/dlib/matrix/matrix_fwd.h b/ml/dlib/dlib/matrix/matrix_fwd.h new file mode 100644 index 000000000..1f40a17a8 --- /dev/null +++ b/ml/dlib/dlib/matrix/matrix_fwd.h @@ -0,0 +1,31 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MATRIx_FWD +#define DLIB_MATRIx_FWD + +#include "../algs.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + struct row_major_layout; + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + long num_rows = 0, + long num_cols = 0, + typename mem_manager = default_memory_manager, + typename layout = row_major_layout + > + class matrix; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MATRIx_FWD + diff --git a/ml/dlib/dlib/matrix/matrix_generic_image.h b/ml/dlib/dlib/matrix/matrix_generic_image.h new file mode 100644 index 000000000..0455af205 --- /dev/null +++ b/ml/dlib/dlib/matrix/matrix_generic_image.h @@ -0,0 +1,110 @@ +// Copyright (C) 2014 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MATRIX_GENERIC_iMAGE_Hh_ +#define DLIB_MATRIX_GENERIC_iMAGE_Hh_ + +#include "matrix.h" +#include "../image_processing/generic_image.h" + +namespace dlib +{ + template < + typename T, + long NR, + long NC, + typename MM + > + struct image_traits > + { + typedef T pixel_type; + }; + + template < + typename T, + long NR, + long NC, + typename MM + > + struct image_traits > + { + typedef T pixel_type; + }; + + template < + typename T, + long NR, + long NC, + typename MM + > + inline long num_rows( const matrix& img) { return img.nr(); } + + template < + typename T, + long NR, + long NC, + typename MM + > + inline long num_columns( const matrix& img) { return img.nc(); } + + template < + typename T, + long NR, + long NC, + typename MM + > + inline void set_image_size( + matrix& img, + long rows, + long cols + ) { img.set_size(rows,cols); } + + template < + typename T, + long NR, + long NC, + typename MM + > + inline void* image_data( + matrix& img + ) + { + if (img.size() != 0) + return &img(0,0); + else + return 0; + } + + template < + typename T, + long NR, + long NC, + typename MM + > + inline const void* image_data( + const matrix& img + ) + { + if (img.size() != 0) + return &img(0,0); + else + return 0; + } + + template < + typename T, + long NR, + long NC, + typename MM + > + inline long width_step( + const matrix& img + ) + { + return img.nc()*sizeof(T); + } + +} + +#endif // DLIB_MATRIX_GENERIC_iMAGE_Hh_ + + diff --git a/ml/dlib/dlib/matrix/matrix_la.h b/ml/dlib/dlib/matrix/matrix_la.h new file mode 100644 index 000000000..35b5b42e2 --- /dev/null +++ b/ml/dlib/dlib/matrix/matrix_la.h @@ -0,0 +1,1807 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MATRIx_LA_FUNCTS_ +#define DLIB_MATRIx_LA_FUNCTS_ + +#include "matrix_la_abstract.h" +#include "matrix_utilities.h" +#include "../sparse_vector.h" +#include "../optimization/optimization_line_search.h" + +// The 4 decomposition objects described in the matrix_la_abstract.h file are +// actually implemented in the following 4 files. +#include "matrix_lu.h" +#include "matrix_qr.h" +#include "matrix_cholesky.h" +#include "matrix_eigenvalue.h" + +#ifdef DLIB_USE_LAPACK +#include "lapack/potrf.h" +#include "lapack/pbtrf.h" +#include "lapack/gesdd.h" +#include "lapack/gesvd.h" +#endif + +#include "../threads.h" + +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + enum svd_u_mode + { + SVD_NO_U, + SVD_SKINNY_U, + SVD_FULL_U + }; + + template < + typename EXP, + long qN, long qX, + long uM, long uN, + long vM, long vN, + typename MM1, + typename MM2, + typename MM3, + typename L1 + > + long svd4 ( + svd_u_mode u_mode, + bool withv, + const matrix_exp& a, + matrix& u, + matrix& q, + matrix& v + ) + { + /* + Singular value decomposition. Translated to 'C' from the + original Algol code in "Handbook for Automatic Computation, + vol. II, Linear Algebra", Springer-Verlag. Note that this + published algorithm is considered to be the best and numerically + stable approach to computing the real-valued svd and is referenced + repeatedly in ieee journal papers, etc where the svd is used. + + This is almost an exact translation from the original, except that + an iteration counter is added to prevent stalls. This corresponds + to similar changes in other translations. + + Returns an error code = 0, if no errors and 'k' if a failure to + converge at the 'kth' singular value. + + USAGE: given the singular value decomposition a = u * diagm(q) * trans(v) for an m*n + matrix a with m >= n ... + After the svd call u is an m x m matrix which is columnwise + orthogonal. q will be an n element vector consisting of singular values + and v an n x n orthogonal matrix. eps and tol are tolerance constants. + Suitable values are eps=1e-16 and tol=(1e-300)/eps if T == double. + + If u_mode == SVD_NO_U then u won't be computed and similarly if withv == false + then v won't be computed. If u_mode == SVD_SKINNY_U then u will be m x n instead of m x m. + */ + + + DLIB_ASSERT(a.nr() >= a.nc(), + "\tconst matrix_exp svd4()" + << "\n\tYou have given an invalidly sized matrix" + << "\n\ta.nr(): " << a.nr() + << "\n\ta.nc(): " << a.nc() + ); + + + typedef typename EXP::type T; + +#ifdef DLIB_USE_LAPACK + matrix temp(a), vtemp; + + char jobu = 'A'; + char jobvt = 'A'; + if (u_mode == SVD_NO_U) + jobu = 'N'; + else if (u_mode == SVD_SKINNY_U) + jobu = 'S'; + if (withv == false) + jobvt = 'N'; + + int info; + if (jobu == jobvt) + { + info = lapack::gesdd(jobu, temp, q, u, vtemp); + } + else + { + info = lapack::gesvd(jobu, jobvt, temp, q, u, vtemp); + } + + // pad q with zeros if it isn't the length we want + if (q.nr() < a.nc()) + q = join_cols(q, zeros_matrix(a.nc()-q.nr(),1)); + + if (withv) + v = trans(vtemp); + + return info; +#else + using std::abs; + using std::sqrt; + + T eps = std::numeric_limits::epsilon(); + T tol = std::numeric_limits::min()/eps; + + const long m = a.nr(); + const long n = a.nc(); + long i, j, k, l = 0, l1, iter, retval; + T c, f, g, h, s, x, y, z; + + matrix e(n,1); + q.set_size(n,1); + if (u_mode == SVD_FULL_U) + u.set_size(m,m); + else + u.set_size(m,n); + retval = 0; + + if (withv) + { + v.set_size(n,n); + } + + /* Copy 'a' to 'u' */ + for (i=0; i x) + x = y; + } /* end i */ + + /* accumulation of right-hand transformations */ + if (withv) + { + for (i=n-1; i>=0; i--) + { + if (g != 0.0) + { + h = u(i,i+1) * g; + + for (j=l; j=0; i--) + { + l = i + 1; + g = q(i); + + for (j=l; j=0; k--) + { + iter = 0; + +test_f_splitting: + + for (l=k; l>=0; l--) + { + if (abs(e(l)) <= eps) + goto test_f_convergence; + + if (abs(q(l-1)) <= eps) + goto cancellation; + } /* end l */ + + /* cancellation of e(l) if l > 0 */ + +cancellation: + + c = 0.0; + s = 1.0; + l1 = l - 1; + + for (i=l; i<=k; i++) + { + f = s * e(i); + e(i) *= c; + + if (abs(f) <= eps) + goto test_f_convergence; + + g = q(i); + h = q(i) = sqrt(f*f + g*g); + c = g / h; + s = -f / h; + + if (u_mode != SVD_NO_U) + { + for (j=0; j 300) + { + retval = k; + break; + } + x = q(l); + y = q(k-1); + g = e(k-1); + h = e(k); + f = ((y - z) * (y + z) + (g - h) * (g + h)) / (2 * h * y); + g = sqrt(f * f + 1.0); + f = ((x - z) * (x + z) + h * (y / ((f < 0)?(f - g) : (f + g)) - h)) / x; + + /* next QR transformation */ + c = s = 1.0; + + for (i=l+1; i<=k; i++) + { + g = e(i); + y = q(i); + h = s * g; + g *= c; + e(i-1) = z = sqrt(f * f + h * h); + c = f / z; + s = h / z; + f = x * c + g * s; + g = -x * s + g * c; + h = y * s; + y *= c; + + if (withv) + { + for (j=0;j + long svd2 ( + bool withu, + bool withv, + const matrix_exp& a, + matrix& u, + matrix& q, + matrix& v + ) + { + const long NR = matrix_exp::NR; + const long NC = matrix_exp::NC; + + // make sure the output matrices have valid dimensions if they are statically dimensioned + COMPILE_TIME_ASSERT(qX == 0 || qX == 1); + COMPILE_TIME_ASSERT(NR == 0 || uM == 0 || NR == uM); + COMPILE_TIME_ASSERT(NC == 0 || vN == 0 || NC == vN); + + DLIB_ASSERT(a.nr() >= a.nc(), + "\tconst matrix_exp svd4()" + << "\n\tYou have given an invalidly sized matrix" + << "\n\ta.nr(): " << a.nr() + << "\n\ta.nc(): " << a.nc() + ); + + if (withu) + return svd4(SVD_FULL_U, withv, a,u,q,v); + else + return svd4(SVD_NO_U, withv, a,u,q,v); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + long NR, + long NC, + typename MM + > + void orthogonalize ( + matrix& m + ) + { + // We don't really need to use this temporary, but doing it this way runs a lot + // faster. + matrix temp; + qr_decomposition>(m).get_q(temp); + m = temp; + } + + template < + typename T, + long NR, + long NC, + typename MM + > + void orthogonalize ( + matrix& m + ) + { + qr_decomposition>(m).get_q(m); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T, + long Anr, long Anc, + typename MM, + typename L + > + void find_matrix_range ( + const matrix& A, + unsigned long l, + matrix& Q, + unsigned long q + ) + /*! + requires + - A.nr() >= l + ensures + - #Q.nr() == A.nr() + - #Q.nc() == l + - #Q == an orthonormal matrix whose range approximates the range of the + matrix A. + - This function implements the randomized subspace iteration defined + in the algorithm 4.4 box of the paper: + Finding Structure with Randomness: Probabilistic Algorithms for + Constructing Approximate Matrix Decompositions by Halko et al. + - q defines the number of extra subspace iterations this algorithm will + perform. Often q == 0 is fine, but performing more iterations can lead to a + more accurate approximation of the range of A if A has slowly decaying + singular values. In these cases, using a q of 1 or 2 is good. + !*/ + { + DLIB_ASSERT(A.nr() >= (long)l, "Invalid inputs were given to this function."); + Q = A*matrix_cast(gaussian_randm(A.nc(), l)); + + orthogonalize(Q); + + // Do some extra iterations of the power method to make sure we get Q into the + // span of the most important singular vectors of A. + if (q != 0) + { + for (unsigned long itr = 0; itr < q; ++itr) + { + Q = trans(A)*Q; + orthogonalize(Q); + + Q = A*Q; + orthogonalize(Q); + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + long Anr, long Anc, + long Unr, long Unc, + long Wnr, long Wnc, + long Vnr, long Vnc, + typename MM, + typename L + > + void svd_fast ( + const matrix& A, + matrix& u, + matrix& w, + matrix& v, + unsigned long l, + unsigned long q = 1 + ) + { + const unsigned long k = std::min(l, std::min(A.nr(),A.nc())); + + DLIB_ASSERT(l > 0 && A.size() > 0, + "\t void svd_fast()" + << "\n\t Invalid inputs were given to this function." + << "\n\t l: " << l + << "\n\t A.size(): " << A.size() + ); + + matrix Q; + find_matrix_range(A, k, Q, q); + + // Compute trans(B) = trans(Q)*A. The reason we store B transposed + // is so that when we take its SVD later using svd3() it doesn't consume + // a whole lot of RAM. That is, we make sure the square matrix coming out + // of svd3() has size lxl rather than the potentially much larger nxn. + matrix B = trans(A)*Q; + svd3(B, v,w,u); + u = Q*u; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename sparse_vector_type, + typename T, + typename MM, + typename L + > + void find_matrix_range ( + const std::vector& A, + unsigned long l, + matrix& Q, + unsigned long q + ) + /*! + requires + - A.size() >= l + ensures + - #Q.nr() == A.size() + - #Q.nc() == l + - #Q == an orthonormal matrix whose range approximates the range of the + matrix A. In this case, we interpret A as a matrix of A.size() rows, + where each row is defined by a sparse vector. + - This function implements the randomized subspace iteration defined + in the algorithm 4.4 box of the paper: + Finding Structure with Randomness: Probabilistic Algorithms for + Constructing Approximate Matrix Decompositions by Halko et al. + - q defines the number of extra subspace iterations this algorithm will + perform. Often q == 0 is fine, but performing more iterations can lead to a + more accurate approximation of the range of A if A has slowly decaying + singular values. In these cases, using a q of 1 or 2 is good. + !*/ + { + DLIB_ASSERT(A.size() >= l, "Invalid inputs were given to this function."); + Q.set_size(A.size(), l); + + // Compute Q = A*gaussian_randm() + parallel_for(0, Q.nr(), [&](long r) + { + for (long c = 0; c < Q.nc(); ++c) + { + Q(r,c) = dot(A[r], gaussian_randm(std::numeric_limits::max(), 1, c)); + } + }); + + orthogonalize(Q); + + // Do some extra iterations of the power method to make sure we get Q into the + // span of the most important singular vectors of A. + if (q != 0) + { + dlib::mutex mut; + const unsigned long n = max_index_plus_one(A); + for (unsigned long itr = 0; itr < q; ++itr) + { + matrix Z; + // Compute Z = trans(A)*Q + parallel_for_blocked(0, A.size(), [&](long begin, long end) + { + matrix Zlocal(n,l); + Zlocal = 0; + for (long m = begin; m < end; ++m) + { + for (unsigned long r = 0; r < l; ++r) + { + for (auto& i : A[m]) + { + const auto c = i.first; + const auto val = i.second; + + Zlocal(c,r) += Q(m,r)*val; + } + } + } + auto_mutex lock(mut); + Z += Zlocal; + },1); + + Q.set_size(0,0); // free RAM + orthogonalize(Z); + + // Compute Q = A*Z + Q.set_size(A.size(), l); + parallel_for(0, Q.nr(), [&](long r) + { + for (long c = 0; c < Q.nc(); ++c) + { + Q(r,c) = dot(A[r], colm(Z,c)); + } + }); + + Z.set_size(0,0); // free RAM + orthogonalize(Q); + } + } + } + +// ---------------------------------------------------------------------------------------- + + namespace simpl + { + template < + typename sparse_vector_type, + typename T, + long Unr, long Unc, + long Wnr, long Wnc, + long Vnr, long Vnc, + typename MM, + typename L + > + void svd_fast ( + bool compute_u, + const std::vector& A, + matrix& u, + matrix& w, + matrix& v, + unsigned long l, + unsigned long q + ) + { + const long n = max_index_plus_one(A); + const unsigned long k = std::min(l, std::min(A.size(),n)); + + DLIB_ASSERT(l > 0 && A.size() > 0 && n > 0, + "\t void svd_fast()" + << "\n\t Invalid inputs were given to this function." + << "\n\t l: " << l + << "\n\t n (i.e. max_index_plus_one(A)): " << n + << "\n\t A.size(): " << A.size() + ); + + matrix Q; + find_matrix_range(A, k, Q, q); + + // Compute trans(B) = trans(Q)*A. The reason we store B transposed + // is so that when we take its SVD later using svd3() it doesn't consume + // a whole lot of RAM. That is, we make sure the square matrix coming out + // of svd3() has size lxl rather than the potentially much larger nxn. + matrix B; + dlib::mutex mut; + parallel_for_blocked(0, A.size(), [&](long begin, long end) + { + matrix Blocal(n,k); + Blocal = 0; + for (long m = begin; m < end; ++m) + { + for (unsigned long r = 0; r < k; ++r) + { + for (auto& i : A[m]) + { + const auto c = i.first; + const auto val = i.second; + + Blocal(c,r) += Q(m,r)*val; + } + } + } + auto_mutex lock(mut); + B += Blocal; + },1); + + svd3(B, v,w,u); + if (compute_u) + u = Q*u; + } + } + + template < + typename sparse_vector_type, + typename T, + long Unr, long Unc, + long Wnr, long Wnc, + long Vnr, long Vnc, + typename MM, + typename L + > + void svd_fast ( + const std::vector& A, + matrix& u, + matrix& w, + matrix& v, + unsigned long l, + unsigned long q = 1 + ) + { + simpl::svd_fast(true, A,u,w,v,l,q); + } + + template < + typename sparse_vector_type, + typename T, + long Wnr, long Wnc, + long Vnr, long Vnc, + typename MM, + typename L + > + void svd_fast ( + const std::vector& A, + matrix& w, + matrix& v, + unsigned long l, + unsigned long q = 1 + ) + { + matrix u; + simpl::svd_fast(false, A,u,w,v,l,q); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename EXP, + long N + > + struct inv_helper + { + static const typename matrix_exp::matrix_type inv ( + const matrix_exp& m + ) + { + // you can't invert a non-square matrix + COMPILE_TIME_ASSERT(matrix_exp::NR == matrix_exp::NC || + matrix_exp::NR == 0 || + matrix_exp::NC == 0); + DLIB_ASSERT(m.nr() == m.nc(), + "\tconst matrix_exp::type inv(const matrix_exp& m)" + << "\n\tYou can only apply inv() to a square matrix" + << "\n\tm.nr(): " << m.nr() + << "\n\tm.nc(): " << m.nc() + ); + typedef typename matrix_exp::type type; + + lu_decomposition lu(m); + return lu.solve(identity_matrix(m.nr())); + } + }; + + template < + typename EXP + > + struct inv_helper + { + static const typename matrix_exp::matrix_type inv ( + const matrix_exp& m + ) + { + COMPILE_TIME_ASSERT(matrix_exp::NR == matrix_exp::NC); + typedef typename matrix_exp::type type; + + matrix a; + // if m is invertible + if (m(0) != 0) + a(0) = 1/m(0); + else + a(0) = 1; + return a; + } + }; + + template < + typename EXP + > + struct inv_helper + { + static const typename matrix_exp::matrix_type inv ( + const matrix_exp& m + ) + { + COMPILE_TIME_ASSERT(matrix_exp::NR == matrix_exp::NC); + typedef typename matrix_exp::type type; + + matrix a; + type d = det(m); + if (d != 0) + { + d = static_cast(1.0/d); + a(0,0) = m(1,1)*d; + a(0,1) = m(0,1)*-d; + a(1,0) = m(1,0)*-d; + a(1,1) = m(0,0)*d; + } + else + { + // Matrix isn't invertible so just return the identity matrix. + a = identity_matrix(); + } + return a; + } + }; + + template < + typename EXP + > + struct inv_helper + { + static const typename matrix_exp::matrix_type inv ( + const matrix_exp& m + ) + { + COMPILE_TIME_ASSERT(matrix_exp::NR == matrix_exp::NC); + typedef typename matrix_exp::type type; + + matrix ret; + type de = det(m); + if (de != 0) + { + de = static_cast(1.0/de); + const type a = m(0,0); + const type b = m(0,1); + const type c = m(0,2); + const type d = m(1,0); + const type e = m(1,1); + const type f = m(1,2); + const type g = m(2,0); + const type h = m(2,1); + const type i = m(2,2); + + ret(0,0) = (e*i - f*h)*de; + ret(1,0) = (f*g - d*i)*de; + ret(2,0) = (d*h - e*g)*de; + + ret(0,1) = (c*h - b*i)*de; + ret(1,1) = (a*i - c*g)*de; + ret(2,1) = (b*g - a*h)*de; + + ret(0,2) = (b*f - c*e)*de; + ret(1,2) = (c*d - a*f)*de; + ret(2,2) = (a*e - b*d)*de; + } + else + { + ret = identity_matrix(); + } + + return ret; + } + }; + + template < + typename EXP + > + struct inv_helper + { + static const typename matrix_exp::matrix_type inv ( + const matrix_exp& m + ) + { + COMPILE_TIME_ASSERT(matrix_exp::NR == matrix_exp::NC); + typedef typename matrix_exp::type type; + + matrix ret; + type de = det(m); + if (de != 0) + { + de = static_cast(1.0/de); + ret(0,0) = det(removerc<0,0>(m)); + ret(0,1) = -det(removerc<0,1>(m)); + ret(0,2) = det(removerc<0,2>(m)); + ret(0,3) = -det(removerc<0,3>(m)); + + ret(1,0) = -det(removerc<1,0>(m)); + ret(1,1) = det(removerc<1,1>(m)); + ret(1,2) = -det(removerc<1,2>(m)); + ret(1,3) = det(removerc<1,3>(m)); + + ret(2,0) = det(removerc<2,0>(m)); + ret(2,1) = -det(removerc<2,1>(m)); + ret(2,2) = det(removerc<2,2>(m)); + ret(2,3) = -det(removerc<2,3>(m)); + + ret(3,0) = -det(removerc<3,0>(m)); + ret(3,1) = det(removerc<3,1>(m)); + ret(3,2) = -det(removerc<3,2>(m)); + ret(3,3) = det(removerc<3,3>(m)); + + return trans(ret)*de; + } + else + { + return identity_matrix(); + } + } + }; + + template < + typename EXP + > + inline const typename matrix_exp::matrix_type inv ( + const matrix_exp& m + ) { return inv_helper::NR>::inv(m); } + +// ---------------------------------------------------------------------------------------- + + template + struct op_diag_inv + { + template + op_diag_inv( const matrix_exp& m_) : m(m_){} + + + const static long cost = 1; + const static long NR = ((M::NC!=0)&&(M::NR!=0))? (tmax::value) : (0); + const static long NC = NR; + typedef typename M::type type; + typedef const type const_ret_type; + typedef typename M::mem_manager_type mem_manager_type; + typedef typename M::layout_type layout_type; + + + // hold the matrix by value + const matrix m; + + const_ret_type apply ( long r, long c) const + { + if (r==c) + return m(r); + else + return 0; + } + + long nr () const { return m.size(); } + long nc () const { return m.size(); } + + template bool aliases ( const matrix_exp& item) const { return m.aliases(item); } + template bool destructively_aliases ( const matrix_exp& item) const { return m.aliases(item); } + }; + + template < + typename EXP + > + const matrix_diag_op > inv ( + const matrix_diag_exp& m + ) + { + typedef op_diag_inv op; + return matrix_diag_op(op(reciprocal(diag(m)))); + } + + template < + typename EXP + > + const matrix_diag_op > pinv ( + const matrix_diag_exp& m + ) + { + typedef op_diag_inv op; + return matrix_diag_op(op(reciprocal(diag(m)))); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP + > + const matrix_diag_op > pinv ( + const matrix_diag_exp& m, + double tol + ) + { + DLIB_ASSERT(tol >= 0, + "\tconst matrix_exp::type pinv(const matrix_exp& m)" + << "\n\t tol can't be negative" + << "\n\t tol: "< op; + return matrix_diag_op(op(reciprocal(round_zeros(diag(m),tol)))); + } + +// ---------------------------------------------------------------------------------------- + + template + const typename matrix_exp::matrix_type inv_lower_triangular ( + const matrix_exp& A + ) + { + DLIB_ASSERT(A.nr() == A.nc(), + "\tconst matrix inv_lower_triangular(const matrix_exp& A)" + << "\n\tA must be a square matrix" + << "\n\tA.nr(): " << A.nr() + << "\n\tA.nc(): " << A.nc() + ); + + typedef typename matrix_exp::matrix_type matrix_type; + + matrix_type m(A); + + for(long c = 0; c < m.nc(); ++c) + { + if( m(c,c) == 0 ) + { + // there isn't an inverse so just give up + return m; + } + + // compute m(c,c) + m(c,c) = 1/m(c,c); + + // compute the values in column c that are below m(c,c). + // We do this by just doing the same thing we do for upper triangular + // matrices because we take the transpose of m which turns m into an + // upper triangular matrix. + for(long r = 0; r < c; ++r) + { + const long n = c-r; + m(c,r) = -m(c,c)*subm(trans(m),r,r,1,n)*subm(trans(m),r,c,n,1); + } + } + + return m; + + } + +// ---------------------------------------------------------------------------------------- + + template + const typename matrix_exp::matrix_type inv_upper_triangular ( + const matrix_exp& A + ) + { + DLIB_ASSERT(A.nr() == A.nc(), + "\tconst matrix inv_upper_triangular(const matrix_exp& A)" + << "\n\tA must be a square matrix" + << "\n\tA.nr(): " << A.nr() + << "\n\tA.nc(): " << A.nc() + ); + + typedef typename matrix_exp::matrix_type matrix_type; + + matrix_type m(A); + + for(long c = 0; c < m.nc(); ++c) + { + if( m(c,c) == 0 ) + { + // there isn't an inverse so just give up + return m; + } + + // compute m(c,c) + m(c,c) = 1/m(c,c); + + // compute the values in column c that are above m(c,c) + for(long r = 0; r < c; ++r) + { + const long n = c-r; + m(r,c) = -m(c,c)*subm(m,r,r,1,n)*subm(m,r,c,n,1); + } + } + + return m; + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP + > + inline const typename matrix_exp::matrix_type chol ( + const matrix_exp& A + ) + { + DLIB_ASSERT(A.nr() == A.nc(), + "\tconst matrix chol(const matrix_exp& A)" + << "\n\tYou can only apply the chol to a square matrix" + << "\n\tA.nr(): " << A.nr() + << "\n\tA.nc(): " << A.nc() + ); + typename matrix_exp::matrix_type L(A.nr(),A.nc()); + + typedef typename EXP::type T; + + bool banded = false; + long bandwidth = 0; + + if (A.nr() > 4) // Only test for banded matrix if matrix is big enough + { + // Detect if matrix is banded and, if so, matrix bandwidth + banded = true; + for (long r = 0; r < A.nr(); ++r) + for (long c = (r + bandwidth + 1); c < A.nc(); ++c) + if (A(r, c) != 0) + { + bandwidth = c - r; + if (bandwidth > A.nr() / 2) + { + banded = false; + goto escape_banded_detection; + } + } + } +escape_banded_detection: + + if (banded) + { + // Store in compact form - use column major for LAPACK + matrix B(bandwidth + 1, A.nc()); + set_all_elements(B, 0); + + for (long r = 0; r < A.nr(); ++r) + for (long c = r; c < std::min(r + bandwidth + 1, A.nc()); ++c) + B(c - r, r) = A(r, c); + +#ifdef DLIB_USE_LAPACK + + lapack::pbtrf('L', B); + +#else + + // Peform compact Cholesky + for (long k = 0; k < A.nr(); ++k) + { + long last = std::min(k + bandwidth, A.nr() - 1) - k; + for (long j = 1; j <= last; ++j) + { + long i = k + j; + for (long c = 0; c <= (last - j); ++c) + B(c, i) -= B(j, k) / B(0, k) * B(c + j, k); + } + T norm = std::sqrt(B(0, k)); + for (long i = 0; i <= bandwidth; ++i) + B(i, k) /= norm; + } + for (long c = A.nc() - bandwidth + 1; c < A.nc(); ++c) + B(bandwidth, c) = 0; + +#endif + + // Unpack lower triangular area + set_all_elements(L, 0); + for (long c = 0; c < A.nc(); ++c) + for (long i = 0; i <= bandwidth; ++i) + { + long ind = c + i; + if (ind < A.nc()) + L(ind, c) = B(i, c); + } + + return L; + } + +#ifdef DLIB_USE_LAPACK + // Only call LAPACK if the matrix is big enough. Otherwise, + // our own code is faster, especially for statically dimensioned + // matrices. + if (A.nr() > 4) + { + L = A; + lapack::potrf('L', L); + // mask out upper triangular area + return lowerm(L); + } +#endif + set_all_elements(L,0); + + // do nothing if the matrix is empty + if (A.size() == 0) + return L; + + const T eps = std::numeric_limits::epsilon(); + + // compute the upper left corner + if (A(0,0) > 0) + L(0,0) = std::sqrt(A(0,0)); + + // compute the first column + for (long r = 1; r < A.nr(); ++r) + { + // if (L(0,0) > 0) + if (L(0,0) > eps*std::abs(A(r,0))) + L(r,0) = A(r,0)/L(0,0); + else + return L; + } + + // now compute all the other columns + for (long c = 1; c < A.nc(); ++c) + { + // compute the diagonal element + T temp = A(c,c); + for (long i = 0; i < c; ++i) + { + temp -= L(c,i)*L(c,i); + } + if (temp > 0) + L(c,c) = std::sqrt(temp); + + // compute the non diagonal elements + for (long r = c+1; r < A.nr(); ++r) + { + temp = A(r,c); + for (long i = 0; i < c; ++i) + { + temp -= L(r,i)*L(c,i); + } + + // if (L(c,c) > 0) + if (L(c,c) > eps*std::abs(temp)) + L(r,c) = temp/L(c,c); + else + return L; + } + } + + return L; + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP, + long uNR, + long uNC, + long wN, + long vN, + long wX, + typename MM1, + typename MM2, + typename MM3, + typename L1 + > + inline void svd3 ( + const matrix_exp& m, + matrix::type, uNR, uNC,MM1,L1>& u, + matrix::type, wN, wX,MM2,L1>& w, + matrix::type, vN, vN,MM3,L1>& v + ) + { + typedef typename matrix_exp::type T; + const long NR = matrix_exp::NR; + const long NC = matrix_exp::NC; + + // make sure the output matrices have valid dimensions if they are statically dimensioned + COMPILE_TIME_ASSERT(NR == 0 || uNR == 0 || NR == uNR); + COMPILE_TIME_ASSERT(NC == 0 || uNC == 0 || NC == uNC); + COMPILE_TIME_ASSERT(NC == 0 || wN == 0 || NC == wN); + COMPILE_TIME_ASSERT(NC == 0 || vN == 0 || NC == vN); + COMPILE_TIME_ASSERT(wX == 0 || wX == 1); + +#ifdef DLIB_USE_LAPACK + // use LAPACK but only if it isn't a really small matrix we are taking the SVD of. + if (NR*NC == 0 || NR*NC > 3*3) + { + matrix::type, uNR, uNC,MM1,L1> temp(m); + lapack::gesvd('S','A', temp, w, u, v); + v = trans(v); + // if u isn't the size we want then pad it (and v) with zeros + if (u.nc() < m.nc()) + { + w = join_cols(w, zeros_matrix(m.nc()-u.nc(),1)); + u = join_rows(u, zeros_matrix(u.nr(), m.nc()-u.nc())); + } + return; + } +#endif + if (m.nr() >= m.nc()) + { + svd4(SVD_SKINNY_U,true, m, u,w,v); + } + else + { + svd4(SVD_FULL_U,true, trans(m), v,w,u); + + // if u isn't the size we want then pad it (and v) with zeros + if (u.nc() < m.nc()) + { + w = join_cols(w, zeros_matrix(m.nc()-u.nc(),1)); + u = join_rows(u, zeros_matrix(u.nr(), m.nc()-u.nc())); + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP + > + const matrix pinv_helper ( + const matrix_exp& m, + double tol + ) + /*! + ensures + - computes the results of pinv(m) but does so using a method that is fastest + when m.nc() <= m.nr(). So if m.nc() > m.nr() then it is best to use + trans(pinv_helper(trans(m))) to compute pinv(m). + !*/ + { + typename matrix_exp::matrix_type u; + typedef typename EXP::mem_manager_type MM1; + typedef typename EXP::layout_type layout_type; + matrix v; + + typedef typename matrix_exp::type T; + + matrix::NC,1,MM1, layout_type> w; + + svd3(m, u,w,v); + + const double machine_eps = std::numeric_limits::epsilon(); + // compute a reasonable epsilon below which we round to zero before doing the + // reciprocal. Unless a non-zero tol is given then we just use tol*max(w). + const double eps = (tol!=0) ? tol*max(w) : machine_eps*std::max(m.nr(),m.nc())*max(w); + + // now compute the pseudoinverse + return tmp(scale_columns(v,reciprocal(round_zeros(w,eps))))*trans(u); + } + + template < + typename EXP + > + const matrix pinv ( + const matrix_exp& m, + double tol = 0 + ) + { + DLIB_ASSERT(tol >= 0, + "\tconst matrix_exp::type pinv(const matrix_exp& m)" + << "\n\t tol can't be negative" + << "\n\t tol: "< m.nr()) + return trans(pinv_helper(trans(m),tol)); + else + return pinv_helper(m,tol); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP, + long uNR, + long uNC, + long wN, + long vN, + typename MM1, + typename MM2, + typename MM3, + typename L1 + > + inline void svd ( + const matrix_exp& m, + matrix::type, uNR, uNC,MM1,L1>& u, + matrix::type, wN, wN,MM2,L1>& w, + matrix::type, vN, vN,MM3,L1>& v + ) + { + typedef typename matrix_exp::type T; + const long NR = matrix_exp::NR; + const long NC = matrix_exp::NC; + + // make sure the output matrices have valid dimensions if they are statically dimensioned + COMPILE_TIME_ASSERT(NR == 0 || uNR == 0 || NR == uNR); + COMPILE_TIME_ASSERT(NC == 0 || uNC == 0 || NC == uNC); + COMPILE_TIME_ASSERT(NC == 0 || wN == 0 || NC == wN); + COMPILE_TIME_ASSERT(NC == 0 || vN == 0 || NC == vN); + + matrix::NC,1,MM1, L1> W; + svd3(m,u,W,v); + w = diagm(W); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP + > + const typename matrix_exp::type trace ( + const matrix_exp& m + ) + { + COMPILE_TIME_ASSERT(matrix_exp::NR == matrix_exp::NC || + matrix_exp::NR == 0 || + matrix_exp::NC == 0 + ); + DLIB_ASSERT(m.nr() == m.nc(), + "\tconst matrix_exp::type trace(const matrix_exp& m)" + << "\n\tYou can only apply trace() to a square matrix" + << "\n\tm.nr(): " << m.nr() + << "\n\tm.nc(): " << m.nc() + ); + return sum(diag(m)); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP, + long N = EXP::NR + > + struct det_helper + { + static const typename matrix_exp::type det ( + const matrix_exp& m + ) + { + COMPILE_TIME_ASSERT(matrix_exp::NR == matrix_exp::NC || + matrix_exp::NR == 0 || + matrix_exp::NC == 0 + ); + DLIB_ASSERT(m.nr() == m.nc(), + "\tconst matrix_exp::type det(const matrix_exp& m)" + << "\n\tYou can only apply det() to a square matrix" + << "\n\tm.nr(): " << m.nr() + << "\n\tm.nc(): " << m.nc() + ); + + return lu_decomposition(m).det(); + } + }; + + template < + typename EXP + > + struct det_helper + { + static const typename matrix_exp::type det ( + const matrix_exp& m + ) + { + COMPILE_TIME_ASSERT(matrix_exp::NR == matrix_exp::NC); + + return m(0); + } + }; + + template < + typename EXP + > + struct det_helper + { + static const typename matrix_exp::type det ( + const matrix_exp& m + ) + { + COMPILE_TIME_ASSERT(matrix_exp::NR == matrix_exp::NC); + + return m(0,0)*m(1,1) - m(0,1)*m(1,0); + } + }; + + template < + typename EXP + > + struct det_helper + { + static const typename matrix_exp::type det ( + const matrix_exp& m + ) + { + COMPILE_TIME_ASSERT(matrix_exp::NR == matrix_exp::NC); + typedef typename matrix_exp::type type; + + type temp = m(0,0)*(m(1,1)*m(2,2) - m(1,2)*m(2,1)) - + m(0,1)*(m(1,0)*m(2,2) - m(1,2)*m(2,0)) + + m(0,2)*(m(1,0)*m(2,1) - m(1,1)*m(2,0)); + return temp; + } + }; + + + template < + typename EXP + > + inline const typename matrix_exp::type det ( + const matrix_exp& m + ) { return det_helper::det(m); } + + + template < + typename EXP + > + struct det_helper + { + static const typename matrix_exp::type det ( + const matrix_exp& m + ) + { + COMPILE_TIME_ASSERT(matrix_exp::NR == matrix_exp::NC); + typedef typename matrix_exp::type type; + + type temp = m(0,0)*(dlib::det(removerc<0,0>(m))) - + m(0,1)*(dlib::det(removerc<0,1>(m))) + + m(0,2)*(dlib::det(removerc<0,2>(m))) - + m(0,3)*(dlib::det(removerc<0,3>(m))); + return temp; + } + }; + +// ---------------------------------------------------------------------------------------- + + template + const matrix real_eigenvalues ( + const matrix_exp& m + ) + { + // You can only use this function with matrices that contain float or double values + COMPILE_TIME_ASSERT((is_same_type::value || + is_same_type::value)); + + DLIB_ASSERT(m.nr() == m.nc(), + "\tconst matrix real_eigenvalues()" + << "\n\tYou have given an invalidly sized matrix" + << "\n\tm.nr(): " << m.nr() + << "\n\tm.nc(): " << m.nc() + ); + + if (m.nr() == 2) + { + typedef typename EXP::type T; + const T m00 = m(0,0); + const T m01 = m(0,1); + const T m10 = m(1,0); + const T m11 = m(1,1); + + const T b = -(m00 + m11); + const T c = m00*m11 - m01*m10; + matrix v(2); + + + T disc = b*b - 4*c; + if (disc >= 0) + disc = std::sqrt(disc); + else + disc = 0; + + v(0) = (-b + disc)/2; + v(1) = (-b - disc)/2; + return v; + } + else + { + // Call .ref() so that the symmetric matrix overload can take effect if m + // has the appropriate type. + return eigenvalue_decomposition(m.ref()).get_real_eigenvalues(); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP + > + dlib::vector max_point_interpolated ( + const matrix_exp& m + ) + { + DLIB_ASSERT(m.size() > 0, + "\tdlib::vector point max_point_interpolated(const matrix_exp& m)" + << "\n\tm can't be empty" + << "\n\tm.size(): " << m.size() + << "\n\tm.nr(): " << m.nr() + << "\n\tm.nc(): " << m.nc() + ); + const point p = max_point(m); + + // If this is a column vector then just do interpolation along a line. + if (m.nc()==1) + { + const long pos = p.y(); + if (0 < pos && pos+1 < m.nr()) + { + double v1 = dlib::impl::magnitude(m(pos-1)); + double v2 = dlib::impl::magnitude(m(pos)); + double v3 = dlib::impl::magnitude(m(pos+1)); + double y = lagrange_poly_min_extrap(pos-1,pos,pos+1, -v1, -v2, -v3); + return vector(0,y); + } + } + // If this is a row vector then just do interpolation along a line. + if (m.nr()==1) + { + const long pos = p.x(); + if (0 < pos && pos+1 < m.nc()) + { + double v1 = dlib::impl::magnitude(m(pos-1)); + double v2 = dlib::impl::magnitude(m(pos)); + double v3 = dlib::impl::magnitude(m(pos+1)); + double x = lagrange_poly_min_extrap(pos-1,pos,pos+1, -v1, -v2, -v3); + return vector(x,0); + } + } + + + // If it's on the border then just return the regular max point. + if (shrink_rect(get_rect(m),1).contains(p) == false) + return p; + + //matrix A(9,6); + //matrix G(9); + + matrix pix; + long i = 0; + for (long r = -1; r <= +1; ++r) + { + for (long c = -1; c <= +1; ++c) + { + pix(i) = dlib::impl::magnitude(m(p.y()+r,p.y()+c)); + /* + A(i,0) = c*c; + A(i,1) = c*r; + A(i,2) = r*r; + A(i,3) = c; + A(i,4) = r; + A(i,5) = 1; + G(i) = std::exp(-1*(r*r+c*c)/2.0); // Use a gaussian windowing function around p. + */ + ++i; + } + } + + // This bit of code is how we generated the derivative_filters matrix below. + //A = diagm(G)*A; + //std::cout << std::setprecision(20) << inv(trans(A)*A)*trans(A)*diagm(G) << std::endl; exit(1); + + const double m10 = 0.10597077880854270659; + const double m21 = 0.21194155761708535768; + const double m28 = 0.28805844238291455905; + const double m57 = 0.57611688476582878504; + // So this derivative_filters finds the parameters of the quadratic surface that best fits + // the 3x3 region around p. Then we find the maximizer of that surface within that + // small region and return that as the maximum location. + const double derivative_filters[] = { + // xx + m10,-m21,m10, + m28,-m57,m28, + m10,-m21,m10, + + // xy + 0.25 ,0,-0.25, + 0 ,0, 0, + -0.25,0,0.25, + + // yy + m10, m28, m10, + -m21,-m57,-m21, + m10, m28, m10, + + // x + -m10,0,m10, + -m28,0,m28, + -m10,0,m10, + + // y + -m10,-m28,-m10, + 0, 0, 0, + m10, m28, m10 + }; + const matrix filt(derivative_filters); + // Now w contains the parameters of the quadratic surface + const matrix w = filt*pix; + + + // Now newton step to the max point on the surface + matrix H; + matrix g; + H = 2*w(0), w(1), + w(1), 2*w(2); + g = w(3), + w(4); + const dlib::vector delta = -inv(H)*g; + + // if delta isn't in an ascent direction then just use the normal max point. + if (dot(delta, g) < 0) + return p; + else + return vector(p)+dlib::clamp(delta, -1, 1); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MATRIx_LA_FUNCTS_ + + diff --git a/ml/dlib/dlib/matrix/matrix_la_abstract.h b/ml/dlib/dlib/matrix/matrix_la_abstract.h new file mode 100644 index 000000000..df6a5fd33 --- /dev/null +++ b/ml/dlib/dlib/matrix/matrix_la_abstract.h @@ -0,0 +1,1005 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_MATRIx_LA_FUNCTS_ABSTRACT_ +#ifdef DLIB_MATRIx_LA_FUNCTS_ABSTRACT_ + +#include "matrix_abstract.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// Global linear algebra functions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + const matrix_exp::matrix_type inv ( + const matrix_exp& m + ); + /*! + requires + - m is a square matrix + ensures + - returns the inverse of m + (Note that if m is singular or so close to being singular that there + is a lot of numerical error then the returned matrix will be bogus. + You can check by seeing if m*inv(m) is an identity matrix) + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix pinv ( + const matrix_exp& m, + double tol = 0 + ); + /*! + requires + - tol >= 0 + ensures + - returns the Moore-Penrose pseudoinverse of m. + - The returned matrix has m.nc() rows and m.nr() columns. + - if (tol == 0) then + - singular values less than max(m.nr(),m.nc()) times the machine epsilon + times the largest singular value are ignored. + - else + - singular values less than tol*max(singular value in m) are ignored. + !*/ + +// ---------------------------------------------------------------------------------------- + + void svd ( + const matrix_exp& m, + matrix& u, + matrix& w, + matrix& v + ); + /*! + ensures + - computes the singular value decomposition of m + - m == #u*#w*trans(#v) + - trans(#u)*#u == identity matrix + - trans(#v)*#v == identity matrix + - diag(#w) == the singular values of the matrix m in no + particular order. All non-diagonal elements of #w are + set to 0. + - #u.nr() == m.nr() + - #u.nc() == m.nc() + - #w.nr() == m.nc() + - #w.nc() == m.nc() + - #v.nr() == m.nc() + - #v.nc() == m.nc() + - if DLIB_USE_LAPACK is #defined then the xGESVD routine + from LAPACK is used to compute the SVD. + !*/ + +// ---------------------------------------------------------------------------------------- + + long svd2 ( + bool withu, + bool withv, + const matrix_exp& m, + matrix& u, + matrix& w, + matrix& v + ); + /*! + requires + - m.nr() >= m.nc() + ensures + - computes the singular value decomposition of matrix m + - m == subm(#u,get_rect(m))*diagm(#w)*trans(#v) + - trans(#u)*#u == identity matrix + - trans(#v)*#v == identity matrix + - #w == the singular values of the matrix m in no + particular order. + - #u.nr() == m.nr() + - #u.nc() == m.nr() + - #w.nr() == m.nc() + - #w.nc() == 1 + - #v.nr() == m.nc() + - #v.nc() == m.nc() + - if (widthu == false) then + - ignore the above regarding #u, it isn't computed and its + output state is undefined. + - if (widthv == false) then + - ignore the above regarding #v, it isn't computed and its + output state is undefined. + - returns an error code of 0, if no errors and 'k' if we fail to + converge at the 'kth' singular value. + - if (DLIB_USE_LAPACK is #defined) then + - if (withu == withv) then + - the xGESDD routine from LAPACK is used to compute the SVD. + - else + - the xGESVD routine from LAPACK is used to compute the SVD. + !*/ + +// ---------------------------------------------------------------------------------------- + + void svd3 ( + const matrix_exp& m, + matrix& u, + matrix& w, + matrix& v + ); + /*! + ensures + - computes the singular value decomposition of m + - m == #u*diagm(#w)*trans(#v) + - trans(#u)*#u == identity matrix + - trans(#v)*#v == identity matrix + - #w == the singular values of the matrix m in no + particular order. + - #u.nr() == m.nr() + - #u.nc() == m.nc() + - #w.nr() == m.nc() + - #w.nc() == 1 + - #v.nr() == m.nc() + - #v.nc() == m.nc() + - if DLIB_USE_LAPACK is #defined then the xGESVD routine + from LAPACK is used to compute the SVD. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + void svd_fast ( + const matrix& A, + matrix& u, + matrix& w, + matrix& v, + unsigned long l, + unsigned long q = 1 + ); + /*! + requires + - l > 0 + - A.size() > 0 + (i.e. A can't be an empty matrix) + ensures + - computes the singular value decomposition of A. + - Lets define some constants we use to document the behavior of svd_fast(): + - Let m = A.nr() + - Let n = A.nc() + - Let k = min(l, min(m,n)) + - Therefore, A represents an m by n matrix and svd_fast() is designed + to find a rank-k representation of it. + - if (the rank of A is <= k) then + - A == #u*diagm(#w)*trans(#v) + - else + - A is approximated by #u*diagm(#w)*trans(#v) + (i.e. In this case A can't be represented with a rank-k matrix, so the + matrix you get by trying to reconstruct A from the output of the SVD is + not exactly the same.) + - trans(#u)*#u == identity matrix + - trans(#v)*#v == identity matrix + - #w == the top k singular values of the matrix A (in no particular order). + - #u.nr() == m + - #u.nc() == k + - #w.nr() == k + - #w.nc() == 1 + - #v.nr() == n + - #v.nc() == k + - This function implements the randomized subspace iteration defined in the + algorithm 4.4 and 5.1 boxes of the paper: + Finding Structure with Randomness: Probabilistic Algorithms for + Constructing Approximate Matrix Decompositions by Halko et al. + Therefore, it is very fast and suitable for use with very large matrices. + Moreover, q is the number of subspace iterations performed. Larger + values of q might increase the accuracy of the solution but the default + value should be good for many problems. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename sparse_vector_type, + typename T + > + void svd_fast ( + const std::vector& A, + matrix& u, + matrix& w, + matrix& v, + unsigned long l, + unsigned long q = 1 + ); + /*! + requires + - A contains a set of sparse vectors. See dlib/svm/sparse_vector_abstract.h + for a definition of what constitutes a sparse vector. + - l > 0 + - max_index_plus_one(A) > 0 + (i.e. A can't be an empty matrix) + ensures + - computes the singular value decomposition of A. In this case, we interpret A + as a matrix of A.size() rows, where each row is defined by a sparse vector. + - Lets define some constants we use to document the behavior of svd_fast(): + - Let m = A.size() + - Let n = max_index_plus_one(A) + - Let k = min(l, min(m,n)) + - Therefore, A represents an m by n matrix and svd_fast() is designed + to find a rank-k representation of it. + - if (the rank of A is <= k) then + - A == #u*diagm(#w)*trans(#v) + - else + - A is approximated by #u*diagm(#w)*trans(#v) + (i.e. In this case A can't be represented with a rank-k matrix, so the + matrix you get by trying to reconstruct A from the output of the SVD is + not exactly the same.) + - trans(#u)*#u == identity matrix + - trans(#v)*#v == identity matrix + - #w == the top k singular values of the matrix A (in no particular order). + - #u.nr() == m + - #u.nc() == k + - #w.nr() == k + - #w.nc() == 1 + - #v.nr() == n + - #v.nc() == k + - This function implements the randomized subspace iteration defined in the + algorithm 4.4 and 5.1 boxes of the paper: + Finding Structure with Randomness: Probabilistic Algorithms for + Constructing Approximate Matrix Decompositions by Halko et al. + Therefore, it is very fast and suitable for use with very large matrices. + Moreover, q is the number of subspace iterations performed. Larger + values of q might increase the accuracy of the solution but the default + value should be good for many problems. + !*/ + + template < + typename sparse_vector_type, + typename T + > + void svd_fast ( + const std::vector& A, + matrix& w, + matrix& v, + unsigned long l, + unsigned long q = 1 + ); + /*! + This function is identical to the above svd_fast() except it doesn't compute u. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + long NR, + long NC, + typename MM, + typename L + > + void orthogonalize ( + matrix& m + ); + /*! + requires + - m.nr() >= m.nc() + - m.size() > 0 + ensures + - #m == an orthogonal matrix with the same dimensions as m. In particular, + the columns of #m have the same span as the columns of m. + - trans(#m)*#m == identity matrix + - This function is just shorthand for computing the QR decomposition of m + and then storing the Q factor into #m. + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix real_eigenvalues ( + const matrix_exp& m + ); + /*! + requires + - m.nr() == m.nc() + - matrix_exp::type == float or double + ensures + - returns a matrix E such that: + - E.nr() == m.nr() + - E.nc() == 1 + - E contains the real part of all eigenvalues of the matrix m. + (note that the eigenvalues are not sorted) + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp::type det ( + const matrix_exp& m + ); + /*! + requires + - m is a square matrix + ensures + - returns the determinant of m + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp::type trace ( + const matrix_exp& m + ); + /*! + requires + - m is a square matrix + ensures + - returns the trace of m + (i.e. returns sum(diag(m))) + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp::matrix_type chol ( + const matrix_exp& A + ); + /*! + requires + - A is a square matrix + ensures + - if (A has a Cholesky Decomposition) then + - returns the decomposition of A. That is, returns a matrix L + such that L*trans(L) == A. L will also be lower triangular. + - else + - returns a matrix with the same dimensions as A but it + will have a bogus value. I.e. it won't be a decomposition. + In this case the algorithm returns a partial decomposition. + - You can tell when chol fails by looking at the lower right + element of the returned matrix. If it is 0 then it means + A does not have a cholesky decomposition. + + - If DLIB_USE_LAPACK is defined then the LAPACK routine xPOTRF + is used to compute the cholesky decomposition. + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp::matrix_type inv_lower_triangular ( + const matrix_exp& A + ); + /*! + requires + - A is a square matrix + ensures + - if (A is lower triangular) then + - returns the inverse of A. + - else + - returns a matrix with the same dimensions as A but it + will have a bogus value. I.e. it won't be an inverse. + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp::matrix_type inv_upper_triangular ( + const matrix_exp& A + ); + /*! + requires + - A is a square matrix + ensures + - if (A is upper triangular) then + - returns the inverse of A. + - else + - returns a matrix with the same dimensions as A but it + will have a bogus value. I.e. it won't be an inverse. + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// Matrix decomposition classes +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename matrix_exp_type + > + class lu_decomposition + { + /*! + REQUIREMENTS ON matrix_exp_type + must be some kind of matrix expression as defined in the + dlib/matrix/matrix_abstract.h file. (e.g. a dlib::matrix object) + The matrix type must also contain float or double values. + + WHAT THIS OBJECT REPRESENTS + This object represents something that can compute an LU + decomposition of a real valued matrix. That is, for any + matrix A it computes matrices L, U, and a pivot vector P such + that rowm(A,P) == L*U. + + The LU decomposition with pivoting always exists, even if the matrix is + singular, so the constructor will never fail. The primary use of the + LU decomposition is in the solution of square systems of simultaneous + linear equations. This will fail if is_singular() returns true (or + if A is very nearly singular). + + If DLIB_USE_LAPACK is defined then the LAPACK routine xGETRF + is used to compute the LU decomposition. + !*/ + + public: + + const static long NR = matrix_exp_type::NR; + const static long NC = matrix_exp_type::NC; + typedef typename matrix_exp_type::type type; + typedef typename matrix_exp_type::mem_manager_type mem_manager_type; + typedef typename matrix_exp_type::layout_type layout_type; + + typedef matrix matrix_type; + typedef matrix column_vector_type; + typedef matrix pivot_column_vector_type; + + template + lu_decomposition ( + const matrix_exp &A + ); + /*! + requires + - EXP::type == lu_decomposition::type + - A.size() > 0 + ensures + - #nr() == A.nr() + - #nc() == A.nc() + - #is_square() == (A.nr() == A.nc()) + - computes the LU factorization of the given A matrix. + !*/ + + bool is_square ( + ) const; + /*! + ensures + - if (the input A matrix was a square matrix) then + - returns true + - else + - returns false + !*/ + + bool is_singular ( + ) const; + /*! + requires + - is_square() == true + ensures + - if (the input A matrix is singular) then + - returns true + - else + - returns false + !*/ + + long nr( + ) const; + /*! + ensures + - returns the number of rows in the input matrix + !*/ + + long nc( + ) const; + /*! + ensures + - returns the number of columns in the input matrix + !*/ + + const matrix_type get_l ( + ) const; + /*! + ensures + - returns the lower triangular L factor of the LU factorization. + - L.nr() == nr() + - L.nc() == min(nr(),nc()) + !*/ + + const matrix_type get_u ( + ) const; + /*! + ensures + - returns the upper triangular U factor of the LU factorization. + - U.nr() == min(nr(),nc()) + - U.nc() == nc() + !*/ + + const pivot_column_vector_type& get_pivot ( + ) const; + /*! + ensures + - returns the pivot permutation vector. That is, + if A is the input matrix then this function + returns a vector P such that: + - rowm(A,P) == get_l()*get_u() + - P.nr() == A.nr() + !*/ + + type det ( + ) const; + /*! + requires + - is_square() == true + ensures + - computes and returns the determinant of the input + matrix using LU factors. + !*/ + + template + const matrix_type solve ( + const matrix_exp &B + ) const; + /*! + requires + - EXP::type == lu_decomposition::type + - is_square() == true + - B.nr() == nr() + ensures + - Let A denote the input matrix to this class's constructor. + Then this function solves A*X == B for X and returns X. + - Note that if A is singular (or very close to singular) then + the X returned by this function won't fit A*X == B very well (if at all). + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename matrix_exp_type + > + class cholesky_decomposition + { + /*! + REQUIREMENTS ON matrix_exp_type + must be some kind of matrix expression as defined in the + dlib/matrix/matrix_abstract.h file. (e.g. a dlib::matrix object) + The matrix type must also contain float or double values. + + WHAT THIS OBJECT REPRESENTS + This object represents something that can compute a cholesky + decomposition of a real valued matrix. That is, for any + symmetric, positive definite matrix A, it computes a lower + triangular matrix L such that A == L*trans(L). + + If the matrix is not symmetric or positive definite, the function + computes only a partial decomposition. This can be tested with + the is_spd() flag. + + If DLIB_USE_LAPACK is defined then the LAPACK routine xPOTRF + is used to compute the cholesky decomposition. + !*/ + + public: + + const static long NR = matrix_exp_type::NR; + const static long NC = matrix_exp_type::NC; + typedef typename matrix_exp_type::type type; + typedef typename matrix_exp_type::mem_manager_type mem_manager_type; + typedef typename matrix_exp_type::layout_type layout_type; + + typedef typename matrix_exp_type::matrix_type matrix_type; + typedef matrix column_vector_type; + + template + cholesky_decomposition( + const matrix_exp& A + ); + /*! + requires + - EXP::type == cholesky_decomposition::type + - A.size() > 0 + - A.nr() == A.nc() + (i.e. A must be a square matrix) + ensures + - if (A is symmetric positive-definite) then + - #is_spd() == true + - Constructs a lower triangular matrix L, such that L*trans(L) == A. + and #get_l() == L + - else + - #is_spd() == false + !*/ + + bool is_spd( + ) const; + /*! + ensures + - if (the input matrix was symmetric positive-definite) then + - returns true + - else + - returns false + !*/ + + const matrix_type& get_l( + ) const; + /*! + ensures + - returns the lower triangular factor, L, such that L*trans(L) == A + (where A is the input matrix to this class's constructor) + - Note that if A is not symmetric positive definite or positive semi-definite + then the equation L*trans(L) == A won't hold. + !*/ + + template + const matrix solve ( + const matrix_exp& B + ) const; + /*! + requires + - EXP::type == cholesky_decomposition::type + - B.nr() == get_l().nr() + (i.e. the number of rows in B must match the number of rows in the + input matrix A) + ensures + - Let A denote the input matrix to this class's constructor. Then + this function solves A*X = B for X and returns X. + - Note that if is_spd() == false or A was really close to being + non-SPD then the solver will fail to find an accurate solution. + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename matrix_exp_type + > + class qr_decomposition + { + /*! + REQUIREMENTS ON matrix_exp_type + must be some kind of matrix expression as defined in the + dlib/matrix/matrix_abstract.h file. (e.g. a dlib::matrix object) + The matrix type must also contain float or double values. + + WHAT THIS OBJECT REPRESENTS + This object represents something that can compute a classical + QR decomposition of an m-by-n real valued matrix A with m >= n. + + The QR decomposition is an m-by-n orthogonal matrix Q and an + n-by-n upper triangular matrix R so that A == Q*R. The QR decomposition + always exists, even if the matrix does not have full rank, so the + constructor will never fail. The primary use of the QR decomposition + is in the least squares solution of non-square systems of simultaneous + linear equations. This will fail if is_full_rank() returns false or + A is very nearly not full rank. + + The Q and R factors can be retrieved via the get_q() and get_r() + methods. Furthermore, a solve() method is provided to find the + least squares solution of Ax=b using the QR factors. + + If DLIB_USE_LAPACK is #defined then the xGEQRF routine + from LAPACK is used to compute the QR decomposition. + !*/ + + public: + + const static long NR = matrix_exp_type::NR; + const static long NC = matrix_exp_type::NC; + typedef typename matrix_exp_type::type type; + typedef typename matrix_exp_type::mem_manager_type mem_manager_type; + typedef typename matrix_exp_type::layout_type layout_type; + + typedef matrix matrix_type; + + template + qr_decomposition( + const matrix_exp& A + ); + /*! + requires + - EXP::type == qr_decomposition::type + - A.nr() >= A.nc() + - A.size() > 0 + ensures + - #nr() == A.nr() + - #nc() == A.nc() + - computes the QR decomposition of the given A matrix. + !*/ + + bool is_full_rank( + ) const; + /*! + ensures + - if (the input A matrix had full rank) then + - returns true + - else + - returns false + !*/ + + long nr( + ) const; + /*! + ensures + - returns the number of rows in the input matrix + !*/ + + long nc( + ) const; + /*! + ensures + - returns the number of columns in the input matrix + !*/ + + const matrix_type get_r ( + ) const; + /*! + ensures + - returns a matrix R such that: + - R is the upper triangular factor, R, of the QR factorization + - get_q()*R == input matrix A + - R.nr() == nc() + - R.nc() == nc() + !*/ + + const matrix_type get_q ( + ) const; + /*! + ensures + - returns a matrix Q such that: + - Q is the economy-sized orthogonal factor Q from the QR + factorization. + - trans(Q)*Q == identity matrix + - Q*get_r() == input matrix A + - Q.nr() == nr() + - Q.nc() == nc() + !*/ + + void get_q ( + matrix_type& Q + ) const; + /*! + ensures + - #Q == get_q() + - This function exists to allow a user to get the Q matrix without the + overhead of returning a matrix by value. + !*/ + + template + const matrix_type solve ( + const matrix_exp& B + ) const; + /*! + requires + - EXP::type == qr_decomposition::type + - B.nr() == nr() + ensures + - Let A denote the input matrix to this class's constructor. + Then this function finds the least squares solution to the equation A*X = B + and returns X. X has the following properties: + - X is the matrix that minimizes the two norm of A*X-B. That is, it + minimizes sum(squared(A*X - B)). + - X.nr() == nc() + - X.nc() == B.nc() + - Note that this function will fail to output a good solution if is_full_rank() == false + or the A matrix is close to not being full rank. + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename matrix_exp_type + > + class eigenvalue_decomposition + { + /*! + REQUIREMENTS ON matrix_exp_type + must be some kind of matrix expression as defined in the + dlib/matrix/matrix_abstract.h file. (e.g. a dlib::matrix object) + The matrix type must also contain float or double values. + + WHAT THIS OBJECT REPRESENTS + This object represents something that can compute an eigenvalue + decomposition of a real valued matrix. So it gives + you the set of eigenvalues and eigenvectors for a matrix. + + Let A denote the input matrix to this object's constructor. Then + what this object does is it finds two matrices, D and V, such that + - A*V == V*D + Where V is a square matrix that contains all the eigenvectors + of the A matrix (each column of V is an eigenvector) and + D is a diagonal matrix containing the eigenvalues of A. + + + It is important to note that if A is symmetric or non-symmetric you + get somewhat different results. If A is a symmetric matrix (i.e. A == trans(A)) + then: + - All the eigenvalues and eigenvectors of A are real numbers. + - Because of this there isn't really any point in using the + part of this class's interface that returns complex matrices. + All you need are the get_real_eigenvalues() and + get_pseudo_v() functions. + - V*trans(V) should be equal to the identity matrix. That is, all the + eigenvectors in V should be orthonormal. + - So A == V*D*trans(V) + - If DLIB_USE_LAPACK is #defined then this object uses the xSYEVR LAPACK + routine. + + On the other hand, if A is not symmetric then: + - Some of the eigenvalues and eigenvectors might be complex numbers. + - An eigenvalue is complex if and only if its corresponding eigenvector + is complex. So you can check for this case by just checking + get_imag_eigenvalues() to see if any values are non-zero. You don't + have to check the V matrix as well. + - V*trans(V) won't be equal to the identity matrix but it is usually + invertible. So A == V*D*inv(V) is usually a valid statement but + A == V*D*trans(V) won't be. + - If DLIB_USE_LAPACK is #defined then this object uses the xGEEV LAPACK + routine. + !*/ + + public: + + const static long NR = matrix_exp_type::NR; + const static long NC = matrix_exp_type::NC; + typedef typename matrix_exp_type::type type; + typedef typename matrix_exp_type::mem_manager_type mem_manager_type; + typedef typename matrix_exp_type::layout_type layout_type; + + typedef typename matrix_exp_type::matrix_type matrix_type; + typedef matrix column_vector_type; + + typedef matrix,0,0,mem_manager_type,layout_type> complex_matrix_type; + typedef matrix,NR,1,mem_manager_type,layout_type> complex_column_vector_type; + + + template + eigenvalue_decomposition( + const matrix_exp& A + ); + /*! + requires + - A.nr() == A.nc() + - A.size() > 0 + - EXP::type == eigenvalue_decomposition::type + ensures + - #dim() == A.nr() + - computes the eigenvalue decomposition of A. + - #get_eigenvalues() == the eigenvalues of A + - #get_v() == all the eigenvectors of A + !*/ + + template + eigenvalue_decomposition( + const matrix_op >& A + ); + /*! + requires + - A.nr() == A.nc() + - A.size() > 0 + - EXP::type == eigenvalue_decomposition::type + ensures + - #dim() == A.nr() + - computes the eigenvalue decomposition of the symmetric matrix A. Does so + using a method optimized for symmetric matrices. + - #get_eigenvalues() == the eigenvalues of A + - #get_v() == all the eigenvectors of A + - moreover, since A is symmetric there won't be any imaginary eigenvalues. So + we will have: + - #get_imag_eigenvalues() == 0 + - #get_real_eigenvalues() == the eigenvalues of A + - #get_pseudo_v() == all the eigenvectors of A + - diagm(#get_real_eigenvalues()) == #get_pseudo_d() + + Note that the symmetric matrix operator is created by the + dlib::make_symmetric() function. This function simply reflects + the lower triangular part of a square matrix into the upper triangular + part to create a symmetric matrix. It can also be used to denote that a + matrix is already symmetric using the C++ type system. + !*/ + + long dim ( + ) const; + /*! + ensures + - dim() == the number of rows/columns in the input matrix A + !*/ + + const complex_column_vector_type get_eigenvalues ( + ) const; + /*! + ensures + - returns diag(get_d()). That is, returns a + vector that contains the eigenvalues of the input + matrix. + - the returned vector has dim() rows + - the eigenvalues are not sorted in any particular way + !*/ + + const column_vector_type& get_real_eigenvalues ( + ) const; + /*! + ensures + - returns the real parts of the eigenvalues. That is, + returns real(get_eigenvalues()) + - the returned vector has dim() rows + - the eigenvalues are not sorted in any particular way + !*/ + + const column_vector_type& get_imag_eigenvalues ( + ) const; + /*! + ensures + - returns the imaginary parts of the eigenvalues. That is, + returns imag(get_eigenvalues()) + - the returned vector has dim() rows + - the eigenvalues are not sorted in any particular way + !*/ + + const complex_matrix_type get_v ( + ) const; + /*! + ensures + - returns the eigenvector matrix V that is + dim() rows by dim() columns + - Each column in V is one of the eigenvectors of the input + matrix + !*/ + + const complex_matrix_type get_d ( + ) const; + /*! + ensures + - returns a matrix D such that: + - D.nr() == dim() + - D.nc() == dim() + - diag(D) == get_eigenvalues() + (i.e. the diagonal of D contains all the eigenvalues in the input matrix) + - all off diagonal elements of D are set to 0 + !*/ + + const matrix_type& get_pseudo_v ( + ) const; + /*! + ensures + - returns a matrix that is dim() rows by dim() columns + - Let A denote the input matrix given to this object's constructor. + - if (A has any imaginary eigenvalues) then + - returns the pseudo-eigenvector matrix V + - The matrix V returned by this function is structured such that: + - A*V == V*get_pseudo_d() + - else + - returns the eigenvector matrix V with A's eigenvectors as + the columns of V + - A*V == V*diagm(get_real_eigenvalues()) + !*/ + + const matrix_type get_pseudo_d ( + ) const; + /*! + ensures + - The returned matrix is dim() rows by dim() columns + - Computes and returns the block diagonal eigenvalue matrix. + If the original matrix A is not symmetric, then the eigenvalue + matrix D is block diagonal with the real eigenvalues in 1-by-1 + blocks and any complex eigenvalues, + a + i*b, in 2-by-2 blocks, (a, b; -b, a). That is, if the complex + eigenvalues look like + + u + iv . . . . . + . u - iv . . . . + . . a + ib . . . + . . . a - ib . . + . . . . x . + . . . . . y + + Then D looks like + + u v . . . . + -v u . . . . + . . a b . . + . . -b a . . + . . . . x . + . . . . . y + + This keeps V (The V you get from get_pseudo_v()) a real matrix in both + symmetric and non-symmetric cases, and A*V = V*D. + - the eigenvalues are not sorted in any particular way + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MATRIx_LA_FUNCTS_ABSTRACT_ + diff --git a/ml/dlib/dlib/matrix/matrix_lu.h b/ml/dlib/dlib/matrix/matrix_lu.h new file mode 100644 index 000000000..3e49cd653 --- /dev/null +++ b/ml/dlib/dlib/matrix/matrix_lu.h @@ -0,0 +1,361 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +// This code was adapted from code from the JAMA part of NIST's TNT library. +// See: http://math.nist.gov/tnt/ +#ifndef DLIB_MATRIX_LU_DECOMPOSITION_H +#define DLIB_MATRIX_LU_DECOMPOSITION_H + +#include "matrix.h" +#include "matrix_utilities.h" +#include "matrix_subexp.h" +#include "matrix_trsm.h" +#include + +#ifdef DLIB_USE_LAPACK +#include "lapack/getrf.h" +#endif + + +namespace dlib +{ + + template < + typename matrix_exp_type + > + class lu_decomposition + { + public: + + const static long NR = matrix_exp_type::NR; + const static long NC = matrix_exp_type::NC; + typedef typename matrix_exp_type::type type; + typedef typename matrix_exp_type::mem_manager_type mem_manager_type; + typedef typename matrix_exp_type::layout_type layout_type; + + typedef matrix matrix_type; + typedef matrix column_vector_type; + typedef matrix pivot_column_vector_type; + + // You have supplied an invalid type of matrix_exp_type. You have + // to use this object with matrices that contain float or double type data. + COMPILE_TIME_ASSERT((is_same_type::value || + is_same_type::value )); + + template + lu_decomposition ( + const matrix_exp &A + ); + + bool is_square ( + ) const; + + bool is_singular ( + ) const; + + long nr( + ) const; + + long nc( + ) const; + + const matrix_type get_l ( + ) const; + + const matrix_type get_u ( + ) const; + + const pivot_column_vector_type& get_pivot ( + ) const; + + type det ( + ) const; + + template + const matrix_type solve ( + const matrix_exp &B + ) const; + + private: + + /* Array for internal storage of decomposition. */ + matrix LU; + long m, n, pivsign; + pivot_column_vector_type piv; + + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// Public member functions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template + template + lu_decomposition:: + lu_decomposition ( + const matrix_exp& A + ) : + LU(A), + m(A.nr()), + n(A.nc()) + { + using namespace std; + using std::abs; + + COMPILE_TIME_ASSERT((is_same_type::value)); + + // make sure requires clause is not broken + DLIB_ASSERT(A.size() > 0, + "\tlu_decomposition::lu_decomposition(A)" + << "\n\tInvalid inputs were given to this function" + << "\n\tA.size(): " << A.size() + << "\n\tthis: " << this + ); + +#ifdef DLIB_USE_LAPACK + matrix piv_temp; + lapack::getrf(LU, piv_temp); + + pivsign = 1; + + // Turn the piv_temp vector into a more useful form. This way we will have the identity + // rowm(A,piv) == L*U. The permutation vector that comes out of LAPACK is somewhat + // different. + piv = trans(range(0,m-1)); + for (long i = 0; i < piv_temp.size(); ++i) + { + // -1 because FORTRAN is indexed starting with 1 instead of 0 + if (piv(piv_temp(i)-1) != piv(i)) + { + std::swap(piv(i), piv(piv_temp(i)-1)); + pivsign = -pivsign; + } + } + +#else + + // Use a "left-looking", dot-product, Crout/Doolittle algorithm. + + + piv = trans(range(0,m-1)); + pivsign = 1; + + column_vector_type LUcolj(m); + + // Outer loop. + for (long j = 0; j < n; j++) + { + + // Make a copy of the j-th column to localize references. + LUcolj = colm(LU,j); + + // Apply previous transformations. + for (long i = 0; i < m; i++) + { + // Most of the time is spent in the following dot product. + const long kmax = std::min(i,j); + type s; + if (kmax > 0) + s = rowm(LU,i, kmax)*colm(LUcolj,0,kmax); + else + s = 0; + + LU(i,j) = LUcolj(i) -= s; + } + + // Find pivot and exchange if necessary. + long p = j; + for (long i = j+1; i < m; i++) + { + if (abs(LUcolj(i)) > abs(LUcolj(p))) + { + p = i; + } + } + if (p != j) + { + long k=0; + for (k = 0; k < n; k++) + { + type t = LU(p,k); + LU(p,k) = LU(j,k); + LU(j,k) = t; + } + k = piv(p); + piv(p) = piv(j); + piv(j) = k; + pivsign = -pivsign; + } + + // Compute multipliers. + if ((j < m) && (LU(j,j) != 0.0)) + { + for (long i = j+1; i < m; i++) + { + LU(i,j) /= LU(j,j); + } + } + } + +#endif + } + +// ---------------------------------------------------------------------------------------- + + template + bool lu_decomposition:: + is_square ( + ) const + { + return m == n; + } + +// ---------------------------------------------------------------------------------------- + + template + long lu_decomposition:: + nr ( + ) const + { + return m; + } + +// ---------------------------------------------------------------------------------------- + + template + long lu_decomposition:: + nc ( + ) const + { + return n; + } + +// ---------------------------------------------------------------------------------------- + + template + bool lu_decomposition:: + is_singular ( + ) const + { + /* Is the matrix singular? + if upper triangular factor U (and hence A) is singular, false otherwise. + */ + // make sure requires clause is not broken + DLIB_ASSERT(is_square() == true, + "\tbool lu_decomposition::is_singular()" + << "\n\tYou can only use this on square matrices" + << "\n\tthis: " << this + ); + + type max_val, min_val; + find_min_and_max (abs(diag(LU)), min_val, max_val); + type eps = max_val; + if (eps != 0) + eps *= std::sqrt(std::numeric_limits::epsilon())/10; + else + eps = 1; // there is no max so just use 1 + + return min_val < eps; + } + +// ---------------------------------------------------------------------------------------- + + template + const typename lu_decomposition::matrix_type lu_decomposition:: + get_l ( + ) const + { + if (LU.nr() >= LU.nc()) + return lowerm(LU,1.0); + else + return lowerm(subm(LU,0,0,m,m), 1.0); + } + +// ---------------------------------------------------------------------------------------- + + template + const typename lu_decomposition::matrix_type lu_decomposition:: + get_u ( + ) const + { + if (LU.nr() >= LU.nc()) + return upperm(subm(LU,0,0,n,n)); + else + return upperm(LU); + } + +// ---------------------------------------------------------------------------------------- + + template + const typename lu_decomposition::pivot_column_vector_type& lu_decomposition:: + get_pivot ( + ) const + { + return piv; + } + +// ---------------------------------------------------------------------------------------- + + template + typename lu_decomposition::type lu_decomposition:: + det ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(is_square() == true, + "\ttype lu_decomposition::det()" + << "\n\tYou can only use this on square matrices" + << "\n\tthis: " << this + ); + + // Check if it is singular and if it is just return 0. + // We want to do this because a prod() operation can easily + // overcome a single diagonal element that is effectively 0 when + // LU is a big enough matrix. + if (is_singular()) + return 0; + + return prod(diag(LU))*static_cast(pivsign); + } + +// ---------------------------------------------------------------------------------------- + + template + template + const typename lu_decomposition::matrix_type lu_decomposition:: + solve ( + const matrix_exp &B + ) const + { + COMPILE_TIME_ASSERT((is_same_type::value)); + + // make sure requires clause is not broken + DLIB_ASSERT(is_square() == true && B.nr() == nr(), + "\ttype lu_decomposition::solve()" + << "\n\tInvalid arguments to this function" + << "\n\tis_square(): " << (is_square()? "true":"false" ) + << "\n\tB.nr(): " << B.nr() + << "\n\tnr(): " << nr() + << "\n\tthis: " << this + ); + + // Copy right hand side with pivoting + matrix X(rowm(B, piv)); + + using namespace blas_bindings; + // Solve L*Y = B(piv,:) + triangular_solver(CblasLeft, CblasLower, CblasNoTrans, CblasUnit, LU, X); + // Solve U*X = Y; + triangular_solver(CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, LU, X); + return X; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MATRIX_LU_DECOMPOSITION_H + + diff --git a/ml/dlib/dlib/matrix/matrix_mat.h b/ml/dlib/dlib/matrix/matrix_mat.h new file mode 100644 index 000000000..803d7d999 --- /dev/null +++ b/ml/dlib/dlib/matrix/matrix_mat.h @@ -0,0 +1,733 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MATRIx_MAT_Hh_ +#define DLIB_MATRIx_MAT_Hh_ + +#include "matrix_mat_abstract.h" +#include "../stl_checked.h" +#include +#include "matrix_op.h" +#include "../array2d.h" +#include "../array.h" +#include "../image_processing/generic_image.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP + > + const matrix_exp& mat ( + const matrix_exp& m + ) + { + return m; + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_image_to_mat : does_not_alias + { + op_image_to_mat( const image_type& img) : imgview(img){} + + const_image_view imgview; + + const static long cost = 1; + const static long NR = 0; + const static long NC = 0; + typedef pixel_type type; + typedef const pixel_type& const_ret_type; + typedef default_memory_manager mem_manager_type; + typedef row_major_layout layout_type; + + const_ret_type apply (long r, long c ) const { return imgview[r][c]; } + + long nr () const { return imgview.nr(); } + long nc () const { return imgview.nc(); } + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > // The reason we disable this if it is a matrix is because this matrix_op claims + // to not alias any matrix. But obviously that would be a problem if we let it + // take a matrix. + const typename disable_if,matrix_op::pixel_type> > >::type mat ( + const image_type& img + ) + { + typedef op_image_to_mat::pixel_type> op; + return matrix_op(op(img)); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_image_view_to_mat : does_not_alias + { + op_image_view_to_mat( const image_view& img) : imgview(img){} + + typedef typename image_traits::pixel_type pixel_type; + + const image_view& imgview; + + const static long cost = 1; + const static long NR = 0; + const static long NC = 0; + typedef pixel_type type; + typedef const pixel_type& const_ret_type; + typedef default_memory_manager mem_manager_type; + typedef row_major_layout layout_type; + + const_ret_type apply (long r, long c ) const { return imgview[r][c]; } + + long nr () const { return imgview.nr(); } + long nc () const { return imgview.nc(); } + }; + + template < + typename image_type + > + const matrix_op > mat ( + const image_view& img + ) + { + typedef op_image_view_to_mat op; + return matrix_op(op(img)); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_const_image_view_to_mat : does_not_alias + { + op_const_image_view_to_mat( const const_image_view& img) : imgview(img){} + + typedef typename image_traits::pixel_type pixel_type; + + const const_image_view& imgview; + + const static long cost = 1; + const static long NR = 0; + const static long NC = 0; + typedef pixel_type type; + typedef const pixel_type& const_ret_type; + typedef default_memory_manager mem_manager_type; + typedef row_major_layout layout_type; + + const_ret_type apply (long r, long c ) const { return imgview[r][c]; } + + long nr () const { return imgview.nr(); } + long nc () const { return imgview.nc(); } + }; + + template < + typename image_type + > + const matrix_op > mat ( + const const_image_view& img + ) + { + typedef op_const_image_view_to_mat op; + return matrix_op(op(img)); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_array_to_mat : does_not_alias + { + op_array_to_mat( const T& vect_) : vect(vect_){} + + const T& vect; + + const static long cost = 1; + const static long NR = 0; + const static long NC = 1; + typedef typename T::type type; + typedef const typename T::type& const_ret_type; + typedef typename T::mem_manager_type mem_manager_type; + typedef row_major_layout layout_type; + + const_ret_type apply (long r, long ) const { return vect[r]; } + + long nr () const { return vect.size(); } + long nc () const { return 1; } + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename MM + > + const matrix_op > > mat ( + const array& m + ) + { + typedef op_array_to_mat > op; + return matrix_op(op(m)); + } + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + template + struct not_bool { typedef U type; }; + template <> + struct not_bool { typedef bool type; }; + } + + template + struct op_std_vect_to_mat : does_not_alias + { + op_std_vect_to_mat( const T& vect_) : vect(vect_){} + + const T& vect; + + const static long cost = 1; + const static long NR = 0; + const static long NC = 1; + typedef typename T::value_type type; + // Since std::vector returns a proxy for bool types we need to make sure we don't + // return an element by reference if it is a bool type. + typedef typename impl::not_bool::type const_ret_type; + + typedef default_memory_manager mem_manager_type; + typedef row_major_layout layout_type; + + const_ret_type apply (long r, long ) const { return vect[r]; } + + long nr () const { return vect.size(); } + long nc () const { return 1; } + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename value_type, + typename alloc + > + const matrix_op > > mat ( + const std::vector& vector + ) + { + typedef op_std_vect_to_mat > op; + return matrix_op(op(vector)); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename value_type, + typename alloc + > + const matrix_op > > mat ( + const std_vector_c& vector + ) + { + typedef op_std_vect_to_mat > op; + return matrix_op(op(vector)); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_pointer_to_mat; + + template + struct op_pointer_to_col_vect + { + op_pointer_to_col_vect( + const T* ptr_, + const long size_ + ) : ptr(ptr_), size(size_){} + + const T* ptr; + const long size; + + const static long cost = 1; + const static long NR = 0; + const static long NC = 1; + typedef T type; + typedef const T& const_ret_type; + typedef default_memory_manager mem_manager_type; + typedef row_major_layout layout_type; + + const_ret_type apply (long r, long ) const { return ptr[r]; } + + long nr () const { return size; } + long nc () const { return 1; } + + template bool aliases ( const matrix_exp& ) const { return false; } + template bool destructively_aliases ( const matrix_exp& ) const { return false; } + + template + bool aliases ( + const matrix_exp >& item + ) const + { + if (item.size() == 0) + return false; + else + return (ptr == &item(0,0)); + } + + inline bool aliases ( + const matrix_exp > >& item + ) const; + + bool aliases ( + const matrix_exp > >& item + ) const + { + return item.ref().op.ptr == ptr; + } + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + const matrix_op > mat ( + const T* ptr, + long nr + ) + { + DLIB_ASSERT(nr >= 0 , + "\tconst matrix_exp mat(ptr, nr)" + << "\n\t nr must be >= 0" + << "\n\t nr: " << nr + ); + typedef op_pointer_to_col_vect op; + return matrix_op(op(ptr, nr)); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_pointer_to_mat + { + op_pointer_to_mat( + const T* ptr_, + const long nr_, + const long nc_ + ) : ptr(ptr_), rows(nr_), cols(nc_), stride(nc_){} + + op_pointer_to_mat( + const T* ptr_, + const long nr_, + const long nc_, + const long stride_ + ) : ptr(ptr_), rows(nr_), cols(nc_), stride(stride_){} + + const T* ptr; + const long rows; + const long cols; + const long stride; + + const static long cost = 1; + const static long NR = 0; + const static long NC = 0; + typedef T type; + typedef const T& const_ret_type; + typedef default_memory_manager mem_manager_type; + typedef row_major_layout layout_type; + + const_ret_type apply (long r, long c) const { return ptr[r*stride + c]; } + + long nr () const { return rows; } + long nc () const { return cols; } + + template bool aliases ( const matrix_exp& ) const { return false; } + template bool destructively_aliases ( const matrix_exp& ) const { return false; } + + template + bool aliases ( + const matrix_exp >& item + ) const + { + if (item.size() == 0) + return false; + else + return (ptr == &item(0,0)); + } + + bool aliases ( + const matrix_exp > >& item + ) const + { + return item.ref().op.ptr == ptr; + } + + bool aliases ( + const matrix_exp > >& item + ) const + { + return item.ref().op.ptr == ptr; + } + }; + + template + bool op_pointer_to_col_vect:: + aliases ( + const matrix_exp > >& item + ) const + { + return item.ref().op.ptr == ptr; + } + + template + bool matrix::aliases ( + const matrix_exp > >& item + ) const + { + if (size() != 0) + return item.ref().op.ptr == &data(0,0); + else + return false; + } + + template + bool matrix::aliases ( + const matrix_exp > >& item + ) const + { + if (size() != 0) + return item.ref().op.ptr == &data(0,0); + else + return false; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + const matrix_op > mat ( + const T* ptr, + long nr, + long nc + ) + { + DLIB_ASSERT(nr >= 0 && nc >= 0 , + "\tconst matrix_exp mat(ptr, nr, nc)" + << "\n\t nr and nc must be >= 0" + << "\n\t nr: " << nr + << "\n\t nc: " << nc + ); + typedef op_pointer_to_mat op; + return matrix_op(op(ptr,nr,nc)); + } + + template < + typename T + > + const matrix_op > mat ( + const T* ptr, + long nr, + long nc, + long stride + ) + { + DLIB_ASSERT(nr >= 0 && nc >= 0 && stride > 0 , + "\tconst matrix_exp mat(ptr, nr, nc, stride)" + << "\n\t nr and nc must be >= 0 while stride > 0" + << "\n\t nr: " << nr + << "\n\t nc: " << nc + << "\n\t stride: " << stride + ); + typedef op_pointer_to_mat op; + return matrix_op(op(ptr,nr,nc,stride)); + } + +// ---------------------------------------------------------------------------------------- + +} + +namespace arma +{ + template class Mat; +} +namespace dlib +{ + template + struct op_arma_Mat_to_mat : does_not_alias + { + op_arma_Mat_to_mat( const T& array_) : array(array_){} + + const T& array; + + const static long cost = 1; + const static long NR = 0; + const static long NC = 0; + typedef typename T::elem_type type; + typedef typename T::elem_type const_ret_type; + typedef default_memory_manager mem_manager_type; + typedef row_major_layout layout_type; + + const_ret_type apply (long r, long c ) const { return array(r,c); } + + long nr () const { return array.n_rows; } + long nc () const { return array.n_cols; } + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + const matrix_op > > mat ( + const ::arma::Mat& array + ) + { + typedef op_arma_Mat_to_mat< ::arma::Mat > op; + return matrix_op(op(array)); + } +} + +namespace Eigen +{ + template + class Matrix; +} + +namespace dlib +{ + template + struct op_eigen_Matrix_to_mat : does_not_alias + { + op_eigen_Matrix_to_mat( const T& array_) : m(array_){} + + const T& m; + + const static long cost = 1; + const static long NR = (_Rows > 0) ? _Rows : 0; + const static long NC = (_Cols > 0) ? _Cols : 0; + typedef typename T::Scalar type; + typedef typename T::Scalar const_ret_type; + typedef default_memory_manager mem_manager_type; + typedef row_major_layout layout_type; + + const_ret_type apply (long r, long c ) const { return m(r,c); } + + long nr () const { return m.rows(); } + long nc () const { return m.cols(); } + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename _Scalar, int _Rows, int _Cols, int _Options, int _MaxRows, int _MaxCols + > + const matrix_op,_Rows,_Cols > > mat ( + const ::Eigen::Matrix<_Scalar,_Rows,_Cols,_Options,_MaxRows,_MaxCols>& m + ) + { + typedef op_eigen_Matrix_to_mat< ::Eigen::Matrix<_Scalar,_Rows,_Cols,_Options,_MaxRows,_MaxCols>,_Rows,_Cols > op; + return matrix_op(op(m)); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// DEPRECATED FUNCTIONS +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + +// vector_to_matrix(), array_to_matrix(), pointer_to_matrix(), and +// pointer_to_column_vector() have been deprecated in favor of the more uniform mat() +// function. But they are here for backwards compatibility. + + template < + typename vector_type + > + const typename disable_if, matrix_op > >::type + vector_to_matrix ( + const vector_type& vector + ) + { + typedef op_array_to_mat op; + return matrix_op(op(vector)); + } + + template < + typename vector_type + > + const typename enable_if,vector_type>::type& vector_to_matrix ( + const vector_type& vector + ) + /*! + This overload catches the case where the argument to this function is + already a matrix. + !*/ + { + return vector; + } + + template < + typename value_type, + typename alloc + > + const matrix_op > > vector_to_matrix ( + const std::vector& vector + ) + { + typedef op_std_vect_to_mat > op; + return matrix_op(op(vector)); + } + + template < + typename value_type, + typename alloc + > + const matrix_op > > vector_to_matrix ( + const std_vector_c& vector + ) + { + typedef op_std_vect_to_mat > op; + return matrix_op(op(vector)); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename array_type + > + const typename enable_if,array_type>::type& + array_to_matrix ( + const array_type& array + ) + { + return array; + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_array2d_to_mat : does_not_alias + { + op_array2d_to_mat( const T& array_) : array(array_){} + + const T& array; + + const static long cost = 1; + const static long NR = 0; + const static long NC = 0; + typedef typename T::type type; + typedef const typename T::type& const_ret_type; + typedef typename T::mem_manager_type mem_manager_type; + typedef row_major_layout layout_type; + + const_ret_type apply (long r, long c ) const { return array[r][c]; } + + long nr () const { return array.nr(); } + long nc () const { return array.nc(); } + }; + + // Note that we have this version of mat() because it's slightly faster executing + // than the general one that handles any generic image. This is because it avoids + // calling image_data() which for array2d involves a single if statement but this + // version here has no if statement in its construction. + template < typename T, typename MM > + const matrix_op > > mat ( + const array2d& array + ) + { + typedef op_array2d_to_mat > op; + return matrix_op(op(array)); + } + + template < + typename array_type + > + const typename disable_if,matrix_op > >::type + array_to_matrix ( + const array_type& array + ) + { + typedef op_array2d_to_mat op; + return matrix_op(op(array)); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + const matrix_op > pointer_to_matrix ( + const T* ptr, + long nr, + long nc + ) + { + DLIB_ASSERT(nr > 0 && nc > 0 , + "\tconst matrix_exp pointer_to_matrix(ptr, nr, nc)" + << "\n\t nr and nc must be bigger than 0" + << "\n\t nr: " << nr + << "\n\t nc: " << nc + ); + typedef op_pointer_to_mat op; + return matrix_op(op(ptr,nr,nc)); + } + + template < + typename T + > + const matrix_op > pointer_to_column_vector ( + const T* ptr, + long nr + ) + { + DLIB_ASSERT(nr > 0 , + "\tconst matrix_exp pointer_to_column_vector(ptr, nr)" + << "\n\t nr must be bigger than 0" + << "\n\t nr: " << nr + ); + typedef op_pointer_to_col_vect op; + return matrix_op(op(ptr, nr)); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + inline matrix mat ( + double value + ) + { + matrix temp; + temp(0) = value; + return temp; + } + + inline matrix mat ( + float value + ) + { + matrix temp; + temp(0) = value; + return temp; + } + + inline matrix mat ( + long double value + ) + { + matrix temp; + temp(0) = value; + return temp; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MATRIx_MAT_Hh_ + + diff --git a/ml/dlib/dlib/matrix/matrix_mat_abstract.h b/ml/dlib/dlib/matrix/matrix_mat_abstract.h new file mode 100644 index 000000000..7026f60a1 --- /dev/null +++ b/ml/dlib/dlib/matrix/matrix_mat_abstract.h @@ -0,0 +1,243 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_MATRIx_MAT_ABSTRACT_Hh_ +#ifdef DLIB_MATRIx_MAT_ABSTRACT_Hh_ + +#include "matrix_abstract.h" +#inclue +#include "../array/array_kernel_abstract.h" +#include "../array2d/array2d_kernel_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP + > + const matrix_exp& mat ( + const matrix_exp& m + ); + /*! + ensures + - returns m + (i.e. this function just returns the input matrix without any modifications) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + const matrix_exp mat ( + const image_type& img + ); + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h or image_type is a image_view or + const_image_view object. + ensures + - This function converts any kind of generic image object into a dlib::matrix + expression. Therefore, it is capable of converting objects like dlib::array2d + of dlib::cv_image. + - returns a matrix R such that: + - R.nr() == array.nr() + - R.nc() == array.nc() + - for all valid r and c: + R(r, c) == array[r][c] + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename MM + > + const matrix_exp mat ( + const array& m + ); + /*! + ensures + - returns a matrix R such that: + - is_col_vector(R) == true + - R.size() == m.size() + - for all valid r: + R(r) == m[r] + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename value_type, + typename alloc + > + const matrix_exp mat ( + const std::vector& vector + ); + /*! + ensures + - returns a matrix R such that: + - is_col_vector(R) == true + - R.size() == vector.size() + - for all valid r: + R(r) == vector[r] + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename value_type, + typename alloc + > + const matrix_exp mat ( + const std_vector_c& vector + ); + /*! + ensures + - returns a matrix R such that: + - is_col_vector(R) == true + - R.size() == vector.size() + - for all valid r: + R(r) == vector[r] + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + const matrix_exp mat ( + const T* ptr, + long nr + ); + /*! + requires + - nr >= 0 + - ptr == a pointer to at least nr T objects (or the NULL pointer if nr==0) + ensures + - returns a matrix M such that: + - M.nr() == nr + - m.nc() == 1 + - for all valid i: + M(i) == ptr[i] + - Note that the returned matrix doesn't take "ownership" of + the pointer and thus will not delete or free it. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + const matrix_exp mat ( + const T* ptr, + long nr, + long nc + ); + /*! + requires + - nr >= 0 + - nc >= 0 + - ptr == a pointer to at least nr*nc T objects (or the NULL pointer if nr*nc==0) + ensures + - returns a matrix M such that: + - M.nr() == nr + - m.nc() == nc + - for all valid r and c: + M(r,c) == ptr[r*nc + c] + (i.e. the pointer is interpreted as a matrix laid out in memory + in row major order) + - Note that the returned matrix doesn't take "ownership" of + the pointer and thus will not delete or free it. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + const matrix_exp mat ( + const T* ptr, + long nr, + long nc, + long stride + ); + /*! + requires + - nr >= 0 + - nc >= 0 + - stride > 0 + - ptr == a pointer to at least (nr-1)*stride+nc T objects (or the NULL pointer if nr*nc==0) + ensures + - returns a matrix M such that: + - M.nr() == nr + - m.nc() == nc + - for all valid r and c: + M(r,c) == ptr[r*stride + c] + (i.e. the pointer is interpreted as a matrix laid out in memory + in row major order, with a row stride of the given stride amount.) + - Note that the returned matrix doesn't take "ownership" of + the pointer and thus will not delete or free it. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + const matrix_exp mat ( + const ::arma::Mat& m + ); + /*! + ensures + - Converts a matrix from the Armadillo library into a dlib matrix. + - returns a matrix R such that: + - R.nr() == m.n_rows + - R.nc() == m.n_cols + - for all valid r: + R(r,c) == m(r,c) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename _Scalar, + int _Rows, + int _Cols, + int _Options, + int _MaxRows, + int _MaxCols + > + const matrix_exp mat ( + const ::Eigen::Matrix<_Scalar,_Rows,_Cols,_Options,_MaxRows,_MaxCols>& m + ); + /*! + ensures + - Converts a matrix from the Eigen library into a dlib matrix. + - returns a matrix R such that: + - R.nr() == m.rows() + - R.nc() == m.cols() + - for all valid r: + R(r,c) == m(r,c) + !*/ + +// ---------------------------------------------------------------------------------------- + + matrix mat (double value); + matrix mat (float value); + matrix mat (long double value); + /*! + ensures + - Converts a scalar into a matrix containing just that scalar and returns the + results. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MATRIx_MAT_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/matrix/matrix_math_functions.h b/ml/dlib/dlib/matrix/matrix_math_functions.h new file mode 100644 index 000000000..d1db3ed14 --- /dev/null +++ b/ml/dlib/dlib/matrix/matrix_math_functions.h @@ -0,0 +1,448 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MATRIx_MATH_FUNCTIONS +#define DLIB_MATRIx_MATH_FUNCTIONS + +#include "matrix_math_functions_abstract.h" +#include "matrix_op.h" +#include "matrix_utilities.h" +#include "matrix.h" +#include "../algs.h" +#include +#include +#include + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + DLIB_DEFINE_FUNCTION_M(op_sqrt, sqrt, std::sqrt ,7); + DLIB_DEFINE_FUNCTION_M(op_log, log, std::log ,7); + DLIB_DEFINE_FUNCTION_M(op_log10, log10, std::log10 ,7); + DLIB_DEFINE_FUNCTION_M(op_exp, exp, std::exp ,7); + + DLIB_DEFINE_FUNCTION_M(op_conj, conj, std::conj ,2); + + DLIB_DEFINE_FUNCTION_M(op_ceil, ceil, std::ceil ,7); + DLIB_DEFINE_FUNCTION_M(op_floor, floor, std::floor ,7); + + DLIB_DEFINE_FUNCTION_M(op_sin, sin, std::sin ,7); + DLIB_DEFINE_FUNCTION_M(op_cos, cos, std::cos ,7); + DLIB_DEFINE_FUNCTION_M(op_tan, tan, std::tan ,7); + DLIB_DEFINE_FUNCTION_M(op_sinh, sinh, std::sinh ,7); + DLIB_DEFINE_FUNCTION_M(op_cosh, cosh, std::cosh ,7); + DLIB_DEFINE_FUNCTION_M(op_tanh, tanh, std::tanh ,7); + DLIB_DEFINE_FUNCTION_M(op_asin, asin, std::asin ,7); + DLIB_DEFINE_FUNCTION_M(op_acos, acos, std::acos ,7); + DLIB_DEFINE_FUNCTION_M(op_atan, atan, std::atan ,7); + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + template + inline type sigmoid (const type& val) + { + return static_cast(1/(1 + std::exp(-val))); + } + + template + inline type round_zeros_eps (const type& val, const S& eps) + { + // you can only round matrices that contain built in scalar types like double, long, float, etc. + COMPILE_TIME_ASSERT(is_built_in_scalar_type::value); + + if (val >= eps || val <= -eps) + return val; + else + return 0; + } + + template + inline type round_zeros (const type& val) + { + // you can only round matrices that contain built in scalar types like double, long, float, etc. + COMPILE_TIME_ASSERT(is_built_in_scalar_type::value); + + const type eps = 10*std::numeric_limits::epsilon(); + if (val >= eps || val <= -eps) + return val; + else + return 0; + } + + template + inline type squared (const type& val) + { + return val*val; + } + + template + inline type sign (const type& val) + { + if (val >= 0) + return +1; + else + return -1; + } + + template + type cubed (const type& val) + { + return val*val*val; + } + + template + inline type pow1 (const type& val, const S& s) + { + // you can only call pow() on matrices that contain floats, doubles or long doubles. + COMPILE_TIME_ASSERT(( + is_same_type::value == true || + is_same_type::value == true || + is_same_type::value == true + )); + + return std::pow(val,static_cast(s)); + } + + template + inline type pow2 (const S& s, const type& val) + { + // you can only call pow() on matrices that contain floats, doubles or long doubles. + COMPILE_TIME_ASSERT(( + is_same_type::value == true || + is_same_type::value == true || + is_same_type::value == true + )); + + return std::pow(static_cast(s),val); + } + + template + inline type reciprocal (const type& val) + { + // you can only compute reciprocal matrices that contain floats, doubles or long doubles. + COMPILE_TIME_ASSERT(( + is_same_type::value == true || + is_same_type::value == true || + is_same_type::value == true || + is_same_type >::value == true || + is_same_type >::value == true || + is_same_type >::value == true + )); + + if (val != static_cast(0)) + return static_cast((type)1.0/val); + else + return 0; + } + + template + inline type reciprocal_max (const type& val) + { + // you can only compute reciprocal_max matrices that contain floats, doubles or long doubles. + COMPILE_TIME_ASSERT(( + is_same_type::value == true || + is_same_type::value == true || + is_same_type::value == true + )); + + if (val != static_cast(0)) + return static_cast((type)1.0/val); + else + return std::numeric_limits::max(); + } + + } + + DLIB_DEFINE_FUNCTION_M(op_sigmoid, sigmoid, impl::sigmoid, 7); + DLIB_DEFINE_FUNCTION_MS(op_round_zeros, round_zeros, impl::round_zeros_eps, 7); + DLIB_DEFINE_FUNCTION_M(op_round_zeros2, round_zeros, impl::round_zeros, 7); + DLIB_DEFINE_FUNCTION_M(op_cubed, cubed, impl::cubed, 7); + DLIB_DEFINE_FUNCTION_M(op_squared, squared, impl::squared, 6); + DLIB_DEFINE_FUNCTION_M(op_sign, sign, impl::sign, 6); + DLIB_DEFINE_FUNCTION_MS(op_pow1, pow, impl::pow1, 7); + DLIB_DEFINE_FUNCTION_SM(op_pow2, pow, impl::pow2, 7); + DLIB_DEFINE_FUNCTION_M(op_reciprocal, reciprocal, impl::reciprocal, 6); + DLIB_DEFINE_FUNCTION_M(op_reciprocal_max, reciprocal_max, impl::reciprocal_max, 6); + +// ---------------------------------------------------------------------------------------- + + template + struct op_round : basic_op_m + { + op_round( const M& m_) : basic_op_m(m_){} + + const static long cost = M::cost+7; + typedef typename M::type type; + typedef const typename M::type const_ret_type; + const_ret_type apply (long r, long c) const + { + return static_cast(std::floor(this->m(r,c)+0.5)); + } + }; + + template + struct op_round::is_integer>::type > + : basic_op_m + { + op_round( const M& m_) : basic_op_m(m_){} + + const static long cost = M::cost; + typedef typename M::type type; + typedef typename M::const_ret_type const_ret_type; + const_ret_type apply (long r, long c) const + { + return this->m(r,c); + } + }; + + template < + typename EXP + > + const matrix_op > round ( + const matrix_exp& m + ) + { + // you can only round matrices that contain built in scalar types like double, long, float, etc. + COMPILE_TIME_ASSERT(is_built_in_scalar_type::value); + + typedef op_round op; + return matrix_op(op(m.ref())); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_normalize : basic_op_m + { + typedef typename M::type type; + + op_normalize( const M& m_, const type& s_) : basic_op_m(m_), s(s_){} + + const type s; + + const static long cost = M::cost+5; + typedef const typename M::type const_ret_type; + const_ret_type apply (long r, long c) const + { + return this->m(r,c)*s; + } + }; + + template < + typename EXP + > + const matrix_op > normalize ( + const matrix_exp& m + ) + { + // you can only compute normalized matrices that contain floats, doubles or long doubles. + COMPILE_TIME_ASSERT(( + is_same_type::value == true || + is_same_type::value == true || + is_same_type::value == true + )); + + + typedef op_normalize op; + typename EXP::type temp = std::sqrt(sum(squared(m))); + if (temp != 0.0) + temp = 1.0/temp; + + return matrix_op(op(m.ref(),temp)); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_abs : basic_op_m + { + op_abs( const M& m_) : basic_op_m(m_){} + + const static long cost = M::cost+7; + typedef typename M::type type; + typedef const typename M::type const_ret_type; + const_ret_type apply ( long r, long c) const + { + return static_cast(std::abs(this->m(r,c))); + } + }; + + template + struct op_abs > : basic_op_m + { + op_abs( const M& m_) : basic_op_m(m_){} + + const static long cost = M::cost; + typedef T type; + typedef const T const_ret_type; + const_ret_type apply ( long r, long c) const + { + return static_cast(std::abs(this->m(r,c))); + } + }; + + template < + typename EXP + > + const matrix_op > abs ( + const matrix_exp& m + ) + { + typedef op_abs op; + return matrix_op(op(m.ref())); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_complex_matrix : basic_op_m + { + op_complex_matrix( const M& m_) : basic_op_m(m_){} + + const static long cost = M::cost+1; + typedef std::complex type; + typedef const std::complex const_ret_type; + const_ret_type apply ( long r, long c) const + { + return type(this->m(r,c)); + } + }; + + template < + typename EXP + > + const matrix_op > complex_matrix ( + const matrix_exp& m + ) + { + typedef op_complex_matrix op; + return matrix_op(op(m.ref())); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_complex_matrix2 : basic_op_mm + { + op_complex_matrix2( const M1& m1_, const M2& m2_) : basic_op_mm(m1_,m2_){} + + const static long cost = M1::cost+M2::cost+1; + typedef std::complex type; + typedef const std::complex const_ret_type; + + const_ret_type apply ( long r, long c) const + { return type(this->m1(r,c), this->m2(r,c)); } + }; + + template < + typename EXP1, + typename EXP2 + > + const matrix_op > complex_matrix ( + const matrix_exp& real_part, + const matrix_exp& imag_part + ) + { + COMPILE_TIME_ASSERT((is_same_type::value == true)); + COMPILE_TIME_ASSERT(EXP1::NR == EXP2::NR || EXP1::NR == 0 || EXP2::NR == 0); + COMPILE_TIME_ASSERT(EXP1::NC == EXP2::NC || EXP1::NC == 0 || EXP2::NC == 0); + + DLIB_ASSERT(real_part.nr() == imag_part.nr() && + real_part.nc() == imag_part.nc(), + "\tconst matrix_exp::type complex_matrix(real_part, imag_part)" + << "\n\tYou can only make a complex matrix from two equally sized matrices" + << "\n\treal_part.nr(): " << real_part.nr() + << "\n\treal_part.nc(): " << real_part.nc() + << "\n\timag_part.nr(): " << imag_part.nr() + << "\n\timag_part.nc(): " << imag_part.nc() + ); + + typedef op_complex_matrix2 op; + return matrix_op(op(real_part.ref(),imag_part.ref())); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_norm : basic_op_m + { + op_norm( const M& m_) : basic_op_m(m_){} + + const static long cost = M::cost+6; + typedef typename M::type::value_type type; + typedef const typename M::type::value_type const_ret_type; + const_ret_type apply ( long r, long c) const + { return std::norm(this->m(r,c)); } + }; + + template < + typename EXP + > + const matrix_op > norm ( + const matrix_exp& m + ) + { + typedef op_norm op; + return matrix_op(op(m.ref())); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_real : basic_op_m + { + op_real( const M& m_) : basic_op_m(m_){} + + const static long cost = M::cost; + typedef typename M::type::value_type type; + typedef const typename M::type::value_type const_ret_type; + const_ret_type apply ( long r, long c) const + { return std::real(this->m(r,c)); } + }; + + template < + typename EXP + > + const matrix_op > real ( + const matrix_exp& m + ) + { + typedef op_real op; + return matrix_op(op(m.ref())); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_imag : basic_op_m + { + op_imag( const M& m_) : basic_op_m(m_){} + + const static long cost = M::cost; + typedef typename M::type::value_type type; + typedef const typename M::type::value_type const_ret_type; + const_ret_type apply (long r, long c) const + { return std::imag(this->m(r,c)); } + }; + + template < + typename EXP + > + const matrix_op > imag ( + const matrix_exp& m + ) + { + typedef op_imag op; + return matrix_op(op(m.ref())); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MATRIx_MATH_FUNCTIONS + diff --git a/ml/dlib/dlib/matrix/matrix_math_functions_abstract.h b/ml/dlib/dlib/matrix/matrix_math_functions_abstract.h new file mode 100644 index 000000000..09210270d --- /dev/null +++ b/ml/dlib/dlib/matrix/matrix_math_functions_abstract.h @@ -0,0 +1,595 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_MATRIx_MATH_FUNCTIONS_ABSTRACT_ +#ifdef DLIB_MATRIx_MATH_FUNCTIONS_ABSTRACT_ + +#include "matrix_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// Exponential Functions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + const matrix_exp exp ( + const matrix_exp& m + ); + /*! + ensures + - returns a matrix R such that: + - R::type == the same type that was in m + - R has the same dimensions as m + - for all valid r and c: + R(r,c) == std::exp(m(r,c)) + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp log10 ( + const matrix_exp& m + ); + /*! + ensures + - returns a matrix R such that: + - R::type == the same type that was in m + - R has the same dimensions as m + - for all valid r and c: + R(r,c) == std::log10(m(r,c)) + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp log ( + const matrix_exp& m + ); + /*! + ensures + - returns a matrix R such that: + - R::type == the same type that was in m + - R has the same dimensions as m + - for all valid r and c: + R(r,c) == std::log(m(r,c)) + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp sqrt ( + const matrix_exp& m + ); + /*! + ensures + - returns a matrix R such that: + - R::type == the same type that was in m + - R has the same dimensions as m + - for all valid r and c: + R(r,c) == sqrt(m(r,c)) + !*/ + +// ---------------------------------------------------------------------------------------- + + template + const matrix_exp pow ( + const matrix_exp& m, + const T& e + ); + /*! + requires + - matrix_exp::type == float, double, or long double + ensures + - returns a matrix R such that: + - R::type == the same type that was in m + - R has the same dimensions as m + - for all valid r and c: + R(r,c) == pow(m(r,c),e) + !*/ + +// ---------------------------------------------------------------------------------------- + + template + const matrix_exp pow ( + const T& b, + const matrix_exp& m + ); + /*! + requires + - matrix_exp::type == float, double, or long double + ensures + - returns a matrix R such that: + - R::type == the same type that was in m + - R has the same dimensions as m + - for all valid r and c: + R(r,c) == pow(b, m(r,c)) + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp squared ( + const matrix_exp& m + ); + /*! + ensures + - returns a matrix R such that: + - R::type == the same type that was in m + - R has the same dimensions as m + - for all valid r and c: + R(r,c) == m(r,c)*m(r,c) + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp cubed ( + const matrix_exp& m + ); + /*! + ensures + - returns a matrix R such that: + - R::type == the same type that was in m + - R has the same dimensions as m + - for all valid r and c: + R(r,c) == m(r,c)*m(r,c)*m(r,c) + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// Miscellaneous +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + const matrix_exp sign ( + const matrix_exp& m + ); + /*! + ensures + - returns a matrix that tells the sign of each element in m. In particular: + returns a matrix R such that: + - R::type == the same type that was in m + - R has the same dimensions as m + - for all valid r and c: + - if (m(r,c) >= 0) then + - R(r,c) == +1 + - else + - R(r,c) == -1 + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp sigmoid ( + const matrix_exp& m + ); + /*! + ensures + - returns a matrix R such that: + - R::type == the same type that was in m + - R has the same dimensions as m + - for all valid r and c: + R(r,c) == 1/(1 + exp(-m(r,c))) + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp abs ( + const matrix_exp& m + ); + /*! + ensures + - returns a matrix R such that: + - if (m contains std::complex objects) then + - R::type == T + - else + - R::type == the same type that was in m + - R has the same dimensions as m + - for all valid r and c: + R(r,c) == std::abs(m(r,c)) + (note that if m is complex then std::abs(val) performs sqrt(std::norm(val)) + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp reciprocal ( + const matrix_exp& m + ); + /*! + requires + - matrix_exp::type == float, double, long double, std::complex, + std::complex, or std::complex + ensures + - returns a matrix R such that: + - R::type == the same type that was in m + - R has the same dimensions as m + - for all valid r and c: + - if (m(r,c) != 0) then + - R(r,c) == 1.0/m(r,c) + - else + - R(r,c) == 0 + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp reciprocal_max ( + const matrix_exp& m + ); + /*! + requires + - matrix_exp::type == float, double, long double + ensures + - returns a matrix R such that: + - R::type == the same type that was in m + - R has the same dimensions as m + - for all valid r and c: + - if (m(r,c) != 0) then + - R(r,c) == 1.0/m(r,c) + - else + - R(r,c) == std::numeric_limits::max() + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp normalize ( + const matrix_exp& m + ); + /*! + requires + - matrix_exp::type == float, double, or long double + ensures + - if (sqrt(sum(squared(m))) != 0) then + - returns m/sqrt(sum(squared(m))) + - else + - returns m + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// Rounding numbers one way or another +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + const matrix_exp round ( + const matrix_exp& m + ); + /*! + requires + - is_built_in_scalar_type::value == true + (i.e. m must contain a type like int, float, double, long, etc.) + ensures + - if (m contains integers) then + - returns m unmodified + - else + - returns a matrix R such that: + - R::type == the same type that was in m + - R has the same dimensions as m + - for all valid r and c: + R(r,c) == m(r,c) rounded to the nearest integral value + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp ceil ( + const matrix_exp& m + ); + /*! + ensures + - returns a matrix R such that: + - R::type == the same type that was in m + - R has the same dimensions as m + - for all valid r and c: + R(r,c) == std::ceil(m(r,c)) + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp floor ( + const matrix_exp& m + ); + /*! + ensures + - returns a matrix R such that: + - R::type == the same type that was in m + - R has the same dimensions as m + - for all valid r and c: + R(r,c) == std::floor(m(r,c)) + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp round_zeros ( + const matrix_exp& m + ); + /*! + requires + - is_built_in_scalar_type::value == true + (i.e. m must contain a type like int, float, double, long, etc.) + ensures + - if (m contains integers) then + - returns m unmodified + - else + - returns a matrix R such that: + - R::type == the same type that was in m + - R has the same dimensions as m + - let eps == 10*std::numeric_limits::epsilon() + - for all valid r and c: + - if (abs(m(r,c)) >= eps) then + - R(r,c) == m(r,c) + - else + - R(r,c) == 0 + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp round_zeros ( + const matrix_exp& m, + matrix_exp::type eps + ); + /*! + requires + - is_built_in_scalar_type::value == true + (i.e. m must contain a type like int, float, double, long, etc.) + ensures + - returns a matrix R such that: + - R::type == the same type that was in m + - R has the same dimensions as m + - for all valid r and c: + - if (abs(m(r,c)) >= eps) then + - R(r,c) == m(r,c) + - else + - R(r,c) == 0 + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// Complex number utility functions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + const matrix_exp conj ( + const matrix_exp& m + ); + /*! + requires + - matrix_exp::type == std::complex + ensures + - returns a matrix R such that: + - R::type == std::complex + - R has the same dimensions as m + - for all valid r and c: + R(r,c) == std::conj(m(r,c)) + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp norm ( + const matrix_exp& m + ); + /*! + requires + - matrix_exp::type == std::complex + ensures + - returns a matrix R such that: + - R::type == T + - R has the same dimensions as m + - for all valid r and c: + R(r,c) == std::norm(m(r,c)) + (note that std::norm(val) == val.real()*val.real() + val.imag()*val.imag()) + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp imag ( + const matrix_exp& m + ); + /*! + requires + - matrix_exp::type == std::complex + ensures + - returns a matrix R such that: + - R::type == T + - R has the same dimensions as m + - for all valid r and c: + R(r,c) == std::imag(m(r,c)) + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp real ( + const matrix_exp& m + ); + /*! + requires + - matrix_exp::type == std::complex + ensures + - returns a matrix R such that: + - R::type == T + - R has the same dimensions as m + - for all valid r and c: + R(r,c) == std::real(m(r,c)) + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp complex_matrix ( + const matrix_exp& real_part + ); + /*! + ensures + - returns a matrix R such that: + - R::type == std::complex where T is whatever type real_part used. + - R has the same dimensions as real_part. + - for all valid r and c: + R(r,c) == std::complex(real_part(r,c), 0) + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp complex_matrix ( + const matrix_exp& real_part, + const matrix_exp& imag_part + ); + /*! + requires + - real_part.nr() == imag_part.nr() + - real_part.nc() == imag_part.nc() + - real_part and imag_part both contain the same type of element + ensures + - returns a matrix R such that: + - R::type == std::complex where T is whatever type real_part and imag_part used. + - R has the same dimensions as real_part and imag_part + - for all valid r and c: + R(r,c) == std::complex(real_part(r,c),imag_part(r,c)) + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// Trigonometric Functions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + const matrix_exp sin ( + const matrix_exp& m + ); + /*! + requires + - matrix_exp::type == float, double, or long double + ensures + - returns a matrix R such that: + - R::type == the same type that was in m + - R has the same dimensions as m + - for all valid r and c: + R(r,c) == std::sin(m(r,c)) + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp cos ( + const matrix_exp& m + ); + /*! + requires + - matrix_exp::type == float, double, or long double + ensures + - returns a matrix R such that: + - R::type == the same type that was in m + - R has the same dimensions as m + - for all valid r and c: + R(r,c) == std::cos(m(r,c)) + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp tan ( + const matrix_exp& m + ); + /*! + requires + - matrix_exp::type == float, double, or long double + ensures + - returns a matrix R such that: + - R::type == the same type that was in m + - R has the same dimensions as m + - for all valid r and c: + R(r,c) == std::tan(m(r,c)) + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp asin ( + const matrix_exp& m + ); + /*! + requires + - matrix_exp::type == float, double, or long double + ensures + - returns a matrix R such that: + - R::type == the same type that was in m + - R has the same dimensions as m + - for all valid r and c: + R(r,c) == std::asin(m(r,c)) + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp acos ( + const matrix_exp& m + ); + /*! + requires + - matrix_exp::type == float, double, or long double + ensures + - returns a matrix R such that: + - R::type == the same type that was in m + - R has the same dimensions as m + - for all valid r and c: + R(r,c) == std::acos(m(r,c)) + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp atan ( + const matrix_exp& m + ); + /*! + requires + - matrix_exp::type == float, double, or long double + ensures + - returns a matrix R such that: + - R::type == the same type that was in m + - R has the same dimensions as m + - for all valid r and c: + R(r,c) == std::atan(m(r,c)) + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp sinh ( + const matrix_exp& m + ); + /*! + requires + - matrix_exp::type == float, double, or long double + ensures + - returns a matrix R such that: + - R::type == the same type that was in m + - R has the same dimensions as m + - for all valid r and c: + R(r,c) == std::sinh(m(r,c)) + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp cosh ( + const matrix_exp& m + ); + /*! + requires + - matrix_exp::type == float, double, or long double + ensures + - returns a matrix R such that: + - R::type == the same type that was in m + - R has the same dimensions as m + - for all valid r and c: + R(r,c) == std::cosh(m(r,c)) + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp tanh ( + const matrix_exp& m + ); + /*! + requires + - matrix_exp::type == float, double, or long double + ensures + - returns a matrix R such that: + - R::type == the same type that was in m + - R has the same dimensions as m + - for all valid r and c: + R(r,c) == std::tanh(m(r,c)) + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MATRIx_MATH_FUNCTIONS_ABSTRACT_ + diff --git a/ml/dlib/dlib/matrix/matrix_op.h b/ml/dlib/dlib/matrix/matrix_op.h new file mode 100644 index 000000000..524a775eb --- /dev/null +++ b/ml/dlib/dlib/matrix/matrix_op.h @@ -0,0 +1,479 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MATRIx_OP_H_ +#define DLIB_MATRIx_OP_H_ + +#include "matrix_exp.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template + class matrix_op; + + template < typename OP > + struct matrix_traits > + { + typedef typename OP::type type; + typedef typename OP::const_ret_type const_ret_type; + typedef typename OP::mem_manager_type mem_manager_type; + typedef typename OP::layout_type layout_type; + const static long NR = OP::NR; + const static long NC = OP::NC; + const static long cost = OP::cost; + }; + + template < + typename OP + > + class matrix_op : public matrix_exp > + { + /*! + WHAT THIS OBJECT REPRESENTS + The matrix_op is simply a tool for reducing the amount of boilerplate + you need to write when creating matrix expressions. + !*/ + + public: + typedef typename matrix_traits::type type; + typedef typename matrix_traits::const_ret_type const_ret_type; + typedef typename matrix_traits::mem_manager_type mem_manager_type; + typedef typename matrix_traits::layout_type layout_type; + const static long NR = matrix_traits::NR; + const static long NC = matrix_traits::NC; + const static long cost = matrix_traits::cost; + + private: + // This constructor exists simply for the purpose of causing a compile time error if + // someone tries to create an instance of this object with the wrong kind of object. + template + matrix_op (T1); + public: + + matrix_op ( + const OP& op_ + ) : + op(op_) + {} + + const_ret_type operator() ( + long r, + long c + ) const { return op.apply(r,c); } + + const_ret_type operator() ( long i ) const + { return matrix_exp::operator()(i); } + + template + bool aliases ( + const matrix_exp& item + ) const { return op.aliases(item); } + + template + bool destructively_aliases ( + const matrix_exp& item + ) const { return op.destructively_aliases(item); } + + long nr ( + ) const { return op.nr(); } + + long nc ( + ) const { return op.nc(); } + + + const OP op; + }; + +// ---------------------------------------------------------------------------------------- + + template + class matrix_diag_op; + + template < typename OP > + struct matrix_traits > + { + typedef typename OP::type type; + typedef typename OP::const_ret_type const_ret_type; + typedef typename OP::mem_manager_type mem_manager_type; + typedef typename OP::layout_type layout_type; + const static long NR = OP::NR; + const static long NC = OP::NC; + const static long cost = OP::cost; + }; + + template < + typename OP + > + class matrix_diag_op : public matrix_diag_exp > + { + /*! + WHAT THIS OBJECT REPRESENTS + The matrix_diag_op is simply a tool for reducing the amount of boilerplate + you need to write when creating matrix expressions. + !*/ + + public: + typedef typename matrix_traits::type type; + typedef typename matrix_traits::const_ret_type const_ret_type; + typedef typename matrix_traits::mem_manager_type mem_manager_type; + typedef typename matrix_traits::layout_type layout_type; + const static long NR = matrix_traits::NR; + const static long NC = matrix_traits::NC; + const static long cost = matrix_traits::cost; + + private: + // This constructor exists simply for the purpose of causing a compile time error if + // someone tries to create an instance of this object with the wrong kind of object. + template + matrix_diag_op (T1); + public: + + matrix_diag_op ( + const OP& op_ + ) : + op(op_) + {} + + const_ret_type operator() ( + long r, + long c + ) const { return op.apply(r,c); } + + const_ret_type operator() ( long i ) const + { return matrix_exp::operator()(i); } + + template + bool aliases ( + const matrix_exp& item + ) const { return op.aliases(item); } + + template + bool destructively_aliases ( + const matrix_exp& item + ) const { return op.destructively_aliases(item); } + + long nr ( + ) const { return op.nr(); } + + long nc ( + ) const { return op.nc(); } + + + const OP op; + }; + +// ---------------------------------------------------------------------------------------- + + struct does_not_alias + { + /*! + This is a partial implementation of a matrix operator that never aliases + another expression. + !*/ + + template bool aliases ( const U& ) const { return false; } + template bool destructively_aliases ( const U& ) const { return false; } + }; + +// ---------------------------------------------------------------------------------------- + + template + struct basic_op_m + { + /*! + This is a partial implementation of a matrix operator that preserves + the dimensions of its argument and doesn't have destructive aliasing. + !*/ + + private: + // This constructor exists simply for the purpose of causing a compile time error if + // someone tries to create an instance of this object with the wrong kind of object. + template + basic_op_m (T1); + public: + + basic_op_m( + const M& m_ + ) : m(m_){} + + const M& m; + + const static long NR = M::NR; + const static long NC = M::NC; + typedef typename M::mem_manager_type mem_manager_type; + typedef typename M::layout_type layout_type; + + long nr () const { return m.nr(); } + long nc () const { return m.nc(); } + + template bool aliases ( const matrix_exp& item) const + { return m.aliases(item); } + template bool destructively_aliases ( const matrix_exp& item) const + { return m.destructively_aliases(item); } + + }; + +// ---------------------------------------------------------------------------------------- + + template + struct basic_op_mm + { + /*! + This is a partial implementation of a matrix operator that preserves + the dimensions of its arguments and doesn't have destructive aliasing. + !*/ + + private: + // This constructor exists simply for the purpose of causing a compile time error if + // someone tries to create an instance of this object with the wrong kind of object. + template + basic_op_mm (T1, T2); + public: + + basic_op_mm( + const M1& m1_, + const M2& m2_ + ) : m1(m1_), m2(m2_){} + + const M1& m1; + const M2& m2; + + const static long NR = M1::NR; + const static long NC = M1::NC; + typedef typename M1::mem_manager_type mem_manager_type; + typedef typename M1::layout_type layout_type; + + long nr () const { return m1.nr(); } + long nc () const { return m1.nc(); } + + template bool aliases ( const matrix_exp& item) const + { return m1.aliases(item) || m2.aliases(item); } + template bool destructively_aliases ( const matrix_exp& item) const + { return m1.destructively_aliases(item) || m2.destructively_aliases(item); } + + }; + +// ---------------------------------------------------------------------------------------- + + template + struct basic_op_mmm + { + /*! + This is a partial implementation of a matrix operator that preserves + the dimensions of its arguments and doesn't have destructive aliasing. + !*/ + + private: + // This constructor exists simply for the purpose of causing a compile time error if + // someone tries to create an instance of this object with the wrong kind of object. + template + basic_op_mmm (T1, T2, T3); + public: + + basic_op_mmm( + const M1& m1_, + const M2& m2_, + const M3& m3_ + ) : m1(m1_), m2(m2_), m3(m3_){} + + const M1& m1; + const M2& m2; + const M3& m3; + + const static long NR = M1::NR; + const static long NC = M1::NC; + typedef typename M1::mem_manager_type mem_manager_type; + typedef typename M1::layout_type layout_type; + + long nr () const { return m1.nr(); } + long nc () const { return m1.nc(); } + + template bool aliases ( const matrix_exp& item) const + { return m1.aliases(item) || m2.aliases(item) || m3.aliases(item); } + template bool destructively_aliases ( const matrix_exp& item) const + { return m1.destructively_aliases(item) || m2.destructively_aliases(item) || + m3.destructively_aliases(item);} + + }; + +// ---------------------------------------------------------------------------------------- + + template + struct basic_op_mmmm + { + /*! + This is a partial implementation of a matrix operator that preserves + the dimensions of its arguments and doesn't have destructive aliasing. + !*/ + + private: + // This constructor exists simply for the purpose of causing a compile time error if + // someone tries to create an instance of this object with the wrong kind of object. + template + basic_op_mmmm (T1, T2, T3, T4); + public: + + basic_op_mmmm( + const M1& m1_, + const M2& m2_, + const M3& m3_, + const M4& m4_ + ) : m1(m1_), m2(m2_), m3(m3_), m4(m4_){} + + const M1& m1; + const M2& m2; + const M3& m3; + const M4& m4; + + const static long NR = M1::NR; + const static long NC = M1::NC; + typedef typename M1::mem_manager_type mem_manager_type; + typedef typename M1::layout_type layout_type; + + long nr () const { return m1.nr(); } + long nc () const { return m1.nc(); } + + template bool aliases ( const matrix_exp& item) const + { return m1.aliases(item) || m2.aliases(item) || m3.aliases(item) || m4.aliases(item); } + template bool destructively_aliases ( const matrix_exp& item) const + { return m1.destructively_aliases(item) || m2.destructively_aliases(item) || + m3.destructively_aliases(item) || m4.destructively_aliases(item);} + + }; + +// ---------------------------------------------------------------------------------------- + +#define DLIB_DEFINE_OP_M(op_name, function, extra_cost) \ + template \ + struct op_name \ + { \ + op_name( \ + const M& m_ \ + ) : m(m_){} \ + \ + const M& m; \ + \ + const static long cost = M::cost+(extra_cost); \ + const static long NR = M::NR; \ + const static long NC = M::NC; \ + typedef typename M::type type; \ + typedef const typename M::type const_ret_type; \ + typedef typename M::mem_manager_type mem_manager_type; \ + typedef typename M::layout_type layout_type; \ + \ + const_ret_type apply (long r, long c) const { return function(m(r,c)); } \ + \ + long nr () const { return m.nr(); } \ + long nc () const { return m.nc(); } \ + \ + template bool aliases ( const matrix_exp& item) const \ + { return m.aliases(item); } \ + template bool destructively_aliases ( const matrix_exp& item) const \ + { return m.destructively_aliases(item); } \ + \ + } + +#define DLIB_DEFINE_FUNCTION_M(op_name, name, function, extra_cost) \ + DLIB_DEFINE_OP_M(op_name, function, extra_cost); \ + template < typename M > \ + const matrix_op > name ( const matrix_exp& m) \ + { \ + typedef op_name op; \ + return matrix_op(op(m.ref())); \ + } + +// ---------------------------------------------------------------------------------------- + +#define DLIB_DEFINE_OP_MS(op_name, function, extra_cost) \ + template \ + struct op_name \ + { \ + op_name( \ + const M& m_, \ + const S& s_ \ + ) : m(m_), s(s_){} \ + \ + const M& m; \ + const S s; \ + \ + const static long cost = M::cost+(extra_cost); \ + const static long NR = M::NR; \ + const static long NC = M::NC; \ + typedef typename M::type type; \ + typedef const typename M::type const_ret_type; \ + typedef typename M::mem_manager_type mem_manager_type; \ + typedef typename M::layout_type layout_type; \ + \ + const_ret_type apply (long r, long c) const { return function(m(r,c), s); } \ + \ + long nr () const { return m.nr(); } \ + long nc () const { return m.nc(); } \ + \ + template bool aliases ( const matrix_exp& item) const \ + { return m.aliases(item); } \ + template bool destructively_aliases ( const matrix_exp& item) const \ + { return m.destructively_aliases(item); } \ + \ + } + +#define DLIB_DEFINE_FUNCTION_MS(op_name, name, function, extra_cost) \ + DLIB_DEFINE_OP_MS(op_name, function, extra_cost); \ + template < typename M, typename S > \ + const matrix_op > name ( const matrix_exp& m, const S& s) \ + { \ + typedef op_name op; \ + return matrix_op(op(m.ref(), s)); \ + } + +// ---------------------------------------------------------------------------------------- + +#define DLIB_DEFINE_OP_SM(op_name, function, extra_cost) \ + template \ + struct op_name \ + { \ + op_name( \ + const S& s_, \ + const M& m_ \ + ) : m(m_), s(s_){} \ + \ + const M& m; \ + const S s; \ + \ + const static long cost = M::cost+(extra_cost); \ + const static long NR = M::NR; \ + const static long NC = M::NC; \ + typedef typename M::type type; \ + typedef const typename M::type const_ret_type; \ + typedef typename M::mem_manager_type mem_manager_type; \ + typedef typename M::layout_type layout_type; \ + \ + const_ret_type apply (long r, long c) const { return function(s, m(r,c)); } \ + \ + long nr () const { return m.nr(); } \ + long nc () const { return m.nc(); } \ + \ + template bool aliases ( const matrix_exp& item) const \ + { return m.aliases(item); } \ + template bool destructively_aliases ( const matrix_exp& item) const \ + { return m.destructively_aliases(item); } \ + \ + } + +#define DLIB_DEFINE_FUNCTION_SM(op_name, name, function, extra_cost) \ + DLIB_DEFINE_OP_SM(op_name, function, extra_cost); \ + template < typename S, typename M > \ + const matrix_op > name (const S& s, const matrix_exp& m) \ + { \ + typedef op_name op; \ + return matrix_op(op(s, m.ref())); \ + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MATRIx_OP_H_ + diff --git a/ml/dlib/dlib/matrix/matrix_qr.h b/ml/dlib/dlib/matrix/matrix_qr.h new file mode 100644 index 000000000..086d481f1 --- /dev/null +++ b/ml/dlib/dlib/matrix/matrix_qr.h @@ -0,0 +1,466 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +// This code was adapted from code from the JAMA part of NIST's TNT library. +// See: http://math.nist.gov/tnt/ +#ifndef DLIB_MATRIX_QR_DECOMPOSITION_H +#define DLIB_MATRIX_QR_DECOMPOSITION_H + +#include "matrix.h" +#include "matrix_utilities.h" +#include "matrix_subexp.h" + +#ifdef DLIB_USE_LAPACK +#include "lapack/geqrf.h" +#include "lapack/ormqr.h" +#endif + +#include "matrix_trsm.h" + +namespace dlib +{ + + template < + typename matrix_exp_type + > + class qr_decomposition + { + + public: + + const static long NR = matrix_exp_type::NR; + const static long NC = matrix_exp_type::NC; + typedef typename matrix_exp_type::type type; + typedef typename matrix_exp_type::mem_manager_type mem_manager_type; + typedef typename matrix_exp_type::layout_type layout_type; + + typedef matrix matrix_type; + + // You have supplied an invalid type of matrix_exp_type. You have + // to use this object with matrices that contain float or double type data. + COMPILE_TIME_ASSERT((is_same_type::value || + is_same_type::value )); + + + + template + qr_decomposition( + const matrix_exp& A + ); + + bool is_full_rank( + ) const; + + long nr( + ) const; + + long nc( + ) const; + + const matrix_type get_r ( + ) const; + + const matrix_type get_q ( + ) const; + + template + void get_q ( + matrix& Q + ) const; + + template + const matrix_type solve ( + const matrix_exp& B + ) const; + + private: + +#ifndef DLIB_USE_LAPACK + template + const matrix_type solve_mat ( + const matrix_exp& B + ) const; + + template + const matrix_type solve_vect ( + const matrix_exp& B + ) const; +#endif + + + /** Array for internal storage of decomposition. + @serial internal array storage. + */ + matrix QR_; + + /** Row and column dimensions. + @serial column dimension. + @serial row dimension. + */ + long m, n; + + /** Array for internal storage of diagonal of R. + @serial diagonal of R. + */ + typedef matrix column_vector_type; + column_vector_type tau; + column_vector_type Rdiag; + + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// Member functions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template + template + qr_decomposition:: + qr_decomposition( + const matrix_exp& A + ) + { + COMPILE_TIME_ASSERT((is_same_type::value)); + + // make sure requires clause is not broken + DLIB_ASSERT(A.nr() >= A.nc() && A.size() > 0, + "\tqr_decomposition::qr_decomposition(A)" + << "\n\tInvalid inputs were given to this function" + << "\n\tA.nr(): " << A.nr() + << "\n\tA.nc(): " << A.nc() + << "\n\tA.size(): " << A.size() + << "\n\tthis: " << this + ); + + + QR_ = A; + m = A.nr(); + n = A.nc(); + +#ifdef DLIB_USE_LAPACK + + lapack::geqrf(QR_, tau); + Rdiag = diag(QR_); + +#else + Rdiag.set_size(n); + long i=0, j=0, k=0; + + // Main loop. + for (k = 0; k < n; k++) + { + // Compute 2-norm of k-th column without under/overflow. + type nrm = 0; + for (i = k; i < m; i++) + { + nrm = hypot(nrm,QR_(i,k)); + } + + if (nrm != 0.0) + { + // Form k-th Householder vector. + if (QR_(k,k) < 0) + { + nrm = -nrm; + } + for (i = k; i < m; i++) + { + QR_(i,k) /= nrm; + } + QR_(k,k) += 1.0; + + // Apply transformation to remaining columns. + for (j = k+1; j < n; j++) + { + type s = 0.0; + for (i = k; i < m; i++) + { + s += QR_(i,k)*QR_(i,j); + } + s = -s/QR_(k,k); + for (i = k; i < m; i++) + { + QR_(i,j) += s*QR_(i,k); + } + } + } + Rdiag(k) = -nrm; + } +#endif + } + +// ---------------------------------------------------------------------------------------- + + template + long qr_decomposition:: + nr ( + ) const + { + return m; + } + +// ---------------------------------------------------------------------------------------- + + template + long qr_decomposition:: + nc ( + ) const + { + return n; + } + +// ---------------------------------------------------------------------------------------- + + template + bool qr_decomposition:: + is_full_rank( + ) const + { + type eps = max(abs(Rdiag)); + if (eps != 0) + eps *= std::sqrt(std::numeric_limits::epsilon())/100; + else + eps = 1; // there is no max so just use 1 + + // check if any of the elements of Rdiag are effectively 0 + return min(abs(Rdiag)) > eps; + } + +// ---------------------------------------------------------------------------------------- + + template + const typename qr_decomposition::matrix_type qr_decomposition:: + get_r( + ) const + { + matrix_type R(n,n); + for (long i = 0; i < n; i++) + { + for (long j = 0; j < n; j++) + { + if (i < j) + { + R(i,j) = QR_(i,j); + } + else if (i == j) + { + R(i,j) = Rdiag(i); + } + else + { + R(i,j) = 0.0; + } + } + } + return R; + } + +// ---------------------------------------------------------------------------------------- + + template + const typename qr_decomposition::matrix_type qr_decomposition:: + get_q( + ) const + { + matrix_type Q; + get_q(Q); + return Q; + } + +// ---------------------------------------------------------------------------------------- + + template + template + void qr_decomposition:: + get_q( + matrix& X + ) const + { +#ifdef DLIB_USE_LAPACK + // Take only the first n columns of an identity matrix. This way + // X ends up being an m by n matrix. + X = colm(identity_matrix(m), range(0,n-1)); + + // Compute Y = Q*X + lapack::ormqr('L','N', QR_, tau, X); + +#else + long i=0, j=0, k=0; + + X.set_size(m,n); + for (k = n-1; k >= 0; k--) + { + for (i = 0; i < m; i++) + { + X(i,k) = 0.0; + } + X(k,k) = 1.0; + for (j = k; j < n; j++) + { + if (QR_(k,k) != 0) + { + type s = 0.0; + for (i = k; i < m; i++) + { + s += QR_(i,k)*X(i,j); + } + s = -s/QR_(k,k); + for (i = k; i < m; i++) + { + X(i,j) += s*QR_(i,k); + } + } + } + } +#endif + } + +// ---------------------------------------------------------------------------------------- + + template + template + const typename qr_decomposition::matrix_type qr_decomposition:: + solve( + const matrix_exp& B + ) const + { + COMPILE_TIME_ASSERT((is_same_type::value)); + + // make sure requires clause is not broken + DLIB_ASSERT(B.nr() == nr(), + "\tconst matrix_type qr_decomposition::solve(B)" + << "\n\tInvalid inputs were given to this function" + << "\n\tB.nr(): " << B.nr() + << "\n\tnr(): " << nr() + << "\n\tthis: " << this + ); + +#ifdef DLIB_USE_LAPACK + + using namespace blas_bindings; + matrix X(B); + // Compute Y = transpose(Q)*B + lapack::ormqr('L','T',QR_, tau, X); + // Solve R*X = Y; + triangular_solver(CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, QR_, X, n); + + /* return n x nx portion of X */ + return subm(X,0,0,n,B.nc()); + +#else + // just call the right version of the solve function + if (B.nc() == 1) + return solve_vect(B); + else + return solve_mat(B); +#endif + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// Private member functions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + +#ifndef DLIB_USE_LAPACK + + template + template + const typename qr_decomposition::matrix_type qr_decomposition:: + solve_vect( + const matrix_exp& B + ) const + { + + column_vector_type x(B); + + // Compute Y = transpose(Q)*B + for (long k = 0; k < n; k++) + { + type s = 0.0; + for (long i = k; i < m; i++) + { + s += QR_(i,k)*x(i); + } + s = -s/QR_(k,k); + for (long i = k; i < m; i++) + { + x(i) += s*QR_(i,k); + } + } + // Solve R*X = Y; + for (long k = n-1; k >= 0; k--) + { + x(k) /= Rdiag(k); + for (long i = 0; i < k; i++) + { + x(i) -= x(k)*QR_(i,k); + } + } + + + /* return n x 1 portion of x */ + return colm(x,0,n); + } + +// ---------------------------------------------------------------------------------------- + + template + template + const typename qr_decomposition::matrix_type qr_decomposition:: + solve_mat( + const matrix_exp& B + ) const + { + const long nx = B.nc(); + matrix_type X(B); + long i=0, j=0, k=0; + + // Compute Y = transpose(Q)*B + for (k = 0; k < n; k++) + { + for (j = 0; j < nx; j++) + { + type s = 0.0; + for (i = k; i < m; i++) + { + s += QR_(i,k)*X(i,j); + } + s = -s/QR_(k,k); + for (i = k; i < m; i++) + { + X(i,j) += s*QR_(i,k); + } + } + } + // Solve R*X = Y; + for (k = n-1; k >= 0; k--) + { + for (j = 0; j < nx; j++) + { + X(k,j) /= Rdiag(k); + } + for (i = 0; i < k; i++) + { + for (j = 0; j < nx; j++) + { + X(i,j) -= X(k,j)*QR_(i,k); + } + } + } + + /* return n x nx portion of X */ + return subm(X,0,0,n,nx); + } + +// ---------------------------------------------------------------------------------------- + +#endif // DLIB_USE_LAPACK not defined + +} + +#endif // DLIB_MATRIX_QR_DECOMPOSITION_H + + + diff --git a/ml/dlib/dlib/matrix/matrix_read_from_istream.h b/ml/dlib/dlib/matrix/matrix_read_from_istream.h new file mode 100644 index 000000000..3aced3584 --- /dev/null +++ b/ml/dlib/dlib/matrix/matrix_read_from_istream.h @@ -0,0 +1,108 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MATRIx_READ_FROM_ISTREAM_H_h_ +#define DLIB_MATRIx_READ_FROM_ISTREAM_H_h_ + +#include "matrix.h" +#include +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + inline bool next_is_whitespace ( + std::istream& in + ) + { + return in.peek() == '\n' || + in.peek() == ' ' || + in.peek() == ',' || + in.peek() == '\t' || + in.peek() == '\r'; + } + } + + template + std::istream& operator>> ( + std::istream& in, + matrix& m + ) + { + using namespace dlib::impl; + long num_rows = 0; + std::vector buf; + buf.reserve(100); + + // eat any leading whitespace + while (next_is_whitespace(in)) + in.get(); + + bool at_start_of_line = true; + bool stop = false; + while(!stop && in.peek() != EOF) + { + T temp; + in >> temp; + if (!in) + return in; + + buf.push_back(temp); + if (at_start_of_line) + { + at_start_of_line = false; + ++num_rows; + } + + // Eat next block of whitespace but also note if we hit the start of the next + // line. + while (next_is_whitespace(in)) + { + if (at_start_of_line && in.peek() == '\n') + { + stop = true; + break; + } + + if (in.get() == '\n') + at_start_of_line = true; + } + } + + // It's an error for there to not be any matrix data in the input stream + if (num_rows == 0) + { + in.clear(in.rdstate() | std::ios::failbit); + return in; + } + + const long num_cols = buf.size()/num_rows; + // It's also an error if the sizes don't make sense. + if (num_rows*num_cols != (long)buf.size() || + (NR != 0 && NR != num_rows) || + (NC != 0 && NC != num_cols)) + { + in.clear(in.rdstate() | std::ios::failbit); + return in; + } + + + m = reshape(mat(buf),num_rows, buf.size()/num_rows); + + if (in.eof()) + { + // Clear the eof and fail bits since this is caused by peeking at the EOF. + // But in the current case, we have successfully read the matrix. + in.clear(in.rdstate() & (~(std::ios::eofbit | std::ios::failbit))); + } + return in; + } +} + +// ---------------------------------------------------------------------------------------- + +#endif // DLIB_MATRIx_READ_FROM_ISTREAM_H_h_ + diff --git a/ml/dlib/dlib/matrix/matrix_subexp.h b/ml/dlib/dlib/matrix/matrix_subexp.h new file mode 100644 index 000000000..668e57496 --- /dev/null +++ b/ml/dlib/dlib/matrix/matrix_subexp.h @@ -0,0 +1,1566 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MATRIx_SUBEXP_ +#define DLIB_MATRIx_SUBEXP_ + +#include "matrix_subexp_abstract.h" +#include "matrix_op.h" +#include "matrix.h" +#include "../geometry/rectangle.h" +#include "matrix_expressions.h" +#include "matrix_mat.h" + + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template + const matrix_range_static_exp range ( + ) + { + COMPILE_TIME_ASSERT(inc > 0); + return matrix_range_static_exp(); + } + + template + const matrix_range_static_exp range ( + ) + { + return matrix_range_static_exp(); + } + + inline const matrix_range_exp range ( + long start, + long end + ) + { + return matrix_range_exp(start,end); + } + + inline const matrix_range_exp range ( + long start, + long inc, + long end + ) + { + DLIB_ASSERT(inc > 0, + "\tconst matrix_exp range(start, inc, end)" + << "\n\tInvalid inputs to this function" + << "\n\tstart: " << start + << "\n\tinc: " << inc + << "\n\tend: " << end + ); + + return matrix_range_exp(start,inc,end); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_subm + { + op_subm ( + const M& m_x, + const long& r_x, + const long& c_x, + const long& nr_x, + const long& nc_x + ) : m(m_x), r_(r_x), c_(c_x), nr_(nr_x), nc_(nc_x) { } + + const M& m; + const long r_; + const long c_; + const long nr_; + const long nc_; + + const static long cost = M::cost+1; + typedef typename M::type type; + typedef typename M::const_ret_type const_ret_type; + typedef typename M::mem_manager_type mem_manager_type; + typedef typename M::layout_type layout_type; + const static long NR = 0; + const static long NC = 0; + + const_ret_type apply ( long r, long c) const { return m(r+r_,c+c_); } + + long nr () const { return nr_; } + long nc () const { return nc_; } + + template bool aliases ( const matrix_exp& item) const { return m.aliases(item); } + template bool destructively_aliases ( const matrix_exp& item) const { return m.aliases(item); } + }; + + template < + typename EXP + > + const matrix_op > subm ( + const matrix_exp& m, + long r, + long c, + long nr, + long nc + ) + { + DLIB_ASSERT(r >= 0 && c >= 0 && nr >= 0 && nc >= 0 && r+nr <= m.nr() && c+nc <= m.nc(), + "\tconst matrix_exp subm(const matrix_exp& m, r, c, nr, nc)" + << "\n\tYou have specified invalid sub matrix dimensions" + << "\n\tm.nr(): " << m.nr() + << "\n\tm.nc(): " << m.nc() + << "\n\tr: " << r + << "\n\tc: " << c + << "\n\tnr: " << nr + << "\n\tnc: " << nc + ); + + typedef op_subm op; + return matrix_op(op(m.ref(),r,c,nr,nc)); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP + > + const matrix_op > subm_clipped ( + const matrix_exp& m, + long r, + long c, + long nr, + long nc + ) + { + rectangle box(c,r,c+nc-1,r+nr-1); + box = box.intersect(get_rect(m)); + typedef op_subm op; + return matrix_op(op(m.ref(),box.top(),box.left(),box.height(),box.width())); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP + > + const matrix_op > subm ( + const matrix_exp& m, + const rectangle& rect + ) + { + DLIB_ASSERT(get_rect(m).contains(rect) == true, + "\tconst matrix_exp subm(const matrix_exp& m, const rectangle& rect)" + << "\n\tYou have specified invalid sub matrix dimensions" + << "\n\tm.nr(): " << m.nr() + << "\n\tm.nc(): " << m.nc() + << "\n\trect.left(): " << rect.left() + << "\n\trect.top(): " << rect.top() + << "\n\trect.right(): " << rect.right() + << "\n\trect.bottom(): " << rect.bottom() + ); + + typedef op_subm op; + return matrix_op(op(m.ref(),rect.top(),rect.left(),rect.height(),rect.width())); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP + > + const matrix_op > subm_clipped ( + const matrix_exp& m, + rectangle rect + ) + { + rect = rect.intersect(get_rect(m)); + + typedef op_subm op; + return matrix_op(op(m.ref(),rect.top(),rect.left(),rect.height(),rect.width())); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_subm_range + { + op_subm_range( const M1& m1_, const M2& rows_, const M3& cols_) : + m1(m1_), rows(rows_), cols(cols_) {} + const M1& m1; + const M2& rows; + const M3& cols; + + const static long cost = M1::cost+M2::cost+M3::cost; + typedef typename M1::type type; + typedef typename M1::const_ret_type const_ret_type; + typedef typename M1::mem_manager_type mem_manager_type; + typedef typename M1::layout_type layout_type; + const static long NR = M2::NC*M2::NR; + const static long NC = M3::NC*M3::NR; + + const_ret_type apply ( long r, long c) const { return m1(rows(r),cols(c)); } + + long nr () const { return rows.size(); } + long nc () const { return cols.size(); } + + template bool aliases ( const matrix_exp& item) const + { return m1.aliases(item) || rows.aliases(item) || cols.aliases(item); } + template bool destructively_aliases ( const matrix_exp& item) const + { return m1.aliases(item) || rows.aliases(item) || cols.aliases(item); } + }; + + template < + typename EXP, + typename EXPr, + typename EXPc + > + const matrix_op > subm ( + const matrix_exp& m, + const matrix_exp& rows, + const matrix_exp& cols + ) + { + // the rows and cols matrices must contain integer elements + COMPILE_TIME_ASSERT(std::numeric_limits::is_integer); + COMPILE_TIME_ASSERT(std::numeric_limits::is_integer); + + DLIB_ASSERT(0 <= min(rows) && max(rows) < m.nr() && 0 <= min(cols) && max(cols) < m.nc() && + (rows.nr() == 1 || rows.nc() == 1) && (cols.nr() == 1 || cols.nc() == 1), + "\tconst matrix_exp subm(const matrix_exp& m, const matrix_exp& rows, const matrix_exp& cols)" + << "\n\tYou have given invalid arguments to this function" + << "\n\tm.nr(): " << m.nr() + << "\n\tm.nc(): " << m.nc() + << "\n\tmin(rows): " << min(rows) + << "\n\tmax(rows): " << max(rows) + << "\n\tmin(cols): " << min(cols) + << "\n\tmax(cols): " << max(cols) + << "\n\trows.nr(): " << rows.nr() + << "\n\trows.nc(): " << rows.nc() + << "\n\tcols.nr(): " << cols.nr() + << "\n\tcols.nc(): " << cols.nc() + ); + + typedef op_subm_range op; + return matrix_op(op(m.ref(),rows.ref(),cols.ref())); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_rowm + { + op_rowm(const M& m_, const long& row_) : m(m_), row(row_) {} + const M& m; + const long row; + + const static long cost = M::cost; + const static long NR = 1; + const static long NC = M::NC; + typedef typename M::type type; + typedef typename M::const_ret_type const_ret_type; + typedef typename M::mem_manager_type mem_manager_type; + typedef typename M::layout_type layout_type; + const_ret_type apply ( long, long c) const { return m(row,c); } + + long nr () const { return 1; } + long nc () const { return m.nc(); } + + template bool aliases ( const matrix_exp& item) const { return m.aliases(item); } + template bool destructively_aliases ( const matrix_exp& item) const { return m.aliases(item); } + }; + + template < + typename EXP + > + const matrix_op > rowm ( + const matrix_exp& m, + long row + ) + { + DLIB_ASSERT(row >= 0 && row < m.nr(), + "\tconst matrix_exp rowm(const matrix_exp& m, row)" + << "\n\tYou have specified invalid sub matrix dimensions" + << "\n\tm.nr(): " << m.nr() + << "\n\tm.nc(): " << m.nc() + << "\n\trow: " << row + ); + + typedef op_rowm op; + return matrix_op(op(m.ref(),row)); + } + + template + struct rowm_exp + { + typedef matrix_op > type; + }; + +// ---------------------------------------------------------------------------------------- + + template + struct op_rowm2 + { + op_rowm2(const M& m_, const long& row_, const long& len) : m(m_), row(row_), length(len) {} + const M& m; + const long row; + const long length; + + const static long cost = M::cost; + const static long NR = 1; + const static long NC = 0; + typedef typename M::type type; + typedef typename M::const_ret_type const_ret_type; + typedef typename M::mem_manager_type mem_manager_type; + typedef typename M::layout_type layout_type; + const_ret_type apply ( long , long c) const { return m(row,c); } + + long nr () const { return 1; } + long nc () const { return length; } + + template bool aliases ( const matrix_exp& item) const { return m.aliases(item); } + template bool destructively_aliases ( const matrix_exp& item) const { return m.aliases(item); } + }; + + template < + typename EXP + > + const matrix_op > rowm ( + const matrix_exp& m, + long row, + long length + ) + { + DLIB_ASSERT(row >= 0 && row < m.nr() && + length >= 0 && length <= m.nc(), + "\tconst matrix_exp rowm(const matrix_exp& m, row, length)" + << "\n\tYou have specified invalid sub matrix dimensions" + << "\n\tm.nr(): " << m.nr() + << "\n\tm.nc(): " << m.nc() + << "\n\trow: " << row + << "\n\tlength: " << length + ); + + typedef op_rowm2 op; + return matrix_op(op(m.ref(), row, length)); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_rowm_range + { + op_rowm_range( const M1& m1_, const M2& rows_) : m1(m1_), rows(rows_) {} + const M1& m1; + const M2& rows; + + const static long cost = M1::cost+M2::cost; + typedef typename M1::type type; + typedef typename M1::const_ret_type const_ret_type; + typedef typename M1::mem_manager_type mem_manager_type; + typedef typename M1::layout_type layout_type; + const static long NR = M2::NC*M2::NR; + const static long NC = M1::NC; + + const_ret_type apply ( long r, long c) const { return m1(rows(r),c); } + + long nr () const { return rows.size(); } + long nc () const { return m1.nc(); } + + template bool aliases ( const matrix_exp& item) const + { return m1.aliases(item) || rows.aliases(item); } + template bool destructively_aliases ( const matrix_exp& item) const + { return m1.aliases(item) || rows.aliases(item); } + }; + + template < + typename EXP1, + typename EXP2 + > + const matrix_op > rowm ( + const matrix_exp& m, + const matrix_exp& rows + ) + { + // the rows matrix must contain integer elements + COMPILE_TIME_ASSERT(std::numeric_limits::is_integer); + + DLIB_ASSERT(0 <= min(rows) && max(rows) < m.nr() && (rows.nr() == 1 || rows.nc() == 1), + "\tconst matrix_exp rowm(const matrix_exp& m, const matrix_exp& rows)" + << "\n\tYou have given invalid arguments to this function" + << "\n\tm.nr(): " << m.nr() + << "\n\tm.nc(): " << m.nc() + << "\n\tmin(rows): " << min(rows) + << "\n\tmax(rows): " << max(rows) + << "\n\trows.nr(): " << rows.nr() + << "\n\trows.nc(): " << rows.nc() + ); + + typedef op_rowm_range op; + return matrix_op(op(m.ref(),rows.ref())); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_colm + { + op_colm(const M& m_, const long& col_) : m(m_), col(col_) {} + const M& m; + const long col; + + const static long cost = M::cost; + const static long NR = M::NR; + const static long NC = 1; + typedef typename M::type type; + typedef typename M::const_ret_type const_ret_type; + typedef typename M::mem_manager_type mem_manager_type; + typedef typename M::layout_type layout_type; + const_ret_type apply ( long r, long) const { return m(r,col); } + + long nr () const { return m.nr(); } + long nc () const { return 1; } + + template bool aliases ( const matrix_exp& item) const { return m.aliases(item); } + template bool destructively_aliases ( const matrix_exp& item) const { return m.aliases(item); } + }; + + template < + typename EXP + > + const matrix_op > colm ( + const matrix_exp& m, + long col + ) + { + DLIB_ASSERT(col >= 0 && col < m.nc(), + "\tconst matrix_exp colm(const matrix_exp& m, row)" + << "\n\tYou have specified invalid sub matrix dimensions" + << "\n\tm.nr(): " << m.nr() + << "\n\tm.nc(): " << m.nc() + << "\n\tcol: " << col + ); + + typedef op_colm op; + return matrix_op(op(m.ref(),col)); + } + + template + struct colm_exp + { + typedef matrix_op > type; + }; + +// ---------------------------------------------------------------------------------------- + + template + struct op_colm2 + { + op_colm2(const M& m_, const long& col_, const long& len) : m(m_), col(col_), length(len) {} + const M& m; + const long col; + const long length; + + const static long cost = M::cost; + const static long NR = 0; + const static long NC = 1; + typedef typename M::type type; + typedef typename M::const_ret_type const_ret_type; + typedef typename M::mem_manager_type mem_manager_type; + typedef typename M::layout_type layout_type; + const_ret_type apply ( long r, long ) const { return m(r,col); } + + long nr () const { return length; } + long nc () const { return 1; } + + template bool aliases ( const matrix_exp& item) const { return m.aliases(item); } + template bool destructively_aliases ( const matrix_exp& item) const { return m.aliases(item); } + }; + + template < + typename EXP + > + const matrix_op > colm ( + const matrix_exp& m, + long col, + long length + ) + { + DLIB_ASSERT(col >= 0 && col < m.nc() && + length >= 0 && length <= m.nr(), + "\tconst matrix_exp colm(const matrix_exp& m, col, length)" + << "\n\tYou have specified invalid sub matrix dimensions" + << "\n\tm.nr(): " << m.nr() + << "\n\tm.nc(): " << m.nc() + << "\n\tcol: " << col + << "\n\tlength: " << length + ); + + typedef op_colm2 op; + return matrix_op(op(m.ref(),col, length)); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_colm_range + { + op_colm_range( const M1& m1_, const M2& cols_) : m1(m1_), cols(cols_) {} + const M1& m1; + const M2& cols; + + typedef typename M1::type type; + typedef typename M1::const_ret_type const_ret_type; + typedef typename M1::mem_manager_type mem_manager_type; + typedef typename M1::layout_type layout_type; + const static long NR = M1::NR; + const static long NC = M2::NC*M2::NR; + const static long cost = M1::cost+M2::cost; + + const_ret_type apply (long r, long c) const { return m1(r,cols(c)); } + + long nr () const { return m1.nr(); } + long nc () const { return cols.size(); } + + template bool aliases ( const matrix_exp& item) const + { return m1.aliases(item) || cols.aliases(item); } + template bool destructively_aliases ( const matrix_exp& item) const + { return m1.aliases(item) || cols.aliases(item); } + }; + + template < + typename EXP1, + typename EXP2 + > + const matrix_op > colm ( + const matrix_exp& m, + const matrix_exp& cols + ) + { + // the rows matrix must contain integer elements + COMPILE_TIME_ASSERT(std::numeric_limits::is_integer); + + DLIB_ASSERT(0 <= min(cols) && max(cols) < m.nc() && (cols.nr() == 1 || cols.nc() == 1), + "\tconst matrix_exp colm(const matrix_exp& m, const matrix_exp& cols)" + << "\n\tYou have given invalid arguments to this function" + << "\n\tm.nr(): " << m.nr() + << "\n\tm.nc(): " << m.nc() + << "\n\tmin(cols): " << min(cols) + << "\n\tmax(cols): " << max(cols) + << "\n\tcols.nr(): " << cols.nr() + << "\n\tcols.nc(): " << cols.nc() + ); + + typedef op_colm_range op; + return matrix_op(op(m.ref(),cols.ref())); + } + +// ---------------------------------------------------------------------------------------- + + template + class assignable_ptr_matrix + { + public: + typedef T type; + typedef row_major_layout layout_type; + typedef matrix matrix_type; + + assignable_ptr_matrix( + T* ptr_, + long nr_, + long nc_ + ) : ptr(ptr_), height(nr_), width(nc_){} + + T& operator() ( + long r, + long c + ) + { + return ptr[r*width + c]; + } + + const T& operator() ( + long r, + long c + ) const + { + return ptr[r*width + c]; + } + + long nr() const { return height; } + long nc() const { return width; } + + template + assignable_ptr_matrix& operator= ( + const matrix_exp& exp + ) + { + // You can only assign to a set_ptrm() expression with a source matrix that + // contains the same type of elements as the target (i.e. you can't mix double + // and float types). + COMPILE_TIME_ASSERT((is_same_type::value == true)); + + DLIB_ASSERT( exp.nr() == height && exp.nc() == width, + "\tassignable_matrix_expression set_ptrm()" + << "\n\tYou have tried to assign to this object using a matrix that isn't the right size" + << "\n\texp.nr() (source matrix): " << exp.nr() + << "\n\texp.nc() (source matrix): " << exp.nc() + << "\n\twidth (target matrix): " << width + << "\n\theight (target matrix): " << height + ); + + if (exp.destructively_aliases(mat(ptr,height,width)) == false) + { + matrix_assign(*this, exp); + } + else + { + // make a temporary copy of the matrix we are going to assign to ptr to + // avoid aliasing issues during the copy + this->operator=(tmp(exp)); + } + + return *this; + } + + template + assignable_ptr_matrix& operator+= ( + const matrix_exp& exp + ) + { + // You can only assign to a set_ptrm() expression with a source matrix that + // contains the same type of elements as the target (i.e. you can't mix double + // and float types). + COMPILE_TIME_ASSERT((is_same_type::value == true)); + + DLIB_ASSERT( exp.nr() == height && exp.nc() == width, + "\tassignable_matrix_expression set_ptrm()" + << "\n\tYou have tried to assign to this object using a matrix that isn't the right size" + << "\n\texp.nr() (source matrix): " << exp.nr() + << "\n\texp.nc() (source matrix): " << exp.nc() + << "\n\twidth (target matrix): " << width + << "\n\theight (target matrix): " << height + ); + + if (exp.destructively_aliases(mat(ptr,height,width)) == false) + { + matrix_assign(*this, mat(ptr,height,width)+exp); + } + else + { + // make a temporary copy of the matrix we are going to assign to ptr to + // avoid aliasing issues during the copy + this->operator+=(tmp(exp)); + } + + return *this; + } + + template + assignable_ptr_matrix& operator-= ( + const matrix_exp& exp + ) + { + // You can only assign to a set_ptrm() expression with a source matrix that + // contains the same type of elements as the target (i.e. you can't mix double + // and float types). + COMPILE_TIME_ASSERT((is_same_type::value == true)); + + DLIB_ASSERT( exp.nr() == height && exp.nc() == width, + "\tassignable_matrix_expression set_ptrm()" + << "\n\tYou have tried to assign to this object using a matrix that isn't the right size" + << "\n\texp.nr() (source matrix): " << exp.nr() + << "\n\texp.nc() (source matrix): " << exp.nc() + << "\n\twidth (target matrix): " << width + << "\n\theight (target matrix): " << height + ); + + if (exp.destructively_aliases(mat(ptr,height,width)) == false) + { + matrix_assign(*this, mat(ptr,height,width)-exp); + } + else + { + // make a temporary copy of the matrix we are going to assign to ptr to + // avoid aliasing issues during the copy + this->operator-=(tmp(exp)); + } + + return *this; + } + + assignable_ptr_matrix& operator= ( + const T& value + ) + { + const long size = width*height; + for (long i = 0; i < size; ++i) + ptr[i] = value; + + return *this; + } + + assignable_ptr_matrix& operator+= ( + const T& value + ) + { + const long size = width*height; + for (long i = 0; i < size; ++i) + ptr[i] += value; + + return *this; + } + + assignable_ptr_matrix& operator-= ( + const T& value + ) + { + const long size = width*height; + for (long i = 0; i < size; ++i) + ptr[i] -= value; + + return *this; + } + + + T* ptr; + const long height; + const long width; + }; + + + template + assignable_ptr_matrix set_ptrm ( + T* ptr, + long nr, + long nc = 1 + ) + { + DLIB_ASSERT(nr >= 0 && nc >= 0, + "\t assignable_matrix_expression set_ptrm(T* ptr, long nr, long nc)" + << "\n\t The dimensions can't be negative." + << "\n\t nr: " << nr + << "\n\t nc: " << nc + ); + + + return assignable_ptr_matrix(ptr,nr,nc); + } + +// ---------------------------------------------------------------------------------------- + + template + class assignable_sub_matrix + { + public: + typedef T type; + typedef l layout_type; + typedef matrix matrix_type; + + assignable_sub_matrix( + matrix& m_, + long top_, + long left_, + long height_, + long width_ + ) : m(m_), left(left_), top(top_), width(width_), height(height_) {} + + T& operator() ( + long r, + long c + ) + { + return m(r+top,c+left); + } + + const T& operator() ( + long r, + long c + ) const + { + return m(r+top,c+left); + } + + long nr() const { return height; } + long nc() const { return width; } + + template + assignable_sub_matrix& operator= ( + const matrix_exp& exp + ) + { + DLIB_ASSERT( exp.nr() == height && exp.nc() == width, + "\tassignable_matrix_expression set_subm()" + << "\n\tYou have tried to assign to this object using a matrix that isn't the right size" + << "\n\texp.nr() (source matrix): " << exp.nr() + << "\n\texp.nc() (source matrix): " << exp.nc() + << "\n\twidth (target matrix): " << width + << "\n\theight (target matrix): " << height + ); + + if (exp.destructively_aliases(m) == false) + { + matrix_assign(*this, exp); + } + else + { + // make a temporary copy of the matrix we are going to assign to m to + // avoid aliasing issues during the copy + this->operator=(tmp(exp)); + } + + return *this; + } + + template + assignable_sub_matrix& operator+= ( + const matrix_exp& exp + ) + { + DLIB_ASSERT( exp.nr() == height && exp.nc() == width, + "\tassignable_matrix_expression set_subm()" + << "\n\tYou have tried to assign to this object using a matrix that isn't the right size" + << "\n\texp.nr() (source matrix): " << exp.nr() + << "\n\texp.nc() (source matrix): " << exp.nc() + << "\n\twidth (target matrix): " << width + << "\n\theight (target matrix): " << height + ); + + if (exp.destructively_aliases(m) == false) + { + matrix_assign(*this, subm(m,top,left,height,width)+exp); + } + else + { + // make a temporary copy of the matrix we are going to assign to m to + // avoid aliasing issues during the copy + this->operator+=(tmp(exp)); + } + + return *this; + } + + template + assignable_sub_matrix& operator-= ( + const matrix_exp& exp + ) + { + DLIB_ASSERT( exp.nr() == height && exp.nc() == width, + "\tassignable_matrix_expression set_subm()" + << "\n\tYou have tried to assign to this object using a matrix that isn't the right size" + << "\n\texp.nr() (source matrix): " << exp.nr() + << "\n\texp.nc() (source matrix): " << exp.nc() + << "\n\twidth (target matrix): " << width + << "\n\theight (target matrix): " << height + ); + + if (exp.destructively_aliases(m) == false) + { + matrix_assign(*this, subm(m,top,left,height,width)-exp); + } + else + { + // make a temporary copy of the matrix we are going to assign to m to + // avoid aliasing issues during the copy + this->operator-=(tmp(exp)); + } + + return *this; + } + + assignable_sub_matrix& operator= ( + const T& value + ) + { + const long bottom = top+height-1; + const long right = left+width-1; + for (long r = top; r <= bottom; ++r) + { + for (long c = left; c <= right; ++c) + { + m(r,c) = value; + } + } + + return *this; + } + + assignable_sub_matrix& operator+= ( + const T& value + ) + { + const long bottom = top+height-1; + const long right = left+width-1; + for (long r = top; r <= bottom; ++r) + { + for (long c = left; c <= right; ++c) + { + m(r,c) += value; + } + } + + return *this; + } + + assignable_sub_matrix& operator-= ( + const T& value + ) + { + const long bottom = top+height-1; + const long right = left+width-1; + for (long r = top; r <= bottom; ++r) + { + for (long c = left; c <= right; ++c) + { + m(r,c) -= value; + } + } + + return *this; + } + + + matrix& m; + const long left, top, width, height; + }; + + + template + assignable_sub_matrix set_subm ( + matrix& m, + const rectangle& rect + ) + { + DLIB_ASSERT(get_rect(m).contains(rect) == true, + "\tassignable_matrix_expression set_subm(matrix& m, const rectangle& rect)" + << "\n\tYou have specified invalid sub matrix dimensions" + << "\n\tm.nr(): " << m.nr() + << "\n\tm.nc(): " << m.nc() + << "\n\trect.left(): " << rect.left() + << "\n\trect.top(): " << rect.top() + << "\n\trect.right(): " << rect.right() + << "\n\trect.bottom(): " << rect.bottom() + ); + + + return assignable_sub_matrix(m,rect.top(), rect.left(), rect.height(), rect.width()); + } + + + template + assignable_sub_matrix set_subm ( + matrix& m, + long r, + long c, + long nr, + long nc + ) + { + DLIB_ASSERT(r >= 0 && c >= 0 && nr >= 0 && nc >= 0 && r+nr <= m.nr() && c+nc <= m.nc(), + "\tassignable_matrix_expression set_subm(matrix& m, r, c, nr, nc)" + << "\n\tYou have specified invalid sub matrix dimensions" + << "\n\tm.nr(): " << m.nr() + << "\n\tm.nc(): " << m.nc() + << "\n\tr: " << r + << "\n\tc: " << c + << "\n\tnr: " << nr + << "\n\tnc: " << nc + ); + + return assignable_sub_matrix(m,r,c, nr, nc); + } + +// ---------------------------------------------------------------------------------------- + + template + class assignable_sub_range_matrix + { + public: + typedef T type; + typedef l layout_type; + typedef matrix matrix_type; + + assignable_sub_range_matrix( + matrix& m_, + const EXPr& rows_, + const EXPc& cols_ + ) : m(m_), rows(rows_), cols(cols_) {} + + T& operator() ( + long r, + long c + ) + { + return m(rows(r),cols(c)); + } + + long nr() const { return rows.size(); } + long nc() const { return cols.size(); } + + + template + assignable_sub_range_matrix& operator= ( + const matrix_exp& exp + ) + { + DLIB_ASSERT( exp.nr() == rows.size() && exp.nc() == cols.size(), + "\tassignable_matrix_expression set_subm(matrix& m, const matrix_exp rows, const matrix_exp cols)" + << "\n\tYou have tried to assign to this object using a matrix that isn't the right size" + << "\n\texp.nr() (source matrix): " << exp.nr() + << "\n\texp.nc() (source matrix): " << exp.nc() + << "\n\trows.size() (target matrix): " << rows.size() + << "\n\tcols.size() (target matrix): " << cols.size() + ); + + if (exp.destructively_aliases(m) == false) + { + matrix_assign(*this, exp); + } + else + { + // make a temporary copy of the matrix we are going to assign to m to + // avoid aliasing issues during the copy + this->operator=(tmp(exp)); + } + + return *this; + } + + template + assignable_sub_range_matrix& operator+= ( + const matrix_exp& exp + ) + { + DLIB_ASSERT( exp.nr() == rows.size() && exp.nc() == cols.size(), + "\tassignable_matrix_expression set_subm(matrix& m, const matrix_exp rows, const matrix_exp cols)" + << "\n\tYou have tried to assign to this object using a matrix that isn't the right size" + << "\n\texp.nr() (source matrix): " << exp.nr() + << "\n\texp.nc() (source matrix): " << exp.nc() + << "\n\trows.size() (target matrix): " << rows.size() + << "\n\tcols.size() (target matrix): " << cols.size() + ); + + if (exp.destructively_aliases(m) == false) + { + matrix_assign(*this, subm(m,rows,cols)+exp); + } + else + { + // make a temporary copy of the matrix we are going to assign to m to + // avoid aliasing issues during the copy + this->operator+=(tmp(exp)); + } + + return *this; + } + + template + assignable_sub_range_matrix& operator-= ( + const matrix_exp& exp + ) + { + DLIB_ASSERT( exp.nr() == rows.size() && exp.nc() == cols.size(), + "\tassignable_matrix_expression set_subm(matrix& m, const matrix_exp rows, const matrix_exp cols)" + << "\n\tYou have tried to assign to this object using a matrix that isn't the right size" + << "\n\texp.nr() (source matrix): " << exp.nr() + << "\n\texp.nc() (source matrix): " << exp.nc() + << "\n\trows.size() (target matrix): " << rows.size() + << "\n\tcols.size() (target matrix): " << cols.size() + ); + + if (exp.destructively_aliases(m) == false) + { + matrix_assign(*this, subm(m,rows,cols)-exp); + } + else + { + // make a temporary copy of the matrix we are going to assign to m to + // avoid aliasing issues during the copy + this->operator-=(tmp(exp)); + } + + return *this; + } + + assignable_sub_range_matrix& operator= ( + const T& value + ) + { + for (long r = 0; r < rows.size(); ++r) + { + for (long c = 0; c < cols.size(); ++c) + { + m(rows(r),cols(c)) = value; + } + } + + return *this; + } + + assignable_sub_range_matrix& operator+= ( + const T& value + ) + { + for (long r = 0; r < rows.size(); ++r) + { + for (long c = 0; c < cols.size(); ++c) + { + m(rows(r),cols(c)) += value; + } + } + + return *this; + } + + assignable_sub_range_matrix& operator-= ( + const T& value + ) + { + for (long r = 0; r < rows.size(); ++r) + { + for (long c = 0; c < cols.size(); ++c) + { + m(rows(r),cols(c)) -= value; + } + } + + return *this; + } + + private: + + matrix& m; + const EXPr rows; + const EXPc cols; + }; + + template + assignable_sub_range_matrix set_subm ( + matrix& m, + const matrix_exp& rows, + const matrix_exp& cols + ) + { + DLIB_ASSERT(0 <= min(rows) && max(rows) < m.nr() && 0 <= min(cols) && max(cols) < m.nc() && + (rows.nr() == 1 || rows.nc() == 1) && (cols.nr() == 1 || cols.nc() == 1), + "\tassignable_matrix_expression set_subm(matrix& m, const matrix_exp& rows, const matrix_exp& cols)" + << "\n\tYou have specified invalid sub matrix dimensions" + << "\n\tm.nr(): " << m.nr() + << "\n\tm.nc(): " << m.nc() + << "\n\tmin(rows): " << min(rows) + << "\n\tmax(rows): " << max(rows) + << "\n\tmin(cols): " << min(cols) + << "\n\tmax(cols): " << max(cols) + << "\n\trows.nr(): " << rows.nr() + << "\n\trows.nc(): " << rows.nc() + << "\n\tcols.nr(): " << cols.nr() + << "\n\tcols.nc(): " << cols.nc() + ); + + return assignable_sub_range_matrix(m,rows.ref(),cols.ref()); + } + +// ---------------------------------------------------------------------------------------- + + template + assignable_sub_range_matrix > set_rowm ( + matrix& m, + const matrix_exp& rows + ) + { + DLIB_ASSERT(0 <= min(rows) && max(rows) < m.nr() && (rows.nr() == 1 || rows.nc() == 1), + "\tassignable_matrix_expression set_rowm(matrix& m, const matrix_exp& rows)" + << "\n\tYou have specified invalid sub matrix dimensions" + << "\n\tm.nr(): " << m.nr() + << "\n\tm.nc(): " << m.nc() + << "\n\tmin(rows): " << min(rows) + << "\n\tmax(rows): " << max(rows) + << "\n\trows.nr(): " << rows.nr() + << "\n\trows.nc(): " << rows.nc() + ); + + return assignable_sub_range_matrix >(m,rows.ref(),range(0,m.nc()-1)); + } + +// ---------------------------------------------------------------------------------------- + + template + assignable_sub_range_matrix,EXPc > set_colm ( + matrix& m, + const matrix_exp& cols + ) + { + DLIB_ASSERT(0 <= min(cols) && max(cols) < m.nc() && (cols.nr() == 1 || cols.nc() == 1), + "\tassignable_matrix_expression set_colm(matrix& m, const matrix_exp& cols)" + << "\n\tYou have specified invalid sub matrix dimensions" + << "\n\tm.nr(): " << m.nr() + << "\n\tm.nc(): " << m.nc() + << "\n\tmin(cols): " << min(cols) + << "\n\tmax(cols): " << max(cols) + << "\n\tcols.nr(): " << cols.nr() + << "\n\tcols.nc(): " << cols.nc() + ); + + return assignable_sub_range_matrix,EXPc >(m,range(0,m.nr()-1),cols.ref()); + } + +// ---------------------------------------------------------------------------------------- + + template + class assignable_col_matrix + { + public: + typedef T type; + typedef l layout_type; + typedef matrix matrix_type; + + assignable_col_matrix( + matrix& m_, + const long col_ + ) : m(m_), col(col_) {} + + T& operator() ( + long r, + long + ) + { + return m(r,col); + } + + const T& operator() ( + long r, + long + ) const + { + return m(r,col); + } + + long nr() const { return m.nr(); } + long nc() const { return 1; } + + template + assignable_col_matrix& operator= ( + const matrix_exp& exp + ) + { + DLIB_ASSERT( exp.nc() == 1 && exp.nr() == m.nr(), + "\tassignable_matrix_expression set_colm()" + << "\n\tYou have tried to assign to this object using a matrix that isn't the right size" + << "\n\texp.nr() (source matrix): " << exp.nr() + << "\n\texp.nc() (source matrix): " << exp.nc() + << "\n\tm.nr() (target matrix): " << m.nr() + ); + + if (exp.destructively_aliases(m) == false) + { + matrix_assign(*this, exp); + } + else + { + // make a temporary copy of the matrix we are going to assign to m to + // avoid aliasing issues during the copy + this->operator=(tmp(exp)); + } + + return *this; + } + + template + assignable_col_matrix& operator+= ( + const matrix_exp& exp + ) + { + DLIB_ASSERT( exp.nc() == 1 && exp.nr() == m.nr(), + "\tassignable_matrix_expression set_colm()" + << "\n\tYou have tried to assign to this object using a matrix that isn't the right size" + << "\n\texp.nr() (source matrix): " << exp.nr() + << "\n\texp.nc() (source matrix): " << exp.nc() + << "\n\tm.nr() (target matrix): " << m.nr() + ); + + if (exp.destructively_aliases(m) == false) + { + matrix_assign(*this, colm(m,col)+exp); + } + else + { + // make a temporary copy of the matrix we are going to assign to m to + // avoid aliasing issues during the copy + this->operator+=(tmp(exp)); + } + + return *this; + } + + template + assignable_col_matrix& operator-= ( + const matrix_exp& exp + ) + { + DLIB_ASSERT( exp.nc() == 1 && exp.nr() == m.nr(), + "\tassignable_matrix_expression set_colm()" + << "\n\tYou have tried to assign to this object using a matrix that isn't the right size" + << "\n\texp.nr() (source matrix): " << exp.nr() + << "\n\texp.nc() (source matrix): " << exp.nc() + << "\n\tm.nr() (target matrix): " << m.nr() + ); + + if (exp.destructively_aliases(m) == false) + { + matrix_assign(*this, colm(m,col)-exp); + } + else + { + // make a temporary copy of the matrix we are going to assign to m to + // avoid aliasing issues during the copy + this->operator-=(tmp(exp)); + } + + return *this; + } + + assignable_col_matrix& operator= ( + const T& value + ) + { + for (long i = 0; i < m.nr(); ++i) + { + m(i,col) = value; + } + + return *this; + } + + assignable_col_matrix& operator+= ( + const T& value + ) + { + for (long i = 0; i < m.nr(); ++i) + { + m(i,col) += value; + } + + return *this; + } + + assignable_col_matrix& operator-= ( + const T& value + ) + { + for (long i = 0; i < m.nr(); ++i) + { + m(i,col) -= value; + } + + return *this; + } + + + matrix& m; + const long col; + }; + + + template + assignable_col_matrix set_colm ( + matrix& m, + const long col + ) + { + DLIB_ASSERT(col >= 0 && col < m.nc(), + "\tassignable_matrix_expression set_colm(matrix& m, col)" + << "\n\tYou have specified invalid sub matrix dimensions" + << "\n\tm.nr(): " << m.nr() + << "\n\tm.nc(): " << m.nc() + << "\n\tcol: " << col + ); + + + return assignable_col_matrix(m,col); + } + +// ---------------------------------------------------------------------------------------- + + + template + class assignable_row_matrix + { + public: + typedef T type; + typedef l layout_type; + typedef matrix matrix_type; + + assignable_row_matrix( + matrix& m_, + const long row_ + ) : m(m_), row(row_) {} + + + T& operator() ( + long , + long c + ) + { + return m(row,c); + } + + const T& operator() ( + long , + long c + ) const + { + return m(row,c); + } + + long nr() const { return 1; } + long nc() const { return m.nc(); } + + + template + assignable_row_matrix& operator= ( + const matrix_exp& exp + ) + { + DLIB_ASSERT( exp.nr() == 1 && exp.nc() == m.nc(), + "\tassignable_matrix_expression set_rowm()" + << "\n\tYou have tried to assign to this object using a matrix that isn't the right size" + << "\n\texp.nr() (source matrix): " << exp.nr() + << "\n\texp.nc() (source matrix): " << exp.nc() + << "\n\tm.nc() (target matrix): " << m.nc() + ); + + if (exp.destructively_aliases(m) == false) + { + matrix_assign(*this, exp); + } + else + { + // make a temporary copy of the matrix we are going to assign to m to + // avoid aliasing issues during the copy + this->operator=(tmp(exp)); + } + + return *this; + } + + template + assignable_row_matrix& operator+= ( + const matrix_exp& exp + ) + { + DLIB_ASSERT( exp.nr() == 1 && exp.nc() == m.nc(), + "\tassignable_matrix_expression set_rowm()" + << "\n\tYou have tried to assign to this object using a matrix that isn't the right size" + << "\n\texp.nr() (source matrix): " << exp.nr() + << "\n\texp.nc() (source matrix): " << exp.nc() + << "\n\tm.nc() (target matrix): " << m.nc() + ); + + if (exp.destructively_aliases(m) == false) + { + matrix_assign(*this, rowm(m,row)+exp); + } + else + { + // make a temporary copy of the matrix we are going to assign to m to + // avoid aliasing issues during the copy + this->operator+=(tmp(exp)); + } + + return *this; + } + + template + assignable_row_matrix& operator-= ( + const matrix_exp& exp + ) + { + DLIB_ASSERT( exp.nr() == 1 && exp.nc() == m.nc(), + "\tassignable_matrix_expression set_rowm()" + << "\n\tYou have tried to assign to this object using a matrix that isn't the right size" + << "\n\texp.nr() (source matrix): " << exp.nr() + << "\n\texp.nc() (source matrix): " << exp.nc() + << "\n\tm.nc() (target matrix): " << m.nc() + ); + + if (exp.destructively_aliases(m) == false) + { + matrix_assign(*this, rowm(m,row)-exp); + } + else + { + // make a temporary copy of the matrix we are going to assign to m to + // avoid aliasing issues during the copy + this->operator-=(tmp(exp)); + } + + return *this; + } + + assignable_row_matrix& operator= ( + const T& value + ) + { + for (long i = 0; i < m.nc(); ++i) + { + m(row,i) = value; + } + + return *this; + } + + assignable_row_matrix& operator+= ( + const T& value + ) + { + for (long i = 0; i < m.nc(); ++i) + { + m(row,i) += value; + } + + return *this; + } + + assignable_row_matrix& operator-= ( + const T& value + ) + { + for (long i = 0; i < m.nc(); ++i) + { + m(row,i) -= value; + } + + return *this; + } + + + matrix& m; + const long row; + }; + + + template + assignable_row_matrix set_rowm ( + matrix& m, + const long row + ) + { + DLIB_ASSERT(row >= 0 && row < m.nr(), + "\tassignable_matrix_expression set_rowm(matrix& m, row)" + << "\n\tYou have specified invalid sub matrix dimensions" + << "\n\tm.nr(): " << m.nr() + << "\n\tm.nc(): " << m.nc() + << "\n\trow: " << row + ); + + + return assignable_row_matrix(m,row); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MATRIx_SUBEXP_ + diff --git a/ml/dlib/dlib/matrix/matrix_subexp_abstract.h b/ml/dlib/dlib/matrix/matrix_subexp_abstract.h new file mode 100644 index 000000000..2665d1b99 --- /dev/null +++ b/ml/dlib/dlib/matrix/matrix_subexp_abstract.h @@ -0,0 +1,570 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_MATRIx_SUBEXP_ABSTRACT_ +#ifdef DLIB_MATRIx_SUBEXP_ABSTRACT_ + +#include "matrix_abstract.h" +#include "../geometry/rectangle.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template + const matrix_exp range ( + ); + /*! + requires + - inc > 0 + ensures + - returns a matrix R such that: + - R::type == long + - R.nr() == 1 + - R.nc() == abs(end - start)/inc + 1 + - if (start <= end) then + - R(i) == start + i*inc + - else + - R(i) == start - i*inc + !*/ + + template + const matrix_exp range ( + ) { return range(); } + + const matrix_exp range ( + long start, + long inc, + long end + ); + /*! + requires + - inc > 0 + ensures + - returns a matrix R such that: + - R::type == long + - R.nr() == 1 + - R.nc() == abs(end - start)/inc + 1 + - if (start <= end) then + - R(i) == start + i*inc + - else + - R(i) == start - i*inc + !*/ + + const matrix_exp range ( + long start, + long end + ) { return range(start,1,end); } + +// ---------------------------------------------------------------------------------------- + + const matrix_exp subm ( + const matrix_exp& m, + const matrix_exp& rows, + const matrix_exp& cols, + ); + /*! + requires + - rows and cols contain integral elements (e.g. int, long) + - 0 <= min(rows) && max(rows) < m.nr() + - 0 <= min(cols) && max(cols) < m.nc() + - rows.nr() == 1 || rows.nc() == 1 + - cols.nr() == 1 || cols.nc() == 1 + (i.e. rows and cols must be vectors) + ensures + - returns a matrix R such that: + - R::type == the same type that was in m + - R.nr() == rows.size() + - R.nc() == cols.size() + - for all valid r and c: + R(r,c) == m(rows(r),cols(c)) + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp subm ( + const matrix_exp& m, + long row, + long col, + long nr, + long nc + ); + /*! + requires + - row >= 0 + - col >= 0 + - nr >= 0 + - nc >= 0 + - row + nr <= m.nr() + - col + nc <= m.nc() + ensures + - returns a matrix R such that: + - R.nr() == nr + - R.nc() == nc + - for all valid r and c: + R(r, c) == m(r+row,c+col) + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp subm ( + const matrix_exp& m, + const rectangle& rect + ); + /*! + requires + - get_rect(m).contains(rect) == true + (i.e. rect is a region inside the matrix m) + ensures + - returns a matrix R such that: + - R.nr() == rect.height() + - R.nc() == rect.width() + - for all valid r and c: + R(r, c) == m(r+rect.top(), c+rect.left()) + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp subm_clipped ( + const matrix_exp& m, + long row, + long col, + long nr, + long nc + ); + /*! + ensures + - This function is just like subm() except that it will automatically clip the + indicated sub matrix window so that it does not extend outside m. + In particular: + - Let box = rectangle(col,row,col+nc-1,row+nr-1) + (i.e. the box that contains the indicated sub matrix) + - Let box_clipped = box.intersect(get_rect(m)) + - Then this function returns a matrix R such that: + - R.nr() == box_clipped.height() + - R.nc() == box_clipped.width() + - for all valid r and c: + R(r, c) == m(r+box_clipped.top(),c+box_clipped.left()) + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp subm_clipped ( + const matrix_exp& m, + const rectangle& rect + ); + /*! + ensures + - Let box_clipped == rect.intersect(get_rect(m)) + - returns a matrix R such that: + - R.nr() == box_clipped.height() + - R.nc() == box_clipped.width() + - for all valid r and c: + R(r, c) == m(r+box_clipped.top(), c+box_clipped.left()) + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp rowm ( + const matrix_exp& m, + long row + ); + /*! + requires + - 0 <= row < m.nr() + ensures + - returns a matrix R such that: + - R.nr() == 1 + - R.nc() == m.nc() + - for all valid i: + R(i) == m(row,i) + !*/ + + template + struct rowm_exp + { + /*! + WHAT THIS OBJECT REPRESENTS + This struct allows you to determine the type of matrix expression + object returned from the rowm(m,row) function. An example makes its + use clear: + + template + void do_something( const matrix_exp& mat) + { + // r is a matrix expression that aliases mat. + typename rowm_exp::type r = rowm(mat,0); + + // Print the first row of mat. So we see that by using + // rowm_exp we can save the object returned by rowm() in + // a local variable. + cout << r << endl; + + // Note that you can only save the return value of rowm() to + // a local variable if the argument to rowm() has a lifetime + // beyond the rowm() expression. The example shown above is + // OK but the following would result in undefined behavior: + typename rowm_exp::type bad = rowm(mat + mat,0); + } + !*/ + typedef type_of_expression_returned_by_rowm type; + }; + +// ---------------------------------------------------------------------------------------- + + const matrix_exp rowm ( + const matrix_exp& m, + long row, + long length + ); + /*! + requires + - 0 <= row < m.nr() + - 0 <= length <= m.nc() + ensures + - returns a matrix R such that: + - R.nr() == 1 + - R.nc() == length + - for all valid i: + R(i) == m(row,i) + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp rowm ( + const matrix_exp& m, + const matrix_exp& rows + ); + /*! + requires + - rows contains integral elements (e.g. int, long) + - 0 <= min(rows) && max(rows) < m.nr() + - rows.nr() == 1 || rows.nc() == 1 + (i.e. rows must be a vector) + ensures + - returns a matrix R such that: + - R::type == the same type that was in m + - R.nr() == rows.size() + - R.nc() == m.nc() + - for all valid r and c: + R(r,c) == m(rows(r),c) + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp colm ( + const matrix_exp& m, + long col + ); + /*! + requires + - 0 <= col < m.nc() + ensures + - returns a matrix R such that: + - R.nr() == m.nr() + - R.nc() == 1 + - for all valid i: + R(i) == m(i,col) + !*/ + + template + struct colm_exp + { + /*! + WHAT THIS OBJECT REPRESENTS + This struct allows you to determine the type of matrix expression + object returned from the colm(m,col) function. An example makes its + use clear: + + template + void do_something( const matrix_exp& mat) + { + // c is a matrix expression that aliases mat. + typename colm_exp::type c = colm(mat,0); + + // Print the first column of mat. So we see that by using + // colm_exp we can save the object returned by colm() in + // a local variable. + cout << c << endl; + + // Note that you can only save the return value of colm() to + // a local variable if the argument to colm() has a lifetime + // beyond the colm() expression. The example shown above is + // OK but the following would result in undefined behavior: + typename colm_exp::type bad = colm(mat + mat,0); + } + !*/ + typedef type_of_expression_returned_by_colm type; + }; + +// ---------------------------------------------------------------------------------------- + + const matrix_exp colm ( + const matrix_exp& m, + long col, + long length + ); + /*! + requires + - 0 <= col < m.nc() + - 0 <= length <= m.nr() + ensures + - returns a matrix R such that: + - R.nr() == length + - R.nc() == 1 + - for all valid i: + R(i) == m(i,col) + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp colm ( + const matrix_exp& m, + const matrix_exp& cols + ); + /*! + requires + - cols contains integral elements (e.g. int, long) + - 0 <= min(cols) && max(cols) < m.nc() + - cols.nr() == 1 || cols.nc() == 1 + (i.e. cols must be a vector) + ensures + - returns a matrix R such that: + - R::type == the same type that was in m + - R.nr() == m.nr() + - R.nc() == cols.size() + - for all valid r and c: + R(r,c) == m(r,cols(c)) + !*/ + +// ---------------------------------------------------------------------------------------- + + template + assignable_matrix_expression set_ptrm ( + T* ptr, + long nr, + long nc = 1 + ); + /*! + requires + - ptr == a pointer to nr*nc elements of type T + - nr >= 0 + - nc >= 0 + ensures + - statements of the following form: + - set_ptrm(ptr,nr,nc) = some_matrix; + result in it being the case that: + - mat(ptr,nr,nc) == some_matrix. + + - statements of the following form: + - set_ptrm(ptr,nr,nc) = scalar_value; + result in it being the case that: + - mat(ptr,nr,nc) == uniform_matrix(nr,nc,scalar_value). + + - In addition to the normal assignment statements using the = symbol, you may + also use the usual += and -= versions of the assignment operator. In these + cases, they have their usual effect. + !*/ + +// ---------------------------------------------------------------------------------------- + + assignable_matrix_expression set_subm ( + matrix& m, + long row, + long col, + long nr, + long nc + ); + /*! + requires + - row >= 0 + - col >= 0 + - nr >= 0 + - nc >= 0 + - row + nr <= m.nr() + - col + nc <= m.nc() + ensures + - statements of the following form: + - set_subm(m,row,col,nr,nc) = some_matrix; + result in it being the case that: + - subm(m,row,col,nr,nc) == some_matrix. + + - statements of the following form: + - set_subm(m,row,col,nr,nc) = scalar_value; + result in it being the case that: + - subm(m,row,col,nr,nc) == uniform_matrix(nr,nc,scalar_value). + + - In addition to the normal assignment statements using the = symbol, you may + also use the usual += and -= versions of the assignment operator. In these + cases, they have their usual effect. + !*/ + +// ---------------------------------------------------------------------------------------- + + assignable_matrix_expression set_subm ( + matrix& m, + const rectangle& rect + ); + /*! + requires + - get_rect(m).contains(rect) == true + (i.e. rect is a region inside the matrix m) + ensures + - statements of the following form: + - set_subm(m,rect) = some_matrix; + result in it being the case that: + - subm(m,rect) == some_matrix. + + - statements of the following form: + - set_subm(m,rect) = scalar_value; + result in it being the case that: + - subm(m,rect) == uniform_matrix(nr,nc,scalar_value). + + - In addition to the normal assignment statements using the = symbol, you may + also use the usual += and -= versions of the assignment operator. In these + cases, they have their usual effect. + !*/ + +// ---------------------------------------------------------------------------------------- + + assignable_matrix_expression set_subm ( + matrix& m, + const matrix_exp& rows, + const matrix_exp& cols + ); + /*! + requires + - rows and cols contain integral elements (e.g. int, long) + - 0 <= min(rows) && max(rows) < m.nr() + - 0 <= min(cols) && max(cols) < m.nc() + - rows.nr() == 1 || rows.nc() == 1 + - cols.nr() == 1 || cols.nc() == 1 + (i.e. rows and cols must be vectors) + ensures + - statements of the following form: + - set_subm(m,rows,cols) = some_matrix; + result in it being the case that: + - subm(m,rows,cols) == some_matrix. + + - statements of the following form: + - set_subm(m,rows,cols) = scalar_value; + result in it being the case that: + - subm(m,rows,cols) == uniform_matrix(nr,nc,scalar_value). + + - In addition to the normal assignment statements using the = symbol, you may + also use the usual += and -= versions of the assignment operator. In these + cases, they have their usual effect. + !*/ + +// ---------------------------------------------------------------------------------------- + + assignable_matrix_expression set_rowm ( + matrix& m, + long row + ); + /*! + requires + - 0 <= row < m.nr() + ensures + - statements of the following form: + - set_rowm(m,row) = some_matrix; + result in it being the case that: + - rowm(m,row) == some_matrix. + + - statements of the following form: + - set_rowm(m,row) = scalar_value; + result in it being the case that: + - rowm(m,row) == uniform_matrix(1,nc,scalar_value). + + - In addition to the normal assignment statements using the = symbol, you may + also use the usual += and -= versions of the assignment operator. In these + cases, they have their usual effect. + !*/ + +// ---------------------------------------------------------------------------------------- + + assignable_matrix_expression set_rowm ( + matrix& m, + const matrix_exp& rows + ); + /*! + requires + - rows contains integral elements (e.g. int, long) + - 0 <= min(rows) && max(rows) < m.nr() + - rows.nr() == 1 || rows.nc() == 1 + (i.e. rows must be a vector) + ensures + - statements of the following form: + - set_rowm(m,rows) = some_matrix; + result in it being the case that: + - rowm(m,rows) == some_matrix. + + - statements of the following form: + - set_rowm(m,rows) = scalar_value; + result in it being the case that: + - rowm(m,rows) == uniform_matrix(nr,nc,scalar_value). + + - In addition to the normal assignment statements using the = symbol, you may + also use the usual += and -= versions of the assignment operator. In these + cases, they have their usual effect. + !*/ + +// ---------------------------------------------------------------------------------------- + + assignable_matrix_expression set_colm ( + matrix& m, + long col + ); + /*! + requires + - 0 <= col < m.nr() + ensures + - statements of the following form: + - set_colm(m,col) = some_matrix; + result in it being the case that: + - colm(m,col) == some_matrix. + + - statements of the following form: + - set_colm(m,col) = scalar_value; + result in it being the case that: + - colm(m,col) == uniform_matrix(nr,1,scalar_value). + + - In addition to the normal assignment statements using the = symbol, you may + also use the usual += and -= versions of the assignment operator. In these + cases, they have their usual effect. + !*/ + +// ---------------------------------------------------------------------------------------- + + assignable_matrix_expression set_colm ( + matrix& m, + const matrix_exp& cols + ); + /*! + requires + - cols contains integral elements (e.g. int, long) + - 0 <= min(cols) && max(cols) < m.nc() + - cols.nr() == 1 || cols.nc() == 1 + (i.e. cols must be a vector) + ensures + - statements of the following form: + - set_colm(m,cols) = some_matrix; + result in it being the case that: + - colm(m,cols) == some_matrix. + + - statements of the following form: + - set_colm(m,cols) = scalar_value; + result in it being the case that: + - colm(m,cols) == uniform_matrix(nr,nc,scalar_value). + + - In addition to the normal assignment statements using the = symbol, you may + also use the usual += and -= versions of the assignment operator. In these + cases, they have their usual effect. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MATRIx_SUBEXP_ABSTRACT_ + diff --git a/ml/dlib/dlib/matrix/matrix_trsm.h b/ml/dlib/dlib/matrix/matrix_trsm.h new file mode 100644 index 000000000..ef5ec5ed9 --- /dev/null +++ b/ml/dlib/dlib/matrix/matrix_trsm.h @@ -0,0 +1,654 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MATRiX_TRSM_Hh_ +#define DLIB_MATRiX_TRSM_Hh_ +#include "lapack/fortran_id.h" +#include "cblas_constants.h" + +namespace dlib +{ + namespace blas_bindings + { +#ifdef DLIB_USE_BLAS +#ifndef CBLAS_H + extern "C" + { + void cblas_strsm(const CBLAS_ORDER Order, const CBLAS_SIDE Side, + const CBLAS_UPLO Uplo, const CBLAS_TRANSPOSE TransA, + const CBLAS_DIAG Diag, const int M, const int N, + const float alpha, const float *A, const int lda, + float *B, const int ldb); + + void cblas_dtrsm(const CBLAS_ORDER Order, const CBLAS_SIDE Side, + const CBLAS_UPLO Uplo, const CBLAS_TRANSPOSE TransA, + const CBLAS_DIAG Diag, const int M, const int N, + const double alpha, const double *A, const int lda, + double *B, const int ldb); + } +#endif // if not CBLAS_H +#endif // if DLIB_USE_BLAS + + // ------------------------------------------------------------------------------------ + +/* Purpose */ +/* ======= */ + +/* DTRSM solves one of the matrix equations */ + +/* op( A )*X = alpha*B, or X*op( A ) = alpha*B, */ + +/* where alpha is a scalar, X and B are m by n matrices, A is a unit, or */ +/* non-unit, upper or lower triangular matrix and op( A ) is one of */ + +/* op( A ) = A or op( A ) = A'. */ + +/* The matrix X is overwritten on B. */ + +/* Arguments */ +/* ========== */ + +/* SIDE - CHARACTER*1. */ +/* On entry, SIDE specifies whether op( A ) appears on the left */ +/* or right of X as follows: */ + +/* SIDE = 'L' or 'l' op( A )*X = alpha*B. */ + +/* SIDE = 'R' or 'r' X*op( A ) = alpha*B. */ + +/* Unchanged on exit. */ + +/* UPLO - CHARACTER*1. */ +/* On entry, UPLO specifies whether the matrix A is an upper or */ +/* lower triangular matrix as follows: */ + +/* UPLO = 'U' or 'u' A is an upper triangular matrix. */ + +/* UPLO = 'L' or 'l' A is a lower triangular matrix. */ + +/* Unchanged on exit. */ + +/* TRANSA - CHARACTER*1. */ +/* On entry, TRANSA specifies the form of op( A ) to be used in */ +/* the matrix multiplication as follows: */ + +/* TRANSA = 'N' or 'n' op( A ) = A. */ + +/* TRANSA = 'T' or 't' op( A ) = A'. */ + +/* TRANSA = 'C' or 'c' op( A ) = A'. */ + +/* Unchanged on exit. */ + +/* DIAG - CHARACTER*1. */ +/* On entry, DIAG specifies whether or not A is unit triangular */ +/* as follows: */ + +/* DIAG = 'U' or 'u' A is assumed to be unit triangular. */ + +/* DIAG = 'N' or 'n' A is not assumed to be unit */ +/* triangular. */ + +/* Unchanged on exit. */ + +/* M - INTEGER. */ +/* On entry, M specifies the number of rows of B. M must be at */ +/* least zero. */ +/* Unchanged on exit. */ + +/* N - INTEGER. */ +/* On entry, N specifies the number of columns of B. N must be */ +/* at least zero. */ +/* Unchanged on exit. */ + +/* ALPHA - DOUBLE PRECISION. */ +/* On entry, ALPHA specifies the scalar alpha. When alpha is */ +/* zero then A is not referenced and B need not be set before */ +/* entry. */ +/* Unchanged on exit. */ + +/* A - DOUBLE PRECISION array of DIMENSION ( LDA, k ), where k is m */ +/* when SIDE = 'L' or 'l' and is n when SIDE = 'R' or 'r'. */ +/* Before entry with UPLO = 'U' or 'u', the leading k by k */ +/* upper triangular part of the array A must contain the upper */ +/* triangular matrix and the strictly lower triangular part of */ +/* A is not referenced. */ +/* Before entry with UPLO = 'L' or 'l', the leading k by k */ +/* lower triangular part of the array A must contain the lower */ +/* triangular matrix and the strictly upper triangular part of */ +/* A is not referenced. */ +/* Note that when DIAG = 'U' or 'u', the diagonal elements of */ +/* A are not referenced either, but are assumed to be unity. */ +/* Unchanged on exit. */ + +/* LDA - INTEGER. */ +/* On entry, LDA specifies the first dimension of A as declared */ +/* in the calling (sub) program. When SIDE = 'L' or 'l' then */ +/* LDA must be at least max( 1, m ), when SIDE = 'R' or 'r' */ +/* then LDA must be at least max( 1, n ). */ +/* Unchanged on exit. */ + +/* B - DOUBLE PRECISION array of DIMENSION ( LDB, n ). */ +/* Before entry, the leading m by n part of the array B must */ +/* contain the right-hand side matrix B, and on exit is */ +/* overwritten by the solution matrix X. */ + +/* LDB - INTEGER. */ +/* On entry, LDB specifies the first dimension of B as declared */ +/* in the calling (sub) program. LDB must be at least */ +/* max( 1, m ). */ +/* Unchanged on exit. */ + + +/* Level 3 Blas routine. */ + + +/* -- Written on 8-February-1989. */ +/* Jack Dongarra, Argonne National Laboratory. */ +/* Iain Duff, AERE Harwell. */ +/* Jeremy Du Croz, Numerical Algorithms Group Ltd. */ +/* Sven Hammarling, Numerical Algorithms Group Ltd. */ + + template + void local_trsm( + const CBLAS_ORDER Order, + CBLAS_SIDE Side, + CBLAS_UPLO Uplo, + const CBLAS_TRANSPOSE TransA, + const CBLAS_DIAG Diag, + long m, + long n, + T alpha, + const T *a, + long lda, + T *b, + long ldb + ) + /*! + This is a copy of the dtrsm routine from the netlib.org BLAS which was run though + f2c and converted into this form for use when a BLAS library is not available. + !*/ + { + if (Order == CblasRowMajor) + { + // since row major ordering looks like transposition to FORTRAN we need to flip a + // few things. + if (Side == CblasLeft) + Side = CblasRight; + else + Side = CblasLeft; + + if (Uplo == CblasUpper) + Uplo = CblasLower; + else + Uplo = CblasUpper; + + std::swap(m,n); + } + + /* System generated locals */ + long a_dim1, a_offset, b_dim1, b_offset, i__1, i__2, i__3; + + /* Local variables */ + long i__, j, k, info; + T temp; + bool lside; + long nrowa; + bool upper; + bool nounit; + + /* Parameter adjustments */ + a_dim1 = lda; + a_offset = 1 + a_dim1; + a -= a_offset; + b_dim1 = ldb; + b_offset = 1 + b_dim1; + b -= b_offset; + + /* Function Body */ + lside = (Side == CblasLeft); + if (lside) + { + nrowa = m; + } else + { + nrowa = n; + } + nounit = (Diag == CblasNonUnit); + upper = (Uplo == CblasUpper); + + info = 0; + if (! lside && ! (Side == CblasRight)) { + info = 1; + } else if (! upper && !(Uplo == CblasLower) ) { + info = 2; + } else if (!(TransA == CblasNoTrans) && + !(TransA == CblasTrans) && + !(TransA == CblasConjTrans)) { + info = 3; + } else if (!(Diag == CblasUnit) && + !(Diag == CblasNonUnit) ) { + info = 4; + } else if (m < 0) { + info = 5; + } else if (n < 0) { + info = 6; + } else if (lda < std::max(1,nrowa)) { + info = 9; + } else if (ldb < std::max(1,m)) { + info = 11; + } + DLIB_CASSERT( info == 0, "Invalid inputs given to local_trsm"); + + /* Quick return if possible. */ + + if (m == 0 || n == 0) { + return; + } + + /* And when alpha.eq.zero. */ + + if (alpha == 0.) { + i__1 = n; + for (j = 1; j <= i__1; ++j) { + i__2 = m; + for (i__ = 1; i__ <= i__2; ++i__) { + b[i__ + j * b_dim1] = 0.; + /* L10: */ + } + /* L20: */ + } + return; + } + + /* Start the operations. */ + + if (lside) { + if (TransA == CblasNoTrans) { + + /* Form B := alpha*inv( A )*B. */ + + if (upper) { + i__1 = n; + for (j = 1; j <= i__1; ++j) { + if (alpha != 1.) { + i__2 = m; + for (i__ = 1; i__ <= i__2; ++i__) { + b[i__ + j * b_dim1] = alpha * b[i__ + j * b_dim1] + ; + /* L30: */ + } + } + for (k = m; k >= 1; --k) { + if (b[k + j * b_dim1] != 0.) { + if (nounit) { + b[k + j * b_dim1] /= a[k + k * a_dim1]; + } + i__2 = k - 1; + for (i__ = 1; i__ <= i__2; ++i__) { + b[i__ + j * b_dim1] -= b[k + j * b_dim1] * a[ + i__ + k * a_dim1]; + /* L40: */ + } + } + /* L50: */ + } + /* L60: */ + } + } else { + i__1 = n; + for (j = 1; j <= i__1; ++j) { + if (alpha != 1.) { + i__2 = m; + for (i__ = 1; i__ <= i__2; ++i__) { + b[i__ + j * b_dim1] = alpha * b[i__ + j * b_dim1] + ; + /* L70: */ + } + } + i__2 = m; + for (k = 1; k <= i__2; ++k) { + if (b[k + j * b_dim1] != 0.) { + if (nounit) { + b[k + j * b_dim1] /= a[k + k * a_dim1]; + } + i__3 = m; + for (i__ = k + 1; i__ <= i__3; ++i__) { + b[i__ + j * b_dim1] -= b[k + j * b_dim1] * a[ + i__ + k * a_dim1]; + /* L80: */ + } + } + /* L90: */ + } + /* L100: */ + } + } + } else { + + /* Form B := alpha*inv( A' )*B. */ + + if (upper) { + i__1 = n; + for (j = 1; j <= i__1; ++j) { + i__2 = m; + for (i__ = 1; i__ <= i__2; ++i__) { + temp = alpha * b[i__ + j * b_dim1]; + i__3 = i__ - 1; + for (k = 1; k <= i__3; ++k) { + temp -= a[k + i__ * a_dim1] * b[k + j * b_dim1]; + /* L110: */ + } + if (nounit) { + temp /= a[i__ + i__ * a_dim1]; + } + b[i__ + j * b_dim1] = temp; + /* L120: */ + } + /* L130: */ + } + } else { + i__1 = n; + for (j = 1; j <= i__1; ++j) { + for (i__ = m; i__ >= 1; --i__) { + temp = alpha * b[i__ + j * b_dim1]; + i__2 = m; + for (k = i__ + 1; k <= i__2; ++k) { + temp -= a[k + i__ * a_dim1] * b[k + j * b_dim1]; + /* L140: */ + } + if (nounit) { + temp /= a[i__ + i__ * a_dim1]; + } + b[i__ + j * b_dim1] = temp; + /* L150: */ + } + /* L160: */ + } + } + } + } else { + if (TransA == CblasNoTrans) { + + /* Form B := alpha*B*inv( A ). */ + + if (upper) { + i__1 = n; + for (j = 1; j <= i__1; ++j) { + if (alpha != 1.) { + i__2 = m; + for (i__ = 1; i__ <= i__2; ++i__) { + b[i__ + j * b_dim1] = alpha * b[i__ + j * b_dim1] + ; + /* L170: */ + } + } + i__2 = j - 1; + for (k = 1; k <= i__2; ++k) { + if (a[k + j * a_dim1] != 0.) { + i__3 = m; + for (i__ = 1; i__ <= i__3; ++i__) { + b[i__ + j * b_dim1] -= a[k + j * a_dim1] * b[ + i__ + k * b_dim1]; + /* L180: */ + } + } + /* L190: */ + } + if (nounit) { + temp = 1. / a[j + j * a_dim1]; + i__2 = m; + for (i__ = 1; i__ <= i__2; ++i__) { + b[i__ + j * b_dim1] = temp * b[i__ + j * b_dim1]; + /* L200: */ + } + } + /* L210: */ + } + } else { + for (j = n; j >= 1; --j) { + if (alpha != 1.) { + i__1 = m; + for (i__ = 1; i__ <= i__1; ++i__) { + b[i__ + j * b_dim1] = alpha * b[i__ + j * b_dim1] + ; + /* L220: */ + } + } + i__1 = n; + for (k = j + 1; k <= i__1; ++k) { + if (a[k + j * a_dim1] != 0.) { + i__2 = m; + for (i__ = 1; i__ <= i__2; ++i__) { + b[i__ + j * b_dim1] -= a[k + j * a_dim1] * b[ + i__ + k * b_dim1]; + /* L230: */ + } + } + /* L240: */ + } + if (nounit) { + temp = 1. / a[j + j * a_dim1]; + i__1 = m; + for (i__ = 1; i__ <= i__1; ++i__) { + b[i__ + j * b_dim1] = temp * b[i__ + j * b_dim1]; + /* L250: */ + } + } + /* L260: */ + } + } + } else { + + /* Form B := alpha*B*inv( A' ). */ + + if (upper) { + for (k = n; k >= 1; --k) { + if (nounit) { + temp = 1. / a[k + k * a_dim1]; + i__1 = m; + for (i__ = 1; i__ <= i__1; ++i__) { + b[i__ + k * b_dim1] = temp * b[i__ + k * b_dim1]; + /* L270: */ + } + } + i__1 = k - 1; + for (j = 1; j <= i__1; ++j) { + if (a[j + k * a_dim1] != 0.) { + temp = a[j + k * a_dim1]; + i__2 = m; + for (i__ = 1; i__ <= i__2; ++i__) { + b[i__ + j * b_dim1] -= temp * b[i__ + k * + b_dim1]; + /* L280: */ + } + } + /* L290: */ + } + if (alpha != 1.) { + i__1 = m; + for (i__ = 1; i__ <= i__1; ++i__) { + b[i__ + k * b_dim1] = alpha * b[i__ + k * b_dim1] + ; + /* L300: */ + } + } + /* L310: */ + } + } else { + i__1 = n; + for (k = 1; k <= i__1; ++k) { + if (nounit) { + temp = 1. / a[k + k * a_dim1]; + i__2 = m; + for (i__ = 1; i__ <= i__2; ++i__) { + b[i__ + k * b_dim1] = temp * b[i__ + k * b_dim1]; + /* L320: */ + } + } + i__2 = n; + for (j = k + 1; j <= i__2; ++j) { + if (a[j + k * a_dim1] != 0.) { + temp = a[j + k * a_dim1]; + i__3 = m; + for (i__ = 1; i__ <= i__3; ++i__) { + b[i__ + j * b_dim1] -= temp * b[i__ + k * + b_dim1]; + /* L330: */ + } + } + /* L340: */ + } + if (alpha != 1.) { + i__2 = m; + for (i__ = 1; i__ <= i__2; ++i__) { + b[i__ + k * b_dim1] = alpha * b[i__ + k * b_dim1] + ; + /* L350: */ + } + } + /* L360: */ + } + } + } + } + } + + // ------------------------------------------------------------------------------------ + + inline void cblas_trsm(const CBLAS_ORDER Order, const CBLAS_SIDE Side, + const CBLAS_UPLO Uplo, const CBLAS_TRANSPOSE TransA, + const CBLAS_DIAG Diag, const int M, const int N, + const float alpha, const float *A, const int lda, + float *B, const int ldb) + { +#ifdef DLIB_USE_BLAS + if (M > 4) + { + cblas_strsm(Order, Side, Uplo, TransA, Diag, M, N, alpha, A, lda, B, ldb); + return; + } +#endif + local_trsm(Order, Side, Uplo, TransA, Diag, M, N, alpha, A, lda, B, ldb); + } + + inline void cblas_trsm(const CBLAS_ORDER Order, const CBLAS_SIDE Side, + const CBLAS_UPLO Uplo, const CBLAS_TRANSPOSE TransA, + const CBLAS_DIAG Diag, const int M, const int N, + const double alpha, const double *A, const int lda, + double *B, const int ldb) + { +#ifdef DLIB_USE_BLAS + if (M > 4) + { + cblas_dtrsm(Order, Side, Uplo, TransA, Diag, M, N, alpha, A, lda, B, ldb); + return; + } +#endif + local_trsm(Order, Side, Uplo, TransA, Diag, M, N, alpha, A, lda, B, ldb); + } + + inline void cblas_trsm(const CBLAS_ORDER Order, const CBLAS_SIDE Side, + const CBLAS_UPLO Uplo, const CBLAS_TRANSPOSE TransA, + const CBLAS_DIAG Diag, const int M, const int N, + const long double alpha, const long double *A, const int lda, + long double *B, const int ldb) + { + local_trsm(Order, Side, Uplo, TransA, Diag, M, N, alpha, A, lda, B, ldb); + } + + // ------------------------------------------------------------------------------------ + + template < + typename T, + long NR1, long NR2, + long NC1, long NC2, + typename MM + > + inline void triangular_solver ( + const CBLAS_SIDE Side, + const CBLAS_UPLO Uplo, + const CBLAS_TRANSPOSE TransA, + const CBLAS_DIAG Diag, + const matrix& A, + const T alpha, + matrix& B + ) + { + cblas_trsm(CblasRowMajor, Side, Uplo, TransA, Diag, B.nr(), B.nc(), + alpha, &A(0,0), A.nc(), &B(0,0), B.nc()); + } + + // ------------------------------------------------------------------------------------ + + template < + typename T, + long NR1, long NR2, + long NC1, long NC2, + typename MM + > + inline void triangular_solver ( + const CBLAS_SIDE Side, + const CBLAS_UPLO Uplo, + const CBLAS_TRANSPOSE TransA, + const CBLAS_DIAG Diag, + const matrix& A, + const T alpha, + matrix& B + ) + { + cblas_trsm(CblasColMajor, Side, Uplo, TransA, Diag, B.nr(), B.nc(), + alpha, &A(0,0), A.nr(), &B(0,0), B.nr()); + } + + // ------------------------------------------------------------------------------------ + + template < + typename T, + long NR1, long NR2, + long NC1, long NC2, + typename MM + > + inline void triangular_solver ( + const CBLAS_SIDE Side, + const CBLAS_UPLO Uplo, + const CBLAS_TRANSPOSE TransA, + const CBLAS_DIAG Diag, + const matrix& A, + matrix& B, + long rows_of_B + ) + { + const T alpha = 1; + cblas_trsm(CblasColMajor, Side, Uplo, TransA, Diag, rows_of_B, B.nc(), + alpha, &A(0,0), A.nr(), &B(0,0), B.nr()); + } + + // ------------------------------------------------------------------------------------ + + template < + typename T, + long NR1, long NR2, + long NC1, long NC2, + typename MM, + typename layout + > + inline void triangular_solver ( + const CBLAS_SIDE Side, + const CBLAS_UPLO Uplo, + const CBLAS_TRANSPOSE TransA, + const CBLAS_DIAG Diag, + const matrix& A, + matrix& B + ) + { + const T alpha = 1; + triangular_solver(Side, Uplo, TransA, Diag, A, alpha, B); + } + + // ------------------------------------------------------------------------------------ + + } +} + +#endif // DLIB_MATRiX_TRSM_Hh_ + diff --git a/ml/dlib/dlib/matrix/matrix_utilities.h b/ml/dlib/dlib/matrix/matrix_utilities.h new file mode 100644 index 000000000..0c5091a4b --- /dev/null +++ b/ml/dlib/dlib/matrix/matrix_utilities.h @@ -0,0 +1,4544 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MATRIx_UTILITIES_ +#define DLIB_MATRIx_UTILITIES_ + +#include "matrix_utilities_abstract.h" +#include "matrix.h" +#include +#include +#include +#include "../pixel.h" +#include "../stl_checked.h" +#include +#include +#include "../std_allocator.h" +#include "matrix_expressions.h" +#include "matrix_math_functions.h" +#include "matrix_op.h" +#include "../general_hash/random_hashing.h" +#include "matrix_mat.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + /*!A is_complex + This is a template that can be used to determine if a type is a specialization + of the std::complex template class. + + For example: + is_complex::value == false + is_complex >::value == true + !*/ + + template + struct is_complex { static const bool value = false; }; + + template + struct is_complex > { static const bool value = true; }; + template + struct is_complex& > { static const bool value = true; }; + template + struct is_complex& > { static const bool value = true; }; + template + struct is_complex > { static const bool value = true; }; + +// ---------------------------------------------------------------------------------------- + + template + inline bool is_row_vector ( + const matrix_exp& m + ) { return m.nr() == 1; } + + template + inline bool is_col_vector ( + const matrix_exp& m + ) { return m.nc() == 1; } + + template + inline bool is_vector ( + const matrix_exp& m + ) { return is_row_vector(m) || is_col_vector(m); } + +// ---------------------------------------------------------------------------------------- + + template + inline bool is_finite ( + const matrix_exp& m + ) + { + for (long r = 0; r < m.nr(); ++r) + { + for (long c = 0; c < m.nc(); ++c) + { + if (!is_finite(m(r,c))) + return false; + } + } + return true; + } + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + template + const T& magnitude (const T& item) { return item; } + template + T magnitude (const std::complex& item) { return std::norm(item); } + } + + template < + typename EXP + > + void find_min_and_max ( + const matrix_exp& m, + typename EXP::type& min_val, + typename EXP::type& max_val + ) + { + DLIB_ASSERT(m.size() > 0, + "\ttype find_min_and_max(const matrix_exp& m, min_val, max_val)" + << "\n\tYou can't ask for the min and max of an empty matrix" + << "\n\tm.size(): " << m.size() + ); + typedef typename matrix_exp::type type; + + min_val = m(0,0); + max_val = min_val; + for (long r = 0; r < m.nr(); ++r) + { + for (long c = 0; c < m.nc(); ++c) + { + type temp = m(r,c); + if (dlib::impl::magnitude(temp) > dlib::impl::magnitude(max_val)) + max_val = temp; + if (dlib::impl::magnitude(temp) < dlib::impl::magnitude(min_val)) + min_val = temp; + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP + > + point max_point ( + const matrix_exp& m + ) + { + DLIB_ASSERT(m.size() > 0, + "\tpoint max_point(const matrix_exp& m)" + << "\n\tm can't be empty" + << "\n\tm.size(): " << m.size() + << "\n\tm.nr(): " << m.nr() + << "\n\tm.nc(): " << m.nc() + ); + typedef typename matrix_exp::type type; + + point best_point(0,0); + type val = m(0,0); + for (long r = 0; r < m.nr(); ++r) + { + for (long c = 0; c < m.nc(); ++c) + { + type temp = m(r,c); + if (dlib::impl::magnitude(temp) > dlib::impl::magnitude(val)) + { + val = temp; + best_point = point(c,r); + } + } + } + return best_point; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP + > + point min_point ( + const matrix_exp& m + ) + { + DLIB_ASSERT(m.size() > 0, + "\tpoint min_point(const matrix_exp& m)" + << "\n\tm can't be empty" + << "\n\tm.size(): " << m.size() + << "\n\tm.nr(): " << m.nr() + << "\n\tm.nc(): " << m.nc() + ); + typedef typename matrix_exp::type type; + + point best_point(0,0); + type val = m(0,0); + for (long r = 0; r < m.nr(); ++r) + { + for (long c = 0; c < m.nc(); ++c) + { + type temp = m(r,c); + if (dlib::impl::magnitude(temp) < dlib::impl::magnitude(val)) + { + val = temp; + best_point = point(c,r); + } + } + } + return best_point; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP + > + long index_of_max ( + const matrix_exp& m + ) + { + DLIB_ASSERT(m.size() > 0 && is_vector(m) == true, + "\tlong index_of_max(const matrix_exp& m)" + << "\n\tm must be a row or column matrix" + << "\n\tm.size(): " << m.size() + << "\n\tm.nr(): " << m.nr() + << "\n\tm.nc(): " << m.nc() + ); + typedef typename matrix_exp::type type; + + type val = m(0); + long best_idx = 0; + for (long i = 1; i < m.size(); ++i) + { + type temp = m(i); + if (dlib::impl::magnitude(temp) > dlib::impl::magnitude(val)) + { + val = temp; + best_idx = i; + } + } + return best_idx; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP + > + long index_of_min ( + const matrix_exp& m + ) + { + DLIB_ASSERT(m.size() > 0 && is_vector(m), + "\tlong index_of_min(const matrix_exp& m)" + << "\n\tm must be a row or column matrix" + << "\n\tm.size(): " << m.size() + << "\n\tm.nr(): " << m.nr() + << "\n\tm.nc(): " << m.nc() + ); + typedef typename matrix_exp::type type; + + type val = m(0); + long best_idx = 0; + for (long i = 1; i < m.size(); ++i) + { + type temp = m(i); + if (dlib::impl::magnitude(temp) < dlib::impl::magnitude(val)) + { + val = temp; + best_idx = i; + } + } + return best_idx; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP + > + const typename matrix_exp::type max ( + const matrix_exp& m + ) + { + DLIB_ASSERT(m.size() > 0, + "\ttype max(const matrix_exp& m)" + << "\n\tYou can't ask for the max() of an empty matrix" + << "\n\tm.size(): " << m.size() + ); + typedef typename matrix_exp::type type; + + type val = m(0,0); + for (long r = 0; r < m.nr(); ++r) + { + for (long c = 0; c < m.nc(); ++c) + { + type temp = m(r,c); + if (dlib::impl::magnitude(temp) > dlib::impl::magnitude(val)) + val = temp; + } + } + return val; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP + > + const typename matrix_exp::type min ( + const matrix_exp& m + ) + { + DLIB_ASSERT(m.size() > 0, + "\ttype min(const matrix_exp& m)" + << "\n\tYou can't ask for the min() of an empty matrix" + << "\n\tm.size(): " << m.size() + ); + typedef typename matrix_exp::type type; + + type val = m(0,0); + for (long r = 0; r < m.nr(); ++r) + { + for (long c = 0; c < m.nc(); ++c) + { + type temp = m(r,c); + if (dlib::impl::magnitude(temp) < dlib::impl::magnitude(val)) + val = temp; + } + } + return val; + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_binary_min : basic_op_mm + { + op_binary_min( const M1& m1_, const M2& m2_) : basic_op_mm(m1_,m2_){} + + typedef typename M1::type type; + typedef const type const_ret_type; + const static long cost = M1::cost + M2::cost + 1; + + const_ret_type apply ( long r, long c) const + { return std::min(this->m1(r,c),this->m2(r,c)); } + }; + + template < + typename EXP1, + typename EXP2 + > + inline const matrix_op > min_pointwise ( + const matrix_exp& a, + const matrix_exp& b + ) + { + COMPILE_TIME_ASSERT((is_same_type::value == true)); + COMPILE_TIME_ASSERT(EXP1::NR == EXP2::NR || EXP1::NR == 0 || EXP2::NR == 0); + COMPILE_TIME_ASSERT(EXP1::NC == EXP2::NC || EXP1::NC == 0 || EXP2::NC == 0); + DLIB_ASSERT(a.nr() == b.nr() && + a.nc() == b.nc(), + "\t const matrix_exp min_pointwise(const matrix_exp& a, const matrix_exp& b)" + << "\n\ta.nr(): " << a.nr() + << "\n\ta.nc(): " << a.nc() + << "\n\tb.nr(): " << b.nr() + << "\n\tb.nc(): " << b.nc() + ); + typedef op_binary_min op; + return matrix_op(op(a.ref(),b.ref())); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_min_pointwise3 : basic_op_mmm + { + op_min_pointwise3( const M1& m1_, const M2& m2_, const M3& m3_) : + basic_op_mmm(m1_,m2_,m3_){} + + typedef typename M1::type type; + typedef const typename M1::type const_ret_type; + const static long cost = M1::cost + M2::cost + M3::cost + 2; + + const_ret_type apply (long r, long c) const + { return std::min(this->m1(r,c),std::min(this->m2(r,c),this->m3(r,c))); } + }; + + template < + typename EXP1, + typename EXP2, + typename EXP3 + > + inline const matrix_op > + min_pointwise ( + const matrix_exp& a, + const matrix_exp& b, + const matrix_exp& c + ) + { + COMPILE_TIME_ASSERT((is_same_type::value == true)); + COMPILE_TIME_ASSERT((is_same_type::value == true)); + COMPILE_TIME_ASSERT(EXP1::NR == EXP2::NR || EXP1::NR == 0 || EXP2::NR == 0); + COMPILE_TIME_ASSERT(EXP1::NC == EXP2::NC || EXP1::NR == 0 || EXP2::NC == 0); + COMPILE_TIME_ASSERT(EXP2::NR == EXP3::NR || EXP2::NR == 0 || EXP3::NR == 0); + COMPILE_TIME_ASSERT(EXP2::NC == EXP3::NC || EXP2::NC == 0 || EXP3::NC == 0); + DLIB_ASSERT(a.nr() == b.nr() && + a.nc() == b.nc() && + b.nr() == c.nr() && + b.nc() == c.nc(), + "\tconst matrix_exp min_pointwise(a,b,c)" + << "\n\tYou can only make a do a pointwise min between equally sized matrices" + << "\n\ta.nr(): " << a.nr() + << "\n\ta.nc(): " << a.nc() + << "\n\tb.nr(): " << b.nr() + << "\n\tb.nc(): " << b.nc() + << "\n\tc.nr(): " << c.nr() + << "\n\tc.nc(): " << c.nc() + ); + + typedef op_min_pointwise3 op; + return matrix_op(op(a.ref(),b.ref(),c.ref())); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_binary_max : basic_op_mm + { + op_binary_max( const M1& m1_, const M2& m2_) : basic_op_mm(m1_,m2_){} + + typedef typename M1::type type; + typedef const type const_ret_type; + const static long cost = M1::cost + M2::cost + 1; + + const_ret_type apply ( long r, long c) const + { return std::max(this->m1(r,c),this->m2(r,c)); } + }; + + template < + typename EXP1, + typename EXP2 + > + inline const matrix_op > max_pointwise ( + const matrix_exp& a, + const matrix_exp& b + ) + { + COMPILE_TIME_ASSERT((is_same_type::value == true)); + COMPILE_TIME_ASSERT(EXP1::NR == EXP2::NR || EXP1::NR == 0 || EXP2::NR == 0); + COMPILE_TIME_ASSERT(EXP1::NC == EXP2::NC || EXP1::NC == 0 || EXP2::NC == 0); + DLIB_ASSERT(a.nr() == b.nr() && + a.nc() == b.nc(), + "\t const matrix_exp max_pointwise(const matrix_exp& a, const matrix_exp& b)" + << "\n\ta.nr(): " << a.nr() + << "\n\ta.nc(): " << a.nc() + << "\n\tb.nr(): " << b.nr() + << "\n\tb.nc(): " << b.nc() + ); + typedef op_binary_max op; + return matrix_op(op(a.ref(),b.ref())); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_max_pointwise3 : basic_op_mmm + { + op_max_pointwise3( const M1& m1_, const M2& m2_, const M3& m3_) : + basic_op_mmm(m1_,m2_,m3_){} + + typedef typename M1::type type; + typedef const typename M1::type const_ret_type; + const static long cost = M1::cost + M2::cost + M3::cost + 2; + + const_ret_type apply (long r, long c) const + { return std::max(this->m1(r,c),std::max(this->m2(r,c),this->m3(r,c))); } + }; + + template < + typename EXP1, + typename EXP2, + typename EXP3 + > + inline const matrix_op > + max_pointwise ( + const matrix_exp& a, + const matrix_exp& b, + const matrix_exp& c + ) + { + COMPILE_TIME_ASSERT((is_same_type::value == true)); + COMPILE_TIME_ASSERT((is_same_type::value == true)); + COMPILE_TIME_ASSERT(EXP1::NR == EXP2::NR || EXP1::NR == 0 || EXP2::NR == 0); + COMPILE_TIME_ASSERT(EXP1::NC == EXP2::NC || EXP1::NR == 0 || EXP2::NC == 0); + COMPILE_TIME_ASSERT(EXP2::NR == EXP3::NR || EXP2::NR == 0 || EXP3::NR == 0); + COMPILE_TIME_ASSERT(EXP2::NC == EXP3::NC || EXP2::NC == 0 || EXP3::NC == 0); + DLIB_ASSERT(a.nr() == b.nr() && + a.nc() == b.nc() && + b.nr() == c.nr() && + b.nc() == c.nc(), + "\tconst matrix_exp max_pointwise(a,b,c)" + << "\n\tYou can only make a do a pointwise max between equally sized matrices" + << "\n\ta.nr(): " << a.nr() + << "\n\ta.nc(): " << a.nc() + << "\n\tb.nr(): " << b.nr() + << "\n\tb.nc(): " << b.nc() + << "\n\tc.nr(): " << c.nr() + << "\n\tc.nc(): " << c.nc() + ); + + typedef op_max_pointwise3 op; + return matrix_op(op(a.ref(),b.ref(),c.ref())); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP + > + typename enable_if_c::is_integer, double>::type length ( + const matrix_exp& m + ) + { + DLIB_ASSERT(is_vector(m) == true, + "\ttype length(const matrix_exp& m)" + << "\n\tm must be a row or column vector" + << "\n\tm.nr(): " << m.nr() + << "\n\tm.nc(): " << m.nc() + ); + + return std::sqrt(static_cast(sum(squared(m)))); + } + + template < + typename EXP + > + typename disable_if_c::is_integer, const typename EXP::type>::type length ( + const matrix_exp& m + ) + { + DLIB_ASSERT(is_vector(m) == true, + "\ttype length(const matrix_exp& m)" + << "\n\tm must be a row or column vector" + << "\n\tm.nr(): " << m.nr() + << "\n\tm.nc(): " << m.nc() + ); + return std::sqrt(sum(squared(m))); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP + > + const typename matrix_exp::type length_squared ( + const matrix_exp& m + ) + { + DLIB_ASSERT(is_vector(m) == true, + "\ttype length_squared(const matrix_exp& m)" + << "\n\tm must be a row or column vector" + << "\n\tm.nr(): " << m.nr() + << "\n\tm.nc(): " << m.nc() + ); + return sum(squared(m)); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template + struct op_trans + { + op_trans( const M& m_) : m(m_){} + + const M& m; + + const static long cost = M::cost; + const static long NR = M::NC; + const static long NC = M::NR; + typedef typename M::type type; + typedef typename M::const_ret_type const_ret_type; + typedef typename M::mem_manager_type mem_manager_type; + typedef typename M::layout_type layout_type; + + const_ret_type apply (long r, long c) const { return m(c,r); } + + long nr () const { return m.nc(); } + long nc () const { return m.nr(); } + + template bool aliases ( const matrix_exp& item) const { return m.aliases(item); } + template bool destructively_aliases ( const matrix_exp& item) const { return m.aliases(item); } + + }; + + template < + typename M + > + const matrix_op > trans ( + const matrix_exp& m + ) + { + typedef op_trans op; + return matrix_op(op(m.ref())); + } + +// ---------------------------------------------------------------------------------------- + +// don't to anything at all for diagonal matrices + template < + typename M + > + const matrix_diag_exp& trans ( + const matrix_diag_exp& m + ) + { + return m; + } + +// ---------------------------------------------------------------------------------------- + +// I introduced this struct because it avoids an inane compiler warning from gcc + template + struct is_not_ct_vector{ static const bool value = (EXP::NR != 1 && EXP::NC != 1); }; + + template < + typename EXP1, + typename EXP2 + > + typename enable_if_c<(is_not_ct_vector::value) || (is_not_ct_vector::value), + typename EXP1::type>::type + dot ( + const matrix_exp& m1, + const matrix_exp& m2 + ) + { + // You are getting an error on this line because you are trying to + // compute the dot product between two matrices that aren't both vectors (i.e. + // they aren't column or row matrices). + COMPILE_TIME_ASSERT(EXP1::NR*EXP1::NC == 0 || + EXP2::NR*EXP2::NC == 0); + + DLIB_ASSERT(is_vector(m1) && is_vector(m2) && m1.size() == m2.size() && + m1.size() > 0, + "\t type dot(const matrix_exp& m1, const matrix_exp& m2)" + << "\n\t You can only compute the dot product between non-empty vectors of equal length." + << "\n\t is_vector(m1): " << is_vector(m1) + << "\n\t is_vector(m2): " << is_vector(m2) + << "\n\t m1.size(): " << m1.size() + << "\n\t m2.size(): " << m2.size() + ); + + if (is_col_vector(m1) && is_col_vector(m2)) return (trans(m1)*m2)(0); + if (is_col_vector(m1) && is_row_vector(m2)) return (m2*m1)(0); + if (is_row_vector(m1) && is_col_vector(m2)) return (m1*m2)(0); + + //if (is_row_vector(m1) && is_row_vector(m2)) + return (m1*trans(m2))(0); + } + + template < typename EXP1, typename EXP2 > + typename enable_if_c::type + dot ( const matrix_exp& m1, const matrix_exp& m2) + { + DLIB_ASSERT(m1.size() == m2.size(), + "\t type dot(const matrix_exp& m1, const matrix_exp& m2)" + << "\n\t You can only compute the dot product between vectors of equal length" + << "\n\t m1.size(): " << m1.size() + << "\n\t m2.size(): " << m2.size() + ); + + return m1*trans(m2); + } + + template < typename EXP1, typename EXP2 > + typename enable_if_c::type + dot ( const matrix_exp& m1, const matrix_exp& m2) + { + DLIB_ASSERT(m1.size() == m2.size(), + "\t type dot(const matrix_exp& m1, const matrix_exp& m2)" + << "\n\t You can only compute the dot product between vectors of equal length" + << "\n\t m1.size(): " << m1.size() + << "\n\t m2.size(): " << m2.size() + ); + + return m1*m2; + } + + template < typename EXP1, typename EXP2 > + typename enable_if_c::type + dot ( const matrix_exp& m1, const matrix_exp& m2) + { + DLIB_ASSERT(m1.size() == m2.size(), + "\t type dot(const matrix_exp& m1, const matrix_exp& m2)" + << "\n\t You can only compute the dot product between vectors of equal length" + << "\n\t m1.size(): " << m1.size() + << "\n\t m2.size(): " << m2.size() + ); + + return m2*m1; + } + + template < typename EXP1, typename EXP2 > + typename enable_if_c::type + dot ( const matrix_exp& m1, const matrix_exp& m2) + { + DLIB_ASSERT(m1.size() == m2.size(), + "\t type dot(const matrix_exp& m1, const matrix_exp& m2)" + << "\n\t You can only compute the dot product between vectors of equal length" + << "\n\t m1.size(): " << m1.size() + << "\n\t m2.size(): " << m2.size() + ); + + return trans(m1)*m2; + } + + template < typename EXP1, typename EXP2 > + typename enable_if_c<(EXP1::NC*EXP1::NR == 1) || (EXP2::NC*EXP2::NR == 1), typename EXP1::type>::type + dot ( const matrix_exp& m1, const matrix_exp& m2) + { + DLIB_ASSERT(m1.size() == m2.size(), + "\t type dot(const matrix_exp& m1, const matrix_exp& m2)" + << "\n\t You can only compute the dot product between vectors of equal length" + << "\n\t m1.size(): " << m1.size() + << "\n\t m2.size(): " << m2.size() + ); + + return m1(0)*m2(0); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_removerc + { + op_removerc( const M& m_) : m(m_){} + + const M& m; + + const static long cost = M::cost+2; + const static long NR = (M::NR==0) ? 0 : (M::NR - 1); + const static long NC = (M::NC==0) ? 0 : (M::NC - 1); + typedef typename M::type type; + typedef typename M::const_ret_type const_ret_type; + typedef typename M::mem_manager_type mem_manager_type; + typedef typename M::layout_type layout_type; + const_ret_type apply (long r, long c) const + { + if (r < R) + { + if (c < C) + return m(r,c); + else + return m(r,c+1); + } + else + { + if (c < C) + return m(r+1,c); + else + return m(r+1,c+1); + } + } + + long nr () const { return m.nr() - 1; } + long nc () const { return m.nc() - 1; } + + template bool aliases ( const matrix_exp& item) const { return m.aliases(item); } + template bool destructively_aliases ( const matrix_exp& item) const { return m.aliases(item); } + }; + + template + struct op_removerc2 + { + op_removerc2( const M& m_, const long R_, const long C_) : m(m_), R(R_), C(C_){} + const M& m; + const long R; + const long C; + + const static long cost = M::cost+2; + const static long NR = (M::NR==0) ? 0 : (M::NR - 1); + const static long NC = (M::NC==0) ? 0 : (M::NC - 1); + typedef typename M::type type; + typedef typename M::const_ret_type const_ret_type; + typedef typename M::mem_manager_type mem_manager_type; + typedef typename M::layout_type layout_type; + const_ret_type apply (long r, long c) const + { + if (r < R) + { + if (c < C) + return m(r,c); + else + return m(r,c+1); + } + else + { + if (c < C) + return m(r+1,c); + else + return m(r+1,c+1); + } + } + + long nr () const { return m.nr() - 1; } + long nc () const { return m.nc() - 1; } + + template bool aliases ( const matrix_exp& item) const { return m.aliases(item); } + template bool destructively_aliases ( const matrix_exp& item) const { return m.aliases(item); } + }; + + template < + long R, + long C, + typename EXP + > + const matrix_op > removerc ( + const matrix_exp& m + ) + { + // you can't remove a row from a matrix with only one row + COMPILE_TIME_ASSERT((EXP::NR > R && R >= 0) || EXP::NR == 0); + // you can't remove a column from a matrix with only one column + COMPILE_TIME_ASSERT((EXP::NC > C && C >= 0) || EXP::NR == 0); + DLIB_ASSERT(m.nr() > R && R >= 0 && m.nc() > C && C >= 0, + "\tconst matrix_exp removerc(const matrix_exp& m)" + << "\n\tYou can't remove a row/column from a matrix if it doesn't have that row/column" + << "\n\tm.nr(): " << m.nr() + << "\n\tm.nc(): " << m.nc() + << "\n\tR: " << R + << "\n\tC: " << C + ); + typedef op_removerc op; + return matrix_op(op(m.ref())); + } + + template < + typename EXP + > + const matrix_op > removerc ( + const matrix_exp& m, + long R, + long C + ) + { + DLIB_ASSERT(m.nr() > R && R >= 0 && m.nc() > C && C >= 0, + "\tconst matrix_exp removerc(const matrix_exp& m,R,C)" + << "\n\tYou can't remove a row/column from a matrix if it doesn't have that row/column" + << "\n\tm.nr(): " << m.nr() + << "\n\tm.nc(): " << m.nc() + << "\n\tR: " << R + << "\n\tC: " << C + ); + typedef op_removerc2 op; + return matrix_op(op(m.ref(),R,C)); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_remove_col + { + op_remove_col( const M& m_) : m(m_){} + const M& m; + + const static long cost = M::cost+2; + const static long NR = M::NR; + const static long NC = (M::NC==0) ? 0 : (M::NC - 1); + typedef typename M::type type; + typedef typename M::const_ret_type const_ret_type; + typedef typename M::mem_manager_type mem_manager_type; + typedef typename M::layout_type layout_type; + const_ret_type apply ( long r, long c) const + { + if (c < C) + { + return m(r,c); + } + else + { + return m(r,c+1); + } + } + + long nr () const { return m.nr(); } + long nc () const { return m.nc() - 1; } + + template bool aliases ( const matrix_exp& item) const { return m.aliases(item); } + template bool destructively_aliases ( const matrix_exp& item) const { return m.aliases(item); } + }; + + template + struct op_remove_col2 + { + op_remove_col2( const M& m_, const long C_) : m(m_), C(C_){} + const M& m; + const long C; + + const static long cost = M::cost+2; + const static long NR = M::NR; + const static long NC = (M::NC==0) ? 0 : (M::NC - 1); + typedef typename M::type type; + typedef typename M::const_ret_type const_ret_type; + typedef typename M::mem_manager_type mem_manager_type; + typedef typename M::layout_type layout_type; + const_ret_type apply ( long r, long c) const + { + if (c < C) + { + return m(r,c); + } + else + { + return m(r,c+1); + } + } + + long nr () const { return m.nr(); } + long nc () const { return m.nc() - 1; } + + template bool aliases ( const matrix_exp& item) const { return m.aliases(item); } + template bool destructively_aliases ( const matrix_exp& item) const { return m.aliases(item); } + }; + + template < + long C, + typename EXP + > + const matrix_op > remove_col ( + const matrix_exp& m + ) + { + // You can't remove the given column from the matrix because the matrix doesn't + // have a column with that index. + COMPILE_TIME_ASSERT((EXP::NC > C && C >= 0) || EXP::NC == 0); + DLIB_ASSERT(m.nc() > C && C >= 0 , + "\tconst matrix_exp remove_col(const matrix_exp& m)" + << "\n\tYou can't remove a col from a matrix if it doesn't have it" + << "\n\tm.nr(): " << m.nr() + << "\n\tm.nc(): " << m.nc() + << "\n\tC: " << C + ); + typedef op_remove_col op; + return matrix_op(op(m.ref())); + } + + template < + typename EXP + > + const matrix_op > remove_col ( + const matrix_exp& m, + long C + ) + { + DLIB_ASSERT(m.nc() > C && C >= 0 , + "\tconst matrix_exp remove_col(const matrix_exp& m,C)" + << "\n\tYou can't remove a col from a matrix if it doesn't have it" + << "\n\tm.nr(): " << m.nr() + << "\n\tm.nc(): " << m.nc() + << "\n\tC: " << C + ); + typedef op_remove_col2 op; + return matrix_op(op(m.ref(),C)); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_remove_row + { + op_remove_row( const M& m_) : m(m_){} + const M& m; + + const static long cost = M::cost+2; + const static long NR = (M::NR==0) ? 0 : (M::NR - 1); + const static long NC = M::NC; + typedef typename M::type type; + typedef typename M::const_ret_type const_ret_type; + typedef typename M::mem_manager_type mem_manager_type; + typedef typename M::layout_type layout_type; + const_ret_type apply ( long r, long c) const + { + if (r < R) + { + return m(r,c); + } + else + { + return m(r+1,c); + } + } + + long nr () const { return m.nr() - 1; } + long nc () const { return m.nc(); } + + template bool aliases ( const matrix_exp& item) const { return m.aliases(item); } + template bool destructively_aliases ( const matrix_exp& item) const { return m.aliases(item); } + }; + + template + struct op_remove_row2 + { + op_remove_row2( const M& m_, const long R_) : m(m_), R(R_){} + const M& m; + const long R; + + const static long cost = M::cost+2; + const static long NR = (M::NR==0) ? 0 : (M::NR - 1); + const static long NC = M::NC; + typedef typename M::type type; + typedef typename M::const_ret_type const_ret_type; + typedef typename M::mem_manager_type mem_manager_type; + typedef typename M::layout_type layout_type; + const_ret_type apply ( long r, long c) const + { + if (r < R) + { + return m(r,c); + } + else + { + return m(r+1,c); + } + } + + long nr () const { return m.nr() - 1; } + long nc () const { return m.nc(); } + + template bool aliases ( const matrix_exp& item) const { return m.aliases(item); } + template bool destructively_aliases ( const matrix_exp& item) const { return m.aliases(item); } + }; + + template < + long R, + typename EXP + > + const matrix_op > remove_row ( + const matrix_exp& m + ) + { + // You can't remove the given row from the matrix because the matrix doesn't + // have a row with that index. + COMPILE_TIME_ASSERT((EXP::NR > R && R >= 0) || EXP::NR == 0); + DLIB_ASSERT(m.nr() > R && R >= 0, + "\tconst matrix_exp remove_row(const matrix_exp& m)" + << "\n\tYou can't remove a row from a matrix if it doesn't have it" + << "\n\tm.nr(): " << m.nr() + << "\n\tm.nc(): " << m.nc() + << "\n\tR: " << R + ); + typedef op_remove_row op; + return matrix_op(op(m.ref())); + } + + template < + typename EXP + > + const matrix_op > remove_row ( + const matrix_exp& m, + long R + ) + { + DLIB_ASSERT(m.nr() > R && R >= 0, + "\tconst matrix_exp remove_row(const matrix_exp& m, long R)" + << "\n\tYou can't remove a row from a matrix if it doesn't have it" + << "\n\tm.nr(): " << m.nr() + << "\n\tm.nc(): " << m.nc() + << "\n\tR: " << R + ); + typedef op_remove_row2 op; + return matrix_op(op(m.ref(),R)); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_diagm + { + op_diagm( const M& m_) : m(m_){} + const M& m; + + const static long cost = M::cost+2; + const static long N = M::NC*M::NR; + const static long NR = N; + const static long NC = N; + typedef typename M::type type; + typedef const typename M::type const_ret_type; + typedef typename M::mem_manager_type mem_manager_type; + typedef typename M::layout_type layout_type; + const_ret_type apply ( long r, long c) const + { + if (r==c) + return m(r); + else + return 0; + } + + long nr () const { return (m.nr()>m.nc())? m.nr():m.nc(); } + long nc () const { return (m.nr()>m.nc())? m.nr():m.nc(); } + + template bool aliases ( const matrix_exp& item) const { return m.aliases(item); } + template bool destructively_aliases ( const matrix_exp& item) const { return m.aliases(item); } + }; + + template < + typename EXP + > + const matrix_diag_op > diagm ( + const matrix_exp& m + ) + { + // You can only make a diagonal matrix out of a row or column vector + COMPILE_TIME_ASSERT(EXP::NR == 0 || EXP::NR == 1 || EXP::NC == 1 || EXP::NC == 0); + DLIB_ASSERT(is_vector(m), + "\tconst matrix_exp diagm(const matrix_exp& m)" + << "\n\tYou can only apply diagm() to a row or column matrix" + << "\n\tm.nr(): " << m.nr() + << "\n\tm.nc(): " << m.nc() + ); + typedef op_diagm op; + return matrix_diag_op(op(m.ref())); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_diagm_mult : basic_op_mm + { + op_diagm_mult( const M1& m1_, const M2& m2_) : basic_op_mm(m1_,m2_){} + + typedef typename M1::type type; + typedef const type const_ret_type; + const static long cost = M1::cost + M2::cost + 1; + + const_ret_type apply ( long r, long c) const + { + if (r == c) + return this->m1(r,c)*this->m2(r,c); + else + return 0; + } + }; + + template < + typename EXP1, + typename EXP2 + > + inline const matrix_diag_op > operator* ( + const matrix_diag_exp& a, + const matrix_diag_exp& b + ) + { + COMPILE_TIME_ASSERT((is_same_type::value)); + COMPILE_TIME_ASSERT(EXP1::NR == EXP2::NR || EXP1::NR == 0 || EXP2::NR == 0); + COMPILE_TIME_ASSERT(EXP1::NC == EXP2::NC || EXP1::NC == 0 || EXP2::NC == 0); + DLIB_ASSERT(a.nr() == b.nr() && + a.nc() == b.nc(), + "\tconst matrix_exp operator(const matrix_diag_exp& a, const matrix_diag_exp& b)" + << "\n\tYou can only multiply diagonal matrices together if they are the same size" + << "\n\ta.nr(): " << a.nr() + << "\n\ta.nc(): " << a.nc() + << "\n\tb.nr(): " << b.nr() + << "\n\tb.nc(): " << b.nc() + ); + typedef op_diagm_mult op; + return matrix_diag_op(op(a.ref(),b.ref())); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_diag + { + op_diag( const M& m_) : m(m_){} + const M& m; + + const static long cost = M::cost; + const static long NR = tmin::value; + const static long NC = 1; + typedef typename M::type type; + typedef typename M::const_ret_type const_ret_type; + typedef typename M::mem_manager_type mem_manager_type; + typedef typename M::layout_type layout_type; + const_ret_type apply ( long r, long ) const { return m(r,r); } + + long nr () const { return std::min(m.nc(),m.nr()); } + long nc () const { return 1; } + + template bool aliases ( const matrix_exp& item) const { return m.aliases(item); } + template bool destructively_aliases ( const matrix_exp& item) const { return m.aliases(item); } + }; + + template < + typename EXP + > + const matrix_op > diag ( + const matrix_exp& m + ) + { + typedef op_diag op; + return matrix_op(op(m.ref())); + } + + template + struct diag_exp + { + typedef matrix_op > type; + }; + +// ---------------------------------------------------------------------------------------- + + template + struct op_cast + { + op_cast( const M& m_) : m(m_){} + const M& m; + + const static long cost = M::cost+2; + const static long NR = M::NR; + const static long NC = M::NC; + typedef target_type type; + typedef const target_type const_ret_type; + typedef typename M::mem_manager_type mem_manager_type; + typedef typename M::layout_type layout_type; + const_ret_type apply ( long r, long c) const { return static_cast(m(r,c)); } + + long nr () const { return m.nr(); } + long nc () const { return m.nc(); } + + template bool aliases ( const matrix_exp& item) const { return m.aliases(item); } + template bool destructively_aliases ( const matrix_exp& item) const { return m.destructively_aliases(item); } + }; + + template < + typename target_type, + typename EXP + > + const matrix_op > matrix_cast ( + const matrix_exp& m + ) + { + typedef op_cast op; + return matrix_op(op(m.ref())); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + namespace impl + { + template + inline type lessthan(const type& val, const S& s) + { + if (val < s) + return 1; + else + return 0; + } + + } + DLIB_DEFINE_OP_MS(op_lessthan, impl::lessthan, 1); + + template < + typename EXP, + typename S + > + const typename enable_if, matrix_op > >::type operator< ( + const matrix_exp& m, + const S& s + ) + { + // you can only use this relational operator with the built in scalar types like + // long, float, etc. + COMPILE_TIME_ASSERT(is_built_in_scalar_type::value); + + typedef op_lessthan op; + return matrix_op(op(m.ref(),s)); + } + + template < + typename EXP, + typename S + > + const typename enable_if, matrix_op > >::type operator> ( + const S& s, + const matrix_exp& m + ) + { + // you can only use this relational operator with the built in scalar types like + // long, float, etc. + COMPILE_TIME_ASSERT(is_built_in_scalar_type::value); + + typedef op_lessthan op; + return matrix_op(op(m.ref(),s)); + } + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + template + inline type lessthan_eq(const type& val, const S& s) + { + if (val <= s) + return 1; + else + return 0; + } + + } + DLIB_DEFINE_OP_MS(op_lessthan_eq, impl::lessthan_eq, 1); + + template < + typename EXP, + typename S + > + const typename enable_if, matrix_op > >::type operator<= ( + const matrix_exp& m, + const S& s + ) + { + // you can only use this relational operator with the built in scalar types like + // long, float, etc. + COMPILE_TIME_ASSERT(is_built_in_scalar_type::value); + + typedef op_lessthan_eq op; + return matrix_op(op(m.ref(),s)); + } + + template < + typename EXP, + typename S + > + const typename enable_if, matrix_op > >::type operator>= ( + const S& s, + const matrix_exp& m + ) + { + // you can only use this relational operator with the built in scalar types like + // long, float, etc. + COMPILE_TIME_ASSERT(is_built_in_scalar_type::value); + + typedef op_lessthan_eq op; + return matrix_op(op(m.ref(),s)); + } + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + template + inline type greaterthan(const type& val, const S& s) + { + if (val > s) + return 1; + else + return 0; + } + + } + DLIB_DEFINE_OP_MS(op_greaterthan, impl::greaterthan, 1); + + template < + typename EXP, + typename S + > + const typename enable_if, matrix_op > >::type operator> ( + const matrix_exp& m, + const S& s + ) + { + // you can only use this relational operator with the built in scalar types like + // long, float, etc. + COMPILE_TIME_ASSERT(is_built_in_scalar_type::value); + + typedef op_greaterthan op; + return matrix_op(op(m.ref(),s)); + } + + template < + typename EXP, + typename S + > + const typename enable_if, matrix_op > >::type operator< ( + const S& s, + const matrix_exp& m + ) + { + // you can only use this relational operator with the built in scalar types like + // long, float, etc. + COMPILE_TIME_ASSERT(is_built_in_scalar_type::value); + + typedef op_greaterthan op; + return matrix_op(op(m.ref(),s)); + } + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + template + inline type greaterthan_eq(const type& val, const S& s) + { + if (val >= s) + return 1; + else + return 0; + } + + } + DLIB_DEFINE_OP_MS(op_greaterthan_eq, impl::greaterthan_eq, 1); + + template < + typename EXP, + typename S + > + const typename enable_if, matrix_op > >::type operator>= ( + const matrix_exp& m, + const S& s + ) + { + // you can only use this relational operator with the built in scalar types like + // long, float, etc. + COMPILE_TIME_ASSERT(is_built_in_scalar_type::value); + + typedef op_greaterthan_eq op; + return matrix_op(op(m.ref(),s)); + } + + template < + typename EXP, + typename S + > + const typename enable_if, matrix_op > >::type operator<= ( + const S& s, + const matrix_exp& m + ) + { + // you can only use this relational operator with the built in scalar types like + // long, float, etc. + COMPILE_TIME_ASSERT(is_built_in_scalar_type::value); + + typedef op_greaterthan_eq op; + return matrix_op(op(m.ref(),s)); + } + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + template + inline type equal_to(const type& val, const S& s) + { + if (val == s) + return 1; + else + return 0; + } + + } + DLIB_DEFINE_OP_MS(op_equal_to, impl::equal_to, 1); + + template < + typename EXP, + typename S + > + const typename enable_if, matrix_op > >::type operator== ( + const matrix_exp& m, + const S& s + ) + { + // you can only use this relational operator with the built in scalar types like + // long, float, etc. + COMPILE_TIME_ASSERT( is_built_in_scalar_type::value); + + typedef op_equal_to op; + return matrix_op(op(m.ref(),s)); + } + + template < + typename EXP, + typename S + > + const typename enable_if, matrix_op > >::type operator== ( + const S& s, + const matrix_exp& m + ) + { + // you can only use this relational operator with the built in scalar types like + // long, float, etc. + COMPILE_TIME_ASSERT( is_built_in_scalar_type::value); + + typedef op_equal_to op; + return matrix_op(op(m.ref(),s)); + } + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + template + inline type not_equal_to(const type& val, const S& s) + { + if (val != s) + return 1; + else + return 0; + } + + } + DLIB_DEFINE_OP_MS(op_not_equal_to, impl::not_equal_to, 1); + + + template < + typename EXP, + typename S + > + const typename enable_if, matrix_op > >::type operator!= ( + const matrix_exp& m, + const S& s + ) + { + // you can only use this relational operator with the built in scalar types like + // long, float, etc. + COMPILE_TIME_ASSERT(is_built_in_scalar_type::value); + + typedef op_not_equal_to op; + return matrix_op(op(m.ref(),s)); + } + + template < + typename EXP, + typename S + > + const typename enable_if, matrix_op > >::type operator!= ( + const S& s, + const matrix_exp& m + ) + { + // you can only use this relational operator with the built in scalar types like + // long, float, etc. + COMPILE_TIME_ASSERT(is_built_in_scalar_type::value); + + typedef op_not_equal_to op; + return matrix_op(op(m.ref(),s)); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T, + long NR, + long NC, + typename MM, + typename U, + typename L + > + typename disable_if,void>::type set_all_elements ( + matrix& m, + const U& value + ) + { + // The value you are trying to assign to each element of the m matrix + // doesn't have the appropriate type. + COMPILE_TIME_ASSERT(is_matrix::value == is_matrix::value); + + for (long r = 0; r < m.nr(); ++r) + { + for (long c = 0; c < m.nc(); ++c) + { + m(r,c) = static_cast(value); + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + long NR, + long NC, + typename MM, + typename U, + typename L + > + typename enable_if,void>::type set_all_elements ( + matrix& m, + const U& value + ) + { + for (long r = 0; r < m.nr(); ++r) + { + for (long c = 0; c < m.nc(); ++c) + { + m(r,c) = value; + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP + > + inline const typename matrix_exp::matrix_type tmp ( + const matrix_exp& m + ) + { + return typename matrix_exp::matrix_type (m); + } + +// ---------------------------------------------------------------------------------------- + + template + constexpr bool is_row_major ( + const matrix_exp& + ) + { + return is_same_type::value; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP + > + const typename lazy_disable_if, EXP>::type sum ( + const matrix_exp& m + ) + { + typedef typename matrix_exp::type type; + + type val = 0; + if (is_row_major(m)) + { + for (long r = 0; r < m.nr(); ++r) + { + for (long c = 0; c < m.nc(); ++c) + { + val += m(r,c); + } + } + } + else + { + for (long c = 0; c < m.nc(); ++c) + { + for (long r = 0; r < m.nr(); ++r) + { + val += m(r,c); + } + } + } + return val; + } + + template < + typename EXP + > + const typename lazy_enable_if, EXP>::type sum ( + const matrix_exp& m + ) + { + typedef typename matrix_exp::type type; + + type val; + if (m.size() > 0) + val.set_size(m(0,0).nr(),m(0,0).nc()); + set_all_elements(val,0); + + for (long r = 0; r < m.nr(); ++r) + { + for (long c = 0; c < m.nc(); ++c) + { + val += m(r,c); + } + } + return val; + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_sumr + { + op_sumr(const M& m_) : m(m_) {} + const M& m; + + const static long cost = M::cost+10; + const static long NR = 1; + const static long NC = M::NC; + typedef typename M::type type; + typedef const typename M::type const_ret_type; + typedef typename M::mem_manager_type mem_manager_type; + typedef typename M::layout_type layout_type; + const_ret_type apply ( long , long c) const + { + type temp = m(0,c); + for (long r = 1; r < m.nr(); ++r) + temp += m(r,c); + return temp; + } + + long nr () const { return 1; } + long nc () const { return m.nc(); } + + template bool aliases ( const matrix_exp& item) const { return m.aliases(item); } + template bool destructively_aliases ( const matrix_exp& item) const { return m.aliases(item); } + }; + + template < + typename EXP + > + const matrix_op > sum_rows ( + const matrix_exp& m + ) + { + DLIB_ASSERT(m.size() > 0 , + "\tconst matrix_exp sum_rows(m)" + << "\n\t The matrix can't be empty" + << "\n\t m.size(): " << m.size() + ); + typedef op_sumr op; + return matrix_op(op(m.ref())); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_sumc + { + op_sumc(const M& m_) : m(m_) {} + const M& m; + + const static long cost = M::cost + 10; + const static long NR = M::NR; + const static long NC = 1; + typedef typename M::type type; + typedef const typename M::type const_ret_type; + typedef typename M::mem_manager_type mem_manager_type; + typedef typename M::layout_type layout_type; + const_ret_type apply ( long r, long ) const + { + type temp = m(r,0); + for (long c = 1; c < m.nc(); ++c) + temp += m(r,c); + return temp; + } + + long nr () const { return m.nr(); } + long nc () const { return 1; } + + template bool aliases ( const matrix_exp& item) const { return m.aliases(item); } + template bool destructively_aliases ( const matrix_exp& item) const { return m.aliases(item); } + }; + + template < + typename EXP + > + const matrix_op > sum_cols ( + const matrix_exp& m + ) + { + DLIB_ASSERT(m.size() > 0 , + "\tconst matrix_exp sum_cols(m)" + << "\n\t The matrix can't be empty" + << "\n\t m.size(): " << m.size() + ); + typedef op_sumc op; + return matrix_op(op(m.ref())); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP + > + inline const typename disable_if, typename matrix_exp::type>::type mean ( + const matrix_exp& m + ) + { + return sum(m)/(m.nr()*m.nc()); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP + > + inline const typename enable_if, typename matrix_exp::type>::type mean ( + const matrix_exp& m + ) + { + typedef typename EXP::type::value_type type; + return sum(m)/(type)(m.nr()*m.nc()); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP + > + const typename matrix_exp::type variance ( + const matrix_exp& m + ) + { + using std::pow; + using dlib::pow; + const typename matrix_exp::type avg = mean(m); + + typedef typename matrix_exp::type type; + + type val; + val = 0; + for (long r = 0; r < m.nr(); ++r) + { + for (long c = 0; c < m.nc(); ++c) + { + val += pow(m(r,c) - avg,2); + } + } + + if (m.nr() * m.nc() <= 1) + { + return val; + } + else + { + // Note, for some reason, in gcc 4.1 performing this division using a + // double instead of a long value avoids a segmentation fault. That is, + // using 1.0 instead of 1 does the trick. + return val/(m.nr()*m.nc() - 1.0); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP + > + const typename matrix_exp::type stddev ( + const matrix_exp& m + ) + { + using std::sqrt; + using dlib::sqrt; + return sqrt(variance(m)); + } + +// ---------------------------------------------------------------------------------------- + +// this is a workaround for a bug in visual studio 7.1 + template + struct visual_studio_sucks_cov_helper + { + typedef typename EXP::type inner_type; + typedef matrix type; + }; + + template < + typename EXP + > + const typename visual_studio_sucks_cov_helper::type covariance ( + const matrix_exp& m + ) + { + // perform static checks to make sure m is a column vector + COMPILE_TIME_ASSERT(EXP::NR == 0 || EXP::NR > 1); + COMPILE_TIME_ASSERT(EXP::NC == 1 || EXP::NC == 0); + + // perform static checks to make sure the matrices contained in m are column vectors + COMPILE_TIME_ASSERT(EXP::type::NC == 1 || EXP::type::NC == 0 ); + + DLIB_ASSERT(m.size() > 1 && is_col_vector(m), + "\tconst matrix covariance(const matrix_exp& m)" + << "\n\tYou can only apply covariance() to a column matrix" + << "\n\tm.nr(): " << m.nr() + << "\n\tm.nc(): " << m.nc() + ); +#ifdef ENABLE_ASSERTS + for (long i = 0; i < m.nr(); ++i) + { + DLIB_ASSERT(m(0).size() == m(i).size() && m(i).size() > 0 && is_col_vector(m(i)), + "\tconst matrix covariance(const matrix_exp& m)" + << "\n\tYou can only apply covariance() to a column matrix of column matrices" + << "\n\tm(0).size(): " << m(0).size() + << "\n\tm(i).size(): " << m(i).size() + << "\n\tis_col_vector(m(i)): " << (is_col_vector(m(i)) ? "true" : "false") + << "\n\ti: " << i + ); + } +#endif + + // now perform the actual calculation of the covariance matrix. + typename visual_studio_sucks_cov_helper::type cov(m(0).nr(),m(0).nr()); + set_all_elements(cov,0); + + const typename EXP::type avg = mean(m); + + for (long r = 0; r < m.nr(); ++r) + { + cov += (m(r) - avg)*trans(m(r) - avg); + } + + cov *= 1.0 / (m.nr() - 1.0); + return cov; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP + > + const typename matrix_exp::type prod ( + const matrix_exp& m + ) + { + typedef typename matrix_exp::type type; + + type val = 1; + for (long r = 0; r < m.nr(); ++r) + { + for (long c = 0; c < m.nc(); ++c) + { + val *= m(r,c); + } + } + return val; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + struct op_uniform_matrix_3 : does_not_alias + { + op_uniform_matrix_3(const long& rows_, const long& cols_, const T& val_ ) : + rows(rows_), cols(cols_), val(val_) {} + + const long rows; + const long cols; + const T val; + + const static long cost = 1; + const static long NR = 0; + const static long NC = 0; + typedef default_memory_manager mem_manager_type; + typedef row_major_layout layout_type; + typedef T type; + typedef const T& const_ret_type; + const_ret_type apply (long, long ) const { return val; } + + long nr() const { return rows; } + long nc() const { return cols; } + }; + + template < + typename T + > + const matrix_op > uniform_matrix ( + long nr, + long nc, + const T& val + ) + { + DLIB_ASSERT(nr >= 0 && nc >= 0, + "\tconst matrix_exp uniform_matrix(nr, nc, val)" + << "\n\tnr and nc have to be bigger than 0" + << "\n\tnr: " << nr + << "\n\tnc: " << nc + ); + typedef op_uniform_matrix_3 op; + return matrix_op(op(nr, nc, val)); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + const matrix_op > zeros_matrix ( + long nr, + long nc + ) + { + DLIB_ASSERT(nr >= 0 && nc >= 0, + "\tconst matrix_exp zeros_matrix(nr, nc)" + << "\n\tnr and nc have to be >= 0" + << "\n\tnr: " << nr + << "\n\tnc: " << nc + ); + typedef op_uniform_matrix_3 op; + return matrix_op(op(nr, nc, 0)); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP + > + const matrix_op > zeros_matrix ( + const matrix_exp& mat + ) + { + DLIB_ASSERT(mat.nr() >= 0 && mat.nc() >= 0, + "\tconst matrix_exp zeros_matrix(mat)" + << "\n\t nr and nc have to be >= 0" + << "\n\t mat.nr(): " << mat.nr() + << "\n\t mat.nc(): " << mat.nc() + ); + typedef typename EXP::type T; + typedef op_uniform_matrix_3 op; + return matrix_op(op(mat.nr(), mat.nc(), 0)); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + const matrix_op > ones_matrix ( + long nr, + long nc + ) + { + DLIB_ASSERT(nr >= 0 && nc >= 0, + "\tconst matrix_exp ones_matrix(nr, nc)" + << "\n\tnr and nc have to be >= 0" + << "\n\tnr: " << nr + << "\n\tnc: " << nc + ); + typedef op_uniform_matrix_3 op; + return matrix_op(op(nr, nc, 1)); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP + > + const matrix_op > ones_matrix ( + const matrix_exp& mat + ) + { + DLIB_ASSERT(mat.nr() >= 0 && mat.nc() >= 0, + "\tconst matrix_exp ones_matrix(mat)" + << "\n\t nr and nc have to be >= 0" + << "\n\t mat.nr(): " << mat.nr() + << "\n\t mat.nc(): " << mat.nc() + ); + typedef typename EXP::type T; + typedef op_uniform_matrix_3 op; + return matrix_op(op(mat.nr(), mat.nc(), 1)); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + long NR_, + long NC_ + > + struct op_uniform_matrix_2 : does_not_alias + { + op_uniform_matrix_2( const T& val_ ) : val(val_) {} + const T val; + + const static long cost = 1; + const static long NR = NR_; + const static long NC = NC_; + typedef default_memory_manager mem_manager_type; + typedef row_major_layout layout_type; + typedef T type; + typedef const T& const_ret_type; + + const_ret_type apply (long , long ) const { return val; } + + long nr() const { return NR; } + long nc() const { return NC; } + }; + + template < + typename T, + long NR, + long NC + > + const matrix_op > uniform_matrix ( + const T& val + ) + { + COMPILE_TIME_ASSERT(NR > 0 && NC > 0); + + typedef op_uniform_matrix_2 op; + return matrix_op(op(val)); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + long NR_, + long NC_, + T val + > + struct op_uniform_matrix : does_not_alias + { + const static long cost = 1; + const static long NR = NR_; + const static long NC = NC_; + typedef default_memory_manager mem_manager_type; + typedef row_major_layout layout_type; + typedef T type; + typedef const T const_ret_type; + const_ret_type apply ( long , long ) const { return val; } + + long nr() const { return NR; } + long nc() const { return NC; } + }; + + template < + typename T, + long NR, + long NC, + T val + > + const matrix_op > uniform_matrix ( + ) + { + COMPILE_TIME_ASSERT(NR > 0 && NC > 0); + typedef op_uniform_matrix op; + return matrix_op(op()); + } + +// ---------------------------------------------------------------------------------------- + + struct op_gaussian_randm : does_not_alias + { + op_gaussian_randm ( + long nr_, + long nc_, + unsigned long seed_ + ) :_nr(nr_), _nc(nc_), seed(seed_){} + + const long _nr; + const long _nc; + const unsigned long seed; + + const static long cost = 100; + const static long NR = 0; + const static long NC = 0; + typedef default_memory_manager mem_manager_type; + typedef row_major_layout layout_type; + typedef double type; + typedef double const_ret_type; + const_ret_type apply ( long r, long c) const { return gaussian_random_hash(r,c,seed); } + + long nr() const { return _nr; } + long nc() const { return _nc; } + }; + + inline const matrix_op gaussian_randm ( + long nr, + long nc, + unsigned long seed = 0 + ) + { + DLIB_ASSERT(nr >= 0 && nc >= 0, + "\tmatrix_exp gaussian_randm(nr, nc, seed)" + << "\n\tInvalid inputs to this function" + << "\n\tnr: " << nr + << "\n\tnc: " << nc + ); + + typedef op_gaussian_randm op; + return matrix_op(op(nr,nc,seed)); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_add_diag + { + op_add_diag( const M& m_, const typename M::type& value_) : m(m_), value(value_){} + const M& m; + const typename M::type value; + + const static long cost = M::cost+1; + const static long NR = M::NR; + const static long NC = M::NC; + typedef typename M::type type; + typedef const typename M::type const_ret_type; + typedef typename M::mem_manager_type mem_manager_type; + typedef typename M::layout_type layout_type; + const_ret_type apply ( long r, long c) const + { + if (r==c) + return m(r,c)+value; + else + return m(r,c); + } + + long nr () const { return m.nr(); } + long nc () const { return m.nc(); } + + template bool aliases ( const matrix_exp& item) const { return m.aliases(item); } + template bool destructively_aliases ( const matrix_exp& item) const { return m.destructively_aliases(item); } + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + struct op_identity_matrix_2 : does_not_alias + { + op_identity_matrix_2(const long& size_) : size(size_) {} + + const long size; + + const static long cost = 1; + const static long NR = 0; + const static long NC = 0; + typedef default_memory_manager mem_manager_type; + typedef row_major_layout layout_type; + typedef T type; + typedef const T const_ret_type; + const_ret_type apply (long r, long c) const { return static_cast(r == c); } + + long nr() const { return size; } + long nc() const { return size; } + }; + + template < + typename T, + typename U + > + const matrix_diag_op > identity_matrix ( + const U& size + ) + { + // the size argument must be some scalar value, not a matrix! + COMPILE_TIME_ASSERT(is_matrix::value == false); + + DLIB_ASSERT(size > 0, + "\tconst matrix_exp identity_matrix(size)" + << "\n\tsize must be bigger than 0" + << "\n\tsize: " << size + ); + typedef op_identity_matrix_2 op; + return matrix_diag_op(op(size)); + } + + template < + typename EXP + > + const matrix_diag_op > identity_matrix ( + const matrix_exp& mat + ) + { + DLIB_ASSERT(mat.nr() == mat.nc(), + "\tconst matrix_exp identity_matrix(mat)" + << "\n\t mat must be a square matrix." + << "\n\t mat.nr(): " << mat.nr() + << "\n\t mat.nc(): " << mat.nc() + ); + typedef typename EXP::type T; + typedef op_identity_matrix_2 op; + return matrix_diag_op(op(mat.nr())); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP, + typename T + > + const matrix_op > operator+ ( + const matrix_exp& lhs, + const matrix_exp > >& DLIB_IF_ASSERT(rhs) + ) + { + // both matrices must contain the same type of element + COMPILE_TIME_ASSERT((is_same_type::value == true)); + + // You can only add matrices together if they both have the same number of rows and columns. + DLIB_ASSERT(lhs.nc() == rhs.nc() && + lhs.nr() == rhs.nr(), + "\tconst matrix_exp operator+(const matrix_exp& lhs, const matrix_exp& rhs)" + << "\n\tYou are trying to add two incompatible matrices together" + << "\n\tlhs.nr(): " << lhs.nr() + << "\n\tlhs.nc(): " << lhs.nc() + << "\n\trhs.nr(): " << rhs.nr() + << "\n\trhs.nc(): " << rhs.nc() + << "\n\t&lhs: " << &lhs + << "\n\t&rhs: " << &rhs + ); + + + typedef op_add_diag op; + return matrix_op(op(lhs.ref(),1)); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP, + typename T + > + const matrix_op > operator+ ( + const matrix_exp > >& DLIB_IF_ASSERT(lhs), + const matrix_exp& rhs + ) + { + // both matrices must contain the same type of element + COMPILE_TIME_ASSERT((is_same_type::value == true)); + + // You can only add matrices together if they both have the same number of rows and columns. + DLIB_ASSERT(lhs.nc() == rhs.nc() && + lhs.nr() == rhs.nr(), + "\tconst matrix_exp operator+(const matrix_exp& lhs, const matrix_exp& rhs)" + << "\n\tYou are trying to add two incompatible matrices together" + << "\n\tlhs.nr(): " << lhs.nr() + << "\n\tlhs.nc(): " << lhs.nc() + << "\n\trhs.nr(): " << rhs.nr() + << "\n\trhs.nc(): " << rhs.nc() + << "\n\t&lhs: " << &lhs + << "\n\t&rhs: " << &rhs + ); + + + typedef op_add_diag op; + return matrix_op(op(rhs.ref(),1)); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + long N + > + struct op_const_diag_matrix : does_not_alias + { + op_const_diag_matrix(const long& size_, const T& value_) : size(size_),value(value_) {} + + const long size; + const T value; + + const static long cost = 1; + const static long NR = N; + const static long NC = N; + typedef default_memory_manager mem_manager_type; + typedef row_major_layout layout_type; + typedef T type; + typedef const T const_ret_type; + const_ret_type apply (long r, long c) const + { + if (r == c) + return value; + else + return 0; + } + + long nr() const { return size; } + long nc() const { return size; } + }; + + template < + typename T, + typename U + > + const typename disable_if, matrix_diag_op > >::type operator* ( + const matrix_exp > >& m, + const U& value + ) + { + typedef op_const_diag_matrix op; + return matrix_diag_op(op(m.nr(), value)); + } + + template < + typename T, + typename U + > + const typename disable_if, matrix_diag_op > >::type operator* ( + const U& value, + const matrix_exp > >& m + ) + { + typedef op_const_diag_matrix op; + return matrix_diag_op(op(m.nr(), value)); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP, + typename T, + long N + > + const matrix_op > operator+ ( + const matrix_exp& lhs, + const matrix_exp > >& rhs + ) + { + // both matrices must contain the same type of element + COMPILE_TIME_ASSERT((is_same_type::value == true)); + + // You can only add matrices together if they both have the same number of rows and columns. + DLIB_ASSERT(lhs.nc() == rhs.nc() && + lhs.nr() == rhs.nr(), + "\tconst matrix_exp operator+(const matrix_exp& lhs, const matrix_exp& rhs)" + << "\n\tYou are trying to add two incompatible matrices together" + << "\n\tlhs.nr(): " << lhs.nr() + << "\n\tlhs.nc(): " << lhs.nc() + << "\n\trhs.nr(): " << rhs.nr() + << "\n\trhs.nc(): " << rhs.nc() + << "\n\t&lhs: " << &lhs + << "\n\t&rhs: " << &rhs + ); + + + typedef op_add_diag op; + return matrix_op(op(lhs.ref(),rhs.ref().op.value)); + } + + template < + typename EXP, + typename T, + long N + > + const matrix_op > operator+ ( + const matrix_exp > >& lhs, + const matrix_exp& rhs + ) + { + // both matrices must contain the same type of element + COMPILE_TIME_ASSERT((is_same_type::value == true)); + + // You can only add matrices together if they both have the same number of rows and columns. + DLIB_ASSERT(lhs.nc() == rhs.nc() && + lhs.nr() == rhs.nr(), + "\tconst matrix_exp operator+(const matrix_exp& lhs, const matrix_exp& rhs)" + << "\n\tYou are trying to add two incompatible matrices together" + << "\n\tlhs.nr(): " << lhs.nr() + << "\n\tlhs.nc(): " << lhs.nc() + << "\n\trhs.nr(): " << rhs.nr() + << "\n\trhs.nc(): " << rhs.nc() + << "\n\t&lhs: " << &lhs + << "\n\t&rhs: " << &rhs + ); + + + typedef op_add_diag op; + return matrix_op(op(rhs.ref(),lhs.ref().op.value)); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + long N + > + struct op_identity_matrix : does_not_alias + { + const static long cost = 1; + const static long NR = N; + const static long NC = N; + typedef default_memory_manager mem_manager_type; + typedef row_major_layout layout_type; + typedef T type; + typedef const T const_ret_type; + const_ret_type apply ( long r, long c) const { return static_cast(r == c); } + + long nr () const { return NR; } + long nc () const { return NC; } + }; + + template < + typename T, + long N + > + const matrix_diag_op > identity_matrix ( + ) + { + COMPILE_TIME_ASSERT(N > 0); + + typedef op_identity_matrix op; + return matrix_diag_op(op()); + } + + template < + typename T, + typename U, + long N + > + const typename disable_if, matrix_diag_op > >::type operator* ( + const matrix_exp > >& m, + const U& value + ) + { + typedef op_const_diag_matrix op; + return matrix_diag_op(op(m.nr(), value)); + } + + template < + typename T, + typename U, + long N + > + const typename disable_if, matrix_diag_op > >::type operator* ( + const U& value, + const matrix_exp > >& m + ) + { + typedef op_const_diag_matrix op; + return matrix_diag_op(op(m.nr(), value)); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP, + typename T, + long N + > + const matrix_op > operator+ ( + const matrix_exp > >& DLIB_IF_ASSERT(lhs), + const matrix_exp& rhs + ) + { + // both matrices must contain the same type of element + COMPILE_TIME_ASSERT((is_same_type::value == true)); + + // You can only add matrices together if they both have the same number of rows and columns. + DLIB_ASSERT(lhs.nc() == rhs.nc() && + lhs.nr() == rhs.nr(), + "\tconst matrix_exp operator+(const matrix_exp& lhs, const matrix_exp& rhs)" + << "\n\tYou are trying to add two incompatible matrices together" + << "\n\tlhs.nr(): " << lhs.nr() + << "\n\tlhs.nc(): " << lhs.nc() + << "\n\trhs.nr(): " << rhs.nr() + << "\n\trhs.nc(): " << rhs.nc() + << "\n\t&lhs: " << &lhs + << "\n\t&rhs: " << &rhs + ); + + + typedef op_add_diag op; + return matrix_op(op(rhs.ref(),1)); + } + + template < + typename EXP, + typename T, + long N + > + const matrix_op > operator+ ( + const matrix_exp& lhs, + const matrix_exp > >& DLIB_IF_ASSERT(rhs) + ) + { + // both matrices must contain the same type of element + COMPILE_TIME_ASSERT((is_same_type::value == true)); + + // You can only add matrices together if they both have the same number of rows and columns. + DLIB_ASSERT(lhs.nc() == rhs.nc() && + lhs.nr() == rhs.nr(), + "\tconst matrix_exp operator+(const matrix_exp& lhs, const matrix_exp& rhs)" + << "\n\tYou are trying to add two incompatible matrices together" + << "\n\tlhs.nr(): " << lhs.nr() + << "\n\tlhs.nc(): " << lhs.nc() + << "\n\trhs.nr(): " << rhs.nr() + << "\n\trhs.nc(): " << rhs.nc() + << "\n\t&lhs: " << &lhs + << "\n\t&rhs: " << &rhs + ); + + + typedef op_add_diag op; + return matrix_op(op(lhs.ref(),1)); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_rotate + { + op_rotate(const M& m_) : m(m_) {} + const M& m; + + const static long cost = M::cost + 2; + const static long NR = M::NR; + const static long NC = M::NC; + typedef typename M::type type; + typedef typename M::const_ret_type const_ret_type; + typedef typename M::mem_manager_type mem_manager_type; + typedef typename M::layout_type layout_type; + const_ret_type apply ( long r, long c) const { return m((r+R)%m.nr(),(c+C)%m.nc()); } + + long nr () const { return m.nr(); } + long nc () const { return m.nc(); } + + template bool aliases ( const matrix_exp& item) const { return m.aliases(item); } + template bool destructively_aliases ( const matrix_exp& item) const { return m.aliases(item); } + }; + + template < + long R, + long C, + typename EXP + > + const matrix_op > rotate ( + const matrix_exp& m + ) + { + typedef op_rotate op; + return matrix_op(op(m.ref())); + } + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + // A template to tell me if two types can be multiplied together in a sensible way. Here + // I'm saying it is ok if they are both the same type or one is the complex version of the other. + template struct compatible { static const bool value = false; typedef T type; }; + template struct compatible { static const bool value = true; typedef T type; }; + template struct compatible,T> { static const bool value = true; typedef std::complex type; }; + template struct compatible > { static const bool value = true; typedef std::complex type; }; + } + + + template + struct op_pointwise_multiply : basic_op_mm + { + op_pointwise_multiply( const M1& m1_, const M2& m2_) : basic_op_mm(m1_,m2_){} + + typedef typename impl::compatible::type type; + typedef const type const_ret_type; + const static long cost = M1::cost + M2::cost + 1; + + const_ret_type apply ( long r, long c) const + { return this->m1(r,c)*this->m2(r,c); } + }; + + template < + typename EXP1, + typename EXP2 + > + inline const matrix_op > pointwise_multiply ( + const matrix_exp& a, + const matrix_exp& b + ) + { + COMPILE_TIME_ASSERT((impl::compatible::value == true)); + COMPILE_TIME_ASSERT(EXP1::NR == EXP2::NR || EXP1::NR == 0 || EXP2::NR == 0); + COMPILE_TIME_ASSERT(EXP1::NC == EXP2::NC || EXP1::NC == 0 || EXP2::NC == 0); + DLIB_ASSERT(a.nr() == b.nr() && + a.nc() == b.nc(), + "\tconst matrix_exp pointwise_multiply(const matrix_exp& a, const matrix_exp& b)" + << "\n\tYou can only make a do a pointwise multiply with two equally sized matrices" + << "\n\ta.nr(): " << a.nr() + << "\n\ta.nc(): " << a.nc() + << "\n\tb.nr(): " << b.nr() + << "\n\tb.nc(): " << b.nc() + ); + typedef op_pointwise_multiply op; + return matrix_op(op(a.ref(),b.ref())); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_pointwise_multiply3 : basic_op_mmm + { + op_pointwise_multiply3( const M1& m1_, const M2& m2_, const M3& m3_) : + basic_op_mmm(m1_,m2_,m3_){} + + typedef typename M1::type type; + typedef const typename M1::type const_ret_type; + const static long cost = M1::cost + M2::cost + M3::cost + 2; + + const_ret_type apply (long r, long c) const + { return this->m1(r,c)*this->m2(r,c)*this->m3(r,c); } + }; + + template < + typename EXP1, + typename EXP2, + typename EXP3 + > + inline const matrix_op > + pointwise_multiply ( + const matrix_exp& a, + const matrix_exp& b, + const matrix_exp& c + ) + { + COMPILE_TIME_ASSERT((is_same_type::value == true)); + COMPILE_TIME_ASSERT((is_same_type::value == true)); + COMPILE_TIME_ASSERT(EXP1::NR == EXP2::NR || EXP1::NR == 0 || EXP2::NR == 0); + COMPILE_TIME_ASSERT(EXP1::NC == EXP2::NC || EXP1::NR == 0 || EXP2::NC == 0); + COMPILE_TIME_ASSERT(EXP2::NR == EXP3::NR || EXP2::NR == 0 || EXP3::NR == 0); + COMPILE_TIME_ASSERT(EXP2::NC == EXP3::NC || EXP2::NC == 0 || EXP3::NC == 0); + DLIB_ASSERT(a.nr() == b.nr() && + a.nc() == b.nc() && + b.nr() == c.nr() && + b.nc() == c.nc(), + "\tconst matrix_exp pointwise_multiply(a,b,c)" + << "\n\tYou can only make a do a pointwise multiply between equally sized matrices" + << "\n\ta.nr(): " << a.nr() + << "\n\ta.nc(): " << a.nc() + << "\n\tb.nr(): " << b.nr() + << "\n\tb.nc(): " << b.nc() + << "\n\tc.nr(): " << c.nr() + << "\n\tc.nc(): " << c.nc() + ); + + typedef op_pointwise_multiply3 op; + return matrix_op(op(a.ref(),b.ref(),c.ref())); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_pointwise_multiply4 : basic_op_mmmm + { + op_pointwise_multiply4( const M1& m1_, const M2& m2_, const M3& m3_, const M4& m4_) : + basic_op_mmmm(m1_,m2_,m3_,m4_){} + + typedef typename M1::type type; + typedef const typename M1::type const_ret_type; + const static long cost = M1::cost + M2::cost + M3::cost + M4::cost + 3; + + const_ret_type apply (long r, long c) const + { return this->m1(r,c)*this->m2(r,c)*this->m3(r,c)*this->m4(r,c); } + }; + + + template < + typename EXP1, + typename EXP2, + typename EXP3, + typename EXP4 + > + inline const matrix_op > pointwise_multiply ( + const matrix_exp& a, + const matrix_exp& b, + const matrix_exp& c, + const matrix_exp& d + ) + { + COMPILE_TIME_ASSERT((is_same_type::value == true)); + COMPILE_TIME_ASSERT((is_same_type::value == true)); + COMPILE_TIME_ASSERT((is_same_type::value == true)); + COMPILE_TIME_ASSERT(EXP1::NR == EXP2::NR || EXP1::NR == 0 || EXP2::NR == 0); + COMPILE_TIME_ASSERT(EXP1::NC == EXP2::NC || EXP1::NC == 0 || EXP2::NC == 0 ); + COMPILE_TIME_ASSERT(EXP2::NR == EXP3::NR || EXP2::NR == 0 || EXP3::NR == 0); + COMPILE_TIME_ASSERT(EXP2::NC == EXP3::NC || EXP2::NC == 0 || EXP3::NC == 0); + COMPILE_TIME_ASSERT(EXP3::NR == EXP4::NR || EXP3::NR == 0 || EXP4::NR == 0); + COMPILE_TIME_ASSERT(EXP3::NC == EXP4::NC || EXP3::NC == 0 || EXP4::NC == 0); + DLIB_ASSERT(a.nr() == b.nr() && + a.nc() == b.nc() && + b.nr() == c.nr() && + b.nc() == c.nc() && + c.nr() == d.nr() && + c.nc() == d.nc(), + "\tconst matrix_exp pointwise_multiply(a,b,c,d)" + << "\n\tYou can only make a do a pointwise multiply between equally sized matrices" + << "\n\ta.nr(): " << a.nr() + << "\n\ta.nc(): " << a.nc() + << "\n\tb.nr(): " << b.nr() + << "\n\tb.nc(): " << b.nc() + << "\n\tc.nr(): " << c.nr() + << "\n\tc.nc(): " << c.nc() + << "\n\td.nr(): " << d.nr() + << "\n\td.nc(): " << d.nc() + ); + + typedef op_pointwise_multiply4 op; + return matrix_op(op(a.ref(),b.ref(),c.ref(),d.ref())); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename P, + int type = static_switch< + pixel_traits

    ::grayscale, + pixel_traits

    ::rgb, + pixel_traits

    ::hsi, + pixel_traits

    ::rgb_alpha, + pixel_traits

    ::lab + >::value + > + struct pixel_to_vector_helper; + + template + struct pixel_to_vector_helper + { + template + static void assign ( + M& m, + const P& pixel + ) + { + m(0) = static_cast(pixel); + } + }; + + template + struct pixel_to_vector_helper + { + template + static void assign ( + M& m, + const P& pixel + ) + { + m(0) = static_cast(pixel.red); + m(1) = static_cast(pixel.green); + m(2) = static_cast(pixel.blue); + } + }; + + template + struct pixel_to_vector_helper + { + template + static void assign ( + M& m, + const P& pixel + ) + { + m(0) = static_cast(pixel.h); + m(1) = static_cast(pixel.s); + m(2) = static_cast(pixel.i); + } + }; + + template + struct pixel_to_vector_helper + { + template + static void assign ( + M& m, + const P& pixel + ) + { + m(0) = static_cast(pixel.red); + m(1) = static_cast(pixel.green); + m(2) = static_cast(pixel.blue); + m(3) = static_cast(pixel.alpha); + } + }; + + template + struct pixel_to_vector_helper + { + template + static void assign ( + M& m, + const P& pixel + ) + { + m(0) = static_cast(pixel.l); + m(1) = static_cast(pixel.a); + m(2) = static_cast(pixel.b); + } + }; + + + template < + typename T, + typename P + > + inline const matrix::num,1> pixel_to_vector ( + const P& pixel + ) + { + COMPILE_TIME_ASSERT(pixel_traits

    ::num > 0); + matrix::num,1> m; + pixel_to_vector_helper

    ::assign(m,pixel); + return m; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename P, + int type = static_switch< + pixel_traits

    ::grayscale, + pixel_traits

    ::rgb, + pixel_traits

    ::hsi, + pixel_traits

    ::rgb_alpha, + pixel_traits

    ::lab + >::value + > + struct vector_to_pixel_helper; + + template + struct vector_to_pixel_helper + { + template + static void assign ( + P& pixel, + const M& m + ) + { + pixel = static_cast(m(0)); + } + }; + + template + struct vector_to_pixel_helper + { + template + static void assign ( + P& pixel, + const M& m + ) + { + pixel.red = static_cast(m(0)); + pixel.green = static_cast(m(1)); + pixel.blue = static_cast(m(2)); + } + }; + + template + struct vector_to_pixel_helper + { + template + static void assign ( + P& pixel, + const M& m + ) + { + pixel.h = static_cast(m(0)); + pixel.s = static_cast(m(1)); + pixel.i = static_cast(m(2)); + } + }; + + template + struct vector_to_pixel_helper + { + template + static void assign ( + P& pixel, + const M& m + ) + { + pixel.red = static_cast(m(0)); + pixel.green = static_cast(m(1)); + pixel.blue = static_cast(m(2)); + pixel.alpha = static_cast(m(3)); + } + }; + + template + struct vector_to_pixel_helper + { + template + static void assign ( + P& pixel, + const M& m + ) + { + pixel.l = static_cast(m(0)); + pixel.a = static_cast(m(1)); + pixel.b = static_cast(m(2)); + } + }; + + template < + typename P, + typename EXP + > + inline void vector_to_pixel ( + P& pixel, + const matrix_exp& vector + ) + { + COMPILE_TIME_ASSERT(pixel_traits

    ::num == matrix_exp::NR); + COMPILE_TIME_ASSERT(matrix_exp::NC == 1); + vector_to_pixel_helper

    ::assign(pixel,vector); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_clamp : basic_op_m + { + op_clamp( const M& m_) : basic_op_m(m_){} + + typedef typename M::type type; + typedef const typename M::type const_ret_type; + const static long cost = M::cost + 2; + + const_ret_type apply ( long r, long c) const + { + const type temp = this->m(r,c); + if (temp > static_cast(upper)) + return static_cast(upper); + else if (temp < static_cast(lower)) + return static_cast(lower); + else + return temp; + } + }; + + template < + long l, + long u, + typename EXP + > + const matrix_op > clamp ( + const matrix_exp& m + ) + { + typedef op_clamp op; + return matrix_op(op(m.ref())); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_clamp2 : basic_op_m + { + typedef typename M::type type; + + op_clamp2( const M& m_, const type& l, const type& u) : + basic_op_m(m_), lower(l), upper(u){} + + const type& lower; + const type& upper; + + typedef const typename M::type const_ret_type; + const static long cost = M::cost + 2; + + const_ret_type apply ( long r, long c) const + { + const type temp = this->m(r,c); + if (temp > upper) + return upper; + else if (temp < lower) + return lower; + else + return temp; + } + }; + + template < + typename EXP + > + const matrix_op > clamp ( + const matrix_exp& m, + const typename EXP::type& lower, + const typename EXP::type& upper + ) + { + typedef op_clamp2 op; + return matrix_op(op(m.ref(),lower, upper)); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_clamp_m : basic_op_mmm + { + op_clamp_m( const M1& m1_, const M2& m2_, const M3& m3_) : + basic_op_mmm(m1_,m2_,m3_){} + + typedef typename M1::type type; + typedef const typename M1::type const_ret_type; + const static long cost = M1::cost + M2::cost + M3::cost + 2; + + const_ret_type apply (long r, long c) const + { + const type val = this->m1(r,c); + const type lower = this->m2(r,c); + const type upper = this->m3(r,c); + if (val <= upper) + { + if (lower <= val) + return val; + else + return lower; + } + else + { + return upper; + } + } + }; + + template < + typename EXP1, + typename EXP2, + typename EXP3 + > + const matrix_op > + clamp ( + const matrix_exp& m, + const matrix_exp& lower, + const matrix_exp& upper + ) + { + COMPILE_TIME_ASSERT((is_same_type::value == true)); + COMPILE_TIME_ASSERT((is_same_type::value == true)); + COMPILE_TIME_ASSERT(EXP1::NR == EXP2::NR || EXP1::NR == 0 || EXP2::NR == 0); + COMPILE_TIME_ASSERT(EXP1::NC == EXP2::NC || EXP1::NR == 0 || EXP2::NC == 0); + COMPILE_TIME_ASSERT(EXP2::NR == EXP3::NR || EXP2::NR == 0 || EXP3::NR == 0); + COMPILE_TIME_ASSERT(EXP2::NC == EXP3::NC || EXP2::NC == 0 || EXP3::NC == 0); + DLIB_ASSERT(m.nr() == lower.nr() && + m.nc() == lower.nc() && + m.nr() == upper.nr() && + m.nc() == upper.nc(), + "\tconst matrix_exp clamp(m,lower,upper)" + << "\n\t Invalid inputs were given to this function." + << "\n\t m.nr(): " << m.nr() + << "\n\t m.nc(): " << m.nc() + << "\n\t lower.nr(): " << lower.nr() + << "\n\t lower.nc(): " << lower.nc() + << "\n\t upper.nr(): " << upper.nr() + << "\n\t upper.nc(): " << upper.nc() + ); + + typedef op_clamp_m op; + return matrix_op(op(m.ref(),lower.ref(),upper.ref())); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_lowerbound : basic_op_m + { + typedef typename M::type type; + + op_lowerbound( const M& m_, const type& thresh_) : + basic_op_m(m_), thresh(thresh_){} + + const type& thresh; + + typedef const typename M::type const_ret_type; + const static long cost = M::cost + 2; + + const_ret_type apply ( long r, long c) const + { + const type temp = this->m(r,c); + if (temp >= thresh) + return temp; + else + return thresh; + } + }; + + template < + typename EXP + > + const matrix_op > lowerbound ( + const matrix_exp& m, + const typename EXP::type& thresh + ) + { + typedef op_lowerbound op; + return matrix_op(op(m.ref(), thresh)); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_upperbound : basic_op_m + { + typedef typename M::type type; + + op_upperbound( const M& m_, const type& thresh_) : + basic_op_m(m_), thresh(thresh_){} + + const type& thresh; + + typedef const typename M::type const_ret_type; + const static long cost = M::cost + 2; + + const_ret_type apply ( long r, long c) const + { + const type temp = this->m(r,c); + if (temp <= thresh) + return temp; + else + return thresh; + } + }; + + template < + typename EXP + > + const matrix_op > upperbound ( + const matrix_exp& m, + const typename EXP::type& thresh + ) + { + typedef op_upperbound op; + return matrix_op(op(m.ref(), thresh)); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_reshape + { + op_reshape(const M& m_, const long& rows_, const long& cols_) : m(m_),rows(rows_),cols(cols_) {} + const M& m; + const long rows; + const long cols; + + const static long cost = M::cost+2; + const static long NR = 0; + const static long NC = 0; + typedef typename M::type type; + typedef typename M::const_ret_type const_ret_type; + typedef typename M::mem_manager_type mem_manager_type; + typedef typename M::layout_type layout_type; + + const_ret_type apply ( long r, long c) const + { + const long idx = r*cols + c; + return m(idx/m.nc(), idx%m.nc()); + } + + long nr () const { return rows; } + long nc () const { return cols; } + + template bool aliases ( const matrix_exp& item) const { return m.aliases(item); } + template bool destructively_aliases ( const matrix_exp& item) const { return m.aliases(item); } + }; + + template < + typename EXP + > + const matrix_op > reshape ( + const matrix_exp& m, + const long& rows, + const long& cols + ) + { + DLIB_ASSERT(m.size() == rows*cols && rows > 0 && cols > 0, + "\tconst matrix_exp reshape(m, rows, cols)" + << "\n\t The size of m must match the dimensions you want to reshape it into." + << "\n\t m.size(): " << m.size() + << "\n\t rows*cols: " << rows*cols + << "\n\t rows: " << rows + << "\n\t cols: " << cols + ); + + typedef op_reshape op; + return matrix_op(op(m.ref(), rows, cols)); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP1, + typename EXP2 + > + typename disable_if,bool>::type equal ( + const matrix_exp& a, + const matrix_exp& b, + const typename EXP1::type eps = 100*std::numeric_limits::epsilon() + ) + { + // check if the dimensions don't match + if (a.nr() != b.nr() || a.nc() != b.nc()) + return false; + + for (long r = 0; r < a.nr(); ++r) + { + for (long c = 0; c < a.nc(); ++c) + { + if (std::abs(a(r,c)-b(r,c)) > eps) + return false; + } + } + + // no non-equal points found so we return true + return true; + } + + template < + typename EXP1, + typename EXP2 + > + typename enable_if,bool>::type equal ( + const matrix_exp& a, + const matrix_exp& b, + const typename EXP1::type::value_type eps = 100*std::numeric_limits::epsilon() + ) + { + // check if the dimensions don't match + if (a.nr() != b.nr() || a.nc() != b.nc()) + return false; + + for (long r = 0; r < a.nr(); ++r) + { + for (long c = 0; c < a.nc(); ++c) + { + if (std::abs(real(a(r,c)-b(r,c))) > eps || + std::abs(imag(a(r,c)-b(r,c))) > eps) + return false; + } + } + + // no non-equal points found so we return true + return true; + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_scale_columns + { + op_scale_columns(const M1& m1_, const M2& m2_) : m1(m1_), m2(m2_) {} + const M1& m1; + const M2& m2; + + const static long cost = M1::cost + M2::cost + 1; + typedef typename M1::type type; + typedef const typename M1::type const_ret_type; + typedef typename M1::mem_manager_type mem_manager_type; + typedef typename M1::layout_type layout_type; + const static long NR = M1::NR; + const static long NC = M1::NC; + + const_ret_type apply ( long r, long c) const { return m1(r,c)*m2(c); } + + long nr () const { return m1.nr(); } + long nc () const { return m1.nc(); } + + template bool aliases ( const matrix_exp& item) const + { return m1.aliases(item) || m2.aliases(item) ; } + template bool destructively_aliases ( const matrix_exp& item) const + { return m1.destructively_aliases(item) || m2.aliases(item); } + }; + + template < + typename EXP1, + typename EXP2 + > + const matrix_op > scale_columns ( + const matrix_exp& m, + const matrix_exp& v + ) + { + // Both arguments to this function must contain the same type of element + COMPILE_TIME_ASSERT((is_same_type::value == true)); + // The v argument must be a row or column vector. + COMPILE_TIME_ASSERT((EXP2::NC == 1 || EXP2::NC == 0) || (EXP2::NR == 1 || EXP2::NR == 0)); + + // figure out the compile time known length of v + const long v_len = ((EXP2::NR)*(EXP2::NC) == 0)? 0 : (tmax::value); + + // the length of v must match the number of columns in m + COMPILE_TIME_ASSERT(EXP1::NC == v_len || EXP1::NC == 0 || v_len == 0); + + DLIB_ASSERT(is_vector(v) == true && v.size() == m.nc(), + "\tconst matrix_exp scale_columns(m, v)" + << "\n\tv must be a row or column vector and its length must match the number of columns in m" + << "\n\tm.nr(): " << m.nr() + << "\n\tm.nc(): " << m.nc() + << "\n\tv.nr(): " << v.nr() + << "\n\tv.nc(): " << v.nc() + ); + typedef op_scale_columns op; + return matrix_op(op(m.ref(),v.ref())); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_scale_columns_diag + { + op_scale_columns_diag(const M1& m1_, const M2& m2_) : m1(m1_), m2(m2_) {} + const M1& m1; + const M2& m2; + + const static long cost = M1::cost + M2::cost + 1; + typedef typename M1::type type; + typedef const typename M1::type const_ret_type; + typedef typename M1::mem_manager_type mem_manager_type; + typedef typename M1::layout_type layout_type; + const static long NR = M1::NR; + const static long NC = M1::NC; + + const_ret_type apply ( long r, long c) const { return m1(r,c)*m2(c,c); } + + long nr () const { return m1.nr(); } + long nc () const { return m1.nc(); } + + template bool aliases ( const matrix_exp& item) const + { return m1.aliases(item) || m2.aliases(item) ; } + template bool destructively_aliases ( const matrix_exp& item) const + { return m1.destructively_aliases(item) || m2.aliases(item); } + }; + +// turn expressions of the form mat*diagonal_matrix into scale_columns(mat, d) + template < + typename EXP1, + typename EXP2 + > + const matrix_op > operator* ( + const matrix_exp& m, + const matrix_diag_exp& d + ) + { + // Both arguments to this function must contain the same type of element + COMPILE_TIME_ASSERT((is_same_type::value == true)); + + // figure out the compile time known length of d + const long v_len = ((EXP2::NR)*(EXP2::NC) == 0)? 0 : (tmax::value); + + // the length of d must match the number of columns in m + COMPILE_TIME_ASSERT(EXP1::NC == v_len || EXP1::NC == 0 || v_len == 0); + + DLIB_ASSERT(m.nc() == d.nr(), + "\tconst matrix_exp operator*(m, d)" + << "\n\tmatrix dimensions don't match" + << "\n\tm.nr(): " << m.nr() + << "\n\tm.nc(): " << m.nc() + << "\n\td.nr(): " << d.nr() + << "\n\td.nc(): " << d.nc() + ); + typedef op_scale_columns_diag op; + return matrix_op(op(m.ref(),d.ref())); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_scale_rows + { + op_scale_rows(const M1& m1_, const M2& m2_) : m1(m1_), m2(m2_) {} + const M1& m1; + const M2& m2; + + const static long cost = M1::cost + M2::cost + 1; + typedef typename M1::type type; + typedef const typename M1::type const_ret_type; + typedef typename M1::mem_manager_type mem_manager_type; + typedef typename M1::layout_type layout_type; + const static long NR = M1::NR; + const static long NC = M1::NC; + + const_ret_type apply ( long r, long c) const { return m1(r,c)*m2(r); } + + long nr () const { return m1.nr(); } + long nc () const { return m1.nc(); } + + template bool aliases ( const matrix_exp& item) const + { return m1.aliases(item) || m2.aliases(item) ; } + template bool destructively_aliases ( const matrix_exp& item) const + { return m1.destructively_aliases(item) || m2.aliases(item); } + }; + + template < + typename EXP1, + typename EXP2 + > + const matrix_op > scale_rows ( + const matrix_exp& m, + const matrix_exp& v + ) + { + // Both arguments to this function must contain the same type of element + COMPILE_TIME_ASSERT((is_same_type::value == true)); + // The v argument must be a row or column vector. + COMPILE_TIME_ASSERT((EXP2::NC == 1 || EXP2::NC == 0) || (EXP2::NR == 1 || EXP2::NR == 0)); + + // figure out the compile time known length of v + const long v_len = ((EXP2::NR)*(EXP2::NC) == 0)? 0 : (tmax::value); + + // the length of v must match the number of rows in m + COMPILE_TIME_ASSERT(EXP1::NR == v_len || EXP1::NR == 0 || v_len == 0); + + DLIB_ASSERT(is_vector(v) == true && v.size() == m.nr(), + "\tconst matrix_exp scale_rows(m, v)" + << "\n\tv must be a row or column vector and its length must match the number of rows in m" + << "\n\tm.nr(): " << m.nr() + << "\n\tm.nc(): " << m.nc() + << "\n\tv.nr(): " << v.nr() + << "\n\tv.nc(): " << v.nc() + ); + typedef op_scale_rows op; + return matrix_op(op(m.ref(),v.ref())); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_scale_rows_diag + { + op_scale_rows_diag(const M1& m1_, const M2& m2_) : m1(m1_), m2(m2_) {} + const M1& m1; + const M2& m2; + + const static long cost = M1::cost + M2::cost + 1; + typedef typename M1::type type; + typedef const typename M1::type const_ret_type; + typedef typename M1::mem_manager_type mem_manager_type; + typedef typename M1::layout_type layout_type; + const static long NR = M1::NR; + const static long NC = M1::NC; + + const_ret_type apply ( long r, long c) const { return m1(r,c)*m2(r,r); } + + long nr () const { return m1.nr(); } + long nc () const { return m1.nc(); } + + template bool aliases ( const matrix_exp& item) const + { return m1.aliases(item) || m2.aliases(item) ; } + template bool destructively_aliases ( const matrix_exp& item) const + { return m1.destructively_aliases(item) || m2.aliases(item); } + }; + +// turn expressions of the form diagonal_matrix*mat into scale_rows(mat, d) + template < + typename EXP1, + typename EXP2 + > + const matrix_op > operator* ( + const matrix_diag_exp& d, + const matrix_exp& m + ) + { + // Both arguments to this function must contain the same type of element + COMPILE_TIME_ASSERT((is_same_type::value == true)); + + // figure out the compile time known length of d + const long v_len = ((EXP2::NR)*(EXP2::NC) == 0)? 0 : (tmax::value); + + // the length of d must match the number of rows in m + COMPILE_TIME_ASSERT(EXP1::NR == v_len || EXP1::NR == 0 || v_len == 0); + + DLIB_ASSERT(d.nc() == m.nr(), + "\tconst matrix_exp operator*(d, m)" + << "\n\tThe dimensions of the d and m matrices don't match." + << "\n\tm.nr(): " << m.nr() + << "\n\tm.nc(): " << m.nc() + << "\n\td.nr(): " << d.nr() + << "\n\td.nc(): " << d.nc() + ); + typedef op_scale_rows_diag op; + return matrix_op(op(m.ref(),d.ref())); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + /* + The idea here is to catch expressions of the form d*M*d where d is diagonal and M + is some square matrix and turn them into something equivalent to + pointwise_multiply(diag(d)*trans(diag(d)), M). + + The reason for this is that doing it this way is more numerically stable. In particular, + doing 2 matrix multiplies as suggested by d*M*d could result in an asymmetric matrix even + if M is symmetric to begin with. + */ + + template + struct op_diag_m_diag + { + // This operator represents M1*M2*M3 where M1 and M3 are diagonal + + op_diag_m_diag(const M1& m1_, const M2& m2_, const M3& m3_) : m1(m1_), m2(m2_), m3(m3_) {} + const M1& m1; + const M2& m2; + const M3& m3; + + const static long cost = M1::cost + M2::cost + M3::cost + 1; + typedef typename M2::type type; + typedef const typename M2::type const_ret_type; + typedef typename M2::mem_manager_type mem_manager_type; + typedef typename M2::layout_type layout_type; + const static long NR = M2::NR; + const static long NC = M2::NC; + + const_ret_type apply ( long r, long c) const { return (m1(r,r)*m3(c,c))*m2(r,c); } + + long nr () const { return m2.nr(); } + long nc () const { return m2.nc(); } + + template bool aliases ( const matrix_exp& item) const + { return m1.aliases(item) || m2.aliases(item) || m3.aliases(item) ; } + template bool destructively_aliases ( const matrix_exp& item) const + { return m2.destructively_aliases(item) || m1.aliases(item) || m3.aliases(item) ; } + }; + + // catch d*(M*d) = EXP1*EXP2*EXP3 + template < + typename EXP1, + typename EXP2, + typename EXP3 + > + const matrix_op > operator* ( + const matrix_diag_exp& d, + const matrix_exp > >& m + ) + { + // Both arguments to this function must contain the same type of element + COMPILE_TIME_ASSERT((is_same_type::value == true)); + + // figure out the compile time known length of d + const long v_len = ((EXP1::NR)*(EXP1::NC) == 0)? 0 : (tmax::value); + + // the length of d must match the number of rows in m + COMPILE_TIME_ASSERT(EXP2::NR == v_len || EXP2::NR == 0 || v_len == 0); + + DLIB_ASSERT(d.nc() == m.nr(), + "\tconst matrix_exp operator*(d, m)" + << "\n\tmatrix dimensions don't match" + << "\n\td.nr(): " << d.nr() + << "\n\td.nc(): " << d.nc() + << "\n\tm.nr(): " << m.nr() + << "\n\tm.nc(): " << m.nc() + ); + typedef op_diag_m_diag op; + return matrix_op(op(d.ref(), m.ref().op.m1, m.ref().op.m2)); + } + + // catch (d*M)*d = EXP1*EXP2*EXP3 + template < + typename EXP1, + typename EXP2, + typename EXP3 + > + const matrix_op > operator* ( + const matrix_exp > >& m, + const matrix_diag_exp& d + ) + { + // Both arguments to this function must contain the same type of element + COMPILE_TIME_ASSERT((is_same_type::value == true)); + + // figure out the compile time known length of d + const long v_len = ((EXP3::NR)*(EXP3::NC) == 0)? 0 : (tmax::value); + + // the length of d must match the number of columns in m + COMPILE_TIME_ASSERT(EXP2::NC == v_len || EXP2::NC == 0 || v_len == 0); + + DLIB_ASSERT(m.nc() == d.nr(), + "\tconst matrix_exp operator*(m, d)" + << "\n\tmatrix dimensions don't match" + << "\n\tm.nr(): " << m.nr() + << "\n\tm.nc(): " << m.nc() + << "\n\td.nr(): " << d.nr() + << "\n\td.nc(): " << d.nc() + ); + typedef op_diag_m_diag op; + return matrix_op(op(m.ref().op.m2, m.ref().op.m1, d.ref())); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + struct sort_columns_sort_helper + { + template + bool operator() ( + const T& item1, + const T& item2 + ) const + { + return item1.first < item2.first; + } + }; + + template < + typename T, long NR, long NC, typename mm, typename l1, + long NR2, long NC2, typename mm2, typename l2 + > + void sort_columns ( + matrix& m, + matrix& v + ) + { + COMPILE_TIME_ASSERT(NC2 == 1 || NC2 == 0); + COMPILE_TIME_ASSERT(NC == NR2 || NC == 0 || NR2 == 0); + + DLIB_ASSERT(is_col_vector(v) == true && v.size() == m.nc(), + "\tconst matrix_exp sort_columns(m, v)" + << "\n\tv must be a column vector and its length must match the number of columns in m" + << "\n\tm.nr(): " << m.nr() + << "\n\tm.nc(): " << m.nc() + << "\n\tv.nr(): " << v.nr() + << "\n\tv.nc(): " << v.nc() + ); + + + + // Now we have to sort the given vectors in the m matrix according + // to how big their corresponding v(column index) values are. + typedef std::pair > col_pair; + typedef std_allocator alloc; + std::vector colvalues; + col_pair p; + for (long r = 0; r < v.nr(); ++r) + { + p.first = v(r); + p.second = colm(m,r); + colvalues.push_back(p); + } + std::sort(colvalues.begin(), colvalues.end(), sort_columns_sort_helper()); + + for (long i = 0; i < v.nr(); ++i) + { + v(i) = colvalues[i].first; + set_colm(m,i) = colvalues[i].second; + } + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, long NR, long NC, typename mm, typename l1, + long NR2, long NC2, typename mm2, typename l2 + > + void rsort_columns ( + matrix& m, + matrix& v + ) + { + COMPILE_TIME_ASSERT(NC2 == 1 || NC2 == 0); + COMPILE_TIME_ASSERT(NC == NR2 || NC == 0 || NR2 == 0); + + DLIB_ASSERT(is_col_vector(v) == true && v.size() == m.nc(), + "\tconst matrix_exp rsort_columns(m, v)" + << "\n\tv must be a column vector and its length must match the number of columns in m" + << "\n\tm.nr(): " << m.nr() + << "\n\tm.nc(): " << m.nc() + << "\n\tv.nr(): " << v.nr() + << "\n\tv.nc(): " << v.nc() + ); + + + + // Now we have to sort the given vectors in the m matrix according + // to how big their corresponding v(column index) values are. + typedef std::pair > col_pair; + typedef std_allocator alloc; + std::vector colvalues; + col_pair p; + for (long r = 0; r < v.nr(); ++r) + { + p.first = v(r); + p.second = colm(m,r); + colvalues.push_back(p); + } + std::sort(colvalues.rbegin(), colvalues.rend(), sort_columns_sort_helper()); + + for (long i = 0; i < v.nr(); ++i) + { + v(i) = colvalues[i].first; + set_colm(m,i) = colvalues[i].second; + } + + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_tensor_product + { + op_tensor_product(const M1& m1_, const M2& m2_) : m1(m1_),m2(m2_) {} + const M1& m1; + const M2& m2; + + const static long cost = M1::cost + M2::cost + 1; + const static long NR = M1::NR*M2::NR; + const static long NC = M1::NC*M2::NC; + typedef typename M1::type type; + typedef const typename M1::type const_ret_type; + typedef typename M1::mem_manager_type mem_manager_type; + typedef typename M1::layout_type layout_type; + + const_ret_type apply ( long r, long c) const + { + return m1(r/m2.nr(),c/m2.nc())*m2(r%m2.nr(),c%m2.nc()); + } + + long nr () const { return m1.nr()*m2.nr(); } + long nc () const { return m1.nc()*m2.nc(); } + + template bool aliases ( const matrix_exp& item) const + { return m1.aliases(item) || m2.aliases(item); } + template bool destructively_aliases ( const matrix_exp& item) const + { return m1.aliases(item) || m2.aliases(item); } + }; + + template < + typename EXP1, + typename EXP2 + > + inline const matrix_op > tensor_product ( + const matrix_exp& a, + const matrix_exp& b + ) + { + COMPILE_TIME_ASSERT((is_same_type::value == true)); + typedef op_tensor_product op; + return matrix_op(op(a.ref(),b.ref())); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_make_symmetric : basic_op_m + { + op_make_symmetric ( const M& m_) : basic_op_m(m_){} + + const static long cost = M::cost+1; + typedef typename M::type type; + typedef typename M::const_ret_type const_ret_type; + const_ret_type apply ( long r, long c) const + { + if (r >= c) + return this->m(r,c); + else + return this->m(c,r); + } + }; + + template < + typename EXP + > + const matrix_op > make_symmetric ( + const matrix_exp& m + ) + { + DLIB_ASSERT(m.nr() == m.nc(), + "\tconst matrix make_symmetric(m)" + << "\n\t m must be a square matrix" + << "\n\t m.nr(): " << m.nr() + << "\n\t m.nc(): " << m.nc() + ); + + typedef op_make_symmetric op; + return matrix_op(op(m.ref())); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_lowerm : basic_op_m + { + op_lowerm( const M& m_) : basic_op_m(m_){} + + const static long cost = M::cost+2; + typedef typename M::type type; + typedef const typename M::type const_ret_type; + const_ret_type apply ( long r, long c) const + { + if (r >= c) + return this->m(r,c); + else + return 0; + } + }; + + template + struct op_lowerm_s : basic_op_m + { + typedef typename M::type type; + op_lowerm_s( const M& m_, const type& s_) : basic_op_m(m_), s(s_){} + + const type s; + + const static long cost = M::cost+2; + typedef const typename M::type const_ret_type; + const_ret_type apply ( long r, long c) const + { + if (r > c) + return this->m(r,c); + else if (r==c) + return s; + else + return 0; + } + }; + + template < + typename EXP + > + const matrix_op > lowerm ( + const matrix_exp& m + ) + { + typedef op_lowerm op; + return matrix_op(op(m.ref())); + } + + template < + typename EXP + > + const matrix_op > lowerm ( + const matrix_exp& m, + typename EXP::type s + ) + { + typedef op_lowerm_s op; + return matrix_op(op(m.ref(),s)); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_upperm : basic_op_m + { + op_upperm( const M& m_) : basic_op_m(m_){} + + const static long cost = M::cost+2; + typedef typename M::type type; + typedef const typename M::type const_ret_type; + const_ret_type apply ( long r, long c) const + { + if (r <= c) + return this->m(r,c); + else + return 0; + } + }; + + template + struct op_upperm_s : basic_op_m + { + typedef typename M::type type; + op_upperm_s( const M& m_, const type& s_) : basic_op_m(m_), s(s_){} + + const type s; + + const static long cost = M::cost+2; + typedef const typename M::type const_ret_type; + const_ret_type apply ( long r, long c) const + { + if (r < c) + return this->m(r,c); + else if (r==c) + return s; + else + return 0; + } + }; + + template < + typename EXP + > + const matrix_op > upperm ( + const matrix_exp& m + ) + { + typedef op_upperm op; + return matrix_op(op(m.ref())); + } + + template < + typename EXP + > + const matrix_op > upperm ( + const matrix_exp& m, + typename EXP::type s + ) + { + typedef op_upperm_s op; + return matrix_op(op(m.ref(),s)); + } + +// ---------------------------------------------------------------------------------------- + + template + inline const matrix randm( + long nr, + long nc, + rand_gen& rnd + ) + { + DLIB_ASSERT(nr >= 0 && nc >= 0, + "\tconst matrix randm(nr, nc, rnd)" + << "\n\tInvalid inputs to this function" + << "\n\tnr: " << nr + << "\n\tnc: " << nc + ); + + matrix m(nr,nc); + for (long r = 0; r < m.nr(); ++r) + { + for (long c = 0; c < m.nc(); ++c) + { + m(r,c) = rnd.get_random_double(); + } + } + + return m; + } + +// ---------------------------------------------------------------------------------------- + + inline const matrix randm( + long nr, + long nc + ) + { + DLIB_ASSERT(nr >= 0 && nc >= 0, + "\tconst matrix randm(nr, nc)" + << "\n\tInvalid inputs to this function" + << "\n\tnr: " << nr + << "\n\tnc: " << nc + ); + + matrix m(nr,nc); + // make a double that contains RAND_MAX + the smallest number that still + // makes the resulting double slightly bigger than static_cast(RAND_MAX) + double max_val = RAND_MAX; + max_val += std::numeric_limits::epsilon()*RAND_MAX; + + for (long r = 0; r < m.nr(); ++r) + { + for (long c = 0; c < m.nc(); ++c) + { + m(r,c) = std::rand()/max_val; + } + } + + return m; + } + +// ---------------------------------------------------------------------------------------- + + inline const matrix_range_exp linspace ( + double start, + double end, + long num + ) + { + DLIB_ASSERT(num >= 0, + "\tconst matrix_exp linspace(start, end, num)" + << "\n\tInvalid inputs to this function" + << "\n\tstart: " << start + << "\n\tend: " << end + << "\n\tnum: " << num + ); + + return matrix_range_exp(start,end,num,false); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_linpiece + { + op_linpiece(const double val_, const M& joints_) : joints(joints_), val(val_){} + + const M& joints; + const double val; + + const static long cost = 10; + + const static long NR = (M::NR*M::NC==0) ? (0) : (M::NR*M::NC-1); + const static long NC = 1; + typedef typename M::type type; + typedef default_memory_manager mem_manager_type; + typedef row_major_layout layout_type; + + typedef type const_ret_type; + const_ret_type apply (long i, long ) const + { + if (joints(i) < val) + return std::min(val,joints(i+1)) - joints(i); + else + return 0; + } + + long nr () const { return joints.size()-1; } + long nc () const { return 1; } + + template bool aliases ( const matrix_exp& item) const { return joints.aliases(item); } + template bool destructively_aliases ( const matrix_exp& item) const { return joints.aliases(item); } + }; + + template < typename EXP > + const matrix_op > linpiece ( + const double val, + const matrix_exp& joints + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(is_vector(joints) && joints.size() >= 2, + "\t matrix_exp linpiece()" + << "\n\t Invalid inputs were given to this function " + << "\n\t is_vector(joints): " << is_vector(joints) + << "\n\t joints.size(): " << joints.size() + ); +#ifdef ENABLE_ASSERTS + for (long i = 1; i < joints.size(); ++i) + { + DLIB_ASSERT(joints(i-1) < joints(i), + "\t matrix_exp linpiece()" + << "\n\t Invalid inputs were given to this function " + << "\n\t joints("< op; + return matrix_op(op(val,joints.ref())); + } + +// ---------------------------------------------------------------------------------------- + + inline const matrix_log_range_exp logspace ( + double start, + double end, + long num + ) + { + DLIB_ASSERT(num >= 0, + "\tconst matrix_exp logspace(start, end, num)" + << "\n\tInvalid inputs to this function" + << "\n\tstart: " << start + << "\n\tend: " << end + << "\n\tnum: " << num + ); + + return matrix_log_range_exp(start,end,num); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_cart_prod + { + op_cart_prod(const M1& m1_, const M2& m2_) : m1(m1_),m2(m2_) {} + const M1& m1; + const M2& m2; + + const static long cost = M1::cost+M2::cost+1; + typedef typename M1::type type; + typedef const typename M1::const_ret_type const_ret_type; + + typedef typename M1::mem_manager_type mem_manager_type; + typedef typename M1::layout_type layout_type; + const static long NR = M1::NR+M2::NR; + const static long NC = M1::NC*M2::NC; + + const_ret_type apply ( long r, long c) const + { + if (r < m1.nr()) + return m1(r, c/m2.nc()); + else + return m2(r-m1.nr(), c%m2.nc()); + } + + long nr () const { return m1.nr() + m2.nr(); } + long nc () const { return m1.nc() * m2.nc(); } + + template bool aliases ( const matrix_exp& item) const + { return m1.aliases(item) || m2.aliases(item); } + template bool destructively_aliases ( const matrix_exp& item) const + { return m1.aliases(item) || m2.aliases(item); } + }; + + template < + typename EXP1, + typename EXP2 + > + const matrix_op > cartesian_product ( + const matrix_exp& a, + const matrix_exp& b + ) + { + COMPILE_TIME_ASSERT((is_same_type::value == true)); + + typedef op_cart_prod op; + return matrix_op(op(a.ref(),b.ref())); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_mat_to_vect + { + op_mat_to_vect(const M& m_) : m(m_) {} + const M& m; + + const static long cost = M::cost+2; + const static long NR = M::NC*M::NR; + const static long NC = 1; + typedef typename M::type type; + typedef typename M::const_ret_type const_ret_type; + typedef typename M::mem_manager_type mem_manager_type; + typedef typename M::layout_type layout_type; + + const_ret_type apply ( long r, long ) const { return m(r/m.nc(), r%m.nc()); } + + long nr () const { return m.size(); } + long nc () const { return 1; } + + template bool aliases ( const matrix_exp& item) const { return m.aliases(item); } + template bool destructively_aliases ( const matrix_exp& item) const { return m.aliases(item); } + }; + + template < + typename EXP + > + const matrix_op > reshape_to_column_vector ( + const matrix_exp& m + ) + { + typedef op_mat_to_vect op; + return matrix_op(op(m.ref())); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + long NR_, + long NC_, + typename MM + > + struct op_mat_to_vect2 + { + typedef matrix M; + op_mat_to_vect2(const M& m_) : m(m_) {} + const M& m; + + const static long cost = M::cost+2; + const static long NR = M::NC*M::NR; + const static long NC = 1; + typedef typename M::type type; + typedef typename M::const_ret_type const_ret_type; + typedef typename M::mem_manager_type mem_manager_type; + typedef typename M::layout_type layout_type; + + const_ret_type apply ( long r, long ) const { return (&m(0,0))[r]; } + + long nr () const { return m.size(); } + long nc () const { return 1; } + + template bool aliases ( const matrix_exp& item) const { return m.aliases(item); } + template bool destructively_aliases ( const matrix_exp& item) const { return m.aliases(item); } + }; + + template < + typename T, + long NR, + long NC, + typename MM + > + const matrix_op > reshape_to_column_vector ( + const matrix& m + ) + { + typedef op_mat_to_vect2 op; + return matrix_op(op(m.ref())); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_join_rows + { + op_join_rows(const M1& m1_, const M2& m2_) : m1(m1_),m2(m2_),_nr(std::max(m1.nr(),m2.nr())) {} + const M1& m1; + const M2& m2; + const long _nr; + + template + struct type_selector; + template + struct type_selector { typedef T type; }; + template + struct type_selector { typedef U type; }; + + // If both const_ret_types are references then we should use them as the const_ret_type type + // but otherwise we should use the normal type. + typedef typename M1::const_ret_type T1; + typedef typename M1::type T2; + typedef typename M2::const_ret_type T3; + typedef typename type_selector::value && is_reference_type::value>::type const_ret_type; + + const static long cost = M1::cost + M2::cost + 1; + const static long NR = tmax::value; + const static long NC = (M1::NC*M2::NC != 0)? (M1::NC+M2::NC) : (0); + typedef typename M1::type type; + typedef typename M1::mem_manager_type mem_manager_type; + typedef typename M1::layout_type layout_type; + + const_ret_type apply (long r, long c) const + { + if (c < m1.nc()) + return m1(r,c); + else + return m2(r,c-m1.nc()); + } + + long nr () const { return _nr; } + long nc () const { return m1.nc()+m2.nc(); } + + template bool aliases ( const matrix_exp& item) const + { return m1.aliases(item) || m2.aliases(item); } + template bool destructively_aliases ( const matrix_exp& item) const + { return m1.aliases(item) || m2.aliases(item); } + }; + + template < + typename EXP1, + typename EXP2 + > + inline const matrix_op > join_rows ( + const matrix_exp& a, + const matrix_exp& b + ) + { + COMPILE_TIME_ASSERT((is_same_type::value == true)); + // You are getting an error on this line because you are trying to join two matrices that + // don't have the same number of rows + COMPILE_TIME_ASSERT(EXP1::NR == EXP2::NR || (EXP1::NR*EXP2::NR == 0)); + + DLIB_ASSERT(a.nr() == b.nr() || a.size() == 0 || b.size() == 0, + "\tconst matrix_exp join_rows(const matrix_exp& a, const matrix_exp& b)" + << "\n\tYou can only use join_rows() if both matrices have the same number of rows" + << "\n\ta.nr(): " << a.nr() + << "\n\tb.nr(): " << b.nr() + << "\n\ta.nc(): " << a.nc() + << "\n\tb.nc(): " << b.nc() + ); + + typedef op_join_rows op; + return matrix_op(op(a.ref(),b.ref())); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_join_cols + { + op_join_cols(const M1& m1_, const M2& m2_) : m1(m1_),m2(m2_),_nc(std::max(m1.nc(),m2.nc())) {} + const M1& m1; + const M2& m2; + const long _nc; + + template + struct type_selector; + template + struct type_selector { typedef T type; }; + template + struct type_selector { typedef U type; }; + + // If both const_ret_types are references then we should use them as the const_ret_type type + // but otherwise we should use the normal type. + typedef typename M1::const_ret_type T1; + typedef typename M1::type T2; + typedef typename M2::const_ret_type T3; + typedef typename type_selector::value && is_reference_type::value>::type const_ret_type; + + + + const static long cost = M1::cost + M2::cost + 1; + const static long NC = tmax::value; + const static long NR = (M1::NR*M2::NR != 0)? (M1::NR+M2::NR) : (0); + typedef typename M1::type type; + typedef typename M1::mem_manager_type mem_manager_type; + typedef typename M1::layout_type layout_type; + + const_ret_type apply ( long r, long c) const + { + if (r < m1.nr()) + return m1(r,c); + else + return m2(r-m1.nr(),c); + } + + long nr () const { return m1.nr()+m2.nr(); } + long nc () const { return _nc; } + + + template bool aliases ( const matrix_exp& item) const + { return m1.aliases(item) || m2.aliases(item); } + template bool destructively_aliases ( const matrix_exp& item) const + { return m1.aliases(item) || m2.aliases(item); } + }; + + template < + typename EXP1, + typename EXP2 + > + inline const matrix_op > join_cols ( + const matrix_exp& a, + const matrix_exp& b + ) + { + COMPILE_TIME_ASSERT((is_same_type::value == true)); + // You are getting an error on this line because you are trying to join two matrices that + // don't have the same number of columns + COMPILE_TIME_ASSERT(EXP1::NC == EXP2::NC || (EXP1::NC*EXP2::NC == 0)); + + DLIB_ASSERT(a.nc() == b.nc() || a.size() == 0 || b.size() == 0, + "\tconst matrix_exp join_cols(const matrix_exp& a, const matrix_exp& b)" + << "\n\tYou can only use join_cols() if both matrices have the same number of columns" + << "\n\ta.nr(): " << a.nr() + << "\n\tb.nr(): " << b.nr() + << "\n\ta.nc(): " << a.nc() + << "\n\tb.nc(): " << b.nc() + ); + + typedef op_join_cols op; + return matrix_op(op(a.ref(),b.ref())); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_fliplr + { + op_fliplr( const M& m_) : m(m_){} + + const M& m; + + const static long cost = M::cost; + const static long NR = M::NR; + const static long NC = M::NC; + typedef typename M::type type; + typedef typename M::const_ret_type const_ret_type; + typedef typename M::mem_manager_type mem_manager_type; + typedef typename M::layout_type layout_type; + + const_ret_type apply (long r, long c) const { return m(r,m.nc()-c-1); } + + long nr () const { return m.nr(); } + long nc () const { return m.nc(); } + + template bool aliases ( const matrix_exp& item) const { return m.aliases(item); } + template bool destructively_aliases ( const matrix_exp& item) const { return m.aliases(item); } + + }; + + template < + typename M + > + const matrix_op > fliplr ( + const matrix_exp& m + ) + { + typedef op_fliplr op; + return matrix_op(op(m.ref())); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_flipud + { + op_flipud( const M& m_) : m(m_){} + + const M& m; + + const static long cost = M::cost; + const static long NR = M::NR; + const static long NC = M::NC; + typedef typename M::type type; + typedef typename M::const_ret_type const_ret_type; + typedef typename M::mem_manager_type mem_manager_type; + typedef typename M::layout_type layout_type; + + const_ret_type apply (long r, long c) const { return m(m.nr()-r-1,c); } + + long nr () const { return m.nr(); } + long nc () const { return m.nc(); } + + template bool aliases ( const matrix_exp& item) const { return m.aliases(item); } + template bool destructively_aliases ( const matrix_exp& item) const { return m.aliases(item); } + + }; + + template < + typename M + > + const matrix_op > flipud ( + const matrix_exp& m + ) + { + typedef op_flipud op; + return matrix_op(op(m.ref())); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_flip + { + op_flip( const M& m_) : m(m_){} + + const M& m; + + const static long cost = M::cost; + const static long NR = M::NR; + const static long NC = M::NC; + typedef typename M::type type; + typedef typename M::const_ret_type const_ret_type; + typedef typename M::mem_manager_type mem_manager_type; + typedef typename M::layout_type layout_type; + + const_ret_type apply (long r, long c) const { return m(m.nr()-r-1, m.nc()-c-1); } + + long nr () const { return m.nr(); } + long nc () const { return m.nc(); } + + template bool aliases ( const matrix_exp& item) const { return m.aliases(item); } + template bool destructively_aliases ( const matrix_exp& item) const { return m.aliases(item); } + + }; + + template < + typename M + > + const matrix_op > flip ( + const matrix_exp& m + ) + { + typedef op_flip op; + return matrix_op(op(m.ref())); + } + +// ---------------------------------------------------------------------------------------- + + template + uint32 hash ( + const matrix& item, + uint32 seed = 0 + ) + { + DLIB_ASSERT_HAS_STANDARD_LAYOUT(T); + + if (item.size() == 0) + return 0; + else + return murmur_hash3(&item(0,0), sizeof(T)*item.size(), seed); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MATRIx_UTILITIES_ + diff --git a/ml/dlib/dlib/matrix/matrix_utilities_abstract.h b/ml/dlib/dlib/matrix/matrix_utilities_abstract.h new file mode 100644 index 000000000..ad4c91167 --- /dev/null +++ b/ml/dlib/dlib/matrix/matrix_utilities_abstract.h @@ -0,0 +1,1874 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_MATRIx_UTILITIES_ABSTRACT_ +#ifdef DLIB_MATRIx_UTILITIES_ABSTRACT_ + +#include "matrix_abstract.h" +#include +#include "../pixel.h" +#include "../geometry/rectangle.h" +#inclue + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// Simple matrix utilities +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template + constexpr bool is_row_major ( + const matrix_exp& + ); + /*! + ensures + - returns true if and only if the given matrix expression uses the row_major_layout. + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp diag ( + const matrix_exp& m + ); + /*! + ensures + - returns a column vector R that contains the elements from the diagonal + of m in the order R(0)==m(0,0), R(1)==m(1,1), R(2)==m(2,2) and so on. + !*/ + + template + struct diag_exp + { + /*! + WHAT THIS OBJECT REPRESENTS + This struct allows you to determine the type of matrix expression + object returned from the diag() function. An example makes its + use clear: + + template + void do_something( const matrix_exp& mat) + { + // d is a matrix expression that aliases mat. + typename diag_exp::type d = diag(mat); + + // Print the diagonal of mat. So we see that by using + // diag_exp we can save the object returned by diag() in + // a local variable. + cout << d << endl; + + // Note that you can only save the return value of diag() to + // a local variable if the argument to diag() has a lifetime + // beyond the diag() expression. The example shown above is + // OK but the following would result in undefined behavior: + typename diag_exp::type bad = diag(mat + mat); + } + !*/ + typedef type_of_expression_returned_by_diag type; + }; + +// ---------------------------------------------------------------------------------------- + + const matrix_exp diagm ( + const matrix_exp& m + ); + /*! + requires + - is_vector(m) == true + (i.e. m is a row or column matrix) + ensures + - returns a square matrix M such that: + - diag(M) == m + - non diagonal elements of M are 0 + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp trans ( + const matrix_exp& m + ); + /*! + ensures + - returns the transpose of the matrix m + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_type::type dot ( + const matrix_exp& m1, + const matrix_exp& m2 + ); + /*! + requires + - is_vector(m1) == true + - is_vector(m2) == true + - m1.size() == m2.size() + - m1.size() > 0 + ensures + - returns the dot product between m1 and m2. That is, this function + computes and returns the sum, for all i, of m1(i)*m2(i). + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp lowerm ( + const matrix_exp& m + ); + /*! + ensures + - returns a matrix M such that: + - M::type == the same type that was in m + - M has the same dimensions as m + - M is the lower triangular part of m. That is: + - if (r >= c) then + - M(r,c) == m(r,c) + - else + - M(r,c) == 0 + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp lowerm ( + const matrix_exp& m, + const matrix_exp::type scalar_value + ); + /*! + ensures + - returns a matrix M such that: + - M::type == the same type that was in m + - M has the same dimensions as m + - M is the lower triangular part of m except that the diagonal has + been set to scalar_value. That is: + - if (r > c) then + - M(r,c) == m(r,c) + - else if (r == c) then + - M(r,c) == scalar_value + - else + - M(r,c) == 0 + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp upperm ( + const matrix_exp& m + ); + /*! + ensures + - returns a matrix M such that: + - M::type == the same type that was in m + - M has the same dimensions as m + - M is the upper triangular part of m. That is: + - if (r <= c) then + - M(r,c) == m(r,c) + - else + - M(r,c) == 0 + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp upperm ( + const matrix_exp& m, + const matrix_exp::type scalar_value + ); + /*! + ensures + - returns a matrix M such that: + - M::type == the same type that was in m + - M has the same dimensions as m + - M is the upper triangular part of m except that the diagonal has + been set to scalar_value. That is: + - if (r < c) then + - M(r,c) == m(r,c) + - else if (r == c) then + - M(r,c) == scalar_value + - else + - M(r,c) == 0 + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp make_symmetric ( + const matrix_exp& m + ); + /*! + requires + - m.nr() == m.nc() + (i.e. m must be a square matrix) + ensures + - returns a matrix M such that: + - M::type == the same type that was in m + - M has the same dimensions as m + - M is a symmetric matrix, that is, M == trans(M) and + it is constructed from the lower triangular part of m. Specifically, + we have: + - lowerm(M) == lowerm(m) + - upperm(M) == trans(lowerm(m)) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + long NR, + long NC, + T val + > + const matrix_exp uniform_matrix ( + ); + /*! + requires + - NR > 0 && NC > 0 + ensures + - returns an NR by NC matrix with elements of type T and all set to val. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + long NR, + long NC + > + const matrix_exp uniform_matrix ( + const T& val + ); + /*! + requires + - NR > 0 && NC > 0 + ensures + - returns an NR by NC matrix with elements of type T and all set to val. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + const matrix_exp uniform_matrix ( + long nr, + long nc, + const T& val + ); + /*! + requires + - nr >= 0 && nc >= 0 + ensures + - returns an nr by nc matrix with elements of type T and all set to val. + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp ones_matrix ( + const matrix_exp& mat + ); + /*! + requires + - mat.nr() >= 0 && mat.nc() >= 0 + ensures + - Let T denote the type of element in mat. Then this function + returns uniform_matrix(mat.nr(), mat.nc(), 1) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + const matrix_exp ones_matrix ( + long nr, + long nc + ); + /*! + requires + - nr >= 0 && nc >= 0 + ensures + - returns uniform_matrix(nr, nc, 1) + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp zeros_matrix ( + const matrix_exp& mat + ); + /*! + requires + - mat.nr() >= 0 && mat.nc() >= 0 + ensures + - Let T denote the type of element in mat. Then this function + returns uniform_matrix(mat.nr(), mat.nc(), 0) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + const matrix_exp zeros_matrix ( + long nr, + long nc + ); + /*! + requires + - nr >= 0 && nc >= 0 + ensures + - returns uniform_matrix(nr, nc, 0) + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp identity_matrix ( + const matrix_exp& mat + ); + /*! + requires + - mat.nr() == mat.nc() + ensures + - returns an identity matrix with the same dimensions as mat and + containing the same type of elements as mat. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + const matrix_exp identity_matrix ( + long N + ); + /*! + requires + - N > 0 + ensures + - returns an N by N identity matrix with elements of type T. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + long N + > + const matrix_exp identity_matrix ( + ); + /*! + requires + - N > 0 + ensures + - returns an N by N identity matrix with elements of type T. + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp linspace ( + double start, + double end, + long num + ); + /*! + requires + - num >= 0 + ensures + - returns a matrix M such that: + - M::type == double + - is_row_vector(M) == true + - M.size() == num + - M == a row vector with num linearly spaced values beginning with start + and stopping with end. + - M(num-1) == end + - if (num > 1) then + - M(0) == start + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp logspace ( + double start, + double end, + long num + ); + /*! + requires + - num >= 0 + ensures + - returns a matrix M such that: + - M::type == double + - is_row_vector(M) == true + - M.size() == num + - M == a row vector with num logarithmically spaced values beginning with + 10^start and stopping with 10^end. + (i.e. M == pow(10, linspace(start, end, num))) + - M(num-1) == 10^end + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp linpiece ( + const double val, + const matrix_exp& joints + ); + /*! + requires + - is_vector(joints) == true + - joints.size() >= 2 + - for all valid i < j: + - joints(i) < joints(j) + ensures + - linpiece() is useful for creating piecewise linear functions of val. For + example, if w is a parameter vector then you can represent a piecewise linear + function of val as: f(val) = dot(w, linpiece(val, linspace(0,100,5))). In + this case, f(val) is piecewise linear on the intervals [0,25], [25,50], + [50,75], [75,100]. Moreover, w(i) defines the derivative of f(val) in the + i-th interval. Finally, outside the interval [0,100] f(val) has a derivative + of zero and f(0) == 0. + - To be precise, this function returns a column vector L such that: + - L.size() == joints.size()-1 + - is_col_vector(L) == true + - L contains the same type of elements as joints. + - for all valid i: + - if (joints(i) < val) + - L(i) == min(val,joints(i+1)) - joints(i) + - else + - L(i) == 0 + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + long R, + long C + > + const matrix_exp rotate ( + const matrix_exp& m + ); + /*! + ensures + - returns a matrix R such that: + - R::type == the same type that was in m + - R has the same dimensions as m + - for all valid r and c: + R( (r+R)%m.nr() , (c+C)%m.nc() ) == m(r,c) + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp fliplr ( + const matrix_exp& m + ); + /*! + ensures + - flips the matrix m from left to right and returns the result. + I.e. reverses the order of the columns. + - returns a matrix M such that: + - M::type == the same type that was in m + - M has the same dimensions as m + - for all valid r and c: + M(r,c) == m(r, m.nc()-c-1) + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp flipud ( + const matrix_exp& m + ); + /*! + ensures + - flips the matrix m from up to down and returns the result. + I.e. reverses the order of the rows. + - returns a matrix M such that: + - M::type == the same type that was in m + - M has the same dimensions as m + - for all valid r and c: + M(r,c) == m(m.nr()-r-1, c) + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp flip ( + const matrix_exp& m + ); + /*! + ensures + - flips the matrix m from up to down and left to right and returns the + result. I.e. returns flipud(fliplr(m)). + - returns a matrix M such that: + - M::type == the same type that was in m + - M has the same dimensions as m + - for all valid r and c: + M(r,c) == m(m.nr()-r-1, m.nc()-c-1) + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp reshape ( + const matrix_exp& m, + long rows, + long cols + ); + /*! + requires + - m.size() == rows*cols + - rows > 0 + - cols > 0 + ensures + - returns a matrix M such that: + - M.nr() == rows + - M.nc() == cols + - M.size() == m.size() + - for all valid r and c: + - let IDX = r*cols + c + - M(r,c) == m(IDX/m.nc(), IDX%m.nc()) + + - i.e. The matrix m is reshaped into a new matrix of rows by cols + dimension. Additionally, the elements of m are laid into M in row major + order. + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp reshape_to_column_vector ( + const matrix_exp& m + ); + /*! + ensures + - returns a matrix M such that: + - is_col_vector(M) == true + - M.size() == m.size() + - for all valid r and c: + - m(r,c) == M(r*m.nc() + c) + + - i.e. The matrix m is reshaped into a column vector. Note that + the elements are pulled out in row major order. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + long R, + long C + > + const matrix_exp removerc ( + const matrix_exp& m + ); + /*! + requires + - m.nr() > R >= 0 + - m.nc() > C >= 0 + ensures + - returns a matrix M such that: + - M.nr() == m.nr() - 1 + - M.nc() == m.nc() - 1 + - M == m with its R row and C column removed + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp removerc ( + const matrix_exp& m, + long R, + long C + ); + /*! + requires + - m.nr() > R >= 0 + - m.nc() > C >= 0 + ensures + - returns a matrix M such that: + - M.nr() == m.nr() - 1 + - M.nc() == m.nc() - 1 + - M == m with its R row and C column removed + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + long R + > + const matrix_exp remove_row ( + const matrix_exp& m + ); + /*! + requires + - m.nr() > R >= 0 + ensures + - returns a matrix M such that: + - M.nr() == m.nr() - 1 + - M.nc() == m.nc() + - M == m with its R row removed + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp remove_row ( + const matrix_exp& m, + long R + ); + /*! + requires + - m.nr() > R >= 0 + ensures + - returns a matrix M such that: + - M.nr() == m.nr() - 1 + - M.nc() == m.nc() + - M == m with its R row removed + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + long C + > + const matrix_exp remove_col ( + const matrix_exp& m + ); + /*! + requires + - m.nc() > C >= 0 + ensures + - returns a matrix M such that: + - M.nr() == m.nr() + - M.nc() == m.nc() - 1 + - M == m with its C column removed + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp remove_col ( + const matrix_exp& m, + long C + ); + /*! + requires + - m.nc() > C >= 0 + ensures + - returns a matrix M such that: + - M.nr() == m.nr() + - M.nc() == m.nc() - 1 + - M == m with its C column removed + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename target_type + > + const matrix_exp matrix_cast ( + const matrix_exp& m + ); + /*! + ensures + - returns a matrix R where for all valid r and c: + R(r,c) == static_cast(m(r,c)) + also, R has the same dimensions as m. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + long NR, + long NC, + typename MM, + typename U, + typename L + > + void set_all_elements ( + matrix& m, + U value + ); + /*! + ensures + - for all valid r and c: + m(r,c) == value + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp::matrix_type tmp ( + const matrix_exp& m + ); + /*! + ensures + - returns a temporary matrix object that is a copy of m. + (This allows you to easily force a matrix_exp to fully evaluate) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + long NR, + long NC, + typename MM, + typename L + > + uint32 hash ( + const matrix& item, + uint32 seed = 0 + ); + /*! + requires + - T is a standard layout type (e.g. a POD type like int, float, + or a simple struct). + ensures + - returns a 32bit hash of the data stored in item. + - Each value of seed results in a different hash function being used. + (e.g. hash(item,0) should generally not be equal to hash(item,1)) + - uses the murmur_hash3() routine to compute the actual hash. + - Note that if the memory layout of the elements in item change between + hardware platforms then hash() will give different outputs. If you want + hash() to always give the same output for the same input then you must + ensure that elements of item always have the same layout in memory. + Typically this means using fixed width types and performing byte swapping + to account for endianness before passing item to hash(). + !*/ + +// ---------------------------------------------------------------------------------------- + + // if matrix_exp contains non-complex types (e.g. float, double) + bool equal ( + const matrix_exp& a, + const matrix_exp& b, + const matrix_exp::type epsilon = 100*std::numeric_limits::epsilon() + ); + /*! + ensures + - if (a and b don't have the same dimensions) then + - returns false + - else if (there exists an r and c such that abs(a(r,c)-b(r,c)) > epsilon) then + - returns false + - else + - returns true + !*/ + +// ---------------------------------------------------------------------------------------- + + // if matrix_exp contains std::complex types + bool equal ( + const matrix_exp& a, + const matrix_exp& b, + const matrix_exp::type::value_type epsilon = 100*std::numeric_limits::epsilon() + ); + /*! + ensures + - if (a and b don't have the same dimensions) then + - returns false + - else if (there exists an r and c such that abs(real(a(r,c)-b(r,c))) > epsilon + or abs(imag(a(r,c)-b(r,c))) > epsilon) then + - returns false + - else + - returns true + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp pointwise_multiply ( + const matrix_exp& a, + const matrix_exp& b + ); + /*! + requires + - a.nr() == b.nr() + - a.nc() == b.nc() + - a and b both contain the same type of element (one or both + can also be of type std::complex so long as the underlying type + in them is the same) + ensures + - returns a matrix R such that: + - R::type == the same type that was in a and b. + - R has the same dimensions as a and b. + - for all valid r and c: + R(r,c) == a(r,c) * b(r,c) + !*/ + + const matrix_exp pointwise_multiply ( + const matrix_exp& a, + const matrix_exp& b, + const matrix_exp& c + ); + /*! + performs pointwise_multiply(a,pointwise_multiply(b,c)); + !*/ + + const matrix_exp pointwise_multiply ( + const matrix_exp& a, + const matrix_exp& b, + const matrix_exp& c, + const matrix_exp& d + ); + /*! + performs pointwise_multiply(pointwise_multiply(a,b),pointwise_multiply(c,d)); + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp join_rows ( + const matrix_exp& a, + const matrix_exp& b + ); + /*! + requires + - a.nr() == b.nr() || a.size() == 0 || b.size() == 0 + - a and b both contain the same type of element + ensures + - This function joins two matrices together by concatenating their rows. + - returns a matrix R such that: + - R::type == the same type that was in a and b. + - R.nr() == a.nr() == b.nr() + - R.nc() == a.nc() + b.nc() + - for all valid r and c: + - if (c < a.nc()) then + - R(r,c) == a(r,c) + - else + - R(r,c) == b(r, c-a.nc()) + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp join_cols ( + const matrix_exp& a, + const matrix_exp& b + ); + /*! + requires + - a.nc() == b.nc() || a.size() == 0 || b.size() == 0 + - a and b both contain the same type of element + ensures + - This function joins two matrices together by concatenating their columns. + - returns a matrix R such that: + - R::type == the same type that was in a and b. + - R.nr() == a.nr() + b.nr() + - R.nc() == a.nc() == b.nc() + - for all valid r and c: + - if (r < a.nr()) then + - R(r,c) == a(r,c) + - else + - R(r,c) == b(r-a.nr(), c) + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp tensor_product ( + const matrix_exp& a, + const matrix_exp& b + ); + /*! + requires + - a and b both contain the same type of element + ensures + - returns a matrix R such that: + - R::type == the same type that was in a and b. + - R.nr() == a.nr() * b.nr() + - R.nc() == a.nc() * b.nc() + - for all valid r and c: + R(r,c) == a(r/b.nr(), c/b.nc()) * b(r%b.nr(), c%b.nc()) + - I.e. R is the tensor product of matrix a with matrix b + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp cartesian_product ( + const matrix_exp& A, + const matrix_exp& B + ); + /*! + requires + - A and B both contain the same type of element + ensures + - Think of A and B as sets of column vectors. Then this function + returns a matrix that contains a set of column vectors that is + the Cartesian product of the sets A and B. That is, the resulting + matrix contains every possible combination of vectors from both A and + B. + - returns a matrix R such that: + - R::type == the same type that was in A and B. + - R.nr() == A.nr() + B.nr() + - R.nc() == A.nc() * B.nc() + - Each column of R is the concatenation of a column vector + from A with a column vector from B. + - for all valid r and c: + - if (r < A.nr()) then + - R(r,c) == A(r, c/B.nc()) + - else + - R(r,c) == B(r-A.nr(), c%B.nc()) + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp scale_columns ( + const matrix_exp& m, + const matrix_exp& v + ); + /*! + requires + - is_vector(v) == true + - v.size() == m.nc() + - m and v both contain the same type of element + ensures + - returns a matrix R such that: + - R::type == the same type that was in m and v. + - R has the same dimensions as m. + - for all valid r and c: + R(r,c) == m(r,c) * v(c) + - i.e. R is the result of multiplying each of m's columns by + the corresponding scalar in v. + + - Note that this function is identical to the expression m*diagm(v). + That is, the * operator is overloaded for this case and will invoke + scale_columns() automatically as appropriate. + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp scale_rows ( + const matrix_exp& m, + const matrix_exp& v + ); + /*! + requires + - is_vector(v) == true + - v.size() == m.nr() + - m and v both contain the same type of element + ensures + - returns a matrix R such that: + - R::type == the same type that was in m and v. + - R has the same dimensions as m. + - for all valid r and c: + R(r,c) == m(r,c) * v(r) + - i.e. R is the result of multiplying each of m's rows by + the corresponding scalar in v. + + - Note that this function is identical to the expression diagm(v)*m. + That is, the * operator is overloaded for this case and will invoke + scale_rows() automatically as appropriate. + !*/ + +// ---------------------------------------------------------------------------------------- + + template + void sort_columns ( + matrix& m, + matrix& v + ); + /*! + requires + - is_col_vector(v) == true + - v.size() == m.nc() + - m and v both contain the same type of element + ensures + - the dimensions for m and v are not changed + - sorts the columns of m according to the values in v. + i.e. + - #v == the contents of v but in sorted order according to + operator<. So smaller elements come first. + - Let #v(new(i)) == v(i) (i.e. new(i) is the index element i moved to) + - colm(#m,new(i)) == colm(m,i) + !*/ + +// ---------------------------------------------------------------------------------------- + + template + void rsort_columns ( + matrix& m, + matrix& v + ); + /*! + requires + - is_col_vector(v) == true + - v.size() == m.nc() + - m and v both contain the same type of element + ensures + - the dimensions for m and v are not changed + - sorts the columns of m according to the values in v. + i.e. + - #v == the contents of v but in sorted order according to + operator>. So larger elements come first. + - Let #v(new(i)) == v(i) (i.e. new(i) is the index element i moved to) + - colm(#m,new(i)) == colm(m,i) + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp::type length_squared ( + const matrix_exp& m + ); + /*! + requires + - is_vector(m) == true + ensures + - returns sum(squared(m)) + (i.e. returns the square of the length of the vector m) + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp::type length ( + const matrix_exp& m + ); + /*! + requires + - is_vector(m) == true + ensures + - returns sqrt(sum(squared(m))) + (i.e. returns the length of the vector m) + - if (m contains integer valued elements) then + - The return type is a double that represents the length. Therefore, the + return value of length() is always represented using a floating point + type. + !*/ + +// ---------------------------------------------------------------------------------------- + + bool is_row_vector ( + const matrix_exp& m + ); + /*! + ensures + - if (m.nr() == 1) then + - return true + - else + - returns false + !*/ + + bool is_col_vector ( + const matrix_exp& m + ); + /*! + ensures + - if (m.nc() == 1) then + - return true + - else + - returns false + !*/ + + bool is_vector ( + const matrix_exp& m + ); + /*! + ensures + - if (is_row_vector(m) || is_col_vector(m)) then + - return true + - else + - returns false + !*/ + +// ---------------------------------------------------------------------------------------- + + bool is_finite ( + const matrix_exp& m + ); + /*! + ensures + - returns true if all the values in m are finite values and also not any kind + of NaN value. + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// Thresholding relational operators +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template + const matrix_exp operator< ( + const matrix_exp& m, + const S& s + ); + /*! + requires + - is_built_in_scalar_type::value == true + - is_built_in_scalar_type::value == true + ensures + - returns a matrix R such that: + - R::type == the same type that was in m. + - R has the same dimensions as m. + - for all valid r and c: + - if (m(r,c) < s) then + - R(r,c) == 1 + - else + - R(r,c) == 0 + - i.e. R is a binary matrix of all 1s or 0s. + !*/ + +// ---------------------------------------------------------------------------------------- + + template + const matrix_exp operator< ( + const S& s, + const matrix_exp& m + ); + /*! + requires + - is_built_in_scalar_type::value == true + - is_built_in_scalar_type::value == true + ensures + - returns a matrix R such that: + - R::type == the same type that was in m. + - R has the same dimensions as m. + - for all valid r and c: + - if (s < m(r,c)) then + - R(r,c) == 1 + - else + - R(r,c) == 0 + - i.e. R is a binary matrix of all 1s or 0s. + !*/ + +// ---------------------------------------------------------------------------------------- + + template + const matrix_exp operator<= ( + const matrix_exp& m, + const S& s + ); + /*! + requires + - is_built_in_scalar_type::value == true + - is_built_in_scalar_type::value == true + ensures + - returns a matrix R such that: + - R::type == the same type that was in m. + - R has the same dimensions as m. + - for all valid r and c: + - if (m(r,c) <= s) then + - R(r,c) == 1 + - else + - R(r,c) == 0 + - i.e. R is a binary matrix of all 1s or 0s. + !*/ + +// ---------------------------------------------------------------------------------------- + + template + const matrix_exp operator<= ( + const S& s, + const matrix_exp& m + ); + /*! + requires + - is_built_in_scalar_type::value == true + - is_built_in_scalar_type::value == true + ensures + - returns a matrix R such that: + - R::type == the same type that was in m. + - R has the same dimensions as m. + - for all valid r and c: + - if (s <= m(r,c)) then + - R(r,c) == 1 + - else + - R(r,c) == 0 + - i.e. R is a binary matrix of all 1s or 0s. + !*/ + +// ---------------------------------------------------------------------------------------- + + template + const matrix_exp operator> ( + const matrix_exp& m, + const S& s + ); + /*! + requires + - is_built_in_scalar_type::value == true + - is_built_in_scalar_type::value == true + ensures + - returns a matrix R such that: + - R::type == the same type that was in m. + - R has the same dimensions as m. + - for all valid r and c: + - if (m(r,c) > s) then + - R(r,c) == 1 + - else + - R(r,c) == 0 + - i.e. R is a binary matrix of all 1s or 0s. + !*/ + +// ---------------------------------------------------------------------------------------- + + template + const matrix_exp operator> ( + const S& s, + const matrix_exp& m + ); + /*! + requires + - is_built_in_scalar_type::value == true + - is_built_in_scalar_type::value == true + ensures + - returns a matrix R such that: + - R::type == the same type that was in m. + - R has the same dimensions as m. + - for all valid r and c: + - if (s > m(r,c)) then + - R(r,c) == 1 + - else + - R(r,c) == 0 + - i.e. R is a binary matrix of all 1s or 0s. + !*/ + +// ---------------------------------------------------------------------------------------- + + template + const matrix_exp operator>= ( + const matrix_exp& m, + const S& s + ); + /*! + requires + - is_built_in_scalar_type::value == true + - is_built_in_scalar_type::value == true + ensures + - returns a matrix R such that: + - R::type == the same type that was in m. + - R has the same dimensions as m. + - for all valid r and c: + - if (m(r,c) >= s) then + - R(r,c) == 1 + - else + - R(r,c) == 0 + - i.e. R is a binary matrix of all 1s or 0s. + !*/ + +// ---------------------------------------------------------------------------------------- + + template + const matrix_exp operator>= ( + const S& s, + const matrix_exp& m + ); + /*! + requires + - is_built_in_scalar_type::value == true + - is_built_in_scalar_type::value == true + ensures + - returns a matrix R such that: + - R::type == the same type that was in m. + - R has the same dimensions as m. + - for all valid r and c: + - if (s >= m(r,c)) then + - R(r,c) == 1 + - else + - R(r,c) == 0 + - i.e. R is a binary matrix of all 1s or 0s. + !*/ + +// ---------------------------------------------------------------------------------------- + + template + const matrix_exp operator== ( + const matrix_exp& m, + const S& s + ); + /*! + requires + - is_built_in_scalar_type::value == true + - is_built_in_scalar_type::value == true + ensures + - returns a matrix R such that: + - R::type == the same type that was in m. + - R has the same dimensions as m. + - for all valid r and c: + - if (m(r,c) == s) then + - R(r,c) == 1 + - else + - R(r,c) == 0 + - i.e. R is a binary matrix of all 1s or 0s. + !*/ + +// ---------------------------------------------------------------------------------------- + + template + const matrix_exp operator== ( + const S& s, + const matrix_exp& m + ); + /*! + requires + - is_built_in_scalar_type::value == true + - is_built_in_scalar_type::value == true + ensures + - returns a matrix R such that: + - R::type == the same type that was in m. + - R has the same dimensions as m. + - for all valid r and c: + - if (s == m(r,c)) then + - R(r,c) == 1 + - else + - R(r,c) == 0 + - i.e. R is a binary matrix of all 1s or 0s. + !*/ + +// ---------------------------------------------------------------------------------------- + + template + const matrix_exp operator!= ( + const matrix_exp& m, + const S& s + ); + /*! + requires + - is_built_in_scalar_type::value == true + - is_built_in_scalar_type::value == true + ensures + - returns a matrix R such that: + - R::type == the same type that was in m. + - R has the same dimensions as m. + - for all valid r and c: + - if (m(r,c) != s) then + - R(r,c) == 1 + - else + - R(r,c) == 0 + - i.e. R is a binary matrix of all 1s or 0s. + !*/ + +// ---------------------------------------------------------------------------------------- + + template + const matrix_exp operator!= ( + const S& s, + const matrix_exp& m + ); + /*! + requires + - is_built_in_scalar_type::value == true + - is_built_in_scalar_type::value == true + ensures + - returns a matrix R such that: + - R::type == the same type that was in m. + - R has the same dimensions as m. + - for all valid r and c: + - if (s != m(r,c)) then + - R(r,c) == 1 + - else + - R(r,c) == 0 + - i.e. R is a binary matrix of all 1s or 0s. + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// Statistics +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + const matrix_exp::type min ( + const matrix_exp& m + ); + /*! + requires + - m.size() > 0 + ensures + - returns the value of the smallest element of m. If m contains complex + elements then the element returned is the one with the smallest norm + according to std::norm(). + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp min_pointwise ( + const matrix_exp& a, + const matrix_exp& b + ); + /*! + requires + - a.nr() == b.nr() + - a.nc() == b.nc() + - a and b both contain the same type of element + ensures + - returns a matrix R such that: + - R::type == the same type that was in a and b. + - R has the same dimensions as a and b. + - for all valid r and c: + R(r,c) == std::min(a(r,c), b(r,c)) + !*/ + + const matrix_exp min_pointwise ( + const matrix_exp& a, + const matrix_exp& b, + const matrix_exp& c + ); + /*! + performs min_pointwise(a,min_pointwise(b,c)); + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp::type max ( + const matrix_exp& m + ); + /*! + requires + - m.size() > 0 + ensures + - returns the value of the biggest element of m. If m contains complex + elements then the element returned is the one with the largest norm + according to std::norm(). + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp max_pointwise ( + const matrix_exp& a, + const matrix_exp& b + ); + /*! + requires + - a.nr() == b.nr() + - a.nc() == b.nc() + - a and b both contain the same type of element + ensures + - returns a matrix R such that: + - R::type == the same type that was in a and b. + - R has the same dimensions as a and b. + - for all valid r and c: + R(r,c) == std::max(a(r,c), b(r,c)) + !*/ + + const matrix_exp max_pointwise ( + const matrix_exp& a, + const matrix_exp& b, + const matrix_exp& c + ); + /*! + performs max_pointwise(a,max_pointwise(b,c)); + !*/ + +// ---------------------------------------------------------------------------------------- + + void find_min_and_max ( + const matrix_exp& m, + matrix_exp::type& min_val, + matrix_exp::type& max_val + ); + /*! + requires + - m.size() > 0 + ensures + - #min_val == min(m) + - #max_val == max(m) + - This function computes both the min and max in just one pass + over the elements of the matrix m. + !*/ + +// ---------------------------------------------------------------------------------------- + + long index_of_max ( + const matrix_exp& m + ); + /*! + requires + - is_vector(m) == true + - m.size() > 0 + ensures + - returns the index of the largest element in m. + (i.e. m(index_of_max(m)) == max(m)) + !*/ + +// ---------------------------------------------------------------------------------------- + + long index_of_min ( + const matrix_exp& m + ); + /*! + requires + - is_vector(m) == true + - m.size() > 0 + ensures + - returns the index of the smallest element in m. + (i.e. m(index_of_min(m)) == min(m)) + !*/ + +// ---------------------------------------------------------------------------------------- + + point max_point ( + const matrix_exp& m + ); + /*! + requires + - m.size() > 0 + ensures + - returns the location of the maximum element of the array, that is, if the + returned point is P then it will be the case that: m(P.y(),P.x()) == max(m). + !*/ + +// ---------------------------------------------------------------------------------------- + + dlib::vector max_point_interpolated ( + const matrix_exp& m + ); + /*! + requires + - m.size() > 0 + ensures + - Like max_point(), this function finds the location in m with the largest + value. However, we additionally use some quadratic interpolation to find the + location of the maximum point with sub-pixel accuracy. Therefore, the + returned point is equal to max_point(m) + some small sub-pixel delta. + !*/ + +// ---------------------------------------------------------------------------------------- + + point min_point ( + const matrix_exp& m + ); + /*! + requires + - m.size() > 0 + ensures + - returns the location of the minimum element of the array, that is, if the + returned point is P then it will be the case that: m(P.y(),P.x()) == min(m). + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp::type sum ( + const matrix_exp& m + ); + /*! + ensures + - returns the sum of all elements in m + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp sum_rows ( + const matrix_exp& m + ); + /*! + requires + - m.size() > 0 + ensures + - returns a row matrix that contains the sum of all the rows in m. + - returns a matrix M such that + - M::type == the same type that was in m + - M.nr() == 1 + - M.nc() == m.nc() + - for all valid i: + - M(i) == sum(colm(m,i)) + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp sum_cols ( + const matrix_exp& m + ); + /*! + requires + - m.size() > 0 + ensures + - returns a column matrix that contains the sum of all the columns in m. + - returns a matrix M such that + - M::type == the same type that was in m + - M.nr() == m.nr() + - M.nc() == 1 + - for all valid i: + - M(i) == sum(rowm(m,i)) + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp::type prod ( + const matrix_exp& m + ); + /*! + ensures + - returns the results of multiplying all elements of m together. + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp::type mean ( + const matrix_exp& m + ); + /*! + ensures + - returns the mean of all elements in m. + (i.e. returns sum(m)/(m.nr()*m.nc())) + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp::type variance ( + const matrix_exp& m + ); + /*! + ensures + - returns the unbiased sample variance of all elements in m + (i.e. 1.0/(m.nr()*m.nc() - 1)*(sum of all pow(m(i,j) - mean(m),2))) + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp::type stddev ( + const matrix_exp& m + ); + /*! + ensures + - returns sqrt(variance(m)) + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix covariance ( + const matrix_exp& m + ); + /*! + requires + - matrix_exp::type == a dlib::matrix object + - is_col_vector(m) == true + - m.size() > 1 + - for all valid i, j: + - is_col_vector(m(i)) == true + - m(i).size() > 0 + - m(i).size() == m(j).size() + - i.e. m contains only column vectors and all the column vectors + have the same non-zero length + ensures + - returns the unbiased sample covariance matrix for the set of samples + in m. + (i.e. 1.0/(m.nr()-1)*(sum of all (m(i) - mean(m))*trans(m(i) - mean(m)))) + - the returned matrix will contain elements of type matrix_exp::type::type. + - the returned matrix will have m(0).nr() rows and columns. + !*/ + +// ---------------------------------------------------------------------------------------- + + template + const matrix randm( + long nr, + long nc, + rand_gen& rnd + ); + /*! + requires + - nr >= 0 + - nc >= 0 + - rand_gen == an object that implements the rand/rand_float_abstract.h interface + ensures + - generates a random matrix using the given rnd random number generator + - returns a matrix M such that + - M::type == double + - M.nr() == nr + - M.nc() == nc + - for all valid i, j: + - M(i,j) == a random number such that 0 <= M(i,j) < 1 + !*/ + +// ---------------------------------------------------------------------------------------- + + inline const matrix randm( + long nr, + long nc + ); + /*! + requires + - nr >= 0 + - nc >= 0 + ensures + - generates a random matrix using std::rand() + - returns a matrix M such that + - M::type == double + - M.nr() == nr + - M.nc() == nc + - for all valid i, j: + - M(i,j) == a random number such that 0 <= M(i,j) < 1 + !*/ + +// ---------------------------------------------------------------------------------------- + + inline const matrix_exp gaussian_randm ( + long nr, + long nc, + unsigned long seed = 0 + ); + /*! + requires + - nr >= 0 + - nc >= 0 + ensures + - returns a matrix with its values filled with 0 mean unit variance Gaussian + random numbers. + - Each setting of the seed results in a different random matrix. + - The returned matrix is lazily evaluated using the expression templates + technique. This means that the returned matrix doesn't take up any memory + and is only an expression template. The values themselves are computed on + demand using the gaussian_random_hash() routine. + - returns a matrix M such that + - M::type == double + - M.nr() == nr + - M.nc() == nc + - for all valid i, j: + - M(i,j) == gaussian_random_hash(i,j,seed) + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// Pixel and Image Utilities +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename P + > + const matrix::num,1> pixel_to_vector ( + const P& pixel + ); + /*! + requires + - pixel_traits

    must be defined + ensures + - returns a matrix M such that: + - M::type == T + - M::NC == 1 + - M::NR == pixel_traits

    ::num + - if (pixel_traits

    ::grayscale) then + - M(0) == pixel + - if (pixel_traits

    ::rgb) then + - M(0) == pixel.red + - M(1) == pixel.green + - M(2) == pixel.blue + - if (pixel_traits

    ::hsi) then + - M(0) == pixel.h + - M(1) == pixel.s + - M(2) == pixel.i + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename P + > + void vector_to_pixel ( + P& pixel, + const matrix_exp& vector + ); + /*! + requires + - vector::NR == pixel_traits

    ::num + - vector::NC == 1 + (i.e. you have to use a statically dimensioned vector) + ensures + - if (pixel_traits

    ::grayscale) then + - pixel == M(0) + - if (pixel_traits

    ::rgb) then + - pixel.red == M(0) + - pixel.green == M(1) + - pixel.blue == M(2) + - if (pixel_traits

    ::hsi) then + - pixel.h == M(0) + - pixel.s == M(1) + - pixel.i == M(2) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + long lower, + long upper + > + const matrix_exp clamp ( + const matrix_exp& m + ); + /*! + ensures + - returns a matrix R such that: + - R::type == the same type that was in m + - R has the same dimensions as m + - for all valid r and c: + - if (m(r,c) > upper) then + - R(r,c) == upper + - else if (m(r,c) < lower) then + - R(r,c) == lower + - else + - R(r,c) == m(r,c) + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp clamp ( + const matrix_exp& m, + const matrix_exp::type& lower, + const matrix_exp::type& upper + ); + /*! + ensures + - returns a matrix R such that: + - R::type == the same type that was in m + - R has the same dimensions as m + - for all valid r and c: + - if (m(r,c) > upper) then + - R(r,c) == upper + - else if (m(r,c) < lower) then + - R(r,c) == lower + - else + - R(r,c) == m(r,c) + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp clamp ( + const matrix_exp& m, + const matrix_exp& lower, + const matrix_exp& upper + ); + /*! + requires + - m.nr() == lower.nr() + - m.nc() == lower.nc() + - m.nr() == upper.nr() + - m.nc() == upper.nc() + - m, lower, and upper all contain the same type of elements. + ensures + - returns a matrix R such that: + - R::type == the same type that was in m + - R has the same dimensions as m + - for all valid r and c: + - if (m(r,c) > upper(r,c)) then + - R(r,c) == upper(r,c) + - else if (m(r,c) < lower(r,c)) then + - R(r,c) == lower(r,c) + - else + - R(r,c) == m(r,c) + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp lowerbound ( + const matrix_exp& m, + const matrix_exp::type& thresh + ); + /*! + ensures + - returns a matrix R such that: + - R::type == the same type that was in m + - R has the same dimensions as m + - for all valid r and c: + - if (m(r,c) >= thresh) then + - R(r,c) == m(r,c) + - else + - R(r,c) == thresh + !*/ + +// ---------------------------------------------------------------------------------------- + + const matrix_exp upperbound ( + const matrix_exp& m, + const matrix_exp::type& thresh + ); + /*! + ensures + - returns a matrix R such that: + - R::type == the same type that was in m + - R has the same dimensions as m + - for all valid r and c: + - if (m(r,c) <= thresh) then + - R(r,c) == m(r,c) + - else + - R(r,c) == thresh + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MATRIx_UTILITIES_ABSTRACT_ + diff --git a/ml/dlib/dlib/matrix/symmetric_matrix_cache.h b/ml/dlib/dlib/matrix/symmetric_matrix_cache.h new file mode 100644 index 000000000..bff268aef --- /dev/null +++ b/ml/dlib/dlib/matrix/symmetric_matrix_cache.h @@ -0,0 +1,464 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SYMMETRIC_MATRIX_CAcHE_Hh_ +#define DLIB_SYMMETRIC_MATRIX_CAcHE_Hh_ + +#include "symmetric_matrix_cache_abstract.h" +#include +#include "../matrix.h" +#include "../algs.h" +#include "../array.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template + struct op_symm_cache : basic_op_m + { + inline op_symm_cache( + const M& m_, + long max_size_megabytes_ + ) : + basic_op_m(m_), + max_size_megabytes(max_size_megabytes_), + is_initialized(false) + { + lookup.assign(this->m.nr(), -1); + + diag_cache = matrix_cast(dlib::diag(m_)); + } + + op_symm_cache ( + const op_symm_cache& item + ) : + basic_op_m(item.m), + diag_cache(item.diag_cache), + max_size_megabytes(item.max_size_megabytes), + is_initialized(false) + { + lookup.assign(this->m.nr(), -1); + } + + typedef cache_element_type type; + typedef const cache_element_type& const_ret_type; + const static long cost = M::cost + 3; + + inline const_ret_type apply ( long r, long c) const + { + if (lookup[c] != -1) + { + return cache[lookup[c]](r); + } + else if (r == c) + { + return diag_cache(r); + } + else if (lookup[r] != -1) + { + // the matrix is symmetric so this is legit + return cache[lookup[r]](c); + } + else + { + add_col_to_cache(c); + return cache[lookup[c]](r); + } + } + + inline std::pair col(long i) const + /*! + requires + - 0 <= i < nc() + ensures + - returns a pair P such that: + - P.first == a pointer to the first element of the ith column + - P.second == a pointer to the integer used to count the number of + outstanding references to the ith column. + !*/ + { + if (is_cached(i) == false) + add_col_to_cache(i); + + // find where this column is in the cache + long idx = lookup[i]; + if (idx == next) + { + // if this column was the next to be replaced + // then make sure that doesn't happen + next = (next + 1)%cache.size(); + } + + return std::make_pair(&cache[idx](0), &references[idx]); + } + + const type* diag() const { init(); return &diag_cache(0); } + + long* diag_ref_count() const + { + return &diag_reference_count; + } + + private: + inline bool is_cached ( + long r + ) const + { + return (lookup[r] != -1); + } + + inline void init() const + { + if (is_initialized == false) + { + // figure out how many columns of the matrix we can have + // with the given amount of memory. + long max_size = (max_size_megabytes*1024*1024)/(this->m.nr()*sizeof(type)); + // don't let it be 0 or 1 + if (max_size <= 1) + max_size = 2; + + const long size = std::min(max_size,this->m.nr()); + + diag_reference_count = 0; + + references.set_max_size(this->m.nr()); + references.set_size(size); + for (unsigned long i = 0; i < references.size(); ++i) + references[i] = 0; + + cache.set_max_size(this->m.nr()); + cache.set_size(size); + + rlookup.assign(size,-1); + next = 0; + + is_initialized = true; + } + } + + void make_sure_next_is_unreferenced ( + ) const + { + if (references[next] != 0) + { + // find an unreferenced element of the cache + unsigned long i; + for (i = 1; i < references.size(); ++i) + { + const unsigned long idx = (next+i)%references.size(); + if (references[idx] == 0) + { + next = idx; + break; + } + } + + // if all elements of the cache are referenced then make the cache bigger + // and use the new element. + if (references[next] != 0) + { + cache.resize(cache.size()+1); + + next = references.size(); + references.resize(references.size()+1); + references[next] = 0; + + rlookup.push_back(-1); + } + } + } + + inline void add_col_to_cache( + long c + ) const + { + init(); + make_sure_next_is_unreferenced(); + + // if the lookup table is pointing to cache[next] then clear lookup[next] + if (rlookup[next] != -1) + lookup[rlookup[next]] = -1; + + // make the lookup table so that it says c is now cached at the spot indicated by next + lookup[c] = next; + rlookup[next] = c; + + // compute this column in the matrix and store it in the cache + cache[next] = matrix_cast(colm(this->m,c)); + + next = (next + 1)%cache.size(); + } + + /*! + INITIAL VALUE + - for all valid x: + - lookup(x) == -1 + + - diag_cache == the diagonal of the original matrix + - is_initialized == false + - max_size_megabytes == the max_size_megabytes from symmetric_matrix_cache() + + CONVENTION + - diag_cache == the diagonal of the original matrix + - lookup.size() == diag_cache.size() + + - if (is_initialized) then + - if (lookup[c] != -1) then + - cache[lookup[c]] == the cached column c of the matrix + - rlookup[lookup[c]] == c + + - if (rlookup[x] != -1) then + - lookup[rlookup[x]] == x + - cache[x] == the cached column rlookup[x] of the matrix + + - next == the next element in the cache table to use to cache something + - references[i] == the number of outstanding references to cache element cache[i] + + - diag_reference_count == the number of outstanding references to diag_cache. + (this isn't really needed. It's just here so that we can reuse the matrix + expression from colm() to implement diag()) + !*/ + + + mutable array > cache; + mutable array references; + matrix diag_cache; + mutable std::vector lookup; + mutable std::vector rlookup; + mutable long next; + + const long max_size_megabytes; + mutable bool is_initialized; + mutable long diag_reference_count; + + }; + + template < + typename cache_element_type, + typename EXP + > + const matrix_op > symmetric_matrix_cache ( + const matrix_exp& m, + long max_size_megabytes + ) + { + // Don't check that m is symmetric since doing so would be extremely onerous for the + // kinds of matrices intended for use with the symmetric_matrix_cache. Check everything + // else though. + DLIB_ASSERT(m.size() > 0 && m.nr() == m.nc() && max_size_megabytes >= 0, + "\tconst matrix_exp symmetric_matrix_cache(const matrix_exp& m, max_size_megabytes)" + << "\n\t You have given invalid arguments to this function" + << "\n\t m.nr(): " << m.nr() + << "\n\t m.nc(): " << m.nc() + << "\n\t m.size(): " << m.size() + << "\n\t max_size_megabytes: " << max_size_megabytes + ); + + typedef op_symm_cache op; + return matrix_op(op(m.ref(), max_size_megabytes)); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_colm_symm_cache + { + typedef cache_element_type type; + + op_colm_symm_cache( + const M& m_, + const type* data_, + long* ref_count_ + ) : + m(m_), + data(data_), + ref_count(ref_count_) + { + *ref_count += 1; + } + + op_colm_symm_cache ( + const op_colm_symm_cache& item + ) : + m(item.m), + data(item.data), + ref_count(item.ref_count) + { + *ref_count += 1; + } + + ~op_colm_symm_cache( + ) + { + *ref_count -= 1; + } + + const M& m; + + const type* const data; + long* const ref_count; + + const static long cost = M::cost; + const static long NR = M::NR; + const static long NC = 1; + typedef const type& const_ret_type; + typedef typename M::mem_manager_type mem_manager_type; + typedef typename M::layout_type layout_type; + inline const_ret_type apply ( long r, long) const { return data[r]; } + + long nr () const { return m.nr(); } + long nc () const { return 1; } + + template bool aliases ( const matrix_exp& item) const { return m.aliases(item); } + template bool destructively_aliases ( const matrix_exp& item) const { return m.aliases(item); } + }; + + template < + typename EXP, + typename cache_element_type + > + inline const matrix_op > colm ( + const matrix_exp > >& m, + long col + ) + { + DLIB_ASSERT(col >= 0 && col < m.nc(), + "\tconst matrix_exp colm(const matrix_exp& m, row)" + << "\n\tYou have specified invalid sub matrix dimensions" + << "\n\tm.nr(): " << m.nr() + << "\n\tm.nc(): " << m.nc() + << "\n\tcol: " << col + ); + + std::pair p = m.ref().op.col(col); + + typedef op_colm_symm_cache op; + return matrix_op(op(m.ref().op.m, + p.first, + p.second)); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP, + typename cache_element_type + > + inline const matrix_op > diag ( + const matrix_exp > >& m + ) + { + typedef op_colm_symm_cache op; + return matrix_op(op(m.ref().op.m, + m.ref().op.diag(), + m.ref().op.diag_ref_count())); + } + +// ---------------------------------------------------------------------------------------- + + template + struct op_rowm_symm_cache + { + typedef cache_element_type type; + + op_rowm_symm_cache( + const M& m_, + const type* data_, + long* ref_count_ + ) : + m(m_), + data(data_), + ref_count(ref_count_) + { + *ref_count += 1; + } + + op_rowm_symm_cache ( + const op_rowm_symm_cache& item + ) : + m(item.m), + data(item.data), + ref_count(item.ref_count) + { + *ref_count += 1; + } + + ~op_rowm_symm_cache( + ) + { + *ref_count -= 1; + } + + const M& m; + + const type* const data; + long* const ref_count; + + const static long cost = M::cost; + const static long NR = 1; + const static long NC = M::NC; + typedef const type& const_ret_type; + typedef typename M::mem_manager_type mem_manager_type; + typedef typename M::layout_type layout_type; + inline const_ret_type apply ( long , long c) const { return data[c]; } + + long nr () const { return 1; } + long nc () const { return m.nc(); } + + template bool aliases ( const matrix_exp& item) const { return m.aliases(item); } + template bool destructively_aliases ( const matrix_exp& item) const { return m.aliases(item); } + }; + + template < + typename EXP, + typename cache_element_type + > + inline const matrix_op > rowm ( + const matrix_exp > >& m, + long row + ) + { + DLIB_ASSERT(row >= 0 && row < m.nr(), + "\tconst matrix_exp rowm(const matrix_exp& m, row)" + << "\n\tYou have specified invalid sub matrix dimensions" + << "\n\tm.nr(): " << m.nr() + << "\n\tm.nc(): " << m.nc() + << "\n\trow: " << row + ); + + std::pair p = m.ref().op.col(row); + + typedef op_rowm_symm_cache op; + return matrix_op(op(m.ref().op.m, + p.first, + p.second)); + } + +// ---------------------------------------------------------------------------------------- + + template + struct colm_exp > > + { + typedef matrix_op > type; + }; + + template + struct rowm_exp > > + { + typedef matrix_op > type; + }; + + template + struct diag_exp > > + { + typedef matrix_op > type; + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SYMMETRIC_MATRIX_CAcHE_Hh_ + diff --git a/ml/dlib/dlib/matrix/symmetric_matrix_cache_abstract.h b/ml/dlib/dlib/matrix/symmetric_matrix_cache_abstract.h new file mode 100644 index 000000000..6a41ad282 --- /dev/null +++ b/ml/dlib/dlib/matrix/symmetric_matrix_cache_abstract.h @@ -0,0 +1,63 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#define DLIB_SYMMETRIC_MATRIX_CAcHE_ABSTRACT_Hh_ +#ifndef DLIB_SYMMETRIC_MATRIX_CAcHE_ABSTRACT_Hh_ + +#include "matrix_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename cache_element_type + > + const matrix_exp symmetric_matrix_cache ( + const matrix_exp& m, + long max_size_megabytes + ); + /*! + requires + - m.size() > 0 + - m.nr() == m.nc() + - max_size_megabytes >= 0 + ensures + - This function assumes that m is symmetric. If m is not symmetric then it won't + crash but you will get incorrect results. + - This method creates a matrix expression which internally caches the elements + of m so that they can be accessed quickly. It is useful if m is some kind of + complex matrix expression which is both very large and expensive to evaluate. + An example would be a kernel_matrix() expression with an expensive kernel and + a large number of samples. Such an expression would result in a huge matrix, + potentially too big to store in memory. The symmetric_matrix_cache() then makes + it easy to store just the parts of a matrix expression which are accessed most + often in memory. The specific details are defined below. + - returns a matrix M such that + - M == m + (i.e. M represents the same matrix as m) + - M will cache elements of m and hold them internally so they can be quickly + accessed. In particular, M will attempt to allocate no more than + max_size_megabytes megabytes of memory for the purposes of caching + elements of m. When an element of the matrix is accessed it is either + retrieved from the cache, or if this is not possible, then an entire + column of m is loaded into a part of the cache which hasn't been used + recently and the needed element returned. + - diag(m) is always loaded into the cache and is stored separately from + the cached columns. That means accesses to the diagonal elements of m + are always fast. + - M will store the cached elements of m as cache_element_type objects. + Typically, cache_element_type will be float or double. + - To avoid repeated cache lookups, the following operations are optimized for + use with the symmetric_matrix_cache(): + - diag(M), rowm(M,row_idx), colm(M,col_idx) + These methods will perform only one cache lookup operation for an + entire row/column/diagonal worth of data. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SYMMETRIC_MATRIX_CAcHE_ABSTRACT_Hh_ + diff --git a/ml/dlib/dlib/md5.h b/ml/dlib/dlib/md5.h new file mode 100644 index 000000000..e62930366 --- /dev/null +++ b/ml/dlib/dlib/md5.h @@ -0,0 +1,3 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#include "md5/md5_kernel_1.h" diff --git a/ml/dlib/dlib/md5/md5_kernel_1.cpp b/ml/dlib/dlib/md5/md5_kernel_1.cpp new file mode 100644 index 000000000..f073f9256 --- /dev/null +++ b/ml/dlib/dlib/md5/md5_kernel_1.cpp @@ -0,0 +1,617 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MD5_KERNEL_1_CPp_ +#define DLIB_MD5_KERNEL_1_CPp_ +#include "md5_kernel_1.h" +#include "../uintn.h" + +#include +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + namespace md5_stuff + { + + inline uint32 F ( + uint32 x, + uint32 y, + uint32 z + ) + { + return ( (x&y) | ((~x)&z) ); + } + + // ------------------------------------------------------------------------------------ + + inline uint32 G ( + uint32 x, + uint32 y, + uint32 z + ) + { + return ( (x&z) | (y&(~z)) ); + } + + // ------------------------------------------------------------------------------------ + + inline uint32 H ( + uint32 x, + uint32 y, + uint32 z + ) + { + return ( x^y^z ); + } + + // ------------------------------------------------------------------------------------ + + inline uint32 I ( + uint32 x, + uint32 y, + uint32 z + ) + { + return ( y ^ (x|(~z)) ); + } + + // ------------------------------------------------------------------------------------ + + inline uint32 rotate_left ( + uint32 x, + uint32 n + ) + { + return ( (x<>(32-n)) ); + } + + // ------------------------------------------------------------------------------------ + + inline void FF ( + uint32& a, + uint32 b, + uint32 c, + uint32 d, + uint32 x, + uint32 s, + uint32 ac + ) + { + a += F(b, c, d) + x + ac; + a = rotate_left(a, s); + a += b; + } + + // ------------------------------------------------------------------------------------ + + inline void GG ( + uint32& a, + uint32 b, + uint32 c, + uint32 d, + uint32 x, + uint32 s, + uint32 ac + ) + { + a += G(b, c, d) + x + ac; + a = rotate_left(a, s); + a += b; + } + + // ------------------------------------------------------------------------------------ + + inline void HH ( + uint32& a, + uint32 b, + uint32 c, + uint32 d, + uint32 x, + uint32 s, + uint32 ac + ) + { + a += H(b, c, d) + x + ac; + a = rotate_left(a, s); + a += b; + } + + // ------------------------------------------------------------------------------------ + + inline void II ( + uint32& a, + uint32 b, + uint32 c, + uint32 d, + uint32 x, + uint32 s, + uint32 ac + ) + { + a += I(b, c, d) + x + ac; + a = rotate_left(a, s); + a += b; + } + + // ------------------------------------------------------------------------------------ + + void scramble_block ( + uint32& a, + uint32& b, + uint32& c, + uint32& d, + uint32* x + ) + { + const uint32 S11 = 7; + const uint32 S12 = 12; + const uint32 S13 = 17; + const uint32 S14 = 22; + const uint32 S21 = 5; + const uint32 S22 = 9; + const uint32 S23 = 14; + const uint32 S24 = 20; + const uint32 S31 = 4; + const uint32 S32 = 11; + const uint32 S33 = 16; + const uint32 S34 = 23; + const uint32 S41 = 6; + const uint32 S42 = 10; + const uint32 S43 = 15; + const uint32 S44 = 21; + + + // round 1 + FF (a, b, c, d, x[ 0], S11, 0xd76aa478); // 1 + FF (d, a, b, c, x[ 1], S12, 0xe8c7b756); // 2 + FF (c, d, a, b, x[ 2], S13, 0x242070db); // 3 + FF (b, c, d, a, x[ 3], S14, 0xc1bdceee); // 4 + FF (a, b, c, d, x[ 4], S11, 0xf57c0faf); // 5 + FF (d, a, b, c, x[ 5], S12, 0x4787c62a); // 6 + FF (c, d, a, b, x[ 6], S13, 0xa8304613); // 7 + FF (b, c, d, a, x[ 7], S14, 0xfd469501); // 8 + FF (a, b, c, d, x[ 8], S11, 0x698098d8); // 9 + FF (d, a, b, c, x[ 9], S12, 0x8b44f7af); // 10 + FF (c, d, a, b, x[10], S13, 0xffff5bb1); // 11 + FF (b, c, d, a, x[11], S14, 0x895cd7be); // 12 + FF (a, b, c, d, x[12], S11, 0x6b901122); // 13 + FF (d, a, b, c, x[13], S12, 0xfd987193); // 14 + FF (c, d, a, b, x[14], S13, 0xa679438e); // 15 + FF (b, c, d, a, x[15], S14, 0x49b40821); // 16 + + // Round 2 + GG (a, b, c, d, x[ 1], S21, 0xf61e2562); // 17 + GG (d, a, b, c, x[ 6], S22, 0xc040b340); // 18 + GG (c, d, a, b, x[11], S23, 0x265e5a51); // 19 + GG (b, c, d, a, x[ 0], S24, 0xe9b6c7aa); // 20 + GG (a, b, c, d, x[ 5], S21, 0xd62f105d); // 21 + GG (d, a, b, c, x[10], S22, 0x2441453); // 22 + GG (c, d, a, b, x[15], S23, 0xd8a1e681); // 23 + GG (b, c, d, a, x[ 4], S24, 0xe7d3fbc8); // 24 + GG (a, b, c, d, x[ 9], S21, 0x21e1cde6); // 25 + GG (d, a, b, c, x[14], S22, 0xc33707d6); // 26 + GG (c, d, a, b, x[ 3], S23, 0xf4d50d87); // 27 + GG (b, c, d, a, x[ 8], S24, 0x455a14ed); // 28 + GG (a, b, c, d, x[13], S21, 0xa9e3e905); // 29 + GG (d, a, b, c, x[ 2], S22, 0xfcefa3f8); // 30 + GG (c, d, a, b, x[ 7], S23, 0x676f02d9); // 31 + GG (b, c, d, a, x[12], S24, 0x8d2a4c8a); // 32 + + // Round 3 + HH (a, b, c, d, x[ 5], S31, 0xfffa3942); // 33 + HH (d, a, b, c, x[ 8], S32, 0x8771f681); // 34 + HH (c, d, a, b, x[11], S33, 0x6d9d6122); // 35 + HH (b, c, d, a, x[14], S34, 0xfde5380c); // 36 + HH (a, b, c, d, x[ 1], S31, 0xa4beea44); // 37 + HH (d, a, b, c, x[ 4], S32, 0x4bdecfa9); // 38 + HH (c, d, a, b, x[ 7], S33, 0xf6bb4b60); // 39 + HH (b, c, d, a, x[10], S34, 0xbebfbc70); // 40 + HH (a, b, c, d, x[13], S31, 0x289b7ec6); // 41 + HH (d, a, b, c, x[ 0], S32, 0xeaa127fa); // 42 + HH (c, d, a, b, x[ 3], S33, 0xd4ef3085); // 43 + HH (b, c, d, a, x[ 6], S34, 0x4881d05); // 44 + HH (a, b, c, d, x[ 9], S31, 0xd9d4d039); // 45 + HH (d, a, b, c, x[12], S32, 0xe6db99e5); // 46 + HH (c, d, a, b, x[15], S33, 0x1fa27cf8); // 47 + HH (b, c, d, a, x[ 2], S34, 0xc4ac5665); // 48 + + // Round 4 + II (a, b, c, d, x[ 0], S41, 0xf4292244); // 49 + II (d, a, b, c, x[ 7], S42, 0x432aff97); // 50 + II (c, d, a, b, x[14], S43, 0xab9423a7); // 51 + II (b, c, d, a, x[ 5], S44, 0xfc93a039); // 52 + II (a, b, c, d, x[12], S41, 0x655b59c3); // 53 + II (d, a, b, c, x[ 3], S42, 0x8f0ccc92); // 54 + II (c, d, a, b, x[10], S43, 0xffeff47d); // 55 + II (b, c, d, a, x[ 1], S44, 0x85845dd1); // 56 + II (a, b, c, d, x[ 8], S41, 0x6fa87e4f); // 57 + II (d, a, b, c, x[15], S42, 0xfe2ce6e0); // 58 + II (c, d, a, b, x[ 6], S43, 0xa3014314); // 59 + II (b, c, d, a, x[13], S44, 0x4e0811a1); // 60 + II (a, b, c, d, x[ 4], S41, 0xf7537e82); // 61 + II (d, a, b, c, x[11], S42, 0xbd3af235); // 62 + II (c, d, a, b, x[ 2], S43, 0x2ad7d2bb); // 63 + II (b, c, d, a, x[ 9], S44, 0xeb86d391); // 64 + } + + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + + const std::string md5 ( + const std::string& input + ) + { + unsigned char output[16]; + md5 ( + reinterpret_cast(input.data()), + static_cast(input.size()), + output + ); + + + std::stringstream temp; + for (int i = 0; i < 16; ++i) + { + temp.fill('0'); + temp.width(2); + temp << std::hex << static_cast(output[i]); + } + + return temp.str(); + } + +// ---------------------------------------------------------------------------------------- + + void md5 ( + const unsigned char* input, + unsigned long len, + unsigned char* output + ) + { + using namespace md5_stuff; + + + + + // make a temp version of input with enough space for padding and len appended + unsigned long extra_len = 64-len%64; + if (extra_len <= 8) + extra_len += 64; + unsigned char* temp = new unsigned char[extra_len + len]; + + // number of 16 word blocks + const unsigned long N = (extra_len + len)/64; + + const unsigned char* input2 = input; + unsigned char* temp2 = temp; + unsigned char* end = temp+len; + + // copy input into temp + while (temp2 != end) + { + *temp2 = *input2; + ++temp2; + ++input2; + } + + // pad temp + end += extra_len-8; + *temp2 = static_cast(0x80); + ++temp2; + while (temp2 != end) + { + *temp2 = 0; + ++temp2; + } + + // make len the number of bits in the original message + // but first multiply len by 8 and since len is only 32 bits the number might + // overflow so we will carry out the multiplication manually and end up with + // the result in the base 65536 number with three digits + // result = low + high*65536 + upper*65536*65536 + unsigned long low = len & 0xFFFF; + unsigned long high = len >> 16; + unsigned long upper; + unsigned long tmp; + tmp = low * 8; + low = tmp & 0xFFFF; + tmp = high * 8 + (tmp>>16); + high = tmp & 0xFFFF; + upper = tmp >> 16; + + + // append the length + *temp2 = static_cast(low&0xFF); + ++temp2; + *temp2 = static_cast((low>>8)&0xFF); + ++temp2; + *temp2 = static_cast((high)&0xFF); + ++temp2; + *temp2 = static_cast((high>>8)&0xFF); + ++temp2; + *temp2 = static_cast((upper)&0xFF);; + ++temp2; + *temp2 = static_cast((upper>>8)&0xFF);; + ++temp2; + *temp2 = 0; + ++temp2; + *temp2 = 0; + + + uint32 a = 0x67452301; + uint32 b = 0xefcdab89; + uint32 c = 0x98badcfe; + uint32 d = 0x10325476; + + + // an array of 16 words + uint32 x[16]; + + for (unsigned long i = 0; i < N; ++i) + { + + // copy a block of 16 words from m into x + for (unsigned long j = 0; j < 16; ++j) + { + x[j] = ( + (static_cast(temp[4*(j + 16*i) + 3]) << 24) | + (static_cast(temp[4*(j + 16*i) + 2]) << 16) | + (static_cast(temp[4*(j + 16*i) + 1]) << 8 ) | + (static_cast(temp[4*(j + 16*i) ]) ) + ); + } + + uint32 aa = a; + uint32 bb = b; + uint32 cc = c; + uint32 dd = d; + + + scramble_block(a,b,c,d,x); + + + a = a + aa; + b = b + bb; + c = c + cc; + d = d + dd; + + } + + + // put a, b, c, and d into output + output[0] = static_cast((a) &0xFF); + output[1] = static_cast((a>>8) &0xFF); + output[2] = static_cast((a>>16)&0xFF); + output[3] = static_cast((a>>24)&0xFF); + + output[4] = static_cast((b) &0xFF); + output[5] = static_cast((b>>8) &0xFF); + output[6] = static_cast((b>>16)&0xFF); + output[7] = static_cast((b>>24)&0xFF); + + output[8] = static_cast((c) &0xFF); + output[9] = static_cast((c>>8) &0xFF); + output[10] = static_cast((c>>16)&0xFF); + output[11] = static_cast((c>>24)&0xFF); + + output[12] = static_cast((d) &0xFF); + output[13] = static_cast((d>>8) &0xFF); + output[14] = static_cast((d>>16)&0xFF); + output[15] = static_cast((d>>24)&0xFF); + + delete [] temp; + } + +// ---------------------------------------------------------------------------------------- + + const std::string md5 ( + std::istream& input + ) + { + unsigned char output[16]; + md5 ( + input, + output + ); + + + std::stringstream temp; + for (int i = 0; i < 16; ++i) + { + temp.fill('0'); + temp.width(2); + temp << std::hex << static_cast(output[i]); + } + + return temp.str(); + } + +// ---------------------------------------------------------------------------------------- + + void md5 ( + std::istream& input, + unsigned char* output + ) + { + using namespace md5_stuff; + + + + + uint32 a = 0x67452301; + uint32 b = 0xefcdab89; + uint32 c = 0x98badcfe; + uint32 d = 0x10325476; + + + + unsigned long len = 0; + + // an array of 16 words + uint32 x[16]; + unsigned char temp[64]; + + + + bool write_length = false; + bool at_end = false; + std::streambuf& inputbuf = *input.rdbuf(); + while(!at_end) + { + int num = inputbuf.sgetn(reinterpret_cast(temp),64); + len += num; + + // if we hit the end of the stream then pad and add length + if (num < 64) + { + at_end = true; + unsigned char* temp2 = temp; + unsigned char* end; + if (num < 56) + end = temp+56; + else + end = temp+64; + + temp2 += num; + + // apply padding + *temp2 = 0x80; + ++temp2; + while (temp2 != end) + { + *temp2 = 0; + ++temp2; + } + + + if (num < 56) + { + write_length = true; + // make len the number of bits in the original message + // but first multiply len by 8 and since len is only 32 bits the number might + // overflow so we will carry out the multiplication manually and end up with + // the result in the base 65536 number with three digits + // result = low + high*65536 + upper*65536*65536 + unsigned long low = len & 0xFFFF; + unsigned long high = len >> 16; + unsigned long upper; + unsigned long tmp; + tmp = low * 8; + low = tmp & 0xFFFF; + tmp = high * 8 + (tmp>>16); + high = tmp & 0xFFFF; + upper = tmp >> 16; + + + // append the length + *temp2 = static_cast(low&0xFF); + ++temp2; + *temp2 = static_cast((low>>8)&0xFF); + ++temp2; + *temp2 = static_cast((high)&0xFF); + ++temp2; + *temp2 = static_cast((high>>8)&0xFF); + ++temp2; + *temp2 = static_cast((upper)&0xFF);; + ++temp2; + *temp2 = static_cast((upper>>8)&0xFF);; + ++temp2; + *temp2 = 0; + ++temp2; + *temp2 = 0; + } + + + } + + + // copy a block of 16 words from m into x + for (unsigned long i = 0; i < 16; ++i) + { + x[i] = ( + (static_cast(temp[4*i + 3]) << 24) | + (static_cast(temp[4*i + 2]) << 16) | + (static_cast(temp[4*i + 1]) << 8 ) | + (static_cast(temp[4*i ]) ) + ); + } + + + uint32 aa = a; + uint32 bb = b; + uint32 cc = c; + uint32 dd = d; + + + scramble_block(a,b,c,d,x); + + + a = a + aa; + b = b + bb; + c = c + cc; + d = d + dd; + + } + + if (!write_length) + { + uint64 temp = len*8; + + uint32 aa = a; + uint32 bb = b; + uint32 cc = c; + uint32 dd = d; + + std::memset(x, 0, sizeof(x)); + x[15] = (temp>>32); + x[14] = (temp&0xFFFFFFFF); + + scramble_block(a,b,c,d,x); + + + a = a + aa; + b = b + bb; + c = c + cc; + d = d + dd; + + } + + + // put a, b, c, and d into output + output[0] = static_cast((a) &0xFF); + output[1] = static_cast((a>>8) &0xFF); + output[2] = static_cast((a>>16)&0xFF); + output[3] = static_cast((a>>24)&0xFF); + + output[4] = static_cast((b) &0xFF); + output[5] = static_cast((b>>8) &0xFF); + output[6] = static_cast((b>>16)&0xFF); + output[7] = static_cast((b>>24)&0xFF); + + output[8] = static_cast((c) &0xFF); + output[9] = static_cast((c>>8) &0xFF); + output[10] = static_cast((c>>16)&0xFF); + output[11] = static_cast((c>>24)&0xFF); + + output[12] = static_cast((d) &0xFF); + output[13] = static_cast((d>>8) &0xFF); + output[14] = static_cast((d>>16)&0xFF); + output[15] = static_cast((d>>24)&0xFF); + + input.clear(std::ios::eofbit); + } + +// ---------------------------------------------------------------------------------------- + +} +#endif // DLIB_MD5_KERNEL_1_CPp_ + diff --git a/ml/dlib/dlib/md5/md5_kernel_1.h b/ml/dlib/dlib/md5/md5_kernel_1.h new file mode 100644 index 000000000..7031d21ef --- /dev/null +++ b/ml/dlib/dlib/md5/md5_kernel_1.h @@ -0,0 +1,50 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MD5_KERNEl_1_ +#define DLIB_MD5_KERNEl_1_ + +#include "md5_kernel_abstract.h" +#include +#include +#include "../algs.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + const std::string md5 ( + const std::string& input + ); + +// ---------------------------------------------------------------------------------------- + + void md5 ( + const unsigned char* input, + unsigned long len, + unsigned char* output + ); + +// ---------------------------------------------------------------------------------------- + + const std::string md5 ( + std::istream& input + ); + +// ---------------------------------------------------------------------------------------- + + void md5 ( + std::istream& input, + unsigned char* output + ); + +// ---------------------------------------------------------------------------------------- + +} + +#ifdef NO_MAKEFILE +#include "md5_kernel_1.cpp" +#endif + +#endif // DLIB_MD5_KERNEl_1_ + diff --git a/ml/dlib/dlib/md5/md5_kernel_abstract.h b/ml/dlib/dlib/md5/md5_kernel_abstract.h new file mode 100644 index 000000000..b7d265d72 --- /dev/null +++ b/ml/dlib/dlib/md5/md5_kernel_abstract.h @@ -0,0 +1,83 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_MD5_KERNEl_ABSTRACT_ +#ifdef DLIB_MD5_KERNEl_ABSTRACT_ + +#include +#include + +namespace dlib +{ + + /*! + NOTE: + This is the RSA Data Security, Inc. MD5 Message-Digest Algorithm + as described in rfc1321 + + For the functions which return a unsigned char*. The array contains + the 16 bytes of the digest and are in the correct order. + i.e. output[0], output[1], output[2], ... + !*/ + +// ---------------------------------------------------------------------------------------- + + const std::string md5 ( + const std::string& input + ); + /*! + ensures + - returns the md5 digest of input as a hexadecimal string + !*/ + +// ---------------------------------------------------------------------------------------- + + void md5 ( + const unsigned char* input, + unsigned long len, + unsigned char* output + ); + /*! + requires + - input == pointer to len bytes + - output == pointer to 16 bytes + - input != output + ensures + - #output == the md5 digest of input. + !*/ + +// ---------------------------------------------------------------------------------------- + + const std::string md5 ( + std::istream& input + ); + /*! + requires + - input.fail() == false + ensures + - returns the md5 digest of input as a hexadecimal string + - #input.eof() == true + - #input.fail() == false + !*/ + +// ---------------------------------------------------------------------------------------- + + void md5 ( + std::istream& input + unsigned char* output + ); + /*! + requires + - input.fail() == false + - output == pointer to 16 bytes + ensures + - #output == the md5 digest of input + - #input.eof() == true + - #input.fail() == false + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MD5_KERNEl_ABSTRACT_ + diff --git a/ml/dlib/dlib/member_function_pointer.h b/ml/dlib/dlib/member_function_pointer.h new file mode 100644 index 000000000..3dd72f596 --- /dev/null +++ b/ml/dlib/dlib/member_function_pointer.h @@ -0,0 +1,10 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MEMBER_FUNCTION_POINTEr_ +#define DLIB_MEMBER_FUNCTION_POINTEr_ + +#include "member_function_pointer/member_function_pointer_kernel_1.h" +#include "member_function_pointer/make_mfp.h" + +#endif // DLIB_MEMBER_FUNCTION_POINTEr_ + diff --git a/ml/dlib/dlib/member_function_pointer/make_mfp.h b/ml/dlib/dlib/member_function_pointer/make_mfp.h new file mode 100644 index 000000000..fff9b27ea --- /dev/null +++ b/ml/dlib/dlib/member_function_pointer/make_mfp.h @@ -0,0 +1,179 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MAKE_MFp_H_ +#define DLIB_MAKE_MFp_H_ + +#include "member_function_pointer_kernel_1.h" +#include "make_mfp_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + member_function_pointer<> make_mfp ( + T& object, + void (T::*cb)() + ) + { + member_function_pointer<> temp; + temp.set(object, cb); + return temp; + } + + template < + typename T + > + member_function_pointer<> make_mfp ( + const T& object, + void (T::*cb)()const + ) + { + member_function_pointer<> temp; + temp.set(object, cb); + return temp; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename A1 + > + member_function_pointer make_mfp ( + T& object, + void (T::*cb)(A1) + ) + { + member_function_pointer temp; + temp.set(object, cb); + return temp; + } + + template < + typename T, + typename A1 + > + member_function_pointer make_mfp ( + const T& object, + void (T::*cb)(A1)const + ) + { + member_function_pointer temp; + temp.set(object, cb); + return temp; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename A1, + typename A2 + > + member_function_pointer make_mfp ( + T& object, + void (T::*cb)(A1,A2) + ) + { + member_function_pointer temp; + temp.set(object, cb); + return temp; + } + + template < + typename T, + typename A1, + typename A2 + > + member_function_pointer make_mfp ( + const T& object, + void (T::*cb)(A1,A2)const + ) + { + member_function_pointer temp; + temp.set(object, cb); + return temp; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename A1, + typename A2, + typename A3 + > + member_function_pointer make_mfp ( + T& object, + void (T::*cb)(A1,A2,A3) + ) + { + member_function_pointer temp; + temp.set(object, cb); + return temp; + } + + template < + typename T, + typename A1, + typename A2, + typename A3 + > + member_function_pointer make_mfp ( + const T& object, + void (T::*cb)(A1,A2,A3)const + ) + { + member_function_pointer temp; + temp.set(object, cb); + return temp; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename A1, + typename A2, + typename A3, + typename A4 + > + member_function_pointer make_mfp ( + T& object, + void (T::*cb)(A1,A2,A3,A4) + ) + { + member_function_pointer temp; + temp.set(object, cb); + return temp; + } + + template < + typename T, + typename A1, + typename A2, + typename A3, + typename A4 + > + member_function_pointer make_mfp ( + const T& object, + void (T::*cb)(A1,A2,A3,A4)const + ) + { + member_function_pointer temp; + temp.set(object, cb); + return temp; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MAKE_MFp_H_ + + + diff --git a/ml/dlib/dlib/member_function_pointer/make_mfp_abstract.h b/ml/dlib/dlib/member_function_pointer/make_mfp_abstract.h new file mode 100644 index 000000000..5074ca9a7 --- /dev/null +++ b/ml/dlib/dlib/member_function_pointer/make_mfp_abstract.h @@ -0,0 +1,207 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_MAKE_MFp_ABSTRACT_ +#ifdef DLIB_MAKE_MFp_ABSTRACT_ + +#include "member_function_pointer_kernel_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + member_function_pointer<> make_mfp ( + T& object, + void (T::*cb)() + ); + /*! + requires + - cb == a valid member function pointer for class T + ensures + - returns a member function pointer object MFP such that: + - MFP.is_set() == true + - calls to MFP() will call (object.*cb)() + !*/ + + template < + typename T + > + member_function_pointer<> make_mfp ( + const T& object, + void (T::*cb)()const + ); + /*! + requires + - cb == a valid member function pointer for class T + ensures + - returns a member function pointer object MFP such that: + - MFP.is_set() == true + - calls to MFP() will call (object.*cb)() + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename A1 + > + member_function_pointer make_mfp ( + T& object, + void (T::*cb)(A1 a1) + ); + /*! + requires + - cb == a valid member function pointer for class T + ensures + - returns a member function pointer object MFP such that: + - MFP.is_set() == true + - calls to MFP(a1) will call (object.*cb)(a1) + !*/ + + template < + typename T, + typename A1 + > + member_function_pointer make_mfp ( + const T& object, + void (T::*cb)(A1 a1)const + ); + /*! + requires + - cb == a valid member function pointer for class T + ensures + - returns a member function pointer object MFP such that: + - MFP.is_set() == true + - calls to MFP(a1) will call (object.*cb)(a1) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename A1, + typename A2 + > + member_function_pointer make_mfp ( + T& object, + void (T::*cb)(A1 a1, A2 a2) + ); + /*! + requires + - cb == a valid member function pointer for class T + ensures + - returns a member function pointer object MFP such that: + - MFP.is_set() == true + - calls to MFP(a1,a2) will call (object.*cb)(a1,a2) + !*/ + + template < + typename T, + typename A1, + typename A2 + > + member_function_pointer make_mfp ( + const T& object, + void (T::*cb)(A1 a1, A2 a2)const + ); + /*! + requires + - cb == a valid member function pointer for class T + ensures + - returns a member function pointer object MFP such that: + - MFP.is_set() == true + - calls to MFP(a1,a2) will call (object.*cb)(a1,a2) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename A1, + typename A2, + typename A3 + > + member_function_pointer make_mfp ( + T& object, + void (T::*cb)(A1 a1, A2 a2, A3 a3) + ); + /*! + requires + - cb == a valid member function pointer for class T + ensures + - returns a member function pointer object MFP such that: + - MFP.is_set() == true + - calls to MFP(a1,a2,a3) will call (object.*cb)(a1,a2,a3) + !*/ + + template < + typename T, + typename A1, + typename A2, + typename A3 + > + member_function_pointer make_mfp ( + const T& object, + void (T::*cb)(A1 a1, A2 a2, A3 a3)const + ); + /*! + requires + - cb == a valid member function pointer for class T + ensures + - returns a member function pointer object MFP such that: + - MFP.is_set() == true + - calls to MFP(a1,a2,a3) will call (object.*cb)(a1,a2,a3) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename A1, + typename A2, + typename A3, + typename A4 + > + member_function_pointer make_mfp ( + T& object, + void (T::*cb)(A1 a1, A2 a2, A3 a3, A4 a4) + ); + /*! + requires + - cb == a valid member function pointer for class T + ensures + - returns a member function pointer object MFP such that: + - MFP.is_set() == true + - calls to MFP(a1,a2,a3,a4) will call (object.*cb)(a1,a2,a3,a4) + !*/ + + template < + typename T, + typename A1, + typename A2, + typename A3, + typename A4 + > + member_function_pointer make_mfp ( + const T& object, + void (T::*cb)(A1 a1, A2 a2, A3 a3, A4 a4)const + ); + /*! + requires + - cb == a valid member function pointer for class T + ensures + - returns a member function pointer object MFP such that: + - MFP.is_set() == true + - calls to MFP(a1,a2,a3,a4) will call (object.*cb)(a1,a2,a3,a4) + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MAKE_MFp_ABSTRACT_ + + diff --git a/ml/dlib/dlib/member_function_pointer/member_function_pointer_kernel_1.h b/ml/dlib/dlib/member_function_pointer/member_function_pointer_kernel_1.h new file mode 100644 index 000000000..6cf5630b4 --- /dev/null +++ b/ml/dlib/dlib/member_function_pointer/member_function_pointer_kernel_1.h @@ -0,0 +1,498 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MEMBER_FUNCTION_POINTER_KERNEl_1_ +#define DLIB_MEMBER_FUNCTION_POINTER_KERNEl_1_ + +#include "../algs.h" +#include "member_function_pointer_kernel_abstract.h" +#include "../enable_if.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename PARAM1 = void, + typename PARAM2 = void, + typename PARAM3 = void, + typename PARAM4 = void + > + class member_function_pointer; + +// ---------------------------------------------------------------------------------------- + +#define DLIB_MFP_SC DLIB_ASSERT(cb != 0, \ + "\tvoid member_function_pointer::set" \ + << "\n\tthe member function pointer can't be null" \ + << "\n\tthis: " << this ); + + +#define DLIB_MFP_OC DLIB_ASSERT(this->is_set() == true , \ + "\tvoid member_function_pointer::operator()" \ + << "\n\tYou must call set() before you can use this function" \ + << "\n\tthis: " << this); + +// ---------------------------------------------------------------------------------------- + + template + class mfp_kernel_1_base_class + { + /* + All member function pointer classes inherit from this class. This + is where most of the things in a member function pointer are defined. + + The reason for the num_args template argument to this class is to prevent + any sort of implicit casting between derived member function pointer classes + that take different numbers of arguments. + */ + protected: + enum mfp_type { mfp_nonconst, mfp_const, mfp_null}; + + class mp_base_base + { + public: + mp_base_base(void* ptr, mfp_type type_) : o(ptr),type(type_) {} + virtual ~mp_base_base(){} + virtual void clone(void* ptr) const = 0; + virtual bool is_same (const mp_base_base* item) const = 0; + bool is_set () const { return o!=0; } + + void* const o; + const mfp_type type; + }; + + template + class mp_null : public mp_base_base + { + public: + typedef void (T::*mfp_pointer_type)() ; + + mp_null (void* , mfp_pointer_type ) : mp_base_base(0,mfp_null), callback(0) {} + mp_null () : mp_base_base(0,mfp_null), callback(0) {} + + const mfp_pointer_type callback; + }; + + template + class mp_impl_T : public mp_impl + { + /* + This class supplies the implementations clone() and is_same() for any + classes that inherit from mp_base_base. It does this in a very + roundabout way... + */ + + public: + typedef typename mp_impl::mfp_pointer_type mfp_pointer_type; + + mp_impl_T() : mp_impl(0,0) {} + mp_impl_T(void* ptr, mfp_pointer_type cb) : mp_impl(ptr,cb) {} + + template + void safe_clone(stack_based_memory_block& buf) + { + // This is here just to validate the assumption that our block of memory we have made + // in mp_memory is the right size to store the data for this object. If you + // get a compiler error on this line then email me :) + COMPILE_TIME_ASSERT(sizeof(mp_impl_T) <= mem_size); + clone(buf.get()); + } + + void clone (void* ptr) const { new(ptr) mp_impl_T(this->o,this->callback); } + bool is_same (const mp_base_base* item) const + { + if (item->o == 0 && this->o == 0) + { + return true; + } + else if (item->o == this->o && this->type == item->type) + { + const mp_impl* i = reinterpret_cast(item); + return (i->callback == this->callback); + } + return false; + } + }; + + struct dummy_base { virtual void nonnull() {}; virtual ~dummy_base(){}; int a; }; + struct dummy : virtual public dummy_base{ void nonnull() {}; }; + + typedef mp_impl_T > mp_null_impl; + public: + + mfp_kernel_1_base_class ( + const mfp_kernel_1_base_class& item + ) { item.mp()->clone(mp_memory.get()); } + + mfp_kernel_1_base_class ( + ) { mp_null_impl().safe_clone(mp_memory); } + + bool operator == ( + const mfp_kernel_1_base_class& item + ) const { return mp()->is_same(item.mp()); } + + bool operator != ( + const mfp_kernel_1_base_class& item + ) const { return !(*this == item); } + + mfp_kernel_1_base_class& operator= ( + const mfp_kernel_1_base_class& item + ) { mfp_kernel_1_base_class(item).swap(*this); return *this; } + + ~mfp_kernel_1_base_class ( + ) { destroy_mp_memory(); } + + void clear( + ) { mfp_kernel_1_base_class().swap(*this); } + + bool is_set ( + ) const { return mp()->is_set(); } + + private: + typedef void (dummy::*safe_bool)(); + + public: + operator safe_bool () const { return is_set() ? &dummy::nonnull : 0; } + bool operator!() const { return !is_set(); } + + void swap ( + mfp_kernel_1_base_class& item + ) + { + // make a temp copy of item + mfp_kernel_1_base_class temp(item); + + // destory the stuff in item + item.destroy_mp_memory(); + // copy *this into item + mp()->clone(item.mp_memory.get()); + + // destory the stuff in this + destroy_mp_memory(); + // copy temp into *this + temp.mp()->clone(mp_memory.get()); + } + + protected: + + // The reason for adding 1 here is because visual studio 2003 will sometimes + // try to compile this code with sizeof(mp_null_impl) == 0 (which is a bug in visual studio). + // Fortunately, no actual real instances of this template seem to end up with that screwed up + // value so everything works fine if we just add 1 so that this degenerate case doesn't cause + // trouble. Note that we know it all works fine because safe_clone() checks the size of this + // memory block whenever the member function pointer is used. + stack_based_memory_block mp_memory; + + void destroy_mp_memory ( + ) + { + // Honestly this probably doesn't even do anything but I'm putting + // it here just for good measure. + mp()->~mp_base_base(); + } + + mp_base_base* mp () { return static_cast(mp_memory.get()); } + const mp_base_base* mp () const { return static_cast(mp_memory.get()); } + + }; + +// ---------------------------------------------------------------------------------------- + + template <> + class member_function_pointer : public mfp_kernel_1_base_class<0> + { + class mp_base : public mp_base_base { + public: + mp_base(void* ptr, mfp_type type_) : mp_base_base(ptr,type_) {} + virtual void call() const = 0; + }; + + template + class mp_impl : public mp_base { + public: + typedef void (T::*mfp_pointer_type)() ; + void call () const { (static_cast(this->o)->*callback)(); } + + mp_impl ( void* object, mfp_pointer_type cb) : mp_base(object, mfp_nonconst), callback(cb) {} + const mfp_pointer_type callback; + }; + + template + class mp_impl_const : public mp_base { + public: + typedef void ((T::*mfp_pointer_type)()const); + void call () const { (static_cast(this->o)->*callback)(); } + + mp_impl_const ( void* object, mfp_pointer_type cb) : mp_base(object,mfp_const), callback(cb) {} + const mfp_pointer_type callback; + }; + + public: + typedef void param1_type; + typedef void param2_type; + typedef void param3_type; + typedef void param4_type; + + // These two typedefs are here for backwards compatibility with previous versions + // of dlib. + typedef member_function_pointer kernel_1a; + typedef member_function_pointer kernel_1a_c; + + + void operator() () const { DLIB_MFP_OC; static_cast(mp_memory.get())->call(); } + + // the reason for putting disable_if on this function is that it avoids an overload + // resolution bug in visual studio. + template typename disable_if,void>::type + set(T& object, typename mp_impl::mfp_pointer_type cb) + { DLIB_MFP_SC; destroy_mp_memory(); mp_impl_T >(&object,cb).safe_clone(mp_memory); } + + template void set(const T& object, typename mp_impl_const::mfp_pointer_type cb) + { DLIB_MFP_SC; destroy_mp_memory(); mp_impl_T >((void*)&object,cb).safe_clone(mp_memory); } + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename PARAM1 + > + class member_function_pointer : public mfp_kernel_1_base_class<1> + { + class mp_base : public mp_base_base { + public: + mp_base(void* ptr, mfp_type type_) : mp_base_base(ptr,type_) {} + virtual void call(PARAM1) const = 0; + }; + + template + class mp_impl : public mp_base { + public: + typedef void (T::*mfp_pointer_type)(PARAM1) ; + void call (PARAM1 p1) const { (static_cast(this->o)->*callback)(p1); } + + mp_impl ( void* object, mfp_pointer_type cb) : mp_base(object, mfp_nonconst), callback(cb) {} + const mfp_pointer_type callback; + }; + + template + class mp_impl_const : public mp_base { + public: + typedef void ((T::*mfp_pointer_type)(PARAM1)const); + void call (PARAM1 p1) const { (static_cast(this->o)->*callback)(p1); } + + mp_impl_const ( void* object, mfp_pointer_type cb) : mp_base(object,mfp_const), callback(cb) {} + const mfp_pointer_type callback; + }; + + public: + typedef PARAM1 param1_type; + typedef void param2_type; + typedef void param3_type; + typedef void param4_type; + + // These two typedefs are here for backwards compatibility with previous versions + // of dlib. + typedef member_function_pointer kernel_1a; + typedef member_function_pointer kernel_1a_c; + + + void operator() (PARAM1 p1) const { DLIB_MFP_OC; static_cast(mp_memory.get())->call(p1); } + + // the reason for putting disable_if on this function is that it avoids an overload + // resolution bug in visual studio. + template typename disable_if,void>::type + set(T& object, typename mp_impl::mfp_pointer_type cb) + { DLIB_MFP_SC; destroy_mp_memory(); mp_impl_T >(&object,cb).safe_clone(mp_memory); } + + template void set(const T& object, typename mp_impl_const::mfp_pointer_type cb) + { DLIB_MFP_SC; destroy_mp_memory(); mp_impl_T >((void*)&object,cb).safe_clone(mp_memory); } + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename PARAM1, + typename PARAM2 + > + class member_function_pointer : public mfp_kernel_1_base_class<2> + { + class mp_base : public mp_base_base { + public: + mp_base(void* ptr, mfp_type type_) : mp_base_base(ptr,type_) {} + virtual void call(PARAM1,PARAM2) const = 0; + }; + + template + class mp_impl : public mp_base { + public: + typedef void (T::*mfp_pointer_type)(PARAM1,PARAM2) ; + void call (PARAM1 p1, PARAM2 p2) const { (static_cast(this->o)->*callback)(p1,p2); } + + mp_impl ( void* object, mfp_pointer_type cb) : mp_base(object, mfp_nonconst), callback(cb) {} + const mfp_pointer_type callback; + }; + + template + class mp_impl_const : public mp_base { + public: + typedef void ((T::*mfp_pointer_type)(PARAM1,PARAM2)const); + void call (PARAM1 p1, PARAM2 p2) const { (static_cast(this->o)->*callback)(p1,p2); } + + mp_impl_const ( void* object, mfp_pointer_type cb) : mp_base(object,mfp_const), callback(cb) {} + const mfp_pointer_type callback; + }; + + public: + typedef PARAM1 param1_type; + typedef PARAM2 param2_type; + typedef void param3_type; + typedef void param4_type; + + // These two typedefs are here for backwards compatibility with previous versions + // of dlib. + typedef member_function_pointer kernel_1a; + typedef member_function_pointer kernel_1a_c; + + void operator() (PARAM1 p1, PARAM2 p2) const { DLIB_MFP_OC; static_cast(mp_memory.get())->call(p1,p2); } + + // the reason for putting disable_if on this function is that it avoids an overload + // resolution bug in visual studio. + template typename disable_if,void>::type + set(T& object, typename mp_impl::mfp_pointer_type cb) + { DLIB_MFP_SC; destroy_mp_memory(); mp_impl_T >(&object,cb).safe_clone(mp_memory); } + + template void set(const T& object, typename mp_impl_const::mfp_pointer_type cb) + { DLIB_MFP_SC; destroy_mp_memory(); mp_impl_T >((void*)&object,cb).safe_clone(mp_memory); } + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename PARAM1, + typename PARAM2, + typename PARAM3 + > + class member_function_pointer : public mfp_kernel_1_base_class<3> + { + class mp_base : public mp_base_base { + public: + mp_base(void* ptr, mfp_type type_) : mp_base_base(ptr,type_) {} + virtual void call(PARAM1,PARAM2,PARAM3) const = 0; + }; + + template + class mp_impl : public mp_base { + public: + typedef void (T::*mfp_pointer_type)(PARAM1,PARAM2,PARAM3) ; + void call (PARAM1 p1, PARAM2 p2, PARAM3 p3) const { (static_cast(this->o)->*callback)(p1,p2,p3); } + + mp_impl ( void* object, mfp_pointer_type cb) : mp_base(object, mfp_nonconst), callback(cb) {} + const mfp_pointer_type callback; + }; + + template + class mp_impl_const : public mp_base { + public: + typedef void ((T::*mfp_pointer_type)(PARAM1,PARAM2,PARAM3)const); + void call (PARAM1 p1, PARAM2 p2, PARAM3 p3) const { (static_cast(this->o)->*callback)(p1,p2,p3); } + + mp_impl_const ( void* object, mfp_pointer_type cb) : mp_base(object,mfp_const), callback(cb) {} + const mfp_pointer_type callback; + }; + + public: + typedef PARAM1 param1_type; + typedef PARAM2 param2_type; + typedef PARAM3 param3_type; + typedef void param4_type; + + // These two typedefs are here for backwards compatibility with previous versions + // of dlib. + typedef member_function_pointer kernel_1a; + typedef member_function_pointer kernel_1a_c; + + void operator() (PARAM1 p1, PARAM2 p2, PARAM3 p3) const { DLIB_MFP_OC; static_cast(mp_memory.get())->call(p1,p2,p3); } + + // the reason for putting disable_if on this function is that it avoids an overload + // resolution bug in visual studio. + template typename disable_if,void>::type + set(T& object, typename mp_impl::mfp_pointer_type cb) + { DLIB_MFP_SC; destroy_mp_memory(); mp_impl_T >(&object,cb).safe_clone(mp_memory); } + + template void set(const T& object, typename mp_impl_const::mfp_pointer_type cb) + { DLIB_MFP_SC; destroy_mp_memory(); mp_impl_T >((void*)&object,cb).safe_clone(mp_memory); } + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename PARAM1, + typename PARAM2, + typename PARAM3, + typename PARAM4 + > + class member_function_pointer : public mfp_kernel_1_base_class<4> + { + class mp_base : public mp_base_base { + public: + mp_base(void* ptr, mfp_type type_) : mp_base_base(ptr,type_) {} + virtual void call(PARAM1,PARAM2,PARAM3,PARAM4) const = 0; + }; + + template + class mp_impl : public mp_base { + public: + typedef void (T::*mfp_pointer_type)(PARAM1,PARAM2,PARAM3, PARAM4) ; + void call (PARAM1 p1, PARAM2 p2, PARAM3 p3, PARAM4 p4) const { (static_cast(this->o)->*callback)(p1,p2,p3,p4); } + + mp_impl ( void* object, mfp_pointer_type cb) : mp_base(object, mfp_nonconst), callback(cb) {} + const mfp_pointer_type callback; + }; + + template + class mp_impl_const : public mp_base { + public: + typedef void ((T::*mfp_pointer_type)(PARAM1,PARAM2,PARAM3,PARAM4)const); + void call (PARAM1 p1, PARAM2 p2, PARAM3 p3, PARAM4 p4) const { (static_cast(this->o)->*callback)(p1,p2,p3,p4); } + + mp_impl_const ( void* object, mfp_pointer_type cb) : mp_base(object,mfp_const), callback(cb) {} + const mfp_pointer_type callback; + }; + + public: + typedef PARAM1 param1_type; + typedef PARAM2 param2_type; + typedef PARAM3 param3_type; + typedef PARAM4 param4_type; + + // These two typedefs are here for backwards compatibility with previous versions + // of dlib. + typedef member_function_pointer kernel_1a; + typedef member_function_pointer kernel_1a_c; + + void operator() (PARAM1 p1, PARAM2 p2, PARAM3 p3, PARAM4 p4) const + { DLIB_MFP_OC; static_cast(mp_memory.get())->call(p1,p2,p3,p4); } + + // the reason for putting disable_if on this function is that it avoids an overload + // resolution bug in visual studio. + template typename disable_if,void>::type + set(T& object, typename mp_impl::mfp_pointer_type cb) + { DLIB_MFP_SC; destroy_mp_memory(); mp_impl_T >(&object,cb).safe_clone(mp_memory); } + + template void set(const T& object, typename mp_impl_const::mfp_pointer_type cb) + { DLIB_MFP_SC; destroy_mp_memory(); mp_impl_T >((void*)&object,cb).safe_clone(mp_memory); } + + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MEMBER_FUNCTION_POINTER_KERNEl_1_ + diff --git a/ml/dlib/dlib/member_function_pointer/member_function_pointer_kernel_abstract.h b/ml/dlib/dlib/member_function_pointer/member_function_pointer_kernel_abstract.h new file mode 100644 index 000000000..e152972ee --- /dev/null +++ b/ml/dlib/dlib/member_function_pointer/member_function_pointer_kernel_abstract.h @@ -0,0 +1,483 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_MEMBER_FUNCTION_POINTER_KERNEl_ABSTRACT_ +#ifdef DLIB_MEMBER_FUNCTION_POINTER_KERNEl_ABSTRACT_ + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename PARAM1 = void, + typename PARAM2 = void, + typename PARAM3 = void, + typename PARAM4 = void + > + class member_function_pointer; + +// ---------------------------------------------------------------------------------------- + + template <> + class member_function_pointer + { + /*! + INITIAL VALUE + is_set() == false + + WHAT THIS OBJECT REPRESENTS + This object represents a member function pointer. It is useful because + instances of this object can be created without needing to know the type + of object whose member function we will be calling. + + There are five template specializations of this object. The first + represents a pointer to a member function taking no parameters, the + second represents a pointer to a member function taking one parameter, + the third to one taking two parameters, and so on. + + You specify the parameters to your member function pointer by filling in + the PARAM template parameters. For example: + + To use a pointer to a function with no parameters you would say: + member_function_pointer<> my_pointer; + To use a pointer to a function that takes a single int you would say: + member_function_pointer my_pointer; + To use a pointer to a function that takes an int and then a reference + to a string you would say: + member_function_pointer my_pointer; + + Also note that the formal comments are only present for the first + template specialization. They are all exactly the same except for the + number of parameters each takes in its member function pointer. + !*/ + + public: + typedef void param1_type; + typedef void param2_type; + typedef void param3_type; + typedef void param4_type; + + member_function_pointer ( + ); + /*! + ensures + - #*this is properly initialized + !*/ + + member_function_pointer( + const member_function_pointer& item + ); + /*! + ensures + - *this == item + !*/ + + ~member_function_pointer ( + ); + /*! + ensures + - any resources associated with *this have been released + !*/ + + member_function_pointer& operator=( + const member_function_pointer& item + ); + /*! + ensures + - *this == item + !*/ + + bool operator == ( + const member_function_pointer& item + ) const; + /*! + ensures + - if (is_set() == false && item.is_set() == false) then + - returns true + - else if (both *this and item point to the same member function + in the same object instance) then + - returns true + - else + - returns false + !*/ + + bool operator != ( + const member_function_pointer& item + ) const; + /*! + ensures + - returns !(*this == item) + !*/ + + void clear( + ); + /*! + ensures + - #*this has its initial value + !*/ + + bool is_set ( + ) const; + /*! + ensures + - if (this->set() has been called) then + - returns true + - else + - returns false + !*/ + + template < + typename T + > + void set ( + T& object, + void (T::*cb)() + ); + /*! + requires + - cb == a valid member function pointer for class T + ensures + - #is_set() == true + - calls to this->operator() will call (object.*cb)() + !*/ + + template < + typename T + > + void set ( + const T& object, + void (T::*cb)()const + ); + /*! + requires + - cb == a valid member function pointer for class T + ensures + - #is_set() == true + - calls to this->operator() will call (object.*cb)() + !*/ + + operator some_undefined_pointer_type ( + ) const; + /*! + ensures + - if (is_set()) then + - returns a non 0 value + - else + - returns a 0 value + !*/ + + bool operator! ( + ) const; + /*! + ensures + - returns !is_set() + !*/ + + void operator () ( + ) const; + /*! + requires + - is_set() == true + ensures + - calls the member function on the object specified by the last + call to this->set() + throws + - any exception thrown by the member function specified by + the previous call to this->set(). + If any of these exceptions are thrown then the call to this + function will have no effect on *this. + !*/ + + void swap ( + member_function_pointer& item + ); + /*! + ensures + - swaps *this and item + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename PARAM1 + > + class member_function_pointer + { + public: + typedef PARAM1 param1_type; + typedef void param2_type; + typedef void param3_type; + typedef void param4_type; + + member_function_pointer (); + + member_function_pointer( + const member_function_pointer& item + ); + + ~member_function_pointer ( + ); + + member_function_pointer& operator=( + const member_function_pointer& item + ); + + bool operator == ( + const member_function_pointer& item + ) const; + + bool operator != ( + const member_function_pointer& item + ) const; + + void clear(); + + bool is_set () const; + + template + void set ( + T& object, + void (T::*cb)(PARAM1) + ); + + template + void set ( + const T& object, + void (T::*cb)(PARAM1)const + ); + + operator some_undefined_pointer_type ( + ) const; + + bool operator! ( + ) const; + + void operator () ( + PARAM1 param1 + ) const; + + void swap ( + member_function_pointer& item + ); + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename PARAM1, + typename PARAM2 + > + class member_function_pointer + { + public: + typedef PARAM1 param1_type; + typedef PARAM2 param2_type; + typedef void param3_type; + typedef void param4_type; + + member_function_pointer (); + + member_function_pointer( + const member_function_pointer& item + ); + + ~member_function_pointer ( + ); + + member_function_pointer& operator=( + const member_function_pointer& item + ); + + bool operator == ( + const member_function_pointer& item + ) const; + + bool operator != ( + const member_function_pointer& item + ) const; + + void clear(); + + bool is_set () const; + + template + void set ( + T& object, + void (T::*cb)(PARAM1,PARAM2) + ); + + template + void set ( + const T& object, + void (T::*cb)(PARAM1,PARAM2)const + ); + + operator some_undefined_pointer_type ( + ) const; + + bool operator! ( + ) const; + + void operator () ( + PARAM1 param1, + PARAM2 param2 + ) const; + + void swap ( + member_function_pointer& item + ); + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename PARAM1, + typename PARAM2, + typename PARAM3 + > + class member_function_pointer + { + public: + typedef PARAM1 param1_type; + typedef PARAM2 param2_type; + typedef PARAM3 param3_type; + typedef void param4_type; + + member_function_pointer (); + + member_function_pointer( + const member_function_pointer& item + ); + + ~member_function_pointer ( + ); + + member_function_pointer& operator=( + const member_function_pointer& item + ); + + bool operator == ( + const member_function_pointer& item + ) const; + + bool operator != ( + const member_function_pointer& item + ) const; + + void clear(); + + bool is_set () const; + + template + void set ( + T& object, + void (T::*cb)(PARAM1,PARAM2,PARAM3) + ); + + template + void set ( + const T& object, + void (T::*cb)(PARAM1,PARAM2,PARAM3)const + ); + + operator some_undefined_pointer_type ( + ) const; + + bool operator! ( + ) const; + + void operator () ( + PARAM1 param1, + PARAM2 param2, + PARAM2 param3 + ) const; + + void swap ( + member_function_pointer& item + ); + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename PARAM1, + typename PARAM2, + typename PARAM3, + typename PARAM4 + > + class member_function_pointer + { + public: + typedef PARAM1 param1_type; + typedef PARAM2 param2_type; + typedef PARAM3 param3_type; + typedef PARAM4 param4_type; + + member_function_pointer (); + + member_function_pointer( + const member_function_pointer& item + ); + + ~member_function_pointer ( + ); + + member_function_pointer& operator=( + const member_function_pointer& item + ); + + bool operator == ( + const member_function_pointer& item + ) const; + + bool operator != ( + const member_function_pointer& item + ) const; + + void clear(); + + bool is_set () const; + + template + void set ( + T& object, + void (T::*cb)(PARAM1,PARAM2,PARAM3,PARAM4) + ); + + template + void set ( + const T& object, + void (T::*cb)(PARAM1,PARAM2,PARAM3,PARAM4)const + ); + + operator some_undefined_pointer_type ( + ) const; + + bool operator! ( + ) const; + + void operator () ( + PARAM1 param1, + PARAM2 param2, + PARAM2 param3, + PARAM2 param4 + ) const; + + void swap ( + member_function_pointer& item + ); + + }; + +// ---------------------------------------------------------------------------------------- + + +} + +#endif // DLIB_MEMBER_FUNCTION_POINTER_KERNEl_ABSTRACT_ + diff --git a/ml/dlib/dlib/memory_manager.h b/ml/dlib/dlib/memory_manager.h new file mode 100644 index 000000000..5b5283255 --- /dev/null +++ b/ml/dlib/dlib/memory_manager.h @@ -0,0 +1,73 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MEMORY_MANAGEr_ +#define DLIB_MEMORY_MANAGEr_ + +#include "memory_manager/memory_manager_kernel_1.h" +#include "memory_manager/memory_manager_kernel_2.h" +#include "memory_manager/memory_manager_kernel_3.h" + + + +namespace dlib +{ + + template < + typename T + > + class memory_manager + { + memory_manager() {} + + + public: + + //----------- kernels --------------- + + // kernel_1 + typedef memory_manager_kernel_1 + kernel_1a; + typedef memory_manager_kernel_1 + kernel_1b; + typedef memory_manager_kernel_1 + kernel_1c; + typedef memory_manager_kernel_1 + kernel_1d; + typedef memory_manager_kernel_1 + kernel_1e; + typedef memory_manager_kernel_1 + kernel_1f; + + // kernel_2 + typedef memory_manager_kernel_2 + kernel_2a; + typedef memory_manager_kernel_2 + kernel_2b; + typedef memory_manager_kernel_2 + kernel_2c; + typedef memory_manager_kernel_2 + kernel_2d; + typedef memory_manager_kernel_2 + kernel_2e; + + + // kernel_3 + typedef memory_manager_kernel_3 + kernel_3a; + typedef memory_manager_kernel_3 + kernel_3b; + typedef memory_manager_kernel_3 + kernel_3c; + typedef memory_manager_kernel_3 + kernel_3d; + typedef memory_manager_kernel_3 + kernel_3e; + + + + + }; +} + +#endif // DLIB_MEMORY_MANAGEr_ + diff --git a/ml/dlib/dlib/memory_manager/memory_manager_kernel_1.h b/ml/dlib/dlib/memory_manager/memory_manager_kernel_1.h new file mode 100644 index 000000000..557a19fca --- /dev/null +++ b/ml/dlib/dlib/memory_manager/memory_manager_kernel_1.h @@ -0,0 +1,305 @@ +// Copyright (C) 2004 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MEMORY_MANAGER_KERNEl_1_ +#define DLIB_MEMORY_MANAGER_KERNEl_1_ + +#include "../algs.h" +#include "memory_manager_kernel_abstract.h" +#include "../assert.h" +#include + + +namespace dlib +{ + + template < + typename T, + unsigned long max_pool_size + > + class memory_manager_kernel_1 + { + /*! + INITIAL VALUE + allocations == 0 + next == 0 + pool_size == 0 + + REQUIREMENTS ON max_pool_size + max_pool_size is the maximum number of nodes we will keep in our linked list at once. + So you can put any value in for this argument. + + CONVENTION + This memory manager implementation allocates T objects one at a time when there are + allocation requests. Then when there is a deallocate request the returning T object + is place into a list of free blocks if that list has less than max_pool_size + blocks in it. subsequent allocation requests will be serviced by drawing from the + free list whenever it isn't empty. + + + allocations == get_number_of_allocations() + + - if (next != 0) then + - next == the next pointer to return from allocate() + and next == pointer to the first node in a linked list. each node + is one item in the memory pool. + - the last node in the linked list has next set to 0 + - pool_size == the number of nodes in the linked list + - pool_size <= max_pool_size + - else + - we need to call new to get the next pointer to return from allocate() + + !*/ + + union node + { + node* next; + char item[sizeof(T)]; + }; + + public: + + typedef T type; + + template + struct rebind { + typedef memory_manager_kernel_1 other; + }; + + + memory_manager_kernel_1( + ) : + allocations(0), + next(0), + pool_size(0) + { + } + + virtual ~memory_manager_kernel_1( + ) + { + + while (next != 0) + { + node* temp = next; + next = next->next; + ::operator delete ( static_cast(temp)); + } + } + + unsigned long get_number_of_allocations ( + ) const { return allocations; } + + T* allocate_array ( + unsigned long size + ) + { + T* temp = new T[size]; + ++allocations; + return temp; + } + + void deallocate_array ( + T* item + ) + { + --allocations; + delete [] item; + } + + T* allocate ( + ) + { + T* temp; + if (next != 0) + { + temp = reinterpret_cast(next); + + node* n = next->next; + + try + { + // construct this new T object with placement new. + new (static_cast(temp))T(); + } + catch (...) + { + next->next = n; + throw; + } + + next = n; + + --pool_size; + } + else + { + temp = static_cast(::operator new(sizeof(node))); + try + { + // construct this new T object with placement new. + new (static_cast(temp))T(); + } + catch (...) + { + // construction of the new object threw so delete the block of memory + ::operator delete ( static_cast(temp)); + throw; + } + } + + ++allocations; + return temp; + } + + void deallocate ( + T* item + ) + { + --allocations; + item->~T(); + + if (pool_size >= max_pool_size) + { + ::operator delete ( static_cast(item)); + return; + } + + // add this memory chunk into our linked list. + node* temp = reinterpret_cast(item); + temp->next = next; + next = temp; + ++pool_size; + } + + void swap ( + memory_manager_kernel_1& item + ) + { + exchange(allocations,item.allocations); + exchange(next,item.next); + exchange(pool_size,item.pool_size); + } + + private: + + // data members + unsigned long allocations; + node* next; + unsigned long pool_size; + + // restricted functions + memory_manager_kernel_1(memory_manager_kernel_1&); // copy constructor + memory_manager_kernel_1& operator=(memory_manager_kernel_1&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + class memory_manager_kernel_1 + { + /*! + INITIAL VALUE + allocations == 0 + + CONVENTION + This memory manager just calls new and delete directly so it doesn't + really do anything. + + allocations == get_number_of_allocations() + !*/ + + public: + + typedef T type; + + template + struct rebind { + typedef memory_manager_kernel_1 other; + }; + + + memory_manager_kernel_1( + ) : + allocations(0) + { + } + + virtual ~memory_manager_kernel_1( + ) + { + } + + unsigned long get_number_of_allocations ( + ) const { return allocations; } + + T* allocate_array ( + unsigned long size + ) + { + T* temp = new T[size]; + ++allocations; + return temp; + } + + void deallocate_array ( + T* item + ) + { + --allocations; + delete [] item; + } + + T* allocate ( + ) + { + T* temp = new T; + ++allocations; + return temp; + } + + void deallocate ( + T* item + ) + { + delete item; + --allocations; + } + + void swap ( + memory_manager_kernel_1& item + ) + { + exchange(allocations,item.allocations); + } + + private: + + // data members + unsigned long allocations; + + // restricted functions + memory_manager_kernel_1(memory_manager_kernel_1&); // copy constructor + memory_manager_kernel_1& operator=(memory_manager_kernel_1&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + unsigned long max_pool_size + > + inline void swap ( + memory_manager_kernel_1& a, + memory_manager_kernel_1& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MEMORY_MANAGER_KERNEl_1_ + + + diff --git a/ml/dlib/dlib/memory_manager/memory_manager_kernel_2.h b/ml/dlib/dlib/memory_manager/memory_manager_kernel_2.h new file mode 100644 index 000000000..8f026bc49 --- /dev/null +++ b/ml/dlib/dlib/memory_manager/memory_manager_kernel_2.h @@ -0,0 +1,253 @@ +// Copyright (C) 2004 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MEMORY_MANAGER_KERNEl_2_ +#define DLIB_MEMORY_MANAGER_KERNEl_2_ + +#include "../algs.h" +#include "memory_manager_kernel_abstract.h" +#include "../assert.h" +#include + +namespace dlib +{ + + template < + typename T, + unsigned long chunk_size + > + class memory_manager_kernel_2 + { + /*! + INITIAL VALUE + allocations == 0 + next == 0 + first_chunk == 0 + + REQUIREMENTS ON chunk_size + chunk_size is the number of items of type T we will allocate at a time. so + it must be > 0. + + CONVENTION + This memory manager implementation allocates memory in blocks of chunk_size*sizeof(T) + bytes. All the sizeof(T) subblocks are kept in a linked list of free memory blocks + and are given out whenever an allocation request occurs. Also, memory is not freed + until this object is destructed. + + Note that array allocations are not memory managed. + + + + allocations == get_number_of_allocations() + + - if (next != 0) then + - next == the next pointer to return from allocate() + and next == pointer to the first node in a linked list. each node + is one item in the memory pool. + - the last node in the linked list has next set to 0 + - else + - we need to call new to get the next pointer to return from allocate() + + + - if (first_chunk != 0) then + - first_chunk == the first node in a linked list that contains pointers + to all the chunks we have ever allocated. The last link in the list + has its next pointer set to 0. + !*/ + + union node + { + node* next; + char item[sizeof(T)]; + }; + + struct chunk_node + { + node* chunk; + chunk_node* next; + }; + + public: + + typedef T type; + + template + struct rebind { + typedef memory_manager_kernel_2 other; + }; + + + memory_manager_kernel_2( + ) : + allocations(0), + next(0), + first_chunk(0) + { + // You FOOL! You can't have a zero chunk_size. + COMPILE_TIME_ASSERT(chunk_size > 0); + } + + virtual ~memory_manager_kernel_2( + ) + { + if (allocations == 0) + { + while (first_chunk != 0) + { + chunk_node* temp = first_chunk; + first_chunk = first_chunk->next; + // delete the memory chunk + ::operator delete ( static_cast(temp->chunk)); + // delete the chunk_node + delete temp; + } + } + } + + unsigned long get_number_of_allocations ( + ) const { return allocations; } + + T* allocate_array ( + unsigned long size + ) + { + T* temp = new T[size]; + ++allocations; + return temp; + } + + void deallocate_array ( + T* item + ) + { + --allocations; + delete [] item; + } + + T* allocate ( + ) + { + T* temp = 0; + if (next != 0) + { + temp = reinterpret_cast(next); + node* n = next->next; + + try + { + // construct this new T object with placement new. + new (static_cast(temp))T(); + } + catch (...) + { + next->next = n; + throw; + } + + next = n; + } + else + { + // the linked list is empty so we need to allocate some more memory + node* block = 0; + block = static_cast(::operator new (sizeof(node)*chunk_size)); + + // the first part of this block can be our new object + temp = reinterpret_cast(block); + + try + { + // construct this new T object with placement new. + new (static_cast(temp))T(); + } + catch (...) + { + // construction of the new object threw so delete the block of memory + ::operator delete ( static_cast(block)); + throw; + } + + // allocate a new chunk_node + chunk_node* chunk; + try {chunk = new chunk_node; } + catch (...) + { + temp->~T(); + ::operator delete ( static_cast(block)); + throw; + } + + // add this block into the chunk list + chunk->chunk = block; + chunk->next = first_chunk; + first_chunk = chunk; + + + ++block; + // now add the rest of the block into the linked list of free nodes. + for (unsigned long i = 0; i < chunk_size-1; ++i) + { + block->next = next; + next = block; + ++block; + } + + } + + + ++allocations; + return temp; + } + + void deallocate ( + T* item + ) + { + --allocations; + item->~T(); + + // add this memory into our linked list. + node* temp = reinterpret_cast(item); + temp->next = next; + next = temp; + } + + void swap ( + memory_manager_kernel_2& item + ) + { + exchange(allocations,item.allocations); + exchange(next,item.next); + exchange(first_chunk,item.first_chunk); + } + + private: + + // data members + unsigned long allocations; + node* next; + + chunk_node* first_chunk; + + + + + // restricted functions + memory_manager_kernel_2(memory_manager_kernel_2&); // copy constructor + memory_manager_kernel_2& operator=(memory_manager_kernel_2&); // assignment operator + }; + + template < + typename T, + unsigned long chunk_size + > + inline void swap ( + memory_manager_kernel_2& a, + memory_manager_kernel_2& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MEMORY_MANAGER_KERNEl_2_ + diff --git a/ml/dlib/dlib/memory_manager/memory_manager_kernel_3.h b/ml/dlib/dlib/memory_manager/memory_manager_kernel_3.h new file mode 100644 index 000000000..1f9229772 --- /dev/null +++ b/ml/dlib/dlib/memory_manager/memory_manager_kernel_3.h @@ -0,0 +1,385 @@ +// Copyright (C) 2004 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MEMORY_MANAGER_KERNEl_3_ +#define DLIB_MEMORY_MANAGER_KERNEl_3_ + +#include "../algs.h" +#include "memory_manager_kernel_abstract.h" +#include "../assert.h" +#include +#include "memory_manager_kernel_2.h" +#include "../binary_search_tree/binary_search_tree_kernel_2.h" + + +namespace dlib +{ + + template < + typename T, + unsigned long chunk_size + > + class memory_manager_kernel_3 + { + /*! + INITIAL VALUE + allocations == 0 + next == 0 + first_chunk == 0 + bst_of_arrays == 0 + + REQUIREMENTS ON chunk_size + chunk_size is the number of items of type T we will allocate at a time. so + it must be > 0. + + CONVENTION + This memory manager implementation allocates memory in blocks of chunk_size*sizeof(T) + bytes. All the sizeof(T) subblocks are kept in a linked list of free memory blocks + and are given out whenever an allocation request occurs. Also, memory is not freed + until this object is destructed. + + + + allocations == get_number_of_allocations() + + - if (next != 0) then + - next == the next pointer to return from allocate() + and next == pointer to the first node in a linked list. each node + is one item in the memory pool. + - the last node in the linked list has next set to 0 + - else + - we need to call new to get the next pointer to return from allocate() + + - if (arrays != 0) then + - someone has called allocate_array() + - (*arrays)[size] == an array of size bytes of memory + + - if (first_chunk != 0) then + - first_chunk == the first node in a linked list that contains pointers + to all the chunks we have ever allocated. The last link in the list + has its next pointer set to 0. + !*/ + + union node + { + node* next; + char item[sizeof(T)]; + }; + + struct chunk_node + { + node* chunk; + chunk_node* next; + }; + + + typedef binary_search_tree_kernel_2< + size_t, + char*, + memory_manager_kernel_2 + > bst_of_arrays; + + public: + + typedef T type; + + template + struct rebind { + typedef memory_manager_kernel_3 other; + }; + + + memory_manager_kernel_3( + ) : + allocations(0), + next(0), + first_chunk(0), + arrays(0) + { + // You FOOL! You can't have a zero chunk_size. + COMPILE_TIME_ASSERT(chunk_size > 0); + } + + virtual ~memory_manager_kernel_3( + ) + { + if (allocations == 0) + { + while (first_chunk != 0) + { + chunk_node* temp = first_chunk; + first_chunk = first_chunk->next; + // delete the memory chunk + ::operator delete ( static_cast(temp->chunk)); + // delete the chunk_node + delete temp; + } + } + + if (arrays) + { + arrays->reset(); + while (arrays->move_next()) + { + ::operator delete (arrays->element().value()); + } + delete arrays; + } + } + + unsigned long get_number_of_allocations ( + ) const { return allocations; } + + T* allocate_array ( + unsigned long size + ) + { + size_t block_size = sizeof(T)*size + sizeof(size_t)*2; + + // make sure we have initialized the arrays object. + if (arrays == 0) + { + arrays = new bst_of_arrays; + } + + char* temp; + + // see if we have a suitable block of memory already. + arrays->position_enumerator(block_size); + if (arrays->current_element_valid()) + { + // we have a suitable block of memory already so use that one. + arrays->remove_current_element(block_size,temp); + } + else + { + temp = static_cast(::operator new(block_size)); + } + + reinterpret_cast(temp)[0] = block_size; + reinterpret_cast(temp)[1] = size; + temp += sizeof(size_t)*2; + + try + { + initialize_array(reinterpret_cast(temp),size); + } + catch (...) + { + // something was thrown while we were initializing the array so + // stick our memory block into arrays and rethrow the exception + temp -= sizeof(size_t)*2; + arrays->add(block_size,temp); + throw; + } + + ++allocations; + return reinterpret_cast(temp); + } + + void deallocate_array ( + T* item + ) + { + char* temp = reinterpret_cast(item); + temp -= sizeof(size_t)*2; + size_t block_size = reinterpret_cast(temp)[0]; + size_t size = reinterpret_cast(temp)[1]; + + deinitialize_array(item,size); + + arrays->add(block_size,temp); + + --allocations; + } + + T* allocate ( + ) + { + T* temp; + if (next != 0) + { + temp = reinterpret_cast(next); + node* n = next->next; + + try + { + // construct this new T object with placement new. + new (static_cast(temp))T(); + } + catch (...) + { + next->next = n; + throw; + } + + next = n; + } + else + { + // the linked list is empty so we need to allocate some more memory + node* block = static_cast(::operator new (sizeof(node)*chunk_size)); + + // the first part of this block can be our new object + temp = reinterpret_cast(block); + + try + { + // construct this new T object with placement new. + new (static_cast(temp))T(); + } + catch (...) + { + // construction of the new object threw so delete the block of memory + ::operator delete ( static_cast(block)); + throw; + } + + // allocate a new chunk_node + chunk_node* chunk; + try {chunk = new chunk_node; } + catch (...) + { + temp->~T(); + ::operator delete ( static_cast(block)); + throw; + } + + // add this block into the chunk list + chunk->chunk = block; + chunk->next = first_chunk; + first_chunk = chunk; + + + ++block; + // now add the rest of the block into the linked list of free nodes. + for (unsigned long i = 0; i < chunk_size-1; ++i) + { + block->next = next; + next = block; + ++block; + } + + } + + + ++allocations; + return temp; + } + + void deallocate ( + T* item + ) + { + --allocations; + item->~T(); + + // add this memory into our linked list. + node* temp = reinterpret_cast(item); + temp->next = next; + next = temp; + } + + void swap ( + memory_manager_kernel_3& item + ) + { + exchange(allocations,item.allocations); + exchange(next,item.next); + exchange(first_chunk,item.first_chunk); + exchange(arrays,item.arrays); + } + + private: + + // data members + unsigned long allocations; + node* next; + + chunk_node* first_chunk; + bst_of_arrays* arrays; + + + void initialize_array ( + T* array, + size_t size + ) const + { + size_t i; + try + { + for (i = 0; i < size; ++i) + { + // construct this new T object with placement new. + new (static_cast(array+i))T(); + } + } + catch (...) + { + // Catch any exceptions thrown during the construction process + // and then destruct any T objects that actually were successfully + // constructed. + for (size_t j = 0; j < i; ++j) + { + array[i].~T(); + } + throw; + } + } + + void deinitialize_array ( + T* array, + size_t size + ) const + { + for (size_t i = 0; i < size; ++i) + { + array[i].~T(); + } + } + + // don't do any initialization for the built in types + void initialize_array(unsigned char*, size_t) {} + void deinitialize_array(unsigned char*, size_t) {} + void initialize_array(signed char*, size_t) {} + void deinitialize_array(signed char*, size_t) {} + void initialize_array(char*, size_t) {} + void deinitialize_array(char*, size_t) {} + void initialize_array(int*, size_t) {} + void deinitialize_array(int*, size_t) {} + void initialize_array(unsigned int*, size_t) {} + void deinitialize_array(unsigned int*, size_t) {} + void initialize_array(unsigned long*, size_t) {} + void deinitialize_array(unsigned long*, size_t) {} + void initialize_array(long*, size_t) {} + void deinitialize_array(long*, size_t) {} + void initialize_array(float*, size_t) {} + void deinitialize_array(float*, size_t) {} + void initialize_array(double*, size_t) {} + void deinitialize_array(double*, size_t) {} + void initialize_array(short*, size_t) {} + void deinitialize_array(short*, size_t) {} + void initialize_array(unsigned short*, size_t) {} + void deinitialize_array(unsigned short*, size_t) {} + + + + // restricted functions + memory_manager_kernel_3(memory_manager_kernel_3&); // copy constructor + memory_manager_kernel_3& operator=(memory_manager_kernel_3&); // assignment operator + }; + + template < + typename T, + unsigned long chunk_size + > + inline void swap ( + memory_manager_kernel_3& a, + memory_manager_kernel_3& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MEMORY_MANAGER_KERNEl_3_ + diff --git a/ml/dlib/dlib/memory_manager/memory_manager_kernel_abstract.h b/ml/dlib/dlib/memory_manager/memory_manager_kernel_abstract.h new file mode 100644 index 000000000..8439caf85 --- /dev/null +++ b/ml/dlib/dlib/memory_manager/memory_manager_kernel_abstract.h @@ -0,0 +1,146 @@ +// Copyright (C) 2004 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_MEMORY_MANAGER_KERNEl_ABSTRACT_ +#ifdef DLIB_MEMORY_MANAGER_KERNEl_ABSTRACT_ + +#include "../algs.h" + +namespace dlib +{ + template < + typename T + > + class memory_manager + { + /*! + REQUIREMENTS ON T + T must have a default constructor. + + INITIAL VALUE + get_number_of_allocations() == 0 + + WHAT THIS OBJECT REPRESENTS + This object represents some kind of memory manager or memory pool. + !*/ + + public: + + typedef T type; + + template + struct rebind { + typedef memory_manager other; + }; + + memory_manager( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc + !*/ + + virtual ~memory_manager( + ); + /*! + ensures + - if (get_number_of_allocations() == 0) then + - all resources associated with *this have been released. + - else + - The memory still allocated will not be deleted and this + causes a memory leak. + !*/ + + unsigned long get_number_of_allocations ( + ) const; + /*! + ensures + - returns the current number of outstanding allocations + !*/ + + T* allocate ( + ); + /*! + ensures + - allocates a new object of type T and returns a pointer to it. + - #get_number_of_allocations() == get_number_of_allocations() + 1 + throws + - std::bad_alloc or any exception thrown by T's constructor. + If this exception is thrown then the call to allocate() + has no effect on #*this. + !*/ + + void deallocate ( + T* item + ); + /*! + requires + - item == is a pointer to memory that was obtained from a call to + this->allocate(). (i.e. you can't deallocate a pointer you + got from a different memory_manager instance.) + - the memory pointed to by item hasn't already been deallocated. + ensures + - deallocates the object pointed to by item + - #get_number_of_allocations() == get_number_of_allocations() - 1 + !*/ + + T* allocate_array ( + unsigned long size + ); + /*! + ensures + - allocates a new array of size objects of type T and returns a + pointer to it. + - #get_number_of_allocations() == get_number_of_allocations() + 1 + throws + - std::bad_alloc or any exception thrown by T's constructor. + If this exception is thrown then the call to allocate() + has no effect on #*this. + !*/ + + void deallocate_array ( + T* item + ); + /*! + requires + - item == is a pointer to memory that was obtained from a call to + this->allocate_array(). (i.e. you can't deallocate a pointer you + got from a different memory_manager instance and it must be an + array.) + - the memory pointed to by item hasn't already been deallocated. + ensures + - deallocates the array pointed to by item + - #get_number_of_allocations() == get_number_of_allocations() - 1 + !*/ + + void swap ( + memory_manager& item + ); + /*! + ensures + - swaps *this and item + !*/ + + private: + + // restricted functions + memory_manager(memory_manager&); // copy constructor + memory_manager& operator=(memory_manager&); // assignment operator + }; + + template < + typename T + > + inline void swap ( + memory_manager& a, + memory_manager& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + +} + +#endif // DLIB_MEMORY_MANAGER_KERNEl_ABSTRACT_ + diff --git a/ml/dlib/dlib/memory_manager_global.h b/ml/dlib/dlib/memory_manager_global.h new file mode 100644 index 000000000..05f439fdc --- /dev/null +++ b/ml/dlib/dlib/memory_manager_global.h @@ -0,0 +1,38 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MEMORY_MANAGER_GLOBAl_ +#define DLIB_MEMORY_MANAGER_GLOBAl_ + +#include "memory_manager_global/memory_manager_global_kernel_1.h" +#include "memory_manager.h" + + + +namespace dlib +{ + + template < + typename T, + typename factory + > + class memory_manager_global + { + memory_manager_global() {} + + + public: + + //----------- kernels --------------- + + // kernel_1 + typedef memory_manager_global_kernel_1 + kernel_1a; + + + + + }; +} + +#endif // DLIB_MEMORY_MANAGER_GLOBAl_ + diff --git a/ml/dlib/dlib/memory_manager_global/memory_manager_global_kernel_1.h b/ml/dlib/dlib/memory_manager_global/memory_manager_global_kernel_1.h new file mode 100644 index 000000000..4cf418d24 --- /dev/null +++ b/ml/dlib/dlib/memory_manager_global/memory_manager_global_kernel_1.h @@ -0,0 +1,113 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MEMORY_MANAGER_GLOBAl_1_ +#define DLIB_MEMORY_MANAGER_GLOBAl_1_ + +#include "../algs.h" +#include "../memory_manager/memory_manager_kernel_abstract.h" +#include "memory_manager_global_kernel_abstract.h" + +namespace dlib +{ + template < + typename T, + typename factory + > + class memory_manager_global_kernel_1 + { + /*! + INITIAL VALUE + - *global_mm == get_global_memory_manager() + + CONVENTION + - global_mm->get_number_of_allocations() == get_number_of_allocations() + - *global_mm == get_global_memory_manager() + !*/ + + public: + + typedef typename factory::template return_type::type mm_global_type; + + typedef T type; + + template + struct rebind { + typedef memory_manager_global_kernel_1 other; + }; + + memory_manager_global_kernel_1( + ) : + global_mm(factory::template get_instance()) + {} + + virtual ~memory_manager_global_kernel_1( + ) {} + + unsigned long get_number_of_allocations ( + ) const { return global_mm->get_number_of_allocations(); } + + mm_global_type& get_global_memory_manager ( + ) { return *global_mm; } + + T* allocate ( + ) + { + return global_mm->allocate(); + } + + void deallocate ( + T* item + ) + { + global_mm->deallocate(item); + } + + T* allocate_array ( + unsigned long size + ) + { + return global_mm->allocate_array(size); + } + + void deallocate_array ( + T* item + ) + { + global_mm->deallocate_array(item); + } + + void swap ( + memory_manager_global_kernel_1& item + ) + { + exchange(item.global_mm, global_mm); + } + + private: + + mm_global_type* global_mm; + + + // restricted functions + memory_manager_global_kernel_1(memory_manager_global_kernel_1&); // copy constructor + memory_manager_global_kernel_1& operator=(memory_manager_global_kernel_1&); // assignment operator + }; + + template < + typename T, + typename factory + > + inline void swap ( + memory_manager_global_kernel_1& a, + memory_manager_global_kernel_1& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + +} + +#endif // DLIB_MEMORY_MANAGER_GLOBAl_1_ + + + diff --git a/ml/dlib/dlib/memory_manager_global/memory_manager_global_kernel_abstract.h b/ml/dlib/dlib/memory_manager_global/memory_manager_global_kernel_abstract.h new file mode 100644 index 000000000..a35e38b51 --- /dev/null +++ b/ml/dlib/dlib/memory_manager_global/memory_manager_global_kernel_abstract.h @@ -0,0 +1,181 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_MEMORY_MANAGER_GLOBAl_ABSTRACT_ +#ifdef DLIB_MEMORY_MANAGER_GLOBAl_ABSTRACT_ + +#include "../algs.h" +#include "../memory_manager/memory_manager_kernel_abstract.h" + +namespace dlib +{ + template < + typename T, + typename factory + > + class memory_manager_global + { + /*! + REQUIREMENTS ON T + T must have a default constructor. + + REQUIREMENTS ON factory + factory must be defined as follows: + struct factory + { + template + struct return_type { + typedef typename memory_manager_type type; + }; + + template + static typename return_type::type* get_instance ( + ); + / *! + ensures + - returns a pointer to an instance of a memory_manager object + where memory_manager_type implements the interface defined + by dlib/memory_manager/memory_manager_kernel_abstract.h + !* / + }; + + WHAT THIS OBJECT REPRESENTS + This object represents some kind of global memory manager or memory pool. + It is identical to the memory_manager object except that it gets all of + its allocations from a global instance of a memory_manager object which + is provided by the factory object's static member get_instance(). + + THREAD SAFETY + This object is, by itself, threadsafe. However, if you want to use this + object in multiple threads then you must ensure that your factory is + threadsafe. This means its factory::get_instance() method should be + threadsafe and the memory_manager object it returns must also be threadsafe. + !*/ + + public: + + typedef typename factory::template return_type::type mm_global_type; + + typedef T type; + + template + struct rebind { + typedef memory_manager_global other; + }; + + memory_manager_global( + ); + /*! + ensures + - #*this is properly initialized + - #get_global_memory_manager() == the memory manager that was + returned by a call to factory::get_instance() + throws + - std::bad_alloc + !*/ + + virtual ~memory_manager_global( + ); + /*! + ensures + - This destructor has no effect on the global memory_manager + get_global_memory_manager(). + !*/ + + unsigned long get_number_of_allocations ( + ) const; + /*! + ensures + - returns get_global_memory_manager().get_number_of_allocations() + !*/ + + mm_global_type& get_global_memory_manager ( + ); + /*! + ensures + - returns a reference to the global memory manager instance being + used by *this. + !*/ + + T* allocate ( + ); + /*! + ensures + - #get_number_of_allocations() == get_number_of_allocations() + 1 + - returns get_global_memory_manager().allocate() + throws + - std::bad_alloc or any exception thrown by T's constructor. + If this exception is thrown then the call to allocate() + has no effect on #*this. + !*/ + + void deallocate ( + T* item + ); + /*! + requires + - item == is a pointer to memory that was obtained from a call to + the get_global_memory_manager() object's allocate() method. + - the memory pointed to by item hasn't already been deallocated. + ensures + - calls get_global_memory_manager().deallocate(item) + - #get_number_of_allocations() == get_number_of_allocations() - 1 + !*/ + + T* allocate_array ( + unsigned long size + ); + /*! + ensures + - #get_number_of_allocations() == get_number_of_allocations() + 1 + - returns get_global_memory_manager().allocate_array() + throws + - std::bad_alloc or any exception thrown by T's constructor. + If this exception is thrown then the call to allocate_array() + has no effect on #*this. + !*/ + + void deallocate_array ( + T* item + ); + /*! + requires + - item == is a pointer to memory that was obtained from a call to + the get_global_memory_manager() object's allocate_array() method. + - the memory pointed to by item hasn't already been deallocated. + ensures + - calls get_global_memory_manager().deallocate_array(item) + - #get_number_of_allocations() == get_number_of_allocations() - 1 + !*/ + + void swap ( + memory_manager_global& item + ); + /*! + ensures + - swaps *this and item + !*/ + + private: + + // restricted functions + memory_manager_global(memory_manager_global&); // copy constructor + memory_manager_global& operator=(memory_manager_global&); // assignment operator + }; + + template < + typename T, + typename factory + > + inline void swap ( + memory_manager_global& a, + memory_manager_global& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + +} + +#endif // DLIB_MEMORY_MANAGER_GLOBAl_ABSTRACT_ + + diff --git a/ml/dlib/dlib/memory_manager_stateless.h b/ml/dlib/dlib/memory_manager_stateless.h new file mode 100644 index 000000000..32bc7e9c0 --- /dev/null +++ b/ml/dlib/dlib/memory_manager_stateless.h @@ -0,0 +1,72 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MEMORY_MANAGER_STATELESs_ +#define DLIB_MEMORY_MANAGER_STATELESs_ + +#include "memory_manager_stateless/memory_manager_stateless_kernel_1.h" +#include "memory_manager_stateless/memory_manager_stateless_kernel_2.h" +#include "memory_manager.h" + + + +namespace dlib +{ + + template < + typename T + > + class memory_manager_stateless + { + memory_manager_stateless() {} + + + public: + + //----------- kernels --------------- + + // kernel_1 + typedef memory_manager_stateless_kernel_1 + kernel_1a; + + // kernel_2 + typedef memory_manager_stateless_kernel_2::kernel_1a> + kernel_2_1a; + typedef memory_manager_stateless_kernel_2::kernel_1b> + kernel_2_1b; + typedef memory_manager_stateless_kernel_2::kernel_1c> + kernel_2_1c; + typedef memory_manager_stateless_kernel_2::kernel_1d> + kernel_2_1d; + typedef memory_manager_stateless_kernel_2::kernel_1e> + kernel_2_1e; + typedef memory_manager_stateless_kernel_2::kernel_1f> + kernel_2_1f; + + typedef memory_manager_stateless_kernel_2::kernel_2a> + kernel_2_2a; + typedef memory_manager_stateless_kernel_2::kernel_2b> + kernel_2_2b; + typedef memory_manager_stateless_kernel_2::kernel_2c> + kernel_2_2c; + typedef memory_manager_stateless_kernel_2::kernel_2d> + kernel_2_2d; + typedef memory_manager_stateless_kernel_2::kernel_2e> + kernel_2_2e; + + typedef memory_manager_stateless_kernel_2::kernel_3a> + kernel_2_3a; + typedef memory_manager_stateless_kernel_2::kernel_3b> + kernel_2_3b; + typedef memory_manager_stateless_kernel_2::kernel_3c> + kernel_2_3c; + typedef memory_manager_stateless_kernel_2::kernel_3d> + kernel_2_3d; + typedef memory_manager_stateless_kernel_2::kernel_3e> + kernel_2_3e; + + + }; +} + +#endif // DLIB_MEMORY_MANAGER_STATELESs_ + diff --git a/ml/dlib/dlib/memory_manager_stateless/memory_manager_stateless_kernel_1.h b/ml/dlib/dlib/memory_manager_stateless/memory_manager_stateless_kernel_1.h new file mode 100644 index 000000000..0d5794d54 --- /dev/null +++ b/ml/dlib/dlib/memory_manager_stateless/memory_manager_stateless_kernel_1.h @@ -0,0 +1,86 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MEMORY_MANAGER_STATELESs_1_ +#define DLIB_MEMORY_MANAGER_STATELESs_1_ + +#include "memory_manager_stateless_kernel_abstract.h" + +namespace dlib +{ + template < + typename T + > + class memory_manager_stateless_kernel_1 + { + /*! + this implementation just calls new and delete directly + !*/ + + public: + + typedef T type; + const static bool is_stateless = true; + + template + struct rebind { + typedef memory_manager_stateless_kernel_1 other; + }; + + memory_manager_stateless_kernel_1( + ) + {} + + virtual ~memory_manager_stateless_kernel_1( + ) {} + + T* allocate ( + ) + { + return new T; + } + + void deallocate ( + T* item + ) + { + delete item; + } + + T* allocate_array ( + unsigned long size + ) + { + return new T[size]; + } + + void deallocate_array ( + T* item + ) + { + delete [] item; + } + + void swap (memory_manager_stateless_kernel_1&) + {} + + private: + + // restricted functions + memory_manager_stateless_kernel_1(memory_manager_stateless_kernel_1&); // copy constructor + memory_manager_stateless_kernel_1& operator=(memory_manager_stateless_kernel_1&); // assignment operator + }; + + template < + typename T + > + inline void swap ( + memory_manager_stateless_kernel_1& a, + memory_manager_stateless_kernel_1& b + ) { a.swap(b); } + +} + +#endif // DLIB_MEMORY_MANAGER_STATELESs_1_ + + + diff --git a/ml/dlib/dlib/memory_manager_stateless/memory_manager_stateless_kernel_2.h b/ml/dlib/dlib/memory_manager_stateless/memory_manager_stateless_kernel_2.h new file mode 100644 index 000000000..7c4bf76f5 --- /dev/null +++ b/ml/dlib/dlib/memory_manager_stateless/memory_manager_stateless_kernel_2.h @@ -0,0 +1,119 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MEMORY_MANAGER_STATELESs_2_ +#define DLIB_MEMORY_MANAGER_STATELESs_2_ + +#include "../algs.h" +#include "memory_manager_stateless_kernel_abstract.h" +#include "../threads.h" + +namespace dlib +{ + template < + typename T, + typename mem_manager + > + class memory_manager_stateless_kernel_2 + { + /*! + REQUIREMENTS ON mem_manager + mem_manager must be an implementation of memory_manager/memory_manager_kernel_abstract.h + + CONVENTION + this object has a single global instance of mem_manager + !*/ + + public: + + typedef T type; + const static bool is_stateless = true; + + template + struct rebind { + typedef memory_manager_stateless_kernel_2 other; + }; + + memory_manager_stateless_kernel_2( + ) + { + // call this just to make sure the mutex is is initialized before + // multiple threads start calling the member functions. + global_mutex(); + } + + virtual ~memory_manager_stateless_kernel_2( + ) {} + + T* allocate ( + ) + { + auto_mutex M(global_mutex()); + return global_mm().allocate(); + } + + void deallocate ( + T* item + ) + { + auto_mutex M(global_mutex()); + return global_mm().deallocate(item); + } + + T* allocate_array ( + unsigned long size + ) + { + auto_mutex M(global_mutex()); + return global_mm().allocate_array(size); + } + + void deallocate_array ( + T* item + ) + { + auto_mutex M(global_mutex()); + return global_mm().deallocate_array(item); + } + + void swap (memory_manager_stateless_kernel_2&) + {} + + private: + + static mutex& global_mutex ( + ) + { + static mutex lock; + return lock; + } + + typedef typename mem_manager::template rebind::other rebound_mm_type; + + static rebound_mm_type& global_mm ( + ) + { + static rebound_mm_type mm; + return mm; + } + + // restricted functions + memory_manager_stateless_kernel_2(memory_manager_stateless_kernel_2&); // copy constructor + memory_manager_stateless_kernel_2& operator=(memory_manager_stateless_kernel_2&); // assignment operator + }; + + template < + typename T, + typename mem_manager + > + inline void swap ( + memory_manager_stateless_kernel_2& a, + memory_manager_stateless_kernel_2& b + ) { a.swap(b); } + +} + +#endif // DLIB_MEMORY_MANAGER_STATELESs_2_ + + + + diff --git a/ml/dlib/dlib/memory_manager_stateless/memory_manager_stateless_kernel_abstract.h b/ml/dlib/dlib/memory_manager_stateless/memory_manager_stateless_kernel_abstract.h new file mode 100644 index 000000000..2c5b1e73c --- /dev/null +++ b/ml/dlib/dlib/memory_manager_stateless/memory_manager_stateless_kernel_abstract.h @@ -0,0 +1,142 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_MEMORY_MANAGER_STATELESs_ABSTRACT_ +#ifdef DLIB_MEMORY_MANAGER_STATELESs_ABSTRACT_ + +#include "../algs.h" + +namespace dlib +{ + template < + typename T + > + class memory_manager_stateless + { + /*! + REQUIREMENTS ON T + T must have a default constructor. + + WHAT THIS OBJECT REPRESENTS + This object represents some kind of stateless memory manager or memory pool. + Stateless means that all instances (instances of the same kernel implementation that is) + of this object are identical and can be used interchangeably. Note that + implementations are allowed to have some shared global state such as a + global memory pool. + + THREAD SAFETY + This object is thread safe. You may access it from any thread at any time + without synchronizing access. + !*/ + + public: + + typedef T type; + const static bool is_stateless = true; + + template + struct rebind { + typedef memory_manager_stateless other; + }; + + memory_manager_stateless( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc + !*/ + + virtual ~memory_manager_stateless( + ); + /*! + ensures + - frees any resources used by *this but has no effect on any shared global + resources used by the implementation. + !*/ + + T* allocate ( + ); + /*! + ensures + - allocates a new object of type T and returns a pointer to it. + throws + - std::bad_alloc or any exception thrown by T's constructor. + If this exception is thrown then the call to allocate() + has no effect on #*this. + !*/ + + void deallocate ( + T* item + ); + /*! + requires + - item == is a pointer to memory that was obtained from a call to + allocate(). (i.e. The pointer you are deallocating must have + come from the same implementation of memory_manager_stateless + that is trying to deallocate it.) + - the memory pointed to by item hasn't already been deallocated. + ensures + - deallocates the object pointed to by item + !*/ + + T* allocate_array ( + unsigned long size + ); + /*! + ensures + - allocates a new array of size objects of type T and returns a + pointer to it. + throws + - std::bad_alloc or any exception thrown by T's constructor. + If this exception is thrown then the call to allocate() + has no effect on #*this. + !*/ + + void deallocate_array ( + T* item + ); + /*! + requires + - item == is a pointer to memory that was obtained from a call to + allocate_array(). (i.e. The pointer you are deallocating must have + come from the same implementation of memory_manager_stateless + that is trying to deallocate it.) + - the memory pointed to by item hasn't already been deallocated. + ensures + - deallocates the array pointed to by item + !*/ + + void swap ( + memory_manager_stateless& item + ); + /*! + ensures + - this function has no effect on *this or item. It is just provided + to make this object's interface more compatable with the other + memory managers. + !*/ + + private: + + // restricted functions + memory_manager_stateless(memory_manager_stateless&); // copy constructor + memory_manager_stateless& operator=(memory_manager_stateless&); // assignment operator + }; + + template < + typename T + > + inline void swap ( + memory_manager_stateless& a, + memory_manager_stateless& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + +} + +#endif // DLIB_MEMORY_MANAGER_STATELESs_ABSTRACT_ + + diff --git a/ml/dlib/dlib/metaprogramming.h b/ml/dlib/dlib/metaprogramming.h new file mode 100644 index 000000000..bc63041ff --- /dev/null +++ b/ml/dlib/dlib/metaprogramming.h @@ -0,0 +1,71 @@ +// Copyright (C) 2017 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_METApROGRAMMING_Hh_ +#define DLIB_METApROGRAMMING_Hh_ + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template + struct compile_time_integer_list + { + /*! + WHAT THIS OBJECT REPRESENTS + The point of this type is to, as the name suggests, hold a compile time list of integers. + As an example, here is something simple you could do with it: + + template + void print_compile_time_ints ( + compile_time_integer_list + ) + { + print(ints...); + } + + int main() + { + print_compile_time_ints(compile_time_integer_list<0,4,9>()); + } + + Which just calls: print(0,4,9); + + This is a simple example, but this kind of thing is useful in larger and + more complex template metaprogramming constructs. + !*/ + + template + struct push_back + { + typedef compile_time_integer_list type; + }; + }; + +// ---------------------------------------------------------------------------------------- + + template + struct make_compile_time_integer_range + { + /*! + WHAT THIS OBJECT REPRESENTS + This object makes a compile_time_integer_list containing the integers in the range [1,max] inclusive. + For example: + make_compile_time_integer_range<4>::type + evaluates to: + compile_time_integer_list<1,2,3,4> + !*/ + + typedef typename make_compile_time_integer_range::type::template push_back::type type; + }; + // base case + template <> struct make_compile_time_integer_range<0> { typedef compile_time_integer_list<> type; }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_METApROGRAMMING_Hh_ + + diff --git a/ml/dlib/dlib/misc_api.h b/ml/dlib/dlib/misc_api.h new file mode 100644 index 000000000..acf87be7c --- /dev/null +++ b/ml/dlib/dlib/misc_api.h @@ -0,0 +1,20 @@ +// Copyright (C) 2004 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#ifndef DLIB_MISC_APi_ +#define DLIB_MISC_APi_ + +#include "platform.h" + +#ifdef WIN32 +#include "misc_api/windows.h" +#endif + +#ifndef WIN32 +#include "misc_api/posix.h" +#endif + +#include "misc_api/misc_api_shared.h" + +#endif // DLIB_MISC_APi_ + diff --git a/ml/dlib/dlib/misc_api/misc_api_kernel_1.cpp b/ml/dlib/dlib/misc_api/misc_api_kernel_1.cpp new file mode 100644 index 000000000..f17d850e1 --- /dev/null +++ b/ml/dlib/dlib/misc_api/misc_api_kernel_1.cpp @@ -0,0 +1,149 @@ +// Copyright (C) 2004 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MISC_API_KERNEL_1_CPp_ +#define DLIB_MISC_API_KERNEL_1_CPp_ + +#include "../platform.h" +#include "../threads.h" + +#ifdef WIN32 + +#include "misc_api_kernel_1.h" + +#include "../windows_magic.h" +#include +#include + +// tell visual studio to link to the library needed to call timeGetTime() +#ifdef _MSC_VER +#pragma comment (lib, "winmm.lib") +#endif + +#ifdef __BORLANDC__ +// Apparently the borland compiler doesn't define this. +#define INVALID_FILE_ATTRIBUTES ((DWORD)-1) +#endif + +namespace dlib +{ +// ---------------------------------------------------------------------------------------- + + void sleep ( + unsigned long milliseconds + ) + { + ::Sleep(milliseconds); + } + +// ---------------------------------------------------------------------------------------- + + namespace + { + mutex& cwd_mutex() + { + static mutex m; + return m; + } + // Make sure the above mutex gets constructed before main() + // starts. This way we can be pretty sure it will be constructed + // before any threads could possibly call set_current_dir() or + // get_current_dir() simultaneously. + struct construct_cwd_mutex + { + construct_cwd_mutex() + { + cwd_mutex(); + } + } oaimvweoinvwe; + } + + std::string get_current_dir ( + ) + { + // need to lock a mutex here because getting and setting the + // current working directory is not thread safe on windows. + auto_mutex lock(cwd_mutex()); + char buf[1024]; + if (GetCurrentDirectoryA(sizeof(buf),buf) == 0) + { + return std::string(); + } + else + { + return std::string(buf); + } + } + +// ---------------------------------------------------------------------------------------- + + void set_current_dir ( + const std::string& new_dir + ) + { + // need to lock a mutex here because getting and setting the + // current working directory is not thread safe on windows. + auto_mutex lock(cwd_mutex()); + if (SetCurrentDirectoryA(new_dir.c_str()) == 0) + { + throw set_current_dir_error("Error changing current dir to '" + new_dir + "'"); + } + } + +// ---------------------------------------------------------------------------------------- + + uint64 timestamper:: + get_timestamp ( + ) const + { + unsigned long temp = timeGetTime(); + if (temp >= last_time) + { + last_time = temp; + return (offset + temp)*1000; + } + else + { + last_time = temp; + + // there was overflow since the last call so we need to make the offset + // bigger to account for that + offset += dword_max; + return (offset + temp)*1000; + } + } + +// ---------------------------------------------------------------------------------------- + + void create_directory ( + const std::string& dir + ) + { + if (CreateDirectoryA(dir.c_str(),0) == 0) + { + // an error has occurred + if (GetLastError() == ERROR_ALREADY_EXISTS) + { + // make sure this is actually a directory + DWORD attribs = GetFileAttributesA(dir.c_str()); + if (attribs == INVALID_FILE_ATTRIBUTES || + (attribs&FILE_ATTRIBUTE_DIRECTORY) == 0) + { + // it isn't a directory + throw dir_create_error(dir); + } + } + else + { + throw dir_create_error(dir); + } + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // WIN32 + +#endif // DLIB_MISC_API_KERNEL_1_CPp_ + diff --git a/ml/dlib/dlib/misc_api/misc_api_kernel_1.h b/ml/dlib/dlib/misc_api/misc_api_kernel_1.h new file mode 100644 index 000000000..a500e992a --- /dev/null +++ b/ml/dlib/dlib/misc_api/misc_api_kernel_1.h @@ -0,0 +1,110 @@ +// Copyright (C) 2004 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MISC_API_KERNEl_1_ +#define DLIB_MISC_API_KERNEl_1_ + +#ifdef DLIB_ISO_CPP_ONLY +#error "DLIB_ISO_CPP_ONLY is defined so you can't use this OS dependent code. Turn DLIB_ISO_CPP_ONLY off if you want to use it." +#endif + + +#include "misc_api_kernel_abstract.h" +#include "../algs.h" +#include +#include "../uintn.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + void sleep ( + unsigned long milliseconds + ); + +// ---------------------------------------------------------------------------------------- + + std::string get_current_dir ( + ); + +// ---------------------------------------------------------------------------------------- + + class set_current_dir_error : public error + { + public: + set_current_dir_error( + const std::string& a + ): error(a) {} + }; + + void set_current_dir ( + const std::string& new_dir + ); + +// ---------------------------------------------------------------------------------------- + + class timestamper + { + /*! + INITIAL VALUE + - last_time == 0 + - offset == 0 + - dword_max == 2^32 + + CONVENTION + - last_time == the time returned by GetTickCount() the last time we + called it. + - offset == the number of microseconds we should add to the result of + GetTickCount() so that it is correct. + - dword_max == 2^32. + This is the number of values representable by a DWORD. + !*/ + + mutable unsigned long last_time; + mutable uint64 offset; + mutable uint64 dword_max; + + public: + timestamper( + ) : + last_time(0), + offset(0) + { + dword_max = 0xFFFFFFFF; + ++dword_max; + } + + uint64 get_timestamp ( + ) const; + }; + +// ---------------------------------------------------------------------------------------- + + class dir_create_error : public error + { + public: + dir_create_error( + const std::string& dir_name + ) : + error(EDIR_CREATE,"Error creating directory '" + dir_name + "'."), + name(dir_name) + {} + + ~dir_create_error() throw() {} + const std::string name; + }; + + void create_directory ( + const std::string& dir + ); + +// ---------------------------------------------------------------------------------------- + +} + +#ifdef NO_MAKEFILE +#include "misc_api_kernel_1.cpp" +#endif + +#endif // DLIB_MISC_API_KERNEl_1_ + diff --git a/ml/dlib/dlib/misc_api/misc_api_kernel_2.cpp b/ml/dlib/dlib/misc_api/misc_api_kernel_2.cpp new file mode 100644 index 000000000..e6dc772da --- /dev/null +++ b/ml/dlib/dlib/misc_api/misc_api_kernel_2.cpp @@ -0,0 +1,123 @@ +// Copyright (C) 2004 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MISC_API_KERNEL_2_CPp_ +#define DLIB_MISC_API_KERNEL_2_CPp_ +#include "../platform.h" + +#ifdef POSIX + +#include +#include "misc_api_kernel_2.h" +#include +#include +#include +#include + +namespace dlib +{ +// ---------------------------------------------------------------------------------------- + + void sleep ( + unsigned long milliseconds + ) + { + // in HP-UX you can only usleep for less than a second +#ifdef HPUX + if (milliseconds >= 1000) + { + ::sleep(milliseconds/1000); + unsigned long remaining = milliseconds%1000; + if (remaining > 0) + ::usleep(remaining*1000); + } + else + { + ::usleep(milliseconds*1000); + } +#else + ::usleep(milliseconds*1000); +#endif + } + +// ---------------------------------------------------------------------------------------- + + std::string get_current_dir ( + ) + { + char buf[1024]; + if (getcwd(buf,sizeof(buf)) == 0) + { + return std::string(); + } + else + { + return std::string(buf); + } + } + +// ---------------------------------------------------------------------------------------- + + void set_current_dir ( + const std::string& new_dir + ) + { + if (chdir(new_dir.c_str())) + { + throw set_current_dir_error("Error changing current dir to '" + new_dir + "'"); + } + } + +// ---------------------------------------------------------------------------------------- + + uint64 timestamper:: + get_timestamp ( + ) const + { + uint64 ts; + timeval curtime; + gettimeofday(&curtime,0); + + ts = curtime.tv_sec; + ts *= 1000000; + ts += curtime.tv_usec; + return ts; + } + +// ---------------------------------------------------------------------------------------- + + void create_directory ( + const std::string& dir + ) + { + if (mkdir(dir.c_str(),0777)) + { + // an error has occurred + if (errno == EEXIST) + { + struct stat buffer; + // now check that this is actually a valid directory + if (::stat(dir.c_str(),&buffer)) + { + // the directory was not found + throw dir_create_error(dir); + } + else if (S_ISDIR(buffer.st_mode) == 0) + { + // It is not a directory + throw dir_create_error(dir); + } + } + else + { + throw dir_create_error(dir); + } + } + } + +// ---------------------------------------------------------------------------------------- +} + +#endif // POSIX + +#endif // DLIB_MISC_API_KERNEL_2_CPp_ + diff --git a/ml/dlib/dlib/misc_api/misc_api_kernel_2.h b/ml/dlib/dlib/misc_api/misc_api_kernel_2.h new file mode 100644 index 000000000..86e8a7f5b --- /dev/null +++ b/ml/dlib/dlib/misc_api/misc_api_kernel_2.h @@ -0,0 +1,81 @@ +// Copyright (C) 2004 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MISC_API_KERNEl_2_ +#define DLIB_MISC_API_KERNEl_2_ + +#ifdef DLIB_ISO_CPP_ONLY +#error "DLIB_ISO_CPP_ONLY is defined so you can't use this OS dependent code. Turn DLIB_ISO_CPP_ONLY off if you want to use it." +#endif + + +#include "misc_api_kernel_abstract.h" +#include "../algs.h" +#include +#include "../uintn.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + void sleep ( + unsigned long milliseconds + ); + +// ---------------------------------------------------------------------------------------- + + std::string get_current_dir ( + ); + +// ---------------------------------------------------------------------------------------- + + class set_current_dir_error : public error + { + public: + set_current_dir_error( + const std::string& a + ): error(a) {} + }; + + void set_current_dir ( + const std::string& new_dir + ); + +// ---------------------------------------------------------------------------------------- + + class timestamper + { + public: + uint64 get_timestamp ( + ) const; + }; + +// ---------------------------------------------------------------------------------------- + + class dir_create_error : public error + { + public: + dir_create_error( + const std::string& dir_name + ) : + error(EDIR_CREATE,"Error creating directory '" + dir_name + "'."), + name(dir_name) + {} + const std::string& name; + }; + + + void create_directory ( + const std::string& dir + ); + +// ---------------------------------------------------------------------------------------- + +} + +#ifdef NO_MAKEFILE +#include "misc_api_kernel_2.cpp" +#endif + +#endif // DLIB_MISC_API_KERNEl_2_ + diff --git a/ml/dlib/dlib/misc_api/misc_api_kernel_abstract.h b/ml/dlib/dlib/misc_api/misc_api_kernel_abstract.h new file mode 100644 index 000000000..47749b91b --- /dev/null +++ b/ml/dlib/dlib/misc_api/misc_api_kernel_abstract.h @@ -0,0 +1,159 @@ +// Copyright (C) 2004 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_MISC_API_KERNEl_ABSTRACT_ +#ifdef DLIB_MISC_API_KERNEl_ABSTRACT_ + +#include +#include "../uintn.h" +#include "../algs.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + /*! + GENERAL COMMENTS + This file just contains miscellaneous api functions + !*/ + +// ---------------------------------------------------------------------------------------- + + void sleep ( + unsigned long milliseconds + ); + /*! + ensures + - causes the calling thread to sleep for the given number of + milliseconds. + !*/ + +// ---------------------------------------------------------------------------------------- + + std::string get_current_dir ( + ); + /*! + ensures + - if (no errors occur) then + - returns the path to the current working directory + - else + - returns "" + throws + - std::bad_alloc + !*/ + +// ---------------------------------------------------------------------------------------- + + class set_current_dir_error : public error; + + void set_current_dir ( + const std::string& new_dir + ); + /*! + ensures + - sets the current working directory to new_dir + throws + - std::bad_alloc + - set_current_dir_error + This exception is thrown if there is an error when attempting + to change the current working directory. + !*/ + +// ---------------------------------------------------------------------------------------- + + class locally_change_current_dir : noncopyable + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a RAII tool for safely switching the current directory + to a new directory and then automatically switching back to the original + directory upon this object's destruction. + !*/ + public: + explicit locally_change_current_dir ( + const std::string& new_dir + ); + /*! + ensures + - calls set_current_dir(new_dir) + - #old_dir() == The value of get_current_dir() prior to switching to new_dir. + !*/ + + const std::string& old_dir ( + ) const; + /*! + ensures + - returns the directory we switch back to once this object is destructed. + !*/ + + ~locally_change_current_dir( + ); + /*! + ensures + - if (revert() hasn't already been called) then + - calls set_current_dir(old_dir()) + !*/ + + void revert ( + ); + /*! + ensures + - if (revert() hasn't already been called) then + - calls set_current_dir(old_dir()) + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + class dir_create_error : public error { + public: + const std::string name + }; + + void create_directory ( + const std::string& dir + ); + /*! + ensures + - if (dir does not already exist) then + - creates the given directory. + - else + - the call to create_directory() has no effect. + throws + - dir_create_error + This exception is thrown if we were unable to create the requested + directory and it didn't already exist. The type member of the exception + will bet set to EDIR_CREATE and the name member will be set to dir. + !*/ + +// ---------------------------------------------------------------------------------------- + + class timestamper + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a timer that is capable of returning + timestamps. + + Note that the time is measured in microseconds but you are not + guaranteed to have that level of resolution. The actual resolution + is implementation dependent. + !*/ + + public: + uint64 get_timestamp ( + ) const; + /*! + ensures + - returns a timestamp that measures the time in microseconds since an + arbitrary point in the past. Note that this arbitrary point remains + the same between all calls to get_timestamp(). + !*/ + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MISC_API_KERNEl_ABSTRACT_ + diff --git a/ml/dlib/dlib/misc_api/misc_api_shared.h b/ml/dlib/dlib/misc_api/misc_api_shared.h new file mode 100644 index 000000000..6b84dd64c --- /dev/null +++ b/ml/dlib/dlib/misc_api/misc_api_shared.h @@ -0,0 +1,57 @@ +// Copyright (C) 2014 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MISC_API_ShARED_Hh_ +#define DLIB_MISC_API_ShARED_Hh_ + +#include +#include "../noncopyable.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class locally_change_current_dir : noncopyable + { + public: + explicit locally_change_current_dir ( + const std::string& new_dir + ) + { + reverted = false; + _old_dir = get_current_dir(); + set_current_dir(new_dir); + } + + ~locally_change_current_dir() + { + revert(); + } + + const std::string& old_dir ( + ) const + { + return _old_dir; + } + + void revert ( + ) + { + if (!reverted) + { + set_current_dir(_old_dir); + reverted = true; + } + } + + private: + bool reverted; + std::string _old_dir; + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MISC_API_ShARED_Hh_ + diff --git a/ml/dlib/dlib/misc_api/posix.h b/ml/dlib/dlib/misc_api/posix.h new file mode 100644 index 000000000..1dbb38031 --- /dev/null +++ b/ml/dlib/dlib/misc_api/posix.h @@ -0,0 +1,6 @@ +// Copyright (C) 2004 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MISC_API_KERNEl_1_ +#include "misc_api_kernel_2.h" +#endif + diff --git a/ml/dlib/dlib/misc_api/windows.h b/ml/dlib/dlib/misc_api/windows.h new file mode 100644 index 000000000..0817c2aaf --- /dev/null +++ b/ml/dlib/dlib/misc_api/windows.h @@ -0,0 +1,6 @@ +// Copyright (C) 2004 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MISC_API_KERNEl_2_ +#include "misc_api_kernel_1.h" +#endif + diff --git a/ml/dlib/dlib/mlp.h b/ml/dlib/dlib/mlp.h new file mode 100644 index 000000000..d287847c8 --- /dev/null +++ b/ml/dlib/dlib/mlp.h @@ -0,0 +1,30 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MLp_ +#define DLIB_MLp_ + +#include "mlp/mlp_kernel_1.h" +#include "mlp/mlp_kernel_c.h" + +namespace dlib +{ + + class mlp + { + mlp() {} + + public: + + //----------- kernels --------------- + + // kernel_1a + typedef mlp_kernel_1 + kernel_1a; + typedef mlp_kernel_c + kernel_1a_c; + + }; +} + +#endif // DLIB_MLp_ + diff --git a/ml/dlib/dlib/mlp/mlp_kernel_1.h b/ml/dlib/dlib/mlp/mlp_kernel_1.h new file mode 100644 index 000000000..d420eea9c --- /dev/null +++ b/ml/dlib/dlib/mlp/mlp_kernel_1.h @@ -0,0 +1,394 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MLp_KERNEL_1_ +#define DLIB_MLp_KERNEL_1_ + +#include "../algs.h" +#include "../serialize.h" +#include "../matrix.h" +#include "../rand.h" +#include "mlp_kernel_abstract.h" +#include +#include + +namespace dlib +{ + + class mlp_kernel_1 : noncopyable + { + /*! + INITIAL VALUE + The network is initially initialized with random weights + + CONVENTION + - input_layer_nodes() == input_nodes + - first_hidden_layer_nodes() == first_hidden_nodes + - second_hidden_layer_nodes() == second_hidden_nodes + - output_layer_nodes() == output_nodes + - get_alpha == alpha + - get_momentum() == momentum + + + - if (second_hidden_nodes == 0) then + - for all i and j: + - w1(i,j) == the weight on the link from node i in the first hidden layer + to input node j + - w3(i,j) == the weight on the link from node i in the output layer + to first hidden layer node j + - for all i and j: + - w1m == the momentum terms for w1 from the previous update + - w3m == the momentum terms for w3 from the previous update + - else + - for all i and j: + - w1(i,j) == the weight on the link from node i in the first hidden layer + to input node j + - w2(i,j) == the weight on the link from node i in the second hidden layer + to first hidden layer node j + - w3(i,j) == the weight on the link from node i in the output layer + to second hidden layer node j + - for all i and j: + - w1m == the momentum terms for w1 from the previous update + - w2m == the momentum terms for w2 from the previous update + - w3m == the momentum terms for w3 from the previous update + !*/ + + public: + + mlp_kernel_1 ( + long nodes_in_input_layer, + long nodes_in_first_hidden_layer, + long nodes_in_second_hidden_layer = 0, + long nodes_in_output_layer = 1, + double alpha_ = 0.1, + double momentum_ = 0.8 + ) : + input_nodes(nodes_in_input_layer), + first_hidden_nodes(nodes_in_first_hidden_layer), + second_hidden_nodes(nodes_in_second_hidden_layer), + output_nodes(nodes_in_output_layer), + alpha(alpha_), + momentum(momentum_) + { + + // seed the random number generator + std::ostringstream sout; + sout << time(0); + rand_nums.set_seed(sout.str()); + + w1.set_size(first_hidden_nodes+1, input_nodes+1); + w1m.set_size(first_hidden_nodes+1, input_nodes+1); + z.set_size(input_nodes+1,1); + + if (second_hidden_nodes != 0) + { + w2.set_size(second_hidden_nodes+1, first_hidden_nodes+1); + w3.set_size(output_nodes, second_hidden_nodes+1); + + w2m.set_size(second_hidden_nodes+1, first_hidden_nodes+1); + w3m.set_size(output_nodes, second_hidden_nodes+1); + } + else + { + w3.set_size(output_nodes, first_hidden_nodes+1); + + w3m.set_size(output_nodes, first_hidden_nodes+1); + } + + reset(); + } + + virtual ~mlp_kernel_1 ( + ) {} + + void reset ( + ) + { + // randomize the weights for the first layer + for (long r = 0; r < w1.nr(); ++r) + for (long c = 0; c < w1.nc(); ++c) + w1(r,c) = rand_nums.get_random_double(); + + // randomize the weights for the second layer + for (long r = 0; r < w2.nr(); ++r) + for (long c = 0; c < w2.nc(); ++c) + w2(r,c) = rand_nums.get_random_double(); + + // randomize the weights for the third layer + for (long r = 0; r < w3.nr(); ++r) + for (long c = 0; c < w3.nc(); ++c) + w3(r,c) = rand_nums.get_random_double(); + + // zero all the momentum terms + set_all_elements(w1m,0); + set_all_elements(w2m,0); + set_all_elements(w3m,0); + } + + long input_layer_nodes ( + ) const { return input_nodes; } + + long first_hidden_layer_nodes ( + ) const { return first_hidden_nodes; } + + long second_hidden_layer_nodes ( + ) const { return second_hidden_nodes; } + + long output_layer_nodes ( + ) const { return output_nodes; } + + double get_alpha ( + ) const { return alpha; } + + double get_momentum ( + ) const { return momentum; } + + template + const matrix operator() ( + const matrix_exp& in + ) const + { + for (long i = 0; i < in.nr(); ++i) + z(i) = in(i); + // insert the bias + z(z.nr()-1) = -1; + + tmp1 = sigmoid(w1*z); + // insert the bias + tmp1(tmp1.nr()-1) = -1; + + if (second_hidden_nodes == 0) + { + return sigmoid(w3*tmp1); + } + else + { + tmp2 = sigmoid(w2*tmp1); + // insert the bias + tmp2(tmp2.nr()-1) = -1; + + return sigmoid(w3*tmp2); + } + } + + template + void train ( + const matrix_exp& example_in, + const matrix_exp& example_out + ) + { + for (long i = 0; i < example_in.nr(); ++i) + z(i) = example_in(i); + // insert the bias + z(z.nr()-1) = -1; + + tmp1 = sigmoid(w1*z); + // insert the bias + tmp1(tmp1.nr()-1) = -1; + + + if (second_hidden_nodes == 0) + { + o = sigmoid(w3*tmp1); + + // now compute the errors and propagate them backwards though the network + e3 = pointwise_multiply(example_out-o, uniform_matrix(output_nodes,1,1.0)-o, o); + e1 = pointwise_multiply(tmp1, uniform_matrix(first_hidden_nodes+1,1,1.0) - tmp1, trans(w3)*e3 ); + + // compute the new weight updates + w3m = alpha * e3*trans(tmp1) + w3m*momentum; + w1m = alpha * e1*trans(z) + w1m*momentum; + + // now update the weights + w1 += w1m; + w3 += w3m; + } + else + { + tmp2 = sigmoid(w2*tmp1); + // insert the bias + tmp2(tmp2.nr()-1) = -1; + + o = sigmoid(w3*tmp2); + + + // now compute the errors and propagate them backwards though the network + e3 = pointwise_multiply(example_out-o, uniform_matrix(output_nodes,1,1.0)-o, o); + e2 = pointwise_multiply(tmp2, uniform_matrix(second_hidden_nodes+1,1,1.0) - tmp2, trans(w3)*e3 ); + e1 = pointwise_multiply(tmp1, uniform_matrix(first_hidden_nodes+1,1,1.0) - tmp1, trans(w2)*e2 ); + + // compute the new weight updates + w3m = alpha * e3*trans(tmp2) + w3m*momentum; + w2m = alpha * e2*trans(tmp1) + w2m*momentum; + w1m = alpha * e1*trans(z) + w1m*momentum; + + // now update the weights + w1 += w1m; + w2 += w2m; + w3 += w3m; + } + } + + template + void train ( + const matrix_exp& example_in, + double example_out + ) + { + matrix e_out; + e_out(0) = example_out; + train(example_in,e_out); + } + + double get_average_change ( + ) const + { + // sum up all the weight changes + double delta = sum(abs(w1m)) + sum(abs(w2m)) + sum(abs(w3m)); + + // divide by the number of weights + delta /= w1m.nr()*w1m.nc() + + w2m.nr()*w2m.nc() + + w3m.nr()*w3m.nc(); + + return delta; + } + + void swap ( + mlp_kernel_1& item + ) + { + exchange(input_nodes, item.input_nodes); + exchange(first_hidden_nodes, item.first_hidden_nodes); + exchange(second_hidden_nodes, item.second_hidden_nodes); + exchange(output_nodes, item.output_nodes); + exchange(alpha, item.alpha); + exchange(momentum, item.momentum); + + w1.swap(item.w1); + w2.swap(item.w2); + w3.swap(item.w3); + + w1m.swap(item.w1m); + w2m.swap(item.w2m); + w3m.swap(item.w3m); + + // even swap the temporary matrices because this may ultimately result in + // fewer calls to new and delete. + e1.swap(item.e1); + e2.swap(item.e2); + e3.swap(item.e3); + z.swap(item.z); + tmp1.swap(item.tmp1); + tmp2.swap(item.tmp2); + o.swap(item.o); + } + + + friend void serialize ( + const mlp_kernel_1& item, + std::ostream& out + ); + + friend void deserialize ( + mlp_kernel_1& item, + std::istream& in + ); + + private: + + long input_nodes; + long first_hidden_nodes; + long second_hidden_nodes; + long output_nodes; + double alpha; + double momentum; + + matrix w1; + matrix w2; + matrix w3; + + matrix w1m; + matrix w2m; + matrix w3m; + + + rand rand_nums; + + // temporary storage + mutable matrix e1, e2, e3; + mutable matrix z, tmp1, tmp2, o; + }; + + inline void swap ( + mlp_kernel_1& a, + mlp_kernel_1& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- + + inline void serialize ( + const mlp_kernel_1& item, + std::ostream& out + ) + { + try + { + serialize(item.input_nodes, out); + serialize(item.first_hidden_nodes, out); + serialize(item.second_hidden_nodes, out); + serialize(item.output_nodes, out); + serialize(item.alpha, out); + serialize(item.momentum, out); + + serialize(item.w1, out); + serialize(item.w2, out); + serialize(item.w3, out); + + serialize(item.w1m, out); + serialize(item.w2m, out); + serialize(item.w3m, out); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing object of type mlp_kernel_1"); + } + } + + inline void deserialize ( + mlp_kernel_1& item, + std::istream& in + ) + { + try + { + deserialize(item.input_nodes, in); + deserialize(item.first_hidden_nodes, in); + deserialize(item.second_hidden_nodes, in); + deserialize(item.output_nodes, in); + deserialize(item.alpha, in); + deserialize(item.momentum, in); + + deserialize(item.w1, in); + deserialize(item.w2, in); + deserialize(item.w3, in); + + deserialize(item.w1m, in); + deserialize(item.w2m, in); + deserialize(item.w3m, in); + + item.z.set_size(item.input_nodes+1,1); + } + catch (serialization_error& e) + { + // give item a reasonable value since the deserialization failed + mlp_kernel_1(1,1).swap(item); + throw serialization_error(e.info + "\n while deserializing object of type mlp_kernel_1"); + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MLp_KERNEL_1_ + diff --git a/ml/dlib/dlib/mlp/mlp_kernel_abstract.h b/ml/dlib/dlib/mlp/mlp_kernel_abstract.h new file mode 100644 index 000000000..cbb473a87 --- /dev/null +++ b/ml/dlib/dlib/mlp/mlp_kernel_abstract.h @@ -0,0 +1,225 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_MLp_ABSTRACT_ +#ifdef DLIB_MLp_ABSTRACT_ + +#include "../algs.h" +#include "../serialize.h" +#include "../matrix/matrix_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class mlp : noncopyable + { + /*! + INITIAL VALUE + The network is initially initialized with random weights + + WHAT THIS OBJECT REPRESENTS + This object represents a multilayer layer perceptron network that is + trained using the back propagation algorithm. The training algorithm also + incorporates the momentum method. That is, each round of back propagation + training also adds a fraction of the previous update. This fraction + is controlled by the momentum term set in the constructor. + + The activation function used at each node is the sigmoid function. I.e. + sigmoid(x) = 1/(1 + pow(e,-x)). Thus the output of the network is + always in the range [0,1] + !*/ + + public: + + mlp ( + long nodes_in_input_layer, + long nodes_in_first_hidden_layer, + long nodes_in_second_hidden_layer = 0, + long nodes_in_output_layer = 1, + double alpha = 0.1, + double momentum = 0.8 + ); + /*! + requires + - nodes_in_input_layer > 0 + - nodes_in_first_hidden_layer > 0 + - nodes_in_second_hidden_layer >= 0 + - nodes_in_output_layer > 0 + ensures + - #*this is properly initialized + - #input_layer_nodes() == nodes_in_input_layer + - #first_hidden_layer_nodes() == nodes_in_first_hidden_layer + - #second_hidden_layer_nodes() == nodes_in_second_hidden_layer + - #output_layer_nodes() == nodes_in_output_layer + - #get_alpha() == alpha + - #get_momentum() == momentum + throws + - std::bad_alloc + if this is thrown the mlp will be unusable but + will not leak memory + !*/ + + virtual ~mlp ( + ); + /*! + ensures + - all resources associated with #*this have been released + !*/ + + void reset ( + ) const; + /*! + ensures + - reinitialize the network with random weights + !*/ + + long input_layer_nodes ( + ) const; + /*! + ensures + - returns the number of nodes in the input layer + !*/ + + long first_hidden_layer_nodes ( + ) const; + /*! + ensures + - returns the number of nodes in the first hidden layer. This is + the hidden layer that is directly connected to the input layer. + !*/ + + long second_hidden_layer_nodes ( + ) const; + /*! + ensures + - if (this network has a second hidden layer) then + - returns the number of nodes in the second hidden layer. This is + the hidden layer that is directly connected to the output layer. + - else + - returns 0 + !*/ + + long output_layer_nodes ( + ) const; + /*! + ensures + - returns the number of nodes in the output layer + !*/ + + double get_alpha ( + ) const; + /*! + ensures + - returns the back propagation learning rate used by this object. + !*/ + + double get_momentum ( + ) const; + /*! + ensures + - returns the momentum term used by this object during back propagation + training. The momentum is is the fraction of a previous update to + carry forward to the next call to train() + !*/ + + template + const matrix operator() ( + const matrix_exp& in + ) const; + /*! + requires + - in.nr() == input_layer_nodes() + - in.nc() == 1 + - EXP::type == double + ensures + - returns the output of the network when it is given the + input in. The output's elements are always in the range + of 0.0 to 1.0 + !*/ + + template + void train ( + const matrix_exp& example_in, + const matrix_exp& example_out + ); + /*! + requires + - example_in.nr() == input_layer_nodes() + - example_in.nc() == 1 + - example_out.nr() == output_layer_nodes() + - example_out.nc() == 1 + - max(example_out) <= 1.0 && min(example_out) >= 0.0 + - EXP1::type == double + - EXP2::type == double + ensures + - trains the network that the correct output when given example_in + should be example_out. + !*/ + + template + void train ( + const matrix_exp& example_in, + double example_out + ); + /*! + requires + - example_in.nr() == input_layer_nodes() + - example_in.nc() == 1 + - output_layer_nodes() == 1 + - example_out <= 1.0 && example_out >= 0.0 + - EXP::type == double + ensures + - trains the network that the correct output when given example_in + should be example_out. + !*/ + + double get_average_change ( + ) const; + /*! + ensures + - returns the average change in the node weights in the + neural network during the last call to train() + !*/ + + void swap ( + mlp& item + ); + /*! + ensures + - swaps *this and item + !*/ + + }; + + inline void swap ( + mlp& a, + mlp& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + + void serialize ( + const mlp& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + + void deserialize ( + mlp& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MLp_ABSTRACT_ + + diff --git a/ml/dlib/dlib/mlp/mlp_kernel_c.h b/ml/dlib/dlib/mlp/mlp_kernel_c.h new file mode 100644 index 000000000..bd3438fd9 --- /dev/null +++ b/ml/dlib/dlib/mlp/mlp_kernel_c.h @@ -0,0 +1,151 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MLP_KERNEl_C_ +#define DLIB_MLP_KERNEl_C_ + +#include "mlp_kernel_abstract.h" +#include "../algs.h" +#include "../assert.h" + +namespace dlib +{ + + + template < + typename mlp_base // is an implementation of mlp_kernel_abstract.h + > + class mlp_kernel_c : public mlp_base + { + long verify_constructor_args ( + long nodes_in_input_layer, + long nodes_in_first_hidden_layer, + long nodes_in_second_hidden_layer, + long nodes_in_output_layer + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(nodes_in_input_layer > 0 && + nodes_in_first_hidden_layer > 0 && + nodes_in_second_hidden_layer >= 0 && + nodes_in_output_layer > 0, + "\tconst mlp::constructor()" + << "\n\tinvalid constructor arguments" + << "\n\tnodes_in_input_layer: " << nodes_in_input_layer + << "\n\tnodes_in_first_hidden_layer: " << nodes_in_first_hidden_layer + << "\n\tnodes_in_second_hidden_layer: " << nodes_in_second_hidden_layer + << "\n\tnodes_in_output_layer: " << nodes_in_output_layer + ); + + return nodes_in_input_layer; + } + + public: + + mlp_kernel_c ( + long nodes_in_input_layer, + long nodes_in_first_hidden_layer, + long nodes_in_second_hidden_layer = 0, + long nodes_in_output_layer = 1, + double alpha = 0.1, + double momentum = 0.8 + ) : mlp_base( verify_constructor_args( + nodes_in_input_layer, + nodes_in_input_layer, + nodes_in_second_hidden_layer, + nodes_in_output_layer), + nodes_in_first_hidden_layer, + nodes_in_second_hidden_layer, + nodes_in_output_layer, + alpha, + momentum) + { + } + + template + const matrix operator() ( + const matrix_exp& in + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT(in.nr() == this->input_layer_nodes() && + in.nc() == 1, + "\tconst matrix mlp::operator()(matrix_exp)" + << "\n\tthe input matrix dimensions are not correct" + << "\n\tin.nr(): " << in.nr() + << "\n\tin.nc(): " << in.nc() + << "\n\tinput_layer_nodes(): " << this->input_layer_nodes() + << "\n\tthis: " << this + ); + + return mlp_base::operator()(in); + } + + template + void train ( + const matrix_exp& example_in, + const matrix_exp& example_out + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(example_in.nr() == this->input_layer_nodes() && + example_in.nc() == 1 && + example_out.nr() == this->output_layer_nodes() && + example_out.nc() == 1 && + max(example_out) <= 1.0 && min(example_out) >= 0.0, + "\tvoid mlp::train(matrix_exp, matrix_exp)" + << "\n\tthe training example dimensions are not correct" + << "\n\texample_in.nr(): " << example_in.nr() + << "\n\texample_in.nc(): " << example_in.nc() + << "\n\texample_out.nr(): " << example_out.nr() + << "\n\texample_out.nc(): " << example_out.nc() + << "\n\tmax(example_out): " << max(example_out) + << "\n\tmin(example_out): " << min(example_out) + << "\n\tinput_layer_nodes(): " << this->input_layer_nodes() + << "\n\toutput_layer_nodes(): " << this->output_layer_nodes() + << "\n\tthis: " << this + ); + + mlp_base::train(example_in,example_out); + } + + template + void train ( + const matrix_exp& example_in, + double example_out + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(example_in.nr() == this->input_layer_nodes() && + example_in.nc() == 1 && + this->output_layer_nodes() == 1 && + example_out <= 1.0 && example_out >= 0.0, + "\tvoid mlp::train(matrix_exp, double)" + << "\n\tthe training example dimensions are not correct" + << "\n\texample_in.nr(): " << example_in.nr() + << "\n\texample_in.nc(): " << example_in.nc() + << "\n\texample_out: " << example_out + << "\n\tinput_layer_nodes(): " << this->input_layer_nodes() + << "\n\toutput_layer_nodes(): " << this->output_layer_nodes() + << "\n\tthis: " << this + ); + + mlp_base::train(example_in,example_out); + } + + }; + + template < + typename mlp_base + > + inline void swap ( + mlp_kernel_c& a, + mlp_kernel_c& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MLP_KERNEl_C_ + + diff --git a/ml/dlib/dlib/noncopyable.h b/ml/dlib/dlib/noncopyable.h new file mode 100644 index 000000000..20b9866e7 --- /dev/null +++ b/ml/dlib/dlib/noncopyable.h @@ -0,0 +1,32 @@ +// (C) Copyright Beman Dawes 1999-2003. Distributed under the Boost +// Software License, Version 1.0. (See accompanying file +// LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) +// Contributed by Dave Abrahams +// See http://www.boost.org/libs/utility for documentation. + +#ifndef DLIB_BOOST_NONCOPYABLE_HPP_INCLUDED +#define DLIB_BOOST_NONCOPYABLE_HPP_INCLUDED + + +namespace dlib +{ + class noncopyable + { + /*! + This class makes it easier to declare a class as non-copyable. + If you want to make an object that can't be copied just inherit + from this object. + !*/ + + protected: + noncopyable() = default; + ~noncopyable() = default; + private: // emphasize the following members are private + noncopyable(const noncopyable&); + const noncopyable& operator=(const noncopyable&); + + }; +} + +#endif // DLIB_BOOST_NONCOPYABLE_HPP_INCLUDED + diff --git a/ml/dlib/dlib/numeric_constants.h b/ml/dlib/dlib/numeric_constants.h new file mode 100644 index 000000000..05f26319f --- /dev/null +++ b/ml/dlib/dlib/numeric_constants.h @@ -0,0 +1,53 @@ +//Copyright (C) 2013 Steve Taylor (steve98654@gmail.com), Davis E. King +//License: Boost Software License. See LICENSE.txt for full license. +#ifndef DLIB_NUMERIC_CONSTANTs_H_ +#define DLIB_NUMERIC_CONSTANTs_H_ + +namespace dlib +{ + + // pi -- Pi + const double pi = 3.1415926535897932385; + + // e -- Euler's Constant + const double e = 2.7182818284590452354; + + // sqrt_2 -- The square root of 2 + const double sqrt_2 = 1.4142135623730950488; + + // sqrt_3 -- The square root of 3 + const double sqrt_3 = 1.7320508075688772935; + + // log10_2 -- The logarithm base 10 of two + const double log10_2 = 0.30102999566398119521; + + // light_spd -- The speed of light in vacuum in meters per second + const double light_spd = 2.99792458e8; + + // newton_G -- Newton's gravitational constant (in metric units of m^3/(kg*s^2)) + const double newton_G = 6.67384e-11; + + // planck_cst -- Planck's constant (in units of Joules * seconds) + const double planck_cst = 6.62606957e-34; + + // golden_ratio -- The Golden Ratio + const double golden_ratio = 1.6180339887498948482; + + // euler_gamma -- The Euler Mascheroni Constant + const double euler_gamma = 0.5772156649015328606065; + + // catalan -- Catalan's Constant + const double catalan = 0.91596559417721901505; + + // glaisher -- Glaisher Kinkelin constant + const double glaisher = 1.2824271291006226369; + + // khinchin -- Khinchin's constant + const double khinchin = 2.6854520010653064453; + + // apery -- Apery's constant + const double apery = 1.2020569031595942854; +} + +#endif //DLIB_NUMERIC_CONSTANTs_H_ + diff --git a/ml/dlib/dlib/numerical_integration.h b/ml/dlib/dlib/numerical_integration.h new file mode 100644 index 000000000..6676c9892 --- /dev/null +++ b/ml/dlib/dlib/numerical_integration.h @@ -0,0 +1,8 @@ +// Copyright (C) 2013 Steve Taylor (steve98654@gmail.com) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_INTEGRATE_FUNCTION_ADAPT_SIMPSON_HEADER +#define DLIB_INTEGRATE_FUNCTION_ADAPT_SIMPSON_HEADER + +#include "numerical_integration/integrate_function_adapt_simpson.h" + +#endif // DLIB_INTEGRATE_FUNCTION_ADAPT_SIMPSON_HEADER diff --git a/ml/dlib/dlib/numerical_integration/integrate_function_adapt_simpson.h b/ml/dlib/dlib/numerical_integration/integrate_function_adapt_simpson.h new file mode 100644 index 000000000..c30e21c59 --- /dev/null +++ b/ml/dlib/dlib/numerical_integration/integrate_function_adapt_simpson.h @@ -0,0 +1,93 @@ +// Copyright (C) 2013 Steve Taylor (steve98654@gmail.com) +// License: Boost Software License See LICENSE.txt for full license +#ifndef DLIB_INTEGRATE_FUNCTION_ADAPT_SIMPSONh_ +#define DLIB_INTEGRATE_FUNCTION_ADAPT_SIMPSONh_ + +#include "integrate_function_adapt_simpson_abstract.h" +#include "../assert.h" + +// ---------------------------------------------------------------------------------------- + +namespace dlib +{ + template + T impl_adapt_simp_stop(const funct& f, T a, T b, T fa, T fm, T fb, T is, int cnt) + { + const int maxint = 500; + + T m = (a + b)/2.0; + T h = (b - a)/4.0; + T fml = f(a + h); + T fmr = f(b - h); + T i1 = h/1.5*(fa+4.0*fm+fb); + T i2 = h/3.0*(fa+4.0*(fml+fmr)+2.0*fm+fb); + i1 = (16.0*i2 - i1)/15.0; + T Q = 0; + + if ((std::abs(i1-i2) <= std::abs(is)) || (m <= a) || (b <= m)) + { + Q = i1; + } + else + { + if(cnt < maxint) + { + cnt = cnt + 1; + + Q = impl_adapt_simp_stop(f,a,m,fa,fml,fm,is,cnt) + + impl_adapt_simp_stop(f,m,b,fm,fmr,fb,is,cnt); + } + } + + return Q; + } + +// ---------------------------------------------------------------------------------------- + + template + T integrate_function_adapt_simp( + const funct& f, + T a, + T b, + T tol = 1e-10 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(b > a && tol > 0, + "\t T integrate_function_adapt_simp()" + << "\n\t Invalid arguments were given to this function." + << "\n\t a: " << a + << "\n\t b: " << b + << "\n\t tol: " << tol + ); + + T eps = std::numeric_limits::epsilon(); + if(tol < eps) + { + tol = eps; + } + + const T ba = b-a; + const T fa = f(a); + const T fb = f(b); + const T fm = f((a+b)/2); + + T is = ba/8*(fa+fb+fm+ f(a + 0.9501*ba) + f(a + 0.2311*ba) + f(a + 0.6068*ba) + + f(a + 0.4860*ba) + f(a + 0.8913*ba)); + + if(is == 0) + { + is = b-a; + } + + is = is*tol; + + int cnt = 0; + + return impl_adapt_simp_stop(f, a, b, fa, fm, fb, is, cnt); + } +} + +// ---------------------------------------------------------------------------------------- + +#endif // DLIB_INTEGRATE_FUNCTION_ADAPT_SIMPSONh_ diff --git a/ml/dlib/dlib/numerical_integration/integrate_function_adapt_simpson_abstract.h b/ml/dlib/dlib/numerical_integration/integrate_function_adapt_simpson_abstract.h new file mode 100644 index 000000000..90badcf14 --- /dev/null +++ b/ml/dlib/dlib/numerical_integration/integrate_function_adapt_simpson_abstract.h @@ -0,0 +1,34 @@ +// Copyright (C) 2013 Steve Taylor (steve98654@gmail.com) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_INTEGRATE_FUNCTION_ADAPT_SIMPSON_ABSTRACTh_ +#ifdef DLIB_INTEGRATE_FUNCTION_ADAPT_SIMPSON_ABSTRACTh_ + +namespace dlib +{ + + template + T integrate_function_adapt_simp( + const funct& f, + T a, + T b, + T tol = 1e-10 + ); + /*! + requires + - b > a + - tol > 0 + - T should be either float, double, or long double + - The expression f(a) should be a valid expression that evaluates to a T. + I.e. f() should be a real valued function of a single variable. + ensures + - returns an approximation of the integral of f over the domain [a,b] using the + adaptive Simpson method outlined in Gander, W. and W. Gautshi, "Adaptive + Quadrature -- Revisited" BIT, Vol. 40, (2000), pp.84-101 + - tol is a tolerance parameter that determines the overall accuracy of + approximated integral. We suggest a default value of 1e-10 for tol. + !*/ + +} + +#endif // DLIB_INTEGRATE_FUNCTION_ADAPT_SIMPSON_ABSTRACTh_ + diff --git a/ml/dlib/dlib/opencv.h b/ml/dlib/dlib/opencv.h new file mode 100644 index 000000000..c48a6ec5a --- /dev/null +++ b/ml/dlib/dlib/opencv.h @@ -0,0 +1,17 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifdef DLIB_ALL_SOURCE_END +#include "dlib_basic_cpp_build_tutorial.txt" +#endif + +#ifndef DLIB_OPEnCV_HEADER +#define DLIB_OPEnCV_HEADER + +#include "opencv/cv_image.h" +#include "opencv/to_open_cv.h" + +#endif // DLIB_OPEnCV_HEADER + + + + diff --git a/ml/dlib/dlib/opencv/cv_image.h b/ml/dlib/dlib/opencv/cv_image.h new file mode 100644 index 000000000..5f224d003 --- /dev/null +++ b/ml/dlib/dlib/opencv/cv_image.h @@ -0,0 +1,225 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CvIMAGE_H_ +#define DLIB_CvIMAGE_H_ + +#include +#include +#include "cv_image_abstract.h" +#include "../algs.h" +#include "../pixel.h" +#include "../matrix/matrix_mat.h" +#include "../image_processing/generic_image.h" + +namespace dlib +{ + + template < + typename pixel_type + > + class cv_image + { + public: + typedef pixel_type type; + typedef default_memory_manager mem_manager_type; + + cv_image (const cv::Mat img) + { + DLIB_CASSERT(img.depth() == cv::DataType::basic_pixel_type>::depth && + img.channels() == pixel_traits::num, + "The pixel type you gave doesn't match pixel used by the open cv Mat object." + << "\n\t img.depth(): " << img.depth() + << "\n\t img.cv::DataType::basic_pixel_type>::depth: " + << cv::DataType::basic_pixel_type>::depth + << "\n\t img.channels(): " << img.channels() + << "\n\t img.pixel_traits::num: " << pixel_traits::num + ); + IplImage temp = img; + init(&temp); + } + + cv_image (const IplImage img) + { + init(&img); + } + + cv_image (const IplImage* img) + { + init(img); + } + + cv_image() : _data(0), _widthStep(0), _nr(0), _nc(0) {} + + size_t size () const { return static_cast(_nr*_nc); } + + inline pixel_type* operator[](const long row ) + { + // make sure requires clause is not broken + DLIB_ASSERT(0 <= row && row < nr(), + "\tpixel_type* cv_image::operator[](row)" + << "\n\t you have asked for an out of bounds row " + << "\n\t row: " << row + << "\n\t nr(): " << nr() + << "\n\t this: " << this + ); + + return reinterpret_cast( _data + _widthStep*row); + } + + inline const pixel_type* operator[](const long row ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(0 <= row && row < nr(), + "\tconst pixel_type* cv_image::operator[](row)" + << "\n\t you have asked for an out of bounds row " + << "\n\t row: " << row + << "\n\t nr(): " << nr() + << "\n\t this: " << this + ); + + return reinterpret_cast( _data + _widthStep*row); + } + + inline const pixel_type& operator()(const long row, const long column) const + { + DLIB_ASSERT(0<= column && column < nc(), + "\tcont pixel_type& cv_image::operator()(const long rown const long column)" + << "\n\t you have asked for an out of bounds column " + << "\n\t column: " << column + << "\n\t nc(): " << nc() + << "\n\t this: " << this + ); + + return (*this)[row][column]; + } + + inline pixel_type& operator()(const long row, const long column) + { + DLIB_ASSERT(0<= column && column < nc(), + "\tcont pixel_type& cv_image::operator()(const long rown const long column)" + << "\n\t you have asked for an out of bounds column " + << "\n\t column: " << column + << "\n\t nc(): " << nc() + << "\n\t this: " << this + ); + + return (*this)[row][column]; + } + + long nr() const { return _nr; } + long nc() const { return _nc; } + long width_step() const { return _widthStep; } + + cv_image& operator=( const cv_image& item) + { + _data = item._data; + _widthStep = item._widthStep; + _nr = item._nr; + _nc = item._nc; + return *this; + } + + cv_image& operator=( const IplImage* img) + { + init(img); + return *this; + } + + cv_image& operator=( const IplImage img) + { + init(&img); + return *this; + } + + cv_image& operator=( const cv::Mat img) + { + IplImage temp = img; + init(&temp); + return *this; + } + + private: + + void init (const IplImage* img) + { + DLIB_CASSERT( img->dataOrder == 0, "Only interleaved color channels are supported with cv_image"); + DLIB_CASSERT((img->depth&0xFF)/8*img->nChannels == sizeof(pixel_type), + "The pixel type you gave doesn't match the size of pixel used by the open cv image struct"); + + _data = img->imageData; + _widthStep = img->widthStep; + _nr = img->height; + _nc = img->width; + + } + + char* _data; + long _widthStep; + long _nr; + long _nc; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + const matrix_op > > mat ( + const cv_image& m + ) + { + typedef op_array2d_to_mat > op; + return matrix_op(op(m)); + } + +// ---------------------------------------------------------------------------------------- + +// Define the global functions that make cv_image a proper "generic image" according to +// ../image_processing/generic_image.h + template + struct image_traits > + { + typedef T pixel_type; + }; + + template + inline long num_rows( const cv_image& img) { return img.nr(); } + template + inline long num_columns( const cv_image& img) { return img.nc(); } + + template + inline void* image_data( + cv_image& img + ) + { + if (img.size() != 0) + return &img[0][0]; + else + return 0; + } + + template + inline const void* image_data( + const cv_image& img + ) + { + if (img.size() != 0) + return &img[0][0]; + else + return 0; + } + + template + inline long width_step( + const cv_image& img + ) + { + return img.width_step(); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CvIMAGE_H_ + diff --git a/ml/dlib/dlib/opencv/cv_image_abstract.h b/ml/dlib/dlib/opencv/cv_image_abstract.h new file mode 100644 index 000000000..6fbc56b43 --- /dev/null +++ b/ml/dlib/dlib/opencv/cv_image_abstract.h @@ -0,0 +1,280 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_OPENCV_IMAGE_AbSTRACT_H_ +#ifdef DLIB_OPENCV_IMAGE_AbSTRACT_H_ + +#include +#include +#include "../algs.h" +#include "../pixel.h" + +namespace dlib +{ + + template < + typename pixel_type + > + class cv_image + { + /*! + REQUIREMENTS ON pixel_type + pixel_type just needs to be something that matches the pixel memory + layout of whatever OpenCV image you are going to use with this object. + For example, you might use unsigned char or bgr_pixel depending + on what you needed. + + WHAT THIS OBJECT REPRESENTS + This object is meant to be used as a simple wrapper around the OpenCV + IplImage struct or Mat object. Using this class template you can turn + an OpenCV image into something that looks like a normal dlib style + image object. + + So you should be able to use cv_image objects with many of the image + processing functions in dlib as well as the GUI tools for displaying + images on the screen. + + Note that this object does NOT take ownership of the image data you + give to it. This means it is up to you to make sure the OpenCV image + is properly freed at some point. This also means that an instance of + this object can only be used as long as the OpenCV image it references + remains valid, since a cv_image just points to the OpenCV image's + memory directly. + !*/ + + public: + typedef pixel_type type; + typedef default_memory_manager mem_manager_type; + + cv_image ( + const IplImage* img + ); + /*! + requires + - img->dataOrder == 0 + (i.e. Only interleaved color channels are supported with cv_image) + - (img->depth&0xFF)/8*img->nChannels == sizeof(pixel_type) + (i.e. The size of the pixel_type needs to match the size of the pixels + inside the OpenCV image) + ensures + - #nr() == img->height + #nc() == img->width + - using the operator[] on this object you will be able to access the pixels + inside this OpenCV image. + !*/ + + cv_image ( + const IplImage img + ); + /*! + requires + - img.dataOrder == 0 + (i.e. Only interleaved color channels are supported with cv_image) + - (img.depth&0xFF)/8*img.nChannels == sizeof(pixel_type) + (i.e. The size of the pixel_type needs to match the size of the pixels + inside the OpenCV image) + ensures + - #nr() == img.height + #nc() == img.width + - using the operator[] on this object you will be able to access the pixels + inside this OpenCV image. + !*/ + + cv_image ( + const cv::Mat img + ); + /*! + requires + - img.depth() == cv::DataType::basic_pixel_type>::depth + (i.e. The pixel_type template argument needs to match the type of pixel + used inside the OpenCV image) + - img.channels() == pixel_traits::num + (i.e. the number of channels in the pixel_type needs to match the number of + channels in the OpenCV image) + ensures + - #nr() == img.rows + - #nc() == img.cols + - using the operator[] on this object you will be able to access the pixels + inside this OpenCV image. + !*/ + + cv_image( + ); + /*! + ensures + - #nr() == 0 + - #nc() == 0 + !*/ + + ~cv_image ( + ); + /*! + ensures + - This function does nothing. e.g. It doesn't delete the OpenCV + image used by this cv_image object + !*/ + + long nr( + ) const; + /*! + ensures + - returns the number of rows in this image + !*/ + + long nc( + ) const; + /*! + ensures + - returns the number of columns in this image + !*/ + + size_t size ( + ) const; + /*! + ensures + - returns nr()*nc() + (i.e. returns the number of pixels in this image) + !*/ + + inline pixel_type* operator[] ( + const long row + ); + /*! + requires + - 0 <= row < nr() + ensures + - returns a pointer to the first pixel in the given row + of this image + !*/ + + inline const pixel_type* operator[] ( + const long row + ) const; + /*! + requires + - 0 <= row < nr() + ensures + - returns a pointer to the first pixel in the given row + of this image + !*/ + + inline const pixel_type& operator()( + const long row, const long column + ) const + /*! + requires + - 0 <= row < nr() + - 0 <= column < nc() + ensures + - returns a const reference to the pixel at coordinates (row, column) + of this image + !*/ + + inline pixel_type& operator()( + const long row, const long column + ) + /*! + requires + - 0 <= row < nr() + - 0 <= column < nc() + ensures + - returns a reference to the pixel at coordinates (row, column) + of this image + !*/ + + cv_image& operator= ( + const cv_image& item + ); + /*! + ensures + - #*this is an identical copy of item + - returns #*this + !*/ + + cv_image& operator=( + const IplImage* img + ); + /*! + requires + - img->dataOrder == 0 + (i.e. Only interleaved color channels are supported with cv_image) + - (img->depth&0xFF)/8*img->nChannels == sizeof(pixel_type) + (i.e. The size of the pixel_type needs to match the size of the pixels + inside the OpenCV image) + ensures + - #nr() == img->height + #nc() == img->width + - using the operator[] on this object you will be able to access the pixels + inside this OpenCV image. + - returns #*this + !*/ + + cv_image& operator=( + const IplImage img + ); + /*! + requires + - img->dataOrder == 0 + (i.e. Only interleaved color channels are supported with cv_image) + - (img->depth&0xFF)/8*img->nChannels == sizeof(pixel_type) + (i.e. The size of the pixel_type needs to match the size of the pixels + inside the OpenCV image) + ensures + - #nr() == img->height + #nc() == img->width + - using the operator[] on this object you will be able to access the pixels + inside this OpenCV image. + - returns #*this + !*/ + + cv_image& operator=( + const cv::Mat img + ); + /*! + requires + - img.depth() == cv::DataType::basic_pixel_type>::depth + (i.e. The pixel_type template argument needs to match the type of pixel + used inside the OpenCV image) + - img.channels() == pixel_traits::num + (i.e. the number of channels in the pixel_type needs to match the number of + channels in the OpenCV image) + ensures + - #nr() == img.rows + - #nc() == img.cols + - using the operator[] on this object you will be able to access the pixels + inside this OpenCV image. + - returns #*this + !*/ + + long width_step ( + ) const; + /*! + ensures + - returns the size of one row of the image, in bytes. + More precisely, return a number N such that: + (char*)&item[0][0] + N == (char*)&item[1][0]. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + const matrix_exp mat ( + const cv_image& img + ); + /*! + ensures + - returns a matrix R such that: + - R.nr() == img.nr() + - R.nc() == img.nc() + - for all valid r and c: + R(r, c) == img[r][c] + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_OPENCV_IMAGE_AbSTRACT_H_ + diff --git a/ml/dlib/dlib/opencv/to_open_cv.h b/ml/dlib/dlib/opencv/to_open_cv.h new file mode 100644 index 000000000..02b7bf6fc --- /dev/null +++ b/ml/dlib/dlib/opencv/to_open_cv.h @@ -0,0 +1,46 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_TO_OPEN_Cv_Hh_ +#define DLIB_TO_OPEN_Cv_Hh_ + +#include +#include "to_open_cv_abstract.h" +#include "../pixel.h" +#include "../matrix/matrix.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type + > + cv::Mat toMat ( + image_type& img + ) + { + if (image_size(img) == 0) + return cv::Mat(); + + typedef typename image_traits::pixel_type type; + typedef typename pixel_traits::basic_pixel_type basic_pixel_type; + if (pixel_traits::num == 1) + { + return cv::Mat(num_rows(img), num_columns(img), cv::DataType::type, image_data(img), width_step(img)); + } + else + { + int depth = sizeof(typename pixel_traits::basic_pixel_type)*8; + int channels = pixel_traits::num; + int thetype = CV_MAKETYPE(depth, channels); + return cv::Mat(num_rows(img), num_columns(img), thetype, image_data(img), width_step(img)); + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_TO_OPEN_Cv_Hh_ + diff --git a/ml/dlib/dlib/opencv/to_open_cv_abstract.h b/ml/dlib/dlib/opencv/to_open_cv_abstract.h new file mode 100644 index 000000000..43307e02f --- /dev/null +++ b/ml/dlib/dlib/opencv/to_open_cv_abstract.h @@ -0,0 +1,34 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_TO_OPEN_Cv_ABSTRACTh_ +#ifdef DLIB_TO_OPEN_Cv_ABSTRACTh_ + +#include +#include "../pixel.h" + +namespace dlib +{ + template < + typename image_type + > + cv::Mat toMat ( + image_type& img + ); + /*! + requires + - image_type == an image object that implements the interface defined in + dlib/image_processing/generic_image.h or a dlib::matrix object which uses a + row_major_layout. + - pixel_traits is defined for the contents of img. + ensures + - returns an OpenCV Mat object which represents the same image as img. This + is done by setting up the Mat object to point to the same memory as img. + Therefore, the returned Mat object is valid only as long as pointers + to the pixels in img remain valid. + !*/ +} + +#endif // DLIB_TO_OPEN_Cv_ABSTRACTh_ + + + diff --git a/ml/dlib/dlib/optimization.h b/ml/dlib/dlib/optimization.h new file mode 100644 index 000000000..260eacc1c --- /dev/null +++ b/ml/dlib/dlib/optimization.h @@ -0,0 +1,24 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_OPTIMIZATIOn_HEADER +#define DLIB_OPTIMIZATIOn_HEADER + +#include "optimization/optimization.h" +#include "optimization/optimization_bobyqa.h" +#include "optimization/optimization_solve_qp_using_smo.h" +#include "optimization/optimization_solve_qp2_using_smo.h" +#include "optimization/optimization_solve_qp3_using_smo.h" +#include "optimization/optimization_oca.h" +#include "optimization/optimization_trust_region.h" +#include "optimization/optimization_least_squares.h" +#include "optimization/max_cost_assignment.h" +#include "optimization/max_sum_submatrix.h" +#include "optimization/find_max_factor_graph_nmplp.h" +#include "optimization/find_max_factor_graph_viterbi.h" +#include "optimization/find_max_parse_cky.h" +#include "optimization/isotonic_regression.h" + +#endif // DLIB_OPTIMIZATIOn_HEADER + + + diff --git a/ml/dlib/dlib/optimization/elastic_net.h b/ml/dlib/dlib/optimization/elastic_net.h new file mode 100644 index 000000000..6c4b6d0b4 --- /dev/null +++ b/ml/dlib/dlib/optimization/elastic_net.h @@ -0,0 +1,389 @@ +// Copyright (C) 2016 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ElASTIC_NET_Hh_ +#define DLIB_ElASTIC_NET_Hh_ + +#include "../matrix.h" +#include "elastic_net_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class elastic_net + { + public: + + template + explicit elastic_net( + const matrix_exp& XX + ) : eps(1e-5), max_iterations(50000), verbose(false) + { + // make sure requires clause is not broken + DLIB_ASSERT(XX.size() > 0 && + XX.nr() == XX.nc(), + "\t elastic_net::elastic_net(XX)" + << " \n\t XX must be a non-empty square matrix." + << " \n\t XX.nr(): " << XX.nr() + << " \n\t XX.nc(): " << XX.nc() + << " \n\t this: " << this + ); + + + // If the number of columns in X is big and in particular bigger than the number of + // rows then we can get rid of them by doing some SVD magic. Doing this doesn't + // make the final results of anything change but makes all the matrices have + // dimensions that are X.nr() in size, which can be much smaller. + matrix s; + svd3(XX,u,eig_vals,eig_vects); + s = sqrt(eig_vals); + X = eig_vects*diagm(s); + u = eig_vects*inv(diagm(s)); + + + + samples.resize(X.nr()*2); + + for (size_t i = 0; i < samples.size(); ++i) + index.push_back(i); + active_size = index.size(); + + + // setup the training samples used in the SVM optimizer below + for (size_t i = 0; i < samples.size(); ++i) + { + auto& x = samples[i]; + const long idx = i/2; + if (i%2 == 0) + x.label = +1; + else + x.label = -1; + + x.r = idx%X.nr(); + } + } + + template + elastic_net( + const matrix_exp& XX, + const matrix_exp& XY + ) : elastic_net(XX) + { + // make sure requires clause is not broken + DLIB_ASSERT(XX.size() > 0 && + XX.nr() == XX.nc() && + is_col_vector(XY) && + XX.nc() == XY.size() , + "\t elastic_net::elastic_net(XX,XY)" + << " \n\t Invalid inputs were given to this function." + << " \n\t XX.size(): " << XX.size() + << " \n\t is_col_vector(XY): " << is_col_vector(XY) + << " \n\t XX.nr(): " << XX.nr() + << " \n\t XX.nc(): " << XX.nc() + << " \n\t XY.size(): " << XY.size() + << " \n\t this: " << this + ); + + set_xy(XY); + } + + long size ( + ) const { return u.nr(); } + + template + void set_xy( + const matrix_exp& XY + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(is_col_vector(XY) && + XY.size() == size(), + "\t void elastic_net::set_y(Y)" + << " \n\t Invalid inputs were given to this function." + << " \n\t is_col_vector(XY): " << is_col_vector(XY) + << " \n\t size(): " << size() + << " \n\t XY.size(): " << XY.size() + << " \n\t this: " << this + ); + + Y = trans(u)*XY; + // We can use the ynorm after it has been projected because the only place Y + // appears in the algorithm is in terms of dot products with w and x vectors. + // But those vectors are always in the span of X and therefore we only see the + // part of the norm of Y that is in the span of X (and hence u since u and X + // have the same span by construction) + ynorm = length_squared(Y); + xdoty = X*Y; + eig_vects_xdoty = trans(eig_vects)*xdoty; + + w.set_size(Y.size()); + // zero out any memory of previous solutions + alpha.assign(X.nr()*2, 0); + } + + bool have_target_values ( + ) const { return Y.size() != 0; } + + void set_epsilon( + double eps_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(eps_ > 0, + "\t void elastic_net::set_epsilon()" + << " \n\t eps_ must be greater than 0" + << " \n\t eps_: " << eps_ + << " \n\t this: " << this + ); + + eps = eps_; + } + + unsigned long get_max_iterations ( + ) const { return max_iterations; } + + void set_max_iterations ( + unsigned long max_iter + ) + { + max_iterations = max_iter; + } + + void be_verbose ( + ) + { + verbose = true; + } + + void be_quiet ( + ) + { + verbose = false; + } + + double get_epsilon ( + ) const { return eps; } + + matrix operator() ( + double ridge_lambda, + double lasso_budget = std::numeric_limits::infinity() + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(have_target_values() && + ridge_lambda > 0 && + lasso_budget > 0 , + "\t matrix elastic_net::operator()()" + << " \n\t Invalid inputs were given to this function." + << " \n\t have_target_values(): " << have_target_values() + << " \n\t ridge_lambda: " << ridge_lambda + << " \n\t lasso_budget: " << lasso_budget + << " \n\t this: " << this + ); + + + // First check if lasso_budget is so big that it isn't even active. We do this + // by doing just ridge regression and checking the result. + matrix betas = eig_vects*tmp(inv(diagm(eig_vals + ridge_lambda))*eig_vects_xdoty); + if (sum(abs(betas)) <= lasso_budget) + return betas; + + + // Set w back to 0. We will compute the w corresponding to what is currently + // in alpha layer on. This way w and alpha are always in sync. + w = 0; + wy_mult = 0; + wdoty = 0; + + + // return dot(w,x) + auto dot = [&](const matrix& w, const en_sample2& x) + { + const double xmul = -x.label*(1/lasso_budget); + // Do the base dot product but don't forget to add in the -(1/t)*y part from the svm reduction paper + double val = rowm(X,x.r)*w + xmul*wdoty + wy_mult*xdoty(x.r) + xmul*wy_mult*ynorm; + + return val; + }; + + + // perform w += scale*x; + auto add_to = [&](matrix& w, double scale, const en_sample2& x) + { + const double xmul = -x.label*(1/lasso_budget); + wy_mult += scale*xmul; + wdoty += scale*xdoty(x.r); + w += scale*trans(rowm(X,x.r)); + + }; + + const double Dii = ridge_lambda; + + // setup the training samples used in the SVM optimizer below + for (size_t i = 0; i < samples.size(); ++i) + { + auto& x = samples[i]; + + const double xmul = -x.label*(1/lasso_budget); + x.xdotx = xmul*xmul*ynorm; + for (long c = 0; c < X.nc(); ++c) + x.xdotx += std::pow(X(x.r,c)+xmul*Y(c), 2.0) - std::pow(xmul*Y(c),2.0); + + // compute the correct w given whatever might be in alpha. + if (alpha[i] != 0) + add_to(w, x.label*alpha[i], samples[i]); + } + + + // Now run the optimizer + double PG_max_prev = std::numeric_limits::infinity(); + double PG_min_prev = -std::numeric_limits::infinity(); + + + unsigned int iter; + for (iter = 0; iter < max_iterations; ++iter) + { + // randomly shuffle the indices + for (unsigned long i = 0; i < active_size; ++i) + { + // pick a random index >= i + const long j = i + rnd.get_random_32bit_number()%(active_size-i); + std::swap(index[i], index[j]); + } + + double PG_max = -std::numeric_limits::infinity(); + double PG_min = std::numeric_limits::infinity(); + for (size_t ii = 0; ii < active_size; ++ii) + { + const auto i = index[ii]; + const auto& x = samples[i]; + double G = x.label*dot(w, x) - 1 + Dii*alpha[i]; + + double PG = 0; + if (alpha[i] == 0) + { + if (G > PG_max_prev) + { + // shrink the active set of training examples + --active_size; + std::swap(index[ii], index[active_size]); + --ii; + continue; + } + + if (G < 0) + PG = G; + } + else + { + PG = G; + } + + if (PG > PG_max) + PG_max = PG; + if (PG < PG_min) + PG_min = PG; + + // if PG != 0 + if (std::abs(PG) > 1e-12) + { + const double alpha_old = alpha[i]; + alpha[i] = std::max(alpha[i] - G/(x.xdotx+Dii), (double)0.0); + const double delta = (alpha[i]-alpha_old)*x.label; + add_to(w, delta, x); + } + } + + if (verbose) + { + using namespace std; + cout << "gap: " << PG_max - PG_min << endl; + cout << "active_size: " << active_size << endl; + cout << "iter: " << iter << endl; + cout << endl; + } + + if (PG_max - PG_min <= eps) + { + // stop if we are within eps tolerance and the last iteration + // was over all the samples + if (active_size == index.size()) + break; + + // Turn off shrinking on the next iteration. We will stop if the + // tolerance is still <= eps when shrinking is off. + active_size = index.size(); + PG_max_prev = std::numeric_limits::infinity(); + PG_min_prev = -std::numeric_limits::infinity(); + } + else + { + PG_max_prev = PG_max; + PG_min_prev = PG_min; + if (PG_max_prev <= 0) + PG_max_prev = std::numeric_limits::infinity(); + if (PG_min_prev >= 0) + PG_min_prev = -std::numeric_limits::infinity(); + } + + + // recalculate wdoty every so often to avoid drift. + if (iter%100 == 0) + wdoty = dlib::dot(Y, w); + } + + + betas.set_size(alpha.size()/2); + for (long i = 0; i < betas.size(); ++i) + betas(i) = lasso_budget*(alpha[2*i] - alpha[2*i+1]); + betas /= sum(mat(alpha)); + return betas; + } + + + private: + + struct en_sample2 + { + // X location + long r; + + + double label; + + double xdotx; + }; + + std::vector samples; + std::vector alpha; + double ynorm; + matrix X; + matrix Y; + matrix xdoty; + double wdoty; + double wy_mult; // logically, the real w is what is in the w vector + wy_mult*Y + matrix w; + std::vector index; + unsigned long active_size; + + matrix eig_vects_xdoty; + matrix eig_vals; + matrix eig_vects; + matrix u; + + dlib::rand rnd; + + + double eps; + unsigned long max_iterations; + bool verbose; + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_ElASTIC_NET_Hh_ + + diff --git a/ml/dlib/dlib/optimization/elastic_net_abstract.h b/ml/dlib/dlib/optimization/elastic_net_abstract.h new file mode 100644 index 000000000..8ae69e37e --- /dev/null +++ b/ml/dlib/dlib/optimization/elastic_net_abstract.h @@ -0,0 +1,190 @@ +// Copyright (C) 2016 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_ElASTIC_NET_ABSTRACT_Hh_ +#ifdef DLIB_ElASTIC_NET_ABSTRACT_Hh_ + +#include "../matrix.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class elastic_net + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a tool for solving the following optimization problem: + + min_w: length_squared(X*w - Y) + ridge_lambda*length_squared(w) + such that: sum(abs(w)) <= lasso_budget + + That is, it solves the elastic net optimization problem. This object also + has the special property that you can quickly obtain different solutions + for different settings of ridge_lambda, lasso_budget, and target Y values. + + This is because a large amount of work is precomputed in the constructor. + The solver will also remember the previous solution and will use that to + warm start subsequent invocations. Therefore, you can efficiently get + solutions for a wide range of regularization parameters. + + + The particular algorithm used to solve it is described in the paper: + Zhou, Quan, et al. "A reduction of the elastic net to support vector + machines with an application to gpu computing." arXiv preprint + arXiv:1409.1976 (2014). APA + + And for the SVM solver sub-component we use the algorithm from: + Hsieh, Cho-Jui, et al. "A dual coordinate descent method for large-scale + linear SVM." Proceedings of the 25th international conference on Machine + learning. ACM, 2008. + !*/ + + public: + + template + explicit elastic_net( + const matrix_exp& XX + ); + /*! + requires + - XX.size() != 0 + - XX.nr() == XX.nc() + ensures + - #get_epsilon() == 1e-5 + - #get_max_iterations() == 50000 + - This object will not be verbose unless be_verbose() is called. + - #size() == XX.nc() + - #have_target_values() == false + - We interpret XX as trans(X)*X where X is as defined in the objective + function discussed above in WHAT THIS OBJECT REPRESENTS. + !*/ + + template + elastic_net( + const matrix_exp& XX, + const matrix_exp& XY + ); + /*! + requires + - XX.size() != 0 + - XX.nr() == XX.nc() + - is_col_vector(XY) + - XX.nc() == Y.size() + ensures + - constructs this object by calling the elastic_net(XX) constructor and + then calling this->set_xy(XY). + - #have_target_values() == true + - We interpret XX as trans(X)*X where X is as defined in the objective + function discussed above in WHAT THIS OBJECT REPRESENTS. Similarly, XY + should be trans(X)*Y. + !*/ + + long size ( + ) const; + /*! + ensures + - returns the dimensionality of the data loaded into this object. That is, + how many elements are in the optimal w vector? This function returns + that number. + !*/ + + bool have_target_values ( + ) const; + /*! + ensures + - returns true if set_xy() has been called and false otherwise. + !*/ + + template + void set_xy( + const matrix_exp& XY + ); + /*! + requires + - is_col_vector(Y) + - Y.size() == size() + ensures + - #have_target_values() == true + - Sets the target values of the regression. Note that we expect the given + matrix, XY, to be equal to trans(X)*Y, where X and Y have the definitions + discussed above in WHAT THIS OBJECT REPRESENTS. + !*/ + + void set_epsilon( + double eps + ); + /*! + requires + - eps > 0 + ensures + - #get_epsilon() == eps + !*/ + + double get_epsilon ( + ) const; + /*! + ensures + - returns the error epsilon that determines when the solver should stop. + Smaller values may result in a more accurate solution but take longer to + execute. + !*/ + + unsigned long get_max_iterations ( + ) const; + /*! + ensures + - returns the maximum number of iterations the optimizer is allowed to run + before it is required to stop and return a result. + !*/ + + void set_max_iterations ( + unsigned long max_iter + ); + /*! + ensures + - #get_max_iterations() == max_iter + !*/ + + void be_verbose ( + ); + /*! + ensures + - This object will print status messages to standard out so that a + user can observe the progress of the algorithm. + !*/ + + void be_quiet ( + ); + /*! + ensures + - this object will not print anything to standard out. + !*/ + + + matrix operator() ( + double ridge_lambda, + double lasso_budget = std::numeric_limits::infinity() + ); + /*! + requires + - have_target_values() == true + - ridge_lambda > 0 + - lasso_budget > 0 + ensures + - Solves the optimization problem described in the WHAT THIS OBJECT + REPRESENTS section above and returns the optimal w. + - The returned vector has size() elements. + - if (lasso_budget == infinity) then + - The lasso constraint is ignored + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_ElASTIC_NET_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/optimization/find_max_factor_graph_nmplp.h b/ml/dlib/dlib/optimization/find_max_factor_graph_nmplp.h new file mode 100644 index 000000000..3dd7fd56f --- /dev/null +++ b/ml/dlib/dlib/optimization/find_max_factor_graph_nmplp.h @@ -0,0 +1,337 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_FIND_MAX_FACTOR_GRAPH_nMPLP_Hh_ +#define DLIB_FIND_MAX_FACTOR_GRAPH_nMPLP_Hh_ + +#include "find_max_factor_graph_nmplp_abstract.h" +#include +#include +#include "../matrix.h" +#include "../hash.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + class simple_hash_map + { + public: + + simple_hash_map( + ) : + scan_dist(6) + { + data.resize(5000); + } + + void insert ( + const unsigned long a, + const unsigned long b, + const unsigned long value + ) + /*! + requires + - a != std::numeric_limits::max() + ensures + - #(*this)(a,b) == value + !*/ + { + const uint32 h = murmur_hash3_2(a,b)%(data.size()-scan_dist); + + const unsigned long empty_bucket = std::numeric_limits::max(); + + for (uint32 i = 0; i < scan_dist; ++i) + { + if (data[i+h].key1 == empty_bucket) + { + data[i+h].key1 = a; + data[i+h].key2 = b; + data[i+h].value = value; + return; + } + } + + // if we get this far it means the hash table is filling up. So double its size. + std::vector new_data; + new_data.resize(data.size()*2); + new_data.swap(data); + for (uint32 i = 0; i < new_data.size(); ++i) + { + if (new_data[i].key1 != empty_bucket) + { + insert(new_data[i].key1, new_data[i].key2, new_data[i].value); + } + } + + insert(a,b,value); + } + + unsigned long operator() ( + const unsigned long a, + const unsigned long b + ) const + /*! + requires + - this->insert(a,b,some_value) has been called + ensures + - returns the value stored at key (a,b) + !*/ + { + DLIB_ASSERT(a != b, "An invalid map_problem was given to find_max_factor_graph_nmplp()." + << "\nNode " << a << " is listed as being a neighbor with itself, which is illegal."); + + uint32 h = murmur_hash3_2(a,b)%(data.size()-scan_dist); + + + for (unsigned long i = 0; i < scan_dist; ++i) + { + if (data[h].key1 == a && data[h].key2 == b) + { + return data[h].value; + } + ++h; + } + + + // this should never happen (since this function requires (a,b) to be in the hash table + DLIB_ASSERT(false, "An invalid map_problem was given to find_max_factor_graph_nmplp()." + << "\nThe nodes in the map_problem are inconsistent because node "<::max()) {} + unsigned long key1; + unsigned long key2; + unsigned long value; + }; + + std::vector data; + const unsigned int scan_dist; + }; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_problem + > + void find_max_factor_graph_nmplp ( + const map_problem& prob, + std::vector& map_assignment, + unsigned long max_iter, + double eps + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( eps > 0, + "\t void find_max_factor_graph_nmplp()" + << "\n\t eps must be greater than zero" + << "\n\t eps: " << eps + ); + + /* + This function is an implementation of the NMPLP algorithm introduced in the + following papers: + Fixing Max-Product: Convergent Message Passing Algorithms for MAP LP-Relaxations (2008) + by Amir Globerson and Tommi Jaakkola + + Introduction to dual decomposition for inference (2011) + by David Sontag, Amir Globerson, and Tommi Jaakkola + + In particular, this function implements the star MPLP update equations shown as + equation 1.20 from the paper Introduction to dual decomposition for inference + (the method was called NMPLP in the first paper). It should also be noted that + the original description of the NMPLP in the first paper had an error in the + equations and the second paper contains corrected equations, which is what this + function uses. + */ + + typedef typename map_problem::node_iterator node_iterator; + typedef typename map_problem::neighbor_iterator neighbor_iterator; + + map_assignment.resize(prob.number_of_nodes()); + + + if (prob.number_of_nodes() == 0) + return; + + + std::vector delta_elements; + delta_elements.reserve(prob.number_of_nodes()*prob.num_states(prob.begin())*3); + + impl::simple_hash_map delta_idx; + + + + // Initialize delta to zero and fill up the hash table with the appropriate values + // so we can index into delta later on. + for (node_iterator i = prob.begin(); i != prob.end(); ++i) + { + const unsigned long id_i = prob.node_id(i); + + for (neighbor_iterator j = prob.begin(i); j != prob.end(i); ++j) + { + const unsigned long id_j = prob.node_id(j); + delta_idx.insert(id_i, id_j, delta_elements.size()); + + const unsigned long num_states_xj = prob.num_states(j); + for (unsigned long xj = 0; xj < num_states_xj; ++xj) + delta_elements.push_back(0); + } + } + + + std::vector gamma_i; + std::vector > gamma_ji; + std::vector > delta_to_j_no_i; + // These arrays will end up with a length equal to the maximum number of neighbors + // of any node in the graph. So reserve a bigish number of slots so that we are + // very unlikely to need to preform an expensive reallocation during the + // optimization. + gamma_ji.reserve(10000); + delta_to_j_no_i.reserve(10000); + + + double max_change = eps + 1; + // Now do the main body of the optimization. + unsigned long iter; + for (iter = 0; iter < max_iter && max_change > eps; ++iter) + { + max_change = -std::numeric_limits::infinity(); + + for (node_iterator i = prob.begin(); i != prob.end(); ++i) + { + const unsigned long id_i = prob.node_id(i); + const unsigned long num_states_xi = prob.num_states(i); + gamma_i.assign(num_states_xi, 0); + + double num_neighbors = 0; + + unsigned int jcnt = 0; + // first we fill in the gamma vectors + for (neighbor_iterator j = prob.begin(i); j != prob.end(i); ++j) + { + // Make sure these arrays are big enough to hold all the neighbor + // information. + if (jcnt >= gamma_ji.size()) + { + gamma_ji.resize(gamma_ji.size()+1); + delta_to_j_no_i.resize(delta_to_j_no_i.size()+1); + } + + ++num_neighbors; + const unsigned long id_j = prob.node_id(j); + const unsigned long num_states_xj = prob.num_states(j); + + gamma_ji[jcnt].assign(num_states_xi, -std::numeric_limits::infinity()); + delta_to_j_no_i[jcnt].assign(num_states_xj, 0); + + // compute delta_j^{-i} and store it in delta_to_j_no_i[jcnt] + for (neighbor_iterator k = prob.begin(j); k != prob.end(j); ++k) + { + const unsigned long id_k = prob.node_id(k); + if (id_k==id_i) + continue; + const double* const delta_kj = &delta_elements[delta_idx(id_k,id_j)]; + for (unsigned long xj = 0; xj < num_states_xj; ++xj) + { + delta_to_j_no_i[jcnt][xj] += delta_kj[xj]; + } + } + + // now compute gamma values + for (unsigned long xi = 0; xi < num_states_xi; ++xi) + { + for (unsigned long xj = 0; xj < num_states_xj; ++xj) + { + gamma_ji[jcnt][xi] = std::max(gamma_ji[jcnt][xi], prob.factor_value(i,j,xi,xj) + delta_to_j_no_i[jcnt][xj]); + } + gamma_i[xi] += gamma_ji[jcnt][xi]; + } + ++jcnt; + } + + // now update the delta values + jcnt = 0; + for (neighbor_iterator j = prob.begin(i); j != prob.end(i); ++j) + { + const unsigned long id_j = prob.node_id(j); + const unsigned long num_states_xj = prob.num_states(j); + + // messages from j to i + double* const delta_ji = &delta_elements[delta_idx(id_j,id_i)]; + + // messages from i to j + double* const delta_ij = &delta_elements[delta_idx(id_i,id_j)]; + + for (unsigned long xj = 0; xj < num_states_xj; ++xj) + { + double best_val = -std::numeric_limits::infinity(); + + for (unsigned long xi = 0; xi < num_states_xi; ++xi) + { + double val = prob.factor_value(i,j,xi,xj) + 2/(num_neighbors+1)*gamma_i[xi] -gamma_ji[jcnt][xi]; + if (val > best_val) + best_val = val; + } + best_val = -0.5*delta_to_j_no_i[jcnt][xj] + 0.5*best_val; + + if (std::abs(delta_ij[xj] - best_val) > max_change) + max_change = std::abs(delta_ij[xj] - best_val); + + delta_ij[xj] = best_val; + } + + for (unsigned long xi = 0; xi < num_states_xi; ++xi) + { + double new_val = -1/(num_neighbors+1)*gamma_i[xi] + gamma_ji[jcnt][xi]; + if (std::abs(delta_ji[xi] - new_val) > max_change) + max_change = std::abs(delta_ji[xi] - new_val); + delta_ji[xi] = new_val; + } + ++jcnt; + } + } + } + + + // now decode the "beliefs" + std::vector b; + for (node_iterator i = prob.begin(); i != prob.end(); ++i) + { + const unsigned long id_i = prob.node_id(i); + b.assign(prob.num_states(i), 0); + + for (neighbor_iterator k = prob.begin(i); k != prob.end(i); ++k) + { + const unsigned long id_k = prob.node_id(k); + + for (unsigned long xi = 0; xi < b.size(); ++xi) + { + const double* const delta_ki = &delta_elements[delta_idx(id_k,id_i)]; + b[xi] += delta_ki[xi]; + } + } + + map_assignment[id_i] = index_of_max(mat(b)); + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_FIND_MAX_FACTOR_GRAPH_nMPLP_Hh_ + diff --git a/ml/dlib/dlib/optimization/find_max_factor_graph_nmplp_abstract.h b/ml/dlib/dlib/optimization/find_max_factor_graph_nmplp_abstract.h new file mode 100644 index 000000000..3dd9aead0 --- /dev/null +++ b/ml/dlib/dlib/optimization/find_max_factor_graph_nmplp_abstract.h @@ -0,0 +1,365 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_FIND_MAX_FACTOR_GRAPH_nMPLP_ABSTRACT_Hh_ +#ifdef DLIB_FIND_MAX_FACTOR_GRAPH_nMPLP_ABSTRACT_Hh_ + +#include + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class map_problem + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a factor graph or graphical model. In + particular, this object defines the interface a MAP problem on + a factor graph must implement if it is to be solved using the + find_max_factor_graph_nmplp() routine defined at the bottom of this file. + + Note that there is no dlib::map_problem object. What you are + looking at here is simply the interface definition for a map problem. + You must implement your own version of this object for the problem + you wish to solve and then pass it to the find_max_factor_graph_nmplp() routine. + + + Note also that a factor graph should not have any nodes which are + neighbors with themselves. Additionally, the graph is undirected. This + mean that if A is a neighbor of B then B must be a neighbor of A for + the map problem to be valid. + + + Finally, note that the "neighbor" relationship between nodes means the + following: Two nodes are neighbors if and only if there is a potential + function (implemented by the factor_value() method) which operates on + the nodes. + !*/ + + public: + + class node_iterator + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a simple forward iterator for iterating over + the nodes/variables in this factor graph. + + Note that you can't dereference the iterator and + obtain a value. That is, the iterator is opaque to + the user. It is used only as an argument to the other + methods defined in this interface. + !*/ + + public: + node_iterator( + ); + /*! + ensures + - constructs an iterator in an undefined state + !*/ + + node_iterator( + const node_iterator& item + ); + /*! + ensures + - #*this is a copy of item + !*/ + + node_iterator& operator= ( + const node_iterator& item + ); + /*! + ensures + - #*this is a copy of item + - returns #*this + !*/ + + bool operator== ( + const node_iterator& item + ) const; + /*! + ensures + - returns true if *this and item both reference + the same node in the factor graph and false + otherwise. + !*/ + + bool operator!= ( + const node_iterator& item + ) const; + /*! + ensures + - returns false if *this and item both reference + the same node in the factor graph and true + otherwise. + !*/ + + node_iterator& operator++( + ); + /*! + ensures + - advances *this to the next node in the factor graph. + - returns a reference to the updated *this + (i.e. this is the ++object form of the increment operator) + !*/ + }; + + class neighbor_iterator + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a simple forward iterator for iterating over + the nodes/variables in this factor graph. This version + of the iterator is used for iterating over the neighbors + of another node in the graph. + !*/ + + public: + neighbor_iterator( + ); + /*! + ensures + - constructs an iterator in an undefined state + !*/ + + neighbor_iterator( + const neighbor_iterator& item + ); + /*! + ensures + - #*this is a copy of item + !*/ + + neighbor_iterator& operator= ( + const neighbor_iterator& item + ); + /*! + ensures + - #*this is a copy of item + - returns #*this + !*/ + + bool operator== ( + const neighbor_iterator& item + ) const; + /*! + ensures + - returns true if *this and item both reference + the same node in the factor graph and false + otherwise. + !*/ + + bool operator!= ( + const neighbor_iterator& item + ) const; + /*! + ensures + - returns false if *this and item both reference + the same node in the factor graph and true + otherwise. + !*/ + + neighbor_iterator& operator++( + ); + /*! + ensures + - advances *this to the next node in the factor graph. + - returns a reference to the updated *this + (i.e. this is the ++object form of the increment operator) + !*/ + }; + + unsigned long number_of_nodes ( + ) const; + /*! + ensures + - returns the number of nodes in the factor graph. Or in other words, + returns the number of variables in the MAP problem. + !*/ + + node_iterator begin( + ) const; + /*! + ensures + - returns an iterator to the first node in the graph. If no such + node exists then returns end(). + !*/ + + node_iterator end( + ) const; + /*! + ensures + - returns an iterator to one past the last node in the graph. + !*/ + + neighbor_iterator begin( + const node_iterator& it + ) const; + /*! + requires + - it == a valid iterator (i.e. it must be in the range [begin(), end())) + ensures + - returns an iterator to the first neighboring node of the node + referenced by it. If no such node exists then returns end(it). + !*/ + + neighbor_iterator begin( + const neighbor_iterator& it + ) const; + /*! + requires + - it == a valid iterator. (i.e. it must be in the range + [begin(i), end(i)) where i is some valid iterator. ) + ensures + - returns an iterator to the first neighboring node of the node + referenced by it. If no such node exists then returns end(it). + !*/ + + neighbor_iterator end( + const node_iterator& it + ) const; + /*! + requires + - it == a valid iterator (i.e. it must be in the range [begin(), end())) + ensures + - returns an iterator to one past the last neighboring node of the node + referenced by it. + !*/ + + neighbor_iterator end( + const neighbor_iterator& it + ) const; + /*! + requires + - it == a valid iterator. (i.e. it must be in the range + [begin(i), end(i)) where i is some valid iterator. ) + ensures + - returns an iterator to one past the last neighboring node of the node + referenced by it. + !*/ + + unsigned long node_id ( + const node_iterator& it + ) const; + /*! + requires + - it == a valid iterator (i.e. it must be in the range [begin(), end())) + ensures + - returns a number ID such that: + - 0 <= ID < number_of_nodes() + - ID == a number which uniquely identifies the node pointed to by it. + !*/ + + unsigned long node_id ( + const neighbor_iterator& it + ) const; + /*! + requires + - it == a valid iterator. (i.e. it must be in the range + [begin(i), end(i)) where i is some valid iterator. ) + ensures + - returns a number ID such that: + - 0 <= ID < number_of_nodes() + - ID == a number which uniquely identifies the node pointed to by it. + !*/ + + unsigned long num_states ( + const node_iterator& it + ) const; + /*! + requires + - it == a valid iterator (i.e. it must be in the range [begin(), end())) + ensures + - returns the number of states attainable by the node/variable referenced by it. + !*/ + + unsigned long num_states ( + const neighbor_iterator& it + ) const; + /*! + requires + - it == a valid iterator. (i.e. it must be in the range + [begin(i), end(i)) where i is some valid iterator. ) + ensures + - returns the number of states attainable by the node/variable referenced by it. + !*/ + + // The next four functions all have the same contract. + double factor_value (const node_iterator& it1, const node_iterator& it2, unsigned long s1, unsigned long s2) const; + double factor_value (const neighbor_iterator& it1, const node_iterator& it2, unsigned long s1, unsigned long s2) const; + double factor_value (const node_iterator& it1, const neighbor_iterator& it2, unsigned long s1, unsigned long s2) const; + double factor_value (const neighbor_iterator& it1, const neighbor_iterator& it2, unsigned long s1, unsigned long s2) const; + /*! + requires + - it1 == a valid iterator + - it2 == a valid iterator + - 0 <= s1 < num_states(it1) + - 0 <= s2 < num_states(it2) + - it1 and it2 reference nodes which are neighbors in the factor graph + ensures + - returns the value of the factor/potential function for the given pair of + nodes, defined by it1 and it2, for the case where they take on the values + s1 and s2 respectively. + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename map_problem + > + void find_max_factor_graph_nmplp ( + const map_problem& prob, + std::vector& map_assignment, + unsigned long max_iter, + double eps + ); + /*! + requires + - for all valid i: prob.num_states(i) >= 2 + - map_problem == an object with an interface compatible with the map_problem + object defined at the top of this file. + - eps > 0 + ensures + - This function is a tool for approximately solving the given MAP problem in a graphical + model or factor graph with pairwise potential functions. That is, it attempts + to solve a certain kind of optimization problem which can be defined as follows: + maximize: f(X) + where X is a set of integer valued variables and f(X) can be written as the + sum of functions which each involve only two variables from X. In reference + to the prob object, the nodes in prob represent the variables in X and the + functions which are summed are represented by prob.factor_value(). + - #map_assignment == the result of the optimization. + - #map_assignment.size() == prob.number_of_nodes() + - for all valid i: + - #map_assignment[prob.node_id(i)] < prob.num_states(i) + - #map_assignment[prob.node_id(i)] == The approximate MAP assignment for node/variable i. + - eps controls the stopping condition, smaller values of eps lead to more accurate + solutions of the relaxed linear program but may take more iterations. Note that + the algorithm will never execute more than max_iter iterations regardless of + the setting of eps. + - If the graph is tree-structured then this routine always gives the exact solution + to the MAP problem. However, for graphs with cycles, the solution may be approximate. + + + - This function is an implementation of the NMPLP algorithm introduced in the + following papers: + Fixing Max-Product: Convergent Message Passing Algorithms for MAP LP-Relaxations (2008) + by Amir Globerson and Tommi Jaakkola + + Introduction to dual decomposition for inference (2011) + by David Sontag, Amir Globerson, and Tommi Jaakkola + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_FIND_MAX_FACTOR_GRAPH_nMPLP_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/optimization/find_max_factor_graph_viterbi.h b/ml/dlib/dlib/optimization/find_max_factor_graph_viterbi.h new file mode 100644 index 000000000..f7cbcb8d9 --- /dev/null +++ b/ml/dlib/dlib/optimization/find_max_factor_graph_viterbi.h @@ -0,0 +1,232 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_FIND_MAX_FACTOR_GRAPH_VITERBi_Hh_ +#define DLIB_FIND_MAX_FACTOR_GRAPH_VITERBi_Hh_ + +#include "find_max_factor_graph_viterbi_abstract.h" +#include +#include "../matrix.h" +#include "../array2d.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + struct viterbi_data + { + viterbi_data() :val(-std::numeric_limits::infinity()), back_index(0) {} + double val; + unsigned long back_index; + }; + + template + inline bool advance_state( + matrix& node_states, + unsigned long num_states + ) + /*! + ensures + - advances node_states to the next state by adding 1 + to node_states(node_states.size()-1) and carrying any + rollover (modulo num_states). Stores the result into #node_states. + - if (#node_states is all zeros) then + - returns false + - else + - returns true + !*/ + { + for (long i = node_states.size()-1; i >= 0; --i) + { + node_states(i) += 1; + if (node_states(i) < num_states) + return true; + + node_states(i) = 0; + } + return false; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_problem + > + void find_max_factor_graph_viterbi ( + const map_problem& prob, + std::vector& map_assignment + ) + { + using namespace dlib::impl; + const unsigned long order = prob.order(); + const unsigned long num_states = prob.num_states(); + + + DLIB_ASSERT(prob.num_states() > 0, + "\t void find_max_factor_graph_viterbi()" + << "\n\t The nodes in a factor graph have to be able to take on more than 0 states." + ); + DLIB_ASSERT(std::pow(num_states,(double)order) < std::numeric_limits::max(), + "\t void find_max_factor_graph_viterbi()" + << "\n\t The order is way too large for this algorithm to handle." + << "\n\t order: " << order + << "\n\t num_states: " << num_states + << "\n\t std::pow(num_states,order): " << std::pow(num_states,(double)order) + << "\n\t std::numeric_limits::max(): " << std::numeric_limits::max() + ); + + if (prob.number_of_nodes() == 0) + { + map_assignment.clear(); + return; + } + + if (order == 0) + { + map_assignment.resize(prob.number_of_nodes()); + for (unsigned long i = 0; i < map_assignment.size(); ++i) + { + matrix node_state; + unsigned long best_state = 0; + double best_val = -std::numeric_limits::infinity(); + for (unsigned long s = 0; s < num_states; ++s) + { + node_state(0) = s; + const double temp = prob.factor_value(i,node_state); + if (temp > best_val) + { + best_val = temp; + best_state = s; + } + } + map_assignment[i] = best_state; + } + return; + } + + + const unsigned long trellis_size = static_cast(std::pow(num_states,(double)order)); + unsigned long init_ring_size = 1; + + array2d trellis; + trellis.set_size(prob.number_of_nodes(), trellis_size); + + + for (unsigned long node = 0; node < prob.number_of_nodes(); ++node) + { + + if (node < order) + { + matrix node_states; + node_states.set_size(std::min(node, order) + 1); + node_states = 0; + + unsigned long idx = 0; + if (node == 0) + { + do + { + trellis[node][idx].val = prob.factor_value(node,node_states); + ++idx; + } while(advance_state(node_states,num_states)); + } + else + { + init_ring_size *= num_states; + do + { + const unsigned long back_index = idx%init_ring_size; + trellis[node][idx].val = prob.factor_value(node,node_states) + trellis[node-1][back_index].val; + trellis[node][idx].back_index = back_index; + ++idx; + } while(advance_state(node_states,num_states)); + + } + } + else if (order == 1) + { + /* + WHAT'S THE DEAL WITH THIS PREPROCESSOR MACRO? + Well, if we can declare the dimensions of node_states as a compile + time constant then this function runs significantly faster. So this macro + is here to let us do that. It just lets us avoid replicating this code + block in the following if statements for different order sizes. + */ +#define DLIB_FMFGV_WORK \ + node_states = 0; \ + unsigned long count = 0; \ + for (unsigned long i = 0; i < trellis_size; ++i) \ + { \ + unsigned long back_index = 0; \ + double best_score = -std::numeric_limits::infinity(); \ + for (unsigned long s = 0; s < num_states; ++s) \ + { \ + const double temp = prob.factor_value(node,node_states) + trellis[node-1][count%trellis_size].val; \ + if (temp > best_score) \ + { \ + best_score = temp; \ + back_index = count%trellis_size; \ + } \ + advance_state(node_states,num_states); \ + ++count; \ + } \ + trellis[node][i].val = best_score; \ + trellis[node][i].back_index = back_index; \ + } + + matrix node_states; + DLIB_FMFGV_WORK + } + else if (order == 2) + { + matrix node_states; + DLIB_FMFGV_WORK + } + else if (order == 3) + { + matrix node_states; + DLIB_FMFGV_WORK + } + else + { + // The general case, here we don't define the size of node_states at compile time. + matrix node_states(order+1); + DLIB_FMFGV_WORK + } + } + + + map_assignment.resize(prob.number_of_nodes()); + // Figure out which state of the last node has the biggest value. + unsigned long back_index = 0; + double best_val = -std::numeric_limits::infinity(); + for (long i = 0; i < trellis.nc(); ++i) + { + if (trellis[trellis.nr()-1][i].val > best_val) + { + best_val = trellis[trellis.nr()-1][i].val; + back_index = i; + } + } + // Follow the back links to find the decoding. + for (long node = map_assignment.size()-1; node >= 0; --node) + { + map_assignment[node] = back_index/init_ring_size; + back_index = trellis[node][back_index].back_index; + if (node < (long)order) + init_ring_size /= num_states; + } + + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_FIND_MAX_FACTOR_GRAPH_VITERBi_Hh_ + + diff --git a/ml/dlib/dlib/optimization/find_max_factor_graph_viterbi_abstract.h b/ml/dlib/dlib/optimization/find_max_factor_graph_viterbi_abstract.h new file mode 100644 index 000000000..c19e4c7eb --- /dev/null +++ b/ml/dlib/dlib/optimization/find_max_factor_graph_viterbi_abstract.h @@ -0,0 +1,131 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_FIND_MAX_FACTOR_GRAPH_VITERBi_ABSTRACT_Hh_ +#ifdef DLIB_FIND_MAX_FACTOR_GRAPH_VITERBi_ABSTRACT_Hh_ + +#include +#include "../matrix.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class map_problem + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a chain-structured factor graph or graphical + model. In particular, this object defines the interface a MAP problem + on a factor graph must implement if it is to be solved using the + find_max_factor_graph_viterbi() routine defined at the bottom of this file. + + Note that there is no dlib::map_problem object. What you are looking + at here is simply the interface definition for a map problem. You must + implement your own version of this object for the problem you wish to + solve and then pass it to the find_max_factor_graph_viterbi() routine. + !*/ + + public: + + unsigned long order ( + ) const; + /*! + ensures + - returns the order of this model. The order has the following interpretation: + This model can represent a high order Markov chain. If order()==1 then map_problem + represents a basic chain-structured graph where nodes only depend on their immediate + neighbors. However, high order Markov models can also be used by setting order() > 1. + !*/ + + unsigned long num_states ( + ) const; + /*! + ensures + - returns the number of states attainable by each variable/node in the graph. + !*/ + + unsigned long number_of_nodes ( + ) const; + /*! + ensures + - returns the number of nodes in the factor graph. Or in other words, + returns the number of variables in the MAP problem. + !*/ + + template < + typename EXP + > + double factor_value ( + unsigned long node_id, + const matrix_exp& node_states + ) const; + /*! + requires + - EXP::type == unsigned long + (i.e. node_states contains unsigned longs) + - node_id < number_of_nodes() + - node_states.size() == min(node_id, order()) + 1 + - is_vector(node_states) == true + - max(node_states) < num_states() + ensures + - In a chain-structured graph, each node has a potential function associated with + it. The potential function operates on the variable given by the node as well + as the order() previous variables. Therefore, factor_value() returns the value + of the factor/potential function associated with node node_id where the following + nodes take on the values defined below: + - node_states(0) == the value of the node with ID node_id + - node_states(i) == the value of the node with ID node_id-i + - It is ok for this function to return a value of -std::numeric_limits::infinity(). + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename map_problem + > + void find_max_factor_graph_viterbi ( + const map_problem& prob, + std::vector& map_assignment + ); + /*! + requires + - prob.num_states() > 0 + - std::pow(prob.num_states(), prob.order()) < std::numeric_limits::max() + (i.e. The Viterbi algorithm is exponential in the order of the map problem. So don't + make order too large.) + - map_problem == an object with an interface compatible with the map_problem + object defined at the top of this file. + ensures + - This function is a tool for exactly solving the MAP problem in a chain-structured + graphical model or factor graph. That is, it attempts to solve a certain kind of + optimization problem which can be defined as follows: + - Let X denote a set of prob.number_of_nodes() integer valued variables, each taking + a value in the range [0, prob.num_states()). + - Let X(i) = the ith variable in X. + - Let F(i) = factor_value_i(X(i), X(i-1), ..., X(i-prob.order())) + (This is the value returned by prob.factor_value(i, node_states). Note that + each factor's value function operates on at most prob.order()+1 variables. + Moreover, the variables are adjacent and hence the graph is "chain-structured".) + + Then this function finds the assignments to the X variables which + maximizes: sum over all valid i: F(i) + + - #map_assignment == the result of the optimization. + - #map_assignment.size() == prob.number_of_nodes() + - for all valid i: + - #map_assignment[i] < prob.num_states() + - #map_assignment[i] == The MAP assignment for node/variable i. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_FIND_MAX_FACTOR_GRAPH_VITERBi_ABSTRACT_Hh_ + + + diff --git a/ml/dlib/dlib/optimization/find_max_parse_cky.h b/ml/dlib/dlib/optimization/find_max_parse_cky.h new file mode 100644 index 000000000..79614792a --- /dev/null +++ b/ml/dlib/dlib/optimization/find_max_parse_cky.h @@ -0,0 +1,414 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_FIND_MAX_PaRSE_CKY_Hh_ +#define DLIB_FIND_MAX_PaRSE_CKY_Hh_ + +#include "find_max_parse_cky_abstract.h" +#include +#include +#include +#include "../serialize.h" +#include "../array2d.h" + +namespace dlib +{ + +// ----------------------------------------------------------------------------------------- + + template + struct constituent + { + unsigned long begin, end, k; + T left_tag; + T right_tag; + }; + + template + void serialize( + const constituent& item, + std::ostream& out + ) + { + serialize(item.begin, out); + serialize(item.end, out); + serialize(item.k, out); + serialize(item.left_tag, out); + serialize(item.right_tag, out); + } + + template + void deserialize( + constituent& item, + std::istream& in + ) + { + deserialize(item.begin, in); + deserialize(item.end, in); + deserialize(item.k, in); + deserialize(item.left_tag, in); + deserialize(item.right_tag, in); + } + +// ----------------------------------------------------------------------------------------- + + const unsigned long END_OF_TREE = 0xFFFFFFFF; + +// ----------------------------------------------------------------------------------------- + + template + struct parse_tree_element + { + constituent c; + T tag; // id for the constituent corresponding to this level of the tree + + unsigned long left; + unsigned long right; + double score; + }; + + template + void serialize ( + const parse_tree_element& item, + std::ostream& out + ) + { + serialize(item.c, out); + serialize(item.tag, out); + serialize(item.left, out); + serialize(item.right, out); + serialize(item.score, out); + } + + template + void deserialize ( + parse_tree_element& item, + std::istream& in + ) + { + deserialize(item.c, in); + deserialize(item.tag, in); + deserialize(item.left, in); + deserialize(item.right, in); + deserialize(item.score, in); + } + +// ----------------------------------------------------------------------------------------- + + namespace impl + { + template + unsigned long fill_parse_tree( + std::vector >& parse_tree, + const T& tag, + const array2d > >& back, + long r, long c + ) + /*! + requires + - back[r][c].size() == 0 || back[r][c].count(tag) != 0 + !*/ + { + // base case of the recursion + if (back[r][c].size() == 0) + { + return END_OF_TREE; + } + + const unsigned long idx = parse_tree.size(); + const parse_tree_element& item = back[r][c].find(tag)->second; + parse_tree.push_back(item); + + const long k = item.c.k; + const unsigned long idx_left = fill_parse_tree(parse_tree, item.c.left_tag, back, r, k-1); + const unsigned long idx_right = fill_parse_tree(parse_tree, item.c.right_tag, back, k, c); + parse_tree[idx].left = idx_left; + parse_tree[idx].right = idx_right; + return idx; + } + } + + template + void find_max_parse_cky ( + const std::vector& sequence, + const production_rule_function& production_rules, + std::vector >& parse_tree + ) + { + parse_tree.clear(); + if (sequence.size() == 0) + return; + + array2d > table(sequence.size(), sequence.size()); + array2d > > back(sequence.size(), sequence.size()); + typedef typename std::map::iterator itr; + typedef typename std::map >::iterator itr_b; + + for (long r = 0; r < table.nr(); ++r) + table[r][r][sequence[r]] = 0; + + std::vector > possible_tags; + + for (long r = table.nr()-2; r >= 0; --r) + { + for (long c = r+1; c < table.nc(); ++c) + { + for (long k = r; k < c; ++k) + { + for (itr i = table[k+1][c].begin(); i != table[k+1][c].end(); ++i) + { + for (itr j = table[r][k].begin(); j != table[r][k].end(); ++j) + { + constituent con; + con.begin = r; + con.end = c+1; + con.k = k+1; + con.left_tag = j->first; + con.right_tag = i->first; + possible_tags.clear(); + production_rules(sequence, con, possible_tags); + for (unsigned long m = 0; m < possible_tags.size(); ++m) + { + const double score = possible_tags[m].second + i->second + j->second; + itr match = table[r][c].find(possible_tags[m].first); + if (match == table[r][c].end() || score > match->second) + { + table[r][c][possible_tags[m].first] = score; + parse_tree_element item; + item.c = con; + item.score = score; + item.tag = possible_tags[m].first; + item.left = END_OF_TREE; + item.right = END_OF_TREE; + back[r][c][possible_tags[m].first] = item; + } + } + } + } + } + } + } + + + // now use back pointers to build the parse trees + const long r = 0; + const long c = back.nc()-1; + if (back[r][c].size() != 0) + { + + // find the max scoring element in back[r][c] + itr_b max_i = back[r][c].begin(); + itr_b i = max_i; + ++i; + for (; i != back[r][c].end(); ++i) + { + if (i->second.score > max_i->second.score) + max_i = i; + } + + parse_tree.reserve(c); + impl::fill_parse_tree(parse_tree, max_i->second.tag, back, r, c); + } + } + +// ----------------------------------------------------------------------------------------- + + class parse_tree_to_string_error : public error + { + public: + parse_tree_to_string_error(const std::string& str): error(str) {} + }; + + namespace impl + { + template + typename enable_if_c::type conditional_print( + const T& item, + std::ostream& out + ) { out << item << " "; } + + template + typename disable_if_c::type conditional_print( + const T& , + std::ostream& + ) { } + + template + void print_parse_tree_helper ( + const std::vector >& tree, + const std::vector& words, + unsigned long i, + const T& tag_to_skip, + std::ostream& out + ) + { + if (!skip_tag || tree[i].tag != tag_to_skip) + out << "["; + + bool left_recurse = false; + + // Only print if we are supposed to. Doing it this funny way avoids compiler + // errors in parse_tree_to_string() for the case where tag isn't + // printable. + if (!skip_tag || tree[i].tag != tag_to_skip) + conditional_print(tree[i].tag, out); + + if (tree[i].left < tree.size()) + { + left_recurse = true; + print_parse_tree_helper(tree, words, tree[i].left, tag_to_skip, out); + } + else + { + if ((tree[i].c.begin) < words.size()) + { + out << words[tree[i].c.begin] << " "; + } + else + { + std::ostringstream sout; + sout << "Parse tree refers to element " << tree[i].c.begin + << " of sequence which is only of size " << words.size() << "."; + throw parse_tree_to_string_error(sout.str()); + } + } + + if (left_recurse == true) + out << " "; + + if (tree[i].right < tree.size()) + { + print_parse_tree_helper(tree, words, tree[i].right, tag_to_skip, out); + } + else + { + if (tree[i].c.k < words.size()) + { + out << words[tree[i].c.k]; + } + else + { + std::ostringstream sout; + sout << "Parse tree refers to element " << tree[i].c.k + << " of sequence which is only of size " << words.size() << "."; + throw parse_tree_to_string_error(sout.str()); + } + } + + + if (!skip_tag || tree[i].tag != tag_to_skip) + out << "]"; + } + } + +// ----------------------------------------------------------------------------------------- + + template + std::string parse_tree_to_string ( + const std::vector >& tree, + const std::vector& words, + const unsigned long root_idx = 0 + ) + { + if (root_idx >= tree.size()) + return ""; + + std::ostringstream sout; + impl::print_parse_tree_helper(tree, words, root_idx, tree[root_idx].tag, sout); + return sout.str(); + } + +// ----------------------------------------------------------------------------------------- + + template + std::string parse_tree_to_string_tagged ( + const std::vector >& tree, + const std::vector& words, + const unsigned long root_idx = 0 + ) + { + if (root_idx >= tree.size()) + return ""; + + std::ostringstream sout; + impl::print_parse_tree_helper(tree, words, root_idx, tree[root_idx].tag, sout); + return sout.str(); + } + +// ----------------------------------------------------------------------------------------- + + template + std::string parse_trees_to_string ( + const std::vector >& tree, + const std::vector& words, + const T& tag_to_skip + ) + { + if (tree.size() == 0) + return ""; + + std::ostringstream sout; + impl::print_parse_tree_helper(tree, words, 0, tag_to_skip, sout); + return sout.str(); + } + +// ----------------------------------------------------------------------------------------- + + template + std::string parse_trees_to_string_tagged ( + const std::vector >& tree, + const std::vector& words, + const T& tag_to_skip + ) + { + if (tree.size() == 0) + return ""; + + std::ostringstream sout; + impl::print_parse_tree_helper(tree, words, 0, tag_to_skip, sout); + return sout.str(); + } + +// ----------------------------------------------------------------------------------------- + + namespace impl + { + template + void helper_find_trees_without_tag ( + const std::vector >& tree, + const T& tag, + std::vector& tree_roots, + unsigned long idx + ) + { + if (idx < tree.size()) + { + if (tree[idx].tag != tag) + { + tree_roots.push_back(idx); + } + else + { + helper_find_trees_without_tag(tree, tag, tree_roots, tree[idx].left); + helper_find_trees_without_tag(tree, tag, tree_roots, tree[idx].right); + } + } + } + } + + template + void find_trees_not_rooted_with_tag ( + const std::vector >& tree, + const T& tag, + std::vector& tree_roots + ) + { + tree_roots.clear(); + impl::helper_find_trees_without_tag(tree, tag, tree_roots, 0); + } + +// ----------------------------------------------------------------------------------------- + +} + +#endif // DLIB_FIND_MAX_PaRSE_CKY_Hh_ + diff --git a/ml/dlib/dlib/optimization/find_max_parse_cky_abstract.h b/ml/dlib/dlib/optimization/find_max_parse_cky_abstract.h new file mode 100644 index 000000000..52ffc787b --- /dev/null +++ b/ml/dlib/dlib/optimization/find_max_parse_cky_abstract.h @@ -0,0 +1,388 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_FIND_MAX_PARsE_CKY_ABSTRACT_Hh_ +#ifdef DLIB_FIND_MAX_PARsE_CKY_ABSTRACT_Hh_ + +#include +#include +#include "../algs.h" +#include "../serialize.h" + +namespace dlib +{ + +// ----------------------------------------------------------------------------------------- + + template < + typename T + > + struct constituent + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents the linguistic idea of a constituent, that is, a + group of words that functions as a single unit. In particular, it + represents a combination of two constituents into a new constituent. + + Additionally, a constituent object represents a range of words relative to + some std::vector of words. The range is from [begin, end) (i.e. including + begin but not including end, so using the normal C++ iterator notation). + Moreover, a constituent is always composed of two parts, each having a tag. + Therefore, the left part is composed of the words in the range [begin,k) + and has tag left_tag while the right part of the constituent contains the + words in the range [k,end) and has the tag right_tag. + + The tags are user defined objects of type T. In general, they are used to + represent syntactic categories such as noun phrase, verb phrase, etc. + !*/ + + unsigned long begin, end, k; + T left_tag; + T right_tag; + }; + + template < + typename T + > + void serialize( + const constituent& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + + template < + typename T + > + void deserialize( + constituent& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +// ----------------------------------------------------------------------------------------- + + /*!A END_OF_TREE is used to indicate that parse_tree_element::left or + parse_tree_element::right doesn't point to another subtree. + !*/ + const unsigned long END_OF_TREE = 0xFFFFFFFF; + +// ----------------------------------------------------------------------------------------- + + template < + typename T + > + struct parse_tree_element + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is used to represent a node in a binary parse tree. An entire + parse tree is represented by a std::vector of parse_tree_element objects. + We follow the convention that the first element of this vector is always + the root of the entire tree. + + The fields of this object have the following interpretations: + - c == the constituent spanned by this node in the parse tree. + Therefore, the node spans the words in the range [c.begin, c.end). + - tag == the syntactic category of this node in the parse tree. + - score == the score or log likelihood for this parse tree. In + general, this is the sum of scores of all the production rules used + to build the tree rooted at the current node. + - let PT denote the vector of parse_tree_elements that defines an + entire parse tree. Then we have: + - if (left != END_OF_TREE) then + - PT[left] == the left sub-tree of the current node. + - PT[left] spans the words [c.begin, c.k) + - PT[left].tag == c.left_tag + - else + - there is no left sub-tree + + - if (right != END_OF_TREE) then + - PT[right] == the right sub-tree of the current node. + - PT[right] spans the words [c.k, c.end) + - PT[right].tag == c.right_tag + - else + - there is no right sub-tree + !*/ + + constituent c; + T tag; + double score; + + unsigned long left; + unsigned long right; + }; + + template < + typename T + > + void serialize ( + const parse_tree_element& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + + template < + typename T + > + void deserialize ( + parse_tree_element& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +// ----------------------------------------------------------------------------------------- +// ----------------------------------------------------------------------------------------- + + void example_production_rule_function ( + const std::vector& words, + const constituent& c, + std::vector >& possible_tags + ) + /*! + requires + - 0 <= c.begin < c.k < c.end <= words.size() + - possible_tags.size() == 0 + ensures + - Finds all the syntactic categories that can be used to label c and puts those + categories, along with their scores, into possible_tags. Or in other words, + this function determines which production rules can be used to turn the left + and right sub-constituents in c into a single constituent. The contents of c + have the following interpretations: + - The left sub-constituent has syntactic category c.left_tag + - for all i such that c.begin <= i < c.k: + - words[i] is part of the left sub-constituent. + - The right sub-constituent has syntactic category c.right_tag + - for all i such that c.k <= i < c.end: + - words[i] is part of the right sub-constituent. + + - Note that example_production_rule_function() is not a real function. It is + here just to show you how to define production rule producing functions for + use with the find_max_parse_cky() routine defined below. + !*/ + + template < + typename T, + typename production_rule_function + > + void find_max_parse_cky ( + const std::vector& words, + const production_rule_function& production_rules, + std::vector >& parse_tree + ); + /*! + requires + - production_rule_function == a function or function object with the same + interface as example_production_rule_function defined above. + - It must be possible to store T objects in a std::map. + ensures + - Uses the CKY algorithm to find the most probable/highest scoring binary parse + tree of the given vector of words. + - if (#parse_tree.size() == 0) then + - There is no parse tree, using the given production_rules, that can cover + the given word sequence. + - else + - #parse_tree == the highest scoring parse tree that covers all the + elements of words. + - #parse_tree[0] == the root node of the parse tree. + - #parse_tree[0].score == the score of the parse tree. This is the sum of + the scores of all production rules used to construct the tree. + - #parse_tree[0].begin == 0 + - #parse_tree[0].end == words.size() + - This function uses production_rules() to find out what the allowed production + rules are. That is, production_rules() defines all properties of the grammar + used by find_max_parse_cky(). + !*/ + +// ----------------------------------------------------------------------------------------- +// ----------------------------------------------------------------------------------------- + + class parse_tree_to_string_error : public error + { + /*! + WHAT THIS OBJECT REPRESENTS + This is the exception thrown by parse_tree_to_string() and + parse_tree_to_string_tagged() if the inputs are discovered to be invalid. + !*/ + }; + +// ----------------------------------------------------------------------------------------- + + template < + typename T, + typename U + > + std::string parse_tree_to_string ( + const std::vector >& tree, + const std::vector& words, + const unsigned long root_idx = 0 + ); + /*! + requires + - It must be possible to print U objects to an ostream using operator<< + (typically, U would be something like std::string) + ensures + - Interprets tree as a parse tree defined over the given sequence of words. + - returns a bracketed string that represents the parse tree over the words. + For example, suppose the following parse tree is input: + + /\ + / \ + /\ \ + / \ \ + the dog ran + + Then the output would be the string "[[the dog] ran]" + - Only the sub-tree rooted at tree[root_idx] will be output. If root_idx >= + tree.size() then the empty string is returned. + throws + - parse_tree_to_string_error + This exception is thrown if an invalid tree is detected. This might happen + if the tree refers to elements of words that don't exist because words is + shorted than it is supposed to be. + !*/ + +// ----------------------------------------------------------------------------------------- + + template < + typename T, + typename U + > + std::string parse_tree_to_string_tagged ( + const std::vector >& tree, + const std::vector& words, + const unsigned long root_idx = 0 + ); + /*! + requires + - It must be possible to print T objects to an ostream using operator<< + - It must be possible to print U objects to an ostream using operator<< + (typically, U would be something like std::string) + ensures + - This function does the same thing as parse_tree_to_string() except that it + also includes the parse_tree_element::tag object in the output. Therefore, + the tag of each bracket will be included as the first token inside the + bracket. For example, suppose the following parse tree is input (where tags + are shown at the vertices): + + S + /\ + NP \ + /\ \ + / \ \ + the dog ran + + Then the output would be the string "[S [NP the dog] ran]" + - Only the sub-tree rooted at tree[root_idx] will be output. If root_idx >= + tree.size() then the empty string is returned. + throws + - parse_tree_to_string_error + This exception is thrown if an invalid tree is detected. This might happen + if the tree refers to elements of words that don't exist because words is + shorted than it is supposed to be. + !*/ + +// ----------------------------------------------------------------------------------------- + + template < + typename T, + typename U + > + std::string parse_trees_to_string ( + const std::vector >& tree, + const std::vector& words, + const T& tag_to_skip + ); + /*! + requires + - It must be possible to print U objects to an ostream using operator<< + (typically, U would be something like std::string) + ensures + - This function behaves just like parse_tree_to_string() except that it will + not print the brackets (i.e. []) for the top most parts of the tree which + have tags equal to tag_to_skip. It will however print all the words. + Therefore, this function only includes brackets on the subtrees which begin + with a tag other than tag_to_skip. + throws + - parse_tree_to_string_error + This exception is thrown if an invalid tree is detected. This might happen + if the tree refers to elements of words that don't exist because words is + shorted than it is supposed to be. + !*/ + +// ----------------------------------------------------------------------------------------- + + template < + typename T, + typename U + > + std::string parse_trees_to_string_tagged ( + const std::vector >& tree, + const std::vector& words, + const T& tag_to_skip + ); + /*! + requires + - It must be possible to print T objects to an ostream using operator<< + - It must be possible to print U objects to an ostream using operator<< + (typically, U would be something like std::string) + ensures + - This function behaves just like parse_tree_to_string_tagged() except that it + will not print the brackets (i.e. []) for the top most parts of the tree + which have tags equal to tag_to_skip. It will however print all the words. + Therefore, this function only includes brackets on the subtrees which begin + with a tag other than tag_to_skip. + throws + - parse_tree_to_string_error + This exception is thrown if an invalid tree is detected. This might happen + if the tree refers to elements of words that don't exist because words is + shorted than it is supposed to be. + !*/ + +// ----------------------------------------------------------------------------------------- + + template < + typename T + > + void find_trees_not_rooted_with_tag ( + const std::vector >& tree, + const T& tag, + std::vector& tree_roots + ); + /*! + requires + - objects of type T must be comparable using operator== + ensures + - Finds all the largest non-overlapping trees in tree that are not rooted with + the given tag. + - find_trees_not_rooted_with_tag() is useful when you want to cut a parse tree + into a bunch of sub-trees and you know that the top level of the tree is all + composed of the same kind of tag. So if you want to just "slice off" the top + of the tree where this tag lives then this function is useful for doing that. + - #tree_roots.size() == the number of sub-trees found. + - for all valid i: + - tree[#tree_roots[i]].tag != tag + - To make the operation of this function clearer, here are a few examples of + what it will do: + - if (tree[0].tag != tag) then + - #tree_roots.size() == 0 + - #tree_roots[0] == 0 + - else if (tree[0].tag == tag but its immediate children's tags are not equal to tag) then + - #tree_roots.size() == 2 + - #tree_roots[0] == tree[0].left + - #tree_roots[1] == tree[0].right + !*/ + +// ----------------------------------------------------------------------------------------- + +} + +#endif // DLIB_FIND_MAX_PARsE_CKY_ABSTRACT_Hh_ + diff --git a/ml/dlib/dlib/optimization/find_optimal_parameters.h b/ml/dlib/dlib/optimization/find_optimal_parameters.h new file mode 100644 index 000000000..0884778c9 --- /dev/null +++ b/ml/dlib/dlib/optimization/find_optimal_parameters.h @@ -0,0 +1,117 @@ +// Copyright (C) 2016 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_fIND_OPTIMAL_PARAMETERS_Hh_ +#define DLIB_fIND_OPTIMAL_PARAMETERS_Hh_ + +#include "../matrix.h" +#include "find_optimal_parameters_abstract.h" +#include "optimization_bobyqa.h" +#include "optimization_line_search.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename funct + > + double find_optimal_parameters ( + double initial_search_radius, + double eps, + const unsigned int max_f_evals, + matrix& x, + const matrix& x_lower, + const matrix& x_upper, + const funct& f + ) + { + DLIB_CASSERT(x.size() == x_lower.size() && x_lower.size() == x_upper.size() && x.size() > 0, + "\t double find_optimal_parameters()" + << "\n\t x.size(): " << x.size() + << "\n\t x_lower.size(): " << x_lower.size() + << "\n\t x_upper.size(): " << x_upper.size() + ); + + // check the requirements. Also split the assert up so that the error message isn't huge. + DLIB_CASSERT(max_f_evals > 1 && eps > 0 && initial_search_radius > eps, + "\t double find_optimal_parameters()" + << "\n\t Invalid arguments have been given to this function" + << "\n\t initial_search_radius: " << initial_search_radius + << "\n\t eps: " << eps + << "\n\t max_f_evals: " << max_f_evals + ); + + DLIB_CASSERT( min(x_upper - x_lower) > 0 && + min(x - x_lower) >= 0 && min(x_upper - x) >= 0, + "\t double find_optimal_parameters()" + << "\n\t The bounds constraints have to make sense and also contain the starting point." + << "\n\t min(x_upper - x_lower): " << min(x_upper - x_lower) + << "\n\t min(x - x_lower) >= 0 && min(x_upper - x) >= 0: " << (min(x - x_lower) >= 0 && min(x_upper - x) >= 0) + ); + + // if the search radius is too big then shrink it so it fits inside the bounds. + if (initial_search_radius*2 >= min(x_upper-x_lower)) + initial_search_radius = 0.5*min(x_upper-x_lower)*0.99; + + + double objective_val = std::numeric_limits::infinity(); + size_t num_iter_used = 0; + if (x.size() == 1) + { + // BOBYQA requires x to have at least 2 variables in it. So we can't call it in + // this case. Instead we call find_min_single_variable(). + matrix temp(1); + auto ff = [&](const double& xx) + { + temp = xx; + double obj = f(temp); + ++num_iter_used; + // keep track of the best x. + if (obj < objective_val) + { + objective_val = obj; + x = temp; + } + return obj; + }; + try + { + double dx = x(0); + find_min_single_variable(ff, dx, x_lower(0), x_upper(0), eps, max_f_evals, initial_search_radius); + } catch (optimize_single_variable_failure& ) + { + } + } + else + { + auto ff = [&](const matrix& xx) + { + double obj = f(xx); + ++num_iter_used; + // keep track of the best x. + if (obj < objective_val) + { + objective_val = obj; + x = xx; + } + return obj; + }; + try + { + matrix start_x = x; + find_min_bobyqa(ff, start_x, 2*x.size()+1, x_lower, x_upper, initial_search_radius, eps, max_f_evals); + } catch (bobyqa_failure& ) + { + } + } + + return objective_val; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_fIND_OPTIMAL_PARAMETERS_Hh_ + diff --git a/ml/dlib/dlib/optimization/find_optimal_parameters_abstract.h b/ml/dlib/dlib/optimization/find_optimal_parameters_abstract.h new file mode 100644 index 000000000..96dcee89b --- /dev/null +++ b/ml/dlib/dlib/optimization/find_optimal_parameters_abstract.h @@ -0,0 +1,58 @@ +// Copyright (C) 2016 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_fIND_OPTIMAL_PARAMETERS_ABSTRACT_Hh_ +#ifdef DLIB_fIND_OPTIMAL_PARAMETERS_ABSTRACT_Hh_ + +#include "../matrix.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename funct + > + double find_optimal_parameters ( + double initial_search_radius, + double eps, + const unsigned int max_f_evals, + matrix& x, + const matrix& x_lower, + const matrix& x_upper, + const funct& f + ); + /*! + requires + - f(x) must be a valid expression that evaluates to a double + - x.size() == x_lower.size() == x_upper.size() + - x.size() > 0 + - 0 < eps < initial_search_radius + - max_f_evals > 1 + - min(x_upper - x_lower) > 0 + - min(x - x_lower) >= 0 && min(x_upper - x) >= 0 + (i.e. the given x should be within the bounds defined by x_lower and x_upper) + ensures + - Performs a constrained minimization of the function f() starting from + the initial point x. + - This function does not require derivatives of f(). Instead, it uses + derivative free methods to find the best setting of x. In particular, it + will begin by searching within a sphere of radius initial_search_radius + around x and will continue searching until either f() has been called + max_f_evals times or the search area has been shrunk to less than eps radius. + - #x == the value of x (within the bounds defined by x_lower and x_upper) that + was found to minimize f(). More precisely, it will always be true that: + - min(#x - x_lower) >= 0 && min(x_upper - #x) >= 0 + - returns f(#x). + throws + - No exception is thrown for executing max_f_evals iterations. This function + will simply output the best x it has seen if it runs out of iterations. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_fIND_OPTIMAL_PARAMETERS_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/optimization/isotonic_regression.h b/ml/dlib/dlib/optimization/isotonic_regression.h new file mode 100644 index 000000000..e89e70b48 --- /dev/null +++ b/ml/dlib/dlib/optimization/isotonic_regression.h @@ -0,0 +1,169 @@ +// Copyright (C) 2018 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ISOTONIC_ReGRESSION_H_ +#define DLIB_ISOTONIC_ReGRESSION_H_ + +#include "isotonic_regression_abstract.h" +#include +#include + +namespace dlib +{ + class isotonic_regression + { + public: + + template < + typename const_iterator, + typename iterator + > + void operator() ( + const_iterator begin, + const_iterator end, + iterator obegin + ) + { + do_isotonic_regression(begin, end); + + // unpack blocks to output + for (auto& block : blocks) + { + for (size_t k = 0; k < block.num; ++k) + set_val(*obegin++, block.avg); + } + + blocks.clear(); + } + + void operator() ( + std::vector& vect + ) { (*this)(vect.begin(), vect.end(), vect.begin()); } + + template + void operator() ( + std::vector>& vect + ) { (*this)(vect.begin(), vect.end(), vect.begin()); } + + + template < + typename const_iterator, + typename iterator + > + void fit_with_linear_output_interpolation ( + const_iterator begin, + const_iterator end, + iterator obegin + ) + { + do_isotonic_regression(begin, end); + + // Unpack blocks to output, but here instead of producing the step function + // output we linearly interpolate. Note that this actually fits the data less + // than the step-function, but in many applications might be closer to what you + // really when when using isotonic_regression than the step function. + for (size_t i = 0; i < blocks.size(); ++i) + { + auto& block = blocks[i]; + + double prev = (blocks.front().avg + block.avg)/2; + if (i > 0) + prev = (blocks[i-1].avg+block.avg)/2; + + double next = (blocks.back().avg + block.avg)/2; + if (i+1 < blocks.size()) + next = (blocks[i+1].avg+block.avg)/2; + + for (size_t k = 0; k < block.num; ++k) + { + const auto mid = block.num/2.0; + if (k < mid) + { + const double alpha = k/mid; + set_val(*obegin++, (1-alpha)*prev + alpha*block.avg); + } + else + { + const double alpha = k/mid-1; + set_val(*obegin++, alpha*next + (1-alpha)*block.avg); + } + } + } + + blocks.clear(); + } + + void fit_with_linear_output_interpolation ( + std::vector& vect + ) { fit_with_linear_output_interpolation(vect.begin(), vect.end(), vect.begin()); } + + template + void fit_with_linear_output_interpolation ( + std::vector>& vect + ) { fit_with_linear_output_interpolation(vect.begin(), vect.end(), vect.begin()); } + + private: + + template < + typename const_iterator + > + void do_isotonic_regression ( + const_iterator begin, + const_iterator end + ) + { + blocks.clear(); + + // Do the actual isotonic regression. The output is a step-function and is + // stored in the vector of blocks. + for (auto i = begin; i != end; ++i) + { + blocks.emplace_back(get_val(*i)); + while (blocks.size() > 1 && prev_block().avg > current_block().avg) + { + // merge the last two blocks. + prev_block() = prev_block() + current_block(); + blocks.pop_back(); + } + } + } + + + template + static double get_val(const T& v) { return v;} + + template + static double get_val(const std::pair& v) { return v.second;} + + template + static void set_val(T& v, double val) { v = val;} + + template + static void set_val(std::pair& v, double val) { v.second = val;} + + + + struct block_t + { + block_t(double val) : num(1), avg(val) {} + block_t(size_t n, double val) : num(n), avg(val) {} + + size_t num; + double avg; + + inline block_t operator+(const block_t& rhs) const + { + return block_t(num+rhs.num, + (num*avg + rhs.num*rhs.avg)/(num+rhs.num)); + } + }; + + inline block_t& prev_block() { return blocks[blocks.size()-2]; } + inline block_t& current_block() { return blocks.back(); } + + std::vector blocks; + }; +} + +#endif // DLIB_ISOTONIC_ReGRESSION_H_ + + diff --git a/ml/dlib/dlib/optimization/isotonic_regression_abstract.h b/ml/dlib/dlib/optimization/isotonic_regression_abstract.h new file mode 100644 index 000000000..b00334bef --- /dev/null +++ b/ml/dlib/dlib/optimization/isotonic_regression_abstract.h @@ -0,0 +1,128 @@ +// Copyright (C) 2018 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_ISOTONIC_ReGRESSION_ABSTRACT_H_ +#ifdef DLIB_ISOTONIC_ReGRESSION_ABSTRACT_H_ + +#include +#include + +namespace dlib +{ + class isotonic_regression + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a tool for performing 1-D isotonic regression. That is, it + finds the least squares fit of a non-parametric curve to some user supplied + data, subject to the constraint that the fitted curve is non-decreasing. + + This is done using the fast O(n) pool adjacent violators algorithm. + !*/ + + public: + + template < + typename const_iterator, + typename iterator + > + void operator() ( + const_iterator begin, + const_iterator end, + iterator obegin + ); + /*! + requires + - [begin,end) is an iterator range of float or doubles or a range of + std::pair or std::pair where T an be anything. + - obegin points to an iterator range at least std::distance(begin,end). + - obegin points to an iterator range of objects of type float, double, std::pair, or std::pair. + ensures + - Given the range of real values stored in [begin,end), this method performs isotonic regression + on this data and writes the results to obegin. To be specific: + - let IN refer to the input values stored in the iterator range [begin,end). + - let OUT refer to the output values stored in the iterator range [obegin, obegin+std::distance(begin,end)). + - This function populates OUT with values such that the sum_i of + (IN[i]-OUT[i])^2 is minimized, subject to the constraint that + OUT[i] <= OUT[i+1], i.e. that OUT is monotonic. + - It is OK for [begin,end) to overlap with the range pointed to by obegin. + That is, this function can run in-place. + - Note that when the inputs or outputs are std::pairs this algorithm only + looks at the .second field of the pair. It therefore still treats these + iterator ranges as ranges of reals since it only looks at the .second + field, which is a real number. The .first field is entirely ignored. + !*/ + + void operator() ( + std::vector& vect + ) { (*this)(vect.begin(), vect.end(), vect.begin()); } + /*! + ensures + - performs in-place isotonic regression. Therefore, #vect will contain the + isotonic regression of vect. + - #vect.size() == vect.size() + !*/ + + template + void operator() ( + std::vector>& vect + ) { (*this)(vect.begin(), vect.end(), vect.begin()); } + /*! + ensures + - performs in-place isotonic regression. Therefore, #vect will contain the + isotonic regression of vect. + - #vect.size() == vect.size() + !*/ + + + template < + typename const_iterator, + typename iterator + > + void fit_with_linear_output_interpolation ( + const_iterator begin, + const_iterator end, + iterator obegin + ); + /*! + requires + - [begin,end) is an iterator range of float or doubles or a range of + std::pair or std::pair where T an be anything. + - obegin points to an iterator range at least std::distance(begin,end). + - obegin points to an iterator range of objects of type float, double, std::pair, or std::pair. + ensures + - This function behaves just like (*this)(begin,end,obegin) except that the + output is interpolated. To explain, not that the optimal output of + isotonic regression is a step function. However, in many applications + that isn't really what you want. You want something smoother. So + fit_with_linear_output_interpolation() does isotonic regression and then + linearly interpolates the step function into a piecewise linear function. + !*/ + + void fit_with_linear_output_interpolation ( + std::vector& vect + ) { fit_with_linear_output_interpolation(vect.begin(), vect.end(), vect.begin()); } + /*! + ensures + - performs in-place isotonic regression. Therefore, #vect will contain the + isotonic regression of vect. + - #vect.size() == vect.size() + !*/ + + template + void fit_with_linear_output_interpolation ( + std::vector>& vect + ) { fit_with_linear_output_interpolation(vect.begin(), vect.end(), vect.begin()); } + /*! + ensures + - performs in-place isotonic regression. Therefore, #vect will contain the + isotonic regression of vect. + - #vect.size() == vect.size() + !*/ + + }; +} + +#endif // DLIB_ISOTONIC_ReGRESSION_ABSTRACT_H_ + + + diff --git a/ml/dlib/dlib/optimization/max_cost_assignment.h b/ml/dlib/dlib/optimization/max_cost_assignment.h new file mode 100644 index 000000000..db6c6f0d7 --- /dev/null +++ b/ml/dlib/dlib/optimization/max_cost_assignment.h @@ -0,0 +1,288 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MAX_COST_ASSIgNMENT_Hh_ +#define DLIB_MAX_COST_ASSIgNMENT_Hh_ + +#include "max_cost_assignment_abstract.h" +#include "../matrix.h" +#include +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template + typename EXP::type assignment_cost ( + const matrix_exp& cost, + const std::vector& assignment + ) + { + DLIB_ASSERT(cost.nr() == cost.nc(), + "\t type assignment_cost(cost,assignment)" + << "\n\t cost.nr(): " << cost.nr() + << "\n\t cost.nc(): " << cost.nc() + ); +#ifdef ENABLE_ASSERTS + // can't call max on an empty vector. So put an if here to guard against it. + if (assignment.size() > 0) + { + DLIB_ASSERT(0 <= min(mat(assignment)) && max(mat(assignment)) < cost.nr(), + "\t type assignment_cost(cost,assignment)" + << "\n\t cost.nr(): " << cost.nr() + << "\n\t cost.nc(): " << cost.nc() + << "\n\t min(assignment): " << min(mat(assignment)) + << "\n\t max(assignment): " << max(mat(assignment)) + ); + } +#endif + + typename EXP::type temp = 0; + for (unsigned long i = 0; i < assignment.size(); ++i) + { + temp += cost(i, assignment[i]); + } + return temp; + } + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + template + inline void compute_slack( + const long x, + std::vector& slack, + std::vector& slackx, + const matrix_exp& cost, + const std::vector& lx, + const std::vector& ly + ) + { + for (long y = 0; y < cost.nc(); ++y) + { + if (lx[x] + ly[y] - cost(x,y) < slack[y]) + { + slack[y] = lx[x] + ly[y] - cost(x,y); + slackx[y] = x; + } + } + } + } + +// ---------------------------------------------------------------------------------------- + + template + std::vector max_cost_assignment ( + const matrix_exp& cost_ + ) + { + const_temp_matrix cost(cost_); + typedef typename EXP::type type; + // This algorithm only works if the elements of the cost matrix can be reliably + // compared using operator==. However, comparing for equality with floating point + // numbers is not a stable operation. So you need to use an integer cost matrix. + COMPILE_TIME_ASSERT(std::numeric_limits::is_integer); + DLIB_ASSERT(cost.nr() == cost.nc(), + "\t std::vector max_cost_assignment(cost)" + << "\n\t cost.nr(): " << cost.nr() + << "\n\t cost.nc(): " << cost.nc() + ); + + using namespace dlib::impl; + /* + I based the implementation of this algorithm on the description of the + Hungarian algorithm on the following websites: + http://www.math.uwo.ca/~mdawes/courses/344/kuhn-munkres.pdf + http://www.topcoder.com/tc?module=Static&d1=tutorials&d2=hungarianAlgorithm + + Note that this is the fast O(n^3) version of the algorithm. + */ + + if (cost.size() == 0) + return std::vector(); + + std::vector lx, ly; + std::vector xy; + std::vector yx; + std::vector S, T; + std::vector slack; + std::vector slackx; + std::vector aug_path; + + + + + // Initially, nothing is matched. + xy.assign(cost.nc(), -1); + yx.assign(cost.nc(), -1); + /* + We maintain the following invariant: + Vertex x is matched to vertex xy[x] and + vertex y is matched to vertex yx[y]. + + A value of -1 means a vertex isn't matched to anything. Moreover, + x corresponds to rows of the cost matrix and y corresponds to the + columns of the cost matrix. So we are matching X to Y. + */ + + + // Create an initial feasible labeling. Moreover, in the following + // code we will always have: + // for all valid x and y: lx[x] + ly[y] >= cost(x,y) + lx.resize(cost.nc()); + ly.assign(cost.nc(),0); + for (long x = 0; x < cost.nr(); ++x) + lx[x] = max(rowm(cost,x)); + + // Now grow the match set by picking edges from the equality subgraph until + // we have a complete matching. + for (long match_size = 0; match_size < cost.nc(); ++match_size) + { + std::deque q; + + // Empty out the S and T sets + S.assign(cost.nc(), false); + T.assign(cost.nc(), false); + + // clear out old slack values + slack.assign(cost.nc(), std::numeric_limits::max()); + slackx.resize(cost.nc()); + /* + slack and slackx are maintained such that we always + have the following (once they get initialized by compute_slack() below): + - for all y: + - let x == slackx[y] + - slack[y] == lx[x] + ly[y] - cost(x,y) + */ + + aug_path.assign(cost.nc(), -1); + + for (long x = 0; x < cost.nc(); ++x) + { + // If x is not matched to anything + if (xy[x] == -1) + { + q.push_back(x); + S[x] = true; + + compute_slack(x, slack, slackx, cost, lx, ly); + break; + } + } + + + long x_start = 0; + long y_start = 0; + + // Find an augmenting path. + bool found_augmenting_path = false; + while (!found_augmenting_path) + { + while (q.size() > 0 && !found_augmenting_path) + { + const long x = q.front(); + q.pop_front(); + for (long y = 0; y < cost.nc(); ++y) + { + if (cost(x,y) == lx[x] + ly[y] && !T[y]) + { + // if vertex y isn't matched with anything + if (yx[y] == -1) + { + y_start = y; + x_start = x; + found_augmenting_path = true; + break; + } + + T[y] = true; + q.push_back(yx[y]); + + aug_path[yx[y]] = x; + S[yx[y]] = true; + compute_slack(yx[y], slack, slackx, cost, lx, ly); + } + } + } + + if (found_augmenting_path) + break; + + + // Since we didn't find an augmenting path we need to improve the + // feasible labeling stored in lx and ly. We also need to keep the + // slack updated accordingly. + type delta = std::numeric_limits::max(); + for (unsigned long i = 0; i < T.size(); ++i) + { + if (!T[i]) + delta = std::min(delta, slack[i]); + } + for (unsigned long i = 0; i < T.size(); ++i) + { + if (S[i]) + lx[i] -= delta; + + if (T[i]) + ly[i] += delta; + else + slack[i] -= delta; + } + + + + q.clear(); + for (long y = 0; y < cost.nc(); ++y) + { + if (!T[y] && slack[y] == 0) + { + // if vertex y isn't matched with anything + if (yx[y] == -1) + { + x_start = slackx[y]; + y_start = y; + found_augmenting_path = true; + break; + } + else + { + T[y] = true; + if (!S[yx[y]]) + { + q.push_back(yx[y]); + + aug_path[yx[y]] = slackx[y]; + S[yx[y]] = true; + compute_slack(yx[y], slack, slackx, cost, lx, ly); + } + } + } + } + } // end while (!found_augmenting_path) + + // Flip the edges along the augmenting path. This means we will add one more + // item to our matching. + for (long cx = x_start, cy = y_start, ty; + cx != -1; + cx = aug_path[cx], cy = ty) + { + ty = xy[cx]; + yx[cy] = cx; + xy[cx] = cy; + } + + } + + + return xy; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MAX_COST_ASSIgNMENT_Hh_ + + diff --git a/ml/dlib/dlib/optimization/max_cost_assignment_abstract.h b/ml/dlib/dlib/optimization/max_cost_assignment_abstract.h new file mode 100644 index 000000000..bbdb0abfb --- /dev/null +++ b/ml/dlib/dlib/optimization/max_cost_assignment_abstract.h @@ -0,0 +1,63 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_MAX_COST_ASSIgNMENT_ABSTRACT_Hh_ +#ifdef DLIB_MAX_COST_ASSIgNMENT_ABSTRACT_Hh_ + +#include "../matrix.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template + typename EXP::type assignment_cost ( + const matrix_exp& cost, + const std::vector& assignment + ); + /*! + requires + - cost.nr() == cost.nc() + - for all valid i: + - 0 <= assignment[i] < cost.nr() + ensures + - Interprets cost as a cost assignment matrix. That is, cost(i,j) + represents the cost of assigning i to j. + - Interprets assignment as a particular set of assignments. That is, + i is assigned to assignment[i]. + - returns the cost of the given assignment. That is, returns + a number which is: + sum over i: cost(i,assignment[i]) + !*/ + +// ---------------------------------------------------------------------------------------- + + template + std::vector max_cost_assignment ( + const matrix_exp& cost + ); + /*! + requires + - EXP::type == some integer type (e.g. int) + (i.e. cost must contain integers rather than floats or doubles) + - cost.nr() == cost.nc() + ensures + - Finds and returns the solution to the following optimization problem: + + Maximize: f(A) == assignment_cost(cost, A) + Subject to the following constraints: + - The elements of A are unique. That is, there aren't any + elements of A which are equal. + - A.size() == cost.nr() + + - This function implements the O(N^3) version of the Hungarian algorithm + where N is the number of rows in the cost matrix. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MAX_COST_ASSIgNMENT_ABSTRACT_Hh_ + diff --git a/ml/dlib/dlib/optimization/max_sum_submatrix.h b/ml/dlib/dlib/optimization/max_sum_submatrix.h new file mode 100644 index 000000000..1986cc26b --- /dev/null +++ b/ml/dlib/dlib/optimization/max_sum_submatrix.h @@ -0,0 +1,285 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MAX_SUM_SUBMaTRIX_Hh_ +#define DLIB_MAX_SUM_SUBMaTRIX_Hh_ + +#include "max_sum_submatrix_abstract.h" +#include "../matrix.h" +#include +#include +#include "../geometry.h" + +namespace dlib +{ + namespace impl + { + + // ------------------------------------------------------------------------------------ + + template + struct range_set + { + int top_min; + int top_max; + int bottom_min; + int bottom_max; + T weight; + + bool operator<(const range_set& item) const { return weight < item.weight; } + }; + + // ------------------------------------------------------------------------------------ + + template + bool is_terminal_set ( + const range_set& item + ) + { + return (item.top_min >= item.top_max && + item.bottom_min >= item.bottom_max); + } + + // ------------------------------------------------------------------------------------ + + template + void split ( + const range_set& rset, + range_set& a, + range_set& b + ) + { + if (rset.top_max - rset.top_min > rset.bottom_max - rset.bottom_min) + { + // split top + const int middle = (rset.top_max + rset.top_min)/2; + a.top_min = rset.top_min; + a.top_max = middle; + b.top_min = middle+1; + b.top_max = rset.top_max; + + a.bottom_min = rset.bottom_min; + a.bottom_max = rset.bottom_max; + b.bottom_min = rset.bottom_min; + b.bottom_max = rset.bottom_max; + } + else + { + // split bottom + const int middle = (rset.bottom_max + rset.bottom_min)/2; + a.bottom_min = rset.bottom_min; + a.bottom_max = middle; + b.bottom_min = middle+1; + b.bottom_max = rset.bottom_max; + + a.top_min = rset.top_min; + a.top_max = rset.top_max; + b.top_min = rset.top_min; + b.top_max = rset.top_max; + } + } + + // ------------------------------------------------------------------------------------ + + template + void find_best_column_range ( + const matrix_exp& sum_pos, + const matrix_exp& sum_neg, + const range_set& row_range, + T& weight, + int& left, + int& right + ) + { + left = 0; + right = -1; + weight = 0; + T cur_sum = 0; + int cur_pos = 0; + for (long c = 0; c < sum_pos.nc(); ++c) + { + // compute the value for the current column + T temp = sum_pos(row_range.bottom_max+1,c) - sum_pos(row_range.top_min,c); + if (row_range.top_max <= row_range.bottom_min) + temp += sum_neg(row_range.bottom_min+1,c) - sum_neg(row_range.top_max,c); + + + cur_sum += temp; + if (cur_sum > weight) + { + left = cur_pos; + right = c; + weight = cur_sum; + } + + if (cur_sum <= 0) + { + cur_sum = 0; + cur_pos = c+1; + } + + } + } + } + +// ---------------------------------------------------------------------------------------- + + template + std::vector max_sum_submatrix( + const matrix_exp& mat, + unsigned long max_rects, + double thresh_ = 0 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(thresh_ >= 0 && mat.size() > 0, + "\t std::vector max_sum_submatrix()" + << "\n\t Invalid arguments were given to this function." + << "\n\t mat.size(): " << mat.size() + << "\n\t thresh_: " << thresh_ + ); + + /* + This function is basically an implementation of the efficient subwindow search (I-ESS) + algorithm presented in the following paper: + Efficient Algorithms for Subwindow Search in Object Detection and Localization + by Senjian An, Patrick Peursum, Wanquan Liu and Svetha Venkatesh + In CVPR 2009 + + */ + + + if (max_rects == 0) + return std::vector(); + + using namespace dlib::impl; + typedef typename EXP::type element_type; + typedef typename promote::type scalar_type; + + const scalar_type thresh = static_cast(thresh_); + + + matrix sum_pos; + matrix sum_neg; + sum_pos.set_size(mat.nr()+1, mat.nc()); + sum_neg.set_size(mat.nr()+1, mat.nc()); + // integrate over the rows. + for (long c = 0; c < mat.nc(); ++c) + { + sum_pos(0,c) = 0; + sum_neg(0,c) = 0; + } + for (long r = 0; r < mat.nr(); ++r) + { + for (long c = 0; c < mat.nc(); ++c) + { + if (mat(r,c) > 0) + { + sum_pos(r+1,c) = mat(r,c) + sum_pos(r,c); + sum_neg(r+1,c) = sum_neg(r,c); + } + else + { + sum_pos(r+1,c) = sum_pos(r,c); + sum_neg(r+1,c) = mat(r,c) + sum_neg(r,c); + } + } + } + + std::priority_queue > q; + + // the range_sets will represent ranges of columns + range_set universe_set; + universe_set.bottom_min = 0; + universe_set.top_min = 0; + universe_set.bottom_max = mat.nr()-1; + universe_set.top_max = mat.nr()-1; + universe_set.weight = sum(rowm(dlib::mat(sum_pos),mat.nr())); + + q.push(universe_set); + + std::vector results; + std::vector temp_pos(mat.nc()); + std::vector temp_neg(mat.nc()); + + while (q.size() > 0) + { + if (is_terminal_set(q.top())) + { + int left, right; + scalar_type weight; + find_best_column_range(sum_pos, sum_neg, q.top(), weight, left, right); + + rectangle rect(left, q.top().top_min, + right, q.top().bottom_min); + + if (weight <= thresh) + break; + + results.push_back(rect); + + if (results.size() >= max_rects) + break; + + q = std::priority_queue >(); + // We are going to blank out the weights we just used. So adjust the sum images appropriately. + for (long c = rect.left(); c <= rect.right(); ++c) + { + temp_pos[c] = sum_pos(rect.bottom()+1,c) - sum_pos(rect.top(),c); + temp_neg[c] = sum_neg(rect.bottom()+1,c) - sum_neg(rect.top(),c); + } + // blank out the area inside the rectangle + for (long r = rect.top(); r <= rect.bottom(); ++r) + { + for (long c = rect.left(); c <= rect.right(); ++c) + { + sum_pos(r+1,c) = sum_pos(r,c); + sum_neg(r+1,c) = sum_neg(r,c); + } + } + // account for the area below the rectangle + for (long r = rect.bottom()+2; r < sum_pos.nr(); ++r) + { + for (long c = rect.left(); c <= rect.right(); ++c) + { + sum_pos(r,c) -= temp_pos[c]; + sum_neg(r,c) -= temp_neg[c]; + } + } + + + universe_set.weight = sum(rowm(dlib::mat(sum_pos),mat.nr())); + if (universe_set.weight <= thresh) + break; + + q.push(universe_set); + continue; + } + + range_set a, b; + split(q.top(), a,b); + q.pop(); + + // these variables are not used at this point in the algorithm. + int a_left, a_right; + int b_left, b_right; + + find_best_column_range(sum_pos, sum_neg, a, a.weight, a_left, a_right); + find_best_column_range(sum_pos, sum_neg, b, b.weight, b_left, b_right); + + if (a.weight > thresh) + q.push(a); + if (b.weight > thresh) + q.push(b); + + } + + + return results; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MAX_SUM_SUBMaTRIX_Hh_ + diff --git a/ml/dlib/dlib/optimization/max_sum_submatrix_abstract.h b/ml/dlib/dlib/optimization/max_sum_submatrix_abstract.h new file mode 100644 index 000000000..6714dd7fe --- /dev/null +++ b/ml/dlib/dlib/optimization/max_sum_submatrix_abstract.h @@ -0,0 +1,49 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_MAX_SUM_SUBMaTRIX_ABSTRACT_Hh_ +#ifdef DLIB_MAX_SUM_SUBMaTRIX_ABSTRACT_Hh_ + +#include "../matrix.h" +#include +#include "../geometry.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP + > + std::vector max_sum_submatrix( + const matrix_exp& mat, + unsigned long max_rects, + double thresh = 0 + ); + /*! + requires + - thresh >= 0 + - mat.size() != 0 + ensures + - This function finds the submatrix within mat which has the largest sum. It then + zeros out that submatrix and repeats the process until no more maximal submatrices can + be found. The std::vector returned will be ordered so that the rectangles with the + largest sum come first. + - Each submatrix must have a sum greater than thresh. If no such submatrix exists then + the algorithm terminates and returns an empty std::vector. + - At most max_rects rectangles are returned. + + - This function is basically an implementation of the efficient subwindow search (I-ESS) + algorithm presented in the following paper: + Efficient Algorithms for Subwindow Search in Object Detection and Localization + by Senjian An, Patrick Peursum, Wanquan Liu and Svetha Venkatesh + In CVPR 2009 + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MAX_SUM_SUBMaTRIX_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/optimization/optimization.h b/ml/dlib/dlib/optimization/optimization.h new file mode 100644 index 000000000..561d64376 --- /dev/null +++ b/ml/dlib/dlib/optimization/optimization.h @@ -0,0 +1,714 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_OPTIMIZATIOn_H_ +#define DLIB_OPTIMIZATIOn_H_ + +#include +#include +#include "optimization_abstract.h" +#include "optimization_search_strategies.h" +#include "optimization_stop_strategies.h" +#include "optimization_line_search.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// Functions that transform other functions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template + class central_differences + { + public: + central_differences(const funct& f_, double eps_ = 1e-7) : f(f_), eps(eps_){} + + template + typename T::matrix_type operator()(const T& x) const + { + // T must be some sort of dlib matrix + COMPILE_TIME_ASSERT(is_matrix::value); + + typename T::matrix_type der(x.size()); + typename T::matrix_type e(x); + for (long i = 0; i < x.size(); ++i) + { + const double old_val = e(i); + + e(i) += eps; + const double delta_plus = f(e); + e(i) = old_val - eps; + const double delta_minus = f(e); + + der(i) = (delta_plus - delta_minus)/((old_val+eps)-(old_val-eps)); + + // and finally restore the old value of this element + e(i) = old_val; + } + + return der; + } + + template + typename U::matrix_type operator()(const T& item, const U& x) const + { + // U must be some sort of dlib matrix + COMPILE_TIME_ASSERT(is_matrix::value); + + typename U::matrix_type der(x.size()); + typename U::matrix_type e(x); + for (long i = 0; i < x.size(); ++i) + { + const double old_val = e(i); + + e(i) += eps; + const double delta_plus = f(item,e); + e(i) = old_val - eps; + const double delta_minus = f(item,e); + + der(i) = (delta_plus - delta_minus)/((old_val+eps)-(old_val-eps)); + + // and finally restore the old value of this element + e(i) = old_val; + } + + return der; + } + + + double operator()(const double& x) const + { + return (f(x+eps)-f(x-eps))/((x+eps)-(x-eps)); + } + + private: + const funct& f; + const double eps; + }; + + template + const central_differences derivative(const funct& f) { return central_differences(f); } + template + const central_differences derivative(const funct& f, double eps) + { + DLIB_ASSERT ( + eps > 0, + "\tcentral_differences derivative(f,eps)" + << "\n\tYou must give an epsilon > 0" + << "\n\teps: " << eps + ); + return central_differences(f,eps); + } + +// ---------------------------------------------------------------------------------------- + + template + struct clamped_function_object + { + clamped_function_object( + const funct& f_, + const matrix_exp& x_lower_, + const matrix_exp& x_upper_ + ) : f(f_), x_lower(x_lower_), x_upper(x_upper_) + { + } + + template + double operator() ( + const T& x + ) const + { + return f(clamp(x,x_lower,x_upper)); + } + + const funct& f; + const matrix_exp& x_lower; + const matrix_exp& x_upper; + }; + + template + clamped_function_object clamp_function( + const funct& f, + const matrix_exp& x_lower, + const matrix_exp& x_upper + ) { return clamped_function_object(f,x_lower,x_upper); } + +// ---------------------------------------------------------------------------------------- + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// Functions that perform unconstrained optimization +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename search_strategy_type, + typename stop_strategy_type, + typename funct, + typename funct_der, + typename T + > + double find_min ( + search_strategy_type search_strategy, + stop_strategy_type stop_strategy, + const funct& f, + const funct_der& der, + T& x, + double min_f + ) + { + COMPILE_TIME_ASSERT(is_matrix::value); + // The starting point (i.e. x) must be a column vector. + COMPILE_TIME_ASSERT(T::NC <= 1); + + DLIB_CASSERT ( + is_col_vector(x), + "\tdouble find_min()" + << "\n\tYou have to supply column vectors to this function" + << "\n\tx.nc(): " << x.nc() + ); + + + T g, s; + + double f_value = f(x); + g = der(x); + + if (!is_finite(f_value)) + throw error("The objective function generated non-finite outputs"); + if (!is_finite(g)) + throw error("The objective function generated non-finite outputs"); + + while(stop_strategy.should_continue_search(x, f_value, g) && f_value > min_f) + { + s = search_strategy.get_next_direction(x, f_value, g); + + double alpha = line_search( + make_line_search_function(f,x,s, f_value), + f_value, + make_line_search_function(der,x,s, g), + dot(g,s), // compute initial gradient for the line search + search_strategy.get_wolfe_rho(), search_strategy.get_wolfe_sigma(), min_f, + search_strategy.get_max_line_search_iterations()); + + // Take the search step indicated by the above line search + x += alpha*s; + + if (!is_finite(f_value)) + throw error("The objective function generated non-finite outputs"); + if (!is_finite(g)) + throw error("The objective function generated non-finite outputs"); + } + + return f_value; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename search_strategy_type, + typename stop_strategy_type, + typename funct, + typename funct_der, + typename T + > + double find_max ( + search_strategy_type search_strategy, + stop_strategy_type stop_strategy, + const funct& f, + const funct_der& der, + T& x, + double max_f + ) + { + COMPILE_TIME_ASSERT(is_matrix::value); + // The starting point (i.e. x) must be a column vector. + COMPILE_TIME_ASSERT(T::NC <= 1); + + DLIB_CASSERT ( + is_col_vector(x), + "\tdouble find_max()" + << "\n\tYou have to supply column vectors to this function" + << "\n\tx.nc(): " << x.nc() + ); + + T g, s; + + // This function is basically just a copy of find_min() but with - put in the right places + // to flip things around so that it ends up looking for the max rather than the min. + + double f_value = -f(x); + g = -der(x); + + if (!is_finite(f_value)) + throw error("The objective function generated non-finite outputs"); + if (!is_finite(g)) + throw error("The objective function generated non-finite outputs"); + + while(stop_strategy.should_continue_search(x, f_value, g) && f_value > -max_f) + { + s = search_strategy.get_next_direction(x, f_value, g); + + double alpha = line_search( + negate_function(make_line_search_function(f,x,s, f_value)), + f_value, + negate_function(make_line_search_function(der,x,s, g)), + dot(g,s), // compute initial gradient for the line search + search_strategy.get_wolfe_rho(), search_strategy.get_wolfe_sigma(), -max_f, + search_strategy.get_max_line_search_iterations() + ); + + // Take the search step indicated by the above line search + x += alpha*s; + + // Don't forget to negate these outputs from the line search since they are + // from the unnegated versions of f() and der() + g *= -1; + f_value *= -1; + + if (!is_finite(f_value)) + throw error("The objective function generated non-finite outputs"); + if (!is_finite(g)) + throw error("The objective function generated non-finite outputs"); + } + + return -f_value; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename search_strategy_type, + typename stop_strategy_type, + typename funct, + typename T + > + double find_min_using_approximate_derivatives ( + search_strategy_type search_strategy, + stop_strategy_type stop_strategy, + const funct& f, + T& x, + double min_f, + double derivative_eps = 1e-7 + ) + { + COMPILE_TIME_ASSERT(is_matrix::value); + // The starting point (i.e. x) must be a column vector. + COMPILE_TIME_ASSERT(T::NC <= 1); + + DLIB_CASSERT ( + is_col_vector(x) && derivative_eps > 0, + "\tdouble find_min_using_approximate_derivatives()" + << "\n\tYou have to supply column vectors to this function" + << "\n\tx.nc(): " << x.nc() + << "\n\tderivative_eps: " << derivative_eps + ); + + T g, s; + + double f_value = f(x); + g = derivative(f,derivative_eps)(x); + + if (!is_finite(f_value)) + throw error("The objective function generated non-finite outputs"); + if (!is_finite(g)) + throw error("The objective function generated non-finite outputs"); + + while(stop_strategy.should_continue_search(x, f_value, g) && f_value > min_f) + { + s = search_strategy.get_next_direction(x, f_value, g); + + double alpha = line_search( + make_line_search_function(f,x,s,f_value), + f_value, + derivative(make_line_search_function(f,x,s),derivative_eps), + dot(g,s), // Sometimes the following line is a better way of determining the initial gradient. + //derivative(make_line_search_function(f,x,s),derivative_eps)(0), + search_strategy.get_wolfe_rho(), search_strategy.get_wolfe_sigma(), min_f, + search_strategy.get_max_line_search_iterations() + ); + + // Take the search step indicated by the above line search + x += alpha*s; + + g = derivative(f,derivative_eps)(x); + + if (!is_finite(f_value)) + throw error("The objective function generated non-finite outputs"); + if (!is_finite(g)) + throw error("The objective function generated non-finite outputs"); + } + + return f_value; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename search_strategy_type, + typename stop_strategy_type, + typename funct, + typename T + > + double find_max_using_approximate_derivatives ( + search_strategy_type search_strategy, + stop_strategy_type stop_strategy, + const funct& f, + T& x, + double max_f, + double derivative_eps = 1e-7 + ) + { + COMPILE_TIME_ASSERT(is_matrix::value); + // The starting point (i.e. x) must be a column vector. + COMPILE_TIME_ASSERT(T::NC <= 1); + + DLIB_CASSERT ( + is_col_vector(x) && derivative_eps > 0, + "\tdouble find_max_using_approximate_derivatives()" + << "\n\tYou have to supply column vectors to this function" + << "\n\tx.nc(): " << x.nc() + << "\n\tderivative_eps: " << derivative_eps + ); + + // Just negate the necessary things and call the find_min version of this function. + return -find_min_using_approximate_derivatives( + search_strategy, + stop_strategy, + negate_function(f), + x, + -max_f, + derivative_eps + ); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// Functions for box constrained optimization +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template + T zero_bounded_variables ( + const double eps, + T vect, + const T& x, + const T& gradient, + const U& x_lower, + const V& x_upper + ) + { + for (long i = 0; i < gradient.size(); ++i) + { + const double tol = eps*std::abs(x(i)); + // if x(i) is an active bound constraint + if (x_lower(i)+tol >= x(i) && gradient(i) > 0) + vect(i) = 0; + else if (x_upper(i)-tol <= x(i) && gradient(i) < 0) + vect(i) = 0; + } + return vect; + } + +// ---------------------------------------------------------------------------------------- + + template + T gap_step_assign_bounded_variables ( + const double eps, + T vect, + const T& x, + const T& gradient, + const U& x_lower, + const V& x_upper + ) + { + for (long i = 0; i < gradient.size(); ++i) + { + const double tol = eps*std::abs(x(i)); + // If x(i) is an active bound constraint then we should set its search + // direction such that a single step along the direction either does nothing or + // closes the gap of size tol before hitting the bound exactly. + if (x_lower(i)+tol >= x(i) && gradient(i) > 0) + vect(i) = x_lower(i)-x(i); + else if (x_upper(i)-tol <= x(i) && gradient(i) < 0) + vect(i) = x_upper(i)-x(i); + } + return vect; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename search_strategy_type, + typename stop_strategy_type, + typename funct, + typename funct_der, + typename T, + typename EXP1, + typename EXP2 + > + double find_min_box_constrained ( + search_strategy_type search_strategy, + stop_strategy_type stop_strategy, + const funct& f, + const funct_der& der, + T& x, + const matrix_exp& x_lower, + const matrix_exp& x_upper + ) + { + /* + The implementation of this function is more or less based on the discussion in + the paper Projected Newton-type Methods in Machine Learning by Mark Schmidt, et al. + */ + + // make sure the requires clause is not violated + COMPILE_TIME_ASSERT(is_matrix::value); + // The starting point (i.e. x) must be a column vector. + COMPILE_TIME_ASSERT(T::NC <= 1); + + DLIB_CASSERT ( + is_col_vector(x) && is_col_vector(x_lower) && is_col_vector(x_upper) && + x.size() == x_lower.size() && x.size() == x_upper.size(), + "\tdouble find_min_box_constrained()" + << "\n\t The inputs to this function must be equal length column vectors." + << "\n\t is_col_vector(x): " << is_col_vector(x) + << "\n\t is_col_vector(x_upper): " << is_col_vector(x_upper) + << "\n\t is_col_vector(x_upper): " << is_col_vector(x_upper) + << "\n\t x.size(): " << x.size() + << "\n\t x_lower.size(): " << x_lower.size() + << "\n\t x_upper.size(): " << x_upper.size() + ); + DLIB_ASSERT ( + min(x_upper-x_lower) >= 0, + "\tdouble find_min_box_constrained()" + << "\n\t You have to supply proper box constraints to this function." + << "\n\r min(x_upper-x_lower): " << min(x_upper-x_lower) + ); + + + T g, s; + double f_value = f(x); + g = der(x); + + if (!is_finite(f_value)) + throw error("The objective function generated non-finite outputs"); + if (!is_finite(g)) + throw error("The objective function generated non-finite outputs"); + + // gap_eps determines how close we have to get to a bound constraint before we + // start basically dropping it from the optimization and consider it to be an + // active constraint. + const double gap_eps = 1e-8; + + double last_alpha = 1; + while(stop_strategy.should_continue_search(x, f_value, g)) + { + s = search_strategy.get_next_direction(x, f_value, zero_bounded_variables(gap_eps, g, x, g, x_lower, x_upper)); + s = gap_step_assign_bounded_variables(gap_eps, s, x, g, x_lower, x_upper); + + double alpha = backtracking_line_search( + make_line_search_function(clamp_function(f,x_lower,x_upper), x, s, f_value), + f_value, + dot(g,s), // compute gradient for the line search + last_alpha, + search_strategy.get_wolfe_rho(), + search_strategy.get_max_line_search_iterations()); + + // Do a trust region style thing for alpha. The idea is that if we take a + // small step then we are likely to take another small step. So we reuse the + // alpha from the last iteration unless the line search didn't shrink alpha at + // all, in that case, we start with a bigger alpha next time. + if (alpha == last_alpha) + last_alpha = std::min(last_alpha*10,1.0); + else + last_alpha = alpha; + + // Take the search step indicated by the above line search + x = dlib::clamp(x + alpha*s, x_lower, x_upper); + g = der(x); + + if (!is_finite(f_value)) + throw error("The objective function generated non-finite outputs"); + if (!is_finite(g)) + throw error("The objective function generated non-finite outputs"); + } + + return f_value; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename search_strategy_type, + typename stop_strategy_type, + typename funct, + typename funct_der, + typename T + > + double find_min_box_constrained ( + search_strategy_type search_strategy, + stop_strategy_type stop_strategy, + const funct& f, + const funct_der& der, + T& x, + double x_lower, + double x_upper + ) + { + // The starting point (i.e. x) must be a column vector. + COMPILE_TIME_ASSERT(T::NC <= 1); + + typedef typename T::type scalar_type; + return find_min_box_constrained(search_strategy, + stop_strategy, + f, + der, + x, + uniform_matrix(x.size(),1,x_lower), + uniform_matrix(x.size(),1,x_upper) ); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename search_strategy_type, + typename stop_strategy_type, + typename funct, + typename funct_der, + typename T, + typename EXP1, + typename EXP2 + > + double find_max_box_constrained ( + search_strategy_type search_strategy, + stop_strategy_type stop_strategy, + const funct& f, + const funct_der& der, + T& x, + const matrix_exp& x_lower, + const matrix_exp& x_upper + ) + { + // make sure the requires clause is not violated + COMPILE_TIME_ASSERT(is_matrix::value); + // The starting point (i.e. x) must be a column vector. + COMPILE_TIME_ASSERT(T::NC <= 1); + + DLIB_CASSERT ( + is_col_vector(x) && is_col_vector(x_lower) && is_col_vector(x_upper) && + x.size() == x_lower.size() && x.size() == x_upper.size(), + "\tdouble find_max_box_constrained()" + << "\n\t The inputs to this function must be equal length column vectors." + << "\n\t is_col_vector(x): " << is_col_vector(x) + << "\n\t is_col_vector(x_upper): " << is_col_vector(x_upper) + << "\n\t is_col_vector(x_upper): " << is_col_vector(x_upper) + << "\n\t x.size(): " << x.size() + << "\n\t x_lower.size(): " << x_lower.size() + << "\n\t x_upper.size(): " << x_upper.size() + ); + DLIB_ASSERT ( + min(x_upper-x_lower) >= 0, + "\tdouble find_max_box_constrained()" + << "\n\t You have to supply proper box constraints to this function." + << "\n\r min(x_upper-x_lower): " << min(x_upper-x_lower) + ); + + // This function is basically just a copy of find_min_box_constrained() but with - put + // in the right places to flip things around so that it ends up looking for the max + // rather than the min. + + T g, s; + double f_value = -f(x); + g = -der(x); + + if (!is_finite(f_value)) + throw error("The objective function generated non-finite outputs"); + if (!is_finite(g)) + throw error("The objective function generated non-finite outputs"); + + // gap_eps determines how close we have to get to a bound constraint before we + // start basically dropping it from the optimization and consider it to be an + // active constraint. + const double gap_eps = 1e-8; + + double last_alpha = 1; + while(stop_strategy.should_continue_search(x, f_value, g)) + { + s = search_strategy.get_next_direction(x, f_value, zero_bounded_variables(gap_eps, g, x, g, x_lower, x_upper)); + s = gap_step_assign_bounded_variables(gap_eps, s, x, g, x_lower, x_upper); + + double alpha = backtracking_line_search( + negate_function(make_line_search_function(clamp_function(f,x_lower,x_upper), x, s, f_value)), + f_value, + dot(g,s), // compute gradient for the line search + last_alpha, + search_strategy.get_wolfe_rho(), + search_strategy.get_max_line_search_iterations()); + + // Do a trust region style thing for alpha. The idea is that if we take a + // small step then we are likely to take another small step. So we reuse the + // alpha from the last iteration unless the line search didn't shrink alpha at + // all, in that case, we start with a bigger alpha next time. + if (alpha == last_alpha) + last_alpha = std::min(last_alpha*10,1.0); + else + last_alpha = alpha; + + // Take the search step indicated by the above line search + x = dlib::clamp(x + alpha*s, x_lower, x_upper); + g = -der(x); + + // Don't forget to negate the output from the line search since it is from the + // unnegated version of f() + f_value *= -1; + + if (!is_finite(f_value)) + throw error("The objective function generated non-finite outputs"); + if (!is_finite(g)) + throw error("The objective function generated non-finite outputs"); + } + + return -f_value; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename search_strategy_type, + typename stop_strategy_type, + typename funct, + typename funct_der, + typename T + > + double find_max_box_constrained ( + search_strategy_type search_strategy, + stop_strategy_type stop_strategy, + const funct& f, + const funct_der& der, + T& x, + double x_lower, + double x_upper + ) + { + // The starting point (i.e. x) must be a column vector. + COMPILE_TIME_ASSERT(T::NC <= 1); + + typedef typename T::type scalar_type; + return find_max_box_constrained(search_strategy, + stop_strategy, + f, + der, + x, + uniform_matrix(x.size(),1,x_lower), + uniform_matrix(x.size(),1,x_upper) ); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_OPTIMIZATIOn_H_ + diff --git a/ml/dlib/dlib/optimization/optimization_abstract.h b/ml/dlib/dlib/optimization/optimization_abstract.h new file mode 100644 index 000000000..f3c42c2b4 --- /dev/null +++ b/ml/dlib/dlib/optimization/optimization_abstract.h @@ -0,0 +1,468 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_OPTIMIZATIOn_ABSTRACT_ +#ifdef DLIB_OPTIMIZATIOn_ABSTRACT_ + +#include +#include +#include "../matrix/matrix_abstract.h" +#include "../algs.h" +#include "optimization_search_strategies_abstract.h" +#include "optimization_stop_strategies_abstract.h" +#include "optimization_line_search_abstract.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// Functions that transform other functions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename funct + > + class central_differences; + /*! + This is a function object that represents the derivative of some other + function. + + Note that if funct is a function of a double then the derivative of + funct is just a double but if funct is a function of a dlib::matrix (i.e. a + function of many variables) then its derivative is a gradient vector (a column + vector in particular). + !*/ + + template < + typename funct + > + const central_differences derivative( + const funct& f, + double eps + ); + /*! + requires + - f == a function that returns a scalar + - f must have one of the following forms: + - double f(double) + - double f(dlib::matrix) (where the matrix is a column vector) + - double f(T, dlib::matrix) (where the matrix is a column vector. In + this case the derivative of f is taken with respect to the second argument.) + - eps > 0 + ensures + - returns a function that represents the derivative of the function f. It + is approximated numerically by: + (f(x+eps)-f(x-eps))/(2*eps) + !*/ + + template < + typename funct + > + const central_differences derivative( + const funct& f + ); + /*! + ensures + - returns derivative(f, 1e-7) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename funct, + typename EXP1, + typename EXP2 + > + clamped_function_object clamp_function ( + const funct& f, + const matrix_exp& x_lower, + const matrix_exp& x_upper + ); + /*! + requires + - f == a function that takes a matrix and returns a scalar value. Moreover, f + must be capable of taking in matrices with the same dimensions as x_lower and + x_upper. So f(x_lower) must be a valid expression that evaluates to a scalar + value. + - x_lower.nr() == x_upper.nr() && x_lower.nc() == x_upper.nc() + (i.e. x_lower and x_upper must have the same dimensions) + - x_lower and x_upper must contain the same type of elements. + ensures + - returns a function object that represents the function g(x) where + g(x) == f(clamp(x,x_lower,x_upper)) + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// Functions that perform unconstrained optimization +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename search_strategy_type, + typename stop_strategy_type, + typename funct, + typename funct_der, + typename T + > + double find_min ( + search_strategy_type search_strategy, + stop_strategy_type stop_strategy, + const funct& f, + const funct_der& der, + T& x, + double min_f + ); + /*! + requires + - search_strategy == an object that defines a search strategy such as one + of the objects from dlib/optimization/optimization_search_strategies_abstract.h + - stop_strategy == an object that defines a stop strategy such as one of + the objects from dlib/optimization/optimization_stop_strategies_abstract.h + - f(x) must be a valid expression that evaluates to a double + - der(x) must be a valid expression that evaluates to the derivative of f() at x. + - is_col_vector(x) == true + ensures + - Performs an unconstrained minimization of the function f() using the given + search_strategy and starting from the initial point x. + - The function is optimized until stop_strategy decides that an acceptable + point has been found or f(#x) < min_f. + - #x == the value of x that was found to minimize f() + - returns f(#x). + - When this function makes calls to f() and der() it always does so by + first calling f() and then calling der(). That is, these two functions + are always called in pairs with f() being called first and then der() + being called second. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename search_strategy_type, + typename stop_strategy_type, + typename funct, + typename funct_der, + typename T + > + double find_max ( + search_strategy_type search_strategy, + stop_strategy_type stop_strategy, + const funct& f, + const funct_der& der, + T& x, + double max_f + ); + /*! + requires + - search_strategy == an object that defines a search strategy such as one + of the objects from dlib/optimization/optimization_search_strategies_abstract.h + - stop_strategy == an object that defines a stop strategy such as one of + the objects from dlib/optimization/optimization_stop_strategies_abstract.h + - f(x) must be a valid expression that evaluates to a double + - der(x) must be a valid expression that evaluates to the derivative of f() at x. + - is_col_vector(x) == true + ensures + - Performs an unconstrained maximization of the function f() using the given + search_strategy and starting from the initial point x. + - The function is optimized until stop_strategy decides that an acceptable + point has been found or f(#x) > max_f. + - #x == the value of x that was found to maximize f() + - returns f(#x). + - When this function makes calls to f() and der() it always does so by + first calling f() and then calling der(). That is, these two functions + are always called in pairs with f() being called first and then der() + being called second. + - Note that this function solves the maximization problem by converting it + into a minimization problem. Therefore, the values of f and its derivative + reported to the stopping strategy will be negated. That is, stop_strategy + will see -f() and -der(). All this really means is that the status messages + from a stopping strategy in verbose mode will display a negated objective + value. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename search_strategy_type, + typename stop_strategy_type, + typename funct, + typename T + > + double find_min_using_approximate_derivatives ( + search_strategy_type search_strategy, + stop_strategy_type stop_strategy, + const funct& f, + T& x, + double min_f, + double derivative_eps = 1e-7 + ); + /*! + requires + - search_strategy == an object that defines a search strategy such as one + of the objects from dlib/optimization/optimization_search_strategies_abstract.h + - stop_strategy == an object that defines a stop strategy such as one of + the objects from dlib/optimization/optimization_stop_strategies_abstract.h + - f(x) must be a valid expression that evaluates to a double + - is_col_vector(x) == true + - derivative_eps > 0 + ensures + - Performs an unconstrained minimization of the function f() using the given + search_strategy and starting from the initial point x. + - The function is optimized until stop_strategy decides that an acceptable + point has been found or f(#x) < min_f. + - #x == the value of x that was found to minimize f() + - returns f(#x). + - Uses the dlib::derivative(f,derivative_eps) function to compute gradient + information. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename search_strategy_type, + typename stop_strategy_type, + typename funct, + typename T + > + double find_max_using_approximate_derivatives ( + search_strategy_type search_strategy, + stop_strategy_type stop_strategy, + const funct& f, + T& x, + double max_f, + double derivative_eps = 1e-7 + ); + /*! + requires + - search_strategy == an object that defines a search strategy such as one + of the objects from dlib/optimization/optimization_search_strategies_abstract.h + - stop_strategy == an object that defines a stop strategy such as one of + the objects from dlib/optimization/optimization_stop_strategies_abstract.h + - f(x) must be a valid expression that evaluates to a double + - is_col_vector(x) == true + - derivative_eps > 0 + ensures + - Performs an unconstrained maximization of the function f() using the given + search_strategy and starting from the initial point x. + - The function is optimized until stop_strategy decides that an acceptable + point has been found or f(#x) > max_f. + - #x == the value of x that was found to maximize f() + - returns f(#x). + - Uses the dlib::derivative(f,derivative_eps) function to compute gradient + information. + - Note that this function solves the maximization problem by converting it + into a minimization problem. Therefore, the values of f and its derivative + reported to the stopping strategy will be negated. That is, stop_strategy + will see -f() and -der(). All this really means is that the status messages + from a stopping strategy in verbose mode will display a negated objective + value. + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// Functions that perform box constrained optimization +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename search_strategy_type, + typename stop_strategy_type, + typename funct, + typename funct_der, + typename T, + typename EXP1, + typename EXP2 + > + double find_min_box_constrained ( + search_strategy_type search_strategy, + stop_strategy_type stop_strategy, + const funct& f, + const funct_der& der, + T& x, + const matrix_exp& x_lower, + const matrix_exp& x_upper + ); + /*! + requires + - search_strategy == an object that defines a search strategy such as one + of the objects from dlib/optimization/optimization_search_strategies_abstract.h + - stop_strategy == an object that defines a stop strategy such as one of + the objects from dlib/optimization/optimization_stop_strategies_abstract.h + - f(x) must be a valid expression that evaluates to a double + - der(x) must be a valid expression that evaluates to the derivative of f() at x. + - is_col_vector(x) == true + - is_col_vector(x_lower) == true + - is_col_vector(x_upper) == true + - x.size() == x_lower.size() == x_upper.size() + (i.e. x, x_lower, and x_upper need to all be column vectors of the same dimensionality) + - min(x_upper-x_lower) >= 0 + (i.e. x_upper must contain upper bounds relative to x_lower) + ensures + - Performs a box constrained minimization of the function f() using the given + search_strategy and starting from the initial point x. That is, we try to + find the x value that minimizes f(x) but is also within the box constraints + specified by x_lower and x_upper. That is, we ensure that #x satisfies: + - min(#x - x_lower) >= 0 && min(x_upper - #x) >= 0 + - This function uses a backtracking line search along with a gradient projection + step to handle the box constraints. + - The function is optimized until stop_strategy decides that an acceptable + point has been found. + - #x == the value of x that was found to minimize f() within the given box + constraints. + - returns f(#x). + - The last call to f() will be made with f(#x). + - When calling f() and der(), the input passed to them will always be inside + the box constraints defined by x_lower and x_upper. + - When calling der(x), it will always be the case that the last call to f() was + made with the same x value. This means that you can reuse any intermediate + results from the previous call to f(x) inside der(x) rather than recomputing + them inside der(x). + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename search_strategy_type, + typename stop_strategy_type, + typename funct, + typename funct_der, + typename T + > + double find_min_box_constrained ( + search_strategy_type search_strategy, + stop_strategy_type stop_strategy, + const funct& f, + const funct_der& der, + T& x, + const double x_lower, + const double x_upper + ); + /*! + requires + - search_strategy == an object that defines a search strategy such as one + of the objects from dlib/optimization/optimization_search_strategies_abstract.h + - stop_strategy == an object that defines a stop strategy such as one of + the objects from dlib/optimization/optimization_stop_strategies_abstract.h + - f(x) must be a valid expression that evaluates to a double + - der(x) must be a valid expression that evaluates to the derivative of f() at x. + - is_col_vector(x) == true + - x_lower < x_upper + ensures + - This function is identical to find_min_box_constrained() as defined above + except that it takes x_lower and x_upper as doubles rather than column + vectors. In this case, all variables have the same lower bound of x_lower + and similarly have the same upper bound of x_upper. Therefore, this is just + a convenience function for calling find_max_box_constrained() when all + variables have the same bound constraints. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename search_strategy_type, + typename stop_strategy_type, + typename funct, + typename funct_der, + typename T, + typename EXP1, + typename EXP2 + > + double find_max_box_constrained ( + search_strategy_type search_strategy, + stop_strategy_type stop_strategy, + const funct& f, + const funct_der& der, + T& x, + const matrix_exp& x_lower, + const matrix_exp& x_upper + ); + /*! + requires + - search_strategy == an object that defines a search strategy such as one + of the objects from dlib/optimization/optimization_search_strategies_abstract.h + - stop_strategy == an object that defines a stop strategy such as one of + the objects from dlib/optimization/optimization_stop_strategies_abstract.h + - f(x) must be a valid expression that evaluates to a double + - der(x) must be a valid expression that evaluates to the derivative of f() at x. + - is_col_vector(x) == true + - is_col_vector(x_lower) == true + - is_col_vector(x_upper) == true + - x.size() == x_lower.size() == x_upper.size() + (i.e. x, x_lower, and x_upper need to all be column vectors of the same dimensionality) + - min(x_upper-x_lower) >= 0 + (i.e. x_upper must contain upper bounds relative to x_lower) + ensures + - Performs a box constrained maximization of the function f() using the given + search_strategy and starting from the initial point x. That is, we try to + find the x value that maximizes f(x) but is also within the box constraints + specified by x_lower and x_upper. That is, we ensure that #x satisfies: + - min(#x - x_lower) >= 0 && min(x_upper - #x) >= 0 + - This function uses a backtracking line search along with a gradient projection + step to handle the box constraints. + - The function is optimized until stop_strategy decides that an acceptable + point has been found. + - #x == the value of x that was found to maximize f() within the given box + constraints. + - returns f(#x). + - The last call to f() will be made with f(#x). + - When calling f() and der(), the input passed to them will always be inside + the box constraints defined by x_lower and x_upper. + - When calling der(x), it will always be the case that the last call to f() was + made with the same x value. This means that you can reuse any intermediate + results from the previous call to f(x) inside der(x) rather than recomputing + them inside der(x). + - Note that this function solves the maximization problem by converting it + into a minimization problem. Therefore, the values of f and its derivative + reported to the stopping strategy will be negated. That is, stop_strategy + will see -f() and -der(). All this really means is that the status messages + from a stopping strategy in verbose mode will display a negated objective + value. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename search_strategy_type, + typename stop_strategy_type, + typename funct, + typename funct_der, + typename T + > + double find_max_box_constrained ( + search_strategy_type search_strategy, + stop_strategy_type stop_strategy, + const funct& f, + const funct_der& der, + T& x, + const double x_lower, + const double x_upper + ); + /*! + requires + - search_strategy == an object that defines a search strategy such as one + of the objects from dlib/optimization/optimization_search_strategies_abstract.h + - stop_strategy == an object that defines a stop strategy such as one of + the objects from dlib/optimization/optimization_stop_strategies_abstract.h + - f(x) must be a valid expression that evaluates to a double + - der(x) must be a valid expression that evaluates to the derivative of f() at x. + - is_col_vector(x) == true + - x_lower < x_upper + ensures + - This function is identical to find_max_box_constrained() as defined above + except that it takes x_lower and x_upper as doubles rather than column + vectors. In this case, all variables have the same lower bound of x_lower + and similarly have the same upper bound of x_upper. Therefore, this is just + a convenience function for calling find_max_box_constrained() when all + variables have the same bound constraints. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_OPTIMIZATIOn_ABSTRACT_ + + + diff --git a/ml/dlib/dlib/optimization/optimization_bobyqa.h b/ml/dlib/dlib/optimization/optimization_bobyqa.h new file mode 100644 index 000000000..6fbc40c06 --- /dev/null +++ b/ml/dlib/dlib/optimization/optimization_bobyqa.h @@ -0,0 +1,3423 @@ +// Copyright (C) 2009 M.J.D. Powell, Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_OPTIMIZATIOn_BOBYQA_Hh_ +#define DLIB_OPTIMIZATIOn_BOBYQA_Hh_ + +/* + The code in this file is derived from Powell's BOBYQA Fortran code. + It was created by running f2c on the original Fortran code and then + massaging the resulting C code into what you can see below. + + + The following paper, published in 2009 by Powell, describes the + detailed workings of the BOBYQA algorithm. + + The BOBYQA algorithm for bound constrained optimization + without derivatives by M.J.D. Powell +*/ + +#include +#include +#include + +#include "../matrix.h" +#include "optimization_bobyqa_abstract.h" +#include "optimization.h" + +// ---------------------------------------------------------------------------------------- + +namespace dlib +{ + + class bobyqa_failure : public error { + public: bobyqa_failure(const std::string& s):error(s){} + }; + +// ---------------------------------------------------------------------------------------- + + class bobyqa_implementation + { + typedef long integer; + typedef double doublereal; + + public: + + template < + typename funct, + typename T, + typename U + > + double find_min ( + const funct& f, + T& x, + long npt, + const U& xl_, + const U& xu_, + const double rhobeg, + const double rhoend, + const long max_f_evals + ) const + { + const unsigned long n = x.size(); + const unsigned long w_size = (npt+5)*(npt+n)+3*n*(n+5)/2; + std::unique_ptr w(new doublereal[w_size]); + + // make these temporary matrices becuse U might be some + // kind of matrix_exp that doesn't support taking the address + // of the first element. + matrix xl(xl_); + matrix xu(xu_); + + + return bobyqa_ (f, + x.size(), + npt, + &x(0), + &xl(0), + &xu(0), + rhobeg, + rhoend, + max_f_evals, + w.get() ); + } + + private: + + + template + doublereal bobyqa_( + const funct& calfun, + const integer n, + const integer npt, + doublereal *x, + const doublereal *xl, + const doublereal *xu, + const doublereal rhobeg, + const doublereal rhoend, + const integer maxfun, + doublereal *w + ) const + { + + /* System generated locals */ + integer i__1; + doublereal d__1, d__2; + + /* Local variables */ + integer j, id_, np, iw, igo, ihq, ixb, ixa, ifv, isl, jsl, ipq, ivl, ixn, ixo, ixp, isu, jsu, ndim; + doublereal temp, zero; + integer ibmat, izmat; + + + /* This subroutine seeks the least value of a function of many variables, */ + /* by applying a trust region method that forms quadratic models by */ + /* interpolation. There is usually some freedom in the interpolation */ + /* conditions, which is taken up by minimizing the Frobenius norm of */ + /* the change to the second derivative of the model, beginning with the */ + /* zero matrix. The values of the variables are constrained by upper and */ + /* lower bounds. The arguments of the subroutine are as follows. */ + + /* N must be set to the number of variables and must be at least two. */ + /* NPT is the number of interpolation conditions. Its value must be in */ + /* the interval [N+2,(N+1)(N+2)/2]. Choices that exceed 2*N+1 are not */ + /* recommended. */ + /* Initial values of the variables must be set in X(1),X(2),...,X(N). They */ + /* will be changed to the values that give the least calculated F. */ + /* For I=1,2,...,N, XL(I) and XU(I) must provide the lower and upper */ + /* bounds, respectively, on X(I). The construction of quadratic models */ + /* requires XL(I) to be strictly less than XU(I) for each I. Further, */ + /* the contribution to a model from changes to the I-th variable is */ + /* damaged severely by rounding errors if XU(I)-XL(I) is too small. */ + /* RHOBEG and RHOEND must be set to the initial and final values of a trust */ + /* region radius, so both must be positive with RHOEND no greater than */ + /* RHOBEG. Typically, RHOBEG should be about one tenth of the greatest */ + /* expected change to a variable, while RHOEND should indicate the */ + /* accuracy that is required in the final values of the variables. An */ + /* error return occurs if any of the differences XU(I)-XL(I), I=1,...,N, */ + /* is less than 2*RHOBEG. */ + /* MAXFUN must be set to an upper bound on the number of calls of CALFUN. */ + /* The array W will be used for working space. Its length must be at least */ + /* (NPT+5)*(NPT+N)+3*N*(N+5)/2. */ + + /* Parameter adjustments */ + --w; + --xu; + --xl; + --x; + + /* Function Body */ + np = n + 1; + + /* Return if the value of NPT is unacceptable. */ + if (npt < n + 2 || npt > (n + 2) * np / 2) { + throw bobyqa_failure("Return from BOBYQA because NPT is not in the required interval"); + //goto L40; + } + + /* Partition the working space array, so that different parts of it can */ + /* be treated separately during the calculation of BOBYQB. The partition */ + /* requires the first (NPT+2)*(NPT+N)+3*N*(N+5)/2 elements of W plus the */ + /* space that is taken by the last array in the argument list of BOBYQB. */ + + ndim = npt + n; + ixb = 1; + ixp = ixb + n; + ifv = ixp + n * npt; + ixo = ifv + npt; + igo = ixo + n; + ihq = igo + n; + ipq = ihq + n * np / 2; + ibmat = ipq + npt; + izmat = ibmat + ndim * n; + isl = izmat + npt * (npt - np); + isu = isl + n; + ixn = isu + n; + ixa = ixn + n; + id_ = ixa + n; + ivl = id_ + n; + iw = ivl + ndim; + + /* Return if there is insufficient space between the bounds. Modify the */ + /* initial X if necessary in order to avoid conflicts between the bounds */ + /* and the construction of the first quadratic model. The lower and upper */ + /* bounds on moves from the updated X are set now, in the ISL and ISU */ + /* partitions of W, in order to provide useful and exact information about */ + /* components of X that become within distance RHOBEG from their bounds. */ + + zero = 0.; + i__1 = n; + for (j = 1; j <= i__1; ++j) { + temp = xu[j] - xl[j]; + if (temp < rhobeg + rhobeg) { + throw bobyqa_failure("Return from BOBYQA because one of the differences in x_lower and x_upper is less than 2*rho_begin"); + //goto L40; + } + jsl = isl + j - 1; + jsu = jsl + n; + w[jsl] = xl[j] - x[j]; + w[jsu] = xu[j] - x[j]; + if (w[jsl] >= -(rhobeg)) { + if (w[jsl] >= zero) { + x[j] = xl[j]; + w[jsl] = zero; + w[jsu] = temp; + } else { + x[j] = xl[j] + rhobeg; + w[jsl] = -(rhobeg); + /* Computing MAX */ + d__1 = xu[j] - x[j]; + w[jsu] = std::max(d__1,rhobeg); + } + } else if (w[jsu] <= rhobeg) { + if (w[jsu] <= zero) { + x[j] = xu[j]; + w[jsl] = -temp; + w[jsu] = zero; + } else { + x[j] = xu[j] - rhobeg; + /* Computing MIN */ + d__1 = xl[j] - x[j], d__2 = -(rhobeg); + w[jsl] = std::min(d__1,d__2); + w[jsu] = rhobeg; + } + } + /* L30: */ + } + + /* Make the call of BOBYQB. */ + + return bobyqb_(calfun, n, npt, &x[1], &xl[1], &xu[1], rhobeg, rhoend, maxfun, &w[ + ixb], &w[ixp], &w[ifv], &w[ixo], &w[igo], &w[ihq], &w[ipq], &w[ + ibmat], &w[izmat], ndim, &w[isl], &w[isu], &w[ixn], &w[ixa], &w[ + id_], &w[ivl], &w[iw]); + //L40: + ; + } /* bobyqa_ */ + + // ---------------------------------------------------------------------------------------- + + template + doublereal bobyqb_( + const funct& calfun, + const integer n, + const integer npt, + doublereal *x, + const doublereal *xl, + const doublereal *xu, + const doublereal rhobeg, + const doublereal rhoend, + const integer maxfun, + doublereal *xbase, + doublereal *xpt, + doublereal *fval, + doublereal *xopt, + doublereal *gopt, + doublereal *hq, + doublereal *pq, + doublereal *bmat, + doublereal *zmat, + const integer ndim, + doublereal *sl, + doublereal *su, + doublereal *xnew, + doublereal *xalt, + doublereal *d__, + doublereal *vlag, + doublereal *w + ) const + { + /* System generated locals */ + integer xpt_dim1, xpt_offset, bmat_dim1, bmat_offset, zmat_dim1, + zmat_offset, i__1, i__2, i__3; + doublereal d__1, d__2, d__3, d__4; + + /* Local variables */ + doublereal f = 0; + integer i__, j, k, ih, nf, jj, nh, ip, jp; + doublereal dx; + integer np; + doublereal den = 0, one = 0, ten = 0, dsq = 0, rho = 0, sum = 0, two = 0, diff = 0, half = 0, beta = 0, gisq = 0; + integer knew = 0; + doublereal temp, suma, sumb, bsum, fopt; + integer kopt = 0, nptm; + doublereal zero, curv; + integer ksav; + doublereal gqsq = 0, dist = 0, sumw = 0, sumz = 0, diffa = 0, diffb = 0, diffc = 0, hdiag = 0; + integer kbase; + doublereal alpha = 0, delta = 0, adelt = 0, denom = 0, fsave = 0, bdtol = 0, delsq = 0; + integer nresc, nfsav; + doublereal ratio = 0, dnorm = 0, vquad = 0, pqold = 0, tenth = 0; + integer itest; + doublereal sumpq, scaden; + doublereal errbig, cauchy, fracsq, biglsq, densav; + doublereal bdtest; + doublereal crvmin, frhosq; + doublereal distsq; + integer ntrits; + doublereal xoptsq; + + + + /* The arguments N, NPT, X, XL, XU, RHOBEG, RHOEND, IPRINT and MAXFUN */ + /* are identical to the corresponding arguments in SUBROUTINE BOBYQA. */ + /* XBASE holds a shift of origin that should reduce the contributions */ + /* from rounding errors to values of the model and Lagrange functions. */ + /* XPT is a two-dimensional array that holds the coordinates of the */ + /* interpolation points relative to XBASE. */ + /* FVAL holds the values of F at the interpolation points. */ + /* XOPT is set to the displacement from XBASE of the trust region centre. */ + /* GOPT holds the gradient of the quadratic model at XBASE+XOPT. */ + /* HQ holds the explicit second derivatives of the quadratic model. */ + /* PQ contains the parameters of the implicit second derivatives of the */ + /* quadratic model. */ + /* BMAT holds the last N columns of H. */ + /* ZMAT holds the factorization of the leading NPT by NPT submatrix of H, */ + /* this factorization being ZMAT times ZMAT^T, which provides both the */ + /* correct rank and positive semi-definiteness. */ + /* NDIM is the first dimension of BMAT and has the value NPT+N. */ + /* SL and SU hold the differences XL-XBASE and XU-XBASE, respectively. */ + /* All the components of every XOPT are going to satisfy the bounds */ + /* SL(I) .LEQ. XOPT(I) .LEQ. SU(I), with appropriate equalities when */ + /* XOPT is on a constraint boundary. */ + /* XNEW is chosen by SUBROUTINE TRSBOX or ALTMOV. Usually XBASE+XNEW is the */ + /* vector of variables for the next call of CALFUN. XNEW also satisfies */ + /* the SL and SU constraints in the way that has just been mentioned. */ + /* XALT is an alternative to XNEW, chosen by ALTMOV, that may replace XNEW */ + /* in order to increase the denominator in the updating of UPDATE. */ + /* D is reserved for a trial step from XOPT, which is usually XNEW-XOPT. */ + /* VLAG contains the values of the Lagrange functions at a new point X. */ + /* They are part of a product that requires VLAG to be of length NDIM. */ + /* W is a one-dimensional array that is used for working space. Its length */ + /* must be at least 3*NDIM = 3*(NPT+N). */ + + /* Set some constants. */ + + /* Parameter adjustments */ + zmat_dim1 = npt; + zmat_offset = 1 + zmat_dim1; + zmat -= zmat_offset; + xpt_dim1 = npt; + xpt_offset = 1 + xpt_dim1; + xpt -= xpt_offset; + --x; + --xl; + --xu; + --xbase; + --fval; + --xopt; + --gopt; + --hq; + --pq; + bmat_dim1 = ndim; + bmat_offset = 1 + bmat_dim1; + bmat -= bmat_offset; + --sl; + --su; + --xnew; + --xalt; + --d__; + --vlag; + --w; + + /* Function Body */ + half = .5; + one = 1.; + ten = 10.; + tenth = .1; + two = 2.; + zero = 0.; + np = n + 1; + nptm = npt - np; + nh = n * np / 2; + + /* The call of PRELIM sets the elements of XBASE, XPT, FVAL, GOPT, HQ, PQ, */ + /* BMAT and ZMAT for the first iteration, with the corresponding values of */ + /* of NF and KOPT, which are the number of calls of CALFUN so far and the */ + /* index of the interpolation point at the trust region centre. Then the */ + /* initial XOPT is set too. The branch to label 720 occurs if MAXFUN is */ + /* less than NPT. GOPT will be updated if KOPT is different from KBASE. */ + + prelim_(calfun, n, npt, &x[1], &xl[1], &xu[1], rhobeg, maxfun, &xbase[1], + &xpt[xpt_offset], &fval[1], &gopt[1], &hq[1], &pq[1], &bmat[bmat_offset], + &zmat[zmat_offset], ndim, &sl[1], &su[1], nf, kopt); + xoptsq = zero; + i__1 = n; + for (i__ = 1; i__ <= i__1; ++i__) { + xopt[i__] = xpt[kopt + i__ * xpt_dim1]; + /* L10: */ + /* Computing 2nd power */ + d__1 = xopt[i__]; + xoptsq += d__1 * d__1; + } + fsave = fval[1]; + if (nf < npt) { + throw bobyqa_failure("Return from BOBYQA because the objective function has been called max_f_evals times."); + //goto L720; + } + kbase = 1; + + /* Complete the settings that are required for the iterative procedure. */ + + rho = rhobeg; + delta = rho; + nresc = nf; + ntrits = 0; + diffa = zero; + diffb = zero; + itest = 0; + nfsav = nf; + + /* Update GOPT if necessary before the first iteration and after each */ + /* call of RESCUE that makes a call of CALFUN. */ + +L20: + if (kopt != kbase) { + ih = 0; + i__1 = n; + for (j = 1; j <= i__1; ++j) { + i__2 = j; + for (i__ = 1; i__ <= i__2; ++i__) { + ++ih; + if (i__ < j) { + gopt[j] += hq[ih] * xopt[i__]; + } + /* L30: */ + gopt[i__] += hq[ih] * xopt[j]; + } + } + if (nf > npt) { + i__2 = npt; + for (k = 1; k <= i__2; ++k) { + temp = zero; + i__1 = n; + for (j = 1; j <= i__1; ++j) { + /* L40: */ + temp += xpt[k + j * xpt_dim1] * xopt[j]; + } + temp = pq[k] * temp; + i__1 = n; + for (i__ = 1; i__ <= i__1; ++i__) { + /* L50: */ + gopt[i__] += temp * xpt[k + i__ * xpt_dim1]; + } + } + } + } + + /* Generate the next point in the trust region that provides a small value */ + /* of the quadratic model subject to the constraints on the variables. */ + /* The integer NTRITS is set to the number "trust region" iterations that */ + /* have occurred since the last "alternative" iteration. If the length */ + /* of XNEW-XOPT is less than HALF*RHO, however, then there is a branch to */ + /* label 650 or 680 with NTRITS=-1, instead of calculating F at XNEW. */ + +L60: + trsbox_(n, npt, &xpt[xpt_offset], &xopt[1], &gopt[1], &hq[1], &pq[1], &sl[1], + &su[1], delta, &xnew[1], &d__[1], &w[1], &w[np], &w[np + n], + &w[np + (n << 1)], &w[np + n * 3], &dsq, &crvmin); + /* Computing MIN */ + d__1 = delta, d__2 = std::sqrt(dsq); + dnorm = std::min(d__1,d__2); + if (dnorm < half * rho) { + ntrits = -1; + /* Computing 2nd power */ + d__1 = ten * rho; + distsq = d__1 * d__1; + if (nf <= nfsav + 2) { + goto L650; + } + + /* The following choice between labels 650 and 680 depends on whether or */ + /* not our work with the current RHO seems to be complete. Either RHO is */ + /* decreased or termination occurs if the errors in the quadratic model at */ + /* the last three interpolation points compare favourably with predictions */ + /* of likely improvements to the model within distance HALF*RHO of XOPT. */ + + /* Computing MAX */ + d__1 = std::max(diffa,diffb); + errbig = std::max(d__1,diffc); + frhosq = rho * .125 * rho; + if (crvmin > zero && errbig > frhosq * crvmin) { + goto L650; + } + bdtol = errbig / rho; + i__1 = n; + for (j = 1; j <= i__1; ++j) { + bdtest = bdtol; + if (xnew[j] == sl[j]) { + bdtest = w[j]; + } + if (xnew[j] == su[j]) { + bdtest = -w[j]; + } + if (bdtest < bdtol) { + curv = hq[(j + j * j) / 2]; + i__2 = npt; + for (k = 1; k <= i__2; ++k) { + /* L70: */ + /* Computing 2nd power */ + d__1 = xpt[k + j * xpt_dim1]; + curv += pq[k] * (d__1 * d__1); + } + bdtest += half * curv * rho; + if (bdtest < bdtol) { + goto L650; + } + } + /* L80: */ + } + goto L680; + } + ++ntrits; + + /* Severe cancellation is likely to occur if XOPT is too far from XBASE. */ + /* If the following test holds, then XBASE is shifted so that XOPT becomes */ + /* zero. The appropriate changes are made to BMAT and to the second */ + /* derivatives of the current model, beginning with the changes to BMAT */ + /* that do not depend on ZMAT. VLAG is used temporarily for working space. */ + +L90: + if (dsq <= xoptsq * .001) { + fracsq = xoptsq * .25; + sumpq = zero; + i__1 = npt; + for (k = 1; k <= i__1; ++k) { + sumpq += pq[k]; + sum = -half * xoptsq; + i__2 = n; + for (i__ = 1; i__ <= i__2; ++i__) { + /* L100: */ + sum += xpt[k + i__ * xpt_dim1] * xopt[i__]; + } + w[npt + k] = sum; + temp = fracsq - half * sum; + i__2 = n; + for (i__ = 1; i__ <= i__2; ++i__) { + w[i__] = bmat[k + i__ * bmat_dim1]; + vlag[i__] = sum * xpt[k + i__ * xpt_dim1] + temp * xopt[i__]; + ip = npt + i__; + i__3 = i__; + for (j = 1; j <= i__3; ++j) { + /* L110: */ + bmat[ip + j * bmat_dim1] = bmat[ip + j * bmat_dim1] + w[ + i__] * vlag[j] + vlag[i__] * w[j]; + } + } + } + + /* Then the revisions of BMAT that depend on ZMAT are calculated. */ + + i__3 = nptm; + for (jj = 1; jj <= i__3; ++jj) { + sumz = zero; + sumw = zero; + i__2 = npt; + for (k = 1; k <= i__2; ++k) { + sumz += zmat[k + jj * zmat_dim1]; + vlag[k] = w[npt + k] * zmat[k + jj * zmat_dim1]; + /* L120: */ + sumw += vlag[k]; + } + i__2 = n; + for (j = 1; j <= i__2; ++j) { + sum = (fracsq * sumz - half * sumw) * xopt[j]; + i__1 = npt; + for (k = 1; k <= i__1; ++k) { + /* L130: */ + sum += vlag[k] * xpt[k + j * xpt_dim1]; + } + w[j] = sum; + i__1 = npt; + for (k = 1; k <= i__1; ++k) { + /* L140: */ + bmat[k + j * bmat_dim1] += sum * zmat[k + jj * zmat_dim1]; + } + } + i__1 = n; + for (i__ = 1; i__ <= i__1; ++i__) { + ip = i__ + npt; + temp = w[i__]; + i__2 = i__; + for (j = 1; j <= i__2; ++j) { + /* L150: */ + bmat[ip + j * bmat_dim1] += temp * w[j]; + } + } + } + + /* The following instructions complete the shift, including the changes */ + /* to the second derivative parameters of the quadratic model. */ + + ih = 0; + i__2 = n; + for (j = 1; j <= i__2; ++j) { + w[j] = -half * sumpq * xopt[j]; + i__1 = npt; + for (k = 1; k <= i__1; ++k) { + w[j] += pq[k] * xpt[k + j * xpt_dim1]; + /* L160: */ + xpt[k + j * xpt_dim1] -= xopt[j]; + } + i__1 = j; + for (i__ = 1; i__ <= i__1; ++i__) { + ++ih; + hq[ih] = hq[ih] + w[i__] * xopt[j] + xopt[i__] * w[j]; + /* L170: */ + bmat[npt + i__ + j * bmat_dim1] = bmat[npt + j + i__ * + bmat_dim1]; + } + } + i__1 = n; + for (i__ = 1; i__ <= i__1; ++i__) { + xbase[i__] += xopt[i__]; + xnew[i__] -= xopt[i__]; + sl[i__] -= xopt[i__]; + su[i__] -= xopt[i__]; + /* L180: */ + xopt[i__] = zero; + } + xoptsq = zero; + } + if (ntrits == 0) { + goto L210; + } + goto L230; + + /* XBASE is also moved to XOPT by a call of RESCUE. This calculation is */ + /* more expensive than the previous shift, because new matrices BMAT and */ + /* ZMAT are generated from scratch, which may include the replacement of */ + /* interpolation points whose positions seem to be causing near linear */ + /* dependence in the interpolation conditions. Therefore RESCUE is called */ + /* only if rounding errors have reduced by at least a factor of two the */ + /* denominator of the formula for updating the H matrix. It provides a */ + /* useful safeguard, but is not invoked in most applications of BOBYQA. */ + +L190: + nfsav = nf; + kbase = kopt; + rescue_(calfun, n, npt, &xl[1], &xu[1], maxfun, &xbase[1], &xpt[ + xpt_offset], &fval[1], &xopt[1], &gopt[1], &hq[1], &pq[1], &bmat[ + bmat_offset], &zmat[zmat_offset], ndim, &sl[1], &su[1], nf, delta, + kopt, &vlag[1], &w[1], &w[n + np], &w[ndim + np]); + + /* XOPT is updated now in case the branch below to label 720 is taken. */ + /* Any updating of GOPT occurs after the branch below to label 20, which */ + /* leads to a trust region iteration as does the branch to label 60. */ + + xoptsq = zero; + if (kopt != kbase) { + i__1 = n; + for (i__ = 1; i__ <= i__1; ++i__) { + xopt[i__] = xpt[kopt + i__ * xpt_dim1]; + /* L200: */ + /* Computing 2nd power */ + d__1 = xopt[i__]; + xoptsq += d__1 * d__1; + } + } + if (nf < 0) { + nf = maxfun; + throw bobyqa_failure("Return from BOBYQA because the objective function has been called max_f_evals times."); + //goto L720; + } + nresc = nf; + if (nfsav < nf) { + nfsav = nf; + goto L20; + } + if (ntrits > 0) { + goto L60; + } + + /* Pick two alternative vectors of variables, relative to XBASE, that */ + /* are suitable as new positions of the KNEW-th interpolation point. */ + /* Firstly, XNEW is set to the point on a line through XOPT and another */ + /* interpolation point that minimizes the predicted value of the next */ + /* denominator, subject to ||XNEW - XOPT|| .LEQ. ADELT and to the SL */ + /* and SU bounds. Secondly, XALT is set to the best feasible point on */ + /* a constrained version of the Cauchy step of the KNEW-th Lagrange */ + /* function, the corresponding value of the square of this function */ + /* being returned in CAUCHY. The choice between these alternatives is */ + /* going to be made when the denominator is calculated. */ + +L210: + altmov_(n, npt, &xpt[xpt_offset], &xopt[1], &bmat[bmat_offset], &zmat[zmat_offset], + ndim, &sl[1], &su[1], kopt, knew, adelt, &xnew[1], + &xalt[1], alpha, cauchy, &w[1], &w[np], &w[ndim + 1]); + i__1 = n; + for (i__ = 1; i__ <= i__1; ++i__) { + /* L220: */ + d__[i__] = xnew[i__] - xopt[i__]; + } + + /* Calculate VLAG and BETA for the current choice of D. The scalar */ + /* product of D with XPT(K,.) is going to be held in W(NPT+K) for */ + /* use when VQUAD is calculated. */ + +L230: + i__1 = npt; + for (k = 1; k <= i__1; ++k) { + suma = zero; + sumb = zero; + sum = zero; + i__2 = n; + for (j = 1; j <= i__2; ++j) { + suma += xpt[k + j * xpt_dim1] * d__[j]; + sumb += xpt[k + j * xpt_dim1] * xopt[j]; + /* L240: */ + sum += bmat[k + j * bmat_dim1] * d__[j]; + } + w[k] = suma * (half * suma + sumb); + vlag[k] = sum; + /* L250: */ + w[npt + k] = suma; + } + beta = zero; + i__1 = nptm; + for (jj = 1; jj <= i__1; ++jj) { + sum = zero; + i__2 = npt; + for (k = 1; k <= i__2; ++k) { + /* L260: */ + sum += zmat[k + jj * zmat_dim1] * w[k]; + } + beta -= sum * sum; + i__2 = npt; + for (k = 1; k <= i__2; ++k) { + /* L270: */ + vlag[k] += sum * zmat[k + jj * zmat_dim1]; + } + } + dsq = zero; + bsum = zero; + dx = zero; + i__2 = n; + for (j = 1; j <= i__2; ++j) { + /* Computing 2nd power */ + d__1 = d__[j]; + dsq += d__1 * d__1; + sum = zero; + i__1 = npt; + for (k = 1; k <= i__1; ++k) { + /* L280: */ + sum += w[k] * bmat[k + j * bmat_dim1]; + } + bsum += sum * d__[j]; + jp = npt + j; + i__1 = n; + for (i__ = 1; i__ <= i__1; ++i__) { + /* L290: */ + sum += bmat[jp + i__ * bmat_dim1] * d__[i__]; + } + vlag[jp] = sum; + bsum += sum * d__[j]; + /* L300: */ + dx += d__[j] * xopt[j]; + } + beta = dx * dx + dsq * (xoptsq + dx + dx + half * dsq) + beta - bsum; + vlag[kopt] += one; + + /* If NTRITS is zero, the denominator may be increased by replacing */ + /* the step D of ALTMOV by a Cauchy step. Then RESCUE may be called if */ + /* rounding errors have damaged the chosen denominator. */ + + if (ntrits == 0) { + /* Computing 2nd power */ + d__1 = vlag[knew]; + denom = d__1 * d__1 + alpha * beta; + if (denom < cauchy && cauchy > zero) { + i__2 = n; + for (i__ = 1; i__ <= i__2; ++i__) { + xnew[i__] = xalt[i__]; + /* L310: */ + d__[i__] = xnew[i__] - xopt[i__]; + } + cauchy = zero; + goto L230; + } + /* Computing 2nd power */ + d__1 = vlag[knew]; + if (denom <= half * (d__1 * d__1)) { + if (nf > nresc) { + goto L190; + } + throw bobyqa_failure("Return from BOBYQA because of much cancellation in a denominator."); + //goto L720; + } + + /* Alternatively, if NTRITS is positive, then set KNEW to the index of */ + /* the next interpolation point to be deleted to make room for a trust */ + /* region step. Again RESCUE may be called if rounding errors have damaged */ + /* the chosen denominator, which is the reason for attempting to select */ + /* KNEW before calculating the next value of the objective function. */ + + } else { + delsq = delta * delta; + scaden = zero; + biglsq = zero; + knew = 0; + i__2 = npt; + for (k = 1; k <= i__2; ++k) { + if (k == kopt) { + goto L350; + } + hdiag = zero; + i__1 = nptm; + for (jj = 1; jj <= i__1; ++jj) { + /* L330: */ + /* Computing 2nd power */ + d__1 = zmat[k + jj * zmat_dim1]; + hdiag += d__1 * d__1; + } + /* Computing 2nd power */ + d__1 = vlag[k]; + den = beta * hdiag + d__1 * d__1; + distsq = zero; + i__1 = n; + for (j = 1; j <= i__1; ++j) { + /* L340: */ + /* Computing 2nd power */ + d__1 = xpt[k + j * xpt_dim1] - xopt[j]; + distsq += d__1 * d__1; + } + /* Computing MAX */ + /* Computing 2nd power */ + d__3 = distsq / delsq; + d__1 = one, d__2 = d__3 * d__3; + temp = std::max(d__1,d__2); + if (temp * den > scaden) { + scaden = temp * den; + knew = k; + denom = den; + } + /* Computing MAX */ + /* Computing 2nd power */ + d__3 = vlag[k]; + d__1 = biglsq, d__2 = temp * (d__3 * d__3); + biglsq = std::max(d__1,d__2); +L350: + ; + } + if (scaden <= half * biglsq) { + if (nf > nresc) { + goto L190; + } + throw bobyqa_failure("Return from BOBYQA because of much cancellation in a denominator."); + //goto L720; + } + } + + /* Put the variables for the next calculation of the objective function */ + /* in XNEW, with any adjustments for the bounds. */ + + + /* Calculate the value of the objective function at XBASE+XNEW, unless */ + /* the limit on the number of calculations of F has been reached. */ + +L360: + i__2 = n; + for (i__ = 1; i__ <= i__2; ++i__) { + /* Computing MIN */ + /* Computing MAX */ + d__3 = xl[i__], d__4 = xbase[i__] + xnew[i__]; + d__1 = std::max(d__3,d__4), d__2 = xu[i__]; + x[i__] = std::min(d__1,d__2); + if (xnew[i__] == sl[i__]) { + x[i__] = xl[i__]; + } + if (xnew[i__] == su[i__]) { + x[i__] = xu[i__]; + } + /* L380: */ + } + if (nf >= maxfun) { + throw bobyqa_failure("Return from BOBYQA because the objective function has been called max_f_evals times."); + //goto L720; + } + ++nf; + f = calfun(mat(&x[1], n)); + if (ntrits == -1) { + fsave = f; + goto L720; + } + + /* Use the quadratic model to predict the change in F due to the step D, */ + /* and set DIFF to the error of this prediction. */ + + fopt = fval[kopt]; + vquad = zero; + ih = 0; + i__2 = n; + for (j = 1; j <= i__2; ++j) { + vquad += d__[j] * gopt[j]; + i__1 = j; + for (i__ = 1; i__ <= i__1; ++i__) { + ++ih; + temp = d__[i__] * d__[j]; + if (i__ == j) { + temp = half * temp; + } + /* L410: */ + vquad += hq[ih] * temp; + } + } + i__1 = npt; + for (k = 1; k <= i__1; ++k) { + /* L420: */ + /* Computing 2nd power */ + d__1 = w[npt + k]; + vquad += half * pq[k] * (d__1 * d__1); + } + diff = f - fopt - vquad; + diffc = diffb; + diffb = diffa; + diffa = std::abs(diff); + if (dnorm > rho) { + nfsav = nf; + } + + /* Pick the next value of DELTA after a trust region step. */ + + if (ntrits > 0) { + if (vquad >= zero) { + throw bobyqa_failure("Return from BOBYQA because a trust region step has failed to reduce Q."); + //goto L720; + } + ratio = (f - fopt) / vquad; + if (ratio <= tenth) { + /* Computing MIN */ + d__1 = half * delta; + delta = std::min(d__1,dnorm); + } else if (ratio <= .7) { + /* Computing MAX */ + d__1 = half * delta; + delta = std::max(d__1,dnorm); + } else { + /* Computing MAX */ + d__1 = half * delta, d__2 = dnorm + dnorm; + delta = std::max(d__1,d__2); + } + if (delta <= rho * 1.5) { + delta = rho; + } + + /* Recalculate KNEW and DENOM if the new F is less than FOPT. */ + + if (f < fopt) { + ksav = knew; + densav = denom; + delsq = delta * delta; + scaden = zero; + biglsq = zero; + knew = 0; + i__1 = npt; + for (k = 1; k <= i__1; ++k) { + hdiag = zero; + i__2 = nptm; + for (jj = 1; jj <= i__2; ++jj) { + /* L440: */ + /* Computing 2nd power */ + d__1 = zmat[k + jj * zmat_dim1]; + hdiag += d__1 * d__1; + } + /* Computing 2nd power */ + d__1 = vlag[k]; + den = beta * hdiag + d__1 * d__1; + distsq = zero; + i__2 = n; + for (j = 1; j <= i__2; ++j) { + /* L450: */ + /* Computing 2nd power */ + d__1 = xpt[k + j * xpt_dim1] - xnew[j]; + distsq += d__1 * d__1; + } + /* Computing MAX */ + /* Computing 2nd power */ + d__3 = distsq / delsq; + d__1 = one, d__2 = d__3 * d__3; + temp = std::max(d__1,d__2); + if (temp * den > scaden) { + scaden = temp * den; + knew = k; + denom = den; + } + /* L460: */ + /* Computing MAX */ + /* Computing 2nd power */ + d__3 = vlag[k]; + d__1 = biglsq, d__2 = temp * (d__3 * d__3); + biglsq = std::max(d__1,d__2); + } + if (scaden <= half * biglsq) { + knew = ksav; + denom = densav; + } + } + } + + /* Update BMAT and ZMAT, so that the KNEW-th interpolation point can be */ + /* moved. Also update the second derivative terms of the model. */ + + update_(n, npt, &bmat[bmat_offset], &zmat[zmat_offset], ndim, &vlag[1], + beta, denom, knew, &w[1]); + ih = 0; + pqold = pq[knew]; + pq[knew] = zero; + i__1 = n; + for (i__ = 1; i__ <= i__1; ++i__) { + temp = pqold * xpt[knew + i__ * xpt_dim1]; + i__2 = i__; + for (j = 1; j <= i__2; ++j) { + ++ih; + /* L470: */ + hq[ih] += temp * xpt[knew + j * xpt_dim1]; + } + } + i__2 = nptm; + for (jj = 1; jj <= i__2; ++jj) { + temp = diff * zmat[knew + jj * zmat_dim1]; + i__1 = npt; + for (k = 1; k <= i__1; ++k) { + /* L480: */ + pq[k] += temp * zmat[k + jj * zmat_dim1]; + } + } + + /* Include the new interpolation point, and make the changes to GOPT at */ + /* the old XOPT that are caused by the updating of the quadratic model. */ + + fval[knew] = f; + i__1 = n; + for (i__ = 1; i__ <= i__1; ++i__) { + xpt[knew + i__ * xpt_dim1] = xnew[i__]; + /* L490: */ + w[i__] = bmat[knew + i__ * bmat_dim1]; + } + i__1 = npt; + for (k = 1; k <= i__1; ++k) { + suma = zero; + i__2 = nptm; + for (jj = 1; jj <= i__2; ++jj) { + /* L500: */ + suma += zmat[knew + jj * zmat_dim1] * zmat[k + jj * zmat_dim1]; + } + sumb = zero; + i__2 = n; + for (j = 1; j <= i__2; ++j) { + /* L510: */ + sumb += xpt[k + j * xpt_dim1] * xopt[j]; + } + temp = suma * sumb; + i__2 = n; + for (i__ = 1; i__ <= i__2; ++i__) { + /* L520: */ + w[i__] += temp * xpt[k + i__ * xpt_dim1]; + } + } + i__2 = n; + for (i__ = 1; i__ <= i__2; ++i__) { + /* L530: */ + gopt[i__] += diff * w[i__]; + } + + /* Update XOPT, GOPT and KOPT if the new calculated F is less than FOPT. */ + + if (f < fopt) { + kopt = knew; + xoptsq = zero; + ih = 0; + i__2 = n; + for (j = 1; j <= i__2; ++j) { + xopt[j] = xnew[j]; + /* Computing 2nd power */ + d__1 = xopt[j]; + xoptsq += d__1 * d__1; + i__1 = j; + for (i__ = 1; i__ <= i__1; ++i__) { + ++ih; + if (i__ < j) { + gopt[j] += hq[ih] * d__[i__]; + } + /* L540: */ + gopt[i__] += hq[ih] * d__[j]; + } + } + i__1 = npt; + for (k = 1; k <= i__1; ++k) { + temp = zero; + i__2 = n; + for (j = 1; j <= i__2; ++j) { + /* L550: */ + temp += xpt[k + j * xpt_dim1] * d__[j]; + } + temp = pq[k] * temp; + i__2 = n; + for (i__ = 1; i__ <= i__2; ++i__) { + /* L560: */ + gopt[i__] += temp * xpt[k + i__ * xpt_dim1]; + } + } + } + + /* Calculate the parameters of the least Frobenius norm interpolant to */ + /* the current data, the gradient of this interpolant at XOPT being put */ + /* into VLAG(NPT+I), I=1,2,...,N. */ + + if (ntrits > 0) { + i__2 = npt; + for (k = 1; k <= i__2; ++k) { + vlag[k] = fval[k] - fval[kopt]; + /* L570: */ + w[k] = zero; + } + i__2 = nptm; + for (j = 1; j <= i__2; ++j) { + sum = zero; + i__1 = npt; + for (k = 1; k <= i__1; ++k) { + /* L580: */ + sum += zmat[k + j * zmat_dim1] * vlag[k]; + } + i__1 = npt; + for (k = 1; k <= i__1; ++k) { + /* L590: */ + w[k] += sum * zmat[k + j * zmat_dim1]; + } + } + i__1 = npt; + for (k = 1; k <= i__1; ++k) { + sum = zero; + i__2 = n; + for (j = 1; j <= i__2; ++j) { + /* L600: */ + sum += xpt[k + j * xpt_dim1] * xopt[j]; + } + w[k + npt] = w[k]; + /* L610: */ + w[k] = sum * w[k]; + } + gqsq = zero; + gisq = zero; + i__1 = n; + for (i__ = 1; i__ <= i__1; ++i__) { + sum = zero; + i__2 = npt; + for (k = 1; k <= i__2; ++k) { + /* L620: */ + sum = sum + bmat[k + i__ * bmat_dim1] * vlag[k] + xpt[k + i__ + * xpt_dim1] * w[k]; + } + if (xopt[i__] == sl[i__]) { + /* Computing MIN */ + d__2 = zero, d__3 = gopt[i__]; + /* Computing 2nd power */ + d__1 = std::min(d__2,d__3); + gqsq += d__1 * d__1; + /* Computing 2nd power */ + d__1 = std::min(zero,sum); + gisq += d__1 * d__1; + } else if (xopt[i__] == su[i__]) { + /* Computing MAX */ + d__2 = zero, d__3 = gopt[i__]; + /* Computing 2nd power */ + d__1 = std::max(d__2,d__3); + gqsq += d__1 * d__1; + /* Computing 2nd power */ + d__1 = std::max(zero,sum); + gisq += d__1 * d__1; + } else { + /* Computing 2nd power */ + d__1 = gopt[i__]; + gqsq += d__1 * d__1; + gisq += sum * sum; + } + /* L630: */ + vlag[npt + i__] = sum; + } + + /* Test whether to replace the new quadratic model by the least Frobenius */ + /* norm interpolant, making the replacement if the test is satisfied. */ + + ++itest; + if (gqsq < ten * gisq) { + itest = 0; + } + if (itest >= 3) { + i__1 = std::max(npt,nh); + for (i__ = 1; i__ <= i__1; ++i__) { + if (i__ <= n) { + gopt[i__] = vlag[npt + i__]; + } + if (i__ <= npt) { + pq[i__] = w[npt + i__]; + } + if (i__ <= nh) { + hq[i__] = zero; + } + itest = 0; + /* L640: */ + } + } + } + + /* If a trust region step has provided a sufficient decrease in F, then */ + /* branch for another trust region calculation. The case NTRITS=0 occurs */ + /* when the new interpolation point was reached by an alternative step. */ + + if (ntrits == 0) { + goto L60; + } + if (f <= fopt + tenth * vquad) { + goto L60; + } + + /* Alternatively, find out if the interpolation points are close enough */ + /* to the best point so far. */ + + /* Computing MAX */ + /* Computing 2nd power */ + d__3 = two * delta; + /* Computing 2nd power */ + d__4 = ten * rho; + d__1 = d__3 * d__3, d__2 = d__4 * d__4; + distsq = std::max(d__1,d__2); +L650: + knew = 0; + i__1 = npt; + for (k = 1; k <= i__1; ++k) { + sum = zero; + i__2 = n; + for (j = 1; j <= i__2; ++j) { + /* L660: */ + /* Computing 2nd power */ + d__1 = xpt[k + j * xpt_dim1] - xopt[j]; + sum += d__1 * d__1; + } + if (sum > distsq) { + knew = k; + distsq = sum; + } + /* L670: */ + } + + /* If KNEW is positive, then ALTMOV finds alternative new positions for */ + /* the KNEW-th interpolation point within distance ADELT of XOPT. It is */ + /* reached via label 90. Otherwise, there is a branch to label 60 for */ + /* another trust region iteration, unless the calculations with the */ + /* current RHO are complete. */ + + if (knew > 0) { + dist = std::sqrt(distsq); + if (ntrits == -1) { + /* Computing MIN */ + d__1 = tenth * delta, d__2 = half * dist; + delta = std::min(d__1,d__2); + if (delta <= rho * 1.5) { + delta = rho; + } + } + ntrits = 0; + /* Computing MAX */ + /* Computing MIN */ + d__2 = tenth * dist; + d__1 = std::min(d__2,delta); + adelt = std::max(d__1,rho); + dsq = adelt * adelt; + goto L90; + } + if (ntrits == -1) { + goto L680; + } + if (ratio > zero) { + goto L60; + } + if (std::max(delta,dnorm) > rho) { + goto L60; + } + + /* The calculations with the current value of RHO are complete. Pick the */ + /* next values of RHO and DELTA. */ + +L680: + if (rho > rhoend) { + delta = half * rho; + ratio = rho / rhoend; + if (ratio <= 16.) { + rho = rhoend; + } else if (ratio <= 250.) { + rho = std::sqrt(ratio) * rhoend; + } else { + rho = tenth * rho; + } + delta = std::max(delta,rho); + ntrits = 0; + nfsav = nf; + goto L60; + } + + /* Return from the calculation, after another Newton-Raphson step, if */ + /* it is too short to have been tried before. */ + + if (ntrits == -1) { + goto L360; + } +L720: + if (fval[kopt] <= fsave) { + i__1 = n; + for (i__ = 1; i__ <= i__1; ++i__) { + /* Computing MIN */ + /* Computing MAX */ + d__3 = xl[i__], d__4 = xbase[i__] + xopt[i__]; + d__1 = std::max(d__3,d__4), d__2 = xu[i__]; + x[i__] = std::min(d__1,d__2); + if (xopt[i__] == sl[i__]) { + x[i__] = xl[i__]; + } + if (xopt[i__] == su[i__]) { + x[i__] = xu[i__]; + } + /* L730: */ + } + f = fval[kopt]; + } + + return f; + } /* bobyqb_ */ + + // ---------------------------------------------------------------------------------------- + + void altmov_( + const integer n, + const integer npt, + const doublereal *xpt, + const doublereal *xopt, + const doublereal *bmat, + const doublereal *zmat, + const integer ndim, + const doublereal *sl, + const doublereal *su, + const integer kopt, + const integer knew, + const doublereal adelt, + doublereal *xnew, + doublereal *xalt, + doublereal& alpha, + doublereal& cauchy, + doublereal *glag, + doublereal *hcol, + doublereal *w + ) const + { + /* System generated locals */ + integer xpt_dim1, xpt_offset, bmat_dim1, bmat_offset, zmat_dim1, + zmat_offset, i__1, i__2; + doublereal d__1, d__2, d__3, d__4; + + + /* Local variables */ + integer i__, j, k; + doublereal ha, gw, one, diff, half; + integer ilbd, isbd; + doublereal slbd; + integer iubd; + doublereal vlag, subd, temp; + integer ksav = 0; + doublereal step = 0, zero = 0, curv = 0; + integer iflag; + doublereal scale = 0, csave = 0, tempa = 0, tempb = 0, tempd = 0, const__ = 0, sumin = 0, + ggfree = 0; + integer ibdsav = 0; + doublereal dderiv = 0, bigstp = 0, predsq = 0, presav = 0, distsq = 0, stpsav = 0, wfixsq = 0, wsqsav = 0; + + + /* The arguments N, NPT, XPT, XOPT, BMAT, ZMAT, NDIM, SL and SU all have */ + /* the same meanings as the corresponding arguments of BOBYQB. */ + /* KOPT is the index of the optimal interpolation point. */ + /* KNEW is the index of the interpolation point that is going to be moved. */ + /* ADELT is the current trust region bound. */ + /* XNEW will be set to a suitable new position for the interpolation point */ + /* XPT(KNEW,.). Specifically, it satisfies the SL, SU and trust region */ + /* bounds and it should provide a large denominator in the next call of */ + /* UPDATE. The step XNEW-XOPT from XOPT is restricted to moves along the */ + /* straight lines through XOPT and another interpolation point. */ + /* XALT also provides a large value of the modulus of the KNEW-th Lagrange */ + /* function subject to the constraints that have been mentioned, its main */ + /* difference from XNEW being that XALT-XOPT is a constrained version of */ + /* the Cauchy step within the trust region. An exception is that XALT is */ + /* not calculated if all components of GLAG (see below) are zero. */ + /* ALPHA will be set to the KNEW-th diagonal element of the H matrix. */ + /* CAUCHY will be set to the square of the KNEW-th Lagrange function at */ + /* the step XALT-XOPT from XOPT for the vector XALT that is returned, */ + /* except that CAUCHY is set to zero if XALT is not calculated. */ + /* GLAG is a working space vector of length N for the gradient of the */ + /* KNEW-th Lagrange function at XOPT. */ + /* HCOL is a working space vector of length NPT for the second derivative */ + /* coefficients of the KNEW-th Lagrange function. */ + /* W is a working space vector of length 2N that is going to hold the */ + /* constrained Cauchy step from XOPT of the Lagrange function, followed */ + /* by the downhill version of XALT when the uphill step is calculated. */ + + /* Set the first NPT components of W to the leading elements of the */ + /* KNEW-th column of the H matrix. */ + + /* Parameter adjustments */ + zmat_dim1 = npt; + zmat_offset = 1 + zmat_dim1; + zmat -= zmat_offset; + xpt_dim1 = npt; + xpt_offset = 1 + xpt_dim1; + xpt -= xpt_offset; + --xopt; + bmat_dim1 = ndim; + bmat_offset = 1 + bmat_dim1; + bmat -= bmat_offset; + --sl; + --su; + --xnew; + --xalt; + --glag; + --hcol; + --w; + + /* Function Body */ + half = .5; + one = 1.; + zero = 0.; + const__ = one + std::sqrt(2.); + i__1 = npt; + for (k = 1; k <= i__1; ++k) { + /* L10: */ + hcol[k] = zero; + } + i__1 = npt - n - 1; + for (j = 1; j <= i__1; ++j) { + temp = zmat[knew + j * zmat_dim1]; + i__2 = npt; + for (k = 1; k <= i__2; ++k) { + /* L20: */ + hcol[k] += temp * zmat[k + j * zmat_dim1]; + } + } + alpha = hcol[knew]; + ha = half * alpha; + + /* Calculate the gradient of the KNEW-th Lagrange function at XOPT. */ + + i__2 = n; + for (i__ = 1; i__ <= i__2; ++i__) { + /* L30: */ + glag[i__] = bmat[knew + i__ * bmat_dim1]; + } + i__2 = npt; + for (k = 1; k <= i__2; ++k) { + temp = zero; + i__1 = n; + for (j = 1; j <= i__1; ++j) { + /* L40: */ + temp += xpt[k + j * xpt_dim1] * xopt[j]; + } + temp = hcol[k] * temp; + i__1 = n; + for (i__ = 1; i__ <= i__1; ++i__) { + /* L50: */ + glag[i__] += temp * xpt[k + i__ * xpt_dim1]; + } + } + + /* Search for a large denominator along the straight lines through XOPT */ + /* and another interpolation point. SLBD and SUBD will be lower and upper */ + /* bounds on the step along each of these lines in turn. PREDSQ will be */ + /* set to the square of the predicted denominator for each line. PRESAV */ + /* will be set to the largest admissible value of PREDSQ that occurs. */ + + presav = zero; + i__1 = npt; + for (k = 1; k <= i__1; ++k) { + if (k == kopt) { + goto L80; + } + dderiv = zero; + distsq = zero; + i__2 = n; + for (i__ = 1; i__ <= i__2; ++i__) { + temp = xpt[k + i__ * xpt_dim1] - xopt[i__]; + dderiv += glag[i__] * temp; + /* L60: */ + distsq += temp * temp; + } + subd = adelt / std::sqrt(distsq); + slbd = -subd; + ilbd = 0; + iubd = 0; + sumin = std::min(one,subd); + + /* Revise SLBD and SUBD if necessary because of the bounds in SL and SU. */ + + i__2 = n; + for (i__ = 1; i__ <= i__2; ++i__) { + temp = xpt[k + i__ * xpt_dim1] - xopt[i__]; + if (temp > zero) { + if (slbd * temp < sl[i__] - xopt[i__]) { + slbd = (sl[i__] - xopt[i__]) / temp; + ilbd = -i__; + } + if (subd * temp > su[i__] - xopt[i__]) { + /* Computing MAX */ + d__1 = sumin, d__2 = (su[i__] - xopt[i__]) / temp; + subd = std::max(d__1,d__2); + iubd = i__; + } + } else if (temp < zero) { + if (slbd * temp > su[i__] - xopt[i__]) { + slbd = (su[i__] - xopt[i__]) / temp; + ilbd = i__; + } + if (subd * temp < sl[i__] - xopt[i__]) { + /* Computing MAX */ + d__1 = sumin, d__2 = (sl[i__] - xopt[i__]) / temp; + subd = std::max(d__1,d__2); + iubd = -i__; + } + } + /* L70: */ + } + + /* Seek a large modulus of the KNEW-th Lagrange function when the index */ + /* of the other interpolation point on the line through XOPT is KNEW. */ + + if (k == knew) { + diff = dderiv - one; + step = slbd; + vlag = slbd * (dderiv - slbd * diff); + isbd = ilbd; + temp = subd * (dderiv - subd * diff); + if (std::abs(temp) > std::abs(vlag)) { + step = subd; + vlag = temp; + isbd = iubd; + } + tempd = half * dderiv; + tempa = tempd - diff * slbd; + tempb = tempd - diff * subd; + if (tempa * tempb < zero) { + temp = tempd * tempd / diff; + if (std::abs(temp) > std::abs(vlag)) { + step = tempd / diff; + vlag = temp; + isbd = 0; + } + } + + /* Search along each of the other lines through XOPT and another point. */ + + } else { + step = slbd; + vlag = slbd * (one - slbd); + isbd = ilbd; + temp = subd * (one - subd); + if (std::abs(temp) > std::abs(vlag)) { + step = subd; + vlag = temp; + isbd = iubd; + } + if (subd > half) { + if (std::abs(vlag) < .25) { + step = half; + vlag = .25; + isbd = 0; + } + } + vlag *= dderiv; + } + + /* Calculate PREDSQ for the current line search and maintain PRESAV. */ + + temp = step * (one - step) * distsq; + predsq = vlag * vlag * (vlag * vlag + ha * temp * temp); + if (predsq > presav) { + presav = predsq; + ksav = k; + stpsav = step; + ibdsav = isbd; + } +L80: + ; + } + + /* Construct XNEW in a way that satisfies the bound constraints exactly. */ + + i__1 = n; + for (i__ = 1; i__ <= i__1; ++i__) { + temp = xopt[i__] + stpsav * (xpt[ksav + i__ * xpt_dim1] - xopt[i__]); + /* L90: */ + /* Computing MAX */ + /* Computing MIN */ + d__3 = su[i__]; + d__1 = sl[i__], d__2 = std::min(d__3,temp); + xnew[i__] = std::max(d__1,d__2); + } + if (ibdsav < 0) { + xnew[-ibdsav] = sl[-ibdsav]; + } + if (ibdsav > 0) { + xnew[ibdsav] = su[ibdsav]; + } + + /* Prepare for the iterative method that assembles the constrained Cauchy */ + /* step in W. The sum of squares of the fixed components of W is formed in */ + /* WFIXSQ, and the free components of W are set to BIGSTP. */ + + bigstp = adelt + adelt; + iflag = 0; +L100: + wfixsq = zero; + ggfree = zero; + i__1 = n; + for (i__ = 1; i__ <= i__1; ++i__) { + w[i__] = zero; + /* Computing MIN */ + d__1 = xopt[i__] - sl[i__], d__2 = glag[i__]; + tempa = std::min(d__1,d__2); + /* Computing MAX */ + d__1 = xopt[i__] - su[i__], d__2 = glag[i__]; + tempb = std::max(d__1,d__2); + if (tempa > zero || tempb < zero) { + w[i__] = bigstp; + /* Computing 2nd power */ + d__1 = glag[i__]; + ggfree += d__1 * d__1; + } + /* L110: */ + } + if (ggfree == zero) { + cauchy = zero; + goto L200; + } + + /* Investigate whether more components of W can be fixed. */ + +L120: + temp = adelt * adelt - wfixsq; + if (temp > zero) { + wsqsav = wfixsq; + step = std::sqrt(temp / ggfree); + ggfree = zero; + i__1 = n; + for (i__ = 1; i__ <= i__1; ++i__) { + if (w[i__] == bigstp) { + temp = xopt[i__] - step * glag[i__]; + if (temp <= sl[i__]) { + w[i__] = sl[i__] - xopt[i__]; + /* Computing 2nd power */ + d__1 = w[i__]; + wfixsq += d__1 * d__1; + } else if (temp >= su[i__]) { + w[i__] = su[i__] - xopt[i__]; + /* Computing 2nd power */ + d__1 = w[i__]; + wfixsq += d__1 * d__1; + } else { + /* Computing 2nd power */ + d__1 = glag[i__]; + ggfree += d__1 * d__1; + } + } + /* L130: */ + } + if (wfixsq > wsqsav && ggfree > zero) { + goto L120; + } + } + + /* Set the remaining free components of W and all components of XALT, */ + /* except that W may be scaled later. */ + + gw = zero; + i__1 = n; + for (i__ = 1; i__ <= i__1; ++i__) { + if (w[i__] == bigstp) { + w[i__] = -step * glag[i__]; + /* Computing MAX */ + /* Computing MIN */ + d__3 = su[i__], d__4 = xopt[i__] + w[i__]; + d__1 = sl[i__], d__2 = std::min(d__3,d__4); + xalt[i__] = std::max(d__1,d__2); + } else if (w[i__] == zero) { + xalt[i__] = xopt[i__]; + } else if (glag[i__] > zero) { + xalt[i__] = sl[i__]; + } else { + xalt[i__] = su[i__]; + } + /* L140: */ + gw += glag[i__] * w[i__]; + } + + /* Set CURV to the curvature of the KNEW-th Lagrange function along W. */ + /* Scale W by a factor less than one if that can reduce the modulus of */ + /* the Lagrange function at XOPT+W. Set CAUCHY to the final value of */ + /* the square of this function. */ + + curv = zero; + i__1 = npt; + for (k = 1; k <= i__1; ++k) { + temp = zero; + i__2 = n; + for (j = 1; j <= i__2; ++j) { + /* L150: */ + temp += xpt[k + j * xpt_dim1] * w[j]; + } + /* L160: */ + curv += hcol[k] * temp * temp; + } + if (iflag == 1) { + curv = -curv; + } + if (curv > -gw && curv < -const__ * gw) { + scale = -gw / curv; + i__1 = n; + for (i__ = 1; i__ <= i__1; ++i__) { + temp = xopt[i__] + scale * w[i__]; + /* L170: */ + /* Computing MAX */ + /* Computing MIN */ + d__3 = su[i__]; + d__1 = sl[i__], d__2 = std::min(d__3,temp); + xalt[i__] = std::max(d__1,d__2); + } + /* Computing 2nd power */ + d__1 = half * gw * scale; + cauchy = d__1 * d__1; + } else { + /* Computing 2nd power */ + d__1 = gw + half * curv; + cauchy = d__1 * d__1; + } + + /* If IFLAG is zero, then XALT is calculated as before after reversing */ + /* the sign of GLAG. Thus two XALT vectors become available. The one that */ + /* is chosen is the one that gives the larger value of CAUCHY. */ + + if (iflag == 0) { + i__1 = n; + for (i__ = 1; i__ <= i__1; ++i__) { + glag[i__] = -glag[i__]; + /* L180: */ + w[n + i__] = xalt[i__]; + } + csave = cauchy; + iflag = 1; + goto L100; + } + if (csave > cauchy) { + i__1 = n; + for (i__ = 1; i__ <= i__1; ++i__) { + /* L190: */ + xalt[i__] = w[n + i__]; + } + cauchy = csave; + } +L200: + ; + } /* altmov_ */ + + // ---------------------------------------------------------------------------------------- + + template + void prelim_( + const funct& calfun, + const integer n, + const integer npt, + doublereal *x, + const doublereal *xl, + const doublereal *xu, + const doublereal rhobeg, + const integer maxfun, + doublereal *xbase, + doublereal *xpt, + doublereal *fval, + doublereal *gopt, + doublereal *hq, + doublereal *pq, + doublereal *bmat, + doublereal *zmat, + const integer ndim, + const doublereal *sl, + const doublereal *su, + integer& nf, + integer& kopt + ) const + { + /* System generated locals */ + integer xpt_dim1, xpt_offset, bmat_dim1, bmat_offset, zmat_dim1, + zmat_offset, i__1, i__2; + doublereal d__1, d__2, d__3, d__4; + + + /* Local variables */ + doublereal f; + integer i__, j, k, ih, np, nfm; + doublereal one; + integer nfx = 0, ipt = 0, jpt = 0; + doublereal two = 0, fbeg = 0, diff = 0, half = 0, temp = 0, zero = 0, recip = 0, stepa = 0, stepb = 0; + integer itemp; + doublereal rhosq; + + + + /* The arguments N, NPT, X, XL, XU, RHOBEG, IPRINT and MAXFUN are the */ + /* same as the corresponding arguments in SUBROUTINE BOBYQA. */ + /* The arguments XBASE, XPT, FVAL, HQ, PQ, BMAT, ZMAT, NDIM, SL and SU */ + /* are the same as the corresponding arguments in BOBYQB, the elements */ + /* of SL and SU being set in BOBYQA. */ + /* GOPT is usually the gradient of the quadratic model at XOPT+XBASE, but */ + /* it is set by PRELIM to the gradient of the quadratic model at XBASE. */ + /* If XOPT is nonzero, BOBYQB will change it to its usual value later. */ + /* NF is maintaned as the number of calls of CALFUN so far. */ + /* KOPT will be such that the least calculated value of F so far is at */ + /* the point XPT(KOPT,.)+XBASE in the space of the variables. */ + + /* SUBROUTINE PRELIM sets the elements of XBASE, XPT, FVAL, GOPT, HQ, PQ, */ + /* BMAT and ZMAT for the first iteration, and it maintains the values of */ + /* NF and KOPT. The vector X is also changed by PRELIM. */ + + /* Set some constants. */ + + /* Parameter adjustments */ + zmat_dim1 = npt; + zmat_offset = 1 + zmat_dim1; + zmat -= zmat_offset; + xpt_dim1 = npt; + xpt_offset = 1 + xpt_dim1; + xpt -= xpt_offset; + --x; + --xl; + --xu; + --xbase; + --fval; + --gopt; + --hq; + --pq; + bmat_dim1 = ndim; + bmat_offset = 1 + bmat_dim1; + bmat -= bmat_offset; + --sl; + --su; + + /* Function Body */ + half = .5; + one = 1.; + two = 2.; + zero = 0.; + rhosq = rhobeg * rhobeg; + recip = one / rhosq; + np = n + 1; + + /* Set XBASE to the initial vector of variables, and set the initial */ + /* elements of XPT, BMAT, HQ, PQ and ZMAT to zero. */ + + i__1 = n; + for (j = 1; j <= i__1; ++j) { + xbase[j] = x[j]; + i__2 = npt; + for (k = 1; k <= i__2; ++k) { + /* L10: */ + xpt[k + j * xpt_dim1] = zero; + } + i__2 = ndim; + for (i__ = 1; i__ <= i__2; ++i__) { + /* L20: */ + bmat[i__ + j * bmat_dim1] = zero; + } + } + i__2 = n * np / 2; + for (ih = 1; ih <= i__2; ++ih) { + /* L30: */ + hq[ih] = zero; + } + i__2 = npt; + for (k = 1; k <= i__2; ++k) { + pq[k] = zero; + i__1 = npt - np; + for (j = 1; j <= i__1; ++j) { + /* L40: */ + zmat[k + j * zmat_dim1] = zero; + } + } + + /* Begin the initialization procedure. NF becomes one more than the number */ + /* of function values so far. The coordinates of the displacement of the */ + /* next initial interpolation point from XBASE are set in XPT(NF+1,.). */ + + nf = 0; +L50: + nfm = nf; + nfx = nf - n; + ++(nf); + if (nfm <= n << 1) { + if (nfm >= 1 && nfm <= n) { + stepa = rhobeg; + if (su[nfm] == zero) { + stepa = -stepa; + } + xpt[nf + nfm * xpt_dim1] = stepa; + } else if (nfm > n) { + stepa = xpt[nf - n + nfx * xpt_dim1]; + stepb = -(rhobeg); + if (sl[nfx] == zero) { + /* Computing MIN */ + d__1 = two * rhobeg, d__2 = su[nfx]; + stepb = std::min(d__1,d__2); + } + if (su[nfx] == zero) { + /* Computing MAX */ + d__1 = -two * rhobeg, d__2 = sl[nfx]; + stepb = std::max(d__1,d__2); + } + xpt[nf + nfx * xpt_dim1] = stepb; + } + } else { + itemp = (nfm - np) / n; + jpt = nfm - itemp * n - n; + ipt = jpt + itemp; + if (ipt > n) { + itemp = jpt; + jpt = ipt - n; + ipt = itemp; + } + xpt[nf + ipt * xpt_dim1] = xpt[ipt + 1 + ipt * xpt_dim1]; + xpt[nf + jpt * xpt_dim1] = xpt[jpt + 1 + jpt * xpt_dim1]; + } + + /* Calculate the next value of F. The least function value so far and */ + /* its index are required. */ + + i__1 = n; + for (j = 1; j <= i__1; ++j) { + /* Computing MIN */ + /* Computing MAX */ + d__3 = xl[j], d__4 = xbase[j] + xpt[nf + j * xpt_dim1]; + d__1 = std::max(d__3,d__4), d__2 = xu[j]; + x[j] = std::min(d__1,d__2); + if (xpt[nf + j * xpt_dim1] == sl[j]) { + x[j] = xl[j]; + } + if (xpt[nf + j * xpt_dim1] == su[j]) { + x[j] = xu[j]; + } + /* L60: */ + } + f = calfun(mat(&x[1],n)); + fval[nf] = f; + if (nf == 1) { + fbeg = f; + kopt = 1; + } else if (f < fval[kopt]) { + kopt = nf; + } + + /* Set the nonzero initial elements of BMAT and the quadratic model in the */ + /* cases when NF is at most 2*N+1. If NF exceeds N+1, then the positions */ + /* of the NF-th and (NF-N)-th interpolation points may be switched, in */ + /* order that the function value at the first of them contributes to the */ + /* off-diagonal second derivative terms of the initial quadratic model. */ + + if (nf <= (n << 1) + 1) { + if (nf >= 2 && nf <= n + 1) { + gopt[nfm] = (f - fbeg) / stepa; + if (npt < nf + n) { + bmat[nfm * bmat_dim1 + 1] = -one / stepa; + bmat[nf + nfm * bmat_dim1] = one / stepa; + bmat[npt + nfm + nfm * bmat_dim1] = -half * rhosq; + } + } else if (nf >= n + 2) { + ih = nfx * (nfx + 1) / 2; + temp = (f - fbeg) / stepb; + diff = stepb - stepa; + hq[ih] = two * (temp - gopt[nfx]) / diff; + gopt[nfx] = (gopt[nfx] * stepb - temp * stepa) / diff; + if (stepa * stepb < zero) { + if (f < fval[nf - n]) { + fval[nf] = fval[nf - n]; + fval[nf - n] = f; + if (kopt == nf) { + kopt = nf - n; + } + xpt[nf - n + nfx * xpt_dim1] = stepb; + xpt[nf + nfx * xpt_dim1] = stepa; + } + } + bmat[nfx * bmat_dim1 + 1] = -(stepa + stepb) / (stepa * stepb); + bmat[nf + nfx * bmat_dim1] = -half / xpt[nf - n + nfx * + xpt_dim1]; + bmat[nf - n + nfx * bmat_dim1] = -bmat[nfx * bmat_dim1 + 1] - + bmat[nf + nfx * bmat_dim1]; + zmat[nfx * zmat_dim1 + 1] = std::sqrt(two) / (stepa * stepb); + zmat[nf + nfx * zmat_dim1] = std::sqrt(half) / rhosq; + zmat[nf - n + nfx * zmat_dim1] = -zmat[nfx * zmat_dim1 + 1] - + zmat[nf + nfx * zmat_dim1]; + } + + /* Set the off-diagonal second derivatives of the Lagrange functions and */ + /* the initial quadratic model. */ + + } else { + ih = ipt * (ipt - 1) / 2 + jpt; + zmat[nfx * zmat_dim1 + 1] = recip; + zmat[nf + nfx * zmat_dim1] = recip; + zmat[ipt + 1 + nfx * zmat_dim1] = -recip; + zmat[jpt + 1 + nfx * zmat_dim1] = -recip; + temp = xpt[nf + ipt * xpt_dim1] * xpt[nf + jpt * xpt_dim1]; + hq[ih] = (fbeg - fval[ipt + 1] - fval[jpt + 1] + f) / temp; + } + if (nf < npt && nf < maxfun) { + goto L50; + } + + } /* prelim_ */ + + // ---------------------------------------------------------------------------------------- + + template + void rescue_ ( + const funct& calfun, + const integer n, + const integer npt, + const doublereal *xl, + const doublereal *xu, + const integer maxfun, + doublereal *xbase, + doublereal *xpt, + doublereal *fval, + doublereal *xopt, + doublereal *gopt, + doublereal *hq, + doublereal *pq, + doublereal *bmat, + doublereal *zmat, + const integer ndim, + doublereal *sl, + doublereal *su, + integer& nf, + const doublereal delta, + integer& kopt, + doublereal *vlag, + doublereal * ptsaux, + doublereal *ptsid, + doublereal *w + ) const + { + /* System generated locals */ + integer xpt_dim1, xpt_offset, bmat_dim1, bmat_offset, zmat_dim1, + zmat_offset, i__1, i__2, i__3; + doublereal d__1, d__2, d__3, d__4; + + + /* Local variables */ + doublereal f; + integer i__, j, k, ih, jp, ip, iq, np, iw; + doublereal xp = 0, xq = 0, den = 0; + integer ihp = 0; + doublereal one; + integer ihq, jpn, kpt; + doublereal sum = 0, diff = 0, half = 0, beta = 0; + integer kold; + doublereal winc; + integer nrem, knew; + doublereal temp, bsum; + integer nptm; + doublereal zero = 0, hdiag = 0, fbase = 0, sfrac = 0, denom = 0, vquad = 0, sumpq = 0; + doublereal dsqmin, distsq, vlmxsq; + + + + /* The arguments N, NPT, XL, XU, IPRINT, MAXFUN, XBASE, XPT, FVAL, XOPT, */ + /* GOPT, HQ, PQ, BMAT, ZMAT, NDIM, SL and SU have the same meanings as */ + /* the corresponding arguments of BOBYQB on the entry to RESCUE. */ + /* NF is maintained as the number of calls of CALFUN so far, except that */ + /* NF is set to -1 if the value of MAXFUN prevents further progress. */ + /* KOPT is maintained so that FVAL(KOPT) is the least calculated function */ + /* value. Its correct value must be given on entry. It is updated if a */ + /* new least function value is found, but the corresponding changes to */ + /* XOPT and GOPT have to be made later by the calling program. */ + /* DELTA is the current trust region radius. */ + /* VLAG is a working space vector that will be used for the values of the */ + /* provisional Lagrange functions at each of the interpolation points. */ + /* They are part of a product that requires VLAG to be of length NDIM. */ + /* PTSAUX is also a working space array. For J=1,2,...,N, PTSAUX(1,J) and */ + /* PTSAUX(2,J) specify the two positions of provisional interpolation */ + /* points when a nonzero step is taken along e_J (the J-th coordinate */ + /* direction) through XBASE+XOPT, as specified below. Usually these */ + /* steps have length DELTA, but other lengths are chosen if necessary */ + /* in order to satisfy the given bounds on the variables. */ + /* PTSID is also a working space array. It has NPT components that denote */ + /* provisional new positions of the original interpolation points, in */ + /* case changes are needed to restore the linear independence of the */ + /* interpolation conditions. The K-th point is a candidate for change */ + /* if and only if PTSID(K) is nonzero. In this case let p and q be the */ + /* integer parts of PTSID(K) and (PTSID(K)-p) multiplied by N+1. If p */ + /* and q are both positive, the step from XBASE+XOPT to the new K-th */ + /* interpolation point is PTSAUX(1,p)*e_p + PTSAUX(1,q)*e_q. Otherwise */ + /* the step is PTSAUX(1,p)*e_p or PTSAUX(2,q)*e_q in the cases q=0 or */ + /* p=0, respectively. */ + /* The first NDIM+NPT elements of the array W are used for working space. */ + /* The final elements of BMAT and ZMAT are set in a well-conditioned way */ + /* to the values that are appropriate for the new interpolation points. */ + /* The elements of GOPT, HQ and PQ are also revised to the values that are */ + /* appropriate to the final quadratic model. */ + + /* Set some constants. */ + + /* Parameter adjustments */ + zmat_dim1 = npt; + zmat_offset = 1 + zmat_dim1; + zmat -= zmat_offset; + xpt_dim1 = npt; + xpt_offset = 1 + xpt_dim1; + xpt -= xpt_offset; + --xl; + --xu; + --xbase; + --fval; + --xopt; + --gopt; + --hq; + --pq; + bmat_dim1 = ndim; + bmat_offset = 1 + bmat_dim1; + bmat -= bmat_offset; + --sl; + --su; + --vlag; + ptsaux -= 3; + --ptsid; + --w; + + /* Function Body */ + half = .5; + one = 1.; + zero = 0.; + np = n + 1; + sfrac = half / (doublereal) np; + nptm = npt - np; + + /* Shift the interpolation points so that XOPT becomes the origin, and set */ + /* the elements of ZMAT to zero. The value of SUMPQ is required in the */ + /* updating of HQ below. The squares of the distances from XOPT to the */ + /* other interpolation points are set at the end of W. Increments of WINC */ + /* may be added later to these squares to balance the consideration of */ + /* the choice of point that is going to become current. */ + + sumpq = zero; + winc = zero; + i__1 = npt; + for (k = 1; k <= i__1; ++k) { + distsq = zero; + i__2 = n; + for (j = 1; j <= i__2; ++j) { + xpt[k + j * xpt_dim1] -= xopt[j]; + /* L10: */ + /* Computing 2nd power */ + d__1 = xpt[k + j * xpt_dim1]; + distsq += d__1 * d__1; + } + sumpq += pq[k]; + w[ndim + k] = distsq; + winc = std::max(winc,distsq); + i__2 = nptm; + for (j = 1; j <= i__2; ++j) { + /* L20: */ + zmat[k + j * zmat_dim1] = zero; + } + } + + /* Update HQ so that HQ and PQ define the second derivatives of the model */ + /* after XBASE has been shifted to the trust region centre. */ + + ih = 0; + i__2 = n; + for (j = 1; j <= i__2; ++j) { + w[j] = half * sumpq * xopt[j]; + i__1 = npt; + for (k = 1; k <= i__1; ++k) { + /* L30: */ + w[j] += pq[k] * xpt[k + j * xpt_dim1]; + } + i__1 = j; + for (i__ = 1; i__ <= i__1; ++i__) { + ++ih; + /* L40: */ + hq[ih] = hq[ih] + w[i__] * xopt[j] + w[j] * xopt[i__]; + } + } + + /* Shift XBASE, SL, SU and XOPT. Set the elements of BMAT to zero, and */ + /* also set the elements of PTSAUX. */ + + i__1 = n; + for (j = 1; j <= i__1; ++j) { + xbase[j] += xopt[j]; + sl[j] -= xopt[j]; + su[j] -= xopt[j]; + xopt[j] = zero; + /* Computing MIN */ + d__1 = delta, d__2 = su[j]; + ptsaux[(j << 1) + 1] = std::min(d__1,d__2); + /* Computing MAX */ + d__1 = -(delta), d__2 = sl[j]; + ptsaux[(j << 1) + 2] = std::max(d__1,d__2); + if (ptsaux[(j << 1) + 1] + ptsaux[(j << 1) + 2] < zero) { + temp = ptsaux[(j << 1) + 1]; + ptsaux[(j << 1) + 1] = ptsaux[(j << 1) + 2]; + ptsaux[(j << 1) + 2] = temp; + } + if ((d__2 = ptsaux[(j << 1) + 2], std::abs(d__2)) < half * (d__1 = ptsaux[( + j << 1) + 1], std::abs(d__1))) { + ptsaux[(j << 1) + 2] = half * ptsaux[(j << 1) + 1]; + } + i__2 = ndim; + for (i__ = 1; i__ <= i__2; ++i__) { + /* L50: */ + bmat[i__ + j * bmat_dim1] = zero; + } + } + fbase = fval[kopt]; + + /* Set the identifiers of the artificial interpolation points that are */ + /* along a coordinate direction from XOPT, and set the corresponding */ + /* nonzero elements of BMAT and ZMAT. */ + + ptsid[1] = sfrac; + i__2 = n; + for (j = 1; j <= i__2; ++j) { + jp = j + 1; + jpn = jp + n; + ptsid[jp] = (doublereal) j + sfrac; + if (jpn <= npt) { + ptsid[jpn] = (doublereal) j / (doublereal) np + sfrac; + temp = one / (ptsaux[(j << 1) + 1] - ptsaux[(j << 1) + 2]); + bmat[jp + j * bmat_dim1] = -temp + one / ptsaux[(j << 1) + 1]; + bmat[jpn + j * bmat_dim1] = temp + one / ptsaux[(j << 1) + 2]; + bmat[j * bmat_dim1 + 1] = -bmat[jp + j * bmat_dim1] - bmat[jpn + + j * bmat_dim1]; + zmat[j * zmat_dim1 + 1] = std::sqrt(2.) / (d__1 = ptsaux[(j << 1) + 1] + * ptsaux[(j << 1) + 2], std::abs(d__1)); + zmat[jp + j * zmat_dim1] = zmat[j * zmat_dim1 + 1] * ptsaux[(j << + 1) + 2] * temp; + zmat[jpn + j * zmat_dim1] = -zmat[j * zmat_dim1 + 1] * ptsaux[(j + << 1) + 1] * temp; + } else { + bmat[j * bmat_dim1 + 1] = -one / ptsaux[(j << 1) + 1]; + bmat[jp + j * bmat_dim1] = one / ptsaux[(j << 1) + 1]; + /* Computing 2nd power */ + d__1 = ptsaux[(j << 1) + 1]; + bmat[j + npt + j * bmat_dim1] = -half * (d__1 * d__1); + } + /* L60: */ + } + + /* Set any remaining identifiers with their nonzero elements of ZMAT. */ + + if (npt >= n + np) { + i__2 = npt; + for (k = np << 1; k <= i__2; ++k) { + iw = (integer) (((doublereal) (k - np) - half) / (doublereal) (n) + ); + ip = k - np - iw * n; + iq = ip + iw; + if (iq > n) { + iq -= n; + } + ptsid[k] = (doublereal) ip + (doublereal) iq / (doublereal) np + + sfrac; + temp = one / (ptsaux[(ip << 1) + 1] * ptsaux[(iq << 1) + 1]); + zmat[(k - np) * zmat_dim1 + 1] = temp; + zmat[ip + 1 + (k - np) * zmat_dim1] = -temp; + zmat[iq + 1 + (k - np) * zmat_dim1] = -temp; + /* L70: */ + zmat[k + (k - np) * zmat_dim1] = temp; + } + } + nrem = npt; + kold = 1; + knew = kopt; + + /* Reorder the provisional points in the way that exchanges PTSID(KOLD) */ + /* with PTSID(KNEW). */ + +L80: + i__2 = n; + for (j = 1; j <= i__2; ++j) { + temp = bmat[kold + j * bmat_dim1]; + bmat[kold + j * bmat_dim1] = bmat[knew + j * bmat_dim1]; + /* L90: */ + bmat[knew + j * bmat_dim1] = temp; + } + i__2 = nptm; + for (j = 1; j <= i__2; ++j) { + temp = zmat[kold + j * zmat_dim1]; + zmat[kold + j * zmat_dim1] = zmat[knew + j * zmat_dim1]; + /* L100: */ + zmat[knew + j * zmat_dim1] = temp; + } + ptsid[kold] = ptsid[knew]; + ptsid[knew] = zero; + w[ndim + knew] = zero; + --nrem; + if (knew != kopt) { + temp = vlag[kold]; + vlag[kold] = vlag[knew]; + vlag[knew] = temp; + + /* Update the BMAT and ZMAT matrices so that the status of the KNEW-th */ + /* interpolation point can be changed from provisional to original. The */ + /* branch to label 350 occurs if all the original points are reinstated. */ + /* The nonnegative values of W(NDIM+K) are required in the search below. */ + + update_(n, npt, &bmat[bmat_offset], &zmat[zmat_offset], ndim, &vlag[1], + beta, denom, knew, &w[1]); + if (nrem == 0) { + goto L350; + } + i__2 = npt; + for (k = 1; k <= i__2; ++k) { + /* L110: */ + w[ndim + k] = (d__1 = w[ndim + k], std::abs(d__1)); + } + } + + /* Pick the index KNEW of an original interpolation point that has not */ + /* yet replaced one of the provisional interpolation points, giving */ + /* attention to the closeness to XOPT and to previous tries with KNEW. */ + +L120: + dsqmin = zero; + i__2 = npt; + for (k = 1; k <= i__2; ++k) { + if (w[ndim + k] > zero) { + if (dsqmin == zero || w[ndim + k] < dsqmin) { + knew = k; + dsqmin = w[ndim + k]; + } + } + /* L130: */ + } + if (dsqmin == zero) { + goto L260; + } + + /* Form the W-vector of the chosen original interpolation point. */ + + i__2 = n; + for (j = 1; j <= i__2; ++j) { + /* L140: */ + w[npt + j] = xpt[knew + j * xpt_dim1]; + } + i__2 = npt; + for (k = 1; k <= i__2; ++k) { + sum = zero; + if (k == kopt) { + } else if (ptsid[k] == zero) { + i__1 = n; + for (j = 1; j <= i__1; ++j) { + /* L150: */ + sum += w[npt + j] * xpt[k + j * xpt_dim1]; + } + } else { + ip = (integer) ptsid[k]; + if (ip > 0) { + sum = w[npt + ip] * ptsaux[(ip << 1) + 1]; + } + iq = (integer) ((doublereal) np * ptsid[k] - (doublereal) (ip * + np)); + if (iq > 0) { + iw = 1; + if (ip == 0) { + iw = 2; + } + sum += w[npt + iq] * ptsaux[iw + (iq << 1)]; + } + } + /* L160: */ + w[k] = half * sum * sum; + } + + /* Calculate VLAG and BETA for the required updating of the H matrix if */ + /* XPT(KNEW,.) is reinstated in the set of interpolation points. */ + + i__2 = npt; + for (k = 1; k <= i__2; ++k) { + sum = zero; + i__1 = n; + for (j = 1; j <= i__1; ++j) { + /* L170: */ + sum += bmat[k + j * bmat_dim1] * w[npt + j]; + } + /* L180: */ + vlag[k] = sum; + } + beta = zero; + i__2 = nptm; + for (j = 1; j <= i__2; ++j) { + sum = zero; + i__1 = npt; + for (k = 1; k <= i__1; ++k) { + /* L190: */ + sum += zmat[k + j * zmat_dim1] * w[k]; + } + beta -= sum * sum; + i__1 = npt; + for (k = 1; k <= i__1; ++k) { + /* L200: */ + vlag[k] += sum * zmat[k + j * zmat_dim1]; + } + } + bsum = zero; + distsq = zero; + i__1 = n; + for (j = 1; j <= i__1; ++j) { + sum = zero; + i__2 = npt; + for (k = 1; k <= i__2; ++k) { + /* L210: */ + sum += bmat[k + j * bmat_dim1] * w[k]; + } + jp = j + npt; + bsum += sum * w[jp]; + i__2 = ndim; + for (ip = npt + 1; ip <= i__2; ++ip) { + /* L220: */ + sum += bmat[ip + j * bmat_dim1] * w[ip]; + } + bsum += sum * w[jp]; + vlag[jp] = sum; + /* L230: */ + /* Computing 2nd power */ + d__1 = xpt[knew + j * xpt_dim1]; + distsq += d__1 * d__1; + } + beta = half * distsq * distsq + beta - bsum; + vlag[kopt] += one; + + /* KOLD is set to the index of the provisional interpolation point that is */ + /* going to be deleted to make way for the KNEW-th original interpolation */ + /* point. The choice of KOLD is governed by the avoidance of a small value */ + /* of the denominator in the updating calculation of UPDATE. */ + + denom = zero; + vlmxsq = zero; + i__1 = npt; + for (k = 1; k <= i__1; ++k) { + if (ptsid[k] != zero) { + hdiag = zero; + i__2 = nptm; + for (j = 1; j <= i__2; ++j) { + /* L240: */ + /* Computing 2nd power */ + d__1 = zmat[k + j * zmat_dim1]; + hdiag += d__1 * d__1; + } + /* Computing 2nd power */ + d__1 = vlag[k]; + den = beta * hdiag + d__1 * d__1; + if (den > denom) { + kold = k; + denom = den; + } + } + /* L250: */ + /* Computing MAX */ + /* Computing 2nd power */ + d__3 = vlag[k]; + d__1 = vlmxsq, d__2 = d__3 * d__3; + vlmxsq = std::max(d__1,d__2); + } + if (denom <= vlmxsq * .01) { + w[ndim + knew] = -w[ndim + knew] - winc; + goto L120; + } + goto L80; + + /* When label 260 is reached, all the final positions of the interpolation */ + /* points have been chosen although any changes have not been included yet */ + /* in XPT. Also the final BMAT and ZMAT matrices are complete, but, apart */ + /* from the shift of XBASE, the updating of the quadratic model remains to */ + /* be done. The following cycle through the new interpolation points begins */ + /* by putting the new point in XPT(KPT,.) and by setting PQ(KPT) to zero, */ + /* except that a RETURN occurs if MAXFUN prohibits another value of F. */ + +L260: + i__1 = npt; + for (kpt = 1; kpt <= i__1; ++kpt) { + if (ptsid[kpt] == zero) { + goto L340; + } + if (nf >= maxfun) { + nf = -1; + goto L350; + } + ih = 0; + i__2 = n; + for (j = 1; j <= i__2; ++j) { + w[j] = xpt[kpt + j * xpt_dim1]; + xpt[kpt + j * xpt_dim1] = zero; + temp = pq[kpt] * w[j]; + i__3 = j; + for (i__ = 1; i__ <= i__3; ++i__) { + ++ih; + /* L270: */ + hq[ih] += temp * w[i__]; + } + } + pq[kpt] = zero; + ip = (integer) ptsid[kpt]; + iq = (integer) ((doublereal) np * ptsid[kpt] - (doublereal) (ip * np)) + ; + if (ip > 0) { + xp = ptsaux[(ip << 1) + 1]; + xpt[kpt + ip * xpt_dim1] = xp; + } + if (iq > 0) { + xq = ptsaux[(iq << 1) + 1]; + if (ip == 0) { + xq = ptsaux[(iq << 1) + 2]; + } + xpt[kpt + iq * xpt_dim1] = xq; + } + + /* Set VQUAD to the value of the current model at the new point. */ + + vquad = fbase; + if (ip > 0) { + ihp = (ip + ip * ip) / 2; + vquad += xp * (gopt[ip] + half * xp * hq[ihp]); + } + if (iq > 0) { + ihq = (iq + iq * iq) / 2; + vquad += xq * (gopt[iq] + half * xq * hq[ihq]); + if (ip > 0) { + iw = std::max(ihp,ihq) - (i__3 = ip - iq, std::abs(i__3)); + vquad += xp * xq * hq[iw]; + } + } + i__3 = npt; + for (k = 1; k <= i__3; ++k) { + temp = zero; + if (ip > 0) { + temp += xp * xpt[k + ip * xpt_dim1]; + } + if (iq > 0) { + temp += xq * xpt[k + iq * xpt_dim1]; + } + /* L280: */ + vquad += half * pq[k] * temp * temp; + } + + /* Calculate F at the new interpolation point, and set DIFF to the factor */ + /* that is going to multiply the KPT-th Lagrange function when the model */ + /* is updated to provide interpolation to the new function value. */ + + i__3 = n; + for (i__ = 1; i__ <= i__3; ++i__) { + /* Computing MIN */ + /* Computing MAX */ + d__3 = xl[i__], d__4 = xbase[i__] + xpt[kpt + i__ * xpt_dim1]; + d__1 = std::max(d__3,d__4), d__2 = xu[i__]; + w[i__] = std::min(d__1,d__2); + if (xpt[kpt + i__ * xpt_dim1] == sl[i__]) { + w[i__] = xl[i__]; + } + if (xpt[kpt + i__ * xpt_dim1] == su[i__]) { + w[i__] = xu[i__]; + } + /* L290: */ + } + ++(nf); + f = calfun(mat(&w[1],n)); + fval[kpt] = f; + if (f < fval[kopt]) { + kopt = kpt; + } + diff = f - vquad; + + /* Update the quadratic model. The RETURN from the subroutine occurs when */ + /* all the new interpolation points are included in the model. */ + + i__3 = n; + for (i__ = 1; i__ <= i__3; ++i__) { + /* L310: */ + gopt[i__] += diff * bmat[kpt + i__ * bmat_dim1]; + } + i__3 = npt; + for (k = 1; k <= i__3; ++k) { + sum = zero; + i__2 = nptm; + for (j = 1; j <= i__2; ++j) { + /* L320: */ + sum += zmat[k + j * zmat_dim1] * zmat[kpt + j * zmat_dim1]; + } + temp = diff * sum; + if (ptsid[k] == zero) { + pq[k] += temp; + } else { + ip = (integer) ptsid[k]; + iq = (integer) ((doublereal) np * ptsid[k] - (doublereal) (ip + * np)); + ihq = (iq * iq + iq) / 2; + if (ip == 0) { + /* Computing 2nd power */ + d__1 = ptsaux[(iq << 1) + 2]; + hq[ihq] += temp * (d__1 * d__1); + } else { + ihp = (ip * ip + ip) / 2; + /* Computing 2nd power */ + d__1 = ptsaux[(ip << 1) + 1]; + hq[ihp] += temp * (d__1 * d__1); + if (iq > 0) { + /* Computing 2nd power */ + d__1 = ptsaux[(iq << 1) + 1]; + hq[ihq] += temp * (d__1 * d__1); + iw = std::max(ihp,ihq) - (i__2 = iq - ip, std::abs(i__2)); + hq[iw] += temp * ptsaux[(ip << 1) + 1] * ptsaux[(iq << + 1) + 1]; + } + } + } + /* L330: */ + } + ptsid[kpt] = zero; +L340: + ; + } +L350: + ; + } /* rescue_ */ + + // ---------------------------------------------------------------------------------------- + + void trsbox_( + const integer n, + const integer npt, + const doublereal *xpt, + const doublereal *xopt, + const doublereal *gopt, + const doublereal *hq, + const doublereal *pq, + const doublereal *sl, + const doublereal *su, + const doublereal delta, + doublereal *xnew, + doublereal *d__, + doublereal *gnew, + doublereal *xbdi, + doublereal *s, + doublereal *hs, + doublereal *hred, + doublereal *dsq, + doublereal *crvmin + ) const + { + /* System generated locals */ + integer xpt_dim1, xpt_offset, i__1, i__2; + doublereal d__1, d__2, d__3, d__4; + + /* Local variables */ + integer i__, j, k, ih; + doublereal ds; + integer iu; + doublereal dhd, dhs, cth, one, shs, sth, ssq, half, beta, sdec, blen; + integer iact = 0, nact = 0; + doublereal angt, qred; + integer isav; + doublereal temp = 0, zero = 0, xsav = 0, xsum = 0, angbd = 0, dredg = 0, sredg = 0; + integer iterc; + doublereal resid = 0, delsq = 0, ggsav = 0, tempa = 0, tempb = 0, + redmax = 0, dredsq = 0, redsav = 0, onemin = 0, gredsq = 0, rednew = 0; + integer itcsav = 0; + doublereal rdprev = 0, rdnext = 0, stplen = 0, stepsq = 0; + integer itermax = 0; + + + /* The arguments N, NPT, XPT, XOPT, GOPT, HQ, PQ, SL and SU have the same */ + /* meanings as the corresponding arguments of BOBYQB. */ + /* DELTA is the trust region radius for the present calculation, which */ + /* seeks a small value of the quadratic model within distance DELTA of */ + /* XOPT subject to the bounds on the variables. */ + /* XNEW will be set to a new vector of variables that is approximately */ + /* the one that minimizes the quadratic model within the trust region */ + /* subject to the SL and SU constraints on the variables. It satisfies */ + /* as equations the bounds that become active during the calculation. */ + /* D is the calculated trial step from XOPT, generated iteratively from an */ + /* initial value of zero. Thus XNEW is XOPT+D after the final iteration. */ + /* GNEW holds the gradient of the quadratic model at XOPT+D. It is updated */ + /* when D is updated. */ + /* XBDI is a working space vector. For I=1,2,...,N, the element XBDI(I) is */ + /* set to -1.0, 0.0, or 1.0, the value being nonzero if and only if the */ + /* I-th variable has become fixed at a bound, the bound being SL(I) or */ + /* SU(I) in the case XBDI(I)=-1.0 or XBDI(I)=1.0, respectively. This */ + /* information is accumulated during the construction of XNEW. */ + /* The arrays S, HS and HRED are also used for working space. They hold the */ + /* current search direction, and the changes in the gradient of Q along S */ + /* and the reduced D, respectively, where the reduced D is the same as D, */ + /* except that the components of the fixed variables are zero. */ + /* DSQ will be set to the square of the length of XNEW-XOPT. */ + /* CRVMIN is set to zero if D reaches the trust region boundary. Otherwise */ + /* it is set to the least curvature of H that occurs in the conjugate */ + /* gradient searches that are not restricted by any constraints. The */ + /* value CRVMIN=-1.0D0 is set, however, if all of these searches are */ + /* constrained. */ + + /* A version of the truncated conjugate gradient is applied. If a line */ + /* search is restricted by a constraint, then the procedure is restarted, */ + /* the values of the variables that are at their bounds being fixed. If */ + /* the trust region boundary is reached, then further changes may be made */ + /* to D, each one being in the two dimensional space that is spanned */ + /* by the current D and the gradient of Q at XOPT+D, staying on the trust */ + /* region boundary. Termination occurs when the reduction in Q seems to */ + /* be close to the greatest reduction that can be achieved. */ + + /* Set some constants. */ + + /* Parameter adjustments */ + xpt_dim1 = npt; + xpt_offset = 1 + xpt_dim1; + xpt -= xpt_offset; + --xopt; + --gopt; + --hq; + --pq; + --sl; + --su; + --xnew; + --d__; + --gnew; + --xbdi; + --s; + --hs; + --hred; + + /* Function Body */ + half = .5; + one = 1.; + onemin = -1.; + zero = 0.; + + /* The sign of GOPT(I) gives the sign of the change to the I-th variable */ + /* that will reduce Q from its value at XOPT. Thus XBDI(I) shows whether */ + /* or not to fix the I-th variable at one of its bounds initially, with */ + /* NACT being set to the number of fixed variables. D and GNEW are also */ + /* set for the first iteration. DELSQ is the upper bound on the sum of */ + /* squares of the free variables. QRED is the reduction in Q so far. */ + + iterc = 0; + nact = 0; + i__1 = n; + for (i__ = 1; i__ <= i__1; ++i__) { + xbdi[i__] = zero; + if (xopt[i__] <= sl[i__]) { + if (gopt[i__] >= zero) { + xbdi[i__] = onemin; + } + } else if (xopt[i__] >= su[i__]) { + if (gopt[i__] <= zero) { + xbdi[i__] = one; + } + } + if (xbdi[i__] != zero) { + ++nact; + } + d__[i__] = zero; + /* L10: */ + gnew[i__] = gopt[i__]; + } + delsq = delta * delta; + qred = zero; + *crvmin = onemin; + + /* Set the next search direction of the conjugate gradient method. It is */ + /* the steepest descent direction initially and when the iterations are */ + /* restarted because a variable has just been fixed by a bound, and of */ + /* course the components of the fixed variables are zero. ITERMAX is an */ + /* upper bound on the indices of the conjugate gradient iterations. */ + +L20: + beta = zero; +L30: + stepsq = zero; + i__1 = n; + for (i__ = 1; i__ <= i__1; ++i__) { + if (xbdi[i__] != zero) { + s[i__] = zero; + } else if (beta == zero) { + s[i__] = -gnew[i__]; + } else { + s[i__] = beta * s[i__] - gnew[i__]; + } + /* L40: */ + /* Computing 2nd power */ + d__1 = s[i__]; + stepsq += d__1 * d__1; + } + if (stepsq == zero) { + goto L190; + } + if (beta == zero) { + gredsq = stepsq; + itermax = iterc + n - nact; + } + if (gredsq * delsq <= qred * 1e-4 * qred) { + goto L190; + } + + /* Multiply the search direction by the second derivative matrix of Q and */ + /* calculate some scalars for the choice of steplength. Then set BLEN to */ + /* the length of the the step to the trust region boundary and STPLEN to */ + /* the steplength, ignoring the simple bounds. */ + + goto L210; +L50: + resid = delsq; + ds = zero; + shs = zero; + i__1 = n; + for (i__ = 1; i__ <= i__1; ++i__) { + if (xbdi[i__] == zero) { + /* Computing 2nd power */ + d__1 = d__[i__]; + resid -= d__1 * d__1; + ds += s[i__] * d__[i__]; + shs += s[i__] * hs[i__]; + } + /* L60: */ + } + if (resid <= zero) { + goto L90; + } + temp = std::sqrt(stepsq * resid + ds * ds); + if (ds < zero) { + blen = (temp - ds) / stepsq; + } else { + blen = resid / (temp + ds); + } + stplen = blen; + if (shs > zero) { + /* Computing MIN */ + d__1 = blen, d__2 = gredsq / shs; + stplen = std::min(d__1,d__2); + } + + /* Reduce STPLEN if necessary in order to preserve the simple bounds, */ + /* letting IACT be the index of the new constrained variable. */ + + iact = 0; + i__1 = n; + for (i__ = 1; i__ <= i__1; ++i__) { + if (s[i__] != zero) { + xsum = xopt[i__] + d__[i__]; + if (s[i__] > zero) { + temp = (su[i__] - xsum) / s[i__]; + } else { + temp = (sl[i__] - xsum) / s[i__]; + } + if (temp < stplen) { + stplen = temp; + iact = i__; + } + } + /* L70: */ + } + + /* Update CRVMIN, GNEW and D. Set SDEC to the decrease that occurs in Q. */ + + sdec = zero; + if (stplen > zero) { + ++iterc; + temp = shs / stepsq; + if (iact == 0 && temp > zero) { + *crvmin = std::min(*crvmin,temp); + if (*crvmin == onemin) { + *crvmin = temp; + } + } + ggsav = gredsq; + gredsq = zero; + i__1 = n; + for (i__ = 1; i__ <= i__1; ++i__) { + gnew[i__] += stplen * hs[i__]; + if (xbdi[i__] == zero) { + /* Computing 2nd power */ + d__1 = gnew[i__]; + gredsq += d__1 * d__1; + } + /* L80: */ + d__[i__] += stplen * s[i__]; + } + /* Computing MAX */ + d__1 = stplen * (ggsav - half * stplen * shs); + sdec = std::max(d__1,zero); + qred += sdec; + } + + /* Restart the conjugate gradient method if it has hit a new bound. */ + + if (iact > 0) { + ++nact; + xbdi[iact] = one; + if (s[iact] < zero) { + xbdi[iact] = onemin; + } + /* Computing 2nd power */ + d__1 = d__[iact]; + delsq -= d__1 * d__1; + if (delsq <= zero) { + goto L90; + } + goto L20; + } + + /* If STPLEN is less than BLEN, then either apply another conjugate */ + /* gradient iteration or RETURN. */ + + if (stplen < blen) { + if (iterc == itermax) { + goto L190; + } + if (sdec <= qred * .01) { + goto L190; + } + beta = gredsq / ggsav; + goto L30; + } +L90: + *crvmin = zero; + + /* Prepare for the alternative iteration by calculating some scalars */ + /* and by multiplying the reduced D by the second derivative matrix of */ + /* Q, where S holds the reduced D in the call of GGMULT. */ + +L100: + if (nact >= n - 1) { + goto L190; + } + dredsq = zero; + dredg = zero; + gredsq = zero; + i__1 = n; + for (i__ = 1; i__ <= i__1; ++i__) { + if (xbdi[i__] == zero) { + /* Computing 2nd power */ + d__1 = d__[i__]; + dredsq += d__1 * d__1; + dredg += d__[i__] * gnew[i__]; + /* Computing 2nd power */ + d__1 = gnew[i__]; + gredsq += d__1 * d__1; + s[i__] = d__[i__]; + } else { + s[i__] = zero; + } + /* L110: */ + } + itcsav = iterc; + goto L210; + + /* Let the search direction S be a linear combination of the reduced D */ + /* and the reduced G that is orthogonal to the reduced D. */ + +L120: + ++iterc; + temp = gredsq * dredsq - dredg * dredg; + if (temp <= qred * 1e-4 * qred) { + goto L190; + } + temp = std::sqrt(temp); + i__1 = n; + for (i__ = 1; i__ <= i__1; ++i__) { + if (xbdi[i__] == zero) { + s[i__] = (dredg * d__[i__] - dredsq * gnew[i__]) / temp; + } else { + s[i__] = zero; + } + /* L130: */ + } + sredg = -temp; + + /* By considering the simple bounds on the variables, calculate an upper */ + /* bound on the tangent of half the angle of the alternative iteration, */ + /* namely ANGBD, except that, if already a free variable has reached a */ + /* bound, there is a branch back to label 100 after fixing that variable. */ + + angbd = one; + iact = 0; + i__1 = n; + for (i__ = 1; i__ <= i__1; ++i__) { + if (xbdi[i__] == zero) { + tempa = xopt[i__] + d__[i__] - sl[i__]; + tempb = su[i__] - xopt[i__] - d__[i__]; + if (tempa <= zero) { + ++nact; + xbdi[i__] = onemin; + goto L100; + } else if (tempb <= zero) { + ++nact; + xbdi[i__] = one; + goto L100; + } + /* Computing 2nd power */ + d__1 = d__[i__]; + /* Computing 2nd power */ + d__2 = s[i__]; + ssq = d__1 * d__1 + d__2 * d__2; + /* Computing 2nd power */ + d__1 = xopt[i__] - sl[i__]; + temp = ssq - d__1 * d__1; + if (temp > zero) { + temp = std::sqrt(temp) - s[i__]; + if (angbd * temp > tempa) { + angbd = tempa / temp; + iact = i__; + xsav = onemin; + } + } + /* Computing 2nd power */ + d__1 = su[i__] - xopt[i__]; + temp = ssq - d__1 * d__1; + if (temp > zero) { + temp = std::sqrt(temp) + s[i__]; + if (angbd * temp > tempb) { + angbd = tempb / temp; + iact = i__; + xsav = one; + } + } + } + /* L140: */ + } + + /* Calculate HHD and some curvatures for the alternative iteration. */ + + goto L210; +L150: + shs = zero; + dhs = zero; + dhd = zero; + i__1 = n; + for (i__ = 1; i__ <= i__1; ++i__) { + if (xbdi[i__] == zero) { + shs += s[i__] * hs[i__]; + dhs += d__[i__] * hs[i__]; + dhd += d__[i__] * hred[i__]; + } + /* L160: */ + } + + /* Seek the greatest reduction in Q for a range of equally spaced values */ + /* of ANGT in [0,ANGBD], where ANGT is the tangent of half the angle of */ + /* the alternative iteration. */ + + redmax = zero; + isav = 0; + redsav = zero; + iu = (integer) (angbd * 17. + 3.1); + i__1 = iu; + for (i__ = 1; i__ <= i__1; ++i__) { + angt = angbd * (doublereal) i__ / (doublereal) iu; + sth = (angt + angt) / (one + angt * angt); + temp = shs + angt * (angt * dhd - dhs - dhs); + rednew = sth * (angt * dredg - sredg - half * sth * temp); + if (rednew > redmax) { + redmax = rednew; + isav = i__; + rdprev = redsav; + } else if (i__ == isav + 1) { + rdnext = rednew; + } + /* L170: */ + redsav = rednew; + } + + /* Return if the reduction is zero. Otherwise, set the sine and cosine */ + /* of the angle of the alternative iteration, and calculate SDEC. */ + + if (isav == 0) { + goto L190; + } + if (isav < iu) { + temp = (rdnext - rdprev) / (redmax + redmax - rdprev - rdnext); + angt = angbd * ((doublereal) isav + half * temp) / (doublereal) iu; + } + cth = (one - angt * angt) / (one + angt * angt); + sth = (angt + angt) / (one + angt * angt); + temp = shs + angt * (angt * dhd - dhs - dhs); + sdec = sth * (angt * dredg - sredg - half * sth * temp); + if (sdec <= zero) { + goto L190; + } + + /* Update GNEW, D and HRED. If the angle of the alternative iteration */ + /* is restricted by a bound on a free variable, that variable is fixed */ + /* at the bound. */ + + dredg = zero; + gredsq = zero; + i__1 = n; + for (i__ = 1; i__ <= i__1; ++i__) { + gnew[i__] = gnew[i__] + (cth - one) * hred[i__] + sth * hs[i__]; + if (xbdi[i__] == zero) { + d__[i__] = cth * d__[i__] + sth * s[i__]; + dredg += d__[i__] * gnew[i__]; + /* Computing 2nd power */ + d__1 = gnew[i__]; + gredsq += d__1 * d__1; + } + /* L180: */ + hred[i__] = cth * hred[i__] + sth * hs[i__]; + } + qred += sdec; + if (iact > 0 && isav == iu) { + ++nact; + xbdi[iact] = xsav; + goto L100; + } + + /* If SDEC is sufficiently small, then RETURN after setting XNEW to */ + /* XOPT+D, giving careful attention to the bounds. */ + + if (sdec > qred * .01) { + goto L120; + } +L190: + *dsq = zero; + i__1 = n; + for (i__ = 1; i__ <= i__1; ++i__) { + /* Computing MAX */ + /* Computing MIN */ + d__3 = xopt[i__] + d__[i__], d__4 = su[i__]; + d__1 = std::min(d__3,d__4), d__2 = sl[i__]; + xnew[i__] = std::max(d__1,d__2); + if (xbdi[i__] == onemin) { + xnew[i__] = sl[i__]; + } + if (xbdi[i__] == one) { + xnew[i__] = su[i__]; + } + d__[i__] = xnew[i__] - xopt[i__]; + /* L200: */ + /* Computing 2nd power */ + d__1 = d__[i__]; + *dsq += d__1 * d__1; + } + return; + /* The following instructions multiply the current S-vector by the second */ + /* derivative matrix of the quadratic model, putting the product in HS. */ + /* They are reached from three different parts of the software above and */ + /* they can be regarded as an external subroutine. */ + +L210: + ih = 0; + i__1 = n; + for (j = 1; j <= i__1; ++j) { + hs[j] = zero; + i__2 = j; + for (i__ = 1; i__ <= i__2; ++i__) { + ++ih; + if (i__ < j) { + hs[j] += hq[ih] * s[i__]; + } + /* L220: */ + hs[i__] += hq[ih] * s[j]; + } + } + i__2 = npt; + for (k = 1; k <= i__2; ++k) { + if (pq[k] != zero) { + temp = zero; + i__1 = n; + for (j = 1; j <= i__1; ++j) { + /* L230: */ + temp += xpt[k + j * xpt_dim1] * s[j]; + } + temp *= pq[k]; + i__1 = n; + for (i__ = 1; i__ <= i__1; ++i__) { + /* L240: */ + hs[i__] += temp * xpt[k + i__ * xpt_dim1]; + } + } + /* L250: */ + } + if (*crvmin != zero) { + goto L50; + } + if (iterc > itcsav) { + goto L150; + } + i__2 = n; + for (i__ = 1; i__ <= i__2; ++i__) { + /* L260: */ + hred[i__] = hs[i__]; + } + goto L120; + } /* trsbox_ */ + + // ---------------------------------------------------------------------------------------- + + void update_( + const integer n, + const integer npt, + doublereal *bmat, + doublereal *zmat, + const integer ndim, + doublereal *vlag, + const doublereal beta, + const doublereal denom, + const integer knew, + doublereal *w + ) const + { + /* System generated locals */ + integer bmat_dim1, bmat_offset, zmat_dim1, zmat_offset, i__1, i__2; + doublereal d__1, d__2, d__3; + + /* Local variables */ + integer i__, j, k, jp; + doublereal one, tau, temp; + integer nptm; + doublereal zero, alpha, tempa, tempb, ztest; + + + /* The arrays BMAT and ZMAT are updated, as required by the new position */ + /* of the interpolation point that has the index KNEW. The vector VLAG has */ + /* N+NPT components, set on entry to the first NPT and last N components */ + /* of the product Hw in equation (4.11) of the Powell (2006) paper on */ + /* NEWUOA. Further, BETA is set on entry to the value of the parameter */ + /* with that name, and DENOM is set to the denominator of the updating */ + /* formula. Elements of ZMAT may be treated as zero if their moduli are */ + /* at most ZTEST. The first NDIM elements of W are used for working space. */ + + /* Set some constants. */ + + /* Parameter adjustments */ + zmat_dim1 = npt; + zmat_offset = 1 + zmat_dim1; + zmat -= zmat_offset; + bmat_dim1 = ndim; + bmat_offset = 1 + bmat_dim1; + bmat -= bmat_offset; + --vlag; + --w; + + /* Function Body */ + one = 1.; + zero = 0.; + nptm = npt - n - 1; + ztest = zero; + i__1 = npt; + for (k = 1; k <= i__1; ++k) { + i__2 = nptm; + for (j = 1; j <= i__2; ++j) { + /* L10: */ + /* Computing MAX */ + d__2 = ztest, d__3 = (d__1 = zmat[k + j * zmat_dim1], std::abs(d__1)); + ztest = std::max(d__2,d__3); + } + } + ztest *= 1e-20; + + /* Apply the rotations that put zeros in the KNEW-th row of ZMAT. */ + + i__2 = nptm; + for (j = 2; j <= i__2; ++j) { + if ((d__1 = zmat[knew + j * zmat_dim1], std::abs(d__1)) > ztest) { + /* Computing 2nd power */ + d__1 = zmat[knew + zmat_dim1]; + /* Computing 2nd power */ + d__2 = zmat[knew + j * zmat_dim1]; + temp = std::sqrt(d__1 * d__1 + d__2 * d__2); + tempa = zmat[knew + zmat_dim1] / temp; + tempb = zmat[knew + j * zmat_dim1] / temp; + i__1 = npt; + for (i__ = 1; i__ <= i__1; ++i__) { + temp = tempa * zmat[i__ + zmat_dim1] + tempb * zmat[i__ + j * + zmat_dim1]; + zmat[i__ + j * zmat_dim1] = tempa * zmat[i__ + j * zmat_dim1] + - tempb * zmat[i__ + zmat_dim1]; + /* L20: */ + zmat[i__ + zmat_dim1] = temp; + } + } + zmat[knew + j * zmat_dim1] = zero; + /* L30: */ + } + + /* Put the first NPT components of the KNEW-th column of HLAG into W, */ + /* and calculate the parameters of the updating formula. */ + + i__2 = npt; + for (i__ = 1; i__ <= i__2; ++i__) { + w[i__] = zmat[knew + zmat_dim1] * zmat[i__ + zmat_dim1]; + /* L40: */ + } + alpha = w[knew]; + tau = vlag[knew]; + vlag[knew] -= one; + + /* Complete the updating of ZMAT. */ + + temp = std::sqrt(denom); + tempb = zmat[knew + zmat_dim1] / temp; + tempa = tau / temp; + i__2 = npt; + for (i__ = 1; i__ <= i__2; ++i__) { + /* L50: */ + zmat[i__ + zmat_dim1] = tempa * zmat[i__ + zmat_dim1] - tempb * vlag[ + i__]; + } + + /* Finally, update the matrix BMAT. */ + + i__2 = n; + for (j = 1; j <= i__2; ++j) { + jp = npt + j; + w[jp] = bmat[knew + j * bmat_dim1]; + tempa = (alpha * vlag[jp] - tau * w[jp]) / denom; + tempb = (-(beta) * w[jp] - tau * vlag[jp]) / denom; + i__1 = jp; + for (i__ = 1; i__ <= i__1; ++i__) { + bmat[i__ + j * bmat_dim1] = bmat[i__ + j * bmat_dim1] + tempa * + vlag[i__] + tempb * w[i__]; + if (i__ > npt) { + bmat[jp + (i__ - npt) * bmat_dim1] = bmat[i__ + j * + bmat_dim1]; + } + /* L60: */ + } + } + } /* update_ */ + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename funct, + typename T, + typename U + > + double find_min_bobyqa ( + const funct& f, + T& x, + long npt, + const U& x_lower, + const U& x_upper, + const double rho_begin, + const double rho_end, + const long max_f_evals + ) + { + // The starting point (i.e. x) must be a column vector. + COMPILE_TIME_ASSERT(T::NC <= 1); + + // check the requirements. Also split the assert up so that the error message isn't huge. + DLIB_CASSERT(is_col_vector(x) && is_col_vector(x_lower) && is_col_vector(x_upper) && + x.size() == x_lower.size() && x_lower.size() == x_upper.size() && + x.size() > 1 && max_f_evals > 1, + "\tdouble find_min_bobyqa()" + << "\n\t Invalid arguments have been given to this function" + << "\n\t is_col_vector(x): " << is_col_vector(x) + << "\n\t is_col_vector(x_lower): " << is_col_vector(x_lower) + << "\n\t is_col_vector(x_upper): " << is_col_vector(x_upper) + << "\n\t x.size(): " << x.size() + << "\n\t x_lower.size(): " << x_lower.size() + << "\n\t x_upper.size(): " << x_upper.size() + << "\n\t max_f_evals: " << max_f_evals + ); + + DLIB_CASSERT(x.size() + 2 <= npt && npt <= (x.size()+1)*(x.size()+2)/2 && + 0 < rho_end && rho_end < rho_begin && + min(x_upper - x_lower) > 2*rho_begin && + min(x - x_lower) >= 0 && min(x_upper - x) >= 0, + "\tdouble find_min_bobyqa()" + << "\n\t Invalid arguments have been given to this function" + << "\n\t ntp in valid range: " << (x.size() + 2 <= npt && npt <= (x.size()+1)*(x.size()+2)/2) + << "\n\t npt: " << npt + << "\n\t rho_begin: " << rho_begin + << "\n\t rho_end: " << rho_end + << "\n\t min(x_upper - x_lower) > 2*rho_begin: " << (min(x_upper - x_lower) > 2*rho_begin) + << "\n\t min(x - x_lower) >= 0 && min(x_upper - x) >= 0: " << (min(x - x_lower) >= 0 && min(x_upper - x) >= 0) + ); + + + bobyqa_implementation impl; + return impl.find_min(f, x, npt, x_lower, x_upper, rho_begin, rho_end, max_f_evals); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct, + typename T, + typename U + > + double find_max_bobyqa ( + const funct& f, + T& x, + long npt, + const U& x_lower, + const U& x_upper, + const double rho_begin, + const double rho_end, + const long max_f_evals + ) + { + // The starting point (i.e. x) must be a column vector. + COMPILE_TIME_ASSERT(T::NC <= 1); + + return -find_min_bobyqa(negate_function(f), x, npt, x_lower, x_upper, rho_begin, rho_end, max_f_evals); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_OPTIMIZATIOn_BOBYQA_Hh_ + diff --git a/ml/dlib/dlib/optimization/optimization_bobyqa_abstract.h b/ml/dlib/dlib/optimization/optimization_bobyqa_abstract.h new file mode 100644 index 000000000..46f9436af --- /dev/null +++ b/ml/dlib/dlib/optimization/optimization_bobyqa_abstract.h @@ -0,0 +1,120 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_OPTIMIZATIOn_BOBYQA_ABSTRACT_Hh_ +#ifdef DLIB_OPTIMIZATIOn_BOBYQA_ABSTRACT_Hh_ + +#include "../matrix.h" + +// ---------------------------------------------------------------------------------------- + +/* + This file defines the dlib interface to the BOBYQA software developed by M.J.D Powell. + BOBYQA is a method for optimizing a function in the absence of derivative information. + Powell described it as a method that seeks the least value of a function of many + variables, by applying a trust region method that forms quadratic models by + interpolation. There is usually some freedom in the interpolation conditions, + which is taken up by minimizing the Frobenius norm of the change to the second + derivative of the model, beginning with the zero matrix. The values of the variables + are constrained by upper and lower bounds. + + + The following paper, published in 2009 by Powell, describes the + detailed working of the BOBYQA algorithm. + + The BOBYQA algorithm for bound constrained optimization + without derivatives by M.J.D. Powell +*/ + +// ---------------------------------------------------------------------------------------- + +namespace dlib +{ + class bobyqa_failure : public error; + /*! + This is the exception class used by the functions defined in this file. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename funct, + typename T, + typename U + > + double find_min_bobyqa ( + const funct& f, + T& x, + long npt, + const U& x_lower, + const U& x_upper, + const double rho_begin, + const double rho_end, + const long max_f_evals + ); + /*! + requires + - f(x) must be a valid expression that evaluates to a double + - is_col_vector(x) == true + - is_col_vector(x_lower) == true + - is_col_vector(x_upper) == true + - x.size() == x_lower.size() == x_upper.size() + - x.size() > 1 + - x.size() + 2 <= npt <= (x.size()+1)*(x.size()+2)/2 + - 0 < rho_end < rho_begin + - min(x_upper - x_lower) > 2*rho_begin + (i.e. the lower and upper bounds on each x element must be larger than 2*rho_begin) + - min(x - x_lower) >= 0 && min(x_upper - x) >= 0 + (i.e. the given x should be within the bounds defined by x_lower and x_upper) + - max_f_evals > 1 + ensures + - Performs a constrained minimization of the function f() starting from + the initial point x. + - The BOBYQA algorithm uses a number of interpolating points to perform its + work. The npt argument controls how many points get used. Typically, + a good value to use is 2*x.size()+1. + - #x == the value of x (within the bounds defined by x_lower and x_upper) that + was found to minimize f(). More precisely: + - min(#x - x_lower) >= 0 && min(x_upper - #x) >= 0 + - returns f(#x). + - rho_begin and rho_end are used as the initial and final values of a trust + region radius. Typically, rho_begin should be about one tenth of the greatest + expected change to a variable, while rho_end should indicate the accuracy that + is required in the final values of the variables. + throws + - bobyqa_failure + This exception is thrown if the algorithm is unable to make progress towards + solving the problem. This may occur because the algorithm detects excessive + numerical errors or because max_f_evals of f() have occurred without reaching + convergence. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename funct, + typename T, + typename U + > + double find_max_bobyqa ( + const funct& f, + T& x, + long npt, + const U& x_lower, + const U& x_upper, + const double rho_begin, + const double rho_end, + const long max_f_evals + ); + /*! + This function is identical to the find_min_bobyqa() routine defined above + except that it negates the f() function before performing optimization. + Thus this function will attempt to find the maximizer of f() rather than + the minimizer. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_OPTIMIZATIOn_BOBYQA_ABSTRACT_Hh_ + diff --git a/ml/dlib/dlib/optimization/optimization_least_squares.h b/ml/dlib/dlib/optimization/optimization_least_squares.h new file mode 100644 index 000000000..6d12a919d --- /dev/null +++ b/ml/dlib/dlib/optimization/optimization_least_squares.h @@ -0,0 +1,345 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_OPTIMIZATION_LEAST_SQuARES_H_h_ +#define DLIB_OPTIMIZATION_LEAST_SQuARES_H_h_ + +#include "../matrix.h" +#include "optimization_trust_region.h" +#include "optimization_least_squares_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename column_vector_type, + typename funct_type, + typename funct_der_type, + typename vector_type + > + class least_squares_function_model + { + public: + least_squares_function_model ( + const funct_type& f_, + const funct_der_type& der_, + const vector_type& list_ + ) : f(f_), der(der_), list(list_) + { + S = 0; + last_f = 0; + last_f2 = 0; + + r.set_size(list.size(),1); + } + + const funct_type& f; + const funct_der_type& der; + const vector_type& list; + + typedef typename column_vector_type::type type; + typedef typename column_vector_type::mem_manager_type mem_manager_type; + typedef typename column_vector_type::layout_type layout_type; + const static long NR = column_vector_type::NR; + + typedef column_vector_type column_vector; + typedef matrix general_matrix; + + + type operator() ( + const column_vector& x + ) const + { + type result = 0; + for (long i = 0; i < list.size(); ++i) + { + const type temp = f(list(i), x); + // save the residual for later + r(i) = temp; + result += temp*temp; + } + + last_f = 0.5*result; + return 0.5*result; + } + + void get_derivative_and_hessian ( + const column_vector& x, + column_vector& d, + general_matrix& h + ) const + { + J.set_size(list.size(), x.size()); + + // compute the Jacobian + for (long i = 0; i < list.size(); ++i) + { + set_rowm(J,i) = trans(der(list(i), x)); + } + + // Compute the Levenberg-Marquardt gradient and hessian + d = trans(J)*r; + h = trans(J)*J; + + if (S.size() == 0) + { + S.set_size(x.size(), x.size()); + S = 0; + } + + // If this isn't the first iteration then check if using + // a quasi-newton update helps things. + if (last_r.size() != 0) + { + + s = x - last_x; + y = d - last_d; + yy = d - trans(last_J)*r; + + const type ys = trans(y)*s; + vtemp = yy - S*s; + const type temp2 = std::abs(trans(s)*S*s); + type scale = (temp2 != 0) ? std::min(1, std::abs(dot(s,yy))/temp2) : 1; + + if (ys != 0) + S = scale*S + (vtemp*trans(y) + y*trans(vtemp))/(ys) - dot(vtemp,s)/ys/ys*y*trans(y); + else + S *= scale; + + // check how well both the models fit the last change we saw in f() + const type measured_delta = last_f2 - last_f; + s = -s; + const type h_predicted_delta = 0.5*trans(s)*h*s + trans(d)*s; + const type s_predicted_delta = 0.5*trans(s)*(h+S)*s + trans(d)*s; + + const type h_error = std::abs((h_predicted_delta/measured_delta) - 1); + const type s_error = std::abs((s_predicted_delta/measured_delta) - 1); + + if (s_error < h_error && h_error > 0.01) + { + h += make_symmetric(S); + } + else if (s_error > 10) + { + S = 0; + } + + // put r into last_r + r.swap(last_r); + } + else + { + // put r into last_r + last_r = r; + } + + J.swap(last_J); + last_x = x; + last_d = d; + + last_f2 = last_f; + } + + mutable type last_f; // value of function we saw in last operator() + mutable type last_f2; // value of last_f we saw in get_derivative_and_hessian() + mutable matrix r; + mutable column_vector vtemp; + mutable column_vector s, y, yy; + + mutable general_matrix S; + mutable column_vector last_x; + mutable column_vector last_d; + mutable matrix last_r; + mutable matrix last_J; + mutable matrix J; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename column_vector_type, + typename funct_type, + typename funct_der_type, + typename vector_type + > + least_squares_function_model least_squares_model ( + const funct_type& f, + const funct_der_type& der, + const vector_type& list + ) + { + return least_squares_function_model(f,der,list); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename stop_strategy_type, + typename funct_type, + typename funct_der_type, + typename vector_type, + typename T + > + double solve_least_squares ( + stop_strategy_type stop_strategy, + const funct_type& f, + const funct_der_type& der, + const vector_type& list, + T& x, + double radius = 1 + ) + { + // The starting point (i.e. x) must be a column vector. + COMPILE_TIME_ASSERT(T::NC <= 1); + + // make sure requires clause is not broken + DLIB_ASSERT(is_vector(mat(list)) && list.size() > 0 && + is_col_vector(x) && radius > 0, + "\t double solve_least_squares()" + << "\n\t invalid arguments were given to this function" + << "\n\t is_vector(list): " << is_vector(mat(list)) + << "\n\t list.size(): " << list.size() + << "\n\t is_col_vector(x): " << is_col_vector(x) + << "\n\t radius: " << radius + ); + + return find_min_trust_region(stop_strategy, + least_squares_model(f, der, mat(list)), + x, + radius); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename column_vector_type, + typename funct_type, + typename funct_der_type, + typename vector_type + > + class least_squares_lm_function_model + { + public: + least_squares_lm_function_model ( + const funct_type& f_, + const funct_der_type& der_, + const vector_type& list_ + ) : f(f_), der(der_), list(list_) + { + r.set_size(list.size(),1); + } + + const funct_type& f; + const funct_der_type& der; + const vector_type& list; + + typedef typename column_vector_type::type type; + typedef typename column_vector_type::mem_manager_type mem_manager_type; + typedef typename column_vector_type::layout_type layout_type; + const static long NR = column_vector_type::NR; + + typedef column_vector_type column_vector; + typedef matrix general_matrix; + + mutable matrix r; + mutable column_vector vtemp; + + type operator() ( + const column_vector& x + ) const + { + type result = 0; + for (long i = 0; i < list.size(); ++i) + { + const type temp = f(list(i), x); + // save the residual for later + r(i) = temp; + result += temp*temp; + } + + return 0.5*result; + } + + void get_derivative_and_hessian ( + const column_vector& x, + column_vector& d, + general_matrix& h + ) const + { + d = 0; + h = 0; + for (long i = 0; i < list.size(); ++i) + { + vtemp = der(list(i), x); + d += r(i)*vtemp; + h += vtemp*trans(vtemp); + } + } + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename column_vector_type, + typename funct_type, + typename funct_der_type, + typename vector_type + > + least_squares_lm_function_model least_squares_lm_model ( + const funct_type& f, + const funct_der_type& der, + const vector_type& list + ) + { + return least_squares_lm_function_model(f,der,list); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename stop_strategy_type, + typename funct_type, + typename funct_der_type, + typename vector_type, + typename T + > + double solve_least_squares_lm ( + stop_strategy_type stop_strategy, + const funct_type& f, + const funct_der_type& der, + const vector_type& list, + T& x, + double radius = 1 + ) + { + // The starting point (i.e. x) must be a column vector. + COMPILE_TIME_ASSERT(T::NC <= 1); + + // make sure requires clause is not broken + DLIB_ASSERT(is_vector(mat(list)) && list.size() > 0 && + is_col_vector(x) && radius > 0, + "\t double solve_least_squares_lm()" + << "\n\t invalid arguments were given to this function" + << "\n\t is_vector(list): " << is_vector(mat(list)) + << "\n\t list.size(): " << list.size() + << "\n\t is_col_vector(x): " << is_col_vector(x) + << "\n\t radius: " << radius + ); + + return find_min_trust_region(stop_strategy, + least_squares_lm_model(f, der, mat(list)), + x, + radius); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_OPTIMIZATION_LEAST_SQuARES_H_h_ + + diff --git a/ml/dlib/dlib/optimization/optimization_least_squares_abstract.h b/ml/dlib/dlib/optimization/optimization_least_squares_abstract.h new file mode 100644 index 000000000..aaffb2221 --- /dev/null +++ b/ml/dlib/dlib/optimization/optimization_least_squares_abstract.h @@ -0,0 +1,112 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_OPTIMIZATIOn_LEAST_SQUARES_ABSTRACT_ +#ifdef DLIB_OPTIMIZATIOn_LEAST_SQUARES_ABSTRACT_ + +#include "../matrix/matrix_abstract.h" +#include "optimization_trust_region_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename stop_strategy_type, + typename funct_type, + typename funct_der_type, + typename vector_type, + typename T + > + double solve_least_squares ( + stop_strategy_type stop_strategy, + const funct_type& f, + const funct_der_type& der, + const vector_type& list, + T& x, + double radius = 1 + ); + /*! + requires + - stop_strategy == an object that defines a stop strategy such as one of + the objects from dlib/optimization/optimization_stop_strategies_abstract.h + - list == a matrix or something convertible to a matrix via mat() + such as a std::vector. + - is_vector(list) == true + - list.size() > 0 + - is_col_vector(x) == true + - radius > 0 + - for all valid i: + - f(list(i),x) must be a valid expression that evaluates to a floating point value. + - der(list(i),x) must be a valid expression that evaluates to the derivative of f(list(i),x) + with respect to x. This derivative must take the form of a column vector. + ensures + - This function performs an unconstrained minimization of the least squares + function g(x) defined by: + - g(x) = sum over all i: 0.5*pow( f(list(i),x), 2 ) + - This method combines the Levenberg-Marquardt method with a quasi-newton method + for approximating the second order terms of the hessian and is appropriate for + large residual problems (i.e. problems where the f() function isn't driven to 0). + In particular, it uses the method of Dennis, Gay, and Welsch as described in + Numerical Optimization by Nocedal and Wright (second edition). + - Since this is a trust region algorithm, the radius parameter defines the initial + size of the trust region. + - The function is optimized until stop_strategy decides that an acceptable + point has been found or the trust region subproblem fails to make progress. + - #x == the value of x that was found to minimize g() + - returns g(#x). + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename stop_strategy_type, + typename funct_type, + typename funct_der_type, + typename vector_type, + typename T + > + double solve_least_squares_lm ( + stop_strategy_type stop_strategy, + const funct_type& f, + const funct_der_type& der, + const vector_type& list, + T& x, + double radius = 1 + ); + /*! + requires + - stop_strategy == an object that defines a stop strategy such as one of + the objects from dlib/optimization/optimization_stop_strategies_abstract.h + - list == a matrix or something convertible to a matrix via mat() + such as a std::vector. + - is_vector(list) == true + - list.size() > 0 + - is_col_vector(x) == true + - radius > 0 + - for all valid i: + - f(list(i),x) must be a valid expression that evaluates to a floating point value. + - der(list(i),x) must be a valid expression that evaluates to the derivative of f(list(i),x) + with respect to x. This derivative must take the form of a column vector. + ensures + - This function performs an unconstrained minimization of the least squares + function g(x) defined by: + - g(x) = sum over all i: 0.5*pow( f(list(i),x), 2 ) + - This method implements a plain Levenberg-Marquardt approach for approximating + the hessian of g(). Therefore, it is most appropriate for small residual problems + (i.e. problems where f() goes to 0 at the solution). + - Since this is a trust region algorithm, the radius parameter defines the initial + size of the trust region. + - The function is optimized until stop_strategy decides that an acceptable + point has been found or the trust region subproblem fails to make progress. + - #x == the value of x that was found to minimize g() + - returns g(#x). + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_OPTIMIZATIOn_LEAST_SQUARES_ABSTRACT_ + + diff --git a/ml/dlib/dlib/optimization/optimization_line_search.h b/ml/dlib/dlib/optimization/optimization_line_search.h new file mode 100644 index 000000000..a91e3df84 --- /dev/null +++ b/ml/dlib/dlib/optimization/optimization_line_search.h @@ -0,0 +1,888 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_OPTIMIZATIOn_LINE_SEARCH_H_ +#define DLIB_OPTIMIZATIOn_LINE_SEARCH_H_ + +#include +#include +#include "../matrix.h" +#include "../algs.h" +#include "optimization_line_search_abstract.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template + class line_search_funct + { + public: + line_search_funct(const funct& f_, const T& start_, const T& direction_) + : f(f_),start(start_), direction(direction_), matrix_r(0), scalar_r(0) + {} + + line_search_funct(const funct& f_, const T& start_, const T& direction_, T& r) + : f(f_),start(start_), direction(direction_), matrix_r(&r), scalar_r(0) + {} + + line_search_funct(const funct& f_, const T& start_, const T& direction_, double& r) + : f(f_),start(start_), direction(direction_), matrix_r(0), scalar_r(&r) + {} + + double operator()(const double& x) const + { + return get_value(f(start + x*direction)); + } + + private: + + double get_value (const double& r) const + { + // save a copy of this value for later + if (scalar_r) + *scalar_r = r; + + return r; + } + + template + double get_value (const U& r) const + { + // U should be a matrix type + COMPILE_TIME_ASSERT(is_matrix::value); + + // save a copy of this value for later + if (matrix_r) + *matrix_r = r; + + return dot(r,direction); + } + + const funct& f; + const T& start; + const T& direction; + T* matrix_r; + double* scalar_r; + }; + + template + const line_search_funct make_line_search_function(const funct& f, const T& start, const T& direction) + { + COMPILE_TIME_ASSERT(is_matrix::value); + DLIB_ASSERT ( + is_col_vector(start) && is_col_vector(direction) && start.size() == direction.size(), + "\tline_search_funct make_line_search_function(f,start,direction)" + << "\n\tYou have to supply column vectors to this function" + << "\n\tstart.nc(): " << start.nc() + << "\n\tdirection.nc(): " << direction.nc() + << "\n\tstart.nr(): " << start.nr() + << "\n\tdirection.nr(): " << direction.nr() + ); + return line_search_funct(f,start,direction); + } + +// ---------------------------------------------------------------------------------------- + + template + const line_search_funct make_line_search_function(const funct& f, const T& start, const T& direction, double& f_out) + { + COMPILE_TIME_ASSERT(is_matrix::value); + DLIB_ASSERT ( + is_col_vector(start) && is_col_vector(direction) && start.size() == direction.size(), + "\tline_search_funct make_line_search_function(f,start,direction)" + << "\n\tYou have to supply column vectors to this function" + << "\n\tstart.nc(): " << start.nc() + << "\n\tdirection.nc(): " << direction.nc() + << "\n\tstart.nr(): " << start.nr() + << "\n\tdirection.nr(): " << direction.nr() + ); + return line_search_funct(f,start,direction, f_out); + } + +// ---------------------------------------------------------------------------------------- + + template + const line_search_funct make_line_search_function(const funct& f, const T& start, const T& direction, T& grad_out) + { + COMPILE_TIME_ASSERT(is_matrix::value); + DLIB_ASSERT ( + is_col_vector(start) && is_col_vector(direction) && start.size() == direction.size(), + "\tline_search_funct make_line_search_function(f,start,direction)" + << "\n\tYou have to supply column vectors to this function" + << "\n\tstart.nc(): " << start.nc() + << "\n\tdirection.nc(): " << direction.nc() + << "\n\tstart.nr(): " << start.nr() + << "\n\tdirection.nr(): " << direction.nr() + ); + return line_search_funct(f,start,direction,grad_out); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + inline double poly_min_extrap ( + double f0, + double d0, + double f1, + double d1, + double limit = 1 + ) + { + const double n = 3*(f1 - f0) - 2*d0 - d1; + const double e = d0 + d1 - 2*(f1 - f0); + + + // find the minimum of the derivative of the polynomial + + double temp = std::max(n*n - 3*e*d0,0.0); + + if (temp < 0) + return 0.5; + + temp = std::sqrt(temp); + + if (std::abs(e) <= std::numeric_limits::epsilon()) + return 0.5; + + // figure out the two possible min values + double x1 = (temp - n)/(3*e); + double x2 = -(temp + n)/(3*e); + + // compute the value of the interpolating polynomial at these two points + double y1 = f0 + d0*x1 + n*x1*x1 + e*x1*x1*x1; + double y2 = f0 + d0*x2 + n*x2*x2 + e*x2*x2*x2; + + // pick the best point + double x; + if (y1 < y2) + x = x1; + else + x = x2; + + // now make sure the minimum is within the allowed range of [0,limit] + return put_in_range(0,limit,x); + } + +// ---------------------------------------------------------------------------------------- + + inline double poly_min_extrap ( + double f0, + double d0, + double f1 + ) + { + const double temp = 2*(f1 - f0 - d0); + if (std::abs(temp) <= d0*std::numeric_limits::epsilon()) + return 0.5; + + const double alpha = -d0/temp; + + // now make sure the minimum is within the allowed range of (0,1) + return put_in_range(0,1,alpha); + } + +// ---------------------------------------------------------------------------------------- + + inline double poly_min_extrap ( + double f0, + double d0, + double x1, + double f_x1, + double x2, + double f_x2 + ) + { + DLIB_ASSERT(0 < x1 && x1 < x2,"Invalid inputs were given to this function"); + // The contents of this function follow the equations described on page 58 of the + // book Numerical Optimization by Nocedal and Wright, second edition. + matrix m; + matrix v; + + const double aa2 = x2*x2; + const double aa1 = x1*x1; + m = aa2, -aa1, + -aa2*x2, aa1*x1; + v = f_x1 - f0 - d0*x1, + f_x2 - f0 - d0*x2; + + + double temp = aa2*aa1*(x1-x2); + + // just take a guess if this happens + if (temp == 0 || std::fpclassify(temp) == FP_SUBNORMAL) + { + return x1/2.0; + } + + matrix temp2; + temp2 = m*v/temp; + const double a = temp2(0); + const double b = temp2(1); + + temp = b*b - 3*a*d0; + if (temp < 0 || a == 0) + { + // This is probably a line so just pick the lowest point + if (f0 < f_x2) + return 0; + else + return x2; + } + temp = (-b + std::sqrt(temp))/(3*a); + return put_in_range(0, x2, temp); + } + +// ---------------------------------------------------------------------------------------- + + inline double lagrange_poly_min_extrap ( + double p1, + double p2, + double p3, + double f1, + double f2, + double f3 + ) + { + DLIB_ASSERT(p1 < p2 && p2 < p3 && f1 >= f2 && f2 <= f3, + " p1: " << p1 + << " p2: " << p2 + << " p3: " << p3 + << " f1: " << f1 + << " f2: " << f2 + << " f3: " << f3); + + // This formula is out of the book Nonlinear Optimization by Andrzej Ruszczynski. See section 5.2. + double temp1 = f1*(p3*p3 - p2*p2) + f2*(p1*p1 - p3*p3) + f3*(p2*p2 - p1*p1); + double temp2 = 2*(f1*(p3 - p2) + f2*(p1 - p3) + f3*(p2 - p1) ); + + if (temp2 == 0) + { + return p2; + } + + const double result = temp1/temp2; + + // do a final sanity check to make sure the result is in the right range + if (p1 <= result && result <= p3) + { + return result; + } + else + { + return std::min(std::max(p1,result),p3); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct, + typename funct_der + > + double line_search ( + const funct& f, + const double f0, + const funct_der& der, + const double d0, + double rho, + double sigma, + double min_f, + unsigned long max_iter + ) + { + DLIB_ASSERT ( + 0 < rho && rho < sigma && sigma < 1 && max_iter > 0, + "\tdouble line_search()" + << "\n\tYou have given invalid arguments to this function" + << "\n\t sigma: " << sigma + << "\n\t rho: " << rho + << "\n\t max_iter: " << max_iter + ); + + // The bracketing phase of this function is implemented according to block 2.6.2 from + // the book Practical Methods of Optimization by R. Fletcher. The sectioning + // phase is an implementation of 2.6.4 from the same book. + + // 1 <= tau1a < tau1b. Controls the alpha jump size during the bracketing phase of + // the search. + const double tau1a = 1.4; + const double tau1b = 9; + + // it must be the case that 0 < tau2 < tau3 <= 1/2 for the algorithm to function + // correctly but the specific values of tau2 and tau3 aren't super important. + const double tau2 = 1.0/10.0; + const double tau3 = 1.0/2.0; + + + // Stop right away and return a step size of 0 if the gradient is 0 at the starting point + if (std::abs(d0) <= std::abs(f0)*std::numeric_limits::epsilon()) + return 0; + + // Stop right away if the current value is good enough according to min_f + if (f0 <= min_f) + return 0; + + // Figure out a reasonable upper bound on how large alpha can get. + const double mu = (min_f-f0)/(rho*d0); + + + double alpha = 1; + if (mu < 0) + alpha = -alpha; + alpha = put_in_range(0, 0.65*mu, alpha); + + + double last_alpha = 0; + double last_val = f0; + double last_val_der = d0; + + // The bracketing stage will find a range of points [a,b] + // that contains a reasonable solution to the line search + double a, b; + + // These variables will hold the values and derivatives of f(a) and f(b) + double a_val, b_val, a_val_der, b_val_der; + + // This thresh value represents the Wolfe curvature condition + const double thresh = std::abs(sigma*d0); + + unsigned long itr = 0; + // do the bracketing stage to find the bracket range [a,b] + while (true) + { + ++itr; + const double val = f(alpha); + const double val_der = der(alpha); + + // we are done with the line search since we found a value smaller + // than the minimum f value + if (val <= min_f) + return alpha; + + if (val > f0 + rho*alpha*d0 || val >= last_val) + { + a_val = last_val; + a_val_der = last_val_der; + b_val = val; + b_val_der = val_der; + + a = last_alpha; + b = alpha; + break; + } + + if (std::abs(val_der) <= thresh) + return alpha; + + // if we are stuck not making progress then quit with the current alpha + if (last_alpha == alpha || itr >= max_iter) + return alpha; + + if (val_der >= 0) + { + a_val = val; + a_val_der = val_der; + b_val = last_val; + b_val_der = last_val_der; + + a = alpha; + b = last_alpha; + break; + } + + + + const double temp = alpha; + // Pick a larger range [first, last]. We will pick the next alpha in that + // range. + double first, last; + if (mu > 0) + { + first = std::min(mu, alpha + tau1a*(alpha - last_alpha)); + last = std::min(mu, alpha + tau1b*(alpha - last_alpha)); + } + else + { + first = std::max(mu, alpha + tau1a*(alpha - last_alpha)); + last = std::max(mu, alpha + tau1b*(alpha - last_alpha)); + } + + + + // pick a point between first and last by doing some kind of interpolation + if (last_alpha < alpha) + alpha = last_alpha + (alpha-last_alpha)*poly_min_extrap(last_val, last_val_der, val, val_der, 1e10); + else + alpha = alpha + (last_alpha-alpha)*poly_min_extrap(val, val_der, last_val, last_val_der, 1e10); + + alpha = put_in_range(first,last,alpha); + + last_alpha = temp; + + last_val = val; + last_val_der = val_der; + + } + + + // Now do the sectioning phase from 2.6.4 + while (true) + { + ++itr; + double first = a + tau2*(b-a); + double last = b - tau3*(b-a); + + // use interpolation to pick alpha between first and last + alpha = a + (b-a)*poly_min_extrap(a_val, a_val_der, b_val, b_val_der); + alpha = put_in_range(first,last,alpha); + + const double val = f(alpha); + const double val_der = der(alpha); + + // we are done with the line search since we found a value smaller + // than the minimum f value or we ran out of iterations. + if (val <= min_f || itr >= max_iter) + return alpha; + + // stop if the interval gets so small that it isn't shrinking any more due to rounding error + if (a == first || b == last) + { + return b; + } + + // If alpha has basically become zero then just stop. Think of it like this, + // if we take the largest possible alpha step will the objective function + // change at all? If not then there isn't any point looking for a better + // alpha. + const double max_possible_alpha = std::max(std::abs(a),std::abs(b)); + if (std::abs(max_possible_alpha*d0) <= std::abs(f0)*std::numeric_limits::epsilon()) + return alpha; + + + if (val > f0 + rho*alpha*d0 || val >= a_val) + { + b = alpha; + b_val = val; + b_val_der = val_der; + } + else + { + if (std::abs(val_der) <= thresh) + return alpha; + + if ( (b-a)*val_der >= 0) + { + b = a; + b_val = a_val; + b_val_der = a_val_der; + } + + a = alpha; + a_val = val; + a_val_der = val_der; + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct + > + double backtracking_line_search ( + const funct& f, + double f0, + double d0, + double alpha, + double rho, + unsigned long max_iter + ) + { + DLIB_ASSERT ( + 0 < rho && rho < 1 && max_iter > 0, + "\tdouble backtracking_line_search()" + << "\n\tYou have given invalid arguments to this function" + << "\n\t rho: " << rho + << "\n\t max_iter: " << max_iter + ); + + // make sure alpha is going in the right direction. That is, it should be opposite + // the direction of the gradient. + if ((d0 > 0 && alpha > 0) || + (d0 < 0 && alpha < 0)) + { + alpha *= -1; + } + + bool have_prev_alpha = false; + double prev_alpha = 0; + double prev_val = 0; + unsigned long iter = 0; + while (true) + { + ++iter; + const double val = f(alpha); + if (val <= f0 + alpha*rho*d0 || iter >= max_iter) + { + return alpha; + } + else + { + // Interpolate a new alpha. We also make sure the step by which we + // reduce alpha is not super small. + double step; + if (!have_prev_alpha) + { + if (d0 < 0) + step = alpha*put_in_range(0.1,0.9, poly_min_extrap(f0, d0, val)); + else + step = alpha*put_in_range(0.1,0.9, poly_min_extrap(f0, -d0, val)); + have_prev_alpha = true; + } + else + { + if (d0 < 0) + step = put_in_range(0.1*alpha,0.9*alpha, poly_min_extrap(f0, d0, alpha, val, prev_alpha, prev_val)); + else + step = put_in_range(0.1*alpha,0.9*alpha, -poly_min_extrap(f0, -d0, -alpha, val, -prev_alpha, prev_val)); + } + + prev_alpha = alpha; + prev_val = val; + + alpha = step; + } + } + } + +// ---------------------------------------------------------------------------------------- + + class optimize_single_variable_failure : public error { + public: optimize_single_variable_failure(const std::string& s):error(s){} + }; + +// ---------------------------------------------------------------------------------------- + + template + double find_min_single_variable ( + const funct& f, + double& starting_point, + const double begin = -1e200, + const double end = 1e200, + const double eps = 1e-3, + const long max_iter = 100, + const double initial_search_radius = 1 + ) + { + DLIB_CASSERT( eps > 0 && + max_iter > 1 && + begin <= starting_point && starting_point <= end && + initial_search_radius > 0, + "eps: " << eps + << "\n max_iter: "<< max_iter + << "\n begin: "<< begin + << "\n end: "<< end + << "\n starting_point: "<< starting_point + << "\n initial_search_radius: "<< initial_search_radius + ); + + double search_radius = initial_search_radius; + + double p1=0, p2=0, p3=0, f1=0, f2=0, f3=0; + long f_evals = 1; + + if (begin == end) + { + return f(starting_point); + } + + using std::abs; + using std::min; + using std::max; + + // find three bracketing points such that f1 > f2 < f3. Do this by generating a sequence + // of points expanding away from 0. Also note that, in the following code, it is always the + // case that p1 < p2 < p3. + + + + // The first thing we do is get a starting set of 3 points that are inside the [begin,end] bounds + p1 = max(starting_point-search_radius, begin); + p3 = min(starting_point+search_radius, end); + f1 = f(p1); + f3 = f(p3); + + if (starting_point == p1 || starting_point == p3) + { + p2 = (p1+p3)/2; + f2 = f(p2); + } + else + { + p2 = starting_point; + f2 = f(starting_point); + } + + f_evals += 2; + + // Now we have 3 points on the function. Start looking for a bracketing set such that + // f1 > f2 < f3 is the case. + while ( !(f1 > f2 && f2 < f3)) + { + // check for hitting max_iter or if the interval is now too small + if (f_evals >= max_iter) + { + throw optimize_single_variable_failure( + "The max number of iterations of single variable optimization have been reached\n" + "without converging."); + } + if (p3-p1 < eps) + { + if (f1 < min(f2,f3)) + { + starting_point = p1; + return f1; + } + + if (f2 < min(f1,f3)) + { + starting_point = p2; + return f2; + } + + starting_point = p3; + return f3; + } + + // If the left most points are identical in function value then expand out the + // left a bit, unless it's already at bound or we would drop that left most + // point anyway because it's bad. + if (f1==f2 && f1 f2 < f3 and p1 < p2 < p3 + const double tau = 0.1; + while( f_evals < max_iter && p3-p1 > eps) + { + double p_min = lagrange_poly_min_extrap(p1,p2,p3, f1,f2,f3); + + + // make sure p_min isn't too close to the three points we already have + if (p_min < p2) + { + const double min_dist = (p2-p1)*tau; + if (abs(p1-p_min) < min_dist) + { + p_min = p1 + min_dist; + } + else if (abs(p2-p_min) < min_dist) + { + p_min = p2 - min_dist; + } + } + else + { + const double min_dist = (p3-p2)*tau; + if (abs(p2-p_min) < min_dist) + { + p_min = p2 + min_dist; + } + else if (abs(p3-p_min) < min_dist) + { + p_min = p3 - min_dist; + } + } + + // make sure one side of the bracket isn't super huge compared to the other + // side. If it is then contract it. + const double bracket_ratio = abs(p1-p2)/abs(p2-p3); + if ( !( bracket_ratio < 10 && bracket_ratio > 0.1) ) + { + // Force p_min to be on a reasonable side. But only if lagrange_poly_min_extrap() + // didn't put it on a good side already. + if (bracket_ratio > 1 && p_min > p2) + p_min = (p1+p2)/2; + else if (p_min < p2) + p_min = (p2+p3)/2; + } + + + const double f_min = f(p_min); + + + // Remove one of the endpoints of our bracket depending on where the new point falls. + if (p_min < p2) + { + if (f1 > f_min && f_min < f2) + { + p3 = p2; + f3 = f2; + p2 = p_min; + f2 = f_min; + } + else + { + p1 = p_min; + f1 = f_min; + } + } + else + { + if (f2 > f_min && f_min < f3) + { + p1 = p2; + f1 = f2; + p2 = p_min; + f2 = f_min; + } + else + { + p3 = p_min; + f3 = f_min; + } + } + + + ++f_evals; + } + + if (f_evals >= max_iter) + { + throw optimize_single_variable_failure( + "The max number of iterations of single variable optimization have been reached\n" + "without converging."); + } + + starting_point = p2; + return f2; + } + +// ---------------------------------------------------------------------------------------- + + template + class negate_function_object + { + public: + negate_function_object(const funct& f_) : f(f_){} + + template + double operator()(const T& x) const + { + return -f(x); + } + + private: + const funct& f; + }; + + template + const negate_function_object negate_function(const funct& f) { return negate_function_object(f); } + +// ---------------------------------------------------------------------------------------- + + template + double find_max_single_variable ( + const funct& f, + double& starting_point, + const double begin = -1e200, + const double end = 1e200, + const double eps = 1e-3, + const long max_iter = 100, + const double initial_search_radius = 1 + ) + { + return -find_min_single_variable(negate_function(f), starting_point, begin, end, eps, max_iter, initial_search_radius); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_OPTIMIZATIOn_LINE_SEARCH_H_ + diff --git a/ml/dlib/dlib/optimization/optimization_line_search_abstract.h b/ml/dlib/dlib/optimization/optimization_line_search_abstract.h new file mode 100644 index 000000000..2aa221da4 --- /dev/null +++ b/ml/dlib/dlib/optimization/optimization_line_search_abstract.h @@ -0,0 +1,361 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_OPTIMIZATIOn_ABSTRACT_ +#ifdef DLIB_OPTIMIZATIOn_ABSTRACT_ + +#include +#include +#include "../matrix/matrix_abstract.h" +#include "../algs.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename funct, + typename T + > + class line_search_funct; + /*! + This object is a function object that represents a line search function. + + Moreover, it represents a function with the signature: + double l(double x) + !*/ + + template < + typename funct, + typename T + > + const line_search_funct make_line_search_function ( + const funct& f, + const T& start, + const T& direction + ); + /*! + requires + - is_col_vector(start) && is_col_vector(direction) && start.size() == direction.size() + (i.e. start and direction should be column vectors of the same size) + - f must return either a double or a column vector the same length as start + - f(start + 1.5*direction) should be a valid expression + ensures + - if (f returns a double) then + - returns a line search function that computes l(x) == f(start + x*direction) + - else + - returns a line search function that computes l(x) == dot(f(start + x*direction),direction). + That is, we assume f is the derivative of some other function and that what + f returns is a gradient vector. + So the following two expressions both create the derivative of l(x): + - derivative(make_line_search_function(funct,start,direction)) + - make_line_search_function(derivative(funct),start,direction) + !*/ + + template < + typename funct, + typename T + > + const line_search_funct make_line_search_function ( + const funct& f, + const T& start, + const T& direction, + double& f_out + ); + /*! + This function is identical to the above three argument version of make_line_search_function() + except that, if f() outputs a double, every time f() is evaluated its output is also stored + into f_out. + !*/ + + template < + typename funct, + typename T + > + const line_search_funct make_line_search_function ( + const funct& f, + const T& start, + const T& direction, + T& gradient_out + ); + /*! + This function is identical to the above three argument version of make_line_search_function() + except that, if f() outputs a column vector, every time f() is evaluated its output is also + stored into gradient_out. + !*/ + +// ---------------------------------------------------------------------------------------- + + inline double poly_min_extrap ( + double f0, + double d0, + double f1, + double d1, + double limit = 1 + ); + /*! + ensures + - let c(x) be a 3rd degree polynomial such that: + - c(0) == f0 + - c(1) == f1 + - derivative of c(x) at x==0 is d0 + - derivative of c(x) at x==1 is d1 + - returns the point in the range [0,limit] that minimizes the polynomial c(x) + !*/ + +// ---------------------------------------------------------------------------------------- + + inline double poly_min_extrap ( + double f0, + double d0, + double f1 + ); + /*! + ensures + - let c(x) be a 2nd degree polynomial such that: + - c(0) == f0 + - c(1) == f1 + - derivative of c(x) at x==0 is d0 + - returns the point in the range [0,1] that minimizes the polynomial c(x) + !*/ + +// ---------------------------------------------------------------------------------------- + + inline double poly_min_extrap ( + double f0, + double d0, + double x1, + double f_x1, + double x2, + double f_x2 + ) + /*! + requires + - 0 < x1 < x2 + ensures + - let f(x) be a 3rd degree polynomial such that: + - f(0) == f0 + - derivative of f(x) at x==0 is d0 + - f(x1) == f_x1 + - f(x2) == f_x2 + - returns the point in the range [0,x2] that minimizes the polynomial f(x) + !*/ + +// ---------------------------------------------------------------------------------------- + + inline double lagrange_poly_min_extrap ( + double p1, + double p2, + double p3, + double f1, + double f2, + double f3 + ); + /*! + requires + - f1 >= f2 <= f3 + - p1 < p2 < p3 + ensures + - let c(x) be the second order Lagrange polynomial that interpolates the + points p1, p2, and p3 where c(p1)==f1, c(p2)==f2, and c(p3)==f3 + - this function returns the point in the range [p1,p3] that minimizes + the polynomial c(x) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename funct, + typename funct_der + > + double line_search ( + const funct& f, + const double f0, + const funct_der& der, + const double d0, + double rho, + double sigma, + double min_f, + unsigned long max_iter + ) + /*! + requires + - 0 < rho < sigma < 1 + - f and der are scalar functions of scalars + (e.g. line_search_funct objects) + - der is the derivative of f + - f0 == f(0) + - d0 == der(0) + - max_iter > 0 + ensures + - Performs a line search and uses the strong Wolfe conditions to decide when + the search can stop. + - rho == the parameter of the Wolfe sufficient decrease condition + - sigma == the parameter of the Wolfe curvature condition + - max_iter == the maximum number of iterations allowable. After this + many evaluations of f() line_search() is guaranteed to terminate. + - returns a value alpha such that f(alpha) is significantly closer to + the minimum of f than f(0). + - It is assumed that the minimum possible value of f(x) is min_f. So if + an alpha is found such that f(alpha) <= min_f then the search stops + immediately. + - This function is also optimized for the case where der(0) is negative. I.e. + positive values of the argument to f() should be in a descent direction. + - When this function makes calls to f() and der() it always does so by + first calling f() and then calling der(). That is, these two functions + are always called in pairs with f() being called first and then der() + being called second. + !*/ + + /* + A good discussion of the Wolfe conditions and line search algorithms in + general can be found in the book Practical Methods of Optimization by R. Fletcher + and also in the more recent book Numerical Optimization by Nocedal and Wright. + */ + +// ---------------------------------------------------------------------------------------- + + template < + typename funct + > + double backtracking_line_search ( + const funct& f, + double f0, + double d0, + double alpha, + double rho, + unsigned long max_iter + ); + /*! + requires + - 0 < rho < 1 + - f is a scalar function of scalars + (e.g. a line_search_funct object) + - f0 == f(0) + - d0 == the derivative of f() at f(0). + - max_iter > 0 + ensures + - Performs a backtracking line search and uses the Armijo sufficient decrease + rule to decide when the search can stop. + - rho == the parameter of the sufficient decrease condition. + - max_iter == the maximum number of iterations allowable. After this many + evaluations of f() backtracking_line_search() is guaranteed to terminate. + - The line search starts with the input alpha value and then backtracks until + it finds a good enough alpha value. Once found, it returns the alpha value + such that f(alpha) is significantly closer to the minimum of f than f(0). + - The returned value of alpha will always be the last value of alpha which was + passed to f(). That is, it will always be the case that the last call to f() + made by backtracking_line_search() was f(alpha) where alpha is the return + value from this function. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename funct + > + const negate_function_object negate_function( + const funct& f + ); + /*! + requires + - f == a function that returns a scalar + ensures + - returns a function that represents the negation of f. That is, + the returned function object represents g(x) == -f(x) + !*/ + +// ---------------------------------------------------------------------------------------- + + class optimize_single_variable_failure : public error; + /*! + This is the exception class used by the functions defined below. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename funct + > + double find_min_single_variable ( + const funct& f, + double& starting_point, + const double begin = -1e200, + const double end = 1e200, + const double eps = 1e-3, + const long max_iter = 100, + const double initial_search_radius = 1 + ) + /*! + requires + - eps > 0 + - max_iter > 1 + - begin <= starting_point <= end + - f must be a function of a double that returns a double + (e.g. f(starting_point) should be a valid expression that evaluates to a double) + - initial_search_radius > 0 + ensures + - Finds a point P such that: + - P is a local minimum of the function f(). + - begin <= P <= end + - Evaluates f() no more than max_iter times + - Stops searching when the window around the minimum point is smaller than eps. + The search will begin with the given starting_point and expand out to the + left and right by initial_search_radius sized steps. So if you think the + minimum is likely to be found within X distance from the starting_point then + set initial_search_radius to X. + - #starting_point == P + - returns f(P) + throws + - optimize_single_variable_failure + This exception is thrown if max_iter iterations are performed without + determining the min point to the requested accuracy of eps. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename funct + > + double find_max_single_variable ( + const funct& f, + double& starting_point, + const double begin = -1e200, + const double end = 1e200, + const double eps = 1e-3, + const long max_iter = 100, + const double initial_search_radius = 1 + ) + /*! + requires + - eps > 0 + - max_iter > 1 + - begin <= starting_point <= end + - f must be a function of a double that returns a double + (e.g. f(starting_point) should be a valid expression that evaluates to a double) + - initial_search_radius > 0 + ensures + - Finds a point P such that: + - P is a local maximum of the function f(). + - begin <= P <= end + - Evaluates f() no more than max_iter times + - Stops searching when the window around the maximum point is smaller than eps. + The search will begin with the given starting_point and expand out to the + left and right by initial_search_radius sized steps. So if you think the + maximum is likely to be found within X distance from the starting_point then + set initial_search_radius to X. + - #starting_point == P + - returns f(P) + throws + - optimize_single_variable_failure + This exception is thrown if max_iter iterations are performed without + determining the max point to the requested accuracy of eps. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_OPTIMIZATIOn_ABSTRACT_ + diff --git a/ml/dlib/dlib/optimization/optimization_oca.h b/ml/dlib/dlib/optimization/optimization_oca.h new file mode 100644 index 000000000..4ca9cd7ae --- /dev/null +++ b/ml/dlib/dlib/optimization/optimization_oca.h @@ -0,0 +1,407 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_OPTIMIZATIoN_OCA_Hh_ +#define DLIB_OPTIMIZATIoN_OCA_Hh_ + +#include "optimization_oca_abstract.h" + +#include "../matrix.h" +#include "optimization_solve_qp_using_smo.h" +#include +#include "../sequence.h" + +// ---------------------------------------------------------------------------------------- + +namespace dlib +{ + template + class oca_problem + { + public: + typedef typename matrix_type::type scalar_type; + + virtual ~oca_problem() {} + + virtual bool risk_has_lower_bound ( + scalar_type& + ) const { return false; } + + virtual bool optimization_status ( + scalar_type , + scalar_type , + scalar_type , + scalar_type , + unsigned long, + unsigned long + ) const = 0; + + virtual scalar_type get_c ( + ) const = 0; + + virtual long get_num_dimensions ( + ) const = 0; + + virtual void get_risk ( + matrix_type& current_solution, + scalar_type& risk_value, + matrix_type& risk_subgradient + ) const = 0; + + }; + +// ---------------------------------------------------------------------------------------- + + class oca + { + public: + + oca () + { + sub_eps = 1e-2; + sub_max_iter = 50000; + + inactive_thresh = 20; + } + + void set_subproblem_epsilon ( + double eps_ + ) { sub_eps = eps_; } + + double get_subproblem_epsilon ( + ) const { return sub_eps; } + + void set_subproblem_max_iterations ( + unsigned long sub_max_iter_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(sub_max_iter_ > 0, + "\t void oca::set_subproblem_max_iterations" + << "\n\t max iterations must be greater than 0" + << "\n\t sub_max_iter_: " << sub_max_iter_ + << "\n\t this: " << this + ); + + sub_max_iter = sub_max_iter_; + } + + unsigned long get_subproblem_max_iterations ( + ) const { return sub_max_iter; } + + void set_inactive_plane_threshold ( + unsigned long inactive_thresh_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(inactive_thresh_ > 0, + "\t void oca::set_inactive_plane_threshold" + << "\n\t inactive threshold must be greater than 0" + << "\n\t inactive_thresh_: " << inactive_thresh_ + << "\n\t this: " << this + ); + + inactive_thresh = inactive_thresh_; + } + + unsigned long get_inactive_plane_threshold ( + ) const { return inactive_thresh; } + + template < + typename matrix_type + > + typename matrix_type::type operator() ( + const oca_problem& problem, + matrix_type& w, + unsigned long num_nonnegative = 0, + unsigned long force_weight_to_1 = std::numeric_limits::max() + ) const + { + matrix_type empty_prior; + return oca_impl(problem, w, empty_prior, false, num_nonnegative, force_weight_to_1, 0); + } + + template < + typename matrix_type + > + typename matrix_type::type solve_with_elastic_net ( + const oca_problem& problem, + matrix_type& w, + double lasso_lambda, + unsigned long force_weight_to_1 = std::numeric_limits::max() + ) const + { + matrix_type empty_prior; + return oca_impl(problem, w, empty_prior, false, 0, force_weight_to_1, lasso_lambda); + } + + template < + typename matrix_type + > + typename matrix_type::type operator() ( + const oca_problem& problem, + matrix_type& w, + const matrix_type& prior + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(is_col_vector(prior) && prior.size() == problem.get_num_dimensions(), + "\t scalar_type oca::operator()" + << "\n\t The prior vector does not have the correct dimensions." + << "\n\t is_col_vector(prior): " << is_col_vector(prior) + << "\n\t prior.size(): " << prior.size() + << "\n\t problem.get_num_dimensions(): " << problem.get_num_dimensions() + << "\n\t this: " << this + ); + // disable the force weight to 1 option for this mode. We also disable the + // non-negative constraints. + unsigned long force_weight_to_1 = std::numeric_limits::max(); + return oca_impl(problem, w, prior, true, 0, force_weight_to_1, 0); + } + + private: + + template < + typename matrix_type + > + typename matrix_type::type oca_impl ( + const oca_problem& problem, + matrix_type& w, + const matrix_type& prior, + bool have_prior, + unsigned long num_nonnegative, + unsigned long force_weight_to_1, + const double lasso_lambda + ) const + { + const unsigned long num_dims = problem.get_num_dimensions(); + + // make sure requires clause is not broken + DLIB_ASSERT(problem.get_c() > 0 && + problem.get_num_dimensions() > 0 && + 0 <= lasso_lambda && lasso_lambda < 1, + "\t scalar_type oca::operator()" + << "\n\t The oca_problem is invalid" + << "\n\t problem.get_c(): " << problem.get_c() + << "\n\t problem.get_num_dimensions(): " << num_dims + << "\n\t lasso_lambda: " << lasso_lambda + << "\n\t this: " << this + ); + if (have_prior) + { + DLIB_ASSERT(lasso_lambda == 0, "Solver doesn't support using a prior with lasso."); + DLIB_ASSERT(num_nonnegative == 0, "Solver doesn't support using a prior with non-negative constraints."); + } + else if (lasso_lambda != 0) + { + DLIB_ASSERT(num_nonnegative == 0, "Solver doesn't support using lasso with non-negative constraints."); + } + + const double ridge_lambda = 1-lasso_lambda; + + if (num_nonnegative > num_dims) + num_nonnegative = num_dims; + + typedef typename matrix_type::type scalar_type; + typedef typename matrix_type::layout_type layout_type; + typedef typename matrix_type::mem_manager_type mem_manager_type; + typedef matrix_type vect_type; + + const scalar_type C = problem.get_c(); + + typename sequence::kernel_2a planes; + std::vector bs, miss_count; + + vect_type new_plane, alpha, btemp; + + w.set_size(num_dims, 1); + w = 0; + + // The current objective value. Note also that w always contains + // the current solution. + scalar_type cur_obj = std::numeric_limits::max(); + + // This will hold the cutting plane objective value. This value is + // a lower bound on the true optimal objective value. + scalar_type cp_obj = 0; + + matrix K, Ktmp; + matrix lambda, d; + if (lasso_lambda != 0) + d.set_size(num_dims); + else + d.set_size(num_nonnegative); + d = lasso_lambda*ones_matrix(d); + + scalar_type R_lower_bound; + if (problem.risk_has_lower_bound(R_lower_bound)) + { + // The flat lower bounding plane is always good to have if we know + // what it is. + bs.push_back(R_lower_bound); + new_plane = zeros_matrix(w); + planes.add(0, new_plane); + alpha = uniform_matrix(1,1, C); + miss_count.push_back(0); + + K.set_size(1,1); + K(0,0) = 0; + } + + const double prior_norm = have_prior ? 0.5*dot(prior,prior) : 0; + + unsigned long counter = 0; + while (true) + { + + // add the next cutting plane + scalar_type cur_risk; + if (force_weight_to_1 < (unsigned long)w.size()) + w(force_weight_to_1) = 1; + + problem.get_risk(w, cur_risk, new_plane); + + if (force_weight_to_1 < (unsigned long)w.size()) + { + // We basically arrange for the w(force_weight_to_1) element and all + // subsequent elements of w to not be involved in the optimization at + // all. An easy way to do this is to just make sure the elements of w + // corresponding elements in the subgradient are always set to zero + // while we run the cutting plane algorithm. The only time + // w(force_weight_to_1) is 1 is when we pass it to the oca_problem. + set_rowm(w, range(force_weight_to_1, w.size()-1)) = 0; + set_rowm(new_plane, range(force_weight_to_1, new_plane.size()-1)) = 0; + } + + if (have_prior) + bs.push_back(cur_risk - dot(w,new_plane) + dot(prior,new_plane)); + else + bs.push_back(cur_risk - dot(w,new_plane)); + planes.add(planes.size(), new_plane); + miss_count.push_back(0); + + // If alpha is empty then initialize it (we must always have sum(alpha) == C). + // But otherwise, just append a zero. + if (alpha.size() == 0) + alpha = uniform_matrix(1,1, C); + else + alpha = join_cols(alpha,zeros_matrix(1,1)); + + const scalar_type wnorm = 0.5*ridge_lambda*trans(w)*w + lasso_lambda*sum(abs(w)); + const double prior_part = have_prior? dot(w,prior) : 0; + cur_obj = wnorm + C*cur_risk + prior_norm-prior_part; + + // report current status + const scalar_type risk_gap = cur_risk - (cp_obj-wnorm+prior_part-prior_norm)/C; + if (counter > 0 && problem.optimization_status(cur_obj, cur_obj - cp_obj, + cur_risk, risk_gap, planes.size(), counter)) + { + break; + } + + // compute kernel matrix for all the planes + K.swap(Ktmp); + K.set_size(planes.size(), planes.size()); + // copy over the old K matrix + set_subm(K, 0,0, Ktmp.nr(), Ktmp.nc()) = Ktmp; + + // now add the new row and column to K + for (unsigned long c = 0; c < planes.size(); ++c) + { + K(c, Ktmp.nc()) = dot(planes[c], planes[planes.size()-1]); + K(Ktmp.nc(), c) = K(c,Ktmp.nc()); + } + + + // solve the cutting plane subproblem for the next w. We solve it to an + // accuracy that is related to how big the error gap is. Also, we multiply + // by ridge_lambda because the objective function for the QP we solve was + // implicitly scaled by ridge_lambda. That is, we want to ask the QP + // solver to solve the problem until the duality gap is 0.1 times smaller + // than what it is now. So the factor of ridge_lambda is necessary to make + // this happen. + scalar_type eps = std::min(sub_eps, 0.1*ridge_lambda*(cur_obj-cp_obj)); + // just a sanity check + if (eps < 1e-16) + eps = 1e-16; + // Note that we warm start this optimization by using the alpha from the last + // iteration as the starting point. + if (lasso_lambda != 0) + { + // copy planes into a matrix so we can call solve_qp4_using_smo() + matrix planes_mat(num_dims,planes.size()); + for (unsigned long i = 0; i < planes.size(); ++i) + set_colm(planes_mat,i) = planes[i]; + + btemp = ridge_lambda*mat(bs) - trans(planes_mat)*d; + solve_qp4_using_smo(planes_mat, K, btemp, d, alpha, lambda, eps, sub_max_iter, (scalar_type)(2*lasso_lambda)); + } + else if (num_nonnegative != 0) + { + // copy planes into a matrix so we can call solve_qp4_using_smo() + matrix planes_mat(num_nonnegative,planes.size()); + for (unsigned long i = 0; i < planes.size(); ++i) + set_colm(planes_mat,i) = colm(planes[i],0,num_nonnegative); + + solve_qp4_using_smo(planes_mat, K, mat(bs), d, alpha, lambda, eps, sub_max_iter); + } + else + { + solve_qp_using_smo(K, mat(bs), alpha, eps, sub_max_iter); + } + + // construct the w that minimized the subproblem. + w = -alpha(0)*planes[0]; + for (unsigned long i = 1; i < planes.size(); ++i) + w -= alpha(i)*planes[i]; + if (lasso_lambda != 0) + w = (lambda-d+w)/ridge_lambda; + else if (num_nonnegative != 0) // threshold the first num_nonnegative w elements if necessary. + set_rowm(w,range(0,num_nonnegative-1)) = lowerbound(rowm(w,range(0,num_nonnegative-1)),0); + + for (long i = 0; i < alpha.size(); ++i) + { + if (alpha(i) != 0) + miss_count[i] = 0; + else + miss_count[i] += 1; + } + + // Compute the lower bound on the true objective given to us by the cutting + // plane subproblem. + cp_obj = -0.5*ridge_lambda*trans(w)*w + trans(alpha)*mat(bs); + if (have_prior) + w += prior; + + // If it has been a while since a cutting plane was an active constraint then + // we should throw it away. + while (max(mat(miss_count)) >= inactive_thresh) + { + const long idx = index_of_max(mat(miss_count)); + bs.erase(bs.begin()+idx); + miss_count.erase(miss_count.begin()+idx); + K = removerc(K, idx, idx); + alpha = remove_row(alpha,idx); + planes.remove(idx, new_plane); + } + + ++counter; + } + + if (force_weight_to_1 < (unsigned long)w.size()) + w(force_weight_to_1) = 1; + + return cur_obj; + } + + double sub_eps; + + unsigned long sub_max_iter; + + unsigned long inactive_thresh; + }; +} + +// ---------------------------------------------------------------------------------------- + +#endif // DLIB_OPTIMIZATIoN_OCA_Hh_ + diff --git a/ml/dlib/dlib/optimization/optimization_oca_abstract.h b/ml/dlib/dlib/optimization/optimization_oca_abstract.h new file mode 100644 index 000000000..859dbdcfc --- /dev/null +++ b/ml/dlib/dlib/optimization/optimization_oca_abstract.h @@ -0,0 +1,334 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_OPTIMIZATION_OCA_ABsTRACT_Hh_ +#ifdef DLIB_OPTIMIZATION_OCA_ABsTRACT_Hh_ + +// ---------------------------------------------------------------------------------------- + +namespace dlib +{ + template + class oca_problem + { + /*! + REQUIREMENTS ON matrix_type + - matrix_type == a dlib::matrix capable of storing column vectors + + WHAT THIS OBJECT REPRESENTS + This object is the interface used to define the optimization + problems solved by the oca optimizer defined later in this file. + + OCA solves optimization problems with the following form: + Minimize: f(w) == 0.5*length_squared(w) + C*R(w) + + Where R(w) is a user-supplied convex function and C > 0. Optionally, + there can also be non-negativity constraints on some or all of the + elements of w. + + Or it can alternatively solve: + Minimize: f(w) == 0.5*length_squared(w-prior) + C*R(w) + + Where prior is a user supplied vector and R(w) has the same + interpretation as above. + + Or it can use the elastic net regularizer: + Minimize: f(w) == 0.5*(1-lasso_lambda)*length_squared(w) + lasso_lambda*sum(abs(w)) + C*R(w) + + Where lasso_lambda is a number in the range [0, 1) and controls + trade-off between doing L1 and L2 regularization. R(w) has the same + interpretation as above. + + + Note that the stopping condition must be provided by the user + in the form of the optimization_status() function. + !*/ + + public: + + typedef typename matrix_type::type scalar_type; + + virtual ~oca_problem() {} + + virtual bool risk_has_lower_bound ( + scalar_type& lower_bound + ) const { return false; } + /*! + ensures + - if (R(w) >= a constant for all values of w) then + - returns true + - #lower_bound == the constant that lower bounds R(w) + - else + - returns false + !*/ + + virtual bool optimization_status ( + scalar_type current_objective_value, + scalar_type current_error_gap, + scalar_type current_risk_value, + scalar_type current_risk_gap, + unsigned long num_cutting_planes, + unsigned long num_iterations + ) const = 0; + /*! + requires + - This function is called by the OCA optimizer each iteration. + - current_objective_value == the current value of the objective function f(w) + - current_error_gap == The bound on how much lower the objective function + can drop before we reach the optimal point. At the optimal solution the + error gap is equal to 0. + - current_risk_value == the current value of the R(w) term of the objective function. + - current_risk_gap == the bound on how much lower the risk term can go. At the optimal + solution the risk gap is zero. + - num_cutting_planes == the number of cutting planes the algorithm is currently using. + - num_iterations == A count of the total number of iterations that have executed + since we started running the optimization. + ensures + - If it is appropriate to terminate the optimization then this function returns true + and false otherwise. + !*/ + + virtual scalar_type get_c ( + ) const = 0; + /*! + ensures + - returns the C parameter + !*/ + + virtual long get_num_dimensions ( + ) const = 0; + /*! + ensures + - returns the number of free variables in this optimization problem + !*/ + + virtual void get_risk ( + matrix_type& current_solution, + scalar_type& risk_value, + matrix_type& risk_subgradient + ) const = 0; + /*! + requires + - is_col_vector(current_solution) == true + - current_solution.size() == get_num_dimensions() + ensures + - #current_solution will be set to one of the following: + - current_solution (i.e. it won't be modified at all) + - The result of a line search passing through current_solution. + - #risk_value == R(#current_solution) + - #risk_subgradient == an element of the subgradient of R() at the + point #current_solution + - Note that #risk_value and #risk_subgradient are NOT multiplied by get_c() + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + class oca + { + /*! + INITIAL VALUE + - get_subproblem_epsilon() == 1e-2 + - get_subproblem_max_iterations() == 50000 + - get_inactive_plane_threshold() == 20 + + WHAT THIS OBJECT REPRESENTS + This object is a tool for solving the optimization problem defined above + by the oca_problem abstract class. + + For reference, OCA solves optimization problems with the following form: + Minimize: f(w) == 0.5*length_squared(w) + C*R(w) + + Where R(w) is a user-supplied convex function and C > 0. Optionally, + this object can also add non-negativity constraints to some or all + of the elements of w. + + Or it can alternatively solve: + Minimize: f(w) == 0.5*length_squared(w-prior) + C*R(w) + + Where prior is a user supplied vector and R(w) has the same + interpretation as above. + + Or it can use the elastic net regularizer: + Minimize: f(w) == 0.5*(1-lasso_lambda)*length_squared(w) + lasso_lambda*sum(abs(w)) + C*R(w) + + Where lasso_lambda is a number in the range [0, 1) and controls + trade-off between doing L1 and L2 regularization. R(w) has the same + interpretation as above. + + + For a detailed discussion you should consult the following papers + from the Journal of Machine Learning Research: + Optimized Cutting Plane Algorithm for Large-Scale Risk Minimization + Vojtech Franc, Soren Sonnenburg; 10(Oct):2157--2192, 2009. + + Bundle Methods for Regularized Risk Minimization + Choon Hui Teo, S.V.N. Vishwanthan, Alex J. Smola, Quoc V. Le; 11(Jan):311-365, 2010. + !*/ + public: + + oca ( + ); + /*! + ensures + - this object is properly initialized + !*/ + + template < + typename matrix_type + > + typename matrix_type::type operator() ( + const oca_problem& problem, + matrix_type& w, + unsigned long num_nonnegative = 0, + unsigned long force_weight_to_1 = std::numeric_limits::max() + ) const; + /*! + requires + - problem.get_c() > 0 + - problem.get_num_dimensions() > 0 + ensures + - solves the given oca problem and stores the solution in #w. In particular, + this function solves: + Minimize: f(w) == 0.5*length_squared(w) + C*R(w) + - The optimization algorithm runs until problem.optimization_status() + indicates it is time to stop. + - returns the objective value at the solution #w + - if (num_nonnegative != 0) then + - Adds the constraint that #w(i) >= 0 for all i < num_nonnegative. + That is, the first num_nonnegative elements of #w will always be + non-negative. This includes the copies of w passed to get_risk() + in the form of the current_solution vector as well as the final + output of this function. + - if (force_weight_to_1 < problem.get_num_dimensions()) then + - The optimizer enforces the following constraints: + - #w(force_weight_to_1) == 1 + - for all i > force_weight_to_1: + - #w(i) == 0 + - That is, the element in the weight vector at the index indicated + by force_weight_to_1 will have a value of 1 upon completion of + this function, while all subsequent elements of w will have + values of 0. + !*/ + + template < + typename matrix_type + > + typename matrix_type::type operator() ( + const oca_problem& problem, + matrix_type& w, + const matrix_type& prior + ) const; + /*! + requires + - problem.get_c() > 0 + - problem.get_num_dimensions() > 0 + - is_col_vector(prior) == true + - prior.size() == problem.get_num_dimensions() + ensures + - solves the given oca problem and stores the solution in #w. + - In this mode, we solve a version of the problem with a different + regularizer. In particular, this function solves: + Minimize: f(w) == 0.5*length_squared(w-prior) + C*R(w) + - The optimization algorithm runs until problem.optimization_status() + indicates it is time to stop. + - returns the objective value at the solution #w + !*/ + + template < + typename matrix_type + > + typename matrix_type::type solve_with_elastic_net ( + const oca_problem& problem, + matrix_type& w, + scalar_type lasso_lambda, + unsigned long force_weight_to_1 = std::numeric_limits::max() + ) const; + /*! + requires + - problem.get_c() > 0 + - problem.get_num_dimensions() > 0 + - 0 <= lasso_lambda < 1 + ensures + - Solves the given oca problem and stores the solution in #w, but uses an + elastic net regularizer instead of the normal L2 regularizer. In + particular, this function solves: + Minimize: f(w) == 0.5*(1-lasso_lambda)*length_squared(w) + lasso_lambda*sum(abs(w)) + C*R(w) + - The optimization algorithm runs until problem.optimization_status() + indicates it is time to stop. + - returns the objective value at the solution #w + - if (force_weight_to_1 < problem.get_num_dimensions()) then + - The optimizer enforces the following constraints: + - #w(force_weight_to_1) == 1 + - for all i > force_weight_to_1: + - #w(i) == 0 + - That is, the element in the weight vector at the index indicated + by force_weight_to_1 will have a value of 1 upon completion of + this function, while all subsequent elements of w will have + values of 0. + !*/ + + void set_subproblem_epsilon ( + double eps + ); + /*! + requires + - eps > 0 + ensures + - #get_subproblem_epsilon() == eps + !*/ + + double get_subproblem_epsilon ( + ) const; + /*! + ensures + - returns the accuracy used in solving the quadratic programming + subproblem that is part of the overall OCA algorithm. + !*/ + + void set_subproblem_max_iterations ( + unsigned long sub_max_iter + ); + /*! + requires + - sub_max_iter > 0 + ensures + - #get_subproblem_max_iterations() == sub_max_iter + !*/ + + unsigned long get_subproblem_max_iterations ( + ) const; + /*! + ensures + - returns the maximum number of iterations this object will perform + while attempting to solve each quadratic programming subproblem. + !*/ + + void set_inactive_plane_threshold ( + unsigned long inactive_thresh + ); + /*! + requires + - inactive_thresh > 0 + ensures + - #get_inactive_plane_threshold() == inactive_thresh + !*/ + + unsigned long get_inactive_plane_threshold ( + ) const; + /*! + ensures + - As OCA runs it builds up a set of cutting planes. Typically + cutting planes become inactive after a certain point and can then + be removed. This function returns the number of iterations of + inactivity required before a cutting plane is removed. + !*/ + + }; +} + +// ---------------------------------------------------------------------------------------- + +#endif // DLIB_OPTIMIZATION_OCA_ABsTRACT_Hh_ + + diff --git a/ml/dlib/dlib/optimization/optimization_search_strategies.h b/ml/dlib/dlib/optimization/optimization_search_strategies.h new file mode 100644 index 000000000..eda637fa1 --- /dev/null +++ b/ml/dlib/dlib/optimization/optimization_search_strategies.h @@ -0,0 +1,324 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_OPTIMIZATIOn_SEARCH_STRATEGIES_H_ +#define DLIB_OPTIMIZATIOn_SEARCH_STRATEGIES_H_ + +#include +#include +#include "../matrix.h" +#include "../algs.h" +#include "optimization_search_strategies_abstract.h" +#include "../sequence.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class cg_search_strategy + { + public: + cg_search_strategy() : been_used(false) {} + + double get_wolfe_rho ( + ) const { return 0.001; } + + double get_wolfe_sigma ( + ) const { return 0.01; } + + unsigned long get_max_line_search_iterations ( + ) const { return 100; } + + template + const matrix& get_next_direction ( + const T& , + const double , + const T& funct_derivative + ) + { + if (been_used == false) + { + been_used = true; + prev_direction = -funct_derivative; + } + else + { + // Use the Polak-Ribiere (4.1.12) conjugate gradient described by Fletcher on page 83 + const double temp = trans(prev_derivative)*prev_derivative; + // If this value hits zero then just use the direction of steepest descent. + if (std::abs(temp) < std::numeric_limits::epsilon()) + { + prev_derivative = funct_derivative; + prev_direction = -funct_derivative; + return prev_direction; + } + + double b = trans(funct_derivative-prev_derivative)*funct_derivative/(temp); + prev_direction = -funct_derivative + b*prev_direction; + + } + + prev_derivative = funct_derivative; + return prev_direction; + } + + private: + bool been_used; + matrix prev_derivative; + matrix prev_direction; + }; + +// ---------------------------------------------------------------------------------------- + + class bfgs_search_strategy + { + public: + bfgs_search_strategy() : been_used(false), been_used_twice(false) {} + + double get_wolfe_rho ( + ) const { return 0.01; } + + double get_wolfe_sigma ( + ) const { return 0.9; } + + unsigned long get_max_line_search_iterations ( + ) const { return 100; } + + template + const matrix& get_next_direction ( + const T& x, + const double , + const T& funct_derivative + ) + { + if (been_used == false) + { + been_used = true; + H = identity_matrix(x.size()); + } + else + { + // update H with the BFGS formula from (3.2.12) on page 55 of Fletcher + delta = (x-prev_x); + gamma = funct_derivative-prev_derivative; + + double dg = dot(delta,gamma); + + // Try to set the initial value of the H matrix to something reasonable if we are still + // in the early stages of figuring out what it is. This formula below is what is suggested + // in the book Numerical Optimization by Nocedal and Wright in the chapter on Quasi-Newton methods. + if (been_used_twice == false) + { + double gg = trans(gamma)*gamma; + if (std::abs(gg) > std::numeric_limits::epsilon()) + { + const double temp = put_in_range(0.01, 100, dg/gg); + H = diagm(uniform_matrix(x.size(),1, temp)); + been_used_twice = true; + } + } + + Hg = H*gamma; + gH = trans(trans(gamma)*H); + double gHg = trans(gamma)*H*gamma; + if (gHg < std::numeric_limits::infinity() && dg < std::numeric_limits::infinity() && + dg != 0) + { + H += (1 + gHg/dg)*delta*trans(delta)/(dg) - (delta*trans(gH) + Hg*trans(delta))/(dg); + } + else + { + H = identity_matrix(H.nr()); + been_used_twice = false; + } + } + + prev_x = x; + prev_direction = -H*funct_derivative; + prev_derivative = funct_derivative; + return prev_direction; + } + + private: + bool been_used; + bool been_used_twice; + matrix prev_x; + matrix prev_derivative; + matrix prev_direction; + matrix H; + matrix delta, gamma, Hg, gH; + }; + +// ---------------------------------------------------------------------------------------- + + class lbfgs_search_strategy + { + public: + explicit lbfgs_search_strategy(unsigned long max_size_) : max_size(max_size_), been_used(false) + { + DLIB_ASSERT ( + max_size > 0, + "\t lbfgs_search_strategy(max_size)" + << "\n\t max_size can't be zero" + ); + } + + lbfgs_search_strategy(const lbfgs_search_strategy& item) + { + max_size = item.max_size; + been_used = item.been_used; + prev_x = item.prev_x; + prev_derivative = item.prev_derivative; + prev_direction = item.prev_direction; + alpha = item.alpha; + dh_temp = item.dh_temp; + } + + double get_wolfe_rho ( + ) const { return 0.01; } + + double get_wolfe_sigma ( + ) const { return 0.9; } + + unsigned long get_max_line_search_iterations ( + ) const { return 100; } + + template + const matrix& get_next_direction ( + const T& x, + const double , + const T& funct_derivative + ) + { + prev_direction = -funct_derivative; + + if (been_used == false) + { + been_used = true; + } + else + { + // add an element into the stored data sequence + dh_temp.s = x - prev_x; + dh_temp.y = funct_derivative - prev_derivative; + double temp = dot(dh_temp.s, dh_temp.y); + // only accept this bit of data if temp isn't zero + if (std::abs(temp) > std::numeric_limits::epsilon()) + { + dh_temp.rho = 1/temp; + data.add(data.size(), dh_temp); + } + else + { + data.clear(); + } + + if (data.size() > 0) + { + // This block of code is from algorithm 7.4 in the Nocedal book. + + alpha.resize(data.size()); + for (unsigned long i = data.size()-1; i < data.size(); --i) + { + alpha[i] = data[i].rho*dot(data[i].s, prev_direction); + prev_direction -= alpha[i]*data[i].y; + } + + // Take a guess at what the first H matrix should be. This formula below is what is suggested + // in the book Numerical Optimization by Nocedal and Wright in the chapter on Large Scale + // Unconstrained Optimization (in the L-BFGS section). + double H_0 = 1.0/data[data.size()-1].rho/dot(data[data.size()-1].y, data[data.size()-1].y); + H_0 = put_in_range(0.001, 1000.0, H_0); + prev_direction *= H_0; + + for (unsigned long i = 0; i < data.size(); ++i) + { + double beta = data[i].rho*dot(data[i].y, prev_direction); + prev_direction += data[i].s * (alpha[i] - beta); + } + } + + } + + if (data.size() > max_size) + { + // remove the oldest element in the data sequence + data.remove(0, dh_temp); + } + + prev_x = x; + prev_derivative = funct_derivative; + return prev_direction; + } + + private: + + struct data_helper + { + matrix s; + matrix y; + double rho; + + friend void swap(data_helper& a, data_helper& b) + { + a.s.swap(b.s); + a.y.swap(b.y); + std::swap(a.rho, b.rho); + } + }; + sequence::kernel_2a data; + + unsigned long max_size; + bool been_used; + matrix prev_x; + matrix prev_derivative; + matrix prev_direction; + std::vector alpha; + + data_helper dh_temp; + }; + +// ---------------------------------------------------------------------------------------- + + template + class newton_search_strategy_obj + { + public: + explicit newton_search_strategy_obj( + const hessian_funct& hess + ) : hessian(hess) {} + + double get_wolfe_rho ( + ) const { return 0.01; } + + double get_wolfe_sigma ( + ) const { return 0.9; } + + unsigned long get_max_line_search_iterations ( + ) const { return 100; } + + template + const matrix get_next_direction ( + const T& x, + const double , + const T& funct_derivative + ) + { + return -inv(hessian(x))*funct_derivative; + } + + private: + hessian_funct hessian; + }; + + template + newton_search_strategy_obj newton_search_strategy ( + hessian_funct hessian + ) { return newton_search_strategy_obj(hessian); } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_OPTIMIZATIOn_SEARCH_STRATEGIES_H_ + diff --git a/ml/dlib/dlib/optimization/optimization_search_strategies_abstract.h b/ml/dlib/dlib/optimization/optimization_search_strategies_abstract.h new file mode 100644 index 000000000..44b894bfc --- /dev/null +++ b/ml/dlib/dlib/optimization/optimization_search_strategies_abstract.h @@ -0,0 +1,330 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_OPTIMIZATIOn_SEARCH_STRATEGIES_ABSTRACT_ +#ifdef DLIB_OPTIMIZATIOn_SEARCH_STRATEGIES_ABSTRACT_ + +#include +#include +#include "../matrix/matrix_abstract.h" +#include "../algs.h" + + +namespace dlib +{ + /* + A good discussion of the search strategies in this file can be found in the + following book: Numerical Optimization by Nocedal and Wright. + */ + +// ---------------------------------------------------------------------------------------- + + class cg_search_strategy + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a strategy for determining which direction + a line search should be carried out along. This particular object + is an implementation of the Polak-Ribiere conjugate gradient method + for determining this direction. + + This method uses an amount of memory that is linear in the number + of variables to be optimized. So it is capable of handling problems + with a very large number of variables. However, it is generally + not as good as the L-BFGS algorithm (which is defined below in + the lbfgs_search_strategy class). + !*/ + + public: + cg_search_strategy( + ); + /*! + ensures + - This object is properly initialized and ready to generate + search directions. + !*/ + + double get_wolfe_rho ( + ) const; + /*! + ensures + - returns the value of the Wolfe rho parameter that should be used when + this search strategy is used with the line_search() function. + !*/ + + double get_wolfe_sigma ( + ) const; + /*! + ensures + - returns the value of the Wolfe sigma parameter that should be used when + this search strategy is used with the line_search() function. + !*/ + + unsigned long get_max_line_search_iterations ( + ) const; + /*! + ensures + - returns the value of the max iterations parameter that should be used when + this search strategy is used with the line_search() function. + !*/ + + template + const matrix& get_next_direction ( + const T& x, + const double funct_value, + const T& funct_derivative + ); + /*! + requires + - this function is only called once per search iteration + - for some objective function f(): + - x == the search point for the current iteration + - funct_value == f(x) + - funct_derivative == derivative(f)(x) + ensures + - Assuming that a line search is going to be conducted starting from the point x, + this function returns the direction in which the search should proceed. + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + class bfgs_search_strategy + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a strategy for determining which direction + a line search should be carried out along. This particular object + is an implementation of the BFGS quasi-newton method for determining + this direction. + + This method uses an amount of memory that is quadratic in the number + of variables to be optimized. It is generally very effective but + if your problem has a very large number of variables then it isn't + appropriate. Instead You should try the lbfgs_search_strategy. + !*/ + + public: + bfgs_search_strategy( + ); + /*! + ensures + - This object is properly initialized and ready to generate + search directions. + !*/ + + double get_wolfe_rho ( + ) const; + /*! + ensures + - returns the value of the Wolfe rho parameter that should be used when + this search strategy is used with the line_search() function. + !*/ + + double get_wolfe_sigma ( + ) const; + /*! + ensures + - returns the value of the Wolfe sigma parameter that should be used when + this search strategy is used with the line_search() function. + !*/ + + unsigned long get_max_line_search_iterations ( + ) const; + /*! + ensures + - returns the value of the max iterations parameter that should be used when + this search strategy is used with the line_search() function. + !*/ + + template + const matrix& get_next_direction ( + const T& x, + const double funct_value, + const T& funct_derivative + ); + /*! + requires + - this function is only called once per search iteration + - for some objective function f(): + - x == the search point for the current iteration + - funct_value == f(x) + - funct_derivative == derivative(f)(x) + ensures + - Assuming that a line search is going to be conducted starting from the point x, + this function returns the direction in which the search should proceed. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + class lbfgs_search_strategy + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a strategy for determining which direction + a line search should be carried out along. This particular object + is an implementation of the L-BFGS quasi-newton method for determining + this direction. + + This method uses an amount of memory that is linear in the number + of variables to be optimized. This makes it an excellent method + to use when an optimization problem has a large number of variables. + !*/ + public: + explicit lbfgs_search_strategy( + unsigned long max_size + ); + /*! + requires + - max_size > 0 + ensures + - This object is properly initialized and ready to generate + search directions. + - L-BFGS works by remembering a certain number of position and gradient + pairs. It uses this remembered information to compute search directions. + The max_size argument determines how many of these pairs will be remembered. + Typically, using between 3 and 30 pairs performs well for many problems. + !*/ + + double get_wolfe_rho ( + ) const; + /*! + ensures + - returns the value of the Wolfe rho parameter that should be used when + this search strategy is used with the line_search() function. + !*/ + + double get_wolfe_sigma ( + ) const; + /*! + ensures + - returns the value of the Wolfe sigma parameter that should be used when + this search strategy is used with the line_search() function. + !*/ + + unsigned long get_max_line_search_iterations ( + ) const; + /*! + ensures + - returns the value of the max iterations parameter that should be used when + this search strategy is used with the line_search() function. + !*/ + + template + const matrix& get_next_direction ( + const T& x, + const double funct_value, + const T& funct_derivative + ); + /*! + requires + - this function is only called once per search iteration + - for some objective function f(): + - x == the search point for the current iteration + - funct_value == f(x) + - funct_derivative == derivative(f)(x) + ensures + - Assuming that a line search is going to be conducted starting from the point x, + this function returns the direction in which the search should proceed. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename hessian_funct + > + class newton_search_strategy_obj + { + /*! + REQUIREMENTS ON hessian_funct + Objects of hessian_funct type must be function objects which + take a single argument and return a dlib::matrix of doubles. The + single argument must be a dlib::matrix capable of representing + column vectors of doubles. hessian_funct must also be copy + constructable. + + WHAT THIS OBJECT REPRESENTS + This object represents a strategy for determining which direction + a line search should be carried out along. This particular object + is an implementation of the newton method for determining this + direction. That is, it uses the following formula to determine + the direction: + search_direction = -inv(hessian(x))*derivative + !*/ + public: + explicit newton_search_strategy_obj( + const hessian_funct& hess + ); + /*! + ensures + - This object is properly initialized and ready to generate + search directions. + - hess will be used by this object to generate the needed hessian + matrices every time get_next_direction() is called. + !*/ + + double get_wolfe_rho ( + ) const; + /*! + ensures + - returns the value of the Wolfe rho parameter that should be used when + this search strategy is used with the line_search() function. + !*/ + + double get_wolfe_sigma ( + ) const; + /*! + ensures + - returns the value of the Wolfe sigma parameter that should be used when + this search strategy is used with the line_search() function. + !*/ + + unsigned long get_max_line_search_iterations ( + ) const; + /*! + ensures + - returns the value of the max iterations parameter that should be used when + this search strategy is used with the line_search() function. + !*/ + + template + const matrix get_next_direction ( + const T& x, + const double funct_value, + const T& funct_derivative + ); + /*! + requires + - for some objective function f(): + - x == the search point for the current iteration + - funct_value == f(x) + - funct_derivative == derivative(f)(x) + ensures + - Assuming that a line search is going to be conducted starting from the + point x, this function returns the direction in which the search should + proceed. + - In particular, the search direction will be given by: + - search_direction = -inv(hessian(x))*funct_derivative + !*/ + + }; + + template + newton_search_strategy_obj newton_search_strategy ( + hessian_funct hessian + ) { return newton_search_strategy_obj(hessian); } + /*! + ensures + - constructs and returns a newton_search_strategy_obj. + This function is just a helper to make the syntax for creating + these objects a little simpler. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_OPTIMIZATIOn_SEARCH_STRATEGIES_ABSTRACT_ + diff --git a/ml/dlib/dlib/optimization/optimization_solve_qp2_using_smo.h b/ml/dlib/dlib/optimization/optimization_solve_qp2_using_smo.h new file mode 100644 index 000000000..88cad0cf3 --- /dev/null +++ b/ml/dlib/dlib/optimization/optimization_solve_qp2_using_smo.h @@ -0,0 +1,468 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SOLVE_QP2_USING_SMo_Hh_ +#define DLIB_SOLVE_QP2_USING_SMo_Hh_ + +#include "optimization_solve_qp2_using_smo_abstract.h" +#include +#include +#include +#include "../matrix.h" +#include "../algs.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class invalid_nu_error : public dlib::error + { + public: + invalid_nu_error(const std::string& msg, double nu_) : dlib::error(msg), nu(nu_) {}; + const double nu; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + typename T::type maximum_nu_impl ( + const T& y + ) + { + typedef typename T::type scalar_type; + // make sure requires clause is not broken + DLIB_ASSERT(y.size() > 1 && is_col_vector(y), + "\ttypedef T::type maximum_nu(y)" + << "\n\ty should be a column vector with more than one entry" + << "\n\ty.nr(): " << y.nr() + << "\n\ty.nc(): " << y.nc() + ); + + long pos_count = 0; + long neg_count = 0; + for (long r = 0; r < y.nr(); ++r) + { + if (y(r) == 1.0) + { + ++pos_count; + } + else if (y(r) == -1.0) + { + ++neg_count; + } + else + { + // make sure requires clause is not broken + DLIB_ASSERT(y(r) == -1.0 || y(r) == 1.0, + "\ttypedef T::type maximum_nu(y)" + << "\n\ty should contain only 1 and 0 entries" + << "\n\tr: " << r + << "\n\ty(r): " << y(r) + ); + } + } + return static_cast(2.0*(scalar_type)std::min(pos_count,neg_count)/(scalar_type)y.nr()); + } + + template < + typename T + > + typename T::type maximum_nu ( + const T& y + ) + { + return maximum_nu_impl(mat(y)); + } + + template < + typename T + > + typename T::value_type maximum_nu ( + const T& y + ) + { + return maximum_nu_impl(mat(y)); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename matrix_type + > + class solve_qp2_using_smo + { + public: + typedef typename matrix_type::mem_manager_type mem_manager_type; + typedef typename matrix_type::type scalar_type; + typedef typename matrix_type::layout_type layout_type; + typedef matrix general_matrix; + typedef matrix column_matrix; + + + template < + typename EXP1, + typename EXP2, + long NR + > + unsigned long operator() ( + const matrix_exp& Q, + const matrix_exp& y, + const scalar_type nu, + matrix& alpha, + scalar_type eps + ) + { + DLIB_ASSERT(Q.nr() == Q.nc() && y.size() == Q.nr() && y.size() > 1 && is_col_vector(y) && + sum((y == +1) + (y == -1)) == y.size() && + 0 < nu && nu <= 1 && + eps > 0, + "\t void solve_qp2_using_smo::operator()" + << "\n\t invalid arguments were given to this function" + << "\n\t Q.nr(): " << Q.nr() + << "\n\t Q.nc(): " << Q.nc() + << "\n\t is_col_vector(y): " << is_col_vector(y) + << "\n\t y.size(): " << y.size() + << "\n\t sum((y == +1) + (y == -1)): " << sum((y == +1) + (y == -1)) + << "\n\t nu: " << nu + << "\n\t eps: " << eps + ); + + alpha.set_size(Q.nr(),1); + df.set_size(Q.nr()); + + // now initialize alpha + set_initial_alpha(y, nu, alpha); + + const scalar_type tau = 1e-12; + + typedef typename colm_exp::type col_type; + + set_all_elements(df, 0); + // initialize df. Compute df = Q*alpha + for (long r = 0; r < df.nr(); ++r) + { + if (alpha(r) != 0) + { + df += alpha(r)*matrix_cast(colm(Q,r)); + } + } + + unsigned long count = 0; + + // now perform the actual optimization of alpha + long i=0, j=0; + while (find_working_group(y,alpha,Q,df,tau,eps,i,j)) + { + ++count; + const scalar_type old_alpha_i = alpha(i); + const scalar_type old_alpha_j = alpha(j); + + optimize_working_pair(alpha,Q,df,tau,i,j); + + // update the df vector now that we have modified alpha(i) and alpha(j) + scalar_type delta_alpha_i = alpha(i) - old_alpha_i; + scalar_type delta_alpha_j = alpha(j) - old_alpha_j; + + col_type Q_i = colm(Q,i); + col_type Q_j = colm(Q,j); + for(long k = 0; k < df.nr(); ++k) + df(k) += Q_i(k)*delta_alpha_i + Q_j(k)*delta_alpha_j; + } + + return count; + } + + const column_matrix& get_gradient ( + ) const { return df; } + + private: + + // ------------------------------------------------------------------------------------- + + template < + typename scalar_type, + typename scalar_vector_type, + typename scalar_vector_type2 + > + inline void set_initial_alpha ( + const scalar_vector_type& y, + const scalar_type nu, + scalar_vector_type2& alpha + ) const + { + set_all_elements(alpha,0); + const scalar_type l = y.nr(); + scalar_type temp = nu*l/2; + long num = (long)std::floor(temp); + long num_total = (long)std::ceil(temp); + + bool has_slack = false; + int count = 0; + for (int i = 0; i < alpha.nr(); ++i) + { + if (y(i) == 1) + { + if (count < num) + { + ++count; + alpha(i) = 1; + } + else + { + has_slack = true; + if (num_total > num) + { + ++count; + alpha(i) = temp - std::floor(temp); + } + break; + } + } + } + + if (count != num_total || has_slack == false) + { + std::ostringstream sout; + sout << "Invalid nu of " << nu << ". It is required that: 0 < nu < " << 2*(scalar_type)count/y.nr(); + throw invalid_nu_error(sout.str(),nu); + } + + has_slack = false; + count = 0; + for (int i = 0; i < alpha.nr(); ++i) + { + if (y(i) == -1) + { + if (count < num) + { + ++count; + alpha(i) = 1; + } + else + { + has_slack = true; + if (num_total > num) + { + ++count; + alpha(i) = temp - std::floor(temp); + } + break; + } + } + } + + if (count != num_total || has_slack == false) + { + std::ostringstream sout; + sout << "Invalid nu of " << nu << ". It is required that: 0 < nu < " << 2*(scalar_type)count/y.nr(); + throw invalid_nu_error(sout.str(),nu); + } + } + + // ------------------------------------------------------------------------------------ + + template < + typename scalar_vector_type, + typename scalar_type, + typename EXP, + typename U, typename V + > + inline bool find_working_group ( + const V& y, + const U& alpha, + const matrix_exp& Q, + const scalar_vector_type& df, + const scalar_type tau, + const scalar_type eps, + long& i_out, + long& j_out + ) const + { + using namespace std; + long ip = 0; + long jp = 0; + long in = 0; + long jn = 0; + + + typedef typename colm_exp::type col_type; + typedef typename diag_exp::type diag_type; + + scalar_type ip_val = -numeric_limits::infinity(); + scalar_type jp_val = numeric_limits::infinity(); + scalar_type in_val = -numeric_limits::infinity(); + scalar_type jn_val = numeric_limits::infinity(); + + // loop over the alphas and find the maximum ip and in indices. + for (long i = 0; i < alpha.nr(); ++i) + { + if (y(i) == 1) + { + if (alpha(i) < 1.0) + { + if (-df(i) > ip_val) + { + ip_val = -df(i); + ip = i; + } + } + } + else + { + if (alpha(i) > 0.0) + { + if (df(i) > in_val) + { + in_val = df(i); + in = i; + } + } + } + } + + scalar_type Mp = numeric_limits::infinity(); + scalar_type Mn = numeric_limits::infinity(); + + // Pick out the columns and diagonal of Q we need below. Doing + // it this way is faster if Q is actually a symmetric_matrix_cache + // object. + col_type Q_ip = colm(Q,ip); + col_type Q_in = colm(Q,in); + diag_type Q_diag = diag(Q); + + + + // now we need to find the minimum jp and jn indices + for (long j = 0; j < alpha.nr(); ++j) + { + if (y(j) == 1) + { + if (alpha(j) > 0.0) + { + scalar_type b = ip_val + df(j); + if (-df(j) < Mp) + Mp = -df(j); + + if (b > 0) + { + scalar_type a = Q_ip(ip) + Q_diag(j) - 2*Q_ip(j); + if (a <= 0) + a = tau; + scalar_type temp = -b*b/a; + if (temp < jp_val) + { + jp_val = temp; + jp = j; + } + } + } + } + else + { + if (alpha(j) < 1.0) + { + scalar_type b = in_val - df(j); + if (df(j) < Mn) + Mn = df(j); + + if (b > 0) + { + scalar_type a = Q_in(in) + Q_diag(j) - 2*Q_in(j); + if (a <= 0) + a = tau; + scalar_type temp = -b*b/a; + if (temp < jn_val) + { + jn_val = temp; + jn = j; + } + } + } + } + } + + // if we are at the optimal point then return false so the caller knows + // to stop optimizing + if (std::max(ip_val - Mp, in_val - Mn) < eps) + return false; + + if (jp_val < jn_val) + { + i_out = ip; + j_out = jp; + } + else + { + i_out = in; + j_out = jn; + } + + return true; + } + + // ------------------------------------------------------------------------------------ + + template < + typename EXP, + typename T, typename U + > + inline void optimize_working_pair ( + T& alpha, + const matrix_exp& Q, + const U& df, + const scalar_type tau, + const long i, + const long j + ) const + { + scalar_type quad_coef = Q(i,i)+Q(j,j)-2*Q(j,i); + if (quad_coef <= 0) + quad_coef = tau; + scalar_type delta = (df(i)-df(j))/quad_coef; + scalar_type sum = alpha(i) + alpha(j); + alpha(i) -= delta; + alpha(j) += delta; + + if(sum > 1) + { + if(alpha(i) > 1) + { + alpha(i) = 1; + alpha(j) = sum - 1; + } + else if(alpha(j) > 1) + { + alpha(j) = 1; + alpha(i) = sum - 1; + } + } + else + { + if(alpha(j) < 0) + { + alpha(j) = 0; + alpha(i) = sum; + } + else if(alpha(i) < 0) + { + alpha(i) = 0; + alpha(j) = sum; + } + } + } + + // ------------------------------------------------------------------------------------ + + column_matrix df; // gradient of f(alpha) + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SOLVE_QP2_USING_SMo_Hh_ + + diff --git a/ml/dlib/dlib/optimization/optimization_solve_qp2_using_smo_abstract.h b/ml/dlib/dlib/optimization/optimization_solve_qp2_using_smo_abstract.h new file mode 100644 index 000000000..962541c25 --- /dev/null +++ b/ml/dlib/dlib/optimization/optimization_solve_qp2_using_smo_abstract.h @@ -0,0 +1,150 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_OPTIMIZATION_SOLVE_QP2_USING_SMO_ABSTRACT_H_ +#ifdef DLIB_OPTIMIZATION_SOLVE_QP2_USING_SMO_ABSTRACT_H_ + +#include "../matrix/matrix_abstract.h" +#include "../algs.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class invalid_nu_error : public dlib::error + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is an exception class used to indicate that a + value of nu given to the solve_qp2_using_smo object is incompatible + with the constraints of the quadratic program. + + this->nu will be set to the invalid value of nu used. + !*/ + + public: + invalid_nu_error(const std::string& msg, double nu_) : dlib::error(msg), nu(nu_) {}; + const double nu; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + typename T::type maximum_nu ( + const T& y + ); + /*! + requires + - T == a matrix object or an object convertible to a matrix via mat() + - is_col_vector(y) == true + - y.size() > 1 + - sum((y == +1) + (y == -1)) == y.size() + (i.e. all elements of y must be equal to +1 or -1) + ensures + - returns the maximum valid nu that can be used with solve_qp2_using_smo and + the given y vector. + (i.e. 2.0*min(sum(y == +1), sum(y == -1))/y.size()) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename matrix_type + > + class solve_qp2_using_smo + { + /*! + REQUIREMENTS ON matrix_type + Must be some type of dlib::matrix. + + WHAT THIS OBJECT REPRESENTS + This object is a tool for solving the following quadratic programming + problem using the sequential minimal optimization algorithm: + + Minimize: f(alpha) == 0.5*trans(alpha)*Q*alpha + subject to the following constraints: + - sum(alpha) == nu*y.size() + - 0 <= min(alpha) && max(alpha) <= 1 + - trans(y)*alpha == 0 + + Where all elements of y must be equal to +1 or -1 and f is convex. + This means that Q should be symmetric and positive-semidefinite. + + + This object implements the strategy used by the LIBSVM tool. The following papers + can be consulted for additional details: + - Chang and Lin, Training {nu}-Support Vector Classifiers: Theory and Algorithms + - Chih-Chung Chang and Chih-Jen Lin, LIBSVM : a library for support vector + machines, 2001. Software available at http://www.csie.ntu.edu.tw/~cjlin/libsvm + !*/ + + public: + + typedef typename matrix_type::mem_manager_type mem_manager_type; + typedef typename matrix_type::type scalar_type; + typedef typename matrix_type::layout_type layout_type; + typedef matrix general_matrix; + typedef matrix column_matrix; + + template < + typename EXP1, + typename EXP2, + long NR + > + unsigned long operator() ( + const matrix_exp& Q, + const matrix_exp& y, + const scalar_type nu, + matrix& alpha, + scalar_type eps + ); + /*! + requires + - Q.nr() == Q.nc() + - is_col_vector(y) == true + - y.size() == Q.nr() + - y.size() > 1 + - sum((y == +1) + (y == -1)) == y.size() + (i.e. all elements of y must be equal to +1 or -1) + - alpha must be capable of representing a vector of size y.size() elements + - 0 < nu <= 1 + - eps > 0 + ensures + - This function solves the quadratic program defined in this class's main comment. + - The solution to the quadratic program will be stored in #alpha. + - #alpha.size() == y.size() + - This function uses an implementation of the sequential minimal optimization + algorithm. It runs until the KKT violation is less than eps. So eps controls + how accurate the solution is and smaller values result in better solutions. + (a reasonable eps is usually about 1e-3) + - #get_gradient() == Q*(#alpha) + (i.e. stores the gradient of f() at #alpha in get_gradient()) + - returns the number of iterations performed. + throws + - invalid_nu_error + This exception is thrown if nu >= maximum_nu(y). + (some values of nu cause the constraints to become impossible to satisfy. + If this is detected then an exception is thrown). + !*/ + + const column_matrix& get_gradient ( + ) const; + /*! + ensures + - returns the gradient vector at the solution of the last problem solved + by this object. If no problem has been solved then returns an empty + vector. + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_OPTIMIZATION_SOLVE_QP2_USING_SMO_ABSTRACT_H_ + + + diff --git a/ml/dlib/dlib/optimization/optimization_solve_qp3_using_smo.h b/ml/dlib/dlib/optimization/optimization_solve_qp3_using_smo.h new file mode 100644 index 000000000..617ecc408 --- /dev/null +++ b/ml/dlib/dlib/optimization/optimization_solve_qp3_using_smo.h @@ -0,0 +1,455 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SOLVE_QP3_USING_SMo_Hh_ +#define DLIB_SOLVE_QP3_USING_SMo_Hh_ + +#include "optimization_solve_qp3_using_smo_abstract.h" +#include +#include +#include +#include "../matrix.h" +#include "../algs.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class invalid_qp3_error : public dlib::error + { + + public: + invalid_qp3_error( + const std::string& msg, + double B_, + double Cp_, + double Cn_ + ) : + dlib::error(msg), + B(B_), + Cp(Cp_), + Cn(Cn_) + {}; + + const double B; + const double Cp; + const double Cn; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename matrix_type + > + class solve_qp3_using_smo + { + public: + typedef typename matrix_type::mem_manager_type mem_manager_type; + typedef typename matrix_type::type scalar_type; + typedef typename matrix_type::layout_type layout_type; + typedef matrix general_matrix; + typedef matrix column_matrix; + + + template < + typename EXP1, + typename EXP2, + typename EXP3, + long NR + > + unsigned long operator() ( + const matrix_exp& Q, + const matrix_exp& p, + const matrix_exp& y, + const scalar_type B, + const scalar_type Cp, + const scalar_type Cn, + matrix& alpha, + scalar_type eps + ) + { + DLIB_ASSERT(Q.nr() == Q.nc() && y.size() == Q.nr() && p.size() == y.size() && + y.size() > 0 && is_col_vector(y) && is_col_vector(p) && + sum((y == +1) + (y == -1)) == y.size() && + Cp > 0 && Cn > 0 && + eps > 0, + "\t void solve_qp3_using_smo::operator()" + << "\n\t invalid arguments were given to this function" + << "\n\t Q.nr(): " << Q.nr() + << "\n\t Q.nc(): " << Q.nc() + << "\n\t is_col_vector(p): " << is_col_vector(p) + << "\n\t p.size(): " << p.size() + << "\n\t is_col_vector(y): " << is_col_vector(y) + << "\n\t y.size(): " << y.size() + << "\n\t sum((y == +1) + (y == -1)): " << sum((y == +1) + (y == -1)) + << "\n\t Cp: " << Cp + << "\n\t Cn: " << Cn + << "\n\t eps: " << eps + ); + + + + set_initial_alpha(y, B, Cp, Cn, alpha); + + + const scalar_type tau = 1e-12; + + typedef typename colm_exp::type col_type; + + // initialize df. Compute df = Q*alpha + p + df = p; + for (long r = 0; r < df.nr(); ++r) + { + if (alpha(r) != 0) + { + df += alpha(r)*matrix_cast(colm(Q,r)); + } + } + + unsigned long count = 0; + // now perform the actual optimization of alpha + long i=0, j=0; + while (find_working_group(y,alpha,Q,df,Cp,Cn,tau,eps,i,j)) + { + ++count; + const scalar_type old_alpha_i = alpha(i); + const scalar_type old_alpha_j = alpha(j); + + optimize_working_pair(alpha,Q,y,df,tau,i,j, Cp, Cn ); + + // update the df vector now that we have modified alpha(i) and alpha(j) + scalar_type delta_alpha_i = alpha(i) - old_alpha_i; + scalar_type delta_alpha_j = alpha(j) - old_alpha_j; + + col_type Q_i = colm(Q,i); + col_type Q_j = colm(Q,j); + for(long k = 0; k < df.nr(); ++k) + df(k) += Q_i(k)*delta_alpha_i + Q_j(k)*delta_alpha_j; + } + + return count; + } + + const column_matrix& get_gradient ( + ) const { return df; } + + private: + + // ------------------------------------------------------------------------------------- + + template < + typename scalar_type, + typename scalar_vector_type, + typename scalar_vector_type2 + > + inline void set_initial_alpha ( + const scalar_vector_type& y, + const scalar_type B, + const scalar_type Cp, + const scalar_type Cn, + scalar_vector_type2& alpha + ) const + { + alpha.set_size(y.size()); + + set_all_elements(alpha,0); + + // It's easy in the B == 0 case + if (B == 0) + return; + + const scalar_type C = (B > 0)? Cp : Cn; + + scalar_type temp = std::abs(B)/C; + long num = (long)std::floor(temp); + long num_total = (long)std::ceil(temp); + + const scalar_type B_sign = (B > 0)? 1 : -1; + + long count = 0; + for (long i = 0; i < alpha.nr(); ++i) + { + if (y(i) == B_sign) + { + if (count < num) + { + ++count; + alpha(i) = C; + } + else + { + if (count < num_total) + { + ++count; + alpha(i) = C*(temp - std::floor(temp)); + } + break; + } + } + } + + if (count != num_total) + { + std::ostringstream sout; + sout << "Invalid QP3 constraint parameters of B: " << B << ", Cp: " << Cp << ", Cn: "<< Cn; + throw invalid_qp3_error(sout.str(),B,Cp,Cn); + } + } + + // ------------------------------------------------------------------------------------ + + template < + typename scalar_vector_type, + typename scalar_type, + typename EXP, + typename U, typename V + > + inline bool find_working_group ( + const V& y, + const U& alpha, + const matrix_exp& Q, + const scalar_vector_type& df, + const scalar_type Cp, + const scalar_type Cn, + const scalar_type tau, + const scalar_type eps, + long& i_out, + long& j_out + ) const + { + using namespace std; + + long ip = 0; + long jp = 0; + + + typedef typename colm_exp::type col_type; + typedef typename diag_exp::type diag_type; + + scalar_type ip_val = -numeric_limits::infinity(); + scalar_type jp_val = numeric_limits::infinity(); + + // loop over the alphas and find the maximum ip and in indices. + for (long i = 0; i < alpha.nr(); ++i) + { + if (y(i) == 1) + { + if (alpha(i) < Cp) + { + if (-df(i) > ip_val) + { + ip_val = -df(i); + ip = i; + } + } + } + else + { + if (alpha(i) > 0.0) + { + if (df(i) > ip_val) + { + ip_val = df(i); + ip = i; + } + } + } + } + + scalar_type Mp = -numeric_limits::infinity(); + + // Pick out the column and diagonal of Q we need below. Doing + // it this way is faster if Q is actually a symmetric_matrix_cache + // object. + col_type Q_ip = colm(Q,ip); + diag_type Q_diag = diag(Q); + + + + // now we need to find the minimum jp indices + for (long j = 0; j < alpha.nr(); ++j) + { + if (y(j) == 1) + { + if (alpha(j) > 0.0) + { + scalar_type b = ip_val + df(j); + if (df(j) > Mp) + Mp = df(j); + + if (b > 0) + { + scalar_type a = Q_ip(ip) + Q_diag(j) - 2*y(ip)*Q_ip(j); + if (a <= 0) + a = tau; + scalar_type temp = -b*b/a; + if (temp < jp_val) + { + jp_val = temp; + jp = j; + } + } + } + } + else + { + if (alpha(j) < Cn) + { + scalar_type b = ip_val - df(j); + if (-df(j) > Mp) + Mp = -df(j); + + if (b > 0) + { + scalar_type a = Q_ip(ip) + Q_diag(j) + 2*y(ip)*Q_ip(j); + if (a <= 0) + a = tau; + scalar_type temp = -b*b/a; + if (temp < jp_val) + { + jp_val = temp; + jp = j; + } + } + } + } + } + + // if we are at the optimal point then return false so the caller knows + // to stop optimizing + if (Mp + ip_val < eps) + return false; + + + i_out = ip; + j_out = jp; + + return true; + } + + // ------------------------------------------------------------------------------------ + + template < + typename EXP, + typename EXP2, + typename T, typename U + > + inline void optimize_working_pair ( + T& alpha, + const matrix_exp& Q, + const matrix_exp& y, + const U& df, + const scalar_type& tau, + const long i, + const long j, + const scalar_type& Cp, + const scalar_type& Cn + ) const + { + const scalar_type Ci = (y(i) > 0 )? Cp : Cn; + const scalar_type Cj = (y(j) > 0 )? Cp : Cn; + + if (y(i) != y(j)) + { + scalar_type quad_coef = Q(i,i)+Q(j,j)+2*Q(j,i); + if (quad_coef <= 0) + quad_coef = tau; + scalar_type delta = (-df(i)-df(j))/quad_coef; + scalar_type diff = alpha(i) - alpha(j); + alpha(i) += delta; + alpha(j) += delta; + + if (diff > 0) + { + if (alpha(j) < 0) + { + alpha(j) = 0; + alpha(i) = diff; + } + } + else + { + if (alpha(i) < 0) + { + alpha(i) = 0; + alpha(j) = -diff; + } + } + + if (diff > Ci - Cj) + { + if (alpha(i) > Ci) + { + alpha(i) = Ci; + alpha(j) = Ci - diff; + } + } + else + { + if (alpha(j) > Cj) + { + alpha(j) = Cj; + alpha(i) = Cj + diff; + } + } + } + else + { + scalar_type quad_coef = Q(i,i)+Q(j,j)-2*Q(j,i); + if (quad_coef <= 0) + quad_coef = tau; + scalar_type delta = (df(i)-df(j))/quad_coef; + scalar_type sum = alpha(i) + alpha(j); + alpha(i) -= delta; + alpha(j) += delta; + + if(sum > Ci) + { + if(alpha(i) > Ci) + { + alpha(i) = Ci; + alpha(j) = sum - Ci; + } + } + else + { + if(alpha(j) < 0) + { + alpha(j) = 0; + alpha(i) = sum; + } + } + + if(sum > Cj) + { + if(alpha(j) > Cj) + { + alpha(j) = Cj; + alpha(i) = sum - Cj; + } + } + else + { + if(alpha(i) < 0) + { + alpha(i) = 0; + alpha(j) = sum; + } + } + + } + } + + // ------------------------------------------------------------------------------------ + + column_matrix df; // gradient of f(alpha) + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SOLVE_QP3_USING_SMo_Hh_ + + diff --git a/ml/dlib/dlib/optimization/optimization_solve_qp3_using_smo_abstract.h b/ml/dlib/dlib/optimization/optimization_solve_qp3_using_smo_abstract.h new file mode 100644 index 000000000..8efd7215b --- /dev/null +++ b/ml/dlib/dlib/optimization/optimization_solve_qp3_using_smo_abstract.h @@ -0,0 +1,139 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_OPTIMIZATION_SOLVE_QP3_USING_SMO_ABSTRACT_H_ +#ifdef DLIB_OPTIMIZATION_SOLVE_QP3_USING_SMO_ABSTRACT_H_ + +#include "../matrix/matrix_abstract.h" +#include "../algs.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class invalid_qp3_error : public dlib::error + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is an exception class used to indicate that the values + of B, Cp, and Cn given to the solve_qp3_using_smo object are incompatible + with the constraints of the quadratic program. + + this->B, this->Cp, and this->Cn will be set to the invalid values used. + !*/ + + public: + invalid_qp3_error( const std::string& msg, double B_, double Cp_, double Cn_) : + dlib::error(msg), B(B_), Cp(Cp_), Cn(Cn_) {}; + + const double B; + const double Cp; + const double Cn; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename matrix_type + > + class solve_qp3_using_smo + { + /*! + REQUIREMENTS ON matrix_type + Must be some type of dlib::matrix. + + WHAT THIS OBJECT REPRESENTS + This object is a tool for solving the following quadratic programming + problem using the sequential minimal optimization algorithm: + + Minimize: f(alpha) == 0.5*trans(alpha)*Q*alpha + trans(p)*alpha + subject to the following constraints: + - for all i such that y(i) == +1: 0 <= alpha(i) <= Cp + - for all i such that y(i) == -1: 0 <= alpha(i) <= Cn + - trans(y)*alpha == B + + Where all elements of y must be equal to +1 or -1 and f is convex. + This means that Q should be symmetric and positive-semidefinite. + + + This object implements the strategy used by the LIBSVM tool. The following papers + can be consulted for additional details: + - Chih-Chung Chang and Chih-Jen Lin, LIBSVM : a library for support vector + machines, 2001. Software available at http://www.csie.ntu.edu.tw/~cjlin/libsvm + - Working Set Selection Using Second Order Information for Training Support Vector Machines by + Fan, Chen, and Lin. In the Journal of Machine Learning Research 2005. + !*/ + + public: + + typedef typename matrix_type::mem_manager_type mem_manager_type; + typedef typename matrix_type::type scalar_type; + typedef typename matrix_type::layout_type layout_type; + typedef matrix general_matrix; + typedef matrix column_matrix; + + template < + typename EXP1, + typename EXP2, + typename EXP3, + long NR + > + unsigned long operator() ( + const matrix_exp& Q, + const matrix_exp& p, + const matrix_exp& y, + const scalar_type B, + const scalar_type Cp, + const scalar_type Cn, + matrix& alpha, + scalar_type eps + ); + /*! + requires + - Q.nr() == Q.nc() + - is_col_vector(y) == true + - is_col_vector(p) == true + - p.size() == y.size() == Q.nr() + - y.size() > 0 + - sum((y == +1) + (y == -1)) == y.size() + (i.e. all elements of y must be equal to +1 or -1) + - alpha must be capable of representing a vector of size y.size() elements + - Cp > 0 + - Cn > 0 + - eps > 0 + ensures + - This function solves the quadratic program defined in this class's main comment. + - The solution to the quadratic program will be stored in #alpha. + - #alpha.size() == y.size() + - This function uses an implementation of the sequential minimal optimization + algorithm. It runs until the KKT violation is less than eps. So eps controls + how accurate the solution is and smaller values result in better solutions. + (a reasonable eps is usually about 1e-3) + - #get_gradient() == Q*(#alpha) + (i.e. stores the gradient of f() at #alpha in get_gradient()) + - returns the number of iterations performed. + throws + - invalid_qp3_error + This exception is thrown if the given parameters cause the constraints + of the quadratic programming problem to be impossible to satisfy. + !*/ + + const column_matrix& get_gradient ( + ) const; + /*! + ensures + - returns the gradient vector at the solution of the last problem solved + by this object. If no problem has been solved then returns an empty + vector. + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_OPTIMIZATION_SOLVE_QP3_USING_SMO_ABSTRACT_H_ + + + diff --git a/ml/dlib/dlib/optimization/optimization_solve_qp_using_smo.h b/ml/dlib/dlib/optimization/optimization_solve_qp_using_smo.h new file mode 100644 index 000000000..b9cc74df3 --- /dev/null +++ b/ml/dlib/dlib/optimization/optimization_solve_qp_using_smo.h @@ -0,0 +1,937 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_OPTIMIZATION_SOLVE_QP_UsING_SMO_Hh_ +#define DLIB_OPTIMIZATION_SOLVE_QP_UsING_SMO_Hh_ + +#include "optimization_solve_qp_using_smo_abstract.h" +#include "../matrix.h" +#include +#include "../unordered_pair.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + /* + The algorithm defined in the solve_qp_using_smo() function below can be + derived by using an important theorem from the theory of constrained optimization. + This theorem tells us that any optimal point of a constrained function must + satisfy what are called the KKT conditions (also sometimes called just the KT + conditions, especially in older literature). A very good book to consult + regarding this topic is Practical Methods of Optimization (second edition) by + R. Fletcher. Below I will try to explain the general idea of how this is + applied. + + Let e == ones_matrix(alpha.size(),1) + + First, note that the function below solves the following quadratic program. + Minimize: f(alpha) == 0.5*trans(alpha)*Q*alpha - trans(alpha)*b + subject to the following constraints: + - trans(e)*alpha == C (i.e. the sum of alpha values doesn't change) + - min(alpha) >= 0 (i.e. all alpha values are nonnegative) + Where f is convex. This means that Q should be positive-semidefinite. + + + To get from this problem formulation to the algorithm below we have to + consider the KKT conditions. They tell us that any solution to the above + problem must satisfy the following 5 conditions: + 1. trans(e)*alpha == C + 2. min(alpha) >= 0 + + 3. Let L(alpha, x, y) == f(alpha) - trans(x)*alpha - y*(trans(e)*alpha - C) + Where x is a vector of length alpha.size() and y is a single scalar. + Then the derivative of L with respect to alpha must == 0 + So we get the following as our 3rd condition: + f'(alpha) - x - y*e == 0 + + 4. min(x) >= 0 (i.e. all x values are nonnegative) + 5. pointwise_multiply(x, alpha) == 0 + (i.e. only one member of each x(i) and alpha(i) pair can be non-zero) + + + From 3 we can easily obtain this rule: + for all i: f'(alpha)(i) - x(i) == y + + If we then consider 4 and 5 we see that we can infer that the following + must also be the case: + - if (alpha(i) > 0) then + - x(i) == 0 + - f'(alpha)(i) == y + - else + - x(i) == some nonnegative number + - f'(alpha)(i) >= y + + + The important thing to take away is the final rule. It tells us that at the + optimal solution all elements of the gradient of f have the same value if + their corresponding alpha is non-zero. It also tells us that all the other + gradient values are bigger than y. We can use this information to help us + pick which alpha variables to optimize at each iteration. + */ + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP1, + typename EXP2, + typename T, long NR, long NC, typename MM, typename L + > + unsigned long solve_qp_using_smo ( + const matrix_exp& Q, + const matrix_exp& b, + matrix& alpha, + T eps, + unsigned long max_iter + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(Q.nr() == Q.nc() && + is_col_vector(b) && + is_col_vector(alpha) && + b.size() == alpha.size() && + b.size() == Q.nr() && + alpha.size() > 0 && + min(alpha) >= 0 && + eps > 0 && + max_iter > 0, + "\t unsigned long solve_qp_using_smo()" + << "\n\t Invalid arguments were given to this function" + << "\n\t Q.nr(): " << Q.nr() + << "\n\t Q.nc(): " << Q.nc() + << "\n\t is_col_vector(b): " << is_col_vector(b) + << "\n\t is_col_vector(alpha): " << is_col_vector(alpha) + << "\n\t b.size(): " << b.size() + << "\n\t alpha.size(): " << alpha.size() + << "\n\t Q.nr(): " << Q.nr() + << "\n\t min(alpha): " << min(alpha) + << "\n\t eps: " << eps + << "\n\t max_iter: " << max_iter + ); + + const T C = sum(alpha); + + // Compute f'(alpha) (i.e. the gradient of f(alpha)) for the current alpha. + matrix df = Q*alpha - b; + + const T tau = 1000*std::numeric_limits::epsilon(); + + T big, little; + unsigned long iter = 0; + for (; iter < max_iter; ++iter) + { + // Find the two elements of df that satisfy the following: + // - little_idx == index_of_min(df) + // - big_idx == the index of the largest element in df such that alpha(big_idx) > 0 + // These two indices will tell us which two alpha values are most in violation of the KKT + // optimality conditions. + big = -std::numeric_limits::max(); + long big_idx = 0; + little = std::numeric_limits::max(); + long little_idx = 0; + for (long i = 0; i < df.nr(); ++i) + { + if (df(i) > big && alpha(i) > 0) + { + big = df(i); + big_idx = i; + } + if (df(i) < little) + { + little = df(i); + little_idx = i; + } + } + + // Check if the KKT conditions are still violated and stop if so. + //if (alpha(little_idx) > 0 && (big - little) < eps) + // break; + + // Check how big the duality gap is and stop when it goes below eps. + // The duality gap is the gap between the objective value of the function + // we are optimizing and the value of its primal form. This value is always + // greater than or equal to the distance to the optimum solution so it is a + // good way to decide if we should stop. See the book referenced above for + // more information. In particular, see the part about the Wolfe Dual. + if (trans(alpha)*df - C*little < eps) + break; + + + // Save these values, we will need them later. + const T old_alpha_big = alpha(big_idx); + const T old_alpha_little = alpha(little_idx); + + + // Now optimize the two variables we just picked. + T quad_coef = Q(big_idx,big_idx) + Q(little_idx,little_idx) - 2*Q(big_idx, little_idx); + if (quad_coef <= tau) + quad_coef = tau; + const T delta = (big - little)/quad_coef; + alpha(big_idx) -= delta; + alpha(little_idx) += delta; + + // Make sure alpha stays feasible. That is, make sure the updated alpha doesn't + // violate the non-negativity constraint. + if (alpha(big_idx) < 0) + { + // Since an alpha can't be negative we will just set it to 0 and shift all the + // weight to the other alpha. + alpha(big_idx) = 0; + alpha(little_idx) = old_alpha_big + old_alpha_little; + } + + // Every 300 iterations + if ((iter%300) == 299) + { + // Perform this form of the update every so often because doing so can help + // avoid the buildup of numerical errors you get with the alternate update + // below. + df = Q*alpha - b; + } + else + { + // Now update the gradient. We will perform the equivalent of: df = Q*alpha - b; + const T delta_alpha_big = alpha(big_idx) - old_alpha_big; + const T delta_alpha_little = alpha(little_idx) - old_alpha_little; + + for(long k = 0; k < df.nr(); ++k) + df(k) += Q(big_idx,k)*delta_alpha_big + Q(little_idx,k)*delta_alpha_little;; + } + } + + /* + using namespace std; + cout << "SMO: " << endl; + cout << " duality gap: "<< trans(alpha)*df - C*min(df) << endl; + cout << " KKT gap: "<< big-little << endl; + cout << " iter: "<< iter+1 << endl; + cout << " eps: "<< eps << endl; + */ + + return iter+1; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP1, + typename EXP2, + typename EXP3, + typename EXP4, + typename T, long NR, long NC, typename MM, typename L, + long NR2, long NC2 + > + unsigned long solve_qp4_using_smo ( + const matrix_exp& A, + const matrix_exp& Q, + const matrix_exp& b, + const matrix_exp& d, + matrix& alpha, + matrix& lambda, + T eps, + unsigned long max_iter, + T max_lambda = std::numeric_limits::infinity() + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(A.nc() == alpha.size() && + Q.nr() == Q.nc() && + is_col_vector(b) && + is_col_vector(alpha) && + b.size() == alpha.size() && + b.size() == Q.nr() && + alpha.size() > 0 && + min(alpha) >= 0 && + eps > 0 && + max_iter > 0, + "\t void solve_qp4_using_smo()" + << "\n\t Invalid arguments were given to this function" + << "\n\t A.nc(): " << A.nc() + << "\n\t Q.nr(): " << Q.nr() + << "\n\t Q.nc(): " << Q.nc() + << "\n\t is_col_vector(b): " << is_col_vector(b) + << "\n\t is_col_vector(alpha): " << is_col_vector(alpha) + << "\n\t b.size(): " << b.size() + << "\n\t alpha.size(): " << alpha.size() + << "\n\t Q.nr(): " << Q.nr() + << "\n\t min(alpha): " << min(alpha) + << "\n\t eps: " << eps + << "\n\t max_iter: " << max_iter + ); + DLIB_ASSERT(is_col_vector(d) == true && + max_lambda >= 0 && + d.size() == A.nr(), + "\t void solve_qp4_using_smo()" + << "\n\t Invalid arguments were given to this function" + << "\n\t A.nr(): " << A.nr() + << "\n\t d.size(): " << d.size() + << "\n\t max_lambda: " << max_lambda + ); + + const T C = sum(alpha); + + /* + For this optimization problem, it is the case that the optimal + value of lambda is given by a simple closed form expression if we + know the optimal alpha. So what we will do is to just optimize + alpha and every now and then we will update lambda with its optimal + value. Therefore, we use essentially the same method as the + solve_qp_using_smo() routine. + */ + + const bool d_is_zero = d==zeros_matrix(d); + + // compute optimal lambda for current alpha + if (d_is_zero) + lambda = A*alpha; + else + lambda = A*alpha + d; + lambda = dlib::clamp(lambda, 0, max_lambda); + + // Compute f'(alpha) (i.e. the gradient of f(alpha) with respect to alpha) for the current alpha. + matrix df = Q*alpha - b - trans(A)*lambda; + + const T tau = 1000*std::numeric_limits::epsilon(); + + T big, little; + unsigned long iter = 0; + for (; iter < max_iter; ++iter) + { + // Find the two elements of df that satisfy the following: + // - little_idx == index_of_min(df) + // - big_idx == the index of the largest element in df such that alpha(big_idx) > 0 + // These two indices will tell us which two alpha values are most in violation of the KKT + // optimality conditions. + big = -std::numeric_limits::max(); + long big_idx = 0; + little = std::numeric_limits::max(); + long little_idx = 0; + for (long i = 0; i < df.nr(); ++i) + { + if (df(i) > big && alpha(i) > 0) + { + big = df(i); + big_idx = i; + } + if (df(i) < little) + { + little = df(i); + little_idx = i; + } + } + + // Check how big the duality gap is and stop when it goes below eps. + // The duality gap is the gap between the objective value of the function + // we are optimizing and the value of its primal form. This value is always + // greater than or equal to the distance to the optimum solution so it is a + // good way to decide if we should stop. + if (trans(alpha)*df - C*little < eps) + { + // compute optimal lambda and recheck the duality gap to make + // sure we have really converged. + if (d_is_zero) + lambda = A*alpha; + else + lambda = A*alpha + d; + lambda = dlib::clamp(lambda, 0, max_lambda); + df = Q*alpha - b - trans(A)*lambda; + + if (trans(alpha)*df - C*min(df) < eps) + break; + else + continue; + } + + + // Save these values, we will need them later. + const T old_alpha_big = alpha(big_idx); + const T old_alpha_little = alpha(little_idx); + + + // Now optimize the two variables we just picked. + T quad_coef = Q(big_idx,big_idx) + Q(little_idx,little_idx) - 2*Q(big_idx, little_idx); + if (quad_coef <= tau) + quad_coef = tau; + const T delta = (big - little)/quad_coef; + alpha(big_idx) -= delta; + alpha(little_idx) += delta; + + // Make sure alpha stays feasible. That is, make sure the updated alpha doesn't + // violate the non-negativity constraint. + if (alpha(big_idx) < 0) + { + // Since an alpha can't be negative we will just set it to 0 and shift all the + // weight to the other alpha. + alpha(big_idx) = 0; + alpha(little_idx) = old_alpha_big + old_alpha_little; + } + + + // Every 300 iterations + if ((iter%300) == 299) + { + // compute the optimal lambda for the current alpha + if (d_is_zero) + lambda = A*alpha; + else + lambda = A*alpha + d; + lambda = dlib::clamp(lambda, 0, max_lambda); + + // Perform this form of the update every so often because doing so can help + // avoid the buildup of numerical errors you get with the alternate update + // below. + df = Q*alpha - b - trans(A)*lambda; + } + else + { + // Now update the gradient. We will perform the equivalent of: df = Q*alpha - b; + const T delta_alpha_big = alpha(big_idx) - old_alpha_big; + const T delta_alpha_little = alpha(little_idx) - old_alpha_little; + + for(long k = 0; k < df.nr(); ++k) + df(k) += Q(big_idx,k)*delta_alpha_big + Q(little_idx,k)*delta_alpha_little;; + } + } + + /* + using namespace std; + cout << "SMO: " << endl; + cout << " duality gap: "<< trans(alpha)*df - C*min(df) << endl; + cout << " KKT gap: "<< big-little << endl; + cout << " iter: "<< iter+1 << endl; + cout << " eps: "<< eps << endl; + */ + + + return iter+1; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP1, + typename EXP2, + typename T, long NR, long NC, typename MM, typename L + > + unsigned long solve_qp_box_constrained ( + const matrix_exp& Q, + const matrix_exp& b, + matrix& alpha, + const matrix& lower, + const matrix& upper, + T eps, + unsigned long max_iter + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(Q.nr() == Q.nc() && + alpha.size() == lower.size() && + alpha.size() == upper.size() && + is_col_vector(b) && + is_col_vector(alpha) && + is_col_vector(lower) && + is_col_vector(upper) && + b.size() == alpha.size() && + b.size() == Q.nr() && + alpha.size() > 0 && + 0 <= min(alpha-lower) && + 0 <= max(upper-alpha) && + eps > 0 && + max_iter > 0, + "\t unsigned long solve_qp_box_constrained()" + << "\n\t Invalid arguments were given to this function" + << "\n\t Q.nr(): " << Q.nr() + << "\n\t Q.nc(): " << Q.nc() + << "\n\t is_col_vector(b): " << is_col_vector(b) + << "\n\t is_col_vector(alpha): " << is_col_vector(alpha) + << "\n\t is_col_vector(lower): " << is_col_vector(lower) + << "\n\t is_col_vector(upper): " << is_col_vector(upper) + << "\n\t b.size(): " << b.size() + << "\n\t alpha.size(): " << alpha.size() + << "\n\t lower.size(): " << lower.size() + << "\n\t upper.size(): " << upper.size() + << "\n\t Q.nr(): " << Q.nr() + << "\n\t min(alpha-lower): " << min(alpha-lower) + << "\n\t max(upper-alpha): " << max(upper-alpha) + << "\n\t eps: " << eps + << "\n\t max_iter: " << max_iter + ); + + + // Compute f'(alpha) (i.e. the gradient of f(alpha)) for the current alpha. + matrix df = Q*alpha + b; + matrix QQ = reciprocal_max(diag(Q)); + + // First we use a coordinate descent method to initialize alpha. + double max_df = 0; + for (long iter = 0; iter < alpha.size()*2; ++iter) + { + max_df = 0; + long best_r =0; + // find the best alpha to optimize. + for (long r = 0; r < Q.nr(); ++r) + { + if (alpha(r) <= lower(r) && df(r) > 0) + ;//alpha(r) = lower(r); + else if (alpha(r) >= upper(r) && df(r) < 0) + ;//alpha(r) = upper(r); + else if (std::abs(df(r)) > max_df) + { + best_r = r; + max_df = std::abs(df(r)); + } + } + + // now optimize alpha(best_r) + const long r = best_r; + const T old_alpha = alpha(r); + alpha(r) = -(df(r)-Q(r,r)*alpha(r))*QQ(r); + if (alpha(r) < lower(r)) + alpha(r) = lower(r); + else if (alpha(r) > upper(r)) + alpha(r) = upper(r); + + const T delta = old_alpha-alpha(r); + + // Now update the gradient. We will perform the equivalent of: df = Q*alpha + b; + for(long k = 0; k < df.nr(); ++k) + df(k) -= Q(r,k)*delta; + } + //cout << "max_df: " << max_df << endl; + //cout << "objective value: " << 0.5*trans(alpha)*Q*alpha + trans(b)*alpha << endl; + + + + // Now do the main iteration block of this solver. The coordinate descent method + // we used above can improve the objective rapidly in the beginning. However, + // Nesterov's method has more rapid convergence once it gets going so this is what + // we use for the main iteration. + matrix v, v_old; + v = alpha; + // We need to get an upper bound on the Lipschitz constant for this QP. Since that + // is just the max eigenvalue of Q we can do it using Gershgorin disks. + const T lipschitz_bound = max(diag(Q) + (sum_cols(abs(Q)) - abs(diag(Q)))); + double lambda = 0; + unsigned long iter; + for (iter = 0; iter < max_iter; ++iter) + { + const double next_lambda = (1 + std::sqrt(1+4*lambda*lambda))/2; + const double gamma = (1-lambda)/next_lambda; + lambda = next_lambda; + + v_old = v; + + df = Q*alpha + b; + // now take a projected gradient step using Nesterov's method. + v = clamp(alpha - 1.0/lipschitz_bound * df, lower, upper); + alpha = dlib::clamp((1-gamma)*v + gamma*v_old, lower, upper); + + + // check for convergence every 10 iterations + if (iter%10 == 0) + { + max_df = 0; + for (long r = 0; r < Q.nr(); ++r) + { + if (alpha(r) <= lower(r) && df(r) > 0) + ;//alpha(r) = lower(r); + else if (alpha(r) >= upper(r) && df(r) < 0) + ;//alpha(r) = upper(r); + else if (std::abs(df(r)) > max_df) + max_df = std::abs(df(r)); + } + if (max_df < eps) + break; + } + } + + //cout << "max_df: " << max_df << endl; + //cout << "objective value: " << 0.5*trans(alpha)*Q*alpha + trans(b)*alpha << endl; + return iter+1; + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + namespace impl + { + // Check if each vector in Q_offdiag is actually a constant times the 1s vector. + template < + typename T, long NR, long NC, typename MM, typename L + > + bool has_uniform_offdiag_vectors( + const std::map, matrix>& Q_offdiag + ) + { + for (auto& x : Q_offdiag) + { + auto ref = x.second(0); + for (auto& y : x.second) + if (ref != y) + return false; + } + return true; + } + + template < + typename T, long NR, long NC, typename MM, typename L + > + matrix compact_offdiag( + const size_t& num_blocks, + const std::map, matrix>& Q_offdiag + ) + { + matrix temp; + // we can only compact the offdiag information if they are uniform vectors + if (!has_uniform_offdiag_vectors(Q_offdiag)) + return temp; + + temp.set_size(num_blocks, num_blocks); + temp = 0; + + for (auto& x : Q_offdiag) + { + long r = x.first.first; + long c = x.first.second; + temp(r,c) = x.second(0); + temp(c,r) = x.second(0); + } + + return temp; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, long NR, long NC, typename MM, typename L + > + unsigned long solve_qp_box_constrained_blockdiag ( + const std::vector>& Q_blocks, + const std::vector>& bs, + const std::map, matrix>& Q_offdiag, + std::vector>& alphas, + const std::vector>& lowers, + const std::vector>& uppers, + T eps, + unsigned long max_iter + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(Q_blocks.size() > 0); + DLIB_CASSERT(Q_blocks.size() == bs.size() && + Q_blocks.size() == alphas.size() && + Q_blocks.size() == lowers.size() && + Q_blocks.size() == uppers.size(), + "Q_blocks.size(): "<< Q_blocks.size() << "\n" << + "bs.size(): "<< bs.size() << "\n" << + "alphas.size(): "<< alphas.size() << "\n" << + "lowers.size(): "<< lowers.size() << "\n" << + "uppers.size(): "<< uppers.size() << "\n" + ); + for (auto& Q : Q_blocks) + { + DLIB_CASSERT(Q.nr() == Q.nc(), "All the matrices in Q_blocks have the same dimensions."); + DLIB_CASSERT(Q.size() > 0, "All the matrices in Q_blocks must be non-empty and have the same dimensions."); + DLIB_CASSERT(Q.nr() == Q_blocks[0].nr() && Q.nc() == Q_blocks[0].nc(), "All the matrices in Q_blocks have the same dimensions."); + } +#ifdef ENABLE_ASSERTS + for (size_t i = 0; i < alphas.size(); ++i) + { + DLIB_CASSERT(is_col_vector(bs[i]) && bs[i].size() == Q_blocks[0].nr(), + "is_col_vector(bs["< 0 && max_iter > 0, "eps: " << eps << "\nmax_iter: "<< max_iter); +#endif // ENABLE_ASSERTS + + + const auto offdiag_compact = impl::compact_offdiag(Q_blocks.size(), Q_offdiag); + matrix temp, alphas_compact; + + // Compute f'(alpha) (i.e. the gradient of f(alpha)) for the current alpha. + std::vector> df;// = Q*alpha + b; + auto compute_df = [&]() + { + df.resize(Q_blocks.size()); + for (size_t i = 0; i < df.size(); ++i) + df[i] = Q_blocks[i]*alphas[i] + bs[i]; + + + // Don't forget to include the Q_offdiag terms in the computation. Note that + // we have two options for how we can compute this part. If Q_offdiag is + // uniform and can be compacted into a simple matrix and there are a lot of off + // diagonal entries then it's faster to do it as a matrix multiply. Otherwise + // we do the more general computation. + if (offdiag_compact.size() != 0 && Q_offdiag.size() > Q_blocks.size()*5) + { + // Do it as a matrix multiply (with a bit of data shuffling) + alphas_compact.set_size(alphas[0].size(), offdiag_compact.nr()); + for (long c = 0; c < alphas_compact.nc(); ++c) + set_colm(alphas_compact,c) = alphas[c]; + temp = alphas_compact*offdiag_compact; + for (size_t i = 0; i < df.size(); ++i) + df[i] += colm(temp,i); + } + else + { + // Do the fully general computation that allows for non-uniform values in + // the off diagonal vectors. + for (auto& p : Q_offdiag) + { + long r = p.first.first; + long c = p.first.second; + df[r] += pointwise_multiply(p.second, alphas[c]); + if (r != c) + df[c] += pointwise_multiply(p.second, alphas[r]); + } + } + }; + compute_df(); + + + + std::vector> Q_diag, Q_ggd; + std::vector> QQ;// = reciprocal_max(diag(Q)); + QQ.resize(Q_blocks.size()); + Q_diag.resize(Q_blocks.size()); + Q_ggd.resize(Q_blocks.size()); + + // We need to get an upper bound on the Lipschitz constant for this QP. Since that + // is just the max eigenvalue of Q we can do it using Gershgorin disks. + //const T lipschitz_bound = max(diag(Q) + (sum_cols(abs(Q)) - abs(diag(Q)))); + for (size_t i = 0; i < QQ.size(); ++i) + { + auto f = Q_offdiag.find(make_unordered_pair(i,i)); + if (f != Q_offdiag.end()) + Q_diag[i] = diag(Q_blocks[i]) + f->second; + else + Q_diag[i] = diag(Q_blocks[i]); + QQ[i] = reciprocal_max(Q_diag[i]); + + Q_ggd[i] = Q_diag[i] + (sum_cols(abs(Q_blocks[i]))-abs(diag(Q_blocks[i]))); + } + for (auto& p : Q_offdiag) + { + long r = p.first.first; + long c = p.first.second; + if (r != c) + { + Q_ggd[r] += abs(p.second); + Q_ggd[c] += abs(p.second); + } + } + T lipschitz_bound = -std::numeric_limits::infinity(); + for (auto& x : Q_ggd) + lipschitz_bound = std::max(lipschitz_bound, max(x)); + + + const long num_variables = alphas.size()*alphas[0].size(); + + // First we use a coordinate descent method to initialize alpha. + double max_df = 0; + for (long iter = 0; iter < num_variables*2; ++iter) + { + max_df = 0; + long best_r =0; + size_t best_r2 =0; + // find the best alpha to optimize. + for (size_t r2 = 0; r2 < alphas.size(); ++r2) + { + auto& alpha = alphas[r2]; + auto& df_ = df[r2]; + auto& lower = lowers[r2]; + auto& upper = uppers[r2]; + for (long r = 0; r < alpha.nr(); ++r) + { + if (alpha(r) <= lower(r) && df_(r) > 0) + ;//alpha(r) = lower(r); + else if (alpha(r) >= upper(r) && df_(r) < 0) + ;//alpha(r) = upper(r); + else if (std::abs(df_(r)) > max_df) + { + best_r = r; + best_r2 = r2; + max_df = std::abs(df_(r)); + } + } + } + + // now optimize alphas[best_r2](best_r) + const long r = best_r; + auto& alpha = alphas[best_r2]; + auto& lower = lowers[best_r2]; + auto& upper = uppers[best_r2]; + auto& df_ = df[best_r2]; + const T old_alpha = alpha(r); + alpha(r) = -(df_(r)-Q_diag[best_r2](r)*alpha(r))*QQ[best_r2](r); + if (alpha(r) < lower(r)) + alpha(r) = lower(r); + else if (alpha(r) > upper(r)) + alpha(r) = upper(r); + + const T delta = old_alpha-alpha(r); + + // Now update the gradient. We will perform the equivalent of: df = Q*alpha + + // b; except we only need to compute one column of the matrix multiply because + // only one element of alpha changed. + auto& Q = Q_blocks[best_r2]; + for(long k = 0; k < df_.nr(); ++k) + df_(k) -= Q(r,k)*delta; + for(size_t j = 0; j < Q_blocks.size(); ++j) + { + auto f = Q_offdiag.find(make_unordered_pair(best_r2, j)); + if (f != Q_offdiag.end()) + df[j](r) -= f->second(r)*delta; + } + } + + + + + std::vector> v(alphas), v_old(alphas.size()); + double lambda = 0; + unsigned long iter; + // Now do the main iteration block of this solver. The coordinate descent method + // we used above can improve the objective rapidly in the beginning. However, + // Nesterov's method has more rapid convergence once it gets going so this is what + // we use for the main iteration. + for (iter = 0; iter < max_iter; ++iter) + { + const double next_lambda = (1 + std::sqrt(1+4*lambda*lambda))/2; + const double gamma = (1-lambda)/next_lambda; + lambda = next_lambda; + + v_old.swap(v); + + //df = Q*alpha + b; + compute_df(); + + // now take a projected gradient step using Nesterov's method. + for (size_t j = 0; j < alphas.size(); ++j) + { + v[j] = clamp(alphas[j] - 1.0/lipschitz_bound * df[j], lowers[j], uppers[j]); + alphas[j] = clamp((1-gamma)*v[j] + gamma*v_old[j], lowers[j], uppers[j]); + } + + + // check for convergence every 10 iterations + if (iter%10 == 0) + { + max_df = 0; + for (size_t r2 = 0; r2 < alphas.size(); ++r2) + { + auto& alpha = alphas[r2]; + auto& df_ = df[r2]; + auto& lower = lowers[r2]; + auto& upper = uppers[r2]; + for (long r = 0; r < alpha.nr(); ++r) + { + if (alpha(r) <= lower(r) && df_(r) > 0) + ;//alpha(r) = lower(r); + else if (alpha(r) >= upper(r) && df_(r) < 0) + ;//alpha(r) = upper(r); + else if (std::abs(df_(r)) > max_df) + max_df = std::abs(df_(r)); + } + } + if (max_df < eps) + break; + } + } + + return iter+1; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP1, + typename EXP2, + typename T, long NRa, long NRb + > + unsigned long find_gap_between_convex_hulls ( + const matrix_exp& A, + const matrix_exp& B, + matrix& cA, + matrix& cB, + const double eps, + const unsigned long max_iter = 1000 + ) + { + DLIB_CASSERT(A.size() != 0); + DLIB_CASSERT(B.size() != 0); + DLIB_CASSERT(A.nr() == B.nr(), "The dimensionality of the points in both convex hull sets must match"); + DLIB_CASSERT(eps > 0); + DLIB_CASSERT(max_iter > 0); + + cA.set_size(A.nc()); + cB.set_size(B.nc()); + + // initialize to the centroids of A and B respectively. + cA = 1.0/cA.size(); + cB = 1.0/cB.size(); + + + matrix AA, BB, AB, ABb, ABa; + + AA = trans(A)*A; + BB = trans(B)*B; + AB = trans(A)*B; + + unsigned long iter = 0; + for (iter = 0; iter < max_iter; ++iter) + { + // find the convex combination of A that is nearest to B*cB + ABb = AB*cB; + const auto smo_iter1 = solve_qp_using_smo(AA, ABb, cA, eps, cA.size()); + + // now find the convex combination of B that is nearest to A*cA + ABa = trans(AB)*cA; + const auto smo_iter2 = solve_qp_using_smo(BB, ABa, cB, eps, cB.size()); + + // stop if the QP solvers failed to improve + if (smo_iter1 == 1 && smo_iter2 == 1) + break; + } + + + return iter+1; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_OPTIMIZATION_SOLVE_QP_UsING_SMO_Hh_ + diff --git a/ml/dlib/dlib/optimization/optimization_solve_qp_using_smo_abstract.h b/ml/dlib/dlib/optimization/optimization_solve_qp_using_smo_abstract.h new file mode 100644 index 000000000..5e7d5ec3f --- /dev/null +++ b/ml/dlib/dlib/optimization/optimization_solve_qp_using_smo_abstract.h @@ -0,0 +1,282 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_OPTIMIZATION_SOLVE_QP_UsING_SMO_ABSTRACT_Hh_ +#ifdef DLIB_OPTIMIZATION_SOLVE_QP_UsING_SMO_ABSTRACT_Hh_ + +#include "../matrix.h" +#include +#include "../unordered_pair.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP1, + typename EXP2, + typename T, long NR, long NC, typename MM, typename L + > + unsigned long solve_qp_using_smo ( + const matrix_exp& Q, + const matrix_exp& b, + matrix& alpha, + T eps, + unsigned long max_iter + ); + /*! + requires + - Q.nr() == Q.nc() + - is_col_vector(b) == true + - is_col_vector(alpha) == true + - b.size() == alpha.size() == Q.nr() + - alpha.size() > 0 + - min(alpha) >= 0 + - eps > 0 + - max_iter > 0 + ensures + - Let C == sum(alpha) (i.e. C is the sum of the alpha values you + supply to this function) + - This function solves the following quadratic program: + Minimize: f(alpha) == 0.5*trans(alpha)*Q*alpha - trans(alpha)*b + subject to the following constraints: + - sum(alpha) == C (i.e. the sum of alpha values doesn't change) + - min(alpha) >= 0 (i.e. all alpha values are nonnegative) + Where f is convex. This means that Q should be positive-semidefinite. + - The solution to the above QP will be stored in #alpha. + - This function uses a simple implementation of the sequential minimal + optimization algorithm. It starts the algorithm with the given alpha + and it works on the problem until the duality gap (i.e. how far away + we are from the optimum solution) is less than eps. So eps controls + how accurate the solution is and smaller values result in better solutions. + - At most max_iter iterations of optimization will be performed. + - returns the number of iterations performed. If this method fails to + converge to eps accuracy then the number returned will be max_iter+1. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP1, + typename EXP2, + typename EXP3, + typename T, long NR, long NC, typename MM, typename L, + long NR2, long NC2 + > + unsigned long solve_qp4_using_smo ( + const matrix_exp& A, + const matrix_exp& Q, + const matrix_exp& b, + const matrix_exp& d, + matrix& alpha, + matrix& lambda, + T eps, + unsigned long max_iter, + T max_lambda = std::numeric_limits::infinity() + ); + /*! + requires + - A.nc() == alpha.size() + - Q.nr() == Q.nc() + - is_col_vector(b) == true + - is_col_vector(d) == true + - is_col_vector(alpha) == true + - b.size() == alpha.size() == Q.nr() + - d.size() == A.nr() + - alpha.size() > 0 + - min(alpha) >= 0 + - eps > 0 + - max_iter > 0 + - max_lambda >= 0 + ensures + - Let C == sum(alpha) (i.e. C is the sum of the alpha values you + supply to this function) + - This function solves the following quadratic program: + Minimize: f(alpha,lambda) == 0.5*trans(alpha)*Q*alpha - trans(alpha)*b + + 0.5*trans(lambda)*lambda - trans(lambda)*A*alpha - trans(lambda)*d + subject to the following constraints: + - sum(alpha) == C (i.e. the sum of alpha values doesn't change) + - min(alpha) >= 0 (i.e. all alpha values are nonnegative) + - min(lambda) >= 0 (i.e. all lambda values are nonnegative) + - max(lambda) <= max_lambda (i.e. all lambda values are less than max_lambda) + Where f is convex. This means that Q should be positive-semidefinite. + - If you don't want an upper limit on lambda then max_lambda can be set to + infinity. + - The solution to the above QP will be stored in #alpha and #lambda. + - This function uses a simple implementation of the sequential minimal + optimization algorithm. It starts the algorithm with the given alpha + and it works on the problem until the duality gap (i.e. how far away + we are from the optimum solution) is less than eps. So eps controls + how accurate the solution is and smaller values result in better solutions. + The initial value of lambda is ignored since the optimal lambda can be + obtained via a simple closed form expression given alpha. + - At most max_iter iterations of optimization will be performed. + - returns the number of iterations performed. If this method fails to + converge to eps accuracy then the number returned will be max_iter+1. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP1, + typename EXP2, + typename T, long NR, long NC, typename MM, typename L + > + unsigned long solve_qp_box_constrained ( + const matrix_exp& Q, + const matrix_exp& b, + matrix& alpha, + const matrix& lower, + const matrix& upper, + T eps, + unsigned long max_iter + ); + /*! + requires + - Q.nr() == Q.nc() + - alpha.size() == lower.size() == upper.size() + - is_col_vector(b) == true + - is_col_vector(alpha) == true + - is_col_vector(lower) == true + - is_col_vector(upper) == true + - b.size() == alpha.size() == Q.nr() + - alpha.size() > 0 + - 0 <= min(alpha-lower) + - 0 <= max(upper-alpha) + - eps > 0 + - max_iter > 0 + ensures + - This function solves the following quadratic program: + Minimize: f(alpha) == 0.5*trans(alpha)*Q*alpha + trans(b)*alpha + subject to the following box constraints on alpha: + - 0 <= min(alpha-lower) + - 0 <= max(upper-alpha) + Where f is convex. This means that Q should be positive-semidefinite. + - The solution to the above QP will be stored in #alpha. + - This function uses a combination of a SMO algorithm along with Nesterov's + method as the main iteration of the solver. It starts the algorithm with the + given alpha and it works on the problem until the derivative of f(alpha) is + smaller than eps for each element of alpha or the alpha value is at a box + constraint. So eps controls how accurate the solution is and smaller values + result in better solutions. + - At most max_iter iterations of optimization will be performed. + - returns the number of iterations performed. If this method fails to + converge to eps accuracy then the number returned will be max_iter+1. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, long NR, long NC, typename MM, typename L + > + unsigned long solve_qp_box_constrained_blockdiag ( + const std::vector>& Q_blocks, + const std::vector>& bs, + const std::map, matrix>& Q_offdiag, + std::vector>& alphas, + const std::vector>& lowers, + const std::vector>& uppers, + T eps, + unsigned long max_iter + ); + /*! + requires + - Q_blocks.size() > 0 + - Q_blocks.size() == bs.size() == alphas.size() == lowers.size() == uppers.size() + - All the matrices in Q_blocks have the same dimensions. Moreover, they are + non-empty square matrices. + - All the matrices in bs, Q_offdiag, alphas, lowers, and uppers have the same + dimensions. Moreover, they are all column vectors. + - Q_blocks[0].nr() == alphas[0].size() + (i.e. the dimensionality of the column vectors in alphas must match the + dimensionality of the square matrices in Q_blocks.) + - for all valid i: + - 0 <= min(alphas[i]-lowers[i]) + - 0 <= max(uppers[i]-alphas[i]) + - eps > 0 + - max_iter > 0 + ensures + - This function solves the same QP as solve_qp_box_constrained(), except it is + optimized for the case where the Q matrix has a certain sparsity structure. + To be precise: + - Let Q1 be a block diagonal matrix with the elements of Q_blocks placed + along its diagonal, and in the order contained in Q_blocks. + - Let Q2 be a matrix with the same size as Q1, except instead of being block diagonal, it + is block structured into Q_blocks.nr() by Q_blocks.nc() blocks. If we let (r,c) be the + coordinate of each block then each block contains the matrix + diagm(Q_offdiag[make_unordered_pair(r,c)]) or the zero matrix if Q_offdiag has no entry + for the coordinate (r,c). + - Let Q == Q1+Q2 + - Let b == the concatenation of all the vectors in bs into one big vector. + - Let alpha == the concatenation of all the vectors in alphas into one big vector. + - Let lower == the concatenation of all the vectors in lowers into one big vector. + - Let upper == the concatenation of all the vectors in uppers into one big vector. + - Then this function solves the following quadratic program: + Minimize: f(alpha) == 0.5*trans(alpha)*Q*alpha + trans(b)*alpha + subject to the following box constraints on alpha: + - 0 <= min(alpha-lower) + - 0 <= max(upper-alpha) + Where f is convex. This means that Q should be positive-semidefinite. + - More specifically, this function is identical to + solve_qp_box_constrained(Q, b, alpha, lower, upper, eps, max_iter), + except that it runs faster since it avoids unnecessary computation by + taking advantage of the sparsity structure in the QP. + - The solution to the above QP will be stored in #alphas. + - This function uses a combination of a SMO algorithm along with Nesterov's + method as the main iteration of the solver. It starts the algorithm with the + given alpha and it works on the problem until the derivative of f(alpha) is + smaller than eps for each element of alpha or the alpha value is at a box + constraint. So eps controls how accurate the solution is and smaller values + result in better solutions. + - At most max_iter iterations of optimization will be performed. + - returns the number of iterations performed. If this method fails to + converge to eps accuracy then the number returned will be max_iter+1. + !*/ +// ---------------------------------------------------------------------------------------- + + template < + typename EXP1, + typename EXP2, + typename T, long NRa, long NRb + > + unsigned long find_gap_between_convex_hulls ( + const matrix_exp& A, + const matrix_exp& B, + matrix& cA, + matrix& cB, + const double eps, + const unsigned long max_iter = 1000 + ); + /*! + requires + - A.nr() == B.nr() + - A.size() != 0 + - B.size() != 0 + - eps > 0 + - max_iter > 0 + ensures + - If you think of A and B as sets of column vectors, then we can identify the + convex sets hullA and hullB, which are the convex hulls of A and B + respectively. This function finds the pair of points in hullA and hullB that + are nearest to each other. To be precise, this function solves the following + quadratic program: + Minimize: f(cA,cB) == length_squared(A*cA - B*cB) + subject to the following constraints on cA and cB: + - is_col_vector(cA) == true && cA.size() == A.nc() + - is_col_vector(cB) == true && cB.size() == B.nc() + - sum(cA) == 1 && min(cA) >= 0 + - sum(cB) == 1 && min(cB) >= 0 + - This function uses an iterative block coordinate descent algorithm to solve + the QP. It runs until either max_iter iterations have been performed or the + QP is solved to at least eps accuracy. + - returns the number of iterations performed. If this method fails to + converge to eps accuracy then the number returned will be max_iter+1. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_OPTIMIZATION_SOLVE_QP_UsING_SMO_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/optimization/optimization_stop_strategies.h b/ml/dlib/dlib/optimization/optimization_stop_strategies.h new file mode 100644 index 000000000..a0243cacf --- /dev/null +++ b/ml/dlib/dlib/optimization/optimization_stop_strategies.h @@ -0,0 +1,173 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_OPTIMIZATIOn_STOP_STRATEGIES_H_ +#define DLIB_OPTIMIZATIOn_STOP_STRATEGIES_H_ + +#include +#include +#include "../matrix.h" +#include "../algs.h" +#include "optimization_stop_strategies_abstract.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class objective_delta_stop_strategy + { + public: + explicit objective_delta_stop_strategy ( + double min_delta = 1e-7 + ) : _verbose(false), _been_used(false), _min_delta(min_delta), _max_iter(0), _cur_iter(0), _prev_funct_value(0) + { + DLIB_ASSERT ( + min_delta >= 0, + "\t objective_delta_stop_strategy(min_delta)" + << "\n\t min_delta can't be negative" + << "\n\t min_delta: " << min_delta + ); + } + + objective_delta_stop_strategy ( + double min_delta, + unsigned long max_iter + ) : _verbose(false), _been_used(false), _min_delta(min_delta), _max_iter(max_iter), _cur_iter(0), _prev_funct_value(0) + { + DLIB_ASSERT ( + min_delta >= 0 && max_iter > 0, + "\t objective_delta_stop_strategy(min_delta, max_iter)" + << "\n\t min_delta can't be negative and max_iter can't be 0" + << "\n\t min_delta: " << min_delta + << "\n\t max_iter: " << max_iter + ); + } + + objective_delta_stop_strategy& be_verbose( + ) + { + _verbose = true; + return *this; + } + + template + bool should_continue_search ( + const T& , + const double funct_value, + const T& + ) + { + if (_verbose) + { + using namespace std; + cout << "iteration: " << _cur_iter << " objective: " << funct_value << endl; + } + + ++_cur_iter; + if (_been_used) + { + // Check if we have hit the max allowable number of iterations. (but only + // check if _max_iter is enabled (i.e. not 0)). + if (_max_iter != 0 && _cur_iter > _max_iter) + return false; + + // check if the function change was too small + if (std::abs(funct_value - _prev_funct_value) < _min_delta) + return false; + } + + _been_used = true; + _prev_funct_value = funct_value; + return true; + } + + private: + bool _verbose; + + bool _been_used; + double _min_delta; + unsigned long _max_iter; + unsigned long _cur_iter; + double _prev_funct_value; + }; + +// ---------------------------------------------------------------------------------------- + + class gradient_norm_stop_strategy + { + public: + explicit gradient_norm_stop_strategy ( + double min_norm = 1e-7 + ) : _verbose(false), _min_norm(min_norm), _max_iter(0), _cur_iter(0) + { + DLIB_ASSERT ( + min_norm >= 0, + "\t gradient_norm_stop_strategy(min_norm)" + << "\n\t min_norm can't be negative" + << "\n\t min_norm: " << min_norm + ); + } + + gradient_norm_stop_strategy ( + double min_norm, + unsigned long max_iter + ) : _verbose(false), _min_norm(min_norm), _max_iter(max_iter), _cur_iter(0) + { + DLIB_ASSERT ( + min_norm >= 0 && max_iter > 0, + "\t gradient_norm_stop_strategy(min_norm, max_iter)" + << "\n\t min_norm can't be negative and max_iter can't be 0" + << "\n\t min_norm: " << min_norm + << "\n\t max_iter: " << max_iter + ); + } + + gradient_norm_stop_strategy& be_verbose( + ) + { + _verbose = true; + return *this; + } + + template + bool should_continue_search ( + const T& , + const double funct_value, + const T& funct_derivative + ) + { + if (_verbose) + { + using namespace std; + cout << "iteration: " << _cur_iter << " objective: " << funct_value << " gradient norm: " << length(funct_derivative) << endl; + } + + ++_cur_iter; + + // Check if we have hit the max allowable number of iterations. (but only + // check if _max_iter is enabled (i.e. not 0)). + if (_max_iter != 0 && _cur_iter > _max_iter) + return false; + + // check if the gradient norm is too small + if (length(funct_derivative) < _min_norm) + return false; + + return true; + } + + private: + bool _verbose; + + double _min_norm; + unsigned long _max_iter; + unsigned long _cur_iter; + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_OPTIMIZATIOn_STOP_STRATEGIES_H_ + diff --git a/ml/dlib/dlib/optimization/optimization_stop_strategies_abstract.h b/ml/dlib/dlib/optimization/optimization_stop_strategies_abstract.h new file mode 100644 index 000000000..6a999f8d9 --- /dev/null +++ b/ml/dlib/dlib/optimization/optimization_stop_strategies_abstract.h @@ -0,0 +1,157 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_OPTIMIZATIOn_STOP_STRATEGIES_ABSTRACT_ +#ifdef DLIB_OPTIMIZATIOn_STOP_STRATEGIES_ABSTRACT_ + +#include +#include +#include "../matrix/matrix_abstract.h" +#include "../algs.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class objective_delta_stop_strategy + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a strategy for deciding if an optimization + algorithm should terminate. This particular object looks at the + change in the objective function from one iteration to the next and + bases its decision on how large this change is. If the change + is below a user given threshold then the search stops. + !*/ + + public: + explicit objective_delta_stop_strategy ( + double min_delta = 1e-7 + ); + /*! + requires + - min_delta >= 0 + ensures + - This stop strategy object will only consider a search to be complete + if a change in an objective function from one iteration to the next + is less than min_delta. + !*/ + + objective_delta_stop_strategy ( + double min_delta, + unsigned long max_iter + ); + /*! + requires + - min_delta >= 0 + - max_iter > 0 + ensures + - This stop strategy object will only consider a search to be complete + if a change in an objective function from one iteration to the next + is less than min_delta or more than max_iter iterations has been + executed. + !*/ + + objective_delta_stop_strategy& be_verbose( + ); + /*! + ensures + - causes this object to print status messages to standard out + every time should_continue_search() is called. + - returns *this + !*/ + + template + bool should_continue_search ( + const T& x, + const double funct_value, + const T& funct_derivative + ); + /*! + requires + - this function is only called once per search iteration + - for some objective function f(): + - x == the search point for the current iteration + - funct_value == f(x) + - funct_derivative == derivative(f)(x) + ensures + - returns true if the point x doest not satisfy the stopping condition and + false otherwise. + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + class gradient_norm_stop_strategy + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a strategy for deciding if an optimization + algorithm should terminate. This particular object looks at the + norm (i.e. the length) of the current gradient vector and stops + if it is smaller than a user given threshold. + !*/ + + public: + explicit gradient_norm_stop_strategy ( + double min_norm = 1e-7 + ); + /*! + requires + - min_norm >= 0 + ensures + - This stop strategy object will only consider a search to be complete + if the current gradient norm is less than min_norm + !*/ + + gradient_norm_stop_strategy ( + double min_norm, + unsigned long max_iter + ); + /*! + requires + - min_norm >= 0 + - max_iter > 0 + ensures + - This stop strategy object will only consider a search to be complete + if the current gradient norm is less than min_norm or more than + max_iter iterations has been executed. + !*/ + + gradient_norm_stop_strategy& be_verbose( + ); + /*! + ensures + - causes this object to print status messages to standard out + every time should_continue_search() is called. + - returns *this + !*/ + + template + bool should_continue_search ( + const T& x, + const double funct_value, + const T& funct_derivative + ); + /*! + requires + - this function is only called once per search iteration + - for some objective function f(): + - x == the search point for the current iteration + - funct_value == f(x) + - funct_derivative == derivative(f)(x) + ensures + - returns true if the point x doest not satisfy the stopping condition and + false otherwise. + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_OPTIMIZATIOn_STOP_STRATEGIES_ABSTRACT_ + diff --git a/ml/dlib/dlib/optimization/optimization_trust_region.h b/ml/dlib/dlib/optimization/optimization_trust_region.h new file mode 100644 index 000000000..5f0ad897f --- /dev/null +++ b/ml/dlib/dlib/optimization/optimization_trust_region.h @@ -0,0 +1,564 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_OPTIMIZATION_TRUST_REGIoN_Hh_ +#define DLIB_OPTIMIZATION_TRUST_REGIoN_Hh_ + +#include "../matrix.h" +#include "optimization_trust_region_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP1, + typename EXP2, + typename T, long NR, long NC, typename MM, typename L + > + unsigned long solve_trust_region_subproblem ( + const matrix_exp& B, + const matrix_exp& g, + const typename EXP1::type radius, + matrix& p, + double eps, + unsigned long max_iter + ) + { + /* + This is an implementation of algorithm 4.3(Trust Region Subproblem) + from the book Numerical Optimization by Nocedal and Wright. Some of + the details are also from Practical Methods of Optimization by Fletcher. + */ + + // make sure requires clause is not broken + DLIB_ASSERT(B.nr() == B.nc() && is_col_vector(g) && g.size() == B.nr(), + "\t unsigned long solve_trust_region_subproblem()" + << "\n\t invalid arguments were given to this function" + << "\n\t B.nr(): " << B.nr() + << "\n\t B.nc(): " << B.nc() + << "\n\t is_col_vector(g): " << is_col_vector(g) + << "\n\t g.size(): " << g.size() + ); + DLIB_ASSERT(radius > 0 && eps > 0 && max_iter > 0, + "\t unsigned long solve_trust_region_subproblem()" + << "\n\t invalid arguments were given to this function" + << "\n\t radius: " << radius + << "\n\t eps: " << eps + << "\n\t max_iter: " << max_iter + ); + + + const_temp_matrix BB(B); + const_temp_matrix gg(g); + + p.set_size(g.nr(),g.nc()); + p = 0; + + + const T numeric_eps = max(diag(abs(BB)))*std::numeric_limits::epsilon(); + + matrix R; + + T lambda = 0; + + // We need to put a bracket around lambda. It can't go below 0. We + // can get an upper bound using Gershgorin disks. + // This number is a lower bound on the eigenvalues in BB + const T BB_min_eigenvalue = min(diag(BB) - (sum_cols(abs(BB)) - abs(diag(BB)))); + + const T g_norm = length(gg); + + T lambda_min = 0; + T lambda_max = put_in_range(0, + std::numeric_limits::max(), + g_norm/radius - BB_min_eigenvalue); + + + // If we can tell that the minimum is at 0 then don't do anything. Just return the answer. + if (g_norm < numeric_eps && BB_min_eigenvalue > numeric_eps) + { + return 0; + } + + + // how much lambda has changed recently + T lambda_delta = 0; + + for (unsigned long i = 0; i < max_iter; ++i) + { + R = chol(BB + lambda*identity_matrix(BB.nr())); + + // if the cholesky decomposition doesn't exist. + if (R(R.nr()-1, R.nc()-1) <= 0) + { + // If B is indefinite and g is equal to 0 then we should + // quit this loop and go right to the eigenvalue decomposition method. + if (g_norm <= numeric_eps) + break; + + // narrow the bracket on lambda. Obviously the current lambda is + // too small. + lambda_min = lambda; + + // jump towards the max value. Eventually there will + // be a lambda that results in a cholesky decomposition. + const T alpha = 0.10; + lambda = (1-alpha)*lambda + alpha*lambda_max; + continue; + } + + using namespace blas_bindings; + + p = -gg; + // Solve RR'*p = -g for p. + // Solve R*q = -g for q where q = R'*p. + if (R.nr() == 2) + { + p(0) = p(0)/R(0,0); + p(1) = (p(1)-R(1,0)*p(0))/R(1,1); + } + else + { + triangular_solver(CblasLeft, CblasLower, CblasNoTrans, CblasNonUnit, R, p); + } + const T q_norm = length(p); + + // Solve R'*p = q for p. + if (R.nr() == 2) + { + p(1) = p(1)/R(1,1); + p(0) = (p(0)-R(1,0)*p(1))/R(0,0); + } + else + { + triangular_solver(CblasLeft, CblasLower, CblasTrans, CblasNonUnit, R, p); + } + const T p_norm = length(p); + + // check if we are done. + if (lambda == 0) + { + if (p_norm < radius) + { + // i will always be 0 in this case. So we return 1. + return i+1; + } + } + else + { + // if we are close enough to the solution then terminate + if (std::abs(p_norm - radius)/radius < eps) + return i+1; + } + + // shrink our bracket on lambda + if (p_norm < radius) + lambda_max = lambda; + else + lambda_min = lambda; + + + if (p_norm <= radius*std::numeric_limits::epsilon()) + { + const T alpha = 0.01; + lambda = (1-alpha)*lambda_min + alpha*lambda_max; + continue; + } + + const T old_lambda = lambda; + + // figure out which lambda to try next + lambda = lambda + std::pow(q_norm/p_norm,2)*(p_norm - radius)/radius; + + // make sure the chosen lambda is within our bracket (but not exactly at either end). + const T gap = (lambda_max-lambda_min)*0.01; + lambda = put_in_range(lambda_min+gap, lambda_max-gap, lambda); + + // Keep track of how much lambda is thrashing around inside the search bracket. If it + // keeps moving around a whole lot then cut the search bracket in half. + lambda_delta += std::abs(lambda - old_lambda); + if (lambda_delta > 3*(lambda_max-lambda_min)) + { + lambda = (lambda_min+lambda_max)/2; + lambda_delta = 0; + } + } // end for loop + + + // We are probably in the "hard case". Use an eigenvalue decomposition to sort things out. + // Either that or the eps was just set too tight and really we are already done. + eigenvalue_decomposition ed(make_symmetric(BB)); + + matrix ev = ed.get_real_eigenvalues(); + const long min_eig_idx = index_of_min(ev); + + + ev -= min(ev); + // zero out any values which are basically zero + ev = pointwise_multiply(ev, ev > max(abs(ev))*std::numeric_limits::epsilon()); + ev = reciprocal(ev); + + + // figure out part of what p should be assuming we are in the hard case. + matrix p_hard; + p_hard = trans(ed.get_pseudo_v())*gg; + p_hard = diagm(ev)*p_hard; + p_hard = ed.get_pseudo_v()*p_hard; + + + // If we really are in the hard case then this if will be true. Otherwise, the p + // we found in the "easy case" loop up top is the best answer. + if (length(p_hard) < radius && length(p_hard) >= length(p)) + { + // adjust the length of p_hard by adding a component along the eigenvector associated with + // the smallest eigenvalue. We want to make it the case that length(p) == radius. + const T tau = std::sqrt(radius*radius - length_squared(p_hard)); + p = p_hard + tau*colm(ed.get_pseudo_v(),min_eig_idx); + + + // if we have to do an eigenvalue decomposition then say we did all the iterations + return max_iter; + } + + // if we get this far it means we didn't converge to eps accuracy. + return max_iter+1; + } + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + template < + typename EXP1, + typename EXP2, + typename EXP3 + > + bool bounds_violated ( + const matrix_exp& v, + const matrix_exp& l, + const matrix_exp& u + ) + { + DLIB_ASSERT(v.nr() == l.nr() && v.nr() == u.nr()); + DLIB_ASSERT(v.nc() == l.nc() && v.nc() == u.nc()); + for (long r = 0; r < v.nr(); ++r) + { + for (long c = 0; c < v.nc(); c++) + { + if (!(l(r,c) <= v(r,c) && v(r,c) <= u(r,c))) + return true; + } + } + return false; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP1, + typename EXP2, + typename T, long NR, long NC, typename MM, typename L, + typename EXP3 + > + void solve_trust_region_subproblem_bounded ( + const matrix_exp& B_, + const matrix_exp& g_, + const typename EXP1::type radius_, + matrix& p_, + double eps, + unsigned long max_iter, + const matrix_exp& lower_, + const matrix_exp& upper_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(B_.nr() == B_.nc() && is_col_vector(g_) && g_.size() == B_.nr(), + "\t unsigned long solve_trust_region_subproblem_bounded()" + << "\n\t invalid arguments were given to this function" + << "\n\t B_.nr(): " << B_.nr() + << "\n\t B_.nc(): " << B_.nc() + << "\n\t is_col_vector(g_): " << is_col_vector(g_) + << "\n\t g_.size(): " << g_.size() + ); + DLIB_ASSERT(radius_ > 0 && eps > 0 && max_iter > 0, + "\t unsigned long solve_trust_region_subproblem_bounded()" + << "\n\t invalid arguments were given to this function" + << "\n\t radius_: " << radius_ + << "\n\t eps: " << eps + << "\n\t max_iter: " << max_iter + ); + DLIB_ASSERT(is_col_vector(lower_) && lower_.size() == g_.size()); + DLIB_ASSERT(is_col_vector(upper_) && upper_.size() == g_.size()); + DLIB_ASSERT(max(upper_-lower_) >= 0); + // make sure the problem is feasible. That is, there should be a point inside the + // lower and upper bounds that has a norm <= radius_ + DLIB_ASSERT(length(clamp(zeros_matrix(lower_),lower_,upper_)) <= radius_, + "The lower and upper bounds are incompatible with the radius since there is no point within the bounds with a norm less than the radius."); + + // We are going to solve this by greedily finding the most violated bound constraint, + // locking that variable to its constrained value, removing it from the problem, + // and then resolving. We do that until no more constraint violations are present. + + solve_trust_region_subproblem(B_,g_,radius_,p_,eps,max_iter); + + + // just stop here if all the bounds are satisfied. + if (!impl::bounds_violated(p_, lower_, upper_)) + return; + + matrix B = matrix_cast(B_); + matrix g = matrix_cast(g_); + double radius = radius_; + matrix p = matrix_cast(p_); + matrix lower = matrix_cast(lower_); + matrix upper = matrix_cast(upper_); + + // keep a table that tells us how to map any reduced QP back to the original QP + std::vector idxs(g.size()); + for (size_t i = 0; i < idxs.size(); ++i) + idxs[i] = i; + + + // while we haven't found a p that satisfies the bounds constraints + while(impl::bounds_violated(p, lower, upper) ) + { + // Find the most violated variable and fix its value to a constant (the bound + // value). + long most_violated_idx = 0; + double max_violation = 0; + double bounded_value = 0; + for (long i = 0; i < lower.size(); ++i) + { + if (!(lower(i) <= p(i) && p(i) <= upper(i))) + { + if (lower(i)-p(i) > max_violation) + { + max_violation = lower(i)-p(i); + most_violated_idx = i; + bounded_value = lower(i); + } + else if (p(i)-upper(i) > max_violation) + { + max_violation = p(i)-upper(i); + most_violated_idx = i; + bounded_value = upper(i); + } + } + } + + // assign this variable to its final value. + p_(idxs[most_violated_idx]) = bounded_value; + + + // now reduce the QP by removing the variable p_(idxs[most_violated_idx]). + idxs.erase(idxs.begin()+most_violated_idx); + // we are out of variables to remove since everything is at bounds. + if (idxs.size() == 0) + break; + + lower = remove_row(lower,most_violated_idx); + upper = remove_row(upper,most_violated_idx); + g += colm(B,most_violated_idx)*bounded_value; + g = remove_row(g,most_violated_idx); + p = remove_row(p,most_violated_idx); + B = removerc(B,most_violated_idx, most_violated_idx); + + // Removing a variable changes the radius, so we have to subtract the bounded + // value from the radius so as to not change the effective radius for the whole + // problem. + double squared_radius = radius*radius - bounded_value*bounded_value; + if (squared_radius <= 0) + { + p = 0; + break; + } + radius = std::sqrt(squared_radius); + + + solve_trust_region_subproblem(B,g,radius,p,eps,max_iter); + } + + // assign the non-bound-constrained variables to their final values + for (size_t i = 0; i < idxs.size(); ++i) + p_(idxs[i]) = p(i); + + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename stop_strategy_type, + typename funct_model + > + double find_min_trust_region ( + stop_strategy_type stop_strategy, + const funct_model& model, + typename funct_model::column_vector& x, + double radius = 1 + ) + { + /* + This is an implementation of algorithm 4.1(Trust Region) + from the book Numerical Optimization by Nocedal and Wright. + */ + + // make sure requires clause is not broken + DLIB_ASSERT(is_col_vector(x) && radius > 0, + "\t double find_min_trust_region()" + << "\n\t invalid arguments were given to this function" + << "\n\t is_col_vector(x): " << is_col_vector(x) + << "\n\t radius: " << radius + ); + + const double initial_radius = radius; + + typedef typename funct_model::column_vector T; + typedef typename T::type type; + + typename funct_model::general_matrix h; + typename funct_model::column_vector g, p, d; + type f_value = model(x); + + model.get_derivative_and_hessian(x,g,h); + + DLIB_ASSERT(is_finite(x), "The objective function generated non-finite outputs"); + DLIB_ASSERT(is_finite(g), "The objective function generated non-finite outputs"); + DLIB_ASSERT(is_finite(h), "The objective function generated non-finite outputs"); + + // Sometimes the loop below won't modify x because the trust region step failed. + // This bool tells us when we are in that case. + bool stale_x = false; + + while(stale_x || stop_strategy.should_continue_search(x, f_value, g)) + { + const unsigned long iter = solve_trust_region_subproblem(h, + g, + radius, + p, + 0.1, + 20); + + + const type new_f_value = model(x+p); + const type predicted_improvement = -0.5*trans(p)*h*p - trans(g)*p; + const type measured_improvement = (f_value - new_f_value); + + // If the sub-problem can't find a way to improve then stop. This only happens when p is essentially 0. + if (std::abs(predicted_improvement) <= std::abs(measured_improvement)*std::numeric_limits::epsilon()) + break; + + // predicted_improvement shouldn't be negative but it might be if something went + // wrong in the trust region solver. So put abs() here to guard against that. This + // way the sign of rho is determined only by the sign of measured_improvement. + const type rho = measured_improvement/std::abs(predicted_improvement); + + + if (!is_finite(rho)) + break; + + if (rho < 0.25) + { + radius *= 0.25; + + // something has gone horribly wrong if the radius has shrunk to zero. So just + // give up if that happens. + if (radius <= initial_radius*std::numeric_limits::epsilon()) + break; + } + else + { + // if rho > 0.75 and we are being checked by the radius + if (rho > 0.75 && iter > 1) + { + radius = std::min(1000, 2*radius); + } + } + + if (rho > 0) + { + x = x + p; + f_value = new_f_value; + model.get_derivative_and_hessian(x,g,h); + stale_x = false; + } + else + { + stale_x = true; + } + + DLIB_ASSERT(is_finite(x), "The objective function generated non-finite outputs"); + DLIB_ASSERT(is_finite(g), "The objective function generated non-finite outputs"); + DLIB_ASSERT(is_finite(h), "The objective function generated non-finite outputs"); + } + + return f_value; + } + +// ---------------------------------------------------------------------------------------- + + template + struct negate_tr_model + { + negate_tr_model( const funct_model& m) : model(m) {} + + const funct_model& model; + + typedef typename funct_model::column_vector column_vector; + typedef typename funct_model::general_matrix general_matrix; + + template + typename T::type operator() (const T& x) const + { + return -model(x); + } + + template + void get_derivative_and_hessian ( + const T& x, + T& d, + U& h + ) const + { + model.get_derivative_and_hessian(x,d,h); + d = -d; + h = -h; + } + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename stop_strategy_type, + typename funct_model + > + double find_max_trust_region ( + stop_strategy_type stop_strategy, + const funct_model& model, + typename funct_model::column_vector& x, + double radius = 1 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(is_col_vector(x) && radius > 0, + "\t double find_max_trust_region()" + << "\n\t invalid arguments were given to this function" + << "\n\t is_col_vector(x): " << is_col_vector(x) + << "\n\t radius: " << radius + ); + + return -find_min_trust_region(stop_strategy, + negate_tr_model(model), + x, + radius); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_OPTIMIZATION_TRUST_REGIoN_Hh_ + diff --git a/ml/dlib/dlib/optimization/optimization_trust_region_abstract.h b/ml/dlib/dlib/optimization/optimization_trust_region_abstract.h new file mode 100644 index 000000000..303ee746d --- /dev/null +++ b/ml/dlib/dlib/optimization/optimization_trust_region_abstract.h @@ -0,0 +1,233 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_OPTIMIZATION_TRUST_REGIoN_H_ABSTRACTh_ +#ifdef DLIB_OPTIMIZATION_TRUST_REGIoN_H_ABSTRACTh_ + +#include "../matrix/matrix_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP1, + typename EXP2, + typename T, long NR, long NC, typename MM, typename L + > + unsigned long solve_trust_region_subproblem ( + const matrix_exp& B, + const matrix_exp& g, + const typename EXP1::type radius, + matrix& p, + double eps, + unsigned long max_iter + ); + /*! + requires + - B == trans(B) + (i.e. B should be a symmetric matrix) + - B.nr() == B.nc() + - is_col_vector(g) == true + - g.size() == B.nr() + - p is capable of containing a column vector the size of g + (i.e. p = g; should be a legal expression) + - radius > 0 + - eps > 0 + - max_iter > 0 + ensures + - This function solves the following optimization problem: + Minimize: f(p) == 0.5*trans(p)*B*p + trans(g)*p + subject to the following constraint: + - length(p) <= radius + - returns the number of iterations performed. If this method fails to converge + to eps accuracy then the number returned will be max_iter+1. + - if (this function didn't terminate due to hitting the max_iter iteration limit) then + - if this function returns 0 or 1 then we are not hitting the radius bound Otherwise, + the radius constraint is active and std::abs(length(#p)-radius)/radius <= eps. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP1, + typename EXP2, + typename T, long NR, long NC, typename MM, typename L, + typename EXP3 + > + void solve_trust_region_subproblem_bounded ( + const matrix_exp& B, + const matrix_exp& g, + const typename EXP1::type radius, + matrix& p, + double eps, + unsigned long max_iter, + const matrix_exp& lower, + const matrix_exp& upper + ); + /*! + requires + - B == trans(B) + (i.e. B should be a symmetric matrix) + - B.nr() == B.nc() + - is_col_vector(g) == true + - is_col_vector(lower) == true + - is_col_vector(upper) == true + - g.size() == B.nr() + - lower.size() == B.nr() + - upper.size() == B.nr() + - p is capable of containing a column vector the size of g + (i.e. p = g; should be a legal expression) + - radius > 0 + - eps > 0 + - max_iter > 0 + - min(upper-lower) >= 0 + - length(clamp(zeros_matrix(lower),lower,upper)) <= radius + (i.e. the lower and upper bounds can't exclude all points with the desired radius.) + ensures + - This function solves the following optimization problem: + Minimize: f(p) == 0.5*trans(p)*B*p + trans(g)*p + subject to the following constraints: + - length(p) <= radius + - lower(i) <= p(i) <= upper(i), for all i + - Solves the problem to eps accuracy. We do this by greedily finding the most + violated bound constraint, locking that variable to its constrained value, removing + it from the problem, and then resolving. We do that until no more constraint + violations are present. Each time we just call solve_trust_region_subproblem() + to get the solution and pass eps and max_iter directly to these calls to + solve_trust_region_subproblem(). + !*/ + +// ---------------------------------------------------------------------------------------- + + class function_model + { + /*! + WHAT THIS OBJECT REPRESENTS + This object defines the interface for a function model + used by the trust-region optimizers defined below. + + In particular, this object represents a function f() and + its associated derivative and hessian. + + !*/ + + public: + + // Define the type used to represent column vectors + typedef matrix column_vector; + // Define the type used to represent the hessian matrix + typedef matrix general_matrix; + + double operator() ( + const column_vector& x + ) const; + /*! + ensures + - returns f(x) + (i.e. evaluates this model at the given point and returns the value) + !*/ + + void get_derivative_and_hessian ( + const column_vector& x, + column_vector& d, + general_matrix& h + ) const; + /*! + ensures + - #d == the derivative of f() at x + - #h == the hessian matrix of f() at x + - is_col_vector(#d) == true + - #d.size() == x.size() + - #h.nr() == #h.nc() == x.size() + - #h == trans(#h) + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename stop_strategy_type, + typename funct_model + > + double find_min_trust_region ( + stop_strategy_type stop_strategy, + const funct_model& model, + typename funct_model::column_vector& x, + double radius = 1 + ); + /*! + requires + - stop_strategy == an object that defines a stop strategy such as one of + the objects from dlib/optimization/optimization_stop_strategies_abstract.h + - is_col_vector(x) == true + - radius > 0 + - model must be an object with an interface as defined by the function_model + example object shown above. + ensures + - Performs an unconstrained minimization of the function defined by model + starting from the initial point x. This function uses a trust region + algorithm to perform the minimization. The radius parameter defines + the initial size of the trust region. + - The function is optimized until stop_strategy decides that an acceptable + point has been found or the trust region subproblem fails to make progress. + - #x == the value of x that was found to minimize model() + - returns model(#x). + - When this function makes calls to model.get_derivative_and_hessian() it always + does so by first calling model() and then calling model.get_derivative_and_hessian(). + That is, any call to model.get_derivative_and_hessian(val) will always be + preceded by a call to model(val) with the same value. This way you can reuse + any redundant computations performed by model() and model.get_derivative_and_hessian() + as appropriate. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename stop_strategy_type, + typename funct_model + > + double find_max_trust_region ( + stop_strategy_type stop_strategy, + const funct_model& model, + typename funct_model::column_vector& x, + double radius = 1 + ); + /*! + requires + - stop_strategy == an object that defines a stop strategy such as one of + the objects from dlib/optimization/optimization_stop_strategies_abstract.h + - is_col_vector(x) == true + - radius > 0 + - model must be an object with an interface as defined by the function_model + example object shown above. + ensures + - Performs an unconstrained maximization of the function defined by model + starting from the initial point x. This function uses a trust region + algorithm to perform the maximization. The radius parameter defines + the initial size of the trust region. + - The function is optimized until stop_strategy decides that an acceptable + point has been found or the trust region subproblem fails to make progress. + - #x == the value of x that was found to maximize model() + - returns model(#x). + - When this function makes calls to model.get_derivative_and_hessian() it always + does so by first calling model() and then calling model.get_derivative_and_hessian(). + That is, any call to model.get_derivative_and_hessian(val) will always be + preceded by a call to model(val) with the same value. This way you can reuse + any redundant computations performed by model() and model.get_derivative_and_hessian() + as appropriate. + - Note that this function solves the maximization problem by converting it + into a minimization problem. Therefore, the values of model() and its derivative + reported to the stopping strategy will be negated. That is, stop_strategy + will see -model() and -derivative. All this really means is that the status + messages from a stopping strategy in verbose mode will display a negated objective + value. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_OPTIMIZATION_TRUST_REGIoN_H_ABSTRACTh_ + + diff --git a/ml/dlib/dlib/ostream b/ml/dlib/dlib/ostream new file mode 100644 index 000000000..eb0e59e41 --- /dev/null +++ b/ml/dlib/dlib/ostream @@ -0,0 +1 @@ +#include "dlib_include_path_tutorial.txt" diff --git a/ml/dlib/dlib/pipe.h b/ml/dlib/dlib/pipe.h new file mode 100644 index 000000000..023b985bf --- /dev/null +++ b/ml/dlib/dlib/pipe.h @@ -0,0 +1,10 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_PIPe_ +#define DLIB_PIPe_ + +#include "pipe/pipe_kernel_1.h" + + +#endif // DLIB_PIPe_ + diff --git a/ml/dlib/dlib/pipe/pipe_kernel_1.h b/ml/dlib/dlib/pipe/pipe_kernel_1.h new file mode 100644 index 000000000..543754121 --- /dev/null +++ b/ml/dlib/dlib/pipe/pipe_kernel_1.h @@ -0,0 +1,756 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_PIPE_KERNEl_1_ +#define DLIB_PIPE_KERNEl_1_ + +#include "../algs.h" +#include "../threads.h" +#include "pipe_kernel_abstract.h" + +namespace dlib +{ + + template < + typename T + > + class pipe + { + /*! + INITIAL VALUE + - pipe_size == 0 + - pipe_max_size == defined by constructor + - enabled == true + - data == a pointer to an array of ((pipe_max_size>0)?pipe_max_size:1) T objects. + - dequeue_waiters == 0 + - enqueue_waiters == 0 + - first == 1 + - last == 1 + - unblock_sig_waiters == 0 + + CONVENTION + - size() == pipe_size + - max_size() == pipe_max_size + - is_enabled() == enabled + + - m == the mutex used to lock access to all the members of this class + + - dequeue_waiters == the number of threads blocked on calls to dequeue() + - enqueue_waiters == the number of threads blocked on calls to enqueue() and + wait_until_empty() + - unblock_sig_waiters == the number of threads blocked on calls to + wait_for_num_blocked_dequeues() and the destructor. (i.e. the number of + blocking calls to unblock_sig.wait()) + + - dequeue_sig == the signaler that threads blocked on calls to dequeue() wait on + - enqueue_sig == the signaler that threads blocked on calls to enqueue() + or wait_until_empty() wait on. + - unblock_sig == the signaler that is signaled when a thread stops blocking on a call + to enqueue() or dequeue(). It is also signaled when a dequeue that will probably + block is called. The destructor and wait_for_num_blocked_dequeues are the only + things that will wait on this signaler. + + - if (pipe_size > 0) then + - data[first] == the next item to dequeue + - data[last] == the item most recently added via enqueue, so the last to dequeue. + - else if (pipe_max_size == 0) + - if (first == 0 && last == 0) then + - data[0] == the next item to dequeue + - else if (first == 0 && last == 1) then + - data[0] has been taken out already by a dequeue + !*/ + + public: + // this is here for backwards compatibility with older versions of dlib. + typedef pipe kernel_1a; + + typedef T type; + + explicit pipe ( + size_t maximum_size + ); + + virtual ~pipe ( + ); + + void empty ( + ); + + void wait_until_empty ( + ) const; + + void wait_for_num_blocked_dequeues ( + unsigned long num + )const; + + void enable ( + ); + + void disable ( + ); + + bool is_enqueue_enabled ( + ) const; + + void disable_enqueue ( + ); + + void enable_enqueue ( + ); + + bool is_dequeue_enabled ( + ) const; + + void disable_dequeue ( + ); + + void enable_dequeue ( + ); + + bool is_enabled ( + ) const; + + size_t max_size ( + ) const; + + size_t size ( + ) const; + + bool enqueue ( + T& item + ); + + bool enqueue ( + T&& item + ) { return enqueue(item); } + + bool dequeue ( + T& item + ); + + bool enqueue_or_timeout ( + T& item, + unsigned long timeout + ); + + bool enqueue_or_timeout ( + T&& item, + unsigned long timeout + ) { return enqueue_or_timeout(item,timeout); } + + bool dequeue_or_timeout ( + T& item, + unsigned long timeout + ); + + private: + + size_t pipe_size; + const size_t pipe_max_size; + bool enabled; + + T* const data; + + size_t first; + size_t last; + + mutex m; + signaler dequeue_sig; + signaler enqueue_sig; + signaler unblock_sig; + + unsigned long dequeue_waiters; + mutable unsigned long enqueue_waiters; + mutable unsigned long unblock_sig_waiters; + bool enqueue_enabled; + bool dequeue_enabled; + + // restricted functions + pipe(const pipe&); // copy constructor + pipe& operator=(const pipe&); // assignment operator + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + pipe:: + pipe ( + size_t maximum_size + ) : + pipe_size(0), + pipe_max_size(maximum_size), + enabled(true), + data(new T[(maximum_size>0) ? maximum_size : 1]), + first(1), + last(1), + dequeue_sig(m), + enqueue_sig(m), + unblock_sig(m), + dequeue_waiters(0), + enqueue_waiters(0), + unblock_sig_waiters(0), + enqueue_enabled(true), + dequeue_enabled(true) + { + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + pipe:: + ~pipe ( + ) + { + auto_mutex M(m); + ++unblock_sig_waiters; + + // first make sure no one is blocked on any calls to enqueue() or dequeue() + enabled = false; + dequeue_sig.broadcast(); + enqueue_sig.broadcast(); + unblock_sig.broadcast(); + + // wait for all threads to unblock + while (dequeue_waiters > 0 || enqueue_waiters > 0 || unblock_sig_waiters > 1) + unblock_sig.wait(); + + delete [] data; + --unblock_sig_waiters; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + void pipe:: + empty ( + ) + { + auto_mutex M(m); + pipe_size = 0; + + // let any calls to enqueue() know that the pipe is now empty + if (enqueue_waiters > 0) + enqueue_sig.broadcast(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + void pipe:: + wait_until_empty ( + ) const + { + auto_mutex M(m); + // this function is sort of like a call to enqueue so treat it like that + ++enqueue_waiters; + + while (pipe_size > 0 && enabled && dequeue_enabled ) + enqueue_sig.wait(); + + // let the destructor know we are ending if it is blocked waiting + if (enabled == false) + unblock_sig.broadcast(); + + --enqueue_waiters; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + void pipe:: + enable ( + ) + { + auto_mutex M(m); + enabled = true; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + void pipe:: + disable ( + ) + { + auto_mutex M(m); + enabled = false; + dequeue_sig.broadcast(); + enqueue_sig.broadcast(); + unblock_sig.broadcast(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + bool pipe:: + is_enabled ( + ) const + { + auto_mutex M(m); + return enabled; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + size_t pipe:: + max_size ( + ) const + { + auto_mutex M(m); + return pipe_max_size; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + size_t pipe:: + size ( + ) const + { + auto_mutex M(m); + return pipe_size; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + bool pipe:: + enqueue ( + T& item + ) + { + auto_mutex M(m); + ++enqueue_waiters; + + // wait until there is room or we are disabled + while (pipe_size == pipe_max_size && enabled && enqueue_enabled && + !(pipe_max_size == 0 && first == 1) ) + enqueue_sig.wait(); + + if (enabled == false || enqueue_enabled == false) + { + --enqueue_waiters; + // let the destructor know we are unblocking + unblock_sig.broadcast(); + return false; + } + + // set the appropriate values for first and last + if (pipe_size == 0) + { + first = 0; + last = 0; + } + else + { + last = (last+1)%pipe_max_size; + } + + + exchange(item,data[last]); + + // wake up a call to dequeue() if there are any currently blocked + if (dequeue_waiters > 0) + dequeue_sig.signal(); + + if (pipe_max_size > 0) + { + ++pipe_size; + } + else + { + // wait for a dequeue to take the item out + while (last == 0 && enabled && enqueue_enabled) + enqueue_sig.wait(); + + if (last == 0 && (enabled == false || enqueue_enabled == false)) + { + last = 1; + first = 1; + + // no one dequeued this object to put it back into item + exchange(item,data[0]); + + --enqueue_waiters; + // let the destructor know we are unblocking + if (unblock_sig_waiters > 0) + unblock_sig.broadcast(); + return false; + } + + last = 1; + first = 1; + + // tell any waiting calls to enqueue() that one of them can proceed + if (enqueue_waiters > 1) + enqueue_sig.broadcast(); + + // let the destructor know we are unblocking + if (enabled == false && unblock_sig_waiters > 0) + unblock_sig.broadcast(); + } + + --enqueue_waiters; + return true; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + bool pipe:: + dequeue ( + T& item + ) + { + auto_mutex M(m); + ++dequeue_waiters; + + if (pipe_size == 0) + { + // notify wait_for_num_blocked_dequeues() + if (unblock_sig_waiters > 0) + unblock_sig.broadcast(); + + // notify any blocked enqueue_or_timeout() calls + if (enqueue_waiters > 0) + enqueue_sig.broadcast(); + } + + // wait until there is something in the pipe or we are disabled + while (pipe_size == 0 && enabled && dequeue_enabled && + !(pipe_max_size == 0 && first == 0 && last == 0) ) + dequeue_sig.wait(); + + if (enabled == false || dequeue_enabled == false) + { + --dequeue_waiters; + // let the destructor know we are unblocking + unblock_sig.broadcast(); + return false; + } + + exchange(item,data[first]); + + if (pipe_max_size > 0) + { + // set the appropriate values for first + first = (first+1)%pipe_max_size; + + --pipe_size; + } + else + { + // let the enqueue waiting on us know that we took the + // item out already. + last = 1; + } + + // wake up a call to enqueue() if there are any currently blocked + if (enqueue_waiters > 0) + enqueue_sig.broadcast(); + + --dequeue_waiters; + return true; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + bool pipe:: + enqueue_or_timeout ( + T& item, + unsigned long timeout + ) + { + auto_mutex M(m); + ++enqueue_waiters; + + // wait until there is room or we are disabled or + // we run out of time. + bool timed_out = false; + while (pipe_size == pipe_max_size && enabled && enqueue_enabled && + !(pipe_max_size == 0 && dequeue_waiters > 0 && first == 1) ) + { + if (timeout == 0 || enqueue_sig.wait_or_timeout(timeout) == false) + { + timed_out = true; + break; + } + } + + if (enabled == false || timed_out || enqueue_enabled == false) + { + --enqueue_waiters; + // let the destructor know we are unblocking + unblock_sig.broadcast(); + return false; + } + + // set the appropriate values for first and last + if (pipe_size == 0) + { + first = 0; + last = 0; + } + else + { + last = (last+1)%pipe_max_size; + } + + + exchange(item,data[last]); + + // wake up a call to dequeue() if there are any currently blocked + if (dequeue_waiters > 0) + dequeue_sig.signal(); + + if (pipe_max_size > 0) + { + ++pipe_size; + } + else + { + // wait for a dequeue to take the item out + while (last == 0 && enabled && enqueue_enabled) + enqueue_sig.wait(); + + if (last == 0 && (enabled == false || enqueue_enabled == false)) + { + last = 1; + first = 1; + + // no one dequeued this object to put it back into item + exchange(item,data[0]); + + --enqueue_waiters; + // let the destructor know we are unblocking + if (unblock_sig_waiters > 0) + unblock_sig.broadcast(); + return false; + } + + last = 1; + first = 1; + + // tell any waiting calls to enqueue() that one of them can proceed + if (enqueue_waiters > 1) + enqueue_sig.broadcast(); + + // let the destructor know we are unblocking + if (enabled == false && unblock_sig_waiters > 0) + unblock_sig.broadcast(); + } + + --enqueue_waiters; + return true; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + bool pipe:: + dequeue_or_timeout ( + T& item, + unsigned long timeout + ) + { + auto_mutex M(m); + ++dequeue_waiters; + + if (pipe_size == 0) + { + // notify wait_for_num_blocked_dequeues() + if (unblock_sig_waiters > 0) + unblock_sig.broadcast(); + + // notify any blocked enqueue_or_timeout() calls + if (enqueue_waiters > 0) + enqueue_sig.broadcast(); + } + + bool timed_out = false; + // wait until there is something in the pipe or we are disabled or we timeout. + while (pipe_size == 0 && enabled && dequeue_enabled && + !(pipe_max_size == 0 && first == 0 && last == 0) ) + { + if (timeout == 0 || dequeue_sig.wait_or_timeout(timeout) == false) + { + timed_out = true; + break; + } + } + + if (enabled == false || timed_out || dequeue_enabled == false) + { + --dequeue_waiters; + // let the destructor know we are unblocking + unblock_sig.broadcast(); + return false; + } + + exchange(item,data[first]); + + if (pipe_max_size > 0) + { + // set the appropriate values for first + first = (first+1)%pipe_max_size; + + --pipe_size; + } + else + { + // let the enqueue waiting on us know that we took the + // item out already. + last = 1; + } + + // wake up a call to enqueue() if there are any currently blocked + if (enqueue_waiters > 0) + enqueue_sig.broadcast(); + + --dequeue_waiters; + return true; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + void pipe:: + wait_for_num_blocked_dequeues ( + unsigned long num + )const + { + auto_mutex M(m); + ++unblock_sig_waiters; + + while ( (dequeue_waiters < num || pipe_size != 0) && enabled && dequeue_enabled) + unblock_sig.wait(); + + // let the destructor know we are ending if it is blocked waiting + if (enabled == false) + unblock_sig.broadcast(); + + --unblock_sig_waiters; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + bool pipe:: + is_enqueue_enabled ( + ) const + { + auto_mutex M(m); + return enqueue_enabled; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + void pipe:: + disable_enqueue ( + ) + { + auto_mutex M(m); + enqueue_enabled = false; + enqueue_sig.broadcast(); + } + + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + void pipe:: + enable_enqueue ( + ) + { + auto_mutex M(m); + enqueue_enabled = true; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + bool pipe:: + is_dequeue_enabled ( + ) const + { + auto_mutex M(m); + return dequeue_enabled; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + void pipe:: + disable_dequeue ( + ) + { + auto_mutex M(m); + dequeue_enabled = false; + dequeue_sig.broadcast(); + } + + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + void pipe:: + enable_dequeue ( + ) + { + auto_mutex M(m); + dequeue_enabled = true; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_PIPE_KERNEl_1_ + diff --git a/ml/dlib/dlib/pipe/pipe_kernel_abstract.h b/ml/dlib/dlib/pipe/pipe_kernel_abstract.h new file mode 100644 index 000000000..91b2205e7 --- /dev/null +++ b/ml/dlib/dlib/pipe/pipe_kernel_abstract.h @@ -0,0 +1,323 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_PIPE_KERNEl_ABSTRACT_ +#ifdef DLIB_PIPE_KERNEl_ABSTRACT_ + +#include "../threads.h" + +namespace dlib +{ + + template < + typename T + > + class pipe + { + /*! + REQUIREMENTS ON T + T must be swappable by a global swap() + T must have a default constructor + + INITIAL VALUE + size() == 0 + is_enabled() == true + is_enqueue_enabled() == true + is_dequeue_enabled() == true + + WHAT THIS OBJECT REPRESENTS + This is a first in first out queue with a fixed maximum size containing + items of type T. It is suitable for passing objects between threads. + + THREAD SAFETY + All methods of this class are thread safe. You may call them from any + thread and any number of threads my call them at once. + !*/ + + public: + + typedef T type; + + explicit pipe ( + size_t maximum_size + ); + /*! + ensures + - #*this is properly initialized + - #max_size() == maximum_size + throws + - std::bad_alloc + - dlib::thread_error + !*/ + + virtual ~pipe ( + ); + /*! + ensures + - any resources associated with *this have been released + - disables (i.e. sets is_enabled() == false) this object so that + all calls currently blocking on it will return immediately. + !*/ + + void enable ( + ); + /*! + ensures + - #is_enabled() == true + !*/ + + void disable ( + ); + /*! + ensures + - #is_enabled() == false + - causes all current and future calls to enqueue(), dequeue(), + enqueue_or_timeout() and dequeue_or_timeout() to not block but + to return false immediately until enable() is called. + - causes all current and future calls to wait_until_empty() and + wait_for_num_blocked_dequeues() to not block but return + immediately until enable() is called. + !*/ + + bool is_enabled ( + ) const; + /*! + ensures + - returns true if this pipe is currently enabled, false otherwise. + !*/ + + void empty ( + ); + /*! + ensures + - #size() == 0 + !*/ + + void wait_until_empty ( + ) const; + /*! + ensures + - blocks until one of the following is the case: + - size() == 0 + - is_enabled() == false + - is_dequeue_enabled() == false + !*/ + + void wait_for_num_blocked_dequeues ( + unsigned long num + ) const; + /*! + ensures + - blocks until one of the following is the case: + - size() == 0 and the number of threads blocked on calls + to dequeue() and dequeue_or_timeout() is greater than + or equal to num. + - is_enabled() == false + - is_dequeue_enabled() == false + !*/ + + bool is_enqueue_enabled ( + ) const; + /*! + ensures + - returns true if the enqueue() and enqueue_or_timeout() functions are + currently enabled, returns false otherwise. (note that the higher + level is_enabled() function can overrule this one. So if + is_enabled() == false then enqueue functions are still disabled even + if is_enqueue_enabled() returns true. But if is_enqueue_enabled() == false + then enqueue functions are always disabled no matter the state of + is_enabled()) + !*/ + + void disable_enqueue ( + ); + /*! + ensures + - #is_enqueue_enabled() == false + - causes all current and future calls to enqueue() and + enqueue_or_timeout() to not block but to return false + immediately until enable_enqueue() is called. + !*/ + + void enable_enqueue ( + ); + /*! + ensures + - #is_enqueue_enabled() == true + !*/ + + bool is_dequeue_enabled ( + ) const; + /*! + ensures + - returns true if the dequeue() and dequeue_or_timeout() functions are + currently enabled, returns false otherwise. (note that the higher + level is_enabled() function can overrule this one. So if + is_enabled() == false then dequeue functions are still disabled even + if is_dequeue_enabled() returns true. But if is_dequeue_enabled() == false + then dequeue functions are always disabled no matter the state of + is_enabled()) + !*/ + + void disable_dequeue ( + ); + /*! + ensures + - #is_dequeue_enabled() == false + - causes all current and future calls to dequeue() and + dequeue_or_timeout() to not block but to return false + immediately until enable_dequeue() is called. + !*/ + + void enable_dequeue ( + ); + /*! + ensures + - #is_dequeue_enabled() == true + !*/ + + size_t max_size ( + ) const; + /*! + ensures + - returns the maximum number of objects of type T that this + pipe can contain. + !*/ + + size_t size ( + ) const; + /*! + ensures + - returns the number of objects of type T that this + object currently contains. + !*/ + + bool enqueue ( + T& item + ); + /*! + ensures + - if (size() == max_size()) then + - this call to enqueue() blocks until one of the following is the case: + - there is room in the pipe for another item + - max_size() == 0 and another thread is trying to dequeue from this + pipe and we can pass our item object directly to that thread. + - someone calls disable() + - someone calls disable_enqueue() + - else + - this call does not block. + - if (this call to enqueue() returns true) then + - #is_enabled() == true + - #is_enqueue_enabled() == true + - if (max_size() == 0) then + - using global swap, item was passed directly to a + thread attempting to dequeue from this pipe + - else + - using global swap, item was added into this pipe. + - #item is in an undefined but valid state for its type + - else + - item was NOT added into the pipe + - #item == item (i.e. the value of item is unchanged) + !*/ + + bool enqueue (T&& item) { return enqueue(item); } + /*! + enable enqueueing from rvalues + !*/ + + bool enqueue_or_timeout ( + T& item, + unsigned long timeout + ); + /*! + ensures + - if (size() == max_size() && timeout > 0) then + - this call to enqueue_or_timeout() blocks until one of the following is the case: + - there is room in the pipe to add another item + - max_size() == 0 and another thread is trying to dequeue from this pipe + and we can pass our item object directly to that thread. + - someone calls disable() + - someone calls disable_enqueue() + - timeout milliseconds passes + - else + - this call does not block. + - if (this call to enqueue() returns true) then + - #is_enabled() == true + - #is_enqueue_enabled() == true + - if (max_size() == 0) then + - using global swap, item was passed directly to a + thread attempting to dequeue from this pipe + - else + - using global swap, item was added into this pipe. + - #item is in an undefined but valid state for its type + - else + - item was NOT added into the pipe + - #item == item (i.e. the value of item is unchanged) + !*/ + + bool enqueue_or_timeout (T&& item, unsigned long timeout) { return enqueue_or_timeout(item,timeout); } + /*! + enable enqueueing from rvalues + !*/ + + bool dequeue ( + T& item + ); + /*! + ensures + - if (size() == 0) then + - this call to dequeue() blocks until one of the following is the case: + - there is something in the pipe we can dequeue + - max_size() == 0 and another thread is trying to enqueue an item + onto this pipe and we can receive our item directly from that thread. + - someone calls disable() + - someone calls disable_dequeue() + - else + - this call does not block. + - if (this call to dequeue() returns true) then + - #is_enabled() == true + - #is_dequeue_enabled() == true + - the oldest item that was enqueued into this pipe has been + swapped into #item. + - else + - nothing was dequeued from this pipe. + - #item == item (i.e. the value of item is unchanged) + !*/ + + bool dequeue_or_timeout ( + T& item, + unsigned long timeout + ); + /*! + ensures + - if (size() == 0 && timeout > 0) then + - this call to dequeue_or_timeout() blocks until one of the following is the case: + - there is something in the pipe we can dequeue + - max_size() == 0 and another thread is trying to enqueue an item onto this + pipe and we can receive our item directly from that thread. + - someone calls disable() + - someone calls disable_dequeue() + - timeout milliseconds passes + - else + - this call does not block. + - if (this call to dequeue_or_timeout() returns true) then + - #is_enabled() == true + - #is_dequeue_enabled() == true + - the oldest item that was enqueued into this pipe has been + swapped into #item. + - else + - nothing was dequeued from this pipe. + - #item == item (i.e. the value of item is unchanged) + !*/ + + private: + + // restricted functions + pipe(const pipe&); // copy constructor + pipe& operator=(const pipe&); // assignment operator + + }; + +} + +#endif // DLIB_PIPE_KERNEl_ABSTRACT_ + diff --git a/ml/dlib/dlib/pixel.h b/ml/dlib/dlib/pixel.h new file mode 100644 index 000000000..50ead2c34 --- /dev/null +++ b/ml/dlib/dlib/pixel.h @@ -0,0 +1,1649 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_PIXEl_ +#define DLIB_PIXEl_ + +#include +#include "serialize.h" +#include +#include "algs.h" +#include "uintn.h" +#include +#include +#include "enable_if.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + /*! + This file contains definitions of pixel objects and related classes and + functionality. + !*/ + +// ---------------------------------------------------------------------------------------- + + template + struct pixel_traits; + /*! + WHAT THIS OBJECT REPRESENTS + As the name implies, this is a traits class for pixel types. + It defines the properties of a pixel. + + This traits class will define the following public static members: + - bool grayscale + - bool rgb + - bool rgb_alpha + - bool hsi + - bool lab + + - bool has_alpha + + - long num + + - basic_pixel_type + - basic_pixel_type min() + - basic_pixel_type max() + - is_unsigned + + The above public constants are subject to the following constraints: + - only one of grayscale, rgb, rgb_alpha, hsi or lab is true + - if (rgb == true) then + - The type T will be a struct with 3 public members of type + unsigned char named "red" "green" and "blue". + - This type of pixel represents the RGB color space. + - num == 3 + - has_alpha == false + - basic_pixel_type == unsigned char + - min() == 0 + - max() == 255 + - is_unsigned == true + - if (rgb_alpha == true) then + - The type T will be a struct with 4 public members of type + unsigned char named "red" "green" "blue" and "alpha". + - This type of pixel represents the RGB color space with + an alpha channel where an alpha of 0 represents a pixel + that is totally transparent and 255 represents a pixel + with maximum opacity. + - num == 4 + - has_alpha == true + - basic_pixel_type == unsigned char + - min() == 0 + - max() == 255 + - is_unsigned == true + - else if (hsi == true) then + - The type T will be a struct with 3 public members of type + unsigned char named "h" "s" and "i". + - This type of pixel represents the HSI color space. + - num == 3 + - has_alpha == false + - basic_pixel_type == unsigned char + - min() == 0 + - max() == 255 + - is_unsigned == true + - else if (lab == true) then + - The type T will be a struct with 3 public members of type + unsigned char named "l" "a" and "b". + - This type of pixel represents the Lab color space. + - num == 3 + - has_alpha == false + - basic_pixel_type == unsigned char + - min() == 0 + - max() == 255 + - is_unsigned == true + - else + - grayscale == true + - This type of pixel represents a grayscale color space. T + will be some kind of basic scalar type such as unsigned int. + - num == 1 + - has_alpha == false + - basic_pixel_type == T + - min() == the minimum obtainable value of objects of type T + - max() == the maximum obtainable value of objects of type T + - is_unsigned is true if min() == 0 and false otherwise + !*/ + +// ---------------------------------------------------------------------------------------- + + struct rgb_pixel + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a simple struct that represents an RGB colored graphical pixel. + !*/ + + rgb_pixel ( + ) {} + + rgb_pixel ( + unsigned char red_, + unsigned char green_, + unsigned char blue_ + ) : red(red_), green(green_), blue(blue_) {} + + unsigned char red; + unsigned char green; + unsigned char blue; + }; + +// ---------------------------------------------------------------------------------------- + + struct bgr_pixel + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a simple struct that represents an BGR colored graphical pixel. + (the reason it exists in addition to the rgb_pixel is so you can lay + it down on top of a memory region that organizes its color data in the + BGR format and still be able to read it) + !*/ + + bgr_pixel ( + ) {} + + bgr_pixel ( + unsigned char blue_, + unsigned char green_, + unsigned char red_ + ) : blue(blue_), green(green_), red(red_) {} + + unsigned char blue; + unsigned char green; + unsigned char red; + }; + +// ---------------------------------------------------------------------------------------- + + struct rgb_alpha_pixel + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a simple struct that represents an RGB colored graphical pixel + with an alpha channel. + !*/ + + rgb_alpha_pixel ( + ) {} + + rgb_alpha_pixel ( + unsigned char red_, + unsigned char green_, + unsigned char blue_, + unsigned char alpha_ + ) : red(red_), green(green_), blue(blue_), alpha(alpha_) {} + + unsigned char red; + unsigned char green; + unsigned char blue; + unsigned char alpha; + }; + +// ---------------------------------------------------------------------------------------- + + struct hsi_pixel + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a simple struct that represents an HSI colored graphical pixel. + !*/ + + hsi_pixel ( + ) {} + + hsi_pixel ( + unsigned char h_, + unsigned char s_, + unsigned char i_ + ) : h(h_), s(s_), i(i_) {} + + unsigned char h; + unsigned char s; + unsigned char i; + }; + // ---------------------------------------------------------------------------------------- + + struct lab_pixel + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a simple struct that represents an Lab colored graphical pixel. + !*/ + + lab_pixel ( + ) {} + + lab_pixel ( + unsigned char l_, + unsigned char a_, + unsigned char b_ + ) : l(l_), a(a_), b(b_) {} + + unsigned char l; + unsigned char a; + unsigned char b; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename P1, + typename P2 + > + inline void assign_pixel ( + P1& dest, + const P2& src + ); + /*! + requires + - pixel_traits must be defined + - pixel_traits must be defined + ensures + - if (P1 and P2 are the same type of pixel) then + - simply copies the value of src into dest. In other words, + dest will be identical to src after this function returns. + - else if (P1 and P2 are not the same type of pixel) then + - assigns pixel src to pixel dest and does any necessary color space + conversions. + - When converting from a grayscale color space with more than 255 values the + pixel intensity is saturated at pixel_traits::max() or pixel_traits::min() + as appropriate. + - if (the dest pixel has an alpha channel and the src pixel doesn't) then + - #dest.alpha == 255 + - else if (the src pixel has an alpha channel but the dest pixel doesn't) then + - #dest == the original dest value blended with the src value according + to the alpha channel in src. + (i.e. #dest == src*(alpha/255) + dest*(1-alpha/255)) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename P + > + inline typename pixel_traits

    ::basic_pixel_type get_pixel_intensity ( + const P& src + ); + /*! + requires + - pixel_traits

    must be defined + ensures + - if (pixel_traits

    ::grayscale == true) then + - returns src + - else + - converts src to grayscale and returns the resulting value. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename P, + typename T + > + inline void assign_pixel_intensity ( + P& dest, + const T& new_intensity + ); + /*! + requires + - pixel_traits

    must be defined + - pixel_traits must be defined + ensures + - This function changes the intensity of the dest pixel. So if the pixel in + question is a grayscale pixel then it simply assigns that pixel with the + value of get_pixel_intensity(new_intensity). However, if the pixel is not + a grayscale pixel then it converts the pixel to the HSI color space and sets + the I channel to the given intensity and then converts this HSI value back to + the original pixel's color space. + - Note that we don't necessarily have #get_pixel_intensity(dest) == get_pixel_intensity(new_intensity) + due to vagaries of how converting to and from HSI works out. + - if (the dest pixel has an alpha channel) then + - #dest.alpha == dest.alpha + !*/ + +// ---------------------------------------------------------------------------------------- + + inline void serialize ( + const rgb_pixel& item, + std::ostream& out + ); + /*! + provides serialization support for the rgb_pixel struct + !*/ + +// ---------------------------------------------------------------------------------------- + + inline void deserialize ( + rgb_pixel& item, + std::istream& in + ); + /*! + provides deserialization support for the rgb_pixel struct + !*/ + +// ---------------------------------------------------------------------------------------- + + inline void serialize ( + const bgr_pixel& item, + std::ostream& out + ); + /*! + provides serialization support for the bgr_pixel struct + !*/ + +// ---------------------------------------------------------------------------------------- + + inline void deserialize ( + bgr_pixel& item, + std::istream& in + ); + /*! + provides deserialization support for the bgr_pixel struct + !*/ + +// ---------------------------------------------------------------------------------------- + + inline void serialize ( + const rgb_alpha_pixel& item, + std::ostream& out + ); + /*! + provides serialization support for the rgb_alpha_pixel struct + !*/ + +// ---------------------------------------------------------------------------------------- + + inline void deserialize ( + rgb_alpha_pixel& item, + std::istream& in + ); + /*! + provides deserialization support for the rgb_alpha_pixel struct + !*/ + +// ---------------------------------------------------------------------------------------- + + inline void serialize ( + const hsi_pixel& item, + std::ostream& out + ); + /*! + provides serialization support for the hsi_pixel struct + !*/ + +// ---------------------------------------------------------------------------------------- + + inline void serialize ( + const lab_pixel& item, + std::ostream& out + ); + /*! + provides serialization support for the lab_pixel struct + !*/ + + +// ---------------------------------------------------------------------------------------- + + inline void deserialize ( + hsi_pixel& item, + std::istream& in + ); + /*! + provides deserialization support for the hsi_pixel struct + !*/ +// ---------------------------------------------------------------------------------------- + + inline void deserialize ( + lab_pixel& item, + std::istream& in + ); + /*! + provides deserialization support for the lab_pixel struct + !*/ + +// ---------------------------------------------------------------------------------------- + + template <> + struct pixel_traits + { + constexpr static bool rgb = true; + constexpr static bool rgb_alpha = false; + constexpr static bool grayscale = false; + constexpr static bool hsi = false; + constexpr static bool lab = false; + enum { num = 3}; + typedef unsigned char basic_pixel_type; + static basic_pixel_type min() { return 0;} + static basic_pixel_type max() { return 255;} + constexpr static bool is_unsigned = true; + constexpr static bool has_alpha = false; + }; + +// ---------------------------------------------------------------------------------------- + + template <> + struct pixel_traits + { + constexpr static bool rgb = true; + constexpr static bool rgb_alpha = false; + constexpr static bool grayscale = false; + constexpr static bool hsi = false; + constexpr static bool lab = false; + constexpr static long num = 3; + typedef unsigned char basic_pixel_type; + static basic_pixel_type min() { return 0;} + static basic_pixel_type max() { return 255;} + constexpr static bool is_unsigned = true; + constexpr static bool has_alpha = false; + }; + +// ---------------------------------------------------------------------------------------- + + template <> + struct pixel_traits + { + constexpr static bool rgb = false; + constexpr static bool rgb_alpha = true; + constexpr static bool grayscale = false; + constexpr static bool hsi = false; + constexpr static bool lab = false; + constexpr static long num = 4; + typedef unsigned char basic_pixel_type; + static basic_pixel_type min() { return 0;} + static basic_pixel_type max() { return 255;} + constexpr static bool is_unsigned = true; + constexpr static bool has_alpha = true; + }; + +// ---------------------------------------------------------------------------------------- + + + template <> + struct pixel_traits + { + constexpr static bool rgb = false; + constexpr static bool rgb_alpha = false; + constexpr static bool grayscale = false; + constexpr static bool hsi = true; + constexpr static bool lab = false; + constexpr static long num = 3; + typedef unsigned char basic_pixel_type; + static basic_pixel_type min() { return 0;} + static basic_pixel_type max() { return 255;} + constexpr static bool is_unsigned = true; + constexpr static bool has_alpha = false; + }; + +// ---------------------------------------------------------------------------------------- + + + template <> + struct pixel_traits + { + constexpr static bool rgb = false; + constexpr static bool rgb_alpha = false; + constexpr static bool grayscale = false; + constexpr static bool hsi = false; + constexpr static bool lab = true; + constexpr static long num = 3; + typedef unsigned char basic_pixel_type; + static basic_pixel_type min() { return 0;} + static basic_pixel_type max() { return 255;} + constexpr static bool is_unsigned = true; + constexpr static bool has_alpha = false; + }; + +// ---------------------------------------------------------------------------------------- + + template + struct grayscale_pixel_traits + { + constexpr static bool rgb = false; + constexpr static bool rgb_alpha = false; + constexpr static bool grayscale = true; + constexpr static bool hsi = false; + constexpr static bool lab = false; + constexpr static long num = 1; + constexpr static bool has_alpha = false; + typedef T basic_pixel_type; + static basic_pixel_type min() { return std::numeric_limits::min();} + static basic_pixel_type max() { return std::numeric_limits::max();} + constexpr static bool is_unsigned = is_unsigned_type::value; + }; + + template <> struct pixel_traits : public grayscale_pixel_traits {}; + template <> struct pixel_traits : public grayscale_pixel_traits {}; + template <> struct pixel_traits : public grayscale_pixel_traits {}; + template <> struct pixel_traits : public grayscale_pixel_traits {}; + + template <> struct pixel_traits : public grayscale_pixel_traits {}; + template <> struct pixel_traits : public grayscale_pixel_traits {}; + template <> struct pixel_traits : public grayscale_pixel_traits {}; + template <> struct pixel_traits : public grayscale_pixel_traits {}; + template <> struct pixel_traits : public grayscale_pixel_traits {}; + + template <> struct pixel_traits : public grayscale_pixel_traits {}; + template <> struct pixel_traits : public grayscale_pixel_traits {}; + +// ---------------------------------------------------------------------------------------- + + template + struct float_grayscale_pixel_traits + { + constexpr static bool rgb = false; + constexpr static bool rgb_alpha = false; + constexpr static bool grayscale = true; + constexpr static bool hsi = false; + constexpr static bool lab = false; + constexpr static long num = 1; + constexpr static bool has_alpha = false; + typedef T basic_pixel_type; + static basic_pixel_type min() { return -std::numeric_limits::max();} + static basic_pixel_type max() { return std::numeric_limits::max();} + constexpr static bool is_unsigned = false; + }; + + template <> struct pixel_traits : public float_grayscale_pixel_traits {}; + template <> struct pixel_traits : public float_grayscale_pixel_traits {}; + template <> struct pixel_traits : public float_grayscale_pixel_traits {}; + + // These are here mainly so you can easily copy images into complex arrays. This is + // useful when you want to do a FFT on an image or some similar operation. + template <> struct pixel_traits > : public float_grayscale_pixel_traits {}; + template <> struct pixel_traits > : public float_grayscale_pixel_traits {}; + template <> struct pixel_traits > : public float_grayscale_pixel_traits {}; + +// ---------------------------------------------------------------------------------------- + + // The following is a bunch of conversion stuff for the assign_pixel function. + + namespace assign_pixel_helpers + { + + // ----------------------------- + // all the same kind + + template < typename P > + typename enable_if_c::grayscale>::type + assign(P& dest, const P& src) + { + dest = src; + } + + // ----------------------------- + + template + typename unsigned_type::type make_unsigned ( + const T& val + ) { return static_cast::type>(val); } + + inline float make_unsigned(const float& val) { return val; } + inline double make_unsigned(const double& val) { return val; } + inline long double make_unsigned(const long double& val) { return val; } + + + template + typename enable_if_c::is_unsigned == pixel_traits

    ::is_unsigned, bool>::type less_or_equal_to_max ( + const P& p + ) + /*! + ensures + - returns true if p is <= max value of T + !*/ + { + return p <= pixel_traits::max(); + } + + template + typename enable_if_c::is_unsigned && !pixel_traits

    ::is_unsigned, bool>::type less_or_equal_to_max ( + const P& p + ) + { + if (p <= 0) + return true; + else if (make_unsigned(p) <= pixel_traits::max()) + return true; + else + return false; + } + + template + typename enable_if_c::is_unsigned && pixel_traits

    ::is_unsigned, bool>::type less_or_equal_to_max ( + const P& p + ) + { + return p <= make_unsigned(pixel_traits::max()); + } + + // ----------------------------- + + template + typename enable_if_c::is_unsigned, bool >::type greater_or_equal_to_min ( + const P& + ) { return true; } + /*! + ensures + - returns true if p is >= min value of T + !*/ + + template + typename enable_if_c::is_unsigned && pixel_traits::is_unsigned, bool >::type greater_or_equal_to_min ( + const P& p + ) + { + return p >= 0; + } + + template + typename enable_if_c::is_unsigned && !pixel_traits::is_unsigned, bool >::type greater_or_equal_to_min ( + const P& p + ) + { + return p >= pixel_traits::min(); + } + // ----------------------------- + + template < typename P1, typename P2 > + typename enable_if_c::grayscale && pixel_traits::grayscale>::type + assign(P1& dest, const P2& src) + { + /* + The reason for these weird comparison functions is to avoid getting compiler + warnings about comparing signed types to unsigned and stuff like that. + */ + + if (less_or_equal_to_max(src)) + if (greater_or_equal_to_min(src)) + dest = static_cast(src); + else + dest = pixel_traits::min(); + else + dest = pixel_traits::max(); + } + + // ----------------------------- + // ----------------------------- + // ----------------------------- + + template < typename P1, typename P2 > + typename enable_if_c::rgb && pixel_traits::rgb>::type + assign(P1& dest, const P2& src) + { + dest.red = src.red; + dest.green = src.green; + dest.blue = src.blue; + } + + template < typename P1, typename P2 > + typename enable_if_c::rgb_alpha && pixel_traits::rgb_alpha>::type + assign(P1& dest, const P2& src) + { + dest.red = src.red; + dest.green = src.green; + dest.blue = src.blue; + dest.alpha = src.alpha; + } + + template < typename P1, typename P2 > + typename enable_if_c::hsi && pixel_traits::hsi>::type + assign(P1& dest, const P2& src) + { + dest.h = src.h; + dest.s = src.s; + dest.i = src.i; + } + + template < typename P1, typename P2 > + typename enable_if_c::lab && pixel_traits::lab>::type + assign(P1& dest, const P2& src) + { + dest.l = src.l; + dest.a = src.a; + dest.b = src.b; + } + + // ----------------------------- + // dest is a grayscale + + template < typename P1, typename P2 > + typename enable_if_c::grayscale && pixel_traits::rgb>::type + assign(P1& dest, const P2& src) + { + const unsigned int temp = ((static_cast(src.red) + + static_cast(src.green) + + static_cast(src.blue))/3); + assign_pixel(dest, temp); + } + + template < typename P1, typename P2 > + typename enable_if_c::grayscale && pixel_traits::rgb_alpha>::type + assign(P1& dest, const P2& src) + { + + const unsigned char avg = static_cast((static_cast(src.red) + + static_cast(src.green) + + static_cast(src.blue))/3); + + if (src.alpha == 255) + { + assign_pixel(dest, avg); + } + else + { + // perform this assignment using fixed point arithmetic: + // dest = src*(alpha/255) + dest*(1 - alpha/255); + // dest = src*(alpha/255) + dest*1 - dest*(alpha/255); + // dest = dest*1 + src*(alpha/255) - dest*(alpha/255); + // dest = dest*1 + (src - dest)*(alpha/255); + // dest += (src - dest)*(alpha/255); + + int temp = avg; + // copy dest into dest_copy using assign_pixel to avoid potential + // warnings about implicit float to int warnings. + int dest_copy; + assign_pixel(dest_copy, dest); + + temp -= dest_copy; + + temp *= src.alpha; + + temp /= 255; + + assign_pixel(dest, temp+dest_copy); + } + } + + template < typename P1, typename P2 > + typename enable_if_c::grayscale && pixel_traits::hsi>::type + assign(P1& dest, const P2& src) + { + assign_pixel(dest, src.i); + } + + template < typename P1, typename P2 > + typename enable_if_c::grayscale && pixel_traits::lab>::type + assign(P1& dest, const P2& src) + { + assign_pixel(dest, src.l); + } + + + // ----------------------------- + + struct HSL + { + double h; + double s; + double l; + }; + + struct COLOUR + { + double r; + double g; + double b; + }; + + /* + I found this excellent bit of code for dealing with HSL spaces at + http://local.wasp.uwa.edu.au/~pbourke/colour/hsl/ + */ + /* + Calculate HSL from RGB + Hue is in degrees + Lightness is between 0 and 1 + Saturation is between 0 and 1 + */ + inline HSL RGB2HSL(COLOUR c1) + { + double themin,themax,delta; + HSL c2; + using namespace std; + + themin = std::min(c1.r,std::min(c1.g,c1.b)); + themax = std::max(c1.r,std::max(c1.g,c1.b)); + delta = themax - themin; + c2.l = (themin + themax) / 2; + c2.s = 0; + if (c2.l > 0 && c2.l < 1) + c2.s = delta / (c2.l < 0.5 ? (2*c2.l) : (2-2*c2.l)); + c2.h = 0; + if (delta > 0) { + if (themax == c1.r && themax != c1.g) + c2.h += (c1.g - c1.b) / delta; + if (themax == c1.g && themax != c1.b) + c2.h += (2 + (c1.b - c1.r) / delta); + if (themax == c1.b && themax != c1.r) + c2.h += (4 + (c1.r - c1.g) / delta); + c2.h *= 60; + } + return(c2); + } + + /* + Calculate RGB from HSL, reverse of RGB2HSL() + Hue is in degrees + Lightness is between 0 and 1 + Saturation is between 0 and 1 + */ + inline COLOUR HSL2RGB(HSL c1) + { + COLOUR c2,sat,ctmp; + using namespace std; + + if (c1.h < 120) { + sat.r = (120 - c1.h) / 60.0; + sat.g = c1.h / 60.0; + sat.b = 0; + } else if (c1.h < 240) { + sat.r = 0; + sat.g = (240 - c1.h) / 60.0; + sat.b = (c1.h - 120) / 60.0; + } else { + sat.r = (c1.h - 240) / 60.0; + sat.g = 0; + sat.b = (360 - c1.h) / 60.0; + } + sat.r = std::min(sat.r,1.0); + sat.g = std::min(sat.g,1.0); + sat.b = std::min(sat.b,1.0); + + ctmp.r = 2 * c1.s * sat.r + (1 - c1.s); + ctmp.g = 2 * c1.s * sat.g + (1 - c1.s); + ctmp.b = 2 * c1.s * sat.b + (1 - c1.s); + + if (c1.l < 0.5) { + c2.r = c1.l * ctmp.r; + c2.g = c1.l * ctmp.g; + c2.b = c1.l * ctmp.b; + } else { + c2.r = (1 - c1.l) * ctmp.r + 2 * c1.l - 1; + c2.g = (1 - c1.l) * ctmp.g + 2 * c1.l - 1; + c2.b = (1 - c1.l) * ctmp.b + 2 * c1.l - 1; + } + + return(c2); + } + + // ----------------------------- + + struct Lab + { + double l; + double a; + double b; + }; + /* + Calculate Lab from RGB + L is between 0 and 100 + a is between -128 and 127 + b is between -128 and 127 + RGB is between 0.0 and 1.0 + */ + inline Lab RGB2Lab(COLOUR c1) + { + Lab c2; + using namespace std; + + double var_R = c1.r; + double var_G = c1.g; + double var_B = c1.b; + + if (var_R > 0.04045) { + var_R = pow(((var_R + 0.055) / 1.055), 2.4); + } else { + var_R = var_R / 12.92; + } + + if (var_G > 0.04045) { + var_G = pow(((var_G + 0.055) / 1.055), 2.4); + } else { + var_G = var_G / 12.92; + } + + if (var_B > 0.04045) { + var_B = pow(((var_B + 0.055) / 1.055), 2.4); + } else { + var_B = var_B / 12.92; + } + + var_R = var_R * 100; + var_G = var_G * 100; + var_B = var_B * 100; + +//Observer. = 2°, Illuminant = D65 + double X = var_R * 0.4124 + var_G * 0.3576 + var_B * 0.1805; + double Y = var_R * 0.2126 + var_G * 0.7152 + var_B * 0.0722; + double Z = var_R * 0.0193 + var_G * 0.1192 + var_B * 0.9505; + + double var_X = X / 95.047; + double var_Y = Y / 100.000; + double var_Z = Z / 108.883; + + if (var_X > 0.008856) { + var_X = pow(var_X, (1.0 / 3)); + } + else { + var_X = (7.787 * var_X) + (16.0 / 116); + } + + if (var_Y > 0.008856) { + var_Y = pow(var_Y, (1.0 / 3)); + } + else { + var_Y = (7.787 * var_Y) + (16.0 / 116); + } + + if (var_Z > 0.008856) { + var_Z = pow(var_Z, (1.0 / 3)); + } + else { + var_Z = (7.787 * var_Z) + (16.0 / 116); + } + + //clamping + c2.l = max(0.0, (116.0 * var_Y) - 16); + c2.a = max(-128.0, min(127.0, 500.0 * (var_X - var_Y))); + c2.b = max(-128.0, min(127.0, 200.0 * (var_Y - var_Z))); + + return c2; + } + + /* + Calculate RGB from Lab, reverse of RGB2LAb() + L is between 0 and 100 + a is between -128 and 127 + b is between -128 and 127 + RGB is between 0.0 and 1.0 + */ + inline COLOUR Lab2RGB(Lab c1) { + COLOUR c2; + using namespace std; + + double var_Y = (c1.l + 16) / 116.0; + double var_X = (c1.a / 500.0) + var_Y; + double var_Z = var_Y - (c1.b / 200); + + if (pow(var_Y, 3) > 0.008856) { + var_Y = pow(var_Y, 3); + } else { + var_Y = (var_Y - 16.0 / 116) / 7.787; + } + + if (pow(var_X, 3) > 0.008856) { + var_X = pow(var_X, 3); + } else { + var_X = (var_X - 16.0 / 116) / 7.787; + } + + if (pow(var_Z, 3) > 0.008856) { + var_Z = pow(var_Z, 3); + } else { + var_Z = (var_Z - 16.0 / 116) / 7.787; + } + + double X = var_X * 95.047; + double Y = var_Y * 100.000; + double Z = var_Z * 108.883; + + var_X = X / 100.0; + var_Y = Y / 100.0; + var_Z = Z / 100.0; + + double var_R = var_X * 3.2406 + var_Y * -1.5372 + var_Z * -0.4986; + double var_G = var_X * -0.9689 + var_Y * 1.8758 + var_Z * 0.0415; + double var_B = var_X * 0.0557 + var_Y * -0.2040 + var_Z * 1.0570; + + if (var_R > 0.0031308) { + var_R = 1.055 * pow(var_R, (1 / 2.4)) - 0.055; + } else { + var_R = 12.92 * var_R; + } + + if (var_G > 0.0031308) { + var_G = 1.055 * pow(var_G, (1 / 2.4)) - 0.055; + } else { + var_G = 12.92 * var_G; + } + + if (var_B > 0.0031308) { + var_B = 1.055 * pow(var_B, (1 / 2.4)) - 0.055; + } else { + var_B = 12.92 * var_B; + } + + // clamping + c2.r = max(0.0, min(1.0, var_R)); + c2.g = max(0.0, min(1.0, var_G)); + c2.b = max(0.0, min(1.0, var_B)); + + return (c2); + } + + + // ----------------------------- + // dest is a color rgb_pixel + + template < typename P1 > + typename enable_if_c::rgb>::type + assign(P1& dest, const unsigned char& src) + { + dest.red = src; + dest.green = src; + dest.blue = src; + } + + template < typename P1, typename P2 > + typename enable_if_c::rgb && pixel_traits::grayscale>::type + assign(P1& dest, const P2& src) + { + unsigned char p; + assign_pixel(p, src); + dest.red = p; + dest.green = p; + dest.blue = p; + } + + template < typename P1, typename P2 > + typename enable_if_c::rgb && pixel_traits::rgb_alpha>::type + assign(P1& dest, const P2& src) + { + if (src.alpha == 255) + { + dest.red = src.red; + dest.green = src.green; + dest.blue = src.blue; + } + else + { + // perform this assignment using fixed point arithmetic: + // dest = src*(alpha/255) + dest*(1 - alpha/255); + // dest = src*(alpha/255) + dest*1 - dest*(alpha/255); + // dest = dest*1 + src*(alpha/255) - dest*(alpha/255); + // dest = dest*1 + (src - dest)*(alpha/255); + // dest += (src - dest)*(alpha/255); + + unsigned int temp_r = src.red; + unsigned int temp_g = src.green; + unsigned int temp_b = src.blue; + + temp_r -= dest.red; + temp_g -= dest.green; + temp_b -= dest.blue; + + temp_r *= src.alpha; + temp_g *= src.alpha; + temp_b *= src.alpha; + + temp_r >>= 8; + temp_g >>= 8; + temp_b >>= 8; + + dest.red += static_cast(temp_r&0xFF); + dest.green += static_cast(temp_g&0xFF); + dest.blue += static_cast(temp_b&0xFF); + } + } + + template < typename P1, typename P2 > + typename enable_if_c::rgb && pixel_traits::hsi>::type + assign(P1& dest, const P2& src) + { + COLOUR c; + HSL h; + h.h = src.h; + h.h = h.h/255.0*360; + h.s = src.s/255.0; + h.l = src.i/255.0; + c = HSL2RGB(h); + + dest.red = static_cast(c.r*255.0 + 0.5); + dest.green = static_cast(c.g*255.0 + 0.5); + dest.blue = static_cast(c.b*255.0 + 0.5); + } + + template < typename P1, typename P2 > + typename enable_if_c::rgb && pixel_traits::lab>::type + assign(P1& dest, const P2& src) + { + COLOUR c; + Lab l; + l.l = (src.l/255.0)*100; + l.a = (src.a-128.0); + l.b = (src.b-128.0); + c = Lab2RGB(l); + + dest.red = static_cast(c.r*255.0 + 0.5); + dest.green = static_cast(c.g*255.0 + 0.5); + dest.blue = static_cast(c.b*255.0 + 0.5); + } + + + // ----------------------------- + // dest is a color rgb_alpha_pixel + + template < typename P1 > + typename enable_if_c::rgb_alpha>::type + assign(P1& dest, const unsigned char& src) + { + dest.red = src; + dest.green = src; + dest.blue = src; + dest.alpha = 255; + } + + + template < typename P1, typename P2 > + typename enable_if_c::rgb_alpha && pixel_traits::grayscale>::type + assign(P1& dest, const P2& src) + { + unsigned char p; + assign_pixel(p, src); + + dest.red = p; + dest.green = p; + dest.blue = p; + dest.alpha = 255; + } + + template < typename P1, typename P2 > + typename enable_if_c::rgb_alpha && pixel_traits::rgb>::type + assign(P1& dest, const P2& src) + { + dest.red = src.red; + dest.green = src.green; + dest.blue = src.blue; + dest.alpha = 255; + } + + template < typename P1, typename P2 > + typename enable_if_c::rgb_alpha && pixel_traits::hsi>::type + assign(P1& dest, const P2& src) + { + COLOUR c; + HSL h; + h.h = src.h; + h.h = h.h/255.0*360; + h.s = src.s/255.0; + h.l = src.i/255.0; + c = HSL2RGB(h); + + dest.red = static_cast(c.r*255.0 + 0.5); + dest.green = static_cast(c.g*255.0 + 0.5); + dest.blue = static_cast(c.b*255.0 + 0.5); + dest.alpha = 255; + } + + template < typename P1, typename P2 > + typename enable_if_c::rgb_alpha && pixel_traits::lab>::type + assign(P1& dest, const P2& src) + { + COLOUR c; + Lab l; + l.l = (src.l/255.0)*100; + l.a = (src.a-128.0); + l.b = (src.b-128.0); + c = Lab2RGB(l); + + dest.red = static_cast(c.r * 255 + 0.5); + dest.green = static_cast(c.g * 255 + 0.5); + dest.blue = static_cast(c.b * 255 + 0.5); + dest.alpha = 255; + } + // ----------------------------- + // dest is an hsi pixel + + template < typename P1> + typename enable_if_c::hsi>::type + assign(P1& dest, const unsigned char& src) + { + dest.h = 0; + dest.s = 0; + dest.i = src; + } + + + template < typename P1, typename P2 > + typename enable_if_c::hsi && pixel_traits::grayscale>::type + assign(P1& dest, const P2& src) + { + dest.h = 0; + dest.s = 0; + assign_pixel(dest.i, src); + } + + template < typename P1, typename P2 > + typename enable_if_c::hsi && pixel_traits::rgb>::type + assign(P1& dest, const P2& src) + { + COLOUR c1; + HSL c2; + c1.r = src.red/255.0; + c1.g = src.green/255.0; + c1.b = src.blue/255.0; + c2 = RGB2HSL(c1); + + dest.h = static_cast(c2.h/360.0*255.0 + 0.5); + dest.s = static_cast(c2.s*255.0 + 0.5); + dest.i = static_cast(c2.l*255.0 + 0.5); + } + + template < typename P1, typename P2 > + typename enable_if_c::hsi && pixel_traits::rgb_alpha>::type + assign(P1& dest, const P2& src) + { + rgb_pixel temp; + // convert target hsi pixel to rgb + assign_pixel_helpers::assign(temp,dest); + + // now assign the rgb_alpha value to our temp rgb pixel + assign_pixel_helpers::assign(temp,src); + + // now we can just go assign the new rgb value to the + // hsi pixel + assign_pixel_helpers::assign(dest,temp); + } + + template < typename P1, typename P2 > + typename enable_if_c::hsi && pixel_traits::lab>::type + assign(P1& dest, const P2& src) + { + rgb_pixel temp; + // convert lab value to our temp rgb pixel + assign_pixel_helpers::assign(temp,src); + // now we can just go assign the new rgb value to the + // hsi pixel + assign_pixel_helpers::assign(dest,temp); + } + + // ----------------------------- + // dest is an lab pixel + template < typename P1> + typename enable_if_c::lab>::type + assign(P1& dest, const unsigned char& src) + { + dest.a = 128; + dest.b = 128; + dest.l = src; + } + + + template < typename P1, typename P2 > + typename enable_if_c::lab && pixel_traits::grayscale>::type + assign(P1& dest, const P2& src) + { + dest.a = 128; + dest.b = 128; + assign_pixel(dest.l, src); + } + + template < typename P1, typename P2 > + typename enable_if_c::lab && pixel_traits::rgb>::type + assign(P1& dest, const P2& src) + { + COLOUR c1; + Lab c2; + c1.r = src.red / 255.0; + c1.g = src.green / 255.0; + c1.b = src.blue / 255.0; + c2 = RGB2Lab(c1); + + dest.l = static_cast((c2.l / 100) * 255 + 0.5); + dest.a = static_cast(c2.a + 128 + 0.5); + dest.b = static_cast(c2.b + 128 + 0.5); + } + + template < typename P1, typename P2 > + typename enable_if_c::lab && pixel_traits::rgb_alpha>::type + assign(P1& dest, const P2& src) + { + rgb_pixel temp; + // convert target lab pixel to rgb + assign_pixel_helpers::assign(temp,dest); + + // now assign the rgb_alpha value to our temp rgb pixel + assign_pixel_helpers::assign(temp,src); + + // now we can just go assign the new rgb value to the + // lab pixel + assign_pixel_helpers::assign(dest,temp); + } + + template < typename P1, typename P2 > + typename enable_if_c::lab && pixel_traits::hsi>::type + assign(P1& dest, const P2& src) + { + rgb_pixel temp; + + // convert hsi value to our temp rgb pixel + assign_pixel_helpers::assign(temp,src); + + // now we can just go assign the new rgb value to the + // lab pixel + assign_pixel_helpers::assign(dest,temp); + } + } + + // ----------------------------- + + template < typename P1, typename P2 > + inline void assign_pixel ( + P1& dest, + const P2& src + ) { assign_pixel_helpers::assign(dest,src); } + +// ---------------------------------------------------------------------------------------- + + template < + typename P, + typename T + > + inline typename enable_if_c::grayscale>::type assign_pixel_intensity_helper ( + P& dest, + const T& new_intensity + ) + { + assign_pixel(dest, new_intensity); + } + + template < + typename P, + typename T + > + inline typename enable_if_c::grayscale == false && + pixel_traits

    ::has_alpha>::type assign_pixel_intensity_helper ( + P& dest, + const T& new_intensity + ) + { + hsi_pixel p; + const unsigned long old_alpha = dest.alpha; + dest.alpha = 255; + rgb_pixel temp; + assign_pixel(temp, dest); // put dest into an rgb_pixel to avoid the somewhat complicated assign_pixel(hsi,rgb_alpha). + assign_pixel(p,temp); + assign_pixel(p.i, new_intensity); + assign_pixel(dest,p); + dest.alpha = old_alpha; + } + + template < + typename P, + typename T + > + inline typename enable_if_c::grayscale == false && + pixel_traits

    ::has_alpha == false>::type assign_pixel_intensity_helper ( + P& dest, + const T& new_intensity + ) + { + hsi_pixel p; + assign_pixel(p,dest); + assign_pixel(p.i, new_intensity); + assign_pixel(dest,p); + } + + template < + typename P, + typename T + > + inline void assign_pixel_intensity ( + P& dest, + const T& new_intensity + ) + { + assign_pixel_intensity_helper(dest, new_intensity); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename P + > + inline typename enable_if_c::grayscale, P>::type get_pixel_intensity_helper ( + const P& src + ) + { + return src; + } + + template < + typename P + > + inline typename enable_if_c::grayscale == false&& + pixel_traits

    ::has_alpha, + typename pixel_traits

    ::basic_pixel_type>::type get_pixel_intensity_helper ( + const P& src + ) + { + P temp = src; + temp.alpha = 255; + typename pixel_traits

    ::basic_pixel_type p; + assign_pixel(p,temp); + return p; + } + + template < + typename P + > + inline typename enable_if_c::grayscale == false&& + pixel_traits

    ::has_alpha == false, + typename pixel_traits

    ::basic_pixel_type>::type get_pixel_intensity_helper ( + const P& src + ) + { + typename pixel_traits

    ::basic_pixel_type p; + assign_pixel(p,src); + return p; + } + + template < + typename P + > + inline typename pixel_traits

    ::basic_pixel_type get_pixel_intensity ( + const P& src + ) + { + return get_pixel_intensity_helper(src); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + inline void serialize ( + const rgb_alpha_pixel& item, + std::ostream& out + ) + { + try + { + serialize(item.red,out); + serialize(item.green,out); + serialize(item.blue,out); + serialize(item.alpha,out); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing object of type rgb_alpha_pixel"); + } + } + +// ---------------------------------------------------------------------------------------- + + inline void deserialize ( + rgb_alpha_pixel& item, + std::istream& in + ) + { + try + { + deserialize(item.red,in); + deserialize(item.green,in); + deserialize(item.blue,in); + deserialize(item.alpha,in); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while deserializing object of type rgb_alpha_pixel"); + } + } + +// ---------------------------------------------------------------------------------------- + + inline void serialize ( + const rgb_pixel& item, + std::ostream& out + ) + { + try + { + serialize(item.red,out); + serialize(item.green,out); + serialize(item.blue,out); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing object of type rgb_pixel"); + } + } + +// ---------------------------------------------------------------------------------------- + + inline void deserialize ( + rgb_pixel& item, + std::istream& in + ) + { + try + { + deserialize(item.red,in); + deserialize(item.green,in); + deserialize(item.blue,in); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while deserializing object of type rgb_pixel"); + } + } + +// ---------------------------------------------------------------------------------------- + + inline void serialize ( + const bgr_pixel& item, + std::ostream& out + ) + { + try + { + serialize(item.blue,out); + serialize(item.green,out); + serialize(item.red,out); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing object of type bgr_pixel"); + } + } + +// ---------------------------------------------------------------------------------------- + + inline void deserialize ( + bgr_pixel& item, + std::istream& in + ) + { + try + { + deserialize(item.blue,in); + deserialize(item.green,in); + deserialize(item.red,in); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while deserializing object of type bgr_pixel"); + } + } + +// ---------------------------------------------------------------------------------------- + + inline void serialize ( + const hsi_pixel& item, + std::ostream& out + ) + { + try + { + serialize(item.h,out); + serialize(item.s,out); + serialize(item.i,out); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing object of type hsi_pixel"); + } + } + +// ---------------------------------------------------------------------------------------- + + inline void deserialize ( + hsi_pixel& item, + std::istream& in + ) + { + try + { + deserialize(item.h,in); + deserialize(item.s,in); + deserialize(item.i,in); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while deserializing object of type hsi_pixel"); + } + } + +// ---------------------------------------------------------------------------------------- + + inline void serialize ( + const lab_pixel& item, + std::ostream& out + ) + { + try + { + serialize(item.l,out); + serialize(item.a,out); + serialize(item.b,out); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing object of type lab_pixel"); + } + } + +// ---------------------------------------------------------------------------------------- + + inline void deserialize ( + lab_pixel& item, + std::istream& in + ) + { + try + { + deserialize(item.l,in); + deserialize(item.a,in); + deserialize(item.b,in); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while deserializing object of type lab_pixel"); + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_PIXEl_ + diff --git a/ml/dlib/dlib/platform.h b/ml/dlib/dlib/platform.h new file mode 100644 index 000000000..f3000a6cd --- /dev/null +++ b/ml/dlib/dlib/platform.h @@ -0,0 +1,65 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#ifdef DLIB_ALL_SOURCE_END +#include "dlib_basic_cpp_build_tutorial.txt" +#endif + +#ifndef DLIB_PLATFORm_ +#define DLIB_PLATFORm_ + + +/*! + This file ensures that: + - if (we are compiling under a posix platform) then + - POSIX will be defined + - if (this is also Mac OS X) then + - MACOSX will be defined + - if (this is also HP-UX) then + - HPUX will be defined + - if (we are compiling under an MS Windows platform) then + - WIN32 will be defined +!*/ + + +/* + A good reference for this sort of information is + http://predef.sourceforge.net/ +*/ + +// Define WIN32 if this is MS Windows +#ifndef WIN32 + #if defined( _MSC_VER) || defined(__BORLANDC__) || defined(_WIN32) || defined(__WIN32__) || defined(__TOS_WIN__) + #define WIN32 + #endif +#endif + +#ifndef WIN32 + // since this is the only other platform the library currently supports + // just assume it is POSIX if it isn't WIN32 + #ifndef POSIX + #define POSIX + #endif + + #ifndef HPUX + #if defined(__hpux ) || defined(hpux) || defined (_hpux) + #define HPUX + #endif + #endif + + #ifndef MACOSX + #ifdef __MACOSX__ + #define MACOSX + #endif + #ifdef __APPLE__ + #define MACOSX + #endif + #endif + +#endif + + + + +#endif // DLIB_PLATFORm_ + diff --git a/ml/dlib/dlib/python.h b/ml/dlib/dlib/python.h new file mode 100644 index 000000000..07b9c0707 --- /dev/null +++ b/ml/dlib/dlib/python.h @@ -0,0 +1,14 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_PYTHoN_TOP_ +#define DLIB_PYTHoN_TOP_ + +#include "python/pybind_utils.h" +#include "python/pyassert.h" +#include "python/serialize_pickle.h" +#include "python/numpy.h" +#include "python/numpy_image.h" + +#endif // DLIB_PYTHoN_TOP_ + + diff --git a/ml/dlib/dlib/python/numpy.h b/ml/dlib/dlib/python/numpy.h new file mode 100644 index 000000000..9b2c1a01c --- /dev/null +++ b/ml/dlib/dlib/python/numpy.h @@ -0,0 +1,214 @@ +// Copyright (C) 2014 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_PYTHON_NuMPY_Hh_ +#define DLIB_PYTHON_NuMPY_Hh_ + +#include +#include +#include +#include +#include +#include + +namespace py = pybind11; + +// ---------------------------------------------------------------------------------------- + +template +void validate_numpy_array_type ( + const py::object& obj +) +{ + const char ch = obj.attr("dtype").attr("char").cast(); + + using T = typename dlib::pixel_traits::basic_pixel_type; + + if (dlib::is_same_type::value) + { + if (ch != 'd') + throw dlib::error("Expected numpy.ndarray of float64"); + } + else if (dlib::is_same_type::value) + { + if (ch != 'f') + throw dlib::error("Expected numpy.ndarray of float32"); + } + else if (dlib::is_same_type::value) + { + if (ch != 'h') + throw dlib::error("Expected numpy.ndarray of int16"); + } + else if (dlib::is_same_type::value) + { + if (ch != 'H') + throw dlib::error("Expected numpy.ndarray of uint16"); + } + else if (dlib::is_same_type::value) + { + if (ch != 'i') + throw dlib::error("Expected numpy.ndarray of int32"); + } + else if (dlib::is_same_type::value) + { + if (ch != 'I') + throw dlib::error("Expected numpy.ndarray of uint32"); + } + else if (dlib::is_same_type::value) + { + if (ch != 'B') + throw dlib::error("Expected numpy.ndarray of uint8"); + } + else if (dlib::is_same_type::value) + { + if (ch != 'b') + throw dlib::error("Expected numpy.ndarray of int8"); + } + else + { + throw dlib::error("validate_numpy_array_type() called with unsupported type."); + } +} + +// ---------------------------------------------------------------------------------------- + +template +void get_numpy_ndarray_shape ( + const py::object& obj, + long (&shape)[dims] +) +/*! + ensures + - stores the shape of the array into #shape. + - the dimension of the given numpy array is not greater than #dims. +!*/ +{ + Py_buffer pybuf; + if (PyObject_GetBuffer(obj.ptr(), &pybuf, PyBUF_STRIDES )) + throw dlib::error("Expected numpy.ndarray with shape set."); + + try + { + + if (pybuf.ndim > dims) + throw dlib::error("Expected array with " + dlib::cast_to_string(dims) + " dimensions."); + + for (int i = 0; i < dims; ++i) + { + if (i < pybuf.ndim) + shape[i] = pybuf.shape[i]; + else + shape[i] = 1; + } + } + catch(...) + { + PyBuffer_Release(&pybuf); + throw; + } + PyBuffer_Release(&pybuf); +} + +// ---------------------------------------------------------------------------------------- + +template +void get_numpy_ndarray_parts ( + py::object& obj, + T*& data, + dlib::array& contig_buf, + long (&shape)[dims] +) +/*! + ensures + - extracts the pointer to the data from the given numpy ndarray. Stores the shape + of the array into #shape. + - the dimension of the given numpy array is not greater than #dims. + - #shape[#dims-1] == pixel_traits::num when #dims is greater than 2 +!*/ +{ + Py_buffer pybuf; + if (PyObject_GetBuffer(obj.ptr(), &pybuf, PyBUF_STRIDES | PyBUF_WRITABLE )) + throw dlib::error("Expected writable numpy.ndarray with shape set."); + + try + { + validate_numpy_array_type(obj); + + if (pybuf.ndim > dims) + throw dlib::error("Expected array with " + dlib::cast_to_string(dims) + " dimensions."); + get_numpy_ndarray_shape(obj, shape); + + if (dlib::pixel_traits::num > 1 && dlib::pixel_traits::num != shape[dims-1]) + throw dlib::error("Expected numpy.ndarray with " + dlib::cast_to_string(dlib::pixel_traits::num) + " channels."); + + if (PyBuffer_IsContiguous(&pybuf, 'C')) + data = (T*)pybuf.buf; + else + { + contig_buf.resize(pybuf.len); + if (PyBuffer_ToContiguous(&contig_buf[0], &pybuf, pybuf.len, 'C')) + throw dlib::error("Can't copy numpy.ndarray to a contiguous buffer."); + data = &contig_buf[0]; + } + } + catch(...) + { + PyBuffer_Release(&pybuf); + throw; + } + PyBuffer_Release(&pybuf); +} + +// ---------------------------------------------------------------------------------------- + +template +void get_numpy_ndarray_parts ( + const py::object& obj, + const T*& data, + dlib::array& contig_buf, + long (&shape)[dims] +) +/*! + ensures + - extracts the pointer to the data from the given numpy ndarray. Stores the shape + of the array into #shape. + - the dimension of the given numpy array is not greater than #dims. + - #shape[#dims-1] == pixel_traits::num when #dims is greater than 2 +!*/ +{ + Py_buffer pybuf; + if (PyObject_GetBuffer(obj.ptr(), &pybuf, PyBUF_STRIDES )) + throw dlib::error("Expected numpy.ndarray with shape set."); + + try + { + validate_numpy_array_type(obj); + + if (pybuf.ndim > dims) + throw dlib::error("Expected array with " + dlib::cast_to_string(dims) + " dimensions."); + get_numpy_ndarray_shape(obj, shape); + + if (dlib::pixel_traits::num > 1 && dlib::pixel_traits::num != shape[dims-1]) + throw dlib::error("Expected numpy.ndarray with " + dlib::cast_to_string(dlib::pixel_traits::num) + " channels."); + + if (PyBuffer_IsContiguous(&pybuf, 'C')) + data = (const T*)pybuf.buf; + else + { + contig_buf.resize(pybuf.len); + if (PyBuffer_ToContiguous(&contig_buf[0], &pybuf, pybuf.len, 'C')) + throw dlib::error("Can't copy numpy.ndarray to a contiguous buffer."); + data = &contig_buf[0]; + } + } + catch(...) + { + PyBuffer_Release(&pybuf); + throw; + } + PyBuffer_Release(&pybuf); +} + +// ---------------------------------------------------------------------------------------- + +#endif // DLIB_PYTHON_NuMPY_Hh_ + diff --git a/ml/dlib/dlib/python/numpy_image.h b/ml/dlib/dlib/python/numpy_image.h new file mode 100644 index 000000000..49ea80317 --- /dev/null +++ b/ml/dlib/dlib/python/numpy_image.h @@ -0,0 +1,129 @@ +// Copyright (C) 2014 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_PYTHON_NuMPY_IMAGE_Hh_ +#define DLIB_PYTHON_NuMPY_IMAGE_Hh_ + +#include "numpy.h" +#include +#include +#include + + +// ---------------------------------------------------------------------------------------- + +class numpy_gray_image +{ +public: + + numpy_gray_image() : _data(0), _nr(0), _nc(0) {} + numpy_gray_image (py::object& img) + { + long shape[2]; + get_numpy_ndarray_parts(img, _data, _contig_buf, shape); + _nr = shape[0]; + _nc = shape[1]; + } + + friend inline long num_rows(const numpy_gray_image& img) { return img._nr; } + friend inline long num_columns(const numpy_gray_image& img) { return img._nc; } + friend inline void* image_data(numpy_gray_image& img) { return img._data; } + friend inline const void* image_data(const numpy_gray_image& img) { return img._data; } + friend inline long width_step(const numpy_gray_image& img) { return img._nc*sizeof(unsigned char); } + +private: + + unsigned char* _data; + dlib::array _contig_buf; + long _nr; + long _nc; +}; + +namespace dlib +{ + template <> + struct image_traits + { + typedef unsigned char pixel_type; + }; +} + +// ---------------------------------------------------------------------------------------- + +inline bool is_gray_python_image (py::object& img) +{ + try + { + long shape[2]; + get_numpy_ndarray_shape(img, shape); + return true; + } + catch (dlib::error&) + { + return false; + } +} + +// ---------------------------------------------------------------------------------------- + +class numpy_rgb_image +{ +public: + + numpy_rgb_image() : _data(0), _nr(0), _nc(0) {} + numpy_rgb_image (py::object& img) + { + long shape[3]; + get_numpy_ndarray_parts(img, _data, _contig_buf, shape); + _nr = shape[0]; + _nc = shape[1]; + if (shape[2] != 3) + throw dlib::error("Error, python object is not a three band image and therefore can't be a RGB image."); + } + + friend inline long num_rows(const numpy_rgb_image& img) { return img._nr; } + friend inline long num_columns(const numpy_rgb_image& img) { return img._nc; } + friend inline void* image_data(numpy_rgb_image& img) { return img._data; } + friend inline const void* image_data(const numpy_rgb_image& img) { return img._data; } + friend inline long width_step(const numpy_rgb_image& img) { return img._nc*sizeof(dlib::rgb_pixel); } + + +private: + + dlib::rgb_pixel* _data; + dlib::array _contig_buf; + long _nr; + long _nc; +}; + +namespace dlib +{ + template <> + struct image_traits + { + typedef rgb_pixel pixel_type; + }; +} + +// ---------------------------------------------------------------------------------------- + + +inline bool is_rgb_python_image (py::object& img) +{ + try + { + long shape[3]; + get_numpy_ndarray_shape(img, shape); + if (shape[2] == 3) + return true; + return false; + } + catch (dlib::error&) + { + return false; + } +} + +// ---------------------------------------------------------------------------------------- + +#endif // DLIB_PYTHON_NuMPY_IMAGE_Hh_ + diff --git a/ml/dlib/dlib/python/pyassert.h b/ml/dlib/dlib/python/pyassert.h new file mode 100644 index 000000000..80939f501 --- /dev/null +++ b/ml/dlib/dlib/python/pyassert.h @@ -0,0 +1,17 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_PYaSSERT_Hh_ +#define DLIB_PYaSSERT_Hh_ + +#include + +#define pyassert(_exp,_message) \ + {if ( !(_exp) ) \ + { \ + namespace py = pybind11; \ + PyErr_SetString( PyExc_ValueError, _message ); \ + throw py::error_already_set(); \ + }} + +#endif // DLIB_PYaSSERT_Hh_ + diff --git a/ml/dlib/dlib/python/pybind_utils.h b/ml/dlib/dlib/python/pybind_utils.h new file mode 100644 index 000000000..7f94cf32d --- /dev/null +++ b/ml/dlib/dlib/python/pybind_utils.h @@ -0,0 +1,82 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_PYBIND_UtILS_Hh_ +#define DLIB_PYBIND_UtILS_Hh_ + +#include +#include +#include +#include + +namespace py = pybind11; + +template +std::vector python_list_to_vector ( + const py::list& obj +) +/*! + ensures + - converts a python object into a std::vector and returns it. +!*/ +{ + std::vector vect(len(obj)); + for (unsigned long i = 0; i < vect.size(); ++i) + { + vect[i] = obj[i].cast(); + } + return vect; +} + +template +py::list vector_to_python_list ( + const std::vector& vect +) +/*! + ensures + - converts a std::vector into a python list object. +!*/ +{ + py::list obj; + for (unsigned long i = 0; i < vect.size(); ++i) + obj.append(vect[i]); + return obj; +} + +template +void extend_vector_with_python_list ( + std::vector &v, + const py::list &l +) +/*! + ensures + - appends items from a python list to the end of std::vector. +!*/ +{ + for (const auto &item : l) + v.push_back(item.cast()); +} + +// ---------------------------------------------------------------------------------------- + +template +std::shared_ptr load_object_from_file ( + const std::string& filename +) +/*! + ensures + - deserializes an object of type T from the given file and returns it. +!*/ +{ + std::ifstream fin(filename.c_str(), std::ios::binary); + if (!fin) + throw dlib::error("Unable to open " + filename); + auto obj = std::make_shared(); + deserialize(*obj, fin); + return obj; +} + +// ---------------------------------------------------------------------------------------- + + +#endif // DLIB_PYBIND_UtILS_Hh_ + diff --git a/ml/dlib/dlib/python/serialize_pickle.h b/ml/dlib/dlib/python/serialize_pickle.h new file mode 100644 index 000000000..2dc44c322 --- /dev/null +++ b/ml/dlib/dlib/python/serialize_pickle.h @@ -0,0 +1,66 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SERIALIZE_PiCKLE_Hh_ +#define DLIB_SERIALIZE_PiCKLE_Hh_ + +#include +#include +#include +#include + +template +py::tuple getstate(const T& item) +{ + using namespace dlib; + std::vector buf; + buf.reserve(5000); + vectorstream sout(buf); + serialize(item, sout); + return py::make_tuple(py::handle( + PyBytes_FromStringAndSize(buf.size()?&buf[0]:0, buf.size()))); +} + +template +T setstate(py::tuple state) +{ + using namespace dlib; + if (len(state) != 1) + { + PyErr_SetObject(PyExc_ValueError, + py::str("expected 1-item tuple in call to __setstate__; got {}").format(state).ptr() + ); + throw py::error_already_set(); + } + + // We used to serialize by converting to a str but the boost.python routines for + // doing this don't work in Python 3. You end up getting an error about invalid + // UTF-8 encodings. So instead we access the python C interface directly and use + // bytes objects. However, we keep the deserialization code that worked with str + // for backwards compatibility with previously pickled files. + T item; + py::object obj = state[0]; + if (py::isinstance(obj)) + { + py::str data = state[0].cast(); + std::string temp = data; + std::istringstream sin(temp); + deserialize(item, sin); + } + else if(PyBytes_Check(py::object(state[0]).ptr())) + { + py::object obj = state[0]; + char* data = PyBytes_AsString(obj.ptr()); + unsigned long num = PyBytes_Size(obj.ptr()); + std::istringstream sin(std::string(data, num)); + deserialize(item, sin); + } + else + { + throw error("Unable to unpickle, error in input file."); + } + + return item; +} + +#endif // DLIB_SERIALIZE_PiCKLE_Hh_ + diff --git a/ml/dlib/dlib/quantum_computing.h b/ml/dlib/dlib/quantum_computing.h new file mode 100644 index 000000000..11af76197 --- /dev/null +++ b/ml/dlib/dlib/quantum_computing.h @@ -0,0 +1,12 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_QUANTUM_COMPUTINg_H_ +#define DLIB_QUANTUM_COMPUTINg_H_ + +#include "quantum_computing/quantum_computing.h" + +#endif // DLIB_QUANTUM_COMPUTINg_H_ + + + + diff --git a/ml/dlib/dlib/quantum_computing/quantum_computing.h b/ml/dlib/dlib/quantum_computing/quantum_computing.h new file mode 100644 index 000000000..afa2e40e7 --- /dev/null +++ b/ml/dlib/dlib/quantum_computing/quantum_computing.h @@ -0,0 +1,863 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_QUANTUM_COMPUTINg_1_ +#define DLIB_QUANTUM_COMPUTINg_1_ + +#include +#include +#include "../matrix.h" +#include "../rand.h" +#include "../enable_if.h" +#include "../algs.h" +#include "quantum_computing_abstract.h" + +namespace dlib +{ + + template + struct gate_traits {}; + + namespace qc_helpers + { + + // ------------------------------------------------------------------------------------ + + // This is a template to compute the value of 2^n at compile time + template + struct exp_2_n + { + COMPILE_TIME_ASSERT(0 <= n && n <= 30); + static const long value = exp_2_n::value*2; + }; + + template <> + struct exp_2_n<0> + { + static const long value = 1; + }; + + // ------------------------------------------------------------------------------------ + + } + + typedef std::complex qc_scalar_type; + +// ---------------------------------------------------------------------------------------- + + class quantum_register + { + public: + + quantum_register() + { + set_num_bits(1); + } + + int num_bits ( + ) const + { + return num_bits_in_register; + } + + void set_num_bits ( + int num_bits + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(1 <= num_bits && num_bits <= 30, + "\tvoid quantum_register::set_num_bits()" + << "\n\tinvalid arguments to this function" + << "\n\tnum_bits: " << num_bits + << "\n\tthis: " << this + ); + + num_bits_in_register = num_bits; + + unsigned long size = 1; + for (int i = 0; i < num_bits; ++i) + size *= 2; + + state.set_size(size); + + zero_all_bits(); + } + + void zero_all_bits() + { + set_all_elements(state,0); + state(0) = 1; + } + + void append ( + const quantum_register& reg + ) + { + num_bits_in_register += reg.num_bits_in_register; + state = tensor_product(state, reg.state); + } + + template + bool measure_bit ( + int bit, + rand_type& rnd + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(0 <= bit && bit < num_bits(), + "\tbool quantum_register::measure_bit()" + << "\n\tinvalid arguments to this function" + << "\n\tbit: " << bit + << "\n\tnum_bits(): " << num_bits() + << "\n\tthis: " << this + ); + + const bool value = (rnd.get_random_double() < probability_of_bit(bit)); + + // Next we set all the states where this bit doesn't have the given value to 0 + + // But first make a mask that selects our bit + unsigned long mask = 1; + for (int i = 0; i < bit; ++i) + mask <<= 1; + + // loop over all the elements in the state vector and zero out those that + // conflict with the measurement we just made. + for (long r = 0; r < state.nr(); ++r) + { + const unsigned long field = r; + // if this state indicates that the bit should be set and it isn't + if ((field & mask) && !value) + { + state(r) = 0; + } + // else if this state indicates that the bit should not be set and it is + else if (!(field & mask) && value) + { + state(r) = 0; + } + } + + // normalize the state + state = state/(std::sqrt(sum(norm(state)))); + + return value; + } + + template + bool measure_and_remove_bit ( + int bit, + rand_type& rnd + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(0 <= bit && bit < num_bits() && num_bits() > 0, + "\tbool quantum_register::measure_and_remove_bit()" + << "\n\tinvalid arguments to this function" + << "\n\tbit: " << bit + << "\n\tnum_bits(): " << num_bits() + << "\n\tthis: " << this + ); + + + const bool value = (rnd.get_random_double() < probability_of_bit(bit)); + quantum_register temp; + temp.set_num_bits(num_bits()-1); + + + // Next we set all the states where this bit doesn't have the given value to 0 + + // But first make a mask that selects our bit + unsigned long mask = 1; + for (int i = 0; i < bit; ++i) + mask <<= 1; + + long count = 0; + for (long r = 0; r < state.nr(); ++r) + { + const unsigned long field = r; + // if this basis vector is one that matches the measured state then keep it + if (((field & mask) != 0) == value) + { + temp.state(count) = state(r); + ++count; + } + } + + // normalize the state + temp.state = temp.state/std::sqrt(sum(norm(temp.state))); + + temp.swap(*this); + + return value; + } + + double probability_of_bit ( + int bit + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT(0 <= bit && bit < num_bits(), + "\tdouble quantum_register::probability_of_bit()" + << "\n\tinvalid arguments to this function" + << "\n\tbit: " << bit + << "\n\tnum_bits(): " << num_bits() + << "\n\tthis: " << this + ); + + + // make a mask that selects our bit + unsigned long mask = 1; + for (int i = 0; i < bit; ++i) + mask <<= 1; + + // now find the total probability of all the states that have the given bit set + double prob = 0; + for (long r = 0; r < state.nr(); ++r) + { + const unsigned long field = r; + if (field & mask) + { + prob += std::norm(state(r)); + } + } + + + return prob; + } + + const matrix& state_vector() const { return state; } + matrix& state_vector() { return state; } + + void swap ( + quantum_register& item + ) + { + exchange(num_bits_in_register, item.num_bits_in_register); + state.swap(item.state); + } + + private: + + int num_bits_in_register; + matrix state; + }; + + inline void swap ( + quantum_register& a, + quantum_register& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- + + template + class gate_exp + { + public: + static const long num_bits = gate_traits::num_bits; + static const long dims = gate_traits::dims; + + gate_exp(T& exp_) : exp(exp_) {} + + const qc_scalar_type operator() (long r, long c) const { return exp(r,c); } + + const matrix mat ( + ) const + { + matrix m; + for (long r = 0; r < m.nr(); ++r) + { + for (long c = 0; c < m.nc(); ++c) + { + m(r,c) = exp(r,c); + } + } + return m; + } + + void apply_gate_to (quantum_register& reg) const + { + // make sure requires clause is not broken + DLIB_CASSERT(reg.num_bits() == num_bits, + "\tvoid gate_exp::apply_gate_to()" + << "\n\tinvalid arguments to this function" + << "\n\treg.num_bits(): " << reg.num_bits() + << "\n\tnum_bits: " << num_bits + << "\n\tthis: " << this + ); + + + quantum_register temp(reg); + + + // check if any of the elements of the register are 1 and if so then + // we don't have to do the full matrix multiply. Or check if only a small number are non-zero. + long non_zero_elements = 0; + for (long r = 0; r < dims; ++r) + { + if (reg.state_vector()(r) != qc_scalar_type(0)) + ++non_zero_elements; + + reg.state_vector()(r) = 0; + } + + + if (non_zero_elements > 3) + { + // do a full matrix multiply to compute the output state + for (long r = 0; r < dims; ++r) + { + reg.state_vector()(r) = compute_state_element(temp.state_vector(),r); + } + } + else + { + // do a matrix multiply but only use the columns in the gate matrix + // that correspond to the non-zero register elements + for (long r = 0; r < dims; ++r) + { + if (temp.state_vector()(r) != qc_scalar_type(0)) + { + for (long i = 0; i < dims; ++i) + { + reg.state_vector()(i) += temp.state_vector()(r)*exp(i,r); + } + } + } + } + } + + template + qc_scalar_type compute_state_element ( + const matrix_exp& reg, + long row_idx + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(reg.nr() == dims && reg.nc() == 1 && + 0 <= row_idx && row_idx < dims, + "\tqc_scalar_type gate_exp::compute_state_element(reg,row_idx)" + << "\n\tinvalid arguments to this function" + << "\n\treg.nr(): " << reg.nr() + << "\n\treg.nc(): " << reg.nc() + << "\n\tdims: " << dims + << "\n\trow_idx: " << row_idx + << "\n\tthis: " << this + ); + + + return this->exp.compute_state_element(reg,row_idx); + } + + const T& ref() const { return exp; } + + private: + T& exp; + }; + +// ---------------------------------------------------------------------------------------- + + + template + class composite_gate; + + template + struct gate_traits > + { + static const long num_bits = T::num_bits + U::num_bits; + static const long dims = qc_helpers::exp_2_n::value; + }; + + template + class composite_gate : public gate_exp > + { + public: + + typedef T lhs_type; + typedef U rhs_type; + + composite_gate(const composite_gate& g) : gate_exp(*this), lhs(g.lhs), rhs(g.rhs) {} + + composite_gate( + const gate_exp& lhs_, + const gate_exp& rhs_ + ) : gate_exp(*this), lhs(lhs_.ref()), rhs(rhs_.ref()) {} + + + + static const long num_bits = gate_traits::num_bits; + static const long dims = gate_traits::dims; + + const qc_scalar_type operator() (long r, long c) const { return lhs(r/U::dims,c/U::dims)*rhs(r%U::dims, c%U::dims); } + + template + qc_scalar_type compute_state_element ( + const matrix_exp& reg, + long row_idx + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(reg.nr() == dims && reg.nc() == 1 && + 0 <= row_idx && row_idx < dims, + "\tqc_scalar_type composite_gate::compute_state_element(reg,row_idx)" + << "\n\tinvalid arguments to this function" + << "\n\treg.nr(): " << reg.nr() + << "\n\treg.nc(): " << reg.nc() + << "\n\tdims: " << dims + << "\n\trow_idx: " << row_idx + << "\n\tthis: " << this + ); + + + qc_scalar_type result = 0; + for (long c = 0; c < T::dims; ++c) + { + if (lhs(row_idx/U::dims,c) != qc_scalar_type(0)) + { + result += lhs(row_idx/U::dims,c) * rhs.compute_state_element(subm(reg,c*U::dims,0,U::dims,1), row_idx%U::dims); + } + } + + return result; + } + + + const T lhs; + const U rhs; + }; + +// ---------------------------------------------------------------------------------------- + + template + class gate; + template + struct gate_traits > + { + static const long num_bits = bits; + static const long dims = qc_helpers::exp_2_n::value; + }; + +// ---------------------------------------------------------------------------------------- + + template + class gate : public gate_exp > + { + public: + gate() : gate_exp(*this) { set_all_elements(data,0); } + gate(const gate& g) :gate_exp(*this), data(g.data) {} + + template + explicit gate(const gate_exp& g) : gate_exp(*this) + { + COMPILE_TIME_ASSERT(T::num_bits == num_bits); + for (long r = 0; r < dims; ++r) + { + for (long c = 0; c < dims; ++c) + { + data(r,c) = g(r,c); + } + } + } + + static const long num_bits = gate_traits::num_bits; + static const long dims = gate_traits::dims; + + const qc_scalar_type& operator() (long r, long c) const { return data(r,c); } + qc_scalar_type& operator() (long r, long c) { return data(r,c); } + + template + qc_scalar_type compute_state_element ( + const matrix_exp& reg, + long row_idx + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(reg.nr() == dims && reg.nc() == 1 && + 0 <= row_idx && row_idx < dims, + "\tqc_scalar_type gate::compute_state_element(reg,row_idx)" + << "\n\tinvalid arguments to this function" + << "\n\treg.nr(): " << reg.nr() + << "\n\treg.nc(): " << reg.nc() + << "\n\tdims: " << dims + << "\n\trow_idx: " << row_idx + << "\n\tthis: " << this + ); + + + return (data*reg)(row_idx); + } + + private: + + matrix data; + }; + +// ---------------------------------------------------------------------------------------- + + namespace qc_helpers + { + // This is the maximum number of bits used for cached sub-matrices in composite_gate expressions + const int qc_block_chunking_size = 8; + + template + struct is_composite_gate { const static bool value = false; }; + template + struct is_composite_gate > { const static bool value = true; }; + + + // These overloads all deal with intelligently composing chains of composite_gate expressions + // such that the resulting expression has the form: + // (gate_exp,(gate_exp,(gate_exp,(gate_exp())))) + // and each gate_exp contains a cached gate matrix for a gate of at most qc_block_chunking_size bits. + // This facilitates the optimizations in the compute_state_element() function. + template + struct combine_gates; + + // This is a base case of this recursive template. It takes care of converting small composite_gates into + // cached gate objects. + template + struct combine_gates::type > + { + typedef composite_gate,V> result_type; + + static const result_type eval ( + const composite_gate& lhs, + const gate_exp& rhs + ) + { + typedef gate gate_type; + return composite_gate(gate_type(lhs), rhs); + } + }; + + // this is the recursive step of this template + template + struct combine_gates::value == true)>::type > + { + typedef typename combine_gates::result_type inner_type; + typedef composite_gate result_type; + + static const result_type eval ( + const composite_gate& lhs, + const gate_exp& rhs + ) + { + return composite_gate(lhs.lhs, combine_gates::eval(lhs.rhs,rhs)); + } + + }; + + // This is a base case of this recursive template. It takes care of adding new gates when the left + // hand side is too big to just turn it into a cached gate object. + template + struct combine_gates qc_block_chunking_size && + is_composite_gate::value == false)>::type > + { + typedef composite_gate > result_type; + + static const result_type eval ( + const composite_gate& lhs, + const gate_exp& rhs + ) + { + return result_type(lhs.lhs, composite_gate(lhs.rhs, rhs)); + } + + }; + + } + + template + const composite_gate operator, ( + const gate_exp& lhs, + const gate_exp& rhs + ) + { + return composite_gate(lhs,rhs); + } + + template + const typename qc_helpers::combine_gates::result_type operator, ( + const composite_gate& lhs, + const gate_exp& rhs + ) + { + return qc_helpers::combine_gates::eval(lhs,rhs); + } + + // If you are getting an error here then it means that you are trying to combine a gate expression + // with an integer somewhere (and that is an error). + template void operator, ( const gate_exp&, int) { COMPILE_TIME_ASSERT(sizeof(T) > 100000000); } + template void operator, ( int, const gate_exp&) { COMPILE_TIME_ASSERT(sizeof(T) > 100000000); } + +// ---------------------------------------------------------------------------------------- + + namespace quantum_gates + { + template + class cnot; + + template + class toffoli; + } + + template + struct gate_traits > + { + static const long num_bits = tabs::value+1; + static const long dims = qc_helpers::exp_2_n::value; + }; + + template + struct gate_traits > + { + static const long num_bits = tmax::value, + tabs::value>::value+1; + static const long dims = qc_helpers::exp_2_n::value; + }; + + +// ---------------------------------------------------------------------------------------- + + namespace quantum_gates + { + + inline const gate<1> hadamard( + ) + { + gate<1> h; + h(0,0) = std::sqrt(1/2.0); + h(0,1) = std::sqrt(1/2.0); + h(1,0) = std::sqrt(1/2.0); + h(1,1) = -std::sqrt(1/2.0); + return h; + } + + // ------------------------------------------------------------------------------------ + + inline const gate<1> x( + ) + { + gate<1> x; + x(0,1) = 1; + x(1,0) = 1; + return x; + } + + // ------------------------------------------------------------------------------------ + + inline const gate<1> y( + ) + { + gate<1> x; + qc_scalar_type i(0,1); + x(0,1) = -i; + x(1,0) = i; + return x; + } + + // ------------------------------------------------------------------------------------ + + inline const gate<1> z( + ) + { + gate<1> z; + z(0,0) = 1; + z(1,1) = -1; + return z; + } + + // ------------------------------------------------------------------------------------ + + inline const gate<1> noop( + ) + { + gate<1> i; + i(0,0) = 1; + i(1,1) = 1; + return i; + } + + // ------------------------------------------------------------------------------------ + + template + class cnot : public gate_exp > + { + public: + COMPILE_TIME_ASSERT(control_bit != target_bit); + + cnot() : gate_exp(*this) + { + const int min_bit = std::min(control_bit, target_bit); + + control_mask = 1; + target_mask = 1; + + // make the masks so that their only on bit corresponds to the given control_bit and target_bit bits + for (int i = 0; i < control_bit-min_bit; ++i) + control_mask <<= 1; + for (int i = 0; i < target_bit-min_bit; ++i) + target_mask <<= 1; + } + + static const long num_bits = gate_traits::num_bits; + static const long dims = gate_traits::dims; + + const qc_scalar_type operator() (long r, long c) const + { + unsigned long output; + // if the input control bit is set + if (control_mask&c) + { + output = c^target_mask; + } + else + { + output = c; + } + + if ((unsigned long)r == output) + return 1; + else + return 0; + } + + template + qc_scalar_type compute_state_element ( + const matrix_exp& reg, + long row_idx + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(reg.nr() == dims && reg.nc() == 1 && + 0 <= row_idx && row_idx < dims, + "\tqc_scalar_type cnot::compute_state_element(reg,row_idx)" + << "\n\tinvalid arguments to this function" + << "\n\treg.nr(): " << reg.nr() + << "\n\treg.nc(): " << reg.nc() + << "\n\tdims: " << dims + << "\n\trow_idx: " << row_idx + << "\n\tthis: " << this + ); + + + unsigned long output = row_idx; + // if the input control bit is set + if (control_mask&output) + { + output = output^target_mask; + } + + return reg(output); + } + + private: + + unsigned long control_mask; + unsigned long target_mask; + + + }; + + // ------------------------------------------------------------------------------------ + + template + class toffoli : public gate_exp > + { + public: + COMPILE_TIME_ASSERT(control_bit1 != target_bit && control_bit2 != target_bit && control_bit1 != control_bit2); + COMPILE_TIME_ASSERT((control_bit1 < target_bit && control_bit2 < target_bit) ||(control_bit1 > target_bit && control_bit2 > target_bit) ); + + toffoli() : gate_exp(*this) + { + const int min_bit = std::min(std::min(control_bit1, control_bit2), target_bit); + + control1_mask = 1; + control2_mask = 1; + target_mask = 1; + + // make the masks so that their only on bit corresponds to the given control_bit1 and target_bit bits + for (int i = 0; i < control_bit1-min_bit; ++i) + control1_mask <<= 1; + for (int i = 0; i < control_bit2-min_bit; ++i) + control2_mask <<= 1; + for (int i = 0; i < target_bit-min_bit; ++i) + target_mask <<= 1; + } + + static const long num_bits = gate_traits::num_bits; + static const long dims = gate_traits::dims; + + const qc_scalar_type operator() (long r, long c) const + { + unsigned long output; + // if the input control bits are set + if ((control1_mask&c) && (control2_mask&c)) + { + output = c^target_mask; + } + else + { + output = c; + } + + if ((unsigned long)r == output) + return 1; + else + return 0; + } + + template + qc_scalar_type compute_state_element ( + const matrix_exp& reg, + long row_idx + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(reg.nr() == dims && reg.nc() == 1 && + 0 <= row_idx && row_idx < dims, + "\tqc_scalar_type toffoli::compute_state_element(reg,row_idx)" + << "\n\tinvalid arguments to this function" + << "\n\treg.nr(): " << reg.nr() + << "\n\treg.nc(): " << reg.nc() + << "\n\tdims: " << dims + << "\n\trow_idx: " << row_idx + << "\n\tthis: " << this + ); + + + unsigned long output; + // if the input control bits are set + if ((control1_mask&row_idx) && (control2_mask&row_idx)) + { + output = row_idx^target_mask; + } + else + { + output = row_idx; + } + + return reg(output); + + } + + private: + + unsigned long control1_mask; + unsigned long control2_mask; + unsigned long target_mask; + + + }; + + + // ------------------------------------------------------------------------------------ + + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_QUANTUM_COMPUTINg_1_ + + diff --git a/ml/dlib/dlib/quantum_computing/quantum_computing_abstract.h b/ml/dlib/dlib/quantum_computing/quantum_computing_abstract.h new file mode 100644 index 000000000..bcc65af23 --- /dev/null +++ b/ml/dlib/dlib/quantum_computing/quantum_computing_abstract.h @@ -0,0 +1,590 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_QUANTUM_COMPUTINg_ABSTRACT_ +#ifdef DLIB_QUANTUM_COMPUTINg_ABSTRACT_ + +#include +#include "../matrix.h" +#include "../rand.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + typedef std::complex qc_scalar_type; + +// ---------------------------------------------------------------------------------------- + + class quantum_register + { + /*! + INITIAL VALUE + - num_bits() == 1 + - state_vector().nr() == 2 + - state_vector().nc() == 1 + - state_vector()(0) == 1 + - state_vector()(1) == 0 + - probability_of_bit(0) == 0 + + - i.e. This register represents a single quantum bit and it is + completely in the 0 state. + + WHAT THIS OBJECT REPRESENTS + This object represents a set of quantum bits. + !*/ + + public: + + quantum_register( + ); + /*! + ensures + - this object is properly initialized + !*/ + + int num_bits ( + ) const; + /*! + ensures + - returns the number of quantum bits in this register + !*/ + + void set_num_bits ( + int new_num_bits + ); + /*! + requires + - 1 <= new_num_bits <= 30 + ensures + - #num_bits() == new_num_bits + - #state_vector().nr() == 2^new_num_bits + (i.e. the size of the state_vector is exponential in the number of bits in a register) + - for all valid i: + - probability_of_bit(i) == 0 + !*/ + + void zero_all_bits( + ); + /*! + ensures + - for all valid i: + - probability_of_bit(i) == 0 + !*/ + + void append ( + const quantum_register& reg + ); + /*! + ensures + - #num_bits() == num_bits() + reg.num_bits() + - #this->state_vector() == tensor_product(this->state_vector(), reg.state_vector()) + - The original bits in *this become the high order bits of the resulting + register and all the bits in reg end up as the low order bits in the + resulting register. + !*/ + + double probability_of_bit ( + int bit + ) const; + /*! + requires + - 0 <= bit < num_bits() + ensures + - returns the probability of measuring the given bit and it being in the 1 state. + - The returned value is also equal to the sum of norm(state_vector()(i)) for all + i where the bit'th bit in i is set to 1. (note that the lowest order bit is bit 0) + !*/ + + template + bool measure_bit ( + int bit, + rand_type& rnd + ); + /*! + requires + - 0 <= bit < num_bits() + - rand_type == an implementation of dlib/rand/rand_float_abstract.h + ensures + - measures the given bit in this register. Let R denote the boolean + result of the measurement, where true means the bit was measured to + have value 1 and false means it had a value of 0. + - if (R == true) then + - returns true + - #probability_of_bit(bit) == 1 + - else + - returns false + - #probability_of_bit(bit) == 0 + !*/ + + template + bool measure_and_remove_bit ( + int bit, + rand_type& rnd + ); + /*! + requires + - num_bits() > 1 + - 0 <= bit < num_bits() + - rand_type == an implementation of dlib/rand/rand_float_abstract.h + ensures + - measures the given bit in this register. Let R denote the boolean + result of the measurement, where true means the bit was measured to + have value 1 and false means it had a value of 0. + - #num_bits() == num_bits() - 1 + - removes the bit that was measured from this register. + - if (R == true) then + - returns true + - else + - returns false + !*/ + + const matrix& state_vector( + ) const; + /*! + ensures + - returns a const reference to the state vector that describes the state of + the quantum bits in this register. + !*/ + + matrix& state_vector( + ); + /*! + ensures + - returns a non-const reference to the state vector that describes the state of + the quantum bits in this register. + !*/ + + void swap ( + quantum_register& item + ); + /*! + ensures + - swaps *this and item + !*/ + + }; + + inline void swap ( + quantum_register& a, + quantum_register& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + +// ---------------------------------------------------------------------------------------- + + template + class gate_exp + { + /*! + REQUIREMENTS ON T + T must be some object that inherits from gate_exp and implements its own + version of operator() and compute_state_element(). + + WHAT THIS OBJECT REPRESENTS + This object represents an expression that evaluates to a quantum gate + that operates on T::num_bits qubits. + + This object makes it easy to create new types of gate objects. All + you need to do is inherit from gate_exp in the proper way and + then you can use your new gate objects in conjunction with all the + others. + !*/ + + public: + + static const long num_bits = T::num_bits; + static const long dims = T::dims; + + gate_exp( + T& exp + ); + /*! + ensures + - #&ref() == &exp + !*/ + + const qc_scalar_type operator() ( + long r, + long c + ) const; + /*! + requires + - 0 <= r < dims + - 0 <= c < dims + ensures + - returns ref()(r,c) + !*/ + + void apply_gate_to ( + quantum_register& reg + ) const; + /*! + requires + - reg.num_bits() == num_bits + ensures + - applies this quantum gate to the given quantum register + - Let M represent the matrix for this quantum gate, then + #reg().state_vector() = M*reg().state_vector() + !*/ + + template + qc_scalar_type compute_state_element ( + const matrix_exp& reg, + long row_idx + ) const; + /*! + requires + - reg.nr() == dims + - reg.nc() == 1 + - 0 <= row_idx < dims + ensures + - Let M represent the matrix for this gate, then + this function returns rowm(M*reg, row_idx) + (i.e. returns the row_idx row of what you get when you apply this + gate to the given column vector in reg) + - This function works by calling ref().compute_state_element(reg,row_idx) + !*/ + + const T& ref( + ); + /*! + ensures + - returns a reference to the subexpression contained in this object + !*/ + + const matrix mat ( + ) const; + /*! + ensures + - returns a dense matrix object that contains the matrix for this gate + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + template + class composite_gate : public gate_exp > + { + /*! + REQUIREMENTS ON T AND U + Both must be gate expressions that inherit from gate_exp + + WHAT THIS OBJECT REPRESENTS + This object represents a quantum gate that is the tensor product of + two other quantum gates. + + + As an example, suppose you have 3 registers, reg_high, reg_low, and reg_all. Also + suppose that reg_all is what you get when you append reg_high and reg_low, + so reg_all.state_vector() == tensor_product(reg_high.state_vector(),reg_low.state_vector()). + + Then applying a composite gate to reg_all would give you the same thing as + applying the lhs gate to reg_high and the rhs gate to reg_low and then appending + the two resulting registers. So the lhs gate of a composite_gate applies to + the high order bits of a regitser and the rhs gate applies to the lower order bits. + !*/ + public: + + composite_gate ( + const composite_gate& g + ); + /*! + ensures + - *this is a copy of g + !*/ + + composite_gate( + const gate_exp& lhs_, + const gate_exp& rhs_ + ): + /*! + ensures + - #lhs == lhs_.ref() + - #rhs == rhs_.ref() + - #num_bits == T::num_bits + U::num_bits + - #dims == 2^num_bits + - #&ref() == this + !*/ + + const qc_scalar_type operator() ( + long r, + long c + ) const; + /*! + requires + - 0 <= r < dims + - 0 <= c < dims + ensures + - Let M denote the tensor product of lhs with rhs, then this function + returns M(r,c) + (i.e. returns lhs(r/U::dims,c/U::dims)*rhs(r%U::dims, c%U::dims)) + !*/ + + template + qc_scalar_type compute_state_element ( + const matrix_exp& reg, + long row_idx + ) const; + /*! + requires + - reg.nr() == dims + - reg.nc() == 1 + - 0 <= row_idx < dims + ensures + - Let M represent the matrix for this gate, then this function + returns rowm(M*reg, row_idx) + (i.e. returns the row_idx row of what you get when you apply this + gate to the given column vector in reg) + - This function works by calling rhs.compute_state_element() and using elements + of the matrix in lhs. + !*/ + + static const long num_bits; + static const long dims; + + const T lhs; + const U rhs; + }; + +// ---------------------------------------------------------------------------------------- + + template + class gate : public gate_exp > + { + /*! + REQUIREMENTS ON bits + 0 < bits <= 30 + + WHAT THIS OBJECT REPRESENTS + This object represents a quantum gate that operates on bits qubits. + It stores its gate matrix explicitly in a dense in-memory matrix. + !*/ + + public: + gate( + ); + /*! + ensures + - num_bits == bits + - dims == 2^bits + - #&ref() == this + - for all valid r and c: + #(*this)(r,c) == 0 + !*/ + + gate ( + const gate& g + ); + /*! + ensures + - *this is a copy of g + !*/ + + template + explicit gate( + const gate_exp& g + ); + /*! + requires + - T::num_bits == num_bits + ensures + - num_bits == bits + - dims == 2^bits + - #&ref() == this + - for all valid r and c: + #(*this)(r,c) == g(r,c) + !*/ + + const qc_scalar_type& operator() ( + long r, + long c + ) const; + /*! + requires + - 0 <= r < dims + - 0 <= c < dims + ensures + - Let M denote the matrix for this gate, then this function + returns a const reference to M(r,c) + !*/ + + qc_scalar_type& operator() ( + long r, + long c + ); + /*! + requires + - 0 <= r < dims + - 0 <= c < dims + ensures + - Let M denote the matrix for this gate, then this function + returns a non-const reference to M(r,c) + !*/ + + template + qc_scalar_type compute_state_element ( + const matrix_exp& reg, + long row_idx + ) const; + /*! + requires + - reg.nr() == dims + - reg.nc() == 1 + - 0 <= row_idx < dims + ensures + - Let M represent the matrix for this gate, then this function + returns rowm(M*reg, row_idx) + (i.e. returns the row_idx row of what you get when you apply this + gate to the given column vector in reg) + !*/ + + static const long num_bits; + static const long dims; + + }; + +// ---------------------------------------------------------------------------------------- + + template + const composite_gate operator, ( + const gate_exp& lhs, + const gate_exp& rhs + ) { return composite_gate(lhs,rhs); } + /*! + ensures + - returns a composite_gate that represents the tensor product of the lhs + gate with the rhs gate. + !*/ + +// ---------------------------------------------------------------------------------------- + + namespace quantum_gates + { + + inline const gate<1> hadamard( + ); + /*! + ensures + - returns the Hadamard gate. + (i.e. A gate with a matrix of + |1, 1| + 1/sqrt(2) * |1,-1| ) + !*/ + + inline const gate<1> x( + ); + /*! + ensures + - returns the not gate. + (i.e. A gate with a matrix of + |0, 1| + |1, 0| ) + !*/ + + inline const gate<1> y( + ); + /*! + ensures + - returns the y gate. + (i.e. A gate with a matrix of + |0,-i| + |i, 0| ) + !*/ + + inline const gate<1> z( + ); + /*! + ensures + - returns the z gate. + (i.e. A gate with a matrix of + |1, 0| + |0,-1| ) + !*/ + + inline const gate<1> noop( + ); + /*! + ensures + - returns the no-op or identity gate. + (i.e. A gate with a matrix of + |1, 0| + |0, 1| ) + !*/ + + template < + int control_bit, + int target_bit + > + class cnot : public gate_exp > + { + /*! + REQUIREMENTS ON control_bit AND target_bit + - control_bit != target_bit + + WHAT THIS OBJECT REPRESENTS + This object represents the controlled-not quantum gate. It is a gate that + operates on abs(control_bit-target_bit)+1 qubits. + + In terms of the computational basis vectors, this gate maps input + vectors to output vectors in the following way: + - if (the input vector corresponds to a state where the control_bit + qubit is 1) then + - this gate outputs the computational basis vector that + corresponds to the state where the target_bit has been flipped + with respect to the input vector + - else + - this gate outputs the input vector unmodified + + !*/ + }; + + template < + int control_bit1, + int control_bit2, + int target_bit + > + class toffoli : public gate_exp > + { + /*! + REQUIREMENTS ON control_bit1, control_bit2, AND target_bit + - all the arguments denote different bits, i.e.: + - control_bit1 != target_bit + - control_bit2 != target_bit + - control_bit1 != control_bit2 + - The target bit can't be in-between the control bits, i.e.: + - (control_bit1 < target_bit && control_bit2 < target_bit) || + (control_bit1 > target_bit && control_bit2 > target_bit) + + WHAT THIS OBJECT REPRESENTS + This object represents the toffoli variant of a controlled-not quantum gate. + It is a gate that operates on max(abs(control_bit2-target_bit),abs(control_bit1-target_bit))+1 + qubits. + + In terms of the computational basis vectors, this gate maps input + vectors to output vectors in the following way: + - if (the input vector corresponds to a state where the control_bit1 and + control_bit2 qubits are 1) then + - this gate outputs the computational basis vector that + corresponds to the state where the target_bit has been flipped + with respect to the input vector + - else + - this gate outputs the input vector unmodified + + !*/ + }; + + // ------------------------------------------------------------------------------------ + + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_QUANTUM_COMPUTINg_ABSTRACT_ + + + diff --git a/ml/dlib/dlib/queue.h b/ml/dlib/dlib/queue.h new file mode 100644 index 000000000..b0f331dc9 --- /dev/null +++ b/ml/dlib/dlib/queue.h @@ -0,0 +1,84 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_QUEUe_ +#define DLIB_QUEUe_ + +#include "queue/queue_kernel_1.h" +#include "queue/queue_kernel_2.h" +#include "queue/queue_kernel_c.h" + +#include "queue/queue_sort_1.h" + + +#include "algs.h" + +namespace dlib +{ + + template < + typename T, + typename mem_manager = default_memory_manager + > + class queue + { + queue() {} + public: + + + //----------- kernels --------------- + + // kernel_1a + typedef queue_kernel_1 + kernel_1a; + typedef queue_kernel_c + kernel_1a_c; + + + // kernel_2a + typedef queue_kernel_2 + kernel_2a; + typedef queue_kernel_c + kernel_2a_c; + + + // kernel_2b + typedef queue_kernel_2 + kernel_2b; + typedef queue_kernel_c + kernel_2b_c; + + + + + //---------- extensions ------------ + + // sort_1 extend kernel_1a + typedef queue_sort_1 + sort_1a; + typedef queue_sort_1 + sort_1a_c; + + + // sort_1 extend kernel_2a + typedef queue_sort_1 + sort_1b; + typedef queue_sort_1 + sort_1b_c; + + + + // sort_1 extend kernel_2b + typedef queue_sort_1 + sort_1c; + typedef queue_sort_1 + sort_1c_c; + + + + + + }; +} + +#endif // DLIB_QUEUe_ + diff --git a/ml/dlib/dlib/queue/queue_kernel_1.h b/ml/dlib/dlib/queue/queue_kernel_1.h new file mode 100644 index 000000000..e59bf2659 --- /dev/null +++ b/ml/dlib/dlib/queue/queue_kernel_1.h @@ -0,0 +1,554 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_QUEUE_KERNEl_1_ +#define DLIB_QUEUE_KERNEl_1_ + +#include "queue_kernel_abstract.h" +#include "../algs.h" +#include "../interfaces/enumerable.h" +#include "../interfaces/remover.h" +#include "../serialize.h" + +namespace dlib +{ + + template < + typename T, + typename mem_manager = default_memory_manager + > + class queue_kernel_1 : public enumerable, + public remover + { + + /*! + INITIAL VALUE + queue_size == 0 + current_element == 0 + at_start_ == true + + CONVENTION + queue_size == the number of elements in the queue + at_start() == at_start_ + current_element_valid() == (current_element != 0) + element() == current_element->item + + if (queue_size > 0) + { + in points to the last element to be inserted into the queue + out points to the next element to be dequeued + + each node points to the node inserted after it except for the most + recently inserted node + + current_element == 0 + } + + !*/ + + + struct node + { + node* last; + + T item; + }; + + + public: + + typedef T type; + typedef mem_manager mem_manager_type; + + queue_kernel_1 ( + ) : + in(0), + out(0), + queue_size(0), + current_element(0), + at_start_(true) + { + } + + virtual ~queue_kernel_1 ( + ); + + inline void clear( + ); + + void enqueue ( + T& item + ); + + void dequeue ( + T& item + ); + + void cat ( + queue_kernel_1& item + ); + + T& current ( + ); + + const T& current ( + ) const; + + void swap ( + queue_kernel_1& item + ); + + // functions from the remover interface + inline void remove_any ( + T& item + ); + + // functions from the enumerable interface + inline size_t size ( + ) const; + + inline bool at_start ( + ) const; + + inline void reset ( + ) const; + + bool current_element_valid ( + ) const; + + inline const T& element ( + ) const; + + inline T& element ( + ); + + bool move_next ( + ) const; + + private: + + void delete_nodes ( + node* start, + unsigned long length + ); + /*! + requires + - start points to a node in a singly linked list + - start->last points to the next node in the list + - there are at least length nodes in the list begining with start + ensures + - length nodes have been deleted starting with the node pointed + to by start + !*/ + + // data members + + node* in; + node* out; + unsigned long queue_size; + mutable node* current_element; + mutable bool at_start_; + + // restricted functions + queue_kernel_1(queue_kernel_1&); // copy constructor + queue_kernel_1& operator=(queue_kernel_1&); // assignment operator + + }; + + template < + typename T, + typename mem_manager + > + inline void swap ( + queue_kernel_1& a, + queue_kernel_1& b + ) { a.swap(b); } + + template < + typename T, + typename mem_manager + > + void deserialize ( + queue_kernel_1& item, + std::istream& in + ) + { + try + { + item.clear(); + unsigned long size; + deserialize(size,in); + T temp; + for (unsigned long i = 0; i < size; ++i) + { + deserialize(temp,in); + item.enqueue(temp); + } + } + catch (serialization_error e) + { + item.clear(); + throw serialization_error(e.info + "\n while deserializing object of type queue_kernel_1"); + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + queue_kernel_1:: + ~queue_kernel_1 ( + ) + { + delete_nodes(out,queue_size); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void queue_kernel_1:: + clear ( + ) + { + delete_nodes(out,queue_size); + queue_size = 0; + + // put the enumerator at the start + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void queue_kernel_1:: + enqueue ( + T& item + ) + { + // make new node + node* temp = new node; + + // swap item into new node + exchange(item,temp->item); + + if (queue_size == 0) + out = temp; + else + in->last = temp; + + // make in point to the new node + in = temp; + + ++queue_size; + + // put the enumerator at the start + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void queue_kernel_1:: + dequeue ( + T& item + ) + { + // swap out into item + exchange(item,out->item); + + --queue_size; + + if (queue_size == 0) + { + delete out; + } + else + { + node* temp = out; + + // move out pointer to the next element in the queue + out = out->last; + + // delete old node + delete temp; + } + + // put the enumerator at the start + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void queue_kernel_1:: + cat ( + queue_kernel_1& item + ) + { + if (item.queue_size > 0) + { + if (queue_size > 0) + { + in->last = item.out; + } + else + { + out = item.out; + } + + + in = item.in; + queue_size += item.queue_size; + item.queue_size = 0; + } + + // put the enumerator at the start + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + T& queue_kernel_1:: + current ( + ) + { + return out->item; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + const T& queue_kernel_1:: + current ( + ) const + { + return out->item; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void queue_kernel_1:: + swap ( + queue_kernel_1& item + ) + { + node* in_temp = in; + node* out_temp = out; + unsigned long queue_size_temp = queue_size; + node* current_element_temp = current_element; + bool at_start_temp = at_start_; + + in = item.in; + out = item.out; + queue_size = item.queue_size; + current_element = item.current_element; + at_start_ = item.at_start_; + + item.in = in_temp; + item.out = out_temp; + item.queue_size = queue_size_temp; + item.current_element = current_element_temp; + item.at_start_ = at_start_temp; + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // enumerable function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + bool queue_kernel_1:: + at_start ( + ) const + { + return at_start_; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + size_t queue_kernel_1:: + size ( + ) const + { + return queue_size; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void queue_kernel_1:: + reset ( + ) const + { + at_start_ = true; + current_element = 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + bool queue_kernel_1:: + current_element_valid ( + ) const + { + return (current_element != 0); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + const T& queue_kernel_1:: + element ( + ) const + { + return current_element->item; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + T& queue_kernel_1:: + element ( + ) + { + return current_element->item; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + bool queue_kernel_1:: + move_next ( + ) const + { + if (at_start_) + { + at_start_ = false; + // if the queue is empty then there is nothing to do + if (queue_size == 0) + { + return false; + } + else + { + current_element = out; + return true; + } + } + else + { + // if we are at the last element then the enumeration has finished + if (current_element == in || current_element == 0) + { + current_element = 0; + return false; + } + else + { + current_element = current_element->last; + return true; + } + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // remover function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void queue_kernel_1:: + remove_any ( + T& item + ) + { + dequeue(item); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // private member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void queue_kernel_1:: + delete_nodes ( + node* start, + unsigned long length + ) + { + node* temp; + while (length) + { + temp = start->last; + delete start; + start = temp; + --length; + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_QUEUE_KERNEl_1_ + diff --git a/ml/dlib/dlib/queue/queue_kernel_2.h b/ml/dlib/dlib/queue/queue_kernel_2.h new file mode 100644 index 000000000..8e4536ae9 --- /dev/null +++ b/ml/dlib/dlib/queue/queue_kernel_2.h @@ -0,0 +1,600 @@ +// Copyright (C) 2004 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_QUEUE_KERNEl_2_ +#define DLIB_QUEUE_KERNEl_2_ + +#include "queue_kernel_abstract.h" +#include "../algs.h" +#include "../assert.h" +#include "../interfaces/enumerable.h" +#include "../interfaces/remover.h" +#include "../serialize.h" + +namespace dlib +{ + + template < + typename T, + unsigned long block_size, + typename mem_manager = default_memory_manager + > + class queue_kernel_2 : public enumerable, + public remover + { + + /*! + REQUIREMENTS ON block_size + 0 < block_size < 2000000000 + + INITIAL VALUE + queue_size == 0 + current_element == 0 + at_start_ == true + + CONVENTION + queue_size == the number of elements in the queue + at_start() == at_start_ + current_element_valid() == (current_element != 0) + if (current_element_valid()) then + element() == current_element->item[current_element_pos] + + if (queue_size > 0) + { + in->item[in_pos] == the spot where we will put the next item added + into the queue + out->item[out_pos] == current() + + when enqueuing elements inside each node item[0] is filled first, then + item[1], then item[2], etc. + + + each node points to the node inserted after it except for the most + recently inserted node. + } + + !*/ + + + struct node + { + node* next; + + T item[block_size]; + }; + + + public: + + typedef T type; + typedef mem_manager mem_manager_type; + + queue_kernel_2 ( + ) : + in(0), + out(0), + queue_size(0), + current_element(0), + at_start_(true) + { + } + + virtual ~queue_kernel_2 ( + ); + + inline void clear( + ); + + void enqueue ( + T& item + ); + + void dequeue ( + T& item + ); + + void cat ( + queue_kernel_2& item + ); + + T& current ( + ); + + const T& current ( + ) const; + + void swap ( + queue_kernel_2& item + ); + + // functions from the remover interface + inline void remove_any ( + T& item + ); + + // functions from the enumerable interface + inline size_t size ( + ) const; + + inline bool at_start ( + ) const; + + inline void reset ( + ) const; + + bool current_element_valid ( + ) const; + + inline const T& element ( + ) const; + + inline T& element ( + ); + + bool move_next ( + ) const; + + private: + + void delete_nodes ( + node* start, + node* end + ); + /*! + requires + - start points to a node in a singly linked list + - start->next points to the next node in the list + - by following the next pointers you eventually hit the node pointed + to by end + ensures + - calls delete on the start node, the end node, and all nodes in between + !*/ + + // data members + + typename mem_manager::template rebind::other pool; + + node* in; + node* out; + size_t queue_size; + size_t in_pos; + size_t out_pos; + + + mutable node* current_element; + mutable size_t current_element_pos; + mutable bool at_start_; + + // restricted functions + queue_kernel_2(queue_kernel_2&); // copy constructor + queue_kernel_2& operator=(queue_kernel_2&); // assignment operator + + }; + + template < + typename T, + unsigned long block_size, + typename mem_manager + > + inline void swap ( + queue_kernel_2& a, + queue_kernel_2& b + ) { a.swap(b); } + + template < + typename T, + unsigned long block_size, + typename mem_manager + > + void deserialize ( + queue_kernel_2& item, + std::istream& in + ) + { + try + { + item.clear(); + unsigned long size; + deserialize(size,in); + T temp; + for (unsigned long i = 0; i < size; ++i) + { + deserialize(temp,in); + item.enqueue(temp); + } + } + catch (serialization_error e) + { + item.clear(); + throw serialization_error(e.info + "\n while deserializing object of type queue_kernel_2"); + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T, + unsigned long block_size, + typename mem_manager + > + queue_kernel_2:: + ~queue_kernel_2 ( + ) + { + COMPILE_TIME_ASSERT(0 < block_size && block_size < (unsigned long)(2000000000)); + + if (queue_size > 0) + delete_nodes(out,in); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + unsigned long block_size, + typename mem_manager + > + void queue_kernel_2:: + clear ( + ) + { + if (queue_size > 0) + { + delete_nodes(out,in); + queue_size = 0; + } + + // put the enumerator at the start + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + unsigned long block_size, + typename mem_manager + > + void queue_kernel_2:: + enqueue ( + T& item + ) + { + if (queue_size == 0) + { + out = in = pool.allocate(); + in_pos = 0; + out_pos = 0; + } + else if (in_pos >= block_size) + { + in->next = pool.allocate(); + in_pos = 0; + in = in->next; + } + + exchange(item,in->item[in_pos]); + ++in_pos; + + ++queue_size; + + // put the enumerator at the start + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + unsigned long block_size, + typename mem_manager + > + void queue_kernel_2:: + dequeue ( + T& item + ) + { + // swap out into item + exchange(item,out->item[out_pos]); + + ++out_pos; + --queue_size; + + // if this was the last element in this node then remove this node + if (out_pos == block_size) + { + out_pos = 0; + node* temp = out; + out = out->next; + pool.deallocate(temp); + } + else if (queue_size == 0) + { + pool.deallocate(out); + } + + // put the enumerator at the start + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + unsigned long block_size, + typename mem_manager + > + void queue_kernel_2:: + cat ( + queue_kernel_2& item + ) + { + if (queue_size > 0) + { + T temp; + assign_zero_if_built_in_scalar_type(temp); + while (item.size() > 0) + { + item.dequeue(temp); + enqueue(temp); + } + } + else + { + in = item.in; + out = item.out; + out_pos = item.out_pos; + in_pos = item.in_pos; + + queue_size = item.queue_size; + item.queue_size = 0; + + // put the enumerator at the start + reset(); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + unsigned long block_size, + typename mem_manager + > + T& queue_kernel_2:: + current ( + ) + { + return out->item[out_pos]; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + unsigned long block_size, + typename mem_manager + > + const T& queue_kernel_2:: + current ( + ) const + { + return out->item[out_pos]; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + unsigned long block_size, + typename mem_manager + > + void queue_kernel_2:: + swap ( + queue_kernel_2& item + ) + { + exchange(in,item.in); + exchange(out,item.out); + exchange(queue_size,item.queue_size); + exchange(in_pos,item.in_pos); + exchange(out_pos,item.out_pos); + exchange(current_element,item.current_element); + exchange(current_element_pos,item.current_element_pos); + exchange(at_start_,item.at_start_); + pool.swap(item.pool); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // enumerable function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T, + unsigned long block_size, + typename mem_manager + > + size_t queue_kernel_2:: + size ( + ) const + { + return queue_size; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + unsigned long block_size, + typename mem_manager + > + bool queue_kernel_2:: + at_start ( + ) const + { + return at_start_; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + unsigned long block_size, + typename mem_manager + > + void queue_kernel_2:: + reset ( + ) const + { + at_start_ = true; + current_element = 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + unsigned long block_size, + typename mem_manager + > + bool queue_kernel_2:: + current_element_valid ( + ) const + { + return (current_element != 0); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + unsigned long block_size, + typename mem_manager + > + const T& queue_kernel_2:: + element ( + ) const + { + return current_element->item[current_element_pos]; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + unsigned long block_size, + typename mem_manager + > + T& queue_kernel_2:: + element ( + ) + { + return current_element->item[current_element_pos]; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + unsigned long block_size, + typename mem_manager + > + bool queue_kernel_2:: + move_next ( + ) const + { + if (at_start_) + { + at_start_ = false; + // if the queue is empty then there is nothing to do + if (queue_size == 0) + { + return false; + } + else + { + current_element = out; + current_element_pos = out_pos; + return true; + } + } + else if (current_element == 0) + { + return false; + } + else + { + ++current_element_pos; + // if we are at the last element then the enumeration has finished + if (current_element == in && current_element_pos == in_pos ) + { + current_element = 0; + return false; + } + else if (current_element_pos == block_size) + { + current_element_pos = 0; + current_element = current_element->next; + } + + return true; + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // remover function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T, + unsigned long block_size, + typename mem_manager + > + void queue_kernel_2:: + remove_any ( + T& item + ) + { + dequeue(item); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // private member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T, + unsigned long block_size, + typename mem_manager + > + void queue_kernel_2:: + delete_nodes ( + node* start, + node* end + ) + { + node* temp; + while (start != end) + { + temp = start; + start = start->next; + pool.deallocate(temp); + } + pool.deallocate(start); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_QUEUE_KERNEl_2_ + diff --git a/ml/dlib/dlib/queue/queue_kernel_abstract.h b/ml/dlib/dlib/queue/queue_kernel_abstract.h new file mode 100644 index 000000000..4fd4c7dd1 --- /dev/null +++ b/ml/dlib/dlib/queue/queue_kernel_abstract.h @@ -0,0 +1,196 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_QUEUE_KERNEl_ABSTRACT_ +#ifdef DLIB_QUEUE_KERNEl_ABSTRACT_ + +#include "../interfaces/enumerable.h" +#include "../interfaces/remover.h" +#include "../serialize.h" +#include "../algs.h" + +namespace dlib +{ + + template < + typename T, + typename mem_manager = default_memory_manager + > + class queue : public enumerable, + public remover + { + + /*! + REQUIREMENTS ON T + T must be swappable by a global swap() + T must have a default constructor + + REQUIREMENTS ON mem_manager + must be an implementation of memory_manager/memory_manager_kernel_abstract.h or + must be an implementation of memory_manager_global/memory_manager_global_kernel_abstract.h or + must be an implementation of memory_manager_stateless/memory_manager_stateless_kernel_abstract.h + mem_manager::type can be set to anything. + + POINTERS AND REFERENCES TO INTERNAL DATA + swap() and current() functions do not invalidate pointers or + references to internal data. + All other functions have no such guarantee. + + INITIAL VALUE + size() == 0 + + ENUMERATION ORDER + The enumerator will iterate over the elements in the queue in the + same order they would be removed by repeated calls to dequeue(). + (e.g. current() would be the first element enumerated) + + WHAT THIS OBJECT REPRESENTS + This is a first in first out queue containing items of type T + + e.g. if the queue is {b,c,d,e} and then 'a' is enqueued + the queue becomes {a,b,c,d,e} and then calling dequeue takes e out + making the queue {a,b,c,d} + + Also note that unless specified otherwise, no member functions + of this object throw exceptions. + !*/ + + public: + + typedef T type; + typedef mem_manager mem_manager_type; + + queue ( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc or any exception thrown by T's constructor + !*/ + + virtual ~queue ( + ); + /*! + ensures + - all memory associated with *this has been released + !*/ + + void clear( + ); + /*! + ensures + - #*this has its initial value + throws + - std::bad_alloc or any exception thrown by T's constructor + if this exception is thrown then *this is unusable + until clear() is called and succeeds + !*/ + + void enqueue ( + T& item + ); + /*! + ensures + - item is now at the left end of #*this + - #item has an initial value for its type + - #size() == size() + 1 + - #at_start() == true + throws + - std::bad_alloc or any exception thrown by T's constructor + if enqueue() throws then it has no effect + !*/ + + void dequeue ( + T& item + ); + /*! + requires + - size() != 0 + ensures + - #size() == size() - 1 + - the far right element of *this has been removed and swapped + into #item + - #at_start() == true + !*/ + + void cat ( + queue& item + ); + /*! + ensures + - item has been concatenated onto the left end of *this. + i.e. item.current() is attached onto the left end of *this and + the left most element in item will also be the left most item + in #*this + - #size() == size() + item.size() + - #item has its initial value + - #at_start() == true + throws + - std::bad_alloc or any exception thrown by T's constructor + if cat() throws then the state of #item and *this is undefined + until clear() is successfully called on them. + !*/ + + T& current ( + ); + /*! + requires + - size() != 0 + ensures + - returns a const reference to the next element to be dequeued. + i.e. the right most element. + !*/ + + const T& current ( + ) const; + /*! + requires + - size() != 0 + ensures + - returns a non-const reference to the next element to be dequeued. + i.e. the right most element. + !*/ + + void swap ( + queue& item + ); + /*! + ensures + - swaps *this and item + !*/ + + private: + + // restricted functions + queue(queue&); // copy constructor + queue& operator=(queue&); // assignment operator + + }; + + template < + typename T, + typename mem_manager + > + inline void swap ( + queue& a, + queue& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + + template < + typename T, + typename mem_manager + > + void deserialize ( + queue& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ +} + +#endif // DLIB_QUEUE_KERNEl_ABSTRACT_ + diff --git a/ml/dlib/dlib/queue/queue_kernel_c.h b/ml/dlib/dlib/queue/queue_kernel_c.h new file mode 100644 index 000000000..554c9aecd --- /dev/null +++ b/ml/dlib/dlib/queue/queue_kernel_c.h @@ -0,0 +1,187 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_QUEUE_KERNEl_C_ +#define DLIB_QUEUE_KERNEl_C_ + +#include "queue_kernel_abstract.h" +#include "../algs.h" +#include "../assert.h" + +namespace dlib +{ + + + template < + typename queue_base // is an implementation of queue_kernel_abstract.h + > + class queue_kernel_c : public queue_base + { + typedef typename queue_base::type T; + + public: + + void dequeue ( + T& item + ); + + T& current ( + ); + + const T& current ( + ) const; + + const T& element ( + ) const; + + T& element ( + ); + + void remove_any ( + T& item + ); + + }; + + template < + typename queue_base + > + inline void swap ( + queue_kernel_c& a, + queue_kernel_c& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename queue_base + > + void queue_kernel_c:: + dequeue ( + T& item + ) + { + + // make sure requires clause is not broken + DLIB_CASSERT(this->size() != 0, + "\tvoid queue::dequeue" + << "\n\tsize of queue should not be zero" + << "\n\tthis: " << this + ); + + // call the real function + queue_base::dequeue(item); + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename queue_base + > + const typename queue_base::type& queue_kernel_c:: + current ( + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT(this->size() != 0, + "\tconst T& queue::current" + << "\n\tsize of queue should not be zero" + << "\n\tthis: " << this + ); + + // call the real function + return queue_base::current(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename queue_base + > + typename queue_base::type& queue_kernel_c:: + current ( + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(this->size() != 0, + "\tT& queue::current" + << "\n\tsize of queue should not be zero" + << "\n\tthis: " << this + ); + + // call the real function + return queue_base::current(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename queue_base + > + const typename queue_base::type& queue_kernel_c:: + element ( + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT(this->current_element_valid() == true, + "\tconst T& queue::element" + << "\n\tyou can't access the current element if it doesn't exist" + << "\n\tthis: " << this + ); + + // call the real function + return queue_base::element(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename queue_base + > + typename queue_base::type& queue_kernel_c:: + element ( + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(this->current_element_valid() == true, + "\tT& queue::element" + << "\n\tyou can't access the current element if it doesn't exist" + << "\n\tthis: " << this + ); + + // call the real function + return queue_base::element(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename queue_base + > + void queue_kernel_c:: + remove_any ( + T& item + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( (this->size() > 0), + "\tvoid queue::remove_any" + << "\n\tsize() must be greater than zero if something is going to be removed" + << "\n\tsize(): " << this->size() + << "\n\tthis: " << this + ); + + // call the real function + queue_base::remove_any(item); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_QUEUE_KERNEl_C_ + diff --git a/ml/dlib/dlib/queue/queue_sort_1.h b/ml/dlib/dlib/queue/queue_sort_1.h new file mode 100644 index 000000000..bc20dcb82 --- /dev/null +++ b/ml/dlib/dlib/queue/queue_sort_1.h @@ -0,0 +1,165 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_QUEUE_SORt_1_ +#define DLIB_QUEUE_SORt_1_ + +#include "queue_sort_abstract.h" +#include "../algs.h" +#include +#include "../sort.h" + +namespace dlib +{ + + template < + typename queue_base + > + class queue_sort_1 : public queue_base + { + typedef typename queue_base::type T; + + public: + + /*! + This implementation uses the QuickSort algorithm and + when the quicksort depth goes too high it uses the dlib::qsort_array() + function on the data. + !*/ + + void sort ( + ); + + template + void sort ( + const compare_type& compare + ) + { + if (this->size() > 1) + { + sort_this_queue(*this,0,compare); + } + } + + private: + + template + void sort_this_queue ( + queue_base& queue, + long depth, + const compare_type& compare + ) + /*! + ensures + each element in the queue is < the element behind it according + to compare + !*/ + { + if (queue.size() <= 1) + { + // already sorted + } + else if (queue.size() <= 29) + { + T vect[29]; + const unsigned long size = queue.size(); + for (unsigned long i = 0; i < size; ++i) + { + queue.dequeue(vect[i]); + } + isort_array(vect,0,size-1,compare); + for (unsigned long i = 0; i < size; ++i) + { + queue.enqueue(vect[i]); + } + } + else if (depth > 50) + { + std::vector vect(queue.size()); + for (unsigned long i = 0; i < vect.size(); ++i) + { + queue.dequeue(vect[i]); + } + hsort_array(vect,0,vect.size()-1,compare); + for (unsigned long i = 0; i < vect.size(); ++i) + { + queue.enqueue(vect[i]); + } + } + else + { + queue_base left, right; + T partition_element; + T temp; + // do this just to avoid a compiler warning + assign_zero_if_built_in_scalar_type(temp); + assign_zero_if_built_in_scalar_type(partition_element); + + queue.dequeue(partition_element); + + // partition queue into left and right + while (queue.size() > 0) + { + queue.dequeue(temp); + if (compare(temp , partition_element)) + { + left.enqueue(temp); + } + else + { + right.enqueue(temp); + } + } + + + long ratio; + if (left.size() > right.size()) + ratio = left.size()/(right.size()+1); // add 1 so we can't divide by zero + else + ratio = right.size()/(left.size()+1); + + sort_this_queue(left,ratio+depth,compare); + sort_this_queue(right,ratio+depth,compare); + + // combine the two queues + left.swap(queue); + queue.enqueue(partition_element); + queue.cat(right); + } + } + + + }; + + template < + typename queue_base + > + inline void swap ( + queue_sort_1& a, + queue_sort_1& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename queue_base + > + void queue_sort_1:: + sort ( + ) + { + if (this->size() > 1) + { + sort_this_queue(*this,0,std::less()); + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_QUEUE_SORt_1_ + diff --git a/ml/dlib/dlib/queue/queue_sort_abstract.h b/ml/dlib/dlib/queue/queue_sort_abstract.h new file mode 100644 index 000000000..54c06f430 --- /dev/null +++ b/ml/dlib/dlib/queue/queue_sort_abstract.h @@ -0,0 +1,74 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_QUEUE_SORt_ABSTRACT_ +#ifdef DLIB_QUEUE_SORt_ABSTRACT_ + + +#include "queue_kernel_abstract.h" + +namespace dlib +{ + + template < + typename queue_base + > + class queue_sort : public queue_base + { + + /*! + REQUIREMENTS ON QUEUE_BASE + - is an implementation of queue/queue_kernel_abstract.h + - queue_base::type must be a type with that is comparable via operator< + + POINTERS AND REFERENCES TO INTERNAL DATA + sort() may invalidate pointers and references to internal data. + + WHAT THIS EXTENSION DOES FOR QUEUE + This gives a queue the ability to sort its contents by calling sort(). + !*/ + + + public: + + void sort ( + ); + /*! + ensures + - for all elements in #*this the ith element is <= the i+1 element + - #at_start() == true + throws + - std::bad_alloc or any exception thrown by T's constructor + data may be lost if sort() throws + !*/ + + template + void sort ( + const compare_type& compare + ); + /*! + ensures + - for all elements in #*this the ith element is <= the i+1 element + - uses compare(a,b) as the < operator. So if compare(a,b) == true + then a comes before b in the resulting ordering. + - #at_start() == true + throws + - std::bad_alloc or any exception thrown by T's constructor + data may be lost if sort() throws + !*/ + }; + + template < + template queue_base + > + inline void swap ( + queue_sort& a, + queue_sort& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + +} + +#endif // DLIB_QUEUE_SORt_ABSTRACT_ + diff --git a/ml/dlib/dlib/rand.h b/ml/dlib/dlib/rand.h new file mode 100644 index 000000000..5e0146f0d --- /dev/null +++ b/ml/dlib/dlib/rand.h @@ -0,0 +1,9 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_RANd_ +#define DLIB_RANd_ + +#include "rand/rand_kernel_1.h" + +#endif // DLIB_RANd_ + diff --git a/ml/dlib/dlib/rand/mersenne_twister.h b/ml/dlib/dlib/rand/mersenne_twister.h new file mode 100644 index 000000000..66ee87f54 --- /dev/null +++ b/ml/dlib/dlib/rand/mersenne_twister.h @@ -0,0 +1,210 @@ +/* boost random/mersenne_twister.hpp header file + * + * Copyright Jens Maurer 2000-2001 + * Distributed under the Boost Software License, Version 1.0. (See + * accompanying file LICENSE_1_0.txt or copy at + * http://www.boost.org/LICENSE_1_0.txt) + * + * See http://www.boost.org for most recent version including documentation. + * + * $Id: mersenne_twister.hpp,v 1.20 2005/07/21 22:04:31 jmaurer Exp $ + * + * Revision history + * 2001-02-18 moved to individual header files +*/ + +#ifndef DLIB_BOOST_RANDOM_MERSENNE_TWISTER_HPP +#define DLIB_BOOST_RANDOM_MERSENNE_TWISTER_HPP + +#include +#include // std::copy +#include +#include "../uintn.h" +#include "../serialize.h" + +namespace dlib +{ + namespace random_helpers + { + + // ------------------------------------------------------------------------------------ + + // http://www.math.keio.ac.jp/matumoto/emt.html + template< + class UIntType, + int w, + int n, + int m, + int r, + UIntType a, + int u, + int s, + UIntType b, + int t, + UIntType c, + int l, + UIntType val + > + class mersenne_twister + { + public: + typedef UIntType result_type; + const static int word_size = w; + const static int state_size = n; + const static int shift_size = m; + const static int mask_bits = r; + const static UIntType parameter_a = a; + const static int output_u = u; + const static int output_s = s; + const static UIntType output_b = b; + const static int output_t = t; + const static UIntType output_c = c; + const static int output_l = l; + + const static bool has_fixed_range = false; + + mersenne_twister() { seed(); } + + explicit mersenne_twister(UIntType value) { seed(value); } + + void seed () { seed(UIntType(5489)); } + + // compiler-generated copy ctor and assignment operator are fine + + void seed(UIntType value) + { + // New seeding algorithm from + // http://www.math.sci.hiroshima-u.ac.jp/~m-mat/MT/MT2002/emt19937ar.html + // In the previous versions, MSBs of the seed affected only MSBs of the + // state x[]. + const UIntType mask = ~0u; + x[0] = value & mask; + for (i = 1; i < n; i++) { + // See Knuth "The Art of Computer Programming" Vol. 2, 3rd ed., page 106 + x[i] = (1812433253UL * (x[i-1] ^ (x[i-1] >> (w-2))) + i) & mask; + } + } + + result_type min() const { return 0; } + result_type max() const + { + // avoid "left shift count >= with of type" warning + result_type res = 0; + for(int i = 0; i < w; ++i) + res |= (1u << i); + return res; + } + + result_type operator()(); + + friend void serialize( + const mersenne_twister& item, + std::ostream& out + ) + { + dlib::serialize(item.x, out); + dlib::serialize(item.i, out); + } + + friend void deserialize( + mersenne_twister& item, + std::istream& in + ) + { + dlib::deserialize(item.x, in); + dlib::deserialize(item.i, in); + } + + private: + + void twist(int block); + + // state representation: next output is o(x(i)) + // x[0] ... x[k] x[k+1] ... x[n-1] x[n] ... x[2*n-1] represents + // x(i-k) ... x(i) x(i+1) ... x(i-k+n-1) x(i-k-n) ... x[i(i-k-1)] + // The goal is to always have x(i-n) ... x(i-1) available for + // operator== and save/restore. + + UIntType x[2*n]; + int i; + }; + + // ------------------------------------------------------------------------------------ + + template< + class UIntType, int w, int n, int m, int r, UIntType a, int u, + int s, UIntType b, int t, UIntType c, int l, UIntType val + > + void mersenne_twister::twist( + int block + ) + { + const UIntType upper_mask = (~0u) << r; + const UIntType lower_mask = ~upper_mask; + + if(block == 0) { + for(int j = n; j < 2*n; j++) { + UIntType y = (x[j-n] & upper_mask) | (x[j-(n-1)] & lower_mask); + x[j] = x[j-(n-m)] ^ (y >> 1) ^ (y&1 ? a : 0); + } + } else if (block == 1) { + // split loop to avoid costly modulo operations + { // extra scope for MSVC brokenness w.r.t. for scope + for(int j = 0; j < n-m; j++) { + UIntType y = (x[j+n] & upper_mask) | (x[j+n+1] & lower_mask); + x[j] = x[j+n+m] ^ (y >> 1) ^ (y&1 ? a : 0); + } + } + + for(int j = n-m; j < n-1; j++) { + UIntType y = (x[j+n] & upper_mask) | (x[j+n+1] & lower_mask); + x[j] = x[j-(n-m)] ^ (y >> 1) ^ (y&1 ? a : 0); + } + // last iteration + UIntType y = (x[2*n-1] & upper_mask) | (x[0] & lower_mask); + x[n-1] = x[m-1] ^ (y >> 1) ^ (y&1 ? a : 0); + i = 0; + } + } + + // ------------------------------------------------------------------------------------ + + template< + class UIntType, int w, int n, int m, int r, UIntType a, int u, + int s, UIntType b, int t, UIntType c, int l, UIntType val + > + inline typename mersenne_twister::result_type + mersenne_twister::operator()( + ) + { + if(i == n) + twist(0); + else if(i >= 2*n) + twist(1); + // Step 4 + UIntType z = x[i]; + ++i; + z ^= (z >> u); + z ^= ((z << s) & b); + z ^= ((z << t) & c); + z ^= (z >> l); + return z; + } + + // ------------------------------------------------------------------------------------ + + } // namespace random + + + typedef random_helpers::mersenne_twister mt11213b; + + // validation by experiment from mt19937.c + typedef random_helpers::mersenne_twister mt19937; + +} // namespace dlib + + +#endif // DLIB_BOOST_RANDOM_MERSENNE_TWISTER_HPP + diff --git a/ml/dlib/dlib/rand/rand_kernel_1.h b/ml/dlib/dlib/rand/rand_kernel_1.h new file mode 100644 index 000000000..a1847be24 --- /dev/null +++ b/ml/dlib/dlib/rand/rand_kernel_1.h @@ -0,0 +1,354 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_RAND_KERNEl_1_ +#define DLIB_RAND_KERNEl_1_ + +#include +#include "../algs.h" +#include "rand_kernel_abstract.h" +#include "mersenne_twister.h" +#include "../is_kind.h" +#include +#include "../serialize.h" +#include "../string.h" + +namespace dlib +{ + + + class rand + { + + /*! + INITIAL VALUE + - seed == "" + + CONVENTION + - the random numbers come from the boost mersenne_twister code + - get_seed() == seed + !*/ + + public: + + // These typedefs are here for backwards compatibility with older versions of dlib. + typedef rand kernel_1a; + typedef rand float_1a; + + rand( + ) + { + init(); + } + + rand ( + time_t seed_value + ) + { + init(); + set_seed(cast_to_string(seed_value)); + } + + rand ( + const std::string& seed_value + ) + { + init(); + set_seed(seed_value); + } + + virtual ~rand( + ) + {} + + void clear( + ) + { + mt.seed(); + seed.clear(); + + has_gaussian = false; + next_gaussian = 0; + + // prime the generator a bit + for (int i = 0; i < 10000; ++i) + mt(); + } + + const std::string& get_seed ( + ) + { + return seed; + } + + void set_seed ( + const std::string& value + ) + { + seed = value; + + // make sure we do the seeding so that using a seed of "" gives the same + // state as calling this->clear() + if (value.size() != 0) + { + uint32 s = 0; + for (std::string::size_type i = 0; i < seed.size(); ++i) + { + s = (s*37) + static_cast(seed[i]); + } + mt.seed(s); + } + else + { + mt.seed(); + } + + // prime the generator a bit + for (int i = 0; i < 10000; ++i) + mt(); + + + has_gaussian = false; + next_gaussian = 0; + } + + unsigned char get_random_8bit_number ( + ) + { + return static_cast(mt()); + } + + uint16 get_random_16bit_number ( + ) + { + return static_cast(mt()); + } + + inline uint32 get_random_32bit_number ( + ) + { + return mt(); + } + + inline uint64 get_random_64bit_number ( + ) + { + const uint64 a = get_random_32bit_number(); + const uint64 b = get_random_32bit_number(); + return (a<<32)|b; + } + + double get_double_in_range ( + double begin, + double end + ) + { + DLIB_ASSERT(begin <= end); + return begin + get_random_double()*(end-begin); + } + + long long get_integer_in_range( + long long begin, + long long end + ) + { + DLIB_ASSERT(begin <= end); + if (begin == end) + return begin; + + auto r = get_random_64bit_number(); + const auto limit = std::numeric_limits::max(); + const auto range = end-begin; + // Use rejection sampling to remove the biased sampling you would get with + // the naive get_random_64bit_number()%range sampling. + while(r >= (limit/range)*range) + r = get_random_64bit_number(); + + return begin + static_cast(r%range); + } + + long long get_integer( + long long end + ) + { + DLIB_ASSERT(end >= 0); + + return get_integer_in_range(0,end); + } + + double get_random_double ( + ) + { + uint32 temp; + + temp = rand::get_random_32bit_number(); + temp &= 0xFFFFFF; + + double val = static_cast(temp); + + val *= 0x1000000; + + temp = rand::get_random_32bit_number(); + temp &= 0xFFFFFF; + + val += temp; + + val /= max_val; + + if (val < 1.0) + { + return val; + } + else + { + // return a value slightly less than 1.0 + return 1.0 - std::numeric_limits::epsilon(); + } + } + + float get_random_float ( + ) + { + uint32 temp; + + temp = rand::get_random_32bit_number(); + temp &= 0xFFFFFF; + + const float scale = 1.0/0x1000000; + + const float val = static_cast(temp)*scale; + if (val < 1.0f) + { + return val; + } + else + { + // return a value slightly less than 1.0 + return 1.0f - std::numeric_limits::epsilon(); + } + } + + double get_random_gaussian ( + ) + { + if (has_gaussian) + { + has_gaussian = false; + return next_gaussian; + } + + double x1, x2, w; + + const double rndmax = std::numeric_limits::max(); + + // Generate a pair of Gaussian random numbers using the Box-Muller transformation. + do + { + const double rnd1 = get_random_32bit_number()/rndmax; + const double rnd2 = get_random_32bit_number()/rndmax; + + x1 = 2.0 * rnd1 - 1.0; + x2 = 2.0 * rnd2 - 1.0; + w = x1 * x1 + x2 * x2; + } while ( w >= 1.0 ); + + w = std::sqrt( (-2.0 * std::log( w ) ) / w ); + next_gaussian = x2 * w; + has_gaussian = true; + return x1 * w; + } + + void swap ( + rand& item + ) + { + exchange(mt,item.mt); + exchange(seed, item.seed); + exchange(has_gaussian, item.has_gaussian); + exchange(next_gaussian, item.next_gaussian); + } + + friend void serialize( + const rand& item, + std::ostream& out + ); + + friend void deserialize( + rand& item, + std::istream& in + ); + + private: + + void init() + { + // prime the generator a bit + for (int i = 0; i < 10000; ++i) + mt(); + + max_val = 0xFFFFFF; + max_val *= 0x1000000; + max_val += 0xFFFFFF; + max_val += 0.05; + + + has_gaussian = false; + next_gaussian = 0; + } + + mt19937 mt; + + std::string seed; + + + double max_val; + bool has_gaussian; + double next_gaussian; + }; + + + inline void swap ( + rand& a, + rand& b + ) { a.swap(b); } + + + template <> + struct is_rand + { + static const bool value = true; + }; + + inline void serialize( + const rand& item, + std::ostream& out + ) + { + int version = 1; + serialize(version, out); + + serialize(item.mt, out); + serialize(item.seed, out); + serialize(item.has_gaussian, out); + serialize(item.next_gaussian, out); + } + + inline void deserialize( + rand& item, + std::istream& in + ) + { + int version; + deserialize(version, in); + if (version != 1) + throw serialization_error("Error deserializing object of type rand: unexpected version."); + + deserialize(item.mt, in); + deserialize(item.seed, in); + deserialize(item.has_gaussian, in); + deserialize(item.next_gaussian, in); + } +} + +#endif // DLIB_RAND_KERNEl_1_ + + diff --git a/ml/dlib/dlib/rand/rand_kernel_abstract.h b/ml/dlib/dlib/rand/rand_kernel_abstract.h new file mode 100644 index 000000000..67af27a9e --- /dev/null +++ b/ml/dlib/dlib/rand/rand_kernel_abstract.h @@ -0,0 +1,218 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_RAND_KERNEl_ABSTRACT_ +#ifdef DLIB_RAND_KERNEl_ABSTRACT_ + +#include +#include "../uintn.h" + +namespace dlib +{ + + + class rand + { + + /*! + INITIAL VALUE + get_seed() == "" + + + WHAT THIS OBJECT REPRESENTS + This object represents a pseudorandom number generator. + !*/ + + public: + + + rand( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc + !*/ + + rand ( + time_t seed_value + ); + /*! + ensures + - #*this is properly initialized + - #get_seed() == cast_to_string(seed_value) + - This version of the constructor is equivalent to using + the default constructor and then calling set_seed(cast_to_string(seed_value)) + throws + - std::bad_alloc + !*/ + + rand ( + const std::string& seed_value + ); + /*! + ensures + - #*this is properly initialized + - #get_seed() == seed_value + - This version of the constructor is equivalent to using + the default constructor and then calling set_seed(seed_value) + throws + - std::bad_alloc + !*/ + + virtual ~rand( + ); + /*! + ensures + - all memory associated with *this has been released + !*/ + + void clear( + ); + /*! + ensures + - #*this has its initial value + throws + - std::bad_alloc + if this exception is thrown then *this is unusable + until clear() is called and succeeds + !*/ + + const std::string& get_seed ( + ); + /*! + ensures + - returns the string currently being used as the random seed. + !*/ + + void set_seed ( + const std::string& value + ); + /*! + ensures + - #get_seed() == value + !*/ + + unsigned char get_random_8bit_number ( + ); + /*! + ensures + - returns a pseudorandom number in the range 0 to 255 + !*/ + + uint16 get_random_16bit_number ( + ); + /*! + ensures + - returns a pseudorandom number in the range 0 to 2^16-1 + !*/ + + uint32 get_random_32bit_number ( + ); + /*! + ensures + - returns a pseudorandom number in the range 0 to 2^32-1 + !*/ + + uint64 get_random_64bit_number ( + ); + /*! + ensures + - returns a pseudorandom number in the range 0 to 2^64-1 + !*/ + + float get_random_float ( + ); + /*! + ensures + - returns a random float number N where: 0.0 <= N < 1.0. + !*/ + + double get_random_double ( + ); + /*! + ensures + - returns a random double number N where: 0.0 <= N < 1.0. + !*/ + + double get_double_in_range ( + double begin, + double end + ); + /*! + requires + - begin <= end + ensures + - if (begin < end) then + - returns a random double number N where: begin <= N < end. + - else + - returns begin + !*/ + + long long get_integer_in_range( + long long begin, + long long end + ); + /*! + requires + - begin <= end + ensures + - returns a random integer selected from the range: begin <= N < end + The integer is selected uniformly at random. + !*/ + + long long get_integer( + long long end + ); + /*! + requires + - 0 <= end + ensures + - returns get_integer_in_range(0,end) + !*/ + + double get_random_gaussian ( + ); + /*! + ensures + - returns a random number sampled from a Gaussian distribution + with mean 0 and standard deviation 1. + !*/ + + void swap ( + rand& item + ); + /*! + ensures + - swaps *this and item + !*/ + + }; + + inline void swap ( + rand& a, + rand& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + + void serialize ( + const rand& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + + void deserialize ( + rand& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ +} + +#endif // DLIB_RAND_KERNEl_ABSTRACT_ + diff --git a/ml/dlib/dlib/random_forest.h b/ml/dlib/dlib/random_forest.h new file mode 100644 index 000000000..082f36703 --- /dev/null +++ b/ml/dlib/dlib/random_forest.h @@ -0,0 +1,10 @@ +// Copyright (C) 2018 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_RANDOM_FOReST_H_ +#define DLIB_RANDOM_FOReST_H_ + +#include "random_forest/random_forest_regression.h" + +#endif // DLIB_RANDOM_FOReST_H_ + + diff --git a/ml/dlib/dlib/random_forest/random_forest_regression.h b/ml/dlib/dlib/random_forest/random_forest_regression.h new file mode 100644 index 000000000..a61f7a1a2 --- /dev/null +++ b/ml/dlib/dlib/random_forest/random_forest_regression.h @@ -0,0 +1,738 @@ +// Copyright (C) 2018 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_RANdOM_FOREST_REGRESSION_H_ +#define DLIB_RANdOM_FOREST_REGRESSION_H_ + +#include "random_forest_regression_abstract.h" +#include +#include "../matrix.h" +#include +#include "../threads.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class dense_feature_extractor + { + + public: + typedef uint32_t feature; + typedef matrix sample_type; + + dense_feature_extractor( + ) = default; + + void setup ( + const std::vector& x, + const std::vector& y + ) + { + DLIB_CASSERT(x.size() > 0); + DLIB_CASSERT(x.size() == y.size()); + for (auto& el : x) + DLIB_CASSERT(el.size() == x[0].size(), "All the vectors in a training set have to have the same dimensionality."); + + DLIB_CASSERT(x[0].size() != 0, "The vectors can't be empty."); + + num_feats = x[0].size(); + } + + + void get_random_features ( + dlib::rand& rnd, + size_t num, + std::vector& feats + ) const + { + DLIB_ASSERT(max_num_feats() != 0); + num = std::min(num, num_feats); + + feats.clear(); + for (size_t i = 0; i < num_feats; ++i) + feats.push_back(i); + + // now pick num features at random + for (size_t i = 0; i < num; ++i) + { + auto idx = rnd.get_integer_in_range(i,num_feats); + std::swap(feats[i], feats[idx]); + } + feats.resize(num); + } + + double extract_feature_value ( + const sample_type& item, + const feature& f + ) const + { + DLIB_ASSERT(max_num_feats() != 0); + return item(f); + } + + size_t max_num_feats ( + ) const + { + return num_feats; + } + + friend void serialize(const dense_feature_extractor& item, std::ostream& out) + { + serialize("dense_feature_extractor", out); + serialize(item.num_feats, out); + } + + friend void deserialize(dense_feature_extractor& item, std::istream& in) + { + check_serialized_version("dense_feature_extractor", in); + deserialize(item.num_feats, in); + } + + private: + size_t num_feats = 0; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + struct internal_tree_node + { + uint32_t left; + uint32_t right; + float split_threshold; + typename feature_extractor::feature split_feature; + }; + + template + void serialize(const internal_tree_node& item, std::ostream& out) + { + serialize(item.left, out); + serialize(item.right, out); + serialize(item.split_threshold, out); + serialize(item.split_feature, out); + } + + template + void deserialize(internal_tree_node& item, std::istream& in) + { + deserialize(item.left, in); + deserialize(item.right, in); + deserialize(item.split_threshold, in); + deserialize(item.split_feature, in); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor = dense_feature_extractor + > + class random_forest_regression_function + { + + public: + + typedef feature_extractor feature_extractor_type; + typedef typename feature_extractor::sample_type sample_type; + + random_forest_regression_function( + ) = default; + + random_forest_regression_function ( + feature_extractor_type&& fe_, + std::vector>>&& trees_, + std::vector>&& leaves_ + ) : + fe(std::move(fe_)), + trees(std::move(trees_)), + leaves(std::move(leaves_)) + { + DLIB_ASSERT(trees.size() > 0); + DLIB_ASSERT(trees.size() == leaves.size(), "Every set of tree nodes has to have leaves"); +#ifdef ENABLE_ASSERTS + for (size_t i = 0; i < trees.size(); ++i) + { + DLIB_ASSERT(trees[i].size() > 0, "A tree can't have 0 leaves."); + for (auto& node : trees[i]) + { + DLIB_ASSERT(trees[i].size()+leaves[i].size() > node.left, "left node index in tree is too big. There is no associated tree node or leaf."); + DLIB_ASSERT(trees[i].size()+leaves[i].size() > node.right, "right node index in tree is too big. There is no associated tree node or leaf."); + } + } +#endif + } + + size_t get_num_trees( + ) const + { + return trees.size(); + } + + const std::vector>>& get_internal_tree_nodes ( + ) const { return trees; } + + const std::vector>& get_tree_leaves ( + ) const { return leaves; } + + const feature_extractor_type& get_feature_extractor ( + ) const { return fe; } + + double operator() ( + const sample_type& x + ) const + { + DLIB_ASSERT(get_num_trees() > 0); + + double accum = 0; + + for (size_t i = 0; i < trees.size(); ++i) + { + auto& tree = trees[i]; + // walk the tree to the leaf + uint32_t idx = 0; + while(idx < tree.size()) + { + auto feature_value = fe.extract_feature_value(x, tree[idx].split_feature); + if (feature_value < tree[idx].split_threshold) + idx = tree[idx].left; + else + idx = tree[idx].right; + } + // compute leaf index + accum += leaves[i][idx-tree.size()]; + } + + return accum/trees.size(); + } + + friend void serialize(const random_forest_regression_function& item, std::ostream& out) + { + serialize("random_forest_regression_function", out); + serialize(item.fe, out); + serialize(item.trees, out); + serialize(item.leaves, out); + } + + friend void deserialize(random_forest_regression_function& item, std::istream& in) + { + check_serialized_version("random_forest_regression_function", in); + deserialize(item.fe, in); + deserialize(item.trees, in); + deserialize(item.leaves, in); + } + + private: + + /*! + CONVENTION + - trees.size() == leaves.size() + - Any .left or .right index in trees that is larger than the number of + nodes in the tree references a leaf. Moreover, the index of the leaf is + computed by subtracting the number of nodes in the tree. + !*/ + + feature_extractor_type fe; + + // internal nodes of trees + std::vector>> trees; + // leaves of trees + std::vector> leaves; + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor = dense_feature_extractor + > + class random_forest_regression_trainer + { + public: + typedef feature_extractor feature_extractor_type; + typedef random_forest_regression_function trained_function_type; + typedef typename feature_extractor::sample_type sample_type; + + + random_forest_regression_trainer ( + ) = default; + + const feature_extractor_type& get_feature_extractor ( + ) const + { + return fe_; + } + + void set_feature_extractor ( + const feature_extractor_type& feat_extractor + ) + { + fe_ = feat_extractor; + } + + void set_seed ( + const std::string& seed + ) + { + random_seed = seed; + } + + const std::string& get_random_seed ( + ) const + { + return random_seed; + } + + size_t get_num_trees ( + ) const + { + return num_trees; + } + + void set_num_trees ( + size_t num + ) + { + DLIB_CASSERT(num > 0); + num_trees = num; + } + + void set_feature_subsampling_fraction ( + double frac + ) + { + DLIB_CASSERT(0 < frac && frac <= 1); + feature_subsampling_frac = frac; + } + + double get_feature_subsampling_frac( + ) const + { + return feature_subsampling_frac; + } + + void set_min_samples_per_leaf ( + size_t num + ) + { + DLIB_ASSERT(num > 0); + min_samples_per_leaf = num; + } + + size_t get_min_samples_per_leaf( + ) const + { + return min_samples_per_leaf; + } + + void be_verbose ( + ) + { + verbose = true; + } + + void be_quiet ( + ) + { + verbose = false; + } + + trained_function_type train ( + const std::vector& x, + const std::vector& y + ) const + { + std::vector junk; + return do_train(x,y,junk,false); + } + + trained_function_type train ( + const std::vector& x, + const std::vector& y, + std::vector& oob_values + ) const + { + return do_train(x,y,oob_values,true); + } + + private: + + trained_function_type do_train ( + const std::vector& x, + const std::vector& y, + std::vector& oob_values, + bool compute_oob_values + ) const + { + DLIB_CASSERT(x.size() == y.size()); + DLIB_CASSERT(x.size() > 0); + + feature_extractor_type fe = fe_; + fe.setup(x,y); + + DLIB_CASSERT(fe.max_num_feats() != 0); + + std::vector>> all_trees(num_trees); + std::vector> all_leaves(num_trees); + + const double sumy = sum(mat(y)); + + const size_t feats_per_node = std::max(1.0,std::round(fe.max_num_feats()*feature_subsampling_frac)); + + // Each tree couldn't have more than this many interior nodes. It might + // end up having less though. We need to know this value because the way + // we mark a left or right pointer in a tree as pointing to a leaf is by + // making its index larger than the number of interior nodes in the tree. + // But we don't know the tree's size before we finish building it. So we + // will use max_num_nodes as a proxy during tree construction and then go + // back and fix it once a tree's size is known. + const uint32_t max_num_nodes = y.size(); + + std::vector oob_hits; + + if (compute_oob_values) + { + oob_values.resize(y.size()); + oob_hits.resize(y.size()); + } + + + std::mutex m; + + // Calling build_tree(i) creates the ith tree and stores the results in + // all_trees and all_leaves. + auto build_tree = [&](long i) + { + dlib::rand rnd(random_seed + std::to_string(i)); + auto& tree = all_trees[i]; + auto& leaves = all_leaves[i]; + + // Check if there are fewer than min_samples_per_leaf and if so then + // don't make any tree. Just average the things and be done. + if (y.size() <= min_samples_per_leaf) + { + leaves.push_back(sumy/y.size()); + return; + } + + + // pick a random bootstrap of the data. + std::vector> idxs(y.size()); + for (auto& idx : idxs) + idx = std::make_pair(0,rnd.get_integer(y.size())); + + // We are going to use ranges_to_process as a stack that tracks which + // range of samples we are going to split next. + std::vector ranges_to_process; + // start with the root of the tree, i.e. the entire range of training + // samples. + ranges_to_process.emplace_back(sumy,0,y.size()); + // push an unpopulated root node into the tree. We will populate it + // when we process its corresponding range. + tree.emplace_back(); + + std::vector feats; + + while(ranges_to_process.size() > 0) + { + // Grab the next range/node to process. + const auto range = ranges_to_process.back(); + ranges_to_process.pop_back(); + + + // Get the split features we will consider at this node. + fe.get_random_features(rnd, feats_per_node, feats); + // Then find the best split + auto best_split = find_best_split_among_feats(fe, range, feats, x, y, idxs); + + range_t left_split(best_split.left_sum, range.begin, best_split.split_idx); + range_t right_split(best_split.right_sum, best_split.split_idx, range.end); + + DLIB_ASSERT(left_split.begin < left_split.end); + DLIB_ASSERT(right_split.begin < right_split.end); + + // Now that we know the split we can populate the parent node we popped + // from ranges_to_process. + tree[range.tree_idx].split_threshold = best_split.split_threshold; + tree[range.tree_idx].split_feature = best_split.split_feature; + + // If the left split is big enough to make a new interior leaf + // node. We also stop splitting if all the samples went into this node. + // This could happen if the features are all uniform so there just + // isn't any way to split them anymore. + if (left_split.size() > min_samples_per_leaf && right_split.size() != 0) + { + // allocate an interior leaf node for it. + left_split.tree_idx = tree.size(); + tree.emplace_back(); + // set the pointer in the parent node to the newly allocated + // node. + tree[range.tree_idx].left = left_split.tree_idx; + + ranges_to_process.emplace_back(left_split); + } + else + { + // Add to leaves. Don't forget to set the pointer in the + // parent node to the newly allocated leaf node. + tree[range.tree_idx].left = leaves.size() + max_num_nodes; + leaves.emplace_back(left_split.avg()); + } + + + // If the right split is big enough to make a new interior leaf + // node. We also stop splitting if all the samples went into this node. + // This could happen if the features are all uniform so there just + // isn't any way to split them anymore. + if (right_split.size() > min_samples_per_leaf && left_split.size() != 0) + { + // allocate an interior leaf node for it. + right_split.tree_idx = tree.size(); + tree.emplace_back(); + // set the pointer in the parent node to the newly allocated + // node. + tree[range.tree_idx].right = right_split.tree_idx; + + ranges_to_process.emplace_back(right_split); + } + else + { + // Add to leaves. Don't forget to set the pointer in the + // parent node to the newly allocated leaf node. + tree[range.tree_idx].right = leaves.size() + max_num_nodes; + leaves.emplace_back(right_split.avg()); + } + } // end while (still building tree) + + // Fix the leaf pointers in the tree now that we know the correct + // tree.size() value. + DLIB_CASSERT(max_num_nodes >= tree.size()); + const auto offset = max_num_nodes - tree.size(); + for (auto& n : tree) + { + if (n.left >= max_num_nodes) + n.left -= offset; + if (n.right >= max_num_nodes) + n.right -= offset; + } + + + if (compute_oob_values) + { + std::sort(idxs.begin(), idxs.end(), + [](const std::pair& a, const std::pair& b) {return a.second lock(m); + + size_t j = 0; + for (size_t i = 0; i < oob_values.size(); ++i) + { + // check if i is in idxs + while(j < idxs.size() && i > idxs[j].second) + ++j; + + // i isn't in idxs so it's an oob sample and we should process it. + if (j == idxs.size() || idxs[j].second != i) + { + oob_hits[i]++; + + // walk the tree to find the leaf value for this oob sample + uint32_t idx = 0; + while(idx < tree.size()) + { + auto feature_value = fe.extract_feature_value(x[i], tree[idx].split_feature); + if (feature_value < tree[idx].split_threshold) + idx = tree[idx].left; + else + idx = tree[idx].right; + } + oob_values[i] += leaves[idx-tree.size()]; + } + } + } + }; + + if (verbose) + parallel_for_verbose(0, num_trees, build_tree); + else + parallel_for(0, num_trees, build_tree); + + + if (compute_oob_values) + { + double meanval = 0; + double cnt = 0; + for (size_t i = 0; i < oob_values.size(); ++i) + { + if (oob_hits[i] != 0) + { + oob_values[i] /= oob_hits[i]; + meanval += oob_values[i]; + ++cnt; + } + } + + // If there are some elements that didn't get hits, we set their oob values + // to the mean oob value. + if (cnt != 0) + { + const double typical_value = meanval/cnt; + for (size_t i = 0; i < oob_values.size(); ++i) + { + if (oob_hits[i] == 0) + oob_values[i] = typical_value; + } + } + } + + return trained_function_type(std::move(fe), std::move(all_trees), std::move(all_leaves)); + } + + struct range_t + { + range_t( + double sumy, + uint32_t begin, + uint32_t end + ) : sumy(sumy), begin(begin), end(end), tree_idx(0) {} + + double sumy; + uint32_t begin; + uint32_t end; + + // Every range object corresponds to an entry in a tree. This tells you the + // tree node that owns the range. + uint32_t tree_idx; + + uint32_t size() const { return end-begin; } + double avg() const { return sumy/size(); } + }; + + struct best_split_details + { + double score = -std::numeric_limits::infinity(); + double left_sum; + double right_sum; + uint32_t split_idx; + double split_threshold; + typename feature_extractor::feature split_feature; + + bool operator < (const best_split_details& rhs) const + { + return score < rhs.score; + } + }; + + static best_split_details find_best_split ( + const range_t& range, + const std::vector& y, + const std::vector>& idxs + ) + /*! + requires + - max(mat(idxs)) < y.size() + - range.sumy == sum of y[idxs[j].second] for all valid j in range [range.begin, range.end). + ensures + - finds a threshold T such that there exists an i satisfying the following: + - y[idxs[j].second] < T for all j <= i + - y[idxs[j].second] > T for all j > i + Therefore, the threshold T partitions the contents of y into two groups, + relative to the ordering established by idxs. Moreover the partitioning + of y values into two groups has the additional requirement that it is + optimal in the sense that the sum of the squared deviations from each + partition's mean is minimized. + !*/ + { + + size_t best_i = range.begin; + double best_score = -1; + double left_sum = 0; + double best_left_sum = y[idxs[range.begin].second]; + const auto size = range.size(); + size_t left_size = 0; + for (size_t i = range.begin; i+1 < range.end; ++i) + { + ++left_size; + left_sum += y[idxs[i].second]; + + // Don't split here because the next element has the same feature value so + // we can't *really* split here. + if (idxs[i].first==idxs[i+1].first) + continue; + + const double right_sum = range.sumy-left_sum; + + const double score = left_sum*left_sum/left_size + right_sum*right_sum/(size-left_size); + + if (score > best_score) + { + best_score = score; + best_i = i; + best_left_sum = left_sum; + } + } + + best_split_details result; + result.score = best_score; + result.left_sum = best_left_sum; + result.right_sum = range.sumy-best_left_sum; + result.split_idx = best_i+1; // one past the end of the left range + result.split_threshold = (idxs[best_i].first+idxs[best_i+1].first)/2; + + return result; + } + + + static best_split_details find_best_split_among_feats( + const feature_extractor& fe, + const range_t& range, + const std::vector& feats, + const std::vector& x, + const std::vector& y, + std::vector>& idxs + ) + { + auto compare_first = [](const std::pair& a, const std::pair& b) { return a.first +#include "../matrix.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class dense_feature_extractor + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a tool for extracting features from objects. In particular, + it is designed to be used with the random forest regression tools discussed + below. + + This particular feature extract does almost nothing since it works on + vectors in R^n and simply selects elements from each vector. However, the + tools below are templated and allow you to design your own feature extractors + that operate on whatever object types you create. So for example, maybe + you want to perform regression on images rather than vectors. Moreover, + your feature extraction could be more complex. Maybe you are selecting + differences between pairs of pixels in an image or doing something + involving geometric transforms during feature extraction. Any of these + kinds of more complex feature extraction patterns can be realized with the + random forest tools by implementing your own feature extractor object and + using it with the random forest objects. + + Therefore, you should consider this dense_feature_extractor as an example + that documents the interface as well as the simple default extractor for + use with dense vectors. + + + THREAD SAFETY + It is safe to call const members of this object from multiple threads. ANY + USER DEFINED FEATURE EXTRACTORS MUST ALSO MEET THIS GUARONTEE AS WELL SINCE + IT IS ASSUMED BY THE RANDOM FOREST TRAINING ROUTINES. + !*/ + + public: + typedef uint32_t feature; + typedef matrix sample_type; + + dense_feature_extractor( + ); + /*! + ensures + - #max_num_feats() == 0 + !*/ + + void setup ( + const std::vector& x, + const std::vector& y + ); + /*! + requires + - x.size() == y.size() + - x.size() > 0 + - x[0].size() > 0 + - all the vectors in x have the same dimensionality. + ensures + - Configures this feature extractor to work on the given training data. + For dense feature extractors all we do is record the dimensionality of + the training vectors. + - #max_num_feats() == x[0].size() + (In general, setup() sets max_num_feats() to some non-zero value so that + the other methods of this object can then be called. The point of setup() + is to allow a feature extractor to gather whatever statistics it needs from + training data. That is, more complex feature extraction strategies my + themselves be trained from data.) + !*/ + + void get_random_features ( + dlib::rand& rnd, + size_t num, + std::vector& feats + ) const; + /*! + requires + - max_num_feats() != 0 + ensures + - #feats.size() == min(num, max_num_feats()) + - This function randomly identifies num features and stores them into feats. + These feature objects can then be used with extract_feature_value() to + obtain a value from any particular sample_type object. This value is the + "feature value" used by a decision tree algorithm to deice how to split + and traverse trees. + - The above two conditions define the behavior of get_random_features() in + general. For this specific implementation of the feature extraction interface + this function selects num integer values from the range [0, max_num_feats()), + without replacement. These values are stored into feats. + !*/ + + double extract_feature_value ( + const sample_type& item, + const feature& f + ) const; + /*! + requires + - #max_num_feats() != 0 + - f was produced from a call to get_random_features(). + ensures + - Extracts the feature value corresponding to f. For this simple dense + feature extractor this simply means returning item(f). But in general + you can design feature extractors that do something more complex. + !*/ + + size_t max_num_feats ( + ) const; + /*! + ensures + - returns the number of distinct features this object might extract. That is, + a feature extractor essentially defines a mapping from sample_type objects to + vectors in R^max_num_feats(). + !*/ + }; + + void serialize(const dense_feature_extractor& item, std::ostream& out); + void deserialize(dense_feature_extractor& item, std::istream& in); + /*! + provides serialization support + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + struct internal_tree_node + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is an internal node in a regression tree. See the code of + random_forest_regression_function to see how it is used to create a tree. + !*/ + + uint32_t left; + uint32_t right; + float split_threshold; + typename feature_extractor::feature split_feature; + }; + + template + void serialize(const internal_tree_node& item, std::ostream& out); + template + void deserialize(internal_tree_node& item, std::istream& in); + /*! + provides serialization support + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor = dense_feature_extractor + > + class random_forest_regression_function + { + /*! + REQUIREMENTS ON feature_extractor + feature_extractor must be dense_feature_extractor or a type with a + compatible interface. + + WHAT THIS OBJECT REPRESENTS + This object represents a regression forest. This is a collection of + decision trees that take an object as input and each vote on a real value + to associate with the object. The final real value output is the average + of all the votes from each of the trees. + !*/ + + public: + + typedef feature_extractor feature_extractor_type; + typedef typename feature_extractor::sample_type sample_type; + + random_forest_regression_function( + ); + /*! + ensures + - #num_trees() == 0 + !*/ + + random_forest_regression_function ( + feature_extractor_type&& fe_, + std::vector>>&& trees_, + std::vector>&& leaves_ + ); + /*! + requires + - trees.size() > 0 + - trees.size() = leaves.size() + - for all valid i: + - leaves[i].size() > 0 + - trees[i].size()+leaves[i].size() > the maximal left or right index values in trees[i]. + (i.e. each left or right value must index to some existing internal tree node or leaf node). + ensures + - #get_internal_tree_nodes() == trees_ + - #get_tree_leaves() == leaves_ + - #get_feature_extractor() == fe_ + !*/ + + size_t get_num_trees( + ) const; + /*! + ensures + - returns the number of trees in this regression forest. + !*/ + + const std::vector>>& get_internal_tree_nodes ( + ) const; + /*! + ensures + - returns the internal tree nodes that define the regression trees. + - get_internal_tree_nodes().size() == get_num_trees() + !*/ + + const std::vector>& get_tree_leaves ( + ) const; + /*! + ensures + - returns the tree leaves that define the regression trees. + - get_tree_leaves().size() == get_num_trees() + !*/ + + const feature_extractor_type& get_feature_extractor ( + ) const; + /*! + ensures + - returns the feature extractor used by the trees. + !*/ + + double operator() ( + const sample_type& x + ) const; + /*! + requires + - get_num_trees() > 0 + ensures + - Maps x to a real value and returns the value. To do this, we find the + get_num_trees() leaf values associated with x and then return the average + of these leaf values. + !*/ + }; + + void serialize(const random_forest_regression_function& item, std::ostream& out); + void deserialize(random_forest_regression_function& item, std::istream& in); + /*! + provides serialization support + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor = dense_feature_extractor + > + class random_forest_regression_trainer + { + /*! + REQUIREMENTS ON feature_extractor + feature_extractor must be dense_feature_extractor or a type with a + compatible interface. + + WHAT THIS OBJECT REPRESENTS + This object implements Breiman's classic random forest regression + algorithm. The algorithm learns to map objects, nominally vectors in R^n, + into the reals. It essentially optimizes the mean squared error by fitting + a bunch of decision trees, each of which vote on the output value of the + regressor. The final prediction is obtained by averaging all the + predictions. + + For more information on the algorithm see: + Breiman, Leo. "Random forests." Machine learning 45.1 (2001): 5-32. + !*/ + + public: + typedef feature_extractor feature_extractor_type; + typedef random_forest_regression_function trained_function_type; + typedef typename feature_extractor::sample_type sample_type; + + + random_forest_regression_trainer ( + ); + /*! + ensures + - #get_min_samples_per_leaf() == 5 + - #get_num_trees() == 1000 + - #get_feature_subsampling_frac() == 1.0/3.0 + - #get_feature_extractor() == a default initialized feature extractor. + - #get_random_seed() == "" + - this object is not verbose. + !*/ + + const feature_extractor_type& get_feature_extractor ( + ) const; + /*! + ensures + - returns the feature extractor used when train() is invoked. + !*/ + + void set_feature_extractor ( + const feature_extractor_type& feat_extractor + ); + /*! + ensures + - #get_feature_extractor() == feat_extractor + !*/ + + void set_seed ( + const std::string& seed + ); + /*! + ensures + - #get_random_seed() == seed + !*/ + + const std::string& get_random_seed ( + ) const; + /*! + ensures + - A central part of this algorithm is random selection of both training + samples and features. This function returns the seed used to initialized + the random number generator used for these random selections. + !*/ + + size_t get_num_trees ( + ) const; + /*! + ensures + - Random forests built by this object will contain get_num_trees() trees. + !*/ + + void set_num_trees ( + size_t num + ); + /*! + requires + - num > 0 + ensures + - #get_num_trees() == num + !*/ + + void set_feature_subsampling_fraction ( + double frac + ); + /*! + requires + - 0 < frac <= 1 + ensures + - #get_feature_subsampling_frac() == frac + !*/ + + double get_feature_subsampling_frac( + ) const; + /*! + ensures + - When we build trees, at each node we don't look at all the available + features. We consider only get_feature_subsampling_frac() fraction of + them, selected at random. + !*/ + + void set_min_samples_per_leaf ( + size_t num + ); + /*! + requires + - num > 0 + ensures + - #get_min_samples_per_leaf() == num + !*/ + + size_t get_min_samples_per_leaf( + ) const; + /*! + ensures + - When building trees, each leaf node in a tree will contain at least + get_min_samples_per_leaf() samples. This means that the output votes of + each tree are averages of at least get_min_samples_per_leaf() y values. + !*/ + + void be_verbose ( + ); + /*! + ensures + - This object will print status messages to standard out so that the + progress of training can be tracked.. + !*/ + + void be_quiet ( + ); + /*! + ensures + - this object will not print anything to standard out + !*/ + + random_forest_regression_function train ( + const std::vector& x, + const std::vector& y, + std::vector& oob_values + ) const; + /*! + requires + - x.size() == y.size() + - x.size() > 0 + - Running following code: + auto fe = get_feature_extractor() + fe.setup(x,y); + Must be valid and result in fe.max_num_feats() != 0 + ensures + - This function fits a regression forest to the given training data. The + goal being to regress x to y in the mean squared sense. It therefore + fits regression trees and returns the resulting random_forest_regression_function + RF, which will have the following properties: + - RF.get_num_trees() == get_num_trees() + - for all valid i: + - RF(x[i]) should output a value close to y[i] + - RF.get_feature_extractor() will be a copy of this->get_feature_extractor() + that has been configured by a call the feature extractor's setup() routine. + To run the algorithm we need to use a feature extractor. We obtain a + valid feature extractor by making a copy of get_feature_extractor(), then + invoking setup(x,y) on it. This feature extractor is what is used to fit + the trees and is also the feature extractor stored in the returned random + forest. + - #oob_values.size() == y.size() + - for all valid i: + - #oob_values[i] == the "out of bag" prediction for y[i]. It is + calculated by computing the average output from trees not trained on + y[i]. This is similar to a leave-one-out cross-validation prediction + of y[i] and can be used to estimate the generalization error of the + regression forest. + - Training uses all the available CPU cores. + !*/ + + random_forest_regression_function train ( + const std::vector& x, + const std::vector& y + ) const; + /*! + requires + - x.size() == y.size() + - x.size() > 0 + - Running following code: + auto fe = get_feature_extractor() + fe.setup(x,y); + Must be valid and result in fe.max_num_feats() != 0 + ensures + - This function is identical to train(x,y,oob_values) except that the + oob_values are not calculated. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_RANdOM_FOREST_REGRESION_ABSTRACT_H_ + diff --git a/ml/dlib/dlib/ref.h b/ml/dlib/dlib/ref.h new file mode 100644 index 000000000..53a37fbf8 --- /dev/null +++ b/ml/dlib/dlib/ref.h @@ -0,0 +1,84 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_REFERENCE_WRAPpER_H_ +#define DLIB_REFERENCE_WRAPpER_H_ + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template< + typename T + > + class reference_wrapper + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a simple object that just holds a reference to another object. + It is useful because it can serve as a kind of "copyable reference". + !*/ + + public: + typedef T type; + + explicit reference_wrapper(T& o) : obj(&o) {} + + operator T&() const { return *obj; } + T& get() const { return *obj; } + + private: + T* obj; + }; + +// ---------------------------------------------------------------------------------------- + + template + reference_wrapper ref( + T& obj + ) { return reference_wrapper(obj); } + /*! + ensures + - returns a reference_wrapper that contains a reference to obj. + !*/ + +// ---------------------------------------------------------------------------------------- + + template + reference_wrapper ref( + reference_wrapper obj + ) { return obj; } + /*! + ensures + - returns the given reference_wrapper object without modification + !*/ + +// ---------------------------------------------------------------------------------------- + + template + reference_wrapper cref( + const T& obj + ) { return reference_wrapper(obj); } + /*! + ensures + - returns a reference_wrapper that contains a constant reference to obj. + !*/ + +// ---------------------------------------------------------------------------------------- + + template + reference_wrapper cref( + reference_wrapper obj + ) { return cref(obj.get()); } + /*! + ensures + - converts the given reference_wrapper into a reference_wrapper that contains a + constant reference. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_REFERENCE_WRAPpER_H_ + diff --git a/ml/dlib/dlib/reference_counter.h b/ml/dlib/dlib/reference_counter.h new file mode 100644 index 000000000..3e4b1ee20 --- /dev/null +++ b/ml/dlib/dlib/reference_counter.h @@ -0,0 +1,31 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_REFERENCE_COUNTEr_ +#define DLIB_REFERENCE_COUNTEr_ + +#include "reference_counter/reference_counter_kernel_1.h" +#include "algs.h" + +namespace dlib +{ + + template < + typename T, + typename copy = copy_functor + > + class reference_counter + { + reference_counter() {} + public: + + //----------- kernels --------------- + + // kernel_1a + typedef reference_counter_kernel_1 + kernel_1a; + + }; +} + +#endif // DLIB_REFERENCE_COUNTEr_ + diff --git a/ml/dlib/dlib/reference_counter/reference_counter_kernel_1.h b/ml/dlib/dlib/reference_counter/reference_counter_kernel_1.h new file mode 100644 index 000000000..64e465550 --- /dev/null +++ b/ml/dlib/dlib/reference_counter/reference_counter_kernel_1.h @@ -0,0 +1,298 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_REFERENCE_COUNTER_KERNEl_1_ +#define DLIB_REFERENCE_COUNTER_KERNEl_1_ + +#include "reference_counter_kernel_abstract.h" +#include "../algs.h" + +namespace dlib +{ + + template < + typename T, + typename copy = copy_functor + > + class reference_counter_kernel_1 + { + + /*! + INITIAL VALUE + *data = item of type T with its initial value + *count = 1 + + CONVENTION + *data = pointer to item of type T + *count = number of references to *data + + if clear() threw an exception then count = 0 and data is not a + valid pointer + !*/ + + public: + + typedef T type; + + + reference_counter_kernel_1 ( + ); + + inline reference_counter_kernel_1 ( + const reference_counter_kernel_1& item + ); + + virtual ~reference_counter_kernel_1 ( + ); + + void clear ( + ); + + T& modify ( + ); + + inline const T& access ( + ) const; + + inline reference_counter_kernel_1& operator= ( + const reference_counter_kernel_1& rhs + ); + + inline void swap ( + reference_counter_kernel_1& item + ); + + + private: + + T* data; + unsigned long* count; + mutable copy copy_item; + }; + + template < + typename T, + typename copy + > + inline void swap ( + reference_counter_kernel_1& a, + reference_counter_kernel_1& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename copy + > + reference_counter_kernel_1:: + reference_counter_kernel_1 ( + ) + { + data = new T; + try { count = new unsigned long; } + catch (...) { delete data; throw; } + + *count = 1; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename copy + > + reference_counter_kernel_1:: + reference_counter_kernel_1 ( + const reference_counter_kernel_1& item + ) : + data(item.data), + count(item.count) + { + ++(*count); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename copy + > + reference_counter_kernel_1:: + ~reference_counter_kernel_1 ( + ) + { + if (*count > 1) + { + // if there are other references to this data + --(*count); + } + else + { + // if there are no other references to this data + delete count; + delete data; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename copy + > + void reference_counter_kernel_1:: + clear ( + ) + { + // if an exception was thrown last time clear() was called then do this + if (count == 0) + { + data = new T; + try { count = new unsigned long; } + catch (...) { delete data; throw; } + + *count = 1; + } + // if there are other references to the data then do this + else if (*count > 1) + { + --(*count); + + try { data = new T; } + catch (...) { count = 0; throw; } + + try { count = new unsigned long; } + catch (...) { delete data; count = 0; throw; } + + *count = 1; + + } + else + { + // if there are no other references to this data + *count = 1; + delete data; + try { data = new T; } catch (...) { delete count; count = 0; throw; } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename copy + > + T& reference_counter_kernel_1:: + modify ( + ) + { + // if this is not the only reference then make a new copy + if ( *count > 1 ) + { + T& old_data = *data; + unsigned long& old_count = *count; + + + // get memory for the new copy + try { data = new T; } + catch (...) { data = &old_data; throw; } + + try { count = new unsigned long; } + catch (...) {delete data; data = &old_data; count = &old_count; throw;} + + // decrement the number of references to old_data + --(old_count); + + *count = 1; + + // make a copy of the old data + try { copy_item(old_data,*data); } + catch (...) + { delete data; delete count; data = &old_data; count = &old_count; } + + } + + return *data; + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename copy + > + const T& reference_counter_kernel_1:: + access ( + ) const + { + return *data; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename copy + > + reference_counter_kernel_1& reference_counter_kernel_1:: + operator= ( + const reference_counter_kernel_1& rhs + ) + { + if (this == &rhs) + return *this; + + // delete the current data if this is the last reference to it + if (*count > 1) + { + // if there are other references to this data + --(*count); + } + else + { + // if there are no other references to this data + delete count; + delete data; + } + + // copy the pointers + count = (rhs.count); + data = (rhs.data); + ++(*count); + + return *this; + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename copy + > + void reference_counter_kernel_1:: + swap ( + reference_counter_kernel_1& item + ) + { + T* data_temp = data; + unsigned long* count_temp = count; + + data = item.data; + count = item.count; + + item.data = data_temp; + item.count = count_temp; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_REFERENCE_COUNTER_KERNEl_1_ + diff --git a/ml/dlib/dlib/reference_counter/reference_counter_kernel_abstract.h b/ml/dlib/dlib/reference_counter/reference_counter_kernel_abstract.h new file mode 100644 index 000000000..e46127387 --- /dev/null +++ b/ml/dlib/dlib/reference_counter/reference_counter_kernel_abstract.h @@ -0,0 +1,141 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_REFERENCE_COUNTER_KERNEl_ABSTRACT_ +#ifdef DLIB_REFERENCE_COUNTER_KERNEl_ABSTRACT_ + +#include "../algs.h" + +namespace dlib +{ + + template < + typename T, + typename copy = copy_functor + > + class reference_counter + { + + /*! + REQUIREMENTS ON T + T must have a default constructor + + REQUIREMENTS ON copy + it should be a function object that copies an object of type T. and + it must have a default constructor and + operator() should be overloaded as + void operator()(const T& source, T& destination); + copy may throw any exception + + POINTERS AND REFERENCES TO INTERNAL DATA + swap() and access() functions do not invalidate pointers or + references to internal data. + All other functions have no such guarantee + + + INITIAL VALUE + reference_counter contains one object of type T and + this object of type T has its initial value + + WHAT THIS OBJECT REPRESENTS + This object represents a container for an object of type T and + provides reference counting capabilities for the object it contains + + Also note that unless specified otherwise, no member functions + of this object throw exceptions. + + !*/ + + public: + + typedef T type; + + reference_counter ( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc or any exception thrown by T's constructor + !*/ + + reference_counter ( + const reference_counter& item + ); + /*! + ensures + - #access() == item.access() + !*/ + + virtual ~reference_counter ( + ); + /*! + ensures + - all memory associated with *this has been released + !*/ + + void clear ( + ); + /*! + ensures + - #*this has its initial value + throws + - std::bad_alloc or any exception thrown by T's constructor + if this exception is thrown then *this is unusable + until clear() is called and succeeds + !*/ + + T& modify ( + ); + /*! + ensures + - returns a non-const reference to the item contained in *this + - the item is ok to modify. i.e. there are no other references to it + throws + - std::bad_alloc or any exception thrown by T's constructor + modify() may throw this exception if there are other references + to the item and there is not enough memory to copy it. If modify() + throws then it has no effect. + !*/ + + const T& access ( + ) const; + /*! + ensures + - returns a const reference to the item contained in *this + - there may be other references to to the item + !*/ + + reference_counter& operator= ( + const reference_counter& rhs + ); + /*! + ensures + - #access() == rhs.access() + !*/ + + void swap ( + reference_counter& item + ); + /*! + ensures + - swaps *this and item + !*/ + + }; + + template < + typename T, + typename copy + > + inline void swap ( + reference_counter& a, + reference_counter& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + +} + +#endif // DLIB_REFERENCE_COUNTER_KERNEl_ABSTRACT_ + diff --git a/ml/dlib/dlib/revision.h.in b/ml/dlib/dlib/revision.h.in new file mode 100644 index 000000000..e6922c17c --- /dev/null +++ b/ml/dlib/dlib/revision.h.in @@ -0,0 +1,6 @@ +#ifndef DLIB_REVISION_H +#define DLIB_MAJOR_VERSION @CPACK_PACKAGE_VERSION_MAJOR@ +#define DLIB_MINOR_VERSION @CPACK_PACKAGE_VERSION_MINOR@ +#define DLIB_PATCH_VERSION @CPACK_PACKAGE_VERSION_PATCH@ +#endif + diff --git a/ml/dlib/dlib/sequence.h b/ml/dlib/dlib/sequence.h new file mode 100644 index 000000000..1223c30bb --- /dev/null +++ b/ml/dlib/dlib/sequence.h @@ -0,0 +1,83 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SEQUENCe_ +#define DLIB_SEQUENCe_ + +#include "sequence/sequence_kernel_1.h" +#include "sequence/sequence_kernel_2.h" +#include "sequence/sequence_kernel_c.h" + +#include "sequence/sequence_compare_1.h" +#include "sequence/sequence_sort_1.h" +#include "sequence/sequence_sort_2.h" +#include "algs.h" + + + + +namespace dlib +{ + + template < + typename T, + typename mem_manager = default_memory_manager + > + class sequence + { + + sequence() {} + public: + + //----------- kernels --------------- + + // kernel_1a + typedef sequence_kernel_1 + kernel_1a; + typedef sequence_kernel_c + kernel_1a_c; + + // kernel_2a + typedef sequence_kernel_2 + kernel_2a; + typedef sequence_kernel_c + kernel_2a_c; + + + //---------- extensions ------------ + + // compare_1 extend kernel_1a + typedef sequence_compare_1 + compare_1a; + typedef sequence_compare_1 + compare_1a_c; + + // compare_1 extend kernel_2a + typedef sequence_compare_1 + compare_1b; + typedef sequence_compare_1 + compare_1b_c; + + + + // sort_1 extend kernel_2a + typedef sequence_sort_1 + sort_1a; + typedef sequence_sort_1 + sort_1a_c; + + // sort_2 extend kernel_1a + typedef sequence_sort_2 + sort_2a; + typedef sequence_sort_2 + sort_2a_c; + + + + + + + }; +} + +#endif // DLIB_SEQUENCe_ + diff --git a/ml/dlib/dlib/sequence/sequence_compare_1.h b/ml/dlib/dlib/sequence/sequence_compare_1.h new file mode 100644 index 000000000..9bc3c773d --- /dev/null +++ b/ml/dlib/dlib/sequence/sequence_compare_1.h @@ -0,0 +1,102 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SEQUENCE_COMPARe_1_ +#define DLIB_SEQUENCE_COMPARe_1_ + +#include "sequence_compare_abstract.h" + +#include "../algs.h" + + +namespace dlib +{ + + template < + typename seq_base + > + class sequence_compare_1 : public seq_base + { + typedef typename seq_base::type T; + + public: + + bool operator< ( + const sequence_compare_1& rhs + ) const; + + bool operator== ( + const sequence_compare_1& rhs + ) const; + + }; + + + template < + typename seq_base + > + inline void swap ( + sequence_compare_1& a, + sequence_compare_1& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename seq_base + > + bool sequence_compare_1:: + operator< ( + const sequence_compare_1& rhs + ) const + { + unsigned int length; + if (this->size() < rhs.size()) + length = this->size(); + else + length = rhs.size(); + + for (unsigned long i = 0; i < length; ++i) + { + if ((*this)[i] < rhs[i]) + return true; + else if ( !((*this)[i] == rhs[i]) ) + return false; + } + // they are equal so far + if (this->size() < rhs.size()) + return true; + else + return false; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename seq_base + > + bool sequence_compare_1:: + operator== ( + const sequence_compare_1& rhs + ) const + { + if (this->size() != rhs.size()) + return false; + + for (unsigned long i = 0; i < this->size(); ++i) + { + if (!((*this)[i] == rhs[i])) + return false; + } + return true; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SEQUENCE_COMPARe_1_ + diff --git a/ml/dlib/dlib/sequence/sequence_compare_abstract.h b/ml/dlib/dlib/sequence/sequence_compare_abstract.h new file mode 100644 index 000000000..261703f45 --- /dev/null +++ b/ml/dlib/dlib/sequence/sequence_compare_abstract.h @@ -0,0 +1,75 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SEQUENCE_COMPARe_ABSTRACT_ +#ifdef DLIB_SEQUENCE_COMPARe_ABSTRACT_ + +#include "sequence_kernel_abstract.h" + +#include "../algs.h" + + +namespace dlib +{ + + template < + typename seq_base + > + class sequence_compare : public seq_base + { + + /*! + REQUIREMENTS ON T + T must implement operator< for its type and + T must implement operator== for its type + + REQUIREMENTS ON SEQUENCE_BASE + must be an implementation of sequence/sequence_kernel_abstract.h + + + POINTERS AND REFERENCES TO INTERNAL DATA + operator== and operator< do not invalidate pointers or references to + data members + + WHAT THIS EXTENSION DOES FOR sequence + This gives a sequence the ability to compare itself to other + sequences using the < and == operators. + !*/ + + public: + + bool operator< ( + const sequence_compare& rhs + ) const; + /*! + ensures + - returns true if there exists an integer j such that 0 <= j < size() + and for all integers i such that 0 <= i < j where it is true that + (*this)[i] <= rhs[i] and (*this)[j] < rhs[j] + - returns false if there is no j that will satisfy the above conditions + !*/ + + bool operator== ( + const sequence_compare& rhs + ) const; + /*! + ensures + - returns true if for all i: (*this)[i] == rhs[i] else returns false + !*/ + + }; + + template < + typename seq_base + > + inline void swap ( + sequence_compare& a, + sequence_compare& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + +} + +#endif // DLIB_SEQUENCE_COMPARe_ABSTRACT_ + diff --git a/ml/dlib/dlib/sequence/sequence_kernel_1.h b/ml/dlib/dlib/sequence/sequence_kernel_1.h new file mode 100644 index 000000000..9e1e26f1b --- /dev/null +++ b/ml/dlib/dlib/sequence/sequence_kernel_1.h @@ -0,0 +1,1340 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SEQUENCE_KERNEl_1_ +#define DLIB_SEQUENCE_KERNEl_1_ + +#include "sequence_kernel_abstract.h" +#include "../algs.h" +#include "../interfaces/enumerable.h" +#include "../interfaces/remover.h" +#include "../serialize.h" + +namespace dlib +{ + + template < + typename T, + typename mem_manager = default_memory_manager + > + class sequence_kernel_1 : public enumerable, + public remover + { + + /*! + INITIAL VALUE + - tree_root == 0 + - tree_size == 0 + - at_start_ == true + - current_element == 0 + - stack == array of 50 node pointers + - stack_pos == 0 + + CONVENTION + + - if (tree_size > 0) + - tree_root == pointer to the root node of the binary search tree + - else + - tree_root == 0 + + + + - stack[stack_pos-1] == pop() + + - current_element_valid() == (current_element != 0) + + - at_start_ == at_start() + - if (current_element != 0 && current_element != tree_root) then + - stack[stack_pos-1] == the parent of the node pointed to by current_element + + - if (current_element_valid()) then + - element() == current_element->item + + + + - tree_size == size() + - (*this)[i] == return_reference(i) + + + - for all nodes: + - left_size == the number of elements in the left subtree. + - left points to the left subtree or 0 if there is no left subtree. + - right points to the right subtree or 0 if there is no right subtree. + + - all elements in a left subtree have a position in the sequence < that + of the root of the current tree. + + - all elements in a right subtree have a position in the + sequence > that of the root of the current tree. + + - item is the sequence element for that node. + - balance: + - balance == 0 if both subtrees have the same height + - balance == -1 if the left subtree has a height that is + greater than the height of the right subtree by 1 + - balance == 1 if the right subtree has a height that is + greater than the height of the left subtree by 1 + - for all subtrees: + - the height of the left and right subtrees differ by at most one + + !*/ + + + class node + { + public: + node* left; + node* right; + unsigned long left_size; + T item; + signed char balance; + }; + + + + public: + + typedef T type; + typedef mem_manager mem_manager_type; + + sequence_kernel_1 ( + ) : + tree_root(0), + tree_size(0), + stack(ppool.allocate_array(50)), + current_element(0), + at_start_(true), + stack_pos(0) + {} + + virtual ~sequence_kernel_1 ( + ); + + inline void clear ( + ); + + void add ( + unsigned long pos, + T& item + ); + + void remove ( + unsigned long pos, + T& item + ); + + void cat ( + sequence_kernel_1& item + ); + + const T& operator[] ( + unsigned long pos + ) const; + + T& operator[] ( + unsigned long pos + ); + + inline void swap ( + sequence_kernel_1& item + ); + + // functions from the remover interface + inline void remove_any ( + T& item + ); + + // functions from the enumerable interface + inline size_t size ( + ) const; + + bool at_start ( + ) const; + + inline void reset ( + ) const; + + bool current_element_valid ( + ) const; + + const T& element ( + ) const; + + T& element ( + ); + + bool move_next ( + ) const; + + + private: + + void delete_nodes ( + node* t + ); + /*! + requires + - t == a pointer to a valid node + ensures + - deletes t and all its sub nodes. + !*/ + + inline void rotate_left ( + node*& t + ); + /*! + requires + - t->balance == 2 + - t->right->balance == 0 or 1 + ensures + - #t is still a binary search tree + - #t->balance is between 1 and -1 + - #t now has a height smaller by 1 if #t->balance == 0 + !*/ + + inline void rotate_right ( + node*& t + ); + /*! + requires + - t->balance == -2 + - t->left->balance == 0 or -1 + ensures + - #t is still a binary search tree + - #t->balance is between 1 and -1 + - #t now has a height smaller by 1 if #t->balance == 0 + + !*/ + + inline void double_rotate_right ( + node*& t + ); + /*! + requires + - #t->balance == -2 + - #t->left->balance == 1 + ensures + - #t is still a binary search tree + - #t now has a balance of 0 + - #t now has a height smaller by 1 + !*/ + + inline void double_rotate_left ( + node*& t + ); + /*! + requires + - #t->balance == 2 + - #t->right->balance == -1 + ensures + - #t is still a binary search tree + - #t now has a balance of 0 and + - #t now has a height smaller by 1 + !*/ + + bool remove_least_element_in_tree ( + node*& t, + T& item + ); + /*! + requires + - t != 0 (i.e. there must be something in the tree to remove) + ensures + - the least node in t has been removed + - the least node element in t has been put into #item + - #t is still a binary search tree + - returns false if the height of the tree has not changed + - returns true if the height of the tree has shrunk by one + !*/ + + bool add_to_tree ( + node*& t, + unsigned long pos, + T& item + ); + /*! + requires + - pos <= the number of items in the tree + ensures + - item has been added to #t + - #return_reference(pos) == item + - the convention is still satisfied + - #item has an initial value for its type + - returns false if the height of the tree has not changed + - returns true if the height of the tree has grown by one + !*/ + + bool remove_from_tree ( + node*& t, + unsigned long pos, + T& item + ); + /*! + requires + - there is an item in the tree associated with pos + ensures + - the element in the tree associated with pos has been removed + and put into #item + - the convention is still satisfied + - returns false if the height of the tree has not changed + - returns true if the height of the tree has shrunk by one + !*/ + + const T& return_reference ( + const node* t, + unsigned long pos + ) const; + /*! + requires + - there is an item in the tree associated with pos + ensures + - returns a const reference to the item in the tree associated with pos + !*/ + + T& return_reference ( + node* t, + unsigned long pos + ); + /*! + requires + - there is an item in the tree associated with pos + ensures + - returns a non-const reference to the item in the tree associated + with pos + !*/ + + inline bool keep_node_balanced ( + node*& t + ); + /*! + requires + - t != 0 + ensures + - if (t->balance is < 1 or > 1) then + - keep_node_balanced() will ensure that t->balance == 0, -1, or 1 + - returns true if it made the tree one height shorter + - returns false if it didn't change the height + !*/ + + void push ( + node* n + ) const { stack[stack_pos] = n; ++stack_pos; } + /*! + ensures + - pushes n onto the stack + !*/ + + + node* pop ( + ) const { --stack_pos; return stack[stack_pos]; } + /*! + ensures + - pops the top of the stack and returns it + !*/ + + // data members + typename mem_manager::template rebind::other pool; + typename mem_manager::template rebind::other ppool; + + node* tree_root; + unsigned long tree_size; + + mutable node** stack; + mutable node* current_element; + mutable bool at_start_; + mutable unsigned char stack_pos; + + // restricted functions + sequence_kernel_1(sequence_kernel_1&); // copy constructor + sequence_kernel_1& operator=(sequence_kernel_1&); // assignment operator + + }; + + template < + typename T, + typename mem_manager + > + inline void swap ( + sequence_kernel_1& a, + sequence_kernel_1& b + ) { a.swap(b); } + + template < + typename T, + typename mem_manager + > + void deserialize ( + sequence_kernel_1& item, + std::istream& in + ) + { + try + { + item.clear(); + unsigned long size; + deserialize(size,in); + T temp; + for (unsigned long i = 0; i < size; ++i) + { + deserialize(temp,in); + item.add(i,temp); + } + } + catch (serialization_error e) + { + item.clear(); + throw serialization_error(e.info + "\n while deserializing object of type sequence_kernel_1"); + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + sequence_kernel_1:: + ~sequence_kernel_1 ( + ) + { + ppool.deallocate_array(stack); + if (tree_size > 0) + { + delete_nodes(tree_root); + } + } + +// ---------------------------------------------------------------------------------------- + template < + typename T, + typename mem_manager + > + void sequence_kernel_1:: + swap ( + sequence_kernel_1& item + ) + { + exchange(stack,item.stack); + exchange(stack_pos,item.stack_pos); + + pool.swap(item.pool); + ppool.swap(item.ppool); + + node* tree_root_temp = item.tree_root; + unsigned long tree_size_temp = item.tree_size; + node* current_element_temp = item.current_element; + bool at_start_temp = item.at_start_; + + item.tree_root = tree_root; + item.tree_size = tree_size; + item.current_element = current_element; + item.at_start_ = at_start_; + + tree_root = tree_root_temp; + tree_size = tree_size_temp; + current_element = current_element_temp; + at_start_ = at_start_temp; + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + size_t sequence_kernel_1:: + size ( + ) const + { + return tree_size; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + const T& sequence_kernel_1:: + operator[] ( + unsigned long pos + ) const + { + return return_reference(tree_root,pos); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + T& sequence_kernel_1:: + operator[] ( + unsigned long pos + ) + { + return return_reference(tree_root,pos); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void sequence_kernel_1:: + add ( + unsigned long pos, + T& item + ) + { + add_to_tree(tree_root,pos,item); + ++tree_size; + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void sequence_kernel_1:: + remove ( + unsigned long pos, + T& item + ) + { + remove_from_tree(tree_root,pos,item); + --tree_size; + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void sequence_kernel_1:: + cat ( + sequence_kernel_1& item + ) + { + for (unsigned long i = 0; i < item.tree_size; ++i) + { + add_to_tree( + tree_root, + tree_size, + return_reference(item.tree_root,i) + ); + + ++tree_size; + } + + item.clear(); + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void sequence_kernel_1:: + clear ( + ) + { + if (tree_size > 0) + { + delete_nodes(tree_root); + tree_root = 0; + tree_size = 0; + } + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // enumerable function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + bool sequence_kernel_1:: + at_start ( + ) const + { + return at_start_; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void sequence_kernel_1:: + reset ( + ) const + { + at_start_ = true; + current_element = 0; + stack_pos = 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + bool sequence_kernel_1:: + current_element_valid ( + ) const + { + return (current_element != 0); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + const T& sequence_kernel_1:: + element ( + ) const + { + return current_element->item; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + T& sequence_kernel_1:: + element ( + ) + { + return current_element->item; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + bool sequence_kernel_1:: + move_next ( + ) const + { + // if we haven't started iterating yet + if (at_start_) + { + at_start_ = false; + if (tree_size == 0) + { + return false; + } + else + { + // find the first element in the tree + current_element = tree_root; + node* temp = current_element->left; + while (temp != 0) + { + push(current_element); + current_element = temp; + temp = current_element->left; + } + return true; + } + } + else + { + if (current_element == 0) + { + return false; + } + else + { + node* temp; + bool went_up; // true if we went up the tree from a child node to parent + bool from_left = false; // true if we went up and were coming from a left child node + // find the next element in the tree + if (current_element->right != 0) + { + // go right and down + temp = current_element; + push(current_element); + current_element = temp->right; + went_up = false; + } + else + { + // go up to the parent if we can + if (current_element == tree_root) + { + // in this case we have iterated over all the element of the tree + current_element = 0; + return false; + } + went_up = true; + node* parent = pop(); + + + from_left = (parent->left == current_element); + // go up to parent + current_element = parent; + } + + + while (true) + { + if (went_up) + { + if (from_left) + { + // in this case we have found the next node + break; + } + else + { + if (current_element == tree_root) + { + // in this case we have iterated over all the elements + // in the tree + current_element = 0; + return false; + } + // we should go up + node* parent = pop(); + from_left = (parent->left == current_element); + current_element = parent; + } + } + else + { + // we just went down to a child node + if (current_element->left != 0) + { + // go left + went_up = false; + temp = current_element; + push(current_element); + current_element = temp->left; + } + else + { + // if there is no left child then we have found the next node + break; + } + } + } + + return true; + } + } + + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // remover function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void sequence_kernel_1:: + remove_any ( + T& item + ) + { + remove(0,item); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // private member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void sequence_kernel_1:: + rotate_left ( + node*& t + ) + { + + // set the new balance numbers + if (t->right->balance == 1) + { + t->balance = 0; + t->right->balance = 0; + } + else + { + t->balance = 1; + t->right->balance = -1; + } + + // perform the rotation + node* temp = t->right; + t->right = temp->left; + temp->left = t; + t = temp; + + + // set left_size to its correct value + t->left_size += t->left->left_size + 1; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void sequence_kernel_1:: + rotate_right ( + node*& t + ) + { + // set the new balance numbers + if (t->left->balance == -1) + { + t->balance = 0; + t->left->balance = 0; + } + else + { + t->balance = -1; + t->left->balance = 1; + } + + // preform the rotation + node* temp = t->left; + t->left = temp->right; + temp->right = t; + t = temp; + + + // set left_size to its correct value + t->right->left_size -= t->left_size + 1; + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void sequence_kernel_1:: + double_rotate_right ( + node*& t + ) + { + + node* temp = t; + t = t->left->right; + + temp->left->right = t->left; + t->left = temp->left; + + temp->left = t->right; + t->right = temp; + + if (t->balance < 0) + { + t->left->balance = 0; + t->right->balance = 1; + } + else if (t->balance > 0) + { + t->left->balance = -1; + t->right->balance = 0; + } + else + { + t->left->balance = 0; + t->right->balance = 0; + } + t->balance = 0; + + + // set left_size to its correct value + t->left_size += t->left->left_size + 1; + t->right->left_size -= t->left_size + 1; + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void sequence_kernel_1:: + double_rotate_left ( + node*& t + ) + { + node* temp = t; + t = t->right->left; + + temp->right->left = t->right; + t->right = temp->right; + + temp->right = t->left; + t->left = temp; + + if (t->balance < 0) + { + t->left->balance = 0; + t->right->balance = 1; + } + else if (t->balance > 0) + { + t->left->balance = -1; + t->right->balance = 0; + } + else + { + t->left->balance = 0; + t->right->balance = 0; + } + + t->balance = 0; + + // set left_size to its correct value + t->right->left_size -= t->left_size + 1; + t->left_size += t->left->left_size + 1; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + bool sequence_kernel_1:: + remove_least_element_in_tree ( + node*& t, + T& item + ) + { + // make a reference to the current node so we don't have to dereference + // a pointer a bunch of times + node& tree = *t; + + // if the left tree is an empty tree + if ( tree.left == 0) + { + // swap nodes element into item + exchange(tree.item,item); + + // plug hole left by removing this node + t = tree.right; + + // delete the node that was just removed + tree.right = 0; + delete_nodes(&tree); + + // return that the height of this part of the tree has decreased + return true; + } + else + { + // subtract one from the left size + --tree.left_size; + + // keep going left + + // if remove made the tree one height shorter + if ( remove_least_element_in_tree(tree.left,item) ) + { + // if this caused the current tree to strink then report that + if ( tree.balance == -1) + { + ++tree.balance; + return true; + } + else + { + ++tree.balance; + return keep_node_balanced(t); + } + } + + return false; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + bool sequence_kernel_1:: + add_to_tree ( + node*& t, + unsigned long pos, + T& item + ) + { + // if found place to add + if (t == 0) + { + // create a node to add new item into + t = pool.allocate(); + + // make a reference to the current node so we don't have to dereference + // a pointer a bunch of times + node& tree = *t; + + + // set left and right pointers to 0 to indicate that there are no + // left or right subtrees + tree.left = 0; + tree.right = 0; + tree.balance = 0; + tree.left_size = 0; + + // put item into t + exchange(item,tree.item); + + // indicate that the height of this tree has increased + return true; + } + else // keep looking for a place to add the new item + { + // make a reference to the current node so we don't have to dereference + // a pointer a bunch of times + node& tree = *t; + signed char old_balance = tree.balance; + + // add the new item to whatever subtree it should go into + if ( pos < tree.left_size + 1 ) + { + tree.balance -= add_to_tree(tree.left,pos,item); + ++tree.left_size; + } + else + tree.balance += add_to_tree(tree.right,pos - tree.left_size - 1,item); + + + // if the tree was balanced to start with + if (old_balance == 0) + { + // if its not balanced anymore then it grew in height + if (tree.balance != 0) + return true; + else + return false; + } + else + { + // if the tree is now balanced then it didn't grow + if (tree.balance == 0) + { + return false; + } + else + { + // if the tree needs to be balanced + if (tree.balance != old_balance) + { + return !keep_node_balanced(t); + } + // if there has been no change in the heights + else + { + return false; + } + } + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + bool sequence_kernel_1:: + remove_from_tree ( + node*& t, + unsigned long pos, + T& item + ) + { + + // make a reference to the current node so we don't have to dereference + // a pointer a bunch of times + node& tree = *t; + + // if item is on the left + if (pos < tree.left_size) + { + // adjust the left size + --tree.left_size; + + // if the left side of the tree has the greatest height + if (tree.balance == -1) + { + tree.balance += remove_from_tree(tree.left,pos,item); + return !tree.balance; + } + else + { + tree.balance += remove_from_tree(tree.left,pos,item); + return keep_node_balanced(t); + } + + } + // if item is found + else if (pos == tree.left_size) + { + // if there is no left node + if (tree.left == 0) + { + // swap nodes element into item + exchange(tree.item,item); + + // plug hole left by removing this node and free memory + t = tree.right; // plug hole with right subtree + + // delete old node + tree.right = 0; + delete_nodes(&tree); + + // indicate that the height has changed + return true; + } + // if there is no right node + else if (tree.right == 0) + { + // swap nodes element into item + exchange(tree.item,item); + + // plug hole left by removing this node and free memory + t = tree.left; // plug hole with left subtree + + // delete old node + tree.left = 0; + delete_nodes(&tree); + + // indicate that the height of this tree has changed + return true; + } + // if there are both a left and right sub node + else + { + // get an element that can replace the one being removed and do this + // if it made the right subtree shrink by one + if (remove_least_element_in_tree(tree.right,item)) + { + // adjust the tree height + --tree.balance; + + // put the element into item copy and also plug the + // hole with the smallest element from the right. + exchange(item,tree.item); + + // if the height of the current tree has dropped by one + if (tree.balance == 0) + { + return true; + } + else + { + return keep_node_balanced(t); + } + } + // else this remove did not effect the height of this tree + else + { + // put the element into item copy and also plug the + // hole with the smallest element from the right. + exchange(item,tree.item); + + return false; + } + + } + } + // if item is on the right + else + { + + // if the right side of the tree has the greatest height + if (tree.balance == 1) + { + tree.balance -= remove_from_tree(tree.right,pos - tree.left_size - 1,item); + return !tree.balance; + } + else + { + tree.balance -= remove_from_tree(tree.right,pos - tree.left_size - 1,item); + return keep_node_balanced(t); + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + T& sequence_kernel_1:: + return_reference ( + node* t, + unsigned long pos + ) + { + while (true) + { + // if we have found the node + if (pos == t->left_size) + return t->item; + + if (pos < t->left_size) + { + // go left + t = t->left; + } + else + { + // go right + pos -= t->left_size+1; + t = t->right; + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + const T& sequence_kernel_1:: + return_reference ( + const node* t, + unsigned long pos + ) const + { + while (true) + { + // if we have found the node + if (pos == t->left_size) + return t->item; + + if (pos < t->left_size) + { + // go left + t = t->left; + } + else + { + // go right + pos -= t->left_size+1; + t = t->right; + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + bool sequence_kernel_1:: + keep_node_balanced ( + node*& t + ) + { + // make a reference to the current node so we don't have to dereference + // a pointer a bunch of times + node& tree = *t; + + // if tree does not need to be balanced then return false + if (tree.balance == 0) + return false; + + + // if tree needs to be rotated left + if (tree.balance == 2) + { + if (tree.right->balance >= 0) + rotate_left(t); + else + double_rotate_left(t); + } + // else if the tree needs to be rotated right + else if (tree.balance == -2) + { + if (tree.left->balance <= 0) + rotate_right(t); + else + double_rotate_right(t); + } + + + if (t->balance == 0) + return true; + else + return false; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void sequence_kernel_1:: + delete_nodes ( + node* t + ) + { + if (t->left) + delete_nodes(t->left); + if (t->right) + delete_nodes(t->right); + pool.deallocate(t); + } + +// ---------------------------------------------------------------------------------------- +} + +#endif // DLIB_SEQUENCE_KERNEl_1_ + diff --git a/ml/dlib/dlib/sequence/sequence_kernel_2.h b/ml/dlib/dlib/sequence/sequence_kernel_2.h new file mode 100644 index 000000000..f15c50e93 --- /dev/null +++ b/ml/dlib/dlib/sequence/sequence_kernel_2.h @@ -0,0 +1,682 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SEQUENCE_KERNEl_2_ +#define DLIB_SEQUENCE_KERNEl_2_ + +#include "sequence_kernel_abstract.h" +#include "../algs.h" +#include "../interfaces/enumerable.h" +#include "../interfaces/remover.h" +#include "../serialize.h" + +namespace dlib +{ + + + template < + typename T, + typename mem_manager = default_memory_manager + > + class sequence_kernel_2 : public enumerable, + public remover + { + /*! + INITIAL VALUE + sequence_size == 0 + at_start_ == true + current_enumeration_node == 0 + + CONVENTION + sequence_size == the number of elements in the sequence + + at_start_ == at_start() + (current_enumeration_node!=0) == current_element_valid() + if (current_enumeration_node!=0) then + current_enumeration_node->item == element() + current_enumeration_pos == the position of the node pointed to by + current_enumeration_node + + if ( sequence_size > 0 ) + { + current_node == pointer to a node in the linked list and + current_node->right->right->... eventually == current_node and + current_node->left->left->... eventually == current_node and + current_pos == the position in the sequence of + current_node->item + } + + !*/ + + struct node { + T item; + node* right; + node* left; + }; + + public: + + typedef T type; + typedef mem_manager mem_manager_type; + + sequence_kernel_2 ( + ) : + sequence_size(0), + at_start_(true), + current_enumeration_node(0) + {} + + virtual ~sequence_kernel_2 ( + ); + + inline void clear ( + ); + + void add ( + unsigned long pos, + T& item + ); + + void remove ( + unsigned long pos, + T& item + ); + + void cat ( + sequence_kernel_2& item + ); + + const T& operator[] ( + unsigned long pos + ) const; + + T& operator[] ( + unsigned long pos + ); + + void swap ( + sequence_kernel_2& item + ); + + // functions from the remover interface + inline void remove_any ( + T& item + ); + + // functions from the enumerable interface + inline size_t size ( + ) const; + + bool at_start ( + ) const; + + inline void reset ( + ) const; + + bool current_element_valid ( + ) const; + + const T& element ( + ) const; + + T& element ( + ); + + bool move_next ( + ) const; + + private: + + void delete_nodes ( + node* current_node, + unsigned long sequence_size + ); + /*! + requires + CONVENTION IS CORRECT + ensures + all memory associated with the ring of nodes has been freed + !*/ + + void move_to_pos ( + node*& current_node, + unsigned long& current_pos, + unsigned long pos, + unsigned long size + ) const; + /*! + requires + everything in the CONVENTION is correct and + there is a node corresponding to pos in the CONVENTION and + 0 <= pos < size + ensures + current_pos == pos and + current_node->item is the item in the sequence associated with + position pos + !*/ + + // data members + unsigned long sequence_size; + mutable node* current_node; + mutable unsigned long current_pos; + mutable bool at_start_; + mutable node* current_enumeration_node; + mutable unsigned long current_enumeration_pos; + + // restricted functions + sequence_kernel_2(sequence_kernel_2&); // copy constructor + sequence_kernel_2& operator=(sequence_kernel_2&); // assignment operator + + }; + + + template < + typename T, + typename mem_manager + > + inline void swap ( + sequence_kernel_2& a, + sequence_kernel_2& b + ) { a.swap(b); } + + template < + typename T, + typename mem_manager + > + void deserialize ( + sequence_kernel_2& item, + std::istream& in + ) + { + try + { + item.clear(); + unsigned long size; + deserialize(size,in); + T temp; + for (unsigned long i = 0; i < size; ++i) + { + deserialize(temp,in); + item.add(i,temp); + } + } + catch (serialization_error e) + { + item.clear(); + throw serialization_error(e.info + "\n while deserializing object of type sequence_kernel_2"); + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + sequence_kernel_2:: + ~sequence_kernel_2 ( + ) + { + delete_nodes(current_node,sequence_size); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void sequence_kernel_2:: + clear ( + ) + { + if (sequence_size != 0) + { + delete_nodes(current_node,sequence_size); + sequence_size = 0; + } + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void sequence_kernel_2:: + add ( + unsigned long pos, + T& item + ) + { + // make new node and swap item into it + node* new_node = new node; + exchange(item,new_node->item); + + if (sequence_size > 0) + { + if (pos == sequence_size) + { + move_to_pos(current_node,current_pos,pos-1,sequence_size); + + node& n_node = *new_node; + node& c_node = *current_node; + + // make new node point to the nodes to its left and right + n_node.right = c_node.right; + n_node.left = current_node; + + // make the left node point back to new_node + c_node.right->left = new_node; + + // make the right node point back to new_node + c_node.right = new_node; + current_pos = pos; + + } + else + { + move_to_pos(current_node,current_pos,pos,sequence_size); + + node& n_node = *new_node; + node& c_node = *current_node; + + // make new node point to the nodes to its left and right + n_node.right = current_node; + n_node.left = c_node.left; + + // make the left node point back to new_node + c_node.left->right = new_node; + + // make the right node point back to new_node + c_node.left = new_node; + } + + } + else + { + current_pos = 0; + new_node->left = new_node; + new_node->right = new_node; + } + + // make the new node the current node + current_node = new_node; + + ++sequence_size; + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void sequence_kernel_2:: + remove ( + unsigned long pos, + T& item + ) + { + move_to_pos(current_node,current_pos,pos,sequence_size); + node& c_node = *current_node; + exchange(c_node.item,item); + + node* temp = current_node; + + // close up gap left by remove + c_node.left->right = c_node.right; + c_node.right->left = c_node.left; + + current_node = c_node.right; + + --sequence_size; + + delete temp; + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + const T& sequence_kernel_2:: + operator[] ( + unsigned long pos + ) const + { + move_to_pos(current_node,current_pos,pos,sequence_size); + return current_node->item; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void sequence_kernel_2:: + cat ( + sequence_kernel_2& item + ) + { + if (item.sequence_size > 0) + { + if (sequence_size > 0) + { + // move both sequences to a convenient location + move_to_pos(current_node,current_pos,0,sequence_size); + item.move_to_pos ( + item.current_node, + item.current_pos, + item.sequence_size-1, + item.sequence_size + ); + + // make copies of poitners + node& item_right = *item.current_node->right; + node& left = *current_node->left; + + + item.current_node->right = current_node; + current_node->left = item.current_node; + + left.right = &item_right; + item_right.left = &left; + + // set sizes + sequence_size += item.sequence_size; + item.sequence_size = 0; + } + else + { + // *this is empty so just swap + item.swap(*this); + } + } + item.clear(); + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + T& sequence_kernel_2:: + operator[] ( + unsigned long pos + ) + { + move_to_pos(current_node,current_pos,pos,sequence_size); + return current_node->item; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + size_t sequence_kernel_2:: + size ( + ) const + { + return sequence_size; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void sequence_kernel_2:: + swap ( + sequence_kernel_2& item + ) + { + unsigned long sequence_size_temp = item.sequence_size; + node* current_node_temp = item.current_node; + unsigned long current_pos_temp = item.current_pos; + bool at_start_temp = item.at_start_; + node* current_enumeration_node_temp = item.current_enumeration_node; + unsigned long current_enumeration_pos_temp = item.current_enumeration_pos; + + item.sequence_size = sequence_size; + item.current_node = current_node; + item.current_pos = current_pos; + item.at_start_ = at_start_; + item.current_enumeration_node = current_enumeration_node; + item.current_enumeration_pos = current_enumeration_pos; + + sequence_size = sequence_size_temp; + current_node = current_node_temp; + current_pos = current_pos_temp; + at_start_ = at_start_temp; + current_enumeration_node = current_enumeration_node_temp; + current_enumeration_pos = current_enumeration_pos_temp; + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // enumerable function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + bool sequence_kernel_2:: + at_start ( + ) const + { + return at_start_; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void sequence_kernel_2:: + reset ( + ) const + { + at_start_ = true; + current_enumeration_node = 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + bool sequence_kernel_2:: + current_element_valid ( + ) const + { + return (current_enumeration_node!=0); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + const T& sequence_kernel_2:: + element ( + ) const + { + return current_enumeration_node->item; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + T& sequence_kernel_2:: + element ( + ) + { + return current_enumeration_node->item; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + bool sequence_kernel_2:: + move_next ( + ) const + { + if (at_start_ && sequence_size>0) + { + move_to_pos(current_node,current_pos,0,sequence_size); + current_enumeration_node = current_node; + current_enumeration_pos = 0; + } + else if (current_enumeration_node!=0) + { + ++current_enumeration_pos; + if (current_enumeration_posright; + } + else + { + // we have reached the end of the sequence + current_enumeration_node = 0; + } + } + + at_start_ = false; + return (current_enumeration_node!=0); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // remover function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void sequence_kernel_2:: + remove_any ( + T& item + ) + { + remove(0,item); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // private member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void sequence_kernel_2:: + delete_nodes ( + node* current_node, + unsigned long sequence_size + ) + { + node* temp; + while (sequence_size) + { + temp = current_node->right; + delete current_node; + current_node = temp; + --sequence_size; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void sequence_kernel_2:: + move_to_pos ( + node*& current_node, + unsigned long& current_pos, + unsigned long pos, + unsigned long size + ) const + { + if ( current_pos > pos) + { + // number of hops in each direction needed to reach pos + unsigned long right = size + pos - current_pos; + unsigned long left = current_pos - pos; + current_pos = pos; + + if (left < right) + { + // move left to position pos + for (; left > 0; --left) + current_node = current_node->left; + } + else + { + // move left to position pos + for (; right > 0; --right) + current_node = current_node->right; + } + } + else if (current_pos != pos) + { + // number of hops in each direction needed to reach pos + unsigned long right = pos - current_pos; + unsigned long left = size - pos + current_pos; + current_pos = pos; + + if (left < right) + { + // move left to position pos + for (; left > 0; --left) + current_node = current_node->left; + } + else + { + // move left to position pos + for (; right > 0; --right) + current_node = current_node->right; + } + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SEQUENCE_KERNEl_2_ + diff --git a/ml/dlib/dlib/sequence/sequence_kernel_abstract.h b/ml/dlib/dlib/sequence/sequence_kernel_abstract.h new file mode 100644 index 000000000..8a0bdc5b3 --- /dev/null +++ b/ml/dlib/dlib/sequence/sequence_kernel_abstract.h @@ -0,0 +1,199 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SEQUENCE_KERNEl_ABSTRACT_ +#ifdef DLIB_SEQUENCE_KERNEl_ABSTRACT_ + +#include "../interfaces/enumerable.h" +#include "../interfaces/remover.h" +#include "../serialize.h" +#include "../algs.h" + +namespace dlib +{ + template < + typename T, + typename mem_manager = default_memory_manager + > + class sequence : public enumerable, + public remover + { + + /*! + REQUIREMENTS ON T + T must be swappable by a global swap() and + T must have a default constructor + + REQUIREMENTS ON mem_manager + must be an implementation of memory_manager/memory_manager_kernel_abstract.h or + must be an implementation of memory_manager_global/memory_manager_global_kernel_abstract.h or + must be an implementation of memory_manager_stateless/memory_manager_stateless_kernel_abstract.h + mem_manager::type can be set to anything. + + POINTERS AND REFERENCES TO INTERNAL DATA + swap() and operator[] functions do not invalidate pointers or + references to internal data. + All other functions have no such guarantees. + + ENUMERATION ORDER + The enumerator will iterate over the elements in the sequence from + the 0th element to the (size()-1)th element. + + INITIAL VALUE + size() == 0 + + WHAT THIS OBJECT REPRESENTS + sequence contains items of type T + + This object represents an ordered sequence of items, each item is + associated with an integer value. + The items are numbered from 0 to size()-1 + + Also note that unless specified otherwise, no member functions + of this object throw exceptions. + !*/ + + public: + + typedef T type; + typedef mem_manager mem_manager_type; + + sequence ( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc or any exception thrown by T's constructor + !*/ + + virtual ~sequence ( + ); + /*! + ensures + - all memory associated with *this has been released + !*/ + + void clear ( + ); + /*! + ensures + - #*this has its initial value + throws + - std::bad_alloc or any exception thrown by T's constructor + if this exception is thrown then *this is unusable + until clear() is called and succeeds + !*/ + + void add ( + unsigned long pos, + T& item + ); + /*! + requires + - pos <= size() + ensures + - #size() == size() + 1 + - #item has an initial value for its type + - #operator[](pos) == item + i.e. item has been inserted into *this between the elements which + were previously at position pos-1 and pos + - #at_start() == true + throws + - std::bad_alloc or any exception thrown by T's constructor + if add() throws then it has no effect + !*/ + + void remove ( + unsigned long pos, + T& item + ); + /*! + requires + - pos < size() + ensures + - #size() == size() - 1 + - the element at the position pos in *this has been removed and + swapped into #item + - #at_start() == true + !*/ + + void cat ( + sequence& item + ); + /*! + requires + - &item != this (i.e. you can't concatenate a sequence onto itself) + ensures + - item has been concatenated onto the end of *this + i.e. item[0] becomes (#*this)[size()], item[1] + becomes (#*this)[size()+1], etc. + - #size() == size() + item.size() + - #item has its initial value + - #at_start() == true + !*/ + + const T& operator[] ( + unsigned long pos + ) const; + /*! + requires + - pos < size() + ensures + - returns a const reference to the element at position pos + !*/ + + T& operator[] ( + unsigned long pos + ); + /*! + requires + - pos < size() + ensures + - returns a non-const reference to the element at position pos + !*/ + + void swap ( + sequence& item + ); + /*! + ensures + - swaps *this and item + !*/ + + + private: + + // restricted functions + sequence(sequence&); // copy constructor + sequence& operator=(sequence&); // assignment operator + + }; + + + template < + typename T, + typename mem_manager + > + inline void swap ( + sequence& a, + sequence& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + + template < + typename T, + typename mem_manager + > + void deserialize ( + sequence& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ +} + +#endif // DLIB_SEQUENCE_KERNEl_ABSTRACT_ + diff --git a/ml/dlib/dlib/sequence/sequence_kernel_c.h b/ml/dlib/dlib/sequence/sequence_kernel_c.h new file mode 100644 index 000000000..c565b54f0 --- /dev/null +++ b/ml/dlib/dlib/sequence/sequence_kernel_c.h @@ -0,0 +1,253 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SEQUENCE_KERNEl_C_ +#define DLIB_SEQUENCE_KERNEl_C_ + +#include "sequence_kernel_abstract.h" +#include "../algs.h" +#include "../assert.h" + +namespace dlib +{ + + template < + typename seq_base + > + class sequence_kernel_c : public seq_base + { + typedef typename seq_base::type T; + public: + + + void add ( + unsigned long pos, + T& item + ); + + void remove ( + unsigned long pos, + T& item + ); + + const T& operator[] ( + unsigned long pos + ) const; + + T& operator[] ( + unsigned long pos + ); + + void cat ( + sequence_kernel_c& item + ); + + const T& element ( + ) const; + + T& element ( + ); + + void remove_any ( + T& item + ); + + }; + + + template < + typename seq_base + > + inline void swap ( + sequence_kernel_c& a, + sequence_kernel_c& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename seq_base + > + void sequence_kernel_c:: + add( + unsigned long pos, + T& item + ) + { + + // make sure requires clause is not broken + DLIB_CASSERT(( pos <= this->size() ), + "\tvoid sequence::add" + << "\n\tpos must be >= 0 and <= size()" + << "\n\tpos: " << pos + << "\n\tsize(): " << this->size() + << "\n\tthis: " << this + ); + + // call the real function + seq_base::add(pos,item); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename seq_base + > + void sequence_kernel_c:: + cat ( + sequence_kernel_c& item + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(&item != this, + "\tvoid sequence::cat" + << "\n\tyou can't concatenate a sequence onto itself" + << "\n\t&item: " << &item + << "\n\tthis: " << this + ); + + // call the real function + seq_base::cat(item); + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename seq_base + > + void sequence_kernel_c:: + remove ( + unsigned long pos, + T& item + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(( pos < this->size() ), + "\tvoid sequence::remove" + << "\n\tpos must be >= 0 and < size()" + << "\n\tpos: " << pos + << "\n\tsize(): " << this->size() + << "\n\tthis: " << this + ); + + // call the real function + seq_base::remove(pos,item); + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename seq_base + > + const typename seq_base::type& sequence_kernel_c:: + operator[] ( + unsigned long pos + ) const + { + + // make sure requires clause is not broken + DLIB_CASSERT(( pos < this->size() ), + "\tconst T& sequence::operator[]" + << "\n\tpos must be >= 0 and < size()" + << "\n\tpos: " << pos + << "\n\tsize(): " << this->size() + << "\n\tthis: " << this + ); + + // call the real function + return seq_base::operator[](pos); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename seq_base + > + typename seq_base::type& sequence_kernel_c:: + operator[] ( + unsigned long pos + ) + { + + // make sure requires clause is not broken + DLIB_CASSERT(( pos < this->size() ), + "\tT& sequence::operator[]" + << "\n\tpos must be >= 0 and < size()" + << "\n\tpos: " << pos + << "\n\tsize(): " << this->size() + << "\n\tthis: " << this + ); + + // call the real function + return seq_base::operator[](pos); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename seq_base + > + const typename seq_base::type& sequence_kernel_c:: + element ( + ) const + { + DLIB_CASSERT(this->current_element_valid() == true, + "\tconst T& sequence::element() const" + << "\n\tyou can't access the current element if it doesn't exist" + << "\n\tthis: " << this + ); + + return seq_base::element(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename seq_base + > + typename seq_base::type& sequence_kernel_c:: + element ( + ) + { + DLIB_CASSERT(this->current_element_valid() == true, + "\tT& sequence::element()" + << "\n\tyou can't access the current element if it doesn't exist" + << "\n\tthis: " << this + ); + + return seq_base::element(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename seq_base + > + void sequence_kernel_c:: + remove_any ( + T& item + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( (this->size() > 0), + "\tvoid sequence::remove_any" + << "\n\tsize() must be greater than zero if something is going to be removed" + << "\n\tsize(): " << this->size() + << "\n\tthis: " << this + ); + + // call the real function + seq_base::remove_any(item); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SEQUENCE_KERNEl_C_ + diff --git a/ml/dlib/dlib/sequence/sequence_sort_1.h b/ml/dlib/dlib/sequence/sequence_sort_1.h new file mode 100644 index 000000000..2dd258e50 --- /dev/null +++ b/ml/dlib/dlib/sequence/sequence_sort_1.h @@ -0,0 +1,182 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SEQUENCE_SORt_1_ +#define DLIB_SEQUENCE_SORt_1_ + +#include "sequence_sort_abstract.h" +#include "../algs.h" + +namespace dlib +{ + + template < + typename seq_base + > + class sequence_sort_1 : public seq_base + { + typedef typename seq_base::type T; + + public: + + /*! + this is a median of three version of the QuickSort algorithm and + it sorts sequences of less than 30 elements with a selection sort + !*/ + + void sort ( + ); + + private: + + void sort_this_sequence ( + seq_base& sequence + ); + /*! + ensures + - each element in the sequence is < the element behind it + !*/ + + void selection_sort ( + seq_base& sequence + ); + /*! + ensures + - sequence is sorted with a selection_sort + !*/ + + + }; + + + template < + typename seq_base + > + inline void swap ( + sequence_sort_1& a, + sequence_sort_1& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename seq_base + > + void sequence_sort_1:: + sort ( + ) + { + if (this->size() > 1) + { + sort_this_sequence(*this); + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// private member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename seq_base + > + void sequence_sort_1:: + sort_this_sequence ( + seq_base& sequence + ) + { + if (sequence.size() < 30) + { + selection_sort(sequence); + } + else + { + seq_base left, right; + T partition_element; + + sequence.remove(0,partition_element); + + dlib::median ( + partition_element, + sequence[sequence.size()-1], + sequence[(sequence.size()-1)/2] + ); + + // partition sequence into left and right + T temp; + while (sequence.size() > 0) + { + sequence.remove(0,temp); + if (temp < partition_element) + { + left.add(0,temp); + } + else + { + right.add(0,temp); + } + } + + sort_this_sequence(left); + sort_this_sequence(right); + + // combine left and right into sequence + left.swap(sequence); + sequence.add(sequence.size(),partition_element); + sequence.cat(right); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename seq_base + > + void sequence_sort_1:: + selection_sort ( + seq_base& sequence + ) + { + if (sequence.size() > 2) + { + T temp[29]; + unsigned long ssize = sequence.size(); + + for (unsigned long i = 0; i < ssize; ++i) + sequence.remove(0,temp[i]); + + unsigned long smallest; + for (unsigned long i = 0; i < ssize - 1; ++i) + { + // find smallest element and swap into i + smallest = i; + for (unsigned long j = i+1; j < ssize; ++j) + { + if (temp[j] < temp[smallest]) + smallest = j; + } + exchange(temp[smallest],temp[i]); + } + + for (unsigned long i = 0; i < ssize; ++i) + sequence.add(i,temp[i]); + } + else if (sequence.size() == 2) + { + if (sequence[1] < sequence[0]) + { + exchange(sequence[0],sequence[1]); + } + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SEQUENCE_SORt_1_ + diff --git a/ml/dlib/dlib/sequence/sequence_sort_2.h b/ml/dlib/dlib/sequence/sequence_sort_2.h new file mode 100644 index 000000000..558d4e0d5 --- /dev/null +++ b/ml/dlib/dlib/sequence/sequence_sort_2.h @@ -0,0 +1,65 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SEQUENCE_SORt_2_ +#define DLIB_SEQUENCE_SORt_2_ + +#include "sequence_sort_abstract.h" +#include "../algs.h" +#include "../sort.h" + +namespace dlib +{ + + template < + typename seq_base + > + class sequence_sort_2 : public seq_base + { + typedef typename seq_base::type T; + + public: + + /*! + this is a version of the QuickSort algorithm + this uses the dlib::qsort_array function + !*/ + + void sort ( + ); + + + }; + + template < + typename seq_base + > + inline void swap ( + sequence_sort_2& a, + sequence_sort_2& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename seq_base + > + void sequence_sort_2:: + sort ( + ) + { + if (this->size() > 1) + { + dlib::qsort_array(*this,0,this->size()-1); + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SEQUENCE_SORt_2_ + diff --git a/ml/dlib/dlib/sequence/sequence_sort_abstract.h b/ml/dlib/dlib/sequence/sequence_sort_abstract.h new file mode 100644 index 000000000..fe0a91220 --- /dev/null +++ b/ml/dlib/dlib/sequence/sequence_sort_abstract.h @@ -0,0 +1,65 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SEQUENCE_SORt_ABSTRACT_ +#ifdef DLIB_SEQUENCE_SORt_ABSTRACT_ + +#include "sequence_kernel_abstract.h" + +namespace dlib +{ + + template < + typename seq_base + > + class sequence_sort : public seq_base + { + + /*! + REQUIREMENTS ON T + T must implement operator< for its type + + REQUIREMENTS ON seq_base + must be an implementation of sequence/sequence_kernel_abstract.h + + + + POINTERS AND REFERENCES TO INTERNAL DATA + sort() may invalidate pointers and references to data members. + + WHAT THIS EXTENSION DOES FOR sequence + this gives a sequence the ability to sort its contents by calling sort() + !*/ + + + public: + + void sort ( + ); + /*! + ensures + - for all elements in #*this the ith element is <= the i+1 element + - #at_start() == true + throws + - std::bad_alloc or any exception thrown by T's constructor + data may be lost if sort() throws + !*/ + + + }; + + + template < + typename seq_base + > + inline void swap ( + sequence_sort& a, + sequence_sort& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + +} + +#endif // DLIB_SEQUENCE_SORt_ABSTRACT_ + diff --git a/ml/dlib/dlib/serialize.h b/ml/dlib/dlib/serialize.h new file mode 100644 index 000000000..f21bdaaff --- /dev/null +++ b/ml/dlib/dlib/serialize.h @@ -0,0 +1,1779 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SERIALIZe_ +#define DLIB_SERIALIZe_ + +/*! + There are two global functions in the dlib namespace that provide serialization and + deserialization support. Their signatures and specifications are as follows: + + void serialize ( + const serializable_type& item, + std::ostream& out + ); + /!* + ensures + - writes the state of item to the output stream out + - if (serializable_type implements the enumerable interface) then + - item.at_start() == true + throws + - serialization_error + This exception is thrown if there is some problem which prevents + us from successfully writing item to the output stream. + - any other exception + *!/ + + void deserialize ( + serializable_type& item, + std::istream& in + ); + /!* + ensures + - #item == a deserialized copy of the serializable_type that was + in the input stream in. + - Reads all the bytes associated with the serialized serializable_type + contained inside the input stream and no more. This means you + can serialize multiple objects to an output stream and then read + them all back in, one after another, using deserialize(). + - if (serializable_type implements the enumerable interface) then + - item.at_start() == true + throws + - serialization_error + This exception is thrown if there is some problem which prevents + us from successfully deserializing item from the input stream. + If this exception is thrown then item will have an initial value + for its type. + - any other exception + *!/ + + For convenience, you can also serialize to a file using this syntax: + serialize("your_file.dat") << some_object << another_object; + + That overwrites the contents of your_file.dat with the serialized data from some_object + and another_object. Then to recall the objects from the file you can do: + deserialize("your_file.dat") >> some_object >> another_object; + + Finally, you can chain as many objects together using the << and >> operators as you + like. + + + This file provides serialization support to the following object types: + - The C++ base types (NOT including pointer types) + - std::string + - std::wstring + - std::vector + - std::array + - std::deque + - std::map + - std::set + - std::pair + - std::complex + - dlib::uint64 + - dlib::int64 + - float_details + - enumerable where T is a serializable type + - map_pair where D and R are both serializable types. + - C style arrays of serializable types + - Google protocol buffer objects. + + This file provides deserialization support to the following object types: + - The C++ base types (NOT including pointer types) + - std::string + - std::wstring + - std::vector + - std::array + - std::deque + - std::map + - std::set + - std::pair + - std::complex + - dlib::uint64 + - dlib::int64 + - float_details + - C style arrays of serializable types + - Google protocol buffer objects. + + Support for deserialization of objects which implement the enumerable or + map_pair interfaces is the responsibility of those objects. + + Note that you can deserialize an integer value to any integral type (except for a + char type) if its value will fit into the target integer type. I.e. the types + short, int, long, unsigned short, unsigned int, unsigned long, and dlib::uint64 + can all receive serialized data from each other so long as the actual serialized + value fits within the receiving integral type's range. + + Also note that for any container to be serializable the type of object it contains + must be serializable. + + FILE STREAMS + If you are serializing to and from file streams it is important to + remember to set the file streams to binary mode using the std::ios::binary + flag. + + + INTEGRAL SERIALIZATION FORMAT + All C++ integral types (except the char types) are serialized to the following + format: + The first byte is a control byte. It tells you if the serialized number is + positive or negative and also tells you how many of the following bytes are + part of the number. The absolute value of the number is stored in little + endian byte order and follows the control byte. + + The control byte: + The high order bit of the control byte is a flag that tells you if the + encoded number is negative or not. It is set to 1 when the number is + negative and 0 otherwise. + The 4 low order bits of the control byte represent an unsigned number + and tells you how many of the following bytes are part of the encoded + number. + + bool SERIALIZATION FORMAT + A bool value is serialized as the single byte character '1' or '0' in ASCII. + Where '1' indicates true and '0' indicates false. + + FLOATING POINT SERIALIZATION FORMAT + To serialize a floating point value we convert it into a float_details object and + then serialize the exponent and mantissa values using dlib's integral serialization + format. Therefore, the output is first the exponent and then the mantissa. Note that + the mantissa is a signed integer (i.e. there is not a separate sign bit). +!*/ + + +#include "algs.h" +#include "assert.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "uintn.h" +#include "interfaces/enumerable.h" +#include "interfaces/map_pair.h" +#include "enable_if.h" +#include "unicode.h" +#include "byte_orderer.h" +#include "float_details.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class serialization_error : public error + { + /*! + WHAT THIS OBJECT REPRESENTS + This is the exception object. It is thrown if serialization or + deserialization fails. + !*/ + + public: + serialization_error(const std::string& e):error(e) {} + }; + + + void check_serialized_version( + const std::string& expected_version, + std::istream& in + ); + /*! + ensures + - Deserializes a string from in and if it doesn't match expected_version we + throw serialization_error. + !*/ + +// ---------------------------------------------------------------------------------------- + + /*!A ramdump information !*/ + template + struct ramdump_t + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a type decoration used to indicate that serialization should be + done by simply dumping the memory of some object to disk as fast as + possible without any sort of conversions. This means that the data written + will be "non-portable" in the sense that the format output by a RAM dump + may depend on things like the endianness of your CPU or settings of certain + compiler switches. + + You use this object like this: + serialize("yourfile.dat") << ramdump(yourobject); + deserialize("yourfile.dat") >> ramdump(yourobject); + or + serialize(ramdump(yourobject), out); + deserialize(ramdump(yourobject), in); + + Also, not all objects have a ramdump mode. If you try to use ramdump on an + object that does not define a serialization dump for ramdump you will get a + compiler error. + !*/ + ramdump_t(T& item_) : item(item_) {} + T& item; + }; + + // This function just makes a ramdump that wraps an object. + template + ramdump_t::type> ramdump(T&& item) + { + return ramdump_t::type>(item); + } + + + template < + typename T + > + void serialize ( + const ramdump_t& item_, + std::ostream& out + ) + { + // Move the const from inside the ramdump_t template to outside so we can bind + // against a serialize() call that takes just a const ramdump_t. Doing this + // saves you from needing to write multiple overloads of serialize() to handle + // these different const placement cases. + const auto temp = ramdump(const_cast(item_.item)); + serialize(temp, out); + } + +// ---------------------------------------------------------------------------------------- + + namespace ser_helper + { + + template < + typename T + > + typename enable_if_c::is_signed,bool>::type pack_int ( + T item, + std::ostream& out + ) + /*! + requires + - T is a signed integral type + ensures + - if (no problems occur serializing item) then + - writes item to out + - returns false + - else + - returns true + !*/ + { + COMPILE_TIME_ASSERT(sizeof(T) <= 8); + unsigned char buf[9]; + unsigned char size = sizeof(T); + unsigned char neg; + if (item < 0) + { + neg = 0x80; + item *= -1; + } + else + { + neg = 0; + } + + for (unsigned char i = 1; i <= sizeof(T); ++i) + { + buf[i] = static_cast(item&0xFF); + item >>= 8; + if (item == 0) { size = i; break; } + } + + std::streambuf* sbuf = out.rdbuf(); + buf[0] = size|neg; + if (sbuf->sputn(reinterpret_cast(buf),size+1) != size+1) + { + out.setstate(std::ios::eofbit | std::ios::badbit); + return true; + } + + return false; + } + + // ------------------------------------------------------------------------------------ + + template < + typename T + > + typename enable_if_c::is_signed,bool>::type unpack_int ( + T& item, + std::istream& in + ) + /*! + requires + - T is a signed integral type + ensures + - if (there are no problems deserializing item) then + - returns false + - #item == the value stored in in + - else + - returns true + + !*/ + { + COMPILE_TIME_ASSERT(sizeof(T) <= 8); + + + unsigned char buf[8]; + unsigned char size; + bool is_negative; + + std::streambuf* sbuf = in.rdbuf(); + + item = 0; + int ch = sbuf->sbumpc(); + if (ch != EOF) + { + size = static_cast(ch); + } + else + { + in.setstate(std::ios::badbit); + return true; + } + + if (size&0x80) + is_negative = true; + else + is_negative = false; + size &= 0x0F; + + // check if the serialized object is too big + if (size > (unsigned long)tmin::value || size == 0) + { + return true; + } + + if (sbuf->sgetn(reinterpret_cast(&buf),size) != size) + { + in.setstate(std::ios::badbit); + return true; + } + + + for (unsigned char i = size-1; true; --i) + { + item <<= 8; + item |= buf[i]; + if (i == 0) + break; + } + + if (is_negative) + item *= -1; + + + return false; + } + + // ------------------------------------------------------------------------------------ + + template < + typename T + > + typename disable_if_c::is_signed,bool>::type pack_int ( + T item, + std::ostream& out + ) + /*! + requires + - T is an unsigned integral type + ensures + - if (no problems occur serializing item) then + - writes item to out + - returns false + - else + - returns true + !*/ + { + COMPILE_TIME_ASSERT(sizeof(T) <= 8); + unsigned char buf[9]; + unsigned char size = sizeof(T); + + for (unsigned char i = 1; i <= sizeof(T); ++i) + { + buf[i] = static_cast(item&0xFF); + item >>= 8; + if (item == 0) { size = i; break; } + } + + std::streambuf* sbuf = out.rdbuf(); + buf[0] = size; + if (sbuf->sputn(reinterpret_cast(buf),size+1) != size+1) + { + out.setstate(std::ios::eofbit | std::ios::badbit); + return true; + } + + return false; + } + + // ------------------------------------------------------------------------------------ + + template < + typename T + > + typename disable_if_c::is_signed,bool>::type unpack_int ( + T& item, + std::istream& in + ) + /*! + requires + - T is an unsigned integral type + ensures + - if (there are no problems deserializing item) then + - returns false + - #item == the value stored in in + - else + - returns true + + !*/ + { + COMPILE_TIME_ASSERT(sizeof(T) <= 8); + + unsigned char buf[8]; + unsigned char size; + + item = 0; + + std::streambuf* sbuf = in.rdbuf(); + int ch = sbuf->sbumpc(); + if (ch != EOF) + { + size = static_cast(ch); + } + else + { + in.setstate(std::ios::badbit); + return true; + } + + + // mask out the 3 reserved bits + size &= 0x8F; + + // check if an error occurred + if (size > (unsigned long)tmin::value || size == 0) + return true; + + + if (sbuf->sgetn(reinterpret_cast(&buf),size) != size) + { + in.setstate(std::ios::badbit); + return true; + } + + for (unsigned char i = size-1; true; --i) + { + item <<= 8; + item |= buf[i]; + if (i == 0) + break; + } + + return false; + } + + } + +// ---------------------------------------------------------------------------------------- + + #define USE_DEFAULT_INT_SERIALIZATION_FOR(T) \ + inline void serialize (const T& item, std::ostream& out) \ + { if (ser_helper::pack_int(item,out)) throw serialization_error("Error serializing object of type " + std::string(#T)); } \ + inline void deserialize (T& item, std::istream& in) \ + { if (ser_helper::unpack_int(item,in)) throw serialization_error("Error deserializing object of type " + std::string(#T)); } + + template + inline bool pack_byte ( + const T& ch, + std::ostream& out + ) + { + std::streambuf* sbuf = out.rdbuf(); + return (sbuf->sputc((char)ch) == EOF); + } + + template + inline bool unpack_byte ( + T& ch, + std::istream& in + ) + { + std::streambuf* sbuf = in.rdbuf(); + int temp = sbuf->sbumpc(); + if (temp != EOF) + { + ch = static_cast(temp); + return false; + } + else + { + return true; + } + } + + #define USE_DEFAULT_BYTE_SERIALIZATION_FOR(T) \ + inline void serialize (const T& item,std::ostream& out) \ + { if (pack_byte(item,out)) throw serialization_error("Error serializing object of type " + std::string(#T)); } \ + inline void deserialize (T& item, std::istream& in) \ + { if (unpack_byte(item,in)) throw serialization_error("Error deserializing object of type " + std::string(#T)); } + +// ---------------------------------------------------------------------------------------- + + USE_DEFAULT_INT_SERIALIZATION_FOR(short) + USE_DEFAULT_INT_SERIALIZATION_FOR(int) + USE_DEFAULT_INT_SERIALIZATION_FOR(long) + USE_DEFAULT_INT_SERIALIZATION_FOR(unsigned short) + USE_DEFAULT_INT_SERIALIZATION_FOR(unsigned int) + USE_DEFAULT_INT_SERIALIZATION_FOR(unsigned long) + USE_DEFAULT_INT_SERIALIZATION_FOR(uint64) + USE_DEFAULT_INT_SERIALIZATION_FOR(int64) + + USE_DEFAULT_BYTE_SERIALIZATION_FOR(char) + USE_DEFAULT_BYTE_SERIALIZATION_FOR(signed char) + USE_DEFAULT_BYTE_SERIALIZATION_FOR(unsigned char) + + // Don't define serialization for wchar_t when using visual studio and + // _NATIVE_WCHAR_T_DEFINED isn't defined since if it isn't they improperly set + // wchar_t to be a typedef rather than its own type as required by the C++ + // standard. +#if !defined(_MSC_VER) || _NATIVE_WCHAR_T_DEFINED + USE_DEFAULT_INT_SERIALIZATION_FOR(wchar_t) +#endif + +// ---------------------------------------------------------------------------------------- + + inline void serialize( + const float_details& item, + std::ostream& out + ) + { + serialize(item.mantissa, out); + serialize(item.exponent, out); + } + + inline void deserialize( + float_details& item, + std::istream& in + ) + { + deserialize(item.mantissa, in); + deserialize(item.exponent, in); + } + +// ---------------------------------------------------------------------------------------- + + template + inline void serialize_floating_point ( + const T& item, + std::ostream& out + ) + { + try + { + float_details temp = item; + serialize(temp, out); + } + catch (serialization_error& e) + { throw serialization_error(e.info + "\n while serializing a floating point number."); } + } + + template + inline bool old_deserialize_floating_point ( + T& item, + std::istream& in + ) + { + std::ios::fmtflags oldflags = in.flags(); + in.flags(static_cast(0)); + std::streamsize ss = in.precision(35); + if (in.peek() == 'i') + { + item = std::numeric_limits::infinity(); + in.get(); + in.get(); + in.get(); + } + else if (in.peek() == 'n') + { + item = -std::numeric_limits::infinity(); + in.get(); + in.get(); + in.get(); + in.get(); + } + else if (in.peek() == 'N') + { + item = std::numeric_limits::quiet_NaN(); + in.get(); + in.get(); + in.get(); + } + else + { + in >> item; + } + in.flags(oldflags); + in.precision(ss); + return (in.get() != ' '); + } + + template + inline void deserialize_floating_point ( + T& item, + std::istream& in + ) + { + // check if the serialized data uses the older ASCII based format. We can check + // this easily because the new format starts with the integer control byte which + // always has 0 bits in the positions corresponding to the bitmask 0x70. Moreover, + // since the previous format used ASCII numbers we know that no valid bytes can + // have bit values of one in the positions indicated 0x70. So this test looks at + // the first byte and checks if the serialized data uses the old format or the new + // format. + if ((in.rdbuf()->sgetc()&0x70) == 0) + { + try + { + // Use the fast and compact binary serialization format. + float_details temp; + deserialize(temp, in); + item = temp; + } + catch (serialization_error& e) + { throw serialization_error(e.info + "\n while deserializing a floating point number."); } + } + else + { + if (old_deserialize_floating_point(item, in)) + throw serialization_error("Error deserializing a floating point number."); + } + } + + inline void serialize ( const float& item, std::ostream& out) + { + serialize_floating_point(item,out); + } + + inline void deserialize (float& item, std::istream& in) + { + deserialize_floating_point(item,in); + } + + inline void serialize ( const double& item, std::ostream& out) + { + serialize_floating_point(item,out); + } + + inline void deserialize (double& item, std::istream& in) + { + deserialize_floating_point(item,in); + } + + inline void serialize ( const long double& item, std::ostream& out) + { + serialize_floating_point(item,out); + } + + inline void deserialize ( long double& item, std::istream& in) + { + deserialize_floating_point(item,in); + } + +// ---------------------------------------------------------------------------------------- +// prototypes + + template + void serialize ( + const std::map& item, + std::ostream& out + ); + + template + void deserialize ( + std::map& item, + std::istream& in + ); + + template + void serialize ( + const std::set& item, + std::ostream& out + ); + + template + void deserialize ( + std::set& item, + std::istream& in + ); + + template + void serialize ( + const std::vector& item, + std::ostream& out + ); + + template + void deserialize ( + std::vector& item, + std::istream& in + ); + + template + void serialize ( + const std::deque& item, + std::ostream& out + ); + + template + void deserialize ( + std::deque& item, + std::istream& in + ); + + inline void serialize ( + const std::string& item, + std::ostream& out + ); + + inline void deserialize ( + std::string& item, + std::istream& in + ); + + inline void serialize ( + const std::wstring& item, + std::ostream& out + ); + + inline void deserialize ( + std::wstring& item, + std::istream& in + ); + + inline void serialize ( + const ustring& item, + std::ostream& out + ); + + inline void deserialize ( + ustring& item, + std::istream& in + ); + + template < + typename T + > + inline void serialize ( + const enumerable& item, + std::ostream& out + ); + + template < + typename domain, + typename range + > + inline void serialize ( + const map_pair& item, + std::ostream& out + ); + + template < + typename T, + size_t length + > + inline void serialize ( + const T (&array)[length], + std::ostream& out + ); + + template < + typename T, + size_t length + > + inline void deserialize ( + T (&array)[length], + std::istream& in + ); + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + inline void serialize ( + bool item, + std::ostream& out + ) + { + if (item) + out << '1'; + else + out << '0'; + + if (!out) + throw serialization_error("Error serializing object of type bool"); + } + + inline void deserialize ( + bool& item, + std::istream& in + ) + { + int ch = in.get(); + if (ch != EOF) + { + if (ch == '1') + item = true; + else if (ch == '0') + item = false; + else + throw serialization_error("Error deserializing object of type bool"); + } + else + { + throw serialization_error("Error deserializing object of type bool"); + } + } + +// ---------------------------------------------------------------------------------------- + + template + void serialize ( + const std::pair& item, + std::ostream& out + ) + { + try + { + serialize(item.first,out); + serialize(item.second,out); + } + catch (serialization_error& e) + { throw serialization_error(e.info + "\n while serializing object of type std::pair"); } + } + + template + void deserialize ( + std::pair& item, + std::istream& in + ) + { + try + { + deserialize(item.first,in); + deserialize(item.second,in); + } + catch (serialization_error& e) + { throw serialization_error(e.info + "\n while deserializing object of type std::pair"); } + } + +// ---------------------------------------------------------------------------------------- + + template + void serialize ( + const std::map& item, + std::ostream& out + ) + { + try + { + const unsigned long size = static_cast(item.size()); + + serialize(size,out); + typename std::map::const_iterator i; + for (i = item.begin(); i != item.end(); ++i) + { + serialize(i->first,out); + serialize(i->second,out); + } + + } + catch (serialization_error& e) + { throw serialization_error(e.info + "\n while serializing object of type std::map"); } + } + + template + void deserialize ( + std::map& item, + std::istream& in + ) + { + try + { + item.clear(); + + unsigned long size; + deserialize(size,in); + domain d; + range r; + for (unsigned long i = 0; i < size; ++i) + { + deserialize(d,in); + deserialize(r,in); + item[d] = r; + } + } + catch (serialization_error& e) + { throw serialization_error(e.info + "\n while deserializing object of type std::map"); } + } + +// ---------------------------------------------------------------------------------------- + + template + void serialize ( + const std::set& item, + std::ostream& out + ) + { + try + { + const unsigned long size = static_cast(item.size()); + + serialize(size,out); + typename std::set::const_iterator i; + for (i = item.begin(); i != item.end(); ++i) + { + serialize(*i,out); + } + + } + catch (serialization_error& e) + { throw serialization_error(e.info + "\n while serializing object of type std::set"); } + } + + template + void deserialize ( + std::set& item, + std::istream& in + ) + { + try + { + item.clear(); + + unsigned long size; + deserialize(size,in); + domain d; + for (unsigned long i = 0; i < size; ++i) + { + deserialize(d,in); + item.insert(d); + } + } + catch (serialization_error& e) + { throw serialization_error(e.info + "\n while deserializing object of type std::set"); } + } + +// ---------------------------------------------------------------------------------------- + + template + void serialize ( + const std::vector& item, + std::ostream& out + ) + { + std::vector temp(item.size()); + for (unsigned long i = 0; i < item.size(); ++i) + { + if (item[i]) + temp[i] = '1'; + else + temp[i] = '0'; + } + serialize(temp, out); + } + + template + void deserialize ( + std::vector& item, + std::istream& in + ) + { + std::vector temp; + deserialize(temp, in); + item.resize(temp.size()); + for (unsigned long i = 0; i < temp.size(); ++i) + { + if (temp[i] == '1') + item[i] = true; + else + item[i] = false; + } + } + +// ---------------------------------------------------------------------------------------- + + template + void serialize ( + const std::vector& item, + std::ostream& out + ) + { + try + { + const unsigned long size = static_cast(item.size()); + + serialize(size,out); + for (unsigned long i = 0; i < item.size(); ++i) + serialize(item[i],out); + } + catch (serialization_error& e) + { throw serialization_error(e.info + "\n while serializing object of type std::vector"); } + } + + template + void deserialize ( + std::vector& item, + std::istream& in + ) + { + try + { + unsigned long size; + deserialize(size,in); + item.resize(size); + for (unsigned long i = 0; i < size; ++i) + deserialize(item[i],in); + } + catch (serialization_error& e) + { throw serialization_error(e.info + "\n while deserializing object of type std::vector"); } + } + +// ---------------------------------------------------------------------------------------- + + template + void serialize ( + const std::vector& item, + std::ostream& out + ) + { + try + { + const unsigned long size = static_cast(item.size()); + serialize(size,out); + if (item.size() != 0) + out.write(&item[0], item.size()); + } + catch (serialization_error& e) + { throw serialization_error(e.info + "\n while serializing object of type std::vector"); } + } + + template + void deserialize ( + std::vector& item, + std::istream& in + ) + { + try + { + unsigned long size; + deserialize(size,in); + item.resize(size); + if (item.size() != 0) + in.read(&item[0], item.size()); + } + catch (serialization_error& e) + { throw serialization_error(e.info + "\n while deserializing object of type std::vector"); } + } + +// ---------------------------------------------------------------------------------------- + + template + void serialize ( + const std::vector& item, + std::ostream& out + ) + { + try + { + const unsigned long size = static_cast(item.size()); + serialize(size,out); + if (item.size() != 0) + out.write((char*)&item[0], item.size()); + } + catch (serialization_error& e) + { throw serialization_error(e.info + "\n while serializing object of type std::vector"); } + } + + template + void deserialize ( + std::vector& item, + std::istream& in + ) + { + try + { + unsigned long size; + deserialize(size,in); + item.resize(size); + if (item.size() != 0) + in.read((char*)&item[0], item.size()); + } + catch (serialization_error& e) + { throw serialization_error(e.info + "\n while deserializing object of type std::vector"); } + } + +// ---------------------------------------------------------------------------------------- + + template + void serialize ( + const std::deque& item, + std::ostream& out + ) + { + try + { + const unsigned long size = static_cast(item.size()); + + serialize(size,out); + for (unsigned long i = 0; i < item.size(); ++i) + serialize(item[i],out); + } + catch (serialization_error& e) + { throw serialization_error(e.info + "\n while serializing object of type std::deque"); } + } + + template + void deserialize ( + std::deque& item, + std::istream& in + ) + { + try + { + unsigned long size; + deserialize(size,in); + item.resize(size); + for (unsigned long i = 0; i < size; ++i) + deserialize(item[i],in); + } + catch (serialization_error& e) + { throw serialization_error(e.info + "\n while deserializing object of type std::deque"); } + } + +// ---------------------------------------------------------------------------------------- + + inline void serialize ( + const std::string& item, + std::ostream& out + ) + { + const unsigned long size = static_cast(item.size()); + try{ serialize(size,out); } + catch (serialization_error& e) + { throw serialization_error(e.info + "\n while serializing object of type std::string"); } + + out.write(item.c_str(),size); + if (!out) throw serialization_error("Error serializing object of type std::string"); + } + + inline void deserialize ( + std::string& item, + std::istream& in + ) + { + unsigned long size; + try { deserialize(size,in); } + catch (serialization_error& e) + { throw serialization_error(e.info + "\n while deserializing object of type std::string"); } + + item.resize(size); + if (size != 0) + { + in.read(&item[0],size); + if (!in) throw serialization_error("Error deserializing object of type std::string"); + } + } + +// ---------------------------------------------------------------------------------------- + + inline void serialize ( + const std::wstring& item, + std::ostream& out + ) + { + const unsigned long size = static_cast(item.size()); + try{ serialize(size,out); } + catch (serialization_error& e) + { throw serialization_error(e.info + "\n while serializing object of type std::wstring"); } + + for (unsigned long i = 0; i < item.size(); ++i) + serialize(item[i], out); + if (!out) throw serialization_error("Error serializing object of type std::wstring"); + } + + inline void deserialize ( + std::wstring& item, + std::istream& in + ) + { + unsigned long size; + try { deserialize(size,in); } + catch (serialization_error& e) + { throw serialization_error(e.info + "\n while deserializing object of type std::wstring"); } + + item.resize(size); + for (unsigned long i = 0; i < item.size(); ++i) + deserialize(item[i],in); + + if (!in) throw serialization_error("Error deserializing object of type std::wstring"); + } + +// ---------------------------------------------------------------------------------------- + + inline void serialize ( + const ustring& item, + std::ostream& out + ) + { + const unsigned long size = static_cast(item.size()); + try{ serialize(size,out); } + catch (serialization_error& e) + { throw serialization_error(e.info + "\n while serializing object of type ustring"); } + + for (unsigned long i = 0; i < item.size(); ++i) + serialize(item[i], out); + if (!out) throw serialization_error("Error serializing object of type ustring"); + } + + inline void deserialize ( + ustring& item, + std::istream& in + ) + { + unsigned long size; + try { deserialize(size,in); } + catch (serialization_error& e) + { throw serialization_error(e.info + "\n while deserializing object of type ustring"); } + + item.resize(size); + for (unsigned long i = 0; i < item.size(); ++i) + deserialize(item[i],in); + + if (!in) throw serialization_error("Error deserializing object of type ustring"); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + inline void serialize ( + const enumerable& item, + std::ostream& out + ) + { + try + { + item.reset(); + serialize(item.size(),out); + while (item.move_next()) + serialize(item.element(),out); + item.reset(); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing object of type enumerable"); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range + > + inline void serialize ( + const map_pair& item, + std::ostream& out + ) + { + try + { + serialize(item.key(),out); + serialize(item.value(),out); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing object of type map_pair"); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + size_t length + > + inline void serialize ( + const T (&array)[length], + std::ostream& out + ) + { + try + { + serialize(length,out); + for (size_t i = 0; i < length; ++i) + serialize(array[i],out); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing a C style array"); + } + } + + template < + size_t length + > + inline void serialize ( + const char (&array)[length], + std::ostream& out + ) + { + if (length != 0 && array[length-1] == '\0') + { + // If this is a null terminated string then don't serialize the trailing null. + // We do this so that the serialization format for C-strings is the same as + // std::string. + serialize(length-1, out); + out.write(array, length-1); + if (!out) + throw serialization_error("Error serializing a C-style string"); + } + else + { + try + { + serialize(length,out); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing a C style array"); + } + if (length != 0) + out.write(array, length); + if (!out) + throw serialization_error("Error serializing a C-style string"); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + size_t length + > + inline void deserialize ( + T (&array)[length], + std::istream& in + ) + { + size_t size; + try + { + deserialize(size,in); + if (size == length) + { + for (size_t i = 0; i < length; ++i) + deserialize(array[i],in); + } + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while deserializing a C style array"); + } + + if (size != length) + throw serialization_error("Error deserializing a C style array, lengths do not match"); + } + + template < + size_t length + > + inline void deserialize ( + char (&array)[length], + std::istream& in + ) + { + size_t size; + try + { + deserialize(size,in); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while deserializing a C style array"); + } + + if (size == length) + { + in.read(array, size); + if (!in) + throw serialization_error("Error deserializing a C-style array"); + } + else if (size+1 == length) + { + // In this case we are deserializing a C-style array so we need to add the null + // terminator. + in.read(array, size); + array[size] = '\0'; + if (!in) + throw serialization_error("Error deserializing a C-style string"); + } + else + { + throw serialization_error("Error deserializing a C style array, lengths do not match"); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + size_t N + > + inline void serialize ( + const std::array& array, + std::ostream& out + ) + { + typedef T c_array_type[N]; + serialize(*(const c_array_type*)array.data(), out); + } + + template < + typename T, + size_t N + > + inline void deserialize ( + std::array& array, + std::istream& in + ) + { + typedef T c_array_type[N]; + deserialize(*(c_array_type*)array.data(), in); + } + + template < + typename T + > + inline void serialize ( + const std::array& /*array*/, + std::ostream& out + ) + { + size_t N = 0; + serialize(N, out); + } + + template < + typename T + > + inline void deserialize ( + std::array& /*array*/, + std::istream& in + ) + { + size_t N; + deserialize(N, in); + if (N != 0) + { + std::ostringstream sout; + sout << "Expected std::array of size 0 but found a size of " << N; + throw serialization_error(sout.str()); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + inline void serialize ( + const std::complex& item, + std::ostream& out + ) + { + try + { + serialize(item.real(),out); + serialize(item.imag(),out); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing an object of type std::complex"); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + inline void deserialize ( + std::complex& item, + std::istream& in + ) + { + try + { + T real, imag; + deserialize(real,in); + deserialize(imag,in); + item = std::complex(real,imag); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while deserializing an object of type std::complex"); + } + } + +// ---------------------------------------------------------------------------------------- + + class proxy_serialize + { + public: + explicit proxy_serialize ( + const std::string& filename + ) + { + fout.reset(new std::ofstream(filename.c_str(), std::ios::binary)); + if (!(*fout)) + throw serialization_error("Unable to open " + filename + " for writing."); + } + + template + inline proxy_serialize& operator<<(const T& item) + { + serialize(item, *fout); + return *this; + } + + private: + std::shared_ptr fout; + }; + + class proxy_deserialize + { + public: + explicit proxy_deserialize ( + const std::string& filename + ) : filename(filename) + { + fin.reset(new std::ifstream(filename.c_str(), std::ios::binary)); + if (!(*fin)) + throw serialization_error("Unable to open " + filename + " for reading."); + + // read the file header into a buffer and then seek back to the start of the + // file. + fin->read(file_header,4); + fin->clear(); + fin->seekg(0); + } + + template + inline proxy_deserialize& operator>>(T& item) + { + return doit(item); + } + + template + inline proxy_deserialize& operator>>(ramdump_t&& item) + { + return doit(std::move(item)); + } + + private: + template + inline proxy_deserialize& doit(T&& item) + { + try + { + if (fin->peek() == EOF) + throw serialization_error("No more objects were in the file!"); + deserialize(std::forward(item), *fin); + } + catch (serialization_error& e) + { + std::string suffix; + if (looks_like_a_compressed_file()) + suffix = "\n *** THIS LOOKS LIKE A COMPRESSED FILE. DID YOU FORGET TO DECOMPRESS IT? *** \n"; + + if (objects_read == 0) + { + throw serialization_error("An error occurred while trying to read the first" + " object from the file " + filename + ".\nERROR: " + e.info + "\n" + suffix); + } + else if (objects_read == 1) + { + throw serialization_error("An error occurred while trying to read the second" + " object from the file " + filename + + ".\nERROR: " + e.info + "\n" + suffix); + } + else if (objects_read == 2) + { + throw serialization_error("An error occurred while trying to read the third" + " object from the file " + filename + + ".\nERROR: " + e.info + "\n" + suffix); + } + else + { + throw serialization_error("An error occurred while trying to read the " + + std::to_string(objects_read+1) + "th object from the file " + filename + + ".\nERROR: " + e.info + "\n" + suffix); + } + } + ++objects_read; + return *this; + } + + int objects_read = 0; + std::string filename; + std::shared_ptr fin; + + // We don't need to look at the file header. However, it's here because people + // keep posting questions to the dlib forums asking why they get file load errors. + // Then it turns out that the problem is they have a compressed file that NEEDS TO + // BE DECOMPRESSED by bzip2 or whatever and the reason they are getting + // deserialization errors is because they didn't decompress the file. So we are + // going to check if this file looks like a compressed file and if so then emit an + // error message telling them to unzip the file. :( + char file_header[4] = {0,0,0,0}; + + bool looks_like_a_compressed_file( + ) const + { + if (file_header[0] == 'B' && file_header[1] == 'Z' && file_header[2] == 'h' && + ('0' <= file_header[3] && file_header[3] <= '9') ) + { + return true; + } + + return false; + } + }; + + inline proxy_serialize serialize(const std::string& filename) + { return proxy_serialize(filename); } + inline proxy_deserialize deserialize(const std::string& filename) + { return proxy_deserialize(filename); } + +// ---------------------------------------------------------------------------------------- + +} + +// forward declare the MessageLite object so we can reference it below. +namespace google +{ + namespace protobuf + { + class MessageLite; + } +} + +namespace dlib +{ + + /*!A is_protocol_buffer + This is a template that tells you if a type is a Google protocol buffer object. + !*/ + + template + struct is_protocol_buffer + { + static const bool value = false; + }; + + template + struct is_protocol_buffer >::type > + { + static const bool value = true; + }; + + template + typename enable_if >::type serialize(const T& item, std::ostream& out) + { + // Note that Google protocol buffer messages are not self delimiting + // (see https://developers.google.com/protocol-buffers/docs/techniques) + // This means they don't record their length or where they end, so we have + // to record this information ourselves. So we save the size as a little endian 32bit + // integer prefixed onto the front of the message. + + byte_orderer bo; + + // serialize into temp string + std::string temp; + if (!item.SerializeToString(&temp)) + throw dlib::serialization_error("Error while serializing a Google Protocol Buffer object."); + if (temp.size() > std::numeric_limits::max()) + throw dlib::serialization_error("Error while serializing a Google Protocol Buffer object, message too large."); + + // write temp to the output stream + uint32 size = temp.size(); + bo.host_to_little(size); + out.write((char*)&size, sizeof(size)); + out.write(temp.c_str(), temp.size()); + } + + template + typename enable_if >::type deserialize(T& item, std::istream& in) + { + // Note that Google protocol buffer messages are not self delimiting + // (see https://developers.google.com/protocol-buffers/docs/techniques) + // This means they don't record their length or where they end, so we have + // to record this information ourselves. So we save the size as a little endian 32bit + // integer prefixed onto the front of the message. + + byte_orderer bo; + + uint32 size = 0; + // read the size + in.read((char*)&size, sizeof(size)); + bo.little_to_host(size); + if (!in || size == 0) + throw dlib::serialization_error("Error while deserializing a Google Protocol Buffer object."); + + // read the bytes into temp + std::string temp; + temp.resize(size); + in.read(&temp[0], size); + + // parse temp into item + if (!in || !item.ParseFromString(temp)) + { + throw dlib::serialization_error("Error while deserializing a Google Protocol Buffer object."); + } + } + +// ---------------------------------------------------------------------------------------- + + inline void check_serialized_version(const std::string& expected_version, std::istream& in) + { + std::string version; + deserialize(version, in); + if (version != expected_version) + { + throw serialization_error("Unexpected version '"+version+ + "' found while deserializing object. Expected version to be '"+expected_version+"'."); + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SERIALIZe_ + diff --git a/ml/dlib/dlib/server.h b/ml/dlib/dlib/server.h new file mode 100644 index 000000000..d346941ab --- /dev/null +++ b/ml/dlib/dlib/server.h @@ -0,0 +1,12 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SERVEr_ +#define DLIB_SERVEr_ + +#include "server/server_kernel.h" +#include "server/server_iostream.h" +#include "server/server_http.h" + + +#endif // DLIB_SERVEr_ + diff --git a/ml/dlib/dlib/server/server_http.cpp b/ml/dlib/dlib/server/server_http.cpp new file mode 100644 index 000000000..9e3051a43 --- /dev/null +++ b/ml/dlib/dlib/server/server_http.cpp @@ -0,0 +1,409 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SERVER_HTTP_CPp_ +#define DLIB_SERVER_HTTP_CPp_ + +#include "server_http.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace http_impl + { + inline unsigned char to_hex( unsigned char x ) + { + return x + (x > 9 ? ('A'-10) : '0'); + } + + const std::string urlencode( const std::string& s ) + { + std::ostringstream os; + + for ( std::string::const_iterator ci = s.begin(); ci != s.end(); ++ci ) + { + if ( (*ci >= 'a' && *ci <= 'z') || + (*ci >= 'A' && *ci <= 'Z') || + (*ci >= '0' && *ci <= '9') ) + { // allowed + os << *ci; + } + else if ( *ci == ' ') + { + os << '+'; + } + else + { + os << '%' << to_hex(*ci >> 4) << to_hex(*ci % 16); + } + } + + return os.str(); + } + + inline unsigned char from_hex ( + unsigned char ch + ) + { + if (ch <= '9' && ch >= '0') + ch -= '0'; + else if (ch <= 'f' && ch >= 'a') + ch -= 'a' - 10; + else if (ch <= 'F' && ch >= 'A') + ch -= 'A' - 10; + else + ch = 0; + return ch; + } + + const std::string urldecode ( + const std::string& str + ) + { + using namespace std; + string result; + string::size_type i; + for (i = 0; i < str.size(); ++i) + { + if (str[i] == '+') + { + result += ' '; + } + else if (str[i] == '%' && str.size() > i+2) + { + const unsigned char ch1 = from_hex(str[i+1]); + const unsigned char ch2 = from_hex(str[i+2]); + const unsigned char ch = (ch1 << 4) | ch2; + result += ch; + i += 2; + } + else + { + result += str[i]; + } + } + return result; + } + + void parse_url( + std::string word, + key_value_map& queries + ) + /*! + Parses the query string of a URL. word should be the stuff that comes + after the ? in the query URL. + !*/ + { + std::string::size_type pos; + + for (pos = 0; pos < word.size(); ++pos) + { + if (word[pos] == '&') + word[pos] = ' '; + } + + std::istringstream sin(word); + sin >> word; + while (sin) + { + pos = word.find_first_of("="); + if (pos != std::string::npos) + { + std::string key = urldecode(word.substr(0,pos)); + std::string value = urldecode(word.substr(pos+1)); + + queries[key] = value; + } + sin >> word; + } + } + + void read_with_limit( + std::istream& in, + std::string& buffer, + int delim = '\n' + ) + { + using namespace std; + const size_t max = 64*1024; + buffer.clear(); + buffer.reserve(300); + + while (in.peek() != delim && in.peek() != '\n' && in.peek() != EOF && buffer.size() < max) + { + buffer += (char)in.get(); + } + + // if we quit the loop because the data is longer than expected or we hit EOF + if (in.peek() == EOF) + throw http_parse_error("HTTP field from client terminated incorrectly", 414); + if (buffer.size() == max) + throw http_parse_error("HTTP field from client is too long", 414); + + in.get(); + // eat any remaining whitespace + if (delim == ' ') + { + while (in.peek() == ' ') + in.get(); + } + } + } + +// ---------------------------------------------------------------------------------------- + + unsigned long parse_http_request ( + std::istream& in, + incoming_things& incoming, + unsigned long max_content_length + ) + { + using namespace std; + using namespace http_impl; + read_with_limit(in, incoming.request_type, ' '); + + // get the path + read_with_limit(in, incoming.path, ' '); + + // Get the HTTP/1.1 - Ignore for now... + read_with_limit(in, incoming.protocol); + + key_value_map_ci& incoming_headers = incoming.headers; + key_value_map& cookies = incoming.cookies; + std::string& path = incoming.path; + std::string& content_type = incoming.content_type; + unsigned long content_length = 0; + + string line; + read_with_limit(in, line); + string first_part_of_header; + string::size_type position_of_double_point; + // now loop over all the incoming_headers + while (line != "\r") + { + position_of_double_point = line.find_first_of(':'); + if ( position_of_double_point != string::npos ) + { + first_part_of_header = dlib::trim(line.substr(0, position_of_double_point)); + + if ( !incoming_headers[first_part_of_header].empty() ) + incoming_headers[ first_part_of_header ] += " "; + incoming_headers[first_part_of_header] += dlib::trim(line.substr(position_of_double_point+1)); + + // look for Content-Type: + if (line.size() > 14 && strings_equal_ignore_case(line, "Content-Type:", 13)) + { + content_type = line.substr(14); + if (content_type[content_type.size()-1] == '\r') + content_type.erase(content_type.size()-1); + } + // look for Content-Length: + else if (line.size() > 16 && strings_equal_ignore_case(line, "Content-Length:", 15)) + { + istringstream sin(line.substr(16)); + sin >> content_length; + if (!sin) + { + throw http_parse_error("Invalid Content-Length of '" + line.substr(16) + "'", 411); + } + + if (content_length > max_content_length) + { + std::ostringstream sout; + sout << "Content-Length of post back is too large. It must be less than " << max_content_length; + throw http_parse_error(sout.str(), 413); + } + } + // look for any cookies + else if (line.size() > 6 && strings_equal_ignore_case(line, "Cookie:", 7)) + { + string::size_type pos = 6; + string key, value; + bool seen_key_start = false; + bool seen_equal_sign = false; + while (pos + 1 < line.size()) + { + ++pos; + // ignore whitespace between cookies + if (!seen_key_start && line[pos] == ' ') + continue; + + seen_key_start = true; + if (!seen_equal_sign) + { + if (line[pos] == '=') + { + seen_equal_sign = true; + } + else + { + key += line[pos]; + } + } + else + { + if (line[pos] == ';') + { + cookies[urldecode(key)] = urldecode(value); + seen_equal_sign = false; + seen_key_start = false; + key.clear(); + value.clear(); + } + else + { + value += line[pos]; + } + } + } + if (key.size() > 0) + { + cookies[urldecode(key)] = urldecode(value); + key.clear(); + value.clear(); + } + } + } // no ':' in it! + read_with_limit(in, line); + } // while (line != "\r") + + + // If there is data being posted back to us as a query string then + // pick out the queries using parse_url. + if ((strings_equal_ignore_case(incoming.request_type, "POST") || + strings_equal_ignore_case(incoming.request_type, "PUT")) && + strings_equal_ignore_case(left_substr(content_type,";"), "application/x-www-form-urlencoded")) + { + if (content_length > 0) + { + incoming.body.resize(content_length); + in.read(&incoming.body[0],content_length); + } + parse_url(incoming.body, incoming.queries); + } + + string::size_type pos = path.find_first_of("?"); + if (pos != string::npos) + { + parse_url(path.substr(pos+1), incoming.queries); + } + + + if (!in) + throw http_parse_error("Error parsing HTTP request", 500); + + return content_length; + } + +// ---------------------------------------------------------------------------------------- + + void read_body ( + std::istream& in, + incoming_things& incoming + ) + { + // if the body hasn't already been loaded and there is data to load + if (incoming.body.size() == 0 && + incoming.headers.count("Content-Length") != 0) + { + const unsigned long content_length = string_cast(incoming.headers["Content-Length"]); + + incoming.body.resize(content_length); + if (content_length > 0) + { + in.read(&incoming.body[0],content_length); + } + } + } + +// ---------------------------------------------------------------------------------------- + + void write_http_response ( + std::ostream& out, + outgoing_things outgoing, + const std::string& result + ) + { + using namespace http_impl; + key_value_map& new_cookies = outgoing.cookies; + key_value_map_ci& response_headers = outgoing.headers; + + // only send this header if the user hasn't told us to send another kind + bool has_content_type = false, has_location = false; + for(key_value_map_ci::const_iterator ci = response_headers.begin(); ci != response_headers.end(); ++ci ) + { + if ( !has_content_type && strings_equal_ignore_case(ci->first , "content-type") ) + { + has_content_type = true; + } + else if ( !has_location && strings_equal_ignore_case(ci->first , "location") ) + { + has_location = true; + } + } + + if ( has_location ) + { + outgoing.http_return = 302; + } + + if ( !has_content_type ) + { + response_headers["Content-Type"] = "text/html"; + } + + response_headers["Content-Length"] = cast_to_string(result.size()); + + out << "HTTP/1.0 " << outgoing.http_return << " " << outgoing.http_return_status << "\r\n"; + + // Set any new headers + for(key_value_map_ci::const_iterator ci = response_headers.begin(); ci != response_headers.end(); ++ci ) + { + out << ci->first << ": " << ci->second << "\r\n"; + } + + // set any cookies + for(key_value_map::const_iterator ci = new_cookies.begin(); ci != new_cookies.end(); ++ci ) + { + out << "Set-Cookie: " << urlencode(ci->first) << '=' << urlencode(ci->second) << "\r\n"; + } + out << "\r\n" << result; + } + +// ---------------------------------------------------------------------------------------- + + void write_http_response ( + std::ostream& out, + const http_parse_error& e + ) + { + outgoing_things outgoing; + outgoing.http_return = e.http_error_code; + outgoing.http_return_status = e.what(); + write_http_response(out, outgoing, std::string("Error processing request: ") + e.what()); + } + +// ---------------------------------------------------------------------------------------- + + void write_http_response ( + std::ostream& out, + const std::exception& e + ) + { + outgoing_things outgoing; + outgoing.http_return = 500; + outgoing.http_return_status = e.what(); + write_http_response(out, outgoing, std::string("Error processing request: ") + e.what()); + } + +// ---------------------------------------------------------------------------------------- + + const logger server_http::dlog("dlib.server_http"); + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SERVER_HTTP_CPp_ + diff --git a/ml/dlib/dlib/server/server_http.h b/ml/dlib/dlib/server/server_http.h new file mode 100644 index 000000000..4e95f679f --- /dev/null +++ b/ml/dlib/dlib/server/server_http.h @@ -0,0 +1,242 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net), Steven Van Ingelgem +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SERVER_HTTp_1_ +#define DLIB_SERVER_HTTp_1_ + + +#include "server_http_abstract.h" +#include +#include +#include +#include +#include +#include "../logger.h" +#include "../string.h" +#include "server_iostream.h" + +#ifdef __INTEL_COMPILER +// ignore the bogus warning about hiding on_connect() +#pragma warning (disable: 1125) +#endif + +#if _MSC_VER +# pragma warning( disable: 4503 ) +#endif // _MSC_VER + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class http_parse_error : public error + { + public: + http_parse_error(const std::string& str, int http_error_code_): + error(str),http_error_code(http_error_code_) {} + + const int http_error_code; + }; + +// ---------------------------------------------------------------------------------------- + + template > + class constmap : public std::map + { + public: + const Value& operator[](const Key& k) const + { + static const Value dummy = Value(); + + typename std::map::const_iterator ci = std::map::find(k); + + if ( ci == this->end() ) + return dummy; + else + return ci->second; + } + + Value& operator[](const Key& k) + { + return std::map::operator [](k); + } + }; + + + class less_case_insensitive + { + public: + bool operator()(const std::string& a, const std::string& b) const + { + unsigned long i = 0; + while (i < a.size() && i < b.size()) + { + const int cha = std::tolower(a[i]); + const int chb = std::tolower(b[i]); + if (cha < chb) + return true; + else if (cha > chb) + return false; + ++i; + } + if (a.size() < b.size()) + return true; + else + return false; + } + }; + typedef constmap< std::string, std::string, less_case_insensitive > key_value_map_ci; + typedef constmap< std::string, std::string > key_value_map; + + struct incoming_things + { + incoming_things ( + const std::string& foreign_ip_, + const std::string& local_ip_, + unsigned short foreign_port_, + unsigned short local_port_ + ): + foreign_ip(foreign_ip_), + foreign_port(foreign_port_), + local_ip(local_ip_), + local_port(local_port_) + {} + + + std::string path; + std::string request_type; + std::string content_type; + std::string protocol; + std::string body; + + key_value_map queries; + key_value_map cookies; + key_value_map_ci headers; + + std::string foreign_ip; + unsigned short foreign_port; + std::string local_ip; + unsigned short local_port; + }; + + struct outgoing_things + { + outgoing_things() : http_return(200), http_return_status("OK") { } + + key_value_map cookies; + key_value_map_ci headers; + unsigned short http_return; + std::string http_return_status; + }; + +// ---------------------------------------------------------------------------------------- + + unsigned long parse_http_request ( + std::istream& in, + incoming_things& incoming, + unsigned long max_content_length + ); + + void read_body ( + std::istream& in, + incoming_things& incoming + ); + + void write_http_response ( + std::ostream& out, + outgoing_things outgoing, + const std::string& result + ); + + void write_http_response ( + std::ostream& out, + const http_parse_error& e + ); + + void write_http_response ( + std::ostream& out, + const std::exception& e + ); + +// ---------------------------------------------------------------------------------------- + + class server_http : public server_iostream + { + + public: + + server_http() + { + max_content_length = 10*1024*1024; // 10MB + } + + unsigned long get_max_content_length ( + ) const + { + auto_mutex lock(http_class_mutex); + return max_content_length; + } + + void set_max_content_length ( + unsigned long max_length + ) + { + auto_mutex lock(http_class_mutex); + max_content_length = max_length; + } + + + private: + virtual const std::string on_request ( + const incoming_things& incoming, + outgoing_things& outgoing + ) = 0; + + + virtual void on_connect ( + std::istream& in, + std::ostream& out, + const std::string& foreign_ip, + const std::string& local_ip, + unsigned short foreign_port, + unsigned short local_port, + uint64 + ) + { + try + { + incoming_things incoming(foreign_ip, local_ip, foreign_port, local_port); + outgoing_things outgoing; + + parse_http_request(in, incoming, get_max_content_length()); + read_body(in, incoming); + const std::string& result = on_request(incoming, outgoing); + write_http_response(out, outgoing, result); + } + catch (http_parse_error& e) + { + dlog << LERROR << "Error processing request from: " << foreign_ip << " - " << e.what(); + write_http_response(out, e); + } + catch (std::exception& e) + { + dlog << LERROR << "Error processing request from: " << foreign_ip << " - " << e.what(); + write_http_response(out, e); + } + } + + mutex http_class_mutex; + unsigned long max_content_length; + const static logger dlog; + }; + +// ---------------------------------------------------------------------------------------- + +} + +#ifdef NO_MAKEFILE +#include "server_http.cpp" +#endif + +#endif // DLIB_SERVER_HTTp_1_ + diff --git a/ml/dlib/dlib/server/server_http_abstract.h b/ml/dlib/dlib/server/server_http_abstract.h new file mode 100644 index 000000000..0bebfb61d --- /dev/null +++ b/ml/dlib/dlib/server/server_http_abstract.h @@ -0,0 +1,390 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net), Steven Van Ingelgem +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SERVER_HTTp_ABSTRACT_ +#ifdef DLIB_SERVER_HTTp_ABSTRACT_ + +#include "server_iostream_abstract.h" +#include +#include +#include + +namespace dlib +{ + +// ----------------------------------------------------------------------------------------- + + template < + typename Key, + typename Value, + typename Comparer = std::less + > + class constmap : public std::map + { + /*! + WHAT THIS OBJECT REPRESENTS + This is simply an extension to the std::map that allows you + to use the operator[] accessor with a constant map. + !*/ + public: + + const Value& operator[]( + const Key& k + ) const; + /*! + ensures + - if (this->find(k) != this->end()) then + - This map contains the given key + - return the value associated with the given key + - else + - return a default initialized Value object + !*/ + + Value& operator[]( + const Key& k + ) { return std::map::operator [](k); } + /*! + ensures + - This function does the same thing as the normal std::map operator[] + function. + - if (this->find(k) != this->end()) then + - This map contains the given key + - return the value associated with the given key + - else + - Adds a new entry into the map that is associated with the + given key. The new entry will be default initialized and + this function returns a reference to it. + !*/ + }; + + typedef constmap key_value_map; + // This version of key_value_map treats the key string as being case-insensitive. + // For example, a key string of "Content-Type" would access the same element as a key + // of "content-type". + typedef constmap key_value_map_ci; + +// ----------------------------------------------------------------------------------------- + + struct incoming_things + { + /*! + WHAT THIS OBJECT REPRESENTS + This object contains all the various bits of information that describe a + HTTP request that comes into a web server. + + For a detailed discussion of the fields of this object, see the + server_http::on_request() method defined later in this file. + !*/ + + incoming_things ( + const std::string& foreign_ip_, + const std::string& local_ip_, + unsigned short foreign_port_, + unsigned short local_port_ + ); + /*! + ensures + - #foreign_ip = foreign_ip_ + - #foreign_port = foreign_port_ + - #local_ip = local_ip_ + - #local_port = local_port_ + !*/ + + std::string path; + std::string request_type; + std::string content_type; + std::string protocol; + std::string body; + + key_value_map queries; + key_value_map cookies; + key_value_map_ci headers; + + std::string foreign_ip; + unsigned short foreign_port; + std::string local_ip; + unsigned short local_port; + }; + + struct outgoing_things + { + /*! + WHAT THIS OBJECT REPRESENTS + This object contains all the various bits of information that describe a + HTTP response from a web server. + + For a detailed discussion of the fields of this object, see the + server_http::on_request() method defined later in this file. + !*/ + + outgoing_things( + ); + /*! + ensures + - #http_return == 200 + - #http_return_status == "OK" + !*/ + + key_value_map cookies; + key_value_map_ci headers; + unsigned short http_return; + std::string http_return_status; + }; + +// ----------------------------------------------------------------------------------------- + + class http_parse_error : public error + { + /*! + WHAT THIS OBJECT REPRESENTS + This is an exception thrown by the parse_http_request() routine if + there is a problem. + !*/ + }; + +// ----------------------------------------------------------------------------------------- + + unsigned long parse_http_request ( + std::istream& in, + incoming_things& incoming, + unsigned long max_content_length + ); + /*! + ensures + - Attempts to read a HTTP GET, POST, or PUT request from the given input + stream. + - Reads all headers of the request and puts them into #incoming. In particular, + this function populates the following fields: + - #incoming.path + - #incoming.request_type + - #incoming.content_type + - #incoming.protocol + - #incoming.queries + - #incoming.cookies + - #incoming.headers + - This function also populates the #incoming.body field if and only if the + Content-Type field is equal to "application/x-www-form-urlencoded". + Otherwise, the content is not read from the stream. + throws + - http_parse_error + This exception is thrown if the Content-Length coming from the web + browser is greater than max_content_length or if any other problem + is detected with the request. + !*/ + + void read_body ( + std::istream& in, + incoming_things& incoming + ); + /*! + requires + - parse_http_request(in,incoming,max_content_length) has already been called + and therefore populated the fields of incoming. + ensures + - if (incoming.body has already been populated with the content of an HTTP + request) then + - this function does nothing + - else + - reads the body of the HTTP request into #incoming.body. + !*/ + + void write_http_response ( + std::ostream& out, + outgoing_things outgoing, + const std::string& result + ); + /*! + ensures + - Writes an HTTP response, defined by the data in outgoing, to the given output + stream. + - The result variable is written out as the content of the response. + !*/ + + void write_http_response ( + std::ostream& out, + const http_parse_error& e + ); + /*! + ensures + - Writes an HTTP error response based on the information in the exception + object e. + !*/ + + void write_http_response ( + std::ostream& out, + const std::exception& e + ); + /*! + ensures + - Writes an HTTP error response based on the information in the exception + object e. + !*/ + +// ----------------------------------------------------------------------------------------- +// ----------------------------------------------------------------------------------------- + + class server_http : public server_iostream + { + + /*! + WHAT THIS EXTENSION DOES FOR server_iostream + This extension turns the server object into a simple HTTP server. It only + handles HTTP GET, PUT and POST requests and each incoming request triggers + the on_request() callback. + + COOKIE STRINGS + The strings returned in the cookies key_value_map should be of the following form: + key: cookie_name + value: cookie contents; expires=Fri, 31-Dec-2010 23:59:59 GMT; path=/; domain=.example.net + + You don't have to supply all the extra cookie arguments. So if you just want to + set a cookie that will expire when the client's browser is closed you can + just say something like incoming.cookies["cookie_name"] = "cookie contents"; + + HTTP HEADERS + The HTTP headers in the incoming.headers and outgoing.headers are the name/value pairs + of HTTP headers. For example, the HTTP header "Content-Type: text/html" would be + encoded such that outgoing.headers["Content-Type"] == "text/html". + + Also note that if you wish to change the content type of your response to the + client you may do so by setting the "Content-Type" header to whatever you like. + However, setting this field manually is not necessary as it will default to + "text/html" if you don't explicitly set it to something. + !*/ + + public: + + server_http ( + ); + /*! + ensures + - #get_max_content_length() == 10*1024*1024 + !*/ + + unsigned long get_max_content_length ( + ) const; + /*! + ensures + - returns the max allowable content length, in bytes, of the post back to + the web server. If a client attempts to send more data than this then an + error number 413 is returned back to the client and the request is not + processed by the web server. + !*/ + + void set_max_content_length ( + unsigned long max_length + ); + /*! + ensures + - #get_max_content_length() == max_length + !*/ + + private: + + virtual const std::string on_request ( + const incoming_things& incoming, + outgoing_things& outgoing + ) = 0; + /*! + requires + - on_request() is called when there is an HTTP GET or POST request to be serviced + - on_request() is run in its own thread + - is_running() == true + - the number of current on_request() functions running < get_max_connection() + - in incoming: + - incoming.path == the path being requested by this request + - incoming.request_type == the type of request, GET or POST + - incoming.content_type == the content type associated with this request + - incoming.protocol == The protocol being used by the web browser (e.g. HTTP/1.1) + - incoming.body == a string that contains the data that was posted back to the + web server by the client (e.g. The string has the length specified by the + Content-Length header). + - incoming.body.size() < get_max_content_length() + - incoming.queries == a map that contains all the key/value pairs in the query + string of this request. The key and value strings of the query string will + have been decoded back into their original form before being sent to this + function (i.e. '+' decoded back to ' ' and "%hh" into its corresponding + ascii value. So the URL-encoding is decoded automatically) + - incoming.cookies == The set of cookies that came from the client along with + this request. The cookies will have been decoded back to normal form + from the URL-encoding. + - incoming.headers == a map that contains all the incoming HTTP headers + from the client web browser. + - incoming.foreign_ip == the foreign ip address for this request + - incoming.foreign_port == the foreign port number for this request + - incoming.local_ip == the IP of the local interface this request is coming in on + - incoming.local_port == the local port number this request is coming in on + - in outgoing: + - outgoing.cookies.size() == 0 + - outgoing.headers.size() == 0 + - outgoing.http_return == 200 + - outgoing.http_return_status == "OK" + ensures + - This function returns the HTML page to be displayed as the response to this request. + - this function will not call clear() + - #outgoing.cookies == a set of new cookies to pass back to the client along + with the result of this request. (Note that URL-encoding is automatically applied + so you don't have to do it yourself) + - #outgoing.headers == a set of additional headers you wish to appear in the + HTTP response to this request. (This may be empty, the minimum needed headers + will be added automatically if you don't set them) + - outgoing.http_return and outgoing.http_return_status may be set to override the + default HTTP return code of 200 OK + throws + - throws only exceptions derived from std::exception. If an exception is thrown + then the error string from the exception is returned to the web browser. + !*/ + + + // ----------------------------------------------------------------------- + // Implementation Notes + // ----------------------------------------------------------------------- + + virtual void on_connect ( + std::istream& in, + std::ostream& out, + const std::string& foreign_ip, + const std::string& local_ip, + unsigned short foreign_port, + unsigned short local_port, + uint64 + ) + /*! + on_connect() is the function defined by server_iostream which is overloaded by + server_http. In particular, the server_http's implementation is shown below. + In it you can see how the server_http parses the incoming http request, gets a + response by calling on_request(), and sends it back using the helper routines + defined at the top of this file. + + Therefore, if you want to modify the behavior of the HTTP server, for example, + to do some more complex data streaming requiring direct access to the + iostreams, then you can do so by defining your own on_connect() routine. In + particular, the default implementation shown below is a good starting point. + !*/ + { + try + { + incoming_things incoming(foreign_ip, local_ip, foreign_port, local_port); + outgoing_things outgoing; + + parse_http_request(in, incoming, get_max_content_length()); + read_body(in, incoming); + const std::string& result = on_request(incoming, outgoing); + write_http_response(out, outgoing, result); + } + catch (http_parse_error& e) + { + write_http_response(out, e); + } + catch (std::exception& e) + { + write_http_response(out, e); + } + } + }; + +} + +#endif // DLIB_SERVER_HTTp_ABSTRACT_ + + + diff --git a/ml/dlib/dlib/server/server_iostream.cpp b/ml/dlib/dlib/server/server_iostream.cpp new file mode 100644 index 000000000..0fd49b67c --- /dev/null +++ b/ml/dlib/dlib/server/server_iostream.cpp @@ -0,0 +1,14 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SERVER_IOSTREAM_CPp_ +#define DLIB_SERVER_IOSTREAM_CPp_ + +#include "server_iostream.h" + +namespace dlib +{ + const logger server_iostream::_dLog("dlib.server_iostream"); +} + +#endif // DLIB_SERVER_IOSTREAM_CPp_ + diff --git a/ml/dlib/dlib/server/server_iostream.h b/ml/dlib/dlib/server/server_iostream.h new file mode 100644 index 000000000..eed349015 --- /dev/null +++ b/ml/dlib/dlib/server/server_iostream.h @@ -0,0 +1,155 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SERVER_IOSTREAm_1_ +#define DLIB_SERVER_IOSTREAm_1_ + +#include +#include "server_iostream_abstract.h" +#include "../logger.h" +#include "../uintn.h" +#include "server_kernel.h" +#include "../sockstreambuf.h" +#include "../map.h" + + +namespace dlib +{ + + class server_iostream : public server + { + + /*! + INITIAL VALUE + - next_id == 0 + - con_map.size() == 0 + + CONVENTION + - next_id == the id of the next connection + - for all current connections + - con_map[id] == the connection object with the given id + - m == the mutex that protects the members of this object + !*/ + + typedef map::kernel_2a>::kernel_1b id_map; + + public: + server_iostream( + ) : + next_id(0) + {} + + ~server_iostream( + ) + { + server::clear(); + } + + protected: + + void shutdown_connection ( + uint64 id + ) + { + auto_mutex M(m); + if (con_map.is_in_domain(id)) + { + con_map[id]->shutdown(); + } + } + + private: + + virtual void on_connect ( + std::istream& in, + std::ostream& out, + const std::string& foreign_ip, + const std::string& local_ip, + unsigned short foreign_port, + unsigned short local_port, + uint64 connection_id + )=0; + + void on_connect ( + connection& con + ) + { + bool my_fault = true; + uint64 this_con_id=0; + try + { + sockstreambuf buf(&con); + std::istream in(&buf); + std::ostream out(&buf); + in.tie(&out); + + // add this connection to the con_map + { + auto_mutex M(m); + this_con_id = next_id; + connection* this_con = &con; + con_map.add(this_con_id,this_con); + this_con_id = next_id; + ++next_id; + } + + my_fault = false; + on_connect( + in, + out, + con.get_foreign_ip(), + con.get_local_ip(), + con.get_foreign_port(), + con.get_local_port(), + this_con_id + ); + + // remove this connection from the con_map + { + auto_mutex M(m); + connection* this_con; + uint64 junk; + con_map.remove(this_con_id,junk,this_con); + } + + } + catch (std::bad_alloc&) + { + // make sure we remove this connection from the con_map + { + auto_mutex M(m); + if (con_map.is_in_domain(this_con_id)) + { + connection* this_con; + uint64 junk; + con_map.remove(this_con_id,junk,this_con); + } + } + + _dLog << LERROR << "We ran out of memory in server_iostream::on_connect()"; + // if this is an escaped exception from on_connect then let it fly! + // Seriously though, this way it is obvious to the user that something bad happened + // since they probably won't have the dlib logger enabled. + if (!my_fault) + throw; + } + } + + uint64 next_id; + id_map con_map; + const static logger _dLog; + mutex m; + + + }; + + +} + +#ifdef NO_MAKEFILE +#include "server_iostream.cpp" +#endif + +#endif // DLIB_SERVER_IOSTREAm_1_ + + + diff --git a/ml/dlib/dlib/server/server_iostream_abstract.h b/ml/dlib/dlib/server/server_iostream_abstract.h new file mode 100644 index 000000000..86cd460bf --- /dev/null +++ b/ml/dlib/dlib/server/server_iostream_abstract.h @@ -0,0 +1,84 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SERVER_IOSTREAm_ABSTRACT_ +#ifdef DLIB_SERVER_IOSTREAm_ABSTRACT_ + + +#include "server_kernel_abstract.h" +#include +#include +#include "../uintn.h" + +namespace dlib +{ + + class server_iostream : public server + { + + /*! + WHAT THIS EXTENSION DOES FOR server + This extension redefines the on_connect() function so that + instead of giving you a connection object you get an istream + and ostream object. + + THREAD SAFETY + Note that in on_connect() the input stream in is tied to the output stream + out. This means that when you read from in it will modify out and thus + it is not safe to touch in and out concurrently from different threads + unless you untie them (which you do by saying in.tie(0);) + !*/ + + protected: + + void shutdown_connection ( + uint64 id + ); + /*! + ensures + - if (there is a connection currently being serviced with the given id) then + - the specified connection is shutdown. (i.e. connection::shutdown() is + called on it so the iostreams operating on it will return EOF) + !*/ + + private: + + virtual void on_connect ( + std::istream& in, + std::ostream& out, + const std::string& foreign_ip, + const std::string& local_ip, + unsigned short foreign_port, + unsigned short local_port, + uint64 connection_id + )=0; + /*! + requires + - on_connect() is called when there is a new TCP connection that needs + to be serviced. + - in == the input stream that reads data from the new connection + - out == the output stream that writes data to the new connection + - in.tie() == &out (i.e. when you read from in it automatically calls out.flush()) + - foreign_ip == the foreign ip address for this connection + - foreign_port == the foreign port number for this connection + - local_ip == the IP of the local interface this connection is using + - local_port == the local port number for this connection + - on_connect() is run in its own thread + - is_running() == true + - the number of current connections < get_max_connection() + - connection_id == an integer that uniquely identifies this connection. + It can be used by shutdown_connection() to terminate this connection. + ensures + - when the iostreams hit EOF on_connect() will terminate. + (because this is how clear() signals you the server is shutting down) + - this function will not call clear() + throws + - does not throw any exceptions + !*/ + + }; + +} + +#endif // DLIB_SERVER_IOSTREAm_ABSTRACT_ + + diff --git a/ml/dlib/dlib/server/server_kernel.cpp b/ml/dlib/dlib/server/server_kernel.cpp new file mode 100644 index 000000000..9e8130e78 --- /dev/null +++ b/ml/dlib/dlib/server/server_kernel.cpp @@ -0,0 +1,595 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SERVER_KERNEL_CPp_ +#define DLIB_SERVER_KERNEL_CPp_ + +#include "server_kernel.h" +#include "../string.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + server:: + server ( + ) : + listening_port(0), + running(false), + shutting_down(false), + running_signaler(running_mutex), + thread_count(0), + thread_count_signaler(thread_count_mutex), + max_connections(1000), + thread_count_zero(thread_count_mutex), + graceful_close_timeout(500) + { + } + +// ---------------------------------------------------------------------------------------- + + server:: + ~server ( + ) + { + clear(); + } + +// ---------------------------------------------------------------------------------------- + + unsigned long server:: + get_graceful_close_timeout ( + ) const + { + auto_mutex lock(max_connections_mutex); + return graceful_close_timeout; + } + +// ---------------------------------------------------------------------------------------- + + void server:: + set_graceful_close_timeout ( + unsigned long timeout + ) + { + auto_mutex lock(max_connections_mutex); + graceful_close_timeout = timeout; + } + +// ---------------------------------------------------------------------------------------- + + + int server:: + get_max_connections ( + ) const + { + max_connections_mutex.lock(); + int temp = max_connections; + max_connections_mutex.unlock(); + return temp; + } + +// ---------------------------------------------------------------------------------------- + + void server:: + set_max_connections ( + int max + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( + max >= 0 , + "\tvoid server::set_max_connections" + << "\n\tmax == " << max + << "\n\tthis: " << this + ); + + max_connections_mutex.lock(); + max_connections = max; + max_connections_mutex.unlock(); + } + +// ---------------------------------------------------------------------------------------- + + void server:: + clear ( + ) + { + // signal that we are shutting down + shutting_down_mutex.lock(); + shutting_down = true; + shutting_down_mutex.unlock(); + + + + max_connections_mutex.lock(); + listening_port_mutex.lock(); + listening_ip_mutex.lock(); + listening_ip = ""; + listening_port = 0; + max_connections = 1000; + graceful_close_timeout = 500; + listening_port_mutex.unlock(); + listening_ip_mutex.unlock(); + max_connections_mutex.unlock(); + + + // tell all the connections to shut down + cons_mutex.lock(); + connection* temp; + while (cons.size() > 0) + { + cons.remove_any(temp); + temp->shutdown(); + } + cons_mutex.unlock(); + + + // wait for all the connections to shut down + thread_count_mutex.lock(); + while (thread_count > 0) + { + thread_count_zero.wait(); + } + thread_count_mutex.unlock(); + + + + + // wait for the listener to close + running_mutex.lock(); + while (running == true) + { + running_signaler.wait(); + } + running_mutex.unlock(); + + + + // signal that the shutdown is complete + shutting_down_mutex.lock(); + shutting_down = false; + shutting_down_mutex.unlock(); + } + +// ---------------------------------------------------------------------------------------- + + void server:: + start_async_helper ( + ) + { + try + { + start_accepting_connections(); + } + catch (std::exception& e) + { + sdlog << LERROR << e.what(); + } + } + +// ---------------------------------------------------------------------------------------- + + void server:: + start_async ( + ) + { + auto_mutex lock(running_mutex); + if (running) + return; + + // Any exceptions likely to be thrown by the server are going to be + // thrown when trying to bind the port. So calling this here rather + // than in the thread we are about to make will cause start_async() + // to report errors back to the user in a very straight forward way. + open_listening_socket(); + + async_start_thread.reset(new thread_function(make_mfp(*this,&server::start_async_helper))); + } + +// ---------------------------------------------------------------------------------------- + + void server:: + open_listening_socket ( + ) + { + if (!sock) + { + int status = create_listener(sock,listening_port,listening_ip); + const int port_used = listening_port; + + // if there was an error then clear this object + if (status < 0) + { + max_connections_mutex.lock(); + listening_port_mutex.lock(); + listening_ip_mutex.lock(); + listening_ip = ""; + listening_port = 0; + max_connections = 1000; + graceful_close_timeout = 500; + listening_port_mutex.unlock(); + listening_ip_mutex.unlock(); + max_connections_mutex.unlock(); + } + + + + // throw an exception for the error + if (status == PORTINUSE) + { + throw dlib::socket_error( + EPORT_IN_USE, + "error occurred in server::start()\nport " + cast_to_string(port_used) + " already in use" + ); + } + else if (status == OTHER_ERROR) + { + throw dlib::socket_error( + "error occurred in server::start()\nunable to create listener" + ); + } + } + + running_mutex.lock(); + running = true; + running_mutex.unlock(); + } + +// ---------------------------------------------------------------------------------------- + + void server:: + start ( + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( + this->is_running() == false, + "\tvoid server::start" + << "\n\tis_running() == " << this->is_running() + << "\n\tthis: " << this + ); + + start_accepting_connections(); + + } + +// ---------------------------------------------------------------------------------------- + + void server:: + start_accepting_connections ( + ) + { + open_listening_socket(); + + // determine the listening port + bool port_assigned = false; + listening_port_mutex.lock(); + if (listening_port == 0) + { + port_assigned = true; + listening_port = sock->get_listening_port(); + } + listening_port_mutex.unlock(); + if (port_assigned) + on_listening_port_assigned(); + + + + int status = 0; + + connection* client; + bool exit = false; + while ( true ) + { + + + // accept the next connection + status = sock->accept(client,1000); + + + // if there was an error then quit the loop + if (status == OTHER_ERROR) + { + break; + } + + shutting_down_mutex.lock(); + // if we are shutting down then signal that we should quit the loop + exit = shutting_down; + shutting_down_mutex.unlock(); + + + // if we should be shutting down + if (exit) + { + // if a connection was opened then close it + if (status == 0) + delete client; + break; + } + + + + // if the accept timed out + if (status == TIMEOUT) + { + continue; + } + + + + + + // add this new connection to cons + cons_mutex.lock(); + connection* client_temp = client; + try{cons.add(client_temp);} + catch(...) + { + sock.reset(); + delete client; + cons_mutex.unlock(); + + // signal that we are not running start() anymore + running_mutex.lock(); + running = false; + running_signaler.broadcast(); + running_mutex.unlock(); + + + clear(); + throw; + } + cons_mutex.unlock(); + + + // make a param structure + param* temp = 0; + try{ + temp = new param ( + *this, + *client, + get_graceful_close_timeout() + ); + } catch (...) + { + sock.reset(); + delete client; + running_mutex.lock(); + running = false; + running_signaler.broadcast(); + running_mutex.unlock(); + clear(); + throw; + } + + + // if create_new_thread failed + if (!create_new_thread(service_connection,temp)) + { + delete temp; + // close the listening socket + sock.reset(); + + // close the new connection and remove it from cons + cons_mutex.lock(); + connection* ctemp; + if (cons.is_member(client)) + { + cons.remove(client,ctemp); + } + delete client; + cons_mutex.unlock(); + + + // signal that the listener has closed + running_mutex.lock(); + running = false; + running_signaler.broadcast(); + running_mutex.unlock(); + + // make sure the object is cleared + clear(); + + // throw the exception + throw dlib::thread_error( + ECREATE_THREAD, + "error occurred in server::start()\nunable to start thread" + ); + } + // if we made the new thread then update thread_count + else + { + // increment the thread count + thread_count_mutex.lock(); + ++thread_count; + if (thread_count == 0) + thread_count_zero.broadcast(); + thread_count_mutex.unlock(); + } + + + + + // check if we have hit the maximum allowed number of connections + max_connections_mutex.lock(); + // if max_connections is zero or the loop is ending then skip this + if (max_connections != 0) + { + // wait for thread_count to be less than max_connections + thread_count_mutex.lock(); + while (thread_count >= max_connections) + { + max_connections_mutex.unlock(); + thread_count_signaler.wait(); + max_connections_mutex.lock(); + + // if we are shutting down the quit the loop + shutting_down_mutex.lock(); + exit = shutting_down; + shutting_down_mutex.unlock(); + if (exit) + break; + } + thread_count_mutex.unlock(); + } + max_connections_mutex.unlock(); + + if (exit) + { + break; + } + } //while ( true ) + + + // close the socket + sock.reset(); + + // signal that the listener has closed + running_mutex.lock(); + running = false; + running_signaler.broadcast(); + running_mutex.unlock(); + + // if there was an error with accept then throw an exception + if (status == OTHER_ERROR) + { + // make sure the object is cleared + clear(); + + // throw the exception + throw dlib::socket_error( + "error occurred in server::start()\nlistening socket returned error" + ); + } + } + +// ---------------------------------------------------------------------------------------- + + bool server:: + is_running ( + ) const + { + running_mutex.lock(); + bool temp = running; + running_mutex.unlock(); + return temp; + } + +// ---------------------------------------------------------------------------------------- + + const std::string server:: + get_listening_ip ( + ) const + { + listening_ip_mutex.lock(); + std::string ip(listening_ip); + listening_ip_mutex.unlock(); + return ip; + } + +// ---------------------------------------------------------------------------------------- + + int server:: + get_listening_port ( + ) const + { + listening_port_mutex.lock(); + int port = listening_port; + listening_port_mutex.unlock(); + return port; + } + +// ---------------------------------------------------------------------------------------- + + void server:: + set_listening_port ( + int port + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( + ( port >= 0 && + this->is_running() == false ), + "\tvoid server::set_listening_port" + << "\n\tport == " << port + << "\n\tis_running() == " << this->is_running() + << "\n\tthis: " << this + ); + + listening_port_mutex.lock(); + listening_port = port; + listening_port_mutex.unlock(); + } + +// ---------------------------------------------------------------------------------------- + + void server:: + set_listening_ip ( + const std::string& ip + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( + ( ( is_ip_address(ip) || ip == "" ) && + this->is_running() == false ), + "\tvoid server::set_listening_ip" + << "\n\tip == " << ip + << "\n\tis_running() == " << this->is_running() + << "\n\tthis: " << this + ); + + listening_ip_mutex.lock(); + listening_ip = ip; + listening_ip_mutex.unlock(); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // static member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + const logger server::sdlog("dlib.server"); + + void server:: + service_connection( + void* item + ) + { + param& p = *static_cast(item); + + + p.the_server.on_connect(p.new_connection); + + + // remove this connection from cons and close it + p.the_server.cons_mutex.lock(); + connection* temp; + if (p.the_server.cons.is_member(&p.new_connection)) + p.the_server.cons.remove(&p.new_connection,temp); + p.the_server.cons_mutex.unlock(); + + try{ close_gracefully(&p.new_connection, p.graceful_close_timeout); } + catch (...) { sdlog << LERROR << "close_gracefully() threw"; } + + // decrement the thread count and signal if it is now zero + p.the_server.thread_count_mutex.lock(); + --p.the_server.thread_count; + p.the_server.thread_count_signaler.broadcast(); + if (p.the_server.thread_count == 0) + p.the_server.thread_count_zero.broadcast(); + p.the_server.thread_count_mutex.unlock(); + + delete &p; + + + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SERVER_KERNEL_CPp_ + diff --git a/ml/dlib/dlib/server/server_kernel.h b/ml/dlib/dlib/server/server_kernel.h new file mode 100644 index 000000000..4232ff343 --- /dev/null +++ b/ml/dlib/dlib/server/server_kernel.h @@ -0,0 +1,234 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SERVER_KERNEL_1_ +#define DLIB_SERVER_KERNEL_1_ + +#include "server_kernel_abstract.h" + +#include +#include + +#include "../threads.h" +#include "../sockets.h" +#include "../algs.h" +#include "../set.h" +#include "../logger.h" + + +namespace dlib +{ + + // These forward declarations are here so we can use them in the typedefs in the server + // class. The reason for this is for backwards compatibility with previous versions of + // dlib. + class server_http; + class server_iostream; + + class server + { + + + /*! + INITIAL VALUE + listening_port == 0 + listening_ip == "" + running == false + shutting_down == false + cons.size() == 0 + listening_port_mutex == a mutex + listening_ip_mutex == a mutex + running_mutex == a mutex + running_signaler == a signaler associated with running_mutex + shutting_down_mutex == a mutex + cons_mutex == a mutex + thread_count == 0 + thread_count_mutex == a mutex + thread_count_signaler == a signaler associated with thread_count_mutex + thread_count_zero == a signaler associated with thread_count_mutex + max_connections == 1000 + max_connections_mutex == a mutex for max_connections and graceful_close_timeout + graceful_close_timeout == 500 + + CONVENTION + listening_port == get_listening_port() + listening_ip == get_listening_ip() + running == is_running() + shutting_down == true while clear() is running. this + bool is used to tell the thread blocked on + accept that it should terminate + cons == a set containing all open connections + listening_port_mutex == a mutex for listening_port + listening_ip_mutex == a mutex for listening_ip + running_mutex == a mutex for running + running_signaler == a signaler for running and + is associated with running_mutex. it is + used to signal when running is false + shutting_down_mutex == a mutex for shutting_down + cons_mutex == a mutex for cons + thread_count == the number of threads currently running + thread_count_mutex == a mutex for thread_count + thread_count_signaler == a signaler for thread_count and + is associated with thread_count_mutex. it + is used to signal when thread_count is + decremented + thread_count_zero == a signaler for thread_count and + is associated with thread_count_mutex. it + is used to signal when thread_count becomes + zero + max_connections == get_max_connections() + max_connections_mutex == a mutex for max_connections + !*/ + + + typedef set::kernel_1a set_of_connections; + + // this structure is used to pass parameters to new threads + struct param + { + param ( + server& server_, + connection& new_connection_, + unsigned long graceful_close_timeout_ + ) : + the_server(server_), + new_connection(new_connection_), + graceful_close_timeout(graceful_close_timeout_) + {} + + server& the_server; + connection& new_connection; + unsigned long graceful_close_timeout; + }; + + + + public: + + // These typedefs are here for backward compatibility with previous versions of dlib + typedef server kernel_1a; + typedef server kernel_1a_c; + typedef server_iostream iostream_1a; + typedef server_iostream iostream_1a_c; + typedef server_http http_1a; + typedef server_http http_1a_c; + + server( + ); + + virtual ~server( + ); + + void clear( + ); + + void start ( + ); + + bool is_running ( + ) const; + + const std::string get_listening_ip ( + ) const; + + int get_listening_port ( + ) const; + + void set_listening_port ( + int port + ); + + void set_listening_ip ( + const std::string& ip + ); + + void set_max_connections ( + int max + ); + + int get_max_connections ( + ) const; + + void start_async ( + ); + + void set_graceful_close_timeout ( + unsigned long timeout + ); + + unsigned long get_graceful_close_timeout ( + ) const; + + private: + + void start_async_helper ( + ); + + void start_accepting_connections ( + ); + + void open_listening_socket ( + ); + + virtual void on_connect ( + connection& new_connection + )=0; + + virtual void on_listening_port_assigned ( + ) {} + + const static logger sdlog; + + static void service_connection( + void* item + ); + /*! + requires + item is a pointer to a param struct + ensures + services the new connection + will take care of closing the connection and + adding the connection to cons when it first starts and + remove the connection from cons and signal that it has + done so when it ends + !*/ + + // data members + int listening_port; + std::string listening_ip; + bool running; + bool shutting_down; + set_of_connections cons; + mutex listening_port_mutex; + mutex listening_ip_mutex; + rmutex running_mutex; + rsignaler running_signaler; + mutex shutting_down_mutex; + mutex cons_mutex; + int thread_count; + mutex thread_count_mutex; + signaler thread_count_signaler; + int max_connections; + mutex max_connections_mutex; + signaler thread_count_zero; + std::unique_ptr async_start_thread; + std::unique_ptr sock; + unsigned long graceful_close_timeout; + + + // restricted functions + server(server&); + server& operator= ( + server& + ); + }; + +// ---------------------------------------------------------------------------------------- + +} + +#ifdef NO_MAKEFILE +#include "server_kernel.cpp" +#endif + +#endif // DLIB_SERVER_KERNEL_1_ + diff --git a/ml/dlib/dlib/server/server_kernel_abstract.h b/ml/dlib/dlib/server/server_kernel_abstract.h new file mode 100644 index 000000000..f7860d26c --- /dev/null +++ b/ml/dlib/dlib/server/server_kernel_abstract.h @@ -0,0 +1,310 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SERVER_KERNEL_ABSTRACT_ +#ifdef DLIB_SERVER_KERNEL_ABSTRACT_ + +#include "../threads/threads_kernel_abstract.h" +#include "../sockets/sockets_kernel_abstract.h" +#include + + +namespace dlib +{ + class server + { + + /*! + INITIAL VALUE + get_listening_ip() == "" + get_listening_port() == 0 + is_running() == false + get_max_connections() == 1000 + get_graceful_close_timeout() == 500 + + + CALLBACK FUNCTIONS + on_connect(): + To use this object inherit from it and define the pure virtual function + on_connect. Inside this function is where you will handle each new + connection. Note that the connection object passed to on_connect() should + NOT be closed, just let the function end and it will be gracefully closed + for you. Also note that each call to on_connect() is run in its own + thread. Also note that on_connect() should NOT throw any exceptions, + all exceptions must be dealt with inside on_connect() and cannot be + allowed to leave. + + on_listening_port_assigned(): + This function is called to let the client know that the operating + system has assigned a port number to the listening port. This + happens if a port number of zero was given. Note that this + function does not need to be defined. If you don't care then + don't define it and it will do nothing. Note also that this function + is NOT called in its own thread. Thus, making it block might hang the + server. + + WHAT THIS OBJECT REPRESENTS + This object represents a server that listens on a port and spawns new + threads to handle each new connection. + + Note that the clear() function does not return until all calls to + on_connect() have finished and the start() function has been shutdown. + Also note that when clear() is called all open connection objects + will be shutdown(). + + A note about get_max_connections(): when the maximum number of connections + has been reached accept() will simply not be called until the number of + open connections drops below get_max_connections(). This means connections + will just wait to be serviced, rather than being outright refused. + + THREAD SAFETY + All member functions are thread-safe. + !*/ + + public: + + server( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc + - dlib::thread_error + !*/ + + virtual ~server( + ); + /*! + requires + - is not called from any of server's callbacks + ensures + - all resources associated with *this have been released + !*/ + + void clear( + ); + /*! + requires + - is not called from any of server's callbacks + ensures + - #*this has its initial value + - all open connection objects passed to on_connect() are shutdown() + - blocks until all calls to on_connect() have finished + - blocks until the start() function has released all its resources + throws + - std::bad_alloc + if this exception is thrown then the server object is unusable + until clear() is called and succeeds + !*/ + + void start ( + ); + /*! + requires + - is_running() == false + ensures + - starts listening on the port and ip specified by get_listening_ip() + and #get_listening_port() for new connections. + - if (get_listening_port() == 0) then + - a port to listen on will be automatically selected + - #get_listening_port() == the selected port being used + - if (get_listening_ip() == "" ) then + - all local IPs will be listened on + - blocks until clear() is called or an error occurs + throws + - dlib::socket_error + start() will throw this exception if there is some problem binding + ports and/or starting the server or if there is a problem + accepting new connections while it's running. + If this happens then + - All open connection objects passed to on_connect() are shutdown() + and the exception will not be thrown until all on_connect() calls + have terminated. + - The server will be cleared and returned to its initial value. + - dlib::thread_error + start() will throw this exception if there is a problem + creating new threads. Or it may throw this exception if there + is a problem creating threading objects. + If this happens then + - All open connection objects passed to on_connect() are shutdown() + and the exception will not be thrown until all on_connect() calls + have terminated. + - The server will be cleared and returned to its initial value. + - std::bad_alloc + start() may throw this exception and if it does then the object + will be unusable until clear() is called and succeeds + !*/ + + void start_async ( + ); + /*! + ensures + - starts listening on the port and ip specified by get_listening_ip() + and #get_listening_port() for new connections. + - if (get_listening_port() == 0) then + - a port to listen on will be automatically selected + - #get_listening_port() == the selected port being used + - if (get_listening_ip() == "" ) then + - all local IPs will be listened on + - does NOT block. That is, this function will return right away and + the server will run on a background thread until clear() or this + object's destructor is called (or until some kind of fatal error + occurs). + - if an error occurs in the background thread while the server is + running then it will shut itself down, set is_running() to false, and + log the error to a dlib::logger object. + - calling start_async() on a running server has no effect. + throws + - dlib::socket_error + start_async() will throw this exception if there is some problem binding + ports and/or starting the server. + If this happens then + - The server will be cleared and returned to its initial value. + !*/ + + bool is_running ( + ) const; + /*! + ensures + - returns true if start() is running + - returns false if start() is not running or has released all + its resources and is about to terminate + throws + - std::bad_alloc + !*/ + + int get_max_connections ( + ) const; + /*! + ensures + - returns the maximum number of connections the server will accept + at a time. + - returns 0 if the server will accept any number of connections + throws + - std::bad_alloc + !*/ + + + const std::string get_listening_ip ( + ) const; + /*! + ensures + - returns the local ip to listen for new connections on + - returns "" if ALL local ips are to be listened on + throws + - std::bad_alloc + !*/ + + int get_listening_port ( + ) const; + /*! + ensures + - returns the local port number to listen for new connections on + - returns 0 if the local port number has not yet been set + throws + - std::bad_alloc + !*/ + + void set_listening_port ( + int port + ); + /*! + requires + - port >= 0 + - is_running() == false + ensures + - #get_listening_port() == port + throws + - std::bad_alloc + !*/ + + void set_listening_ip ( + const std::string& ip + ); + /*! + requires + - is_ip_address(ip) == true or ip == "" + - is_running() == false + ensures + - #get_listening_ip() == ip + throws + - std::bad_alloc + !*/ + + void set_max_connections ( + int max + ); + /*! + requires + - max >= 0 + ensures + - #get_max_connections() == max + throws + - std::bad_alloc + !*/ + + void set_graceful_close_timeout ( + unsigned long timeout + ); + /*! + ensures + - #get_graceful_close_timeout() == timeout + !*/ + + unsigned long get_graceful_close_timeout ( + ) const; + /*! + ensures + - When on_connect() terminates, it will close the connection using + close_gracefully(). This is done so that any data still in the + operating system's output buffers gets a chance to be properly + transmitted to the remote host. Part of this involves waiting for + the remote host to close their end of the connection. Therefore, + get_graceful_close_timeout() returns the timeout, in milliseconds, + that we wait for the remote host to close their end of the + connection. This is the timeout value given to close_gracefully(). + !*/ + + private: + + virtual void on_connect ( + connection& new_connection + )=0; + /*! + requires + - on_connect() is run in its own thread + - is_running() == true + - the number of current connections < get_max_connection() + - new_connection == the new connection to the server which is + to be serviced by this call to on_connect() + ensures + - when new_connection is shutdown() on_connect() will terminate + - this function will not call clear() + throws + - does not throw any exceptions + !*/ + + // do nothing by default + virtual void on_listening_port_assigned ( + ) {} + /*! + requires + - is called if a listening port of zero was specified and + an actual port number has just been assigned to the server + ensures + - this function will not block + - this function will not call clear() + throws + - does not throw any exceptions + !*/ + + + // restricted functions + server(server&); // copy constructor + server& operator=(server&); // assignment operator + }; + +} + +#endif // DLIB_SERVER_KERNEL_ABSTRACT_ + diff --git a/ml/dlib/dlib/set.h b/ml/dlib/dlib/set.h new file mode 100644 index 000000000..bec6932a2 --- /dev/null +++ b/ml/dlib/dlib/set.h @@ -0,0 +1,74 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SEt_ +#define DLIB_SEt_ + +#include "set/set_kernel_1.h" +#include "set/set_kernel_c.h" + + + +#include "binary_search_tree.h" + +#include "set/set_compare_1.h" + +#include "algs.h" +#include + +namespace dlib +{ + + template < + typename T, + typename mem_manager = default_memory_manager, + typename compare = std::less + > + class set + { + set() {} + + + + + + typedef typename binary_search_tree::kernel_1a + binary_search_tree_1; + + typedef typename binary_search_tree::kernel_2a + binary_search_tree_2; + + public: + + //----------- kernels --------------- + + // kernel_1a + typedef set_kernel_1 + kernel_1a; + typedef set_kernel_c + kernel_1a_c; + + // kernel_1b + typedef set_kernel_1 + kernel_1b; + typedef set_kernel_c + kernel_1b_c; + + + //---------- extensions ------------ + + // compare extensions + typedef set_compare_1 + compare_1a; + typedef set_compare_1 + compare_1a_c; + + typedef set_compare_1 + compare_1b; + typedef set_compare_1 + compare_1b_c; + + }; +} + +#endif // DLIB_SEt_ + diff --git a/ml/dlib/dlib/set/set_compare_1.h b/ml/dlib/dlib/set/set_compare_1.h new file mode 100644 index 000000000..d4b5e76ca --- /dev/null +++ b/ml/dlib/dlib/set/set_compare_1.h @@ -0,0 +1,122 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SET_COMPARe_1_ +#define DLIB_SET_COMPARe_1_ + +#include "set_compare_abstract.h" + +#include "../algs.h" + + + +namespace dlib +{ + + template < + typename set_base + > + class set_compare_1 : public set_base + { + + public: + + bool operator< ( + const set_compare_1& rhs + ) const; + + bool operator== ( + const set_compare_1& rhs + ) const; + + }; + + + template < + typename set_base + > + inline void swap ( + set_compare_1& a, + set_compare_1& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename set_base + > + bool set_compare_1:: + operator< ( + const set_compare_1& rhs + ) const + { + bool result = false; + if (set_base::size() < rhs.size()) + result = true; + + if (set_base::size() == rhs.size()) + { + rhs.reset(); + set_base::reset(); + while (rhs.move_next()) + { + set_base::move_next(); + if (set_base::element() < rhs.element()) + { + result = true; + break; + } + else if (rhs.element() < set_base::element()) + { + break; + } + } + } + + set_base::reset(); + rhs.reset(); + + return result; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename set_base + > + bool set_compare_1:: + operator== ( + const set_compare_1& rhs + ) const + { + bool result = true; + if (set_base::size() != rhs.size()) + result = false; + + + rhs.reset(); + set_base::reset(); + while (rhs.move_next() && set_base::move_next()) + { + if (!(rhs.element() == set_base::element())) + { + result = false; + break; + } + } + + set_base::reset(); + rhs.reset(); + + return result; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SET_COMPARe_1_ + diff --git a/ml/dlib/dlib/set/set_compare_abstract.h b/ml/dlib/dlib/set/set_compare_abstract.h new file mode 100644 index 000000000..ff06218f8 --- /dev/null +++ b/ml/dlib/dlib/set/set_compare_abstract.h @@ -0,0 +1,96 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SET_COMPARe_ABSTRACT_ +#ifdef DLIB_SET_COMPARe_ABSTRACT_ + +#include "set_kernel_abstract.h" + +#include "../algs.h" + + +namespace dlib +{ + + template < + typename set_base + > + class set_compare : public set_base + { + + /*! + REQUIREMENTS ON set_base + must be an implementation of set/set_kernel_abstract.h + + POINTERS AND REFERENCES TO INTERNAL DATA + operator== and operator< invalidate pointers or references to + data members. + + WHAT THIS EXTENSION DOES FOR set + This gives a set the ability to compare itself to other + sets using the < and == operators. + + The < operator is conceptually weird for sets. It is useful + though because it allows you to make sets of sets since + sets require that their containing type implement operator<. + + Also note that it is the case that for any two sets a and b + if (a rhs.size()) then + - returns false + - else + - returns true if there exists an integer j such that 0 <= j < size() + and for all integers i such that 0 <= i < j where it is true that + (*this)[i] == rhs[i] and (*this)[j] < rhs[j] + - returns false if there is no j that will satisfy the above conditions. + !*/ + + bool operator== ( + const set_compare& rhs + ) const; + /*! + ensures + - #at_start() == true + - returns true if *this and rhs contain the same elements. + returns false otherwise. + !*/ + }; + + + template < + typename set_base + > + inline void swap ( + set_compare& a, + set_compare& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + +} + +#endif // DLIB_SET_COMPARe_ABSTRACT_ + diff --git a/ml/dlib/dlib/set/set_kernel_1.h b/ml/dlib/dlib/set/set_kernel_1.h new file mode 100644 index 000000000..9df96e671 --- /dev/null +++ b/ml/dlib/dlib/set/set_kernel_1.h @@ -0,0 +1,372 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SET_KERNEl_1_ +#define DLIB_SET_KERNEl_1_ + +#include "set_kernel_abstract.h" +#include "../algs.h" +#include "../interfaces/enumerable.h" +#include "../interfaces/remover.h" +#include "../serialize.h" + +namespace dlib +{ + + template < + typename T, + typename bst_base, + typename mem_manager = default_memory_manager + > + class set_kernel_1 : public enumerable, + public asc_remover + { + + /*! + REQUIREMENTS ON bst_base + bst_base is instantiated with and + implements binray_search_tree/binary_search_tree_kernel_abstract.h + + INITIAL VALUE + bst has its initial value + + CONVENTION + bst.size() == the number of elements in the set and + the elements in the set are stored in bst + !*/ + + + public: + + typedef T type; + typedef typename bst_base::compare_type compare_type; + typedef mem_manager mem_manager_type; + + set_kernel_1( + ) + { + } + + virtual ~set_kernel_1( + ) + {} + + inline void clear( + ); + + inline void add ( + T& item + ); + + inline bool is_member ( + const T& item + ) const; + + inline void remove ( + const T& item, + T& item_copy + ); + + inline void destroy ( + const T& item + ); + + inline void swap ( + set_kernel_1& item + ); + + // functions from the remover interface + inline void remove_any ( + T& item + ); + + // functions from the enumerable interface + inline size_t size ( + ) const; + + inline bool at_start ( + ) const; + + inline void reset ( + ) const; + + inline bool current_element_valid ( + ) const; + + inline const T& element ( + ) const; + + + inline const T& element ( + ); + + inline bool move_next ( + ) const; + + + private: + + bst_base bst; + char junk; + + // restricted functions + set_kernel_1(set_kernel_1&); + set_kernel_1& operator=(set_kernel_1&); + + }; + + template < + typename T, + typename bst_base, + typename mem_manager + > + inline void swap ( + set_kernel_1& a, + set_kernel_1& b + ) { a.swap(b); } + + template < + typename T, + typename bst_base, + typename mem_manager + > + void deserialize ( + set_kernel_1& item, + std::istream& in + ) + { + try + { + item.clear(); + unsigned long size; + deserialize(size,in); + T temp; + for (unsigned long i = 0; i < size; ++i) + { + deserialize(temp,in); + item.add(temp); + } + } + catch (serialization_error e) + { + item.clear(); + throw serialization_error(e.info + "\n while deserializing object of type set_kernel_1"); + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename bst_base, + typename mem_manager + > + void set_kernel_1:: + clear ( + ) + { + bst.clear(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename bst_base, + typename mem_manager + > + void set_kernel_1:: + add ( + T& item + ) + { + bst.add(item,junk); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename bst_base, + typename mem_manager + > + bool set_kernel_1:: + is_member( + const T& item + ) const + { + return (bst[item] != 0); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename bst_base, + typename mem_manager + > + void set_kernel_1:: + remove_any ( + T& item + ) + { + bst.remove_any(item,junk); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename bst_base, + typename mem_manager + > + void set_kernel_1:: + remove( + const T& item, + T& item_copy + ) + { + bst.remove(item,item_copy,junk); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename bst_base, + typename mem_manager + > + void set_kernel_1:: + destroy( + const T& item + ) + { + bst.destroy(item); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename bst_base, + typename mem_manager + > + size_t set_kernel_1:: + size ( + ) const + { + return bst.size(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename bst_base, + typename mem_manager + > + void set_kernel_1:: + swap ( + set_kernel_1& item + ) + { + bst.swap(item.bst); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // enumerable function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename bst_base, + typename mem_manager + > + bool set_kernel_1:: + at_start ( + ) const + { + return bst.at_start(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename bst_base, + typename mem_manager + > + void set_kernel_1:: + reset ( + ) const + { + bst.reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename bst_base, + typename mem_manager + > + bool set_kernel_1:: + current_element_valid ( + ) const + { + return bst.current_element_valid(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename bst_base, + typename mem_manager + > + const T& set_kernel_1:: + element ( + ) const + { + return bst.element().key(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename bst_base, + typename mem_manager + > + const T& set_kernel_1:: + element ( + ) + { + return bst.element().key(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename bst_base, + typename mem_manager + > + bool set_kernel_1:: + move_next ( + ) const + { + return bst.move_next(); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SET_KERNEl_1_ + diff --git a/ml/dlib/dlib/set/set_kernel_abstract.h b/ml/dlib/dlib/set/set_kernel_abstract.h new file mode 100644 index 000000000..72a0a98b3 --- /dev/null +++ b/ml/dlib/dlib/set/set_kernel_abstract.h @@ -0,0 +1,192 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SET_KERNEl_ABSTRACT_ +#ifdef DLIB_SET_KERNEl_ABSTRACT_ + +#include "../interfaces/enumerable.h" +#include "../interfaces/remover.h" +#include "../serialize.h" +#include "../algs.h" +#include + +namespace dlib +{ + + template < + typename T, + typename mem_manager = default_memory_manager, + typename compare = std::less + > + class set : public enumerable, + public asc_remover + { + + /*! + REQUIREMENTS ON T + T must be comparable by compare where compare is a functor compatible with std::less and + T must be swappable by a global swap() and + T must have a default constructor + + REQUIREMENTS ON mem_manager + must be an implementation of memory_manager/memory_manager_kernel_abstract.h or + must be an implementation of memory_manager_global/memory_manager_global_kernel_abstract.h or + must be an implementation of memory_manager_stateless/memory_manager_stateless_kernel_abstract.h + mem_manager::type can be set to anything. + + POINTERS AND REFERENCES TO INTERNAL DATA + swap() and is_member() functions do not invalidate pointers + or references to internal data. + All other functions have no such guarantee. + + INITIAL VALUE + size() == 0 + + ENUMERATION ORDER + The enumerator will iterate over the elements in the set in + ascending order according to the compare functor. + (i.e. the elements are enumerated in sorted order) + + WHAT THIS OBJECT REPRESENTS + set contains items of type T + + This object represents an unaddressed collection of items. + Every element in a set is unique. + + definition of equivalent: + a is equivalent to b if + a < b == false and + b < a == false + !*/ + + public: + + typedef T type; + typedef compare compare_type; + typedef mem_manager mem_manager_type; + + set( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc or any exception thrown by T's constructor + !*/ + + virtual ~set( + ); + /*! + ensures + - all memory associated with *this has been released + !*/ + + void clear( + ); + /*! + ensures + - #*this has its initial value + throws + - std::bad_alloc or any exception thrown by T's constructor + if this exception is thrown then *this is unusable + until clear() is called and succeeds + !*/ + + void add ( + T& item + ); + /*! + requires + - is_member(item) == false + ensures + - #is_member(item) == true + - #item has an initial value for its type + - #size() == size() + 1 + - #at_start() == true + throws + - std::bad_alloc or any exception thrown by T's constructor + if add() throws then it has no effect + !*/ + + bool is_member ( + const T& item + ) const; + /*! + ensures + - returns whether or not there is an element in *this equivalent to + item + !*/ + + void remove ( + const T& item, + T& item_copy + ); + /*! + requires + - is_member(item) == true + - &item != &item_copy (i.e. item and item_copy cannot be the same + variable) + ensures + - #is_member(item) == false + - the element in *this equivalent to item has been removed and + swapped into #item_copy + - #size() == size() - 1 + - #at_start() == true + !*/ + + void destroy ( + const T& item + ); + /*! + requires + - is_member(item) == true + ensures + - #is_member(item) == false + - #size() == size() - 1 + - #at_start() == true + !*/ + + void swap ( + set& item + ); + /*! + ensures + - swaps *this and item + !*/ + + private: + + // restricted functions + set(set&); // copy constructor + set& operator=(set&); // assignment operator + + }; + + template < + typename T, + typename mem_manager, + typename compare + > + inline void swap ( + set& a, + set& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + + template < + typename T, + typename mem_manager, + typename compare + > + void deserialize ( + set& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ +} + +#endif // DLIB_SET_KERNEl_ABSTRACT_ + diff --git a/ml/dlib/dlib/set/set_kernel_c.h b/ml/dlib/dlib/set/set_kernel_c.h new file mode 100644 index 000000000..679f9fae4 --- /dev/null +++ b/ml/dlib/dlib/set/set_kernel_c.h @@ -0,0 +1,194 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SET_KERNEl_C_ +#define DLIB_SET_KERNEl_C_ + +#include "set_kernel_abstract.h" +#include "../algs.h" +#include "../assert.h" + +namespace dlib +{ + + template < + typename set_base + > + class set_kernel_c : public set_base + { + typedef typename set_base::type T; + public: + + void add ( + T& item + ); + + void remove_any ( + T& item + ); + + void remove ( + const T& item, + T& item_copy + ); + + void destroy ( + const T& item + ); + + const T& element ( + ); + + const T& element ( + ) const; + }; + + + template < + typename set_base + > + inline void swap ( + set_kernel_c& a, + set_kernel_c& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename set_base + > + void set_kernel_c:: + add( + T& item + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( !this->is_member(item), + "\tvoid set::add" + << "\n\titem being added must not already be in the set" + << "\n\tthis: " << this + ); + + // call the real function + set_base::add(item); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename set_base + > + void set_kernel_c:: + remove ( + const T& item, + T& item_copy + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( this->is_member(item) && + (static_cast(&item) != static_cast(&item_copy)), + "\tvoid set::remove" + << "\n\titem should be in the set if it's going to be removed" + << "\n\tthis: " << this + << "\n\t&item: " << &item + << "\n\t&item_copy: " << &item_copy + << "\n\tis_member(item): " << (this->is_member(item)?"true":"false") + ); + + // call the real function + set_base::remove(item,item_copy); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename set_base + > + void set_kernel_c:: + destroy ( + const T& item + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( this->is_member(item), + "\tvoid set::destroy" + << "\n\titem should be in the set if it's going to be removed" + << "\n\tthis: " << this + << "\n\t&item: " << &item + ); + + // call the real function + set_base::destroy(item); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename set_base + > + void set_kernel_c:: + remove_any ( + T& item + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( this->size() != 0, + "\tvoid set::remove_any" + << "\n\tsize must be greater than zero if an item is to be removed" + << "\n\tthis: " << this + ); + + // call the real function + set_base::remove_any(item); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename set_base + > + const typename set_base::type& set_kernel_c:: + element ( + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT(this->current_element_valid() == true, + "\tconst T& set::element() const" + << "\n\tyou can't access the current element if it doesn't exist" + << "\n\tthis: " << this + ); + + // call the real function + return set_base::element(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename set_base + > + const typename set_base::type& set_kernel_c:: + element ( + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(this->current_element_valid() == true, + "\tconst T& set::element" + << "\n\tyou can't access the current element if it doesn't exist" + << "\n\tthis: " << this + ); + + // call the real function + return set_base::element(); + } + +// ---------------------------------------------------------------------------------------- + + +} + +#endif // DLIB_SET_KERNEl_C_ + diff --git a/ml/dlib/dlib/set_utils.h b/ml/dlib/dlib/set_utils.h new file mode 100644 index 000000000..a4ee5d3da --- /dev/null +++ b/ml/dlib/dlib/set_utils.h @@ -0,0 +1,11 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SET_UTILs_H_ +#define DLIB_SET_UTILs_H_ + +#include "set_utils/set_utils.h" + +#endif // DLIB_SET_UTILs_H_ + + + diff --git a/ml/dlib/dlib/set_utils/set_utils.h b/ml/dlib/dlib/set_utils/set_utils.h new file mode 100644 index 000000000..c47b5ccd7 --- /dev/null +++ b/ml/dlib/dlib/set_utils/set_utils.h @@ -0,0 +1,246 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SET_UTILs_ +#define DLIB_SET_UTILs_ + +#include "../algs.h" +#include "set_utils_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename U + > + unsigned long set_intersection_size ( + const T& a, + const U& b + ) + { + if (is_same_object(a,b)) + return a.size(); + + unsigned long num = 0; + + if (a.size() < b.size()) + { + a.reset(); + while (a.move_next()) + { + if (b.is_member(a.element())) + ++num; + } + } + else + { + b.reset(); + while (b.move_next()) + { + if (a.is_member(b.element())) + ++num; + } + } + + return num; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename U, + typename V + > + void set_union ( + const T& a, + const U& b, + V& u + ) + { + typedef typename T::type type; + if (is_same_object(a,u) || is_same_object(b,u)) + { + V local_u; + type temp; + a.reset(); + while (a.move_next()) + { + temp = a.element(); + local_u.add(temp); + } + + b.reset(); + while (b.move_next()) + { + if (a.is_member(b.element()) == false) + { + temp = b.element(); + local_u.add(temp); + } + } + + local_u.swap(u); + } + else + { + u.clear(); + + type temp; + a.reset(); + while (a.move_next()) + { + temp = a.element(); + u.add(temp); + } + + b.reset(); + while (b.move_next()) + { + if (a.is_member(b.element()) == false) + { + temp = b.element(); + u.add(temp); + } + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename U, + typename V + > + void set_intersection ( + const T& a, + const U& b, + V& i + ) + { + typedef typename T::type type; + if (is_same_object(a,i) || is_same_object(b,i)) + { + V local_i; + + type temp; + + if (a.size() < b.size()) + { + a.reset(); + while (a.move_next()) + { + if (b.is_member(a.element())) + { + temp = a.element(); + local_i.add(temp); + } + } + } + else + { + b.reset(); + while (b.move_next()) + { + if (a.is_member(b.element())) + { + temp = b.element(); + local_i.add(temp); + } + } + } + + local_i.swap(i); + } + else + { + i.clear(); + type temp; + + if (a.size() < b.size()) + { + a.reset(); + while (a.move_next()) + { + if (b.is_member(a.element())) + { + temp = a.element(); + i.add(temp); + } + } + } + else + { + b.reset(); + while (b.move_next()) + { + if (a.is_member(b.element())) + { + temp = b.element(); + i.add(temp); + } + } + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename U, + typename V + > + void set_difference ( + const T& a, + const U& b, + V& d + ) + { + typedef typename T::type type; + if (is_same_object(a,d) || is_same_object(b,d)) + { + V local_d; + + type temp; + + a.reset(); + while (a.move_next()) + { + if (b.is_member(a.element()) == false) + { + temp = a.element(); + local_d.add(temp); + } + } + + local_d.swap(d); + } + else + { + d.clear(); + type temp; + + a.reset(); + while (a.move_next()) + { + if (b.is_member(a.element()) == false) + { + temp = a.element(); + d.add(temp); + } + } + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SET_UTILs_ + + + diff --git a/ml/dlib/dlib/set_utils/set_utils_abstract.h b/ml/dlib/dlib/set_utils/set_utils_abstract.h new file mode 100644 index 000000000..e1eba6afc --- /dev/null +++ b/ml/dlib/dlib/set_utils/set_utils_abstract.h @@ -0,0 +1,98 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SET_UTILs_ABSTRACT_ +#ifdef DLIB_SET_UTILs_ABSTRACT_ + +#include "../set.h" +#include "../algs.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename U + > + unsigned long set_intersection_size ( + const T& a, + const U& b + ); + /*! + requires + - T and U must both be implementations of set/set_kernel_abstract.h + ensures + - returns the number of elements that are in both set a and b + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename U, + typename V + > + void set_union ( + const T& a, + const U& b, + V& u + ); + /*! + requires + - T, U, and V must all be implementations of set/set_kernel_abstract.h + - the types of objects contained in these sets must be copyable + ensures + - #u == the union of a and b. That is, u contains all elements + of a and all the elements of b. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename U, + typename V + > + void set_intersection ( + const T& a, + const U& b, + V& i + ); + /*! + requires + - T, U, and V must all be implementations of set/set_kernel_abstract.h + - the types of objects contained in these sets must be copyable + ensures + - #i == the intersection of a and b. That is, i contains all elements + of a that are also in b. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename U, + typename V + > + void set_difference ( + const T& a, + const U& b, + V& d + ); + /*! + requires + - T, U, and V must all be implementations of set/set_kernel_abstract.h + - the types of objects contained in these sets must be copyable + ensures + - #d == the difference of a and b. That is, d contains all elements + of a that are NOT in b. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SET_UTILs_ABSTRACT_ + + diff --git a/ml/dlib/dlib/simd.h b/ml/dlib/dlib/simd.h new file mode 100644 index 000000000..93df6702d --- /dev/null +++ b/ml/dlib/dlib/simd.h @@ -0,0 +1,12 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SIMd_Hh_ +#define DLIB_SIMd_Hh_ + +#include "simd/simd4f.h" +#include "simd/simd4i.h" +#include "simd/simd8f.h" +#include "simd/simd8i.h" + +#endif // DLIB_SIMd_Hh_ + diff --git a/ml/dlib/dlib/simd/simd4f.h b/ml/dlib/dlib/simd/simd4f.h new file mode 100644 index 000000000..2bfadd23f --- /dev/null +++ b/ml/dlib/dlib/simd/simd4f.h @@ -0,0 +1,685 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_sIMD4F_Hh_ +#define DLIB_sIMD4F_Hh_ + +#include "simd_check.h" +#include "simd4i.h" +#include +#include + +namespace dlib +{ + +#ifdef DLIB_HAVE_SSE2 + class simd4f + { + public: + typedef float type; + + inline simd4f() {} + inline simd4f(float f) { x = _mm_set1_ps(f); } + inline simd4f(float r0, float r1, float r2, float r3) { x = _mm_setr_ps(r0,r1,r2,r3); } + inline simd4f(const __m128& val):x(val) {} + inline simd4f(const simd4i& val):x(_mm_cvtepi32_ps(val)) {} + + inline simd4f& operator=(const simd4i& val) + { + x = simd4f(val); + return *this; + } + + inline simd4f& operator=(const float& val) + { + x = simd4f(val); + return *this; + } + + inline simd4f& operator=(const __m128& val) + { + x = val; + return *this; + } + + inline operator __m128() const { return x; } + + // truncate to 32bit integers + inline operator __m128i() const { return _mm_cvttps_epi32(x); } + + inline void load_aligned(const type* ptr) { x = _mm_load_ps(ptr); } + inline void store_aligned(type* ptr) const { _mm_store_ps(ptr, x); } + inline void load(const type* ptr) { x = _mm_loadu_ps(ptr); } + inline void store(type* ptr) const { _mm_storeu_ps(ptr, x); } + + inline unsigned int size() const { return 4; } + inline float operator[](unsigned int idx) const + { + float temp[4]; + store(temp); + return temp[idx]; + } + + private: + __m128 x; + }; + + class simd4f_bool + { + public: + typedef float type; + + inline simd4f_bool() {} + inline simd4f_bool(const __m128& val):x(val) {} + + inline simd4f_bool& operator=(const __m128& val) + { + x = val; + return *this; + } + + inline operator __m128() const { return x; } + + + private: + __m128 x; + }; + +#elif defined(DLIB_HAVE_VSX) + + class simd4f + { + typedef union { + vector float v; + float x[4]; + } v4f; + + v4f x; + + public: + inline simd4f() : x{0,0,0,0} {} + inline simd4f(const simd4f& v) : x(v.x) { } + inline simd4f(const vector float& v) : x{v} { } + + inline simd4f(const simd4i& v) { + x.x[0]=v[0]; x.x[1]=v[1]; x.x[2]=v[2]; x.x[3]=v[3]; + } + + + inline simd4f(float f) : x{f,f,f,f} { } + inline simd4f(float r0, float r1, float r2, float r3) + : x{r0,r1,r2,r3} { } + + inline simd4f& operator=(const simd4f& v) { x = v.x; return *this; } + inline simd4f& operator=(const float& v) { *this = simd4f(v); return *this; } + + inline vector float operator() () const { return x.v; } + inline float operator[](unsigned int idx) const { return x.x[idx]; } + + inline void load_aligned(const float* ptr) { x.v = vec_ld(0, ptr); } + inline void store_aligned(float* ptr) const { vec_st(x.v, 0, ptr); } + inline void load(const float* ptr) { x.v = vec_vsx_ld(0, ptr); } + inline void store(float* ptr) const { vec_vsx_st(x.v, 0, ptr); } + + + // truncate to 32bit integers + inline operator simd4i::rawarray() const + { + simd4i::rawarray temp; + temp.v.x[0] = x.x[0]; + temp.v.x[1] = x.x[1]; + temp.v.x[2] = x.x[2]; + temp.v.x[3] = x.x[3]; + return temp; + } + }; + + typedef simd4i simd4f_bool; + +#elif defined(DLIB_HAVE_NEON) + + class simd4f + { + public: + typedef float type; + + inline simd4f() {} + inline simd4f(float f) { x = vdupq_n_f32(f); } + inline simd4f(float r0, float r1, float r2, float r3) + { + float __attribute__ ((aligned (16))) data[4] = { r0, r1, r2, r3 }; + x = vld1q_f32(data); + } + inline simd4f(const float32x4_t& val):x(val) {} + inline simd4f(const simd4i& val):x(vcvtq_f32_s32(val)) {} + + inline simd4f& operator=(const simd4i& val) + { + x = simd4f(val); + return *this; + } + + inline simd4f& operator=(const float& val) + { + x = simd4f(val); + return *this; + } + + inline simd4f& operator=(const float32x4_t& val) + { + x = val; + return *this; + } + + inline operator float32x4_t() const { return x; } + + // truncate to 32bit integers + inline operator int32x4_t() const { return vcvtq_s32_f32(x); } + + inline void load_aligned(const type* ptr) { x = vld1q_f32(ptr); } + inline void store_aligned(type* ptr) const { vst1q_f32(ptr, x); } + inline void load(const type* ptr) { x = vld1q_f32(ptr); } + inline void store(type* ptr) const { vst1q_f32(ptr, x); } + + inline unsigned int size() const { return 4; } + inline float operator[](unsigned int idx) const + { + float temp[4]; + store(temp); + return temp[idx]; + } + + private: + float32x4_t x; + }; + + + typedef simd4i simd4f_bool; +#else + class simd4f + { + public: + typedef float type; + + inline simd4f() {} + inline simd4f(float f) { x[0]=f; x[1]=f; x[2]=f; x[3]=f; } + inline simd4f(float r0, float r1, float r2, float r3) { x[0]=r0; x[1]=r1; x[2]=r2; x[3]=r3;} + inline simd4f(const simd4i& val) { x[0]=val[0]; x[1]=val[1]; x[2]=val[2]; x[3]=val[3];} + + // truncate to 32bit integers + inline operator simd4i::rawarray() const + { + simd4i::rawarray temp; + temp.a[0] = (int32)x[0]; + temp.a[1] = (int32)x[1]; + temp.a[2] = (int32)x[2]; + temp.a[3] = (int32)x[3]; + return temp; + } + + inline simd4f& operator=(const float& val) + { + *this = simd4f(val); + return *this; + } + + inline simd4f& operator=(const simd4i& val) + { + x[0] = val[0]; + x[1] = val[1]; + x[2] = val[2]; + x[3] = val[3]; + return *this; + } + + + inline void load_aligned(const type* ptr) + { + x[0] = ptr[0]; + x[1] = ptr[1]; + x[2] = ptr[2]; + x[3] = ptr[3]; + } + + inline void store_aligned(type* ptr) const + { + ptr[0] = x[0]; + ptr[1] = x[1]; + ptr[2] = x[2]; + ptr[3] = x[3]; + } + + inline void load(const type* ptr) + { + x[0] = ptr[0]; + x[1] = ptr[1]; + x[2] = ptr[2]; + x[3] = ptr[3]; + } + + inline void store(type* ptr) const + { + ptr[0] = x[0]; + ptr[1] = x[1]; + ptr[2] = x[2]; + ptr[3] = x[3]; + } + + inline unsigned int size() const { return 4; } + inline float operator[](unsigned int idx) const { return x[idx]; } + + private: + float x[4]; + }; + + + class simd4f_bool + { + public: + typedef float type; + + inline simd4f_bool() {} + inline simd4f_bool(bool r0, bool r1, bool r2, bool r3) { x[0]=r0; x[1]=r1; x[2]=r2; x[3]=r3;} + + inline bool operator[](unsigned int idx) const { return x[idx]; } + private: + bool x[4]; + }; + +#endif + +// ---------------------------------------------------------------------------------------- + + inline std::ostream& operator<<(std::ostream& out, const simd4f& item) + { + float temp[4]; + item.store(temp); + out << "(" << temp[0] << ", " << temp[1] << ", " << temp[2] << ", " << temp[3] << ")"; + return out; + } + +// ---------------------------------------------------------------------------------------- + + inline simd4f operator+ (const simd4f& lhs, const simd4f& rhs) + { +#ifdef DLIB_HAVE_SSE2 + return _mm_add_ps(lhs, rhs); +#elif defined(DLIB_HAVE_VSX) + return vec_add(lhs(), rhs()); +#elif defined(DLIB_HAVE_NEON) + return vaddq_f32(lhs, rhs); +#else + return simd4f(lhs[0]+rhs[0], + lhs[1]+rhs[1], + lhs[2]+rhs[2], + lhs[3]+rhs[3]); +#endif + } + inline simd4f& operator+= (simd4f& lhs, const simd4f& rhs) + { lhs = lhs + rhs; return lhs; } + +// ---------------------------------------------------------------------------------------- + + inline simd4f operator- (const simd4f& lhs, const simd4f& rhs) + { +#ifdef DLIB_HAVE_SSE2 + return _mm_sub_ps(lhs, rhs); +#elif defined(DLIB_HAVE_VSX) + return vec_sub(lhs(), rhs()); +#elif defined(DLIB_HAVE_NEON) + return vsubq_f32(lhs, rhs); +#else + return simd4f(lhs[0]-rhs[0], + lhs[1]-rhs[1], + lhs[2]-rhs[2], + lhs[3]-rhs[3]); +#endif + } + inline simd4f& operator-= (simd4f& lhs, const simd4f& rhs) + { lhs = lhs - rhs; return lhs; } + +// ---------------------------------------------------------------------------------------- + + inline simd4f operator* (const simd4f& lhs, const simd4f& rhs) + { +#ifdef DLIB_HAVE_SSE2 + return _mm_mul_ps(lhs, rhs); +#elif defined(DLIB_HAVE_VSX) + return vec_mul(lhs(), rhs()); +#elif defined(DLIB_HAVE_NEON) + return vmulq_f32(lhs, rhs); +#else + return simd4f(lhs[0]*rhs[0], + lhs[1]*rhs[1], + lhs[2]*rhs[2], + lhs[3]*rhs[3]); +#endif + } + inline simd4f& operator*= (simd4f& lhs, const simd4f& rhs) + { lhs = lhs * rhs; return lhs; } + +// ---------------------------------------------------------------------------------------- + + inline simd4f operator/ (const simd4f& lhs, const simd4f& rhs) + { +#ifdef DLIB_HAVE_SSE2 + return _mm_div_ps(lhs, rhs); +#elif defined(DLIB_HAVE_VSX) + return vec_div(lhs(), rhs()); +#elif defined(DLIB_HAVE_NEON) + float32x4_t reciprocal = vrecpeq_f32(rhs); + reciprocal = vmulq_f32(vrecpsq_f32(rhs, reciprocal), reciprocal); + reciprocal = vmulq_f32(vrecpsq_f32(rhs, reciprocal), reciprocal); + float32x4_t result = vmulq_f32(lhs,reciprocal); + return result; +#else + return simd4f(lhs[0]/rhs[0], + lhs[1]/rhs[1], + lhs[2]/rhs[2], + lhs[3]/rhs[3]); +#endif + } + inline simd4f& operator/= (simd4f& lhs, const simd4f& rhs) + { lhs = lhs / rhs; return lhs; } + +// ---------------------------------------------------------------------------------------- + + inline simd4f_bool operator== (const simd4f& lhs, const simd4f& rhs) + { +#ifdef DLIB_HAVE_SSE2 + return _mm_cmpeq_ps(lhs, rhs); +#elif defined(DLIB_HAVE_VSX) + return vec_cmpeq(lhs(), rhs()); +#elif defined(DLIB_HAVE_NEON) + return (int32x4_t)vceqq_f32(lhs, rhs); +#else + return simd4f_bool(lhs[0]==rhs[0], + lhs[1]==rhs[1], + lhs[2]==rhs[2], + lhs[3]==rhs[3]); +#endif + } + +// ---------------------------------------------------------------------------------------- + + inline simd4f_bool operator!= (const simd4f& lhs, const simd4f& rhs) + { +#ifdef DLIB_HAVE_SSE2 + return _mm_cmpneq_ps(lhs, rhs); +#elif defined(DLIB_HAVE_VSX) || defined(DLIB_HAVE_NEON) + return ~(lhs==rhs); // simd4f_bool is simd4i typedef, can use ~ +#else + return simd4f_bool(lhs[0]!=rhs[0], + lhs[1]!=rhs[1], + lhs[2]!=rhs[2], + lhs[3]!=rhs[3]); +#endif + } + +// ---------------------------------------------------------------------------------------- + + inline simd4f_bool operator< (const simd4f& lhs, const simd4f& rhs) + { +#ifdef DLIB_HAVE_SSE2 + return _mm_cmplt_ps(lhs, rhs); +#elif defined(DLIB_HAVE_VSX) + return vec_cmplt(lhs(), rhs()); +#elif defined(DLIB_HAVE_NEON) + return (int32x4_t)vcltq_f32(lhs, rhs); +#else + return simd4f_bool(lhs[0] (const simd4f& lhs, const simd4f& rhs) + { + return rhs < lhs; + } + +// ---------------------------------------------------------------------------------------- + + inline simd4f_bool operator<= (const simd4f& lhs, const simd4f& rhs) + { +#ifdef DLIB_HAVE_SSE2 + return _mm_cmple_ps(lhs, rhs); +#elif defined(DLIB_HAVE_VSX) + return vec_cmple(lhs(), rhs()); +#elif defined(DLIB_HAVE_NEON) + return (int32x4_t)vcleq_f32(lhs, rhs); +#else + return simd4f_bool(lhs[0]<=rhs[0], + lhs[1]<=rhs[1], + lhs[2]<=rhs[2], + lhs[3]<=rhs[3]); +#endif + } + +// ---------------------------------------------------------------------------------------- + + inline simd4f_bool operator>= (const simd4f& lhs, const simd4f& rhs) + { + return rhs <= lhs; + } + +// ---------------------------------------------------------------------------------------- + + inline simd4f min (const simd4f& lhs, const simd4f& rhs) + { +#ifdef DLIB_HAVE_SSE2 + return _mm_min_ps(lhs, rhs); +#elif defined(DLIB_HAVE_VSX) + return vec_min(lhs(), rhs()); +#elif defined(DLIB_HAVE_NEON) + return vminq_f32(lhs, rhs); +#else + return simd4f(std::min(lhs[0],rhs[0]), + std::min(lhs[1],rhs[1]), + std::min(lhs[2],rhs[2]), + std::min(lhs[3],rhs[3])); +#endif + } + +// ---------------------------------------------------------------------------------------- + + inline simd4f max (const simd4f& lhs, const simd4f& rhs) + { +#ifdef DLIB_HAVE_SSE2 + return _mm_max_ps(lhs, rhs); +#elif defined(DLIB_HAVE_VSX) + return vec_max(lhs(), rhs()); +#elif defined(DLIB_HAVE_NEON) + return vmaxq_f32(lhs, rhs); +#else + return simd4f(std::max(lhs[0],rhs[0]), + std::max(lhs[1],rhs[1]), + std::max(lhs[2],rhs[2]), + std::max(lhs[3],rhs[3])); +#endif + } + +// ---------------------------------------------------------------------------------------- + + inline simd4f reciprocal (const simd4f& item) + { +#ifdef DLIB_HAVE_SSE2 + return _mm_rcp_ps(item); +#elif defined(DLIB_HAVE_VSX) + return vec_re(item()); +#elif defined(DLIB_HAVE_NEON) + float32x4_t estimate = vrecpeq_f32(item); + estimate = vmulq_f32(vrecpsq_f32(estimate , item), estimate ); + estimate = vmulq_f32(vrecpsq_f32(estimate , item), estimate ); + return estimate ; +#else + return simd4f(1.0f/item[0], + 1.0f/item[1], + 1.0f/item[2], + 1.0f/item[3]); +#endif + } + +// ---------------------------------------------------------------------------------------- + + inline simd4f reciprocal_sqrt (const simd4f& item) + { +#ifdef DLIB_HAVE_SSE2 + return _mm_rsqrt_ps(item); +#elif defined(DLIB_HAVE_VSX) + return vec_rsqrt(item()); +#elif defined(DLIB_HAVE_NEON) + float32x4_t estimate = vrsqrteq_f32(item); + simd4f estimate2 = vmulq_f32(estimate, item); + estimate = vmulq_f32(estimate, vrsqrtsq_f32(estimate2, estimate)); + return estimate; +#else + return simd4f(1.0f/std::sqrt(item[0]), + 1.0f/std::sqrt(item[1]), + 1.0f/std::sqrt(item[2]), + 1.0f/std::sqrt(item[3])); +#endif + } + +// ---------------------------------------------------------------------------------------- + + inline float dot(const simd4f& lhs, const simd4f& rhs); + inline float sum(const simd4f& item) + { +#ifdef DLIB_HAVE_SSE41 + return dot(simd4f(1), item); +#elif defined(DLIB_HAVE_SSE3) + simd4f temp = _mm_hadd_ps(item,item); + return _mm_cvtss_f32(_mm_hadd_ps(temp,temp)); +#elif defined(DLIB_HAVE_SSE2) && (!defined(_MSC_VER) || _MSC_VER!=1400) + simd4f temp = _mm_add_ps(item,_mm_movehl_ps(item,item)); + simd4f temp2 = _mm_shuffle_ps(temp,temp,1); + return _mm_cvtss_f32(_mm_add_ss(temp,temp2)); +#elif defined(DLIB_HAVE_NEON) + float32x2_t r = vadd_f32(vget_high_f32(item), vget_low_f32(item)); + return vget_lane_f32(vpadd_f32(r, r), 0); +#else + return item[0]+item[1]+item[2]+item[3]; +#endif + } + +// ---------------------------------------------------------------------------------------- + + inline float dot(const simd4f& lhs, const simd4f& rhs) + { +#ifdef DLIB_HAVE_SSE41 + return _mm_cvtss_f32(_mm_dp_ps(lhs, rhs, 0xff)); +#else + return sum(lhs*rhs); +#endif + } + +// ---------------------------------------------------------------------------------------- + + inline simd4f sqrt(const simd4f& item) + { +#ifdef DLIB_HAVE_SSE2 + return _mm_sqrt_ps(item); +#elif defined(DLIB_HAVE_VSX) + return vec_sqrt(item()); +#elif defined(DLIB_HAVE_NEON) + float32x4_t q_step_0 = vrsqrteq_f32(item); + float32x4_t q_step_parm0 = vmulq_f32(item, q_step_0); + float32x4_t q_step_result0 = vrsqrtsq_f32(q_step_parm0, q_step_0); + float32x4_t q_step_1 = vmulq_f32(q_step_0, q_step_result0); + float32x4_t q_step_parm1 = vmulq_f32(item, q_step_1); + float32x4_t q_step_result1 = vrsqrtsq_f32(q_step_parm1, q_step_1); + float32x4_t q_step_2 = vmulq_f32(q_step_1, q_step_result1); + float32x4_t res3 = vmulq_f32(item, q_step_2); + + // normalize sqrt(0)=0 + uint32x4_t zcomp = vceqq_f32(vdupq_n_f32(0), item); + float32x4_t rcorr = vbslq_f32(zcomp, item, res3); + return rcorr; +#else + return simd4f(std::sqrt(item[0]), + std::sqrt(item[1]), + std::sqrt(item[2]), + std::sqrt(item[3])); +#endif + } + +// ---------------------------------------------------------------------------------------- + + inline simd4f ceil(const simd4f& item) + { +#ifdef DLIB_HAVE_SSE41 + return _mm_ceil_ps(item); +#elif defined(DLIB_HAVE_SSE2) || defined(DLIB_HAVE_NEON) + float temp[4]; + item.store(temp); + temp[0] = std::ceil(temp[0]); + temp[1] = std::ceil(temp[1]); + temp[2] = std::ceil(temp[2]); + temp[3] = std::ceil(temp[3]); + simd4f temp2; + temp2.load(temp); + return temp2; +#elif defined(DLIB_HAVE_VSX) + return vec_ceil(item()); +#else + return simd4f(std::ceil(item[0]), + std::ceil(item[1]), + std::ceil(item[2]), + std::ceil(item[3])); +#endif + } + +// ---------------------------------------------------------------------------------------- + + inline simd4f floor(const simd4f& item) + { +#ifdef DLIB_HAVE_SSE41 + return _mm_floor_ps(item); +#elif defined(DLIB_HAVE_SSE2) || defined(DLIB_HAVE_NEON) + float temp[4]; + item.store(temp); + temp[0] = std::floor(temp[0]); + temp[1] = std::floor(temp[1]); + temp[2] = std::floor(temp[2]); + temp[3] = std::floor(temp[3]); + simd4f temp2; + temp2.load(temp); + return temp2; +#elif defined(DLIB_HAVE_VSX) + return vec_floor(item()); +#else + return simd4f(std::floor(item[0]), + std::floor(item[1]), + std::floor(item[2]), + std::floor(item[3])); +#endif + } + +// ---------------------------------------------------------------------------------------- + + // perform cmp ? a : b + inline simd4f select(const simd4f_bool& cmp, const simd4f& a, const simd4f& b) + { +#ifdef DLIB_HAVE_SSE41 + return _mm_blendv_ps(b,a,cmp); +#elif defined(DLIB_HAVE_SSE2) + return _mm_or_ps(_mm_and_ps(cmp,a) , _mm_andnot_ps(cmp,b)); +#elif defined(DLIB_HAVE_NEON) + return vbslq_f32(cmp, a, b); +#else + return simd4f(cmp[0]?a[0]:b[0], + cmp[1]?a[1]:b[1], + cmp[2]?a[2]:b[2], + cmp[3]?a[3]:b[3]); +#endif + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_sIMD4F_Hh_ + diff --git a/ml/dlib/dlib/simd/simd4i.h b/ml/dlib/dlib/simd/simd4i.h new file mode 100644 index 000000000..ea33f14a8 --- /dev/null +++ b/ml/dlib/dlib/simd/simd4i.h @@ -0,0 +1,566 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_sIMD4I_Hh_ +#define DLIB_sIMD4I_Hh_ + +#include "simd_check.h" +#include "../uintn.h" + +namespace dlib +{ + +#ifdef DLIB_HAVE_SSE2 + class simd4i + { + public: + typedef int32 type; + + inline simd4i() {} + inline simd4i(int32 f) { x = _mm_set1_epi32(f); } + inline simd4i(int32 r0, int32 r1, int32 r2, int32 r3) { x = _mm_setr_epi32(r0,r1,r2,r3); } + inline simd4i(const __m128i& val):x(val) {} + + inline simd4i& operator=(const __m128i& val) + { + x = val; + return *this; + } + + inline operator __m128i() const { return x; } + + inline void load_aligned(const type* ptr) { x = _mm_load_si128((const __m128i*)ptr); } + inline void store_aligned(type* ptr) const { _mm_store_si128((__m128i*)ptr, x); } + inline void load(const type* ptr) { x = _mm_loadu_si128((const __m128i*)ptr); } + inline void store(type* ptr) const { _mm_storeu_si128((__m128i*)ptr, x); } + + inline unsigned int size() const { return 4; } + inline int32 operator[](unsigned int idx) const + { + int32 temp[4]; + store(temp); + return temp[idx]; + } + + private: + __m128i x; + }; + +#elif defined(DLIB_HAVE_VSX) + + class simd4i + { + typedef union { + vector signed int v; + vector bool int b; + signed int x[4]; + } v4i; + + v4i x; + + public: + inline simd4i() : x{0,0,0,0} { } + inline simd4i(const simd4i& v) : x(v.x) { } + inline simd4i(const vector int& v) : x{v} { } + inline simd4i(const vector bool int& b) { x.b=b; } + + inline simd4i(int32 f) : x{f,f,f,f} { } + inline simd4i(int32 r0, int32 r1, int32 r2, int32 r3) + : x{r0,r1,r2,r3} { } + + inline simd4i& operator=(const simd4i& v) { x = v.x; return *this; } + inline simd4i& operator=(const int32& v) { *this = simd4i(v); return *this; } + + inline vector signed int operator() () const { return x.v; } + inline int32 operator[](unsigned int idx) const { return x.x[idx]; } + + inline vector bool int to_bool() const { return x.b; } + + // intrinsics now seem to use xxpermdi automatically now + inline void load_aligned(const int32* ptr) { x.v = vec_ld(0, ptr); } + inline void store_aligned(int32* ptr) const { vec_st(x.v, 0, ptr); } + inline void load(const int32* ptr) { x.v = vec_vsx_ld(0, ptr); } + inline void store(int32* ptr) const { vec_vsx_st(x.v, 0, ptr); } + + + struct rawarray + { + v4i v; + }; + inline simd4i(const rawarray& a) : x{a.v} { } + + }; + +#elif defined(DLIB_HAVE_NEON) + + class simd4i + { + public: + typedef int32 type; + + inline simd4i() {} + inline simd4i(int32 f) { x = vdupq_n_s32(f); } + inline simd4i(int32 r0, int32 r1, int32 r2, int32 r3) + { + int32 __attribute__((aligned(16))) data[4] = { r0, r1, r2, r3 }; + x = vld1q_s32(data); + } + inline simd4i(const int32x4_t& val):x(val) {} + + inline simd4i& operator=(const int32x4_t& val) + { + x = val; + return *this; + } + + inline operator int32x4_t() const { return x; } + inline operator uint32x4_t() const { return (uint32x4_t)x; } + + inline void load_aligned(const type* ptr) { x = vld1q_s32(ptr); } + inline void store_aligned(type* ptr) const { vst1q_s32(ptr, x); } + inline void load(const type* ptr) { x = vld1q_s32(ptr); } + inline void store(type* ptr) const { vst1q_s32(ptr, x); } + + inline unsigned int size() const { return 4; } + inline int32 operator[](unsigned int idx) const + { + int32 temp[4]; + store(temp); + return temp[idx]; + } + + private: + int32x4_t x; + }; + +#else + + class simd4i + { + public: + typedef int32 type; + + inline simd4i() {} + inline simd4i(int32 f) { x[0]=f; x[1]=f; x[2]=f; x[3]=f; } + inline simd4i(int32 r0, int32 r1, int32 r2, int32 r3) { x[0]=r0; x[1]=r1; x[2]=r2; x[3]=r3;} + + struct rawarray + { + int32 a[4]; + }; + inline simd4i(const rawarray& a) { x[0]=a.a[0]; x[1]=a.a[1]; x[2]=a.a[2]; x[3]=a.a[3]; } + + inline void load_aligned(const type* ptr) + { + x[0] = ptr[0]; + x[1] = ptr[1]; + x[2] = ptr[2]; + x[3] = ptr[3]; + } + + inline void store_aligned(type* ptr) const + { + ptr[0] = x[0]; + ptr[1] = x[1]; + ptr[2] = x[2]; + ptr[3] = x[3]; + } + + inline void load(const type* ptr) + { + x[0] = ptr[0]; + x[1] = ptr[1]; + x[2] = ptr[2]; + x[3] = ptr[3]; + } + + inline void store(type* ptr) const + { + ptr[0] = x[0]; + ptr[1] = x[1]; + ptr[2] = x[2]; + ptr[3] = x[3]; + } + + inline unsigned int size() const { return 4; } + inline int32 operator[](unsigned int idx) const { return x[idx]; } + + private: + int32 x[4]; + }; +#endif + +// ---------------------------------------------------------------------------------------- + + inline std::ostream& operator<<(std::ostream& out, const simd4i& item) + { + int32 temp[4]; + item.store(temp); + out << "(" << temp[0] << ", " << temp[1] << ", " << temp[2] << ", " << temp[3] << ")"; + return out; + } + +// ---------------------------------------------------------------------------------------- + + inline simd4i operator+ (const simd4i& lhs, const simd4i& rhs) + { +#ifdef DLIB_HAVE_SSE2 + return _mm_add_epi32(lhs, rhs); +#elif defined(DLIB_HAVE_VSX) + return vec_add(lhs(), rhs()); +#elif defined(DLIB_HAVE_NEON) + return vaddq_s32(lhs, rhs); +#else + return simd4i(lhs[0]+rhs[0], + lhs[1]+rhs[1], + lhs[2]+rhs[2], + lhs[3]+rhs[3]); +#endif + } + inline simd4i& operator+= (simd4i& lhs, const simd4i& rhs) + { return lhs = lhs + rhs; return lhs;} + +// ---------------------------------------------------------------------------------------- + + inline simd4i operator- (const simd4i& lhs, const simd4i& rhs) + { +#ifdef DLIB_HAVE_SSE2 + return _mm_sub_epi32(lhs, rhs); +#elif defined(DLIB_HAVE_VSX) + return vec_sub(lhs(), rhs()); +#elif defined(DLIB_HAVE_NEON) + return vsubq_s32(lhs, rhs); +#else + return simd4i(lhs[0]-rhs[0], + lhs[1]-rhs[1], + lhs[2]-rhs[2], + lhs[3]-rhs[3]); +#endif + } + inline simd4i& operator-= (simd4i& lhs, const simd4i& rhs) + { return lhs = lhs - rhs; return lhs;} + +// ---------------------------------------------------------------------------------------- + + inline simd4i operator* (const simd4i& lhs, const simd4i& rhs) + { +#ifdef DLIB_HAVE_SSE41 + return _mm_mullo_epi32(lhs, rhs); +#elif defined(DLIB_HAVE_SSE2) + int32 _lhs[4]; lhs.store(_lhs); + int32 _rhs[4]; rhs.store(_rhs); + return simd4i(_lhs[0]*_rhs[0], + _lhs[1]*_rhs[1], + _lhs[2]*_rhs[2], + _lhs[3]*_rhs[3]); +#elif defined(DLIB_HAVE_VSX) + vector int a = lhs(), b = rhs(); + asm("vmuluwm %0, %0, %1\n\t" : "+&v" (a) : "v" (b) ); + return simd4i(a); +#elif defined(DLIB_HAVE_NEON) + return vmulq_s32(lhs, rhs); +#else + return simd4i(lhs[0]*rhs[0], + lhs[1]*rhs[1], + lhs[2]*rhs[2], + lhs[3]*rhs[3]); +#endif + } + inline simd4i& operator*= (simd4i& lhs, const simd4i& rhs) + { return lhs = lhs * rhs; return lhs;} + +// ---------------------------------------------------------------------------------------- + + inline simd4i operator& (const simd4i& lhs, const simd4i& rhs) + { +#ifdef DLIB_HAVE_SSE2 + return _mm_and_si128(lhs, rhs); +#elif defined(DLIB_HAVE_VSX) + return vec_and(lhs(), rhs()); +#elif defined(DLIB_HAVE_NEON) + return vandq_s32(lhs, rhs); +#else + return simd4i(lhs[0]&rhs[0], + lhs[1]&rhs[1], + lhs[2]&rhs[2], + lhs[3]&rhs[3]); +#endif + } + inline simd4i& operator&= (simd4i& lhs, const simd4i& rhs) + { return lhs = lhs & rhs; return lhs;} + +// ---------------------------------------------------------------------------------------- + + inline simd4i operator| (const simd4i& lhs, const simd4i& rhs) + { +#ifdef DLIB_HAVE_SSE2 + return _mm_or_si128(lhs, rhs); +#elif defined(DLIB_HAVE_VSX) + return vec_or(lhs(), rhs()); +#elif defined(DLIB_HAVE_NEON) + return vorrq_s32(lhs, rhs); +#else + return simd4i(lhs[0]|rhs[0], + lhs[1]|rhs[1], + lhs[2]|rhs[2], + lhs[3]|rhs[3]); +#endif + } + inline simd4i& operator|= (simd4i& lhs, const simd4i& rhs) + { return lhs = lhs | rhs; return lhs;} + +// ---------------------------------------------------------------------------------------- + + inline simd4i operator^ (const simd4i& lhs, const simd4i& rhs) + { +#ifdef DLIB_HAVE_SSE2 + return _mm_xor_si128(lhs, rhs); +#elif defined(DLIB_HAVE_VSX) + return vec_xor(lhs(), rhs()); +#elif defined(DLIB_HAVE_NEON) + return veorq_s32(lhs, rhs); +#else + return simd4i(lhs[0]^rhs[0], + lhs[1]^rhs[1], + lhs[2]^rhs[2], + lhs[3]^rhs[3]); +#endif + } + inline simd4i& operator^= (simd4i& lhs, const simd4i& rhs) + { return lhs = lhs ^ rhs; return lhs;} + +// ---------------------------------------------------------------------------------------- + + inline simd4i operator~ (const simd4i& lhs) + { +#ifdef DLIB_HAVE_SSE2 + return _mm_xor_si128(lhs, _mm_set1_epi32(0xFFFFFFFF)); +#elif defined(DLIB_HAVE_VSX) + return vec_xor(lhs(), vec_splats(~0)); +#elif defined(DLIB_HAVE_NEON) + return vmvnq_s32(lhs); +#else + return simd4i(~lhs[0], + ~lhs[1], + ~lhs[2], + ~lhs[3]); +#endif + } + +// ---------------------------------------------------------------------------------------- + + inline simd4i operator<< (const simd4i& lhs, const int& rhs) + { +#ifdef DLIB_HAVE_SSE2 + return _mm_sll_epi32(lhs,_mm_cvtsi32_si128(rhs)); +#elif defined(DLIB_HAVE_VSX) + return vec_sl(lhs(), vec_splats((uint32_t)rhs)); +#elif defined(DLIB_HAVE_NEON) + return vshlq_s32(lhs, simd4i(rhs)); +#else + return simd4i(lhs[0]<> (const simd4i& lhs, const int& rhs) + { +#ifdef DLIB_HAVE_SSE2 + return _mm_sra_epi32(lhs,_mm_cvtsi32_si128(rhs)); +#elif defined(DLIB_HAVE_VSX) + return vec_sr(lhs(), vec_splats((uint32_t)rhs)); +#elif defined(DLIB_HAVE_NEON) + int32 _lhs[4]; lhs.store(_lhs); + return simd4i(_lhs[0]>>rhs, + _lhs[1]>>rhs, + _lhs[2]>>rhs, + _lhs[3]>>rhs); +#else + return simd4i(lhs[0]>>rhs, + lhs[1]>>rhs, + lhs[2]>>rhs, + lhs[3]>>rhs); +#endif + } + inline simd4i& operator>>= (simd4i& lhs, const int& rhs) + { return lhs = lhs >> rhs; return lhs;} + +// ---------------------------------------------------------------------------------------- + + inline simd4i operator== (const simd4i& lhs, const simd4i& rhs) + { +#ifdef DLIB_HAVE_SSE2 + return _mm_cmpeq_epi32(lhs, rhs); +#elif defined(DLIB_HAVE_VSX) + return vec_cmpeq(lhs(), rhs()); +#elif defined(DLIB_HAVE_NEON) + return (int32x4_t)vceqq_s32(lhs,rhs); +#else + return simd4i(lhs[0]==rhs[0] ? 0xFFFFFFFF : 0, + lhs[1]==rhs[1] ? 0xFFFFFFFF : 0, + lhs[2]==rhs[2] ? 0xFFFFFFFF : 0, + lhs[3]==rhs[3] ? 0xFFFFFFFF : 0); +#endif + } + +// ---------------------------------------------------------------------------------------- + + inline simd4i operator!= (const simd4i& lhs, const simd4i& rhs) + { +#if defined(DLIB_HAVE_SSE2) || defined(DLIB_HAVE_VSX) || defined(DLIB_HAVE_NEON) + return ~(lhs==rhs); +#else + return simd4i(lhs[0]!=rhs[0] ? 0xFFFFFFFF : 0, + lhs[1]!=rhs[1] ? 0xFFFFFFFF : 0, + lhs[2]!=rhs[2] ? 0xFFFFFFFF : 0, + lhs[3]!=rhs[3] ? 0xFFFFFFFF : 0); +#endif + } + +// ---------------------------------------------------------------------------------------- + + inline simd4i operator< (const simd4i& lhs, const simd4i& rhs) + { +#ifdef DLIB_HAVE_SSE2 + return _mm_cmplt_epi32(lhs, rhs); +#elif defined(DLIB_HAVE_VSX) + return vec_cmplt(lhs(), rhs()); +#elif defined(DLIB_HAVE_NEON) + return (int32x4_t)vcltq_s32(lhs, rhs); +#else + return simd4i(lhs[0] (const simd4i& lhs, const simd4i& rhs) + { + return rhs < lhs; + } + +// ---------------------------------------------------------------------------------------- + + inline simd4i operator<= (const simd4i& lhs, const simd4i& rhs) + { +#ifdef DLIB_HAVE_SSE2 + return ~(lhs > rhs); +#elif defined(DLIB_HAVE_NEON) + return (int32x4_t)vcleq_s32(lhs, rhs); +#else + return simd4i(lhs[0]<=rhs[0] ? 0xFFFFFFFF : 0, + lhs[1]<=rhs[1] ? 0xFFFFFFFF : 0, + lhs[2]<=rhs[2] ? 0xFFFFFFFF : 0, + lhs[3]<=rhs[3] ? 0xFFFFFFFF : 0); +#endif + } + +// ---------------------------------------------------------------------------------------- + + inline simd4i operator>= (const simd4i& lhs, const simd4i& rhs) + { + return rhs <= lhs; + } + +// ---------------------------------------------------------------------------------------- + + inline simd4i min (const simd4i& lhs, const simd4i& rhs) + { +#ifdef DLIB_HAVE_SSE41 + return _mm_min_epi32(lhs, rhs); +#elif defined(DLIB_HAVE_SSE2) + int32 _lhs[4]; lhs.store(_lhs); + int32 _rhs[4]; rhs.store(_rhs); + return simd4i(std::min(_lhs[0],_rhs[0]), + std::min(_lhs[1],_rhs[1]), + std::min(_lhs[2],_rhs[2]), + std::min(_lhs[3],_rhs[3])); +#elif defined(DLIB_HAVE_VSX) + return vec_min(lhs(), rhs()); +#elif defined(DLIB_HAVE_NEON) + return (int32x4_t)vminq_s32(lhs, rhs); +#else + return simd4i(std::min(lhs[0],rhs[0]), + std::min(lhs[1],rhs[1]), + std::min(lhs[2],rhs[2]), + std::min(lhs[3],rhs[3])); +#endif + } + +// ---------------------------------------------------------------------------------------- + + inline simd4i max (const simd4i& lhs, const simd4i& rhs) + { +#ifdef DLIB_HAVE_SSE41 + return _mm_max_epi32(lhs, rhs); +#elif defined(DLIB_HAVE_SSE2) + int32 _lhs[4]; lhs.store(_lhs); + int32 _rhs[4]; rhs.store(_rhs); + return simd4i(std::max(_lhs[0],_rhs[0]), + std::max(_lhs[1],_rhs[1]), + std::max(_lhs[2],_rhs[2]), + std::max(_lhs[3],_rhs[3])); +#elif defined(DLIB_HAVE_VSX) + return vec_max(lhs(), rhs()); +#elif defined(DLIB_HAVE_NEON) + return vmaxq_s32(lhs, rhs); +#else + return simd4i(std::max(lhs[0],rhs[0]), + std::max(lhs[1],rhs[1]), + std::max(lhs[2],rhs[2]), + std::max(lhs[3],rhs[3])); +#endif + } + +// ---------------------------------------------------------------------------------------- + + inline int32 sum(const simd4i& item) + { +#ifdef DLIB_HAVE_SSE3 + simd4i temp = _mm_hadd_epi32(item,item); + temp = _mm_hadd_epi32(temp,temp); + return _mm_cvtsi128_si32(temp); +#elif defined(DLIB_HAVE_SSE2) + int32 temp[4]; + item.store(temp); + return temp[0]+temp[1]+temp[2]+temp[3]; +#elif defined(DLIB_HAVE_NEON) + int32x2_t r = vadd_s32(vget_high_s32(item), vget_low_s32(item)); + return vget_lane_s32(vpadd_s32(r, r), 0); +#else + return item[0]+item[1]+item[2]+item[3]; +#endif + } + +// ---------------------------------------------------------------------------------------- + + // perform cmp ? a : b + inline simd4i select(const simd4i& cmp, const simd4i& a, const simd4i& b) + { +#ifdef DLIB_HAVE_SSE41 + return _mm_blendv_epi8(b,a,cmp); +#elif defined(DLIB_HAVE_SSE2) + return ((cmp&a) | _mm_andnot_si128(cmp,b)); +#elif defined(DLIB_HAVE_VSX) + return vec_sel(b(), a(), cmp.to_bool()); +#elif defined(DLIB_HAVE_NEON) + return vbslq_s32(cmp, a, b); +#else + return ((cmp&a) | (~cmp&b)); +#endif + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_sIMD4I_Hh_ + diff --git a/ml/dlib/dlib/simd/simd8f.h b/ml/dlib/dlib/simd/simd8f.h new file mode 100644 index 000000000..628ba74ee --- /dev/null +++ b/ml/dlib/dlib/simd/simd8f.h @@ -0,0 +1,402 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_sIMD8F_Hh_ +#define DLIB_sIMD8F_Hh_ + +#include "simd_check.h" +#include "simd4f.h" +#include "simd8i.h" + +namespace dlib +{ +#ifdef DLIB_HAVE_AVX + class simd8f + { + public: + typedef float type; + + inline simd8f() {} + inline simd8f(const simd4f& low, const simd4f& high) + { + x = _mm256_insertf128_ps(_mm256_castps128_ps256(low),high,1); + } + inline simd8f(float f) { x = _mm256_set1_ps(f); } + inline simd8f(float r0, float r1, float r2, float r3, float r4, float r5, float r6, float r7) + { x = _mm256_setr_ps(r0,r1,r2,r3,r4,r5,r6,r7); } + + inline simd8f(const simd8i& val):x(_mm256_cvtepi32_ps(val)) {} + inline simd8f(const __m256& val):x(val) {} + inline simd8f& operator=(const __m256& val) + { + x = val; + return *this; + } + inline operator __m256() const { return x; } + + // truncate to 32bit integers + inline operator __m256i() const { return _mm256_cvttps_epi32(x); } + + inline void load_aligned(const type* ptr) { x = _mm256_load_ps(ptr); } + inline void store_aligned(type* ptr) const { _mm256_store_ps(ptr, x); } + inline void load(const type* ptr) { x = _mm256_loadu_ps(ptr); } + inline void store(type* ptr) const { _mm256_storeu_ps(ptr, x); } + + inline simd8f& operator=(const simd8i& rhs) { *this = simd8f(rhs); return *this; } + inline simd8f& operator=(const float& val) + { + x = simd8f(val); + return *this; + } + + inline unsigned int size() const { return 8; } + inline float operator[](unsigned int idx) const + { + float temp[8]; + store(temp); + return temp[idx]; + } + + inline simd4f low() const { return _mm256_castps256_ps128(x); } + inline simd4f high() const { return _mm256_extractf128_ps(x,1); } + + private: + __m256 x; + }; + + + class simd8f_bool + { + public: + typedef float type; + + inline simd8f_bool() {} + inline simd8f_bool(const __m256& val):x(val) {} + inline simd8f_bool(const simd4f_bool& low, const simd4f_bool& high) + { + x = _mm256_insertf128_ps(_mm256_castps128_ps256(low),high,1); + } + + inline simd8f_bool& operator=(const __m256& val) + { + x = val; + return *this; + } + + inline operator __m256() const { return x; } + + + private: + __m256 x; + }; + +#else + class simd8f + { + public: + typedef float type; + + inline simd8f() {} + inline simd8f(const simd4f& low_, const simd4f& high_): _low(low_),_high(high_){} + inline simd8f(float f) :_low(f),_high(f) {} + inline simd8f(float r0, float r1, float r2, float r3, float r4, float r5, float r6, float r7) : + _low(r0,r1,r2,r3), _high(r4,r5,r6,r7) {} + inline simd8f(const simd8i& val) : _low(val.low()), _high(val.high()) { } + + // truncate to 32bit integers + inline operator simd8i::rawarray() const + { + simd8i::rawarray temp; + temp.low = simd4i(_low); + temp.high = simd4i(_high); + return temp; + } + + inline void load_aligned(const type* ptr) { _low.load_aligned(ptr); _high.load_aligned(ptr+4); } + inline void store_aligned(type* ptr) const { _low.store_aligned(ptr); _high.store_aligned(ptr+4); } + inline void load(const type* ptr) { _low.load(ptr); _high.load(ptr+4); } + inline void store(type* ptr) const { _low.store(ptr); _high.store(ptr+4); } + + inline unsigned int size() const { return 8; } + inline float operator[](unsigned int idx) const + { + if (idx < 4) + return _low[idx]; + else + return _high[idx-4]; + } + + inline const simd4f& low() const { return _low; } + inline const simd4f& high() const { return _high; } + + private: + simd4f _low, _high; + }; + + class simd8f_bool + { + public: + typedef float type; + + inline simd8f_bool() {} + inline simd8f_bool(const simd4f_bool& low_, const simd4f_bool& high_): _low(low_),_high(high_){} + + + inline const simd4f_bool& low() const { return _low; } + inline const simd4f_bool& high() const { return _high; } + private: + simd4f_bool _low,_high; + }; +#endif + +// ---------------------------------------------------------------------------------------- + + inline std::ostream& operator<<(std::ostream& out, const simd8f& item) + { + float temp[8]; + item.store(temp); + out << "(" << temp[0] << ", " << temp[1] << ", " << temp[2] << ", " << temp[3] << ", " + << temp[4] << ", " << temp[5] << ", " << temp[6] << ", " << temp[7] << ")"; + return out; + } + +// ---------------------------------------------------------------------------------------- + + inline simd8f operator+ (const simd8f& lhs, const simd8f& rhs) + { +#ifdef DLIB_HAVE_AVX + return _mm256_add_ps(lhs, rhs); +#else + return simd8f(lhs.low()+rhs.low(), + lhs.high()+rhs.high()); +#endif + } + inline simd8f& operator+= (simd8f& lhs, const simd8f& rhs) + { lhs = lhs + rhs; return lhs; } + +// ---------------------------------------------------------------------------------------- + + inline simd8f operator- (const simd8f& lhs, const simd8f& rhs) + { +#ifdef DLIB_HAVE_AVX + return _mm256_sub_ps(lhs, rhs); +#else + return simd8f(lhs.low()-rhs.low(), + lhs.high()-rhs.high()); +#endif + } + inline simd8f& operator-= (simd8f& lhs, const simd8f& rhs) + { lhs = lhs - rhs; return lhs; } + +// ---------------------------------------------------------------------------------------- + + inline simd8f operator* (const simd8f& lhs, const simd8f& rhs) + { +#ifdef DLIB_HAVE_AVX + return _mm256_mul_ps(lhs, rhs); +#else + return simd8f(lhs.low()*rhs.low(), + lhs.high()*rhs.high()); +#endif + } + inline simd8f& operator*= (simd8f& lhs, const simd8f& rhs) + { lhs = lhs * rhs; return lhs; } + +// ---------------------------------------------------------------------------------------- + + inline simd8f operator/ (const simd8f& lhs, const simd8f& rhs) + { +#ifdef DLIB_HAVE_AVX + return _mm256_div_ps(lhs, rhs); +#else + return simd8f(lhs.low()/rhs.low(), + lhs.high()/rhs.high()); +#endif + } + inline simd8f& operator/= (simd8f& lhs, const simd8f& rhs) + { lhs = lhs / rhs; return lhs; } + +// ---------------------------------------------------------------------------------------- + + inline simd8f_bool operator== (const simd8f& lhs, const simd8f& rhs) + { +#ifdef DLIB_HAVE_AVX + return _mm256_cmp_ps(lhs, rhs, 0); +#else + return simd8f_bool(lhs.low() ==rhs.low(), + lhs.high()==rhs.high()); +#endif + } + +// ---------------------------------------------------------------------------------------- + + inline simd8f_bool operator!= (const simd8f& lhs, const simd8f& rhs) + { +#ifdef DLIB_HAVE_AVX + return _mm256_cmp_ps(lhs, rhs, 4); +#else + return simd8f_bool(lhs.low() !=rhs.low(), + lhs.high()!=rhs.high()); +#endif + } + +// ---------------------------------------------------------------------------------------- + + inline simd8f_bool operator< (const simd8f& lhs, const simd8f& rhs) + { +#ifdef DLIB_HAVE_AVX + return _mm256_cmp_ps(lhs, rhs, 1); +#else + return simd8f_bool(lhs.low() (const simd8f& lhs, const simd8f& rhs) + { + return rhs < lhs; + } + +// ---------------------------------------------------------------------------------------- + + inline simd8f_bool operator<= (const simd8f& lhs, const simd8f& rhs) + { +#ifdef DLIB_HAVE_AVX + return _mm256_cmp_ps(lhs, rhs, 2); +#else + return simd8f_bool(lhs.low() <=rhs.low(), + lhs.high()<=rhs.high()); +#endif + } + +// ---------------------------------------------------------------------------------------- + + inline simd8f_bool operator>= (const simd8f& lhs, const simd8f& rhs) + { + return rhs <= lhs; + } + +// ---------------------------------------------------------------------------------------- + + inline simd8f min (const simd8f& lhs, const simd8f& rhs) + { +#ifdef DLIB_HAVE_AVX + return _mm256_min_ps(lhs, rhs); +#else + return simd8f(min(lhs.low(), rhs.low()), + min(lhs.high(),rhs.high())); +#endif + } + +// ---------------------------------------------------------------------------------------- + + inline simd8f max (const simd8f& lhs, const simd8f& rhs) + { +#ifdef DLIB_HAVE_AVX + return _mm256_max_ps(lhs, rhs); +#else + return simd8f(max(lhs.low(), rhs.low()), + max(lhs.high(),rhs.high())); +#endif + } + +// ---------------------------------------------------------------------------------------- + + inline simd8f reciprocal (const simd8f& item) + { +#ifdef DLIB_HAVE_AVX + return _mm256_rcp_ps(item); +#else + return simd8f(reciprocal(item.low()), + reciprocal(item.high())); +#endif + } + +// ---------------------------------------------------------------------------------------- + + inline simd8f reciprocal_sqrt (const simd8f& item) + { +#ifdef DLIB_HAVE_AVX + return _mm256_rsqrt_ps(item); +#else + return simd8f(reciprocal_sqrt(item.low()), + reciprocal_sqrt(item.high())); +#endif + } + +// ---------------------------------------------------------------------------------------- + + inline float sum(const simd8f& item) + { +#ifdef DLIB_HAVE_AVX + simd8f temp = _mm256_hadd_ps(item,item); + simd8f temp2 = _mm256_hadd_ps(temp,temp); + return _mm_cvtss_f32(_mm_add_ss(_mm256_castps256_ps128(temp2),_mm256_extractf128_ps(temp2,1))); +#else + return sum(item.low()+item.high()); +#endif + } + +// ---------------------------------------------------------------------------------------- + + inline float dot(const simd8f& lhs, const simd8f& rhs) + { + return sum(lhs*rhs); + } + +// ---------------------------------------------------------------------------------------- + + inline simd8f sqrt(const simd8f& item) + { +#ifdef DLIB_HAVE_AVX + return _mm256_sqrt_ps(item); +#else + return simd8f(sqrt(item.low()), + sqrt(item.high())); +#endif + } + +// ---------------------------------------------------------------------------------------- + + inline simd8f ceil(const simd8f& item) + { +#ifdef DLIB_HAVE_AVX + return _mm256_ceil_ps(item); +#else + return simd8f(ceil(item.low()), + ceil(item.high())); +#endif + } + +// ---------------------------------------------------------------------------------------- + + inline simd8f floor(const simd8f& item) + { +#ifdef DLIB_HAVE_AVX + return _mm256_floor_ps(item); +#else + return simd8f(floor(item.low()), + floor(item.high())); +#endif + } + +// ---------------------------------------------------------------------------------------- + + // perform cmp ? a : b + inline simd8f select(const simd8f_bool& cmp, const simd8f& a, const simd8f& b) + { +#ifdef DLIB_HAVE_AVX + return _mm256_blendv_ps(b,a,cmp); +#else + return simd8f(select(cmp.low(), a.low(), b.low()), + select(cmp.high(), a.high(), b.high())); +#endif + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_sIMD8F_Hh_ + diff --git a/ml/dlib/dlib/simd/simd8i.h b/ml/dlib/dlib/simd/simd8i.h new file mode 100644 index 000000000..18c06ec7e --- /dev/null +++ b/ml/dlib/dlib/simd/simd8i.h @@ -0,0 +1,339 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_sIMD8I_Hh_ +#define DLIB_sIMD8I_Hh_ + +#include "simd_check.h" +#include "../uintn.h" + +namespace dlib +{ + +#ifdef DLIB_HAVE_AVX + class simd8i + { + public: + typedef int32 type; + + inline simd8i() {} + inline simd8i(int32 f) { x = _mm256_set1_epi32(f); } + inline simd8i(int32 r0, int32 r1, int32 r2, int32 r3, + int32 r4, int32 r5, int32 r6, int32 r7 ) + { x = _mm256_setr_epi32(r0,r1,r2,r3,r4,r5,r6,r7); } + + inline simd8i(const __m256i& val):x(val) {} + + inline simd8i(const simd4i& low, const simd4i& high) + { + x = _mm256_insertf128_si256(_mm256_castsi128_si256(low),high,1); + } + + inline simd8i& operator=(const __m256i& val) + { + x = val; + return *this; + } + + inline operator __m256i() const { return x; } + + inline void load_aligned(const type* ptr) { x = _mm256_load_si256((const __m256i*)ptr); } + inline void store_aligned(type* ptr) const { _mm256_store_si256((__m256i*)ptr, x); } + inline void load(const type* ptr) { x = _mm256_loadu_si256((const __m256i*)ptr); } + inline void store(type* ptr) const { _mm256_storeu_si256((__m256i*)ptr, x); } + + inline simd4i low() const { return _mm256_castsi256_si128(x); } + inline simd4i high() const { return _mm256_extractf128_si256(x,1); } + + inline unsigned int size() const { return 8; } + inline int32 operator[](unsigned int idx) const + { + int32 temp[8]; + store(temp); + return temp[idx]; + } + + private: + __m256i x; + }; +#else + class simd8i + { + public: + typedef int32 type; + + inline simd8i() {} + inline simd8i(const simd4i& low_, const simd4i& high_): _low(low_),_high(high_){} + inline simd8i(int32 f) :_low(f),_high(f) {} + inline simd8i(int32 r0, int32 r1, int32 r2, int32 r3, int32 r4, int32 r5, int32 r6, int32 r7) : + _low(r0,r1,r2,r3), _high(r4,r5,r6,r7) {} + + struct rawarray + { + simd4i low, high; + }; + inline simd8i(const rawarray& a) + { + _low = a.low; + _high = a.high; + } + + inline void load_aligned(const type* ptr) { _low.load_aligned(ptr); _high.load_aligned(ptr+4); } + inline void store_aligned(type* ptr) const { _low.store_aligned(ptr); _high.store_aligned(ptr+4); } + inline void load(const type* ptr) { _low.load(ptr); _high.load(ptr+4); } + inline void store(type* ptr) const { _low.store(ptr); _high.store(ptr+4); } + + inline unsigned int size() const { return 8; } + inline int32 operator[](unsigned int idx) const + { + if (idx < 4) + return _low[idx]; + else + return _high[idx-4]; + } + + inline const simd4i& low() const { return _low; } + inline const simd4i& high() const { return _high; } + + private: + simd4i _low, _high; + }; + +#endif + +// ---------------------------------------------------------------------------------------- + + inline std::ostream& operator<<(std::ostream& out, const simd8i& item) + { + int32 temp[8]; + item.store(temp); + out << "(" << temp[0] << ", " << temp[1] << ", " << temp[2] << ", " << temp[3] << ", " + << temp[4] << ", " << temp[5] << ", " << temp[6] << ", " << temp[7] << ")"; + return out; + } + +// ---------------------------------------------------------------------------------------- + + inline simd8i operator+ (const simd8i& lhs, const simd8i& rhs) + { +#ifdef DLIB_HAVE_AVX2 + return _mm256_add_epi32(lhs, rhs); +#else + return simd8i(lhs.low()+rhs.low(), + lhs.high()+rhs.high()); +#endif + } + inline simd8i& operator+= (simd8i& lhs, const simd8i& rhs) + { return lhs = lhs + rhs; return lhs;} + +// ---------------------------------------------------------------------------------------- + + inline simd8i operator- (const simd8i& lhs, const simd8i& rhs) + { +#ifdef DLIB_HAVE_AVX2 + return _mm256_sub_epi32(lhs, rhs); +#else + return simd8i(lhs.low()-rhs.low(), + lhs.high()-rhs.high()); +#endif + } + inline simd8i& operator-= (simd8i& lhs, const simd8i& rhs) + { return lhs = lhs - rhs; return lhs;} + +// ---------------------------------------------------------------------------------------- + + inline simd8i operator* (const simd8i& lhs, const simd8i& rhs) + { +#ifdef DLIB_HAVE_AVX2 + return _mm256_mullo_epi32(lhs, rhs); +#else + return simd8i(lhs.low()*rhs.low(), + lhs.high()*rhs.high()); +#endif + } + inline simd8i& operator*= (simd8i& lhs, const simd8i& rhs) + { return lhs = lhs * rhs; return lhs;} + +// ---------------------------------------------------------------------------------------- + + inline simd8i operator& (const simd8i& lhs, const simd8i& rhs) + { +#ifdef DLIB_HAVE_AVX2 + return _mm256_and_si256(lhs, rhs); +#else + return simd8i(lhs.low()&rhs.low(), + lhs.high()&rhs.high()); +#endif + } + inline simd8i& operator&= (simd8i& lhs, const simd8i& rhs) + { return lhs = lhs & rhs; return lhs;} + +// ---------------------------------------------------------------------------------------- + + inline simd8i operator| (const simd8i& lhs, const simd8i& rhs) + { +#ifdef DLIB_HAVE_AVX2 + return _mm256_or_si256(lhs, rhs); +#else + return simd8i(lhs.low()|rhs.low(), + lhs.high()|rhs.high()); +#endif + } + inline simd8i& operator|= (simd8i& lhs, const simd8i& rhs) + { return lhs = lhs | rhs; return lhs;} + +// ---------------------------------------------------------------------------------------- + + inline simd8i operator^ (const simd8i& lhs, const simd8i& rhs) + { +#ifdef DLIB_HAVE_AVX2 + return _mm256_xor_si256(lhs, rhs); +#else + return simd8i(lhs.low()^rhs.low(), + lhs.high()^rhs.high()); +#endif + } + inline simd8i& operator^= (simd8i& lhs, const simd8i& rhs) + { return lhs = lhs ^ rhs; return lhs;} + +// ---------------------------------------------------------------------------------------- + + inline simd8i operator~ (const simd8i& lhs) + { +#ifdef DLIB_HAVE_AVX2 + return _mm256_xor_si256(lhs, _mm256_set1_epi32(0xFFFFFFFF)); +#else + return simd8i(~lhs.low(), ~lhs.high()); +#endif + } + +// ---------------------------------------------------------------------------------------- + + inline simd8i operator<< (const simd8i& lhs, const int& rhs) + { +#ifdef DLIB_HAVE_AVX2 + return _mm256_sll_epi32(lhs,_mm_cvtsi32_si128(rhs)); +#else + return simd8i(lhs.low()<> (const simd8i& lhs, const int& rhs) + { +#ifdef DLIB_HAVE_AVX2 + return _mm256_sra_epi32(lhs,_mm_cvtsi32_si128(rhs)); +#else + return simd8i(lhs.low()>>rhs, + lhs.high()>>rhs); +#endif + } + inline simd8i& operator>>= (simd8i& lhs, const int& rhs) + { return lhs = lhs >> rhs; return lhs;} + +// ---------------------------------------------------------------------------------------- + + inline simd8i operator== (const simd8i& lhs, const simd8i& rhs) + { +#ifdef DLIB_HAVE_AVX2 + return _mm256_cmpeq_epi32(lhs, rhs); +#else + return simd8i(lhs.low()==rhs.low(), + lhs.high()==rhs.high()); +#endif + } + +// ---------------------------------------------------------------------------------------- + + inline simd8i operator!= (const simd8i& lhs, const simd8i& rhs) + { + return ~(lhs==rhs); + } + +// ---------------------------------------------------------------------------------------- + + inline simd8i operator> (const simd8i& lhs, const simd8i& rhs) + { +#ifdef DLIB_HAVE_AVX2 + return _mm256_cmpgt_epi32(lhs, rhs); +#else + return simd8i(lhs.low()>rhs.low(), + lhs.high()>rhs.high()); +#endif + } + +// ---------------------------------------------------------------------------------------- + + inline simd8i operator< (const simd8i& lhs, const simd8i& rhs) + { + return rhs > lhs; + } + +// ---------------------------------------------------------------------------------------- + + inline simd8i operator<= (const simd8i& lhs, const simd8i& rhs) + { + return ~(lhs > rhs); + } + +// ---------------------------------------------------------------------------------------- + + inline simd8i operator>= (const simd8i& lhs, const simd8i& rhs) + { + return rhs <= lhs; + } + +// ---------------------------------------------------------------------------------------- + + inline simd8i min (const simd8i& lhs, const simd8i& rhs) + { +#ifdef DLIB_HAVE_AVX2 + return _mm256_min_epi32(lhs, rhs); +#else + return simd8i(min(lhs.low(),rhs.low()), + min(lhs.high(),rhs.high())); +#endif + } + +// ---------------------------------------------------------------------------------------- + + inline simd8i max (const simd8i& lhs, const simd8i& rhs) + { +#ifdef DLIB_HAVE_AVX2 + return _mm256_max_epi32(lhs, rhs); +#else + return simd8i(max(lhs.low(),rhs.low()), + max(lhs.high(),rhs.high())); +#endif + } + +// ---------------------------------------------------------------------------------------- + + inline int32 sum(const simd8i& item) + { + return sum(item.low()+item.high()); + } + +// ---------------------------------------------------------------------------------------- + + // perform cmp ? a : b + inline simd8i select(const simd8i& cmp, const simd8i& a, const simd8i& b) + { +#ifdef DLIB_HAVE_AVX2 + return _mm256_blendv_epi8(b,a,cmp); +#else + return simd8i(select(cmp.low(), a.low(), b.low()), + select(cmp.high(), a.high(), b.high())); +#endif + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_sIMD8I_Hh_ + + diff --git a/ml/dlib/dlib/simd/simd_check.h b/ml/dlib/dlib/simd/simd_check.h new file mode 100644 index 000000000..c4ca0c3b8 --- /dev/null +++ b/ml/dlib/dlib/simd/simd_check.h @@ -0,0 +1,177 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SIMd_CHECK_Hh_ +#define DLIB_SIMd_CHECK_Hh_ + +#include +#include + +//#define DLIB_DO_NOT_USE_SIMD + +// figure out which SIMD instructions we can use. +#ifndef DLIB_DO_NOT_USE_SIMD + #if defined(_MSC_VER) + #ifdef __AVX__ + #ifndef DLIB_HAVE_SSE2 + #define DLIB_HAVE_SSE2 + #endif + #ifndef DLIB_HAVE_SSE3 + #define DLIB_HAVE_SSE3 + #endif + #ifndef DLIB_HAVE_SSE41 + #define DLIB_HAVE_SSE41 + #endif + #ifndef DLIB_HAVE_AVX + #define DLIB_HAVE_AVX + #endif + #endif + #if (defined( _M_X64) || defined(_M_IX86_FP) && _M_IX86_FP >= 2) && !defined(DLIB_HAVE_SSE2) + #define DLIB_HAVE_SSE2 + #endif + #else + #ifdef __SSE2__ + #ifndef DLIB_HAVE_SSE2 + #define DLIB_HAVE_SSE2 + #endif + #endif + #ifdef __SSSE3__ + #ifndef DLIB_HAVE_SSE3 + #define DLIB_HAVE_SSE3 + #endif + #endif + #ifdef __SSE4_1__ + #ifndef DLIB_HAVE_SSE41 + #define DLIB_HAVE_SSE41 + #endif + #endif + #ifdef __AVX__ + #ifndef DLIB_HAVE_AVX + #define DLIB_HAVE_AVX + #endif + #endif + #ifdef __AVX2__ + #ifndef DLIB_HAVE_AVX2 + #define DLIB_HAVE_AVX2 + #endif + #endif + #ifdef __ALTIVEC__ + #ifndef DLIB_HAVE_ALTIVEC + #define DLIB_HAVE_ALTIVEC + #endif + #endif + #ifdef __VSX__ + #ifndef DLIB_HAVE_VSX + #define DLIB_HAVE_VSX + #endif + #endif + #ifdef __VEC__ // __VEC__ = 10206 + #ifndef DLIB_HAVE_POWER_VEC // vector and vec_ intrinsics + #define DLIB_HAVE_POWER_VEC + #endif + #endif + #ifdef __ARM_NEON + #ifndef DLIB_HAVE_NEON + #define DLIB_HAVE_NEON + #endif + #endif + #endif +#endif + + +// ---------------------------------------------------------------------------------------- + + +#ifdef DLIB_HAVE_ALTIVEC +#include +#endif + +#ifdef DLIB_HAVE_SSE2 + #include + #include + #include +#endif +#ifdef DLIB_HAVE_SSE3 + #include // SSE3 + #include +#endif +#ifdef DLIB_HAVE_SSE41 + #include // SSE4 +#endif +#ifdef DLIB_HAVE_AVX + #include // AVX +#endif +#ifdef DLIB_HAVE_AVX2 + #include // AVX +// #include +#endif +#ifdef DLIB_HAVE_NEON + #include // ARM NEON +#endif + +// ---------------------------------------------------------------------------------------- +// Define functions to check, at runtime, what instructions are available + +#if defined(_MSC_VER) && (defined(_M_I86) || defined(_M_IX86) || defined(_M_X64) || defined(_M_AMD64) ) + #include + + inline std::array cpuid(int function_id) + { + std::array info; + // Load EAX, EBX, ECX, EDX into info + __cpuid((int*)info.data(), function_id); + return info; + } + +#elif (defined(__GNUC__) || defined(__clang__)) && (defined(__i386__) || defined(__i686__) || defined(__amd64__) || defined(__x86_64__)) + #include + + inline std::array cpuid(int function_id) + { + std::array info; + // Load EAX, EBX, ECX, EDX into info + __cpuid(function_id, info[0], info[1], info[2], info[3]); + return info; + } + +#else + + inline std::array cpuid(int) + { + return std::array{}; + } + +#endif + + inline bool cpu_has_sse2_instructions() { return 0!=(cpuid(1)[3]&(1<<26)); } + inline bool cpu_has_sse3_instructions() { return 0!=(cpuid(1)[2]&(1<<0)); } + inline bool cpu_has_sse41_instructions() { return 0!=(cpuid(1)[2]&(1<<19)); } + inline bool cpu_has_sse42_instructions() { return 0!=(cpuid(1)[2]&(1<<20)); } + inline bool cpu_has_avx_instructions() { return 0!=(cpuid(1)[2]&(1<<28)); } + inline bool cpu_has_avx2_instructions() { return 0!=(cpuid(7)[1]&(1<<5)); } + inline bool cpu_has_avx512_instructions() { return 0!=(cpuid(7)[1]&(1<<16)); } + + inline void warn_about_unavailable_but_used_cpu_instructions() + { +#if defined(DLIB_HAVE_AVX2) + if (!cpu_has_avx2_instructions()) + std::cerr << "Dlib was compiled to use AVX2 instructions, but these aren't available on your machine." << std::endl; +#elif defined(DLIB_HAVE_AVX) + if (!cpu_has_avx_instructions()) + std::cerr << "Dlib was compiled to use AVX instructions, but these aren't available on your machine." << std::endl; +#elif defined(DLIB_HAVE_SSE41) + if (!cpu_has_sse41_instructions()) + std::cerr << "Dlib was compiled to use SSE41 instructions, but these aren't available on your machine." << std::endl; +#elif defined(DLIB_HAVE_SSE3) + if (!cpu_has_sse3_instructions()) + std::cerr << "Dlib was compiled to use SSE3 instructions, but these aren't available on your machine." << std::endl; +#elif defined(DLIB_HAVE_SSE2) + if (!cpu_has_sse2_instructions()) + std::cerr << "Dlib was compiled to use SSE2 instructions, but these aren't available on your machine." << std::endl; +#endif + } + +// ---------------------------------------------------------------------------------------- + +#endif // DLIB_SIMd_CHECK_Hh_ + + diff --git a/ml/dlib/dlib/sliding_buffer.h b/ml/dlib/dlib/sliding_buffer.h new file mode 100644 index 000000000..fb89e1b00 --- /dev/null +++ b/ml/dlib/dlib/sliding_buffer.h @@ -0,0 +1,38 @@ +// Copyright (C) 2004 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SLIDING_BUFFEr_ +#define DLIB_SLIDING_BUFFEr_ + + +#include "sliding_buffer/sliding_buffer_kernel_1.h" +#include "sliding_buffer/sliding_buffer_kernel_c.h" +#include "sliding_buffer/circular_buffer.h" + + + +namespace dlib +{ + + template < + typename T + > + class sliding_buffer + { + + sliding_buffer() {} + public: + + //----------- kernels --------------- + + // kernel_1a + typedef sliding_buffer_kernel_1 + kernel_1a; + typedef sliding_buffer_kernel_c + kernel_1a_c; + + + }; +} + +#endif // DLIB_SLIDING_BUFFEr_ + diff --git a/ml/dlib/dlib/sliding_buffer/circular_buffer.h b/ml/dlib/dlib/sliding_buffer/circular_buffer.h new file mode 100644 index 000000000..4fcc922d6 --- /dev/null +++ b/ml/dlib/dlib/sliding_buffer/circular_buffer.h @@ -0,0 +1,235 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CIRCULAR_BuFFER_Hh_ +#define DLIB_CIRCULAR_BuFFER_Hh_ + +#include "circular_buffer_abstract.h" +#include +#include "../algs.h" +#include "../serialize.h" +#include "../matrix/matrix_mat.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + class circular_buffer + { + public: + typedef default_memory_manager mem_manager_type; + typedef T value_type; + typedef T type; + + circular_buffer() + { + } + + explicit circular_buffer(unsigned long s) + { + resize(s); + } + + void clear ( + ) + { + offset = 0; + data.clear(); + } + + T& operator[] ( unsigned long i) + { + DLIB_ASSERT(i < size(), + "\t T& circular_buffer::operator[](i)" + << "\n\t You have supplied an invalid index" + << "\n\t this: " << this + << "\n\t i: " << i + << "\n\t size(): " << size() + ); + return data[(i+offset)%data.size()]; + } + + const T& operator[] ( unsigned long i) const + { + DLIB_ASSERT(i < size(), + "\t const T& circular_buffer::operator[](i)" + << "\n\t You have supplied an invalid index" + << "\n\t this: " << this + << "\n\t i: " << i + << "\n\t size(): " << size() + ); + return data[(i+offset)%data.size()]; + } + + void resize(unsigned long size) + { + offset = 0; + data.resize(size); + } + + void assign( + unsigned long size, + const T& value + ) + { + offset = 0; + data.assign(size,value); + } + + unsigned long size() const { return data.size(); } + + void push_front(const T& value) + { + if (data.size() != 0) + { + offset = (offset - 1 + data.size())%data.size(); + data[offset] = value; + } + } + + void push_back(const T& value) + { + if (data.size() != 0) + { + data[offset] = value; + offset = (offset + 1 + data.size())%data.size(); + } + } + + T& front( + ) + { + DLIB_CASSERT(size() > 0, + "\t T& circular_buffer::front()" + << "\n\t You can't call front() on an empty circular_buffer" + << "\n\t this: " << this + ); + return (*this)[0]; + } + + const T& front( + ) const + { + DLIB_CASSERT(size() > 0, + "\t const T& circular_buffer::front()" + << "\n\t You can't call front() on an empty circular_buffer" + << "\n\t this: " << this + ); + return (*this)[0]; + } + + T& back( + ) + { + DLIB_CASSERT(size() > 0, + "\t T& circular_buffer::back()" + << "\n\t You can't call back() on an empty circular_buffer" + << "\n\t this: " << this + ); + return (*this)[size()-1]; + } + + const T& back( + ) const + { + DLIB_CASSERT(size() > 0, + "\t const T& circular_buffer::back()" + << "\n\t You can't call back() on an empty circular_buffer" + << "\n\t this: " << this + ); + return (*this)[size()-1]; + } + + void swap( circular_buffer& item) + { + std::swap(item.offset, offset); + data.swap(item.data); + } + + + private: + std::vector data; + + unsigned long offset = 0; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + void swap ( + circular_buffer& a, + circular_buffer& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + void serialize ( + const circular_buffer& item, + std::ostream& out + ) + { + try + { + serialize(item.size(),out); + for (unsigned long i = 0; i < item.size(); ++i) + serialize(item[i],out); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing object of type circular_buffer"); + } + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + void deserialize ( + circular_buffer& item, + std::istream& in + ) + { + try + { + unsigned long size; + deserialize(size,in); + item.resize(size); + for (unsigned long i = 0; i < size; ++i) + deserialize(item[i],in); + } + catch (serialization_error& e) + { + item.clear(); + throw serialization_error(e.info + "\n while deserializing object of type circular_buffer"); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + const matrix_op > > mat ( + const circular_buffer& m + ) + { + typedef op_array_to_mat > op; + return matrix_op(op(m)); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CIRCULAR_BuFFER_Hh_ + diff --git a/ml/dlib/dlib/sliding_buffer/circular_buffer_abstract.h b/ml/dlib/dlib/sliding_buffer/circular_buffer_abstract.h new file mode 100644 index 000000000..dc9f35c7a --- /dev/null +++ b/ml/dlib/dlib/sliding_buffer/circular_buffer_abstract.h @@ -0,0 +1,257 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_CIRCULAR_BuFFER_ABSTRACT_Hh_ +#ifdef DLIB_CIRCULAR_BuFFER_ABSTRACT_Hh_ + +#include "../algs.h" +#include "../serialize.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + class circular_buffer + { + /*! + REQUIREMENTS ON T + T must have a default constructor and be copyable. + + POINTERS AND REFERENCES TO INTERNAL DATA + swap(), size(), front(), back(), and operator[] functions do + not invalidate pointers or references to internal data. + All other functions have no such guarantee. + + INITIAL VALUE + - size() == 0 + + WHAT THIS OBJECT REPRESENTS + This object is a circular buffer of objects of type T. This means + that when objects are pushed onto one of its ends it does not grow + in size. Instead, it shifts all elements over one to make room for + the new element and the element at the opposing end falls off the + buffer and is lost. + !*/ + + public: + typedef default_memory_manager mem_manager_type; + typedef T value_type; + typedef T type; + + circular_buffer( + ); + /*! + ensures + - #size() == 0 + - this object is properly initialized + !*/ + + explicit circular_buffer( + unsigned long s + ); + /*! + ensures + - #size() == s + - this object is properly initialized + !*/ + + void clear ( + ); + /*! + ensures + - this object has its initial value + - #size() == 0 + !*/ + + T& operator[] ( + unsigned long i + ) const; + /*! + requires + - i < size() + ensures + - returns a non-const reference to the i-th element of this circular buffer + !*/ + + const T& operator[] ( + unsigned long i + ) const; + /*! + requires + - i < size() + ensures + - returns a const reference to the i-th element of this circular buffer + !*/ + + void resize( + unsigned long new_size + ); + /*! + ensures + - #size() == new_size + !*/ + + void assign( + unsigned long new_size, + const T& value + ); + /*! + ensures + - #size() == new_size + - for all valid i: + - (*this)[i] == value + !*/ + + unsigned long size( + ) const; + /*! + ensures + - returns the number of elements in this circular buffer + !*/ + + T& front( + ); + /*! + requires + - size() > 0 + ensures + - returns a reference to (*this)[0] + !*/ + + const T& front( + ) const; + /*! + requires + - size() > 0 + ensures + - returns a const reference to (*this)[0] + !*/ + + T& back( + ); + /*! + requires + - size() > 0 + ensures + - returns a reference to (*this)[size()-1] + !*/ + + const T& back( + ) const; + /*! + requires + - size() > 0 + ensures + - returns a const reference to (*this)[size()-1] + !*/ + + void push_front( + const T& value + ); + /*! + ensures + - #size() == size() + (i.e. the size of this object does not change) + - if (size() != 0) then + - #front() == value + - all items are shifted over such that, + - #(*this)[1] == (*this)[0] + - #(*this)[2] == (*this)[1] + - #(*this)[3] == (*this)[2] + - etc. + - back() is shifted out of the circular buffer + - else + - This function has no effect on this object + !*/ + + void push_back( + const T& value + ); + /*! + ensures + - #size() == size() + (i.e. the size of this object does not change) + - if (size() != 0) then + - #back() == value + - all items are shifted over such that, + - front() is shifted out of the circular buffer + - #(*this)[0] == (*this)[1] + - #(*this)[1] == (*this)[2] + - #(*this)[2] == (*this)[3] + - etc. + - else + - This function has no effect on this object + !*/ + + void swap ( + circular_buffer& item + ); + /*! + ensures + - swaps *this with item + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + void swap ( + circular_buffer& a, + circular_buffer& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + + template < + typename T + > + void serialize ( + const circular_buffer& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + + template < + typename T + > + void deserialize ( + circular_buffer& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + const matrix_exp mat ( + const circular_buffer& m + ); + /*! + ensures + - returns a matrix R such that: + - is_col_vector(R) == true + - R.size() == m.size() + - for all valid r: + R(r) == m[r] + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CIRCULAR_BuFFER_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/sliding_buffer/sliding_buffer_kernel_1.h b/ml/dlib/dlib/sliding_buffer/sliding_buffer_kernel_1.h new file mode 100644 index 000000000..d3e6cc4b4 --- /dev/null +++ b/ml/dlib/dlib/sliding_buffer/sliding_buffer_kernel_1.h @@ -0,0 +1,227 @@ +// Copyright (C) 2004 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SLIDING_BUFFER_KERNEl_1_ +#define DLIB_SLIDING_BUFFER_KERNEl_1_ + +#include "sliding_buffer_kernel_abstract.h" +#include "../algs.h" +#include "../interfaces/enumerable.h" +#include "../serialize.h" + +namespace dlib +{ + + template < + typename T + > + class sliding_buffer_kernel_1 : public enumerable + { + /*! + INITIAL VALUE + - buffer_size == 0 + - buffer == 0 + - buffer_start == 0 + - current == 0 + - at_start_ == true + + CONVENTION + - buffer_size == size() + + - element() == (*this)[current] + - current_element_valid() == (current < buffer_size) && at_start_ == false + - at_start() == at_start_ + + - if (buffer_size != 0) then + - buffer[(buffer_start+i)&(mask)] == operator[](i) + - mask == buffer_size-1 + - else + - buffer == 0 + - buffer_size == 0 + !*/ + + public: + + typedef T type; + + sliding_buffer_kernel_1 ( + ) : + buffer_start(0), + buffer_size(0), + buffer(0), + current(0), + at_start_(true) + {} + + virtual ~sliding_buffer_kernel_1 ( + ) { if (buffer) delete [] buffer; } + + void clear( + ) + { + buffer_size = 0; + if (buffer) delete [] buffer; + buffer = 0; + at_start_ = true; + current = 0; + } + + void set_size ( + unsigned long exp_size + ) + { + at_start_ = true; + if (buffer) delete [] buffer; + buffer_size = 1; + while (exp_size != 0) + { + --exp_size; + buffer_size <<= 1; + } + mask = buffer_size-1; + try { buffer = new T[buffer_size]; } + catch (...) { buffer = 0; buffer_size = 0; throw; } + } + + size_t size ( + ) const { return buffer_size; } + + void rotate_left ( + unsigned long amount + ) { buffer_start = ((buffer_start-amount)&mask); at_start_ = true; } + + void rotate_right ( + unsigned long amount + ) { buffer_start = ((buffer_start+amount)&mask); at_start_ = true;} + + const T& operator[] ( + unsigned long index + ) const { return buffer[(buffer_start+index)&mask]; } + + T& operator[] ( + unsigned long index + ) { return buffer[(buffer_start+index)&mask]; } + + unsigned long get_element_id( + unsigned long index + ) const { return ((buffer_start+index)&mask); } + + unsigned long get_element_index ( + unsigned long element_id + ) const { return ((element_id-buffer_start)&mask);} + + void swap ( + sliding_buffer_kernel_1& item + ) + { + exchange(buffer_start,item.buffer_start); + exchange(buffer_size,item.buffer_size); + exchange(buffer,item.buffer); + exchange(mask,item.mask); + exchange(current,item.current); + exchange(at_start_,item.at_start_); + } + + + bool at_start ( + ) const { return at_start_; } + + void reset ( + ) const { at_start_ = true; } + + bool current_element_valid ( + ) const { return (current < buffer_size) && (at_start_ == false); } + + const T& element ( + ) const { return (*this)[current]; } + + T& element ( + ) { return (*this)[current]; } + + bool move_next ( + ) const + { + if (at_start_ == false) + { + if (current+1 < buffer_size) + { + ++current; + return true; + } + else + { + current = buffer_size; + return false; + } + } + else + { + at_start_ = false; + current = 0; + return (buffer_size != 0); + } + } + + + private: + + // data members + unsigned long buffer_start; + unsigned long buffer_size; + T* buffer; + unsigned long mask; + + + mutable unsigned long current; + mutable bool at_start_; + + // restricted functions + sliding_buffer_kernel_1(sliding_buffer_kernel_1&); // copy constructor + sliding_buffer_kernel_1& operator=(sliding_buffer_kernel_1&); // assignment operator + + }; + + template < + typename T + > + inline void swap ( + sliding_buffer_kernel_1& a, + sliding_buffer_kernel_1& b + ) { a.swap(b); } + + template < + typename T + > + void deserialize ( + sliding_buffer_kernel_1& item, + std::istream& in + ) + { + try + { + item.clear(); + unsigned long size; + deserialize(size,in); + if (size > 0) + { + int count = 0; + while (size != 1) + { + size /= 2; + ++count; + } + item.set_size(count); + + for (unsigned long i = 0; i < item.size(); ++i) + deserialize(item[i],in); + } + } + catch (serialization_error e) + { + item.clear(); + throw serialization_error(e.info + "\n while deserializing object of type sliding_buffer_kernel_1"); + } + } +} + +#endif // DLIB_SLIDING_BUFFER_KERNEl_1_ + diff --git a/ml/dlib/dlib/sliding_buffer/sliding_buffer_kernel_abstract.h b/ml/dlib/dlib/sliding_buffer/sliding_buffer_kernel_abstract.h new file mode 100644 index 000000000..687a9e878 --- /dev/null +++ b/ml/dlib/dlib/sliding_buffer/sliding_buffer_kernel_abstract.h @@ -0,0 +1,205 @@ +// Copyright (C) 2004 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SLIDING_BUFFER_KERNEl_ABSTRACT_ +#ifdef DLIB_SLIDING_BUFFER_KERNEl_ABSTRACT_ + +#include "../algs.h" +#include "../interfaces/enumerable.h" +#include "../serialize.h" + +namespace dlib +{ + + template < + typename T + > + class sliding_buffer : public enumerable + { + /*! + REQUIREMENTS ON T + T must have a default constructor + + INITIAL VALUE + size() == 0 + + ENUMERATION ORDER + The enumerator will iterate over the elements of the sliding_buffer in the + order (*this)[0], (*this)[1], (*this)[2], ... + + WHAT THIS OBJECT REPRESENTS + This object represents an array of T objects. The main + feature of this object is its ability to rotate its contents + left or right. An example will make it clear. + + suppose we have the following buffer (assuming T is a char): + "some data!" <-- the data in the buffer + 9876543210 <-- the index numbers associated with each character + + applying rotate_left(2) to this buffer would give us + "me data!so" + 9876543210 + + if instead of calling rotate_left we call rotate_right(3) instead we would have + "ta!some da" + 9876543210 + + Also note that unless specified otherwise, no member functions + of this object throw exceptions. + !*/ + + public: + + typedef T type; + + sliding_buffer ( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc or any exception thrown by T's constructor. + !*/ + + virtual ~sliding_buffer ( + ); + /*! + ensures + - any resources associated with *this have been released + !*/ + + void clear( + ); + /*! + ensures + - #*this has its initial value + throws + - std::bad_alloc or any exception thrown by T's constructor. + if this exception is thrown then #*this is unusable + until clear() is called and succeeds + !*/ + + void set_size ( + unsigned long exp_size + ); + /*! + requires + - 0 < exp_size < 32 + ensures + - #size() == 2^exp_size + - the value of all elements in the buffer are undefined + - #at_start() == true + throws + - std::bad_alloc or any exception thrown by T's constructor. + if this exception is thrown then #size() == 0 + !*/ + + void rotate_left ( + unsigned long amount + ); + /*! + ensures + - for all i where 0 <= i < size(): + (#*this)[i] == (*this)[(i-amount)&(size()-1)] + i.e. rotates the contents of *this left by amount spaces + - #at_start() == true + !*/ + + void rotate_right ( + unsigned long amount + ); + /*! + ensures + - for all i where 0 <= i < size(): + (#*this)[i] == (*this)[(i+amount)&(size()-1)] + i.e. rotates the contents of *this right by amount spaces + - #at_start() == true + !*/ + + unsigned long get_element_id ( + unsigned long index + ) const; + /*! + requires + - index < size() + ensures + - returns an element id number that uniquely references the element at + the given index. (you can use this id to locate the new position of + an element after the buffer has been rotated) + - returned value is < size() + !*/ + + unsigned long get_element_index ( + unsigned long element_id + ) const; + /*! + require + - element_id < size() + ensures + - returns the index of the element with the given element_id. + ( (*this)[get_element_index(element_id)] will always refer to the same element + no matter where it has been rotated to) + - returned value is < size() + !*/ + + const T& operator[] ( + unsigned long index + ) const; + /*! + requires + - index < size() + ensures + - returns a const reference to the element in *this at position index + !*/ + + T& operator[] ( + unsigned long index + ); + /*! + requires + - index < size() + ensures + - returns a reference to the element in *this at position index + !*/ + + void swap ( + sliding_buffer& item + ); + /*! + ensures + - swaps *this and item + !*/ + + private: + + // restricted functions + sliding_buffer(sliding_buffer&); // copy constructor + sliding_buffer& operator=(sliding_buffer&); // assignment operator + + }; + + template < + typename T + > + void swap ( + sliding_buffer& a, + sliding_buffer& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + + template < + typename T + > + void deserialize ( + sliding_buffer& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +} + +#endif // DLIB_SLIDING_BUFFER_KERNEl_ABSTRACT_ + diff --git a/ml/dlib/dlib/sliding_buffer/sliding_buffer_kernel_c.h b/ml/dlib/dlib/sliding_buffer/sliding_buffer_kernel_c.h new file mode 100644 index 000000000..a7330e4b5 --- /dev/null +++ b/ml/dlib/dlib/sliding_buffer/sliding_buffer_kernel_c.h @@ -0,0 +1,222 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SLIDING_BUFFER_KERNEl_C_ +#define DLIB_SLIDING_BUFFER_KERNEl_C_ + +#include "sliding_buffer_kernel_abstract.h" +#include "../algs.h" +#include "../assert.h" +#include + +namespace dlib +{ + + template < + typename sb_base + > + class sliding_buffer_kernel_c : public sb_base + { + typedef typename sb_base::type T; + + public: + void set_size ( + unsigned long exp_size + ); + + const T& operator[] ( + unsigned long index + ) const; + + T& operator[] ( + unsigned long index + ); + + unsigned long get_element_id ( + unsigned long index + ) const; + + unsigned long get_element_index ( + unsigned long element_id + ) const; + + const T& element ( + ) const; + + T& element ( + ); + + + }; + + template < + typename sb_base + > + inline void swap ( + sliding_buffer_kernel_c& a, + sliding_buffer_kernel_c& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename sb_base + > + void sliding_buffer_kernel_c:: + set_size ( + unsigned long exp_size + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( 0 < exp_size && exp_size < 32, + "\tvoid sliding_buffer::set_size(unsigned long)" + << "\n\texp_size must be some number between 1 and 31" + << "\n\tthis: " << this + << "\n\texp_size: " << exp_size + ); + + // call the real function + sb_base::set_size(exp_size); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename sb_base + > + unsigned long sliding_buffer_kernel_c:: + get_element_id ( + unsigned long index + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT( index < this->size(), + "\tunsigned long sliding_buffer::get_element_id(unsigned long) const" + << "\n\tindex must be in the range 0 to size()-1" + << "\n\tthis: " << this + << "\n\tsize(): " << this->size() + << "\n\tindex: " << index + ); + + // call the real function + return sb_base::get_element_id(index); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename sb_base + > + unsigned long sliding_buffer_kernel_c:: + get_element_index ( + unsigned long element_id + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT( element_id < this->size(), + "\tunsigned long sliding_buffer::get_element_index(unsigned long) const" + << "\n\tid must be in the range 0 to size()-1" + << "\n\tthis: " << this + << "\n\tsize(): " << this->size() + << "\n\tid: " << element_id + ); + + // call the real function + return sb_base::get_element_index(element_id); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename sb_base + > + const typename sb_base::type& sliding_buffer_kernel_c:: + operator[] ( + unsigned long index + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT( index < this->size(), + "\tconst T& sliding_buffer::operator[](unsigned long) const" + << "\n\tindex must be in the range 0 to size()-1" + << "\n\tthis: " << this + << "\n\tsize(): " << this->size() + << "\n\tindex: " << index + ); + + // call the real function + return sb_base::operator[](index); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename sb_base + > + typename sb_base::type& sliding_buffer_kernel_c:: + operator[] ( + unsigned long index + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( index < this->size(), + "\tT& sliding_buffer::operator[](unsigned long)" + << "\n\tindex must be in the range 0 to size()-1" + << "\n\tthis: " << this + << "\n\tsize(): " << this->size() + << "\n\tindex: " << index + ); + + // call the real function + return sb_base::operator[](index); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename sb_base + > + const typename sb_base::type& sliding_buffer_kernel_c:: + element ( + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT(this->current_element_valid() == true, + "\tconst T& sliding_buffer::element" + << "\n\tyou can't access the current element if it doesn't exist" + << "\n\tthis: " << this + ); + + // call the real function + return sb_base::element(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename sb_base + > + typename sb_base::type& sliding_buffer_kernel_c:: + element ( + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(this->current_element_valid() == true, + "\tT& sliding_buffer::element" + << "\n\tyou can't access the current element if it doesn't exist" + << "\n\tthis: " << this + ); + + // call the real function + return sb_base::element(); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SLIDING_BUFFER_KERNEl_C_ + diff --git a/ml/dlib/dlib/smart_pointers.h b/ml/dlib/dlib/smart_pointers.h new file mode 100644 index 000000000..905e88b15 --- /dev/null +++ b/ml/dlib/dlib/smart_pointers.h @@ -0,0 +1,22 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SMART_POINTERs_H_ +#define DLIB_SMART_POINTERs_H_ + +// This is legacy smart pointer code that will likely stop working under default compiler +// flags when C++17 becomes the default standard in compilers. Please consider migrating +// your code to new smart pointers from C++ standard library. +#if (defined(__GNUC__) && ((__GNUC__ >= 4 && __GNUC_MINOR__ >= 8) || (__GNUC__ > 4))) || \ + (defined(__clang__) && ((__clang_major__ >= 3 && __clang_minor__ >= 4))) +#pragma GCC warning "smart_pointers.h is included. This code will fail to compile under C++17" +#endif + +#include + +#include "smart_pointers/shared_ptr.h" +#include "smart_pointers/weak_ptr.h" +#include "smart_pointers/scoped_ptr.h" + +#endif // DLIB_SMART_POINTERs_H_ + + diff --git a/ml/dlib/dlib/smart_pointers/scoped_ptr.h b/ml/dlib/dlib/smart_pointers/scoped_ptr.h new file mode 100644 index 000000000..dd890f330 --- /dev/null +++ b/ml/dlib/dlib/smart_pointers/scoped_ptr.h @@ -0,0 +1,16 @@ +#ifndef DLIB_SCOPED_PTr_H_ +#define DLIB_SCOPED_PTr_H_ + +#include + +namespace dlib { + // Template alias for compatibility with clients using old dlib::scoped_ptr + // Old scoped_ptr implementation is removed completely + // This alias may fail in some reference deduction cases + + template > + using scoped_ptr = std::unique_ptr; + +} + +#endif diff --git a/ml/dlib/dlib/smart_pointers/shared_ptr.h b/ml/dlib/dlib/smart_pointers/shared_ptr.h new file mode 100644 index 000000000..15f7a4919 --- /dev/null +++ b/ml/dlib/dlib/smart_pointers/shared_ptr.h @@ -0,0 +1,492 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SHARED_PTr_ +#define DLIB_SHARED_PTr_ + +#include +#include +#include +#include // for the exceptions +#include "../algs.h" +#include "shared_ptr_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class bad_weak_ptr: public std::exception {}; + +// ---------------------------------------------------------------------------------------- + + template class weak_ptr; + +// ---------------------------------------------------------------------------------------- + + struct shared_ptr_deleter + { + virtual void del(const void* p) = 0; + virtual ~shared_ptr_deleter() {} + + virtual void* get_deleter_void(const std::type_info& t) const = 0; + /*! + ensures + - if (the deleter in this object has typeid() == t) then + - returns a pointer to the deleter + - else + - return 0 + !*/ + }; + + struct shared_ptr_node; + struct weak_ptr_node + { + weak_ptr_node ( + shared_ptr_node* sn + ) : + ref_count(1), + shared_node(sn) + { + DLIB_ASSERT(sn != 0,""); + } + + long ref_count; + shared_ptr_node* shared_node; + }; + + struct shared_ptr_node + { + shared_ptr_node( + ) : + ref_count(1), + del(0), + weak_node(0) + {} + + long ref_count; + shared_ptr_deleter* del; + weak_ptr_node* weak_node; + }; + + struct shared_ptr_static_cast {}; + struct shared_ptr_const_cast {}; + struct shared_ptr_dynamic_cast {}; + +// ---------------------------------------------------------------------------------------- + + template + class shared_ptr + { + /*! + CONVENTION + - get() == data + - unique() == (shared_node != 0) && (shared_node->ref_count == 1) + - if (shared_node != 0) then + - use_count() == shared_node->ref_count + - get() == a valid pointer + - if (we are supposed to use the deleter) then + - shared_node->del == the deleter to use + - else + - shared_node->del == 0 + - else + - use_count() == 0 + - get() == 0 + + + - if (there are any weak_ptrs that reference this->data) then + - shared_node->weak_node->ref_count == the number of referencing weak_ptrs + - else + - shared_node->weak_node == 0 + !*/ + + template + struct deleter_template : public shared_ptr_deleter + { + deleter_template(const D& d_) : d(d_) {} + void del(const void* p) { d((T*)p); } + D d; + + void* get_deleter_void(const std::type_info& t) const + { + if (typeid(D) == t) + return (void*)&d; + else + return 0; + } + }; + + struct default_deleter : public shared_ptr_deleter + { + void del(const void* p) { delete ((T*)p); } + + void* get_deleter_void(const std::type_info&) const + { + return 0; + } + }; + + public: + + typedef T element_type; + + shared_ptr( + ) : data(0), shared_node(0) {} + + template + explicit shared_ptr( + Y* p + ) : data(p) + { + DLIB_ASSERT(p != 0, + "\tshared_ptr::shared_ptr(p)" + << "\n\tp can't be null" + << "\n\tthis: " << this + ); + try + { + shared_node = new shared_ptr_node; + shared_node->del = new default_deleter; + } + catch (...) + { + delete p; + throw; + } + } + + template + shared_ptr( + Y* p, + const D& d + ) : + data(p) + { + DLIB_ASSERT(p != 0, + "\tshared_ptr::shared_ptr(p,d)" + << "\n\tp can't be null" + << "\n\tthis: " << this + ); + try + { + shared_node = 0; + shared_node = new shared_ptr_node; + shared_node->del = new deleter_template(d); + } + catch (...) + { + if (shared_node) delete shared_node; + d(p); + throw; + } + } + + ~shared_ptr() + { + if ( shared_node != 0) + { + if (shared_node->ref_count == 1) + { + // delete the data in the appropriate way + shared_node->del->del(data); + delete shared_node->del; + + // notify any weak_ptrs that the data has now expired + if (shared_node->weak_node) + shared_node->weak_node->shared_node = 0; + + // finally delete the shared_node + delete shared_node; + } + else + { + shared_node->ref_count -= 1; + } + } + } + + shared_ptr( + const shared_ptr& r + ) + { + data = r.data; + shared_node = r.shared_node; + if (shared_node) + shared_node->ref_count += 1; + } + + template + shared_ptr( + const shared_ptr& r, + const shared_ptr_static_cast& + ) + { + data = static_cast(r.data); + if (data != 0) + { + shared_node = r.shared_node; + shared_node->ref_count += 1; + } + else + { + shared_node = 0; + } + } + + template + shared_ptr( + const shared_ptr& r, + const shared_ptr_const_cast& + ) + { + data = const_cast(r.data); + if (data != 0) + { + shared_node = r.shared_node; + shared_node->ref_count += 1; + } + else + { + shared_node = 0; + } + } + + template + shared_ptr( + const shared_ptr& r, + const shared_ptr_dynamic_cast& + ) + { + data = dynamic_cast(r.data); + if (data != 0) + { + shared_node = r.shared_node; + shared_node->ref_count += 1; + } + else + { + shared_node = 0; + } + } + + template + shared_ptr( + const shared_ptr& r + ) + { + data = r.data; + shared_node = r.shared_node; + if (shared_node) + shared_node->ref_count += 1; + } + + + template + explicit shared_ptr( + const weak_ptr& r + ) + { + if (r.expired()) + throw bad_weak_ptr(); + + data = r.data; + shared_node = r.weak_node->shared_node; + shared_node->ref_count += 1; + } + + shared_ptr& operator= ( + const shared_ptr& r + ) + { + shared_ptr(r).swap(*this); + return *this; + } + + template + shared_ptr& operator= ( + const shared_ptr& r + ) + { + shared_ptr(r).swap(*this); + return *this; + } + + void reset() + { + shared_ptr().swap(*this); + } + + template + void reset(Y* p) + { + DLIB_ASSERT(p != 0, + "\tshared_ptr::reset(p)" + << "\n\tp can't be null" + << "\n\tthis: " << this + ); + + shared_ptr(p).swap(*this); + } + + template + void reset( + Y* p, + const D& d + ) + { + DLIB_ASSERT(p != 0, + "\tshared_ptr::reset(p,d)" + << "\n\tp can't be null" + << "\n\tthis: " << this + ); + + shared_ptr(p,d).swap(*this); + } + + T& operator*( + ) const + { + DLIB_ASSERT(get() != 0, + "\tshared_ptr::operator*()" + << "\n\tget() can't be null if you are going to dereference it" + << "\n\tthis: " << this + ); + + return *data; + } + + T* operator->( + ) const + { + DLIB_ASSERT(get() != 0, + "\tshared_ptr::operator->()" + << "\n\tget() can't be null" + << "\n\tthis: " << this + ); + + return data; + } + + T* get() const { return data; } + + bool unique() const + { + return use_count() == 1; + } + + long use_count() const + { + if (shared_node != 0) + return shared_node->ref_count; + else + return 0; + } + + operator bool( + ) const { return get() != 0; } + + void swap(shared_ptr& b) + { + std::swap(data, b.data); + std::swap(shared_node, b.shared_node); + } + + template + D* _get_deleter( + ) const + { + if (shared_node && shared_node->del) + return static_cast(shared_node->del->get_deleter_void(typeid(D))); + else + return 0; + } + + template + bool _private_less ( + const shared_ptr& rhs + ) const + { + return shared_node < rhs.shared_node; + } + + private: + + template friend class shared_ptr; + template friend class weak_ptr; + + T* data; + shared_ptr_node* shared_node; + }; + +// ---------------------------------------------------------------------------------------- + + template + bool operator== ( + const shared_ptr& a, + const shared_ptr& b + ) { return a.get() == b.get(); } + + template + bool operator!= ( + const shared_ptr& a, + const shared_ptr& b + ) { return a.get() != b.get(); } + + template + bool operator< ( + const shared_ptr& a, + const shared_ptr& b + ) + { + return a._private_less(b); + } + + template + void swap( + shared_ptr& a, + shared_ptr& b + ) { a.swap(b); } + + template + shared_ptr static_pointer_cast( + const shared_ptr& r + ) + { + return shared_ptr(r, shared_ptr_static_cast()); + } + + template + shared_ptr const_pointer_cast( + shared_ptr const & r + ) + { + return shared_ptr(r, shared_ptr_const_cast()); + } + + template + shared_ptr dynamic_pointer_cast( + const shared_ptr& r + ) + { + return shared_ptr(r, shared_ptr_dynamic_cast()); + } + + template + std::basic_ostream & operator<< (std::basic_ostream & os, shared_ptr const & p) + { + os << p.get(); + return os; + } + + template + D* get_deleter(const shared_ptr& p) + { + return p.template _get_deleter(); + } + +// ---------------------------------------------------------------------------------------- + +} + + +#endif // DLIB_SHARED_PTr_ + diff --git a/ml/dlib/dlib/smart_pointers/shared_ptr_abstract.h b/ml/dlib/dlib/smart_pointers/shared_ptr_abstract.h new file mode 100644 index 000000000..9fc12c8e6 --- /dev/null +++ b/ml/dlib/dlib/smart_pointers/shared_ptr_abstract.h @@ -0,0 +1,374 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SHARED_PTr_ABSTRACT_ +#ifdef DLIB_SHARED_PTr_ABSTRACT_ + +#include "weak_ptr_abstract.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class bad_weak_ptr: public std::exception {} + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + class shared_ptr + { + /*! + INITIAL VALUE + defined by constructors + + WHAT THIS OBJECT REPRESENTS + This object represents a reference counted smart pointer. Each shared_ptr + contains a pointer to some object and when the last shared_ptr that points + to the object is destructed or reset() then the object is guaranteed to be + deleted. + + This is an implementation of the std::tr1::shared_ptr template from the + document ISO/IEC PDTR 19768, Proposed Draft Technical Report on C++ + Library Extensions. The only deviation from that document is that this + shared_ptr is declared inside the dlib namespace rather than std::tr1. + + THREAD SAFETY + This object is not thread safe. Especially so since it is + reference counted. So you should take care to not have two shared_ptr + objects in different threads that point to the same object. + + If you want a thread safe version of this object you should use the + dlib::shared_ptr_thread_safe object instead. + !*/ + + public: + + typedef T element_type; + + shared_ptr( + ); + /*! + ensures + - #get() == 0 + - #use_count() == 0 + !*/ + + template + explicit shared_ptr( + Y* p + ); + /*! + requires + - p is convertible to a T* type pointer + - p can be deleted by calling "delete p;" and doing so will not throw exceptions + - p != 0 + ensures + - #get() == p + - #use_count() == 1 + - #*this object owns the pointer p + throws + - std::bad_alloc + if this exception is thrown then "delete p;" is called + !*/ + + template + shared_ptr( + Y* p, + const D& d + ); + /*! + requires + - p is convertible to a T* type pointer + - D is copy constructable (and the copy constructor of D doesn't throw) + - p can be deleted by calling "d(p);" and doing so will not throw exceptions + - p != 0 + ensures + - #get() == p + - #use_count() == 1 + - #*this object owns the pointer p + throws + - std::bad_alloc + if this exception is thrown then "d(p);" is called + !*/ + + shared_ptr( + const shared_ptr& r + ); + /*! + ensures + - #get() == #r.get() + - #use_count() == #r.use_count() + - If r is empty, constructs an empty shared_ptr object; otherwise, constructs + a shared_ptr object that shares ownership with r. + !*/ + + template + shared_ptr( + const shared_ptr& r + ); + /*! + requires + - Y* is convertible to T* + ensures + - #get() == #r.get() + - #use_count() == #r.use_count() + - If r is empty, constructs an empty shared_ptr object; otherwise, constructs + a shared_ptr object that shares ownership with r. + !*/ + + template + explicit shared_ptr( + const weak_ptr& r + ); + /*! + requires + - Y* is convertible to T* + ensures + - #get() == #r.get() + - #use_count() == #r.use_count() + - If r is empty, constructs an empty shared_ptr object; otherwise, constructs + a shared_ptr object that shares ownership with r. + throws + - bad_weak_ptr + this exception is thrown if r.expired() == true + !*/ + + ~shared_ptr( + ); + /*! + ensures + - if (use_count() > 1) + - this object destroys itself but otherwise has no effect (i.e. + the pointer get() is still valid and shared between the remaining + shared_ptr objects) + - else if (use_count() == 1) + - deletes the pointer get() by calling delete (or using the deleter passed + to the constructor if one was passed in) + - else + - in this case get() == 0 so there is nothing to do so nothing occurs + !*/ + + shared_ptr& operator= ( + const shared_ptr& r + ); + /*! + ensures + - equivalent to shared_ptr(r).swap(*this). + - returns #*this + !*/ + + template + shared_ptr& operator= ( + const shared_ptr& r + ); + /*! + requires + - Y* is convertible to T* + ensures + - equivalent to shared_ptr(r).swap(*this). + - returns #*this + !*/ + + void reset( + ); + /*! + ensures + - equivalent to shared_ptr().swap(*this) + !*/ + + template + void reset( + Y* p + ); + /*! + requires + - p is convertible to a T* type pointer + - p can be deleted by calling "delete p;" and doing so will not throw exceptions + - p != 0 + ensures + - equivalent to shared_ptr(p).swap(*this) + !*/ + + template + void reset( + Y* p, + const D& d + ); + /*! + requires + - p is convertible to a T* type pointer + - D is copy constructable (and the copy constructor of D doesn't throw) + - p can be deleted by calling "d(p);" and doing so will not throw exceptions + - p != 0 + ensures + - equivalent to shared_ptr(p,d).swap(*this) + !*/ + + T* get( + ) const; + /*! + ensures + - returns the stored pointer + !*/ + + T& operator*( + ) const; + /*! + requires + - get() != 0 + ensures + - returns a reference to *get() + !*/ + + T* operator->( + ) const; + /*! + requires + - get() != 0 + ensures + - returns get() + !*/ + + bool unique( + ) const; + /*! + ensures + - returns (use_count() == 1) + !*/ + + long use_count( + ) const; + /*! + ensures + - The number of shared_ptr objects, *this included, that share ownership with *this, or 0 when *this + is empty. + !*/ + + operator bool( + ) const; + /*! + ensures + - returns (get() != 0) + !*/ + + void swap( + shared_ptr& b + ); + /*! + ensures + - swaps *this and item + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template + bool operator== ( + const shared_ptr& a, + const shared_ptr& b + ); + /*! + ensures + - returns a.get() == b.get() + !*/ + + template + bool operator!= ( + const shared_ptr& a, + const shared_ptr& b + ) { return a.get() != b.get(); } + /*! + ensures + - returns a.get() != b.get() + !*/ + + template + bool operator< ( + const shared_ptr& a, + const shared_ptr& b + ); + /*! + ensures + - Defines an operator< on shared_ptr types appropriate for use in the associative + containers. + !*/ + + template + void swap( + shared_ptr& a, + shared_ptr& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + + template + shared_ptr static_pointer_cast( + const shared_ptr& r + ); + /*! + - if (r.get() == 0) then + - returns shared_ptr() + - else + - returns a shared_ptr object that stores static_cast(r.get()) and shares + ownership with r. + !*/ + + template + shared_ptr const_pointer_cast( + const shared_ptr& r + ); + /*! + - if (r.get() == 0) then + - returns shared_ptr() + - else + - returns a shared_ptr object that stores const_cast(r.get()) and shares + ownership with r. + !*/ + + template + shared_ptr dynamic_pointer_cast( + const shared_ptr& r + ); + /*! + ensures + - if (dynamic_cast(r.get()) returns a nonzero value) then + - returns a shared_ptr object that stores a copy of + dynamic_cast(r.get()) and shares ownership with r + - else + - returns an empty shared_ptr object. + !*/ + + template + std::basic_ostream & operator<< ( + std::basic_ostream & os, + const shared_ptr& p + ); + /*! + ensures + - performs os << p.get() + - returns os + !*/ + + template + D* get_deleter( + const shared_ptr& p + ); + /*! + ensures + - if (*this owns a deleter d of type cv-unqualified D) then + - returns &d + - else + - returns 0 + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SHARED_PTr_ABSTRACT_ + diff --git a/ml/dlib/dlib/smart_pointers/shared_ptr_thread_safe.h b/ml/dlib/dlib/smart_pointers/shared_ptr_thread_safe.h new file mode 100644 index 000000000..31bda5651 --- /dev/null +++ b/ml/dlib/dlib/smart_pointers/shared_ptr_thread_safe.h @@ -0,0 +1,462 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SHARED_THREAD_SAFE_PTr_ +#define DLIB_SHARED_THREAD_SAFE_PTr_ + +#include +#include +#include +#include // for the exceptions +#include "../algs.h" +#include "shared_ptr_thread_safe_abstract.h" +#include "../threads/threads_kernel.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + struct shared_ptr_thread_safe_deleter + { + virtual void del(const void* p) = 0; + virtual ~shared_ptr_thread_safe_deleter() {} + + virtual void* get_deleter_void(const std::type_info& t) const = 0; + /*! + ensures + - if (the deleter in this object has typeid() == t) then + - returns a pointer to the deleter + - else + - return 0 + !*/ + }; + + struct shared_ptr_thread_safe_node + { + shared_ptr_thread_safe_node( + ) : + ref_count(1), + del(0) + {} + + dlib::mutex m; + long ref_count; + shared_ptr_thread_safe_deleter* del; + }; + + struct shared_ptr_ts_static_cast {}; + struct shared_ptr_ts_const_cast {}; + struct shared_ptr_ts_dynamic_cast {}; + +// ---------------------------------------------------------------------------------------- + + template + class shared_ptr_thread_safe + { + /*! + CONVENTION + - get() == data + - unique() == (shared_node != 0) && (shared_node->ref_count == 1) + - if (shared_node != 0) then + - use_count() == shared_node->ref_count + - get() == a valid pointer + - if (we are supposed to use the deleter) then + - shared_node->del == the deleter to use + - else + - shared_node->del == 0 + - else + - use_count() == 0 + - get() == 0 + + !*/ + + template + struct deleter_template : public shared_ptr_thread_safe_deleter + { + deleter_template(const D& d_) : d(d_) {} + void del(const void* p) { d((T*)p); } + D d; + + void* get_deleter_void(const std::type_info& t) const + { + if (typeid(D) == t) + return (void*)&d; + else + return 0; + } + }; + + public: + + typedef T element_type; + + shared_ptr_thread_safe( + ) : data(0), shared_node(0) {} + + template + explicit shared_ptr_thread_safe( + Y* p + ) : data(p) + { + DLIB_ASSERT(p != 0, + "\tshared_ptr::shared_ptr_thread_safe(p)" + << "\n\tp can't be null" + << "\n\tthis: " << this + ); + try + { + shared_node = new shared_ptr_thread_safe_node; + } + catch (...) + { + delete p; + throw; + } + } + + template + shared_ptr_thread_safe( + Y* p, + const D& d + ) : + data(p) + { + DLIB_ASSERT(p != 0, + "\tshared_ptr::shared_ptr_thread_safe(p,d)" + << "\n\tp can't be null" + << "\n\tthis: " << this + ); + try + { + shared_node = 0; + shared_node = new shared_ptr_thread_safe_node; + shared_node->del = new deleter_template(d); + } + catch (...) + { + if (shared_node) delete shared_node; + d(p); + throw; + } + } + + ~shared_ptr_thread_safe() + { + if ( shared_node != 0) + { + shared_node->m.lock(); + if (shared_node->ref_count == 1) + { + // delete the data in the appropriate way + if (shared_node->del) + { + shared_node->del->del(data); + + shared_node->m.unlock(); + delete shared_node->del; + } + else + { + shared_node->m.unlock(); + delete data; + } + + // finally delete the shared_node + delete shared_node; + } + else + { + shared_node->ref_count -= 1; + shared_node->m.unlock(); + } + } + } + + shared_ptr_thread_safe( + const shared_ptr_thread_safe& r + ) + { + data = r.data; + shared_node = r.shared_node; + if (shared_node) + { + auto_mutex M(shared_node->m); + shared_node->ref_count += 1; + } + } + + template + shared_ptr_thread_safe( + const shared_ptr_thread_safe& r, + const shared_ptr_ts_static_cast& + ) + { + data = static_cast(r.data); + if (data != 0) + { + shared_node = r.shared_node; + auto_mutex M(shared_node->m); + shared_node->ref_count += 1; + } + else + { + shared_node = 0; + } + } + + template + shared_ptr_thread_safe( + const shared_ptr_thread_safe& r, + const shared_ptr_ts_const_cast& + ) + { + data = const_cast(r.data); + if (data != 0) + { + shared_node = r.shared_node; + auto_mutex M(shared_node->m); + shared_node->ref_count += 1; + } + else + { + shared_node = 0; + } + } + + template + shared_ptr_thread_safe( + const shared_ptr_thread_safe& r, + const shared_ptr_ts_dynamic_cast& + ) + { + data = dynamic_cast(r.data); + if (data != 0) + { + shared_node = r.shared_node; + auto_mutex M(shared_node->m); + shared_node->ref_count += 1; + } + else + { + shared_node = 0; + } + } + + template + shared_ptr_thread_safe( + const shared_ptr_thread_safe& r + ) + { + data = r.data; + shared_node = r.shared_node; + if (shared_node) + { + auto_mutex M(shared_node->m); + shared_node->ref_count += 1; + } + } + + shared_ptr_thread_safe& operator= ( + const shared_ptr_thread_safe& r + ) + { + shared_ptr_thread_safe(r).swap(*this); + return *this; + } + + template + shared_ptr_thread_safe& operator= ( + const shared_ptr_thread_safe& r + ) + { + shared_ptr_thread_safe(r).swap(*this); + return *this; + } + + void reset() + { + shared_ptr_thread_safe().swap(*this); + } + + template + void reset(Y* p) + { + DLIB_ASSERT(p != 0, + "\tshared_ptr::reset(p)" + << "\n\tp can't be null" + << "\n\tthis: " << this + ); + + shared_ptr_thread_safe(p).swap(*this); + } + + template + void reset( + Y* p, + const D& d + ) + { + DLIB_ASSERT(p != 0, + "\tshared_ptr::reset(p,d)" + << "\n\tp can't be null" + << "\n\tthis: " << this + ); + + shared_ptr_thread_safe(p,d).swap(*this); + } + + T& operator*( + ) const + { + DLIB_ASSERT(get() != 0, + "\tshared_ptr::operator*()" + << "\n\tget() can't be null if you are going to dereference it" + << "\n\tthis: " << this + ); + + return *data; + } + + T* operator->( + ) const + { + DLIB_ASSERT(get() != 0, + "\tshared_ptr::operator->()" + << "\n\tget() can't be null" + << "\n\tthis: " << this + ); + + return data; + } + + T* get() const { return data; } + + bool unique() const + { + return use_count() == 1; + } + + long use_count() const + { + if (shared_node != 0) + { + auto_mutex M(shared_node->m); + return shared_node->ref_count; + } + else + { + return 0; + } + } + + operator bool( + ) const { return get() != 0; } + + void swap(shared_ptr_thread_safe& b) + { + std::swap(data, b.data); + std::swap(shared_node, b.shared_node); + } + + template + D* _get_deleter( + ) const + { + if (shared_node) + { + auto_mutex M(shared_node->m); + if (shared_node->del) + return static_cast(shared_node->del->get_deleter_void(typeid(D))); + } + return 0; + } + + template + bool _private_less ( + const shared_ptr_thread_safe& rhs + ) const + { + return shared_node < rhs.shared_node; + } + + private: + + template friend class shared_ptr_thread_safe; + + T* data; + shared_ptr_thread_safe_node* shared_node; + }; + +// ---------------------------------------------------------------------------------------- + + template + bool operator== ( + const shared_ptr_thread_safe& a, + const shared_ptr_thread_safe& b + ) { return a.get() == b.get(); } + + template + bool operator!= ( + const shared_ptr_thread_safe& a, + const shared_ptr_thread_safe& b + ) { return a.get() != b.get(); } + + template + bool operator< ( + const shared_ptr_thread_safe& a, + const shared_ptr_thread_safe& b + ) + { + return a._private_less(b); + } + + template + void swap( + shared_ptr_thread_safe& a, + shared_ptr_thread_safe& b + ) { a.swap(b); } + + template + shared_ptr_thread_safe static_pointer_cast( + const shared_ptr_thread_safe& r + ) + { + return shared_ptr_thread_safe(r, shared_ptr_ts_static_cast()); + } + + template + shared_ptr_thread_safe const_pointer_cast( + shared_ptr_thread_safe const & r + ) + { + return shared_ptr_thread_safe(r, shared_ptr_ts_const_cast()); + } + + template + shared_ptr_thread_safe dynamic_pointer_cast( + const shared_ptr_thread_safe& r + ) + { + return shared_ptr_thread_safe(r, shared_ptr_ts_dynamic_cast()); + } + + template + std::basic_ostream & operator<< (std::basic_ostream & os, shared_ptr_thread_safe const & p) + { + os << p.get(); + return os; + } + + template + D* get_deleter(const shared_ptr_thread_safe& p) + { + return p.template _get_deleter(); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SHARED_THREAD_SAFE_PTr_ + diff --git a/ml/dlib/dlib/smart_pointers/shared_ptr_thread_safe_abstract.h b/ml/dlib/dlib/smart_pointers/shared_ptr_thread_safe_abstract.h new file mode 100644 index 000000000..472a00464 --- /dev/null +++ b/ml/dlib/dlib/smart_pointers/shared_ptr_thread_safe_abstract.h @@ -0,0 +1,352 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SHARED_PTr_THREAD_SAFE_ABSTRACT_ +#ifdef DLIB_SHARED_PTr_THREAD_SAFE_ABSTRACT_ + +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + class shared_ptr_thread_safe + { + /*! + INITIAL VALUE + defined by constructors + + WHAT THIS OBJECT REPRESENTS + This object represents a reference counted smart pointer. Each shared_ptr_thread_safe + contains a pointer to some object and when the last shared_ptr_thread_safe that points + to the object is destructed or reset() then the object is guaranteed to be + deleted. + + This is an implementation of the std::tr1::shared_ptr template from the + document ISO/IEC PDTR 19768, Proposed Draft Technical Report on C++ + Library Extensions. The only deviation from that document is that this + shared_ptr_thread_safe is declared inside the dlib namespace rather than std::tr1, + this one is explicitly thread safe, and there isn't a corresponding weak_ptr. + + THREAD SAFETY + This is a version of the shared_ptr object that can be used to share pointers + across more than one thread. Note however, that individual instances of this object + must still have access to them serialized by a mutex lock if they are to be modified + by more than one thread. But if you have two different shared_ptr_thread_safe objects + that both point to the same thing from different threads then you are safe. + !*/ + + public: + + typedef T element_type; + + shared_ptr_thread_safe( + ); + /*! + ensures + - #get() == 0 + - #use_count() == 0 + !*/ + + template + explicit shared_ptr_thread_safe( + Y* p + ); + /*! + requires + - p is convertible to a T* type pointer + - p can be deleted by calling "delete p;" and doing so will not throw exceptions + - p != 0 + ensures + - #get() == p + - #use_count() == 1 + - #*this object owns the pointer p + throws + - std::bad_alloc + if this exception is thrown then "delete p;" is called + !*/ + + template + shared_ptr_thread_safe( + Y* p, + const D& d + ); + /*! + requires + - p is convertible to a T* type pointer + - D is copy constructable (and the copy constructor of D doesn't throw) + - p can be deleted by calling "d(p);" and doing so will not throw exceptions + - p != 0 + ensures + - #get() == p + - #use_count() == 1 + - #*this object owns the pointer p + throws + - std::bad_alloc + if this exception is thrown then "d(p);" is called + !*/ + + shared_ptr_thread_safe( + const shared_ptr_thread_safe& r + ); + /*! + ensures + - #get() == #r.get() + - #use_count() == #r.use_count() + - If r is empty, constructs an empty shared_ptr_thread_safe object; otherwise, constructs + a shared_ptr_thread_safe object that shares ownership with r. + !*/ + + template + shared_ptr_thread_safe( + const shared_ptr_thread_safe& r + ); + /*! + requires + - Y* is convertible to T* + ensures + - #get() == #r.get() + - #use_count() == #r.use_count() + - If r is empty, constructs an empty shared_ptr_thread_safe object; otherwise, constructs + a shared_ptr_thread_safe object that shares ownership with r. + !*/ + + ~shared_ptr_thread_safe( + ); + /*! + ensures + - if (use_count() > 1) + - this object destroys itself but otherwise has no effect (i.e. + the pointer get() is still valid and shared between the remaining + shared_ptr_thread_safe objects) + - else if (use_count() == 1) + - deletes the pointer get() by calling delete (or using the deleter passed + to the constructor if one was passed in) + - else + - in this case get() == 0 so there is nothing to do so nothing occurs + !*/ + + shared_ptr_thread_safe& operator= ( + const shared_ptr_thread_safe& r + ); + /*! + ensures + - equivalent to shared_ptr_thread_safe(r).swap(*this). + - returns #*this + !*/ + + template + shared_ptr_thread_safe& operator= ( + const shared_ptr_thread_safe& r + ); + /*! + requires + - Y* is convertible to T* + ensures + - equivalent to shared_ptr_thread_safe(r).swap(*this). + - returns #*this + !*/ + + void reset( + ); + /*! + ensures + - equivalent to shared_ptr_thread_safe().swap(*this) + !*/ + + template + void reset( + Y* p + ); + /*! + requires + - p is convertible to a T* type pointer + - p can be deleted by calling "delete p;" and doing so will not throw exceptions + - p != 0 + ensures + - equivalent to shared_ptr_thread_safe(p).swap(*this) + !*/ + + template + void reset( + Y* p, + const D& d + ); + /*! + requires + - p is convertible to a T* type pointer + - D is copy constructable (and the copy constructor of D doesn't throw) + - p can be deleted by calling "d(p);" and doing so will not throw exceptions + - p != 0 + ensures + - equivalent to shared_ptr_thread_safe(p,d).swap(*this) + !*/ + + T* get( + ) const; + /*! + ensures + - returns the stored pointer + !*/ + + T& operator*( + ) const; + /*! + requires + - get() != 0 + ensures + - returns a reference to *get() + !*/ + + T* operator->( + ) const; + /*! + requires + - get() != 0 + ensures + - returns get() + !*/ + + bool unique( + ) const; + /*! + ensures + - returns (use_count() == 1) + !*/ + + long use_count( + ) const; + /*! + ensures + - The number of shared_ptr_thread_safe objects, *this included, that share ownership with *this, or 0 when *this + is empty. + !*/ + + operator bool( + ) const; + /*! + ensures + - returns (get() != 0) + !*/ + + void swap( + shared_ptr_thread_safe& b + ); + /*! + ensures + - swaps *this and item + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template + bool operator== ( + const shared_ptr_thread_safe& a, + const shared_ptr_thread_safe& b + ); + /*! + ensures + - returns a.get() == b.get() + !*/ + + template + bool operator!= ( + const shared_ptr_thread_safe& a, + const shared_ptr_thread_safe& b + ) { return a.get() != b.get(); } + /*! + ensures + - returns a.get() != b.get() + !*/ + + template + bool operator< ( + const shared_ptr_thread_safe& a, + const shared_ptr_thread_safe& b + ); + /*! + ensures + - Defines an operator< on shared_ptr_thread_safe types appropriate for use in the associative + containers. + !*/ + + template + void swap( + shared_ptr_thread_safe& a, + shared_ptr_thread_safe& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + + template + shared_ptr_thread_safe static_pointer_cast( + const shared_ptr_thread_safe& r + ); + /*! + - if (r.get() == 0) then + - returns shared_ptr_thread_safe() + - else + - returns a shared_ptr_thread_safe object that stores static_cast(r.get()) and shares + ownership with r. + !*/ + + template + shared_ptr_thread_safe const_pointer_cast( + const shared_ptr_thread_safe& r + ); + /*! + - if (r.get() == 0) then + - returns shared_ptr_thread_safe() + - else + - returns a shared_ptr_thread_safe object that stores const_cast(r.get()) and shares + ownership with r. + !*/ + + template + shared_ptr_thread_safe dynamic_pointer_cast( + const shared_ptr_thread_safe& r + ); + /*! + ensures + - if (dynamic_cast(r.get()) returns a nonzero value) then + - returns a shared_ptr_thread_safe object that stores a copy of + dynamic_cast(r.get()) and shares ownership with r + - else + - returns an empty shared_ptr_thread_safe object. + !*/ + + template + std::basic_ostream & operator<< ( + std::basic_ostream & os, + const shared_ptr_thread_safe& p + ); + /*! + ensures + - performs os << p.get() + - returns os + !*/ + + template + D* get_deleter( + const shared_ptr_thread_safe& p + ); + /*! + ensures + - if (*this owns a deleter d of type cv-unqualified D) then + - returns &d + - else + - returns 0 + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SHARED_PTr_THREAD_SAFE_ABSTRACT_ + diff --git a/ml/dlib/dlib/smart_pointers/weak_ptr.h b/ml/dlib/dlib/smart_pointers/weak_ptr.h new file mode 100644 index 000000000..7e3405678 --- /dev/null +++ b/ml/dlib/dlib/smart_pointers/weak_ptr.h @@ -0,0 +1,225 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_WEAK_PTr_ +#define DLIB_WEAK_PTr_ + +#include +#include +#include "shared_ptr.h" +#include "../algs.h" +#include "weak_ptr_abstract.h" + +namespace dlib { + + template < + typename T + > + class weak_ptr + { + + /*! + CONVENTION + - if (weak_node != 0) then + - data == valid pointer to shared data + - weak_node->ref_count == the number of weak_ptrs that reference this->data + - else + - data == 0 + + - expired() == ((weak_node == 0) || (weak_node->shared_node == 0)) + - if (expired() == false) then + - use_count() == weak_node->shared_node->ref_count + - else + - use_count() == 0 + !*/ + + public: + typedef T element_type; + + weak_ptr( + ) : data(0), weak_node(0) + { + } + + template + weak_ptr( + const shared_ptr& r + ) + { + data = r.data; + if (r.shared_node) + { + if (r.shared_node->weak_node) + { + weak_node = r.shared_node->weak_node; + weak_node->ref_count += 1; + } + else + { + weak_node = new weak_ptr_node(r.shared_node); + r.shared_node->weak_node = weak_node; + } + } + else + { + weak_node = 0; + } + } + + weak_ptr( + const weak_ptr& r + ) + { + data = r.data; + weak_node = r.weak_node; + if (weak_node) + weak_node->ref_count += 1; + } + + template + weak_ptr( + const weak_ptr& r + ) + { + data = r.data; + weak_node = r.weak_node; + if (weak_node) + weak_node->ref_count += 1; + } + + ~weak_ptr( + ) + { + if (weak_node) + { + // make note that this weak_ptr is being destroyed + weak_node->ref_count -= 1; + + // if this is the last weak_ptr then we should clean up our stuff + if (weak_node->ref_count == 0) + { + if (expired() == false) + weak_node->shared_node->weak_node = 0; + delete weak_node; + } + } + } + + weak_ptr& operator= ( + const weak_ptr& r + ) + { + weak_ptr(r).swap(*this); + return *this; + } + + template + weak_ptr& operator= ( + const weak_ptr& r + ) + { + weak_ptr(r).swap(*this); + return *this; + } + + template + weak_ptr& operator=( + const shared_ptr& r + ) + { + weak_ptr(r).swap(*this); + return *this; + } + + long use_count( + ) const + { + if (expired()) + return 0; + else + return weak_node->shared_node->ref_count; + } + + bool expired() const { return weak_node == 0 || weak_node->shared_node == 0; } + + shared_ptr lock( + ) const + { + if (expired()) + return shared_ptr(); + else + return shared_ptr(*this); + } + + void reset( + ) + { + weak_ptr().swap(*this); + } + + void swap( + weak_ptr& b + ) + { + std::swap(data, b.data); + std::swap(weak_node, b.weak_node); + } + + template + bool _private_less ( + const weak_ptr& rhs + ) const + { + if (expired()) + { + if (rhs.expired()) + { + return false; + } + else + { + return true; + } + } + else + { + if (rhs.expired()) + { + return false; + } + else + { + // in this case they have both not expired so lets + // compare the shared_node pointers + return (weak_node->shared_node) < (rhs.weak_node->shared_node); + } + } + } + + private: + + template friend class shared_ptr; + template friend class weak_ptr; + + T* data; + weak_ptr_node* weak_node; + }; + + template + bool operator< ( + const weak_ptr& a, + const weak_ptr& b + ) + { + return a._private_less(b); + } + + template + void swap( + weak_ptr& a, + weak_ptr & b + ) { a.swap(b); } +} + +#endif // DLIB_WEAK_PTr_ + + diff --git a/ml/dlib/dlib/smart_pointers/weak_ptr_abstract.h b/ml/dlib/dlib/smart_pointers/weak_ptr_abstract.h new file mode 100644 index 000000000..549684e3a --- /dev/null +++ b/ml/dlib/dlib/smart_pointers/weak_ptr_abstract.h @@ -0,0 +1,193 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_WEAK_PTr_ABSTRACT_ +#ifdef DLIB_WEAK_PTr_ABSTRACT_ + +#include "shared_ptr_abstract.h" + +namespace dlib { + + template < + typename T + > + class weak_ptr + { + + /*! + INITIAL VALUE + defined by constructor + + WHAT THIS OBJECT REPRESENTS + The weak_ptr class template stores a weak reference to an object that is + already managed by a shared_ptr. To access the object, a weak_ptr can + be converted to a shared_ptr using the member function lock(). + + This is an implementation of the std::tr1::weak_ptr template from the + document ISO/IEC PDTR 19768, Proposed Draft Technical Report on C++ + Library Extensions. The only deviation from that document is that this + shared_ptr is declared inside the dlib namespace rather than std::tr1. + !*/ + + public: + typedef T element_type; + + weak_ptr( + ); + /*! + ensures + - #use_count() == 0 + - creates an empty weak_ptr + !*/ + + template + weak_ptr( + const shared_ptr& r + ); + /*! + requires + - Y* must be convertible to T* + ensures + - if (r is empty) then + - constructs an empty weak_ptr object + - else + - constructs a weak_ptr object that shares ownership with r and + stores a copy of the pointer stored in r. + - #use_count() == #r.use_count() + !*/ + + weak_ptr( + const weak_ptr& r + ); + /*! + ensures + - if (r is empty) then + - constructs an empty weak_ptr object + - else + - constructs a weak_ptr object that shares ownership with r and + stores a copy of the pointer stored in r. + - #use_count() == #r.use_count() + !*/ + + template + weak_ptr( + const weak_ptr& r + ); + /*! + requires + - Y* must be convertible to T* + ensures + - if (r is empty) then + - constructs an empty weak_ptr object + - else + - constructs a weak_ptr object that shares ownership with r and + stores a copy of the pointer stored in r. + - #use_count() == #r.use_count() + !*/ + + ~weak_ptr( + ); + /*! + ensures + - destroys this weak_ptr object but has no effect on the object its + stored pointer points to. + !*/ + + weak_ptr& operator= ( + const weak_ptr& r + ); + /*! + ensures + - equivalent to weak_ptr(r).swap(*this) + !*/ + + template + weak_ptr& operator= ( + const weak_ptr& r + ); + /*! + requires + - Y* must be convertible to T* + ensures + - equivalent to weak_ptr(r).swap(*this) + !*/ + + template + weak_ptr& operator=( + const shared_ptr& r + ); + /*! + requires + - Y* must be convertible to T* + ensures + - equivalent to weak_ptr(r).swap(*this) + !*/ + + long use_count( + ) const; + /*! + ensures + - if (*this is empty) then + - returns 0 + - else + - returns the number of shared_ptr instances that share ownership + with *this + !*/ + + bool expired( + ) const; + /*! + ensures + - returns (use_count() == 0) + !*/ + + shared_ptr lock( + ) const; + /*! + ensures + - if (expired()) then + - returns shared_ptr() + - else + - returns shared_ptr(*this) + !*/ + + void reset( + ); + /*! + ensures + - equivalent to weak_ptr().swap(*this) + !*/ + + void swap( + weak_ptr& b + ); + /*! + ensures + - swaps *this and item + !*/ + + }; + + template + bool operator< ( + const weak_ptr& a, + const weak_ptr& b + ); + /*! + ensures + - Defines an operator< on shared_ptr types appropriate for use in the associative + containers. + !*/ + + template + void swap( + weak_ptr& a, + weak_ptr & b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ +} + +#endif // DLIB_WEAK_PTr_ABSTRACT_ + + diff --git a/ml/dlib/dlib/smart_pointers_thread_safe.h b/ml/dlib/dlib/smart_pointers_thread_safe.h new file mode 100644 index 000000000..e00141f08 --- /dev/null +++ b/ml/dlib/dlib/smart_pointers_thread_safe.h @@ -0,0 +1,21 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SMART_POINTERs_THREAD_SAFE_H_ +#define DLIB_SMART_POINTERs_THREAD_SAFE_H_ + +// This is legacy smart pointer code that will likely to stop working under default +// compiler flags when C++17 becomes the default standard in the compilers. +// Please consider migrating your code to contemporary smart pointers from C++ +// standard library. The warning below will help to detect if the deprecated code +// was included from library's clients. +#if (defined(__GNUC__) && ((__GNUC__ >= 4 && __GNUC_MINOR__ >= 8) || (__GNUC__ > 4))) || \ + (defined(__clang__) && ((__clang_major__ >= 3 && __clang_minor__ >= 4))) +#pragma GCC warning "smart_pointers_thread_safe.h is included which will fail to compile under C++17" +#endif + +#include "smart_pointers/shared_ptr_thread_safe.h" + +#endif // DLIB_SMART_POINTERs_THREAD_SAFE_H_ + + + diff --git a/ml/dlib/dlib/sockets.h b/ml/dlib/dlib/sockets.h new file mode 100644 index 000000000..e253587ed --- /dev/null +++ b/ml/dlib/dlib/sockets.h @@ -0,0 +1,20 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SOCKETs_ +#define DLIB_SOCKETs_ + +#include "platform.h" + + +#ifdef WIN32 +#include "sockets/windows.h" +#endif + +#ifndef WIN32 +#include "sockets/posix.h" +#endif + +#include "sockets/sockets_extensions.h" + +#endif // DLIB_SOCKETs_ + diff --git a/ml/dlib/dlib/sockets/posix.h b/ml/dlib/dlib/sockets/posix.h new file mode 100644 index 000000000..d736a20d4 --- /dev/null +++ b/ml/dlib/dlib/sockets/posix.h @@ -0,0 +1,6 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SOCKETS_KERNEl_1_ +#include "sockets_kernel_2.h" +#endif + diff --git a/ml/dlib/dlib/sockets/sockets_extensions.cpp b/ml/dlib/dlib/sockets/sockets_extensions.cpp new file mode 100644 index 000000000..be08c1998 --- /dev/null +++ b/ml/dlib/dlib/sockets/sockets_extensions.cpp @@ -0,0 +1,341 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SOCKETS_EXTENSIONs_CPP +#define DLIB_SOCKETS_EXTENSIONs_CPP + +#include +#include +#include "../sockets.h" +#include "../error.h" +#include "sockets_extensions.h" +#include "../timer.h" +#include "../algs.h" +#include "../timeout.h" +#include "../misc_api.h" +#include "../serialize.h" +#include "../string.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + network_address:: + network_address( + const std::string& full_address + ) + { + std::istringstream sin(full_address); + sin >> *this; + if (!sin || sin.peek() != EOF) + throw invalid_network_address("invalid network address: " + full_address); + } + +// ---------------------------------------------------------------------------------------- + + void serialize( + const network_address& item, + std::ostream& out + ) + { + serialize(item.host_address, out); + serialize(item.port, out); + } + +// ---------------------------------------------------------------------------------------- + + void deserialize( + network_address& item, + std::istream& in + ) + { + deserialize(item.host_address, in); + deserialize(item.port, in); + } + +// ---------------------------------------------------------------------------------------- + + std::ostream& operator<< ( + std::ostream& out, + const network_address& item + ) + { + out << item.host_address << ":" << item.port; + return out; + } + +// ---------------------------------------------------------------------------------------- + + std::istream& operator>> ( + std::istream& in, + network_address& item + ) + { + std::string temp; + in >> temp; + + std::string::size_type pos = temp.find_last_of(":"); + if (pos == std::string::npos) + { + in.setstate(std::ios::badbit); + return in; + } + + item.host_address = temp.substr(0, pos); + try + { + item.port = sa = temp.substr(pos+1); + } catch (std::exception& ) + { + in.setstate(std::ios::badbit); + return in; + } + + + return in; + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + connection* connect ( + const std::string& host_or_ip, + unsigned short port + ) + { + std::string ip; + connection* con; + if (is_ip_address(host_or_ip)) + { + ip = host_or_ip; + } + else + { + if( hostname_to_ip(host_or_ip,ip)) + throw socket_error(ERESOLVE,"unable to resolve '" + host_or_ip + "' in connect()"); + } + + if(create_connection(con,port,ip)) + { + std::ostringstream sout; + sout << "unable to connect to '" << host_or_ip << ":" << port << "'"; + throw socket_error(sout.str()); + } + + return con; + } + +// ---------------------------------------------------------------------------------------- + + connection* connect ( + const network_address& addr + ) + { + return connect(addr.host_address, addr.port); + } + +// ---------------------------------------------------------------------------------------- + + namespace connect_timeout_helpers + { + mutex connect_mutex; + signaler connect_signaler(connect_mutex); + timestamper ts; + long outstanding_connects = 0; + + struct thread_data + { + std::string host_or_ip; + unsigned short port; + connection* con; + bool connect_ended; + bool error_occurred; + }; + + void thread(void* param) + { + thread_data p = *static_cast(param); + try + { + p.con = connect(p.host_or_ip, p.port); + } + catch (...) + { + p.error_occurred = true; + } + + auto_mutex M(connect_mutex); + // report the results back to the connect() call that spawned this + // thread. + static_cast(param)->con = p.con; + static_cast(param)->error_occurred = p.error_occurred; + connect_signaler.broadcast(); + + // wait for the call to connect() that spawned this thread to terminate + // before we delete the thread_data struct. + while (static_cast(param)->connect_ended == false) + connect_signaler.wait(); + + connect_signaler.broadcast(); + --outstanding_connects; + delete static_cast(param); + } + } + + connection* connect ( + const std::string& host_or_ip, + unsigned short port, + unsigned long timeout + ) + { + using namespace connect_timeout_helpers; + + auto_mutex M(connect_mutex); + + const uint64 end_time = ts.get_timestamp() + timeout*1000; + + + // wait until there are less than 100 outstanding connections + while (outstanding_connects > 100) + { + uint64 cur_time = ts.get_timestamp(); + if (end_time > cur_time) + { + timeout = static_cast((end_time - cur_time)/1000); + } + else + { + throw socket_error("unable to connect to '" + host_or_ip + "' because connect timed out"); + } + + connect_signaler.wait_or_timeout(timeout); + } + + + thread_data* data = new thread_data; + data->host_or_ip = host_or_ip.c_str(); + data->port = port; + data->con = 0; + data->connect_ended = false; + data->error_occurred = false; + + + if (create_new_thread(thread, data) == false) + { + delete data; + throw socket_error("unable to connect to '" + host_or_ip); + } + + ++outstanding_connects; + + // wait until we have a connection object + while (data->con == 0) + { + uint64 cur_time = ts.get_timestamp(); + if (end_time > cur_time && data->error_occurred == false) + { + timeout = static_cast((end_time - cur_time)/1000); + } + else + { + // let the thread know that it should terminate + data->connect_ended = true; + connect_signaler.broadcast(); + if (data->error_occurred) + throw socket_error("unable to connect to '" + host_or_ip); + else + throw socket_error("unable to connect to '" + host_or_ip + "' because connect timed out"); + } + + connect_signaler.wait_or_timeout(timeout); + } + + // let the thread know that it should terminate + data->connect_ended = true; + connect_signaler.broadcast(); + return data->con; + } + +// ---------------------------------------------------------------------------------------- + + bool is_ip_address ( + std::string ip + ) + { + for (std::string::size_type i = 0; i < ip.size(); ++i) + { + if (ip[i] == '.') + ip[i] = ' '; + } + std::istringstream sin(ip); + + bool bad = false; + int num; + for (int i = 0; i < 4; ++i) + { + sin >> num; + if (!sin || num < 0 || num > 255) + { + bad = true; + break; + } + } + + if (sin.get() != EOF) + bad = true; + + return !bad; + } + +// ---------------------------------------------------------------------------------------- + + void close_gracefully ( + connection* con, + unsigned long timeout + ) + { + std::unique_ptr ptr(con); + close_gracefully(ptr,timeout); + } + +// ---------------------------------------------------------------------------------------- + + void close_gracefully ( + std::unique_ptr& con, + unsigned long timeout + ) + { + if (!con) + return; + + if(con->shutdown_outgoing()) + { + // there was an error so just close it now and return + con.reset(); + return; + } + + try + { + dlib::timeout t(*con,&connection::shutdown,timeout); + + char junk[100]; + // wait for the other end to close their side + while (con->read(junk,sizeof(junk)) > 0) ; + } + catch (...) + { + con.reset(); + throw; + } + + con.reset(); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SOCKETS_EXTENSIONs_CPP + + diff --git a/ml/dlib/dlib/sockets/sockets_extensions.h b/ml/dlib/dlib/sockets/sockets_extensions.h new file mode 100644 index 000000000..9faa34e01 --- /dev/null +++ b/ml/dlib/dlib/sockets/sockets_extensions.h @@ -0,0 +1,151 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SOCKETS_EXTENSIONs_ +#define DLIB_SOCKETS_EXTENSIONs_ + +#include +#include +#include + +#include "../sockets.h" +#include "../smart_pointers/scoped_ptr.h" +#include "sockets_extensions_abstract.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class invalid_network_address : public dlib::error + { + public: + invalid_network_address(const std::string& msg) : dlib::error(msg) {}; + }; + +// ---------------------------------------------------------------------------------------- + + struct network_address + { + network_address() : port(0){} + + network_address( + const std::string& full_address + ); + + network_address ( + const char* full_address + ) + { + *this = network_address(std::string(full_address)); + } + + network_address( + const std::string& host_address_, + const unsigned short port_ + ) : host_address(host_address_), port(port_) {} + + std::string host_address; + unsigned short port; + }; + +// ---------------------------------------------------------------------------------------- + + inline bool operator < ( + const network_address& a, + const network_address& b + ) + { + if (a.host_address < b.host_address) + return true; + else if (a.host_address > b.host_address) + return false; + else if (a.port < b.port) + return true; + else + return false; + } + + inline bool operator== ( + const network_address& a, + const network_address& b + ) { return a.host_address == b.host_address && a.port == b.port; } + + inline bool operator != ( + const network_address& a, + const network_address& b + ) { return !(a == b); } + +// ---------------------------------------------------------------------------------------- + + void serialize( + const network_address& item, + std::ostream& out + ); + + void deserialize( + network_address& item, + std::istream& in + ); + + std::ostream& operator<< ( + std::ostream& out, + const network_address& item + ); + + std::istream& operator>> ( + std::istream& in, + network_address& item + ); + +// ---------------------------------------------------------------------------------------- + + connection* connect ( + const std::string& host_or_ip, + unsigned short port + ); + +// ---------------------------------------------------------------------------------------- + + connection* connect ( + const network_address& addr + ); + +// ---------------------------------------------------------------------------------------- + + connection* connect ( + const std::string& host_or_ip, + unsigned short port, + unsigned long timeout + ); + +// ---------------------------------------------------------------------------------------- + + bool is_ip_address ( + std::string ip + ); + +// ---------------------------------------------------------------------------------------- + + void close_gracefully ( + connection* con, + unsigned long timeout = 500 + ); + +// ---------------------------------------------------------------------------------------- + + void close_gracefully ( + std::unique_ptr& con, + unsigned long timeout = 500 + ); + +// ---------------------------------------------------------------------------------------- + +} + +#ifdef NO_MAKEFILE +#include "sockets_extensions.cpp" +#endif + +#endif // DLIB_SOCKETS_EXTENSIONs_ + diff --git a/ml/dlib/dlib/sockets/sockets_extensions_abstract.h b/ml/dlib/dlib/sockets/sockets_extensions_abstract.h new file mode 100644 index 000000000..194c22ab2 --- /dev/null +++ b/ml/dlib/dlib/sockets/sockets_extensions_abstract.h @@ -0,0 +1,300 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SOCKETS_EXTENSIONs_ABSTRACT_ +#ifdef DLIB_SOCKETS_EXTENSIONs_ABSTRACT_ + +#include +#include + +#include "sockets_kernel_abstract.h" +#include "../error.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class invalid_network_address : public dlib::error + { + /*! + WHAT THIS OBJECT REPRESENTS + This is the exception thrown by network_address's constructor if the + input is invalid. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + struct network_address + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is simply a container for two things: + - A host machine address which is either an IP address or DNS name + for a machine. + - A port number. + + Together, these things define a machine and port on that machine. + !*/ + + network_address( + ); + /*! + ensures + - host_address == "" + - #port == 0 + !*/ + + network_address( + const std::string& full_address + ); + /*! + ensures + - interprets full_address as a network address of the form: + host_address:port + and assigns each part into #host_address and #port. For example, + network_address("localhost:80") would result in a network_address + object where host_address was "localhost" and port was 80. + throws + - invalid_network_address + This exception is thrown if the full_address string can't be + interpreted as a valid network address. + !*/ + + network_address ( + const char* full_address + ); + /*! + requires + - full_address == a valid pointer to a null terminated string + ensures + - Invoking this constructor is equivalent to performing + network_address(std::string(full_address)) + !*/ + + network_address( + const std::string& host_address_, + const unsigned short port_ + ); + /*! + ensures + - #host_address == host_address_ + - #port == port_ + !*/ + + + std::string host_address; + unsigned short port; + }; + +// ---------------------------------------------------------------------------------------- + + inline bool operator < ( + const network_address& a, + const network_address& b + ); + /*! + ensures + - provides a total ordering over network_address objects so you can use them in + the standard associative containers. The ordering is defined such that if + you sorted network addresses they would sort first on the host_address string + and then, for network_address objects with equal host_address, they would + sort on the port number + !*/ + + inline bool operator== ( + const network_address& a, + const network_address& b + ); + /*! + ensures + - returns true if a and b contain exactly the same address and false otherwise. + That is, the following must be true for this function to return true: + - a.host_address == b.host_address + - a.port == b.port + Note that this means that two addresses which are logically equivalent but + written differently will not compare equal. For example, suppose example.com + has the IP address 10.1.1.1. Then network_address("10.1.1.1:80") and + network_address("example.com:80") really refer to the same network resource + but will nevertheless not compare equal since. + !*/ + + inline bool operator != ( + const network_address& a, + const network_address& b + ); + /*! + ensures + - returns !(a == b) + !*/ + +// ---------------------------------------------------------------------------------------- + + void serialize( + const network_address& item, + std::ostream& out + ); + /*! + ensures + - provides serialization support + !*/ + + void deserialize( + network_address& item, + std::istream& in + ); + /*! + ensures + - provides deserialization support + !*/ + + std::ostream& operator<< ( + std::ostream& out, + const network_address& item + ); + /*! + ensures + - writes the given network_address to the output stream. The format is the + host_address, then a colon, then the port number. So for example: + cout << network_address("localhost", 80); + would print: + localhost:80 + - returns #out + !*/ + + std::istream& operator>> ( + std::istream& in, + network_address& item + ); + /*! + ensures + - reads a network_address from the given input stream. The expected format is + the same as the one used to print them by the above operator<<() routine. + - returns #in + - if (there is an error reading the network_address) then + - #in.good() == false + !*/ + +// ---------------------------------------------------------------------------------------- + + connection* connect ( + const std::string& host_or_ip, + unsigned short port + ); + /*! + ensures + - returns a connection object that is connected to the given host at the + given port + throws + - dlib::socket_error + This exception is thrown if there is some problem that prevents us from + creating the connection + - std::bad_alloc + !*/ + +// ---------------------------------------------------------------------------------------- + + connection* connect ( + const network_address& addr + ); + /*! + ensures + - returns connect(addr.host_address, addr_port); + !*/ + +// ---------------------------------------------------------------------------------------- + + connection* connect ( + const std::string& host_or_ip, + unsigned short port, + unsigned long timeout + ); + /*! + ensures + - returns a connection object that is connected to the given host at the + given port. + - blocks for at most timeout milliseconds + throws + - dlib::socket_error + This exception is thrown if there is some problem that prevents us from + creating the connection or if timeout milliseconds elapses before the + connect is successful. + - std::bad_alloc + !*/ + +// ---------------------------------------------------------------------------------------- + + + bool is_ip_address ( + std::string ip + ); + /*! + ensures + - if (ip is a valid ip address) then + - returns true + - else + - returns false + !*/ + +// ---------------------------------------------------------------------------------------- + + void close_gracefully ( + connection* con, + unsigned long timeout = 500 + ); + /*! + requires + - con == a valid pointer to a connection object or 0 + ensures + - This function does nothing if con == 0, otherwise it performs the following: + - performs a graceful close of the given connection and if it takes longer + than timeout milliseconds to complete then forces the connection closed. + - Specifically, a graceful close means that the outgoing part of con is + closed (a FIN is sent) and then we wait for the other end to to close + their end of the connection. This way any data still on its way to + the other end of the connection will be received properly. + - This function will block until the graceful close is completed or we + timeout. + - calls "delete con;". Thus con is no longer a valid pointer after this + function has finished. + throws + - std::bad_alloc or dlib::thread_error + If either of these exceptions are thrown con will still be closed via + "delete con;" + !*/ + +// ---------------------------------------------------------------------------------------- + + void close_gracefully ( + std::unique_ptr& con, + unsigned long timeout = 500 + ); + /*! + requires + - con == a valid pointer to a connection object or con.get() == 0 + ensures + - This function does nothing if con.get() == 0, otherwise it performs the + following: + - performs a graceful close of the given connection and if it takes longer + than timeout milliseconds to complete then forces the connection closed. + - Specifically, a graceful close means that the outgoing part of con is + closed (a FIN is sent) and then we wait for the other end to to close + their end of the connection. This way any data still on its way to + the other end of the connection will be received properly. + - This function will block until the graceful close is completed or we + timeout. + - #con.get() == 0. Thus con is no longer a valid pointer after this + function has finished. + throws + - std::bad_alloc or dlib::thread_error + If either of these exceptions are thrown con will still be closed and + deleted (i.e. #con.get() == 0). + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SOCKETS_EXTENSIONs_ABSTRACT_ + + diff --git a/ml/dlib/dlib/sockets/sockets_kernel_1.cpp b/ml/dlib/dlib/sockets/sockets_kernel_1.cpp new file mode 100644 index 000000000..55e39569f --- /dev/null +++ b/ml/dlib/dlib/sockets/sockets_kernel_1.cpp @@ -0,0 +1,979 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net), Miguel Grinberg +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SOCKETS_KERNEL_1_CPp_ +#define DLIB_SOCKETS_KERNEL_1_CPp_ +#include "../platform.h" + +#ifdef WIN32 + +#include + +#ifndef _WINSOCKAPI_ +#define _WINSOCKAPI_ /* Prevent inclusion of winsock.h in windows.h */ +#endif + +#include "../windows_magic.h" + +#include "sockets_kernel_1.h" + +#include + +#ifndef NI_MAXHOST +#define NI_MAXHOST 1025 +#endif + + +// tell visual studio to link to the libraries we need if we are +// in fact using visual studio +#ifdef _MSC_VER +#pragma comment (lib, "ws2_32.lib") +#endif + +#include "../assert.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class SOCKET_container + { + /*! + This object is just a wrapper around the SOCKET type. It exists + so that we can #include the windows.h and Winsock2.h header files + in this cpp file and not at all in the header file. + !*/ + public: + SOCKET_container ( + SOCKET s = INVALID_SOCKET + ) : val(s) {} + + SOCKET val; + operator SOCKET&() { return val; } + + SOCKET_container& operator= ( + const SOCKET& s + ) { val = s; return *this; } + + bool operator== ( + const SOCKET& s + ) const { return s == val; } + }; + +// ---------------------------------------------------------------------------------------- +// stuff to ensure that WSAStartup() is always called before any sockets stuff is needed + + namespace sockets_kernel_1_mutex + { + mutex startup_lock; + } + + class sockets_startupdown + { + public: + sockets_startupdown(); + ~sockets_startupdown() { WSACleanup( ); } + + }; + sockets_startupdown::sockets_startupdown ( + ) + { + WSADATA wsaData; + WSAStartup (MAKEWORD(2,0), &wsaData); + } + + void sockets_startup() + { + // mutex crap to make this function thread-safe + sockets_kernel_1_mutex::startup_lock.lock(); + static sockets_startupdown a; + sockets_kernel_1_mutex::startup_lock.unlock(); + } + +// ---------------------------------------------------------------------------------------- + + // lookup functions + + int + get_local_hostname ( + std::string& hostname + ) + { + // ensure that WSAStartup has been called and WSACleanup will eventually + // be called when program ends + sockets_startup(); + + try + { + + char temp[NI_MAXHOST]; + if (gethostname(temp,NI_MAXHOST) == SOCKET_ERROR ) + { + return OTHER_ERROR; + } + + hostname = temp; + } + catch (...) + { + return OTHER_ERROR; + } + + return 0; + } + +// ----------------- + + int + hostname_to_ip ( + const std::string& hostname, + std::string& ip, + int n + ) + { + // ensure that WSAStartup has been called and WSACleanup will eventually + // be called when program ends + sockets_startup(); + + try + { + // lock this mutex since gethostbyname isn't really thread safe + auto_mutex M(sockets_kernel_1_mutex::startup_lock); + + // if no hostname was given then return error + if ( hostname.empty()) + return OTHER_ERROR; + + hostent* address; + address = gethostbyname(hostname.c_str()); + + if (address == 0) + { + return OTHER_ERROR; + } + + // find the nth address + in_addr* addr = reinterpret_cast(address->h_addr_list[0]); + for (int i = 1; i <= n; ++i) + { + addr = reinterpret_cast(address->h_addr_list[i]); + + // if there is no nth address then return error + if (addr == 0) + return OTHER_ERROR; + } + + char* resolved_ip = inet_ntoa(*addr); + + // check if inet_ntoa returned an error + if (resolved_ip == NULL) + { + return OTHER_ERROR; + } + + ip.assign(resolved_ip); + + } + catch(...) + { + return OTHER_ERROR; + } + + return 0; + } + +// ----------------- + + int + ip_to_hostname ( + const std::string& ip, + std::string& hostname + ) + { + // ensure that WSAStartup has been called and WSACleanup will eventually + // be called when program ends + sockets_startup(); + + try + { + // lock this mutex since gethostbyaddr isn't really thread safe + auto_mutex M(sockets_kernel_1_mutex::startup_lock); + + // if no ip was given then return error + if (ip.empty()) + return OTHER_ERROR; + + hostent* address; + unsigned long ipnum = inet_addr(ip.c_str()); + + // if inet_addr couldn't convert ip then return an error + if (ipnum == INADDR_NONE) + { + return OTHER_ERROR; + } + address = gethostbyaddr(reinterpret_cast(&ipnum),4,AF_INET); + + // check if gethostbyaddr returned an error + if (address == 0) + { + return OTHER_ERROR; + } + hostname.assign(address->h_name); + + } + catch (...) + { + return OTHER_ERROR; + } + return 0; + + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // connection object +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + connection:: + connection( + SOCKET_container sock, + unsigned short foreign_port, + const std::string& foreign_ip, + unsigned short local_port, + const std::string& local_ip + ) : + user_data(0), + connection_socket(*(new SOCKET_container())), + connection_foreign_port(foreign_port), + connection_foreign_ip(foreign_ip), + connection_local_port(local_port), + connection_local_ip(local_ip), + sd(false), + sdo(false), + sdr(0) + { + connection_socket = sock; + } + +// ---------------------------------------------------------------------------------------- + + connection:: + ~connection ( + ) + { + if (connection_socket != INVALID_SOCKET) + closesocket(connection_socket); + delete &connection_socket; + } + +// ---------------------------------------------------------------------------------------- + + int connection:: + disable_nagle() + { + int flag = 1; + int status = setsockopt( connection_socket, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(flag) ); + + if (status == SOCKET_ERROR) + return OTHER_ERROR; + else + return 0; + } + +// ---------------------------------------------------------------------------------------- + + long connection:: + write ( + const char* buf, + long num + ) + { + const long old_num = num; + long status; + const long max_send_length = 1024*1024*100; + while (num > 0) + { + // Make sure to cap the max value num can take on so that if it is + // really large (it might be big on 64bit platforms) so that the OS + // can't possibly get upset about it being large. + const long length = std::min(max_send_length, num); + if ( (status = send(connection_socket,buf,length,0)) == SOCKET_ERROR) + { + if (sdo_called()) + return SHUTDOWN; + else + return OTHER_ERROR; + } + num -= status; + buf += status; + } + return old_num; + } + +// ---------------------------------------------------------------------------------------- + + long connection:: + read ( + char* buf, + long num + ) + { + const long max_recv_length = 1024*1024*100; + // Make sure to cap the max value num can take on so that if it is + // really large (it might be big on 64bit platforms) so that the OS + // can't possibly get upset about it being large. + const long length = std::min(max_recv_length, num); + long status = recv(connection_socket,buf,length,0); + if (status == SOCKET_ERROR) + { + // if this error is the result of a shutdown call then return SHUTDOWN + if (sd_called()) + return SHUTDOWN; + else + return OTHER_ERROR; + } + else if (status == 0 && sd_called()) + { + return SHUTDOWN; + } + return status; + } + +// ---------------------------------------------------------------------------------------- + + long connection:: + read ( + char* buf, + long num, + unsigned long timeout + ) + { + if (readable(timeout) == false) + return TIMEOUT; + + const long max_recv_length = 1024*1024*100; + // Make sure to cap the max value num can take on so that if it is + // really large (it might be big on 64bit platforms) so that the OS + // can't possibly get upset about it being large. + const long length = std::min(max_recv_length, num); + long status = recv(connection_socket,buf,length,0); + if (status == SOCKET_ERROR) + { + // if this error is the result of a shutdown call then return SHUTDOWN + if (sd_called()) + return SHUTDOWN; + else + return OTHER_ERROR; + } + else if (status == 0 && sd_called()) + { + return SHUTDOWN; + } + return status; + } + +// ---------------------------------------------------------------------------------------- + + bool connection:: + readable ( + unsigned long timeout + ) const + { + fd_set read_set; + // initialize read_set + FD_ZERO(&read_set); + + // add the listening socket to read_set + FD_SET(connection_socket, &read_set); + + // setup a timeval structure + timeval time_to_wait; + time_to_wait.tv_sec = static_cast(timeout/1000); + time_to_wait.tv_usec = static_cast((timeout%1000)*1000); + + // wait on select + int status = select(0,&read_set,0,0,&time_to_wait); + + // if select timed out or there was an error + if (status <= 0) + return false; + + // data is ready to be read + return true; + } + +// ---------------------------------------------------------------------------------------- + + int connection:: + shutdown_outgoing ( + ) + { + sd_mutex.lock(); + if (sdo || sd) + { + sd_mutex.unlock(); + return sdr; + } + sdo = true; + sdr = ::shutdown(connection_socket,SD_SEND); + + // convert -1 error code into the OTHER_ERROR error code + if (sdr == -1) + sdr = OTHER_ERROR; + + int temp = sdr; + + sd_mutex.unlock(); + return temp; + } + +// ---------------------------------------------------------------------------------------- + + int connection:: + shutdown ( + ) + { + sd_mutex.lock(); + if (sd) + { + sd_mutex.unlock(); + return sdr; + } + sd = true; + SOCKET stemp = connection_socket; + connection_socket = INVALID_SOCKET; + sdr = closesocket(stemp); + + // convert SOCKET_ERROR error code into the OTHER_ERROR error code + if (sdr == SOCKET_ERROR) + sdr = OTHER_ERROR; + + int temp = sdr; + + sd_mutex.unlock(); + return temp; + } + +// ---------------------------------------------------------------------------------------- + + connection::socket_descriptor_type connection:: + get_socket_descriptor ( + ) const + { + return connection_socket.val; + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // listener object +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + listener:: + listener( + SOCKET_container sock, + unsigned short port, + const std::string& ip + ) : + listening_socket(*(new SOCKET_container)), + listening_port(port), + listening_ip(ip), + inaddr_any(listening_ip.empty()) + { + listening_socket = sock; + } + +// ---------------------------------------------------------------------------------------- + + listener:: + ~listener ( + ) + { + closesocket(listening_socket); + delete &listening_socket; + } + +// ---------------------------------------------------------------------------------------- + + int listener:: + accept ( + std::unique_ptr& new_connection, + unsigned long timeout + ) + { + new_connection.reset(0); + connection* con; + int status = this->accept(con, timeout); + + if (status == 0) + new_connection.reset(con); + + return status; + } + +// ---------------------------------------------------------------------------------------- + + int listener:: + accept ( + connection*& new_connection, + unsigned long timeout + ) + { + SOCKET incoming; + sockaddr_in incomingAddr; + int length = sizeof(sockaddr_in); + + // implement timeout with select if timeout is > 0 + if (timeout > 0) + { + fd_set read_set; + // initialize read_set + FD_ZERO(&read_set); + + // add the listening socket to read_set + FD_SET(listening_socket, &read_set); + + // setup a timeval structure + timeval time_to_wait; + time_to_wait.tv_sec = static_cast(timeout/1000); + time_to_wait.tv_usec = static_cast((timeout%1000)*1000); + + + // wait on select + int status = select(0,&read_set,0,0,&time_to_wait); + + // if select timed out + if (status == 0) + return TIMEOUT; + + // if select returned an error + if (status == SOCKET_ERROR) + return OTHER_ERROR; + + } + + + // call accept to get a new connection + incoming=::accept(listening_socket,reinterpret_cast(&incomingAddr),&length); + + // if there was an error return OTHER_ERROR + if ( incoming == INVALID_SOCKET ) + return OTHER_ERROR; + + + // get the port of the foreign host into foreign_port + int foreign_port = ntohs(incomingAddr.sin_port); + + // get the IP of the foreign host into foreign_ip + std::string foreign_ip; + { + char* foreign_ip_temp = inet_ntoa(incomingAddr.sin_addr); + + // check if inet_ntoa() returned an error + if (foreign_ip_temp == NULL) + { + closesocket(incoming); + return OTHER_ERROR; + } + + foreign_ip.assign(foreign_ip_temp); + } + + + // get the local ip + std::string local_ip; + if (inaddr_any == true) + { + sockaddr_in local_info; + length = sizeof(sockaddr_in); + // get the local sockaddr_in structure associated with this new connection + if ( getsockname ( + incoming, + reinterpret_cast(&local_info), + &length + ) == SOCKET_ERROR + ) + { // an error occurred + closesocket(incoming); + return OTHER_ERROR; + } + char* temp = inet_ntoa(local_info.sin_addr); + + // check if inet_ntoa() returned an error + if (temp == NULL) + { + closesocket(incoming); + return OTHER_ERROR; + } + local_ip.assign(temp); + } + else + { + local_ip = listening_ip; + } + + + // set the SO_OOBINLINE option + int flag_value = 1; + if (setsockopt(incoming,SOL_SOCKET,SO_OOBINLINE,reinterpret_cast(&flag_value),sizeof(int)) == SOCKET_ERROR ) + { + closesocket(incoming); + return OTHER_ERROR; + } + + + // make a new connection object for this new connection + try + { + new_connection = new connection ( + incoming, + foreign_port, + foreign_ip, + listening_port, + local_ip + ); + } + catch (...) { closesocket(incoming); return OTHER_ERROR; } + + return 0; + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // socket creation functions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + int create_listener ( + std::unique_ptr& new_listener, + unsigned short port, + const std::string& ip + ) + { + new_listener.reset(); + listener* temp; + int status = create_listener(temp,port,ip); + + if (status == 0) + new_listener.reset(temp); + + return status; + } + + int create_listener ( + listener*& new_listener, + unsigned short port, + const std::string& ip + ) + { + // ensure that WSAStartup has been called and WSACleanup will eventually + // be called when program ends + sockets_startup(); + + sockaddr_in sa; // local socket structure + ZeroMemory(&sa,sizeof(sockaddr_in)); // initialize sa + + SOCKET sock = socket (AF_INET, SOCK_STREAM, 0); // get a new socket + + // if socket() returned an error then return OTHER_ERROR + if (sock == INVALID_SOCKET ) + { + return OTHER_ERROR; + } + + // set the local socket structure + sa.sin_family = AF_INET; + sa.sin_port = htons(port); + if (ip.empty()) + { + // if the listener should listen on any IP + sa.sin_addr.S_un.S_addr = htons(INADDR_ANY); + } + else + { + // if there is a specific ip to listen on + sa.sin_addr.S_un.S_addr = inet_addr(ip.c_str()); + // if inet_addr couldn't convert the ip then return an error + if ( sa.sin_addr.S_un.S_addr == INADDR_NONE ) + { + closesocket(sock); + return OTHER_ERROR; + } + } + + // set the SO_REUSEADDR option + int flag_value = 1; + setsockopt(sock,SOL_SOCKET,SO_REUSEADDR,reinterpret_cast(&flag_value),sizeof(int)); + + // bind the new socket to the requested port and ip + if (bind(sock,reinterpret_cast(&sa),sizeof(sockaddr_in))==SOCKET_ERROR) + { + const int err = WSAGetLastError(); + // if there was an error + closesocket(sock); + + // if the port is already bound then return PORTINUSE + if (err == WSAEADDRINUSE) + return PORTINUSE; + else + return OTHER_ERROR; + } + + + // tell the new socket to listen + if ( listen(sock,SOMAXCONN) == SOCKET_ERROR) + { + const int err = WSAGetLastError(); + // if there was an error return OTHER_ERROR + closesocket(sock); + + // if the port is already bound then return PORTINUSE + if (err == WSAEADDRINUSE) + return PORTINUSE; + else + return OTHER_ERROR; + } + + // determine the port used if necessary + if (port == 0) + { + sockaddr_in local_info; + int length = sizeof(sockaddr_in); + if ( getsockname ( + sock, + reinterpret_cast(&local_info), + &length + ) == SOCKET_ERROR + ) + { + closesocket(sock); + return OTHER_ERROR; + } + port = ntohs(local_info.sin_port); + } + + + // initialize a listener object on the heap with the new socket + try { new_listener = new listener(sock,port,ip); } + catch(...) { closesocket(sock); return OTHER_ERROR; } + + return 0; + } + +// ---------------------------------------------------------------------------------------- + + int create_connection ( + std::unique_ptr& new_connection, + unsigned short foreign_port, + const std::string& foreign_ip, + unsigned short local_port, + const std::string& local_ip + ) + { + new_connection.reset(); + connection* temp; + int status = create_connection(temp,foreign_port, foreign_ip, local_port, local_ip); + + if (status == 0) + new_connection.reset(temp); + + return status; + } + + int create_connection ( + connection*& new_connection, + unsigned short foreign_port, + const std::string& foreign_ip, + unsigned short local_port, + const std::string& local_ip + ) + { + // ensure that WSAStartup has been called and WSACleanup + // will eventually be called when program ends + sockets_startup(); + + + sockaddr_in local_sa; // local socket structure + sockaddr_in foreign_sa; // foreign socket structure + ZeroMemory(&local_sa,sizeof(sockaddr_in)); // initialize local_sa + ZeroMemory(&foreign_sa,sizeof(sockaddr_in)); // initialize foreign_sa + + int length; + + SOCKET sock = socket (AF_INET, SOCK_STREAM, 0); // get a new socket + + // if socket() returned an error then return OTHER_ERROR + if (sock == INVALID_SOCKET ) + { + return OTHER_ERROR; + } + + // set the foreign socket structure + foreign_sa.sin_family = AF_INET; + foreign_sa.sin_port = htons(foreign_port); + foreign_sa.sin_addr.S_un.S_addr = inet_addr(foreign_ip.c_str()); + + // if inet_addr couldn't convert the ip then return an error + if ( foreign_sa.sin_addr.S_un.S_addr == INADDR_NONE ) + { + closesocket(sock); + return OTHER_ERROR; + } + + + // set up the local socket structure + local_sa.sin_family = AF_INET; + + // set the local ip + if (local_ip.empty()) + { + // if the listener should listen on any IP + local_sa.sin_addr.S_un.S_addr = htons(INADDR_ANY); + } + else + { + // if there is a specific ip to listen on + local_sa.sin_addr.S_un.S_addr = inet_addr(local_ip.c_str()); + + // if inet_addr couldn't convert the ip then return an error + if (local_sa.sin_addr.S_un.S_addr == INADDR_NONE) + { + closesocket(sock); + return OTHER_ERROR; + } + } + + // set the local port + local_sa.sin_port = htons(local_port); + + + + // bind the new socket to the requested local port and local ip + if ( bind ( + sock, + reinterpret_cast(&local_sa), + sizeof(sockaddr_in) + ) == SOCKET_ERROR + ) + { + const int err = WSAGetLastError(); + // if there was an error + closesocket(sock); + + // if the port is already bound then return PORTINUSE + if (err == WSAEADDRINUSE) + return PORTINUSE; + else + return OTHER_ERROR; + } + + // connect the socket + if (connect ( + sock, + reinterpret_cast(&foreign_sa), + sizeof(sockaddr_in) + ) == SOCKET_ERROR + ) + { + const int err = WSAGetLastError(); + closesocket(sock); + // if the port is already bound then return PORTINUSE + if (err == WSAEADDRINUSE) + return PORTINUSE; + else + return OTHER_ERROR; + } + + + + // determine the local port and IP and store them in used_local_ip + // and used_local_port + int used_local_port; + std::string used_local_ip; + sockaddr_in local_info; + if (local_port == 0) + { + length = sizeof(sockaddr_in); + if (getsockname ( + sock, + reinterpret_cast(&local_info), + &length + ) == SOCKET_ERROR + ) + { + closesocket(sock); + return OTHER_ERROR; + } + used_local_port = ntohs(local_info.sin_port); + } + else + { + used_local_port = local_port; + } + + // determine real local ip + if (local_ip.empty()) + { + // if local_port is not 0 then we must fill the local_info structure + if (local_port != 0) + { + length = sizeof(sockaddr_in); + if ( getsockname ( + sock, + reinterpret_cast(&local_info), + &length + ) == SOCKET_ERROR + ) + { + closesocket(sock); + return OTHER_ERROR; + } + } + char* temp = inet_ntoa(local_info.sin_addr); + + // check if inet_ntoa returned an error + if (temp == NULL) + { + closesocket(sock); + return OTHER_ERROR; + } + used_local_ip.assign(temp); + } + else + { + used_local_ip = local_ip; + } + + // set the SO_OOBINLINE option + int flag_value = 1; + if (setsockopt(sock,SOL_SOCKET,SO_OOBINLINE,reinterpret_cast(&flag_value),sizeof(int)) == SOCKET_ERROR ) + { + closesocket(sock); + return OTHER_ERROR; + } + + // initialize a connection object on the heap with the new socket + try + { + new_connection = new connection ( + sock, + foreign_port, + foreign_ip, + used_local_port, + used_local_ip + ); + } + catch(...) {closesocket(sock); return OTHER_ERROR; } + + return 0; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // WIN32 + +#endif // DLIB_SOCKETS_KERNEL_1_CPp_ + diff --git a/ml/dlib/dlib/sockets/sockets_kernel_1.h b/ml/dlib/dlib/sockets/sockets_kernel_1.h new file mode 100644 index 000000000..5fb73ecd6 --- /dev/null +++ b/ml/dlib/dlib/sockets/sockets_kernel_1.h @@ -0,0 +1,351 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net), Miguel Grinberg +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SOCKETS_KERNEl_1_ +#define DLIB_SOCKETS_KERNEl_1_ + +#ifdef DLIB_ISO_CPP_ONLY +#error "DLIB_ISO_CPP_ONLY is defined so you can't use this OS dependent code. Turn DLIB_ISO_CPP_ONLY off if you want to use it." +#endif + +#include "sockets_kernel_abstract.h" + +#include +#include + +#include "../algs.h" +#include "../threads.h" +#include "../uintn.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + // forward declarations + class socket_factory; + class listener; + class SOCKET_container; + +// ---------------------------------------------------------------------------------------- + + // lookup functions + + int + get_local_hostname ( + std::string& hostname + ); + +// ----------------- + + int + hostname_to_ip ( + const std::string& hostname, + std::string& ip, + int n = 0 + ); + +// ----------------- + + int + ip_to_hostname ( + const std::string& ip, + std::string& hostname + ); + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // connection object +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class connection + { + /*! + INITIAL_VALUE + - sd == false + - sdo == false + - sdr == 0 + + CONVENTION + - connection_socket == the socket handle for this connection. + - connection_foreign_port == the port that foreign host is using for + this connection. + - connection_foreign_ip == a string containing the IP address of the + foreign host. + - connection_local_port == the port that the local host is using for + this connection. + - connection_local_ip == a string containing the IP address of the + local interface being used by this connection. + + - sd == if shutdown() has been called then true else false. + - sdo == if shutdown_outgoing() has been called then true else false. + - sdr == the return value of shutdown() if it has been called. if it + hasn't been called then 0. + + !*/ + + friend class listener; // make listener a friend of connection + // make create_connection a friend of connection + friend int create_connection ( + connection*& new_connection, + unsigned short foreign_port, + const std::string& foreign_ip, + unsigned short local_port, + const std::string& local_ip + ); + + public: + + ~connection ( + ); + + void* user_data; + + long write ( + const char* buf, + long num + ); + + long read ( + char* buf, + long num + ); + + long read ( + char* buf, + long num, + unsigned long timeout + ); + + unsigned short get_local_port ( + ) const { return connection_local_port; } + + unsigned short get_foreign_port ( + ) const { return connection_foreign_port; } + + const std::string& get_local_ip ( + ) const { return connection_local_ip; } + + const std::string& get_foreign_ip ( + ) const { return connection_foreign_ip; } + + int shutdown_outgoing ( + ); + + int shutdown ( + ); + + // I would use SOCKET here but I don't want to include the windows + // header files since they bring in a bunch of unpleasantness. So + // I'm doing this instead which should ultimately be the same type + // as the SOCKET win the windows API. + typedef unsigned_type::type socket_descriptor_type; + + int disable_nagle( + ); + + socket_descriptor_type get_socket_descriptor ( + ) const; + + private: + + bool readable ( + unsigned long timeout + ) const; + /*! + requires + - timeout < 2000000 + ensures + - returns true if a read call on this connection will not block. + - returns false if a read call on this connection will block or if + there was an error. + !*/ + + bool sd_called ( + )const + /*! + ensures + - returns true if shutdown() has been called else + returns false + !*/ + { + sd_mutex.lock(); + bool temp = sd; + sd_mutex.unlock(); + return temp; + } + + bool sdo_called ( + )const + /*! + ensures + - returns true if shutdown_outgoing() or shutdown() has been called + else returns false + !*/ + { + sd_mutex.lock(); + bool temp = false; + if (sdo || sd) + temp = true; + sd_mutex.unlock(); + return temp; + } + + + // data members + SOCKET_container& connection_socket; + const unsigned short connection_foreign_port; + const std::string connection_foreign_ip; + const unsigned short connection_local_port; + const std::string connection_local_ip; + + bool sd; // called shutdown + bool sdo; // called shutdown_outgoing + int sdr; // return value for shutdown + mutex sd_mutex; // a lock for the three above vars + + + connection( + SOCKET_container sock, + unsigned short foreign_port, + const std::string& foreign_ip, + unsigned short local_port, + const std::string& local_ip + ); + /*! + requires + sock is a socket handle and + sock is the handle for the connection between foreign_ip:foreign_port + and local_ip:local_port + ensures + *this is initialized correctly with the above parameters + !*/ + + + // restricted functions + connection(connection&); // copy constructor + connection& operator=(connection&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // listener object +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class listener + { + /*! + CONVENTION + if (inaddr_any == false) + { + listening_ip == a string containing the address the listener is + listening on + } + else + { + the listener is listening on all interfaces + } + + listening_port == the port the listener is listening on + listening_socket == the listening socket handle for this object + !*/ + + // make the create_listener a friend of listener + friend int create_listener ( + listener*& new_listener, + unsigned short port, + const std::string& ip + ); + + public: + + ~listener ( + ); + + int accept ( + connection*& new_connection, + unsigned long timeout = 0 + ); + + int accept ( + std::unique_ptr& new_connection, + unsigned long timeout = 0 + ); + + unsigned short get_listening_port ( + ) { return listening_port; } + + const std::string& get_listening_ip ( + ) { return listening_ip; } + + private: + + // data members + SOCKET_container& listening_socket; + const unsigned short listening_port; + const std::string listening_ip; + const bool inaddr_any; + + listener( + SOCKET_container sock, + unsigned short port, + const std::string& ip + ); + /*! + requires + sock is a socket handle and + sock is listening on the port and ip(may be "") indicated in the + above parameters + ensures + *this is initialized correctly with the above parameters + !*/ + + + // restricted functions + listener(listener&); // copy constructor + listener& operator=(listener&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- + + int create_listener ( + listener*& new_listener, + unsigned short port, + const std::string& ip = "" + ); + + int create_connection ( + connection*& new_connection, + unsigned short foreign_port, + const std::string& foreign_ip, + unsigned short local_port = 0, + const std::string& local_ip = "" + ); + + int create_listener ( + std::unique_ptr& new_listener, + unsigned short port, + const std::string& ip = "" + ); + + int create_connection ( + std::unique_ptr& new_connection, + unsigned short foreign_port, + const std::string& foreign_ip, + unsigned short local_port = 0, + const std::string& local_ip = "" + ); + +// ---------------------------------------------------------------------------------------- + + +} + +#ifdef NO_MAKEFILE +#include "sockets_kernel_1.cpp" +#endif + +#endif // DLIB_SOCKETS_KERNEl_1_ + diff --git a/ml/dlib/dlib/sockets/sockets_kernel_2.cpp b/ml/dlib/dlib/sockets/sockets_kernel_2.cpp new file mode 100644 index 000000000..ac7408ff3 --- /dev/null +++ b/ml/dlib/dlib/sockets/sockets_kernel_2.cpp @@ -0,0 +1,1109 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net), Miguel Grinberg +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SOCKETS_KERNEL_2_CPp_ +#define DLIB_SOCKETS_KERNEL_2_CPp_ + +#include "../platform.h" + +#ifdef POSIX + + +#include "sockets_kernel_2.h" +#include +#include "../set.h" +#include + + + +namespace dlib +{ +// ---------------------------------------------------------------------------------------- + +#ifdef HPUX + typedef int dsocklen_t; +#else + typedef socklen_t dsocklen_t; +#endif + +// ---------------------------------------------------------------------------------------- +// stuff to ensure that the signal SIGPIPE is ignored before any connections are made +// so that when a connection object is shutdown the program won't end on a broken pipe + + namespace sockets_kernel_2_mutex + { + mutex startup_lock; + } + + + void sockets_startup() + { + // mutex crap to make this function thread safe + sockets_kernel_2_mutex::startup_lock.lock(); + static bool init = false; + if (init == false) + { + init = true; + signal( SIGPIPE, SIG_IGN); + } + sockets_kernel_2_mutex::startup_lock.unlock(); + } + + +// ---------------------------------------------------------------------------------------- + + // lookup functions + + int + get_local_hostname ( + std::string& hostname + ) + { + try + { + char temp[MAXHOSTNAMELEN]; + + if (gethostname(temp,MAXHOSTNAMELEN) == -1) + { + return OTHER_ERROR; + } + // ensure that NUL is at the end of the string + temp[MAXHOSTNAMELEN-1] = '\0'; + + hostname = temp; + } + catch (...) + { + return OTHER_ERROR; + } + + return 0; + } + +// ----------------- + +// cygwin currently doesn't support the getaddrinfo stuff +#ifndef __CYGWIN__ + + int + hostname_to_ip ( + const std::string& hostname, + std::string& ip, + int n + ) + { + try + { + set::kernel_1a sos; + + if (hostname.empty()) + return OTHER_ERROR; + + addrinfo* result = 0; + if (getaddrinfo(hostname.c_str(),0,0,&result)) + { + return OTHER_ERROR; + } + addrinfo* result_orig = result; + + // loop over all the addrinfo structures and add them to the set. the reason for doing + // this dumb crap is because different platforms return all kinds of weird garbage. many + // return the same ip multiple times, etc. + while (result != 0) + { + char temp[16]; + inet_ntop ( + AF_INET, + &((reinterpret_cast(result->ai_addr))->sin_addr), + temp,16 + ); + + result = result->ai_next; + + ip.assign(temp); + if (sos.is_member(ip) == false && ip != "0.0.0.0") + sos.add(ip); + } + + freeaddrinfo(result_orig); + + // now return the nth unique ip address + int i = 0; + while (sos.move_next()) + { + if (i == n) + { + ip = sos.element(); + return 0; + } + ++i; + } + + return OTHER_ERROR; + } + catch (...) + { + return OTHER_ERROR; + } + return 0; + } + + +// ----------------- + + int + ip_to_hostname ( + const std::string& ip, + std::string& hostname + ) + { + + try + { + + if (ip.empty()) + return OTHER_ERROR; + + sockaddr_in sa; + sa.sin_family = AF_INET; + inet_pton(AF_INET,ip.c_str(),&sa.sin_addr); + + char temp[NI_MAXHOST]; + if ( getnameinfo ( + reinterpret_cast(&sa),sizeof(sockaddr_in), + temp, + NI_MAXHOST, + 0, + 0, + NI_NAMEREQD + ) + ) + { + return OTHER_ERROR; + } + + hostname.assign(temp); + + } + catch (...) + { + return OTHER_ERROR; + } + return 0; + } +#else + int + hostname_to_ip ( + const std::string& hostname, + std::string& ip, + int n + ) + { + try + { + // lock this mutex since gethostbyname isn't really thread safe + auto_mutex M(sockets_kernel_2_mutex::startup_lock); + + // if no hostname was given then return error + if ( hostname.empty()) + return OTHER_ERROR; + + hostent* address; + address = gethostbyname(hostname.c_str()); + + if (address == 0) + { + return OTHER_ERROR; + } + + // find the nth address + in_addr* addr = reinterpret_cast(address->h_addr_list[0]); + for (int i = 1; i <= n; ++i) + { + addr = reinterpret_cast(address->h_addr_list[i]); + + // if there is no nth address then return error + if (addr == 0) + return OTHER_ERROR; + } + + char* resolved_ip = inet_ntoa(*addr); + + // check if inet_ntoa returned an error + if (resolved_ip == NULL) + { + return OTHER_ERROR; + } + + ip.assign(resolved_ip); + + } + catch(...) + { + return OTHER_ERROR; + } + + return 0; + } + +// ----------------- + + int + ip_to_hostname ( + const std::string& ip, + std::string& hostname + ) + { + try + { + // lock this mutex since gethostbyaddr isn't really thread safe + auto_mutex M(sockets_kernel_2_mutex::startup_lock); + + // if no ip was given then return error + if (ip.empty()) + return OTHER_ERROR; + + hostent* address; + unsigned long ipnum = inet_addr(ip.c_str()); + + // if inet_addr couldn't convert ip then return an error + if (ipnum == INADDR_NONE) + { + return OTHER_ERROR; + } + address = gethostbyaddr(reinterpret_cast(&ipnum),4,AF_INET); + + // check if gethostbyaddr returned an error + if (address == 0) + { + return OTHER_ERROR; + } + hostname.assign(address->h_name); + + } + catch (...) + { + return OTHER_ERROR; + } + return 0; + + } + +#endif // __CYGWIN__ + +// ---------------------------------------------------------------------------------------- + + connection:: + connection( + int sock, + int foreign_port, + const std::string& foreign_ip, + int local_port, + const std::string& local_ip + ) : + connection_socket(sock), + connection_foreign_port(foreign_port), + connection_foreign_ip(foreign_ip), + connection_local_port(local_port), + connection_local_ip(local_ip), + sd(false), + sdo(false), + sdr(0) + {} + +// ---------------------------------------------------------------------------------------- + + int connection:: + disable_nagle() + { + int flag = 1; + if(setsockopt( connection_socket, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(flag) )) + { + return OTHER_ERROR; + } + + return 0; + } + +// ---------------------------------------------------------------------------------------- + + long connection:: + write ( + const char* buf, + long num + ) + { + const long old_num = num; + long status; + const long max_send_length = 1024*1024*100; + while (num > 0) + { + // Make sure to cap the max value num can take on so that if it is + // really large (it might be big on 64bit platforms) so that the OS + // can't possibly get upset about it being large. + const long length = std::min(max_send_length, num); + if ( (status = ::send(connection_socket,buf,length,0)) <=0) + { + // if send was interupted by a signal then restart it + if (errno == EINTR) + { + continue; + } + else + { + // check if shutdown or shutdown_outgoing have been called + if (sdo_called()) + return SHUTDOWN; + else + return OTHER_ERROR; + } + } + num -= status; + buf += status; + } + return old_num; + } + +// ---------------------------------------------------------------------------------------- + + long connection:: + read ( + char* buf, + long num + ) + { + long status; + const long max_recv_length = 1024*1024*100; + while (true) + { + // Make sure to cap the max value num can take on so that if it is + // really large (it might be big on 64bit platforms) so that the OS + // can't possibly get upset about it being large. + const long length = std::min(max_recv_length, num); + status = recv(connection_socket,buf,length,0); + if (status == -1) + { + // if recv was interupted then try again + if (errno == EINTR) + continue; + else + { + if (sd_called()) + return SHUTDOWN; + else + return OTHER_ERROR; + } + } + else if (status == 0 && sd_called()) + { + return SHUTDOWN; + } + + return status; + } // while (true) + } +// ---------------------------------------------------------------------------------------- + + long connection:: + read ( + char* buf, + long num, + unsigned long timeout + ) + { + long status; + const long max_recv_length = 1024*1024*100; + + if (readable(timeout) == false) + return TIMEOUT; + + // Make sure to cap the max value num can take on so that if it is + // really large (it might be big on 64bit platforms) so that the OS + // can't possibly get upset about it being large. + const long length = std::min(max_recv_length, num); + status = recv(connection_socket,buf,length,0); + if (status == -1) + { + // if recv was interupted then call this a timeout + if (errno == EINTR) + { + return TIMEOUT; + } + else + { + if (sd_called()) + return SHUTDOWN; + else + return OTHER_ERROR; + } + } + else if (status == 0 && sd_called()) + { + return SHUTDOWN; + } + + return status; + } + +// ---------------------------------------------------------------------------------------- + + bool connection:: + readable ( + unsigned long timeout + ) const + { + fd_set read_set; + // initialize read_set + FD_ZERO(&read_set); + + // add the listening socket to read_set + FD_SET(connection_socket, &read_set); + + // setup a timeval structure + timeval time_to_wait; + time_to_wait.tv_sec = static_cast(timeout/1000); + time_to_wait.tv_usec = static_cast((timeout%1000)*1000); + + // wait on select + int status = select(connection_socket+1,&read_set,0,0,&time_to_wait); + + // if select timed out or there was an error + if (status <= 0) + return false; + + // socket is ready to be read + return true; + } + +// ---------------------------------------------------------------------------------------- + + connection:: + ~connection ( + ) + { + while (true) + { + int status = ::close(connection_socket); + if (status == -1 && errno == EINTR) + continue; + break; + } + } + + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // listener object +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + listener:: + listener( + int sock, + int port, + const std::string& ip + ) : + listening_socket(sock), + listening_port(port), + listening_ip(ip), + inaddr_any(listening_ip.empty()) + {} + +// ---------------------------------------------------------------------------------------- + + listener:: + ~listener ( + ) + { + while (true) + { + int status = ::close(listening_socket); + if (status == -1 && errno == EINTR) + continue; + break; + } + } + +// ---------------------------------------------------------------------------------------- + + int listener:: + accept ( + std::unique_ptr& new_connection, + unsigned long timeout + ) + { + new_connection.reset(0); + connection* con; + int status = this->accept(con, timeout); + + if (status == 0) + new_connection.reset(con); + + return status; + } + +// ---------------------------------------------------------------------------------------- + + int listener:: + accept ( + connection*& new_connection, + unsigned long timeout + ) + { + int incoming; + sockaddr_in incomingAddr; + dsocklen_t length = sizeof(sockaddr_in); + + // implement timeout with select if timeout is > 0 + if (timeout > 0) + { + + fd_set read_set; + // initialize read_set + FD_ZERO(&read_set); + + // add the listening socket to read_set + FD_SET(listening_socket, &read_set); + + timeval time_to_wait; + + + // loop on select so if its interupted then we can start it again + while (true) + { + + // setup a timeval structure + time_to_wait.tv_sec = static_cast(timeout/1000); + time_to_wait.tv_usec = static_cast((timeout%1000)*1000); + + // wait on select + int status = select(listening_socket+1,&read_set,0,0,&time_to_wait); + + // if select timed out + if (status == 0) + return TIMEOUT; + + // if select returned an error + if (status == -1) + { + // if select was interupted or the connection was aborted + // then go back to select + if (errno == EINTR || + errno == ECONNABORTED || +#ifdef EPROTO + errno == EPROTO || +#endif + errno == ECONNRESET + ) + { + continue; + } + else + { + return OTHER_ERROR; + } + } + + // accept the new connection + incoming=::accept ( + listening_socket, + reinterpret_cast(&incomingAddr), + &length + ); + + // if there was an error return OTHER_ERROR + if ( incoming == -1 ) + { + // if accept was interupted then go back to accept + if (errno == EINTR || + errno == ECONNABORTED || +#ifdef EPROTO + errno == EPROTO || +#endif + errno == ECONNRESET + ) + { + continue; + } + else + { + return OTHER_ERROR; + } + } + + // if there were no errors then quit loop + break; + + } + + } + // else if there is no time out then just go into accept + else + { + while (true) + { + // call accept to get a new connection + incoming=::accept ( + listening_socket, + reinterpret_cast(&incomingAddr), + &length + ); + + // if there was an error return OTHER_ERROR + if ( incoming == -1 ) + { + // if accept was interupted then go back to accept + if (errno == EINTR || + errno == ECONNABORTED || +#ifdef EPROTO + errno == EPROTO || +#endif + errno == ECONNRESET + ) + { + continue; + } + else + { + return OTHER_ERROR; + } + } + break; + } + + } + + + // get the port of the foreign host into foreign_port + int foreign_port = ntohs(incomingAddr.sin_port); + + // get the IP of the foreign host into foreign_ip + char foreign_ip[16]; + inet_ntop(AF_INET,&incomingAddr.sin_addr,foreign_ip,16); + + + + // get the local ip for this connection into local_ip + char temp_local_ip[16]; + std::string local_ip; + if (inaddr_any == true) + { + sockaddr_in local_info; + length = sizeof(sockaddr_in); + // get the local sockaddr_in structure associated with this new connection + if ( getsockname ( + incoming, + reinterpret_cast(&local_info), + &length + ) == -1 + ) + { // an error occurred + while (true) + { + int status = ::close(incoming); + if (status == -1 && errno == EINTR) + continue; + break; + } + return OTHER_ERROR; + } + local_ip = const_cast ( + inet_ntop(AF_INET,&local_info.sin_addr,temp_local_ip,16) + ); + } + else + { + local_ip = listening_ip; + } + + + + // set the SO_OOBINLINE option + int flag_value = 1; + if (setsockopt(incoming,SOL_SOCKET,SO_OOBINLINE,reinterpret_cast(&flag_value),sizeof(int))) + { + while (true) + { + int status = ::close(incoming); + if (status == -1 && errno == EINTR) + continue; + break; + } + return OTHER_ERROR; + } + + + + // make a new connection object for this new connection + try + { + new_connection = new connection ( + incoming, + foreign_port, + foreign_ip, + listening_port, + local_ip + ); + } + catch (...) + { + while (true) + { + int status = ::close(incoming); + if (status == -1 && errno == EINTR) + continue; + break; + } + return OTHER_ERROR; + } + + return 0; + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // socket creation functions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + static void + close_socket ( + int sock + ) + /*! + requires + - sock == a socket + ensures + - sock has been closed + !*/ + { + while (true) + { + int status = ::close(sock); + if (status == -1 && errno == EINTR) + continue; + break; + } + } + +// ---------------------------------------------------------------------------------------- + + int create_listener ( + std::unique_ptr& new_listener, + unsigned short port, + const std::string& ip + ) + { + new_listener.reset(); + listener* temp; + int status = create_listener(temp,port,ip); + + if (status == 0) + new_listener.reset(temp); + + return status; + } + + int create_listener ( + listener*& new_listener, + unsigned short port, + const std::string& ip + ) + { + sockets_startup(); + + + sockaddr_in sa; // local socket structure + memset(&sa,'\0',sizeof(sockaddr_in)); // initialize sa + + + int sock = socket (AF_INET, SOCK_STREAM, 0); // get a new socket + + // if socket() returned an error then return OTHER_ERROR + if (sock == -1) + { + return OTHER_ERROR; + } + + // set the local socket structure + sa.sin_family = AF_INET; + sa.sin_port = htons(port); + if (ip.empty()) + { + // if the listener should listen on any IP + sa.sin_addr.s_addr = htons(INADDR_ANY); + } + else + { + // if there is a specific ip to listen on + sa.sin_addr.s_addr = inet_addr(ip.c_str()); + + // if inet_addr couldn't convert the ip then return an error + if ( sa.sin_addr.s_addr == ( in_addr_t)(-1)) + { + close_socket(sock); + return OTHER_ERROR; + } + } + + // set the SO_REUSEADDR option + int flag_value = 1; + if (setsockopt(sock,SOL_SOCKET,SO_REUSEADDR,reinterpret_cast(&flag_value),sizeof(int))) + { + close_socket(sock); + return OTHER_ERROR; + } + + + // bind the new socket to the requested port and ip + if (bind(sock,reinterpret_cast(&sa),sizeof(sockaddr_in)) == -1) + { // if there was an error + close_socket(sock); + + // if the port is already bound then return PORTINUSE + if (errno == EADDRINUSE) + return PORTINUSE; + else + return OTHER_ERROR; + } + + + // tell the new socket to listen + if ( listen(sock,SOMAXCONN) == -1) + { + // if there was an error return OTHER_ERROR + close_socket(sock); + + // if the port is already bound then return PORTINUSE + if (errno == EADDRINUSE) + return PORTINUSE; + else + return OTHER_ERROR; + } + + // determine the used local port if necessary + if (port == 0) + { + sockaddr_in local_info; + dsocklen_t length = sizeof(sockaddr_in); + if ( getsockname( + sock, + reinterpret_cast(&local_info), + &length + ) == -1) + { + close_socket(sock); + return OTHER_ERROR; + } + port = ntohs(local_info.sin_port); + } + + // initialize a listener object on the heap with the new socket + try { new_listener = new listener(sock,port,ip); } + catch(...) { close_socket(sock); return OTHER_ERROR; } + + return 0; + } + +// ---------------------------------------------------------------------------------------- + + int create_connection ( + std::unique_ptr& new_connection, + unsigned short foreign_port, + const std::string& foreign_ip, + unsigned short local_port, + const std::string& local_ip + ) + { + new_connection.reset(); + connection* temp; + int status = create_connection(temp,foreign_port, foreign_ip, local_port, local_ip); + + if (status == 0) + new_connection.reset(temp); + + return status; + } + + int + create_connection ( + connection*& new_connection, + unsigned short foreign_port, + const std::string& foreign_ip, + unsigned short local_port, + const std::string& local_ip + ) + { + sockets_startup(); + + sockaddr_in local_sa; // local socket structure + sockaddr_in foreign_sa; // foreign socket structure + memset(&local_sa,'\0',sizeof(sockaddr_in)); // initialize local_sa + memset(&foreign_sa,'\0',sizeof(sockaddr_in)); // initialize foreign_sa + + dsocklen_t length; + + int sock = socket (AF_INET, SOCK_STREAM, 0); // get a new socket + + // if socket() returned an error then return OTHER_ERROR + if (sock == -1 ) + { + return OTHER_ERROR; + } + + // set the foreign socket structure + foreign_sa.sin_family = AF_INET; + foreign_sa.sin_port = htons(foreign_port); + foreign_sa.sin_addr.s_addr = inet_addr(foreign_ip.c_str()); + + // if inet_addr couldn't convert the ip then return an error + if ( foreign_sa.sin_addr.s_addr == ( in_addr_t)(-1)) + { + close_socket(sock); + return OTHER_ERROR; + } + + + // set up the local socket structure + local_sa.sin_family = AF_INET; + + // set the local port + local_sa.sin_port = htons(local_port); + + // set the local ip + if (local_ip.empty()) + { + // if the listener should listen on any IP + local_sa.sin_addr.s_addr = htons(INADDR_ANY); + } + else + { + // if there is a specific ip to listen on + local_sa.sin_addr.s_addr = inet_addr(local_ip.c_str()); + + // if inet_addr couldn't convert the ip then return an error + if ( local_sa.sin_addr.s_addr == ( in_addr_t)(-1)) + { + close_socket(sock); + return OTHER_ERROR; + } + } + + + + + + // bind the new socket to the requested local port and local ip + if ( bind(sock,reinterpret_cast(&local_sa),sizeof(sockaddr_in)) == -1) + { // if there was an error + close_socket(sock); + + // if the port is already bound then return PORTINUSE + if (errno == EADDRINUSE) + return PORTINUSE; + else + return OTHER_ERROR; + } + + // connect the socket + if ( connect ( + sock, + reinterpret_cast(&foreign_sa), + sizeof(sockaddr_in) + ) == -1 + ) + { + close_socket(sock); + // if the port is already bound then return PORTINUSE + if (errno == EADDRINUSE) + return PORTINUSE; + else + return OTHER_ERROR; + } + + + // determine the local port and IP and store them in used_local_ip + // and used_local_port + int used_local_port; + char temp_used_local_ip[16]; + std::string used_local_ip; + sockaddr_in local_info; + + // determine the port + if (local_port == 0) + { + length = sizeof(sockaddr_in); + if ( getsockname( + sock, + reinterpret_cast(&local_info), + &length + ) == -1) + { + close_socket(sock); + return OTHER_ERROR; + } + used_local_port = ntohs(local_info.sin_port); + } + else + { + used_local_port = local_port; + } + + // determine the ip + if (local_ip.empty()) + { + // if local_port is not 0 then we must fill the local_info structure + if (local_port != 0) + { + length = sizeof(sockaddr_in); + if ( getsockname ( + sock, + reinterpret_cast(&local_info), + &length + ) == -1 + ) + { + close_socket(sock); + return OTHER_ERROR; + } + } + used_local_ip = inet_ntop(AF_INET,&local_info.sin_addr,temp_used_local_ip,16); + } + else + { + used_local_ip = local_ip; + } + + + // set the SO_OOBINLINE option + int flag_value = 1; + if (setsockopt(sock,SOL_SOCKET,SO_OOBINLINE,reinterpret_cast(&flag_value),sizeof(int))) + { + close_socket(sock); + return OTHER_ERROR; + } + + + // initialize a connection object on the heap with the new socket + try + { + new_connection = new connection ( + sock, + foreign_port, + foreign_ip, + used_local_port, + used_local_ip + ); + } + catch(...) {close_socket(sock); return OTHER_ERROR; } + + return 0; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // POSIX + +#endif // DLIB_SOCKETS_KERNEL_2_CPp_ + diff --git a/ml/dlib/dlib/sockets/sockets_kernel_2.h b/ml/dlib/dlib/sockets/sockets_kernel_2.h new file mode 100644 index 000000000..f3bc94ec0 --- /dev/null +++ b/ml/dlib/dlib/sockets/sockets_kernel_2.h @@ -0,0 +1,396 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net), Miguel Grinberg +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SOCKETS_KERNEl_2_ +#define DLIB_SOCKETS_KERNEl_2_ + +#ifdef DLIB_ISO_CPP_ONLY +#error "DLIB_ISO_CPP_ONLY is defined so you can't use this OS dependent code. Turn DLIB_ISO_CPP_ONLY off if you want to use it." +#endif + +#include "../platform.h" + +#include "sockets_kernel_abstract.h" + +#define _BSD_SOCKLEN_T_ + +#include +#include +#include + +#include +#include +#include + +#ifndef HPUX +#include +#endif +#include +#include +#include +#include +#include +#include + +#include + +#include "../threads.h" +#include "../algs.h" + + + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + // forward declarations + class socket_factory; + class listener; + +// ---------------------------------------------------------------------------------------- + + // lookup functions + + int + get_local_hostname ( + std::string& hostname + ); + +// ----------------- + + int + hostname_to_ip ( + const std::string& hostname, + std::string& ip, + int n = 0 + ); + +// ----------------- + + int + ip_to_hostname ( + const std::string& ip, + std::string& hostname + ); + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // connection object +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class connection + { + /*! + INITIAL_VALUE + sd == false + sdo == false + sdr == 0 + + + CONVENTION + connection_socket == the socket handle for this connection. + connection_foreign_port == the port that foreign host is using for + this connection + connection_foreign_ip == a string containing the IP address of the + foreign host + connection_local_port == the port that the local host is using for + this connection + connection_local_ip == a string containing the IP address of the + local interface being used by this connection + + sd == if shutdown() has been called then true + else false + sdo == if shutdown_outgoing() has been called then true + else false + sdr == the return value of shutdown() if it has been + called. if it hasn't been called then 0 + + + !*/ + + friend class listener; // make listener a friend of connection + // make create_connection a friend of connection + friend int create_connection ( + connection*& new_connection, + unsigned short foreign_port, + const std::string& foreign_ip, + unsigned short local_port, + const std::string& local_ip + ); + + public: + + ~connection(); + + void* user_data; + + long write ( + const char* buf, + long num + ); + + long read ( + char* buf, + long num + ); + + long read ( + char* buf, + long num, + unsigned long timeout + ); + + int get_local_port ( + ) const { return connection_local_port; } + + int get_foreign_port ( + ) const { return connection_foreign_port; } + + const std::string& get_local_ip ( + ) const { return connection_local_ip; } + + const std::string& get_foreign_ip ( + ) const { return connection_foreign_ip; } + + int shutdown_outgoing ( + ) + { + sd_mutex.lock(); + if (sdo || sd) + { + sd_mutex.unlock(); + return sdr; + } + sdo = true; + sdr = ::shutdown(connection_socket,SHUT_WR); + int temp = sdr; + sd_mutex.unlock(); + return temp; + } + + int shutdown ( + ) + { + sd_mutex.lock(); + if (sd) + { + sd_mutex.unlock(); + return sdr; + } + sd = true; + sdr = ::shutdown(connection_socket,SHUT_RDWR); + int temp = sdr; + sd_mutex.unlock(); + return temp; + } + + int disable_nagle( + ); + + typedef int socket_descriptor_type; + + socket_descriptor_type get_socket_descriptor ( + ) const { return connection_socket; } + + private: + + bool readable ( + unsigned long timeout + ) const; + /*! + requires + - timeout < 2000000 + ensures + - returns true if a read call on this connection will not block. + - returns false if a read call on this connection will block or if + there was an error. + !*/ + + bool sd_called ( + )const + /*! + ensures + - returns true if shutdown() has been called else + - returns false + !*/ + { + sd_mutex.lock(); + bool temp = sd; + sd_mutex.unlock(); + return temp; + } + + bool sdo_called ( + )const + /*! + ensures + - returns true if shutdown_outgoing() or shutdown() has been called + else returns false + !*/ + { + sd_mutex.lock(); + bool temp = false; + if (sdo || sd) + temp = true; + sd_mutex.unlock(); + return temp; + } + + + // data members + int connection_socket; + const int connection_foreign_port; + const std::string connection_foreign_ip; + const int connection_local_port; + const std::string connection_local_ip; + + bool sd; // called shutdown + bool sdo; // called shutdown_outgoing + int sdr; // return value for shutdown + mutex sd_mutex; // a lock for the three above vars + + connection( + int sock, + int foreign_port, + const std::string& foreign_ip, + int local_port, + const std::string& local_ip + ); + /*! + requires + - sock is a socket handle + - sock is the handle for the connection between foreign_ip:foreign_port + and local_ip:local_port + ensures + - *this is initialized correctly with the above parameters + !*/ + + + // restricted functions + connection(); + connection(connection&); // copy constructor + connection& operator=(connection&); // assignement opertor + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // listener object +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class listener + { + /*! + CONVENTION + if (inaddr_any == false) + { + listening_ip == a string containing the address the listener is + listening on + } + else + { + the listener is listening on all interfaces + } + + listening_port == the port the listener is listening on + listening_socket == the listening socket handle for this object + !*/ + + // make the create_listener a friend of listener + friend int create_listener ( + listener*& new_listener, + unsigned short port, + const std::string& ip + ); + + public: + + ~listener(); + + int accept ( + connection*& new_connection, + unsigned long timeout = 0 + ); + + int accept ( + std::unique_ptr& new_connection, + unsigned long timeout = 0 + ); + + int get_listening_port ( + ) const { return listening_port; } + + const std::string& get_listening_ip ( + ) const { return listening_ip; } + + private: + + // data members + int listening_socket; + const int listening_port; + const std::string listening_ip; + const bool inaddr_any; + + listener( + int sock, + int port, + const std::string& ip + ); + /*! + requires + - sock is a socket handle + - sock is listening on the port and ip(may be "") indicated in the above + parameters + ensures + - *this is initialized correctly with the above parameters + !*/ + + + // restricted functions + listener(); + listener(listener&); // copy constructor + listener& operator=(listener&); // assignement opertor + }; + +// ---------------------------------------------------------------------------------------- + + int create_listener ( + listener*& new_listener, + unsigned short port, + const std::string& ip = "" + ); + + int create_connection ( + connection*& new_connection, + unsigned short foreign_port, + const std::string& foreign_ip, + unsigned short local_port = 0, + const std::string& local_ip = "" + ); + + int create_listener ( + std::unique_ptr& new_listener, + unsigned short port, + const std::string& ip = "" + ); + + int create_connection ( + std::unique_ptr& new_connection, + unsigned short foreign_port, + const std::string& foreign_ip, + unsigned short local_port = 0, + const std::string& local_ip = "" + ); + +// ---------------------------------------------------------------------------------------- + +} + +#ifdef NO_MAKEFILE +#include "sockets_kernel_2.cpp" +#endif + +#endif // DLIB_SOCKETS_KERNEl_2_ + diff --git a/ml/dlib/dlib/sockets/sockets_kernel_abstract.h b/ml/dlib/dlib/sockets/sockets_kernel_abstract.h new file mode 100644 index 000000000..d4571acad --- /dev/null +++ b/ml/dlib/dlib/sockets/sockets_kernel_abstract.h @@ -0,0 +1,495 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SOCKETS_KERNEl_ABSTRACT_ +#ifdef DLIB_SOCKETS_KERNEl_ABSTRACT_ + +#include +#include "../threads.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + /*! + GENERAL COMMENTS: + Nothing in here will throw exceptions. + + All ip address strings in this file refer to IPv4 addresses. For + example "192.168.1.1" + + Timeouts: + All timeout values are measured in milliseconds but you are not + guaranteed to have that level of resolution. The actual resolution + is implementation defined. + + GENERAL WARNING + Don't call any of these functions or make any of these objects + before main() has been entered. + + EXCEPTIONS + Unless specified otherwise, nothing in this file throws exceptions. + !*/ + +// ---------------------------------------------------------------------------------------- + + // LOOKUP FUNCTIONS + + // all lookup functions are thread-safe + + int get_local_hostname ( + std::string& hostname + ); + /*! + ensures + - if (#get_local_hostname() == 0) then + - #hostname == a string containing the hostname of the local computer + + - returns 0 upon success + - returns OTHER_ERROR upon failure and in this case #hostname's value + is undefined + !*/ + +// ----------------- + + int hostname_to_ip ( + const std::string& hostname, + std::string& ip, + int n = 0 + ); + /*! + requires + - n >= 0 + ensures + - if (#hostname_to_ip() == 0) then + - #ip == string containing the nth ip address associated with the hostname + + - returns 0 upon success + - returns OTHER_ERROR upon failure + !*/ + +// ----------------- + + int ip_to_hostname ( + const std::string& ip, + std::string& hostname + ); + /*! + ensures + - if (#ip_to_hostname() == 0) then + - #hostname == string containing the hostname associated with ip + + - returns 0 upon success + - returns OTHER_ERROR upon failure + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // + // socket creation functions + // + // The following functions are guaranteed to be thread-safe + // +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + int create_listener ( + listener*& new_listener, + unsigned short port, + const std::string& ip = "" + ); + /*! + requires + - 0 <= port <= 65535 + ensures + - if (#create_listener() == 0) then + - #new_listener == a pointer to a listener object that is listening on + the specified port and ip for an incoming connection + - if (ip == "") then + - the new listener will be listening on all interfaces + - if (port == 0) then + - the operating system will assign a free port to listen on + + + - returns 0 if create_listener was successful + - returns PORTINUSE if the specified local port was already in use + - returns OTHER_ERROR if some other error occurred + !*/ + + int create_listener ( + std::unique_ptr& new_listener, + unsigned short port, + const std::string& ip = "" + ); + /*! + This function is just an overload of the above function but it gives you a + std::unique_ptr smart pointer instead of a C pointer. + !*/ + + int create_connection ( + connection*& new_connection, + unsigned short foreign_port, + const std::string& foreign_ip, + unsigned short local_port = 0, + const std::string& local_ip = "" + ); + /*! + requires + - 0 < foreign_port <= 65535 + - 0 <= local_port <= 65535 + ensures + - if (#create_connection() == 0) then + - #new_connection == a pointer to a connection object that is connected + to foreign_ip on port foreign_port and is using the local interface + local_ip and local port local_port + - #new_connection->user_data == 0 + - if (local_ip == "") then + - the operating system will chose this for you + - if (local_port == 0) then + - the operating system will chose this for you + + - returns 0 if create_connection was successful + - returns PORTINUSE if the specified local port was already in use + - returns OTHER_ERROR if some other error occurred + !*/ + + int create_connection ( + std::unique_ptr& new_connection, + unsigned short foreign_port, + const std::string& foreign_ip, + unsigned short local_port = 0, + const std::string& local_ip = "" + ); + /*! + This function is just an overload of the above function but it gives you a + std::unique_ptr smart pointer instead of a C pointer. + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // connection object +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class connection + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a TCP connection. + + Instances of this class can only be created by using the + create_connection function or listener class defined below. + + NOTE: + A connection object must ALWAYS be closed (delete the pointer to the + connection) or it will cause a resource leak. + + Note also that all errors indicated by a return code of OTHER_ERROR + are fatal so if one occurs the connection should just be closed. + + CLOSING A CONNECTION + Note that if ~connection() or shutdown() is called before the remote client + has received all sent data it is possible that the data will be lost. To + avoid this you should call the close_gracefully() function to close your + connections (unless you actually do want to immediately dispose of a + connection and don't care about the data). + (example: close_gracefully(con); // close con gracefully but force it closed + // if it takes more than 500 milliseconds.) + + THREAD SAFETY + - It is always safe to call shutdown() or shutdown_outgoing(). + - you may NOT call any function more than once at a time (except the + shutdown functions). + - do not call read() more than once at a time + - do not call write() more than once at a time + - You can safely call shutdown or shutdown_outgoing in conjunction with + the read/write functions. + This is helpful if you want to unblock another thread that is + blocking on a read/write operation. Shutting down the connection + will cause the read/write functions to return a value of SHUTDOWN. + + OUT-OF-BAND DATA: + All out-of-band data will be put inline into the normal data stream. + This means that you can read any out-of-band data via calls to read(). + (i.e. the SO_OOBINLINE socket option will be set) + !*/ + + public: + + ~connection ( + ); + /*! + requires + - no other threads are using this connection object + ensures + - closes the connection (this is an abrupt non-graceful close) + - frees the resources used by this object + !*/ + + void* user_data; + /*! + This pointer is provided so that the client programmer may easily associate + some data with a connection object. You can really do whatever you want + with it. Initially user_data is 0. + !*/ + + long write ( + const char* buf, + long num + ); + /*! + requires + - num > 0 + - buf points to an array of at least num bytes + ensures + - will block until ONE of the following occurs: + - num bytes from buf have been written to the connection + - an error has occurred + - the outgoing channel of the connection has been shutdown locally + + - returns num if write succeeded + - returns OTHER_ERROR if there was an error (this could be due to a + connection close) + - returns SHUTDOWN if the outgoing channel of the connection has been + shutdown locally + !*/ + + long read ( + char* buf, + long num + ); + /*! + requires + - num > 0 + - buf points to an array of at least num bytes + ensures + - read() will not read more than num bytes of data into #buf + - read blocks until ONE of the following happens: + - there is some data available and it has been written into #buf + - the remote end of the connection is closed + - an error has occurred + - the connection has been shutdown locally + + - returns the number of bytes read into #buf if there was any data. + - returns 0 if the connection has ended/terminated and there is no more data. + - returns OTHER_ERROR if there was an error. + - returns SHUTDOWN if the connection has been shutdown locally + !*/ + + long read ( + char* buf, + long num, + unsigned long timeout + ); + /*! + requires + - num > 0 + - buf points to an array of at least num bytes + - timeout < 2000000 + ensures + - read() will not read more than num bytes of data into #buf + - if (timeout > 0) then read() blocks until ONE of the following happens: + - there is some data available and it has been written into #buf + - the remote end of the connection is closed + - an error has occurred + - the connection has been shutdown locally + - timeout milliseconds has elapsed + - else + - read() does not block + + - returns the number of bytes read into #buf if there was any data. + - returns 0 if the connection has ended/terminated and there is no more data. + - returns TIMEOUT if timeout milliseconds elapsed before we got any data. + - returns OTHER_ERROR if there was an error. + - returns SHUTDOWN if the connection has been shutdown locally + !*/ + + unsigned short get_local_port ( + ) const; + /*! + ensures + - returns the local port number for this connection + !*/ + + unsigned short get_foreign_port ( + ) const; + /*! + ensures + - returns the foreign port number for this connection + !*/ + + const std::string& get_local_ip ( + ) const; + /*! + ensures + - returns the IP of the local interface this connection is using + !*/ + + const std::string& get_foreign_ip ( + ) const; + /*! + ensures + - returns the IP of the foreign host for this connection + !*/ + + int shutdown ( + ); + /*! + ensures + - if (#shutdown() == 0 && connection was still open) then + - terminates the connection but does not free the resources for the + connection object + + - any read() or write() calls on this connection will return immediately + with the code SHUTDOWN. + + - returns 0 upon success + - returns OTHER_ERROR if there was an error + !*/ + + int shutdown_outgoing ( + ); + /*! + ensures + - if (#shutdown_outgoing() == 0 && outgoing channel was still open) then + - sends a FIN to indicate that no more data will be sent on this + connection but leaves the receive half of the connection open to + receive more data from the other host + + - any calls to write() will return immediately with the code SHUTDOWN. + + - returns 0 upon success + - returns OTHER_ERROR if there was an error + !*/ + + int disable_nagle( + ); + /*! + ensures + - Sets the TCP_NODELAY socket option to disable Nagle's algorithm. + This can sometimes reduce transmission latency, however, in almost + all normal cases you don't want to mess with this as the default + setting is usually appropriate. + + - returns 0 upon success + - returns OTHER_ERROR if there was an error + !*/ + + typedef platform_specific_type socket_descriptor_type; + socket_descriptor_type get_socket_descriptor ( + ) const; + /*! + ensures + - returns the underlying socket descriptor for this connection + object. The reason you might want access to this is to + pass it to some other library that requires a socket file + descriptor. However, if you do this then you probably shouldn't + use the dlib::connection read() and write() anymore since + whatever you are doing with the socket descriptor is probably + doing I/O with the socket. + !*/ + + private: + // restricted functions + connection(); + connection(connection&); // copy constructor + connection& operator=(connection&); // assignment operator + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // listener object +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class listener + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a TCP socket waiting for incoming connections. + Calling accept returns a pointer to any new incoming connections on its + port. + + Instances of this class can only be created by using the + create_listener function defined below. + + NOTE: + A listener object must ALWAYS be closed (delete the pointer to it) or + it will cause a resource leak. + + Note also that all errors indicated by a return code of OTHER_ERROR + are fatal so if one occurs the listener should be closed. + + THREAD SAFETY + None of the functions in this object are guaranteed to be thread-safe. + This means that you must serialize all access to this object. + !*/ + + public: + + ~listener ( + ); + /*! + requires + - no other threads are using this listener object + ensures + - closes the listener + - frees the resources used by this object + !*/ + + int accept ( + connection*& new_connection, + unsigned long timeout = 0 + ); + /*! + requires + - timeout < 2000000 + ensures + - blocks until a new connection is ready or timeout milliseconds have + elapsed. + - #new_connection == a pointer to the new connection object + - #new_connection->user_data == 0 + - if (timeout == 0) then + - the timeout argument is ignored + + - returns 0 if accept() was successful + - returns TIMEOUT if timeout milliseconds have elapsed + - returns OTHER_ERROR if an error has occurred + !*/ + + int accept ( + std::unique_ptr& new_connection, + unsigned long timeout = 0 + ); + /*! + This function is just an overload of the above function but it gives you a + std::unique_ptr smart pointer instead of a C pointer. + !*/ + + unsigned short get_listening_port ( + ) const; + /*! + ensures + - returns the port number that this object is listening on + !*/ + + const std::string& get_listening_ip ( + ) const; + /*! + ensures + - returns a string containing the IP (e.g. "127.0.0.1") of the + interface this object is listening on + - returns "" if it is accepting connections on all interfaces + !*/ + + private: + // restricted functions + listener(); + listener(listener&); // copy constructor + listener& operator=(listener&); // assignment operator + }; +} + +#endif // DLIB_SOCKETS_KERNEl_ABSTRACT_ + diff --git a/ml/dlib/dlib/sockets/windows.h b/ml/dlib/dlib/sockets/windows.h new file mode 100644 index 000000000..85b7fd8d8 --- /dev/null +++ b/ml/dlib/dlib/sockets/windows.h @@ -0,0 +1,6 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SOCKETS_KERNEl_2_ +#include "sockets_kernel_1.h" +#endif + diff --git a/ml/dlib/dlib/sockstreambuf.h b/ml/dlib/dlib/sockstreambuf.h new file mode 100644 index 000000000..41b0e7a9e --- /dev/null +++ b/ml/dlib/dlib/sockstreambuf.h @@ -0,0 +1,11 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SOCKSTREAMBUf_H_h_ +#define DLIB_SOCKSTREAMBUf_H_h_ + +#include "sockstreambuf/sockstreambuf.h" +#include "sockstreambuf/sockstreambuf_unbuffered.h" + + +#endif // DLIB_SOCKSTREAMBUf_H_h_ + diff --git a/ml/dlib/dlib/sockstreambuf/sockstreambuf.cpp b/ml/dlib/dlib/sockstreambuf/sockstreambuf.cpp new file mode 100644 index 000000000..e328e4259 --- /dev/null +++ b/ml/dlib/dlib/sockstreambuf/sockstreambuf.cpp @@ -0,0 +1,177 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SOCKStREAMBUF_CPp_ +#define DLIB_SOCKStREAMBUF_CPp_ + +#include "sockstreambuf.h" +#include "../assert.h" + +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + // output functions +// ---------------------------------------------------------------------------------------- + + sockstreambuf::int_type sockstreambuf:: + overflow ( + int_type c + ) + { + if (c != EOF) + { + *pptr() = c; + pbump(1); + } + if (flush_out_buffer() == EOF) + { + // an error occurred + return EOF; + } + return c; + } + +// ---------------------------------------------------------------------------------------- + + std::streamsize sockstreambuf:: + xsputn ( + const char* s, + std::streamsize num + ) + { + // Add a sanity check here + DLIB_ASSERT(num >= 0, + "\tstd::streamsize sockstreambuf::xsputn" + << "\n\tThe number of bytes to write can't be negative" + << "\n\tnum: " << num + << "\n\tthis: " << this + ); + + std::streamsize space_left = static_cast(epptr()-pptr()); + if (num <= space_left) + { + std::memcpy(pptr(),s,static_cast(num)); + pbump(static_cast(num)); + return num; + } + else + { + std::memcpy(pptr(),s,static_cast(space_left)); + s += space_left; + pbump(space_left); + std::streamsize num_left = num - space_left; + + if (flush_out_buffer() == EOF) + { + // the write was not successful so return that 0 bytes were written + return 0; + } + + if (num_left < out_buffer_size) + { + std::memcpy(pptr(),s,static_cast(num_left)); + pbump(num_left); + return num; + } + else + { + if (con.write(s,num_left) != num_left) + { + // the write was not successful so return that 0 bytes were written + return 0; + } + return num; + } + } + } + +// ---------------------------------------------------------------------------------------- + // input functions +// ---------------------------------------------------------------------------------------- + + sockstreambuf::int_type sockstreambuf:: + underflow( + ) + { + if (gptr() < egptr()) + { + return static_cast(*gptr()); + } + + int num_put_back = static_cast(gptr() - eback()); + if (num_put_back > max_putback) + { + num_put_back = max_putback; + } + + // copy the putback characters into the putback end of the in_buffer + std::memmove(in_buffer+(max_putback-num_put_back), gptr()-num_put_back, num_put_back); + + if (flushes_output_on_read()) + { + if (flush_out_buffer() == EOF) + { + // an error occurred + return EOF; + } + } + + int num = con.read(in_buffer+max_putback, in_buffer_size-max_putback); + if (num <= 0) + { + // an error occurred or the connection is over which is EOF + return EOF; + } + + // reset in_buffer pointers + setg (in_buffer+(max_putback-num_put_back), + in_buffer+max_putback, + in_buffer+max_putback+num); + + return static_cast(*gptr()); + } + +// ---------------------------------------------------------------------------------------- + + std::streamsize sockstreambuf:: + xsgetn ( + char_type* s, + std::streamsize n + ) + { + std::streamsize temp = n; + while (n > 0) + { + int num = static_cast(egptr() - gptr()); + if (num >= n) + { + // copy data from our buffer + std::memcpy(s, gptr(), static_cast(n)); + gbump(static_cast(n)); + return temp; + } + + // read more data into our buffer + if (num == 0) + { + if (underflow() == EOF) + break; + continue; + } + + // copy all the data from our buffer + std::memcpy(s, gptr(), num); + n -= num; + gbump(num); + s += num; + } + return temp-n; + } + +// ---------------------------------------------------------------------------------------- + +} +#endif // DLIB_SOCKStREAMBUF_CPp_ + diff --git a/ml/dlib/dlib/sockstreambuf/sockstreambuf.h b/ml/dlib/dlib/sockstreambuf/sockstreambuf.h new file mode 100644 index 000000000..f5b450e78 --- /dev/null +++ b/ml/dlib/dlib/sockstreambuf/sockstreambuf.h @@ -0,0 +1,172 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SOCKStREAMBUF_Hh_ +#define DLIB_SOCKStREAMBUF_Hh_ + +#include +#include +#include "../sockets.h" +#include "sockstreambuf_abstract.h" +#include "sockstreambuf_unbuffered.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class sockstreambuf : public std::streambuf + { + /*! + INITIAL VALUE + - con == a connection + - in_buffer == an array of in_buffer_size bytes + - out_buffer == an array of out_buffer_size bytes + + CONVENTION + - in_buffer == the input buffer used by this streambuf + - out_buffer == the output buffer used by this streambuf + - max_putback == the maximum number of chars to have in the put back buffer. + !*/ + + public: + + // These typedefs are here for backwards compatibility with previous versions of + // dlib. + typedef sockstreambuf_unbuffered kernel_1a; + typedef sockstreambuf kernel_2a; + + sockstreambuf ( + connection* con_ + ) : + con(*con_), + out_buffer(0), + in_buffer(0), + autoflush(false) + { + init(); + } + + sockstreambuf ( + const std::unique_ptr& con_ + ) : + con(*con_), + out_buffer(0), + in_buffer(0), + autoflush(false) + { + init(); + } + + virtual ~sockstreambuf ( + ) + { + sync(); + delete [] out_buffer; + delete [] in_buffer; + } + + connection* get_connection ( + ) { return &con; } + + void flush_output_on_read() + { + autoflush = true; + } + + bool flushes_output_on_read() const + { + return autoflush; + } + + void do_not_flush_output_on_read() + { + autoflush = false; + } + + protected: + + void init ( + ) + { + try + { + out_buffer = new char[out_buffer_size]; + in_buffer = new char[in_buffer_size]; + } + catch (...) + { + if (out_buffer) delete [] out_buffer; + throw; + } + setp(out_buffer, out_buffer + (out_buffer_size-1)); + setg(in_buffer+max_putback, + in_buffer+max_putback, + in_buffer+max_putback); + } + + int flush_out_buffer ( + ) + { + int num = static_cast(pptr()-pbase()); + if (con.write(out_buffer,num) != num) + { + // the write was not successful so return EOF + return EOF; + } + pbump(-num); + return num; + } + + // output functions + int_type overflow ( + int_type c + ); + + int sync ( + ) + { + if (flush_out_buffer() == EOF) + { + // an error occurred + return -1; + } + return 0; + } + + std::streamsize xsputn ( + const char* s, + std::streamsize num + ); + + // input functions + int_type underflow( + ); + + std::streamsize xsgetn ( + char_type* s, + std::streamsize n + ); + + private: + + // member data + connection& con; + static const std::streamsize max_putback = 4; + static const std::streamsize out_buffer_size = 10000; + static const std::streamsize in_buffer_size = 10000; + char* out_buffer; + char* in_buffer; + bool autoflush; + + }; + +// ---------------------------------------------------------------------------------------- + +} + +#ifdef NO_MAKEFILE +#include "sockstreambuf.cpp" +#endif + +#endif // DLIB_SOCKStREAMBUF_Hh_ + diff --git a/ml/dlib/dlib/sockstreambuf/sockstreambuf_abstract.h b/ml/dlib/dlib/sockstreambuf/sockstreambuf_abstract.h new file mode 100644 index 000000000..12be84193 --- /dev/null +++ b/ml/dlib/dlib/sockstreambuf/sockstreambuf_abstract.h @@ -0,0 +1,127 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SOCKSTREAMBUF_ABSTRACT_ +#ifdef DLIB_SOCKSTREAMBUF_ABSTRACT_ + +#include +#include +#include + +#include "../sockets/sockets_kernel_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class sockstreambuf : public std::streambuf + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a stream buffer capable of writing to and + reading from TCP connections. + + NOTE: + For a sockstreambuf EOF is when the connection has closed or otherwise + returned some kind of error. + + Also note that any data written to the streambuf may be buffered + internally. So if you need to ensure that data is actually sent then you + should flush the stream. + + A read operation is guaranteed to block until the number of bytes + requested has arrived on the connection. It will never keep blocking + once enough data has arrived. + + THREADING + Generally speaking, this object has the same kind of threading + restrictions as a connection object. those being: + - Do not try to write to a sockstreambuf from more than one thread. + - Do not try to read from a sockstreambuf from more than one thread. + - You may call shutdown() on the connection object and this will + cause any reading or writing calls to end. To the sockstreambuf it + will appear the same as hitting EOF. (note that EOF for a sockstreambuf + means that the connection has closed) + - It is safe to read from and write to the sockstreambuf at the same time + from different threads so long as flushes_output_on_read()==false. + - It is not safe to try to putback a char and read from the stream from + different threads + !*/ + public: + sockstreambuf ( + connection* con + ); + /*! + requires + - con == a valid connection object + ensures + - *this will read from and write to con + - #flushes_output_on_read() == false + throws + - std::bad_alloc + !*/ + + sockstreambuf ( + const std::unique_ptr& con + ); + /*! + requires + - con == a valid connection object + ensures + - *this will read from and write to con + - #flushes_output_on_read() == false + throws + - std::bad_alloc + !*/ + + ~sockstreambuf ( + ); + /*! + requires + - get_connection() object has not been deleted + ensures + - sockstream buffer is destructed but the connection object will + NOT be closed. + - Any buffered data is flushed to the connection. + !*/ + + connection* get_connection ( + ); + /*! + ensures + - returns a pointer to the connection object which this buffer + reads from and writes to + !*/ + + void flush_output_on_read ( + ); + /*! + ensures + - #flushes_output_on_read() == true + !*/ + + bool flushes_output_on_read ( + ) const; + /*! + ensures + - This function returns true if this object will flush its output buffer + to the network socket before performing any network read. + - if (flushes_output_on_read() == true) + - It is not safe to make concurrent read and write calls to this object. + !*/ + + void do_not_flush_output_on_read ( + ); + /*! + ensures + - #flushes_output_on_read() == false + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SOCKSTREAMBUF_ABSTRACT_ + diff --git a/ml/dlib/dlib/sockstreambuf/sockstreambuf_unbuffered.cpp b/ml/dlib/dlib/sockstreambuf/sockstreambuf_unbuffered.cpp new file mode 100644 index 000000000..4dcefc17b --- /dev/null +++ b/ml/dlib/dlib/sockstreambuf/sockstreambuf_unbuffered.cpp @@ -0,0 +1,168 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SOCKSTrEAMBUF_UNBUFFERED_CPp_ +#define DLIB_SOCKSTrEAMBUF_UNBUFFERED_CPp_ + +#include "sockstreambuf_unbuffered.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + // output functions +// ---------------------------------------------------------------------------------------- + + sockstreambuf_unbuffered::int_type sockstreambuf_unbuffered:: + overflow ( + int_type c + ) + { + if (c != EOF) + { + char temp = static_cast(c); + if (con.write(&temp,1) != 1) + { + // if the write was not successful + return EOF; + } + } + return c; + } + +// ---------------------------------------------------------------------------------------- + + std::streamsize sockstreambuf_unbuffered:: + xsputn ( + const char* s, + std::streamsize num + ) + { + if (con.write(s,static_cast(num)) != num) + { + // the write was not successful so return that 0 bytes were written + return 0; + } + return num; + } + +// ---------------------------------------------------------------------------------------- + // input functions +// ---------------------------------------------------------------------------------------- + + sockstreambuf_unbuffered::int_type sockstreambuf_unbuffered:: + underflow( + ) + { + if (lastread_next) + { + return lastread; + } + else if (peek != EOF) + { + return peek; + } + else + { + char temp; + if (con.read(&temp,1) != 1) + { + // some error occurred + return EOF; + } + peek = static_cast(temp); + return peek; + } + } + +// ---------------------------------------------------------------------------------------- + + sockstreambuf_unbuffered::int_type sockstreambuf_unbuffered:: + uflow( + ) + { + if (lastread_next) + { + lastread_next = false; + return lastread; + } + else if (peek != EOF) + { + lastread = peek; + peek = EOF; + return lastread; + } + else + { + char temp; + if (con.read(&temp,1) != 1) + { + // some error occurred + return EOF; + } + lastread = static_cast(temp); + return lastread; + } + } + +// ---------------------------------------------------------------------------------------- + + sockstreambuf_unbuffered::int_type sockstreambuf_unbuffered:: + pbackfail( + int_type c + ) + { + // if they are trying to push back a character that they didn't read last + // that is an error + if (c != EOF && c != lastread) + return EOF; + + // if they are trying to push back a second character then thats an error + if (lastread_next) + return EOF; + + lastread_next = true; + return 1; + } + +// ---------------------------------------------------------------------------------------- + + std::streamsize sockstreambuf_unbuffered:: + xsgetn ( + char_type* s, + std::streamsize n + ) + { + std::streamsize temp = n; + if (lastread_next && n > 0) + { + *s = lastread; + lastread_next = false; + ++s; + --n; + } + if (peek != EOF && n > 0) + { + *s = peek; + peek = EOF; + ++s; + --n; + } + + while (n>0) + { + int status = con.read(s,static_cast(n)); + if (status < 1) + break; + n -= status; + s += status; + } + + return temp-n; + } + +// ---------------------------------------------------------------------------------------- + +} +#endif // DLIB_SOCKSTrEAMBUF_UNBUFFERED_CPp_ + diff --git a/ml/dlib/dlib/sockstreambuf/sockstreambuf_unbuffered.h b/ml/dlib/dlib/sockstreambuf/sockstreambuf_unbuffered.h new file mode 100644 index 000000000..8aa5992db --- /dev/null +++ b/ml/dlib/dlib/sockstreambuf/sockstreambuf_unbuffered.h @@ -0,0 +1,118 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SOCKSTrEAMBUF_UNBUFFERED_Hh_ +#define DLIB_SOCKSTrEAMBUF_UNBUFFERED_Hh_ + +#include +#include +#include "../sockets.h" +#include "sockstreambuf_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class sockstreambuf_unbuffered : public std::streambuf + { + /*! + WHAT THIS OBJECT REPRESENTS + This is an implementation of the interface defined in + sockstreambuf_abstract.h except that it doesn't do any kind of buffering at + all. It just writes data directly to a connection. However, note that we + don't implement the flushes_output_on_read() routine as this object always + flushes immediately (since it isn't buffers. Moreover, it should be + pointed out that this object is deprecated and only present for backwards + compatibility with previous versions of dlib. So you really should use the + sockstreambuf object instead. + + INITIAL VALUE + con == a connection + lastread_next == false + peek == EOF + + CONVENTION + if (peek != EOF) then + peek == the last character read from the connection object and + is used to store the char in the event the user peeks by + calling sgetc() + if (lastread != EOF) then + lastread == the last character read and consumed by the user + + if (lastread_next) then + the next character to be returned to the user is lastread because + the user put it back into the buffer + + !*/ + + public: + + + sockstreambuf_unbuffered ( + connection* con_ + ) : + con(*con_), + peek(EOF), + lastread_next(false) + {} + + sockstreambuf_unbuffered ( + const std::unique_ptr& con_ + ) : + con(*con_), + peek(EOF), + lastread_next(false) + {} + + connection* get_connection ( + ) { return &con; } + + + protected: + + // output functions + int_type overflow ( + int_type c + ); + + std::streamsize xsputn ( + const char* s, + std::streamsize num + ); + + // input functions + int_type underflow( + ); + + int_type uflow( + ); + + int_type pbackfail( + int_type c + ); + + std::streamsize xsgetn ( + char_type* s, + std::streamsize n + ); + + private: + + // member data + connection& con; + int_type peek; + int_type lastread; + bool lastread_next; + + }; + +// ---------------------------------------------------------------------------------------- + +} + +#ifdef NO_MAKEFILE +#include "sockstreambuf_unbuffered.cpp" +#endif + +#endif // DLIB_SOCKSTrEAMBUF_UNBUFFERED_Hh_ + diff --git a/ml/dlib/dlib/sort.h b/ml/dlib/dlib/sort.h new file mode 100644 index 000000000..c798372f3 --- /dev/null +++ b/ml/dlib/dlib/sort.h @@ -0,0 +1,490 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SORt_ +#define DLIB_SORt_ + +#include "algs.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename compare + > + inline void qsort_array ( + T& array, + unsigned long left, + unsigned long right, + const compare& comp + ); + /*! + requires + - T implements operator[] + - the items in array must be comparable by comp where comp is a function + object with the same syntax as std::less<> + - the items in array must be swappable by a global swap() + - left and right are within the bounds of array + i.e. array[left] and array[right] are valid elements + - left <= right + ensures + - for all elements in #array between and including left and right the + ith element is < the i+1 element + - sorts using a quick sort algorithm + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename compare + > + void hsort_array ( + T& array, + unsigned long left, + unsigned long right, + const compare& comp + ); + /*! + requires + - T implements operator[] + - the items in array must be comparable by comp where comp is a function + object with the same syntax as std::less<> + - the items in array must be swappable by a global swap() + - left and right are within the bounds of array + i.e. array[left] and array[right] are valid elements + - left <= right + ensures + - for all elements in #array between and including left and right the + ith element is < the i+1 element + - sorts using a heapsort algorithm + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename compare + > + void isort_array ( + T& array, + unsigned long left, + unsigned long right, + const compare& comp + ); + /*! + requires + - T implements operator[] + - the items in array must be comparable by comp where comp is a function + object with the same syntax as std::less<> + - the items in array must be swappable by a global swap() + - left and right are within the bounds of array + i.e. array[left] and array[right] are valid elements + - left <= right + ensures + - for all elements in #array between and including left and right the + ith element is < the i+1 element + - sorts using an insertion sort algorithm + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + inline void qsort_array ( + T& array, + unsigned long left, + unsigned long right + ); + /*! + requires + - T implements operator[] + - the items in array must be comparable by std::less + - the items in array must be swappable by a global swap() + - left and right are within the bounds of array + i.e. array[left] and array[right] are valid elements + - left <= right + ensures + - for all elements in #array between and including left and right the + ith element is < the i+1 element + - sorts using a quick sort algorithm + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + void hsort_array ( + T& array, + unsigned long left, + unsigned long right + ); + /*! + requires + - T implements operator[] + - the items in array must be comparable by std::less + - the items in array must be swappable by a global swap() + - left and right are within the bounds of array + i.e. array[left] and array[right] are valid elements + - left <= right + ensures + - for all elements in #array between and including left and right the + ith element is < the i+1 element + - sorts using a heapsort algorithm + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + void isort_array ( + T& array, + unsigned long left, + unsigned long right + ); + /*! + requires + - T implements operator[] + - the items in array must be comparable by std::less + - the items in array must be swappable by a global swap() + - left and right are within the bounds of array + i.e. array[left] and array[right] are valid elements + - left <= right + ensures + - for all elements in #array between and including left and right the + ith element is < the i+1 element + - sorts using an insertion sort algorithm + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// IMPLEMENTATION DETAILS +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + namespace sort_helpers + { + template + inline const std::less comp (const T&) + { + return std::less(); + } + + template < + typename T, + typename Y, + typename compare + > + inline unsigned long qsort_partition ( + T& array, + Y& pivot, + const unsigned long left, + const unsigned long right, + const compare& comp + ) + /*! + requires + - &pivot == &array[right] + - T implements operator[] + - the items in array must be comparable by comp + - left and right are within the bounts of the array + - left < right + ensures + - returns a number called partition_element such that: + - left <= partition_element <= right + - all elements in #array < #array[partition_element] have + indices >= left and < partition_element + - all elements in #array > #array[partition_element] have + indices > partition_element and <= right + !*/ + { + DLIB_ASSERT (&pivot == &array[right] && left < right, + "\tunsigned long qsort_partition()" + << "\n\t&pivot: " << &pivot + << "\n\t&array[right]: " << &array[right] + << "\n\tleft: " << left + << "\n\tright: " << right ); + + exchange(array[(right-left)/2 +left],pivot); + + unsigned long i = left; + for (unsigned long j = left; j < right; ++j) + { + if (comp(array[j] , pivot)) + { + swap(array[i],array[j]); + ++i; + } + } + exchange(array[i],pivot); + + return i; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename compare + > + void qsort_array_main ( + T& array, + const unsigned long left, + const unsigned long right, + unsigned long depth_check, + const compare& comp + ) + /*! + requires + - T implements operator[] + - the items in array must be comparable by comp + - the items in array must be swappable by a global swap() + - left and right are within the bounds of array + i.e. array[left] and array[right] are valid elements + ensures + - for all elements in #array between and including left and right the + ith element is < the i+1 element + - will only recurse about as deep as log(depth_check) calls + - sorts using a quick sort algorithm + !*/ + { + if ( left < right) + { + if (right-left < 30 || depth_check == 0) + { + hsort_array(array,left,right,comp); + } + else + { + // The idea here is to only let quick sort go about log(N) + // calls deep before it kicks into something else. + depth_check >>= 1; + depth_check += (depth_check>>4); + + unsigned long partition_element = + qsort_partition(array,array[right],left,right,comp); + + if (partition_element > 0) + qsort_array_main(array,left,partition_element-1,depth_check,comp); + qsort_array_main(array,partition_element+1,right,depth_check,comp); + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename compare + > + void heapify ( + T& array, + const unsigned long start, + const unsigned long end, + unsigned long i, + const compare& comp + ) + /*! + requires + - T implements operator[] + - the items in array must be comparable by comp + - the items in array must be swappable by a global swap() + - start, end, and i are within the bounds of array + i.e. array[start], array[end], and array[i] are valid elements + - start <= i <= end + - array[i/2] is a max heap + - array[i/2+1] is a max heap + - start and end specify the range of the array we are working with. + ensures + - array[i] is now a max heap + !*/ + { + DLIB_ASSERT (start <= i && i <= end, + "\tvoid heapify()" + << "\n\tstart: " << start + << "\n\tend: " << end + << "\n\ti: " << i ); + + bool keep_going = true; + unsigned long left; + unsigned long right; + unsigned long largest; + while (keep_going) + { + keep_going = false; + left = (i<<1)+1-start; + right = left+1; + + if (left <= end && comp(array[i] , array[left])) + largest = left; + else + largest = i; + + if (right <= end && comp(array[largest] , array[right])) + largest = right; + + if (largest != i) + { + exchange(array[i],array[largest]); + i = largest; + keep_going = true; + } + } + } + +// ---------------------------------------------------------------------------------------- + } +// ---------------------------------------------------------------------------------------- + + + template < + typename T + > + inline void qsort_array ( + T& array, + unsigned long left, + unsigned long right + ) + { + using namespace sort_helpers; + qsort_array(array,left,right,comp(array[left])); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + void hsort_array ( + T& array, + unsigned long left, + unsigned long right + ) + { + using namespace sort_helpers; + hsort_array(array,left,right,comp(array[left])); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + void isort_array ( + T& array, + unsigned long left, + unsigned long right + ) + { + using namespace sort_helpers; + isort_array(array,left,right,comp(array[left])); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename compare + > + void isort_array ( + T& array, + const unsigned long left, + const unsigned long right, + const compare& comp + ) + { + DLIB_ASSERT (left <= right, + "\tvoid isort_array()" + << "\n\tleft: " << left + << "\n\tright: " << right ); + using namespace sort_helpers; + + unsigned long pos; + for (unsigned long i = left+1; i <= right; ++i) + { + // everything from left to i-1 is sorted. + pos = i; + for (unsigned long j = i-1; comp(array[pos] , array[j]); --j) + { + exchange(array[pos],array[j]); + pos = j; + + if (j == left) + break; + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename compare + > + void qsort_array ( + T& array, + const unsigned long left, + const unsigned long right, + const compare& comp + ) + { + DLIB_ASSERT (left <= right, + "\tvoid qsort_array()" + << "\n\tleft: " << left + << "\n\tright: " << right ); + + sort_helpers::qsort_array_main(array,left,right,right-left,comp); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename compare + > + void hsort_array ( + T& array, + const unsigned long left, + const unsigned long right, + const compare& comp + ) + { + DLIB_ASSERT (left <= right, + "\tvoid hsort_array()" + << "\n\tleft: " << left + << "\n\tright: " << right ); + + if (right-left < 30) + { + isort_array(array,left,right,comp); + return; + } + + // turn array into a max heap + for (unsigned long i = left+((right-left)>>1);; --i) + { + sort_helpers::heapify(array,left,right,i,comp); + if (i == left) + break; + } + + // now sort the array + for (unsigned long i = right; i > left;) + { + exchange(array[i],array[left]); + sort_helpers::heapify(array,left,--i,left,comp); + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SORt_ + diff --git a/ml/dlib/dlib/sparse_vector.h b/ml/dlib/dlib/sparse_vector.h new file mode 100644 index 000000000..4be6a3adc --- /dev/null +++ b/ml/dlib/dlib/sparse_vector.h @@ -0,0 +1,10 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SPaRSE_VECTOR_Hh_ +#define DLIB_SPaRSE_VECTOR_Hh_ + +#include "svm/sparse_vector.h" + +#endif // DLIB_SPaRSE_VECTOR_Hh_ + + diff --git a/ml/dlib/dlib/sqlite.h b/ml/dlib/dlib/sqlite.h new file mode 100644 index 000000000..b5aadbed2 --- /dev/null +++ b/ml/dlib/dlib/sqlite.h @@ -0,0 +1,11 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SQLiTE_HEADER +#define DLIB_SQLiTE_HEADER + +#include "sqlite/sqlite_tools.h" + +#endif // DLIB_SVm_HEADER + + + diff --git a/ml/dlib/dlib/sqlite/sqlite.h b/ml/dlib/dlib/sqlite/sqlite.h new file mode 100644 index 000000000..7eefbb21a --- /dev/null +++ b/ml/dlib/dlib/sqlite/sqlite.h @@ -0,0 +1,625 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SQLiTE_H_ +#define DLIB_SQLiTE_H_ + +#include "sqlite_abstract.h" + +#include +#include +#include +#include + +#include "../algs.h" +#include +#include "../serialize.h" + + +// -------------------------------------------------------------------------------------------- + +namespace dlib +{ + +// -------------------------------------------------------------------------------------------- + + struct sqlite_error : public error + { + sqlite_error(const std::string& message): error(message) {} + }; + +// -------------------------------------------------------------------------------------------- + + namespace impl + { + struct db_deleter + { + void operator()( + sqlite3* db + )const + { + sqlite3_close(db); + } + }; + } + +// -------------------------------------------------------------------------------------------- + + class database : noncopyable + { + public: + database( + ) + { + } + + database ( + const std::string& file + ) + { + open(file); + } + + bool is_open ( + ) const + { + return db.get() != 0; + } + + void open ( + const std::string& file + ) + { + filename = file; + sqlite3* ptr = 0; + int status = sqlite3_open(file.c_str(), &ptr); + db.reset(ptr, impl::db_deleter()); + if (status != SQLITE_OK) + { + throw sqlite_error(sqlite3_errmsg(db.get())); + } + } + + const std::string& get_database_filename ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(is_open() == true, + "\t std::string database::get_database_filename()" + << "\n\t The database must be opened before calling this routine." + << "\n\t this: " << this + ); + + return filename; + } + + inline void exec ( + const std::string& sql_statement + ); + + int64 last_insert_rowid ( + ) const + { + return sqlite3_last_insert_rowid(db.get()); + } + + private: + + friend class statement; + + std::string filename; + std::shared_ptr db; + }; + +// -------------------------------------------------------------------------------------------- + + class statement : noncopyable + { + public: + statement ( + database& db_, + const std::string sql_statement + ) : + needs_reset(false), + step_status(SQLITE_DONE), + at_first_step(true), + db(db_.db), + stmt(0), + sql_string(sql_statement) + { + // make sure requires clause is not broken + DLIB_ASSERT(db_.is_open() == true, + "\t statement::statement()" + << "\n\t The database must be opened before calling this routine." + << "\n\t this: " << this + ); + + int status = sqlite3_prepare_v2(db.get(), + sql_string.c_str(), + sql_string.size()+1, + &stmt, + NULL); + + if (status != SQLITE_OK) + { + sqlite3_finalize(stmt); + throw sqlite_error(sqlite3_errmsg(db.get())); + } + if (stmt == 0) + { + throw sqlite_error("Invalid SQL statement"); + } + } + + ~statement( + ) + { + sqlite3_finalize(stmt); + } + + void exec( + ) + { + reset(); + + step_status = sqlite3_step(stmt); + needs_reset = true; + if (step_status != SQLITE_DONE && step_status != SQLITE_ROW) + { + if (step_status == SQLITE_ERROR) + throw sqlite_error(sqlite3_errmsg(db.get())); + else if (step_status == SQLITE_BUSY) + throw sqlite_error("statement::exec() failed. SQLITE_BUSY returned"); + else + throw sqlite_error("statement::exec() failed."); + } + } + + bool move_next ( + ) + { + if (step_status == SQLITE_ROW) + { + if (at_first_step) + { + at_first_step = false; + return true; + } + else + { + step_status = sqlite3_step(stmt); + if (step_status == SQLITE_DONE) + { + return false; + } + else if (step_status == SQLITE_ROW) + { + return true; + } + else + { + throw sqlite_error(sqlite3_errmsg(db.get())); + } + } + } + else + { + return false; + } + } + + unsigned long get_num_columns( + ) const + { + if( (at_first_step==false) && (step_status==SQLITE_ROW)) + { + return sqlite3_column_count(stmt); + } + else + { + return 0; + } + } + + const std::string& get_sql_string ( + ) const + { + return sql_string; + } + + template + typename enable_if_c::is_integer>::type get_column ( + unsigned long idx, + T& item + ) const + { + // unsigned ints won't fit into int all the time so put those into 64bit ints. + if (sizeof(T) < sizeof(int) || (sizeof(T)==sizeof(int) && is_signed_type::value)) + item = get_column_as_int(idx); + else + item = get_column_as_int64(idx); + } + + void get_column(unsigned long idx, std::string& item) const { item = get_column_as_text(idx); } + void get_column(unsigned long idx, float& item ) const { item = get_column_as_double(idx); } + void get_column(unsigned long idx, double& item ) const { item = get_column_as_double(idx); } + void get_column(unsigned long idx, long double& item) const { item = get_column_as_double(idx); } + + template + typename disable_if_c::is_integer>::type get_column ( + unsigned long idx, + T& item + ) const + { + get_column_as_object(idx, item); + } + + const std::vector get_column_as_blob ( + unsigned long idx + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(idx < get_num_columns(), + "\t std::vector statement::get_column_as_blob()" + << "\n\t Invalid column index." + << "\n\t idx: " << idx + << "\n\t this: " << this + ); + + const char* data = static_cast(sqlite3_column_blob(stmt, idx)); + const int size = sqlite3_column_bytes(stmt, idx); + + return std::vector(data, data+size); + } + + template + void get_column_as_object ( + unsigned long idx, + T& item + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(idx < get_num_columns(), + "\t void statement::get_column_as_object()" + << "\n\t Invalid column index." + << "\n\t idx: " << idx + << "\n\t this: " << this + ); + + const char* data = static_cast(sqlite3_column_blob(stmt, idx)); + const int size = sqlite3_column_bytes(stmt, idx); + std::istringstream sin(std::string(data,size)); + deserialize(item, sin); + } + + const std::string get_column_as_text ( + unsigned long idx + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(idx < get_num_columns(), + "\t std::string statement::get_column_as_text()" + << "\n\t Invalid column index." + << "\n\t idx: " << idx + << "\n\t this: " << this + ); + + const char* data = reinterpret_cast(sqlite3_column_text(stmt, idx)); + if (data != 0) + return std::string(data); + else + return std::string(); + } + + double get_column_as_double ( + unsigned long idx + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(idx < get_num_columns(), + "\t double statement::get_column_as_double()" + << "\n\t Invalid column index." + << "\n\t idx: " << idx + << "\n\t this: " << this + ); + + return sqlite3_column_double(stmt, idx); + } + + int get_column_as_int ( + unsigned long idx + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(idx < get_num_columns(), + "\t int statement::get_column_as_int()" + << "\n\t Invalid column index." + << "\n\t idx: " << idx + << "\n\t this: " << this + ); + + return sqlite3_column_int(stmt, idx); + } + + int64 get_column_as_int64 ( + unsigned long idx + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(idx < get_num_columns(), + "\t int64 statement::get_column_as_int64()" + << "\n\t Invalid column index." + << "\n\t idx: " << idx + << "\n\t this: " << this + ); + + return sqlite3_column_int64(stmt, idx); + } + + const std::string get_column_name ( + unsigned long idx + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(idx < get_num_columns(), + "\t std::string statement::get_column_name()" + << "\n\t Invalid column index." + << "\n\t idx: " << idx + << "\n\t this: " << this + ); + + return std::string(sqlite3_column_name(stmt,idx)); + } + + unsigned long get_max_parameter_id ( + ) const + { + return sqlite3_limit(db.get(), SQLITE_LIMIT_VARIABLE_NUMBER, -1); + } + + unsigned long get_parameter_id ( + const std::string& name + ) const + { + return sqlite3_bind_parameter_index(stmt, name.c_str()); + } + + template + typename enable_if_c::is_integer>::type bind ( + unsigned long idx, + const T& item + ) + { + // unsigned ints won't fit into int all the time so put those into 64bit ints. + if (sizeof(T) < sizeof(int) || (sizeof(T)==sizeof(int) && is_signed_type::value)) + bind_int(idx, item); + else + bind_int64(idx, item); + } + + void bind(unsigned long idx, const std::string& item) { bind_text(idx, item); } + void bind(unsigned long idx, const float& item ) { bind_double(idx, item); } + void bind(unsigned long idx, const double& item ) { bind_double(idx, item); } + void bind(unsigned long idx, const long double& item) { bind_double(idx, item); } + + template + typename disable_if_c::is_integer>::type bind ( + unsigned long idx, + const T& item + ) + { + bind_object(idx, item); + } + + void bind_blob ( + unsigned long parameter_id, + const std::vector& item + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(1 <= parameter_id && parameter_id <= get_max_parameter_id(), + "\t void statement::bind_blob()" + << "\n\t Invalid parameter id." + << "\n\t parameter_id: " << parameter_id + << "\n\t get_max_parameter_id(): " << get_max_parameter_id() + << "\n\t this: " << this + ); + + reset(); + int status = sqlite3_bind_blob(stmt, parameter_id, &item[0], item.size(), SQLITE_TRANSIENT); + + if (status != SQLITE_OK) + { + throw sqlite_error(sqlite3_errmsg(db.get())); + } + } + + template + void bind_object ( + unsigned long parameter_id, + const T& item + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(1 <= parameter_id && parameter_id <= get_max_parameter_id(), + "\t void statement::bind_object()" + << "\n\t Invalid parameter id." + << "\n\t parameter_id: " << parameter_id + << "\n\t get_max_parameter_id(): " << get_max_parameter_id() + << "\n\t this: " << this + ); + + reset(); + std::ostringstream sout; + serialize(item, sout); + const std::string& str = sout.str(); + int status = sqlite3_bind_blob(stmt, parameter_id, str.data(), str.size(), SQLITE_TRANSIENT); + + if (status != SQLITE_OK) + { + throw sqlite_error(sqlite3_errmsg(db.get())); + } + } + + void bind_double ( + unsigned long parameter_id, + const double& item + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(1 <= parameter_id && parameter_id <= get_max_parameter_id(), + "\t void statement::bind_double()" + << "\n\t Invalid parameter id." + << "\n\t parameter_id: " << parameter_id + << "\n\t get_max_parameter_id(): " << get_max_parameter_id() + << "\n\t this: " << this + ); + + reset(); + int status = sqlite3_bind_double(stmt, parameter_id, item); + + if (status != SQLITE_OK) + { + throw sqlite_error(sqlite3_errmsg(db.get())); + } + } + + void bind_int ( + unsigned long parameter_id, + const int& item + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(1 <= parameter_id && parameter_id <= get_max_parameter_id(), + "\t void statement::bind_int()" + << "\n\t Invalid parameter id." + << "\n\t parameter_id: " << parameter_id + << "\n\t get_max_parameter_id(): " << get_max_parameter_id() + << "\n\t this: " << this + ); + + reset(); + int status = sqlite3_bind_int(stmt, parameter_id, item); + + if (status != SQLITE_OK) + { + throw sqlite_error(sqlite3_errmsg(db.get())); + } + } + + void bind_int64 ( + unsigned long parameter_id, + const int64& item + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(1 <= parameter_id && parameter_id <= get_max_parameter_id(), + "\t void statement::bind_int64()" + << "\n\t Invalid parameter id." + << "\n\t parameter_id: " << parameter_id + << "\n\t get_max_parameter_id(): " << get_max_parameter_id() + << "\n\t this: " << this + ); + + reset(); + int status = sqlite3_bind_int64(stmt, parameter_id, item); + + if (status != SQLITE_OK) + { + throw sqlite_error(sqlite3_errmsg(db.get())); + } + } + + void bind_null ( + unsigned long parameter_id + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(1 <= parameter_id && parameter_id <= get_max_parameter_id(), + "\t void statement::bind_null()" + << "\n\t Invalid parameter id." + << "\n\t parameter_id: " << parameter_id + << "\n\t get_max_parameter_id(): " << get_max_parameter_id() + << "\n\t this: " << this + ); + + reset(); + int status = sqlite3_bind_null(stmt, parameter_id); + + if (status != SQLITE_OK) + { + throw sqlite_error(sqlite3_errmsg(db.get())); + } + } + + void bind_text ( + unsigned long parameter_id, + const std::string& item + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(1 <= parameter_id && parameter_id <= get_max_parameter_id(), + "\t void statement::bind_text()" + << "\n\t Invalid parameter id." + << "\n\t parameter_id: " << parameter_id + << "\n\t get_max_parameter_id(): " << get_max_parameter_id() + << "\n\t this: " << this + ); + + reset(); + int status = sqlite3_bind_text(stmt, parameter_id, item.c_str(), -1, SQLITE_TRANSIENT); + + if (status != SQLITE_OK) + { + throw sqlite_error(sqlite3_errmsg(db.get())); + } + } + + private: + + void reset() + { + if (needs_reset) + { + if (sqlite3_reset(stmt) != SQLITE_OK) + { + step_status = SQLITE_DONE; + throw sqlite_error(sqlite3_errmsg(db.get())); + } + needs_reset = false; + step_status = SQLITE_DONE; + at_first_step = true; + } + } + + bool needs_reset; // true if sqlite3_step() has been called more recently than sqlite3_reset() + int step_status; + bool at_first_step; + + std::shared_ptr db; + sqlite3_stmt* stmt; + std::string sql_string; + }; + +// -------------------------------------------------------------------------------------------- + + void database:: + exec ( + const std::string& sql_statement + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(is_open() == true, + "\t void database::exec()" + << "\n\t The database must be opened before calling this routine." + << "\n\t this: " << this + ); + + statement(*this, sql_statement).exec(); + } + +// -------------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SQLiTE_H_ + diff --git a/ml/dlib/dlib/sqlite/sqlite_abstract.h b/ml/dlib/dlib/sqlite/sqlite_abstract.h new file mode 100644 index 000000000..7372162d8 --- /dev/null +++ b/ml/dlib/dlib/sqlite/sqlite_abstract.h @@ -0,0 +1,506 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SQLiTE_ABSTRACT_H_ +#ifdef DLIB_SQLiTE_ABSTRACT_H_ + + +#include +#include +#include "../algs.h" +#include +#include "../smart_pointers.h" + +// ---------------------------------------------------------------------------------------- + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + struct sqlite_error : public error + { + /*! + WHAT THIS OBJECT REPRESENTS + This is the exception object used by the SQLite tools to indicate + that an error has occurred. Any of the functions defined in this + file might throw this exception. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + class database : noncopyable + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a C++ wrapper around a SQLite database connection + handle and therefore represents a SQLite database file. + + Note that this wrapper is targeted at SQLite Version 3. + + Note also that whenever SQLite indicates an error has occurred + this object will throw the sqlite_error exception. + !*/ + + public: + database( + ); + /*! + ensures + - #is_open() == false + !*/ + + database ( + const std::string& file + ); + /*! + ensures + - opens the indicated database file or creates a new + database with the given name if one doesn't already exist. + - #get_database_filename() == file + - #is_open() == true + !*/ + + ~database ( + ); + /*! + ensures + - safely disposes of any SQLite database connection. If + any statement objects still exist which reference this database + then the SQLite database connection won't be fully closed + until those statement objects are also destroyed. This allows + for any destruction order between database and statement objects. + !*/ + + void open ( + const std::string& file + ); + /*! + ensures + - opens the indicated database file or creates a new + database with the given name if one doesn't already exist. + - #get_database_filename() == file + - #is_open() == true + - safely disposes of any previous SQLite database connection. If + any statement objects still exist which reference this database + then the SQLite database connection won't be fully closed + until those statement objects are also destroyed. + !*/ + + bool is_open ( + ) const; + /*! + ensures + - if (this object has an open connection to a SQLite database) then + - returns true + - else + - returns false + !*/ + + const std::string& get_database_filename ( + ) const; + /*! + requires + - is_open() == true + ensures + - returns the name of the SQLite database file this object + currently has open. + !*/ + + void exec ( + const std::string& sql_statement + ); + /*! + requires + - is_open() == true + ensures + - executes the supplied SQL statement against this database + !*/ + + int64 last_insert_rowid ( + ) const; + /*! + requires + - is_open() == true + ensures + - Each element in a database table has a rowid which uniquely identifies + it. Therefore, this routine returns the rowid of the most recent + successful INSERT into the database via this database instance. + - If an INSERT has not been performed on the current database instance then + the return value is 0. This is true even if the database is not empty. + - See the sqlite documentation for the full details on how this function + behaves: http://www.sqlite.org/c3ref/last_insert_rowid.html + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + class statement : noncopyable + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a SQL statement which can be executed + against a database object. In particular, this object is a + C++ wrapper around a SQLite prepared statement. + + + Note that whenever SQLite indicates an error has occurred this + object will throw the sqlite_error exception. + + BINDABLE SQL PARAMETERS + Sometimes you want to execute a bunch of very similar SQL statements. + For example, you might need to execute many insert statements where each + statement changes only the value of a field. Since it is somewhat + costly to construct a statement object for each SQL operation, SQLite + supports defining bindable parameters for a statement object. This allows + you to reuse the same statement object. + + Therefore, in SQL statements used with SQLite, wherever it is valid to + include a string literal, one can use a parameter in one of the following + forms: + + ? + ?NNN + :AAA + $AAA + @AAA + + In the examples above, NNN is an integer value and AAA is an identifier. A + parameter initially has a value of NULL. You can use the bind_*() routines + to attach values to the parameters. Each call to a bind_*() routine overrides + prior bindings on the same parameter. + + Each SQL parameter has a numeric ID which is used to reference it when invoking + a bind_*() routine. The leftmost SQL parameter in a statement has an index of 1, + the next parameter has an index of 2, and so on, except when the following rules + apply. When the same named SQL parameter is used more than once, second and + subsequent occurrences have the same index as the first occurrence. The index + for named parameters can be looked up using the get_parameter_id() method if desired. + The index for "?NNN" parameters is the value of NNN. The NNN value must be between + 1 and get_max_parameter_id(). + !*/ + + public: + statement ( + database& db, + const std::string sql_statement + ); + /*! + requires + - db.is_open() == true + ensures + - The given SQL statement can be executed against the given + database by calling exec(). + - #get_sql_string() == sql_statement + !*/ + + ~statement( + ); + /*! + ensures + - any resources associated with this object have been freed. + !*/ + + const std::string& get_sql_string ( + ) const; + /*! + ensures + - returns a copy of the SQL statement used to create this statement object. + !*/ + + void exec( + ); + /*! + ensures + - #get_num_columns() == 0 + - executes the SQL statement get_sql_string() against the database + given to this object's constructor. + - If this was a select statement then you can obtain the resulting + rows by calling move_next() and using the get_column_as_*() member + functions. + !*/ + + // ---------------------------- + + bool move_next ( + ); + /*! + ensures + - if (there is a result row for this query) then + - #get_num_columns() == the number of columns in the result row. + - The get_column_as_*() routines can be used to access the elements + of the row data. + - returns true + - else + - returns false + - #get_num_columns() == 0 + !*/ + + unsigned long get_num_columns( + ) const; + /*! + ensures + - returns the number of columns of data available via the get_column_as_*() + routines. + !*/ + + template < + typename T + > + void get_column ( + unsigned long idx, + T& item + ) const; + /*! + requires + - idx < get_num_columns() + ensures + - This function automatically selects how to extract the column data based + on the type of item given. In particular: + - if (T is a 32bit or smaller built in integer type) then + - #item == get_column_as_int(idx) + - else if (T is a 64bit built in integer type) then + - #item == get_column_as_int64(idx) + - else if (T is float, double, or long double) then + - #item == get_column_as_double(idx) + - else if (T is std::string) then + - #item == get_column_as_text(idx) + - else + - invokes: get_column_as_object(idx, item) + !*/ + + const std::vector get_column_as_blob ( + unsigned long idx + ) const; + /*! + requires + - idx < get_num_columns() + ensures + - returns the contents of the idx-th column as a binary BLOB. + !*/ + + template < + typename T + > + void get_column_as_object ( + unsigned long idx, + T& item + ) const; + /*! + requires + - idx < get_num_columns() + - item is deserializable + (i.e. Calling deserialize(item, some_input_stream) reads an item + of type T from the some_input_stream stream) + ensures + - gets the contents of the idx-th column as a binary BLOB and then + deserializes it into item. + !*/ + + const std::string get_column_as_text ( + unsigned long idx + ) const; + /*! + requires + - idx < get_num_columns() + ensures + - returns the contents of the idx-th column as a text string. + !*/ + + double get_column_as_double ( + unsigned long idx + ) const; + /*! + requires + - idx < get_num_columns() + ensures + - returns the contents of the idx-th column as a double. + !*/ + + int get_column_as_int ( + unsigned long idx + ) const; + /*! + requires + - idx < get_num_columns() + ensures + - returns the contents of the idx-th column as an int. + !*/ + + int64 get_column_as_int64 ( + unsigned long idx + ) const; + /*! + requires + - idx < get_num_columns() + ensures + - returns the contents of the idx-th column as a 64bit int. + !*/ + + const std::string get_column_name ( + unsigned long idx + ) const; + /*! + requires + - idx < get_num_columns() + ensures + - returns the name of the idx-th column. In particular: + The name of a result column is the value of the "AS" clause for + that column, if there is an AS clause. If there is no AS clause + then the name of the column is unspecified and may change from + one release of SQLite to the next. + !*/ + + // ---------------------------- + + unsigned long get_max_parameter_id ( + ) const; + /*! + ensures + - returns the max parameter ID value which can be used with the + bind_() member functions defined below. + - In SQLite, the default value of this limit is usually 999. + !*/ + + unsigned long get_parameter_id ( + const std::string& name + ) const; + /*! + ensures + - if (This SQL statement contains a SQL parameter with the given name) then + - returns the parameter_id number which can be used in the bind_*() + member functions defined below. + - else + - returns 0 + !*/ + + template < + typename T + > + void bind ( + unsigned long parameter_id, + const T& item + ) const; + /*! + requires + - 1 <= parameter_id <= get_max_parameter_id() + ensures + - This function automatically selects how to bind item to a statement based + on the type of item given. In particular: + - if (T is a 32bit or smaller built in integer type) then + - invokes: bind_int(parameter_id, item) + - else if (T is a 64bit built in integer type) then + - invokes: bind_int64(parameter_id, item) + - else if (T is float, double, or long double) then + - invokes: bind_double(parameter_id, item) + - else if (T is std::string) then + - invokes: bind_text(parameter_id, item) + - else + - invokes: bind_object(parameter_id, item) + !*/ + + void bind_blob ( + unsigned long parameter_id, + const std::vector& item + ); + /*! + requires + - 1 <= parameter_id <= get_max_parameter_id() + ensures + - #get_num_columns() == 0 + - binds the value of item into the SQL parameter indicated by + parameter_id. + !*/ + + template < + typename T + > + void bind_object ( + unsigned long parameter_id, + const T& item + ); + /*! + requires + - 1 <= parameter_id <= get_max_parameter_id() + - item is serializable + (i.e. Calling serialize(item, some_output_stream) writes an item + of type T to the some_output_stream stream) + ensures + - #get_num_columns() == 0 + - binds the value of item into the SQL parameter indicated by + parameter_id. This is performed by serializing item and then + binding it as a binary BLOB. + !*/ + + void bind_double ( + unsigned long parameter_id, + const double& item + ); + /*! + requires + - 1 <= parameter_id <= get_max_parameter_id() + ensures + - #get_num_columns() == 0 + - binds the value of item into the SQL parameter indicated by + parameter_id. + !*/ + + void bind_int ( + unsigned long parameter_id, + const int& item + ); + /*! + requires + - 1 <= parameter_id <= get_max_parameter_id() + ensures + - #get_num_columns() == 0 + - binds the value of item into the SQL parameter indicated by + parameter_id. + !*/ + + void bind_int64 ( + unsigned long parameter_id, + const int64& item + ); + /*! + requires + - 1 <= parameter_id <= get_max_parameter_id() + ensures + - #get_num_columns() == 0 + - binds the value of item into the SQL parameter indicated by + parameter_id. + !*/ + + void bind_null ( + unsigned long parameter_id + ); + /*! + requires + - 1 <= parameter_id <= get_max_parameter_id() + ensures + - #get_num_columns() == 0 + - binds a NULL to the SQL parameter indicated by parameter_id. + !*/ + + void bind_text ( + unsigned long parameter_id, + const std::string& item + ); + /*! + requires + - 1 <= parameter_id <= get_max_parameter_id() + ensures + - #get_num_columns() == 0 + - binds the value of item into the SQL parameter indicated by + parameter_id. + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SQLiTE_ABSTRACT_H_ + + diff --git a/ml/dlib/dlib/sqlite/sqlite_tools.h b/ml/dlib/dlib/sqlite/sqlite_tools.h new file mode 100644 index 000000000..062a6b2c8 --- /dev/null +++ b/ml/dlib/dlib/sqlite/sqlite_tools.h @@ -0,0 +1,189 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SQLiTE_TOOLS_H_ +#define DLIB_SQLiTE_TOOLS_H_ + + +#include "sqlite_tools_abstract.h" +#include "sqlite.h" + +// ---------------------------------------------------------------------------------------- + +namespace dlib +{ + + class transaction : noncopyable + { + public: + transaction ( + database& db_ + ) : + db(db_), + committed(false) + { + db.exec("begin transaction"); + } + + void commit () + { + if (!committed) + { + committed = true; + db.exec("commit"); + } + } + + ~transaction() throw (std::exception) + { + if (!committed) + db.exec("rollback"); + } + + private: + database& db; + bool committed; + + }; + +// ---------------------------------------------------------------------------------------- + + + template < + typename T + > + void query_object ( + database& db, + const std::string& query, + T& item + ) + { + statement st(db, query); + st.exec(); + if (st.move_next() && st.get_num_columns() == 1) + { + st.get_column_as_object(0,item); + if (st.move_next()) + throw sqlite_error("query doesn't result in exactly 1 element"); + } + else + { + throw sqlite_error("query doesn't result in exactly 1 element"); + } + } + +// ---------------------------------------------------------------------------------------- + + inline std::string query_text ( + database& db, + const std::string& query + ) + { + statement st(db, query); + st.exec(); + if (st.move_next() && st.get_num_columns() == 1) + { + const std::string& temp = st.get_column_as_text(0); + if (st.move_next()) + throw sqlite_error("query doesn't result in exactly 1 element"); + return temp; + } + else + { + throw sqlite_error("query doesn't result in exactly 1 element"); + } + } + +// ---------------------------------------------------------------------------------------- + + inline double query_double ( + database& db, + const std::string& query + ) + { + statement st(db, query); + st.exec(); + if (st.move_next() && st.get_num_columns() == 1) + { + double temp = st.get_column_as_double(0); + if (st.move_next()) + throw sqlite_error("query doesn't result in exactly 1 element"); + return temp; + } + else + { + throw sqlite_error("query doesn't result in exactly 1 element"); + } + } + +// ---------------------------------------------------------------------------------------- + + inline int query_int ( + database& db, + const std::string& query + ) + { + statement st(db, query); + st.exec(); + if (st.move_next() && st.get_num_columns() == 1) + { + int temp = st.get_column_as_int(0); + if (st.move_next()) + throw sqlite_error("query doesn't result in exactly 1 element"); + return temp; + } + else + { + throw sqlite_error("query doesn't result in exactly 1 element"); + } + } + +// ---------------------------------------------------------------------------------------- + + inline int64 query_int64 ( + database& db, + const std::string& query + ) + { + statement st(db, query); + st.exec(); + if (st.move_next() && st.get_num_columns() == 1) + { + int64 temp = st.get_column_as_int64(0); + if (st.move_next()) + throw sqlite_error("query doesn't result in exactly 1 element"); + return temp; + } + else + { + throw sqlite_error("query doesn't result in exactly 1 element"); + } + } + +// ---------------------------------------------------------------------------------------- + + inline const std::vector query_blob ( + database& db, + const std::string& query + ) + { + statement st(db, query); + st.exec(); + if (st.move_next() && st.get_num_columns() == 1) + { + const std::vector& temp = st.get_column_as_blob(0); + if (st.move_next()) + throw sqlite_error("query doesn't result in exactly 1 element"); + return temp; + } + else + { + throw sqlite_error("query doesn't result in exactly 1 element"); + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SQLiTE_TOOLS_H_ + diff --git a/ml/dlib/dlib/sqlite/sqlite_tools_abstract.h b/ml/dlib/dlib/sqlite/sqlite_tools_abstract.h new file mode 100644 index 000000000..c13a09265 --- /dev/null +++ b/ml/dlib/dlib/sqlite/sqlite_tools_abstract.h @@ -0,0 +1,164 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SQLiTE_TOOLS_ABSTRACT_H_ +#ifdef DLIB_SQLiTE_TOOLS_ABSTRACT_H_ + + +#include "sqlite_abstract.h" + +// ---------------------------------------------------------------------------------------- + +namespace dlib +{ + + class transaction : noncopyable + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a tool for creating exception safe + database transactions. + !*/ + + public: + transaction ( + database& db + ); + /*! + ensures + - Begins a database transaction which will be rolled back + if commit() isn't called eventually. + - In particular, performs: db.exec("begin transaction"); + !*/ + + void commit ( + ); + /*! + ensures + - if (commit() hasn't already been called) then + - Commits all changes made during this database transaction. + - In particular, performs: db.exec("commit"); + - else + - does nothing + !*/ + + ~transaction( + ); + /*! + ensures + - if (commit() was never called) then + - rolls back any changes made to the database during this transaction. + - In particular, performs: db.exec("rollback"); + - else + - does nothing + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + void query_object ( + database& db, + const std::string& query, + T& item + ); + /*! + ensures + - executes the given SQL query against db. If the query results in a + single row and column being returned then the data in the column is + interpreted as a binary BLOB and deserialized into item. + throws + - sqlite_error or serialization_error if an error occurs which prevents + this operation from succeeding. + + !*/ + +// ---------------------------------------------------------------------------------------- + + std::string query_text ( + database& db, + const std::string& query + ); + /*! + ensures + - executes the given SQL query against db. If the query results in a + single row and column being returned then the data in the column is + converted to text and returned. + throws + - sqlite_error if an error occurs which prevents this operation from + succeeding. + !*/ + +// ---------------------------------------------------------------------------------------- + + double query_double ( + database& db, + const std::string& query + ); + /*! + ensures + - executes the given SQL query against db. If the query results in a + single row and column being returned then the data in the column is + converted to a double and returned. + throws + - sqlite_error if an error occurs which prevents this operation from + succeeding. + !*/ + +// ---------------------------------------------------------------------------------------- + + int query_int ( + database& db, + const std::string& query + ); + /*! + ensures + - executes the given SQL query against db. If the query results in a + single row and column being returned then the data in the column is + converted to an int and returned. + throws + - sqlite_error if an error occurs which prevents this operation from + succeeding. + !*/ + +// ---------------------------------------------------------------------------------------- + + int64 query_int64 ( + database& db, + const std::string& query + ); + /*! + ensures + - executes the given SQL query against db. If the query results in a + single row and column being returned then the data in the column is + converted to an int64 and returned. + throws + - sqlite_error if an error occurs which prevents this operation from + succeeding. + !*/ + +// ---------------------------------------------------------------------------------------- + + const std::vector query_blob ( + database& db, + const std::string& query + ); + /*! + ensures + - executes the given SQL query against db. If the query results in a + single row and column being returned then the data in the column is + returned as a binary BLOB. + throws + - sqlite_error if an error occurs which prevents this operation from + succeeding. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SQLiTE_TOOLS_H_ + + diff --git a/ml/dlib/dlib/sstream b/ml/dlib/dlib/sstream new file mode 100644 index 000000000..eb0e59e41 --- /dev/null +++ b/ml/dlib/dlib/sstream @@ -0,0 +1 @@ +#include "dlib_include_path_tutorial.txt" diff --git a/ml/dlib/dlib/stack.h b/ml/dlib/dlib/stack.h new file mode 100644 index 000000000..58f2c5a6c --- /dev/null +++ b/ml/dlib/dlib/stack.h @@ -0,0 +1,34 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_STACk_ +#define DLIB_STACk_ + +#include "stack/stack_kernel_1.h" +#include "stack/stack_kernel_c.h" +#include "algs.h" + +namespace dlib +{ + + template < + typename T, + typename mem_manager = default_memory_manager + > + class stack + { + stack() {} + public: + + //----------- kernels --------------- + + // kernel_1a + typedef stack_kernel_1 + kernel_1a; + typedef stack_kernel_c + kernel_1a_c; + + }; +} + +#endif // DLIB_STACk_ + diff --git a/ml/dlib/dlib/stack/stack_kernel_1.h b/ml/dlib/dlib/stack/stack_kernel_1.h new file mode 100644 index 000000000..427d65183 --- /dev/null +++ b/ml/dlib/dlib/stack/stack_kernel_1.h @@ -0,0 +1,504 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_STACK_KERNEl_1_ +#define DLIB_STACK_KERNEl_1_ + +#include "stack_kernel_abstract.h" +#include "../algs.h" +#include "../interfaces/enumerable.h" +#include "../interfaces/remover.h" +#include "../serialize.h" + +namespace dlib +{ + + template < + typename T, + typename mem_manager = default_memory_manager + > + class stack_kernel_1 : public enumerable, + public remover + { + + /*! + INITIAL VALUE + stack_size == 0 + top == 0 + current_element == 0 + _at_start == true + + + CONVENTION + at_start() == _at_start + current_element_valid() == (current_element != 0) + if (current_element != 0) then + element() == current_element->item + + stack_size == the number of elements in the stack. + Each node points to the next node to be poped off the stack. + The last node in the list has its next pointer is set to 0. + + if (size == 0) + { + top == 0 + } + else + { + top == pointer to the last element added to the stack + } + !*/ + + struct node + { + node* next; + T item; + }; + + public: + + typedef T type; + typedef mem_manager mem_manager_type; + + stack_kernel_1( + ): + top(0), + stack_size(0), + current_element(0), + _at_start(true) + {} + + virtual ~stack_kernel_1( + ); + + inline void clear( + ); + + inline void push( + T& item + ); + + void pop( + T& item + ); + + T& current( + ); + + const T& current( + ) const; + + inline void swap ( + stack_kernel_1& item + ); + + // functions from the remover interface + inline void remove_any ( + T& item + ); + + // functions from the enumerable interface + inline size_t size ( + ) const; + + inline bool at_start ( + ) const; + + inline void reset ( + ) const; + + bool current_element_valid ( + ) const; + + inline const T& element ( + ) const; + + inline T& element ( + ); + + bool move_next ( + ) const; + + private: + + void delete_elements_in_stack( + node*& top + ); + /*! + requires + - top points to the top of the stack + ensures + - all memory has been freed + - #top = 0 + !*/ + + + // data members + typename mem_manager::template rebind::other pool; + node* top; + unsigned long stack_size; + mutable node* current_element; + mutable bool _at_start; + + + // restricted functions + stack_kernel_1(stack_kernel_1&); // copy constructor + stack_kernel_1& operator=(stack_kernel_1&); // assignment operator + + }; + + + template < + typename T, + typename mem_manager + > + inline void swap ( + stack_kernel_1& a, + stack_kernel_1& b + ) { a.swap(b); } + + template < + typename T, + typename mem_manager + > + void deserialize ( + stack_kernel_1& item, + std::istream& in + ) + { + try + { + item.clear(); + unsigned long size; + deserialize(size,in); + T temp = T(); + stack_kernel_1 temp_stack; + for (unsigned long i = 0; i < size; ++i) + { + deserialize(temp,in); + temp_stack.push(temp); + } + while (temp_stack.size() > 0) + { + temp_stack.pop(temp); + item.push(temp); + } + } + catch (serialization_error e) + { + item.clear(); + throw serialization_error(e.info + "\n while deserializing object of type stack_kernel_1"); + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + stack_kernel_1:: + ~stack_kernel_1( + ) + { + delete_elements_in_stack(top); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void stack_kernel_1:: + clear( + ) + { + if (stack_size != 0) + { + delete_elements_in_stack(top); + stack_size = 0; + } + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + T& stack_kernel_1:: + current( + ) + { + return top->item; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + const T& stack_kernel_1:: + current( + ) const + { + return top->item; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void stack_kernel_1:: + swap( + stack_kernel_1& item + ) + { + pool.swap(item.pool); + + // declare temp variables + node* top_temp; + unsigned long stack_size_temp; + + // swap stack_size variables + stack_size_temp = item.stack_size; + item.stack_size = stack_size; + stack_size = stack_size_temp; + + // swap top pointers + top_temp = item.top; + item.top = top; + top = top_temp; + + exchange(current_element,item.current_element); + exchange(_at_start,item._at_start); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void stack_kernel_1:: + push( + T& item + ) + { + // allocate memory for new node + node* new_node = pool.allocate(); + + // swap item into new_node + exchange(new_node->item,item); + + // put new_node into stack + new_node->next = top; + top = new_node; + ++stack_size; + + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void stack_kernel_1:: + pop( + T& item + ) + { + node* old_node = top; + top = top->next; + + // swap the item from the stack into item + exchange(old_node->item,item); + + // free the memory + pool.deallocate(old_node); + --stack_size; + + reset(); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // private member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void stack_kernel_1:: + delete_elements_in_stack( + node*& top + ) + { + node* temp; + while (top != 0) + { + temp = top->next; + pool.deallocate(top); + top = temp; + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // enumerable function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + size_t stack_kernel_1:: + size ( + ) const + { + return stack_size; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + bool stack_kernel_1:: + at_start ( + ) const + { + return _at_start; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void stack_kernel_1:: + reset ( + ) const + { + _at_start = true; + current_element = 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + bool stack_kernel_1:: + current_element_valid ( + ) const + { + return current_element != 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + const T& stack_kernel_1:: + element ( + ) const + { + return current_element->item; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + T& stack_kernel_1:: + element ( + ) + { + return current_element->item; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + bool stack_kernel_1:: + move_next ( + ) const + { + if (!_at_start) + { + if (current_element) + { + current_element = current_element->next; + if (current_element) + return true; + else + return false; + } + else + { + return false; + } + } + else + { + _at_start = false; + if (stack_size) + { + current_element = top; + return true; + } + else + { + return false; + } + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // remover function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void stack_kernel_1:: + remove_any ( + T& item + ) + { + pop(item); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_STACK_KERNEl_1_ + diff --git a/ml/dlib/dlib/stack/stack_kernel_abstract.h b/ml/dlib/dlib/stack/stack_kernel_abstract.h new file mode 100644 index 000000000..d86cc0629 --- /dev/null +++ b/ml/dlib/dlib/stack/stack_kernel_abstract.h @@ -0,0 +1,180 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_STACK_KERNEl_ABSTRACT_ +#ifdef DLIB_STACK_KERNEl_ABSTRACT_ + +#include "../interfaces/enumerable.h" +#include "../interfaces/remover.h" +#include "../serialize.h" +#include "../algs.h" + +namespace dlib +{ + + template < + typename T, + typename mem_manager = default_memory_manager + > + class stack : public enumerable, + public remover + { + + /*! + REQUIREMENTS ON T + T must be swappable by a global swap() and + T must have a default constructor + + REQUIREMENTS ON mem_manager + must be an implementation of memory_manager/memory_manager_kernel_abstract.h or + must be an implementation of memory_manager_global/memory_manager_global_kernel_abstract.h or + must be an implementation of memory_manager_stateless/memory_manager_stateless_kernel_abstract.h + mem_manager::type can be set to anything. + + POINTERS AND REFERENCES TO INTERNAL DATA + swap() and current() functions do not invalidate pointers + or references to internal data. + All other functions have no such guarantee. + + INITIAL VALUE + size() == 0 + + ENUMERATION ORDER + The enumerator will iterate over the elements in the stack in the + same order they would be removed in by repeated calls to pop(). + (e.g. current() would be the first element enumerated) + + WHAT THIS OBJECT REPRESENTS + This is a last in first out stack containing items of type T. + + e.g. if the stack is {b,c,d,e} then a is put in + the stack becomes {a,b,c,d,e} and then pop takes a back out + returning the stack to {b,c,d,e} + + Also note that unless specified otherwise, no member functions + of this object throw exceptions. + !*/ + + public: + + typedef T type; + typedef mem_manager mem_manager_type; + + stack ( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc or any exception thrown by T's constructor + !*/ + + virtual ~stack ( + ); + /*! + ensures + - all memory associated with *this has been released + !*/ + + void clear( + ); + /*! + ensures + - #*this has its initial value + throws + - std::bad_alloc or any exception thrown by T's constructor + if this exception is thrown then *this is unusable + until clear() is called and succeeds + !*/ + + void push ( + T& item + ); + /*! + ensures + - item has been swapped onto the top of the stack + - #current() == item + - #item has an initial value for its type + - #size() == size() + 1 + - #at_start() == true + throws + - std::bad_alloc or any exception thrown by T's constructor + if push() throws then it has no effect + !*/ + + void pop ( + T& item + ); + /*! + requires + - size() != 0 + ensures + - #size() == size() - 1 + - #item == current() + i.e. the top element of *this has been removed and swapped + into #item + - #at_start() == true + !*/ + + T& current ( + ); + /*! + requires + - size() != 0 + ensures + - returns a const reference to the element at the top of *this + !*/ + + const T& current ( + ) const; + /*! + requires + - size() != 0 + ensures + - returns a non-const reference to the element at the top of *this + !*/ + + void swap ( + stack& item + ); + /*! + ensures + - swaps *this and item + !*/ + + + private: + + // restricted functions + stack(stack&); // copy constructor + stack& operator=(stack&); // assignment operator + + }; + + + template < + typename T, + typename mem_manager + > + inline void swap ( + stack& a, + stack& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + + template < + typename T, + typename mem_manager + > + void deserialize ( + stack& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ +} + +#endif // DLIB_STACK_KERNEl_ABSTRACT_ + diff --git a/ml/dlib/dlib/stack/stack_kernel_c.h b/ml/dlib/dlib/stack/stack_kernel_c.h new file mode 100644 index 000000000..ec8642a40 --- /dev/null +++ b/ml/dlib/dlib/stack/stack_kernel_c.h @@ -0,0 +1,189 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_STACK_KERNEl_C_ +#define DLIB_STACK_KERNEl_C_ + +#include "stack_kernel_abstract.h" +#include "../algs.h" +#include "../assert.h" + +namespace dlib +{ + + template < + typename stack_base + > + class stack_kernel_c : public stack_base + { + typedef typename stack_base::type T; + public: + void pop( + T& item + ); + + T& current( + ); + + const T& current( + ) const; + + const T& element( + ) const; + + T& element( + ); + + void remove_any ( + T& item + ); + + }; + + + template < + typename stack_base + > + inline void swap ( + stack_kernel_c& a, + stack_kernel_c& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename stack_base + > + void stack_kernel_c:: + pop( + T& item + ) + { + + // make sure requires clause is not broken + DLIB_CASSERT(this->size() != 0, + "\tvoid stack::pop" + << "\n\tsize of stack should not be zero" + << "\n\tthis: " << this + ); + + // call the real function + stack_base::pop(item); + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename stack_base + > + const typename stack_base::type& stack_kernel_c:: + current( + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT(this->size() != 0, + "\tconst T& stack::current" + << "\n\tsize of stack should not be zero" + << "\n\tthis: " << this + ); + + // call the real function + return stack_base::current(); + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename stack_base + > + typename stack_base::type& stack_kernel_c:: + current( + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(this->size() != 0, + "\tT& stack::current" + << "\n\tsize of stack should not be zero" + << "\n\tthis: " << this + ); + + // call the real function + return stack_base::current(); + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename stack_base + > + typename stack_base::type& stack_kernel_c:: + element( + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(this->current_element_valid(), + "\tT& stack::element" + << "\n\tThe current element must be valid if you are to access it." + << "\n\tthis: " << this + ); + + // call the real function + return stack_base::element(); + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename stack_base + > + const typename stack_base::type& stack_kernel_c:: + element( + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT(this->current_element_valid(), + "\tconst T& stack::element" + << "\n\tThe current element must be valid if you are to access it." + << "\n\tthis: " << this + ); + + // call the real function + return stack_base::element(); + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename stack_base + > + void stack_kernel_c:: + remove_any ( + T& item + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( (this->size() > 0), + "\tvoid stack::remove_any" + << "\n\tsize() must be greater than zero if something is going to be removed" + << "\n\tsize(): " << this->size() + << "\n\tthis: " << this + ); + + // call the real function + stack_base::remove_any(item); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_STACK_KERNEl_C_ + diff --git a/ml/dlib/dlib/stack_trace.cpp b/ml/dlib/dlib/stack_trace.cpp new file mode 100644 index 000000000..0a6ff8ee6 --- /dev/null +++ b/ml/dlib/dlib/stack_trace.cpp @@ -0,0 +1,91 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_STACK_TRACE_CPp_ +#define DLIB_STACK_TRACE_CPp_ + +#if defined(DLIB_ENABLE_STACK_TRACE) && !defined(NO_MAKEFILE) + +#include +#include +#include "stack_trace.h" +#include "stack.h" +#include "memory_manager.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace + { + struct stack_tracer_data + { + stack_tracer_data( + ) : funct_name(0), + file_name(0), + line_number(0) + {} + const char* funct_name; + const char* file_name; + int line_number; + }; + + using stack_tracer_stack_type = stack::kernel_2a>::kernel_1a; + + stack_tracer_stack_type& get_dlib_stack_trace_stack() + { + thread_local stack_tracer_stack_type a; + return a; + } + } + +// ---------------------------------------------------------------------------------------- + + stack_tracer:: + stack_tracer ( + const char* funct_name, + const char* file_name, + const int line_number + ) + { + stack_tracer_data data; + data.funct_name = funct_name; + data.file_name = file_name; + data.line_number = line_number; + + // pop the info onto the function stack trace + get_dlib_stack_trace_stack().push(data); + } + +// ---------------------------------------------------------------------------------------- + + stack_tracer:: + ~stack_tracer() + { + stack_tracer_data temp; + get_dlib_stack_trace_stack().pop(temp); + } + +// ---------------------------------------------------------------------------------------- + + const std::string get_stack_trace() + { + std::ostringstream sout; + auto& stack = get_dlib_stack_trace_stack(); + stack.reset(); + while (stack.move_next()) + { + stack_tracer_data data = stack.element(); + sout << data.file_name << ":" << data.line_number << "\n " << data.funct_name << "\n"; + } + return sout.str(); + } + +// ---------------------------------------------------------------------------------------- + +} +#endif + +#endif // DLIB_STACK_TRACE_CPp_ + + diff --git a/ml/dlib/dlib/stack_trace.h b/ml/dlib/dlib/stack_trace.h new file mode 100644 index 000000000..aacbeb782 --- /dev/null +++ b/ml/dlib/dlib/stack_trace.h @@ -0,0 +1,118 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_STACK_TRACe_ +#define DLIB_STACK_TRACe_ + +/*! + This file defines 3 things. Two of them are preprocessor macros that + enable you to tag functions with the dlib stack trace watcher. The + third thing is a function named get_stack_trace() which returns the + current stack trace in std::string form. + + To enable the stack trace you must #define DLIB_ENABLE_STACK_TRACE. + When this #define isn't set then the 3 things described above + still exist but they don't do anything. + + Also note that when the stack trace is enabled it changes the DLIB_ASSERT + and DLIB_CASSERT macros so that they print stack traces when + an assert fails. + + See the following example program for details: + + #include + #include + + void funct2() + { + // put this macro at the top of each function you would + // like to appear in stack traces + DLIB_STACK_TRACE; + + // you may print the current stack trace as follows. + std::cout << dlib::get_stack_trace() << endl; + } + + void funct() + { + // This alternate form of DLIB_STACK_TRACE allows you to specify + // the string used to name the current function. The other form + // will usually output an appropriate function name automatically + // so this may not be needed. + DLIB_STACK_TRACE_NAMED("funct"); + funct2(); + } + + int main() + { + funct(); + } +!*/ + + +#include +#include "assert.h" + +// only setup the stack trace stuff if the asserts are enabled (which happens in debug mode +// basically). Also, this stuff doesn't work if you use NO_MAKEFILE +#if defined(DLIB_ENABLE_STACK_TRACE) +#ifdef NO_MAKEFILE +#error "You can't use the dlib stack trace stuff and NO_MAKEFILE at the same time" +#endif + +namespace dlib +{ + const std::string get_stack_trace(); +} + +// redefine the DLIB_CASSERT macro to include the stack trace +#undef DLIBM_CASSERT +#define DLIBM_CASSERT(_exp,_message) \ + {if ( !(_exp) ) \ + { \ + std::ostringstream dlib_o_out; \ + dlib_o_out << "\n\nError occurred at line " << __LINE__ << ".\n"; \ + dlib_o_out << "Error occurred in file " << __FILE__ << ".\n"; \ + dlib_o_out << "Error occurred in function " << DLIB_FUNCTION_NAME << ".\n\n"; \ + dlib_o_out << "Failing expression was " << #_exp << ".\n"; \ + dlib_o_out << _message << "\n\n"; \ + dlib_o_out << "Stack Trace: \n" << dlib::get_stack_trace() << "\n"; \ + dlib_assert_breakpoint(); \ + throw dlib::fatal_error(dlib::EBROKEN_ASSERT,dlib_o_out.str()); \ + }} + + + +namespace dlib +{ + + class stack_tracer + { + public: + stack_tracer ( + const char* funct_name, + const char* file_name, + const int line_number + ); + + ~stack_tracer(); + + }; +} + +#define DLIB_STACK_TRACE_NAMED(x) dlib::stack_tracer dlib_stack_tracer_object(x,__FILE__,__LINE__) +#define DLIB_STACK_TRACE dlib::stack_tracer dlib_stack_tracer_object(DLIB_FUNCTION_NAME,__FILE__,__LINE__) + +#else // don't do anything if ENABLE_ASSERTS isn't defined +#define DLIB_STACK_TRACE_NAMED(x) +#define DLIB_STACK_TRACE + +namespace dlib +{ + inline const std::string get_stack_trace() { return std::string("stack trace not enabled");} +} + +#endif + + +#endif // DLIB_STACK_TRACe_ + diff --git a/ml/dlib/dlib/static_map.h b/ml/dlib/dlib/static_map.h new file mode 100644 index 000000000..f1fcadab9 --- /dev/null +++ b/ml/dlib/dlib/static_map.h @@ -0,0 +1,43 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_STATIC_MAp_ +#define DLIB_STATIC_MAp_ + +#include "static_map/static_map_kernel_1.h" +#include "static_map/static_map_kernel_c.h" + +#include + + +namespace dlib +{ + + template < + typename domain, + typename range, + typename compare = std::less + > + class static_map + { + static_map() {} + + + public: + + //----------- kernels --------------- + + // kernel_1a + typedef static_map_kernel_1 + kernel_1a; + typedef static_map_kernel_c + kernel_1a_c; + + + + + + }; +} + +#endif // DLIB_STATIC_MAp_ + diff --git a/ml/dlib/dlib/static_map/static_map_kernel_1.h b/ml/dlib/dlib/static_map/static_map_kernel_1.h new file mode 100644 index 000000000..a7b627ae6 --- /dev/null +++ b/ml/dlib/dlib/static_map/static_map_kernel_1.h @@ -0,0 +1,756 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_STATIC_MAP_KERNEl_1_ +#define DLIB_STATIC_MAP_KERNEl_1_ + +#include "static_map_kernel_abstract.h" +#include "../interfaces/map_pair.h" +#include "../interfaces/enumerable.h" +#include "../interfaces/remover.h" +#include "../algs.h" +#include "../serialize.h" +#include + +namespace dlib +{ + + template < + typename domain, + typename range, + typename compare = std::less + > + class static_map_kernel_1 : public enumerable > + { + + /*! + INITIAL VALUE + - map_size == 0 + - d == 0 + - r == 0 + - mp.d = 0; + - at_start_ == true + + + CONVENTION + - size() == map_size + - if (size() > 0) then + - d == pointer to an array containing all the domain elements + - r == pointer to an array containing all the range elements + - for every i: operator[](d[i]) == r[i] + - d is sorted according to operator< + - else + - d == 0 + - r == 0 + + - current_element_valid() == (mp.d != 0) + - at_start() == (at_start_) + - if (current_element_valid()) then + - element() == mp + !*/ + + class mpair : public map_pair + { + public: + const domain* d; + range* r; + + const domain& key( + ) const { return *d; } + + const range& value( + ) const { return *r; } + + range& value( + ) { return *r; } + }; + + + // I would define this outside the class but Borland 5.5 has some problems + // with non-inline templated friend functions. + friend void deserialize ( + static_map_kernel_1& item, + std::istream& in + ) + { + try + { + item.clear(); + unsigned long size; + deserialize(size,in); + item.map_size = size; + item.d = new domain[size]; + item.r = new range[size]; + for (unsigned long i = 0; i < size; ++i) + { + deserialize(item.d[i],in); + deserialize(item.r[i],in); + } + } + catch (serialization_error& e) + { + item.map_size = 0; + if (item.d) + { + delete [] item.d; + item.d = 0; + } + if (item.r) + { + delete [] item.r; + item.r = 0; + } + + throw serialization_error(e.info + "\n while deserializing object of type static_map_kernel_1"); + } + catch (...) + { + item.map_size = 0; + if (item.d) + { + delete [] item.d; + item.d = 0; + } + if (item.r) + { + delete [] item.r; + item.r = 0; + } + + throw; + } + } + + + public: + + typedef domain domain_type; + typedef range range_type; + typedef compare compare_type; + + static_map_kernel_1( + ); + + virtual ~static_map_kernel_1( + ); + + void clear ( + ); + + void load ( + pair_remover& source + ); + + void load ( + asc_pair_remover& source + ); + + inline const range* operator[] ( + const domain& d + ) const; + + inline range* operator[] ( + const domain& d + ); + + inline void swap ( + static_map_kernel_1& item + ); + + // functions from the enumerable interface + inline size_t size ( + ) const; + + inline bool at_start ( + ) const; + + inline void reset ( + ) const; + + inline bool current_element_valid ( + ) const; + + inline const map_pair& element ( + ) const; + + inline map_pair& element ( + ); + + inline bool move_next ( + ) const; + + + private: + + bool binary_search ( + const domain& item, + unsigned long& pos + ) const; + /*! + ensures + - if (there is an item in d equivalent to item) then + - returns true + - d[#pos] is equivalent item + - else + - returns false + !*/ + + void sort_arrays ( + unsigned long left, + unsigned long right + ); + /*! + requires + - left and right are within the bounts of the array + ensures + - everything in the convention is still true and d[left] though + d[right] is sorted according to operator< + !*/ + + void qsort_partition ( + unsigned long& partition_element, + const unsigned long left, + const unsigned long right + ); + /*! + requires + - left < right + - left and right are within the bounts of the array + ensures + - the convention is still true + - left <= #partition_element <= right + - all elements in #d < #d[#partition_element] have + indices >= left and < #partition_element + - all elements in #d >= #d[#partition_element] have + indices >= #partition_element and <= right + !*/ + + unsigned long median ( + unsigned long one, + unsigned long two, + unsigned long three + ); + /*! + requires + - one, two, and three are valid indexes into d + ensures + - returns the median of d[one], d[two], and d[three] + !*/ + + + + + // data members + unsigned long map_size; + domain* d; + range* r; + mutable mpair mp; + mutable bool at_start_; + compare comp; + + // restricted functions + static_map_kernel_1(static_map_kernel_1&); // copy constructor + static_map_kernel_1& operator=(static_map_kernel_1&); // assignment operator + }; + + template < + typename domain, + typename range, + typename compare + > + inline void swap ( + static_map_kernel_1& a, + static_map_kernel_1& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename compare + > + static_map_kernel_1:: + static_map_kernel_1( + ) : + map_size(0), + d(0), + r(0), + at_start_(true) + { + mp.d = 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename compare + > + static_map_kernel_1:: + ~static_map_kernel_1( + ) + { + if (map_size > 0) + { + delete [] d; + delete [] r; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename compare + > + void static_map_kernel_1:: + clear ( + ) + { + if (map_size > 0) + { + map_size = 0; + delete [] d; + delete [] r; + d = 0; + r = 0; + } + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename compare + > + void static_map_kernel_1:: + load ( + pair_remover& source + ) + { + if (source.size() > 0) + { + domain* old_d = d; + d = new domain[source.size()]; + try { r = new range[source.size()]; } + catch (...) { delete [] d; d = old_d; throw; } + + map_size = source.size(); + + for (unsigned long i = 0; source.size() > 0; ++i) + source.remove_any(d[i],r[i]); + + sort_arrays(0,map_size-1); + } + else + { + clear(); + } + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename compare + > + void static_map_kernel_1:: + load ( + asc_pair_remover& source + ) + { + if (source.size() > 0) + { + domain* old_d = d; + d = new domain[source.size()]; + try { r = new range[source.size()]; } + catch (...) { delete [] d; d = old_d; throw; } + + map_size = source.size(); + + for (unsigned long i = 0; source.size() > 0; ++i) + source.remove_any(d[i],r[i]); + } + else + { + clear(); + } + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename compare + > + const range* static_map_kernel_1:: + operator[] ( + const domain& d_item + ) const + { + unsigned long pos; + if (binary_search(d_item,pos)) + return r+pos; + else + return 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename compare + > + range* static_map_kernel_1:: + operator[] ( + const domain& d_item + ) + { + unsigned long pos; + if (binary_search(d_item,pos)) + return r+pos; + else + return 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename compare + > + size_t static_map_kernel_1:: + size ( + ) const + { + return map_size; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename compare + > + void static_map_kernel_1:: + swap ( + static_map_kernel_1& item + ) + { + exchange(map_size,item.map_size); + exchange(d,item.d); + exchange(r,item.r); + exchange(mp,item.mp); + exchange(at_start_,item.at_start_); + exchange(comp,item.comp); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // enumerable function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename compare + > + bool static_map_kernel_1:: + at_start ( + ) const + { + return (at_start_); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename compare + > + void static_map_kernel_1:: + reset ( + ) const + { + mp.d = 0; + at_start_ = true; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename compare + > + bool static_map_kernel_1:: + current_element_valid ( + ) const + { + return (mp.d != 0); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename compare + > + const map_pair& static_map_kernel_1:: + element ( + ) const + { + return mp; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename compare + > + map_pair& static_map_kernel_1:: + element ( + ) + { + return mp; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename compare + > + bool static_map_kernel_1:: + move_next ( + ) const + { + // if at_start() && size() > 0 + if (at_start_ && map_size > 0) + { + at_start_ = false; + mp.r = r; + mp.d = d; + return true; + } + // else if current_element_valid() + else if (mp.d != 0) + { + ++mp.d; + ++mp.r; + if (static_cast(mp.d - d) < map_size) + { + return true; + } + else + { + mp.d = 0; + return false; + } + } + else + { + at_start_ = false; + return false; + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // private member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename compare + > + bool static_map_kernel_1:: + binary_search ( + const domain& item, + unsigned long& pos + ) const + { + unsigned long high = map_size; + unsigned long low = 0; + unsigned long p = map_size; + unsigned long idx; + while (p > 0) + { + p = (high-low)>>1; + idx = p+low; + if (comp(item , d[idx])) + { + high = idx; + } + else if (comp(d[idx] , item)) + { + low = idx; + } + else + { + pos = idx; + return true; + } + } + return false; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename compare + > + void static_map_kernel_1:: + sort_arrays ( + unsigned long left, + unsigned long right + ) + { + if ( left < right) + { + unsigned long partition_element; + qsort_partition(partition_element,left,right); + + if (partition_element > 0) + sort_arrays(left,partition_element-1); + sort_arrays(partition_element+1,right); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename compare + > + void static_map_kernel_1:: + qsort_partition ( + unsigned long& partition_element, + const unsigned long left, + const unsigned long right + ) + { + partition_element = right; + + unsigned long med = median(partition_element,left,((right-left)>>1) +left); + exchange(d[partition_element],d[med]); + exchange(r[partition_element],r[med]); + + unsigned long right_scan = right-1; + unsigned long left_scan = left; + + while (true) + { + // find an element to the left of partition_element that needs to be moved + while ( comp( d[left_scan] , d[partition_element]) ) + { + ++left_scan; + } + + // find an element to the right of partition_element that needs to be moved + while ( + !(comp (d[right_scan] , d[partition_element])) && + (right_scan > left_scan) + ) + { + --right_scan; + } + if (left_scan >= right_scan) + break; + + exchange(d[left_scan],d[right_scan]); + exchange(r[left_scan],r[right_scan]); + + } + exchange(d[left_scan],d[partition_element]); + exchange(r[left_scan],r[partition_element]); + partition_element = left_scan; + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename compare + > + unsigned long static_map_kernel_1:: + median ( + unsigned long one, + unsigned long two, + unsigned long three + ) + { + if ( comp( d[one] , d[two]) ) + { + // one < two + if ( comp( d[two] , d[three]) ) + { + // one < two < three : two + return two; + } + else + { + // one < two >= three + if (comp( d[one] , d[three])) + { + // three + return three; + } + } + + } + else + { + // one >= two + if ( comp(d[three] , d[one] )) + { + // three <= one >= two + if ( comp(d[three] , d[two]) ) + { + // two + return two; + } + else + { + // three + return three; + } + } + } + return one; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_STATIC_MAP_KERNEl_1_ + diff --git a/ml/dlib/dlib/static_map/static_map_kernel_abstract.h b/ml/dlib/dlib/static_map/static_map_kernel_abstract.h new file mode 100644 index 000000000..0821367b5 --- /dev/null +++ b/ml/dlib/dlib/static_map/static_map_kernel_abstract.h @@ -0,0 +1,181 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_STATIC_MAP_KERNEl_ABSTRACT_ +#ifdef DLIB_STATIC_MAP_KERNEl_ABSTRACT_ + +#include "../interfaces/map_pair.h" +#include "../interfaces/enumerable.h" +#include "../interfaces/remover.h" +#include "../serialize.h" +#include + +namespace dlib +{ + + template < + typename domain, + typename range, + typename compare = std::less + > + class static_map : public enumerable > + { + + /*! + REQUIREMENTS ON domain + domain must be comparable by compare where compare is a functor compatible with std::less and + domain is swappable by a global swap() and + domain must have a default constructor + + REQUIREMENTS ON range + range is swappable by a global swap() and + range must have a default constructor + + POINTERS AND REFERENCES TO INTERNAL DATA + Only the destructor and load_from() will invalidate pointers or + references to internal data. + + INITIAL VALUE + size() == 0 + + ENUMERATION ORDER + The enumerator will iterate over the domain (and each associated + range element) elements in ascending order according to the compare functor. + (i.e. the elements are enumerated in sorted order) + + WHAT THIS OBJECT REPRESENTS + static_map contains items of type domain and range + + This object is similar an array. It maps items of type domain on to + items of type range. + + Also note that unless specified otherwise, no member functions + of this object throw exceptions. + + NOTE + definition of equivalent: + a is equivalent to b if + a < b == false and + b < a == false + !*/ + + public: + + typedef domain domain_type; + typedef range range_type; + typedef compare compare_type; + + static_map ( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc or any exception thrown by domain's or range's + constructor. + !*/ + + virtual ~static_map( + ); + /*! + ensures + - all memory associated with *this has been released + !*/ + + void clear( + ); + /*! + ensures + - #*this has its initial value + throws + - std::bad_alloc or any exception thrown by domain's or range's + constructor. + If this exception is thrown then #*this is unusable + until clear() is called and succeeds. + !*/ + + void load ( + pair_remover& source + ); + /*! + ensures + - #size() == source.size() + - #source.size() == 0 + - all the pairs in source are removed and placed into #*this + - #at_start() == true + throws + - std::bad_alloc or any exception thrown by domain's or range's + constructor. + If this exception is thrown then the call to load() will have + no effect on #*this. + !*/ + + const range* operator[] ( + const domain& d + ) const; + /*! + ensures + - if (there is an element in the domain equivalent to d) then + - returns a pointer to an element in the range of *this that + is associated with an element in the domain of *this + equivalent to d. + - else + - returns 0 + !*/ + + range* operator[] ( + const domain& d + ); + /*! + ensures + - if (there is an element in the domain equivalent to d) then + - returns a pointer to an element in the range of *this that + is associated with an element in the domain of *this + equivalent to d. + - else + - returns 0 + !*/ + + void swap ( + static_map& item + ); + /*! + ensures + - swaps *this and item + !*/ + + private: + + // restricted functions + static_map(static_map&); // copy constructor + static_map& operator=(static_map&); // assignment operator + }; + + template < + typename domain, + typename range, + typename compare + > + inline void swap ( + static_map& a, + static_map& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + + template < + typename domain, + typename range, + typename compare + > + void deserialize ( + static_map& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ +} + +#endif // DLIB_STATIC_MAP_KERNEl_ABSTRACT_ + diff --git a/ml/dlib/dlib/static_map/static_map_kernel_c.h b/ml/dlib/dlib/static_map/static_map_kernel_c.h new file mode 100644 index 000000000..576d79374 --- /dev/null +++ b/ml/dlib/dlib/static_map/static_map_kernel_c.h @@ -0,0 +1,89 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_STATIC_MAP_KERNEl_C_ +#define DLIB_STATIC_MAP_KERNEl_C_ + +#include "static_map_kernel_abstract.h" +#include "../algs.h" +#include "../assert.h" +#include "../interfaces/map_pair.h" +#include "../interfaces/remover.h" + +namespace dlib +{ + + template < + typename map_base + > + class static_map_kernel_c : public map_base + { + typedef typename map_base::domain_type domain; + typedef typename map_base::range_type range; + + public: + const map_pair& element ( + ) const; + + map_pair& element ( + ); + + }; + + template < + typename map_base + > + inline void swap ( + static_map_kernel_c& a, + static_map_kernel_c& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename map_base + > + const map_pair& static_map_kernel_c:: + element ( + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT(this->current_element_valid() == true, + "\tconst map_pair& static_map::element" + << "\n\tyou can't access the current element if it doesn't exist" + << "\n\tthis: " << this + ); + + // call the real function + return map_base::element(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_base + > + map_pair& static_map_kernel_c:: + element ( + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(this->current_element_valid() == true, + "\tmap_pair& static_map::element" + << "\n\tyou can't access the current element if it doesn't exist" + << "\n\tthis: " << this + ); + + // call the real function + return map_base::element(); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_STATIC_MAP_KERNEl_C_ + diff --git a/ml/dlib/dlib/static_set.h b/ml/dlib/dlib/static_set.h new file mode 100644 index 000000000..47ecbafe4 --- /dev/null +++ b/ml/dlib/dlib/static_set.h @@ -0,0 +1,49 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_STATIC_SEt_ +#define DLIB_STATIC_SEt_ + +#include "static_set/static_set_kernel_1.h" +#include "static_set/static_set_kernel_c.h" +#include "static_set/static_set_compare_1.h" + +#include + +namespace dlib +{ + + template < + typename T, + typename compare = std::less + > + class static_set + { + static_set() {} + + + public: + + //----------- kernels --------------- + + // kernel_1a + typedef static_set_kernel_1 + kernel_1a; + typedef static_set_kernel_c + kernel_1a_c; + + + //----------- extensions ------------- + + typedef static_set_compare_1 + compare_1a; + typedef static_set_compare_1 + compare_1a_c; + + + + + }; +} + +#endif // DLIB_STATIC_SEt_ + diff --git a/ml/dlib/dlib/static_set/static_set_compare_1.h b/ml/dlib/dlib/static_set/static_set_compare_1.h new file mode 100644 index 000000000..b5271e1d4 --- /dev/null +++ b/ml/dlib/dlib/static_set/static_set_compare_1.h @@ -0,0 +1,122 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_STATIC_SET_COMPARe_1_ +#define DLIB_STATIC_SET_COMPARe_1_ + +#include "static_set_compare_abstract.h" + +#include "../algs.h" + + + +namespace dlib +{ + + template < + typename static_set_base + > + class static_set_compare_1 : public static_set_base + { + + public: + + bool operator< ( + const static_set_compare_1& rhs + ) const; + + bool operator== ( + const static_set_compare_1& rhs + ) const; + + }; + + + template < + typename static_set_base + > + inline void swap ( + static_set_compare_1& a, + static_set_compare_1& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename static_set_base + > + bool static_set_compare_1:: + operator< ( + const static_set_compare_1& rhs + ) const + { + bool result = false; + if (static_set_base::size() < rhs.size()) + result = true; + + if (static_set_base::size() == rhs.size()) + { + rhs.reset(); + static_set_base::reset(); + while (rhs.move_next()) + { + static_set_base::move_next(); + if (static_set_base::element() < rhs.element()) + { + result = true; + break; + } + else if (rhs.element() < static_set_base::element()) + { + break; + } + } + } + + static_set_base::reset(); + rhs.reset(); + + return result; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename static_set_base + > + bool static_set_compare_1:: + operator== ( + const static_set_compare_1& rhs + ) const + { + bool result = true; + if (static_set_base::size() != rhs.size()) + result = false; + + + rhs.reset(); + static_set_base::reset(); + while (rhs.move_next() && static_set_base::move_next()) + { + if (!(rhs.element() == static_set_base::element())) + { + result = false; + break; + } + } + + static_set_base::reset(); + rhs.reset(); + + return result; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_STATIC_SET_COMPARe_1_ + diff --git a/ml/dlib/dlib/static_set/static_set_compare_abstract.h b/ml/dlib/dlib/static_set/static_set_compare_abstract.h new file mode 100644 index 000000000..356354e69 --- /dev/null +++ b/ml/dlib/dlib/static_set/static_set_compare_abstract.h @@ -0,0 +1,93 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_STATIC_SET_COMPARe_ABSTRACT_ +#ifdef DLIB_STATIC_SET_COMPARe_ABSTRACT_ + +#include "static_set_kernel_abstract.h" + +#include "../algs.h" + + +namespace dlib +{ + + template < + typename static_set_base + > + class static_set_compare : public static_set_base + { + + /*! + REQUIREMENTS ON static_set_base + must an implementation of static_set/static_set_kernel_abstract.h + + POINTERS AND REFERENCES TO INTERNAL DATA + operator== and operator< invalidate pointers or references to + data members. + + WHAT THIS EXTENSION DOES FOR static_set + This gives a static_set the ability to compare itself to other + static_sets using the < and == operators. + + The < operator is conceptually weird for sets. It is useful + though because it allows you to make sets of sets since + sets require that their containing type implement operator<. + + Also note that it is the case that for any two sets a and b + if (a rhs.size()) then + - returns false + - else + - returns true if there exists an integer j such that 0 <= j < size() + and for all integers i such that 0 <= i < j where it is true that + (*this)[i] == rhs[i] and (*this)[j] < rhs[j] + - returns false if there is no j that will satisfy the above conditions. + !*/ + + bool operator== ( + const static_set_compare& rhs + ) const; + /*! + ensures + - #at_start() == true + - returns true if *this and rhs contain the same elements. + returns false otherwise. + !*/ + }; + + + template < + typename static_set_base + > + inline void swap ( + static_set_compare& a, + static_set_compare& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + +} + +#endif // DLIB_STATIC_SET_COMPARe_ABSTRACT_ + diff --git a/ml/dlib/dlib/static_set/static_set_kernel_1.h b/ml/dlib/dlib/static_set/static_set_kernel_1.h new file mode 100644 index 000000000..7a1f166fc --- /dev/null +++ b/ml/dlib/dlib/static_set/static_set_kernel_1.h @@ -0,0 +1,446 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_STATIC_SET_KERNEl_1_ +#define DLIB_STATIC_SET_KERNEl_1_ + +#include "static_set_kernel_abstract.h" +#include "../interfaces/enumerable.h" +#include "../interfaces/remover.h" +#include "../algs.h" +#include "../sort.h" +#include "../serialize.h" +#include + +namespace dlib +{ + + template < + typename T, + typename compare = std::less + > + class static_set_kernel_1 : public enumerable + { + + /*! + INITIAL VALUE + - set_size == 0 + - d == 0 + - at_start_ == true + - cur == 0 + + CONVENTION + - size() == set_size + - if (set_size > 0) then + - d == pointer to an array containing all the elements of the set + - d is sorted according to operator< + - else + - d == 0 + + - current_element_valid() == (cur != 0) + - at_start() == (at_start_) + - if (current_element_valid()) then + - element() == *cur + !*/ + + // I would define this outside the class but Borland 5.5 has some problems + // with non-inline templated friend functions. + friend void deserialize ( + static_set_kernel_1& item, + std::istream& in + ) + { + try + { + item.clear(); + size_t size; + deserialize(size,in); + item.set_size = size; + item.d = new T[size]; + for (size_t i = 0; i < size; ++i) + { + deserialize(item.d[i],in); + } + } + catch (serialization_error e) + { + item.set_size = 0; + if (item.d) + { + delete [] item.d; + item.d = 0; + } + + throw serialization_error(e.info + "\n while deserializing object of type static_set_kernel_1"); + } + catch (...) + { + item.set_size = 0; + if (item.d) + { + delete [] item.d; + item.d = 0; + } + + throw; + } + } + + public: + + typedef T type; + typedef compare compare_type; + + static_set_kernel_1( + ); + + virtual ~static_set_kernel_1( + ); + + void clear ( + ); + + void load ( + remover& source + ); + + void load ( + asc_remover& source + ); + + bool is_member ( + const T& item + ) const; + + inline void swap ( + static_set_kernel_1& item + ); + + // functions from the enumerable interface + inline size_t size ( + ) const; + + inline bool at_start ( + ) const; + + inline void reset ( + ) const; + + inline bool current_element_valid ( + ) const; + + inline const T& element ( + ) const; + + inline const T& element ( + ); + + inline bool move_next ( + ) const; + + + private: + + + // data members + size_t set_size; + T* d; + mutable T* cur; + mutable bool at_start_; + + // restricted functions + static_set_kernel_1(static_set_kernel_1&); // copy constructor + static_set_kernel_1& operator=(static_set_kernel_1&); // assignment operator + }; + + template < + typename T, + typename compare + > + inline void swap ( + static_set_kernel_1& a, + static_set_kernel_1& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename compare + > + static_set_kernel_1:: + static_set_kernel_1( + ) : + set_size(0), + d(0), + cur(0), + at_start_(true) + { + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename compare + > + static_set_kernel_1:: + ~static_set_kernel_1( + ) + { + if (set_size > 0) + delete [] d; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename compare + > + void static_set_kernel_1:: + clear( + ) + { + if (set_size > 0) + { + set_size = 0; + delete [] d; + d = 0; + } + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename compare + > + void static_set_kernel_1:: + load ( + remover& source + ) + { + if (source.size() > 0) + { + d = new T[source.size()]; + + set_size = source.size(); + + for (size_t i = 0; source.size() > 0; ++i) + source.remove_any(d[i]); + + compare comp; + qsort_array(d,0,set_size-1,comp); + } + else + { + clear(); + } + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename compare + > + void static_set_kernel_1:: + load ( + asc_remover& source + ) + { + if (source.size() > 0) + { + d = new T[source.size()]; + + set_size = source.size(); + + for (size_t i = 0; source.size() > 0; ++i) + source.remove_any(d[i]); + } + else + { + clear(); + } + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename compare + > + bool static_set_kernel_1:: + is_member ( + const T& item + ) const + { + size_t high = set_size; + size_t low = 0; + size_t p = set_size; + size_t idx; + while (p > 0) + { + p = (high-low)>>1; + idx = p+low; + if (item < d[idx]) + high = idx; + else if (d[idx] < item) + low = idx; + else + return true; + } + return false; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename compare + > + size_t static_set_kernel_1:: + size ( + ) const + { + return set_size; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename compare + > + void static_set_kernel_1:: + swap ( + static_set_kernel_1& item + ) + { + exchange(set_size,item.set_size); + exchange(d,item.d); + exchange(cur,item.cur); + exchange(at_start_,item.at_start_); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // enumerable function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename compare + > + bool static_set_kernel_1:: + at_start ( + ) const + { + return at_start_; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename compare + > + void static_set_kernel_1:: + reset ( + ) const + { + at_start_ = true; + cur = 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename compare + > + bool static_set_kernel_1:: + current_element_valid ( + ) const + { + return (cur != 0); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename compare + > + const T& static_set_kernel_1:: + element ( + ) const + { + return *cur; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename compare + > + const T& static_set_kernel_1:: + element ( + ) + { + return *cur; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename compare + > + bool static_set_kernel_1:: + move_next ( + ) const + { + // if at_start() && size() > 0 + if (at_start_ && set_size > 0) + { + at_start_ = false; + cur = d; + return true; + } + // else if current_element_valid() + else if (cur != 0) + { + ++cur; + if (static_cast(cur - d) < set_size) + { + return true; + } + else + { + cur = 0; + return false; + } + } + else + { + at_start_ = false; + return false; + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_STATIC_SET_KERNEl_1_ + diff --git a/ml/dlib/dlib/static_set/static_set_kernel_abstract.h b/ml/dlib/dlib/static_set/static_set_kernel_abstract.h new file mode 100644 index 000000000..a2efd1aaa --- /dev/null +++ b/ml/dlib/dlib/static_set/static_set_kernel_abstract.h @@ -0,0 +1,154 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_STATIC_SET_KERNEl_ABSTRACT_ +#ifdef DLIB_STATIC_SET_KERNEl_ABSTRACT_ + +#include "../interfaces/enumerable.h" +#include "../interfaces/remover.h" +#include "../serialize.h" +#include + +namespace dlib +{ + + template < + typename T, + typename compare = std::less + > + class static_set : public enumerable + { + + /*! + REQUIREMENTS ON T + T must be comparable by compare where compare is a functor compatible with std::less and + T is swappable by a global swap() and + T must have a default constructor + + POINTERS AND REFERENCES TO INTERNAL DATA + Only the destructor will invalidate pointers or references + to internal data. + + INITIAL VALUE + size() == 0 + + ENUMERATION ORDER + The enumerator will iterate over the elements in the set in + ascending order according to the compare functor. + (i.e. the elements are enumerated in sorted order) + + WHAT THIS OBJECT REPRESENTS + static_set contains items of type T + + This object represents an unaddressed collection of items. + + Also note that unless specified otherwise, no member functions + of this object throw exceptions. + + NOTE + definition of equivalent: + a is equivalent to b if + a < b == false and + b < a == false + !*/ + + public: + + typedef T type; + typedef compare compare_type; + + static_set ( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc or any exception thrown by T's constructor. + !*/ + + virtual ~static_set( + ); + /*! + ensures + - all memory associated with *this has been released + !*/ + + void clear( + ); + /*! + ensures + - #*this has its initial value + throws + - std::bad_alloc or any exception thrown by T's constructor. + If this exception is thrown then #*this is unusable + until clear() is called and succeeds. + !*/ + + void load ( + remover& source + ); + /*! + ensures + - #size() == source.size() + - #source.size() == 0 + - all the elements in source are removed and placed into #*this + - #at_start() == true + throws + - std::bad_alloc or any exception thrown by T's constructor. + If this exception is thrown then the call to load() will have + no effect on #*this. + !*/ + + bool is_member ( + const T& item + ) const; + /*! + ensures + - if (there is an item in *this equivalent to item) then + - returns true + - else + - returns false + !*/ + + void swap ( + static_set& item + ); + /*! + ensures + - swaps *this and item + !*/ + + private: + + // restricted functions + static_set(static_set&); // copy constructor + static_set& operator=(static_set&); // assignment operator + }; + + template < + typename T, + typename compare + > + inline void swap ( + static_set& a, + static_set& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + + template < + typename T, + typename compare + > + void deserialize ( + static_set& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +} + +#endif // DLIB_STATIC_SET_KERNEl_ABSTRACT_ + diff --git a/ml/dlib/dlib/static_set/static_set_kernel_c.h b/ml/dlib/dlib/static_set/static_set_kernel_c.h new file mode 100644 index 000000000..1280c9c89 --- /dev/null +++ b/ml/dlib/dlib/static_set/static_set_kernel_c.h @@ -0,0 +1,88 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_STATIC_SET_KERNEl_C_ +#define DLIB_STATIC_SET_KERNEl_C_ + +#include "static_set_kernel_abstract.h" +#include "../algs.h" +#include "../assert.h" +#include "../interfaces/remover.h" + +namespace dlib +{ + + template < + typename set_base + > + class static_set_kernel_c : public set_base + { + typedef typename set_base::type T; + public: + + const T& element ( + ); + + const T& element ( + ) const; + }; + + + template < + typename set_base + > + inline void swap ( + static_set_kernel_c& a, + static_set_kernel_c& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename set_base + > + const typename set_base::type& static_set_kernel_c:: + element ( + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT(this->current_element_valid() == true, + "\tconst T& static_set::element() const" + << "\n\tyou can't access the current element if it doesn't exist" + << "\n\tthis: " << this + ); + + // call the real function + return set_base::element(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename set_base + > + const typename set_base::type& static_set_kernel_c:: + element ( + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(this->current_element_valid() == true, + "\tconst T& static_set::element" + << "\n\tyou can't access the current element if it doesn't exist" + << "\n\tthis: " << this + ); + + // call the real function + return set_base::element(); + } + +// ---------------------------------------------------------------------------------------- + + +} + +#endif // DLIB_STATIC_SET_KERNEl_C_ + diff --git a/ml/dlib/dlib/statistics.h b/ml/dlib/dlib/statistics.h new file mode 100644 index 000000000..45785c635 --- /dev/null +++ b/ml/dlib/dlib/statistics.h @@ -0,0 +1,19 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_STATISTICs_H_ +#define DLIB_STATISTICs_H_ + +#include "statistics/statistics.h" +#include "statistics/dpca.h" +#include "statistics/random_subset_selector.h" +#include "statistics/image_feature_sampling.h" +#include "statistics/sammon.h" +#include "statistics/cca.h" +#include "statistics/average_precision.h" +#include "statistics/vector_normalizer_frobmetric.h" +#include "statistics/lda.h" + +#endif // DLIB_STATISTICs_H_ + + + diff --git a/ml/dlib/dlib/statistics/average_precision.h b/ml/dlib/dlib/statistics/average_precision.h new file mode 100644 index 000000000..6c2e7e0b1 --- /dev/null +++ b/ml/dlib/dlib/statistics/average_precision.h @@ -0,0 +1,66 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_AVERAGE_PREcISION_Hh_ +#define DLIB_AVERAGE_PREcISION_Hh_ + +#include "average_precision_abstract.h" +#include + + +namespace dlib +{ + namespace impl + { + inline bool get_bool_part ( + const bool& b + ) { return b; } + + template + bool get_bool_part(const std::pair& item) { return item.second; } + } + +// ---------------------------------------------------------------------------------------- + + template + double average_precision ( + const std::vector& items, + unsigned long missing_relevant_items = 0 + ) + { + using namespace dlib::impl; + double relevant_count = 0; + // find the precision values + std::vector precision; + for (unsigned long i = 0; i < items.size(); ++i) + { + if (get_bool_part(items[i])) + { + ++relevant_count; + precision.push_back(relevant_count / (i+1)); + } + } + + double precision_sum = 0; + double max_val = 0; + // now sum over the interpolated precision values + for (std::vector::reverse_iterator i = precision.rbegin(); i != precision.rend(); ++i) + { + max_val = std::max(max_val, *i); + precision_sum += max_val; + } + + + relevant_count += missing_relevant_items; + + if (relevant_count != 0) + return precision_sum/relevant_count; + else + return 1; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_AVERAGE_PREcISION_Hh_ + diff --git a/ml/dlib/dlib/statistics/average_precision_abstract.h b/ml/dlib/dlib/statistics/average_precision_abstract.h new file mode 100644 index 000000000..76c2c702a --- /dev/null +++ b/ml/dlib/dlib/statistics/average_precision_abstract.h @@ -0,0 +1,67 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_AVERAGE_PREcISION_ABSTRACT_Hh_ +#ifdef DLIB_AVERAGE_PREcISION_ABSTRACT_Hh_ + +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename alloc + > + double average_precision ( + const std::vector& items, + unsigned long missing_relevant_items = 0 + ); + /*! + ensures + - Interprets items as a list of relevant and non-relevant items in a response + from an information retrieval system. In particular, items with a true value + are relevant and false items are non-relevant. This function then returns + the average precision of the ranking of the given items. For, example, the + ranking [true, true, true, true, false] would have an average precision of 1. + On the other hand, the ranking of [true false false true] would have an + average precision of 0.75 (because the first true has a precision of 1 and + the second true has a precision of 0.5, giving an average of 0.75). + - As a special case, if item contains no true elements then the average + precision is considered to be 1. + - Note that we use the interpolated precision. That is, the interpolated + precision at a recall value r is set to the maximum precision obtained at any + higher recall value. Or in other words, we interpolate the precision/recall + curve so that precision is monotonically decreasing. Therefore, the average + precision value returned by this function is the area under this interpolated + precision/recall curve. + - This function will add in missing_relevant_items number of items with a + precision of zero into the average value returned. For example, the average + precision of the ranking [true, true] if there are 2 missing relevant items + is 0.5. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename alloc + > + double average_precision ( + const std::vector,alloc>& items, + unsigned long missing_relevant_items = 0 + ); + /*! + ensures + - this function is equivalent to copying the bool values from items into a + std::vector and then calling the above average_precision() routine on + it and returning the result. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_AVERAGE_PREcISION_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/statistics/cca.h b/ml/dlib/dlib/statistics/cca.h new file mode 100644 index 000000000..21300ea8f --- /dev/null +++ b/ml/dlib/dlib/statistics/cca.h @@ -0,0 +1,186 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CCA_hh_ +#define DLIB_CCA_hh_ + +#include "cca_abstract.h" +#include "../algs.h" +#include "../matrix.h" +#include "../sparse_vector.h" +#include "random_subset_selector.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + matrix compute_correlations ( + const matrix_exp& L, + const matrix_exp& R + ) + { + DLIB_ASSERT( L.size() > 0 && R.size() > 0 && L.nr() == R.nr(), + "\t matrix compute_correlations()" + << "\n\t Invalid inputs were given to this function." + << "\n\t L.size(): " << L.size() + << "\n\t R.size(): " << R.size() + << "\n\t L.nr(): " << L.nr() + << "\n\t R.nr(): " << R.nr() + ); + + typedef typename T::type type; + matrix A, B, C; + A = diag(trans(R)*L); + B = sqrt(diag(trans(L)*L)); + C = sqrt(diag(trans(R)*R)); + A = pointwise_multiply(A , reciprocal(pointwise_multiply(B,C))); + return A; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename matrix_type, + typename T + > + matrix impl_cca ( + const matrix_type& L, + const matrix_type& R, + matrix& Ltrans, + matrix& Rtrans, + unsigned long num_correlations, + unsigned long extra_rank, + unsigned long q, + unsigned long num_output_correlations, + double regularization + ) + { + matrix Ul, Vl; + matrix Ur, Vr; + matrix U, V; + matrix Dr, Dl, D; + + + // Note that we add a few more singular vectors in because it helps improve the + // final results if we run this part with a little higher rank than the final SVD. + svd_fast(L, Ul, Dl, Vl, num_correlations+extra_rank, q); + svd_fast(R, Ur, Dr, Vr, num_correlations+extra_rank, q); + + // Zero out singular values that are essentially zero so they don't cause numerical + // difficulties in the code below. + const double eps = std::numeric_limits::epsilon()*std::max(max(Dr),max(Dl))*100; + Dl = round_zeros(Dl+regularization,eps); + Dr = round_zeros(Dr+regularization,eps); + + // This matrix is really small so we can do a normal full SVD on it. Note that we + // also throw away the columns of Ul and Ur corresponding to zero singular values. + svd3(diagm(Dl>0)*tmp(trans(Ul)*Ur)*diagm(Dr>0), U, D, V); + + // now throw away extra columns of the transformations. We do this in a way + // that keeps the directions that have the highest correlations. + matrix temp = D; + rsort_columns(U, temp); + rsort_columns(V, D); + U = colm(U, range(0, num_output_correlations-1)); + V = colm(V, range(0, num_output_correlations-1)); + D = rowm(D, range(0, num_output_correlations-1)); + + Ltrans = Vl*inv(diagm(Dl))*U; + Rtrans = Vr*inv(diagm(Dr))*V; + + // Note that the D matrix contains the correlation values for the transformed + // vectors. However, when the L and R matrices have rank higher than + // num_correlations+extra_rank then the values in D become only approximate. + return D; + } + +// ---------------------------------------------------------------------------------------- + + template + matrix cca ( + const matrix& L, + const matrix& R, + matrix& Ltrans, + matrix& Rtrans, + unsigned long num_correlations, + unsigned long extra_rank = 5, + unsigned long q = 2, + double regularization = 0 + ) + { + DLIB_ASSERT( num_correlations > 0 && L.size() > 0 && R.size() > 0 && L.nr() == R.nr() && + regularization >= 0, + "\t matrix cca()" + << "\n\t Invalid inputs were given to this function." + << "\n\t num_correlations: " << num_correlations + << "\n\t regularization: " << regularization + << "\n\t L.size(): " << L.size() + << "\n\t R.size(): " << R.size() + << "\n\t L.nr(): " << L.nr() + << "\n\t R.nr(): " << R.nr() + ); + + using std::min; + const unsigned long n = min(num_correlations, (unsigned long)min(R.nr(),min(L.nc(), R.nc()))); + return impl_cca(L,R,Ltrans, Rtrans, num_correlations, extra_rank, q, n, regularization); + } + +// ---------------------------------------------------------------------------------------- + + template + matrix cca ( + const std::vector& L, + const std::vector& R, + matrix& Ltrans, + matrix& Rtrans, + unsigned long num_correlations, + unsigned long extra_rank = 5, + unsigned long q = 2, + double regularization = 0 + ) + { + DLIB_ASSERT( num_correlations > 0 && L.size() == R.size() && + max_index_plus_one(L) > 0 && max_index_plus_one(R) > 0 && + regularization >= 0, + "\t matrix cca()" + << "\n\t Invalid inputs were given to this function." + << "\n\t num_correlations: " << num_correlations + << "\n\t regularization: " << regularization + << "\n\t L.size(): " << L.size() + << "\n\t R.size(): " << R.size() + << "\n\t max_index_plus_one(L): " << max_index_plus_one(L) + << "\n\t max_index_plus_one(R): " << max_index_plus_one(R) + ); + + using std::min; + const unsigned long n = min(max_index_plus_one(L), max_index_plus_one(R)); + const unsigned long num_output_correlations = min(num_correlations, std::min(R.size(),n)); + return impl_cca(L,R,Ltrans, Rtrans, num_correlations, extra_rank, q, num_output_correlations, regularization); + } + +// ---------------------------------------------------------------------------------------- + + template + matrix cca ( + const random_subset_selector& L, + const random_subset_selector& R, + matrix& Ltrans, + matrix& Rtrans, + unsigned long num_correlations, + unsigned long extra_rank = 5, + unsigned long q = 2 + ) + { + return cca(L.to_std_vector(), R.to_std_vector(), Ltrans, Rtrans, num_correlations, extra_rank, q); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CCA_hh_ + + diff --git a/ml/dlib/dlib/statistics/cca_abstract.h b/ml/dlib/dlib/statistics/cca_abstract.h new file mode 100644 index 000000000..8e0b4e742 --- /dev/null +++ b/ml/dlib/dlib/statistics/cca_abstract.h @@ -0,0 +1,191 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_CCA_AbSTRACT_Hh_ +#ifdef DLIB_CCA_AbSTRACT_Hh_ + +#include "../matrix/matrix_la_abstract.h" +#include "random_subset_selector_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + matrix compute_correlations ( + const matrix_exp& L, + const matrix_exp& R + ); + /*! + requires + - L.size() > 0 + - R.size() > 0 + - L.nr() == R.nr() + ensures + - This function treats L and R as sequences of paired row vectors. It + then computes the correlation values between the elements of these + row vectors. In particular, we return a vector COR such that: + - COR.size() == L.nc() + - for all valid i: + - COR(i) == the correlation coefficient between the following sequence + of paired numbers: (L(k,i), R(k,i)) for k: 0 <= k < L.nr(). + Therefore, COR(i) is a value between -1 and 1 inclusive where 1 + indicates perfect correlation and -1 perfect anti-correlation. Note + that this function assumes the input data vectors have been centered + (i.e. made to have zero mean). If this is not the case then it will + report inaccurate results. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + matrix cca ( + const matrix& L, + const matrix& R, + matrix& Ltrans, + matrix& Rtrans, + unsigned long num_correlations, + unsigned long extra_rank = 5, + unsigned long q = 2, + double regularization = 0 + ); + /*! + requires + - num_correlations > 0 + - L.size() > 0 + - R.size() > 0 + - L.nr() == R.nr() + - regularization >= 0 + ensures + - This function performs a canonical correlation analysis between the row + vectors in L and R. That is, it finds two transformation matrices, Ltrans + and Rtrans, such that row vectors in the transformed matrices L*Ltrans and + R*Rtrans are as correlated as possible. That is, we try to find two transforms + such that the correlation values returned by compute_correlations(L*Ltrans, R*Rtrans) + would be maximized. + - Let N == min(num_correlations, min(R.nr(),min(L.nc(),R.nc()))) + (This is the actual number of elements in the transformed vectors. + Therefore, note that you can't get more outputs than there are rows or + columns in the input matrices.) + - #Ltrans.nr() == L.nc() + - #Ltrans.nc() == N + - #Rtrans.nr() == R.nc() + - #Rtrans.nc() == N + - This function assumes the data vectors in L and R have already been centered + (i.e. we assume the vectors have zero means). However, in many cases it is + fine to use uncentered data with cca(). But if it is important for your + problem then you should center your data before passing it to cca(). + - This function works with reduced rank approximations of the L and R matrices. + This makes it fast when working with large matrices. In particular, we use + the svd_fast() routine to find reduced rank representations of the input + matrices by calling it as follows: svd_fast(L, U,D,V, num_correlations+extra_rank, q) + and similarly for R. This means that you can use the extra_rank and q + arguments to cca() to influence the accuracy of the reduced rank + approximation. However, the default values should work fine for most + problems. + - returns an estimate of compute_correlations(L*#Ltrans, R*#Rtrans). The + returned vector should exactly match the output of compute_correlations() + when the reduced rank approximation to L and R is accurate and regularization + is set to 0. However, if this is not the case then the return value of this + function will deviate from compute_correlations(L*#Ltrans, R*#Rtrans). This + deviation can be used to check if the reduced rank approximation is working + or you need to increase extra_rank. + - The dimensions of the output vectors produced by L*#Ltrans or R*#Rtrans are + ordered such that the dimensions with the highest correlations come first. + That is, after applying the transforms produced by cca() to a set of vectors + you will find that dimension 0 has the highest correlation, then dimension 1 + has the next highest, and so on. This also means that the list of numbers + returned from cca() will always be listed in decreasing order. + - This function performs the ridge regression version of Canonical Correlation + Analysis when regularization is set to a value > 0. In particular, larger + values indicate the solution should be more heavily regularized. This can be + useful when the dimensionality of the data is larger than the number of + samples. + - A good discussion of CCA can be found in the paper "Canonical Correlation + Analysis" by David Weenink. In particular, this function is implemented + using equations 29 and 30 from his paper. We also use the idea of doing CCA + on a reduced rank approximation of L and R as suggested by Paramveer S. + Dhillon in his paper "Two Step CCA: A new spectral method for estimating + vector models of words". + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename sparse_vector_type, + typename T + > + matrix cca ( + const std::vector& L, + const std::vector& R, + matrix& Ltrans, + matrix& Rtrans, + unsigned long num_correlations, + unsigned long extra_rank = 5, + unsigned long q = 2, + double regularization = 0 + ); + /*! + requires + - num_correlations > 0 + - L.size() == R.size() + - max_index_plus_one(L) > 0 && max_index_plus_one(R) > 0 + (i.e. L and R can't represent empty matrices) + - L and R must contain sparse vectors (see the top of dlib/svm/sparse_vector_abstract.h + for a definition of sparse vector) + - regularization >= 0 + ensures + - This is just an overload of the cca() function defined above. Except in this + case we take a sparse representation of the input L and R matrices rather than + dense matrices. Therefore, in this case, we interpret L and R as matrices + with L.size() rows, where each row is defined by a sparse vector. So this + function does exactly the same thing as the above cca(). + - Note that you can apply the output transforms to a sparse vector with the + following code: + sparse_matrix_vector_multiply(trans(Ltrans), your_sparse_vector) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename sparse_vector_type, + typename Rand_type, + typename T + > + matrix cca ( + const random_subset_selector& L, + const random_subset_selector& R, + matrix& Ltrans, + matrix& Rtrans, + unsigned long num_correlations, + unsigned long extra_rank = 5, + unsigned long q = 2, + double regularization = 0 + ); + /*! + requires + - num_correlations > 0 + - L.size() == R.size() + - max_index_plus_one(L) > 0 && max_index_plus_one(R) > 0 + (i.e. L and R can't represent empty matrices) + - L and R must contain sparse vectors (see the top of dlib/svm/sparse_vector_abstract.h + for a definition of sparse vector) + - regularization >= 0 + ensures + - returns cca(L.to_std_vector(), R.to_std_vector(), Ltrans, Rtrans, num_correlations, extra_rank, q) + (i.e. this is just a convenience function for calling the cca() routine when + your sparse vectors are contained inside a random_subset_selector rather than + a std::vector) + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CCA_AbSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/statistics/dpca.h b/ml/dlib/dlib/statistics/dpca.h new file mode 100644 index 000000000..cae784682 --- /dev/null +++ b/ml/dlib/dlib/statistics/dpca.h @@ -0,0 +1,541 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_DPCA_h_ +#define DLIB_DPCA_h_ + +#include "dpca_abstract.h" +#include +#include +#include "../algs.h" +#include "../matrix.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename matrix_type + > + class discriminant_pca + { + /*! + INITIAL VALUE + - vect_size == 0 + - total_count == 0 + - between_count == 0 + - within_count == 0 + - between_weight == 1 + - within_weight == 1 + + CONVENTION + - vect_size == in_vector_size() + - total_count == the number of times add_to_total_variance() has been called. + - within_count == the number of times add_to_within_class_variance() has been called. + - between_count == the number of times add_to_between_class_variance() has been called. + - between_weight == between_class_weight() + - within_weight == within_class_weight() + + - if (total_count != 0) + - total_sum == the sum of all vectors given to add_to_total_variance() + - the covariance of all the elements given to add_to_total_variance() is given + by: + - let avg == total_sum/total_count + - covariance == total_cov/total_count - avg*trans(avg) + - if (within_count != 0) + - within_cov/within_count == the normalized within class scatter matrix + - if (between_count != 0) + - between_cov/between_count == the normalized between class scatter matrix + !*/ + + public: + + struct discriminant_pca_error : public error + { + discriminant_pca_error(const std::string& message): error(message) {} + }; + + typedef typename matrix_type::mem_manager_type mem_manager_type; + typedef typename matrix_type::type scalar_type; + typedef typename matrix_type::layout_type layout_type; + typedef matrix general_matrix; + typedef matrix column_matrix; + + discriminant_pca ( + ) + { + clear(); + } + + void clear( + ) + { + total_count = 0; + between_count = 0; + within_count = 0; + + vect_size = 0; + + + between_weight = 1; + within_weight = 1; + + + total_sum.set_size(0); + between_cov.set_size(0,0); + total_cov.set_size(0,0); + within_cov.set_size(0,0); + } + + long in_vector_size ( + ) const + { + return vect_size; + } + + void set_within_class_weight ( + scalar_type weight + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(weight >= 0, + "\t void discriminant_pca::set_within_class_weight()" + << "\n\t You can't use negative weight values" + << "\n\t weight: " << weight + << "\n\t this: " << this + ); + + within_weight = weight; + } + + scalar_type within_class_weight ( + ) const + { + return within_weight; + } + + void set_between_class_weight ( + scalar_type weight + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(weight >= 0, + "\t void discriminant_pca::set_between_class_weight()" + << "\n\t You can't use negative weight values" + << "\n\t weight: " << weight + << "\n\t this: " << this + ); + + between_weight = weight; + } + + scalar_type between_class_weight ( + ) const + { + return between_weight; + } + + template + void add_to_within_class_variance( + const matrix_exp& x, + const matrix_exp& y + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(is_col_vector(x) && is_col_vector(y) && + x.size() == y.size() && + (in_vector_size() == 0 || x.size() == in_vector_size()), + "\t void discriminant_pca::add_to_within_class_variance()" + << "\n\t Invalid inputs were given to this function" + << "\n\t is_col_vector(x): " << is_col_vector(x) + << "\n\t is_col_vector(y): " << is_col_vector(y) + << "\n\t x.size(): " << x.size() + << "\n\t y.size(): " << y.size() + << "\n\t in_vector_size(): " << in_vector_size() + << "\n\t this: " << this + ); + + vect_size = x.size(); + if (within_count == 0) + { + within_cov = (x-y)*trans(x-y); + } + else + { + within_cov += (x-y)*trans(x-y); + } + ++within_count; + } + + template + void add_to_between_class_variance( + const matrix_exp& x, + const matrix_exp& y + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(is_col_vector(x) && is_col_vector(y) && + x.size() == y.size() && + (in_vector_size() == 0 || x.size() == in_vector_size()), + "\t void discriminant_pca::add_to_between_class_variance()" + << "\n\t Invalid inputs were given to this function" + << "\n\t is_col_vector(x): " << is_col_vector(x) + << "\n\t is_col_vector(y): " << is_col_vector(y) + << "\n\t x.size(): " << x.size() + << "\n\t y.size(): " << y.size() + << "\n\t in_vector_size(): " << in_vector_size() + << "\n\t this: " << this + ); + + vect_size = x.size(); + if (between_count == 0) + { + between_cov = (x-y)*trans(x-y); + } + else + { + between_cov += (x-y)*trans(x-y); + } + ++between_count; + } + + template + void add_to_total_variance( + const matrix_exp& x + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(is_col_vector(x) && (in_vector_size() == 0 || x.size() == in_vector_size()), + "\t void discriminant_pca::add_to_total_variance()" + << "\n\t Invalid inputs were given to this function" + << "\n\t is_col_vector(x): " << is_col_vector(x) + << "\n\t in_vector_size(): " << in_vector_size() + << "\n\t x.size(): " << x.size() + << "\n\t this: " << this + ); + + vect_size = x.size(); + if (total_count == 0) + { + total_cov = x*trans(x); + total_sum = x; + } + else + { + total_cov += x*trans(x); + total_sum += x; + } + ++total_count; + } + + const general_matrix dpca_matrix ( + const double eps = 0.99 + ) const + { + general_matrix dpca_mat; + general_matrix eigenvalues; + dpca_matrix(dpca_mat, eigenvalues, eps); + return dpca_mat; + } + + const general_matrix dpca_matrix_of_size ( + const long num_rows + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(0 < num_rows && num_rows <= in_vector_size(), + "\t general_matrix discriminant_pca::dpca_matrix_of_size()" + << "\n\t Invalid inputs were given to this function" + << "\n\t num_rows: " << num_rows + << "\n\t in_vector_size(): " << in_vector_size() + << "\n\t this: " << this + ); + + general_matrix dpca_mat; + general_matrix eigenvalues; + dpca_matrix_of_size(dpca_mat, eigenvalues, num_rows); + return dpca_mat; + } + + void dpca_matrix ( + general_matrix& dpca_mat, + general_matrix& eigenvalues, + const double eps = 0.99 + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(0 < eps && eps <= 1 && in_vector_size() != 0, + "\t void discriminant_pca::dpca_matrix()" + << "\n\t Invalid inputs were given to this function" + << "\n\t eps: " << eps + << "\n\t in_vector_size(): " << in_vector_size() + << "\n\t this: " << this + ); + + compute_dpca_matrix(dpca_mat, eigenvalues, eps, 0); + } + + void dpca_matrix_of_size ( + general_matrix& dpca_mat, + general_matrix& eigenvalues, + const long num_rows + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(0 < num_rows && num_rows <= in_vector_size(), + "\t general_matrix discriminant_pca::dpca_matrix_of_size()" + << "\n\t Invalid inputs were given to this function" + << "\n\t num_rows: " << num_rows + << "\n\t in_vector_size(): " << in_vector_size() + << "\n\t this: " << this + ); + + compute_dpca_matrix(dpca_mat, eigenvalues, 1, num_rows); + } + + void swap ( + discriminant_pca& item + ) + { + using std::swap; + swap(total_cov, item.total_cov); + swap(total_sum, item.total_sum); + swap(total_count, item.total_count); + swap(vect_size, item.vect_size); + swap(between_cov, item.between_cov); + + swap(between_count, item.between_count); + swap(between_weight, item.between_weight); + swap(within_cov, item.within_cov); + swap(within_count, item.within_count); + swap(within_weight, item.within_weight); + } + + friend void deserialize ( + discriminant_pca& item, + std::istream& in + ) + { + deserialize( item.total_cov, in); + deserialize( item.total_sum, in); + deserialize( item.total_count, in); + deserialize( item.vect_size, in); + deserialize( item.between_cov, in); + deserialize( item.between_count, in); + deserialize( item.between_weight, in); + deserialize( item.within_cov, in); + deserialize( item.within_count, in); + deserialize( item.within_weight, in); + } + + friend void serialize ( + const discriminant_pca& item, + std::ostream& out + ) + { + serialize( item.total_cov, out); + serialize( item.total_sum, out); + serialize( item.total_count, out); + serialize( item.vect_size, out); + serialize( item.between_cov, out); + serialize( item.between_count, out); + serialize( item.between_weight, out); + serialize( item.within_cov, out); + serialize( item.within_count, out); + serialize( item.within_weight, out); + } + + discriminant_pca operator+ ( + const discriminant_pca& item + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT((in_vector_size() == 0 || item.in_vector_size() == 0 || in_vector_size() == item.in_vector_size()) && + between_class_weight() == item.between_class_weight() && + within_class_weight() == item.within_class_weight(), + "\t discriminant_pca discriminant_pca::operator+()" + << "\n\t The two discriminant_pca objects being added must have compatible parameters" + << "\n\t in_vector_size(): " << in_vector_size() + << "\n\t item.in_vector_size(): " << item.in_vector_size() + << "\n\t between_class_weight(): " << between_class_weight() + << "\n\t item.between_class_weight(): " << item.between_class_weight() + << "\n\t within_class_weight(): " << within_class_weight() + << "\n\t item.within_class_weight(): " << item.within_class_weight() + << "\n\t this: " << this + ); + + discriminant_pca temp(item); + + // We need to make sure to ignore empty matrices. That's what these if statements + // are for. + + if (total_count != 0 && temp.total_count != 0) + { + temp.total_cov += total_cov; + temp.total_sum += total_sum; + temp.total_count += total_count; + } + else if (total_count != 0) + { + temp.total_cov = total_cov; + temp.total_sum = total_sum; + temp.total_count = total_count; + } + + if (between_count != 0 && temp.between_count != 0) + { + temp.between_cov += between_cov; + temp.between_count += between_count; + } + else if (between_count != 0) + { + temp.between_cov = between_cov; + temp.between_count = between_count; + } + + if (within_count != 0 && temp.within_count != 0) + { + temp.within_cov += within_cov; + temp.within_count += within_count; + } + else if (within_count != 0) + { + temp.within_cov = within_cov; + temp.within_count = within_count; + } + + return temp; + } + + discriminant_pca& operator+= ( + const discriminant_pca& rhs + ) + { + (*this + rhs).swap(*this); + return *this; + } + + private: + + void compute_dpca_matrix ( + general_matrix& dpca_mat, + general_matrix& eigenvalues, + const double eps, + long num_rows + ) const + { + general_matrix cov; + + // now combine the three measures of variance into a single matrix by using the + // within_weight and between_weight weights. + cov = get_total_covariance_matrix(); + if (within_count != 0) + cov -= within_weight*within_cov/within_count; + if (between_count != 0) + cov += between_weight*between_cov/between_count; + + + eigenvalue_decomposition eig(make_symmetric(cov)); + + eigenvalues = eig.get_real_eigenvalues(); + dpca_mat = eig.get_pseudo_v(); + + // sort the eigenvalues and eigenvectors so that the biggest eigenvalues come first + rsort_columns(dpca_mat, eigenvalues); + + long num_vectors = 0; + if (num_rows == 0) + { + // Some of the eigenvalues might be negative. So first lets zero those out + // so they won't get considered. + eigenvalues = pointwise_multiply(eigenvalues > 0, eigenvalues); + // figure out how many eigenvectors we want in our dpca matrix + const double thresh = sum(eigenvalues)*eps; + double total = 0; + for (long r = 0; r < eigenvalues.size() && total < thresh; ++r) + { + // Don't even think about looking at eigenvalues that are 0. If we go this + // far then we have all we need. + if (eigenvalues(r) == 0) + break; + + ++num_vectors; + total += eigenvalues(r); + } + + if (num_vectors == 0) + throw discriminant_pca_error("While performing discriminant_pca, all eigenvalues were negative or 0"); + } + else + { + num_vectors = num_rows; + } + + + // So now we know we want to use num_vectors of the first eigenvectors. So + // pull those out and discard the rest. + dpca_mat = trans(colm(dpca_mat,range(0,num_vectors-1))); + + // also clip off the eigenvalues we aren't using + eigenvalues = rowm(eigenvalues, range(0,num_vectors-1)); + + } + + general_matrix get_total_covariance_matrix ( + ) const + /*! + ensures + - returns the covariance matrix of all the data given to the add_to_total_variance() + !*/ + { + // if we don't even know the dimensionality of the vectors we are dealing + // with then just return an empty matrix + if (vect_size == 0) + return general_matrix(); + + // we know the vector size but we have zero total covariance. + if (total_count == 0) + { + general_matrix temp(vect_size,vect_size); + temp = 0; + return temp; + } + + // In this case we actually have something to make a total covariance matrix out of. + // So do that. + column_matrix avg = total_sum/total_count; + + return total_cov/total_count - avg*trans(avg); + } + + general_matrix total_cov; + column_matrix total_sum; + scalar_type total_count; + + long vect_size; + + general_matrix between_cov; + scalar_type between_count; + scalar_type between_weight; + + general_matrix within_cov; + scalar_type within_count; + scalar_type within_weight; + }; + + template < + typename matrix_type + > + inline void swap ( + discriminant_pca& a, + discriminant_pca& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_DPCA_h_ + + diff --git a/ml/dlib/dlib/statistics/dpca_abstract.h b/ml/dlib/dlib/statistics/dpca_abstract.h new file mode 100644 index 000000000..d9eef635b --- /dev/null +++ b/ml/dlib/dlib/statistics/dpca_abstract.h @@ -0,0 +1,365 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_DPCA_ABSTRaCT_ +#ifdef DLIB_DPCA_ABSTRaCT_ + +#include +#include +#include "../matrix/matrix_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename matrix_type + > + class discriminant_pca + { + /*! + REQUIREMENTS ON matrix_type + Must be some type of dlib::matrix. + + INITIAL VALUE + - in_vector_size() == 0 + - between_class_weight() == 1 + - within_class_weight() == 1 + + WHAT THIS OBJECT REPRESENTS + This object implements the Discriminant PCA technique described in the paper: + A New Discriminant Principal Component Analysis Method with Partial Supervision (2009) + by Dan Sun and Daoqiang Zhang + + This algorithm is basically a straightforward generalization of the classical PCA + technique to handle partially labeled data. It is useful if you want to learn a linear + dimensionality reduction rule using a bunch of data that is partially labeled. + + It functions by estimating three different scatter matrices. The first is the total scatter + matrix St (i.e. the total data covariance matrix), the second is the between class scatter + matrix Sb (basically a measure of the variance between data of different classes) and the + third is the within class scatter matrix Sw (a measure of the variance of data within the + same classes). + + Once these three matrices are estimated they are combined according to the following equation: + S = St + a*Sb - b*Sw + Where a and b are user supplied weights. Then the largest eigenvalues of the S matrix are + computed and their associated eigenvectors are returned as the output of this algorithm. + That is, the desired linear dimensionality reduction is given by the matrix with these + eigenvectors stored in its rows. + + Note that if a and b are set to 0 (or no labeled data is provided) then the output transformation + matrix is the same as the one produced by the classical PCA algorithm. + !*/ + + public: + + struct discriminant_pca_error : public error; + /*! + This exception is thrown if there is some error that prevents us from creating + a DPCA matrix. + !*/ + + typedef typename matrix_type::mem_manager_type mem_manager_type; + typedef typename matrix_type::type scalar_type; + typedef typename matrix_type::layout_type layout_type; + typedef matrix general_matrix; + typedef matrix column_matrix; + + discriminant_pca ( + ); + /*! + ensures + - this object is properly initialized + !*/ + + void clear( + ); + /*! + ensures + - #*this has its initial value + !*/ + + long in_vector_size ( + ) const; + /*! + ensures + - if (this object has been presented with any input vectors) then + - returns the dimension of the column vectors used with this object + - else + - returns 0 + !*/ + + void set_within_class_weight ( + scalar_type weight + ); + /*! + requires + - weight >= 0 + ensures + - #within_class_weight() == weight + !*/ + + scalar_type within_class_weight ( + ) const; + /*! + ensures + - returns the weight used when combining the within class scatter matrix with + the other scatter matrices. + !*/ + + void set_between_class_weight ( + scalar_type weight + ); + /*! + requires + - weight >= 0 + ensures + - #between_class_weight() == weight + !*/ + + scalar_type between_class_weight ( + ) const; + /*! + ensures + - returns the weight used when combining the between class scatter matrix with + the other scatter matrices. + !*/ + + void add_to_within_class_variance( + const matrix_exp& x, + const matrix_exp& y + ); + /*! + requires + - is_col_vector(x) == true + - is_col_vector(y) == true + - x.size() == y.size() + - if (in_vector_size() != 0) then + - x.size() == y.size() == in_vector_size() + ensures + - #in_vector_size() == x.size() + - Adds (x-y)*trans(x-y) to the within class scatter matrix. + (i.e. the direction given by (x-y) is recorded as being a direction associated + with within class variance and is therefore unimportant and will be weighted + less in the final dimensionality reduction) + !*/ + + void add_to_between_class_variance( + const matrix_exp& x, + const matrix_exp& y + ); + /*! + requires + - is_col_vector(x) == true + - is_col_vector(y) == true + - x.size() == y.size() + - if (in_vector_size() != 0) then + - x.size() == y.size() == in_vector_size() + ensures + - #in_vector_size() == x.size() + - Adds (x-y)*trans(x-y) to the between class scatter matrix. + (i.e. the direction given by (x-y) is recorded as being a direction associated + with between class variance and is therefore important and will be weighted + higher in the final dimensionality reduction) + !*/ + + void add_to_total_variance( + const matrix_exp& x + ); + /*! + requires + - is_col_vector(x) == true + - if (in_vector_size() != 0) then + - x.size() == in_vector_size() + ensures + - #in_vector_size() == x.size() + - let M denote the centroid (or mean) of all the data. Then this function + Adds (x-M)*trans(x-M) to the total scatter matrix. + (i.e. the direction given by (x-M) is recorded as being a direction associated + with unlabeled variance and is therefore of default importance and will be weighted + as described in the discriminant_pca class description.) + !*/ + + const general_matrix dpca_matrix ( + const double eps = 0.99 + ) const; + /*! + requires + - 0 < eps <= 1 + - in_vector_size() != 0 + (i.e. you have to have given this object some data) + ensures + - computes and returns the matrix MAT given by dpca_matrix(MAT,eigen,eps). + That is, this function returns the dpca_matrix computed by the function + defined below. + - Note that MAT is the desired linear transformation matrix. That is, + multiplying a vector by MAT performs the desired linear dimensionality reduction. + throws + - discriminant_pca_error + This exception is thrown if we are unable to create the dpca_matrix for some + reason. For example, if only within class examples have been given or + within_class_weight() is very large then all eigenvalues will be negative and + that prevents this algorithm from working properly. + !*/ + + void dpca_matrix ( + general_matrix& dpca_mat, + general_matrix& eigenvalues, + const double eps = 0.99 + ) const; + /*! + requires + - 0 < eps <= 1 + - in_vector_size() != 0 + (i.e. you have to have given this object some data) + ensures + - is_col_vector(#eigenvalues) == true + - #dpca_mat.nr() == eigenvalues.size() + - #dpca_mat.nc() == in_vector_size() + - rowm(#dpca_mat,i) represents the ith eigenvector of the S matrix described + in the class description and its eigenvalue is given by eigenvalues(i). + - all values in #eigenvalues are > 0. Moreover, the eigenvalues are in + sorted order with the largest eigenvalue stored at eigenvalues(0). + - (#dpca_mat)*trans(#dpca_mat) == identity_matrix. + (i.e. the rows of the dpca_matrix are all unit length vectors and are mutually + orthogonal) + - Note that #dpca_mat is the desired linear transformation matrix. That is, + multiplying a vector by #dpca_mat performs the desired linear dimensionality + reduction. + - sum(#eigenvalues) will be equal to about eps times the total sum of all + positive eigenvalues in the S matrix described in this class's description. + This means that eps is a number that controls how "lossy" the dimensionality + reduction will be. Large values of eps result in more output dimensions + while smaller values result in fewer. + throws + - discriminant_pca_error + This exception is thrown if we are unable to create the dpca_matrix for some + reason. For example, if only within class examples have been given or + within_class_weight() is very large then all eigenvalues will be negative and + that prevents this algorithm from working properly. + !*/ + + const general_matrix dpca_matrix_of_size ( + const long num_rows + ); + /*! + requires + - 0 < num_rows <= in_vector_size() + ensures + - computes and returns the matrix MAT given by dpca_matrix_of_size(MAT,eigen,num_rows). + That is, this function returns the dpca_matrix computed by the function + defined below. + - Note that MAT is the desired linear transformation matrix. That is, + multiplying a vector by MAT performs the desired linear dimensionality + reduction to num_rows dimensions. + !*/ + + void dpca_matrix_of_size ( + general_matrix& dpca_mat, + general_matrix& eigenvalues, + const long num_rows + ); + /*! + requires + - 0 < num_rows <= in_vector_size() + ensures + - is_col_vector(#eigenvalues) == true + - #dpca_mat.nr() == eigenvalues.size() + - #dpca_mat.nr() == num_rows + - #dpca_mat.nc() == in_vector_size() + - rowm(#dpca_mat,i) represents the ith eigenvector of the S matrix described + in the class description and its eigenvalue is given by eigenvalues(i). + - The values in #eigenvalues might be positive or negative. Additionally, the + eigenvalues are in sorted order with the largest eigenvalue stored at + eigenvalues(0). + - (#dpca_mat)*trans(#dpca_mat) == identity_matrix. + (i.e. the rows of the dpca_matrix are all unit length vectors and are mutually + orthogonal) + - Note that #dpca_mat is the desired linear transformation matrix. That is, + multiplying a vector by #dpca_mat performs the desired linear dimensionality + reduction to num_rows dimensions. + !*/ + + discriminant_pca operator+ ( + const discriminant_pca& item + ) const; + /*! + requires + - in_vector_size() == 0 || item.in_vector_size() == 0 || in_vector_size() == item.in_vector_size() + (i.e. the in_vector_size() of *this and item must match or one must be zero) + - between_class_weight() == item.between_class_weight() + - within_class_weight() == item.within_class_weight() + ensures + - returns a new discriminant_pca object that represents the combination of all + the measurements given to *this and item. That is, this function returns a + discriminant_pca object, R, that is equivalent to what you would obtain if all + modifying calls (e.g. the add_to_*() functions) to *this and item had instead + been done to R. + !*/ + + discriminant_pca& operator+= ( + const discriminant_pca& rhs + ); + /*! + requires + - in_vector_size() == 0 || rhs.in_vector_size() == 0 || in_vector_size() == rhs.in_vector_size() + (i.e. the in_vector_size() of *this and rhs must match or one must be zero) + - between_class_weight() == rhs.between_class_weight() + - within_class_weight() == rhs.within_class_weight() + ensures + - #*this == *item + rhs + - returns #*this + !*/ + + void swap ( + discriminant_pca& item + ); + /*! + ensures + - swaps *this and item + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename matrix_type + > + inline void swap ( + discriminant_pca& a, + discriminant_pca& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + + template < + typename matrix_type, + > + void deserialize ( + discriminant_pca& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + + template < + typename matrix_type, + > + void serialize ( + const discriminant_pca& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_DPCA_ABSTRaCT_ + diff --git a/ml/dlib/dlib/statistics/image_feature_sampling.h b/ml/dlib/dlib/statistics/image_feature_sampling.h new file mode 100644 index 000000000..f04f9926e --- /dev/null +++ b/ml/dlib/dlib/statistics/image_feature_sampling.h @@ -0,0 +1,82 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_IMAGE_FEATURE_SaMPLING_Hh_ +#define DLIB_IMAGE_FEATURE_SaMPLING_Hh_ + +#include "image_feature_sampling_abstract.h" +#include "../statistics.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_array_type, + typename feature_extractor_type, + typename pyramid_type + > + random_subset_selector randomly_sample_image_features ( + const image_array_type& images, + const pyramid_type& pyr, + const feature_extractor_type& fe_, + unsigned long num + ) + { + feature_extractor_type fe; + fe.copy_configuration(fe_); + random_subset_selector basis; + basis.set_max_size(num); + + typedef typename image_array_type::type image_type; + image_type temp_img, temp_img2; + + for (unsigned long i = 0; i < images.size(); ++i) + { + bool at_pyramid_top = true; + while (true) + { + if (at_pyramid_top) + fe.load(images[i]); + else + fe.load(temp_img); + + if (fe.size() == 0) + break; + + for (long r = 0; r < fe.nr(); ++r) + { + for (long c = 0; c < fe.nc(); ++c) + { + if (basis.next_add_accepts()) + { + basis.add(fe(r,c)); + } + else + { + basis.add(); + } + } + } + + if (at_pyramid_top) + { + at_pyramid_top = false; + pyr(images[i], temp_img); + } + else + { + pyr(temp_img, temp_img2); + swap(temp_img2,temp_img); + } + } + } + return basis; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_IMAGE_FEATURE_SaMPLING_Hh_ + diff --git a/ml/dlib/dlib/statistics/image_feature_sampling_abstract.h b/ml/dlib/dlib/statistics/image_feature_sampling_abstract.h new file mode 100644 index 000000000..b51ef5423 --- /dev/null +++ b/ml/dlib/dlib/statistics/image_feature_sampling_abstract.h @@ -0,0 +1,45 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_IMAGE_FEATURE_SaMPLING_ABSTRACT_Hh_ +#ifdef DLIB_IMAGE_FEATURE_SaMPLING_ABSTRACT_Hh_ + +#include "random_subset_selector_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_array_type, + typename feature_extractor_type, + typename pyramid_type + > + random_subset_selector randomly_sample_image_features ( + const image_array_type& images, + const pyramid_type& pyr, + const feature_extractor_type& fe, + unsigned long num + ); + /*! + requires + - pyramid_type == a type compatible with the image pyramid objects defined + in dlib/image_transforms/image_pyramid_abstract.h + - feature_extractor_type == a local image feature extractor type such as the + dlib::hog_image + - image_array_type == an implementation of dlib/array/array_kernel_abstract.h + and it must contain image objects which can be passed to pyr() and fe.load() + and are swappable by global swap(). + ensures + - creates an image pyramid for each image in images and performs feature + extraction on each pyramid level. Then selects a random subsample of at + most num local feature vectors and returns it. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_IMAGE_FEATURE_SaMPLING_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/statistics/lda.h b/ml/dlib/dlib/statistics/lda.h new file mode 100644 index 000000000..38de3fd1e --- /dev/null +++ b/ml/dlib/dlib/statistics/lda.h @@ -0,0 +1,237 @@ +// Copyright (C) 2014 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_LDA_Hh_ +#define DLIB_LDA_Hh_ + +#include "lda_abstract.h" +#include "../algs.h" +#include +#include "../matrix.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + + inline std::map make_class_labels( + const std::vector& row_labels + ) + { + std::map class_labels; + for (unsigned long i = 0; i < row_labels.size(); ++i) + { + const unsigned long next = class_labels.size(); + if (class_labels.count(row_labels[i]) == 0) + class_labels[row_labels[i]] = next; + } + return class_labels; + } + + // ------------------------------------------------------------------------------------ + + template < + typename T + > + matrix center_matrix ( + matrix& X + ) + { + matrix mean; + for (long r = 0; r < X.nr(); ++r) + mean += rowm(X,r); + mean /= X.nr(); + + for (long r = 0; r < X.nr(); ++r) + set_rowm(X,r) -= mean; + + return trans(mean); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + void compute_lda_transform ( + matrix& X, + matrix& mean, + const std::vector& row_labels, + unsigned long lda_dims = 500, + unsigned long extra_pca_dims = 200 + ) + { + std::map class_labels = impl::make_class_labels(row_labels); + // LDA can only give out at most class_labels.size()-1 dimensions so don't try to + // compute more than that. + lda_dims = std::min(lda_dims, class_labels.size()-1); + + // make sure requires clause is not broken + DLIB_CASSERT(class_labels.size() > 1, + "\t void compute_lda_transform()" + << "\n\t You can't call this function if the number of distinct class labels is less than 2." + ); + DLIB_CASSERT(X.size() != 0 && (long)row_labels.size() == X.nr() && lda_dims != 0, + "\t void compute_lda_transform()" + << "\n\t Invalid inputs were given to this function." + << "\n\t X.size(): " << X.size() + << "\n\t row_labels.size(): " << row_labels.size() + << "\n\t lda_dims: " << lda_dims + ); + + + mean = impl::center_matrix(X); + // Do PCA to reduce dims + matrix pu,pw,pv; + svd_fast(X, pu, pw, pv, lda_dims+extra_pca_dims, 4); + pu.set_size(0,0); // free RAM, we don't need pu. + X = X*pv; + + + matrix class_means(class_labels.size(), X.nc()); + class_means = 0; + matrix class_counts(class_labels.size()); + class_counts = 0; + + // First compute the means of each class + for (unsigned long i = 0; i < row_labels.size(); ++i) + { + const unsigned long class_idx = class_labels[row_labels[i]]; + set_rowm(class_means,class_idx) += rowm(X,i); + class_counts(class_idx)++; + } + class_means = inv(diagm(class_counts))*class_means; + // subtract means from the data + for (unsigned long i = 0; i < row_labels.size(); ++i) + { + const unsigned long class_idx = class_labels[row_labels[i]]; + set_rowm(X,i) -= rowm(class_means,class_idx); + } + + // Note that we are using the formulas from the paper Using Discriminant + // Eigenfeatures for Image Retrieval by Swets and Weng. + matrix Sw = trans(X)*X; + matrix Sb = trans(class_means)*class_means; + matrix A, H; + matrix W; + svd3(Sw, A, W, H); + W = sqrt(W); + W = reciprocal(lowerbound(W,max(W)*1e-5)); + A = trans(H*diagm(W))*Sb*H*diagm(W); + matrix v,s,u; + svd3(A, v, s, u); + matrix tform = H*diagm(W)*u; + // pick out only the number of dimensions we are supposed to for the output, unless + // we should just keep them all, then don't do anything. + if ((long)lda_dims <= tform.nc()) + { + rsort_columns(tform, s); + tform = colm(tform, range(0, lda_dims-1)); + } + + X = trans(pv*tform); + mean = X*mean; + } + +// ---------------------------------------------------------------------------------------- + + inline std::pair equal_error_rate ( + const std::vector& low_vals, + const std::vector& high_vals + ) + { + std::vector > temp; + temp.reserve(low_vals.size()+high_vals.size()); + for (unsigned long i = 0; i < low_vals.size(); ++i) + temp.push_back(std::make_pair(low_vals[i], -1)); + for (unsigned long i = 0; i < high_vals.size(); ++i) + temp.push_back(std::make_pair(high_vals[i], +1)); + + std::sort(temp.begin(), temp.end()); + + if (temp.size() == 0) + return std::make_pair(0,0); + + double thresh = temp[0].first; + + unsigned long num_low_wrong = low_vals.size(); + unsigned long num_high_wrong = 0; + double low_error = num_low_wrong/(double)low_vals.size(); + double high_error = num_high_wrong/(double)high_vals.size(); + for (unsigned long i = 0; i < temp.size() && high_error < low_error; ++i) + { + thresh = temp[i].first; + if (temp[i].second > 0) + { + num_high_wrong++; + high_error = num_high_wrong/(double)high_vals.size(); + } + else + { + num_low_wrong--; + low_error = num_low_wrong/(double)low_vals.size(); + } + } + + return std::make_pair((low_error+high_error)/2, thresh); + } + +// ---------------------------------------------------------------------------------------- + + struct roc_point + { + double true_positive_rate; + double false_positive_rate; + double detection_threshold; + }; + + inline std::vector compute_roc_curve ( + const std::vector& true_detections, + const std::vector& false_detections + ) + { + DLIB_CASSERT(true_detections.size() != 0); + DLIB_CASSERT(false_detections.size() != 0); + + std::vector > temp; + temp.reserve(true_detections.size()+false_detections.size()); + for (unsigned long i = 0; i < true_detections.size(); ++i) + temp.push_back(std::make_pair(true_detections[i], +1)); + for (unsigned long i = 0; i < false_detections.size(); ++i) + temp.push_back(std::make_pair(false_detections[i], -1)); + + std::sort(temp.rbegin(), temp.rend()); + + + std::vector roc_curve; + roc_curve.reserve(temp.size()); + + double num_false_included = 0; + double num_true_included = 0; + for (unsigned long i = 0; i < temp.size(); ++i) + { + if (temp[i].second > 0) + num_true_included++; + else + num_false_included++; + + roc_point p; + p.true_positive_rate = num_true_included/true_detections.size(); + p.false_positive_rate = num_false_included/false_detections.size(); + p.detection_threshold = temp[i].first; + roc_curve.push_back(p); + } + + return roc_curve; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_LDA_Hh_ + diff --git a/ml/dlib/dlib/statistics/lda_abstract.h b/ml/dlib/dlib/statistics/lda_abstract.h new file mode 100644 index 000000000..ab9fd7a32 --- /dev/null +++ b/ml/dlib/dlib/statistics/lda_abstract.h @@ -0,0 +1,118 @@ +// Copyright (C) 2014 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_LDA_ABSTRACT_Hh_ +#ifdef DLIB_LDA_ABSTRACT_Hh_ + +#include +#include "../matrix.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + void compute_lda_transform ( + matrix& X, + matrix& M, + const std::vector& row_labels, + unsigned long lda_dims = 500, + unsigned long extra_pca_dims = 200 + ); + /*! + requires + - X.size() != 0 + - row_labels.size() == X.nr() + - The number of distinct values in row_labels > 1 + - lda_dims != 0 + ensures + - We interpret X as a collection X.nr() of input vectors, where each row of X + is one of the vectors. + - We interpret row_labels[i] as the label of the vector rowm(X,i). + - This function performs the dimensionality reducing version of linear + discriminant analysis. That is, you give it a set of labeled vectors and it + returns a linear transform that maps the input vectors into a new space that + is good for distinguishing between the different classes. In particular, + this function finds matrices Z and M such that: + - Given an input vector x, Z*x-M, is the transformed version of x. That is, + Z*x-M maps x into a space where x vectors that share the same class label + are near each other. + - Z*x-M results in the transformed vectors having zero expected mean. + - Z.nr() <= lda_dims + (it might be less than lda_dims if there are not enough distinct class + labels to support lda_dims dimensions). + - Z.nc() == X.nc() + - We overwrite the input matrix X and store Z in it. Therefore, the + outputs of this function are in X and M. + - In order to deal with very high dimensional inputs, we perform PCA internally + to map the input vectors into a space of at most lda_dims+extra_pca_dims + prior to performing LDA. + !*/ + +// ---------------------------------------------------------------------------------------- + + std::pair equal_error_rate ( + const std::vector& low_vals, + const std::vector& high_vals + ); + /*! + ensures + - This function finds a threshold T that best separates the elements of + low_vals from high_vals by selecting the threshold with equal error rate. In + particular, we try to pick a threshold T such that: + - for all valid i: + - high_vals[i] >= T + - for all valid i: + - low_vals[i] < T + Where the best T is determined such that the fraction of low_vals >= T is the + same as the fraction of high_vals < T. + - Let ERR == the equal error rate. I.e. the fraction of times low_vals >= T + and high_vals < T. Note that 0 <= ERR <= 1. + - returns make_pair(ERR,T) + !*/ + +// ---------------------------------------------------------------------------------------- + + struct roc_point + { + double true_positive_rate; + double false_positive_rate; + double detection_threshold; + }; + + std::vector compute_roc_curve ( + const std::vector& true_detections, + const std::vector& false_detections + ); + /*! + requires + - true_detections.size() != 0 + - false_detections.size() != 0 + ensures + - This function computes the ROC curve (receiver operating characteristic) + curve of the given data. Therefore, we interpret true_detections as + containing detection scores for a bunch of true detections and + false_detections as detection scores from a bunch of false detections. A + perfect detector would always give higher scores to true detections than to + false detections, resulting in a true positive rate of 1 and a false positive + rate of 0, for some appropriate detection threshold. + - Returns an array, ROC, such that: + - ROC.size() == true_detections.size()+false_detections.size() + - for all valid i: + - If you were to accept all detections with a score >= ROC[i].detection_threshold + then you would obtain a true positive rate of ROC[i].true_positive_rate and a + false positive rate of ROC[i].false_positive_rate. + - ROC is ordered such that low detection rates come first. That is, the + curve is swept from a high detection threshold to a low threshold. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_LDA_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/statistics/random_subset_selector.h b/ml/dlib/dlib/statistics/random_subset_selector.h new file mode 100644 index 000000000..17492363d --- /dev/null +++ b/ml/dlib/dlib/statistics/random_subset_selector.h @@ -0,0 +1,372 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_RANDOM_SUBSeT_SELECTOR_H_ +#define DLIB_RANDOM_SUBSeT_SELECTOR_H_ + +#include "random_subset_selector_abstract.h" +#include "../rand.h" +#include +#include "../algs.h" +#include "../string.h" +#include "../serialize.h" +#include "../matrix/matrix_mat.h" +#include + +namespace dlib +{ + template < + typename T, + typename Rand_type = dlib::rand + > + class random_subset_selector + { + /*! + INITIAL VALUE + - _max_size == 0 + - items.size() == 0 + - count == 0 + - _next_add_accepts == false + + CONVENTION + - count == the number of times add() has been called since the last + time this object was empty. + - items.size() == size() + - max_size() == _max_size + - next_add_accepts() == _next_add_accepts + !*/ + public: + typedef T type; + typedef T value_type; + typedef default_memory_manager mem_manager_type; + typedef Rand_type rand_type; + + typedef typename std::vector::iterator iterator; + typedef typename std::vector::const_iterator const_iterator; + + + random_subset_selector ( + ) + { + _max_size = 0; + make_empty(); + } + + void set_seed(const std::string& value) + { + rnd.set_seed(value); + } + + void make_empty ( + ) + { + items.resize(0); + count = 0; + update_next_add_accepts(); + } + + const std::vector& to_std_vector( + ) const { return items; } + + size_t size ( + ) const + { + return items.size(); + } + + void set_max_size ( + unsigned long new_max_size + ) + { + items.reserve(new_max_size); + make_empty(); + _max_size = new_max_size; + update_next_add_accepts(); + } + + unsigned long max_size ( + ) const + { + return _max_size; + } + + T& operator[] ( + unsigned long idx + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(idx < size(), + "\tvoid random_subset_selector::operator[]()" + << "\n\t idx is out of range" + << "\n\t idx: " << idx + << "\n\t size(): " << size() + << "\n\t this: " << this + ); + + return items[idx]; + } + + const T& operator[] ( + unsigned long idx + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(idx < size(), + "\tvoid random_subset_selector::operator[]()" + << "\n\t idx is out of range" + << "\n\t idx: " << idx + << "\n\t size(): " << size() + << "\n\t this: " << this + ); + + return items[idx]; + } + + iterator begin() { return items.begin(); } + const_iterator begin() const { return items.begin(); } + iterator end() { return items.end(); } + const_iterator end() const { return items.end(); } + + bool next_add_accepts ( + ) const + { + return _next_add_accepts; + } + + void add ( + const T& new_item + ) + { + if (items.size() < _max_size) + { + items.push_back(new_item); + // swap into a random place + exchange(items[rnd.get_random_32bit_number()%items.size()], items.back()); + } + else if (_next_add_accepts) + { + // pick a random element of items and replace it. + items[rnd.get_random_32bit_number()%items.size()] = new_item; + } + + update_next_add_accepts(); + ++count; + } + + void add ( + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(next_add_accepts() == false, + "\tvoid random_subset_selector::add()" + << "\n\t You should be calling the version of add() that takes an argument" + << "\n\t this: " << this + ); + + update_next_add_accepts(); + ++count; + } + + void swap ( + random_subset_selector& a + ) + { + items.swap(a.items); + std::swap(_max_size, a._max_size); + std::swap(count, a.count); + rnd.swap(a.rnd); + std::swap(_next_add_accepts, a._next_add_accepts); + } + + template + friend void serialize ( + const random_subset_selector& item, + std::ostream& out + ); + + template + friend void deserialize ( + random_subset_selector& item, + std::istream& in + ); + + private: + + void update_next_add_accepts ( + ) + { + if (items.size() < _max_size) + { + _next_add_accepts = true; + } + else if (_max_size == 0) + { + _next_add_accepts = false; + } + else + { + // At this point each element of items has had an equal chance of being in this object. + // In particular, the probability that each arrived here is currently items.size()/count. + // We need to be able to say that, after this function ends, the probability of any + // particular object ending up in items is items.size()/(count+1). So this means that + // we should decide to add a new item into items with this probability. Also, if we do + // so then we pick one of the current items and replace it at random with the new item. + + // Make me a random 64 bit number. This might seem excessive but I want this object + // to be able to handle an effectively infinite number of calls to add(). So count + // might get very large and we need to deal with that properly. + const unsigned long num1 = rnd.get_random_32bit_number(); + const unsigned long num2 = rnd.get_random_32bit_number(); + uint64 num = num1; + num <<= 32; + num |= num2; + + num %= (count+1); + + _next_add_accepts = (num < items.size()); + } + + } + + std::vector items; + unsigned long _max_size; + uint64 count; + + rand_type rnd; + + bool _next_add_accepts; + + }; + + template < + typename T, + typename rand_type + > + void swap ( + random_subset_selector& a, + random_subset_selector& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- + + template + void serialize ( + const random_subset_selector& item, + std::ostream& out + ) + { + serialize(item.items, out); + serialize(item._max_size, out); + serialize(item.count, out); + serialize(item.rnd, out); + serialize(item._next_add_accepts, out); + } + + template + void deserialize ( + random_subset_selector& item, + std::istream& in + ) + { + deserialize(item.items, in); + deserialize(item._max_size, in); + deserialize(item.count, in); + deserialize(item.rnd, in); + deserialize(item._next_add_accepts, in); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename alloc + > + random_subset_selector randomly_subsample ( + const std::vector& samples, + unsigned long num + ) + { + random_subset_selector subset; + subset.set_max_size(num); + for (unsigned long i = 0; i < samples.size(); ++i) + subset.add(samples[i]); + return subset; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename alloc, + typename U + > + random_subset_selector randomly_subsample ( + const std::vector& samples, + unsigned long num, + const U& random_seed + ) + { + random_subset_selector subset; + subset.set_seed(cast_to_string(random_seed)); + subset.set_max_size(num); + for (unsigned long i = 0; i < samples.size(); ++i) + subset.add(samples[i]); + return subset; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + random_subset_selector randomly_subsample ( + const random_subset_selector& samples, + unsigned long num + ) + { + random_subset_selector subset; + subset.set_max_size(num); + for (unsigned long i = 0; i < samples.size(); ++i) + subset.add(samples[i]); + return subset; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename U + > + random_subset_selector randomly_subsample ( + const random_subset_selector& samples, + unsigned long num, + const U& random_seed + ) + { + random_subset_selector subset; + subset.set_seed(cast_to_string(random_seed)); + subset.set_max_size(num); + for (unsigned long i = 0; i < samples.size(); ++i) + subset.add(samples[i]); + return subset; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + const matrix_op > > mat ( + const random_subset_selector& m + ) + { + typedef op_array_to_mat > op; + return matrix_op(op(m)); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_RANDOM_SUBSeT_SELECTOR_H_ + + diff --git a/ml/dlib/dlib/statistics/random_subset_selector_abstract.h b/ml/dlib/dlib/statistics/random_subset_selector_abstract.h new file mode 100644 index 000000000..96f8b545d --- /dev/null +++ b/ml/dlib/dlib/statistics/random_subset_selector_abstract.h @@ -0,0 +1,388 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_RANDOM_SUBSeT_SELECTOR_ABSTRACT_H_ +#ifdef DLIB_RANDOM_SUBSeT_SELECTOR_ABSTRACT_H_ + +#include +#include "../rand/rand_kernel_abstract.h" +#include "../algs.h" +#include "../string.h" + +namespace dlib +{ + template < + typename T, + typename Rand_type = dlib::rand + > + class random_subset_selector + { + /*! + REQUIREMENTS ON T + T must be a copyable type + + REQUIREMENTS ON Rand_type + must be an implementation of dlib/rand/rand_kernel_abstract.h + + INITIAL VALUE + - size() == 0 + - max_size() == 0 + - next_add_accepts() == false + + WHAT THIS OBJECT REPRESENTS + This object is a tool to help you select a random subset of a large body of data. + In particular, it is useful when the body of data is too large to fit into memory. + + So for example, suppose you have 1000000 data samples and you want to select a + random subset of size 1000. Then you could do that as follows: + + random_subset_selector rand_subset; + rand_subset.set_max_size(1000) + for (int i = 0; i < 1000000; ++i) + rand_subset.add( get_next_data_sample()); + + + At the end of the for loop you will have your random subset of 1000 samples. And by + random I mean that each of the 1000000 data samples has an equal chance of ending + up in the rand_subset object. + + + Note that the above example calls get_next_data_sample() for each data sample. This + may be inefficient since most of the data samples are just ignored. An alternative + method that doesn't require you to load each sample can also be used. Consider the + following: + + random_subset_selector rand_subset; + rand_subset.set_max_size(1000) + for (int i = 0; i < 1000000; ++i) + if (rand_subset.next_add_accepts()) + rand_subset.add(get_data_sample(i)); + else + rand_subset.add() + + In the above example we only actually fetch the data sample into memory if we + know that the rand_subset would include it into the random subset. Otherwise, + we can just call the empty add(). + + + Finally, note that the random_subset_selector uses a deterministic pseudo-random + number generator under the hood. Moreover, the default constructor always seeds + the random number generator in the same way. So unless you call set_seed() + each instance of the random_subset_selector will function identically. + !*/ + public: + typedef T type; + typedef T value_type; + typedef default_memory_manager mem_manager_type; + typedef Rand_type rand_type; + + typedef typename std::vector::iterator iterator; + typedef typename std::vector::const_iterator const_iterator; + + random_subset_selector ( + ); + /*! + ensures + - this object is properly initialized + !*/ + + void set_seed( + const std::string& value + ); + /*! + ensures + - sets the seed of the random number generator that is embedded in + this object to the given value. + !*/ + + void make_empty ( + ); + /*! + ensures + - #size() == 0 + !*/ + + size_t size ( + ) const; + /*! + ensures + - returns the number of items of type T currently contained in this object + !*/ + + void set_max_size ( + unsigned long new_max_size + ); + /*! + ensures + - #max_size() == new_max_size + - #size() == 0 + !*/ + + unsigned long max_size ( + ) const; + /*! + ensures + - returns the maximum allowable size for this object + !*/ + + T& operator[] ( + unsigned long idx + ); + /*! + requires + - idx < size() + ensures + - returns a non-const reference to the idx'th element of this object + !*/ + + const T& operator[] ( + unsigned long idx + ) const; + /*! + requires + - idx < size() + ensures + - returns a const reference to the idx'th element of this object + !*/ + + bool next_add_accepts ( + ) const; + /*! + ensures + - if (the next call to add(item) will result in item being included + into *this) then + - returns true + - Note that the next item will always be accepted if size() < max_size(). + - else + - returns false + - Note that the next item will never be accepted if max_size() == 0. + !*/ + + void add ( + const T& new_item + ); + /*! + ensures + - if (next_add_accepts()) then + - places new_item into *this object at a random location + - if (size() < max_size()) then + - #size() == size() + 1 + - #next_add_accepts() == The updated information about the acceptance + of the next call to add() + !*/ + + void add ( + ); + /*! + requires + - next_add_accepts() == false + ensures + - This function does nothing but update the value of #next_add_accepts() + !*/ + + iterator begin( + ); + /*! + ensures + - if (size() > 0) then + - returns an iterator referring to the first element in + this container. + - else + - returns end() + !*/ + + const_iterator begin( + ) const; + /*! + ensures + - if (size() > 0) then + - returns a const_iterator referring to the first element in + this container. + - else + - returns end() + !*/ + + iterator end( + ); + /*! + ensures + - returns an iterator that represents one past the end of + this container + !*/ + + const_iterator end( + ) const; + /*! + ensures + - returns an iterator that represents one past the end of + this container + !*/ + + const std::vector& to_std_vector( + ) const; + /*! + ensures + - returns a const reference to the underlying std::vector that contains + all elements in this object. That is, this function returns a vector, V, + which has the following properties: + - V.size() == this->size() + - V.begin() == this->begin() + - V.end() == this->end() + !*/ + + void swap ( + random_subset_selector& item + ); + /*! + ensures + - swaps *this and item + !*/ + + }; + + template < + typename T, + typename rand_type + > + void swap ( + random_subset_selector& a, + random_subset_selector& b + ) { a.swap(b); } + /*! + provides global swap support + !*/ + + template < + typename T, + typename rand_type + > + void serialize ( + const random_subset_selector& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + + template < + typename T, + typename rand_type + > + void deserialize ( + random_subset_selector& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename alloc + > + random_subset_selector randomly_subsample ( + const std::vector& samples, + unsigned long num + ); + /*! + ensures + - returns a random subset R such that: + - R contains a random subset of the given samples + - R.size() == min(num, samples.size()) + - R.max_size() == num + - The random number generator used by this function will always be + initialized in the same way. I.e. this function will always pick + the same random subsample if called multiple times. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename alloc, + typename U + > + random_subset_selector randomly_subsample ( + const std::vector& samples, + unsigned long num, + const U& random_seed + ); + /*! + requires + - random_seed must be convertible to a string by dlib::cast_to_string() + ensures + - returns a random subset R such that: + - R contains a random subset of the given samples + - R.size() == min(num, samples.size()) + - R.max_size() == num + - The given random_seed will be used to initialize the random number + generator used by this function. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + random_subset_selector randomly_subsample ( + const random_subset_selector& samples, + unsigned long num + ); + /*! + ensures + - returns a random subset R such that: + - R contains a random subset of the given samples + - R.size() == min(num, samples.size()) + - R.max_size() == num + - The random number generator used by this function will always be + initialized in the same way. I.e. this function will always pick + the same random subsample if called multiple times. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename U + > + random_subset_selector randomly_subsample ( + const random_subset_selector& samples, + unsigned long num, + const U& random_seed + ); + /*! + requires + - random_seed must be convertible to a string by dlib::cast_to_string() + ensures + - returns a random subset R such that: + - R contains a random subset of the given samples + - R.size() == min(num, samples.size()) + - R.max_size() == num + - The given random_seed will be used to initialize the random number + generator used by this function. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + const matrix_exp mat ( + const random_subset_selector& m + ); + /*! + ensures + - returns a matrix R such that: + - is_col_vector(R) == true + - R.size() == m.size() + - for all valid r: + R(r) == m[r] + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_RANDOM_SUBSeT_SELECTOR_ABSTRACT_H_ + diff --git a/ml/dlib/dlib/statistics/running_gradient.h b/ml/dlib/dlib/statistics/running_gradient.h new file mode 100644 index 000000000..d3f1b3ddf --- /dev/null +++ b/ml/dlib/dlib/statistics/running_gradient.h @@ -0,0 +1,370 @@ +// Copyright (C) 2016 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_RuNNING_GRADIENT_Hh_ +#define DLIB_RuNNING_GRADIENT_Hh_ + +#include "running_gradient_abstract.h" +#include "../algs.h" +#include "../serialize.h" +#include +#include "../matrix.h" +#include + + +namespace dlib +{ + class running_gradient + { + public: + + running_gradient ( + ) + { + clear(); + } + + void clear( + ) + { + n = 0; + R = identity_matrix(2)*1e6; + w = 0; + residual_squared = 0; + } + + double current_n ( + ) const + { + return n; + } + + void add( + double y + ) + { + matrix x; + x = n, 1; + + // Do recursive least squares computations + const double temp = 1 + trans(x)*R*x; + matrix tmp = R*x; + R = R - (tmp*trans(tmp))/temp; + // R should always be symmetric. This line improves numeric stability of this algorithm. + R = 0.5*(R + trans(R)); + w = w + R*x*(y - trans(x)*w); + + // Also, recursively keep track of the residual error between the given value + // and what our linear predictor outputs. + residual_squared = residual_squared + std::pow((y - trans(x)*w),2.0)*temp; + + ++n; + } + + double gradient ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(current_n() > 1, + "\t double running_gradient::gradient()" + << "\n\t You must add more values into this object before calling this function." + << "\n\t this: " << this + ); + + return w(0); + } + + double intercept ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(current_n() > 0, + "\t double running_gradient::intercept()" + << "\n\t You must add more values into this object before calling this function." + << "\n\t this: " << this + ); + + return w(1); + } + double standard_error ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(current_n() > 2, + "\t double running_gradient::standard_error()" + << "\n\t You must add more values into this object before calling this function." + << "\n\t this: " << this + ); + + + const double s = residual_squared/(n-2); + const double adjust = 12.0/(std::pow(current_n(),3.0) - current_n()); + return std::sqrt(s*adjust); + } + + double probability_gradient_less_than ( + double thresh + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(current_n() > 2, + "\t double running_gradient::probability_gradient_less_than()" + << "\n\t You must add more values into this object before calling this function." + << "\n\t this: " << this + ); + + return normal_cdf(thresh, gradient(), standard_error()); + } + + double probability_gradient_greater_than ( + double thresh + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(current_n() > 2, + "\t double running_gradient::probability_gradient_greater_than()" + << "\n\t You must add more values into this object before calling this function." + << "\n\t this: " << this + ); + + return 1-probability_gradient_less_than(thresh); + } + + friend void serialize (const running_gradient& item, std::ostream& out) + { + int version = 1; + serialize(version, out); + serialize(item.n, out); + serialize(item.R, out); + serialize(item.w, out); + serialize(item.residual_squared, out); + } + + friend void deserialize (running_gradient& item, std::istream& in) + { + int version = 0; + deserialize(version, in); + if (version != 1) + throw serialization_error("Unexpected version found while deserializing dlib::running_gradient."); + deserialize(item.n, in); + deserialize(item.R, in); + deserialize(item.w, in); + deserialize(item.residual_squared, in); + } + + private: + + static double normal_cdf(double value, double mean, double stddev) + { + if (stddev == 0) + { + if (value < mean) + return 0; + else if (value > mean) + return 1; + else + return 0.5; + } + value = (value-mean)/stddev; + return 0.5 * std::erfc(-value / std::sqrt(2.0)); + } + + double n; + matrix R; + matrix w; + double residual_squared; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + double probability_gradient_less_than ( + const T& container, + double thresh + ) + { + running_gradient g; + for(auto&& v : container) + g.add(v); + + // make sure requires clause is not broken + DLIB_ASSERT(g.current_n() > 2, + "\t double probability_gradient_less_than()" + << "\n\t You need more than 2 elements in the given container to call this function." + ); + return g.probability_gradient_less_than(thresh); + } + + template < + typename T + > + double probability_gradient_greater_than ( + const T& container, + double thresh + ) + { + running_gradient g; + for(auto&& v : container) + g.add(v); + + // make sure requires clause is not broken + DLIB_ASSERT(g.current_n() > 2, + "\t double probability_gradient_greater_than()" + << "\n\t You need more than 2 elements in the given container to call this function." + ); + return g.probability_gradient_greater_than(thresh); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + double find_upper_quantile ( + const T& container_, + double quantile + ) + { + DLIB_CASSERT(0 <= quantile && quantile <= 1.0); + + // copy container into a std::vector + std::vector container(container_.begin(), container_.end()); + + DLIB_CASSERT(container.size() > 0); + + size_t idx_upper = std::round((container.size()-1)*(1-quantile)); + + std::nth_element(container.begin(), container.begin()+idx_upper, container.end()); + auto upper_q = *(container.begin()+idx_upper); + return upper_q; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + size_t count_steps_without_decrease ( + const T& container, + double probability_of_decrease = 0.51 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(0.5 < probability_of_decrease && probability_of_decrease < 1, + "\t size_t count_steps_without_decrease()" + << "\n\t probability_of_decrease: "<< probability_of_decrease + ); + + running_gradient g; + size_t count = 0; + size_t j = 0; + for (auto i = container.rbegin(); i != container.rend(); ++i) + { + ++j; + g.add(*i); + if (g.current_n() > 2) + { + // Note that this only looks backwards because we are looping over the + // container backwards. So here we are really checking if the gradient isn't + // decreasing. + double prob_decreasing = g.probability_gradient_greater_than(0); + // If we aren't confident things are decreasing. + if (prob_decreasing < probability_of_decrease) + count = j; + } + } + return count; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + size_t count_steps_without_decrease_robust ( + const T& container, + double probability_of_decrease = 0.51, + double quantile_discard = 0.10 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(0 <= quantile_discard && quantile_discard <= 1); + DLIB_ASSERT(0.5 < probability_of_decrease && probability_of_decrease < 1, + "\t size_t count_steps_without_decrease_robust()" + << "\n\t probability_of_decrease: "<< probability_of_decrease + ); + + if (container.size() == 0) + return 0; + + const auto quantile_thresh = find_upper_quantile(container, quantile_discard); + + running_gradient g; + size_t count = 0; + size_t j = 0; + for (auto i = container.rbegin(); i != container.rend(); ++i) + { + ++j; + // ignore values that are too large + if (*i <= quantile_thresh) + g.add(*i); + + if (g.current_n() > 2) + { + // Note that this only looks backwards because we are looping over the + // container backwards. So here we are really checking if the gradient isn't + // decreasing. + double prob_decreasing = g.probability_gradient_greater_than(0); + // If we aren't confident things are decreasing. + if (prob_decreasing < probability_of_decrease) + count = j; + } + } + return count; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + size_t count_steps_without_increase ( + const T& container, + double probability_of_increase = 0.51 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(0.5 < probability_of_increase && probability_of_increase < 1, + "\t size_t count_steps_without_increase()" + << "\n\t probability_of_increase: "<< probability_of_increase + ); + + running_gradient g; + size_t count = 0; + size_t j = 0; + for (auto i = container.rbegin(); i != container.rend(); ++i) + { + ++j; + g.add(*i); + if (g.current_n() > 2) + { + // Note that this only looks backwards because we are looping over the + // container backwards. So here we are really checking if the gradient isn't + // increasing. + double prob_increasing = g.probability_gradient_less_than(0); + // If we aren't confident things are increasing. + if (prob_increasing < probability_of_increase) + count = j; + } + } + return count; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_RuNNING_GRADIENT_Hh_ + + diff --git a/ml/dlib/dlib/statistics/running_gradient_abstract.h b/ml/dlib/dlib/statistics/running_gradient_abstract.h new file mode 100644 index 000000000..a42e1c152 --- /dev/null +++ b/ml/dlib/dlib/statistics/running_gradient_abstract.h @@ -0,0 +1,276 @@ +// Copyright (C) 2016 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_RuNNING_GRADIENT_ABSTRACT_Hh_ +#ifdef DLIB_RuNNING_GRADIENT_ABSTRACT_Hh_ + + +namespace dlib +{ + class running_gradient + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a tool for estimating if a noisy sequence of numbers is + trending up or down and by how much. It does this by finding the least + squares fit of a line to the data and then allows you to perform a + statistical test on the slope of that line. + !*/ + + public: + + running_gradient ( + ); + /*! + ensures + - #current_n() == 0 + !*/ + + void clear( + ); + /*! + ensures + - #current_n() == 0 + - this object has its initial value + - clears all memory of any previous data points + !*/ + + double current_n ( + ) const; + /*! + ensures + - returns the number of values given to this object by add(). + !*/ + + void add( + double y + ); + /*! + ensures + - Updates the gradient() and standard_error() estimates in this object + based on the new y value. + - #current_n() == current_n() + 1 + !*/ + + double gradient ( + ) const; + /*! + requires + - current_n() > 1 + ensures + - If we consider the values given to add() as time series data, we can + estimate the rate-of-change of those values. That is, how much, + typically, do those values change from sample to sample? The gradient() + function returns the current estimate. It does this by finding the least + squares fit of a line to the data given to add() and returning the slope + of this line. + !*/ + + double intercept ( + ) const; + /*! + requires + - current_n() > 0 + ensures + - This class fits a line to the time series data given to add(). This + function returns the intercept of that line while gradient() returns the + slope of that line. This means that, for example, the next point that + add() will see, as predicted by this best fit line, is the value + intercept() + current_n()*gradient(). + !*/ + + double standard_error ( + ) const; + /*! + requires + - current_n() > 2 + ensures + - returns the standard deviation of the estimate of gradient(). + !*/ + + double probability_gradient_less_than ( + double thresh + ) const; + /*! + requires + - current_n() > 2 + ensures + - If we can assume the values given to add() are linearly related to each + other and corrupted by Gaussian additive noise then our estimate of + gradient() is a random variable with a mean value of gradient() and a + standard deviation of standard_error(). This lets us compute the + probability that the true gradient of the data is less than thresh, which + is what this function returns. + !*/ + + double probability_gradient_greater_than ( + double thresh + ) const; + /*! + requires + - current_n() > 2 + ensures + - returns 1-probability_gradient_less_than(thresh) + !*/ + + }; + + void serialize ( + const running_gradient& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + + void deserialize ( + running_gradient& item, + std::istream& in + ); + /*! + provides serialization support + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + double probability_gradient_less_than ( + const T& container, + double thresh + ); + /*! + requires + - container must be a container of double values that can be enumerated with a + range based for loop. + - The container must contain more than 2 elements. + ensures + - Puts all the elements of container into a running_gradient object, R, and + then returns R.probability_gradient_less_than(thresh). + !*/ + + template < + typename T + > + double probability_gradient_greater_than ( + const T& container, + double thresh + ); + /*! + requires + - container must be a container of double values that can be enumerated with a + range based for loop. + - The container must contain more than 2 elements. + ensures + - Puts all the elements of container into a running_gradient object, R, and + then returns R.probability_gradient_greater_than(thresh). + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + size_t count_steps_without_decrease ( + const T& container, + double probability_of_decrease = 0.51 + ); + /*! + requires + - container must be a container of double values that can be enumerated with + .rbegin() and .rend(). + - 0.5 < probability_of_decrease < 1 + ensures + - If you think of the contents of container as a potentially noisy time series, + then this function returns a count of how long the time series has gone + without noticeably decreasing in value. It does this by adding the + elements into a running_gradient object and counting how many elements, + starting with container.back(), that you need to examine before you are + confident that the series has been decreasing in value. Here, "confident of + decrease" means that the probability of decrease is >= probability_of_decrease. + - Setting probability_of_decrease to 0.51 means we count until we see even a + small hint of decrease, whereas a larger value of 0.99 would return a larger + count since it keeps going until it is nearly certain the time series is + decreasing. + - The max possible output from this function is container.size(). + !*/ + + template < + typename T + > + size_t count_steps_without_decrease_robust ( + const T& container, + double probability_of_decrease = 0.51, + double quantile_discard = 0.10 + ); + /*! + requires + - container must be a container of double values that can be enumerated with + .begin() and .end() as well as .rbegin() and .rend(). + - 0.5 < probability_of_decrease < 1 + - 0 <= quantile_discard <= 1 + ensures + - This function behaves just like + count_steps_without_decrease(container,probability_of_decrease) except that + it ignores values in container that are in the upper quantile_discard + quantile. So for example, if the quantile discard is 0.1 then the 10% + largest values in container are ignored. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + size_t count_steps_without_increase ( + const T& container, + double probability_of_increase = 0.51 + ); + /*! + requires + - container must be a container of double values that can be enumerated with + .rbegin() and .rend(). + - 0.5 < probability_of_increase < 1 + ensures + - If you think of the contents of container as a potentially noisy time series, + then this function returns a count of how long the time series has gone + without noticeably increasing in value. It does this by adding the + elements into a running_gradient object and counting how many elements, + starting with container.back(), that you need to examine before you are + confident that the series has been increasing in value. Here, "confident of + increase" means that the probability of increase is >= probability_of_increase. + - Setting probability_of_increase to 0.51 means we count until we see even a + small hint of increase, whereas a larger value of 0.99 would return a larger + count since it keeps going until it is nearly certain the time series is + increasing. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + double find_upper_quantile ( + const T& container, + double quantile + ); + /*! + requires + - container must be a container of double values that can be enumerated with + .begin() and .end(). + - 0 <= quantile <= 1 + - container.size() > 0 + ensures + - Finds and returns the value such that quantile percent of the values in + container are greater than it. For example, 0.5 would find the median value + in container while 0.1 would find the value that lower bounded the 10% + largest values in container. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_RuNNING_GRADIENT_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/statistics/sammon.h b/ml/dlib/dlib/statistics/sammon.h new file mode 100644 index 000000000..1a3eb72a1 --- /dev/null +++ b/ml/dlib/dlib/statistics/sammon.h @@ -0,0 +1,269 @@ +// Copyright (C) 2012 Emanuele Cesena (emanuele.cesena@gmail.com), Davis E. King +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SAMMoN_Hh_ +#define DLIB_SAMMoN_Hh_ + +#include "sammon_abstract.h" +#include "../matrix.h" +#include "../algs.h" +#include "dpca.h" +#include + +namespace dlib +{ + + class sammon_projection + { + + public: + + // ------------------------------------------------------------------------------------ + + template + std::vector > operator() ( + const std::vector& data, + const long num_dims + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(num_dims > 0, + "\t std::vector > sammon_projection::operator()" + << "\n\t Invalid inputs were given to this function." + << "\n\t num_dims: " << num_dims + ); + std::vector > result; // projections + if (data.size() == 0) + { + return result; + } + +#ifdef ENABLE_ASSERTS + DLIB_ASSERT(0 < num_dims && num_dims <= data[0].size(), + "\t std::vector > sammon_projection::operator()" + << "\n\t Invalid inputs were given to this function." + << "\n\t data.size(): " << data.size() + << "\n\t num_dims: " << num_dims + << "\n\t data[0].size(): " << data[0].size() + ); + for (unsigned long i = 0; i < data.size(); ++i) + { + DLIB_ASSERT(is_col_vector(data[i]) && data[i].size() == data[0].size(), + "\t std::vector > sammon_projection::operator()" + << "\n\t Invalid inputs were given to this function." + << "\n\t data["< + void operator() ( + const std::vector& data, + const long num_dims, + std::vector >& result, + double &err, + const unsigned long num_iters = 1000, + const double err_delta = 1.0e-9 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(num_dims > 0 && num_iters > 0 && err_delta > 0.0, + "\t std::vector > sammon_projection::operator()" + << "\n\t Invalid inputs were given to this function." + << "\n\t data.size(): " << data.size() + << "\n\t num_dims: " << num_dims + << "\n\t num_iters: " << num_iters + << "\n\t err_delta: " << err_delta + ); + if (data.size() == 0) + { + result.clear(); + err = 0; + return; + } + +#ifdef ENABLE_ASSERTS + DLIB_ASSERT(0 < num_dims && num_dims <= data[0].size(), + "\t std::vector > sammon_projection::operator()" + << "\n\t Invalid inputs were given to this function." + << "\n\t data.size(): " << data.size() + << "\n\t num_dims: " << num_dims + << "\n\t data[0].size(): " << data[0].size() + ); + for (unsigned long i = 0; i < data.size(); ++i) + { + DLIB_ASSERT(is_col_vector(data[i]) && data[i].size() == data[0].size(), + "\t std::vector > sammon_projection::operator()" + << "\n\t Invalid inputs were given to this function." + << "\n\t data["<& dist, // relative distances (output) + matrix& data, // input data (matrix whose columns are the input vectors) + double eps_ratio = 1.0e-7 // to compute the minimum distance eps + ) + /*! + requires + - dist.nc() == comb( data.nc(), 2 ), preallocated + - eps_ratio > 0 + ensures + - dist[k] == lenght(data[i] - data[j]) for k = j(j-1)/2 + i + !*/ + { + const long N = data.nc(); // num of points + double eps; // minimum distance, forced to avoid vectors collision + // computed at runtime as eps_ration * mean(vectors distances) + for (int k = 0, i = 1; i < N; ++i) + for (int j = 0; j < i; ++j) + dist(k++) = length(colm(data, i) - colm(data, j)); + + eps = eps_ratio * mean(dist); + dist = lowerbound(dist, eps); + } + + // ---------------------------------------------------------------------------------------- + + template + void do_sammon_projection( + const std::vector& data, // input data + unsigned long num_dims, // dimension of the reduced space + std::vector >& result, // projections (output) + double &err, // error (output) + unsigned long num_iters = 1000, // max num of iterations: stop condition + const double err_delta = 1.0e-9 // delta error: stop condition + ) + /*! + requires + - matrix_type should be a kind of dlib::matrix + - num_dims > 0 + - num_iters > 0 + - err_delta > 0 + ensures + - result == a set of matrix objects that represent + the Sammon's projections of data vectors. + - err == the estimated error done in the projection, with the extra + property that err(at previous iteration) - err < err_delta + !*/ + { + // other params + const double mf = 0.3; // magic factor + + matrix mdata; // input data as matrix + matrix projs; // projected vectors, i.e. output data as matrix + + // std::vector -> matrix + mdata.set_size(data[0].size(), data.size()); + for (unsigned int i = 0; i < data.size(); i++) + set_colm(mdata, i) = data[i]; + + const long N = mdata.nc(); // num of points + const long d = num_dims; // size of the reduced space + const long nd = N * (N - 1) / 2; // num of pairs of points = size of the distances vectors + + matrix dsij, inv_dsij; // d*_ij: pair-wise distances in the input space (and inverses) + dsij.set_size(nd, 1); + inv_dsij.set_size(nd, 1); + double ic; // 1.0 / sum of dsij + + matrix dij; // d_ij: pair-wise distances in the reduced space + dij.set_size(nd, 1); + + matrix dE, dE2, dtemp; // matrices representing error partial derivatives + dE.set_size(d, N); + dE2.set_size(d, N); + dtemp.set_size(d, N); + + matrix inv_dij, alpha; // utility vectors used to compute the partial derivatives + inv_dij.set_size(N, 1); // inv_dij is 1.0/dij, but we only need it column-wise + alpha.set_size(N, 1); // (slightly wasting a bit of computation) + // alpha = 1.0/dij - 1.0/dsij, again column-wise + + + // initialize projs with PCA + discriminant_pca > dpca; + for (int i = 0; i < mdata.nc(); ++i) + { + dpca.add_to_total_variance(colm(mdata, i)); + } + matrix mat = dpca.dpca_matrix_of_size(num_dims); + projs = mat * mdata; + + // compute dsij, inv_dsij and ic + compute_relative_distances(dsij, mdata); + inv_dsij = 1.0 / dsij; + ic = 1.0 / sum(dsij); + + // compute dij and err + compute_relative_distances(dij, projs); + err = ic * sum(pointwise_multiply(squared(dij - dsij), inv_dsij)); + + // start iterating + while (num_iters--) + { + // compute dE, dE2 progressively column by column + for (int p = 0; p < N; ++p) + { + // compute + // - alpha_p, the column vector with 1/d_pj - 1/d*_pj + // - dtemp, the matrix with the p-th column repeated all along + //TODO: optimize constructions + for (int i = 0; i < N; ++i) + { + int pos = (i < p) ? p * (p - 1) / 2 + i : i * (i - 1) / 2 + p; + inv_dij(i) = (i == p) ? 0.0 : 1.0 / dij(pos); + alpha(i) = (i == p) ? 0.0 : inv_dij(i) - inv_dsij(pos); + set_colm(dtemp, i) = colm(projs, p); + } + + dtemp -= projs; + set_colm(dE, p) = dtemp * alpha; + + double sum_alpha = sum(alpha); + set_colm(dE2, p) = abs( sum_alpha + squared(dtemp) * cubed(inv_dij) ); + } + + + // compute the update projections + projs += pointwise_multiply(dE, mf * reciprocal(dE2)); + + // compute new dij and error + compute_relative_distances(dij, projs); + double err_new = ic * sum( pointwise_multiply(squared(dij - dsij), inv_dsij) ); + if (err - err_new < err_delta) + break; + err = err_new; + } + + // matrix -> std::vector + result.clear(); + for (int i = 0; i < projs.nc(); ++i) + result.push_back(colm(projs, i)); + } + + }; + +} // namespace dlib + +#endif // DLIB_SAMMoN_Hh_ + diff --git a/ml/dlib/dlib/statistics/sammon_abstract.h b/ml/dlib/dlib/statistics/sammon_abstract.h new file mode 100644 index 000000000..0e009729c --- /dev/null +++ b/ml/dlib/dlib/statistics/sammon_abstract.h @@ -0,0 +1,117 @@ +// Copyright (C) 2012 Emanuele Cesena (emanuele.cesena@gmail.com), Davis E. King +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SAMMoN_ABSTRACT_Hh_ +#ifdef DLIB_SAMMoN_ABSTRACT_Hh_ + +#include "../matrix/matrix_abstract.h" +#include + +namespace dlib +{ + + class sammon_projection + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a function object that computes the Sammon projection of a set + of N points in a L-dimensional vector space onto a d-dimensional space + (d < L), according to the paper: + A Nonlinear Mapping for Data Structure Analysis (1969) by J.W. Sammon + + The current implementation is a vectorized version of the original algorithm. + !*/ + + public: + + sammon_projection( + ); + /*! + ensures + - this object is properly initialized + !*/ + + template + std::vector > operator() ( + const std::vector& data, + const long num_dims + ); + /*! + requires + - num_dims > 0 + - matrix_type should be a kind of dlib::matrix of doubles capable + of representing column vectors. + - for all valid i: + - is_col_vector(data[i]) == true + - data[0].size() == data[i].size() + (i.e. all the vectors in data must have the same dimensionality) + - if (data.size() != 0) then + - 0 < num_dims <= data[0].size() + (i.e. you can't project into a higher dimension than the input data, + only to a lower dimension.) + ensures + - This routine computes Sammon's dimensionality reduction method based on the + given input data. It will attempt to project the contents of data into a + num_dims dimensional space that preserves relative distances between the + input data points. + - This function returns a std::vector, OUT, such that: + - OUT == a set of column vectors that represent the Sammon projection of + the input data vectors. + - OUT.size() == data.size() + - for all valid i: + - OUT[i].size() == num_dims + - OUT[i] == the Sammon projection of the input vector data[i] + !*/ + + template + void operator() ( + const std::vector& data, + const long num_dims, + std::vector >& result, + double &err, + const unsigned long num_iters = 1000, + const double err_delta = 1.0e-9 + ); + /*! + requires + - num_iters > 0 + - err_delta > 0 + - num_dims > 0 + - matrix_type should be a kind of dlib::matrix of doubles capable + of representing column vectors. + - for all valid i: + - is_col_vector(data[i]) == true + - data[0].size() == data[i].size() + (i.e. all the vectors in data must have the same dimensionality) + - if (data.size() != 0) then + - 0 < num_dims <= data[0].size() + (i.e. you can't project into a higher dimension than the input data, + only to a lower dimension.) + ensures + - This routine computes Sammon's dimensionality reduction method based on the + given input data. It will attempt to project the contents of data into a + num_dims dimensional space that preserves relative distances between the + input data points. + - #err == the final error value at the end of the algorithm. The goal of Sammon's + algorithm is to find a lower dimensional projection of the input data that + preserves the relative distances between points. The value in #err is a measure + of the total error at the end of the algorithm. So smaller values indicate + a better projection was found than if a large value is returned via #err. + - Sammon's algorithm will run until either num_iters iterations has executed + or the change in error from one iteration to the next is less than err_delta. + - Upon completion, the output of Sammon's projection is stored into #result, in + particular, we will have: + - #result == a set of column vectors that represent the Sammon projection of + the input data vectors. + - #result.size() == data.size() + - for all valid i: + - #result[i].size() == num_dims + - #result[i] == the Sammon projection of the input vector data[i] + !*/ + + }; + +} + +#endif // DLIB_SAMMoN_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/statistics/statistics.h b/ml/dlib/dlib/statistics/statistics.h new file mode 100644 index 000000000..9dee7006b --- /dev/null +++ b/ml/dlib/dlib/statistics/statistics.h @@ -0,0 +1,1890 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net), Steve Taylor +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_STATISTICs_ +#define DLIB_STATISTICs_ + +#include "statistics_abstract.h" +#include +#include +#include "../algs.h" +#include "../matrix.h" +#include "../sparse_vector.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + class running_stats + { + public: + + running_stats() + { + clear(); + + COMPILE_TIME_ASSERT (( + is_same_type::value || + is_same_type::value || + is_same_type::value + )); + } + + void clear() + { + sum = 0; + sum_sqr = 0; + sum_cub = 0; + sum_four = 0; + + n = 0; + min_value = std::numeric_limits::infinity(); + max_value = -std::numeric_limits::infinity(); + } + + void add ( + const T& val + ) + { + sum += val; + sum_sqr += val*val; + sum_cub += cubed(val); + sum_four += quaded(val); + + if (val < min_value) + min_value = val; + if (val > max_value) + max_value = val; + + ++n; + } + + T current_n ( + ) const + { + return n; + } + + T mean ( + ) const + { + if (n != 0) + return sum/n; + else + return 0; + } + + T max ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(current_n() > 0, + "\tT running_stats::max" + << "\n\tyou have to add some numbers to this object first" + << "\n\tthis: " << this + ); + + return max_value; + } + + T min ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(current_n() > 0, + "\tT running_stats::min" + << "\n\tyou have to add some numbers to this object first" + << "\n\tthis: " << this + ); + + return min_value; + } + + T variance ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(current_n() > 1, + "\tT running_stats::variance" + << "\n\tyou have to add some numbers to this object first" + << "\n\tthis: " << this + ); + + T temp = 1/(n-1); + temp = temp*(sum_sqr - sum*sum/n); + // make sure the variance is never negative. This might + // happen due to numerical errors. + if (temp >= 0) + return temp; + else + return 0; + } + + T stddev ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(current_n() > 1, + "\tT running_stats::stddev" + << "\n\tyou have to add some numbers to this object first" + << "\n\tthis: " << this + ); + + return std::sqrt(variance()); + } + + T skewness ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(current_n() > 2, + "\tT running_stats::skewness" + << "\n\tyou have to add some numbers to this object first" + << "\n\tthis: " << this + ); + + T temp = 1/n; + T temp1 = std::sqrt(n*(n-1))/(n-2); + temp = temp1*temp*(sum_cub - 3*sum_sqr*sum*temp + 2*cubed(sum)*temp*temp)/ + (std::sqrt(std::pow(temp*(sum_sqr-sum*sum*temp),3))); + + return temp; + } + + T ex_kurtosis ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(current_n() > 3, + "\tT running_stats::kurtosis" + << "\n\tyou have to add some numbers to this object first" + << "\n\tthis: " << this + ); + + T temp = 1/n; + T m4 = temp*(sum_four - 4*sum_cub*sum*temp+6*sum_sqr*sum*sum*temp*temp + -3*quaded(sum)*cubed(temp)); + T m2 = temp*(sum_sqr-sum*sum*temp); + temp = (n-1)*((n+1)*m4/(m2*m2)-3*(n-1))/((n-2)*(n-3)); + + return temp; + } + + T scale ( + const T& val + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(current_n() > 1, + "\tT running_stats::variance" + << "\n\tyou have to add some numbers to this object first" + << "\n\tthis: " << this + ); + return (val-mean())/std::sqrt(variance()); + } + + running_stats operator+ ( + const running_stats& rhs + ) const + { + running_stats temp(*this); + + temp.sum += rhs.sum; + temp.sum_sqr += rhs.sum_sqr; + temp.sum_cub += rhs.sum_cub; + temp.sum_four += rhs.sum_four; + temp.n += rhs.n; + temp.min_value = std::min(rhs.min_value, min_value); + temp.max_value = std::max(rhs.max_value, max_value); + return temp; + } + + template + friend void serialize ( + const running_stats& item, + std::ostream& out + ); + + template + friend void deserialize ( + running_stats& item, + std::istream& in + ); + + private: + T sum; + T sum_sqr; + T sum_cub; + T sum_four; + T n; + T min_value; + T max_value; + + T cubed (const T& val) const {return val*val*val; } + T quaded (const T& val) const {return val*val*val*val; } + }; + + template + void serialize ( + const running_stats& item, + std::ostream& out + ) + { + int version = 2; + serialize(version, out); + + serialize(item.sum, out); + serialize(item.sum_sqr, out); + serialize(item.sum_cub, out); + serialize(item.sum_four, out); + serialize(item.n, out); + serialize(item.min_value, out); + serialize(item.max_value, out); + } + + template + void deserialize ( + running_stats& item, + std::istream& in + ) + { + int version = 0; + deserialize(version, in); + if (version != 2) + throw dlib::serialization_error("Unexpected version number found while deserializing dlib::running_stats object."); + + deserialize(item.sum, in); + deserialize(item.sum_sqr, in); + deserialize(item.sum_cub, in); + deserialize(item.sum_four, in); + deserialize(item.n, in); + deserialize(item.min_value, in); + deserialize(item.max_value, in); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + class running_scalar_covariance + { + public: + + running_scalar_covariance() + { + clear(); + + COMPILE_TIME_ASSERT (( + is_same_type::value || + is_same_type::value || + is_same_type::value + )); + } + + void clear() + { + sum_xy = 0; + sum_x = 0; + sum_y = 0; + sum_xx = 0; + sum_yy = 0; + n = 0; + } + + void add ( + const T& x, + const T& y + ) + { + sum_xy += x*y; + + sum_xx += x*x; + sum_yy += y*y; + + sum_x += x; + sum_y += y; + + n += 1; + } + + T current_n ( + ) const + { + return n; + } + + T mean_x ( + ) const + { + if (n != 0) + return sum_x/n; + else + return 0; + } + + T mean_y ( + ) const + { + if (n != 0) + return sum_y/n; + else + return 0; + } + + T covariance ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(current_n() > 1, + "\tT running_scalar_covariance::covariance()" + << "\n\tyou have to add some numbers to this object first" + << "\n\tthis: " << this + ); + + return 1/(n-1) * (sum_xy - sum_y*sum_x/n); + } + + T correlation ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(current_n() > 1, + "\tT running_scalar_covariance::correlation()" + << "\n\tyou have to add some numbers to this object first" + << "\n\tthis: " << this + ); + + return covariance() / std::sqrt(variance_x()*variance_y()); + } + + T variance_x ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(current_n() > 1, + "\tT running_scalar_covariance::variance_x()" + << "\n\tyou have to add some numbers to this object first" + << "\n\tthis: " << this + ); + + T temp = 1/(n-1) * (sum_xx - sum_x*sum_x/n); + // make sure the variance is never negative. This might + // happen due to numerical errors. + if (temp >= 0) + return temp; + else + return 0; + } + + T variance_y ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(current_n() > 1, + "\tT running_scalar_covariance::variance_y()" + << "\n\tyou have to add some numbers to this object first" + << "\n\tthis: " << this + ); + + T temp = 1/(n-1) * (sum_yy - sum_y*sum_y/n); + // make sure the variance is never negative. This might + // happen due to numerical errors. + if (temp >= 0) + return temp; + else + return 0; + } + + T stddev_x ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(current_n() > 1, + "\tT running_scalar_covariance::stddev_x()" + << "\n\tyou have to add some numbers to this object first" + << "\n\tthis: " << this + ); + + return std::sqrt(variance_x()); + } + + T stddev_y ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(current_n() > 1, + "\tT running_scalar_covariance::stddev_y()" + << "\n\tyou have to add some numbers to this object first" + << "\n\tthis: " << this + ); + + return std::sqrt(variance_y()); + } + + running_scalar_covariance operator+ ( + const running_scalar_covariance& rhs + ) const + { + running_scalar_covariance temp(rhs); + + temp.sum_xy += sum_xy; + temp.sum_x += sum_x; + temp.sum_y += sum_y; + temp.sum_xx += sum_xx; + temp.sum_yy += sum_yy; + temp.n += n; + return temp; + } + + private: + + T sum_xy; + T sum_x; + T sum_y; + T sum_xx; + T sum_yy; + T n; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + class running_scalar_covariance_decayed + { + public: + + explicit running_scalar_covariance_decayed( + T decay_halflife = 1000 + ) + { + DLIB_ASSERT(decay_halflife > 0); + + sum_xy = 0; + sum_x = 0; + sum_y = 0; + sum_xx = 0; + sum_yy = 0; + forget = std::pow(0.5, 1/decay_halflife); + n = 0; + + COMPILE_TIME_ASSERT (( + is_same_type::value || + is_same_type::value || + is_same_type::value + )); + } + + T forget_factor ( + ) const + { + return forget; + } + + void add ( + const T& x, + const T& y + ) + { + sum_xy = sum_xy*forget + x*y; + + sum_xx = sum_xx*forget + x*x; + sum_yy = sum_yy*forget + y*y; + + sum_x = sum_x*forget + x; + sum_y = sum_y*forget + y; + + n = n*forget + 1; + } + + T current_n ( + ) const + { + return n; + } + + T mean_x ( + ) const + { + if (n != 0) + return sum_x/n; + else + return 0; + } + + T mean_y ( + ) const + { + if (n != 0) + return sum_y/n; + else + return 0; + } + + T covariance ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(current_n() > 1, + "\tT running_scalar_covariance_decayed::covariance()" + << "\n\tyou have to add some numbers to this object first" + << "\n\tthis: " << this + ); + + return 1/(n-1) * (sum_xy - sum_y*sum_x/n); + } + + T correlation ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(current_n() > 1, + "\tT running_scalar_covariance_decayed::correlation()" + << "\n\tyou have to add some numbers to this object first" + << "\n\tthis: " << this + ); + + T temp = std::sqrt(variance_x()*variance_y()); + if (temp != 0) + return covariance() / temp; + else + return 0; // just say it's zero if there isn't any variance in x or y. + } + + T variance_x ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(current_n() > 1, + "\tT running_scalar_covariance_decayed::variance_x()" + << "\n\tyou have to add some numbers to this object first" + << "\n\tthis: " << this + ); + + T temp = 1/(n-1) * (sum_xx - sum_x*sum_x/n); + // make sure the variance is never negative. This might + // happen due to numerical errors. + if (temp >= 0) + return temp; + else + return 0; + } + + T variance_y ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(current_n() > 1, + "\tT running_scalar_covariance_decayed::variance_y()" + << "\n\tyou have to add some numbers to this object first" + << "\n\tthis: " << this + ); + + T temp = 1/(n-1) * (sum_yy - sum_y*sum_y/n); + // make sure the variance is never negative. This might + // happen due to numerical errors. + if (temp >= 0) + return temp; + else + return 0; + } + + T stddev_x ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(current_n() > 1, + "\tT running_scalar_covariance_decayed::stddev_x()" + << "\n\tyou have to add some numbers to this object first" + << "\n\tthis: " << this + ); + + return std::sqrt(variance_x()); + } + + T stddev_y ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(current_n() > 1, + "\tT running_scalar_covariance_decayed::stddev_y()" + << "\n\tyou have to add some numbers to this object first" + << "\n\tthis: " << this + ); + + return std::sqrt(variance_y()); + } + + private: + + T sum_xy; + T sum_x; + T sum_y; + T sum_xx; + T sum_yy; + T n; + T forget; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + class running_stats_decayed + { + public: + + explicit running_stats_decayed( + T decay_halflife = 1000 + ) + { + DLIB_ASSERT(decay_halflife > 0); + + sum_x = 0; + sum_xx = 0; + forget = std::pow(0.5, 1/decay_halflife); + n = 0; + + COMPILE_TIME_ASSERT (( + is_same_type::value || + is_same_type::value || + is_same_type::value + )); + } + + T forget_factor ( + ) const + { + return forget; + } + + void add ( + const T& x + ) + { + + sum_xx = sum_xx*forget + x*x; + + sum_x = sum_x*forget + x; + + n = n*forget + 1; + } + + T current_n ( + ) const + { + return n; + } + + T mean ( + ) const + { + if (n != 0) + return sum_x/n; + else + return 0; + } + + T variance ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(current_n() > 1, + "\tT running_stats_decayed::variance()" + << "\n\tyou have to add some numbers to this object first" + << "\n\tthis: " << this + ); + + T temp = 1/(n-1) * (sum_xx - sum_x*sum_x/n); + // make sure the variance is never negative. This might + // happen due to numerical errors. + if (temp >= 0) + return temp; + else + return 0; + } + + T stddev ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(current_n() > 1, + "\tT running_stats_decayed::stddev()" + << "\n\tyou have to add some numbers to this object first" + << "\n\tthis: " << this + ); + + return std::sqrt(variance()); + } + + template + friend void serialize ( + const running_stats_decayed& item, + std::ostream& out + ); + + template + friend void deserialize ( + running_stats_decayed& item, + std::istream& in + ); + + private: + + T sum_x; + T sum_xx; + T n; + T forget; + }; + + template + void serialize ( + const running_stats_decayed& item, + std::ostream& out + ) + { + int version = 1; + serialize(version, out); + + serialize(item.sum_x, out); + serialize(item.sum_xx, out); + serialize(item.n, out); + serialize(item.forget, out); + } + + template + void deserialize ( + running_stats_decayed& item, + std::istream& in + ) + { + int version = 0; + deserialize(version, in); + if (version != 1) + throw dlib::serialization_error("Unexpected version number found while deserializing dlib::running_stats_decayed object."); + + deserialize(item.sum_x, in); + deserialize(item.sum_xx, in); + deserialize(item.n, in); + deserialize(item.forget, in); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename alloc + > + double mean_sign_agreement ( + const std::vector& a, + const std::vector& b + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(a.size() == b.size(), + "\t double mean_sign_agreement(a,b)" + << "\n\t a and b must be the same length." + << "\n\t a.size(): " << a.size() + << "\n\t b.size(): " << b.size() + ); + + + double temp = 0; + for (unsigned long i = 0; i < a.size(); ++i) + { + if ((a[i] >= 0 && b[i] >= 0) || + (a[i] < 0 && b[i] < 0)) + { + temp += 1; + } + } + + return temp/a.size(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename alloc + > + double correlation ( + const std::vector& a, + const std::vector& b + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(a.size() == b.size() && a.size() > 1, + "\t double correlation(a,b)" + << "\n\t a and b must be the same length and have more than one element." + << "\n\t a.size(): " << a.size() + << "\n\t b.size(): " << b.size() + ); + + running_scalar_covariance rs; + for (unsigned long i = 0; i < a.size(); ++i) + { + rs.add(a[i], b[i]); + } + return rs.correlation(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename alloc + > + double covariance ( + const std::vector& a, + const std::vector& b + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(a.size() == b.size() && a.size() > 1, + "\t double covariance(a,b)" + << "\n\t a and b must be the same length and have more than one element." + << "\n\t a.size(): " << a.size() + << "\n\t b.size(): " << b.size() + ); + + running_scalar_covariance rs; + for (unsigned long i = 0; i < a.size(); ++i) + { + rs.add(a[i], b[i]); + } + return rs.covariance(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename alloc + > + double r_squared ( + const std::vector& a, + const std::vector& b + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(a.size() == b.size() && a.size() > 1, + "\t double r_squared(a,b)" + << "\n\t a and b must be the same length and have more than one element." + << "\n\t a.size(): " << a.size() + << "\n\t b.size(): " << b.size() + ); + + return std::pow(correlation(a,b),2.0); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename alloc + > + double mean_squared_error ( + const std::vector& a, + const std::vector& b + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(a.size() == b.size(), + "\t double mean_squared_error(a,b)" + << "\n\t a and b must be the same length." + << "\n\t a.size(): " << a.size() + << "\n\t b.size(): " << b.size() + ); + + return mean(squared(matrix_cast(mat(a))-matrix_cast(mat(b)))); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename matrix_type + > + class running_covariance + { + /*! + INITIAL VALUE + - vect_size == 0 + - total_count == 0 + + CONVENTION + - vect_size == in_vector_size() + - total_count == current_n() + + - if (total_count != 0) + - total_sum == the sum of all vectors given to add() + - the covariance of all the elements given to add() is given + by: + - let avg == total_sum/total_count + - covariance == total_cov/total_count - avg*trans(avg) + !*/ + public: + + typedef typename matrix_type::mem_manager_type mem_manager_type; + typedef typename matrix_type::type scalar_type; + typedef typename matrix_type::layout_type layout_type; + typedef matrix general_matrix; + typedef matrix column_matrix; + + running_covariance( + ) + { + clear(); + } + + void clear( + ) + { + total_count = 0; + + vect_size = 0; + + total_sum.set_size(0); + total_cov.set_size(0,0); + } + + long in_vector_size ( + ) const + { + return vect_size; + } + + long current_n ( + ) const + { + return static_cast(total_count); + } + + void set_dimension ( + long size + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( size > 0, + "\t void running_covariance::set_dimension()" + << "\n\t Invalid inputs were given to this function" + << "\n\t size: " << size + << "\n\t this: " << this + ); + + clear(); + vect_size = size; + total_sum.set_size(size); + total_cov.set_size(size,size); + total_sum = 0; + total_cov = 0; + } + + template + typename disable_if >::type add ( + const T& val + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(((long)max_index_plus_one(val) <= in_vector_size() && in_vector_size() > 0), + "\t void running_covariance::add()" + << "\n\t Invalid inputs were given to this function" + << "\n\t max_index_plus_one(val): " << max_index_plus_one(val) + << "\n\t in_vector_size(): " << in_vector_size() + << "\n\t this: " << this + ); + + for (typename T::const_iterator i = val.begin(); i != val.end(); ++i) + { + total_sum(i->first) += i->second; + for (typename T::const_iterator j = val.begin(); j != val.end(); ++j) + { + total_cov(i->first, j->first) += i->second*j->second; + } + } + + ++total_count; + } + + template + typename enable_if >::type add ( + const T& val + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(is_col_vector(val) && (in_vector_size() == 0 || val.size() == in_vector_size()), + "\t void running_covariance::add()" + << "\n\t Invalid inputs were given to this function" + << "\n\t is_col_vector(val): " << is_col_vector(val) + << "\n\t in_vector_size(): " << in_vector_size() + << "\n\t val.size(): " << val.size() + << "\n\t this: " << this + ); + + vect_size = val.size(); + if (total_count == 0) + { + total_cov = val*trans(val); + total_sum = val; + } + else + { + total_cov += val*trans(val); + total_sum += val; + } + ++total_count; + } + + const column_matrix mean ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT( in_vector_size() != 0, + "\t running_covariance::mean()" + << "\n\t This object can not execute this function in its current state." + << "\n\t in_vector_size(): " << in_vector_size() + << "\n\t current_n(): " << current_n() + << "\n\t this: " << this + ); + + return total_sum/total_count; + } + + const general_matrix covariance ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT( in_vector_size() != 0 && current_n() > 1, + "\t running_covariance::covariance()" + << "\n\t This object can not execute this function in its current state." + << "\n\t in_vector_size(): " << in_vector_size() + << "\n\t current_n(): " << current_n() + << "\n\t this: " << this + ); + + return (total_cov - total_sum*trans(total_sum)/total_count)/(total_count-1); + } + + const running_covariance operator+ ( + const running_covariance& item + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT((in_vector_size() == 0 || item.in_vector_size() == 0 || in_vector_size() == item.in_vector_size()), + "\t running_covariance running_covariance::operator+()" + << "\n\t The two running_covariance objects being added must have compatible parameters" + << "\n\t in_vector_size(): " << in_vector_size() + << "\n\t item.in_vector_size(): " << item.in_vector_size() + << "\n\t this: " << this + ); + + running_covariance temp(item); + + // make sure we ignore empty matrices + if (total_count != 0 && temp.total_count != 0) + { + temp.total_cov += total_cov; + temp.total_sum += total_sum; + temp.total_count += total_count; + } + else if (total_count != 0) + { + temp.total_cov = total_cov; + temp.total_sum = total_sum; + temp.total_count = total_count; + } + + return temp; + } + + + private: + + general_matrix total_cov; + column_matrix total_sum; + scalar_type total_count; + + long vect_size; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename matrix_type + > + class running_cross_covariance + { + /*! + INITIAL VALUE + - x_vect_size == 0 + - y_vect_size == 0 + - total_count == 0 + + CONVENTION + - x_vect_size == x_vector_size() + - y_vect_size == y_vector_size() + - total_count == current_n() + + - if (total_count != 0) + - sum_x == the sum of all x vectors given to add() + - sum_y == the sum of all y vectors given to add() + - total_cov == sum of all x*trans(y) given to add() + !*/ + + public: + + typedef typename matrix_type::mem_manager_type mem_manager_type; + typedef typename matrix_type::type scalar_type; + typedef typename matrix_type::layout_type layout_type; + typedef matrix general_matrix; + typedef matrix column_matrix; + + running_cross_covariance( + ) + { + clear(); + } + + void clear( + ) + { + total_count = 0; + + x_vect_size = 0; + y_vect_size = 0; + + sum_x.set_size(0); + sum_y.set_size(0); + total_cov.set_size(0,0); + } + + long x_vector_size ( + ) const + { + return x_vect_size; + } + + long y_vector_size ( + ) const + { + return y_vect_size; + } + + void set_dimensions ( + long x_size, + long y_size + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( x_size > 0 && y_size > 0, + "\t void running_cross_covariance::set_dimensions()" + << "\n\t Invalid inputs were given to this function" + << "\n\t x_size: " << x_size + << "\n\t y_size: " << y_size + << "\n\t this: " << this + ); + + clear(); + x_vect_size = x_size; + y_vect_size = y_size; + sum_x.set_size(x_size); + sum_y.set_size(y_size); + total_cov.set_size(x_size,y_size); + + sum_x = 0; + sum_y = 0; + total_cov = 0; + } + + long current_n ( + ) const + { + return static_cast(total_count); + } + + template + typename enable_if_c::value && !is_matrix::value>::type add ( + const T& x, + const U& y + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( ((long)max_index_plus_one(x) <= x_vector_size() && x_vector_size() > 0) && + ((long)max_index_plus_one(y) <= y_vector_size() && y_vector_size() > 0) , + "\t void running_cross_covariance::add()" + << "\n\t Invalid inputs were given to this function" + << "\n\t max_index_plus_one(x): " << max_index_plus_one(x) + << "\n\t max_index_plus_one(y): " << max_index_plus_one(y) + << "\n\t x_vector_size(): " << x_vector_size() + << "\n\t y_vector_size(): " << y_vector_size() + << "\n\t this: " << this + ); + + for (typename T::const_iterator i = x.begin(); i != x.end(); ++i) + { + sum_x(i->first) += i->second; + for (typename U::const_iterator j = y.begin(); j != y.end(); ++j) + { + total_cov(i->first, j->first) += i->second*j->second; + } + } + + // do sum_y += y + for (typename U::const_iterator j = y.begin(); j != y.end(); ++j) + { + sum_y(j->first) += j->second; + } + + ++total_count; + } + + template + typename enable_if_c::value && !is_matrix::value>::type add ( + const T& x, + const U& y + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( (is_col_vector(x) && x.size() == x_vector_size() && x_vector_size() > 0) && + ((long)max_index_plus_one(y) <= y_vector_size() && y_vector_size() > 0) , + "\t void running_cross_covariance::add()" + << "\n\t Invalid inputs were given to this function" + << "\n\t is_col_vector(x): " << is_col_vector(x) + << "\n\t x.size(): " << x.size() + << "\n\t max_index_plus_one(y): " << max_index_plus_one(y) + << "\n\t x_vector_size(): " << x_vector_size() + << "\n\t y_vector_size(): " << y_vector_size() + << "\n\t this: " << this + ); + + sum_x += x; + + for (long i = 0; i < x.size(); ++i) + { + for (typename U::const_iterator j = y.begin(); j != y.end(); ++j) + { + total_cov(i, j->first) += x(i)*j->second; + } + } + + // do sum_y += y + for (typename U::const_iterator j = y.begin(); j != y.end(); ++j) + { + sum_y(j->first) += j->second; + } + + ++total_count; + } + + template + typename enable_if_c::value && is_matrix::value>::type add ( + const T& x, + const U& y + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( ((long)max_index_plus_one(x) <= x_vector_size() && x_vector_size() > 0) && + (is_col_vector(y) && y.size() == (long)y_vector_size() && y_vector_size() > 0) , + "\t void running_cross_covariance::add()" + << "\n\t Invalid inputs were given to this function" + << "\n\t max_index_plus_one(x): " << max_index_plus_one(x) + << "\n\t is_col_vector(y): " << is_col_vector(y) + << "\n\t y.size(): " << y.size() + << "\n\t x_vector_size(): " << x_vector_size() + << "\n\t y_vector_size(): " << y_vector_size() + << "\n\t this: " << this + ); + + for (typename T::const_iterator i = x.begin(); i != x.end(); ++i) + { + sum_x(i->first) += i->second; + for (long j = 0; j < y.size(); ++j) + { + total_cov(i->first, j) += i->second*y(j); + } + } + + sum_y += y; + + ++total_count; + } + + template + typename enable_if_c::value && is_matrix::value>::type add ( + const T& x, + const U& y + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(is_col_vector(x) && (x_vector_size() == 0 || x.size() == x_vector_size()) && + is_col_vector(y) && (y_vector_size() == 0 || y.size() == y_vector_size()) && + x.size() != 0 && + y.size() != 0, + "\t void running_cross_covariance::add()" + << "\n\t Invalid inputs were given to this function" + << "\n\t is_col_vector(x): " << is_col_vector(x) + << "\n\t x_vector_size(): " << x_vector_size() + << "\n\t x.size(): " << x.size() + << "\n\t is_col_vector(y): " << is_col_vector(y) + << "\n\t y_vector_size(): " << y_vector_size() + << "\n\t y.size(): " << y.size() + << "\n\t this: " << this + ); + + x_vect_size = x.size(); + y_vect_size = y.size(); + if (total_count == 0) + { + total_cov = x*trans(y); + sum_x = x; + sum_y = y; + } + else + { + total_cov += x*trans(y); + sum_x += x; + sum_y += y; + } + ++total_count; + } + + const column_matrix mean_x ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT( current_n() != 0, + "\t running_cross_covariance::mean()" + << "\n\t This object can not execute this function in its current state." + << "\n\t current_n(): " << current_n() + << "\n\t this: " << this + ); + + return sum_x/total_count; + } + + const column_matrix mean_y ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT( current_n() != 0, + "\t running_cross_covariance::mean()" + << "\n\t This object can not execute this function in its current state." + << "\n\t current_n(): " << current_n() + << "\n\t this: " << this + ); + + return sum_y/total_count; + } + + const general_matrix covariance_xy ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT( current_n() > 1, + "\t running_cross_covariance::covariance()" + << "\n\t This object can not execute this function in its current state." + << "\n\t x_vector_size(): " << x_vector_size() + << "\n\t y_vector_size(): " << y_vector_size() + << "\n\t current_n(): " << current_n() + << "\n\t this: " << this + ); + + return (total_cov - sum_x*trans(sum_y)/total_count)/(total_count-1); + } + + const running_cross_covariance operator+ ( + const running_cross_covariance& item + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT((x_vector_size() == 0 || item.x_vector_size() == 0 || x_vector_size() == item.x_vector_size()) && + (y_vector_size() == 0 || item.y_vector_size() == 0 || y_vector_size() == item.y_vector_size()), + "\t running_cross_covariance running_cross_covariance::operator+()" + << "\n\t The two running_cross_covariance objects being added must have compatible parameters" + << "\n\t x_vector_size(): " << x_vector_size() + << "\n\t item.x_vector_size(): " << item.x_vector_size() + << "\n\t y_vector_size(): " << y_vector_size() + << "\n\t item.y_vector_size(): " << item.y_vector_size() + << "\n\t this: " << this + ); + + running_cross_covariance temp(item); + + // make sure we ignore empty matrices + if (total_count != 0 && temp.total_count != 0) + { + temp.total_cov += total_cov; + temp.sum_x += sum_x; + temp.sum_y += sum_y; + temp.total_count += total_count; + } + else if (total_count != 0) + { + temp.total_cov = total_cov; + temp.sum_x = sum_x; + temp.sum_y = sum_y; + temp.total_count = total_count; + } + + return temp; + } + + + private: + + general_matrix total_cov; + column_matrix sum_x; + column_matrix sum_y; + scalar_type total_count; + + long x_vect_size; + long y_vect_size; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename matrix_type + > + class vector_normalizer + { + public: + typedef typename matrix_type::mem_manager_type mem_manager_type; + typedef typename matrix_type::type scalar_type; + typedef matrix_type result_type; + + template + void train ( + const vector_type& samples + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(samples.size() > 0, + "\tvoid vector_normalizer::train()" + << "\n\tyou have to give a nonempty set of samples to this function" + << "\n\tthis: " << this + ); + + m = mean(mat(samples)); + sd = reciprocal(sqrt(variance(mat(samples)))); + + DLIB_ASSERT(is_finite(m), "Some of the input vectors to vector_normalizer::train() have infinite or NaN values"); + } + + long in_vector_size ( + ) const + { + return m.nr(); + } + + long out_vector_size ( + ) const + { + return m.nr(); + } + + const matrix_type& means ( + ) const + { + return m; + } + + const matrix_type& std_devs ( + ) const + { + return sd; + } + + const result_type& operator() ( + const matrix_type& x + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(x.nr() == in_vector_size() && x.nc() == 1, + "\tmatrix vector_normalizer::operator()" + << "\n\t you have given invalid arguments to this function" + << "\n\t x.nr(): " << x.nr() + << "\n\t in_vector_size(): " << in_vector_size() + << "\n\t x.nc(): " << x.nc() + << "\n\t this: " << this + ); + + temp_out = pointwise_multiply(x-m, sd); + return temp_out; + } + + void swap ( + vector_normalizer& item + ) + { + m.swap(item.m); + sd.swap(item.sd); + temp_out.swap(item.temp_out); + } + + template + friend void deserialize ( + vector_normalizer& item, + std::istream& in + ); + + template + friend void serialize ( + const vector_normalizer& item, + std::ostream& out + ); + + private: + + // ------------------- private data members ------------------- + + matrix_type m, sd; + + // This is just a temporary variable that doesn't contribute to the + // state of this object. + mutable matrix_type temp_out; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename matrix_type + > + inline void swap ( + vector_normalizer& a, + vector_normalizer& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- + + template < + typename matrix_type + > + void deserialize ( + vector_normalizer& item, + std::istream& in + ) + { + deserialize(item.m, in); + deserialize(item.sd, in); + // Keep deserializing the pca matrix for backwards compatibility. + matrix pca; + deserialize(pca, in); + + if (pca.size() != 0) + throw serialization_error("Error deserializing object of type vector_normalizer\n" + "It looks like a serialized vector_normalizer_pca was accidentally deserialized into \n" + "a vector_normalizer object."); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename matrix_type + > + void serialize ( + const vector_normalizer& item, + std::ostream& out + ) + { + serialize(item.m, out); + serialize(item.sd, out); + // Keep serializing the pca matrix for backwards compatibility. + matrix pca; + serialize(pca, out); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename matrix_type + > + class vector_normalizer_pca + { + public: + typedef typename matrix_type::mem_manager_type mem_manager_type; + typedef typename matrix_type::type scalar_type; + typedef matrix result_type; + + template + void train ( + const vector_type& samples, + const double eps = 0.99 + ) + { + // You are getting an error here because you are trying to apply PCA + // to a vector of fixed length. But PCA is going to try and do + // dimensionality reduction so you can't use a vector with a fixed dimension. + COMPILE_TIME_ASSERT(matrix_type::NR == 0); + + // make sure requires clause is not broken + DLIB_ASSERT(samples.size() > 0, + "\tvoid vector_normalizer_pca::train_pca()" + << "\n\tyou have to give a nonempty set of samples to this function" + << "\n\tthis: " << this + ); + DLIB_ASSERT(0 < eps && eps <= 1, + "\tvoid vector_normalizer_pca::train_pca()" + << "\n\tyou have to give a nonempty set of samples to this function" + << "\n\tthis: " << this + ); + train_pca_impl(mat(samples),eps); + + DLIB_ASSERT(is_finite(m), "Some of the input vectors to vector_normalizer_pca::train() have infinite or NaN values"); + } + + long in_vector_size ( + ) const + { + return m.nr(); + } + + long out_vector_size ( + ) const + { + return pca.nr(); + } + + const matrix& means ( + ) const + { + return m; + } + + const matrix& std_devs ( + ) const + { + return sd; + } + + const matrix& pca_matrix ( + ) const + { + return pca; + } + + const result_type& operator() ( + const matrix_type& x + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(x.nr() == in_vector_size() && x.nc() == 1, + "\tmatrix vector_normalizer_pca::operator()" + << "\n\t you have given invalid arguments to this function" + << "\n\t x.nr(): " << x.nr() + << "\n\t in_vector_size(): " << in_vector_size() + << "\n\t x.nc(): " << x.nc() + << "\n\t this: " << this + ); + + // If we have a pca transform matrix on hand then + // also apply that. + temp_out = pca*pointwise_multiply(x-m, sd); + + return temp_out; + } + + void swap ( + vector_normalizer_pca& item + ) + { + m.swap(item.m); + sd.swap(item.sd); + pca.swap(item.pca); + temp_out.swap(item.temp_out); + } + + template + friend void deserialize ( + vector_normalizer_pca& item, + std::istream& in + ); + + template + friend void serialize ( + const vector_normalizer_pca& item, + std::ostream& out + ); + + private: + + template + void train_pca_impl ( + const mat_type& samples, + const double eps + ) + { + m = mean(samples); + sd = reciprocal(sqrt(variance(samples))); + + // fill x with the normalized version of the input samples + matrix x(samples); + for (long r = 0; r < x.size(); ++r) + x(r) = pointwise_multiply(x(r)-m, sd); + + matrix temp, eigen; + matrix eigenvalues; + + // Compute the svd of the covariance matrix of the normalized inputs + svd(covariance(x), temp, eigen, pca); + eigenvalues = diag(eigen); + + rsort_columns(pca, eigenvalues); + + // figure out how many eigenvectors we want in our pca matrix + const double thresh = sum(eigenvalues)*eps; + long num_vectors = 0; + double total = 0; + for (long r = 0; r < eigenvalues.size() && total < thresh; ++r) + { + ++num_vectors; + total += eigenvalues(r); + } + + // So now we know we want to use num_vectors of the first eigenvectors. So + // pull those out and discard the rest. + pca = trans(colm(pca,range(0,num_vectors-1))); + + // Apply the pca transform to the data in x. Then we will normalize the + // pca matrix below. + for (long r = 0; r < x.nr(); ++r) + { + x(r) = pca*x(r); + } + + // Here we just scale the output features from the pca transform so + // that the variance of each feature is 1. So this doesn't really change + // what the pca is doing, it just makes sure the output features are + // normalized. + pca = trans(scale_columns(trans(pca), reciprocal(sqrt(variance(x))))); + } + + + // ------------------- private data members ------------------- + + matrix m, sd; + matrix pca; + + // This is just a temporary variable that doesn't contribute to the + // state of this object. + mutable result_type temp_out; + }; + + template < + typename matrix_type + > + inline void swap ( + vector_normalizer_pca& a, + vector_normalizer_pca& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- + + template < + typename matrix_type + > + void deserialize ( + vector_normalizer_pca& item, + std::istream& in + ) + { + deserialize(item.m, in); + deserialize(item.sd, in); + deserialize(item.pca, in); + if (item.pca.nc() != item.m.nr()) + throw serialization_error("Error deserializing object of type vector_normalizer_pca\n" + "It looks like a serialized vector_normalizer was accidentally deserialized into \n" + "a vector_normalizer_pca object."); + } + + template < + typename matrix_type + > + void serialize ( + const vector_normalizer_pca& item, + std::ostream& out + ) + { + serialize(item.m, out); + serialize(item.sd, out); + serialize(item.pca, out); + } + +// ---------------------------------------------------------------------------------------- + + inline double binomial_random_vars_are_different ( + uint64_t k1, + uint64_t n1, + uint64_t k2, + uint64_t n2 + ) + { + DLIB_ASSERT(k1 <= n1, "k1: "<< k1 << " n1: "<< n1); + DLIB_ASSERT(k2 <= n2, "k2: "<< k2 << " n2: "<< n2); + + const double p1 = k1/(double)n1; + const double p2 = k2/(double)n2; + const double p = (k1+k2)/(double)(n1+n2); + + auto ll = [](double p, uint64_t k, uint64_t n) { + if (p == 0 || p == 1) + return 0.0; + return k*std::log(p) + (n-k)*std::log(1-p); + }; + + auto logll = ll(p1,k1,n1) + ll(p2,k2,n2) - ll(p,k1,n1) - ll(p,k2,n2); + + // The basic statistic only tells you if the random variables are different. But + // it's nice to know which way they are different, i.e., which one is bigger. So + // stuff that information into the sign bit of the return value. + if (p1>=p2) + return logll; + else + return -logll; + } + +// ---------------------------------------------------------------------------------------- + + inline double event_correlation ( + uint64_t A_count, + uint64_t B_count, + uint64_t AB_count, + uint64_t total_num_observations + ) + { + DLIB_ASSERT(AB_count <= A_count && A_count <= total_num_observations, + "AB_count: " << AB_count << ", A_count: "<< A_count << ", total_num_observations: " << total_num_observations); + DLIB_ASSERT(AB_count <= B_count && B_count <= total_num_observations, + "AB_count: " << AB_count << ", B_count: "<< B_count << ", total_num_observations: " << total_num_observations); + + if (total_num_observations == 0) + return 0; + + DLIB_ASSERT(A_count + B_count - AB_count <= total_num_observations, + "AB_count: " << AB_count << " A_count: " << A_count << ", B_count: "<< B_count << ", total_num_observations: " << total_num_observations); + + + const auto AnotB_count = A_count - AB_count; + const auto notB_count = total_num_observations - B_count; + // How likely is it that the odds of A happening is different when conditioned on + // whether or not B happened? + return binomial_random_vars_are_different( + AB_count, B_count, // A conditional on the presence of B + AnotB_count, notB_count // A conditional on the absence of B + ); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_STATISTICs_ + diff --git a/ml/dlib/dlib/statistics/statistics_abstract.h b/ml/dlib/dlib/statistics/statistics_abstract.h new file mode 100644 index 000000000..ef8f13802 --- /dev/null +++ b/ml/dlib/dlib/statistics/statistics_abstract.h @@ -0,0 +1,1387 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net), Steve Taylor +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_STATISTICs_ABSTRACT_ +#ifdef DLIB_STATISTICs_ABSTRACT_ + +#include +#include +#include "../matrix/matrix_abstract.h" +#include "../svm/sparse_vector_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename alloc + > + double mean_sign_agreement ( + const std::vector& a, + const std::vector& b + ); + /*! + requires + - a.size() == b.size() + ensures + - returns the number of times a[i] has the same sign as b[i] divided by + a.size(). So we return the probability that elements of a and b have + the same sign. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename alloc + > + double correlation ( + const std::vector& a, + const std::vector& b + ); + /*! + requires + - a.size() == b.size() + - a.size() > 1 + ensures + - returns the correlation coefficient between all the elements of a and b. + (i.e. how correlated is a(i) with b(i)) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename alloc + > + double covariance ( + const std::vector& a, + const std::vector& b + ); + /*! + requires + - a.size() == b.size() + - a.size() > 1 + ensures + - returns the covariance between all the elements of a and b. + (i.e. how does a(i) vary with b(i)) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename alloc + > + double r_squared ( + const std::vector& a, + const std::vector& b + ); + /*! + requires + - a.size() == b.size() + - a.size() > 1 + ensures + - returns the R^2 coefficient of determination between all the elements of a and b. + This value is just the square of correlation(a,b). + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename alloc + > + double mean_squared_error ( + const std::vector& a, + const std::vector& b + ); + /*! + requires + - a.size() == b.size() + ensures + - returns the mean squared error between all the elements of a and b. + (i.e. mean(squared(mat(a)-mat(b)))) + !*/ + +// ---------------------------------------------------------------------------------------- + + double binomial_random_vars_are_different ( + uint64_t k1, + uint64_t n1, + uint64_t k2, + uint64_t n2 + ); + /*! + requires + - k1 <= n1 + - k2 <= n2 + ensures + - Given two binomially distributed random variables, X1 and X2, we want to know + if these variables have the same parameter (i.e. the chance of "success"). + So assume that: + - You observed X1 to give k1 successes out of n1 trials. + - You observed X2 to give k2 successes out of n2 trials. + - This function performs a simple likelihood ratio test to determine if X1 and + X2 have the same parameter. The return value of this function will be: + - Close to 0 if they are probably the same. + - Larger than 0 if X1 probably has a higher "success" rate than X2. + - Smaller than 0 if X2 probably has a higher "success" rate than X1. + Moreover, the larger the absolute magnitude of the return value the more + likely it is that X1 and X2 have different distributions. + - For a discussion of the technique and applications see: + Dunning, Ted. "Accurate methods for the statistics of surprise and + coincidence." Computational linguistics 19.1 (1993): 61-74. + !*/ + +// ---------------------------------------------------------------------------------------- + + double event_correlation ( + uint64_t A_count, + uint64_t B_count, + uint64_t AB_count, + uint64_t total_num_observations + ); + /*! + requires + - AB_count <= A_count <= total_num_observations + - AB_count <= B_count <= total_num_observations + - A_count + B_count - AB_count <= total_num_observations + ensures + - This function does a statistical test to determine if two events co-occur in + a statistically significant way. In particular, we assume you performed + total_num_observations measurements and during those measurements you: + - Observed event A to happen A_count times. + - Observed event B to happen B_count times. + - Observed AB_count co-occurrences of the events. That is, AB_count is the + number of times the events happened together during the same measurement. + - This function returns a number, COR, which can take any real value. It has + the following interpretations: + - COR == 0: there is no evidence of correlation between the two events. + They appear to be unrelated. + - COR > 0: There is evidence that A and B co-occur together. That is, + they happen at the same times more often than you would expect if they + were independent events. The larger the magnitude of COR the more + evidence we have for the correlation. + - COR < 0: There is evidence that A and B are anti-correlated. That is, + when A happens B is unlikely to happen and vise versa. The larger the + magnitude of COR the more evidence we have for the anti-correlation. + - This function implements the simple likelihood ratio test discussed in the + following paper: + Dunning, Ted. "Accurate methods for the statistics of surprise and + coincidence." Computational linguistics 19.1 (1993): 61-74. + So for an extended discussion of the method see the above paper. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + class running_stats + { + /*! + REQUIREMENTS ON T + - T must be a float, double, or long double type + + INITIAL VALUE + - mean() == 0 + - current_n() == 0 + + WHAT THIS OBJECT REPRESENTS + This object represents something that can compute the running mean, + variance, skewness, and excess kurtosis of a stream of real numbers. + !*/ + public: + + running_stats( + ); + /*! + ensures + - this object is properly initialized + !*/ + + void clear( + ); + /*! + ensures + - this object has its initial value + - clears all memory of any previous data points + !*/ + + T current_n ( + ) const; + /*! + ensures + - returns the number of points given to this object so far. + !*/ + + void add ( + const T& val + ); + /*! + ensures + - updates the mean, variance, skewness, and kurtosis stored in this object + so that the new value is factored into them. + - #mean() == mean()*current_n()/(current_n()+1) + val/(current_n()+1). + (i.e. the updated mean value that takes the new value into account) + - #variance() == the updated variance that takes this new value into account. + - #skewness() == the updated skewness that takes this new value into account. + - #ex_kurtosis() == the updated kurtosis that takes this new value into account. + - #current_n() == current_n() + 1 + !*/ + + T mean ( + ) const; + /*! + ensures + - returns the mean of all the values presented to this object + so far. + !*/ + + T variance ( + ) const; + /*! + requires + - current_n() > 1 + ensures + - returns the unbiased sample variance of all the values presented to this + object so far. + !*/ + + T stddev ( + ) const; + /*! + requires + - current_n() > 1 + ensures + - returns the unbiased sampled standard deviation of all the values + presented to this object so far. + !*/ + + T skewness ( + ) const; + /*! + requires + - current_n() > 2 + ensures + - returns the unbiased sample skewness of all the values presented + to this object so far. + !*/ + + T ex_kurtosis( + ) const; + /*! + requires + - current_n() > 3 + ensures + - returns the unbiased sample kurtosis of all the values presented + to this object so far. + !*/ + + T max ( + ) const; + /*! + requires + - current_n() > 1 + ensures + - returns the largest value presented to this object so far. + !*/ + + T min ( + ) const; + /*! + requires + - current_n() > 1 + ensures + - returns the smallest value presented to this object so far. + !*/ + + T scale ( + const T& val + ) const; + /*! + requires + - current_n() > 1 + ensures + - return (val-mean())/stddev(); + !*/ + + running_stats operator+ ( + const running_stats& rhs + ) const; + /*! + ensures + - returns a new running_stats object that represents the combination of all + the values given to *this and rhs. That is, this function returns a + running_stats object, R, that is equivalent to what you would obtain if + all calls to this->add() and rhs.add() had instead been done to R. + !*/ + }; + + template + void serialize ( + const running_stats& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + + template + void deserialize ( + running_stats& item, + std::istream& in + ); + /*! + provides serialization support + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + class running_scalar_covariance + { + /*! + REQUIREMENTS ON T + - T must be a float, double, or long double type + + INITIAL VALUE + - mean_x() == 0 + - mean_y() == 0 + - current_n() == 0 + + WHAT THIS OBJECT REPRESENTS + This object represents something that can compute the running covariance + of a stream of real number pairs. + !*/ + + public: + + running_scalar_covariance( + ); + /*! + ensures + - this object is properly initialized + !*/ + + void clear( + ); + /*! + ensures + - this object has its initial value + - clears all memory of any previous data points + !*/ + + void add ( + const T& x, + const T& y + ); + /*! + ensures + - updates the statistics stored in this object so that + the new pair (x,y) is factored into them. + - #current_n() == current_n() + 1 + !*/ + + T current_n ( + ) const; + /*! + ensures + - returns the number of points given to this object so far. + !*/ + + T mean_x ( + ) const; + /*! + ensures + - returns the mean value of all x samples presented to this object + via add(). + !*/ + + T mean_y ( + ) const; + /*! + ensures + - returns the mean value of all y samples presented to this object + via add(). + !*/ + + T covariance ( + ) const; + /*! + requires + - current_n() > 1 + ensures + - returns the covariance between all the x and y samples presented + to this object via add() + !*/ + + T correlation ( + ) const; + /*! + requires + - current_n() > 1 + ensures + - returns the correlation coefficient between all the x and y samples + presented to this object via add() + !*/ + + T variance_x ( + ) const; + /*! + requires + - current_n() > 1 + ensures + - returns the unbiased sample variance value of all x samples presented + to this object via add(). + !*/ + + T variance_y ( + ) const; + /*! + requires + - current_n() > 1 + ensures + - returns the unbiased sample variance value of all y samples presented + to this object via add(). + !*/ + + T stddev_x ( + ) const; + /*! + requires + - current_n() > 1 + ensures + - returns the unbiased sample standard deviation of all x samples + presented to this object via add(). + !*/ + + T stddev_y ( + ) const; + /*! + requires + - current_n() > 1 + ensures + - returns the unbiased sample standard deviation of all y samples + presented to this object via add(). + !*/ + + running_scalar_covariance operator+ ( + const running_covariance& rhs + ) const; + /*! + ensures + - returns a new running_scalar_covariance object that represents the + combination of all the values given to *this and rhs. That is, this + function returns a running_scalar_covariance object, R, that is + equivalent to what you would obtain if all calls to this->add() and + rhs.add() had instead been done to R. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + class running_scalar_covariance_decayed + { + /*! + REQUIREMENTS ON T + - T must be a float, double, or long double type + + INITIAL VALUE + - mean_x() == 0 + - mean_y() == 0 + - current_n() == 0 + + WHAT THIS OBJECT REPRESENTS + This object represents something that can compute the running covariance of + a stream of real number pairs. It is essentially the same as + running_scalar_covariance except that it forgets about data it has seen + after a certain period of time. It does this by exponentially decaying old + statistics. + !*/ + + public: + + running_scalar_covariance_decayed( + T decay_halflife = 1000 + ); + /*! + requires + - decay_halflife > 0 + ensures + - #forget_factor() == std::pow(0.5, 1/decay_halflife); + (i.e. after decay_halflife calls to add() the data given to the first add + will be down weighted by 0.5 in the statistics stored in this object). + !*/ + + T forget_factor ( + ) const; + /*! + ensures + - returns the exponential forget factor used to forget old statistics when + add() is called. + !*/ + + void add ( + const T& x, + const T& y + ); + /*! + ensures + - updates the statistics stored in this object so that + the new pair (x,y) is factored into them. + - #current_n() == current_n()*forget_factor() + forget_factor() + - Down weights old statistics by a factor of forget_factor(). + !*/ + + T current_n ( + ) const; + /*! + ensures + - returns the effective number of points given to this object. As add() + is called this value will converge to a constant, the value of which is + based on the decay_halflife supplied to the constructor. + !*/ + + T mean_x ( + ) const; + /*! + ensures + - returns the mean value of all x samples presented to this object + via add(). + !*/ + + T mean_y ( + ) const; + /*! + ensures + - returns the mean value of all y samples presented to this object + via add(). + !*/ + + T covariance ( + ) const; + /*! + requires + - current_n() > 1 + ensures + - returns the covariance between all the x and y samples presented + to this object via add() + !*/ + + T correlation ( + ) const; + /*! + requires + - current_n() > 1 + ensures + - returns the correlation coefficient between all the x and y samples + presented to this object via add() + !*/ + + T variance_x ( + ) const; + /*! + requires + - current_n() > 1 + ensures + - returns the sample variance value of all x samples presented + to this object via add(). + !*/ + + T variance_y ( + ) const; + /*! + requires + - current_n() > 1 + ensures + - returns the sample variance value of all y samples presented + to this object via add(). + !*/ + + T stddev_x ( + ) const; + /*! + requires + - current_n() > 1 + ensures + - returns the sample standard deviation of all x samples + presented to this object via add(). + !*/ + + T stddev_y ( + ) const; + /*! + requires + - current_n() > 1 + ensures + - returns the sample standard deviation of all y samples + presented to this object via add(). + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + class running_stats_decayed + { + /*! + REQUIREMENTS ON T + - T must be a float, double, or long double type + + INITIAL VALUE + - mean() == 0 + - current_n() == 0 + + WHAT THIS OBJECT REPRESENTS + This object represents something that can compute the running mean and + variance of a stream of real numbers. It is similar to running_stats + except that it forgets about data it has seen after a certain period of + time. It does this by exponentially decaying old statistics. + !*/ + + public: + + running_stats_decayed( + T decay_halflife = 1000 + ); + /*! + requires + - decay_halflife > 0 + ensures + - #forget_factor() == std::pow(0.5, 1/decay_halflife); + (i.e. after decay_halflife calls to add() the data given to the first add + will be down weighted by 0.5 in the statistics stored in this object). + !*/ + + T forget_factor ( + ) const; + /*! + ensures + - returns the exponential forget factor used to forget old statistics when + add() is called. + !*/ + + void add ( + const T& x + ); + /*! + ensures + - updates the statistics stored in this object so that x is factored into + them. + - #current_n() == current_n()*forget_factor() + forget_factor() + - Down weights old statistics by a factor of forget_factor(). + !*/ + + T current_n ( + ) const; + /*! + ensures + - returns the effective number of points given to this object. As add() + is called this value will converge to a constant, the value of which is + based on the decay_halflife supplied to the constructor. + !*/ + + T mean ( + ) const; + /*! + ensures + - returns the mean value of all x samples presented to this object + via add(). + !*/ + + T variance ( + ) const; + /*! + requires + - current_n() > 1 + ensures + - returns the sample variance value of all x samples presented to this + object via add(). + !*/ + + T stddev ( + ) const; + /*! + requires + - current_n() > 1 + ensures + - returns the sample standard deviation of all x samples presented to this + object via add(). + !*/ + + }; + + template + void serialize ( + const running_stats_decayed& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + + template + void deserialize ( + running_stats_decayed& item, + std::istream& in + ); + /*! + provides serialization support + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename matrix_type + > + class running_covariance + { + /*! + REQUIREMENTS ON matrix_type + Must be some type of dlib::matrix. + + INITIAL VALUE + - in_vector_size() == 0 + - current_n() == 0 + + WHAT THIS OBJECT REPRESENTS + This object is a simple tool for computing the mean and + covariance of a sequence of vectors. + !*/ + public: + + typedef typename matrix_type::mem_manager_type mem_manager_type; + typedef typename matrix_type::type scalar_type; + typedef typename matrix_type::layout_type layout_type; + typedef matrix general_matrix; + typedef matrix column_matrix; + + running_covariance( + ); + /*! + ensures + - this object is properly initialized + !*/ + + void clear( + ); + /*! + ensures + - this object has its initial value + - clears all memory of any previous data points + !*/ + + long current_n ( + ) const; + /*! + ensures + - returns the number of samples that have been presented to this object + !*/ + + long in_vector_size ( + ) const; + /*! + ensures + - if (this object has been presented with any input vectors or + set_dimension() has been called) then + - returns the dimension of the column vectors used with this object + - else + - returns 0 + !*/ + + void set_dimension ( + long size + ); + /*! + requires + - size > 0 + ensures + - #in_vector_size() == size + - #current_n() == 0 + !*/ + + template + void add ( + const T& val + ); + /*! + requires + - val must represent a column vector. It can either be a dlib::matrix + object or some kind of unsorted sparse vector type. See the top of + dlib/svm/sparse_vector_abstract.h for a definition of unsorted sparse vector. + - val must have a number of dimensions which is compatible with the current + setting of in_vector_size(). In particular, this means that the + following must hold: + - if (val is a dlib::matrix) then + - in_vector_size() == 0 || val.size() == val_vector_size() + - else + - max_index_plus_one(val) <= in_vector_size() + - in_vector_size() > 0 + (i.e. you must call set_dimension() prior to calling add() if + you want to use sparse vectors.) + ensures + - updates the mean and covariance stored in this object so that + the new value is factored into them. + - if (val is a dlib::matrix) then + - #in_vector_size() == val.size() + !*/ + + const column_matrix mean ( + ) const; + /*! + requires + - in_vector_size() != 0 + ensures + - returns the mean of all the vectors presented to this object + so far. + !*/ + + const general_matrix covariance ( + ) const; + /*! + requires + - in_vector_size() != 0 + - current_n() > 1 + ensures + - returns the unbiased sample covariance matrix for all the vectors + presented to this object so far. + !*/ + + const running_covariance operator+ ( + const running_covariance& item + ) const; + /*! + requires + - in_vector_size() == 0 || item.in_vector_size() == 0 || in_vector_size() == item.in_vector_size() + (i.e. the in_vector_size() of *this and item must match or one must be zero) + ensures + - returns a new running_covariance object that represents the combination of all + the vectors given to *this and item. That is, this function returns a + running_covariance object, R, that is equivalent to what you would obtain if all + calls to this->add() and item.add() had instead been done to R. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename matrix_type + > + class running_cross_covariance + { + /*! + REQUIREMENTS ON matrix_type + Must be some type of dlib::matrix. + + INITIAL VALUE + - x_vector_size() == 0 + - y_vector_size() == 0 + - current_n() == 0 + + WHAT THIS OBJECT REPRESENTS + This object is a simple tool for computing the mean and cross-covariance + matrices of a sequence of pairs of vectors. + !*/ + + public: + + typedef typename matrix_type::mem_manager_type mem_manager_type; + typedef typename matrix_type::type scalar_type; + typedef typename matrix_type::layout_type layout_type; + typedef matrix general_matrix; + typedef matrix column_matrix; + + running_cross_covariance( + ); + /*! + ensures + - this object is properly initialized + !*/ + + void clear( + ); + /*! + ensures + - This object has its initial value. + - Clears all memory of any previous data points. + !*/ + + long x_vector_size ( + ) const; + /*! + ensures + - if (this object has been presented with any input vectors or + set_dimensions() has been called) then + - returns the dimension of the x vectors given to this object via add(). + - else + - returns 0 + !*/ + + long y_vector_size ( + ) const; + /*! + ensures + - if (this object has been presented with any input vectors or + set_dimensions() has been called) then + - returns the dimension of the y vectors given to this object via add(). + - else + - returns 0 + !*/ + + void set_dimensions ( + long x_size, + long y_size + ); + /*! + requires + - x_size > 0 + - y_size > 0 + ensures + - #x_vector_size() == x_size + - #y_vector_size() == y_size + - #current_n() == 0 + !*/ + + long current_n ( + ) const; + /*! + ensures + - returns the number of samples that have been presented to this object. + !*/ + + template + void add ( + const T& x, + const U& y + ); + /*! + requires + - x and y must represent column vectors. They can either be dlib::matrix + objects or some kind of unsorted sparse vector type. See the top of + dlib/svm/sparse_vector_abstract.h for a definition of unsorted sparse vector. + - x and y must have a number of dimensions which is compatible with the + current setting of x_vector_size() and y_vector_size(). In particular, + this means that the following must hold: + - if (x or y is a sparse vector type) then + - x_vector_size() > 0 && y_vector_size() > 0 + (i.e. you must call set_dimensions() prior to calling add() if + you want to use sparse vectors.) + - if (x is a dlib::matrix) then + - x_vector_size() == 0 || x.size() == x_vector_size() + - else + - max_index_plus_one(x) <= x_vector_size() + - if (y is a dlib::matrix) then + - y_vector_size() == 0 || y.size() == y_vector_size() + - else + - max_index_plus_one(y) <= y_vector_size() + ensures + - updates the mean and cross-covariance matrices stored in this object so + that the new (x,y) vector pair is factored into them. + - if (x is a dlib::matrix) then + - #x_vector_size() == x.size() + - if (y is a dlib::matrix) then + - #y_vector_size() == y.size() + !*/ + + const column_matrix mean_x ( + ) const; + /*! + requires + - current_n() != 0 + ensures + - returns the mean of all the x vectors presented to this object so far. + - The returned vector will have x_vector_size() dimensions. + !*/ + + const column_matrix mean_y ( + ) const; + /*! + requires + - current_n() != 0 + ensures + - returns the mean of all the y vectors presented to this object so far. + - The returned vector will have y_vector_size() dimensions. + !*/ + + const general_matrix covariance_xy ( + ) const; + /*! + requires + - current_n() > 1 + ensures + - returns the unbiased sample cross-covariance matrix for all the vector + pairs presented to this object so far. In particular, returns a matrix + M such that: + - M.nr() == x_vector_size() + - M.nc() == y_vector_size() + - M == the cross-covariance matrix of the data given to add(). + !*/ + + const running_cross_covariance operator+ ( + const running_cross_covariance& item + ) const; + /*! + requires + - x_vector_size() == 0 || item.x_vector_size() == 0 || x_vector_size() == item.x_vector_size() + (i.e. the x_vector_size() of *this and item must match or one must be zero) + - y_vector_size() == 0 || item.y_vector_size() == 0 || y_vector_size() == item.y_vector_size() + (i.e. the y_vector_size() of *this and item must match or one must be zero) + ensures + - returns a new running_cross_covariance object that represents the + combination of all the vectors given to *this and item. That is, this + function returns a running_cross_covariance object, R, that is equivalent + to what you would obtain if all calls to this->add() and item.add() had + instead been done to R. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename matrix_type + > + class vector_normalizer + { + /*! + REQUIREMENTS ON matrix_type + - must be a dlib::matrix object capable of representing column + vectors + + INITIAL VALUE + - in_vector_size() == 0 + - out_vector_size() == 0 + - means().size() == 0 + - std_devs().size() == 0 + + WHAT THIS OBJECT REPRESENTS + This object represents something that can learn to normalize a set + of column vectors. In particular, normalized column vectors should + have zero mean and a variance of one. + + Also, if desired, this object can use principal component + analysis for the purposes of reducing the number of elements in a + vector. + + THREAD SAFETY + Note that this object contains a cached matrix object it uses + to store intermediate results for normalization. This avoids + needing to reallocate it every time this object performs normalization + but also makes it non-thread safe. So make sure you don't share + instances of this object between threads. + !*/ + + public: + typedef typename matrix_type::mem_manager_type mem_manager_type; + typedef typename matrix_type::type scalar_type; + typedef matrix_type result_type; + + template + void train ( + const vector_type& samples + ); + /*! + requires + - samples.size() > 0 + - samples == a column matrix or something convertible to a column + matrix via mat(). Also, x should contain + matrix_type objects that represent nonempty column vectors. + - samples does not contain any infinite or NaN values + ensures + - #in_vector_size() == samples(0).nr() + - #out_vector_size() == samples(0).nr() + - This object has learned how to normalize vectors that look like + vectors in the given set of samples. + - #means() == mean(samples) + - #std_devs() == reciprocal(sqrt(variance(samples))); + !*/ + + long in_vector_size ( + ) const; + /*! + ensures + - returns the number of rows that input vectors are + required to contain if they are to be normalized by + this object. + !*/ + + long out_vector_size ( + ) const; + /*! + ensures + - returns the number of rows in the normalized vectors + that come out of this object. + !*/ + + const matrix_type& means ( + ) const; + /*! + ensures + - returns a matrix M such that: + - M.nc() == 1 + - M.nr() == in_vector_size() + - M(i) == the mean of the ith input feature shown to train() + !*/ + + const matrix_type& std_devs ( + ) const; + /*! + ensures + - returns a matrix SD such that: + - SD.nc() == 1 + - SD.nr() == in_vector_size() + - SD(i) == the reciprocal of the standard deviation of the ith + input feature shown to train() + !*/ + + const result_type& operator() ( + const matrix_type& x + ) const; + /*! + requires + - x.nr() == in_vector_size() + - x.nc() == 1 + ensures + - returns a normalized version of x, call it Z, that has the + following properties: + - Z.nr() == out_vector_size() + - Z.nc() == 1 + - the mean of each element of Z is 0 + - the variance of each element of Z is 1 + - Z == pointwise_multiply(x-means(), std_devs()); + !*/ + + void swap ( + vector_normalizer& item + ); + /*! + ensures + - swaps *this and item + !*/ + }; + + template < + typename matrix_type + > + inline void swap ( + vector_normalizer& a, + vector_normalizer& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + + template < + typename matrix_type, + > + void deserialize ( + vector_normalizer& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + + template < + typename matrix_type, + > + void serialize ( + const vector_normalizer& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename matrix_type + > + class vector_normalizer_pca + { + /*! + REQUIREMENTS ON matrix_type + - must be a dlib::matrix object capable of representing column + vectors + + INITIAL VALUE + - in_vector_size() == 0 + - out_vector_size() == 0 + - means().size() == 0 + - std_devs().size() == 0 + - pca_matrix().size() == 0 + + WHAT THIS OBJECT REPRESENTS + This object represents something that can learn to normalize a set + of column vectors. In particular, normalized column vectors should + have zero mean and a variance of one. + + Also, this object uses principal component analysis for the purposes + of reducing the number of elements in a vector. + + THREAD SAFETY + Note that this object contains a cached matrix object it uses + to store intermediate results for normalization. This avoids + needing to reallocate it every time this object performs normalization + but also makes it non-thread safe. So make sure you don't share + instances of this object between threads. + !*/ + + public: + typedef typename matrix_type::mem_manager_type mem_manager_type; + typedef typename matrix_type::type scalar_type; + typedef matrix result_type; + + template + void train ( + const vector_type& samples, + const double eps = 0.99 + ); + /*! + requires + - 0 < eps <= 1 + - samples.size() > 0 + - samples == a column matrix or something convertible to a column + matrix via mat(). Also, x should contain + matrix_type objects that represent nonempty column vectors. + - samples does not contain any infinite or NaN values + ensures + - This object has learned how to normalize vectors that look like + vectors in the given set of samples. + - Principal component analysis is performed to find a transform + that might reduce the number of output features. + - #in_vector_size() == samples(0).nr() + - 0 < #out_vector_size() <= samples(0).nr() + - eps is a number that controls how "lossy" the pca transform will be. + Large values of eps result in #out_vector_size() being larger and + smaller values of eps result in #out_vector_size() being smaller. + - #means() == mean(samples) + - #std_devs() == reciprocal(sqrt(variance(samples))); + - #pca_matrix() == the PCA transform matrix that is out_vector_size() + rows by in_vector_size() columns. + !*/ + + long in_vector_size ( + ) const; + /*! + ensures + - returns the number of rows that input vectors are + required to contain if they are to be normalized by + this object. + !*/ + + long out_vector_size ( + ) const; + /*! + ensures + - returns the number of rows in the normalized vectors + that come out of this object. + !*/ + + const matrix& means ( + ) const; + /*! + ensures + - returns a matrix M such that: + - M.nc() == 1 + - M.nr() == in_vector_size() + - M(i) == the mean of the ith input feature shown to train() + !*/ + + const matrix& std_devs ( + ) const; + /*! + ensures + - returns a matrix SD such that: + - SD.nc() == 1 + - SD.nr() == in_vector_size() + - SD(i) == the reciprocal of the standard deviation of the ith + input feature shown to train() + !*/ + + const matrix& pca_matrix ( + ) const; + /*! + ensures + - returns a matrix PCA such that: + - PCA.nr() == out_vector_size() + - PCA.nc() == in_vector_size() + - PCA == the principal component analysis transformation + matrix + !*/ + + const result_type& operator() ( + const matrix_type& x + ) const; + /*! + requires + - x.nr() == in_vector_size() + - x.nc() == 1 + ensures + - returns a normalized version of x, call it Z, that has the + following properties: + - Z.nr() == out_vector_size() + - Z.nc() == 1 + - the mean of each element of Z is 0 + - the variance of each element of Z is 1 + - Z == pca_matrix()*pointwise_multiply(x-means(), std_devs()); + !*/ + + void swap ( + vector_normalizer_pca& item + ); + /*! + ensures + - swaps *this and item + !*/ + }; + + template < + typename matrix_type + > + inline void swap ( + vector_normalizer_pca& a, + vector_normalizer_pca& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + + template < + typename matrix_type, + > + void deserialize ( + vector_normalizer_pca& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + + template < + typename matrix_type, + > + void serialize ( + const vector_normalizer_pca& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_STATISTICs_ABSTRACT_ + diff --git a/ml/dlib/dlib/statistics/vector_normalizer_frobmetric.h b/ml/dlib/dlib/statistics/vector_normalizer_frobmetric.h new file mode 100644 index 000000000..690370f80 --- /dev/null +++ b/ml/dlib/dlib/statistics/vector_normalizer_frobmetric.h @@ -0,0 +1,618 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_VECTOR_NORMALIZER_FRoBMETRIC_Hh_ +#define DLIB_VECTOR_NORMALIZER_FRoBMETRIC_Hh_ + +#include "vector_normalizer_frobmetric_abstract.h" +#include "../matrix.h" +#include "../optimization.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename matrix_type + > + struct frobmetric_training_sample + { + matrix_type anchor_vect; + std::vector near_vects; + std::vector far_vects; + + unsigned long num_triples ( + ) const { return near_vects.size() * far_vects.size(); } + + void clear() + { + near_vects.clear(); + far_vects.clear(); + } + }; + + template < + typename matrix_type + > + void serialize(const frobmetric_training_sample& item, std::ostream& out) + { + int version = 1; + serialize(version, out); + serialize(item.anchor_vect, out); + serialize(item.near_vects, out); + serialize(item.far_vects, out); + } + + template < + typename matrix_type + > + void deserialize(frobmetric_training_sample& item, std::istream& in) + { + int version = 0; + deserialize(version, in); + if (version != 1) + throw serialization_error("Unexpected version found while deserializing dlib::frobmetric_training_sample."); + deserialize(item.anchor_vect, in); + deserialize(item.near_vects, in); + deserialize(item.far_vects, in); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename matrix_type + > + class vector_normalizer_frobmetric + { + + public: + typedef typename matrix_type::mem_manager_type mem_manager_type; + typedef typename matrix_type::type scalar_type; + typedef matrix_type result_type; + + private: + struct compact_frobmetric_training_sample + { + std::vector near_vects; + std::vector far_vects; + }; + + struct objective + { + objective ( + const std::vector& samples_, + matrix& Aminus_, + const matrix& bias_ + ) : samples(samples_), Aminus(Aminus_), bias(bias_) {} + + double operator()(const matrix& u) const + { + long idx = 0; + const long dims = samples[0].far_vects[0].size(); + // Here we compute \hat A from the paper, which we refer to as just A in + // the code. + matrix A(dims,dims); + A = 0; + std::vector ufar, unear; + for (unsigned long i = 0; i < samples.size(); ++i) + { + ufar.assign(samples[i].far_vects.size(),0); + unear.assign(samples[i].near_vects.size(),0); + for (unsigned long j = 0; j < unear.size(); ++j) + { + for (unsigned long k = 0; k < ufar.size(); ++k) + { + const double val = u(idx++); + ufar[k] -= val; + unear[j] += val; + } + } + for (unsigned long j = 0; j < unear.size(); ++j) + A += unear[j]*samples[i].near_vects[j]*trans(samples[i].near_vects[j]); + for (unsigned long j = 0; j < ufar.size(); ++j) + A += ufar[j]*samples[i].far_vects[j]*trans(samples[i].far_vects[j]); + } + + eigenvalue_decomposition > ed(make_symmetric(A)); + Aminus = ed.get_pseudo_v()*diagm(upperbound(ed.get_real_eigenvalues(),0))*trans(ed.get_pseudo_v()); + // Do this to avoid numeric instability later on since the above + // computation can make Aminus slightly non-symmetric. + Aminus = make_symmetric(Aminus); + + return dot(u,bias) - 0.5*sum(squared(Aminus)); + } + + private: + const std::vector& samples; + matrix& Aminus; + const matrix& bias; + }; + + struct derivative + { + derivative ( + unsigned long num_triples_, + const std::vector& samples_, + matrix& Aminus_, + const matrix& bias_ + ) : num_triples(num_triples_), samples(samples_), Aminus(Aminus_), bias(bias_) {} + + matrix operator()(const matrix& ) const + { + // Note that Aminus is a function of u (the argument to this function), but + // since Aminus will have been computed already by the most recent call to + // the objective function we don't need to do anything with u. We can just + // use Aminus right away. + matrix grad(num_triples); + + long idx = 0; + std::vector ufar, unear; + for (unsigned long i = 0; i < samples.size(); ++i) + { + ufar.resize(samples[i].far_vects.size()); + unear.resize(samples[i].near_vects.size()); + + for (unsigned long j = 0; j < unear.size(); ++j) + unear[j] = sum(pointwise_multiply(Aminus, samples[i].near_vects[j]*trans(samples[i].near_vects[j]))); + for (unsigned long j = 0; j < ufar.size(); ++j) + ufar[j] = sum(pointwise_multiply(Aminus, samples[i].far_vects[j]*trans(samples[i].far_vects[j]))); + + for (unsigned long j = 0; j < samples[i].near_vects.size(); ++j) + { + for (unsigned long k = 0; k < samples[i].far_vects.size(); ++k) + { + grad(idx) = bias(idx) + ufar[k]-unear[j]; + idx++; + } + } + } + + return grad; + } + + private: + const unsigned long num_triples; + const std::vector& samples; + matrix& Aminus; + const matrix& bias; + }; + + + class custom_stop_strategy + { + public: + custom_stop_strategy( + double C_, + double eps_, + bool be_verbose_, + unsigned long max_iter_ + ) + { + _c = C_; + + _cur_iter = 0; + _gradient_thresh = eps_; + _max_iter = max_iter_; + _verbose = be_verbose_; + } + + template + bool should_continue_search ( + const T& u, + const double , + const T& grad + ) + { + ++_cur_iter; + + double max_gradient = 0; + for (long i = 0; i < grad.size(); ++i) + { + const bool at_lower_bound = (0 >= u(i) && grad(i) > 0); + const bool at_upper_bound = (_c/grad.size() <= u(i) && grad(i) < 0); + if (!at_lower_bound && !at_upper_bound) + max_gradient = std::max(std::abs(grad(i)), max_gradient); + } + + if (_verbose) + { + std::cout << "iteration: " << _cur_iter << " max_gradient: "<< max_gradient << std::endl; + } + + // Only stop when the largest non-bound-constrained element of the gradient + // is lower than the threshold. + if (max_gradient < _gradient_thresh) + return false; + + // Check if we have hit the max allowable number of iterations. + if (_cur_iter > _max_iter) + { + return false; + } + + return true; + } + + private: + bool _verbose; + + unsigned long _max_iter; + unsigned long _cur_iter; + double _c; + double _gradient_thresh; + }; + + public: + vector_normalizer_frobmetric ( + ) + { + verbose = false; + eps = 0.1; + C = 1; + max_iter = 5000; + _use_identity_matrix_prior = false; + } + + bool uses_identity_matrix_prior ( + ) const + { + return _use_identity_matrix_prior; + } + + void set_uses_identity_matrix_prior ( + bool use_prior + ) + { + _use_identity_matrix_prior = use_prior; + } + + void be_verbose( + ) + { + verbose = true; + } + + void set_epsilon ( + double eps_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(eps_ > 0, + "\t void vector_normalizer_frobmetric::set_epsilon(eps_)" + << "\n\t invalid inputs were given to this function" + << "\n\t eps: " << eps_ + ); + eps = eps_; + } + + double get_epsilon ( + ) const + { + return eps; + } + + void set_c ( + double C_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(C_ > 0, + "\t void vector_normalizer_frobmetric::set_c()" + << "\n\t C_ must be greater than 0" + << "\n\t C_: " << C_ + << "\n\t this: " << this + ); + + C = C_; + } + + void set_max_iterations ( + unsigned long max_iterations + ) + { + max_iter = max_iterations; + } + + unsigned long get_max_iterations ( + ) const + { + return max_iter; + } + + double get_c ( + ) const + { + return C; + } + + void be_quiet ( + ) + { + verbose = false; + } + + void train ( + const std::vector >& samples + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(samples.size() > 0, + "\tvoid vector_normalizer_frobmetric::train()" + << "\n\t you have to give a nonempty set of samples to this function" + ); +#ifdef ENABLE_ASSERTS + { + const long dims = samples[0].anchor_vect.size(); + DLIB_ASSERT(dims != 0, + "\tvoid vector_normalizer_frobmetric::train()" + << "\n\t The dimension of the input vectors can't be zero." + ); + for (unsigned long i = 0; i < samples.size(); ++i) + { + DLIB_ASSERT(is_col_vector(samples[i].anchor_vect), + "\tvoid vector_normalizer_frobmetric::train()" + << "\n\t Invalid inputs were given to this function." + << "\n\t i: " << i + ); + DLIB_ASSERT(samples[i].anchor_vect.size() == dims, + "\tvoid vector_normalizer_frobmetric::train()" + << "\n\t Invalid inputs were given to this function." + << "\n\t i: " << i + << "\n\t dims: " << dims + << "\n\t samples[i].anchor_vect.size(): " << samples[i].anchor_vect.size() + ); + + DLIB_ASSERT(samples[i].num_triples() != 0, + "\tvoid vector_normalizer_frobmetric::train()" + << "\n\t It is illegal for a training sample to have no data in it" + << "\n\t i: " << i + ); + for (unsigned long j = 0; j < samples[i].near_vects.size(); ++j) + { + DLIB_ASSERT(is_col_vector(samples[i].near_vects[j]), + "\tvoid vector_normalizer_frobmetric::train()" + << "\n\t Invalid inputs were given to this function." + << "\n\t i: " << i + << "\n\t j: " << j + ); + DLIB_ASSERT(samples[i].near_vects[j].size() == dims, + "\tvoid vector_normalizer_frobmetric::train()" + << "\n\t Invalid inputs were given to this function." + << "\n\t i: " << i + << "\n\t j: " << j + << "\n\t dims: " << dims + << "\n\t samples[i].near_vects[j].size(): " << samples[i].near_vects[j].size() + ); + } + for (unsigned long j = 0; j < samples[i].far_vects.size(); ++j) + { + DLIB_ASSERT(is_col_vector(samples[i].far_vects[j]), + "\tvoid vector_normalizer_frobmetric::train()" + << "\n\t Invalid inputs were given to this function." + << "\n\t i: " << i + << "\n\t j: " << j + ); + DLIB_ASSERT(samples[i].far_vects[j].size() == dims, + "\tvoid vector_normalizer_frobmetric::train()" + << "\n\t Invalid inputs were given to this function." + << "\n\t i: " << i + << "\n\t j: " << j + << "\n\t dims: " << dims + << "\n\t samples[i].far_vects[j].size(): " << samples[i].far_vects[j].size() + ); + } + } + } +#endif // end ENABLE_ASSERTS + + + // compute the mean sample + m = 0; + for (unsigned long i = 0; i < samples.size(); ++i) + m += samples[i].anchor_vect; + m /= samples.size(); + + DLIB_ASSERT(is_finite(m), "Some of the input vectors to vector_normalizer_frobmetric::train() have infinite or NaN values"); + + // Now we need to find tform. So we setup the optimization problem and run it + // over the next few lines of code. + unsigned long num_triples = 0; + for (unsigned long i = 0; i < samples.size(); ++i) + num_triples += samples[i].near_vects.size()*samples[i].far_vects.size(); + + matrix u(num_triples); + matrix bias(num_triples); + u = 0; + bias = 1; + + + // precompute all the anchor_vect to far_vects/near_vects pairs + std::vector data(samples.size()); + unsigned long cnt = 0; + std::vector far_norm, near_norm; + for (unsigned long i = 0; i < data.size(); ++i) + { + far_norm.clear(); + near_norm.clear(); + data[i].far_vects.reserve(samples[i].far_vects.size()); + data[i].near_vects.reserve(samples[i].near_vects.size()); + for (unsigned long j = 0; j < samples[i].far_vects.size(); ++j) + { + data[i].far_vects.push_back(samples[i].anchor_vect - samples[i].far_vects[j]); + if (_use_identity_matrix_prior) + far_norm.push_back(length_squared(data[i].far_vects.back())); + } + for (unsigned long j = 0; j < samples[i].near_vects.size(); ++j) + { + data[i].near_vects.push_back(samples[i].anchor_vect - samples[i].near_vects[j]); + if (_use_identity_matrix_prior) + near_norm.push_back(length_squared(data[i].near_vects.back())); + } + + // Note that this loop only executes if _use_identity_matrix_prior == true. + for (unsigned long j = 0; j < near_norm.size(); ++j) + { + for (unsigned long k = 0; k < far_norm.size(); ++k) + { + bias(cnt++) = 1 - (far_norm[k] - near_norm[j]); + } + } + } + + // Now run the main part of the algorithm + matrix Aminus; + find_max_box_constrained(lbfgs_search_strategy(10), + custom_stop_strategy(C, eps, verbose, max_iter), + objective(data, Aminus, bias), + derivative(num_triples, data, Aminus, bias), + u, 0, C/num_triples); + + + // What we need is the optimal Aminus which is a function of u. So we already + // have what we need and just need to put it into tform. + eigenvalue_decomposition > ed(make_symmetric(-Aminus)); + matrix eigs = ed.get_real_eigenvalues(); + // But first, discard the components that are zero to within the machine epsilon. + const double tol = max(eigs)*std::numeric_limits::epsilon(); + for (long i = 0; i < eigs.size(); ++i) + { + if (eigs(i) < tol) + eigs(i) = 0; + } + if (_use_identity_matrix_prior) + tform = matrix_cast(identity_matrix(Aminus) + diagm(sqrt(eigs))*trans(ed.get_pseudo_v())); + else + tform = matrix_cast(diagm(sqrt(eigs))*trans(ed.get_pseudo_v())); + + // Pre-apply the transform to m so we don't have to do it inside operator() + // every time it's called. + m = tform*m; + } + + long in_vector_size ( + ) const + { + return m.nr(); + } + + long out_vector_size ( + ) const + { + return m.nr(); + } + + const matrix& transformed_means ( + ) const + { + return m; + } + + const matrix& transform ( + ) const + { + return tform; + } + + const result_type& operator() ( + const matrix_type& x + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(in_vector_size() != 0 && in_vector_size() == x.size() && + is_col_vector(x) == true, + "\tmatrix vector_normalizer_frobmetric::operator()" + << "\n\t you have given invalid arguments to this function" + << "\n\t in_vector_size(): " << in_vector_size() + << "\n\t x.size(): " << x.size() + << "\n\t is_col_vector(x): " << is_col_vector(x) + << "\n\t this: " << this + ); + + temp_out = tform*x-m; + return temp_out; + } + + template + friend void deserialize ( + vector_normalizer_frobmetric& item, + std::istream& in + ); + + template + friend void serialize ( + const vector_normalizer_frobmetric& item, + std::ostream& out + ); + + private: + + // ------------------- private data members ------------------- + + matrix_type m; + matrix tform; + bool verbose; + double eps; + double C; + unsigned long max_iter; + bool _use_identity_matrix_prior; + + // This is just a temporary variable that doesn't contribute to the + // state of this object. + mutable matrix_type temp_out; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename matrix_type + > + void serialize ( + const vector_normalizer_frobmetric& item, + std::ostream& out + ) + { + const int version = 2; + serialize(version, out); + + serialize(item.m, out); + serialize(item.tform, out); + serialize(item.verbose, out); + serialize(item.eps, out); + serialize(item.C, out); + serialize(item.max_iter, out); + serialize(item._use_identity_matrix_prior, out); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename matrix_type + > + void deserialize ( + vector_normalizer_frobmetric& item, + std::istream& in + ) + { + int version = 0; + deserialize(version, in); + if (version != 1 && version != 2) + throw serialization_error("Unsupported version found while deserializing dlib::vector_normalizer_frobmetric."); + + deserialize(item.m, in); + deserialize(item.tform, in); + deserialize(item.verbose, in); + deserialize(item.eps, in); + deserialize(item.C, in); + deserialize(item.max_iter, in); + if (version == 2) + deserialize(item._use_identity_matrix_prior, in); + else + item._use_identity_matrix_prior = false; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_VECTOR_NORMALIZER_FRoBMETRIC_Hh_ + diff --git a/ml/dlib/dlib/statistics/vector_normalizer_frobmetric_abstract.h b/ml/dlib/dlib/statistics/vector_normalizer_frobmetric_abstract.h new file mode 100644 index 000000000..302628330 --- /dev/null +++ b/ml/dlib/dlib/statistics/vector_normalizer_frobmetric_abstract.h @@ -0,0 +1,328 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_VECTOR_NORMALIZER_FRoBMETRIC_ABSTRACT_Hh_ +#ifdef DLIB_VECTOR_NORMALIZER_FRoBMETRIC_ABSTRACT_Hh_ + +#include "../matrix.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename matrix_type + > + struct frobmetric_training_sample + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a training data sample for the + vector_normalizer_frobmetric object. It defines a set of training triplets + relative to a single anchor_vect vector. That is, it specifies that the + learned distance metric should satisfy num_triples() constraints which are, + for all valid i and j: + length(T*anchor_vect-T*near_vects[i]) + 1 < length(T*anchor_vect - T*far_vects[j]) + for some appropriate linear transformation T which will be learned by + vector_normalizer_frobmetric. + !*/ + + matrix_type anchor_vect; + std::vector near_vects; + std::vector far_vects; + + unsigned long num_triples ( + ) const { return near_vects.size() * far_vects.size(); } + /*! + ensures + - returns the number of training triplets defined by this object. + !*/ + + void clear() + /*! + ensures + - #near_vects.size() == 0 + - #far_vects.size() == 0 + !*/ + }; + + template < typename matrix_type > + void serialize(const frobmetric_training_sample& item, std::ostream& out) + template < typename matrix_type > + void deserialize(frobmetric_training_sample& item, std::istream& in) + /*! + provides serialisation support. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename matrix_type + > + class vector_normalizer_frobmetric + { + /*! + REQUIREMENTS ON matrix_type + - must be a dlib::matrix object capable of representing column + vectors + + INITIAL VALUE + - in_vector_size() == 0 + - out_vector_size() == 0 + - get_epsilon() == 0.1 + - get_c() == 1 + - get_max_iterations() == 5000 + - This object is not verbose + - uses_identity_matrix_prior() == false + + WHAT THIS OBJECT REPRESENTS + This object is a tool for performing the FrobMetric distance metric + learning algorithm described in the following paper: + A Scalable Dual Approach to Semidefinite Metric Learning + By Chunhua Shen, Junae Kim, Lei Wang, in CVPR 2011 + + Therefore, this object is a tool that takes as input training triplets + (anchor_vect, near, far) of vectors and attempts to learn a linear + transformation T such that: + length(T*anchor_vect-T*near) + 1 < length(T*anchor_vect - T*far) + That is, you give a bunch of anchor_vect vectors and for each anchor_vect + you specify some vectors which should be near to it and some that should be + far form it. This object then tries to find a transformation matrix that + makes the "near" vectors close to their anchors while the "far" vectors are + farther away. + + THREAD SAFETY + Note that this object contains a cached matrix object it uses + to store intermediate results for normalization. This avoids + needing to reallocate it every time this object performs normalization + but also makes it non-thread safe. So make sure you don't share + instances of this object between threads. + !*/ + + public: + typedef typename matrix_type::mem_manager_type mem_manager_type; + typedef typename matrix_type::type scalar_type; + typedef matrix_type result_type; + + vector_normalizer_frobmetric ( + ); + /*! + ensures + - this object is properly initialized + !*/ + + bool uses_identity_matrix_prior ( + ) const; + /*! + ensures + - Normally this object will try and find a matrix transform() that + minimizes sum(squared(transform())) but also fits the training data. + However, if #uses_identity_matrix_prior() == true then it will instead + try to find the transformation matrix that minimizes + sum(squared(identity_matrix()-transform())). That is, it will try to + find the matrix most similar to the identity matrix that best fits the + training data. + !*/ + + void set_uses_identity_matrix_prior ( + bool use_prior + ); + /*! + ensures + - #uses_identity_matrix_prior() == use_prior + !*/ + + void be_verbose( + ); + /*! + ensures + - This object will print status messages to standard out so the user can + observe the progress of the train() routine. + !*/ + + void be_quiet ( + ); + /*! + ensures + - this object will not print anything to standard out. + !*/ + + void set_epsilon ( + double eps + ); + /*! + requires + - eps > 0 + ensures + - #get_epsilon() == eps + !*/ + + double get_epsilon ( + ) const; + /*! + ensures + - returns the error epsilon that determines when training should stop. + Smaller values may result in a more accurate solution but take longer to + execute. + !*/ + + void set_c ( + double C + ); + /*! + requires + - C > 0 + ensures + - #set_c() == C + !*/ + + double get_c ( + ) const; + /*! + ensures + - returns the regularization parameter. It is the parameter that + determines the trade-off between trying to fit the training data exactly + or allowing more errors but hopefully improving the generalization of the + resulting distance metric. Larger values encourage exact fitting while + smaller values of C may encourage better generalization. + !*/ + + void set_max_iterations ( + unsigned long max_iterations + ); + /*! + ensures + - #get_max_iterations() == max_iterations + !*/ + + unsigned long get_max_iterations ( + ) const; + /*! + ensures + - The train() routine uses an iterative numerical solver to find the best + distance metric. This function returns the maximum allowable number of + iterations it will use before terminating. Note that typically the + solver terminates prior to the max iteration count limit due to the error + dropping below get_epsilon(). + !*/ + + void train ( + const std::vector >& samples + ); + /*! + requires + - samples.size() != 0 + - All matrices inside samples (i.e. anchors and elements of near_vects and far_vects) + are column vectors with the same non-zero dimension. + - All the vectors in samples contain finite values. + - All elements of samples contain data, specifically, for all valid i: + - samples[i].num_triples() != 0 + ensures + - learns a distance metric from the given training samples. After train + finishes you can use this object's operator() to transform vectors + according to the learned distance metric. In particular, we will have: + - #transform() == The linear transformation learned by the FrobMetric + learning procedure. + - #in_vector_size() == samples[0].anchor_vect.size() + - You can call (*this)(x) to transform a vector according to the learned + distance metric. That is, it should generally be the case that: + - length((*this)(anchor_vect) - (*this)(near)) + 1 < length((*this)(anchor_vect) - (*this)(far)) + for the anchor_vect, near, and far vectors in the training data. + - #transformed_means() == the mean of the input anchor_vect vectors + after being transformed by #transform() + !*/ + + long in_vector_size ( + ) const; + /*! + ensures + - returns the number of rows that input vectors are required to contain if + they are to be normalized by this object. + !*/ + + long out_vector_size ( + ) const; + /*! + ensures + - returns the number of rows in the normalized vectors that come out of + this object. + - The value returned is always in_vector_size(). So out_vector_size() is + just provided to maintain interface consistency with other vector + normalizer objects. That is, the transformations applied by this object + do not change the dimensionality of the vectors. + !*/ + + const matrix& transformed_means ( + ) const; + /*! + ensures + - returns a column vector V such that: + - V.size() == in_vector_size() + - V is a vector such that subtracting it from transformed vectors + results in them having an expected value of 0. Therefore, it is + equal to transform() times the mean of the input anchor_vect vectors + given to train(). + !*/ + + const matrix& transform ( + ) const; + /*! + ensures + - returns a copy of the transformation matrix we learned during the last + call to train(). + - The returned matrix is square and has in_vector_size() by in_vector_size() + dimensions. + !*/ + + const result_type& operator() ( + const matrix_type& x + ) const; + /*! + requires + - in_vector_size() != 0 + - in_vector_size() == x.size() + - is_col_vector(x) == true + ensures + - returns a normalized version of x, call it Z, that has the following + properties: + - Z == The result of applying the linear transform we learned during + train() to the input vector x. + - Z == transform()*x-transformed_means() + - is_col_vector(Z) == true + - Z.size() == x.size() + - The expected value of each element of Z is 0. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename matrix_type + > + void serialize ( + const vector_normalizer_frobmetric& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename matrix_type + > + void deserialize ( + vector_normalizer_frobmetric& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_VECTOR_NORMALIZER_FRoBMETRIC_ABSTRACT_Hh_ + diff --git a/ml/dlib/dlib/std_allocator.h b/ml/dlib/dlib/std_allocator.h new file mode 100644 index 000000000..b6e411c12 --- /dev/null +++ b/ml/dlib/dlib/std_allocator.h @@ -0,0 +1,199 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_STD_ALLOc_H_ +#define DLIB_STD_ALLOc_H_ + +#include +#include +#include "enable_if.h" +#include "algs.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename M + > + class std_allocator + { + /*! + REQUIREMENTS ON M + must be an implementation of memory_manager/memory_manager_kernel_abstract.h or + must be an implementation of memory_manager_global/memory_manager_global_kernel_abstract.h or + must be an implementation of memory_manager_stateless/memory_manager_stateless_kernel_abstract.h + M::type can be set to anything. + + WHAT THIS OBJECT REPRESENTS + This object is an implementation of an allocator that conforms to the C++ standard + requirements for allocator objects. The M template argument is one of the dlib + memory manager objects and this allocator implementation will do all of its memory allocations + using whatever dlib memory manager you supply. + + Thus, using this allocator object you can use any of the dlib memory manager objects with + the containers in the STL or with any other object that requires a C++ allocator object. + + It is important to note that many STL implementations make the assumption that the memory + allocated by one allocator can be freed by another. This effectively means that you should + only use a global or stateless memory manager with the std_allocator. Either that or you + have to verify that your version of the STL isn't going to try and allocate and deallocate + memory with different allocators. + !*/ + + public: + //type definitions + typedef std::size_t size_type; + typedef std::ptrdiff_t difference_type; + typedef T* pointer; + typedef const T* const_pointer; + typedef T& reference; + typedef const T& const_reference; + typedef T value_type; + + //rebind std_allocator to type U + template + struct rebind { + typedef std_allocator other; + }; + + //return address of values + pointer address (reference value) const { return &value; } + + const_pointer address (const_reference value) const { return &value; } + + /*constructors and destructor + *-nothing to do because the std_allocator has no state + */ + std_allocator() throw() { } + + std_allocator(const std_allocator&) throw() { } + + template + std_allocator (const std_allocator&) throw() { } + + ~std_allocator() throw() { } + + //return maximum number of elements that can be allocated + size_type max_size () const throw() + { + //for numeric_limits see Section 4.3, page 59 + return std::numeric_limits::max() / sizeof(T); + } + + //allocate but don't initialize num elements of type T + pointer allocate ( + size_type num, + typename std_allocator::const_pointer = 0 + ) + { + return (pointer) pool.allocate_array(num*sizeof(T)); + } + + // This function is not required by the C++ standard but some versions of the STL + // distributed with gcc erroneously require it. See the bug report for further + // details: http://gcc.gnu.org/bugzilla/show_bug.cgi?id=51626 + void construct(pointer p) { return construct(p, value_type()); } + + //initialize elements of allocated storage p with value value + void construct (pointer p, const T& value) + { + //initialize memory with placement new + new((void*)p)T(value); + } + + + //destroy elements of initialized storage p + void destroy (pointer p) + { + // destroy objects by calling their destructor + p->~T(); + } + + //deallocate storage p of deleted elements + void deallocate (pointer p, size_type ) + { + pool.deallocate_array((char*)p); + } + + void swap ( + std_allocator& item + ) + { + pool.swap(item.pool); + } + + std_allocator& operator= (const std_allocator&) { return *this;} + + private: + typename M::template rebind::other pool; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename M + > + class std_allocator + { + public: + //type definitions + typedef std::size_t size_type; + typedef std::ptrdiff_t difference_type; + typedef void* pointer; + typedef const void* const_pointer; + typedef void value_type; + + //rebind std_allocator to type U + template + struct rebind { + typedef std_allocator other; + }; + + }; + +// ---------------------------------------------------------------------------------------- + + template + struct std_alloc_compare + { const static bool are_interchangeable = false; }; + + template + struct std_alloc_compare >::type> + { const static bool are_interchangeable = true; }; + + template + struct std_alloc_compare::type> + { const static bool are_interchangeable = true; }; + + //return that all specializations of this std_allocator are interchangeable if they use memory_manager_global + // instances with the same mm_global_type + template + bool operator== ( + const std_allocator&, + const std_allocator& + ) throw() + { return std_alloc_compare::are_interchangeable; } + + template + bool operator!= ( + const std_allocator&, + const std_allocator& + ) throw() + { return !std_alloc_compare::are_interchangeable; } + +// ---------------------------------------------------------------------------------------- + + template + void swap ( + std_allocator& a, + std_allocator& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_STD_ALLOc_H_ + diff --git a/ml/dlib/dlib/stl_checked.h b/ml/dlib/dlib/stl_checked.h new file mode 100644 index 000000000..ec15aa84e --- /dev/null +++ b/ml/dlib/dlib/stl_checked.h @@ -0,0 +1,10 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_STL_CHECKEd_HEADER +#define DLIB_STL_CHECKEd_HEADER + +#include "stl_checked/std_vector_c.h" + +#endif // DLIB_STL_CHECKEd_HEADER + + diff --git a/ml/dlib/dlib/stl_checked/std_vector_c.h b/ml/dlib/dlib/stl_checked/std_vector_c.h new file mode 100644 index 000000000..d46c9850c --- /dev/null +++ b/ml/dlib/dlib/stl_checked/std_vector_c.h @@ -0,0 +1,333 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_STD_VECTOr_C_H_ +#define DLIB_STD_VECTOr_C_H_ + +#include +#include +#include "../assert.h" +#include "std_vector_c_abstract.h" +#include "../serialize.h" +#include "../is_kind.h" + +namespace dlib +{ + + template < + typename T, + typename Allocator = std::allocator + > + class std_vector_c : public std::vector + { + typedef typename std::vector base_type; + public: + // types: + typedef typename Allocator::reference reference; + typedef typename Allocator::const_reference const_reference; + typedef typename base_type::iterator iterator; // See 23.1 + typedef typename base_type::const_iterator const_iterator; // See 23.1 + typedef typename base_type::size_type size_type; // See 23.1 + typedef typename base_type::difference_type difference_type;// See 23.1 + typedef T value_type; + typedef Allocator allocator_type; + typedef typename Allocator::pointer pointer; + typedef typename Allocator::const_pointer const_pointer; + typedef std::reverse_iterator reverse_iterator; + typedef std::reverse_iterator const_reverse_iterator; + + + // 23.2.4.1 construct/copy/destroy: + explicit std_vector_c(const Allocator& alloc= Allocator()) : base_type(alloc) {} + + explicit std_vector_c(size_type n, const T& value = T(), + const Allocator& alloc= Allocator()) : base_type(n, value, alloc) {} + + template + std_vector_c(InputIterator first, InputIterator last, + const Allocator& alloc= Allocator()) : base_type(first,last,alloc) {} + + std_vector_c(const std::vector& x) : base_type(x) {} + + std_vector_c& operator=(const std::vector& x) + { + static_cast(*this) = x; + return *this; + } + + template + void assign(InputIterator first, InputIterator last) { base_type::assign(first,last); } + void assign(size_type n, const T& u) { base_type::assign(n,u); } + allocator_type get_allocator() const { return base_type::get_allocator(); } + // iterators: + iterator begin() { return base_type::begin(); } + const_iterator begin() const { return base_type::begin(); } + iterator end() { return base_type::end(); } + const_iterator end() const { return base_type::end(); } + reverse_iterator rbegin() { return base_type::rbegin(); } + const_reverse_iterator rbegin() const { return base_type::rbegin(); } + reverse_iterator rend() { return base_type::rend(); } + const_reverse_iterator rend() const { return base_type::rend(); } + // 23.2.4.2 capacity: + size_type size() const { return base_type::size(); } + size_type max_size() const { return base_type::max_size(); } + void resize(size_type sz, T c = T()) { base_type::resize(sz,c); } + size_type capacity() const { return base_type::capacity(); } + bool empty() const { return base_type::empty(); } + void reserve(size_type n) { base_type::reserve(n); } + + // element access: + const_reference at(size_type n) const { return base_type::at(n); } + reference at(size_type n) { return base_type::at(n); } + + + // 23.2.4.3 modifiers: + void push_back(const T& x) { base_type::push_back(x); } + void swap(std_vector_c& x) { base_type::swap(x); } + void clear() { base_type::clear(); } + + + // ------------------------------------------------------ + // Things that have preconditions that should be checked. + // ------------------------------------------------------ + + reference operator[]( + size_type n + ) + { + DLIB_CASSERT(n < size(), + "\treference std_vector_c::operator[](n)" + << "\n\tYou have supplied an invalid index" + << "\n\tthis: " << this + << "\n\tn: " << n + << "\n\tsize(): " << size() + ); + return static_cast(*this)[n]; + } + + // ------------------------------------------------------ + + const_reference operator[]( + size_type n + ) const + { + DLIB_CASSERT(n < size(), + "\tconst_reference std_vector_c::operator[](n)" + << "\n\tYou have supplied an invalid index" + << "\n\tthis: " << this + << "\n\tn: " << n + << "\n\tsize(): " << size() + ); + return static_cast(*this)[n]; + } + + // ------------------------------------------------------ + + reference front( + ) + { + DLIB_CASSERT(size() > 0, + "\treference std_vector_c::front()" + << "\n\tYou can't call front() on an empty vector" + << "\n\tthis: " << this + ); + return base_type::front(); + } + + // ------------------------------------------------------ + + const_reference front( + ) const + { + DLIB_CASSERT(size() > 0, + "\tconst_reference std_vector_c::front()" + << "\n\tYou can't call front() on an empty vector" + << "\n\tthis: " << this + ); + return base_type::front(); + } + + // ------------------------------------------------------ + + reference back( + ) + { + DLIB_CASSERT(size() > 0, + "\treference std_vector_c::back()" + << "\n\tYou can't call back() on an empty vector" + << "\n\tthis: " << this + ); + return base_type::back(); + } + + // ------------------------------------------------------ + + const_reference back( + ) const + { + DLIB_CASSERT(size() > 0, + "\tconst_reference std_vector_c::back()" + << "\n\tYou can't call back() on an empty vector" + << "\n\tthis: " << this + ); + return base_type::back(); + } + + // ------------------------------------------------------ + + void pop_back( + ) + { + DLIB_CASSERT(size() > 0, + "\tconst_reference std_vector_c::pop_back()" + << "\n\tYou can't call pop_back() on an empty vector" + << "\n\tthis: " << this + ); + base_type::pop_back(); + } + + // ------------------------------------------------------ + + iterator insert( + iterator position, + const T& x + ) + { + DLIB_CASSERT( begin() <= position && position <= end(), + "\titerator std_vector_c::insert(position,x)" + << "\n\tYou have called insert() with an invalid position" + << "\n\tthis: " << this + ); + return base_type::insert(position, x); + } + + // ------------------------------------------------------ + + void insert( + iterator position, + size_type n, + const T& x + ) + { + DLIB_CASSERT( begin() <= position && position <= end(), + "\tvoid std_vector_c::insert(position,n,x)" + << "\n\tYou have called insert() with an invalid position" + << "\n\tthis: " << this + ); + base_type::insert(position, n, x); + } + + // ------------------------------------------------------ + + template + void insert( + iterator position, + InputIterator first, + InputIterator last + ) + { + DLIB_CASSERT( begin() <= position && position <= end(), + "\tvoid std_vector_c::insert(position,first,last)" + << "\n\tYou have called insert() with an invalid position" + << "\n\tthis: " << this + ); + base_type::insert(position, first, last); + } + + // ------------------------------------------------------ + + iterator erase( + iterator position + ) + { + DLIB_CASSERT( begin() <= position && position < end(), + "\titerator std_vector_c::erase(position)" + << "\n\tYou have called erase() with an invalid position" + << "\n\tthis: " << this + ); + return base_type::erase(position); + } + + // ------------------------------------------------------ + + iterator erase( + iterator first, + iterator last + ) + { + DLIB_CASSERT( begin() <= first && first <= last && last <= end(), + "\titerator std_vector_c::erase(first,last)" + << "\n\tYou have called erase() with an invalid range of iterators" + << "\n\tthis: " << this + ); + return base_type::erase(first,last); + } + + // ------------------------------------------------------ + + + }; + +// ---------------------------------------------------------------------------------------- + +// Add these swaps just to make absolutely sure the specialized swap always gets called even +// if the compiler is crappy and would otherwise mess it up. + template + void swap(std_vector_c& x, std_vector_c& y) { x.swap(y); } + + template + void swap(std::vector& x, std_vector_c& y) { x.swap(y); } + + template + void swap(std_vector_c& x, std::vector& y) { y.swap(x); } + +// ---------------------------------------------------------------------------------------- + + template + void serialize ( + const std_vector_c& item, + std::ostream& out + ) + { + try + { + const unsigned long size = static_cast(item.size()); + + serialize(size,out); + for (unsigned long i = 0; i < item.size(); ++i) + serialize(item[i],out); + } + catch (serialization_error& e) + { throw serialization_error(e.info + "\n while serializing object of type std_vector_c"); } + } + +// ---------------------------------------------------------------------------------------- + + template + void deserialize ( + std_vector_c& item, + std::istream& in + ) + { + try + { + unsigned long size; + deserialize(size,in); + item.resize(size); + for (unsigned long i = 0; i < size; ++i) + deserialize(item[i],in); + } + catch (serialization_error& e) + { throw serialization_error(e.info + "\n while deserializing object of type std_vector_c"); } + } + +// ---------------------------------------------------------------------------------------- + + template + struct is_std_vector > { const static bool value = true; }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_STD_VECTOr_C_H_ + diff --git a/ml/dlib/dlib/stl_checked/std_vector_c_abstract.h b/ml/dlib/dlib/stl_checked/std_vector_c_abstract.h new file mode 100644 index 000000000..2a8045c72 --- /dev/null +++ b/ml/dlib/dlib/stl_checked/std_vector_c_abstract.h @@ -0,0 +1,470 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_STD_VECTOr_C_ABSTRACT_H_ +#ifdef DLIB_STD_VECTOr_C_ABSTRACT_H_ + +#include +#include +#include "../assert.h" + +namespace dlib +{ + + template < + typename T, + typename Allocator = std::allocator + > + class std_vector_c : public std::vector + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a simple wrapper around the std::vector object. It + provides an identical interface but also checks the preconditions of + each member function. That is, if you violate a requires + clause the dlib::fatal_error exception is thrown. + !*/ + + typedef typename std::vector base_type; + public: + typedef typename Allocator::reference reference; + typedef typename Allocator::const_reference const_reference; + typedef typename base_type::iterator iterator; + typedef typename base_type::const_iterator const_iterator; + typedef typename base_type::size_type size_type; + typedef typename base_type::difference_type difference_type; + typedef T value_type; + typedef Allocator allocator_type; + typedef typename Allocator::pointer pointer; + typedef typename Allocator::const_pointer const_pointer; + typedef std::reverse_iterator reverse_iterator; + typedef std::reverse_iterator const_reverse_iterator; + + + explicit std_vector_c( + const Allocator& alloc = Allocator() + ); + /*! + ensures + - #get_allocator() == alloc + - #size() == 0 + !*/ + + explicit std_vector_c ( + size_type n, + const T& value = T(), + const Allocator& alloc = Allocator() + ); + /*! + ensures + - #size() == n + - #get_allocator() == alloc + - for all valid i: + - (*this)[i] == value + !*/ + + template + std_vector_c ( + InputIterator first, + InputIterator last, + const Allocator& alloc = Allocator() + ); + /*! + ensures + - #size() == std::distance(first,last) + - #get_allocator() == alloc + - std::equal(first, last, begin()) == true + !*/ + + std_vector_c( + const std::vector& x + ); + /*! + ensures + - #*this == x + !*/ + + std_vector_c& operator= ( + const std::vector& x + ); + /*! + ensures + - #*this == x + - returns #*this + !*/ + + template + void assign( + InputIterator first, + InputIterator last + ); + /*! + ensures + - #size() == std::distance(first,last) + - std::equal(first, last, begin()) == true + !*/ + + void assign( + size_type n, + const T& value + ); + /*! + ensures + - #size() == n + - for all valid i: + - (*this)[i] == value + !*/ + + allocator_type get_allocator( + ) const; + /*! + ensures + - returns the allocator used by this vector + !*/ + + iterator begin( + ); + /*! + ensures + - if (size() > 0) then + - returns an iterator referring to the first element in + this container. + - else + - returns end() + !*/ + + const_iterator begin( + ) const; + /*! + ensures + - if (size() > 0) then + - returns a const_iterator referring to the first element in + this container. + - else + - returns end() + !*/ + + iterator end( + ); + /*! + ensures + - returns an iterator that represents one past the end of + this container + !*/ + + const_iterator end( + ) const; + /*! + ensures + - returns an iterator that represents one past the end of + this container + !*/ + + reverse_iterator rbegin( + ); + /*! + ensures + - returns std::reverse_iterator(end()) + !*/ + + const_reverse_iterator rbegin( + ) const; + /*! + ensures + - returns std::reverse_iterator(end()) + !*/ + + reverse_iterator rend( + ); + /*! + ensures + - returns std::reverse_iterator(begin()) + !*/ + + const_reverse_iterator rend( + ) const; + /*! + ensures + - returns std::reverse_iterator(begin()) + !*/ + + size_type size( + ) const; + /*! + ensures + - returns end()-begin() + (i.e. returns the number of elements in this container) + !*/ + + size_type max_size( + ) const; + /*! + ensures + - returns the maximum number of elements this vector can contain + !*/ + + void resize( + size_type sz, + T c = T() + ); + /*! + ensures + - #size() == sz + - any element with index between 0 and sz - 1 which was in the + vector before the call to resize() retains its value and index. + All other elements have a value given by c. + !*/ + + size_type capacity( + ) const; + /*! + ensures + - returns the total number of elements that the vector can hold without + requiring reallocation. + !*/ + + bool empty( + ) const; + /*! + ensures + - if (size() == 0) then + - returns true + - else + - returns false + !*/ + + void reserve( + size_type n + ); + /*! + ensures + - #capacity() >= n + !*/ + + const_reference at( + size_type n + ) const; + /*! + ensures + - if (n < size()) then + - returns a const reference to (*this)[n] + - else + - throws std::out_of_range + !*/ + + reference at( + size_type n + ); + /*! + ensures + - if (n < size()) then + - returns a reference to (*this)[n] + - else + - throws std::out_of_range + !*/ + + void push_back( + const T& x + ); + /*! + ensures + - #size() == size() + 1 + - #back() == x + !*/ + + void swap( + std_vector_c& x + ); + /*! + ensures + - swaps the state of *this and x + !*/ + + void clear( + ); + /*! + ensures + - #size() == 0 + !*/ + + reference operator[]( + size_type n + ); + /*! + requires + - n < size() + ensures + - returns a reference to the nth element of this container + !*/ + + const_reference operator[]( + size_type n + ) const; + /*! + requires + - n < size() + ensures + - returns a const reference to the nth element of this container + !*/ + + reference front( + ); + /*! + requires + - size() > 0 + ensures + - returns a reference to (*this)[0] + !*/ + + const_reference front( + ) const; + /*! + requires + - size() > 0 + ensures + - returns a const reference to (*this)[0] + !*/ + + reference back( + ); + /*! + requires + - size() > 0 + ensures + - returns a reference to (*this)[size()-1] + !*/ + + const_reference back( + ) const; + /*! + requires + - size() > 0 + ensures + - returns a const reference to (*this)[size()-1] + !*/ + + void pop_back( + ); + /*! + requires + - size() > 0 + ensures + - #size() == size() - 1 + - removes the last element in the vector but leaves the others + unmodified. + !*/ + + iterator insert( + iterator position, + const T& x + ); + /*! + requires + - begin() <= position && position <= end() + (i.e. position references an element in this vector object) + ensures + - #size() == size() + 1 + - inserts a copy of x into *this before the given position + - returns an iterator that points to the copy of x inserted + into *this + !*/ + + void insert( + iterator position, + size_type n, + const T& x + ); + /*! + requires + - begin() <= position && position <= end() + (i.e. position references an element in this vector object) + ensures + - #size() == size() + n + - inserts n copies of x into *this before the given position + !*/ + + template + void insert( + iterator position, + InputIterator first, + InputIterator last + ); + /*! + requires + - begin() <= position && position <= end() + (i.e. position references an element in this vector object) + - first and last are not iterators into *this + ensures + - #size() == size() + std::distance(last,first) + - inserts copies of the range of elements [first,last) into *this + before the given position + !*/ + + iterator erase( + iterator position + ); + /*! + requires + - begin() <= position && position < end() + (i.e. position references an element in this vector object) + ensures + - #size() == size() - 1 + - removes the element in this vector referenced by position but + leaves all other elements in this vector unmodified. + - if (position < end()-1) then + - returns an iterator referencing the element immediately + following *position prior to the erase. + - else + - returns end() + !*/ + + iterator erase( + iterator first, + iterator last + ); + /*! + requires + - begin() <= first && first <= last && last <= end() + (i.e. the range [first,last) must be inside this container ) + ensures + - #size() == size() - (last-first) + - removes the elements in this vector referenced by the + iterator range [first,last) but leaves all other elements + in this vector unmodified. + - if (last < end()-1) then + - returns an iterator referencing the element immediately + following *last prior to the erase. + - else + - returns end() + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template + void serialize ( + const std_vector_c& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + +// ---------------------------------------------------------------------------------------- + + template + void deserialize ( + std_vector_c& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_STD_VECTOr_C_ABSTRACT_H_ + + diff --git a/ml/dlib/dlib/string.h b/ml/dlib/dlib/string.h new file mode 100644 index 000000000..671fee404 --- /dev/null +++ b/ml/dlib/dlib/string.h @@ -0,0 +1,9 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_STRINg_TOP_ +#define DLIB_STRINg_TOP_ + +#include "string/string.h" + +#endif // DLIB_STRINg_TOP_ + diff --git a/ml/dlib/dlib/string/cassert b/ml/dlib/dlib/string/cassert new file mode 100644 index 000000000..6139ba823 --- /dev/null +++ b/ml/dlib/dlib/string/cassert @@ -0,0 +1 @@ +#include "../dlib_include_path_tutorial.txt" diff --git a/ml/dlib/dlib/string/iomanip b/ml/dlib/dlib/string/iomanip new file mode 100644 index 000000000..6139ba823 --- /dev/null +++ b/ml/dlib/dlib/string/iomanip @@ -0,0 +1 @@ +#include "../dlib_include_path_tutorial.txt" diff --git a/ml/dlib/dlib/string/iosfwd b/ml/dlib/dlib/string/iosfwd new file mode 100644 index 000000000..6139ba823 --- /dev/null +++ b/ml/dlib/dlib/string/iosfwd @@ -0,0 +1 @@ +#include "../dlib_include_path_tutorial.txt" diff --git a/ml/dlib/dlib/string/iostream b/ml/dlib/dlib/string/iostream new file mode 100644 index 000000000..6139ba823 --- /dev/null +++ b/ml/dlib/dlib/string/iostream @@ -0,0 +1 @@ +#include "../dlib_include_path_tutorial.txt" diff --git a/ml/dlib/dlib/string/locale b/ml/dlib/dlib/string/locale new file mode 100644 index 000000000..6139ba823 --- /dev/null +++ b/ml/dlib/dlib/string/locale @@ -0,0 +1 @@ +#include "../dlib_include_path_tutorial.txt" diff --git a/ml/dlib/dlib/string/string.h b/ml/dlib/dlib/string/string.h new file mode 100644 index 000000000..2c2602198 --- /dev/null +++ b/ml/dlib/dlib/string/string.h @@ -0,0 +1,1004 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_STRINg_ +#define DLIB_STRINg_ + +#include "string_abstract.h" +#include +#include "../algs.h" +#include +#include +#include +#include "../error.h" +#include "../assert.h" +#include "../uintn.h" +#include +#include +#include +#include "../enable_if.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename traits, + typename alloc + > + inline const typename disable_if,std::string>::type narrow ( + const std::basic_string& str + ) + { + std::string temp; + temp.reserve(str.size()); + std::string::size_type i; + for (i = 0; i < str.size(); ++i) + { + if (zero_extend_cast(str[i]) > 255) + temp += ' '; + else + temp += zero_extend_cast(str[i]); + } + return temp; + } + + template < + typename charT, + typename traits, + typename alloc + > + inline const typename enable_if,std::string>::type narrow ( + const std::basic_string& str + ) + { + return str; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename traits, + typename alloc + > + const std::basic_string tolower ( + const std::basic_string& str + ) + { + std::basic_string temp; + + temp.resize(str.size()); + + for (typename std::basic_string::size_type i = 0; i < str.size(); ++i) + temp[i] = (char)std::tolower(str[i]); + + return temp; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename traits, + typename alloc + > + const std::basic_string toupper ( + const std::basic_string& str + ) + { + std::basic_string temp; + + temp.resize(str.size()); + + for (typename std::basic_string::size_type i = 0; i < str.size(); ++i) + temp[i] = (char)std::toupper(str[i]); + + return temp; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename traits, + typename alloc + > + bool strings_equal_ignore_case ( + const std::basic_string& str1, + const std::basic_string& str2 + ) + { + if (str1.size() != str2.size()) + return false; + + for (typename std::basic_string::size_type i = 0; i < str1.size(); ++i) + { + if (std::tolower(str1[i]) != std::tolower(str2[i])) + return false; + } + + return true; + } + + template < + typename traits, + typename alloc + > + bool strings_equal_ignore_case ( + const std::basic_string& str1, + const char* str2 + ) + { + typename std::basic_string::size_type i; + for (i = 0; i < str1.size(); ++i) + { + // if we hit the end of str2 then the strings aren't the same length + if (str2[i] == '\0') + return false; + + if (std::tolower(str1[i]) != std::tolower(str2[i])) + return false; + } + + // This happens when str2 is longer than str1 + if (str2[i] != '\0') + return false; + + return true; + } + + template < + typename traits, + typename alloc + > + bool strings_equal_ignore_case ( + const char* str1, + const std::basic_string& str2 + ) + { + return strings_equal_ignore_case(str2, str1); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename traits, + typename alloc + > + bool strings_equal_ignore_case ( + const std::basic_string& str1, + const std::basic_string& str2, + unsigned long num + ) + { + if (str1.size() != str2.size() && (str1.size() < num || str2.size() < num)) + return false; + + for (typename std::basic_string::size_type i = 0; i < str1.size() && i < num; ++i) + { + if (std::tolower(str1[i]) != std::tolower(str2[i])) + return false; + } + + return true; + } + + template < + typename traits, + typename alloc + > + bool strings_equal_ignore_case ( + const std::basic_string& str1, + const char* str2, + unsigned long num + ) + { + typename std::basic_string::size_type i; + for (i = 0; i < str1.size() && i < num; ++i) + { + // if we hit the end of str2 then the strings aren't the same length + if (str2[i] == '\0') + return false; + + if (std::tolower(str1[i]) != std::tolower(str2[i])) + return false; + } + + return true; + } + + template < + typename traits, + typename alloc + > + bool strings_equal_ignore_case ( + const char* str1, + const std::basic_string& str2, + unsigned long num + ) + { + return strings_equal_ignore_case(str2, str1, num); + } + +// ---------------------------------------------------------------------------------------- + + class cast_to_string_error : public error + { + public: + cast_to_string_error():error(ECAST_TO_STRING) {} + }; + + template < + typename T + > + const std::string cast_to_string ( + const T& item + ) + { + std::ostringstream sout; + sout << item; + if (!sout) + throw cast_to_string_error(); + return sout.str(); + } + + // don't declare this if we are using mingw because it apparently doesn't + // support iostreams with wchar_t? +#if !(defined(__MINGW32__) && (__GNUC__ < 4)) + template < + typename T + > + const std::wstring cast_to_wstring ( + const T& item + ) + { + std::basic_ostringstream sout; + sout << item; + if (!sout) + throw cast_to_string_error(); + return sout.str(); + } +#endif + +// ---------------------------------------------------------------------------------------- + + inline std::string pad_int_with_zeros ( + int i, + unsigned long width = 6 + ) + { + std::ostringstream sout; + sout << std::setw(width) << std::setfill('0') << i; + return sout.str(); + } + +// ---------------------------------------------------------------------------------------- + + class string_cast_error : public error + { + public: + string_cast_error(const std::string& str): + error(ESTRING_CAST,"string cast error: invalid string = '" + str + "'") {} + }; + + template < + typename T + > + struct string_cast_helper + { + template < typename charT, typename traits, typename alloc > + static const T cast ( + const std::basic_string& str + ) + { + using namespace std; + basic_istringstream sin(str); + T temp; + sin >> temp; + if (!sin) throw string_cast_error(narrow(str)); + if (sin.get() != std::char_traits::eof()) throw string_cast_error(narrow(str)); + return temp; + } + }; + + template + struct string_cast_helper > + { + template < typename charT, typename traits, typename alloc > + static const std::basic_string cast ( + const std::basic_string& str + ) + { + std::basic_string temp; + temp.resize(str.size()); + for (unsigned long i = 0; i < str.size(); ++i) + temp[i] = zero_extend_cast(str[i]); + return temp; + } + }; + + template <> + struct string_cast_helper + { + template < typename charT, typename traits, typename alloc > + static bool cast ( + const std::basic_string& str + ) + { + using namespace std; + if (str.size() == 1 && str[0] == '1') + return true; + if (str.size() == 1 && str[0] == '0') + return false; + if (tolower(narrow(str)) == "true") + return true; + if (tolower(narrow(str)) == "false") + return false; + + throw string_cast_error(narrow(str)); + } + }; + +#define DLIB_STRING_CAST_INTEGRAL(type) \ + template <> \ + struct string_cast_helper \ + { \ + template < typename charT, typename traits, typename alloc> \ + static type cast ( \ + const std::basic_string& str \ + ) \ + { \ + using namespace std; \ + basic_istringstream sin(str); \ + type temp; \ + if (str.size() > 2 && str[0] == _dT(charT,'0') && str[1] == _dT(charT,'x')) \ + sin >> hex >> temp; \ + else \ + sin >> temp; \ + if (!sin) throw string_cast_error(narrow(str)); \ + if (sin.get() != std::char_traits::eof()) throw string_cast_error(narrow(str)); \ + return temp; \ + } \ + }; + + DLIB_STRING_CAST_INTEGRAL(unsigned short) + DLIB_STRING_CAST_INTEGRAL(short) + DLIB_STRING_CAST_INTEGRAL(unsigned int) + DLIB_STRING_CAST_INTEGRAL(int) + DLIB_STRING_CAST_INTEGRAL(unsigned long) + DLIB_STRING_CAST_INTEGRAL(long) + DLIB_STRING_CAST_INTEGRAL(uint64) + + template < + typename T, + typename charT, + typename traits, + typename alloc + > + inline const T string_cast ( + const std::basic_string& str + ) + { + COMPILE_TIME_ASSERT(is_pointer_type::value == false); + return string_cast_helper::cast(str); + } + + template + inline const T string_cast (const char* str){ return string_cast(std::string(str)); } + template + inline const T string_cast (const wchar_t* str){ return string_cast(std::wstring(str)); } + +// ---------------------------------------------------------------------------------------- + + class string_assign + { + template < + typename charT, + typename traits, + typename alloc + > + class string_assign_helper + { + public: + string_assign_helper ( + const std::basic_string& str_ + ) : str(str_) {} + + template + operator T () const + { + return string_cast(str); + } + + private: + + const std::basic_string& str; + }; + + // ------------- + + class char_assign_helper + { + public: + char_assign_helper ( + const char* str_ + ) : str(str_) {} + + template + operator T () const + { + return string_cast(str); + } + + private: + + const char* str; + }; + + // ------------- + + class wchar_t_assign_helper + { + public: + wchar_t_assign_helper ( + const wchar_t* str_ + ) : str(str_) {} + + template + operator T () const + { + return string_cast(str); + } + + private: + + const wchar_t* str; + }; + + // ------------- + + public: + + template < + typename charT, + typename traits, + typename alloc + > + string_assign_helper operator=( + const std::basic_string& str + ) const + { + return string_assign_helper(str); + } + + char_assign_helper operator= ( + const char* str + ) const + { + return char_assign_helper(str); + } + + wchar_t_assign_helper operator= ( + const wchar_t* str + ) const + { + return wchar_t_assign_helper(str); + } + }; + + const string_assign sa = string_assign(); + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename traits, + typename alloc + > + const std::basic_string wrap_string ( + const std::basic_string& str, + const unsigned long first_pad = 0, + const unsigned long rest_pad = 0, + const unsigned long max_per_line = 79 + ) + { + DLIB_ASSERT ( first_pad < max_per_line && rest_pad < max_per_line && + rest_pad >= first_pad, + "\tconst std::basic_string wrap_string()" + << "\n\tfirst_pad: " << first_pad + << "\n\trest_pad: " << rest_pad + << "\n\tmax_per_line: " << max_per_line ); + + using namespace std; + + basic_ostringstream sout; + basic_istringstream sin(str); + + for (unsigned long i = 0; i < rest_pad; ++i) + sout << _dT(charT," "); + const basic_string pad(sout.str()); + sout.str(_dT(charT,"")); + + for (unsigned long i = 0; i < first_pad; ++i) + sout << _dT(charT," "); + + + typename basic_string::size_type remaining = max_per_line - rest_pad; + + basic_string temp; + + sin >> temp; + while (sin) + { + if (temp.size() > remaining) + { + if (temp.size() + rest_pad >= max_per_line) + { + string::size_type i = 0; + for (; i < temp.size(); ++i) + { + sout << temp[i]; + --remaining; + if (remaining == 0) + { + sout << _dT(charT,"\n") << pad; + remaining = max_per_line - rest_pad; + } + } + } + else + { + sout << _dT(charT,"\n") << pad << temp; + remaining = max_per_line - rest_pad - temp.size(); + } + } + else if (temp.size() == remaining) + { + sout << temp; + remaining = 0; + } + else + { + sout << temp; + remaining -= temp.size(); + } + + sin >> temp; + if (remaining == 0 && sin) + { + sout << _dT(charT,"\n") << pad; + remaining = max_per_line - rest_pad; + } + else + { + sout << _dT(charT," "); + --remaining; + } + } + + return sout.str(); + } + + template < + typename charT + > + const std::basic_string wrap_string ( + const charT* str, + const unsigned long first_pad = 0, + const unsigned long rest_pad = 0, + const unsigned long max_per_line = 79 + ) { return wrap_string(std::basic_string(str),first_pad,rest_pad,max_per_line); } + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename traits, + typename alloc + > + const std::basic_string ltrim ( + const std::basic_string& str, + const std::basic_string& trim_chars + ) + { + typedef std::basic_string string; + typename string::size_type pos = str.find_first_not_of(trim_chars); + if (pos != string::npos) + return str.substr(pos); + else + return std::basic_string(); + } + + template < + typename charT, + typename traits, + typename alloc + > + const std::basic_string ltrim ( + const std::basic_string& str, + const charT* trim_chars = _dT(charT," \t\r\n") + ) { return ltrim(str,std::basic_string(trim_chars)); } + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename traits, + typename alloc + > + const std::basic_string rtrim ( + const std::basic_string& str, + const std::basic_string& trim_chars + ) + { + typedef std::basic_string string; + + typename string::size_type pos = str.find_last_not_of(trim_chars); + if (pos != string::npos) + return str.substr(0,pos+1); + else + return std::basic_string(); + } + + template < + typename charT, + typename traits, + typename alloc + > + const std::basic_string rtrim ( + const std::basic_string& str, + const charT* trim_chars = _dT(charT," \t\r\n") + ) { return rtrim(str,std::basic_string(trim_chars)); } + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename traits, + typename alloc + > + const std::basic_string trim ( + const std::basic_string& str, + const std::basic_string& trim_chars + ) + { + typedef std::basic_string string; + typename string::size_type lpos = str.find_first_not_of(trim_chars); + if (lpos != string::npos) + { + typename string::size_type rpos = str.find_last_not_of(trim_chars); + return str.substr(lpos,rpos-lpos+1); + } + else + { + return std::basic_string(); + } + } + + template < + typename charT, + typename traits, + typename alloc + > + const std::basic_string trim ( + const std::basic_string& str, + const charT* trim_chars = _dT(charT," \t\r\n") + ) { return trim(str,std::basic_string(trim_chars)); } + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename traits, + typename alloc + > + const std::basic_string rpad ( + const std::basic_string& str, + long pad_length, + const std::basic_string& pad_string + ) + { + typedef std::basic_string string; + // if str is too big then just return str + if (pad_length <= static_cast(str.size())) + return str; + + // make the string we will padd onto the string + string P; + while (P.size() < pad_length - str.size()) + P += pad_string; + P = P.substr(0,pad_length - str.size()); + + // return the padded string + return str + P; + } + + template < + typename charT, + typename traits, + typename alloc + > + const std::basic_string rpad ( + const std::basic_string& str, + long pad_length, + const charT* pad_string = _dT(charT," ") + ) { return rpad(str,pad_length,std::basic_string(pad_string)); } + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename traits, + typename alloc + > + const std::basic_string lpad ( + const std::basic_string& str, + long pad_length, + const std::basic_string& pad_string + ) + { + typedef std::basic_string string; + // if str is too big then just return str + if (pad_length <= static_cast(str.size())) + return str; + + // make the string we will padd onto the string + string P; + while (P.size() < pad_length - str.size()) + P += pad_string; + P = P.substr(0,pad_length - str.size()); + + // return the padded string + return P + str; + } + + template < + typename charT, + typename traits, + typename alloc + > + const std::basic_string lpad ( + const std::basic_string& str, + long pad_length, + const charT* pad_string = _dT(charT," ") + ) { return lpad(str,pad_length,std::basic_string(pad_string)); } + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename traits, + typename alloc + > + const std::basic_string pad ( + const std::basic_string& str, + long pad_length, + const std::basic_string& pad_string + ) + { + const long str_size = static_cast(str.size()); + return rpad(lpad(str,(pad_length-str_size)/2 + str_size,pad_string), + pad_length, + pad_string); + } + + template < + typename charT, + typename traits, + typename alloc + > + const std::basic_string pad ( + const std::basic_string& str, + long pad_length, + const charT* pad_string = _dT(charT," ") + ) { return pad(str,pad_length,std::basic_string(pad_string)); } + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename traits, + typename alloc + > + const std::basic_string left_substr ( + const std::basic_string& str, + const std::basic_string& delim + ) + { + return str.substr(0,str.find_first_of(delim)); + } + + template < + typename charT, + typename traits, + typename alloc + > + const std::basic_string left_substr ( + const std::basic_string& str, + const charT* delim = _dT(charT," \n\r\t") + ) + { + return str.substr(0,str.find_first_of(delim)); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename traits, + typename alloc + > + const std::basic_string right_substr ( + const std::basic_string& str, + const std::basic_string& delim + ) + { + typename std::basic_string::size_type delim_pos = str.find_last_of(delim); + if (delim_pos != std::basic_string::npos) + return str.substr(delim_pos+1); + else + return _dT(charT,""); + } + + template < + typename charT, + typename traits, + typename alloc + > + const std::basic_string right_substr ( + const std::basic_string& str, + const charT* delim = _dT(charT," \n\r\t") + ) + { + typename std::basic_string::size_type delim_pos = str.find_last_of(delim); + if (delim_pos != std::basic_string::npos) + return str.substr(delim_pos+1); + else + return _dT(charT,""); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename traits, + typename alloc + > + std::pair, std::basic_string > + split_on_first ( + const std::basic_string& str, + const charT* delim = _dT(charT," \n\r\t") + ) + { + typename std::basic_string::size_type delim_pos = str.find_first_of(delim); + if (delim_pos != std::basic_string::npos) + return std::make_pair(str.substr(0, delim_pos), str.substr(delim_pos+1)); + else + return std::make_pair(str, _dT(charT,"")); + } + + template < + typename charT, + typename traits, + typename alloc + > + inline std::pair, std::basic_string > + split_on_first ( + const std::basic_string& str, + const std::basic_string& delim + ) + { + return split_on_first(str, delim.c_str()); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename traits, + typename alloc + > + std::pair, std::basic_string > + split_on_last ( + const std::basic_string& str, + const charT* delim = _dT(charT," \n\r\t") + ) + { + typename std::basic_string::size_type delim_pos = str.find_last_of(delim); + if (delim_pos != std::basic_string::npos) + return std::make_pair(str.substr(0, delim_pos), str.substr(delim_pos+1)); + else + return std::make_pair(str, _dT(charT,"")); + } + + template < + typename charT, + typename traits, + typename alloc + > + inline std::pair, std::basic_string > + split_on_last ( + const std::basic_string& str, + const std::basic_string& delim + ) + { + return split_on_last(str, delim.c_str()); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename traits, + typename alloc + > + const std::vector > split ( + const std::basic_string& str, + const charT* delim = _dT(charT," \n\r\t") + ) + { + std::basic_string temp; + + std::vector > res; + + for (unsigned long i = 0; i < str.size(); ++i) + { + // check if delim contains the character str[i] + bool hit = false; + const charT* d = delim; + while (*d != '\0') + { + if (str[i] == *d) + { + hit = true; + break; + } + ++d; + } + + if (hit) + { + if (temp.size() != 0) + { + res.push_back(temp); + temp.clear(); + } + } + else + { + temp.push_back(str[i]); + } + } + + if (temp.size() != 0) + res.push_back(temp); + + return res; + } + + template < + typename charT, + typename traits, + typename alloc + > + const std::vector > split ( + const std::basic_string& str, + const std::basic_string& delim + ) + { + return split(str,delim.c_str()); + } + + inline const std::vector split ( + const char* str, + const char* delim = " \n\r\t" + ) + { + return split(std::string(str),delim); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_STRINg_ + diff --git a/ml/dlib/dlib/string/string_abstract.h b/ml/dlib/dlib/string/string_abstract.h new file mode 100644 index 000000000..0a1ef0c79 --- /dev/null +++ b/ml/dlib/dlib/string/string_abstract.h @@ -0,0 +1,652 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_STRINg_ABSTRACT_ +#ifdef DLIB_STRINg_ABSTRACT_ + +#include +#include +#include +#include "../error.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class string_cast_error : public error + { + public: + string_cast_error():error(ECAST_TO_STRING) {} + }; + + template < + typename T, + typename charT, + typename traits, + typename alloc + > + const T string_cast ( + const std::basic_string& str + ); + /*! + requires + - T is not a pointer type + ensures + - returns str converted to T + throws + - string_cast_error + This exception is thrown if string_cast() is unable to convert + str into a T. Also, string_cast_error::info == str + !*/ + +// ---------------------------------------------------------------------------------------- + + class string_assign + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a simple tool which provides an alternative syntax for using + the string_cast() function. It can be understood by considering + the following example: + + string_assign sa; + int val; + double dval; + + val = sa = "1234"; // executes: val = string_cast("1234"); + dval = sa = "3.141"; // executes: val = string_cast("3.141"); + + After executing, val will be equal to 1234 and dval will be 3.141. + Note that you can use string_assign to assign to any type which you could + use with string_cast(), except for std::basic_string, assigning to this + type is ambiguous for boring technical reasons. But there isn't much + point in using this tool to assign from one string to another so it doesn't + matter. + + Additionally, note that there is a global instance of this object, dlib::sa. + So you never have to create a string_assign object yourself. Finally, this + object is totally stateless and threadsafe. + !*/ + }; + + const string_assign sa = string_assign(); + +// ---------------------------------------------------------------------------------------- + + class cast_to_string_error : public error + { + public: + cast_to_string_error():error(ECAST_TO_STRING) {} + }; + + template < + typename T + > + const std::string cast_to_string ( + const T& item + ); + /*! + requires + - T is not a pointer type + ensures + - returns item converted to std::string + throws + - cast_to_string_error + This exception is thrown if cast_to_string() is unable to convert + item into a std::string. + !*/ + + template < + typename T + > + const std::wstring cast_to_wstring ( + const T& item + ); + /*! + requires + - T is not a pointer type + ensures + - returns item converted to std::wstring + throws + - cast_to_string_error + This exception is thrown if cast_to_string() is unable to convert + item into a std::string. + !*/ + +// ---------------------------------------------------------------------------------------- + + std::string pad_int_with_zeros ( + int i, + unsigned long width = 6 + ); + /*! + ensures + - converts i into a string of at least width characters in length. If + necessary, the string will be padded with leading zeros to get + to width characters. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename traits, + typename alloc + > + const std::string narrow ( + const std::basic_string& str + ); + /*! + ensures + - returns str as a std::string by converting every character in it to a char. + Note that any characters that do not have a mapping to type char will be + converted to a space. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename traits, + typename alloc + > + const std::basic_string wrap_string ( + const std::basic_string& str, + const unsigned long first_pad = 0, + const unsigned long rest_pad = 0, + const unsigned long max_per_line = 79 + ); + /*! + requires + - first_pad < max_per_line + - rest_pad < max_per_line + - rest_pad >= first_pad + ensures + - returns a copy of str S such that: + - S is broken up into lines separated by the \n character. + - The first line starts with first_pad space characters. + - The second and all subsequent lines start with rest_pad space characters. + - The first line is no longer than max_per_line - (rest_pad-first_pad) characters. + - The second and all subsequent lines are no longer than max_per_line characters. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename traits + typename alloc + > + const std::basic_string tolower ( + const std::basic_string& str + ); + /*! + ensures + - returns a copy of str S such that: + - #S.size() == str.size() + - #S[i] == std::tolower(str[i]) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename traits, + typename alloc + > + const std::basic_string toupper ( + const std::basic_string& str + ); + /*! + ensures + - returns a copy of str S such that: + - #S.size() == str.size() + - #S[i] == std::toupper(str[i]) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename traits, + typename alloc + > + bool strings_equal_ignore_case ( + const std::basic_string& str1, + const std::basic_string& str2 + ); + /*! + ensures + - returns tolower(str1) == tolower(str2) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename traits, + typename alloc + > + bool strings_equal_ignore_case ( + const std::basic_string& str1, + const std::basic_string& str2, + unsigned long num + ); + /*! + ensures + - returns tolower(str1.substr(0,num)) == tolower(str2.substr(0,num)) + (i.e. only compares the first num characters) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename traits, + typename alloc + > + const std::basic_string ltrim ( + const std::basic_string& str, + const std::basic_string& trim_chars + ); + /*! + ensures + - returns a copy of str with any leading trim_chars + from the left side of the string removed. + !*/ + + template < + typename charT, + typename traits, + typename alloc + > + const std::basic_string ltrim ( + const std::basic_string& str, + const charT* trim_chars = _dT(charT," \t\r\n") + ); + /*! + requires + - trim_chars == a valid null-terminated C string + ensures + - returns ltrim(str, std::basic_string(trim_chars)) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename traits, + typename alloc + > + const std::basic_string rtrim ( + const std::basic_string& str, + const std::basic_string& trim_chars + ); + /*! + ensures + - returns a copy of str with any trailing trim_chars + from the right side of the string removed. + !*/ + + template < + typename charT, + typename traits, + typename alloc + > + const std::basic_string rtrim ( + const std::basic_string& str, + const charT* trim_chars = _dT(charT," \t\r\n") + ); + /*! + requires + - trim_chars == a valid null-terminated C string + ensures + - returns rtrim(str, std::basic_string(trim_chars)) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename traits, + typename alloc + > + const std::basic_string trim ( + const std::basic_string& str, + const std::basic_string& trim_chars + ); + /*! + ensures + - returns a copy of str with any leading or trailing trim_chars + from the ends of the string removed. + !*/ + + template < + typename charT, + typename traits, + typename alloc + > + const std::basic_string trim ( + const std::basic_string& str, + const charT* trim_chars = _dT(charT," \t\r\n") + ); + /*! + requires + - trim_chars == a valid null-terminated C string + ensures + - returns trim(str, std::basic_string(trim_chars)) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename traits, + typename alloc + > + const std::basic_string rpad ( + const std::basic_string& str, + long pad_length, + const std::basic_string& pad_string + ); + /*! + ensures + - if (pad_length <= str.size()) then + - returns str + - else + - let P be a string defined as follows: + - P.size() == pad_length - str.size() + - P == (pad_string + pad_string + ... + pad_string).substr(0,pad_length - str.size()) + (i.e. P == a string with the above specified size that contains just + repitions of the pad_string) + - returns the string str + P + !*/ + + template < + typename charT, + typename traits, + typename alloc + > + const std::basic_string rpad ( + const std::basic_string& str, + long pad_length, + const charT* pad_string = _dT(charT," ") + ); + /*! + requires + - pad_string == a valid null-terminated C string + ensures + - returns rpad(str, pad_length, std::basic_string(pad_string)) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename traits, + typename alloc + > + const std::basic_string lpad ( + const std::basic_string& str, + long pad_length, + const std::basic_string& pad_string + ); + /*! + ensures + - if (pad_length <= str.size()) then + - returns str + - else + - let P be a string defined as follows: + - P.size() == pad_length - str.size() + - P == (pad_string + pad_string + ... + pad_string).substr(0,pad_length - str.size()) + (i.e. P == a string with the above specified size that contains just + repitions of the pad_string) + - returns the string P + str + !*/ + + template < + typename charT, + typename traits, + typename alloc + > + const std::basic_string lpad ( + const std::basic_string& str, + long pad_length, + const charT* pad_string = _dT(charT," ") + ); + /*! + requires + - pad_string == a valid null-terminated C string + ensures + - returns lpad(str, pad_length, std::basic_string(pad_string)) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename traits, + typename alloc + > + const std::basic_string pad ( + const std::basic_string& str, + long pad_length, + const std::basic_string& pad_string + ); + /*! + ensures + - let str_size == static_cast(str.size()) + - returns rpad( lpad(str, (pad_length-str_size)/2 + str_size, pad_string), + pad_length, + pad_string); + !*/ + + template < + typename charT, + typename traits, + typename alloc + > + const std::basic_string pad ( + const std::basic_string& str, + long pad_length, + const charT* pad_string = _dT(charT," ") + ); + /*! + requires + - pad_string == a valid null-terminated C string + ensures + - returns pad(str, pad_length, std::basic_string(pad_string)) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename traits, + typename alloc + > + const std::basic_string left_substr ( + const std::basic_string& str, + const std::basic_string& delim + ); + /*! + ensures + - let delim_pos = str.find_first_of(delim) + - returns str.substr(0,delim_pos) + !*/ + + template < + typename charT, + typename traits, + typename alloc + > + const std::basic_string left_substr ( + const std::basic_string& str, + const charT* delim = _dT(charT," \n\r\t") + ); + /*! + requires + - delim == a valid null-terminated C string + ensures + - returns left_substr(str, std::basic_string(delim)) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename traits, + typename alloc + > + const std::basic_string right_substr ( + const std::basic_string& str, + const std::basic_string& delim + ); + /*! + ensures + - let delim_pos = str.find_last_of(delim) + - if (delim_pos == std::string::npos) then + - returns "" + - else + - returns str.substr(delim_pos+1) + !*/ + + template < + typename charT, + typename traits + typename alloc + > + const std::basic_string right_substr ( + const std::basic_string& str, + const charT* delim = _dT(charT," \n\r\t") + ); + /*! + requires + - delim == a valid null-terminated C string + ensures + - returns right_substr(str, std::basic_string(delim)) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename traits, + typename alloc + > + std::pair, std::basic_string > + split_on_first ( + const std::basic_string& str, + const charT* delim = _dT(charT, " \n\r\t") + ); + /*! + ensures + - This function splits string into two parts, the split is based on the first + occurrence of any character from delim. + - let delim_pos = str.find_first_of(delim) + - if (delim_pos == std::string::npos) then + - returns make_pair(str,"") + - else + - return make_pair(str.substr(0, delim_pos), str.substr(delim_pos+1)); + !*/ + + template < + typename charT, + typename traits, + typename alloc + > + std::pair, std::basic_string > + split_on_first ( + const std::basic_string& str, + const std::basic_string& delim + ); + /*! + requires + - delim == a valid null-terminated C string + ensures + - returns split_on_first(str, delim.c_str()) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename traits, + typename alloc + > + std::pair, std::basic_string > + split_on_last ( + const std::basic_string& str, + const charT* delim = _dT(charT, " \n\r\t") + ); + /*! + ensures + - This function splits string into two parts, the split is based on the last + occurrence of any character from delim. + - let delim_pos = str.find_last_of(delim) + - if (delim_pos == std::string::npos) then + - returns make_pair(str,"") + - else + - return make_pair(str.substr(0, delim_pos), str.substr(delim_pos+1)); + !*/ + + template < + typename charT, + typename traits, + typename alloc + > + std::pair, std::basic_string > + split_on_last ( + const std::basic_string& str, + const std::basic_string& delim + ); + /*! + requires + - delim == a valid null-terminated C string + ensures + - returns split_on_last(str, delim.c_str()) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename traits, + typename alloc + > + const std::vector > split ( + const std::basic_string& str, + const std::basic_string& delim + ); + /*! + ensures + - Breaks the given string str into a sequence of substrings delimited + by characters in delim and returns the results. + - returns a vector V such that: + - V.size() == the number of substrings found in str. + - for all i: V[i] == The ith substring. Note that it will not contain + any delimiter characters (i.e. characters in delim). It will also + never be an empty string. + - V contains the substrings in the order in which they appear in str. + That is, V[0] contains the first substring, V[1] the second, and + so on. + !*/ + + template < + typename charT, + typename traits, + typename alloc + > + const std::vector > split ( + const std::basic_string& str, + const charT* delim = _dT(charT," \n\r\t") + ); + /*! + requires + - trim_chars == a valid null-terminated C string + ensures + - returns split(str, std::basic_string(delim)) + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_STRINg_ABSTRACT_ + diff --git a/ml/dlib/dlib/svm.h b/ml/dlib/dlib/svm.h new file mode 100644 index 000000000..4dc7382c8 --- /dev/null +++ b/ml/dlib/dlib/svm.h @@ -0,0 +1,60 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#ifdef DLIB_ALL_SOURCE_END +#include "dlib_basic_cpp_build_tutorial.txt" +#endif + +#ifndef DLIB_SVm_HEADER +#define DLIB_SVm_HEADER + +#include "svm/svm_rank_trainer.h" +#include "svm/svm.h" +#include "svm/krls.h" +#include "svm/rls.h" +#include "svm/kcentroid.h" +#include "svm/kcentroid_overloads.h" +#include "svm/kkmeans.h" +#include "svm/feature_ranking.h" +#include "svm/rbf_network.h" +#include "svm/linearly_independent_subset_finder.h" +#include "svm/reduced.h" +#include "svm/rvm.h" +#include "svm/pegasos.h" +#include "svm/sparse_kernel.h" +#include "svm/null_trainer.h" +#include "svm/roc_trainer.h" +#include "svm/kernel_matrix.h" +#include "svm/empirical_kernel_map.h" +#include "svm/svm_c_linear_trainer.h" +#include "svm/svm_c_linear_dcd_trainer.h" +#include "svm/svm_c_ekm_trainer.h" +#include "svm/simplify_linear_decision_function.h" +#include "svm/krr_trainer.h" +#include "svm/sort_basis_vectors.h" +#include "svm/svm_c_trainer.h" +#include "svm/svm_one_class_trainer.h" +#include "svm/svr_trainer.h" + +#include "svm/one_vs_one_decision_function.h" +#include "svm/multiclass_tools.h" +#include "svm/cross_validate_multiclass_trainer.h" +#include "svm/cross_validate_regression_trainer.h" +#include "svm/cross_validate_object_detection_trainer.h" +#include "svm/cross_validate_sequence_labeler.h" +#include "svm/cross_validate_sequence_segmenter.h" +#include "svm/cross_validate_assignment_trainer.h" + +#include "svm/one_vs_all_decision_function.h" + +#include "svm/structural_svm_problem.h" +#include "svm/sequence_labeler.h" +#include "svm/assignment_function.h" +#include "svm/track_association_function.h" +#include "svm/active_learning.h" +#include "svm/svr_linear_trainer.h" +#include "svm/sequence_segmenter.h" + +#endif // DLIB_SVm_HEADER + + diff --git a/ml/dlib/dlib/svm/active_learning.h b/ml/dlib/dlib/svm/active_learning.h new file mode 100644 index 000000000..581540e67 --- /dev/null +++ b/ml/dlib/dlib/svm/active_learning.h @@ -0,0 +1,162 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ACTIVE_LEARnING_Hh_ +#define DLIB_ACTIVE_LEARnING_Hh_ + +#include "active_learning_abstract.h" + +#include "svm_c_linear_dcd_trainer.h" +#include + +namespace dlib +{ + + enum active_learning_mode + { + max_min_margin, + ratio_margin + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename kernel_type, + typename in_sample_vector_type, + typename in_scalar_vector_type, + typename in_sample_vector_type2 + > + std::vector impl_rank_unlabeled_training_samples ( + const svm_c_linear_dcd_trainer& trainer, + const in_sample_vector_type& samples, + const in_scalar_vector_type& labels, + const in_sample_vector_type2& unlabeled_samples, + const active_learning_mode mode + ) + { + DLIB_ASSERT(is_vector(unlabeled_samples) && + (samples.size() == 0 || is_learning_problem(samples, labels)) , + "\t std::vector rank_unlabeled_training_samples()" + << "\n\t Invalid inputs were given to this function" + << "\n\t is_vector(unlabeled_samples): " << is_vector(unlabeled_samples) + << "\n\t is_learning_problem(samples, labels): " << is_learning_problem(samples, labels) + << "\n\t samples.size(): " << samples.size() + << "\n\t labels.size(): " << labels.size() + ); + + // If there aren't any training samples then all unlabeled_samples are equally good. + // So just report an arbitrary ordering. + if (samples.size() == 0 || unlabeled_samples.size() == 0) + { + std::vector ret(unlabeled_samples.size()); + for (unsigned long i = 0; i < ret.size(); ++i) + ret[i] = i; + + return ret; + } + + // We are going to score each unlabeled sample and put the score and index into + // results. Then at the end of this function we just sort it and return the indices. + std::vector > results; + results.resize(unlabeled_samples.size()); + + // make sure we use this trainer's ability to warm start itself since that will make + // this whole function run a lot faster. But first, we need to find out what the state + // we will be warm starting from is. + typedef typename svm_c_linear_dcd_trainer::optimizer_state optimizer_state; + optimizer_state state; + trainer.train(samples, labels, state); // call train() just to get state + + decision_function df; + + std::vector temp_samples; + std::vector temp_labels; + temp_samples.reserve(samples.size()+1); + temp_labels.reserve(labels.size()+1); + temp_samples.assign(samples.begin(), samples.end()); + temp_labels.assign(labels.begin(), labels.end()); + temp_samples.resize(temp_samples.size()+1); + temp_labels.resize(temp_labels.size()+1); + + + for (long i = 0; i < unlabeled_samples.size(); ++i) + { + temp_samples.back() = unlabeled_samples(i); + // figure out the margin for each possible labeling of this sample. + + optimizer_state temp(state); + temp_labels.back() = +1; + df = trainer.train(temp_samples, temp_labels, temp); + const double margin_p = temp_labels.back()*df(temp_samples.back()); + + temp = state; + temp_labels.back() = -1; + df = trainer.train(temp_samples, temp_labels, temp); + const double margin_n = temp_labels.back()*df(temp_samples.back()); + + if (mode == max_min_margin) + { + // The score for this sample is its min possible margin over possible labels. + // Therefore, this score measures how much flexibility we have to label this + // sample however we want. The intuition being that the most useful points to + // label are the ones that are still free to obtain either label. + results[i] = std::make_pair(std::min(margin_p, margin_n), i); + } + else + { + // In this case, the score for the sample is a ratio that tells how close the + // two margin values are to each other. The closer they are the better. So in + // this case we are saying we are looking for samples that have the same + // preference for either class label. + if (std::abs(margin_p) >= std::abs(margin_n)) + { + if (margin_p != 0) + results[i] = std::make_pair(margin_n/margin_p, i); + else // if both are == 0 then say 0/0 == 1 + results[i] = std::make_pair(1, i); + } + else + { + results[i] = std::make_pair(margin_p/margin_n, i); + } + } + } + + // sort the results so the highest scoring samples come first. + std::sort(results.rbegin(), results.rend()); + + // transfer results into a vector with just sample indices so we can return it. + std::vector ret(results.size()); + for (unsigned long i = 0; i < ret.size(); ++i) + ret[i] = results[i].second; + return ret; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename kernel_type, + typename in_sample_vector_type, + typename in_scalar_vector_type, + typename in_sample_vector_type2 + > + std::vector rank_unlabeled_training_samples ( + const svm_c_linear_dcd_trainer& trainer, + const in_sample_vector_type& samples, + const in_scalar_vector_type& labels, + const in_sample_vector_type2& unlabeled_samples, + const active_learning_mode mode = max_min_margin + ) + { + return impl_rank_unlabeled_training_samples(trainer, + mat(samples), + mat(labels), + mat(unlabeled_samples), + mode); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_ACTIVE_LEARnING_Hh_ + diff --git a/ml/dlib/dlib/svm/active_learning_abstract.h b/ml/dlib/dlib/svm/active_learning_abstract.h new file mode 100644 index 000000000..76a5120e3 --- /dev/null +++ b/ml/dlib/dlib/svm/active_learning_abstract.h @@ -0,0 +1,75 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_ACTIVE_LEARnING_ABSTRACT_Hh_ +#ifdef DLIB_ACTIVE_LEARnING_ABSTRACT_Hh_ + +#include "svm_c_linear_dcd_trainer_abstract.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + enum active_learning_mode + { + max_min_margin, + ratio_margin + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename kernel_type, + typename in_sample_vector_type, + typename in_scalar_vector_type, + typename in_sample_vector_type2 + > + std::vector rank_unlabeled_training_samples ( + const svm_c_linear_dcd_trainer& trainer, + const in_sample_vector_type& samples, + const in_scalar_vector_type& labels, + const in_sample_vector_type2& unlabeled_samples, + const active_learning_mode mode = max_min_margin + ); + /*! + requires + - if (samples.size() != 0) then + - it must be legal to call trainer.train(samples, labels) + - is_learning_problem(samples, labels) == true + - unlabeled_samples must contain the same kind of vectors as samples. + - unlabeled_samples, samples, and labels must be matrices or types of + objects convertible to a matrix via mat(). + - is_vector(unlabeled_samples) == true + ensures + - Suppose that we wish to learn a binary classifier by calling + trainer.train(samples, labels) but we are also interested in selecting one of + the elements of unlabeled_samples to add to our training data. Since doing + this requires us to find out the label of the sample, a potentially tedious + or expensive process, we would like to select the "best" element from + unlabeled_samples for labeling. The rank_unlabeled_training_samples() + attempts to find this "best" element. In particular, this function returns a + ranked list of all the elements in unlabeled_samples such that that the + "best" elements come first. + - The method used by this function is described in the paper: + Support Vector Machine Active Learning with Applications to Text Classification + by Simon Tong and Daphne Koller + In particular, this function implements the MaxMin Margin and Ratio Margin + selection strategies described in the paper. Moreover, the mode argument + to this function selects which of these strategies is used. + - returns a std::vector V such that: + - V contains a list of all the indices from unlabeled_samples. Moreover, + they are ordered so that the most useful samples come first. + - V.size() == unlabeled_samples.size() + - unlabeled_samples[V[0]] == The best sample to add into the training set. + - unlabeled_samples[V[1]] == The second best sample to add into the training set. + - unlabeled_samples[V[i]] == The i-th best sample to add into the training set. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_ACTIVE_LEARnING_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/svm/assignment_function.h b/ml/dlib/dlib/svm/assignment_function.h new file mode 100644 index 000000000..fdacb2c17 --- /dev/null +++ b/ml/dlib/dlib/svm/assignment_function.h @@ -0,0 +1,255 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ASSIGNMENT_FuNCTION_Hh_ +#define DLIB_ASSIGNMENT_FuNCTION_Hh_ + +#include "assignment_function_abstract.h" +#include "../matrix.h" +#include +#include "../optimization/max_cost_assignment.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + class assignment_function + { + public: + + typedef typename feature_extractor::lhs_element lhs_element; + typedef typename feature_extractor::rhs_element rhs_element; + + + typedef std::pair, std::vector > sample_type; + + typedef std::vector label_type; + typedef label_type result_type; + + assignment_function() + { + weights.set_size(fe.num_features()); + weights = 0; + bias = 0; + force_assignment = false; + } + + explicit assignment_function( + const matrix& weights_, + double bias_ + ) : + weights(weights_), + bias(bias_), + force_assignment(false) + { + // make sure requires clause is not broken + DLIB_ASSERT(fe.num_features() == static_cast(weights_.size()), + "\t assignment_function::assignment_function(weights_)" + << "\n\t These sizes should match" + << "\n\t fe.num_features(): " << fe.num_features() + << "\n\t weights_.size(): " << weights_.size() + << "\n\t this: " << this + ); + + } + + assignment_function( + const matrix& weights_, + double bias_, + const feature_extractor& fe_ + ) : + fe(fe_), + weights(weights_), + bias(bias_), + force_assignment(false) + { + // make sure requires clause is not broken + DLIB_ASSERT(fe_.num_features() == static_cast(weights_.size()), + "\t assignment_function::assignment_function(weights_,fe_)" + << "\n\t These sizes should match" + << "\n\t fe_.num_features(): " << fe_.num_features() + << "\n\t weights_.size(): " << weights_.size() + << "\n\t this: " << this + ); + } + + assignment_function( + const matrix& weights_, + double bias_, + const feature_extractor& fe_, + bool force_assignment_ + ) : + fe(fe_), + weights(weights_), + bias(bias_), + force_assignment(force_assignment_) + { + // make sure requires clause is not broken + DLIB_ASSERT(fe_.num_features() == static_cast(weights_.size()), + "\t assignment_function::assignment_function(weights_,fe_,force_assignment_)" + << "\n\t These sizes should match" + << "\n\t fe_.num_features(): " << fe_.num_features() + << "\n\t weights_.size(): " << weights_.size() + << "\n\t this: " << this + ); + } + + const feature_extractor& get_feature_extractor ( + ) const { return fe; } + + const matrix& get_weights ( + ) const { return weights; } + + double get_bias ( + ) const { return bias; } + + bool forces_assignment ( + ) const { return force_assignment; } + + void predict_assignments ( + const std::vector& lhs, + const std::vector& rhs, + result_type& assignment + ) const + { + assignment.clear(); + + matrix cost; + unsigned long size; + if (force_assignment) + { + size = std::max(lhs.size(), rhs.size()); + } + else + { + size = rhs.size() + lhs.size(); + } + cost.set_size(size, size); + + typedef typename feature_extractor::feature_vector_type feature_vector_type; + feature_vector_type feats; + + // now fill out the cost assignment matrix + for (long r = 0; r < cost.nr(); ++r) + { + for (long c = 0; c < cost.nc(); ++c) + { + if (r < (long)lhs.size() && c < (long)rhs.size()) + { + fe.get_features(lhs[r], rhs[c], feats); + cost(r,c) = dot(weights, feats) + bias; + } + else + { + cost(r,c) = 0; + } + } + } + + + if (cost.size() != 0) + { + // max_cost_assignment() only works with integer matrices, so convert from + // double to integer. + const double scale = (std::numeric_limits::max()/1000)/max(abs(cost)); + matrix int_cost = matrix_cast(round(cost*scale)); + assignment = max_cost_assignment(int_cost); + assignment.resize(lhs.size()); + } + + // adjust assignment so that non-assignments have a value of -1 + for (unsigned long i = 0; i < assignment.size(); ++i) + { + if (assignment[i] >= (long)rhs.size()) + assignment[i] = -1; + } + } + + void predict_assignments ( + const sample_type& item, + result_type& assignment + ) const + { + predict_assignments(item.first, item.second, assignment); + } + + result_type operator()( + const std::vector& lhs, + const std::vector& rhs + ) const + { + result_type temp; + predict_assignments(lhs,rhs,temp); + return temp; + } + + result_type operator() ( + const sample_type& item + ) const + { + return (*this)(item.first, item.second); + } + + private: + + + feature_extractor fe; + matrix weights; + double bias; + bool force_assignment; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + void serialize ( + const assignment_function& item, + std::ostream& out + ) + { + int version = 2; + serialize(version, out); + serialize(item.get_feature_extractor(), out); + serialize(item.get_weights(), out); + serialize(item.get_bias(), out); + serialize(item.forces_assignment(), out); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + void deserialize ( + assignment_function& item, + std::istream& in + ) + { + feature_extractor fe; + matrix weights; + double bias; + bool force_assignment; + int version = 0; + deserialize(version, in); + if (version != 2) + throw serialization_error("Unexpected version found while deserializing dlib::assignment_function."); + + deserialize(fe, in); + deserialize(weights, in); + deserialize(bias, in); + deserialize(force_assignment, in); + + item = assignment_function(weights, bias, fe, force_assignment); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_ASSIGNMENT_FuNCTION_Hh_ + diff --git a/ml/dlib/dlib/svm/assignment_function_abstract.h b/ml/dlib/dlib/svm/assignment_function_abstract.h new file mode 100644 index 000000000..927731856 --- /dev/null +++ b/ml/dlib/dlib/svm/assignment_function_abstract.h @@ -0,0 +1,342 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_ASSIGNMENT_FuNCTION_ABSTRACT_Hh_ +#ifdef DLIB_ASSIGNMENT_FuNCTION_ABSTRACT_Hh_ + +#include +#include "../optimization/max_cost_assignment_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class example_feature_extractor + { + /*! + WHAT THIS OBJECT REPRESENTS + This object defines the interface a feature extractor must implement + if it is to be used with the assignment_function defined at the bottom + of this file. + + The model used by assignment_function objects is the following. + Given two sets of objects, the Left Hand Set (LHS) and Right Hand Set (RHS), + find a one-to-one mapping M from LHS into RHS such that: + M == argmax_m sum_{l in LHS} match_score(l,m(l)) + Where match_score() returns a scalar value indicating how good it is + to say l maps to the RHS element m(l). Additionally, in this model, + m() is allowed to indicate that l doesn't map to anything, and in this + case it is excluded from the sum. + + Finally, match_score() is defined as: + match_score(l,r) == dot(w, PSI(l,r)) + bias + where l is an element of LHS, r is an element of RHS, w is a parameter + vector and bias is a scalar valued parameter. + + Therefore, a feature extractor defines how the PSI() feature vector + is calculated. In particular, PSI() is defined by the get_features() + method of this class. + + THREAD SAFETY + Instances of this object are required to be threadsafe, that is, it should + be safe for multiple threads to make concurrent calls to the member + functions of this object. + + !*/ + + public: + + // This type should be a dlib::matrix capable of storing column vectors + // or an unsorted sparse vector type as defined in dlib/svm/sparse_vector_abstract.h. + typedef matrix_or_sparse_vector_type feature_vector_type; + + // These two typedefs define the types used to represent an element in + // the left hand and right hand sets. You can use any copyable types here. + typedef user_defined_type_1 lhs_element; + typedef user_defined_type_2 rhs_element; + + unsigned long num_features( + ) const; + /*! + ensures + - returns the dimensionality of the PSI() feature vector. + !*/ + + void get_features ( + const lhs_element& left, + const rhs_element& right, + feature_vector_type& feats + ) const; + /*! + ensures + - #feats == PSI(left,right) + (i.e. This function computes a feature vector which, in some sense, + captures information useful for deciding if matching left to right + is "good"). + !*/ + + unsigned long num_nonnegative_weights ( + ) const; + /*! + ensures + - returns the number of elements of the w parameter vector which should be + non-negative. That is, this feature extractor is intended to be used + with w vectors where the first num_nonnegative_weights() elements of w + are >= 0. That is, it should be the case that w(i) >= 0 for all i < + num_nonnegative_weights(). + - Note that num_nonnegative_weights() is just an optional method to allow + you to tell a tool like the structural_assignment_trainer that the + learned w should have a certain number of non-negative elements. + Therefore, if you do not provide a num_nonnegative_weights() method in + your feature extractor then it will default to a value of 0, indicating + that all elements of the w parameter vector may be any value. + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + void serialize( + const example_feature_extractor& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + + void deserialize( + example_feature_extractor& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + class assignment_function + { + /*! + REQUIREMENTS ON feature_extractor + It must be an object that implements an interface compatible with + the example_feature_extractor discussed above. + + WHAT THIS OBJECT REPRESENTS + This object is a tool for solving the optimal assignment problem given a + user defined method for computing the quality of any particular assignment. + + To define this precisely, suppose you have two sets of objects, a + Left Hand Set (LHS) and a Right Hand Set (RHS) and you want to + find a one-to-one mapping M from LHS into RHS such that: + M == argmax_m sum_{l in LHS} match_score(l,m(l)) + Where match_score() returns a scalar value indicating how good it is + to say l maps to the RHS element m(l). Additionally, in this model, + m() is allowed to indicate that l doesn't map to anything, and in this + case it is excluded from the sum. + + Finally, this object supports match_score() functions of the form: + match_score(l,r) == dot(w, PSI(l,r)) + bias + where l is an element of LHS, r is an element of RHS, w is a parameter + vector, bias is a scalar valued parameter, and PSI() is defined by the + feature_extractor template argument. + + THREAD SAFETY + It is always safe to use distinct instances of this object in different + threads. However, when a single instance is shared between threads then + the following rules apply: + It is safe to call the const members of this object from multiple + threads so long as the feature_extractor is also threadsafe. This is + because the const members are purely read-only operations. However, + any operation that modifies an assignment_function is not threadsafe. + !*/ + + public: + + typedef typename feature_extractor::lhs_element lhs_element; + typedef typename feature_extractor::rhs_element rhs_element; + typedef std::vector label_type; + typedef label_type result_type; + typedef std::pair, std::vector > sample_type; + + assignment_function( + ); + /*! + ensures + - #get_feature_extractor() == feature_extractor() + (i.e. it will have its default value) + - #get_weights().size() == #get_feature_extractor().num_features() + - #get_weights() == 0 + - #get_bias() == 0 + - #forces_assignment() == false + !*/ + + explicit assignment_function( + const matrix& weights, + double bias + ); + /*! + requires + - feature_extractor().num_features() == weights.size() + ensures + - #get_feature_extractor() == feature_extractor() + (i.e. it will have its default value) + - #get_weights() == weights + - #get_bias() == bias + - #forces_assignment() == false + !*/ + + assignment_function( + const matrix& weights, + double bias, + const feature_extractor& fe + ); + /*! + requires + - fe.num_features() == weights.size() + ensures + - #get_feature_extractor() == fe + - #get_weights() == weights + - #get_bias() == bias + - #forces_assignment() == false + !*/ + + assignment_function( + const matrix& weights, + double bias, + const feature_extractor& fe, + bool force_assignment + ); + /*! + requires + - fe.num_features() == weights.size() + ensures + - #get_feature_extractor() == fe + - #get_weights() == weights + - #get_bias() == bias + - #forces_assignment() == force_assignment + !*/ + + const feature_extractor& get_feature_extractor ( + ) const; + /*! + ensures + - returns the feature extractor used by this object + !*/ + + const matrix& get_weights ( + ) const; + /*! + ensures + - returns the parameter vector (w) associated with this assignment function. + The length of the vector is get_feature_extractor().num_features(). + !*/ + + double get_bias ( + ) const; + /*! + ensures + - returns the bias parameter associated with this assignment function. + !*/ + + bool forces_assignment ( + ) const; + /*! + ensures + - returns true if this object is in the "forced assignment mode" and false + otherwise. + - When deciding how to match LHS to RHS, this object can operate in one of + two modes. In the default mode, this object will indicate that there is + no match for an element of LHS if the best matching element of RHS would + result in a negative match_score(). However, in the "forced assignment mode", + this object will always make the assignment if there is an available + element in RHS, regardless of the match_score(). + + Another way to understand this distinction is to consider an example. + Suppose LHS and RHS both contain 10 elements. Then in the default mode, + it is possible for this object to indicate that there are anywhere between + 0 to 10 matches between LHS and RHS. However, in forced assignment mode + it will always indicate exactly 10 matches. + !*/ + + result_type operator()( + const std::vector& lhs, + const std::vector& rhs + ) const + /*! + ensures + - returns a vector ASSIGN such that: + - ASSIGN.size() == lhs.size() + - if (ASSIGN[i] != -1) then + - lhs[i] is predicted to associate to rhs[ASSIGN[i]]. + - else + - lhs[i] doesn't associate with anything in rhs. + - All values in ASSIGN which are not equal to -1 are unique. + That is, ASSIGN will never indicate that more than one element + of lhs is assigned to a particular element of rhs. + !*/ + + result_type operator() ( + const sample_type& item + ) const; + /*! + ensures + - returns (*this)(item.first, item.second); + !*/ + + void predict_assignments ( + const sample_type& item, + result_type& assignment + ) const; + /*! + ensures + - #assignment == (*this)(item) + !*/ + + void predict_assignments ( + const std::vector& lhs, + const std::vector& rhs + result_type& assignment + ) const; + /*! + ensures + - #assignment == (*this)(lhs,rhs) + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + void serialize ( + const assignment_function& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + void deserialize ( + assignment_function& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_ASSIGNMENT_FuNCTION_ABSTRACT_Hh_ + diff --git a/ml/dlib/dlib/svm/cross_validate_assignment_trainer.h b/ml/dlib/dlib/svm/cross_validate_assignment_trainer.h new file mode 100644 index 000000000..8166e1c82 --- /dev/null +++ b/ml/dlib/dlib/svm/cross_validate_assignment_trainer.h @@ -0,0 +1,181 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CROSS_VALIDATE_ASSiGNEMNT_TRAINER_Hh_ +#define DLIB_CROSS_VALIDATE_ASSiGNEMNT_TRAINER_Hh_ + +#include "cross_validate_assignment_trainer_abstract.h" +#include +#include "../matrix.h" +#include "svm.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename assignment_function + > + double test_assignment_function ( + const assignment_function& assigner, + const std::vector& samples, + const std::vector& labels + ) + { + // make sure requires clause is not broken +#ifdef ENABLE_ASSERTS + if (assigner.forces_assignment()) + { + DLIB_ASSERT(is_forced_assignment_problem(samples, labels), + "\t double test_assignment_function()" + << "\n\t invalid inputs were given to this function" + << "\n\t is_forced_assignment_problem(samples,labels): " << is_forced_assignment_problem(samples,labels) + << "\n\t is_assignment_problem(samples,labels): " << is_assignment_problem(samples,labels) + << "\n\t is_learning_problem(samples,labels): " << is_learning_problem(samples,labels) + ); + } + else + { + DLIB_ASSERT(is_assignment_problem(samples, labels), + "\t double test_assignment_function()" + << "\n\t invalid inputs were given to this function" + << "\n\t is_assignment_problem(samples,labels): " << is_assignment_problem(samples,labels) + << "\n\t is_learning_problem(samples,labels): " << is_learning_problem(samples,labels) + ); + } +#endif + double total_right = 0; + double total = 0; + for (unsigned long i = 0; i < samples.size(); ++i) + { + const std::vector& out = assigner(samples[i]); + for (unsigned long j = 0; j < out.size(); ++j) + { + if (out[j] == labels[i][j]) + ++total_right; + + ++total; + } + } + + if (total != 0) + return total_right/total; + else + return 1; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type + > + double cross_validate_assignment_trainer ( + const trainer_type& trainer, + const std::vector& samples, + const std::vector& labels, + const long folds + ) + { + // make sure requires clause is not broken +#ifdef ENABLE_ASSERTS + if (trainer.forces_assignment()) + { + DLIB_ASSERT(is_forced_assignment_problem(samples, labels) && + 1 < folds && folds <= static_cast(samples.size()), + "\t double cross_validate_assignment_trainer()" + << "\n\t invalid inputs were given to this function" + << "\n\t samples.size(): " << samples.size() + << "\n\t folds: " << folds + << "\n\t is_forced_assignment_problem(samples,labels): " << is_forced_assignment_problem(samples,labels) + << "\n\t is_assignment_problem(samples,labels): " << is_assignment_problem(samples,labels) + << "\n\t is_learning_problem(samples,labels): " << is_learning_problem(samples,labels) + ); + } + else + { + DLIB_ASSERT(is_assignment_problem(samples, labels) && + 1 < folds && folds <= static_cast(samples.size()), + "\t double cross_validate_assignment_trainer()" + << "\n\t invalid inputs were given to this function" + << "\n\t samples.size(): " << samples.size() + << "\n\t folds: " << folds + << "\n\t is_assignment_problem(samples,labels): " << is_assignment_problem(samples,labels) + << "\n\t is_learning_problem(samples,labels): " << is_learning_problem(samples,labels) + ); + } +#endif + + + + typedef typename trainer_type::sample_type sample_type; + typedef typename trainer_type::label_type label_type; + + const long num_in_test = samples.size()/folds; + const long num_in_train = samples.size() - num_in_test; + + + std::vector samples_test, samples_train; + std::vector labels_test, labels_train; + + + long next_test_idx = 0; + double total_right = 0; + double total = 0; + + + for (long i = 0; i < folds; ++i) + { + samples_test.clear(); + labels_test.clear(); + samples_train.clear(); + labels_train.clear(); + + // load up the test samples + for (long cnt = 0; cnt < num_in_test; ++cnt) + { + samples_test.push_back(samples[next_test_idx]); + labels_test.push_back(labels[next_test_idx]); + next_test_idx = (next_test_idx + 1)%samples.size(); + } + + // load up the training samples + long next = next_test_idx; + for (long cnt = 0; cnt < num_in_train; ++cnt) + { + samples_train.push_back(samples[next]); + labels_train.push_back(labels[next]); + next = (next + 1)%samples.size(); + } + + + const typename trainer_type::trained_function_type& df = trainer.train(samples_train,labels_train); + + // check how good df is on the test data + for (unsigned long i = 0; i < samples_test.size(); ++i) + { + const std::vector& out = df(samples_test[i]); + for (unsigned long j = 0; j < out.size(); ++j) + { + if (out[j] == labels_test[i][j]) + ++total_right; + + ++total; + } + } + + } // for (long i = 0; i < folds; ++i) + + if (total != 0) + return total_right/total; + else + return 1; + + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CROSS_VALIDATE_ASSiGNEMNT_TRAINER_Hh_ + diff --git a/ml/dlib/dlib/svm/cross_validate_assignment_trainer_abstract.h b/ml/dlib/dlib/svm/cross_validate_assignment_trainer_abstract.h new file mode 100644 index 000000000..05dd4758e --- /dev/null +++ b/ml/dlib/dlib/svm/cross_validate_assignment_trainer_abstract.h @@ -0,0 +1,69 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_CROSS_VALIDATE_ASSiGNEMNT_TRAINER_ABSTRACT_Hh_ +#ifdef DLIB_CROSS_VALIDATE_ASSiGNEMNT_TRAINER_ABSTRACT_Hh_ + +#include +#include "../matrix.h" +#include "svm.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename assignment_function + > + double test_assignment_function ( + const assignment_function& assigner, + const std::vector& samples, + const std::vector& labels + ); + /*! + requires + - is_assignment_problem(samples, labels) + - if (assigner.forces_assignment()) then + - is_forced_assignment_problem(samples, labels) + - assignment_function == an instantiation of the dlib::assignment_function + template or an object with a compatible interface. + ensures + - Tests assigner against the given samples and labels and returns the fraction + of assignments predicted correctly. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type + > + double cross_validate_assignment_trainer ( + const trainer_type& trainer, + const std::vector& samples, + const std::vector& labels, + const long folds + ); + /*! + requires + - is_assignment_problem(samples, labels) + - if (trainer.forces_assignment()) then + - is_forced_assignment_problem(samples, labels) + - 1 < folds <= samples.size() + - trainer_type == dlib::structural_assignment_trainer or an object + with a compatible interface. + ensures + - performs k-fold cross validation by using the given trainer to solve the + given assignment learning problem for the given number of folds. Each fold + is tested using the output of the trainer and the fraction of assignments + predicted correctly is returned. + - The number of folds used is given by the folds argument. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CROSS_VALIDATE_ASSiGNEMNT_TRAINER_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/svm/cross_validate_graph_labeling_trainer.h b/ml/dlib/dlib/svm/cross_validate_graph_labeling_trainer.h new file mode 100644 index 000000000..83e4e4048 --- /dev/null +++ b/ml/dlib/dlib/svm/cross_validate_graph_labeling_trainer.h @@ -0,0 +1,258 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CROSS_VALIDATE_GRAPh_LABELING_TRAINER_Hh_ +#define DLIB_CROSS_VALIDATE_GRAPh_LABELING_TRAINER_Hh_ + +#include "../array.h" +#include "../graph_cuts/min_cut.h" +#include "svm.h" +#include "cross_validate_graph_labeling_trainer_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename graph_labeler, + typename graph_type + > + matrix test_graph_labeling_function ( + const graph_labeler& labeler, + const dlib::array& samples, + const std::vector >& labels, + const std::vector >& losses + ) + { +#ifdef ENABLE_ASSERTS + std::string reason_for_failure; + DLIB_ASSERT(is_graph_labeling_problem(samples, labels, reason_for_failure) , + "\t matrix test_graph_labeling_function()" + << "\n\t invalid inputs were given to this function" + << "\n\t samples.size(): " << samples.size() + << "\n\t reason_for_failure: " << reason_for_failure + ); + DLIB_ASSERT((losses.size() == 0 || sizes_match(labels, losses) == true) && + all_values_are_nonnegative(losses) == true, + "\t matrix test_graph_labeling_function()" + << "\n\t Invalid inputs were given to this function." + << "\n\t labels.size(): " << labels.size() + << "\n\t losses.size(): " << losses.size() + << "\n\t sizes_match(labels,losses): " << sizes_match(labels,losses) + << "\n\t all_values_are_nonnegative(losses): " << all_values_are_nonnegative(losses) + ); +#endif + + std::vector temp; + double num_pos_correct = 0; + double num_pos = 0; + double num_neg_correct = 0; + double num_neg = 0; + + for (unsigned long i = 0; i < samples.size(); ++i) + { + labeler(samples[i], temp); + + for (unsigned long j = 0; j < labels[i].size(); ++j) + { + // What is the loss for this example? It's just 1 unless we have a + // per example loss vector. + const double loss = (losses.size() == 0) ? 1.0 : losses[i][j]; + + if (labels[i][j]) + { + num_pos += loss; + if (temp[j]) + num_pos_correct += loss; + } + else + { + num_neg += loss; + if (!temp[j]) + num_neg_correct += loss; + } + } + } + + matrix res; + if (num_pos != 0) + res(0) = num_pos_correct/num_pos; + else + res(0) = 1; + if (num_neg != 0) + res(1) = num_neg_correct/num_neg; + else + res(1) = 1; + return res; + } + + template < + typename graph_labeler, + typename graph_type + > + matrix test_graph_labeling_function ( + const graph_labeler& labeler, + const dlib::array& samples, + const std::vector >& labels + ) + { + std::vector > losses; + return test_graph_labeling_function(labeler, samples, labels, losses); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type, + typename graph_type + > + matrix cross_validate_graph_labeling_trainer ( + const trainer_type& trainer, + const dlib::array& samples, + const std::vector >& labels, + const std::vector >& losses, + const long folds + ) + { +#ifdef ENABLE_ASSERTS + std::string reason_for_failure; + DLIB_ASSERT(is_graph_labeling_problem(samples, labels, reason_for_failure), + "\t matrix cross_validate_graph_labeling_trainer()" + << "\n\t invalid inputs were given to this function" + << "\n\t samples.size(): " << samples.size() + << "\n\t reason_for_failure: " << reason_for_failure + ); + DLIB_ASSERT( 1 < folds && folds <= static_cast(samples.size()), + "\t matrix cross_validate_graph_labeling_trainer()" + << "\n\t invalid inputs were given to this function" + << "\n\t folds: " << folds + ); + DLIB_ASSERT((losses.size() == 0 || sizes_match(labels, losses) == true) && + all_values_are_nonnegative(losses) == true, + "\t matrix cross_validate_graph_labeling_trainer()" + << "\n\t Invalid inputs were given to this function." + << "\n\t labels.size(): " << labels.size() + << "\n\t losses.size(): " << losses.size() + << "\n\t sizes_match(labels,losses): " << sizes_match(labels,losses) + << "\n\t all_values_are_nonnegative(losses): " << all_values_are_nonnegative(losses) + ); +#endif + + typedef std::vector label_type; + + const long num_in_test = samples.size()/folds; + const long num_in_train = samples.size() - num_in_test; + + + dlib::array samples_test, samples_train; + std::vector labels_test, labels_train; + std::vector > losses_test, losses_train; + + + long next_test_idx = 0; + + std::vector temp; + double num_pos_correct = 0; + double num_pos = 0; + double num_neg_correct = 0; + double num_neg = 0; + + graph_type gtemp; + + for (long i = 0; i < folds; ++i) + { + samples_test.clear(); + labels_test.clear(); + losses_test.clear(); + samples_train.clear(); + labels_train.clear(); + losses_train.clear(); + + // load up the test samples + for (long cnt = 0; cnt < num_in_test; ++cnt) + { + copy_graph(samples[next_test_idx], gtemp); + samples_test.push_back(gtemp); + labels_test.push_back(labels[next_test_idx]); + if (losses.size() != 0) + losses_test.push_back(losses[next_test_idx]); + next_test_idx = (next_test_idx + 1)%samples.size(); + } + + // load up the training samples + long next = next_test_idx; + for (long cnt = 0; cnt < num_in_train; ++cnt) + { + copy_graph(samples[next], gtemp); + samples_train.push_back(gtemp); + labels_train.push_back(labels[next]); + if (losses.size() != 0) + losses_train.push_back(losses[next]); + next = (next + 1)%samples.size(); + } + + + const typename trainer_type::trained_function_type& labeler = trainer.train(samples_train,labels_train,losses_train); + + // check how good labeler is on the test data + for (unsigned long i = 0; i < samples_test.size(); ++i) + { + labeler(samples_test[i], temp); + for (unsigned long j = 0; j < labels_test[i].size(); ++j) + { + // What is the loss for this example? It's just 1 unless we have a + // per example loss vector. + const double loss = (losses_test.size() == 0) ? 1.0 : losses_test[i][j]; + + if (labels_test[i][j]) + { + num_pos += loss; + if (temp[j]) + num_pos_correct += loss; + } + else + { + num_neg += loss; + if (!temp[j]) + num_neg_correct += loss; + } + } + } + + } // for (long i = 0; i < folds; ++i) + + + matrix res; + if (num_pos != 0) + res(0) = num_pos_correct/num_pos; + else + res(0) = 1; + if (num_neg != 0) + res(1) = num_neg_correct/num_neg; + else + res(1) = 1; + return res; + } + + template < + typename trainer_type, + typename graph_type + > + matrix cross_validate_graph_labeling_trainer ( + const trainer_type& trainer, + const dlib::array& samples, + const std::vector >& labels, + const long folds + ) + { + std::vector > losses; + return cross_validate_graph_labeling_trainer(trainer, samples, labels, losses, folds); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CROSS_VALIDATE_GRAPh_LABELING_TRAINER_Hh_ + + diff --git a/ml/dlib/dlib/svm/cross_validate_graph_labeling_trainer_abstract.h b/ml/dlib/dlib/svm/cross_validate_graph_labeling_trainer_abstract.h new file mode 100644 index 000000000..cda4af91e --- /dev/null +++ b/ml/dlib/dlib/svm/cross_validate_graph_labeling_trainer_abstract.h @@ -0,0 +1,147 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_CROSS_VALIDATE_GRAPh_LABELING_TRAINER_ABSTRACT_Hh_ +#ifdef DLIB_CROSS_VALIDATE_GRAPh_LABELING_TRAINER_ABSTRACT_Hh_ + +#include "../array/array_kernel_abstract.h" +#include +#include "../matrix/matrix_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename graph_labeler, + typename graph_type + > + matrix test_graph_labeling_function ( + const graph_labeler& labeler, + const dlib::array& samples, + const std::vector >& labels + ); + /*! + requires + - is_graph_labeling_problem(samples,labels) == true + - graph_labeler == an object with an interface compatible with the + dlib::graph_labeler object. + - the following must be a valid expression: labeler(samples[0]); + ensures + - This function tests the accuracy of the given graph labeler against + the sample graphs and their associated labels. In particular, this + function returns a matrix R such that: + - R(0) == The fraction of nodes which are supposed to have a label of + true that are labeled as such by the labeler. + - R(1) == The fraction of nodes which are supposed to have a label of + false that are labeled as such by the labeler. + Therefore, if R is [1,1] then the labeler makes perfect predictions while + an R of [0,0] indicates that it gets everything wrong. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename graph_labeler, + typename graph_type + > + matrix test_graph_labeling_function ( + const graph_labeler& labeler, + const dlib::array& samples, + const std::vector >& labels, + const std::vector >& losses + ); + /*! + requires + - is_graph_labeling_problem(samples,labels) == true + - graph_labeler == an object with an interface compatible with the + dlib::graph_labeler object. + - the following must be a valid expression: labeler(samples[0]); + - if (losses.size() != 0) then + - sizes_match(labels, losses) == true + - all_values_are_nonnegative(losses) == true + ensures + - This overload of test_graph_labeling_function() does the same thing as the + one defined above, except that instead of counting 1 for each labeling + mistake, it weights each mistake according to the corresponding value in + losses. That is, instead of counting a value of 1 for making a mistake on + samples[i].node(j), this routine counts a value of losses[i][j]. Under this + interpretation, the loss values represent how useful it is to correctly label + each node. Therefore, the values returned represent fractions of overall + labeling utility rather than raw labeling accuracy. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type, + typename graph_type + > + matrix cross_validate_graph_labeling_trainer ( + const trainer_type& trainer, + const dlib::array& samples, + const std::vector >& labels, + const long folds + ); + /*! + requires + - is_graph_labeling_problem(samples,labels) == true + - 1 < folds <= samples.size() + - trainer_type == an object which trains some kind of graph labeler object + (e.g. structural_graph_labeling_trainer) + ensures + - performs k-fold cross validation by using the given trainer to solve the + given graph labeling problem for the given number of folds. Each fold + is tested using the output of the trainer and the average classification + accuracy from all folds is returned. In particular, this function returns + a matrix R such that: + - R(0) == The fraction of nodes which are supposed to have a label of + true that are labeled as such by the learned labeler. + - R(1) == The fraction of nodes which are supposed to have a label of + false that are labeled as such by the learned labeler. + Therefore, if R is [1,1] then the labeler makes perfect predictions while + an R of [0,0] indicates that it gets everything wrong. + - The number of folds used is given by the folds argument. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type, + typename graph_type + > + matrix cross_validate_graph_labeling_trainer ( + const trainer_type& trainer, + const dlib::array& samples, + const std::vector >& labels, + const std::vector >& losses, + const long folds + ); + /*! + requires + - is_graph_labeling_problem(samples,labels) == true + - 1 < folds <= samples.size() + - trainer_type == an object which trains some kind of graph labeler object + (e.g. structural_graph_labeling_trainer) + - if (losses.size() != 0) then + - sizes_match(labels, losses) == true + - all_values_are_nonnegative(losses) == true + ensures + - This overload of cross_validate_graph_labeling_trainer() does the same thing + as the one defined above, except that instead of counting 1 for each labeling + mistake, it weights each mistake according to the corresponding value in + losses. That is, instead of counting a value of 1 for making a mistake on + samples[i].node(j), this routine counts a value of losses[i][j]. Under this + interpretation, the loss values represent how useful it is to correctly label + each node. Therefore, the values returned represent fractions of overall + labeling utility rather than raw labeling accuracy. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CROSS_VALIDATE_GRAPh_LABELING_TRAINER_ABSTRACT_Hh_ + + + diff --git a/ml/dlib/dlib/svm/cross_validate_multiclass_trainer.h b/ml/dlib/dlib/svm/cross_validate_multiclass_trainer.h new file mode 100644 index 000000000..be8fa3f3f --- /dev/null +++ b/ml/dlib/dlib/svm/cross_validate_multiclass_trainer.h @@ -0,0 +1,208 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CROSS_VALIDATE_MULTICLASS_TRaINER_Hh_ +#define DLIB_CROSS_VALIDATE_MULTICLASS_TRaINER_Hh_ + +#include +#include "../matrix.h" +#include "cross_validate_multiclass_trainer_abstract.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename dec_funct_type, + typename sample_type, + typename label_type + > + const matrix test_multiclass_decision_function ( + const dec_funct_type& dec_funct, + const std::vector& x_test, + const std::vector& y_test + ) + { + + // make sure requires clause is not broken + DLIB_ASSERT( is_learning_problem(x_test,y_test) == true, + "\tmatrix test_multiclass_decision_function()" + << "\n\t invalid inputs were given to this function" + << "\n\t is_learning_problem(x_test,y_test): " + << is_learning_problem(x_test,y_test)); + + + const std::vector all_labels = dec_funct.get_labels(); + + // make a lookup table that maps from labels to their index in all_labels + std::map label_to_int; + for (unsigned long i = 0; i < all_labels.size(); ++i) + label_to_int[all_labels[i]] = i; + + matrix res; + res.set_size(all_labels.size(), all_labels.size()); + + res = 0; + + typename std::map::const_iterator iter; + + // now test this trained object + for (unsigned long i = 0; i < x_test.size(); ++i) + { + iter = label_to_int.find(y_test[i]); + // ignore samples with labels that the decision function doesn't know about. + if (iter == label_to_int.end()) + continue; + + const unsigned long truth = iter->second; + const unsigned long pred = label_to_int[dec_funct(x_test[i])]; + + res(truth,pred) += 1; + } + + return res; + } + +// ---------------------------------------------------------------------------------------- + + class cross_validation_error : public dlib::error + { + public: + cross_validation_error(const std::string& msg) : dlib::error(msg){}; + }; + + template < + typename trainer_type, + typename sample_type, + typename label_type + > + const matrix cross_validate_multiclass_trainer ( + const trainer_type& trainer, + const std::vector& x, + const std::vector& y, + const long folds + ) + { + typedef typename trainer_type::mem_manager_type mem_manager_type; + + // make sure requires clause is not broken + DLIB_ASSERT(is_learning_problem(x,y) == true && + 1 < folds && folds <= static_cast(x.size()), + "\tmatrix cross_validate_multiclass_trainer()" + << "\n\t invalid inputs were given to this function" + << "\n\t x.size(): " << x.size() + << "\n\t folds: " << folds + << "\n\t is_learning_problem(x,y): " << is_learning_problem(x,y) + ); + + const std::vector all_labels = select_all_distinct_labels(y); + + // count the number of times each label shows up + std::map label_counts; + for (unsigned long i = 0; i < y.size(); ++i) + label_counts[y[i]] += 1; + + + // figure out how many samples from each class will be in the test and train splits + std::map num_in_test, num_in_train; + for (typename std::map::iterator i = label_counts.begin(); i != label_counts.end(); ++i) + { + const long in_test = i->second/folds; + if (in_test == 0) + { + std::ostringstream sout; + sout << "In dlib::cross_validate_multiclass_trainer(), the number of folds was larger" << std::endl; + sout << "than the number of elements of one of the training classes." << std::endl; + sout << " folds: "<< folds << std::endl; + sout << " size of class " << i->first << ": "<< i->second << std::endl; + throw cross_validation_error(sout.str()); + } + num_in_test[i->first] = in_test; + num_in_train[i->first] = i->second - in_test; + } + + + + std::vector x_test, x_train; + std::vector y_test, y_train; + + matrix res; + + std::map next_test_idx; + for (unsigned long i = 0; i < all_labels.size(); ++i) + next_test_idx[all_labels[i]] = 0; + + label_type label; + + for (long i = 0; i < folds; ++i) + { + x_test.clear(); + y_test.clear(); + x_train.clear(); + y_train.clear(); + + // load up the test samples + for (unsigned long j = 0; j < all_labels.size(); ++j) + { + label = all_labels[j]; + long next = next_test_idx[label]; + + long cur = 0; + const long num_needed = num_in_test[label]; + while (cur < num_needed) + { + if (y[next] == label) + { + x_test.push_back(x[next]); + y_test.push_back(label); + ++cur; + } + next = (next + 1)%x.size(); + } + + next_test_idx[label] = next; + } + + // load up the training samples + for (unsigned long j = 0; j < all_labels.size(); ++j) + { + label = all_labels[j]; + long next = next_test_idx[label]; + + long cur = 0; + const long num_needed = num_in_train[label]; + while (cur < num_needed) + { + if (y[next] == label) + { + x_train.push_back(x[next]); + y_train.push_back(label); + ++cur; + } + next = (next + 1)%x.size(); + } + } + + + try + { + // do the training and testing + res += test_multiclass_decision_function(trainer.train(x_train,y_train),x_test,y_test); + } + catch (invalid_nu_error&) + { + // just ignore cases which result in an invalid nu + } + + } // for (long i = 0; i < folds; ++i) + + return res; + } + +} + +// ---------------------------------------------------------------------------------------- + +#endif // DLIB_CROSS_VALIDATE_MULTICLASS_TRaINER_Hh_ + diff --git a/ml/dlib/dlib/svm/cross_validate_multiclass_trainer_abstract.h b/ml/dlib/dlib/svm/cross_validate_multiclass_trainer_abstract.h new file mode 100644 index 000000000..f84503cdc --- /dev/null +++ b/ml/dlib/dlib/svm/cross_validate_multiclass_trainer_abstract.h @@ -0,0 +1,99 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_CROSS_VALIDATE_MULTICLASS_TRaINER_ABSTRACT_Hh_ +#ifdef DLIB_CROSS_VALIDATE_MULTICLASS_TRaINER_ABSTRACT_Hh_ + +#include +#include "../matrix.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename dec_funct_type, + typename sample_type, + typename label_type + > + const matrix test_multiclass_decision_function ( + const dec_funct_type& dec_funct, + const std::vector& x_test, + const std::vector& y_test + ); + /*! + requires + - is_learning_problem(x_test, y_test) + - dec_funct_type == some kind of multiclass decision function object + (e.g. one_vs_one_decision_function) + ensures + - Tests dec_funct against the given samples in x_test and labels in y_test + and returns a confusion matrix summarizing the results. + - let L = dec_funct.get_labels(). Then the confusion matrix C returned + by this function has the following properties. + - C.nr() == C.nc() == L.size() + - C(r,c) == the number of times a sample with label L(r) was predicted + to have a label of L(c) + - Any samples with a y_test value not in L are ignored. That is, samples + with labels the decision function hasn't ever seen before are ignored. + !*/ + +// ---------------------------------------------------------------------------------------- + + class cross_validation_error : public dlib::error + { + /*! + This is the exception class used by the cross_validate_multiclass_trainer() + routine. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type, + typename sample_type, + typename label_type + > + const matrix cross_validate_multiclass_trainer ( + const trainer_type& trainer, + const std::vector& x, + const std::vector& y, + const long folds + ); + /*! + requires + - is_learning_problem(x,y) + - 1 < folds <= x.size() + - trainer_type == some kind of multiclass classification trainer object (e.g. one_vs_one_trainer) + ensures + - performs k-fold cross validation by using the given trainer to solve the + given multiclass classification problem for the given number of folds. + Each fold is tested using the output of the trainer and the confusion + matrix from all folds is summed and returned. + - The total confusion matrix is computed by running test_binary_decision_function() + on each fold and summing its output. + - The number of folds used is given by the folds argument. + - let L = select_all_distinct_labels(y). Then the confusion matrix C returned + by this function has the following properties. + - C.nr() == C.nc() == L.size() + - C(r,c) == the number of times a sample with label L(r) was predicted + to have a label of L(c) + + Note that sum(C) might be slightly less than x.size(). This happens if the number of + samples in a class is not an even multiple of folds. This is because each fold has the + same number of test samples in it and so if the number of samples in a class isn't a + multiple of folds then a few are not tested. + throws + - cross_validation_error + This exception is thrown if one of the classes has fewer samples than + the number of requested folds. + !*/ + +} + +// ---------------------------------------------------------------------------------------- + +#endif // DLIB_CROSS_VALIDATE_MULTICLASS_TRaINER_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/svm/cross_validate_object_detection_trainer.h b/ml/dlib/dlib/svm/cross_validate_object_detection_trainer.h new file mode 100644 index 000000000..7cb38f0b7 --- /dev/null +++ b/ml/dlib/dlib/svm/cross_validate_object_detection_trainer.h @@ -0,0 +1,430 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CROSS_VALIDATE_OBJECT_DETECTION_TRaINER_Hh_ +#define DLIB_CROSS_VALIDATE_OBJECT_DETECTION_TRaINER_Hh_ + +#include "cross_validate_object_detection_trainer_abstract.h" +#include +#include "../matrix.h" +#include "svm.h" +#include "../geometry.h" +#include "../image_processing/full_object_detection.h" +#include "../image_processing/box_overlap_testing.h" +#include "../statistics.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + inline unsigned long number_of_truth_hits ( + const std::vector& truth_boxes, + const std::vector& ignore, + const std::vector >& boxes, + const test_box_overlap& overlap_tester, + std::vector >& all_dets, + unsigned long& missing_detections, + const test_box_overlap& overlaps_ignore_tester + ) + /*! + ensures + - returns the number of elements in truth_boxes which are overlapped by an + element of boxes. In this context, two boxes, A and B, overlap if and only if + overlap_tester(A,B) == true. + - No element of boxes is allowed to account for more than one element of truth_boxes. + - The returned number is in the range [0,truth_boxes.size()] + - Adds the score for each box from boxes into all_dets and labels each with + a bool indicating if it hit a truth box. Note that we skip boxes that + don't hit any truth boxes and match an ignore box. + - Adds the number of truth boxes which didn't have any hits into + missing_detections. + !*/ + { + if (boxes.size() == 0) + { + missing_detections += truth_boxes.size(); + return 0; + } + + unsigned long count = 0; + std::vector used(boxes.size(),false); + for (unsigned long i = 0; i < truth_boxes.size(); ++i) + { + bool found_match = false; + // Find the first box that hits truth_boxes[i] + for (unsigned long j = 0; j < boxes.size(); ++j) + { + if (used[j]) + continue; + + if (overlap_tester(truth_boxes[i].get_rect(), boxes[j].second)) + { + used[j] = true; + ++count; + found_match = true; + break; + } + } + + if (!found_match) + ++missing_detections; + } + + for (unsigned long i = 0; i < boxes.size(); ++i) + { + // only out put boxes if they match a truth box or are not ignored. + if (used[i] || !overlaps_any_box(overlaps_ignore_tester, ignore, boxes[i].second)) + { + all_dets.push_back(std::make_pair(boxes[i].first, used[i])); + } + } + + return count; + } + + inline unsigned long number_of_truth_hits ( + const std::vector& truth_boxes, + const std::vector& ignore, + const std::vector >& boxes, + const test_box_overlap& overlap_tester, + std::vector >& all_dets, + unsigned long& missing_detections + ) + { + return number_of_truth_hits(truth_boxes, ignore, boxes, overlap_tester, all_dets, missing_detections, overlap_tester); + } + + // ------------------------------------------------------------------------------------ + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename object_detector_type, + typename image_array_type + > + const matrix test_object_detection_function ( + object_detector_type& detector, + const image_array_type& images, + const std::vector >& truth_dets, + const std::vector >& ignore, + const test_box_overlap& overlap_tester = test_box_overlap(), + const double adjust_threshold = 0 + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( is_learning_problem(images,truth_dets) == true && + ignore.size() == images.size(), + "\t matrix test_object_detection_function()" + << "\n\t invalid inputs were given to this function" + << "\n\t is_learning_problem(images,truth_dets): " << is_learning_problem(images,truth_dets) + << "\n\t ignore.size(): " << ignore.size() + << "\n\t images.size(): " << images.size() + ); + + + + double correct_hits = 0; + double total_true_targets = 0; + + std::vector > all_dets; + unsigned long missing_detections = 0; + + + for (unsigned long i = 0; i < images.size(); ++i) + { + std::vector > hits; + detector(images[i], hits, adjust_threshold); + + correct_hits += impl::number_of_truth_hits(truth_dets[i], ignore[i], hits, overlap_tester, all_dets, missing_detections); + total_true_targets += truth_dets[i].size(); + } + + std::sort(all_dets.rbegin(), all_dets.rend()); + + double precision, recall; + + double total_hits = all_dets.size(); + + if (total_hits == 0) + precision = 1; + else + precision = correct_hits / total_hits; + + if (total_true_targets == 0) + recall = 1; + else + recall = correct_hits / total_true_targets; + + matrix res; + res = precision, recall, average_precision(all_dets, missing_detections); + return res; + } + + template < + typename object_detector_type, + typename image_array_type + > + const matrix test_object_detection_function ( + object_detector_type& detector, + const image_array_type& images, + const std::vector >& truth_dets, + const std::vector >& ignore, + const test_box_overlap& overlap_tester = test_box_overlap(), + const double adjust_threshold = 0 + ) + { + // convert into a list of regular rectangles. + std::vector > rects(truth_dets.size()); + for (unsigned long i = 0; i < truth_dets.size(); ++i) + { + for (unsigned long j = 0; j < truth_dets[i].size(); ++j) + { + rects[i].push_back(full_object_detection(truth_dets[i][j])); + } + } + + return test_object_detection_function(detector, images, rects, ignore, overlap_tester, adjust_threshold); + } + + template < + typename object_detector_type, + typename image_array_type + > + const matrix test_object_detection_function ( + object_detector_type& detector, + const image_array_type& images, + const std::vector >& truth_dets, + const test_box_overlap& overlap_tester = test_box_overlap(), + const double adjust_threshold = 0 + ) + { + std::vector > ignore(images.size()); + return test_object_detection_function(detector,images,truth_dets,ignore, overlap_tester, adjust_threshold); + } + + template < + typename object_detector_type, + typename image_array_type + > + const matrix test_object_detection_function ( + object_detector_type& detector, + const image_array_type& images, + const std::vector >& truth_dets, + const test_box_overlap& overlap_tester = test_box_overlap(), + const double adjust_threshold = 0 + ) + { + std::vector > ignore(images.size()); + return test_object_detection_function(detector,images,truth_dets,ignore, overlap_tester, adjust_threshold); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + namespace impl + { + template < + typename array_type + > + struct array_subset_helper + { + typedef typename array_type::mem_manager_type mem_manager_type; + + array_subset_helper ( + const array_type& array_, + const std::vector& idx_set_ + ) : + array(array_), + idx_set(idx_set_) + { + } + + unsigned long size() const { return idx_set.size(); } + + typedef typename array_type::type type; + const type& operator[] ( + unsigned long idx + ) const { return array[idx_set[idx]]; } + + private: + const array_type& array; + const std::vector& idx_set; + }; + + template < + typename T + > + const matrix_op > > mat ( + const array_subset_helper& m + ) + { + typedef op_array_to_mat > op; + return matrix_op(op(m)); + } + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type, + typename image_array_type + > + const matrix cross_validate_object_detection_trainer ( + const trainer_type& trainer, + const image_array_type& images, + const std::vector >& truth_dets, + const std::vector >& ignore, + const long folds, + const test_box_overlap& overlap_tester = test_box_overlap(), + const double adjust_threshold = 0 + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( is_learning_problem(images,truth_dets) == true && + ignore.size() == images.size() && + 1 < folds && folds <= static_cast(images.size()), + "\t matrix cross_validate_object_detection_trainer()" + << "\n\t invalid inputs were given to this function" + << "\n\t is_learning_problem(images,truth_dets): " << is_learning_problem(images,truth_dets) + << "\n\t folds: "<< folds + << "\n\t ignore.size(): " << ignore.size() + << "\n\t images.size(): " << images.size() + ); + + double correct_hits = 0; + double total_true_targets = 0; + + const long test_size = images.size()/folds; + + std::vector > all_dets; + unsigned long missing_detections = 0; + unsigned long test_idx = 0; + for (long iter = 0; iter < folds; ++iter) + { + std::vector train_idx_set; + std::vector test_idx_set; + + for (long i = 0; i < test_size; ++i) + test_idx_set.push_back(test_idx++); + + unsigned long train_idx = test_idx%images.size(); + std::vector > training_rects; + std::vector > training_ignores; + for (unsigned long i = 0; i < images.size()-test_size; ++i) + { + training_rects.push_back(truth_dets[train_idx]); + training_ignores.push_back(ignore[train_idx]); + train_idx_set.push_back(train_idx); + train_idx = (train_idx+1)%images.size(); + } + + + impl::array_subset_helper array_subset(images, train_idx_set); + typename trainer_type::trained_function_type detector = trainer.train(array_subset, training_rects, training_ignores, overlap_tester); + for (unsigned long i = 0; i < test_idx_set.size(); ++i) + { + std::vector > hits; + detector(images[test_idx_set[i]], hits, adjust_threshold); + + correct_hits += impl::number_of_truth_hits(truth_dets[test_idx_set[i]], ignore[i], hits, overlap_tester, all_dets, missing_detections); + total_true_targets += truth_dets[test_idx_set[i]].size(); + } + + } + + std::sort(all_dets.rbegin(), all_dets.rend()); + + + double precision, recall; + + double total_hits = all_dets.size(); + + if (total_hits == 0) + precision = 1; + else + precision = correct_hits / total_hits; + + if (total_true_targets == 0) + recall = 1; + else + recall = correct_hits / total_true_targets; + + matrix res; + res = precision, recall, average_precision(all_dets, missing_detections); + return res; + } + + template < + typename trainer_type, + typename image_array_type + > + const matrix cross_validate_object_detection_trainer ( + const trainer_type& trainer, + const image_array_type& images, + const std::vector >& truth_dets, + const std::vector >& ignore, + const long folds, + const test_box_overlap& overlap_tester = test_box_overlap(), + const double adjust_threshold = 0 + ) + { + // convert into a list of regular rectangles. + std::vector > dets(truth_dets.size()); + for (unsigned long i = 0; i < truth_dets.size(); ++i) + { + for (unsigned long j = 0; j < truth_dets[i].size(); ++j) + { + dets[i].push_back(full_object_detection(truth_dets[i][j])); + } + } + + return cross_validate_object_detection_trainer(trainer, images, dets, ignore, folds, overlap_tester, adjust_threshold); + } + + template < + typename trainer_type, + typename image_array_type + > + const matrix cross_validate_object_detection_trainer ( + const trainer_type& trainer, + const image_array_type& images, + const std::vector >& truth_dets, + const long folds, + const test_box_overlap& overlap_tester = test_box_overlap(), + const double adjust_threshold = 0 + ) + { + const std::vector > ignore(images.size()); + return cross_validate_object_detection_trainer(trainer,images,truth_dets,ignore,folds,overlap_tester,adjust_threshold); + } + + template < + typename trainer_type, + typename image_array_type + > + const matrix cross_validate_object_detection_trainer ( + const trainer_type& trainer, + const image_array_type& images, + const std::vector >& truth_dets, + const long folds, + const test_box_overlap& overlap_tester = test_box_overlap(), + const double adjust_threshold = 0 + ) + { + const std::vector > ignore(images.size()); + return cross_validate_object_detection_trainer(trainer,images,truth_dets,ignore,folds,overlap_tester,adjust_threshold); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CROSS_VALIDATE_OBJECT_DETECTION_TRaINER_Hh_ + diff --git a/ml/dlib/dlib/svm/cross_validate_object_detection_trainer_abstract.h b/ml/dlib/dlib/svm/cross_validate_object_detection_trainer_abstract.h new file mode 100644 index 000000000..575ed77fb --- /dev/null +++ b/ml/dlib/dlib/svm/cross_validate_object_detection_trainer_abstract.h @@ -0,0 +1,297 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_CROSS_VALIDATE_OBJECT_DETECTION_TRaINER_ABSTRACT_Hh_ +#ifdef DLIB_CROSS_VALIDATE_OBJECT_DETECTION_TRaINER_ABSTRACT_Hh_ + +#include +#include "../matrix.h" +#include "../geometry.h" +#include "../image_processing/full_object_detection_abstract.h" +#include "../dnn/layers_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename object_detector_type, + typename image_array_type + > + const matrix test_object_detection_function ( + object_detector_type& detector, + const image_array_type& images, + const std::vector >& truth_dets, + const std::vector >& ignore, + const test_box_overlap& overlap_tester = test_box_overlap(), + const double adjust_threshold = 0 + ); + /*! + requires + - is_learning_problem(images,truth_dets) + - images.size() == ignore.size() + - object_detector_type == some kind of object detector function object + (e.g. object_detector) + - image_array_type must be an implementation of dlib/array/array_kernel_abstract.h + and it must contain objects which can be accepted by detector(). + ensures + - Tests the given detector against the supplied object detection problem and + returns the precision, recall, and average precision. Note that the task is + to predict, for each images[i], the set of object locations given by + truth_dets[i]. Additionally, any detections on image[i] that match a box in + ignore[i] are ignored. That is, detections matching a box in ignore[i] do + not count as a false alarm and similarly if any element of ignore[i] goes + undetected it does not count as a missed detection. So we say that ignore[i] + contains a set of boxes that we "don't care" if they are detected or not. + - In particular, returns a matrix M such that: + - M(0) == the precision of the detector object. This is a number + in the range [0,1] which measures the fraction of detector outputs + which correspond to a real target. A value of 1 means the detector + never produces any false alarms while a value of 0 means it only + produces false alarms. + - M(1) == the recall of the detector object. This is a number in the + range [0,1] which measures the fraction of targets found by the + detector. A value of 1 means the detector found all the targets + in truth_dets while a value of 0 means the detector didn't locate + any of the targets. + - M(2) == the average precision of the detector object. This is a number + in the range [0,1] which measures the overall quality of the detector. + We compute this by taking all the detections output by the detector and + ordering them in descending order of their detection scores. Then we use + the average_precision() routine to score the ranked listing and store the + output into M(2). + - This function considers a detector output D to match a rectangle T if and + only if overlap_tester(T,D) returns true. + - Note that you can use the adjust_threshold argument to raise or lower the + detection threshold. This value is passed into the identically named + argument to the detector object and therefore influences the number of + output detections. It can be useful, for example, to lower the detection + threshold because it results in more detections being output by the + detector, and therefore provides more information in the ranking, + possibly raising the average precision. + !*/ + + template < + typename object_detector_type, + typename image_array_type + > + const matrix test_object_detection_function ( + object_detector_type& detector, + const image_array_type& images, + const std::vector >& truth_dets, + const std::vector >& ignore, + const test_box_overlap& overlap_tester = test_box_overlap(), + const double adjust_threshold = 0 + ); + /*! + requires + - All the requirements of the above test_object_detection_function() routine. + ensures + - converts all the rectangles in truth_dets into full_object_detection objects + via full_object_detection's rectangle constructor. Then invokes + test_object_detection_function() on the full_object_detections and returns + the results. + !*/ + + template < + typename object_detector_type, + typename image_array_type + > + const matrix test_object_detection_function ( + object_detector_type& detector, + const image_array_type& images, + const std::vector >& truth_dets, + const test_box_overlap& overlap_tester = test_box_overlap(), + const double adjust_threshold = 0 + ); + /*! + requires + - All the requirements of the above test_object_detection_function() routine. + ensures + - This function simply invokes test_object_detection_function() with all the + given arguments and an empty set of ignore rectangles and returns the results. + !*/ + + template < + typename object_detector_type, + typename image_array_type + > + const matrix test_object_detection_function ( + object_detector_type& detector, + const image_array_type& images, + const std::vector >& truth_dets, + const test_box_overlap& overlap_tester = test_box_overlap(), + const double adjust_threshold = 0 + ); + /*! + requires + - All the requirements of the above test_object_detection_function() routine. + ensures + - This function simply invokes test_object_detection_function() with all the + given arguments and an empty set of ignore rectangles and returns the results. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename SUBNET, + typename image_array_type + > + const matrix test_object_detection_function ( + loss_mmod& detector, + const image_array_type& images, + const std::vector>& truth_dets, + const test_box_overlap& overlap_tester = test_box_overlap(), + const double adjust_threshold = 0, + const test_box_overlap& overlaps_ignore_tester = test_box_overlap() + ); + /*! + requires + - is_learning_problem(images,truth_dets) + - image_array_type must be an implementation of dlib/array/array_kernel_abstract.h + and it must contain objects which can be accepted by detector(). + ensures + - This function is just like the test_object_detection_function() for + object_detector's except it runs on CNNs that use loss_mmod. + - Tests the given detector against the supplied object detection problem and + returns the precision, recall, and average precision. Note that the task is + to predict, for each images[i], the set of object locations, and their + corresponding labels, given by truth_dets[i]. Additionally, any detections + on image[i] that match a box in truth_dets[i] that are marked ignore are + ignored. That is, detections matching an ignore box, regardless of the + ignore box's label, do not count as a false alarm and similarly if any + ignored box in truth_dets goes undetected it does not count as a missed + detection. To test if a box overlaps an ignore box, we use overlaps_ignore_tester. + - In particular, returns a matrix M such that: + - M(0) == the precision of the detector object. This is a number + in the range [0,1] which measures the fraction of detector outputs + which correspond to a real target. A value of 1 means the detector + never produces any false alarms while a value of 0 means it only + produces false alarms. + - M(1) == the recall of the detector object. This is a number in the + range [0,1] which measures the fraction of targets found by the detector. + A value of 1 means the detector found all the non-ignore targets in + truth_dets while a value of 0 means the detector didn't locate any of the + targets. + - M(2) == the average precision of the detector object. This is a number + in the range [0,1] which measures the overall quality of the detector. + We compute this by taking all the detections output by the detector and + ordering them in descending order of their detection scores. Then we use + the average_precision() routine to score the ranked listing and store the + output into M(2). + - This function considers a detector output D to match a truth rectangle T if + and only if overlap_tester(T,D) returns true and the labels are identical strings. + - Note that you can use the adjust_threshold argument to raise or lower the + detection threshold. This value is passed into the identically named + argument to the detector object and therefore influences the number of + output detections. It can be useful, for example, to lower the detection + threshold because it results in more detections being output by the + detector, and therefore provides more information in the ranking, + possibly raising the average precision. + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type, + typename image_array_type + > + const matrix cross_validate_object_detection_trainer ( + const trainer_type& trainer, + const image_array_type& images, + const std::vector >& truth_dets, + const std::vector >& ignore, + const long folds, + const test_box_overlap& overlap_tester = test_box_overlap(), + const double adjust_threshold = 0 + ); + /*! + requires + - is_learning_problem(images,truth_dets) + - images.size() == ignore.size() + - 1 < folds <= images.size() + - trainer_type == some kind of object detection trainer (e.g structural_object_detection_trainer) + - image_array_type must be an implementation of dlib/array/array_kernel_abstract.h + and it must contain objects which can be accepted by detector(). + - it is legal to call trainer.train(images, truth_dets) + ensures + - Performs k-fold cross-validation by using the given trainer to solve an + object detection problem for the given number of folds. Each fold is tested + using the output of the trainer and a matrix summarizing the results is + returned. The matrix contains the precision, recall, and average + precision of the trained detectors and is defined identically to the + test_object_detection_function() routine defined at the top of this file. + !*/ + + template < + typename trainer_type, + typename image_array_type + > + const matrix cross_validate_object_detection_trainer ( + const trainer_type& trainer, + const image_array_type& images, + const std::vector >& truth_dets, + const std::vector >& ignore, + const long folds, + const test_box_overlap& overlap_tester = test_box_overlap(), + const double adjust_threshold = 0 + ); + /*! + requires + - all the requirements of the above cross_validate_object_detection_trainer() routine. + ensures + - converts all the rectangles in truth_dets into full_object_detection objects + via full_object_detection's rectangle constructor. Then invokes + cross_validate_object_detection_trainer() on the full_object_detections and + returns the results. + !*/ + + template < + typename trainer_type, + typename image_array_type + > + const matrix cross_validate_object_detection_trainer ( + const trainer_type& trainer, + const image_array_type& images, + const std::vector >& truth_dets, + const long folds, + const test_box_overlap& overlap_tester = test_box_overlap(), + const double adjust_threshold = 0 + ); + /*! + requires + - All the requirements of the above cross_validate_object_detection_trainer() routine. + ensures + - This function simply invokes cross_validate_object_detection_trainer() with all + the given arguments and an empty set of ignore rectangles and returns the results. + !*/ + + template < + typename trainer_type, + typename image_array_type + > + const matrix cross_validate_object_detection_trainer ( + const trainer_type& trainer, + const image_array_type& images, + const std::vector >& truth_dets, + const long folds, + const test_box_overlap& overlap_tester = test_box_overlap(), + const double adjust_threshold = 0 + ); + /*! + requires + - All the requirements of the above cross_validate_object_detection_trainer() routine. + ensures + - This function simply invokes cross_validate_object_detection_trainer() with all + the given arguments and an empty set of ignore rectangles and returns the results. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CROSS_VALIDATE_OBJECT_DETECTION_TRaINER_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/svm/cross_validate_regression_trainer.h b/ml/dlib/dlib/svm/cross_validate_regression_trainer.h new file mode 100644 index 000000000..a4c6077c9 --- /dev/null +++ b/ml/dlib/dlib/svm/cross_validate_regression_trainer.h @@ -0,0 +1,155 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CROSS_VALIDATE_REGRESSION_TRaINER_Hh_ +#define DLIB_CROSS_VALIDATE_REGRESSION_TRaINER_Hh_ + +#include +#include "../matrix.h" +#include "../statistics.h" +#include "cross_validate_regression_trainer_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename reg_funct_type, + typename sample_type, + typename label_type + > + matrix + test_regression_function ( + reg_funct_type& reg_funct, + const std::vector& x_test, + const std::vector& y_test + ) + { + + // make sure requires clause is not broken + DLIB_ASSERT( is_learning_problem(x_test,y_test) == true, + "\tmatrix test_regression_function()" + << "\n\t invalid inputs were given to this function" + << "\n\t is_learning_problem(x_test,y_test): " + << is_learning_problem(x_test,y_test)); + + running_stats rs, rs_mae; + running_scalar_covariance rc; + + for (unsigned long i = 0; i < x_test.size(); ++i) + { + // compute error + const double output = reg_funct(x_test[i]); + const double temp = output - y_test[i]; + + rs_mae.add(std::abs(temp)); + rs.add(temp*temp); + rc.add(output, y_test[i]); + } + + matrix result; + result = rs.mean(), rc.correlation(), rs_mae.mean(), rs_mae.stddev(); + return result; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type, + typename sample_type, + typename label_type + > + matrix + cross_validate_regression_trainer ( + const trainer_type& trainer, + const std::vector& x, + const std::vector& y, + const long folds + ) + { + + // make sure requires clause is not broken + DLIB_ASSERT(is_learning_problem(x,y) == true && + 1 < folds && folds <= static_cast(x.size()), + "\tmatrix cross_validate_regression_trainer()" + << "\n\t invalid inputs were given to this function" + << "\n\t x.size(): " << x.size() + << "\n\t folds: " << folds + << "\n\t is_learning_problem(x,y): " << is_learning_problem(x,y) + ); + + + + const long num_in_test = x.size()/folds; + const long num_in_train = x.size() - num_in_test; + + running_stats rs, rs_mae; + running_scalar_covariance rc; + + std::vector x_test, x_train; + std::vector y_test, y_train; + + + long next_test_idx = 0; + + + for (long i = 0; i < folds; ++i) + { + x_test.clear(); + y_test.clear(); + x_train.clear(); + y_train.clear(); + + // load up the test samples + for (long cnt = 0; cnt < num_in_test; ++cnt) + { + x_test.push_back(x[next_test_idx]); + y_test.push_back(y[next_test_idx]); + next_test_idx = (next_test_idx + 1)%x.size(); + } + + // load up the training samples + long next = next_test_idx; + for (long cnt = 0; cnt < num_in_train; ++cnt) + { + x_train.push_back(x[next]); + y_train.push_back(y[next]); + next = (next + 1)%x.size(); + } + + + try + { + const typename trainer_type::trained_function_type& df = trainer.train(x_train,y_train); + + // do the training and testing + for (unsigned long j = 0; j < x_test.size(); ++j) + { + // compute error + const double output = df(x_test[j]); + const double temp = output - y_test[j]; + + rs_mae.add(std::abs(temp)); + rs.add(temp*temp); + rc.add(output, y_test[j]); + } + } + catch (invalid_nu_error&) + { + // just ignore cases which result in an invalid nu + } + + } // for (long i = 0; i < folds; ++i) + + matrix result; + result = rs.mean(), rc.correlation(), rs_mae.mean(), rs_mae.stddev(); + return result; + } + +} + +// ---------------------------------------------------------------------------------------- + +#endif // DLIB_CROSS_VALIDATE_REGRESSION_TRaINER_Hh_ + + diff --git a/ml/dlib/dlib/svm/cross_validate_regression_trainer_abstract.h b/ml/dlib/dlib/svm/cross_validate_regression_trainer_abstract.h new file mode 100644 index 000000000..d6298aa74 --- /dev/null +++ b/ml/dlib/dlib/svm/cross_validate_regression_trainer_abstract.h @@ -0,0 +1,82 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_CROSS_VALIDATE_REGRESSION_TRaINER_ABSTRACT_Hh_ +#ifdef DLIB_CROSS_VALIDATE_REGRESSION_TRaINER_ABSTRACT_Hh_ + +#include +#include "../matrix.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename reg_funct_type, + typename sample_type, + typename label_type + > + matrix + test_regression_function ( + reg_funct_type& reg_funct, + const std::vector& x_test, + const std::vector& y_test + ); + /*! + requires + - is_learning_problem(x_test, y_test) + - reg_funct_type == some kind of regression function object + (e.g. a decision_function created by the svr_trainer ) + ensures + - Tests reg_funct against the given samples in x_test and target values in + y_test and returns a matrix M summarizing the results. Specifically: + - M(0) == the mean squared error. + The MSE is given by: sum over i: pow(reg_funct(x_test[i]) - y_test[i], 2.0) + - M(1) == the correlation between reg_funct(x_test[i]) and y_test[i]. + This is a number between -1 and 1. + - M(2) == the mean absolute error. + This is given by: sum over i: abs(reg_funct(x_test[i]) - y_test[i]) + - M(3) == the standard deviation of the absolute error. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type, + typename sample_type, + typename label_type + > + matrix + cross_validate_regression_trainer ( + const trainer_type& trainer, + const std::vector& x, + const std::vector& y, + const long folds + ); + /*! + requires + - is_learning_problem(x,y) + - 1 < folds <= x.size() + - trainer_type == some kind of regression trainer object (e.g. svr_trainer) + ensures + - Performs k-fold cross validation by using the given trainer to solve a + regression problem for the given number of folds. Each fold is tested using + the output of the trainer. A matrix M summarizing the results is returned. + Specifically: + - M(0) == the mean squared error. + The MSE is given by: sum over i: pow(reg_funct(x[i]) - y[i], 2.0) + - M(1) == the correlation between a predicted y value and its true value. + This is a number between -1 and 1. + - M(2) == the mean absolute error. + This is given by: sum over i: abs(reg_funct(x_test[i]) - y_test[i]) + - M(3) == the standard deviation of the absolute error. + !*/ + +} + +// ---------------------------------------------------------------------------------------- + +#endif // DLIB_CROSS_VALIDATE_REGRESSION_TRaINER_ABSTRACT_Hh_ + + + diff --git a/ml/dlib/dlib/svm/cross_validate_sequence_labeler.h b/ml/dlib/dlib/svm/cross_validate_sequence_labeler.h new file mode 100644 index 000000000..75c4e363a --- /dev/null +++ b/ml/dlib/dlib/svm/cross_validate_sequence_labeler.h @@ -0,0 +1,152 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CROSS_VALIDATE_SEQUENCE_LABeLER_Hh_ +#define DLIB_CROSS_VALIDATE_SEQUENCE_LABeLER_Hh_ + +#include "cross_validate_sequence_labeler_abstract.h" +#include +#include "../matrix.h" +#include "svm.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename sequence_labeler_type, + typename sequence_type + > + const matrix test_sequence_labeler ( + const sequence_labeler_type& labeler, + const std::vector& samples, + const std::vector >& labels + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( is_sequence_labeling_problem(samples, labels) == true, + "\tmatrix test_sequence_labeler()" + << "\n\t invalid inputs were given to this function" + << "\n\t is_sequence_labeling_problem(samples, labels): " + << is_sequence_labeling_problem(samples, labels)); + + matrix res(labeler.num_labels(), labeler.num_labels()); + res = 0; + + std::vector pred; + for (unsigned long i = 0; i < samples.size(); ++i) + { + labeler.label_sequence(samples[i], pred); + + for (unsigned long j = 0; j < pred.size(); ++j) + { + const unsigned long truth = labels[i][j]; + if (truth >= static_cast(res.nr())) + { + // ignore labels the labeler doesn't know about. + continue; + } + + res(truth, pred[j]) += 1; + } + } + + return res; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type, + typename sequence_type + > + const matrix cross_validate_sequence_labeler ( + const trainer_type& trainer, + const std::vector& samples, + const std::vector >& labels, + const long folds + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(is_sequence_labeling_problem(samples,labels) == true && + 1 < folds && folds <= static_cast(samples.size()), + "\tmatrix cross_validate_sequence_labeler()" + << "\n\t invalid inputs were given to this function" + << "\n\t samples.size(): " << samples.size() + << "\n\t folds: " << folds + << "\n\t is_sequence_labeling_problem(samples,labels): " << is_sequence_labeling_problem(samples,labels) + ); + +#ifdef ENABLE_ASSERTS + for (unsigned long i = 0; i < labels.size(); ++i) + { + for (unsigned long j = 0; j < labels[i].size(); ++j) + { + // make sure requires clause is not broken + DLIB_ASSERT(labels[i][j] < trainer.num_labels(), + "\t matrix cross_validate_sequence_labeler()" + << "\n\t The labels are invalid." + << "\n\t labels[i][j]: " << labels[i][j] + << "\n\t trainer.num_labels(): " << trainer.num_labels() + << "\n\t i: " << i + << "\n\t j: " << j + ); + } + } +#endif + + + + + const long num_in_test = samples.size()/folds; + const long num_in_train = samples.size() - num_in_test; + + std::vector x_test, x_train; + std::vector > y_test, y_train; + + + long next_test_idx = 0; + + matrix res; + + + for (long i = 0; i < folds; ++i) + { + x_test.clear(); + y_test.clear(); + x_train.clear(); + y_train.clear(); + + // load up the test samples + for (long cnt = 0; cnt < num_in_test; ++cnt) + { + x_test.push_back(samples[next_test_idx]); + y_test.push_back(labels[next_test_idx]); + next_test_idx = (next_test_idx + 1)%samples.size(); + } + + // load up the training samples + long next = next_test_idx; + for (long cnt = 0; cnt < num_in_train; ++cnt) + { + x_train.push_back(samples[next]); + y_train.push_back(labels[next]); + next = (next + 1)%samples.size(); + } + + + res += test_sequence_labeler(trainer.train(x_train,y_train), x_test, y_test); + + } // for (long i = 0; i < folds; ++i) + + return res; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CROSS_VALIDATE_SEQUENCE_LABeLER_Hh_ + + diff --git a/ml/dlib/dlib/svm/cross_validate_sequence_labeler_abstract.h b/ml/dlib/dlib/svm/cross_validate_sequence_labeler_abstract.h new file mode 100644 index 000000000..3d2409b28 --- /dev/null +++ b/ml/dlib/dlib/svm/cross_validate_sequence_labeler_abstract.h @@ -0,0 +1,83 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_CROSS_VALIDATE_SEQUENCE_LABeLER_ABSTRACT_Hh_ +#ifdef DLIB_CROSS_VALIDATE_SEQUENCE_LABeLER_ABSTRACT_Hh_ + +#include +#include "../matrix.h" +#include "svm.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename sequence_labeler_type, + typename sequence_type + > + const matrix test_sequence_labeler ( + const sequence_labeler_type& labeler, + const std::vector& samples, + const std::vector >& labels + ); + /*! + requires + - is_sequence_labeling_problem(samples, labels) + - sequence_labeler_type == dlib::sequence_labeler or an object with a + compatible interface. + ensures + - Tests labeler against the given samples and labels and returns a confusion + matrix summarizing the results. + - The confusion matrix C returned by this function has the following properties. + - C.nc() == labeler.num_labels() + - C.nr() == labeler.num_labels() + - C(T,P) == the number of times a sequence element with label T was predicted + to have a label of P. + - Any samples with a label value >= labeler.num_labels() are ignored. That + is, samples with labels the labeler hasn't ever seen before are ignored. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type, + typename sequence_type + > + const matrix cross_validate_sequence_labeler ( + const trainer_type& trainer, + const std::vector& samples, + const std::vector >& labels, + const long folds + ); + /*! + requires + - is_sequence_labeling_problem(samples, labels) + - 1 < folds <= samples.size() + - for all valid i and j: labels[i][j] < trainer.num_labels() + - trainer_type == dlib::structural_sequence_labeling_trainer or an object + with a compatible interface. + ensures + - performs k-fold cross validation by using the given trainer to solve the + given sequence labeling problem for the given number of folds. Each fold + is tested using the output of the trainer and the confusion matrix from all + folds is summed and returned. + - The total confusion matrix is computed by running test_sequence_labeler() + on each fold and summing its output. + - The number of folds used is given by the folds argument. + - The confusion matrix C returned by this function has the following properties. + - C.nc() == trainer.num_labels() + - C.nr() == trainer.num_labels() + - C(T,P) == the number of times a sequence element with label T was predicted + to have a label of P. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CROSS_VALIDATE_SEQUENCE_LABeLER_ABSTRACT_Hh_ + + + diff --git a/ml/dlib/dlib/svm/cross_validate_sequence_segmenter.h b/ml/dlib/dlib/svm/cross_validate_sequence_segmenter.h new file mode 100644 index 000000000..8413f9165 --- /dev/null +++ b/ml/dlib/dlib/svm/cross_validate_sequence_segmenter.h @@ -0,0 +1,187 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CROSS_VALIDATE_SEQUENCE_sEGMENTER_Hh_ +#define DLIB_CROSS_VALIDATE_SEQUENCE_sEGMENTER_Hh_ + +#include "cross_validate_sequence_segmenter_abstract.h" +#include "sequence_segmenter.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + template < + typename sequence_segmenter_type, + typename sequence_type + > + const matrix raw_metrics_test_sequence_segmenter ( + const sequence_segmenter_type& segmenter, + const std::vector& samples, + const std::vector > >& segments + ) + { + std::vector > truth; + std::vector > pred; + + double true_hits = 0; + double total_detections = 0; + double total_true_segments = 0; + + for (unsigned long i = 0; i < samples.size(); ++i) + { + segmenter.segment_sequence(samples[i], pred); + truth = segments[i]; + // sort the segments so they will be in the same orders + std::sort(truth.begin(), truth.end()); + std::sort(pred.begin(), pred.end()); + + total_true_segments += truth.size(); + total_detections += pred.size(); + + unsigned long j=0,k=0; + while (j < pred.size() && k < truth.size()) + { + if (pred[j].first == truth[k].first && + pred[j].second == truth[k].second) + { + ++true_hits; + ++j; + ++k; + } + else if (pred[j].first < truth[k].first) + { + ++j; + } + else + { + ++k; + } + } + } + + matrix res; + res = total_detections, total_true_segments, true_hits; + return res; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename sequence_segmenter_type, + typename sequence_type + > + const matrix test_sequence_segmenter ( + const sequence_segmenter_type& segmenter, + const std::vector& samples, + const std::vector > >& segments + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( is_sequence_segmentation_problem(samples, segments) == true, + "\tmatrix test_sequence_segmenter()" + << "\n\t invalid inputs were given to this function" + << "\n\t is_sequence_segmentation_problem(samples, segments): " + << is_sequence_segmentation_problem(samples, segments)); + + const matrix metrics = impl::raw_metrics_test_sequence_segmenter(segmenter, samples, segments); + + const double total_detections = metrics(0); + const double total_true_segments = metrics(1); + const double true_hits = metrics(2); + + const double precision = (total_detections ==0) ? 1 : true_hits/total_detections; + const double recall = (total_true_segments==0) ? 1 : true_hits/total_true_segments; + const double f1 = (precision+recall ==0) ? 0 : 2*precision*recall/(precision+recall); + + matrix res; + res = precision, recall, f1; + return res; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type, + typename sequence_type + > + const matrix cross_validate_sequence_segmenter ( + const trainer_type& trainer, + const std::vector& samples, + const std::vector > >& segments, + const long folds + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( is_sequence_segmentation_problem(samples, segments) == true && + 1 < folds && folds <= static_cast(samples.size()), + "\tmatrix cross_validate_sequence_segmenter()" + << "\n\t invalid inputs were given to this function" + << "\n\t folds: " << folds + << "\n\t is_sequence_segmentation_problem(samples, segments): " + << is_sequence_segmentation_problem(samples, segments)); + + + const long num_in_test = samples.size()/folds; + const long num_in_train = samples.size() - num_in_test; + + std::vector x_test, x_train; + std::vector > > y_test, y_train; + + long next_test_idx = 0; + + matrix metrics; + metrics = 0; + + for (long i = 0; i < folds; ++i) + { + x_test.clear(); + y_test.clear(); + x_train.clear(); + y_train.clear(); + + // load up the test samples + for (long cnt = 0; cnt < num_in_test; ++cnt) + { + x_test.push_back(samples[next_test_idx]); + y_test.push_back(segments[next_test_idx]); + next_test_idx = (next_test_idx + 1)%samples.size(); + } + + // load up the training samples + long next = next_test_idx; + for (long cnt = 0; cnt < num_in_train; ++cnt) + { + x_train.push_back(samples[next]); + y_train.push_back(segments[next]); + next = (next + 1)%samples.size(); + } + + + metrics += impl::raw_metrics_test_sequence_segmenter(trainer.train(x_train,y_train), x_test, y_test); + } // for (long i = 0; i < folds; ++i) + + + const double total_detections = metrics(0); + const double total_true_segments = metrics(1); + const double true_hits = metrics(2); + + const double precision = (total_detections ==0) ? 1 : true_hits/total_detections; + const double recall = (total_true_segments==0) ? 1 : true_hits/total_true_segments; + const double f1 = (precision+recall ==0) ? 0 : 2*precision*recall/(precision+recall); + + matrix res; + res = precision, recall, f1; + return res; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CROSS_VALIDATE_SEQUENCE_sEGMENTER_Hh_ + + diff --git a/ml/dlib/dlib/svm/cross_validate_sequence_segmenter_abstract.h b/ml/dlib/dlib/svm/cross_validate_sequence_segmenter_abstract.h new file mode 100644 index 000000000..87e21d592 --- /dev/null +++ b/ml/dlib/dlib/svm/cross_validate_sequence_segmenter_abstract.h @@ -0,0 +1,80 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_CROSS_VALIDATE_SEQUENCE_sEGMENTER_ABSTRACT_Hh_ +#ifdef DLIB_CROSS_VALIDATE_SEQUENCE_sEGMENTER_ABSTRACT_Hh_ + +#include "sequence_segmenter_abstract.h" +#include "../matrix.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename sequence_segmenter_type, + typename sequence_type + > + const matrix test_sequence_segmenter ( + const sequence_segmenter_type& segmenter, + const std::vector& samples, + const std::vector > >& segments + ); + /*! + requires + - is_sequence_segmentation_problem(samples, segments) == true + - sequence_segmenter_type == dlib::sequence_segmenter or an object with a + compatible interface. + ensures + - Tests segmenter against the given samples and truth segments and returns the + precision, recall, and F1-score obtained by the segmenter. That is, the goal + of the segmenter should be to predict segments[i] given samples[i] as input. + The test_sequence_segmenter() routine therefore measures how well the + segmenter is able to perform this task. + - Returns a row matrix M with the following properties: + - M(0) == The precision of the segmenter measured against the task of + detecting the segments of each sample. This is a number in the range 0 + to 1 and represents the fraction of segments output by the segmenter + which correspond to true segments for each sample. + - M(1) == The recall of the segmenter measured against the task of + detecting the segments of each sample. This is a number in the range 0 + to 1 and represents the fraction of the true segments found by the + segmenter. + - M(2) == The F1-score for the segmenter. This is the harmonic mean of + M(0) and M(1). + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type, + typename sequence_type + > + const matrix cross_validate_sequence_segmenter ( + const trainer_type& trainer, + const std::vector& samples, + const std::vector > >& segments, + const long folds + ); + /*! + requires + - is_sequence_segmentation_problem(samples, segments) == true + - 1 < folds <= samples.size() + - trainer_type == dlib::structural_sequence_segmentation_trainer or an object + with a compatible interface. + ensures + - Performs k-fold cross validation by using the given trainer to solve the + given sequence segmentation problem for the given number of folds. Each fold + is tested using the output of the trainer and the results from all folds are + summarized and returned. + - This function returns the precision, recall, and F1-score for the trainer. + In particular, the output is the same as the output from the + test_sequence_segmenter() routine defined above. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CROSS_VALIDATE_SEQUENCE_sEGMENTER_ABSTRACT_Hh_ + diff --git a/ml/dlib/dlib/svm/cross_validate_track_association_trainer.h b/ml/dlib/dlib/svm/cross_validate_track_association_trainer.h new file mode 100644 index 000000000..dac519b7a --- /dev/null +++ b/ml/dlib/dlib/svm/cross_validate_track_association_trainer.h @@ -0,0 +1,163 @@ +// Copyright (C) 2014 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CROSS_VALIDATE_TRACK_ASSOCIATION_TrAINER_Hh_ +#define DLIB_CROSS_VALIDATE_TRACK_ASSOCIATION_TrAINER_Hh_ + +#include "cross_validate_track_association_trainer_abstract.h" +#include "structural_track_association_trainer.h" + +namespace dlib +{ +// ---------------------------------------------------------------------------------------- + + namespace impl + { + template < + typename track_association_function, + typename detection_type, + typename label_type + > + void test_track_association_function ( + const track_association_function& assoc, + const std::vector > >& samples, + unsigned long& total_dets, + unsigned long& correctly_associated_dets + ) + { + const typename track_association_function::association_function_type& f = assoc.get_assignment_function(); + + typedef typename detection_type::track_type track_type; + using namespace impl; + + dlib::rand rnd; + std::vector tracks; + std::map track_idx; // tracks[track_idx[id]] == track with ID id. + + for (unsigned long j = 0; j < samples.size(); ++j) + { + std::vector > dets = samples[j]; + // Shuffle the order of the detections so we can be sure that there isn't + // anything funny going on like the detections always coming in the same + // order relative to their labels and the association function just gets + // lucky by picking the same assignment ordering every time. So this way + // we know the assignment function really is doing something rather than + // just being lucky. + randomize_samples(dets, rnd); + + total_dets += dets.size(); + std::vector assignments = f(get_unlabeled_dets(dets), tracks); + std::vector updated_track(tracks.size(), false); + // now update all the tracks with the detections that associated to them. + for (unsigned long k = 0; k < assignments.size(); ++k) + { + // If the detection is associated to tracks[assignments[k]] + if (assignments[k] != -1) + { + tracks[assignments[k]].update_track(dets[k].det); + updated_track[assignments[k]] = true; + + // if this detection was supposed to go to this track + if (track_idx.count(dets[k].label) && track_idx[dets[k].label]==assignments[k]) + ++correctly_associated_dets; + + track_idx[dets[k].label] = assignments[k]; + } + else + { + track_type new_track; + new_track.update_track(dets[k].det); + tracks.push_back(new_track); + + // if this detection was supposed to go to a new track + if (track_idx.count(dets[k].label) == 0) + ++correctly_associated_dets; + + track_idx[dets[k].label] = tracks.size()-1; + } + } + + // Now propagate all the tracks that didn't get any detections. + for (unsigned long k = 0; k < updated_track.size(); ++k) + { + if (!updated_track[k]) + tracks[k].propagate_track(); + } + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename track_association_function, + typename detection_type, + typename label_type + > + double test_track_association_function ( + const track_association_function& assoc, + const std::vector > > >& samples + ) + { + unsigned long total_dets = 0; + unsigned long correctly_associated_dets = 0; + + for (unsigned long i = 0; i < samples.size(); ++i) + { + impl::test_track_association_function(assoc, samples[i], total_dets, correctly_associated_dets); + } + + return (double)correctly_associated_dets/(double)total_dets; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type, + typename detection_type, + typename label_type + > + double cross_validate_track_association_trainer ( + const trainer_type& trainer, + const std::vector > > >& samples, + const long folds + ) + { + const long num_in_test = samples.size()/folds; + const long num_in_train = samples.size() - num_in_test; + + std::vector > > > samples_train; + + long next_test_idx = 0; + unsigned long total_dets = 0; + unsigned long correctly_associated_dets = 0; + + for (long i = 0; i < folds; ++i) + { + samples_train.clear(); + + // load up the training samples + long next = (next_test_idx + num_in_test)%samples.size(); + for (long cnt = 0; cnt < num_in_train; ++cnt) + { + samples_train.push_back(samples[next]); + next = (next + 1)%samples.size(); + } + + const track_association_function& df = trainer.train(samples_train); + for (long cnt = 0; cnt < num_in_test; ++cnt) + { + impl::test_track_association_function(df, samples[next_test_idx], total_dets, correctly_associated_dets); + next_test_idx = (next_test_idx + 1)%samples.size(); + } + } + + return (double)correctly_associated_dets/(double)total_dets; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CROSS_VALIDATE_TRACK_ASSOCIATION_TrAINER_Hh_ + + diff --git a/ml/dlib/dlib/svm/cross_validate_track_association_trainer_abstract.h b/ml/dlib/dlib/svm/cross_validate_track_association_trainer_abstract.h new file mode 100644 index 000000000..76b985600 --- /dev/null +++ b/ml/dlib/dlib/svm/cross_validate_track_association_trainer_abstract.h @@ -0,0 +1,69 @@ +// Copyright (C) 2014 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_CROSS_VALIDATE_TRACK_ASSOCIATION_TrAINER_ABSTRACT_Hh_ +#ifdef DLIB_CROSS_VALIDATE_TRACK_ASSOCIATION_TrAINER_ABSTRACT_Hh_ + +#include "structural_track_association_trainer_abstract.h" +#include "svm_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename track_association_function, + typename detection_type, + typename label_type + > + double test_track_association_function ( + const track_association_function& assoc, + const std::vector > > >& samples + ); + /*! + requires + - is_track_association_problem(samples) + - track_association_function == an instantiation of the dlib::track_association_function + template or an object with a compatible interface. + ensures + - Tests assoc against the given samples and returns the fraction of detections + which were correctly associated to their tracks. That is, if assoc produces + perfect tracks when used then this function returns a value of 1. Similarly, + if 5% of the detections were associated to the incorrect track then the + return value is 0.05. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type, + typename detection_type, + typename label_type + > + double cross_validate_track_association_trainer ( + const trainer_type& trainer, + const std::vector > > >& samples, + const long folds + ); + /*! + requires + - is_track_association_problem(samples) + - 1 < folds <= samples.size() + - trainer_type == dlib::structural_track_association_trainer or an object with + a compatible interface. + ensures + - Performs k-fold cross validation by using the given trainer to solve the + given track association learning problem for the given number of folds. Each + fold is tested using the output of the trainer and the fraction of + mis-associated detections is returned (i.e. this function returns the same + measure of track association quality as test_track_association_function()). + - The number of folds used is given by the folds argument. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CROSS_VALIDATE_TRACK_ASSOCIATION_TrAINER_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/svm/empirical_kernel_map.h b/ml/dlib/dlib/svm/empirical_kernel_map.h new file mode 100644 index 000000000..7a91e591a --- /dev/null +++ b/ml/dlib/dlib/svm/empirical_kernel_map.h @@ -0,0 +1,429 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_EMPIRICAL_KERNEl_MAP_H_ +#define DLIB_EMPIRICAL_KERNEl_MAP_H_ + +#include "../matrix.h" +#include "empirical_kernel_map_abstract.h" +#include "linearly_independent_subset_finder.h" +#include +#include "../algs.h" +#include "kernel_matrix.h" +#include "function.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template + const decision_function convert_to_decision_function ( + const projection_function& project_funct, + const matrix_exp& vect + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(project_funct.out_vector_size() > 0 && is_vector(vect) && + project_funct.out_vector_size() == vect.size() && project_funct.weights.nc() == project_funct.basis_vectors.size(), + "\t const decision_function convert_to_decision_function()" + << "\n\t Invalid inputs to this function." + << "\n\t project_funct.out_vector_size(): " << project_funct.out_vector_size() + << "\n\t project_funct.weights.nc(): " << project_funct.weights.nc() + << "\n\t project_funct.basis_vectors.size(): " << project_funct.basis_vectors.size() + << "\n\t is_vector(vect): " << is_vector(vect) + << "\n\t vect.size(): " << vect.size() + ); + + return decision_function(trans(project_funct.weights)*vect, + 0, + project_funct.kernel_function, + project_funct.basis_vectors); + } + +// ---------------------------------------------------------------------------------------- + + template + class empirical_kernel_map + { + public: + + struct empirical_kernel_map_error : public error + { + empirical_kernel_map_error(const std::string& message): error(message) {} + }; + + typedef kern_type kernel_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + + void clear ( + ) + { + empirical_kernel_map().swap(*this); + } + + template + void load( + const kernel_type& kernel_, + const T& basis_samples + ) + { + load_impl(kernel_, mat(basis_samples)); + } + + void load( + const linearly_independent_subset_finder& lisf + ) + { + if (lisf.size() == 0) + { + std::ostringstream sout; + sout << "An empty linearly_independent_subset_finder was supplied to the\n" + << "empirical_kernel_map::load() function. One reason this might occur\n" + << "is if your dataset contains only zero vectors (or vectors \n" + << "approximately zero).\n"; + clear(); + throw empirical_kernel_map_error(sout.str()); + } + + kernel = lisf.get_kernel(); + weights = trans(chol(lisf.get_inv_kernel_marix())); + basis.resize(lisf.size()); + for (unsigned long i = 0; i < basis.size(); ++i) + basis[i] = lisf[i]; + + } + + const kernel_type get_kernel ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(out_vector_size() > 0, + "\tconst kernel_type empirical_kernel_map::get_kernel()" + << "\n\t You have to load this object with a kernel before you can call this function" + << "\n\t this: " << this + ); + + return kernel; + } + + long out_vector_size ( + ) const + { + return weights.nr(); + } + + unsigned long basis_size ( + ) const + { + return basis.size(); + } + + const sample_type& operator[] ( + unsigned long idx + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT( idx < basis_size(), + "\t const sample_type& empirical_kernel_map::operator[](idx)" + << "\n\t Invalid inputs to this function." + << "\n\t basis_size(): " << basis_size() + << "\n\t this: " << this + ); + + return basis[idx]; + } + + template + const decision_function convert_to_decision_function ( + const matrix_exp& vect + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(out_vector_size() != 0 && is_vector(vect) && out_vector_size() == vect.size(), + "\t const decision_function empirical_kernel_map::convert_to_decision_function()" + << "\n\t Invalid inputs to this function." + << "\n\t out_vector_size(): " << out_vector_size() + << "\n\t is_vector(vect): " << is_vector(vect) + << "\n\t vect.size(): " << vect.size() + << "\n\t this: " << this + ); + + return decision_function(trans(weights)*vect, 0, kernel, mat(basis)); + } + + template + const distance_function convert_to_distance_function ( + const matrix_exp& vect + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(out_vector_size() != 0 && is_vector(vect) && out_vector_size() == vect.size(), + "\t const distance_function empirical_kernel_map::convert_to_distance_function()" + << "\n\t Invalid inputs to this function." + << "\n\t out_vector_size(): " << out_vector_size() + << "\n\t is_vector(vect): " << is_vector(vect) + << "\n\t vect.size(): " << vect.size() + << "\n\t this: " << this + ); + + return distance_function(trans(weights)*vect, dot(vect,vect), kernel, mat(basis)); + } + + const projection_function get_projection_function ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(out_vector_size() != 0, + "\tconst projection_function empirical_kernel_map::get_projection_function()" + << "\n\t You have to load this object with data before you can call this function" + << "\n\t this: " << this + ); + + return projection_function(weights, kernel, mat(basis)); + } + + const matrix get_transformation_to ( + const empirical_kernel_map& target + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(out_vector_size() != 0 && + target.out_vector_size() != 0 && + get_kernel() == target.get_kernel(), + "\t const matrix empirical_kernel_map::get_transformation_to(target)" + << "\n\t Invalid inputs were given to this function" + << "\n\t out_vector_size(): " << out_vector_size() + << "\n\t target.out_vector_size(): " << target.out_vector_size() + << "\n\t get_kernel()==target.get_kernel(): " << (get_kernel()==target.get_kernel()) + << "\n\t this: " << this + ); + + return target.weights * kernel_matrix(target.get_kernel(),target.basis, basis)*trans(weights); + } + + void get_transformation_to ( + const empirical_kernel_map& target, + matrix& tmat, + projection_function& partial_projection + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(out_vector_size() != 0 && + target.out_vector_size() != 0 && + get_kernel() == target.get_kernel() && + basis_size() < target.basis_size(), + "\t void empirical_kernel_map::get_transformation_to(target, tmat, partial_projection)" + << "\n\t Invalid inputs were given to this function" + << "\n\t out_vector_size(): " << out_vector_size() + << "\n\t target.out_vector_size(): " << target.out_vector_size() + << "\n\t basis_size(): " << basis_size() + << "\n\t target.basis_size(): " << target.basis_size() + << "\n\t get_kernel()==target.get_kernel(): " << (get_kernel()==target.get_kernel()) + << "\n\t this: " << this + ); + +#ifdef ENABLE_ASSERTS + for (unsigned long i = 0; i < basis_size(); ++i) + { + DLIB_ASSERT(dlib::equal((*this)[i], target[i]), + "\t const matrix empirical_kernel_map::get_transformation_to(target, tmat, partial_projection)" + << "\n\t target must contain a superset of the basis vectors in *this" + << "\n\t i: " << i + << "\n\t this: " << this + ); + } +#endif + + const unsigned long num1 = basis.size(); + const unsigned long num2 = target.basis.size(); + + tmat = colm(target.weights, range(0,num1-1))*kernel_matrix(kernel, basis)*trans(weights); + + empirical_kernel_map temp_ekm; + temp_ekm.load(kernel, rowm(mat(target.basis), range(num1,num2-1))); + + partial_projection = temp_ekm.get_projection_function(); + + partial_projection.weights = colm(target.weights,range(num1,num2-1))* + kernel_matrix(kernel, temp_ekm.basis)* + trans(temp_ekm.weights)* + partial_projection.weights; + } + + const matrix& project ( + const sample_type& samp + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(out_vector_size() != 0, + "\tconst matrix empirical_kernel_map::project()" + << "\n\t You have to load this object with data before you can call this function" + << "\n\t this: " << this + ); + + temp1 = kernel_matrix(kernel, basis, samp); + temp2 = weights*temp1; + return temp2; + } + + const matrix& project ( + const sample_type& samp, + scalar_type& projection_error + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(out_vector_size() != 0, + "\tconst matrix empirical_kernel_map::project()" + << "\n\t You have to load this object with data before you can call this function" + << "\n\t this: " << this + ); + + temp1 = kernel_matrix(kernel, basis, samp); + temp2 = weights*temp1; + // This value should never be negative (it measures squared distance) but I'm putting the abs() + // here just for good measure since rounding error might push it slightly negative. + projection_error = std::abs( kernel(samp,samp) - dot(temp2,temp2)); + + return temp2; + } + + void swap ( + empirical_kernel_map& item + ) + { + basis.swap(item.basis); + weights.swap(item.weights); + std::swap(kernel, item.kernel); + + temp1.swap(item.temp1); + temp2.swap(item.temp2); + } + + friend void serialize ( + const empirical_kernel_map& item, + std::ostream& out + ) + { + serialize(item.basis, out); + serialize(item.weights, out); + serialize(item.kernel, out); + } + + friend void deserialize ( + empirical_kernel_map& item, + std::istream& in + ) + { + deserialize(item.basis, in); + deserialize(item.weights, in); + deserialize(item.kernel, in); + } + + private: + + template + void load_impl( + const kernel_type& kernel_, + const T& basis_samples + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(basis_samples.size() > 0 && is_vector(basis_samples), + "\tvoid empirical_kernel_map::load(kernel,basis_samples)" + << "\n\t You have to give a non-empty set of basis_samples and it must be a vector" + << "\n\t basis_samples.size(): " << basis_samples.size() + << "\n\t is_vector(basis_samples): " << is_vector(basis_samples) + << "\n\t this: " << this + ); + + // clear out the weights before we begin. This way if an exception throws + // this object will already be in the right state. + weights.set_size(0,0); + kernel = kernel_; + basis.clear(); + basis.reserve(basis_samples.size()); + + // find out the value of the largest norm of the elements in basis_samples. + const scalar_type max_norm = max(diag(kernel_matrix(kernel, basis_samples))); + // we will consider anything less than or equal to this number to be 0 + const scalar_type eps = max_norm*100*std::numeric_limits::epsilon(); + + // Copy all the basis_samples into basis but make sure we don't copy any samples + // that have length 0 + for (long i = 0; i < basis_samples.size(); ++i) + { + const scalar_type norm = kernel(basis_samples(i), basis_samples(i)); + if (norm > eps) + { + basis.push_back(basis_samples(i)); + } + } + + if (basis.size() == 0) + { + clear(); + throw empirical_kernel_map_error("All basis_samples given to empirical_kernel_map::load() were zero vectors"); + } + + matrix K(kernel_matrix(kernel, basis)), U,W,V; + + if (svd2(false,true,K,U,W,V)) + { + clear(); + throw empirical_kernel_map_error("While loading empirical_kernel_map with data, SVD failed to converge."); + } + + + // now count how many elements of W are non-zero + const long num_not_zero = static_cast(sum(W>eps)); + + // Really, this should never happen. But I'm checking for good measure. + if (num_not_zero == 0) + { + clear(); + throw empirical_kernel_map_error("While loading empirical_kernel_map with data, SVD failed"); + } + + weights.set_size(num_not_zero, basis.size()); + + // now fill the weights matrix with the output of the SVD + long counter = 0; + for (long i =0; i < W.size(); ++i) + { + double val = W(i); + if (val > eps) + { + val = std::sqrt(val); + set_rowm(weights,counter) = rowm(trans(V),i)/val; + ++counter; + } + } + + } + + + std::vector basis; + matrix weights; + kernel_type kernel; + + // These members don't contribute to the logical state of this object. They are + // just here so that they don't have to be reallocated every time the project() function + // is called. + mutable matrix temp1, temp2; + + }; + + template + void swap ( + empirical_kernel_map& a, + empirical_kernel_map& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_EMPIRICAL_KERNEl_MAP_H_ + diff --git a/ml/dlib/dlib/svm/empirical_kernel_map_abstract.h b/ml/dlib/dlib/svm/empirical_kernel_map_abstract.h new file mode 100644 index 000000000..8fc413447 --- /dev/null +++ b/ml/dlib/dlib/svm/empirical_kernel_map_abstract.h @@ -0,0 +1,430 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_EMPIRICAL_KERNEl_MAP_ABSTRACT_H_ +#ifdef DLIB_EMPIRICAL_KERNEl_MAP_ABSTRACT_H_ + +#include +#include "../matrix.h" +#include "kernel_abstract.h" +#include "function_abstract.h" +#include "linearly_independent_subset_finder_abstract.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename kernel_type, + typename EXP + > + const decision_function convert_to_decision_function ( + const projection_function& project_funct, + const matrix_exp& vect + ); + /*! + requires + - is_vector(vect) == true + - vect.size() == project_funct.out_vector_size() + - project_funct.out_vector_size() > 0 + - project_funct.weights.nc() == project_funct.basis_vectors.size() + ensures + - This function interprets the given vector as a point in the kernel feature space defined + by the given projection function. The return value of this function is a decision + function, DF, that represents the given vector in the following sense: + - for all possible sample_type objects, S, it is the case that DF(S) == dot(project_funct(S), vect) + (i.e. the returned decision function computes dot products, in kernel feature space, + between vect and any argument you give it. Note also that this equality is exact, even + for sample_type objects not in the span of the basis_vectors.) + - DF.kernel_function == project_funct.kernel_function + - DF.b == 0 + - DF.basis_vectors == project_funct.basis_vectors. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename kern_type + > + class empirical_kernel_map + { + /*! + REQUIREMENTS ON kern_type + - must be a kernel function object as defined in dlib/svm/kernel_abstract.h + + INITIAL VALUE + - out_vector_size() == 0 + - basis_size() == 0 + + WHAT THIS OBJECT REPRESENTS + This object represents a map from objects of sample_type (the kind of object + a kernel function operates on) to finite dimensional column vectors which + represent points in the kernel feature space defined by whatever kernel + is used with this object. + + To use the empirical_kernel_map you supply it with a particular kernel and a set of + basis samples. After that you can present it with new samples and it will project + them into the part of kernel feature space spanned by your basis samples. + + This means the empirical_kernel_map is a tool you can use to very easily kernelize + any algorithm that operates on column vectors. All you have to do is select a + set of basis samples and then use the empirical_kernel_map to project all your + data points into the part of kernel feature space spanned by those basis samples. + Then just run your normal algorithm on the output vectors and it will be effectively + kernelized. + + Regarding methods to select a set of basis samples, if you are working with only a + few thousand samples then you can just use all of them as basis samples. + Alternatively, the linearly_independent_subset_finder often works well for + selecting a basis set. I also find that picking a random subset typically works + well. + + + The empirical kernel map is something that has been around in the kernel methods + literature for a long time but is seemingly not well known. Anyway, one of the + best books on the subject is the following: + Learning with Kernels: Support Vector Machines, Regularization, Optimization, + and Beyond by Bernhard Schlkopf, Alexander J. Smola + The authors discuss the empirical kernel map as well as many other interesting + topics. + !*/ + + public: + + typedef kern_type kernel_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + + struct empirical_kernel_map_error : public error; + /*! + This is an exception class used to indicate a failure to create a + kernel map from data given by the user. + !*/ + + empirical_kernel_map ( + ); + /*! + ensures + - this object is properly initialized + !*/ + + void clear ( + ); + /*! + ensures + - this object has its initial value + !*/ + + template + void load( + const kernel_type& kernel, + const T& basis_samples + ); + /*! + requires + - T must be a dlib::matrix type or something convertible to a matrix via mat() + (e.g. a std::vector) + - is_vector(basis_samples) == true + - basis_samples.size() > 0 + - kernel must be capable of operating on the elements of basis_samples. That is, + expressions such as kernel(basis_samples(0), basis_samples(0)) should make sense. + ensures + - 0 < #out_vector_size() <= basis_samples.size() + - #basis_size() == basis_samples.size() + - #get_kernel() == kernel + - This function constructs a map between normal sample_type objects and the + subspace of the kernel feature space defined by the given kernel and the + given set of basis samples. So after this function has been called you + will be able to project sample_type objects into kernel feature space + and obtain the resulting vector as a regular column matrix. + - The basis samples are loaded into this object in the order in which they + are stored in basis_samples. That is: + - for all valid i: (*this)[i] == basis_samples(i) + throws + - empirical_kernel_map_error + This exception is thrown if we are unable to create a kernel map. + If this happens then this object will revert back to its initial value. + !*/ + + void load( + const linearly_independent_subset_finder& lisf + ); + /*! + ensures + - #out_vector_size() == lisf.dictionary_size() + - #basis_size() == lisf.dictionary_size() + - #get_kernel() == lisf.get_kernel() + - Uses the dictionary vectors from lisf as a basis set. Thus, this function + constructs a map between normal sample_type objects and the subspace of + the kernel feature space defined by the given kernel and the given set + of basis samples. So after this function has been called you will be + able to project sample_type objects into kernel feature space and obtain + the resulting vector as a regular column matrix. + - The basis samples are loaded into this object in the order in which they + are stored in lisf. That is: + - for all valid i: (*this)[i] == lisf[i] + throws + - empirical_kernel_map_error + This exception is thrown if we are unable to create a kernel map. + E.g. if the lisf.size() == 0. + If this happens then this object will revert back to its initial value. + !*/ + + const kernel_type get_kernel ( + ) const; + /*! + requires + - out_vector_size() != 0 + ensures + - returns a copy of the kernel used by this object + !*/ + + long out_vector_size ( + ) const; + /*! + ensures + - if (this object has been loaded with basis samples) then + - returns the dimensionality of the vectors output by the project() function. + - else + - returns 0 + !*/ + + unsigned long basis_size ( + ) const; + /*! + ensures + - returns the number of basis vectors in projection_functions created + by this object. This is also equal to the number of basis vectors + given to the load() function. + !*/ + + const sample_type& operator[] ( + unsigned long idx + ) const; + /*! + requires + - idx < basis_size() + ensures + - returns a const reference to the idx'th basis vector contained inside + this object. + !*/ + + const matrix& project ( + const sample_type& sample + ) const; + /*! + requires + - out_vector_size() != 0 + ensures + - takes the given sample and projects it into the kernel feature space + of out_vector_size() dimensions defined by this kernel map and + returns the resulting vector. + - in more precise terms, this function returns a vector such that: + - The returned vector will contain out_vector_size() elements. + - for any sample_type object S, the following equality is approximately true: + - get_kernel()(sample,S) == dot(project(sample), project(S)). + - The approximation error in the above equality will be zero (within rounding error) + if both sample_type objects involved are within the span of the set of basis + samples given to the load() function. If they are not then there will be some + approximation error. Note that all the basis samples are always within their + own span. So the equality is always exact for the samples given to the load() + function. + !*/ + + const matrix& project ( + const sample_type& samp, + scalar_type& projection_error + ) const; + /*! + requires + - out_vector_size() != 0 + ensures + - This function returns project(samp) + (i.e. it returns the same thing as the above project() function) + - #projection_error == the square of the distance between the point samp + gets projected onto and samp's true image in kernel feature space. + That is, this value is equal to: + pow(convert_to_distance_function(project(samp))(samp),2) + !*/ + + template + const decision_function convert_to_decision_function ( + const matrix_exp& vect + ) const; + /*! + requires + - is_vector(vect) == true + - vect.size() == out_vector_size() + - out_vector_size() != 0 + ensures + - This function interprets the given vector as a point in the kernel feature space defined + by this empirical_kernel_map. The return value of this function is a decision + function, DF, that represents the given vector in the following sense: + - for all possible sample_type objects, S, it is the case that DF(S) == dot(project(S), vect) + (i.e. the returned decision function computes dot products, in kernel feature space, + between vect and any argument you give it. Note also that this equality is exact, even + for sample_type objects not in the span of the basis samples.) + - DF.kernel_function == get_kernel() + - DF.b == 0 + - DF.basis_vectors == these will be the basis samples given to the previous call to load(). Note + that it is possible for there to be fewer basis_vectors than basis samples given to load(). + - DF.basis_vectors.size() == basis_size() + !*/ + + template + const distance_function convert_to_distance_function ( + const matrix_exp& vect + ) const + /*! + requires + - is_vector(vect) == true + - vect.size() == out_vector_size() + - out_vector_size() != 0 + ensures + - This function interprets the given vector as a point in the kernel feature space defined + by this empirical_kernel_map. The return value of this function is a distance + function, DF, that represents the given vector in the following sense: + - for any sample_type object S, the following equality is approximately true: + - DF(S) == length(project(S) - vect) + (i.e. the returned distance function computes distances, in kernel feature space, + between vect and any argument you give it. ) + - The approximation error in the above equality will be zero (within rounding error) + if S is within the span of the set of basis samples given to the load() function. + If it is not then there will be some approximation error. Note that all the basis + samples are always within their own span. So the equality is always exact for the + samples given to the load() function. Note further that the distance computed + by DF(S) is always the correct distance in kernel feature space between vect and + the true projection of S. That is, the above equality is approximate only because + of potential error in the project() function, not in DF(S). + - DF.kernel_function == get_kernel() + - DF.b == dot(vect,vect) + - DF.basis_vectors == these will be the basis samples given to the previous call to load(). Note + that it is possible for there to be fewer basis_vectors than basis samples given to load(). + - DF.basis_vectors.size() == basis_size() + !*/ + + const projection_function get_projection_function ( + ) const; + /*! + requires + - out_vector_size() != 0 + ensures + - returns a projection_function, PF, that computes the same projection as project(). + That is, calling PF() on any sample will produce the same output vector as calling + this->project() on that sample. + - PF.basis_vectors.size() == basis_size() + !*/ + + const matrix get_transformation_to ( + const empirical_kernel_map& target + ) const; + /*! + requires + - get_kernel() == target.get_kernel() + - out_vector_size() != 0 + - target.out_vector_size() != 0 + ensures + - A point in the kernel feature space defined by the kernel get_kernel() typically + has different representations with respect to different empirical_kernel_maps. + This function lets you obtain a transformation matrix that will allow you + to project between these different representations. That is, this function returns + a matrix M with the following properties: + - M maps vectors represented according to *this into the representation used by target. + - M.nr() == target.out_vector_size() + - M.nc() == this->out_vector_size() + - Let V be a vector of this->out_vector_size() length. Then define two distance_functions + DF1 = this->convert_to_distance_function(V) + DF2 = target.convert_to_distance_function(M*V) + + Then DF1(DF2) == 0 // i.e. the distance between these two points should be 0 + + That is, DF1 and DF2 both represent the same point in kernel feature space. Note + that the above equality is only approximate. If the vector V represents a point in + kernel space that isn't in the span of the basis samples used by target then the + equality is approximate. However, if it is in their span then the equality will + be exact. For example, if target's basis samples are a superset of the basis samples + used by *this then the equality will always be exact (within rounding error). + !*/ + + void get_transformation_to ( + const empirical_kernel_map& target, + matrix& tmat, + projection_function& partial_projection + ) const; + /*! + requires + - get_kernel() == target.get_kernel() + - out_vector_size() != 0 + - target.out_vector_size() != 0 + - basis_size() < target.basis_size() + - for all i < basis_size(): (*this)[i] == target[i] + i.e. target must contain a superset of the basis vectors contained in *this. Moreover, + it must contain them in the same order. + ensures + - The single argument version of get_transformation_to() allows you to project + vectors from one empirical_kernel_map representation to another. This version + provides a somewhat different capability. Assuming target's basis vectors form a + superset of *this's basis vectors then this form of get_transformation_to() allows + you to reuse a vector from *this ekm to speed up the projection performed by target. + The defining relation is given below. + - for any sample S: + - target.project(S) == #tmat * this->project(S) + #partial_projection(S) + (this is always true to within rounding error for any S) + - #partial_projection.basis_vectors.size() == target.basis_vectors.size() - this->basis_vectors.size() + - #tmat.nr() == target.out_vector_size() + - #tmat.nc() == this->out_vector_size() + !*/ + + void swap ( + empirical_kernel_map& item + ); + /*! + ensures + - swaps the state of *this and item + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename kernel_type + > + void swap ( + empirical_kernel_map& a, + empirical_kernel_map& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + + template < + typename kernel_type + > + void serialize ( + const empirical_kernel_map& item, + std::ostream& out + ); + /*! + provides serialization support for empirical_kernel_map objects + !*/ + + template < + typename kernel_type + > + void deserialize ( + empirical_kernel_map& item, + std::istream& in + ); + /*! + provides serialization support for empirical_kernel_map objects + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_EMPIRICAL_KERNEl_MAP_ABSTRACT_H_ + diff --git a/ml/dlib/dlib/svm/feature_ranking.h b/ml/dlib/dlib/svm/feature_ranking.h new file mode 100644 index 000000000..f6324fe3d --- /dev/null +++ b/ml/dlib/dlib/svm/feature_ranking.h @@ -0,0 +1,477 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_KERNEL_FEATURE_RANKINg_H_ +#define DLIB_KERNEL_FEATURE_RANKINg_H_ + +#include +#include + +#include "feature_ranking_abstract.h" +#include "kcentroid.h" +#include "../optimization.h" +#include "../statistics.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename kernel_type, + typename sample_matrix_type, + typename label_matrix_type + > + matrix rank_features_impl ( + const kcentroid& kc, + const sample_matrix_type& samples, + const label_matrix_type& labels + ) + { + /* + This function ranks features by doing recursive feature elimination + + */ + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::mem_manager_type mm; + + + // make sure requires clause is not broken + DLIB_ASSERT(is_binary_classification_problem(samples, labels) == true, + "\tmatrix rank_features()" + << "\n\t you have given invalid arguments to this function" + ); + + matrix results(samples(0).nr(), 2); + matrix mask(samples(0).nr()); + set_all_elements(mask,1); + + // figure out what the separation is between the two centroids when all the features are + // present. + scalar_type first_separation; + { + kcentroid c1(kc); + kcentroid c2(kc); + // find the centers of each class + for (long s = 0; s < samples.size(); ++s) + { + if (labels(s) < 0) + { + c1.train(samples(s)); + } + else + { + c2.train(samples(s)); + } + + } + first_separation = c1(c2); + } + + + using namespace std; + + for (long i = results.nr()-1; i >= 0; --i) + { + long worst_feature_idx = 0; + scalar_type worst_feature_score = -std::numeric_limits::infinity(); + + // figure out which feature to remove next + for (long j = 0; j < mask.size(); ++j) + { + // skip features we have already removed + if (mask(j) == 0) + continue; + + kcentroid c1(kc); + kcentroid c2(kc); + + // temporarily remove this feature from the working set of features + mask(j) = 0; + + // find the centers of each class + for (long s = 0; s < samples.size(); ++s) + { + if (labels(s) < 0) + { + c1.train(pointwise_multiply(samples(s),mask)); + } + else + { + c2.train(pointwise_multiply(samples(s),mask)); + } + + } + + // find the distance between the two centroids and use that + // as the score + const double score = c1(c2); + + if (score > worst_feature_score) + { + worst_feature_score = score; + worst_feature_idx = j; + } + + // add this feature back to the working set of features + mask(j) = 1; + + } + + // now that we know what the next worst feature is record it + mask(worst_feature_idx) = 0; + results(i,0) = worst_feature_idx; + results(i,1) = worst_feature_score; + } + + // now normalize the results + const scalar_type max_separation = std::max(max(colm(results,1)), first_separation); + set_colm(results,1) = colm(results,1)/max_separation; + for (long r = 0; r < results.nr()-1; ++r) + { + results(r,1) = results(r+1,1); + } + results(results.nr()-1,1) = first_separation/max_separation; + + return results; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename kernel_type, + typename sample_matrix_type, + typename label_matrix_type + > + matrix rank_features ( + const kcentroid& kc, + const sample_matrix_type& samples, + const label_matrix_type& labels + ) + { + return rank_features_impl(kc, mat(samples), mat(labels)); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename kernel_type, + typename sample_matrix_type, + typename label_matrix_type + > + matrix rank_features_impl ( + const kcentroid& kc, + const sample_matrix_type& samples, + const label_matrix_type& labels, + const long num_features + ) + { + /* + This function ranks features by doing recursive feature addition + + */ + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::mem_manager_type mm; + + // make sure requires clause is not broken + DLIB_ASSERT(is_binary_classification_problem(samples, labels) == true, + "\tmatrix rank_features()" + << "\n\t you have given invalid arguments to this function" + ); + DLIB_ASSERT(0 < num_features && num_features <= samples(0).nr(), + "\tmatrix rank_features()" + << "\n\t you have given invalid arguments to this function" + << "\n\t num_features: " << num_features + << "\n\t samples(0).nr(): " << samples(0).nr() + ); + + matrix results(num_features, 2); + matrix mask(samples(0).nr()); + set_all_elements(mask,0); + + using namespace std; + + for (long i = 0; i < results.nr(); ++i) + { + long best_feature_idx = 0; + scalar_type best_feature_score = -std::numeric_limits::infinity(); + + // figure out which feature to add next + for (long j = 0; j < mask.size(); ++j) + { + // skip features we have already added + if (mask(j) == 1) + continue; + + kcentroid c1(kc); + kcentroid c2(kc); + + // temporarily add this feature to the working set of features + mask(j) = 1; + + // find the centers of each class + for (long s = 0; s < samples.size(); ++s) + { + if (labels(s) < 0) + { + c1.train(pointwise_multiply(samples(s),mask)); + } + else + { + c2.train(pointwise_multiply(samples(s),mask)); + } + + } + + // find the distance between the two centroids and use that + // as the score + const double score = c1(c2); + + if (score > best_feature_score) + { + best_feature_score = score; + best_feature_idx = j; + } + + // take this feature back out of the working set of features + mask(j) = 0; + + } + + // now that we know what the next best feature is record it + mask(best_feature_idx) = 1; + results(i,0) = best_feature_idx; + results(i,1) = best_feature_score; + } + + // now normalize the results + set_colm(results,1) = colm(results,1)/max(colm(results,1)); + + return results; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename kernel_type, + typename sample_matrix_type, + typename label_matrix_type + > + matrix rank_features ( + const kcentroid& kc, + const sample_matrix_type& samples, + const label_matrix_type& labels, + const long num_features + ) + { + if (mat(samples).nr() > 0 && num_features == mat(samples)(0).nr()) + { + // if we are going to rank them all then might as well do the recursive feature elimination version + return rank_features_impl(kc, mat(samples), mat(labels)); + } + else + { + return rank_features_impl(kc, mat(samples), mat(labels), num_features); + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + namespace rank_features_helpers + { + template < + typename K, + typename sample_matrix_type, + typename label_matrix_type + > + typename K::scalar_type centroid_gap ( + const kcentroid& kc, + const sample_matrix_type& samples, + const label_matrix_type& labels + ) + { + kcentroid kc1(kc); + kcentroid kc2(kc); + + // toss all the samples into our kcentroids + for (long i = 0; i < samples.size(); ++i) + { + if (labels(i) > 0) + kc1.train(samples(i)); + else + kc2.train(samples(i)); + } + + // now return the separation between the mean of these two centroids + return kc1(kc2); + } + + template < + typename sample_matrix_type, + typename label_matrix_type + > + class test + { + typedef typename sample_matrix_type::type sample_type; + typedef typename sample_type::type scalar_type; + typedef typename sample_type::mem_manager_type mem_manager_type; + + public: + test ( + const sample_matrix_type& samples_, + const label_matrix_type& labels_, + unsigned long num_sv_, + bool verbose_ + ) : samples(samples_), labels(labels_), num_sv(num_sv_), verbose(verbose_) + { + } + + double operator() ( + double gamma + ) const + { + using namespace std; + + // we are doing the optimization in log space so don't forget to convert back to normal space + gamma = std::exp(gamma); + + typedef radial_basis_kernel kernel_type; + // Make a kcentroid and find out what the gap is at the current gamma. Try to pick a reasonable + // tolerance. + const double tolerance = std::min(gamma*0.01, 0.01); + const kernel_type kern(gamma); + kcentroid kc(kern, tolerance, num_sv); + scalar_type temp = centroid_gap(kc, samples, labels); + + if (verbose) + { + cout << "\rChecking goodness of gamma = " << gamma << ". Goodness = " + << temp << " " << flush; + } + return temp; + } + + const sample_matrix_type& samples; + const label_matrix_type& labels; + unsigned long num_sv; + bool verbose; + + }; + + template < + typename sample_matrix_type, + typename label_matrix_type + > + double find_gamma_with_big_centroid_gap_impl ( + const sample_matrix_type& samples, + const label_matrix_type& labels, + double initial_gamma, + unsigned long num_sv, + bool verbose + ) + { + using namespace std; + + if (verbose) + { + cout << endl; + } + + test funct(samples, labels, num_sv, verbose); + double best_gamma = std::log(initial_gamma); + double goodness = find_max_single_variable(funct, best_gamma, -15, 15, 1e-3, 100); + + if (verbose) + { + cout << "\rBest gamma = " << std::exp(best_gamma) << ". Goodness = " + << goodness << " " << endl; + } + + return std::exp(best_gamma); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename sample_matrix_type, + typename label_matrix_type + > + double find_gamma_with_big_centroid_gap ( + const sample_matrix_type& samples, + const label_matrix_type& labels, + double initial_gamma = 0.1, + unsigned long num_sv = 40 + ) + { + DLIB_ASSERT(initial_gamma > 0 && num_sv > 0 && is_binary_classification_problem(samples, labels), + "\t double find_gamma_with_big_centroid_gap()" + << "\n\t initial_gamma: " << initial_gamma + << "\n\t num_sv: " << num_sv + << "\n\t is_binary_classification_problem(): " << is_binary_classification_problem(samples, labels) + ); + + return rank_features_helpers::find_gamma_with_big_centroid_gap_impl(mat(samples), + mat(labels), + initial_gamma, + num_sv, + false); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename sample_matrix_type, + typename label_matrix_type + > + double verbose_find_gamma_with_big_centroid_gap ( + const sample_matrix_type& samples, + const label_matrix_type& labels, + double initial_gamma = 0.1, + unsigned long num_sv = 40 + ) + { + DLIB_ASSERT(initial_gamma > 0 && num_sv > 0 && is_binary_classification_problem(samples, labels), + "\t double verbose_find_gamma_with_big_centroid_gap()" + << "\n\t initial_gamma: " << initial_gamma + << "\n\t num_sv: " << num_sv + << "\n\t is_binary_classification_problem(): " << is_binary_classification_problem(samples, labels) + ); + + return rank_features_helpers::find_gamma_with_big_centroid_gap_impl(mat(samples), + mat(labels), + initial_gamma, + num_sv, + true); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename vector_type + > + double compute_mean_squared_distance ( + const vector_type& samples + ) + { + running_stats rs; + for (unsigned long i = 0; i < samples.size(); ++i) + { + for (unsigned long j = i+1; j < samples.size(); ++j) + { + rs.add(length_squared(samples[i] - samples[j])); + } + } + + return rs.mean(); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_KERNEL_FEATURE_RANKINg_H_ + + diff --git a/ml/dlib/dlib/svm/feature_ranking_abstract.h b/ml/dlib/dlib/svm/feature_ranking_abstract.h new file mode 100644 index 000000000..5a6fd3bb9 --- /dev/null +++ b/ml/dlib/dlib/svm/feature_ranking_abstract.h @@ -0,0 +1,136 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_KERNEL_FEATURE_RANKINg_ABSTRACT_H_ +#ifdef DLIB_KERNEL_FEATURE_RANKINg_ABSTRACT_H_ + +#include +#include + +#include "svm_abstract.h" +#include "kcentroid_abstract.h" +#include "../is_kind.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename kernel_type, + typename sample_matrix_type, + typename label_matrix_type + > + matrix rank_features ( + const kcentroid& kc, + const sample_matrix_type& samples, + const label_matrix_type& labels, + const long num_features = samples(0).nr() + ); + /*! + requires + - sample_matrix_type == a matrix or something convertible to a matrix via mat() + - label_matrix_type == a matrix or something convertible to a matrix via mat() + - is_binary_classification_problem(samples, labels) == true + - kc.train(samples(0)) must be a valid expression. This means that + kc must use a kernel type that is capable of operating on the + contents of the samples matrix + - 0 < num_features <= samples(0).nr() + ensures + - Let Class1 denote the centroid of all the samples with labels that are < 0 + - Let Class2 denote the centroid of all the samples with labels that are > 0 + - finds a ranking of the features where the best features come first. This + function does this by computing the distance between the centroid of the Class1 + samples and the Class2 samples in kernel defined feature space. + Good features are then ones that result in the biggest separation between + the two centroids of Class1 and Class2. + - Uses the kc object to compute the centroids of the two classes + - returns a ranking matrix R where: + - R.nr() == num_features + - r.nc() == 2 + - R(i,0) == the index of the ith best feature according to our ranking. + (e.g. samples(n)(R(0,0)) is the best feature from sample(n) and + samples(n)(R(1,0)) is the second best, samples(n)(R(2,0)) the + third best and so on) + - R(i,1) == a number that indicates how much separation exists between + the two centroids when features 0 through i are used. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename sample_matrix_type, + typename label_matrix_type + > + double find_gamma_with_big_centroid_gap ( + const sample_matrix_type& samples, + const label_matrix_type& labels, + double initial_gamma = 0.1, + unsigned long num_sv = 40 + ); + /*! + requires + - initial_gamma > 0 + - num_sv > 0 + - is_binary_classification_problem(samples, labels) == true + ensures + - This is a function that tries to pick a reasonable default value for the gamma + parameter of the radial_basis_kernel. It picks the parameter that gives the + largest separation between the centroids, in kernel feature space, of two classes + of data. It does this using the kcentroid object and it sets the kcentroid up + to use num_sv dictionary vectors. + - This function does a search for the best gamma and the search starts with + the value given by initial_gamma. Better initial guesses will give + better results since the routine may get stuck in a local minima. + - returns the value of gamma that results in the largest separation. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename sample_matrix_type, + typename label_matrix_type + > + double verbose_find_gamma_with_big_centroid_gap ( + const sample_matrix_type& samples, + const label_matrix_type& labels, + double initial_gamma = 0.1, + unsigned long num_sv = 40 + ); + /*! + requires + - initial_gamma > 0 + - num_sv > 0 + - is_binary_classification_problem(samples, labels) == true + ensures + - This function does the same exact thing as the above find_gamma_with_big_centroid_gap() + except that it is also verbose in the sense that it will print status messages to + standard out during its processing. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename vector_type + > + double compute_mean_squared_distance ( + const vector_type& samples + ); + /*! + requires + - vector_type is something with an interface compatible with std::vector. + Additionally, it must in turn contain dlib::matrix types which contain + scalars such as float or double values. + - for all valid i: is_vector(samples[i]) == true + ensures + - computes the average value of the squares of all the pairwise + distances between every element of samples. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_KERNEL_FEATURE_RANKINg_ABSTRACT_H_ + + + diff --git a/ml/dlib/dlib/svm/function.h b/ml/dlib/dlib/svm/function.h new file mode 100644 index 000000000..f5a62a9f7 --- /dev/null +++ b/ml/dlib/dlib/svm/function.h @@ -0,0 +1,882 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SVm_FUNCTION +#define DLIB_SVm_FUNCTION + +#include "function_abstract.h" +#include +#include +#include +#include "../matrix.h" +#include "../algs.h" +#include "../serialize.h" +#include "../rand.h" +#include "../statistics.h" +#include "kernel_matrix.h" +#include "kernel.h" +#include "sparse_kernel.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename K + > + struct decision_function + { + typedef K kernel_type; + typedef typename K::scalar_type scalar_type; + typedef typename K::scalar_type result_type; + typedef typename K::sample_type sample_type; + typedef typename K::mem_manager_type mem_manager_type; + + typedef matrix scalar_vector_type; + typedef matrix sample_vector_type; + + scalar_vector_type alpha; + scalar_type b; + K kernel_function; + sample_vector_type basis_vectors; + + decision_function ( + ) : b(0), kernel_function(K()) {} + + decision_function ( + const decision_function& d + ) : + alpha(d.alpha), + b(d.b), + kernel_function(d.kernel_function), + basis_vectors(d.basis_vectors) + {} + + decision_function ( + const scalar_vector_type& alpha_, + const scalar_type& b_, + const K& kernel_function_, + const sample_vector_type& basis_vectors_ + ) : + alpha(alpha_), + b(b_), + kernel_function(kernel_function_), + basis_vectors(basis_vectors_) + {} + + result_type operator() ( + const sample_type& x + ) const + { + result_type temp = 0; + for (long i = 0; i < alpha.nr(); ++i) + temp += alpha(i) * kernel_function(x,basis_vectors(i)); + + return temp - b; + } + }; + + template < + typename K + > + void serialize ( + const decision_function& item, + std::ostream& out + ) + { + try + { + serialize(item.alpha, out); + serialize(item.b, out); + serialize(item.kernel_function, out); + serialize(item.basis_vectors, out); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing object of type decision_function"); + } + } + + template < + typename K + > + void deserialize ( + decision_function& item, + std::istream& in + ) + { + try + { + deserialize(item.alpha, in); + deserialize(item.b, in); + deserialize(item.kernel_function, in); + deserialize(item.basis_vectors, in); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while deserializing object of type decision_function"); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename function_type + > + struct probabilistic_function + { + typedef typename function_type::scalar_type scalar_type; + typedef typename function_type::result_type result_type; + typedef typename function_type::sample_type sample_type; + typedef typename function_type::mem_manager_type mem_manager_type; + + scalar_type alpha; + scalar_type beta; + function_type decision_funct; + + probabilistic_function ( + ) : alpha(0), beta(0), decision_funct(function_type()) {} + + probabilistic_function ( + const probabilistic_function& d + ) : + alpha(d.alpha), + beta(d.beta), + decision_funct(d.decision_funct) + {} + + probabilistic_function ( + const scalar_type a_, + const scalar_type b_, + const function_type& decision_funct_ + ) : + alpha(a_), + beta(b_), + decision_funct(decision_funct_) + {} + + result_type operator() ( + const sample_type& x + ) const + { + result_type f = decision_funct(x); + return 1/(1 + std::exp(alpha*f + beta)); + } + }; + + template < + typename function_type + > + void serialize ( + const probabilistic_function& item, + std::ostream& out + ) + { + try + { + serialize(item.alpha, out); + serialize(item.beta, out); + serialize(item.decision_funct, out); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing object of type probabilistic_function"); + } + } + + template < + typename function_type + > + void deserialize ( + probabilistic_function& item, + std::istream& in + ) + { + try + { + deserialize(item.alpha, in); + deserialize(item.beta, in); + deserialize(item.decision_funct, in); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while deserializing object of type probabilistic_function"); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename K + > + struct probabilistic_decision_function + { + typedef K kernel_type; + typedef typename K::scalar_type scalar_type; + typedef typename K::scalar_type result_type; + typedef typename K::sample_type sample_type; + typedef typename K::mem_manager_type mem_manager_type; + + scalar_type alpha; + scalar_type beta; + decision_function decision_funct; + + probabilistic_decision_function ( + ) : alpha(0), beta(0), decision_funct(decision_function()) {} + + probabilistic_decision_function ( + const probabilistic_function >& d + ) : + alpha(d.alpha), + beta(d.beta), + decision_funct(d.decision_funct) + {} + + probabilistic_decision_function ( + const probabilistic_decision_function& d + ) : + alpha(d.alpha), + beta(d.beta), + decision_funct(d.decision_funct) + {} + + probabilistic_decision_function ( + const scalar_type a_, + const scalar_type b_, + const decision_function& decision_funct_ + ) : + alpha(a_), + beta(b_), + decision_funct(decision_funct_) + {} + + result_type operator() ( + const sample_type& x + ) const + { + result_type f = decision_funct(x); + return 1/(1 + std::exp(alpha*f + beta)); + } + }; + + template < + typename K + > + void serialize ( + const probabilistic_decision_function& item, + std::ostream& out + ) + { + try + { + serialize(item.alpha, out); + serialize(item.beta, out); + serialize(item.decision_funct, out); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing object of type probabilistic_decision_function"); + } + } + + template < + typename K + > + void deserialize ( + probabilistic_decision_function& item, + std::istream& in + ) + { + try + { + deserialize(item.alpha, in); + deserialize(item.beta, in); + deserialize(item.decision_funct, in); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while deserializing object of type probabilistic_decision_function"); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename K + > + class distance_function + { + public: + typedef K kernel_type; + typedef typename K::scalar_type scalar_type; + typedef typename K::scalar_type result_type; + typedef typename K::sample_type sample_type; + typedef typename K::mem_manager_type mem_manager_type; + + typedef matrix scalar_vector_type; + typedef matrix sample_vector_type; + + + distance_function ( + ) : b(0), kernel_function(K()) {} + + explicit distance_function ( + const kernel_type& kern + ) : b(0), kernel_function(kern) {} + + distance_function ( + const kernel_type& kern, + const sample_type& samp + ) : + alpha(ones_matrix(1,1)), + b(kern(samp,samp)), + kernel_function(kern) + { + basis_vectors.set_size(1,1); + basis_vectors(0) = samp; + } + + distance_function ( + const decision_function& f + ) : + alpha(f.alpha), + b(trans(f.alpha)*kernel_matrix(f.kernel_function,f.basis_vectors)*f.alpha), + kernel_function(f.kernel_function), + basis_vectors(f.basis_vectors) + { + // make sure requires clause is not broken + DLIB_ASSERT(f.alpha.size() == f.basis_vectors.size(), + "\t distance_function(f)" + << "\n\t The supplied decision_function is invalid." + << "\n\t f.alpha.size(): " << f.alpha.size() + << "\n\t f.basis_vectors.size(): " << f.basis_vectors.size() + ); + } + + distance_function ( + const distance_function& d + ) : + alpha(d.alpha), + b(d.b), + kernel_function(d.kernel_function), + basis_vectors(d.basis_vectors) + { + } + + distance_function ( + const scalar_vector_type& alpha_, + const scalar_type& b_, + const K& kernel_function_, + const sample_vector_type& basis_vectors_ + ) : + alpha(alpha_), + b(b_), + kernel_function(kernel_function_), + basis_vectors(basis_vectors_) + { + // make sure requires clause is not broken + DLIB_ASSERT(alpha_.size() == basis_vectors_.size(), + "\t distance_function()" + << "\n\t The supplied arguments are invalid." + << "\n\t alpha_.size(): " << alpha_.size() + << "\n\t basis_vectors_.size(): " << basis_vectors_.size() + ); + } + + distance_function ( + const scalar_vector_type& alpha_, + const K& kernel_function_, + const sample_vector_type& basis_vectors_ + ) : + alpha(alpha_), + b(trans(alpha)*kernel_matrix(kernel_function_,basis_vectors_)*alpha), + kernel_function(kernel_function_), + basis_vectors(basis_vectors_) + { + // make sure requires clause is not broken + DLIB_ASSERT(alpha_.size() == basis_vectors_.size(), + "\t distance_function()" + << "\n\t The supplied arguments are invalid." + << "\n\t alpha_.size(): " << alpha_.size() + << "\n\t basis_vectors_.size(): " << basis_vectors_.size() + ); + } + + const scalar_vector_type& get_alpha ( + ) const { return alpha; } + + const scalar_type& get_squared_norm ( + ) const { return b; } + + const K& get_kernel( + ) const { return kernel_function; } + + const sample_vector_type& get_basis_vectors ( + ) const { return basis_vectors; } + + result_type operator() ( + const sample_type& x + ) const + { + result_type temp = 0; + for (long i = 0; i < alpha.nr(); ++i) + temp += alpha(i) * kernel_function(x,basis_vectors(i)); + + temp = b + kernel_function(x,x) - 2*temp; + if (temp > 0) + return std::sqrt(temp); + else + return 0; + } + + result_type operator() ( + const distance_function& x + ) const + { + result_type temp = 0; + for (long i = 0; i < alpha.nr(); ++i) + for (long j = 0; j < x.alpha.nr(); ++j) + temp += alpha(i)*x.alpha(j) * kernel_function(basis_vectors(i), x.basis_vectors(j)); + + temp = b + x.b - 2*temp; + if (temp > 0) + return std::sqrt(temp); + else + return 0; + } + + distance_function operator* ( + const scalar_type& val + ) const + { + return distance_function(val*alpha, + val*val*b, + kernel_function, + basis_vectors); + } + + distance_function operator/ ( + const scalar_type& val + ) const + { + return distance_function(alpha/val, + b/val/val, + kernel_function, + basis_vectors); + } + + distance_function operator+ ( + const distance_function& rhs + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(get_kernel() == rhs.get_kernel(), + "\t distance_function distance_function::operator+()" + << "\n\t You can only add two distance_functions together if they use the same kernel." + ); + + if (alpha.size() == 0) + return rhs; + else if (rhs.alpha.size() == 0) + return *this; + else + return distance_function(join_cols(alpha, rhs.alpha), + b + rhs.b + 2*trans(alpha)*kernel_matrix(kernel_function,basis_vectors,rhs.basis_vectors)*rhs.alpha, + kernel_function, + join_cols(basis_vectors, rhs.basis_vectors)); + } + + distance_function operator- ( + const distance_function& rhs + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(get_kernel() == rhs.get_kernel(), + "\t distance_function distance_function::operator-()" + << "\n\t You can only subtract two distance_functions if they use the same kernel." + ); + + if (alpha.size() == 0 && rhs.alpha.size() == 0) + return distance_function(kernel_function); + else if (alpha.size() != 0 && rhs.alpha.size() == 0) + return *this; + else if (alpha.size() == 0 && rhs.alpha.size() != 0) + return -1*rhs; + else + return distance_function(join_cols(alpha, -rhs.alpha), + b + rhs.b - 2*trans(alpha)*kernel_matrix(kernel_function,basis_vectors,rhs.basis_vectors)*rhs.alpha, + kernel_function, + join_cols(basis_vectors, rhs.basis_vectors)); + } + + private: + + scalar_vector_type alpha; + scalar_type b; + K kernel_function; + sample_vector_type basis_vectors; + + }; + + template < + typename K + > + distance_function operator* ( + const typename K::scalar_type& val, + const distance_function& df + ) { return df*val; } + + template < + typename K + > + void serialize ( + const distance_function& item, + std::ostream& out + ) + { + try + { + serialize(item.alpha, out); + serialize(item.b, out); + serialize(item.kernel_function, out); + serialize(item.basis_vectors, out); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing object of type distance_function"); + } + } + + template < + typename K + > + void deserialize ( + distance_function& item, + std::istream& in + ) + { + try + { + deserialize(item.alpha, in); + deserialize(item.b, in); + deserialize(item.kernel_function, in); + deserialize(item.basis_vectors, in); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while deserializing object of type distance_function"); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename function_type, + typename normalizer_type = vector_normalizer + > + struct normalized_function + { + typedef typename function_type::result_type result_type; + typedef typename function_type::sample_type sample_type; + typedef typename function_type::mem_manager_type mem_manager_type; + + normalizer_type normalizer; + function_type function; + + normalized_function ( + ){} + + normalized_function ( + const normalized_function& f + ) : + normalizer(f.normalizer), + function(f.function) + {} + + const std::vector get_labels( + ) const { return function.get_labels(); } + + unsigned long number_of_classes ( + ) const { return function.number_of_classes(); } + + normalized_function ( + const vector_normalizer& normalizer_, + const function_type& funct + ) : normalizer(normalizer_), function(funct) {} + + result_type operator() ( + const sample_type& x + ) const { return function(normalizer(x)); } + }; + + template < + typename function_type, + typename normalizer_type + > + void serialize ( + const normalized_function& item, + std::ostream& out + ) + { + try + { + serialize(item.normalizer, out); + serialize(item.function, out); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing object of type normalized_function"); + } + } + + template < + typename function_type, + typename normalizer_type + > + void deserialize ( + normalized_function& item, + std::istream& in + ) + { + try + { + deserialize(item.normalizer, in); + deserialize(item.function, in); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while deserializing object of type normalized_function"); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename K + > + struct projection_function + { + typedef K kernel_type; + typedef typename K::scalar_type scalar_type; + typedef typename K::sample_type sample_type; + typedef typename K::mem_manager_type mem_manager_type; + + typedef matrix scalar_vector_type; + typedef matrix scalar_matrix_type; + typedef matrix sample_vector_type; + typedef scalar_vector_type result_type; + + scalar_matrix_type weights; + K kernel_function; + sample_vector_type basis_vectors; + + projection_function ( + ) {} + + projection_function ( + const projection_function& f + ) : weights(f.weights), kernel_function(f.kernel_function), basis_vectors(f.basis_vectors) {} + + projection_function ( + const scalar_matrix_type& weights_, + const K& kernel_function_, + const sample_vector_type& basis_vectors_ + ) : weights(weights_), kernel_function(kernel_function_), basis_vectors(basis_vectors_) {} + + long out_vector_size ( + ) const { return weights.nr(); } + + const result_type& operator() ( + const sample_type& x + ) const + { + // Run the x sample through all the basis functions we have and then + // multiply it by the weights matrix and return the result. Note that + // the temp vectors are here to avoid reallocating their memory every + // time this function is called. + temp1 = kernel_matrix(kernel_function, basis_vectors, x); + temp2 = weights*temp1; + return temp2; + } + + private: + mutable result_type temp1, temp2; + }; + + template < + typename K + > + void serialize ( + const projection_function& item, + std::ostream& out + ) + { + try + { + serialize(item.weights, out); + serialize(item.kernel_function, out); + serialize(item.basis_vectors, out); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing object of type projection_function"); + } + } + + template < + typename K + > + void deserialize ( + projection_function& item, + std::istream& in + ) + { + try + { + deserialize(item.weights, in); + deserialize(item.kernel_function, in); + deserialize(item.basis_vectors, in); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while deserializing object of type projection_function"); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename K, + typename result_type_ = typename K::scalar_type + > + struct multiclass_linear_decision_function + { + typedef result_type_ result_type; + + typedef K kernel_type; + typedef typename K::scalar_type scalar_type; + typedef typename K::sample_type sample_type; + typedef typename K::mem_manager_type mem_manager_type; + + typedef matrix scalar_vector_type; + typedef matrix scalar_matrix_type; + + // You are getting a compiler error on this line because you supplied a non-linear kernel + // to the multiclass_linear_decision_function object. You have to use one of the linear + // kernels with this object. + COMPILE_TIME_ASSERT((is_same_type >::value || + is_same_type >::value )); + + + scalar_matrix_type weights; + scalar_vector_type b; + std::vector labels; + + const std::vector& get_labels( + ) const { return labels; } + + unsigned long number_of_classes ( + ) const { return labels.size(); } + + std::pair predict ( + const sample_type& x + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(weights.size() > 0 && + weights.nr() == (long)number_of_classes() && + weights.nr() == b.size(), + "\t pair multiclass_linear_decision_function::predict(x)" + << "\n\t This object must be properly initialized before you can use it." + << "\n\t weights.size(): " << weights.size() + << "\n\t weights.nr(): " << weights.nr() + << "\n\t number_of_classes(): " << number_of_classes() + ); + + // Rather than doing something like, best_idx = index_of_max(weights*x-b) + // we do the following somewhat more complex thing because this supports + // both sparse and dense samples. + scalar_type best_val = dot(rowm(weights,0),x) - b(0); + unsigned long best_idx = 0; + + for (unsigned long i = 1; i < labels.size(); ++i) + { + scalar_type temp = dot(rowm(weights,i),x) - b(i); + if (temp > best_val) + { + best_val = temp; + best_idx = i; + } + } + + return std::make_pair(labels[best_idx], best_val); + } + + result_type operator() ( + const sample_type& x + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(weights.size() > 0 && + weights.nr() == (long)number_of_classes() && + weights.nr() == b.size(), + "\t result_type multiclass_linear_decision_function::operator()(x)" + << "\n\t This object must be properly initialized before you can use it." + << "\n\t weights.size(): " << weights.size() + << "\n\t weights.nr(): " << weights.nr() + << "\n\t number_of_classes(): " << number_of_classes() + ); + + return predict(x).first; + } + }; + + template < + typename K, + typename result_type_ + > + void serialize ( + const multiclass_linear_decision_function& item, + std::ostream& out + ) + { + try + { + serialize(item.weights, out); + serialize(item.b, out); + serialize(item.labels, out); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing object of type multiclass_linear_decision_function"); + } + } + + template < + typename K, + typename result_type_ + > + void deserialize ( + multiclass_linear_decision_function& item, + std::istream& in + ) + { + try + { + deserialize(item.weights, in); + deserialize(item.b, in); + deserialize(item.labels, in); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while deserializing object of type multiclass_linear_decision_function"); + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SVm_FUNCTION + + diff --git a/ml/dlib/dlib/svm/function_abstract.h b/ml/dlib/dlib/svm/function_abstract.h new file mode 100644 index 000000000..783a68c50 --- /dev/null +++ b/ml/dlib/dlib/svm/function_abstract.h @@ -0,0 +1,997 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SVm_FUNCTION_ABSTRACT_ +#ifdef DLIB_SVm_FUNCTION_ABSTRACT_ + +#include +#include +#include +#include "../matrix/matrix_abstract.h" +#include "../algs.h" +#include "../serialize.h" +#include "../statistics/statistics_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename K + > + struct decision_function + { + /*! + REQUIREMENTS ON K + K must be a kernel function object type as defined at the + top of dlib/svm/kernel_abstract.h + + WHAT THIS OBJECT REPRESENTS + This object represents a classification or regression function that was + learned by a kernel based learning algorithm. Therefore, it is a function + object that takes a sample object and returns a scalar value. + + THREAD SAFETY + It is always safe to use distinct instances of this object in different + threads. However, when a single instance is shared between threads then + the following rules apply: + It is safe to call operator() on this object from multiple threads so + long as the kernel, K, is also threadsafe. This is because operator() + is a read-only operation. However, any operation that modifies a + decision_function is not threadsafe. + !*/ + + typedef K kernel_type; + typedef typename K::scalar_type scalar_type; + typedef typename K::scalar_type result_type; + typedef typename K::sample_type sample_type; + typedef typename K::mem_manager_type mem_manager_type; + + typedef matrix scalar_vector_type; + typedef matrix sample_vector_type; + + scalar_vector_type alpha; + scalar_type b; + K kernel_function; + sample_vector_type basis_vectors; + + decision_function ( + ); + /*! + ensures + - #b == 0 + - #alpha.nr() == 0 + - #basis_vectors.nr() == 0 + !*/ + + decision_function ( + const decision_function& f + ); + /*! + ensures + - #*this is a copy of f + !*/ + + decision_function ( + const scalar_vector_type& alpha_, + const scalar_type& b_, + const K& kernel_function_, + const sample_vector_type& basis_vectors_ + ) : alpha(alpha_), b(b_), kernel_function(kernel_function_), basis_vectors(basis_vectors_) {} + /*! + ensures + - populates the decision function with the given basis vectors, weights(i.e. alphas), + b term, and kernel function. + !*/ + + result_type operator() ( + const sample_type& x + ) const + /*! + ensures + - evaluates this sample according to the decision + function contained in this object. + !*/ + { + result_type temp = 0; + for (long i = 0; i < alpha.nr(); ++i) + temp += alpha(i) * kernel_function(x,basis_vectors(i)); + + return temp - b; + } + }; + + template < + typename K + > + void serialize ( + const decision_function& item, + std::ostream& out + ); + /*! + provides serialization support for decision_function + !*/ + + template < + typename K + > + void deserialize ( + decision_function& item, + std::istream& in + ); + /*! + provides serialization support for decision_function + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename function_type + > + struct probabilistic_function + { + /*! + REQUIREMENTS ON function_type + - function_type must be a function object with an overloaded + operator() similar to the other function objects defined in + this file. The operator() should return a scalar type such as + double or float. + + WHAT THIS OBJECT REPRESENTS + This object represents a binary decision function that returns an + estimate of the probability that a given sample is in the +1 class. + + THREAD SAFETY + It is always safe to use distinct instances of this object in different + threads. However, when a single instance is shared between threads then + the following rules apply: + It is safe to call operator() on this object from multiple threads so + long as decision_funct is also threadsafe. This is because operator() + is a read-only operation. However, any operation that modifies a + probabilistic_function is not threadsafe. + !*/ + + typedef typename function_type::scalar_type scalar_type; + typedef typename function_type::result_type result_type; + typedef typename function_type::sample_type sample_type; + typedef typename function_type::mem_manager_type mem_manager_type; + + scalar_type alpha; + scalar_type beta; + function_type decision_funct; + + probabilistic_function ( + ); + /*! + ensures + - #alpha == 0 + - #beta == 0 + - #decision_funct has its initial value + !*/ + + probabilistic_function ( + const probabilistic_function& f + ); + /*! + ensures + - #*this is a copy of f + !*/ + + probabilistic_function ( + const scalar_type a, + const scalar_type b, + const function_type& decision_funct_ + ) : alpha(a), beta(b), decision_funct(decision_funct_) {} + /*! + ensures + - populates the probabilistic decision function with the given alpha, beta, + and decision function. + !*/ + + result_type operator() ( + const sample_type& x + ) const + /*! + ensures + - returns a number P such that: + - 0 <= P <= 1 + - P represents the probability that sample x is from + the class +1 + !*/ + { + // Evaluate the normal decision function + result_type f = decision_funct(x); + // Now basically normalize the output so that it is a properly + // conditioned probability of x being in the +1 class given + // the output of the decision function. + return 1/(1 + std::exp(alpha*f + beta)); + } + }; + + template < + typename function_type + > + void serialize ( + const probabilistic_function& item, + std::ostream& out + ); + /*! + provides serialization support for probabilistic_function + !*/ + + template < + typename function_type + > + void deserialize ( + probabilistic_function& item, + std::istream& in + ); + /*! + provides serialization support for probabilistic_function + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename K + > + struct probabilistic_decision_function + { + /*! + REQUIREMENTS ON K + K must be a kernel function object type as defined at the + top of dlib/svm/kernel_abstract.h + + WHAT THIS OBJECT REPRESENTS + This object represents a binary decision function that returns an + estimate of the probability that a given sample is in the +1 class. + + Note that this object is essentially just a copy of + probabilistic_function but with the template argument + changed from being a function type to a kernel type. Therefore, this + type is just a convenient version of probabilistic_function + for the case where the decision function is a dlib::decision_function. + + THREAD SAFETY + It is always safe to use distinct instances of this object in different + threads. However, when a single instance is shared between threads then + the following rules apply: + It is safe to call operator() on this object from multiple threads so + long as the kernel, K, is also threadsafe. This is because operator() + is a read-only operation. However, any operation that modifies a + probabilistic_decision_function is not threadsafe. + !*/ + + typedef K kernel_type; + typedef typename K::scalar_type scalar_type; + typedef typename K::scalar_type result_type; + typedef typename K::sample_type sample_type; + typedef typename K::mem_manager_type mem_manager_type; + + scalar_type alpha; + scalar_type beta; + decision_function decision_funct; + + probabilistic_decision_function ( + ); + /*! + ensures + - #alpha == 0 + - #beta == 0 + - #decision_funct has its initial value + !*/ + + probabilistic_decision_function ( + const probabilistic_decision_function& f + ); + /*! + ensures + - #*this is a copy of f + !*/ + + probabilistic_decision_function ( + const probabilistic_function >& d + ); + /*! + ensures + - #*this is a copy of f + !*/ + + probabilistic_decision_function ( + const scalar_type a, + const scalar_type b, + const decision_function& decision_funct_ + ) : alpha(a), beta(b), decision_funct(decision_funct_) {} + /*! + ensures + - populates the probabilistic decision function with the given alpha, beta, + and decision_function. + !*/ + + result_type operator() ( + const sample_type& x + ) const + /*! + ensures + - returns a number P such that: + - 0 <= P <= 1 + - P represents the probability that sample x is from + the class +1 + !*/ + { + // Evaluate the normal decision function + result_type f = decision_funct(x); + // Now basically normalize the output so that it is a properly + // conditioned probability of x being in the +1 class given + // the output of the decision function. + return 1/(1 + std::exp(alpha*f + beta)); + } + }; + + template < + typename K + > + void serialize ( + const probabilistic_decision_function& item, + std::ostream& out + ); + /*! + provides serialization support for probabilistic_decision_function + !*/ + + template < + typename K + > + void deserialize ( + probabilistic_decision_function& item, + std::istream& in + ); + /*! + provides serialization support for probabilistic_decision_function + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename K + > + class distance_function + { + /*! + REQUIREMENTS ON K + K must be a kernel function object type as defined at the + top of dlib/svm/kernel_abstract.h + + WHAT THIS OBJECT REPRESENTS + This object represents a point in kernel induced feature space. + You may use this object to find the distance from the point it + represents to points in input space as well as other points + represented by distance_functions. + + Specifically, if O() is the feature mapping associated with + the kernel used by this object. Then this object represents + the point: + sum alpha(i)*O(basis_vectors(i)) + + I.e. It represents a linear combination of the basis vectors where + the weights of the linear combination are stored in the alpha vector. + + THREAD SAFETY + It is always safe to use distinct instances of this object in different + threads. However, when a single instance is shared between threads then + the following rules apply: + It is safe to call the const members of this object from multiple + threads so long as the kernel, K, is also threadsafe. This is because + the const members are purely read-only operations. However, any + operation that modifies a distance_function is not threadsafe. + !*/ + + public: + typedef K kernel_type; + typedef typename K::scalar_type scalar_type; + typedef typename K::scalar_type result_type; + typedef typename K::sample_type sample_type; + typedef typename K::mem_manager_type mem_manager_type; + + typedef matrix scalar_vector_type; + typedef matrix sample_vector_type; + + distance_function ( + ); + /*! + ensures + - #get_squared_norm() == 0 + - #get_alpha().size() == 0 + - #get_basis_vectors().size() == 0 + - #get_kernel() == K() (i.e. the default value of the kernel) + !*/ + + explicit distance_function ( + const kernel_type& kern + ); + /*! + ensures + - #get_squared_norm() == 0 + - #get_alpha().size() == 0 + - #get_basis_vectors().size() == 0 + - #get_kernel() == kern + !*/ + + distance_function ( + const kernel_type& kern, + const sample_type& samp + ); + /*! + ensures + - This object represents the point in kernel feature space which + corresponds directly to the given sample. In particular this means + that: + - #get_kernel() == kern + - #get_alpha() == a vector of length 1 which contains the value 1 + - #get_basis_vectors() == a vector of length 1 which contains samp + !*/ + + distance_function ( + const decision_function& f + ); + /*! + ensures + - Every decision_function represents a point in kernel feature space along + with a bias value. This constructor discards the bias value and creates + a distance_function which represents the point associated with the given + decision_function f. In particular, this means: + - #get_alpha() == f.alpha + - #get_kernel() == f.kernel_function + - #get_basis_vectors() == f.basis_vectors + !*/ + + distance_function ( + const distance_function& f + ); + /*! + requires + - f is a valid distance_function. In particular, this means that + f.alpha.size() == f.basis_vectors.size() + ensures + - #*this is a copy of f + !*/ + + distance_function ( + const scalar_vector_type& alpha, + const scalar_type& squared_norm, + const K& kernel_function, + const sample_vector_type& basis_vectors + ); + /*! + requires + - alpha.size() == basis_vectors.size() + - squared_norm == trans(alpha)*kernel_matrix(kernel_function,basis_vectors)*alpha + (Basically, squared_norm needs to be set properly for this object to make sense. + You should prefer to use the following constructor which computes squared_norm for + you. This version is provided just in case you already know squared_norm and + don't want to spend CPU cycles to recompute it.) + ensures + - populates the distance function with the given basis vectors, weights(i.e. alphas), + squared_norm value, and kernel function. I.e. + - #get_alpha() == alpha + - #get_squared_norm() == squared_norm + - #get_kernel() == kernel_function + - #get_basis_vectors() == basis_vectors + !*/ + + distance_function ( + const scalar_vector_type& alpha, + const K& kernel_function, + const sample_vector_type& basis_vectors + ); + /*! + requires + - alpha.size() == basis_vectors.size() + ensures + - populates the distance function with the given basis vectors, weights(i.e. alphas), + and kernel function. The correct b value is computed automatically. I.e. + - #get_alpha() == alpha + - #get_squared_norm() == trans(alpha)*kernel_matrix(kernel_function,basis_vectors)*alpha + (i.e. get_squared_norm() will be automatically set to the correct value) + - #get_kernel() == kernel_function + - #get_basis_vectors() == basis_vectors + !*/ + + const scalar_vector_type& get_alpha ( + ) const; + /*! + ensures + - returns the set of weights on each basis vector in this object + !*/ + + const scalar_type& get_squared_norm ( + ) const; + /*! + ensures + - returns the squared norm of the point represented by this object. This value is + equal to the following expression: + trans(get_alpha()) * kernel_matrix(get_kernel(),get_basis_vectors()) * get_alpha() + !*/ + + const K& get_kernel( + ) const; + /*! + ensures + - returns the kernel used by this object. + !*/ + + const sample_vector_type& get_basis_vectors ( + ) const; + /*! + ensures + - returns the set of basis vectors contained in this object + !*/ + + result_type operator() ( + const sample_type& x + ) const; + /*! + ensures + - Let O(x) represent the point x projected into kernel induced feature space. + - let c == sum_over_i get_alpha()(i)*O(get_basis_vectors()(i)) == the point in kernel space that + this object represents. That is, c is the weighted sum of basis vectors. + - Then this object returns the distance between the point O(x) and c in kernel + space. + !*/ + + result_type operator() ( + const distance_function& x + ) const; + /*! + requires + - kernel_function == x.kernel_function + ensures + - returns the distance between the points in kernel space represented by *this and x. + !*/ + + distance_function operator* ( + const scalar_type& val + ) const; + /*! + ensures + - multiplies the point represented by *this by val and returns the result. In + particular, this function returns a decision_function DF such that: + - DF.get_basis_vectors() == get_basis_vectors() + - DF.get_kernel() == get_kernel() + - DF.get_alpha() == get_alpha() * val + !*/ + + distance_function operator/ ( + const scalar_type& val + ) const; + /*! + ensures + - divides the point represented by *this by val and returns the result. In + particular, this function returns a decision_function DF such that: + - DF.get_basis_vectors() == get_basis_vectors() + - DF.get_kernel() == get_kernel() + - DF.get_alpha() == get_alpha() / val + !*/ + + distance_function operator+ ( + const distance_function& rhs + ) const; + /*! + requires + - get_kernel() == rhs.get_kernel() + ensures + - returns a distance function DF such that: + - DF represents the sum of the point represented by *this and rhs + - DF.get_basis_vectors().size() == get_basis_vectors().size() + rhs.get_basis_vectors().size() + - DF.get_basis_vectors() contains all the basis vectors in both *this and rhs. + - DF.get_kernel() == get_kernel() + - DF.alpha == join_cols(get_alpha(), rhs.get_alpha()) + !*/ + + distance_function operator- ( + const distance_function& rhs + ) const; + /*! + requires + - get_kernel() == rhs.get_kernel() + ensures + - returns a distance function DF such that: + - DF represents the difference of the point represented by *this and rhs (i.e. *this - rhs) + - DF.get_basis_vectors().size() == get_basis_vectors().size() + rhs.get_basis_vectors().size() + - DF.get_basis_vectors() contains all the basis vectors in both *this and rhs. + - DF.get_kernel() == get_kernel() + - DF.alpha == join_cols(get_alpha(), -1 * rhs.get_alpha()) + !*/ + }; + + template < + typename K + > + distance_function operator* ( + const typename K::scalar_type& val, + const distance_function& df + ) { return df*val; } + /*! + ensures + - multiplies the point represented by *this by val and returns the result. This + function just allows multiplication syntax of the form val*df. + !*/ + + template < + typename K + > + void serialize ( + const distance_function& item, + std::ostream& out + ); + /*! + provides serialization support for distance_function + !*/ + + template < + typename K + > + void deserialize ( + distance_function& item, + std::istream& in + ); + /*! + provides serialization support for distance_function + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename function_type, + typename normalizer_type = vector_normalizer + > + struct normalized_function + { + /*! + REQUIREMENTS ON function_type + - function_type must be a function object with an overloaded + operator() similar to the other function objects defined in + this file. + + REQUIREMENTS ON normalizer_type + - normalizer_type must be a function object with an overloaded + operator() that takes a sample_type and returns a sample_type. + + WHAT THIS OBJECT REPRESENTS + This object represents a container for another function + object and an instance of a normalizer function. + + It automatically normalizes all inputs before passing them + off to the contained function object. + !*/ + + typedef typename function_type::result_type result_type; + typedef typename function_type::sample_type sample_type; + typedef typename function_type::mem_manager_type mem_manager_type; + + normalizer_type normalizer; + function_type function; + + normalized_function ( + ); + /*! + ensures + - the members of this object have their default values + !*/ + + normalized_function ( + const normalized_function& f + ); + /*! + ensures + - #*this is a copy of f + !*/ + + normalized_function ( + const vector_normalizer& normalizer_, + const function_type& funct + ) : normalizer(normalizer_), function(funct) {} + /*! + ensures + - populates this object with the vector_normalizer and function object + !*/ + + const std::vector get_labels( + ) const; + /*! + ensures + - returns function.get_labels() + !*/ + + unsigned long number_of_classes ( + ) const; + /*! + ensures + - returns function.number_of_classes() + !*/ + + result_type operator() ( + const sample_type& x + ) const + /*! + ensures + - returns function(normalizer(x)) + !*/ + }; + + template < + typename function_type, + typename normalizer_type + > + void serialize ( + const normalized_function& item, + std::ostream& out + ); + /*! + provides serialization support for normalized_function + !*/ + + template < + typename function_type, + typename normalizer_type + > + void deserialize ( + normalized_function& item, + std::istream& in + ); + /*! + provides serialization support for normalized_function + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename K + > + struct projection_function + { + /*! + REQUIREMENTS ON K + K must be a kernel function object type as defined at the + top of dlib/svm/kernel_abstract.h + + WHAT THIS OBJECT REPRESENTS + This object represents a function that takes a data sample and projects + it into kernel feature space. The result is a real valued column vector that + represents a point in a kernel feature space. + + THREAD SAFETY + It is always safe to use distinct instances of this object in different + threads. However, when a single instance is shared between threads then + the following rules apply: + Instances of this object have a mutable cache which is used by const + member functions. Therefore, it is not safe to use one instance of + this object from multiple threads (unless protected by a mutex). + !*/ + + typedef K kernel_type; + typedef typename K::scalar_type scalar_type; + typedef typename K::sample_type sample_type; + typedef typename K::mem_manager_type mem_manager_type; + + typedef matrix scalar_vector_type; + typedef matrix scalar_matrix_type; + typedef matrix sample_vector_type; + typedef scalar_vector_type result_type; + + scalar_matrix_type weights; + K kernel_function; + sample_vector_type basis_vectors; + + projection_function ( + ); + /*! + ensures + - #weights.size() == 0 + - #basis_vectors.size() == 0 + !*/ + + projection_function ( + const projection_function& f + ); + /*! + ensures + - #*this is a copy of f + !*/ + + projection_function ( + const scalar_matrix_type& weights_, + const K& kernel_function_, + const sample_vector_type& basis_vectors_ + ) : weights(weights_), kernel_function(kernel_function_), basis_vectors(basis_vectors_) {} + /*! + ensures + - populates the projection function with the given basis vectors, weights, + and kernel function. + !*/ + + long out_vector_size ( + ) const; + /*! + ensures + - returns weights.nr() + (i.e. returns the dimensionality of the vectors output by this projection_function.) + !*/ + + const result_type& operator() ( + const sample_type& x + ) const + /*! + requires + - weights.nc() == basis_vectors.size() + - out_vector_size() > 0 + ensures + - Takes the given x sample and projects it onto part of the kernel feature + space spanned by the basis_vectors. The exact projection arithmetic is + defined below. + !*/ + { + // Run the x sample through all the basis functions we have and then + // multiply it by the weights matrix and return the result. Note that + // the temp vectors are here to avoid reallocating their memory every + // time this function is called. + temp1 = kernel_matrix(kernel_function, basis_vectors, x); + temp2 = weights*temp1; + return temp2; + } + + private: + mutable result_type temp1, temp2; + }; + + template < + typename K + > + void serialize ( + const projection_function& item, + std::ostream& out + ); + /*! + provides serialization support for projection_function + !*/ + + template < + typename K + > + void deserialize ( + projection_function& item, + std::istream& in + ); + /*! + provides serialization support for projection_function + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename K, + typename result_type_ = typename K::scalar_type + > + struct multiclass_linear_decision_function + { + /*! + REQUIREMENTS ON K + K must be either linear_kernel or sparse_linear_kernel. + + WHAT THIS OBJECT REPRESENTS + This object represents a multiclass classifier built out of a set of + binary classifiers. Each binary classifier is used to vote for the + correct multiclass label using a one vs. all strategy. Therefore, + if you have N classes then there will be N binary classifiers inside + this object. Additionally, this object is linear in the sense that + each of these binary classifiers is a simple linear plane. + + THREAD SAFETY + It is always safe to use distinct instances of this object in different + threads. However, when a single instance is shared between threads then + the following rules apply: + It is safe to call the const member functions of this object from + multiple threads. This is because the const members are purely + read-only operations. However, any operation that modifies a + multiclass_linear_decision_function is not threadsafe. + !*/ + + typedef result_type_ result_type; + + typedef K kernel_type; + typedef typename K::scalar_type scalar_type; + typedef typename K::sample_type sample_type; + typedef typename K::mem_manager_type mem_manager_type; + + typedef matrix scalar_vector_type; + typedef matrix scalar_matrix_type; + + scalar_matrix_type weights; + scalar_vector_type b; + std::vector labels; + + const std::vector& get_labels( + ) const { return labels; } + /*! + ensures + - returns a vector containing all the labels which can be + predicted by this object. + !*/ + + unsigned long number_of_classes ( + ) const; + /*! + ensures + - returns get_labels().size() + (i.e. returns the number of different labels/classes predicted by + this object) + !*/ + + std::pair predict ( + const sample_type& x + ) const; + /*! + requires + - weights.size() > 0 + - weights.nr() == number_of_classes() == b.size() + - if (x is a dense vector, i.e. a dlib::matrix) then + - is_vector(x) == true + - x.size() == weights.nc() + (i.e. it must be legal to multiply weights with x) + ensures + - Returns the predicted label for the x sample and also it's score. + In particular, it returns the following: + std::make_pair(labels[index_of_max(weights*x-b)], max(weights*x-b)) + !*/ + + result_type operator() ( + const sample_type& x + ) const; + /*! + requires + - weights.size() > 0 + - weights.nr() == number_of_classes() == b.size() + - if (x is a dense vector, i.e. a dlib::matrix) then + - is_vector(x) == true + - x.size() == weights.nc() + (i.e. it must be legal to multiply weights with x) + ensures + - Returns the predicted label for the x sample. In particular, it returns + the following: + labels[index_of_max(weights*x-b)] + Or in other words, this function returns predict(x).first + !*/ + }; + + template < + typename K, + typename result_type_ + > + void serialize ( + const multiclass_linear_decision_function& item, + std::ostream& out + ); + /*! + provides serialization support for multiclass_linear_decision_function + !*/ + + template < + typename K, + typename result_type_ + > + void deserialize ( + multiclass_linear_decision_function& item, + std::istream& in + ); + /*! + provides serialization support for multiclass_linear_decision_function + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SVm_FUNCTION_ABSTRACT_ + + + diff --git a/ml/dlib/dlib/svm/kcentroid.h b/ml/dlib/dlib/svm/kcentroid.h new file mode 100644 index 000000000..5f380486a --- /dev/null +++ b/ml/dlib/dlib/svm/kcentroid.h @@ -0,0 +1,614 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_KCENTROId_ +#define DLIB_KCENTROId_ + +#include + +#include "kcentroid_abstract.h" +#include "../matrix.h" +#include "function.h" +#include "../std_allocator.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template + class kcentroid + { + /*! + This object represents a weighted sum of sample points in a kernel induced + feature space. It can be used to kernelize any algorithm that requires only + the ability to perform vector addition, subtraction, scalar multiplication, + and inner products. It uses the sparsification technique described in the + paper The Kernel Recursive Least Squares Algorithm by Yaakov Engel. + + To understand the code it would also be useful to consult page 114 of the book + Kernel Methods for Pattern Analysis by Taylor and Cristianini as well as page 554 + (particularly equation 18.31) of the book Learning with Kernels by Scholkopf and + Smola. Everything you really need to know is in the Engel paper. But the other + books help give more perspective on the issues involved. + + + INITIAL VALUE + - min_strength == 0 + - min_vect_idx == 0 + - K_inv.size() == 0 + - K.size() == 0 + - dictionary.size() == 0 + - bias == 0 + - bias_is_stale == false + + CONVENTION + - max_dictionary_size() == my_max_dictionary_size + - get_kernel() == kernel + + - K.nr() == dictionary.size() + - K.nc() == dictionary.size() + - for all valid r,c: + - K(r,c) == kernel(dictionary[r], dictionary[c]) + - K_inv == inv(K) + + - if (dictionary.size() == my_max_dictionary_size && my_remove_oldest_first == false) then + - for all valid 0 < i < dictionary.size(): + - Let STRENGTHS[i] == the delta you would get for dictionary[i] (i.e. Approximately + Linearly Dependent value) if you removed dictionary[i] from this object and then + tried to add it back in. + - min_strength == the minimum value from STRENGTHS + - min_vect_idx == the index of the element in STRENGTHS with the smallest value + + !*/ + + public: + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + + kcentroid ( + ) : + my_remove_oldest_first(false), + my_tolerance(0.001), + my_max_dictionary_size(1000000), + bias(0), + bias_is_stale(false) + { + clear_dictionary(); + } + + explicit kcentroid ( + const kernel_type& kernel_, + scalar_type tolerance_ = 0.001, + unsigned long max_dictionary_size_ = 1000000, + bool remove_oldest_first_ = false + ) : + my_remove_oldest_first(remove_oldest_first_), + kernel(kernel_), + my_tolerance(tolerance_), + my_max_dictionary_size(max_dictionary_size_), + bias(0), + bias_is_stale(false) + { + // make sure requires clause is not broken + DLIB_ASSERT(tolerance_ > 0 && max_dictionary_size_ > 1, + "\tkcentroid::kcentroid()" + << "\n\t You have to give a positive tolerance" + << "\n\t this: " << this + << "\n\t tolerance_: " << tolerance_ + << "\n\t max_dictionary_size_: " << max_dictionary_size_ + ); + + clear_dictionary(); + } + + scalar_type tolerance() const + { + return my_tolerance; + } + + unsigned long max_dictionary_size() const + { + return my_max_dictionary_size; + } + + bool remove_oldest_first ( + ) const + { + return my_remove_oldest_first; + } + + const kernel_type& get_kernel ( + ) const + { + return kernel; + } + + void clear_dictionary () + { + dictionary.clear(); + alpha.clear(); + + min_strength = 0; + min_vect_idx = 0; + K_inv.set_size(0,0); + K.set_size(0,0); + samples_seen = 0; + bias = 0; + bias_is_stale = false; + } + + scalar_type operator() ( + const kcentroid& x + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(x.get_kernel() == get_kernel(), + "\tscalar_type kcentroid::operator()(const kcentroid& x)" + << "\n\tYou can only compare two kcentroid objects if they use the same kernel" + << "\n\tthis: " << this + ); + + // make sure the bias terms are up to date + refresh_bias(); + x.refresh_bias(); + + scalar_type temp = x.bias + bias - 2*inner_product(x); + + if (temp > 0) + return std::sqrt(temp); + else + return 0; + } + + scalar_type inner_product ( + const sample_type& x + ) const + { + scalar_type temp = 0; + for (unsigned long i = 0; i < alpha.size(); ++i) + temp += alpha[i]*kernel(dictionary[i], x); + return temp; + } + + scalar_type inner_product ( + const kcentroid& x + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(x.get_kernel() == get_kernel(), + "\tscalar_type kcentroid::inner_product(const kcentroid& x)" + << "\n\tYou can only compare two kcentroid objects if they use the same kernel" + << "\n\tthis: " << this + ); + + scalar_type temp = 0; + for (unsigned long i = 0; i < alpha.size(); ++i) + { + for (unsigned long j = 0; j < x.alpha.size(); ++j) + { + temp += alpha[i]*x.alpha[j]*kernel(dictionary[i], x.dictionary[j]); + } + } + return temp; + } + + scalar_type squared_norm ( + ) const + { + refresh_bias(); + return bias; + } + + scalar_type operator() ( + const sample_type& x + ) const + { + // make sure the bias terms are up to date + refresh_bias(); + + const scalar_type kxx = kernel(x,x); + + scalar_type temp = kxx + bias - 2*inner_product(x); + if (temp > 0) + return std::sqrt(temp); + else + return 0; + } + + scalar_type samples_trained ( + ) const + { + return samples_seen; + } + + scalar_type test_and_train ( + const sample_type& x + ) + { + ++samples_seen; + const scalar_type xscale = 1/samples_seen; + const scalar_type cscale = 1-xscale; + return train_and_maybe_test(x,cscale,xscale,true); + } + + void train ( + const sample_type& x + ) + { + ++samples_seen; + const scalar_type xscale = 1/samples_seen; + const scalar_type cscale = 1-xscale; + train_and_maybe_test(x,cscale,xscale,false); + } + + scalar_type test_and_train ( + const sample_type& x, + scalar_type cscale, + scalar_type xscale + ) + { + ++samples_seen; + return train_and_maybe_test(x,cscale,xscale,true); + } + + void scale_by ( + scalar_type cscale + ) + { + for (unsigned long i = 0; i < alpha.size(); ++i) + { + alpha[i] = cscale*alpha[i]; + } + } + + void train ( + const sample_type& x, + scalar_type cscale, + scalar_type xscale + ) + { + ++samples_seen; + train_and_maybe_test(x,cscale,xscale,false); + } + + void swap ( + kcentroid& item + ) + { + exchange(min_strength, item.min_strength); + exchange(min_vect_idx, item.min_vect_idx); + exchange(my_remove_oldest_first, item.my_remove_oldest_first); + + exchange(kernel, item.kernel); + dictionary.swap(item.dictionary); + alpha.swap(item.alpha); + K_inv.swap(item.K_inv); + K.swap(item.K); + exchange(my_tolerance, item.my_tolerance); + exchange(samples_seen, item.samples_seen); + exchange(bias, item.bias); + a.swap(item.a); + k.swap(item.k); + exchange(bias_is_stale, item.bias_is_stale); + exchange(my_max_dictionary_size, item.my_max_dictionary_size); + } + + unsigned long dictionary_size ( + ) const { return dictionary.size(); } + + friend void serialize(const kcentroid& item, std::ostream& out) + { + serialize(item.min_strength, out); + serialize(item.min_vect_idx, out); + serialize(item.my_remove_oldest_first, out); + + serialize(item.kernel, out); + serialize(item.dictionary, out); + serialize(item.alpha, out); + serialize(item.K_inv, out); + serialize(item.K, out); + serialize(item.my_tolerance, out); + serialize(item.samples_seen, out); + serialize(item.bias, out); + serialize(item.bias_is_stale, out); + serialize(item.my_max_dictionary_size, out); + } + + friend void deserialize(kcentroid& item, std::istream& in) + { + deserialize(item.min_strength, in); + deserialize(item.min_vect_idx, in); + deserialize(item.my_remove_oldest_first, in); + + deserialize(item.kernel, in); + deserialize(item.dictionary, in); + deserialize(item.alpha, in); + deserialize(item.K_inv, in); + deserialize(item.K, in); + deserialize(item.my_tolerance, in); + deserialize(item.samples_seen, in); + deserialize(item.bias, in); + deserialize(item.bias_is_stale, in); + deserialize(item.my_max_dictionary_size, in); + } + + distance_function get_distance_function ( + ) const + { + refresh_bias(); + return distance_function(mat(alpha), + bias, + kernel, + mat(dictionary)); + } + + private: + + void refresh_bias ( + ) const + { + if (bias_is_stale) + { + bias_is_stale = false; + // recompute the bias term + bias = sum(pointwise_multiply(K, mat(alpha)*trans(mat(alpha)))); + } + } + + scalar_type train_and_maybe_test ( + const sample_type& x, + scalar_type cscale, + scalar_type xscale, + bool do_test + ) + { + scalar_type test_result = 0; + const scalar_type kx = kernel(x,x); + if (alpha.size() == 0) + { + // just ignore this sample if it is the zero vector (or really close to being zero) + if (std::abs(kx) > std::numeric_limits::epsilon()) + { + // set initial state since this is the first training example we have seen + + K_inv.set_size(1,1); + K_inv(0,0) = 1/kx; + K.set_size(1,1); + K(0,0) = kx; + + alpha.push_back(xscale); + dictionary.push_back(x); + } + else + { + // the distance from an empty kcentroid and the zero vector is zero by definition. + return 0; + } + } + else + { + // fill in k + k.set_size(alpha.size()); + for (long r = 0; r < k.nr(); ++r) + k(r) = kernel(x,dictionary[r]); + + if (do_test) + { + refresh_bias(); + test_result = std::sqrt(kx + bias - 2*trans(mat(alpha))*k); + } + + // compute the error we would have if we approximated the new x sample + // with the dictionary. That is, do the ALD test from the KRLS paper. + a = K_inv*k; + scalar_type delta = kx - trans(k)*a; + + // if this new vector isn't approximately linearly dependent on the vectors + // in our dictionary. + if (delta > min_strength && delta > my_tolerance) + { + bool need_to_update_min_strength = false; + if (dictionary.size() >= my_max_dictionary_size) + { + // We need to remove one of the old members of the dictionary before + // we proceed with adding a new one. + long idx_to_remove; + if (my_remove_oldest_first) + { + // remove the oldest one + idx_to_remove = 0; + } + else + { + // if we have never computed the min_strength then we should compute it + if (min_strength == 0) + recompute_min_strength(); + + // select the dictionary vector that is most linearly dependent for removal + idx_to_remove = min_vect_idx; + need_to_update_min_strength = true; + } + + remove_dictionary_vector(idx_to_remove); + + // recompute these guys since they were computed with the old + // kernel matrix + k = remove_row(k,idx_to_remove); + a = K_inv*k; + delta = kx - trans(k)*a; + } + + // add x to the dictionary + dictionary.push_back(x); + + + // update K_inv by computing the new one in the temp matrix (equation 3.14) + matrix temp(K_inv.nr()+1, K_inv.nc()+1); + // update the middle part of the matrix + set_subm(temp, get_rect(K_inv)) = K_inv + a*trans(a)/delta; + // update the right column of the matrix + set_subm(temp, 0, K_inv.nr(),K_inv.nr(),1) = -a/delta; + // update the bottom row of the matrix + set_subm(temp, K_inv.nr(), 0, 1, K_inv.nr()) = trans(-a/delta); + // update the bottom right corner of the matrix + temp(K_inv.nr(), K_inv.nc()) = 1/delta; + // put temp into K_inv + temp.swap(K_inv); + + + + // update K (the kernel matrix) + temp.set_size(K.nr()+1, K.nc()+1); + set_subm(temp, get_rect(K)) = K; + // update the right column of the matrix + set_subm(temp, 0, K.nr(),K.nr(),1) = k; + // update the bottom row of the matrix + set_subm(temp, K.nr(), 0, 1, K.nr()) = trans(k); + temp(K.nr(), K.nc()) = kx; + // put temp into K + temp.swap(K); + + + // now update the alpha vector + for (unsigned long i = 0; i < alpha.size(); ++i) + { + alpha[i] *= cscale; + } + alpha.push_back(xscale); + + + if (need_to_update_min_strength) + { + // now we have to recompute the min_strength in this case + recompute_min_strength(); + } + } + else + { + // update the alpha vector so that this new sample has been added into + // the mean vector we are accumulating + for (unsigned long i = 0; i < alpha.size(); ++i) + { + alpha[i] = cscale*alpha[i] + xscale*a(i); + } + } + } + + bias_is_stale = true; + + return test_result; + } + + void remove_dictionary_vector ( + long i + ) + /*! + requires + - 0 <= i < dictionary.size() + ensures + - #dictionary.size() == dictionary.size() - 1 + - #alpha.size() == alpha.size() - 1 + - updates the K_inv matrix so that it is still a proper inverse of the + kernel matrix + - also removes the necessary row and column from the K matrix + - uses the this->a variable so after this function runs that variable + will contain a different value. + !*/ + { + // remove the dictionary vector + dictionary.erase(dictionary.begin()+i); + + // remove the i'th vector from the inverse kernel matrix. This formula is basically + // just the reverse of the way K_inv is updated by equation 3.14 during normal training. + K_inv = removerc(K_inv,i,i) - remove_row(colm(K_inv,i)/K_inv(i,i),i)*remove_col(rowm(K_inv,i),i); + + // now compute the updated alpha values to take account that we just removed one of + // our dictionary vectors + a = (K_inv*remove_row(K,i)*mat(alpha)); + + // now copy over the new alpha values + alpha.resize(alpha.size()-1); + for (unsigned long k = 0; k < alpha.size(); ++k) + { + alpha[k] = a(k); + } + + // update the K matrix as well + K = removerc(K,i,i); + } + + void recompute_min_strength ( + ) + /*! + ensures + - recomputes the min_strength and min_vect_idx values + so that they are correct with respect to the CONVENTION + - uses the this->a variable so after this function runs that variable + will contain a different value. + !*/ + { + min_strength = std::numeric_limits::max(); + + // here we loop over each dictionary vector and compute what its delta would be if + // we were to remove it from the dictionary and then try to add it back in. + for (unsigned long i = 0; i < dictionary.size(); ++i) + { + // compute a = K_inv*k but where dictionary vector i has been removed + a = (removerc(K_inv,i,i) - remove_row(colm(K_inv,i)/K_inv(i,i),i)*remove_col(rowm(K_inv,i),i)) * + (remove_row(colm(K,i),i)); + scalar_type delta = K(i,i) - trans(remove_row(colm(K,i),i))*a; + + if (delta < min_strength) + { + min_strength = delta; + min_vect_idx = i; + } + } + } + + + + typedef std_allocator alloc_sample_type; + typedef std_allocator alloc_scalar_type; + typedef std::vector dictionary_vector_type; + typedef std::vector alpha_vector_type; + + + scalar_type min_strength; + unsigned long min_vect_idx; + bool my_remove_oldest_first; + + kernel_type kernel; + dictionary_vector_type dictionary; + alpha_vector_type alpha; + + matrix K_inv; + matrix K; + + scalar_type my_tolerance; + unsigned long my_max_dictionary_size; + scalar_type samples_seen; + mutable scalar_type bias; + mutable bool bias_is_stale; + + + // temp variables here just so we don't have to reconstruct them over and over. Thus, + // they aren't really part of the state of this object. + matrix a; + matrix k; + + }; + +// ---------------------------------------------------------------------------------------- + + template + void swap(kcentroid& a, kcentroid& b) + { a.swap(b); } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_KCENTROId_ + diff --git a/ml/dlib/dlib/svm/kcentroid_abstract.h b/ml/dlib/dlib/svm/kcentroid_abstract.h new file mode 100644 index 000000000..44b94c813 --- /dev/null +++ b/ml/dlib/dlib/svm/kcentroid_abstract.h @@ -0,0 +1,339 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_KCENTROId_ABSTRACT_ +#ifdef DLIB_KCENTROId_ABSTRACT_ + +#include "../algs.h" +#include "../serialize.h" +#include "kernel_abstract.h" + +namespace dlib +{ + + template < + typename kernel_type + > + class kcentroid + { + /*! + REQUIREMENTS ON kernel_type + is a kernel function object as defined in dlib/svm/kernel_abstract.h + + INITIAL VALUE + - dictionary_size() == 0 + - samples_trained() == 0 + + WHAT THIS OBJECT REPRESENTS + This object represents a weighted sum of sample points in a kernel induced + feature space. It can be used to kernelize any algorithm that requires only + the ability to perform vector addition, subtraction, scalar multiplication, + and inner products. + + An example use of this object is as an online algorithm for recursively estimating + the centroid of a sequence of training points. This object then allows you to + compute the distance between the centroid and any test points. So you can use + this object to predict how similar a test point is to the data this object has + been trained on (larger distances from the centroid indicate dissimilarity/anomalous + points). + + Also note that the algorithm internally keeps a set of "dictionary vectors" + that are used to represent the centroid. You can force the algorithm to use + no more than a set number of vectors by setting the 3rd constructor argument + to whatever you want. + + This object uses the sparsification technique described in the paper The + Kernel Recursive Least Squares Algorithm by Yaakov Engel. This technique + allows us to keep the number of dictionary vectors down to a minimum. In fact, + the object has a user selectable tolerance parameter that controls the trade off + between accuracy and number of stored dictionary vectors. + !*/ + + public: + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + + kcentroid ( + ); + /*! + ensures + - this object is properly initialized + - #tolerance() == 0.001 + - #get_kernel() == kernel_type() (i.e. whatever the kernel's default value is) + - #max_dictionary_size() == 1000000 + - #remove_oldest_first() == false + !*/ + + explicit kcentroid ( + const kernel_type& kernel_, + scalar_type tolerance_ = 0.001, + unsigned long max_dictionary_size_ = 1000000, + bool remove_oldest_first_ = false + ); + /*! + requires + - tolerance > 0 + - max_dictionary_size_ > 1 + ensures + - this object is properly initialized + - #tolerance() == tolerance_ + - #get_kernel() == kernel_ + - #max_dictionary_size() == max_dictionary_size_ + - #remove_oldest_first() == remove_oldest_first_ + !*/ + + const kernel_type& get_kernel ( + ) const; + /*! + ensures + - returns a const reference to the kernel used by this object + !*/ + + unsigned long max_dictionary_size( + ) const; + /*! + ensures + - returns the maximum number of dictionary vectors this object will + use at a time. That is, dictionary_size() will never be greater + than max_dictionary_size(). + !*/ + + bool remove_oldest_first ( + ) const; + /*! + ensures + - When the maximum dictionary size is reached this object sometimes + needs to discard dictionary vectors when new samples are added via + one of the train functions. When this happens this object chooses + the dictionary vector to discard based on the setting of the + remove_oldest_first() parameter. + - if (remove_oldest_first() == true) then + - This object discards the oldest dictionary vectors when necessary. + This is an appropriate mode when using this object in an online + setting and the input training samples come from a slowly + varying distribution. + - else (remove_oldest_first() == false) then + - This object discards the most linearly dependent dictionary vectors + when necessary. This it the default behavior and should be used + in most cases. + !*/ + + unsigned long dictionary_size ( + ) const; + /*! + ensures + - returns the number of basis vectors in the dictionary. These are + the basis vectors used by this object to represent a point in kernel + feature space. + !*/ + + scalar_type samples_trained ( + ) const; + /*! + ensures + - returns the number of samples this object has been trained on so far + !*/ + + scalar_type tolerance( + ) const; + /*! + ensures + - returns the tolerance to use for the approximately linearly dependent + test used for sparsification (see the KRLS paper for details). This is + a number which governs how accurately this object will approximate the + centroid it is learning. Smaller values generally result in a more + accurate estimate while also resulting in a bigger set of vectors in + the dictionary. Bigger tolerances values result in a less accurate + estimate but also in less dictionary vectors. (Note that in any case, + the max_dictionary_size() limits the number of dictionary vectors no + matter the setting of the tolerance) + - The exact meaning of the tolerance parameter is the following: + Imagine that we have an empirical_kernel_map that contains all + the current dictionary vectors. Then the tolerance is the minimum + projection error (as given by empirical_kernel_map::project()) required + to cause us to include a new vector in the dictionary. So each time + you call train() the kcentroid basically just computes the projection + error for that new sample and if it is larger than the tolerance + then that new sample becomes part of the dictionary. + !*/ + + void clear_dictionary ( + ); + /*! + ensures + - clears out all learned data (e.g. #dictionary_size() == 0) + - #samples_seen() == 0 + !*/ + + scalar_type operator() ( + const kcentroid& x + ) const; + /*! + requires + - x.get_kernel() == get_kernel() + ensures + - returns the distance in kernel feature space between this centroid and the + centroid represented by x. + !*/ + + scalar_type operator() ( + const sample_type& x + ) const; + /*! + ensures + - returns the distance in kernel feature space between the sample x and the + current estimate of the centroid of the training samples given + to this object so far. + !*/ + + scalar_type inner_product ( + const sample_type& x + ) const; + /*! + ensures + - returns the inner product of the given x point and the current + estimate of the centroid of the training samples given to this object + so far. + !*/ + + scalar_type inner_product ( + const kcentroid& x + ) const; + /*! + requires + - x.get_kernel() == get_kernel() + ensures + - returns the inner product between x and this centroid object. + !*/ + + scalar_type squared_norm ( + ) const; + /*! + ensures + - returns the squared norm of the centroid vector represented by this + object. I.e. returns this->inner_product(*this) + !*/ + + void train ( + const sample_type& x + ); + /*! + ensures + - adds the sample x into the current estimate of the centroid + - also note that calling this function is equivalent to calling + train(x, samples_trained()/(samples_trained()+1.0, 1.0/(samples_trained()+1.0). + That is, this function finds the normal unweighted centroid of all training points. + !*/ + + void train ( + const sample_type& x, + scalar_type cscale, + scalar_type xscale + ); + /*! + ensures + - adds the sample x into the current estimate of the centroid but + uses a user given scale. That is, this function performs: + - new_centroid = cscale*old_centroid + xscale*x + - This function allows you to weight different samples however + you want. + !*/ + + void scale_by ( + scalar_type cscale + ); + /*! + ensures + - multiplies the current centroid vector by the given scale value. + This function is equivalent to calling train(some_x_value, cscale, 0). + So it performs: + - new_centroid == cscale*old_centroid + !*/ + + scalar_type test_and_train ( + const sample_type& x + ); + /*! + ensures + - calls train(x) + - returns (*this)(x) + - The reason this function exists is because train() and operator() + both compute some of the same things. So this function is more efficient + than calling both individually. + !*/ + + scalar_type test_and_train ( + const sample_type& x, + scalar_type cscale, + scalar_type xscale + ); + /*! + ensures + - calls train(x,cscale,xscale) + - returns (*this)(x) + - The reason this function exists is because train() and operator() + both compute some of the same things. So this function is more efficient + than calling both individually. + !*/ + + void swap ( + kcentroid& item + ); + /*! + ensures + - swaps *this with item + !*/ + + distance_function get_distance_function ( + ) const; + /*! + ensures + - returns a distance function F that represents the point learned + by this object so far. I.e. it is the case that: + - for all x: F(x) == (*this)(x) + !*/ + + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename kernel_type + > + void swap( + kcentroid& a, + kcentroid& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + + template < + typename kernel_type + > + void serialize ( + const kcentroid& item, + std::ostream& out + ); + /*! + provides serialization support for kcentroid objects + !*/ + + template < + typename kernel_type + > + void deserialize ( + kcentroid& item, + std::istream& in + ); + /*! + provides serialization support for kcentroid objects + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_KCENTROId_ABSTRACT_ + diff --git a/ml/dlib/dlib/svm/kcentroid_overloads.h b/ml/dlib/dlib/svm/kcentroid_overloads.h new file mode 100644 index 000000000..9c39f3d78 --- /dev/null +++ b/ml/dlib/dlib/svm/kcentroid_overloads.h @@ -0,0 +1,1324 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_KCENTROId_OVERLOADS_ +#define DLIB_KCENTROId_OVERLOADS_ + +#include "kcentroid_abstract.h" +#include "sparse_kernel.h" +#include "sparse_vector.h" +#include + +namespace dlib +{ + /* + This file contains optimized overloads of the kcentroid object for the following + linear cases: + kcentroid> + kcentroid> + kcentroid>> + kcentroid>> + */ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// Overloads for when kernel_type == linear_kernel +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template + class kcentroid > + { + + + typedef linear_kernel kernel_type; + public: + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + + + explicit kcentroid ( + const kernel_type& kernel_, + scalar_type tolerance_ = 0.001, + unsigned long max_dictionary_size_ = 1000000, + bool remove_oldest_first_ = false + ) : + my_remove_oldest_first(remove_oldest_first_), + kernel(kernel_), + my_tolerance(tolerance_), + my_max_dictionary_size(max_dictionary_size_) + { + // make sure requires clause is not broken + DLIB_ASSERT(tolerance_ >= 0 && max_dictionary_size_ > 0, + "\tkcentroid::kcentroid()" + << "\n\t You have to give a positive tolerance" + << "\n\t this: " << this + << "\n\t tolerance_: " << tolerance_ + << "\n\t max_dictionary_size_: " << max_dictionary_size_ + ); + + clear_dictionary(); + } + + scalar_type tolerance() const { return my_tolerance; } + unsigned long max_dictionary_size() const { return my_max_dictionary_size; } + bool remove_oldest_first () const { return my_remove_oldest_first; } + const kernel_type& get_kernel () const { return kernel; } + scalar_type samples_trained () const { return samples_seen; } + + void clear_dictionary () + { + samples_seen = 0; + set_all_elements(w, 0); + alpha = 0; + } + + scalar_type operator() ( + const kcentroid& x + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(x.get_kernel() == get_kernel(), + "\tscalar_type kcentroid::operator()(const kcentroid& x)" + << "\n\tYou can only compare two kcentroid objects if they use the same kernel" + << "\n\tthis: " << this + ); + + if (w.size() > 0) + { + if (x.w.size() > 0) + return length(alpha*w - x.alpha*x.w); + else + return alpha*length(w); + } + else + { + if (x.w.size() > 0) + return x.alpha*length(x.w); + else + return 0; + } + } + + scalar_type inner_product ( + const sample_type& x + ) const + { + if (w.size() > 0) + return alpha*trans(w)*x; + else + return 0; + } + + scalar_type inner_product ( + const kcentroid& x + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(x.get_kernel() == get_kernel(), + "\tscalar_type kcentroid::inner_product(const kcentroid& x)" + << "\n\tYou can only compare two kcentroid objects if they use the same kernel" + << "\n\tthis: " << this + ); + + if (w.size() > 0 && x.w.size() > 0) + return alpha*x.alpha*trans(w)*x.w; + else + return 0; + } + + scalar_type squared_norm ( + ) const + { + if (w.size() > 0) + return alpha*alpha*trans(w)*w; + else + return 0; + } + + scalar_type operator() ( + const sample_type& x + ) const + { + if (w.size() > 0) + return length(x-alpha*w); + else + return length(x); + } + + scalar_type test_and_train ( + const sample_type& x + ) + { + ++samples_seen; + const scalar_type xscale = 1/samples_seen; + const scalar_type cscale = 1-xscale; + + do_train(x, cscale, xscale); + + return (*this)(x); + } + + void train ( + const sample_type& x + ) + { + ++samples_seen; + const scalar_type xscale = 1/samples_seen; + const scalar_type cscale = 1-xscale; + + do_train(x, cscale, xscale); + } + + scalar_type test_and_train ( + const sample_type& x, + scalar_type cscale, + scalar_type xscale + ) + { + ++samples_seen; + + do_train(x, cscale, xscale); + + return (*this)(x); + } + + void scale_by ( + scalar_type cscale + ) + { + alpha *= cscale; + } + + void train ( + const sample_type& x, + scalar_type cscale, + scalar_type xscale + ) + { + ++samples_seen; + do_train(x, cscale, xscale); + } + + void swap ( + kcentroid& item + ) + { + exchange(my_remove_oldest_first, item.my_remove_oldest_first); + exchange(kernel, item.kernel); + exchange(w, item.w); + exchange(alpha, item.alpha); + exchange(my_tolerance, item.my_tolerance); + exchange(my_max_dictionary_size, item.my_max_dictionary_size); + exchange(samples_seen, item.samples_seen); + } + + unsigned long dictionary_size ( + ) const + { + if (samples_seen > 0) + return 1; + else + return 0; + } + + friend void serialize(const kcentroid& item, std::ostream& out) + { + serialize(item.my_remove_oldest_first, out); + serialize(item.kernel, out); + serialize(item.w, out); + serialize(item.alpha, out); + serialize(item.my_tolerance, out); + serialize(item.my_max_dictionary_size, out); + serialize(item.samples_seen, out); + } + + friend void deserialize(kcentroid& item, std::istream& in) + { + deserialize(item.my_remove_oldest_first, in); + deserialize(item.kernel, in); + deserialize(item.w, in); + deserialize(item.alpha, in); + deserialize(item.my_tolerance, in); + deserialize(item.my_max_dictionary_size, in); + deserialize(item.samples_seen, in); + } + + distance_function get_distance_function ( + ) const + { + if (samples_seen > 0) + { + typename distance_function::sample_vector_type temp_basis_vectors; + typename distance_function::scalar_vector_type temp_alpha; + + temp_basis_vectors.set_size(1); + temp_basis_vectors(0) = w; + temp_alpha.set_size(1); + temp_alpha(0) = alpha; + + return distance_function(temp_alpha, squared_norm(), kernel, temp_basis_vectors); + } + else + { + return distance_function(kernel); + } + } + + private: + + void do_train ( + const sample_type& x, + scalar_type cscale, + scalar_type xscale + ) + { + set_size_of_w(x); + + const scalar_type temp = cscale*alpha; + + if (temp != 0) + { + w = w + xscale*x/temp; + alpha = temp; + } + else + { + w = cscale*alpha*w + xscale*x; + alpha = 1; + } + } + + void set_size_of_w ( + const sample_type& x + ) + { + if (x.size() != w.size()) + { + w.set_size(x.nr(), x.nc()); + set_all_elements(w, 0); + alpha = 0; + } + } + + bool my_remove_oldest_first; + + kernel_type kernel; + + sample_type w; + scalar_type alpha; + + + scalar_type my_tolerance; + unsigned long my_max_dictionary_size; + scalar_type samples_seen; + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// Overloads for when kernel_type == offset_kernel +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template + class kcentroid > > + { + + /*! + INITIAL VALUE + - x_extra == sqrt(kernel.offset) + + CONVENTION + - x_extra == sqrt(kernel.offset) + - w_extra == the value of the extra dimension tacked onto the + end of the w vector + !*/ + + typedef offset_kernel > kernel_type; + public: + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + + + explicit kcentroid ( + const kernel_type& kernel_, + scalar_type tolerance_ = 0.001, + unsigned long max_dictionary_size_ = 1000000, + bool remove_oldest_first_ = false + ) : + my_remove_oldest_first(remove_oldest_first_), + kernel(kernel_), + my_tolerance(tolerance_), + my_max_dictionary_size(max_dictionary_size_) + { + // make sure requires clause is not broken + DLIB_ASSERT(tolerance_ >= 0 && max_dictionary_size_ > 0, + "\tkcentroid::kcentroid()" + << "\n\t You have to give a positive tolerance" + << "\n\t this: " << this + << "\n\t tolerance_: " << tolerance_ + << "\n\t max_dictionary_size_: " << max_dictionary_size_ + ); + + x_extra = std::sqrt(kernel.offset); + + clear_dictionary(); + } + + scalar_type tolerance() const { return my_tolerance; } + unsigned long max_dictionary_size() const { return my_max_dictionary_size; } + bool remove_oldest_first () const { return my_remove_oldest_first; } + const kernel_type& get_kernel () const { return kernel; } + scalar_type samples_trained () const { return samples_seen; } + + void clear_dictionary () + { + samples_seen = 0; + set_all_elements(w, 0); + alpha = 0; + w_extra = x_extra; + } + + scalar_type operator() ( + const kcentroid& x + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(x.get_kernel() == get_kernel(), + "\tscalar_type kcentroid::operator()(const kcentroid& x)" + << "\n\tYou can only compare two kcentroid objects if they use the same kernel" + << "\n\tthis: " << this + ); + + if (w.size() > 0) + { + if (x.w.size() > 0) + { + scalar_type temp1 = length_squared(alpha*w - x.alpha*x.w); + scalar_type temp2 = alpha*w_extra - x.alpha*x.w_extra; + return std::sqrt(temp1 + temp2*temp2); + } + else + { + return alpha*std::sqrt(length_squared(w) + w_extra*w_extra); + } + } + else + { + if (x.w.size() > 0) + return x.alpha*std::sqrt(length_squared(x.w) + x.w_extra*x.w_extra); + else + return 0; + } + } + + scalar_type inner_product ( + const sample_type& x + ) const + { + if (w.size() > 0) + return alpha*(trans(w)*x + w_extra*x_extra); + else + return 0; + } + + scalar_type inner_product ( + const kcentroid& x + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(x.get_kernel() == get_kernel(), + "\tscalar_type kcentroid::inner_product(const kcentroid& x)" + << "\n\tYou can only compare two kcentroid objects if they use the same kernel" + << "\n\tthis: " << this + ); + + if (w.size() > 0 && x.w.size() > 0) + return alpha*x.alpha*(trans(w)*x.w + w_extra*x.w_extra); + else + return 0; + } + + scalar_type squared_norm ( + ) const + { + if (w.size() > 0) + return alpha*alpha*(trans(w)*w + w_extra*w_extra); + else + return 0; + } + + scalar_type operator() ( + const sample_type& x + ) const + { + if (w.size() > 0) + { + scalar_type temp1 = length_squared(x-alpha*w); + scalar_type temp2 = x_extra - alpha*w_extra; + return std::sqrt(temp1 + temp2*temp2); + } + else + { + return std::sqrt(length_squared(x) + x_extra*x_extra); + } + } + + scalar_type test_and_train ( + const sample_type& x + ) + { + ++samples_seen; + const scalar_type xscale = 1/samples_seen; + const scalar_type cscale = 1-xscale; + + do_train(x, cscale, xscale); + + return (*this)(x); + } + + void train ( + const sample_type& x + ) + { + ++samples_seen; + const scalar_type xscale = 1/samples_seen; + const scalar_type cscale = 1-xscale; + + do_train(x, cscale, xscale); + } + + scalar_type test_and_train ( + const sample_type& x, + scalar_type cscale, + scalar_type xscale + ) + { + ++samples_seen; + + do_train(x, cscale, xscale); + + return (*this)(x); + } + + void scale_by ( + scalar_type cscale + ) + { + alpha *= cscale; + w_extra *= cscale; + } + + void train ( + const sample_type& x, + scalar_type cscale, + scalar_type xscale + ) + { + ++samples_seen; + do_train(x, cscale, xscale); + } + + void swap ( + kcentroid& item + ) + { + exchange(my_remove_oldest_first, item.my_remove_oldest_first); + exchange(kernel, item.kernel); + exchange(w, item.w); + exchange(alpha, item.alpha); + exchange(w_extra, item.w_extra); + exchange(x_extra, item.x_extra); + exchange(my_tolerance, item.my_tolerance); + exchange(my_max_dictionary_size, item.my_max_dictionary_size); + exchange(samples_seen, item.samples_seen); + } + + unsigned long dictionary_size ( + ) const + { + if (samples_seen > 0) + { + if (std::abs(w_extra) > std::numeric_limits::epsilon()) + return 1; + else + return 2; + } + else + return 0; + } + + friend void serialize(const kcentroid& item, std::ostream& out) + { + serialize(item.my_remove_oldest_first, out); + serialize(item.kernel, out); + serialize(item.w, out); + serialize(item.alpha, out); + serialize(item.w_extra, out); + serialize(item.x_extra, out); + serialize(item.my_tolerance, out); + serialize(item.my_max_dictionary_size, out); + serialize(item.samples_seen, out); + } + + friend void deserialize(kcentroid& item, std::istream& in) + { + deserialize(item.my_remove_oldest_first, in); + deserialize(item.kernel, in); + deserialize(item.w, in); + deserialize(item.alpha, in); + deserialize(item.w_extra, in); + deserialize(item.x_extra, in); + deserialize(item.my_tolerance, in); + deserialize(item.my_max_dictionary_size, in); + deserialize(item.samples_seen, in); + } + + distance_function get_distance_function ( + ) const + { + + if (samples_seen > 0) + { + typename distance_function::sample_vector_type temp_basis_vectors; + typename distance_function::scalar_vector_type temp_alpha; + + // What we are doing here needs a bit of explanation. The w vector + // has an implicit extra dimension tacked on to it with the value of w_extra. + // The kernel we are using takes normal vectors and implicitly tacks the value + // x_extra onto their end. So what we are doing here is scaling w so that + // the value it should have tacked onto it is x_scale. Note that we also + // adjust alpha so that the combination of alpha*w stays the same. + scalar_type scale; + + // if w_extra is basically greater than 0 + if (std::abs(w_extra) > std::numeric_limits::epsilon()) + { + scale = (x_extra/w_extra); + temp_basis_vectors.set_size(1); + temp_alpha.set_size(1); + temp_basis_vectors(0) = w*scale; + temp_alpha(0) = alpha/scale; + } + else + { + // In this case w_extra is zero. So the only way we can get the same + // thing in the output basis vector set is by using two vectors + temp_basis_vectors.set_size(2); + temp_alpha.set_size(2); + temp_basis_vectors(0) = 2*w; + temp_alpha(0) = alpha; + temp_basis_vectors(1) = w; + temp_alpha(1) = -alpha; + } + + + return distance_function(temp_alpha, squared_norm(), kernel, temp_basis_vectors); + } + else + { + return distance_function(kernel); + } + } + + private: + + void do_train ( + const sample_type& x, + scalar_type cscale, + scalar_type xscale + ) + { + set_size_of_w(x); + + const scalar_type temp = cscale*alpha; + + if (temp != 0) + { + w = w + xscale*x/temp; + w_extra = w_extra + xscale*x_extra/temp; + alpha = temp; + } + else + { + w = cscale*alpha*w + xscale*x; + w_extra = cscale*alpha*w_extra + xscale*x_extra; + alpha = 1; + } + } + + void set_size_of_w ( + const sample_type& x + ) + { + if (x.size() != w.size()) + { + w.set_size(x.nr(), x.nc()); + set_all_elements(w, 0); + alpha = 0; + w_extra = x_extra; + } + } + + bool my_remove_oldest_first; + + kernel_type kernel; + + sample_type w; + scalar_type alpha; + + scalar_type w_extra; + scalar_type x_extra; + + + scalar_type my_tolerance; + unsigned long my_max_dictionary_size; + scalar_type samples_seen; + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// Overloads for when kernel_type == sparse_linear_kernel +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template + class kcentroid > + { + + + typedef sparse_linear_kernel kernel_type; + public: + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + + + explicit kcentroid ( + const kernel_type& kernel_, + scalar_type tolerance_ = 0.001, + unsigned long max_dictionary_size_ = 1000000, + bool remove_oldest_first_ = false + ) : + my_remove_oldest_first(remove_oldest_first_), + kernel(kernel_), + my_tolerance(tolerance_), + my_max_dictionary_size(max_dictionary_size_) + { + // make sure requires clause is not broken + DLIB_ASSERT(tolerance_ >= 0 && max_dictionary_size_ > 0, + "\tkcentroid::kcentroid()" + << "\n\t You have to give a positive tolerance" + << "\n\t this: " << this + << "\n\t tolerance_: " << tolerance_ + << "\n\t max_dictionary_size_: " << max_dictionary_size_ + ); + + clear_dictionary(); + } + + scalar_type tolerance() const { return my_tolerance; } + unsigned long max_dictionary_size() const { return my_max_dictionary_size; } + bool remove_oldest_first () const { return my_remove_oldest_first; } + const kernel_type& get_kernel () const { return kernel; } + scalar_type samples_trained () const { return samples_seen; } + + void clear_dictionary () + { + samples_seen = 0; + w.clear(); + alpha = 0; + } + + scalar_type operator() ( + const kcentroid& x + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(x.get_kernel() == get_kernel(), + "\tscalar_type kcentroid::operator()(const kcentroid& x)" + << "\n\tYou can only compare two kcentroid objects if they use the same kernel" + << "\n\tthis: " << this + ); + + return distance(alpha,w , x.alpha,x.w); + } + + scalar_type inner_product ( + const sample_type& x + ) const + { + return alpha*dot(w,x); + } + + scalar_type inner_product ( + const kcentroid& x + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(x.get_kernel() == get_kernel(), + "\tscalar_type kcentroid::inner_product(const kcentroid& x)" + << "\n\tYou can only compare two kcentroid objects if they use the same kernel" + << "\n\tthis: " << this + ); + + return alpha*x.alpha*dot(w,x.w); + } + + scalar_type squared_norm ( + ) const + { + return alpha*alpha*length_squared(w); + } + + scalar_type operator() ( + const sample_type& x + ) const + { + return distance(static_cast(1), x, alpha, w); + } + + scalar_type test_and_train ( + const sample_type& x + ) + { + ++samples_seen; + const scalar_type xscale = 1/samples_seen; + const scalar_type cscale = 1-xscale; + + do_train(x, cscale, xscale); + + return (*this)(x); + } + + void train ( + const sample_type& x + ) + { + ++samples_seen; + const scalar_type xscale = 1/samples_seen; + const scalar_type cscale = 1-xscale; + + do_train(x, cscale, xscale); + } + + scalar_type test_and_train ( + const sample_type& x, + scalar_type cscale, + scalar_type xscale + ) + { + ++samples_seen; + + do_train(x, cscale, xscale); + + return (*this)(x); + } + + void scale_by ( + scalar_type cscale + ) + { + alpha *= cscale; + } + + void train ( + const sample_type& x, + scalar_type cscale, + scalar_type xscale + ) + { + ++samples_seen; + do_train(x, cscale, xscale); + } + + void swap ( + kcentroid& item + ) + { + exchange(my_remove_oldest_first, item.my_remove_oldest_first); + exchange(kernel, item.kernel); + exchange(w, item.w); + exchange(alpha, item.alpha); + exchange(my_tolerance, item.my_tolerance); + exchange(my_max_dictionary_size, item.my_max_dictionary_size); + exchange(samples_seen, item.samples_seen); + } + + unsigned long dictionary_size ( + ) const + { + if (samples_seen > 0) + return 1; + else + return 0; + } + + friend void serialize(const kcentroid& item, std::ostream& out) + { + serialize(item.my_remove_oldest_first, out); + serialize(item.kernel, out); + serialize(item.w, out); + serialize(item.alpha, out); + serialize(item.my_tolerance, out); + serialize(item.my_max_dictionary_size, out); + serialize(item.samples_seen, out); + } + + friend void deserialize(kcentroid& item, std::istream& in) + { + deserialize(item.my_remove_oldest_first, in); + deserialize(item.kernel, in); + deserialize(item.w, in); + deserialize(item.alpha, in); + deserialize(item.my_tolerance, in); + deserialize(item.my_max_dictionary_size, in); + deserialize(item.samples_seen, in); + } + + distance_function get_distance_function ( + ) const + { + if (samples_seen > 0) + { + typename distance_function::sample_vector_type temp_basis_vectors; + typename distance_function::scalar_vector_type temp_alpha; + + temp_basis_vectors.set_size(1); + temp_basis_vectors(0) = sample_type(w.begin(), w.end()); + temp_alpha.set_size(1); + temp_alpha(0) = alpha; + + return distance_function(temp_alpha, squared_norm(), kernel, temp_basis_vectors); + } + else + { + return distance_function(kernel); + } + } + + private: + + void do_train ( + const sample_type& x, + scalar_type cscale, + scalar_type xscale + ) + { + const scalar_type temp = cscale*alpha; + + if (temp != 0) + { + // compute w += xscale*x/temp + typename sample_type::const_iterator i; + for (i = x.begin(); i != x.end(); ++i) + { + w[i->first] += xscale*(i->second)/temp; + } + + alpha = temp; + } + else + { + // first compute w = cscale*alpha*w + for (typename std::map::iterator i = w.begin(); i != w.end(); ++i) + { + i->second *= cscale*alpha; + } + + // now compute w += xscale*x + for (typename sample_type::const_iterator i = x.begin(); i != x.end(); ++i) + { + w[i->first] += xscale*(i->second); + } + + alpha = 1; + } + } + + bool my_remove_oldest_first; + + kernel_type kernel; + + std::map w; + scalar_type alpha; + + + scalar_type my_tolerance; + unsigned long my_max_dictionary_size; + scalar_type samples_seen; + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// Overloads for when kernel_type == offset_kernel +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template + class kcentroid > > + { + + /*! + INITIAL VALUE + - x_extra == sqrt(kernel.offset) + + CONVENTION + - x_extra == sqrt(kernel.offset) + - w_extra == the value of the extra dimension tacked onto the + end of the w vector + !*/ + + typedef offset_kernel > kernel_type; + public: + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + + + explicit kcentroid ( + const kernel_type& kernel_, + scalar_type tolerance_ = 0.001, + unsigned long max_dictionary_size_ = 1000000, + bool remove_oldest_first_ = false + ) : + my_remove_oldest_first(remove_oldest_first_), + kernel(kernel_), + my_tolerance(tolerance_), + my_max_dictionary_size(max_dictionary_size_) + { + // make sure requires clause is not broken + DLIB_ASSERT(tolerance_ >= 0 && max_dictionary_size_ > 0, + "\tkcentroid::kcentroid()" + << "\n\t You have to give a positive tolerance" + << "\n\t this: " << this + << "\n\t tolerance_: " << tolerance_ + << "\n\t max_dictionary_size_: " << max_dictionary_size_ + ); + + x_extra = std::sqrt(kernel.offset); + + clear_dictionary(); + } + + scalar_type tolerance() const { return my_tolerance; } + unsigned long max_dictionary_size() const { return my_max_dictionary_size; } + bool remove_oldest_first () const { return my_remove_oldest_first; } + const kernel_type& get_kernel () const { return kernel; } + scalar_type samples_trained () const { return samples_seen; } + + void clear_dictionary () + { + samples_seen = 0; + w.clear(); + alpha = 0; + w_extra = x_extra; + } + + scalar_type operator() ( + const kcentroid& x + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(x.get_kernel() == get_kernel(), + "\tscalar_type kcentroid::operator()(const kcentroid& x)" + << "\n\tYou can only compare two kcentroid objects if they use the same kernel" + << "\n\tthis: " << this + ); + + if (samples_seen > 0) + { + scalar_type temp1 = distance_squared(alpha,w , x.alpha,x.w); + scalar_type temp2 = alpha*w_extra - x.alpha*x.w_extra; + return std::sqrt(temp1 + temp2*temp2); + } + else + { + return 0; + } + } + + scalar_type inner_product ( + const sample_type& x + ) const + { + if (samples_seen > 0) + return alpha*(dot(w,x) + w_extra*x_extra); + else + return 0; + } + + scalar_type inner_product ( + const kcentroid& x + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(x.get_kernel() == get_kernel(), + "\tscalar_type kcentroid::inner_product(const kcentroid& x)" + << "\n\tYou can only compare two kcentroid objects if they use the same kernel" + << "\n\tthis: " << this + ); + + if (samples_seen > 0 && x.samples_seen > 0) + return alpha*x.alpha*(dot(w,x.w) + w_extra*x.w_extra); + else + return 0; + } + + scalar_type squared_norm ( + ) const + { + if (samples_seen > 0) + return alpha*alpha*(length_squared(w) + w_extra*w_extra); + else + return 0; + } + + scalar_type operator() ( + const sample_type& x + ) const + { + if (samples_seen > 0) + { + scalar_type temp1 = distance_squared(1,x,alpha,w); + scalar_type temp2 = x_extra - alpha*w_extra; + return std::sqrt(temp1 + temp2*temp2); + } + else + { + return std::sqrt(length_squared(x) + x_extra*x_extra); + } + } + + scalar_type test_and_train ( + const sample_type& x + ) + { + ++samples_seen; + const scalar_type xscale = 1/samples_seen; + const scalar_type cscale = 1-xscale; + + do_train(x, cscale, xscale); + + return (*this)(x); + } + + void train ( + const sample_type& x + ) + { + ++samples_seen; + const scalar_type xscale = 1/samples_seen; + const scalar_type cscale = 1-xscale; + + do_train(x, cscale, xscale); + } + + scalar_type test_and_train ( + const sample_type& x, + scalar_type cscale, + scalar_type xscale + ) + { + ++samples_seen; + + do_train(x, cscale, xscale); + + return (*this)(x); + } + + void scale_by ( + scalar_type cscale + ) + { + alpha *= cscale; + w_extra *= cscale; + } + + void train ( + const sample_type& x, + scalar_type cscale, + scalar_type xscale + ) + { + ++samples_seen; + do_train(x, cscale, xscale); + } + + void swap ( + kcentroid& item + ) + { + exchange(my_remove_oldest_first, item.my_remove_oldest_first); + exchange(kernel, item.kernel); + exchange(w, item.w); + exchange(alpha, item.alpha); + exchange(w_extra, item.w_extra); + exchange(x_extra, item.x_extra); + exchange(my_tolerance, item.my_tolerance); + exchange(my_max_dictionary_size, item.my_max_dictionary_size); + exchange(samples_seen, item.samples_seen); + } + + unsigned long dictionary_size ( + ) const + { + if (samples_seen > 0) + { + if (std::abs(w_extra) > std::numeric_limits::epsilon()) + return 1; + else + return 2; + } + else + { + return 0; + } + } + + friend void serialize(const kcentroid& item, std::ostream& out) + { + serialize(item.my_remove_oldest_first, out); + serialize(item.kernel, out); + serialize(item.w, out); + serialize(item.alpha, out); + serialize(item.w_extra, out); + serialize(item.x_extra, out); + serialize(item.my_tolerance, out); + serialize(item.my_max_dictionary_size, out); + serialize(item.samples_seen, out); + } + + friend void deserialize(kcentroid& item, std::istream& in) + { + deserialize(item.my_remove_oldest_first, in); + deserialize(item.kernel, in); + deserialize(item.w, in); + deserialize(item.alpha, in); + deserialize(item.w_extra, in); + deserialize(item.x_extra, in); + deserialize(item.my_tolerance, in); + deserialize(item.my_max_dictionary_size, in); + deserialize(item.samples_seen, in); + } + + distance_function get_distance_function ( + ) const + { + if (samples_seen > 0) + { + typename distance_function::sample_vector_type temp_basis_vectors; + typename distance_function::scalar_vector_type temp_alpha; + + // What we are doing here needs a bit of explanation. The w vector + // has an implicit extra dimension tacked on to it with the value of w_extra. + // The kernel we are using takes normal vectors and implicitly tacks the value + // x_extra onto their end. So what we are doing here is scaling w so that + // the value it should have tacked onto it is x_scale. Note that we also + // adjust alpha so that the combination of alpha*w stays the same. + scalar_type scale; + + // if w_extra is basically greater than 0 + if (std::abs(w_extra) > std::numeric_limits::epsilon()) + { + scale = (x_extra/w_extra); + temp_basis_vectors.set_size(1); + temp_alpha.set_size(1); + temp_basis_vectors(0) = sample_type(w.begin(), w.end()); + dlib::scale_by(temp_basis_vectors(0), scale); + temp_alpha(0) = alpha/scale; + } + else + { + // In this case w_extra is zero. So the only way we can get the same + // thing in the output basis vector set is by using two vectors + temp_basis_vectors.set_size(2); + temp_alpha.set_size(2); + temp_basis_vectors(0) = sample_type(w.begin(), w.end()); + dlib::scale_by(temp_basis_vectors(0), 2); + temp_alpha(0) = alpha; + temp_basis_vectors(1) = sample_type(w.begin(), w.end()); + temp_alpha(1) = -alpha; + } + + return distance_function(temp_alpha, squared_norm(), kernel, temp_basis_vectors); + + } + else + { + return distance_function(kernel); + } + + } + + private: + + void do_train ( + const sample_type& x, + scalar_type cscale, + scalar_type xscale + ) + { + + const scalar_type temp = cscale*alpha; + + if (temp != 0) + { + // compute w += xscale*x/temp + typename sample_type::const_iterator i; + for (i = x.begin(); i != x.end(); ++i) + { + w[i->first] += xscale*(i->second)/temp; + } + + w_extra = w_extra + xscale*x_extra/temp; + alpha = temp; + } + else + { + // first compute w = cscale*alpha*w + for (typename std::map::iterator i = w.begin(); i != w.end(); ++i) + { + i->second *= cscale*alpha; + } + + // now compute w += xscale*x + for (typename sample_type::const_iterator i = x.begin(); i != x.end(); ++i) + { + w[i->first] += xscale*(i->second); + } + + + w_extra = cscale*alpha*w_extra + xscale*x_extra; + alpha = 1; + } + } + + bool my_remove_oldest_first; + + kernel_type kernel; + + std::map w; + scalar_type alpha; + + scalar_type w_extra; + scalar_type x_extra; + + + scalar_type my_tolerance; + unsigned long my_max_dictionary_size; + scalar_type samples_seen; + + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_KCENTROId_OVERLOADS_ + + diff --git a/ml/dlib/dlib/svm/kernel.h b/ml/dlib/dlib/svm/kernel.h new file mode 100644 index 000000000..907420986 --- /dev/null +++ b/ml/dlib/dlib/svm/kernel.h @@ -0,0 +1,569 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SVm_KERNEL +#define DLIB_SVm_KERNEL + +#include "kernel_abstract.h" +#include +#include +#include +#include "../matrix.h" +#include "../algs.h" +#include "../serialize.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < typename kernel_type > struct kernel_derivative; + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + struct radial_basis_kernel + { + typedef typename T::type scalar_type; + typedef T sample_type; + typedef typename T::mem_manager_type mem_manager_type; + + // T must be capable of representing a column vector. + COMPILE_TIME_ASSERT(T::NC == 1 || T::NC == 0); + + radial_basis_kernel(const scalar_type g) : gamma(g) {} + radial_basis_kernel() : gamma(0.1) {} + radial_basis_kernel( + const radial_basis_kernel& k + ) : gamma(k.gamma) {} + + + const scalar_type gamma; + + scalar_type operator() ( + const sample_type& a, + const sample_type& b + ) const + { + const scalar_type d = trans(a-b)*(a-b); + return std::exp(-gamma*d); + } + + radial_basis_kernel& operator= ( + const radial_basis_kernel& k + ) + { + const_cast(gamma) = k.gamma; + return *this; + } + + bool operator== ( + const radial_basis_kernel& k + ) const + { + return gamma == k.gamma; + } + }; + + template < + typename T + > + void serialize ( + const radial_basis_kernel& item, + std::ostream& out + ) + { + try + { + serialize(item.gamma, out); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing object of type radial_basis_kernel"); + } + } + + template < + typename T + > + void deserialize ( + radial_basis_kernel& item, + std::istream& in + ) + { + typedef typename T::type scalar_type; + try + { + deserialize(const_cast(item.gamma), in); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while deserializing object of type radial_basis_kernel"); + } + } + + template < + typename T + > + struct kernel_derivative > + { + typedef typename T::type scalar_type; + typedef T sample_type; + typedef typename T::mem_manager_type mem_manager_type; + + kernel_derivative(const radial_basis_kernel& k_) : k(k_){} + + const sample_type& operator() (const sample_type& x, const sample_type& y) const + { + // return the derivative of the rbf kernel + temp = 2*k.gamma*(x-y)*k(x,y); + return temp; + } + + const radial_basis_kernel& k; + mutable sample_type temp; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + struct polynomial_kernel + { + typedef typename T::type scalar_type; + typedef T sample_type; + typedef typename T::mem_manager_type mem_manager_type; + + // T must be capable of representing a column vector. + COMPILE_TIME_ASSERT(T::NC == 1 || T::NC == 0); + + polynomial_kernel(const scalar_type g, const scalar_type c, const scalar_type d) : gamma(g), coef(c), degree(d) {} + polynomial_kernel() : gamma(1), coef(0), degree(1) {} + polynomial_kernel( + const polynomial_kernel& k + ) : gamma(k.gamma), coef(k.coef), degree(k.degree) {} + + typedef T type; + const scalar_type gamma; + const scalar_type coef; + const scalar_type degree; + + scalar_type operator() ( + const sample_type& a, + const sample_type& b + ) const + { + return std::pow(gamma*(trans(a)*b) + coef, degree); + } + + polynomial_kernel& operator= ( + const polynomial_kernel& k + ) + { + const_cast(gamma) = k.gamma; + const_cast(coef) = k.coef; + const_cast(degree) = k.degree; + return *this; + } + + bool operator== ( + const polynomial_kernel& k + ) const + { + return (gamma == k.gamma) && (coef == k.coef) && (degree == k.degree); + } + }; + + template < + typename T + > + void serialize ( + const polynomial_kernel& item, + std::ostream& out + ) + { + try + { + serialize(item.gamma, out); + serialize(item.coef, out); + serialize(item.degree, out); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing object of type polynomial_kernel"); + } + } + + template < + typename T + > + void deserialize ( + polynomial_kernel& item, + std::istream& in + ) + { + typedef typename T::type scalar_type; + try + { + deserialize(const_cast(item.gamma), in); + deserialize(const_cast(item.coef), in); + deserialize(const_cast(item.degree), in); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while deserializing object of type polynomial_kernel"); + } + } + + template < + typename T + > + struct kernel_derivative > + { + typedef typename T::type scalar_type; + typedef T sample_type; + typedef typename T::mem_manager_type mem_manager_type; + + kernel_derivative(const polynomial_kernel& k_) : k(k_){} + + const sample_type& operator() (const sample_type& x, const sample_type& y) const + { + // return the derivative of the rbf kernel + temp = k.degree*k.gamma*x*std::pow(k.gamma*(trans(x)*y) + k.coef, k.degree-1); + return temp; + } + + const polynomial_kernel& k; + mutable sample_type temp; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + struct sigmoid_kernel + { + typedef typename T::type scalar_type; + typedef T sample_type; + typedef typename T::mem_manager_type mem_manager_type; + + // T must be capable of representing a column vector. + COMPILE_TIME_ASSERT(T::NC == 1 || T::NC == 0); + + sigmoid_kernel(const scalar_type g, const scalar_type c) : gamma(g), coef(c) {} + sigmoid_kernel() : gamma(0.1), coef(-1.0) {} + sigmoid_kernel( + const sigmoid_kernel& k + ) : gamma(k.gamma), coef(k.coef) {} + + typedef T type; + const scalar_type gamma; + const scalar_type coef; + + scalar_type operator() ( + const sample_type& a, + const sample_type& b + ) const + { + return std::tanh(gamma*(trans(a)*b) + coef); + } + + sigmoid_kernel& operator= ( + const sigmoid_kernel& k + ) + { + const_cast(gamma) = k.gamma; + const_cast(coef) = k.coef; + return *this; + } + + bool operator== ( + const sigmoid_kernel& k + ) const + { + return (gamma == k.gamma) && (coef == k.coef); + } + }; + + template < + typename T + > + void serialize ( + const sigmoid_kernel& item, + std::ostream& out + ) + { + try + { + serialize(item.gamma, out); + serialize(item.coef, out); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing object of type sigmoid_kernel"); + } + } + + template < + typename T + > + void deserialize ( + sigmoid_kernel& item, + std::istream& in + ) + { + typedef typename T::type scalar_type; + try + { + deserialize(const_cast(item.gamma), in); + deserialize(const_cast(item.coef), in); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while deserializing object of type sigmoid_kernel"); + } + } + + template < + typename T + > + struct kernel_derivative > + { + typedef typename T::type scalar_type; + typedef T sample_type; + typedef typename T::mem_manager_type mem_manager_type; + + kernel_derivative(const sigmoid_kernel& k_) : k(k_){} + + const sample_type& operator() (const sample_type& x, const sample_type& y) const + { + // return the derivative of the rbf kernel + temp = k.gamma*x*(1-std::pow(k(x,y),2)); + return temp; + } + + const sigmoid_kernel& k; + mutable sample_type temp; + }; + +// ---------------------------------------------------------------------------------------- + + template + struct linear_kernel + { + typedef typename T::type scalar_type; + typedef T sample_type; + typedef typename T::mem_manager_type mem_manager_type; + + // T must be capable of representing a column vector. + COMPILE_TIME_ASSERT(T::NC == 1 || T::NC == 0); + + scalar_type operator() ( + const sample_type& a, + const sample_type& b + ) const + { + return trans(a)*b; + } + + bool operator== ( + const linear_kernel& + ) const + { + return true; + } + }; + + template < + typename T + > + void serialize ( + const linear_kernel& , + std::ostream& + ){} + + template < + typename T + > + void deserialize ( + linear_kernel& , + std::istream& + ){} + + template < + typename T + > + struct kernel_derivative > + { + typedef typename T::type scalar_type; + typedef T sample_type; + typedef typename T::mem_manager_type mem_manager_type; + + kernel_derivative(const linear_kernel& k_) : k(k_){} + + const sample_type& operator() (const sample_type& x, const sample_type& ) const + { + return x; + } + + const linear_kernel& k; + }; + +// ---------------------------------------------------------------------------------------- + + template + struct histogram_intersection_kernel + { + typedef typename T::type scalar_type; + typedef T sample_type; + typedef typename T::mem_manager_type mem_manager_type; + + scalar_type operator() ( + const sample_type& a, + const sample_type& b + ) const + { + scalar_type temp = 0; + for (long i = 0; i < a.size(); ++i) + { + temp += std::min(a(i), b(i)); + } + return temp; + } + + bool operator== ( + const histogram_intersection_kernel& + ) const + { + return true; + } + }; + + template < + typename T + > + void serialize ( + const histogram_intersection_kernel& , + std::ostream& + ){} + + template < + typename T + > + void deserialize ( + histogram_intersection_kernel& , + std::istream& + ){} + +// ---------------------------------------------------------------------------------------- + + template + struct offset_kernel + { + typedef typename T::scalar_type scalar_type; + typedef typename T::sample_type sample_type; + typedef typename T::mem_manager_type mem_manager_type; + + offset_kernel(const T& k, const scalar_type& offset_ + ) : kernel(k), offset(offset_) {} + offset_kernel() : kernel(T()), offset(0.01) {} + offset_kernel( + const offset_kernel& k + ) : kernel(k.kernel), offset(k.offset) {} + + const T kernel; + const scalar_type offset; + + scalar_type operator() ( + const sample_type& a, + const sample_type& b + ) const + { + return kernel(a,b) + offset; + } + + offset_kernel& operator= ( + const offset_kernel& k + ) + { + const_cast(kernel) = k.kernel; + const_cast(offset) = k.offset; + return *this; + } + + bool operator== ( + const offset_kernel& k + ) const + { + return k.kernel == kernel && offset == k.offset; + } + }; + + template < + typename T + > + void serialize ( + const offset_kernel& item, + std::ostream& out + ) + { + try + { + serialize(item.offset, out); + serialize(item.kernel, out); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing object of type offset_kernel"); + } + } + + template < + typename T + > + void deserialize ( + offset_kernel& item, + std::istream& in + ) + { + typedef typename offset_kernel::scalar_type scalar_type; + try + { + deserialize(const_cast(item.offset), in); + deserialize(const_cast(item.kernel), in); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while deserializing object of type offset_kernel"); + } + } + + template < + typename T + > + struct kernel_derivative > + { + typedef typename T::scalar_type scalar_type; + typedef typename T::sample_type sample_type; + typedef typename T::mem_manager_type mem_manager_type; + + kernel_derivative(const offset_kernel& k) : der(k.kernel){} + + const sample_type operator() (const sample_type& x, const sample_type& y) const + { + return der(x,y); + } + + kernel_derivative der; + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SVm_KERNEL + + diff --git a/ml/dlib/dlib/svm/kernel_abstract.h b/ml/dlib/dlib/svm/kernel_abstract.h new file mode 100644 index 000000000..f72430eb8 --- /dev/null +++ b/ml/dlib/dlib/svm/kernel_abstract.h @@ -0,0 +1,681 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SVm_KERNEL_ABSTRACT_ +#ifdef DLIB_SVm_KERNEL_ABSTRACT_ + +#include +#include +#include +#include "../matrix/matrix_abstract.h" +#include "../algs.h" +#include "../serialize.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +/*!A Kernel_Function_Objects */ +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + /*! + WHAT IS A KERNEL FUNCTION OBJECT? + In the context of the dlib library documentation a kernel function object + is an object with an interface with the following properties: + - a public typedef named sample_type + - a public typedef named scalar_type which should be a float, double, or + long double type. + - an overloaded operator() that operates on two items of sample_type + and returns a scalar_type. + (e.g. scalar_type val = kernel_function(sample1,sample2); + would be a valid expression) + - a public typedef named mem_manager_type that is an implementation of + dlib/memory_manager/memory_manager_kernel_abstract.h or + dlib/memory_manager_global/memory_manager_global_kernel_abstract.h or + dlib/memory_manager_stateless/memory_manager_stateless_kernel_abstract.h + - an overloaded == operator that tells you if two kernels are + identical or not. + + THREAD SAFETY + For a kernel function to be threadsafe it means that it must be safe to + evaluate an expression like val = kernel_function(sample1,sample2) + simultaneously from multiple threads, even when the threads operate on the same + object instances (i.e. kernel_function, sample1, and sample2). The most common + way to make this safe is to ensure that the kernel function does not mutate any + data, either in itself or in its arguments. + + For examples of kernel functions see the following objects + (e.g. the radial_basis_kernel). + !*/ + + template < + typename T + > + struct radial_basis_kernel + { + /*! + REQUIREMENTS ON T + T must be a dlib::matrix object + + WHAT THIS OBJECT REPRESENTS + This object represents a radial basis function kernel + + THREAD SAFETY + This kernel is threadsafe. + !*/ + + typedef typename T::type scalar_type; + typedef T sample_type; + typedef typename T::mem_manager_type mem_manager_type; + + const scalar_type gamma; + + radial_basis_kernel( + ); + /*! + ensures + - #gamma == 0.1 + !*/ + + radial_basis_kernel( + const radial_basis_kernel& k + ); + /*! + ensures + - #gamma == k.gamma + !*/ + + radial_basis_kernel( + const scalar_type g + ); + /*! + ensures + - #gamma == g + !*/ + + scalar_type operator() ( + const sample_type& a, + const sample_type& b + ) const; + /*! + requires + - a.nc() == 1 + - b.nc() == 1 + - a.nr() == b.nr() + ensures + - returns exp(-gamma * ||a-b||^2) + !*/ + + radial_basis_kernel& operator= ( + const radial_basis_kernel& k + ); + /*! + ensures + - #gamma = k.gamma + - returns *this + !*/ + + bool operator== ( + const radial_basis_kernel& k + ) const; + /*! + ensures + - if (k and *this are identical) then + - returns true + - else + - returns false + !*/ + + }; + + template < + typename T + > + void serialize ( + const radial_basis_kernel& item, + std::ostream& out + ); + /*! + provides serialization support for radial_basis_kernel + !*/ + + template < + typename T + > + void deserialize ( + radial_basis_kernel& item, + std::istream& in + ); + /*! + provides deserialization support for radial_basis_kernel + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + struct sigmoid_kernel + { + /*! + REQUIREMENTS ON T + T must be a dlib::matrix object + + WHAT THIS OBJECT REPRESENTS + This object represents a sigmoid kernel + + THREAD SAFETY + This kernel is threadsafe. + !*/ + + typedef typename T::type scalar_type; + typedef T sample_type; + typedef typename T::mem_manager_type mem_manager_type; + + const scalar_type gamma; + const scalar_type coef; + + sigmoid_kernel( + ); + /*! + ensures + - #gamma == 0.1 + - #coef == -1.0 + !*/ + + sigmoid_kernel( + const sigmoid_kernel& k + ); + /*! + ensures + - #gamma == k.gamma + - #coef == k.coef + !*/ + + sigmoid_kernel( + const scalar_type g, + const scalar_type c + ); + /*! + ensures + - #gamma == g + - #coef == c + !*/ + + scalar_type operator() ( + const sample_type& a, + const sample_type& b + ) const; + /*! + requires + - a.nc() == 1 + - b.nc() == 1 + - a.nr() == b.nr() + ensures + - returns tanh(gamma*trans(a)*b + coef) + !*/ + + sigmoid_kernel& operator= ( + const sigmoid_kernel& k + ); + /*! + ensures + - #gamma = k.gamma + - #coef = k.coef + - returns *this + !*/ + + bool operator== ( + const sigmoid_kernel& k + ) const; + /*! + ensures + - if (k and *this are identical) then + - returns true + - else + - returns false + !*/ + }; + + template < + typename T + > + void serialize ( + const sigmoid_kernel& item, + std::ostream& out + ); + /*! + provides serialization support for sigmoid_kernel + !*/ + + template < + typename T + > + void deserialize ( + sigmoid_kernel& item, + std::istream& in + ); + /*! + provides deserialization support for sigmoid_kernel + !*/ + + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + struct polynomial_kernel + { + /*! + REQUIREMENTS ON T + T must be a dlib::matrix object + + WHAT THIS OBJECT REPRESENTS + This object represents a polynomial kernel + + THREAD SAFETY + This kernel is threadsafe. + !*/ + + typedef typename T::type scalar_type; + typedef T sample_type; + typedef typename T::mem_manager_type mem_manager_type; + + const scalar_type gamma; + const scalar_type coef; + const scalar_type degree; + + polynomial_kernel( + ); + /*! + ensures + - #gamma == 1 + - #coef == 0 + - #degree == 1 + !*/ + + polynomial_kernel( + const polynomial_kernel& k + ); + /*! + ensures + - #gamma == k.gamma + - #coef == k.coef + - #degree == k.degree + !*/ + + polynomial_kernel( + const scalar_type g, + const scalar_type c, + const scalar_type d + ); + /*! + ensures + - #gamma == g + - #coef == c + - #degree == d + !*/ + + scalar_type operator() ( + const sample_type& a, + const sample_type& b + ) const; + /*! + requires + - a.nc() == 1 + - b.nc() == 1 + - a.nr() == b.nr() + ensures + - returns pow(gamma*trans(a)*b + coef, degree) + !*/ + + polynomial_kernel& operator= ( + const polynomial_kernel& k + ); + /*! + ensures + - #gamma = k.gamma + - #coef = k.coef + - #degree = k.degree + - returns *this + !*/ + + bool operator== ( + const polynomial_kernel& k + ) const; + /*! + ensures + - if (k and *this are identical) then + - returns true + - else + - returns false + !*/ + }; + + template < + typename T + > + void serialize ( + const polynomial_kernel& item, + std::ostream& out + ); + /*! + provides serialization support for polynomial_kernel + !*/ + + template < + typename T + > + void deserialize ( + polynomial_kernel& item, + std::istream& in + ); + /*! + provides deserialization support for polynomial_kernel + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + struct linear_kernel + { + /*! + REQUIREMENTS ON T + T must be a dlib::matrix object + + WHAT THIS OBJECT REPRESENTS + This object represents a linear function kernel + + THREAD SAFETY + This kernel is threadsafe. + !*/ + + typedef typename T::type scalar_type; + typedef T sample_type; + typedef typename T::mem_manager_type mem_manager_type; + + scalar_type operator() ( + const sample_type& a, + const sample_type& b + ) const; + /*! + requires + - a.nc() == 1 + - b.nc() == 1 + - a.nr() == b.nr() + ensures + - returns trans(a)*b + !*/ + + bool operator== ( + const linear_kernel& k + ) const; + /*! + ensures + - returns true + !*/ + }; + + template < + typename T + > + void serialize ( + const linear_kernel& item, + std::ostream& out + ); + /*! + provides serialization support for linear_kernel + !*/ + + template < + typename T + > + void deserialize ( + linear_kernel& item, + std::istream& in + ); + /*! + provides deserialization support for linear_kernel + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + struct histogram_intersection_kernel + { + /*! + REQUIREMENTS ON T + T must be a dlib::matrix object + + WHAT THIS OBJECT REPRESENTS + This object represents a histogram intersection kernel kernel + + THREAD SAFETY + This kernel is threadsafe. + !*/ + + typedef typename T::type scalar_type; + typedef T sample_type; + typedef typename T::mem_manager_type mem_manager_type; + + scalar_type operator() ( + const sample_type& a, + const sample_type& b + ) const; + /*! + requires + - is_vector(a) + - is_vector(b) + - a.size() == b.size() + - min(a) >= 0 + - min(b) >= 0 + ensures + - returns sum over all i: std::min(a(i), b(i)) + !*/ + + bool operator== ( + const histogram_intersection_kernel& k + ) const; + /*! + ensures + - returns true + !*/ + }; + + template < + typename T + > + void serialize ( + const histogram_intersection_kernel& item, + std::ostream& out + ); + /*! + provides serialization support for histogram_intersection_kernel + !*/ + + template < + typename T + > + void deserialize ( + histogram_intersection_kernel& item, + std::istream& in + ); + /*! + provides deserialization support for histogram_intersection_kernel + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + struct offset_kernel + { + /*! + REQUIREMENTS ON T + T must be a kernel object (e.g. radial_basis_kernel, polynomial_kernel, etc.) + + WHAT THIS OBJECT REPRESENTS + This object represents a kernel with a fixed value offset + added to it. + + THREAD SAFETY + This kernel is threadsafe. + !*/ + + typedef typename T::scalar_type scalar_type; + typedef typename T::sample_type sample_type; + typedef typename T::mem_manager_type mem_manager_type; + + const T kernel; + const scalar_type offset; + + offset_kernel( + ); + /*! + ensures + - #offset == 0.01 + !*/ + + offset_kernel( + const offset_kernel& k + ); + /*! + ensures + - #offset == k.offset + - #kernel == k.kernel + !*/ + + offset_kernel( + const T& k, + const scalar_type& off + ); + /*! + ensures + - #kernel == k + - #offset == off + !*/ + + scalar_type operator() ( + const sample_type& a, + const sample_type& b + ) const; + /*! + ensures + - returns kernel(a,b) + offset + !*/ + + offset_kernel& operator= ( + const offset_kernel& k + ); + /*! + ensures + - #offset == k.offset + - #kernel == k.kernel + !*/ + + bool operator== ( + const offset_kernel& k + ) const; + /*! + ensures + - if (k and *this are identical) then + - returns true + - else + - returns false + !*/ + }; + + template < + typename T + > + void serialize ( + const offset_kernel& item, + std::ostream& out + ); + /*! + provides serialization support for offset_kernel + !*/ + + template < + typename T + > + void deserialize ( + offset_kernel& item, + std::istream& in + ); + /*! + provides deserialization support for offset_kernel + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename kernel_type + > + struct kernel_derivative + { + /*! + REQUIREMENTS ON kernel_type + kernel_type must be one of the following kernel types: + - radial_basis_kernel + - polynomial_kernel + - sigmoid_kernel + - linear_kernel + - offset_kernel + + WHAT THIS OBJECT REPRESENTS + This is a function object that computes the derivative of a kernel + function object. + + THREAD SAFETY + It is always safe to use distinct instances of this object in different + threads. However, when a single instance is shared between threads then + the following rules apply: + Instances of this object are allowed to have a mutable cache which is + used by const member functions. Therefore, it is not safe to use one + instance of this object from multiple threads (unless protected by a + mutex). + !*/ + + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + + kernel_derivative( + const kernel_type& k_ + ); + /*! + ensures + - this object will return derivatives of the kernel object k_ + - #k == k_ + !*/ + + const sample_type operator() ( + const sample_type& x, + const sample_type& y + ) const; + /*! + ensures + - returns the derivative of k with respect to y. + !*/ + + const kernel_type& k; + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SVm_KERNEL_ABSTRACT_ + + + diff --git a/ml/dlib/dlib/svm/kernel_matrix.h b/ml/dlib/dlib/svm/kernel_matrix.h new file mode 100644 index 000000000..f6e1e0b90 --- /dev/null +++ b/ml/dlib/dlib/svm/kernel_matrix.h @@ -0,0 +1,268 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SVm_KERNEL_MATRIX_ +#define DLIB_SVm_KERNEL_MATRIX_ + +#include +#include "kernel_matrix_abstract.h" +#include "../matrix.h" +#include "../algs.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + namespace impl + { + template + inline const typename T::type& access ( const matrix_exp& m, long i) + { + return m(i); + } + + // bind to anything that looks like an array and isn't a matrix + template + inline const typename disable_if,typename T::type>::type& access ( const T& m, long i) + { + return m[i]; + } + + // Only use this function if T isn't a std::pair because in that case the entire vector is + // probably itself a sparse sample. + template + inline typename disable_if,const T&>::type access ( const std::vector& m, long i) + { + return m[i]; + } + + // Only use this function if T isn't a std::pair because in that case the entire vector is + // probably a sparse sample. + template + inline typename disable_if,const T&>::type access ( const std_vector_c& m, long i) + { + return m[i]; + } + + template + inline const typename kernel_type::sample_type& access ( + const typename kernel_type::sample_type& samp, + long + ) + { + return samp; + } + + // -------------------------------------------- + + template + inline typename disable_if,unsigned long>::type + size ( const T& m) + { + return m.size(); + } + + template + inline size_t size ( + const typename kernel_type::sample_type& + ) + { + return 1; + } + + // -------------------------------------------- + + template + typename disable_if >::type assert_is_vector(const T&) + {} + + template + // This funny #ifdef thing is here because gcc sometimes gives a warning + // about v being unused otherwise. +#ifdef ENABLE_ASSERTS + void assert_is_vector(const matrix_exp& v) +#else + void assert_is_vector(const matrix_exp& ) +#endif + { + // make sure requires clause is not broken + DLIB_ASSERT(is_vector(v) == true, + "\tconst matrix_exp kernel_matrix()" + << "\n\t You have to supply this function with row or column vectors" + << "\n\t v.nr(): " << v.nr() + << "\n\t v.nc(): " << v.nc() + ); + } + + } + + template + struct op_kern_mat + { + op_kern_mat( + const K& kern_, + const vect_type1& vect1_, + const vect_type2& vect2_ + ) : + kern(kern_), + vect1(vect1_), + vect2(vect2_) + { + // make sure the requires clauses get checked eventually + impl::assert_is_vector(vect1); + impl::assert_is_vector(vect2); + } + + const K& kern; + const vect_type1& vect1; + const vect_type2& vect2; + + typedef typename K::scalar_type type; + + const static long cost = 100; + const static long NR = (is_same_type::value) ? 1 : 0; + const static long NC = (is_same_type::value) ? 1 : 0; + + typedef const type const_ret_type; + typedef typename K::mem_manager_type mem_manager_type; + typedef row_major_layout layout_type; + + const_ret_type apply (long r, long c ) const + { + return kern(impl::access(vect1,r), impl::access(vect2,c)); + } + + long nr () const { return impl::size(vect1); } + long nc () const { return impl::size(vect2); } + + template bool aliases ( const matrix_exp& item ) const { return alias_helper(item.ref()); } + template bool destructively_aliases ( const matrix_exp& item ) const { return alias_helper(item.ref()); } + + template bool alias_helper ( const U& ) const { return false; } + + typedef typename K::sample_type samp_type; + + // Say we destructively alias if one of the vect* objects is actually item. + bool alias_helper (const samp_type& item ) const { return are_same(item, vect1) || are_same(item, vect2); } + template bool are_same (const samp_type& , const U& ) const { return false; } + bool are_same (const samp_type& a, const samp_type& b) const { return (&a == &b); } + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename K, + typename V1, + typename V2 + > + const matrix_op > kernel_matrix ( + const K& kern, + const V1& v1, + const V2& v2 + ) + { + typedef op_kern_mat op; + return matrix_op(op(kern,v1,v2)); + } + +// ---------------------------------------------------------------------------------------- + + /* + It is possible to implement the kernel_matrix() operator with just one operator + class but treating the version that takes only a single vector separately + leads to more efficient output by gcc in certain instances. + */ + template + struct op_kern_mat_single + { + op_kern_mat_single( + const K& kern_, + const vect_type1& vect1_ + ) : + kern(kern_), + vect1(vect1_) + { + // make sure the requires clauses get checked eventually + impl::assert_is_vector(vect1); + } + + const K& kern; + const vect_type1& vect1; + + typedef typename K::scalar_type type; + + const static long cost = 100; + const static long NR = (is_same_type::value) ? 1 : 0; + const static long NC = (is_same_type::value) ? 1 : 0; + + typedef const type const_ret_type; + typedef typename K::mem_manager_type mem_manager_type; + typedef row_major_layout layout_type; + + const_ret_type apply (long r, long c ) const + { + return kern(impl::access(vect1,r), impl::access(vect1,c)); + } + + long nr () const { return impl::size(vect1); } + long nc () const { return impl::size(vect1); } + + template bool aliases ( const matrix_exp& item ) const { return alias_helper(item.ref()); } + template bool destructively_aliases ( const matrix_exp& item ) const { return alias_helper(item.ref()); } + + template bool alias_helper ( const U& ) const { return false; } + + typedef typename K::sample_type samp_type; + + // Say we destructively alias if vect1 is actually item. + bool alias_helper (const samp_type& item ) const { return are_same(item, vect1); } + template bool are_same (const samp_type& , const U& ) const { return false; } + bool are_same (const samp_type& a, const samp_type& b) const { return (&a == &b); } + }; + + template < + typename K, + typename V + > + const matrix_op > kernel_matrix ( + const K& kern, + const V& v + ) + { + typedef op_kern_mat_single op; + return matrix_op(op(kern,v)); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename matrix_dest_type, + typename K, + typename V + > + inline void matrix_assign ( + matrix_dest_type& dest, + const matrix_exp > >& src + ) + /*! + Overload matrix assignment so that when a kernel_matrix expression + gets assigned it only evaluates half the kernel matrix (since it is symmetric) + !*/ + { + for (long r = 0; r < src.nr(); ++r) + { + for (long c = r; c < src.nc(); ++c) + { + dest(r,c) = dest(c,r) = src(r,c); + } + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SVm_KERNEL_MATRIX_ + diff --git a/ml/dlib/dlib/svm/kernel_matrix_abstract.h b/ml/dlib/dlib/svm/kernel_matrix_abstract.h new file mode 100644 index 000000000..4aa4b1ce2 --- /dev/null +++ b/ml/dlib/dlib/svm/kernel_matrix_abstract.h @@ -0,0 +1,115 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SVm_KERNEL_MATRIX_ABSTRACT_ +#ifdef DLIB_SVm_KERNEL_MATRIX_ABSTRACT_ + +#include +#include "kernel_abstract.h" +#include "../matrix/matrix_abstract.h" +#include "../algs.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename kernel_type, + typename V + > + const matrix_exp kernel_matrix ( + const kernel_type& kernel, + const V& v + ); + /*! + requires + - kernel == a kernel function object as defined by the file dlib/svm/kernel_abstract.h. + This kernel must also be capable of operating on the contents of v. + - V == dlib::matrix, std::vector, dlib::std_vector_c, dlib::random_subset_selector, + dlib::linearly_independent_subset_finder, or kernel_type::sample_type. + - if (V is a dlib::matrix) then + - is_vector(v) == true + ensures + - if (V is of type kernel_type::sample_type) then + - returns a matrix R such that: + - R::type == kernel_type::scalar_type + - R.size() == 1 + - R(0,0) == kernel(v,v) + - else + - returns a matrix R such that: + - R::type == kernel_type::scalar_type + - R is a square matrix of v.size() rows by v.size() columns + - for all valid r and c: + - R(r,c) == kernel(v(r), v(c)) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename kernel_type, + typename V1, + typename V2 + > + const matrix_exp kernel_matrix ( + const kernel_type& kernel, + const V1& v1, + const V2& v2 + ); + /*! + requires + - kernel == a kernel function object as defined by the file dlib/svm/kernel_abstract.h + This kernel must also be capable of operating on the contents of v1 and v2. + - V1 == dlib::matrix, std::vector, dlib::std_vector_c, dlib::random_subset_selector, + dlib::linearly_independent_subset_finder, or kernel_type::sample_type. + - V2 == dlib::matrix, std::vector, dlib::std_vector_c, dlib::random_subset_selector, + dlib::linearly_independent_subset_finder, or kernel_type::sample_type. + - if (V1 is a dlib::matrix) then + - is_vector(v1) == true + - if (V2 is a dlib::matrix) then + - is_vector(v2) == true + ensures + - if (V1 and V2 are of type kernel_type::sample_type) then + - returns a matrix R such that: + - R::type == kernel_type::scalar_type + - R.size() == 1 + - R(0,0) == kernel(v1,v2) + - else if (V1 is of type kernel_type::sample_type) then + - returns a matrix R such that: + - R::type == kernel_type::scalar_type + - R.nr() == 1 + - R.nc() == v2.size() + - for all valid c: + - R(0,c) == kernel(v1, v2(c)) + - else if (V2 is of type kernel_type::sample_type) then + - returns a matrix R such that: + - R::type == kernel_type::scalar_type + - R.nr() == v1.size() + - R.nc() == 1 + - for all valid r: + - R(r,0) == kernel(v1(r), v2) + - else + - returns a matrix R such that: + - R::type == kernel_type::scalar_type + - R.nr() == v1.size() + - R.nc() == v2.size() + - for all valid r and c: + - R(r,c) == kernel(v1(r), v2(c)) + + + A note about aliasing (see the examples/matrix_expressions_ex.cpp example program + for a discussion of what aliasing is in the context of the dlib::matrix): + kernel_matrix() expressions can detect aliasing of an argument if that + argument is of type kernel_type::sample_type. However, it can't detect + aliasing though std::vectors or other "list of sample type" container class + arguments. This means that it is safe to assign a kernel_matrix() expression + to a sample_type if V1 or V2 are of sample_type but not safe otherwise. However, + since the latter case results in a general n by m matrix rather than a column + or row vector you shouldn't ever be doing it anyway. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SVm_KERNEL_MATRIX_ABSTRACT_ + diff --git a/ml/dlib/dlib/svm/kkmeans.h b/ml/dlib/dlib/svm/kkmeans.h new file mode 100644 index 000000000..4c72106d8 --- /dev/null +++ b/ml/dlib/dlib/svm/kkmeans.h @@ -0,0 +1,654 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_KKMEANs_ +#define DLIB_KKMEANs_ + +#include +#include + +#include "../matrix/matrix_abstract.h" +#include "../algs.h" +#include "../serialize.h" +#include "kernel.h" +#include "../array.h" +#include "kcentroid.h" +#include "kkmeans_abstract.h" +#include "../noncopyable.h" + +namespace dlib +{ + + template < + typename kernel_type + > + class kkmeans : public noncopyable + { + public: + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + + kkmeans ( + const kcentroid& kc_ + ): + kc(kc_), + min_change(0.01) + { + set_number_of_centers(1); + } + + ~kkmeans() + { + } + + const kernel_type& get_kernel ( + ) const + { + return kc.get_kernel(); + } + + void set_kcentroid ( + const kcentroid& kc_ + ) + { + kc = kc_; + set_number_of_centers(number_of_centers()); + } + + const kcentroid& get_kcentroid ( + unsigned long i + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(i < number_of_centers(), + "\tkcentroid kkmeans::get_kcentroid(i)" + << "\n\tYou have given an invalid value for i" + << "\n\ti: " << i + << "\n\tnumber_of_centers(): " << number_of_centers() + << "\n\tthis: " << this + ); + + return *centers[i]; + } + + void set_number_of_centers ( + unsigned long num + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(num > 0, + "\tvoid kkmeans::set_number_of_centers()" + << "\n\tYou can't set the number of centers to zero" + << "\n\tthis: " << this + ); + + centers.set_max_size(num); + centers.set_size(num); + + for (unsigned long i = 0; i < centers.size(); ++i) + { + centers[i].reset(new kcentroid(kc)); + } + } + + unsigned long number_of_centers ( + ) const + { + return centers.size(); + } + + template + void train ( + const T& samples, + const U& initial_centers, + long max_iter = 1000 + ) + { + do_train(mat(samples),mat(initial_centers),max_iter); + } + + unsigned long operator() ( + const sample_type& sample + ) const + { + unsigned long label = 0; + scalar_type best_score = (*centers[0])(sample); + + // figure out which center the given sample is closest too + for (unsigned long i = 1; i < centers.size(); ++i) + { + scalar_type temp = (*centers[i])(sample); + if (temp < best_score) + { + label = i; + best_score = temp; + } + } + + return label; + } + + void set_min_change ( + scalar_type min_change_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( 0 <= min_change_ < 1, + "\tvoid kkmeans::set_min_change()" + << "\n\tInvalid arguments to this function" + << "\n\tthis: " << this + << "\n\tmin_change_: " << min_change_ + ); + min_change = min_change_; + } + + const scalar_type get_min_change ( + ) const + { + return min_change; + } + + void swap ( + kkmeans& item + ) + { + centers.swap(item.centers); + kc.swap(item.kc); + assignments.swap(item.assignments); + exchange(min_change, item.min_change); + } + + friend void serialize(const kkmeans& item, std::ostream& out) + { + serialize(item.centers.size(),out); + for (unsigned long i = 0; i < item.centers.size(); ++i) + { + serialize(*item.centers[i], out); + } + serialize(item.kc, out); + serialize(item.min_change, out); + } + + friend void deserialize(kkmeans& item, std::istream& in) + { + unsigned long num; + deserialize(num, in); + item.centers.resize(num); + for (unsigned long i = 0; i < item.centers.size(); ++i) + { + std::unique_ptr > temp(new kcentroid(kernel_type())); + deserialize(*temp, in); + item.centers[i].swap(temp); + } + + deserialize(item.kc, in); + deserialize(item.min_change, in); + } + + private: + + template + void do_train ( + const matrix_type& samples, + const matrix_type2& initial_centers, + long max_iter = 1000 + ) + { + COMPILE_TIME_ASSERT((is_same_type::value)); + COMPILE_TIME_ASSERT((is_same_type::value)); + + // make sure requires clause is not broken + DLIB_ASSERT(samples.nc() == 1 && initial_centers.nc() == 1 && + initial_centers.nr() == static_cast(number_of_centers()), + "\tvoid kkmeans::train()" + << "\n\tInvalid arguments to this function" + << "\n\tthis: " << this + << "\n\tsamples.nc(): " << samples.nc() + << "\n\tinitial_centers.nc(): " << initial_centers.nc() + << "\n\tinitial_centers.nr(): " << initial_centers.nr() + ); + + // clear out the old data and initialize the centers + for (unsigned long i = 0; i < centers.size(); ++i) + { + centers[i]->clear_dictionary(); + centers[i]->train(initial_centers(i)); + } + + assignments.resize(samples.size()); + + bool assignment_changed = true; + + // loop until the centers stabilize + long count = 0; + const unsigned long min_num_change = static_cast(min_change*samples.size()); + unsigned long num_changed = min_num_change; + while (assignment_changed && count < max_iter && num_changed >= min_num_change) + { + ++count; + assignment_changed = false; + num_changed = 0; + + // loop over all the samples and assign them to their closest centers + for (long i = 0; i < samples.size(); ++i) + { + // find the best center + unsigned long best_center = 0; + scalar_type best_score = (*centers[0])(samples(i)); + for (unsigned long c = 1; c < centers.size(); ++c) + { + scalar_type temp = (*centers[c])(samples(i)); + if (temp < best_score) + { + best_score = temp; + best_center = c; + } + } + + // if the current sample changed centers then make note of that + if (assignments[i] != best_center) + { + assignments[i] = best_center; + assignment_changed = true; + ++num_changed; + } + } + + if (assignment_changed) + { + // now clear out the old data + for (unsigned long i = 0; i < centers.size(); ++i) + centers[i]->clear_dictionary(); + + // recalculate the cluster centers + for (unsigned long i = 0; i < assignments.size(); ++i) + centers[assignments[i]]->train(samples(i)); + } + + } + + + } + + array > > centers; + kcentroid kc; + scalar_type min_change; + + // temp variables + array assignments; + }; + +// ---------------------------------------------------------------------------------------- + + template + void swap(kkmeans& a, kkmeans& b) + { a.swap(b); } + +// ---------------------------------------------------------------------------------------- + + struct dlib_pick_initial_centers_data + { + dlib_pick_initial_centers_data():idx(0), dist(std::numeric_limits::infinity()){} + long idx; + double dist; + bool operator< (const dlib_pick_initial_centers_data& d) const { return dist < d.dist; } + }; + + template < + typename vector_type1, + typename vector_type2, + typename kernel_type + > + void pick_initial_centers( + long num_centers, + vector_type1& centers, + const vector_type2& samples, + const kernel_type& k, + double percentile = 0.01 + ) + { + /* + This function is basically just a non-randomized version of the kmeans++ algorithm + described in the paper: + kmeans++: The Advantages of Careful Seeding by Arthur and Vassilvitskii + + */ + + + // make sure requires clause is not broken + DLIB_ASSERT(num_centers > 1 && 0 <= percentile && percentile < 1 && samples.size() > 1, + "\tvoid pick_initial_centers()" + << "\n\tYou passed invalid arguments to this function" + << "\n\tnum_centers: " << num_centers + << "\n\tpercentile: " << percentile + << "\n\tsamples.size(): " << samples.size() + ); + + std::vector scores(samples.size()); + std::vector scores_sorted(samples.size()); + centers.clear(); + + // pick the first sample as one of the centers + centers.push_back(samples[0]); + + const long best_idx = static_cast(std::max(0.0,samples.size() - samples.size()*percentile - 1)); + + // pick the next center + for (long i = 0; i < num_centers-1; ++i) + { + // Loop over the samples and compare them to the most recent center. Store + // the distance from each sample to its closest center in scores. + const double k_cc = k(centers[i], centers[i]); + for (unsigned long s = 0; s < samples.size(); ++s) + { + // compute the distance between this sample and the current center + const double dist = k_cc + k(samples[s],samples[s]) - 2*k(samples[s], centers[i]); + + if (dist < scores[s].dist) + { + scores[s].dist = dist; + scores[s].idx = s; + } + } + + scores_sorted = scores; + + // now find the winning center and add it to centers. It is the one that is + // far away from all the other centers. + sort(scores_sorted.begin(), scores_sorted.end()); + centers.push_back(samples[scores_sorted[best_idx].idx]); + } + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename vector_type1, + typename vector_type2 + > + void pick_initial_centers( + long num_centers, + vector_type1& centers, + const vector_type2& samples, + double percentile = 0.01 + ) + { + typedef typename vector_type1::value_type sample_type; + linear_kernel kern; + pick_initial_centers(num_centers, centers, samples, kern, percentile); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename array_type, + typename sample_type, + typename alloc + > + void find_clusters_using_kmeans ( + const array_type& samples, + std::vector& centers, + unsigned long max_iter = 1000 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(samples.size() > 0 && centers.size() > 0, + "\tvoid find_clusters_using_kmeans()" + << "\n\tYou passed invalid arguments to this function" + << "\n\t samples.size(): " << samples.size() + << "\n\t centers.size(): " << centers.size() + ); + +#ifdef ENABLE_ASSERTS + { + const long nr = samples[0].nr(); + const long nc = samples[0].nc(); + for (unsigned long i = 0; i < samples.size(); ++i) + { + DLIB_ASSERT(is_vector(samples[i]) && samples[i].nr() == nr && samples[i].nc() == nc, + "\tvoid find_clusters_using_kmeans()" + << "\n\t You passed invalid arguments to this function" + << "\n\t is_vector(samples[i]): " << is_vector(samples[i]) + << "\n\t samples[i].nr(): " << samples[i].nr() + << "\n\t nr: " << nr + << "\n\t samples[i].nc(): " << samples[i].nc() + << "\n\t nc: " << nc + << "\n\t i: " << i + ); + } + } +#endif + + typedef typename sample_type::type scalar_type; + + sample_type zero(centers[0]); + set_all_elements(zero, 0); + + std::vector center_element_count; + + // tells which center a sample belongs to + std::vector assignments(samples.size(), samples.size()); + + + unsigned long iter = 0; + bool centers_changed = true; + while (centers_changed && iter < max_iter) + { + ++iter; + centers_changed = false; + center_element_count.assign(centers.size(), 0); + + // loop over each sample and see which center it is closest to + for (unsigned long i = 0; i < samples.size(); ++i) + { + // find the best center for sample[i] + scalar_type best_dist = std::numeric_limits::max(); + unsigned long best_center = 0; + for (unsigned long j = 0; j < centers.size(); ++j) + { + scalar_type dist = length(centers[j] - samples[i]); + if (dist < best_dist) + { + best_dist = dist; + best_center = j; + } + } + + if (assignments[i] != best_center) + { + centers_changed = true; + assignments[i] = best_center; + } + + center_element_count[best_center] += 1; + } + + // now update all the centers + centers.assign(centers.size(), zero); + for (unsigned long i = 0; i < samples.size(); ++i) + { + centers[assignments[i]] += samples[i]; + } + for (unsigned long i = 0; i < centers.size(); ++i) + { + if (center_element_count[i] != 0) + centers[i] /= center_element_count[i]; + } + } + + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename array_type, + typename sample_type, + typename alloc + > + void find_clusters_using_angular_kmeans ( + const array_type& samples, + std::vector& centers, + unsigned long max_iter = 1000 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(samples.size() > 0 && centers.size() > 0, + "\tvoid find_clusters_using_angular_kmeans()" + << "\n\tYou passed invalid arguments to this function" + << "\n\t samples.size(): " << samples.size() + << "\n\t centers.size(): " << centers.size() + ); + +#ifdef ENABLE_ASSERTS + { + const long nr = samples[0].nr(); + const long nc = samples[0].nc(); + for (unsigned long i = 0; i < samples.size(); ++i) + { + DLIB_ASSERT(is_vector(samples[i]) && samples[i].nr() == nr && samples[i].nc() == nc, + "\tvoid find_clusters_using_angular_kmeans()" + << "\n\t You passed invalid arguments to this function" + << "\n\t is_vector(samples[i]): " << is_vector(samples[i]) + << "\n\t samples[i].nr(): " << samples[i].nr() + << "\n\t nr: " << nr + << "\n\t samples[i].nc(): " << samples[i].nc() + << "\n\t nc: " << nc + << "\n\t i: " << i + ); + } + } +#endif + + typedef typename sample_type::type scalar_type; + + sample_type zero(centers[0]); + set_all_elements(zero, 0); + + unsigned long seed = 0; + + // tells which center a sample belongs to + std::vector assignments(samples.size(), samples.size()); + std::vector lengths; + for (unsigned long i = 0; i < samples.size(); ++i) + { + lengths.push_back(length(samples[i])); + // If there are zero vectors in samples then just say their length is 1 so we + // can avoid a division by zero check later on. Also, this doesn't matter + // since zero vectors can be assigned to any cluster randomly as there is no + // basis for picking one based on angle. + if (lengths.back() == 0) + lengths.back() = 1; + } + + // We will keep the centers as unit vectors at all times throughout the processing. + for (unsigned long i = 0; i < centers.size(); ++i) + { + double len = length(centers[i]); + // Avoid having length 0 centers. If that is the case then pick another center + // at random. + while(len == 0) + { + centers[i] = matrix_cast(gaussian_randm(centers[i].nr(), centers[i].nc(), seed++)); + len = length(centers[i]); + } + centers[i] /= len; + } + + + unsigned long iter = 0; + bool centers_changed = true; + while (centers_changed && iter < max_iter) + { + ++iter; + centers_changed = false; + + // loop over each sample and see which center it is closest to + for (unsigned long i = 0; i < samples.size(); ++i) + { + // find the best center for sample[i] + scalar_type best_angle = std::numeric_limits::max(); + unsigned long best_center = 0; + for (unsigned long j = 0; j < centers.size(); ++j) + { + scalar_type angle = -dot(centers[j],samples[i])/lengths[i]; + + if (angle < best_angle) + { + best_angle = angle; + best_center = j; + } + } + + if (assignments[i] != best_center) + { + centers_changed = true; + assignments[i] = best_center; + } + } + + // now update all the centers + centers.assign(centers.size(), zero); + for (unsigned long i = 0; i < samples.size(); ++i) + { + centers[assignments[i]] += samples[i]; + } + // Now length normalize all the centers. + for (unsigned long i = 0; i < centers.size(); ++i) + { + double len = length(centers[i]); + // Avoid having length 0 centers. If that is the case then pick another center + // at random. + while(len == 0) + { + centers[i] = matrix_cast(gaussian_randm(centers[i].nr(), centers[i].nc(), seed++)); + len = length(centers[i]); + centers_changed = true; + } + centers[i] /= len; + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename array_type, + typename EXP + > + unsigned long nearest_center ( + const array_type& centers, + const matrix_exp& sample + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(centers.size() > 0 && sample.size() > 0 && is_vector(sample), + "\t unsigned long nearest_center()" + << "\n\t You have given invalid inputs to this function." + << "\n\t centers.size(): " << centers.size() + << "\n\t sample.size(): " << sample.size() + << "\n\t is_vector(sample): " << is_vector(sample) + ); + + double best_dist = length_squared(centers[0] - sample); + unsigned long best_idx = 0; + for (unsigned long i = 1; i < centers.size(); ++i) + { + const double dist = length_squared(centers[i] - sample); + if (dist < best_dist) + { + best_dist = dist; + best_idx = i; + } + } + return best_idx; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_KKMEANs_ + + diff --git a/ml/dlib/dlib/svm/kkmeans_abstract.h b/ml/dlib/dlib/svm/kkmeans_abstract.h new file mode 100644 index 000000000..9f9d7ccce --- /dev/null +++ b/ml/dlib/dlib/svm/kkmeans_abstract.h @@ -0,0 +1,365 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_KKMEANs_ABSTRACT_ +#ifdef DLIB_KKMEANs_ABSTRACT_ + +#include +#include "../matrix/matrix_abstract.h" +#include "../algs.h" +#include "../serialize.h" +#include "kernel_abstract.h" +#include "kcentroid_abstract.h" +#include "../noncopyable.h" + +namespace dlib +{ + + template < + typename kernel_type + > + class kkmeans : public noncopyable + { + /*! + REQUIREMENTS ON kernel_type + is a kernel function object as defined in dlib/svm/kernel_abstract.h + + INITIAL VALUE + - number_of_centers() == 1 + - get_min_change() == 0.01 + + WHAT THIS OBJECT REPRESENTS + This is an implementation of a kernelized k-means clustering algorithm. + It performs k-means clustering by using the kcentroid object. + !*/ + + public: + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + + kkmeans ( + const kcentroid& kc_ + ); + /*! + ensures + - #number_of_centers() == 1 + - #get_min_change() == 0.01 + - #get_kcentroid(0) == a copy of kc_ + !*/ + + ~kkmeans( + ); + /*! + ensures + - all resources associated with *this have been released + !*/ + + void set_kcentroid ( + const kcentroid& kc_ + ); + /*! + ensures + - for all idx: + - #get_kcentroid(idx) == a copy of kc_ + !*/ + + const kcentroid& get_kcentroid ( + unsigned long i + ) const; + /*! + requires + - i < number_of_centers() + ensures + - returns a const reference to the ith kcentroid object contained in + this object. Each kcentroid represents one of the centers found + by the k-means clustering algorithm. + !*/ + + const kernel_type& get_kernel ( + ) const; + /*! + ensures + - returns a const reference to the kernel used by this object + !*/ + + void set_number_of_centers ( + unsigned long num + ); + /*! + requires + - num > 0 + ensures + - #number_of_centers() == num + !*/ + + unsigned long number_of_centers ( + ) const; + /*! + ensures + - returns the number of centers used in this instance of the k-means clustering + algorithm. + !*/ + + template < + typename matrix_type, + typename matrix_type2 + > + void train ( + const matrix_type& samples, + const matrix_type2& initial_centers, + long max_iter = 1000 + ); + /*! + requires + - matrix_type and matrix_type2 must either be dlib::matrix objects or convertible to dlib::matrix + via mat() + - matrix_type::type == sample_type (i.e. matrix_type should contain sample_type objects) + - matrix_type2::type == sample_type (i.e. matrix_type2 should contain sample_type objects) + - initial_centers.nc() == 1 (i.e. must be a column vector) + - samples.nc() == 1 (i.e. must be a column vector) + - initial_centers.nr() == number_of_centers() + ensures + - performs k-means clustering of the given set of samples. The initial center points + are taken from the initial_centers argument. + - loops over the data and continues to refine the clustering until either less than + get_min_change() fraction of the data points change clusters or we have done max_iter + iterations over the data. + - After this function finishes you can call the operator() function below + to determine which centroid a given sample is closest to. + !*/ + + unsigned long operator() ( + const sample_type& sample + ) const; + /*! + ensures + - returns a number idx such that: + - idx < number_of_centers() + - get_kcentroid(idx) == the centroid that is closest to the given + sample. + !*/ + + void set_min_change ( + scalar_type min_change + ); + /*! + requires + - 0 <= min_change < 1 + ensures + - #get_min_change() == min_change + !*/ + + const scalar_type get_min_change ( + ) const; + /*! + ensures + - returns the minimum fraction of data points that need to change + centers in an iteration of kmeans for the algorithm to keep going. + !*/ + + void swap ( + kkmeans& item + ); + /*! + ensures + - swaps *this and item + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename kernel_type + > + void swap( + kkmeans& a, + kkmeans& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + + template < + typename kernel_type + > + void serialize ( + const kkmeans& item, + std::ostream& out + ); + /*! + provides serialization support for kkmeans objects + !*/ + + template < + typename kernel_type + > + void deserialize ( + kkmeans& item, + std::istream& in + ); + /*! + provides serialization support for kkmeans objects + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename vector_type1, + typename vector_type2, + typename kernel_type + > + void pick_initial_centers( + long num_centers, + vector_type1& centers, + const vector_type2& samples, + const kernel_type& k, + double percentile = 0.01 + ); + /*! + requires + - num_centers > 1 + - 0 <= percentile < 1 + - samples.size() > 1 + - vector_type1 == something with an interface compatible with std::vector + - vector_type2 == something with an interface compatible with std::vector + - k(samples[0],samples[0]) must be a valid expression that returns a double + - both centers and samples must be able to contain kernel_type::sample_type + objects + ensures + - finds num_centers candidate cluster centers in the data in the samples + vector. Assumes that k is the kernel that will be used during clustering + to define the space in which clustering occurs. + - The centers are found by looking for points that are far away from other + candidate centers. However, if the data is noisy you probably want to + ignore the farthest way points since they will be outliers. To do this + set percentile to the fraction of outliers you expect the data to contain. + - #centers.size() == num_centers + - #centers == a vector containing the candidate centers found + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename vector_type1, + typename vector_type2 + > + void pick_initial_centers( + long num_centers, + vector_type1& centers, + const vector_type2& samples, + double percentile = 0.01 + ); + /*! + requires + - num_centers > 1 + - 0 <= percentile < 1 + - samples.size() > 1 + - vector_type1 == something with an interface compatible with std::vector + - vector_type2 == something with an interface compatible with std::vector + - Both centers and samples must be able to contain dlib::matrix based row or + column vectors. + ensures + - performs: pick_initial_centers(num_centers, centers, samples, linear_kernel(), percentile) + (i.e. this function is simply an overload that uses the linear kernel. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename array_type, + typename sample_type, + typename alloc + > + void find_clusters_using_kmeans ( + const array_type& samples, + std::vector& centers, + unsigned long max_iter = 1000 + ); + /*! + requires + - samples.size() > 0 + - samples == a bunch of row or column vectors and they all must be of the + same length. + - centers.size() > 0 + - array_type == something with an interface compatible with std::vector + and it must contain row or column vectors capable of being stored in + sample_type objects. + - sample_type == a dlib::matrix capable of representing vectors + ensures + - performs regular old linear kmeans clustering on the samples. The clustering + begins with the initial set of centers given as an argument to this function. + When it finishes #centers will contain the resulting centers. + - no more than max_iter iterations will be performed before this function + terminates. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename array_type, + typename sample_type, + typename alloc + > + void find_clusters_using_angular_kmeans ( + const array_type& samples, + std::vector& centers, + unsigned long max_iter = 1000 + ); + /*! + requires + - samples.size() > 0 + - samples == a bunch of row or column vectors and they all must be of the + same length. + - centers.size() > 0 + - array_type == something with an interface compatible with std::vector + and it must contain row or column vectors capable of being stored in + sample_type objects. + - sample_type == a dlib::matrix capable of representing vectors + ensures + - performs linear kmeans clustering on the samples, except instead of using + Euclidean distance to compare samples to the centers it uses the angle + between a sample and a center (with respect to the origin). So we try to + cluster samples together if they have small angles with respect to each + other. The clustering begins with the initial set of centers given as an + argument to this function. When it finishes #centers will contain the + resulting centers. + - for all valid i: + - length(#centers[i]) == 1 + (i.e. the output centers are scaled to be unit vectors since their + magnitude is irrelevant. Moreover, this makes it so you can use + functions like nearest_center() with #centers to find the cluster + assignments.) + - No more than max_iter iterations will be performed before this function + terminates. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename array_type, + typename EXP + > + unsigned long nearest_center ( + const array_type& centers, + const matrix_exp& sample + ); + /*! + requires + - centers.size() > 0 + - sample.size() > 0 + - is_vector(sample) == true + - centers must be an array of vectors such that the following expression is + valid: length_squared(sample - centers[0]). (e.g. centers could be a + std::vector of matrix objects holding column vectors). + ensures + - returns the index that identifies the element of centers that is nearest to + sample. That is, returns a number IDX such that centers[IDX] is the element + of centers that minimizes length(centers[IDX]-sample). + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_KKMEANs_ABSTRACT_ + diff --git a/ml/dlib/dlib/svm/krls.h b/ml/dlib/dlib/svm/krls.h new file mode 100644 index 000000000..6c72e45e8 --- /dev/null +++ b/ml/dlib/dlib/svm/krls.h @@ -0,0 +1,358 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_KRLs_ +#define DLIB_KRLs_ + +#include + +#include "krls_abstract.h" +#include "../matrix.h" +#include "function.h" +#include "../std_allocator.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template + class krls + { + /*! + This is an implementation of the kernel recursive least squares algorithm described in the paper: + The Kernel Recursive Least Squares Algorithm by Yaakov Engel. + !*/ + + public: + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + + + explicit krls ( + const kernel_type& kernel_, + scalar_type tolerance_ = 0.001, + unsigned long max_dictionary_size_ = 1000000 + ) : + kernel(kernel_), + my_tolerance(tolerance_), + my_max_dictionary_size(max_dictionary_size_) + { + // make sure requires clause is not broken + DLIB_ASSERT(tolerance_ >= 0, + "\tkrls::krls()" + << "\n\t You have to give a positive tolerance" + << "\n\t this: " << this + << "\n\t tolerance: " << tolerance_ + ); + + clear_dictionary(); + } + + scalar_type tolerance() const + { + return my_tolerance; + } + + unsigned long max_dictionary_size() const + { + return my_max_dictionary_size; + } + + const kernel_type& get_kernel ( + ) const + { + return kernel; + } + + void clear_dictionary () + { + dictionary.clear(); + alpha.clear(); + + K_inv.set_size(0,0); + K.set_size(0,0); + P.set_size(0,0); + } + + scalar_type operator() ( + const sample_type& x + ) const + { + scalar_type temp = 0; + for (unsigned long i = 0; i < alpha.size(); ++i) + temp += alpha[i]*kern(dictionary[i], x); + + return temp; + } + + void train ( + const sample_type& x, + scalar_type y + ) + { + const scalar_type kx = kern(x,x); + if (alpha.size() == 0) + { + // just ignore this sample if it is the zero vector (or really close to being zero) + if (std::abs(kx) > std::numeric_limits::epsilon()) + { + // set initial state since this is the first training example we have seen + + K_inv.set_size(1,1); + K_inv(0,0) = 1/kx; + K.set_size(1,1); + K(0,0) = kx; + + alpha.push_back(y/kx); + dictionary.push_back(x); + P.set_size(1,1); + P(0,0) = 1; + } + } + else + { + // fill in k + k.set_size(alpha.size()); + for (long r = 0; r < k.nr(); ++r) + k(r) = kern(x,dictionary[r]); + + // compute the error we would have if we approximated the new x sample + // with the dictionary. That is, do the ALD test from the KRLS paper. + a = K_inv*k; + scalar_type delta = kx - trans(k)*a; + + // if this new vector isn't approximately linearly dependent on the vectors + // in our dictionary. + if (delta > my_tolerance) + { + if (dictionary.size() >= my_max_dictionary_size) + { + // We need to remove one of the old members of the dictionary before + // we proceed with adding a new one. So remove the oldest one. + remove_dictionary_vector(0); + + // recompute these guys since they were computed with the old + // kernel matrix + k = remove_row(k,0); + a = K_inv*k; + delta = kx - trans(k)*a; + } + + // add x to the dictionary + dictionary.push_back(x); + + // update K_inv by computing the new one in the temp matrix (equation 3.14) + matrix temp(K_inv.nr()+1, K_inv.nc()+1); + // update the middle part of the matrix + set_subm(temp, get_rect(K_inv)) = K_inv + a*trans(a)/delta; + // update the right column of the matrix + set_subm(temp, 0, K_inv.nr(),K_inv.nr(),1) = -a/delta; + // update the bottom row of the matrix + set_subm(temp, K_inv.nr(), 0, 1, K_inv.nr()) = trans(-a/delta); + // update the bottom right corner of the matrix + temp(K_inv.nr(), K_inv.nc()) = 1/delta; + // put temp into K_inv + temp.swap(K_inv); + + + + + // update K (the kernel matrix) + temp.set_size(K.nr()+1, K.nc()+1); + set_subm(temp, get_rect(K)) = K; + // update the right column of the matrix + set_subm(temp, 0, K.nr(),K.nr(),1) = k; + // update the bottom row of the matrix + set_subm(temp, K.nr(), 0, 1, K.nr()) = trans(k); + temp(K.nr(), K.nc()) = kx; + // put temp into K + temp.swap(K); + + + + + // Now update the P matrix (equation 3.15) + temp.set_size(P.nr()+1, P.nc()+1); + set_subm(temp, get_rect(P)) = P; + // initialize the new sides of P + set_rowm(temp,P.nr()) = 0; + set_colm(temp,P.nr()) = 0; + temp(P.nr(), P.nc()) = 1; + temp.swap(P); + + // now update the alpha vector (equation 3.16) + const scalar_type k_a = (y-trans(k)*mat(alpha))/delta; + for (unsigned long i = 0; i < alpha.size(); ++i) + { + alpha[i] -= a(i)*k_a; + } + alpha.push_back(k_a); + } + else + { + q = P*a/(1+trans(a)*P*a); + + // update P (equation 3.12) + temp_matrix = trans(a)*P; + P -= q*temp_matrix; + + // update the alpha vector (equation 3.13) + const scalar_type k_a = y-trans(k)*mat(alpha); + for (unsigned long i = 0; i < alpha.size(); ++i) + { + alpha[i] += (K_inv*q*k_a)(i); + } + } + } + } + + void swap ( + krls& item + ) + { + exchange(kernel, item.kernel); + dictionary.swap(item.dictionary); + alpha.swap(item.alpha); + K_inv.swap(item.K_inv); + K.swap(item.K); + P.swap(item.P); + exchange(my_tolerance, item.my_tolerance); + q.swap(item.q); + a.swap(item.a); + k.swap(item.k); + temp_matrix.swap(item.temp_matrix); + exchange(my_max_dictionary_size, item.my_max_dictionary_size); + } + + unsigned long dictionary_size ( + ) const { return dictionary.size(); } + + decision_function get_decision_function ( + ) const + { + return decision_function( + mat(alpha), + -sum(mat(alpha))*tau, + kernel, + mat(dictionary) + ); + } + + friend void serialize(const krls& item, std::ostream& out) + { + serialize(item.kernel, out); + serialize(item.dictionary, out); + serialize(item.alpha, out); + serialize(item.K_inv, out); + serialize(item.K, out); + serialize(item.P, out); + serialize(item.my_tolerance, out); + serialize(item.my_max_dictionary_size, out); + } + + friend void deserialize(krls& item, std::istream& in) + { + deserialize(item.kernel, in); + deserialize(item.dictionary, in); + deserialize(item.alpha, in); + deserialize(item.K_inv, in); + deserialize(item.K, in); + deserialize(item.P, in); + deserialize(item.my_tolerance, in); + deserialize(item.my_max_dictionary_size, in); + } + + private: + + inline scalar_type kern (const sample_type& m1, const sample_type& m2) const + { + return kernel(m1,m2) + tau; + } + + void remove_dictionary_vector ( + long i + ) + /*! + requires + - 0 <= i < dictionary.size() + ensures + - #dictionary.size() == dictionary.size() - 1 + - #alpha.size() == alpha.size() - 1 + - updates the K_inv matrix so that it is still a proper inverse of the + kernel matrix + - also removes the necessary row and column from the K matrix + - uses the this->a variable so after this function runs that variable + will contain a different value. + !*/ + { + // remove the dictionary vector + dictionary.erase(dictionary.begin()+i); + + // remove the i'th vector from the inverse kernel matrix. This formula is basically + // just the reverse of the way K_inv is updated by equation 3.14 during normal training. + K_inv = removerc(K_inv,i,i) - remove_row(colm(K_inv,i)/K_inv(i,i),i)*remove_col(rowm(K_inv,i),i); + + // now compute the updated alpha values to take account that we just removed one of + // our dictionary vectors + a = (K_inv*remove_row(K,i)*mat(alpha)); + + // now copy over the new alpha values + alpha.resize(alpha.size()-1); + for (unsigned long k = 0; k < alpha.size(); ++k) + { + alpha[k] = a(k); + } + + // update the P matrix as well + P = removerc(P,i,i); + + // update the K matrix as well + K = removerc(K,i,i); + } + + + kernel_type kernel; + + typedef std_allocator alloc_sample_type; + typedef std_allocator alloc_scalar_type; + typedef std::vector dictionary_vector_type; + typedef std::vector alpha_vector_type; + + dictionary_vector_type dictionary; + alpha_vector_type alpha; + + matrix K_inv; + matrix K; + matrix P; + + scalar_type my_tolerance; + unsigned long my_max_dictionary_size; + + + // temp variables here just so we don't have to reconstruct them over and over. Thus, + // they aren't really part of the state of this object. + matrix q; + matrix a; + matrix k; + matrix temp_matrix; + + const static scalar_type tau; + + }; + + template + const typename kernel_type::scalar_type krls::tau = static_cast(0.01); + +// ---------------------------------------------------------------------------------------- + + template + void swap(krls& a, krls& b) + { a.swap(b); } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_KRLs_ + diff --git a/ml/dlib/dlib/svm/krls_abstract.h b/ml/dlib/dlib/svm/krls_abstract.h new file mode 100644 index 000000000..7ea2d9872 --- /dev/null +++ b/ml/dlib/dlib/svm/krls_abstract.h @@ -0,0 +1,202 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_KRLs_ABSTRACT_ +#ifdef DLIB_KRLs_ABSTRACT_ + +#include +#include "../matrix/matrix_abstract.h" +#include "../algs.h" +#include "../serialize.h" +#include "kernel_abstract.h" + +namespace dlib +{ + + template < + typename kernel_type + > + class krls + { + /*! + REQUIREMENTS ON kernel_type + is a kernel function object as defined in dlib/svm/kernel_abstract.h + + INITIAL VALUE + - dictionary_size() == 0 + + WHAT THIS OBJECT REPRESENTS + This is an implementation of the kernel recursive least squares algorithm + described in the paper: + The Kernel Recursive Least Squares Algorithm by Yaakov Engel. + + The long and short of this algorithm is that it is an online kernel based + regression algorithm. You give it samples (x,y) and it learns the function + f(x) == y. For a detailed description of the algorithm read the above paper. + + Also note that the algorithm internally keeps a set of "dictionary vectors" + that are used to represent the regression function. You can force the + algorithm to use no more than a set number of vectors by setting + the 3rd constructor argument to whatever you want. However, note that + doing this causes the algorithm to bias it's results towards more + recent training examples. + !*/ + + public: + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + + + explicit krls ( + const kernel_type& kernel_, + scalar_type tolerance_ = 0.001, + unsigned long max_dictionary_size_ = 1000000 + ); + /*! + requires + - tolerance >= 0 + ensures + - this object is properly initialized + - #tolerance() == tolerance_ + - #get_decision_function().kernel_function == kernel_ + (i.e. this object will use the given kernel function) + - #get_kernel() == kernel_ + - #max_dictionary_size() == max_dictionary_size_ + !*/ + + scalar_type tolerance( + ) const; + /*! + ensures + - returns the tolerance to use for the approximately linearly dependent + test in the KRLS algorithm. This is a number which governs how + accurately this object will approximate the decision function it is + learning. Smaller values generally result in a more accurate + estimate while also resulting in a bigger set of dictionary vectors in + the learned decision function. Bigger tolerances values result in a + less accurate decision function but also in less dictionary vectors. + - The exact meaning of the tolerance parameter is the following: + Imagine that we have an empirical_kernel_map that contains all + the current dictionary vectors. Then the tolerance is the minimum + projection error (as given by empirical_kernel_map::project()) required + to cause us to include a new vector in the dictionary. So each time + you call train() the krls object basically just computes the projection + error for that new sample and if it is larger than the tolerance + then that new sample becomes part of the dictionary. + !*/ + + const kernel_type& get_kernel ( + ) const; + /*! + ensures + - returns a const reference to the kernel used by this object + !*/ + + unsigned long max_dictionary_size( + ) const; + /*! + ensures + - returns the maximum number of dictionary vectors this object + will use at a time. That is, dictionary_size() will never be + greater than max_dictionary_size(). + !*/ + + void clear_dictionary ( + ); + /*! + ensures + - clears out all learned data + (e.g. #get_decision_function().basis_vectors.size() == 0) + !*/ + + scalar_type operator() ( + const sample_type& x + ) const; + /*! + ensures + - returns the current y estimate for the given x + !*/ + + void train ( + const sample_type& x, + scalar_type y + ); + /*! + ensures + - trains this object that the given x should be mapped to the given y + - if (dictionary_size() == max_dictionary_size() and training + would add another dictionary vector to this object) then + - discards the oldest dictionary vector so that we can still + add a new one and remain below the max number of dictionary + vectors. + !*/ + + void swap ( + krls& item + ); + /*! + ensures + - swaps *this with item + !*/ + + unsigned long dictionary_size ( + ) const; + /*! + ensures + - returns the number of vectors in the dictionary. That is, + returns a number equal to get_decision_function().basis_vectors.size() + !*/ + + decision_function get_decision_function ( + ) const; + /*! + ensures + - returns a decision function F that represents the function learned + by this object so far. I.e. it is the case that: + - for all x: F(x) == (*this)(x) + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename kernel_type + > + void swap( + krls& a, + krls& b + ) + { a.swap(b); } + /*! + provides a global swap function + !*/ + + template < + typename kernel_type + > + void serialize ( + const krls& item, + std::ostream& out + ); + /*! + provides serialization support for krls objects + !*/ + + template < + typename kernel_type + > + void deserialize ( + krls& item, + std::istream& in + ); + /*! + provides serialization support for krls objects + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_KRLs_ABSTRACT_ + diff --git a/ml/dlib/dlib/svm/krr_trainer.h b/ml/dlib/dlib/svm/krr_trainer.h new file mode 100644 index 000000000..a43431169 --- /dev/null +++ b/ml/dlib/dlib/svm/krr_trainer.h @@ -0,0 +1,368 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_KRR_TRAInER_Hh_ +#define DLIB_KRR_TRAInER_Hh_ + +#include "../algs.h" +#include "function.h" +#include "kernel.h" +#include "empirical_kernel_map.h" +#include "linearly_independent_subset_finder.h" +#include "../statistics.h" +#include "rr_trainer.h" +#include "krr_trainer_abstract.h" +#include +#include + +namespace dlib +{ + template < + typename K + > + class krr_trainer + { + + public: + typedef K kernel_type; + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + typedef decision_function trained_function_type; + + krr_trainer ( + ) : + verbose(false), + max_basis_size(400), + ekm_stale(true) + { + } + + void be_verbose ( + ) + { + verbose = true; + trainer.be_verbose(); + } + + void be_quiet ( + ) + { + verbose = false; + trainer.be_quiet(); + } + + void use_regression_loss_for_loo_cv ( + ) + { + trainer.use_regression_loss_for_loo_cv(); + } + + void use_classification_loss_for_loo_cv ( + ) + { + trainer.use_classification_loss_for_loo_cv(); + } + + bool will_use_regression_loss_for_loo_cv ( + ) const + { + return trainer.will_use_regression_loss_for_loo_cv(); + } + + const kernel_type get_kernel ( + ) const + { + return kern; + } + + void set_kernel ( + const kernel_type& k + ) + { + kern = k; + } + + template + void set_basis ( + const T& basis_samples + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(basis_samples.size() > 0 && is_vector(mat(basis_samples)), + "\tvoid krr_trainer::set_basis(basis_samples)" + << "\n\t You have to give a non-empty set of basis_samples and it must be a vector" + << "\n\t basis_samples.size(): " << basis_samples.size() + << "\n\t is_vector(mat(basis_samples)): " << is_vector(mat(basis_samples)) + << "\n\t this: " << this + ); + + basis = mat(basis_samples); + ekm_stale = true; + } + + bool basis_loaded ( + ) const + { + return (basis.size() != 0); + } + + void clear_basis ( + ) + { + basis.set_size(0); + ekm.clear(); + ekm_stale = true; + } + + unsigned long get_max_basis_size ( + ) const + { + return max_basis_size; + } + + void set_max_basis_size ( + unsigned long max_basis_size_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(max_basis_size_ > 0, + "\t void krr_trainer::set_max_basis_size()" + << "\n\t max_basis_size_ must be greater than 0" + << "\n\t max_basis_size_: " << max_basis_size_ + << "\n\t this: " << this + ); + + max_basis_size = max_basis_size_; + } + + void set_lambda ( + scalar_type lambda_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(lambda_ >= 0, + "\t void krr_trainer::set_lambda()" + << "\n\t lambda must be greater than or equal to 0" + << "\n\t lambda_: " << lambda_ + << "\n\t this: " << this + ); + + trainer.set_lambda(lambda_); + } + + const scalar_type get_lambda ( + ) const + { + return trainer.get_lambda(); + } + + template + void set_search_lambdas ( + const matrix_exp& lambdas + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(is_vector(lambdas) && lambdas.size() > 0 && min(lambdas) > 0, + "\t void krr_trainer::set_search_lambdas()" + << "\n\t lambdas must be a non-empty vector of values" + << "\n\t is_vector(lambdas): " << is_vector(lambdas) + << "\n\t lambdas.size(): " << lambdas.size() + << "\n\t min(lambdas): " << min(lambdas) + << "\n\t this: " << this + ); + + trainer.set_search_lambdas(lambdas); + } + + const matrix& get_search_lambdas ( + ) const + { + return trainer.get_search_lambdas(); + } + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y + ) const + { + std::vector temp; + scalar_type temp2; + return do_train(mat(x), mat(y), false, temp, temp2); + } + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y, + std::vector& loo_values + ) const + { + scalar_type temp; + return do_train(mat(x), mat(y), true, loo_values, temp); + } + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y, + std::vector& loo_values, + scalar_type& lambda_used + ) const + { + return do_train(mat(x), mat(y), true, loo_values, lambda_used); + } + + + private: + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function do_train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y, + const bool output_loo_values, + std::vector& loo_values, + scalar_type& the_lambda + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(is_learning_problem(x,y), + "\t decision_function krr_trainer::train(x,y)" + << "\n\t invalid inputs were given to this function" + << "\n\t is_vector(x): " << is_vector(x) + << "\n\t is_vector(y): " << is_vector(y) + << "\n\t x.size(): " << x.size() + << "\n\t y.size(): " << y.size() + ); + +#ifdef ENABLE_ASSERTS + if (get_lambda() == 0 && will_use_regression_loss_for_loo_cv() == false) + { + // make sure requires clause is not broken + DLIB_ASSERT(is_binary_classification_problem(x,y), + "\t decision_function krr_trainer::train(x,y)" + << "\n\t invalid inputs were given to this function" + ); + } +#endif + + // The first thing we do is make sure we have an appropriate ekm ready for use below. + if (basis_loaded()) + { + if (ekm_stale) + { + ekm.load(kern, basis); + ekm_stale = false; + } + } + else + { + linearly_independent_subset_finder lisf(kern, max_basis_size); + fill_lisf(lisf, x); + ekm.load(lisf); + } + + if (verbose) + { + std::cout << "\nNumber of basis vectors used: " << ekm.out_vector_size() << std::endl; + } + + typedef matrix column_matrix_type; + + running_stats rs; + + // Now we project all the x samples into kernel space using our EKM + matrix proj_x; + proj_x.set_size(x.size()); + for (long i = 0; i < proj_x.size(); ++i) + { + scalar_type err; + // Note that we also append a 1 to the end of the vectors because this is + // a convenient way of dealing with the bias term later on. + if (verbose == false) + { + proj_x(i) = ekm.project(x(i)); + } + else + { + proj_x(i) = ekm.project(x(i),err); + rs.add(err); + } + } + + if (verbose) + { + std::cout << "Mean EKM projection error: " << rs.mean() << std::endl; + std::cout << "Standard deviation of EKM projection error: " << rs.stddev() << std::endl; + } + + + decision_function > > lin_df; + + if (output_loo_values) + lin_df = trainer.train(proj_x,y, loo_values, the_lambda); + else + lin_df = trainer.train(proj_x,y); + + // convert the linear decision function into a kernelized one. + decision_function df; + df = ekm.convert_to_decision_function(lin_df.basis_vectors(0)); + df.b = lin_df.b; + + // If we used an automatically derived basis then there isn't any point in + // keeping the ekm around. So free its memory. + if (basis_loaded() == false) + { + ekm.clear(); + } + + return df; + } + + + /*! + CONVENTION + - if (ekm_stale) then + - kern or basis have changed since the last time + they were loaded into the ekm + + - get_lambda() == trainer.get_lambda() + - get_kernel() == kern + - get_max_basis_size() == max_basis_size + - will_use_regression_loss_for_loo_cv() == trainer.will_use_regression_loss_for_loo_cv() + - get_search_lambdas() == trainer.get_search_lambdas() + + - basis_loaded() == (basis.size() != 0) + !*/ + + rr_trainer > > trainer; + + bool verbose; + + + kernel_type kern; + unsigned long max_basis_size; + + matrix basis; + mutable empirical_kernel_map ekm; + mutable bool ekm_stale; + + }; + +} + +#endif // DLIB_KRR_TRAInER_Hh_ + + diff --git a/ml/dlib/dlib/svm/krr_trainer_abstract.h b/ml/dlib/dlib/svm/krr_trainer_abstract.h new file mode 100644 index 000000000..399802f6b --- /dev/null +++ b/ml/dlib/dlib/svm/krr_trainer_abstract.h @@ -0,0 +1,322 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_KRR_TRAInER_ABSTRACT_Hh_ +#ifdef DLIB_KRR_TRAInER_ABSTRACT_Hh_ + +#include "../algs.h" +#include "function_abstract.h" +#include "kernel_abstract.h" +#include "empirical_kernel_map_abstract.h" + +namespace dlib +{ + template < + typename K + > + class krr_trainer + { + /*! + REQUIREMENTS ON K + is a kernel function object as defined in dlib/svm/kernel_abstract.h + + INITIAL VALUE + - get_lambda() == 0 + - basis_loaded() == false + - get_max_basis_size() == 400 + - will_use_regression_loss_for_loo_cv() == true + - get_search_lambdas() == logspace(-9, 2, 50) + - this object will not be verbose unless be_verbose() is called + + WHAT THIS OBJECT REPRESENTS + This object represents a tool for performing kernel ridge regression + (This basic algorithm is also known my many other names, e.g. regularized + least squares or least squares SVM). + + The exact definition of what this algorithm does is this: + Find w and b that minimizes the following (x_i are input samples and y_i are target values): + lambda*dot(w,w) + sum_over_i( (f(x_i) - y_i)^2 ) + where f(x) == dot(x,w) - b + + Except the dot products are replaced by kernel functions. So this + algorithm is just regular old least squares regression but with the + addition of a regularization term which encourages small w and the + application of the kernel trick. + + + It is implemented using the empirical_kernel_map and thus allows you + to run the algorithm on large datasets and obtain sparse outputs. It is also + capable of estimating the lambda parameter using leave-one-out cross-validation. + + + The leave-one-out cross-validation implementation is based on the techniques + discussed in this paper: + Notes on Regularized Least Squares by Ryan M. Rifkin and Ross A. Lippert. + !*/ + + public: + typedef K kernel_type; + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + typedef decision_function trained_function_type; + + krr_trainer ( + ); + /*! + ensures + - This object is properly initialized and ready to be used. + !*/ + + void be_verbose ( + ); + /*! + ensures + - This object will print status messages to standard out. + !*/ + + void be_quiet ( + ); + /*! + ensures + - this object will not print anything to standard out + !*/ + + const kernel_type get_kernel ( + ) const; + /*! + ensures + - returns a copy of the kernel function in use by this object + !*/ + + void set_kernel ( + const kernel_type& k + ); + /*! + ensures + - #get_kernel() == k + !*/ + + template + void set_basis ( + const T& basis_samples + ); + /*! + requires + - T must be a dlib::matrix type or something convertible to a matrix via mat() + (e.g. a std::vector) + - is_vector(basis_samples) == true + - basis_samples.size() > 0 + - get_kernel() must be capable of operating on the elements of basis_samples. That is, + expressions such as get_kernel()(basis_samples(0), basis_samples(0)) should make sense. + ensures + - #basis_loaded() == true + - training will be carried out in the span of the given basis_samples + !*/ + + bool basis_loaded ( + ) const; + /*! + ensures + - returns true if this object has been loaded with user supplied basis vectors and false otherwise. + !*/ + + void clear_basis ( + ); + /*! + ensures + - #basis_loaded() == false + !*/ + + unsigned long get_max_basis_size ( + ) const; + /*! + ensures + - returns the maximum number of basis vectors this object is allowed + to use. This parameter only matters when the user has not supplied + a basis via set_basis(). + !*/ + + void set_max_basis_size ( + unsigned long max_basis_size + ); + /*! + requires + - max_basis_size > 0 + ensures + - #get_max_basis_size() == max_basis_size + !*/ + + void set_lambda ( + scalar_type lambda + ); + /*! + requires + - lambda >= 0 + ensures + - #get_lambda() == lambda + !*/ + + const scalar_type get_lambda ( + ) const; + /*! + ensures + - returns the regularization parameter. It is the parameter that + determines the trade off between trying to fit the training data + exactly or allowing more errors but hopefully improving the + generalization ability of the resulting function. Smaller values + encourage exact fitting while larger values of lambda may encourage + better generalization. + + Note that a lambda of 0 has a special meaning. It indicates to this + object that it should automatically determine an appropriate lambda + value. This is done using leave-one-out cross-validation. + !*/ + + void use_regression_loss_for_loo_cv ( + ); + /*! + ensures + - #will_use_regression_loss_for_loo_cv() == true + !*/ + + void use_classification_loss_for_loo_cv ( + ); + /*! + ensures + - #will_use_regression_loss_for_loo_cv() == false + !*/ + + bool will_use_regression_loss_for_loo_cv ( + ) const; + /*! + ensures + - returns true if the automatic lambda estimation will attempt to estimate a lambda + appropriate for a regression task. Otherwise it will try and find one which + minimizes the number of classification errors. + !*/ + + template + void set_search_lambdas ( + const matrix_exp& lambdas + ); + /*! + requires + - is_vector(lambdas) == true + - lambdas.size() > 0 + - min(lambdas) > 0 + - lambdas must contain floating point numbers + ensures + - #get_search_lambdas() == lambdas + !*/ + + const matrix& get_search_lambdas ( + ) const; + /*! + ensures + - returns a matrix M such that: + - is_vector(M) == true + - M == a list of all the lambda values which will be tried when performing + LOO cross-validation for determining the best lambda. + !*/ + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y + ) const; + /*! + requires + - x == a matrix or something convertible to a matrix via mat(). + Also, x should contain sample_type objects. + - y == a matrix or something convertible to a matrix via mat(). + Also, y should contain scalar_type objects. + - is_learning_problem(x,y) == true + - if (get_lambda() == 0 && will_use_regression_loss_for_loo_cv() == false) then + - is_binary_classification_problem(x,y) == true + (i.e. if you want this algorithm to estimate a lambda appropriate for + classification functions then you had better give a valid classification + problem) + ensures + - performs kernel ridge regression given the training samples in x and target values in y. + - returns a decision_function F with the following properties: + - F(new_x) == predicted y value + + - if (basis_loaded()) then + - training will be carried out in the span of the user supplied basis vectors + - else + - this object will attempt to automatically select an appropriate basis + + - if (get_lambda() == 0) then + - This object will perform internal leave-one-out cross-validation to determine an + appropriate lambda automatically. It will compute the LOO error for each lambda + in get_search_lambdas() and select the best one. + - if (will_use_regression_loss_for_loo_cv()) then + - the lambda selected will be the one that minimizes the mean squared error. + - else + - the lambda selected will be the one that minimizes the number classification + mistakes. We say a point is classified correctly if the output of the + decision_function has the same sign as its label. + - #get_lambda() == 0 + (i.e. we don't change the get_lambda() value. If you want to know what the + automatically selected lambda value was then call the version of train() + defined below) + - else + - The user supplied value of get_lambda() will be used to perform the kernel + ridge regression. + !*/ + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y, + std::vector& loo_values + ) const; + /*! + requires + - all the requirements for train(x,y) must be satisfied + ensures + - returns train(x,y) + (i.e. executes train(x,y) and returns its result) + - #loo_values.size() == y.size() + - for all valid i: + - #loo_values[i] == leave-one-out prediction for the value of y(i) based + on all the training samples other than (x(i),y(i)). + !*/ + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y, + std::vector& loo_values, + scalar_type& lambda_used + ) const; + /*! + requires + - all the requirements for train(x,y) must be satisfied + ensures + - returns train(x,y) + (i.e. executes train(x,y) and returns its result) + - #loo_values.size() == y.size() + - for all valid i: + - #loo_values[i] == leave-one-out prediction for the value of y(i) based + on all the training samples other than (x(i),y(i)). + - #lambda_used == the value of lambda used to generate the + decision_function. Note that this lambda value is always + equal to get_lambda() if get_lambda() isn't 0. + !*/ + + }; + +} + +#endif // DLIB_KRR_TRAInER_ABSTRACT_Hh_ + diff --git a/ml/dlib/dlib/svm/linearly_independent_subset_finder.h b/ml/dlib/dlib/svm/linearly_independent_subset_finder.h new file mode 100644 index 000000000..3bac0df2c --- /dev/null +++ b/ml/dlib/dlib/svm/linearly_independent_subset_finder.h @@ -0,0 +1,540 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_LISfh_ +#define DLIB_LISfh_ + +#include + +#include "linearly_independent_subset_finder_abstract.h" +#include "../matrix.h" +#include "function.h" +#include "../std_allocator.h" +#include "../algs.h" +#include "../serialize.h" +#include "../is_kind.h" +#include "../string.h" +#include "../rand.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template + class linearly_independent_subset_finder + { + /*! + INITIAL VALUE + - min_strength == 0 + - min_vect_idx == 0 + - K_inv.size() == 0 + - K.size() == 0 + - dictionary.size() == 0 + + CONVENTION + - max_dictionary_size() == my_max_dictionary_size + - get_kernel() == kernel + - minimum_tolerance() == min_tolerance + - size() == dictionary.size() + - get_dictionary() == mat(dictionary) + - K.nr() == dictionary.size() + - K.nc() == dictionary.size() + - for all valid r,c: + - K(r,c) == kernel(dictionary[r], dictionary[c]) + - K_inv == inv(K) + + - if (dictionary.size() == my_max_dictionary_size) then + - for all valid 0 < i < dictionary.size(): + - Let STRENGTHS[i] == the delta you would get for dictionary[i] (i.e. Approximately + Linearly Dependent value) if you removed dictionary[i] from this object and then + tried to add it back in. + - min_strength == the minimum value from STRENGTHS + - min_vect_idx == the index of the element in STRENGTHS with the smallest value + !*/ + + public: + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::sample_type type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + + linearly_independent_subset_finder ( + ) : + my_max_dictionary_size(100), + min_tolerance(0.001) + { + clear_dictionary(); + } + + linearly_independent_subset_finder ( + const kernel_type& kernel_, + unsigned long max_dictionary_size_, + scalar_type min_tolerance_ = 0.001 + ) : + kernel(kernel_), + my_max_dictionary_size(max_dictionary_size_), + min_tolerance(min_tolerance_) + { + // make sure requires clause is not broken + DLIB_ASSERT(min_tolerance_ > 0 && max_dictionary_size_ > 1, + "\tlinearly_independent_subset_finder()" + << "\n\tinvalid argument to constructor" + << "\n\tmin_tolerance_: " << min_tolerance_ + << "\n\tmax_dictionary_size_: " << max_dictionary_size_ + << "\n\tthis: " << this + ); + clear_dictionary(); + } + + unsigned long max_dictionary_size() const + { + return my_max_dictionary_size; + } + + const kernel_type& get_kernel ( + ) const + { + return kernel; + } + + scalar_type minimum_tolerance( + ) const + { + return min_tolerance; + } + + void set_minimum_tolerance ( + scalar_type min_tol + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(min_tol > 0, + "\tlinearly_independent_subset_finder::set_minimum_tolerance()" + << "\n\tinvalid argument to this function" + << "\n\tmin_tol: " << min_tol + << "\n\tthis: " << this + ); + min_tolerance = min_tol; + } + + void clear_dictionary () + { + dictionary.clear(); + min_strength = 0; + min_vect_idx = 0; + + K_inv.set_size(0,0); + K.set_size(0,0); + } + + scalar_type projection_error ( + const sample_type& x + ) const + { + const scalar_type kx = kernel(x,x); + if (dictionary.size() == 0) + { + return kx; + } + else + { + // fill in k + k.set_size(dictionary.size()); + for (long r = 0; r < k.nr(); ++r) + k(r) = kernel(x,dictionary[r]); + + // compute the error we would have if we approximated the new x sample + // with the dictionary. That is, do the ALD test from the KRLS paper. + a = K_inv*k; + scalar_type delta = kx - trans(k)*a; + + return delta; + } + } + + bool add ( + const sample_type& x + ) + { + const scalar_type kx = kernel(x,x); + if (dictionary.size() == 0) + { + // just ignore this sample if it is the zero vector (or really close to being zero) + if (std::abs(kx) > std::numeric_limits::epsilon()) + { + // set initial state since this is the first sample we have seen + K_inv.set_size(1,1); + K_inv(0,0) = 1/kx; + + K.set_size(1,1); + K(0,0) = kx; + + dictionary.push_back(x); + return true; + } + return false; + } + else + { + // fill in k + k.set_size(dictionary.size()); + for (long r = 0; r < k.nr(); ++r) + k(r) = kernel(x,dictionary[r]); + + // compute the error we would have if we approximated the new x sample + // with the dictionary. That is, do the ALD test from the KRLS paper. + a = K_inv*k; + scalar_type delta = kx - trans(k)*a; + + // if this new vector is approximately linearly independent of the vectors + // in our dictionary. + if (delta > min_strength && delta > min_tolerance) + { + if (dictionary.size() == my_max_dictionary_size) + { + // if we have never computed the min_strength then we should compute it + if (min_strength == 0) + recompute_min_strength(); + + const long i = min_vect_idx; + + // replace the min strength vector with x. Put the new vector onto the end of + // dictionary and remove the vector at position i. + dictionary.erase(dictionary.begin()+i); + dictionary.push_back(x); + + // compute reduced K_inv. + // Remove the i'th vector from the inverse kernel matrix. This formula is basically + // just the reverse of the way K_inv is updated by equation 3.14 below. + temp = removerc(K_inv,i,i) - remove_row(colm(K_inv,i)/K_inv(i,i),i)*remove_col(rowm(K_inv,i),i); + + // recompute these guys since they were computed with the old + // kernel matrix + k2 = remove_row(k,i); + a2 = temp*k2; + delta = kx - trans(k2)*a2; + + // now update temp with the new dictionary vector + // update the middle part of the matrix + set_subm(K_inv, get_rect(temp)) = temp + a2*trans(a2)/delta; + // update the right column of the matrix + set_subm(K_inv, 0, temp.nr(),temp.nr(),1) = -a2/delta; + // update the bottom row of the matrix + set_subm(K_inv, temp.nr(), 0, 1, temp.nr()) = trans(-a2/delta); + // update the bottom right corner of the matrix + K_inv(temp.nr(), temp.nc()) = 1/delta; + + // now update the kernel matrix K + set_subm(K,get_rect(temp)) = removerc(K, i,i); + set_subm(K, 0, K.nr()-1,K.nr()-1,1) = k2; + // update the bottom row of the matrix + set_subm(K, K.nr()-1, 0, 1, K.nr()-1) = trans(k2); + K(K.nr()-1, K.nc()-1) = kx; + + // now we have to recompute the min_strength in this case + recompute_min_strength(); + } + else + { + // update K_inv by computing the new one in the temp matrix (equation 3.14 from Engel) + temp.set_size(K_inv.nr()+1, K_inv.nc()+1); + // update the middle part of the matrix + set_subm(temp, get_rect(K_inv)) = K_inv + a*trans(a)/delta; + // update the right column of the matrix + set_subm(temp, 0, K_inv.nr(),K_inv.nr(),1) = -a/delta; + // update the bottom row of the matrix + set_subm(temp, K_inv.nr(), 0, 1, K_inv.nr()) = trans(-a/delta); + // update the bottom right corner of the matrix + temp(K_inv.nr(), K_inv.nc()) = 1/delta; + // put temp into K_inv + temp.swap(K_inv); + + + // update K (the kernel matrix) + temp.set_size(K.nr()+1, K.nc()+1); + set_subm(temp, get_rect(K)) = K; + // update the right column of the matrix + set_subm(temp, 0, K.nr(),K.nr(),1) = k; + // update the bottom row of the matrix + set_subm(temp, K.nr(), 0, 1, K.nr()) = trans(k); + temp(K.nr(), K.nc()) = kx; + // put temp into K + temp.swap(K); + + + // add x to the dictionary + dictionary.push_back(x); + + } + return true; + } + else + { + return false; + } + } + } + + void swap ( + linearly_independent_subset_finder& item + ) + { + exchange(kernel, item.kernel); + dictionary.swap(item.dictionary); + exchange(min_strength, item.min_strength); + exchange(min_vect_idx, item.min_vect_idx); + K_inv.swap(item.K_inv); + K.swap(item.K); + exchange(my_max_dictionary_size, item.my_max_dictionary_size); + exchange(min_tolerance, item.min_tolerance); + + // non-state temp members + a.swap(item.a); + k.swap(item.k); + a2.swap(item.a2); + k2.swap(item.k2); + temp.swap(item.temp); + } + + size_t size ( + ) const { return dictionary.size(); } + + const matrix get_dictionary ( + ) const + { + return mat(dictionary); + } + + friend void serialize(const linearly_independent_subset_finder& item, std::ostream& out) + { + serialize(item.kernel, out); + serialize(item.dictionary, out); + serialize(item.min_strength, out); + serialize(item.min_vect_idx, out); + serialize(item.K_inv, out); + serialize(item.K, out); + serialize(item.my_max_dictionary_size, out); + serialize(item.min_tolerance, out); + } + + friend void deserialize(linearly_independent_subset_finder& item, std::istream& in) + { + deserialize(item.kernel, in); + deserialize(item.dictionary, in); + deserialize(item.min_strength, in); + deserialize(item.min_vect_idx, in); + deserialize(item.K_inv, in); + deserialize(item.K, in); + deserialize(item.my_max_dictionary_size, in); + deserialize(item.min_tolerance, in); + } + + const sample_type& operator[] ( + unsigned long index + ) const + { + return dictionary[index]; + } + + const matrix& get_kernel_matrix ( + ) const + { + return K; + } + + const matrix& get_inv_kernel_marix ( + ) const + { + return K_inv; + } + + private: + + typedef std_allocator alloc_sample_type; + typedef std_allocator alloc_scalar_type; + typedef std::vector dictionary_vector_type; + typedef std::vector scalar_vector_type; + + void recompute_min_strength ( + ) + /*! + ensures + - recomputes the min_strength and min_vect_idx values + so that they are correct with respect to the CONVENTION + !*/ + { + min_strength = std::numeric_limits::max(); + + // here we loop over each dictionary vector and compute what its delta would be if + // we were to remove it from the dictionary and then try to add it back in. + for (unsigned long i = 0; i < dictionary.size(); ++i) + { + // compute a2 = K_inv*k but where dictionary vector i has been removed + a2 = (removerc(K_inv,i,i) - remove_row(colm(K_inv,i)/K_inv(i,i),i)*remove_col(rowm(K_inv,i),i)) * + (remove_row(colm(K,i),i)); + scalar_type delta = K(i,i) - trans(remove_row(colm(K,i),i))*a2; + + if (delta < min_strength) + { + min_strength = delta; + min_vect_idx = i; + } + } + } + + + kernel_type kernel; + dictionary_vector_type dictionary; + scalar_type min_strength; + unsigned long min_vect_idx; + + matrix K_inv; + matrix K; + + unsigned long my_max_dictionary_size; + scalar_type min_tolerance; + + // temp variables here just so we don't have to reconstruct them over and over. Thus, + // they aren't really part of the state of this object. + mutable matrix a, a2; + mutable matrix k, k2; + mutable matrix temp; + + }; + +// ---------------------------------------------------------------------------------------- + + template + void swap(linearly_independent_subset_finder& a, linearly_independent_subset_finder& b) + { a.swap(b); } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + const matrix_op > > mat ( + const linearly_independent_subset_finder& m + ) + { + typedef op_array_to_mat > op; + return matrix_op(op(m)); + } + +// ---------------------------------------------------------------------------------------- + namespace impl + { + template < + typename kernel_type, + typename vector_type, + typename rand_type + > + void fill_lisf ( + linearly_independent_subset_finder& lisf, + const vector_type& samples, + rand_type& rnd, + int sampling_size + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(is_vector(samples) && sampling_size > 0, + "\t void fill_lisf()" + << "\n\t invalid arguments to this function" + << "\n\t is_vector(samples): " << is_vector(samples) + << "\n\t sampling_size: " << sampling_size + ); + + // no need to do anything if there aren't any samples + if (samples.size() == 0) + return; + + typedef typename kernel_type::scalar_type scalar_type; + + // Start out by guessing what a reasonable projection error tolerance is. We will use + // the biggest projection error we see in a small sample. + scalar_type tol = 0; + for (int i = 0; i < sampling_size; ++i) + { + const unsigned long idx = rnd.get_random_32bit_number()%samples.size(); + const scalar_type temp = lisf.projection_error(samples(idx)); + if (temp > tol) + tol = temp; + } + + const scalar_type min_tol = lisf.minimum_tolerance(); + + // run many rounds of random sampling. In each round we drop the tolerance lower. + while (tol >= min_tol && lisf.size() < lisf.max_dictionary_size()) + { + tol *= 0.5; + lisf.set_minimum_tolerance(std::max(tol, min_tol)); + int add_failures = 0; + + // Keep picking random samples and adding them into the lisf. Stop when we either + // fill it up or can't find any more samples with projection error larger than the + // current tolerance. + while (lisf.size() < lisf.max_dictionary_size() && add_failures < sampling_size) + { + if (lisf.add(samples(rnd.get_random_32bit_number()%samples.size())) == false) + { + ++add_failures; + } + } + } + + // set this back to its original value + lisf.set_minimum_tolerance(min_tol); + } + } + + template < + typename kernel_type, + typename vector_type + > + void fill_lisf ( + linearly_independent_subset_finder& lisf, + const vector_type& samples + ) + { + dlib::rand rnd; + impl::fill_lisf(lisf, mat(samples),rnd, 2000); + } + + template < + typename kernel_type, + typename vector_type, + typename rand_type + > + typename enable_if >::type fill_lisf ( + linearly_independent_subset_finder& lisf, + const vector_type& samples, + rand_type& rnd, + const int sampling_size = 2000 + ) + { + impl::fill_lisf(lisf, mat(samples),rnd, sampling_size); + } + + template < + typename kernel_type, + typename vector_type, + typename rand_type + > + typename disable_if >::type fill_lisf ( + linearly_independent_subset_finder& lisf, + const vector_type& samples, + rand_type random_seed, + const int sampling_size = 2000 + ) + { + dlib::rand rnd; + rnd.set_seed(cast_to_string(random_seed)); + impl::fill_lisf(lisf, mat(samples), rnd, sampling_size); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_LISfh_ + diff --git a/ml/dlib/dlib/svm/linearly_independent_subset_finder_abstract.h b/ml/dlib/dlib/svm/linearly_independent_subset_finder_abstract.h new file mode 100644 index 000000000..3224f9a0a --- /dev/null +++ b/ml/dlib/dlib/svm/linearly_independent_subset_finder_abstract.h @@ -0,0 +1,327 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_LISf_ABSTRACT_ +#ifdef DLIB_LISf_ABSTRACT_ + +#include "../algs.h" +#include "../serialize.h" +#include "kernel_abstract.h" + +namespace dlib +{ + + template < + typename kernel_type + > + class linearly_independent_subset_finder + { + /*! + REQUIREMENTS ON kernel_type + is a kernel function object as defined in dlib/svm/kernel_abstract.h + + INITIAL VALUE + - size() == 0 + + WHAT THIS OBJECT REPRESENTS + This is an implementation of an online algorithm for recursively finding a + set (aka dictionary) of linearly independent vectors in a kernel induced + feature space. To use it you decide how large you would like the dictionary + to be and then you feed it sample points. + + The implementation uses the Approximately Linearly Dependent metric described + in the paper The Kernel Recursive Least Squares Algorithm by Yaakov Engel to + decide which points are more linearly independent than others. The metric is + simply the squared distance between a test point and the subspace spanned by + the set of dictionary vectors. + + Each time you present this object with a new sample point (via this->add()) + it calculates the projection distance and if it is sufficiently large then this + new point is included into the dictionary. Note that this object can be configured + to have a maximum size. Once the max dictionary size is reached each new point + kicks out a previous point. This is done by removing the dictionary vector that + has the smallest projection distance onto the others. That is, the "least linearly + independent" vector is removed to make room for the new one. + !*/ + + public: + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::sample_type type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + + linearly_independent_subset_finder ( + ); + /*! + ensures + - #minimum_tolerance() == 0.001 + - this object is properly initialized + - #get_kernel() == kernel_type() (i.e. whatever the default is for the supplied kernel) + - #max_dictionary_size() == 100 + !*/ + + linearly_independent_subset_finder ( + const kernel_type& kernel_, + unsigned long max_dictionary_size_, + scalar_type min_tolerance = 0.001 + ); + /*! + requires + - min_tolerance > 0 + - max_dictionary_size > 1 + ensures + - #minimum_tolerance() == min_tolerance + - this object is properly initialized + - #get_kernel() == kernel_ + - #max_dictionary_size() == max_dictionary_size_ + !*/ + + const kernel_type& get_kernel ( + ) const; + /*! + ensures + - returns a const reference to the kernel used by this object + !*/ + + unsigned long max_dictionary_size( + ) const; + /*! + ensures + - returns the maximum number of dictionary vectors this object + will accumulate. That is, size() will never be + greater than max_dictionary_size(). + !*/ + + scalar_type minimum_tolerance( + ) const; + /*! + ensures + - returns the minimum projection error necessary to include a sample point + into the dictionary. + !*/ + + void set_minimum_tolerance ( + scalar_type min_tolerance + ); + /*! + requires + - min_tolerance > 0 + ensures + - #minimum_tolerance() == min_tolerance + !*/ + + void clear_dictionary ( + ); + /*! + ensures + - clears out all the data (e.g. #size() == 0) + !*/ + + bool add ( + const sample_type& x + ); + /*! + ensures + - if (size() < max_dictionary_size() then + - if (projection_error(x) > minimum_tolerance()) then + - adds x into the dictionary + - (*this)[#size()-1] == x + - #size() == size() + 1 + - returns true + - else + - the dictionary is not changed + - returns false + - else + - #size() == size() + (i.e. the number of vectors in this object doesn't change) + - since the dictionary is full adding a new element means we have to + remove one of the current ones. So let proj_error[i] be equal to the + projection error obtained when projecting dictionary vector (*this)[i] + onto the other elements of the dictionary. Then let min_proj_error + be equal to the minimum value in proj_error. The dictionary element + with the minimum projection error is the "least linearly independent" + vector in the dictionary and is the one which will be removed to make + room for a new element. + - if (projection_error(x) > minimum_tolerance() && projection_error(x) > min_proj_error) + - the least linearly independent vector in this object is removed + - adds x into the dictionary + - (*this)[#size()-1] == x + - returns true + - else + - the dictionary is not changed + - returns false + !*/ + + scalar_type projection_error ( + const sample_type& x + ) const; + /*! + ensures + - returns the squared distance between x and the subspace spanned by + the set of dictionary vectors. (e.g. this is the same number that + gets returned by the empirical_kernel_map::project() function's + projection_error argument when the ekm is loaded with the dictionary + vectors.) + - Note that if the dictionary is empty then the return value is + equal to get_kernel()(x,x). + !*/ + + void swap ( + linearly_independent_subset_finder& item + ); + /*! + ensures + - swaps *this with item + !*/ + + size_t size ( + ) const; + /*! + ensures + - returns the number of vectors in the dictionary. + !*/ + + const sample_type& operator[] ( + unsigned long index + ) const; + /*! + requires + - index < size() + ensures + - returns the index'th element in the set of linearly independent + vectors contained in this object. + !*/ + + const matrix get_dictionary ( + ) const; + /*! + ensures + - returns a column vector that contains all the dictionary + vectors in this object. + !*/ + + const matrix& get_kernel_matrix ( + ) const; + /*! + ensures + - returns a matrix K such that: + - K.nr() == K.nc() == size() + - K == kernel_matrix(get_kernel(), get_dictionary()) + i.e. K == the kernel matrix for the dictionary vectors + !*/ + + const matrix& get_inv_kernel_marix ( + ) const; + /*! + ensures + - if (size() != 0) + - returns inv(get_kernel_matrix()) + - else + - returns an empty matrix + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename kernel_type + > + void swap( + linearly_independent_subset_finder& a, + linearly_independent_subset_finder& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + + template < + typename kernel_type + > + void serialize ( + const linearly_independent_subset_finder& item, + std::ostream& out + ); + /*! + provides serialization support for linearly_independent_subset_finder objects + !*/ + + template < + typename kernel_type + > + void deserialize ( + linearly_independent_subset_finder& item, + std::istream& in + ); + /*! + provides serialization support for linearly_independent_subset_finder objects + !*/ + + template < + typename T + > + const matrix_exp mat ( + const linearly_independent_subset_finder& m + ); + /*! + ensures + - converts m into a matrix + - returns a matrix R such that: + - is_col_vector(R) == true + - R.size() == m.size() + - for all valid r: + R(r) == m[r] + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename kernel_type, + typename vector_type, + typename rand_type + > + void fill_lisf ( + linearly_independent_subset_finder& lisf, + const vector_type& samples, + rand_type& rnd, + int sampling_size = 2000 + ); + /*! + requires + - vector_type == a dlib::matrix or something convertible to one via + mat() + - is_vector(mat(samples)) == true + - rand_type == an implementation of rand/rand_kernel_abstract.h or a type + convertible to a string via cast_to_string() + - sampling_size > 0 + ensures + - The purpose of this function is to fill lisf with points from samples. It does + this by randomly sampling elements of samples until no more can be added. The + precise stopping condition is when sampling_size additions to lisf have failed + or the max dictionary size has been reached. + - This function employs a random number generator. If rand_type is a random + number generator then it uses the instance given. Otherwise it uses cast_to_string(rnd) + to seed a new random number generator. + !*/ + + template < + typename kernel_type, + typename vector_type + > + void fill_lisf ( + linearly_independent_subset_finder& lisf, + const vector_type& samples + ); + /*! + requires + - vector_type == a dlib::matrix or something convertible to one via + mat() + - is_vector(mat(samples)) == true + ensures + - performs fill_lisf(lisf, samples, default_rand_generator, 2000) + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_LISf_ABSTRACT_ + diff --git a/ml/dlib/dlib/svm/multiclass_tools.h b/ml/dlib/dlib/svm/multiclass_tools.h new file mode 100644 index 000000000..d97e8aa04 --- /dev/null +++ b/ml/dlib/dlib/svm/multiclass_tools.h @@ -0,0 +1,68 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MULTICLASS_TOoLS_Hh_ +#define DLIB_MULTICLASS_TOoLS_Hh_ + +#include "multiclass_tools_abstract.h" + +#include +#include +#include "../unordered_pair.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template + std::vector select_all_distinct_labels ( + const std::vector& labels + ) + { + std::set temp; + temp.insert(labels.begin(), labels.end()); + return std::vector(temp.begin(), temp.end()); + } + +// ---------------------------------------------------------------------------------------- + + template + std::vector > find_missing_pairs ( + const std::map,U>& bdfs + ) + { + typedef std::map,U> map_type; + + // find all the labels + std::set temp; + for (typename map_type::const_iterator i = bdfs.begin(); i != bdfs.end(); ++i) + { + temp.insert(i->first.first); + temp.insert(i->first.second); + } + + std::vector > missing_pairs; + + // now make sure all label pairs are present + typename std::set::const_iterator i, j; + for (i = temp.begin(); i != temp.end(); ++i) + { + for (j = i, ++j; j != temp.end(); ++j) + { + const unordered_pair p(*i, *j); + + if (bdfs.count(p) == 0) + missing_pairs.push_back(p); + } + } + + return missing_pairs; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MULTICLASS_TOoLS_Hh_ + + diff --git a/ml/dlib/dlib/svm/multiclass_tools_abstract.h b/ml/dlib/dlib/svm/multiclass_tools_abstract.h new file mode 100644 index 000000000..9e7774d3f --- /dev/null +++ b/ml/dlib/dlib/svm/multiclass_tools_abstract.h @@ -0,0 +1,45 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_MULTICLASS_TOoLS_ABSTRACT_Hh_ +#ifdef DLIB_MULTICLASS_TOoLS_ABSTRACT_Hh_ + +#include +#include +#include "../unordered_pair.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template + std::vector select_all_distinct_labels ( + const std::vector& labels + ); + /*! + ensures + - Determines all distinct values present in labels and stores them + into a sorted vector and returns it. They are sorted in ascending + order. + !*/ + +// ---------------------------------------------------------------------------------------- + + template + std::vector > find_missing_pairs ( + const std::map,U>& binary_decision_functions + ); + /*! + ensures + - Let L denote the set of all label_type values present in binary_decision_functions. + - This function finds all the label pairs with both elements distinct and in L but + not also in binary_decision_functions. All these missing pairs are stored + in a sorted vector and returned. They are sorted in ascending order. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MULTICLASS_TOoLS_ABSTRACT_Hh_ + diff --git a/ml/dlib/dlib/svm/null_df.h b/ml/dlib/dlib/svm/null_df.h new file mode 100644 index 000000000..2cbbf04a7 --- /dev/null +++ b/ml/dlib/dlib/svm/null_df.h @@ -0,0 +1,33 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_NULL_DECISION_FUnCTION_Hh_ +#define DLIB_NULL_DECISION_FUnCTION_Hh_ + +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + struct null_df + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a type used to represent an unused field in the list of template + arguments of the one_vs_one_decision_function and one_vs_all_decision_function + templates. As such, null_df doesn't actually do anything. + !*/ + template + double operator() ( const T&) const { return 0; } + }; + + inline void serialize(const null_df&, std::ostream&) {} + inline void deserialize(null_df&, std::istream&) {} + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_NULL_DECISION_FUnCTION_Hh_ + diff --git a/ml/dlib/dlib/svm/null_trainer.h b/ml/dlib/dlib/svm/null_trainer.h new file mode 100644 index 000000000..015b00c15 --- /dev/null +++ b/ml/dlib/dlib/svm/null_trainer.h @@ -0,0 +1,61 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_NULL_TRAINERs_H_ +#define DLIB_NULL_TRAINERs_H_ + +#include "null_trainer_abstract.h" +#include "../algs.h" +#include "function_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename dec_funct_type + > + class null_trainer_type + { + public: + typedef typename dec_funct_type::kernel_type kernel_type; + typedef typename dec_funct_type::scalar_type scalar_type; + typedef typename dec_funct_type::sample_type sample_type; + typedef typename dec_funct_type::mem_manager_type mem_manager_type; + typedef dec_funct_type trained_function_type; + + null_trainer_type ( + ){} + + null_trainer_type ( + const dec_funct_type& dec_funct_ + ) : dec_funct(dec_funct_) {} + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const dec_funct_type& train ( + const in_sample_vector_type& , + const in_scalar_vector_type& + ) const { return dec_funct; } + + private: + dec_funct_type dec_funct; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename dec_funct_type + > + const null_trainer_type null_trainer ( + const dec_funct_type& dec_funct + ) { return null_trainer_type(dec_funct); } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_NULL_TRAINERs_H_ + diff --git a/ml/dlib/dlib/svm/null_trainer_abstract.h b/ml/dlib/dlib/svm/null_trainer_abstract.h new file mode 100644 index 000000000..25f6a5443 --- /dev/null +++ b/ml/dlib/dlib/svm/null_trainer_abstract.h @@ -0,0 +1,101 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_NULL_TRAINERs_ABSTRACT_ +#ifdef DLIB_NULL_TRAINERs_ABSTRACT_ + +#include "../algs.h" +#include "function_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename dec_funct_type + > + class null_trainer_type + { + /*! + REQUIREMENTS ON dec_funct_type + dec_funct_type can be any copyable type that provides the needed + typedefs used below (e.g. kernel_type, scalar_type, etc.). + + WHAT THIS OBJECT REPRESENTS + This object is a simple tool for turning a decision function + into a trainer object that always returns the original decision + function when you try to train with it. + + dlib contains a few "training post processing" algorithms (e.g. + reduced() and reduced2()). These tools take in a trainer object, + tell it to perform training, and then they take the output decision + function and do some kind of post processing to it. The null_trainer_type + object is useful because you can use it to run an already + learned decision function through the training post processing + algorithms by turning a decision function into a null_trainer_type + and then giving it to a post processor. + !*/ + + public: + typedef typename dec_funct_type::kernel_type kernel_type; + typedef typename dec_funct_type::scalar_type scalar_type; + typedef typename dec_funct_type::sample_type sample_type; + typedef typename dec_funct_type::mem_manager_type mem_manager_type; + typedef dec_funct_type trained_function_type; + + null_trainer_type ( + ); + /*! + ensures + - any call to this->train(x,y) will return a default initialized + dec_funct_type object. + !*/ + + null_trainer_type ( + const dec_funct_type& dec_funct + ); + /*! + ensures + - any call to this->train(x,y) will always return a copy of + the given dec_funct object. + !*/ + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const dec_funct_type& train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y + ) const; + /*! + ensures + - returns a copy of the decision function object given to + this object's constructor. + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename dec_funct_type + > + const null_trainer_type null_trainer ( + const dec_funct_type& dec_funct + ) { return null_trainer_type(dec_funct); } + /*! + ensures + - returns a null_trainer_type object that has been instantiated with + the given arguments. That is, this function returns a null_trainer_type + trainer that will return a copy of the given dec_funct object every time + someone calls its train() function. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_NULL_TRAINERs_ABSTRACT_ + + diff --git a/ml/dlib/dlib/svm/num_nonnegative_weights.h b/ml/dlib/dlib/svm/num_nonnegative_weights.h new file mode 100644 index 000000000..4f21f9b69 --- /dev/null +++ b/ml/dlib/dlib/svm/num_nonnegative_weights.h @@ -0,0 +1,76 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_NUM_NONNEGATIVE_WEIGHtS_Hh_ +#define DLIB_NUM_NONNEGATIVE_WEIGHtS_Hh_ + +#include "../enable_if.h" + +namespace dlib +{ + + namespace impl2 + { + template < + typename T, + unsigned long (T::*funct)()const + > + struct hnnf_helper + { + typedef char type; + }; + + template + char has_num_nonnegative_weights_helper( typename hnnf_helper::type = 0 ) { return 0;} + + struct two_bytes + { + char a[2]; + }; + + template + two_bytes has_num_nonnegative_weights_helper(int) { return two_bytes();} + + template + struct work_around_visual_studio_bug + { + const static unsigned long U = sizeof(has_num_nonnegative_weights_helper('a')); + }; + + + // This is a template to tell you if a feature_extractor has a num_nonnegative_weights function or not. + template ::U > + struct has_num_nonnegative_weights + { + static const bool value = false; + }; + + template + struct has_num_nonnegative_weights + { + static const bool value = true; + }; + + + } + + // call fe.num_nonnegative_weights() if it exists, otherwise return 0. + template + typename enable_if,unsigned long>::type num_nonnegative_weights ( + const feature_extractor& fe + ) + { + return fe.num_nonnegative_weights(); + } + + template + typename disable_if,unsigned long>::type num_nonnegative_weights ( + const feature_extractor& /*fe*/ + ) + { + return 0; + } + +} + +#endif // DLIB_NUM_NONNEGATIVE_WEIGHtS_Hh_ + diff --git a/ml/dlib/dlib/svm/one_vs_all_decision_function.h b/ml/dlib/dlib/svm/one_vs_all_decision_function.h new file mode 100644 index 000000000..8afa52344 --- /dev/null +++ b/ml/dlib/dlib/svm/one_vs_all_decision_function.h @@ -0,0 +1,265 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ONE_VS_ALL_DECISION_FUnCTION_Hh_ +#define DLIB_ONE_VS_ALL_DECISION_FUnCTION_Hh_ + +#include "one_vs_all_decision_function_abstract.h" + +#include "../serialize.h" +#include "../type_safe_union.h" +#include +#include +#include "../any.h" +#include "null_df.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename one_vs_all_trainer, + typename DF1 = null_df, typename DF2 = null_df, typename DF3 = null_df, + typename DF4 = null_df, typename DF5 = null_df, typename DF6 = null_df, + typename DF7 = null_df, typename DF8 = null_df, typename DF9 = null_df, + typename DF10 = null_df + > + class one_vs_all_decision_function + { + public: + + typedef typename one_vs_all_trainer::label_type result_type; + typedef typename one_vs_all_trainer::sample_type sample_type; + typedef typename one_vs_all_trainer::scalar_type scalar_type; + typedef typename one_vs_all_trainer::mem_manager_type mem_manager_type; + + typedef std::map > binary_function_table; + + one_vs_all_decision_function() :num_classes(0) {} + + explicit one_vs_all_decision_function( + const binary_function_table& dfs_ + ) : dfs(dfs_) + { + num_classes = dfs.size(); + } + + const binary_function_table& get_binary_decision_functions ( + ) const + { + return dfs; + } + + const std::vector get_labels ( + ) const + { + std::vector temp; + temp.reserve(dfs.size()); + for (typename binary_function_table::const_iterator i = dfs.begin(); i != dfs.end(); ++i) + { + temp.push_back(i->first); + } + return temp; + } + + + template < + typename df1, typename df2, typename df3, typename df4, typename df5, + typename df6, typename df7, typename df8, typename df9, typename df10 + > + one_vs_all_decision_function ( + const one_vs_all_decision_function& item + ) : dfs(item.get_binary_decision_functions()), num_classes(item.number_of_classes()) {} + + unsigned long number_of_classes ( + ) const + { + return num_classes; + } + + std::pair predict ( + const sample_type& sample + ) const + { + DLIB_ASSERT(number_of_classes() != 0, + "\t pair one_vs_all_decision_function::predict()" + << "\n\t You can't make predictions with an empty decision function." + << "\n\t this: " << this + ); + + result_type best_label = result_type(); + scalar_type best_score = -std::numeric_limits::infinity(); + + // run all the classifiers over the sample and find the best one + for(typename binary_function_table::const_iterator i = dfs.begin(); i != dfs.end(); ++i) + { + const scalar_type score = i->second(sample); + + if (score > best_score) + { + best_score = score; + best_label = i->first; + } + } + + return std::make_pair(best_label, best_score); + } + + result_type operator() ( + const sample_type& sample + ) const + { + DLIB_ASSERT(number_of_classes() != 0, + "\t result_type one_vs_all_decision_function::operator()" + << "\n\t You can't make predictions with an empty decision function." + << "\n\t this: " << this + ); + + return predict(sample).first; + } + + + + private: + binary_function_table dfs; + unsigned long num_classes; + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename DF1, typename DF2, typename DF3, + typename DF4, typename DF5, typename DF6, + typename DF7, typename DF8, typename DF9, + typename DF10 + > + void serialize( + const one_vs_all_decision_function& item, + std::ostream& out + ) + { + try + { + type_safe_union temp; + typedef typename T::label_type result_type; + typedef typename T::sample_type sample_type; + typedef typename T::scalar_type scalar_type; + typedef std::map > binary_function_table; + + const unsigned long version = 1; + serialize(version, out); + + const unsigned long size = item.get_binary_decision_functions().size(); + serialize(size, out); + + for(typename binary_function_table::const_iterator i = item.get_binary_decision_functions().begin(); + i != item.get_binary_decision_functions().end(); ++i) + { + serialize(i->first, out); + + if (i->second.template contains()) temp.template get() = any_cast(i->second); + else if (i->second.template contains()) temp.template get() = any_cast(i->second); + else if (i->second.template contains()) temp.template get() = any_cast(i->second); + else if (i->second.template contains()) temp.template get() = any_cast(i->second); + else if (i->second.template contains()) temp.template get() = any_cast(i->second); + else if (i->second.template contains()) temp.template get() = any_cast(i->second); + else if (i->second.template contains()) temp.template get() = any_cast(i->second); + else if (i->second.template contains()) temp.template get() = any_cast(i->second); + else if (i->second.template contains()) temp.template get() = any_cast(i->second); + else if (i->second.template contains()) temp.template get() = any_cast(i->second); + else throw serialization_error("Can't serialize one_vs_all_decision_function. Not all decision functions defined."); + + serialize(temp,out); + } + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing an object of type one_vs_all_decision_function"); + } + + } + +// ---------------------------------------------------------------------------------------- + + namespace impl_ova + { + template + struct copy_to_df_helper + { + copy_to_df_helper(any_decision_function& target_) : target(target_) {} + + any_decision_function& target; + + template + void operator() ( + const T& item + ) const + { + target = item; + } + }; + } + + template < + typename T, + typename DF1, typename DF2, typename DF3, + typename DF4, typename DF5, typename DF6, + typename DF7, typename DF8, typename DF9, + typename DF10 + > + void deserialize( + one_vs_all_decision_function& item, + std::istream& in + ) + { + try + { + type_safe_union temp; + typedef typename T::label_type result_type; + typedef typename T::sample_type sample_type; + typedef typename T::scalar_type scalar_type; + typedef impl_ova::copy_to_df_helper copy_to; + + unsigned long version; + deserialize(version, in); + + if (version != 1) + throw serialization_error("Can't deserialize one_vs_all_decision_function. Wrong version."); + + unsigned long size; + deserialize(size, in); + + typedef std::map > binary_function_table; + binary_function_table dfs; + + result_type l; + for (unsigned long i = 0; i < size; ++i) + { + deserialize(l, in); + deserialize(temp, in); + if (temp.template contains()) + throw serialization_error("A sub decision function of unknown type was encountered."); + + temp.apply_to_contents(copy_to(dfs[l])); + } + + item = one_vs_all_decision_function(dfs); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while deserializing an object of type one_vs_all_decision_function"); + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_ONE_VS_ALL_DECISION_FUnCTION_Hh_ + + + diff --git a/ml/dlib/dlib/svm/one_vs_all_decision_function_abstract.h b/ml/dlib/dlib/svm/one_vs_all_decision_function_abstract.h new file mode 100644 index 000000000..8daacb8d6 --- /dev/null +++ b/ml/dlib/dlib/svm/one_vs_all_decision_function_abstract.h @@ -0,0 +1,214 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_ONE_VS_ALL_DECISION_FUnCTION_ABSTRACT_Hh_ +#ifdef DLIB_ONE_VS_ALL_DECISION_FUnCTION_ABSTRACT_Hh_ + + +#include "../serialize.h" +#include +#include "../any/any_decision_function_abstract.h" +#include "one_vs_all_trainer_abstract.h" +#include "null_df.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename one_vs_all_trainer, + typename DF1 = null_df, typename DF2 = null_df, typename DF3 = null_df, + typename DF4 = null_df, typename DF5 = null_df, typename DF6 = null_df, + typename DF7 = null_df, typename DF8 = null_df, typename DF9 = null_df, + typename DF10 = null_df + > + class one_vs_all_decision_function + { + /*! + REQUIREMENTS ON one_vs_all_trainer + This should be an instantiation of the one_vs_all_trainer template. + It is used to infer which types are used for various things, such as + representing labels. + + REQUIREMENTS ON DF* + These types can either be left at their default values or set + to any kind of decision function object capable of being + stored in an any_decision_function + object. These types should also be serializable. + + WHAT THIS OBJECT REPRESENTS + This object represents a multiclass classifier built out of a set of + binary classifiers. Each binary classifier is used to vote for the + correct multiclass label using a one vs. all strategy. Therefore, + if you have N classes then there will be N binary classifiers inside + this object. + + Note that the DF* template arguments are only used if you want + to serialize and deserialize one_vs_all_decision_function objects. + Specifically, all the types of binary decision function contained + within a one_vs_all_decision_function must be listed in the + template arguments if serialization and deserialization is to + be used. + + THREAD SAFETY + It is always safe to use distinct instances of this object in different + threads. However, when a single instance is shared between threads then + the following rules apply: + It is safe to call the const members of this object from multiple + threads so long as all the decision functions contained in this object + are also threadsafe. This is because the const members are purely + read-only operations. However, any operation that modifies a + one_vs_all_decision_function is not threadsafe. + !*/ + public: + + typedef typename one_vs_all_trainer::label_type result_type; + typedef typename one_vs_all_trainer::sample_type sample_type; + typedef typename one_vs_all_trainer::scalar_type scalar_type; + typedef typename one_vs_all_trainer::mem_manager_type mem_manager_type; + + typedef std::map > binary_function_table; + + one_vs_all_decision_function( + ); + /*! + ensures + - #number_of_classes() == 0 + - #get_binary_decision_functions().size() == 0 + - #get_labels().size() == 0 + !*/ + + explicit one_vs_all_decision_function( + const binary_function_table& decision_functions + ); + /*! + ensures + - #get_binary_decision_functions() == decision_functions + - #get_labels() == a list of all the labels which appear in the + given set of decision functions + - #number_of_classes() == #get_labels().size() + !*/ + + template < + typename df1, typename df2, typename df3, typename df4, typename df5, + typename df6, typename df7, typename df8, typename df9, typename df10 + > + one_vs_all_decision_function ( + const one_vs_all_decision_function& item + ); + /*! + ensures + - #*this will be a copy of item + - #number_of_classes() == item.number_of_classes() + - #get_labels() == item.get_labels() + - #get_binary_decision_functions() == item.get_binary_decision_functions() + !*/ + + const binary_function_table& get_binary_decision_functions ( + ) const; + /*! + ensures + - returns the table of binary decision functions used by this + object. The label given to a test sample is computed by + determining which binary decision function has the largest + (i.e. most positive) output and returning the label associated + with that decision function. + !*/ + + const std::vector get_labels ( + ) const; + /*! + ensures + - returns a vector containing all the labels which can be + predicted by this object. + !*/ + + unsigned long number_of_classes ( + ) const; + /*! + ensures + - returns get_labels().size() + (i.e. returns the number of different labels/classes predicted by + this object) + !*/ + + std::pair predict ( + const sample_type& sample + ) const; + /*! + requires + - number_of_classes() != 0 + ensures + - Evaluates all the decision functions in get_binary_decision_functions() + and returns the predicted label and score for the input sample. That is, + returns std::make_pair(label,score) + - The label is determined by whichever classifier outputs the largest + score. + !*/ + + result_type operator() ( + const sample_type& sample + ) const + /*! + requires + - number_of_classes() != 0 + ensures + - Evaluates all the decision functions in get_binary_decision_functions() + and returns the predicted label. That is, returns predict(sample).first. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename DF1, typename DF2, typename DF3, + typename DF4, typename DF5, typename DF6, + typename DF7, typename DF8, typename DF9, + typename DF10 + > + void serialize( + const one_vs_all_decision_function& item, + std::ostream& out + ); + /*! + ensures + - writes the given item to the output stream out. + throws + - serialization_error. + This is thrown if there is a problem writing to the ostream or if item + contains a type of decision function not listed among the DF* template + arguments. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename DF1, typename DF2, typename DF3, + typename DF4, typename DF5, typename DF6, + typename DF7, typename DF8, typename DF9, + typename DF10 + > + void deserialize( + one_vs_all_decision_function& item, + std::istream& in + ); + /*! + ensures + - deserializes a one_vs_all_decision_function from in and stores it in item. + throws + - serialization_error. + This is thrown if there is a problem reading from the istream or if the + serialized data contains decision functions not listed among the DF* + template arguments. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_ONE_VS_ALL_DECISION_FUnCTION_ABSTRACT_Hh_ + diff --git a/ml/dlib/dlib/svm/one_vs_all_trainer.h b/ml/dlib/dlib/svm/one_vs_all_trainer.h new file mode 100644 index 000000000..bcb006a41 --- /dev/null +++ b/ml/dlib/dlib/svm/one_vs_all_trainer.h @@ -0,0 +1,234 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ONE_VS_ALL_TRAiNER_Hh_ +#define DLIB_ONE_VS_ALL_TRAiNER_Hh_ + +#include "one_vs_all_trainer_abstract.h" + +#include "one_vs_all_decision_function.h" +#include + +#include "multiclass_tools.h" + +#include +#include + +#include "../any.h" +#include +#include +#include "../threads.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename any_trainer, + typename label_type_ = double + > + class one_vs_all_trainer + { + public: + typedef label_type_ label_type; + + typedef typename any_trainer::sample_type sample_type; + typedef typename any_trainer::scalar_type scalar_type; + typedef typename any_trainer::mem_manager_type mem_manager_type; + + typedef one_vs_all_decision_function trained_function_type; + + one_vs_all_trainer ( + ) : + verbose(false), + num_threads(4) + {} + + void set_trainer ( + const any_trainer& trainer + ) + { + default_trainer = trainer; + trainers.clear(); + } + + void set_trainer ( + const any_trainer& trainer, + const label_type& l + ) + { + trainers[l] = trainer; + } + + void be_verbose ( + ) + { + verbose = true; + } + + void be_quiet ( + ) + { + verbose = false; + } + + void set_num_threads ( + unsigned long num + ) + { + num_threads = num; + } + + unsigned long get_num_threads ( + ) const + { + return num_threads; + } + + struct invalid_label : public dlib::error + { + invalid_label(const std::string& msg, const label_type& l_ + ) : dlib::error(msg), l(l_) {}; + + virtual ~invalid_label( + ) throw() {} + + label_type l; + }; + + trained_function_type train ( + const std::vector& all_samples, + const std::vector& all_labels + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(is_learning_problem(all_samples,all_labels), + "\t trained_function_type one_vs_all_trainer::train(all_samples,all_labels)" + << "\n\t invalid inputs were given to this function" + << "\n\t all_samples.size(): " << all_samples.size() + << "\n\t all_labels.size(): " << all_labels.size() + ); + + const std::vector distinct_labels = select_all_distinct_labels(all_labels); + + // make sure we have a trainer object for each of the label types. + for (unsigned long i = 0; i < distinct_labels.size(); ++i) + { + const label_type l = distinct_labels[i]; + const typename binary_function_table::const_iterator itr = trainers.find(l); + + if (itr == trainers.end() && default_trainer.is_empty()) + { + std::ostringstream sout; + sout << "In one_vs_all_trainer, no trainer registered for the " << l << " label."; + throw invalid_label(sout.str(), l); + } + } + + + // now do the training + parallel_for_helper helper(all_samples,all_labels,default_trainer,trainers,verbose,distinct_labels); + parallel_for(num_threads, 0, distinct_labels.size(), helper, 500); + + if (helper.error_message.size() != 0) + { + throw dlib::error("binary trainer threw while training one vs. all classifier. Error was: " + helper.error_message); + } + return trained_function_type(helper.dfs); + } + + private: + + typedef std::map binary_function_table; + struct parallel_for_helper + { + parallel_for_helper( + const std::vector& all_samples_, + const std::vector& all_labels_, + const any_trainer& default_trainer_, + const binary_function_table& trainers_, + const bool verbose_, + const std::vector& distinct_labels_ + ) : + all_samples(all_samples_), + all_labels(all_labels_), + default_trainer(default_trainer_), + trainers(trainers_), + verbose(verbose_), + distinct_labels(distinct_labels_) + {} + + void operator()(long i) const + { + try + { + std::vector labels; + + const label_type l = distinct_labels[i]; + + // setup one of the one vs all training sets + for (unsigned long k = 0; k < all_samples.size(); ++k) + { + if (all_labels[k] == l) + labels.push_back(+1); + else + labels.push_back(-1); + } + + + if (verbose) + { + auto_mutex lock(class_mutex); + std::cout << "Training classifier for " << l << " vs. all" << std::endl; + } + + any_trainer trainer; + // now train a binary classifier using the samples we selected + { auto_mutex lock(class_mutex); + const typename binary_function_table::const_iterator itr = trainers.find(l); + if (itr != trainers.end()) + trainer = itr->second; + else + trainer = default_trainer; + } + + any_decision_function binary_df = trainer.train(all_samples, labels); + + auto_mutex lock(class_mutex); + dfs[l] = binary_df; + } + catch (std::exception& e) + { + auto_mutex lock(class_mutex); + error_message = e.what(); + } + } + + mutable typename trained_function_type::binary_function_table dfs; + mutex class_mutex; + mutable std::string error_message; + + const std::vector& all_samples; + const std::vector& all_labels; + const any_trainer& default_trainer; + const binary_function_table& trainers; + const bool verbose; + const std::vector& distinct_labels; + }; + + any_trainer default_trainer; + + binary_function_table trainers; + + bool verbose; + unsigned long num_threads; + + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_ONE_VS_ALL_TRAiNER_Hh_ + + diff --git a/ml/dlib/dlib/svm/one_vs_all_trainer_abstract.h b/ml/dlib/dlib/svm/one_vs_all_trainer_abstract.h new file mode 100644 index 000000000..fb719a7e4 --- /dev/null +++ b/ml/dlib/dlib/svm/one_vs_all_trainer_abstract.h @@ -0,0 +1,163 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_ONE_VS_ALL_TRAiNER_ABSTRACT_Hh_ +#ifdef DLIB_ONE_VS_ALL_TRAiNER_ABSTRACT_Hh_ + + +#include "one_vs_all_decision_function_abstract.h" +#include + +#include "../any/any_trainer_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename any_trainer, + typename label_type_ = double + > + class one_vs_all_trainer + { + /*! + REQUIREMENTS ON any_trainer + must be an instantiation of the dlib::any_trainer template. + + REQUIREMENTS ON label_type_ + label_type_ must be default constructable, copyable, and comparable using + operator < and ==. It must also be possible to write it to an std::ostream + using operator<<. + + WHAT THIS OBJECT REPRESENTS + This object is a tool for turning a bunch of binary classifiers into a + multiclass classifier. It does this by training the binary classifiers + in a one vs. all fashion. That is, if you have N possible classes then + it trains N binary classifiers which are then used to vote on the identity + of a test sample. + + This object works with any kind of binary classification trainer object + capable of being assigned to an any_trainer object. (e.g. the svm_nu_trainer) + !*/ + + public: + + + typedef label_type_ label_type; + + typedef typename any_trainer::sample_type sample_type; + typedef typename any_trainer::scalar_type scalar_type; + typedef typename any_trainer::mem_manager_type mem_manager_type; + + typedef one_vs_all_decision_function trained_function_type; + + one_vs_all_trainer ( + ); + /*! + ensures + - This object is properly initialized. + - This object will not be verbose unless be_verbose() is called. + - No binary trainers are associated with *this. I.e. you have to + call set_trainer() before calling train(). + - #get_num_threads() == 4 + !*/ + + void set_trainer ( + const any_trainer& trainer + ); + /*! + ensures + - sets the trainer used for all binary subproblems. Any previous + calls to set_trainer() are overridden by this function. Even the + more specific set_trainer(trainer, l) form. + !*/ + + void set_trainer ( + const any_trainer& trainer, + const label_type& l + ); + /*! + ensures + - Sets the trainer object used to create a binary classifier to + distinguish l labeled samples from all other samples. + !*/ + + void be_verbose ( + ); + /*! + ensures + - This object will print status messages to standard out so that a + user can observe the progress of the algorithm. + !*/ + + void be_quiet ( + ); + /*! + ensures + - this object will not print anything to standard out + !*/ + + void set_num_threads ( + unsigned long num + ); + /*! + ensures + - #get_num_threads() == num + !*/ + + unsigned long get_num_threads ( + ) const; + /*! + ensures + - returns the number of threads used during training. You should + usually set this equal to the number of processing cores on your + machine. + !*/ + + struct invalid_label : public dlib::error + { + /*! + This is the exception thrown by the train() function below. + !*/ + label_type l; + }; + + trained_function_type train ( + const std::vector& all_samples, + const std::vector& all_labels + ) const; + /*! + requires + - is_learning_problem(all_samples, all_labels) + ensures + - trains a bunch of binary classifiers in a one vs all fashion to solve the given + multiclass classification problem. + - returns a one_vs_all_decision_function F with the following properties: + - F contains all the learned binary classifiers and can be used to predict + the labels of new samples. + - if (new_x is a sample predicted to have a label of L) then + - F(new_x) == L + - F.get_labels() == select_all_distinct_labels(all_labels) + - F.number_of_classes() == select_all_distinct_labels(all_labels).size() + throws + - invalid_label + This exception is thrown if there are labels in all_labels which don't have + any corresponding trainer object. This will never happen if set_trainer(trainer) + has been called. However, if only the set_trainer(trainer,l) form has been + used then this exception is thrown if not all labels have been given a trainer. + + invalid_label::l will contain the label which is missing a trainer object. + Additionally, the exception will contain an informative error message available + via invalid_label::what(). + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_ONE_VS_ALL_TRAiNER_ABSTRACT_Hh_ + + + diff --git a/ml/dlib/dlib/svm/one_vs_one_decision_function.h b/ml/dlib/dlib/svm/one_vs_one_decision_function.h new file mode 100644 index 000000000..02a5fa51e --- /dev/null +++ b/ml/dlib/dlib/svm/one_vs_one_decision_function.h @@ -0,0 +1,291 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ONE_VS_ONE_DECISION_FUnCTION_Hh_ +#define DLIB_ONE_VS_ONE_DECISION_FUnCTION_Hh_ + +#include "one_vs_one_decision_function_abstract.h" + +#include "../serialize.h" +#include "../type_safe_union.h" +#include +#include +#include +#include +#include "../any.h" +#include "../unordered_pair.h" +#include "null_df.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename one_vs_one_trainer, + typename DF1 = null_df, typename DF2 = null_df, typename DF3 = null_df, + typename DF4 = null_df, typename DF5 = null_df, typename DF6 = null_df, + typename DF7 = null_df, typename DF8 = null_df, typename DF9 = null_df, + typename DF10 = null_df + > + class one_vs_one_decision_function + { + public: + + typedef typename one_vs_one_trainer::label_type result_type; + typedef typename one_vs_one_trainer::sample_type sample_type; + typedef typename one_vs_one_trainer::scalar_type scalar_type; + typedef typename one_vs_one_trainer::mem_manager_type mem_manager_type; + + typedef std::map, any_decision_function > binary_function_table; + + one_vs_one_decision_function() :num_classes(0) {} + + explicit one_vs_one_decision_function( + const binary_function_table& dfs_ + ) : dfs(dfs_) + { +#ifdef ENABLE_ASSERTS + { + const std::vector > missing_pairs = find_missing_pairs(dfs_); + if (missing_pairs.size() != 0) + { + std::ostringstream sout; + for (unsigned long i = 0; i < missing_pairs.size(); ++i) + { + sout << "\t (" << missing_pairs[i].first << ", " << missing_pairs[i].second << ")\n"; + } + DLIB_ASSERT(missing_pairs.size() == 0, + "\t void one_vs_one_decision_function::one_vs_one_decision_function()" + << "\n\t The supplied set of binary decision functions is incomplete." + << "\n\t this: " << this + << "\n\t Classifiers are missing for the following label pairs: \n" << sout.str() + ); + } + } +#endif + + // figure out how many labels are covered by this set of binary decision functions + std::set labels; + for (typename binary_function_table::const_iterator i = dfs.begin(); i != dfs.end(); ++i) + { + labels.insert(i->first.first); + labels.insert(i->first.second); + } + num_classes = labels.size(); + } + + const binary_function_table& get_binary_decision_functions ( + ) const + { + return dfs; + } + + const std::vector get_labels ( + ) const + { + std::set labels; + for (typename binary_function_table::const_iterator i = dfs.begin(); i != dfs.end(); ++i) + { + labels.insert(i->first.first); + labels.insert(i->first.second); + } + return std::vector(labels.begin(), labels.end()); + } + + + template < + typename df1, typename df2, typename df3, typename df4, typename df5, + typename df6, typename df7, typename df8, typename df9, typename df10 + > + one_vs_one_decision_function ( + const one_vs_one_decision_function& item + ) : dfs(item.get_binary_decision_functions()), num_classes(item.number_of_classes()) {} + + unsigned long number_of_classes ( + ) const + { + return num_classes; + } + + result_type operator() ( + const sample_type& sample + ) const + { + DLIB_ASSERT(number_of_classes() != 0, + "\t void one_vs_one_decision_function::operator()" + << "\n\t You can't make predictions with an empty decision function." + << "\n\t this: " << this + ); + + std::map votes; + + // run all the classifiers over the sample + for(typename binary_function_table::const_iterator i = dfs.begin(); i != dfs.end(); ++i) + { + const scalar_type score = i->second(sample); + + if (score > 0) + votes[i->first.first] += 1; + else + votes[i->first.second] += 1; + } + + // now figure out who had the most votes + result_type best_label = result_type(); + int best_votes = 0; + for (typename std::map::iterator i = votes.begin(); i != votes.end(); ++i) + { + if (i->second > best_votes) + { + best_votes = i->second; + best_label = i->first; + } + } + + return best_label; + } + + + + private: + binary_function_table dfs; + unsigned long num_classes; + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename DF1, typename DF2, typename DF3, + typename DF4, typename DF5, typename DF6, + typename DF7, typename DF8, typename DF9, + typename DF10 + > + void serialize( + const one_vs_one_decision_function& item, + std::ostream& out + ) + { + try + { + type_safe_union temp; + typedef typename T::label_type result_type; + typedef typename T::sample_type sample_type; + typedef typename T::scalar_type scalar_type; + typedef std::map, any_decision_function > binary_function_table; + + const unsigned long version = 1; + serialize(version, out); + + const unsigned long size = item.get_binary_decision_functions().size(); + serialize(size, out); + + for(typename binary_function_table::const_iterator i = item.get_binary_decision_functions().begin(); + i != item.get_binary_decision_functions().end(); ++i) + { + serialize(i->first, out); + + if (i->second.template contains()) temp.template get() = any_cast(i->second); + else if (i->second.template contains()) temp.template get() = any_cast(i->second); + else if (i->second.template contains()) temp.template get() = any_cast(i->second); + else if (i->second.template contains()) temp.template get() = any_cast(i->second); + else if (i->second.template contains()) temp.template get() = any_cast(i->second); + else if (i->second.template contains()) temp.template get() = any_cast(i->second); + else if (i->second.template contains()) temp.template get() = any_cast(i->second); + else if (i->second.template contains()) temp.template get() = any_cast(i->second); + else if (i->second.template contains()) temp.template get() = any_cast(i->second); + else if (i->second.template contains()) temp.template get() = any_cast(i->second); + else throw serialization_error("Can't serialize one_vs_one_decision_function. Not all decision functions defined."); + + serialize(temp,out); + } + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing an object of type one_vs_one_decision_function"); + } + + } + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + template + struct copy_to_df_helper + { + copy_to_df_helper(any_decision_function& target_) : target(target_) {} + + any_decision_function& target; + + template + void operator() ( + const T& item + ) const + { + target = item; + } + }; + } + + template < + typename T, + typename DF1, typename DF2, typename DF3, + typename DF4, typename DF5, typename DF6, + typename DF7, typename DF8, typename DF9, + typename DF10 + > + void deserialize( + one_vs_one_decision_function& item, + std::istream& in + ) + { + try + { + type_safe_union temp; + typedef typename T::label_type result_type; + typedef typename T::sample_type sample_type; + typedef typename T::scalar_type scalar_type; + typedef impl::copy_to_df_helper copy_to; + + unsigned long version; + deserialize(version, in); + + if (version != 1) + throw serialization_error("Can't deserialize one_vs_one_decision_function. Wrong version."); + + unsigned long size; + deserialize(size, in); + + typedef std::map, any_decision_function > binary_function_table; + binary_function_table dfs; + + unordered_pair p; + for (unsigned long i = 0; i < size; ++i) + { + deserialize(p, in); + deserialize(temp, in); + if (temp.template contains()) + throw serialization_error("A sub decision function of unknown type was encountered."); + + temp.apply_to_contents(copy_to(dfs[p])); + } + + item = one_vs_one_decision_function(dfs); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while deserializing an object of type one_vs_one_decision_function"); + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_ONE_VS_ONE_DECISION_FUnCTION_Hh_ + + diff --git a/ml/dlib/dlib/svm/one_vs_one_decision_function_abstract.h b/ml/dlib/dlib/svm/one_vs_one_decision_function_abstract.h new file mode 100644 index 000000000..cf22e0ba7 --- /dev/null +++ b/ml/dlib/dlib/svm/one_vs_one_decision_function_abstract.h @@ -0,0 +1,213 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_ONE_VS_ONE_DECISION_FUnCTION_ABSTRACT_Hh_ +#ifdef DLIB_ONE_VS_ONE_DECISION_FUnCTION_ABSTRACT_Hh_ + + +#include "../serialize.h" +#include +#include "../any/any_decision_function_abstract.h" +#include "../unordered_pair.h" +#include "one_vs_one_trainer_abstract.h" +#include "null_df.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename one_vs_one_trainer, + typename DF1 = null_df, typename DF2 = null_df, typename DF3 = null_df, + typename DF4 = null_df, typename DF5 = null_df, typename DF6 = null_df, + typename DF7 = null_df, typename DF8 = null_df, typename DF9 = null_df, + typename DF10 = null_df + > + class one_vs_one_decision_function + { + /*! + REQUIREMENTS ON one_vs_one_trainer + This should be an instantiation of the one_vs_one_trainer template. + It is used to infer which types are used for various things, such as + representing labels. + + REQUIREMENTS ON DF* + These types can either be left at their default values or set + to any kind of decision function object capable of being + stored in an any_decision_function + object. These types should also be serializable. + + WHAT THIS OBJECT REPRESENTS + This object represents a multiclass classifier built out + of a set of binary classifiers. Each binary classifier + is used to vote for the correct multiclass label using a + one vs. one strategy. Therefore, if you have N classes then + there will be N*(N-1)/2 binary classifiers inside this object. + + Note that the DF* template arguments are only used if you want + to serialize and deserialize one_vs_one_decision_function objects. + Specifically, all the types of binary decision function contained + within a one_vs_one_decision_function must be listed in the + template arguments if serialization and deserialization is to + be used. + + THREAD SAFETY + It is always safe to use distinct instances of this object in different + threads. However, when a single instance is shared between threads then + the following rules apply: + It is safe to call the const members of this object from multiple + threads so long as all the decision functions contained in this object + are also threadsafe. This is because the const members are purely + read-only operations. However, any operation that modifies a + one_vs_one_decision_function is not threadsafe. + !*/ + public: + + typedef typename one_vs_one_trainer::label_type result_type; + typedef typename one_vs_one_trainer::sample_type sample_type; + typedef typename one_vs_one_trainer::scalar_type scalar_type; + typedef typename one_vs_one_trainer::mem_manager_type mem_manager_type; + + typedef std::map, any_decision_function > binary_function_table; + + one_vs_one_decision_function( + ); + /*! + ensures + - #number_of_classes() == 0 + - #get_binary_decision_functions().size() == 0 + - #get_labels().size() == 0 + !*/ + + explicit one_vs_one_decision_function( + const binary_function_table& decision_functions + ); + /*! + requires + - find_missing_pairs(decision_functions).size() == 0 + (i.e. all pairs of labels have an associated decision function) + ensures + - #get_binary_decision_functions() == decision_functions + - #get_labels() == a list of all the labels which appear in the + given set of decision functions + - #number_of_classes() == #get_labels().size() + !*/ + + template < + typename df1, typename df2, typename df3, typename df4, typename df5, + typename df6, typename df7, typename df8, typename df9, typename df10 + > + one_vs_one_decision_function ( + const one_vs_one_decision_function& item + ); + /*! + ensures + - #*this will be a copy of item + - #number_of_classes() == item.number_of_classes() + - #get_labels() == item.get_labels() + - #get_binary_decision_functions() == item.get_binary_decision_functions() + !*/ + + const binary_function_table& get_binary_decision_functions ( + ) const; + /*! + ensures + - returns the table of binary decision functions used by this + object. The correspondence between binary decision functions + and multiclass labels is the following: + - for each element i of get_binary_decision_functions() + - i->first == the label pair associated with binary decision + function i->second. + - if (decision function i->second outputs a value > 0) then + - i->second is indicating that a test sample should + receive a label of i->first.first + - else + - i->second is indicating that a test sample should + receive a label of i->first.second + !*/ + + const std::vector get_labels ( + ) const; + /*! + ensures + - returns a vector containing all the labels which can be + predicted by this object. + !*/ + + unsigned long number_of_classes ( + ) const; + /*! + ensures + - returns get_labels().size() + (i.e. returns the number of different labels/classes predicted by + this object) + !*/ + + result_type operator() ( + const sample_type& sample + ) const + /*! + requires + - number_of_classes() != 0 + ensures + - evaluates all the decision functions in get_binary_decision_functions() + and returns the label which received the most votes. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename DF1, typename DF2, typename DF3, + typename DF4, typename DF5, typename DF6, + typename DF7, typename DF8, typename DF9, + typename DF10 + > + void serialize( + const one_vs_one_decision_function& item, + std::ostream& out + ); + /*! + ensures + - writes the given item to the output stream out. + throws + - serialization_error. + This is thrown if there is a problem writing to the ostream or if item + contains a type of decision function not listed among the DF* template + arguments. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename DF1, typename DF2, typename DF3, + typename DF4, typename DF5, typename DF6, + typename DF7, typename DF8, typename DF9, + typename DF10 + > + void deserialize( + one_vs_one_decision_function& item, + std::istream& in + ); + /*! + ensures + - deserializes a one_vs_one_decision_function from in and stores it in item. + throws + - serialization_error. + This is thrown if there is a problem reading from the istream or if the + serialized data contains decision functions not listed among the DF* + template arguments. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_ONE_VS_ONE_DECISION_FUnCTION_ABSTRACT_Hh_ + + + diff --git a/ml/dlib/dlib/svm/one_vs_one_trainer.h b/ml/dlib/dlib/svm/one_vs_one_trainer.h new file mode 100644 index 000000000..2beec8f67 --- /dev/null +++ b/ml/dlib/dlib/svm/one_vs_one_trainer.h @@ -0,0 +1,249 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ONE_VS_ONE_TRAiNER_Hh_ +#define DLIB_ONE_VS_ONE_TRAiNER_Hh_ + +#include "one_vs_one_trainer_abstract.h" + +#include "one_vs_one_decision_function.h" +#include + +#include "../unordered_pair.h" +#include "multiclass_tools.h" + +#include +#include + +#include "../any.h" +#include +#include +#include "../threads.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename any_trainer, + typename label_type_ = double + > + class one_vs_one_trainer + { + public: + typedef label_type_ label_type; + + typedef typename any_trainer::sample_type sample_type; + typedef typename any_trainer::scalar_type scalar_type; + typedef typename any_trainer::mem_manager_type mem_manager_type; + + typedef one_vs_one_decision_function trained_function_type; + + one_vs_one_trainer ( + ) : + verbose(false), + num_threads(4) + {} + + void set_trainer ( + const any_trainer& trainer + ) + { + default_trainer = trainer; + trainers.clear(); + } + + void set_trainer ( + const any_trainer& trainer, + const label_type& l1, + const label_type& l2 + ) + { + trainers[make_unordered_pair(l1,l2)] = trainer; + } + + void be_verbose ( + ) + { + verbose = true; + } + + void be_quiet ( + ) + { + verbose = false; + } + + void set_num_threads ( + unsigned long num + ) + { + num_threads = num; + } + + unsigned long get_num_threads ( + ) const + { + return num_threads; + } + + struct invalid_label : public dlib::error + { + invalid_label(const std::string& msg, const label_type& l1_, const label_type& l2_ + ) : dlib::error(msg), l1(l1_), l2(l2_) {}; + + virtual ~invalid_label( + ) throw() {} + + label_type l1, l2; + }; + + trained_function_type train ( + const std::vector& all_samples, + const std::vector& all_labels + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(is_learning_problem(all_samples,all_labels), + "\t trained_function_type one_vs_one_trainer::train(all_samples,all_labels)" + << "\n\t invalid inputs were given to this function" + << "\n\t all_samples.size(): " << all_samples.size() + << "\n\t all_labels.size(): " << all_labels.size() + ); + + const std::vector distinct_labels = select_all_distinct_labels(all_labels); + + + // fill pairs with all the pairs of labels. + std::vector > pairs; + for (unsigned long i = 0; i < distinct_labels.size(); ++i) + { + for (unsigned long j = i+1; j < distinct_labels.size(); ++j) + { + pairs.push_back(unordered_pair(distinct_labels[i], distinct_labels[j])); + + // make sure we have a trainer for this pair + const typename binary_function_table::const_iterator itr = trainers.find(pairs.back()); + if (itr == trainers.end() && default_trainer.is_empty()) + { + std::ostringstream sout; + sout << "In one_vs_one_trainer, no trainer registered for the (" + << pairs.back().first << ", " << pairs.back().second << ") label pair."; + throw invalid_label(sout.str(), pairs.back().first, pairs.back().second); + } + } + } + + + + // Now train on all the label pairs. + parallel_for_helper helper(all_samples,all_labels,default_trainer,trainers,verbose,pairs); + parallel_for(num_threads, 0, pairs.size(), helper, 500); + + if (helper.error_message.size() != 0) + { + throw dlib::error("binary trainer threw while training one vs. one classifier. Error was: " + helper.error_message); + } + return trained_function_type(helper.dfs); + } + + private: + + typedef std::map, any_trainer> binary_function_table; + + struct parallel_for_helper + { + parallel_for_helper( + const std::vector& all_samples_, + const std::vector& all_labels_, + const any_trainer& default_trainer_, + const binary_function_table& trainers_, + const bool verbose_, + const std::vector >& pairs_ + ) : + all_samples(all_samples_), + all_labels(all_labels_), + default_trainer(default_trainer_), + trainers(trainers_), + verbose(verbose_), + pairs(pairs_) + {} + + void operator()(long i) const + { + try + { + std::vector samples; + std::vector labels; + + const unordered_pair p = pairs[i]; + + // pick out the samples corresponding to these two classes + for (unsigned long k = 0; k < all_samples.size(); ++k) + { + if (all_labels[k] == p.first) + { + samples.push_back(all_samples[k]); + labels.push_back(+1); + } + else if (all_labels[k] == p.second) + { + samples.push_back(all_samples[k]); + labels.push_back(-1); + } + } + + if (verbose) + { + auto_mutex lock(class_mutex); + std::cout << "Training classifier for " << p.first << " vs. " << p.second << std::endl; + } + + any_trainer trainer; + // now train a binary classifier using the samples we selected + { auto_mutex lock(class_mutex); + const typename binary_function_table::const_iterator itr = trainers.find(p); + if (itr != trainers.end()) + trainer = itr->second; + else + trainer = default_trainer; + } + + any_decision_function binary_df = trainer.train(samples, labels); + + auto_mutex lock(class_mutex); + dfs[p] = binary_df; + } + catch (std::exception& e) + { + auto_mutex lock(class_mutex); + error_message = e.what(); + } + } + + mutable typename trained_function_type::binary_function_table dfs; + mutex class_mutex; + mutable std::string error_message; + + const std::vector& all_samples; + const std::vector& all_labels; + const any_trainer& default_trainer; + const binary_function_table& trainers; + const bool verbose; + const std::vector >& pairs; + }; + + + any_trainer default_trainer; + binary_function_table trainers; + bool verbose; + unsigned long num_threads; + + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_ONE_VS_ONE_TRAiNER_Hh_ + diff --git a/ml/dlib/dlib/svm/one_vs_one_trainer_abstract.h b/ml/dlib/dlib/svm/one_vs_one_trainer_abstract.h new file mode 100644 index 000000000..42ba35815 --- /dev/null +++ b/ml/dlib/dlib/svm/one_vs_one_trainer_abstract.h @@ -0,0 +1,166 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_ONE_VS_ONE_TRAiNER_ABSTRACT_Hh_ +#ifdef DLIB_ONE_VS_ONE_TRAiNER_ABSTRACT_Hh_ + + +#include "one_vs_one_decision_function_abstract.h" +#include + +#include "../any/any_trainer_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename any_trainer, + typename label_type_ = double + > + class one_vs_one_trainer + { + /*! + REQUIREMENTS ON any_trainer + must be an instantiation of the dlib::any_trainer template. + + REQUIREMENTS ON label_type_ + label_type_ must be default constructable, copyable, and comparable using + operator < and ==. It must also be possible to write it to an std::ostream + using operator<<. + + WHAT THIS OBJECT REPRESENTS + This object is a tool for turning a bunch of binary classifiers + into a multiclass classifier. It does this by training the binary + classifiers in a one vs. one fashion. That is, if you have N possible + classes then it trains N*(N-1)/2 binary classifiers which are then used + to vote on the identity of a test sample. + + This object works with any kind of binary classification trainer object + capable of being assigned to an any_trainer object. (e.g. the svm_nu_trainer) + !*/ + + public: + + + typedef label_type_ label_type; + + typedef typename any_trainer::sample_type sample_type; + typedef typename any_trainer::scalar_type scalar_type; + typedef typename any_trainer::mem_manager_type mem_manager_type; + + typedef one_vs_one_decision_function trained_function_type; + + one_vs_one_trainer ( + ); + /*! + ensures + - This object is properly initialized + - This object will not be verbose unless be_verbose() is called. + - No binary trainers are associated with *this. I.e. you have to + call set_trainer() before calling train(). + - #get_num_threads() == 4 + !*/ + + void set_trainer ( + const any_trainer& trainer + ); + /*! + ensures + - sets the trainer used for all pairs of training. Any previous + calls to set_trainer() are overridden by this function. Even the + more specific set_trainer(trainer, l1, l2) form. + !*/ + + void set_trainer ( + const any_trainer& trainer, + const label_type& l1, + const label_type& l2 + ); + /*! + requires + - l1 != l2 + ensures + - Sets the trainer object used to create a binary classifier to + distinguish l1 labeled samples from l2 labeled samples. + !*/ + + void be_verbose ( + ); + /*! + ensures + - This object will print status messages to standard out so that a + user can observe the progress of the algorithm. + !*/ + + void be_quiet ( + ); + /*! + ensures + - this object will not print anything to standard out + !*/ + + void set_num_threads ( + unsigned long num + ); + /*! + ensures + - #get_num_threads() == num + !*/ + + unsigned long get_num_threads ( + ) const; + /*! + ensures + - returns the number of threads used during training. You should + usually set this equal to the number of processing cores on your + machine. + !*/ + + struct invalid_label : public dlib::error + { + /*! + This is the exception thrown by the train() function below. + !*/ + label_type l1, l2; + }; + + trained_function_type train ( + const std::vector& all_samples, + const std::vector& all_labels + ) const; + /*! + requires + - is_learning_problem(all_samples, all_labels) + ensures + - trains a bunch of binary classifiers in a one vs one fashion to solve the given + multiclass classification problem. + - returns a one_vs_one_decision_function F with the following properties: + - F contains all the learned binary classifiers and can be used to predict + the labels of new samples. + - if (new_x is a sample predicted to have a label of L) then + - F(new_x) == L + - F.get_labels() == select_all_distinct_labels(all_labels) + - F.number_of_classes() == select_all_distinct_labels(all_labels).size() + throws + - invalid_label + This exception is thrown if there are labels in all_labels which don't have + any corresponding trainer object. This will never happen if set_trainer(trainer) + has been called. However, if only the set_trainer(trainer,l1,l2) form has been + used then this exception is thrown if not all necessary label pairs have been + given a trainer. + + invalid_label::l1 and invalid_label::l2 will contain the label pair which is + missing a trainer object. Additionally, the exception will contain an + informative error message available via invalid_label::what(). + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_ONE_VS_ONE_TRAiNER_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/svm/pegasos.h b/ml/dlib/dlib/svm/pegasos.h new file mode 100644 index 000000000..c28093fe0 --- /dev/null +++ b/ml/dlib/dlib/svm/pegasos.h @@ -0,0 +1,710 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_PEGASoS_ +#define DLIB_PEGASoS_ + +#include "pegasos_abstract.h" +#include +#include "../algs.h" +#include "function.h" +#include "kernel.h" +#include "kcentroid.h" +#include +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename K + > + class svm_pegasos + { + typedef kcentroid > kc_type; + + public: + typedef K kernel_type; + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + typedef decision_function trained_function_type; + + template + struct rebind { + typedef svm_pegasos other; + }; + + svm_pegasos ( + ) : + max_sv(40), + lambda_c1(0.0001), + lambda_c2(0.0001), + tau(0.01), + tolerance(0.01), + train_count(0), + w(offset_kernel(kernel,tau),tolerance, max_sv, false) + { + } + + svm_pegasos ( + const kernel_type& kernel_, + const scalar_type& lambda_, + const scalar_type& tolerance_, + unsigned long max_num_sv + ) : + max_sv(max_num_sv), + kernel(kernel_), + lambda_c1(lambda_), + lambda_c2(lambda_), + tau(0.01), + tolerance(tolerance_), + train_count(0), + w(offset_kernel(kernel,tau),tolerance, max_sv, false) + { + // make sure requires clause is not broken + DLIB_ASSERT(lambda_ > 0 && tolerance > 0 && max_num_sv > 0, + "\tsvm_pegasos::svm_pegasos(kernel,lambda,tolerance)" + << "\n\t invalid inputs were given to this function" + << "\n\t lambda_: " << lambda_ + << "\n\t max_num_sv: " << max_num_sv + ); + } + + void clear ( + ) + { + // reset the w vector back to its initial state + w = kc_type(offset_kernel(kernel,tau),tolerance, max_sv, false); + train_count = 0; + } + + void set_kernel ( + kernel_type k + ) + { + kernel = k; + clear(); + } + + void set_max_num_sv ( + unsigned long max_num_sv + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(max_num_sv > 0, + "\tvoid svm_pegasos::set_max_num_sv(max_num_sv)" + << "\n\t invalid inputs were given to this function" + << "\n\t max_num_sv: " << max_num_sv + ); + max_sv = max_num_sv; + clear(); + } + + unsigned long get_max_num_sv ( + ) const + { + return max_sv; + } + + void set_tolerance ( + double tol + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(0 < tol, + "\tvoid svm_pegasos::set_tolerance(tol)" + << "\n\t invalid inputs were given to this function" + << "\n\t tol: " << tol + ); + tolerance = tol; + clear(); + } + + void set_lambda ( + scalar_type lambda_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(0 < lambda_, + "\tvoid svm_pegasos::set_lambda(lambda_)" + << "\n\t invalid inputs were given to this function" + << "\n\t lambda_: " << lambda_ + ); + lambda_c1 = lambda_; + lambda_c2 = lambda_; + + max_wnorm = 1/std::sqrt(std::min(lambda_c1, lambda_c2)); + clear(); + } + + void set_lambda_class1 ( + scalar_type lambda_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(0 < lambda_, + "\tvoid svm_pegasos::set_lambda_class1(lambda_)" + << "\n\t invalid inputs were given to this function" + << "\n\t lambda_: " << lambda_ + ); + lambda_c1 = lambda_; + max_wnorm = 1/std::sqrt(std::min(lambda_c1, lambda_c2)); + clear(); + } + + void set_lambda_class2 ( + scalar_type lambda_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(0 < lambda_, + "\tvoid svm_pegasos::set_lambda_class2(lambda_)" + << "\n\t invalid inputs were given to this function" + << "\n\t lambda_: " << lambda_ + ); + lambda_c2 = lambda_; + max_wnorm = 1/std::sqrt(std::min(lambda_c1, lambda_c2)); + clear(); + } + + const scalar_type get_lambda_class1 ( + ) const + { + return lambda_c1; + } + + const scalar_type get_lambda_class2 ( + ) const + { + return lambda_c2; + } + + const scalar_type get_tolerance ( + ) const + { + return tolerance; + } + + const kernel_type get_kernel ( + ) const + { + return kernel; + } + + unsigned long get_train_count ( + ) const + { + return static_cast(train_count); + } + + scalar_type train ( + const sample_type& x, + const scalar_type& y + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(y == -1 || y == 1, + "\tscalar_type svm_pegasos::train(x,y)" + << "\n\t invalid inputs were given to this function" + << "\n\t y: " << y + ); + + const double lambda = (y==+1)? lambda_c1 : lambda_c2; + + ++train_count; + const scalar_type learning_rate = 1/(lambda*train_count); + + // if this sample point is within the margin of the current hyperplane + if (y*w.inner_product(x) < 1) + { + + // compute: w = (1-learning_rate*lambda)*w + y*learning_rate*x + w.train(x, 1 - learning_rate*lambda, y*learning_rate); + + scalar_type wnorm = std::sqrt(w.squared_norm()); + scalar_type temp = max_wnorm/wnorm; + if (temp < 1) + w.scale_by(temp); + } + else + { + w.scale_by(1 - learning_rate*lambda); + } + + // return the current learning rate + return 1/(std::min(lambda_c1,lambda_c2)*train_count); + } + + scalar_type operator() ( + const sample_type& x + ) const + { + return w.inner_product(x); + } + + const decision_function get_decision_function ( + ) const + { + distance_function > df = w.get_distance_function(); + return decision_function(df.get_alpha(), -tau*sum(df.get_alpha()), kernel, df.get_basis_vectors()); + } + + void swap ( + svm_pegasos& item + ) + { + exchange(max_sv, item.max_sv); + exchange(kernel, item.kernel); + exchange(lambda_c1, item.lambda_c1); + exchange(lambda_c2, item.lambda_c2); + exchange(max_wnorm, item.max_wnorm); + exchange(tau, item.tau); + exchange(tolerance, item.tolerance); + exchange(train_count, item.train_count); + exchange(w, item.w); + } + + friend void serialize(const svm_pegasos& item, std::ostream& out) + { + serialize(item.max_sv, out); + serialize(item.kernel, out); + serialize(item.lambda_c1, out); + serialize(item.lambda_c2, out); + serialize(item.max_wnorm, out); + serialize(item.tau, out); + serialize(item.tolerance, out); + serialize(item.train_count, out); + serialize(item.w, out); + } + + friend void deserialize(svm_pegasos& item, std::istream& in) + { + deserialize(item.max_sv, in); + deserialize(item.kernel, in); + deserialize(item.lambda_c1, in); + deserialize(item.lambda_c2, in); + deserialize(item.max_wnorm, in); + deserialize(item.tau, in); + deserialize(item.tolerance, in); + deserialize(item.train_count, in); + deserialize(item.w, in); + } + + private: + + unsigned long max_sv; + kernel_type kernel; + scalar_type lambda_c1; + scalar_type lambda_c2; + scalar_type max_wnorm; + scalar_type tau; + scalar_type tolerance; + scalar_type train_count; + kc_type w; + + }; // end of class svm_pegasos + + template < + typename K + > + void swap ( + svm_pegasos& a, + svm_pegasos& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename U + > + void replicate_settings ( + const svm_pegasos& source, + svm_pegasos& dest + ) + { + dest.set_tolerance(source.get_tolerance()); + dest.set_lambda_class1(source.get_lambda_class1()); + dest.set_lambda_class2(source.get_lambda_class2()); + dest.set_max_num_sv(source.get_max_num_sv()); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type + > + class batch_trainer + { + + // ------------------------------------------------------------------------------------ + + template < + typename K, + typename sample_vector_type + > + class caching_kernel + { + public: + typedef typename K::scalar_type scalar_type; + typedef long sample_type; + //typedef typename K::sample_type sample_type; + typedef typename K::mem_manager_type mem_manager_type; + + caching_kernel () {} + + caching_kernel ( + const K& kern, + const sample_vector_type& samps, + long cache_size_ + ) : real_kernel(kern), samples(&samps), counter(0) + { + cache_size = std::min(cache_size_, samps.size()); + + cache.reset(new cache_type); + cache->frequency_of_use.resize(samps.size()); + for (long i = 0; i < samps.size(); ++i) + cache->frequency_of_use[i] = std::make_pair(0, i); + + // Set the cache build/rebuild threshold so that we have to have + // as many cache misses as there are entries in the cache before + // we build/rebuild. + counter_threshold = samps.size()*cache_size; + cache->sample_location.assign(samples->size(), -1); + } + + scalar_type operator() ( + const sample_type& a, + const sample_type& b + ) const + { + // rebuild the cache every so often + if (counter > counter_threshold ) + { + build_cache(); + } + + const long a_loc = cache->sample_location[a]; + const long b_loc = cache->sample_location[b]; + + cache->frequency_of_use[a].first += 1; + cache->frequency_of_use[b].first += 1; + + if (a_loc != -1) + { + return cache->kernel(a_loc, b); + } + else if (b_loc != -1) + { + return cache->kernel(b_loc, a); + } + else + { + ++counter; + return real_kernel((*samples)(a), (*samples)(b)); + } + } + + bool operator== ( + const caching_kernel& item + ) const + { + return item.real_kernel == real_kernel && + item.samples == samples; + } + + private: + K real_kernel; + + void build_cache ( + ) const + { + std::sort(cache->frequency_of_use.rbegin(), cache->frequency_of_use.rend()); + counter = 0; + + + cache->kernel.set_size(cache_size, samples->size()); + cache->sample_location.assign(samples->size(), -1); + + // loop over all the samples in the cache + for (long i = 0; i < cache_size; ++i) + { + const long cur = cache->frequency_of_use[i].second; + cache->sample_location[cur] = i; + + // now populate all possible kernel products with the current sample + for (long j = 0; j < samples->size(); ++j) + { + cache->kernel(i, j) = real_kernel((*samples)(cur), (*samples)(j)); + } + + } + + // reset the frequency of use metrics + for (long i = 0; i < samples->size(); ++i) + cache->frequency_of_use[i] = std::make_pair(0, i); + } + + + struct cache_type + { + matrix kernel; + + std::vector sample_location; // where in the cache a sample is. -1 means not in cache + std::vector > frequency_of_use; + }; + + const sample_vector_type* samples = 0; + + std::shared_ptr cache; + mutable unsigned long counter = 0; + unsigned long counter_threshold = 0; + long cache_size = 0; + }; + + // ------------------------------------------------------------------------------------ + + public: + typedef typename trainer_type::kernel_type kernel_type; + typedef typename trainer_type::scalar_type scalar_type; + typedef typename trainer_type::sample_type sample_type; + typedef typename trainer_type::mem_manager_type mem_manager_type; + typedef typename trainer_type::trained_function_type trained_function_type; + + + batch_trainer ( + ) : + min_learning_rate(0.1), + use_cache(false), + cache_size(100) + { + } + + batch_trainer ( + const trainer_type& trainer_, + const scalar_type min_learning_rate_, + bool verbose_, + bool use_cache_, + long cache_size_ = 100 + ) : + trainer(trainer_), + min_learning_rate(min_learning_rate_), + verbose(verbose_), + use_cache(use_cache_), + cache_size(cache_size_) + { + // make sure requires clause is not broken + DLIB_ASSERT(0 < min_learning_rate_ && + cache_size_ > 0, + "\tbatch_trainer::batch_trainer()" + << "\n\t invalid inputs were given to this function" + << "\n\t min_learning_rate_: " << min_learning_rate_ + << "\n\t cache_size_: " << cache_size_ + ); + + trainer.clear(); + } + + const scalar_type get_min_learning_rate ( + ) const + { + return min_learning_rate; + } + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y + ) const + { + if (use_cache) + return do_train_cached(mat(x), mat(y)); + else + return do_train(mat(x), mat(y)); + } + + private: + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function do_train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y + ) const + { + + dlib::rand rnd; + + trainer_type my_trainer(trainer); + + scalar_type cur_learning_rate = min_learning_rate + 10; + unsigned long count = 0; + + while (cur_learning_rate > min_learning_rate) + { + const long i = rnd.get_random_32bit_number()%x.size(); + // keep feeding the trainer data until its learning rate goes below our threshold + cur_learning_rate = my_trainer.train(x(i), y(i)); + + if (verbose) + { + if ( (count&0x7FF) == 0) + { + std::cout << "\rbatch_trainer(): Percent complete: " + << 100*min_learning_rate/cur_learning_rate << " " << std::flush; + } + ++count; + } + } + + if (verbose) + { + decision_function df = my_trainer.get_decision_function(); + std::cout << "\rbatch_trainer(): Percent complete: 100 " << std::endl; + std::cout << " Num sv: " << df.basis_vectors.size() << std::endl; + std::cout << " bias: " << df.b << std::endl; + return df; + } + else + { + return my_trainer.get_decision_function(); + } + } + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function do_train_cached ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y + ) const + { + + dlib::rand rnd; + + // make a caching kernel + typedef caching_kernel ckernel_type; + ckernel_type ck(trainer.get_kernel(), x, cache_size); + + // now rebind the trainer to use the caching kernel + typedef typename trainer_type::template rebind::other rebound_trainer_type; + rebound_trainer_type my_trainer; + my_trainer.set_kernel(ck); + replicate_settings(trainer, my_trainer); + + scalar_type cur_learning_rate = min_learning_rate + 10; + unsigned long count = 0; + + while (cur_learning_rate > min_learning_rate) + { + const long i = rnd.get_random_32bit_number()%x.size(); + // keep feeding the trainer data until its learning rate goes below our threshold + cur_learning_rate = my_trainer.train(i, y(i)); + + if (verbose) + { + if ( (count&0x7FF) == 0) + { + std::cout << "\rbatch_trainer(): Percent complete: " + << 100*min_learning_rate/cur_learning_rate << " " << std::flush; + } + ++count; + } + } + + if (verbose) + { + decision_function cached_df; + cached_df = my_trainer.get_decision_function(); + + std::cout << "\rbatch_trainer(): Percent complete: 100 " << std::endl; + std::cout << " Num sv: " << cached_df.basis_vectors.size() << std::endl; + std::cout << " bias: " << cached_df.b << std::endl; + + return decision_function ( + cached_df.alpha, + cached_df.b, + trainer.get_kernel(), + rowm(x, cached_df.basis_vectors) + ); + } + else + { + decision_function cached_df; + cached_df = my_trainer.get_decision_function(); + + return decision_function ( + cached_df.alpha, + cached_df.b, + trainer.get_kernel(), + rowm(x, cached_df.basis_vectors) + ); + } + } + + trainer_type trainer; + scalar_type min_learning_rate; + bool verbose; + bool use_cache; + long cache_size; + + }; // end of class batch_trainer + +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type + > + const batch_trainer batch ( + const trainer_type& trainer, + const typename trainer_type::scalar_type min_learning_rate = 0.1 + ) { return batch_trainer(trainer, min_learning_rate, false, false); } + +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type + > + const batch_trainer verbose_batch ( + const trainer_type& trainer, + const typename trainer_type::scalar_type min_learning_rate = 0.1 + ) { return batch_trainer(trainer, min_learning_rate, true, false); } + +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type + > + const batch_trainer batch_cached ( + const trainer_type& trainer, + const typename trainer_type::scalar_type min_learning_rate = 0.1, + long cache_size = 100 + ) { return batch_trainer(trainer, min_learning_rate, false, true, cache_size); } + +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type + > + const batch_trainer verbose_batch_cached ( + const trainer_type& trainer, + const typename trainer_type::scalar_type min_learning_rate = 0.1, + long cache_size = 100 + ) { return batch_trainer(trainer, min_learning_rate, true, true, cache_size); } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_PEGASoS_ + diff --git a/ml/dlib/dlib/svm/pegasos_abstract.h b/ml/dlib/dlib/svm/pegasos_abstract.h new file mode 100644 index 000000000..008b1cb94 --- /dev/null +++ b/ml/dlib/dlib/svm/pegasos_abstract.h @@ -0,0 +1,514 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_PEGASoS_ABSTRACT_ +#ifdef DLIB_PEGASoS_ABSTRACT_ + +#include +#include "../algs.h" +#include "function_abstract.h" +#include "kernel_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename kern_type + > + class svm_pegasos + { + /*! + REQUIREMENTS ON kern_type + is a kernel function object as defined in dlib/svm/kernel_abstract.h + + WHAT THIS OBJECT REPRESENTS + This object implements an online algorithm for training a support + vector machine for solving binary classification problems. + + The implementation of the Pegasos algorithm used by this object is based + on the following excellent paper: + Pegasos: Primal estimated sub-gradient solver for SVM (2007) + by Shai Shalev-Shwartz, Yoram Singer, Nathan Srebro + In ICML + + This SVM training algorithm has two interesting properties. First, the + pegasos algorithm itself converges to the solution in an amount of time + unrelated to the size of the training set (in addition to being quite fast + to begin with). This makes it an appropriate algorithm for learning from + very large datasets. Second, this object uses the dlib::kcentroid object + to maintain a sparse approximation of the learned decision function. + This means that the number of support vectors in the resulting decision + function is also unrelated to the size of the dataset (in normal SVM + training algorithms, the number of support vectors grows approximately + linearly with the size of the training set). + !*/ + + public: + typedef kern_type kernel_type; + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + typedef decision_function trained_function_type; + + template + struct rebind { + typedef svm_pegasos other; + }; + + svm_pegasos ( + ); + /*! + ensures + - this object is properly initialized + - #get_lambda_class1() == 0.0001 + - #get_lambda_class2() == 0.0001 + - #get_tolerance() == 0.01 + - #get_train_count() == 0 + - #get_max_num_sv() == 40 + !*/ + + svm_pegasos ( + const kernel_type& kernel_, + const scalar_type& lambda_, + const scalar_type& tolerance_, + unsigned long max_num_sv + ); + /*! + requires + - lambda_ > 0 + - tolerance_ > 0 + - max_num_sv > 0 + ensures + - this object is properly initialized + - #get_lambda_class1() == lambda_ + - #get_lambda_class2() == lambda_ + - #get_tolerance() == tolerance_ + - #get_kernel() == kernel_ + - #get_train_count() == 0 + - #get_max_num_sv() == max_num_sv + !*/ + + void clear ( + ); + /*! + ensures + - #get_train_count() == 0 + - clears out any memory of previous calls to train() + - doesn't change any of the algorithm parameters. I.e. + - #get_lambda_class1() == get_lambda_class1() + - #get_lambda_class2() == get_lambda_class2() + - #get_tolerance() == get_tolerance() + - #get_kernel() == get_kernel() + - #get_max_num_sv() == get_max_num_sv() + !*/ + + const scalar_type get_lambda_class1 ( + ) const; + /*! + ensures + - returns the SVM regularization term for the +1 class. It is the + parameter that determines the trade off between trying to fit the + +1 training data exactly or allowing more errors but hopefully + improving the generalization ability of the resulting classifier. + Smaller values encourage exact fitting while larger values may + encourage better generalization. It is also worth noting that the + number of iterations it takes for this algorithm to converge is + proportional to 1/lambda. So smaller values of this term cause + the running time of this algorithm to increase. For more + information you should consult the paper referenced above. + !*/ + + const scalar_type get_lambda_class2 ( + ) const; + /*! + ensures + - returns the SVM regularization term for the -1 class. It has + the same properties as the get_lambda_class1() parameter except that + it applies to the -1 class. + !*/ + + const scalar_type get_tolerance ( + ) const; + /*! + ensures + - returns the tolerance used by the internal kcentroid object to + represent the learned decision function. Smaller values of this + tolerance will result in a more accurate representation of the + decision function but will use more support vectors (up to + a max of get_max_num_sv()). + !*/ + + unsigned long get_max_num_sv ( + ) const; + /*! + ensures + - returns the maximum number of support vectors this object is + allowed to use. + !*/ + + const kernel_type get_kernel ( + ) const; + /*! + ensures + - returns the kernel used by this object + !*/ + + void set_kernel ( + kernel_type k + ); + /*! + ensures + - #get_kernel() == k + - #get_train_count() == 0 + (i.e. clears any memory of previous training) + !*/ + + void set_tolerance ( + double tol + ); + /*! + requires + - tol > 0 + ensures + - #get_tolerance() == tol + - #get_train_count() == 0 + (i.e. clears any memory of previous training) + !*/ + + void set_max_num_sv ( + unsigned long max_num_sv + ); + /*! + requires + - max_num_sv > 0 + ensures + - #get_max_num_sv() == max_num_sv + - #get_train_count() == 0 + (i.e. clears any memory of previous training) + !*/ + + void set_lambda ( + scalar_type lambda_ + ); + /*! + requires + - lambda_ > 0 + ensures + - #get_lambda_class1() == lambda_ + - #get_lambda_class2() == lambda_ + - #get_train_count() == 0 + (i.e. clears any memory of previous training) + !*/ + + void set_lambda_class1 ( + scalar_type lambda_ + ); + /*! + requires + - lambda_ > 0 + ensures + - #get_lambda_class1() == lambda_ + #get_train_count() == 0 + (i.e. clears any memory of previous training) + !*/ + + void set_lambda_class2 ( + scalar_type lambda_ + ); + /*! + requires + - lambda_ > 0 + ensures + - #get_lambda_class2() == lambda_ + #get_train_count() == 0 + (i.e. clears any memory of previous training) + !*/ + + unsigned long get_train_count ( + ) const; + /*! + ensures + - returns how many times this->train() has been called + since this object was constructed or last cleared. + !*/ + + scalar_type train ( + const sample_type& x, + const scalar_type& y + ); + /*! + requires + - y == 1 || y == -1 + ensures + - trains this svm using the given sample x and label y + - #get_train_count() == get_train_count() + 1 + - returns the current learning rate + (i.e. 1/(get_train_count()*min(get_lambda_class1(),get_lambda_class2())) ) + !*/ + + scalar_type operator() ( + const sample_type& x + ) const; + /*! + ensures + - classifies the given x sample using the decision function + this object has learned so far. + - if (x is a sample predicted have +1 label) then + - returns a number >= 0 + - else + - returns a number < 0 + !*/ + + const decision_function get_decision_function ( + ) const; + /*! + ensures + - returns a decision function F that represents the function learned + by this object so far. I.e. it is the case that: + - for all x: F(x) == (*this)(x) + !*/ + + void swap ( + svm_pegasos& item + ); + /*! + ensures + - swaps *this and item + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename kern_type + > + void swap( + svm_pegasos& a, + svm_pegasos& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + + template < + typename kern_type + > + void serialize ( + const svm_pegasos& item, + std::ostream& out + ); + /*! + provides serialization support for svm_pegasos objects + !*/ + + template < + typename kern_type + > + void deserialize ( + svm_pegasos& item, + std::istream& in + ); + /*! + provides serialization support for svm_pegasos objects + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename U + > + void replicate_settings ( + const svm_pegasos& source, + svm_pegasos& dest + ); + /*! + ensures + - copies all the parameters from the source trainer to the dest trainer. + - #dest.get_tolerance() == source.get_tolerance() + - #dest.get_lambda_class1() == source.get_lambda_class1() + - #dest.get_lambda_class2() == source.get_lambda_class2() + - #dest.get_max_num_sv() == source.get_max_num_sv() + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type + > + class batch_trainer + { + /*! + REQUIREMENTS ON trainer_type + - trainer_type == some kind of online trainer object (e.g. svm_pegasos) + replicate_settings() must also be defined for the type. + + WHAT THIS OBJECT REPRESENTS + This is a trainer object that is meant to wrap online trainer objects + that create decision_functions. It turns an online learning algorithm + such as svm_pegasos into a batch learning object. This allows you to + use objects like svm_pegasos with functions (e.g. cross_validate_trainer) + that expect batch mode training objects. + !*/ + + public: + typedef typename trainer_type::kernel_type kernel_type; + typedef typename trainer_type::scalar_type scalar_type; + typedef typename trainer_type::sample_type sample_type; + typedef typename trainer_type::mem_manager_type mem_manager_type; + typedef typename trainer_type::trained_function_type trained_function_type; + + + batch_trainer ( + ); + /*! + ensures + - This object is in an uninitialized state. You must + construct a real one with the other constructor and assign it + to this instance before you use this object. + !*/ + + batch_trainer ( + const trainer_type& online_trainer, + const scalar_type min_learning_rate_, + bool verbose_, + bool use_cache_, + long cache_size_ = 100 + ); + /*! + requires + - min_learning_rate_ > 0 + - cache_size_ > 0 + ensures + - returns a batch trainer object that uses the given online_trainer object + to train a decision function. + - #get_min_learning_rate() == min_learning_rate_ + - if (verbose_ == true) then + - this object will output status messages to standard out while + training is under way. + - if (use_cache_ == true) then + - this object will cache up to cache_size_ columns of the kernel + matrix during the training process. + !*/ + + const scalar_type get_min_learning_rate ( + ) const; + /*! + ensures + - returns the min learning rate that the online trainer must reach + before this object considers training to be complete. + !*/ + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y + ) const; + /*! + ensures + - trains and returns a decision_function using the trainer that was + supplied to this object's constructor. + - training continues until the online training object indicates that + its learning rate has dropped below get_min_learning_rate(). + throws + - std::bad_alloc + - any exceptions thrown by the trainer_type object + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type + > + const batch_trainer batch ( + const trainer_type& trainer, + const typename trainer_type::scalar_type min_learning_rate = 0.1 + ) { return batch_trainer(trainer, min_learning_rate, false, false); } + /*! + requires + - min_learning_rate > 0 + - trainer_type == some kind of online trainer object that creates decision_function + objects (e.g. svm_pegasos). replicate_settings() must also be defined for the type. + ensures + - returns a batch_trainer object that has been instantiated with the + given arguments. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type + > + const batch_trainer verbose_batch ( + const trainer_type& trainer, + const typename trainer_type::scalar_type min_learning_rate = 0.1 + ) { return batch_trainer(trainer, min_learning_rate, true, false); } + /*! + requires + - min_learning_rate > 0 + - trainer_type == some kind of online trainer object that creates decision_function + objects (e.g. svm_pegasos). replicate_settings() must also be defined for the type. + ensures + - returns a batch_trainer object that has been instantiated with the + given arguments (and is verbose). + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type + > + const batch_trainer batch_cached ( + const trainer_type& trainer, + const typename trainer_type::scalar_type min_learning_rate = 0.1, + long cache_size = 100 + ) { return batch_trainer(trainer, min_learning_rate, false, true, cache_size); } + /*! + requires + - min_learning_rate > 0 + - cache_size > 0 + - trainer_type == some kind of online trainer object that creates decision_function + objects (e.g. svm_pegasos). replicate_settings() must also be defined for the type. + ensures + - returns a batch_trainer object that has been instantiated with the + given arguments (uses a kernel cache). + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type + > + const batch_trainer verbose_batch_cached ( + const trainer_type& trainer, + const typename trainer_type::scalar_type min_learning_rate = 0.1, + long cache_size = 100 + ) { return batch_trainer(trainer, min_learning_rate, true, true, cache_size); } + /*! + requires + - min_learning_rate > 0 + - cache_size > 0 + - trainer_type == some kind of online trainer object that creates decision_function + objects (e.g. svm_pegasos). replicate_settings() must also be defined for the type. + ensures + - returns a batch_trainer object that has been instantiated with the + given arguments (is verbose and uses a kernel cache). + !*/ + +// ---------------------------------------------------------------------------------------- + + +} + +#endif // DLIB_PEGASoS_ABSTRACT_ + + diff --git a/ml/dlib/dlib/svm/ranking_tools.h b/ml/dlib/dlib/svm/ranking_tools.h new file mode 100644 index 000000000..3c77b41ae --- /dev/null +++ b/ml/dlib/dlib/svm/ranking_tools.h @@ -0,0 +1,448 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_RANKING_ToOLS_Hh_ +#define DLIB_RANKING_ToOLS_Hh_ + +#include "ranking_tools_abstract.h" + +#include "../algs.h" +#include "../matrix.h" +#include +#include +#include +#include "sparse_vector.h" +#include "../statistics.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + struct ranking_pair + { + ranking_pair() {} + + ranking_pair( + const std::vector& r, + const std::vector& nr + ) : + relevant(r), nonrelevant(nr) + {} + + std::vector relevant; + std::vector nonrelevant; + }; + + template < + typename T + > + void serialize ( + const ranking_pair& item, + std::ostream& out + ) + { + int version = 1; + serialize(version, out); + serialize(item.relevant, out); + serialize(item.nonrelevant, out); + } + + + template < + typename T + > + void deserialize ( + ranking_pair& item, + std::istream& in + ) + { + int version = 0; + deserialize(version, in); + if (version != 1) + throw dlib::serialization_error("Wrong version found while deserializing dlib::ranking_pair"); + + deserialize(item.relevant, in); + deserialize(item.nonrelevant, in); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + typename disable_if,bool>::type is_ranking_problem ( + const std::vector >& samples + ) + { + if (samples.size() == 0) + return false; + + + for (unsigned long i = 0; i < samples.size(); ++i) + { + if (samples[i].relevant.size() == 0) + return false; + if (samples[i].nonrelevant.size() == 0) + return false; + } + + return true; + } + + template < + typename T + > + typename enable_if,bool>::type is_ranking_problem ( + const std::vector >& samples + ) + { + if (samples.size() == 0) + return false; + + + for (unsigned long i = 0; i < samples.size(); ++i) + { + if (samples[i].relevant.size() == 0) + return false; + if (samples[i].nonrelevant.size() == 0) + return false; + } + + // If these are dense vectors then they must all have the same dimensionality. + const long dims = max_index_plus_one(samples[0].relevant); + for (unsigned long i = 0; i < samples.size(); ++i) + { + for (unsigned long j = 0; j < samples[i].relevant.size(); ++j) + { + if (is_vector(samples[i].relevant[j]) == false) + return false; + + if (samples[i].relevant[j].size() != dims) + return false; + } + for (unsigned long j = 0; j < samples[i].nonrelevant.size(); ++j) + { + if (is_vector(samples[i].nonrelevant[j]) == false) + return false; + + if (samples[i].nonrelevant[j].size() != dims) + return false; + } + } + + return true; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + unsigned long max_index_plus_one ( + const ranking_pair& item + ) + { + return std::max(max_index_plus_one(item.relevant), max_index_plus_one(item.nonrelevant)); + } + + template < + typename T + > + unsigned long max_index_plus_one ( + const std::vector >& samples + ) + { + unsigned long dims = 0; + for (unsigned long i = 0; i < samples.size(); ++i) + { + dims = std::max(dims, max_index_plus_one(samples[i])); + } + return dims; + } + +// ---------------------------------------------------------------------------------------- + + template + void count_ranking_inversions ( + const std::vector& x, + const std::vector& y, + std::vector& x_count, + std::vector& y_count + ) + { + x_count.assign(x.size(),0); + y_count.assign(y.size(),0); + + if (x.size() == 0 || y.size() == 0) + return; + + std::vector > xsort(x.size()); + std::vector > ysort(y.size()); + for (unsigned long i = 0; i < x.size(); ++i) + xsort[i] = std::make_pair(x[i], i); + for (unsigned long j = 0; j < y.size(); ++j) + ysort[j] = std::make_pair(y[j], j); + + std::sort(xsort.begin(), xsort.end()); + std::sort(ysort.begin(), ysort.end()); + + + unsigned long i, j; + + // Do the counting for the x values. + for (i = 0, j = 0; i < x_count.size(); ++i) + { + // Skip past y values that are in the correct order with respect to xsort[i]. + while (j < ysort.size() && ysort[j].first < xsort[i].first) + ++j; + + x_count[xsort[i].second] = ysort.size() - j; + } + + + // Now do the counting for the y values. + for (i = 0, j = 0; j < y_count.size(); ++j) + { + // Skip past x values that are in the incorrect order with respect to ysort[j]. + while (i < xsort.size() && !(ysort[j].first < xsort[i].first)) + ++i; + + y_count[ysort[j].second] = i; + } + } + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + inline bool compare_first_reverse_second ( + const std::pair& a, + const std::pair& b + ) + { + if (a.first < b.first) + return true; + else if (a.first > b.first) + return false; + else if (a.second && !b.second) + return true; + else + return false; + } + } + + template < + typename ranking_function, + typename T + > + matrix test_ranking_function ( + const ranking_function& funct, + const std::vector >& samples + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(is_ranking_problem(samples), + "\t double test_ranking_function()" + << "\n\t invalid inputs were given to this function" + << "\n\t samples.size(): " << samples.size() + << "\n\t is_ranking_problem(samples): " << is_ranking_problem(samples) + ); + + unsigned long total_pairs = 0; + unsigned long total_wrong = 0; + + std::vector rel_scores; + std::vector nonrel_scores; + std::vector rel_counts; + std::vector nonrel_counts; + + running_stats rs; + std::vector > total_scores; + std::vector total_ranking; + + for (unsigned long i = 0; i < samples.size(); ++i) + { + rel_scores.resize(samples[i].relevant.size()); + nonrel_scores.resize(samples[i].nonrelevant.size()); + total_scores.clear(); + + for (unsigned long k = 0; k < rel_scores.size(); ++k) + { + rel_scores[k] = funct(samples[i].relevant[k]); + total_scores.push_back(std::make_pair(rel_scores[k], true)); + } + + for (unsigned long k = 0; k < nonrel_scores.size(); ++k) + { + nonrel_scores[k] = funct(samples[i].nonrelevant[k]); + total_scores.push_back(std::make_pair(nonrel_scores[k], false)); + } + + // Now compute the average precision for this sample. We need to sort the + // results and the back them into total_ranking. Note that we sort them so + // that, if you get a block of ranking values that are all equal, the elements + // marked as true will come last. This prevents a ranking from outputting a + // constant value for everything and still getting a good MAP score. + std::sort(total_scores.rbegin(), total_scores.rend(), impl::compare_first_reverse_second); + total_ranking.clear(); + for (unsigned long i = 0; i < total_scores.size(); ++i) + total_ranking.push_back(total_scores[i].second); + rs.add(average_precision(total_ranking)); + + + count_ranking_inversions(rel_scores, nonrel_scores, rel_counts, nonrel_counts); + + total_pairs += rel_scores.size()*nonrel_scores.size(); + + // Note that we don't need to look at nonrel_counts since it is redundant with + // the information in rel_counts in this case. + total_wrong += sum(mat(rel_counts)); + } + + const double rank_swaps = static_cast(total_pairs - total_wrong) / total_pairs; + const double mean_average_precision = rs.mean(); + matrix res; + res = rank_swaps, mean_average_precision; + return res; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename ranking_function, + typename T + > + matrix test_ranking_function ( + const ranking_function& funct, + const ranking_pair& sample + ) + { + return test_ranking_function(funct, std::vector >(1,sample)); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type, + typename T + > + matrix cross_validate_ranking_trainer ( + const trainer_type& trainer, + const std::vector >& samples, + const long folds + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(is_ranking_problem(samples) && + 1 < folds && folds <= static_cast(samples.size()), + "\t double cross_validate_ranking_trainer()" + << "\n\t invalid inputs were given to this function" + << "\n\t samples.size(): " << samples.size() + << "\n\t folds: " << folds + << "\n\t is_ranking_problem(samples): " << is_ranking_problem(samples) + ); + + + const long num_in_test = samples.size()/folds; + const long num_in_train = samples.size() - num_in_test; + + + std::vector > samples_test, samples_train; + + + long next_test_idx = 0; + + unsigned long total_pairs = 0; + unsigned long total_wrong = 0; + + std::vector rel_scores; + std::vector nonrel_scores; + std::vector rel_counts; + std::vector nonrel_counts; + + running_stats rs; + std::vector > total_scores; + std::vector total_ranking; + + for (long i = 0; i < folds; ++i) + { + samples_test.clear(); + samples_train.clear(); + + // load up the test samples + for (long cnt = 0; cnt < num_in_test; ++cnt) + { + samples_test.push_back(samples[next_test_idx]); + next_test_idx = (next_test_idx + 1)%samples.size(); + } + + // load up the training samples + long next = next_test_idx; + for (long cnt = 0; cnt < num_in_train; ++cnt) + { + samples_train.push_back(samples[next]); + next = (next + 1)%samples.size(); + } + + + const typename trainer_type::trained_function_type& df = trainer.train(samples_train); + + // check how good df is on the test data + for (unsigned long i = 0; i < samples_test.size(); ++i) + { + rel_scores.resize(samples_test[i].relevant.size()); + nonrel_scores.resize(samples_test[i].nonrelevant.size()); + + total_scores.clear(); + + for (unsigned long k = 0; k < rel_scores.size(); ++k) + { + rel_scores[k] = df(samples_test[i].relevant[k]); + total_scores.push_back(std::make_pair(rel_scores[k], true)); + } + + for (unsigned long k = 0; k < nonrel_scores.size(); ++k) + { + nonrel_scores[k] = df(samples_test[i].nonrelevant[k]); + total_scores.push_back(std::make_pair(nonrel_scores[k], false)); + } + + // Now compute the average precision for this sample. We need to sort the + // results and the back them into total_ranking. Note that we sort them so + // that, if you get a block of ranking values that are all equal, the elements + // marked as true will come last. This prevents a ranking from outputting a + // constant value for everything and still getting a good MAP score. + std::sort(total_scores.rbegin(), total_scores.rend(), impl::compare_first_reverse_second); + total_ranking.clear(); + for (unsigned long i = 0; i < total_scores.size(); ++i) + total_ranking.push_back(total_scores[i].second); + rs.add(average_precision(total_ranking)); + + + count_ranking_inversions(rel_scores, nonrel_scores, rel_counts, nonrel_counts); + + total_pairs += rel_scores.size()*nonrel_scores.size(); + + // Note that we don't need to look at nonrel_counts since it is redundant with + // the information in rel_counts in this case. + total_wrong += sum(mat(rel_counts)); + } + + } // for (long i = 0; i < folds; ++i) + + const double rank_swaps = static_cast(total_pairs - total_wrong) / total_pairs; + const double mean_average_precision = rs.mean(); + matrix res; + res = rank_swaps, mean_average_precision; + return res; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_RANKING_ToOLS_Hh_ + diff --git a/ml/dlib/dlib/svm/ranking_tools_abstract.h b/ml/dlib/dlib/svm/ranking_tools_abstract.h new file mode 100644 index 000000000..af6c7a2e3 --- /dev/null +++ b/ml/dlib/dlib/svm/ranking_tools_abstract.h @@ -0,0 +1,247 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_RANKING_ToOLS_ABSTRACT_Hh_ +#ifdef DLIB_RANKING_ToOLS_ABSTRACT_Hh_ + + +#include "../algs.h" +#include "../matrix.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + struct ranking_pair + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is used to contain a ranking example. In particular, we say + that a good ranking of T objects is one in which all the elements in + this->relevant are ranked higher than the elements of this->nonrelevant. + Therefore, ranking_pair objects are used to represent training examples for + learning-to-rank tasks. + !*/ + + ranking_pair() {} + /*! + ensures + - #relevant.size() == 0 + - #nonrelevant.size() == 0 + !*/ + + ranking_pair( + const std::vector& r, + const std::vector& nr + ) : relevant(r), nonrelevant(nr) {} + /*! + ensures + - #relevant == r + - #nonrelevant == nr + !*/ + + std::vector relevant; + std::vector nonrelevant; + }; + + template < + typename T + > + void serialize ( + const ranking_pair& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + + template < + typename T + > + void deserialize ( + ranking_pair& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + bool is_ranking_problem ( + const std::vector >& samples + ); + /*! + ensures + - returns true if the data in samples represents a valid learning-to-rank + learning problem. That is, this function returns true if all of the + following are true and false otherwise: + - samples.size() > 0 + - for all valid i: + - samples[i].relevant.size() > 0 + - samples[i].nonrelevant.size() > 0 + - if (is_matrix::value == true) then + - All the elements of samples::nonrelevant and samples::relevant must + represent row or column vectors and they must be the same dimension. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + unsigned long max_index_plus_one ( + const ranking_pair& item + ); + /*! + requires + - T must be a dlib::matrix capable of storing column vectors or T must be a + sparse vector type as defined in dlib/svm/sparse_vector_abstract.h. + ensures + - returns std::max(max_index_plus_one(item.relevant), max_index_plus_one(item.nonrelevant)). + Therefore, this function can be used to find the dimensionality of the + vectors stored in item. + !*/ + + template < + typename T + > + unsigned long max_index_plus_one ( + const std::vector >& samples + ); + /*! + requires + - T must be a dlib::matrix capable of storing column vectors or T must be a + sparse vector type as defined in dlib/svm/sparse_vector_abstract.h. + ensures + - returns the maximum of max_index_plus_one(samples[i]) over all valid values + of i. Therefore, this function can be used to find the dimensionality of the + vectors stored in samples + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + void count_ranking_inversions ( + const std::vector& x, + const std::vector& y, + std::vector& x_count, + std::vector& y_count + ); + /*! + requires + - T objects must be copyable + - T objects must be comparable via operator< + ensures + - This function counts how many times we see a y value greater than or equal to + an x value. This is done efficiently in O(n*log(n)) time via the use of + quick sort. + - #x_count.size() == x.size() + - #y_count.size() == y.size() + - for all valid i: + - #x_count[i] == how many times a value in y was >= x[i]. + - for all valid j: + - #y_count[j] == how many times a value in x was <= y[j]. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename ranking_function, + typename T + > + matrix test_ranking_function ( + const ranking_function& funct, + const std::vector >& samples + ); + /*! + requires + - is_ranking_problem(samples) == true + - ranking_function == some kind of decision function object (e.g. decision_function) + ensures + - Tests the given ranking function on the supplied example ranking data and + returns the fraction of ranking pair orderings predicted correctly. This is + a number in the range [0,1] where 0 means everything was incorrectly + predicted while 1 means everything was correctly predicted. This function + also returns the mean average precision. + - In particular, this function returns a matrix M summarizing the results. + Specifically, it returns an M such that: + - M(0) == the fraction of times that the following is true: + - funct(samples[k].relevant[i]) > funct(samples[k].nonrelevant[j]) + (for all valid i,j,k) + - M(1) == the mean average precision of the rankings induced by funct. + (Mean average precision is a number in the range 0 to 1. Moreover, a + mean average precision of 1 means everything was correctly predicted + while smaller values indicate worse rankings. See the documentation + for average_precision() for details of its computation.) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename ranking_function, + typename T + > + matrix test_ranking_function ( + const ranking_function& funct, + const ranking_pair& sample + ); + /*! + requires + - is_ranking_problem(std::vector >(1, sample)) == true + - ranking_function == some kind of decision function object (e.g. decision_function) + ensures + - This is just a convenience routine for calling the above + test_ranking_function() routine. That is, it just copies sample into a + std::vector object and invokes the above test_ranking_function() routine. + This means that calling this function is equivalent to invoking: + return test_ranking_function(funct, std::vector >(1, sample)); + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type, + typename T + > + matrix cross_validate_ranking_trainer ( + const trainer_type& trainer, + const std::vector >& samples, + const long folds + ); + /*! + requires + - is_ranking_problem(samples) == true + - 1 < folds <= samples.size() + - trainer_type == some kind of ranking trainer object (e.g. svm_rank_trainer) + ensures + - Performs k-fold cross validation by using the given trainer to solve the + given ranking problem for the given number of folds. Each fold is tested + using the output of the trainer and the average ranking accuracy as well as + the mean average precision over the number of folds is returned. + - The accuracy is computed the same way test_ranking_function() computes its + accuracy. Therefore, it is a number in the range [0,1] that represents the + fraction of times a ranking pair's ordering was predicted correctly. Similarly, + the mean average precision is computed identically to test_ranking_function(). + In particular, this means that this function returns a matrix M such that: + - M(0) == the ranking accuracy + - M(1) == the mean average precision + - The number of folds used is given by the folds argument. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_RANKING_ToOLS_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/svm/rbf_network.h b/ml/dlib/dlib/svm/rbf_network.h new file mode 100644 index 000000000..23a2c7424 --- /dev/null +++ b/ml/dlib/dlib/svm/rbf_network.h @@ -0,0 +1,162 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_RBf_NETWORK_ +#define DLIB_RBf_NETWORK_ + +#include "../matrix.h" +#include "rbf_network_abstract.h" +#include "kernel.h" +#include "linearly_independent_subset_finder.h" +#include "function.h" +#include "../algs.h" + +namespace dlib +{ + +// ------------------------------------------------------------------------------ + + template < + typename Kern + > + class rbf_network_trainer + { + /*! + This is an implementation of an RBF network trainer that follows + the directions right off Wikipedia basically. So nothing + particularly fancy. Although the way the centers are selected + is somewhat unique. + !*/ + + public: + typedef Kern kernel_type; + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + typedef decision_function trained_function_type; + + rbf_network_trainer ( + ) : + num_centers(10) + { + } + + void set_kernel ( + const kernel_type& k + ) + { + kernel = k; + } + + const kernel_type& get_kernel ( + ) const + { + return kernel; + } + + void set_num_centers ( + const unsigned long num + ) + { + num_centers = num; + } + + unsigned long get_num_centers ( + ) const + { + return num_centers; + } + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y + ) const + { + return do_train(mat(x), mat(y)); + } + + void swap ( + rbf_network_trainer& item + ) + { + exchange(kernel, item.kernel); + exchange(num_centers, item.num_centers); + } + + private: + + // ------------------------------------------------------------------------------------ + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function do_train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y + ) const + { + typedef typename decision_function::scalar_vector_type scalar_vector_type; + + // make sure requires clause is not broken + DLIB_ASSERT(is_learning_problem(x,y), + "\tdecision_function rbf_network_trainer::train(x,y)" + << "\n\t invalid inputs were given to this function" + << "\n\t x.nr(): " << x.nr() + << "\n\t y.nr(): " << y.nr() + << "\n\t x.nc(): " << x.nc() + << "\n\t y.nc(): " << y.nc() + ); + + // use the linearly_independent_subset_finder object to select the centers. So here + // we show it all the data samples so it can find the best centers. + linearly_independent_subset_finder lisf(kernel, num_centers); + fill_lisf(lisf, x); + + const long num_centers = lisf.size(); + + // fill the K matrix with the output of the kernel for all the center and sample point pairs + matrix K(x.nr(), num_centers+1); + for (long r = 0; r < x.nr(); ++r) + { + for (long c = 0; c < num_centers; ++c) + { + K(r,c) = kernel(x(r), lisf[c]); + } + // This last column of the K matrix takes care of the bias term + K(r,num_centers) = 1; + } + + // compute the best weights by using the pseudo-inverse + scalar_vector_type weights(pinv(K)*y); + + // now put everything into a decision_function object and return it + return decision_function (remove_row(weights,num_centers), + -weights(num_centers), + kernel, + lisf.get_dictionary()); + + } + + kernel_type kernel; + unsigned long num_centers; + + }; // end of class rbf_network_trainer + +// ---------------------------------------------------------------------------------------- + + template + void swap ( + rbf_network_trainer& a, + rbf_network_trainer& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_RBf_NETWORK_ + diff --git a/ml/dlib/dlib/svm/rbf_network_abstract.h b/ml/dlib/dlib/svm/rbf_network_abstract.h new file mode 100644 index 000000000..782a4bdbd --- /dev/null +++ b/ml/dlib/dlib/svm/rbf_network_abstract.h @@ -0,0 +1,132 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_RBf_NETWORK_ABSTRACT_ +#ifdef DLIB_RBf_NETWORK_ABSTRACT_ + +#include "../algs.h" +#include "function_abstract.h" +#include "kernel_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename K + > + class rbf_network_trainer + { + /*! + REQUIREMENTS ON K + is a kernel function object as defined in dlib/svm/kernel_abstract.h + (since this is supposed to be a RBF network it is probably reasonable + to use some sort of radial basis kernel) + + INITIAL VALUE + - get_num_centers() == 10 + + WHAT THIS OBJECT REPRESENTS + This object implements a trainer for a radial basis function network. + + The implementation of this algorithm follows the normal RBF training + process. For more details see the code or the Wikipedia article + about RBF networks. + !*/ + + public: + typedef K kernel_type; + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + typedef decision_function trained_function_type; + + rbf_network_trainer ( + ); + /*! + ensures + - this object is properly initialized + !*/ + + void set_kernel ( + const kernel_type& k + ); + /*! + ensures + - #get_kernel() == k + !*/ + + const kernel_type& get_kernel ( + ) const; + /*! + ensures + - returns a copy of the kernel function in use by this object + !*/ + + void set_num_centers ( + const unsigned long num_centers + ); + /*! + ensures + - #get_num_centers() == num_centers + !*/ + + const unsigned long get_num_centers ( + ) const; + /*! + ensures + - returns the maximum number of centers (a.k.a. basis_vectors in the + trained decision_function) you will get when you train this object on data. + !*/ + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y + ) const + /*! + requires + - x == a matrix or something convertible to a matrix via mat(). + Also, x should contain sample_type objects. + - y == a matrix or something convertible to a matrix via mat(). + Also, y should contain scalar_type objects. + - is_learning_problem(x,y) == true + ensures + - trains a RBF network given the training samples in x and + labels in y and returns the resulting decision_function + throws + - std::bad_alloc + !*/ + + void swap ( + rbf_network_trainer& item + ); + /*! + ensures + - swaps *this and item + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template + void swap ( + rbf_network_trainer& a, + rbf_network_trainer& b + ) { a.swap(b); } + /*! + provides a global swap + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_RBf_NETWORK_ABSTRACT_ + + + diff --git a/ml/dlib/dlib/svm/reduced.h b/ml/dlib/dlib/svm/reduced.h new file mode 100644 index 000000000..b4c5b63ca --- /dev/null +++ b/ml/dlib/dlib/svm/reduced.h @@ -0,0 +1,613 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_REDUCEd_TRAINERS_ +#define DLIB_REDUCEd_TRAINERS_ + +#include "reduced_abstract.h" +#include "../matrix.h" +#include "../algs.h" +#include "function.h" +#include "kernel.h" +#include "kcentroid.h" +#include "linearly_independent_subset_finder.h" +#include "../optimization.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type + > + class reduced_decision_function_trainer + { + public: + typedef typename trainer_type::kernel_type kernel_type; + typedef typename trainer_type::scalar_type scalar_type; + typedef typename trainer_type::sample_type sample_type; + typedef typename trainer_type::mem_manager_type mem_manager_type; + typedef typename trainer_type::trained_function_type trained_function_type; + + reduced_decision_function_trainer ( + ) :num_bv(0) {} + + reduced_decision_function_trainer ( + const trainer_type& trainer_, + const unsigned long num_sb_ + ) : + trainer(trainer_), + num_bv(num_sb_) + { + // make sure requires clause is not broken + DLIB_ASSERT(num_bv > 0, + "\t reduced_decision_function_trainer()" + << "\n\t you have given invalid arguments to this function" + << "\n\t num_bv: " << num_bv + ); + } + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(num_bv > 0, + "\t reduced_decision_function_trainer::train(x,y)" + << "\n\t You have tried to use an uninitialized version of this object" + << "\n\t num_bv: " << num_bv ); + return do_train(mat(x), mat(y)); + } + + private: + + // ------------------------------------------------------------------------------------ + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function do_train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y + ) const + { + // get the decision function object we are going to try and approximate + const decision_function& dec_funct = trainer.train(x,y); + + // now find a linearly independent subset of the training points of num_bv points. + linearly_independent_subset_finder lisf(dec_funct.kernel_function, num_bv); + fill_lisf(lisf, x); + + // The next few statements just find the best weights with which to approximate + // the dec_funct object with the smaller set of vectors in the lisf dictionary. This + // is really just a simple application of some linear algebra. For the details + // see page 554 of Learning with kernels by Scholkopf and Smola where they talk + // about "Optimal Expansion Coefficients." + + const kernel_type kern(dec_funct.kernel_function); + + matrix alpha; + + alpha = lisf.get_inv_kernel_marix()*(kernel_matrix(kern,lisf,dec_funct.basis_vectors)*dec_funct.alpha); + + decision_function new_df(alpha, + 0, + kern, + lisf.get_dictionary()); + + // now we have to figure out what the new bias should be. It might be a little + // different since we just messed with all the weights and vectors. + double bias = 0; + for (long i = 0; i < x.nr(); ++i) + { + bias += new_df(x(i)) - dec_funct(x(i)); + } + + new_df.b = bias/x.nr(); + + return new_df; + } + + // ------------------------------------------------------------------------------------ + + trainer_type trainer; + unsigned long num_bv; + + + }; // end of class reduced_decision_function_trainer + + template + const reduced_decision_function_trainer reduced ( + const trainer_type& trainer, + const unsigned long num_bv + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(num_bv > 0, + "\tconst reduced_decision_function_trainer reduced()" + << "\n\t you have given invalid arguments to this function" + << "\n\t num_bv: " << num_bv + ); + + return reduced_decision_function_trainer(trainer, num_bv); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + namespace red_impl + { + + // ------------------------------------------------------------------------------------ + + template + class objective + { + /* + This object represents the objective function we will try to + minimize in approximate_distance_function(). + + The objective is the distance, in kernel induced feature space, between + the original distance function and the approximated version. + + */ + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + public: + objective( + const distance_function& dist_funct_, + matrix& b_, + matrix& out_vectors_ + ) : + dist_funct(dist_funct_), + b(b_), + out_vectors(out_vectors_) + { + } + + const matrix state_to_vector ( + ) const + /*! + ensures + - returns a vector that contains all the information necessary to + reproduce the current state of the approximated distance function + !*/ + { + matrix z(b.nr() + out_vectors.size()*out_vectors(0).nr()); + long i = 0; + for (long j = 0; j < b.nr(); ++j) + { + z(i) = b(j); + ++i; + } + + for (long j = 0; j < out_vectors.size(); ++j) + { + for (long k = 0; k < out_vectors(j).size(); ++k) + { + z(i) = out_vectors(j)(k); + ++i; + } + } + return z; + } + + + void vector_to_state ( + const matrix& z + ) const + /*! + requires + - z came from the state_to_vector() function or has a compatible format + ensures + - loads the vector z into the state variables of the approximate + distance function (i.e. b and out_vectors) + !*/ + { + long i = 0; + for (long j = 0; j < b.nr(); ++j) + { + b(j) = z(i); + ++i; + } + + for (long j = 0; j < out_vectors.size(); ++j) + { + for (long k = 0; k < out_vectors(j).size(); ++k) + { + out_vectors(j)(k) = z(i); + ++i; + } + } + } + + double operator() ( + const matrix& z + ) const + /*! + ensures + - loads the current approximate distance function with z + - returns the distance between the original distance function + and the approximate one. + !*/ + { + vector_to_state(z); + const kernel_type k(dist_funct.get_kernel()); + + double temp = 0; + for (long i = 0; i < out_vectors.size(); ++i) + { + for (long j = 0; j < dist_funct.get_basis_vectors().nr(); ++j) + { + temp -= b(i)*dist_funct.get_alpha()(j)*k(out_vectors(i), dist_funct.get_basis_vectors()(j)); + } + } + + temp *= 2; + + for (long i = 0; i < out_vectors.size(); ++i) + { + for (long j = 0; j < out_vectors.size(); ++j) + { + temp += b(i)*b(j)*k(out_vectors(i), out_vectors(j)); + } + } + + return temp + dist_funct.get_squared_norm(); + } + + private: + + const distance_function& dist_funct; + matrix& b; + matrix& out_vectors; + + }; + + // ------------------------------------------------------------------------------------ + + template + class objective_derivative + { + /*! + This object represents the derivative of the objective object + !*/ + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + public: + + + objective_derivative( + const distance_function& dist_funct_, + matrix& b_, + matrix& out_vectors_ + ) : + dist_funct(dist_funct_), + b(b_), + out_vectors(out_vectors_) + { + } + + void vector_to_state ( + const matrix& z + ) const + /*! + requires + - z came from the state_to_vector() function or has a compatible format + ensures + - loads the vector z into the state variables of the approximate + distance function (i.e. b and out_vectors) + !*/ + { + long i = 0; + for (long j = 0; j < b.nr(); ++j) + { + b(j) = z(i); + ++i; + } + + for (long j = 0; j < out_vectors.size(); ++j) + { + for (long k = 0; k < out_vectors(j).size(); ++k) + { + out_vectors(j)(k) = z(i); + ++i; + } + } + } + + const matrix& operator() ( + const matrix& z + ) const + /*! + ensures + - loads the current approximate distance function with z + - returns the derivative of the distance between the original + distance function and the approximate one. + !*/ + { + vector_to_state(z); + res.set_size(z.nr()); + set_all_elements(res,0); + const kernel_type k(dist_funct.get_kernel()); + const kernel_derivative K_der(k); + + // first compute the gradient for the beta weights + for (long i = 0; i < out_vectors.size(); ++i) + { + for (long j = 0; j < out_vectors.size(); ++j) + { + res(i) += b(j)*k(out_vectors(i), out_vectors(j)); + } + } + for (long i = 0; i < out_vectors.size(); ++i) + { + for (long j = 0; j < dist_funct.get_basis_vectors().size(); ++j) + { + res(i) -= dist_funct.get_alpha()(j)*k(out_vectors(i), dist_funct.get_basis_vectors()(j)); + } + } + + + // now compute the gradient of the actual vectors that go into the kernel functions + long pos = out_vectors.size(); + const long num = out_vectors(0).nr(); + temp.set_size(num,1); + for (long i = 0; i < out_vectors.size(); ++i) + { + set_all_elements(temp,0); + for (long j = 0; j < out_vectors.size(); ++j) + { + temp += b(j)*K_der(out_vectors(j), out_vectors(i)); + } + for (long j = 0; j < dist_funct.get_basis_vectors().nr(); ++j) + { + temp -= dist_funct.get_alpha()(j)*K_der(dist_funct.get_basis_vectors()(j), out_vectors(i) ); + } + + // store the gradient for out_vectors(i) into result in the proper spot + set_subm(res,pos,0,num,1) = b(i)*temp; + pos += num; + } + + + res *= 2; + return res; + } + + private: + + mutable matrix res; + mutable sample_type temp; + + const distance_function& dist_funct; + matrix& b; + matrix& out_vectors; + + }; + + // ------------------------------------------------------------------------------------ + + } + + template < + typename K, + typename stop_strategy_type, + typename T + > + distance_function approximate_distance_function ( + stop_strategy_type stop_strategy, + const distance_function& target, + const T& starting_basis + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(target.get_basis_vectors().size() > 0 && + starting_basis.size() > 0, + "\t distance_function approximate_distance_function()" + << "\n\t Invalid inputs were given to this function." + << "\n\t target.get_basis_vectors().size(): " << target.get_basis_vectors().size() + << "\n\t starting_basis.size(): " << starting_basis.size() + ); + + using namespace red_impl; + // The next few statements just find the best weights with which to approximate + // the target object with the set of basis vectors in starting_basis. This + // is really just a simple application of some linear algebra. For the details + // see page 554 of Learning with kernels by Scholkopf and Smola where they talk + // about "Optimal Expansion Coefficients." + + const K kern(target.get_kernel()); + typedef typename K::scalar_type scalar_type; + typedef typename K::sample_type sample_type; + typedef typename K::mem_manager_type mem_manager_type; + + matrix beta; + + // Now we compute the fist approximate distance function. + beta = pinv(kernel_matrix(kern,starting_basis)) * + (kernel_matrix(kern,starting_basis,target.get_basis_vectors())*target.get_alpha()); + matrix out_vectors(mat(starting_basis)); + + + // Now setup to do a global optimization of all the parameters in the approximate + // distance function. + const objective obj(target, beta, out_vectors); + const objective_derivative obj_der(target, beta, out_vectors); + matrix opt_starting_point(obj.state_to_vector()); + + + // perform a full optimization of all the parameters (i.e. both beta and the basis vectors together) + find_min(lbfgs_search_strategy(20), + stop_strategy, + obj, obj_der, opt_starting_point, 0); + + // now make sure that the final optimized value is loaded into the beta and + // out_vectors matrices + obj.vector_to_state(opt_starting_point); + + // Do a final reoptimization of beta just to make sure it is optimal given the new + // set of basis vectors. + beta = pinv(kernel_matrix(kern,out_vectors))*(kernel_matrix(kern,out_vectors,target.get_basis_vectors())*target.get_alpha()); + + // It is possible that some of the beta weights will be very close to zero. Lets remove + // the basis vectors with these essentially zero weights. + const scalar_type eps = max(abs(beta))*std::numeric_limits::epsilon(); + for (long i = 0; i < beta.size(); ++i) + { + // if beta(i) is zero (but leave at least one beta no matter what) + if (std::abs(beta(i)) < eps && beta.size() > 1) + { + beta = remove_row(beta, i); + out_vectors = remove_row(out_vectors, i); + --i; + } + } + + return distance_function(beta, kern, out_vectors); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type + > + class reduced_decision_function_trainer2 + { + public: + typedef typename trainer_type::kernel_type kernel_type; + typedef typename trainer_type::scalar_type scalar_type; + typedef typename trainer_type::sample_type sample_type; + typedef typename trainer_type::mem_manager_type mem_manager_type; + typedef typename trainer_type::trained_function_type trained_function_type; + + reduced_decision_function_trainer2 () : num_bv(0) {} + reduced_decision_function_trainer2 ( + const trainer_type& trainer_, + const long num_sb_, + const double eps_ = 1e-3 + ) : + trainer(trainer_), + num_bv(num_sb_), + eps(eps_) + { + COMPILE_TIME_ASSERT(is_matrix::value); + + // make sure requires clause is not broken + DLIB_ASSERT(num_bv > 0 && eps > 0, + "\t reduced_decision_function_trainer2()" + << "\n\t you have given invalid arguments to this function" + << "\n\t num_bv: " << num_bv + << "\n\t eps: " << eps + ); + } + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(num_bv > 0, + "\t reduced_decision_function_trainer2::train(x,y)" + << "\n\t You have tried to use an uninitialized version of this object" + << "\n\t num_bv: " << num_bv ); + return do_train(mat(x), mat(y)); + } + + private: + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function do_train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y + ) const + { + // get the decision function object we are going to try and approximate + const decision_function& dec_funct = trainer.train(x,y); + const kernel_type kern(dec_funct.kernel_function); + + // now find a linearly independent subset of the training points of num_bv points. + linearly_independent_subset_finder lisf(kern, num_bv); + fill_lisf(lisf,x); + + distance_function approx, target; + target = dec_funct; + approx = approximate_distance_function(objective_delta_stop_strategy(eps), target, lisf); + + decision_function new_df(approx.get_alpha(), + 0, + kern, + approx.get_basis_vectors()); + + // now we have to figure out what the new bias should be. It might be a little + // different since we just messed with all the weights and vectors. + double bias = 0; + for (long i = 0; i < x.nr(); ++i) + { + bias += new_df(x(i)) - dec_funct(x(i)); + } + + new_df.b = bias/x.nr(); + + return new_df; + + } + + // ------------------------------------------------------------------------------------ + + trainer_type trainer; + long num_bv; + double eps; + + + }; // end of class reduced_decision_function_trainer2 + + template + const reduced_decision_function_trainer2 reduced2 ( + const trainer_type& trainer, + const long num_bv, + double eps = 1e-3 + ) + { + COMPILE_TIME_ASSERT(is_matrix::value); + + // make sure requires clause is not broken + DLIB_ASSERT(num_bv > 0 && eps > 0, + "\tconst reduced_decision_function_trainer2 reduced2()" + << "\n\t you have given invalid arguments to this function" + << "\n\t num_bv: " << num_bv + << "\n\t eps: " << eps + ); + + return reduced_decision_function_trainer2(trainer, num_bv, eps); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_REDUCEd_TRAINERS_ + diff --git a/ml/dlib/dlib/svm/reduced_abstract.h b/ml/dlib/dlib/svm/reduced_abstract.h new file mode 100644 index 000000000..8b186c033 --- /dev/null +++ b/ml/dlib/dlib/svm/reduced_abstract.h @@ -0,0 +1,267 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_REDUCED_TRAINERs_ABSTRACT_ +#ifdef DLIB_REDUCED_TRAINERs_ABSTRACT_ + +#include "../matrix.h" +#include "../algs.h" +#include "function_abstract.h" +#include "kernel_abstract.h" +#include "../optimization.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type + > + class reduced_decision_function_trainer + { + /*! + REQUIREMENTS ON trainer_type + - trainer_type == some kind of batch trainer object (e.g. svm_nu_trainer) + + WHAT THIS OBJECT REPRESENTS + This object represents an implementation of a reduced set algorithm. + This object acts as a post processor for anything that creates + decision_function objects. It wraps another trainer object and + performs this reduced set post processing with the goal of + representing the original decision function in a form that + involves fewer basis vectors. + !*/ + + public: + typedef typename trainer_type::kernel_type kernel_type; + typedef typename trainer_type::scalar_type scalar_type; + typedef typename trainer_type::sample_type sample_type; + typedef typename trainer_type::mem_manager_type mem_manager_type; + typedef typename trainer_type::trained_function_type trained_function_type; + + reduced_decision_function_trainer ( + ); + /*! + ensures + - This object is in an uninitialized state. You must + construct a real one with the other constructor and assign it + to this instance before you use this object. + !*/ + + reduced_decision_function_trainer ( + const trainer_type& trainer, + const unsigned long num_bv + ); + /*! + requires + - num_bv > 0 + ensures + - returns a trainer object that applies post processing to the decision_function + objects created by the given trainer object with the goal of creating + decision_function objects with fewer basis vectors. + - The reduced decision functions that are output will have at most + num_bv basis vectors. + !*/ + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y + ) const; + /*! + ensures + - trains a decision_function using the trainer that was supplied to + this object's constructor and then finds a reduced representation + for it and returns the reduced version. + throws + - std::bad_alloc + - any exceptions thrown by the trainer_type object + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type + > + const reduced_decision_function_trainer reduced ( + const trainer_type& trainer, + const unsigned long num_bv + ) { return reduced_decision_function_trainer(trainer, num_bv); } + /*! + requires + - num_bv > 0 + - trainer_type == some kind of batch trainer object that creates decision_function + objects (e.g. svm_nu_trainer) + ensures + - returns a reduced_decision_function_trainer object that has been + instantiated with the given arguments. + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename K, + typename stop_strategy_type, + typename T + > + distance_function approximate_distance_function ( + stop_strategy_type stop_strategy, + const distance_function& target, + const T& starting_basis + ); + /*! + requires + - stop_strategy == an object that defines a stop strategy such as one of + the objects from dlib/optimization/optimization_stop_strategies_abstract.h + - requirements on starting_basis + - T must be a dlib::matrix type or something convertible to a matrix via mat() + (e.g. a std::vector). Additionally, starting_basis must contain K::sample_type + objects which can be supplied to the kernel function used by target. + - is_vector(starting_basis) == true + - starting_basis.size() > 0 + - target.get_basis_vectors().size() > 0 + - kernel_derivative is defined + (i.e. The analytic derivative for the given kernel must be defined) + - K::sample_type must be a dlib::matrix object and the basis_vectors inside target + and starting_basis must be column vectors. + ensures + - This routine attempts to find a distance_function object which is close + to the given target. That is, it searches for an X such that target(X) is + minimized. The optimization begins with an X in the span of the elements + of starting_basis and searches for an X which locally minimizes target(X). + Since this problem can have many local minima, the quality of the starting + basis can significantly influence the results. + - The optimization is over all variables in a distance_function, however, + the size of the basis set is constrained to no more than starting_basis.size(). + That is, in the returned distance_function DF, we will have: + - DF.get_basis_vectors().size() <= starting_basis.size() + - The optimization is carried out until the stop_strategy indicates it + should stop. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type + > + class reduced_decision_function_trainer2 + { + /*! + REQUIREMENTS ON trainer_type + - trainer_type == some kind of batch trainer object (e.g. svm_nu_trainer) + - trainer_type::sample_type must be a dlib::matrix object + - kernel_derivative must be defined + + WHAT THIS OBJECT REPRESENTS + This object represents an implementation of a reduced set algorithm. + This object acts as a post processor for anything that creates + decision_function objects. It wraps another trainer object and + performs this reduced set post processing with the goal of + representing the original decision function in a form that + involves fewer basis vectors. + + This object's implementation is the same as that in the above + reduced_decision_function_trainer object except it also performs + a global gradient based optimization at the end to further + improve the approximation to the original decision function + object. + !*/ + + public: + typedef typename trainer_type::kernel_type kernel_type; + typedef typename trainer_type::scalar_type scalar_type; + typedef typename trainer_type::sample_type sample_type; + typedef typename trainer_type::mem_manager_type mem_manager_type; + typedef typename trainer_type::trained_function_type trained_function_type; + + reduced_decision_function_trainer2 ( + ); + /*! + ensures + - This object is in an uninitialized state. You must + construct a real one with the other constructor and assign it + to this instance before you use this object. + !*/ + + reduced_decision_function_trainer2 ( + const trainer_type& trainer, + const unsigned long num_bv, + double eps = 1e-3 + ); + /*! + requires + - num_bv > 0 + - eps > 0 + ensures + - returns a trainer object that applies post processing to the decision_function + objects created by the given trainer object with the goal of creating + decision_function objects with fewer basis vectors. + - The reduced decision functions that are output will have at most + num_bv basis vectors. + - the gradient based optimization will continue until the change in the + objective function is less than eps. So smaller values of eps will + give better results but take longer to compute. + !*/ + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y + ) const; + /*! + requires + - x must be a list of objects which are each some kind of dlib::matrix + which represents column or row vectors. + ensures + - trains a decision_function using the trainer that was supplied to + this object's constructor and then finds a reduced representation + for it and returns the reduced version. + throws + - std::bad_alloc + - any exceptions thrown by the trainer_type object + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type + > + const reduced_decision_function_trainer2 reduced2 ( + const trainer_type& trainer, + const unsigned long num_bv, + double eps = 1e-3 + ) { return reduced_decision_function_trainer2(trainer, num_bv, eps); } + /*! + requires + - num_bv > 0 + - trainer_type == some kind of batch trainer object that creates decision_function + objects (e.g. svm_nu_trainer) + - kernel_derivative is defined + - eps > 0 + ensures + - returns a reduced_decision_function_trainer2 object that has been + instantiated with the given arguments. + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_REDUCED_TRAINERs_ABSTRACT_ + diff --git a/ml/dlib/dlib/svm/rls.h b/ml/dlib/dlib/svm/rls.h new file mode 100644 index 000000000..edee6b062 --- /dev/null +++ b/ml/dlib/dlib/svm/rls.h @@ -0,0 +1,232 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_RLs_Hh_ +#define DLIB_RLs_Hh_ + +#include "rls_abstract.h" +#include "../matrix.h" +#include "function.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class rls + { + + public: + + + explicit rls( + double forget_factor_, + double C_ = 1000, + bool apply_forget_factor_to_C_ = false + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(0 < forget_factor_ && forget_factor_ <= 1 && + 0 < C_, + "\t rls::rls()" + << "\n\t invalid arguments were given to this function" + << "\n\t forget_factor_: " << forget_factor_ + << "\n\t C_: " << C_ + << "\n\t this: " << this + ); + + + C = C_; + forget_factor = forget_factor_; + apply_forget_factor_to_C = apply_forget_factor_to_C_; + } + + rls( + ) + { + C = 1000; + forget_factor = 1; + apply_forget_factor_to_C = false; + } + + double get_c( + ) const + { + return C; + } + + double get_forget_factor( + ) const + { + return forget_factor; + } + + bool should_apply_forget_factor_to_C ( + ) const + { + return apply_forget_factor_to_C; + } + + template + void train ( + const matrix_exp& x, + double y + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(is_col_vector(x) && + (get_w().size() == 0 || get_w().size() == x.size()), + "\t void rls::train()" + << "\n\t invalid arguments were given to this function" + << "\n\t is_col_vector(x): " << is_col_vector(x) + << "\n\t x.size(): " << x.size() + << "\n\t get_w().size(): " << get_w().size() + << "\n\t this: " << this + ); + + if (R.size() == 0) + { + R = identity_matrix(x.size())*C; + w.set_size(x.size()); + w = 0; + } + + // multiply by forget factor and incorporate x*trans(x) into R. + const double l = 1.0/forget_factor; + const double temp = 1 + l*trans(x)*R*x; + tmp = R*x; + R = l*R - l*l*(tmp*trans(tmp))/temp; + + // Since we multiplied by the forget factor, we need to add (1-forget_factor) of the + // identity matrix back in to keep the regularization alive. + if (forget_factor != 1 && !apply_forget_factor_to_C) + add_eye_to_inv(R, (1-forget_factor)/C); + + // R should always be symmetric. This line improves numeric stability of this algorithm. + if (cnt%10 == 0) + R = 0.5*(R + trans(R)); + ++cnt; + + w = w + R*x*(y - trans(x)*w); + + } + + + + const matrix& get_w( + ) const + { + return w; + } + + template + double operator() ( + const matrix_exp& x + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(is_col_vector(x) && get_w().size() == x.size(), + "\t double rls::operator()()" + << "\n\t invalid arguments were given to this function" + << "\n\t is_col_vector(x): " << is_col_vector(x) + << "\n\t x.size(): " << x.size() + << "\n\t get_w().size(): " << get_w().size() + << "\n\t this: " << this + ); + + return dot(x,w); + } + + decision_function > > get_decision_function ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(get_w().size() != 0, + "\t decision_function rls::get_decision_function()" + << "\n\t invalid arguments were given to this function" + << "\n\t get_w().size(): " << get_w().size() + << "\n\t this: " << this + ); + + decision_function > > df; + df.alpha.set_size(1); + df.basis_vectors.set_size(1); + df.b = 0; + df.alpha = 1; + df.basis_vectors(0) = w; + + return df; + } + + friend inline void serialize(const rls& item, std::ostream& out) + { + int version = 2; + serialize(version, out); + serialize(item.w, out); + serialize(item.R, out); + serialize(item.C, out); + serialize(item.forget_factor, out); + serialize(item.cnt, out); + serialize(item.apply_forget_factor_to_C, out); + } + + friend inline void deserialize(rls& item, std::istream& in) + { + int version = 0; + deserialize(version, in); + if (!(1 <= version && version <= 2)) + throw dlib::serialization_error("Unknown version number found while deserializing rls object."); + + if (version >= 1) + { + deserialize(item.w, in); + deserialize(item.R, in); + deserialize(item.C, in); + deserialize(item.forget_factor, in); + } + item.cnt = 0; + item.apply_forget_factor_to_C = false; + if (version >= 2) + { + deserialize(item.cnt, in); + deserialize(item.apply_forget_factor_to_C, in); + } + } + + private: + + void add_eye_to_inv( + matrix& m, + double C + ) + /*! + ensures + - Let m == inv(M) + - this function returns inv(M + C*identity_matrix(m.nr())) + !*/ + { + for (long r = 0; r < m.nr(); ++r) + { + m = m - colm(m,r)*trans(colm(m,r))/(1/C + m(r,r)); + } + } + + + matrix w; + matrix R; + double C; + double forget_factor; + int cnt = 0; + bool apply_forget_factor_to_C; + + + // This object is here only to avoid reallocation during training. It don't + // logically contribute to the state of this object. + matrix tmp; + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_RLs_Hh_ + diff --git a/ml/dlib/dlib/svm/rls_abstract.h b/ml/dlib/dlib/svm/rls_abstract.h new file mode 100644 index 000000000..c593e4330 --- /dev/null +++ b/ml/dlib/dlib/svm/rls_abstract.h @@ -0,0 +1,175 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_RLs_ABSTRACT_Hh_ +#ifdef DLIB_RLs_ABSTRACT_Hh_ + +#include "../matrix/matrix_abstract.h" +#include "function_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class rls + { + /*! + WHAT THIS OBJECT REPRESENTS + This is an implementation of the linear version of the recursive least + squares algorithm. It accepts training points incrementally and, at + each step, maintains the solution to the following optimization problem: + find w minimizing: 0.5*dot(w,w) + C*sum_i(y_i - trans(x_i)*w)^2 + Where (x_i,y_i) are training pairs. x_i is some vector and y_i is a target + scalar value. + + This object can also be configured to use exponential forgetting. This is + where each training example is weighted by pow(forget_factor, i), where i + indicates the sample's age. So older samples are weighted less in the + least squares solution and therefore become forgotten after some time. + Therefore, with forgetting, this object solves the following optimization + problem at each step: + find w minimizing: 0.5*dot(w,w) + C*sum_i pow(forget_factor, i)*(y_i - trans(x_i)*w)^2 + Where i starts at 0 and i==0 corresponds to the most recent training point. + !*/ + + public: + + + explicit rls( + double forget_factor, + double C = 1000, + bool apply_forget_factor_to_C = false + ); + /*! + requires + - 0 < forget_factor <= 1 + - 0 < C + ensures + - #get_w().size() == 0 + - #get_c() == C + - #get_forget_factor() == forget_factor + - #should_apply_forget_factor_to_C() == apply_forget_factor_to_C + !*/ + + rls( + ); + /*! + ensures + - #get_w().size() == 0 + - #get_c() == 1000 + - #get_forget_factor() == 1 + - #should_apply_forget_factor_to_C() == false + !*/ + + double get_c( + ) const; + /*! + ensures + - returns the regularization parameter. It is the parameter + that determines the trade-off between trying to fit the training + data or allowing more errors but hopefully improving the generalization + of the resulting regression. Larger values encourage exact fitting while + smaller values of C may encourage better generalization. + !*/ + + double get_forget_factor( + ) const; + /*! + ensures + - returns the exponential forgetting factor. A value of 1 disables forgetting + and results in normal least squares regression. On the other hand, a smaller + value causes the regression to forget about old training examples and prefer + instead to fit more recent examples. The closer the forget factor is to + zero the faster old examples are forgotten. + !*/ + + bool should_apply_forget_factor_to_C ( + ) const; + /*! + ensures + - If this function returns false then it means we are optimizing the + objective function discussed in the WHAT THIS OBJECT REPRESENTS section + above. However, if it returns true then we will allow the forget factor + (get_forget_factor()) to be applied to the C value which causes the + algorithm to slowly increase C and convert into a textbook version of RLS + without regularization. The main reason you might want to do this is + because it can make the algorithm run significantly faster. + !*/ + + template + void train ( + const matrix_exp& x, + double y + ) + /*! + requires + - is_col_vector(x) == true + - if (get_w().size() != 0) then + - x.size() == get_w().size() + (i.e. all training examples must have the same + dimensionality) + ensures + - #get_w().size() == x.size() + - updates #get_w() such that it contains the solution to the least + squares problem of regressing the given x onto the given y as well + as all the previous training examples supplied to train(). + !*/ + + const matrix& get_w( + ) const; + /*! + ensures + - returns the regression weights. These are the values learned by the + least squares procedure. If train() has not been called then this + function returns an empty vector. + !*/ + + template + double operator() ( + const matrix_exp& x + ) const; + /*! + requires + - is_col_vector(x) == true + - get_w().size() == x.size() + ensures + - returns dot(x, get_w()) + !*/ + + decision_function > > get_decision_function ( + ) const; + /*! + requires + - get_w().size() != 0 + ensures + - returns a decision function DF such that: + - DF(x) == dot(x, get_w()) + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + void serialize ( + const rls& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + + void deserialize ( + rls& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_RLs_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/svm/roc_trainer.h b/ml/dlib/dlib/svm/roc_trainer.h new file mode 100644 index 000000000..fa2c0ef9b --- /dev/null +++ b/ml/dlib/dlib/svm/roc_trainer.h @@ -0,0 +1,149 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ROC_TRAINEr_H_ +#define DLIB_ROC_TRAINEr_H_ + +#include "roc_trainer_abstract.h" +#include "../algs.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type + > + class roc_trainer_type + { + public: + typedef typename trainer_type::kernel_type kernel_type; + typedef typename trainer_type::scalar_type scalar_type; + typedef typename trainer_type::sample_type sample_type; + typedef typename trainer_type::mem_manager_type mem_manager_type; + typedef typename trainer_type::trained_function_type trained_function_type; + + roc_trainer_type ( + ) : desired_accuracy(0), class_selection(0){} + + roc_trainer_type ( + const trainer_type& trainer_, + const scalar_type& desired_accuracy_, + const scalar_type& class_selection_ + ) : trainer(trainer_), desired_accuracy(desired_accuracy_), class_selection(class_selection_) + { + // make sure requires clause is not broken + DLIB_ASSERT(0 <= desired_accuracy && desired_accuracy <= 1 && + (class_selection == -1 || class_selection == +1), + "\t roc_trainer_type::roc_trainer_type()" + << "\n\t invalid inputs were given to this function" + << "\n\t desired_accuracy: " << desired_accuracy + << "\n\t class_selection: " << class_selection + ); + } + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const trained_function_type train ( + const in_sample_vector_type& samples, + const in_scalar_vector_type& labels + ) const + /*! + requires + - is_binary_classification_problem(samples, labels) == true + !*/ + { + // make sure requires clause is not broken + DLIB_ASSERT(is_binary_classification_problem(samples, labels), + "\t roc_trainer_type::train()" + << "\n\t invalid inputs were given to this function" + ); + + + return do_train(mat(samples), mat(labels)); + } + + private: + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const trained_function_type do_train ( + const in_sample_vector_type& samples, + const in_scalar_vector_type& labels + ) const + { + trained_function_type df = trainer.train(samples, labels); + + // clear out the old bias + df.b = 0; + + // obtain all the scores from the df using all the class_selection labeled samples + std::vector scores; + for (long i = 0; i < samples.size(); ++i) + { + if (labels(i) == class_selection) + scores.push_back(df(samples(i))); + } + + if (class_selection == +1) + std::sort(scores.rbegin(), scores.rend()); + else + std::sort(scores.begin(), scores.end()); + + // now pick out the index that gives us the desired accuracy with regards to selected class + unsigned long idx = static_cast(desired_accuracy*scores.size() + 0.5); + if (idx >= scores.size()) + idx = scores.size()-1; + + df.b = scores[idx]; + + // In this case add a very small extra amount to the bias so that all the samples + // with the class_selection label are classified correctly. + if (desired_accuracy == 1) + { + if (class_selection == +1) + df.b -= std::numeric_limits::epsilon()*df.b; + else + df.b += std::numeric_limits::epsilon()*df.b; + } + + return df; + } + + trainer_type trainer; + scalar_type desired_accuracy; + scalar_type class_selection; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type + > + const roc_trainer_type roc_c1_trainer ( + const trainer_type& trainer, + const typename trainer_type::scalar_type& desired_accuracy + ) { return roc_trainer_type(trainer, desired_accuracy, +1); } + +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type + > + const roc_trainer_type roc_c2_trainer ( + const trainer_type& trainer, + const typename trainer_type::scalar_type& desired_accuracy + ) { return roc_trainer_type(trainer, desired_accuracy, -1); } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_ROC_TRAINEr_H_ + + diff --git a/ml/dlib/dlib/svm/roc_trainer_abstract.h b/ml/dlib/dlib/svm/roc_trainer_abstract.h new file mode 100644 index 000000000..74e6f9b65 --- /dev/null +++ b/ml/dlib/dlib/svm/roc_trainer_abstract.h @@ -0,0 +1,135 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_ROC_TRAINEr_ABSTRACT_ +#ifdef DLIB_ROC_TRAINEr_ABSTRACT_ + +#include "../algs.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type + > + class roc_trainer_type + { + /*! + REQUIREMENTS ON trainer_type + - trainer_type == some kind of batch trainer object (e.g. svm_nu_trainer) + + WHAT THIS OBJECT REPRESENTS + This object is a simple trainer post processor that allows you to + easily adjust the bias term in a trained decision_function object. + That is, this object lets you pick a point on the ROC curve and + it will adjust the bias term appropriately. + + So for example, suppose you wanted to set the bias term so that + the accuracy of your decision function on +1 labeled samples was 99%. + To do this you would use an instance of this object declared as follows: + roc_trainer_type(your_trainer, 0.99, +1); + !*/ + + public: + typedef typename trainer_type::kernel_type kernel_type; + typedef typename trainer_type::scalar_type scalar_type; + typedef typename trainer_type::sample_type sample_type; + typedef typename trainer_type::mem_manager_type mem_manager_type; + typedef typename trainer_type::trained_function_type trained_function_type; + + roc_trainer_type ( + ); + /*! + ensures + - This object is in an uninitialized state. You must + construct a real one with the other constructor and assign it + to this instance before you use this object. + !*/ + + roc_trainer_type ( + const trainer_type& trainer_, + const scalar_type& desired_accuracy_, + const scalar_type& class_selection_ + ); + /*! + requires + - 0 <= desired_accuracy_ <= 1 + - class_selection_ == +1 or -1 + ensures + - when training is performed using this object it will automatically + adjust the bias term in the returned decision function so that it + achieves the desired accuracy on the selected class type. + !*/ + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const trained_function_type train ( + const in_sample_vector_type& samples, + const in_scalar_vector_type& labels + ) const + /*! + requires + - is_binary_classification_problem(samples, labels) == true + - x == a matrix or something convertible to a matrix via mat(). + Also, x should contain sample_type objects. + - y == a matrix or something convertible to a matrix via mat(). + Also, y should contain scalar_type objects. + ensures + - performs training using the trainer object given to this object's + constructor, then modifies the bias term in the returned decision function + as discussed above, and finally returns the decision function. + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type + > + const roc_trainer_type roc_c1_trainer ( + const trainer_type& trainer, + const typename trainer_type::scalar_type& desired_accuracy + ) { return roc_trainer_type(trainer, desired_accuracy, +1); } + /*! + requires + - 0 <= desired_accuracy <= 1 + - trainer_type == some kind of batch trainer object that creates decision_function + objects (e.g. svm_nu_trainer) + ensures + - returns a roc_trainer_type object that has been instantiated with the given + arguments. The returned roc trainer will select the decision function + bias that gives the desired accuracy with respect to the +1 class. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type + > + const roc_trainer_type roc_c2_trainer ( + const trainer_type& trainer, + const typename trainer_type::scalar_type& desired_accuracy + ) { return roc_trainer_type(trainer, desired_accuracy, -1); } + /*! + requires + - 0 <= desired_accuracy <= 1 + - trainer_type == some kind of batch trainer object that creates decision_function + objects (e.g. svm_nu_trainer) + ensures + - returns a roc_trainer_type object that has been instantiated with the given + arguments. The returned roc trainer will select the decision function + bias that gives the desired accuracy with respect to the -1 class. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_ROC_TRAINEr_ABSTRACT_ + + + diff --git a/ml/dlib/dlib/svm/rr_trainer.h b/ml/dlib/dlib/svm/rr_trainer.h new file mode 100644 index 000000000..09177217e --- /dev/null +++ b/ml/dlib/dlib/svm/rr_trainer.h @@ -0,0 +1,456 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_RR_TRAInER_Hh_ +#define DLIB_RR_TRAInER_Hh_ + +#include "../algs.h" +#include "function.h" +#include "kernel.h" +#include "empirical_kernel_map.h" +#include "linearly_independent_subset_finder.h" +#include "../statistics.h" +#include "rr_trainer_abstract.h" +#include +#include + +namespace dlib +{ + template < + typename K + > + class rr_trainer + { + + public: + typedef K kernel_type; + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + typedef decision_function trained_function_type; + + // You are getting a compiler error on this line because you supplied a non-linear or + // sparse kernel to the rr_trainer object. You have to use dlib::linear_kernel with this trainer. + COMPILE_TIME_ASSERT((is_same_type >::value)); + + rr_trainer ( + ) : + verbose(false), + use_regression_loss(true), + lambda(0) + { + // default lambda search list + lams = matrix_cast(logspace(-9, 2, 50)); + } + + void be_verbose ( + ) + { + verbose = true; + } + + void be_quiet ( + ) + { + verbose = false; + } + + void use_regression_loss_for_loo_cv ( + ) + { + use_regression_loss = true; + } + + void use_classification_loss_for_loo_cv ( + ) + { + use_regression_loss = false; + } + + bool will_use_regression_loss_for_loo_cv ( + ) const + { + return use_regression_loss; + } + + const kernel_type get_kernel ( + ) const + { + return kernel_type(); + } + + void set_lambda ( + scalar_type lambda_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(lambda_ >= 0, + "\t void rr_trainer::set_lambda()" + << "\n\t lambda must be greater than or equal to 0" + << "\n\t lambda: " << lambda + << "\n\t this: " << this + ); + + lambda = lambda_; + } + + const scalar_type get_lambda ( + ) const + { + return lambda; + } + + template + void set_search_lambdas ( + const matrix_exp& lambdas + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(is_vector(lambdas) && lambdas.size() > 0 && min(lambdas) > 0, + "\t void rr_trainer::set_search_lambdas()" + << "\n\t lambdas must be a non-empty vector of values" + << "\n\t is_vector(lambdas): " << is_vector(lambdas) + << "\n\t lambdas.size(): " << lambdas.size() + << "\n\t min(lambdas): " << min(lambdas) + << "\n\t this: " << this + ); + + + lams = matrix_cast(lambdas); + } + + const matrix& get_search_lambdas ( + ) const + { + return lams; + } + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y + ) const + { + std::vector temp; + scalar_type temp2; + return do_train(mat(x), mat(y), false, temp, temp2); + } + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y, + std::vector& loo_values + ) const + { + scalar_type temp; + return do_train(mat(x), mat(y), true, loo_values, temp); + } + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y, + std::vector& loo_values, + scalar_type& lambda_used + ) const + { + return do_train(mat(x), mat(y), true, loo_values, lambda_used); + } + + + private: + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function do_train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y, + const bool output_loo_values, + std::vector& loo_values, + scalar_type& the_lambda + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(is_learning_problem(x,y), + "\t decision_function rr_trainer::train(x,y)" + << "\n\t invalid inputs were given to this function" + << "\n\t is_vector(x): " << is_vector(x) + << "\n\t is_vector(y): " << is_vector(y) + << "\n\t x.size(): " << x.size() + << "\n\t y.size(): " << y.size() + ); + +#ifdef ENABLE_ASSERTS + if (get_lambda() == 0 && will_use_regression_loss_for_loo_cv() == false) + { + // make sure requires clause is not broken + DLIB_ASSERT(is_binary_classification_problem(x,y), + "\t decision_function rr_trainer::train(x,y)" + << "\n\t invalid inputs were given to this function" + ); + } +#endif + + typedef matrix column_matrix_type; + typedef matrix general_matrix_type; + + const long dims = x(0).size(); + + /* + Notes on the solution of ridge regression + + Let A = an x.size() by dims matrix which contains all the data samples. + + Let I = an identity matrix + + Let C = trans(A)*A + Let L = trans(A)*y + + Then the optimal w is given by: + w = inv(C + lambda*I) * L + + + There is a trick to compute leave one out cross validation results for many different + lambda values quickly. The following paper has a detailed discussion of various + approaches: + + Notes on Regularized Least Squares by Ryan M. Rifkin and Ross A. Lippert. + + In the implementation of the rr_trainer I'm only using two simple equations + from the above paper. + + + First note that inv(C + lambda*I) can be computed for many different lambda + values in an efficient way by using an eigen decomposition of C. So we use + the fact that: + inv(C + lambda*I) == V*inv(D + lambda*I)*trans(V) + where V*D*trans(V) == C + + Also, via some simple linear algebra the above paper works out that the leave one out + value for a sample x(i) is equal to the following: + Let G = inv(C + lambda*I) + let val = trans(x(i))*G*x(i); + + leave one out value for sample x(i): + LOOV = (trans(w)*x(i) - y(i)*val) / (1 - val) + + leave one out error for sample x(i): + LOOE = loss(y(i), LOOV) + + + Finally, note that we will pretend there was a 1 appended to the end of each + vector in x. We won't actually do that though because we don't want to + have to make a copy of all the samples. So throughout the following code + I have explicitly dealt with this. + */ + + general_matrix_type C, tempm, G; + column_matrix_type L, tempv, w; + + // compute C and L + for (long i = 0; i < x.size(); ++i) + { + C += x(i)*trans(x(i)); + L += y(i)*x(i); + tempv += x(i); + } + + // Account for the extra 1 that we pretend is appended to x + // Make C = [C tempv + // tempv' x.size()] + C = join_cols(join_rows(C, tempv), + join_rows(trans(tempv), uniform_matrix(1,1, x.size()))); + L = join_cols(L, uniform_matrix(1,1, sum(y))); + + eigenvalue_decomposition eig(make_symmetric(C)); + const general_matrix_type V = eig.get_pseudo_v(); + const column_matrix_type D = eig.get_real_eigenvalues(); + + // We can save some work by pre-multiplying the x vectors by trans(V) + // and saving the result so we don't have to recompute it over and over later. + matrix Vx; + if (lambda == 0 || output_loo_values) + { + // Save the transpose of V into a temporary because the subsequent matrix + // vector multiplies will be faster (because of better cache locality). + const general_matrix_type transV( colm(trans(V),range(0,dims-1)) ); + // Remember the pretend 1 at the end of x(*). We want to multiply trans(V)*x(*) + // so to do this we pull the last column off trans(V) and store it separately. + const column_matrix_type lastV = colm(trans(V), dims); + Vx.set_size(x.size()); + for (long i = 0; i < x.size(); ++i) + { + Vx(i) = transV*x(i); + Vx(i) = squared(Vx(i) + lastV); + } + } + + the_lambda = lambda; + + // If we need to automatically select a lambda then do so using the LOOE trick described + // above. + bool did_loov = false; + scalar_type best_looe = std::numeric_limits::max(); + if (lambda == 0) + { + did_loov = true; + + // Compute leave one out errors for a bunch of different lambdas and pick the best one. + for (long idx = 0; idx < lams.size(); ++idx) + { + // first compute G + tempv = 1.0/(D + lams(idx)); + tempm = scale_columns(V,tempv); + G = tempm*trans(V); + + // compute the solution w for the current lambda + w = G*L; + + // make w have the same length as the x vectors. + const scalar_type b = w(dims); + w = colm(w,0,dims); + + scalar_type looe = 0; + for (long i = 0; i < x.size(); ++i) + { + // perform equivalent of: val = trans(x(i))*G*x(i); + const scalar_type val = dot(tempv, Vx(i)); + const scalar_type temp = (1 - val); + scalar_type loov; + if (temp != 0) + loov = (trans(w)*x(i) + b - y(i)*val) / temp; + else + loov = 0; + + looe += loss(loov, y(i)); + } + + // Keep track of the lambda which gave the lowest looe. If two lambdas + // have the same looe then pick the biggest lambda. + if (looe < best_looe || (looe == best_looe && lams(idx) > the_lambda)) + { + best_looe = looe; + the_lambda = lams(idx); + } + } + + best_looe /= x.size(); + } + + + + // Now perform the main training. That is, find w. + // first, compute G = inv(C + the_lambda*I) + tempv = 1.0/(D + the_lambda); + tempm = scale_columns(V,tempv); + G = tempm*trans(V); + w = G*L; + + // make w have the same length as the x vectors. + const scalar_type b = w(dims); + w = colm(w,0,dims); + + + // If we haven't done this already and we are supposed to then compute the LOO error rate for + // the current lambda and store the result in best_looe. + if (output_loo_values) + { + loo_values.resize(x.size()); + did_loov = true; + best_looe = 0; + for (long i = 0; i < x.size(); ++i) + { + // perform equivalent of: val = trans(x(i))*G*x(i); + const scalar_type val = dot(tempv, Vx(i)); + const scalar_type temp = (1 - val); + scalar_type loov; + if (temp != 0) + loov = (trans(w)*x(i) + b - y(i)*val) / temp; + else + loov = 0; + + best_looe += loss(loov, y(i)); + loo_values[i] = loov; + } + + best_looe /= x.size(); + + } + else + { + loo_values.clear(); + } + + if (verbose && did_loov) + { + using namespace std; + cout << "Using lambda: " << the_lambda << endl; + if (use_regression_loss) + cout << "LOO Mean Squared Error: " << best_looe << endl; + else + cout << "LOO Classification Error: " << best_looe << endl; + } + + // convert w into a proper decision function + decision_function df; + df.alpha.set_size(1); + df.alpha = 1; + df.basis_vectors.set_size(1); + df.basis_vectors(0) = w; + df.b = -b; // don't forget about the bias we stuck onto all the vectors + + return df; + } + + inline scalar_type loss ( + const scalar_type& a, + const scalar_type& b + ) const + { + if (use_regression_loss) + { + return (a-b)*(a-b); + } + else + { + // if a and b have the same sign then no loss + if (a*b >= 0) + return 0; + else + return 1; + } + } + + + /*! + CONVENTION + - get_lambda() == lambda + - get_kernel() == kernel_type() + - will_use_regression_loss_for_loo_cv() == use_regression_loss + - get_search_lambdas() == lams + !*/ + + bool verbose; + bool use_regression_loss; + + scalar_type lambda; + + matrix lams; + }; + +} + +#endif // DLIB_RR_TRAInER_Hh_ + + diff --git a/ml/dlib/dlib/svm/rr_trainer_abstract.h b/ml/dlib/dlib/svm/rr_trainer_abstract.h new file mode 100644 index 000000000..f2fe21068 --- /dev/null +++ b/ml/dlib/dlib/svm/rr_trainer_abstract.h @@ -0,0 +1,255 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_RR_TRAInER_ABSTRACT_Hh_ +#ifdef DLIB_RR_TRAInER_ABSTRACT_Hh_ + +#include "../algs.h" +#include "function_abstract.h" + +namespace dlib +{ + template < + typename K + > + class rr_trainer + { + /*! + REQUIREMENTS ON K + is the dlib::linear_kernel instantiated with some kind of column vector. + + INITIAL VALUE + - get_lambda() == 0 + - will_use_regression_loss_for_loo_cv() == true + - get_search_lambdas() == logspace(-9, 2, 50) + - this object will not be verbose unless be_verbose() is called + + WHAT THIS OBJECT REPRESENTS + This object represents a tool for performing linear ridge regression + (This basic algorithm is also known my many other names, e.g. regularized + least squares or least squares SVM). + + The exact definition of what this algorithm does is this: + Find w and b that minimizes the following (x_i are input samples and y_i are target values): + lambda*dot(w,w) + sum_over_i( (f(x_i) - y_i)^2 ) + where f(x) == dot(x,w) - b + + So this algorithm is just regular old least squares regression but + with the addition of a regularization term which encourages small w. + + + It is capable of estimating the lambda parameter using leave-one-out cross-validation. + + + The leave-one-out cross-validation implementation is based on the techniques + discussed in this paper: + Notes on Regularized Least Squares by Ryan M. Rifkin and Ross A. Lippert. + !*/ + + public: + typedef K kernel_type; + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + typedef decision_function trained_function_type; + + rr_trainer ( + ); + /*! + ensures + - This object is properly initialized and ready to be used. + !*/ + + void be_verbose ( + ); + /*! + ensures + - This object will print status messages to standard out. + !*/ + + void be_quiet ( + ); + /*! + ensures + - this object will not print anything to standard out + !*/ + + const kernel_type get_kernel ( + ) const; + /*! + ensures + - returns a copy of the kernel function in use by this object. Since + the linear kernels don't have any parameters this function just + returns kernel_type() + !*/ + + void set_lambda ( + scalar_type lambda + ); + /*! + requires + - lambda >= 0 + ensures + - #get_lambda() == lambda + !*/ + + const scalar_type get_lambda ( + ) const; + /*! + ensures + - returns the regularization parameter. It is the parameter that + determines the trade off between trying to fit the training data + exactly or allowing more errors but hopefully improving the + generalization ability of the resulting function. Smaller values + encourage exact fitting while larger values of lambda may encourage + better generalization. + + Note that a lambda of 0 has a special meaning. It indicates to this + object that it should automatically determine an appropriate lambda + value. This is done using leave-one-out cross-validation. + !*/ + + void use_regression_loss_for_loo_cv ( + ); + /*! + ensures + - #will_use_regression_loss_for_loo_cv() == true + !*/ + + void use_classification_loss_for_loo_cv ( + ); + /*! + ensures + - #will_use_regression_loss_for_loo_cv() == false + !*/ + + bool will_use_regression_loss_for_loo_cv ( + ) const; + /*! + ensures + - returns true if the automatic lambda estimation will attempt to estimate a lambda + appropriate for a regression task. Otherwise it will try and find one which + minimizes the number of classification errors. + !*/ + + template + void set_search_lambdas ( + const matrix_exp& lambdas + ); + /*! + requires + - is_vector(lambdas) == true + - lambdas.size() > 0 + - min(lambdas) > 0 + - lambdas must contain floating point numbers + ensures + - #get_search_lambdas() == lambdas + !*/ + + const matrix& get_search_lambdas ( + ) const; + /*! + ensures + - returns a matrix M such that: + - is_vector(M) == true + - M == a list of all the lambda values which will be tried when performing + LOO cross-validation for determining the best lambda. + !*/ + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y + ) const; + /*! + requires + - x == a matrix or something convertible to a matrix via mat(). + Also, x should contain sample_type objects. + - y == a matrix or something convertible to a matrix via mat(). + Also, y should contain scalar_type objects. + - is_learning_problem(x,y) == true + - if (get_lambda() == 0 && will_use_regression_loss_for_loo_cv() == false) then + - is_binary_classification_problem(x,y) == true + (i.e. if you want this algorithm to estimate a lambda appropriate for + classification functions then you had better give a valid classification + problem) + ensures + - performs linear ridge regression given the training samples in x and target values in y. + - returns a decision_function F with the following properties: + - F(new_x) == predicted y value + - F.alpha.size() == 1 + - F.basis_vectors.size() == 1 + - F.alpha(0) == 1 + + - if (get_lambda() == 0) then + - This object will perform internal leave-one-out cross-validation to determine an + appropriate lambda automatically. It will compute the LOO error for each lambda + in get_search_lambdas() and select the best one. + - if (will_use_regression_loss_for_loo_cv()) then + - the lambda selected will be the one that minimizes the mean squared error. + - else + - the lambda selected will be the one that minimizes the number classification + mistakes. We say a point is classified correctly if the output of the + decision_function has the same sign as its label. + - #get_lambda() == 0 + (i.e. we don't change the get_lambda() value. If you want to know what the + automatically selected lambda value was then call the version of train() + defined below) + - else + - The user supplied value of get_lambda() will be used to perform the ridge regression. + !*/ + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y, + std::vector& loo_values + ) const; + /*! + requires + - all the requirements for train(x,y) must be satisfied + ensures + - returns train(x,y) + (i.e. executes train(x,y) and returns its result) + - #loo_values.size() == y.size() + - for all valid i: + - #loo_values[i] == leave-one-out prediction for the value of y(i) based + on all the training samples other than (x(i),y(i)). + !*/ + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y, + std::vector& loo_values, + scalar_type& lambda_used + ) const; + /*! + requires + - all the requirements for train(x,y) must be satisfied + ensures + - returns train(x,y) + (i.e. executes train(x,y) and returns its result) + - #loo_values.size() == y.size() + - for all valid i: + - #loo_values[i] == leave-one-out prediction for the value of y(i) based + on all the training samples other than (x(i),y(i)). + - #lambda_used == the value of lambda used to generate the + decision_function. Note that this lambda value is always + equal to get_lambda() if get_lambda() isn't 0. + !*/ + + }; + +} + +#endif // DLIB_RR_TRAInER_ABSTRACT_Hh_ + diff --git a/ml/dlib/dlib/svm/rvm.h b/ml/dlib/dlib/svm/rvm.h new file mode 100644 index 000000000..e7ad495a2 --- /dev/null +++ b/ml/dlib/dlib/svm/rvm.h @@ -0,0 +1,1018 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_RVm_ +#define DLIB_RVm_ + +#include "rvm_abstract.h" +#include +#include +#include "../matrix.h" +#include "../algs.h" +#include "function.h" +#include "kernel.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace rvm_helpers + { + + // ------------------------------------------------------------------------------------ + + template + long find_next_best_alpha_to_update ( + const scalar_vector_type& S, + const scalar_vector_type& Q, + const scalar_vector_type& alpha, + const matrix& active_bases, + const bool search_all_alphas, + typename scalar_vector_type::type eps + ) + /*! + ensures + - if (we can find another alpha to update) then + - returns the index of said alpha + - else + - returns -1 + !*/ + { + typedef typename scalar_vector_type::type scalar_type; + // now use S and Q to find next alpha to update. What + // we want to do here is select the alpha to update that gives us + // the greatest improvement in marginal likelihood. + long selected_idx = -1; + scalar_type greatest_improvement = -1; + for (long i = 0; i < S.nr(); ++i) + { + scalar_type value = -1; + + // if i is currently in the active set + if (active_bases(i) >= 0) + { + const long idx = active_bases(i); + const scalar_type s = alpha(idx)*S(i)/(alpha(idx) - S(i)); + const scalar_type q = alpha(idx)*Q(i)/(alpha(idx) - S(i)); + + if (q*q-s > 0) + { + // only update an existing alpha if this is a narrow search + if (search_all_alphas == false) + { + // choosing this sample would mean doing an update of an + // existing alpha value. + scalar_type new_alpha = s*s/(q*q-s); + scalar_type cur_alpha = alpha(idx); + new_alpha = 1/new_alpha; + cur_alpha = 1/cur_alpha; + + // from equation 32 in the Tipping paper + value = Q(i)*Q(i)/(S(i) + 1/(new_alpha - cur_alpha) ) - + std::log(1 + S(i)*(new_alpha - cur_alpha)); + } + + } + // we only pick an alpha to remove if this is a wide search and it wasn't one of the recently added ones + else if (search_all_alphas && idx+2 < alpha.size() ) + { + // choosing this sample would mean the alpha value is infinite + // so we would remove the selected sample from our model. + + // from equation 37 in the Tipping paper + value = Q(i)*Q(i)/(S(i) - alpha(idx)) - + std::log(1-S(i)/alpha(idx)); + + } + } + else if (search_all_alphas) + { + const scalar_type s = S(i); + const scalar_type q = Q(i); + + if (q*q-s > 0) + { + // choosing this sample would mean we would add the selected + // sample to our model. + + // from equation 27 in the Tipping paper + value = (Q(i)*Q(i)-S(i))/S(i) + std::log(S(i)/(Q(i)*Q(i))); + } + } + + if (value > greatest_improvement) + { + greatest_improvement = value; + selected_idx = i; + } + } + + // If the greatest_improvement in marginal likelihood we would get is less + // than our epsilon then report that there isn't anything else to do. But + // if it is big enough then return the selected_idx. + if (greatest_improvement > eps) + return selected_idx; + else + return -1; + } + + } // end namespace rvm_helpers + + // ------------------------------------------------------------------------------------ + + + template < + typename kern_type + > + class rvm_trainer + { + /*! + This is an implementation of the binary classifier version of the + relevance vector machine algorithm described in the paper: + Tipping, M. E. and A. C. Faul (2003). Fast marginal likelihood maximisation + for sparse Bayesian models. In C. M. Bishop and B. J. Frey (Eds.), Proceedings + of the Ninth International Workshop on Artificial Intelligence and Statistics, + Key West, FL, Jan 3-6. + + This code mostly does what is described in the above paper with the exception + that here we use a different stopping condition as well as a modified alpha + selection rule. See the code for the exact details. + !*/ + + public: + typedef kern_type kernel_type; + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + typedef decision_function trained_function_type; + + rvm_trainer ( + ) : eps(0.001), max_iterations(2000) + { + } + + void set_max_iterations ( + unsigned long max_iterations_ + ) + { + max_iterations = max_iterations_; + } + + unsigned long get_max_iterations ( + ) const + { + return max_iterations; + } + + void set_epsilon ( + scalar_type eps_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(eps_ > 0, + "\tvoid rvm_trainer::set_epsilon(eps_)" + << "\n\t invalid inputs were given to this function" + << "\n\t eps: " << eps_ + ); + eps = eps_; + } + + const scalar_type get_epsilon ( + ) const + { + return eps; + } + + void set_kernel ( + const kernel_type& k + ) + { + kernel = k; + } + + const kernel_type& get_kernel ( + ) const + { + return kernel; + } + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y + ) const + { + return do_train(mat(x), mat(y)); + } + + void swap ( + rvm_trainer& item + ) + { + exchange(kernel, item.kernel); + exchange(eps, item.eps); + } + + private: + + // ------------------------------------------------------------------------------------ + + typedef matrix scalar_vector_type; + typedef matrix scalar_matrix_type; + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function do_train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y + ) const + { + + // make sure requires clause is not broken + DLIB_ASSERT(is_binary_classification_problem(x,y) == true, + "\tdecision_function rvm_trainer::train(x,y)" + << "\n\t invalid inputs were given to this function" + << "\n\t x.nr(): " << x.nr() + << "\n\t y.nr(): " << y.nr() + << "\n\t x.nc(): " << x.nc() + << "\n\t y.nc(): " << y.nc() + << "\n\t is_binary_classification_problem(x,y): " << ((is_binary_classification_problem(x,y))? "true":"false") + ); + + // make a target vector where +1 examples have value 1 and -1 examples + // have a value of 0. + scalar_vector_type t(y.size()); + for (long i = 0; i < y.size(); ++i) + { + if (y(i) == 1) + t(i) = 1; + else + t(i) = 0; + } + + /*! This is the convention for the active_bases variable in the function: + - if (active_bases(i) >= 0) then + - alpha(active_bases(i)) == the alpha value associated with sample x(i) + - weights(active_bases(i)) == the weight value associated with sample x(i) + - colm(phi, active_bases(i)) == the column of phi associated with sample x(i) + - colm(phi, active_bases(i)) == kernel column i (from get_kernel_colum()) + - else + - the i'th sample isn't in the model and notionally has an alpha of infinity and + a weight of 0. + !*/ + matrix active_bases(x.nr()); + scalar_matrix_type phi(x.nr(),1); + scalar_vector_type alpha(1), prev_alpha; + scalar_vector_type weights(1), prev_weights; + + scalar_vector_type tempv, K_col; + + // set the initial values of these guys + set_all_elements(active_bases, -1); + long first_basis = pick_initial_vector(x,t); + get_kernel_colum(first_basis, x, K_col); + active_bases(first_basis) = 0; + set_colm(phi,0) = K_col; + alpha(0) = compute_initial_alpha(phi, t); + weights(0) = 1; + + + // now declare a bunch of other variables we will be using below + scalar_vector_type mu, t_hat, Q, S; + scalar_matrix_type sigma; + + matrix tempv2, tempv3; + scalar_matrix_type tempm; + + scalar_vector_type t_estimate; + scalar_vector_type beta; + + + Q.set_size(x.nr()); + S.set_size(x.nr()); + + bool recompute_beta = true; + + bool search_all_alphas = false; + unsigned long ticker = 0; + const unsigned long rounds_of_narrow_search = 100; + unsigned long iterations = 0; + + while (iterations != max_iterations) + { + iterations++; + if (recompute_beta) + { + // calculate the current t_estimate. (this is the predicted t value for each sample according to the + // current state of the classifier) + t_estimate = phi*weights; + + // calculate the current beta + beta = sigmoid(t_estimate); + beta = pointwise_multiply(beta,(uniform_matrix(beta.nr(),beta.nc(),1)-beta)); + recompute_beta = false; + } + + // Compute optimal weights and sigma for current alpha using IRLS. This is the same + // technique documented in the paper by equations 12-14. + scalar_type weight_delta = std::numeric_limits::max(); + int count = 0; + while (weight_delta > 0.0001) + { + // This is a sanity check to make sure we never get stuck in this + // loop to do some degenerate numerical condition + ++count; + if (count > 100) + { + // jump us to where search_all_alphas will be set to true + ticker = rounds_of_narrow_search; + break; + } + + // compute the updated sigma matrix + sigma = scale_columns(trans(phi),beta)*phi; + for (long r = 0; r < alpha.nr(); ++r) + sigma(r,r) += alpha(r); + sigma = inv(sigma); + + + // compute the updated weights vector (t_hat = phi*mu_mp + inv(B)*(t-y)) + t_hat = t_estimate + trans(scale_columns(trans(t-sigmoid(t_estimate)),reciprocal(beta))); + + // mu = sigma*trans(phi)*b*t_hat + mu = sigma*tmp(trans(phi)* trans(scale_columns(trans(t_hat), beta))); + + // now compute how much the weights vector changed during this iteration + // through this loop. + weight_delta = max(abs(mu-weights)); + + // put mu into the weights vector + mu.swap(weights); + + // calculate the current t_estimate + t_estimate = phi*weights; + + // calculate the current beta + beta = sigmoid(t_estimate); + beta = pointwise_multiply(beta, uniform_matrix(beta.nr(),beta.nc(),1)-beta); + + } + + // check if we should do a full search for the best alpha to optimize + if (ticker >= rounds_of_narrow_search) + { + // if the current alpha and weights are equal to what they were + // at the last time we were about to start a wide search then + // we are done. + if (equal(prev_alpha, alpha, eps) && equal(prev_weights, weights, eps)) + break; + + + prev_alpha = alpha; + prev_weights = weights; + search_all_alphas = true; + ticker = 0; + } + else + { + search_all_alphas = false; + } + ++ticker; + + // compute S and Q using equations 24 and 25 (tempv = phi*sigma*trans(phi)*B*t_hat) + tempv = phi*tmp(sigma*tmp(trans(phi)*trans(scale_columns(trans(t_hat),beta)))); + for (long i = 0; i < S.size(); ++i) + { + // if we are currently limiting the search for the next alpha to update + // to the set in the active set then skip a non-active vector. + if (search_all_alphas == false && active_bases(i) == -1) + continue; + + // get the column for this sample out of the kernel matrix. If it is + // something in the active set then just get it right out of phi, otherwise + // we have to calculate it. + if (active_bases(i) != -1) + K_col = colm(phi,active_bases(i)); + else + get_kernel_colum(i, x, K_col); + + // tempv2 = trans(phi_m)*B + tempv2 = scale_columns(trans(K_col), beta); + tempv3 = tempv2*phi; + S(i) = tempv2*K_col - tempv3*sigma*trans(tempv3); + Q(i) = tempv2*t_hat - tempv2*tempv; + } + + const long selected_idx = rvm_helpers::find_next_best_alpha_to_update(S,Q,alpha,active_bases, search_all_alphas, eps); + + + // if find_next_best_alpha_to_update didn't find any good alpha to update + if (selected_idx == -1) + { + if (search_all_alphas == false) + { + // jump us to where search_all_alphas will be set to true and try again + ticker = rounds_of_narrow_search; + continue; + } + else + { + // we are really done so quit the main loop + break; + } + } + + + // next we update the selected alpha. + + // if the selected alpha is in the active set + if (active_bases(selected_idx) >= 0) + { + const long idx = active_bases(selected_idx); + const scalar_type s = alpha(idx)*S(selected_idx)/(alpha(idx) - S(selected_idx)); + const scalar_type q = alpha(idx)*Q(selected_idx)/(alpha(idx) - S(selected_idx)); + + if (q*q-s > 0) + { + // reestimate the value of alpha + alpha(idx) = s*s/(q*q-s); + + } + else + { + // the new alpha value is infinite so remove the selected alpha from our model + active_bases(selected_idx) = -1; + phi = remove_col(phi, idx); + weights = remove_row(weights, idx); + alpha = remove_row(alpha, idx); + + // fix the index values in active_bases + for (long i = 0; i < active_bases.size(); ++i) + { + if (active_bases(i) > idx) + { + active_bases(i) -= 1; + } + } + + // we changed the number of weights so we need to remember to + // recompute the beta vector next time around the main loop. + recompute_beta = true; + } + } + else + { + const scalar_type s = S(selected_idx); + const scalar_type q = Q(selected_idx); + + if (q*q-s > 0) + { + // add the selected alpha to our model + + active_bases(selected_idx) = phi.nc(); + + // update alpha + tempv.set_size(alpha.size()+1); + set_subm(tempv, get_rect(alpha)) = alpha; + tempv(phi.nc()) = s*s/(q*q-s); + tempv.swap(alpha); + + // update weights + tempv.set_size(weights.size()+1); + set_subm(tempv, get_rect(weights)) = weights; + tempv(phi.nc()) = 0; + tempv.swap(weights); + + // update phi by adding the new sample's kernel matrix column in as one of phi's columns + tempm.set_size(phi.nr(), phi.nc()+1); + set_subm(tempm, get_rect(phi)) = phi; + get_kernel_colum(selected_idx, x, K_col); + set_colm(tempm, phi.nc()) = K_col; + tempm.swap(phi); + + + // we changed the number of weights so we need to remember to + // recompute the beta vector next time around the main loop. + recompute_beta = true; + } + } + + } // end while(true). So we have converged on the final answer. + + + // now put everything into a decision_function object and return it + std_vector_c dictionary; + std_vector_c final_weights; + for (long i = 0; i < active_bases.size(); ++i) + { + if (active_bases(i) >= 0) + { + dictionary.push_back(x(i)); + final_weights.push_back(weights(active_bases(i))); + } + } + + return decision_function ( mat(final_weights), + -sum(mat(final_weights))*tau, + kernel, + mat(dictionary)); + + } + + // ------------------------------------------------------------------------------------ + + template + long pick_initial_vector ( + const M1& x, + const M2& t + ) const + { + scalar_vector_type K_col; + double max_projection = -std::numeric_limits::infinity(); + long max_idx = 0; + // find the row in the kernel matrix that has the biggest normalized projection onto the t vector + for (long r = 0; r < x.nr(); ++r) + { + get_kernel_colum(r,x,K_col); + double temp = trans(K_col)*t; + temp = temp*temp/length_squared(K_col); + + if (temp > max_projection) + { + max_projection = temp; + max_idx = r; + } + } + + return max_idx; + } + + // ------------------------------------------------------------------------------------ + + template + void get_kernel_colum ( + long idx, + const T& x, + scalar_vector_type& col + ) const + { + col.set_size(x.nr()); + for (long i = 0; i < col.size(); ++i) + { + col(i) = kernel(x(idx), x(i)) + tau; + } + } + + // ------------------------------------------------------------------------------------ + + template + scalar_type compute_initial_alpha ( + const M1& phi, + const M2& t + ) const + { + const double temp = length_squared(phi); + const double temp2 = trans(phi)*t; + + return temp/( temp2*temp2/temp + variance(t)*0.1); + } + + // ------------------------------------------------------------------------------------ + + // private member variables + kernel_type kernel; + scalar_type eps; + unsigned long max_iterations; + + const static scalar_type tau; + + }; // end of class rvm_trainer + + template + const typename kernel_type::scalar_type rvm_trainer::tau = static_cast(0.001); + +// ---------------------------------------------------------------------------------------- + + template + void swap ( + rvm_trainer& a, + rvm_trainer& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename kern_type + > + class rvm_regression_trainer + { + /*! + This is an implementation of the regression version of the + relevance vector machine algorithm described in the paper: + Tipping, M. E. and A. C. Faul (2003). Fast marginal likelihood maximisation + for sparse Bayesian models. In C. M. Bishop and B. J. Frey (Eds.), Proceedings + of the Ninth International Workshop on Artificial Intelligence and Statistics, + Key West, FL, Jan 3-6. + + This code mostly does what is described in the above paper with the exception + that here we use a different stopping condition as well as a modified alpha + selection rule. See the code for the exact details. + !*/ + + public: + typedef kern_type kernel_type; + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + typedef decision_function trained_function_type; + + rvm_regression_trainer ( + ) : eps(0.001) + { + } + + void set_epsilon ( + scalar_type eps_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(eps_ > 0, + "\tvoid rvm_regression_trainer::set_epsilon(eps_)" + << "\n\t invalid inputs were given to this function" + << "\n\t eps: " << eps_ + ); + eps = eps_; + } + + const scalar_type get_epsilon ( + ) const + { + return eps; + } + + void set_kernel ( + const kernel_type& k + ) + { + kernel = k; + } + + const kernel_type& get_kernel ( + ) const + { + return kernel; + } + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& t + ) const + { + return do_train(mat(x), mat(t)); + } + + void swap ( + rvm_regression_trainer& item + ) + { + exchange(kernel, item.kernel); + exchange(eps, item.eps); + } + + private: + + // ------------------------------------------------------------------------------------ + + typedef matrix scalar_vector_type; + typedef matrix scalar_matrix_type; + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function do_train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& t + ) const + { + + // make sure requires clause is not broken + DLIB_ASSERT(is_learning_problem(x,t) && x.size() > 0, + "\tdecision_function rvm_regression_trainer::train(x,t)" + << "\n\t invalid inputs were given to this function" + << "\n\t x.nr(): " << x.nr() + << "\n\t t.nr(): " << t.nr() + << "\n\t x.nc(): " << x.nc() + << "\n\t t.nc(): " << t.nc() + ); + + + /*! This is the convention for the active_bases variable in the function: + - if (active_bases(i) >= 0) then + - alpha(active_bases(i)) == the alpha value associated with sample x(i) + - weights(active_bases(i)) == the weight value associated with sample x(i) + - colm(phi, active_bases(i)) == the column of phi associated with sample x(i) + - colm(phi, active_bases(i)) == kernel column i (from get_kernel_colum()) + - else + - the i'th sample isn't in the model and notionally has an alpha of infinity and + a weight of 0. + !*/ + matrix active_bases(x.nr()); + scalar_matrix_type phi(x.nr(),1); + scalar_vector_type alpha(1), prev_alpha; + scalar_vector_type weights(1), prev_weights; + + scalar_vector_type tempv, K_col; + scalar_type var = variance(t)*0.1; + + // set the initial values of these guys + set_all_elements(active_bases, -1); + long first_basis = pick_initial_vector(x,t); + get_kernel_colum(first_basis, x, K_col); + active_bases(first_basis) = 0; + set_colm(phi,0) = K_col; + alpha(0) = compute_initial_alpha(phi, t, var); + weights(0) = 1; + + + // now declare a bunch of other variables we will be using below + scalar_vector_type Q, S; + scalar_matrix_type sigma; + + matrix tempv2, tempv3; + scalar_matrix_type tempm; + + + Q.set_size(x.nr()); + S.set_size(x.nr()); + + + bool search_all_alphas = false; + unsigned long ticker = 0; + const unsigned long rounds_of_narrow_search = 100; + + while (true) + { + // Compute optimal weights and sigma for current alpha using equation 6. + sigma = trans(phi)*phi/var; + for (long r = 0; r < alpha.nr(); ++r) + sigma(r,r) += alpha(r); + sigma = inv(sigma); + weights = sigma*trans(phi)*t/var; + + + + // check if we should do a full search for the best alpha to optimize + if (ticker == rounds_of_narrow_search) + { + // if the current alpha and weights are equal to what they were + // at the last time we were about to start a wide search then + // we are done. + if (equal(prev_alpha, alpha, eps) && equal(prev_weights, weights, eps)) + break; + + prev_alpha = alpha; + prev_weights = weights; + search_all_alphas = true; + ticker = 0; + } + else + { + search_all_alphas = false; + } + ++ticker; + + // compute S and Q using equations 24 and 25 (tempv = phi*sigma*trans(phi)*B*t) + tempv = phi*tmp(sigma*tmp(trans(phi)*t/var)); + for (long i = 0; i < S.size(); ++i) + { + // if we are currently limiting the search for the next alpha to update + // to the set in the active set then skip a non-active vector. + if (search_all_alphas == false && active_bases(i) == -1) + continue; + + // get the column for this sample out of the kernel matrix. If it is + // something in the active set then just get it right out of phi, otherwise + // we have to calculate it. + if (active_bases(i) != -1) + K_col = colm(phi,active_bases(i)); + else + get_kernel_colum(i, x, K_col); + + // tempv2 = trans(phi_m)*B + tempv2 = trans(K_col)/var; + tempv3 = tempv2*phi; + S(i) = tempv2*K_col - tempv3*sigma*trans(tempv3); + Q(i) = tempv2*t - tempv2*tempv; + } + + const long selected_idx = rvm_helpers::find_next_best_alpha_to_update(S,Q,alpha,active_bases, search_all_alphas, eps); + + // if find_next_best_alpha_to_update didn't find any good alpha to update + if (selected_idx == -1) + { + if (search_all_alphas == false) + { + // jump us to where search_all_alphas will be set to true and try again + ticker = rounds_of_narrow_search; + continue; + } + else + { + // we are really done so quit the main loop + break; + } + } + + // recompute the variance + var = length_squared(t - phi*weights)/(x.nr() - weights.size() + trans(alpha)*diag(sigma)); + + // next we update the selected alpha. + + // if the selected alpha is in the active set + if (active_bases(selected_idx) >= 0) + { + const long idx = active_bases(selected_idx); + const scalar_type s = alpha(idx)*S(selected_idx)/(alpha(idx) - S(selected_idx)); + const scalar_type q = alpha(idx)*Q(selected_idx)/(alpha(idx) - S(selected_idx)); + + if (q*q-s > 0) + { + // reestimate the value of alpha + alpha(idx) = s*s/(q*q-s); + + } + else + { + // the new alpha value is infinite so remove the selected alpha from our model + active_bases(selected_idx) = -1; + phi = remove_col(phi, idx); + weights = remove_row(weights, idx); + alpha = remove_row(alpha, idx); + + // fix the index values in active_bases + for (long i = 0; i < active_bases.size(); ++i) + { + if (active_bases(i) > idx) + { + active_bases(i) -= 1; + } + } + } + } + else + { + const scalar_type s = S(selected_idx); + const scalar_type q = Q(selected_idx); + + if (q*q-s > 0) + { + // add the selected alpha to our model + + active_bases(selected_idx) = phi.nc(); + + // update alpha + tempv.set_size(alpha.size()+1); + set_subm(tempv, get_rect(alpha)) = alpha; + tempv(phi.nc()) = s*s/(q*q-s); + tempv.swap(alpha); + + // update weights + tempv.set_size(weights.size()+1); + set_subm(tempv, get_rect(weights)) = weights; + tempv(phi.nc()) = 0; + tempv.swap(weights); + + // update phi by adding the new sample's kernel matrix column in as one of phi's columns + tempm.set_size(phi.nr(), phi.nc()+1); + set_subm(tempm, get_rect(phi)) = phi; + get_kernel_colum(selected_idx, x, K_col); + set_colm(tempm, phi.nc()) = K_col; + tempm.swap(phi); + + } + } + + + + } // end while(true). So we have converged on the final answer. + + + // now put everything into a decision_function object and return it + std_vector_c dictionary; + std_vector_c final_weights; + for (long i = 0; i < active_bases.size(); ++i) + { + if (active_bases(i) >= 0) + { + dictionary.push_back(x(i)); + final_weights.push_back(weights(active_bases(i))); + } + } + + return decision_function ( mat(final_weights), + -sum(mat(final_weights))*tau, + kernel, + mat(dictionary)); + + } + + // ------------------------------------------------------------------------------------ + + template + void get_kernel_colum ( + long idx, + const T& x, + scalar_vector_type& col + ) const + { + col.set_size(x.nr()); + for (long i = 0; i < col.size(); ++i) + { + col(i) = kernel(x(idx), x(i)) + tau; + } + } + + // ------------------------------------------------------------------------------------ + + template + scalar_type compute_initial_alpha ( + const M1& phi, + const M2& t, + const scalar_type& var + ) const + { + const double temp = length_squared(phi); + const double temp2 = trans(phi)*t; + + return temp/( temp2*temp2/temp + var); + } + + // ------------------------------------------------------------------------------------ + + template + long pick_initial_vector ( + const M1& x, + const M2& t + ) const + { + scalar_vector_type K_col; + double max_projection = -std::numeric_limits::infinity(); + long max_idx = 0; + // find the row in the kernel matrix that has the biggest normalized projection onto the t vector + for (long r = 0; r < x.nr(); ++r) + { + get_kernel_colum(r,x,K_col); + double temp = trans(K_col)*t; + temp = temp*temp/length_squared(K_col); + + if (temp > max_projection) + { + max_projection = temp; + max_idx = r; + } + } + + return max_idx; + } + + // ------------------------------------------------------------------------------------ + + // private member variables + kernel_type kernel; + scalar_type eps; + + const static scalar_type tau; + + }; // end of class rvm_regression_trainer + + template + const typename kernel_type::scalar_type rvm_regression_trainer::tau = static_cast(0.001); + +// ---------------------------------------------------------------------------------------- + + template + void swap ( + rvm_regression_trainer& a, + rvm_regression_trainer& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_RVm_ + + diff --git a/ml/dlib/dlib/svm/rvm_abstract.h b/ml/dlib/dlib/svm/rvm_abstract.h new file mode 100644 index 000000000..236d2ad3c --- /dev/null +++ b/ml/dlib/dlib/svm/rvm_abstract.h @@ -0,0 +1,278 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_RVm_ABSTRACT_ +#ifdef DLIB_RVm_ABSTRACT_ + +#include +#include +#include "../matrix.h" +#include "../algs.h" +#include "function.h" +#include "kernel.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename kern_type + > + class rvm_trainer + { + /*! + REQUIREMENTS ON kern_type + is a kernel function object as defined in dlib/svm/kernel_abstract.h + + WHAT THIS OBJECT REPRESENTS + This object implements a trainer for a relevance vector machine for + solving binary classification problems. + + The implementation of the RVM training algorithm used by this object is based + on the following excellent paper: + Tipping, M. E. and A. C. Faul (2003). Fast marginal likelihood maximisation + for sparse Bayesian models. In C. M. Bishop and B. J. Frey (Eds.), Proceedings + of the Ninth International Workshop on Artificial Intelligence and Statistics, + Key West, FL, Jan 3-6. + !*/ + + public: + typedef kern_type kernel_type; + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + typedef decision_function trained_function_type; + + rvm_trainer ( + ); + /*! + ensures + - This object is properly initialized and ready to be used + to train a relevance vector machine. + - #get_epsilon() == 0.001 + - #get_max_iterations() == 2000 + !*/ + + void set_epsilon ( + scalar_type eps + ); + /*! + requires + - eps > 0 + ensures + - #get_epsilon() == eps + !*/ + + const scalar_type get_epsilon ( + ) const; + /*! + ensures + - returns the error epsilon that determines when training should stop. + Generally a good value for this is 0.001. Smaller values may result + in a more accurate solution but take longer to execute. + !*/ + + void set_kernel ( + const kernel_type& k + ); + /*! + ensures + - #get_kernel() == k + !*/ + + const kernel_type& get_kernel ( + ) const; + /*! + ensures + - returns a copy of the kernel function in use by this object + !*/ + + unsigned long get_max_iterations ( + ) const; + /*! + ensures + - returns the maximum number of iterations the RVM optimizer is allowed to + run before it is required to stop and return a result. + !*/ + + void set_max_iterations ( + unsigned long max_iter + ); + /*! + ensures + - #get_max_iterations() == max_iter + !*/ + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y + ) const; + /*! + requires + - is_binary_classification_problem(x,y) == true + - x == a matrix or something convertible to a matrix via mat(). + Also, x should contain sample_type objects. + - y == a matrix or something convertible to a matrix via mat(). + Also, y should contain scalar_type objects. + ensures + - trains a relevance vector classifier given the training samples in x and + labels in y. + - returns a decision function F with the following properties: + - if (new_x is a sample predicted have +1 label) then + - F(new_x) >= 0 + - else + - F(new_x) < 0 + throws + - std::bad_alloc + !*/ + + void swap ( + rvm_trainer& item + ); + /*! + ensures + - swaps *this and item + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template + void swap ( + rvm_trainer& a, + rvm_trainer& b + ) { a.swap(b); } + /*! + provides a global swap + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename kern_type + > + class rvm_regression_trainer + { + /*! + REQUIREMENTS ON kern_type + is a kernel function object as defined in dlib/svm/kernel_abstract.h + + WHAT THIS OBJECT REPRESENTS + This object implements a trainer for a relevance vector machine for + solving regression problems. + + The implementation of the RVM training algorithm used by this object is based + on the following excellent paper: + Tipping, M. E. and A. C. Faul (2003). Fast marginal likelihood maximisation + for sparse Bayesian models. In C. M. Bishop and B. J. Frey (Eds.), Proceedings + of the Ninth International Workshop on Artificial Intelligence and Statistics, + Key West, FL, Jan 3-6. + !*/ + + public: + typedef kern_type kernel_type; + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + typedef decision_function trained_function_type; + + rvm_regression_trainer ( + ); + /*! + ensures + - This object is properly initialized and ready to be used + to train a relevance vector machine. + - #get_epsilon() == 0.001 + !*/ + + void set_epsilon ( + scalar_type eps + ); + /*! + requires + - eps > 0 + ensures + - #get_epsilon() == eps + !*/ + + const scalar_type get_epsilon ( + ) const; + /*! + ensures + - returns the error epsilon that determines when training should stop. + Generally a good value for this is 0.001. Smaller values may result + in a more accurate solution but take longer to execute. + !*/ + + void set_kernel ( + const kernel_type& k + ); + /*! + ensures + - #get_kernel() == k + !*/ + + const kernel_type& get_kernel ( + ) const; + /*! + ensures + - returns a copy of the kernel function in use by this object + !*/ + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y + ) const; + /*! + requires + - x == a matrix or something convertible to a matrix via mat(). + Also, x should contain sample_type objects. + - y == a matrix or something convertible to a matrix via mat(). + Also, y should contain scalar_type objects. + - is_learning_problem(x,y) == true + - x.size() > 0 + ensures + - trains a RVM given the training samples in x and + labels in y and returns the resulting decision_function. + throws + - std::bad_alloc + !*/ + + void swap ( + rvm_regression_trainer& item + ); + /*! + ensures + - swaps *this and item + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template + void swap ( + rvm_regression_trainer& a, + rvm_regression_trainer& b + ) { a.swap(b); } + /*! + provides a global swap + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_RVm_ABSTRACT_ + diff --git a/ml/dlib/dlib/svm/sequence_labeler.h b/ml/dlib/dlib/svm/sequence_labeler.h new file mode 100644 index 000000000..882cdb881 --- /dev/null +++ b/ml/dlib/dlib/svm/sequence_labeler.h @@ -0,0 +1,339 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SEQUENCE_LAbELER_H_h_ +#define DLIB_SEQUENCE_LAbELER_H_h_ + +#include "sequence_labeler_abstract.h" +#include "../matrix.h" +#include +#include "../optimization/find_max_factor_graph_viterbi.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace fe_helpers + { + template + struct dot_functor + { + dot_functor(const matrix_exp& lambda_) : lambda(lambda_), value(0) {} + + inline void operator() ( + unsigned long feat_index + ) + { + value += lambda(feat_index); + } + + inline void operator() ( + unsigned long feat_index, + double feat_value + ) + { + value += feat_value*lambda(feat_index); + } + + const matrix_exp& lambda; + double value; + }; + + template + double dot( + const matrix_exp& lambda, + const feature_extractor& fe, + const sequence_type& sequence, + const matrix_exp& candidate_labeling, + unsigned long position + ) + { + dot_functor dot(lambda); + fe.get_features(dot, sequence, candidate_labeling, position); + return dot.value; + } + + } + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST( + has_reject_labeling, + bool, + template reject_labeling >, + (const typename T::sequence_type&, const matrix_exp >&, unsigned long)const + ); + + template + typename enable_if,bool>::type call_reject_labeling_if_exists ( + const feature_extractor& fe, + const sequence_type& x, + const matrix_exp& y, + unsigned long position + ) + { + return fe.reject_labeling(x, y, position); + } + + template + typename disable_if,bool>::type call_reject_labeling_if_exists ( + const feature_extractor& , + const sequence_type& , + const matrix_exp& , + unsigned long + ) + { + return false; + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + typename enable_if,bool>::type contains_invalid_labeling ( + const feature_extractor& fe, + const typename feature_extractor::sequence_type& x, + const std::vector& y + ) + { + if (x.size() != y.size()) + return true; + + matrix node_states; + + for (unsigned long i = 0; i < x.size(); ++i) + { + node_states.set_size(std::min(fe.order(),i) + 1); + for (unsigned long j = 0; j < (unsigned long)node_states.size(); ++j) + node_states(j) = y[i-j]; + + if (fe.reject_labeling(x, node_states, i)) + return true; + } + + return false; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + typename disable_if,bool>::type contains_invalid_labeling ( + const feature_extractor& , + const typename feature_extractor::sequence_type& x, + const std::vector& y + ) + { + if (x.size() != y.size()) + return true; + + return false; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + bool contains_invalid_labeling ( + const feature_extractor& fe, + const std::vector& x, + const std::vector >& y + ) + { + if (x.size() != y.size()) + return true; + + for (unsigned long i = 0; i < x.size(); ++i) + { + if (contains_invalid_labeling(fe,x[i],y[i])) + return true; + } + return false; + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + class sequence_labeler + { + public: + typedef typename feature_extractor::sequence_type sample_sequence_type; + typedef std::vector labeled_sequence_type; + + private: + class map_prob + { + public: + unsigned long order() const { return fe.order(); } + unsigned long num_states() const { return fe.num_labels(); } + + map_prob( + const sample_sequence_type& x_, + const feature_extractor& fe_, + const matrix& weights_ + ) : + sequence(x_), + fe(fe_), + weights(weights_) + { + } + + unsigned long number_of_nodes( + ) const + { + return sequence.size(); + } + + template < + typename EXP + > + double factor_value ( + unsigned long node_id, + const matrix_exp& node_states + ) const + { + if (dlib::impl::call_reject_labeling_if_exists(fe, sequence, node_states, node_id)) + return -std::numeric_limits::infinity(); + + return fe_helpers::dot(weights, fe, sequence, node_states, node_id); + } + + const sample_sequence_type& sequence; + const feature_extractor& fe; + const matrix& weights; + }; + public: + + sequence_labeler() + { + weights.set_size(fe.num_features()); + weights = 0; + } + + explicit sequence_labeler( + const matrix& weights_ + ) : + weights(weights_) + { + // make sure requires clause is not broken + DLIB_ASSERT(fe.num_features() == static_cast(weights_.size()), + "\t sequence_labeler::sequence_labeler(weights_)" + << "\n\t These sizes should match" + << "\n\t fe.num_features(): " << fe.num_features() + << "\n\t weights_.size(): " << weights_.size() + << "\n\t this: " << this + ); + } + + sequence_labeler( + const matrix& weights_, + const feature_extractor& fe_ + ) : + fe(fe_), + weights(weights_) + { + // make sure requires clause is not broken + DLIB_ASSERT(fe_.num_features() == static_cast(weights_.size()), + "\t sequence_labeler::sequence_labeler(weights_,fe_)" + << "\n\t These sizes should match" + << "\n\t fe_.num_features(): " << fe_.num_features() + << "\n\t weights_.size(): " << weights_.size() + << "\n\t this: " << this + ); + } + + const feature_extractor& get_feature_extractor ( + ) const { return fe; } + + const matrix& get_weights ( + ) const { return weights; } + + unsigned long num_labels ( + ) const { return fe.num_labels(); } + + labeled_sequence_type operator() ( + const sample_sequence_type& x + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(num_labels() > 0, + "\t labeled_sequence_type sequence_labeler::operator()(x)" + << "\n\t You can't have no labels." + << "\n\t this: " << this + ); + + labeled_sequence_type y; + find_max_factor_graph_viterbi(map_prob(x,fe,weights), y); + return y; + } + + void label_sequence ( + const sample_sequence_type& x, + labeled_sequence_type& y + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(num_labels() > 0, + "\t void sequence_labeler::label_sequence(x,y)" + << "\n\t You can't have no labels." + << "\n\t this: " << this + ); + + find_max_factor_graph_viterbi(map_prob(x,fe,weights), y); + } + + private: + + feature_extractor fe; + matrix weights; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + void serialize ( + const sequence_labeler& item, + std::ostream& out + ) + { + serialize(item.get_feature_extractor(), out); + serialize(item.get_weights(), out); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + void deserialize ( + sequence_labeler& item, + std::istream& in + ) + { + feature_extractor fe; + matrix weights; + + deserialize(fe, in); + deserialize(weights, in); + + item = sequence_labeler(weights, fe); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SEQUENCE_LAbELER_H_h_ + diff --git a/ml/dlib/dlib/svm/sequence_labeler_abstract.h b/ml/dlib/dlib/svm/sequence_labeler_abstract.h new file mode 100644 index 000000000..3970b723a --- /dev/null +++ b/ml/dlib/dlib/svm/sequence_labeler_abstract.h @@ -0,0 +1,396 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SEQUENCE_LAbELER_ABSTRACT_H_h_ +#ifdef DLIB_SEQUENCE_LAbELER_ABSTRACT_H_h_ + +#include "../matrix.h" +#include +#include "../optimization/find_max_factor_graph_viterbi_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class example_feature_extractor + { + /*! + WHAT THIS OBJECT REPRESENTS + This object defines the interface a feature extractor must implement + if it is to be used with the sequence_labeler defined at the bottom + of this file. + + The model used by sequence_labeler objects is the following. + Given an input sequence x, predict an output label sequence y + such that: + y == argmax_Y dot(w, PSI(x,Y)) + Where w is a parameter vector. + + Therefore, a feature extractor defines how the PSI(x,y) feature vector + is calculated. It also defines how many output labels there are as + well as the order of the model. + + Finally, note that PSI(x,y) is a sum of feature vectors, each derived + from the entire input sequence x but only part of the label sequence y. + Each of these constituent feature vectors is defined by the get_features() + method of this class. + + THREAD SAFETY + Instances of this object are required to be threadsafe, that is, it should + be safe for multiple threads to make concurrent calls to the member + functions of this object. + !*/ + + public: + // This should be the type used to represent an input sequence. It can be + // anything so long as it has a .size() which returns the length of the sequence. + typedef the_type_used_to_represent_a_sequence sequence_type; + + example_feature_extractor ( + ); + /*! + ensures + - this object is properly initialized + !*/ + + unsigned long num_features ( + ) const; + /*! + ensures + - returns the dimensionality of the PSI() feature vector. + !*/ + + unsigned long order( + ) const; + /*! + ensures + - This object represents a Markov model on the output labels. + This parameter defines the order of the model. That is, this + value controls how many previous label values get to be taken + into consideration when performing feature extraction for a + particular element of the input sequence. Note that the runtime + of the algorithm is exponential in the order. So don't make order + very large. + !*/ + + unsigned long num_labels( + ) const; + /*! + ensures + - returns the number of possible output labels. + !*/ + + template + bool reject_labeling ( + const sequence_type& x, + const matrix_exp& y, + unsigned long position + ) const; + /*! + requires + - EXP::type == unsigned long + (i.e. y contains unsigned longs) + - position < x.size() + - y.size() == min(position, order()) + 1 + - is_vector(y) == true + - max(y) < num_labels() + ensures + - for all valid i: + - interprets y(i) as the label corresponding to x[position-i] + - if (the labeling in y for x[position] is always the wrong labeling) then + - returns true + (note that reject_labeling() is just an optional tool to allow you + to overrule the normal labeling algorithm. You don't have to use + it. So if you don't include a reject_labeling() method in your + feature extractor it is the same as including one that always + returns false.) + - else + - returns false + !*/ + + template + void get_features ( + feature_setter& set_feature, + const sequence_type& x, + const matrix_exp& y, + unsigned long position + ) const; + /*! + requires + - EXP::type == unsigned long + (i.e. y contains unsigned longs) + - reject_labeling(x,y,position) == false + - position < x.size() + - y.size() == min(position, order()) + 1 + - is_vector(y) == true + - max(y) < num_labels() + - set_feature is a function object which allows expressions of the form: + - set_features((unsigned long)feature_index, (double)feature_value); + - set_features((unsigned long)feature_index); + ensures + - for all valid i: + - interprets y(i) as the label corresponding to x[position-i] + - This function computes the part of PSI() corresponding to the x[position] + element of the input sequence. Moreover, this part of PSI() is returned as + a sparse vector by invoking set_feature(). For example, to set the feature + with an index of 55 to the value of 1 this method would call: + set_feature(55); + Or equivalently: + set_feature(55,1); + Therefore, the first argument to set_feature is the index of the feature + to be set while the second argument is the value the feature should take. + Additionally, note that calling set_feature() multiple times with the same + feature index does NOT overwrite the old value, it adds to the previous + value. For example, if you call set_feature(55) 3 times then it will + result in feature 55 having a value of 3. + - This function only calls set_feature() with feature_index values < num_features() + !*/ + + unsigned long num_nonnegative_weights ( + ) const; + /*! + ensures + - returns the number of elements of the w parameter vector which should be + non-negative. That is, this feature extractor is intended to be used + with w vectors where the first num_nonnegative_weights() elements of w + are >= 0. That is, it should be the case that w(i) >= 0 for all i < + num_nonnegative_weights(). + - Note that num_nonnegative_weights() is just an optional method to allow + you to tell a tool like the structural_sequence_labeling_trainer that the + learned w should have a certain number of non-negative elements. + Therefore, if you do not provide a num_nonnegative_weights() method in + your feature extractor then it will default to a value of 0, indicating + that all elements of the w parameter vector may be any value. + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + void serialize( + const example_feature_extractor& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + + void deserialize( + example_feature_extractor& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + bool contains_invalid_labeling ( + const feature_extractor& fe, + const typename feature_extractor::sequence_type& x, + const std::vector& y + ); + /*! + requires + - feature_extractor must be an object that implements an interface compatible + with the example_feature_extractor discussed above. + ensures + - if (x.size() != y.size() || + fe.reject_labeling() rejects any of the labels in y) then + - returns true + - else + - returns false + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + bool contains_invalid_labeling ( + const feature_extractor& fe, + const std::vector& x, + const std::vector >& y + ); + /*! + requires + - feature_extractor must be an object that implements an interface compatible + with the example_feature_extractor discussed above. + ensures + - if (x.size() != y.size() || + contains_invalid_labeling(fe,x[i],y[i]) == true for some i ) then + - returns true + - else + - returns false + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + class sequence_labeler + { + /*! + REQUIREMENTS ON feature_extractor + It must be an object that implements an interface compatible with + the example_feature_extractor discussed above. + + WHAT THIS OBJECT REPRESENTS + This object is a tool for doing sequence labeling. In particular, it is + capable of representing sequence labeling models such as those produced by + Hidden Markov SVMs or Chain Structured Conditional Random fields. See the + following papers for an introduction to these techniques: + - Hidden Markov Support Vector Machines by + Y. Altun, I. Tsochantaridis, T. Hofmann + - Shallow Parsing with Conditional Random Fields by + Fei Sha and Fernando Pereira + + + The model used by this object is the following. Given + an input sequence x, predict an output label sequence y + such that: + y == argmax_Y dot(get_weights(), PSI(x,Y)) + Where PSI() is defined by the feature_extractor template + argument. + + THREAD SAFETY + It is always safe to use distinct instances of this object in different + threads. However, when a single instance is shared between threads then + the following rules apply: + It is safe to call the const members of this object from multiple + threads so long as the feature_extractor is also threadsafe. This is + because the const members are purely read-only operations. However, + any operation that modifies a sequence_labeler is not threadsafe. + !*/ + + public: + typedef typename feature_extractor::sequence_type sample_sequence_type; + typedef std::vector labeled_sequence_type; + + sequence_labeler( + ); + /*! + ensures + - #get_feature_extractor() == feature_extractor() + (i.e. it will have its default value) + - #get_weights().size() == #get_feature_extractor().num_features() + - #get_weights() == 0 + !*/ + + explicit sequence_labeler( + const matrix& weights + ); + /*! + requires + - feature_extractor().num_features() == weights.size() + ensures + - #get_feature_extractor() == feature_extractor() + (i.e. it will have its default value) + - #get_weights() == weights + !*/ + + sequence_labeler( + const matrix& weights, + const feature_extractor& fe + ); + /*! + requires + - fe.num_features() == weights.size() + ensures + - #get_feature_extractor() == fe + - #get_weights() == weights + !*/ + + const feature_extractor& get_feature_extractor ( + ) const; + /*! + ensures + - returns the feature extractor used by this object + !*/ + + const matrix& get_weights ( + ) const; + /*! + ensures + - returns the parameter vector associated with this sequence labeler. + The length of the vector is get_feature_extractor().num_features(). + !*/ + + unsigned long num_labels ( + ) const; + /*! + ensures + - returns get_feature_extractor().num_labels() + (i.e. returns the number of possible output labels for each + element of a sequence) + !*/ + + labeled_sequence_type operator() ( + const sample_sequence_type& x + ) const; + /*! + requires + - num_labels() > 0 + ensures + - returns a vector Y of label values such that: + - Y.size() == x.size() + - for all valid i: + - Y[i] == the predicted label for x[i] + - 0 <= Y[i] < num_labels() + !*/ + + void label_sequence ( + const sample_sequence_type& x, + labeled_sequence_type& y + ) const; + /*! + requires + - num_labels() > 0 + ensures + - #y == (*this)(x) + (i.e. This is just another interface to the operator() routine + above. This one avoids returning the results by value and therefore + might be a little faster in some cases) + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + void serialize ( + const sequence_labeler& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + void deserialize ( + sequence_labeler& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SEQUENCE_LAbELER_ABSTRACT_H_h_ + + diff --git a/ml/dlib/dlib/svm/sequence_segmenter.h b/ml/dlib/dlib/svm/sequence_segmenter.h new file mode 100644 index 000000000..237023efa --- /dev/null +++ b/ml/dlib/dlib/svm/sequence_segmenter.h @@ -0,0 +1,468 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SEQUENCE_SeGMENTER_H_h_ +#define DLIB_SEQUENCE_SeGMENTER_H_h_ + +#include "sequence_segmenter_abstract.h" +#include "../matrix.h" +#include "sequence_labeler.h" +#include + +namespace dlib +{ + // This namespace contains implementation details for the sequence_segmenter. + namespace impl_ss + { + + // ------------------------------------------------------------------------------------ + + // BIO/BILOU labels + const unsigned int BEGIN = 0; + const unsigned int INSIDE = 1; + const unsigned int OUTSIDE = 2; + const unsigned int LAST = 3; + const unsigned int UNIT = 4; + + + // ------------------------------------------------------------------------------------ + + template + class feature_extractor + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a feature extractor for a sequence_labeler. It serves to map + the interface defined by a sequence_labeler into the kind of interface + defined for a sequence_segmenter. + !*/ + + public: + typedef typename ss_feature_extractor::sequence_type sequence_type; + + ss_feature_extractor fe; + + feature_extractor() {} + feature_extractor(const ss_feature_extractor& ss_fe_) : fe(ss_fe_) {} + + unsigned long num_nonnegative_weights ( + ) const + { + const unsigned long NL = ss_feature_extractor::use_BIO_model ? 3 : 5; + if (ss_feature_extractor::allow_negative_weights) + { + return 0; + } + else + { + // We make everything non-negative except for the label transition + // and bias features. + return num_features() - NL*NL - NL; + } + } + + friend void serialize(const feature_extractor& item, std::ostream& out) + { + serialize(item.fe, out); + } + + friend void deserialize(feature_extractor& item, std::istream& in) + { + deserialize(item.fe, in); + } + + unsigned long num_features() const + { + const unsigned long NL = ss_feature_extractor::use_BIO_model ? 3 : 5; + if (ss_feature_extractor::use_high_order_features) + return NL + NL*NL + (NL*NL+NL)*fe.num_features()*fe.window_size(); + else + return NL + NL*NL + NL*fe.num_features()*fe.window_size(); + } + + unsigned long order() const + { + return 1; + } + + unsigned long num_labels() const + { + if (ss_feature_extractor::use_BIO_model) + return 3; + else + return 5; + } + + private: + + template + struct dot_functor + { + /*! + WHAT THIS OBJECT REPRESENTS + This class wraps the feature_setter used by a sequence_labeler + and turns it into the kind needed by a sequence_segmenter. + !*/ + + dot_functor(feature_setter& set_feature_, unsigned long offset_) : + set_feature(set_feature_), offset(offset_) {} + + feature_setter& set_feature; + unsigned long offset; + + inline void operator() ( + unsigned long feat_index + ) + { + set_feature(offset+feat_index); + } + + inline void operator() ( + unsigned long feat_index, + double feat_value + ) + { + set_feature(offset+feat_index, feat_value); + } + }; + + public: + + template + bool reject_labeling ( + const sequence_type& x, + const matrix_exp& y, + unsigned long pos + ) const + { + if (ss_feature_extractor::use_BIO_model) + { + // Don't allow BIO label patterns that don't correspond to a sensical + // segmentation. + if (y.size() > 1 && y(0) == INSIDE && y(1) == OUTSIDE) + return true; + if (y.size() == 1 && y(0) == INSIDE) + return true; + } + else + { + // Don't allow BILOU label patterns that don't correspond to a sensical + // segmentation. + if (y.size() > 1) + { + if (y(1) == BEGIN && y(0) == OUTSIDE) + return true; + if (y(1) == BEGIN && y(0) == UNIT) + return true; + if (y(1) == BEGIN && y(0) == BEGIN) + return true; + + if (y(1) == INSIDE && y(0) == BEGIN) + return true; + if (y(1) == INSIDE && y(0) == OUTSIDE) + return true; + if (y(1) == INSIDE && y(0) == UNIT) + return true; + + if (y(1) == OUTSIDE && y(0) == INSIDE) + return true; + if (y(1) == OUTSIDE && y(0) == LAST) + return true; + + if (y(1) == LAST && y(0) == INSIDE) + return true; + if (y(1) == LAST && y(0) == LAST) + return true; + + if (y(1) == UNIT && y(0) == INSIDE) + return true; + if (y(1) == UNIT && y(0) == LAST) + return true; + + // if at the end of the sequence + if (pos == x.size()-1) + { + if (y(0) == BEGIN) + return true; + if (y(0) == INSIDE) + return true; + } + } + else + { + if (y(0) == INSIDE) + return true; + if (y(0) == LAST) + return true; + + // if at the end of the sequence + if (pos == x.size()-1) + { + if (y(0) == BEGIN) + return true; + } + } + } + return false; + } + + template + void get_features ( + feature_setter& set_feature, + const sequence_type& x, + const matrix_exp& y, + unsigned long position + ) const + { + unsigned long offset = 0; + + const int window_size = fe.window_size(); + + const int base_dims = fe.num_features(); + for (int i = 0; i < window_size; ++i) + { + const long pos = i-window_size/2 + static_cast(position); + if (0 <= pos && pos < (long)x.size()) + { + const unsigned long off1 = y(0)*base_dims; + dot_functor fs1(set_feature, offset+off1); + fe.get_features(fs1, x, pos); + + if (ss_feature_extractor::use_high_order_features && y.size() > 1) + { + const unsigned long off2 = num_labels()*base_dims + (y(0)*num_labels()+y(1))*base_dims; + dot_functor fs2(set_feature, offset+off2); + fe.get_features(fs2, x, pos); + } + } + + if (ss_feature_extractor::use_high_order_features) + offset += num_labels()*base_dims + num_labels()*num_labels()*base_dims; + else + offset += num_labels()*base_dims; + } + + // Pull out an indicator feature for the type of transition between the + // previous label and the current label. + if (y.size() > 1) + set_feature(offset + y(1)*num_labels() + y(0)); + + offset += num_labels()*num_labels(); + // pull out an indicator feature for the current label. This is the per + // label bias. + set_feature(offset + y(0)); + } + }; + + } // end namespace impl_ss + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + unsigned long total_feature_vector_size ( + const feature_extractor& fe + ) + { + const unsigned long NL = feature_extractor::use_BIO_model ? 3 : 5; + if (feature_extractor::use_high_order_features) + return NL + NL*NL + (NL*NL+NL)*fe.num_features()*fe.window_size(); + else + return NL + NL*NL + NL*fe.num_features()*fe.window_size(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + class sequence_segmenter + { + public: + typedef typename feature_extractor::sequence_type sample_sequence_type; + typedef std::vector > segmented_sequence_type; + + + sequence_segmenter() + { +#ifdef ENABLE_ASSERTS + const feature_extractor& fe = labeler.get_feature_extractor().fe; + DLIB_ASSERT(fe.window_size() >= 1 && fe.num_features() >= 1, + "\t sequence_segmenter::sequence_segmenter()" + << "\n\t An invalid feature extractor was supplied." + << "\n\t fe.window_size(): " << fe.window_size() + << "\n\t fe.num_features(): " << fe.num_features() + << "\n\t this: " << this + ); +#endif + } + + explicit sequence_segmenter( + const matrix& weights + ) : + labeler(weights) + { +#ifdef ENABLE_ASSERTS + const feature_extractor& fe = labeler.get_feature_extractor().fe; + // make sure requires clause is not broken + DLIB_ASSERT(total_feature_vector_size(fe) == (unsigned long)weights.size(), + "\t sequence_segmenter::sequence_segmenter(weights)" + << "\n\t These sizes should match" + << "\n\t total_feature_vector_size(fe): " << total_feature_vector_size(fe) + << "\n\t weights.size(): " << weights.size() + << "\n\t this: " << this + ); + DLIB_ASSERT(fe.window_size() >= 1 && fe.num_features() >= 1, + "\t sequence_segmenter::sequence_segmenter()" + << "\n\t An invalid feature extractor was supplied." + << "\n\t fe.window_size(): " << fe.window_size() + << "\n\t fe.num_features(): " << fe.num_features() + << "\n\t this: " << this + ); +#endif + } + + sequence_segmenter( + const matrix& weights, + const feature_extractor& fe + ) : + labeler(weights, impl_ss::feature_extractor(fe)) + { + // make sure requires clause is not broken + DLIB_ASSERT(total_feature_vector_size(fe) == (unsigned long)weights.size(), + "\t sequence_segmenter::sequence_segmenter(weights,fe)" + << "\n\t These sizes should match" + << "\n\t total_feature_vector_size(fe): " << total_feature_vector_size(fe) + << "\n\t weights.size(): " << weights.size() + << "\n\t this: " << this + ); + DLIB_ASSERT(fe.window_size() >= 1 && fe.num_features() >= 1, + "\t sequence_segmenter::sequence_segmenter()" + << "\n\t An invalid feature extractor was supplied." + << "\n\t fe.window_size(): " << fe.window_size() + << "\n\t fe.num_features(): " << fe.num_features() + << "\n\t this: " << this + ); + } + + const feature_extractor& get_feature_extractor ( + ) const { return labeler.get_feature_extractor().fe; } + + const matrix& get_weights ( + ) const { return labeler.get_weights(); } + + segmented_sequence_type operator() ( + const sample_sequence_type& x + ) const + { + segmented_sequence_type y; + segment_sequence(x,y); + return y; + } + + void segment_sequence ( + const sample_sequence_type& x, + segmented_sequence_type& y + ) const + { + y.clear(); + std::vector labels; + labeler.label_sequence(x, labels); + + if (feature_extractor::use_BIO_model) + { + // Convert from BIO tagging to the explicit segments representation. + for (unsigned long i = 0; i < labels.size(); ++i) + { + if (labels[i] == impl_ss::BEGIN) + { + const unsigned long begin = i; + ++i; + while (i < labels.size() && labels[i] == impl_ss::INSIDE) + ++i; + + y.push_back(std::make_pair(begin, i)); + --i; + } + } + } + else + { + // Convert from BILOU tagging to the explicit segments representation. + for (unsigned long i = 0; i < labels.size(); ++i) + { + if (labels[i] == impl_ss::BEGIN) + { + const unsigned long begin = i; + ++i; + while (i < labels.size() && labels[i] == impl_ss::INSIDE) + ++i; + + y.push_back(std::make_pair(begin, i+1)); + } + else if (labels[i] == impl_ss::UNIT) + { + y.push_back(std::make_pair(i, i+1)); + } + } + } + } + + friend void serialize(const sequence_segmenter& item, std::ostream& out) + { + int version = 1; + serialize(version, out); + + // Save these just so we can compare them when we deserialize and make + // sure the feature_extractor being used is compatible with the model being + // loaded. + serialize(feature_extractor::use_BIO_model, out); + serialize(feature_extractor::use_high_order_features, out); + serialize(total_feature_vector_size(item.get_feature_extractor()), out); + + serialize(item.labeler, out); + } + + friend void deserialize(sequence_segmenter& item, std::istream& in) + { + int version = 0; + deserialize(version, in); + if (version != 1) + throw serialization_error("Unexpected version found while deserializing dlib::sequence_segmenter."); + + // Try to check if the saved model is compatible with the current feature + // extractor. + bool use_BIO_model, use_high_order_features; + unsigned long dims; + deserialize(use_BIO_model, in); + deserialize(use_high_order_features, in); + deserialize(dims, in); + deserialize(item.labeler, in); + if (use_BIO_model != feature_extractor::use_BIO_model) + { + throw serialization_error("Incompatible feature extractor found while deserializing " + "dlib::sequence_segmenter. Wrong value of use_BIO_model."); + } + if (use_high_order_features != feature_extractor::use_high_order_features) + { + throw serialization_error("Incompatible feature extractor found while deserializing " + "dlib::sequence_segmenter. Wrong value of use_high_order_features."); + } + if (dims != total_feature_vector_size(item.get_feature_extractor())) + { + throw serialization_error("Incompatible feature extractor found while deserializing " + "dlib::sequence_segmenter. Wrong value of total_feature_vector_size()."); + } + } + + private: + sequence_labeler > labeler; + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SEQUENCE_SeGMENTER_H_h_ + + diff --git a/ml/dlib/dlib/svm/sequence_segmenter_abstract.h b/ml/dlib/dlib/svm/sequence_segmenter_abstract.h new file mode 100644 index 000000000..7229fee22 --- /dev/null +++ b/ml/dlib/dlib/svm/sequence_segmenter_abstract.h @@ -0,0 +1,452 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SEQUENCE_SeGMENTER_ABSTRACT_H_h_ +#ifdef DLIB_SEQUENCE_SeGMENTER_ABSTRACT_H_h_ + +#include "../matrix.h" +#include +#include "sequence_labeler_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class example_feature_extractor + { + /*! + WHAT THIS OBJECT REPRESENTS + This object defines the interface a feature extractor must implement if it + is to be used with the sequence_segmenter defined at the bottom of this + file. + + The model used by sequence_segmenter objects is the following. Given an + input sequence x, predict an output label sequence y such that: + y == argmax_Y dot(w, PSI(x,Y)) + Where w is a parameter vector and the label sequence defines a segmentation + of x. + + Recall that a sequence_segmenter uses the BIO or BILOU tagging model and is + also an instantiation of the dlib::sequence_labeler. Selecting to use the + BIO model means that each element of the label sequence y takes on one of + three possible values (B, I, or O) and together these labels define a + segmentation of the sequence. For example, to represent a segmentation of + the sequence of words "The dog ran to Bob Jones" where only "Bob Jones" was + segmented out we would use the label sequence OOOOBI. The BILOU model is + similar except that it uses five different labels and each segment is + labeled as U, BL, BIL, BIIL, BIIIL, and so on depending on its length. + Therefore, the BILOU model is able to more explicitly model the ends of the + segments than the BIO model, but has more parameters to estimate. + + Keeping all this in mind, the purpose of a sequence_segmenter is to take + care of the bookkeeping associated with creating BIO/BILOU tagging models + for segmentation tasks. In particular, it presents the user with a + simplified version of the interface used by the dlib::sequence_labeler. It + does this by completely hiding the BIO/BILOU tags from the user and instead + exposes an explicit sub-segment based labeling representation. It also + simplifies the construction of the PSI() feature vector. + + Like in the dlib::sequence_labeler, PSI() is a sum of feature vectors, each + derived from the entire input sequence x but only part of the label + sequence y. In the case of the sequence_segmenter, we use an order one + Markov model. This means that + PSI(x,y) == sum_i XI(x, y_{i-1}, y_{i}, i) + where the sum is taken over all the elements in the sequence. At each + element we extract a feature vector, XI(), that is expected to encode + important details describing what the i-th position of the sequence looks + like in the context of the current and previous labels. To do this, XI() + is allowed to look at any part of the input sequence x, the current and + previous labels, and of course it must also know the position in question, i. + + The sequence_segmenter simplifies this further by decomposing XI() into + components which model the current window around each position as well as + the conjunction of the current window around each position and the previous + label. In particular, the sequence_segmenter only asks a user to provide a + single feature vector which characterizes a position of the sequence + independent of any labeling. We denote this feature vector by ZI(x,i), where + x is the sequence and i is the position in question. + + For example, suppose we use a window size of 3 and BIO tags, then we can + put all this together and define XI() in terms of ZI(). To do this, we can + think of XI() as containing 12*3 slots which contain either a zero vector + or a ZI() vector. Each combination of window position and labeling has a + different slot. To explain further, consider the following examples where + we have annotated which parts of XI() correspond to each slot. + + If the previous and current label are both B and we use a window size of 3 + then XI() would be instantiated as: + XI(x, B, B, i) = [ZI(x,i-1) \ + ZI(x,i) > If current label is B + ZI(x,i+1) / + 0 \ + 0 > If current label is I + 0 / + 0 \ + 0 > If current label is O + 0 / + + ZI(x,i-1) \ + ZI(x,i) > If previous label is B and current label is B + ZI(x,i+1) / + 0 \ + 0 > If previous label is B and current label is I + 0 / + 0 \ + 0 > If previous label is B and current label is O + 0 / + + 0 \ + 0 > If previous label is I and current label is B + 0 / + 0 \ + 0 > If previous label is I and current label is I + 0 / + 0 \ + 0 > If previous label is I and current label is O + 0 / + + 0 \ + 0 > If previous label is O and current label is B + 0 / + 0 \ + 0 > If previous label is O and current label is I + 0 / + 0 \ + 0 > If previous label is O and current label is O + 0] / + + + If the previous label is I and the current label is O and we use a window + size of 3 then XI() would be instantiated as: + XI(x, I, O, i) = [0 \ + 0 > If current label is B + 0 / + 0 \ + 0 > If current label is I + 0 / + ZI(x,i-1) \ + ZI(x,i) > If current label is O + ZI(x,i+1) / + + 0 \ + 0 > If previous label is B and current label is B + 0 / + 0 \ + 0 > If previous label is B and current label is I + 0 / + 0 \ + 0 > If previous label is B and current label is O + 0 / + + 0 \ + 0 > If previous label is I and current label is B + 0 / + 0 \ + 0 > If previous label is I and current label is I + 0 / + ZI(x,i-1) \ + ZI(x,i) > If previous label is I and current label is O + ZI(x,i+1) / + + 0 \ + 0 > If previous label is O and current label is B + 0 / + 0 \ + 0 > If previous label is O and current label is I + 0 / + 0 \ + 0 > If previous label is O and current label is O + 0] / + + If we had instead used the BILOU tagging model the XI() vector would + have been similarly defined except that there would be 30*3 slots for + the various label combination instead of 12*3. + + Finally, while not shown here, we also include indicator features in + XI() to model label transitions and individual label biases. These are + 12 extra features in the case of the BIO tagging model and 30 extra in + the case of the BILOU tagging model. + + THREAD SAFETY + Instances of this object are required to be threadsafe, that is, it should + be safe for multiple threads to make concurrent calls to the member + functions of this object. + !*/ + + public: + // This should be the type used to represent an input sequence. It can be + // anything so long as it has a .size() which returns the length of the sequence. + typedef the_type_used_to_represent_a_sequence sequence_type; + + // If you want to use the BIO tagging model then set this bool to true. Set it to + // false to use the BILOU tagging model. + const static bool use_BIO_model = true; + + // In the WHAT THIS OBJECT REPRESENTS section above we discussed how we model the + // conjunction of the previous label and the window around each position. Doing + // this greatly expands the size of the parameter vector w. You can optionally + // disable these higher order features by setting the use_high_order_features bool + // to false. This will cause XI() to include only slots which are independent of + // the previous label. + const static bool use_high_order_features = true; + + // You use a tool like the structural_sequence_segmentation_trainer to learn the weight + // vector needed by a sequence_segmenter. You can tell the trainer to force all the + // elements of the weight vector corresponding to ZI() to be non-negative. This is all + // the elements of w except for the elements corresponding to the label transition and + // bias indicator features. To do this, just set allow_negative_weights to false. + const static bool allow_negative_weights = true; + + + example_feature_extractor ( + ); + /*! + ensures + - this object is properly initialized + !*/ + + unsigned long num_features( + ) const; + /*! + ensures + - returns the dimensionality of the ZI() feature vector. This number is + always >= 1 + !*/ + + unsigned long window_size( + ) const; + /*! + ensures + - returns the size of the window ZI() vectors are extracted from. This + number is always >= 1. + !*/ + + template + void get_features ( + feature_setter& set_feature, + const sequence_type& x, + unsigned long position + ) const; + /*! + requires + - position < x.size() + - set_feature is a function object which allows expressions of the form: + - set_features((unsigned long)feature_index, (double)feature_value); + - set_features((unsigned long)feature_index); + ensures + - This function computes the ZI(x,position) feature vector. This is a + feature vector which should capture the properties of x[position] that + are informative relative to the sequence segmentation task you are trying + to perform. + - ZI(x,position) is returned as a sparse vector by invoking set_feature(). + For example, to set the feature with an index of 55 to the value of 1 + this method would call: + set_feature(55); + Or equivalently: + set_feature(55,1); + Therefore, the first argument to set_feature is the index of the feature + to be set while the second argument is the value the feature should take. + Additionally, note that calling set_feature() multiple times with the + same feature index does NOT overwrite the old value, it adds to the + previous value. For example, if you call set_feature(55) 3 times then it + will result in feature 55 having a value of 3. + - This function only calls set_feature() with feature_index values < num_features() + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + void serialize( + const example_feature_extractor& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + + void deserialize( + example_feature_extractor& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + unsigned long total_feature_vector_size ( + const feature_extractor& fe + ); + /*! + requires + - fe must be an object that implements an interface compatible with the + example_feature_extractor discussed above. + ensures + - returns the dimensionality of the PSI() vector defined by the given feature + extractor. + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + class sequence_segmenter + { + /*! + REQUIREMENTS ON feature_extractor + It must be an object that implements an interface compatible with + the example_feature_extractor discussed above. + + WHAT THIS OBJECT REPRESENTS + This object is a tool for segmenting a sequence of objects into a set of + non-overlapping chunks. An example sequence segmentation task is to take + English sentences and identify all the named entities. In this example, + you would be using a sequence_segmenter to find all the chunks of + contiguous words which refer to proper names. + + Internally, the sequence_segmenter uses the BIO (Begin, Inside, Outside) or + BILOU (Begin, Inside, Last, Outside, Unit) sequence tagging model. + Moreover, it is implemented using a dlib::sequence_labeler object and + therefore sequence_segmenter objects are examples of chain structured + conditional random field style sequence taggers. + + THREAD SAFETY + It is always safe to use distinct instances of this object in different + threads. However, when a single instance is shared between threads then + the following rules apply: + It is safe to call the const members of this object from multiple + threads so long as the feature_extractor is also threadsafe. This is + because the const members are purely read-only operations. However, + any operation that modifies a sequence_segmenter is not threadsafe. + !*/ + + public: + typedef typename feature_extractor::sequence_type sample_sequence_type; + typedef std::vector > segmented_sequence_type; + + sequence_segmenter( + ); + /*! + ensures + - #get_feature_extractor() == feature_extractor() + (i.e. it will have its default value) + - #get_weights().size() == total_feature_vector_size(#get_feature_extractor()) + - #get_weights() == 0 + !*/ + + explicit sequence_segmenter( + const matrix& weights + ); + /*! + requires + - total_feature_vector_size(feature_extractor()) == weights.size() + ensures + - #get_feature_extractor() == feature_extractor() + (i.e. it will have its default value) + - #get_weights() == weights + !*/ + + sequence_segmenter( + const matrix& weights, + const feature_extractor& fe + ); + /*! + requires + - total_feature_vector_size(fe) == weights.size() + ensures + - #get_feature_extractor() == fe + - #get_weights() == weights + !*/ + + const feature_extractor& get_feature_extractor ( + ) const; + /*! + ensures + - returns the feature extractor used by this object. + !*/ + + const matrix& get_weights ( + ) const; + /*! + ensures + - returns the parameter vector associated with this sequence segmenter. + The length of the vector is total_feature_vector_size(get_feature_extractor()). + !*/ + + segmented_sequence_type operator() ( + const sample_sequence_type& x + ) const; + /*! + ensures + - Takes an input sequence and returns a list of detected segments within + that sequence. + - None of the returned segments will overlap. + - The returned segments are listed in the order they appeared in the input sequence. + - To be precise, this function returns a std::vector Y of segments such that: + - Y.size() == the number of segments detected in the input sequence x. + - for all valid i: + - Y[i].first == the start of the i-th segment. + - Y[i].second == one past the end of the i-th segment. + - Therefore, the i-th detected segment in x is composed of the elements + x[Y[i].first], x[Y[i].first+1], ..., x[Y[i].second-1] + - Y[i].first < x.size() + - Y[i].second <= x.size() + - Y[i].first < Y[i].second + (i.e. This function never outputs empty segments) + - Y[i].second <= Y[i+1].first + (i.e. the segments are listed in order of appearance and do not overlap) + !*/ + + void segment_sequence ( + const sample_sequence_type& x, + segmented_sequence_type& y + ) const; + /*! + ensures + - #y == (*this)(x) + (i.e. This is just another interface to the operator() routine + above. This one avoids returning the results by value and therefore + might be a little faster in some cases) + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + void serialize ( + const sequence_segmenter& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + void deserialize ( + sequence_segmenter& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SEQUENCE_SeGMENTER_ABSTRACT_H_h_ + diff --git a/ml/dlib/dlib/svm/simplify_linear_decision_function.h b/ml/dlib/dlib/svm/simplify_linear_decision_function.h new file mode 100644 index 000000000..4f5bef6f3 --- /dev/null +++ b/ml/dlib/dlib/svm/simplify_linear_decision_function.h @@ -0,0 +1,110 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SIMPLIFY_LINEAR_DECiSION_FUNCTION_Hh_ +#define DLIB_SIMPLIFY_LINEAR_DECiSION_FUNCTION_Hh_ + +#include "simplify_linear_decision_function_abstract.h" +#include "../algs.h" +#include "function.h" +#include "sparse_kernel.h" +#include "kernel.h" +#include +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + decision_function > simplify_linear_decision_function ( + const decision_function >& df + ) + { + // don't do anything if we don't have to + if (df.basis_vectors.size() <= 1) + return df; + + decision_function > new_df; + + new_df.b = df.b; + new_df.basis_vectors.set_size(1); + new_df.alpha.set_size(1); + new_df.alpha(0) = 1; + + // now compute the weighted sum of all the sparse basis_vectors in df + typedef typename T::value_type pair_type; + typedef typename pair_type::first_type key_type; + typedef typename pair_type::second_type value_type; + std::map accum; + for (long i = 0; i < df.basis_vectors.size(); ++i) + { + typename T::const_iterator j = df.basis_vectors(i).begin(); + const typename T::const_iterator end = df.basis_vectors(i).end(); + for (; j != end; ++j) + { + accum[j->first] += df.alpha(i) * (j->second); + } + } + + new_df.basis_vectors(0) = T(accum.begin(), accum.end()); + + return new_df; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + decision_function > simplify_linear_decision_function ( + const decision_function >& df + ) + { + // don't do anything if we don't have to + if (df.basis_vectors.size() <= 1) + return df; + + decision_function > new_df; + + new_df.b = df.b; + new_df.basis_vectors.set_size(1); + new_df.alpha.set_size(1); + new_df.alpha(0) = 1; + + // now compute the weighted sum of all the basis_vectors in df + new_df.basis_vectors(0) = 0; + for (long i = 0; i < df.basis_vectors.size(); ++i) + { + new_df.basis_vectors(0) += df.alpha(i) * df.basis_vectors(i); + } + + return new_df; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + decision_function > simplify_linear_decision_function ( + const normalized_function >, vector_normalizer >& df + ) + { + decision_function > new_df = simplify_linear_decision_function(df.function); + + // now incorporate the normalization stuff into new_df + new_df.basis_vectors(0) = pointwise_multiply(new_df.basis_vectors(0), df.normalizer.std_devs()); + new_df.b += dot(new_df.basis_vectors(0), df.normalizer.means()); + + return new_df; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SIMPLIFY_LINEAR_DECiSION_FUNCTION_Hh_ + diff --git a/ml/dlib/dlib/svm/simplify_linear_decision_function_abstract.h b/ml/dlib/dlib/svm/simplify_linear_decision_function_abstract.h new file mode 100644 index 000000000..cff8ae11f --- /dev/null +++ b/ml/dlib/dlib/svm/simplify_linear_decision_function_abstract.h @@ -0,0 +1,74 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SIMPLIFY_LINEAR_DECiSION_FUNCTION_ABSTRACT_Hh_ +#ifdef DLIB_SIMPLIFY_LINEAR_DECiSION_FUNCTION_ABSTRACT_Hh_ + +#include "../algs.h" +#include "function_abstract.h" +#include "sparse_kernel_abstract.h" +#include "kernel_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + decision_function > simplify_linear_decision_function ( + const decision_function >& df + ); + /*! + requires + - T must be a sparse vector as defined in dlib/svm/sparse_vector_abstract.h + ensures + - returns a simplified version of df that only has one basis vector. That + is, returns a decision function D such that: + - D.basis_vectors.size() == 1 (or 0 if df is empty) + - for all possible x: D(x) == df(x) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + decision_function > simplify_linear_decision_function ( + const decision_function >& df + ); + /*! + requires + - T must be a dlib::matrix object + ensures + - returns a simplified version of df that only has one basis vector. That + is, returns a decision function D such that: + - D.basis_vectors.size() == 1 (or 0 if df is empty) + - for all possible x: D(x) == df(x) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + decision_function > simplify_linear_decision_function ( + const normalized_function >, vector_normalizer >& df + ); + /*! + requires + - T must be a dlib::matrix object + ensures + - returns a simplified version of df that only has one basis vector and + doesn't involve an explicit vector_normalizer. That is, returns a + decision function D such that: + - D.basis_vectors.size() == 1 (or 0 if df is empty) + - for all possible x: D(x) == df(x) + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SIMPLIFY_LINEAR_DECiSION_FUNCTION_ABSTRACT_Hh_ + diff --git a/ml/dlib/dlib/svm/sort_basis_vectors.h b/ml/dlib/dlib/svm/sort_basis_vectors.h new file mode 100644 index 000000000..1d4605b41 --- /dev/null +++ b/ml/dlib/dlib/svm/sort_basis_vectors.h @@ -0,0 +1,224 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SORT_BASIS_VECTORs_Hh_ +#define DLIB_SORT_BASIS_VECTORs_Hh_ + +#include + +#include "sort_basis_vectors_abstract.h" +#include "../matrix.h" +#include "../statistics.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace bs_impl + { + template + typename EXP::matrix_type invert ( + const matrix_exp& m + ) + { + eigenvalue_decomposition eig(make_symmetric(m)); + + typedef typename EXP::type scalar_type; + typedef typename EXP::mem_manager_type mm_type; + + matrix vals = eig.get_real_eigenvalues(); + + const scalar_type max_eig = max(abs(vals)); + const scalar_type thresh = max_eig*std::sqrt(std::numeric_limits::epsilon()); + + // Since m might be singular or almost singular we need to do something about + // any very small eigenvalues. So here we set the smallest eigenvalues to + // be equal to a large value to make the inversion stable. We can't just set + // them to zero like in a normal pseudo-inverse since we want the resulting + // inverse matrix to be full rank. + for (long i = 0; i < vals.size(); ++i) + { + if (std::abs(vals(i)) < thresh) + vals(i) = max_eig; + } + + // Build the inverse matrix. This is basically a pseudo-inverse. + return make_symmetric(eig.get_pseudo_v()*diagm(reciprocal(vals))*trans(eig.get_pseudo_v())); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename kernel_type, + typename vect1_type, + typename vect2_type, + typename vect3_type + > + const std::vector sort_basis_vectors_impl ( + const kernel_type& kern, + const vect1_type& samples, + const vect2_type& labels, + const vect3_type& basis, + double eps + ) + { + DLIB_ASSERT(is_binary_classification_problem(samples, labels) && + 0 < eps && eps <= 1 && + basis.size() > 0, + "\t void sort_basis_vectors()" + << "\n\t Invalid arguments were given to this function." + << "\n\t is_binary_classification_problem(samples, labels): " << is_binary_classification_problem(samples, labels) + << "\n\t basis.size(): " << basis.size() + << "\n\t eps: " << eps + ); + + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::mem_manager_type mm_type; + + typedef matrix col_matrix; + typedef matrix gen_matrix; + + col_matrix c1_mean, c2_mean, temp, delta; + + + col_matrix weights; + + running_covariance cov; + + // compute the covariance matrix and the means of the two classes. + for (long i = 0; i < samples.size(); ++i) + { + temp = kernel_matrix(kern, basis, samples(i)); + cov.add(temp); + if (labels(i) > 0) + c1_mean += temp; + else + c2_mean += temp; + } + + c1_mean /= sum(labels > 0); + c2_mean /= sum(labels < 0); + + delta = c1_mean - c2_mean; + + gen_matrix cov_inv = bs_impl::invert(cov.covariance()); + + + matrix total_perm = trans(range(0, delta.size()-1)); + matrix perm = total_perm; + + std::vector > sorted_feats(delta.size()); + + long best_size = delta.size(); + long misses = 0; + matrix best_total_perm = perm; + + // Now we basically find fisher's linear discriminant over and over. Each + // time sorting the features so that the most important ones pile up together. + weights = trans(chol(cov_inv))*delta; + while (true) + { + + for (unsigned long i = 0; i < sorted_feats.size(); ++i) + sorted_feats[i] = make_pair(std::abs(weights(i)), i); + + std::sort(sorted_feats.begin(), sorted_feats.end()); + + // make a permutation vector according to the sorting + for (long i = 0; i < perm.size(); ++i) + perm(i) = sorted_feats[i].second; + + + // Apply the permutation. Doing this gives the same result as permuting all the + // features and then recomputing the delta and cov_inv from scratch. + cov_inv = subm(cov_inv,perm,perm); + delta = rowm(delta,perm); + + // Record all the permutations we have done so we will know how the final + // weights match up with the original basis vectors when we are done. + total_perm = rowm(total_perm, perm); + + // compute new Fisher weights for sorted features. + weights = trans(chol(cov_inv))*delta; + + // Measure how many features it takes to account for eps% of the weights vector. + const scalar_type total_weight = length_squared(weights); + scalar_type weight_accum = 0; + long size = 0; + // figure out how to get eps% of the weights + for (long i = weights.size()-1; i >= 0; --i) + { + ++size; + weight_accum += weights(i)*weights(i); + if (weight_accum/total_weight > eps) + break; + } + + // loop until the best_size stops dropping + if (size < best_size) + { + misses = 0; + best_size = size; + best_total_perm = total_perm; + } + else + { + ++misses; + + // Give up once we have had 10 rounds where we didn't find a weights vector with + // a smaller concentration of good features. + if (misses >= 10) + break; + } + + } + + // make sure best_size isn't zero + if (best_size == 0) + best_size = 1; + + std::vector sorted_basis; + + // permute the basis so that it matches up with the contents of the best weights + sorted_basis.resize(best_size); + for (unsigned long i = 0; i < sorted_basis.size(); ++i) + { + // Note that we load sorted_basis backwards so that the most important + // basis elements come first. + sorted_basis[i] = basis(best_total_perm(basis.size()-i-1)); + } + + return sorted_basis; + } + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename kernel_type, + typename vect1_type, + typename vect2_type, + typename vect3_type + > + const std::vector sort_basis_vectors ( + const kernel_type& kern, + const vect1_type& samples, + const vect2_type& labels, + const vect3_type& basis, + double eps = 0.99 + ) + { + return bs_impl::sort_basis_vectors_impl(kern, + mat(samples), + mat(labels), + mat(basis), + eps); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SORT_BASIS_VECTORs_Hh_ + diff --git a/ml/dlib/dlib/svm/sort_basis_vectors_abstract.h b/ml/dlib/dlib/svm/sort_basis_vectors_abstract.h new file mode 100644 index 000000000..b43dca170 --- /dev/null +++ b/ml/dlib/dlib/svm/sort_basis_vectors_abstract.h @@ -0,0 +1,59 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SORT_BASIS_VECTORs_ABSTRACT_Hh_ +#ifdef DLIB_SORT_BASIS_VECTORs_ABSTRACT_Hh_ + +#include + +#include "../matrix.h" +#include "../statistics.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename kernel_type, + typename vect1_type, + typename vect2_type, + typename vect3_type + > + const std::vector sort_basis_vectors ( + const kernel_type& kern, + const vect1_type& samples, + const vect2_type& labels, + const vect3_type& basis, + double eps = 0.99 + ); + /*! + requires + - is_binary_classification_problem(samples, labels) + - 0 < eps <= 1 + - basis.size() > 0 + - kernel_type is a kernel function object as defined in dlib/svm/kernel_abstract.h + It must be capable of operating on the elements of samples and basis. + - vect1_type == a matrix or something convertible to a matrix via mat() + - vect2_type == a matrix or something convertible to a matrix via mat() + - vect3_type == a matrix or something convertible to a matrix via mat() + ensures + - A kernel based learning method ultimately needs to select a set of basis functions + represented by a particular choice of kernel and a set of basis vectors. + sort_basis_vectors() attempts to order the elements of basis so that elements which are + most useful in solving the binary classification problem defined by samples and + labels come first. + - In particular, this function returns a std::vector, SB, of sorted basis vectors such that: + - 0 < SB.size() <= basis.size() + - SB will contain elements from basis but they will have been sorted so that + the most useful elements come first (i.e. SB[0] is the most important). + - eps notionally controls how big SB will be. Bigger eps corresponds to a + bigger basis. You can think of it like asking for eps percent of the + discriminating power from the input basis. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SORT_BASIS_VECTORs_ABSTRACT_Hh_ + diff --git a/ml/dlib/dlib/svm/sparse_kernel.h b/ml/dlib/dlib/svm/sparse_kernel.h new file mode 100644 index 000000000..f571135ec --- /dev/null +++ b/ml/dlib/dlib/svm/sparse_kernel.h @@ -0,0 +1,384 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SVm_SPARSE_KERNEL +#define DLIB_SVm_SPARSE_KERNEL + +#include "sparse_kernel_abstract.h" +#include +#include +#include "../algs.h" +#include "../serialize.h" +#include "sparse_vector.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + struct sparse_radial_basis_kernel + { + typedef typename T::value_type::second_type scalar_type; + typedef T sample_type; + typedef default_memory_manager mem_manager_type; + + sparse_radial_basis_kernel(const scalar_type g) : gamma(g) {} + sparse_radial_basis_kernel() : gamma(0.1) {} + sparse_radial_basis_kernel( + const sparse_radial_basis_kernel& k + ) : gamma(k.gamma) {} + + + const scalar_type gamma; + + scalar_type operator() ( + const sample_type& a, + const sample_type& b + ) const + { + const scalar_type d = distance_squared(a,b); + return std::exp(-gamma*d); + } + + sparse_radial_basis_kernel& operator= ( + const sparse_radial_basis_kernel& k + ) + { + const_cast(gamma) = k.gamma; + return *this; + } + + bool operator== ( + const sparse_radial_basis_kernel& k + ) const + { + return gamma == k.gamma; + } + }; + + template < + typename T + > + void serialize ( + const sparse_radial_basis_kernel& item, + std::ostream& out + ) + { + try + { + serialize(item.gamma, out); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing object of type sparse_radial_basis_kernel"); + } + } + + template < + typename T + > + void deserialize ( + sparse_radial_basis_kernel& item, + std::istream& in + ) + { + typedef typename T::value_type::second_type scalar_type; + try + { + deserialize(const_cast(item.gamma), in); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while deserializing object of type sparse_radial_basis_kernel"); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + struct sparse_polynomial_kernel + { + typedef typename T::value_type::second_type scalar_type; + typedef T sample_type; + typedef default_memory_manager mem_manager_type; + + sparse_polynomial_kernel(const scalar_type g, const scalar_type c, const scalar_type d) : gamma(g), coef(c), degree(d) {} + sparse_polynomial_kernel() : gamma(1), coef(0), degree(1) {} + sparse_polynomial_kernel( + const sparse_polynomial_kernel& k + ) : gamma(k.gamma), coef(k.coef), degree(k.degree) {} + + typedef T type; + const scalar_type gamma; + const scalar_type coef; + const scalar_type degree; + + scalar_type operator() ( + const sample_type& a, + const sample_type& b + ) const + { + return std::pow(gamma*(dot(a,b)) + coef, degree); + } + + sparse_polynomial_kernel& operator= ( + const sparse_polynomial_kernel& k + ) + { + const_cast(gamma) = k.gamma; + const_cast(coef) = k.coef; + const_cast(degree) = k.degree; + return *this; + } + + bool operator== ( + const sparse_polynomial_kernel& k + ) const + { + return (gamma == k.gamma) && (coef == k.coef) && (degree == k.degree); + } + }; + + template < + typename T + > + void serialize ( + const sparse_polynomial_kernel& item, + std::ostream& out + ) + { + try + { + serialize(item.gamma, out); + serialize(item.coef, out); + serialize(item.degree, out); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing object of type sparse_polynomial_kernel"); + } + } + + template < + typename T + > + void deserialize ( + sparse_polynomial_kernel& item, + std::istream& in + ) + { + typedef typename T::value_type::second_type scalar_type; + try + { + deserialize(const_cast(item.gamma), in); + deserialize(const_cast(item.coef), in); + deserialize(const_cast(item.degree), in); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while deserializing object of type sparse_polynomial_kernel"); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + struct sparse_sigmoid_kernel + { + typedef typename T::value_type::second_type scalar_type; + typedef T sample_type; + typedef default_memory_manager mem_manager_type; + + sparse_sigmoid_kernel(const scalar_type g, const scalar_type c) : gamma(g), coef(c) {} + sparse_sigmoid_kernel() : gamma(0.1), coef(-1.0) {} + sparse_sigmoid_kernel( + const sparse_sigmoid_kernel& k + ) : gamma(k.gamma), coef(k.coef) {} + + typedef T type; + const scalar_type gamma; + const scalar_type coef; + + scalar_type operator() ( + const sample_type& a, + const sample_type& b + ) const + { + return std::tanh(gamma*(dot(a,b)) + coef); + } + + sparse_sigmoid_kernel& operator= ( + const sparse_sigmoid_kernel& k + ) + { + const_cast(gamma) = k.gamma; + const_cast(coef) = k.coef; + return *this; + } + + bool operator== ( + const sparse_sigmoid_kernel& k + ) const + { + return (gamma == k.gamma) && (coef == k.coef); + } + }; + + template < + typename T + > + void serialize ( + const sparse_sigmoid_kernel& item, + std::ostream& out + ) + { + try + { + serialize(item.gamma, out); + serialize(item.coef, out); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing object of type sparse_sigmoid_kernel"); + } + } + + template < + typename T + > + void deserialize ( + sparse_sigmoid_kernel& item, + std::istream& in + ) + { + typedef typename T::value_type::second_type scalar_type; + try + { + deserialize(const_cast(item.gamma), in); + deserialize(const_cast(item.coef), in); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while deserializing object of type sparse_sigmoid_kernel"); + } + } + +// ---------------------------------------------------------------------------------------- + + template + struct sparse_linear_kernel + { + typedef typename T::value_type::second_type scalar_type; + typedef T sample_type; + typedef default_memory_manager mem_manager_type; + + scalar_type operator() ( + const sample_type& a, + const sample_type& b + ) const + { + return dot(a,b); + } + + bool operator== ( + const sparse_linear_kernel& + ) const + { + return true; + } + }; + + template < + typename T + > + void serialize ( + const sparse_linear_kernel& , + std::ostream& + ){} + + template < + typename T + > + void deserialize ( + sparse_linear_kernel& , + std::istream& + ){} + +// ---------------------------------------------------------------------------------------- + + template + struct sparse_histogram_intersection_kernel + { + typedef typename T::value_type::second_type scalar_type; + typedef T sample_type; + typedef default_memory_manager mem_manager_type; + + scalar_type operator() ( + const sample_type& a, + const sample_type& b + ) const + { + typename sample_type::const_iterator ai = a.begin(); + typename sample_type::const_iterator bi = b.begin(); + + scalar_type sum = 0; + while (ai != a.end() && bi != b.end()) + { + if (ai->first == bi->first) + { + sum += std::min(ai->second , bi->second); + ++ai; + ++bi; + } + else if (ai->first < bi->first) + { + ++ai; + } + else + { + ++bi; + } + } + + return sum; + } + + bool operator== ( + const sparse_histogram_intersection_kernel& + ) const + { + return true; + } + }; + + template < + typename T + > + void serialize ( + const sparse_histogram_intersection_kernel& , + std::ostream& + ){} + + template < + typename T + > + void deserialize ( + sparse_histogram_intersection_kernel& , + std::istream& + ){} + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SVm_SPARSE_KERNEL + + + diff --git a/ml/dlib/dlib/svm/sparse_kernel_abstract.h b/ml/dlib/dlib/svm/sparse_kernel_abstract.h new file mode 100644 index 000000000..55f9d7caa --- /dev/null +++ b/ml/dlib/dlib/svm/sparse_kernel_abstract.h @@ -0,0 +1,486 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SVm_SPARSE_KERNEL_ABSTRACT_ +#ifdef DLIB_SVm_SPARSE_KERNEL_ABSTRACT_ + +#include +#include +#include "../algs.h" +#include "../serialize.h" +#include "kernel_abstract.h" +#include "sparse_vector_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + struct sparse_radial_basis_kernel + { + /*! + REQUIREMENTS ON T + Must be a sparse vector as defined in dlib/svm/sparse_vector_abstract.h + + WHAT THIS OBJECT REPRESENTS + This object represents a radial basis function kernel + that works with sparse vectors. + + THREAD SAFETY + This kernel is threadsafe. + !*/ + + typedef typename T::value_type::second_type scalar_type; + typedef T sample_type; + typedef default_memory_manager mem_manager_type; + + const scalar_type gamma; + + sparse_radial_basis_kernel( + ); + /*! + ensures + - #gamma == 0.1 + !*/ + + sparse_radial_basis_kernel( + const sparse_radial_basis_kernel& k + ); + /*! + ensures + - #gamma == k.gamma + !*/ + + sparse_radial_basis_kernel( + const scalar_type g + ); + /*! + ensures + - #gamma == g + !*/ + + scalar_type operator() ( + const sample_type& a, + const sample_type& b + ) const; + /*! + requires + - a is a sparse vector + - b is a sparse vector + ensures + - returns exp(-gamma * distance_squared(a,b)) + !*/ + + sparse_radial_basis_kernel& operator= ( + const sparse_radial_basis_kernel& k + ); + /*! + ensures + - #gamma = k.gamma + - returns *this + !*/ + + bool operator== ( + const sparse_radial_basis_kernel& k + ) const; + /*! + ensures + - if (k and *this are identical) then + - returns true + - else + - returns false + !*/ + + }; + + template < + typename T + > + void serialize ( + const sparse_radial_basis_kernel& item, + std::ostream& out + ); + /*! + provides serialization support for sparse_radial_basis_kernel + !*/ + + template < + typename T + > + void deserialize ( + sparse_radial_basis_kernel& item, + std::istream& in + ); + /*! + provides deserialization support for sparse_radial_basis_kernel + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + struct sparse_sigmoid_kernel + { + /*! + REQUIREMENTS ON T + Must be a sparse vector as defined in dlib/svm/sparse_vector_abstract.h + + WHAT THIS OBJECT REPRESENTS + This object represents a sigmoid kernel + that works with sparse vectors. + + THREAD SAFETY + This kernel is threadsafe. + !*/ + + typedef typename T::value_type::second_type scalar_type; + typedef T sample_type; + typedef default_memory_manager mem_manager_type; + + const scalar_type gamma; + const scalar_type coef; + + sparse_sigmoid_kernel( + ); + /*! + ensures + - #gamma == 0.1 + - #coef == -1.0 + !*/ + + sparse_sigmoid_kernel( + const sparse_sigmoid_kernel& k + ); + /*! + ensures + - #gamma == k.gamma + - #coef == k.coef + !*/ + + sparse_sigmoid_kernel( + const scalar_type g, + const scalar_type c + ); + /*! + ensures + - #gamma == g + - #coef == c + !*/ + + scalar_type operator() ( + const sample_type& a, + const sample_type& b + ) const; + /*! + requires + - a is a sparse vector + - b is a sparse vector + ensures + - returns tanh(gamma * dot(a,b) + coef) + !*/ + + sparse_sigmoid_kernel& operator= ( + const sparse_sigmoid_kernel& k + ); + /*! + ensures + - #gamma = k.gamma + - #coef = k.coef + - returns *this + !*/ + + bool operator== ( + const sparse_sigmoid_kernel& k + ) const; + /*! + ensures + - if (k and *this are identical) then + - returns true + - else + - returns false + !*/ + }; + + template < + typename T + > + void serialize ( + const sparse_sigmoid_kernel& item, + std::ostream& out + ); + /*! + provides serialization support for sparse_sigmoid_kernel + !*/ + + template < + typename T + > + void deserialize ( + sparse_sigmoid_kernel& item, + std::istream& in + ); + /*! + provides deserialization support for sparse_sigmoid_kernel + !*/ + + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + struct sparse_polynomial_kernel + { + /*! + REQUIREMENTS ON T + Must be a sparse vector as defined in dlib/svm/sparse_vector_abstract.h + + WHAT THIS OBJECT REPRESENTS + This object represents a polynomial kernel + that works with sparse vectors. + + THREAD SAFETY + This kernel is threadsafe. + !*/ + + typedef typename T::value_type::second_type scalar_type; + typedef T sample_type; + typedef default_memory_manager mem_manager_type; + + const scalar_type gamma; + const scalar_type coef; + const scalar_type degree; + + sparse_polynomial_kernel( + ); + /*! + ensures + - #gamma == 1 + - #coef == 0 + - #degree == 1 + !*/ + + sparse_polynomial_kernel( + const sparse_polynomial_kernel& k + ); + /*! + ensures + - #gamma == k.gamma + - #coef == k.coef + - #degree == k.degree + !*/ + + sparse_polynomial_kernel( + const scalar_type g, + const scalar_type c, + const scalar_type d + ); + /*! + ensures + - #gamma == g + - #coef == c + - #degree == d + !*/ + + scalar_type operator() ( + const sample_type& a, + const sample_type& b + ) const; + /*! + requires + - a is a sparse vector + - b is a sparse vector + ensures + - returns pow(gamma * dot(a,b) + coef, degree) + !*/ + + sparse_polynomial_kernel& operator= ( + const sparse_polynomial_kernel& k + ); + /*! + ensures + - #gamma = k.gamma + - #coef = k.coef + - #degree = k.degree + - returns *this + !*/ + + bool operator== ( + const sparse_polynomial_kernel& k + ) const; + /*! + ensures + - if (k and *this are identical) then + - returns true + - else + - returns false + !*/ + }; + + template < + typename T + > + void serialize ( + const sparse_polynomial_kernel& item, + std::ostream& out + ); + /*! + provides serialization support for sparse_polynomial_kernel + !*/ + + template < + typename T + > + void deserialize ( + sparse_polynomial_kernel& item, + std::istream& in + ); + /*! + provides deserialization support for sparse_polynomial_kernel + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + struct sparse_linear_kernel + { + /*! + REQUIREMENTS ON T + Must be a sparse vector as defined in dlib/svm/sparse_vector_abstract.h + + WHAT THIS OBJECT REPRESENTS + This object represents a linear function kernel + that works with sparse vectors. + + THREAD SAFETY + This kernel is threadsafe. + !*/ + + typedef typename T::value_type::second_type scalar_type; + typedef T sample_type; + typedef default_memory_manager mem_manager_type; + + scalar_type operator() ( + const sample_type& a, + const sample_type& b + ) const; + /*! + requires + - a is a sparse vector + - b is a sparse vector + ensures + - returns dot(a,b) + !*/ + + bool operator== ( + const sparse_linear_kernel& k + ) const; + /*! + ensures + - returns true + !*/ + }; + + template < + typename T + > + void serialize ( + const sparse_linear_kernel& item, + std::ostream& out + ); + /*! + provides serialization support for sparse_linear_kernel + !*/ + + template < + typename T + > + void deserialize ( + sparse_linear_kernel& item, + std::istream& in + ); + /*! + provides deserialization support for sparse_linear_kernel + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + struct sparse_histogram_intersection_kernel + { + /*! + REQUIREMENTS ON T + Must be a sparse vector as defined in dlib/svm/sparse_vector_abstract.h + + WHAT THIS OBJECT REPRESENTS + This object represents a histogram intersection kernel + that works with sparse vectors. + + THREAD SAFETY + This kernel is threadsafe. + !*/ + + typedef typename T::value_type::second_type scalar_type; + typedef T sample_type; + typedef default_memory_manager mem_manager_type; + + scalar_type operator() ( + const sample_type& a, + const sample_type& b + ) const; + /*! + requires + - a is a sparse vector + - b is a sparse vector + - all the values in a and b are >= 0 + ensures + - Let A(i) denote the value of the ith dimension of the a vector. + - Let B(i) denote the value of the ith dimension of the b vector. + - returns sum over all i: std::min(A(i), B(i)) + !*/ + + bool operator== ( + const sparse_histogram_intersection_kernel& k + ) const; + /*! + ensures + - returns true + !*/ + }; + + template < + typename T + > + void serialize ( + const sparse_histogram_intersection_kernel& item, + std::ostream& out + ); + /*! + provides serialization support for sparse_histogram_intersection_kernel + !*/ + + template < + typename T + > + void deserialize ( + sparse_histogram_intersection_kernel& item, + std::istream& in + ); + /*! + provides deserialization support for sparse_histogram_intersection_kernel + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SVm_SPARSE_KERNEL_ABSTRACT_ + + diff --git a/ml/dlib/dlib/svm/sparse_vector.h b/ml/dlib/dlib/svm/sparse_vector.h new file mode 100644 index 000000000..c42723f89 --- /dev/null +++ b/ml/dlib/dlib/svm/sparse_vector.h @@ -0,0 +1,1170 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SVm_SPARSE_VECTOR +#define DLIB_SVm_SPARSE_VECTOR + +#include "sparse_vector_abstract.h" +#include +#include +#include "../algs.h" +#include +#include +#include "../graph_utils/edge_list_graphs.h" +#include "../matrix.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template + typename T::value_type::second_type distance_squared ( + const T& a, + const U& b + ) + { + typedef typename T::value_type::second_type scalar_type; + typedef typename U::value_type::second_type scalar_typeU; + // Both T and U must contain the same kinds of elements + COMPILE_TIME_ASSERT((is_same_type::value)); + + typename T::const_iterator ai = a.begin(); + typename U::const_iterator bi = b.begin(); + + scalar_type sum = 0, temp = 0; + while (ai != a.end() && bi != b.end()) + { + if (ai->first == bi->first) + { + temp = ai->second - bi->second; + ++ai; + ++bi; + } + else if (ai->first < bi->first) + { + temp = ai->second; + ++ai; + } + else + { + temp = bi->second; + ++bi; + } + + sum += temp*temp; + } + + while (ai != a.end()) + { + sum += ai->second*ai->second; + ++ai; + } + while (bi != b.end()) + { + sum += bi->second*bi->second; + ++bi; + } + + return sum; + } + +// ------------------------------------------------------------------------------------ + + template + typename T::value_type::second_type distance_squared ( + const V& a_scale, + const T& a, + const W& b_scale, + const U& b + ) + { + typedef typename T::value_type::second_type scalar_type; + typedef typename U::value_type::second_type scalar_typeU; + // Both T and U must contain the same kinds of elements + COMPILE_TIME_ASSERT((is_same_type::value)); + + typename T::const_iterator ai = a.begin(); + typename U::const_iterator bi = b.begin(); + + scalar_type sum = 0, temp = 0; + while (ai != a.end() && bi != b.end()) + { + if (ai->first == bi->first) + { + temp = a_scale*ai->second - b_scale*bi->second; + ++ai; + ++bi; + } + else if (ai->first < bi->first) + { + temp = a_scale*ai->second; + ++ai; + } + else + { + temp = b_scale*bi->second; + ++bi; + } + + sum += temp*temp; + } + + while (ai != a.end()) + { + sum += a_scale*a_scale*ai->second*ai->second; + ++ai; + } + while (bi != b.end()) + { + sum += b_scale*b_scale*bi->second*bi->second; + ++bi; + } + + return sum; + } + +// ------------------------------------------------------------------------------------ + + template + typename T::value_type::second_type distance ( + const T& a, + const U& b + ) + { + return std::sqrt(distance_squared(a,b)); + } + +// ------------------------------------------------------------------------------------ + + template + typename T::value_type::second_type distance ( + const V& a_scale, + const T& a, + const W& b_scale, + const U& b + ) + { + return std::sqrt(distance_squared(a_scale,a,b_scale,b)); + } + +// ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + + template + typename enable_if >::type assign ( + T& dest, + const matrix_exp& src + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(is_vector(src), + "\t void assign(dest,src)" + << "\n\t the src matrix must be a row or column vector" + ); + + dest = src; + } + + template + typename disable_if >::type assign ( + T& dest, + const matrix_exp& src + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(is_vector(src), + "\t void assign(dest,src)" + << "\n\t the src matrix must be a row or column vector" + ); + + dest.clear(); + typedef typename T::value_type item_type; + for (long i = 0; i < src.size(); ++i) + { + dest.insert(dest.end(),item_type(i, src(i))); + } + } + + template + typename disable_if_c::value || is_matrix::value>::type assign ( + T& dest, // sparse + const U& src // sparse + ) + { + dest.assign(src.begin(), src.end()); + } + + template + typename disable_if >::type assign ( + std::map& dest, // sparse + const S& src // sparse + ) + { + dest.clear(); + dest.insert(src.begin(), src.end()); + } + +// ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + + template + struct has_unsigned_keys + { + static const bool value = is_unsigned_type::value; + }; + +// ------------------------------------------------------------------------------------ + + namespace impl + { + template + typename T::value_type::second_type general_dot ( + const T& a, + const U& b + ) + { + typedef typename T::value_type::second_type scalar_type; + + typename T::const_iterator ai = a.begin(); + typename U::const_iterator bi = b.begin(); + + scalar_type sum = 0; + while (ai != a.end() && bi != b.end()) + { + if (ai->first == bi->first) + { + sum += ai->second * bi->second; + ++ai; + ++bi; + } + else if (ai->first < bi->first) + { + ++ai; + } + else + { + ++bi; + } + } + + return sum; + } + + template + inline typename T::value_type::second_type dot ( + const T& a, + const U& b + ) + { + return general_dot(a,b); + } + + template + U dot ( + const std::vector,alloc>& a, + const std::vector,alloc>& b + ) + { + // You are getting this error because you are attempting to use sparse sample vectors + // but you aren't using an unsigned integer as your key type in the sparse vectors. + COMPILE_TIME_ASSERT(is_unsigned_type::value); + + if (a.size() == 0 || b.size() == 0) + return 0; + + // if a is really a dense vector but just represented in a sparse container + if (a.back().first == a.size()-1) + { + double sum = 0; + for (unsigned long i = 0; i < b.size(); ++i) + { + if (b[i].first >= a.size()) + break; + sum += a[b[i].first].second * b[i].second; + } + return sum; + } + // if b is really a dense vector but just represented in a sparse container + else if (b.back().first == b.size()-1) + { + double sum = 0; + for (unsigned long i = 0; i < a.size(); ++i) + { + if (a[i].first >= b.size()) + break; + sum += b[a[i].first].second * a[i].second; + } + return sum; + } + else + { + return general_dot(a,b); + } + } + } + + template + inline typename T::value_type::second_type dot ( + const T& a, + const T& b + ) + { + return impl::dot(a,b); + } + + template + inline T4 dot ( + const std::vector& a, + const std::map& b + ) + { + return impl::dot(a,b); + } + + template + inline T4 dot ( + const std::map& a, + const std::vector& b + ) + { + return impl::dot(a,b); + } + +// ------------------------------------------------------------------------------------ + + template + typename T::value_type::second_type dot ( + const T& a, + const matrix_exp& b + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(is_vector(b), + "\t scalar_type dot(sparse_vector a, dense_vector b)" + << "\n\t 'b' must be a vector to be used in a dot product." + ); + + typedef typename T::value_type::second_type scalar_type; + typedef typename T::value_type::first_type first_type; + + scalar_type sum = 0; + for (typename T::const_iterator ai = a.begin(); + (ai != a.end()) && (ai->first < static_cast(b.size())); + ++ai) + { + sum += ai->second * b(ai->first); + } + + return sum; + } + +// ------------------------------------------------------------------------------------ + + template + typename T::value_type::second_type dot ( + const matrix_exp& b, + const T& a + ) + { + return dot(a,b); + } + +// ------------------------------------------------------------------------------------ + + template + typename T::value_type::second_type length_squared ( + const T& a + ) + { + typedef typename T::value_type::second_type scalar_type; + + typename T::const_iterator i; + + scalar_type sum = 0; + + for (i = a.begin(); i != a.end(); ++i) + { + sum += i->second * i->second; + } + + return sum; + } + +// ------------------------------------------------------------------------------------ + + template + typename T::value_type::second_type length ( + const T& a + ) + { + return std::sqrt(length_squared(a)); + } + +// ------------------------------------------------------------------------------------ + + template + typename disable_if,void>::type scale_by ( + T& a, + const U& value + ) + { + for (typename T::iterator i = a.begin(); i != a.end(); ++i) + { + i->second *= value; + } + } + + template + typename enable_if,void>::type scale_by ( + T& a, + const U& value + ) + { + a *= value; + } + +// ------------------------------------------------------------------------------------ + + template + typename disable_if,T>::type add ( + const T& a, + const T& b + ) + { + T temp; + + typename T::const_iterator i = a.begin(); + typename T::const_iterator j = b.begin(); + while (i != a.end() && j != b.end()) + { + if (i->first == j->first) + { + temp.insert(temp.end(), std::make_pair(i->first, i->second + j->second)); + ++i; + ++j; + } + else if (i->first < j->first) + { + temp.insert(temp.end(), *i); + ++i; + } + else + { + temp.insert(temp.end(), *j); + ++j; + } + } + + while (i != a.end()) + { + temp.insert(temp.end(), *i); + ++i; + } + while (j != b.end()) + { + temp.insert(temp.end(), *j); + ++j; + } + + return temp; + } + + template + typename enable_if_c::value && is_matrix::value, matrix_add_exp >::type add ( + const T& a, + const U& b + ) + { + return matrix_add_exp(a.ref(),b.ref()); + } + +// ------------------------------------------------------------------------------------ + + template + typename disable_if,T>::type subtract ( + const T& a, + const T& b + ) + { + T temp; + + typename T::const_iterator i = a.begin(); + typename T::const_iterator j = b.begin(); + while (i != a.end() && j != b.end()) + { + if (i->first == j->first) + { + temp.insert(temp.end(), std::make_pair(i->first, i->second - j->second)); + ++i; + ++j; + } + else if (i->first < j->first) + { + temp.insert(temp.end(), *i); + ++i; + } + else + { + temp.insert(temp.end(), std::make_pair(j->first, -j->second)); + ++j; + } + } + + while (i != a.end()) + { + temp.insert(temp.end(), *i); + ++i; + } + while (j != b.end()) + { + temp.insert(temp.end(), std::make_pair(j->first, -j->second)); + ++j; + } + + return temp; + } + + template + typename enable_if_c::value && is_matrix::value, matrix_subtract_exp >::type subtract ( + const T& a, + const U& b + ) + { + return matrix_subtract_exp(a.ref(),b.ref()); + } + +// ------------------------------------------------------------------------------------ +// ------------------------------------------------------------------------------------ + + namespace impl + { + template + typename enable_if,unsigned long>::type max_index_plus_one ( + const T& samples + ) + { + if (samples.size() > 0) + return samples(0).size(); + else + return 0; + } + + template + typename enable_if,unsigned long>::type max_index_plus_one ( + const T& sample + ) + { + return sample.size(); + } + + // This !is_built_in_scalar_type::value is here to avoid an inexplicable bug in Vistual Studio 2005 + template + typename enable_if_c<(!is_built_in_scalar_type::value) && (is_pair::value) ,unsigned long>::type + max_index_plus_one ( + const T& samples + ) + { + typedef typename T::type sample_type; + // You are getting this error because you are attempting to use sparse sample vectors + // but you aren't using an unsigned integer as your key type in the sparse vectors. + COMPILE_TIME_ASSERT(has_unsigned_keys::value); + + + // these should be sparse samples so look over all them to find the max index. + unsigned long max_dim = 0; + for (long i = 0; i < samples.size(); ++i) + { + if (samples(i).size() > 0) + max_dim = std::max(max_dim, (--samples(i).end())->first + 1); + } + + return max_dim; + } + } + + template + typename enable_if,unsigned long>::type max_index_plus_one ( + const T& sample + ) + { + if (sample.size() > 0) + return (--sample.end())->first + 1; + return 0; + } + + template + typename disable_if_c::value || + is_same_type::value || + is_same_type::value , unsigned long>::type + max_index_plus_one ( + const T& samples + ) + { + return impl::max_index_plus_one(mat(samples)); + } + +// ------------------------------------------------------------------------------------ + + template + inline void add_to ( + matrix& dest, + const matrix_exp& src + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(is_vector(dest) && max_index_plus_one(src) <= static_cast(dest.size()), + "\t void add_to(dest,src)" + << "\n\t dest must be a vector large enough to hold the src vector." + << "\n\t is_vector(dest): " << is_vector(dest) + << "\n\t max_index_plus_one(src): " << max_index_plus_one(src) + << "\n\t dest.size(): " << dest.size() + ); + + for (long r = 0; r < src.size(); ++r) + dest(r) += src(r); + } + + template + inline typename disable_if >::type add_to ( + matrix& dest, + const EXP& src + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(is_vector(dest) && max_index_plus_one(src) <= static_cast(dest.size()), + "\t void add_to(dest,src)" + << "\n\t dest must be a vector large enough to hold the src vector." + << "\n\t is_vector(dest): " << is_vector(dest) + << "\n\t max_index_plus_one(src): " << max_index_plus_one(src) + << "\n\t dest.size(): " << dest.size() + ); + + for (typename EXP::const_iterator i = src.begin(); i != src.end(); ++i) + dest(i->first) += i->second; + } + +// ------------------------------------------------------------------------------------ + + template + inline void add_to ( + matrix& dest, + const matrix_exp& src, + const U& C + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(is_vector(dest) && max_index_plus_one(src) <= static_cast(dest.size()), + "\t void add_to(dest,src)" + << "\n\t dest must be a vector large enough to hold the src vector." + << "\n\t is_vector(dest): " << is_vector(dest) + << "\n\t max_index_plus_one(src): " << max_index_plus_one(src) + << "\n\t dest.size(): " << dest.size() + ); + + for (long r = 0; r < src.size(); ++r) + dest(r) += C*src(r); + } + + template + inline typename disable_if >::type add_to ( + matrix& dest, + const EXP& src, + const U& C + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(is_vector(dest) && max_index_plus_one(src) <= static_cast(dest.size()), + "\t void add_to(dest,src)" + << "\n\t dest must be a vector large enough to hold the src vector." + << "\n\t is_vector(dest): " << is_vector(dest) + << "\n\t max_index_plus_one(src): " << max_index_plus_one(src) + << "\n\t dest.size(): " << dest.size() + ); + + for (typename EXP::const_iterator i = src.begin(); i != src.end(); ++i) + dest(i->first) += C*i->second; + } + +// ------------------------------------------------------------------------------------ + + template + inline void subtract_from ( + matrix& dest, + const matrix_exp& src + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(is_vector(dest) && max_index_plus_one(src) <= static_cast(dest.size()), + "\t void subtract_from(dest,src)" + << "\n\t dest must be a vector large enough to hold the src vector." + << "\n\t is_vector(dest): " << is_vector(dest) + << "\n\t max_index_plus_one(src): " << max_index_plus_one(src) + << "\n\t dest.size(): " << dest.size() + ); + + for (long r = 0; r < src.size(); ++r) + dest(r) -= src(r); + } + + template + inline typename disable_if >::type subtract_from ( + matrix& dest, + const EXP& src + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(is_vector(dest) && max_index_plus_one(src) <= static_cast(dest.size()), + "\t void subtract_from(dest,src)" + << "\n\t dest must be a vector large enough to hold the src vector." + << "\n\t is_vector(dest): " << is_vector(dest) + << "\n\t max_index_plus_one(src): " << max_index_plus_one(src) + << "\n\t dest.size(): " << dest.size() + ); + + for (typename EXP::const_iterator i = src.begin(); i != src.end(); ++i) + dest(i->first) -= i->second; + } + +// ------------------------------------------------------------------------------------ + + template + inline void subtract_from ( + matrix& dest, + const matrix_exp& src, + const U& C + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(is_vector(dest) && max_index_plus_one(src) <= static_cast(dest.size()), + "\t void subtract_from(dest,src)" + << "\n\t dest must be a vector large enough to hold the src vector." + << "\n\t is_vector(dest): " << is_vector(dest) + << "\n\t max_index_plus_one(src): " << max_index_plus_one(src) + << "\n\t dest.size(): " << dest.size() + ); + + for (long r = 0; r < src.size(); ++r) + dest(r) -= C*src(r); + } + + template + inline typename disable_if >::type subtract_from ( + matrix& dest, + const EXP& src, + const U& C + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(is_vector(dest) && max_index_plus_one(src) <= static_cast(dest.size()), + "\t void subtract_from(dest,src)" + << "\n\t dest must be a vector large enough to hold the src vector." + << "\n\t is_vector(dest): " << is_vector(dest) + << "\n\t max_index_plus_one(src): " << max_index_plus_one(src) + << "\n\t dest.size(): " << dest.size() + ); + + for (typename EXP::const_iterator i = src.begin(); i != src.end(); ++i) + dest(i->first) -= C*i->second; + } + +// ------------------------------------------------------------------------------------ + + template + typename T::value_type::second_type min ( + const T& a + ) + { + typedef typename T::value_type::second_type type; + + type temp = 0; + for (typename T::const_iterator i = a.begin(); i != a.end(); ++i) + { + if (temp > i->second) + temp = i->second; + } + return temp; + } + +// ------------------------------------------------------------------------------------ + + template + typename T::value_type::second_type max ( + const T& a + ) + { + typedef typename T::value_type::second_type type; + + type temp = 0; + for (typename T::const_iterator i = a.begin(); i != a.end(); ++i) + { + if (temp < i->second) + temp = i->second; + } + return temp; + } + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + template + inline matrix sparse_to_dense ( + const sparse_vector_type& vect, + unsigned long num_dimensions + ) + { + // You must use unsigned integral key types in your sparse vectors + typedef typename sparse_vector_type::value_type::first_type idx_type; + typedef typename sparse_vector_type::value_type::second_type value_type; + COMPILE_TIME_ASSERT(is_unsigned_type::value); + + matrix result; + + if (vect.size() == 0) + return result; + + result.set_size(num_dimensions); + result = 0; + + for (typename sparse_vector_type::const_iterator j = vect.begin(); j != vect.end(); ++j) + { + if ((long)(j->first) < result.size()) + { + result(j->first) += j->second; + } + } + + return result; + } + } + +// ---------------------------------------------------------------------------------------- + + template + matrix sparse_to_dense ( + const std::vector,alloc>& vect, + unsigned long num_dimensions + ) + { + return impl::sparse_to_dense(vect,num_dimensions); + } + +// ---------------------------------------------------------------------------------------- + + template + matrix sparse_to_dense ( + const std::vector,alloc>& vect + ) + { + return impl::sparse_to_dense(vect, max_index_plus_one(vect)); + } + +// ---------------------------------------------------------------------------------------- + + template + matrix sparse_to_dense ( + const std::map& vect, + unsigned long num_dimensions + ) + { + return impl::sparse_to_dense(vect,num_dimensions); + } + +// ---------------------------------------------------------------------------------------- + + template + matrix sparse_to_dense ( + const std::map& vect + ) + { + return impl::sparse_to_dense(vect, max_index_plus_one(vect)); + } + +// ---------------------------------------------------------------------------------------- + + template + typename enable_if,T&>::type sparse_to_dense( + T& item + ) { return item; } + + template + matrix sparse_to_dense( + const matrix_exp& item, + unsigned long num + ) + { + typedef typename EXP::type type; + if (item.size() == (long)num) + return item; + else if (item.size() < (long)num) + return join_cols(item, zeros_matrix((long)num-item.size(),1)); + else + return colm(item,0,(long)num); + } + +// ---------------------------------------------------------------------------------------- + + template + std::vector > sparse_to_dense ( + const std::vector& samples, + unsigned long num_dimensions + ) + { + typedef typename sample_type::value_type pair_type; + typedef typename pair_type::second_type value_type; + + std::vector< matrix > result; + + // now turn all the samples into dense samples + result.resize(samples.size()); + + for (unsigned long i = 0; i < samples.size(); ++i) + { + result[i] = sparse_to_dense(samples[i],num_dimensions); + } + + return result; + } + +// ---------------------------------------------------------------------------------------- + + template + std::vector > sparse_to_dense ( + const std::vector& samples + ) + { + return sparse_to_dense(samples, max_index_plus_one(samples)); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + T make_sparse_vector ( + const T& v + ) + { + // You must use unsigned integral key types in your sparse vectors + typedef typename T::value_type::first_type idx_type; + typedef typename T::value_type::second_type value_type; + COMPILE_TIME_ASSERT(is_unsigned_type::value); + std::map temp; + for (typename T::const_iterator i = v.begin(); i != v.end(); ++i) + { + temp[i->first] += i->second; + } + + return T(temp.begin(), temp.end()); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + void make_sparse_vector_inplace( + T& vect + ) + { + vect = make_sparse_vector(vect); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename U, + typename alloc + > + void make_sparse_vector_inplace ( + std::vector,alloc>& vect + ) + { + if (vect.size() > 0) + { + std::sort(vect.begin(), vect.end()); + + // merge duplicates + for (unsigned long i = 1; i < vect.size(); ++i) + { + // if we found a duplicate + if (vect[i-1].first == vect[i].first) + { + // now start collapsing and merging the vector + unsigned long j = i-1; + for (unsigned long k = i; k < vect.size(); ++k) + { + if (vect[j].first == vect[k].first) + { + vect[j].second += vect[k].second; + } + else + { + ++j; + vect[j] = vect[k]; + } + } + + + // we removed elements when we merged so we need to adjust the size. + vect.resize(j+1); + return; + } + } + } + } + +// ---------------------------------------------------------------------------------------- + + template + void sparse_matrix_vector_multiply ( + const std::vector& edges, + const matrix_exp& v, + matrix& result + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(max_index_plus_one(edges) <= (unsigned long)v.size() && + is_col_vector(v), + "\t void sparse_matrix_vector_multiply()" + << "\n\t Invalid inputs were given to this function" + << "\n\t max_index_plus_one(edges): " << max_index_plus_one(edges) + << "\n\t v.size(): " << v.size() + << "\n\t is_col_vector(v): " << is_col_vector(v) + ); + + result.set_size(v.nr(),v.nc()); + result = 0; + + for (unsigned long k = 0; k < edges.size(); ++k) + { + const long i = edges[k].index1(); + const long j = edges[k].index2(); + const double d = edges[k].distance(); + + result(i) += v(j)*d; + if (i != j) + result(j) += v(i)*d; + } + } + +// ---------------------------------------------------------------------------------------- + + template + matrix sparse_matrix_vector_multiply ( + const std::vector& edges, + const matrix_exp& v + ) + { + matrix result; + sparse_matrix_vector_multiply(edges,v,result); + return result; + } + +// ---------------------------------------------------------------------------------------- + + template + void sparse_matrix_vector_multiply ( + const std::vector& edges, + const matrix_exp& v, + matrix& result + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(max_index_plus_one(edges) <= (unsigned long)v.size() && + is_col_vector(v), + "\t void sparse_matrix_vector_multiply()" + << "\n\t Invalid inputs were given to this function" + << "\n\t max_index_plus_one(edges): " << max_index_plus_one(edges) + << "\n\t v.size(): " << v.size() + << "\n\t is_col_vector(v): " << is_col_vector(v) + ); + + + result.set_size(v.nr(),v.nc()); + result = 0; + + for (unsigned long k = 0; k < edges.size(); ++k) + { + const long i = edges[k].index1(); + const long j = edges[k].index2(); + const double d = edges[k].distance(); + + result(i) += v(j)*d; + } + } + +// ---------------------------------------------------------------------------------------- + + template + matrix sparse_matrix_vector_multiply ( + const std::vector& edges, + const matrix_exp& v + ) + { + matrix result; + sparse_matrix_vector_multiply(edges,v,result); + return result; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP, + typename sparse_vector_type, + typename T, + long NR, + long NC, + typename MM, + typename L + > + void sparse_matrix_vector_multiply ( + const matrix_exp& m, + const sparse_vector_type& v, + matrix& result + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(max_index_plus_one(v) <= (unsigned long)m.nc(), + "\t void sparse_matrix_vector_multiply()" + << "\n\t Invalid inputs were given to this function" + << "\n\t max_index_plus_one(v): " << max_index_plus_one(v) + << "\n\t m.size(): " << m.size() + ); + + result.set_size(m.nr(),1); + result = 0; + + for (typename sparse_vector_type::const_iterator i = v.begin(); i != v.end(); ++i) + { + for (long r = 0; r < result.nr(); ++r) + { + result(r) += m(r, i->first)*i->second; + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP, + typename sparse_vector_type + > + matrix sparse_matrix_vector_multiply ( + const matrix_exp& m, + const sparse_vector_type& v + ) + { + matrix result; + sparse_matrix_vector_multiply(m,v,result); + return result; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SVm_SPARSE_VECTOR + diff --git a/ml/dlib/dlib/svm/sparse_vector_abstract.h b/ml/dlib/dlib/svm/sparse_vector_abstract.h new file mode 100644 index 000000000..e0c8d1f8c --- /dev/null +++ b/ml/dlib/dlib/svm/sparse_vector_abstract.h @@ -0,0 +1,688 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SVm_SPARSE_VECTOR_ABSTRACT_ +#ifdef DLIB_SVm_SPARSE_VECTOR_ABSTRACT_ + +#include +#include "../algs.h" +#include "../serialize.h" +#include "../matrix.h" +#include +#include +#include "../graph_utils/sample_pair_abstract.h" +#include "../graph_utils/ordered_sample_pair_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + /*!A sparse_vectors + + In dlib, sparse vectors are represented using the container objects + in the C++ STL. In particular, a sparse vector is any container that + contains a range of std::pair objects where: + - key is an unsigned integral type + - scalar_value is float, double, or long double + - the std::pair objects have unique key values + - the std::pair objects are sorted such that small keys come first + + Therefore, if an object satisfies the above requirements we call it a + "sparse vector". Additionally, we define the concept of an "unsorted sparse vector" + to be a sparse vector that doesn't necessarily have sorted or unique key values. + Therefore, all sparse vectors are valid unsorted sparse vectors but not the other + way around. + + An unsorted sparse vector with duplicate keys is always interpreted as + a vector where each dimension contains the sum of all corresponding elements + of the unsorted sparse vector. For example, an unsorted sparse vector + with the elements { (3,1), (0, 4), (3,5) } represents the 4D vector: + [4, 0, 0, 1+5] + + + + Examples of valid sparse vectors are: + - std::map + - std::vector > where the vector is sorted. + (you could make sure it was sorted by applying std::sort to it) + + + Finally, by "dense vector" we mean a dlib::matrix object which represents + either a row or column vector. + + The rest of this file defines a number of helper functions for doing normal + vector arithmetic things with sparse vectors. + !*/ + +// ---------------------------------------------------------------------------------------- + + /*!A has_unsigned_keys + + This is a template where has_unsigned_keys::value == true when T is a + sparse vector that contains unsigned integral keys and false otherwise. + !*/ + + template + struct has_unsigned_keys + { + static const bool value = is_unsigned_type::value; + }; + +// ---------------------------------------------------------------------------------------- + + template + typename T::value_type::second_type distance_squared ( + const T& a, + const U& b + ); + /*! + requires + - a and b are sparse vectors + ensures + - returns the squared distance between the vectors + a and b + !*/ + +// ---------------------------------------------------------------------------------------- + + template + typename T::value_type::second_type distance_squared ( + const V& a_scale, + const T& a, + const W& b_scale, + const U& b + ); + /*! + requires + - a and b are sparse vectors + ensures + - returns the squared distance between the vectors + a_scale*a and b_scale*b + !*/ + +// ---------------------------------------------------------------------------------------- + + template + typename T::value_type::second_type distance ( + const T& a, + const U& b + ); + /*! + requires + - a and b are sparse vectors + ensures + - returns the distance between the vectors + a and b. (i.e. std::sqrt(distance_squared(a,b))) + !*/ + +// ---------------------------------------------------------------------------------------- + + template + typename T::value_type::second_type distance ( + const V& a_scale, + const T& a, + const W& b_scale, + const U& b + ); + /*! + requires + - a and b are sparse vectors + ensures + - returns the distance between the vectors + a_scale*a and b_scale*b. (i.e. std::sqrt(distance_squared(a_scale,a,b_scale,b))) + !*/ + +// ---------------------------------------------------------------------------------------- + + template + void assign ( + T& dest, + const U& src + ); + /*! + requires + - dest == a sparse vector or a dense vector + - src == a sparse vector or a dense vector + - dest is not dense when src is sparse + (i.e. you can't assign a sparse vector to a dense vector. This is + because we don't know what the proper dimensionality should be for the + dense vector) + ensures + - #src represents the same vector as dest. + (conversion between sparse/dense formats is done automatically) + !*/ + + +// ---------------------------------------------------------------------------------------- + + template + typename T::value_type::second_type dot ( + const T& a, + const T& b + ); + /*! + requires + - a and b are sparse vectors + ensures + - returns the dot product between the vectors a and b + !*/ + + template + T4 dot ( + const std::vector& a, + const std::map& b + ); + /*! + requires + - a and b are sparse vectors + ensures + - returns the dot product between the vectors a and b + !*/ + + template + T4 dot ( + const std::map& a, + const std::vector& b + ); + /*! + requires + - a and b are sparse vectors + ensures + - returns the dot product between the vectors a and b + !*/ + +// ---------------------------------------------------------------------------------------- + + template + typename T::value_type::second_type dot ( + const T& a, + const matrix_exp& b + ); + /*! + requires + - a is an unsorted sparse vector + - is_vector(b) == true + ensures + - returns the dot product between the vectors a and b. + - if (max_index_plus_one(a) >= b.size()) then + - a's dimensionality is greater than b's dimensionality. In this case we + pretend b is padded by as many zeros as is needed to make the dot product + work. So this means that any elements in a that go beyond the length of + b are simply ignored. + !*/ + +// ---------------------------------------------------------------------------------------- + + template + typename T::value_type::second_type dot ( + const matrix_exp& a, + const T& b + ); + /*! + requires + - b is an unsorted sparse vector + - is_vector(a) == true + ensures + - returns the dot product between the vectors a and b + - if (max_index_plus_one(b) >= a.size()) then + - b's dimensionality is greater than a's dimensionality. In this case we + pretend a is padded by as many zeros as is needed to make the dot product + work. So this means that any elements in b that go beyond the length of + a are simply ignored. + !*/ + +// ---------------------------------------------------------------------------------------- + + template + typename T::value_type::second_type length_squared ( + const T& a + ); + /*! + requires + - a is a sparse vector + ensures + - returns dot(a,a) + !*/ + +// ---------------------------------------------------------------------------------------- + + template + typename T::value_type::second_type length ( + const T& a + ); + /*! + requires + - a is a sparse vector + ensures + - returns std::sqrt(length_squared(a,a)) + !*/ + +// ---------------------------------------------------------------------------------------- + + template + void scale_by ( + T& a, + const U& value + ); + /*! + requires + - a is an unsorted sparse vector or a dlib::matrix + ensures + - #a == a*value + (i.e. multiplies every element of the vector a by value) + !*/ + +// ---------------------------------------------------------------------------------------- + + template + T add ( + const T& a, + const T& b + ); + /*! + requires + - a is a sparse vector or dlib::matrix + - b is a sparse vector or dlib::matrix + ensures + - returns a vector or matrix which represents a+b. If the inputs are + sparse vectors then the result is a sparse vector. + !*/ + +// ---------------------------------------------------------------------------------------- + + template + T subtract ( + const T& a, + const T& b + ); + /*! + requires + - a is a sparse vector or dlib::matrix + - b is a sparse vector or dlib::matrix + ensures + - returns a vector or matrix which represents a-b. If the inputs are + sparse vectors then the result is a sparse vector. + !*/ + +// ---------------------------------------------------------------------------------------- + + template + unsigned long max_index_plus_one ( + const T& samples + ); + /*! + requires + - samples == a single vector (either sparse or dense), or a container + of vectors which is either a dlib::matrix of vectors or something + convertible to a dlib::matrix via mat() (e.g. a std::vector) + Valid types of samples include (but are not limited to): + - dlib::matrix // A single dense vector + - std::map // A single sparse vector + - std::vector > // An array of dense vectors + - std::vector > // An array of sparse vectors + ensures + - This function tells you the dimensionality of a set of vectors. The vectors + can be either sparse or dense. + - if (samples.size() == 0) then + - returns 0 + - else if (samples contains dense vectors or is a dense vector) then + - returns the number of elements in the first sample vector. This means + we implicitly assume all dense vectors have the same length) + - else + - In this case samples contains sparse vectors or is a sparse vector. + - returns the largest element index in any sample + 1. Note that the element index values + are the values stored in std::pair::first. So this number tells you the dimensionality + of a set of sparse vectors. + !*/ + +// ---------------------------------------------------------------------------------------- + + template + inline void add_to ( + matrix& dest, + const SRC& src, + const U& C = 1 + ); + /*! + requires + - SRC == a matrix expression or an unsorted sparse vector + - is_vector(dest) == true + - Let MAX denote the largest element index in src. + Then we require that: + - MAX < dest.size() + - (i.e. dest needs to be big enough to contain all the elements of src) + ensures + - #dest == dest + C*src + !*/ + +// ---------------------------------------------------------------------------------------- + + template + inline void subtract_from ( + matrix& dest, + const SRC& src, + const U& C = 1 + ); + /*! + requires + - SRC == a matrix expression or an unsorted sparse vector + - is_vector(dest) == true + - Let MAX denote the largest element index in src. + Then we require that: + - MAX < dest.size() + - (i.e. dest needs to be big enough to contain all the elements of src) + ensures + - #dest == dest - C*src + !*/ + +// ---------------------------------------------------------------------------------------- + + template + typename T::value_type::second_type min ( + const T& vect + ); + /*! + requires + - T == an unsorted sparse vector + ensures + - returns the minimum value in the sparse vector vect. Note that + this value is always <= 0 since a sparse vector has an unlimited number + of 0 elements. + !*/ + +// ------------------------------------------------------------------------------------ + + template + typename T::value_type::second_type max ( + const T& vect + ); + /*! + requires + - T == an unsorted sparse vector + ensures + - returns the maximum value in the sparse vector vect. Note that + this value is always >= 0 since a sparse vector has an unlimited number + of 0 elements. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename sample_type + > + matrix sparse_to_dense ( + const sample_type& vect + ); + /*! + requires + - vect must be a sparse vector or a dense column vector. + ensures + - converts the single sparse or dense vector vect to a dense (column matrix form) + representation. That is, this function returns a vector V such that: + - V.size() == max_index_plus_one(vect) + - for all valid j: + - V(j) == The value of the j'th dimension of the vector vect. Note + that V(j) is zero if it is a sparse vector that doesn't contain an + entry for the j'th dimension. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename sample_type + > + matrix sparse_to_dense ( + const sample_type& vect, + unsigned long num_dimensions + ); + /*! + requires + - vect must be a sparse vector or a dense column vector. + ensures + - converts the single sparse or dense vector vect to a dense (column matrix form) + representation. That is, this function returns a vector V such that: + - V.size() == num_dimensions + - for all valid j: + - V(j) == The value of the j'th dimension of the vector vect. Note + that V(j) is zero if it is a sparse vector that doesn't contain an + entry for the j'th dimension. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename sample_type, + typename alloc + > + std::vector > sparse_to_dense ( + const std::vector& samples + ); + /*! + requires + - all elements of samples must be sparse vectors or dense column vectors. + ensures + - converts from sparse sample vectors to dense (column matrix form) + - That is, this function returns a std::vector R such that: + - R contains column matrices + - R.size() == samples.size() + - for all valid i: + - R[i] == sparse_to_dense(samples[i], max_index_plus_one(samples)) + (i.e. the dense (i.e. dlib::matrix) version of the sparse sample + given by samples[i].) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename sample_type, + typename alloc + > + std::vector > sparse_to_dense ( + const std::vector& samples, + unsigned long num_dimensions + ); + /*! + requires + - all elements of samples must be sparse vectors or dense column vectors. + ensures + - converts from sparse sample vectors to dense (column matrix form) + - That is, this function returns a std::vector R such that: + - R contains column matrices + - R.size() == samples.size() + - for all valid i: + - R[i] == sparse_to_dense(samples[i], num_dimensions) + (i.e. the dense (i.e. dlib::matrix) version of the sparse sample + given by samples[i].) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + T make_sparse_vector ( + const T& v + ); + /*! + requires + - v is an unsorted sparse vector + ensures + - returns a copy of v which is a sparse vector. + (i.e. it will be properly sorted and not have any duplicate key values but + will still logically represent the same vector). + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + void make_sparse_vector_inplace( + T& vect + ); + /*! + requires + - v is an unsorted sparse vector + ensures + - vect == make_sparse_vector(vect) + - This function is just an optimized version of make_sparse_vector(), in + particular, when T is a std::vector> type it is much more + efficient. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP, + typename T, + long NR, + long NC, + typename MM, + typename L + > + void sparse_matrix_vector_multiply ( + const std::vector& edges, + const matrix_exp& v, + matrix& result + ); + /*! + requires + - is_col_vector(v) == true + - max_index_plus_one(edges) <= v.size() + ensures + - Interprets edges as representing a symmetric sparse matrix M. The elements + of M are defined such that, for all valid i,j: + - M(i,j) == sum of edges[k].distance() for all k where edges[k]==sample_pair(i,j) + - This means that any element of M that doesn't have any edges associated + with it will have a value of 0. + - #result == M*v + (i.e. this function multiplies the vector v with the sparse matrix + represented by edges and stores the output into result) + - get_rect(#result) == get_rect(v) + (i.e. result will have the same dimensions as v) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP, + typename T, + long NR, + long NC, + typename MM, + typename L + > + void sparse_matrix_vector_multiply ( + const std::vector& edges, + const matrix_exp& v, + matrix& result + ); + /*! + requires + - is_col_vector(v) == true + - max_index_plus_one(edges) <= v.size() + ensures + - Interprets edges as representing a square sparse matrix M. The elements of M + are defined such that, for all valid i,j: + - M(i,j) == sum of edges[k].distance() for all k where edges[k]==ordered_sample_pair(i,j) + - This means that any element of M that doesn't have any edges associated + with it will have a value of 0. + - #result == M*v + (i.e. this function multiplies the vector v with the sparse matrix + represented by edges and stores the output into result) + - get_rect(#result) == get_rect(v) + (i.e. result will have the same dimensions as v) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP + > + matrix sparse_matrix_vector_multiply ( + const std::vector& edges, + const matrix_exp& v + ); + /*! + requires + - is_col_vector(v) == true + - max_index_plus_one(edges) <= v.size() + ensures + - This is just a convenience routine for invoking the above + sparse_matrix_vector_multiply() routine. In particular, it just calls + sparse_matrix_vector_multiply() with a temporary result matrix and then + returns the result. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP + > + matrix sparse_matrix_vector_multiply ( + const std::vector& edges, + const matrix_exp& v + ); + /*! + requires + - is_col_vector(v) == true + - max_index_plus_one(edges) <= v.size() + ensures + - This is just a convenience routine for invoking the above + sparse_matrix_vector_multiply() routine. In particular, it just calls + sparse_matrix_vector_multiply() with a temporary result matrix and then + returns the result. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP, + typename sparse_vector_type, + typename T, + long NR, + long NC, + typename MM, + typename L + > + void sparse_matrix_vector_multiply ( + const matrix_exp& m, + const sparse_vector_type& v, + matrix& result + ); + /*! + requires + - max_index_plus_one(v) <= m.nc() + - v == an unsorted sparse vector + ensures + - #result == m*v + (i.e. multiply m by the vector v and store the output in result) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP, + typename sparse_vector_type + > + matrix sparse_matrix_vector_multiply ( + const matrix_exp& m, + const sparse_vector_type& v + ); + /*! + requires + - max_index_plus_one(v) <= m.nc() + - v == an unsorted sparse vector + ensures + - returns m*v + (i.e. multiply m by the vector v and return the resulting vector) + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SVm_SPARSE_VECTOR_ABSTRACT_ + + + diff --git a/ml/dlib/dlib/svm/structural_assignment_trainer.h b/ml/dlib/dlib/svm/structural_assignment_trainer.h new file mode 100644 index 000000000..d55b74ff0 --- /dev/null +++ b/ml/dlib/dlib/svm/structural_assignment_trainer.h @@ -0,0 +1,294 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_STRUCTURAL_ASSiGNMENT_TRAINER_Hh_ +#define DLIB_STRUCTURAL_ASSiGNMENT_TRAINER_Hh_ + +#include "structural_assignment_trainer_abstract.h" +#include "../algs.h" +#include "../optimization.h" +#include "structural_svm_assignment_problem.h" +#include "num_nonnegative_weights.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + class structural_assignment_trainer + { + public: + typedef typename feature_extractor::lhs_element lhs_element; + typedef typename feature_extractor::rhs_element rhs_element; + typedef std::pair, std::vector > sample_type; + typedef std::vector label_type; + typedef assignment_function trained_function_type; + + structural_assignment_trainer ( + ) + { + set_defaults(); + } + + explicit structural_assignment_trainer ( + const feature_extractor& fe_ + ) : fe(fe_) + { + set_defaults(); + } + + const feature_extractor& get_feature_extractor ( + ) const { return fe; } + + void set_num_threads ( + unsigned long num + ) + { + num_threads = num; + } + + unsigned long get_num_threads ( + ) const + { + return num_threads; + } + + void set_epsilon ( + double eps_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(eps_ > 0, + "\t void structural_assignment_trainer::set_epsilon()" + << "\n\t eps_ must be greater than 0" + << "\n\t eps_: " << eps_ + << "\n\t this: " << this + ); + + eps = eps_; + } + + double get_epsilon ( + ) const { return eps; } + + void set_max_cache_size ( + unsigned long max_size + ) + { + max_cache_size = max_size; + } + + unsigned long get_max_cache_size ( + ) const + { + return max_cache_size; + } + + void be_verbose ( + ) + { + verbose = true; + } + + void be_quiet ( + ) + { + verbose = false; + } + + void set_oca ( + const oca& item + ) + { + solver = item; + } + + const oca get_oca ( + ) const + { + return solver; + } + + void set_c ( + double C_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(C_ > 0, + "\t void structural_assignment_trainer::set_c()" + << "\n\t C_ must be greater than 0" + << "\n\t C_: " << C_ + << "\n\t this: " << this + ); + + C = C_; + } + + double get_c ( + ) const + { + return C; + } + + bool forces_assignment( + ) const { return force_assignment; } + + void set_forces_assignment ( + bool new_value + ) + { + force_assignment = new_value; + } + + void set_loss_per_false_association ( + double loss + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(loss > 0, + "\t void structural_assignment_trainer::set_loss_per_false_association(loss)" + << "\n\t Invalid inputs were given to this function " + << "\n\t loss: " << loss + << "\n\t this: " << this + ); + + loss_per_false_association = loss; + } + + double get_loss_per_false_association ( + ) const + { + return loss_per_false_association; + } + + void set_loss_per_missed_association ( + double loss + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(loss > 0, + "\t void structural_assignment_trainer::set_loss_per_missed_association(loss)" + << "\n\t Invalid inputs were given to this function " + << "\n\t loss: " << loss + << "\n\t this: " << this + ); + + loss_per_missed_association = loss; + } + + double get_loss_per_missed_association ( + ) const + { + return loss_per_missed_association; + } + + bool forces_last_weight_to_1 ( + ) const + { + return last_weight_1; + } + + void force_last_weight_to_1 ( + bool should_last_weight_be_1 + ) + { + last_weight_1 = should_last_weight_be_1; + } + + const assignment_function train ( + const std::vector& samples, + const std::vector& labels + ) const + { + // make sure requires clause is not broken +#ifdef ENABLE_ASSERTS + if (force_assignment) + { + DLIB_ASSERT(is_forced_assignment_problem(samples, labels), + "\t assignment_function structural_assignment_trainer::train()" + << "\n\t invalid inputs were given to this function" + << "\n\t is_forced_assignment_problem(samples,labels): " << is_forced_assignment_problem(samples,labels) + << "\n\t is_assignment_problem(samples,labels): " << is_assignment_problem(samples,labels) + << "\n\t is_learning_problem(samples,labels): " << is_learning_problem(samples,labels) + ); + } + else + { + DLIB_ASSERT(is_assignment_problem(samples, labels), + "\t assignment_function structural_assignment_trainer::train()" + << "\n\t invalid inputs were given to this function" + << "\n\t is_assignment_problem(samples,labels): " << is_assignment_problem(samples,labels) + << "\n\t is_learning_problem(samples,labels): " << is_learning_problem(samples,labels) + ); + } +#endif + + + + structural_svm_assignment_problem prob(samples,labels, fe, force_assignment, num_threads, + loss_per_false_association, loss_per_missed_association); + + if (verbose) + prob.be_verbose(); + + prob.set_c(C); + prob.set_epsilon(eps); + prob.set_max_cache_size(max_cache_size); + + matrix weights; + + // Take the min here because we want to prevent the user from accidentally + // forcing the bias term to be non-negative. + const unsigned long num_nonneg = std::min(fe.num_features(),num_nonnegative_weights(fe)); + if (last_weight_1) + solver(prob, weights, num_nonneg, fe.num_features()-1); + else + solver(prob, weights, num_nonneg); + + const double bias = weights(weights.size()-1); + return assignment_function(colm(weights,0,weights.size()-1), bias,fe,force_assignment); + + } + + + private: + + bool force_assignment; + double C; + oca solver; + double eps; + bool verbose; + unsigned long num_threads; + unsigned long max_cache_size; + double loss_per_false_association; + double loss_per_missed_association; + bool last_weight_1; + + void set_defaults () + { + force_assignment = false; + C = 100; + verbose = false; + eps = 0.01; + num_threads = 2; + max_cache_size = 5; + loss_per_false_association = 1; + loss_per_missed_association = 1; + last_weight_1 = false; + } + + feature_extractor fe; + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_STRUCTURAL_ASSiGNMENT_TRAINER_Hh_ + + + + diff --git a/ml/dlib/dlib/svm/structural_assignment_trainer_abstract.h b/ml/dlib/dlib/svm/structural_assignment_trainer_abstract.h new file mode 100644 index 000000000..ebd402d42 --- /dev/null +++ b/ml/dlib/dlib/svm/structural_assignment_trainer_abstract.h @@ -0,0 +1,299 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_STRUCTURAL_ASSiGNMENT_TRAINER_ABSTRACT_Hh_ +#ifdef DLIB_STRUCTURAL_ASSiGNMENT_TRAINER_ABSTRACT_Hh_ + +#include "../algs.h" +#include "structural_svm_assignment_problem.h" +#include "assignment_function_abstract.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + class structural_assignment_trainer + { + /*! + REQUIREMENTS ON feature_extractor + It must be an object that implements an interface compatible with + the example_feature_extractor defined in dlib/svm/assignment_function_abstract.h. + + WHAT THIS OBJECT REPRESENTS + This object is a tool for learning to solve an assignment problem based + on a training dataset of example assignments. The training procedure produces an + assignment_function object which can be used to predict the assignments of + new data. + + Note that this is just a convenience wrapper around the + structural_svm_assignment_problem to make it look + similar to all the other trainers in dlib. + !*/ + + public: + typedef typename feature_extractor::lhs_element lhs_element; + typedef typename feature_extractor::rhs_element rhs_element; + typedef std::pair, std::vector > sample_type; + typedef std::vector label_type; + typedef assignment_function trained_function_type; + + structural_assignment_trainer ( + ); + /*! + ensures + - #get_c() == 100 + - this object isn't verbose + - #get_epsilon() == 0.01 + - #get_num_threads() == 2 + - #get_max_cache_size() == 5 + - #get_feature_extractor() == a default initialized feature_extractor + - #forces_assignment() == false + - #get_loss_per_false_association() == 1 + - #get_loss_per_missed_association() == 1 + - #forces_last_weight_to_1() == false + !*/ + + explicit structural_assignment_trainer ( + const feature_extractor& fe + ); + /*! + ensures + - #get_c() == 100 + - this object isn't verbose + - #get_epsilon() == 0.01 + - #get_num_threads() == 2 + - #get_max_cache_size() == 40 + - #get_feature_extractor() == fe + - #forces_assignment() == false + - #get_loss_per_false_association() == 1 + - #get_loss_per_missed_association() == 1 + - #forces_last_weight_to_1() == false + !*/ + + const feature_extractor& get_feature_extractor ( + ) const; + /*! + ensures + - returns the feature extractor used by this object + !*/ + + void set_num_threads ( + unsigned long num + ); + /*! + ensures + - #get_num_threads() == num + !*/ + + unsigned long get_num_threads ( + ) const; + /*! + ensures + - returns the number of threads used during training. You should + usually set this equal to the number of processing cores on your + machine. + !*/ + + void set_epsilon ( + double eps + ); + /*! + requires + - eps > 0 + ensures + - #get_epsilon() == eps + !*/ + + double get_epsilon ( + ) const; + /*! + ensures + - returns the error epsilon that determines when training should stop. + Smaller values may result in a more accurate solution but take longer + to train. You can think of this epsilon value as saying "solve the + optimization problem until the average number of assignment mistakes per + training sample is within epsilon of its optimal value". + !*/ + + void set_max_cache_size ( + unsigned long max_size + ); + /*! + ensures + - #get_max_cache_size() == max_size + !*/ + + unsigned long get_max_cache_size ( + ) const; + /*! + ensures + - During training, this object basically runs the assignment_function on + each training sample, over and over. To speed this up, it is possible to + cache the results of these invocations. This function returns the number + of cache elements per training sample kept in the cache. Note that a value + of 0 means caching is not used at all. + !*/ + + void set_loss_per_false_association ( + double loss + ); + /*! + requires + - loss > 0 + ensures + - #get_loss_per_false_association() == loss + !*/ + + double get_loss_per_false_association ( + ) const; + /*! + ensures + - returns the amount of loss experienced for associating two objects + together that shouldn't be associated. If you care more about avoiding + accidental associations than ensuring all possible associations are + identified then then you can increase this value. + !*/ + + void set_loss_per_missed_association ( + double loss + ); + /*! + requires + - loss > 0 + ensures + - #get_loss_per_missed_association() == loss + !*/ + + double get_loss_per_missed_association ( + ) const; + /*! + ensures + - returns the amount of loss experienced for failing to associate two + objects that are supposed to be associated. If you care more about + getting all the associations than avoiding accidentally associating + objects that shouldn't be associated then you can increase this value. + !*/ + + void be_verbose ( + ); + /*! + ensures + - This object will print status messages to standard out so that a + user can observe the progress of the algorithm. + !*/ + + void be_quiet ( + ); + /*! + ensures + - this object will not print anything to standard out + !*/ + + void set_oca ( + const oca& item + ); + /*! + ensures + - #get_oca() == item + !*/ + + const oca get_oca ( + ) const; + /*! + ensures + - returns a copy of the optimizer used to solve the structural SVM problem. + !*/ + + void set_c ( + double C + ); + /*! + requires + - C > 0 + ensures + - #get_c() = C + !*/ + + double get_c ( + ) const; + /*! + ensures + - returns the SVM regularization parameter. It is the parameter + that determines the trade-off between trying to fit the training + data (i.e. minimize the loss) or allowing more errors but hopefully + improving the generalization of the resulting assignment_function. + Larger values encourage exact fitting while smaller values of C may + encourage better generalization. + !*/ + + void set_forces_assignment ( + bool new_value + ); + /*! + ensures + - #forces_assignment() == new_value + !*/ + + bool forces_assignment( + ) const; + /*! + ensures + - returns the value of the forces_assignment() parameter for the + assignment_functions generated by this object. + !*/ + + bool forces_last_weight_to_1 ( + ) const; + /*! + ensures + - returns true if this trainer has the constraint that the last weight in + the learned parameter vector must be 1. This is the weight corresponding + to the feature in the training vectors with the highest dimension. + - Forcing the last weight to 1 also disables the bias and therefore the + get_bias() field of the learned assignment_function will be 0 when + forces_last_weight_to_1() == true. + !*/ + + void force_last_weight_to_1 ( + bool should_last_weight_be_1 + ); + /*! + ensures + - #forces_last_weight_to_1() == should_last_weight_be_1 + !*/ + + const assignment_function train ( + const std::vector& samples, + const std::vector& labels + ) const; + /*! + requires + - is_assignment_problem(samples,labels) == true + - if (forces_assignment()) then + - is_forced_assignment_problem(samples,labels) == true + ensures + - Uses the structural_svm_assignment_problem to train an + assignment_function on the given samples/labels training pairs. + The idea is to learn to predict a label given an input sample. + - returns a function F with the following properties: + - F(new_sample) == A set of assignments indicating how the elements of + new_sample.first match up with the elements of new_sample.second. + - F.forces_assignment() == forces_assignment() + - F.get_feature_extractor() == get_feature_extractor() + - if (forces_last_weight_to_1()) then + - F.get_bias() == 0 + - F.get_weights()(F.get_weights().size()-1) == 1 + !*/ + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_STRUCTURAL_ASSiGNMENT_TRAINER_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/svm/structural_graph_labeling_trainer.h b/ml/dlib/dlib/svm/structural_graph_labeling_trainer.h new file mode 100644 index 000000000..4d55c772b --- /dev/null +++ b/ml/dlib/dlib/svm/structural_graph_labeling_trainer.h @@ -0,0 +1,282 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_STRUCTURAL_GRAPH_LABELING_tRAINER_Hh_ +#define DLIB_STRUCTURAL_GRAPH_LABELING_tRAINER_Hh_ + +#include "structural_graph_labeling_trainer_abstract.h" +#include "../algs.h" +#include "../optimization.h" +#include "structural_svm_graph_labeling_problem.h" +#include "../graph_cuts/graph_labeler.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename vector_type + > + class structural_graph_labeling_trainer + { + public: + typedef std::vector label_type; + typedef graph_labeler trained_function_type; + + structural_graph_labeling_trainer ( + ) + { + C = 10; + verbose = false; + eps = 0.1; + num_threads = 2; + max_cache_size = 5; + loss_pos = 1.0; + loss_neg = 1.0; + } + + void set_num_threads ( + unsigned long num + ) + { + num_threads = num; + } + + unsigned long get_num_threads ( + ) const + { + return num_threads; + } + + void set_epsilon ( + double eps_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(eps_ > 0, + "\t void structural_graph_labeling_trainer::set_epsilon()" + << "\n\t eps_ must be greater than 0" + << "\n\t eps_: " << eps_ + << "\n\t this: " << this + ); + + eps = eps_; + } + + double get_epsilon ( + ) const { return eps; } + + void set_max_cache_size ( + unsigned long max_size + ) + { + max_cache_size = max_size; + } + + unsigned long get_max_cache_size ( + ) const + { + return max_cache_size; + } + + void be_verbose ( + ) + { + verbose = true; + } + + void be_quiet ( + ) + { + verbose = false; + } + + void set_oca ( + const oca& item + ) + { + solver = item; + } + + const oca get_oca ( + ) const + { + return solver; + } + + void set_c ( + double C_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(C_ > 0, + "\t void structural_graph_labeling_trainer::set_c()" + << "\n\t C_ must be greater than 0" + << "\n\t C_: " << C_ + << "\n\t this: " << this + ); + + C = C_; + } + + double get_c ( + ) const + { + return C; + } + + + void set_loss_on_positive_class ( + double loss + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(loss >= 0, + "\t structural_graph_labeling_trainer::set_loss_on_positive_class()" + << "\n\t Invalid inputs were given to this function." + << "\n\t loss: " << loss + << "\n\t this: " << this ); + + loss_pos = loss; + } + + void set_loss_on_negative_class ( + double loss + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(loss >= 0, + "\t structural_graph_labeling_trainer::set_loss_on_negative_class()" + << "\n\t Invalid inputs were given to this function." + << "\n\t loss: " << loss + << "\n\t this: " << this ); + + loss_neg = loss; + } + + double get_loss_on_negative_class ( + ) const { return loss_neg; } + + double get_loss_on_positive_class ( + ) const { return loss_pos; } + + + template < + typename graph_type + > + const graph_labeler train ( + const dlib::array& samples, + const std::vector& labels, + const std::vector >& losses + ) const + { +#ifdef ENABLE_ASSERTS + std::string reason_for_failure; + DLIB_ASSERT(is_graph_labeling_problem(samples, labels, reason_for_failure) == true , + "\t void structural_graph_labeling_trainer::train()" + << "\n\t Invalid inputs were given to this function." + << "\n\t reason_for_failure: " << reason_for_failure + << "\n\t samples.size(): " << samples.size() + << "\n\t labels.size(): " << labels.size() + << "\n\t this: " << this ); + DLIB_ASSERT((losses.size() == 0 || sizes_match(labels, losses) == true) && + all_values_are_nonnegative(losses) == true, + "\t void structural_graph_labeling_trainer::train()" + << "\n\t Invalid inputs were given to this function." + << "\n\t labels.size(): " << labels.size() + << "\n\t losses.size(): " << losses.size() + << "\n\t sizes_match(labels,losses): " << sizes_match(labels,losses) + << "\n\t all_values_are_nonnegative(losses): " << all_values_are_nonnegative(losses) + << "\n\t this: " << this ); +#endif + + + structural_svm_graph_labeling_problem prob(samples, labels, losses, num_threads); + + if (verbose) + prob.be_verbose(); + + prob.set_c(C); + prob.set_epsilon(eps); + prob.set_max_cache_size(max_cache_size); + if (prob.get_losses().size() == 0) + { + prob.set_loss_on_positive_class(loss_pos); + prob.set_loss_on_negative_class(loss_neg); + } + + matrix w; + solver(prob, w, prob.get_num_edge_weights()); + + vector_type edge_weights; + vector_type node_weights; + populate_weights(w, edge_weights, node_weights, prob.get_num_edge_weights()); + return graph_labeler(edge_weights, node_weights); + } + + template < + typename graph_type + > + const graph_labeler train ( + const dlib::array& samples, + const std::vector& labels + ) const + { + std::vector > losses; + return train(samples, labels, losses); + } + + private: + + template + typename enable_if >::type populate_weights ( + const matrix& w, + T& edge_weights, + T& node_weights, + long split_idx + ) const + { + edge_weights = rowm(w,range(0, split_idx-1)); + node_weights = rowm(w,range(split_idx,w.size()-1)); + } + + template + typename disable_if >::type populate_weights ( + const matrix& w, + T& edge_weights, + T& node_weights, + long split_idx + ) const + { + edge_weights.clear(); + node_weights.clear(); + for (long i = 0; i < split_idx; ++i) + { + if (w(i) != 0) + edge_weights.insert(edge_weights.end(), std::make_pair(i,w(i))); + } + for (long i = split_idx; i < w.size(); ++i) + { + if (w(i) != 0) + node_weights.insert(node_weights.end(), std::make_pair(i-split_idx,w(i))); + } + } + + + double C; + oca solver; + double eps; + bool verbose; + unsigned long num_threads; + unsigned long max_cache_size; + double loss_pos; + double loss_neg; + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_STRUCTURAL_GRAPH_LABELING_tRAINER_Hh_ + diff --git a/ml/dlib/dlib/svm/structural_graph_labeling_trainer_abstract.h b/ml/dlib/dlib/svm/structural_graph_labeling_trainer_abstract.h new file mode 100644 index 000000000..df88096a0 --- /dev/null +++ b/ml/dlib/dlib/svm/structural_graph_labeling_trainer_abstract.h @@ -0,0 +1,265 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_STRUCTURAL_GRAPH_LABELING_tRAINER_ABSTRACT_Hh_ +#ifdef DLIB_STRUCTURAL_GRAPH_LABELING_tRAINER_ABSTRACT_Hh_ + +#include "../algs.h" +#include "../optimization.h" +#include "structural_svm_graph_labeling_problem_abstract.h" +#include "../graph_cuts/graph_labeler_abstract.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename vector_type + > + class structural_graph_labeling_trainer + { + /*! + REQUIREMENTS ON vector_type + - vector_type is a dlib::matrix capable of representing column + vectors or it is a sparse vector type as defined in dlib/svm/sparse_vector_abstract.h. + + WHAT THIS OBJECT REPRESENTS + This object is a tool for learning to solve a graph labeling problem based + on a training dataset of example labeled graphs. The training procedure + produces a graph_labeler object which can be used to predict the labelings + of new graphs. + + Note that this is just a convenience wrapper around the + structural_svm_graph_labeling_problem to make it look + similar to all the other trainers in dlib. + !*/ + + public: + typedef std::vector label_type; + typedef graph_labeler trained_function_type; + + structural_graph_labeling_trainer ( + ); + /*! + ensures + - #get_c() == 10 + - this object isn't verbose + - #get_epsilon() == 0.1 + - #get_num_threads() == 2 + - #get_max_cache_size() == 5 + - #get_loss_on_positive_class() == 1.0 + - #get_loss_on_negative_class() == 1.0 + !*/ + + void set_num_threads ( + unsigned long num + ); + /*! + ensures + - #get_num_threads() == num + !*/ + + unsigned long get_num_threads ( + ) const; + /*! + ensures + - returns the number of threads used during training. You should + usually set this equal to the number of processing cores on your + machine. + !*/ + + void set_epsilon ( + double eps + ); + /*! + requires + - eps > 0 + ensures + - #get_epsilon() == eps + !*/ + + double get_epsilon ( + ) const; + /*! + ensures + - returns the error epsilon that determines when training should stop. + Smaller values may result in a more accurate solution but take longer + to train. You can think of this epsilon value as saying "solve the + optimization problem until the average number of labeling mistakes per + example graph is within epsilon of its optimal value". + !*/ + + void set_max_cache_size ( + unsigned long max_size + ); + /*! + ensures + - #get_max_cache_size() == max_size + !*/ + + unsigned long get_max_cache_size ( + ) const; + /*! + ensures + - During training, this object basically runs the graph_labeler on each + training sample, over and over. To speed this up, it is possible to + cache the results of these invocations. This function returns the number + of cache elements per training sample kept in the cache. Note that a value + of 0 means caching is not used at all. + !*/ + + void be_verbose ( + ); + /*! + ensures + - This object will print status messages to standard out so that a + user can observe the progress of the algorithm. + !*/ + + void be_quiet ( + ); + /*! + ensures + - this object will not print anything to standard out + !*/ + + void set_oca ( + const oca& item + ); + /*! + ensures + - #get_oca() == item + !*/ + + const oca get_oca ( + ) const; + /*! + ensures + - returns a copy of the optimizer used to solve the structural SVM problem. + !*/ + + void set_c ( + double C + ); + /*! + requires + - C > 0 + ensures + - #get_c() = C + !*/ + + double get_c ( + ) const; + /*! + ensures + - returns the SVM regularization parameter. It is the parameter + that determines the trade-off between trying to fit the training + data (i.e. minimize the loss) or allowing more errors but hopefully + improving the generalization of the resulting graph_labeler. Larger + values encourage exact fitting while smaller values of C may encourage + better generalization. + !*/ + + void set_loss_on_positive_class ( + double loss + ); + /*! + requires + - loss >= 0 + ensures + - #get_loss_on_positive_class() == loss + !*/ + + void set_loss_on_negative_class ( + double loss + ); + /*! + requires + - loss >= 0 + ensures + - #get_loss_on_negative_class() == loss + !*/ + + double get_loss_on_positive_class ( + ) const; + /*! + ensures + - returns the loss incurred when a graph node which is supposed to have + a label of true gets misclassified. This value controls how much we care + about correctly classifying nodes which should be labeled as true. Larger + loss values indicate that we care more strongly than smaller values. + !*/ + + double get_loss_on_negative_class ( + ) const; + /*! + ensures + - returns the loss incurred when a graph node which is supposed to have + a label of false gets misclassified. This value controls how much we care + about correctly classifying nodes which should be labeled as false. Larger + loss values indicate that we care more strongly than smaller values. + !*/ + + template < + typename graph_type + > + const graph_labeler train ( + const dlib::array& samples, + const std::vector& labels + ) const; + /*! + requires + - is_graph_labeling_problem(samples,labels) == true + ensures + - Uses the structural_svm_graph_labeling_problem to train a graph_labeler + on the given samples/labels training pairs. The idea is to learn to + predict a label given an input sample. + - The values of get_loss_on_positive_class() and get_loss_on_negative_class() + are used to determine how to value mistakes on each node during training. + - returns a function F with the following properties: + - F(new_sample) == The predicted labels for the nodes in the graph + new_sample. + !*/ + + template < + typename graph_type + > + const graph_labeler train ( + const dlib::array& samples, + const std::vector& labels, + const std::vector >& losses + ) const; + /*! + requires + - is_graph_labeling_problem(samples,labels) == true + - if (losses.size() != 0) then + - sizes_match(labels, losses) == true + - all_values_are_nonnegative(losses) == true + ensures + - Uses the structural_svm_graph_labeling_problem to train a graph_labeler + on the given samples/labels training pairs. The idea is to learn to + predict a label given an input sample. + - returns a function F with the following properties: + - F(new_sample) == The predicted labels for the nodes in the graph + new_sample. + - if (losses.size() == 0) then + - The values of get_loss_on_positive_class() and get_loss_on_negative_class() + are used to determine how to value mistakes on each node during training. + - The losses argument is effectively ignored if its size is zero. + - else + - Each node in the training data has its own loss value defined by the + corresponding entry of losses. In particular, this means that the + node with label labels[i][j] incurs a loss of losses[i][j] if it is + incorrectly labeled. + - The get_loss_on_positive_class() and get_loss_on_negative_class() + parameters are ignored. Only losses is used in this case. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_STRUCTURAL_GRAPH_LABELING_tRAINER_ABSTRACT_Hh_ + diff --git a/ml/dlib/dlib/svm/structural_object_detection_trainer.h b/ml/dlib/dlib/svm/structural_object_detection_trainer.h new file mode 100644 index 000000000..bdf8c5b5c --- /dev/null +++ b/ml/dlib/dlib/svm/structural_object_detection_trainer.h @@ -0,0 +1,402 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_STRUCTURAL_OBJECT_DETECTION_TRAiNER_Hh_ +#define DLIB_STRUCTURAL_OBJECT_DETECTION_TRAiNER_Hh_ + +#include "structural_object_detection_trainer_abstract.h" +#include "../algs.h" +#include "../optimization.h" +#include "structural_svm_object_detection_problem.h" +#include "../image_processing/object_detector.h" +#include "../image_processing/box_overlap_testing.h" +#include "../image_processing/full_object_detection.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_scanner_type, + typename svm_struct_prob_type + > + void configure_nuclear_norm_regularizer ( + const image_scanner_type&, + svm_struct_prob_type& + ) + { + // does nothing by default. Specific scanner types overload this function to do + // whatever is appropriate. + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_scanner_type + > + class structural_object_detection_trainer : noncopyable + { + + public: + typedef double scalar_type; + typedef default_memory_manager mem_manager_type; + typedef object_detector trained_function_type; + + + explicit structural_object_detection_trainer ( + const image_scanner_type& scanner_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(scanner_.get_num_detection_templates() > 0, + "\t structural_object_detection_trainer::structural_object_detection_trainer(scanner_)" + << "\n\t You can't have zero detection templates" + << "\n\t this: " << this + ); + + C = 1; + verbose = false; + eps = 0.1; + num_threads = 2; + max_cache_size = 5; + match_eps = 0.5; + loss_per_missed_target = 1; + loss_per_false_alarm = 1; + + scanner.copy_configuration(scanner_); + + auto_overlap_tester = true; + } + + const image_scanner_type& get_scanner ( + ) const + { + return scanner; + } + + bool auto_set_overlap_tester ( + ) const + { + return auto_overlap_tester; + } + + void set_overlap_tester ( + const test_box_overlap& tester + ) + { + overlap_tester = tester; + auto_overlap_tester = false; + } + + test_box_overlap get_overlap_tester ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(auto_set_overlap_tester() == false, + "\t test_box_overlap structural_object_detection_trainer::get_overlap_tester()" + << "\n\t You can't call this function if the overlap tester is generated dynamically." + << "\n\t this: " << this + ); + + return overlap_tester; + } + + void set_num_threads ( + unsigned long num + ) + { + num_threads = num; + } + + unsigned long get_num_threads ( + ) const + { + return num_threads; + } + + void set_epsilon ( + scalar_type eps_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(eps_ > 0, + "\t void structural_object_detection_trainer::set_epsilon()" + << "\n\t eps_ must be greater than 0" + << "\n\t eps_: " << eps_ + << "\n\t this: " << this + ); + + eps = eps_; + } + + scalar_type get_epsilon ( + ) const { return eps; } + + void set_max_cache_size ( + unsigned long max_size + ) + { + max_cache_size = max_size; + } + + unsigned long get_max_cache_size ( + ) const + { + return max_cache_size; + } + + void be_verbose ( + ) + { + verbose = true; + } + + void be_quiet ( + ) + { + verbose = false; + } + + void set_oca ( + const oca& item + ) + { + solver = item; + } + + const oca get_oca ( + ) const + { + return solver; + } + + void set_c ( + scalar_type C_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(C_ > 0, + "\t void structural_object_detection_trainer::set_c()" + << "\n\t C_ must be greater than 0" + << "\n\t C_: " << C_ + << "\n\t this: " << this + ); + + C = C_; + } + + scalar_type get_c ( + ) const + { + return C; + } + + void set_match_eps ( + double eps + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(0 < eps && eps < 1, + "\t void structural_object_detection_trainer::set_match_eps(eps)" + << "\n\t Invalid inputs were given to this function " + << "\n\t eps: " << eps + << "\n\t this: " << this + ); + + match_eps = eps; + } + + double get_match_eps ( + ) const + { + return match_eps; + } + + double get_loss_per_missed_target ( + ) const + { + return loss_per_missed_target; + } + + void set_loss_per_missed_target ( + double loss + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(loss > 0, + "\t void structural_object_detection_trainer::set_loss_per_missed_target(loss)" + << "\n\t Invalid inputs were given to this function " + << "\n\t loss: " << loss + << "\n\t this: " << this + ); + + loss_per_missed_target = loss; + } + + double get_loss_per_false_alarm ( + ) const + { + return loss_per_false_alarm; + } + + void set_loss_per_false_alarm ( + double loss + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(loss > 0, + "\t void structural_object_detection_trainer::set_loss_per_false_alarm(loss)" + << "\n\t Invalid inputs were given to this function " + << "\n\t loss: " << loss + << "\n\t this: " << this + ); + + loss_per_false_alarm = loss; + } + + template < + typename image_array_type + > + const trained_function_type train ( + const image_array_type& images, + const std::vector >& truth_object_detections + ) const + { + std::vector > empty_ignore(images.size()); + return train_impl(images, truth_object_detections, empty_ignore, test_box_overlap()); + } + + template < + typename image_array_type + > + const trained_function_type train ( + const image_array_type& images, + const std::vector >& truth_object_detections, + const std::vector >& ignore, + const test_box_overlap& ignore_overlap_tester = test_box_overlap() + ) const + { + return train_impl(images, truth_object_detections, ignore, ignore_overlap_tester); + } + + template < + typename image_array_type + > + const trained_function_type train ( + const image_array_type& images, + const std::vector >& truth_object_detections + ) const + { + std::vector > empty_ignore(images.size()); + return train(images, truth_object_detections, empty_ignore, test_box_overlap()); + } + + template < + typename image_array_type + > + const trained_function_type train ( + const image_array_type& images, + const std::vector >& truth_object_detections, + const std::vector >& ignore, + const test_box_overlap& ignore_overlap_tester = test_box_overlap() + ) const + { + std::vector > truth_dets(truth_object_detections.size()); + for (unsigned long i = 0; i < truth_object_detections.size(); ++i) + { + for (unsigned long j = 0; j < truth_object_detections[i].size(); ++j) + { + truth_dets[i].push_back(full_object_detection(truth_object_detections[i][j])); + } + } + + return train_impl(images, truth_dets, ignore, ignore_overlap_tester); + } + + private: + + template < + typename image_array_type + > + const trained_function_type train_impl ( + const image_array_type& images, + const std::vector >& truth_object_detections, + const std::vector >& ignore, + const test_box_overlap& ignore_overlap_tester + ) const + { +#ifdef ENABLE_ASSERTS + // make sure requires clause is not broken + DLIB_ASSERT(is_learning_problem(images,truth_object_detections) == true && images.size() == ignore.size(), + "\t trained_function_type structural_object_detection_trainer::train()" + << "\n\t invalid inputs were given to this function" + << "\n\t images.size(): " << images.size() + << "\n\t ignore.size(): " << ignore.size() + << "\n\t truth_object_detections.size(): " << truth_object_detections.size() + << "\n\t is_learning_problem(images,truth_object_detections): " << is_learning_problem(images,truth_object_detections) + ); + for (unsigned long i = 0; i < truth_object_detections.size(); ++i) + { + for (unsigned long j = 0; j < truth_object_detections[i].size(); ++j) + { + DLIB_ASSERT(truth_object_detections[i][j].num_parts() == get_scanner().get_num_movable_components_per_detection_template() && + all_parts_in_rect(truth_object_detections[i][j]) == true, + "\t trained_function_type structural_object_detection_trainer::train()" + << "\n\t invalid inputs were given to this function" + << "\n\t truth_object_detections["< + svm_prob(scanner, overlap_tester, auto_overlap_tester, images, + truth_object_detections, ignore, ignore_overlap_tester, num_threads); + + if (verbose) + svm_prob.be_verbose(); + + svm_prob.set_c(C); + svm_prob.set_epsilon(eps); + svm_prob.set_max_cache_size(max_cache_size); + svm_prob.set_match_eps(match_eps); + svm_prob.set_loss_per_missed_target(loss_per_missed_target); + svm_prob.set_loss_per_false_alarm(loss_per_false_alarm); + configure_nuclear_norm_regularizer(scanner, svm_prob); + matrix w; + + // Run the optimizer to find the optimal w. + solver(svm_prob,w); + + // report the results of the training. + return object_detector(scanner, svm_prob.get_overlap_tester(), w); + } + + image_scanner_type scanner; + test_box_overlap overlap_tester; + + double C; + oca solver; + double eps; + double match_eps; + bool verbose; + unsigned long num_threads; + unsigned long max_cache_size; + double loss_per_missed_target; + double loss_per_false_alarm; + bool auto_overlap_tester; + + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_STRUCTURAL_OBJECT_DETECTION_TRAiNER_Hh_ + + diff --git a/ml/dlib/dlib/svm/structural_object_detection_trainer_abstract.h b/ml/dlib/dlib/svm/structural_object_detection_trainer_abstract.h new file mode 100644 index 000000000..2dd799874 --- /dev/null +++ b/ml/dlib/dlib/svm/structural_object_detection_trainer_abstract.h @@ -0,0 +1,390 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_STRUCTURAL_OBJECT_DETECTION_TRAiNER_H_ABSTRACTh_ +#ifdef DLIB_STRUCTURAL_OBJECT_DETECTION_TRAiNER_H_ABSTRACTh_ + +#include "structural_svm_object_detection_problem_abstract.h" +#include "../image_processing/object_detector_abstract.h" +#include "../image_processing/box_overlap_testing_abstract.h" +#include "../image_processing/full_object_detection_abstract.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_scanner_type + > + class structural_object_detection_trainer : noncopyable + { + /*! + REQUIREMENTS ON image_scanner_type + image_scanner_type must be an implementation of + dlib/image_processing/scan_fhog_pyramid_abstract.h or + dlib/image_processing/scan_image_custom_abstract.h or + dlib/image_processing/scan_image_pyramid_abstract.h or + dlib/image_processing/scan_image_boxes_abstract.h + + WHAT THIS OBJECT REPRESENTS + This object is a tool for learning to detect objects in images based on a + set of labeled images. The training procedure produces an object_detector + which can be used to predict the locations of objects in new images. + + Note that this is just a convenience wrapper around the structural_svm_object_detection_problem + to make it look similar to all the other trainers in dlib. + !*/ + + public: + typedef double scalar_type; + typedef default_memory_manager mem_manager_type; + typedef object_detector trained_function_type; + + + explicit structural_object_detection_trainer ( + const image_scanner_type& scanner + ); + /*! + requires + - scanner.get_num_detection_templates() > 0 + ensures + - #get_c() == 1 + - this object isn't verbose + - #get_epsilon() == 0.1 + - #get_num_threads() == 2 + - #get_max_cache_size() == 5 + - #get_match_eps() == 0.5 + - #get_loss_per_missed_target() == 1 + - #get_loss_per_false_alarm() == 1 + - This object will attempt to learn a model for the given + scanner object when train() is called. + - #get_scanner() == scanner + (note that only the "configuration" of scanner is copied. + I.e. the copy is done using copy_configuration()) + - #auto_set_overlap_tester() == true + !*/ + + const image_scanner_type& get_scanner ( + ) const; + /*! + ensures + - returns the image scanner used by this object. + !*/ + + bool auto_set_overlap_tester ( + ) const; + /*! + ensures + - if (this object will automatically determine an appropriate + state for the overlap tester used for non-max suppression.) then + - returns true + - In this case, it is determined using the find_tight_overlap_tester() + routine based on the truth_object_detections given to the + structural_object_detection_trainer::train() method. + - else + - returns false + !*/ + + void set_overlap_tester ( + const test_box_overlap& tester + ); + /*! + ensures + - #get_overlap_tester() == tester + - #auto_set_overlap_tester() == false + !*/ + + test_box_overlap get_overlap_tester ( + ) const; + /*! + requires + - auto_set_overlap_tester() == false + ensures + - returns the overlap tester object which will be used to perform non-max suppression. + In particular, this function returns the overlap tester which will populate the + object_detector returned by train(). + !*/ + + void set_num_threads ( + unsigned long num + ); + /*! + ensures + - #get_num_threads() == num + !*/ + + unsigned long get_num_threads ( + ) const; + /*! + ensures + - returns the number of threads used during training. You should + usually set this equal to the number of processing cores on your + machine. + !*/ + + void set_epsilon ( + scalar_type eps + ); + /*! + requires + - eps > 0 + ensures + - #get_epsilon() == eps + !*/ + + const scalar_type get_epsilon ( + ) const; + /*! + ensures + - returns the error epsilon that determines when training should stop. + Smaller values may result in a more accurate solution but take longer + to train. You can think of this epsilon value as saying "solve the + optimization problem until the average loss per sample is within epsilon + of its optimal value". + !*/ + + void set_max_cache_size ( + unsigned long max_size + ); + /*! + ensures + - #get_max_cache_size() == max_size + !*/ + + unsigned long get_max_cache_size ( + ) const; + /*! + ensures + - During training, this object basically runs the object detector on + each image, over and over. To speed this up, it is possible to cache + the results of these detector invocations. This function returns the + number of cache elements per training sample kept in the cache. Note + that a value of 0 means caching is not used at all. Note also that + each cache element takes up about sizeof(double)*scanner.get_num_dimensions() + memory (where scanner is the scanner given to this object's constructor). + !*/ + + void be_verbose ( + ); + /*! + ensures + - This object will print status messages to standard out so that a + user can observe the progress of the algorithm. + !*/ + + void be_quiet ( + ); + /*! + ensures + - this object will not print anything to standard out + !*/ + + void set_oca ( + const oca& item + ); + /*! + ensures + - #get_oca() == item + !*/ + + const oca get_oca ( + ) const; + /*! + ensures + - returns a copy of the optimizer used to solve the structural SVM problem. + !*/ + + void set_c ( + scalar_type C + ); + /*! + requires + - C > 0 + ensures + - #get_c() = C + !*/ + + const scalar_type get_c ( + ) const; + /*! + ensures + - returns the SVM regularization parameter. It is the parameter + that determines the trade-off between trying to fit the training + data (i.e. minimize the loss) or allowing more errors but hopefully + improving the generalization of the resulting detector. Larger + values encourage exact fitting while smaller values of C may encourage + better generalization. + !*/ + + void set_match_eps ( + double eps + ); + /*! + requires + - 0 < eps < 1 + ensures + - #get_match_eps() == eps + !*/ + + double get_match_eps ( + ) const; + /*! + ensures + - returns the amount of alignment necessary for a detection to be considered + as matching with a ground truth rectangle. If it doesn't match then + it is considered to be a false alarm. To define this precisely, let + A and B be two rectangles, then A and B match if and only if: + A.intersect(B).area()/(A+B).area() > get_match_eps() + !*/ + + double get_loss_per_missed_target ( + ) const; + /*! + ensures + - returns the amount of loss experienced for failing to detect one of the + targets. If you care more about finding targets than having a low false + alarm rate then you can increase this value. + !*/ + + void set_loss_per_missed_target ( + double loss + ); + /*! + requires + - loss > 0 + ensures + - #get_loss_per_missed_target() == loss + !*/ + + double get_loss_per_false_alarm ( + ) const; + /*! + ensures + - returns the amount of loss experienced for emitting a false alarm detection. + Or in other words, the loss for generating a detection that doesn't correspond + to one of the truth rectangles. If you care more about having a low false + alarm rate than finding all the targets then you can increase this value. + !*/ + + void set_loss_per_false_alarm ( + double loss + ); + /*! + requires + - loss > 0 + ensures + - #get_loss_per_false_alarm() == loss + !*/ + + template < + typename image_array_type + > + const trained_function_type train ( + const image_array_type& images, + const std::vector >& truth_object_detections + ) const; + /*! + requires + - is_learning_problem(images, truth_object_detections) == true + - it must be valid to pass images[0] into the image_scanner_type::load() method. + (also, image_array_type must be an implementation of dlib/array/array_kernel_abstract.h) + - for all valid i, j: + - truth_object_detections[i][j].num_parts() == get_scanner().get_num_movable_components_per_detection_template() + - all_parts_in_rect(truth_object_detections[i][j]) == true + ensures + - Uses the structural_svm_object_detection_problem to train an object_detector + on the given images and truth_object_detections. + - returns a function F with the following properties: + - F(new_image) == A prediction of what objects are present in new_image. This + is a set of rectangles indicating their positions. + !*/ + + template < + typename image_array_type + > + const trained_function_type train ( + const image_array_type& images, + const std::vector >& truth_object_detections + ) const; + /*! + requires + - is_learning_problem(images, truth_object_detections) == true + - it must be valid to pass images[0] into the image_scanner_type::load() method. + (also, image_array_type must be an implementation of dlib/array/array_kernel_abstract.h) + - get_scanner().get_num_movable_components_per_detection_template() == 0 + ensures + - This function is identical to the above train(), except that it converts + each element of truth_object_detections into a full_object_detection by + passing it to full_object_detection's constructor taking only a rectangle. + Therefore, this version of train() is a convenience function for for the + case where you don't have any movable components of the detection templates. + !*/ + + template < + typename image_array_type + > + const trained_function_type train ( + const image_array_type& images, + const std::vector >& truth_object_detections, + const std::vector >& ignore, + const test_box_overlap& ignore_overlap_tester = test_box_overlap() + ) const; + /*! + requires + - is_learning_problem(images, truth_object_detections) == true + - it must be valid to pass images[0] into the image_scanner_type::load() method. + (also, image_array_type must be an implementation of dlib/array/array_kernel_abstract.h) + - ignore.size() == images.size() + - for all valid i, j: + - truth_object_detections[i][j].num_parts() == get_scanner().get_num_movable_components_per_detection_template() + - all_parts_in_rect(truth_object_detections[i][j]) == true + ensures + - Uses the structural_svm_object_detection_problem to train an object_detector + on the given images and truth_object_detections. + - for all valid i: + - Within images[i] any detections that match against a rectangle in + ignore[i], according to ignore_overlap_tester, are ignored. That is, + the optimizer doesn't care if the detector outputs a detection that + matches any of the ignore rectangles or if it fails to output a + detection for an ignore rectangle. Therefore, if there are objects + in your dataset that you are unsure if you want to detect or otherwise + don't care if the detector gets or doesn't then you can mark them + with ignore rectangles and the optimizer will simply ignore them. + - returns a function F with the following properties: + - F(new_image) == A prediction of what objects are present in new_image. This + is a set of rectangles indicating their positions. + !*/ + + template < + typename image_array_type + > + const trained_function_type train ( + const image_array_type& images, + const std::vector >& truth_object_detections, + const std::vector >& ignore, + const test_box_overlap& ignore_overlap_tester = test_box_overlap() + ) const; + /*! + requires + - is_learning_problem(images, truth_object_detections) == true + - ignore.size() == images.size() + - it must be valid to pass images[0] into the image_scanner_type::load() method. + (also, image_array_type must be an implementation of dlib/array/array_kernel_abstract.h) + - get_scanner().get_num_movable_components_per_detection_template() == 0 + ensures + - This function is identical to the above train(), except that it converts + each element of truth_object_detections into a full_object_detection by + passing it to full_object_detection's constructor taking only a rectangle. + Therefore, this version of train() is a convenience function for for the + case where you don't have any movable components of the detection templates. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_STRUCTURAL_OBJECT_DETECTION_TRAiNER_H_ABSTRACTh_ + + diff --git a/ml/dlib/dlib/svm/structural_sequence_labeling_trainer.h b/ml/dlib/dlib/svm/structural_sequence_labeling_trainer.h new file mode 100644 index 000000000..9b61fd6c2 --- /dev/null +++ b/ml/dlib/dlib/svm/structural_sequence_labeling_trainer.h @@ -0,0 +1,271 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_STRUCTURAL_SEQUENCE_LABELING_TRAiNER_Hh_ +#define DLIB_STRUCTURAL_SEQUENCE_LABELING_TRAiNER_Hh_ + +#include "structural_sequence_labeling_trainer_abstract.h" +#include "../algs.h" +#include "../optimization.h" +#include "structural_svm_sequence_labeling_problem.h" +#include "num_nonnegative_weights.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + class structural_sequence_labeling_trainer + { + public: + typedef typename feature_extractor::sequence_type sample_sequence_type; + typedef std::vector labeled_sequence_type; + + typedef sequence_labeler trained_function_type; + + explicit structural_sequence_labeling_trainer ( + const feature_extractor& fe_ + ) : fe(fe_) + { + set_defaults(); + } + + structural_sequence_labeling_trainer ( + ) + { + set_defaults(); + } + + const feature_extractor& get_feature_extractor ( + ) const { return fe; } + + unsigned long num_labels ( + ) const { return fe.num_labels(); } + + void set_num_threads ( + unsigned long num + ) + { + num_threads = num; + } + + unsigned long get_num_threads ( + ) const + { + return num_threads; + } + + void set_epsilon ( + double eps_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(eps_ > 0, + "\t void structural_sequence_labeling_trainer::set_epsilon()" + << "\n\t eps_ must be greater than 0" + << "\n\t eps_: " << eps_ + << "\n\t this: " << this + ); + + eps = eps_; + } + + double get_epsilon ( + ) const { return eps; } + + unsigned long get_max_iterations ( + ) const { return max_iterations; } + + void set_max_iterations ( + unsigned long max_iter + ) + { + max_iterations = max_iter; + } + + void set_max_cache_size ( + unsigned long max_size + ) + { + max_cache_size = max_size; + } + + unsigned long get_max_cache_size ( + ) const + { + return max_cache_size; + } + + void be_verbose ( + ) + { + verbose = true; + } + + void be_quiet ( + ) + { + verbose = false; + } + + void set_oca ( + const oca& item + ) + { + solver = item; + } + + const oca get_oca ( + ) const + { + return solver; + } + + void set_c ( + double C_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(C_ > 0, + "\t void structural_sequence_labeling_trainer::set_c()" + << "\n\t C_ must be greater than 0" + << "\n\t C_: " << C_ + << "\n\t this: " << this + ); + + C = C_; + } + + double get_c ( + ) const + { + return C; + } + + double get_loss ( + unsigned long label + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(label < num_labels(), + "\t void structural_sequence_labeling_trainer::get_loss()" + << "\n\t invalid inputs were given to this function" + << "\n\t label: " << label + << "\n\t num_labels(): " << num_labels() + << "\n\t this: " << this + ); + + return loss_values[label]; + } + + void set_loss ( + unsigned long label, + double value + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(label < num_labels() && value >= 0, + "\t void structural_sequence_labeling_trainer::set_loss()" + << "\n\t invalid inputs were given to this function" + << "\n\t label: " << label + << "\n\t num_labels(): " << num_labels() + << "\n\t value: " << value + << "\n\t this: " << this + ); + + loss_values[label] = value; + } + + + const sequence_labeler train( + const std::vector& x, + const std::vector& y + ) const + { + + // make sure requires clause is not broken + DLIB_ASSERT(is_sequence_labeling_problem(x,y) == true && + contains_invalid_labeling(get_feature_extractor(), x, y) == false, + "\t sequence_labeler structural_sequence_labeling_trainer::train(x,y)" + << "\n\t invalid inputs were given to this function" + << "\n\t x.size(): " << x.size() + << "\n\t is_sequence_labeling_problem(x,y): " << is_sequence_labeling_problem(x,y) + << "\n\t contains_invalid_labeling(get_feature_extractor(),x,y): " << contains_invalid_labeling(get_feature_extractor(),x,y) + << "\n\t this: " << this + ); + +#ifdef ENABLE_ASSERTS + for (unsigned long i = 0; i < y.size(); ++i) + { + for (unsigned long j = 0; j < y[i].size(); ++j) + { + // make sure requires clause is not broken + DLIB_ASSERT(y[i][j] < num_labels(), + "\t sequence_labeler structural_sequence_labeling_trainer::train(x,y)" + << "\n\t The given labels in y are invalid." + << "\n\t y[i][j]: " << y[i][j] + << "\n\t num_labels(): " << num_labels() + << "\n\t i: " << i + << "\n\t j: " << j + << "\n\t this: " << this + ); + } + } +#endif + + + + + structural_svm_sequence_labeling_problem prob(x, y, fe, num_threads); + matrix weights; + if (verbose) + prob.be_verbose(); + + prob.set_epsilon(eps); + prob.set_max_iterations(max_iterations); + prob.set_c(C); + prob.set_max_cache_size(max_cache_size); + for (unsigned long i = 0; i < loss_values.size(); ++i) + prob.set_loss(i,loss_values[i]); + + solver(prob, weights, num_nonnegative_weights(fe)); + + return sequence_labeler(weights,fe); + } + + private: + + double C; + oca solver; + double eps; + unsigned long max_iterations; + bool verbose; + unsigned long num_threads; + unsigned long max_cache_size; + std::vector loss_values; + + void set_defaults () + { + C = 100; + verbose = false; + eps = 0.1; + max_iterations = 10000; + num_threads = 2; + max_cache_size = 5; + loss_values.assign(num_labels(), 1); + } + + feature_extractor fe; + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_STRUCTURAL_SEQUENCE_LABELING_TRAiNER_Hh_ + + + diff --git a/ml/dlib/dlib/svm/structural_sequence_labeling_trainer_abstract.h b/ml/dlib/dlib/svm/structural_sequence_labeling_trainer_abstract.h new file mode 100644 index 000000000..43e5f5131 --- /dev/null +++ b/ml/dlib/dlib/svm/structural_sequence_labeling_trainer_abstract.h @@ -0,0 +1,266 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_STRUCTURAL_SEQUENCE_LABELING_TRAiNER_ABSTRACT_Hh_ +#ifdef DLIB_STRUCTURAL_SEQUENCE_LABELING_TRAiNER_ABSTRACT_Hh_ + +#include "../algs.h" +#include "../optimization.h" +#include "structural_svm_sequence_labeling_problem_abstract.h" +#include "sequence_labeler_abstract.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + class structural_sequence_labeling_trainer + { + /*! + REQUIREMENTS ON feature_extractor + It must be an object that implements an interface compatible with + the example_feature_extractor defined in dlib/svm/sequence_labeler_abstract.h. + + WHAT THIS OBJECT REPRESENTS + This object is a tool for learning to do sequence labeling based + on a set of training data. The training procedure produces a + sequence_labeler object which can be used to predict the labels of + new data sequences. + + Note that this is just a convenience wrapper around the + structural_svm_sequence_labeling_problem to make it look + similar to all the other trainers in dlib. + !*/ + + public: + typedef typename feature_extractor::sequence_type sample_sequence_type; + typedef std::vector labeled_sequence_type; + typedef sequence_labeler trained_function_type; + + structural_sequence_labeling_trainer ( + ); + /*! + ensures + - #get_c() == 100 + - this object isn't verbose + - #get_epsilon() == 0.1 + - #get_max_iterations() == 10000 + - #get_num_threads() == 2 + - #get_max_cache_size() == 5 + - #get_feature_extractor() == a default initialized feature_extractor + !*/ + + explicit structural_sequence_labeling_trainer ( + const feature_extractor& fe + ); + /*! + ensures + - #get_c() == 100 + - this object isn't verbose + - #get_epsilon() == 0.1 + - #get_max_iterations() == 10000 + - #get_num_threads() == 2 + - #get_max_cache_size() == 5 + - #get_feature_extractor() == fe + !*/ + + const feature_extractor& get_feature_extractor ( + ) const; + /*! + ensures + - returns the feature extractor used by this object + !*/ + + unsigned long num_labels ( + ) const; + /*! + ensures + - returns get_feature_extractor().num_labels() + (i.e. returns the number of possible output labels for each + element of a sequence) + !*/ + + void set_num_threads ( + unsigned long num + ); + /*! + ensures + - #get_num_threads() == num + !*/ + + unsigned long get_num_threads ( + ) const; + /*! + ensures + - returns the number of threads used during training. You should + usually set this equal to the number of processing cores on your + machine. + !*/ + + void set_epsilon ( + double eps + ); + /*! + requires + - eps > 0 + ensures + - #get_epsilon() == eps + !*/ + + const double get_epsilon ( + ) const; + /*! + ensures + - returns the error epsilon that determines when training should stop. + Smaller values may result in a more accurate solution but take longer + to train. You can think of this epsilon value as saying "solve the + optimization problem until the average number of labeling mistakes per + training sample is within epsilon of its optimal value". + !*/ + + void set_max_iterations ( + unsigned long max_iter + ); + /*! + ensures + - #get_max_iterations() == max_iter + !*/ + + unsigned long get_max_iterations ( + ); + /*! + ensures + - returns the maximum number of iterations the SVM optimizer is allowed to + run before it is required to stop and return a result. + !*/ + + void set_max_cache_size ( + unsigned long max_size + ); + /*! + ensures + - #get_max_cache_size() == max_size + !*/ + + unsigned long get_max_cache_size ( + ) const; + /*! + ensures + - During training, this object basically runs the sequence_labeler on + each training sample, over and over. To speed this up, it is possible to + cache the results of these labeler invocations. This function returns the + number of cache elements per training sample kept in the cache. Note + that a value of 0 means caching is not used at all. + !*/ + + void be_verbose ( + ); + /*! + ensures + - This object will print status messages to standard out so that a + user can observe the progress of the algorithm. + !*/ + + void be_quiet ( + ); + /*! + ensures + - this object will not print anything to standard out + !*/ + + void set_oca ( + const oca& item + ); + /*! + ensures + - #get_oca() == item + !*/ + + const oca get_oca ( + ) const; + /*! + ensures + - returns a copy of the optimizer used to solve the structural SVM problem. + !*/ + + void set_c ( + double C + ); + /*! + requires + - C > 0 + ensures + - #get_c() = C + !*/ + + double get_c ( + ) const; + /*! + ensures + - returns the SVM regularization parameter. It is the parameter + that determines the trade-off between trying to fit the training + data (i.e. minimize the loss) or allowing more errors but hopefully + improving the generalization of the resulting sequence labeler. Larger + values encourage exact fitting while smaller values of C may encourage + better generalization. + !*/ + + double get_loss ( + unsigned long label + ) const; + /*! + requires + - label < num_labels() + ensures + - returns the loss incurred when a sequence element with the given + label is misclassified. This value controls how much we care about + correctly classifying this type of label. Larger loss values indicate + that we care more strongly than smaller values. + !*/ + + void set_loss ( + unsigned long label, + double value + ); + /*! + requires + - label < num_labels() + - value >= 0 + ensures + - #get_loss(label) == value + !*/ + + const sequence_labeler train( + const std::vector& x, + const std::vector& y + ) const; + /*! + requires + - is_sequence_labeling_problem(x, y) == true + - contains_invalid_labeling(get_feature_extractor(), x, y) == false + - for all valid i and j: y[i][j] < num_labels() + ensures + - Uses the structural_svm_sequence_labeling_problem to train a + sequence_labeler on the given x/y training pairs. The idea is + to learn to predict a y given an input x. + - returns a function F with the following properties: + - F(new_x) == A sequence of predicted labels for the elements of new_x. + - F(new_x).size() == new_x.size() + - for all valid i: + - F(new_x)[i] == the predicted label of new_x[i] + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_STRUCTURAL_SEQUENCE_LABELING_TRAiNER_ABSTRACT_Hh_ + + + + diff --git a/ml/dlib/dlib/svm/structural_sequence_segmentation_trainer.h b/ml/dlib/dlib/svm/structural_sequence_segmentation_trainer.h new file mode 100644 index 000000000..2e0214008 --- /dev/null +++ b/ml/dlib/dlib/svm/structural_sequence_segmentation_trainer.h @@ -0,0 +1,281 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_STRUCTURAL_SEQUENCE_sEGMENTATION_TRAINER_Hh_ +#define DLIB_STRUCTURAL_SEQUENCE_sEGMENTATION_TRAINER_Hh_ + +#include "structural_sequence_segmentation_trainer_abstract.h" +#include "structural_sequence_labeling_trainer.h" +#include "sequence_segmenter.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + class structural_sequence_segmentation_trainer + { + public: + typedef typename feature_extractor::sequence_type sample_sequence_type; + typedef std::vector > segmented_sequence_type; + + typedef sequence_segmenter trained_function_type; + + explicit structural_sequence_segmentation_trainer ( + const feature_extractor& fe_ + ) : trainer(impl_ss::feature_extractor(fe_)) + { + loss_per_missed_segment = 1; + loss_per_false_alarm = 1; + } + + structural_sequence_segmentation_trainer ( + ) + { + loss_per_missed_segment = 1; + loss_per_false_alarm = 1; + } + + const feature_extractor& get_feature_extractor ( + ) const { return trainer.get_feature_extractor().fe; } + + void set_num_threads ( + unsigned long num + ) + { + trainer.set_num_threads(num); + } + + unsigned long get_num_threads ( + ) const + { + return trainer.get_num_threads(); + } + + void set_epsilon ( + double eps_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(eps_ > 0, + "\t void structural_sequence_segmentation_trainer::set_epsilon()" + << "\n\t eps_ must be greater than 0" + << "\n\t eps_: " << eps_ + << "\n\t this: " << this + ); + + trainer.set_epsilon(eps_); + } + + double get_epsilon ( + ) const { return trainer.get_epsilon(); } + + unsigned long get_max_iterations ( + ) const { return trainer.get_max_iterations(); } + + void set_max_iterations ( + unsigned long max_iter + ) + { + trainer.set_max_iterations(max_iter); + } + + void set_max_cache_size ( + unsigned long max_size + ) + { + trainer.set_max_cache_size(max_size); + } + + unsigned long get_max_cache_size ( + ) const + { + return trainer.get_max_cache_size(); + } + + void be_verbose ( + ) + { + trainer.be_verbose(); + } + + void be_quiet ( + ) + { + trainer.be_quiet(); + } + + void set_oca ( + const oca& item + ) + { + trainer.set_oca(item); + } + + const oca get_oca ( + ) const + { + return trainer.get_oca(); + } + + void set_c ( + double C_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(C_ > 0, + "\t void structural_sequence_segmentation_trainer::set_c()" + << "\n\t C_ must be greater than 0" + << "\n\t C_: " << C_ + << "\n\t this: " << this + ); + + trainer.set_c(C_); + } + + double get_c ( + ) const + { + return trainer.get_c(); + } + + void set_loss_per_missed_segment ( + double loss + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(loss >= 0, + "\t void structural_sequence_segmentation_trainer::set_loss_per_missed_segment(loss)" + << "\n\t invalid inputs were given to this function" + << "\n\t loss: " << loss + << "\n\t this: " << this + ); + + loss_per_missed_segment = loss; + + if (feature_extractor::use_BIO_model) + { + trainer.set_loss(impl_ss::BEGIN, loss_per_missed_segment); + trainer.set_loss(impl_ss::INSIDE, loss_per_missed_segment); + } + else + { + trainer.set_loss(impl_ss::BEGIN, loss_per_missed_segment); + trainer.set_loss(impl_ss::INSIDE, loss_per_missed_segment); + trainer.set_loss(impl_ss::LAST, loss_per_missed_segment); + trainer.set_loss(impl_ss::UNIT, loss_per_missed_segment); + } + } + + double get_loss_per_missed_segment ( + ) const + { + return loss_per_missed_segment; + } + + void set_loss_per_false_alarm ( + double loss + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(loss >= 0, + "\t void structural_sequence_segmentation_trainer::set_loss_per_false_alarm(loss)" + << "\n\t invalid inputs were given to this function" + << "\n\t loss: " << loss + << "\n\t this: " << this + ); + + loss_per_false_alarm = loss; + + trainer.set_loss(impl_ss::OUTSIDE, loss_per_false_alarm); + } + + double get_loss_per_false_alarm ( + ) const + { + return loss_per_false_alarm; + } + + const sequence_segmenter train( + const std::vector& x, + const std::vector& y + ) const + { + + // make sure requires clause is not broken + DLIB_ASSERT(is_sequence_segmentation_problem(x,y) == true, + "\t sequence_segmenter structural_sequence_segmentation_trainer::train(x,y)" + << "\n\t invalid inputs were given to this function" + << "\n\t x.size(): " << x.size() + << "\n\t is_sequence_segmentation_problem(x,y): " << is_sequence_segmentation_problem(x,y) + << "\n\t this: " << this + ); + + std::vector > labels(y.size()); + if (feature_extractor::use_BIO_model) + { + // convert y into tagged BIO labels + for (unsigned long i = 0; i < labels.size(); ++i) + { + labels[i].resize(x[i].size(), impl_ss::OUTSIDE); + for (unsigned long j = 0; j < y[i].size(); ++j) + { + const unsigned long begin = y[i][j].first; + const unsigned long end = y[i][j].second; + if (begin != end) + { + labels[i][begin] = impl_ss::BEGIN; + for (unsigned long k = begin+1; k < end; ++k) + labels[i][k] = impl_ss::INSIDE; + } + } + } + } + else + { + // convert y into tagged BILOU labels + for (unsigned long i = 0; i < labels.size(); ++i) + { + labels[i].resize(x[i].size(), impl_ss::OUTSIDE); + for (unsigned long j = 0; j < y[i].size(); ++j) + { + const unsigned long begin = y[i][j].first; + const unsigned long end = y[i][j].second; + if (begin != end) + { + if (begin+1==end) + { + labels[i][begin] = impl_ss::UNIT; + } + else + { + labels[i][begin] = impl_ss::BEGIN; + for (unsigned long k = begin+1; k+1 < end; ++k) + labels[i][k] = impl_ss::INSIDE; + labels[i][end-1] = impl_ss::LAST; + } + } + } + } + } + + sequence_labeler > temp; + temp = trainer.train(x, labels); + return sequence_segmenter(temp.get_weights(), trainer.get_feature_extractor().fe); + } + + private: + + structural_sequence_labeling_trainer > trainer; + double loss_per_missed_segment; + double loss_per_false_alarm; + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_STRUCTURAL_SEQUENCE_sEGMENTATION_TRAINER_Hh_ + diff --git a/ml/dlib/dlib/svm/structural_sequence_segmentation_trainer_abstract.h b/ml/dlib/dlib/svm/structural_sequence_segmentation_trainer_abstract.h new file mode 100644 index 000000000..bcd927ca6 --- /dev/null +++ b/ml/dlib/dlib/svm/structural_sequence_segmentation_trainer_abstract.h @@ -0,0 +1,264 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_STRUCTURAL_SEQUENCE_sEGMENTATION_TRAINER_ABSTRACT_Hh_ +#ifdef DLIB_STRUCTURAL_SEQUENCE_sEGMENTATION_TRAINER_ABSTRACT_Hh_ + +#include "sequence_segmenter_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + class structural_sequence_segmentation_trainer + { + /*! + REQUIREMENTS ON feature_extractor + It must be an object that implements an interface compatible with + the example_feature_extractor defined in dlib/svm/sequence_segmenter_abstract.h. + + WHAT THIS OBJECT REPRESENTS + This object is a tool for learning to do sequence segmentation based on a + set of training data. The training procedure produces a sequence_segmenter + object which can be used to identify the sub-segments of new data + sequences. + + This object internally uses the structural_sequence_labeling_trainer to + solve the learning problem. + !*/ + + public: + + typedef typename feature_extractor::sequence_type sample_sequence_type; + typedef std::vector > segmented_sequence_type; + + typedef sequence_segmenter trained_function_type; + + structural_sequence_segmentation_trainer ( + ); + /*! + ensures + - #get_c() == 100 + - this object isn't verbose + - #get_epsilon() == 0.1 + - #get_max_iterations() == 10000 + - #get_num_threads() == 2 + - #get_max_cache_size() == 40 + - #get_feature_extractor() == a default initialized feature_extractor + - #get_loss_per_missed_segment() == 1 + - #get_loss_per_false_alarm() == 1 + !*/ + + explicit structural_sequence_segmentation_trainer ( + const feature_extractor& fe + ); + /*! + ensures + - #get_c() == 100 + - this object isn't verbose + - #get_epsilon() == 0.1 + - #get_max_iterations() == 10000 + - #get_num_threads() == 2 + - #get_max_cache_size() == 40 + - #get_feature_extractor() == fe + - #get_loss_per_missed_segment() == 1 + - #get_loss_per_false_alarm() == 1 + !*/ + + const feature_extractor& get_feature_extractor ( + ) const; + /*! + ensures + - returns the feature extractor used by this object + !*/ + + void set_num_threads ( + unsigned long num + ); + /*! + ensures + - #get_num_threads() == num + !*/ + + unsigned long get_num_threads ( + ) const; + /*! + ensures + - returns the number of threads used during training. You should + usually set this equal to the number of processing cores on your + machine. + !*/ + + void set_epsilon ( + double eps_ + ); + /*! + requires + - eps > 0 + ensures + - #get_epsilon() == eps + !*/ + + double get_epsilon ( + ) const; + /*! + ensures + - returns the error epsilon that determines when training should stop. + Smaller values may result in a more accurate solution but take longer + to train. You can think of this epsilon value as saying "solve the + optimization problem until the average number of segmentation mistakes + per training sample is within epsilon of its optimal value". + !*/ + + void set_max_iterations ( + unsigned long max_iter + ); + /*! + ensures + - #get_max_iterations() == max_iter + !*/ + + unsigned long get_max_iterations ( + ); + /*! + ensures + - returns the maximum number of iterations the SVM optimizer is allowed to + run before it is required to stop and return a result. + !*/ + + void set_max_cache_size ( + unsigned long max_size + ); + /*! + ensures + - #get_max_cache_size() == max_size + !*/ + + unsigned long get_max_cache_size ( + ) const; + /*! + ensures + - During training, this object basically runs the sequence_segmenter on + each training sample, over and over. To speed this up, it is possible to + cache the results of these segmenter invocations. This function returns + the number of cache elements per training sample kept in the cache. Note + that a value of 0 means caching is not used at all. + !*/ + + void be_verbose ( + ); + /*! + ensures + - This object will print status messages to standard out so that a user can + observe the progress of the algorithm. + !*/ + + void be_quiet ( + ); + /*! + ensures + - this object will not print anything to standard out + !*/ + + void set_oca ( + const oca& item + ); + /*! + ensures + - #get_oca() == item + !*/ + + const oca get_oca ( + ) const; + /*! + ensures + - returns a copy of the optimizer used to solve the structural SVM problem. + !*/ + + void set_c ( + double C + ); + /*! + requires + - C > 0 + ensures + - #get_c() = C + !*/ + + double get_c ( + ) const; + /*! + ensures + - returns the SVM regularization parameter. It is the parameter that + determines the trade-off between trying to fit the training data (i.e. + minimize the loss) or allowing more errors but hopefully improving the + generalization of the resulting sequence labeler. Larger values + encourage exact fitting while smaller values of C may encourage better + generalization. + !*/ + + void set_loss_per_missed_segment ( + double loss + ); + /*! + requires + - loss >= 0 + ensures + - #get_loss_per_missed_segment() == loss + !*/ + + double get_loss_per_missed_segment ( + ) const; + /*! + ensures + - returns the amount of loss incurred for failing to detect a segment. The + larger the loss the more important it is to detect all the segments. + !*/ + + + void set_loss_per_false_alarm ( + double loss + ); + /*! + requires + - loss >= 0 + ensures + - #get_loss_per_false_alarm() == loss + !*/ + + double get_loss_per_false_alarm ( + ) const; + /*! + ensures + - returns the amount of loss incurred for outputting a false detection. The + larger the loss the more important it is to avoid outputting false + detections. + !*/ + + const sequence_segmenter train( + const std::vector& x, + const std::vector& y + ) const; + /*! + requires + - is_sequence_segmentation_problem(x, y) == true + ensures + - Uses the given training data to learn to do sequence segmentation. That + is, this function will try to find a sequence_segmenter capable of + predicting y[i] when given x[i] as input. Moreover, it should also be + capable of predicting the segmentation of new input sequences. Or in + other words, the learned sequence_segmenter should also generalize to new + data outside the training dataset. + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_STRUCTURAL_SEQUENCE_sEGMENTATION_TRAINER_ABSTRACT_Hh_ + diff --git a/ml/dlib/dlib/svm/structural_svm_assignment_problem.h b/ml/dlib/dlib/svm/structural_svm_assignment_problem.h new file mode 100644 index 000000000..963af1631 --- /dev/null +++ b/ml/dlib/dlib/svm/structural_svm_assignment_problem.h @@ -0,0 +1,288 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_STRUCTURAL_SVM_ASSiGNMENT_PROBLEM_Hh_ +#define DLIB_STRUCTURAL_SVM_ASSiGNMENT_PROBLEM_Hh_ + + +#include "structural_svm_assignment_problem_abstract.h" +#include "../matrix.h" +#include +#include +#include "structural_svm_problem_threaded.h" + +// ---------------------------------------------------------------------------------------- + +namespace dlib +{ + template + struct column_matrix_static_resize + { + typedef T type; + }; + + template + struct column_matrix_static_resize > + { + typedef matrix type; + }; + + template + struct column_matrix_static_resize > + { + typedef matrix type; + }; + + template + struct add_one_to_static_feat_size + { + typedef typename column_matrix_static_resize<1,typename T::feature_vector_type>::type type; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + class structural_svm_assignment_problem : noncopyable, + public structural_svm_problem_threaded, typename add_one_to_static_feat_size::type > + { + public: + typedef matrix matrix_type; + typedef typename add_one_to_static_feat_size::type feature_vector_type; + + typedef typename feature_extractor::lhs_element lhs_element; + typedef typename feature_extractor::rhs_element rhs_element; + + + typedef std::pair, std::vector > sample_type; + + typedef std::vector label_type; + + structural_svm_assignment_problem( + const std::vector& samples_, + const std::vector& labels_, + const feature_extractor& fe_, + bool force_assignment_, + unsigned long num_threads, + const double loss_per_false_association_, + const double loss_per_missed_association_ + ) : + structural_svm_problem_threaded(num_threads), + samples(samples_), + labels(labels_), + fe(fe_), + force_assignment(force_assignment_), + loss_per_false_association(loss_per_false_association_), + loss_per_missed_association(loss_per_missed_association_) + { + // make sure requires clause is not broken +#ifdef ENABLE_ASSERTS + DLIB_ASSERT(loss_per_false_association > 0 && loss_per_missed_association > 0, + "\t structural_svm_assignment_problem::structural_svm_assignment_problem()" + << "\n\t invalid inputs were given to this function" + << "\n\t loss_per_false_association: " << loss_per_false_association + << "\n\t loss_per_missed_association: " << loss_per_missed_association + << "\n\t this: " << this + ); + if (force_assignment) + { + DLIB_ASSERT(is_forced_assignment_problem(samples, labels), + "\t structural_svm_assignment_problem::structural_svm_assignment_problem()" + << "\n\t invalid inputs were given to this function" + << "\n\t is_forced_assignment_problem(samples,labels): " << is_forced_assignment_problem(samples,labels) + << "\n\t is_assignment_problem(samples,labels): " << is_assignment_problem(samples,labels) + << "\n\t is_learning_problem(samples,labels): " << is_learning_problem(samples,labels) + << "\n\t this: " << this + ); + } + else + { + DLIB_ASSERT(is_assignment_problem(samples, labels), + "\t structural_svm_assignment_problem::structural_svm_assignment_problem()" + << "\n\t invalid inputs were given to this function" + << "\n\t is_assignment_problem(samples,labels): " << is_assignment_problem(samples,labels) + << "\n\t is_learning_problem(samples,labels): " << is_learning_problem(samples,labels) + << "\n\t this: " << this + ); + } +#endif + + } + + private: + virtual long get_num_dimensions ( + ) const + { + return fe.num_features()+1; // +1 for the bias term + } + + virtual long get_num_samples ( + ) const + { + return samples.size(); + } + + template + typename enable_if >::type get_joint_feature_vector ( + const sample_type& sample, + const label_type& label, + psi_type& psi + ) const + { + typename feature_extractor::feature_vector_type feats; + psi.set_size(get_num_dimensions()); + psi = 0; + for (unsigned long i = 0; i < sample.first.size(); ++i) + { + if (label[i] != -1) + { + fe.get_features(sample.first[i], sample.second[label[i]], feats); + set_rowm(psi,range(0,feats.size()-1)) += feats; + psi(get_num_dimensions()-1) += 1; + } + } + } + + template + void append_to_sparse_vect ( + T& psi, + const T& vect + ) const + { + std::copy(vect.begin(), vect.end(), std::back_inserter(psi)); + } + + template + typename disable_if >::type get_joint_feature_vector ( + const sample_type& sample, + const label_type& label, + psi_type& psi + ) const + { + psi.clear(); + feature_vector_type feats; + int num_assignments = 0; + for (unsigned long i = 0; i < sample.first.size(); ++i) + { + if (label[i] != -1) + { + fe.get_features(sample.first[i], sample.second[label[i]], feats); + append_to_sparse_vect(psi, feats); + ++num_assignments; + } + } + psi.push_back(std::make_pair(get_num_dimensions()-1,num_assignments)); + } + + virtual void get_truth_joint_feature_vector ( + long idx, + feature_vector_type& psi + ) const + { + get_joint_feature_vector(samples[idx], labels[idx], psi); + } + + virtual void separation_oracle ( + const long idx, + const matrix_type& current_solution, + double& loss, + feature_vector_type& psi + ) const + { + matrix cost; + unsigned long size; + if (force_assignment) + { + unsigned long lhs_size = samples[idx].first.size(); + unsigned long rhs_size = samples[idx].second.size(); + size = std::max(lhs_size, rhs_size); + } + else + { + unsigned long rhs_size = samples[idx].second.size() + samples[idx].first.size(); + size = rhs_size; + } + cost.set_size(size, size); + + typename feature_extractor::feature_vector_type feats; + + // now fill out the cost assignment matrix + for (long r = 0; r < cost.nr(); ++r) + { + for (long c = 0; c < cost.nc(); ++c) + { + if (r < (long)samples[idx].first.size()) + { + if (c < (long)samples[idx].second.size()) + { + fe.get_features(samples[idx].first[r], samples[idx].second[c], feats); + const double bias = current_solution(current_solution.size()-1); + cost(r,c) = dot(colm(current_solution,0,current_solution.size()-1), feats) + bias; + + // add in the loss since this corresponds to an incorrect prediction. + if (c != labels[idx][r]) + { + cost(r,c) += loss_per_false_association; + } + } + else + { + if (labels[idx][r] == -1) + cost(r,c) = 0; + else + cost(r,c) = loss_per_missed_association; + } + + } + else + { + cost(r,c) = 0; + } + } + } + + std::vector assignment; + + if (cost.size() != 0) + { + // max_cost_assignment() only works with integer matrices, so convert from + // double to integer. + const double scale = (std::numeric_limits::max()/1000)/max(abs(cost)); + matrix int_cost = matrix_cast(round(cost*scale)); + assignment = max_cost_assignment(int_cost); + assignment.resize(samples[idx].first.size()); + } + + loss = 0; + // adjust assignment so that non-assignments have a value of -1. Also compute loss. + for (unsigned long i = 0; i < assignment.size(); ++i) + { + if (assignment[i] >= (long)samples[idx].second.size()) + assignment[i] = -1; + + if (assignment[i] != labels[idx][i]) + { + if (assignment[i] == -1) + loss += loss_per_missed_association; + else + loss += loss_per_false_association; + } + } + + get_joint_feature_vector(samples[idx], assignment, psi); + } + + const std::vector& samples; + const std::vector& labels; + const feature_extractor& fe; + bool force_assignment; + const double loss_per_false_association; + const double loss_per_missed_association; + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_STRUCTURAL_SVM_ASSiGNMENT_PROBLEM_Hh_ + diff --git a/ml/dlib/dlib/svm/structural_svm_assignment_problem_abstract.h b/ml/dlib/dlib/svm/structural_svm_assignment_problem_abstract.h new file mode 100644 index 000000000..c06190726 --- /dev/null +++ b/ml/dlib/dlib/svm/structural_svm_assignment_problem_abstract.h @@ -0,0 +1,87 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_STRUCTURAL_SVM_ASSiGNMENT_PROBLEM_ABSTRACT_Hh_ +#ifdef DLIB_STRUCTURAL_SVM_ASSiGNMENT_PROBLEM_ABSTRACT_Hh_ + + +#include "../matrix.h" +#include +#include "structural_svm_problem_threaded_abstract.h" +#include "assignment_function_abstract.h" + +// ---------------------------------------------------------------------------------------- + +namespace dlib +{ + + template < + typename feature_extractor + > + class structural_svm_assignment_problem : noncopyable, + public structural_svm_problem_threaded, + typename feature_extractor::feature_vector_type > + { + /*! + REQUIREMENTS ON feature_extractor + It must be an object that implements an interface compatible with + the example_feature_extractor defined in dlib/svm/assignment_function_abstract.h. + + WHAT THIS OBJECT REPRESENTS + This object is a tool for learning the parameters needed to use an + assignment_function object. It learns the parameters by formulating the + problem as a structural SVM problem. + !*/ + + public: + typedef matrix matrix_type; + typedef typename feature_extractor::feature_vector_type feature_vector_type; + typedef typename feature_extractor::lhs_element lhs_element; + typedef typename feature_extractor::rhs_element rhs_element; + typedef std::pair, std::vector > sample_type; + typedef std::vector label_type; + + structural_svm_assignment_problem( + const std::vector& samples, + const std::vector& labels, + const feature_extractor& fe, + bool force_assignment, + unsigned long num_threads, + const double loss_per_false_association, + const double loss_per_missed_association + ); + /*! + requires + - loss_per_false_association > 0 + - loss_per_missed_association > 0 + - is_assignment_problem(samples,labels) == true + - if (force_assignment) then + - is_forced_assignment_problem(samples,labels) == true + ensures + - This object attempts to learn a mapping from the given samples to the + given labels. In particular, it attempts to learn to predict labels[i] + based on samples[i]. Or in other words, this object can be used to learn + a parameter vector and bias, w and b, such that an assignment_function declared as: + assignment_function assigner(w,b,fe,force_assignment) + results in an assigner object which attempts to compute the following mapping: + labels[i] == labeler(samples[i]) + - This object will use num_threads threads during the optimization + procedure. You should set this parameter equal to the number of + available processing cores on your machine. + - When solving the structural SVM problem, we will use + loss_per_false_association as the loss for incorrectly associating + objects that shouldn't be associated. + - When solving the structural SVM problem, we will use + loss_per_missed_association as the loss for failing to associate to + objects that are supposed to be associated with each other. + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_STRUCTURAL_SVM_ASSiGNMENT_PROBLEM_ABSTRACT_Hh_ + + + diff --git a/ml/dlib/dlib/svm/structural_svm_distributed.h b/ml/dlib/dlib/svm/structural_svm_distributed.h new file mode 100644 index 000000000..a9542c70f --- /dev/null +++ b/ml/dlib/dlib/svm/structural_svm_distributed.h @@ -0,0 +1,700 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_STRUCTURAL_SVM_DISTRIBUTeD_Hh_ +#define DLIB_STRUCTURAL_SVM_DISTRIBUTeD_Hh_ + +#include +#include +#include + +#include "structural_svm_distributed_abstract.h" +#include "structural_svm_problem.h" +#include "../bridge.h" +#include "../misc_api.h" +#include "../statistics.h" +#include "../threads.h" +#include "../pipe.h" +#include "../type_safe_union.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + + template + struct oracle_response + { + typedef typename matrix_type::type scalar_type; + + matrix_type subgradient; + scalar_type loss; + long num; + + friend void swap (oracle_response& a, oracle_response& b) + { + a.subgradient.swap(b.subgradient); + std::swap(a.loss, b.loss); + std::swap(a.num, b.num); + } + + friend void serialize (const oracle_response& item, std::ostream& out) + { + serialize(item.subgradient, out); + dlib::serialize(item.loss, out); + dlib::serialize(item.num, out); + } + + friend void deserialize (oracle_response& item, std::istream& in) + { + deserialize(item.subgradient, in); + dlib::deserialize(item.loss, in); + dlib::deserialize(item.num, in); + } + }; + + // ---------------------------------------------------------------------------------------- + + template + struct oracle_request + { + typedef typename matrix_type::type scalar_type; + + matrix_type current_solution; + scalar_type saved_current_risk_gap; + bool skip_cache; + bool converged; + + friend void swap (oracle_request& a, oracle_request& b) + { + a.current_solution.swap(b.current_solution); + std::swap(a.saved_current_risk_gap, b.saved_current_risk_gap); + std::swap(a.skip_cache, b.skip_cache); + std::swap(a.converged, b.converged); + } + + friend void serialize (const oracle_request& item, std::ostream& out) + { + serialize(item.current_solution, out); + dlib::serialize(item.saved_current_risk_gap, out); + dlib::serialize(item.skip_cache, out); + dlib::serialize(item.converged, out); + } + + friend void deserialize (oracle_request& item, std::istream& in) + { + deserialize(item.current_solution, in); + dlib::deserialize(item.saved_current_risk_gap, in); + dlib::deserialize(item.skip_cache, in); + dlib::deserialize(item.converged, in); + } + }; + + } + +// ---------------------------------------------------------------------------------------- + + class svm_struct_processing_node : noncopyable + { + public: + + template < + typename T, + typename U + > + svm_struct_processing_node ( + const structural_svm_problem& problem, + unsigned short port, + unsigned short num_threads + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(port != 0 && problem.get_num_samples() != 0 && + problem.get_num_dimensions() != 0, + "\t svm_struct_processing_node()" + << "\n\t Invalid arguments were given to this function" + << "\n\t port: " << port + << "\n\t problem.get_num_samples(): " << problem.get_num_samples() + << "\n\t problem.get_num_dimensions(): " << problem.get_num_dimensions() + << "\n\t this: " << this + ); + + the_problem.reset(new node_type(problem, port, num_threads)); + } + + private: + + struct base + { + virtual ~base(){} + }; + + template < + typename matrix_type, + typename feature_vector_type + > + class node_type : public base, threaded_object + { + public: + typedef typename matrix_type::type scalar_type; + + node_type( + const structural_svm_problem& prob, + unsigned short port, + unsigned long num_threads + ) : in(3),out(3), problem(prob), tp(num_threads) + { + b.reconfigure(listen_on_port(port), receive(in), transmit(out)); + + start(); + } + + ~node_type() + { + in.disable(); + out.disable(); + wait(); + } + + private: + + void thread() + { + using namespace impl; + tsu_in msg; + tsu_out temp; + + timestamper ts; + running_stats with_buffer_time; + running_stats without_buffer_time; + unsigned long num_iterations_executed = 0; + + while (in.dequeue(msg)) + { + // initialize the cache and compute psi_true. + if (cache.size() == 0) + { + cache.resize(problem.get_num_samples()); + for (unsigned long i = 0; i < cache.size(); ++i) + cache[i].init(&problem,i); + + psi_true.set_size(problem.get_num_dimensions(),1); + psi_true = 0; + + const unsigned long num = problem.get_num_samples(); + feature_vector_type ftemp; + for (unsigned long i = 0; i < num; ++i) + { + cache[i].get_truth_joint_feature_vector_cached(ftemp); + + subtract_from(psi_true, ftemp); + } + } + + + if (msg.template contains() && + msg.template get().is_connected) + { + temp = problem.get_num_dimensions(); + out.enqueue(temp); + + } + else if (msg.template contains >()) + { + ++num_iterations_executed; + + const oracle_request& req = msg.template get >(); + + oracle_response& data = temp.template get >(); + + data.subgradient = psi_true; + data.loss = 0; + + data.num = problem.get_num_samples(); + + const uint64 start_time = ts.get_timestamp(); + + // pick fastest buffering strategy + bool buffer_subgradients_locally = with_buffer_time.mean() < without_buffer_time.mean(); + + // every 50 iterations we should try to flip the buffering scheme to see if + // doing it the other way might be better. + if ((num_iterations_executed%50) == 0) + { + buffer_subgradients_locally = !buffer_subgradients_locally; + } + + binder b(*this, req, data, buffer_subgradients_locally); + parallel_for_blocked(tp, 0, data.num, b, &binder::call_oracle); + + const uint64 stop_time = ts.get_timestamp(); + if (buffer_subgradients_locally) + with_buffer_time.add(stop_time-start_time); + else + without_buffer_time.add(stop_time-start_time); + + out.enqueue(temp); + } + } + } + + struct binder + { + binder ( + const node_type& self_, + const impl::oracle_request& req_, + impl::oracle_response& data_, + bool buffer_subgradients_locally_ + ) : self(self_), req(req_), data(data_), + buffer_subgradients_locally(buffer_subgradients_locally_) {} + + void call_oracle ( + long begin, + long end + ) + { + // If we are only going to call the separation oracle once then don't + // run the slightly more complex for loop version of this code. Or if + // we just don't want to run the complex buffering one. The code later + // on decides if we should do the buffering based on how long it takes + // to execute. We do this because, when the subgradient is really high + // dimensional it can take a lot of time to add them together. So we + // might want to avoid doing that. + if (end-begin <= 1 || !buffer_subgradients_locally) + { + scalar_type loss; + feature_vector_type ftemp; + for (long i = begin; i < end; ++i) + { + self.cache[i].separation_oracle_cached(req.converged, + req.skip_cache, + req.saved_current_risk_gap, + req.current_solution, + loss, + ftemp); + + auto_mutex lock(self.accum_mutex); + data.loss += loss; + add_to(data.subgradient, ftemp); + } + } + else + { + scalar_type loss = 0; + matrix_type faccum(data.subgradient.size(),1); + faccum = 0; + + feature_vector_type ftemp; + + for (long i = begin; i < end; ++i) + { + scalar_type loss_temp; + self.cache[i].separation_oracle_cached(req.converged, + req.skip_cache, + req.saved_current_risk_gap, + req.current_solution, + loss_temp, + ftemp); + loss += loss_temp; + add_to(faccum, ftemp); + } + + auto_mutex lock(self.accum_mutex); + data.loss += loss; + add_to(data.subgradient, faccum); + } + } + + const node_type& self; + const impl::oracle_request& req; + impl::oracle_response& data; + bool buffer_subgradients_locally; + }; + + + + typedef type_safe_union, bridge_status> tsu_in; + typedef type_safe_union , long> tsu_out; + + pipe in; + pipe out; + bridge b; + + mutable matrix_type psi_true; + const structural_svm_problem& problem; + mutable std::vector > > cache; + + mutable thread_pool tp; + mutex accum_mutex; + }; + + + std::unique_ptr the_problem; + }; + +// ---------------------------------------------------------------------------------------- + + class svm_struct_controller_node : noncopyable + { + public: + + svm_struct_controller_node ( + ) : + eps(0.001), + max_iterations(10000), + cache_based_eps(std::numeric_limits::infinity()), + verbose(false), + C(1) + {} + + double get_cache_based_epsilon ( + ) const + { + return cache_based_eps; + } + + void set_cache_based_epsilon ( + double eps_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(eps_ > 0, + "\t void svm_struct_controller_node::set_cache_based_epsilon()" + << "\n\t eps_ must be greater than 0" + << "\n\t eps_: " << eps_ + << "\n\t this: " << this + ); + + cache_based_eps = eps_; + } + + void set_epsilon ( + double eps_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(eps_ > 0, + "\t void svm_struct_controller_node::set_epsilon()" + << "\n\t eps_ must be greater than 0" + << "\n\t eps_: " << eps_ + << "\n\t this: " << this + ); + + eps = eps_; + } + + double get_epsilon ( + ) const { return eps; } + + unsigned long get_max_iterations ( + ) const { return max_iterations; } + + void set_max_iterations ( + unsigned long max_iter + ) + { + max_iterations = max_iter; + } + + void be_verbose ( + ) + { + verbose = true; + } + + void be_quiet( + ) + { + verbose = false; + } + + void add_nuclear_norm_regularizer ( + long first_dimension, + long rows, + long cols, + double regularization_strength + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(0 <= first_dimension && + 0 <= rows && 0 <= cols && + 0 < regularization_strength, + "\t void svm_struct_controller_node::add_nuclear_norm_regularizer()" + << "\n\t Invalid arguments were given to this function." + << "\n\t first_dimension: " << first_dimension + << "\n\t rows: " << rows + << "\n\t cols: " << cols + << "\n\t regularization_strength: " << regularization_strength + << "\n\t this: " << this + ); + + impl::nuclear_norm_regularizer temp; + temp.first_dimension = first_dimension; + temp.nr = rows; + temp.nc = cols; + temp.regularization_strength = regularization_strength; + nuclear_norm_regularizers.push_back(temp); + } + + unsigned long num_nuclear_norm_regularizers ( + ) const { return nuclear_norm_regularizers.size(); } + + void clear_nuclear_norm_regularizers ( + ) { nuclear_norm_regularizers.clear(); } + + + double get_c ( + ) const { return C; } + + void set_c ( + double C_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(C_ > 0, + "\t void svm_struct_controller_node::set_c()" + << "\n\t C_ must be greater than 0" + << "\n\t C_: " << C_ + << "\n\t this: " << this + ); + + C = C_; + } + + void add_processing_node ( + const network_address& addr + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(addr.port != 0, + "\t void svm_struct_controller_node::add_processing_node()" + << "\n\t Invalid inputs were given to this function" + << "\n\t addr.host_address: " << addr.host_address + << "\n\t addr.port: " << addr.port + << "\n\t this: " << this + ); + + // check if this address is already registered + for (unsigned long i = 0; i < nodes.size(); ++i) + { + if (nodes[i] == addr) + { + return; + } + } + + nodes.push_back(addr); + } + + void add_processing_node ( + const std::string& ip_or_hostname, + unsigned short port + ) + { + add_processing_node(network_address(ip_or_hostname,port)); + } + + unsigned long get_num_processing_nodes ( + ) const + { + return nodes.size(); + } + + void remove_processing_nodes ( + ) + { + nodes.clear(); + } + + template + double operator() ( + const oca& solver, + matrix_type& w + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(get_num_processing_nodes() != 0, + "\t double svm_struct_controller_node::operator()" + << "\n\t You must add some processing nodes before calling this function." + << "\n\t this: " << this + ); + + problem_type problem(nodes); + problem.set_cache_based_epsilon(cache_based_eps); + problem.set_epsilon(eps); + problem.set_max_iterations(max_iterations); + if (verbose) + problem.be_verbose(); + problem.set_c(C); + for (unsigned long i = 0; i < nuclear_norm_regularizers.size(); ++i) + { + problem.add_nuclear_norm_regularizer( + nuclear_norm_regularizers[i].first_dimension, + nuclear_norm_regularizers[i].nr, + nuclear_norm_regularizers[i].nc, + nuclear_norm_regularizers[i].regularization_strength); + } + + return solver(problem, w); + } + + class invalid_problem : public error + { + public: + invalid_problem( + const std::string& a + ): error(a) {} + }; + + + private: + + template + class problem_type : public structural_svm_problem + { + public: + typedef typename matrix_type_::type scalar_type; + typedef matrix_type_ matrix_type; + + problem_type ( + const std::vector& nodes_ + ) : + nodes(nodes_), + in(3), + num_dims(0) + { + + // initialize all the transmit pipes + out_pipes.resize(nodes.size()); + for (unsigned long i = 0; i < out_pipes.size(); ++i) + { + out_pipes[i].reset(new pipe(3)); + } + + // make bridges that connect to all our remote processing nodes + bridges.resize(nodes.size()); + for (unsigned long i = 0; i< bridges.size(); ++i) + { + bridges[i].reset(new bridge(connect_to(nodes[i]), + receive(in), transmit(*out_pipes[i]))); + } + + + + // The remote processing nodes are supposed to all send the problem dimensionality + // upon connection. So get that and make sure everyone agrees on what it's supposed to be. + tsu_in temp; + unsigned long responses = 0; + bool seen_dim = false; + while (responses < nodes.size()) + { + in.dequeue(temp); + if (temp.template contains()) + { + ++responses; + // if this new dimension doesn't match what we have seen previously + if (seen_dim && num_dims != temp.template get()) + { + throw invalid_problem("remote hosts disagree on the number of dimensions!"); + } + seen_dim = true; + num_dims = temp.template get(); + } + } + } + + // These functions are just here because the structural_svm_problem requires + // them, but since we are overloading get_risk() they are never called so they + // don't matter. + virtual long get_num_samples () const {return 0;} + virtual void get_truth_joint_feature_vector ( long , matrix_type& ) const {} + virtual void separation_oracle ( const long , const matrix_type& , scalar_type& , matrix_type& ) const {} + + virtual long get_num_dimensions ( + ) const + { + return num_dims; + } + + virtual void get_risk ( + matrix_type& w, + scalar_type& risk, + matrix_type& subgradient + ) const + { + using namespace impl; + subgradient.set_size(w.size(),1); + subgradient = 0; + + // send out all the oracle requests + tsu_out temp_out; + for (unsigned long i = 0; i < out_pipes.size(); ++i) + { + temp_out.template get >().current_solution = w; + temp_out.template get >().saved_current_risk_gap = this->saved_current_risk_gap; + temp_out.template get >().skip_cache = this->skip_cache; + temp_out.template get >().converged = this->converged; + out_pipes[i]->enqueue(temp_out); + } + + // collect all the oracle responses + long num = 0; + scalar_type total_loss = 0; + tsu_in temp_in; + unsigned long responses = 0; + while (responses < out_pipes.size()) + { + in.dequeue(temp_in); + if (temp_in.template contains >()) + { + ++responses; + const oracle_response& data = temp_in.template get >(); + subgradient += data.subgradient; + total_loss += data.loss; + num += data.num; + } + } + + subgradient /= num; + total_loss /= num; + risk = total_loss + dot(subgradient,w); + + if (this->nuclear_norm_regularizers.size() != 0) + { + matrix_type grad; + double obj; + this->compute_nuclear_norm_parts(w, grad, obj); + risk += obj; + subgradient += grad; + } + } + + std::vector nodes; + + typedef type_safe_union > tsu_out; + typedef type_safe_union, long> tsu_in; + + std::vector > > out_pipes; + mutable pipe in; + std::vector > bridges; + long num_dims; + }; + + std::vector nodes; + double eps; + unsigned long max_iterations; + double cache_based_eps; + bool verbose; + double C; + std::vector nuclear_norm_regularizers; + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_STRUCTURAL_SVM_DISTRIBUTeD_Hh_ + diff --git a/ml/dlib/dlib/svm/structural_svm_distributed_abstract.h b/ml/dlib/dlib/svm/structural_svm_distributed_abstract.h new file mode 100644 index 000000000..175a643c8 --- /dev/null +++ b/ml/dlib/dlib/svm/structural_svm_distributed_abstract.h @@ -0,0 +1,357 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_STRUCTURAL_SVM_DISTRIBUTeD_ABSTRACT_Hh_ +#ifdef DLIB_STRUCTURAL_SVM_DISTRIBUTeD_ABSTRACT_Hh_ + + +#include "structural_svm_problem_abstract.h" +#include "../optimization/optimization_oca_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class svm_struct_processing_node : noncopyable + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a tool for distributing the work involved in solving + a dlib::structural_svm_problem across many computers. It is used in + conjunction with the svm_struct_controller_node defined below. + !*/ + + public: + + template < + typename T, + typename U + > + svm_struct_processing_node ( + const structural_svm_problem& problem, + unsigned short port, + unsigned short num_threads + ); + /*! + requires + - port != 0 + - problem.get_num_samples() != 0 + - problem.get_num_dimensions() != 0 + ensures + - This object will listen on the given port for a TCP connection from a + svm_struct_controller_node. Once connected, the controller node will + be able to access the given problem. + - Will use num_threads threads at a time to make concurrent calls to the + problem.separation_oracle() routine. You should set this parameter equal + to the number of available processing cores. + - Note that the following parameters within the given problem are ignored: + - problem.get_c() + - problem.get_epsilon() + - problem.get_cache_based_epsilon() + - problem.num_nuclear_norm_regularizers() + - weather the problem is verbose or not + Instead, they are defined by the svm_struct_controller_node. Note, however, + that the problem.get_max_cache_size() parameter is meaningful and controls + the size of the separation oracle cache within a svm_struct_processing_node. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + class svm_struct_controller_node : noncopyable + { + /*! + INITIAL VALUE + - get_num_processing_nodes() == 0 + - get_epsilon() == 0.001 + - get_max_iterations() == 10000 + - get_c() == 1 + - This object will not be verbose + + WHAT THIS OBJECT REPRESENTS + This object is a tool for distributing the work involved in solving a + dlib::structural_svm_problem across many computers. The best way to understand + its use is via example: + + First, suppose you have defined a structural_svm_problem object by inheriting from + it and defining the appropriate virtual functions. You could solve it by passing + an instance to the oca optimizer. However, if your separation oracle takes a long + time to evaluate then the optimization will take a long time to solve. To speed + this up we can distribute the calls to the separation oracle across many computers. + + To make this concrete, lets imagine you want to distribute the work across three + computers. You can accomplish this by creating four programs. One containing a + svm_struct_controller_node and three containing svm_struct_processing_nodes. + + The programs might look like this: + + Controller program: + int main() + { + svm_struct_controller_node cont; + cont.set_c(100); + // Tell cont where the processing nodes are on your network. + cont.add_processing_node("192.168.1.10:12345"); + cont.add_processing_node("192.168.1.11:12345"); + cont.add_processing_node("192.168.1.12:12345"); + matrix w; + oca solver; + cont(solver, w); // Run the optimization. + // After this finishes w will contain the solution vector. + } + + Processing programs (they are all the same, except that each loads a different subset + of the training data): + int main() + { + // Put one third of your data into this problem object. How you do this depends on your problem. + your_structural_svm_problem problem; + svm_struct_processing_node node(problem, 12345, number_of_cores_on_this_computer); + cout << "hit enter to terminate this program" << endl; + cin.get(); + } + + !*/ + + public: + + svm_struct_controller_node ( + ); + /*! + ensures + - this object is properly initialized + !*/ + + void set_epsilon ( + double eps + ); + /*! + requires + - eps > 0 + ensures + - #get_epsilon() == eps + !*/ + + double get_epsilon ( + ) const; + /*! + ensures + - returns the error epsilon that determines when training should stop. + Smaller values may result in a more accurate solution but take longer + to execute. Specifically, the algorithm stops when the average sample + risk (i.e. R(w) as defined by the dlib::structural_svm_problem object) is + within epsilon of its optimal value. + + Also note that sample risk is an upper bound on a sample's loss. So + you can think of this epsilon value as saying "solve the optimization + problem until the average loss per sample is within epsilon of it's + optimal value". + !*/ + + double get_cache_based_epsilon ( + ) const; + /*! + ensures + - if (get_max_cache_size() != 0) then + - The solver will not stop when the average sample risk is within + get_epsilon() of its optimal value. Instead, it will keep running + but will run the optimizer completely on the cache until the average + sample risk is within #get_cache_based_epsilon() of its optimal + value. This means that it will perform this additional refinement in + the solution accuracy without making any additional calls to the + separation_oracle(). This is useful when using a nuclear norm + regularization term because it allows you to quickly solve the + optimization problem to a high precision, which in the case of a + nuclear norm regularized problem means that many of the learned + matrices will be low rank or very close to low rank due to the + nuclear norm regularizer. This may not happen without solving the + problem to a high accuracy or their ranks may be difficult to + determine, so the extra accuracy given by the cache based refinement + is very useful. Finally, note that we include the nuclear norm term + as part of the "risk" for the purposes of determining when to stop. + - else + - The value of #get_cache_based_epsilon() has no effect. + !*/ + + void set_cache_based_epsilon ( + double eps + ); + /*! + requires + - eps > 0 + ensures + - #get_cache_based_epsilon() == eps + !*/ + + void set_max_iterations ( + unsigned long max_iter + ); + /*! + ensures + - #get_max_iterations() == max_iter + !*/ + + unsigned long get_max_iterations ( + ); + /*! + ensures + - returns the maximum number of iterations the SVM optimizer is allowed to + run before it is required to stop and return a result. + !*/ + + void add_nuclear_norm_regularizer ( + long first_dimension, + long rows, + long cols, + double regularization_strength + ); + /*! + requires + - 0 <= first_dimension < number of dimensions in problem + - 0 <= rows + - 0 <= cols + - first_dimension+rows*cols <= number of dimensions in problem + - 0 < regularization_strength + ensures + - Adds a nuclear norm regularization term to the optimization problem + solved by this object. That is, instead of solving: + Minimize: h(w) == 0.5*dot(w,w) + C*R(w) + this object will solve: + Minimize: h(w) == 0.5*dot(w,w) + C*R(w) + regularization_strength*nuclear_norm_of(part of w) + where "part of w" is the part of w indicated by the arguments to this + function. In particular, the part of w included in the nuclear norm is + exactly the matrix reshape(rowm(w, range(first_dimension, first_dimension+rows*cols-1)), rows, cols). + Therefore, if you think of the w vector as being the concatenation of a + bunch of matrices then you can use multiple calls to add_nuclear_norm_regularizer() + to add nuclear norm regularization terms to any of the matrices packed into w. + - #num_nuclear_norm_regularizers() == num_nuclear_norm_regularizers() + 1 + !*/ + + unsigned long num_nuclear_norm_regularizers ( + ) const; + /*! + ensures + - returns the number of nuclear norm regularizers that are currently a part + of this optimization problem. That is, returns the number of times + add_nuclear_norm_regularizer() has been called since the last call to + clear_nuclear_norm_regularizers() or object construction, whichever is + most recent. + !*/ + + void clear_nuclear_norm_regularizers ( + ); + /*! + ensures + - #num_nuclear_norm_regularizers() == 0 + !*/ + + void be_verbose ( + ); + /*! + ensures + - This object will print status messages to standard out so that a + user can observe the progress of the algorithm. + !*/ + + void be_quiet( + ); + /*! + ensures + - this object will not print anything to standard out + !*/ + + double get_c ( + ) const; + /*! + ensures + - returns the SVM regularization parameter. It is the parameter that + determines the trade off between trying to fit the training data + exactly or allowing more errors but hopefully improving the + generalization of the resulting classifier. Larger values encourage + exact fitting while smaller values of C may encourage better + generalization. + !*/ + + void set_c ( + double C + ); + /*! + requires + - C > 0 + ensures + - #get_c() == C + !*/ + + void add_processing_node ( + const network_address& addr + ); + /*! + requires + - addr.port != 0 + ensures + - if (this address hasn't already been added) then + - #get_num_processing_nodes() == get_num_processing_nodes() + 1 + - When operator() is invoked to solve the structural svm problem this + object will connect to the svm_struct_processing_node located at the + given network address and will include it in the distributed + optimization. + !*/ + + void add_processing_node ( + const std::string& ip_or_hostname, + unsigned short port + ); + /*! + requires + - port != 0 + ensures + - invokes: add_processing_node(network_address(ip_or_hostname, port)) + !*/ + + unsigned long get_num_processing_nodes ( + ) const; + /*! + ensures + - returns the number of remote processing nodes that have been + registered with this object. + !*/ + + void remove_processing_nodes ( + ); + /*! + ensures + - #get_num_processing_nodes() == 0 + !*/ + + class invalid_problem : public error {}; + + template + double operator() ( + const oca& solver, + matrix_type& w + ) const; + /*! + requires + - get_num_processing_nodes() != 0 + - matrix_type == a dlib::matrix capable of storing column vectors + ensures + - connects to the processing nodes and begins optimizing the structural + svm problem using the given oca solver. + - stores the solution in #w + - returns the objective value at the solution #w + throws + - invalid_problem + This exception is thrown if the svm_struct_processing_nodes disagree + on the dimensionality of the problem. That is, if they disagree on + the value of structural_svm_problem::get_num_dimensions(). + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_STRUCTURAL_SVM_DISTRIBUTeD_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/svm/structural_svm_graph_labeling_problem.h b/ml/dlib/dlib/svm/structural_svm_graph_labeling_problem.h new file mode 100644 index 000000000..c677861c9 --- /dev/null +++ b/ml/dlib/dlib/svm/structural_svm_graph_labeling_problem.h @@ -0,0 +1,542 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_STRUCTURAL_SVM_GRAPH_LAbELING_PROBLEM_Hh_ +#define DLIB_STRUCTURAL_SVM_GRAPH_LAbELING_PROBLEM_Hh_ + + +#include "structural_svm_graph_labeling_problem_abstract.h" +#include "../graph_cuts.h" +#include "../matrix.h" +#include "../array.h" +#include +#include +#include "structural_svm_problem_threaded.h" +#include "../graph.h" +#include "sparse_vector.h" +#include + +// ---------------------------------------------------------------------------------------- + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename graph_type + > + bool is_graph_labeling_problem ( + const dlib::array& samples, + const std::vector >& labels, + std::string& reason_for_failure + ) + { + typedef typename graph_type::type node_vector_type; + typedef typename graph_type::edge_type edge_vector_type; + // The graph must use all dense vectors or all sparse vectors. It can't mix the two types together. + COMPILE_TIME_ASSERT( (is_matrix::value && is_matrix::value) || + (!is_matrix::value && !is_matrix::value)); + + + std::ostringstream sout; + reason_for_failure.clear(); + + if (!is_learning_problem(samples, labels)) + { + reason_for_failure = "is_learning_problem(samples, labels) returned false."; + return false; + } + + const bool ismat = is_matrix::value; + + // these are -1 until assigned with a value + long node_dims = -1; + long edge_dims = -1; + + for (unsigned long i = 0; i < samples.size(); ++i) + { + if (samples[i].number_of_nodes() != labels[i].size()) + { + sout << "samples["< + bool is_graph_labeling_problem ( + const dlib::array& samples, + const std::vector >& labels + ) + { + std::string reason_for_failure; + return is_graph_labeling_problem(samples, labels, reason_for_failure); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename U + > + bool sizes_match ( + const std::vector >& lhs, + const std::vector >& rhs + ) + { + if (lhs.size() != rhs.size()) + return false; + + for (unsigned long i = 0; i < lhs.size(); ++i) + { + if (lhs[i].size() != rhs[i].size()) + return false; + } + + return true; + } + +// ---------------------------------------------------------------------------------------- + + inline bool all_values_are_nonnegative ( + const std::vector >& x + ) + { + for (unsigned long i = 0; i < x.size(); ++i) + { + for (unsigned long j = 0; j < x[i].size(); ++j) + { + if (x[i][j] < 0) + return false; + } + } + return true; + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + namespace impl + { + template < + typename T, + typename enable = void + > + struct fvect + { + // In this case type should be some sparse vector type + typedef typename T::type type; + }; + + template < typename T > + struct fvect >::type> + { + // The point of this stuff is to create the proper matrix + // type to represent the concatenation of an edge vector + // with an node vector. + typedef typename T::type node_mat; + typedef typename T::edge_type edge_mat; + const static long NRd = node_mat::NR; + const static long NRe = edge_mat::NR; + const static long NR = ((NRd!=0) && (NRe!=0)) ? (NRd+NRe) : 0; + typedef typename node_mat::value_type value_type; + + typedef matrix type; + }; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename graph_type + > + class structural_svm_graph_labeling_problem : noncopyable, + public structural_svm_problem_threaded, + typename dlib::impl::fvect::type > + { + public: + typedef matrix matrix_type; + typedef typename dlib::impl::fvect::type feature_vector_type; + + typedef graph_type sample_type; + + typedef std::vector label_type; + + structural_svm_graph_labeling_problem( + const dlib::array& samples_, + const std::vector& labels_, + const std::vector >& losses_, + unsigned long num_threads = 2 + ) : + structural_svm_problem_threaded(num_threads), + samples(samples_), + labels(labels_), + losses(losses_) + { + // make sure requires clause is not broken +#ifdef ENABLE_ASSERTS + std::string reason_for_failure; + DLIB_ASSERT(is_graph_labeling_problem(samples, labels, reason_for_failure) == true , + "\t structural_svm_graph_labeling_problem::structural_svm_graph_labeling_problem()" + << "\n\t Invalid inputs were given to this function." + << "\n\t reason_for_failure: " << reason_for_failure + << "\n\t samples.size(): " << samples.size() + << "\n\t labels.size(): " << labels.size() + << "\n\t this: " << this ); + DLIB_ASSERT((losses.size() == 0 || sizes_match(labels, losses) == true) && + all_values_are_nonnegative(losses) == true, + "\t structural_svm_graph_labeling_problem::structural_svm_graph_labeling_problem()" + << "\n\t Invalid inputs were given to this function." + << "\n\t labels.size(): " << labels.size() + << "\n\t losses.size(): " << losses.size() + << "\n\t sizes_match(labels,losses): " << sizes_match(labels,losses) + << "\n\t all_values_are_nonnegative(losses): " << all_values_are_nonnegative(losses) + << "\n\t this: " << this ); +#endif + + loss_pos = 1.0; + loss_neg = 1.0; + + // figure out how many dimensions are in the node and edge vectors. + node_dims = 0; + edge_dims = 0; + for (unsigned long i = 0; i < samples.size(); ++i) + { + for (unsigned long j = 0; j < samples[i].number_of_nodes(); ++j) + { + node_dims = std::max(node_dims,(long)max_index_plus_one(samples[i].node(j).data)); + for (unsigned long n = 0; n < samples[i].node(j).number_of_neighbors(); ++n) + { + edge_dims = std::max(edge_dims, (long)max_index_plus_one(samples[i].node(j).edge(n))); + } + } + } + } + + const std::vector >& get_losses ( + ) const { return losses; } + + long get_num_edge_weights ( + ) const + { + return edge_dims; + } + + void set_loss_on_positive_class ( + double loss + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(loss >= 0 && get_losses().size() == 0, + "\t void structural_svm_graph_labeling_problem::set_loss_on_positive_class()" + << "\n\t Invalid inputs were given to this function." + << "\n\t loss: " << loss + << "\n\t this: " << this ); + + loss_pos = loss; + } + + void set_loss_on_negative_class ( + double loss + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(loss >= 0 && get_losses().size() == 0, + "\t void structural_svm_graph_labeling_problem::set_loss_on_negative_class()" + << "\n\t Invalid inputs were given to this function." + << "\n\t loss: " << loss + << "\n\t this: " << this ); + + loss_neg = loss; + } + + double get_loss_on_negative_class ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(get_losses().size() == 0, + "\t double structural_svm_graph_labeling_problem::get_loss_on_negative_class()" + << "\n\t Invalid inputs were given to this function." + << "\n\t this: " << this ); + + return loss_neg; + } + + double get_loss_on_positive_class ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(get_losses().size() == 0, + "\t double structural_svm_graph_labeling_problem::get_loss_on_positive_class()" + << "\n\t Invalid inputs were given to this function." + << "\n\t this: " << this ); + + return loss_pos; + } + + + private: + virtual long get_num_dimensions ( + ) const + { + // The psi/w vector will begin with all the edge dims and then follow with the node dims. + return edge_dims + node_dims; + } + + virtual long get_num_samples ( + ) const + { + return samples.size(); + } + + template + typename enable_if >::type get_joint_feature_vector ( + const sample_type& sample, + const label_type& label, + psi_type& psi + ) const + { + psi.set_size(get_num_dimensions()); + psi = 0; + for (unsigned long i = 0; i < sample.number_of_nodes(); ++i) + { + // accumulate the node vectors + if (label[i] == true) + set_rowm(psi, range(edge_dims, psi.size()-1)) += sample.node(i).data; + + for (unsigned long n = 0; n < sample.node(i).number_of_neighbors(); ++n) + { + const unsigned long j = sample.node(i).neighbor(n).index(); + + // Don't double count edges. Also only include the vector if + // the labels disagree. + if (i < j && label[i] != label[j]) + { + set_rowm(psi, range(0, edge_dims-1)) -= sample.node(i).edge(n); + } + } + } + } + + template + void add_to_sparse_vect ( + T& psi, + const T& vect, + unsigned long offset + ) const + { + for (typename T::const_iterator i = vect.begin(); i != vect.end(); ++i) + { + psi.insert(psi.end(), std::make_pair(i->first+offset, i->second)); + } + } + + template + void subtract_from_sparse_vect ( + T& psi, + const T& vect + ) const + { + for (typename T::const_iterator i = vect.begin(); i != vect.end(); ++i) + { + psi.insert(psi.end(), std::make_pair(i->first, -i->second)); + } + } + + template + typename disable_if >::type get_joint_feature_vector ( + const sample_type& sample, + const label_type& label, + psi_type& psi + ) const + { + psi.clear(); + for (unsigned long i = 0; i < sample.number_of_nodes(); ++i) + { + // accumulate the node vectors + if (label[i] == true) + add_to_sparse_vect(psi, sample.node(i).data, edge_dims); + + for (unsigned long n = 0; n < sample.node(i).number_of_neighbors(); ++n) + { + const unsigned long j = sample.node(i).neighbor(n).index(); + + // Don't double count edges. Also only include the vector if + // the labels disagree. + if (i < j && label[i] != label[j]) + { + subtract_from_sparse_vect(psi, sample.node(i).edge(n)); + } + } + } + } + + virtual void get_truth_joint_feature_vector ( + long idx, + feature_vector_type& psi + ) const + { + get_joint_feature_vector(samples[idx], labels[idx], psi); + } + + virtual void separation_oracle ( + const long idx, + const matrix_type& current_solution, + double& loss, + feature_vector_type& psi + ) const + { + const sample_type& samp = samples[idx]; + + // setup the potts graph based on samples[idx] and current_solution. + graph::kernel_1a g; + copy_graph_structure(samp, g); + for (unsigned long i = 0; i < g.number_of_nodes(); ++i) + { + g.node(i).data = dot(rowm(current_solution,range(edge_dims,current_solution.size()-1)), + samp.node(i).data); + + // Include a loss augmentation so that we will get the proper loss augmented + // max when we use find_max_factor_graph_potts() below. + if (labels[idx][i]) + g.node(i).data -= get_loss_for_sample(idx,i,!labels[idx][i]); + else + g.node(i).data += get_loss_for_sample(idx,i,!labels[idx][i]); + + for (unsigned long n = 0; n < g.node(i).number_of_neighbors(); ++n) + { + const unsigned long j = g.node(i).neighbor(n).index(); + // Don't compute an edge weight more than once. + if (i < j) + { + g.node(i).edge(n) = dot(rowm(current_solution,range(0,edge_dims-1)), + samp.node(i).edge(n)); + } + } + + } + + std::vector labeling; + find_max_factor_graph_potts(g, labeling); + + + std::vector bool_labeling; + bool_labeling.reserve(labeling.size()); + // figure out the loss + loss = 0; + for (unsigned long i = 0; i < labeling.size(); ++i) + { + const bool predicted_label = (labeling[i]!= 0); + bool_labeling.push_back(predicted_label); + loss += get_loss_for_sample(idx, i, predicted_label); + } + + // compute psi + get_joint_feature_vector(samp, bool_labeling, psi); + } + + double get_loss_for_sample ( + long sample_idx, + long node_idx, + bool predicted_label + ) const + /*! + requires + - 0 <= sample_idx < labels.size() + - 0 <= node_idx < labels[sample_idx].size() + ensures + - returns the loss incurred for predicting that the node + samples[sample_idx].node(node_idx) has a label of predicted_label. + !*/ + { + const bool true_label = labels[sample_idx][node_idx]; + if (true_label != predicted_label) + { + if (losses.size() != 0) + return losses[sample_idx][node_idx]; + else if (true_label == true) + return loss_pos; + else + return loss_neg; + } + else + { + // no loss for making the correct prediction. + return 0; + } + } + + const dlib::array& samples; + const std::vector& labels; + const std::vector >& losses; + + long node_dims; + long edge_dims; + double loss_pos; + double loss_neg; + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_STRUCTURAL_SVM_GRAPH_LAbELING_PROBLEM_Hh_ + + diff --git a/ml/dlib/dlib/svm/structural_svm_graph_labeling_problem_abstract.h b/ml/dlib/dlib/svm/structural_svm_graph_labeling_problem_abstract.h new file mode 100644 index 000000000..ab99ed8f4 --- /dev/null +++ b/ml/dlib/dlib/svm/structural_svm_graph_labeling_problem_abstract.h @@ -0,0 +1,249 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_STRUCTURAL_SVM_GRAPH_LAbELING_PROBLEM_ABSTRACT_Hh_ +#ifdef DLIB_STRUCTURAL_SVM_GRAPH_LAbELING_PROBLEM_ABSTRACT_Hh_ + +#include "../array/array_kernel_abstract.h" +#include "../graph/graph_kernel_abstract.h" +#include "../matrix/matrix_abstract.h" +#include "sparse_vector_abstract.h" +#include "structural_svm_problem_threaded_abstract.h" +#include + +// ---------------------------------------------------------------------------------------- + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename graph_type + > + bool is_graph_labeling_problem ( + const dlib::array& samples, + const std::vector >& labels + ); + /*! + requires + - graph_type is an implementation of dlib/graph/graph_kernel_abstract.h + - graph_type::type and graph_type::edge_type are either both dlib::matrix types + capable of containing column vectors or both some kind of sparse vector type + as defined in dlib/svm/sparse_vector_abstract.h. + ensures + - Note that a graph labeling problem is a task to learn a binary classifier which + predicts the correct label for each node in the provided graphs. Additionally, + we have information in the form of edges between nodes where edges are present + when we believe the linked nodes are likely to have the same label. Therefore, + part of a graph labeling problem is to learn to score each edge in terms of how + strongly the edge should enforce labeling consistency between its two nodes. + Thus, to be a valid graph labeling problem, samples should contain example graphs + of connected nodes while labels should indicate the desired label of each node. + The precise requirements for a valid graph labeling problem are listed below. + - This function returns true if all of the following are true and false otherwise: + - is_learning_problem(samples, labels) == true + - All the vectors stored on the edges of each graph in samples + contain only values which are >= 0. + - for all valid i: + - graph_contains_length_one_cycle(samples[i]) == false + - samples[i].number_of_nodes() == labels[i].size() + (i.e. Every graph node gets its own label) + - if (graph_type::edge_type is a dlib::matrix) then + - All the nodes must contain vectors with the same number of dimensions. + - All the edges must contain vectors with the same number of dimensions. + (However, edge vectors may differ in dimension from node vectors.) + - All vectors have non-zero size. That is, they have more than 0 dimensions. + !*/ + + template < + typename graph_type + > + bool is_graph_labeling_problem ( + const dlib::array& samples, + const std::vector >& labels, + std::string& reason_for_failure + ); + /*! + This function is identical to the above version of is_graph_labeling_problem() + except that if it returns false it will populate reason_for_failure with a message + describing why the graph is not a valid learning problem. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename U + > + bool sizes_match ( + const std::vector >& lhs, + const std::vector >& rhs + ); + /*! + ensures + - returns true if the sizes of lhs and rhs, as well as their constituent vectors + all match. In particular, we return true if all of the following conditions are + met and false otherwise: + - lhs.size() == rhs.size() + - for all valid i: + - lhs[i].size() == rhs[i].size() + !*/ + +// ---------------------------------------------------------------------------------------- + + bool all_values_are_nonnegative ( + const std::vector >& x + ); + /*! + ensures + - returns true if all the double values contained in x are >= 0. + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename graph_type + > + class structural_svm_graph_labeling_problem : noncopyable, + public structural_svm_problem_threaded, + typename graph_type::type > + { + /*! + REQUIREMENTS ON graph_type + - graph_type is an implementation of dlib/graph/graph_kernel_abstract.h + - graph_type::type and graph_type::edge_type must be either matrix objects + capable of representing column vectors or some kind of sparse vector + type as defined in dlib/svm/sparse_vector_abstract.h. + + WHAT THIS OBJECT REPRESENTS + This object is a tool for learning the weight vectors needed to use + a graph_labeler object. It learns the parameter vectors by formulating + the problem as a structural SVM problem. + !*/ + + public: + typedef matrix matrix_type; + typedef typename graph_type::type feature_vector_type; + typedef graph_type sample_type; + typedef std::vector label_type; + + structural_svm_graph_labeling_problem( + const dlib::array& samples, + const std::vector& labels, + const std::vector >& losses, + unsigned long num_threads + ); + /*! + requires + - is_graph_labeling_problem(samples,labels) == true + - if (losses.size() != 0) then + - sizes_match(labels, losses) == true + - all_values_are_nonnegative(losses) == true + ensures + - This object attempts to learn a mapping from the given samples to the + given labels. In particular, it attempts to learn to predict labels[i] + based on samples[i]. Or in other words, this object can be used to learn + parameter vectors, E and W, such that a graph_labeler declared as: + graph_labeler labeler(E,W) + results in a labeler object which attempts to compute the following mapping: + labels[i] == labeler(samples[i]) + - When you use this object with the oca optimizer you get back just one + big parameter vector as the solution. Therefore, note that this single + big vector is the concatenation of E and W. The first get_num_edge_weights() + elements of this vector correspond to E and the rest is W. + - This object will use num_threads threads during the optimization + procedure. You should set this parameter equal to the number of + available processing cores on your machine. + - if (losses.size() == 0) then + - #get_loss_on_positive_class() == 1.0 + - #get_loss_on_negative_class() == 1.0 + - #get_losses().size() == 0 + - The losses argument is effectively ignored if its size is zero. + - else + - #get_losses() == losses + - Each node in the training data has its own loss value defined by + the corresponding entry of losses. In particular, this means that + the node with label labels[i][j] incurs a loss of losses[i][j] if + it is incorrectly labeled. + - The get_loss_on_positive_class() and get_loss_on_negative_class() + parameters are ignored. Only get_losses() is used in this case. + !*/ + + const std::vector >& get_losses ( + ) const; + /*! + ensures + - returns the losses vector given to this object's constructor. + This vector defines the per sample loss values used. If the vector + is empty then the loss values defined by get_loss_on_positive_class() and + get_loss_on_positive_class() are used instead. + !*/ + + long get_num_edge_weights ( + ) const; + /*! + ensures + - returns the dimensionality of the edge weight vector. It is also + important to know that when using the oca solver with this object, + you must set it to generate non-negative weights for the edge weight + part of the total weight vector. You can do this by passing get_num_edge_weights() + to the third argument to oca::operator(). + !*/ + + void set_loss_on_positive_class ( + double loss + ); + /*! + requires + - loss >= 0 + - get_losses().size() == 0 + ensures + - #get_loss_on_positive_class() == loss + !*/ + + void set_loss_on_negative_class ( + double loss + ); + /*! + requires + - loss >= 0 + - get_losses().size() == 0 + ensures + - #get_loss_on_negative_class() == loss + !*/ + + double get_loss_on_positive_class ( + ) const; + /*! + requires + - get_losses().size() == 0 + ensures + - returns the loss incurred when a graph node which is supposed to have + a label of true gets misclassified. This value controls how much we care + about correctly classifying nodes which should be labeled as true. Larger + loss values indicate that we care more strongly than smaller values. + !*/ + + double get_loss_on_negative_class ( + ) const; + /*! + requires + - get_losses().size() == 0 + ensures + - returns the loss incurred when a graph node which is supposed to have + a label of false gets misclassified. This value controls how much we care + about correctly classifying nodes which should be labeled as false. Larger + loss values indicate that we care more strongly than smaller values. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_STRUCTURAL_SVM_GRAPH_LAbELING_PROBLEM_ABSTRACT_Hh_ + + + + diff --git a/ml/dlib/dlib/svm/structural_svm_object_detection_problem.h b/ml/dlib/dlib/svm/structural_svm_object_detection_problem.h new file mode 100644 index 000000000..1c54a42b1 --- /dev/null +++ b/ml/dlib/dlib/svm/structural_svm_object_detection_problem.h @@ -0,0 +1,531 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_STRUCTURAL_SVM_ObJECT_DETECTION_PROBLEM_Hh_ +#define DLIB_STRUCTURAL_SVM_ObJECT_DETECTION_PROBLEM_Hh_ + +#include "structural_svm_object_detection_problem_abstract.h" +#include "../matrix.h" +#include "structural_svm_problem_threaded.h" +#include +#include "../string.h" +#include "../array.h" +#include "../image_processing/full_object_detection.h" +#include "../image_processing/box_overlap_testing.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_scanner_type, + typename image_array_type + > + class structural_svm_object_detection_problem : public structural_svm_problem_threaded >, + noncopyable + { + public: + + structural_svm_object_detection_problem( + const image_scanner_type& scanner, + const test_box_overlap& overlap_tester, + const bool auto_overlap_tester, + const image_array_type& images_, + const std::vector >& truth_object_detections_, + const std::vector >& ignore_, + const test_box_overlap& ignore_overlap_tester_, + unsigned long num_threads = 2 + ) : + structural_svm_problem_threaded >(num_threads), + boxes_overlap(overlap_tester), + images(images_), + truth_object_detections(truth_object_detections_), + ignore(ignore_), + ignore_overlap_tester(ignore_overlap_tester_), + match_eps(0.5), + loss_per_false_alarm(1), + loss_per_missed_target(1) + { +#ifdef ENABLE_ASSERTS + // make sure requires clause is not broken + DLIB_ASSERT(is_learning_problem(images_, truth_object_detections_) && + ignore_.size() == images_.size() && + scanner.get_num_detection_templates() > 0, + "\t structural_svm_object_detection_problem::structural_svm_object_detection_problem()" + << "\n\t Invalid inputs were given to this function " + << "\n\t scanner.get_num_detection_templates(): " << scanner.get_num_detection_templates() + << "\n\t is_learning_problem(images_,truth_object_detections_): " << is_learning_problem(images_,truth_object_detections_) + << "\n\t ignore.size(): " << ignore.size() + << "\n\t images.size(): " << images.size() + << "\n\t this: " << this + ); + for (unsigned long i = 0; i < truth_object_detections.size(); ++i) + { + for (unsigned long j = 0; j < truth_object_detections[i].size(); ++j) + { + DLIB_ASSERT(truth_object_detections[i][j].num_parts() == scanner.get_num_movable_components_per_detection_template(), + "\t trained_function_type structural_object_detection_trainer::train()" + << "\n\t invalid inputs were given to this function" + << "\n\t truth_object_detections["< max_num_dets) + max_num_dets = truth_object_detections[i].size(); + } + max_num_dets = max_num_dets*3 + 10; + + initialize_scanners(scanner, num_threads); + + if (auto_overlap_tester) + { + auto_configure_overlap_tester(); + } + } + + test_box_overlap get_overlap_tester ( + ) const + { + return boxes_overlap; + } + + void set_match_eps ( + double eps + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(0 < eps && eps < 1, + "\t void structural_svm_object_detection_problem::set_match_eps(eps)" + << "\n\t Invalid inputs were given to this function " + << "\n\t eps: " << eps + << "\n\t this: " << this + ); + + match_eps = eps; + } + + double get_match_eps ( + ) const + { + return match_eps; + } + + double get_loss_per_missed_target ( + ) const + { + return loss_per_missed_target; + } + + void set_loss_per_missed_target ( + double loss + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(loss > 0, + "\t void structural_svm_object_detection_problem::set_loss_per_missed_target(loss)" + << "\n\t Invalid inputs were given to this function " + << "\n\t loss: " << loss + << "\n\t this: " << this + ); + + loss_per_missed_target = loss; + } + + double get_loss_per_false_alarm ( + ) const + { + return loss_per_false_alarm; + } + + void set_loss_per_false_alarm ( + double loss + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(loss > 0, + "\t void structural_svm_object_detection_problem::set_loss_per_false_alarm(loss)" + << "\n\t Invalid inputs were given to this function " + << "\n\t loss: " << loss + << "\n\t this: " << this + ); + + loss_per_false_alarm = loss; + } + + private: + + void auto_configure_overlap_tester( + ) + { + std::vector > mapped_rects(truth_object_detections.size()); + for (unsigned long i = 0; i < truth_object_detections.size(); ++i) + { + mapped_rects[i].resize(truth_object_detections[i].size()); + for (unsigned long j = 0; j < truth_object_detections[i].size(); ++j) + { + mapped_rects[i][j] = scanners[i].get_best_matching_rect(truth_object_detections[i][j].get_rect()); + } + } + + boxes_overlap = find_tight_overlap_tester(mapped_rects); + } + + + virtual long get_num_dimensions ( + ) const + { + return scanners[0].get_num_dimensions() + + 1;// for threshold + } + + virtual long get_num_samples ( + ) const + { + return images.size(); + } + + virtual void get_truth_joint_feature_vector ( + long idx, + feature_vector_type& psi + ) const + { + const image_scanner_type& scanner = scanners[idx]; + + psi.set_size(get_num_dimensions()); + std::vector mapped_rects; + + psi = 0; + for (unsigned long i = 0; i < truth_object_detections[idx].size(); ++i) + { + mapped_rects.push_back(scanner.get_best_matching_rect(truth_object_detections[idx][i].get_rect())); + scanner.get_feature_vector(truth_object_detections[idx][i], psi); + } + psi(scanner.get_num_dimensions()) = -1.0*truth_object_detections[idx].size(); + + // check if any of the boxes overlap. If they do then it is impossible for + // us to learn to correctly classify this sample + for (unsigned long i = 0; i < mapped_rects.size(); ++i) + { + for (unsigned long j = i+1; j < mapped_rects.size(); ++j) + { + if (boxes_overlap(mapped_rects[i], mapped_rects[j])) + { + const double area_overlap = mapped_rects[i].intersect(mapped_rects[j]).area(); + const double match_amount = area_overlap/(double)( mapped_rects[i]+mapped_rects[j]).area(); + const double overlap_amount = area_overlap/std::min(mapped_rects[i].area(),mapped_rects[j].area()); + + using namespace std; + ostringstream sout; + sout << "An impossible set of object labels was detected. This is happening because "; + sout << "the truth labels for an image contain rectangles which overlap according to the "; + sout << "test_box_overlap object supplied for non-max suppression. To resolve this, you "; + sout << "either need to relax the test_box_overlap object so it doesn't mark these rectangles as "; + sout << "overlapping or adjust the truth rectangles in your training dataset. "; + + // make sure the above string fits nicely into a command prompt window. + string temp = sout.str(); + sout.str(""); sout << wrap_string(temp,0,0) << endl << endl; + + + sout << "image index: "<< idx << endl; + sout << "The offending rectangles are:\n"; + sout << "rect1: "<< mapped_rects[i] << endl; + sout << "rect2: "<< mapped_rects[j] << endl; + sout << "match amount: " << match_amount << endl; + sout << "overlap amount: " << overlap_amount << endl; + throw dlib::impossible_labeling_error(sout.str()); + } + } + } + + // make sure the mapped rectangles are within match_eps of the + // truth rectangles. + for (unsigned long i = 0; i < mapped_rects.size(); ++i) + { + const double area = (truth_object_detections[idx][i].get_rect().intersect(mapped_rects[i])).area(); + const double total_area = (truth_object_detections[idx][i].get_rect() + mapped_rects[i]).area(); + if (area/total_area <= match_eps) + { + using namespace std; + ostringstream sout; + sout << "An impossible set of object labels was detected. This is happening because "; + sout << "none of the object locations checked by the supplied image scanner is a close "; + sout << "enough match to one of the truth boxes in your training dataset. To resolve this "; + sout << "you need to either lower the match_eps, adjust the settings of the image scanner "; + sout << "so that it is capable of hitting this truth box, or adjust the offending truth rectangle so it "; + sout << "can be matched by the current image scanner. Also, if you "; + sout << "are using the scan_fhog_pyramid object then you could try using a finer image pyramid. "; + sout << "Additionally, the scan_fhog_pyramid scans a fixed aspect ratio box across the image when it "; + sout << "searches for objects. So if you are getting this error and you are using the scan_fhog_pyramid, "; + sout << "it's very likely the problem is that your training dataset contains truth rectangles of widely "; + sout << "varying aspect ratios. The solution is to make sure your training boxes all have about the same aspect ratio. "; + + + // make sure the above string fits nicely into a command prompt window. + string temp = sout.str(); + sout.str(""); sout << wrap_string(temp,0,0) << endl << endl; + + sout << "image index "<< idx << endl; + sout << "match_eps: "<< match_eps << endl; + sout << "best possible match: "<< area/total_area << endl; + sout << "truth rect: "<< truth_object_detections[idx][i].get_rect() << endl; + sout << "truth rect width/height: "<< truth_object_detections[idx][i].get_rect().width()/(double)truth_object_detections[idx][i].get_rect().height() << endl; + sout << "truth rect area: "<< truth_object_detections[idx][i].get_rect().area() << endl; + sout << "nearest detection template rect: "<< mapped_rects[i] << endl; + sout << "nearest detection template rect width/height: "<< mapped_rects[i].width()/(double)mapped_rects[i].height() << endl; + sout << "nearest detection template rect area: "<< mapped_rects[i].area() << endl; + throw dlib::impossible_labeling_error(sout.str()); + } + + } + } + + virtual void separation_oracle ( + const long idx, + const matrix_type& current_solution, + scalar_type& loss, + feature_vector_type& psi + ) const + { + const image_scanner_type& scanner = scanners[idx]; + + std::vector > dets; + const double thresh = current_solution(scanner.get_num_dimensions()); + + + scanner.detect(current_solution, dets, thresh-loss_per_false_alarm); + + + // The loss will measure the number of incorrect detections. A detection is + // incorrect if it doesn't hit a truth rectangle or if it is a duplicate detection + // on a truth rectangle. + loss = truth_object_detections[idx].size()*loss_per_missed_target; + + // Measure the loss augmented score for the detections which hit a truth rect. + std::vector truth_score_hits(truth_object_detections[idx].size(), 0); + + // keep track of which truth boxes we have hit so far. + std::vector hit_truth_table(truth_object_detections[idx].size(), false); + + std::vector final_dets; + // The point of this loop is to fill out the truth_score_hits array. + for (unsigned long i = 0; i < dets.size() && final_dets.size() < max_num_dets; ++i) + { + if (overlaps_any_box(boxes_overlap, final_dets, dets[i].second)) + continue; + + const std::pair truth = find_best_match(truth_object_detections[idx], dets[i].second); + + final_dets.push_back(dets[i].second); + + const double truth_match = truth.first; + // if hit truth rect + if (truth_match > match_eps) + { + // if this is the first time we have seen a detect which hit truth_object_detections[idx][truth.second] + const double score = dets[i].first - thresh; + if (hit_truth_table[truth.second] == false) + { + hit_truth_table[truth.second] = true; + truth_score_hits[truth.second] += score; + } + else + { + truth_score_hits[truth.second] += score + loss_per_false_alarm; + } + } + } + + hit_truth_table.assign(hit_truth_table.size(), false); + + final_dets.clear(); +#ifdef ENABLE_ASSERTS + double total_score = 0; +#endif + // Now figure out which detections jointly maximize the loss and detection score sum. We + // need to take into account the fact that allowing a true detection in the output, while + // initially reducing the loss, may allow us to increase the loss later with many duplicate + // detections. + for (unsigned long i = 0; i < dets.size() && final_dets.size() < max_num_dets; ++i) + { + if (overlaps_any_box(boxes_overlap, final_dets, dets[i].second)) + continue; + + const std::pair truth = find_best_match(truth_object_detections[idx], dets[i].second); + + const double truth_match = truth.first; + if (truth_match > match_eps) + { + if (truth_score_hits[truth.second] > loss_per_missed_target) + { + if (!hit_truth_table[truth.second]) + { + hit_truth_table[truth.second] = true; + final_dets.push_back(dets[i].second); +#ifdef ENABLE_ASSERTS + total_score += dets[i].first; +#endif + loss -= loss_per_missed_target; + } + else + { + final_dets.push_back(dets[i].second); +#ifdef ENABLE_ASSERTS + total_score += dets[i].first; +#endif + loss += loss_per_false_alarm; + } + } + } + else if (!overlaps_ignore_box(idx,dets[i].second)) + { + // didn't hit anything + final_dets.push_back(dets[i].second); +#ifdef ENABLE_ASSERTS + total_score += dets[i].first; +#endif + loss += loss_per_false_alarm; + } + } + + psi.set_size(get_num_dimensions()); + psi = 0; + for (unsigned long i = 0; i < final_dets.size(); ++i) + scanner.get_feature_vector(scanner.get_full_object_detection(final_dets[i], current_solution), psi); + +#ifdef ENABLE_ASSERTS + const double psi_score = dot(psi, current_solution); + DLIB_CASSERT(std::abs(psi_score-total_score) <= 1e-4 * std::max(1.0,std::max(std::abs(psi_score),std::abs(total_score))), + "\t The get_feature_vector() and detect() methods of image_scanner_type are not in sync." + << "\n\t The relative error is too large to be attributed to rounding error." + << "\n\t error: " << std::abs(psi_score-total_score) + << "\n\t psi_score: " << psi_score + << "\n\t total_score: " << total_score + ); +#endif + + psi(scanner.get_num_dimensions()) = -1.0*final_dets.size(); + } + + + bool overlaps_ignore_box ( + const long idx, + const dlib::rectangle& rect + ) const + { + for (unsigned long i = 0; i < ignore[idx].size(); ++i) + { + if (ignore_overlap_tester(ignore[idx][i], rect)) + return true; + } + return false; + } + + std::pair find_best_match( + const std::vector& boxes, + const rectangle rect + ) const + /*! + ensures + - determines which rectangle in boxes matches rect the most and + returns the amount of this match. Specifically, the match is + a number O with the following properties: + - 0 <= O <= 1 + - Let R be the maximum matching rectangle in boxes, then + O == (R.intersect(rect)).area() / (R + rect).area() + - O == 0 if there is no match with any rectangle. + !*/ + { + double match = 0; + unsigned int best_idx = 0; + for (unsigned long i = 0; i < boxes.size(); ++i) + { + + const unsigned long area = rect.intersect(boxes[i].get_rect()).area(); + if (area != 0) + { + const double new_match = area / static_cast((rect + boxes[i].get_rect()).area()); + if (new_match > match) + { + match = new_match; + best_idx = i; + } + } + } + + return std::make_pair(match,best_idx); + } + + struct init_scanners_helper + { + init_scanners_helper ( + array& scanners_, + const image_array_type& images_ + ) : + scanners(scanners_), + images(images_) + {} + + array& scanners; + const image_array_type& images; + + void operator() (long i ) const + { + scanners[i].load(images[i]); + } + }; + + void initialize_scanners ( + const image_scanner_type& scanner, + unsigned long num_threads + ) + { + scanners.set_max_size(images.size()); + scanners.set_size(images.size()); + + for (unsigned long i = 0; i < scanners.size(); ++i) + scanners[i].copy_configuration(scanner); + + // now load the images into all the scanners + parallel_for(num_threads, 0, scanners.size(), init_scanners_helper(scanners, images)); + } + + + test_box_overlap boxes_overlap; + + mutable array scanners; + + const image_array_type& images; + const std::vector >& truth_object_detections; + const std::vector >& ignore; + const test_box_overlap ignore_overlap_tester; + + unsigned long max_num_dets; + double match_eps; + double loss_per_false_alarm; + double loss_per_missed_target; + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_STRUCTURAL_SVM_ObJECT_DETECTION_PROBLEM_Hh_ + + diff --git a/ml/dlib/dlib/svm/structural_svm_object_detection_problem_abstract.h b/ml/dlib/dlib/svm/structural_svm_object_detection_problem_abstract.h new file mode 100644 index 000000000..d73c5920d --- /dev/null +++ b/ml/dlib/dlib/svm/structural_svm_object_detection_problem_abstract.h @@ -0,0 +1,178 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_STRUCTURAL_SVM_ObJECT_DETECTION_PROBLEM_ABSTRACT_Hh_ +#ifdef DLIB_STRUCTURAL_SVM_ObJECT_DETECTION_PROBLEM_ABSTRACT_Hh_ + +#include "../matrix.h" +#include "structural_svm_problem_threaded_abstract.h" +#include +#include "../image_processing/full_object_detection_abstract.h" +#include "../image_processing/box_overlap_testing.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename image_scanner_type, + typename image_array_type + > + class structural_svm_object_detection_problem : public structural_svm_problem_threaded >, + noncopyable + { + /*! + REQUIREMENTS ON image_scanner_type + image_scanner_type must be an implementation of + dlib/image_processing/scan_fhog_pyramid_abstract.h or + dlib/image_processing/scan_image_custom_abstract.h or + dlib/image_processing/scan_image_pyramid_abstract.h or + dlib/image_processing/scan_image_boxes_abstract.h + + REQUIREMENTS ON image_array_type + image_array_type must be an implementation of dlib/array/array_kernel_abstract.h + and it must contain objects which can be accepted by image_scanner_type::load(). + + WHAT THIS OBJECT REPRESENTS + This object is a tool for learning the parameter vector needed to use a + scan_image_pyramid, scan_fhog_pyramid, scan_image_custom, or + scan_image_boxes object. + + It learns the parameter vector by formulating the problem as a structural + SVM problem. The exact details of the method are described in the paper + Max-Margin Object Detection by Davis E. King (http://arxiv.org/abs/1502.00046). + + + !*/ + + public: + + structural_svm_object_detection_problem( + const image_scanner_type& scanner, + const test_box_overlap& overlap_tester, + const bool auto_overlap_tester, + const image_array_type& images, + const std::vector >& truth_object_detections, + const std::vector >& ignore, + const test_box_overlap& ignore_overlap_tester, + unsigned long num_threads = 2 + ); + /*! + requires + - is_learning_problem(images, truth_object_detections) + - ignore.size() == images.size() + - scanner.get_num_detection_templates() > 0 + - scanner.load(images[0]) must be a valid expression. + - for all valid i, j: + - truth_object_detections[i][j].num_parts() == scanner.get_num_movable_components_per_detection_template() + - all_parts_in_rect(truth_object_detections[i][j]) == true + ensures + - This object attempts to learn a mapping from the given images to the + object locations given in truth_object_detections. In particular, it + attempts to learn to predict truth_object_detections[i] based on + images[i]. Or in other words, this object can be used to learn a + parameter vector, w, such that an object_detector declared as: + object_detector detector(scanner,get_overlap_tester(),w) + results in a detector object which attempts to compute the locations of + all the objects in truth_object_detections. So if you called + detector(images[i]) you would hopefully get a list of rectangles back + that had truth_object_detections[i].size() elements and contained exactly + the rectangles indicated by truth_object_detections[i]. + - if (auto_overlap_tester == true) then + - #get_overlap_tester() == a test_box_overlap object that is configured + using the find_tight_overlap_tester() routine and the contents of + truth_object_detections. + - else + - #get_overlap_tester() == overlap_tester + - #get_match_eps() == 0.5 + - This object will use num_threads threads during the optimization + procedure. You should set this parameter equal to the number of + available processing cores on your machine. + - #get_loss_per_missed_target() == 1 + - #get_loss_per_false_alarm() == 1 + - for all valid i: + - Within images[i] any detections that match against a rectangle in + ignore[i], according to ignore_overlap_tester, are ignored. That is, + the optimizer doesn't care if the detector outputs a detection that + matches any of the ignore rectangles or if it fails to output a + detection for an ignore rectangle. Therefore, if there are objects + in your dataset that you are unsure you want to detect or otherwise + don't care if the detector gets or doesn't then you can mark them + with ignore rectangles and the optimizer will simply ignore them. + !*/ + + test_box_overlap get_overlap_tester ( + ) const; + /*! + ensures + - returns the overlap tester used by this object. + !*/ + + void set_match_eps ( + double eps + ); + /*! + requires + - 0 < eps < 1 + ensures + - #get_match_eps() == eps + !*/ + + double get_match_eps ( + ) const; + /*! + ensures + - returns the amount of alignment necessary for a detection to be considered + as matching with a ground truth rectangle. The precise formula for determining + if two rectangles match each other is the following, rectangles A and B match + if and only if: + A.intersect(B).area()/(A+B).area() > get_match_eps() + !*/ + + double get_loss_per_missed_target ( + ) const; + /*! + ensures + - returns the amount of loss experienced for failing to detect one of the + targets. + !*/ + + void set_loss_per_missed_target ( + double loss + ); + /*! + requires + - loss > 0 + ensures + - #get_loss_per_missed_target() == loss + !*/ + + double get_loss_per_false_alarm ( + ) const; + /*! + ensures + - returns the amount of loss experienced for emitting a false alarm detection. + Or in other words, the loss for generating a detection that doesn't correspond + to one of the truth rectangles. + !*/ + + void set_loss_per_false_alarm ( + double loss + ); + /*! + requires + - loss > 0 + ensures + - #get_loss_per_false_alarm() == loss + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_STRUCTURAL_SVM_ObJECT_DETECTION_PROBLEM_ABSTRACT_Hh_ + + + diff --git a/ml/dlib/dlib/svm/structural_svm_problem.h b/ml/dlib/dlib/svm/structural_svm_problem.h new file mode 100644 index 000000000..3a73457b9 --- /dev/null +++ b/ml/dlib/dlib/svm/structural_svm_problem.h @@ -0,0 +1,649 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_STRUCTURAL_SVM_PRObLEM_Hh_ +#define DLIB_STRUCTURAL_SVM_PRObLEM_Hh_ + +#include "structural_svm_problem_abstract.h" +#include "../algs.h" +#include +#include "../optimization/optimization_oca.h" +#include "../matrix.h" +#include "sparse_vector.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + struct nuclear_norm_regularizer + { + long first_dimension; + long nr; + long nc; + double regularization_strength; + }; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename structural_svm_problem + > + class cache_element_structural_svm + { + public: + + cache_element_structural_svm ( + ) : prob(0), sample_idx(0), last_true_risk_computed(std::numeric_limits::infinity()) {} + + typedef typename structural_svm_problem::scalar_type scalar_type; + typedef typename structural_svm_problem::matrix_type matrix_type; + typedef typename structural_svm_problem::feature_vector_type feature_vector_type; + + void init ( + const structural_svm_problem* prob_, + const long idx + ) + /*! + ensures + - This object will be a cache for the idx-th sample in the given + structural_svm_problem. + !*/ + { + prob = prob_; + sample_idx = idx; + + loss.clear(); + psi.clear(); + lru_count.clear(); + + if (prob->get_max_cache_size() != 0) + { + prob->get_truth_joint_feature_vector(idx, true_psi); + compact_sparse_vector(true_psi); + } + } + + void get_truth_joint_feature_vector_cached ( + feature_vector_type& psi + ) const + { + if (prob->get_max_cache_size() != 0) + psi = true_psi; + else + prob->get_truth_joint_feature_vector(sample_idx, psi); + + if (is_matrix::value) + { + DLIB_CASSERT((long)psi.size() == prob->get_num_dimensions(), + "The dimensionality of your PSI vector doesn't match get_num_dimensions()"); + } + } + + void separation_oracle_cached ( + const bool use_only_cache, + const bool skip_cache, + const scalar_type& saved_current_risk_gap, + const matrix_type& current_solution, + scalar_type& out_loss, + feature_vector_type& out_psi + ) const + { + const bool cache_enabled = prob->get_max_cache_size() != 0; + + // Don't waste time computing this if the cache isn't going to be used. + const scalar_type dot_true_psi = cache_enabled ? dot(true_psi, current_solution) : 0; + + scalar_type best_risk = -std::numeric_limits::infinity(); + unsigned long best_idx = 0; + long max_lru_count = 0; + if (cache_enabled) + { + // figure out which element in the cache is the best (i.e. has the biggest risk) + for (unsigned long i = 0; i < loss.size(); ++i) + { + const scalar_type risk = loss[i] + dot(psi[i], current_solution) - dot_true_psi; + if (risk > best_risk) + { + best_risk = risk; + out_loss = loss[i]; + best_idx = i; + } + if (lru_count[i] > max_lru_count) + max_lru_count = lru_count[i]; + } + + if (!skip_cache) + { + // Check if the best psi vector in the cache is still good enough to use as + // a proxy for the true separation oracle. If the risk value has dropped + // by enough to get into the stopping condition then the best psi isn't + // good enough. + if ((best_risk + saved_current_risk_gap > last_true_risk_computed && + best_risk >= 0) || use_only_cache) + { + out_psi = psi[best_idx]; + lru_count[best_idx] = max_lru_count + 1; + return; + } + } + } + + + prob->separation_oracle(sample_idx, current_solution, out_loss, out_psi); + if (is_matrix::value) + { + DLIB_CASSERT((long)out_psi.size() == prob->get_num_dimensions(), + "The dimensionality of your PSI vector doesn't match get_num_dimensions()"); + } + + if (!cache_enabled) + return; + + compact_sparse_vector(out_psi); + + last_true_risk_computed = out_loss + dot(out_psi, current_solution) - dot_true_psi; + + // If the separation oracle is only solved approximately then the result might + // not be as good as just selecting true_psi as the output. So here we check + // if that is the case. + if (last_true_risk_computed < 0 && best_risk < 0) + { + out_psi = true_psi; + out_loss = 0; + } + // Alternatively, an approximate separation oracle might not do as well as just + // selecting from the cache. So if that is the case when just take the best + // element from the cache. + else if (last_true_risk_computed < best_risk) + { + out_psi = psi[best_idx]; + out_loss = loss[best_idx]; + lru_count[best_idx] = max_lru_count + 1; + } + // if the cache is full + else if (loss.size() >= prob->get_max_cache_size()) + { + // find least recently used cache entry for idx-th sample + const long i = index_of_min(mat(lru_count)); + + // save our new data in the cache + loss[i] = out_loss; + psi[i] = out_psi; + + const long max_use = max(mat(lru_count)); + // Make sure this new cache entry has the best lru count since we have used + // it most recently. + lru_count[i] = max_use + 1; + } + else + { + // In this case we just append the new psi into the cache. + + loss.push_back(out_loss); + psi.push_back(out_psi); + long max_use = 1; + if (lru_count.size() != 0) + max_use = max(mat(lru_count)) + 1; + lru_count.push_back(max_use); + } + } + + private: + // Do nothing if T isn't actually a sparse vector + template void compact_sparse_vector( T& ) const { } + + template < + typename T, + typename U, + typename alloc + > + void compact_sparse_vector ( + std::vector,alloc>& vect + ) const + { + // If the sparse vector has more entires than dimensions then it must have some + // duplicate elements. So compact them using make_sparse_vector_inplace(). + if (vect.size() > (unsigned long)prob->get_num_dimensions()) + { + make_sparse_vector_inplace(vect); + // make sure the vector doesn't use more RAM than is necessary + std::vector,alloc>(vect).swap(vect); + } + } + + const structural_svm_problem* prob; + + long sample_idx; + + mutable feature_vector_type true_psi; + mutable std::vector loss; + mutable std::vector psi; + mutable std::vector lru_count; + mutable double last_true_risk_computed; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename matrix_type_, + typename feature_vector_type_ = matrix_type_ + > + class structural_svm_problem : public oca_problem + { + public: + /*! + CONVENTION + - C == get_c() + - eps == get_epsilon() + - max_iterations == get_max_iterations() + - if (skip_cache) then + - we won't use the oracle cache when we need to evaluate the separation + oracle. Instead, we will directly call the user supplied separation_oracle(). + + - get_max_cache_size() == max_cache_size + + - if (cache.size() != 0) then + - cache.size() == get_num_samples() + - for all i: cache[i] == the cached results of calls to separation_oracle() + for the i-th sample. + !*/ + + typedef matrix_type_ matrix_type; + typedef typename matrix_type::type scalar_type; + typedef feature_vector_type_ feature_vector_type; + + structural_svm_problem ( + ) : + saved_current_risk_gap(0), + eps(0.001), + max_iterations(10000), + verbose(false), + skip_cache(true), + count_below_eps(0), + max_cache_size(5), + converged(false), + nuclear_norm_part(0), + cache_based_eps(std::numeric_limits::infinity()), + C(1) + {} + + scalar_type get_cache_based_epsilon ( + ) const + { + return cache_based_eps; + } + + void set_cache_based_epsilon ( + scalar_type eps_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(eps_ > 0, + "\t void structural_svm_problem::set_cache_based_epsilon()" + << "\n\t eps_ must be greater than 0" + << "\n\t eps_: " << eps_ + << "\n\t this: " << this + ); + + cache_based_eps = eps_; + } + + void set_epsilon ( + scalar_type eps_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(eps_ > 0, + "\t void structural_svm_problem::set_epsilon()" + << "\n\t eps_ must be greater than 0" + << "\n\t eps_: " << eps_ + << "\n\t this: " << this + ); + + eps = eps_; + } + + const scalar_type get_epsilon ( + ) const { return eps; } + + unsigned long get_max_iterations ( + ) const { return max_iterations; } + + void set_max_iterations ( + unsigned long max_iter + ) + { + max_iterations = max_iter; + } + + void set_max_cache_size ( + unsigned long max_size + ) + { + max_cache_size = max_size; + } + + unsigned long get_max_cache_size ( + ) const { return max_cache_size; } + + void be_verbose ( + ) + { + verbose = true; + } + + void be_quiet( + ) + { + verbose = false; + } + + scalar_type get_c ( + ) const { return C; } + + void set_c ( + scalar_type C_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(C_ > 0, + "\t void structural_svm_problem::set_c()" + << "\n\t C_ must be greater than 0" + << "\n\t C_: " << C_ + << "\n\t this: " << this + ); + + C = C_; + } + + void add_nuclear_norm_regularizer ( + long first_dimension, + long rows, + long cols, + double regularization_strength + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(0 <= first_dimension && first_dimension < get_num_dimensions() && + 0 <= rows && 0 <= cols && rows*cols+first_dimension <= get_num_dimensions() && + 0 < regularization_strength, + "\t void structural_svm_problem::add_nuclear_norm_regularizer()" + << "\n\t Invalid arguments were given to this function." + << "\n\t first_dimension: " << first_dimension + << "\n\t rows: " << rows + << "\n\t cols: " << cols + << "\n\t get_num_dimensions(): " << get_num_dimensions() + << "\n\t regularization_strength: " << regularization_strength + << "\n\t this: " << this + ); + + impl::nuclear_norm_regularizer temp; + temp.first_dimension = first_dimension; + temp.nr = rows; + temp.nc = cols; + temp.regularization_strength = regularization_strength; + nuclear_norm_regularizers.push_back(temp); + } + + unsigned long num_nuclear_norm_regularizers ( + ) const { return nuclear_norm_regularizers.size(); } + + void clear_nuclear_norm_regularizers ( + ) { nuclear_norm_regularizers.clear(); } + + virtual long get_num_dimensions ( + ) const = 0; + + virtual long get_num_samples ( + ) const = 0; + + virtual void get_truth_joint_feature_vector ( + long idx, + feature_vector_type& psi + ) const = 0; + + virtual void separation_oracle ( + const long idx, + const matrix_type& current_solution, + scalar_type& loss, + feature_vector_type& psi + ) const = 0; + + private: + + virtual bool risk_has_lower_bound ( + scalar_type& lower_bound + ) const + { + lower_bound = 0; + return true; + } + + virtual bool optimization_status ( + scalar_type current_objective_value, + scalar_type current_error_gap, + scalar_type current_risk_value, + scalar_type current_risk_gap, + unsigned long num_cutting_planes, + unsigned long num_iterations + ) const + { + if (verbose) + { + using namespace std; + if (nuclear_norm_regularizers.size() != 0) + { + cout << "objective: " << current_objective_value << endl; + cout << "objective gap: " << current_error_gap << endl; + cout << "risk: " << current_risk_value-nuclear_norm_part << endl; + cout << "risk+nuclear norm: " << current_risk_value << endl; + cout << "risk+nuclear norm gap: " << current_risk_gap << endl; + cout << "num planes: " << num_cutting_planes << endl; + cout << "iter: " << num_iterations << endl; + } + else + { + cout << "objective: " << current_objective_value << endl; + cout << "objective gap: " << current_error_gap << endl; + cout << "risk: " << current_risk_value << endl; + cout << "risk gap: " << current_risk_gap << endl; + cout << "num planes: " << num_cutting_planes << endl; + cout << "iter: " << num_iterations << endl; + } + cout << endl; + } + + if (num_iterations >= max_iterations) + return true; + + saved_current_risk_gap = current_risk_gap; + + if (converged) + { + return (current_risk_gap < std::max(cache_based_eps,cache_based_eps*current_risk_value)) || + (current_risk_gap == 0); + } + + if (current_risk_gap < eps) + { + // Only stop when we see that the risk gap is small enough on a non-cached + // iteration. But even then, if we are supposed to do the cache based + // refinement then we just mark that we have "converged" to avoid further + // calls to the separation oracle and run all subsequent iterations off the + // cache. + if (skip_cache || max_cache_size == 0) + { + converged = true; + skip_cache = false; + return (current_risk_gap < std::max(cache_based_eps,cache_based_eps*current_risk_value)) || + (current_risk_gap == 0); + } + + ++count_below_eps; + + // Only disable the cache if we have seen a few consecutive iterations that + // look to have converged. + if (count_below_eps > 1) + { + // Instead of stopping we shouldn't use the cache on the next iteration. This way + // we can be sure to have the best solution rather than assuming the cache is up-to-date + // enough. + skip_cache = true; + count_below_eps = 0; + } + } + else + { + count_below_eps = 0; + skip_cache = false; + } + + return false; + } + + virtual void get_risk ( + matrix_type& w, + scalar_type& risk, + matrix_type& subgradient + ) const + { + feature_vector_type ftemp; + const unsigned long num = get_num_samples(); + + // initialize the cache and compute psi_true. + if (cache.size() == 0) + { + cache.resize(get_num_samples()); + for (unsigned long i = 0; i < cache.size(); ++i) + cache[i].init(this,i); + + psi_true.set_size(w.size(),1); + psi_true = 0; + + for (unsigned long i = 0; i < num; ++i) + { + cache[i].get_truth_joint_feature_vector_cached(ftemp); + + subtract_from(psi_true, ftemp); + } + } + + subgradient = psi_true; + scalar_type total_loss = 0; + call_separation_oracle_on_all_samples(w,subgradient,total_loss); + + subgradient /= num; + total_loss /= num; + risk = total_loss + dot(subgradient,w); + + if (nuclear_norm_regularizers.size() != 0) + { + matrix_type grad; + scalar_type obj; + compute_nuclear_norm_parts(w, grad, obj); + risk += obj; + subgradient += grad; + } + } + + virtual void call_separation_oracle_on_all_samples ( + const matrix_type& w, + matrix_type& subgradient, + scalar_type& total_loss + ) const + { + feature_vector_type ftemp; + const unsigned long num = get_num_samples(); + for (unsigned long i = 0; i < num; ++i) + { + scalar_type loss; + separation_oracle_cached(i, w, loss, ftemp); + total_loss += loss; + add_to(subgradient, ftemp); + } + } + + protected: + + void compute_nuclear_norm_parts( + const matrix_type& m, + matrix_type& grad, + scalar_type& obj + ) const + { + obj = 0; + grad.set_size(m.size(), 1); + grad = 0; + + matrix u,v,w,f; + nuclear_norm_part = 0; + for (unsigned long i = 0; i < nuclear_norm_regularizers.size(); ++i) + { + const long nr = nuclear_norm_regularizers[i].nr; + const long nc = nuclear_norm_regularizers[i].nc; + const long size = nr*nc; + const long idx = nuclear_norm_regularizers[i].first_dimension; + const double strength = nuclear_norm_regularizers[i].regularization_strength; + + f = matrix_cast(reshape(rowm(m, range(idx, idx+size-1)), nr, nc)); + svd3(f, u,w,v); + + + const double norm = sum(w); + obj += strength*norm; + nuclear_norm_part += strength*norm/C; + + f = u*trans(v); + + set_rowm(grad, range(idx, idx+size-1)) = matrix_cast(strength*reshape_to_column_vector(f)); + } + + obj /= C; + grad /= C; + } + + void separation_oracle_cached ( + const long idx, + const matrix_type& current_solution, + scalar_type& loss, + feature_vector_type& psi + ) const + { + cache[idx].separation_oracle_cached(converged, + skip_cache, + saved_current_risk_gap, + current_solution, + loss, + psi); + } + + std::vector nuclear_norm_regularizers; + + mutable scalar_type saved_current_risk_gap; + mutable matrix_type psi_true; + scalar_type eps; + unsigned long max_iterations; + mutable bool verbose; + + + mutable std::vector > cache; + mutable bool skip_cache; + mutable int count_below_eps; + unsigned long max_cache_size; + mutable bool converged; + mutable double nuclear_norm_part; + scalar_type cache_based_eps; + + scalar_type C; + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_STRUCTURAL_SVM_PRObLEM_Hh_ + diff --git a/ml/dlib/dlib/svm/structural_svm_problem_abstract.h b/ml/dlib/dlib/svm/structural_svm_problem_abstract.h new file mode 100644 index 000000000..20b3d73a7 --- /dev/null +++ b/ml/dlib/dlib/svm/structural_svm_problem_abstract.h @@ -0,0 +1,348 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_STRUCTURAL_SVM_PRObLEM_ABSTRACT_Hh_ +#ifdef DLIB_STRUCTURAL_SVM_PRObLEM_ABSTRACT_Hh_ + +#include "../optimization/optimization_oca_abstract.h" +#include "sparse_vector_abstract.h" +#include "../matrix.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename matrix_type_, + typename feature_vector_type_ = matrix_type_ + > + class structural_svm_problem : public oca_problem + { + public: + /*! + REQUIREMENTS ON matrix_type_ + - matrix_type_ == a dlib::matrix capable of storing column vectors + + REQUIREMENTS ON feature_vector_type_ + - feature_vector_type_ == a dlib::matrix capable of storing column vectors + or an unsorted sparse vector type as defined in dlib/svm/sparse_vector_abstract.h. + + INITIAL VALUE + - get_epsilon() == 0.001 + - get_max_iterations() == 10000 + - get_max_cache_size() == 5 + - get_c() == 1 + - get_cache_based_epsilon() == std::numeric_limits::infinity() + (I.e. the cache based epsilon feature is disabled) + - num_nuclear_norm_regularizers() == 0 + - This object will not be verbose + + WHAT THIS OBJECT REPRESENTS + This object is a tool for solving the optimization problem associated with + a structural support vector machine. A structural SVM is a supervised + machine learning method for learning to predict complex outputs. This is + contrasted with a binary classifier which makes only simple yes/no + predictions. A structural SVM, on the other hand, can learn to predict + complex outputs such as entire parse trees or DNA sequence alignments. To + do this, it learns a function F(x,y) which measures how well a particular + data sample x matches a label y. When used for prediction, the best label + for a new x is given by the y which maximizes F(x,y). + + To use this object you inherit from it, provide implementations of its four + pure virtual functions, and then pass your object to the oca optimizer. + Also, you should only pass an instance of this object to the oca optimizer + once. That is, the act of using a structural_svm_problem instance with the + oca solver "uses" the structural_svm_problem instance. If you want to + solve the same problem multiple times then you must use a fresh instance of + your structural_svm_problem. + + + To define the optimization problem precisely, we first introduce some notation: + - let PSI(x,y) == the joint feature vector for input x and a label y. + - let F(x,y|w) == dot(w,PSI(x,y)). + - let LOSS(idx,y) == the loss incurred for predicting that the idx-th training + sample has a label of y. Note that LOSS() should always be >= 0 and should + become exactly 0 when y is the correct label for the idx-th sample. + - let x_i == the i-th training sample. + - let y_i == the correct label for the i-th training sample. + - The number of data samples is N. + + Then the optimization problem solved using this object is the following: + Minimize: h(w) == 0.5*dot(w,w) + C*R(w) + + Where R(w) == sum from i=1 to N: 1/N * sample_risk(i,w) + and sample_risk(i,w) == max over all Y: LOSS(i,Y) + F(x_i,Y|w) - F(x_i,y_i|w) + and C > 0 + + + + For an introduction to structured support vector machines you should consult + the following paper: + Predicting Structured Objects with Support Vector Machines by + Thorsten Joachims, Thomas Hofmann, Yisong Yue, and Chun-nam Yu + + For a more detailed discussion of the particular algorithm implemented by this + object see the following paper: + T. Joachims, T. Finley, Chun-Nam Yu, Cutting-Plane Training of Structural SVMs, + Machine Learning, 77(1):27-59, 2009. + + Note that this object is essentially a tool for solving the 1-Slack structural + SVM with margin-rescaling. Specifically, see Algorithm 3 in the above referenced + paper. + !*/ + + typedef matrix_type_ matrix_type; + typedef typename matrix_type::type scalar_type; + typedef feature_vector_type_ feature_vector_type; + + structural_svm_problem ( + ); + /*! + ensures + - this object is properly initialized + !*/ + + void set_epsilon ( + scalar_type eps + ); + /*! + requires + - eps > 0 + ensures + - #get_epsilon() == eps + !*/ + + const scalar_type get_epsilon ( + ) const; + /*! + ensures + - returns the error epsilon that determines when training should stop. + Smaller values may result in a more accurate solution but take longer + to execute. Specifically, the algorithm stops when the average sample + risk (i.e. R(w) as defined above) is within epsilon of its optimal value. + + Also note that sample risk is an upper bound on a sample's loss. So + you can think of this epsilon value as saying "solve the optimization + problem until the average loss per sample is within epsilon of it's + optimal value". + !*/ + + scalar_type get_cache_based_epsilon ( + ) const; + /*! + ensures + - if (get_max_cache_size() != 0) then + - The solver will not stop when the average sample risk is within + get_epsilon() of its optimal value. Instead, it will keep running + but will run the optimizer completely on the cache until the average + sample risk is within #get_cache_based_epsilon() of its optimal + value. This means that it will perform this additional refinement in + the solution accuracy without making any additional calls to the + separation_oracle(). This is useful when using a nuclear norm + regularization term because it allows you to quickly solve the + optimization problem to a high precision, which in the case of a + nuclear norm regularized problem means that many of the learned + matrices will be low rank or very close to low rank due to the + nuclear norm regularizer. This may not happen without solving the + problem to a high accuracy or their ranks may be difficult to + determine, so the extra accuracy given by the cache based refinement + is very useful. Finally, note that we include the nuclear norm term + as part of the "risk" for the purposes of determining when to stop. + - else + - The value of #get_cache_based_epsilon() has no effect. + !*/ + + void set_cache_based_epsilon ( + scalar_type eps + ); + /*! + requires + - eps > 0 + ensures + - #get_cache_based_epsilon() == eps + !*/ + + void set_max_iterations ( + unsigned long max_iter + ); + /*! + ensures + - #get_max_iterations() == max_iter + !*/ + + unsigned long get_max_iterations ( + ); + /*! + ensures + - returns the maximum number of iterations the SVM optimizer is allowed to + run before it is required to stop and return a result. + !*/ + + void set_max_cache_size ( + unsigned long max_size + ); + /*! + ensures + - #get_max_cache_size() == max_size + !*/ + + unsigned long get_max_cache_size ( + ) const; + /*! + ensures + - Returns the number of joint feature vectors per training sample kept in + the separation oracle cache. This cache is used to avoid unnecessary + calls to the user supplied separation_oracle() function. Note that a + value of 0 means that caching is not used at all. This is appropriate + if the separation oracle is cheap to evaluate. + !*/ + + void add_nuclear_norm_regularizer ( + long first_dimension, + long rows, + long cols, + double regularization_strength + ); + /*! + requires + - 0 <= first_dimension < get_num_dimensions() + - 0 <= rows + - 0 <= cols + - first_dimension+rows*cols <= get_num_dimensions() + - 0 < regularization_strength + ensures + - Adds a nuclear norm regularization term to the optimization problem + solved by this object. That is, instead of solving: + Minimize: h(w) == 0.5*dot(w,w) + C*R(w) + this object will solve: + Minimize: h(w) == 0.5*dot(w,w) + C*R(w) + regularization_strength*nuclear_norm_of(part of w) + where "part of w" is the part of w indicated by the arguments to this + function. In particular, the part of w included in the nuclear norm is + exactly the matrix reshape(rowm(w, range(first_dimension, first_dimension+rows*cols-1)), rows, cols). + Therefore, if you think of the w vector as being the concatenation of a + bunch of matrices then you can use multiple calls to add_nuclear_norm_regularizer() + to add nuclear norm regularization terms to any of the matrices packed into w. + - #num_nuclear_norm_regularizers() == num_nuclear_norm_regularizers() + 1 + !*/ + + unsigned long num_nuclear_norm_regularizers ( + ) const; + /*! + ensures + - returns the number of nuclear norm regularizers that are currently a part + of this optimization problem. That is, returns the number of times + add_nuclear_norm_regularizer() has been called since the last call to + clear_nuclear_norm_regularizers() or object construction, whichever is + most recent. + !*/ + + void clear_nuclear_norm_regularizers ( + ); + /*! + ensures + - #num_nuclear_norm_regularizers() == 0 + !*/ + + void be_verbose ( + ); + /*! + ensures + - This object will print status messages to standard out so that a + user can observe the progress of the algorithm. + !*/ + + void be_quiet( + ); + /*! + ensures + - this object will not print anything to standard out + !*/ + + scalar_type get_c ( + ) const; + /*! + ensures + - returns the SVM regularization parameter. It is the parameter that + determines the trade off between trying to fit the training data + exactly or allowing more errors but hopefully improving the + generalization of the resulting classifier. Larger values encourage + exact fitting while smaller values of C may encourage better + generalization. + !*/ + + void set_c ( + scalar_type C + ); + /*! + requires + - C > 0 + ensures + - #get_c() == C + !*/ + + // -------------------------------- + // User supplied routines + // -------------------------------- + + virtual long get_num_dimensions ( + ) const = 0; + /*! + ensures + - returns the dimensionality of a joint feature vector + !*/ + + virtual long get_num_samples ( + ) const = 0; + /*! + ensures + - returns the number of training samples in this problem. + !*/ + + virtual void get_truth_joint_feature_vector ( + long idx, + feature_vector_type& psi + ) const = 0; + /*! + requires + - 0 <= idx < get_num_samples() + ensures + - #psi == PSI(x_idx, y_idx) + (i.e. the joint feature vector for the idx-th training sample its true label.) + !*/ + + virtual void separation_oracle ( + const long idx, + const matrix_type& current_solution, + scalar_type& loss, + feature_vector_type& psi + ) const = 0; + /*! + requires + - 0 <= idx < get_num_samples() + - current_solution.size() == get_num_dimensions() + ensures + - runs the separation oracle on the idx-th sample. We define this as follows: + - let X == the idx-th training sample. + - let PSI(X,y) == the joint feature vector for input X and an arbitrary label y. + - let F(X,y) == dot(current_solution,PSI(X,y)). + - let LOSS(idx,y) == the loss incurred for predicting that the idx-th sample + has a label of y. Note that LOSS() should always be >= 0 and should + become exactly 0 when y is the correct label for the idx-th sample. + + Then the separation oracle finds a Y such that: + Y = argmax over all y: LOSS(idx,y) + F(X,y) + (i.e. It finds the label which maximizes the above expression.) + + Finally, we can define the outputs of this function as: + - #loss == LOSS(idx,Y) + - #psi == PSI(X,Y) + !*/ + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_STRUCTURAL_SVM_PRObLEM_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/svm/structural_svm_problem_threaded.h b/ml/dlib/dlib/svm/structural_svm_problem_threaded.h new file mode 100644 index 000000000..e981ba8d9 --- /dev/null +++ b/ml/dlib/dlib/svm/structural_svm_problem_threaded.h @@ -0,0 +1,157 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_STRUCTURAL_SVM_PRObLEM_THREADED_Hh_ +#define DLIB_STRUCTURAL_SVM_PRObLEM_THREADED_Hh_ + +#include "structural_svm_problem_threaded_abstract.h" +#include "../algs.h" +#include +#include "structural_svm_problem.h" +#include "../matrix.h" +#include "sparse_vector.h" +#include +#include "../threads.h" +#include "../misc_api.h" +#include "../statistics.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename matrix_type_, + typename feature_vector_type_ = matrix_type_ + > + class structural_svm_problem_threaded : public structural_svm_problem + { + public: + + typedef matrix_type_ matrix_type; + typedef typename matrix_type::type scalar_type; + typedef feature_vector_type_ feature_vector_type; + + explicit structural_svm_problem_threaded ( + unsigned long num_threads + ) : + tp(num_threads), + num_iterations_executed(0) + {} + + unsigned long get_num_threads ( + ) const { return tp.num_threads_in_pool(); } + + private: + + struct binder + { + binder ( + const structural_svm_problem_threaded& self_, + const matrix_type& w_, + matrix_type& subgradient_, + scalar_type& total_loss_, + bool buffer_subgradients_locally_ + ) : self(self_), w(w_), subgradient(subgradient_), total_loss(total_loss_), + buffer_subgradients_locally(buffer_subgradients_locally_){} + + void call_oracle ( + long begin, + long end + ) + { + // If we are only going to call the separation oracle once then don't run + // the slightly more complex for loop version of this code. Or if we just + // don't want to run the complex buffering one. The code later on decides + // if we should do the buffering based on how long it takes to execute. We + // do this because, when the subgradient is really high dimensional it can + // take a lot of time to add them together. So we might want to avoid + // doing that. + if (end-begin <= 1 || !buffer_subgradients_locally) + { + scalar_type loss; + feature_vector_type ftemp; + for (long i = begin; i < end; ++i) + { + self.separation_oracle_cached(i, w, loss, ftemp); + + auto_mutex lock(self.accum_mutex); + total_loss += loss; + add_to(subgradient, ftemp); + } + } + else + { + scalar_type loss = 0; + matrix_type faccum(subgradient.size(),1); + faccum = 0; + + feature_vector_type ftemp; + + for (long i = begin; i < end; ++i) + { + scalar_type loss_temp; + self.separation_oracle_cached(i, w, loss_temp, ftemp); + loss += loss_temp; + add_to(faccum, ftemp); + } + + auto_mutex lock(self.accum_mutex); + total_loss += loss; + add_to(subgradient, faccum); + } + } + + const structural_svm_problem_threaded& self; + const matrix_type& w; + matrix_type& subgradient; + scalar_type& total_loss; + bool buffer_subgradients_locally; + }; + + + virtual void call_separation_oracle_on_all_samples ( + const matrix_type& w, + matrix_type& subgradient, + scalar_type& total_loss + ) const + { + ++num_iterations_executed; + + const uint64 start_time = ts.get_timestamp(); + + bool buffer_subgradients_locally = with_buffer_time.mean() < without_buffer_time.mean(); + + // every 50 iterations we should try to flip the buffering scheme to see if + // doing it the other way might be better. + if ((num_iterations_executed%50) == 0) + { + buffer_subgradients_locally = !buffer_subgradients_locally; + } + + binder b(*this, w, subgradient, total_loss, buffer_subgradients_locally); + parallel_for_blocked(tp, 0, this->get_num_samples(), b, &binder::call_oracle); + + const uint64 stop_time = ts.get_timestamp(); + + if (buffer_subgradients_locally) + with_buffer_time.add(stop_time-start_time); + else + without_buffer_time.add(stop_time-start_time); + + } + + mutable thread_pool tp; + mutable mutex accum_mutex; + mutable timestamper ts; + mutable running_stats with_buffer_time; + mutable running_stats without_buffer_time; + mutable unsigned long num_iterations_executed; + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_STRUCTURAL_SVM_PRObLEM_THREADED_Hh_ + + diff --git a/ml/dlib/dlib/svm/structural_svm_problem_threaded_abstract.h b/ml/dlib/dlib/svm/structural_svm_problem_threaded_abstract.h new file mode 100644 index 000000000..3cfc6a6eb --- /dev/null +++ b/ml/dlib/dlib/svm/structural_svm_problem_threaded_abstract.h @@ -0,0 +1,68 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_STRUCTURAL_SVM_PRObLEM_THREADED_ABSTRACT_Hh_ +#ifdef DLIB_STRUCTURAL_SVM_PRObLEM_THREADED_ABSTRACT_Hh_ + +#include "structural_svm_problem_abstract.h" +#include "../matrix.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename matrix_type_, + typename feature_vector_type_ = matrix_type_ + > + class structural_svm_problem_threaded : public structural_svm_problem + { + public: + /*! + WHAT THIS OBJECT REPRESENTS + This object is identical to the structural_svm_problem object defined in + dlib/svm/structural_svm_problem_abstract.h except that its constructor + takes a number which defines how many threads will be used to make concurrent + calls to the separation_oracle() routine. + + So this object lets you take advantage of a multi-core system. You should + set the num_threads parameter equal to the number of available cores. Note + that the separation_oracle() function which you provide must be thread safe + if you are to use this version of the structural_svm_problem. In + particular, it must be safe to call separation_oracle() concurrently from + different threads. However, it is guaranteed that different threads will + never make concurrent calls to separation_oracle() using the same idx value + (i.e. the first argument). + !*/ + + typedef matrix_type_ matrix_type; + typedef typename matrix_type::type scalar_type; + typedef feature_vector_type_ feature_vector_type; + + structural_svm_problem ( + unsigned long num_threads + ); + /*! + ensures + - this object is properly initialized + - #get_num_threads() == num_threads + !*/ + + unsigned long get_num_threads ( + ) const; + /*! + ensures + - Returns the number of threads which will be used to make concurrent + calls to the separation_oracle() function. + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_STRUCTURAL_SVM_PRObLEM_THREADED_ABSTRACT_Hh_ + + + diff --git a/ml/dlib/dlib/svm/structural_svm_sequence_labeling_problem.h b/ml/dlib/dlib/svm/structural_svm_sequence_labeling_problem.h new file mode 100644 index 000000000..68dff66f5 --- /dev/null +++ b/ml/dlib/dlib/svm/structural_svm_sequence_labeling_problem.h @@ -0,0 +1,281 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_STRUCTURAL_SVM_SEQUENCE_LaBELING_PROBLEM_Hh_ +#define DLIB_STRUCTURAL_SVM_SEQUENCE_LaBELING_PROBLEM_Hh_ + + +#include "structural_svm_sequence_labeling_problem_abstract.h" +#include "../matrix.h" +#include "sequence_labeler.h" +#include +#include "structural_svm_problem_threaded.h" + +// ---------------------------------------------------------------------------------------- + +namespace dlib +{ + + namespace fe_helpers + { + + // ---------------------------------------------------------------------------------------- + + struct get_feats_functor + { + get_feats_functor(std::vector >& feats_) : feats(feats_) {} + + inline void operator() ( + unsigned long feat_index, + double feat_value + ) + { + feats.push_back(std::make_pair(feat_index, feat_value)); + } + + inline void operator() ( + unsigned long feat_index + ) + { + feats.push_back(std::make_pair(feat_index, 1)); + } + + std::vector >& feats; + }; + + // ---------------------------------------------------------------------------------------- + + template + void get_feature_vector( + std::vector >& feats, + const feature_extractor& fe, + const sequence_type& sequence, + const matrix_exp& candidate_labeling, + unsigned long position + ) + { + get_feats_functor funct(feats); + fe.get_features(funct, sequence,candidate_labeling, position); + } + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + class structural_svm_sequence_labeling_problem : noncopyable, + public structural_svm_problem_threaded, std::vector > > + { + public: + typedef matrix matrix_type; + typedef std::vector > feature_vector_type; + + typedef typename feature_extractor::sequence_type sequence_type; + + structural_svm_sequence_labeling_problem( + const std::vector& samples_, + const std::vector >& labels_, + const feature_extractor& fe_, + unsigned long num_threads = 2 + ) : + structural_svm_problem_threaded(num_threads), + samples(samples_), + labels(labels_), + fe(fe_) + { + // make sure requires clause is not broken + DLIB_ASSERT(is_sequence_labeling_problem(samples,labels) == true && + contains_invalid_labeling(fe, samples, labels) == false, + "\t structural_svm_sequence_labeling_problem::structural_svm_sequence_labeling_problem()" + << "\n\t invalid inputs were given to this function" + << "\n\t samples.size(): " << samples.size() + << "\n\t is_sequence_labeling_problem(samples,labels): " << is_sequence_labeling_problem(samples,labels) + << "\n\t contains_invalid_labeling(fe,samples,labels): " << contains_invalid_labeling(fe,samples,labels) + << "\n\t this: " << this + ); + +#ifdef ENABLE_ASSERTS + for (unsigned long i = 0; i < labels.size(); ++i) + { + for (unsigned long j = 0; j < labels[i].size(); ++j) + { + // make sure requires clause is not broken + DLIB_ASSERT(labels[i][j] < fe.num_labels(), + "\t structural_svm_sequence_labeling_problem::structural_svm_sequence_labeling_problem()" + << "\n\t The given labels in labels are invalid." + << "\n\t labels[i][j]: " << labels[i][j] + << "\n\t fe.num_labels(): " << fe.num_labels() + << "\n\t i: " << i + << "\n\t j: " << j + << "\n\t this: " << this + ); + } + } +#endif + + loss_values.assign(num_labels(), 1); + + } + + unsigned long num_labels ( + ) const { return fe.num_labels(); } + + double get_loss ( + unsigned long label + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(label < num_labels(), + "\t void structural_svm_sequence_labeling_problem::get_loss()" + << "\n\t invalid inputs were given to this function" + << "\n\t label: " << label + << "\n\t num_labels(): " << num_labels() + << "\n\t this: " << this + ); + + return loss_values[label]; + } + + void set_loss ( + unsigned long label, + double value + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(label < num_labels() && value >= 0, + "\t void structural_svm_sequence_labeling_problem::set_loss()" + << "\n\t invalid inputs were given to this function" + << "\n\t label: " << label + << "\n\t num_labels(): " << num_labels() + << "\n\t value: " << value + << "\n\t this: " << this + ); + + loss_values[label] = value; + } + + private: + virtual long get_num_dimensions ( + ) const + { + return fe.num_features(); + } + + virtual long get_num_samples ( + ) const + { + return samples.size(); + } + + void get_joint_feature_vector ( + const sequence_type& sample, + const std::vector& label, + feature_vector_type& psi + ) const + { + psi.clear(); + + const int order = fe.order(); + + matrix candidate_labeling; + for (unsigned long i = 0; i < sample.size(); ++i) + { + candidate_labeling = rowm(mat(label), range(i, std::max((int)i-order,0))); + + fe_helpers::get_feature_vector(psi,fe,sample,candidate_labeling, i); + } + } + + virtual void get_truth_joint_feature_vector ( + long idx, + feature_vector_type& psi + ) const + { + get_joint_feature_vector(samples[idx], labels[idx], psi); + } + + class map_prob + { + public: + unsigned long order() const { return fe.order(); } + unsigned long num_states() const { return fe.num_labels(); } + + map_prob( + const sequence_type& sequence_, + const std::vector& label_, + const feature_extractor& fe_, + const matrix& weights_, + const std::vector& loss_values_ + ) : + sequence(sequence_), + label(label_), + fe(fe_), + weights(weights_), + loss_values(loss_values_) + { + } + + unsigned long number_of_nodes( + ) const + { + return sequence.size(); + } + + template < + typename EXP + > + double factor_value ( + unsigned long node_id, + const matrix_exp& node_states + ) const + { + if (dlib::impl::call_reject_labeling_if_exists(fe, sequence, node_states, node_id)) + return -std::numeric_limits::infinity(); + + double loss = 0; + if (node_states(0) != label[node_id]) + loss = loss_values[label[node_id]]; + + return fe_helpers::dot(weights, fe, sequence, node_states, node_id) + loss; + } + + const sequence_type& sequence; + const std::vector& label; + const feature_extractor& fe; + const matrix& weights; + const std::vector& loss_values; + }; + + virtual void separation_oracle ( + const long idx, + const matrix_type& current_solution, + scalar_type& loss, + feature_vector_type& psi + ) const + { + std::vector y; + find_max_factor_graph_viterbi(map_prob(samples[idx],labels[idx],fe,current_solution,loss_values), y); + + loss = 0; + for (unsigned long i = 0; i < y.size(); ++i) + { + if (y[i] != labels[idx][i]) + loss += loss_values[labels[idx][i]]; + } + + get_joint_feature_vector(samples[idx], y, psi); + } + + const std::vector& samples; + const std::vector >& labels; + const feature_extractor& fe; + std::vector loss_values; + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_STRUCTURAL_SVM_SEQUENCE_LaBELING_PROBLEM_Hh_ + diff --git a/ml/dlib/dlib/svm/structural_svm_sequence_labeling_problem_abstract.h b/ml/dlib/dlib/svm/structural_svm_sequence_labeling_problem_abstract.h new file mode 100644 index 000000000..b46a55350 --- /dev/null +++ b/ml/dlib/dlib/svm/structural_svm_sequence_labeling_problem_abstract.h @@ -0,0 +1,110 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_STRUCTURAL_SVM_SEQUENCE_LaBELING_PROBLEM_ABSTRACT_Hh_ +#ifdef DLIB_STRUCTURAL_SVM_SEQUENCE_LaBELING_PROBLEM_ABSTRACT_Hh_ + + +#include "../matrix.h" +#include +#include "structural_svm_problem_threaded_abstract.h" +#include "sequence_labeler_abstract.h" + +// ---------------------------------------------------------------------------------------- + +namespace dlib +{ + + template < + typename feature_extractor + > + class structural_svm_sequence_labeling_problem : noncopyable, + public structural_svm_problem_threaded, + std::vector > > + { + /*! + REQUIREMENTS ON feature_extractor + It must be an object that implements an interface compatible with + the example_feature_extractor defined in dlib/svm/sequence_labeler_abstract.h. + + WHAT THIS OBJECT REPRESENTS + This object is a tool for learning the weight vector needed to use + a sequence_labeler object. + + It learns the parameter vector by formulating the problem as a structural + SVM problem. The general approach is discussed in the paper: + Hidden Markov Support Vector Machines by + Y. Altun, I. Tsochantaridis, T. Hofmann + While the particular optimization strategy used is the method from: + T. Joachims, T. Finley, Chun-Nam Yu, Cutting-Plane Training of + Structural SVMs, Machine Learning, 77(1):27-59, 2009. + !*/ + + public: + typedef typename feature_extractor::sequence_type sequence_type; + + structural_svm_sequence_labeling_problem( + const std::vector& samples, + const std::vector >& labels, + const feature_extractor& fe, + unsigned long num_threads = 2 + ); + /*! + requires + - is_sequence_labeling_problem(samples, labels) == true + - contains_invalid_labeling(fe, samples, labels) == false + - for all valid i and j: labels[i][j] < fe.num_labels() + ensures + - This object attempts to learn a mapping from the given samples to the + given labels. In particular, it attempts to learn to predict labels[i] + based on samples[i]. Or in other words, this object can be used to learn + a parameter vector, w, such that a sequence_labeler declared as: + sequence_labeler labeler(w,fe) + results in a labeler object which attempts to compute the following mapping: + labels[i] == labeler(samples[i]) + - This object will use num_threads threads during the optimization + procedure. You should set this parameter equal to the number of + available processing cores on your machine. + - #num_labels() == fe.num_labels() + - for all valid i: #get_loss(i) == 1 + !*/ + + unsigned long num_labels ( + ) const; + /*! + ensures + - returns the number of possible labels in this learning problem + !*/ + + double get_loss ( + unsigned long label + ) const; + /*! + requires + - label < num_labels() + ensures + - returns the loss incurred when a sequence element with the given + label is misclassified. This value controls how much we care about + correctly classifying this type of label. Larger loss values indicate + that we care more strongly than smaller values. + !*/ + + void set_loss ( + unsigned long label, + double value + ); + /*! + requires + - label < num_labels() + - value >= 0 + ensures + - #get_loss(label) == value + !*/ + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_STRUCTURAL_SVM_SEQUENCE_LaBELING_PROBLEM_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/svm/structural_track_association_trainer.h b/ml/dlib/dlib/svm/structural_track_association_trainer.h new file mode 100644 index 000000000..87fb829b2 --- /dev/null +++ b/ml/dlib/dlib/svm/structural_track_association_trainer.h @@ -0,0 +1,404 @@ +// Copyright (C) 2014 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_STRUCTURAL_TRACK_ASSOCIATION_TRAnER_Hh_ +#define DLIB_STRUCTURAL_TRACK_ASSOCIATION_TRAnER_Hh_ + +#include "structural_track_association_trainer_abstract.h" +#include "../algs.h" +#include "svm.h" +#include +#include "track_association_function.h" +#include "structural_assignment_trainer.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + template < + typename detection_type, + typename label_type + > + std::vector get_unlabeled_dets ( + const std::vector >& dets + ) + { + std::vector temp; + temp.reserve(dets.size()); + for (unsigned long i = 0; i < dets.size(); ++i) + temp.push_back(dets[i].det); + return temp; + } + + } + +// ---------------------------------------------------------------------------------------- + + class structural_track_association_trainer + { + public: + + structural_track_association_trainer ( + ) + { + set_defaults(); + } + + void set_num_threads ( + unsigned long num + ) + { + num_threads = num; + } + + unsigned long get_num_threads ( + ) const + { + return num_threads; + } + + void set_epsilon ( + double eps_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(eps_ > 0, + "\t void structural_track_association_trainer::set_epsilon()" + << "\n\t eps_ must be greater than 0" + << "\n\t eps_: " << eps_ + << "\n\t this: " << this + ); + + eps = eps_; + } + + double get_epsilon ( + ) const { return eps; } + + void set_max_cache_size ( + unsigned long max_size + ) + { + max_cache_size = max_size; + } + + unsigned long get_max_cache_size ( + ) const + { + return max_cache_size; + } + + void set_loss_per_false_association ( + double loss + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(loss > 0, + "\t void structural_track_association_trainer::set_loss_per_false_association(loss)" + << "\n\t Invalid inputs were given to this function " + << "\n\t loss: " << loss + << "\n\t this: " << this + ); + + loss_per_false_association = loss; + } + + double get_loss_per_false_association ( + ) const + { + return loss_per_false_association; + } + + void set_loss_per_track_break ( + double loss + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(loss > 0, + "\t void structural_track_association_trainer::set_loss_per_track_break(loss)" + << "\n\t Invalid inputs were given to this function " + << "\n\t loss: " << loss + << "\n\t this: " << this + ); + + loss_per_track_break = loss; + } + + double get_loss_per_track_break ( + ) const + { + return loss_per_track_break; + } + + void be_verbose ( + ) + { + verbose = true; + } + + void be_quiet ( + ) + { + verbose = false; + } + + void set_oca ( + const oca& item + ) + { + solver = item; + } + + const oca get_oca ( + ) const + { + return solver; + } + + void set_c ( + double C_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(C_ > 0, + "\t void structural_track_association_trainer::set_c()" + << "\n\t C_ must be greater than 0" + << "\n\t C_: " << C_ + << "\n\t this: " << this + ); + + C = C_; + } + + double get_c ( + ) const + { + return C; + } + + bool learns_nonnegative_weights ( + ) const { return learn_nonnegative_weights; } + + void set_learns_nonnegative_weights ( + bool value + ) + { + learn_nonnegative_weights = value; + } + + template < + typename detection_type, + typename label_type + > + const track_association_function train ( + const std::vector > > >& samples + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(is_track_association_problem(samples), + "\t track_association_function structural_track_association_trainer::train()" + << "\n\t invalid inputs were given to this function" + << "\n\t is_track_association_problem(samples): " << is_track_association_problem(samples) + ); + + typedef typename detection_type::track_type track_type; + + const unsigned long num_dims = find_num_dims(samples); + + feature_extractor_track_association fe(num_dims, learn_nonnegative_weights?num_dims:0); + structural_assignment_trainer > trainer(fe); + + + if (verbose) + trainer.be_verbose(); + + trainer.set_c(C); + trainer.set_epsilon(eps); + trainer.set_max_cache_size(max_cache_size); + trainer.set_num_threads(num_threads); + trainer.set_oca(solver); + trainer.set_loss_per_missed_association(loss_per_track_break); + trainer.set_loss_per_false_association(loss_per_false_association); + + std::vector, std::vector > > assignment_samples; + std::vector > labels; + for (unsigned long i = 0; i < samples.size(); ++i) + convert_dets_to_association_sets(samples[i], assignment_samples, labels); + + + return track_association_function(trainer.train(assignment_samples, labels)); + } + + template < + typename detection_type, + typename label_type + > + const track_association_function train ( + const std::vector > >& sample + ) const + { + std::vector > > > samples; + samples.push_back(sample); + return train(samples); + } + + private: + + template < + typename detection_type, + typename label_type + > + static unsigned long find_num_dims ( + const std::vector > > >& samples + ) + { + typedef typename detection_type::track_type track_type; + // find a detection_type object so we can call get_similarity_features() and + // find out how big the feature vectors are. + + // for all detection histories + for (unsigned long i = 0; i < samples.size(); ++i) + { + // for all time instances in the detection history + for (unsigned j = 0; j < samples[i].size(); ++j) + { + if (samples[i][j].size() > 0) + { + track_type new_track; + new_track.update_track(samples[i][j][0].det); + typename track_type::feature_vector_type feats; + new_track.get_similarity_features(samples[i][j][0].det, feats); + return feats.size(); + } + } + } + + DLIB_CASSERT(false, + "No detection objects were given in the call to dlib::structural_track_association_trainer::train()"); + } + + template < + typename detections_at_single_time_step, + typename detection_type, + typename track_type + > + static void convert_dets_to_association_sets ( + const std::vector& det_history, + std::vector, std::vector > >& data, + std::vector >& labels + ) + { + if (det_history.size() < 1) + return; + + typedef typename detections_at_single_time_step::value_type::label_type label_type; + std::vector tracks; + // track_labels maps from detection labels to the index in tracks. So track + // with detection label X is at tracks[track_labels[X]]. + std::map track_labels; + add_dets_to_tracks(tracks, track_labels, det_history[0]); + + using namespace impl; + for (unsigned long i = 1; i < det_history.size(); ++i) + { + data.push_back(std::make_pair(get_unlabeled_dets(det_history[i]), tracks)); + labels.push_back(get_association_labels(det_history[i], track_labels)); + add_dets_to_tracks(tracks, track_labels, det_history[i]); + } + } + + template < + typename labeled_detection, + typename label_type + > + static std::vector get_association_labels( + const std::vector& dets, + const std::map& track_labels + ) + { + std::vector assoc(dets.size(),-1); + // find out which detections associate to what tracks + for (unsigned long i = 0; i < dets.size(); ++i) + { + typename std::map::const_iterator j; + j = track_labels.find(dets[i].label); + // If this detection matches one of the tracks then record which track it + // matched with. + if (j != track_labels.end()) + assoc[i] = j->second; + } + return assoc; + } + + template < + typename track_type, + typename label_type, + typename labeled_detection + > + static void add_dets_to_tracks ( + std::vector& tracks, + std::map& track_labels, + const std::vector& dets + ) + { + std::vector updated_track(tracks.size(), false); + + // first assign the dets to the tracks + for (unsigned long i = 0; i < dets.size(); ++i) + { + const label_type& label = dets[i].label; + if (track_labels.count(label)) + { + const unsigned long track_idx = track_labels[label]; + tracks[track_idx].update_track(dets[i].det); + updated_track[track_idx] = true; + } + else + { + // this detection creates a new track + track_type new_track; + new_track.update_track(dets[i].det); + tracks.push_back(new_track); + track_labels[label] = tracks.size()-1; + } + + } + + // Now propagate all the tracks that didn't get any detections. + for (unsigned long i = 0; i < updated_track.size(); ++i) + { + if (!updated_track[i]) + tracks[i].propagate_track(); + } + } + + double C; + oca solver; + double eps; + bool verbose; + unsigned long num_threads; + unsigned long max_cache_size; + bool learn_nonnegative_weights; + double loss_per_track_break; + double loss_per_false_association; + + void set_defaults () + { + C = 100; + verbose = false; + eps = 0.001; + num_threads = 2; + max_cache_size = 5; + learn_nonnegative_weights = false; + loss_per_track_break = 1; + loss_per_false_association = 1; + } + }; + +} + +#endif // DLIB_STRUCTURAL_TRACK_ASSOCIATION_TRAnER_Hh_ + diff --git a/ml/dlib/dlib/svm/structural_track_association_trainer_abstract.h b/ml/dlib/dlib/svm/structural_track_association_trainer_abstract.h new file mode 100644 index 000000000..e78fadef7 --- /dev/null +++ b/ml/dlib/dlib/svm/structural_track_association_trainer_abstract.h @@ -0,0 +1,268 @@ +// Copyright (C) 2014 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_STRUCTURAL_TRACK_ASSOCIATION_TRAnER_ABSTRACT_Hh_ +#ifdef DLIB_STRUCTURAL_TRACK_ASSOCIATION_TRAnER_ABSTRACT_Hh_ + +#include "track_association_function_abstract.h" +#include "structural_assignment_trainer_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class structural_track_association_trainer + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a tool for learning to solve a track association problem. That + is, it takes in a set of training data and outputs a track_association_function + you can use to do detection to track association. The training data takes the + form of a set or sets of "track histories". Each track history is a + std::vector where each element contains all the detections from a single time + step. Moreover, each detection has a label that uniquely identifies which + object (e.g. person or whatever) the detection really corresponds to. That is, + the labels indicate the correct detection to track associations. The goal of + this object is then to produce a track_association_function that can perform a + correct detection to track association at each time step. + !*/ + + public: + + structural_track_association_trainer ( + ); + /*! + ensures + - #get_c() == 100 + - this object isn't verbose + - #get_epsilon() == 0.001 + - #get_num_threads() == 2 + - #get_max_cache_size() == 5 + - #learns_nonnegative_weights() == false + - #get_loss_per_track_break() == 1 + - #get_loss_per_false_association() == 1 + !*/ + + void set_num_threads ( + unsigned long num + ); + /*! + ensures + - #get_num_threads() == num + !*/ + + unsigned long get_num_threads ( + ) const; + /*! + ensures + - returns the number of threads used during training. You should + usually set this equal to the number of processing cores on your + machine. + !*/ + + void set_epsilon ( + double eps + ); + /*! + requires + - eps > 0 + ensures + - #get_epsilon() == eps + !*/ + + double get_epsilon ( + ) const; + /*! + ensures + - returns the error epsilon that determines when training should stop. + Smaller values may result in a more accurate solution but take longer to + train. You can think of this epsilon value as saying "solve the + optimization problem until the average number of association mistakes per + time step is within epsilon of its optimal value". + !*/ + + void set_max_cache_size ( + unsigned long max_size + ); + /*! + ensures + - #get_max_cache_size() == max_size + !*/ + + unsigned long get_max_cache_size ( + ) const; + /*! + ensures + - During training, this object basically runs the track_association_function on + each training sample, over and over. To speed this up, it is possible to + cache the results of these invocations. This function returns the number + of cache elements per training sample kept in the cache. Note that a value + of 0 means caching is not used at all. + !*/ + + void be_verbose ( + ); + /*! + ensures + - This object will print status messages to standard out so that a user can + observe the progress of the algorithm. + !*/ + + void be_quiet ( + ); + /*! + ensures + - this object will not print anything to standard out + !*/ + + void set_loss_per_false_association ( + double loss + ); + /*! + requires + - loss > 0 + ensures + - #get_loss_per_false_association() == loss + !*/ + + double get_loss_per_false_association ( + ) const; + /*! + ensures + - returns the amount of loss experienced for assigning a detection to the + wrong track. If you care more about avoiding false associations than + avoiding track breaks then you can increase this value. + !*/ + + void set_loss_per_track_break ( + double loss + ); + /*! + requires + - loss > 0 + ensures + - #get_loss_per_track_break() == loss + !*/ + + double get_loss_per_track_break ( + ) const; + /*! + ensures + - returns the amount of loss experienced for incorrectly assigning a + detection to a new track instead of assigning it to its existing track. + If you care more about avoiding track breaks than avoiding things like + track swaps then you can increase this value. + !*/ + + void set_oca ( + const oca& item + ); + /*! + ensures + - #get_oca() == item + !*/ + + const oca get_oca ( + ) const; + /*! + ensures + - Internally this object treats track association learning as a structural + SVM problem. This routine returns a copy of the optimizer used to solve + the structural SVM problem. + !*/ + + void set_c ( + double C + ); + /*! + ensures + - returns the SVM regularization parameter. It is the parameter that + determines the trade-off between trying to fit the training data (i.e. + minimize the loss) or allowing more errors but hopefully improving the + generalization of the resulting track_association_function. Larger + values encourage exact fitting while smaller values of C may encourage + better generalization. + !*/ + + double get_c ( + ) const; + /*! + requires + - C > 0 + ensures + - #get_c() = C + !*/ + + bool learns_nonnegative_weights ( + ) const; + /*! + ensures + - Ultimately, the output of training is a parameter vector that defines the + behavior of the track_association_function. If + learns_nonnegative_weights() == true then the resulting learned parameter + vector will always have non-negative entries. + !*/ + + void set_learns_nonnegative_weights ( + bool value + ); + /*! + ensures + - #learns_nonnegative_weights() == value + !*/ + + template < + typename detection_type, + typename label_type + > + const track_association_function train ( + const std::vector > >& sample + ) const; + /*! + requires + - is_track_association_problem(sample) == true + ensures + - This function attempts to learn to do track association from the given + training data. Note that we interpret sample as a single track history such + that sample[0] are all detections from the first time step, then sample[1] + are detections from the second time step, and so on. + - returns a function F such that: + - Executing F(tracks, detections) will try to correctly associate the + contents of detections to the contents of tracks and perform track + updating and creation. + - if (learns_nonnegative_weights() == true) then + - min(F.get_assignment_function().get_weights()) >= 0 + !*/ + + template < + typename detection_type, + typename label_type + > + const track_association_function train ( + const std::vector > > >& sample + ) const; + /*! + requires + - is_track_association_problem(samples) == true + ensures + - This function attempts to learn to do track association from the given + training data. In this case, we take a set of track histories as + training data instead of just one track history as with the above train() + method. + - returns a function F such that: + - Executing F(tracks, detections) will try to correctly associate the + contents of detections to the contents of tracks and perform track + updating and creation. + - if (learns_nonnegative_weights() == true) then + - min(F.get_assignment_function().get_weights()) >= 0 + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_STRUCTURAL_TRACK_ASSOCIATION_TRAnER_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/svm/svm.h b/ml/dlib/dlib/svm/svm.h new file mode 100644 index 000000000..e0587ef4a --- /dev/null +++ b/ml/dlib/dlib/svm/svm.h @@ -0,0 +1,1205 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SVm_ +#define DLIB_SVm_ + +#include "svm_abstract.h" +#include +#include +#include +#include "../matrix.h" +#include "../algs.h" +#include "../serialize.h" +#include "../rand.h" +#include "../std_allocator.h" +#include "function.h" +#include "kernel.h" +#include "../enable_if.h" +#include "../optimization.h" +#include "svm_nu_trainer.h" +#include +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename U + > + inline bool is_learning_problem_impl ( + const T& x, + const U& x_labels + ) + { + return is_col_vector(x) && + is_col_vector(x_labels) && + x.size() == x_labels.size() && + x.size() > 0; + } + + template < + typename T, + typename U + > + inline bool is_learning_problem ( + const T& x, + const U& x_labels + ) + { + return is_learning_problem_impl(mat(x), mat(x_labels)); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename U + > + bool is_binary_classification_problem_impl ( + const T& x, + const U& x_labels + ) + { + bool seen_neg_class = false; + bool seen_pos_class = false; + + if (is_learning_problem_impl(x,x_labels) == false) + return false; + + if (x.size() <= 1) return false; + + for (long r = 0; r < x_labels.nr(); ++r) + { + if (x_labels(r) != -1 && x_labels(r) != 1) + return false; + + if (x_labels(r) == 1) + seen_pos_class = true; + if (x_labels(r) == -1) + seen_neg_class = true; + } + + return seen_pos_class && seen_neg_class; + } + + template < + typename T, + typename U + > + bool is_binary_classification_problem ( + const T& x, + const U& x_labels + ) + { + return is_binary_classification_problem_impl(mat(x), mat(x_labels)); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename dec_funct_type, + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const matrix test_binary_decision_function_impl ( + const dec_funct_type& dec_funct, + const in_sample_vector_type& x_test, + const in_scalar_vector_type& y_test + ) + { + + // make sure requires clause is not broken + DLIB_ASSERT( is_binary_classification_problem(x_test,y_test) == true, + "\tmatrix test_binary_decision_function()" + << "\n\t invalid inputs were given to this function" + << "\n\t is_binary_classification_problem(x_test,y_test): " + << ((is_binary_classification_problem(x_test,y_test))? "true":"false")); + + + // count the number of positive and negative examples + long num_pos = 0; + long num_neg = 0; + + + long num_pos_correct = 0; + long num_neg_correct = 0; + + + // now test this trained object + for (long i = 0; i < x_test.nr(); ++i) + { + // if this is a positive example + if (y_test(i) == +1.0) + { + ++num_pos; + if (dec_funct(x_test(i)) >= 0) + ++num_pos_correct; + } + else if (y_test(i) == -1.0) + { + ++num_neg; + if (dec_funct(x_test(i)) < 0) + ++num_neg_correct; + } + else + { + throw dlib::error("invalid input labels to the test_binary_decision_function() function"); + } + } + + + matrix res; + res(0) = (double)num_pos_correct/(double)(num_pos); + res(1) = (double)num_neg_correct/(double)(num_neg); + return res; + } + + template < + typename dec_funct_type, + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const matrix test_binary_decision_function ( + const dec_funct_type& dec_funct, + const in_sample_vector_type& x_test, + const in_scalar_vector_type& y_test + ) + { + return test_binary_decision_function_impl(dec_funct, + mat(x_test), + mat(y_test)); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename sequence_type + > + bool is_sequence_labeling_problem ( + const std::vector& samples, + const std::vector >& labels + ) + { + if (is_learning_problem(samples, labels)) + { + for (unsigned long i = 0; i < samples.size(); ++i) + { + if (samples[i].size() != labels[i].size()) + return false; + } + return true; + } + + return false; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename sequence_type + > + bool is_sequence_segmentation_problem ( + const std::vector& samples, + const std::vector > >& segments + ) + { + if (is_learning_problem(samples, segments)) + { + for (unsigned long i = 0; i < samples.size(); ++i) + { + // Make sure the segments are inside samples[i] and don't overlap with each + // other. + std::vector hits(samples[i].size(), false); + for (unsigned long j = 0; j < segments[i].size(); ++j) + { + const unsigned long begin = segments[i][j].first; + const unsigned long end = segments[i][j].second; + // if the segment is outside the sequence + if (end > samples[i].size()) + return false; + + if (begin >= end) + return false; + + // check for overlap + for (unsigned long k = begin; k < end; ++k) + { + if (hits[k]) + return false; + hits[k] = true; + } + } + } + return true; + } + + return false; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename lhs_type, + typename rhs_type + > + bool is_assignment_problem ( + const std::vector, std::vector > >& samples, + const std::vector >& labels + ) + { + std::vector seen_label; + + if (is_learning_problem(samples, labels)) + { + for (unsigned long i = 0; i < samples.size(); ++i) + { + if (samples[i].first.size() != labels[i].size()) + return false; + + seen_label.assign(samples[i].second.size(), false); + + for (unsigned long j = 0; j < labels[i].size(); ++j) + { + if (!(-1 <= labels[i][j] && labels[i][j] < (long)samples[i].second.size())) + return false; + + if (labels[i][j] != -1) + { + // check label uniqueness + if (seen_label[labels[i][j]]) + return false; + + seen_label[labels[i][j]] = true; + } + } + } + return true; + } + + return false; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename lhs_type, + typename rhs_type + > + bool is_forced_assignment_problem ( + const std::vector, std::vector > >& samples, + const std::vector >& labels + ) + { + if (is_assignment_problem(samples, labels)) + { + for (unsigned long i = 0; i < samples.size(); ++i) + { + const unsigned long N = sum(mat(labels[i]) != -1); + if (std::min(samples[i].first.size(), samples[i].second.size()) != N) + return false; + } + return true; + } + + return false; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename detection_type_, + typename label_type_ = long + > + struct labeled_detection + { + typedef detection_type_ detection_type; + typedef label_type_ label_type; + detection_type det; + label_type label; + }; + + template < + typename detection_type_, + typename label_type_ + > + inline void serialize ( const labeled_detection& item, std::ostream& out) + { + serialize(item.det, out); + serialize(item.label, out); + } + + template < + typename detection_type_, + typename label_type_ + > + inline void deserialize (labeled_detection& item, std::istream& in) + { + deserialize(item.det, in); + deserialize(item.label, in); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename detection_type, + typename label_type + > + bool is_track_association_problem ( + const std::vector > >& samples + ) + { + if (samples.size() == 0) + return false; + + unsigned long num_nonzero_elements = 0; + for (unsigned long i = 0; i < samples.size(); ++i) + { + if (samples.size() > 0) + ++num_nonzero_elements; + } + if (num_nonzero_elements < 2) + return false; + + // now make sure the label_type values are unique within each time step. + for (unsigned long i = 0; i < samples.size(); ++i) + { + std::set vals; + for (unsigned long j = 0; j < samples[i].size(); ++j) + vals.insert(samples[i][j].label); + if (vals.size() != samples[i].size()) + return false; + } + + // passed all tests so it's good + return true; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename detection_type, + typename label_type + > + bool is_track_association_problem ( + const std::vector > > >& samples + ) + { + for (unsigned long i = 0; i < samples.size(); ++i) + { + if (!is_track_association_problem(samples[i])) + return false; + } + + // passed all tests so it's good + return true; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type, + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const matrix + cross_validate_trainer_impl ( + const trainer_type& trainer, + const in_sample_vector_type& x, + const in_scalar_vector_type& y, + const long folds + ) + { + typedef typename in_scalar_vector_type::value_type scalar_type; + typedef typename trainer_type::mem_manager_type mem_manager_type; + typedef matrix scalar_vector_type; + + // make sure requires clause is not broken + DLIB_ASSERT(is_binary_classification_problem(x,y) == true && + 1 < folds && folds <= std::min(sum(y>0),sum(y<0)), + "\tmatrix cross_validate_trainer()" + << "\n\t invalid inputs were given to this function" + << "\n\t std::min(sum(y>0),sum(y<0)): " << std::min(sum(y>0),sum(y<0)) + << "\n\t folds: " << folds + << "\n\t is_binary_classification_problem(x,y): " << ((is_binary_classification_problem(x,y))? "true":"false") + ); + + + // count the number of positive and negative examples + long num_pos = 0; + long num_neg = 0; + for (long r = 0; r < y.nr(); ++r) + { + if (y(r) == +1.0) + ++num_pos; + else + ++num_neg; + } + + // figure out how many positive and negative examples we will have in each fold + const long num_pos_test_samples = num_pos/folds; + const long num_pos_train_samples = num_pos - num_pos_test_samples; + const long num_neg_test_samples = num_neg/folds; + const long num_neg_train_samples = num_neg - num_neg_test_samples; + + + matrix x_test, x_train; + scalar_vector_type y_test, y_train; + x_test.set_size (num_pos_test_samples + num_neg_test_samples); + y_test.set_size (num_pos_test_samples + num_neg_test_samples); + x_train.set_size(num_pos_train_samples + num_neg_train_samples); + y_train.set_size(num_pos_train_samples + num_neg_train_samples); + + long pos_idx = 0; + long neg_idx = 0; + + matrix res; + set_all_elements(res,0); + + for (long i = 0; i < folds; ++i) + { + long cur = 0; + + // load up our positive test samples + while (cur < num_pos_test_samples) + { + if (y(pos_idx) == +1.0) + { + x_test(cur) = pos_idx; + y_test(cur) = +1.0; + ++cur; + } + pos_idx = (pos_idx+1)%x.nr(); + } + + // load up our negative test samples + while (cur < x_test.nr()) + { + if (y(neg_idx) == -1.0) + { + x_test(cur) = neg_idx; + y_test(cur) = -1.0; + ++cur; + } + neg_idx = (neg_idx+1)%x.nr(); + } + + // load the training data from the data following whatever we loaded + // as the testing data + long train_pos_idx = pos_idx; + long train_neg_idx = neg_idx; + cur = 0; + + // load up our positive train samples + while (cur < num_pos_train_samples) + { + if (y(train_pos_idx) == +1.0) + { + x_train(cur) = train_pos_idx; + y_train(cur) = +1.0; + ++cur; + } + train_pos_idx = (train_pos_idx+1)%x.nr(); + } + + // load up our negative train samples + while (cur < x_train.nr()) + { + if (y(train_neg_idx) == -1.0) + { + x_train(cur) = train_neg_idx; + y_train(cur) = -1.0; + ++cur; + } + train_neg_idx = (train_neg_idx+1)%x.nr(); + } + + try + { + // do the training and testing + res += test_binary_decision_function(trainer.train(rowm(x,x_train),y_train),rowm(x,x_test),y_test); + } + catch (invalid_nu_error&) + { + // Just ignore the error in this case since we are going to + // interpret an invalid nu value the same as generating a decision + // function that miss-classifies everything. + } + + } // for (long i = 0; i < folds; ++i) + + return res/(double)folds; + } + + template < + typename trainer_type, + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const matrix + cross_validate_trainer ( + const trainer_type& trainer, + const in_sample_vector_type& x, + const in_scalar_vector_type& y, + const long folds + ) + { + return cross_validate_trainer_impl(trainer, + mat(x), + mat(y), + folds); + } + +// ---------------------------------------------------------------------------------------- + + namespace prob_impl + { + template + struct objective + { + objective ( + const vect_type& f_, + const vect_type& t_ + ) : f(f_), t(t_) {} + + double operator() ( + const matrix& x + ) const + { + const double A = x(0); + const double B = x(1); + + double res = 0; + for (unsigned long i = 0; i < f.size(); ++i) + { + const double val = A*f[i]+B; + // See the paper "A Note on Platt's Probabilistic Outputs for Support Vector Machines" + // for an explanation of why this code looks the way it does (rather than being the + // obvious formula). + if (val < 0) + res += (t[i] - 1)*val + std::log(1 + std::exp(val)); + else + res += t[i]*val + std::log(1 + std::exp(-val)); + } + + return res; + } + + const vect_type& f; + const vect_type& t; + }; + + template + struct der + { + der ( + const vect_type& f_, + const vect_type& t_ + ) : f(f_), t(t_) {} + + matrix operator() ( + const matrix& x + ) const + { + const double A = x(0); + const double B = x(1); + + double derA = 0; + double derB = 0; + + for (unsigned long i = 0; i < f.size(); ++i) + { + const double val = A*f[i]+B; + double p; + // compute p = 1/(1+exp(val)) + // but do so in a way that avoids numerical overflow. + if (val < 0) + p = 1.0/(1 + std::exp(val)); + else + p = std::exp(-val)/(1 + std::exp(-val)); + + derA += f[i]*(t[i] - p); + derB += (t[i] - p); + } + + matrix res; + res = derA, derB; + return res; + } + + const vect_type& f; + const vect_type& t; + }; + + template + struct hessian + { + hessian ( + const vect_type& f_, + const vect_type& t_ + ) : f(f_), t(t_) {} + + matrix operator() ( + const matrix& x + ) const + { + const double A = x(0); + const double B = x(1); + + matrix h; + h = 0; + + for (unsigned long i = 0; i < f.size(); ++i) + { + const double val = A*f[i]+B; + // compute pp = 1/(1+exp(val)) and + // compute pn = 1 - pp + // but do so in a way that avoids numerical overflow and catastrophic cancellation. + double pp, pn; + if (val < 0) + { + const double temp = std::exp(val); + pp = 1.0/(1 + temp); + pn = temp*pp; + } + else + { + const double temp = std::exp(-val); + pn = 1.0/(1 + temp); + pp = temp*pn; + } + + h(0,0) += f[i]*f[i]*pp*pn; + const double temp2 = f[i]*pp*pn; + h(0,1) += temp2; + h(1,0) += temp2; + h(1,1) += pp*pn; + } + + return h; + } + + const vect_type& f; + const vect_type& t; + }; + } + +// ---------------------------------------------------------------------------------------- + + inline double platt_scale ( + const std::pair& params, + const double score + ) + { + return 1/(1 + std::exp(params.first*score + params.second)); + } + +// ---------------------------------------------------------------------------------------- + + template + std::pair learn_platt_scaling ( + const std::vector& scores, + const std::vector& labels + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(is_binary_classification_problem(scores,labels) == true, + "\t std::pair learn_platt_scaling()" + << "\n\t invalid inputs were given to this function" + << "\n\t scores.size(): " << scores.size() + << "\n\t labels.size(): " << labels.size() + << "\n\t is_binary_classification_problem(scores,labels): " << is_binary_classification_problem(scores,labels) + ); + + const T num_pos = sum(mat(labels)>0); + const T num_neg = sum(mat(labels)<0); + const T hi_target = (num_pos+1)/(num_pos+2); + const T lo_target = 1.0/(num_neg+2); + + std::vector target; + for (unsigned long i = 0; i < labels.size(); ++i) + { + // if this was a positive example + if (labels[i] == +1.0) + { + target.push_back(hi_target); + } + else if (labels[i] == -1.0) + { + target.push_back(lo_target); + } + else + { + throw dlib::error("invalid input labels to the learn_platt_scaling() function."); + } + } + + // Now find the maximum likelihood parameters of the sigmoid. + + prob_impl::objective > obj(scores, target); + prob_impl::der > obj_der(scores, target); + prob_impl::hessian > obj_hessian(scores, target); + + matrix val; + val = 0; + find_min(newton_search_strategy(obj_hessian), + objective_delta_stop_strategy(), + obj, + obj_der, + val, + 0); + + const double A = val(0); + const double B = val(1); + + return std::make_pair(A,B); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type, + typename sample_vector_type, + typename label_vector_type + > + const probabilistic_function + train_probabilistic_decision_function ( + const trainer_type& trainer, + const sample_vector_type& x, + const label_vector_type& y, + const long folds + ) + { + typedef typename sample_vector_type::value_type sample_type; + typedef typename label_vector_type::value_type scalar_type; + + /* + This function fits a sigmoid function to the output of the + svm trained by svm_nu_trainer or a similar trainer. The + technique used is the one described in the papers: + + Probabilistic Outputs for Support Vector Machines and + Comparisons to Regularized Likelihood Methods by + John C. Platt. March 26, 1999 + + A Note on Platt's Probabilistic Outputs for Support Vector Machines + by Hsuan-Tien Lin, Chih-Jen Lin, and Ruby C. Weng + */ + + // make sure requires clause is not broken + DLIB_ASSERT(is_binary_classification_problem(x,y) == true && + 1 < folds && folds <= (long)x.size(), + "\tprobabilistic_decision_function train_probabilistic_decision_function()" + << "\n\t invalid inputs were given to this function" + << "\n\t x.size(): " << x.size() + << "\n\t y.size(): " << y.size() + << "\n\t folds: " << folds + << "\n\t is_binary_classification_problem(x,y): " << is_binary_classification_problem(x,y) + ); + + // count the number of positive and negative examples + const long num_pos = (long)sum(mat(y) > 0); + const long num_neg = (long)sum(mat(y) < 0); + + // figure out how many positive and negative examples we will have in each fold + const long num_pos_test_samples = num_pos/folds; + const long num_pos_train_samples = num_pos - num_pos_test_samples; + const long num_neg_test_samples = num_neg/folds; + const long num_neg_train_samples = num_neg - num_neg_test_samples; + + typename trainer_type::trained_function_type d; + std::vector x_test, x_train; + std::vector y_test, y_train; + x_test.resize (num_pos_test_samples + num_neg_test_samples); + y_test.resize (num_pos_test_samples + num_neg_test_samples); + x_train.resize(num_pos_train_samples + num_neg_train_samples); + y_train.resize(num_pos_train_samples + num_neg_train_samples); + + std::vector out, out_label; + + long pos_idx = 0; + long neg_idx = 0; + + for (long i = 0; i < folds; ++i) + { + long cur = 0; + + // load up our positive test samples + while (cur < num_pos_test_samples) + { + if (y[pos_idx] == +1.0) + { + x_test[cur] = x[pos_idx]; + y_test[cur] = +1.0; + ++cur; + } + pos_idx = (pos_idx+1)%x.size(); + } + + // load up our negative test samples + while (cur < (long)x_test.size()) + { + if (y[neg_idx] == -1.0) + { + x_test[cur] = x[neg_idx]; + y_test[cur] = -1.0; + ++cur; + } + neg_idx = (neg_idx+1)%x.size(); + } + + // load the training data from the data following whatever we loaded + // as the testing data + long train_pos_idx = pos_idx; + long train_neg_idx = neg_idx; + cur = 0; + + // load up our positive train samples + while (cur < num_pos_train_samples) + { + if (y[train_pos_idx] == +1.0) + { + x_train[cur] = x[train_pos_idx]; + y_train[cur] = +1.0; + ++cur; + } + train_pos_idx = (train_pos_idx+1)%x.size(); + } + + // load up our negative train samples + while (cur < (long)x_train.size()) + { + if (y[train_neg_idx] == -1.0) + { + x_train[cur] = x[train_neg_idx]; + y_train[cur] = -1.0; + ++cur; + } + train_neg_idx = (train_neg_idx+1)%x.size(); + } + + // do the training + d = trainer.train (x_train,y_train); + + // now test this fold + for (unsigned long i = 0; i < x_test.size(); ++i) + { + out.push_back(d(x_test[i])); + out_label.push_back(y_test[i]); + } + + } // for (long i = 0; i < folds; ++i) + + std::pair params = learn_platt_scaling(out, out_label); + + const double A = params.first; + const double B = params.second; + + return probabilistic_function( A, B, trainer.train(x,y) ); + } + +// ---------------------------------------------------------------------------------------- + + template + struct trainer_adapter_probabilistic + { + typedef probabilistic_function trained_function_type; + + const trainer_type trainer; + const long folds; + + trainer_adapter_probabilistic ( + const trainer_type& trainer_, + const long folds_ + ) : trainer(trainer_),folds(folds_) {} + + template < + typename T, + typename U + > + const trained_function_type train ( + const T& samples, + const U& labels + ) const + { + return train_probabilistic_decision_function(trainer, samples, labels, folds); + } + + }; + + template < + typename trainer_type + > + trainer_adapter_probabilistic probabilistic ( + const trainer_type& trainer, + const long folds + ) + { + return trainer_adapter_probabilistic(trainer,folds); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename U, + typename V, + typename rand_type + > + typename enable_if,void>::type randomize_samples ( + T& t, + U& u, + V& v, + rand_type& r + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(is_vector(t) && is_vector(u) && is_vector(v) && u.size() == t.size() && + u.size() == v.size(), + "\t randomize_samples(t,u,v)" + << "\n\t invalid inputs were given to this function" + << "\n\t t.size(): " << t.size() + << "\n\t u.size(): " << u.size() + << "\n\t v.size(): " << v.size() + << "\n\t is_vector(t): " << is_vector(t) + << "\n\t is_vector(u): " << is_vector(u) + << "\n\t is_vector(v): " << is_vector(v) + ); + + long n = t.size()-1; + while (n > 0) + { + // pick a random index to swap into t[n] + const unsigned long idx = r.get_random_32bit_number()%(n+1); + + // swap our randomly selected index into the n position + exchange(t(idx), t(n)); + exchange(u(idx), u(n)); + exchange(v(idx), v(n)); + + --n; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename U, + typename V, + typename rand_type + > + typename disable_if,void>::type randomize_samples ( + T& t, + U& u, + V& v, + rand_type& r + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(u.size() == t.size() && u.size() == v.size(), + "\t randomize_samples(t,u,v)" + << "\n\t invalid inputs were given to this function" + << "\n\t t.size(): " << t.size() + << "\n\t u.size(): " << u.size() + << "\n\t v.size(): " << v.size() + ); + + long n = t.size()-1; + while (n > 0) + { + // pick a random index to swap into t[n] + const unsigned long idx = r.get_random_32bit_number()%(n+1); + + // swap our randomly selected index into the n position + exchange(t[idx], t[n]); + exchange(u[idx], u[n]); + exchange(v[idx], v[n]); + + --n; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename U, + typename V + > + typename disable_if,void>::type randomize_samples ( + T& t, + U& u, + V& v + ) + { + rand r; + randomize_samples(t,u,v,r); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename U, + typename rand_type + > + typename enable_if_c::value && is_rand::value,void>::type randomize_samples ( + T& t, + U& u, + rand_type& r + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(is_vector(t) && is_vector(u) && u.size() == t.size(), + "\t randomize_samples(t,u)" + << "\n\t invalid inputs were given to this function" + << "\n\t t.size(): " << t.size() + << "\n\t u.size(): " << u.size() + << "\n\t is_vector(t): " << (is_vector(t)? "true" : "false") + << "\n\t is_vector(u): " << (is_vector(u)? "true" : "false") + ); + + long n = t.size()-1; + while (n > 0) + { + // pick a random index to swap into t[n] + const unsigned long idx = r.get_random_32bit_number()%(n+1); + + // swap our randomly selected index into the n position + exchange(t(idx), t(n)); + exchange(u(idx), u(n)); + + --n; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename U, + typename rand_type + > + typename disable_if_c::value || !is_rand::value,void>::type randomize_samples ( + T& t, + U& u, + rand_type& r + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(u.size() == t.size(), + "\t randomize_samples(t,u)" + << "\n\t invalid inputs were given to this function" + << "\n\t t.size(): " << t.size() + << "\n\t u.size(): " << u.size() + ); + + long n = t.size()-1; + while (n > 0) + { + // pick a random index to swap into t[n] + const unsigned long idx = r.get_random_32bit_number()%(n+1); + + // swap our randomly selected index into the n position + exchange(t[idx], t[n]); + exchange(u[idx], u[n]); + + --n; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename U + > + typename disable_if,void>::type randomize_samples ( + T& t, + U& u + ) + { + rand r; + randomize_samples(t,u,r); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename rand_type + > + typename enable_if_c::value && is_rand::value,void>::type randomize_samples ( + T& t, + rand_type& r + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(is_vector(t), + "\t randomize_samples(t)" + << "\n\t invalid inputs were given to this function" + << "\n\t is_vector(t): " << (is_vector(t)? "true" : "false") + ); + + long n = t.size()-1; + while (n > 0) + { + // pick a random index to swap into t[n] + const unsigned long idx = r.get_random_32bit_number()%(n+1); + + // swap our randomly selected index into the n position + exchange(t(idx), t(n)); + + --n; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename rand_type + > + typename disable_if_c<(is_matrix::value==true)||(is_rand::value==false),void>::type randomize_samples ( + T& t, + rand_type& r + ) + { + long n = t.size()-1; + while (n > 0) + { + // pick a random index to swap into t[n] + const unsigned long idx = r.get_random_32bit_number()%(n+1); + + // swap our randomly selected index into the n position + exchange(t[idx], t[n]); + + --n; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + void randomize_samples ( + T& t + ) + { + rand r; + randomize_samples(t,r); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SVm_ + diff --git a/ml/dlib/dlib/svm/svm_abstract.h b/ml/dlib/dlib/svm/svm_abstract.h new file mode 100644 index 000000000..ec92cf55b --- /dev/null +++ b/ml/dlib/dlib/svm/svm_abstract.h @@ -0,0 +1,604 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SVm_ABSTRACT_ +#ifdef DLIB_SVm_ABSTRACT_ + +#include +#include +#include +#include "../matrix/matrix_abstract.h" +#include "../algs.h" +#include "../serialize.h" +#include "function_abstract.h" +#include "kernel_abstract.h" +#include "svm_nu_trainer_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename U + > + bool is_learning_problem ( + const T& x, + const U& x_labels + ); + /*! + requires + - T == a matrix or something convertible to a matrix via mat() + - U == a matrix or something convertible to a matrix via mat() + ensures + - returns true if all of the following are true and false otherwise: + - is_col_vector(x) == true + - is_col_vector(x_labels) == true + - x.size() == x_labels.size() + - x.size() > 0 + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename U + > + bool is_binary_classification_problem ( + const T& x, + const U& x_labels + ); + /*! + requires + - T == a matrix or something convertible to a matrix via mat() + - U == a matrix or something convertible to a matrix via mat() + ensures + - returns true if all of the following are true and false otherwise: + - is_learning_problem(x, x_labels) == true + - x.size() > 1 + - there exists at least one sample from both the +1 and -1 classes. + (i.e. all samples can't have the same label) + - for all valid i: + - x_labels(i) == -1 or +1 + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename sequence_type + > + bool is_sequence_labeling_problem ( + const std::vector& samples, + const std::vector >& labels + ); + /*! + ensures + - returns true if all of the following are true and false otherwise: + - is_learning_problem(samples, labels) == true + - for all valid i: + - samples[i].size() == labels[i].size() + (i.e. The size of a label sequence need to match the size of + its corresponding sample sequence) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename sequence_type + > + bool is_sequence_segmentation_problem ( + const std::vector& samples, + const std::vector > >& segments + ); + /*! + ensures + - Note that a sequence segmentation problem is a task where you are given a + sequence of objects (e.g. words in a sentence) and your task is to find + certain types of sub-sequences (e.g. proper names). + - returns true if all of the following are true and false otherwise: + - is_learning_problem(samples, segments) == true + - for all valid i and j: + - We interpret segments[i][j] as defining a half open range starting + with segments[i][j].first and ending just before segments[i][j].second. + - segments[i][j].first < segments[i][j].second + - segments[i][j].second <= samples[i].size() + (i.e. Each segment must be contained within its associated sequence) + - segments[i][j] does not overlap with any of the other ranges in + segments[i]. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename lhs_type, + typename rhs_type + > + bool is_assignment_problem ( + const std::vector, std::vector > >& samples, + const std::vector >& labels + ); + /*! + ensures + - Note that an assignment problem is a task to associate each element of samples[i].first + to an element of samples[i].second, or to indicate that the element doesn't associate + with anything. Therefore, labels[i] should contain the association information for + samples[i]. + - This function returns true if all of the following are true and false otherwise: + - is_learning_problem(samples, labels) == true + - for all valid i: + - samples[i].first.size() == labels[i].size() + - for all valid j: + -1 <= labels[i][j] < samples[i].second.size() + (A value of -1 indicates that samples[i].first[j] isn't associated with anything. + All other values indicate the associating element of samples[i].second) + - All elements of labels[i] which are not equal to -1 are unique. That is, + multiple elements of samples[i].first can't associate to the same element + in samples[i].second. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename lhs_type, + typename rhs_type + > + bool is_forced_assignment_problem ( + const std::vector, std::vector > >& samples, + const std::vector >& labels + ); + /*! + ensures + - A regular assignment problem is allowed to indicate that all elements of + samples[i].first don't associate to anything. However, a forced assignment + problem is required to always associate an element of samples[i].first to + something in samples[i].second if there is an element of samples[i].second + that hasn't already been associated to something. + - This function returns true if all of the following are true and false otherwise: + - is_assignment_problem(samples, labels) == true + - for all valid i: + - let N denote the number of elements in labels[i] that are not equal to -1. + - min(samples[i].first.size(), samples[i].second.size()) == N + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename detection_type_, + typename label_type_ = long + > + struct labeled_detection + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a simple object, like std::pair, it just holds two objects. It + serves the same purpose as std::pair except that it has informative names + describing its two members and is intended for use with track association + problems. + !*/ + + typedef detection_type_ detection_type; + typedef label_type_ label_type; + + detection_type det; + label_type label; + }; + + template < + typename detection_type_, + typename label_type_ + > + void serialize (const labeled_detection& item, std::ostream& out); + /*! + provides serialization support + !*/ + + template < + typename detection_type_, + typename label_type_ + > + void deserialize (labeled_detection& item, std::istream& in); + /*! + provides deserialization support + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename detection_type, + typename label_type + > + bool is_track_association_problem ( + const std::vector > >& samples + ); + /*! + ensures + - In this tracking model you get a set of detections at each time step and are + expected to associate each detection with a track or have it spawn a new + track. Therefore, a track association problem is a machine learning problem + where you are given a dataset of example input detections and are expected to + learn to perform the proper detection to track association. + - This function checks if samples can form a valid dataset for this machine + learning problem and returns true if this is the case. This means we should + interpret samples in the following way: + - samples is a track history and for each valid i: + - samples[i] is a set of labeled detections from the i-th time step. + Each detection has been labeled with its "true object identity". + That is, all the detection throughout the history with the same + label_type value are detections from the same object and therefore + should be associated to the same track. + Putting this all together, samples is a valid track association learning + problem if and only if the following are all true: + - samples.size() > 0 + - There are at least two values, i and j such that: + - i != j + - samples[i].size() > 0 + - samples[j].size() > 0 + Or in other words, there needs to be some detections in samples somewhere + or it is impossible to learn anything. + - for all valid i: + - for all valid j and k where j!=k: + - samples[i][j].label != samples[i][k].label + (i.e. the label_type values must be unique within each time step. + Or in other words, you can't have two detections on the same + object in a single time step.) + !*/ + + template < + typename detection_type, + typename label_type + > + bool is_track_association_problem ( + const std::vector > > >& samples + ); + /*! + ensures + - returns true if is_track_association_problem(samples[i]) == true for all + valid i and false otherwise. + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + double platt_scale ( + const std::pair& params, + const double score + ); + /*! + ensures + - returns 1/(1 + std::exp(params.first*score + params.second)) + !*/ + +// ---------------------------------------------------------------------------------------- + + template + std::pair learn_platt_scaling ( + const std::vector& scores, + const std::vector& labels + ); + /*! + requires + - T should be either float, double, or long double + - is_binary_classification_problem(scores,labels) == true + ensures + - This function learns to map scalar values into well calibrated probabilities + using Platt scaling. In particular, it returns a params object such that, + for all valid i: + - platt_scale(params,scores[i]) == the scaled version of the scalar value + scores[i]. That is, the output is a number between 0 and 1. In + particular, platt_scale(params,scores[i]) is meant to represent the + probability that labels[i] == +1. + - This function is an implementation of the algorithm described in the following + papers: + Probabilistic Outputs for Support Vector Machines and Comparisons to + Regularized Likelihood Methods by John C. Platt. March 26, 1999 + + A Note on Platt's Probabilistic Outputs for Support Vector Machines + by Hsuan-Tien Lin, Chih-Jen Lin, and Ruby C. Weng + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type, + typename sample_vector_type, + typename label_vector_type + > + const probabilistic_function + train_probabilistic_decision_function ( + const trainer_type& trainer, + const sample_vector_type& x, + const label_vector_type& y, + const long folds + ); + /*! + requires + - 1 < folds <= x.size() + - is_binary_classification_problem(x,y) == true + - x and y must be std::vector objects or types with a compatible interface. + - trainer_type == some kind of batch trainer object (e.g. svm_nu_trainer) + ensures + - trains a classifier given the training samples in x and labels in y. + - returns a probabilistic_decision_function that represents the trained classifier. + - The parameters of the probability model are estimated by performing k-fold + cross validation. + - The number of folds used is given by the folds argument. + - This function is implemented using learn_platt_scaling() + throws + - any exceptions thrown by trainer.train() + - std::bad_alloc + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type + > + trainer_adapter_probabilistic probabilistic ( + const trainer_type& trainer, + const long folds + ); + /*! + requires + - 1 < folds <= x.size() + - trainer_type == some kind of batch trainer object (e.g. svm_nu_trainer) + ensures + - returns a trainer adapter TA such that calling TA.train(samples, labels) + returns the same object as calling train_probabilistic_decision_function(trainer,samples,labels,folds). + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// Miscellaneous functions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type, + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const matrix cross_validate_trainer ( + const trainer_type& trainer, + const in_sample_vector_type& x, + const in_scalar_vector_type& y, + const long folds + ); + /*! + requires + - is_binary_classification_problem(x,y) == true + - 1 < folds <= std::min(sum(y>0),sum(y<0)) + (e.g. There must be at least as many examples of each class as there are folds) + - trainer_type == some kind of binary classification trainer object (e.g. svm_nu_trainer) + ensures + - performs k-fold cross validation by using the given trainer to solve the + given binary classification problem for the given number of folds. + Each fold is tested using the output of the trainer and the average + classification accuracy from all folds is returned. + - The average accuracy is computed by running test_binary_decision_function() + on each fold and its output is averaged and returned. + - The number of folds used is given by the folds argument. + throws + - any exceptions thrown by trainer.train() + - std::bad_alloc + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename dec_funct_type, + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const matrix test_binary_decision_function ( + const dec_funct_type& dec_funct, + const in_sample_vector_type& x_test, + const in_scalar_vector_type& y_test + ); + /*! + requires + - is_binary_classification_problem(x_test,y_test) == true + - dec_funct_type == some kind of decision function object (e.g. decision_function) + ensures + - Tests the given decision function by calling it on the x_test and y_test samples. + The output of dec_funct is interpreted as a prediction for the +1 class + if its output is >= 0 and as a prediction for the -1 class otherwise. + - The test accuracy is returned in a row vector, let us call it R. Both + quantities in R are numbers between 0 and 1 which represent the fraction + of examples correctly classified. R(0) is the fraction of +1 examples + correctly classified and R(1) is the fraction of -1 examples correctly + classified. + throws + - std::bad_alloc + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename U + > + void randomize_samples ( + T& samples, + U& labels + ); + /*! + requires + - T == a matrix object or an object compatible with std::vector that contains + a swappable type. + - U == a matrix object or an object compatible with std::vector that contains + a swappable type. + - if samples or labels are matrix objects then is_vector(samples) == true and + is_vector(labels) == true + - samples.size() == labels.size() + ensures + - randomizes the order of the samples and labels but preserves + the pairing between each sample and its label + - A default initialized random number generator is used to perform the randomizing. + Note that this means that each call this this function does the same thing. + That is, the random number generator always uses the same seed. + - for all valid i: + - let r == the random index samples(i) was moved to. then: + - #labels(r) == labels(i) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename U, + typename rand_type + > + void randomize_samples ( + T& samples, + U& labels, + rand_type& rnd + ); + /*! + requires + - T == a matrix object or an object compatible with std::vector that contains + a swappable type. + - U == a matrix object or an object compatible with std::vector that contains + a swappable type. + - if samples or labels are matrix objects then is_vector(samples) == true and + is_vector(labels) == true + - samples.size() == labels.size() + - rand_type == a type that implements the dlib/rand/rand_kernel_abstract.h interface + ensures + - randomizes the order of the samples and labels but preserves + the pairing between each sample and its label + - the given rnd random number generator object is used to do the randomizing + - for all valid i: + - let r == the random index samples(i) was moved to. then: + - #labels(r) == labels(i) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + void randomize_samples ( + T& samples + ); + /*! + requires + - T == a matrix object or an object compatible with std::vector that contains + a swappable type. + - if (samples is a matrix) then + - is_vector(samples) == true + ensures + - randomizes the order of the elements inside samples + - A default initialized random number generator is used to perform the randomizing. + Note that this means that each call this this function does the same thing. + That is, the random number generator always uses the same seed. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename rand_type + > + void randomize_samples ( + T& samples, + rand_type& rnd + ); + /*! + requires + - T == a matrix object or an object compatible with std::vector that contains + a swappable type. + - rand_type == a type that implements the dlib/rand/rand_kernel_abstract.h interface + - if (samples is a matrix) then + - is_vector(samples) == true + ensures + - randomizes the order of the elements inside samples + - the given rnd random number generator object is used to do the randomizing + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename U, + typename V + > + void randomize_samples ( + T& samples, + U& labels, + V& auxiliary + ); + /*! + requires + - T == a matrix object or an object compatible with std::vector that contains + a swappable type. + - U == a matrix object or an object compatible with std::vector that contains + a swappable type. + - V == a matrix object or an object compatible with std::vector that contains + a swappable type. + - if (samples, labels, or auxiliary are matrix objects) then + - is_vector(samples) == true + - is_vector(labels) == true + - is_vector(auxiliary) == true + - samples.size() == labels.size() == auxiliary.size() + ensures + - randomizes the order of the samples, labels, and auxiliary but preserves the + pairing between each sample, its label, and its auxiliary value. + - A default initialized random number generator is used to perform the + randomizing. Note that this means that each call this this function does the + same thing. That is, the random number generator always uses the same seed. + - for all valid i: + - let r == the random index samples(i) was moved to. then: + - #labels(r) == labels(i) + - #auxiliary(r) == auxiliary(i) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename U, + typename V, + typename rand_type + > + void randomize_samples ( + T& samples, + U& labels, + V& auxiliary, + rand_type& rnd + ); + /*! + requires + - T == a matrix object or an object compatible with std::vector that contains + a swappable type. + - U == a matrix object or an object compatible with std::vector that contains + a swappable type. + - V == a matrix object or an object compatible with std::vector that contains + a swappable type. + - if (samples, labels, or auxiliary are matrix objects) then + - is_vector(samples) == true + - is_vector(labels) == true + - is_vector(auxiliary) == true + - samples.size() == labels.size() == auxiliary.size() + - rand_type == a type that implements the dlib/rand/rand_kernel_abstract.h interface + ensures + - randomizes the order of the samples, labels, and auxiliary but preserves the + pairing between each sample, its label, and its auxiliary value. + - the given rnd random number generator object is used to do the randomizing + - for all valid i: + - let r == the random index samples(i) was moved to. then: + - #labels(r) == labels(i) + - #auxiliary(r) == auxiliary(i) + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SVm_ABSTRACT_ + + diff --git a/ml/dlib/dlib/svm/svm_c_ekm_trainer.h b/ml/dlib/dlib/svm/svm_c_ekm_trainer.h new file mode 100644 index 000000000..735e0f22e --- /dev/null +++ b/ml/dlib/dlib/svm/svm_c_ekm_trainer.h @@ -0,0 +1,636 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SVM_C_EKm_TRAINER_Hh_ +#define DLIB_SVM_C_EKm_TRAINER_Hh_ + +#include "../algs.h" +#include "function.h" +#include "kernel.h" +#include "empirical_kernel_map.h" +#include "svm_c_linear_trainer.h" +#include "svm_c_ekm_trainer_abstract.h" +#include "../statistics.h" +#include "../rand.h" +#include + +namespace dlib +{ + template < + typename K + > + class svm_c_ekm_trainer + { + + public: + typedef K kernel_type; + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + typedef decision_function trained_function_type; + + svm_c_ekm_trainer ( + ) + { + verbose = false; + ekm_stale = true; + + initial_basis_size = 10; + basis_size_increment = 50; + max_basis_size = 300; + } + + explicit svm_c_ekm_trainer ( + const scalar_type& C + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(C > 0, + "\t svm_c_ekm_trainer::svm_c_ekm_trainer()" + << "\n\t C must be greater than 0" + << "\n\t C: " << C + << "\n\t this: " << this + ); + + + ocas.set_c(C); + verbose = false; + ekm_stale = true; + + initial_basis_size = 10; + basis_size_increment = 50; + max_basis_size = 300; + } + + void set_epsilon ( + scalar_type eps + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(eps > 0, + "\t void svm_c_ekm_trainer::set_epsilon()" + << "\n\t eps must be greater than 0" + << "\n\t eps: " << eps + << "\n\t this: " << this + ); + + ocas.set_epsilon(eps); + } + + const scalar_type get_epsilon ( + ) const + { + return ocas.get_epsilon(); + } + + void set_max_iterations ( + unsigned long max_iter + ) + { + ocas.set_max_iterations(max_iter); + } + + unsigned long get_max_iterations ( + ) + { + return ocas.get_max_iterations(); + } + + void be_verbose ( + ) + { + verbose = true; + ocas.be_quiet(); + } + + void be_very_verbose ( + ) + { + verbose = true; + ocas.be_verbose(); + } + + void be_quiet ( + ) + { + verbose = false; + ocas.be_quiet(); + } + + void set_oca ( + const oca& item + ) + { + ocas.set_oca(item); + } + + const oca get_oca ( + ) const + { + return ocas.get_oca(); + } + + const kernel_type get_kernel ( + ) const + { + return kern; + } + + void set_kernel ( + const kernel_type& k + ) + { + kern = k; + ekm_stale = true; + } + + template + void set_basis ( + const T& basis_samples + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(basis_samples.size() > 0 && is_vector(mat(basis_samples)), + "\tvoid svm_c_ekm_trainer::set_basis(basis_samples)" + << "\n\t You have to give a non-empty set of basis_samples and it must be a vector" + << "\n\t basis_samples.size(): " << basis_samples.size() + << "\n\t is_vector(mat(basis_samples)): " << is_vector(mat(basis_samples)) + << "\n\t this: " << this + ); + + basis = mat(basis_samples); + ekm_stale = true; + } + + bool basis_loaded( + ) const + { + return (basis.size() != 0); + } + + void clear_basis ( + ) + { + basis.set_size(0); + ekm.clear(); + ekm_stale = true; + } + + unsigned long get_max_basis_size ( + ) const + { + return max_basis_size; + } + + void set_max_basis_size ( + unsigned long max_basis_size_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(max_basis_size_ > 0, + "\t void svm_c_ekm_trainer::set_max_basis_size()" + << "\n\t max_basis_size_ must be greater than 0" + << "\n\t max_basis_size_: " << max_basis_size_ + << "\n\t this: " << this + ); + + max_basis_size = max_basis_size_; + if (initial_basis_size > max_basis_size) + initial_basis_size = max_basis_size; + } + + unsigned long get_initial_basis_size ( + ) const + { + return initial_basis_size; + } + + void set_initial_basis_size ( + unsigned long initial_basis_size_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(initial_basis_size_ > 0, + "\t void svm_c_ekm_trainer::set_initial_basis_size()" + << "\n\t initial_basis_size_ must be greater than 0" + << "\n\t initial_basis_size_: " << initial_basis_size_ + << "\n\t this: " << this + ); + + initial_basis_size = initial_basis_size_; + + if (initial_basis_size > max_basis_size) + max_basis_size = initial_basis_size; + } + + unsigned long get_basis_size_increment ( + ) const + { + return basis_size_increment; + } + + void set_basis_size_increment ( + unsigned long basis_size_increment_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(basis_size_increment_ > 0, + "\t void svm_c_ekm_trainer::set_basis_size_increment()" + << "\n\t basis_size_increment_ must be greater than 0" + << "\n\t basis_size_increment_: " << basis_size_increment_ + << "\n\t this: " << this + ); + + basis_size_increment = basis_size_increment_; + } + + void set_c ( + scalar_type C + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(C > 0, + "\t void svm_c_ekm_trainer::set_c()" + << "\n\t C must be greater than 0" + << "\n\t C: " << C + << "\n\t this: " << this + ); + + ocas.set_c(C); + } + + const scalar_type get_c_class1 ( + ) const + { + return ocas.get_c_class1(); + } + + const scalar_type get_c_class2 ( + ) const + { + return ocas.get_c_class2(); + } + + void set_c_class1 ( + scalar_type C + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(C > 0, + "\t void svm_c_ekm_trainer::set_c_class1()" + << "\n\t C must be greater than 0" + << "\n\t C: " << C + << "\n\t this: " << this + ); + + ocas.set_c_class1(C); + } + + void set_c_class2 ( + scalar_type C + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(C > 0, + "\t void svm_c_ekm_trainer::set_c_class2()" + << "\n\t C must be greater than 0" + << "\n\t C: " << C + << "\n\t this: " << this + ); + + ocas.set_c_class2(C); + } + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y + ) const + { + scalar_type obj; + if (basis_loaded()) + return do_train_user_basis(mat(x),mat(y),obj); + else + return do_train_auto_basis(mat(x),mat(y),obj); + } + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y, + scalar_type& svm_objective + ) const + { + if (basis_loaded()) + return do_train_user_basis(mat(x),mat(y),svm_objective); + else + return do_train_auto_basis(mat(x),mat(y),svm_objective); + } + + + private: + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function do_train_user_basis ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y, + scalar_type& svm_objective + ) const + /*! + requires + - basis_loaded() == true + ensures + - trains an SVM with the user supplied basis + !*/ + { + // make sure requires clause is not broken + DLIB_ASSERT(is_binary_classification_problem(x,y) == true, + "\t decision_function svm_c_ekm_trainer::train(x,y)" + << "\n\t invalid inputs were given to this function" + << "\n\t x.nr(): " << x.nr() + << "\n\t y.nr(): " << y.nr() + << "\n\t x.nc(): " << x.nc() + << "\n\t y.nc(): " << y.nc() + << "\n\t is_binary_classification_problem(x,y): " << is_binary_classification_problem(x,y) + ); + + if (ekm_stale) + { + ekm.load(kern, basis); + ekm_stale = false; + } + + // project all the samples with the ekm + running_stats rs; + std::vector > proj_samples; + proj_samples.reserve(x.size()); + for (long i = 0; i < x.size(); ++i) + { + if (verbose) + { + scalar_type err; + proj_samples.push_back(ekm.project(x(i), err)); + rs.add(err); + } + else + { + proj_samples.push_back(ekm.project(x(i))); + } + } + + if (verbose) + { + std::cout << "\nMean EKM projection error: " << rs.mean() << std::endl; + std::cout << "Standard deviation of EKM projection error: " << rs.stddev() << std::endl; + } + + // now do the training + decision_function > > df; + df = ocas.train(proj_samples, y, svm_objective); + + if (verbose) + { + std::cout << "Final svm objective: " << svm_objective << std::endl; + } + + decision_function final_df; + final_df = ekm.convert_to_decision_function(df.basis_vectors(0)); + final_df.b = df.b; + return final_df; + } + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function do_train_auto_basis ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y, + scalar_type& svm_objective + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(is_binary_classification_problem(x,y) == true, + "\t decision_function svm_c_ekm_trainer::train(x,y)" + << "\n\t invalid inputs were given to this function" + << "\n\t x.nr(): " << x.nr() + << "\n\t y.nr(): " << y.nr() + << "\n\t x.nc(): " << x.nc() + << "\n\t y.nc(): " << y.nc() + << "\n\t is_binary_classification_problem(x,y): " << is_binary_classification_problem(x,y) + ); + + + std::vector > proj_samples(x.size()); + decision_function > > df; + + // we will use a linearly_independent_subset_finder to store our basis set. + linearly_independent_subset_finder lisf(get_kernel(), max_basis_size); + + dlib::rand rnd; + + // first pick the initial basis set randomly + for (unsigned long i = 0; i < 10*initial_basis_size && lisf.size() < initial_basis_size; ++i) + { + lisf.add(x(rnd.get_random_32bit_number()%x.size())); + } + + ekm.load(lisf); + + // first project all samples into the span of the current basis + for (long i = 0; i < x.size(); ++i) + { + proj_samples[i] = ekm.project(x(i)); + } + + + svm_c_linear_trainer > > trainer(ocas); + + const scalar_type min_epsilon = trainer.get_epsilon(); + // while we are determining what the basis set will be we are going to use a very + // lose stopping condition. We will tighten it back up before producing the + // final decision_function. + trainer.set_epsilon(0.2); + + scalar_type prev_svm_objective = std::numeric_limits::max(); + + empirical_kernel_map prev_ekm; + + // This loop is where we try to generate a basis for SVM training. We will + // do this by repeatedly training the SVM and adding a few points which violate the + // margin to the basis in each iteration. + while (true) + { + // if the basis is already as big as it's going to get then just do the most + // accurate training right now. + if (lisf.size() == max_basis_size) + trainer.set_epsilon(min_epsilon); + + while (true) + { + // now do the training. + df = trainer.train(proj_samples, y, svm_objective); + + if (svm_objective < prev_svm_objective) + break; + + // If the training didn't reduce the objective more than last time then + // try lowering the epsilon and doing it again. + if (trainer.get_epsilon() > min_epsilon) + { + trainer.set_epsilon(std::max(trainer.get_epsilon()*0.5, min_epsilon)); + if (verbose) + std::cout << " *** Reducing epsilon to " << trainer.get_epsilon() << std::endl; + } + else + break; + } + + if (verbose) + { + std::cout << "svm objective: " << svm_objective << std::endl; + std::cout << "basis size: " << lisf.size() << std::endl; + } + + // if we failed to make progress on this iteration then we are done + if (svm_objective >= prev_svm_objective) + break; + + prev_svm_objective = svm_objective; + + // now add more elements to the basis + unsigned long count = 0; + for (unsigned long j = 0; + (j < 100*basis_size_increment) && (count < basis_size_increment) && (lisf.size() < max_basis_size); + ++j) + { + // pick a random sample + const unsigned long idx = rnd.get_random_32bit_number()%x.size(); + // If it is a margin violator then it is useful to add it into the basis set. + if (df(proj_samples[idx])*y(idx) < 1) + { + // Add the sample into the basis set if it is linearly independent of all the + // vectors already in the basis set. + if (lisf.add(x(idx))) + { + ++count; + } + } + } + // if we couldn't add any more basis vectors then stop + if (count == 0) + { + if (verbose) + std::cout << "Stopping, couldn't add more basis vectors." << std::endl; + break; + } + + + // Project all the samples into the span of our newly enlarged basis. We will do this + // using the special transformation in the EKM that lets us project from a smaller + // basis set to a larger without needing to reevaluate kernel functions we have already + // computed. + ekm.swap(prev_ekm); + ekm.load(lisf); + projection_function proj_part; + matrix prev_to_new; + prev_ekm.get_transformation_to(ekm, prev_to_new, proj_part); + + + matrix temp; + for (long i = 0; i < x.size(); ++i) + { + // assign to temporary to avoid memory allocation that would result if we + // assigned this expression straight into proj_samples[i] + temp = prev_to_new*proj_samples[i] + proj_part(x(i)); + proj_samples[i] = temp; + + } + } + + // Reproject all the data samples using the final basis. We could just use what we + // already have but the recursive thing done above to compute the proj_samples + // might have accumulated a little numerical error. So lets just be safe. + running_stats rs, rs_margin; + for (long i = 0; i < x.size(); ++i) + { + if (verbose) + { + scalar_type err; + proj_samples[i] = ekm.project(x(i),err); + rs.add(err); + // if this point is within the margin + if (df(proj_samples[i])*y(i) < 1) + rs_margin.add(err); + } + else + { + proj_samples[i] = ekm.project(x(i)); + } + } + + // do the final training + trainer.set_epsilon(min_epsilon); + df = trainer.train(proj_samples, y, svm_objective); + + + if (verbose) + { + std::cout << "\nMean EKM projection error: " << rs.mean() << std::endl; + std::cout << "Standard deviation of EKM projection error: " << rs.stddev() << std::endl; + std::cout << "Mean EKM projection error for margin violators: " << rs_margin.mean() << std::endl; + std::cout << "Standard deviation of EKM projection error for margin violators: " << ((rs_margin.current_n()>1)?rs_margin.stddev():0) << std::endl; + + std::cout << "Final svm objective: " << svm_objective << std::endl; + } + + + decision_function final_df; + final_df = ekm.convert_to_decision_function(df.basis_vectors(0)); + final_df.b = df.b; + + // we don't need the ekm anymore so clear it out + ekm.clear(); + + return final_df; + } + + + + + /*! + CONVENTION + - if (ekm_stale) then + - kern or basis have changed since the last time + they were loaded into the ekm + !*/ + + svm_c_linear_trainer > > ocas; + bool verbose; + + kernel_type kern; + unsigned long max_basis_size; + unsigned long basis_size_increment; + unsigned long initial_basis_size; + + + matrix basis; + mutable empirical_kernel_map ekm; + mutable bool ekm_stale; + + }; + +} + +#endif // DLIB_SVM_C_EKm_TRAINER_Hh_ + + + diff --git a/ml/dlib/dlib/svm/svm_c_ekm_trainer_abstract.h b/ml/dlib/dlib/svm/svm_c_ekm_trainer_abstract.h new file mode 100644 index 000000000..d1ba2bf5f --- /dev/null +++ b/ml/dlib/dlib/svm/svm_c_ekm_trainer_abstract.h @@ -0,0 +1,384 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SVM_C_EKm_TRAINER_ABSTRACT_Hh_ +#ifdef DLIB_SVM_C_EKm_TRAINER_ABSTRACT_Hh_ + +#include "../algs.h" +#include "function_abstract.h" +#include "kernel_abstract.h" +#include "empirical_kernel_map_abstract.h" +#include "svm_c_linear_trainer_abstract.h" + +namespace dlib +{ + template < + typename K + > + class svm_c_ekm_trainer + { + /*! + REQUIREMENTS ON K + is a kernel function object as defined in dlib/svm/kernel_abstract.h + + WHAT THIS OBJECT REPRESENTS + This object represents a tool for training the C formulation of + a support vector machine. It is implemented using the empirical_kernel_map + to kernelize the svm_c_linear_trainer. This makes it a very fast algorithm + capable of learning from very large datasets. + !*/ + + public: + typedef K kernel_type; + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + typedef decision_function trained_function_type; + + svm_c_ekm_trainer ( + ); + /*! + ensures + - This object is properly initialized and ready to be used + to train a support vector machine. + - #get_oca() == oca() (i.e. an instance of oca with default parameters) + - #get_c_class1() == 1 + - #get_c_class2() == 1 + - #get_epsilon() == 0.001 + - #basis_loaded() == false + - #get_initial_basis_size() == 10 + - #get_basis_size_increment() == 50 + - #get_max_basis_size() == 300 + - this object will not be verbose unless be_verbose() is called + - #get_max_iterations() == 10000 + !*/ + + explicit svm_c_ekm_trainer ( + const scalar_type& C + ); + /*! + requires + - C > 0 + ensures + - This object is properly initialized and ready to be used + to train a support vector machine. + - #get_oca() == oca() (i.e. an instance of oca with default parameters) + - #get_c_class1() == C + - #get_c_class2() == C + - #get_epsilon() == 0.001 + - #basis_loaded() == false + - #get_initial_basis_size() == 10 + - #get_basis_size_increment() == 50 + - #get_max_basis_size() == 300 + - this object will not be verbose unless be_verbose() is called + - #get_max_iterations() == 10000 + !*/ + + void set_epsilon ( + scalar_type eps + ); + /*! + requires + - eps > 0 + ensures + - #get_epsilon() == eps + !*/ + + const scalar_type get_epsilon ( + ) const; + /*! + ensures + - returns the error epsilon that determines when training should stop. + Smaller values may result in a more accurate solution but take longer + to execute. + !*/ + + void set_max_iterations ( + unsigned long max_iter + ); + /*! + ensures + - #get_max_iterations() == max_iter + !*/ + + unsigned long get_max_iterations ( + ); + /*! + ensures + - returns the maximum number of iterations the SVM optimizer is allowed to + run before it is required to stop and return a result. + !*/ + + void be_verbose ( + ); + /*! + ensures + - This object will print status messages to standard out so that a + user can observe the progress of the algorithm. + !*/ + + void be_very_verbose ( + ); + /*! + ensures + - This object will print a lot of status messages to standard out so that a + user can observe the progress of the algorithm. In addition to the + few status messages normal verbosity produces this setting also causes + the underlying svm_c_linear_trainer to be verbose. + !*/ + + void be_quiet ( + ); + /*! + ensures + - this object will not print anything to standard out + !*/ + + void set_oca ( + const oca& item + ); + /*! + ensures + - #get_oca() == item + !*/ + + const oca get_oca ( + ) const; + /*! + ensures + - returns a copy of the optimizer used to solve the SVM problem. + !*/ + + const kernel_type get_kernel ( + ) const; + /*! + ensures + - returns a copy of the kernel function in use by this object + !*/ + + void set_kernel ( + const kernel_type& k + ); + /*! + ensures + - #get_kernel() == k + !*/ + + template + void set_basis ( + const T& basis_samples + ); + /*! + requires + - T must be a dlib::matrix type or something convertible to a matrix via mat() + (e.g. a std::vector) + - is_vector(basis_samples) == true + - basis_samples.size() > 0 + - get_kernel() must be capable of operating on the elements of basis_samples. That is, + expressions such as get_kernel()(basis_samples(0), basis_samples(0)) should make sense. + ensures + - #basis_loaded() == true + - training will be carried out in the span of the given basis_samples + !*/ + + bool basis_loaded ( + ) const; + /*! + ensures + - returns true if this object has been loaded with user supplied basis vectors and false otherwise. + !*/ + + void clear_basis ( + ); + /*! + ensures + - #basis_loaded() == false + !*/ + + unsigned long get_max_basis_size ( + ) const; + /*! + ensures + - returns the maximum number of basis vectors this object is allowed + to use. This parameter only matters when the user has not supplied + a basis via set_basis(). + !*/ + + void set_max_basis_size ( + unsigned long max_basis_size + ); + /*! + requires + - max_basis_size > 0 + ensures + - #get_max_basis_size() == max_basis_size + - if (get_initial_basis_size() > max_basis_size) then + - #get_initial_basis_size() == max_basis_size + !*/ + + unsigned long get_initial_basis_size ( + ) const; + /*! + ensures + - If the user does not supply a basis via set_basis() then this object + will generate one automatically. It does this by starting with + a small basis of size N and repeatedly adds basis vectors to it + until a stopping condition is reached. This function returns that + initial size N. + !*/ + + void set_initial_basis_size ( + unsigned long initial_basis_size + ); + /*! + requires + - initial_basis_size > 0 + ensures + - #get_initial_basis_size() == initial_basis_size + - if (initial_basis_size > get_max_basis_size()) then + - #get_max_basis_size() == initial_basis_size + !*/ + + unsigned long get_basis_size_increment ( + ) const; + /*! + ensures + - If the user does not supply a basis via set_basis() then this object + will generate one automatically. It does this by starting with a small + basis and repeatedly adds sets of N basis vectors to it until a stopping + condition is reached. This function returns that increment size N. + !*/ + + void set_basis_size_increment ( + unsigned long basis_size_increment + ); + /*! + requires + - basis_size_increment > 0 + ensures + - #get_basis_size_increment() == basis_size_increment + !*/ + + void set_c ( + scalar_type C + ); + /*! + requires + - C > 0 + ensures + - #get_c_class1() == C + - #get_c_class2() == C + !*/ + + const scalar_type get_c_class1 ( + ) const; + /*! + ensures + - returns the SVM regularization parameter for the +1 class. + It is the parameter that determines the trade off between + trying to fit the +1 training data exactly or allowing more errors + but hopefully improving the generalization ability of the + resulting classifier. Larger values encourage exact fitting + while smaller values of C may encourage better generalization. + !*/ + + const scalar_type get_c_class2 ( + ) const; + /*! + ensures + - returns the SVM regularization parameter for the -1 class. + It is the parameter that determines the trade off between + trying to fit the -1 training data exactly or allowing more errors + but hopefully improving the generalization ability of the + resulting classifier. Larger values encourage exact fitting + while smaller values of C may encourage better generalization. + !*/ + + void set_c_class1 ( + scalar_type C + ); + /*! + requires + - C > 0 + ensures + - #get_c_class1() == C + !*/ + + void set_c_class2 ( + scalar_type C + ); + /*! + requires + - C > 0 + ensures + - #get_c_class2() == C + !*/ + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y + ) const; + /*! + requires + - is_binary_classification_problem(x,y) == true + - x == a matrix or something convertible to a matrix via mat(). + Also, x should contain sample_type objects. + - y == a matrix or something convertible to a matrix via mat(). + Also, y should contain scalar_type objects. + ensures + - trains a C support vector classifier given the training samples in x and + labels in y. + - if (basis_loaded()) then + - training will be carried out in the span of the user supplied basis vectors + - else + - this object will attempt to automatically select an appropriate basis + + - returns a decision function F with the following properties: + - if (new_x is a sample predicted have +1 label) then + - F(new_x) >= 0 + - else + - F(new_x) < 0 + !*/ + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y, + scalar_type& svm_objective + ) const; + /*! + requires + - is_binary_classification_problem(x,y) == true + - x == a matrix or something convertible to a matrix via mat(). + Also, x should contain sample_type objects. + - y == a matrix or something convertible to a matrix via mat(). + Also, y should contain scalar_type objects. + ensures + - trains a C support vector classifier given the training samples in x and + labels in y. + - if (basis_loaded()) then + - training will be carried out in the span of the user supplied basis vectors + - else + - this object will attempt to automatically select an appropriate basis + + - #svm_objective == the final value of the SVM objective function + - returns a decision function F with the following properties: + - if (new_x is a sample predicted have +1 label) then + - F(new_x) >= 0 + - else + - F(new_x) < 0 + !*/ + + }; + +} + +#endif // DLIB_SVM_C_EKm_TRAINER_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/svm/svm_c_linear_dcd_trainer.h b/ml/dlib/dlib/svm/svm_c_linear_dcd_trainer.h new file mode 100644 index 000000000..039b70993 --- /dev/null +++ b/ml/dlib/dlib/svm/svm_c_linear_dcd_trainer.h @@ -0,0 +1,712 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SVm_C_LINEAR_DCD_TRAINER_Hh_ +#define DLIB_SVm_C_LINEAR_DCD_TRAINER_Hh_ + +#include "svm_c_linear_dcd_trainer_abstract.h" +#include +#include +#include "../matrix.h" +#include "../algs.h" +#include "../rand.h" +#include "svm.h" + +#include "function.h" +#include "kernel.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename K + > + class svm_c_linear_dcd_trainer + { + public: + typedef K kernel_type; + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + typedef decision_function trained_function_type; + typedef typename decision_function::sample_vector_type sample_vector_type; + typedef typename decision_function::scalar_vector_type scalar_vector_type; + + // You are getting a compiler error on this line because you supplied a non-linear + // kernel to the svm_c_linear_dcd_trainer object. You have to use one of the + // linear kernels with this trainer. + COMPILE_TIME_ASSERT((is_same_type >::value || + is_same_type >::value )); + + svm_c_linear_dcd_trainer ( + ) : + Cpos(1), + Cneg(1), + eps(0.1), + max_iterations(10000), + verbose(false), + have_bias(true), + last_weight_1(false), + do_shrinking(true), + do_svm_l2(false) + { + } + + explicit svm_c_linear_dcd_trainer ( + const scalar_type& C_ + ) : + Cpos(C_), + Cneg(C_), + eps(0.1), + max_iterations(10000), + verbose(false), + have_bias(true), + last_weight_1(false), + do_shrinking(true), + do_svm_l2(false) + { + // make sure requires clause is not broken + DLIB_ASSERT(0 < C_, + "\tsvm_c_trainer::svm_c_linear_dcd_trainer(kernel,C)" + << "\n\t invalid inputs were given to this function" + << "\n\t C_: " << C_ + ); + } + + bool includes_bias ( + ) const + { + return have_bias; + } + + void include_bias ( + bool should_have_bias + ) + { + have_bias = should_have_bias; + } + + bool forces_last_weight_to_1 ( + ) const + { + return last_weight_1; + } + + void force_last_weight_to_1 ( + bool should_last_weight_be_1 + ) + { + last_weight_1 = should_last_weight_be_1; + } + + bool shrinking_enabled ( + ) const { return do_shrinking; } + + void enable_shrinking ( + bool enabled + ) { do_shrinking = enabled; } + + bool solving_svm_l2_problem ( + ) const { return do_svm_l2; } + + void solve_svm_l2_problem ( + bool enabled + ) { do_svm_l2 = enabled; } + + void be_verbose ( + ) + { + verbose = true; + } + + void be_quiet ( + ) + { + verbose = false; + } + + void set_epsilon ( + scalar_type eps_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(eps_ > 0, + "\tvoid svm_c_linear_dcd_trainer::set_epsilon(eps_)" + << "\n\t invalid inputs were given to this function" + << "\n\t eps_: " << eps_ + ); + eps = eps_; + } + + const scalar_type get_epsilon ( + ) const + { + return eps; + } + + const kernel_type& get_kernel ( + ) const + { + return kernel_type(); + } + + unsigned long get_max_iterations ( + ) const { return max_iterations; } + + void set_max_iterations ( + unsigned long max_iter + ) + { + max_iterations = max_iter; + } + + void set_c ( + scalar_type C + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(C > 0, + "\t void svm_c_linear_dcd_trainer::set_c()" + << "\n\t C must be greater than 0" + << "\n\t C: " << C + << "\n\t this: " << this + ); + + Cpos = C; + Cneg = C; + } + + const scalar_type get_c_class1 ( + ) const + { + return Cpos; + } + + const scalar_type get_c_class2 ( + ) const + { + return Cneg; + } + + void set_c_class1 ( + scalar_type C + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(C > 0, + "\t void svm_c_linear_dcd_trainer::set_c_class1()" + << "\n\t C must be greater than 0" + << "\n\t C: " << C + << "\n\t this: " << this + ); + + Cpos = C; + } + + void set_c_class2 ( + scalar_type C + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(C > 0, + "\t void svm_c_linear_dcd_trainer::set_c_class2()" + << "\n\t C must be greater than 0" + << "\n\t C: " << C + << "\n\t this: " << this + ); + + Cneg = C; + } + + class optimizer_state + { + friend class svm_c_linear_dcd_trainer; + + public: + optimizer_state() : did_init(false) {} + + private: + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + void init( + const in_sample_vector_type& x, + const in_scalar_vector_type& y, + bool have_bias_, + bool last_weight_1_, + bool do_svm_l2_, + scalar_type Cpos, + scalar_type Cneg + ) + { + const long new_dims = max_index_plus_one(x); + long new_idx = 0; + + if (did_init) + { + DLIB_CASSERT(have_bias_ == have_bias && + last_weight_1_ == last_weight_1, + "\t decision_function svm_c_linear_dcd_trainer::train(x,y,state)" + << "\n\t The given state object is invalid because the previous trainer was configured differently." + << "\n\t have_bias_: " << have_bias_ + << "\n\t have_bias: " << have_bias + << "\n\t last_weight_1_: " << last_weight_1_ + << "\n\t last_weight_1: " << last_weight_1 + ); + + DLIB_CASSERT( new_dims >= dims, + "\t decision_function svm_c_linear_dcd_trainer::train(x,y,state)" + << "\n\t The given state object is invalid because the training data dimensions have shrunk." + << "\n\t new_dims: " << new_dims + << "\n\t dims: " << dims + ); + + DLIB_CASSERT( x.size() >= static_cast(alpha.size()), + "\t decision_function svm_c_linear_dcd_trainer::train(x,y,state)" + << "\n\t The given state object is invalid because the training data has fewer samples than previously." + << "\n\t x.size(): " << x.size() + << "\n\t alpha.size(): " << alpha.size() + ); + + // make sure we amortize the cost of growing the alpha vector. + if (alpha.capacity() < static_cast(x.size())) + alpha.reserve(x.size()*2); + + new_idx = alpha.size(); + + // Make sure alpha has the same length as x. So pad with extra zeros if + // necessary to make this happen. + alpha.resize(x.size(),0); + + + if (new_dims != dims) + { + // The only valid way the dimensions can be different here is if + // you are using a sparse vector type. This is because we might + // have had training samples which just happened to not include all + // the features previously. Therefore, max_index_plus_one() would + // have given too low of a result. But for dense vectors it is + // definitely a user error if the dimensions don't match. + + DLIB_CASSERT(is_matrix::value == false, + "\t decision_function svm_c_linear_dcd_trainer::train(x,y,state)" + << "\n\t The given state object is invalid because the training data dimensions have changed." + << "\n\t new_dims: " << new_dims + << "\n\t dims: " << dims + ); + + // extend w by the right number of elements + if (have_bias && !last_weight_1) + { + // Splice some zeros into the w vector so it will have the + // right length. Here we are being careful to move the bias + // weight to the end of the resulting vector. + w = join_cols(join_cols( + colm(w,0,dims), + zeros_matrix(new_dims-dims,1)), + uniform_matrix(1,1,w(dims)) + ); + } + else + { + // Just concatenate the right number of zeros. + w = join_cols(w, zeros_matrix(new_dims-dims,1)); + } + dims = new_dims; + } + + } + else + { + did_init = true; + have_bias = have_bias_; + last_weight_1 = last_weight_1_; + dims = new_dims; + + alpha.resize(x.size()); + + index.reserve(x.size()); + Q.reserve(x.size()); + + if (have_bias && !last_weight_1) + w.set_size(dims+1); + else + w.set_size(dims); + + w = 0; + } + + for (long i = new_idx; i < x.size(); ++i) + { + Q.push_back(length_squared(x(i))); + + if (have_bias && !last_weight_1) + { + index.push_back(i); + Q.back() += 1; + } + else if (Q.back() != 0) + { + index.push_back(i); + } + + if (do_svm_l2_) + { + if (y(i) > 0) + Q.back() += 1/(2*Cpos); + else + Q.back() += 1/(2*Cneg); + } + } + + if (last_weight_1) + w(dims-1) = 1; + } + + template + typename enable_if,scalar_type>::type length_squared (const T& x) const + { + if (!last_weight_1) + { + return dlib::dot(x,x); + } + else + { + // skip the last dimension + return dlib::dot(colm(x,0,x.size()-1), + colm(x,0,x.size()-1)); + } + + } + + template + typename disable_if,scalar_type>::type length_squared (const T& x) const + { + if (!last_weight_1) + { + return dlib::dot(x,x); + } + else + { + scalar_type temp = 0; + typename T::const_iterator i; + for (i = x.begin(); i != x.end(); ++i) + { + // skip the last dimension + if (static_cast(i->first) < dims-1) + temp += i->second*i->second; + } + return temp; + } + } + + + bool did_init; + bool have_bias; + bool last_weight_1; + std::vector alpha; + scalar_vector_type w; + std::vector Q; + std::vector index; + long dims; + dlib::rand rnd; + + public: + + const std::vector& get_alpha () const { return alpha; } + + friend void serialize(const optimizer_state& item, std::ostream& out) + { + const int version = 1; + dlib::serialize(version, out); + dlib::serialize(item.did_init, out); + dlib::serialize(item.have_bias, out); + dlib::serialize(item.last_weight_1, out); + dlib::serialize(item.alpha, out); + dlib::serialize(item.w, out); + dlib::serialize(item.Q, out); + dlib::serialize(item.index, out); + dlib::serialize(item.dims, out); + dlib::serialize(item.rnd, out); + } + + friend void deserialize(optimizer_state& item, std::istream& in) + { + int version = 0; + dlib::deserialize(version, in); + if (version != 1) + { + throw dlib::serialization_error( + "Error while deserializing dlib::svm_c_linear_dcd_trainer::optimizer_state, unexpected version." + ); + } + + dlib::deserialize(item.did_init, in); + dlib::deserialize(item.have_bias, in); + dlib::deserialize(item.last_weight_1, in); + dlib::deserialize(item.alpha, in); + dlib::deserialize(item.w, in); + dlib::deserialize(item.Q, in); + dlib::deserialize(item.index, in); + dlib::deserialize(item.dims, in); + dlib::deserialize(item.rnd, in); + } + + }; + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y + ) const + { + optimizer_state state; + return do_train(mat(x), mat(y), state); + } + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y, + optimizer_state& state + ) const + { + return do_train(mat(x), mat(y), state); + } + + private: + + // ------------------------------------------------------------------------------------ + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function do_train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y, + optimizer_state& state + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(is_learning_problem(x,y) == true, + "\t decision_function svm_c_linear_dcd_trainer::train(x,y)" + << "\n\t invalid inputs were given to this function" + << "\n\t x.size(): " << x.size() + << "\n\t y.size(): " << y.size() + << "\n\t is_learning_problem(x,y): " << is_learning_problem(x,y) + ); +#ifdef ENABLE_ASSERTS + for (long i = 0; i < x.size(); ++i) + { + DLIB_ASSERT(y(i) == +1 || y(i) == -1, + "\t decision_function svm_c_linear_dcd_trainer::train(x,y)" + << "\n\t invalid inputs were given to this function" + << "\n\t y("<& alpha = state.alpha; + scalar_vector_type& w = state.w; + std::vector& index = state.index; + const long dims = state.dims; + + + unsigned long active_size = index.size(); + + scalar_type PG_max_prev = std::numeric_limits::infinity(); + scalar_type PG_min_prev = -std::numeric_limits::infinity(); + + const scalar_type Dii_pos = 1/(2*Cpos); + const scalar_type Dii_neg = 1/(2*Cneg); + + // main loop + for (unsigned long iter = 0; iter < max_iterations; ++iter) + { + scalar_type PG_max = -std::numeric_limits::infinity(); + scalar_type PG_min = std::numeric_limits::infinity(); + + // randomly shuffle the indices + for (unsigned long i = 0; i < active_size; ++i) + { + // pick a random index >= i + const long j = i + state.rnd.get_random_32bit_number()%(active_size-i); + std::swap(index[i], index[j]); + } + + // for all the active training samples + for (unsigned long ii = 0; ii < active_size; ++ii) + { + const long i = index[ii]; + + scalar_type G = y(i)*dot(w, x(i)) - 1; + if (do_svm_l2) + { + if (y(i) > 0) + G += Dii_pos*alpha[i]; + else + G += Dii_neg*alpha[i]; + } + const scalar_type C = (y(i) > 0) ? Cpos : Cneg; + const scalar_type U = do_svm_l2 ? std::numeric_limits::infinity() : C; + + scalar_type PG = 0; + if (alpha[i] == 0) + { + if (G > PG_max_prev) + { + // shrink the active set of training examples + --active_size; + std::swap(index[ii], index[active_size]); + --ii; + continue; + } + + if (G < 0) + PG = G; + } + else if (alpha[i] == U) + { + if (G < PG_min_prev) + { + // shrink the active set of training examples + --active_size; + std::swap(index[ii], index[active_size]); + --ii; + continue; + } + + if (G > 0) + PG = G; + } + else + { + PG = G; + } + + if (PG > PG_max) + PG_max = PG; + if (PG < PG_min) + PG_min = PG; + + // if PG != 0 + if (std::abs(PG) > 1e-12) + { + const scalar_type alpha_old = alpha[i]; + alpha[i] = std::min(std::max(alpha[i] - G/state.Q[i], (scalar_type)0.0), U); + const scalar_type delta = (alpha[i]-alpha_old)*y(i); + add_to(w, x(i), delta); + if (have_bias && !last_weight_1) + w(w.size()-1) -= delta; + + if (last_weight_1) + w(dims-1) = 1; + } + + } + + if (verbose) + { + using namespace std; + cout << "gap: " << PG_max - PG_min << endl; + cout << "active_size: " << active_size << endl; + cout << "iter: " << iter << endl; + cout << endl; + } + + if (PG_max - PG_min <= eps) + { + // stop if we are within eps tolerance and the last iteration + // was over all the samples + if (active_size == index.size()) + break; + + // Turn off shrinking on the next iteration. We will stop if the + // tolerance is still <= eps when shrinking is off. + active_size = index.size(); + PG_max_prev = std::numeric_limits::infinity(); + PG_min_prev = -std::numeric_limits::infinity(); + } + else if (do_shrinking) + { + PG_max_prev = PG_max; + PG_min_prev = PG_min; + if (PG_max_prev <= 0) + PG_max_prev = std::numeric_limits::infinity(); + if (PG_min_prev >= 0) + PG_min_prev = -std::numeric_limits::infinity(); + } + + } // end of main optimization loop + + + + + // put the solution into a decision function and then return it + decision_function df; + if (have_bias && !last_weight_1) + df.b = w(w.size()-1); + else + df.b = 0; + + df.basis_vectors.set_size(1); + // Copy the plane normal into the output basis vector. The output vector might + // be a sparse vector container so we need to use this special kind of copy to + // handle that case. + assign(df.basis_vectors(0), colm(w, 0, dims)); + df.alpha.set_size(1); + df.alpha(0) = 1; + + return df; + } + + scalar_type dot ( + const scalar_vector_type& w, + const sample_type& sample + ) const + { + if (have_bias && !last_weight_1) + { + const long w_size_m1 = w.size()-1; + return dlib::dot(colm(w,0,w_size_m1), sample) - w(w_size_m1); + } + else + { + return dlib::dot(w, sample); + } + } + + // ------------------------------------------------------------------------------------ + + scalar_type Cpos; + scalar_type Cneg; + scalar_type eps; + unsigned long max_iterations; + bool verbose; + bool have_bias; // having a bias means we pretend all x vectors have an extra element which is always -1. + bool last_weight_1; + bool do_shrinking; + bool do_svm_l2; + + }; // end of class svm_c_linear_dcd_trainer + +// ---------------------------------------------------------------------------------------- + + +} + +#endif // DLIB_SVm_C_LINEAR_DCD_TRAINER_Hh_ + + diff --git a/ml/dlib/dlib/svm/svm_c_linear_dcd_trainer_abstract.h b/ml/dlib/dlib/svm/svm_c_linear_dcd_trainer_abstract.h new file mode 100644 index 000000000..b57c54260 --- /dev/null +++ b/ml/dlib/dlib/svm/svm_c_linear_dcd_trainer_abstract.h @@ -0,0 +1,382 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SVm_C_LINEAR_DCD_TRAINER_ABSTRACT_Hh_ +#ifdef DLIB_SVm_C_LINEAR_DCD_TRAINER_ABSTRACT_Hh_ + +#include "function_abstract.h" +#include "kernel_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename K + > + class svm_c_linear_dcd_trainer + { + /*! + REQUIREMENTS ON K + Is either linear_kernel or sparse_linear_kernel. + + WHAT THIS OBJECT REPRESENTS + This object represents a tool for training the C formulation of a support + vector machine. It is optimized for the case where linear kernels are + used. + + + In particular, it is implemented using the algorithm described in the + following paper: + A Dual Coordinate Descent Method for Large-scale Linear SVM + by Cho-Jui Hsieh, Kai-Wei Chang, and Chih-Jen Lin + + It solves the optimization problem of: + min_w: 0.5||w||^2 + C*sum_i (hinge loss for sample i) + where w is the learned SVM parameter vector. + + Note that this object is very similar to the svm_c_linear_trainer, however, + it interprets the C parameter slightly differently. In particular, C for + the DCD trainer is not automatically divided by the number of samples like + it is with the svm_c_linear_trainer. For example, a C value of 10 when + given to the svm_c_linear_trainer is equivalent to a C value of 10/N for + the svm_c_linear_dcd_trainer, where N is the number of training samples. + !*/ + + public: + typedef K kernel_type; + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + typedef decision_function trained_function_type; + typedef typename decision_function::sample_vector_type sample_vector_type; + typedef typename decision_function::scalar_vector_type scalar_vector_type; + + + svm_c_linear_dcd_trainer ( + ); + /*! + ensures + - This object is properly initialized and ready to be used to train a + support vector machine. + - #get_c_class1() == 1 + - #get_c_class2() == 1 + - #get_epsilon() == 0.1 + - #get_max_iterations() == 10000 + - This object will not be verbose unless be_verbose() is called + - #forces_last_weight_to_1() == false + - #includes_bias() == true + - #shrinking_enabled() == true + - #solving_svm_l2_problem() == false + !*/ + + explicit svm_c_linear_dcd_trainer ( + const scalar_type& C + ); + /*! + requires + - C > 0 + ensures + - This object is properly initialized and ready to be used to train a + support vector machine. + - #get_c_class1() == C + - #get_c_class2() == C + - #get_epsilon() == 0.1 + - #get_max_iterations() == 10000 + - This object will not be verbose unless be_verbose() is called + - #forces_last_weight_to_1() == false + - #includes_bias() == true + - #shrinking_enabled() == true + - #solving_svm_l2_problem() == false + !*/ + + bool includes_bias ( + ) const; + /*! + ensures + - returns true if this trainer will produce decision_functions with + non-zero bias values. + !*/ + + void include_bias ( + bool should_have_bias + ); + /*! + ensures + - #includes_bias() == should_have_bias + !*/ + + bool forces_last_weight_to_1 ( + ) const; + /*! + ensures + - returns true if this trainer has the constraint that the last weight in + the learned parameter vector must be 1. This is the weight corresponding + to the feature in the training vectors with the highest dimension. + - Forcing the last weight to 1 also disables the bias and therefore the b + field of the learned decision_function will be 0 when forces_last_weight_to_1() == true. + This is true regardless of the setting of #include_bias(). + !*/ + + void force_last_weight_to_1 ( + bool should_last_weight_be_1 + ); + /*! + ensures + - #forces_last_weight_to_1() == should_last_weight_be_1 + !*/ + + bool shrinking_enabled ( + ) const; + /*! + ensures + - returns true if the shrinking heuristic is enabled. Typically this makes + the algorithm run a lot faster so it should be enabled. + !*/ + + void enable_shrinking ( + bool enabled + ); + /*! + ensures + - #shrinking_enabled() == enabled + !*/ + + bool solving_svm_l2_problem ( + ) const; + /*! + ensures + - returns true if this solver will solve the L2 version of the SVM + objective function. That is, if solving_svm_l2_problem()==true then this + object, rather than using the hinge loss, uses the squared hinge loss. + !*/ + + void solve_svm_l2_problem ( + bool enabled + ); + /*! + ensures + - #solving_svm_l2_problem() == enabled + !*/ + + void be_verbose ( + ); + /*! + ensures + - This object will print status messages to standard out so that a + user can observe the progress of the algorithm. + !*/ + + void be_quiet ( + ); + /*! + ensures + - this object will not print anything to standard out + !*/ + + void set_epsilon ( + scalar_type eps_ + ); + /*! + requires + - eps > 0 + ensures + - #get_epsilon() == eps + !*/ + + const scalar_type get_epsilon ( + ) const; + /*! + ensures + - returns the error epsilon that determines when training should stop. + Smaller values may result in a more accurate solution but take longer to + train. + !*/ + + const kernel_type& get_kernel ( + ) const; + /*! + ensures + - returns a copy of the kernel function in use by this object. Since the + linear kernels don't have any parameters this function just returns + kernel_type() + !*/ + + unsigned long get_max_iterations ( + ) const; + /*! + ensures + - returns the maximum number of iterations the SVM optimizer is allowed to + run before it is required to stop and return a result. + !*/ + + void set_max_iterations ( + unsigned long max_iter + ); + /*! + ensures + - #get_max_iterations() == max_iter + !*/ + + void set_c ( + scalar_type C + ); + /*! + requires + - C > 0 + ensures + - #get_c_class1() == C + - #get_c_class2() == C + !*/ + + const scalar_type get_c_class1 ( + ) const; + /*! + ensures + - returns the SVM regularization parameter for the +1 class. It is the + parameter that determines the trade off between trying to fit the +1 + training data exactly or allowing more errors but hopefully improving the + generalization of the resulting classifier. Larger values encourage + exact fitting while smaller values of C may encourage better + generalization. + !*/ + + const scalar_type get_c_class2 ( + ) const; + /*! + ensures + - returns the SVM regularization parameter for the -1 class. It is the + parameter that determines the trade off between trying to fit the -1 + training data exactly or allowing more errors but hopefully improving the + generalization of the resulting classifier. Larger values encourage + exact fitting while smaller values of C may encourage better + generalization. + !*/ + + void set_c_class1 ( + scalar_type C + ); + /*! + requires + - C > 0 + ensures + - #get_c_class1() == C + !*/ + + void set_c_class2 ( + scalar_type C + ); + /*! + requires + - C > 0 + ensures + - #get_c_class2() == C + !*/ + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y + ) const; + /*! + requires + - is_learning_problem(x,y) == true + (Note that it is ok for x.size() == 1) + - All elements of y must be equal to +1 or -1 + - x == a matrix or something convertible to a matrix via mat(). + Also, x should contain sample_type objects. + - y == a matrix or something convertible to a matrix via mat(). + Also, y should contain scalar_type objects. + ensures + - Trains a C support vector classifier given the training samples in x and + labels in y. + - returns a decision function F with the following properties: + - F.alpha.size() == 1 + - F.basis_vectors.size() == 1 + - F.alpha(0) == 1 + - if (new_x is a sample predicted have +1 label) then + - F(new_x) >= 0 + - else + - F(new_x) < 0 + !*/ + + // optimizer_state is used to record the internal state of the SVM optimizer. It + // can be used with the following train() routine to warm-start the optimizer or + // access the optimal alpha values (see the Hsieh paper mentioned above). The + // optimizer_state objects are serializable and allow you to get the alphas, but + // are otherwise completely opaque to the user. + class optimizer_state + { + public: + const std::vector& get_alpha ( + ) const; + }; + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y, + optimizer_state& state + ) const; + /*! + requires + - is_learning_problem(x,y) == true + (Note that it is ok for x.size() == 1) + - All elements of y must be equal to +1 or -1 + - state must be either a default initialized optimizer_state object or all the + following conditions must be satisfied: + - Let LAST denote the previous trainer used with the state object, then + we must have: + - LAST.includes_bias() == includes_bias() + - LAST.forces_last_weight_to_1() == forces_last_weight_to_1() + - Let X denote the previous training samples used with state, then the + following must be satisfied: + - x.size() >= X.size() + - for all valid i: + - x(i) == X(i) + (i.e. the samples x and X have in common must be identical. + That is, the only allowed difference between x and X is that + x might have new training samples appended onto its end) + - if (x contains dense vectors) then + - max_index_plus_one(x) == max_index_plus_one(X) + - else + - max_index_plus_one(x) >= max_index_plus_one(X) + - x == a matrix or something convertible to a matrix via mat(). + Also, x should contain sample_type objects. + - y == a matrix or something convertible to a matrix via mat(). + Also, y should contain scalar_type objects. + ensures + - Trains a C support vector classifier given the training samples in x and + labels in y. + - The point of the state object is to allow you to warm start the SVM + optimizer from the solution to a previous call to train(). Doing this + might make the training run faster. This is useful when you are trying + different C values or have grown the training set and want to retrain. + - #state == the internal state of the optimizer at the solution to the SVM + problem. Therefore, passing #state to a new call to train() will start + the optimizer from the current solution. + - #state.get_alpha().size() == x.size() + - #state.get_alpha() == the optimal alpha/dual values learned by the optimizer. + - returns a decision function F with the following properties: + - F.alpha.size() == 1 + - F.basis_vectors.size() == 1 + - F.alpha(0) == 1 + - if (new_x is a sample predicted have +1 label) then + - F(new_x) >= 0 + - else + - F(new_x) < 0 + !*/ + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SVm_C_LINEAR_DCD_TRAINER_ABSTRACT_Hh_ + diff --git a/ml/dlib/dlib/svm/svm_c_linear_trainer.h b/ml/dlib/dlib/svm/svm_c_linear_trainer.h new file mode 100644 index 000000000..8d136d711 --- /dev/null +++ b/ml/dlib/dlib/svm/svm_c_linear_trainer.h @@ -0,0 +1,706 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SVM_C_LiNEAR_TRAINER_Hh_ +#define DLIB_SVM_C_LiNEAR_TRAINER_Hh_ + +#include "svm_c_linear_trainer_abstract.h" +#include "../algs.h" +#include "../optimization.h" +#include "../matrix.h" +#include "function.h" +#include "kernel.h" +#include +#include +#include "sparse_vector.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename matrix_type, + typename in_sample_vector_type, + typename in_scalar_vector_type + > + class oca_problem_c_svm : public oca_problem + { + public: + /* + This class is used as part of the implementation of the svm_c_linear_trainer + defined towards the end of this file. + + + The bias parameter is dealt with by imagining that each sample vector has -1 + as its last element. + */ + + typedef typename matrix_type::type scalar_type; + + oca_problem_c_svm( + const scalar_type C_pos, + const scalar_type C_neg, + const in_sample_vector_type& samples_, + const in_scalar_vector_type& labels_, + const bool be_verbose_, + const scalar_type eps_, + const unsigned long max_iter, + const unsigned long dims_ + ) : + samples(samples_), + labels(labels_), + C(std::min(C_pos,C_neg)), + Cpos(C_pos/C), + Cneg(C_neg/C), + be_verbose(be_verbose_), + eps(eps_), + max_iterations(max_iter), + dims(dims_) + { + dot_prods.resize(samples.size()); + is_first_call = true; + } + + virtual scalar_type get_c ( + ) const + { + return C; + } + + virtual long get_num_dimensions ( + ) const + { + // plus 1 for the bias term + return dims + 1; + } + + virtual bool optimization_status ( + scalar_type current_objective_value, + scalar_type current_error_gap, + scalar_type current_risk_value, + scalar_type current_risk_gap, + unsigned long num_cutting_planes, + unsigned long num_iterations + ) const + { + if (be_verbose) + { + using namespace std; + cout << "objective: " << current_objective_value << endl; + cout << "objective gap: " << current_error_gap << endl; + cout << "risk: " << current_risk_value << endl; + cout << "risk gap: " << current_risk_gap << endl; + cout << "num planes: " << num_cutting_planes << endl; + cout << "iter: " << num_iterations << endl; + cout << endl; + } + + if (num_iterations >= max_iterations) + return true; + + if (current_risk_gap < eps) + return true; + + return false; + } + + virtual bool risk_has_lower_bound ( + scalar_type& lower_bound + ) const + { + lower_bound = 0; + return true; + } + + virtual void get_risk ( + matrix_type& w, + scalar_type& risk, + matrix_type& subgradient + ) const + { + line_search(w); + + subgradient.set_size(w.size(),1); + subgradient = 0; + risk = 0; + + + // loop over all the samples and compute the risk and its subgradient at the current solution point w + for (long i = 0; i < samples.size(); ++i) + { + // multiply current SVM output for the ith sample by its label + const scalar_type df_val = labels(i)*dot_prods[i]; + + if (labels(i) > 0) + risk += Cpos*std::max(0.0,1 - df_val); + else + risk += Cneg*std::max(0.0,1 - df_val); + + if (df_val < 1) + { + if (labels(i) > 0) + { + subtract_from(subgradient, samples(i), Cpos); + + subgradient(subgradient.size()-1) += Cpos; + } + else + { + add_to(subgradient, samples(i), Cneg); + + subgradient(subgradient.size()-1) -= Cneg; + } + } + } + + scalar_type scale = 1.0/samples.size(); + + risk *= scale; + subgradient = scale*subgradient; + } + + private: + + // ----------------------------------------------------- + // ----------------------------------------------------- + + void line_search ( + matrix_type& w + ) const + /*! + ensures + - does a line search to find a better w + - for all i: #dot_prods[i] == dot(colm(#w,0,w.size()-1), samples(i)) - #w(w.size()-1) + !*/ + { + // The reason for using w_size_m1 and not just w.size()-1 is because + // doing it this way avoids an inane warning from gcc that can occur in some cases. + const long w_size_m1 = w.size()-1; + for (long i = 0; i < samples.size(); ++i) + dot_prods[i] = dot(colm(w,0,w_size_m1), samples(i)) - w(w_size_m1); + + if (is_first_call) + { + is_first_call = false; + best_so_far = w; + dot_prods_best = dot_prods; + } + else + { + // do line search going from best_so_far to w. Store results in w. + // Here we use the line search algorithm presented in section 3.1.1 of Franc and Sonnenburg. + + const scalar_type A0 = length_squared(best_so_far - w); + const scalar_type BB0 = dot(best_so_far, w - best_so_far); + + const scalar_type scale_pos = (get_c()*Cpos)/samples.size(); + const scalar_type scale_neg = (get_c()*Cneg)/samples.size(); + + ks.clear(); + ks.reserve(samples.size()); + + scalar_type f0 = BB0; + for (long i = 0; i < samples.size(); ++i) + { + const scalar_type& scale = (labels(i)>0) ? scale_pos : scale_neg; + + const scalar_type B = scale*labels(i) * ( dot_prods_best[i] - dot_prods[i]); + const scalar_type C = scale*(1 - labels(i)* dot_prods_best[i]); + // Note that if B is 0 then it doesn't matter what k is set to. So 0 is fine. + scalar_type k = 0; + if (B != 0) + k = -C/B; + + if (k > 0) + ks.push_back(helper(k, std::abs(B))); + + if ( (B < 0 && k > 0) || (B > 0 && k <= 0) ) + f0 += B; + } + + scalar_type opt_k = 1; + // ks.size() == 0 shouldn't happen but check anyway + if (f0 >= 0 || ks.size() == 0) + { + // Getting here means that we aren't searching in a descent direction. + // We could take a zero step but instead lets just assign w to the new best + // so far point just to make sure we don't get stuck coming back to this + // case over and over. This might happen if we never move the best point + // seen so far. + + // So we let opt_k be 1 + } + else + { + std::sort(ks.begin(), ks.end()); + + // figure out where f0 goes positive. + for (unsigned long i = 0; i < ks.size(); ++i) + { + f0 += ks[i].B; + if (f0 + A0*ks[i].k >= 0) + { + opt_k = ks[i].k; + break; + } + } + + } + + // Don't let the step size get too big. Otherwise we might pick huge steps + // over and over that don't improve the cutting plane approximation. + if (opt_k > 1.0) + { + opt_k = 1.0; + } + + // take the step suggested by the line search + best_so_far = (1-opt_k)*best_so_far + opt_k*w; + + // update best_so_far dot products + for (unsigned long i = 0; i < dot_prods_best.size(); ++i) + dot_prods_best[i] = (1-opt_k)*dot_prods_best[i] + opt_k*dot_prods[i]; + + + const scalar_type mu = 0.1; + // Make sure we always take a little bit of a step towards w regardless of what the + // line search says to do. We do this since it is possible that some steps won't + // advance the best_so_far point. So this ensures we always make some progress each + // iteration. + w = (1-mu)*best_so_far + mu*w; + + // update dot products + for (unsigned long i = 0; i < dot_prods.size(); ++i) + dot_prods[i] = (1-mu)*dot_prods_best[i] + mu*dot_prods[i]; + } + } + + struct helper + { + helper(scalar_type k_, scalar_type B_) : k(k_), B(B_) {} + scalar_type k; + scalar_type B; + + bool operator< (const helper& item) const { return k < item.k; } + }; + + mutable std::vector ks; + + mutable bool is_first_call; + mutable std::vector dot_prods; + + mutable matrix_type best_so_far; // best w seen so far + mutable std::vector dot_prods_best; // dot products between best_so_far and samples + + + const in_sample_vector_type& samples; + const in_scalar_vector_type& labels; + const scalar_type C; + const scalar_type Cpos; + const scalar_type Cneg; + + const bool be_verbose; + const scalar_type eps; + const unsigned long max_iterations; + const unsigned long dims; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename matrix_type, + typename in_sample_vector_type, + typename in_scalar_vector_type, + typename scalar_type + > + oca_problem_c_svm make_oca_problem_c_svm ( + const scalar_type C_pos, + const scalar_type C_neg, + const in_sample_vector_type& samples, + const in_scalar_vector_type& labels, + const bool be_verbose, + const scalar_type eps, + const unsigned long max_iterations, + const unsigned long dims + ) + { + return oca_problem_c_svm( + C_pos, C_neg, samples, labels, be_verbose, eps, max_iterations, dims); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename K + > + class svm_c_linear_trainer + { + + public: + typedef K kernel_type; + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + typedef decision_function trained_function_type; + + // You are getting a compiler error on this line because you supplied a non-linear kernel + // to the svm_c_linear_trainer object. You have to use one of the linear kernels with this + // trainer. + COMPILE_TIME_ASSERT((is_same_type >::value || + is_same_type >::value )); + + svm_c_linear_trainer ( + ) + { + Cpos = 1; + Cneg = 1; + verbose = false; + eps = 0.001; + max_iterations = 10000; + learn_nonnegative_weights = false; + last_weight_1 = false; + } + + explicit svm_c_linear_trainer ( + const scalar_type& C + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(C > 0, + "\t svm_c_linear_trainer::svm_c_linear_trainer()" + << "\n\t C must be greater than 0" + << "\n\t C: " << C + << "\n\t this: " << this + ); + + Cpos = C; + Cneg = C; + verbose = false; + eps = 0.001; + max_iterations = 10000; + learn_nonnegative_weights = false; + last_weight_1 = false; + } + + void set_epsilon ( + scalar_type eps_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(eps_ > 0, + "\t void svm_c_linear_trainer::set_epsilon()" + << "\n\t eps_ must be greater than 0" + << "\n\t eps_: " << eps_ + << "\n\t this: " << this + ); + + eps = eps_; + } + + const scalar_type get_epsilon ( + ) const { return eps; } + + unsigned long get_max_iterations ( + ) const { return max_iterations; } + + void set_max_iterations ( + unsigned long max_iter + ) + { + max_iterations = max_iter; + } + + void be_verbose ( + ) + { + verbose = true; + } + + void be_quiet ( + ) + { + verbose = false; + } + + void set_oca ( + const oca& item + ) + { + solver = item; + } + + const oca get_oca ( + ) const + { + return solver; + } + + const kernel_type get_kernel ( + ) const + { + return kernel_type(); + } + + bool learns_nonnegative_weights ( + ) const { return learn_nonnegative_weights; } + + void set_learns_nonnegative_weights ( + bool value + ) + { + learn_nonnegative_weights = value; + if (learn_nonnegative_weights) + prior.set_size(0); + } + + bool forces_last_weight_to_1 ( + ) const + { + return last_weight_1; + } + + void force_last_weight_to_1 ( + bool should_last_weight_be_1 + ) + { + last_weight_1 = should_last_weight_be_1; + if (last_weight_1) + prior.set_size(0); + } + + void set_prior ( + const trained_function_type& prior_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(prior_.basis_vectors.size() == 1 && + prior_.alpha(0) == 1, + "\t void svm_c_linear_trainer::set_prior()" + << "\n\t The supplied prior could not have been created by this object's train() method." + << "\n\t prior_.basis_vectors.size(): " << prior_.basis_vectors.size() + << "\n\t prior_.alpha(0): " << prior_.alpha(0) + << "\n\t this: " << this + ); + + prior = sparse_to_dense(prior_.basis_vectors(0)); + prior_b = prior_.b; + learn_nonnegative_weights = false; + last_weight_1 = false; + } + + bool has_prior ( + ) const + { + return prior.size() != 0; + } + + void set_c ( + scalar_type C + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(C > 0, + "\t void svm_c_linear_trainer::set_c()" + << "\n\t C must be greater than 0" + << "\n\t C: " << C + << "\n\t this: " << this + ); + + Cpos = C; + Cneg = C; + } + + const scalar_type get_c_class1 ( + ) const + { + return Cpos; + } + + const scalar_type get_c_class2 ( + ) const + { + return Cneg; + } + + void set_c_class1 ( + scalar_type C + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(C > 0, + "\t void svm_c_linear_trainer::set_c_class1()" + << "\n\t C must be greater than 0" + << "\n\t C: " << C + << "\n\t this: " << this + ); + + Cpos = C; + } + + void set_c_class2 ( + scalar_type C + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(C > 0, + "\t void svm_c_linear_trainer::set_c_class2()" + << "\n\t C must be greater than 0" + << "\n\t C: " << C + << "\n\t this: " << this + ); + + Cneg = C; + } + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y + ) const + { + scalar_type obj; + return do_train(mat(x),mat(y),obj); + } + + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y, + scalar_type& svm_objective + ) const + { + return do_train(mat(x),mat(y),svm_objective); + } + + private: + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function do_train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y, + scalar_type& svm_objective + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(is_learning_problem(x,y) == true, + "\t decision_function svm_c_linear_trainer::train(x,y)" + << "\n\t invalid inputs were given to this function" + << "\n\t x.nr(): " << x.nr() + << "\n\t y.nr(): " << y.nr() + << "\n\t x.nc(): " << x.nc() + << "\n\t y.nc(): " << y.nc() + << "\n\t is_learning_problem(x,y): " << is_learning_problem(x,y) + ); +#ifdef ENABLE_ASSERTS + for (long i = 0; i < x.size(); ++i) + { + DLIB_ASSERT(y(i) == +1 || y(i) == -1, + "\t decision_function svm_c_linear_trainer::train(x,y)" + << "\n\t invalid inputs were given to this function" + << "\n\t y("< w_type; + w_type w; + + const unsigned long num_dims = max_index_plus_one(x); + + unsigned long num_nonnegative = 0; + if (learn_nonnegative_weights) + { + num_nonnegative = num_dims; + } + + unsigned long force_weight_1_idx = std::numeric_limits::max(); + if (last_weight_1) + { + force_weight_1_idx = num_dims-1; + } + + + if (has_prior()) + { + if (is_matrix::value) + { + // make sure requires clause is not broken + DLIB_CASSERT(num_dims == (unsigned long)prior.size(), + "\t decision_function svm_c_linear_trainer::train(x,y)" + << "\n\t The dimension of the training vectors must match the dimension of\n" + << "\n\t those used to create the prior." + << "\n\t num_dims: " << num_dims + << "\n\t prior.size(): " << prior.size() + ); + } + const unsigned long dims = std::max(num_dims, (unsigned long)prior.size()); + // In the case of sparse sample vectors, it is possible that the input + // vector dimensionality is larger than the prior vector dimensionality. + // We need to check for this case and pad prior with zeros if it is the + // case. + matrix prior_temp = join_cols(join_cols(prior, + zeros_matrix(dims-prior.size(),1)), + mat(prior_b)); + + svm_objective = solver( + make_oca_problem_c_svm(Cpos, Cneg, x, y, verbose, eps, max_iterations, dims), + w, + prior_temp); + } + else + { + svm_objective = solver( + make_oca_problem_c_svm(Cpos, Cneg, x, y, verbose, eps, max_iterations, num_dims), + w, + num_nonnegative, + force_weight_1_idx); + } + + // put the solution into a decision function and then return it + decision_function df; + df.b = static_cast(w(w.size()-1)); + df.basis_vectors.set_size(1); + // Copy the plane normal into the output basis vector. The output vector might be a + // sparse vector container so we need to use this special kind of copy to handle that case. + // As an aside, the reason for using max_index_plus_one() and not just w.size()-1 is because + // doing it this way avoids an inane warning from gcc that can occur in some cases. + const long out_size = max_index_plus_one(x); + assign(df.basis_vectors(0), matrix_cast(colm(w, 0, out_size))); + df.alpha.set_size(1); + df.alpha(0) = 1; + + return df; + } + + scalar_type Cpos; + scalar_type Cneg; + oca solver; + scalar_type eps; + bool verbose; + unsigned long max_iterations; + bool learn_nonnegative_weights; + bool last_weight_1; + matrix prior; + scalar_type prior_b = 0; + }; + +// ---------------------------------------------------------------------------------------- + +} + +// ---------------------------------------------------------------------------------------- + + +#endif // DLIB_SVM_C_LiNEAR_TRAINER_Hh_ + diff --git a/ml/dlib/dlib/svm/svm_c_linear_trainer_abstract.h b/ml/dlib/dlib/svm/svm_c_linear_trainer_abstract.h new file mode 100644 index 000000000..1b7a128f0 --- /dev/null +++ b/ml/dlib/dlib/svm/svm_c_linear_trainer_abstract.h @@ -0,0 +1,359 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SVM_C_LiNEAR_TRAINER_ABSTRACT_Hh_ +#ifdef DLIB_SVM_C_LiNEAR_TRAINER_ABSTRACT_Hh_ + +#include "../matrix/matrix_abstract.h" +#include "../algs.h" +#include "function_abstract.h" +#include "kernel_abstract.h" +#include "sparse_kernel_abstract.h" + +namespace dlib +{ + template < + typename K + > + class svm_c_linear_trainer + { + /*! + REQUIREMENTS ON K + Is either linear_kernel or sparse_linear_kernel. + + WHAT THIS OBJECT REPRESENTS + This object represents a tool for training the C formulation of + a support vector machine. It is optimized for the case where + linear kernels are used. + + + In particular, it is implemented using the OCAS algorithm + described in the following paper: + Optimized Cutting Plane Algorithm for Large-Scale Risk Minimization + Vojtech Franc, Soren Sonnenburg; Journal of Machine Learning + Research, 10(Oct):2157--2192, 2009. + !*/ + + public: + typedef K kernel_type; + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + typedef decision_function trained_function_type; + + svm_c_linear_trainer ( + ); + /*! + ensures + - This object is properly initialized and ready to be used + to train a support vector machine. + - #get_oca() == oca() (i.e. an instance of oca with default parameters) + - #get_c_class1() == 1 + - #get_c_class2() == 1 + - #get_epsilon() == 0.001 + - this object will not be verbose unless be_verbose() is called + - #get_max_iterations() == 10000 + - #learns_nonnegative_weights() == false + - #force_last_weight_to_1() == false + - #has_prior() == false + !*/ + + explicit svm_c_linear_trainer ( + const scalar_type& C + ); + /*! + requires + - C > 0 + ensures + - This object is properly initialized and ready to be used + to train a support vector machine. + - #get_oca() == oca() (i.e. an instance of oca with default parameters) + - #get_c_class1() == C + - #get_c_class2() == C + - #get_epsilon() == 0.001 + - this object will not be verbose unless be_verbose() is called + - #get_max_iterations() == 10000 + - #learns_nonnegative_weights() == false + - #force_last_weight_to_1() == false + - #has_prior() == false + !*/ + + void set_epsilon ( + scalar_type eps + ); + /*! + requires + - eps > 0 + ensures + - #get_epsilon() == eps + !*/ + + const scalar_type get_epsilon ( + ) const; + /*! + ensures + - returns the error epsilon that determines when training should stop. + Smaller values may result in a more accurate solution but take longer to + train. You can think of this epsilon value as saying "solve the + optimization problem until the probability of misclassification is within + epsilon of its optimal value". + !*/ + + void set_max_iterations ( + unsigned long max_iter + ); + /*! + ensures + - #get_max_iterations() == max_iter + !*/ + + unsigned long get_max_iterations ( + ); + /*! + ensures + - returns the maximum number of iterations the SVM optimizer is allowed to + run before it is required to stop and return a result. + !*/ + + void be_verbose ( + ); + /*! + ensures + - This object will print status messages to standard out so that a + user can observe the progress of the algorithm. + !*/ + + void be_quiet ( + ); + /*! + ensures + - this object will not print anything to standard out + !*/ + + void set_oca ( + const oca& item + ); + /*! + ensures + - #get_oca() == item + !*/ + + const oca get_oca ( + ) const; + /*! + ensures + - returns a copy of the optimizer used to solve the SVM problem. + !*/ + + const kernel_type get_kernel ( + ) const; + /*! + ensures + - returns a copy of the kernel function in use by this object. Since + the linear kernels don't have any parameters this function just + returns kernel_type() + !*/ + + bool learns_nonnegative_weights ( + ) const; + /*! + ensures + - The output of training is a weight vector and a bias value. These + two things define the resulting decision function. That is, the + decision function simply takes the dot product between the learned + weight vector and a test sample, then subtracts the bias value. + Therefore, if learns_nonnegative_weights() == true then the resulting + learned weight vector will always have non-negative entries. The + bias value may still be negative though. + !*/ + + void set_learns_nonnegative_weights ( + bool value + ); + /*! + ensures + - #learns_nonnegative_weights() == value + - if (value == true) then + - #has_prior() == false + !*/ + + void set_prior ( + const trained_function_type& prior + ); + /*! + requires + - prior == a function produced by a call to this class's train() function. + Therefore, it must be the case that: + - prior.basis_vectors.size() == 1 + - prior.alpha(0) == 1 + ensures + - Subsequent calls to train() will try to learn a function similar to the + given prior. + - #has_prior() == true + - #learns_nonnegative_weights() == false + - #forces_last_weight_to_1() == false + !*/ + + bool has_prior ( + ) const + /*! + ensures + - returns true if a prior has been set and false otherwise. Having a prior + set means that you have called set_prior() and supplied a previously + trained function as a reference. In this case, any call to train() will + try to learn a function that matches the behavior of the prior as close + as possible but also fits the supplied training data. In more technical + detail, having a prior means we replace the ||w||^2 regularizer with one + of the form ||w-prior||^2 where w is the set of parameters for a learned + function. + !*/ + + bool forces_last_weight_to_1 ( + ) const; + /*! + ensures + - returns true if this trainer has the constraint that the last weight in + the learned parameter vector must be 1. This is the weight corresponding + to the feature in the training vectors with the highest dimension. + - Forcing the last weight to 1 also disables the bias and therefore the b + field of the learned decision_function will be 0 when forces_last_weight_to_1() == true. + !*/ + + void force_last_weight_to_1 ( + bool should_last_weight_be_1 + ); + /*! + ensures + - #forces_last_weight_to_1() == should_last_weight_be_1 + - if (should_last_weight_be_1 == true) then + - #has_prior() == false + !*/ + + void set_c ( + scalar_type C + ); + /*! + requires + - C > 0 + ensures + - #get_c_class1() == C + - #get_c_class2() == C + !*/ + + const scalar_type get_c_class1 ( + ) const; + /*! + ensures + - returns the SVM regularization parameter for the +1 class. + It is the parameter that determines the trade off between + trying to fit the +1 training data exactly or allowing more errors + but hopefully improving the generalization of the resulting + classifier. Larger values encourage exact fitting while + smaller values of C may encourage better generalization. + !*/ + + const scalar_type get_c_class2 ( + ) const; + /*! + ensures + - returns the SVM regularization parameter for the -1 class. + It is the parameter that determines the trade off between + trying to fit the -1 training data exactly or allowing more errors + but hopefully improving the generalization of the resulting + classifier. Larger values encourage exact fitting while + smaller values of C may encourage better generalization. + !*/ + + void set_c_class1 ( + scalar_type C + ); + /*! + requires + - C > 0 + ensures + - #get_c_class1() == C + !*/ + + void set_c_class2 ( + scalar_type C + ); + /*! + requires + - C > 0 + ensures + - #get_c_class2() == C + !*/ + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y + ) const; + /*! + requires + - is_learning_problem(x,y) == true + (Note that it is ok for x.size() == 1) + - All elements of y must be equal to +1 or -1 + - x == a matrix or something convertible to a matrix via mat(). + Also, x should contain sample_type objects. + - y == a matrix or something convertible to a matrix via mat(). + Also, y should contain scalar_type objects. + - if (has_prior()) then + - The vectors in x must have the same dimensionality as the vectors + used to train the prior given to set_prior(). + ensures + - trains a C support vector classifier given the training samples in x and + labels in y. + - returns a decision function F with the following properties: + - F.alpha.size() == 1 + - F.basis_vectors.size() == 1 + - F.alpha(0) == 1 + - if (new_x is a sample predicted have +1 label) then + - F(new_x) >= 0 + - else + - F(new_x) < 0 + !*/ + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y, + scalar_type& svm_objective + ) const; + /*! + requires + - is_learning_problem(x,y) == true + (Note that it is ok for x.size() == 1) + - All elements of y must be equal to +1 or -1 + - x == a matrix or something convertible to a matrix via mat(). + Also, x should contain sample_type objects. + - y == a matrix or something convertible to a matrix via mat(). + Also, y should contain scalar_type objects. + - if (has_prior()) then + - The vectors in x must have the same dimensionality as the vectors + used to train the prior given to set_prior(). + ensures + - trains a C support vector classifier given the training samples in x and + labels in y. + - #svm_objective == the final value of the SVM objective function + - returns a decision function F with the following properties: + - F.alpha.size() == 1 + - F.basis_vectors.size() == 1 + - F.alpha(0) == 1 + - if (new_x is a sample predicted have +1 label) then + - F(new_x) >= 0 + - else + - F(new_x) < 0 + !*/ + + }; + +} + +#endif // DLIB_SVM_C_LiNEAR_TRAINER_ABSTRACT_Hh_ + diff --git a/ml/dlib/dlib/svm/svm_c_trainer.h b/ml/dlib/dlib/svm/svm_c_trainer.h new file mode 100644 index 000000000..14dcf3482 --- /dev/null +++ b/ml/dlib/dlib/svm/svm_c_trainer.h @@ -0,0 +1,359 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SVm_C_TRAINER_Hh_ +#define DLIB_SVm_C_TRAINER_Hh_ + +//#include "local/make_label_kernel_matrix.h" + +#include "svm_c_trainer_abstract.h" +#include +#include +#include +#include "../matrix.h" +#include "../algs.h" + +#include "function.h" +#include "kernel.h" +#include "../optimization/optimization_solve_qp3_using_smo.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename K + > + class svm_c_trainer + { + public: + typedef K kernel_type; + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + typedef decision_function trained_function_type; + + svm_c_trainer ( + ) : + Cpos(1), + Cneg(1), + cache_size(200), + eps(0.001) + { + } + + svm_c_trainer ( + const kernel_type& kernel_, + const scalar_type& C_ + ) : + kernel_function(kernel_), + Cpos(C_), + Cneg(C_), + cache_size(200), + eps(0.001) + { + // make sure requires clause is not broken + DLIB_ASSERT(0 < C_, + "\tsvm_c_trainer::svm_c_trainer(kernel,C)" + << "\n\t invalid inputs were given to this function" + << "\n\t C_: " << C_ + ); + } + + void set_cache_size ( + long cache_size_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(cache_size_ > 0, + "\tvoid svm_c_trainer::set_cache_size(cache_size_)" + << "\n\t invalid inputs were given to this function" + << "\n\t cache_size: " << cache_size_ + ); + cache_size = cache_size_; + } + + long get_cache_size ( + ) const + { + return cache_size; + } + + void set_epsilon ( + scalar_type eps_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(eps_ > 0, + "\tvoid svm_c_trainer::set_epsilon(eps_)" + << "\n\t invalid inputs were given to this function" + << "\n\t eps_: " << eps_ + ); + eps = eps_; + } + + const scalar_type get_epsilon ( + ) const + { + return eps; + } + + void set_kernel ( + const kernel_type& k + ) + { + kernel_function = k; + } + + const kernel_type& get_kernel ( + ) const + { + return kernel_function; + } + + void set_c ( + scalar_type C + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(C > 0, + "\t void svm_c_trainer::set_c()" + << "\n\t C must be greater than 0" + << "\n\t C: " << C + << "\n\t this: " << this + ); + + Cpos = C; + Cneg = C; + } + + const scalar_type get_c_class1 ( + ) const + { + return Cpos; + } + + const scalar_type get_c_class2 ( + ) const + { + return Cneg; + } + + void set_c_class1 ( + scalar_type C + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(C > 0, + "\t void svm_c_trainer::set_c_class1()" + << "\n\t C must be greater than 0" + << "\n\t C: " << C + << "\n\t this: " << this + ); + + Cpos = C; + } + + void set_c_class2 ( + scalar_type C + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(C > 0, + "\t void svm_c_trainer::set_c_class2()" + << "\n\t C must be greater than 0" + << "\n\t C: " << C + << "\n\t this: " << this + ); + + Cneg = C; + } + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y + ) const + { + return do_train(mat(x), mat(y)); + } + + void swap ( + svm_c_trainer& item + ) + { + exchange(kernel_function, item.kernel_function); + exchange(Cpos, item.Cpos); + exchange(Cneg, item.Cneg); + exchange(cache_size, item.cache_size); + exchange(eps, item.eps); + } + + private: + + // ------------------------------------------------------------------------------------ + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function do_train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y + ) const + { + typedef typename K::scalar_type scalar_type; + typedef typename decision_function::sample_vector_type sample_vector_type; + typedef typename decision_function::scalar_vector_type scalar_vector_type; + + // make sure requires clause is not broken + DLIB_ASSERT(is_binary_classification_problem(x,y) == true, + "\tdecision_function svm_c_trainer::train(x,y)" + << "\n\t invalid inputs were given to this function" + << "\n\t x.nr(): " << x.nr() + << "\n\t y.nr(): " << y.nr() + << "\n\t x.nc(): " << x.nc() + << "\n\t y.nc(): " << y.nc() + << "\n\t is_binary_classification_problem(x,y): " << is_binary_classification_problem(x,y) + ); + + + scalar_vector_type alpha; + + solve_qp3_using_smo solver; + + solver(symmetric_matrix_cache((diagm(y)*kernel_matrix(kernel_function,x)*diagm(y)), cache_size), + //solver(symmetric_matrix_cache(make_label_kernel_matrix(kernel_matrix(kernel_function,x),y), cache_size), + uniform_matrix(y.size(),1,-1), + y, + 0, + Cpos, + Cneg, + alpha, + eps); + + scalar_type b; + calculate_b(y,alpha,solver.get_gradient(),Cpos,Cneg,b); + alpha = pointwise_multiply(alpha,y); + + // count the number of support vectors + const long sv_count = (long)sum(alpha != 0); + + scalar_vector_type sv_alpha; + sample_vector_type support_vectors; + + // size these column vectors so that they have an entry for each support vector + sv_alpha.set_size(sv_count); + support_vectors.set_size(sv_count); + + // load the support vectors and their alpha values into these new column matrices + long idx = 0; + for (long i = 0; i < alpha.nr(); ++i) + { + if (alpha(i) != 0) + { + sv_alpha(idx) = alpha(i); + support_vectors(idx) = x(i); + ++idx; + } + } + + // now return the decision function + return decision_function (sv_alpha, b, kernel_function, support_vectors); + } + + // ------------------------------------------------------------------------------------ + + template < + typename scalar_vector_type, + typename scalar_vector_type2 + > + void calculate_b( + const scalar_vector_type2& y, + const scalar_vector_type& alpha, + const scalar_vector_type& df, + const scalar_type& Cpos, + const scalar_type& Cneg, + scalar_type& b + ) const + { + using namespace std; + long num_free = 0; + scalar_type sum_free = 0; + + scalar_type upper_bound = -numeric_limits::infinity(); + scalar_type lower_bound = numeric_limits::infinity(); + + for(long i = 0; i < alpha.nr(); ++i) + { + if(y(i) == 1) + { + if(alpha(i) == Cpos) + { + if (df(i) > upper_bound) + upper_bound = df(i); + } + else if(alpha(i) == 0) + { + if (df(i) < lower_bound) + lower_bound = df(i); + } + else + { + ++num_free; + sum_free += df(i); + } + } + else + { + if(alpha(i) == Cneg) + { + if (-df(i) < lower_bound) + lower_bound = -df(i); + } + else if(alpha(i) == 0) + { + if (-df(i) > upper_bound) + upper_bound = -df(i); + } + else + { + ++num_free; + sum_free -= df(i); + } + } + } + + if(num_free > 0) + b = sum_free/num_free; + else + b = (upper_bound+lower_bound)/2; + } + + // ------------------------------------------------------------------------------------ + + + kernel_type kernel_function; + scalar_type Cpos; + scalar_type Cneg; + long cache_size; + scalar_type eps; + }; // end of class svm_c_trainer + +// ---------------------------------------------------------------------------------------- + + template + void swap ( + svm_c_trainer& a, + svm_c_trainer& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SVm_C_TRAINER_Hh_ + diff --git a/ml/dlib/dlib/svm/svm_c_trainer_abstract.h b/ml/dlib/dlib/svm/svm_c_trainer_abstract.h new file mode 100644 index 000000000..696cccdb7 --- /dev/null +++ b/ml/dlib/dlib/svm/svm_c_trainer_abstract.h @@ -0,0 +1,237 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SVm_C_TRAINER_ABSTRACT_ +#ifdef DLIB_SVm_C_TRAINER_ABSTRACT_ + +#include +#include +#include +#include "../matrix/matrix_abstract.h" +#include "../algs.h" +#include "function_abstract.h" +#include "kernel_abstract.h" +#include "../optimization/optimization_solve_qp3_using_smo_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename K + > + class svm_c_trainer + { + /*! + REQUIREMENTS ON K + is a kernel function object as defined in dlib/svm/kernel_abstract.h + + WHAT THIS OBJECT REPRESENTS + This object implements a trainer for a C support vector machine for + solving binary classification problems. It is implemented using the SMO + algorithm. + + The implementation of the C-SVM training algorithm used by this object is based + on the following paper: + - Chih-Chung Chang and Chih-Jen Lin, LIBSVM : a library for support vector + machines, 2001. Software available at http://www.csie.ntu.edu.tw/~cjlin/libsvm + + !*/ + + public: + typedef K kernel_type; + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + typedef decision_function trained_function_type; + + svm_c_trainer ( + ); + /*! + ensures + - This object is properly initialized and ready to be used + to train a support vector machine. + - #get_c_class1() == 1 + - #get_c_class2() == 1 + - #get_cache_size() == 200 + - #get_epsilon() == 0.001 + !*/ + + svm_c_trainer ( + const kernel_type& kernel, + const scalar_type& C + ); + /*! + requires + - 0 < C + ensures + - This object is properly initialized and ready to be used + to train a support vector machine. + - #get_kernel() == kernel + - #get_c_class1() == C + - #get_c_class2() == C + - #get_cache_size() == 200 + - #get_epsilon() == 0.001 + !*/ + + void set_cache_size ( + long cache_size + ); + /*! + requires + - cache_size > 0 + ensures + - #get_cache_size() == cache_size + !*/ + + const long get_cache_size ( + ) const; + /*! + ensures + - returns the number of megabytes of cache this object will use + when it performs training via the this->train() function. + (bigger values of this may make training go faster but won't affect + the result. However, too big a value will cause you to run out of + memory, obviously.) + !*/ + + void set_epsilon ( + scalar_type eps + ); + /*! + requires + - eps > 0 + ensures + - #get_epsilon() == eps + !*/ + + const scalar_type get_epsilon ( + ) const; + /*! + ensures + - returns the error epsilon that determines when training should stop. + Generally a good value for this is 0.001. Smaller values may result + in a more accurate solution but take longer to execute. + !*/ + + void set_kernel ( + const kernel_type& k + ); + /*! + ensures + - #get_kernel() == k + !*/ + + const kernel_type& get_kernel ( + ) const; + /*! + ensures + - returns a copy of the kernel function in use by this object + !*/ + + void set_c ( + scalar_type C + ); + /*! + requires + - C > 0 + ensures + - #get_c_class1() == C + - #get_c_class2() == C + !*/ + + const scalar_type get_c_class1 ( + ) const; + /*! + ensures + - returns the SVM regularization parameter for the +1 class. + It is the parameter that determines the trade off between + trying to fit the +1 training data exactly or allowing more errors + but hopefully improving the generalization ability of the + resulting classifier. Larger values encourage exact fitting + while smaller values of C may encourage better generalization. + !*/ + + const scalar_type get_c_class2 ( + ) const; + /*! + ensures + - returns the SVM regularization parameter for the -1 class. + It is the parameter that determines the trade off between + trying to fit the -1 training data exactly or allowing more errors + but hopefully improving the generalization ability of the + resulting classifier. Larger values encourage exact fitting + while smaller values of C may encourage better generalization. + !*/ + + void set_c_class1 ( + scalar_type C + ); + /*! + requires + - C > 0 + ensures + - #get_c_class1() == C + !*/ + + void set_c_class2 ( + scalar_type C + ); + /*! + requires + - C > 0 + ensures + - #get_c_class2() == C + !*/ + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y + ) const; + /*! + requires + - is_binary_classification_problem(x,y) == true + - x == a matrix or something convertible to a matrix via mat(). + Also, x should contain sample_type objects. + - y == a matrix or something convertible to a matrix via mat(). + Also, y should contain scalar_type objects. + ensures + - trains a C support vector classifier given the training samples in x and + labels in y. Training is done when the error is less than get_epsilon(). + - returns a decision function F with the following properties: + - if (new_x is a sample predicted have +1 label) then + - F(new_x) >= 0 + - else + - F(new_x) < 0 + !*/ + + void swap ( + svm_c_trainer& item + ); + /*! + ensures + - swaps *this and item + !*/ + }; + + template + void swap ( + svm_c_trainer& a, + svm_c_trainer& b + ) { a.swap(b); } + /*! + provides a global swap + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SVm_C_TRAINER_ABSTRACT_ + + + diff --git a/ml/dlib/dlib/svm/svm_multiclass_linear_trainer.h b/ml/dlib/dlib/svm/svm_multiclass_linear_trainer.h new file mode 100644 index 000000000..4727f7226 --- /dev/null +++ b/ml/dlib/dlib/svm/svm_multiclass_linear_trainer.h @@ -0,0 +1,432 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SVm_MULTICLASS_LINEAR_TRAINER_Hh_ +#define DLIB_SVm_MULTICLASS_LINEAR_TRAINER_Hh_ + +#include "svm_multiclass_linear_trainer_abstract.h" +#include "structural_svm_problem_threaded.h" +#include +#include "../optimization/optimization_oca.h" +#include "../matrix.h" +#include "sparse_vector.h" +#include "function.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename matrix_type, + typename sample_type, + typename label_type + > + class multiclass_svm_problem : public structural_svm_problem_threaded > > + { + /*! + WHAT THIS OBJECT REPRESENTS + This object defines the optimization problem for the multiclass SVM trainer + object at the bottom of this file. + + The joint feature vectors used by this object, the PSI(x,y) vectors, are + defined as follows: + PSI(x,0) = [x,0,0,0,0, ...,0] + PSI(x,1) = [0,x,0,0,0, ...,0] + PSI(x,2) = [0,0,x,0,0, ...,0] + That is, if there are N labels then the joint feature vector has a + dimension that is N times the dimension of a single x sample. Also, + note that we append a -1 value onto each x to account for the bias term. + !*/ + + public: + typedef typename matrix_type::type scalar_type; + typedef std::vector > feature_vector_type; + + multiclass_svm_problem ( + const std::vector& samples_, + const std::vector& labels_, + const std::vector& distinct_labels_, + const unsigned long dims_, + const unsigned long num_threads + ) : + structural_svm_problem_threaded > >(num_threads), + samples(samples_), + labels(labels_), + distinct_labels(distinct_labels_), + dims(dims_+1) // +1 for the bias + {} + + virtual long get_num_dimensions ( + ) const + { + return dims*distinct_labels.size(); + } + + virtual long get_num_samples ( + ) const + { + return static_cast(samples.size()); + } + + virtual void get_truth_joint_feature_vector ( + long idx, + feature_vector_type& psi + ) const + { + assign(psi, samples[idx]); + // Add a constant -1 to account for the bias term. + psi.push_back(std::make_pair(dims-1,static_cast(-1))); + + // Find which distinct label goes with this psi. + long label_idx = 0; + for (unsigned long i = 0; i < distinct_labels.size(); ++i) + { + if (distinct_labels[i] == labels[idx]) + { + label_idx = i; + break; + } + } + + offset_feature_vector(psi, dims*label_idx); + } + + virtual void separation_oracle ( + const long idx, + const matrix_type& current_solution, + scalar_type& loss, + feature_vector_type& psi + ) const + { + scalar_type best_val = -std::numeric_limits::infinity(); + unsigned long best_idx = 0; + + // Figure out which label is the best. That is, what label maximizes + // LOSS(idx,y) + F(x,y). Note that y in this case is given by distinct_labels[i]. + for (unsigned long i = 0; i < distinct_labels.size(); ++i) + { + // Compute the F(x,y) part: + // perform: temp == dot(relevant part of current solution, samples[idx]) - current_bias + scalar_type temp = dot(mat(¤t_solution(i*dims),dims-1), samples[idx]) - current_solution((i+1)*dims-1); + + // Add the LOSS(idx,y) part: + if (labels[idx] != distinct_labels[i]) + temp += 1; + + // Now temp == LOSS(idx,y) + F(x,y). Check if it is the biggest we have seen. + if (temp > best_val) + { + best_val = temp; + best_idx = i; + } + } + + assign(psi, samples[idx]); + // add a constant -1 to account for the bias term + psi.push_back(std::make_pair(dims-1,static_cast(-1))); + + offset_feature_vector(psi, dims*best_idx); + + if (distinct_labels[best_idx] == labels[idx]) + loss = 0; + else + loss = 1; + } + + private: + + void offset_feature_vector ( + feature_vector_type& sample, + const unsigned long val + ) const + { + if (val != 0) + { + for (typename feature_vector_type::iterator i = sample.begin(); i != sample.end(); ++i) + { + i->first += val; + } + } + } + + + const std::vector& samples; + const std::vector& labels; + const std::vector& distinct_labels; + const long dims; + }; + + +// ---------------------------------------------------------------------------------------- + + template < + typename K, + typename label_type_ = typename K::scalar_type + > + class svm_multiclass_linear_trainer + { + public: + typedef label_type_ label_type; + typedef K kernel_type; + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + + typedef multiclass_linear_decision_function trained_function_type; + + + // You are getting a compiler error on this line because you supplied a non-linear kernel + // to the svm_c_linear_trainer object. You have to use one of the linear kernels with this + // trainer. + COMPILE_TIME_ASSERT((is_same_type >::value || + is_same_type >::value )); + + svm_multiclass_linear_trainer ( + ) : + num_threads(4), + C(1), + eps(0.001), + max_iterations(10000), + verbose(false), + learn_nonnegative_weights(false) + { + } + + void set_num_threads ( + unsigned long num + ) + { + num_threads = num; + } + + unsigned long get_num_threads ( + ) const + { + return num_threads; + } + + void set_epsilon ( + scalar_type eps_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(eps_ > 0, + "\t void svm_multiclass_linear_trainer::set_epsilon()" + << "\n\t eps_ must be greater than 0" + << "\n\t eps_: " << eps_ + << "\n\t this: " << this + ); + + eps = eps_; + } + + const scalar_type get_epsilon ( + ) const { return eps; } + + unsigned long get_max_iterations ( + ) const { return max_iterations; } + + void set_max_iterations ( + unsigned long max_iter + ) + { + max_iterations = max_iter; + } + + void be_verbose ( + ) + { + verbose = true; + } + + void be_quiet ( + ) + { + verbose = false; + } + + void set_oca ( + const oca& item + ) + { + solver = item; + } + + const oca get_oca ( + ) const + { + return solver; + } + + const kernel_type get_kernel ( + ) const + { + return kernel_type(); + } + + bool learns_nonnegative_weights ( + ) const { return learn_nonnegative_weights; } + + void set_learns_nonnegative_weights ( + bool value + ) + { + learn_nonnegative_weights = value; + if (learn_nonnegative_weights) + prior = trained_function_type(); + } + + void set_c ( + scalar_type C_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(C_ > 0, + "\t void svm_multiclass_linear_trainer::set_c()" + << "\n\t C must be greater than 0" + << "\n\t C_: " << C_ + << "\n\t this: " << this + ); + + C = C_; + } + + const scalar_type get_c ( + ) const + { + return C; + } + + void set_prior ( + const trained_function_type& prior_ + ) + { + prior = prior_; + learn_nonnegative_weights = false; + } + + bool has_prior ( + ) const + { + return prior.labels.size() != 0; + } + + trained_function_type train ( + const std::vector& all_samples, + const std::vector& all_labels + ) const + { + scalar_type svm_objective = 0; + return train(all_samples, all_labels, svm_objective); + } + + trained_function_type train ( + const std::vector& all_samples, + const std::vector& all_labels, + scalar_type& svm_objective + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(is_learning_problem(all_samples,all_labels), + "\t trained_function_type svm_multiclass_linear_trainer::train(all_samples,all_labels)" + << "\n\t invalid inputs were given to this function" + << "\n\t all_samples.size(): " << all_samples.size() + << "\n\t all_labels.size(): " << all_labels.size() + ); + + trained_function_type df; + df.labels = select_all_distinct_labels(all_labels); + if (has_prior()) + { + df.labels.insert(df.labels.end(), prior.labels.begin(), prior.labels.end()); + df.labels = select_all_distinct_labels(df.labels); + } + const long input_sample_dimensionality = max_index_plus_one(all_samples); + // If the samples are sparse then the right thing to do is to take the max + // dimensionality between the prior and the new samples. But if the samples + // are dense vectors then they definitely all have to have exactly the same + // dimensionality. + const long dims = std::max(df.weights.nc(),input_sample_dimensionality); + if (is_matrix::value && has_prior()) + { + DLIB_ASSERT(input_sample_dimensionality == prior.weights.nc(), + "\t trained_function_type svm_multiclass_linear_trainer::train(all_samples,all_labels)" + << "\n\t The training samples given to this function are not the same kind of training " + << "\n\t samples used to create the prior." + << "\n\t input_sample_dimensionality: " << input_sample_dimensionality + << "\n\t prior.weights.nc(): " << prior.weights.nc() + ); + } + + typedef matrix w_type; + w_type weights; + multiclass_svm_problem problem(all_samples, all_labels, df.labels, dims, num_threads); + if (verbose) + problem.be_verbose(); + + problem.set_max_cache_size(0); + problem.set_c(C); + problem.set_epsilon(eps); + problem.set_max_iterations(max_iterations); + + unsigned long num_nonnegative = 0; + if (learn_nonnegative_weights) + { + num_nonnegative = problem.get_num_dimensions(); + } + + if (!has_prior()) + { + svm_objective = solver(problem, weights, num_nonnegative); + } + else + { + matrix temp(df.labels.size(),dims); + w_type b(df.labels.size()); + temp = 0; + b = 0; + + const long pad_size = dims-prior.weights.nc(); + // Copy the prior into the temp and b matrices. We have to do this row + // by row copy because the new training data might have new labels we + // haven't seen before and therefore the sizes of these matrices could be + // different. + for (unsigned long i = 0; i < prior.labels.size(); ++i) + { + const long r = std::find(df.labels.begin(), df.labels.end(), prior.labels[i])-df.labels.begin(); + set_rowm(temp,r) = join_rows(rowm(prior.weights,i), zeros_matrix(1,pad_size)); + b(r) = prior.b(i); + } + + const w_type prior_vect = reshape_to_column_vector(join_rows(temp,b)); + svm_objective = solver(problem, weights, prior_vect); + } + + + df.weights = colm(reshape(weights, df.labels.size(), dims+1), range(0,dims-1)); + df.b = colm(reshape(weights, df.labels.size(), dims+1), dims); + return df; + } + + private: + + unsigned long num_threads; + scalar_type C; + scalar_type eps; + unsigned long max_iterations; + bool verbose; + oca solver; + bool learn_nonnegative_weights; + + trained_function_type prior; + }; + +// ---------------------------------------------------------------------------------------- + +} + + +#endif // DLIB_SVm_MULTICLASS_LINEAR_TRAINER_Hh_ + diff --git a/ml/dlib/dlib/svm/svm_multiclass_linear_trainer_abstract.h b/ml/dlib/dlib/svm/svm_multiclass_linear_trainer_abstract.h new file mode 100644 index 000000000..6561ce7b2 --- /dev/null +++ b/ml/dlib/dlib/svm/svm_multiclass_linear_trainer_abstract.h @@ -0,0 +1,275 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SVm_MULTICLASS_LINEAR_TRAINER_ABSTRACT_Hh_ +#ifdef DLIB_SVm_MULTICLASS_LINEAR_TRAINER_ABSTRACT_Hh_ + +#include "../matrix/matrix_abstract.h" +#include "../algs.h" +#include "function_abstract.h" +#include "kernel_abstract.h" +#include "sparse_kernel_abstract.h" +#include "../optimization/optimization_oca_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename K, + typename label_type_ = typename K::scalar_type + > + class svm_multiclass_linear_trainer + { + /*! + REQUIREMENTS ON K + Is either linear_kernel or sparse_linear_kernel. + + REQUIREMENTS ON label_type_ + label_type_ must be default constructable, copyable, and comparable using + operator < and ==. It must also be possible to write it to an std::ostream + using operator<<. + + INITIAL VALUE + - get_num_threads() == 4 + - learns_nonnegative_weights() == false + - get_epsilon() == 0.001 + - get_max_iterations() == 10000 + - get_c() == 1 + - this object will not be verbose unless be_verbose() is called + - #get_oca() == oca() (i.e. an instance of oca with default parameters) + - has_prior() == false + + WHAT THIS OBJECT REPRESENTS + This object represents a tool for training a multiclass support + vector machine. It is optimized for the case where linear kernels + are used. + !*/ + + public: + typedef label_type_ label_type; + typedef K kernel_type; + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + typedef multiclass_linear_decision_function trained_function_type; + + svm_multiclass_linear_trainer ( + ); + /*! + ensures + - this object is properly initialized + !*/ + + void set_epsilon ( + scalar_type eps + ); + /*! + requires + - eps > 0 + ensures + - #get_epsilon() == eps + !*/ + + const scalar_type get_epsilon ( + ) const; + /*! + ensures + - returns the error epsilon that determines when training should stop. + Smaller values may result in a more accurate solution but take longer + to execute. + !*/ + + void set_max_iterations ( + unsigned long max_iter + ); + /*! + ensures + - #get_max_iterations() == max_iter + !*/ + + unsigned long get_max_iterations ( + ); + /*! + ensures + - returns the maximum number of iterations the SVM optimizer is allowed to + run before it is required to stop and return a result. + !*/ + + void be_verbose ( + ); + /*! + ensures + - This object will print status messages to standard out so that a + user can observe the progress of the algorithm. + !*/ + + void be_quiet ( + ); + /*! + ensures + - this object will not print anything to standard out + !*/ + + void set_oca ( + const oca& item + ); + /*! + ensures + - #get_oca() == item + !*/ + + const oca get_oca ( + ) const; + /*! + ensures + - returns a copy of the optimizer used to solve the SVM problem. + !*/ + + void set_num_threads ( + unsigned long num + ); + /*! + ensures + - #get_num_threads() == num + !*/ + + unsigned long get_num_threads ( + ) const; + /*! + ensures + - returns the number of threads used during training. You should + usually set this equal to the number of processing cores on your + machine. + !*/ + + const kernel_type get_kernel ( + ) const; + /*! + ensures + - returns a copy of the kernel function in use by this object. Since + the linear kernels don't have any parameters this function just + returns kernel_type() + !*/ + + void set_c ( + scalar_type C + ); + /*! + requires + - C > 0 + ensures + - #get_c() == C + !*/ + + const scalar_type get_c ( + ) const; + /*! + ensures + - returns the SVM regularization parameter. It is the parameter that + determines the trade off between trying to fit the training data + exactly or allowing more errors but hopefully improving the + generalization of the resulting classifier. Larger values encourage + exact fitting while smaller values of C may encourage better + generalization. + !*/ + + bool learns_nonnegative_weights ( + ) const; + /*! + ensures + - The output of training is a set of weights and bias values that together + define the behavior of a multiclass_linear_decision_function object. If + learns_nonnegative_weights() == true then the resulting weights and bias + values will always have non-negative values. That is, if this function + returns true then all the numbers in the multiclass_linear_decision_function + objects output by train() will be non-negative. + !*/ + + void set_learns_nonnegative_weights ( + bool value + ); + /*! + ensures + - #learns_nonnegative_weights() == value + - if (value == true) then + - #has_prior() == false + !*/ + + void set_prior ( + const trained_function_type& prior + ); + /*! + ensures + - Subsequent calls to train() will try to learn a function similar to the + given prior. + - #has_prior() == true + - #learns_nonnegative_weights() == false + !*/ + + bool has_prior ( + ) const + /*! + ensures + - returns true if a prior has been set and false otherwise. Having a prior + set means that you have called set_prior() and supplied a previously + trained function as a reference. In this case, any call to train() will + try to learn a function that matches the behavior of the prior as close + as possible but also fits the supplied training data. In more technical + detail, having a prior means we replace the ||w||^2 regularizer with one + of the form ||w-prior||^2 where w is the set of parameters for a learned + function. + !*/ + + trained_function_type train ( + const std::vector& all_samples, + const std::vector& all_labels + ) const; + /*! + requires + - is_learning_problem(all_samples, all_labels) + - All the vectors in all_samples must have the same dimensionality. + - if (has_prior()) then + - The vectors in all_samples must have the same dimensionality as the + vectors used to train the prior given to set_prior(). + ensures + - trains a multiclass SVM to solve the given multiclass classification problem. + - returns a multiclass_linear_decision_function F with the following properties: + - if (new_x is a sample predicted to have a label of L) then + - F(new_x) == L + - F.get_labels() == select_all_distinct_labels(all_labels) + - F.number_of_classes() == select_all_distinct_labels(all_labels).size() + !*/ + + trained_function_type train ( + const std::vector& all_samples, + const std::vector& all_labels, + scalar_type& svm_objective + ) const; + /*! + requires + - is_learning_problem(all_samples, all_labels) + - All the vectors in all_samples must have the same dimensionality. + - if (has_prior()) then + - The vectors in all_samples must have the same dimensionality as the + vectors used to train the prior given to set_prior(). + ensures + - trains a multiclass SVM to solve the given multiclass classification problem. + - returns a multiclass_linear_decision_function F with the following properties: + - if (new_x is a sample predicted to have a label of L) then + - F(new_x) == L + - F.get_labels() == select_all_distinct_labels(all_labels) + - F.number_of_classes() == select_all_distinct_labels(all_labels).size() + - #svm_objective == the final value of the SVM objective function + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + +} + + +#endif // DLIB_SVm_MULTICLASS_LINEAR_TRAINER_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/svm/svm_nu_trainer.h b/ml/dlib/dlib/svm/svm_nu_trainer.h new file mode 100644 index 000000000..1e89d6efa --- /dev/null +++ b/ml/dlib/dlib/svm/svm_nu_trainer.h @@ -0,0 +1,326 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SVm_NU_TRAINER_Hh_ +#define DLIB_SVm_NU_TRAINER_Hh_ + +//#include "local/make_label_kernel_matrix.h" + +#include "svm_nu_trainer_abstract.h" +#include +#include +#include +#include "../matrix.h" +#include "../algs.h" +#include "../serialize.h" + +#include "function.h" +#include "kernel.h" +#include "../optimization/optimization_solve_qp2_using_smo.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename K + > + class svm_nu_trainer + { + public: + typedef K kernel_type; + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + typedef decision_function trained_function_type; + + svm_nu_trainer ( + ) : + nu(0.1), + cache_size(200), + eps(0.001) + { + } + + svm_nu_trainer ( + const kernel_type& kernel_, + const scalar_type& nu_ + ) : + kernel_function(kernel_), + nu(nu_), + cache_size(200), + eps(0.001) + { + // make sure requires clause is not broken + DLIB_ASSERT(0 < nu && nu <= 1, + "\tsvm_nu_trainer::svm_nu_trainer(kernel,nu)" + << "\n\t invalid inputs were given to this function" + << "\n\t nu: " << nu + ); + } + + void set_cache_size ( + long cache_size_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(cache_size_ > 0, + "\tvoid svm_nu_trainer::set_cache_size(cache_size_)" + << "\n\t invalid inputs were given to this function" + << "\n\t cache_size: " << cache_size_ + ); + cache_size = cache_size_; + } + + long get_cache_size ( + ) const + { + return cache_size; + } + + void set_epsilon ( + scalar_type eps_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(eps_ > 0, + "\tvoid svm_nu_trainer::set_epsilon(eps_)" + << "\n\t invalid inputs were given to this function" + << "\n\t eps: " << eps_ + ); + eps = eps_; + } + + const scalar_type get_epsilon ( + ) const + { + return eps; + } + + void set_kernel ( + const kernel_type& k + ) + { + kernel_function = k; + } + + const kernel_type& get_kernel ( + ) const + { + return kernel_function; + } + + void set_nu ( + scalar_type nu_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(0 < nu_ && nu_ <= 1, + "\tvoid svm_nu_trainer::set_nu(nu_)" + << "\n\t invalid inputs were given to this function" + << "\n\t nu: " << nu_ + ); + nu = nu_; + } + + const scalar_type get_nu ( + ) const + { + return nu; + } + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y + ) const + { + return do_train(mat(x), mat(y)); + } + + void swap ( + svm_nu_trainer& item + ) + { + exchange(kernel_function, item.kernel_function); + exchange(nu, item.nu); + exchange(cache_size, item.cache_size); + exchange(eps, item.eps); + } + + private: + + // ------------------------------------------------------------------------------------ + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function do_train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y + ) const + { + typedef typename K::scalar_type scalar_type; + typedef typename decision_function::sample_vector_type sample_vector_type; + typedef typename decision_function::scalar_vector_type scalar_vector_type; + + // make sure requires clause is not broken + DLIB_ASSERT(is_binary_classification_problem(x,y) == true, + "\tdecision_function svm_nu_trainer::train(x,y)" + << "\n\t invalid inputs were given to this function" + << "\n\t x.nr(): " << x.nr() + << "\n\t y.nr(): " << y.nr() + << "\n\t x.nc(): " << x.nc() + << "\n\t y.nc(): " << y.nc() + << "\n\t is_binary_classification_problem(x,y): " << is_binary_classification_problem(x,y) + ); + + + scalar_vector_type alpha; + + solve_qp2_using_smo solver; + + solver(symmetric_matrix_cache((diagm(y)*kernel_matrix(kernel_function,x)*diagm(y)), cache_size), + //solver(symmetric_matrix_cache(make_label_kernel_matrix(kernel_matrix(kernel_function,x),y), cache_size), + y, + nu, + alpha, + eps); + + scalar_type rho, b; + calculate_rho_and_b(y,alpha,solver.get_gradient(),rho,b); + alpha = pointwise_multiply(alpha,y)/rho; + + // count the number of support vectors + const long sv_count = (long)sum(alpha != 0); + + scalar_vector_type sv_alpha; + sample_vector_type support_vectors; + + // size these column vectors so that they have an entry for each support vector + sv_alpha.set_size(sv_count); + support_vectors.set_size(sv_count); + + // load the support vectors and their alpha values into these new column matrices + long idx = 0; + for (long i = 0; i < alpha.nr(); ++i) + { + if (alpha(i) != 0) + { + sv_alpha(idx) = alpha(i); + support_vectors(idx) = x(i); + ++idx; + } + } + + // now return the decision function + return decision_function (sv_alpha, b, kernel_function, support_vectors); + } + + // ------------------------------------------------------------------------------------ + + template < + typename scalar_vector_type, + typename scalar_vector_type2, + typename scalar_type + > + void calculate_rho_and_b( + const scalar_vector_type2& y, + const scalar_vector_type& alpha, + const scalar_vector_type& df, + scalar_type& rho, + scalar_type& b + ) const + { + using namespace std; + long num_p_free = 0; + long num_n_free = 0; + scalar_type sum_p_free = 0; + scalar_type sum_n_free = 0; + + scalar_type upper_bound_p = -numeric_limits::infinity(); + scalar_type upper_bound_n = -numeric_limits::infinity(); + scalar_type lower_bound_p = numeric_limits::infinity(); + scalar_type lower_bound_n = numeric_limits::infinity(); + + for(long i = 0; i < alpha.nr(); ++i) + { + if(y(i) == 1) + { + if(alpha(i) == 1) + { + if (df(i) > upper_bound_p) + upper_bound_p = df(i); + } + else if(alpha(i) == 0) + { + if (df(i) < lower_bound_p) + lower_bound_p = df(i); + } + else + { + ++num_p_free; + sum_p_free += df(i); + } + } + else + { + if(alpha(i) == 1) + { + if (df(i) > upper_bound_n) + upper_bound_n = df(i); + } + else if(alpha(i) == 0) + { + if (df(i) < lower_bound_n) + lower_bound_n = df(i); + } + else + { + ++num_n_free; + sum_n_free += df(i); + } + } + } + + scalar_type r1,r2; + if(num_p_free > 0) + r1 = sum_p_free/num_p_free; + else + r1 = (upper_bound_p+lower_bound_p)/2; + + if(num_n_free > 0) + r2 = sum_n_free/num_n_free; + else + r2 = (upper_bound_n+lower_bound_n)/2; + + rho = (r1+r2)/2; + b = (r1-r2)/2/rho; + } + + // ------------------------------------------------------------------------------------ + + kernel_type kernel_function; + scalar_type nu; + long cache_size; + scalar_type eps; + }; // end of class svm_nu_trainer + +// ---------------------------------------------------------------------------------------- + + template + void swap ( + svm_nu_trainer& a, + svm_nu_trainer& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SVm_NU_TRAINER_Hh_ + diff --git a/ml/dlib/dlib/svm/svm_nu_trainer_abstract.h b/ml/dlib/dlib/svm/svm_nu_trainer_abstract.h new file mode 100644 index 000000000..5ae0fba4a --- /dev/null +++ b/ml/dlib/dlib/svm/svm_nu_trainer_abstract.h @@ -0,0 +1,210 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SVm_NU_TRAINER_ABSTRACT_ +#ifdef DLIB_SVm_NU_TRAINER_ABSTRACT_ + +#include +#include +#include +#include "../matrix/matrix_abstract.h" +#include "../algs.h" +#include "../serialize.h" +#include "function_abstract.h" +#include "kernel_abstract.h" +#include "../optimization/optimization_solve_qp2_using_smo_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename K + > + class svm_nu_trainer + { + /*! + REQUIREMENTS ON K + is a kernel function object as defined in dlib/svm/kernel_abstract.h + + WHAT THIS OBJECT REPRESENTS + This object implements a trainer for a nu support vector machine for + solving binary classification problems. It is implemented using the SMO + algorithm. + + The implementation of the nu-svm training algorithm used by this object is based + on the following excellent papers: + - Chang and Lin, Training {nu}-Support Vector Classifiers: Theory and Algorithms + - Chih-Chung Chang and Chih-Jen Lin, LIBSVM : a library for support vector + machines, 2001. Software available at http://www.csie.ntu.edu.tw/~cjlin/libsvm + + !*/ + + public: + typedef K kernel_type; + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + typedef decision_function trained_function_type; + + svm_nu_trainer ( + ); + /*! + ensures + - This object is properly initialized and ready to be used + to train a support vector machine. + - #get_nu() == 0.1 + - #get_cache_size() == 200 + - #get_epsilon() == 0.001 + !*/ + + svm_nu_trainer ( + const kernel_type& kernel, + const scalar_type& nu + ); + /*! + requires + - 0 < nu <= 1 + ensures + - This object is properly initialized and ready to be used + to train a support vector machine. + - #get_kernel() == kernel + - #get_nu() == nu + - #get_cache_size() == 200 + - #get_epsilon() == 0.001 + !*/ + + void set_cache_size ( + long cache_size + ); + /*! + requires + - cache_size > 0 + ensures + - #get_cache_size() == cache_size + !*/ + + const long get_cache_size ( + ) const; + /*! + ensures + - returns the number of megabytes of cache this object will use + when it performs training via the this->train() function. + (bigger values of this may make training go faster but won't affect + the result. However, too big a value will cause you to run out of + memory, obviously.) + !*/ + + void set_epsilon ( + scalar_type eps + ); + /*! + requires + - eps > 0 + ensures + - #get_epsilon() == eps + !*/ + + const scalar_type get_epsilon ( + ) const; + /*! + ensures + - returns the error epsilon that determines when training should stop. + Generally a good value for this is 0.001. Smaller values may result + in a more accurate solution but take longer to execute. + !*/ + + void set_kernel ( + const kernel_type& k + ); + /*! + ensures + - #get_kernel() == k + !*/ + + const kernel_type& get_kernel ( + ) const; + /*! + ensures + - returns a copy of the kernel function in use by this object + !*/ + + void set_nu ( + scalar_type nu + ); + /*! + requires + - 0 < nu <= 1 + ensures + - #get_nu() == nu + !*/ + + const scalar_type get_nu ( + ) const; + /*! + ensures + - returns the nu svm parameter. This is a value between 0 and + 1. It is the parameter that determines the trade off between + trying to fit the training data exactly or allowing more errors + but hopefully improving the generalization ability of the + resulting classifier. Smaller values encourage exact fitting + while larger values of nu may encourage better generalization. + For more information you should consult the papers referenced + above. + !*/ + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y + ) const; + /*! + requires + - is_binary_classification_problem(x,y) == true + - x == a matrix or something convertible to a matrix via mat(). + Also, x should contain sample_type objects. + - y == a matrix or something convertible to a matrix via mat(). + Also, y should contain scalar_type objects. + ensures + - trains a nu support vector classifier given the training samples in x and + labels in y. Training is done when the error is less than get_epsilon(). + - returns a decision function F with the following properties: + - if (new_x is a sample predicted have +1 label) then + - F(new_x) >= 0 + - else + - F(new_x) < 0 + throws + - invalid_nu_error + This exception is thrown if get_nu() >= maximum_nu(y) + - std::bad_alloc + !*/ + + void swap ( + svm_nu_trainer& item + ); + /*! + ensures + - swaps *this and item + !*/ + }; + + template + void swap ( + svm_nu_trainer& a, + svm_nu_trainer& b + ) { a.swap(b); } + /*! + provides a global swap + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SVm_NU_TRAINER_ABSTRACT_ + + + diff --git a/ml/dlib/dlib/svm/svm_one_class_trainer.h b/ml/dlib/dlib/svm/svm_one_class_trainer.h new file mode 100644 index 000000000..be3cc8caf --- /dev/null +++ b/ml/dlib/dlib/svm/svm_one_class_trainer.h @@ -0,0 +1,284 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SVm_ONE_CLASS_TRAINER_Hh_ +#define DLIB_SVm_ONE_CLASS_TRAINER_Hh_ + +#include "svm_one_class_trainer_abstract.h" +#include +#include +#include +#include "../matrix.h" +#include "../algs.h" + +#include "function.h" +#include "kernel.h" +#include "../optimization/optimization_solve_qp3_using_smo.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename K + > + class svm_one_class_trainer + { + public: + typedef K kernel_type; + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + typedef decision_function trained_function_type; + + svm_one_class_trainer ( + ) : + nu(0.1), + cache_size(200), + eps(0.001) + { + } + + svm_one_class_trainer ( + const kernel_type& kernel_, + const scalar_type& nu_ + ) : + kernel_function(kernel_), + nu(nu_), + cache_size(200), + eps(0.001) + { + // make sure requires clause is not broken + DLIB_ASSERT(0 < nu && nu <= 1, + "\tsvm_one_class_trainer::svm_one_class_trainer(kernel,nu)" + << "\n\t invalid inputs were given to this function" + << "\n\t nu: " << nu + ); + } + + void set_cache_size ( + long cache_size_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(cache_size_ > 0, + "\tvoid svm_one_class_trainer::set_cache_size(cache_size_)" + << "\n\t invalid inputs were given to this function" + << "\n\t cache_size: " << cache_size_ + ); + cache_size = cache_size_; + } + + long get_cache_size ( + ) const + { + return cache_size; + } + + void set_epsilon ( + scalar_type eps_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(eps_ > 0, + "\tvoid svm_one_class_trainer::set_epsilon(eps_)" + << "\n\t invalid inputs were given to this function" + << "\n\t eps: " << eps_ + ); + eps = eps_; + } + + const scalar_type get_epsilon ( + ) const + { + return eps; + } + + void set_kernel ( + const kernel_type& k + ) + { + kernel_function = k; + } + + const kernel_type& get_kernel ( + ) const + { + return kernel_function; + } + + void set_nu ( + scalar_type nu_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(0 < nu_ && nu_ <= 1, + "\tvoid svm_one_class_trainer::set_nu(nu_)" + << "\n\t invalid inputs were given to this function" + << "\n\t nu: " << nu_ + ); + nu = nu_; + } + + const scalar_type get_nu ( + ) const + { + return nu; + } + + template < + typename in_sample_vector_type + > + const decision_function train ( + const in_sample_vector_type& x + ) const + { + return do_train(mat(x)); + } + + void swap ( + svm_one_class_trainer& item + ) + { + exchange(kernel_function, item.kernel_function); + exchange(nu, item.nu); + exchange(cache_size, item.cache_size); + exchange(eps, item.eps); + } + + private: + + // ------------------------------------------------------------------------------------ + + template < + typename in_sample_vector_type + > + const decision_function do_train ( + const in_sample_vector_type& x + ) const + { + typedef typename K::scalar_type scalar_type; + typedef typename decision_function::sample_vector_type sample_vector_type; + typedef typename decision_function::scalar_vector_type scalar_vector_type; + + // make sure requires clause is not broken + DLIB_ASSERT(is_col_vector(x) && x.size() > 0, + "\tdecision_function svm_one_class_trainer::train(x)" + << "\n\t invalid inputs were given to this function" + << "\n\t x.nr(): " << x.nr() + << "\n\t x.nc(): " << x.nc() + ); + + + scalar_vector_type alpha; + + solve_qp3_using_smo solver; + + solver(symmetric_matrix_cache(kernel_matrix(kernel_function,x), cache_size), + zeros_matrix(x.size(),1), + ones_matrix(x.size(),1), + nu*x.size(), + 1, + 1, + alpha, + eps); + + scalar_type rho; + calculate_rho(alpha,solver.get_gradient(),rho); + + + // count the number of support vectors + const long sv_count = (long)sum(alpha != 0); + + scalar_vector_type sv_alpha; + sample_vector_type support_vectors; + + // size these column vectors so that they have an entry for each support vector + sv_alpha.set_size(sv_count); + support_vectors.set_size(sv_count); + + // load the support vectors and their alpha values into these new column matrices + long idx = 0; + for (long i = 0; i < alpha.nr(); ++i) + { + if (alpha(i) != 0) + { + sv_alpha(idx) = alpha(i); + support_vectors(idx) = x(i); + ++idx; + } + } + + // now return the decision function + return decision_function (sv_alpha, rho, kernel_function, support_vectors); + } + + // ------------------------------------------------------------------------------------ + + template < + typename scalar_vector_type + > + void calculate_rho( + const scalar_vector_type& alpha, + const scalar_vector_type& df, + scalar_type& rho + ) const + { + using namespace std; + long num_p_free = 0; + scalar_type sum_p_free = 0; + + + scalar_type upper_bound_p; + scalar_type lower_bound_p; + + find_min_and_max(df, upper_bound_p, lower_bound_p); + + for(long i = 0; i < alpha.nr(); ++i) + { + if(alpha(i) == 1) + { + if (df(i) > upper_bound_p) + upper_bound_p = df(i); + } + else if(alpha(i) == 0) + { + if (df(i) < lower_bound_p) + lower_bound_p = df(i); + } + else + { + ++num_p_free; + sum_p_free += df(i); + } + } + + scalar_type r1; + if(num_p_free > 0) + r1 = sum_p_free/num_p_free; + else + r1 = (upper_bound_p+lower_bound_p)/2; + + rho = r1; + } + + kernel_type kernel_function; + scalar_type nu; + long cache_size; + scalar_type eps; + }; // end of class svm_one_class_trainer + +// ---------------------------------------------------------------------------------------- + + template + void swap ( + svm_one_class_trainer& a, + svm_one_class_trainer& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SVm_ONE_CLASS_TRAINER_Hh_ + diff --git a/ml/dlib/dlib/svm/svm_one_class_trainer_abstract.h b/ml/dlib/dlib/svm/svm_one_class_trainer_abstract.h new file mode 100644 index 000000000..6b55919ad --- /dev/null +++ b/ml/dlib/dlib/svm/svm_one_class_trainer_abstract.h @@ -0,0 +1,201 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SVm_ONE_CLASS_TRAINER_ABSTRACT_ +#ifdef DLIB_SVm_ONE_CLASS_TRAINER_ABSTRACT_ + +#include +#include +#include "../matrix/matrix_abstract.h" +#include "../algs.h" +#include "function_abstract.h" +#include "kernel_abstract.h" +#include "../optimization/optimization_solve_qp3_using_smo_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename K + > + class svm_one_class_trainer + { + /*! + REQUIREMENTS ON K + is a kernel function object as defined in dlib/svm/kernel_abstract.h + + WHAT THIS OBJECT REPRESENTS + This object implements a trainer for a support vector machine for + solving one-class classification problems. It is implemented using the SMO + algorithm. + + The implementation of the training algorithm used by this object is based + on the following excellent paper: + - Chih-Chung Chang and Chih-Jen Lin, LIBSVM : a library for support vector + machines, 2001. Software available at http://www.csie.ntu.edu.tw/~cjlin/libsvm + + !*/ + + public: + typedef K kernel_type; + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + typedef decision_function trained_function_type; + + svm_one_class_trainer ( + ); + /*! + ensures + - This object is properly initialized and ready to be used + to train a support vector machine. + - #get_nu() == 0.1 + - #get_cache_size() == 200 + - #get_epsilon() == 0.001 + !*/ + + svm_one_class_trainer ( + const kernel_type& kernel, + const scalar_type& nu + ); + /*! + requires + - 0 < nu <= 1 + ensures + - This object is properly initialized and ready to be used + to train a support vector machine. + - #get_kernel() == kernel + - #get_nu() == nu + - #get_cache_size() == 200 + - #get_epsilon() == 0.001 + !*/ + + void set_cache_size ( + long cache_size + ); + /*! + requires + - cache_size > 0 + ensures + - #get_cache_size() == cache_size + !*/ + + const long get_cache_size ( + ) const; + /*! + ensures + - returns the number of megabytes of cache this object will use + when it performs training via the this->train() function. + (bigger values of this may make training go faster but won't affect + the result. However, too big a value will cause you to run out of + memory, obviously.) + !*/ + + void set_epsilon ( + scalar_type eps + ); + /*! + requires + - eps > 0 + ensures + - #get_epsilon() == eps + !*/ + + const scalar_type get_epsilon ( + ) const; + /*! + ensures + - returns the error epsilon that determines when training should stop. + Generally a good value for this is 0.001. Smaller values may result + in a more accurate solution but take longer to execute. + !*/ + + void set_kernel ( + const kernel_type& k + ); + /*! + ensures + - #get_kernel() == k + !*/ + + const kernel_type& get_kernel ( + ) const; + /*! + ensures + - returns a copy of the kernel function in use by this object + !*/ + + void set_nu ( + scalar_type nu + ); + /*! + requires + - 0 < nu <= 1 + ensures + - #get_nu() == nu + !*/ + + const scalar_type get_nu ( + ) const; + /*! + ensures + - returns the nu svm parameter. This is a value between 0 and + 1. It is the parameter that determines the trade off between + trying to fit the training data exactly or allowing more errors + but hopefully improving the generalization ability of the + resulting classifier. Smaller values encourage exact fitting + while larger values of nu may encourage better generalization. + For more information you should consult the papers referenced + above. + !*/ + + template < + typename in_sample_vector_type + > + const decision_function train ( + const in_sample_vector_type& x + ) const; + /*! + requires + - x.size() > 0 + - is_col_vector(x) == true + - x == a matrix or something convertible to a matrix via mat(). + Also, x should contain sample_type objects. + ensures + - trains a one-class support vector classifier given the training samples in x. + Training is done when the error is less than get_epsilon(). + - returns a decision function F with the following properties: + - if (new_x is a sample predicted to arise from the distribution + which generated the training samples) then + - F(new_x) >= 0 + - else + - F(new_x) < 0 + !*/ + + void swap ( + svm_one_class_trainer& item + ); + /*! + ensures + - swaps *this and item + !*/ + }; + + template + void swap ( + svm_one_class_trainer& a, + svm_one_class_trainer& b + ) { a.swap(b); } + /*! + provides a global swap + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SVm_ONE_CLASS_TRAINER_ABSTRACT_ + + + diff --git a/ml/dlib/dlib/svm/svm_rank_trainer.h b/ml/dlib/dlib/svm/svm_rank_trainer.h new file mode 100644 index 000000000..0be737f48 --- /dev/null +++ b/ml/dlib/dlib/svm/svm_rank_trainer.h @@ -0,0 +1,495 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SVM_RANK_TrAINER_Hh_ +#define DLIB_SVM_RANK_TrAINER_Hh_ + +#include "svm_rank_trainer_abstract.h" + +#include "ranking_tools.h" +#include "../algs.h" +#include "../optimization.h" +#include "function.h" +#include "kernel.h" +#include "sparse_vector.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename matrix_type, + typename sample_type + > + class oca_problem_ranking_svm : public oca_problem + { + public: + /* + This class is used as part of the implementation of the svm_rank_trainer + defined towards the end of this file. + */ + + typedef typename matrix_type::type scalar_type; + + oca_problem_ranking_svm( + const scalar_type C_, + const std::vector >& samples_, + const bool be_verbose_, + const scalar_type eps_, + const unsigned long max_iter, + const unsigned long dims_ + ) : + samples(samples_), + C(C_), + be_verbose(be_verbose_), + eps(eps_), + max_iterations(max_iter), + dims(dims_) + { + } + + virtual scalar_type get_c ( + ) const + { + return C; + } + + virtual long get_num_dimensions ( + ) const + { + return dims; + } + + virtual bool optimization_status ( + scalar_type current_objective_value, + scalar_type current_error_gap, + scalar_type current_risk_value, + scalar_type current_risk_gap, + unsigned long num_cutting_planes, + unsigned long num_iterations + ) const + { + if (be_verbose) + { + using namespace std; + cout << "objective: " << current_objective_value << endl; + cout << "objective gap: " << current_error_gap << endl; + cout << "risk: " << current_risk_value << endl; + cout << "risk gap: " << current_risk_gap << endl; + cout << "num planes: " << num_cutting_planes << endl; + cout << "iter: " << num_iterations << endl; + cout << endl; + } + + if (num_iterations >= max_iterations) + return true; + + if (current_risk_gap < eps) + return true; + + return false; + } + + virtual bool risk_has_lower_bound ( + scalar_type& lower_bound + ) const + { + lower_bound = 0; + return true; + } + + virtual void get_risk ( + matrix_type& w, + scalar_type& risk, + matrix_type& subgradient + ) const + { + subgradient.set_size(w.size(),1); + subgradient = 0; + risk = 0; + + // Note that we want the risk value to be in terms of the fraction of overall + // rank flips. So a risk of 0.1 would mean that rank flips happen < 10% of the + // time. + + + std::vector rel_scores; + std::vector nonrel_scores; + std::vector rel_counts; + std::vector nonrel_counts; + + unsigned long total_pairs = 0; + + // loop over all the samples and compute the risk and its subgradient at the current solution point w + for (unsigned long i = 0; i < samples.size(); ++i) + { + rel_scores.resize(samples[i].relevant.size()); + nonrel_scores.resize(samples[i].nonrelevant.size()); + + for (unsigned long k = 0; k < rel_scores.size(); ++k) + rel_scores[k] = dot(samples[i].relevant[k], w); + + for (unsigned long k = 0; k < nonrel_scores.size(); ++k) + nonrel_scores[k] = dot(samples[i].nonrelevant[k], w) + 1; + + count_ranking_inversions(rel_scores, nonrel_scores, rel_counts, nonrel_counts); + + total_pairs += rel_scores.size()*nonrel_scores.size(); + + for (unsigned long k = 0; k < rel_counts.size(); ++k) + { + if (rel_counts[k] != 0) + { + risk -= rel_counts[k]*rel_scores[k]; + subtract_from(subgradient, samples[i].relevant[k], rel_counts[k]); + } + } + + for (unsigned long k = 0; k < nonrel_counts.size(); ++k) + { + if (nonrel_counts[k] != 0) + { + risk += nonrel_counts[k]*nonrel_scores[k]; + add_to(subgradient, samples[i].nonrelevant[k], nonrel_counts[k]); + } + } + + } + + const scalar_type scale = 1.0/total_pairs; + + risk *= scale; + subgradient = scale*subgradient; + } + + private: + + // ----------------------------------------------------- + // ----------------------------------------------------- + + + const std::vector >& samples; + const scalar_type C; + + const bool be_verbose; + const scalar_type eps; + const unsigned long max_iterations; + const unsigned long dims; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename matrix_type, + typename sample_type, + typename scalar_type + > + oca_problem_ranking_svm make_oca_problem_ranking_svm ( + const scalar_type C, + const std::vector >& samples, + const bool be_verbose, + const scalar_type eps, + const unsigned long max_iterations, + const unsigned long dims + ) + { + return oca_problem_ranking_svm( + C, samples, be_verbose, eps, max_iterations, dims); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename K + > + class svm_rank_trainer + { + + public: + typedef K kernel_type; + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + typedef decision_function trained_function_type; + + // You are getting a compiler error on this line because you supplied a non-linear kernel + // to the svm_rank_trainer object. You have to use one of the linear kernels with this + // trainer. + COMPILE_TIME_ASSERT((is_same_type >::value || + is_same_type >::value )); + + svm_rank_trainer ( + ) + { + C = 1; + verbose = false; + eps = 0.001; + max_iterations = 10000; + learn_nonnegative_weights = false; + last_weight_1 = false; + } + + explicit svm_rank_trainer ( + const scalar_type& C_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(C_ > 0, + "\t svm_rank_trainer::svm_rank_trainer()" + << "\n\t C_ must be greater than 0" + << "\n\t C_: " << C_ + << "\n\t this: " << this + ); + + C = C_; + verbose = false; + eps = 0.001; + max_iterations = 10000; + learn_nonnegative_weights = false; + last_weight_1 = false; + } + + void set_epsilon ( + scalar_type eps_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(eps_ > 0, + "\t void svm_rank_trainer::set_epsilon()" + << "\n\t eps_ must be greater than 0" + << "\n\t eps_: " << eps_ + << "\n\t this: " << this + ); + + eps = eps_; + } + + const scalar_type get_epsilon ( + ) const { return eps; } + + unsigned long get_max_iterations ( + ) const { return max_iterations; } + + void set_max_iterations ( + unsigned long max_iter + ) + { + max_iterations = max_iter; + } + + void be_verbose ( + ) + { + verbose = true; + } + + void be_quiet ( + ) + { + verbose = false; + } + + bool forces_last_weight_to_1 ( + ) const + { + return last_weight_1; + } + + void force_last_weight_to_1 ( + bool should_last_weight_be_1 + ) + { + last_weight_1 = should_last_weight_be_1; + if (last_weight_1) + prior.set_size(0); + } + + void set_oca ( + const oca& item + ) + { + solver = item; + } + + const oca get_oca ( + ) const + { + return solver; + } + + const kernel_type get_kernel ( + ) const + { + return kernel_type(); + } + + bool learns_nonnegative_weights ( + ) const { return learn_nonnegative_weights; } + + void set_learns_nonnegative_weights ( + bool value + ) + { + learn_nonnegative_weights = value; + if (learn_nonnegative_weights) + prior.set_size(0); + } + + void set_prior ( + const trained_function_type& prior_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(prior_.basis_vectors.size() == 1 && + prior_.alpha(0) == 1, + "\t void svm_rank_trainer::set_prior()" + << "\n\t The supplied prior could not have been created by this object's train() method." + << "\n\t prior_.basis_vectors.size(): " << prior_.basis_vectors.size() + << "\n\t prior_.alpha(0): " << prior_.alpha(0) + << "\n\t this: " << this + ); + + prior = sparse_to_dense(prior_.basis_vectors(0)); + learn_nonnegative_weights = false; + last_weight_1 = false; + } + + bool has_prior ( + ) const + { + return prior.size() != 0; + } + + void set_c ( + scalar_type C_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(C_ > 0, + "\t void svm_rank_trainer::set_c()" + << "\n\t C_ must be greater than 0" + << "\n\t C_: " << C_ + << "\n\t this: " << this + ); + + C = C_; + } + + const scalar_type get_c ( + ) const + { + return C; + } + + const decision_function train ( + const std::vector >& samples + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT(is_ranking_problem(samples) == true, + "\t decision_function svm_rank_trainer::train(samples)" + << "\n\t invalid inputs were given to this function" + << "\n\t samples.size(): " << samples.size() + << "\n\t is_ranking_problem(samples): " << is_ranking_problem(samples) + ); + + + typedef matrix w_type; + w_type w; + + const unsigned long num_dims = max_index_plus_one(samples); + + unsigned long num_nonnegative = 0; + if (learn_nonnegative_weights) + { + num_nonnegative = num_dims; + } + + unsigned long force_weight_1_idx = std::numeric_limits::max(); + if (last_weight_1) + { + force_weight_1_idx = num_dims-1; + } + + if (has_prior()) + { + if (is_matrix::value) + { + // make sure requires clause is not broken + DLIB_CASSERT(num_dims == (unsigned long)prior.size(), + "\t decision_function svm_rank_trainer::train(samples)" + << "\n\t The dimension of the training vectors must match the dimension of\n" + << "\n\t those used to create the prior." + << "\n\t num_dims: " << num_dims + << "\n\t prior.size(): " << prior.size() + ); + } + const unsigned long dims = std::max(num_dims, (unsigned long)prior.size()); + // In the case of sparse sample vectors, it is possible that the input + // vector dimensionality is larger than the prior vector dimensionality. + // We need to check for this case and pad prior with zeros if it is the + // case. + if ((unsigned long)prior.size() < dims) + { + matrix prior_temp = join_cols(prior, zeros_matrix(dims-prior.size(),1)); + solver( make_oca_problem_ranking_svm(C, samples, verbose, eps, max_iterations, dims), + w, + prior_temp); + } + else + { + solver( make_oca_problem_ranking_svm(C, samples, verbose, eps, max_iterations, dims), + w, + prior); + } + + } + else + { + solver( make_oca_problem_ranking_svm(C, samples, verbose, eps, max_iterations, num_dims), + w, + num_nonnegative, + force_weight_1_idx); + } + + + // put the solution into a decision function and then return it + decision_function df; + df.b = 0; + df.basis_vectors.set_size(1); + // Copy the results into the output basis vector. The output vector might be a + // sparse vector container so we need to use this special kind of copy to + // handle that case. + assign(df.basis_vectors(0), matrix_cast(w)); + df.alpha.set_size(1); + df.alpha(0) = 1; + + return df; + } + + const decision_function train ( + const ranking_pair& sample + ) const + { + return train(std::vector >(1, sample)); + } + + private: + + scalar_type C; + oca solver; + scalar_type eps; + bool verbose; + unsigned long max_iterations; + bool learn_nonnegative_weights; + bool last_weight_1; + matrix prior; + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SVM_RANK_TrAINER_Hh_ + diff --git a/ml/dlib/dlib/svm/svm_rank_trainer_abstract.h b/ml/dlib/dlib/svm/svm_rank_trainer_abstract.h new file mode 100644 index 000000000..4658d950f --- /dev/null +++ b/ml/dlib/dlib/svm/svm_rank_trainer_abstract.h @@ -0,0 +1,298 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SVM_RANK_TrAINER_ABSTRACT_Hh_ +#ifdef DLIB_SVM_RANK_TrAINER_ABSTRACT_Hh_ + +#include "ranking_tools_abstract.h" +#include "sparse_vector_abstract.h" +#include "function_abstract.h" +#include "kernel_abstract.h" +#include "../algs.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename K + > + class svm_rank_trainer + { + /*! + REQUIREMENTS ON K + Is either linear_kernel or sparse_linear_kernel. + + WHAT THIS OBJECT REPRESENTS + This object represents a tool for training a ranking support vector machine + using linear kernels. In particular, this object is a tool for training + the Ranking SVM described in the paper: + Optimizing Search Engines using Clickthrough Data by Thorsten Joachims + + Note that we normalize the C parameter by multiplying it by 1/(number of ranking pairs). + Therefore, to make an exact comparison between this object and Equation 12 + in the paper you must multiply C by the appropriate normalizing quantity. + + Finally, note that the implementation of this object is done using the oca + optimizer and count_ranking_inversions() method. This means that it runs + in O(n*log(n)) time, making it suitable for use with large datasets. + !*/ + + public: + typedef K kernel_type; + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + typedef decision_function trained_function_type; + + svm_rank_trainer ( + ); + /*! + ensures + - This object is properly initialized and ready to be used to train a + ranking support vector machine. + - #get_oca() == oca() (i.e. an instance of oca with default parameters) + - #get_c() == 1 + - #get_epsilon() == 0.001 + - this object will not be verbose unless be_verbose() is called + - #get_max_iterations() == 10000 + - #learns_nonnegative_weights() == false + - #forces_last_weight_to_1() == false + - #has_prior() == false + !*/ + + explicit svm_rank_trainer ( + const scalar_type& C + ); + /*! + requires + - C > 0 + ensures + - This object is properly initialized and ready to be used to train a + ranking support vector machine. + - #get_oca() == oca() (i.e. an instance of oca with default parameters) + - #get_c() == C + - #get_epsilon() == 0.001 + - this object will not be verbose unless be_verbose() is called + - #get_max_iterations() == 10000 + - #learns_nonnegative_weights() == false + - #forces_last_weight_to_1() == false + - #has_prior() == false + !*/ + + void set_epsilon ( + scalar_type eps + ); + /*! + requires + - eps > 0 + ensures + - #get_epsilon() == eps + !*/ + + const scalar_type get_epsilon ( + ); + /*! + ensures + - returns the error epsilon that determines when training should stop. + Smaller values may result in a more accurate solution but take longer to + train. You can think of this epsilon value as saying "solve the + optimization problem until the average ranking accuracy is within epsilon + of its optimal value". Here we mean "ranking accuracy" in the same sense + used by test_ranking_function() and cross_validate_ranking_trainer(). + !*/ + + unsigned long get_max_iterations ( + ) const; + /*! + ensures + - returns the maximum number of iterations the SVM optimizer is allowed to + run before it is required to stop and return a result. + !*/ + + void set_max_iterations ( + unsigned long max_iter + ); + /*! + ensures + - #get_max_iterations() == max_iter + !*/ + + void be_verbose ( + ); + /*! + ensures + - This object will print status messages to standard out so that a user can + observe the progress of the algorithm. + !*/ + + void be_quiet ( + ); + /*! + ensures + - this object will not print anything to standard out + !*/ + + bool forces_last_weight_to_1 ( + ) const; + /*! + ensures + - returns true if this trainer has the constraint that the last weight in + the learned parameter vector must be 1. This is the weight corresponding + to the feature in the training vectors with the highest dimension. + !*/ + + void force_last_weight_to_1 ( + bool should_last_weight_be_1 + ); + /*! + ensures + - #forces_last_weight_to_1() == should_last_weight_be_1 + - if (should_last_weight_be_1 == true) then + - #has_prior() == false + !*/ + + void set_oca ( + const oca& item + ); + /*! + ensures + - #get_oca() == item + !*/ + + const oca get_oca ( + ) const; + /*! + ensures + - returns a copy of the optimizer used to solve the SVM problem. + !*/ + + const kernel_type get_kernel ( + ) const; + /*! + ensures + - returns a copy of the kernel function in use by this object. Since the + linear kernels don't have any parameters this function just returns + kernel_type() + !*/ + + bool learns_nonnegative_weights ( + ) const; + /*! + ensures + - The output of training is a weight vector that defines the behavior of + the resulting decision function. That is, the decision function simply + takes the dot product between the learned weight vector and a test sample + and returns the result. Therefore, if learns_nonnegative_weights() == true + then the resulting learned weight vector will always have non-negative + entries. + !*/ + + void set_learns_nonnegative_weights ( + bool value + ); + /*! + ensures + - #learns_nonnegative_weights() == value + - if (value == true) then + - #has_prior() == false + !*/ + + void set_prior ( + const trained_function_type& prior + ); + /*! + requires + - prior == a function produced by a call to this class's train() function. + Therefore, it must be the case that: + - prior.basis_vectors.size() == 1 + - prior.alpha(0) == 1 + ensures + - Subsequent calls to train() will try to learn a function similar to the + given prior. + - #has_prior() == true + - #learns_nonnegative_weights() == false + - #forces_last_weight_to_1() == false + !*/ + + bool has_prior ( + ) const + /*! + ensures + - returns true if a prior has been set and false otherwise. Having a prior + set means that you have called set_prior() and supplied a previously + trained function as a reference. In this case, any call to train() will + try to learn a function that matches the behavior of the prior as close + as possible but also fits the supplied training data. In more technical + detail, having a prior means we replace the ||w||^2 regularizer with one + of the form ||w-prior||^2 where w is the set of parameters for a learned + function. + !*/ + + void set_c ( + scalar_type C + ); + /*! + requires + - C > 0 + ensures + - #get_c() == C + !*/ + + const scalar_type get_c ( + ) const; + /*! + ensures + - returns the SVM regularization parameter. It is the parameter that + determines the trade off between trying to fit the training data exactly + or allowing more errors but hopefully improving the generalization of the + resulting classifier. Larger values encourage exact fitting while + smaller values of C may encourage better generalization. + !*/ + + const decision_function train ( + const std::vector >& samples + ) const; + /*! + requires + - is_ranking_problem(samples) == true + - if (has_prior()) then + - The vectors in samples must have the same dimensionality as the + vectors used to train the prior given to set_prior(). + ensures + - trains a ranking support vector classifier given the training samples. + - returns a decision function F with the following properties: + - F.alpha.size() == 1 + - F.basis_vectors.size() == 1 + - F.alpha(0) == 1 + - Given two vectors, A and B, then A is predicted to come before B + in the learned ranking if and only if F(A) > F(B). + - Based on the contents of samples, F will attempt to give relevant + vectors higher scores than non-relevant vectors. + !*/ + + const decision_function train ( + const ranking_pair& sample + ) const; + /*! + requires + - is_ranking_problem(std::vector >(1, sample)) == true + - if (has_prior()) then + - The vectors in samples must have the same dimensionality as the + vectors used to train the prior given to set_prior(). + ensures + - This is just a convenience routine for calling the above train() + function. That is, it just copies sample into a std::vector object and + invokes the above train() method. This means that calling this function + is equivalent to invoking: + return train(std::vector >(1, sample)); + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SVM_RANK_TrAINER_ABSTRACT_Hh_ + diff --git a/ml/dlib/dlib/svm/svm_threaded.h b/ml/dlib/dlib/svm/svm_threaded.h new file mode 100644 index 000000000..37927456b --- /dev/null +++ b/ml/dlib/dlib/svm/svm_threaded.h @@ -0,0 +1,253 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SVm_THREADED_ +#define DLIB_SVm_THREADED_ + +#include +#include +#include +#include +#include + +#include "svm_threaded_abstract.h" +#include "svm.h" +#include "../matrix.h" +#include "../algs.h" +#include "../serialize.h" +#include "function.h" +#include "kernel.h" +#include "../threads.h" +#include "../pipe.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace cvtti_helpers + { + template + struct job + { + typedef typename trainer_type::scalar_type scalar_type; + typedef typename trainer_type::sample_type sample_type; + typedef typename trainer_type::mem_manager_type mem_manager_type; + typedef matrix sample_vector_type; + typedef matrix scalar_vector_type; + + job() : x(0) {} + + trainer_type trainer; + matrix x_test, x_train; + scalar_vector_type y_test, y_train; + const in_sample_vector_type* x; + }; + + struct task + { + template < + typename trainer_type, + typename mem_manager_type, + typename in_sample_vector_type + > + void operator()( + job& j, + matrix& result + ) + { + try + { + result = test_binary_decision_function(j.trainer.train(rowm(*j.x,j.x_train), j.y_train), rowm(*j.x,j.x_test), j.y_test); + + // Do this just to make j release it's memory since people might run threaded cross validation + // on very large datasets. Every bit of freed memory helps out. + j = job(); + } + catch (invalid_nu_error&) + { + // If this is a svm_nu_trainer then we might get this exception if the nu is + // invalid. In this case just return a cross validation score of 0. + result = 0; + } + catch (std::bad_alloc&) + { + std::cerr << "\nstd::bad_alloc thrown while running cross_validate_trainer_threaded(). Not enough memory.\n" << std::endl; + throw; + } + } + }; + } + + template < + typename trainer_type, + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const matrix + cross_validate_trainer_threaded_impl ( + const trainer_type& trainer, + const in_sample_vector_type& x, + const in_scalar_vector_type& y, + const long folds, + const long num_threads + ) + { + using namespace dlib::cvtti_helpers; + typedef typename trainer_type::mem_manager_type mem_manager_type; + + // make sure requires clause is not broken + DLIB_ASSERT(is_binary_classification_problem(x,y) == true && + 1 < folds && folds <= std::min(sum(y>0),sum(y<0)) && + num_threads > 0, + "\tmatrix cross_validate_trainer()" + << "\n\t invalid inputs were given to this function" + << "\n\t std::min(sum(y>0),sum(y<0)): " << std::min(sum(y>0),sum(y<0)) + << "\n\t folds: " << folds + << "\n\t num_threads: " << num_threads + << "\n\t is_binary_classification_problem(x,y): " << ((is_binary_classification_problem(x,y))? "true":"false") + ); + + + task mytask; + thread_pool tp(num_threads); + + + // count the number of positive and negative examples + long num_pos = 0; + long num_neg = 0; + for (long r = 0; r < y.nr(); ++r) + { + if (y(r) == +1.0) + ++num_pos; + else + ++num_neg; + } + + // figure out how many positive and negative examples we will have in each fold + const long num_pos_test_samples = num_pos/folds; + const long num_pos_train_samples = num_pos - num_pos_test_samples; + const long num_neg_test_samples = num_neg/folds; + const long num_neg_train_samples = num_neg - num_neg_test_samples; + + + long pos_idx = 0; + long neg_idx = 0; + + + + std::vector > > jobs(folds); + std::vector > > results(folds); + + + for (long i = 0; i < folds; ++i) + { + job& j = jobs[i].get(); + + j.x = &x; + j.x_test.set_size (num_pos_test_samples + num_neg_test_samples); + j.y_test.set_size (num_pos_test_samples + num_neg_test_samples); + j.x_train.set_size(num_pos_train_samples + num_neg_train_samples); + j.y_train.set_size(num_pos_train_samples + num_neg_train_samples); + j.trainer = trainer; + + long cur = 0; + + // load up our positive test samples + while (cur < num_pos_test_samples) + { + if (y(pos_idx) == +1.0) + { + j.x_test(cur) = pos_idx; + j.y_test(cur) = +1.0; + ++cur; + } + pos_idx = (pos_idx+1)%x.nr(); + } + + // load up our negative test samples + while (cur < j.x_test.nr()) + { + if (y(neg_idx) == -1.0) + { + j.x_test(cur) = neg_idx; + j.y_test(cur) = -1.0; + ++cur; + } + neg_idx = (neg_idx+1)%x.nr(); + } + + // load the training data from the data following whatever we loaded + // as the testing data + long train_pos_idx = pos_idx; + long train_neg_idx = neg_idx; + cur = 0; + + // load up our positive train samples + while (cur < num_pos_train_samples) + { + if (y(train_pos_idx) == +1.0) + { + j.x_train(cur) = train_pos_idx; + j.y_train(cur) = +1.0; + ++cur; + } + train_pos_idx = (train_pos_idx+1)%x.nr(); + } + + // load up our negative train samples + while (cur < j.x_train.nr()) + { + if (y(train_neg_idx) == -1.0) + { + j.x_train(cur) = train_neg_idx; + j.y_train(cur) = -1.0; + ++cur; + } + train_neg_idx = (train_neg_idx+1)%x.nr(); + } + + // finally spawn a task to process this job + tp.add_task(mytask, jobs[i], results[i]); + + } // for (long i = 0; i < folds; ++i) + + matrix res; + set_all_elements(res,0); + + // now compute the total results + for (long i = 0; i < folds; ++i) + { + res += results[i].get(); + } + + return res/(double)folds; + } + + template < + typename trainer_type, + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const matrix + cross_validate_trainer_threaded ( + const trainer_type& trainer, + const in_sample_vector_type& x, + const in_scalar_vector_type& y, + const long folds, + const long num_threads + ) + { + return cross_validate_trainer_threaded_impl(trainer, + mat(x), + mat(y), + folds, + num_threads); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SVm_THREADED_ + + diff --git a/ml/dlib/dlib/svm/svm_threaded_abstract.h b/ml/dlib/dlib/svm/svm_threaded_abstract.h new file mode 100644 index 000000000..f9973fb5c --- /dev/null +++ b/ml/dlib/dlib/svm/svm_threaded_abstract.h @@ -0,0 +1,62 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SVm_THREADED_ABSTRACT_ +#ifdef DLIB_SVm_THREADED_ABSTRACT_ + +#include "../matrix/matrix_abstract.h" +#include "../algs.h" +#include "../svm.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type, + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const matrix + cross_validate_trainer_threaded ( + const trainer_type& trainer, + const in_sample_vector_type& x, + const in_scalar_vector_type& y, + const long folds, + const long num_threads + ); + /*! + requires + - is_binary_classification_problem(x,y) == true + - 1 < folds <= std::min(sum(y>0),sum(y<0)) + (e.g. There must be at least as many examples of each class as there are folds) + - trainer_type == some kind of trainer object (e.g. svm_nu_trainer) + - num_threads > 0 + - It must be safe for multiple trainer objects to access the elements of x from + multiple threads at the same time. Note that all trainers and kernels in + dlib are thread safe in this regard since they do not mutate the elements of x. + ensures + - performs k-fold cross validation by using the given trainer to solve the + given binary classification problem for the given number of folds. + Each fold is tested using the output of the trainer and the average + classification accuracy from all folds is returned. + - uses num_threads threads of execution in doing the cross validation. + - The accuracy is returned in a row vector, let us call it R. Both + quantities in R are numbers between 0 and 1 which represent the fraction + of examples correctly classified. R(0) is the fraction of +1 examples + correctly classified and R(1) is the fraction of -1 examples correctly + classified. + - The number of folds used is given by the folds argument. + throws + - any exceptions thrown by trainer.train() + - std::bad_alloc + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SVm_THREADED_ABSTRACT_ + + + diff --git a/ml/dlib/dlib/svm/svr_linear_trainer.h b/ml/dlib/dlib/svm/svr_linear_trainer.h new file mode 100644 index 000000000..27ce5b52a --- /dev/null +++ b/ml/dlib/dlib/svm/svr_linear_trainer.h @@ -0,0 +1,424 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SVR_LINEAR_TrAINER_Hh_ +#define DLIB_SVR_LINEAR_TrAINER_Hh_ + +#include "svr_linear_trainer_abstract.h" + +#include "../algs.h" +#include "../optimization.h" +#include "function.h" +#include "kernel.h" +#include "sparse_vector.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename matrix_type, + typename sample_type + > + class oca_problem_linear_svr : public oca_problem + { + public: + /* + This class is used as part of the implementation of the svr_linear_trainer + defined towards the end of this file. + */ + + typedef typename matrix_type::type scalar_type; + + oca_problem_linear_svr( + const scalar_type C_, + const std::vector& samples_, + const std::vector& targets_, + const bool be_verbose_, + const scalar_type eps_, + const scalar_type eps_insensitivity_, + const unsigned long max_iter + ) : + samples(samples_), + targets(targets_), + C(C_), + be_verbose(be_verbose_), + eps(eps_), + eps_insensitivity(eps_insensitivity_), + max_iterations(max_iter) + { + } + + virtual scalar_type get_c ( + ) const + { + return C; + } + + virtual long get_num_dimensions ( + ) const + { + // plus one for the bias term + return max_index_plus_one(samples) + 1; + } + + virtual bool optimization_status ( + scalar_type current_objective_value, + scalar_type current_error_gap, + scalar_type current_risk_value, + scalar_type current_risk_gap, + unsigned long num_cutting_planes, + unsigned long num_iterations + ) const + { + current_risk_value /= samples.size(); + current_risk_gap /= samples.size(); + if (be_verbose) + { + using namespace std; + cout << "objective: " << current_objective_value << endl; + cout << "objective gap: " << current_error_gap << endl; + cout << "risk: " << current_risk_value << endl; + cout << "risk gap: " << current_risk_gap << endl; + cout << "num planes: " << num_cutting_planes << endl; + cout << "iter: " << num_iterations << endl; + cout << endl; + } + + if (num_iterations >= max_iterations) + return true; + + if (current_risk_gap < eps*eps_insensitivity) + return true; + + return false; + } + + virtual bool risk_has_lower_bound ( + scalar_type& lower_bound + ) const + { + lower_bound = 0; + return true; + } + + virtual void get_risk ( + matrix_type& w, + scalar_type& risk, + matrix_type& subgradient + ) const + { + subgradient.set_size(w.size(),1); + subgradient = 0; + risk = 0; + + // loop over all the samples and compute the risk and its subgradient at the current solution point w + for (unsigned long i = 0; i < samples.size(); ++i) + { + const long w_size_m1 = w.size()-1; + const scalar_type prediction = dot(colm(w,0,w_size_m1), samples[i]) - w(w_size_m1); + + if (std::abs(prediction - targets[i]) > eps_insensitivity) + { + if (prediction < targets[i]) + { + subtract_from(subgradient, samples[i]); + subgradient(w_size_m1) += 1; + } + else + { + add_to(subgradient, samples[i]); + subgradient(w_size_m1) -= 1; + } + + risk += std::abs(prediction - targets[i]) - eps_insensitivity; + } + } + } + + private: + + // ----------------------------------------------------- + // ----------------------------------------------------- + + + const std::vector& samples; + const std::vector& targets; + const scalar_type C; + + const bool be_verbose; + const scalar_type eps; + const scalar_type eps_insensitivity; + const unsigned long max_iterations; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename matrix_type, + typename sample_type, + typename scalar_type + > + oca_problem_linear_svr make_oca_problem_linear_svr ( + const scalar_type C, + const std::vector& samples, + const std::vector& targets, + const bool be_verbose, + const scalar_type eps, + const scalar_type eps_insensitivity, + const unsigned long max_iterations + ) + { + return oca_problem_linear_svr( + C, samples, targets, be_verbose, eps, eps_insensitivity, max_iterations); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename K + > + class svr_linear_trainer + { + + public: + typedef K kernel_type; + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + typedef decision_function trained_function_type; + + // You are getting a compiler error on this line because you supplied a non-linear kernel + // to the svr_linear_trainer object. You have to use one of the linear kernels with this + // trainer. + COMPILE_TIME_ASSERT((is_same_type >::value || + is_same_type >::value )); + + svr_linear_trainer ( + ) + { + C = 1; + verbose = false; + eps = 0.01; + max_iterations = 10000; + learn_nonnegative_weights = false; + last_weight_1 = false; + eps_insensitivity = 0.1; + } + + explicit svr_linear_trainer ( + const scalar_type& C_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(C_ > 0, + "\t svr_linear_trainer::svr_linear_trainer()" + << "\n\t C_ must be greater than 0" + << "\n\t C_: " << C_ + << "\n\t this: " << this + ); + + C = C_; + verbose = false; + eps = 0.01; + max_iterations = 10000; + learn_nonnegative_weights = false; + last_weight_1 = false; + eps_insensitivity = 0.1; + } + + void set_epsilon ( + scalar_type eps_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(eps_ > 0, + "\t void svr_linear_trainer::set_epsilon()" + << "\n\t eps_ must be greater than 0" + << "\n\t eps_: " << eps_ + << "\n\t this: " << this + ); + + eps = eps_; + } + + const scalar_type get_epsilon ( + ) const { return eps; } + + void set_epsilon_insensitivity ( + scalar_type eps_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(eps_ > 0, + "\tvoid svr_linear_trainer::set_epsilon_insensitivity(eps_)" + << "\n\t invalid inputs were given to this function" + << "\n\t eps_: " << eps_ + ); + eps_insensitivity = eps_; + } + + const scalar_type get_epsilon_insensitivity ( + ) const + { + return eps_insensitivity; + } + + unsigned long get_max_iterations ( + ) const { return max_iterations; } + + void set_max_iterations ( + unsigned long max_iter + ) + { + max_iterations = max_iter; + } + + void be_verbose ( + ) + { + verbose = true; + } + + void be_quiet ( + ) + { + verbose = false; + } + + bool forces_last_weight_to_1 ( + ) const + { + return last_weight_1; + } + + void force_last_weight_to_1 ( + bool should_last_weight_be_1 + ) + { + last_weight_1 = should_last_weight_be_1; + } + + void set_oca ( + const oca& item + ) + { + solver = item; + } + + const oca get_oca ( + ) const + { + return solver; + } + + const kernel_type get_kernel ( + ) const + { + return kernel_type(); + } + + bool learns_nonnegative_weights ( + ) const { return learn_nonnegative_weights; } + + void set_learns_nonnegative_weights ( + bool value + ) + { + learn_nonnegative_weights = value; + } + + void set_c ( + scalar_type C_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(C_ > 0, + "\t void svr_linear_trainer::set_c()" + << "\n\t C_ must be greater than 0" + << "\n\t C_: " << C_ + << "\n\t this: " << this + ); + + C = C_; + } + + const scalar_type get_c ( + ) const + { + return C; + } + + const decision_function train ( + const std::vector& samples, + const std::vector& targets + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT(is_learning_problem(samples, targets) == true, + "\t decision_function svr_linear_trainer::train(samples, targets)" + << "\n\t invalid inputs were given to this function" + << "\n\t samples.size(): " << samples.size() + << "\n\t targets.size(): " << targets.size() + << "\n\t is_learning_problem(samples,targets): " << is_learning_problem(samples,targets) + ); + + + typedef matrix w_type; + w_type w; + + const unsigned long num_dims = max_index_plus_one(samples); + + unsigned long num_nonnegative = 0; + if (learn_nonnegative_weights) + { + num_nonnegative = num_dims; + } + + unsigned long force_weight_1_idx = std::numeric_limits::max(); + if (last_weight_1) + { + force_weight_1_idx = num_dims-1; + } + + solver( make_oca_problem_linear_svr(C, samples, targets, verbose, eps, eps_insensitivity, max_iterations), + w, + num_nonnegative, + force_weight_1_idx); + + + // put the solution into a decision function and then return it + decision_function df; + df.b = static_cast(w(w.size()-1)); + df.basis_vectors.set_size(1); + // Copy the plane normal into the output basis vector. The output vector might be a + // sparse vector container so we need to use this special kind of copy to handle that case. + // As an aside, the reason for using max_index_plus_one() and not just w.size()-1 is because + // doing it this way avoids an inane warning from gcc that can occur in some cases. + const long out_size = max_index_plus_one(samples); + assign(df.basis_vectors(0), matrix_cast(colm(w, 0, out_size))); + df.alpha.set_size(1); + df.alpha(0) = 1; + + return df; + } + + private: + + scalar_type C; + oca solver; + scalar_type eps; + bool verbose; + unsigned long max_iterations; + bool learn_nonnegative_weights; + bool last_weight_1; + scalar_type eps_insensitivity; + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SVR_LINEAR_TrAINER_Hh_ + diff --git a/ml/dlib/dlib/svm/svr_linear_trainer_abstract.h b/ml/dlib/dlib/svm/svr_linear_trainer_abstract.h new file mode 100644 index 000000000..c74310f06 --- /dev/null +++ b/ml/dlib/dlib/svm/svr_linear_trainer_abstract.h @@ -0,0 +1,269 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SVR_LINEAR_TrAINER_ABSTRACT_Hh_ +#ifdef DLIB_SVR_LINEAR_TrAINER_ABSTRACT_Hh_ + +#include "sparse_vector_abstract.h" +#include "function_abstract.h" +#include "kernel_abstract.h" +#include "../algs.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename K + > + class svr_linear_trainer + { + /*! + REQUIREMENTS ON K + Is either linear_kernel or sparse_linear_kernel. + + WHAT THIS OBJECT REPRESENTS + This object implements a trainer for performing epsilon-insensitive support + vector regression. It uses the oca optimizer so it is very efficient at + solving this problem when linear kernels are used, making it suitable for + use with large datasets. + + For an introduction to support vector regression see the following paper: + A Tutorial on Support Vector Regression by Alex J. Smola and Bernhard Scholkopf. + Note that this object solves the version of support vector regression + defined by equation (3) in the paper, except that we incorporate the bias + term into the w vector by appending a 1 to the end of each sample. + !*/ + + public: + typedef K kernel_type; + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + typedef decision_function trained_function_type; + + svr_linear_trainer ( + ); + /*! + ensures + - This object is properly initialized and ready to be used to train a + ranking support vector machine. + - #get_oca() == oca() (i.e. an instance of oca with default parameters) + - #get_c() == 1 + - #get_epsilon() == 0.01 + - #get_epsilon_insensitivity() = 0.1 + - This object will not be verbose unless be_verbose() is called + - #get_max_iterations() == 10000 + - #learns_nonnegative_weights() == false + - #forces_last_weight_to_1() == false + !*/ + + explicit svr_linear_trainer ( + const scalar_type& C + ); + /*! + requires + - C > 0 + ensures + - This object is properly initialized and ready to be used to train a + ranking support vector machine. + - #get_oca() == oca() (i.e. an instance of oca with default parameters) + - #get_c() == C + - #get_epsilon() == 0.01 + - #get_epsilon_insensitivity() = 0.1 + - This object will not be verbose unless be_verbose() is called + - #get_max_iterations() == 10000 + - #learns_nonnegative_weights() == false + - #forces_last_weight_to_1() == false + !*/ + + void set_epsilon ( + scalar_type eps + ); + /*! + requires + - eps > 0 + ensures + - #get_epsilon() == eps + !*/ + + const scalar_type get_epsilon ( + ) const; + /*! + ensures + - returns the error epsilon that determines when training should stop. + Smaller values may result in a more accurate solution but take longer to + train. You can think of this epsilon value as saying "solve the + optimization problem until the average regression error is within epsilon + of its optimal value". See get_epsilon_insensitivity() below for a + definition of "regression error". + !*/ + + void set_epsilon_insensitivity ( + scalar_type eps + ); + /*! + requires + - eps > 0 + ensures + - #get_epsilon_insensitivity() == eps + !*/ + + const scalar_type get_epsilon_insensitivity ( + ) const; + /*! + ensures + - This object tries to find a function which minimizes the regression error + on a training set. This error is measured in the following way: + - if (abs(predicted_value - true_labeled_value) < eps) then + - The error is 0. That is, any function which gets within eps of + the correct output is good enough. + - else + - The error grows linearly once it gets bigger than eps. + + So epsilon-insensitive regression means we do regression but stop trying + to fit a data point once it is "close enough". This function returns + that eps value which controls what we mean by "close enough". + !*/ + + unsigned long get_max_iterations ( + ) const; + /*! + ensures + - returns the maximum number of iterations the SVM optimizer is allowed to + run before it is required to stop and return a result. + !*/ + + void set_max_iterations ( + unsigned long max_iter + ); + /*! + ensures + - #get_max_iterations() == max_iter + !*/ + + void be_verbose ( + ); + /*! + ensures + - This object will print status messages to standard out so that a user can + observe the progress of the algorithm. + !*/ + + void be_quiet ( + ); + /*! + ensures + - this object will not print anything to standard out + !*/ + + bool forces_last_weight_to_1 ( + ) const; + /*! + ensures + - returns true if this trainer has the constraint that the last weight in + the learned parameter vector must be 1. This is the weight corresponding + to the feature in the training vectors with the highest dimension. + - Forcing the last weight to 1 also disables the bias and therefore the b + field of the learned decision_function will be 0 when forces_last_weight_to_1() == true. + !*/ + + void force_last_weight_to_1 ( + bool should_last_weight_be_1 + ); + /*! + ensures + - #forces_last_weight_to_1() == should_last_weight_be_1 + !*/ + + void set_oca ( + const oca& item + ); + /*! + ensures + - #get_oca() == item + !*/ + + const oca get_oca ( + ) const; + /*! + ensures + - returns a copy of the optimizer used to solve the SVM problem. + !*/ + + const kernel_type get_kernel ( + ) const; + /*! + ensures + - returns a copy of the kernel function in use by this object. Since the + linear kernels don't have any parameters this function just returns + kernel_type() + !*/ + + bool learns_nonnegative_weights ( + ) const; + /*! + ensures + - The output of training is a weight vector and a bias value. These two + things define the resulting decision function. That is, the decision + function simply takes the dot product between the learned weight vector + and a test sample, then subtracts the bias value. Therefore, if + learns_nonnegative_weights() == true then the resulting learned weight + vector will always have non-negative entries. The bias value may still + be negative though. + !*/ + + void set_learns_nonnegative_weights ( + bool value + ); + /*! + ensures + - #learns_nonnegative_weights() == value + !*/ + + void set_c ( + scalar_type C + ); + /*! + requires + - C > 0 + ensures + - #get_c() == C + !*/ + + const scalar_type get_c ( + ) const; + /*! + ensures + - returns the SVM regularization parameter. It is the parameter that + determines the trade off between trying to fit the training data exactly + or allowing more errors but hopefully improving the generalization of the + resulting classifier. Larger values encourage exact fitting while + smaller values of C may encourage better generalization. + !*/ + + const decision_function train ( + const std::vector& samples, + const std::vector& targets + ) const; + /*! + requires + - is_learning_problem(samples,targets) == true + ensures + - performs support vector regression given the training samples and targets. + - returns a decision_function F with the following properties: + - F(new_sample) == predicted target value for new_sample + - F.alpha.size() == 1 + - F.basis_vectors.size() == 1 + - F.alpha(0) == 1 + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SVR_LINEAR_TrAINER_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/svm/svr_trainer.h b/ml/dlib/dlib/svm/svr_trainer.h new file mode 100644 index 000000000..bc6378a20 --- /dev/null +++ b/ml/dlib/dlib/svm/svr_trainer.h @@ -0,0 +1,393 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SVm_EPSILON_REGRESSION_TRAINER_Hh_ +#define DLIB_SVm_EPSILON_REGRESSION_TRAINER_Hh_ + + +#include "svr_trainer_abstract.h" +#include +#include +#include "../matrix.h" +#include "../algs.h" + +#include "function.h" +#include "kernel.h" +#include "../optimization/optimization_solve_qp3_using_smo.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename K + > + class svr_trainer + { + public: + typedef K kernel_type; + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + typedef decision_function trained_function_type; + + svr_trainer ( + ) : + C(1), + eps_insensitivity(0.1), + cache_size(200), + eps(0.001) + { + } + + void set_cache_size ( + long cache_size_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(cache_size_ > 0, + "\tvoid svr_trainer::set_cache_size(cache_size_)" + << "\n\t invalid inputs were given to this function" + << "\n\t cache_size: " << cache_size_ + ); + cache_size = cache_size_; + } + + long get_cache_size ( + ) const + { + return cache_size; + } + + void set_epsilon ( + scalar_type eps_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(eps_ > 0, + "\tvoid svr_trainer::set_epsilon(eps_)" + << "\n\t invalid inputs were given to this function" + << "\n\t eps_: " << eps_ + ); + eps = eps_; + } + + const scalar_type get_epsilon ( + ) const + { + return eps; + } + + void set_epsilon_insensitivity ( + scalar_type eps_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(eps_ > 0, + "\tvoid svr_trainer::set_epsilon_insensitivity(eps_)" + << "\n\t invalid inputs were given to this function" + << "\n\t eps_: " << eps_ + ); + eps_insensitivity = eps_; + } + + const scalar_type get_epsilon_insensitivity ( + ) const + { + return eps_insensitivity; + } + + void set_kernel ( + const kernel_type& k + ) + { + kernel_function = k; + } + + const kernel_type& get_kernel ( + ) const + { + return kernel_function; + } + + void set_c ( + scalar_type C_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(C_ > 0, + "\t void svr_trainer::set_c()" + << "\n\t C must be greater than 0" + << "\n\t C_: " << C_ + << "\n\t this: " << this + ); + + C = C_; + } + + const scalar_type get_c ( + ) const + { + return C; + } + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y + ) const + { + return do_train(mat(x), mat(y)); + } + + void swap ( + svr_trainer& item + ) + { + exchange(kernel_function, item.kernel_function); + exchange(C, item.C); + exchange(eps_insensitivity, item.eps_insensitivity); + exchange(cache_size, item.cache_size); + exchange(eps, item.eps); + } + + private: + + // ------------------------------------------------------------------------------------ + + template + struct op_quad + { + explicit op_quad( + const M& m_ + ) : m(m_) {} + + const M& m; + + typedef typename M::type type; + typedef type const_ret_type; + const static long cost = M::cost + 2; + + inline const_ret_type apply ( long r, long c) const + { + if (r < m.nr()) + { + if (c < m.nc()) + { + return m(r,c); + } + else + { + return -m(r,c-m.nc()); + } + } + else + { + if (c < m.nc()) + { + return -m(r-m.nr(),c); + } + else + { + return m(r-m.nr(),c-m.nc()); + } + } + } + + const static long NR = 2*M::NR; + const static long NC = 2*M::NC; + typedef typename M::mem_manager_type mem_manager_type; + typedef typename M::layout_type layout_type; + + long nr () const { return 2*m.nr(); } + long nc () const { return 2*m.nc(); } + + template bool aliases ( const matrix_exp& item) const + { return m.aliases(item); } + template bool destructively_aliases ( const matrix_exp& item) const + { return m.aliases(item); } + }; + + template < + typename EXP + > + const matrix_op > make_quad ( + const matrix_exp& m + ) const + /*! + ensures + - returns the following matrix: + m -m + -m m + - I.e. returns a matrix that is twice the size of m and just + contains copies of m and -m + !*/ + { + typedef op_quad op; + return matrix_op(op(m.ref())); + } + + // ------------------------------------------------------------------------------------ + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function do_train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y + ) const + { + typedef typename K::scalar_type scalar_type; + typedef typename decision_function::sample_vector_type sample_vector_type; + typedef typename decision_function::scalar_vector_type scalar_vector_type; + + // make sure requires clause is not broken + DLIB_ASSERT(is_learning_problem(x,y) == true, + "\tdecision_function svr_trainer::train(x,y)" + << "\n\t invalid inputs were given to this function" + << "\n\t x.nr(): " << x.nr() + << "\n\t y.nr(): " << y.nr() + << "\n\t x.nc(): " << x.nc() + << "\n\t y.nc(): " << y.nc() + ); + + + scalar_vector_type alpha; + + solve_qp3_using_smo solver; + + solver(symmetric_matrix_cache(make_quad(kernel_matrix(kernel_function,x)), cache_size), + uniform_matrix(2*x.size(),1, eps_insensitivity) + join_cols(y,-y), + join_cols(uniform_matrix(x.size(),1,1), uniform_matrix(x.size(),1,-1)), + 0, + C, + C, + alpha, + eps); + + scalar_type b; + calculate_b(alpha,solver.get_gradient(),C,b); + + alpha = -rowm(alpha,range(0,x.size()-1)) + rowm(alpha,range(x.size(), alpha.size()-1)); + + // count the number of support vectors + const long sv_count = (long)sum(alpha != 0); + + scalar_vector_type sv_alpha; + sample_vector_type support_vectors; + + // size these column vectors so that they have an entry for each support vector + sv_alpha.set_size(sv_count); + support_vectors.set_size(sv_count); + + // load the support vectors and their alpha values into these new column matrices + long idx = 0; + for (long i = 0; i < alpha.nr(); ++i) + { + if (alpha(i) != 0) + { + sv_alpha(idx) = alpha(i); + support_vectors(idx) = x(i); + ++idx; + } + } + + // now return the decision function + return decision_function (sv_alpha, -b, kernel_function, support_vectors); + } + + // ------------------------------------------------------------------------------------ + + template < + typename scalar_vector_type + > + void calculate_b( + const scalar_vector_type& alpha, + const scalar_vector_type& df, + const scalar_type& C, + scalar_type& b + ) const + { + using namespace std; + long num_free = 0; + scalar_type sum_free = 0; + + scalar_type upper_bound = -numeric_limits::infinity(); + scalar_type lower_bound = numeric_limits::infinity(); + + find_min_and_max(df, upper_bound, lower_bound); + + for(long i = 0; i < alpha.nr(); ++i) + { + if(i < alpha.nr()/2) + { + if(alpha(i) == C) + { + if (df(i) > upper_bound) + upper_bound = df(i); + } + else if(alpha(i) == 0) + { + if (df(i) < lower_bound) + lower_bound = df(i); + } + else + { + ++num_free; + sum_free += df(i); + } + } + else + { + if(alpha(i) == C) + { + if (-df(i) < lower_bound) + lower_bound = -df(i); + } + else if(alpha(i) == 0) + { + if (-df(i) > upper_bound) + upper_bound = -df(i); + } + else + { + ++num_free; + sum_free -= df(i); + } + } + } + + if(num_free > 0) + b = sum_free/num_free; + else + b = (upper_bound+lower_bound)/2; + } + + // ------------------------------------------------------------------------------------ + + + kernel_type kernel_function; + scalar_type C; + scalar_type eps_insensitivity; + long cache_size; + scalar_type eps; + }; // end of class svr_trainer + +// ---------------------------------------------------------------------------------------- + + template + void swap ( + svr_trainer& a, + svr_trainer& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SVm_EPSILON_REGRESSION_TRAINER_Hh_ + diff --git a/ml/dlib/dlib/svm/svr_trainer_abstract.h b/ml/dlib/dlib/svm/svr_trainer_abstract.h new file mode 100644 index 000000000..c1dd5f1f3 --- /dev/null +++ b/ml/dlib/dlib/svm/svr_trainer_abstract.h @@ -0,0 +1,209 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SVm_EPSILON_REGRESSION_TRAINER_ABSTRACT_ +#ifdef DLIB_SVm_EPSILON_REGRESSION_TRAINER_ABSTRACT_ + +#include +#include +#include "../matrix/matrix_abstract.h" +#include "../algs.h" +#include "function_abstract.h" +#include "kernel_abstract.h" +#include "../optimization/optimization_solve_qp3_using_smo_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename K + > + class svr_trainer + { + /*! + REQUIREMENTS ON K + is a kernel function object as defined in dlib/svm/kernel_abstract.h + + WHAT THIS OBJECT REPRESENTS + This object implements a trainer for performing epsilon-insensitive support + vector regression. It is implemented using the SMO algorithm. + + The implementation of the eps-SVR training algorithm used by this object is based + on the following paper: + - Chih-Chung Chang and Chih-Jen Lin, LIBSVM : a library for support vector + machines, 2001. Software available at http://www.csie.ntu.edu.tw/~cjlin/libsvm + !*/ + + public: + typedef K kernel_type; + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + typedef decision_function trained_function_type; + + svr_trainer ( + ); + /*! + ensures + - This object is properly initialized and ready to be used + to train a support vector machine. + - #get_c() == 1 + - #get_epsilon_insensitivity() == 0.1 + - #get_cache_size() == 200 + - #get_epsilon() == 0.001 + !*/ + + void set_cache_size ( + long cache_size + ); + /*! + requires + - cache_size > 0 + ensures + - #get_cache_size() == cache_size + !*/ + + const long get_cache_size ( + ) const; + /*! + ensures + - returns the number of megabytes of cache this object will use + when it performs training via the this->train() function. + (bigger values of this may make training go faster but won't affect + the result. However, too big a value will cause you to run out of + memory, obviously.) + !*/ + + void set_epsilon ( + scalar_type eps + ); + /*! + requires + - eps > 0 + ensures + - #get_epsilon() == eps + !*/ + + const scalar_type get_epsilon ( + ) const; + /*! + ensures + - returns the error epsilon that determines when training should stop. + Generally a good value for this is 0.001. Smaller values may result + in a more accurate solution but take longer to execute. + !*/ + + void set_epsilon_insensitivity ( + scalar_type eps + ); + /*! + requires + - eps > 0 + ensures + - #get_epsilon_insensitivity() == eps + !*/ + + const scalar_type get_epsilon_insensitivity ( + ) const; + /*! + ensures + - This object tries to find a function which minimizes the + regression error on a training set. This error is measured + in the following way: + - if (abs(predicted_value - true_labeled_value) < eps) then + - The error is 0. That is, any function which gets within + eps of the correct output is good enough. + - else + - The error grows linearly once it gets bigger than eps + + So epsilon-insensitive regression means we do regression but + stop trying to fit a data point once it is "close enough". + This function returns that eps value which controls what we + mean by "close enough". + !*/ + + void set_kernel ( + const kernel_type& k + ); + /*! + ensures + - #get_kernel() == k + !*/ + + const kernel_type& get_kernel ( + ) const; + /*! + ensures + - returns a copy of the kernel function in use by this object + !*/ + + void set_c ( + scalar_type C + ); + /*! + requires + - C > 0 + ensures + - #get_c() == C + !*/ + + const scalar_type get_c ( + ) const; + /*! + ensures + - returns the SVR regularization parameter. It is the parameter that + determines the trade-off between trying to reduce the training error + or allowing more errors but hopefully improving the generalization + of the resulting decision_function. Larger values encourage exact + fitting while smaller values of C may encourage better generalization. + !*/ + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y + ) const; + /*! + requires + - is_learning_problem(x,y) == true + - x == a matrix or something convertible to a matrix via mat(). + Also, x should contain sample_type objects. + - y == a matrix or something convertible to a matrix via mat(). + Also, y should contain scalar_type objects. + ensures + - performs support vector regression given the training samples in x and + target values in y. + - returns a decision_function F with the following properties: + - F(new_x) == predicted y value + !*/ + + void swap ( + svr_trainer& item + ); + /*! + ensures + - swaps *this and item + !*/ + }; + + template + void swap ( + svr_trainer& a, + svr_trainer& b + ) { a.swap(b); } + /*! + provides a global swap + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SVm_EPSILON_REGRESSION_TRAINER_ABSTRACT_ + + + diff --git a/ml/dlib/dlib/svm/track_association_function.h b/ml/dlib/dlib/svm/track_association_function.h new file mode 100644 index 000000000..bf5ef36c7 --- /dev/null +++ b/ml/dlib/dlib/svm/track_association_function.h @@ -0,0 +1,154 @@ +// Copyright (C) 2014 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_TRACK_ASSOCiATION_FUNCTION_Hh_ +#define DLIB_TRACK_ASSOCiATION_FUNCTION_Hh_ + + +#include "track_association_function_abstract.h" +#include +#include +#include "../algs.h" +#include "../serialize.h" +#include "assignment_function.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename detection_type + > + class feature_extractor_track_association + { + public: + typedef typename detection_type::track_type track_type; + typedef typename track_type::feature_vector_type feature_vector_type; + + typedef detection_type lhs_element; + typedef track_type rhs_element; + + feature_extractor_track_association() : num_dims(0), num_nonnegative(0) {} + + explicit feature_extractor_track_association ( + unsigned long num_dims_, + unsigned long num_nonnegative_ + ) : num_dims(num_dims_), num_nonnegative(num_nonnegative_) {} + + unsigned long num_features( + ) const { return num_dims; } + + unsigned long num_nonnegative_weights ( + ) const { return num_nonnegative; } + + void get_features ( + const detection_type& det, + const track_type& track, + feature_vector_type& feats + ) const + { + track.get_similarity_features(det, feats); + } + + friend void serialize (const feature_extractor_track_association& item, std::ostream& out) + { + serialize(item.num_dims, out); + serialize(item.num_nonnegative, out); + } + + friend void deserialize (feature_extractor_track_association& item, std::istream& in) + { + deserialize(item.num_dims, in); + deserialize(item.num_nonnegative, in); + } + + private: + unsigned long num_dims; + unsigned long num_nonnegative; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename detection_type_ + > + class track_association_function + { + public: + + typedef detection_type_ detection_type; + typedef typename detection_type::track_type track_type; + typedef assignment_function > association_function_type; + + track_association_function() {} + + track_association_function ( + const association_function_type& assoc_ + ) : assoc(assoc_) + { + } + + const association_function_type& get_assignment_function ( + ) const + { + return assoc; + } + + void operator() ( + std::vector& tracks, + const std::vector& dets + ) const + { + std::vector assignments = assoc(dets, tracks); + std::vector updated_track(tracks.size(), false); + // now update all the tracks with the detections that associated to them. + for (unsigned long i = 0; i < assignments.size(); ++i) + { + if (assignments[i] != -1) + { + tracks[assignments[i]].update_track(dets[i]); + updated_track[assignments[i]] = true; + } + else + { + track_type new_track; + new_track.update_track(dets[i]); + tracks.push_back(new_track); + } + } + + // Now propagate all the tracks that didn't get any detections. + for (unsigned long i = 0; i < updated_track.size(); ++i) + { + if (!updated_track[i]) + tracks[i].propagate_track(); + } + } + + friend void serialize (const track_association_function& item, std::ostream& out) + { + int version = 1; + serialize(version, out); + serialize(item.assoc, out); + } + friend void deserialize (track_association_function& item, std::istream& in) + { + int version = 0; + deserialize(version, in); + if (version != 1) + throw serialization_error("Unexpected version found while deserializing dlib::track_association_function."); + + deserialize(item.assoc, in); + } + + private: + + assignment_function > assoc; + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_TRACK_ASSOCiATION_FUNCTION_Hh_ + diff --git a/ml/dlib/dlib/svm/track_association_function_abstract.h b/ml/dlib/dlib/svm/track_association_function_abstract.h new file mode 100644 index 000000000..8a6fe153c --- /dev/null +++ b/ml/dlib/dlib/svm/track_association_function_abstract.h @@ -0,0 +1,271 @@ +// Copyright (C) 2014 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_TRACK_ASSOCiATION_FUNCTION_ABSTRACT_Hh_ +#ifdef DLIB_TRACK_ASSOCiATION_FUNCTION_ABSTRACT_Hh_ + +#include +#include "assignment_function_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class example_detection + { + /*! + WHAT THIS OBJECT REPRESENTS + This object defines the interface a detection must implement if it is to be + used with the track_association_function defined at the bottom of this + file. In this case, the interface is very simple. A detection object is + only required to define the track_type typedef and it must also be possible + to store detection objects in a std::vector. + !*/ + + public: + // Each detection object should be designed to work with a specific track object. + // This typedef lets us determine which track type is meant for use with this + // detection object. + typedef class example_track track_type; + + }; + +// ---------------------------------------------------------------------------------------- + + class example_track + { + /*! + WHAT THIS OBJECT REPRESENTS + This object defines the interface a track must implement if it is to be + used with the track_association_function defined at the bottom of this + file. + !*/ + + public: + // This type should be a dlib::matrix capable of storing column vectors or an + // unsorted sparse vector type as defined in dlib/svm/sparse_vector_abstract.h. + typedef matrix_or_sparse_vector_type feature_vector_type; + + example_track( + ); + /*! + ensures + - this object is properly initialized + !*/ + + void get_similarity_features ( + const example_detection& det, + feature_vector_type& feats + ) const; + /*! + requires + - update_track() has been called on this track at least once. + ensures + - #feats == A feature vector that contains information describing how + likely it is that det is a detection from the object corresponding to + this track. That is, the feature vector should contain information that + lets someone decide if det should be associated to this track. + - #feats.size() must be a constant. That is, every time we call + get_similarity_features() it must output a feature vector of the same + dimensionality. + !*/ + + void update_track ( + const example_detection& det + ); + /*! + ensures + - Updates this track with the given detection assuming that det is the most + current observation of the object under track. + !*/ + + void propagate_track ( + ); + /*! + ensures + - propagates this track forward in time one time step. + !*/ + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename detection_type + > + class feature_extractor_track_association + { + /*! + REQUIREMENTS ON detection_type + It must be an object that implements an interface compatible with the + example_detection discussed above. This also means that detection_type::track_type + must be an object that implements an interface compatible with example_track + defined above. + + WHAT THIS OBJECT REPRESENTS + This object is an adapter that converts from the detection/track style + interface defined above to the feature extraction interface required by the + association rule learning tools in dlib. Specifically, it converts the + detection/track interface into a form usable by the assignment_function and + its trainer object structural_assignment_trainer. + !*/ + + public: + typedef typename detection_type::track_type track_type; + typedef typename track_type::feature_vector_type feature_vector_type; + typedef detection_type lhs_element; + typedef track_type rhs_element; + + unsigned long num_features( + ) const; + /*! + ensures + - returns the dimensionality of the feature vectors produced by get_features(). + !*/ + + void get_features ( + const detection_type& det, + const track_type& track, + feature_vector_type& feats + ) const; + /*! + ensures + - performs: track.get_similarity_features(det, feats); + !*/ + }; + + template < + typename detection_type + > + void serialize ( + const feature_extractor_track_association& item, + std::ostream& out + ); + /*! + Provides serialization support. + !*/ + + template < + typename detection_type + > + void deserialize ( + feature_extractor_track_association& item, + std::istream& in + ); + /*! + Provides deserialization support. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename detection_type_ + > + class track_association_function + { + /*! + REQUIREMENTS ON detection_type + It must be an object that implements an interface compatible with the + example_detection discussed above. This also means that detection_type::track_type + must be an object that implements an interface compatible with example_track + defined above. + + WHAT THIS OBJECT REPRESENTS + This object is a tool that helps you implement an object tracker. So for + example, if you wanted to track people moving around in a video then this + object can help. In particular, imagine you have a tool for detecting the + positions of each person in an image. Then you can run this person + detector on the video and at each time step, i.e. at each frame, you get a + set of person detections. However, that by itself doesn't tell you how + many people there are in the video and where they are moving to and from. + To get that information you need to figure out which detections match each + other from frame to frame. This is where the track_association_function + comes in. It performs the detection to track association. It will also do + some of the track management tasks like creating a new track when a + detection doesn't match any of the existing tracks. + + Internally, this object is implemented using the assignment_function object. + In fact, it's really just a thin wrapper around assignment_function and + exists just to provide a more convenient interface to users doing detection + to track association. + !*/ + public: + + typedef detection_type_ detection_type; + typedef typename detection_type::track_type track_type; + typedef assignment_function > association_function_type; + + track_association_function( + ); + /*! + ensures + - #get_assignment_function() will be default initialized. + !*/ + + track_association_function ( + const association_function_type& assoc + ); + /*! + ensures + - #get_assignment_function() == assoc + !*/ + + const association_function_type& get_assignment_function ( + ) const; + /*! + ensures + - returns the assignment_function used by this object to assign detections + to tracks. + !*/ + + void operator() ( + std::vector& tracks, + const std::vector& dets + ) const; + /*! + ensures + - This function uses get_assignment_function() to assign each detection + in dets to its appropriate track in tracks. Then each track which + associates to a detection is updated by calling update_track() with the + associated detection. + - Detections that don't associate with any of the elements of tracks will + spawn new tracks. For each unassociated detection, this is done by + creating a new track_type object, calling update_track() on it with the + new detection, and then adding the new track into tracks. + - Tracks that don't have a detection associate to them are propagated + forward in time by calling propagate_track() on them. That is, we call + propagate_track() only on tracks that do not get associated with a + detection. + !*/ + }; + + template < + typename detection_type + > + void serialize ( + const track_association_function& item, + std::ostream& out + ); + /*! + Provides serialization support. + !*/ + + template < + typename detection_type + > + void deserialize ( + track_association_function& item, + std::istream& in + ); + /*! + Provides deserialization support. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_TRACK_ASSOCiATION_FUNCTION_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/svm_threaded.h b/ml/dlib/dlib/svm_threaded.h new file mode 100644 index 000000000..f77fab705 --- /dev/null +++ b/ml/dlib/dlib/svm_threaded.h @@ -0,0 +1,36 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#ifdef DLIB_ALL_SOURCE_END +#include "dlib_basic_cpp_build_tutorial.txt" +#endif + +#ifndef DLIB_SVm_THREADED_HEADER +#define DLIB_SVm_THREADED_HEADER + +#include "svm.h" +#include "svm/svm_threaded.h" +#include "svm/structural_svm_problem_threaded.h" +#include "svm/structural_svm_distributed.h" +#include "svm/structural_svm_object_detection_problem.h" +#include "svm/structural_object_detection_trainer.h" +#include "svm/structural_svm_sequence_labeling_problem.h" +#include "svm/structural_sequence_labeling_trainer.h" + +#include "svm/structural_svm_assignment_problem.h" +#include "svm/structural_assignment_trainer.h" +#include "svm/cross_validate_track_association_trainer.h" +#include "svm/structural_track_association_trainer.h" + +#include "svm/structural_svm_graph_labeling_problem.h" +#include "svm/structural_graph_labeling_trainer.h" +#include "svm/cross_validate_graph_labeling_trainer.h" +#include "svm/svm_multiclass_linear_trainer.h" +#include "svm/one_vs_one_trainer.h" +#include "svm/one_vs_all_trainer.h" +#include "svm/structural_sequence_segmentation_trainer.h" + +#endif // DLIB_SVm_THREADED_HEADER + + + diff --git a/ml/dlib/dlib/sync_extension.h b/ml/dlib/dlib/sync_extension.h new file mode 100644 index 000000000..8a50bce01 --- /dev/null +++ b/ml/dlib/dlib/sync_extension.h @@ -0,0 +1,31 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SYNC_EXTENSIOn_ +#define DLIB_SYNC_EXTENSIOn_ + +#include "sync_extension/sync_extension_kernel_1.h" + + + +namespace dlib +{ + + template < + typename base + > + class sync_extension + { + sync_extension() {} + public: + + //----------- kernels --------------- + + // kernel_1a + typedef sync_extension_kernel_1 + kernel_1a; + + }; +} + +#endif // DLIB_SYNC_EXTENSIOn_ + diff --git a/ml/dlib/dlib/sync_extension/sync_extension_kernel_1.h b/ml/dlib/dlib/sync_extension/sync_extension_kernel_1.h new file mode 100644 index 000000000..71fe7c391 --- /dev/null +++ b/ml/dlib/dlib/sync_extension/sync_extension_kernel_1.h @@ -0,0 +1,67 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SYNC_EXTENSION_KERNEl_1_ +#define DLIB_SYNC_EXTENSION_KERNEl_1_ + +#include "../threads.h" +#include "../algs.h" +#include "sync_extension_kernel_abstract.h" + +namespace dlib +{ + + template < + typename base + > + class sync_extension_kernel_1 : public base + { + + rmutex m; + rsignaler s; + + public: + + sync_extension_kernel_1 () : s(m) {} + + template < typename T > + sync_extension_kernel_1 (const T& one) : base(one),s(m) {} + template < typename T, typename U > + sync_extension_kernel_1 (const T& one, const U& two) : base(one,two),s(m) {} + + + const rmutex& get_mutex( + ) const { return m; } + + void lock ( + ) const { m.lock(); } + + void unlock ( + ) const { m.unlock(); } + + void wait ( + ) const { s.wait(); } + + bool wait_or_timeout ( + unsigned long milliseconds + ) const { return s.wait_or_timeout(milliseconds); } + + void broadcast ( + ) const { s.broadcast(); } + + void signal ( + ) const { s.signal(); } + + }; + + template < + typename base + > + inline void swap ( + sync_extension_kernel_1& a, + sync_extension_kernel_1& b + ) { a.swap(b); } + +} + +#endif // DLIB_SYNC_EXTENSION_KERNEl_1_ + diff --git a/ml/dlib/dlib/sync_extension/sync_extension_kernel_abstract.h b/ml/dlib/dlib/sync_extension/sync_extension_kernel_abstract.h new file mode 100644 index 000000000..35665d430 --- /dev/null +++ b/ml/dlib/dlib/sync_extension/sync_extension_kernel_abstract.h @@ -0,0 +1,190 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SYNC_EXTENSION_KERNEl_ABSTRACT_ +#ifdef DLIB_SYNC_EXTENSION_KERNEl_ABSTRACT_ + +#include "../threads/threads_kernel_abstract.h" +#include "../threads/rmutex_extension_abstract.h" +#include "../threads/rsignaler_extension_abstract.h" +#include "../algs.h" + +namespace dlib +{ + + template < + typename base + > + class sync_extension : public base + { + + /*! + REQUIREMENTS ON base + base must have a default constructor + base must implement swap(base&) + + + WHAT THIS OBJECT REPRESENTS + This object represents a general extension to any object (given the + restrictions on base). This object gives any object which it extends + an integrated rmutex and rsignaler object. The extended object will + then be able to be treated as if it was also a rmutex and rsignaler. + + NOTE that just like the threading api, this object does not check + its requires clauses so be careful with it. + + Also note that swap() does not swap the rmutex and rsignaler objects. + the rmutex and rsignaler are associated with the object instance itself, + not with whatever the object represents. + !*/ + + + public: + + sync_extension ( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc + this is thrown if there is a problem gathering memory + - dlib::thread_error + this is thrown if there is a problem creating threading objects + - any exception thrown by the constructor for the parent class base + !*/ + + template < + typename T + > + sync_extension ( + const T& one + ); + /*! + ensures + - #*this is properly initialized + - the argument one will be passed on to the constructor for the parent + class base. + throws + - std::bad_alloc + this is thrown if there is a problem gathering memory + - dlib::thread_error + this is thrown if there is a problem creating threading objects + - any exception thrown by the constructor for the parent class base + !*/ + + template < + typename T, + typename U + > + sync_extension ( + const T& one, + const T& two + ); + /*! + ensures + - #*this is properly initialized + - the argument one will be passed on to the constructor for the parent + class base as its first argument. + - the argument two will be passed on to the constructor for the parent + class base as its second argument. + throws + - std::bad_alloc + this is thrown if there is a problem gathering memory + - dlib::thread_error + this is thrown if there is a problem creating threading objects + - any exception thrown by the constructor for the parent class base + !*/ + + + const rmutex& get_mutex ( + ) const; + /*! + ensures + - returns the rmutex embedded in this object + !*/ + + void lock ( + ) const; + /*! + requires + - the thread calling lock() does not already have a lock on *this + ensures + - if (*this is currently locked by another thread) then + - the thread that called lock() on *this is put to sleep until + it becomes available + - if (*this is currently unlocked) then + - #*this becomes locked and the current thread is NOT put to sleep + but now "owns" #*this + !*/ + + void unlock ( + ) const; + /*! + ensures + - #*this is unlocked (i.e. other threads may now lock this object) + !*/ + + + void wait ( + ) const; + /*! + requires + - *this is locked and owned by the calling thread + ensures + - atomically unlocks *this and blocks the calling thread + - calling thread will wake if another thread calls signal() or broadcast() + on *this + - when wait returns the calling thread again has a lock on #*this + !*/ + + + bool wait_or_timeout ( + unsigned long milliseconds + ) const; + /*! + requires + - *this is locked and owned by the calling thread + ensures + - atomically unlocks *this and blocks the calling thread + - calling thread will wake if another thread calls signal() or broadcast() + on *this + - after the specified number of milliseconds has elapsed the calling thread + will wake once *this is free to be locked + - when wait returns the calling thread again has a lock on #*this + + - returns false if the call to wait_or_timeout timed out + - returns true if the call did not time out + !*/ + + void signal ( + ) const; + /*! + ensures + - if (at least one thread is waiting on *this) then + - at least one of the waiting threads will wake + !*/ + + void broadcast ( + ) const; + /*! + ensures + - any and all threads waiting on *this will wake + !*/ + + }; + + template < + typename base + > + inline void swap ( + sync_extension& a, + sync_extension& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + +} + +#endif // DLIB_SYNC_EXTENSION_KERNEl_ABSTRACT_ + diff --git a/ml/dlib/dlib/test/CMakeLists.txt b/ml/dlib/dlib/test/CMakeLists.txt new file mode 100644 index 000000000..d6147cb0e --- /dev/null +++ b/ml/dlib/dlib/test/CMakeLists.txt @@ -0,0 +1,181 @@ +# +# This is a CMake makefile. You can find the cmake utility and +# information about it at http://www.cmake.org +# + +cmake_minimum_required(VERSION 2.8.12) + +# create a variable called target_name and set it to the string "dtest" +set (target_name dtest) +PROJECT(${target_name}) + +# compile the dlib/all/source.cpp file into its own object just to make sure it compiles +set(DLIB_TEST_COMPILE_ALL_SOURCE_CPP ON) + +add_subdirectory(.. dlib_build) + +# This variable contains a list of all the tests we are building +# into the regression test suite. +set (tests + example.cpp + active_learning.cpp + any.cpp + any_function.cpp + array2d.cpp + array.cpp + assignment_learning.cpp + base64.cpp + bayes_nets.cpp + bigint.cpp + binary_search_tree_kernel_1a.cpp + binary_search_tree_kernel_2a.cpp + binary_search_tree_mm1.cpp + binary_search_tree_mm2.cpp + bridge.cpp + bsp.cpp + byte_orderer.cpp + cca.cpp + clustering.cpp + cmd_line_parser.cpp + cmd_line_parser_wchar_t.cpp + compress_stream.cpp + conditioning_class_c.cpp + conditioning_class.cpp + config_reader.cpp + correlation_tracker.cpp + crc32.cpp + create_iris_datafile.cpp + data_io.cpp + directed_graph.cpp + discriminant_pca.cpp + disjoint_subsets.cpp + disjoint_subsets_sized.cpp + ekm_and_lisf.cpp + empirical_kernel_map.cpp + entropy_coder.cpp + entropy_encoder_model.cpp + example_args.cpp + face.cpp + fft.cpp + fhog.cpp + filtering.cpp + find_max_factor_graph_nmplp.cpp + find_max_factor_graph_viterbi.cpp + geometry.cpp + graph.cpp + graph_cuts.cpp + graph_labeler.cpp + hash.cpp + hash_map.cpp + hash_set.cpp + hash_table.cpp + hog_image.cpp + image.cpp + iosockstream.cpp + is_same_object.cpp + isotonic_regression.cpp + kcentroid.cpp + kernel_matrix.cpp + kmeans.cpp + learning_to_track.cpp + least_squares.cpp + linear_manifold_regularizer.cpp + lspi.cpp + lz77_buffer.cpp + map.cpp + matrix2.cpp + matrix3.cpp + matrix4.cpp + matrix_chol.cpp + matrix.cpp + matrix_eig.cpp + matrix_lu.cpp + matrix_qr.cpp + max_cost_assignment.cpp + max_sum_submatrix.cpp + md5.cpp + member_function_pointer.cpp + metaprogramming.cpp + mpc.cpp + multithreaded_object.cpp + numerical_integration.cpp + object_detector.cpp + oca.cpp + one_vs_all_trainer.cpp + one_vs_one_trainer.cpp + optimization.cpp + optimization_test_functions.cpp + global_optimization.cpp + opt_qp_solver.cpp + parallel_for.cpp + parse.cpp + pipe.cpp + pixel.cpp + probabilistic.cpp + pyramid_down.cpp + queue.cpp + rand.cpp + ranking.cpp + read_write_mutex.cpp + reference_counter.cpp + rls.cpp + random_forest.cpp + sammon.cpp + scan_image.cpp + sequence.cpp + sequence_labeler.cpp + sequence_segmenter.cpp + serialize.cpp + set.cpp + sldf.cpp + sliding_buffer.cpp + sockets2.cpp + sockets.cpp + sockstreambuf.cpp + sparse_vector.cpp + stack.cpp + static_map.cpp + static_set.cpp + statistics.cpp + std_vector_c.cpp + string.cpp + svm_c_linear.cpp + svm_c_linear_dcd.cpp + svm.cpp + svm_multiclass_linear.cpp + svm_struct.cpp + svr_linear_trainer.cpp + symmetric_matrix_cache.cpp + thread_pool.cpp + threads.cpp + timer.cpp + tokenizer.cpp + trust_region.cpp + tuple.cpp + type_safe_union.cpp + vectorstream.cpp + dnn.cpp + cublas.cpp + find_optimal_parameters.cpp + elastic_net.cpp + ) + + +# add all the cpp files we want to compile to this list. This tells +# cmake that they are part of our target (which is the executable named dtest) +ADD_EXECUTABLE(${target_name} main.cpp tester.cpp ${tests}) + +# Turn on all warnings when using gcc. +if (CMAKE_COMPILER_IS_GNUCXX) + add_definitions("-W -Wall") +endif() + + +TARGET_LINK_LIBRARIES(${target_name} dlib::dlib ) + + +if (NOT DLIB_NO_GUI_SUPPORT) + add_subdirectory(gui) + add_subdirectory(examples) + add_subdirectory(tools) +endif() diff --git a/ml/dlib/dlib/test/WINDOWS_build_and_run_all_unit_tests.bat b/ml/dlib/dlib/test/WINDOWS_build_and_run_all_unit_tests.bat new file mode 100644 index 000000000..0cacfa631 --- /dev/null +++ b/ml/dlib/dlib/test/WINDOWS_build_and_run_all_unit_tests.bat @@ -0,0 +1,42 @@ +date /T > test_log.txt +time /T >> test_log.txt + +rem the pings are to wait between builds so visual studio doesn't get in a funk. + + + +echo testing python >> test_log.txt +rmdir /S /Q build_python +mkdir build_python +cd build_python +cmake -G "Visual Studio 14 2015 Win64" ../../../tools/python -DPYTHON3=ON +cmake --build . --config Release --target install || exit /B +ping 127.0.0.1 -n 5 -w 1000 > null +cd .. + + + +echo testing vc2015 >> test_log.txt +rmdir /S /Q build_vc2015_64 +mkdir build_vc2015_64 +cd build_vc2015_64 +cmake -G "Visual Studio 14 2015 Win64" .. +cmake --build . --config Release || exit /B +ping 127.0.0.1 -n 5 -w 1000 > null +cmake --build . --config Debug || exit /B +ping 127.0.0.1 -n 5 -w 1000 > null +cd Release +dtest --runall -d || exit /B +cd .. +cd .. + + + + + +del null +type test_log.txt + +date /T +time /T + diff --git a/ml/dlib/dlib/test/active_learning.cpp b/ml/dlib/dlib/test/active_learning.cpp new file mode 100644 index 000000000..9dc0013a5 --- /dev/null +++ b/ml/dlib/dlib/test/active_learning.cpp @@ -0,0 +1,165 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include + +#include "tester.h" + + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.active_learning"); + +// ---------------------------------------------------------------------------------------- + + typedef matrix sample_type; + typedef radial_basis_kernel kernel_type; + +// ---------------------------------------------------------------------------------------- + + void make_dataset ( + std::vector& samples, + std::vector& labels + ) + { + for (int r = -10; r <= 10; ++r) + { + for (int c = -10; c <= 10; ++c) + { + sample_type samp(2); + samp(0) = r; + samp(1) = c; + samples.push_back(samp); + + // if this point is less than 10 from the origin + if (sqrt((double)r*r + c*c) <= 8) + labels.push_back(+1); + else + labels.push_back(-1); + + } + } + + + vector_normalizer normalizer; + normalizer.train(samples); + for (unsigned long i = 0; i < samples.size(); ++i) + samples[i] = normalizer(samples[i]); + + randomize_samples(samples, labels); + + /* + cout << "samples.size(): " << samples.size() << endl; + cout << "num +1 samples: "<< sum(mat(labels) > 0) << endl; + cout << "num -1 samples: "<< sum(mat(labels) < 0) << endl; + */ + + empirical_kernel_map ekm; + ekm.load(kernel_type(0.15), samples); + for (unsigned long i = 0; i < samples.size(); ++i) + samples[i] = ekm.project(samples[i]); + + //cout << "dims: "<< ekm.out_vector_size() << endl; + } + +// ---------------------------------------------------------------------------------------- + + double test_rank_unlabeled_training_samples ( + const std::vector& samples, + const std::vector& labels, + active_learning_mode mode, + int iterations, + bool pick_front + ) + { + matrix s; + s = sum(mat(labels) > 0), sum(mat(labels) < 0); + s /= labels.size(); + + + svm_c_linear_dcd_trainer > trainer; + trainer.set_c(25); + + const unsigned long initial_size = 1; + std::vector tsamples(samples.begin(), samples.begin()+initial_size); + std::vector tlabels(labels.begin(), labels.begin()+initial_size); + + decision_function > df; + + double random_score = 0; + double active_learning_score = 0; + for (int i = 0; i < iterations; ++i) + { + print_spinner(); + random_subset_selector sss = randomly_subsample(samples,50,i); + random_subset_selector ssl = randomly_subsample(labels,50,i); + std::vector results; + + results = rank_unlabeled_training_samples(trainer, tsamples, tlabels, sss, mode); + + const unsigned long idx = pick_front ? results.front() : results.back(); + tsamples.push_back(sss[idx]); + tlabels.push_back(ssl[idx]); + + df = trainer.train(tsamples, tlabels); + //cout << "tsamples.size(): " << tsamples.size() << endl; + const unsigned long num = tsamples.size(); + const double active = test_binary_decision_function(df, samples, labels)*s; + //cout << "test: "<< active; + df = trainer.train(randomly_subsample(samples,num,i), randomly_subsample(labels,num,i)); + const double random = test_binary_decision_function(df, samples, labels)*s; + //cout << "test: "<< random << endl; + + active_learning_score += active; + random_score += random; + + //cout << "\n\n***********\n\n" << flush; + } + + dlog << LINFO << "pick_front: " << pick_front << " mode: "<< mode; + dlog << LINFO << "active_learning_score: "<< active_learning_score; + dlog << LINFO << "random_score: "<< random_score; + return active_learning_score / random_score; + } + +// ---------------------------------------------------------------------------------------- + + class test_active_learning : public tester + { + public: + test_active_learning ( + ) : + tester ("test_active_learning", + "Runs tests on the active learning components.") + {} + + void perform_test ( + ) + { + std::vector samples; + std::vector labels; + print_spinner(); + make_dataset(samples, labels); + dlog << LINFO << "samples.size(): "<< samples.size(); + + // When we pick the best/front ranked element then the active learning method + // shouldn't do much worse than random selection (and often much better). + DLIB_TEST(test_rank_unlabeled_training_samples(samples, labels, max_min_margin, 35, true) >= 0.97); + DLIB_TEST(test_rank_unlabeled_training_samples(samples, labels, ratio_margin, 25, true) >= 0.96); + // However, picking the worst ranked element should do way worse than random + // selection. + DLIB_TEST(test_rank_unlabeled_training_samples(samples, labels, max_min_margin, 25, false) < 0.8); + DLIB_TEST(test_rank_unlabeled_training_samples(samples, labels, ratio_margin, 25, false) < 0.8); + } + } a; + +} + + + diff --git a/ml/dlib/dlib/test/any.cpp b/ml/dlib/dlib/test/any.cpp new file mode 100644 index 000000000..355d00b31 --- /dev/null +++ b/ml/dlib/dlib/test/any.cpp @@ -0,0 +1,139 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include +#include +#include "../rand.h" + +#include "tester.h" + + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.any"); + +// ---------------------------------------------------------------------------------------- + + void test_contains_4( + const any a + ) + { + DLIB_TEST(a.is_empty() == false); + DLIB_TEST(a.contains() == true); + DLIB_TEST(a.contains() == false); + DLIB_TEST(any_cast(a) == 4); + } + +// ---------------------------------------------------------------------------------------- + + void run_test() + { + any a, b, c; + + DLIB_TEST(a.is_empty()); + DLIB_TEST(a.contains() == false); + DLIB_TEST(a.contains() == false); + DLIB_TEST(a.is_empty()); + + a = b; + + swap(a,b); + a.swap(b); + + a = 4; + DLIB_TEST(a.is_empty() == false); + DLIB_TEST(a.contains() == true); + DLIB_TEST(a.contains() == false); + DLIB_TEST(any_cast(a) == 4); + + test_contains_4(a); + + DLIB_TEST(a.is_empty() == false); + DLIB_TEST(a.contains() == true); + DLIB_TEST(a.contains() == false); + DLIB_TEST(any_cast(a) == 4); + + bool error = false; + try + { + any_cast(a); + } + catch (bad_any_cast&) + { + error = true; + } + DLIB_TEST(error); + + swap(a,b); + + test_contains_4(b); + + DLIB_TEST(a.is_empty()); + + a = b; + + test_contains_4(a); + + c.get() = "test string"; + DLIB_TEST(c.get() == "test string"); + + a = c; + DLIB_TEST(a.cast_to() == "test string"); + + + a.clear(); + DLIB_TEST(a.is_empty()); + error = false; + try + { + any_cast(a); + } + catch (bad_any_cast&) + { + error = true; + } + DLIB_TEST(error); + + + a = 1; + b = 2; + + int* a_ptr = &a.get(); + int* b_ptr = &b.get(); + + swap(a,b); + DLIB_TEST(a_ptr == &b.get()); + DLIB_TEST(b_ptr == &a.get()); + } + +// ---------------------------------------------------------------------------------------- + + class any_tester : public tester + { + public: + any_tester ( + ) : + tester ("test_any", + "Runs tests on the any component.") + {} + + void perform_test ( + ) + { + run_test(); + } + } a; + +} + + diff --git a/ml/dlib/dlib/test/any_function.cpp b/ml/dlib/dlib/test/any_function.cpp new file mode 100644 index 000000000..8defb5988 --- /dev/null +++ b/ml/dlib/dlib/test/any_function.cpp @@ -0,0 +1,253 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include +#include +#include "../rand.h" + +#include "tester.h" + + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.any_function"); + +// ---------------------------------------------------------------------------------------- + + int add ( int a, int b) { return a + b; } + string cat ( string a, string b) { return a + b; } + +// ---------------------------------------------------------------------------------------- + + void set_vals1( int& a) { a = 1; } + void set_vals2( int& a, int& b) { a = 1; b = 2; } + void set_vals3( int& a, int& b, int& c) { a = 1; b = 2; c = 3; } + void set_vals4( int& a, int& b, int& c, int& d) { a = 1; b = 2; c = 3; d = 4; } + void set_vals5( int& a, int& b, int& c, int& d, int& e) { a = 1; b = 2; c = 3; d = 4; e = 5; } + void set_vals6( int& a, int& b, int& c, int& d, int& e, int& f) { a = 1; b = 2; c = 3; d = 4; e = 5; f = 6; } + void set_vals7( int& a, int& b, int& c, int& d, int& e, int& f, int& g) { a = 1; b = 2; c = 3; d = 4; e = 5; f = 6; g = 7; } + + void set_vals8( int& a, int& b, int& c, int& d, int& e, int& f, int& g, int& h) + { a = 1; b = 2; c = 3; d = 4; e = 5; f = 6; g = 7; h = 8; } + + void set_vals9( int& a, int& b, int& c, int& d, int& e, int& f, int& g, int& h, int& i) + { a = 1; b = 2; c = 3; d = 4; e = 5; f = 6; g = 7; h = 8; i = 9;} + + void set_vals10( int& a, int& b, int& c, int& d, int& e, int& f, int& g, int& h, int& i, int& j) + { a = 1; b = 2; c = 3; d = 4; e = 5; f = 6; g = 7; h = 8; i = 9; j = 10;} + + void zero_vals( int& a, int& b, int& c, int& d, int& e, int& f, int& g, int& h, int& i, int& j) + { a = 0; b = 0; c = 0; d = 0; e = 0; f = 0; g = 0; h = 0; i = 0; j = 0;} + +// ---------------------------------------------------------------------------------------- + + struct test + { + int operator()() const { return 4; } + }; + + struct test2 + { + int v; + + test2() : v(0) {} + test2(int val) : v(val) {} + int operator()() const { return v; } + }; + +// ---------------------------------------------------------------------------------------- + + void test_contains_4( + const any_function a + ) + { + DLIB_TEST(a.is_empty() == false); + DLIB_TEST(a.is_set() == true); + DLIB_TEST(a.contains() == true); + DLIB_TEST(a.contains() == false); + DLIB_TEST(any_cast(a)() == 4); + DLIB_TEST(a() == 4); + } + +// ---------------------------------------------------------------------------------------- + + void run_test() + { + any_function a, b, c; + + DLIB_TEST(a.is_empty()); + DLIB_TEST(a.is_set()==false); + DLIB_TEST(a.contains() == false); + DLIB_TEST(a.contains() == false); + DLIB_TEST(a.is_empty()); + + a = b; + + swap(a,b); + a.swap(b); + + a = test(); + test_contains_4(a); + + + bool error = false; + try + { + any_cast(a); + } + catch (bad_any_cast&) + { + error = true; + } + DLIB_TEST(error); + + swap(a,b); + + test_contains_4(b); + + DLIB_TEST(a.is_empty()); + + a = b; + + test_contains_4(a); + + c.get() = test2(10); + DLIB_TEST(c.get().v == 10); + + a = c; + DLIB_TEST(a.cast_to().v == 10); + + + a.clear(); + DLIB_TEST(a.is_empty()); + error = false; + try + { + any_cast(a); + } + catch (bad_any_cast&) + { + error = true; + } + DLIB_TEST(error); + + } + +// ---------------------------------------------------------------------------------------- + + void run_test2() + { + any_function f = &add; + + DLIB_TEST(f(1,3) == 4); + + any_function g(&cat); + DLIB_TEST(g("one", "two") == "onetwo"); + } + +// ---------------------------------------------------------------------------------------- + + void run_test3() + { + any_function f1; + any_function f2; + any_function f3; + any_function f4; + any_function f5; + any_function f6; + any_function f7; + any_function f8; + any_function f9; + any_function f10; + + f1 = set_vals1; + f2 = set_vals2; + f3 = set_vals3; + f4 = set_vals4; + f5 = set_vals5; + f6 = set_vals6; + f7 = set_vals7; + f8 = set_vals8; + f9 = set_vals9; + f10 = set_vals10; + + int a,b,c,d,e,f,g,h,i,j; + + zero_vals(a,b,c,d,e,f,g,h,i,j); + + f1(a); + DLIB_TEST(a==1); + zero_vals(a,b,c,d,e,f,g,h,i,j); + + f2(a,b); + DLIB_TEST(a==1 && b==2); + zero_vals(a,b,c,d,e,f,g,h,i,j); + + f3(a,b,c); + DLIB_TEST(a==1 && b==2 && c==3); + zero_vals(a,b,c,d,e,f,g,h,i,j); + + f4(a,b,c,d); + DLIB_TEST(a==1 && b==2 && c==3 && d==4); + zero_vals(a,b,c,d,e,f,g,h,i,j); + + f5(a,b,c,d,e); + DLIB_TEST(a==1 && b==2 && c==3 && d==4 && e==5); + zero_vals(a,b,c,d,e,f,g,h,i,j); + + f6(a,b,c,d,e,f); + DLIB_TEST(a==1 && b==2 && c==3 && d==4 && e==5 && f==6); + zero_vals(a,b,c,d,e,f,g,h,i,j); + + f7(a,b,c,d,e,f,g); + DLIB_TEST(a==1 && b==2 && c==3 && d==4 && e==5 && f==6 && g==7); + zero_vals(a,b,c,d,e,f,g,h,i,j); + + f8(a,b,c,d,e,f,g,h); + DLIB_TEST(a==1 && b==2 && c==3 && d==4 && e==5 && f==6 && g==7 && h==8); + zero_vals(a,b,c,d,e,f,g,h,i,j); + + f9(a,b,c,d,e,f,g,h,i); + DLIB_TEST(a==1 && b==2 && c==3 && d==4 && e==5 && f==6 && g==7 && h==8 && i==9); + zero_vals(a,b,c,d,e,f,g,h,i,j); + + f10(a,b,c,d,e,f,g,h,i,j); + DLIB_TEST(a==1 && b==2 && c==3 && d==4 && e==5 && f==6 && g==7 && h==8 && i==9 && j==10); + zero_vals(a,b,c,d,e,f,g,h,i,j); + } +// ---------------------------------------------------------------------------------------- + + class test_any_function : public tester + { + public: + test_any_function ( + ) : + tester ("test_any_function", + "Runs tests on the any_function component.") + {} + + void perform_test ( + ) + { + print_spinner(); + run_test(); + print_spinner(); + run_test2(); + print_spinner(); + run_test3(); + } + } a; + +} + + diff --git a/ml/dlib/dlib/test/array.cpp b/ml/dlib/dlib/test/array.cpp new file mode 100644 index 000000000..55ba1a724 --- /dev/null +++ b/ml/dlib/dlib/test/array.cpp @@ -0,0 +1,669 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include +#include +#include +#include +#include +#include +#include + + +#include "tester.h" + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + + using dlib::array; + + logger dlog("test.array"); + + template < + typename array + > + void array_expand_test ( + ) + /*! + requires + - array is an implementation of array/array_sort_abstract.h + array is instantiated with unsigned long + ensures + - runs tests on array for compliance with the specs + !*/ + { + dlib::rand rnd; + + DLIB_TEST(dlib::is_array::value == true); + + array a1, a2; + + { + array a4(4); + DLIB_TEST(a4.size() == 4); + } + + { + array a1, a2; + + for (int k = 1; k < 100000; k += 1000) + { + for (int i = 0; i < 10; ++i) + { + a1.clear(); + a1.set_max_size(500+k); + a1.set_size(500+k); + for (unsigned long j = 0; j < a1.size(); ++j) + { + a1[j] = j; + DLIB_TEST(a1[j] == j); + } + } + } + } + + DLIB_TEST(a1.max_size() == 0); + DLIB_TEST(a2.max_size() == 0); + + + DLIB_TEST(a1.size() == 0); + DLIB_TEST(a1.at_start()); + DLIB_TEST(a1.current_element_valid() == false); + DLIB_TEST(a1.move_next() == false); + DLIB_TEST(a1.size() == 0); + DLIB_TEST(a1.current_element_valid() == false); + DLIB_TEST(a1.at_start() == false); + DLIB_TEST(a1.move_next() == false); + DLIB_TEST(a1.current_element_valid() == false); + DLIB_TEST(a1.size() == 0); + DLIB_TEST(a1.at_start() == false); + DLIB_TEST(a1.size() == 0); + + swap(a1,a2); + DLIB_TEST(a2.size() == 0); + DLIB_TEST(a2.current_element_valid() == false); + DLIB_TEST(a2.at_start() == false); + DLIB_TEST(a2.move_next() == false); + DLIB_TEST(a2.current_element_valid() == false); + DLIB_TEST(a2.size() == 0); + DLIB_TEST(a2.at_start() == false); + DLIB_TEST(a2.size() == 0); + + + + DLIB_TEST(a1.size() == 0); + DLIB_TEST(a1.at_start()); + DLIB_TEST(a1.current_element_valid() == false); + DLIB_TEST(a1.move_next() == false); + DLIB_TEST(a1.size() == 0); + DLIB_TEST(a1.current_element_valid() == false); + DLIB_TEST(a1.at_start() == false); + DLIB_TEST(a1.move_next() == false); + DLIB_TEST(a1.current_element_valid() == false); + DLIB_TEST(a1.size() == 0); + DLIB_TEST(a1.at_start() == false); + DLIB_TEST(a1.size() == 0); + + a1.reset(); + a2.reset(); + + for (unsigned long k = 0; k < 4; ++k) + { + + DLIB_TEST(a1.size() == 0); + DLIB_TEST(a1.at_start()); + DLIB_TEST(a1.current_element_valid() == false); + DLIB_TEST(a1.move_next() == false); + DLIB_TEST(a1.size() == 0); + DLIB_TEST(a1.current_element_valid() == false); + DLIB_TEST(a1.at_start() == false); + DLIB_TEST(a1.move_next() == false); + DLIB_TEST(a1.current_element_valid() == false); + DLIB_TEST(a1.size() == 0); + DLIB_TEST(a1.at_start() == false); + DLIB_TEST(a1.size() == 0); + + swap(a1,a2); + DLIB_TEST(a2.size() == 0); + DLIB_TEST(a2.current_element_valid() == false); + DLIB_TEST(a2.at_start() == false); + DLIB_TEST(a2.move_next() == false); + DLIB_TEST(a2.current_element_valid() == false); + DLIB_TEST(a2.size() == 0); + DLIB_TEST(a2.at_start() == false); + DLIB_TEST(a2.size() == 0); + + + + DLIB_TEST(a1.size() == 0); + DLIB_TEST(a1.at_start()); + DLIB_TEST(a1.current_element_valid() == false); + DLIB_TEST(a1.move_next() == false); + DLIB_TEST(a1.size() == 0); + DLIB_TEST(a1.current_element_valid() == false); + DLIB_TEST(a1.at_start() == false); + DLIB_TEST(a1.move_next() == false); + DLIB_TEST(a1.current_element_valid() == false); + DLIB_TEST(a1.size() == 0); + DLIB_TEST(a1.at_start() == false); + DLIB_TEST(a1.size() == 0); + + a1.clear(); + a2.clear(); + + + DLIB_TEST(a1.size() == 0); + DLIB_TEST(a1.at_start()); + DLIB_TEST(a1.current_element_valid() == false); + DLIB_TEST(a1.move_next() == false); + DLIB_TEST(a1.size() == 0); + DLIB_TEST(a1.current_element_valid() == false); + DLIB_TEST(a1.at_start() == false); + DLIB_TEST(a1.move_next() == false); + DLIB_TEST(a1.current_element_valid() == false); + DLIB_TEST(a1.size() == 0); + DLIB_TEST(a1.at_start() == false); + DLIB_TEST(a1.size() == 0); + + swap(a1,a2); + DLIB_TEST(a2.size() == 0); + DLIB_TEST(a2.current_element_valid() == false); + DLIB_TEST(a2.at_start() == false); + DLIB_TEST(a2.move_next() == false); + DLIB_TEST(a2.current_element_valid() == false); + DLIB_TEST(a2.size() == 0); + DLIB_TEST(a2.at_start() == false); + DLIB_TEST(a2.size() == 0); + + + + DLIB_TEST(a1.size() == 0); + DLIB_TEST(a1.at_start()); + DLIB_TEST(a1.current_element_valid() == false); + DLIB_TEST(a1.move_next() == false); + DLIB_TEST(a1.size() == 0); + DLIB_TEST(a1.current_element_valid() == false); + DLIB_TEST(a1.at_start() == false); + DLIB_TEST(a1.move_next() == false); + DLIB_TEST(a1.current_element_valid() == false); + DLIB_TEST(a1.size() == 0); + DLIB_TEST(a1.at_start() == false); + DLIB_TEST(a1.size() == 0); + + a1.clear(); + a2.clear(); + + + + + a1.set_max_size(100000); + a2.set_max_size(100000); + a1.set_size(10000); + DLIB_TEST(a1.size() == 10000); + a2.set_size(10000); + DLIB_TEST(a2.size() == 10000); + for (unsigned long i = 0; i < a1.size(); ++i) + { + unsigned long a = static_cast(rnd.get_random_32bit_number()); + a1[i] = a; + a2[i] = i; + DLIB_TEST(a1[i] == a); + DLIB_TEST(a2[i] == i); + } + + DLIB_TEST(a1.at_start()); + DLIB_TEST(a1.current_element_valid() == false); + DLIB_TEST(a1.move_next()); + DLIB_TEST(a1.current_element_valid()); + + DLIB_TEST(a1.at_start() == false); + a1.sort(); + DLIB_TEST(a1.at_start()); + a2.sort(); + DLIB_TEST(a1.size() == 10000); + DLIB_TEST(a2.size() == 10000); + + + for (unsigned long i = 0; i < a1.size(); ++i) + { + if (i+1 < a1.size()) + { + DLIB_TEST_MSG(a1[i] <= a1[i+1], + "a1[i]: " << a1[i] << " a1[i+1]: " << a1[i+1] + << " i: " << i); + } + DLIB_TEST_MSG(a2[i] == i,"i: " << i << " a2[i]: " << a2[i]); + } + + unsigned long last = 0; + unsigned long count = 0; + while (a1.move_next()) + { + DLIB_TEST(last <= a1.element()); + last = a1.element(); + ++count; + } + DLIB_TEST(count == a1.size()); + + last = 0; + count = 0; + while (a2.move_next()) + { + DLIB_TEST(last <= a2.element()); + last = a2.element(); + ++count; + } + DLIB_TEST(count == a2.size()); + + a2.set_size(15000); + + for (unsigned long i = 0; i < a1.size(); ++i) + { + if (i+1 < a1.size()) + { + DLIB_TEST(a1[i] <= a1[i+1]); + } + DLIB_TEST(a2[i] == i); + } + + for (unsigned long i = 10000; i < a2.size(); ++i) + { + a2[i] = i; + DLIB_TEST(a2[i] == i); + } + + for (unsigned long i = 0; i < a2.size(); ++i) + { + DLIB_TEST(a2[i] == i); + } + + a2.reset(); + last = 0; + while (a2.move_next()) + { + DLIB_TEST(last <= a2.element()); + last = a2.element(); + } + + a1.reset(); + last = 0; + while (a1.move_next()) + { + DLIB_TEST(last <= a1.element()); + last = a1.element(); + } + + a1.sort(); + last = 0; + while (a1.move_next()) + { + DLIB_TEST(last <= a1.element()); + last = a1.element(); + } + + swap(a2,a1); + + for (unsigned long i = 0; i < 15000; ++i) + { + DLIB_TEST(a1[i] == i); + } + + + + a1.clear(); + DLIB_TEST(a1.max_size() == 0); + + + + + a1.clear(); + a2.clear(); + + + DLIB_TEST(a1.size() == 0); + DLIB_TEST(a2.size() == 0); + a1.set_max_size(100000); + a2.set_max_size(100000); + + a1.set_size(10000); + DLIB_TEST(a1.size() == 10000); + a2.set_size(10000); + DLIB_TEST(a2.size() == 10000); + for (unsigned long i = 0; i < a1.size(); ++i) + { + unsigned long a = static_cast(rnd.get_random_32bit_number()); + a1[i] = a; + a2[i] = i; + DLIB_TEST(a1[i] == a); + DLIB_TEST(a2[i] == i); + } + + DLIB_TEST(a1.at_start()); + DLIB_TEST(a1.current_element_valid() == false); + DLIB_TEST(a1.move_next()); + DLIB_TEST(a1.current_element_valid()); + + DLIB_TEST(a1.at_start() == false); + a1.sort(); + DLIB_TEST(a1.at_start()); + a2.sort(); + DLIB_TEST(a1.size() == 10000); + DLIB_TEST(a2.size() == 10000); + + + for (unsigned long i = 0; i < a1.size(); ++i) + { + if (i+1 < a1.size()) + { + DLIB_TEST(a1[i] <= a1[i+1]); + } + DLIB_TEST(a2[i] == i); + } + + last = 0; + while (a1.move_next()) + { + DLIB_TEST(last <= a1.element()); + last = a1.element(); + } + + last = 0; + while (a2.move_next()) + { + DLIB_TEST(last <= a2.element()); + last = a2.element(); + } + + a2.set_size(15000); + + for (unsigned long i = 0; i < a1.size(); ++i) + { + if (i+1 < a1.size()) + { + DLIB_TEST(a1[i] <= a1[i+1]); + } + DLIB_TEST(a2[i] == i); + } + + for (unsigned long i = 10000; i < a2.size(); ++i) + { + a2[i] = i; + DLIB_TEST(a2[i] == i); + } + + for (unsigned long i = 0; i < a2.size(); ++i) + { + DLIB_TEST(a2[i] == i); + } + + a2.reset(); + last = 0; + while (a2.move_next()) + { + DLIB_TEST(last <= a2.element()); + last = a2.element(); + } + + a1.reset(); + last = 0; + while (a1.move_next()) + { + DLIB_TEST(last <= a1.element()); + last = a1.element(); + } + + a1.sort(); + last = 0; + while (a1.move_next()) + { + DLIB_TEST(last <= a1.element()); + last = a1.element(); + } + + swap(a2,a1); + + for (unsigned long i = 0; i < 15000; ++i) + { + DLIB_TEST(a1[i] == i); + } + + + + a1.clear(); + DLIB_TEST(a1.max_size() == 0); + + a2.clear(); + print_spinner(); + } + + + + a1.set_max_size(2000000); + DLIB_TEST(a1.max_size() == 2000000); + DLIB_TEST(a1.size() == 0); + a1.set_size(2000000); + DLIB_TEST(a1.max_size() == 2000000); + DLIB_TEST(a1.size() == 2000000); + + for (unsigned long i = 0; i < a1.size(); ++i) + { + a1[i] = rnd.get_random_32bit_number(); + } + + print_spinner(); + a1.sort(); + + print_spinner(); + // serialize the state of a1, then clear a1, then + // load the state back into a1. + ostringstream sout; + serialize(a1,sout); + DLIB_TEST(a1.at_start() == true); + istringstream sin(sout.str()); + a1.clear(); + DLIB_TEST(a1.max_size() == 0); + deserialize(a1,sin); + + DLIB_TEST(a1.size() == 2000000); + + for (unsigned long i = 0; i < a1.size()-1; ++i) + { + DLIB_TEST(a1[i] <= a1[i+1]); + } + + DLIB_TEST(a1.max_size() == 2000000); + DLIB_TEST(a1.size() == 2000000); + + + swap(a1,a2); + + print_spinner(); + + DLIB_TEST(a2.size() == 2000000); + + for (unsigned long i = 0; i < a2.size()-1; ++i) + { + DLIB_TEST(a2[i] <= a2[i+1]); + } + + DLIB_TEST(a2.max_size() == 2000000); + DLIB_TEST(a2.size() == 2000000); + + swap(a1,a2); + + + a1.clear(); + DLIB_TEST(a1.size() == 0); + DLIB_TEST(a1.max_size() == 0); + + a1.resize(10); + DLIB_TEST(a1.size() == 10); + DLIB_TEST(a1.max_size() == 10); + + for (unsigned long i = 0; i < a1.size(); ++i) + { + a1[i] = i; + } + + print_spinner(); + a1.resize(100); + DLIB_TEST(a1.size() == 100); + DLIB_TEST(a1.max_size() == 100); + + for (unsigned long i = 0; i < 10; ++i) + { + DLIB_TEST(a1[i] == i); + } + + a1.resize(50); + DLIB_TEST(a1.size() == 50); + DLIB_TEST(a1.max_size() == 100); + + for (unsigned long i = 0; i < 10; ++i) + { + DLIB_TEST(a1[i] == i); + } + + a1.resize(10); + DLIB_TEST(a1.size() == 10); + DLIB_TEST(a1.max_size() == 100); + + for (unsigned long i = 0; i < 10; ++i) + { + DLIB_TEST(a1[i] == i); + } + + a1.resize(20); + DLIB_TEST(a1.size() == 20); + DLIB_TEST(a1.max_size() == 100); + + for (unsigned long i = 0; i < 10; ++i) + { + DLIB_TEST(a1[i] == i); + } + + + a1.resize(100); + DLIB_TEST(a1.size() == 100); + DLIB_TEST(a1.max_size() == 100); + + for (unsigned long i = 0; i < 10; ++i) + { + DLIB_TEST(a1[i] == i); + } + + { + a1.clear(); + DLIB_TEST(a1.size() == 0); + for (unsigned long i = 0; i < 100; ++i) + { + unsigned long a = i; + a1.push_back(a); + DLIB_TEST(a1.size() == i+1); + DLIB_TEST(a1.back() == i); + } + for (unsigned long i = 0; i < 100; ++i) + { + DLIB_TEST(a1[i] == i); + } + for (unsigned long i = 0; i < 100; ++i) + { + unsigned long a = 0; + a1.pop_back(a); + DLIB_TEST(a == 99-i); + } + } + + { + a1.clear(); + DLIB_TEST(a1.size() == 0); + for (unsigned long i = 0; i < 100; ++i) + { + unsigned long a = i; + a1.push_back(a); + DLIB_TEST(a1.size() == i+1); + DLIB_TEST(a1.back() == i); + } + for (unsigned long i = 0; i < 100; ++i) + { + DLIB_TEST(a1[i] == i); + } + for (unsigned long i = 0; i < 100; ++i) + { + a1.pop_back(); + } + DLIB_TEST(a1.size() == 0); + } + + } + + struct stuff + { + int whatever; + }; + void another_array_test() + { + array a; + a.resize(5); + a[0].whatever = 0; + stuff temp; + temp.whatever = 99; + a.push_back(temp); + DLIB_TEST(a.size() == 6); + DLIB_TEST(a[5].whatever == 99); + + DLIB_TEST(dlib::is_array >::value == true); + } + + void test_array_split() + { + array temp(5); + + for (unsigned int i = 0; i < temp.size(); ++i) + temp[i] = i; + + array b; + + split_array(temp, b, 0.5); + DLIB_TEST(temp.size() == 2); + DLIB_TEST(b.size() == 3); + + DLIB_TEST(temp[0] == 0); + DLIB_TEST(temp[1] == 1); + DLIB_TEST(b[0] == 2); + DLIB_TEST(b[1] == 3); + DLIB_TEST(b[2] == 4); + } + + class array_tester : public tester + { + public: + array_tester ( + ) : + tester ("test_array", + "Runs tests on the array component.") + {} + + void perform_test ( + ) + { + print_spinner(); + another_array_test(); + + // test a checking version first for good measure + print_spinner(); + array_expand_test >(); + + DLIB_TEST(dlib::is_array::value == false); + test_array_split(); + } + } a; + + + + +} + diff --git a/ml/dlib/dlib/test/array2d.cpp b/ml/dlib/dlib/test/array2d.cpp new file mode 100644 index 000000000..12b8b586a --- /dev/null +++ b/ml/dlib/dlib/test/array2d.cpp @@ -0,0 +1,580 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include +#include +#include "tester.h" +#include +#include + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.array2d"); + + template < + typename array2d + > + void array2d_kernel_test ( + ) + /*! + requires + - array2d is an implementation of array2d/array2d_kernel_abstract.h + is instantiated with unsigned long + ensures + - runs tests on array2d for compliance with the specs + !*/ + { + srand(static_cast(time(0))); + + array2d test,test2; + + long nc, nr; + + + DLIB_TEST(get_rect(test).is_empty()); + + enumerable& e = test; + DLIB_TEST(e.at_start() == true); + + + DLIB_TEST(e.size() == 0); + DLIB_TEST(e.at_start() == true); + DLIB_TEST(e.current_element_valid() == false); + + DLIB_TEST (e.move_next() == false); + DLIB_TEST (e.move_next() == false); + DLIB_TEST (e.move_next() == false); + DLIB_TEST (e.move_next() == false); + DLIB_TEST (e.move_next() == false); + DLIB_TEST (e.move_next() == false); + + + DLIB_TEST(e.size() == 0); + DLIB_TEST(e.at_start() == false); + DLIB_TEST(e.current_element_valid() == false); + + + e.reset(); + + DLIB_TEST(e.size() == 0); + DLIB_TEST(e.at_start() == true); + DLIB_TEST(e.current_element_valid() == false); + + + DLIB_TEST(get_rect(test).is_empty()); + + + + DLIB_TEST(test.at_start() == true); + + + DLIB_TEST(test.size() == 0); + DLIB_TEST(test.at_start() == true); + DLIB_TEST(test.current_element_valid() == false); + + DLIB_TEST (test.move_next() == false); + DLIB_TEST (test.move_next() == false); + DLIB_TEST (test.move_next() == false); + DLIB_TEST (test.move_next() == false); + DLIB_TEST (test.move_next() == false); + DLIB_TEST (test.move_next() == false); + + + DLIB_TEST(test.size() == 0); + DLIB_TEST(test.at_start() == false); + DLIB_TEST(test.current_element_valid() == false); + + + test.reset(); + + DLIB_TEST(test.size() == 0); + DLIB_TEST(test.at_start() == true); + DLIB_TEST(test.current_element_valid() == false); + + test.clear(); + + + DLIB_TEST(test.at_start() == true); + + + DLIB_TEST(test.size() == 0); + DLIB_TEST(test.at_start() == true); + DLIB_TEST(test.current_element_valid() == false); + + DLIB_TEST (test.move_next() == false); + DLIB_TEST (test.move_next() == false); + DLIB_TEST (test.move_next() == false); + DLIB_TEST (test.move_next() == false); + DLIB_TEST (test.move_next() == false); + DLIB_TEST (test.move_next() == false); + + + test.set_size(0,0); + + DLIB_TEST(get_rect(test).is_empty()); + + DLIB_TEST(test.at_start() == true); + + + DLIB_TEST(test.size() == 0); + DLIB_TEST(test.at_start() == true); + DLIB_TEST(test.current_element_valid() == false); + + DLIB_TEST (test.move_next() == false); + DLIB_TEST (test.move_next() == false); + DLIB_TEST (test.move_next() == false); + DLIB_TEST (test.move_next() == false); + DLIB_TEST (test.move_next() == false); + DLIB_TEST (test.move_next() == false); + + swap(test,test2); + DLIB_TEST (test2.at_start() == false); + DLIB_TEST (test2.current_element_valid() == false); + DLIB_TEST(test.at_start() == true); + DLIB_TEST(test.current_element_valid() == false); + swap(test,test2); + DLIB_TEST(test2.at_start() == true); + DLIB_TEST(test2.current_element_valid() == false); + + + DLIB_TEST(test.size() == 0); + DLIB_TEST(test.at_start() == false); + DLIB_TEST(test.current_element_valid() == false); + + + test.reset(); + + DLIB_TEST(test.size() == 0); + DLIB_TEST(test.at_start() == true); + DLIB_TEST(test.current_element_valid() == false); + + + + + for (int j = 0; j < 30; ++j) + { + test2.clear(); + switch (j) + { + case 0: + nc = 10; + nr = 11; + break; + case 1: + nc = 1; + nr = 1; + break; + case 2: + nc = 100; + nr = 1; + break; + case 3: + nc = 1; + nr = 100; + break; + default: + nc = ::rand()%100 + 1; + nr = ::rand()%100 + 1; + break; + } + + test.set_size(nr,nc); + + DLIB_TEST(get_rect(test).left() == 0); + DLIB_TEST(get_rect(test).top() == 0); + DLIB_TEST(get_rect(test).right() == nc-1); + DLIB_TEST(get_rect(test).bottom() == nr-1); + + DLIB_TEST(test.size() == static_cast(nc*nr)); + DLIB_TEST(test.nr() == nr); + DLIB_TEST(test.nc() == nc); + DLIB_TEST(test.at_start() == true); + DLIB_TEST(test.current_element_valid() == false); + + unsigned long i = 0; + while (test.move_next()) + { + DLIB_TEST(test.current_element_valid() == true); + DLIB_TEST(test.at_start() == false); + test.element() = i; + DLIB_TEST(const_cast(test).element() == i); + ++i; + } + DLIB_TEST(i == test.size()); + DLIB_TEST(test.current_element_valid() == false); + + DLIB_TEST(test.nr() == nr); + DLIB_TEST(test.nc() == nc); + DLIB_TEST(test.at_start() == false); + DLIB_TEST(test.size() == static_cast(nc*nr)); + + i = 0; + for (long row = 0; row < test.nr(); ++row) + { + for (long col = 0; col < test.nc(); ++col) + { + DLIB_TEST_MSG(test[row][col] == i, + "\n\trow: " << row << + "\n\tcol: " << col << + "\n\ti: " << i << + "\n\ttest[row][col]: " << test[row][col]); + DLIB_TEST(test[row].nc() == test.nc()); + DLIB_TEST(test.current_element_valid() == false); + + DLIB_TEST(test.nr() == nr); + DLIB_TEST(test.nc() == nc); + DLIB_TEST(test.at_start() == false); + DLIB_TEST(test.size() == static_cast(nc*nr)); + ++i; + } + } + + test.reset(); + + i = 0; + while (test.move_next()) + { + DLIB_TEST(test.element() == i); + ++i; + DLIB_TEST(test.current_element_valid() == true); + DLIB_TEST(test.at_start() == false); + } + DLIB_TEST(i == test.size()); + + test.reset(); + + + + + swap(test,test2); + + DLIB_TEST(test2.size() == static_cast(nc*nr)); + DLIB_TEST(test2.nr() == nr); + DLIB_TEST(test2.nc() == nc); + DLIB_TEST(test2.at_start() == true); + DLIB_TEST(test2.current_element_valid() == false); + + i = 0; + while (test2.move_next()) + { + DLIB_TEST(test2.current_element_valid() == true); + DLIB_TEST(test2.at_start() == false); + test2.element() = i; + ++i; + } + DLIB_TEST(i == test2.size()); + DLIB_TEST(test2.current_element_valid() == false); + + DLIB_TEST(test2.nr() == nr); + DLIB_TEST(test2.nr() == test2.nr()); + DLIB_TEST(test2.nc() == nc); + DLIB_TEST(test2.nc() == test2.nc()); + DLIB_TEST(test2.at_start() == false); + DLIB_TEST(test2.size() == static_cast(nc*nr)); + + i = 0; + for (long row = 0; row < test2.nr(); ++row) + { + for (long col = 0; col < test2.nc(); ++col) + { + DLIB_TEST(test2[row][col] == i); + DLIB_TEST(const_cast(test2)[row][col] == i); + DLIB_TEST(test2[row].nc() == test2.nc()); + DLIB_TEST(test2.current_element_valid() == false); + + DLIB_TEST(test2.nr() == nr); + DLIB_TEST(test2.nr() == test2.nr()); + DLIB_TEST(test2.nc() == nc); + DLIB_TEST(test2.nc() == test2.nc()); + DLIB_TEST(test2.at_start() == false); + DLIB_TEST(test2.size() == static_cast(nc*nr)); + ++i; + } + } + + test2.reset(); + + i = 0; + while (test2.move_next()) + { + DLIB_TEST(test2.element() == i); + DLIB_TEST(const_cast(test2).element() == i); + ++i; + DLIB_TEST(test2.current_element_valid() == true); + DLIB_TEST(test2.at_start() == false); + } + DLIB_TEST(i == test2.size()); + + + test2.clear(); + DLIB_TEST(test2.size() == 0); + DLIB_TEST(test2.nr() == 0); + DLIB_TEST(test2.nc() == 0); + DLIB_TEST(test2.current_element_valid() == false); + DLIB_TEST(test2.at_start() == true); + + DLIB_TEST(test.size() == 0); + DLIB_TEST(test.nc() == 0); + DLIB_TEST(test.nr() == 0); + + test.set_size(nr,nc); + DLIB_TEST(test.size() == static_cast(nc*nr)); + DLIB_TEST(test.nc() == nc); + DLIB_TEST(test.nr() == nr); + + + + } + + + + + + // test the serialization + istringstream sin; + ostringstream sout; + test.clear(); + test2.clear(); + + DLIB_TEST(test.size() == 0); + DLIB_TEST(test.nc() == 0); + DLIB_TEST(test.nr() == 0); + DLIB_TEST(test2.size() == 0); + DLIB_TEST(test2.nc() == 0); + DLIB_TEST(test2.nr() == 0); + + test.set_size(10,10); + + for (long row = 0; row < test.nr(); ++row) + { + for (long col = 0; col < test.nc(); ++col) + { + test[row][col] = row*col; + } + } + + serialize(test,sout); + sin.str(sout.str()); + deserialize(test2,sin); + + DLIB_TEST(test2.size() == test.size()); + DLIB_TEST(test2.nc() == test.nc()); + DLIB_TEST(test2.nr() == test.nr()); + DLIB_TEST(test2.size() == 100); + DLIB_TEST(test2.nc() == 10); + DLIB_TEST(test2.nr() == 10); + + + for (long row = 0; row < test.nr(); ++row) + { + for (long col = 0; col < test.nc(); ++col) + { + DLIB_TEST(test[row][col] == static_cast(row*col)); + DLIB_TEST(test2[row][col] == static_cast(row*col)); + } + } + + + + + + + test.set_size(10,11); + DLIB_TEST(test.nr() == 10); + DLIB_TEST(test.nc() == 11); + test.set_size(0,0); + DLIB_TEST(test.nr() == 0); + DLIB_TEST(test.nc() == 0); + + } + + void test_serialization() + { + // Do these tests because there are overloads of the serialize routines + // specifically for these types of pixel (except for unsigned short, + // we do that because you can never have too many tests). + { + array2d img, img2; + img.set_size(3,2); + assign_all_pixels(img, 5); + img[1][1].red = 9; + img[1][1].green = 8; + img[1][1].blue = 7; + img[1][1].alpha = 3; + ostringstream sout; + serialize(img, sout); + istringstream sin(sout.str()); + deserialize(img2, sin); + + DLIB_TEST(img2.nr() == 3); + DLIB_TEST(img2.nc() == 2); + + for (long r = 0; r < img.nr(); ++r) + { + for (long c = 0; c < img.nc(); ++c) + { + DLIB_TEST(img[r][c].red == img2[r][c].red); + DLIB_TEST(img[r][c].green == img2[r][c].green); + DLIB_TEST(img[r][c].blue == img2[r][c].blue); + DLIB_TEST(img[r][c].alpha == img2[r][c].alpha); + } + } + } + { + array2d img, img2; + img.set_size(3,2); + assign_all_pixels(img, 5); + img[1][1].h = 9; + img[1][1].s = 2; + img[1][1].i = 3; + ostringstream sout; + serialize(img, sout); + istringstream sin(sout.str()); + deserialize(img2, sin); + + DLIB_TEST(img2.nr() == 3); + DLIB_TEST(img2.nc() == 2); + + for (long r = 0; r < img.nr(); ++r) + { + for (long c = 0; c < img.nc(); ++c) + { + DLIB_TEST(img[r][c].h == img2[r][c].h); + DLIB_TEST(img[r][c].s == img2[r][c].s); + DLIB_TEST(img[r][c].i == img2[r][c].i); + } + } + } + { + array2d img, img2; + img.set_size(3,2); + assign_all_pixels(img, 5); + img[1][1].red = 1; + img[1][1].green = 2; + img[1][1].blue = 3; + ostringstream sout; + serialize(img, sout); + istringstream sin(sout.str()); + deserialize(img2, sin); + + DLIB_TEST(img2.nr() == 3); + DLIB_TEST(img2.nc() == 2); + + for (long r = 0; r < img.nr(); ++r) + { + for (long c = 0; c < img.nc(); ++c) + { + DLIB_TEST(img[r][c].red == img2[r][c].red); + DLIB_TEST(img[r][c].green == img2[r][c].green); + DLIB_TEST(img[r][c].blue == img2[r][c].blue); + } + } + } + { + array2d img, img2; + img.set_size(3,2); + assign_all_pixels(img, 5); + img[1][1].red = 1; + img[1][1].green = 2; + img[1][1].blue = 3; + ostringstream sout; + serialize(img, sout); + istringstream sin(sout.str()); + deserialize(img2, sin); + + DLIB_TEST(img2.nr() == 3); + DLIB_TEST(img2.nc() == 2); + + for (long r = 0; r < img.nr(); ++r) + { + for (long c = 0; c < img.nc(); ++c) + { + DLIB_TEST(img[r][c].red == img2[r][c].red); + DLIB_TEST(img[r][c].green == img2[r][c].green); + DLIB_TEST(img[r][c].blue == img2[r][c].blue); + } + } + } + { + array2d img, img2; + img.set_size(3,2); + assign_all_pixels(img, 5); + img[1][1] = 9; + ostringstream sout; + serialize(img, sout); + istringstream sin(sout.str()); + deserialize(img2, sin); + + DLIB_TEST(img2.nr() == 3); + DLIB_TEST(img2.nc() == 2); + + for (long r = 0; r < img.nr(); ++r) + { + for (long c = 0; c < img.nc(); ++c) + { + DLIB_TEST(img[r][c] == img2[r][c]); + } + } + } + { + array2d img, img2; + img.set_size(3,2); + assign_all_pixels(img, 5); + img[1][1] = 9; + ostringstream sout; + serialize(img, sout); + istringstream sin(sout.str()); + deserialize(img2, sin); + + DLIB_TEST(img2.nr() == 3); + DLIB_TEST(img2.nc() == 2); + + for (long r = 0; r < img.nr(); ++r) + { + for (long c = 0; c < img.nc(); ++c) + { + DLIB_TEST(img[r][c] == img2[r][c]); + } + } + + DLIB_TEST((char*)&img[0][0] + img.width_step() == (char*)&img[1][0]); + } + + COMPILE_TIME_ASSERT(is_array2d >::value == true); + COMPILE_TIME_ASSERT(is_array2d >::value == true); + COMPILE_TIME_ASSERT(is_array2d::value == false); + } + + + class array2d_tester : public tester + { + public: + array2d_tester ( + ) : + tester ("test_array2d", + "Runs tests on the array2d component.") + {} + + void perform_test ( + ) + { + dlog << LINFO << "testing kernel_1a"; + array2d_kernel_test >(); + print_spinner(); + test_serialization(); + print_spinner(); + } + } a; + +} + + diff --git a/ml/dlib/dlib/test/assignment_learning.cpp b/ml/dlib/dlib/test/assignment_learning.cpp new file mode 100644 index 000000000..bba47db3d --- /dev/null +++ b/ml/dlib/dlib/test/assignment_learning.cpp @@ -0,0 +1,379 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include "tester.h" +#include +#include + + +typedef dlib::matrix lhs_element; +typedef dlib::matrix rhs_element; + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.assignment_learning"); + +// ---------------------------------------------------------------------------------------- + + +// ---------------------------------------------------------------------------------------- + + struct feature_extractor_dense + { + typedef matrix feature_vector_type; + + typedef ::lhs_element lhs_element; + typedef ::rhs_element rhs_element; + + unsigned long num_features() const + { + return 3; + } + + void get_features ( + const lhs_element& left, + const rhs_element& right, + feature_vector_type& feats + ) const + { + feats = squared(left - right); + } + + }; + + void serialize (const feature_extractor_dense& , std::ostream& ) {} + void deserialize (feature_extractor_dense& , std::istream& ) {} + +// ---------------------------------------------------------------------------------------- + + struct feature_extractor_sparse + { + typedef std::vector > feature_vector_type; + + typedef ::lhs_element lhs_element; + typedef ::rhs_element rhs_element; + + unsigned long num_features() const + { + return 3; + } + + void get_features ( + const lhs_element& left, + const rhs_element& right, + feature_vector_type& feats + ) const + { + feats.clear(); + feats.push_back(make_pair(0,squared(left-right)(0))); + feats.push_back(make_pair(1,squared(left-right)(1))); + feats.push_back(make_pair(2,squared(left-right)(2))); + } + + }; + + void serialize (const feature_extractor_sparse& , std::ostream& ) {} + void deserialize (feature_extractor_sparse& , std::istream& ) {} + +// ---------------------------------------------------------------------------------------- + + typedef std::pair, std::vector > sample_type; + typedef std::vector label_type; + +// ---------------------------------------------------------------------------------------- + + void make_data ( + std::vector& samples, + std::vector& labels + ) + { + lhs_element a, b, c, d; + a = 1,0,0; + b = 0,1,0; + c = 0,0,1; + d = 0,1,1; + + std::vector lhs; + std::vector rhs; + label_type label; + + lhs.push_back(a); + lhs.push_back(b); + lhs.push_back(c); + + rhs.push_back(b); + rhs.push_back(a); + rhs.push_back(c); + + label.push_back(1); + label.push_back(0); + label.push_back(2); + + samples.push_back(make_pair(lhs,rhs)); + labels.push_back(label); + + + + + lhs.clear(); + rhs.clear(); + label.clear(); + + lhs.push_back(a); + lhs.push_back(b); + lhs.push_back(c); + + rhs.push_back(c); + rhs.push_back(b); + rhs.push_back(a); + rhs.push_back(d); + + label.push_back(2); + label.push_back(1); + label.push_back(0); + + samples.push_back(make_pair(lhs,rhs)); + labels.push_back(label); + + + lhs.clear(); + rhs.clear(); + label.clear(); + + lhs.push_back(a); + lhs.push_back(b); + lhs.push_back(c); + + rhs.push_back(c); + rhs.push_back(a); + rhs.push_back(d); + + label.push_back(1); + label.push_back(-1); + label.push_back(0); + + samples.push_back(make_pair(lhs,rhs)); + labels.push_back(label); + + + + lhs.clear(); + rhs.clear(); + label.clear(); + + lhs.push_back(d); + lhs.push_back(b); + lhs.push_back(c); + + label.push_back(-1); + label.push_back(-1); + label.push_back(-1); + + samples.push_back(make_pair(lhs,rhs)); + labels.push_back(label); + + + + lhs.clear(); + rhs.clear(); + label.clear(); + + samples.push_back(make_pair(lhs,rhs)); + labels.push_back(label); + + } + +// ---------------------------------------------------------------------------------------- + + void make_data_force ( + std::vector& samples, + std::vector& labels + ) + { + lhs_element a, b, c, d; + a = 1,0,0; + b = 0,1,0; + c = 0,0,1; + d = 0,1,1; + + std::vector lhs; + std::vector rhs; + label_type label; + + lhs.push_back(a); + lhs.push_back(b); + lhs.push_back(c); + + rhs.push_back(b); + rhs.push_back(a); + rhs.push_back(c); + + label.push_back(1); + label.push_back(0); + label.push_back(2); + + samples.push_back(make_pair(lhs,rhs)); + labels.push_back(label); + + + + + lhs.clear(); + rhs.clear(); + label.clear(); + + lhs.push_back(a); + lhs.push_back(b); + lhs.push_back(c); + + rhs.push_back(c); + rhs.push_back(b); + rhs.push_back(a); + rhs.push_back(d); + + label.push_back(2); + label.push_back(1); + label.push_back(0); + + samples.push_back(make_pair(lhs,rhs)); + labels.push_back(label); + + + lhs.clear(); + rhs.clear(); + label.clear(); + + lhs.push_back(a); + lhs.push_back(c); + + rhs.push_back(c); + rhs.push_back(a); + + label.push_back(1); + label.push_back(0); + + samples.push_back(make_pair(lhs,rhs)); + labels.push_back(label); + + + + + + lhs.clear(); + rhs.clear(); + label.clear(); + + samples.push_back(make_pair(lhs,rhs)); + labels.push_back(label); + + } + +// ---------------------------------------------------------------------------------------- + + template + void test1(F make_data, bool force_assignment) + { + print_spinner(); + + std::vector samples; + std::vector labels; + + make_data(samples, labels); + make_data(samples, labels); + make_data(samples, labels); + + randomize_samples(samples, labels); + + structural_assignment_trainer trainer; + + DLIB_TEST(trainer.forces_assignment() == false); + DLIB_TEST(trainer.get_c() == 100); + DLIB_TEST(trainer.get_num_threads() == 2); + DLIB_TEST(trainer.get_max_cache_size() == 5); + + + trainer.set_forces_assignment(force_assignment); + trainer.set_num_threads(3); + trainer.set_c(50); + + DLIB_TEST(trainer.get_c() == 50); + DLIB_TEST(trainer.get_num_threads() == 3); + DLIB_TEST(trainer.forces_assignment() == force_assignment); + + assignment_function ass = trainer.train(samples, labels); + + for (unsigned long i = 0; i < samples.size(); ++i) + { + std::vector out = ass(samples[i]); + dlog << LINFO << "true labels: " << trans(mat(labels[i])); + dlog << LINFO << "pred labels: " << trans(mat(out)); + DLIB_TEST(trans(mat(labels[i])) == trans(mat(out))); + } + + double accuracy; + + dlog << LINFO << "samples.size(): "<< samples.size(); + accuracy = test_assignment_function(ass, samples, labels); + dlog << LINFO << "accuracy: "<< accuracy; + DLIB_TEST(accuracy == 1); + + accuracy = cross_validate_assignment_trainer(trainer, samples, labels, 3); + dlog << LINFO << "cv accuracy: "<< accuracy; + DLIB_TEST(accuracy == 1); + + ostringstream sout; + serialize(ass, sout); + istringstream sin(sout.str()); + assignment_function ass2; + deserialize(ass2, sin); + + DLIB_TEST(ass2.forces_assignment() == ass.forces_assignment()); + DLIB_TEST(length(ass2.get_weights() - ass.get_weights()) < 1e-10); + + for (unsigned long i = 0; i < samples.size(); ++i) + { + std::vector out = ass2(samples[i]); + dlog << LINFO << "true labels: " << trans(mat(labels[i])); + dlog << LINFO << "pred labels: " << trans(mat(out)); + DLIB_TEST(trans(mat(labels[i])) == trans(mat(out))); + } + } + +// ---------------------------------------------------------------------------------------- + + class test_assignment_learning : public tester + { + public: + test_assignment_learning ( + ) : + tester ("test_assignment_learning", + "Runs tests on the assignment learning code.") + {} + + void perform_test ( + ) + { + test1(make_data, false); + test1(make_data, false); + + test1(make_data_force, false); + test1(make_data_force, false); + test1(make_data_force, true); + test1(make_data_force, true); + } + } a; + +// ---------------------------------------------------------------------------------------- + +} + + diff --git a/ml/dlib/dlib/test/base64.cpp b/ml/dlib/dlib/test/base64.cpp new file mode 100644 index 000000000..f4d478018 --- /dev/null +++ b/ml/dlib/dlib/test/base64.cpp @@ -0,0 +1,208 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include +#include +#include +#include +#include + +#include "tester.h" + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.base64"); + + template < + typename base64 + > + void base64_kernel_test ( + ) + /*! + requires + - base64 is an implementation of base64/base64_kernel_abstract.h + ensures + - runs tests on base64 for compliance with the specs + !*/ + { + + const unsigned int seed = static_cast(time(0)); + try + { + + srand(seed); + + base64 test; + + const string wiki_normal = "\ +Man is distinguished, not only by his reason, but by this singular passion from other \ +animals, which is a lust of the mind, that by a perseverance of delight in the continued \ +and indefatigable generation of knowledge, exceeds the short vehemence of any carnal pleasure."; + + const string wiki_encoded = "\ +TWFuIGlzIGRpc3Rpbmd1aXNoZWQsIG5vdCBvbmx5IGJ5IGhpcyByZWFzb24sIGJ1dCBieSB0\n\ +aGlzIHNpbmd1bGFyIHBhc3Npb24gZnJvbSBvdGhlciBhbmltYWxzLCB3aGljaCBpcyBhIGx1\n\ +c3Qgb2YgdGhlIG1pbmQsIHRoYXQgYnkgYSBwZXJzZXZlcmFuY2Ugb2YgZGVsaWdodCBpbiB0\n\ +aGUgY29udGludWVkIGFuZCBpbmRlZmF0aWdhYmxlIGdlbmVyYXRpb24gb2Yga25vd2xlZGdl\n\ +LCBleGNlZWRzIHRoZSBzaG9ydCB2ZWhlbWVuY2Ugb2YgYW55IGNhcm5hbCBwbGVhc3VyZS4="; + + + + string str; + + istringstream sin; + ostringstream sout; + + sin.str(wiki_encoded); + test.decode(sin,sout); + DLIB_TEST_MSG(sout.str() == wiki_normal, + "sout.str(): " << sout.str() << + "\nwiki_normal: " << wiki_normal); + + + sout.str(""); + sin.str(wiki_normal); + sin.clear(); + test.encode(sin,sout); + + string a(sout.str()), b(wiki_encoded); + // we want to strip all the whitespace from a and b now + sin.str(a); + a.clear(); + sin >> str; + while (sin) + { + a += str; + sin >> str; + } + + sin.clear(); + sin.str(b); + b.clear(); + sin >> str; + while (sin) + { + b += str; + sin >> str; + } + sin.clear(); + + DLIB_TEST_MSG(a == b, + "a: \n" << a << + "\n\nb: \n" << b); + + + + sin.clear(); + sin.str(""); + sout.str(""); + test.encode(sin,sout); + sin.str(sout.str()); + sout.str(""); + test.decode(sin,sout); + DLIB_TEST(sout.str() == ""); + + sin.clear(); + sin.str("a"); + sout.str(""); + test.encode(sin,sout); + sin.str(sout.str()); + sout.str(""); + test.decode(sin,sout); + DLIB_TEST(sout.str() == "a"); + + sin.clear(); + sin.str("da"); + sout.str(""); + test.encode(sin,sout); + sin.str(sout.str()); + sout.str(""); + test.decode(sin,sout); + DLIB_TEST(sout.str() == "da"); + + sin.clear(); + sin.str("dav"); + sout.str(""); + test.encode(sin,sout); + sin.str(sout.str()); + sout.str(""); + test.decode(sin,sout); + DLIB_TEST(sout.str() == "dav"); + + sin.clear(); + sin.str("davi"); + sout.str(""); + test.encode(sin,sout); + sin.str(sout.str()); + sout.str(""); + test.decode(sin,sout); + DLIB_TEST(sout.str() == "davi"); + + + for (int i = 0; i < 1000; ++i) + { + str.clear(); + sin.clear(); + sout.str(""); + sin.str(""); + + // fill str with random garbage + const int size = rand()%2000; + for (int j = 0; j < size; ++j) + { + unsigned char ch = rand()&0xFF; + str += ch; + } + + sin.str(str); + test.encode(sin,sout); + sin.clear(); + sin.str(sout.str()); + sout.str(""); + test.decode(sin,sout); + + DLIB_TEST(str == sout.str()); + + + } + + + + + } + catch (typename base64::decode_error& e) + { + DLIB_TEST_MSG(false, + "decode_error thrown when it shouldn't have been (" << seed << "):\n " + << e.info); + } + } + + + class base64_tester : public tester + { + public: + base64_tester ( + ) : + tester ("test_base64", + "Runs tests on the base64 component.") + {} + + void perform_test ( + ) + { + print_spinner(); + base64_kernel_test(); + } + } a; + + + +} + + + diff --git a/ml/dlib/dlib/test/bayes_nets.cpp b/ml/dlib/dlib/test/bayes_nets.cpp new file mode 100644 index 000000000..1a3035762 --- /dev/null +++ b/ml/dlib/dlib/test/bayes_nets.cpp @@ -0,0 +1,411 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include "dlib/graph_utils.h" +#include "dlib/graph.h" +#include "dlib/directed_graph.h" +#include "dlib/bayes_utils.h" +#include "dlib/set.h" +#include +#include +#include +#include + +#include "tester.h" + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.bayes_nets"); + enum nodes + { + A, T, S, L, O, B, D, X + }; + + template + void setup_simple_network ( + gtype& bn + ) + { + /* + A + / \ + T S + */ + + using namespace bayes_node_utils; + + bn.set_number_of_nodes(3); + bn.add_edge(A, T); + bn.add_edge(A, S); + + + set_node_num_values(bn, A, 2); + set_node_num_values(bn, T, 2); + set_node_num_values(bn, S, 2); + + assignment parents; + + // set probabilities for node A + set_node_probability(bn, A, 1, parents, 0.1); + set_node_probability(bn, A, 0, parents, 1-0.1); + + // set probabilities for node T + parents.add(A, 1); + set_node_probability(bn, T, 1, parents, 0.5); + set_node_probability(bn, T, 0, parents, 1-0.5); + parents[A] = 0; + set_node_probability(bn, T, 1, parents, 0.5); + set_node_probability(bn, T, 0, parents, 1-0.5); + + // set probabilities for node S + parents[A] = 1; + set_node_probability(bn, S, 1, parents, 0.5); + set_node_probability(bn, S, 0, parents, 1-0.5); + parents[A] = 0; + set_node_probability(bn, S, 1, parents, 0.5); + set_node_probability(bn, S, 0, parents, 1-0.5); + + + // test the serialization code here by pushing this network though it + ostringstream sout; + serialize(bn, sout); + bn.clear(); + DLIB_TEST(bn.number_of_nodes() == 0); + istringstream sin(sout.str()); + deserialize(bn, sin); + DLIB_TEST(bn.number_of_nodes() == 3); + } + + + template + void setup_dyspnea_network ( + gtype& bn, + bool deterministic_o_node = true + ) + { + /* + This is the example network used by David Zaret in his + reasoning under uncertainty class at Johns Hopkins + */ + + using namespace bayes_node_utils; + + bn.set_number_of_nodes(8); + bn.add_edge(A, T); + bn.add_edge(T, O); + + bn.add_edge(O, D); + bn.add_edge(O, X); + + bn.add_edge(S, B); + bn.add_edge(S, L); + + bn.add_edge(L, O); + bn.add_edge(B, D); + + + set_node_num_values(bn, A, 2); + set_node_num_values(bn, T, 2); + set_node_num_values(bn, O, 2); + set_node_num_values(bn, X, 2); + set_node_num_values(bn, L, 2); + set_node_num_values(bn, S, 2); + set_node_num_values(bn, B, 2); + set_node_num_values(bn, D, 2); + + assignment parents; + + // set probabilities for node A + set_node_probability(bn, A, 1, parents, 0.01); + set_node_probability(bn, A, 0, parents, 1-0.01); + + // set probabilities for node S + set_node_probability(bn, S, 1, parents, 0.5); + set_node_probability(bn, S, 0, parents, 1-0.5); + + // set probabilities for node T + parents.add(A, 1); + set_node_probability(bn, T, 1, parents, 0.05); + set_node_probability(bn, T, 0, parents, 1-0.05); + parents[A] = 0; + set_node_probability(bn, T, 1, parents, 0.01); + set_node_probability(bn, T, 0, parents, 1-0.01); + + // set probabilities for node L + parents.clear(); + parents.add(S,1); + set_node_probability(bn, L, 1, parents, 0.1); + set_node_probability(bn, L, 0, parents, 1-0.1); + parents[S] = 0; + set_node_probability(bn, L, 1, parents, 0.01); + set_node_probability(bn, L, 0, parents, 1-0.01); + + + // set probabilities for node B + parents[S] = 1; + set_node_probability(bn, B, 1, parents, 0.6); + set_node_probability(bn, B, 0, parents, 1-0.6); + parents[S] = 0; + set_node_probability(bn, B, 1, parents, 0.3); + set_node_probability(bn, B, 0, parents, 1-0.3); + + + // set probabilities for node O + double v; + if (deterministic_o_node) + v = 1; + else + v = 0.99; + + parents.clear(); + parents.add(T,1); + parents.add(L,1); + set_node_probability(bn, O, 1, parents, v); + set_node_probability(bn, O, 0, parents, 1-v); + parents[T] = 0; parents[L] = 1; + set_node_probability(bn, O, 1, parents, v); + set_node_probability(bn, O, 0, parents, 1-v); + parents[T] = 1; parents[L] = 0; + set_node_probability(bn, O, 1, parents, v); + set_node_probability(bn, O, 0, parents, 1-v); + parents[T] = 0; parents[L] = 0; + set_node_probability(bn, O, 1, parents, 1-v); + set_node_probability(bn, O, 0, parents, v); + + + // set probabilities for node D + parents.clear(); + parents.add(O,1); + parents.add(B,1); + set_node_probability(bn, D, 1, parents, 0.9); + set_node_probability(bn, D, 0, parents, 1-0.9); + parents[O] = 1; parents[B] = 0; + set_node_probability(bn, D, 1, parents, 0.7); + set_node_probability(bn, D, 0, parents, 1-0.7); + parents[O] = 0; parents[B] = 1; + set_node_probability(bn, D, 1, parents, 0.8); + set_node_probability(bn, D, 0, parents, 1-0.8); + parents[O] = 0; parents[B] = 0; + set_node_probability(bn, D, 1, parents, 0.1); + set_node_probability(bn, D, 0, parents, 1-0.1); + + + // set probabilities for node X + parents.clear(); + parents.add(O,1); + set_node_probability(bn, X, 1, parents, 0.98); + set_node_probability(bn, X, 0, parents, 1-0.98); + parents[O] = 0; + set_node_probability(bn, X, 1, parents, 0.05); + set_node_probability(bn, X, 0, parents, 1-0.05); + + + // test the serialization code here by pushing this network though it + ostringstream sout; + serialize(bn, sout); + bn.clear(); + DLIB_TEST(bn.number_of_nodes() == 0); + istringstream sin(sout.str()); + deserialize(bn, sin); + DLIB_TEST(bn.number_of_nodes() == 8); + } + + + void bayes_nets_test ( + ) + /*! + ensures + - runs tests on the bayesian network objects and functions for compliance with the specs + !*/ + { + + print_spinner(); + + directed_graph::kernel_1a_c bn; + setup_dyspnea_network(bn); + + using namespace bayes_node_utils; + + + graph::compare_1b_c, dlib::set::compare_1b_c>::kernel_1a_c join_tree; + + create_moral_graph(bn, join_tree); + create_join_tree(join_tree, join_tree); + + bayesian_network_join_tree solution(bn, join_tree); + + matrix dist; + + dist = solution.probability(A); + DLIB_TEST(abs(sum(dist) - 1.0) < 1e-5); + DLIB_TEST(abs(dist(1) - 0.01 ) < 1e-5); + + dist = solution.probability(T); + DLIB_TEST(abs(sum(dist) - 1.0) < 1e-5); + DLIB_TEST(abs(dist(1) - 0.0104) < 1e-5); + + dist = solution.probability(O); + DLIB_TEST(abs(sum(dist) - 1.0) < 1e-5); + DLIB_TEST(abs(dist(1) - 0.064828) < 1e-5); + + dist = solution.probability(X); + DLIB_TEST(abs(sum(dist) - 1.0) < 1e-5); + DLIB_TEST(abs(dist(1) - 0.11029004) < 1e-5); + + dist = solution.probability(L); + DLIB_TEST(abs(sum(dist) - 1.0) < 1e-5); + DLIB_TEST(abs(dist(1) - 0.055) < 1e-5); + + dist = solution.probability(S); + DLIB_TEST(abs(sum(dist) - 1.0) < 1e-5); + DLIB_TEST(abs(dist(1) - 0.5) < 1e-5); + + dist = solution.probability(B); + DLIB_TEST(abs(sum(dist) - 1.0) < 1e-5); + DLIB_TEST(abs(dist(1) - 0.4499999) < 1e-5); + + dist = solution.probability(D); + DLIB_TEST(abs(sum(dist) - 1.0) < 1e-5); + DLIB_TEST(abs(dist(1) - 0.4359706 ) < 1e-5); + + // now lets modify the probabilities of the bayesian network by making O + // not a deterministic node anymore but otherwise leave the network alone + setup_dyspnea_network(bn, false); + + set_node_value(bn, A, 1); + set_node_value(bn, X, 1); + set_node_value(bn, S, 1); + // lets also make some of these nodes evidence nodes + set_node_as_evidence(bn, A); + set_node_as_evidence(bn, X); + set_node_as_evidence(bn, S); + + // reload the solution now that we have changed the probabilities of node O + bayesian_network_join_tree(bn, join_tree).swap(solution); + DLIB_TEST(solution.number_of_nodes() == bn.number_of_nodes()); + + dist = solution.probability(A); + DLIB_TEST(abs(sum(dist) - 1.0) < 1e-5); + DLIB_TEST(abs(dist(1) - 1.0 ) < 1e-5); + + dist = solution.probability(T); + DLIB_TEST(abs(sum(dist) - 1.0) < 1e-5); + DLIB_TEST(abs(dist(1) - 0.253508694039 ) < 1e-5); + + dist = solution.probability(O); + DLIB_TEST(abs(sum(dist) - 1.0) < 1e-5); + DLIB_TEST(abs(dist(1) - 0.77856184024 ) < 1e-5); + + dist = solution.probability(X); + DLIB_TEST(abs(sum(dist) - 1.0) < 1e-5); + DLIB_TEST(abs(dist(1) - 1.0 ) < 1e-5); + + dist = solution.probability(L); + DLIB_TEST(abs(sum(dist) - 1.0) < 1e-5); + DLIB_TEST(abs(dist(1) - 0.5070173880 ) < 1e-5); + + dist = solution.probability(S); + DLIB_TEST(abs(sum(dist) - 1.0) < 1e-5); + DLIB_TEST(abs(dist(1) - 1.0 ) < 1e-5); + + dist = solution.probability(B); + DLIB_TEST(abs(sum(dist) - 1.0) < 1e-5); + DLIB_TEST(abs(dist(1) - 0.6 ) < 1e-5); + + dist = solution.probability(D); + DLIB_TEST(abs(sum(dist) - 1.0) < 1e-5); + DLIB_TEST(abs(dist(1) - 0.7535685520 ) < 1e-5); + + + // now lets test the bayesian_network_gibbs_sampler + set_node_value(bn, A, 1); + set_node_value(bn, T, 1); + set_node_value(bn, O, 1); + set_node_value(bn, X, 1); + set_node_value(bn, S, 1); + set_node_value(bn, L, 1); + set_node_value(bn, B, 1); + set_node_value(bn, D, 1); + + bayesian_network_gibbs_sampler sampler; + matrix counts; + set_all_elements(counts, 0); + const unsigned long rounds = 500000; + for (unsigned long i = 0; i < rounds; ++i) + { + sampler.sample_graph(bn); + + for (long c = 0; c < counts.nc(); ++c) + { + if (node_value(bn, c) == 1) + counts(c) += 1; + } + + if ((i&0x3FF) == 0) + { + print_spinner(); + } + } + + counts /= rounds; + + DLIB_TEST(abs(counts(A) - 1.0 ) < 1e-2); + DLIB_TEST(abs(counts(T) - 0.253508694039 ) < 1e-2); + DLIB_TEST_MSG(abs(counts(O) - 0.77856184024 ) < 1e-2,abs(counts(O) - 0.77856184024 ) ); + DLIB_TEST(abs(counts(X) - 1.0 ) < 1e-2); + DLIB_TEST(abs(counts(L) - 0.5070173880 ) < 1e-2); + DLIB_TEST(abs(counts(S) - 1.0 ) < 1e-2); + DLIB_TEST(abs(counts(B) - 0.6 ) < 1e-2); + DLIB_TEST(abs(counts(D) - 0.7535685520 ) < 1e-2); + + + setup_simple_network(bn); + create_moral_graph(bn, join_tree); + create_join_tree(join_tree, join_tree); + bayesian_network_join_tree(bn, join_tree).swap(solution); + DLIB_TEST(solution.number_of_nodes() == bn.number_of_nodes()); + + dist = solution.probability(A); + DLIB_TEST(abs(sum(dist) - 1.0) < 1e-5); + DLIB_TEST(abs(dist(1) - 0.1 ) < 1e-5); + + dist = solution.probability(T); + DLIB_TEST(abs(sum(dist) - 1.0) < 1e-5); + DLIB_TEST(abs(dist(1) - 0.5 ) < 1e-5); + + dist = solution.probability(S); + DLIB_TEST(abs(sum(dist) - 1.0) < 1e-5); + DLIB_TEST(abs(dist(1) - 0.5 ) < 1e-5); + + + } + + + + + class bayes_nets_tester : public tester + { + public: + bayes_nets_tester ( + ) : + tester ("test_bayes_nets", + "Runs tests on the bayes_nets objects and functions.") + {} + + void perform_test ( + ) + { + bayes_nets_test(); + } + } a; + +} + + + + diff --git a/ml/dlib/dlib/test/bigint.cpp b/ml/dlib/dlib/test/bigint.cpp new file mode 100644 index 000000000..3ddc631b4 --- /dev/null +++ b/ml/dlib/dlib/test/bigint.cpp @@ -0,0 +1,522 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include + +#include + +#include "tester.h" + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.bigint"); + + namespace bigint_kernel_test_helpers + { + template < + typename bint + > + bint short_fact (unsigned short value) + /*! + ensures + - returns the factorial of value + !*/ + { + using namespace relational_operators; + + bint a = 1; + for (unsigned short i = 2; i <= value; ++i) + a *= i; + + return a; + } + + template < + typename bint + > + bint short_fact_squared (unsigned short value) + /*! + ensures + - returns the square of the factorial of value + !*/ + { + using namespace relational_operators; + + bint a = 1; + for (unsigned short i = 2; i <= value; ++i) + { + a *= i; + a *= i; + } + + return a; + } + + + template < + typename bint + > + bint big_fact (unsigned short value) + /*! + ensures + - returns the factorial of value + !*/ + { + using namespace relational_operators; + + bint a = 1; + int k = 0; + for (bint i = 2; i <= value; ++i) + { + ++k; + if (k%10 == 0) + print_spinner(); + a *= i; + } + + return a; + } + } + + template < + typename bint + > + void bigint_kernel_test ( + ) + /*! + requires + - bint is an implementation of bigint/bigint_kernel_abstract.h + ensures + - runs tests on bint for compliance with the specs + !*/ + { + using namespace bigint_kernel_test_helpers; + using namespace relational_operators; + istringstream sin; + ostringstream sout; + + bint i = 0; + bint a(5), b, c(0); + + DLIB_TEST(5 - a == 0); + DLIB_TEST(a - 5 == 0); + + DLIB_TEST(0 - c == 0); + DLIB_TEST(c - 0 == 0); + + DLIB_TEST(0 + c == 0); + DLIB_TEST(c + 0 == 0); + + DLIB_TEST(0 + a == 5); + DLIB_TEST(a + 0 == 5); + + DLIB_TEST(0 - b == 0); + DLIB_TEST(b - 0 == 0); + + DLIB_TEST(0 + b == 0); + DLIB_TEST(b + 0 == 0); + + DLIB_TEST(i == 0); + DLIB_TEST(a == 5); + DLIB_TEST(b == 0); + DLIB_TEST(c == 0); + + + + a -= 5; + DLIB_TEST(a == 0); + + + + for (int k = 0; k < 100; ++k) + { + // compute the factorial of k using the O(n) multiplication algorithm + a = short_fact(k); + // compute the factorial of k using the full blown big int + // multiplication algorithm. + b = big_fact(k); + // compute the square of the factorial of k using the full blown + // big int multiplication algorithm. + c = a*b; + // make sure a and b ended up being the same number + DLIB_TEST_MSG(a == b, + "k: " << k << "\n" + "short_fact: " << a << "\n" + "big_fact: " << b + ); + // make sure c really is the square of the factorial of k + DLIB_TEST_MSG(short_fact_squared(k) == c,"k: " << k); + print_spinner(); + } + + // do the same thing as the last loop but do it with way bigger numbers + for (int k = 1000; k < 10000; k += 2000) + { + bint a = short_fact(k); + bint b = big_fact(k); + bint c = a*b; + DLIB_TEST_MSG(a == b, + "k: " << k << "\n" + "short_fact: " << a << "\n" + "big_fact: " << b + ); + DLIB_TEST_MSG(short_fact_squared(k) == c,"k: " << k); + print_spinner(); + } + + + + // test the << and >> operators a little + a = big_fact(20); + sout << a; + DLIB_TEST_MSG( sout.str() == "2432902008176640000","was: " << a); + + sin.str("684626312793279327952039475203945"); + sin >> a; + sout.str(""); + sout << a; + DLIB_TEST(sout.str() == "684626312793279327952039475203945"); + + print_spinner(); + + DLIB_TEST(a > 0); + + + // make sure that when you try to read something that isn't a number + // into a bigint you get an error + DLIB_TEST(sin.fail() == false); + sin.str("the cat ate some cheese"); + sin >> a; + DLIB_TEST(sin.fail() == true); + sin.clear(); + sin.str(""); + + + + sin.str("3628913"); + sin >> i; + DLIB_TEST(short_fact(10) + short_fact(5) - 7 == i); + + sin.str("2432902008173011193"); + sin >> i; + DLIB_TEST(short_fact(20) - short_fact(10) - 7 == i); + + // test the serialization stuff + sout.str(""); + serialize(i,sout); + i = 0; + sin.str(sout.str()); + deserialize(i,sin); + + DLIB_TEST(short_fact(20) - short_fact(10) - 7 == i); + + + + + print_spinner(); + + + + + sin.str("100000"); + sin >> b; + a = b; + ++b; + DLIB_TEST_MSG ( a + 1 == b,"a==" << a << endl << "b==" << b << endl); + + + + + + // compute some stuff and see if you get the right value + a = 0; + b = 0; + sin.str("1000000"); + sin >> b; + int mel = 0; + for (i = a; i <= b; ++i) + { + // switch it up on em + if (i%2 == 0) + a = a + i; + else + a += i; + ++mel; + if ((mel&0xFFF) == 0) + print_spinner(); + } + DLIB_TEST_MSG(a == b*(b+1)/2, "a==" << a << endl << "b*(b+1)/2==" << b*(b+1)/2 << endl); + + + + + + + print_spinner(); + + + // compute some stuff and see if you get the right value + // this time going the other way using operator-- + a = 0; + b = 0; + sin.str("100000"); + sin >> b; + i = b; + DLIB_TEST(i == b); + DLIB_TEST_MSG(i > 0,"i==" << i); + mel = 0; + for (i = b; i > 0; --i) + { + // switch it up on em + if (i%2 == 0) + a = a + i; + else + a += i; + ++mel; + if ((mel&0xFF) == 0) + print_spinner(); + } + DLIB_TEST_MSG(a == b*(b+1)/2, "a==" << a << endl << "b*(b+1)/2==" << b*(b+1)/2 << endl); + + + + + + + + + + + + DLIB_TEST(short_fact(10)/short_fact(5) == 30240); + DLIB_TEST(short_fact(10)/(short_fact(5)+1) == 29990); + + sin.str("221172909834240000"); + sin >> a; + DLIB_TEST(short_fact(20)/(short_fact(5)+1) == a/11); + + sin.str("670442388044"); + sin >> b; + DLIB_TEST(short_fact(20)/(short_fact(10)+1) == b); + + print_spinner(); + + sin.str("1860479"); + sin >> i; + DLIB_TEST_MSG(short_fact(20)/(short_fact(15)+1) == i,short_fact(20)/(short_fact(15)+1)); + + // test the serialization stuff + sout.str(""); + serialize(i,sout); + i = 0; + sin.str(sout.str()); + deserialize(i,sin); + + DLIB_TEST_MSG(short_fact(20)/(short_fact(15)+1) == i,short_fact(20)/(short_fact(15)+1)); + + + print_spinner(); + + // test the serialization stuff + sout.str(""); + i = 0; + serialize(i,sout); + i = 1234; + sin.str(sout.str()); + deserialize(i,sin); + DLIB_TEST(i == 0); + + + DLIB_TEST(short_fact(10000)/short_fact(9999) == 10000); + + + DLIB_TEST(bint(5)%bint(1) == 0); + DLIB_TEST(bint(5)%bint(6) == 5); + DLIB_TEST(bint(25)%bint(6) == 1); + print_spinner(); + DLIB_TEST(bint(354)%bint(123) == 108); + DLIB_TEST(bint(20)%(bint(10)) == 0); + DLIB_TEST(bint(20)%(bint(10)+1) == 9); + + DLIB_TEST(bint(20)%(bint(15)+1) == 4); + + + DLIB_TEST(short_fact(10)%(short_fact(5)+2) == 32); + + sin.str("2908082"); + sin >> i; + DLIB_TEST(short_fact(15)%(short_fact(10)+2) == i); + + + + + + + // same as some of the above stuff but using big_fact + + DLIB_TEST(big_fact(10)%(big_fact(5)+2) == 32); + + sin.str("2908082"); + sin >> i; + DLIB_TEST(big_fact(15)%(big_fact(10)+2) == i); + + + print_spinner(); + + + DLIB_TEST(big_fact(10)/big_fact(5) == 30240); + DLIB_TEST(big_fact(10)/(big_fact(5)+1) == 29990); + + sin.str("221172909834240000"); + sin >> a; + DLIB_TEST(big_fact(20)/(big_fact(5)+1) == a/11); + + + sin.str("670442388044"); + sin >> b; + DLIB_TEST(big_fact(20)/(big_fact(10)+1) == b); + + + sin.str("1860479"); + sin >> i; + DLIB_TEST_MSG(big_fact(20)/(big_fact(15)+1) == i,big_fact(20)/(big_fact(15)+1)); + + DLIB_TEST(big_fact(100)/big_fact(99) == 100); + + + + + sout.str(""); + sout << "148571596448176149730952273362082573788556996128468876694221686370498539309"; + sout << "4065876545992131370884059645617234469978112000000000000000000000"; + sin.str(sout.str()); + sin >> a; + + sout.str(""); + sout << "933262154439441526816992388562667004907159682643816214685929638952175999932"; + sout << "299156089414639761565182862536979208272237582511852109168640000000000000000"; + sout << "000000"; + sin.str(sout.str()); + sin >> b; + + + sout.str(""); + sout << "138656248189732152054159609718432247180282092567575172939636909224427929240"; + sout << "834642263988043338170905744175653189424779336521852536242160190545537133916"; + sout << "649622615351174407746524657461692702500613722228638559932561661493048332720"; + sout << "6050692647868232055316807680000000000000000000000000000000000000000000"; + sin.str(sout.str()); + sin >> c; + + DLIB_TEST_MSG(a*b == c, + "a*b: " << a*b << + "\nc: " << c); + + + print_spinner(); + + i = 0; + mel = 0; + unsigned long j; + for (j = 0; i < bint(100000); ++j) + { + DLIB_TEST(i++ == bint(j)); + ++mel; + if((mel&0xFF) == 0) + print_spinner(); + } + DLIB_TEST(j == 100000); + + i = 1234; + + DLIB_TEST(i == 1234); + DLIB_TEST(i < 2345 ); + DLIB_TEST(i > 0 ); + DLIB_TEST(i > 123 ); + + DLIB_TEST(i != 1334); + DLIB_TEST(i <= 2345); + DLIB_TEST(i >= 0 ); + DLIB_TEST(i >= 123 ); + DLIB_TEST(i >= 1234); + DLIB_TEST(i <= 1234); + + + DLIB_TEST(1234 == i); + DLIB_TEST(2345 > i); + DLIB_TEST(0 < i); + DLIB_TEST(123 < i); + + DLIB_TEST(1334 != i); + DLIB_TEST(2345 >= i); + DLIB_TEST(0 <= i); + DLIB_TEST(123 <= i); + DLIB_TEST(1234 <= i); + DLIB_TEST(1234 >= i); + + + a = big_fact(200); + b = big_fact(100); + + DLIB_TEST(a > b); + DLIB_TEST(a != b); + DLIB_TEST(b < a); + DLIB_TEST(b != a); + DLIB_TEST(b <= a); + DLIB_TEST(a >= b); + + + + a = 10000; + a = a*a*a*a; a = a*a; a = a*a; + b = 2; + DLIB_TEST((a/b)*b == a); + + a = 10000*5; + a = a*a*a*a; a = a*a; a = a*a; + b = 5; + DLIB_TEST((a/b)*b == a); + } + + + + + class bigint_tester : public tester + { + public: + bigint_tester ( + ) : + tester ("test_bigint", + "Runs tests on the bigint component.") + {} + + void perform_test ( + ) + { + dlog << LINFO << "testing kernel_1a"; + bigint_kernel_test(); + print_spinner(); + + dlog << LINFO << "testing kernel_1a_c"; + bigint_kernel_test(); + print_spinner(); + + dlog << LINFO << "testing kernel_2a"; + bigint_kernel_test(); + print_spinner(); + + dlog << LINFO << "testing kernel_2a_c"; + bigint_kernel_test(); + print_spinner(); + + } + } a; + +} + diff --git a/ml/dlib/dlib/test/binary_search_tree.h b/ml/dlib/dlib/test/binary_search_tree.h new file mode 100644 index 000000000..18bdff70d --- /dev/null +++ b/ml/dlib/dlib/test/binary_search_tree.h @@ -0,0 +1,889 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BINARY_SEARCH_TREE_KERNEl_TEST_H_ +#define DLIB_BINARY_SEARCH_TREE_KERNEl_TEST_H_ + + +#include +#include +#include +#include + +#include +#include +#include +#include "tester.h" + +namespace +{ + + using namespace test; + using namespace std; + using namespace dlib; + + logger dlog("test.binary_search_tree"); + + template < + typename bst + > + void binary_search_tree_kernel_test ( + ) + /*! + requires + - bst is an implementation of + binary_search_tree/binary_search_tree_kernel_abstract.h is instantiated + to map int to int + ensures + - runs tests on bst for compliance with the specs + !*/ + { + + bst test, test2; + + srand(static_cast(time(0))); + + + DLIB_TEST(test.count(3) == 0); + + enumerable >& e = test; + DLIB_TEST(e.at_start() == true); + + DLIB_TEST(test.count(3) == 0); + + for (int i = 0; i < 4; ++i) + { + DLIB_TEST(test.size() == 0); + DLIB_TEST(test.count(3) == 0); + DLIB_TEST(test.height() == 0); + DLIB_TEST(test[5] == 0); + DLIB_TEST(test[0] == 0); + DLIB_TEST(test.at_start()); + DLIB_TEST(test.current_element_valid() == false); + DLIB_TEST(test.count(3) == 0); + + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.count(3) == 0); + + DLIB_TEST(test.at_start() == false); + DLIB_TEST(test.current_element_valid() == false); + + test.clear(); + test.position_enumerator(5); + DLIB_TEST(test.current_element_valid() == false); + DLIB_TEST(test.at_start() == false); + test.position_enumerator(5); + DLIB_TEST(test.current_element_valid() == false); + DLIB_TEST(test.at_start() == false); + test.position_enumerator(9); + DLIB_TEST(test.current_element_valid() == false); + DLIB_TEST(test.at_start() == false); + test.clear(); + test.position_enumerator(5); + DLIB_TEST(test.current_element_valid() == false); + DLIB_TEST(test.at_start() == false); + test.position_enumerator(5); + DLIB_TEST(test.current_element_valid() == false); + DLIB_TEST(test.at_start() == false); + test.position_enumerator(9); + DLIB_TEST(test.current_element_valid() == false); + DLIB_TEST(test.at_start() == false); + test.clear(); + DLIB_TEST(test.at_start() == true); + DLIB_TEST(test.current_element_valid() == false); + + + DLIB_TEST(test.count(3) == 0); + + DLIB_TEST(test.size() == 0); + DLIB_TEST(test.height() == 0); + DLIB_TEST(test[5] == 0); + DLIB_TEST(test[0] == 0); + DLIB_TEST(const_cast(test)[5] == 0); + DLIB_TEST(const_cast(test)[0] == 0); + DLIB_TEST(test.at_start()); + DLIB_TEST(test.current_element_valid() == false); + + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.move_next() == false); + + DLIB_TEST(test.at_start() == false); + DLIB_TEST(test.current_element_valid() == false); + + + DLIB_TEST(test.count(3) == 0); + test.reset(); + DLIB_TEST(test.count(3) == 0); + + DLIB_TEST(test.at_start()); + DLIB_TEST(test.current_element_valid() == false); + + + + + + + int a = 0, b = 0; + + for (int i = 0; i < 10000; ++i) + { + a = ::rand()%1000; + int temp = a; + unsigned long count = test.count(a); + test.add(a,b); + DLIB_TEST(test.count(temp) == count+1); + } + + + { + unsigned long count = test.count(3); + + a = 3; test.add(a,b); ++count; + DLIB_TEST(test.count(3) == count); + a = 3; test.add(a,b); ++count; + DLIB_TEST(test.count(3) == count); + a = 3; test.add(a,b); ++count; + DLIB_TEST(test.count(3) == count); + a = 3; test.add(a,b); ++count; + DLIB_TEST(test.count(3) == count); + } + + + test.clear(); + + + + + + for (int i = 0; i < 10000; ++i) + { + a = ::rand()&0x7FFF; + b = 0; + int temp = a; + unsigned long count = test.count(a); + test.add(a,b); + DLIB_TEST(test.count(temp) == count+1); + } + + // serialize the state of test, then clear test, then + // load the state back into test. + ostringstream sout; + serialize(test,sout); + istringstream sin(sout.str()); + test.clear(); + deserialize(test,sin); + + DLIB_TEST(test.size() == 10000); + DLIB_TEST(test.at_start() == true); + DLIB_TEST(test.current_element_valid() == false); + + + DLIB_TEST_MSG(test.height() > 13 && test.height() <= 26,"this is somewhat of an implementation dependent " + << "but really it should be in this range or the implementation is just crap"); + + a = 0; + unsigned long count = 0; + while (test.move_next()) + { + DLIB_TEST_MSG(a <= test.element().key(),"the numers are out of order but they should be in order"); + a = test.element().key(); + ++count; + + + DLIB_TEST(test.at_start() == false); + DLIB_TEST(test.current_element_valid() == true); + } + + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.move_next() == false); + + DLIB_TEST(count == 10000); + + + + + DLIB_TEST_MSG(test.height() > 13 && test.height() <= 26,"this is somewhat of an implementation dependent " + << "but really it should be in this range or the implementation is just crap"); + + DLIB_TEST(test.at_start() == false); + DLIB_TEST(test.current_element_valid() == false); + DLIB_TEST(test.size() == 10000); + + + swap(test,test2); + + + test2.reset(); + count = 0; + a = 0; + while (test2.move_next()) + { + DLIB_TEST_MSG(a <= test2.element().key(),"the numers are out of order but they should be in order"); + a = test2.element().key(); + ++count; + + + DLIB_TEST(test2.at_start() == false); + DLIB_TEST(test2.current_element_valid() == true); + + if (count == 5000) + { + break; + } + } + + DLIB_TEST(test2.move_next() == true); + DLIB_TEST(test2.move_next() == true); + DLIB_TEST(test2.move_next() == true); + + + test2.reset(); + + count = 0; + a = 0; + while (test2.move_next()) + { + DLIB_TEST_MSG(a <= test2.element().key(),"the numers are out of order but they should be in order"); + a = test2.element().key(); + ++count; + + + DLIB_TEST(test2.at_start() == false); + DLIB_TEST(test2.current_element_valid() == true); + } + + DLIB_TEST(count == 10000); + DLIB_TEST(test2.move_next() == false); + DLIB_TEST(test2.move_next() == false); + DLIB_TEST(test2.move_next() == false); + + + + + + + + + int last = 0; + asc_pair_remover& asdf = test2; + DLIB_TEST(asdf.size() > 0); + while (asdf.size() > 0) + { + asdf.remove_any(a,b); + DLIB_TEST(last <= a); + last = a; + --count; + DLIB_TEST(asdf.size() == count); + } + + + DLIB_TEST(test2.size() == 0); + DLIB_TEST(test2.height() ==0); + DLIB_TEST(test2.at_start() == true); + DLIB_TEST(test2.current_element_valid() == false); + DLIB_TEST(test2.move_next() == false); + DLIB_TEST(test2.move_next() == false); + DLIB_TEST(test2.move_next() == false); + + + + + for (int i = 0; i < 10000; ++i) + { + a = i; + b = i; + test2.add(a,b); + DLIB_TEST(test2.size() == (unsigned int)(i +1)); + DLIB_TEST(test2.count(i) == 1); + } + + a = 0; + test2.position_enumerator(a); + DLIB_TEST(test2.at_start() == false); + DLIB_TEST(test2.element().key() == a); + DLIB_TEST(test2.element().value() == a); + a = 0; + test2.position_enumerator(a); + DLIB_TEST(test2.element().key() == a); + DLIB_TEST(test2.element().value() == a); + a = 8; + test2.position_enumerator(a); + DLIB_TEST(test2.at_start() == false); + DLIB_TEST(test2.element().key() == a); + DLIB_TEST(test2.element().value() == a); + a = 1; + test2.position_enumerator(a); + DLIB_TEST(test2.element().key() == a); + DLIB_TEST(test2.element().value() == a); + a = -29; + test2.position_enumerator(a); + DLIB_TEST(test2.element().key() == 0); + DLIB_TEST(test2.element().value() == 0); + a = 10000; + test2.position_enumerator(a); + DLIB_TEST(test2.at_start() == false); + DLIB_TEST(test2.current_element_valid() == false); + a = -29; + test2.position_enumerator(a); + DLIB_TEST(test2.element().key() == 0); + DLIB_TEST(test2.element().value() == 0); + a = 8; + test2.position_enumerator(a); + DLIB_TEST(test2.at_start() == false); + DLIB_TEST(test2.element().key() == a); + DLIB_TEST(test2.element().value() == a); + test2.reset(); + + + DLIB_TEST_MSG(test2.height() > 13 && test2.height() <= 26,"this is somewhat of an implementation dependent " + << "but really it should be in this range or the implementation is just crap"); + + DLIB_TEST(test2.at_start() == true); + DLIB_TEST(test2.current_element_valid() == false); + DLIB_TEST(test2.size() == 10000); + + + for (int i = 0; i < 10000; ++i) + { + DLIB_TEST(test2.move_next() == true); + DLIB_TEST(test2.element().key() == i); + } + + + + DLIB_TEST_MSG(test2.height() > 13 && test2.height() <= 26,"this is somewhat of an implementation dependent " + << "but really it should be in this range or the implementation is just crap"); + + DLIB_TEST(test2.at_start() == false); + DLIB_TEST(test2.current_element_valid() == true); + DLIB_TEST(test2.size() == 10000); + + + DLIB_TEST(test2.move_next() == false); + DLIB_TEST(test2.current_element_valid() == false); + + a = 3; + test2.add(a,b); + DLIB_TEST(test2.count(3) == 2); + + + for (int i = 0; i < 10000; ++i) + { + test2.remove(i,a,b); + DLIB_TEST(i == a); + } + test2.remove(3,a,b); + + + DLIB_TEST(test2.size() == 0); + DLIB_TEST(test2.height() == 0); + DLIB_TEST(test2.at_start() == true); + DLIB_TEST(test2.current_element_valid() == false); + DLIB_TEST(test2.move_next() == false); + DLIB_TEST(test2.at_start() == false); + DLIB_TEST(test2.current_element_valid() == false); + + + + test2.clear(); + + + int m = 0; + for (int i = 0; i < 10000; ++i) + { + a = ::rand()&0x7FFF; + m = max(a,m); + test2.add(a,b); + } + + DLIB_TEST(test2.at_start() == true); + DLIB_TEST(test2.move_next() == true); + DLIB_TEST(test2.at_start() == false); + DLIB_TEST(test2.current_element_valid() == true); + DLIB_TEST(test2.move_next() == true); + DLIB_TEST(test2.current_element_valid() == true); + DLIB_TEST(test2.move_next() == true); + DLIB_TEST(test2.current_element_valid() == true); + DLIB_TEST(test2.move_next() == true); + DLIB_TEST(test2.current_element_valid() == true); + DLIB_TEST(test2.at_start() == false); + + for (int i = 0; i < 10000; ++i) + { + a = ::rand()&0xFFFF; + test2.position_enumerator(a); + if (test2[a]) + { + DLIB_TEST(test2.element().key() == a); + } + else if (a <= m) + { + DLIB_TEST(test2.element().key() > a); + } + } + + test2.clear(); + + DLIB_TEST(test2.current_element_valid() == false); + DLIB_TEST(test2.at_start() == true); + DLIB_TEST(test2.move_next() == false); + DLIB_TEST(test2.at_start() == false); + DLIB_TEST(test2.current_element_valid() == false); + DLIB_TEST(test2.move_next() == false); + DLIB_TEST(test2.current_element_valid() == false); + DLIB_TEST(test2.move_next() == false); + DLIB_TEST(test2.current_element_valid() == false); + DLIB_TEST(test2.at_start() == false); + + + DLIB_TEST(test2.size() == 0); + DLIB_TEST(test2.height() == 0); + + + for (int i = 0; i < 20000; ++i) + { + a = ::rand()&0x7FFF; + b = a; + test2.add(a,b); + } + + + DLIB_TEST(test2.size() == 20000); + + + + // remove a bunch of elements randomly + int c; + for (int i = 0; i < 50000; ++i) + { + a = ::rand()&0x7FFF; + if (test2[a] != 0) + { + test2.remove(a,b,c); + DLIB_TEST(a == b); + } + } + + + // now add a bunch more + for (int i = 0; i < 10000; ++i) + { + a = ::rand()&0x7FFF; + b = a; + test2.add(a,b); + } + + + // now iterate over it all and then remove all elements + { + int* array = new int[test2.size()]; + int* tmp = array; + DLIB_TEST(test2.at_start() == true); + while (test2.move_next()) + { + *tmp = test2.element().key(); + ++tmp; + } + + DLIB_TEST(test2.at_start() == false); + DLIB_TEST(test2.current_element_valid() == false); + DLIB_TEST(test2.move_next() == false); + + tmp = array; + for (int i = 0; i < 10000; ++i) + { + DLIB_TEST(*test2[*tmp] == *tmp); + DLIB_TEST(*test2[*tmp] == *tmp); + DLIB_TEST(*test2[*tmp] == *tmp); + DLIB_TEST(*const_cast(test2)[*tmp] == *tmp); + ++tmp; + } + + tmp = array; + while (test2.size() > 0) + { + unsigned long count = test2.count(*tmp); + test2.destroy(*tmp); + DLIB_TEST(test2.count(*tmp)+1 == count); + ++tmp; + } + + DLIB_TEST(test2.at_start() == true); + DLIB_TEST(test2.current_element_valid() == false); + DLIB_TEST(test2.move_next() == false); + DLIB_TEST(test2.at_start() == false); + test.swap(test2); + test.reset(); + + delete [] array; + } + + + DLIB_TEST(test.size() == 0); + DLIB_TEST(test.height() == 0); + + for (unsigned long i = 1; i < 100; ++i) + { + a = 1234; + test.add(a,b); + DLIB_TEST(test.count(1234) == i); + } + + test.clear(); + + + + + + + for (int m = 0; m < 3; ++m) + { + + test2.clear(); + + DLIB_TEST(test2.current_element_valid() == false); + DLIB_TEST(test2.at_start() == true); + DLIB_TEST(test2.move_next() == false); + DLIB_TEST(test2.at_start() == false); + DLIB_TEST(test2.current_element_valid() == false); + DLIB_TEST(test2.move_next() == false); + DLIB_TEST(test2.current_element_valid() == false); + DLIB_TEST(test2.move_next() == false); + DLIB_TEST(test2.current_element_valid() == false); + DLIB_TEST(test2.at_start() == false); + + + DLIB_TEST(test2.size() == 0); + DLIB_TEST(test2.height() == 0); + + + int counter = 0; + while (counter < 10000) + { + a = ::rand()&0x7FFF; + b = ::rand()&0x7FFF; + if (test2[a] == 0) + { + test2.add(a,b); + ++counter; + } + + } + + + + DLIB_TEST(test2.size() == 10000); + + + + // remove a bunch of elements randomly + for (int i = 0; i < 20000; ++i) + { + a = ::rand()&0x7FFF; + if (test2[a] != 0) + { + test2.remove(a,b,c); + DLIB_TEST(a == b); + } + } + + + // now add a bunch more + for (int i = 0; i < 20000; ++i) + { + a = ::rand()&0x7FFF; + b = ::rand()&0x7FFF; + if (test2[a] == 0) + test2.add(a,b); + } + + + // now iterate over it all and then remove all elements + { + int* array = new int[test2.size()]; + int* array_val = new int[test2.size()]; + int* tmp = array; + int* tmp_val = array_val; + DLIB_TEST(test2.at_start() == true); + int count = 0; + while (test2.move_next()) + { + *tmp = test2.element().key(); + ++tmp; + *tmp_val = test2.element().value(); + ++tmp_val; + + DLIB_TEST(*test2[*(tmp-1)] == *(tmp_val-1)); + ++count; + } + + DLIB_TEST(count == (int)test2.size()); + DLIB_TEST(test2.at_start() == false); + DLIB_TEST(test2.current_element_valid() == false); + DLIB_TEST(test2.move_next() == false); + + tmp = array; + tmp_val = array_val; + for (unsigned long i = 0; i < test2.size(); ++i) + { + DLIB_TEST_MSG(*test2[*tmp] == *tmp_val,i); + DLIB_TEST(*test2[*tmp] == *tmp_val); + DLIB_TEST(*test2[*tmp] == *tmp_val); + DLIB_TEST(*const_cast(test2)[*tmp] == *tmp_val); + ++tmp; + ++tmp_val; + } + + // out << "\nsize: " << test2.size() << endl; + // out << "height: " << test2.height() << endl; + + tmp = array; + while (test2.size() > 0) + { + unsigned long count = test2.count(*tmp); + test2.destroy(*tmp); + DLIB_TEST(test2.count(*tmp)+1 == count); + ++tmp; + } + + DLIB_TEST(test2.at_start() == true); + DLIB_TEST(test2.current_element_valid() == false); + DLIB_TEST(test2.move_next() == false); + DLIB_TEST(test2.at_start() == false); + test.swap(test2); + test.reset(); + + delete [] array; + delete [] array_val; + } + + + DLIB_TEST(test.size() == 0); + DLIB_TEST(test.height() == 0); + + for (unsigned long i = 1; i < 100; ++i) + { + a = 1234; + test.add(a,b); + DLIB_TEST(test.count(1234) == i); + } + + test.clear(); + + } + + + + a = 1; + b = 2; + + test.add(a,b); + + test.position_enumerator(0); + a = 0; + b = 0; + DLIB_TEST(test.height() == 1); + test.remove_current_element(a,b); + DLIB_TEST(a == 1); + DLIB_TEST(b == 2); + DLIB_TEST(test.at_start() == false); + DLIB_TEST(test.current_element_valid() == false); + DLIB_TEST(test.height() == 0); + DLIB_TEST(test.size() == 0); + + + a = 1; + b = 2; + test.add(a,b); + a = 1; + b = 2; + test.add(a,b); + + test.position_enumerator(0); + a = 0; + b = 0; + DLIB_TEST(test.height() == 2); + test.remove_current_element(a,b); + DLIB_TEST(a == 1); + DLIB_TEST(b == 2); + DLIB_TEST(test.at_start() == false); + DLIB_TEST(test.current_element_valid() == true); + DLIB_TEST(test.height() == 1); + DLIB_TEST(test.size() == 1); + + test.remove_current_element(a,b); + DLIB_TEST(a == 1); + DLIB_TEST(b == 2); + DLIB_TEST(test.at_start() == false); + DLIB_TEST(test.current_element_valid() == false); + DLIB_TEST(test.height() == 0); + DLIB_TEST(test.size() == 0); + + for (int i = 0; i < 100; ++i) + { + a = i; + b = i; + test.add(a,b); + } + + DLIB_TEST(test.size() == 100); + test.remove_last_in_order(a,b); + DLIB_TEST(a == 99); + DLIB_TEST(b == 99); + DLIB_TEST(test.size() == 99); + test.remove_last_in_order(a,b); + DLIB_TEST(a == 98); + DLIB_TEST(b == 98); + DLIB_TEST(test.size() == 98); + + test.position_enumerator(-10); + for (int i = 0; i < 97; ++i) + { + DLIB_TEST(test.element().key() == i); + DLIB_TEST(test.element().value() == i); + DLIB_TEST(test.move_next()); + } + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.current_element_valid() == false); + + + test.position_enumerator(10); + for (int i = 10; i < 97; ++i) + { + DLIB_TEST(test.element().key() == i); + DLIB_TEST(test.element().value() == i); + DLIB_TEST(test.move_next()); + } + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.current_element_valid() == false); + + test.reset(); + DLIB_TEST(test.at_start()); + DLIB_TEST(test.current_element_valid() == false); + for (int i = 0; i < 98; ++i) + { + DLIB_TEST(test.move_next()); + DLIB_TEST(test.element().key() == i); + DLIB_TEST(test.element().value() == i); + } + DLIB_TEST_MSG(test.size() == 98, test.size()); + DLIB_TEST(test.move_next() == false); + + test.position_enumerator(98); + DLIB_TEST(test.current_element_valid() == false); + DLIB_TEST(test.at_start() == false); + + + test.position_enumerator(50); + DLIB_TEST(test.element().key() == 50); + DLIB_TEST(test.element().value() == 50); + DLIB_TEST(test[50] != 0); + test.remove_current_element(a,b); + DLIB_TEST(test[50] == 0); + DLIB_TEST_MSG(test.size() == 97, test.size()); + DLIB_TEST(a == 50); + DLIB_TEST(b == 50); + DLIB_TEST(test.element().key() == 51); + DLIB_TEST(test.element().value() == 51); + DLIB_TEST(test.current_element_valid()); + test.remove_current_element(a,b); + DLIB_TEST_MSG(test.size() == 96, test.size()); + DLIB_TEST(a == 51); + DLIB_TEST(b == 51); + DLIB_TEST_MSG(test.element().key() == 52,test.element().key()); + DLIB_TEST_MSG(test.element().value() == 52,test.element().value()); + DLIB_TEST(test.current_element_valid()); + test.remove_current_element(a,b); + DLIB_TEST_MSG(test.size() == 95, test.size()); + DLIB_TEST(a == 52); + DLIB_TEST(b == 52); + DLIB_TEST_MSG(test.element().key() == 53,test.element().key()); + DLIB_TEST_MSG(test.element().value() == 53,test.element().value()); + DLIB_TEST(test.current_element_valid()); + test.position_enumerator(50); + DLIB_TEST_MSG(test.element().key() == 53,test.element().key()); + DLIB_TEST_MSG(test.element().value() == 53,test.element().value()); + DLIB_TEST(test.current_element_valid()); + test.position_enumerator(51); + DLIB_TEST_MSG(test.element().key() == 53,test.element().key()); + DLIB_TEST_MSG(test.element().value() == 53,test.element().value()); + DLIB_TEST(test.current_element_valid()); + test.position_enumerator(52); + DLIB_TEST_MSG(test.element().key() == 53,test.element().key()); + DLIB_TEST_MSG(test.element().value() == 53,test.element().value()); + DLIB_TEST(test.current_element_valid()); + test.position_enumerator(53); + DLIB_TEST_MSG(test.element().key() == 53,test.element().key()); + DLIB_TEST_MSG(test.element().value() == 53,test.element().value()); + DLIB_TEST(test.current_element_valid()); + + test.reset(); + test.move_next(); + int lasta = -1, lastb = -1; + count = 0; + while (test.current_element_valid() ) + { + ++count; + int c = test.element().key(); + int d = test.element().value(); + test.remove_current_element(a,b); + DLIB_TEST(c == a); + DLIB_TEST(d == a); + DLIB_TEST(lasta < a); + DLIB_TEST(lastb < b); + lasta = a; + lastb = b; + } + DLIB_TEST_MSG(count == 95, count); + DLIB_TEST(test.size() == 0); + DLIB_TEST(test.height() == 0); + + test.clear(); + + for (int i = 0; i < 1000; ++i) + { + a = 1; + b = 1; + test.add(a,b); + } + + for (int i = 0; i < 40; ++i) + { + int num = ::rand()%800 + 1; + test.reset(); + for (int j = 0; j < num; ++j) + { + DLIB_TEST(test.move_next()); + } + DLIB_TEST_MSG(test.current_element_valid(),"size: " << test.size() << " num: " << num); + test.remove_current_element(a,b); + DLIB_TEST_MSG(test.current_element_valid(),"size: " << test.size() << " num: " << num); + test.remove_current_element(a,b); + test.position_enumerator(1); + if (test.current_element_valid()) + test.remove_current_element(a,b); + DLIB_TEST(a == 1); + DLIB_TEST(b == 1); + } + + test.clear(); + + } + + + test.clear(); + test2.clear(); + + } + +} + +#endif // DLIB_BINARY_SEARCH_TREE_KERNEl_TEST_H_ + diff --git a/ml/dlib/dlib/test/binary_search_tree_kernel_1a.cpp b/ml/dlib/dlib/test/binary_search_tree_kernel_1a.cpp new file mode 100644 index 000000000..b7e0b3a1a --- /dev/null +++ b/ml/dlib/dlib/test/binary_search_tree_kernel_1a.cpp @@ -0,0 +1,47 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BINARY_SEARCH_TREE_KERNEl_TEST_ +#define DLIB_BINARY_SEARCH_TREE_KERNEl_TEST_ + + +#include +#include +#include +#include + +#include +#include +#include +#include "tester.h" +#include "binary_search_tree.h" + +namespace +{ + + + class binary_search_tree_tester : public tester + { + + public: + binary_search_tree_tester ( + ) : + tester ("test_binary_search_tree_kernel_1a", + "Runs tests on the binary_search_tree_kernel_1a component.") + {} + + void perform_test ( + ) + { + dlog << LINFO << "testing kernel_1a"; + binary_search_tree_kernel_test::kernel_1a>(); + print_spinner(); + + dlog << LINFO << "testing kernel_1a_c"; + binary_search_tree_kernel_test::kernel_1a_c>(); + print_spinner(); + } + } a; + +} + +#endif diff --git a/ml/dlib/dlib/test/binary_search_tree_kernel_2a.cpp b/ml/dlib/dlib/test/binary_search_tree_kernel_2a.cpp new file mode 100644 index 000000000..e2be4b143 --- /dev/null +++ b/ml/dlib/dlib/test/binary_search_tree_kernel_2a.cpp @@ -0,0 +1,45 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BINARY_SEARCH_TREE_KERNEl_TEST_ +#define DLIB_BINARY_SEARCH_TREE_KERNEl_TEST_ + + +#include +#include +#include +#include + +#include +#include +#include +#include "tester.h" +#include "binary_search_tree.h" + +namespace +{ + + class binary_search_tree_tester : public tester + { + public: + binary_search_tree_tester ( + ) : + tester ("test_binary_search_tree_kernel_2a", + "Runs tests on the binary_search_tree_kernel_2a component.") + {} + + void perform_test ( + ) + { + dlog << LINFO << "testing kernel_2a"; + binary_search_tree_kernel_test::kernel_2a>(); + print_spinner(); + + dlog << LINFO << "testing kernel_2a_c"; + binary_search_tree_kernel_test::kernel_2a_c>(); + print_spinner(); + } + } a; + +} + +#endif diff --git a/ml/dlib/dlib/test/binary_search_tree_mm1.cpp b/ml/dlib/dlib/test/binary_search_tree_mm1.cpp new file mode 100644 index 000000000..a9693bd15 --- /dev/null +++ b/ml/dlib/dlib/test/binary_search_tree_mm1.cpp @@ -0,0 +1,66 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BINARY_SEARCH_TREE_KERNEl_TEST_ +#define DLIB_BINARY_SEARCH_TREE_KERNEl_TEST_ + + +#include +#include +#include +#include + +#include +#include +#include +#include "tester.h" +#include "binary_search_tree.h" + +namespace +{ + + class binary_search_tree_tester : public tester + { + struct factory + { + template + struct return_type { + typedef typename memory_manager::kernel_1c type; + }; + + template + static typename return_type::type* get_instance ( + ) + { + static typename return_type::type instance; + return &instance; + } + + }; + + + public: + binary_search_tree_tester ( + ) : + tester ("test_binary_search_tree_mm1", + "Runs tests on the binary_search_tree component with memory managers.") + {} + + void perform_test ( + ) + { + dlog << LINFO << "testing kernel_1a /w memory_manager_global"; + binary_search_tree_kernel_test::kernel_1a>::kernel_1a>(); + print_spinner(); + + + dlog << LINFO << "testing kernel_1a /w memory_manager_stateless"; + binary_search_tree_kernel_test::kernel_1a>::kernel_1a>(); + print_spinner(); + } + } a; + +} + +#endif diff --git a/ml/dlib/dlib/test/binary_search_tree_mm2.cpp b/ml/dlib/dlib/test/binary_search_tree_mm2.cpp new file mode 100644 index 000000000..354b1f91c --- /dev/null +++ b/ml/dlib/dlib/test/binary_search_tree_mm2.cpp @@ -0,0 +1,48 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BINARY_SEARCH_TREE_KERNEl_TEST_ +#define DLIB_BINARY_SEARCH_TREE_KERNEl_TEST_ + + +#include +#include +#include +#include + +#include +#include +#include +#include "tester.h" +#include "binary_search_tree.h" + +namespace +{ + + class binary_search_tree_tester : public tester + { + + public: + binary_search_tree_tester ( + ) : + tester ("test_binary_search_tree_mm2", + "Runs tests on the binary_search_tree component with memory managers.") + {} + + void perform_test ( + ) + { + dlog << LINFO << "testing kernel_1a /w memory_manager_stateless_2"; + binary_search_tree_kernel_test::kernel_2_2c>::kernel_1a>(); + print_spinner(); + + dlog << LINFO << "testing kernel_1a /w memory_manager_3"; + binary_search_tree_kernel_test::kernel_3b>::kernel_1a>(); + print_spinner(); + } + } a; + +} + +#endif diff --git a/ml/dlib/dlib/test/blas_bindings/CMakeLists.txt b/ml/dlib/dlib/test/blas_bindings/CMakeLists.txt new file mode 100644 index 000000000..5deddee04 --- /dev/null +++ b/ml/dlib/dlib/test/blas_bindings/CMakeLists.txt @@ -0,0 +1,33 @@ +# +# This is a CMake makefile. You can find the cmake utility and +# information about it at http://www.cmake.org +# + +cmake_minimum_required(VERSION 2.8.12) + +# This variable contains a list of all the tests we are building +# into the regression test suite. +set (tests + blas_bindings_gemm.cpp + blas_bindings_gemv.cpp + blas_bindings_ger.cpp + blas_bindings_dot.cpp + blas_bindings_scal_axpy.cpp + vector.cpp + ) + +# create a variable called target_name and set it to the string "test" +set (target_name dtest) + +PROJECT(${target_name}) + +# add all the cpp files we want to compile to this list. This tells +# cmake that they are part of our target (which is the executable named test) +ADD_EXECUTABLE(${target_name} ../main.cpp ../tester.cpp ${tests}) + +ADD_DEFINITIONS(-DDLIB_TEST_BLAS_BINDINGS) + +# Tell cmake to link our target executable to dlib +include(../../cmake) +TARGET_LINK_LIBRARIES(${target_name} dlib ) + diff --git a/ml/dlib/dlib/test/blas_bindings/blas_bindings_dot.cpp b/ml/dlib/dlib/test/blas_bindings/blas_bindings_dot.cpp new file mode 100644 index 000000000..0571b0685 --- /dev/null +++ b/ml/dlib/dlib/test/blas_bindings/blas_bindings_dot.cpp @@ -0,0 +1,314 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "../tester.h" +#include + +#ifndef DLIB_USE_BLAS +#error "BLAS bindings must be used for this test to make any sense" +#endif + +namespace dlib +{ + namespace blas_bindings + { + // This is a little screwy. This function is used inside the BLAS + // bindings to count how many times each of the BLAS functions get called. +#ifdef DLIB_TEST_BLAS_BINDINGS + int& counter_dot() { static int counter = 0; return counter; } +#endif + + } +} + +namespace +{ + using namespace test; + using namespace std; + // Declare the logger we will use in this test. The name of the logger + // should start with "test." + dlib::logger dlog("test.dot"); + + + class blas_bindings_dot_tester : public tester + { + public: + blas_bindings_dot_tester ( + ) : + tester ( + "test_dot", // the command line argument name for this test + "Run tests for DOT routines.", // the command line argument description + 0 // the number of command line arguments for this test + ) + {} + + void test_mat_bindings() + { + using namespace dlib; + using namespace dlib::blas_bindings; + matrix rv(10); + matrix cv(10); + double val; + + rv = 1; cv = 1; + counter_dot() = 0; + val = rv*cv; + DLIB_TEST(val == 10); + DLIB_TEST(counter_dot() == 1); + + rv = 1; cv = 1; + counter_dot() = 0; + val = rv*mat(&cv(0),cv.size()); + DLIB_TEST(val == 10); + DLIB_TEST(counter_dot() == 1); + + rv = 1; cv = 1; + counter_dot() = 0; + val = trans(mat(&rv(0),rv.size()))*mat(&cv(0),cv.size()); + DLIB_TEST(val == 10); + DLIB_TEST(counter_dot() == 1); + + std::vector sv(10,1); + rv = 1; + counter_dot() = 0; + val = trans(mat(&rv(0),rv.size()))*mat(sv); + DLIB_TEST(val == 10); + DLIB_TEST(counter_dot() == 1); + + + counter_dot() = 0; + val = trans(mat(sv))*mat(sv); + DLIB_TEST(val == 10); + DLIB_TEST(counter_dot() == 1); + + std_vector_c svc(10,1); + counter_dot() = 0; + val = trans(mat(svc))*mat(svc); + DLIB_TEST(val == 10); + DLIB_TEST(counter_dot() == 1); + + + dlib::array arr(10); + for (unsigned int i = 0; i < arr.size(); ++i) + arr[i] = 1; + counter_dot() = 0; + val = trans(mat(arr))*mat(arr); + DLIB_TEST(val == 10); + DLIB_TEST(counter_dot() == 1); + } + + template + void test_dot_stuff( + matrix_type& m, + rv_type& rv, + cv_type& cv + ) const + { + using namespace dlib; + using namespace dlib::blas_bindings; + + rv_type rv2; + cv_type cv2; + matrix_type m2; + typedef typename matrix_type::type scalar_type; + scalar_type val; + + counter_dot() = 0; + m2 = rv*cv; + DLIB_TEST(counter_dot() == 1); + + counter_dot() = 0; + val = rv*cv; + DLIB_TEST(counter_dot() == 1); + + counter_dot() = 0; + val = rv*3*cv; + DLIB_TEST(counter_dot() == 1); + + counter_dot() = 0; + val = rv*trans(rv)*3; + DLIB_TEST(counter_dot() == 1); + + counter_dot() = 0; + val = trans(rv*trans(rv)*3 + trans(cv)*cv); + DLIB_TEST(counter_dot() == 2); + + + counter_dot() = 0; + val = trans(cv)*cv; + DLIB_TEST(counter_dot() == 1); + + counter_dot() = 0; + val = trans(cv)*trans(rv); + DLIB_TEST(counter_dot() == 1); + + counter_dot() = 0; + val = dot(rv,cv); + DLIB_TEST(counter_dot() == 1); + + counter_dot() = 0; + val = dot(rv,colm(cv,0)); + DLIB_TEST(counter_dot() == 1); + + counter_dot() = 0; + val = dot(cv,cv); + DLIB_TEST(counter_dot() == 1); + + counter_dot() = 0; + val = dot(colm(cv,0,cv.size()),colm(cv,0)); + DLIB_TEST(counter_dot() == 1); + + counter_dot() = 0; + val = dot(rv,rv); + DLIB_TEST(counter_dot() == 1); + + counter_dot() = 0; + val = dot(rv,trans(rv)); + DLIB_TEST(counter_dot() == 1); + + counter_dot() = 0; + val = dot(trans(cv),cv); + DLIB_TEST(counter_dot() == 1); + + counter_dot() = 0; + val = dot(trans(cv),trans(rv)); + DLIB_TEST(counter_dot() == 1); + + + // This does one dot and one gemv + counter_dot() = 0; + val = trans(cv)*m*trans(rv); + DLIB_TEST_MSG(counter_dot() == 1, counter_dot()); + + // This does one dot and two gemv + counter_dot() = 0; + val = (trans(cv)*m)*(m*trans(rv)); + DLIB_TEST_MSG(counter_dot() == 1, counter_dot()); + + // This does one dot and two gemv + counter_dot() = 0; + val = trans(cv)*m*trans(m)*trans(rv); + DLIB_TEST_MSG(counter_dot() == 1, counter_dot()); + } + + + template + void test_dot_stuff_conj( + matrix_type& , + rv_type& rv, + cv_type& cv + ) const + { + using namespace dlib; + using namespace dlib::blas_bindings; + + rv_type rv2; + cv_type cv2; + matrix_type m2; + typedef typename matrix_type::type scalar_type; + scalar_type val; + + counter_dot() = 0; + val = conj(rv)*cv; + DLIB_TEST(counter_dot() == 1); + + counter_dot() = 0; + val = trans(conj(cv))*cv; + DLIB_TEST(counter_dot() == 1); + + counter_dot() = 0; + val = trans(conj(cv))*trans(rv); + DLIB_TEST(counter_dot() == 1); + + counter_dot() = 0; + val = trans(conj(cv))*3*trans(rv); + DLIB_TEST(counter_dot() == 1); + } + + void perform_test ( + ) + { + using namespace dlib; + typedef dlib::memory_manager::kernel_1a mm; + + dlog << dlib::LINFO << "test double"; + { + matrix m = randm(4,4); + matrix rv = randm(1,4); + matrix cv = randm(4,1); + test_dot_stuff(m,rv,cv); + } + + dlog << dlib::LINFO << "test float"; + { + matrix m = matrix_cast(randm(4,4)); + matrix rv = matrix_cast(randm(1,4)); + matrix cv = matrix_cast(randm(4,1)); + test_dot_stuff(m,rv,cv); + } + + dlog << dlib::LINFO << "test complex"; + { + matrix > m = complex_matrix(randm(4,4), randm(4,4)); + matrix,1,0> rv = complex_matrix(randm(1,4), randm(1,4)); + matrix,0,1> cv = complex_matrix(randm(4,1), randm(4,1)); + test_dot_stuff(m,rv,cv); + test_dot_stuff_conj(m,rv,cv); + } + + dlog << dlib::LINFO << "test complex"; + { + matrix > m = matrix_cast >(complex_matrix(randm(4,4), randm(4,4))); + matrix,1,0> rv = matrix_cast >(complex_matrix(randm(1,4), randm(1,4))); + matrix,0,1> cv = matrix_cast >(complex_matrix(randm(4,1), randm(4,1))); + test_dot_stuff(m,rv,cv); + test_dot_stuff_conj(m,rv,cv); + } + + + dlog << dlib::LINFO << "test double, column major"; + { + matrix m = randm(4,4); + matrix rv = randm(1,4); + matrix cv = randm(4,1); + test_dot_stuff(m,rv,cv); + } + + dlog << dlib::LINFO << "test float, column major"; + { + matrix m = matrix_cast(randm(4,4)); + matrix rv = matrix_cast(randm(1,4)); + matrix cv = matrix_cast(randm(4,1)); + test_dot_stuff(m,rv,cv); + } + + dlog << dlib::LINFO << "test complex, column major"; + { + matrix,0,0,mm,column_major_layout > m = complex_matrix(randm(4,4), randm(4,4)); + matrix,1,0,mm,column_major_layout> rv = complex_matrix(randm(1,4), randm(1,4)); + matrix,0,1,mm,column_major_layout> cv = complex_matrix(randm(4,1), randm(4,1)); + test_dot_stuff(m,rv,cv); + test_dot_stuff_conj(m,rv,cv); + } + + dlog << dlib::LINFO << "test complex, column major"; + { + matrix,0,0,mm,column_major_layout > m = matrix_cast >(complex_matrix(randm(4,4), randm(4,4))); + matrix,1,0,mm,column_major_layout> rv = matrix_cast >(complex_matrix(randm(1,4), randm(1,4))); + matrix,0,1,mm,column_major_layout> cv = matrix_cast >(complex_matrix(randm(4,1), randm(4,1))); + test_dot_stuff(m,rv,cv); + test_dot_stuff_conj(m,rv,cv); + } + + + test_mat_bindings(); + + print_spinner(); + } + }; + + blas_bindings_dot_tester a; + +} + + diff --git a/ml/dlib/dlib/test/blas_bindings/blas_bindings_gemm.cpp b/ml/dlib/dlib/test/blas_bindings/blas_bindings_gemm.cpp new file mode 100644 index 000000000..83d41edd1 --- /dev/null +++ b/ml/dlib/dlib/test/blas_bindings/blas_bindings_gemm.cpp @@ -0,0 +1,311 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "../tester.h" +#include + +#ifndef DLIB_USE_BLAS +#error "BLAS bindings must be used for this test to make any sense" +#endif + +namespace dlib +{ + namespace blas_bindings + { + // This is a little screwy. This function is used inside the BLAS + // bindings to count how many times each of the BLAS functions get called. +#ifdef DLIB_TEST_BLAS_BINDINGS + int& counter_gemm() { static int counter = 0; return counter; } +#endif + + } +} + +namespace +{ + using namespace test; + using namespace std; + // Declare the logger we will use in this test. The name of the logger + // should start with "test." + dlib::logger dlog("test.gemm"); + + + class blas_bindings_gemm_tester : public tester + { + public: + blas_bindings_gemm_tester ( + ) : + tester ( + "test_gemm", // the command line argument name for this test + "Run tests for GEMM routines.", // the command line argument description + 0 // the number of command line arguments for this test + ) + {} + + template + void test_gemm_stuff( + const matrix_type& c + ) const + { + using namespace dlib; + using namespace dlib::blas_bindings; + + matrix_type b, a; + a = c; + + counter_gemm() = 0; + b = a*a; + DLIB_TEST(counter_gemm() == 1); + + counter_gemm() = 0; + b = a/2*a; + DLIB_TEST(counter_gemm() == 1); + + counter_gemm() = 0; + b = a*trans(a) + a; + DLIB_TEST(counter_gemm() == 1); + + counter_gemm() = 0; + b = (a+a)*(a+a); + DLIB_TEST(counter_gemm() == 1); + + counter_gemm() = 0; + b = a*(a-a); + DLIB_TEST(counter_gemm() == 1); + + counter_gemm() = 0; + b = trans(a)*trans(a) + a; + DLIB_TEST(counter_gemm() == 1); + + counter_gemm() = 0; + b = trans(trans(trans(a)*a + a)); + DLIB_TEST(counter_gemm() == 1); + + counter_gemm() = 0; + b = a*a*a*a; + DLIB_TEST(counter_gemm() == 3); + b = c; + + counter_gemm() = 0; + a = a*a*a*a; + DLIB_TEST(counter_gemm() == 3); + a = c; + + counter_gemm() = 0; + a = (b + a*trans(a)*a*3*a)*trans(b); + DLIB_TEST(counter_gemm() == 4); + a = c; + + counter_gemm() = 0; + a = trans((trans(b) + trans(a)*trans(a)*a*3*a)*trans(b)); + DLIB_TEST(counter_gemm() == 4); + a = c; + + counter_gemm() = 0; + a = trans((trans(b) + trans(a)*(a)*trans(a)*3*a)*trans(b)); + DLIB_TEST(counter_gemm() == 4); + a = c; + + counter_gemm() = 0; + a = trans((trans(b) + trans(a)*(a + b)*trans(a)*3*a)*trans(b)); + DLIB_TEST_MSG(counter_gemm() == 4, counter_gemm()); + a = c; + + counter_gemm() = 0; + a = trans((trans(b) + trans(a)*(a*8 + b+b+b+b)*trans(a)*3*a)*trans(b)); + DLIB_TEST_MSG(counter_gemm() == 4, counter_gemm()); + a = c; + } + + template + void test_gemm_stuff_conj( + const matrix_type& c + ) const + { + using namespace dlib; + using namespace dlib::blas_bindings; + + matrix_type b, a; + a = c; + + counter_gemm() = 0; + b = a*conj(a); + DLIB_TEST(counter_gemm() == 1); + + counter_gemm() = 0; + b = a*trans(conj(a)) + a; + DLIB_TEST(counter_gemm() == 1); + + counter_gemm() = 0; + b = conj(trans(a))*trans(a) + a; + DLIB_TEST(counter_gemm() == 1); + + counter_gemm() = 0; + b = trans(trans(trans(a)*conj(a) + conj(a))); + DLIB_TEST(counter_gemm() == 1); + + counter_gemm() = 0; + b = a*a*conj(a)*a; + DLIB_TEST(counter_gemm() == 3); + b = c; + + counter_gemm() = 0; + a = a*trans(conj(a))*a*a; + DLIB_TEST(counter_gemm() == 3); + a = c; + + counter_gemm() = 0; + a = (b + a*trans(conj(a))*a*3*a)*trans(b); + DLIB_TEST(counter_gemm() == 4); + a = c; + + counter_gemm() = 0; + a = (trans((conj(trans(b)) + trans(a)*conj(trans(a))*a*3*a)*trans(b))); + DLIB_TEST(counter_gemm() == 4); + a = c; + + counter_gemm() = 0; + a = ((trans(b) + trans(a)*(a)*trans(a)*3*a)*trans(conj(b))); + DLIB_TEST(counter_gemm() == 4); + a = c; + + counter_gemm() = 0; + a = trans((trans(b) + trans(a)*conj(a + b)*trans(a)*3*a)*trans(b)); + DLIB_TEST_MSG(counter_gemm() == 4, counter_gemm()); + a = c; + + counter_gemm() = 0; + a = trans((trans(b) + trans(a)*(a*8 + b+b+b+b)*trans(a)*3*conj(a))*trans(b)); + DLIB_TEST_MSG(counter_gemm() == 4, counter_gemm()); + a = c; + } + + void perform_test ( + ) + { + using namespace dlib; + typedef dlib::memory_manager::kernel_1a mm; + + print_spinner(); + + dlog << dlib::LINFO << "test double"; + { + matrix a = randm(4,4); + test_gemm_stuff(a); + } + + print_spinner(); + dlog << dlib::LINFO << "test float"; + { + matrix a = matrix_cast(randm(4,4)); + test_gemm_stuff(a); + } + + print_spinner(); + dlog << dlib::LINFO << "test complex"; + { + matrix a = matrix_cast(randm(4,4)); + matrix b = matrix_cast(randm(4,4)); + matrix > c = complex_matrix(a,b); + test_gemm_stuff(c); + test_gemm_stuff_conj(c); + } + + print_spinner(); + dlog << dlib::LINFO << "test complex"; + { + matrix a = matrix_cast(randm(4,4)); + matrix b = matrix_cast(randm(4,4)); + matrix > c = complex_matrix(a,b); + test_gemm_stuff(c); + test_gemm_stuff_conj(c); + } + + + print_spinner(); + + dlog << dlib::LINFO << "test double, column major"; + { + matrix a = randm(100,100); + test_gemm_stuff(a); + } + + print_spinner(); + dlog << dlib::LINFO << "test float, column major"; + { + matrix a = matrix_cast(randm(100,100)); + test_gemm_stuff(a); + } + + print_spinner(); + dlog << dlib::LINFO << "test complex, column major"; + { + matrix a = matrix_cast(randm(100,100)); + matrix b = matrix_cast(randm(100,100)); + matrix,100,100,mm,column_major_layout > c = complex_matrix(a,b); + test_gemm_stuff(c); + test_gemm_stuff_conj(c); + } + + print_spinner(); + + dlog << dlib::LINFO << "test complex, column major"; + { + matrix a = matrix_cast(randm(100,100)); + matrix b = matrix_cast(randm(100,100)); + matrix,100,100,mm,column_major_layout > c = complex_matrix(a,b); + test_gemm_stuff(c); + test_gemm_stuff_conj(c); + } + + { + using namespace dlib; + using namespace dlib::blas_bindings; + array2d a(100,100); + array2d b(100,100); + matrix c; + + counter_gemm() = 0; + c = mat(a)*mat(b); + DLIB_TEST(counter_gemm() == 1); + + counter_gemm() = 0; + c = trans(2*mat(a)*mat(b)); + DLIB_TEST(counter_gemm() == 1); + } + + { + using namespace dlib; + using namespace dlib::blas_bindings; + array2d a(100,100); + array2d b(100,100); + matrix aa(100,100); + matrix bb(100,100); + matrix c; + + counter_gemm() = 0; + c = mat(&a[0][0],100,100)*mat(&b[0][0],100,100); + DLIB_TEST(counter_gemm() == 1); + set_ptrm(&c(0,0),100,100) = mat(&a[0][0],100,100)*mat(&b[0][0],100,100); + DLIB_TEST(counter_gemm() == 2); + set_ptrm(&c(0,0),100,100) = aa*bb; + DLIB_TEST(counter_gemm() == 3); + + counter_gemm() = 0; + c = trans(2*mat(&a[0][0],100,100)*mat(&b[0][0],100,100)); + DLIB_TEST(counter_gemm() == 1); + set_ptrm(&c(0,0),100,100) = trans(2*mat(&a[0][0],100,100)*mat(&b[0][0],100,100)); + DLIB_TEST(counter_gemm() == 2); + set_ptrm(&c(0,0),100,100) = trans(2*mat(a)*mat(b)); + DLIB_TEST(counter_gemm() == 3); + } + + print_spinner(); + } + }; + + blas_bindings_gemm_tester a; + +} + + diff --git a/ml/dlib/dlib/test/blas_bindings/blas_bindings_gemv.cpp b/ml/dlib/dlib/test/blas_bindings/blas_bindings_gemv.cpp new file mode 100644 index 000000000..322438313 --- /dev/null +++ b/ml/dlib/dlib/test/blas_bindings/blas_bindings_gemv.cpp @@ -0,0 +1,226 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "../tester.h" +#include + +#ifndef DLIB_USE_BLAS +#error "BLAS bindings must be used for this test to make any sense" +#endif + +namespace dlib +{ + namespace blas_bindings + { + // This is a little screwy. This function is used inside the BLAS + // bindings to count how many times each of the BLAS functions get called. +#ifdef DLIB_TEST_BLAS_BINDINGS + int& counter_gemv() { static int counter = 0; return counter; } +#endif + + } +} + +namespace +{ + using namespace test; + using namespace std; + // Declare the logger we will use in this test. The name of the logger + // should start with "test." + dlib::logger dlog("test.gemv"); + + + class blas_bindings_gemv_tester : public tester + { + public: + blas_bindings_gemv_tester ( + ) : + tester ( + "test_gemv", // the command line argument name for this test + "Run tests for GEMV routines.", // the command line argument description + 0 // the number of command line arguments for this test + ) + {} + + template + void test_gemv_stuff( + matrix_type& m, + cv_type& cv, + rv_type& rv + ) const + { + using namespace dlib; + using namespace dlib::blas_bindings; + + cv_type cv2; + rv_type rv2; + typedef typename matrix_type::type scalar_type; + scalar_type val; + + counter_gemv() = 0; + cv2 = m*cv; + DLIB_TEST(counter_gemv() == 1); + + counter_gemv() = 0; + cv2 = m*2*cv; + DLIB_TEST(counter_gemv() == 1); + + counter_gemv() = 0; + cv2 = m*2*trans(rv); + DLIB_TEST(counter_gemv() == 1); + + counter_gemv() = 0; + rv2 = trans(m*2*cv); + DLIB_TEST(counter_gemv() == 1); + + counter_gemv() = 0; + rv2 = rv*m; + DLIB_TEST(counter_gemv() == 1); + + counter_gemv() = 0; + rv2 = (rv + rv)*m; + DLIB_TEST(counter_gemv() == 1); + + counter_gemv() = 0; + rv2 = trans(cv)*m; + DLIB_TEST(counter_gemv() == 1); + dlog << dlib::LTRACE << 1; + + counter_gemv() = 0; + rv2 = trans(cv)*trans(m) + rv*trans(m); + DLIB_TEST(counter_gemv() == 2); + dlog << dlib::LTRACE << 2; + + counter_gemv() = 0; + cv2 = m*trans(trans(cv)*trans(m) + 3*rv*trans(m)); + DLIB_TEST(counter_gemv() == 3); + + // This does one dot and one gemv + counter_gemv() = 0; + val = trans(cv)*m*trans(rv); + DLIB_TEST_MSG(counter_gemv() == 1, counter_gemv()); + + // This does one dot and two gemv + counter_gemv() = 0; + val = (trans(cv)*m)*(m*trans(rv)); + DLIB_TEST_MSG(counter_gemv() == 2, counter_gemv()); + + // This does one dot and two gemv + counter_gemv() = 0; + val = trans(cv)*m*trans(m)*trans(rv); + DLIB_TEST_MSG(counter_gemv() == 2, counter_gemv()); + } + + + template + void test_gemv_stuff_conj( + matrix_type& m, + cv_type& cv, + rv_type& rv + ) const + { + using namespace dlib; + using namespace dlib::blas_bindings; + + cv_type cv2; + rv_type rv2; + + counter_gemv() = 0; + cv2 = trans(cv)*conj(m); + DLIB_TEST(counter_gemv() == 1); + + counter_gemv() = 0; + cv2 = conj(trans(m))*rv; + DLIB_TEST(counter_gemv() == 1); + + counter_gemv() = 0; + cv2 = conj(trans(m))*trans(cv); + DLIB_TEST(counter_gemv() == 1); + + counter_gemv() = 0; + cv2 = trans(trans(cv)*conj(2*m) + conj(3*trans(m))*rv + conj(trans(m)*3)*trans(cv)); + DLIB_TEST(counter_gemv() == 3); + + } + + void perform_test ( + ) + { + using namespace dlib; + typedef dlib::memory_manager::kernel_1a mm; + + dlog << dlib::LINFO << "test double"; + { + matrix m = randm(4,4); + matrix cv = randm(4,1); + matrix rv = randm(1,4); + test_gemv_stuff(m,cv,rv); + } + + dlog << dlib::LINFO << "test float"; + { + matrix m = matrix_cast(randm(4,4)); + matrix cv = matrix_cast(randm(4,1)); + matrix rv = matrix_cast(randm(1,4)); + test_gemv_stuff(m,cv,rv); + } + + dlog << dlib::LINFO << "test complex"; + { + matrix > m = complex_matrix(randm(4,4), randm(4,4)); + matrix,0,1> cv = complex_matrix(randm(4,1), randm(4,1)); + matrix,1,0> rv = complex_matrix(randm(1,4), randm(1,4)); + test_gemv_stuff(m,cv,rv); + } + + dlog << dlib::LINFO << "test complex"; + { + matrix > m = matrix_cast >(complex_matrix(randm(4,4), randm(4,4))); + matrix,0,1> cv = matrix_cast >(complex_matrix(randm(4,1), randm(4,1))); + matrix,1,0> rv = matrix_cast >(complex_matrix(randm(1,4), randm(1,4))); + test_gemv_stuff(m,cv,rv); + } + + + dlog << dlib::LINFO << "test double"; + { + matrix m = randm(4,4); + matrix cv = randm(4,1); + matrix rv = randm(1,4); + test_gemv_stuff(m,cv,rv); + } + + dlog << dlib::LINFO << "test float"; + { + matrix m = matrix_cast(randm(4,4)); + matrix cv = matrix_cast(randm(4,1)); + matrix rv = matrix_cast(randm(1,4)); + test_gemv_stuff(m,cv,rv); + } + + dlog << dlib::LINFO << "test complex"; + { + matrix,0,0,mm,column_major_layout > m = complex_matrix(randm(4,4), randm(4,4)); + matrix,0,1,mm,column_major_layout> cv = complex_matrix(randm(4,1), randm(4,1)); + matrix,1,0,mm,column_major_layout> rv = complex_matrix(randm(1,4), randm(1,4)); + test_gemv_stuff(m,cv,rv); + } + + dlog << dlib::LINFO << "test complex"; + { + matrix,0,0,mm,column_major_layout > m = matrix_cast >(complex_matrix(randm(4,4), randm(4,4))); + matrix,0,1,mm,column_major_layout> cv = matrix_cast >(complex_matrix(randm(4,1), randm(4,1))); + matrix,1,0,mm,column_major_layout> rv = matrix_cast >(complex_matrix(randm(1,4), randm(1,4))); + test_gemv_stuff(m,cv,rv); + } + + + print_spinner(); + } + }; + + blas_bindings_gemv_tester a; + +} + + diff --git a/ml/dlib/dlib/test/blas_bindings/blas_bindings_ger.cpp b/ml/dlib/dlib/test/blas_bindings/blas_bindings_ger.cpp new file mode 100644 index 000000000..2aac834d2 --- /dev/null +++ b/ml/dlib/dlib/test/blas_bindings/blas_bindings_ger.cpp @@ -0,0 +1,200 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "../tester.h" +#include + +#ifndef DLIB_USE_BLAS +#error "BLAS bindings must be used for this test to make any sense" +#endif + +namespace dlib +{ + namespace blas_bindings + { + // This is a little screwy. This function is used inside the BLAS + // bindings to count how many times each of the BLAS functions get called. +#ifdef DLIB_TEST_BLAS_BINDINGS + int& counter_ger() { static int counter = 0; return counter; } +#endif + + } +} + +namespace +{ + using namespace test; + using namespace std; + // Declare the logger we will use in this test. The name of the logger + // should start with "test." + dlib::logger dlog("test.ger"); + + + class blas_bindings_ger_tester : public tester + { + public: + blas_bindings_ger_tester ( + ) : + tester ( + "test_ger", // the command line argument name for this test + "Run tests for GER routines.", // the command line argument description + 0 // the number of command line arguments for this test + ) + {} + + template + void test_ger_stuff( + matrix_type& m, + rv_type& rv, + cv_type& cv + ) const + { + using namespace dlib; + using namespace dlib::blas_bindings; + + rv_type rv2; + cv_type cv2; + matrix_type m2; + + counter_ger() = 0; + m2 = m + cv*rv; + DLIB_TEST_MSG(counter_ger() == 1, counter_ger()); + + counter_ger() = 0; + m += trans(rv)*rv; + DLIB_TEST(counter_ger() == 1); + + counter_ger() = 0; + m += trans(rv)*trans(cv); + DLIB_TEST(counter_ger() == 1); + + counter_ger() = 0; + m += cv*trans(cv); + DLIB_TEST(counter_ger() == 1); + + counter_ger() = 0; + m += trans(rv)*rv + trans(cv*3*rv); + DLIB_TEST(counter_ger() == 2); + } + + + template + void test_ger_stuff_conj( + matrix_type& m, + rv_type& rv, + cv_type& cv + ) const + { + using namespace dlib; + using namespace dlib::blas_bindings; + + rv_type rv2; + cv_type cv2; + matrix_type m2; + + counter_ger() = 0; + m += cv*conj(rv); + DLIB_TEST_MSG(counter_ger() == 1, counter_ger()); + + counter_ger() = 0; + m += trans(rv)*conj(rv); + DLIB_TEST(counter_ger() == 1); + + counter_ger() = 0; + m += trans(rv)*conj(trans(cv)); + DLIB_TEST(counter_ger() == 1); + + counter_ger() = 0; + m += cv*trans(conj(cv)); + DLIB_TEST(counter_ger() == 1); + + counter_ger() = 0; + m += trans(rv)*rv + trans(cv*3*conj(rv)); + DLIB_TEST(counter_ger() == 2); + } + + void perform_test ( + ) + { + using namespace dlib; + typedef dlib::memory_manager::kernel_1a mm; + + dlog << dlib::LINFO << "test double"; + { + matrix m = randm(4,4); + matrix rv = randm(1,4); + matrix cv = randm(4,1); + test_ger_stuff(m,rv,cv); + } + + dlog << dlib::LINFO << "test float"; + { + matrix m = matrix_cast(randm(4,4)); + matrix rv = matrix_cast(randm(1,4)); + matrix cv = matrix_cast(randm(4,1)); + test_ger_stuff(m,rv,cv); + } + + dlog << dlib::LINFO << "test complex"; + { + matrix > m = complex_matrix(randm(4,4), randm(4,4)); + matrix,1,0> rv = complex_matrix(randm(1,4), randm(1,4)); + matrix,0,1> cv = complex_matrix(randm(4,1), randm(4,1)); + test_ger_stuff(m,rv,cv); + test_ger_stuff_conj(m,rv,cv); + } + + dlog << dlib::LINFO << "test complex"; + { + matrix > m = matrix_cast >(complex_matrix(randm(4,4), randm(4,4))); + matrix,1,0> rv = matrix_cast >(complex_matrix(randm(1,4), randm(1,4))); + matrix,0,1> cv = matrix_cast >(complex_matrix(randm(4,1), randm(4,1))); + test_ger_stuff(m,rv,cv); + test_ger_stuff_conj(m,rv,cv); + } + + + dlog << dlib::LINFO << "test double"; + { + matrix m = randm(4,4); + matrix rv = randm(1,4); + matrix cv = randm(4,1); + test_ger_stuff(m,rv,cv); + } + + dlog << dlib::LINFO << "test float"; + { + matrix m = matrix_cast(randm(4,4)); + matrix rv = matrix_cast(randm(1,4)); + matrix cv = matrix_cast(randm(4,1)); + test_ger_stuff(m,rv,cv); + } + + dlog << dlib::LINFO << "test complex"; + { + matrix,0,0,mm,column_major_layout > m = complex_matrix(randm(4,4), randm(4,4)); + matrix,1,0,mm,column_major_layout> rv = complex_matrix(randm(1,4), randm(1,4)); + matrix,0,1,mm,column_major_layout> cv = complex_matrix(randm(4,1), randm(4,1)); + test_ger_stuff(m,rv,cv); + test_ger_stuff_conj(m,rv,cv); + } + + dlog << dlib::LINFO << "test complex"; + { + matrix,0,0,mm,column_major_layout > m = matrix_cast >(complex_matrix(randm(4,4), randm(4,4))); + matrix,1,0,mm,column_major_layout> rv = matrix_cast >(complex_matrix(randm(1,4), randm(1,4))); + matrix,0,1,mm,column_major_layout> cv = matrix_cast >(complex_matrix(randm(4,1), randm(4,1))); + test_ger_stuff(m,rv,cv); + test_ger_stuff_conj(m,rv,cv); + } + + + print_spinner(); + } + }; + + blas_bindings_ger_tester a; + +} + + diff --git a/ml/dlib/dlib/test/blas_bindings/blas_bindings_scal_axpy.cpp b/ml/dlib/dlib/test/blas_bindings/blas_bindings_scal_axpy.cpp new file mode 100644 index 000000000..d1a7b99e4 --- /dev/null +++ b/ml/dlib/dlib/test/blas_bindings/blas_bindings_scal_axpy.cpp @@ -0,0 +1,261 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "../tester.h" +#include + +#ifndef DLIB_USE_BLAS +#error "BLAS bindings must be used for this test to make any sense" +#endif + +namespace dlib +{ + namespace blas_bindings + { + // This is a little screwy. This function is used inside the BLAS + // bindings to count how many times each of the BLAS functions get called. +#ifdef DLIB_TEST_BLAS_BINDINGS + int& counter_axpy() { static int counter = 0; return counter; } + int& counter_scal() { static int counter = 0; return counter; } +#endif + + } +} + +namespace +{ + using namespace test; + using namespace std; + // Declare the logger we will use in this test. The name of the logger + // should start with "test." + dlib::logger dlog("test.scal_axpy"); + + + class blas_bindings_scal_axpy_tester : public tester + { + public: + blas_bindings_scal_axpy_tester ( + ) : + tester ( + "test_scal_axpy", // the command line argument name for this test + "Run tests for DOT routines.", // the command line argument description + 0 // the number of command line arguments for this test + ) + {} + + template + void test_scal_axpy_stuff( + matrix_type& m, + rv_type& rv, + cv_type& cv + ) const + { + using namespace dlib; + using namespace dlib::blas_bindings; + + rv_type rv2 = rv; + cv_type cv2 = cv; + matrix_type m2 = m; + typedef typename matrix_type::type scalar_type; + scalar_type val; + + counter_scal() = 0; + m = 5*m; + DLIB_TEST(counter_scal() == 1); + + counter_scal() = 0; + rv = 5*rv; + DLIB_TEST(counter_scal() == 1); + + counter_scal() = 0; + rv = 5*rv; + DLIB_TEST(counter_scal() == 1); + + + counter_axpy() = 0; + m2 += 5*m; + DLIB_TEST(counter_axpy() == 1); + + counter_axpy() = 0; + rv2 += 5*rv; + DLIB_TEST(counter_axpy() == 1); + + counter_axpy() = 0; + rv2 += 5*rv; + DLIB_TEST(counter_axpy() == 1); + + + + counter_scal() = 0; + m = m*5; + DLIB_TEST(counter_scal() == 1); + + counter_scal() = 0; + rv = rv*5; + DLIB_TEST(counter_scal() == 1); + + counter_scal() = 0; + cv = cv*5; + DLIB_TEST(counter_scal() == 1); + + + counter_axpy() = 0; + m2 += m*5; + DLIB_TEST(counter_axpy() == 1); + + counter_axpy() = 0; + rv2 += rv*5; + DLIB_TEST(counter_axpy() == 1); + + counter_axpy() = 0; + cv2 += cv*5; + DLIB_TEST(counter_axpy() == 1); + + + + + counter_axpy() = 0; + m2 = m2 + m*5; + DLIB_TEST(counter_axpy() == 1); + + counter_axpy() = 0; + rv2 = rv2 + rv*5; + DLIB_TEST(counter_axpy() == 1); + + counter_axpy() = 0; + cv2 = cv2 + cv*5; + DLIB_TEST(counter_axpy() == 1); + + + counter_axpy() = 0; + cv2 = 1; + cv = 1; + cv2 = 2*cv2 + cv*5; + DLIB_TEST(counter_axpy() == 1); + DLIB_TEST(max(abs(cv2 - 7)) == 0); + + + counter_axpy() = 0; + rv2 = 1; + rv = 1; + rv2 = 2*rv2 + rv*5; + DLIB_TEST(counter_axpy() == 1); + DLIB_TEST(max(abs(rv2 - 7)) == 0); + + counter_axpy() = 0; + m2 = 1; + m = 1; + m2 = 2*m2 + m*5; + DLIB_TEST(counter_axpy() == 1); + DLIB_TEST(max(abs(m2 - 7)) == 0); + + + if (is_same_type::value) + { + counter_axpy() = 0; + m2 = 1; + m = 1; + set_ptrm(&m2(0,0),m2.nr(),m2.nc()) = 2*m2 + m*5; + DLIB_TEST(max(abs(m2 - 7)) == 0); + DLIB_TEST(counter_axpy() == 1); + + counter_axpy() = 0; + m2 = 1; + m = 1; + set_ptrm(&m2(0,0),m2.nr(),m2.nc()) = 2*mat(&m2(0,0),m2.nr(),m2.nc()) + mat(&m(0,0),m.nr(),m.nc())*5; + DLIB_TEST(max(abs(m2 - 7)) == 0); + DLIB_TEST(counter_axpy() == 1); + + counter_axpy() = 0; + m2 = 1; + m = 1; + m2 = 2*mat(&m2(0,0),m2.nr(),m2.nc()) + mat(&m(0,0),m.nr(),m.nc())*5; + DLIB_TEST(max(abs(m2 - 7)) == 0); + DLIB_TEST(counter_axpy() == 1); + } + + } + + + + void perform_test ( + ) + { + using namespace dlib; + typedef dlib::memory_manager::kernel_1a mm; + + dlog << dlib::LINFO << "test double"; + { + matrix m = randm(4,4); + matrix rv = randm(1,4); + matrix cv = randm(4,1); + test_scal_axpy_stuff(m,rv,cv); + } + + dlog << dlib::LINFO << "test float"; + { + matrix m = matrix_cast(randm(4,4)); + matrix rv = matrix_cast(randm(1,4)); + matrix cv = matrix_cast(randm(4,1)); + test_scal_axpy_stuff(m,rv,cv); + } + + dlog << dlib::LINFO << "test complex"; + { + matrix > m = complex_matrix(randm(4,4), randm(4,4)); + matrix,1,0> rv = complex_matrix(randm(1,4), randm(1,4)); + matrix,0,1> cv = complex_matrix(randm(4,1), randm(4,1)); + test_scal_axpy_stuff(m,rv,cv); + } + + dlog << dlib::LINFO << "test complex"; + { + matrix > m = matrix_cast >(complex_matrix(randm(4,4), randm(4,4))); + matrix,1,0> rv = matrix_cast >(complex_matrix(randm(1,4), randm(1,4))); + matrix,0,1> cv = matrix_cast >(complex_matrix(randm(4,1), randm(4,1))); + test_scal_axpy_stuff(m,rv,cv); + } + + + dlog << dlib::LINFO << "test double, column major"; + { + matrix m = randm(4,4); + matrix rv = randm(1,4); + matrix cv = randm(4,1); + test_scal_axpy_stuff(m,rv,cv); + } + + dlog << dlib::LINFO << "test float, column major"; + { + matrix m = matrix_cast(randm(4,4)); + matrix rv = matrix_cast(randm(1,4)); + matrix cv = matrix_cast(randm(4,1)); + test_scal_axpy_stuff(m,rv,cv); + } + + dlog << dlib::LINFO << "test complex, column major"; + { + matrix,0,0,mm,column_major_layout > m = complex_matrix(randm(4,4), randm(4,4)); + matrix,1,0,mm,column_major_layout> rv = complex_matrix(randm(1,4), randm(1,4)); + matrix,0,1,mm,column_major_layout> cv = complex_matrix(randm(4,1), randm(4,1)); + test_scal_axpy_stuff(m,rv,cv); + } + + dlog << dlib::LINFO << "test complex, column major"; + { + matrix,0,0,mm,column_major_layout > m = matrix_cast >(complex_matrix(randm(4,4), randm(4,4))); + matrix,1,0,mm,column_major_layout> rv = matrix_cast >(complex_matrix(randm(1,4), randm(1,4))); + matrix,0,1,mm,column_major_layout> cv = matrix_cast >(complex_matrix(randm(4,1), randm(4,1))); + test_scal_axpy_stuff(m,rv,cv); + } + + + print_spinner(); + } + }; + + blas_bindings_scal_axpy_tester a; + +} + + diff --git a/ml/dlib/dlib/test/blas_bindings/vector.cpp b/ml/dlib/dlib/test/blas_bindings/vector.cpp new file mode 100644 index 000000000..0a6f5f301 --- /dev/null +++ b/ml/dlib/dlib/test/blas_bindings/vector.cpp @@ -0,0 +1,115 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "../tester.h" +#include +#include + +#ifndef DLIB_USE_BLAS +#error "BLAS bindings must be used for this test to make any sense" +#endif + +namespace dlib +{ + namespace blas_bindings + { + +#ifdef DLIB_TEST_BLAS_BINDINGS + extern int& counter_gemm(); + extern int& counter_gemv(); + extern int& counter_ger(); + extern int& counter_dot(); +#endif + + } +} + +namespace +{ + using namespace test; + using namespace std; + // Declare the logger we will use in this test. The name of the logger + // should start with "test." + dlib::logger dlog("test.vector"); + + + class vector_tester : public tester + { + public: + vector_tester ( + ) : + tester ( + "test_vector", // the command line argument name for this test + "Run tests on dlib::vector.", // the command line argument description + 0 // the number of command line arguments for this test + ) + {} + + template + void test_vector( + ) const + { + using namespace dlib; + using namespace dlib::blas_bindings; + + dlib::vector a2, b2, c2; + dlib::vector a3, b3, c3; + + matrix mat2(2,2), mat3(3,3); + mat2 = 0; + mat3 = 0; + + type var = 0; + + // We want to make sure that the BLAS bindings are being called for the 2D and 3D vectors. That would + // be very slow. + counter_gemm() = 0; + counter_gemv() = 0; + counter_ger() = 0; + counter_dot() = 0; + + var = trans(a2)*(a2); + var = dot(a2,a2); + + a2 = mat2*b2; + var = trans(b2)*mat2*b2; + + var = trans(a3)*(a3); + var = dot(a3,a3); + + a3 = mat3*b3; + var = trans(b3)*mat3*b3; + + mat3 = c3*trans(a3); + mat2 = c2*trans(a2); + + DLIB_TEST(counter_gemm() == 0 && counter_gemv() == 0 && counter_ger() == 0 && counter_dot() == 0); + + } + + void perform_test ( + ) + { + using namespace dlib; + + dlog << dlib::LINFO << "test double"; + test_vector(); + + dlog << dlib::LINFO << "test float"; + test_vector(); + + dlog << dlib::LINFO << "test int"; + test_vector(); + + dlog << dlib::LINFO << "test short"; + test_vector(); + + print_spinner(); + } + }; + + vector_tester a; + +} + + diff --git a/ml/dlib/dlib/test/bridge.cpp b/ml/dlib/dlib/test/bridge.cpp new file mode 100644 index 000000000..36d270c74 --- /dev/null +++ b/ml/dlib/dlib/test/bridge.cpp @@ -0,0 +1,259 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include +#include +#include +#include +#include +#include + +#include "tester.h" + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.bridge"); + + const unsigned short testing_port = 41238; + + void do_test1() + { + dlib::pipe in(0), out(0); + + bridge b1(connect_to_ip_and_port("127.0.0.1",testing_port), receive(in)); + bridge b2(listen_on_port(testing_port), transmit(out)); + + for (int i = 0; i < 100; ++i) + { + int val = i; + out.enqueue(val); + val = 0; + in.dequeue(val); + DLIB_TEST(val == i); + } + } + + void do_test2() + { + dlib::pipe in(0), out(0), echo_pipe(0); + + bridge b2(listen_on_port(testing_port), transmit(out), receive(in)); + bridge echo(connect_to_ip_and_port("127.0.0.1",testing_port), receive(echo_pipe), transmit(echo_pipe)); + + for (int i = 0; i < 100; ++i) + { + int val = i; + out.enqueue(val); + val = 0; + in.dequeue(val); + DLIB_TEST(val == i); + } + } + + void do_test3() + { + dlib::pipe in(10), out(10), echo_pipe(10); + + bridge b2(listen_on_port(testing_port), transmit(out), receive(in)); + bridge echo(connect_to_ip_and_port("127.0.0.1",testing_port), receive(echo_pipe), transmit(echo_pipe)); + + b2.reconfigure(listen_on_port(testing_port), transmit(out), receive(in)); + + for (int i = 0; i < 100; ++i) + { + int val = i; + out.enqueue(val); + val = 0; + in.dequeue(val); + DLIB_TEST(val == i); + } + } + + void do_test4() + { + dlib::pipe in(0), out(0), echo_pipe(0); + + bridge b2, echo; + b2.reconfigure(listen_on_port(testing_port), receive(in), transmit(out)); + echo.reconfigure(connect_to_ip_and_port("127.0.0.1",testing_port), transmit(echo_pipe), receive(echo_pipe)); + + for (int i = 0; i < 100; ++i) + { + int val = i; + out.enqueue(val); + val = 0; + in.dequeue(val); + DLIB_TEST(val == i); + } + } + + void do_test5(int pipe_size) + { + typedef type_safe_union tsu_type; + + dlib::pipe out(pipe_size); + dlib::pipe in(pipe_size); + dlib::pipe out_status(pipe_size); + + bridge b1(connect_to_ip_and_port("127.0.0.1",testing_port), receive(in)); + tsu_type msg; + + msg = b1.get_bridge_status(); + DLIB_TEST(msg.contains() == true); + DLIB_TEST(msg.get().is_connected == false); + DLIB_TEST(msg.get().foreign_ip == ""); + DLIB_TEST(msg.get().foreign_port == 0); + + { + bridge b2(listen_on_port(testing_port), transmit(out), receive(out_status)); + + in.dequeue(msg); + DLIB_TEST(msg.contains() == true); + DLIB_TEST(msg.get().is_connected == true); + DLIB_TEST(msg.get().foreign_ip == "127.0.0.1"); + DLIB_TEST(msg.get().foreign_port == testing_port); + msg = b1.get_bridge_status(); + DLIB_TEST(msg.contains() == true); + DLIB_TEST(msg.get().is_connected == true); + DLIB_TEST(msg.get().foreign_ip == "127.0.0.1"); + DLIB_TEST(msg.get().foreign_port == testing_port); + + bridge_status temp; + out_status.dequeue(temp); + DLIB_TEST(temp.is_connected == true); + DLIB_TEST(temp.foreign_ip == "127.0.0.1"); + + for (int i = 0; i < 100; ++i) + { + msg = i; + out.enqueue(msg); + + msg.get() = 0; + + in.dequeue(msg); + DLIB_TEST(msg.contains() == true); + DLIB_TEST(msg.get() == i); + } + + } + + in.dequeue(msg); + DLIB_TEST(msg.contains() == true); + DLIB_TEST(msg.get().is_connected == false); + DLIB_TEST(msg.get().foreign_ip == "127.0.0.1"); + DLIB_TEST(msg.get().foreign_port == testing_port); + } + + void do_test5_5(int pipe_size) + { + typedef type_safe_union tsu_type; + + dlib::pipe out(pipe_size); + dlib::pipe in(pipe_size); + dlib::pipe out_status(pipe_size); + + bridge b1(connect_to_ip_and_port("127.0.0.1",testing_port), receive(in)); + tsu_type msg; + + bridge b2(listen_on_port(testing_port), transmit(out), receive(out_status)); + + in.dequeue(msg); + DLIB_TEST(msg.contains() == true); + DLIB_TEST(msg.get().is_connected == true); + DLIB_TEST(msg.get().foreign_ip == "127.0.0.1"); + DLIB_TEST(msg.get().foreign_port == testing_port); + + bridge_status temp; + out_status.dequeue(temp); + DLIB_TEST(temp.is_connected == true); + DLIB_TEST(temp.foreign_ip == "127.0.0.1"); + + for (int i = 0; i < 100; ++i) + { + msg = i; + out.enqueue(msg); + + msg.get() = 0; + + in.dequeue(msg); + DLIB_TEST(msg.contains() == true); + DLIB_TEST(msg.get() == i); + } + + b2.clear(); + msg = b2.get_bridge_status(); + DLIB_TEST(msg.contains() == true); + DLIB_TEST(msg.get().is_connected == false); + DLIB_TEST(msg.get().foreign_ip == ""); + DLIB_TEST(msg.get().foreign_port == 0); + + in.dequeue(msg); + DLIB_TEST(msg.contains() == true); + DLIB_TEST(msg.get().is_connected == false); + DLIB_TEST(msg.get().foreign_ip == "127.0.0.1"); + DLIB_TEST(msg.get().foreign_port == testing_port); + } + + void do_test6() + { + dlib::pipe in(0), out(300); + + bridge b1(connect_to_ip_and_port("127.0.0.1",testing_port), receive(in)); + bridge b2(listen_on_port(testing_port), transmit(out)); + + for (int i = 0; i < 100; ++i) + { + int val = i; + out.enqueue(val); + } + + int val = 10; + in.dequeue(val); + DLIB_TEST(val == 0); + dlib::sleep(100); + in.dequeue(val); + DLIB_TEST(val == 1); + dlib::sleep(100); + } + + class test_bridge : public tester + { + public: + test_bridge ( + ) : + tester ("test_bridge", + "Runs tests on the bridge component.") + {} + + void perform_test ( + ) + { + dlog << LINFO << "testing bridge, using local port number of " << testing_port; + + print_spinner(); + do_test1(); + print_spinner(); + do_test2(); + print_spinner(); + do_test3(); + print_spinner(); + do_test4(); + print_spinner(); + for (int i = 0; i < 5; ++i) + do_test5(i); + do_test5_5(1); + print_spinner(); + do_test6(); + } + } a; + + + +} + + + diff --git a/ml/dlib/dlib/test/bsp.cpp b/ml/dlib/dlib/test/bsp.cpp new file mode 100644 index 000000000..04367c1d3 --- /dev/null +++ b/ml/dlib/dlib/test/bsp.cpp @@ -0,0 +1,566 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include + +#include "tester.h" + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.bsp"); + + + template + struct callfunct_helper + { + callfunct_helper ( + funct f_, + int port_, + bool& error_occurred_ + ) :f(f_), port(port_), error_occurred(error_occurred_) {} + + funct f; + int port; + bool& error_occurred; + + void operator() ( + ) const + { + try + { + bsp_listen(port, f); + } + catch (exception& e) + { + dlog << LERROR << "error calling bsp_listen(): " << e.what(); + error_occurred = true; + } + } + }; + + template + callfunct_helper callfunct(funct f, int port, bool& error_occurred) + { + return callfunct_helper(f,port,error_occurred); + + } + +// ---------------------------------------------------------------------------------------- + + template + struct callfunct_helper_pn + { + callfunct_helper_pn ( + funct f_, + int port_, + bool& error_occurred_, + dlib::pipe& port_pipe_ + ) :f(f_), port(port_), error_occurred(error_occurred_), port_pipe(port_pipe_) {} + + funct f; + int port; + bool& error_occurred; + dlib::pipe& port_pipe; + + struct helper + { + helper ( + dlib::pipe& port_pipe_ + ) : port_pipe(port_pipe_) {} + + dlib::pipe& port_pipe; + + void operator() (unsigned short p) { port_pipe.enqueue(p); } + }; + + void operator() ( + ) const + { + try + { + bsp_listen_dynamic_port(port, helper(port_pipe), f); + } + catch (exception& e) + { + dlog << LERROR << "error calling bsp_listen_dynamic_port(): " << e.what(); + error_occurred = true; + } + } + }; + + template + callfunct_helper_pn callfunct(funct f, int port, bool& error_occurred, dlib::pipe& port_pipe) + { + return callfunct_helper_pn(f,port,error_occurred,port_pipe); + } + +// ---------------------------------------------------------------------------------------- + + void sum_array_driver ( + bsp_context& obj, + const std::vector& v, + int& result + ) + { + obj.broadcast(v); + + result = 0; + int val; + while(obj.try_receive(val)) + result += val; + } + + void sum_array_other ( + bsp_context& obj + ) + { + std::vector v; + obj.receive(v); + + int sum = 0; + for (unsigned long i = 0; i < v.size(); ++i) + sum += v[i]; + + obj.send(sum, 0); + + + } + + + void dotest1() + { + dlog << LINFO << "start dotest1()"; + print_spinner(); + bool error_occurred = false; + { + thread_function t1(callfunct(sum_array_other, 12345, error_occurred)); + thread_function t2(callfunct(sum_array_other, 12346, error_occurred)); + thread_function t3(callfunct(sum_array_other, 12347, error_occurred)); + std::vector v; + int true_value = 0; + for (int i = 0; i < 10; ++i) + { + v.push_back(i); + true_value += i; + } + + // wait a little bit for the threads to start up + dlib::sleep(200); + + try + { + int result; + std::vector hosts; + hosts.push_back("127.0.0.1:12345"); + hosts.push_back("localhost:12346"); + hosts.push_back("127.0.0.1:12347"); + bsp_connect(hosts, sum_array_driver, dlib::ref(v), dlib::ref(result)); + + dlog << LINFO << "result: "<< result; + dlog << LINFO << "should be: "<< 3*true_value; + DLIB_TEST(result == 3*true_value); + } + catch (std::exception& e) + { + dlog << LERROR << "error during bsp_context: " << e.what(); + DLIB_TEST(false); + } + } + DLIB_TEST(error_occurred == false); + } + +// ---------------------------------------------------------------------------------------- + + template + void test2_job(bsp_context& obj) + { + if (obj.node_id() == id) + dlib::sleep(100); + } + + template + void dotest2() + { + dlog << LINFO << "start dotest2()"; + print_spinner(); + bool error_occurred = false; + { + thread_function t1(callfunct(test2_job, 12345, error_occurred)); + thread_function t2(callfunct(test2_job, 12346, error_occurred)); + thread_function t3(callfunct(test2_job, 12347, error_occurred)); + + // wait a little bit for the threads to start up + dlib::sleep(200); + + try + { + std::vector hosts; + hosts.push_back("127.0.0.1:12345"); + hosts.push_back("127.0.0.1:12346"); + hosts.push_back("127.0.0.1:12347"); + bsp_connect(hosts, test2_job); + } + catch (std::exception& e) + { + dlog << LERROR << "error during bsp_context: " << e.what(); + DLIB_TEST(false); + } + + } + DLIB_TEST(error_occurred == false); + } + +// ---------------------------------------------------------------------------------------- + + void test3_job_driver(bsp_context& obj, int& result) + { + + obj.broadcast(obj.node_id()); + + int accum = 0; + int temp = 0; + while(obj.try_receive(temp)) + accum += temp; + + // send to node 1 so it can sum everything + if (obj.node_id() != 1) + obj.send(accum, 1); + + while(obj.try_receive(temp)) + accum += temp; + + // Now hop the accum values along the nodes until the value from node 1 gets to + // node 0. + obj.send(accum, (obj.node_id()+1)%obj.number_of_nodes()); + obj.receive(accum); + obj.send(accum, (obj.node_id()+1)%obj.number_of_nodes()); + obj.receive(accum); + obj.send(accum, (obj.node_id()+1)%obj.number_of_nodes()); + obj.receive(accum); + + // this whole block is a noop since it doesn't end up doing anything. + for (int k = 0; k < 100; ++k) + { + dlog << LINFO << "k: " << k; + for (int i = 0; i < 4; ++i) + { + obj.send(accum, (obj.node_id()+1)%obj.number_of_nodes()); + obj.receive(accum); + } + } + + + dlog << LINFO << "TERMINATE"; + if (obj.node_id() == 0) + result = accum; + } + + + void test3_job(bsp_context& obj) + { + int junk; + test3_job_driver(obj, junk); + } + + + void dotest3() + { + dlog << LINFO << "start dotest3()"; + print_spinner(); + bool error_occurred = false; + { + dlib::pipe ports(5); + thread_function t1(callfunct(test3_job, 12345, error_occurred, ports)); + thread_function t2(callfunct(test3_job, 0, error_occurred, ports)); + thread_function t3(callfunct(test3_job, 12347, error_occurred, ports)); + + + try + { + std::vector hosts; + unsigned short port; + ports.dequeue(port); hosts.push_back(network_address("127.0.0.1",port)); dlog << LINFO << "PORT: " << port; + ports.dequeue(port); hosts.push_back(network_address("127.0.0.1",port)); dlog << LINFO << "PORT: " << port; + ports.dequeue(port); hosts.push_back(network_address("127.0.0.1",port)); dlog << LINFO << "PORT: " << port; + int result = 0; + const int expected = 1+2+3 + 0+2+3 + 0+1+3 + 0+1+2; + bsp_connect(hosts, test3_job_driver, dlib::ref(result)); + + dlog << LINFO << "result: " << result; + dlog << LINFO << "should be: " << expected; + DLIB_TEST(result == expected); + } + catch (std::exception& e) + { + dlog << LERROR << "error during bsp_context: " << e.what(); + DLIB_TEST(false); + } + + } + DLIB_TEST(error_occurred == false); + } + +// ---------------------------------------------------------------------------------------- + + void test4_job_driver(bsp_context& obj, int& result) + { + + obj.broadcast(obj.node_id()); + + int accum = 0; + int temp = 0; + while(obj.try_receive(temp)) + accum += temp; + + // send to node 1 so it can sum everything + if (obj.node_id() != 1) + obj.send(accum, 1); + + while(obj.try_receive(temp)) + accum += temp; + + // Now hop the accum values along the nodes until the value from node 1 gets to + // node 0. + obj.send(accum, (obj.node_id()+1)%obj.number_of_nodes()); + obj.receive(accum); + obj.send(accum, (obj.node_id()+1)%obj.number_of_nodes()); + obj.receive(accum); + obj.send(accum, (obj.node_id()+1)%obj.number_of_nodes()); + obj.receive(accum); + + // this whole block is a noop since it doesn't end up doing anything. + for (int k = 0; k < 40; ++k) + { + dlog << LINFO << "k: " << k; + for (int i = 0; i < 4; ++i) + { + obj.send(accum, (obj.node_id()+1)%obj.number_of_nodes()); + obj.receive(accum); + + obj.receive(); + } + } + + + dlog << LINFO << "TERMINATE"; + if (obj.node_id() == 0) + result = accum; + } + + + void test4_job(bsp_context& obj) + { + int junk; + test4_job_driver(obj, junk); + } + + + void dotest4() + { + dlog << LINFO << "start dotest4()"; + print_spinner(); + bool error_occurred = false; + { + dlib::pipe ports(5); + thread_function t1(callfunct(test4_job, 0, error_occurred, ports)); + thread_function t2(callfunct(test4_job, 0, error_occurred, ports)); + thread_function t3(callfunct(test4_job, 0, error_occurred, ports)); + + + try + { + std::vector hosts; + unsigned short port; + ports.dequeue(port); hosts.push_back(network_address("127.0.0.1",port)); dlog << LINFO << "PORT: " << port; + ports.dequeue(port); hosts.push_back(network_address("127.0.0.1",port)); dlog << LINFO << "PORT: " << port; + ports.dequeue(port); hosts.push_back(network_address("127.0.0.1",port)); dlog << LINFO << "PORT: " << port; + int result = 0; + const int expected = 1+2+3 + 0+2+3 + 0+1+3 + 0+1+2; + bsp_connect(hosts, test4_job_driver, dlib::ref(result)); + + dlog << LINFO << "result: " << result; + dlog << LINFO << "should be: " << expected; + DLIB_TEST(result == expected); + } + catch (std::exception& e) + { + dlog << LERROR << "error during bsp_context: " << e.what(); + DLIB_TEST(false); + } + + } + DLIB_TEST(error_occurred == false); + } + +// ---------------------------------------------------------------------------------------- + + void test5_job( + bsp_context& , + int& val + ) + { + val = 25; + } + + void dotest5() + { + dlog << LINFO << "start dotest5()"; + print_spinner(); + std::vector hosts; + int val = 0; + bsp_connect(hosts, test5_job, dlib::ref(val)); + DLIB_TEST(val == 25); + } + +// ---------------------------------------------------------------------------------------- + + double f ( double x) + { + return std::pow(x-2.0, 2.0); + } + + + void bsp_job_node_0 ( + bsp_context& context, + double& min_value, + double& optimal_x + ) + { + double left = -100; + double right = 100; + + min_value = std::numeric_limits::infinity(); + double interval_width = std::abs(right-left); + + // This is doing a BSP based grid search for the minimum of f(). Here we + // do 100 iterations where we keep shrinking the grid size. + for (int i = 0; i < 100; ++i) + { + context.broadcast(left); + context.broadcast(right); + + for (unsigned int k = 1; k < context.number_of_nodes(); ++k) + { + std::pair val; + context.receive(val); + if (val.second < min_value) + { + min_value = val.second; + optimal_x = val.first; + } + } + + interval_width *= 0.5; + left = optimal_x - interval_width/2; + right = optimal_x + interval_width/2; + } + } + + + void bsp_job_other_nodes ( + bsp_context& context + ) + { + double left, right; + while (context.try_receive(left)) + { + context.receive(right); + + const double l = (context.node_id()-1)/(context.number_of_nodes()-1.0); + const double r = context.node_id() /(context.number_of_nodes()-1.0); + + const double width = right-left; + matrix values_to_check = linspace(left +l*width, left + r*width, 100); + + double best_x = 0; + double best_val = std::numeric_limits::infinity(); + for (long j = 0; j < values_to_check.size(); ++j) + { + double temp = f(values_to_check(j)); + if (temp < best_val) + { + best_val = temp; + best_x = values_to_check(j); + } + } + + context.send(make_pair(best_x, best_val), 0); + } + } + + void dotest6() + { + dlog << LINFO << "start dotest6()"; + print_spinner(); + bool error_occurred = false; + { + dlib::pipe ports(5); + thread_function t1(callfunct(bsp_job_other_nodes, 0, error_occurred, ports)); + thread_function t2(callfunct(bsp_job_other_nodes, 0, error_occurred, ports)); + thread_function t3(callfunct(bsp_job_other_nodes, 0, error_occurred, ports)); + + + try + { + std::vector hosts; + unsigned short port; + ports.dequeue(port); hosts.push_back(network_address("127.0.0.1",port)); dlog << LINFO << "PORT: " << port; + ports.dequeue(port); hosts.push_back(network_address("127.0.0.1",port)); dlog << LINFO << "PORT: " << port; + ports.dequeue(port); hosts.push_back(network_address("127.0.0.1",port)); dlog << LINFO << "PORT: " << port; + double min_value = 10, optimal_x = 0; + bsp_connect(hosts, bsp_job_node_0, dlib::ref(min_value), dlib::ref(optimal_x)); + + dlog << LINFO << "min_value: " << min_value; + dlog << LINFO << "optimal_x: " << optimal_x; + DLIB_TEST(std::abs(min_value - 0) < 1e-14); + DLIB_TEST(std::abs(optimal_x - 2) < 1e-14); + } + catch (std::exception& e) + { + dlog << LERROR << "error during bsp_context: " << e.what(); + DLIB_TEST(false); + } + + } + DLIB_TEST(error_occurred == false); + } +// ---------------------------------------------------------------------------------------- + + class bsp_tester : public tester + { + + public: + bsp_tester ( + ) : + tester ("test_bsp", + "Runs tests on the BSP components.") + {} + + void perform_test ( + ) + { + for (int i = 0; i < 3; ++i) + { + dotest1(); + dotest2<0>(); + dotest2<1>(); + dotest2<2>(); + dotest3(); + dotest4(); + dotest5(); + dotest6(); + } + } + } a; + +} + diff --git a/ml/dlib/dlib/test/byte_orderer.cpp b/ml/dlib/dlib/test/byte_orderer.cpp new file mode 100644 index 000000000..7200c1b4a --- /dev/null +++ b/ml/dlib/dlib/test/byte_orderer.cpp @@ -0,0 +1,111 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include +#include + +#include "tester.h" + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.byte_orderer"); + + + class byte_orderer_tester : public tester + { + public: + byte_orderer_tester ( + ) : + tester ("test_byte_orderer", + "Runs tests on the byte_orderer component.") + {} + + void perform_test ( + ) + { + byte_orderer bo; + + union data + { + unsigned char b[4]; + dlib::uint32 val; + }; + + data a; + + a.val = 1; + + if (bo.host_is_little_endian()) + { + DLIB_TEST(a.b[0] == 1); + DLIB_TEST(a.b[1] == 0); + DLIB_TEST(a.b[2] == 0); + DLIB_TEST(a.b[3] == 0); + + bo.host_to_big(a.val); + + DLIB_TEST(a.b[0] == 0); + DLIB_TEST(a.b[1] == 0); + DLIB_TEST(a.b[2] == 0); + DLIB_TEST(a.b[3] == 1); + + bo.big_to_host(a.val); + + DLIB_TEST(a.b[0] == 1); + DLIB_TEST(a.b[1] == 0); + DLIB_TEST(a.b[2] == 0); + DLIB_TEST(a.b[3] == 0); + + DLIB_TEST(a.val == 1); + bo.host_to_network(a.val); + DLIB_TEST(a.val == 0x01000000); + bo.network_to_host(a.val); + DLIB_TEST(a.val == 1); + } + else + { + DLIB_TEST(a.b[0] == 0); + DLIB_TEST(a.b[1] == 0); + DLIB_TEST(a.b[2] == 0); + DLIB_TEST(a.b[3] == 1); + + bo.host_to_little(a.val); + + DLIB_TEST(a.b[0] == 1); + DLIB_TEST(a.b[1] == 0); + DLIB_TEST(a.b[2] == 0); + DLIB_TEST(a.b[3] == 0); + + bo.little_to_host(a.val); + + DLIB_TEST(a.b[0] == 0); + DLIB_TEST(a.b[1] == 0); + DLIB_TEST(a.b[2] == 0); + DLIB_TEST(a.b[3] == 1); + + + DLIB_TEST(a.val == 1); + bo.network_to_host(a.val); + DLIB_TEST(a.val == 1); + bo.host_to_network(a.val); + DLIB_TEST(a.val == 1); + + } + + + } + } a; + +} + + diff --git a/ml/dlib/dlib/test/cca.cpp b/ml/dlib/dlib/test/cca.cpp new file mode 100644 index 000000000..a2014121d --- /dev/null +++ b/ml/dlib/dlib/test/cca.cpp @@ -0,0 +1,460 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include +#include +#include +#include + +#include "tester.h" + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.cca"); + + dlib::rand rnd; +// ---------------------------------------------------------------------------------------- + + /* + std::vector > make_really_big_test_matrix ( + ) + { + std::vector > temp(30000); + for (unsigned long i = 0; i < temp.size(); ++i) + { + for (int k = 0; k < 30; ++k) + temp[i][rnd.get_random_32bit_number()%10000] = 1; + } + return temp; + } + */ + + template + std::vector > mat_to_sparse ( + const matrix& A + ) + { + std::vector > temp(A.nr()); + for (long r = 0; r < A.nr(); ++r) + { + for (long c = 0; c < A.nc(); ++c) + { + temp[r][c] = A(r,c); + } + } + return temp; + } + +// ---------------------------------------------------------------------------------------- + + template + matrix rm_zeros ( + const matrix_exp& m + ) + { + // Do this to avoid trying to correlate super small numbers that are really just + // zero. Doing this avoids some potential false alarms in the unit tests below. + return round_zeros(m, max(abs(m))*1e-14); + } + +// ---------------------------------------------------------------------------------------- + + /* + void check_correlation ( + matrix L, + matrix R, + const matrix& Ltrans, + const matrix& Rtrans, + const matrix& correlations + ) + { + // apply the transforms + L = L*Ltrans; + R = R*Rtrans; + + // compute the real correlation values. Store them in A. + matrix A = compute_correlations(L, R); + + for (long i = 0; i < correlations.size(); ++i) + { + // compare what the measured correlation values are (in A) to the + // predicted values. + cout << "error: "<< A(i) - correlations(i); + } + } + */ + +// ---------------------------------------------------------------------------------------- + + void test_cca3() + { + print_spinner(); + const unsigned long rank = rnd.get_random_32bit_number()%10 + 1; + const unsigned long m = rank + rnd.get_random_32bit_number()%15; + const unsigned long n = rank + rnd.get_random_32bit_number()%15; + const unsigned long n2 = rank + rnd.get_random_32bit_number()%15; + const unsigned long rank2 = rank + rnd.get_random_32bit_number()%5; + + dlog << LINFO << "m: " << m; + dlog << LINFO << "n: " << n; + dlog << LINFO << "n2: " << n2; + dlog << LINFO << "rank: " << rank; + dlog << LINFO << "rank2: " << rank2; + + + matrix L = randm(m,rank, rnd)*randm(rank,n, rnd); + //matrix R = randm(m,rank, rnd)*randm(rank,n2, rnd); + matrix R = L*randm(n,n2, rnd); + //matrix L = randm(m,n, rnd); + //matrix R = randm(m,n2, rnd); + + matrix Ltrans, Rtrans; + matrix correlations; + + { + correlations = cca(L, R, Ltrans, Rtrans, min(m,n), max(n,n2)); + DLIB_TEST(Ltrans.nc() == Rtrans.nc()); + dlog << LINFO << "correlations: "<< trans(correlations); + + const double corr_error = max(abs(compute_correlations(rm_zeros(L*Ltrans), rm_zeros(R*Rtrans)) - correlations)); + dlog << LINFO << "correlation error: "<< corr_error; + DLIB_TEST_MSG(corr_error < 1e-13, Ltrans << "\n\n" << Rtrans); + + const double trans_error = max(abs(L*Ltrans - R*Rtrans)); + dlog << LINFO << "trans_error: "<< trans_error; + DLIB_TEST_MSG(trans_error < 1e-9, trans_error); + } + { + correlations = cca(mat_to_sparse(L), mat_to_sparse(R), Ltrans, Rtrans, min(m,n), max(n,n2)+6, 4); + DLIB_TEST(Ltrans.nc() == Rtrans.nc()); + dlog << LINFO << "correlations: "<< trans(correlations); + dlog << LINFO << "computed cors: " << trans(compute_correlations(rm_zeros(L*Ltrans), rm_zeros(R*Rtrans))); + + const double trans_error = max(abs(L*Ltrans - R*Rtrans)); + dlog << LINFO << "trans_error: "<< trans_error; + const double corr_error = max(abs(compute_correlations(rm_zeros(L*Ltrans), rm_zeros(R*Rtrans)) - correlations)); + dlog << LINFO << "correlation error: "<< corr_error; + DLIB_TEST_MSG(corr_error < 1e-13, Ltrans << "\n\n" << Rtrans); + + DLIB_TEST(trans_error < 2e-9); + } + + dlog << LINFO << "*****************************************************"; + } + + void test_cca2() + { + print_spinner(); + const unsigned long rank = rnd.get_random_32bit_number()%10 + 1; + const unsigned long m = rank + rnd.get_random_32bit_number()%15; + const unsigned long n = rank + rnd.get_random_32bit_number()%15; + const unsigned long n2 = rank + rnd.get_random_32bit_number()%15; + + dlog << LINFO << "m: " << m; + dlog << LINFO << "n: " << n; + dlog << LINFO << "n2: " << n2; + dlog << LINFO << "rank: " << rank; + + + matrix L = randm(m,n, rnd); + matrix R = randm(m,n2, rnd); + + matrix Ltrans, Rtrans; + matrix correlations; + + { + correlations = cca(L, R, Ltrans, Rtrans, min(n,n2), max(n,n2)-min(n,n2)); + DLIB_TEST(Ltrans.nc() == Rtrans.nc()); + dlog << LINFO << "correlations: "<< trans(correlations); + + if (Ltrans.nc() > 1) + { + // The CCA projection directions are supposed to be uncorrelated for + // non-matching pairs of projections. + const double corr_rot1_error = max(abs(compute_correlations(rm_zeros(L*rotate<0,1>(Ltrans)), rm_zeros(R*Rtrans)))); + dlog << LINFO << "corr_rot1_error: "<< corr_rot1_error; + DLIB_TEST(std::abs(corr_rot1_error) < 1e-10); + } + // Matching projection directions should be correlated with the amount of + // correlation indicated by the return value of cca(). + const double corr_error = max(abs(compute_correlations(rm_zeros(L*Ltrans), rm_zeros(R*Rtrans)) - correlations)); + dlog << LINFO << "correlation error: "<< corr_error; + DLIB_TEST(corr_error < 1e-13); + } + { + correlations = cca(mat_to_sparse(L), mat_to_sparse(R), Ltrans, Rtrans, min(n,n2), max(n,n2)-min(n,n2)); + DLIB_TEST(Ltrans.nc() == Rtrans.nc()); + dlog << LINFO << "correlations: "<< trans(correlations); + + if (Ltrans.nc() > 1) + { + // The CCA projection directions are supposed to be uncorrelated for + // non-matching pairs of projections. + const double corr_rot1_error = max(abs(compute_correlations(rm_zeros(L*rotate<0,1>(Ltrans)), rm_zeros(R*Rtrans)))); + dlog << LINFO << "corr_rot1_error: "<< corr_rot1_error; + DLIB_TEST(std::abs(corr_rot1_error) < 1e-10); + } + // Matching projection directions should be correlated with the amount of + // correlation indicated by the return value of cca(). + const double corr_error = max(abs(compute_correlations(rm_zeros(L*Ltrans), rm_zeros(R*Rtrans)) - correlations)); + dlog << LINFO << "correlation error: "<< corr_error; + DLIB_TEST(corr_error < 1e-13); + } + + dlog << LINFO << "*****************************************************"; + } + + void test_cca1() + { + print_spinner(); + const unsigned long rank = rnd.get_random_32bit_number()%10 + 1; + const unsigned long m = rank + rnd.get_random_32bit_number()%15; + const unsigned long n = rank + rnd.get_random_32bit_number()%15; + + dlog << LINFO << "m: " << m; + dlog << LINFO << "n: " << n; + dlog << LINFO << "rank: " << rank; + + matrix T = randm(n,n, rnd); + + matrix L = randm(m,rank, rnd)*randm(rank,n, rnd); + //matrix L = randm(m,n, rnd); + matrix R = L*T; + + matrix Ltrans, Rtrans; + matrix correlations; + + { + correlations = cca(L, R, Ltrans, Rtrans, rank); + DLIB_TEST(Ltrans.nc() == Rtrans.nc()); + if (Ltrans.nc() > 1) + { + // The CCA projection directions are supposed to be uncorrelated for + // non-matching pairs of projections. + const double corr_rot1_error = max(abs(compute_correlations(rm_zeros(L*rotate<0,1>(Ltrans)), rm_zeros(R*Rtrans)))); + dlog << LINFO << "corr_rot1_error: "<< corr_rot1_error; + DLIB_TEST(std::abs(corr_rot1_error) < 2e-9); + } + // Matching projection directions should be correlated with the amount of + // correlation indicated by the return value of cca(). + const double corr_error = max(abs(compute_correlations(rm_zeros(L*Ltrans), rm_zeros(R*Rtrans)) - correlations)); + dlog << LINFO << "correlation error: "<< corr_error; + DLIB_TEST(corr_error < 1e-13); + + const double trans_error = max(abs(L*Ltrans - R*Rtrans)); + dlog << LINFO << "trans_error: "<< trans_error; + DLIB_TEST(trans_error < 2e-9); + + dlog << LINFO << "correlations: "<< trans(correlations); + } + { + correlations = cca(mat_to_sparse(L), mat_to_sparse(R), Ltrans, Rtrans, rank); + DLIB_TEST(Ltrans.nc() == Rtrans.nc()); + if (Ltrans.nc() > 1) + { + // The CCA projection directions are supposed to be uncorrelated for + // non-matching pairs of projections. + const double corr_rot1_error = max(abs(compute_correlations(rm_zeros(L*rotate<0,1>(Ltrans)), rm_zeros(R*Rtrans)))); + dlog << LINFO << "corr_rot1_error: "<< corr_rot1_error; + DLIB_TEST(std::abs(corr_rot1_error) < 2e-9); + } + // Matching projection directions should be correlated with the amount of + // correlation indicated by the return value of cca(). + const double corr_error = max(abs(compute_correlations(rm_zeros(L*Ltrans), rm_zeros(R*Rtrans)) - correlations)); + dlog << LINFO << "correlation error: "<< corr_error; + DLIB_TEST(corr_error < 1e-13); + + const double trans_error = max(abs(L*Ltrans - R*Rtrans)); + dlog << LINFO << "trans_error: "<< trans_error; + DLIB_TEST(trans_error < 2e-9); + + dlog << LINFO << "correlations: "<< trans(correlations); + } + + dlog << LINFO << "*****************************************************"; + } + +// ---------------------------------------------------------------------------------------- + + void test_svd_fast( + long rank, + long m, + long n + ) + { + print_spinner(); + matrix A = randm(m,rank,rnd)*randm(rank,n,rnd); + matrix u,v; + matrix w; + + dlog << LINFO << "rank: "<< rank; + dlog << LINFO << "m: "<< m; + dlog << LINFO << "n: "<< n; + + svd_fast(A, u, w, v, rank, 2); + DLIB_TEST(u.nr() == m); + DLIB_TEST(u.nc() == rank); + DLIB_TEST(w.nr() == rank); + DLIB_TEST(w.nc() == 1); + DLIB_TEST(v.nr() == n); + DLIB_TEST(v.nc() == rank); + DLIB_TEST(max(abs(trans(u)*u - identity_matrix(u.nc()))) < 1e-13); + DLIB_TEST(max(abs(trans(v)*v - identity_matrix(u.nc()))) < 1e-13); + + DLIB_TEST(max(abs(tmp(A - u*diagm(w)*trans(v)))) < 1e-13); + svd_fast(mat_to_sparse(A), u, w, v, rank, 2); + DLIB_TEST(u.nr() == m); + DLIB_TEST(u.nc() == rank); + DLIB_TEST(w.nr() == rank); + DLIB_TEST(w.nc() == 1); + DLIB_TEST(v.nr() == n); + DLIB_TEST(v.nc() == rank); + DLIB_TEST(max(abs(trans(u)*u - identity_matrix(u.nc()))) < 1e-13); + DLIB_TEST(max(abs(trans(v)*v - identity_matrix(u.nc()))) < 1e-13); + DLIB_TEST(max(abs(tmp(A - u*diagm(w)*trans(v)))) < 1e-13); + + svd_fast(A, u, w, v, rank, 0); + DLIB_TEST(u.nr() == m); + DLIB_TEST(u.nc() == rank); + DLIB_TEST(w.nr() == rank); + DLIB_TEST(w.nc() == 1); + DLIB_TEST(v.nr() == n); + DLIB_TEST(v.nc() == rank); + DLIB_TEST(max(abs(trans(u)*u - identity_matrix(u.nc()))) < 1e-13); + DLIB_TEST(max(abs(trans(v)*v - identity_matrix(u.nc()))) < 1e-13); + DLIB_TEST_MSG(max(abs(tmp(A - u*diagm(w)*trans(v)))) < 1e-9,max(abs(tmp(A - u*diagm(w)*trans(v))))); + svd_fast(mat_to_sparse(A), u, w, v, rank, 0); + DLIB_TEST(u.nr() == m); + DLIB_TEST(u.nc() == rank); + DLIB_TEST(w.nr() == rank); + DLIB_TEST(w.nc() == 1); + DLIB_TEST(v.nr() == n); + DLIB_TEST(v.nc() == rank); + DLIB_TEST(max(abs(trans(u)*u - identity_matrix(u.nc()))) < 1e-13); + DLIB_TEST(max(abs(trans(v)*v - identity_matrix(u.nc()))) < 1e-13); + DLIB_TEST(max(abs(tmp(A - u*diagm(w)*trans(v)))) < 1e-10); + + svd_fast(A, u, w, v, rank+5, 0); + DLIB_TEST(max(abs(trans(u)*u - identity_matrix(u.nc()))) < 1e-13); + DLIB_TEST(max(abs(trans(v)*v - identity_matrix(u.nc()))) < 1e-13); + DLIB_TEST(max(abs(tmp(A - u*diagm(w)*trans(v)))) < 1e-11); + svd_fast(mat_to_sparse(A), u, w, v, rank+5, 0); + DLIB_TEST(max(abs(trans(u)*u - identity_matrix(u.nc()))) < 1e-13); + DLIB_TEST(max(abs(trans(v)*v - identity_matrix(u.nc()))) < 1e-13); + DLIB_TEST(max(abs(tmp(A - u*diagm(w)*trans(v)))) < 1e-11); + svd_fast(A, u, w, v, rank+5, 1); + DLIB_TEST(max(abs(trans(u)*u - identity_matrix(u.nc()))) < 1e-13); + DLIB_TEST(max(abs(trans(v)*v - identity_matrix(u.nc()))) < 1e-13); + DLIB_TEST(max(abs(tmp(A - u*diagm(w)*trans(v)))) < 1e-12); + svd_fast(mat_to_sparse(A), u, w, v, rank+5, 1); + DLIB_TEST(max(abs(trans(u)*u - identity_matrix(u.nc()))) < 1e-13); + DLIB_TEST(max(abs(trans(v)*v - identity_matrix(u.nc()))) < 1e-13); + DLIB_TEST(max(abs(tmp(A - u*diagm(w)*trans(v)))) < 1e-12); + } + + void test_svd_fast() + { + for (int iter = 0; iter < 1000; ++iter) + { + const unsigned long rank = rnd.get_random_32bit_number()%10 + 1; + const unsigned long m = rank + rnd.get_random_32bit_number()%10; + const unsigned long n = rank + rnd.get_random_32bit_number()%10; + + test_svd_fast(rank, m, n); + + } + test_svd_fast(1, 1, 1); + test_svd_fast(1, 2, 2); + test_svd_fast(1, 1, 2); + test_svd_fast(1, 2, 1); + } + +// ---------------------------------------------------------------------------------------- + + /* + typedef std::vector> sv; + sv rand_sparse_vector() + { + static dlib::rand rnd; + sv v; + for (int i = 0; i < 50; ++i) + v.push_back(make_pair(rnd.get_integer(400000), rnd.get_random_gaussian()*100)); + + make_sparse_vector_inplace(v); + return v; + } + + sv rand_basis_combo(const std::vector& basis) + { + static dlib::rand rnd; + sv result; + + for (int i = 0; i < 5; ++i) + { + sv temp = basis[rnd.get_integer(basis.size())]; + scale_by(temp, rnd.get_random_gaussian()); + result = add(result,temp); + } + return result; + } + + void big_sparse_speed_test() + { + cout << "making A" << endl; + std::vector basis; + for (int i = 0; i < 100; ++i) + basis.emplace_back(rand_sparse_vector()); + + std::vector A; + for (int i = 0; i < 500000; ++i) + A.emplace_back(rand_basis_combo(basis)); + + cout << "done making A" << endl; + + matrix u,v; + matrix w; + { + timing::block aosijdf(0,"call it"); + svd_fast(A, u,w,v, 100, 5); + } + + timing::print(); + } + */ + +// ---------------------------------------------------------------------------------------- + + class test_cca : public tester + { + public: + test_cca ( + ) : + tester ("test_cca", + "Runs tests on the cca() and svd_fast() routines.") + {} + + void perform_test ( + ) + { + //big_sparse_speed_test(); + for (int i = 0; i < 200; ++i) + { + test_cca1(); + test_cca2(); + test_cca3(); + } + test_svd_fast(); + } + } a; + + + +} + + + + diff --git a/ml/dlib/dlib/test/checkerboard.h b/ml/dlib/dlib/test/checkerboard.h new file mode 100644 index 000000000..90c8f0cb8 --- /dev/null +++ b/ml/dlib/dlib/test/checkerboard.h @@ -0,0 +1,55 @@ +#ifndef DLIB_CHECKERBOARD_TeST_H_ +#define DLIB_CHECKERBOARD_TeST_H_ + +#include +#include +#include + +namespace dlib +{ + + template + void get_checkerboard_problem ( + std::vector >& x, + std::vector& y, + const long num_samples, + const long board_dimension = 8 + ) + /*! + requires + - num_samples > 0 + - board_dimension > 0 + ensures + - #x.size() == y.size() == num_samples + - is_binary_classification_problem(#x,#y) == true + - #x will contain points and #y labels that were + sampled randomly from a checkers board that has + board_dimension squares on each side. + !*/ + { + static dlib::rand rnd; + + x.clear(); + y.clear(); + + matrix sample; + for (long i = 0; i < num_samples; ++i) + { + sample(0) = rnd.get_random_double(); + sample(1) = rnd.get_random_double(); + sample *= board_dimension; + + x.push_back(sample); + if (((int)sum(floor(sample)) %2) == 0) + y.push_back(+1); + else + y.push_back(-1); + + } + } + + +} + +#endif // DLIB_CHECKERBOARD_TeST_H_ + diff --git a/ml/dlib/dlib/test/clustering.cpp b/ml/dlib/dlib/test/clustering.cpp new file mode 100644 index 000000000..a784c57c0 --- /dev/null +++ b/ml/dlib/dlib/test/clustering.cpp @@ -0,0 +1,410 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include + +#include "tester.h" + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.clustering"); + +// ---------------------------------------------------------------------------------------- + + void make_test_graph( + dlib::rand& rnd, + std::vector& edges, + std::vector& labels, + const int groups, + const int group_size, + const int noise_level, + const double missed_edges + ) + { + labels.resize(groups*group_size); + + for (unsigned long i = 0; i < labels.size(); ++i) + { + labels[i] = i/group_size; + } + + edges.clear(); + for (int i = 0; i < groups; ++i) + { + for (int j = 0; j < group_size; ++j) + { + for (int k = 0; k < group_size; ++k) + { + if (j == k) + continue; + + if (rnd.get_random_double() < missed_edges) + continue; + + edges.push_back(sample_pair(j+group_size*i, k+group_size*i, 1)); + } + } + } + + for (int k = 0; k < groups*noise_level; ++k) + { + const int i = rnd.get_random_32bit_number()%labels.size(); + const int j = rnd.get_random_32bit_number()%labels.size(); + edges.push_back(sample_pair(i,j,1)); + } + + } + +// ---------------------------------------------------------------------------------------- + + void make_modularity_matrices ( + const std::vector& edges, + matrix& A, + matrix& P, + double& m + ) + { + const unsigned long num_nodes = max_index_plus_one(edges); + A.set_size(num_nodes, num_nodes); + P.set_size(num_nodes, num_nodes); + A = 0; + P = 0; + std::vector k(num_nodes,0); + + for (unsigned long i = 0; i < edges.size(); ++i) + { + const unsigned long n1 = edges[i].index1(); + const unsigned long n2 = edges[i].index2(); + k[n1] += edges[i].distance(); + if (n1 != n2) + { + k[n2] += edges[i].distance(); + A(n2,n1) += edges[i].distance(); + } + + A(n1,n2) += edges[i].distance(); + } + + m = sum(A)/2; + + for (long r = 0; r < P.nr(); ++r) + { + for (long c = 0; c < P.nc(); ++c) + { + P(r,c) = k[r]*k[c]/(2*m); + } + } + + } + + double compute_modularity_simple ( + const std::vector& edges, + std::vector labels + ) + { + double m; + matrix A,P; + make_modularity_matrices(edges, A, P, m); + matrix B = A - P; + + double Q = 0; + for (long r = 0; r < B.nr(); ++r) + { + for (long c = 0; c < B.nc(); ++c) + { + if (labels[r] == labels[c]) + { + Q += B(r,c); + } + } + } + return 1.0/(2*m) * Q; + } + +// ---------------------------------------------------------------------------------------- + + void test_modularity(dlib::rand& rnd) + { + print_spinner(); + std::vector edges; + std::vector oedges; + std::vector labels; + + make_test_graph(rnd, edges, labels, 10, 30, 3, 0.10); + if (rnd.get_random_double() < 0.5) + remove_duplicate_edges(edges); + convert_unordered_to_ordered(edges, oedges); + + + const double m1 = modularity(edges, labels); + const double m2 = compute_modularity_simple(edges, labels); + const double m3 = modularity(oedges, labels); + + DLIB_TEST(std::abs(m1-m2) < 1e-12); + DLIB_TEST(std::abs(m2-m3) < 1e-12); + DLIB_TEST(std::abs(m3-m1) < 1e-12); + } + + void test_newman_clustering(dlib::rand& rnd) + { + print_spinner(); + std::vector edges; + std::vector labels; + + make_test_graph(rnd, edges, labels, 5, 30, 3, 0.10); + if (rnd.get_random_double() < 0.5) + remove_duplicate_edges(edges); + + + std::vector labels2; + + unsigned long num_clusters = newman_cluster(edges, labels2); + DLIB_TEST(labels.size() == labels2.size()); + DLIB_TEST(num_clusters == 5); + + for (unsigned long i = 0; i < labels.size(); ++i) + { + for (unsigned long j = 0; j < labels.size(); ++j) + { + if (labels[i] == labels[j]) + { + DLIB_TEST(labels2[i] == labels2[j]); + } + else + { + DLIB_TEST(labels2[i] != labels2[j]); + } + } + } + } + + void test_chinese_whispers(dlib::rand& rnd) + { + print_spinner(); + std::vector edges; + std::vector labels; + + make_test_graph(rnd, edges, labels, 5, 30, 3, 0.10); + if (rnd.get_random_double() < 0.5) + remove_duplicate_edges(edges); + + + std::vector labels2; + + unsigned long num_clusters; + if (rnd.get_random_double() < 0.5) + num_clusters = chinese_whispers(edges, labels2, 200, rnd); + else + num_clusters = chinese_whispers(edges, labels2); + + DLIB_TEST(labels.size() == labels2.size()); + DLIB_TEST(num_clusters == 5); + + for (unsigned long i = 0; i < labels.size(); ++i) + { + for (unsigned long j = 0; j < labels.size(); ++j) + { + if (labels[i] == labels[j]) + { + DLIB_TEST(labels2[i] == labels2[j]); + } + else + { + DLIB_TEST(labels2[i] != labels2[j]); + } + } + } + } + + void test_bottom_up_clustering() + { + std::vector pts; + pts.push_back(dpoint(0.0,0.0)); + pts.push_back(dpoint(0.5,0.0)); + pts.push_back(dpoint(0.5,0.5)); + pts.push_back(dpoint(0.0,0.5)); + + pts.push_back(dpoint(3.0,3.0)); + pts.push_back(dpoint(3.5,3.0)); + pts.push_back(dpoint(3.5,3.5)); + pts.push_back(dpoint(3.0,3.5)); + + pts.push_back(dpoint(7.0,7.0)); + pts.push_back(dpoint(7.5,7.0)); + pts.push_back(dpoint(7.5,7.5)); + pts.push_back(dpoint(7.0,7.5)); + + matrix dists(pts.size(), pts.size()); + for (long r = 0; r < dists.nr(); ++r) + for (long c = 0; c < dists.nc(); ++c) + dists(r,c) = length(pts[r]-pts[c]); + + + matrix truth(12); + truth = 0, 0, 0, 0, + 1, 1, 1, 1, + 2, 2, 2, 2; + + std::vector labels; + DLIB_TEST(bottom_up_cluster(dists, labels, 3) == 3); + DLIB_TEST(mat(labels) == truth); + DLIB_TEST(bottom_up_cluster(dists, labels, 1, 4.0) == 3); + DLIB_TEST(mat(labels) == truth); + DLIB_TEST(bottom_up_cluster(dists, labels, 1, 4.95) == 2); + truth = 0, 0, 0, 0, + 0, 0, 0, 0, + 1, 1, 1, 1; + DLIB_TEST(mat(labels) == truth); + DLIB_TEST(bottom_up_cluster(dists, labels, 1) == 1); + truth = 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0; + DLIB_TEST(mat(labels) == truth); + + dists.set_size(0,0); + DLIB_TEST(bottom_up_cluster(dists, labels, 3) == 0); + DLIB_TEST(labels.size() == 0); + DLIB_TEST(bottom_up_cluster(dists, labels, 1) == 0); + DLIB_TEST(labels.size() == 0); + + dists.set_size(1,1); + dists = 1; + DLIB_TEST(bottom_up_cluster(dists, labels, 3) == 1); + DLIB_TEST(labels.size() == 1); + DLIB_TEST(labels[0] == 0); + DLIB_TEST(bottom_up_cluster(dists, labels, 1) == 1); + DLIB_TEST(labels.size() == 1); + DLIB_TEST(labels[0] == 0); + DLIB_TEST(bottom_up_cluster(dists, labels, 1, 0) == 1); + DLIB_TEST(labels.size() == 1); + DLIB_TEST(labels[0] == 0); + + dists.set_size(2,2); + dists = 1; + DLIB_TEST(bottom_up_cluster(dists, labels, 3) == 2); + DLIB_TEST(labels.size() == 2); + DLIB_TEST(labels[0] == 0); + DLIB_TEST(labels[1] == 1); + DLIB_TEST(bottom_up_cluster(dists, labels, 1) == 1); + DLIB_TEST(labels.size() == 2); + DLIB_TEST(labels[0] == 0); + DLIB_TEST(labels[1] == 0); + DLIB_TEST(bottom_up_cluster(dists, labels, 1, 1) == 1); + DLIB_TEST(labels.size() == 2); + DLIB_TEST(labels[0] == 0); + DLIB_TEST(labels[1] == 0); + DLIB_TEST(bottom_up_cluster(dists, labels, 1, 0.999) == 2); + DLIB_TEST(labels.size() == 2); + DLIB_TEST(labels[0] == 0); + DLIB_TEST(labels[1] == 1); + } + + void test_segment_number_line() + { + dlib::rand rnd; + + + std::vector x; + for (int i = 0; i < 5000; ++i) + { + x.push_back(rnd.get_double_in_range(-1.5, -1.01)); + x.push_back(rnd.get_double_in_range(-0.99, -0.01)); + x.push_back(rnd.get_double_in_range(0.01, 1)); + } + + auto r = segment_number_line(x,1); + std::sort(r.begin(), r.end()); + DLIB_TEST(r.size() == 3); + DLIB_TEST(-1.5 <= r[0].lower && r[0].lower < r[0].upper && r[0].upper <= -1.01); + DLIB_TEST(-0.99 <= r[1].lower && r[1].lower < r[1].upper && r[1].upper <= -0.01); + DLIB_TEST(0.01 <= r[2].lower && r[2].lower < r[2].upper && r[2].upper <= 1); + + x.clear(); + for (int i = 0; i < 5000; ++i) + { + x.push_back(rnd.get_double_in_range(-2, 1)); + x.push_back(rnd.get_double_in_range(-2, 1)); + x.push_back(rnd.get_double_in_range(-2, 1)); + } + + r = segment_number_line(x,1); + DLIB_TEST(r.size() == 3); + r = segment_number_line(x,1.5); + DLIB_TEST(r.size() == 2); + r = segment_number_line(x,10.5); + DLIB_TEST(r.size() == 1); + DLIB_TEST(-2 <= r[0].lower && r[0].lower < r[0].upper && r[0].upper <= 1); + } + + class test_clustering : public tester + { + public: + test_clustering ( + ) : + tester ("test_clustering", + "Runs tests on the clustering routines.") + {} + + void perform_test ( + ) + { + test_bottom_up_clustering(); + test_segment_number_line(); + + dlib::rand rnd; + + std::vector edges; + std::vector labels; + DLIB_TEST(newman_cluster(edges, labels) == 0); + DLIB_TEST(chinese_whispers(edges, labels) == 0); + + edges.push_back(sample_pair(0,1,1)); + DLIB_TEST(newman_cluster(edges, labels) == 1); + DLIB_TEST(labels.size() == 2); + DLIB_TEST(chinese_whispers(edges, labels) == 1); + DLIB_TEST(labels.size() == 2); + + edges.clear(); + edges.push_back(sample_pair(0,0,1)); + DLIB_TEST(newman_cluster(edges, labels) == 1); + DLIB_TEST(labels.size() == 1); + DLIB_TEST(chinese_whispers(edges, labels) == 1); + DLIB_TEST(labels.size() == 1); + + edges.clear(); + edges.push_back(sample_pair(1,1,1)); + DLIB_TEST(newman_cluster(edges, labels) == 1); + DLIB_TEST(labels.size() == 2); + DLIB_TEST(chinese_whispers(edges, labels) == 2); + DLIB_TEST(labels.size() == 2); + + edges.push_back(sample_pair(0,0,1)); + DLIB_TEST(newman_cluster(edges, labels) == 2); + DLIB_TEST(labels.size() == 2); + DLIB_TEST(chinese_whispers(edges, labels) == 2); + DLIB_TEST(labels.size() == 2); + + + for (int i = 0; i < 10; ++i) + test_modularity(rnd); + + for (int i = 0; i < 10; ++i) + test_newman_clustering(rnd); + + for (int i = 0; i < 10; ++i) + test_chinese_whispers(rnd); + + + } + } a; + + + +} + + + diff --git a/ml/dlib/dlib/test/cmd_line_parser.cpp b/ml/dlib/dlib/test/cmd_line_parser.cpp new file mode 100644 index 000000000..9216a76cc --- /dev/null +++ b/ml/dlib/dlib/test/cmd_line_parser.cpp @@ -0,0 +1,40 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include + +#include + +#include "tester.h" + +#include "cmd_line_parser.h" +namespace +{ + + class cmd_line_parser_tester : public tester + { + public: + cmd_line_parser_tester ( + ) : + tester ("test_cmd_line_parser_char", + "Runs tests on the cmd_line_parser component.") + {} + + void perform_test ( + ) + { + dlog << LINFO << "testing kernel_1a with char"; + cmd_line_parser_kernel_test::kernel_1a>(); + print_spinner(); + + dlog << LINFO << "testing kernel_1a_c with char"; + cmd_line_parser_kernel_test::kernel_1a_c>(); + print_spinner(); + } + } a; + +} + + diff --git a/ml/dlib/dlib/test/cmd_line_parser.h b/ml/dlib/dlib/test/cmd_line_parser.h new file mode 100644 index 000000000..6f8e411a4 --- /dev/null +++ b/ml/dlib/dlib/test/cmd_line_parser.h @@ -0,0 +1,901 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CMD_LINE_PARSER_KERNEl_TEST_H_ +#define DLIB_CMD_LINE_PARSER_KERNEl_TEST_H_ + + +#include +#include + +#include + +#include "tester.h" + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.cmd_line_parser"); + + template < + typename clp + > + void cmd_line_parser_kernel_test ( + ) + /*! + requires + - clp is an implementation of cmd_line_parser_kernel_abstract.h + ensures + - runs tests on clp for compliance with the specs + !*/ + { + typedef typename clp::char_type ct; + + + + + int argc; + const ct* argv[100]; + bool ok; + + for (int j = 0; j < 3; ++j) + { + clp test, test2; + + + + + DLIB_TEST(test.current_element_valid() == false); + DLIB_TEST(test.at_start()); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.current_element_valid() == false); + + + + DLIB_TEST(test.parsed_line() == false); + DLIB_TEST(test.option_is_defined(_dT(ct,"a")) == false); + DLIB_TEST(test.option_is_defined(_dT(ct,"a")) == false); + DLIB_TEST(test.option_is_defined(_dT(ct,"a")) == false); + + DLIB_TEST(test.parsed_line() == false); + DLIB_TEST(test.option_is_defined(_dT(ct,"a")) == false); + DLIB_TEST(test.option_is_defined(_dT(ct,"b")) == false); + DLIB_TEST(test.option_is_defined(_dT(ct,"\0")) == false); + + DLIB_TEST(test.current_element_valid() == false); + DLIB_TEST(test.at_start() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.current_element_valid() == false); + + + + // program arg1 --davis arg2 -cZzarg asdf + argv[0] = _dT(ct,"program"); + argv[1] = _dT(ct,"arg1"); + argv[2] = _dT(ct,"--davis"); + argv[3] = _dT(ct,"arg2"); + argv[4] = _dT(ct,"-cZzarg"); + argv[5] = _dT(ct,"asdf"); + argc = 6; + + + test.add_option(_dT(ct,"davis"),_dT(ct,"davis option")); + test.add_option(_dT(ct,"c"),_dT(ct,"c option")); + test.add_option(_dT(ct,"d"),_dT(ct,"d option")); + test.add_option(_dT(ct,"Z"),_dT(ct,"Z option"),2); + + + for (int k = 0; k < 5; ++k) + { + + try { test.parse(argc,argv); } + catch (error& e) + { + DLIB_TEST_MSG(false,e.info); + } + + DLIB_TEST(test.option(_dT(ct,"davis")).name() == _dT(ct,"davis")); + DLIB_TEST(test.option(_dT(ct,"c")).name() == _dT(ct,"c")); + DLIB_TEST(test.option(_dT(ct,"Z")).name() == _dT(ct,"Z")); + DLIB_TEST(test.option(_dT(ct,"davis")).number_of_arguments() == 0); + DLIB_TEST(test.option(_dT(ct,"c")).number_of_arguments() == 0); + DLIB_TEST(test.option(_dT(ct,"Z")).number_of_arguments() == 2); + DLIB_TEST(test.number_of_arguments() == 2); + DLIB_TEST(test[0] == _dT(ct,"arg1")); + DLIB_TEST(test[1] == _dT(ct,"arg2")); + DLIB_TEST(test.option(_dT(ct,"d")).count()==0); + DLIB_TEST(test.option(_dT(ct,"davis")).count()==1); + DLIB_TEST_MSG(test.option(_dT(ct,"c")).count()==1,test.option(_dT(ct,"c")).count()); + DLIB_TEST(test.option(_dT(ct,"Z")).count()==1); + DLIB_TEST(test.option(_dT(ct,"Z")).argument(0,0) == _dT(ct,"zarg")); + DLIB_TEST(test.option(_dT(ct,"Z")).argument(1,0) == _dT(ct,"asdf")); + + } + + + + swap(test,test2); + + + + + + // program arg1 --davis arg2 -cZ zarg asdf + argv[0] = _dT(ct,"program"); + argv[1] = _dT(ct,"arg1"); + argv[2] = _dT(ct,"--davis"); + argv[3] = _dT(ct,"arg2"); + argv[4] = _dT(ct,"-cZ"); + argv[5] = _dT(ct,"zarg"); + argv[6] = _dT(ct,"asdf"); + argc = 7; + + + + + for (int k = 0; k < 5; ++k) + { + + try { test2.parse(argc,argv); } + catch (error& e) + { + DLIB_TEST_MSG(false,e.info); + } + + DLIB_TEST(test2.option(_dT(ct,"davis")).name() == _dT(ct,"davis")); + DLIB_TEST(test2.option(_dT(ct,"c")).name() == _dT(ct,"c")); + DLIB_TEST(test2.option(_dT(ct,"Z")).name() == _dT(ct,"Z")); + DLIB_TEST(test2.option(_dT(ct,"davis")).number_of_arguments() == 0); + DLIB_TEST(test2.option(_dT(ct,"c")).number_of_arguments() == 0); + DLIB_TEST(test2.option(_dT(ct,"Z")).number_of_arguments() == 2); + DLIB_TEST(test2.number_of_arguments() == 2); + DLIB_TEST(test2[0] == _dT(ct,"arg1")); + DLIB_TEST(test2[1] == _dT(ct,"arg2")); + DLIB_TEST(test2.option(_dT(ct,"d")).count()==0); + DLIB_TEST(test2.option(_dT(ct,"davis")).count()==1); + DLIB_TEST(test2.option(_dT(ct,"c")).count()==1); + DLIB_TEST(test2.option(_dT(ct,"Z")).count()==1); + DLIB_TEST(test2.option(_dT(ct,"Z")).argument(1,0) == _dT(ct,"asdf")); + DLIB_TEST_MSG(test2.option(_dT(ct,"Z")).argument(0,0) == _dT(ct,"zarg"), + narrow(_dT(ct,"*") + test2.option(_dT(ct,"Z")).argument(0,0) + _dT(ct,"*"))); + + + } + + + + + + // program arg1 --davis= darg darg2 arg2 -cZzarg asdf + argv[0] = _dT(ct,"program"); + argv[1] = _dT(ct,"arg1"); + argv[2] = _dT(ct,"--davis="); + argv[3] = _dT(ct,"darg"); + argv[4] = _dT(ct,"darg2"); + argv[5] = _dT(ct,"arg2"); + argv[6] = _dT(ct,"-cZzarg"); + argv[7] = _dT(ct,"asdf"); + argc = 8; + + + test.add_option(_dT(ct,"davis"),_dT(ct,"davis option"), 2); + test.add_option(_dT(ct,"c"),_dT(ct,"c option")); + test.add_option(_dT(ct,"d"),_dT(ct,"d option")); + test.add_option(_dT(ct,"Z"),_dT(ct,"Z option"),2); + + + for (int k = 0; k < 5; ++k) + { + + try { test.parse(argc,argv); } + catch (error& e) + { + DLIB_TEST_MSG(false,e.info); + } + + DLIB_TEST(test.parsed_line()); + + int count = 0; + while (test.move_next()) + { + ++count; + if (test.element().name() == _dT(ct,"d")) + { + DLIB_TEST(test.element().count() == 0); + } + else + { + DLIB_TEST(test.element().count() == 1); + } + + } + DLIB_TEST_MSG(count == 4,count); + + DLIB_TEST(test.option(_dT(ct,"davis")).name() == _dT(ct,"davis")); + DLIB_TEST(test.option(_dT(ct,"c")).name() == _dT(ct,"c")); + DLIB_TEST(test.option(_dT(ct,"Z")).name() == _dT(ct,"Z")); + DLIB_TEST(test.option(_dT(ct,"davis")).number_of_arguments() == 2); + DLIB_TEST(test.option(_dT(ct,"c")).number_of_arguments() == 0); + DLIB_TEST(test.option(_dT(ct,"Z")).number_of_arguments() == 2); + DLIB_TEST(test.number_of_arguments() == 2); + DLIB_TEST(test[0] == _dT(ct,"arg1")); + DLIB_TEST(test[1] == _dT(ct,"arg2")); + DLIB_TEST(test.option(_dT(ct,"d")).count()==0); + DLIB_TEST(test.option(_dT(ct,"davis")).count()==1); + DLIB_TEST(test.option(_dT(ct,"c")).count()==1); + DLIB_TEST(test.option(_dT(ct,"Z")).count()==1); + DLIB_TEST(test.option(_dT(ct,"Z")).argument(0,0) == _dT(ct,"zarg")); + DLIB_TEST(test.option(_dT(ct,"Z")).argument(1,0) == _dT(ct,"asdf")); + DLIB_TEST(test.option(_dT(ct,"davis")).argument(0,0) == _dT(ct,"darg")); + DLIB_TEST_MSG(test.option(_dT(ct,"davis")).argument(1,0) == _dT(ct,"darg2"), + narrow(test.option(_dT(ct,"davis")).argument(1,0))); + } + + + + + + + + + + + test.clear(); + + + + + + + + // program arg1 --dav-is=darg darg2 arg2 -cZzarg asdf + argv[0] = _dT(ct,"program"); + argv[1] = _dT(ct,"arg1"); + argv[2] = _dT(ct,"--dav-is=darg"); + argv[3] = _dT(ct,"darg2"); + argv[4] = _dT(ct,"arg2"); + argv[5] = _dT(ct,"-cZzarg"); + argv[6] = _dT(ct,"asdf"); + argc = 7; + + + test.add_option(_dT(ct,"dav-is"),_dT(ct,"davis option"), 2); + test.add_option(_dT(ct,"c"),_dT(ct,"c option")); + test.add_option(_dT(ct,"d"),_dT(ct,"d option")); + test.add_option(_dT(ct,"Z"),_dT(ct,"Z option"),2); + + + for (int k = 0; k < 5; ++k) + { + + try { test.parse(argc,argv); } + catch (error& e) + { + DLIB_TEST_MSG(false,e.info); + } + + DLIB_TEST(test.parsed_line()); + + int count = 0; + while (test.move_next()) + { + ++count; + if (test.element().name() == _dT(ct,"d")) + { + DLIB_TEST(test.element().count() == 0); + } + else + { + DLIB_TEST(test.element().count() == 1); + } + + } + DLIB_TEST_MSG(count == 4,count); + + DLIB_TEST(test.option(_dT(ct,"dav-is")).name() == _dT(ct,"dav-is")); + DLIB_TEST(test.option(_dT(ct,"c")).name() == _dT(ct,"c")); + DLIB_TEST(test.option(_dT(ct,"Z")).name() == _dT(ct,"Z")); + DLIB_TEST(test.option(_dT(ct,"dav-is")).number_of_arguments() == 2); + DLIB_TEST(test.option(_dT(ct,"c")).number_of_arguments() == 0); + DLIB_TEST(test.option(_dT(ct,"Z")).number_of_arguments() == 2); + DLIB_TEST(test.number_of_arguments() == 2); + DLIB_TEST(test[0] == _dT(ct,"arg1")); + DLIB_TEST(test[1] == _dT(ct,"arg2")); + DLIB_TEST(test.option(_dT(ct,"d")).count()==0); + DLIB_TEST(test.option(_dT(ct,"dav-is")).count()==1); + DLIB_TEST(test.option(_dT(ct,"c")).count()==1); + DLIB_TEST(test.option(_dT(ct,"Z")).count()==1); + DLIB_TEST(test.option(_dT(ct,"Z")).argument(0,0) == _dT(ct,"zarg")); + DLIB_TEST(test.option(_dT(ct,"Z")).argument(1,0) == _dT(ct,"asdf")); + DLIB_TEST(test.option(_dT(ct,"dav-is")).argument(0,0) == _dT(ct,"darg")); + DLIB_TEST_MSG(test.option(_dT(ct,"dav-is")).argument(1,0) == _dT(ct,"darg2"), + narrow(test.option(_dT(ct,"dav-is")).argument(1,0))); + } + + + + + + + + + + test.clear(); + + + + + + + + // program arg1 --davis=darg darg2 arg2 -cZzarg asdf + argv[0] = _dT(ct,"program"); + argv[1] = _dT(ct,"arg1"); + argv[2] = _dT(ct,"--davis=darg"); + argv[3] = _dT(ct,"darg2"); + argv[4] = _dT(ct,"arg2"); + argv[5] = _dT(ct,"-cZzarg"); + argv[6] = _dT(ct,"asdf"); + argc = 7; + + + test.add_option(_dT(ct,"davis"),_dT(ct,"davis option"), 2); + test.add_option(_dT(ct,"c"),_dT(ct,"c option")); + test.add_option(_dT(ct,"d"),_dT(ct,"d option")); + test.add_option(_dT(ct,"Z"),_dT(ct,"Z option"),2); + + + for (int k = 0; k < 5; ++k) + { + + try { test.parse(argc,argv); } + catch (error& e) + { + DLIB_TEST_MSG(false,e.info); + } + + DLIB_TEST(test.parsed_line()); + + int count = 0; + while (test.move_next()) + { + ++count; + if (test.element().name() == _dT(ct,"d")) + { + DLIB_TEST(test.element().count() == 0); + } + else + { + DLIB_TEST(test.element().count() == 1); + } + + } + DLIB_TEST_MSG(count == 4,count); + + DLIB_TEST(test.option(_dT(ct,"davis")).name() == _dT(ct,"davis")); + DLIB_TEST(test.option(_dT(ct,"c")).name() == _dT(ct,"c")); + DLIB_TEST(test.option(_dT(ct,"Z")).name() == _dT(ct,"Z")); + DLIB_TEST(test.option(_dT(ct,"davis")).number_of_arguments() == 2); + DLIB_TEST(test.option(_dT(ct,"c")).number_of_arguments() == 0); + DLIB_TEST(test.option(_dT(ct,"Z")).number_of_arguments() == 2); + DLIB_TEST(test.number_of_arguments() == 2); + DLIB_TEST(test[0] == _dT(ct,"arg1")); + DLIB_TEST(test[1] == _dT(ct,"arg2")); + DLIB_TEST(test.option(_dT(ct,"d")).count()==0); + DLIB_TEST(test.option(_dT(ct,"davis")).count()==1); + DLIB_TEST(test.option(_dT(ct,"c")).count()==1); + DLIB_TEST(test.option(_dT(ct,"Z")).count()==1); + DLIB_TEST(test.option(_dT(ct,"Z")).argument(0,0) == _dT(ct,"zarg")); + DLIB_TEST(test.option(_dT(ct,"Z")).argument(1,0) == _dT(ct,"asdf")); + DLIB_TEST(test.option(_dT(ct,"davis")).argument(0,0) == _dT(ct,"darg")); + DLIB_TEST_MSG(test.option(_dT(ct,"davis")).argument(1,0) == _dT(ct,"darg2"), + narrow(test.option(_dT(ct,"davis")).argument(1,0))); + } + + + + + + + + + + test.clear(); + + + + + + + + // program arg1 --davis=darg arg2 -cZzarg asdf + argv[0] = _dT(ct,"program"); + argv[1] = _dT(ct,"arg1"); + argv[2] = _dT(ct,"--davis=darg"); + argv[3] = _dT(ct,"arg2"); + argv[4] = _dT(ct,"-cZzarg"); + argv[5] = _dT(ct,"asdf"); + argc = 6; + + + test.add_option(_dT(ct,"davis"),_dT(ct,"davis option"), 1); + test.add_option(_dT(ct,"c"),_dT(ct,"c option")); + test.add_option(_dT(ct,"d"),_dT(ct,"d option")); + test.add_option(_dT(ct,"Z"),_dT(ct,"Z option"),2); + + + for (int k = 0; k < 5; ++k) + { + + try { test.parse(argc,argv); } + catch (error& e) + { + DLIB_TEST_MSG(false,e.info); + } + + DLIB_TEST(test.parsed_line()); + + int count = 0; + while (test.move_next()) + { + ++count; + if (test.element().name() == _dT(ct,"d")) + { + DLIB_TEST(test.element().count() == 0); + } + else + { + DLIB_TEST(test.element().count() == 1); + } + + } + DLIB_TEST_MSG(count == 4,count); + + DLIB_TEST(test.option(_dT(ct,"davis")).name() == _dT(ct,"davis")); + DLIB_TEST(test.option(_dT(ct,"c")).name() == _dT(ct,"c")); + DLIB_TEST(test.option(_dT(ct,"Z")).name() == _dT(ct,"Z")); + DLIB_TEST(test.option(_dT(ct,"davis")).number_of_arguments() == 1); + DLIB_TEST(test.option(_dT(ct,"c")).number_of_arguments() == 0); + DLIB_TEST(test.option(_dT(ct,"Z")).number_of_arguments() == 2); + DLIB_TEST(test.number_of_arguments() == 2); + DLIB_TEST(test[0] == _dT(ct,"arg1")); + DLIB_TEST(test[1] == _dT(ct,"arg2")); + DLIB_TEST(test.option(_dT(ct,"d")).count()==0); + DLIB_TEST(test.option(_dT(ct,"davis")).count()==1); + DLIB_TEST(test.option(_dT(ct,"c")).count()==1); + DLIB_TEST(test.option(_dT(ct,"Z")).count()==1); + DLIB_TEST(test.option(_dT(ct,"Z")).argument(0,0) == _dT(ct,"zarg")); + DLIB_TEST(test.option(_dT(ct,"Z")).argument(1,0) == _dT(ct,"asdf")); + DLIB_TEST(test.option(_dT(ct,"davis")).argument(0,0) == _dT(ct,"darg")); + } + + + + + + + + + + test.clear(); + + + + + + + // program arg1 --davis darg arg2 -cZzarg asdf + argv[0] = _dT(ct,"program"); + argv[1] = _dT(ct,"arg1"); + argv[2] = _dT(ct,"--davis"); + argv[3] = _dT(ct,"darg"); + argv[4] = _dT(ct,"arg2"); + argv[5] = _dT(ct,"-cZzarg"); + argv[6] = _dT(ct,"asdf"); + argc = 7; + + + test.add_option(_dT(ct,"davis"),_dT(ct,"davis option"), 1); + test.add_option(_dT(ct,"c"),_dT(ct,"c option")); + test.add_option(_dT(ct,"d"),_dT(ct,"d option")); + test.add_option(_dT(ct,"Z"),_dT(ct,"Z option"),2); + + + for (int k = 0; k < 5; ++k) + { + + try { test.parse(argc,argv); } + catch (error& e) + { + DLIB_TEST_MSG(false,e.info); + } + + DLIB_TEST(test.parsed_line()); + + int count = 0; + while (test.move_next()) + { + ++count; + if (test.element().name() == _dT(ct,"d")) + { + DLIB_TEST(test.element().count() == 0); + } + else + { + DLIB_TEST(test.element().count() == 1); + } + + } + DLIB_TEST_MSG(count == 4,count); + + DLIB_TEST(test.option(_dT(ct,"davis")).name() == _dT(ct,"davis")); + DLIB_TEST(test.option(_dT(ct,"c")).name() == _dT(ct,"c")); + DLIB_TEST(test.option(_dT(ct,"Z")).name() == _dT(ct,"Z")); + DLIB_TEST(test.option(_dT(ct,"davis")).number_of_arguments() == 1); + DLIB_TEST(test.option(_dT(ct,"c")).number_of_arguments() == 0); + DLIB_TEST(test.option(_dT(ct,"Z")).number_of_arguments() == 2); + DLIB_TEST(test.number_of_arguments() == 2); + DLIB_TEST(test[0] == _dT(ct,"arg1")); + DLIB_TEST(test[1] == _dT(ct,"arg2")); + DLIB_TEST(test.option(_dT(ct,"d")).count()==0); + DLIB_TEST(test.option(_dT(ct,"davis")).count()==1); + DLIB_TEST(test.option(_dT(ct,"c")).count()==1); + DLIB_TEST(test.option(_dT(ct,"Z")).count()==1); + DLIB_TEST(test.option(_dT(ct,"Z")).argument(0,0) == _dT(ct,"zarg")); + DLIB_TEST(test.option(_dT(ct,"Z")).argument(1) == _dT(ct,"asdf")); + DLIB_TEST(test.option(_dT(ct,"davis")).argument(0,0) == _dT(ct,"darg")); + } + + + + + + + + + + test.clear(); + + // this string is incorrect because there is no avis option + // program arg1 --avis darg arg2 -cZzarg asdf + argv[0] = _dT(ct,"program"); + argv[1] = _dT(ct,"arg1"); + argv[2] = _dT(ct,"--avis"); + argv[3] = _dT(ct,"darg"); + argv[4] = _dT(ct,"arg2"); + argv[5] = _dT(ct,"-cZzarg"); + argv[6] = _dT(ct,"asdf"); + argc = 7; + + + test.add_option(_dT(ct,"davis"),_dT(ct,"davis option"), 1); + test.add_option(_dT(ct,"c"),_dT(ct,"c option")); + test.add_option(_dT(ct,"d"),_dT(ct,"d option")); + test.add_option(_dT(ct,"Z"),_dT(ct,"Z option"),2); + + + for (int k = 0; k < 5; ++k) + { + + ok = false; + try { test.parse(argc,argv); } + catch (typename clp::cmd_line_parse_error& e) + { + DLIB_TEST(e.type == EINVALID_OPTION); + DLIB_TEST(e.item == _dT(ct,"avis")); + ok = true; + } + DLIB_TEST(ok); + + + } + + + + + + + + + + + + test.clear(); + + // the c argument appears twice. make sure its count is correct + // program arg1 --davis darg arg2 -ccZzarg asdf + argv[0] = _dT(ct,"program"); + argv[1] = _dT(ct,"arg1"); + argv[2] = _dT(ct,"--davis"); + argv[3] = _dT(ct,"darg"); + argv[4] = _dT(ct,"arg2"); + argv[5] = _dT(ct,"-ccZ"); + argv[6] = _dT(ct,"zarg"); + argv[7] = _dT(ct,"asdf"); + argc = 8; + + + test.add_option(_dT(ct,"davis"),_dT(ct,"davis option"), 1); + test.add_option(_dT(ct,"c"),_dT(ct,"c option")); + test.add_option(_dT(ct,"d"),_dT(ct,"d option")); + test.add_option(_dT(ct,"Z"),_dT(ct,"Z option"),2); + + + for (int k = 0; k < 5; ++k) + { + + ok = false; + test.parse(argc,argv); + + DLIB_TEST(test.option(_dT(ct,"c")).count()==2); + + } + + + + + + + + + + + + + + + + test.clear(); + + // this is a bad line because the davis argument requires 2 arguments but + // only gets one. + // program arg1 --davis darg darg2 --davis zarg + argv[0] = _dT(ct,"program"); + argv[1] = _dT(ct,"arg1"); + argv[2] = _dT(ct,"--davis"); + argv[3] = _dT(ct,"darg"); + argv[4] = _dT(ct,"darg2"); + argv[5] = _dT(ct,"--davis"); + argv[6] = _dT(ct,"zarg"); + argc = 7; + + + test.add_option(_dT(ct,"davis"),_dT(ct,"davis option"), 2); + test.add_option(_dT(ct,"b"),_dT(ct,"b option")); + test.add_option(_dT(ct,"d"),_dT(ct,"d option")); + test.add_option(_dT(ct,"Z"),_dT(ct,"Z option"),2); + + + DLIB_TEST(test.option(_dT(ct,"davis")).description() == _dT(ct,"davis option")); + DLIB_TEST(test.option(_dT(ct,"b")).description() == _dT(ct,"b option")); + DLIB_TEST(test.option(_dT(ct,"d")).description() == _dT(ct,"d option")); + DLIB_TEST(test.option(_dT(ct,"Z")).description() == _dT(ct,"Z option")); + + for (int k = 0; k < 5; ++k) + { + + ok = false; + try { test.parse(argc,argv); } + catch (typename clp::cmd_line_parse_error& e) + { + DLIB_TEST(e.type == ETOO_FEW_ARGS); + DLIB_TEST(e.num == 2); + DLIB_TEST(e.item == _dT(ct,"davis")); + ok = true; + } + DLIB_TEST(ok); + + + + int count = 0; + while (test.move_next()) + { + ++count; + DLIB_TEST(test.element().count() == 0); + DLIB_TEST(test.option_is_defined(test.element().name())); + } + DLIB_TEST_MSG(count == 4,count); + + + } + + + + + + + + + + + + + + + + + + + test.clear(); + + // this is a bad line because the davis argument is not defined + // program arg1 --davis darg arg2 -davis zarg asdf + argv[0] = _dT(ct,"program"); + argv[1] = _dT(ct,"arg1"); + argv[2] = _dT(ct,"--davis"); + argv[3] = _dT(ct,"darg"); + argv[4] = _dT(ct,"arg2"); + argv[5] = _dT(ct,"--davis"); + argv[6] = _dT(ct,"zarg"); + argv[7] = _dT(ct,"asdf"); + argc = 8; + + + DLIB_TEST(std::basic_string(argv[0]) == _dT(ct,"program")); + + test.add_option(_dT(ct,"mavis"),_dT(ct,"mavis option"), 1); + test.add_option(_dT(ct,"b"),_dT(ct,"b option")); + test.add_option(_dT(ct,"d"),_dT(ct,"d option")); + test.add_option(_dT(ct,"Z"),_dT(ct,"Z option"),2); + + + DLIB_TEST(test.option(_dT(ct,"mavis")).description() == _dT(ct,"mavis option")); + DLIB_TEST(test.option(_dT(ct,"b")).description() == _dT(ct,"b option")); + DLIB_TEST(test.option(_dT(ct,"d")).description() == _dT(ct,"d option")); + DLIB_TEST(test.option(_dT(ct,"Z")).description() == _dT(ct,"Z option")); + + for (int k = 0; k < 5; ++k) + { + + ok = false; + try { test.parse(argc,argv); } + catch (typename clp::cmd_line_parse_error& e) + { + DLIB_TEST(e.type == EINVALID_OPTION); + DLIB_TEST(e.item == _dT(ct,"davis")); + ok = true; + } + DLIB_TEST(ok); + + + + int count = 0; + while (test.move_next()) + { + ++count; + DLIB_TEST(test.element().count() == 0); + DLIB_TEST(test.option_is_defined(test.element().name())); + } + DLIB_TEST_MSG(count == 4,count); + + + } + + + + + + + + + + + + + + + + test.clear(); + + + argv[0] = _dT(ct,"program"); + argc = 1; + + + test.add_option(_dT(ct,"davis"),_dT(ct,"davis option"), 1); + test.add_option(_dT(ct,"c"),_dT(ct,"c option")); + test.add_option(_dT(ct,"d"),_dT(ct,"d option")); + test.add_option(_dT(ct,"Z"),_dT(ct,"Z option"),2); + + + DLIB_TEST(test.option(_dT(ct,"davis")).description() == _dT(ct,"davis option")); + DLIB_TEST(test.option(_dT(ct,"c")).description() == _dT(ct,"c option")); + DLIB_TEST(test.option(_dT(ct,"d")).description() == _dT(ct,"d option")); + DLIB_TEST(test.option(_dT(ct,"Z")).description() == _dT(ct,"Z option")); + + for (int k = 0; k < 5; ++k) + { + + test.parse(argc,argv); + + + DLIB_TEST(test.number_of_arguments() == 0); + + int count = 0; + while (test.move_next()) + { + ++count; + DLIB_TEST(test.element().count() == 0); + DLIB_TEST(test.option_is_defined(test.element().name())); + } + DLIB_TEST_MSG(count == 4,count); + + + } + + + + + + + + + + + + + test.clear(); + + // this is to make sure the -- command works right + // program arg1 --davis -darg -- arg2 -c asdf -Zeat -Zat -Zjoe's + argv[0] = _dT(ct,"program"); + argv[1] = _dT(ct,"arg1"); + argv[2] = _dT(ct,"--davis"); + argv[3] = _dT(ct,"-darg"); + argv[4] = _dT(ct,"-Zeat"); + argv[5] = _dT(ct,"-Zat"); + argv[6] = _dT(ct,"-Zjoe's"); + argv[7] = _dT(ct,"--"); + argv[8] = _dT(ct,"arg2"); + argv[9] = _dT(ct,"-c"); + argv[10] = _dT(ct,"asdf"); + + argc = 11; + + + test.add_option(_dT(ct,"davis"),_dT(ct,"davis option"), 1); + test.add_option(_dT(ct,"c"),_dT(ct,"c option")); + test.add_option(_dT(ct,"d"),_dT(ct,"d option")); + test.add_option(_dT(ct,"Z"),_dT(ct,"Z option"),1); + + + DLIB_TEST(test.option(_dT(ct,"davis")).description() == _dT(ct,"davis option")); + DLIB_TEST(test.option(_dT(ct,"c")).description() == _dT(ct,"c option")); + DLIB_TEST(test.option(_dT(ct,"d")).description() == _dT(ct,"d option")); + DLIB_TEST(test.option(_dT(ct,"Z")).description() == _dT(ct,"Z option")); + + for (int k = 0; k < 5; ++k) + { + + test.parse(argc,argv); + + DLIB_TEST_MSG(test.number_of_arguments() == 4,test.number_of_arguments()); + DLIB_TEST(test[0] == _dT(ct,"arg1")); + DLIB_TEST(test[1] == _dT(ct,"arg2")); + DLIB_TEST(test[2] == _dT(ct,"-c")); + DLIB_TEST(test[3] == _dT(ct,"asdf")); + + DLIB_TEST(test.option(_dT(ct,"davis")).count()==1); + DLIB_TEST(test.option(_dT(ct,"davis")).argument() == _dT(ct,"-darg")); + DLIB_TEST(test.option(_dT(ct,"c")).count()==0); + DLIB_TEST(test.option(_dT(ct,"d")).count()==0); + DLIB_TEST(test.option(_dT(ct,"Z")).count()==3); + + DLIB_TEST(test.option(_dT(ct,"Z")).argument(0,0) == _dT(ct,"eat")); + DLIB_TEST(test.option(_dT(ct,"Z")).argument(0,1) == _dT(ct,"at")); + DLIB_TEST(test.option(_dT(ct,"Z")).argument(0,2) == _dT(ct,"joe's")); + + + } + + + } + } + +} + + +#endif // DLIB_CMD_LINE_PARSER_KERNEl_TEST_H_ + diff --git a/ml/dlib/dlib/test/cmd_line_parser_wchar_t.cpp b/ml/dlib/dlib/test/cmd_line_parser_wchar_t.cpp new file mode 100644 index 000000000..f771eeedb --- /dev/null +++ b/ml/dlib/dlib/test/cmd_line_parser_wchar_t.cpp @@ -0,0 +1,40 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include + +#include + +#include "tester.h" + +#include "cmd_line_parser.h" +namespace +{ + + class cmd_line_parser_tester : public tester + { + public: + cmd_line_parser_tester ( + ) : + tester ("test_cmd_line_parser_wchar_t", + "Runs tests on the cmd_line_parser component.") + {} + + void perform_test ( + ) + { + dlog << LINFO << "testing kernel_1a with wchar_t"; + cmd_line_parser_kernel_test::kernel_1a>(); + print_spinner(); + + dlog << LINFO << "testing kernel_1a_c with wchar_t"; + cmd_line_parser_kernel_test::kernel_1a_c>(); + print_spinner(); + } + } a; + +} + + diff --git a/ml/dlib/dlib/test/compress_stream.cpp b/ml/dlib/dlib/test/compress_stream.cpp new file mode 100644 index 000000000..fbc57dd4c --- /dev/null +++ b/ml/dlib/dlib/test/compress_stream.cpp @@ -0,0 +1,306 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include + +#include + +#include "tester.h" + +namespace +{ + + using namespace test; + using namespace std; + using namespace dlib; + + logger dlog("test.compress_stream"); + + template < + typename cs + > + void compress_stream_kernel_test ( + unsigned long seed + ) + /*! + requires + - cs is an implementation of compress_stream/compress_stream_kernel_abstract.h + the alphabet_size for cc is 256 + ensures + - runs tests on cs for compliance with the specs + !*/ + { + + + srand(seed); + + cs test; + + + dlog << LTRACE << 1; + + int count = 0; + while (count < 2) + { + print_spinner(); + istringstream sin; + ostringstream sout; + string buffer; + buffer.reserve(10000); + // fill sin with a bunch of random data in the range 0 to 63 + for (int i = 0; i < 10000; ++i) + { + char temp = static_cast(::rand()&0x3f); + buffer.push_back(temp); + } + + print_spinner(); + sin.str(buffer); + string old_buffer = buffer; + + test.compress(sin,sout); + buffer = sout.str(); + + print_spinner(); + // corrput the data in buffer + buffer[buffer.size()/2]++; + + sin.str(buffer); + sout.str(""); + + bool detected_error = false; + try { + test.decompress(sin,sout); + } catch ( typename cs::decompression_error e ) + { + detected_error = true; + ++count; + } + + + DLIB_TEST_MSG(detected_error || sout.str() == old_buffer,(unsigned int)sout.str().size()); + + + + } /**/ + + + dlog << LTRACE << 2; + + for (int j = 0; j < 2; ++j) + { + + print_spinner(); + istringstream sin; + ostringstream sout; + + string buffer; + + buffer.reserve(10); + + // make sure a single char can be compressed and decompressed + for (int i = 0; i < 256; ++i) + { + sin.str(""); + sout.str(""); + char ch = static_cast(i); + buffer = ch; + sin.str(buffer); + + test.compress(sin,sout); + + sin.str(sout.str()); + sout.str(""); + test.decompress(sin,sout); + DLIB_TEST(sout.str() == buffer); + } + + print_spinner(); + + // make sure you can compress a single char, then append a new + // compressed single char. and make sure you can decode the + // two streams. Just to make sure the decoder doesn't leave + // extra bytes behind or eat more than it should. + for (int i = 0; i < 500; ++i) + { + sin.str(""); + sin.clear(); + sout.str(""); + sout.clear(); + char ch = static_cast(::rand()%256); + char ch2 = static_cast(::rand()%256); + + buffer = ch; + sin.str(buffer); + + + + test.compress(sin,sout); + + + + + buffer = ch2; + sin.str(buffer); + test.compress(sin,sout); + + sin.str(sout.str()); + + sout.str(""); + test.decompress(sin,sout); + buffer = ch; + DLIB_TEST(sout.str() == buffer); + + + + + sout.str(""); + test.decompress(sin,sout); + buffer = ch2; + DLIB_TEST(sout.str() == buffer); + + + } + print_spinner(); + + + // make sure you can compress and decompress the empty string + sout.str(""); + sin.str(""); + test.compress(sin,sout); + sin.str(sout.str()); + sout.str(""); + test.decompress(sin,sout); + DLIB_TEST_MSG(sout.str() == "",sout.str()); + + + + + + print_spinner(); + + sin.str(""); + sout.str(""); + buffer = ""; + + buffer.reserve(20000); + // fill buffer with a bunch of random data in the range 0 to 63 + for (int i = 0; i < 20000; ++i) + { + char temp = static_cast(::rand()&0x3f); + buffer.push_back(temp); + } + + sin.str(buffer); + + print_spinner(); + test.compress(sin,sout); + + sin.str(sout.str()); + sout.str(""); + + print_spinner(); + test.decompress(sin,sout); + + DLIB_TEST(sout.str() == buffer); + + print_spinner(); + } + + dlog << LTRACE << 3; + + // this block will try to compress a bunch of 'a' chars + { + + istringstream sin; + ostringstream sout; + + string buffer; + + + print_spinner(); + + sin.str(""); + sout.str(""); + buffer = ""; + + buffer.reserve(50000); + // fill buffer with a bunch of 'a' chars + for (int i = 0; i < 50000; ++i) + { + char temp = 'a'; + buffer.push_back(temp); + } + + sin.str(buffer); + + print_spinner(); + test.compress(sin,sout); + + sin.str(sout.str()); + sout.str(""); + + print_spinner(); + test.decompress(sin,sout); + + DLIB_TEST(sout.str() == buffer); + + print_spinner(); + + } + + dlog << LTRACE << 4; + + } + + + + + + + class compress_stream_tester : public tester + { + public: + compress_stream_tester ( + ) : + tester ("test_compress_stream", + "Runs tests on the compress_stream component.") + {} + + void perform_test ( + ) + { + const unsigned int seed = static_cast(time(0)); + dlog << LINFO << "using seed: " << seed; + + dlog << LINFO << "testing kernel_1a"; + compress_stream_kernel_test(seed); + dlog << LINFO << "testing kernel_1b"; + compress_stream_kernel_test(seed); + dlog << LINFO << "testing kernel_1c"; + compress_stream_kernel_test(seed); + dlog << LINFO << "testing kernel_1da"; + compress_stream_kernel_test(seed); + dlog << LINFO << "testing kernel_1db"; + compress_stream_kernel_test(seed); + dlog << LINFO << "testing kernel_1ea"; + compress_stream_kernel_test(seed); + dlog << LINFO << "testing kernel_1eb"; + compress_stream_kernel_test(seed); + dlog << LINFO << "testing kernel_1ec"; + compress_stream_kernel_test(seed); + dlog << LINFO << "testing kernel_2a"; + compress_stream_kernel_test(seed); + dlog << LINFO << "testing kernel_3a"; + compress_stream_kernel_test(seed); + dlog << LINFO << "testing kernel_3b"; + compress_stream_kernel_test(seed); + } + } a; + +} + diff --git a/ml/dlib/dlib/test/conditioning_class.cpp b/ml/dlib/dlib/test/conditioning_class.cpp new file mode 100644 index 000000000..b4415eafd --- /dev/null +++ b/ml/dlib/dlib/test/conditioning_class.cpp @@ -0,0 +1,86 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include + +#include + +#include "tester.h" +#include "conditioning_class.h" + +namespace +{ + + + class conditioning_class_tester : public tester + { + public: + conditioning_class_tester ( + ) : + tester ("test_conditioning_class", + "Runs tests on the conditioning_class component.") + {} + + void perform_test ( + ) + { + dlog << LINFO << "testing kernel_1a"; + conditioning_class_kernel_test< + conditioning_class<256>::kernel_1a, + conditioning_class<2>::kernel_1a + >(); + print_spinner(); + + dlog << LINFO << "testing kernel_2a"; + conditioning_class_kernel_test< + conditioning_class<256>::kernel_2a, + conditioning_class<2>::kernel_2a + >(); + print_spinner(); + + dlog << LINFO << "testing kernel_3a"; + conditioning_class_kernel_test< + conditioning_class<256>::kernel_3a, + conditioning_class<2>::kernel_3a + >(); + print_spinner(); + + dlog << LINFO << "testing kernel_4a"; + conditioning_class_kernel_test< + conditioning_class<256>::kernel_4a, + conditioning_class<2>::kernel_4a + >(); + print_spinner(); + + dlog << LINFO << "testing kernel_4b"; + conditioning_class_kernel_test< + conditioning_class<256>::kernel_4b, + conditioning_class<2>::kernel_4b + >(); + print_spinner(); + + dlog << LINFO << "testing kernel_4c"; + conditioning_class_kernel_test< + conditioning_class<256>::kernel_4c, + conditioning_class<2>::kernel_4c + >(); + print_spinner(); + + dlog << LINFO << "testing kernel_4d"; + conditioning_class_kernel_test< + conditioning_class<256>::kernel_4d, + conditioning_class<2>::kernel_4d + >(); + print_spinner(); + + + } + } a; + + +} + diff --git a/ml/dlib/dlib/test/conditioning_class.h b/ml/dlib/dlib/test/conditioning_class.h new file mode 100644 index 000000000..3c6c88b8d --- /dev/null +++ b/ml/dlib/dlib/test/conditioning_class.h @@ -0,0 +1,841 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_TEST_CONDITIONING_CLASs_H_ +#define DLIB_TEST_CONDITIONING_CLASs_H_ + + +#include +#include +#include +#include + +#include + +#include "tester.h" + +namespace +{ + + using namespace test; + using namespace std; + using namespace dlib; + + logger dlog("test.conditioning_class"); + + template < + typename cc, + typename cc2 + > + void conditioning_class_kernel_test ( + ) + /*! + requires + - cc is an implementation of conditioning_class/conditioning_class_kernel_abstract.h + the alphabet_size for cc is 256 + - cc2 is an implementation of conditioning_class/conditioning_class_kernel_abstract.h + the alphabet_size for cc2 is 2 + ensures + - runs tests on cc for compliance with the specs + !*/ + { + + srand(static_cast(time(0))); + + + + typename cc::global_state_type gs; + typename cc2::global_state_type gs2; + + + + + for (int g = 0; g < 2; ++g) + { + print_spinner(); + unsigned long amount=g+1; + cc2 test(gs2); + cc2 test2(gs2); + + + DLIB_TEST(test.get_memory_usage() != 0); + + const unsigned long alphabet_size = 2; + + + DLIB_TEST(test.get_total() == 1); + + DLIB_TEST(test.get_count(alphabet_size-1)==1); + for (unsigned long i = 0; i < alphabet_size-1; ++i) + { + unsigned long low_count, high_count, total_count; + DLIB_TEST_MSG(test.get_range(i,low_count,high_count,total_count) == 0,i); + DLIB_TEST(test.get_count(i) == 0); + DLIB_TEST(test.get_total() == 1); + } + + + + for (unsigned long i = 0; i < alphabet_size; ++i) + { + test.increment_count(i,static_cast(amount)); + unsigned long low_count = 0, high_count = 0, total_count = 0; + + if (i ==alphabet_size-1) + { + DLIB_TEST(test.get_range(i,low_count,high_count,total_count) == 1+amount); + + DLIB_TEST(high_count == low_count+1+amount); + DLIB_TEST(total_count == test.get_total()); + + + DLIB_TEST(test.get_count(i) == 1+amount); + } + else + { + DLIB_TEST(test.get_range(i,low_count,high_count,total_count) == amount); + + DLIB_TEST(high_count == low_count+amount); + DLIB_TEST(total_count == test.get_total()); + + + DLIB_TEST(test.get_count(i) == amount); + } + DLIB_TEST(test.get_total() == (i+1)*amount + 1); + } + + + for (unsigned long i = 0; i < alphabet_size; ++i) + { + unsigned long temp = static_cast(::rand()%40); + for (unsigned long j = 0; j < temp; ++j) + { + test.increment_count(i,static_cast(amount)); + if (i == alphabet_size-1) + { + DLIB_TEST(test.get_count(i) == (j+1)*amount + 1 + amount); + } + else + { + DLIB_TEST(test.get_count(i) == (j+1)*amount + amount); + } + } + + unsigned long target = test.get_total()/2; + unsigned long symbol = i, low_count = 0, high_count = 0, total_count = 0; + + if (i == alphabet_size-1) + { + DLIB_TEST(test.get_range(symbol,low_count,high_count,total_count)==temp*amount+1+amount); + DLIB_TEST(high_count-low_count == temp*amount+1+amount); + } + else + { + DLIB_TEST(test.get_range(symbol,low_count,high_count,total_count)==temp*amount + amount); + DLIB_TEST(high_count-low_count == temp*amount + amount); + } + DLIB_TEST(total_count == test.get_total()); + + test.get_symbol(target,symbol,low_count,high_count); + DLIB_TEST(test.get_count(symbol) == high_count-low_count); + DLIB_TEST(low_count <= target); + DLIB_TEST(target < high_count); + DLIB_TEST(high_count <= test.get_total()); + + } + + test.clear(); + + + for (unsigned long i = 0; i < alphabet_size-1; ++i) + { + test.increment_count(i); + unsigned long low_count = 0, high_count = 0, total_count = 0; + DLIB_TEST(test.get_range(i,low_count,high_count,total_count) == 1); + + DLIB_TEST(high_count == low_count+1); + DLIB_TEST(total_count == test.get_total()); + + DLIB_TEST(test.get_count(i) == 1); + DLIB_TEST(test.get_total() == i+2); + } + + + + + unsigned long counts[alphabet_size]; + + + print_spinner(); + for (int k = 0; k < 10; ++k) + { + unsigned long range = ::rand()%50000 + 2; + + test.clear(); + + for (unsigned long i = 0; i < alphabet_size-1; ++i) + counts[i] = 0; + unsigned long total = 1; + counts[alphabet_size-1] = 1; + + + for (unsigned long i = 0; i < alphabet_size; ++i) + { + unsigned long temp = static_cast(::rand()%range); + for (unsigned long j = 0; j < temp; ++j) + { + test.increment_count(i); + + + if (total >= 65535) + { + total = 0; + for (unsigned long i = 0; i < alphabet_size; ++i) + { + counts[i] >>= 1; + total += counts[i]; + } + if (counts[alphabet_size-1]==0) + { + counts[alphabet_size-1] = 1; + ++total; + } + } + counts[i] = counts[i] + 1; + ++total; + + + } + + + unsigned long temp_total = 0; + for (unsigned long a = 0; a < alphabet_size; ++a) + { + temp_total += test.get_count(a); + } + DLIB_TEST_MSG(temp_total == test.get_total(), + "temp_total == " << temp_total << endl << + "test.get_total() == " << test.get_total() + ); + + DLIB_TEST(test.get_count(alphabet_size-1) == counts[alphabet_size-1]); + DLIB_TEST_MSG(test.get_total() == total, + "test.get_total() == " << test.get_total() << endl << + "total == " << total + ); + + unsigned long target = test.get_total()/2; + unsigned long symbol = i, low_count = 0, high_count = 0, total_count = 0; + + + DLIB_TEST(test.get_range(symbol,low_count,high_count,total_count)==counts[symbol]); + + if (counts[symbol] != 0) + { + DLIB_TEST(total_count == total); + + DLIB_TEST(high_count <= total); + DLIB_TEST(low_count < high_count); + DLIB_TEST(high_count <= test.get_total()); + DLIB_TEST(test.get_count(symbol) == high_count-low_count); + } + + + if (target < total) + { + test.get_symbol(target,symbol,low_count,high_count); + + + DLIB_TEST(high_count <= total); + DLIB_TEST(low_count < high_count); + DLIB_TEST(high_count <= test.get_total()); + DLIB_TEST(test.get_count(symbol) == high_count-low_count); + DLIB_TEST(test.get_count(symbol) == counts[symbol]); + } + + + + + } + + } + + print_spinner(); + + for (unsigned long h = 0; h < 10; ++h) + { + test.clear(); + DLIB_TEST(test.get_total() == 1); + + // fill out test with some numbers + unsigned long temp = ::rand()%30000 + 50000; + for (unsigned long j = 0; j < temp; ++j) + { + unsigned long symbol = (unsigned long)::rand()%alphabet_size; + test.increment_count(symbol); + } + + // make sure all symbols have a count of at least one + for (unsigned long j = 0; j < alphabet_size; ++j) + { + if (test.get_count(j) == 0) + test.increment_count(j); + } + + unsigned long temp_total = 0; + for (unsigned long j = 0; j < alphabet_size; ++j) + { + temp_total += test.get_count(j); + } + DLIB_TEST(temp_total == test.get_total()); + + + unsigned long low_counts[alphabet_size]; + unsigned long high_counts[alphabet_size]; + // iterate over all the symbols + for (unsigned long j = 0; j < alphabet_size; ++j) + { + unsigned long total; + unsigned long count = test.get_range(j,low_counts[j],high_counts[j],total); + DLIB_TEST(count == test.get_count(j)); + DLIB_TEST(count == high_counts[j] - low_counts[j]); + + } + + + // make sure get_symbol() matches what get_range() told us + for (unsigned long j = 0; j < alphabet_size; ++j) + { + for (unsigned long k = low_counts[j]; k < high_counts[j]; ++k) + { + unsigned long symbol, low_count, high_count; + test.get_symbol(k,symbol,low_count,high_count); + DLIB_TEST(high_count - low_count == test.get_count(symbol)); + DLIB_TEST_MSG(j == symbol, + "j == " << j << endl << + "k == " << k << endl << + "symbol == " << symbol << endl << + "low_counts[j] == " << low_counts[j] << endl << + "high_counts[j] == " << high_counts[j] << endl << + "low_counts[symbol] == " << low_counts[symbol] << endl << + "high_counts[symbol] == " << high_counts[symbol] << endl << + "low_count == " << low_count << endl << + "high_count == " << high_count << endl << + "temp.count(j) == " << test.get_count(j) + ); + DLIB_TEST_MSG(low_count == low_counts[j], + "symbol: " << j << "\n" << + "target: " << k << "\n" << + "low_count: " << low_count << "\n" << + "low_counts[j]: " << low_counts[j]); + DLIB_TEST(high_count == high_counts[j]); + } + + } + + } + + + + print_spinner(); + + for (int h = 0; h < 10; ++h) + { + + + test.clear(); + + for (unsigned long k = 0; k < alphabet_size-1; ++k) + { + counts[k] = 0; + } + counts[alphabet_size-1] = 1; + unsigned long total = 1; + unsigned long i = ::rand()%alphabet_size; + + unsigned long temp = 65536; + for (unsigned long j = 0; j < temp; ++j) + { + test.increment_count(i); + + + if (total >= 65535) + { + total = 0; + for (unsigned long i = 0; i < alphabet_size; ++i) + { + counts[i] >>= 1; + total += counts[i]; + } + if (counts[alphabet_size-1] == 0) + { + ++total; + counts[alphabet_size-1] = 1; + } + } + counts[i] = counts[i] + 1; + ++total; + + } + + + DLIB_TEST(test.get_total() == total); + + unsigned long target = test.get_total()/2; + unsigned long symbol = i, low_count = 0, high_count = 0, total_count = 0; + + + DLIB_TEST(test.get_range(symbol,low_count,high_count,total_count)==counts[symbol]); + + if (counts[symbol] != 0) + { + DLIB_TEST(total_count == total); + + DLIB_TEST(high_count <= total); + DLIB_TEST(low_count < high_count); + DLIB_TEST(high_count <= test.get_total()); + DLIB_TEST(test.get_count(symbol) == high_count-low_count); + } + + + + test.get_symbol(target,symbol,low_count,high_count); + + + DLIB_TEST(high_count <= total); + DLIB_TEST(low_count < high_count); + DLIB_TEST(high_count <= test.get_total()); + DLIB_TEST(test.get_count(symbol) == high_count-low_count); + DLIB_TEST(test.get_count(symbol) == counts[symbol]); + + + + + + + + } + + } // for (int g = 0; g < 2; ++g) + + + + + + + + + + + + + + for (int g = 0; g < 2; ++g) + { + print_spinner(); + unsigned long amount=g+1; + cc test(gs); + cc test2(gs); + + DLIB_TEST(test.get_memory_usage() != 0); + + const unsigned long alphabet_size = 256; + + + DLIB_TEST(test.get_total() == 1); + + DLIB_TEST(test.get_count(alphabet_size-1)==1); + for (unsigned long i = 0; i < alphabet_size-1; ++i) + { + unsigned long low_count, high_count, total_count; + DLIB_TEST(test.get_range(i,low_count,high_count,total_count) == 0); + DLIB_TEST(test.get_count(i) == 0); + DLIB_TEST(test.get_total() == 1); + } + + + bool oom = false; + for (unsigned long i = 0; i < alphabet_size; ++i) + { + bool status = test.increment_count(i,static_cast(amount)); + unsigned long low_count = 0, high_count = 0, total_count = 0; + if (!status) + oom = true; + + if (status) + { + if (i ==alphabet_size-1) + { + DLIB_TEST(test.get_range(i,low_count,high_count,total_count) == 1+amount); + + DLIB_TEST(high_count == low_count+1+amount); + DLIB_TEST(total_count == test.get_total()); + + + DLIB_TEST(test.get_count(i) == 1+amount); + } + else + { + DLIB_TEST(test.get_range(i,low_count,high_count,total_count) == amount); + + DLIB_TEST(high_count == low_count+amount); + DLIB_TEST(total_count == test.get_total()); + + + DLIB_TEST(test.get_count(i) == amount); + } + if (!oom) + DLIB_TEST(test.get_total() == (i+1)*amount + 1); + } + } + + + oom = false; + for (unsigned long i = 0; i < alphabet_size; ++i) + { + unsigned long temp = static_cast(::rand()%40); + for (unsigned long j = 0; j < temp; ++j) + { + bool status = test.increment_count(i,static_cast(amount)); + if (!status) + oom = true; + if (status) + { + if (i == alphabet_size-1) + { + DLIB_TEST(test.get_count(i) == (j+1)*amount + 1 + amount); + } + else + { + DLIB_TEST(test.get_count(i) == (j+1)*amount + amount); + } + } + } + + unsigned long target = test.get_total()/2; + unsigned long symbol = i, low_count = 0, high_count = 0, total_count = 0; + + if (!oom) + { + if (i == alphabet_size-1) + { + DLIB_TEST(test.get_range(symbol,low_count,high_count,total_count)==temp*amount+1+amount); + DLIB_TEST(high_count-low_count == temp*amount+1+amount); + } + else + { + DLIB_TEST(test.get_range(symbol,low_count,high_count,total_count)==temp*amount + amount); + DLIB_TEST(high_count-low_count == temp*amount + amount); + } + DLIB_TEST(total_count == test.get_total()); + + + test.get_symbol(target,symbol,low_count,high_count); + DLIB_TEST(test.get_count(symbol) == high_count-low_count); + DLIB_TEST(low_count <= target); + DLIB_TEST(target < high_count); + DLIB_TEST(high_count <= test.get_total()); + } + + } + + test.clear(); + + + oom = false; + for (unsigned long i = 0; i < alphabet_size-1; ++i) + { + if(!test.increment_count(i)) + oom = true; + unsigned long low_count = 0, high_count = 0, total_count = 0; + + if (!oom) + { + DLIB_TEST(test.get_range(i,low_count,high_count,total_count) == 1); + + DLIB_TEST(high_count == low_count+1); + DLIB_TEST(total_count == test.get_total()); + + DLIB_TEST(test.get_count(i) == 1); + DLIB_TEST(test.get_total() == i+2); + } + } + + + + unsigned long counts[alphabet_size]; + + + for (int k = 0; k < 10; ++k) + { + unsigned long range = ::rand()%50000 + 2; + + test.clear(); + + for (unsigned long i = 0; i < alphabet_size-1; ++i) + counts[i] = 0; + unsigned long total = 1; + counts[alphabet_size-1] = 1; + + + oom = false; + for (unsigned long i = 0; i < alphabet_size; ++i) + { + unsigned long temp = static_cast(::rand()%range); + for (unsigned long j = 0; j < temp; ++j) + { + if (!test.increment_count(i)) + oom = true; + + + if (total >= 65535) + { + + total = 0; + for (unsigned long i = 0; i < alphabet_size; ++i) + { + counts[i] >>= 1; + total += counts[i]; + } + if (counts[alphabet_size-1]==0) + { + counts[alphabet_size-1] = 1; + ++total; + } + } + counts[i] = counts[i] + 1; + ++total; + + + } + + + unsigned long temp_total = 0; + for (unsigned long a = 0; a < alphabet_size; ++a) + { + temp_total += test.get_count(a); + } + + if (!oom) + { + DLIB_TEST_MSG(temp_total == test.get_total(), + "temp_total == " << temp_total << endl << + "test.get_total() == " << test.get_total() + ); + + DLIB_TEST(test.get_count(alphabet_size-1) == counts[alphabet_size-1]); + DLIB_TEST_MSG(test.get_total() == total, + "test.get_total() == " << test.get_total() << endl << + "total == " << total + ); + } + + unsigned long target = test.get_total()/2; + unsigned long symbol = i, low_count = 0, high_count = 0, total_count = 0; + + if (!oom) + { + + DLIB_TEST(test.get_range(symbol,low_count,high_count,total_count)==counts[symbol]); + + if (counts[symbol] != 0) + { + DLIB_TEST(total_count == total); + + DLIB_TEST(high_count <= total); + DLIB_TEST(low_count < high_count); + DLIB_TEST(high_count <= test.get_total()); + DLIB_TEST(test.get_count(symbol) == high_count-low_count); + } + + + if (target < total) + { + test.get_symbol(target,symbol,low_count,high_count); + + + DLIB_TEST(high_count <= total); + DLIB_TEST(low_count < high_count); + DLIB_TEST(high_count <= test.get_total()); + DLIB_TEST(test.get_count(symbol) == high_count-low_count); + DLIB_TEST(test.get_count(symbol) == counts[symbol]); + } + } + + + + } + + } + + oom = false; + for (unsigned long h = 0; h < 10; ++h) + { + test.clear(); + DLIB_TEST(test.get_total() == 1); + + // fill out test with some numbers + unsigned long temp = ::rand()%30000 + 50000; + for (unsigned long j = 0; j < temp; ++j) + { + unsigned long symbol = (unsigned long)::rand()%alphabet_size; + if (!test.increment_count(symbol)) + oom = true; + } + + // make sure all symbols have a count of at least one + for (unsigned long j = 0; j < alphabet_size; ++j) + { + if (test.get_count(j) == 0) + test.increment_count(j); + } + + unsigned long temp_total = 0; + for (unsigned long j = 0; j < alphabet_size; ++j) + { + temp_total += test.get_count(j); + } + if (!oom) + DLIB_TEST(temp_total == test.get_total()); + + + unsigned long low_counts[alphabet_size]; + unsigned long high_counts[alphabet_size]; + + if (!oom) + { + + // iterate over all the symbols + for (unsigned long j = 0; j < alphabet_size; ++j) + { + unsigned long total; + unsigned long count = test.get_range(j,low_counts[j],high_counts[j],total); + DLIB_TEST(count == test.get_count(j)); + DLIB_TEST(count == high_counts[j] - low_counts[j]); + + } + + + + + // make sure get_symbol() matches what get_range() told us + for (unsigned long j = 0; j < alphabet_size; ++j) + { + for (unsigned long k = low_counts[j]; k < high_counts[j]; ++k) + { + unsigned long symbol, low_count, high_count; + test.get_symbol(k,symbol,low_count,high_count); + DLIB_TEST(high_count - low_count == test.get_count(symbol)); + DLIB_TEST_MSG(j == symbol, + "j == " << j << endl << + "k == " << k << endl << + "symbol == " << symbol << endl << + "low_counts[j] == " << low_counts[j] << endl << + "high_counts[j] == " << high_counts[j] << endl << + "low_counts[symbol] == " << low_counts[symbol] << endl << + "high_counts[symbol] == " << high_counts[symbol] << endl << + "low_count == " << low_count << endl << + "high_count == " << high_count << endl << + "temp.count(j) == " << test.get_count(j) + ); + DLIB_TEST_MSG(low_count == low_counts[j], + "symbol: " << j << "\n" << + "target: " << k << "\n" << + "low_count: " << low_count << "\n" << + "low_counts[j]: " << low_counts[j]); + DLIB_TEST(high_count == high_counts[j]); + } + + } + } + + } + + + + + for (int h = 0; h < 10; ++h) + { + + + test.clear(); + + for (unsigned long k = 0; k < alphabet_size-1; ++k) + { + counts[k] = 0; + } + counts[alphabet_size-1] = 1; + unsigned long total = 1; + unsigned long i = ::rand()%alphabet_size; + + unsigned long temp = 65536; + for (unsigned long j = 0; j < temp; ++j) + { + test.increment_count(i); + + + if (total >= 65535) + { + total = 0; + for (unsigned long i = 0; i < alphabet_size; ++i) + { + counts[i] >>= 1; + total += counts[i]; + } + if (counts[alphabet_size-1] == 0) + { + ++total; + counts[alphabet_size-1] = 1; + } + } + counts[i] = counts[i] + 1; + ++total; + + } + + + DLIB_TEST(test.get_total() == total); + + unsigned long target = test.get_total()/2; + unsigned long symbol = i, low_count = 0, high_count = 0, total_count = 0; + + + DLIB_TEST(test.get_range(symbol,low_count,high_count,total_count)==counts[symbol]); + + if (counts[symbol] != 0) + { + DLIB_TEST(total_count == total); + + DLIB_TEST(high_count <= total); + DLIB_TEST(low_count < high_count); + DLIB_TEST(high_count <= test.get_total()); + DLIB_TEST(test.get_count(symbol) == high_count-low_count); + } + + + + test.get_symbol(target,symbol,low_count,high_count); + + + DLIB_TEST(high_count <= total); + DLIB_TEST(low_count < high_count); + DLIB_TEST(high_count <= test.get_total()); + DLIB_TEST(test.get_count(symbol) == high_count-low_count); + DLIB_TEST(test.get_count(symbol) == counts[symbol]); + + + + + + + + } + + } // for (int g = 0; g < 2; ++g) + + + } + +} + +#endif // DLIB_TEST_CONDITIONING_CLASs_H_ + diff --git a/ml/dlib/dlib/test/conditioning_class_c.cpp b/ml/dlib/dlib/test/conditioning_class_c.cpp new file mode 100644 index 000000000..4bfd9f32a --- /dev/null +++ b/ml/dlib/dlib/test/conditioning_class_c.cpp @@ -0,0 +1,87 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include + +#include + +#include "tester.h" +#include "conditioning_class.h" + +namespace +{ + + + class conditioning_class_tester : public tester + { + public: + conditioning_class_tester ( + ) : + tester ("test_conditioning_class_c", + "Runs tests on the conditioning_class checked components.") + {} + + void perform_test ( + ) + { + dlog << LINFO << "testing kernel_1a_c"; + conditioning_class_kernel_test< + conditioning_class<256>::kernel_1a_c, + conditioning_class<2>::kernel_1a_c + >(); + print_spinner(); + + dlog << LINFO << "testing kernel_2a_c"; + conditioning_class_kernel_test< + conditioning_class<256>::kernel_2a_c, + conditioning_class<2>::kernel_2a_c + >(); + print_spinner(); + + dlog << LINFO << "testing kernel_3a_c"; + conditioning_class_kernel_test< + conditioning_class<256>::kernel_3a_c, + conditioning_class<2>::kernel_3a_c + >(); + print_spinner(); + + dlog << LINFO << "testing kernel_4a_c"; + conditioning_class_kernel_test< + conditioning_class<256>::kernel_4a_c, + conditioning_class<2>::kernel_4a_c + >(); + print_spinner(); + + dlog << LINFO << "testing kernel_4b_c"; + conditioning_class_kernel_test< + conditioning_class<256>::kernel_4b_c, + conditioning_class<2>::kernel_4b_c + >(); + print_spinner(); + + + dlog << LINFO << "testing kernel_4c_c"; + conditioning_class_kernel_test< + conditioning_class<256>::kernel_4c_c, + conditioning_class<2>::kernel_4c_c + >(); + print_spinner(); + + dlog << LINFO << "testing kernel_4d_c"; + conditioning_class_kernel_test< + conditioning_class<256>::kernel_4d_c, + conditioning_class<2>::kernel_4d_c + >(); + print_spinner(); + + + } + } a; + + +} + diff --git a/ml/dlib/dlib/test/config_reader.cpp b/ml/dlib/dlib/test/config_reader.cpp new file mode 100644 index 000000000..20b5215f3 --- /dev/null +++ b/ml/dlib/dlib/test/config_reader.cpp @@ -0,0 +1,509 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include +#include + +#include "tester.h" + +// This is called an unnamed-namespace and it has the effect of making everything inside this file "private" +// so that everything you declare will have static linkage. Thus we won't have any multiply +// defined symbol errors coming out of the linker when we try to compile the test suite. +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + // Declare the logger we will use in this test. The name of the tester + // should start with "test." + logger dlog("test.config_reader"); + + template < + typename config_reader + > + void do_the_tests ( + config_reader& cr + ) + { + DLIB_TEST(cr.is_key_defined("global")); + DLIB_TEST(cr.is_block_defined("all")); + DLIB_TEST(cr.is_key_defined("globalasfd") == false); + DLIB_TEST(cr.is_block_defined("all!") == false); + DLIB_TEST(cr["global"] == "hmm"); + DLIB_TEST(cr["global2"] == "hmm2"); + + std_vector_c blocks; + cr.block("all").get_blocks(blocks); + DLIB_TEST(blocks.size() == 4); + cr.block("all").block("block1").get_blocks(blocks); DLIB_TEST(blocks.size() == 0); + cr.block("all").block("block2").get_blocks(blocks); DLIB_TEST(blocks.size() == 0); + cr.block("all").block("block3").get_blocks(blocks); DLIB_TEST(blocks.size() == 0); + cr.block("all").block("block4").get_blocks(blocks); DLIB_TEST(blocks.size() == 0); + + DLIB_TEST(cr.block("all").block("block1").is_key_defined("name")); + DLIB_TEST(cr.block("all").block("block2").is_key_defined("name")); + DLIB_TEST(cr.block("all").block("block3").is_key_defined("name")); + DLIB_TEST(cr.block("all").block("block4").is_key_defined("name")); + DLIB_TEST(cr.block("all").block("block1").is_key_defined("age")); + DLIB_TEST(cr.block("all").block("block2").is_key_defined("age")); + DLIB_TEST(cr.block("all").block("block3").is_key_defined("age")); + DLIB_TEST(cr.block("all").block("block4").is_key_defined("age")); + + DLIB_TEST(cr.block("all").block("block1")["name"] == "davis king"); + DLIB_TEST(cr.block("all").block("block2")["name"] == "joel"); + DLIB_TEST(cr.block("all").block("block3")["name"] == "john"); + DLIB_TEST(cr.block("all").block("block4")["name"] == "dude"); + DLIB_TEST(cr.block("all").block("block1")["age"] == "24"); + DLIB_TEST(cr.block("all").block("block2")["age"] == "24"); + DLIB_TEST(cr.block("all").block("block3")["age"] == "24"); + DLIB_TEST(cr.block("all").block("block4")["age"] == "53"); + + + int count2 = 0; + cr.get_blocks(blocks); + DLIB_TEST(blocks.size() == 1); + DLIB_TEST(blocks[0] == "all"); + + + DLIB_TEST(cr.block("all").is_key_defined("global") == false); + DLIB_TEST(cr.block("all").is_key_defined("global2") == false); + DLIB_TEST(cr.block("all").is_key_defined("name") == false); + DLIB_TEST(cr.block("all").is_key_defined("age") == false); + + cr.block("all").get_blocks(blocks); + DLIB_TEST(blocks.size() == 4); + std::vector temp_blocks; + for (unsigned long i = 0; i < blocks.size(); ++i) + { + ++count2; + ostringstream sout; + sout << "block" << count2; + DLIB_TEST(blocks[i] == sout.str()); + + cr.block("all").block(blocks[i]).get_blocks(temp_blocks); + DLIB_TEST(temp_blocks.size() == 0); + + DLIB_TEST(cr.block("all").block(blocks[i]).is_key_defined("name")); + DLIB_TEST(cr.block("all").block(blocks[i]).is_key_defined("age")); + } + + + + bool found_error = false; + try + { + cr.block("bogus_block"); + } + catch (typename config_reader::config_reader_access_error& e) + { + DLIB_TEST(e.block_name == "bogus_block"); + DLIB_TEST(e.key_name == ""); + found_error = true; + } + DLIB_TEST(found_error); + + found_error = false; + try + { + cr["bogus_key"]; + } + catch (typename config_reader::config_reader_access_error& e) + { + DLIB_TEST(e.block_name == ""); + DLIB_TEST(e.key_name == "bogus_key"); + found_error = true; + } + DLIB_TEST(found_error); + + + found_error = false; + try + { + cr.block("all").block("block10"); + } + catch (typename config_reader::config_reader_access_error& e) + { + DLIB_TEST(e.block_name == "block10"); + DLIB_TEST(e.key_name == ""); + found_error = true; + } + DLIB_TEST(found_error); + + found_error = false; + try + { + cr.block("all")["msdofg"]; + } + catch (typename config_reader::config_reader_access_error& e) + { + DLIB_TEST(e.block_name == ""); + DLIB_TEST(e.key_name == "msdofg"); + found_error = true; + } + DLIB_TEST(found_error); + + } + + + + template < + typename config_reader + > + void config_reader_test ( + ) + /*! + requires + - config_reader is an implementation of config_reader/config_reader_kernel_abstract.h + is instantiated with int + ensures + - runs tests on config_reader for compliance with the specs + !*/ + { + + + + ostringstream sout; + + sout << "all#comment { { } \n"; + sout << "{ \n"; + sout << " block1 \n"; + sout << " { \n"; + sout << " name = davis king \n"; + sout << " age = 24 \n"; + sout << " } \n"; + sout << " \n"; + sout << " block2 \n"; + sout << " { \n"; + sout << " name= joel \n"; + sout << " age =24 \n"; + sout << " } \n"; + sout << " \n"; + sout << " block3 \n"; + sout << " { \n"; + sout << " name = john \n"; + sout << " age = 24 \n"; + sout << " } \n"; + sout << " #comment \n"; + sout << "#comment \n"; + sout << " block4{ # comment"; + sout << " \n"; + sout << " name = dude \n"; + sout << " age = 53}\n"; + sout << " \n"; + sout << "} \n"; + sout << " \n"; + sout << " \n"; + sout << "global=hmm#comment \n"; + sout << "global2=hmm2 \n"; + sout << " # comment \n"; + + string data = sout.str(); + + config_reader cr2; + for (int i = 0; i < 3; ++i) + { + istringstream sin; + + sin.clear(); + sin.str(data); + + config_reader cr(sin); + sin.clear(); + sin.str(data); + + cr2.load_from(sin); + + do_the_tests(cr); + do_the_tests(cr2); + + cr.clear(); + DLIB_TEST(cr.is_key_defined("global") == false); + } + + + sout.clear(); + sout.str(""); + + { + sout << "all#comment { { } \n"; + sout << "{ \n"; + sout << " block1 \n"; + sout << " { \n"; + sout << " name = davis king \n"; + sout << " age = 24 \n"; + sout << " } \n"; + sout << " \n"; + sout << " block2 \n"; + sout << " { \n"; + sout << " name= joel \n"; + sout << " age =24 \n"; + sout << " } \n"; + sout << " \n"; + sout << " block3 \n"; + sout << " {{ \n"; // error on this line + sout << " name = john \n"; + sout << " age = 24 \n"; + sout << " } \n"; + sout << " #comment \n"; + sout << "#comment \n"; + sout << " block4{ # comment"; + sout << " \n"; + sout << " name = dude \n"; + sout << " age = 53}\n"; + sout << " \n"; + sout << "} \n"; + sout << " \n"; + sout << " \n"; + sout << "global=hmm#comment \n"; + sout << "global2=hmm2 \n"; + sout << " # comment \n"; + + istringstream sin(sout.str()); + + bool error_found = false; + try + { + cr2.load_from(sin); + } + catch (typename config_reader::config_reader_error& e) + { + error_found = true; + DLIB_TEST(e.line_number == 16); + DLIB_TEST(e.redefinition == false); + } + DLIB_TEST(error_found); + } + + { + sout.str(""); + sout.clear(); + sout << "all#comment { { } \n"; + sout << "{ \n"; + sout << " block1 \n"; + sout << " { \n"; + sout << " name = davis king \n"; + sout << " age = 24 \n"; + sout << " } \n"; + sout << " \n"; + sout << " block2 \n"; + sout << " { \n"; + sout << " name= joel \n"; + sout << " age =24 \n"; + sout << " } \n"; + sout << " \n"; + sout << " block3 \n"; + sout << " { \n"; + sout << " name = john \n"; + sout << " age = 24 \n"; + sout << " } \n"; + sout << " #comment \n"; + sout << "#comment \n"; + sout << " block4{ # comment"; + sout << " \n"; + sout << " name = dude \n"; + sout << " age = 53}\n"; + sout << " \n"; + sout << "} \n"; + sout << " \n"; + sout << " \n"; + sout << "global=hmm#comment \n"; + sout << " \n"; + sout << "global=hmm2 \n"; // error on this line + sout << " # comment \n"; + + istringstream sin(sout.str()); + + bool error_found = false; + try + { + cr2.load_from(sin); + } + catch (typename config_reader::config_reader_error& e) + { + error_found = true; + DLIB_TEST_MSG(e.line_number == 31,e.line_number); + DLIB_TEST(e.redefinition == true); + } + DLIB_TEST(error_found); + } + + + { + sout.str(""); + sout.clear(); + sout << "all#comment { { } \n"; + sout << "{ \n"; + sout << " block1 \n"; + sout << " { \n"; + sout << " name = davis king \n"; + sout << " age = 24 \n"; + sout << " } \n"; + sout << " \n"; + sout << " block2 \n"; + sout << " { \n"; + sout << " name= joel \n"; + sout << " age =24 \n"; + sout << " } block2{} \n"; // error on this line + sout << " \n"; + sout << " block3 \n"; + sout << " { \n"; + sout << " name = john \n"; + sout << " age = 24 \n"; + sout << " } \n"; + sout << " #comment \n"; + sout << "#comment \n"; + sout << " block4{ # comment"; + sout << " \n"; + sout << " name = dude \n"; + sout << " age = 53}\n"; + sout << " \n"; + sout << "} \n"; + sout << " \n"; + sout << " \n"; + sout << "global=hmm#comment \n"; + sout << " \n"; + sout << " # comment \n"; + + istringstream sin(sout.str()); + + bool error_found = false; + try + { + cr2.load_from(sin); + } + catch (typename config_reader::config_reader_error& e) + { + error_found = true; + DLIB_TEST_MSG(e.line_number == 13,e.line_number); + DLIB_TEST(e.redefinition == true); + } + DLIB_TEST(error_found); + } + + + + } + + + void test_get_option() + { + const char* argv[100]; + int argc; + + // program --opt 4 -d dude + argv[0] = "program"; + argv[1] = "--opt"; + argv[2] = "4"; + argv[3] = "-d"; + argv[4] = "dude"; + argc = 5; + + std::ostringstream sout; + sout << "block#comment { { } \n"; + sout << "{ \n"; + sout << " opt = 5 \n"; + sout << " a = 6 \n"; + sout << " d = joel \n"; + sout << " subblock {} \n"; + sout << "} \n"; + sout << " \n"; + sout << " \n"; + sout << "opt = 8 \n"; + sout << "d = davis \n"; + sout << "a = 50 \n"; + sout << " # comment \n"; + + std::istringstream sin(sout.str()); + + config_reader cr(sin); + + dlib::cmd_line_parser::kernel_1a_c parser; + + parser.add_option("opt","",1); + parser.add_option("d","",1); + parser.add_option("a","",1); + parser.add_option("b","",1); + parser.parse(argc, argv); + + DLIB_TEST(get_option(cr, "d", "default") == "davis"); + DLIB_TEST(get_option(cr, "opt", "default") == "8"); + DLIB_TEST(get_option(cr, "opt", 1) == 8); + DLIB_TEST(get_option(cr, "optasdf", 1) == 1); + DLIB_TEST(get_option(cr, "optasdf", 1.1) == 1.1); + DLIB_TEST(get_option(cr.block("block"), "d", "default") == "joel"); + DLIB_TEST(get_option(cr.block("block"), "opt", "default") == "5"); + DLIB_TEST(get_option(cr.block("block"), "opt", 1) == 5); + DLIB_TEST(get_option(cr.block("block").block("subblock"), "d", "default") == "default"); + DLIB_TEST(get_option(cr.block("block").block("subblock"), "opt", "default") == "default"); + DLIB_TEST(get_option(cr.block("block").block("subblock"), "opt", 1) == 1); + DLIB_TEST(get_option(cr, "block.d", "default") == "joel"); + DLIB_TEST(get_option(cr, "block.opt", "default") == "5"); + DLIB_TEST(get_option(cr, "block.opt", 1) == 5); + DLIB_TEST(get_option(cr, "block.asdf.d", "default") == "default"); + DLIB_TEST(get_option(cr, "block.asdf.opt", "default") == "default"); + DLIB_TEST(get_option(cr, "block.asdf.opt", 2) == 2); + DLIB_TEST(get_option(cr, "block.subblock.d", "default") == "default"); + DLIB_TEST(get_option(cr, "block.subblock.opt", "default") == "default"); + DLIB_TEST(get_option(cr, "block.subblock.opt", 2) == 2); + + DLIB_TEST(get_option(parser, "opt", 99) == 4); + DLIB_TEST(get_option(parser, "d", "stuff") == "dude"); + DLIB_TEST(get_option(parser, "a", "stuff") == "stuff"); + DLIB_TEST(get_option(parser, "a", 99) == 99); + + DLIB_TEST(get_option(parser, cr, "d", "default") == "dude"); + DLIB_TEST(get_option(cr, parser, "d", "default") == "dude"); + DLIB_TEST(get_option(parser, cr, "a", 2) == 50); + DLIB_TEST(get_option(cr, parser, "a", 2) == 50); + DLIB_TEST(get_option(parser, cr, "opt", 2) == 4); + DLIB_TEST(get_option(cr, parser, "opt", 2) == 4); + DLIB_TEST(get_option(parser, cr, "b", 2) == 2); + DLIB_TEST(get_option(cr, parser, "b", 2) == 2); + + DLIB_TEST(get_option(parser, cr.block("block"), "a", 2) == 6); + DLIB_TEST(get_option(cr.block("block"), parser, "a", 2) == 6); + } + + + class config_reader_tester : public tester + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a test for the config_reader object. When it is constructed + it adds itself into the testing framework. The command line switch is + specified as test_config_reader by passing that string to the tester constructor. + !*/ + public: + config_reader_tester ( + ) : + tester ("test_config_reader", + "Runs tests on the config_reader component.") + {} + + void perform_test ( + ) + { + dlog << LINFO << "testing config_reader"; + print_spinner(); + config_reader_test(); + + dlog << LINFO << "testing config_reader_thread_safe"; + print_spinner(); + config_reader_test(); + + dlog << LINFO << "testing get_option()"; + print_spinner(); + test_get_option(); + } + } a; + +} + + diff --git a/ml/dlib/dlib/test/correlation_tracker.cpp b/ml/dlib/dlib/test/correlation_tracker.cpp new file mode 100644 index 000000000..5ccf61c57 --- /dev/null +++ b/ml/dlib/dlib/test/correlation_tracker.cpp @@ -0,0 +1,955 @@ +// Copyright (C) 2016 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "tester.h" +#include +#include +#include +#include +#include +#include + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + dlib::logger dlog("test.correlation_tracker"); + + + class correlation_tracker_tester : public tester + { + public: + correlation_tracker_tester( + ) : + tester ( + "test_correlation_tracker", // the command line argument name for this test + "Run tests on the correlation_tracker functions.", // the command line argument description + 0 // the number of command line arguments for this test + ) + { + } + + void perform_test ( + ) + { + dlog << LINFO << "perform_test()"; + + typedef const std::string(*frame_fn_type)(); + // frames from examples folder + frame_fn_type frames[] = { &get_decoded_string_frame_000100, + &get_decoded_string_frame_000101, + &get_decoded_string_frame_000102, + &get_decoded_string_frame_000103 + }; + // correct tracking rectangles - recorded by successful runs + drectangle correct_rects[] = {drectangle(74, 67, 111, 152), + drectangle(76.025, 72.634, 112.799, 157.114), + drectangle(78.6849, 78.504, 115.413, 162.88), + drectangle(82.7572, 83.6035, 120.319, 169.895) + }; + // correct update results - recorded by successful runs + double correct_update_results[] = { 0, 18.3077, 16.8406, 13.1716 }; + + correlation_tracker tracker; + std::istringstream sin(frames[0]()); + array2d img; + load_bmp(img, sin); + tracker.start_track(img, centered_rect(point(93, 110), 38, 86)); + for (unsigned i = 1; i < sizeof(frames) / sizeof(frames[0]); ++i) + { + std::istringstream sin(frames[i]()); + load_bmp(img, sin); + + double res = tracker.update(img); + double correct_res = correct_update_results[i]; + double res_diff = abs(correct_res - res); + + drectangle pos = tracker.get_position(); + drectangle correct_pos = correct_rects[i]; + drectangle pos_intresect = pos.intersect(correct_pos); + double pos_area = pos.area(); + double intersect_area = pos_intresect.area(); + double rect_confidence = intersect_area / pos_area; + + dlog << LINFO << "Frame #" << i << " res: " << res << " correct res: " << correct_res << " pos: " << pos + << " correct pos: " << correct_pos << " rect confidence: " << rect_confidence; + + // small error possible due to rounding and different optimization options + DLIB_TEST(res_diff <= 1); + DLIB_TEST(rect_confidence >= 0.97); + print_spinner(); + } + } + + // ------------------------------------------------------------------------------------ + + // This function returns the contents of the file 'frame_000100.bmp' + static const std::string get_decoded_string_frame_000100() + { + dlib::base64 base64_coder; + dlib::compress_stream::kernel_1ea compressor; + std::ostringstream sout; + std::istringstream sin; + + // The base64 encoded data from the file '..\..\examples\video_frames\frame_000100.bmp' we want to decode and return. + sout << "Qld+lmlZhXVX5NDRWIFG1T4+OGdJbkmxAXXdHZEtpDf6knVTlRWyAhgv85Tf11KJZbhKv1dcsKtQ"; + sout << "fX3y/9RxoGwtXxs/KaBv3IrfvBK5WFRWzOM7H11xVin5AjxdzyzgU28fRgEasu0Jvk3SaMBNM1cI"; + sout << "ZAK+MA+qEKwBn+rkuLbptS4sUNNC8PWaqLqNQ577TiMbsEgoa1FGkZX4feP8EfvV7w8H4FvbS+Cl"; + sout << "yJCj0Q0Bx2fqFCDaAVtlm41VIJUFS7AePKmSnVnAlYYOQ35XxWS2sjAlWZHXiEfKmWmSOTt3x9/u"; + sout << "5l78aPB0FvrUgF5RrvyVIOSgsC250JwINl8uulOOsykmhUUjRCw81dAITDr/d2EAYmPyiYNZWbvl"; + sout << "0Iy4f9sOJr0HnaPdCSvojI9/xdQVT3MVjHhu82UtpJ2cuc1MjUooCZdQmTn1mksdAdZtsn78aRMw"; + sout << "P97MKVgFmYMmG+yekaN6+nYFgTXnuEnatHDszyavB3+73iWTYiYs+CDRqEJ7XUkp+F97W/hUuvFr"; + sout << "DwePerItRkWZN8Nmnjx0Jibv2wu37zkFVnFIAydvV9GzaFDd75hUwQE5rjGN8jp4dBMNwcK6l7kt"; + sout << "xLm5hYF2eo9sESGJqR+8uAhGsQZESUpdBvLmf/+28MVfY7Li0lmeq+bz+JxvIa/jF5ptffuTRtZd"; + sout << "WLbJK+EXlfeorbfh+6di5CPp5/50W5zrOepXHeDfrzCIpz6XTCw4k749ajiMk5y5XUAk/gObe1bG"; + sout << "JZeGfMezy6bNcosolf5mmg5yDk07bK0JTLGFvYALT8aZFfUTHR8+47krp43r6XxaalqH3JFy3pYz"; + sout << "3t1IQPC2wz9FiCGBpn+UdULDwpBt0oqbGJYGEHDWgzwv8dv31NEbqez+G+HSEfF7CWs80SH3yYzL"; + sout << "9EjKo7ANmcYWtr8+3/SpUl0Yyzn/7zDL7191rWglr3N06+bB+bdQy7J1yGueGtTsO5XMqfUEQL5n"; + sout << "xQ2FOU9c39cruW0Fp8TDkEKM6gowBRTelXMA1w84+L9Dynheg1iLE8MSrzvA0/TlXoFD7Wt1cC/6"; + sout << "EoYPulG3BRy/3ugzUeiHm4ZWh95YOflbH2p8Ix+5/sBZ+W5xSr/37NRYi1nQ1jFwqPdAHx/OGvfW"; + sout << "BKhz+PF3146JpvuVbq9fTqjqYeM+qcAu1mS+VqDyJTw3IlvRUsixyJu7/eN01m08yumXP2mm4SwC"; + sout << "Q0j4aPKRzl4rRx5PJE4qnKxmrxSa5DavIO6ik+DKR+leGGk29zEqY+UcsqcJQLKWVy3k40wL8ouF"; + sout << "ihEPUISSZDTE/iU/4wntqVXCP9vRNYnrnXVbh6EFFA6nr4dBrj+qnkkpDxrt3tGA0Dml1u68hNhI"; + sout << "aI2bAca+a1TXb9mn18JTJqK6tkoJmE4JP/T5W03PSGYhRDa0ddNEt06eVPuCrSyDwQ2lQVjHiF7N"; + sout << "iHVbLJOQ49//SmXdeId3+sJ4NwRM0jquqURHrBAYwuEImtqZDgJ9Ac7iC5FDMuigc3KJ6Y3MRAI9"; + sout << "irrbcei2XbALqQH4hsHiOvplAaUwqE2f8iaHHQrzSKdNoRpl/0rWgEU4eWcKVl7JmmwLRWqRDg6x"; + sout << "jYDZ9Bu3KAR3llOBPHN6G6VcVx8TfP6qHb7bYh0+lKqIv7qNYUdbVNZLnfBSrgYkrbf8LmgiUAyS"; + sout << "d3JNuR7XsICjOPdEgC5OhKHNT8w/wcZBN1T2svgaokoYxI+dV9t84s5h+W1RgCLTnEu5Lz811RlG"; + sout << "KEAZ+AgqPAk6PneeZs/Ujk1Q8ISZ/K6pL1SBnZN2RaN731UHMYJKK4/yCfdkqGAtFBnS4vvA82tV"; + sout << "uv2SwPvhD2KHvZDCsAVPXFQWVVkzTWKRcSXe55vgnZjF33ziAFILy9xNklmZZ/tDpZa3I6VDkNUW"; + sout << "o5+eLuUcJRi/dlrKnNaHSsjV6jQd0ud+JB+Gv2YfV9XcxFZDPavRGkNRdOTe045rJ7QKTI5OqZ5o"; + sout << "DZ9Olo25TyXVteyt9CwVAlLB2owBoEmni8HFWUS1TiFWmU+DupCJw60bwE9SLhR1KpfUp32Myg2S"; + sout << "6e7PJVndER84KdqRkPFWZyJsBTXJ6U1g273ty7MLFxZYmCp7SwvtXNrPMJ5yCS5VGa7sX8xavNkJ"; + sout << "ade8iiSs0Upqbm71iDbMCp4oBRxa1w9zMWZ0VF+Qg4HVyKdyRmtJzBmRvUmujyYFCQwMtmzCgIu7"; + sout << "mPCbA7EWkoteGnuFIgag9P7kAzEF6pI46gd6UyOOFF44eUytnP1AqCJYHvNexViNZ0FXzgoCzHMs"; + sout << "m6RPhsqJnQjG+KaefJm52DCfLNOPd6+rj4mv5iQeVEyNUtdHuS3ll1svho7IKMF0KXyjajXV2X8i"; + sout << "BaHxZ25zSlogQy//I4qdHx2SlUkiGJC8jpRZ9LVmaCWGCnbEnHn3uiiwTLRiPO0G2qgbMjgMvius"; + sout << "XjOSu1Dlbz4A4XJ5nxRyF+7uG450OJbVF6LUYujHukrHn8LADkKQxY60sddDMcx16lofSVJEF2X8"; + sout << "pwhXS+EcqDHQ5pT0Z7+fMFSumC6kTWJIfcwXrTJDcWslF+cnQw/qmlnSjK1jnTyU5YpjyByETfWc"; + sout << "hpFJbt0INtrUIDjOwXha0dlbC/jLXrVbW26wTJrKmi7ftD6CJSE8cZBGLXCX4zakRRPg3DEHsv8D"; + sout << "BFwPWlX7keygRS8L6o8teh9LHlYSbFqlno7FNqKh1V7D1ERDvsBO+6xoabo+mYXOdtyVWPWxJ0Kg"; + sout << "1qsnFibHgmqi4RFHoSMkAcHVYj9cHhiFo3DSj8sm8l6FGBQSm+o+oehamFW6xIL9ICgm7gcBQH6N"; + sout << "BH3f6q0CorAetVw7kyLqNiEt7G3V1DRjJ7bcn0/ziYsn6dZQ6kq9gslBWZe0dp227DpameI0TliG"; + sout << "pZ+d5+sm6g1jE1H5H6Zlh6jJ9xDMywzNFOitOdcOL0YFuzsxKXVPA1Ye9/wYqtUyJbFvO2ZqDSbC"; + sout << "k2z1UcbCv9KQhonO7aO6qgkaRIFmPmV82R7JsHqjA7CPQC/hURfB7qWP5mzvyB8oEqtd8LSuMgCC"; + sout << "2WRsJTQ6GU40AoOcUFDPOo/Gv99uWzDQaGgJJzIxPlKhT3EtJWH7ZiapmojgRgdS2iEAXb4adzRQ"; + sout << "/NsD80C6PvsuxB6NEravtrBV4nz7fXq2u1mPkjv+/jmp0hHf5SV1YDE+j+EfZqNMgEMaWJfMDbxs"; + sout << "C70Ji0d8iv7zyevr8fIrVmTbttOxS73NShBRX7mHQnUcDZFDfLJy464v8PEvwX0gr8/ytzdMJTrN"; + sout << "LEYHS/+ZxEN3CZyr5tzYL+DdqsTY4qFyuoQOQ0brplrGc2E1LubNhGVHXoKTsAE1+D9oq81HG4VH"; + sout << "Y91M0A3+ckyDvQ+dTtPx+KT0on8OkZzSLy0VUJ6yHVmzlzL2R2B/C8Gp2KHKk2YJLxE8DLPRkOyz"; + sout << "ZPrWyDUWKir87Cm6xHLScUj0MWujlg/ag1CLz38LKEBzoYxCzqQ/Zmyfp3NZRi+PEhSlSj8AyVPv"; + sout << "FOFacvTBn+dQwWq6uLbOBXJiycohjgBGxrLc7ILTp2FHmNfpELCTCowepYKuO8VnEBGlIUJAH7+I"; + sout << "zTtM2bvgr7k0b8dwh4SMdCBSOWPfYxuY8kVy/vcW0xKeAkrwKlVuL1+HJAtUTK0p9BOWl3bi7E4q"; + sout << "Ur5UH6CSaVzbn7QvGh1wSKbr48XdeOx/MlqvH6yKeXxT+iogx7BdmQAMOsndXWH+T+HBbmqIF6Z8"; + sout << "qA68Q6qcooNpHu13lqIojIbbf6VUT9eG9RqGPYXxKBAdWuRDbMmklXYoM7PI4VlmBcIgq+4U9Bps"; + sout << "Yge/Oqf6DPFxPaqG+FnYPl7GpcXHNHAsBh5JwaDAM5uKCq4Y720KHLKqG1buziPjRwUr/5LhXyu6"; + sout << "CF1ZzEKb6w26NDjD47myrFKRihB1zVZBxb6uIlHVcEqFZhJqC0uV/QQjRUfAQrG9G804Bcf7nuGM"; + sout << "6u+nXr1Xl+oQIRiEYueMD1T0WbiIYIIO6lT0ICkSQQjoLQXq/8KP21Yrutv3H4PjDZXOPi9Z6Rcl"; + sout << "40Vfdt9jyqeRgBhkhloRYFOBwXLYxlb6umBcmXoT3cv/Zlh4SfQrpG/LrfAHtYFyfwoDzXcV8Xuq"; + sout << "qAtTkj97UcAVnfC0lNrBnWCT2SieQUl8nNLoIQl9uhzRBFHH7U9ey6uX3zop4esABGV/DL1ypLd2"; + sout << "Yu2Q6XU+1j7wh3Rn5zttubqvYR8P5jMEqMUaX2Fith42jHy0U1NuNzX/MPCm6DFLvah/G/0sPKiP"; + sout << "lboucbKekedV/knGV3kx+h6MmMgBRC4OkkPf3HG5ewpVpe4I3SZ/auTk1ZnePM5RL5wskdGCdbND"; + sout << "njzEv5sKtTE1TxCAgVh6iFUtvr7c+3DNDkTilcm+ILRMwEr9vCF/OGcOT3YWr9cW/eP6fy/801bU"; + sout << "nhotDzcKSPQbbN7jd3WIe6kaeXjsBk6LIHmPjtGPT3UdRFhEJZrnABhQrasLPeY0/+dODC5Y3GcC"; + sout << "fV68kmensUcZ/5hFSEpLsScx/OtGvFI7hLNBAAJ6bhEuvgId7OXxQgWjQ6wKDhwOJmEB1Ra/Bztw"; + sout << "4bGdRMl6c99nGcc9GmSWI1CmQZqbNfxL7AKO1oSkwpWbK9E9Cl+ZV5XyP28AQky4KQlYBvwUDH57"; + sout << "LvtOvuIe/HzSagQjG8Sxf7wGF+MAz7MLRtFYGkHviL+mRvtp0NSHSYDEOqaIsnk1HbTWHUj0FUQt"; + sout << "AD14VkX9ReVi1LnJGJNwmjeLYk7keYTkksi/kKEWCkh/8Av2Rlk7oZCwODxsAcnAyXStJV0tEAHN"; + sout << "lq7MaAJz5zWZ4ZsETA5mn+3Qj4jIi5vbqJScxOh5KCvyOYfio9Eo+DPvyJtvyjGnqVkoFFk7v+8H"; + sout << "e4d20Aq0vfYGKPzVNoXAtOPuA6/d/2yizrl9gEPBGVw2I3Q1/XlFMIHHZ3yT1TGJfGXcVkrYRhDY"; + sout << "s8k6oEKbp0QQUiSkercwxbLpSMblFDaGeudPcWzKVwPE52rMEzh1/jr72s6ZUd/+Os2b3+CEg+lj"; + sout << "AlnkU1tNZVnanmzOb5xBnk1H+3Q+FppUe9MgOK73344O7QlbBklWclKizZIoRVaOeP7WvjXSJK/A"; + sout << "PPXXyKLMwP8eYDDPKOTA7MvBH3q4r7j7Au5av2sn1ZGwxBHLH4aYFH54kXJa5rwl0TgsTAHsQ2+5"; + sout << "0/OWixbr8Ysl2JxunMjcewhJFI3mLHFmsYfVYDwC2+hIubW231HCxvqI5BgWZZRU+9DZ2hdQ+orw"; + sout << "qocOI3yK5uHs6Qaxv9HPgWzNXLG4a1h2C+RbhIFMGNUcDAGrCgBCgepruLT0RLuaSYrsZY/gWRwj"; + sout << "krk1M69/lnwvZlasFruau26RLyMAbELixpeeGpztv/tJx5fqStsjMGLHUsRj0SKgXgdawRdgUXUs"; + sout << "Hepp9HsmJ/HR+5xAv8ORiEQekO02b/2wgqGTYXyHGT+RP9f76A90kVJ6Px4njrOLrICyTUcIFrIh"; + sout << "0AvbAEysPnxbtYcp5tdNbG2RuarygtgscOxERIa1NaunqnBouzW88mOLeTgBV9Ofh2fvX3qalfcs"; + sout << "xYY4//taBvyKWHiCanf5drYyOBPm1m8uzNP9ew6K6IYC6S4m8aklO+NrkoCqL2sq6Ged5y1nsWws"; + sout << "BLeSaCakjVDU3ysh+J8YamF1Tzp+cOfh9dAquG4sk2jJmSfgl31jtG2lCBkBRQf+Cybgr7QfJYGd"; + sout << "ZoX4RLzmSK3BXrufSNHntV7dXKCDT+2NvrW4n1EUegwonALERWRTMRDmA2NPZYgqr5cy8v0KQr0V"; + sout << "hvEMjqP/9SPJdTtbjJY7WtdTNU5Er4KdzasYcJjphAPU3zC5PtCHbtTVUGXNz1UneXnzrN13zlK4"; + sout << "WyJ9UcXeWpjtztobQZ/8NwXiYOndZ+/qF3BdYjHYDYhSjod4JyCnmLuy+cVG9yB3e21XVeUVC8sh"; + sout << "7yg2oBx+9yn1ZK5ResSEAmX+m6Jq0etJWVnYvnl6TSN0XAFbZRk5n6r9w443Frw8RVRfYyISbuTm"; + sout << "oLMaClZFf/XxxvCIQAd6+IDMzukUDvEKObp2+rbf7IWZtDUeR3VpCp3CJhTMr2UBRC68fwB3mx/n"; + sout << "C2pAyPNX8WbZ8ZpAbtW3ax8U+yh2rH+hEK6zJSXFZk0Ea9Yc4MNhtDqGEXLgBvPX/OMH4E5wJxTP"; + sout << "H070u7+OQWwA+Aeup//kJ+jm8EqF2RTAyGCiiYl8JXAMNBdIGaG+kzb9RzRD+YiyOrV/WyXn2pFv"; + sout << "68NBt9sunIBnzNMLCPzT3lb4DcoOExmoOjcrziN6G2HjRsrkmyVnjvgoWO2wILYkm+yGy+ZEM1U4"; + sout << "MpVYZgXikHOos/FR0GG9H65gUCPthPn448mtVNFvuYJhhenvuuSkUWI2IFW+rhEeOUXg+r4ZISpZ"; + sout << "F26+EzKiDYKJWbJQNrVJv7CygE6SZ6E4VnKuXbzKHDct3O5EHoE3NDCO7J5kD7RyFE6HgUXusSgb"; + sout << "kH2DOgJrK8KXhwcDIH4AmrPFukCmgS0/yTWOjiZKW8dmp9q9UHKQlRf1rqgmkph2fGXXBbi+AMLd"; + sout << "qtk8gsm5Gb/a/fHIg8zOJ295ZShCqWm87z9jK5lx0/kFZKraB2JleJ0ryPFXCOp60hUbSvzzfxJ7"; + sout << "JkH1dWoRvWr3wpXrwZucfHH11AeFRe3ShRaAKy42+CclzPkWDFLnJQt8NXSkZeuoPjcx9A7lI6UP"; + sout << "WfwYZowJkoGDUq//TSqzSK6XAvQBUHuviclx/4/C3BvCwu40wXDadUtt+2xWJjNNvMlhnoPY3/i0"; + sout << "yKS/BSg4TsdTceqfkt0nU5FeTdyfpLW/1TOAFERXkV2bTkEnoJ8g5LXHqj58hDZRnO45lxDXDYvX"; + sout << "/3G7ja5OfBA+8flE+MDTtRnWAxgUBDxlFEKD2T2RBgQlLcy0Xa23LgAD3qIWYH+o5UVAmVoI5HjL"; + sout << "mJ4FxQdvJOH7VwyIubukaioVoB63Ls1ySl1Ysk7EdgCpdiugM6I39QsFlYeHE3AlLeoHn4caWBFE"; + sout << "lWIhgKer7kBvPFG7M6bN5flAtzAbEXFAH4M35L/6Cl+O2h+BvCJFiK1W/LZw2pkc+Ie0xC0YNWSn"; + sout << "5nAPPxgGYtxCHLVl3R2XJvBcPpAHzbc05fAp2G4AKscWc75SZwfayKKYarX2U+zu3PVJIYbioSqD"; + sout << "B1b3Wq4vO7MDdD7AMZjJO1ZrX3+8piEVo1MeVbqZBy3VTbjsnkkErQfyOulxlQfdcblMj2Zn9NRg"; + sout << "EQwOnGEsWulIurqEe7bQEZldvTTz4pTOZ2W+uv8TqfnvtKgNESaqYvofF72ZFSJ4rMU+Z18Fqaep"; + sout << "5MitEMIKwX3ttFQlgM4/CWeaygkhyNMbFA2C/T3jIopsy3bPo7AwwhHxaZIo9ghkvt0nTP80dgtU"; + sout << "DhP2a2tg+fNUkvWFSfdP9/NKyH5Fahlm9UPNozDehlI0dzeHvqSpdYFYuTYtf1jggNVeC3r/Zfq8"; + sout << "+B7U2qq+e4iGW1KVkOlimC3tYTRfPj4ZILtHAUIwHh4HY7pTxxpErDLpxyEZJNkBUxgN0hPLm8UN"; + sout << "hTu3qFXzB149UpZCv/JmjmmEgLAWMT3PbBCebNVPEjfO6I59rTOvl+fVD3X2UJM1It4mh3NJPHzj"; + sout << "4O7qLNm+S+A00bZx/g4ncCxrCvFkTpBmt21ZMAI/H4swFDlJ0gzKTq6+nhI/nbiWsuRY8z+GNqcE"; + sout << "Snutk54hYXlQ3tVAxhRY4srAcEQMcHjX8pWrJUZjRiiYWnVfnmUf6U3rki8P25851GwTUGB6/lL+"; + sout << "6MWKQ+s1Sa71hc5icdJtxTJSWADg41KO8n1xQ0Pu0KJzvdNMM9G4AswY87XnASsaF9EXKDTSW4q/"; + sout << "s601Ghu8vEYKQVBNfJdiZ0wPsaLM2SvRwSS4Ji7agaJRLBvcv/cK0Hrxml20CNIGS2Q81xGS0Hdk"; + sout << "ykMptYe+8pClFugfpJ+ETSqFgclYc0XoucOUxFh0vyg5CVE7WX6chpBmdhwWNnoyJz+WvNcoPIu9"; + sout << "UZdYaseFLVhArb8cVaA9Mfk7tmFxaxNeXsIBFjiE4dcWKJbZXMnkhCaqG4k3SIHsswxjtj9hFoTB"; + sout << "BzmbIwFhxPg7sZVrG7Af26CYC01P2PGqNqJnZRZQt1LwtPVQHzvMk2v0r1I3DDD6ugGlva+PrHIP"; + sout << "D8FCY/mc0poDzTANlwAx28hkPTbtCpJIrWScZ/Vu/KYJo3F1Gk4C4CwcghgkwTYLhe2eMwlnA+Ww"; + sout << "vOlw5SD9jca7GrTA2tkTvsUnlskDgGlkAvEwc6N5DkS6clO1XRbh1mwhr4UwRhkR9Z8sQLLUv0yt"; + sout << "N/Wq4i6xuXuROdy78DSiSh9Gis5XQbqCvf1VKUOKkaA+/H0Y+XsrHrRCcqE5uaY0iIZtc62XgVHW"; + sout << "QyrlDHsh4Lt9qGD93Dx2EZQqyl8KoyhQ/WbgnG/s77zdSNGTkoJDEQKIrKLRk9ReptdqQzjLPJvf"; + sout << "GwN7wgO2N68B9XbX8hfXUmHX+G7kVnncugQzg0DS1qQ0Hbp4ibZHKEAUHtkoPEVgzGcsVowGYCqB"; + sout << "KCGaq6FqyyIYp+UO5j00xyftvh6uta3atuu+JHkWAPGKY7uooV3MGmdcnnF8umE1NDEBfrcWrrNi"; + sout << "IzQcnZe31Nxqo3WsknqYpRvCDwUZiU1f/EidxhykoQ1NCo1CY6ociQna6kpsM52E9ALvYlUyobN+"; + sout << "7iJQcfagbJw6OpT+P54HFsedkIAlBgTIfZfag1u4lKZnibEZ5SEBfOGeI4pemf/ST+4a+2GalVHc"; + sout << "4W5T3xhuJjQiIeE8X2/nEwqEPQRHdTH109EW4BCv1rlh+XWqRXQVHPOYY3lVV4ONn7/iWN0fslOG"; + sout << "z2/MghTZb7Z1MhiWLxTcwpUXoyy/fexdBySIUf2ukM45ZPLr7T1aCrc+pW+XKfTewxqhWnKSjpVR"; + sout << "o4doFMv+eVk8mQPFjAP98MdIdYEgTakSQPoEdHijyq31ID0lI3qDPElGdFEq/yKpd9Tg0i2OaOvk"; + sout << "BiaBfldrXL+7jlCSTB14Vo0RCiGIGuxqmEgQhUrT2iHV1baKVWUuKERauwS3jVNz+xq6PsAU9lWn"; + sout << "mC1WbmHOWyHpvoxE/X3X53njIVVIbRnGwEma/1THUaFIZ+qb4UVfVqBXHCiCtyAnf7+aH2/T52Pd"; + sout << "gayp9h72ofbN/eBLbc0qXqECKcWLvwQlgiUkV5490CkjKnf49ubdFKRRH0PKCi0CR8T8lfhrHYnf"; + sout << "j29CDEgwX3LkIPw9GXt93ua7XgOilyqxhPm5vDXFpfgR98Y4iMUyheHms8U81xbRmM7SqAPXDJsO"; + sout << "vjoNtHmc20hgCb46BUQbMQMe/hhDjQNFMQ+HA83cvTyCqsXVdKcWIw3vPfD3RUqZ1R6LGgmLe+Xl"; + sout << "vS2S/uDJjBohq5pJ7xCSI2+ylMf47qtS1pYd1bUXssnpJOu3kplTwqW5idkz8l01FlR+duXcn4we"; + sout << "Wop+0ncMd5AsBpk2j868TAPVEJyC2lyQ9Qm6OouPEtMg4x0jyANl7JhyfEesTTvxIm3GNFe0fKTT"; + sout << "cFwOojUje215zqgOqwE0b7JkipU4ssQ6j1lSHuW8lWEVKa17Wr2B77lj21JxubQJ99iifjoarWGs"; + sout << "5NX5nxCQKJGDEPagDVmL7TZaQj+Cp34MR9mCJtRTyLmA4g2rwARMsbq1btaKXUG91484ipcu2jot"; + sout << "/F3wLGdAhXQFWrxhIJS0ORxalPcSZ01zkECazBTIMokBrneVaX0RulGEMrs+CIGiSE6pFJKX18xs"; + sout << "6dGqxF6TkPk3GalrQNn0al/8FgetsD8zUX+d5PdqOpb7qNqEfHAClOcNxhv341nNrGnR7YuO+r8R"; + sout << "67gAjqnan2NY02JmVhT7xqN1Uz9jCWqmbLNAPDUWPvVLvFtHAqqp556xhpVhJchWi81mIZJ3S2L2"; + sout << "TKd+jGHjojiYNyjQxz3DD1S6ZN4mPgzK6JMC4txgvcQ9qoWhoqBDRDrj9s3lT5kkjZnHRPO0RE+s"; + sout << "W6m51Z6sKMT1/yoneuJEnrwLxlD3Efd/cq8zmmUKREZgqR0q7WkXC1Tfs5sPVwHDlc45p2ejbze2"; + sout << "iDEv4ULiYrNnj/wdOSwVnNt/1DXvQ1FLtJ7K/Bt0qUszCNnNu6ZOwhOM2bqVugcOvONb6fLNOzY3"; + sout << "WYA+2RGS0nuHC/pFuzX/AQiRveFsq3IqtYxyexv8uxmlg4sPDb7Q8i3pT5fQDYAh6CIFX7mCITvy"; + sout << "2JnjSfiVqGdITtvsHyJtUVO0YlQKUETRBzsUQ7bvQ9gTETskkHNPe3h/VR/Dpa6ukhyspf1G9XI5"; + sout << "bobOkgu1/3ExBOTiCjbcxgGrgXw7VdmFURpq+FAQwuNxLDSoNfwFw6ISgP80lgBDV8/5l24p517f"; + sout << "fvNPTCus2I2A//vLxMIQqU6cz+kIeViDq0fcTxoBiD7SoISwJiRTqciiJ005uVYFAisBpU06k0NJ"; + sout << "NQf1DfYz5JKAiP05ehMhQLCnJaHjbC3yILIfxXYu4wEV2lfLWWZ2u7/0oDBaZKNHv5JLzjQglqBr"; + sout << "o+0GHf/hhu1EEqAyFRPWruxhJ1XzvbLt8sHc6wsxkQCYXPGfxlz5+WMrSdNP6jPoEleH7xcL/b1r"; + sout << "mw5oP8fsoppeoxiK63Td0Ut6WtCI9grmCKOYJt9UfzTYI4TLDsI7mtofZPUvX8IeOOr8LySclV+m"; + sout << "AcwU6BJepTxRmINnck9tCe7m5unJI8nBy+uy6a9NvSILGvuoJ6bidAqS0dojvAV9m1smC//ZJPLJ"; + sout << "8UMVkhMUHEgx8n/Ss08rQXBFFqao1rCOvbUR1XC6g31GAju2dLm8Zyk9eomFmdOdtRTobmK4XbX4"; + sout << "mNwCOlbKSZt1aBDN8JLYBR84cEBB4sjsPEXupIrJyKAe306c3RrnYj5lMxvnnknMqIfllkKr1BOm"; + sout << "kBYJ91aDe10mUd7zHQwW8KAjHI4pmDhLwusleQcl0RL2Q6CgOB3x6ZaYnXLtuWsOf958+niS6dg1"; + sout << "RA2Lv57ajSAOHzvuy7h/WV6uWhfsikjw6TEci8rQ6v5gY2DEMjrtSCbJeiJzcIqIC8bz4lvmEfiV"; + sout << "QpZPhGDgfS73qeZV7ljrfBcjvSuN9MPbMFQfkr5v9lTNJ+/AosXAjqM6aJ4TTrMq3XAYMcbuEaDt"; + sout << "89cI2TabaBe8B7cniD+JBOu73fndg6YglAy0Cl41GFjU0u/xq0xHIim7e9TLVtTJE47CjTWRrBEX"; + sout << "ZAJLPDlnhevjsz+0vEuLeqGnsX5yjbmtZzpkDvh+R6eZAJiBVq9uOLHbKqplwbjU31y2OO+Gf2Sg"; + sout << "C8nczxwSt6JT8ktG3CuGhuEIGi4l5LAjIjYd4LpQVcsUM/vP5cbzaC7/XyLmSY92KfBD8OuUL6FJ"; + sout << "kMNyHHEtawNcKlUWsW6N7ybRpvZEwRXhQP+5Q3QXHDbiQ9YJhvnnLWHFJx1TmMlhAQlyJBnuGD9q"; + sout << "Re5kqVi5ztXFHLY+yHAFHgCVbGq7WS3xxwQD+jeuLEnvbtNF5qUePj7u7psaPpYkQEj2DvRBoBGm"; + sout << "eEheol5Gc8NcW8wRw0hLr3Syw/8b+1uMl7c/Sxx08yvnB7e8JNQ4kQMkxVWFCWE+OXKul5nI5mBN"; + sout << "HRbSO5bBCuYU/9P7DvvK2U/WmcnOvNuFFax6mmBc7E8nqao9wKmpypDvddWJ+aStllc1xZt875G+"; + sout << "xrYOX7Iv0MSfk41j+VwrawGyVYKLrssBQoft2lbeDqOsJBgxS1wFmxQXgFl5pW2ynnye7GuR9A+t"; + sout << "pbgKYqFhYg8v2ZIYJ/SDJey5H6TImYlPNriXZv/ocn888rvlfOBjDcG9KiJ/CJDLs6RdMMI67EOE"; + sout << "5VBRISu61OMlhv/zp2bOsyvONztlHXERM+N6zViIlmbewanZ7ujo969A8Edx1YiD7eBR9AkB2VxW"; + sout << "nlnX2RzojsFbKIQjuE5NoFZdE49PsauvRhzYBid47XuPLu48IGlVWuX8dUENwWe8d3jUb3c8BPUv"; + sout << "2E2TUEOcbEx/b+q5ex28LRk3CL2AolCUHMhw/aFKOj+P+AJMC2qAgzhEkVUWLfD5oHnN0gFzIr28"; + sout << "EONuinHzkLIW/FcvxqVjNBn9O6svsz1RGQ7dOUStEjwpKDyYWboawNC12lNgh6+JO5uA9jGqmSyF"; + sout << "uBacrq7GoETHISP5/7Mc6PkdsT3IQ7iy+suT5vTybsqlATrvJCR2/6S+M8kZCJdkIN3tVrbKqKHf"; + sout << "O5sHBb6MgPVWotsXQyZ5m5t5x2bquGfonmlEdOcjBpYJSEM9Gd+8zvt36NgjRRr7Zgs8DU691FOE"; + sout << "eoOOYDfc5yzjHClH/b2k5po2saebjyh6Zf2/EC5mQuqkHEeLtRbxSDSa0u5jbh80vEZ//J/3nhsr"; + sout << "lL7Am8nXsj6aqyEFVYnomyDo9Uk+KLWD11PqLqJiCNTnxZbsCwiqZDo6qo4JngJbvdbGZBQMpgzU"; + sout << "kj5cDYxSZLDARNLr+/M49l7esBKTtJUVtgJX2+nejrlgeNAc2BwGQPecI9Zx9v7qHOhxSsCwSrGG"; + sout << "ZrHrTT3Pb08AKSzLQLmuijTWBRw9SmApDxTcUnSCI3RGTmqueOtpq8hGzD07OFhFU62SAdx8cW8w"; + sout << "oB/FAeZGOpIHY3hi47JXjZGTw3oGMtUkbq6RN6wQ5NJzij/HHjLzXyXKpd6CveDuxM24zzyvhEwY"; + sout << "Q3FNWhhgi8RcIFb5VutGDNZ4cd+jqf/veLh9Q1110CuDpR7MUUWm+MlG//rMcScEZDOZORL/qMuw"; + sout << "+aKT4+x2LXKb8G5uWwNovb3eZOJDTnb7Mvp4RfFql4X38BUjUgjZEAO5QXRPv+OfUi3PeeZBDQO9"; + sout << "yjqeCNsecLNlu4cf222h7Et1rdp8zYPwMbdfLrI3pH69VxM5eSJ81OtprdbsO36svGkPOOB8/rNq"; + sout << "Lu/xFCLKvWnHCPoGT5ng1vmcUl6Lkz3r+yD3+Uy0X96++sXHnOVdiWxKyRhew+3/gSQuM6zoPumc"; + sout << "xzAkEjOiKghShOmqPKQzrsEUi+sZ3NnCBwRyUwlaFw7/Fp1pi/N/9bCX30AEfaD7ucsxJGtC78W1"; + sout << "jWxi5K4qPpgaM+qU9VORlnJRFpYFBgu/nqvd8LkFrLkLsFDxkEO7bQDjEKX6xTltM7nbF/KVSasE"; + sout << "OXiMk4x9aOa8buelZrRiANp56cidcnJ8Ayv1GGAaA+hNYC+BNmTXZUal3NyMvxJOzFlKfaSUbgNB"; + sout << "EsYQusXQuVqkD6lW2Odx1Lt3hHzF0MYgMY3cuw5oeMKzarQoA97JDf74Lnp63MeuU+pjBanY5pNF"; + sout << "9D3/l1UIsrFh2ZyC0vS4i+5SGDI0Fza9ZsopX14fbUOG5FFSkrHPL1ArMvIXAG7B3OPSRF9ChX1u"; + sout << "uPCaRK+0JYxLdACdJB4x+Mhnv3fS8w5dDetUO2zy6swM2OonEXqln7lETWsBk5yQpXBo9sR8RamA"; + sout << "nt4HQhwaK0IAX4L6p2LMKJ27YUq2F3Xe2PWf9gMXdrUQPTvpq+MdEnDIJYf7sTkT+lyqLJJZap2y"; + sout << "jmuv5U6uQyaeJ2ayTd2TcvryYtrYg4dUEfTvjdlE1QdJPxUPCzjbkM/y9/SrfPCXsWVlDZDtytup"; + sout << "tSK5GBlCH8JOpiiwgyBR5nAHNZ8SKP/e0K2Ps+mDx6xcSz4WL6loJUN5+lPuMHYxXbT53LPOqwNX"; + sout << "nS+OCDPFclYZO/0TxEerFNvxvjXLU9VxB87zLZRV24VAztABk3d5zrY48iSZ8WDUudPF0mf39geE"; + sout << "4j9zhG2/3+rc1yYFhAQFN0ch0G+Gu/mEEkSkxSZy/PK/v+ITne9JIqao8JfEVgwmQY2VCSyy4Rj/"; + sout << "KHb5SOY6CDIWIRMRj7GsqxI="; + + // Put the data into the istream sin + sin.str(sout.str()); + sout.str(""); + + // Decode the base64 text into its compressed binary form + base64_coder.decode(sin,sout); + sin.clear(); + sin.str(sout.str()); + sout.str(""); + + // Decompress the data into its original form + compressor.decompress(sin,sout); + + // Return the decoded and decompressed data + return sout.str(); + } + + // ------------------------------------------------------------------------------------ + + // This function returns the contents of the file 'frame_000101.bmp' + static const std::string get_decoded_string_frame_000101() + { + dlib::base64 base64_coder; + dlib::compress_stream::kernel_1ea compressor; + std::ostringstream sout; + std::istringstream sin; + + // The base64 encoded data from the file '..\..\examples\video_frames\frame_000101.bmp' we want to decode and return. + sout << "Qld+lmlZhXVX5NDRWIFG1T4+OGdJbkmxAXXdHZEtpDf6knVTlRWyAhgv85Tf11KJZbhKv1dcsKtQ"; + sout << "fX3y/9RxoGs2WfSxuQ88PN3jZemvW7PfY1s1vwY1Govp+hD6LGyg6qbJF5iSr2edUHLWGq/00D7J"; + sout << "dhKHu/QsLpEM3w0yjnufc6hORX4GNjJ2a/k5xFnHf29G4ztGVp6OAlsQa4tcLsh3cOK/eyCjevG6"; + sout << "7GaMl/sfkK4TsjNfqdXM8sWxBCw/r6JTLrvlzVxDTmu6Sm2MpQ+IrRXFn4gK5vwPmc/mEp6etj4w"; + sout << "YeFI6w2jGVTFIg3R3ixaPUtvv+gaIOtPsF+4StbI3AlmD1kDBt0t2xEyPOsmjOBuh+oy/si3xWt7"; + sout << "RlCBakbDZnIihbzDMrLwTJvgoiQDgT3U4JRpH+5cL8E/3M5JAOuKWJU9nfdtNVIf1NuHCd1Ajsvt"; + sout << "LjRLI3+rqQErHsKre5sId70rx+sYSqTWaiiGi0Nwg+JdCRuA1OwFmK4rahCAEjFr55py0iZxFuo/"; + sout << "r9DLOVGqKKIWDX2HQQrQ8N1gATa5Wh9ypd1RWKz0s1wghRS8nzIpFqtvXxJxYz39gjai7qUOP25M"; + sout << "zEL4RlnFsqP7GJaT7YnpV4PLkKAzYoq2SMTq4nlK2MGohKoPNj1JCfVcq3syGMJNYtNNOtN1m/E2"; + sout << "3AKUnpPX0CAEgA43Y5yMLR5ca4+KGlvO/qWwex/bcU6uZSc6uEh2pcVaFvlnyUYdZlBg5huymL56"; + sout << "HojkUQWiv48vxAoYtiEc2xsgsD/HeJxF7z+kWJ5nDeyxLhCLpDZfrc0g76TVEOVy68CtdGLuhDe0"; + sout << "dhFhj+tlkmbNse+X7PlhhnI3Oq8kt8tFJly0YH9RXYd2eQDSK3Wz88jGhe9wcBEwcFir1foxdRa/"; + sout << "rj6qz9DeAo6bVEZkIEXPSXL9EmiyTxwqjHgIySK1il+csxW2Gpwb1hwlKQWr2x7bRf826W/DGplc"; + sout << "XNLAy0HLzQh+8E1iJZhpCwYthe751RQLrYaBxlFNZtlHDKIfuGg6ByiFdr8y3T9dx1e4xJHHAEhE"; + sout << "Aem48B79Zp5xD/ZZOgAmJx/hrXJCnYXdH1/5FNhTyySNQKG2fChxHwALxDoOsB5BwFP77RkEuXkO"; + sout << "QWLOWFUtu8XYAvf2re0ZafPDonS/GoPZkMBWo4S6KNoczfDskSXQ4PqKAW2stt1K3BoxO5LXfQYG"; + sout << "qepdQBcZfL1Ir2WmA0L2HC3UmFr4bPwVkzaYCs4uQJqivpw1lD7CcNKi4BfBqWFTs0uWXQrwY1Yv"; + sout << "rmhEavMquYjJRdo2QnufXkSjrW8VpVEthiQK895ABfggCWg4F3yiqBlCR+aGcCetNoJZ9pS3sb3Z"; + sout << "32tmtN4rrQ4jMb02QNnXPQ+vAavQaWok/dELaGYCG7PefjjcFqX48WNVmHQx+wLW4m0xo0mqXSxz"; + sout << "SQakpCGETfxuzyIlmYVRh41QcDdd/NH/CIwIMFVyga44Ewpcuew5mbz7bXP3k+FlxzXJ4e0ABVn9"; + sout << "hO2sOG1V8teDymKMNPQOBuec5AenB+kchCtCDL8ZtglpBBqmijUCegyxiVpjnCzdGQvIq1XdfVmg"; + sout << "e/hWiXT1Vi0+1wtdL3CayA1h+zvUiIJspWD5QOHj8O6buKKWky+V2UZDPK24AuIl125ajH481E6H"; + sout << "sjS0/FqKjHWfw/SayHcCdbqIt6/8AouwlkzGOK4C14HxAg1jmvGnRn+46SCXNjlHBDH8/mJ0Z1uw"; + sout << "C2Fv6G6t/IX1hcjJ8NBQOY/3gXAtDRmkNUc3KImzaKNDUkS9N8U5QLFVqrBQ9Dw2kjJfc28QJcXy"; + sout << "OOItYP1RQmG942DKVwW+QQV/fvK8BFpRP5yuNgmzjGZI8gP5cqIOkW+D4eBFBb8PYsa15fOF9AWB"; + sout << "RXDztTbO3uXWcTpa5w/dSbawwBqfeOY5wh3lnYd16Q+Yorn9D0O9VGFWr2JqKwl/NcRtH5FFsNC4"; + sout << "LpqcX12SoO3gsql2isJzJlyTmeSDB5C7o0mBE23CsBx9uqDwBJn+YAucIFxB6I4sxgF6EOilkA0B"; + sout << "5CpLkspGnHs66pvrLeByoDMCpkt56bXRlcvFhjRo5nljY+GjBh6Ux3ZJnMdUG4Qded7oVaY65JVX"; + sout << "MpoHUUG/VSR8FWuHUP+SXh71QiEjsrAAVQ+Fm+fAH3xQ9Vjiv1ffRvxkb72G9O/LjF6NWWgowFcj"; + sout << "QBhuD/mTI2q+6spFLWlZk9PmCWTBkQGsWouyZMuVQqr56sP0UEKANEGwpq+h3pq9t4+eN8WcGhyd"; + sout << "kKgOgUvt8aY7Ib1HdckQt9cbtOrYyjxrQ7o7CXR8eg1SLPDJoWc05YevckkECeF4nQcZvoRxOlEV"; + sout << "rQwpnngQ9oNoREpNnGlZGtBr97bvWTw3UWwiEkqDJWTTbgoDSzjEroYV3RjVyVKu7BZRTvWDFz+/"; + sout << "xohrwbZ/NFA8wGZpgm0qtJRMCyDUIAJuQSEjxa925hWfdkVJxhExxXgWp7NQzhawF7/16Fa1tyMK"; + sout << "62nnOYWV/OWcNOrRCt96RlQooSf0DbUqtLt0f6AzkjKobn3qpkaMxg4Fwzru+BaLmId2Cnsqdlvu"; + sout << "jCjsgpY9cfv7NdcnPvhLSPyykMPmBRe2XZycAzmQdpY04DpVI/A7gpkVyH3IvYP0ZC4SXNwt3Iht"; + sout << "CnmDh+VF9+iVjlaG/JJ2cSXG1mDSfjPgQqEO2DtotkjN3cHmqqSkLqbmA2PMnbzCn5FJbdhES/C7"; + sout << "uMRFLhsEyH7BF/UGHCyIkzbGQhF1q3PE1PP7A6UOtWr9tmKvctlDnsiEpGvxBN/8wfDERqImT7Ex"; + sout << "f/0fF7UbHD1+lJc9VzXKeV2ZsFhFRp7r89yVcejSwiAcgHenFuvMEtZ3lgRqy7mFxwN7KNAh3O6Y"; + sout << "RqGFSz9w0gb2gmX/1f7kVvoKTgibMHYXyHUSvusrjX9betzQ4NysjILBta51/ZbeBIZPF8V7mhNE"; + sout << "5c6iBo4j1i+jEzBGCbH4jGwNTRoD2S+uWFrDS2x3hYJL6uBoEWHN0oG5f7H3I82zkju7jkuUCM7V"; + sout << "bd02Suu/hjwDlr4YaQUTIo2aThnXT7ZgCaNhkWODTVU3BqX6M8WP3zFZrMWGtbuXrE6WdweK5ENk"; + sout << "+aV+0CDcPHertb4LmJCOY3yXtPe15OsDHwedxxBJCt4y1UdamFtFo57XjHeAr5c17B/MWQ82/ZlW"; + sout << "1Rwrm+gNysxVGgfmzvmT0hpIiE//g/HqI737ovjOo6yRfkjil8AqIlOT8Gw8mBbnME2OyjX/TPOw"; + sout << "wX4F52tAhs3oi9HmaKHW//GlC7ILPqFN1Yg1azmvTgqVeYq3l76JbgZwNmPUEM5z/EfOcWu1GKT1"; + sout << "Pc0Hz6WcPYYyANIlmv32CkBiZQwMi5AIhrUJvpGaD16hPW8f5lXO+oG3FSJlhZhiIQVKFX/8UqZk"; + sout << "AzFhmyPjUHNHOsJo+1c+5FTnPsjHDzYwW5fUQRSz3L3/rv6AY7g0BXKLmxCJYzIwxyQ5i7z810d0"; + sout << "jTPy1DANcssoEqdA0P2u6udO3d0rtL1AY7dKoH5QYyKrOjx+HkLZcfxX2lpjp+kbq4uy89JYfo25"; + sout << "M4tX0B45FKrUDDtIPKR9GRfb6VgsH2bi7aediwVXCoaTxPNgUyKaFWSf+y9gPBiD060TccAZfCqQ"; + sout << "HqI2fQUwEfsiv1XeTLGH8mNIHETMKP78LdWLhATM3ejl8aGcNxsRAbMXm9JgJmuHxaNViUdWmfPU"; + sout << "hJMksO1JBS66hcn0jBJhZEYF8qUwG855G65k8Vvo8tihDIbJhxGMDByP1beBLs6uoohknWTtDU0K"; + sout << "Q70nziEj+VDbaexOU+dF/b/ywH7Q+KUAFvHH/xObyc0/maoRS/e0pfyhmB4OeGsbWAs9aGQNIaAn"; + sout << "Toh1zF+xQBFMKGurc3P4aPvh3Vo/cw2hQTMYJpUJQvhYx8XAnA+hRtnXZtHD2Jas88G0JXCCX5EG"; + sout << "9te/j6MEP9Xwb/TjVLDGOc4IeMYqVftVFcmL9hrX6nk3TYDyUkvLjgn8vXEmJ7qB8CthBF7UUyjU"; + sout << "mcz3agjf1/6xrIotky+zbEAy45LvmZkcTfVVJt5nVyir5hvcGyXWUOpqp/9NQr0ClEVhzfD2d1cd"; + sout << "xSSLaeTKsr0jTWFV9GDeIhprg4lJhxL/HbxFu3iOocrysDTTxsbLSTjO33ndrGiz08uW96K1QY2F"; + sout << "lqmK8FOhR8C1eaUdUFfIa/cINoThxQRkoYq3tUr313+VzGTRO0I2SSJMIHgFBslLDHwX5AbkpJPs"; + sout << "7jCQESWTp+LDoN6g2X+RmiLlPQiU7iJXxF7Y2SU09kmHOg7HFPKWZQ8bI+m6CG7kvBF9lWNuzfr0"; + sout << "lWz08lEMExIEAAngM7G0LqqOuLJO+Dpk+lLWjO5OICoxQ/M7akvSJHtc2mVOhXNskHOYG7ZtEEEd"; + sout << "ceDN4bzzJ6qEuYMljZcKkh3NduZbETGnW7D7Ec/UqU6WNX01iMGt+4lCWtu8NHySWGqIXcX+8ITJ"; + sout << "euDFM4B86RSKk38VpVUXXLWd9GPUYa795WcoVxHlFRPULnHJK0G2/AUuKe/K3CxHNqnxsk1aGdxS"; + sout << "NGVyQfhXJBWBA22f0Eclu/Z/UdzpKZjCZFFZWt/4ppIWGyMvxheOSqjA55keu3QoFj+xhE0TUNlv"; + sout << "cYMT/b9kHJnHQ+j9X7ReicHMNdtWYmzybolMFV5P8fBwlReR6BS7nAk7hu37K4fpKoKujfMUfE76"; + sout << "PYTfRoTxKKY0atJg87ZhID745jym89xIqFqex1vqb3Ysw0+dZDWHbwiidUZjHza33U07hl66qo7M"; + sout << "iHWjOPzrMEhad9fsvLGQkA3WNjj4fjx5TLN0mfieIPgAjrQvZhCO9DspEWW+jwRP63BYUU3rB2sV"; + sout << "oEJfsyRSoam9ujGE2LF6ePYZhOOs41OHGsbUweduJ47XGdt8Z8+wxnZ0ykwvxc4eVsNbNERVJ0pz"; + sout << "5hULsfB5BCs0jc2cz6N+1MLS397qKGNwnim8OBuU90wU0vMy+QWF78OnVb9jlg3asOK5riTvCWMb"; + sout << "qgj1qPo0GfWRojb6L2JaKiQlU7r5LiJSmmCQLdy7qnbtc2ul9kKn7iQZqxZVlHJklsrsOCdA3/Lb"; + sout << "kzpWUbCp21oyQKgp8PRkMv4NhWt8JkushMklTjuvDhvqVTqrzFQE/UJoxWEBEjeZWRgTG9gYISJz"; + sout << "9mnT+89BXF+oSahu0YO8TKgNMqIDQ0d5GVrc/lakm/jQpZ3nl2tVbN5vLMtCuwbkeYRiFuaQKt8n"; + sout << "+0HpFoECCJZdk6vyVFpIHNPSP7bfnsrZJHRiprvhz41LowOajR20mGW/+mlqo5K5eVxrW4I2Uumc"; + sout << "Xxp+DK5yhhIdYxY1SkRYi4CynSNUPLqoD3RaTFKfo+aGwd+N1abtMJpWmE/Xp5k3NNHXWi95ltoF"; + sout << "tU4BxGQWWDjG3UB9t6eUDoV/WXwCy5pxs4rbXLb2O9CM2HaBC+lDaW9RjxpOjI1jvAXmcQrs/MeT"; + sout << "ex8n5RxFcyRzjbaWhd/V4vZ5+qY6eLoT/cUpYVOp0w3lQEBaGz0W8bcetjroDYGSV1U+3nBgcKPb"; + sout << "rW1wp9l+x/ihlfM4yHdApD1WUTYfdIgnO6YAw3tlm0UxARb3WxyVjIQ7JMr1Xgle4UbFSWrjI8pD"; + sout << "U3arZPJzf7oN1eFDdS/V4fZ7I/B+j580QOmGI5LGvSBVR3Ic2iVSbZuk2mjWgK/YW4vg3kp7LjZe"; + sout << "yIxulHxWGysi//Oe7bqcJsM8Jndw3sId9UYV9r9GaBGBf+E2YWibCRJCiDjzO4PizCVOnTGJlOMI"; + sout << "Y7BvFi0d0eLHGAfIm+3d1bUqD1XcPblx0IYVa5AHY90JfMS7gBJHmJiY9Gv9+WtQmLZmNkU6Tcy6"; + sout << "yNyIMaIfI3UO03ApmhM9babcNpX8xhxYotZXzuxf19nnSdyOebzl40qU4BPzwwQmQMISK35A80xS"; + sout << "dvnVDMbcRQIC4fjwo/OKNWtk2oLY9ERJ6ixMMxG532tF12jN4ZbWn1X1HGFNXvSU2bFhuOLMSgDC"; + sout << "1poVi4Cjjgd4n49BJmlkeBoL88z83sPikchLnW1kxHubG/0VBwyHYBmocPrK46PL1rx1+dSwc+Da"; + sout << "3E+edMWBVe/ca2Z2lBDFIvo29fjwwumf5Ljg/M2YybU9jZC4vLR9VcdPJ9J/gi53iNOktmDcBcp8"; + sout << "ThCOSXuaDyKBtQ6nvd3AUGdNdLADPuNXTRNybM+Lqc4k7cApgM6DZV3Elpz38473d3HDAW4pCOD3"; + sout << "K/y2pYyog+/27OSOW9SGoVkOuqOBKwqrFAuD1jIbI/yDq/LYajcJIPqhkUv3srTxHkPiOBbFy4pI"; + sout << "6NAftTHf4GGy8VTGIzeD+z7L1qJToCogS+FoBG8ixRGUYKvUYEQyc1EhXdoPOYY6pZiMsA4xSGyN"; + sout << "aU9RrfweZH3ld9QU9Y2kjSqVOhQavzIPtD2wQBIDWxk3/bmH1m76qrcEfY9WKCb7Sl//1oILVf/N"; + sout << "b0/TZEqVSAFOMoTzTXO1ClXymBTA9b1bJKlQL/8DRayUN0NUMllrwHT1PGOmpoJ+AyhUOjEmWREv"; + sout << "mKkaxiRkpyLdgKyphJhtYGid41FAAKacNN4CMl9W+fZnydgR0SDvYvpOwveSXr66xfDZlQti8ZR5"; + sout << "i8Il6sq3+2ybJj/oakohjWPxMAA2rEvsOYkuUbviTAYQOM3jMXAVPIkgAYgwwXhvTpHjeRurdxho"; + sout << "5dz5zGTwSRsxCUn6QSTtt4DhroBs4xBCmYa8BFsUAzG9nG+SP+ejIgc6HsrLCxZ9Orlgn5nkX26K"; + sout << "bI2dWHqUL8DIgM9OjrXeWath283tUxYQlnQ8XJ2/IsmRkbq7eNJfBATvOmvk6CTJJdGrK3iX/xaI"; + sout << "WTJ8J4qY19ehkQoD8fLA0W9FitmfQyLGzxFSI4b3YOcl5KQUvSDjPpG053Ok4LhcQCILRhoHk5jl"; + sout << "usMw7nWyw97HV8XHrQkjsGyvkT0G0LyeNV5iYo6o8yFtF/s44/mfI3sh+b5AzhdYOEyRcnPSlr7y"; + sout << "axPlZnAfZRfqccQnEricrTtaPvYaJIh/dbtnlM5crcub0tJ5WNWAlLdS6IShjiIivEzqakfqg1GD"; + sout << "w4TzFl4VeqaTR7JWeGT+yyRwTfqkIEFWp4/7DyRSXF1hbiUifVaBwy4mpar2tHuSL/eknEI2Hd9A"; + sout << "oIFg0F1y0Jg4uAewz1MvmFsfoHpVAT8ejEF5OdFjkeUL8fjSBlX2SwMNaV07UlWkDDGckuHch6r0"; + sout << "FIDtx71pldsL2ALaQgKk+85QC8YqJdzgtfK/lRT0MDYOh4NlCZoFqQAImSYqasRtsyeNYWVi7Hh5"; + sout << "HUcaH4dpUrWFu8HFXN20rzWwkplYAGTeG1/S6T/sBXzAcv78VXwd8ec3+Soa/ZZNXY+yWuHziyrP"; + sout << "K7FFd/fcYHML+2bB+E2Trm4MVpzv7n0a+Kuh9sy0SVe0IuVJMo68R2Ftl7tGgF+dTKVKmEHLHqcR"; + sout << "fSjxUAto2wxtM3Lormd/1Yz3loH8CL83JrK1f699TQTZee6voQPJvlvKJlctr3NeWB0uwk3TIACf"; + sout << "198pytEpHGhlLByCzpwV68aPjs9EV5yekrCJzEp0arZSIzfI/cgwYAZ35Ukff/bp6jxg6AB6yUL8"; + sout << "Muq0rqGDqwpfTfTDaeBJAhwevEMyapDwrbzwHhnJLi7ZT5l67Q6Xjo/A8/U60t9m2Sdm5ULwzBdo"; + sout << "KdSBd9S1UUueBxrx109Yh1gjvbk6k6d8L/Tuqjue77ZCUV5mOJ/rmuTi6OEDIKdzZuBgooisaIg+"; + sout << "CargeAu0U/JS54e+b+Emr/56cogvgeJo2QeIKT5yhqHxtoEvHjhA+7q8MGuosAeQtp/9yYt5kPiM"; + sout << "j0d5GjKTviHQqgISzuskcAWAhjIfXSyVrFhiL2tZ9hqS0u3juXdBoUy3nXcx9WF1tLkyxONcILs9"; + sout << "+dOGRg+VpsTcjRqQyJoKm6OzVoN5J8iWKFkdZGLm0IM+p7F+jRasFIgJd4iPjHMsxYBlFX/aNcxT"; + sout << "dt1W97L7uw42LMYVG+w3wnHdkO/ddrsp8DemUQzsT7yGDOhg5LMBqK+Lh1tcizQrb73NDpqjGmeh"; + sout << "wmSb9GCzxDxw9uxfaZBBbO7M4KXJiANLFF86djhSpYgmApyhO1QcTSl2UR4XptqikYbPGFFQzKIC"; + sout << "qjvmsXWnECWWSZK6pnNnHWZOMto8VUmPx64HIPe1hW7KeRM4ra7J5Imw6F5APG4+Jg8IUf/sH86g"; + sout << "v3F35nnPnYeVvkQuYP3iLNapQqaKR+pQizKxXj/wgPagwSRrzlSDbKM7KqwbvdBGVMGfeW4PTpuP"; + sout << "FxtmMKzWMWAGUtPSBHktXyt2yn/Gp7HVfirM1FYaIOsWvNPtBfsQ7HGA3dvKbP9f6ZvRSonW/+kU"; + sout << "S+jp/UP+sIfihswVxrP8TO0YqQCwz6+X3nCT1pW7sAJKggb+DZyEisyl0jwK8fVttFraIdhqO2Ko"; + sout << "o2Mfc8+V5jVTIC/VPtuiK0Wjv0+Eov2U1UgvFM0jsGIPfVZV5UDra0rWjp/vnzo5C9MBy/MxfCyZ"; + sout << "5tC8AnEXpI6V5JSHSb8xRAHWZg4HUMmoZG9u4mX+fW9Lc1OET8xVMzfN68kndi25/bnfx6SpR6f1"; + sout << "30Nhy8rd93qy7DpYUHhLLyhBgujuJiT4scogc2iDNwP/Fsam9RR/EeVjKv8gwDHrDSbbUbRbzfUD"; + sout << "fM7oph68ce4/sQ0rCh1WbtPrSQWbNU225oOYZr1lgrfxlYnI6ROd1M8nWS623ofggs4Wfh5CMYXJ"; + sout << "OYe0XlSNliVnU1CX4MXuX8dIXaEmU2HUjfKWif6YAQY9eSLikPWVgEYvpthn6dur/TP1jC+W3a/Q"; + sout << "vigZwfynqIl/NC8Pe7tpm2qe/K6TcOgru5ojdaPeyLUsVx/wqvvVlT8ZOgVV2vKOXAUQ8bLXAH6N"; + sout << "kccB9Aw6NLtXi63VjdwUrgfBwT3orFckTuYA4u60vxs6e2ocbkJa8YsPyN71Q13O7t5q7aYP3O4H"; + sout << "JbDKncX6fZddc26LFxgRQem44nxcab0yoiP6H3cp8mHB86+vkBuSQHWGSWN5uNK03BP0rlUntMT8"; + sout << "zis2iDvOY2gDKYHPj57HszSQ8/SZx7MFKiFRTIKMK5P39lNz0ylKhskAfZ4KDzLp6xzwjwxDsy70"; + sout << "LQatqB5ojtaMcglatNYvdU48iHz7T/KPIU7fs/vySYdx4EHZ+Oei/dKFvpdK+y9lFyRJUoBXpb/L"; + sout << "LCs202cmGv01lonaN1Qe0QEL8OTPqvPwGb/rqsf6CNobSZd7mmiMwlQ4dARKyk3PWgYvalT7nm7w"; + sout << "aZvPlDUnws05FXtATiepdeNlffsCK+G/TEzH4vxOzbFsNwhwqL2tVf0t+1UJwD+NQaF4p0mSoA8B"; + sout << "TZCslymygnhgF9Tu0jHPigJ99Nj36RhdYkndHiry2OlwpTACuBv2CFv9F8SDhfwuRw4sr3xQpomb"; + sout << "vpTNyGrXXmSVrsTnGP506nCBg4IQOi5zXA6MgVNXiy62mh6u0lceyZXfbvCKkGJi6wUj05DdrCvY"; + sout << "3TrTH7DxcwTZ01O1e2A+6A4v7u5/GyN2vEQb/p+lSZ9XAWYLywIId2pFMJ6nDuD4HKuoVljrA4+A"; + sout << "wrwACO6QAn0H15I4nIBRzxgEhL+/U9Cmhmudtl5iKeZCifsTrIda80CJ9rTl/w6knmll4/9rLceK"; + sout << "QmMKBXjd4sdpcDh8YNAUa4lWNsgL7ElkbTzmZlPe9ScYHAfcJGEtwrOuYDp5LfJOh76nM/LbGAkF"; + sout << "qR596/Sqf23C+pfe4JH2lerREvbvkx+N9b6lRHEEU6tN5qx4AVPUlVDi2/1qo4wcIFJEKYS4bfKc"; + sout << "jZdriM9msBQPzO1GDoVTfUxcY7vVtkVSjVONgpjlM4Bsr8YdqtRH7xAEuMFG5y97X5PK6KkIAqAc"; + sout << "fXc8BSroGGBcXmIx5dAh/U6oQAu8gOe8mSvy84/dz2sFbioZSlVWR8asbFLDjrfHHpo4d4EX3t1D"; + sout << "XsGHOG0QPuj5/lH2QZerKkfyMZiD7rSJTnpf8oUyY+gw/e0aZMHk510G4ybRKgY+9tCGnu/A3cJc"; + sout << "5wi4HcD1PUKmZwKZI2PIfoIv/VUYQSZt7U26rf0wUZEMcuUldIxiU583HZ58fUGjmNhi5uWdDeIF"; + sout << "pgIvY8tvmwUNT3W7H3iw+idsfn1w1CjWmWiCRo+yjAB+TiOxHD4JCCNAcM5PCWDa/NXLBA2tO1h9"; + sout << "Tk8h96lBlAtu4IQVJaB6+o7iPk2GNq6FAiFQwYC2F/Oh5A3wKWrd2kqQ8JeZtNdk4L2+eHg3aDAU"; + sout << "4aXaYzfxDSOYHVHLgDj4VY6ulSZx9DBst5UPsHbaLaANf98uIZuji2hXOcg/FEnchb7A5bb+UIyV"; + sout << "jJH9d2D4R80e5GwPF/vVZK2GQr7h/pEgdGTh6hr8piMsR5fzJpEI9dS2jTJnet4FtnhDo90fImU7"; + sout << "O07J4AT+uO6k0B+lKuc5QlmaXg5Hm3hCqq5CggvnHNKWHV4f1ux0GNn19RiXwX18EdkxSEDIX1W9"; + sout << "CmdCn7gQtqznt3pNq3apLeITu3phYapWi3FlHhxoxOE8gKjF204FYOBeKztMqCqSlYa3ALQiac1W"; + sout << "bk5HyFaFvWEG5Rb2gs7bHSgHuDFlJwErbiDfHM8415fDkbWAOMJFTDh3YY/tuNz4vU4bNC7TcJHn"; + sout << "o5gak21470Iy5oZ9WHOgQm07tsFNxyJg/98rl1fQ0lcjNF+2nrhj8O3bXRpgt0eh9AQeu6QVTC9A"; + sout << "d7DzZeQ7N3HCWJl+VbwFCW8Jyo5umAkO5qFn+c15SO7QZ0KEjEORTiUnBDXzAtQjDHURr4yEjtiS"; + sout << "LkLTQsXwiwCmBOnoTdEPegCHIEeVSIye1t0x+dR2DeqK3WA/hB8mFawdIlhtmdEOZXov1mutXVDB"; + sout << "LBN80uBR/EypdQyctL5IR7IZIS8p1vG9kY2yceXlszfRZ/t1H3jx3sRSJEdVoMOP6Es4IpKj5uw4"; + sout << "agZ/POEGbn9qKdL1N2MM0R3y6D6YjIXhQF6QO3XsYlzU/OeOWxoemfC3/qn4yxwbaYyEWjICn/9d"; + sout << "e+z3q8F3Kh2JC5VsuIVaKZ7MqXxvtn9SBuUQ2vsx/VqH/LyRi9glQZom7H7XYefzQ6neF1DobC/p"; + sout << "HsFSf3Q8QV215kXq7lm0FUxQHcA2p4F037be3CN4V8aa4NqL/ChPqdR/EMtB4C72nOJClAGY0BsL"; + sout << "4ApuEZj/SfshLKcnMdTcJTsqBDrm2ykdu8JEg5PQT3vpG72eUbAm3Mbm3CYImISFQiZ07XYDnN6J"; + sout << "DUOjaiwfjJ4ydnkaAQbDnChIS6JXZV1LEMS19DCB0TLoRl6jnX3mNu4S+fugfFHqXYxgKIAen83P"; + sout << "YdaCotX2Q4qj7YjNhknIksW1OPA+8VxgBhrO/xBocrJY+uEi4+iVFJE/8Ge8sXmYr4AcY8q6/CuF"; + sout << "jYWYkmj8UZp69kE9QGJdiwmHBFOLIgAdnKMXuRS2NFZgFcoGMWMxeasH12WFGRmtuoS+eBi7RzUn"; + sout << "CC7RpHf++XNJClYpMSPnEYafiDQOcMmjHcaBYkEBB6kRV6dPtwrb0ZQPGVFYnhhcOtcM8tkLD0k6"; + sout << "Ek1+UYaG75go18OFlshMOzU3Rq7SuuNwtSNQ6drsWWcGYEOWtQG1b6DniQ4+e9hkP1sHI9GE3jMV"; + sout << "tbuSi8Q7pRY+vzyHyNIOH5FmnWCFPYLMkHt+6aenSR4bh9b8Bx7khop+XymNuNOZg/hu863yOx7P"; + sout << "D/OCXH8BflcACMxVMFKcPMxYZHrxC4cINAO2wY7ooge06LSlM/3gNNwfizSn5miDw+3d4UsnnD6G"; + sout << "+oMA9AJTuPcy0tMqBsGqmBZMAIO3yHWroF9G0dS2TVJqMn8Yt5shS7TKoVn2ognWInejfbOSP8nl"; + sout << "wliaBSd4XCMRoacNMgoPilfwQsqwntNn1jsNRVG3YwQwJnSWXo0YUvbCeijqzP0pX7IezmCY8qgI"; + sout << "VG7spKv3J900W4CCvbAlNsfHnNAnjU/5MWjs7I2j+WhuFv+b++3vDsedHI5nUWjTUYYy7N8O48LX"; + sout << "XH042vXLucXjT4inm+IWDpj+br+GmppCY/bZzqO/IwzMt8pWiReyU9NNtawN8ag6FYA9fxrqGwCf"; + sout << "j98DdzKzgBgo28R8Y2Al0oC52pFovY0Ym/ormPTSwtehSaq3lTFgbCKBhYR9QPfx3vMbYZupWWCD"; + sout << "FjnZYkKtJltasN+SKs7Wp/cCU+U+6zPAqHmv+ZSWiNhC4lPgYDGSdyhPhp/sHz9rW582bdo0iUT0"; + sout << "8QRPz1UGheorTDK0K+V0sR6ltaIS4/5P7/QzmVvGnnilWTFmqdBViBZRFv+3wp5xMqMjtZthbZtI"; + sout << "3YmEvxUwnPvJCu3+veeErSbvB9AKDj0PJGmqLx8Cw+SbQgN874XSSq+w4WVXCT+GOUrspzZR1LNi"; + sout << "59SSCEs6AKkKmXmjk7sdl6FocyINLIGBUSPc+nwCgESD4xSaR1cx2SjSf0CynPF/YprCxvqRvQEZ"; + sout << "dcKWvwgeU6oJP/0tz55yIZjiIi47ucY4ySXSel5iPoofoh4Gg21tguOQovIAbzqYrkN44bmBy+2S"; + sout << "Bpxa7JRdQr+QXwMMEbsnjrXbdHL7J5fZa/80FkwOHY8Fvz1wCRoZ5y0EyAZCq1s5H4Vd4yHMaZmW"; + sout << "ZNJzXvADrDSO7NxyTDGo4pVX4Dxh/v854SiHkO12eDvoizsraCz2ENtE6cbpHIdRq0Tt/K1HEH1E"; + sout << "V9l7eNcD6XJrvSe+6yBEJKD7srWbyC3MPSzfl7cFEZPbVzfkjSP8H7AnZl9mFnEv8AbuLvHJTVzq"; + sout << "LdwmDV35e4jFpDbf61+kzeO0F9ABrphKOY2YP01f0sKoIicMWOcXyBVIMCtB/9Bav/Pn6bH0pyn9"; + sout << "fyVa22J4yLzvKdYKluBaKUYVVUS5GnvTT1Qxdtlv0AWceJApOg4sHx9zZmoUMDlQQqH54S9sikKk"; + sout << "NnW4TD+CbdPYFH7kgXAkoFveAOIKcZZhTUD3FLbarLFqwcU37NSuqM841MlPW4Koa5Jrc5ySoL2U"; + sout << "FU/Xu07j/PvrpxdVHbok0aRtPgNm44+vPDAWmuG8zuxqluSW2xxzPi/YVo7nsWPvJoDKjZA1aRMr"; + sout << "LehSLFoQa66fCq39/XXbQp83pL9b3tFGRndFnP8FxvMcCaYkG/SDevUg+cG1ysNbBSBJ1plCJe4S"; + sout << "loGs/jQUQCNIX2/mR9L9PHM7FS+KfQn0tHgjipj5DRGiFInscy8cXn/44uYSHLvZUTvdcg3+I9hi"; + sout << "qdV0luQoUCZm9ooT/SBeEo4VzB6+L+fsuOlYCQloHNtmJ3KrHtRtbk3caJcrbTcoQETMbSGN5awK"; + sout << "XueIC5Ar7zL30I4/DTjrhho+/6Kf4nRhplqPpt+0lyLOSk9KjeNxWktMwnnplhggnxLqb9laXNAM"; + sout << "juhmMFkTHN21kH6F3TDmBE6rl8iMHHn1+kSKhhYjDeb+n35qu3kYYEg651ow0NVdq+aPC4YYCvt8"; + sout << "44xqYd6LxPeXF66jNdNl+oOBdds5MLOyW8uIFdQbmh514Q41qZ3TEMieqT08Rp7qW4/mqiPjRzp8"; + sout << "DZgOaUXJnQnUKkVbqqaiBv+ot+ArYe5JB+e7nvzgyf2zxCZLk9/Szpqn+JbGGEHFixV99XjJQRTU"; + sout << "q5DcOkzNjd2/Da4VEXCfjQPWgYlYu1u4hSgasO4GVpGFnyuj76gnzPGbkSVq5GMJ4KzYtl7/sta6"; + sout << "sijpcScgVlvvL6ff7fhEVnwQfa4mCZn9/umsDB5ZbG+VJufprMVSb5CMYcMyOU2KYDyvWHE6jSP3"; + sout << "Za1W9Fjo1Fkzx99+pV5Hex/GeiEOxyYCPjNEvCNpn9xJWjU8fx6/6BbN2mGET9hO18lol0OFeBio"; + sout << "88h5F6msca4plFbxu7eLv1Mx4kyQNmvupkPaufis95NgiPhNwOUSPsGffK8WGIaJJx7+g/SkgN1h"; + sout << "RqaZGgxZdrnY9nOOB4TsqNd1dG4p8ybarjysGHg0JQEFRcxj22DlC2PoAMEEILAjYb3cIX0xxy18"; + sout << "D6TPca+XaHi6LoYixk1zz+S8FdUGweRHS43IaWzBa9KDJ5vEIxAAXpyJ5Uv5PRcBhAfdjqjeT8DN"; + sout << "bVN0R0D5QQA="; + + // Put the data into the istream sin + sin.str(sout.str()); + sout.str(""); + + // Decode the base64 text into its compressed binary form + base64_coder.decode(sin, sout); + sin.clear(); + sin.str(sout.str()); + sout.str(""); + + // Decompress the data into its original form + compressor.decompress(sin, sout); + + // Return the decoded and decompressed data + return sout.str(); + } + + // ------------------------------------------------------------------------------------ + + // This function returns the contents of the file 'frame_000102.bmp' + static const std::string get_decoded_string_frame_000102() + { + dlib::base64 base64_coder; + dlib::compress_stream::kernel_1ea compressor; + std::ostringstream sout; + std::istringstream sin; + + // The base64 encoded data from the file '..\..\examples\video_frames\frame_000102.bmp' we want to decode and return. + sout << "Qld+lmlZhXVX5NDRWIFG1T4+OGdJbkmxAXXdHZEtpDf6knVTlRWyAhgv85Tf11KJZbhKv1dcsKtQ"; + sout << "fX3y/9RxoGwtXxwFXCxbBrLRaklamboDoK5N06vzzd+CS5m2gEc6/c0nITY0hfeRURvJI6cvb8OM"; + sout << "W3UbsFYtU1ntq1umGkQDDu2ukl7sV9LfHj5eh+n3zjy7VIX/RAghwQA44rcpnh2d9icwDh3G2dCo"; + sout << "KxHo3W4qqxVLis7IIGg0dzwguoB8ZGmlvOg7hdMz20beAel/LlU8TMghQk2Aj+4NTKQRVwcfXfdd"; + sout << "AyOhcT9wYaLabnDZgL+7yq8Ji3VHvlGqgVD14WJj87s83dgMGikK03QgTzhyY4dwADT1aQ+xqeec"; + sout << "K5d37u4xE+btNpcqFhHx1YCp1r6+O9kjWRE08rfb1qWZB4qyCwTkxc1C5buXaeNRrw6pHHqSLWd0"; + sout << "9IFAbB0ukrEWmdHr8lVKAfBzL+fTFrAO8gjgT/vf15asPfqSIl+zJyHK+YJ1p0F/UwIzPkKl5nx2"; + sout << "O3u06yBtIHKdflKGjRSXomDYGjrjeoLNjqF//7zBBXyNXXv9TsWwekzsX7A+qRdm/j/C0lCHU+sL"; + sout << "g6wWh7L0kL4ARqClFQfanO+87/EPhnu9LuovQzTY8PhLB7a0R2fhAncpTEPCHBm/0bTEL5F7V+W3"; + sout << "8yzUfcgi1NOCAUBN1tVwWBDsVv/NUIl1x+FTKd1BFAsN7a0LLZO3IGnCikPyU/Tn+TCslyvvnU7C"; + sout << "SxgTwVZIeDC4sq0l7kk3Jc3UWrDI77EQHjxNFKG+lnJxMLXrNMy/lJ621WGz8uVPvpkcTajH/YdN"; + sout << "eaRRgbdqnkqmoSyX7cJeaZG4NE0n+isv3j7WU3cM6N8WKhYCH0ReG7XjbdBgrPHaUdFD1dV20eC1"; + sout << "2PSv5D2jEJkDUVNdMZSmwOtF/UDg4WRMY80YLGSKhxqLqoszhKWXltFdm1M2+WPtSFOdr4i3wZRs"; + sout << "W1At1+bmMTFCM18KnBg0mQv6zyjKRjzBPHLiQZQVBmv3al6bQnyLsbpNDM79oVd4cdEv+l6Cze9Q"; + sout << "gPwotlVgN8hi9gcF3lH+KTQ3Nmd7WK3+125inZjFasNLUVJZNLshz0DVnKWeaO4RGLx/Nb2ceMwG"; + sout << "YknvrhsARhzLZe8ZFcdvQs8JnlXwuNpN9OKtBNQMpQHf6ytCWCV/ZCsrFCT+GnKMQxCvR4/LNIFU"; + sout << "vWKvSIe1iIDoOhUCgJknwUH1M1Ug/VLDpOxT4ylapmV+YoKSYnJANOQ3+3FhqHRc2KG8vGMTNzab"; + sout << "3rs5TEQVOdP+fEtU8oynfcb/0WK1BAKgmDRhKcz8XG28EX15aWyr6uJkAxqYE8UXHOr+kn27TuWA"; + sout << "49KqCFNsNcjhvnjWQEBvwTwGRKofE4ASdnQi/fF6Ozk321H0GJYpXETW69XJiNzT0f6lDUYWR+Bd"; + sout << "HC1sW256R9dn/MBAcSH2FRUf0Mv7z0yhkb4aUwHNvGhHPX5g0Bhy6x1TCOqHtjubkHkxcq6YpwW0"; + sout << "f+9DXG2M/sK+IPqqgy4tY/fcJeYxq3Vd4+voDPYfVxcWpmy/4X5kzf+fFwtXEu2MBdYKQEKviV0Y"; + sout << "WixgGguhWig6d2vAbUKgAlJd9yixtBum62nEK1vMZmwzfokP80fimNdw7Wony6EA+v1oB+zHQojG"; + sout << "1fNaBTacRa+NsW58U848DFrsJ94S/9ss/RUx8GuaWwKP6WImvkhyHP+YdbtkykVPVzm3s798VWAX"; + sout << "nvjz9R2NlupUdBVgJg03BAH4DoJdYWOxnLtUChjW++StA7OGgWoay7DDs/6AJYycFh+TFW8D/HhA"; + sout << "sEQCUAYDnpze0aSX6CBZ7vHMr4z9Gr6V1zE4dbJ27Yv7ZFBetDLTN3MN1vk8SYDFeFHtzyMBfqor"; + sout << "hHXu4rKtGW2j084a3BQvlDVZ5bQX8mh6QBENJ+C3kBUhE4QthRIUwGsuarKCnOUAAL2Zt64CgGKi"; + sout << "gRF4wZ3V3Tt/FZ725pUJe2Td65CddbBd2mRqEzdN+qqn24au9ucznuEr7E2j0T3ey1RWEcU7Xsx9"; + sout << "dEX6LoyFqowIEvyGOwdTlycOq4v53Z2cJTaitjGSvtaL1h1ASDW11jvsX0OHueMJJyLFwKOjj5JL"; + sout << "JrtfkGvWRrycw3UdT6g2n8IPKk8j96hds+UlDvgwIzmlromWvPfwixA9mt8MRqJlBIo94lKjOvuw"; + sout << "sS5FUcr31Dqmaz7uQcrwmxh35jAvW3EbMObKNdbaojaZl3u2Szt4LA0yFs7UJpkkJgEO9E7lGte3"; + sout << "WAUC09lpDNq1CLvEfOEsWQqzXayT9AU86CCnFgi6iErKoA3uNToSaqQVOxOi8RWzQWuNfmkluBMF"; + sout << "6NmgipWLj1nlLLBUOpQD6nVlwYqZEeiaQOhHRYUiSUFbJ1EVdBPoRi9CPOquM1WX5H9qbUOBchB1"; + sout << "/oI2wsh86ZqOld75VfD5QsHXmrFsQ74mO4MyvAojodFWADFmRloe6UoEBFYhjXSlj2sUSTyAMfPh"; + sout << "jgRrJwdMb0/53hm/Emqz9AhG9FpTI5P6doDAm19dGp09NvQyYoZnCCWHC6fPpl+BS9WANNfsSZnl"; + sout << "dtI/JmBbtlx51QDTA+Dc/RqGZZFzIIDnBoIUnyFY6FBb/S03rQadSUdeuBInghBu3bw1UH0AImFZ"; + sout << "uy0gKklA51AkE3xVxtyGaCniGugrWwcGYl+FeiisBNKtZtbo3spTO/Rr2Fjg/mbJ1I0dCDf2gJt5"; + sout << "qkjF7zRl35+gl58FjkBxMdzM64289w8wqjKO6P1JBnJV2gsbU6QYGN/w3Xy9BC6Yq7WaHU2eUves"; + sout << "tyuWAxa2+pl8SzrrvubhGUV0Rx4wVcRCLilUKHQ9VoCEJeou1y15h/hl8BkmuU65Do4a5txsH+Xt"; + sout << "5zVkkmGnvrvq8xMd15UIz867U7lt+BwgV8UoDz5ZQnQXvwt4+NLK8dLaRfr3Tb7KpLPOdymoOW1H"; + sout << "7ZnfmR75beprkaMn29qygGhlTmVTxNC7CppNNfS97Yc3ki+g14rkFZMmIpETfUnM1eqp9O1+y5Lt"; + sout << "1p1DrkrZduSwZtg4DmJt/2g0eV5xO7LPishn8ezxLTcMLCQYZDM4gjhiWt26U0XEQClxxouxQKiV"; + sout << "jpcwn99nBEj5r7tkQLfJ4clQW8gOuBidgJ3gWiAlR/Pwkys4L7cdQybo8WegBvXRJyNi0mDGsM4k"; + sout << "cD8A38+zxN/mI7Z1Gv7y5zWUMrILQJvPVXCCWXo6gikdJpkLrprcp0RUXo0b1iUgGgOXVvY6PGA9"; + sout << "HrM5/u0t40+Yh7+3XqzLWWU9kuOgDLvjwYeuRu7oniMmTWOuRtkWyqKmTXlrwjBz+yqIQJyiAa0h"; + sout << "NF+n6CNX4Bk1X2h7DIbZf/hCbYAsEQ/zHAudSiEfzew7T8r/h4YG2huCTjV6NrBNOzxk2Bu2A2kd"; + sout << "PbIOFbLWphpDX3pQtfSvx46IL0uX4O03VgTiz2NCvNKaH6oZZ8mTxkV7QmVnRdHG2my+oFsqtkv0"; + sout << "uRF3h9xXKDij2r3lfNRhHrM3OAsySepTuXa0hN2bU0DIiX+BlxVjBKicOLIQUI2k9ws3akGMpicU"; + sout << "vO/fKIwOs7D5L2oUOqDcycceYM3DCTjyjdeP9hq9+w3nxnuvtoNuFnE67nR0FA/XSYLjZOxq5EYD"; + sout << "N4LdNhDWzYFjEHXgmyvc6IrCcltsisuRN1GNZmXDTd++/8CdA9ThbqH/AZvMeUt3zZeKMmb4Ad3o"; + sout << "T9in73sligZv9ed32/6rc3w96aoiGtHgv5szIJ+h49hCNfuVoSR6GB89zKr3cHcsQ1PyI9+S7ON/"; + sout << "/lTAQtwIUSGSQxM7xmawtKVSit1zv/gw307SfoWAZMmz3TzrMzKbe3mmNNwW91a6GLqe1iyaQ9Lf"; + sout << "KdIVsCHzqq3DKgex8X0XPnP699SlqSR9gjHWZvuVOzZkA5WE8oatWQNS/W/bhBaaAiT5/781O1s2"; + sout << "9y+RfF/lQh6tQbhPthbY/ALQDR07LM43qqvdaiinxT1gYh3pJJHs6oEQEAZwNvbjUZURmwbJ+hE7"; + sout << "fJnWbYtftZn7pxlJXzXueWe7IPmYuzrn6jJEh4xbnANe2gpzkGPi2mHIlJ8ij9vp5KwkaZ+SkmEB"; + sout << "El1hUdhsJoXXErzKm76RD0oLriVUGd2yvI/pM5Mm/3/pykt8cARjatq2CVc2YE8P2M0cC3ItKD0Z"; + sout << "YqnMdxmJpxys5JoN7q9yemZFpiV2OMJ24kPc8BGJtgbVQBWpMUnNEOZOI0P17D3QRehHWf543/X5"; + sout << "PqzfHj9J1tUcJMNCpp3GFxejtpOvKD8J7JB7+0QzqCCOE8SQIC3Bs6IB3IeZiYj70f3XU1UPF028"; + sout << "Kbkis0YuvkL79pWPoqNAoA7YVVGhnb7ed4p2/lqFw1NHG0+H4YQO1gz8U6MeiUrjcxA/IkwNFAZX"; + sout << "5WsC2QSadmv7OtUT6L9ifTCs0sogqcC5UETke+cLIfprQ/c4qIMP6JWMqRasR4+0qjxZbSxovEme"; + sout << "0oIBoRGE2za5L0/Ulwcx4URU5/iUu9u7IPEcUVQzABaMIN4wMA4Lbd2evs1Xjo6xcFLK/DBtOUdM"; + sout << "JTmiOUwloSGDd8MSx25d3BM1HdXfUc5uYihDFkk/0AUSLa3x0GbKimowJYGAkr6Px/5MsCnFW+Ix"; + sout << "anv9LFErILEMTwCvETg2LuRDSC/2uyGNnndivYpOu8QGplJbW9n4ALWpvS/dc6PJKjZo6GhqqK/X"; + sout << "U8Q39jn5hSYIGb/LXS0mMJZ7CPTO1AcNCAZb2r7Fc3orKeMoBDA+9R2x+ZQnlGWV8PdAzfpnfsIa"; + sout << "O+JL5Nh2NimtPQsqEVykbF1RF3caiYPCcEj+SAUFSDZ2pe+YfXPOpe2EOZ/BnGTQWdn8u87mw68G"; + sout << "Y2zcXQAjUHXuu8iF3IXR0bbRR4bDKp1MJ3YW3nlk2E37v0SEkyX7EpXg8hvNB0uReBLOk8pxw4Ha"; + sout << "8R6ekjMAfMnOQynT0gi3tOoy6/CNXBQFAp7/u+Owz+DO+pRX9StkBjFUWe6+8ZeG9YJLnNJZUGkY"; + sout << "Ki0EEoWpBbW4zvGa0pOPJ/OceaYIzhL2/1DohO4dO9jJkOU6UxpYaHMOqxr56K7MCpJF6k3ii/q5"; + sout << "yLDM1Ee3peTb6MfWrrkeYc9PhEtD6TdlJwQpHkivrFvBvp5tz4Tg2zoW4O27SZZn9Hr0MPAKCT9E"; + sout << "Vouc/LBstfyV8cwXm+nqNGg+f9ZLgns7No8isMXQTPso65SeRw1K48aIwSqeZ/LZ3fQZ+n8l0M0B"; + sout << "BSRAF11nGpSNKY/G+7vKuJ+FVV1z1MXmc9wp4tjnYRMEBsmhTnL9fFpoHvDlOAgI7ZN+YTKFOP8R"; + sout << "RzeXkz6ne3/ztgWIS6QZ1xFzl3rEPwvcADeMAT5C/j9COs8YXQUafQ4o2I42PM8PloMpz73nTNqR"; + sout << "yot/tF65hOa7X6sYFISZxTXvAq1qaV4j/bOyv7Rr4IMQkvPhZRnKuOKd1lEtEleREDKdNA+3Un3u"; + sout << "3czlYhhMp358CRLKGyMNsLc0oVmrc3Zf7AljUJg3j/WmyZdSWCfCFAUzh4y+OF0r0LSxy566EC0f"; + sout << "EvoeoqDl6Dn20my8VVm7Ek7kqtjGX8OPXh+OYw7q7GDHCLn+gSQKRPZEMAVwFPfY5J0qUdn2EhBi"; + sout << "bIdpDLJlX5A9Tuvsfk6/TgWJWrd1ljnk5LRU1JdLpy9Gr3i7RuUzHVmxutyDezdNhJQ2a/dGJoyL"; + sout << "wqtRLh+sAldKv7SyftQMUr3mVgcuViK24HOW55optbEe8LiDnZVyjx0fWPnvUF3jrOu9iKpCzEQb"; + sout << "UcV/ItitPx76PWI/HUsXW9nPkyHBkVhy5n4A5OFbK4O7r1eiQBRttalZs9UPIX/zGjhL+mKsmDPZ"; + sout << "5W1ejsVxJT54FOwFmCD6yxgeWyGirMHJ8iIwL12Vf2WG3NvnLTIMokCOIrJIGl2nsgpYWFh66aLF"; + sout << "BY52VplzlZ7uvieN8T1gFkoD1GKDTPvxJdbUIyCQG09lZhFsqWDT1h+sZMk2cE4UGFa9SUZY8cC3"; + sout << "n8nDCxxmmQybJp70Ig8cQPy9US51O+PL8ZjyRJirlSwqhSXndjyvgm4xk9uOSk7AooavjRJYlO7r"; + sout << "Mm+U11XEN5mZJL10W3nMa3Ls99DQIrifOcWVAaY0pEjmeBWHizxOwwqt3X9CY+Sa/X35awGgaQCe"; + sout << "CaaIXBkxwkRLD8/Of5fYQ8Dwg+LWZnkMjs1IBdauSjdWvT7B9CXK7O+YJC5kjduIGOAu8qfJqG6o"; + sout << "HP/uOEvCq4vT6Fg/gLbOjNenVhZSq6bAVSgG3C8/WonUeNQiWnEI0xLm+WOU1L6bcqBNHNBOOiOB"; + sout << "yn1YojmgQszuBRj2yC4rCAkddymD/yzJ99g8QmhpzlakTRtpjf9LTErhIBDIQSd19U8mNdu1u+cT"; + sout << "PeIl7vlVmhhnojOHfLGRvD5Zy5dC9XZiXRRcNWHj9vOTdrmtwy0mPIZZMPSBLJnur3djgCLfMniS"; + sout << "Q8iBSo8Uzf00zSZBUBUmTekAAHFvf+we0/ERcXVkJ2ikfZ0v3sxGf4zIVGAiI0Nx6PZTWgW1xH9a"; + sout << "3OYLo3mzzz7TVy/Yueh2awJyZ6CWq/ysu9x3tSJG8DbBBc+4segNq+IKeu52mYTcLmASIR9PT12z"; + sout << "ifAJWWtKNPR4T/FK1pDTkaVkULRPzqHP8jxerpBRMnt6FStshZiZOs+DZVrXJxSP64ZxFthgqMd2"; + sout << "UvzcqeI8yD8ndIyd69Os5t9OoeWseLPWq95+5OUX5v13wiS+EHbJBNnZLLt4K8udeohSNUXEBDND"; + sout << "tj7TDkzLyXYpcm3KYkJgtRY88kwidchhjvMQRDNw15ozLnR01t7UsNLVLuBij9sUBDA81i4YXyrr"; + sout << "5se81KvEZG5kLLkFCUnPKRU/3aRb40F3zvY9pcBU6KeJsjBQzKRX/EqFise10qyClgAyvADqyW8P"; + sout << "jMDslvAq1mY3kJJU74AQNRuFP0iBJWTX5B3BinqVoE+VJzZPbziFVs0riZccQIoff2DiHS49f06c"; + sout << "NvWgHb04xNXz1skoINdhFVRSsjik+qmxbFE90F+h9eshTMxwQVobiRLdLv1pGkVOpsRxl0eDsPMI"; + sout << "NECgDvYSyRtSxXz+SUGk4dokSgkrTPa6NDx82FYiyITDu7wcgtQBOTj+SXRWrrG+KB1MR5cRHWP6"; + sout << "nmELArZ2JquWF90mbylFHbzKZNB27/TLN3X4/0cVqLNxyzow44+Z8f26lOFFnh9qfzqS874839gB"; + sout << "IRvAiKnuSo7KCh9BbsvAHk8a23Ei/KNxJFH745cvLab2oVcPuFdwsDvuYgPNT22tuLb+/QN7djB+"; + sout << "6a84/73zWFCnfeMlPgtGbxUE9yns4nYTNx+jLDUpKAmYBWeiPEiGhsvJYwy/dKvXF+CVdUHFbzJO"; + sout << "r6rnCYGBY4Urxy+S6KPldLloAnA2SgsqfiU3B59g18OuY2NyPQ2fj5/Ytpum62hQo3SlZTVGXFq7"; + sout << "pInYYEueHtHIyJA7xpqX7TG8HrdRJRbz0Obs5X0P1vbpELhaaI0vppEIcfzZrmrqQfOT9fUILia4"; + sout << "DOCR6e6TWVirmzFa8hilcMLjG89m/+h34mO+pwqNsURvWoJ7oixSe9arMPUl2pKyA/wICP8AjQak"; + sout << "XXaacTp0DMIbcOsf1NY14cuKUgl7tcst714GMXCHfPCnYMUzoawg4o6LVOzHwHRgys0aWz4aNmQq"; + sout << "5UJDqc1UWVweo8LD1++5BCSMO2pNJaKdbHqVhT9yJP5NxIkV1UxvH1TowNN4SI7q7BMH2fxRBB5e"; + sout << "bu2XU2s/5Vl7snSkKk0PxmbiiwWGiN1Dx9YIUCEZGJcomUh7phIlPaoT44MEDMT96z7neW110We6"; + sout << "/FUCAtq2aeF887UH/ARplwM/elwZyI2BIDGAWF5udKyRtCrVznJdXgtlNND4a8MM2OQicC3D85Z4"; + sout << "suz+m7cwYIW1o0xBcuM2ASLhhAdATCHoyD7A8fIm3dD0vMKM7vA8QQi8ETjiN7GzRqoyHOLt1+4U"; + sout << "wM/BQdskv+ufGD3KADK5CrKPyKzaosnPyGYyzclqSBbt//QUdYBQO4AP2A3hqh/QySkGID1r4fUC"; + sout << "cM+fXd4DhwvNjAZHRPfGPD2xPguPOmmuCxkHU1CpcKdBSbxR2q4FTdaQ+1HBbU6KZ6358kqT+5by"; + sout << "3LP0eS7VemMfP0d9rshtgbK/+nG30EGTtSIJe74inTJ63k3mkHb7jRu0rn+vmO2aO60HQM4SUiK5"; + sout << "eKZUz9LGB6lD49IqbEwj55SpSUmLsITyO9HdZ2hhAAxRTQeMz2xatUny/fFOGSUiBSeYvyrmnWjs"; + sout << "R93uXWixgiS7DC15ARLHAkvkhT76lGJkxfdyhdkneQlImBYUecqxHmGRAEOQEUFB7o7ARObjdaFr"; + sout << "Cii//3Mj+6/YSOAaYlS1YInt2Wo2+gin8+eeqpBp88j4NIH7qyKiGw/fco3w0Yi/x/2tbc+OR79D"; + sout << "HylA5tAU38lKI9ZuW3/8nYGw1mXE133YDK8ZVzGsU5JnTnCuIW1+aVLf5pzzldfg3z49X2S11QnW"; + sout << "SNtd3gZ6cY2sWHnQ0gm/EIQCF8696kwZlp/J7fZonWd5Fr3P57QepBjktJs2ReoUXGQGAYABNJAK"; + sout << "VdbPpCSu/kWlizZFudp+2MxUJnGMg72fGFn2AYhlW47Luuu3/DNQiHhmrxaqXbnpqk+jn52j2rQT"; + sout << "BRiIujM4aCQk8bXqFCvgD2pcFpKMKtwgkKeLlLMhqX4H4kKZ2h2KpJPSHsTg4EW6p0p1boVMoD8H"; + sout << "ZFZWTuDKtJOorF/hCTAC7paeeY6EqeEs1S+5ujoI2XjTNcai4TMAOe/pAdgpVIRIybPZP8kWAc44"; + sout << "P56yfuv2RYxcqAvmpEjvPqhaP0lG+bOUzxAelJy0hu0xe14CdvwWZseSDq+mDaFrrM2fEqGp0eBw"; + sout << "ysDRYDp0HJaY+QM7mS1chGIw3iTOg5IfI+wgz4BqRz6ILSX1Ci93QzzMsr+t8AqDssxzDX9kMNsW"; + sout << "oSoRYt29qfxPr7wjBdWURj2KxpW5yCygTMfAh53XuqjyEmuiSu30o3fCOzdQFOppq/LUul6A/Xbj"; + sout << "anZBHBWlmaIpPtvtdvjcWrflDrwdsJVZENu3qKyEmY+wpE2MMqXUtctJRBr1eauRU/Kcn+3JfSji"; + sout << "PtgwJ2W6ztTzjqr+6+TRCqJ7xnzIZ2Cj9jxjcD5VCEcqAVg1/rlP50cHMlBrW+UIySQQRG1FtX7k"; + sout << "8zqh30BJhM6Qff6NrvLkpzAl6qKkgcSe3tsbmovCf5EmB0irBvKg15LiCDxI0HyYT+KfKv5NBgjB"; + sout << "NQDDNZzqulGPMIqzBGaZc29SfIHaMxYihQS5mONL345HTVNQg7MkcHIYalGwbpIbq3GKv5MumOVj"; + sout << "SP1kZ2BGFt2U1A0ytDfQLhHan2a7v4+7T5JgUkRrvP8hKHRm0JTMFqFxZpzGOEQ9dwl6atFoNdQw"; + sout << "f+5ZTWFfzYzWxfvZT0j18yQ4fo8VQzCqlOc5kLYatapswTogEQ1PGd3UH9gTs4SfMjmvrfre/kBL"; + sout << "nTISbSjaS3Z7N81jEQkExFHqRfxiMIK+h5VRZ/FwviueAATnsIY6vChZYueN9Y8m92Ou81EHj6Je"; + sout << "o4mKrrtkUETSWP7IkTd4wD/EC7hMWvA7c9cWpY7Q0nAcruSrfcUgks3mW//cNgLmMo0UE4+0W8LM"; + sout << "H5Ib3drj/Ha1OAZK3NRhtUN2TRgJEaLWrvJ3fo/xDyIdmm/Ap0jf6jXGlXtdWl7KaMElGsmdnFyb"; + sout << "xgIsGgh9k6hyZ2uTFcTzeRFcsAVOBkQWOBhyqtqh2asmPFjOro/sNvstEVD4+SrHSPG8rlqK3nr6"; + sout << "SzNq0Is07At+PS14wWqp9ZseQSLr/pWwK7a/CyyAwxnbdJqhhsnxmODTz7nIWFgQittN2Z5l0RhP"; + sout << "C06ZoojtRQVQPyjCr8C0BnFGyKapnGPGWVbuzlu6xMhqC9C4T+4zjsl/hVT5BaBKvGgjvxUJ9pWi"; + sout << "fs120RVW4ra/T6X1oAXvYrjp+2i8QVcOPAC/jfE3zT1UpsTrt4IKYR1ILwbWsOi9ORJUqpQgQKpb"; + sout << "ZrMseH1LAmp1RzFN7M3Xo+wWx1fu9S0W3ZYizaVrkbWrGJyv99zuiCeIY8Ns7oqUfQAfBbJM2TPb"; + sout << "6jmkCyD1U5a1tir6vGc3YDkBryvw1oXRdeQEuKPSF3YVsp7bHf03ALIn5glbDJwh+8FvJVX0rkPa"; + sout << "yK4n0dU6mBDfyCt8yZ7/uFqPHVr9Y3d7Ec6RJPEouC3MOlKhYzow/lkVnCVteEaGqhGh42TsKnHt"; + sout << "I4MUSgA6n2QVTcFfD/hsQdtjzWoFupA0v3PsXobN27vz/aEGrtHjaPXKNDLfbRyW0OQN1Rfj/nRG"; + sout << "+TbEr9Y17f87ar0qnd0aLd/hgEJ65lW41HWWj3bjRP/RL/X5HpKVzj5zOPBfnOH04vCT2XtjAK2o"; + sout << "J9gFPdwz+6hWrKXGuQlrIJS+7RpVwZG4woY4cgBv95rFjfWwLrLuuX7PYeCJNLFZKpoAz3WpzRYQ"; + sout << "PepcF0/AzD6U/dLKAaI3pr/g3kItwzrhGIOz8ZLN4IrVYrSQ+kw+R9NuRWq6Wxg84hoRBWN3hfxt"; + sout << "A/58J+s/+DnmBilHfrYMhysYMVe0TaN0fM/Am5VIQW+lQJOVY8nVFqxxxt3BTNDAQtVl1BwaRQw1"; + sout << "7PVWmpFsBKX/cIkJRfuUo47InS4SvzV6JEgNSJb4Jp7BHtNHQpwKDkCjfwcajIrde2nIEXCZ7CPX"; + sout << "6TmNL3i/ys3IU0zW7v3uDZja60mepUlqvb3mJsiPNu0OkX26rIO+K/E9roqWnwp4HdzTCgv4Lmuz"; + sout << "5jaVBx3ZVxdjB0Eb3Kkt96pZflyLdI6WSR5RVCfpl8ov/QOKAjZoihm3/YZJJvytsUsj+Wi9CS3G"; + sout << "If3aps72fkcCfzlzB/2H7HS2Xmfor9P+hBdfEjBiz0oMlkXo0+JuamAk7ueQ5RO7YpnI61Q190x4"; + sout << "vgeGRm588X3qNb5IXuhumYUZJ1ATmHc9mMkNSNWEGdFy13DJmhHl+fOrADwMAy2+J8g0L2yHloaO"; + sout << "rSycaN7DgGe0YJD8IUbgS8QMOz8Kq1/Dy2/Uinix2smPobySnP+uQjRJ6ilbHnZCAa5t9KB1fNRs"; + sout << "HExRV0YMN0tZ/njB/hUsI+N/sg3lvPtSOOpXb8wAQQIcj5v7rCaCBbiGdfjWcgUbHfLDgFLTxq2q"; + sout << "LND3J3HmONGoE7kwPwNRBMD06KUY9PYwSqjQ7spkdBOYtuBVqCbnnaT6hnh/DsSTqG6DkymArX6G"; + sout << "Y4SesPr5KhQtUvkrHHQAAmQDWALzA7w0fKLi2unHSsRmteT1ctL4lGwVdEb/YFmNzxiD7BM3Zqvq"; + sout << "JtEDUScoB94K8R2BBOy/CQqEcln+DCoukLW8BdLAJ6b3mDaw8tLIevYutQwrQln0pc6HwXKmVTeb"; + sout << "AKKBMmZaAAClk5amuEkyI5YPen9gSZDX45yRohZcef7NSVOmI1ma3mc0rFYJYkJkC7T9bcpcHqAb"; + sout << "WONeENewbaLz1onBbrVYBdqPeBWjrGt4kOi+YSXtjm6WdZRtLf0cy2uRq1gvzg5oNIDGer12uAAA"; + sout << "AIqr189l1IrG/3MtrYrKK/Km9wuukr7dJe5Qf/96FXuft+BUsbt8WYQKOxeMZJgWfoZVUzKwkzEN"; + sout << "/+uWsF3fYne2EclBulQem6566qaX95eQMQv+mzrBVL7+PG6lweDQ6r6ZETzX5ogxLqyeqLSue8qJ"; + sout << "xmVPMI9fr4rDFcRrt5XgwH2j/QB8ymhsIsqPVnZKbq9+jrPUoRJDs1X6VCycaMS9bG36PFJwxW9y"; + sout << "Be+LAmVi2N4wlBCmNwXclqK1URJRvocKWAKrt0Kn4dILU1Y7hXNPjcUx/0oG5Ffoh0f12glRsyva"; + sout << "FJjtV7ePY1hClgBK9v7owd1YZUz9e10XcWOzS6L8HOpX9pocEiBdfO/oTDYBPd6LNk/E5V032Usw"; + sout << "XdqReOtJL/zOBPW1vJ5xc1UMRyGxDeqUQFxCKLkFAQj46cG89dM3bhj+Dh4U5J5MPhQltLLhIVbZ"; + sout << "qHvKkLr8cCYqSS5h/SOF9HWJKjNfXOgpRbR2l3/m8t71h6sSi5XR5o0NCi5V3/VCgAs/CfXCYH/O"; + sout << "WEDkfZKw1IgO15FUghj5y1vzejXv1logti/MHvh3WqMlQ5i6B6SgaGmIbOdqNHdYUbQnNg41m+8u"; + sout << "k5VPCiPWtGkbzzuzaYSx/Na14VdvAbZXDztFPbN3l1fMrAolDKlexOXohzVZhYOsqyu0xdGka67Q"; + sout << "i/Bz6DbYM/U4rF1lWoX1Xh9eslMlBlJjaAPdicvaJF7YArjzLKOHB9p+EZ5tfrotIKGfdGIbt/Yt"; + sout << "MSsyTx9PBINYD1u/JeI4SqXTwvWcfF8+ZmnspgnEabm8NKufPXhaJODzrxL7filfb0/JbL6qmf/b"; + sout << "5ksG2XT2sPxBM3TjQE3DAtnS/Y3psQUGsXUuY8oGNF/PgRWd8e+Ek9IoAnMLuvW7S4D4MP/wmWTX"; + sout << "kZzlpZBnBBvlh7H5yDHGYCwGUofjCiOk+4hU1LUNU4LGgZYUsfjNW+zbj81/JSoK8AMlmB/ERH+S"; + sout << "yuBnOIDJNY7BG1n0q/YoDIU/YuM09fyzW4tAHNqfbxaYif2QldKZWbxUEiiOWaItpW4QTsQe0MaC"; + sout << "h6ppkigA6/7DUgUy3PKon+JB/jdrUdW/2Q1qgLqV22zBLHJoq76n4QYZFmpLN7FkLFtabiXZigdm"; + sout << "7y4GEh33CQIlETdes59sn/0qIUj36ICMh1T/ujQOe8hlG1EIruBdZk9CtprIsVYdpuiftqJHRLmz"; + sout << "271Lti8HdM/KSVqrupdYb1GO00XOGVU4/kujCa/nzhCukpCpESV65sKSuIaW+VQlvYFATzX74WWL"; + sout << "bKfIvaa0TLR8pMz30JVDT7f+I568gRLEMsBdtMnLyJ8bgkPGdngpnz4G9FBOqkDeZdb/Ji0aKN8v"; + sout << "gwRZxrwUrjBE4EjfcRQy886hvReIP6GvKax4bwVcGfmYGdNuliSsCsPC6pLXSH3fnmYBHD8DbGct"; + sout << "Q7dlu1Z/aFTO/uY3yup3clLe89WTrPTK/N1lsUDULj/gPNwbWY5l3ay7YRoTFt49EjYPmx6wz3kc"; + sout << "YXYzOroiVWGspN9WfmHL4XN6pXWhiCOrMRG2Z43vbdUnqNk1992xy1q/ZHTw2ikxIwKy77wygWfx"; + sout << "28l1Iqj1k8Km2Ci8sBy/QCLnql6olazxv2fxn6vj+4QNrOcJb3sDuLbEm6BkUHIoiFzZEpc5nIpw"; + sout << "5YBSARDuvqMt+01+ed7yBUTms08EFh6EwPIsJ5R/6YE/Qt1aP/hfB7nGWlpGb+yqycdZbUw5XyZm"; + sout << "pOEHQHvW1j30bk27jCRHzktRYXRXQ3BpP+gVcAlXgxeoARjFfC3hYPFtKSQsmKpGKOFFMdvdO+qI"; + sout << "lwrru0fgpd0ctBiPsWR4loLRLpt/GHedyVRF29GtPPfP+JekSMDzQj4/73RZfIdHlc2b7zX+JdLz"; + sout << "Jk/4yNb2NmUnum55ezFqq7sv4HPFvMBuOth14QPJsrT1f0NyUBP3qYocsTAY7yo/fg2gYqLFMBHK"; + sout << "+N9IvtHqxyn4actl6t8DIHeBTsM48AJVza+LkOrmuK6k1x61G5r/qyEVTbWVMqKagbySMYtIyJxd"; + sout << "G87XVnI5bbsfSdTdTSQ2EOry9U2H3HIrKIFzxOcj2NxJ+dNqpiLtreoN8XsfLvfM/46NeNaynG17"; + sout << "jIqcV8TL7MVqJ74/OYv6Mek/0vjLysUDQwPniEkka91ZCJboXzsuW41M5QkSfU/nxsD05TvQVbFj"; + sout << "GBNApTN8Oe8miBahONkbfBB2ks536yBwG8RXWGSHR5Jdl6yhUIw4mWb0WupVEd90sZmXEYeP1q4K"; + sout << "MvdHRuEQbNZfyXLskOtwKsFAqa9ebN6yVdM2tIOOp12Km3mabcUct0nOq///UyfEclOfrC5km0on"; + sout << "WWXFk++6vwChLbG+5WCZl5SNeuxEWu1odbh2XZ5ng+XSz/T2P2l3Fa7ujEKilHAYYbS3foWHO+ev"; + sout << "sZlG35o8e5wuKiqnf7gGs8tAIFFyONtfRZuGzH1qy5W7BysZaohtU3ihEzJdomSI1xKHE7+Dn94W"; + sout << "ZDi7S+Qp6VKbepW6JcdWctmtXyncC5AW+UzquyYuJzjoDRdVe7uEAA44Fhc7nHFHx7snTuZpaYcE"; + sout << "OfIBgzr61aCpVMKX7YJ6NOzRNgfTBIMl24JybjFG9Mcbk2qQVzKQ+w1StjfGTEexWIfYRDxhoUzB"; + sout << "I+2cAA=="; + + // Put the data into the istream sin + sin.str(sout.str()); + sout.str(""); + + // Decode the base64 text into its compressed binary form + base64_coder.decode(sin, sout); + sin.clear(); + sin.str(sout.str()); + sout.str(""); + + // Decompress the data into its original form + compressor.decompress(sin, sout); + + // Return the decoded and decompressed data + return sout.str(); + } + + // ------------------------------------------------------------------------------------ + + // This function returns the contents of the file 'frame_000103.bmp' + static const std::string get_decoded_string_frame_000103() + { + dlib::base64 base64_coder; + dlib::compress_stream::kernel_1ea compressor; + std::ostringstream sout; + std::istringstream sin; + + // The base64 encoded data from the file '..\..\examples\video_frames\frame_000103.bmp' we want to decode and return. + sout << "Qld+lmlZhXVX5NDRWIFG1T4+OGdJbkmxAXXdHZEtpDf6knVTlRWyAhgv85Tf11KJZbhKv1dcsKtQ"; + sout << "fX3y/9RxoGwtXxwFXCxbBrLRaklZ0wCJRoLkFEbV5T2V0aJZQJTKWkAqP6JwJ7zPTyCiJkDRbGA9"; + sout << "1cIxGeLInzTUAlhbGO8eexmWzXHz6KhgZ87K5c4n+YcRjVfwTo3Bl2AB4QPv/bs5da4g6jZd7iwv"; + sout << "PYnF6pRCiwuiCobwvi8eG43pWnNEX2bIUtfMVGP2zA3e9Qg7Q7w6/KNILIYSXeI4KjLlnmC9YPB2"; + sout << "2qPH0Q7DVlBLBBUHO8lZaFg1gR2Jb+4rgSvuUg6JaoUbqAUMaFFnQCjw9bwdN4bJ1GFXCu4G2afO"; + sout << "/izIyq+QxXHIy/Ez1TZWTtkJfWRkKAjj59IBD7fRSqyvL2tSatSiYQffAcsyQJzcPMTHahKq3XeO"; + sout << "pFu5XOr3OsyxVr7ME4/GAOcUjNpdW5pAMPNddRfL6/Jy6S9ICHxppdwhMAtbxsaHiKv2xgcc/t3I"; + sout << "J0t4NPThn/QTNed+F4iFR30+lwrS1EyOUiqRT7q51LFDQE+/ZyqCrq8vEGp+BRAxMEoL0ekA2B9W"; + sout << "vceyzfAlymEivr89KSSC9L+VZjE8qylGshgeXM30qladBEU4CKtcTlfJFWSHZX9Gm8qDXmp/5bX1"; + sout << "lT9/GNd2E1hBWWQtmQ50DMlInp8oFsoURaMgzNu8zOelG1WAD663IxutR2wJY6ILYQfXNND70XrM"; + sout << "O2PI0ZVRWenZjLzXtSefvwSXCG1G2u5sjq45P4Uk35wtetzV+N2WzVinM3yYtcABa51N3Se2nDWP"; + sout << "huxbCEVYT0c5sfaNCy7XRQ0ijcrMdqMDbecOuTMSlE8FEu12Z0cNFQqao0R8q3601D39vmJoxIPM"; + sout << "wRIiHH6qVcCMBV3oVPWIiD8T96pUZoN3V9BpyQx43MEC6ZdVFqUVQtiYu/PgWqK2pwk4Lh4vpou3"; + sout << "Jad9ENAsLdoVKZrtHlD8k70MVFOXNXJELqJarxwDy35JAha1+dHBvcgU3q/8pOnhS6xZ6judsmXX"; + sout << "LVBdUnKu4w/tbXgJA7lkLtwtODiXCL6edRCSkuXmTRLqXhrzUYzOVO3lZQSNWLLagHa73ce5B6Wu"; + sout << "rVTjMEVaNC3Guz5/cOkZrveG7MTk/eKR56x7MhkDS3H9pbNiCoqO7Ub52x7j/aicE48SYYkXFTrM"; + sout << "ETpxS8VFa1i5VpaiDxhfZKKV03S/fA0QoUcakVISiSN3VgSrhYsu9KvVjD3kjW9IMKGMlcVH040t"; + sout << "1cs0rHtLRVh/FjpMZlQxWxhI9ZMH0Ccfzpr7SbTN0Y4WfKaO1O/JfqYkzCnFTThvkrUMAeA1+zHJ"; + sout << "1HS44rFnT06RSAGFX6KaqrRFBOC3V3ljXEHyyrGJJe0OHWTUAwov+JzoXuB8onm10lcxUpV2Wd5R"; + sout << "RZ3dx/u9pEYgBolzQIXVkcE0YBYmkXMJfgXK3Rp6V+1pIR9bTWUU76/0CgMHqVBKZOtpl1ykTQ90"; + sout << "h/n2Mcjhmo9BtRIi+R7rBhr2a+oNj7GZHVPaaItTcdSvns42OxGjc8R6ZdzLHzNNPYx0NkbVeKnU"; + sout << "FSV4qd+iYq+N5hiPeTRfwdLU1wRL84Gvm+JfR3Y6eK8dBmMGCp0skspHGgpf8xTXXpYmOuZ1FmXV"; + sout << "YMKmrYMuTm0IigxtpmCmHPmRZ8U9xbe4Oc8UDxkkmCUvcuvMLS9SJKFOQiUJwxuXgf2IzzzePLcx"; + sout << "ePs3jrlOoD5GzQzSvDrJRIeSxjwTrVFiFT+qwMIIAR7TyTjBOWV51+kM0+5mU9g89OwLmO4yxXes"; + sout << "F++dusUFHv+DGSCxk2iGNwiNDQA8PktmdVKux3PpVBxWZ6P0W4Kyso6RNIXk9iAVZj59zCr9uPwK"; + sout << "MzIZ1ASN4a7lrfvv4cD+Uxovvt1s8Kck2wmIiWg95RiLdnopqSg8qZY/3U5u6XmXCqnEFTYuKKVY"; + sout << "dWsDem2GxTTNM3DtxFN9fYbHN6MdbuyiVRPovBz/T+9AYpx0PVaFst9WVLbf8Ph2Hs1S1vgrW5sS"; + sout << "qM3lBgjEYl4CvDzL5HbmqBV1gTOZrzkrEdr00ChKzRURvTBn7NlvPZewMhL+Hc0SVXYtzCLSl+p4"; + sout << "C3tIOZaI2l8JIt8k3znBFs3UrWx1WGCv3GY8HEsTa6uX984B5dVvdlUqmstnqK0FgMEkYxiPWoKV"; + sout << "eYlx3k/sV8GChiZWY/0k15KIY5c1ZUxBQezRc6a8IwYeuCK4dcfkMtn6AUpWu3W/vY/sx4dwNNaC"; + sout << "eWhKda7lBCMsfRdeOZvWgvFKoIknjnkh27uEJu0AskTF+xhdXdKwIfM2ThYFipLGAImNXGuHbtlN"; + sout << "CJ4FBcpn1BFWEY6E4Q+UeOlU1rv//LgKYxgNC2GX64oLOAoPGgEOG6dTKmnad/7y42k+kmQnvLvT"; + sout << "YNkffPIcqqBvhZCmvIg4eMR+cB5Kw5IN1tik/XNmrI2FP8y+7OhNlFah1F+slgX9oHVKoRLtJFPT"; + sout << "e4rQDWqAnWMAoSqsITbHCq9Em1lP6Cq3anR7zsCNH79KJCsbZOZ+J9ccoLm7uLJQUl9BW4irlZH8"; + sout << "fYZypsN59YE3kzHivG1xH0QsDLXrw4G9E2Z6CFojySfluCPiHXO9V9XxiMaU2X4uYHVb7ehs+/TX"; + sout << "h8uPIngi2rIspfThIzUPC8nl9k8xIH1BoWhryXHkzPi4M+ZAI9afBNr6dB9qa8O5gPvcfpeWPRvB"; + sout << "TGAw+OW3GX40CxJhqQ6n1G101LjjbFvWWdHTm94cR3cS4B2BYzfAqy9Fe2woNh8OT9520VCEMgWc"; + sout << "4xYjf8rbQFq3kBgPDVEyu86VZCdVBaAs91vTnfd1i7tKCN8JECyr04O4qJ5IM1+5Mtv1Z6wlgT4r"; + sout << "NmW5FaCZIKwfORsXG55kuEeKBmvbnBUd6iO+jVtIn3D3RjWCHS05cBsw5uCvixNIHXIF48lTiHpR"; + sout << "hGITp+vcydTLUEeVRxnlKIn0RJX0Ht2PWFNNbU0RV2WFS4K8QGQxekj6i3XFsB0UWJn7ipi4eHJl"; + sout << "CIP6hG5bQ0qG1tvhC3a+Ve4et41gI/DZu9LbSqg3oNWUPjEgFSFadjJ6ZByhJ3QTm8Pl0uVdpzra"; + sout << "myKEafoz0tbaRgJgqAGpQwkIPgp1ocPZBplMSbI9zjBhrTlI6fXfH8aOw4FgB6F436yGuVl3/Rjn"; + sout << "UuvOT27jH6TvUOGxWBbzXXcxU6250E9S45QSh17Ly1/pfal5GdAcemaHKLRBeCLMwMszIr5ZOM8M"; + sout << "LxeGCXvWuMuoLIPmElzolFuKKEht5zAr+KGbhsILlLcVpQBnQRM3NxXfh6bO86V/lEOZxhy1b+Js"; + sout << "1bQrn1Uj90QnhqDoI8t1P6aSFDWKvbsnzC30gPZ8E+FFTwsvvKlzijcoroFQ6I5rSdVehv8/YNvf"; + sout << "1yYpXA+hIK5wbbyIAXxxmi+5OP2viwGEihe5rwkuHRTmtJJ09rHWFT2TbVQAF7775ZFNpSogY57b"; + sout << "GJVCa9pBGN8qzsahk99Z3tjeP+fP+4F4Si8JyPJq+xmTO62ciKxHrHt9b6sE9H3N62huURtpP90g"; + sout << "91QbEbFJU/5JSPfJjJZIp299UhafpBSUACpOnnjBJT4xhVU1E3k5x6TKdrZYzAC/dsQ7BEwZgg+W"; + sout << "5wfLTnkJsuvMLQgS5e+Ot/EUNd5+16kkR/wFeH6vtzG0EyxkWK4V7GrKtsrLWV7jOC6yY2mGwfnB"; + sout << "iWLdaI9Wj9pIeTk7w+PcyTZZfdDldyOF9TwYKvHAz6ea5s7sUe569Y+iDqOOXOwpjX7TMUVgyTCW"; + sout << "4VC8ixQp+UNTdlKLoeS+2fw2QupH395hLQ1DHBcO5Tuil8Yqxxw6nV+j27THym0CaB3j0JgA/4Tq"; + sout << "vbqsEonwHgSaP55u/j+0QRpycKoXfEkS4gh/qMhbLAPQyTNY7QLjE5cVUqzLG1twetMMcXZL2Dqf"; + sout << "dP2uDCswsX+RJdIAN8BFoa+J+uXHgBMgbxO4DrGz8GxuSr8RfS0JeEFReLiLpE3n+SrnLwKAUr8y"; + sout << "qWjNrqRD6XKVWpX4SY1I6wcE9/3jX0c27mGX+87XnUZdIOZbNgXmntyS8J1P4uFaEZIy2rC2ahgg"; + sout << "1mdwALsK29JJ5QEaSg/qd9UEmm6jnWLfquCgBIkacJgjG1wc/029kL0/Br9t7tuI4jPiKKp+XRM8"; + sout << "ZZJGwjcB0DbJPB25JQ7zp01lbnEnNiGpNkNXUQtf7a9Z/F530W0M6wG4fOl3lW1HP2zF1Ad0GJxK"; + sout << "y9g9bj+tCsfJQfeymW/5xaC1C+9sEDOsg6hgVH6avh1a0ILiY/5MddtCHDQKQF5tgkuiKXG1+ihQ"; + sout << "CUlFov0Q7HG9UruG+eydcYSdyfKkdTqYQ6F8PjLP/rpiypX1jiZAodzEuvZwQcZLAgp3gqV1swZw"; + sout << "6rm2JpyS2Kqc4yCanywpHAU8LkVsjbPTanJiOgDguguVRtYGpAh4fTje2Kofw3rbPjznNu+CC6+4"; + sout << "8QxPZtcc8USR2EShYVlE90d5uu13F+epClENFc5hww6ymXysLDCIop2vM0PuA6csUxJqf6YFZHBm"; + sout << "c1h+7m2ULTr7DEVeRCbDkLvrXrL3zzqESvN2x7QDUW1lQRG+A+hLKS/nuibs6tI8uCZ0I3/VDQKX"; + sout << "hAhNjD9x82FqKKyPGFikNASTaFg7I1DAHm2iHm1Rbff4I5S57UUIk5zr4o/pfQBh12XxfpLtvSfn"; + sout << "SGXRAUcTO5jYoftuSmX9tiMfjjSXopedLWjMKSUXE6m2yggIzxfAzd60Zyr9lNMBVzeXTSkrReAU"; + sout << "QjMxJ9zsAF/86q2lsxr4gJKNmg9QkunRbsVlh3OXE7MkxBHlb2MkULioU83Gf8LiosiEEmYSB5vY"; + sout << "H8D6cgnXtB64BWoEJpQuG/R+zvnrWNHKmcBd3YkEi1j0dtL5G50SlMOJuai2xsD4RosbPEdV6Qx7"; + sout << "ESWrAO1naNHp0fewaVl3d8qXh2rFZAKNZXP3O3H7h3BAoMOsxWMFNJwQ0xZshhRvASlOr3Yn3HB4"; + sout << "fOqqLiPecX/o5r5FRa6qIfP75QaKqcu+iA23psV4/bszP8L4wREKioaSFncsUwsMU5NyZn3hPDUo"; + sout << "kdW+ZGGQk1F6Md46VYMIvSxC0xer/NcISrDfWmae5WCoGlJGhv1rSnSjRmMrj213eEh2qpDNhz2C"; + sout << "8dlByOZtSTOPejh42nYMZKCdBTK1lftdSeYdsHRGt51TQOauYYNjCwQkh87aW32Z5dsUi6H96k/c"; + sout << "NTnEoTZbmyqFdmov0QyQBgIfPmIUV/z+3x9PnfuVHJr2PBdElqqgp7h0Z+Z/7ESnDV+G/0/dx6nf"; + sout << "hfktS6/8axrd2xTKr7VrfBLvasV0dTC/GibZlQ1hvowffGgH8dgooP+Nc/S2UqhkYuMZz/Df1r5g"; + sout << "jYMY3rLQt3RqpbZmwU3p9MGe74SVOSCLgYXIDQiGyoGGPuwRwj4uhI5f26fC0UzG/0RCe2g4dDe1"; + sout << "vpDGatzd+3ixmyzEf6K+4UU4l7V7XAuUYs3lGEmtkz9li9mF3huDAUgOvvw+jErQrMnUGeyfzqDF"; + sout << "xyGfB50sK3Y1q/NSxYKQ1oHsyNYvnTr/aD5OPgMzbjv1Ii1mnwtP3m/EWmdSmAsvsjEq6Dy/yBuv"; + sout << "WZ9JO/XtxbusYXBNUxb6lU00DpJK5NGZVxv6UiQmrRqZvg531ratwkkaH0ZJSOwOzWlRAeVKiDPM"; + sout << "mxtnqoUGNcyr1SqIHBHvhFpr2BJKEeZ6MRwNHoQYCA1dQ4l2YJtZVITuL/o0SHNq1EXC+5sWB00T"; + sout << "UFfHRWFOYH6/baG2GJSvgzQrxBH9miEG4WzXQcdTQlmup0RZWjnwEoyHRMj4pc9cujSsTMaGktXD"; + sout << "ZzkAX52IncX0jqq2ZE4epQnx1UHg5IXANgb0Ed0y/FQjS24SZ57uIvl396ubCO1LQ04vGH2Mvs/m"; + sout << "bSM7jKGtcvjQ4uvVaCyk29Ek4EKJM+gUAuEj3IcKnoRNL/P7L2NEBlJuQBlIEgczCVG68kkIOFLZ"; + sout << "CqGgbf9IaBULEjthq4Pr+B5bUYmoBGtkUDM0KewVRJ+YA6A62AAAeStbx+myMgTTqq2b4idEllZn"; + sout << "0Hc4kztLuBgtHfwaizv9gYEUbgKbXQWfTQdUNAGSyKyeY5dQ425q+eyHezNiXZ7C+tDiM2UN+Lg9"; + sout << "t49U+7urUylh189bDcZcGmbx10unLEwYGT0CgDPGA8DJRlkHUkXSlImHpz8+3hjRtgljVcz+kFXi"; + sout << "jP8F2VcEoiUDRpaTwdJTi4pRsWSZF6pDEvHWntpnhtBI51zoDzEpbkHJbjRbvBA2zl0b2MqmwGXW"; + sout << "2qDOXOfSFVUVzTr7++omJ+UXxq9awcr/LVXz2tsuMyIvREj0Fx3c9jVozDSbOPzn97QD9LNh8Pau"; + sout << "bvPgfmMGz1Xj5f1UuKvUiCfvfP8ZZs5l08ChMadB64ipgmdWKK4adqE0ES0cTqcg+pJDDL8FgLYV"; + sout << "susPETHqp58vc13TMcBaqMAa44/xA98rFy+KZpYNxFf+l2U1eq44OH/zdytXhOg6y8TWrKYIdX1A"; + sout << "C91FZWXj0t5KfTahGZE5t0nP5iKl1IdXnE+dOIvIHuzrTnyjM+IiHmj1DxClG1VQXhZcwvfBzIY6"; + sout << "MRwFxUJXmt7O55uesEbZ/anmQB6LHdy7hMuvvCgeLBhoRnnyiRJYp3OAwNH5MLy7LXK84ZKTkodw"; + sout << "BuyRw5oJ9KOB13Buxj2yCRhLky572BT61cw0rXXkw2ZOGQgzHbO+cjikNk5LsVyMYfE49ClKM4w9"; + sout << "6pVtXDNzSPvQK0KSO0ZMKPDAr1/MrV0PD5Y2v1JIVPOuKSjq1s2vP7ypzarrFHAJnsD1FrigYXME"; + sout << "9amdaYhCmPIiOjj+9xSfmuBrrsHJ/THe82pg0fO3mpsF7TVY3zQCpiMDY/UqDyRrgc8NQu6A2WRc"; + sout << "oGCRQ5QCWCTFhbIV1HQ/KXjiCeFrCd88KkZDLDwrWMqeeGDTM8eWQcTIqoMMtu/JiCrts9c8yiTj"; + sout << "IFiHTXnpZMYwxDeT9hvNkifueXVjFYbEs6hHySwX0kcUSgtRP3fUKyQe+Iz+u/iOi1tFHQgPsV5W"; + sout << "gltpMLE72aQmlQ9C1lYkJQoKrQJRFiJmiQn5S2hQKXurP2HOIsAYSpW4mpeeFJpxsF8lekzeniil"; + sout << "96nrrpQlpc1vPdDhbmTY07tGoNlYRjoKTYaelkktnWA2a4n69x+kD8iBfVLANU+EFHNfCyQgmdJe"; + sout << "fWMYQwqwgxMWk88peAftdx7bDJ1ZJQ4q0zzkot7w7NB4oe0Q66awUwj2n4DYvDAum3SRz2LwsaiR"; + sout << "Ke8N8dMtRYN2mosZ/108MktIKlnJ08bkI6eDnO5vHlbeBfLOfEP++jGsgdS1KHrZHdIygtioQ9Zj"; + sout << "RzWb1fKOpoIbxV4mgh5MwpliqwHmG0DIq1UPIEqj8W/Vbd1T49FhWRJs8cyYraimx9MPklhyxlGA"; + sout << "HvHpQfuihMzLT7nNAVceWz6GzYfHy5uvje2MYBVHkoQ6RaMnJUCfFd64pbFTmsB+/+Teoo6q4I5G"; + sout << "fbeIMh26cs37X5MwocXcq7jSunVLgPJYDlBcbpIdhDt+Uuf5VPo1GXRAk80ouu1E9mrYvlZ4sanS"; + sout << "5EJ61PNevEOSDPPdYWwfMR0rKZuoefUCNfvy7XyXpEnZpNLHKsO7Cs6ZyL6zQykBk+KQCfdRnTHa"; + sout << "SVSpKvmD34O1gGkrq7kSfyWO6eU0TU8HImJY1tpI8fcMTheZ24hxx2QsV6IbsM0wbxS3d2Qo15XC"; + sout << "oP9xQxVcn17iiHUg70awt7j2VFEuCb1QjMbNnd5lUvjSJ0jH6gi+n/on65AWEJYm73Akn9i3fP+v"; + sout << "oLTDxIOmjRZsDgzsAbgtxb0Ozu0rNutnPRH0Am9C+GGJX78X3u5qDVn0q2+qD7eQkPOmW6oQiIYq"; + sout << "8rWM8ya2N940j+H5HsvyW0FKyrqjSSXQ5fF2LJE27Wc5ob0tVuvWAzOQ8COZWMMapu/S2C61f+2s"; + sout << "dFiCjJQqvr3Ai3pePmVrODBGROVHow3PXekC8iQLRwOYfNHwiM1jGwPOiMBZ3v+LghoBYHhmUh0Y"; + sout << "kZUb1j5Mncmgz3pctYS7KFhlHq1RHTDlF0pevtVLDXyy0QNXFoOqGydJaqLN2eILSl/eOGmsVlQz"; + sout << "lgyoCzPW0YsuN7Z4m0z0CfdP7WTpL4s2jlB7QpdaF3JOaip2nVZzX9xb6qXPosQ1WF9yixcznoAZ"; + sout << "bfWnipe5Lnxm+KHX8OUclbWUC6Twaxoq9Uo9jXuwkU6PZ5VzM5lao7D0JCPfVo7b55q1258j1AnN"; + sout << "9JNYYcZ4ZjTEKvbXGdtb943Wmmr7CnI4IyKv7AFpeiMAYPVRgxIvn0RjWpJ6lzV7LutAjv0HXwZ7"; + sout << "5XHtqvvuIe+IIIG7Iqj+JBW/pgQsbb7CKXiRmrWwl1MTtDIrQuh6UP6mstikBOF/d+yOSqCvU71B"; + sout << "A/Bj8gEqiPEqB7uih1Kakz9wRraRAYo6VZphC8Q5LdkFQG45Bj7KnACHbBs7e9xe7GDrP4OVbQHh"; + sout << "m5QCONkNpd+0umv74Adr64LDpsMr6CLBgHRrXukUXUGmFPqUCD7lC+quhu8psnZZrvSNPSa+goRp"; + sout << "PnO4wDlSeBSkuSdyC7utuC40jQtIZrkJmvkq5UURj4VcltYQ+5KiypO7ixAK4fvKyBTgxV2WNVeY"; + sout << "QvxNeJcNck+IPEhYLyplcLzyoekhnP4ZDD6ALG2jWZ8QoKpShA8+HdRSJoFguoERQRS7ePlMxywb"; + sout << "kcLPkcyEOlBu7lSjCRvjvdROmLAidz2GrlWvkuLrRe6clvs+U8JfBTjOON6P5nq41ElWNA33r9KN"; + sout << "kcCmEICyFl/Ier/YdjOUtXaHsKwHaMjmrXHQOTP7ELs7jouCbb4cTFVtgB20smKj3paJgoQysxmq"; + sout << "12Ur3QnHaHjaoxzIQnlaD69TF4O7T6U68/C6BwQDuZiD6iB+mkCQ7Uz0zFsiVMbF+RgFJKpjN4oo"; + sout << "SuR2gUL8sPlixeYmannuPwhCkJ1swN7NiVeUwTGFUDeQt2eJvJc2CllOpUZwF9PbyHSkxkQLs9kP"; + sout << "9DkvDhKkqJ2aEHDlPGBjETFYjNGdoBfMIcr1yvWshl9fTGlEi8c7mxMW/5Ed4sSdV3N6qRj3KZB7"; + sout << "zdQJ4Ql6tH7zkYWfo/ZtTEggQQAe6FzyQ12/xgFgdAulwF1IJZ/JlzY/SKePQPX0/89AooSUAl2q"; + sout << "/BfEFEI5XJhXDp2VJevsfFAjlmf0pbC5DkyBlUwdStZAoZpvfZPXcQ1NshCuD/hT4FinYmLl7GaW"; + sout << "LJSXkHegN09TBLWDdWe3DfKDZEWYqbIvtgt5Evgtw/5Qvy5DtefyHPq21BoIT8C1zxkgfOqEhH+l"; + sout << "E/fLaDsROj+RUJ055ycF9Wa8mzCOZMnaqsD13eTXtEY7LQ7MEo/15S1ny5JxaGNHxy7YJ3JfrshB"; + sout << "0pZ5vL3jvb9iYwocPnlmFpli+Md6qEdonEN6K86WwcRNqV57mMlF/EhBZXi8VnW0nqglR3mT2xTz"; + sout << "CKiwPK1W/zOnR2K0SsQrHV5kaLCfRi4vM3Jv/bGjlgJSUYBkB1QC7jBKE+uA3H5EmspwnDKPAb+b"; + sout << "kYXzpMWfQbvMvzpumEs+nY/xbCf2Hr1vhZAAFR1L4F49xrxDr8rjP3DryRgtBquQH6/5qMwdg9pR"; + sout << "4ACieRQts0S8jC77fxLAFW2GqifC5DBhwIcdhOhxR1FGLo5RzNktZkPob/fXEcC6u/Uz4v4Gy0j/"; + sout << "ZM783wye39lB2eymWQ5XPGC9FKR+ZLodyKJK+NXDBiFBXOjY4WNKEghllE7jAEU6VVmAN8NDkQdf"; + sout << "u6MiVlBZLK1cepTPDj6G/TGH2FK0I4O/z+Gb2WQdZ2VUObGhf3GpYNR4XEoBsJe0I8F+TKpHYpdY"; + sout << "guyaXCEMW9XpK0hTOBXfaooIQu4+Yb9e18fKnkrW74Aj45mdseL80RH2RJGm8LeZa/Nluqm71cqu"; + sout << "9zJkwZjxlfICDgLhkOXtt/Lihzoqav/abehnuZjlyWwJ34eIhCmEv9D7C0P9VE/7B4P1S/demat+"; + sout << "W+rmgFScowUvs1lfxNr1I5uCyLEJ9VphPUXnykyJ0XM6zm2Za8jD17Nh82k/eCX+JFMWv5znDNQe"; + sout << "aDzkIRjTQhXWt4e6GNZ5zU6g9rr/TIePeMJyRE4D09xOF6orHWMBh/TFAL6PQnJaI6RUqCg8le0+"; + sout << "ySd9qWCRpZsYBK+bZBEV9A1iJyvhLyKLgSGC8Q5Y0tQmE5zWSlYkRr4YtfASRn7e4l+Hz+hhcReD"; + sout << "F6s8jWCAHjsq4uZi8qf17LwZu2xpbSLmgCE1sij59Mli4sWlPDkX86ehGqvH2D2x3hKlXco1v+q5"; + sout << "dbNe/zvwcng5KoCwL0pmT3itlXfQs08+1lXbpoqUf2wqmQFEi3/TIWq+1lF/zV+6ICVx1d1eKarh"; + sout << "ZgdmD0C3QLlPNlUAG1j6qWQ4K32ICyKu2XWRjcaElm8uRxPBBJ2DKWJFyDwP8caoB80A2ApnnFCH"; + sout << "mrijgRO7LeQ6bQJJjOBSjeppPm4jTsVd5tlAefLV25uRwx4Reif0e4x+24fxtPj9udHRWggwGu/1"; + sout << "Zs4+wsJ7KmX2ekZ2LR5aotwtBCo8X1J9GaHal4WCzh9G7EqAeSjjrD63DoSgJx8PAUoH32QXUMxw"; + sout << "b6JR8Czkimh47Uv+aIqZsr6GDskW7xfxxgmzIPwSA0cJkulJKiOYEmjcfgzqaKHrcjMRb4LhJRTE"; + sout << "F7Xgdk4nJGT2hRv3Abqz4725TLPUsfd1DFTM8BJuJSWjGNa2cyQ6f3CTQPEZYFiyYMCtd4Ycytnq"; + sout << "N4djAycInnDJV/+XSjC94UB04UDvkq5lds+tgGMDf5YWClAz5EO0ztmL5PIQrA7qx1ykF7xiRHVB"; + sout << "Jy1Ln0oTEV9yf8jlIYILg3V/j716p61dkyJIvRtO6XhFdyE+4ia0WdXaNWNwvx4wkzqRHD4IxMD0"; + sout << "imBJe7IIMVztDzJ7T4K61P7nzwgGJgXHwfdh4xxYZC+4Qf7LVyGOeH30AgAiZBzEFu6kCHmPMxgZ"; + sout << "VQIAvbW0/84nBgbqTvUTQTVytf9ufnhuVZSJfXbMlFGTxeCtHG9C6rgAAk2QNPHb8up8NL4ZnVcA"; + sout << "AWk0OHyU1DZrLFb26nmHBxwVZySrSppYbldjEYxMdewD2OIYBYrCBCKIJzZpTMP8D7QOIX7B05Vq"; + sout << "zPH2i9cAAAWjZjAkQxG25u2KhBSufKHwxVZ9xgi82dPIUZJRMhNOyUQ3uVhK4MilNqxi3Giw2KtP"; + sout << "lO1IWReKc0m1pXxv70r03E0i7kf19iQ6AAAABRagQbu/yzAW7Stqw18BH3pVGZSoj8no2Sn/A9QF"; + sout << "E49JB+qp6BejcC0vGwH+ogKo7pk9fecrkHEc/6aH1Dt7ciMaDXfFII5rE9AAAAAB5kmFhDYODCNn"; + sout << "qZy5bb542mHiW4BxDLW3mQ4uKyXFV56IuntOj0rsjRZdrOATffLebKWyn1anGbGdQvJcOZcHvV6h"; + sout << "1Aih5pmFfPxyhA2ba90hh+kPYjhT2QAAAIYcLHepOiMlf+gULTiFNjbeyaz5vgWcGTiwkv4q4ED3"; + sout << "ZLE0lb+l/7qt620m8kgPvUAOLnK6ysWdLpoTtevH7MY5LafF3R/YxEPOo9xgr7zDLrMKLkZILeIb"; + sout << "yegmBAnGeWonZpv3eidKXsbBFhHGw7DoSakVKAT7Jahlmyk6howH8x8VEVpK1QNbaAXt/pFXGrzr"; + sout << "dAO6LjOcrwxvW8iYPjPChEHqwYloPizhrN7k0u0F/oq/urAZxyJ/MP8Tx9afuHY4aocW8FIGmN8f"; + sout << "NSozaWDqqw209VkydVcnibQweQucq14zNKxGZgvnuw3SKeLPAc40Z3fSWsEfG/DImcOJ6QdmY8S8"; + sout << "34MKBXUjR8W0w1P0gs407lSQFRZBEmiOpU3HW+A4BisSe4fcFlzFt+lhP9VNDr7lBHpbAqkg1Ob4"; + sout << "8/uu5w5qut8RiDJp2dYQFY1OieuY1lEN1BFYl4hoUuBQqGBpq3MDn6V40ZzXHl2Rbl+u2tXhNICL"; + sout << "FPOFo0Y+7b0ppLIQGPAd5fBGTfMnibWFP9Bw8KDo9wUbvGslyV0ny303hYXO4TzIliPUUu+nsJRc"; + sout << "pBCeJdyRvQUjr6BjCXrNVSldiB6oCr7e0AfMBU1BErZAhzseZYzzlIavra3tEOdW4xweFxOO1JNp"; + sout << "BqoG5a9gozEd6VQvhurdttGa5jsSm/tEmGDZl0t14nN18hjVW6KCb0u3+kCbcCRRcFz6tZ3+eomJ"; + sout << "TQIBUJu40uNzrLmur+zJyT7+JRRIq/xXSv2R1oJhIiYGe9P99wNd13TuTi5Oyx1b3hHJdysPjIt1"; + sout << "/I8UGCpIQ/FCIEEplSCkhDgedfL9OtDOnU/+bI2cipB8tWSwpwbczIslv49e3+xIqBzLf/4gBDd2"; + sout << "4ZYwe2lHWGL7OXGYCAQbC9ELg2KuRDpLrG2ad4fhUuAnizCLb+q+3Nd7qgkPXweeEMAbSHhp0vuN"; + sout << "XSbVB+cDaHkw7DxxJR0nn4IhdbkuQ54G975t+Wo5uPunpyZ5QvDXijG8NEhIAytpMmhXjkiyUbk/"; + sout << "XmaEnaSTOxOEVh/fBjeGfxHmjn0TgkQlw7GLHa1rB6Lp3y7hSWoZW0RSn0TmVQbWBFHen0cfhsp+"; + sout << "zc7gTzMeFxhMf3ZmBS9pmigcVHPAmU+rOLTtAzaklKumgxf0B+Su8YLZpcvc6Jy8t6rJgCYwIhR+"; + sout << "BiKEN1l6CxFeQtJp+2Do6oEZ4Lfk0DoAnJdV8l5DhkiXiuwX4zPFiKZQtKSBsrp7bq6dqjDphbgL"; + sout << "1A6JahmxkTYcJVOnVWG6LvfHYCIJ6hwDV4J23P+tNZtLCzfHPnDDdyJCmBExsOV1xXzNeGhqZhb3"; + sout << "kSWXntkMAstuTSdgb8vPGYSyEMsTx+pTeXFGv3TN+YCdQTQIovV2f0KTZMxmjoTmlZKueVeFwzSP"; + sout << "S+G65EMEcgQUGloWsFPdIaRzPLYHUktOcJc4bqEDvwIUDLqfjH4zr7kgF6i0KfRUMP4Jrje78S9G"; + sout << "C7o9XJLGWbS32yrdKVlIG+9f+2YQRrZECzhex6WjV69/JSe0jbilwXwsJSKtY6jFik8eE6RCZqN9"; + sout << "c1vOxMcd+5twMMkmnu4yvBEV6Y207UtegGSHaESpEW1MTvbDk/75X46WMSECbeEe8lzlmNaaWyWY"; + sout << "C+siHamjyTgHAvJLEwiNlQ26LS3mJr4HmpOVizbTykoe2WGV/jmUrhhUg2AN0CrS2LqriZolW/JX"; + sout << "jJeehNlJ9Jy4e4NZE2Xn/Q5UC3zVJm3bpbC3JKre7sALksQSQInLf7OoYArfwgYmAdMWhgxRakIv"; + sout << "0M0acqOuNUxmW0k4YEzEncfjSAQ3cAUkaL1XS55I5vm3o1RPiKBjjypzdi8GI7migOK5dRrIctP4"; + sout << "U0S4L6uojFipn9VGhIjGiv0Rg6OcaZO2FDq9SOGmLZ4b92sE7v4rtEt8oxlbbxPdgT6CLQkVCCJh"; + sout << "uvjCIeAmV0wwhaKYzE2TpSkPFrtCPWhBjSulx5OwEGvdgzUq1NAnHBqlc6Oazt+6Hb/b+zpziei2"; + sout << "huprTtScOqJcYzqJYGIwYebjxpVM0f/05a3LodBUWUNUIRNyzJpyEZluHCupDqfEgF7ObUDvtTDO"; + sout << "Zl1aCRWbztt2W0JE7SSCAZbzY+3Dq+m0VUJj1L7NZkg5py0mUHHnz2lDveZjtHr9UBqETzN2xCuw"; + sout << "8LQqbhDKR5I5ThHilGjZjxSN4VUKgBiBNUHXu8D2VN9RHr9xjQzmnSOJPHtcFuf9ioGO43OVtV+I"; + sout << "DvE3YjEt53F37uZBmvTQD3zcBvgRoNF24j6jjTeHE3I/MROKotCIBkxk5AZTF9Awzha4XwP3xOAI"; + sout << "FyhCle92a3102sdU/azu+1n3Jq0ZeifCSicBjrAHgQpVM7afn/yha/YKXqiwingX4pvrN6IRKADv"; + sout << "edJG6C4OttSovas7ayKdaTWURKwQWQ2NQF/24DUEmfGkWgk3e1uFjfdBNFez7/omGkRqvSAoT5pe"; + sout << "rXpbWzb7nigtdMRUdVMdzAHooSObROsbHc7ngUn6sJ6rD88EJZH94kk3adWN6plDDTyRe/eu5Ah0"; + sout << "XnBPmYIgVSYSZ/X6BvS5AZDElImmABuvuLEjOQ3Ydqc1NQIGOgMzWhn/Dc94YMxSOdmHCo5RyGjo"; + sout << "Zn0wntZDa+gKaR8PwzsIi+Q7RvReVKP4xKNPbI3W47v3sh1PeidJXt75OoHyKpoxJJcjwB8me4iZ"; + sout << "ctKBz0QZ5ZBuywtn2Lq69I5nmjHQbpIB+zmPqIxEJOworB7Qg6thkda97KcHsRbMJAA="; + + // Put the data into the istream sin + sin.str(sout.str()); + sout.str(""); + + // Decode the base64 text into its compressed binary form + base64_coder.decode(sin, sout); + sin.clear(); + sin.str(sout.str()); + sout.str(""); + + // Decompress the data into its original form + compressor.decompress(sin, sout); + + // Return the decoded and decompressed data + return sout.str(); + } + + // ---------------------------------------------------------------------------------------- + + }; + + correlation_tracker_tester a; + +// ---------------------------------------------------------------------------------------- + +} + + diff --git a/ml/dlib/dlib/test/crc32.cpp b/ml/dlib/dlib/test/crc32.cpp new file mode 100644 index 000000000..32f67ed3a --- /dev/null +++ b/ml/dlib/dlib/test/crc32.cpp @@ -0,0 +1,74 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include +#include + +#include "tester.h" + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.crc32"); + + + class crc32_tester : public tester + { + public: + crc32_tester ( + ) : + tester ("test_crc32", + "Runs tests on the crc32 component.") + {} + + void perform_test ( + ) + { + DLIB_TEST(crc32("davis").get_checksum() == 0x0445527C); + + crc32 c, c2; + DLIB_TEST(c.get_checksum() == 0); + c.add("davis"); + DLIB_TEST(c.get_checksum() == 0x0445527C); + DLIB_TEST(c2.get_checksum() == 0); + c2 = c; + DLIB_TEST(c2.get_checksum() == 0x0445527C); + crc32 c3(c); + DLIB_TEST(c3.get_checksum() == 0x0445527C); + c.add('a'); + c2.add('a'); + c3.add('a'); + DLIB_TEST(c.get_checksum() == 0xB100C606); + DLIB_TEST(c2.get_checksum() == 0xB100C606); + DLIB_TEST(c3.get_checksum() == 0xB100C606); + + + crc32::kernel_1a cold; + DLIB_TEST(cold.get_checksum() == 0); + cold.add("davis"); + DLIB_TEST(cold.get_checksum() == 0x0445527C); + + c.clear(); + DLIB_TEST(c.get_checksum() == 0); + c.add("davis"); + DLIB_TEST(c.get_checksum() == 0x0445527C); + + std::vector buf; + for (int i = 0; i < 4000; ++i) + buf.push_back(i); + DLIB_TEST(crc32(buf) == 492662731); + } + } a; + +} + + diff --git a/ml/dlib/dlib/test/create_iris_datafile.cpp b/ml/dlib/dlib/test/create_iris_datafile.cpp new file mode 100644 index 000000000..1e19d2aac --- /dev/null +++ b/ml/dlib/dlib/test/create_iris_datafile.cpp @@ -0,0 +1,65 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include +#include +#include +#include + +namespace +{ + // This function returns the contents of the file 'iris.scale' + const std::string get_decoded_string() + { + dlib::base64::kernel_1a base64_coder; + dlib::compress_stream::kernel_1ea compressor; + std::ostringstream sout; + std::istringstream sin; + + // The base64 encoded data from the file 'iris.scale' we want to decode and return. + sout << "MU66cCmT9lCXWJXwhdfOGELwlyExClbHEF1s9XoqxNDV7o8AdVVHws/C9oIKO5EShH1lI/QFTWk3"; + sout << "8EUdVpSw/NpZCUa7O9nq5uO6SE0gfRAyryH+pfIVL9jPiQi8rBdagDf4kUd4eggz9glwYnKEE+US"; + sout << "K4GUBnW33YDf/jMF2GIBLNvz69yGJj8RC5rOUeJxR4mlHrDmnEfgRSdFIfXk4OQ4V/XbOsE1bnhG"; + sout << "fmACcu7nYv6M/043Z6o8oaBeoJ2XK/9UqOWGFOwfVpQ46fz1a0oTlOzyDbbzMiniLr8z5P/VYwYd"; + sout << "iAE70MwxHXs6Ga3zMmD/h1WxB/uRRph39B1lPN1UXC7U6SIatmtGWY+JYpwBk6raAnR3sblTFBNs"; + sout << "UdPW+1a7AxinR0NZO6YEiCFy8lbpfPRZNAr5ENqPbD2DZtkHk3L8ARxSoFBgqPa8aO3fFow7rVxF"; + sout << "xIJ2TxcHS84+BtH7KvtWfH7kUPOZLQ+Ohqghn9I57IeMl7E3aoTRTiVv3P2twAbP5Y+ZaAUoU7CK"; + sout << "c9FptjKgMClUkuWxA7tGUEp069PqGT8NbI+yxorh/iVhkVhuGAzgjjXYS/D26OGj4bzF6mtRbnms"; + sout << "Y2OYlF7QqhawZaHLtmZ6xLhR2F8p/0nrbpAz2brQLNKgQAMvU9rTZ0XYpuJNbRSsARkRDorPopDO"; + sout << "kKNUORfkh2zfIytVToQ9tZ9W2LkfGZdWjJu/wEKjPDAU55q3bCfKOUk12tjq0sq/7qjUWJRcLSCu"; + sout << "bqo8EzaKJj3cTXVgXXLHP6WEOPZ9vShuxQUu1JWkh8YEinjwFSyA6UnAKqPtN/HsBgv8YbnfnY/q"; + sout << "e5JvUYWbs3Lk9enlhcI0vEVTV5f0GMjdkW87l3cWgmXJqiljJDREWEdKZJQ0rGBU/gW5kO3SAS1W"; + sout << "OETVJG2kJD8Ib7hT15Mu2lOVNQYFri6O3yWtp5/NLHsYXoDKIYrxoJtM9+GkprVwRuhDcwxE+eQa"; + sout << "pp5nC8qj38ameQHaJR2hJCuW2nvr4Wwm0ploF00ZP9cS9YznCO52cueUQX0+zil7bU++jghqSGP5"; + sout << "+JyRzWUWWbDhnCyanej2Y3sqfZ3o2kuUjaAgZFz5pLqK64uACjztp4bQFsaMRdc+OCV2uItqoaRg"; + sout << "a6u7/VrvS+ZigwcGWDjXSKev334f8ZqQQIR5hljdeseGuw7/5XySzUrgc8lCOvMa0pKNn9Nl8W/W"; + sout << "vbKz1VwA"; + + // Put the data into the istream sin + sin.str(sout.str()); + sout.str(""); + + // Decode the base64 text into its compressed binary form + base64_coder.decode(sin,sout); + sin.clear(); + sin.str(sout.str()); + sout.str(""); + + // Decompress the data into its original form + compressor.decompress(sin,sout); + + // Return the decoded and decompressed data + return sout.str(); + } +} + +namespace dlib +{ + void create_iris_datafile ( + ) + { + std::ofstream fout("iris.scale"); + fout << get_decoded_string(); + } +} + diff --git a/ml/dlib/dlib/test/create_iris_datafile.h b/ml/dlib/dlib/test/create_iris_datafile.h new file mode 100644 index 000000000..805291f07 --- /dev/null +++ b/ml/dlib/dlib/test/create_iris_datafile.h @@ -0,0 +1,19 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CREATE_IRIS_DAtAFILE_Hh_ +#define DLIB_CREATE_IRIS_DAtAFILE_Hh_ + +namespace dlib +{ + void create_iris_datafile ( + ); + /*! + ensures + - Creates a local file called iris.scale that contains the + 150 samples from the 3-class Iris dataset from the UCI + repository. The file will be in LIBSVM format (it was + originally downloaded from the LIBSVM website). + !*/ +} + +#endif // DLIB_CREATE_IRIS_DAtAFILE_Hh_ diff --git a/ml/dlib/dlib/test/cublas.cpp b/ml/dlib/dlib/test/cublas.cpp new file mode 100644 index 000000000..3f3cb09dc --- /dev/null +++ b/ml/dlib/dlib/test/cublas.cpp @@ -0,0 +1,198 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include +#include +#include "../dnn/tensor_tools.h" + +#include "tester.h" + +// We only do these tests if CUDA is available to test in the first place. +#ifdef DLIB_USE_CUDA + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.cublas"); + + + void test_inv() + { + tt::tensor_rand rnd; + dlib::tt::inv tinv; + dlib::cuda::inv cinv; + resizable_tensor minv1, minv2; + for (int n = 1; n < 20; ++n) + { + print_spinner(); + resizable_tensor m(n,n); + rnd.fill_uniform(m); + + tinv(m, minv1); + cinv(m, minv2); + matrix mref = inv(mat(m)); + DLIB_TEST_MSG(mean(abs(mref-mat(minv1)))/mean(abs(mref)) < 1e-5, mean(abs(mref-mat(minv1)))/mean(abs(mref)) <<" n: " << n); + DLIB_TEST_MSG(mean(abs(mref-mat(minv2)))/mean(abs(mref)) < 1e-5, mean(abs(mref-mat(minv2)))/mean(abs(mref)) <<" n: " << n); + } + } + + + class cublas_tester : public tester + { + public: + cublas_tester ( + ) : + tester ("test_cublas", + "Runs tests on the cuBLAS bindings.") + {} + + void perform_test ( + ) + { + test_inv(); + { + resizable_tensor a(4,3), b(3,4), c(3,3); + + c = 1; + a = matrix_cast(gaussian_randm(a.num_samples(),a.size()/a.num_samples())); + b = matrix_cast(gaussian_randm(b.num_samples(),b.size()/b.num_samples())); + + matrix truth = 2*mat(c)+trans(mat(a))*trans(mat(b)); + + a.async_copy_to_device(); b.async_copy_to_device(); c.async_copy_to_device(); + cuda::gemm(2, c, 1, a, true, b, true); + DLIB_TEST(max(abs(truth-mat(c))) < 1e-6); + } + { + resizable_tensor a(4,3), b(4,3), c(3,3); + + c = 1; + a = matrix_cast(gaussian_randm(a.num_samples(),a.size()/a.num_samples())); + b = matrix_cast(gaussian_randm(b.num_samples(),b.size()/b.num_samples())); + + matrix truth = 2*mat(c)+trans(mat(a))*mat(b); + + a.async_copy_to_device(); b.async_copy_to_device(); c.async_copy_to_device(); + cuda::gemm(2, c, 1, a, true, b, false); + DLIB_TEST(max(abs(truth-mat(c))) < 1e-6); + } + { + resizable_tensor a(3,4), b(3,4), c(3,3); + + c = 1; + a = matrix_cast(gaussian_randm(a.num_samples(),a.size()/a.num_samples())); + b = matrix_cast(gaussian_randm(b.num_samples(),b.size()/b.num_samples())); + + matrix truth = 2*mat(c)+mat(a)*trans(mat(b)); + + a.async_copy_to_device(); b.async_copy_to_device(); c.async_copy_to_device(); + cuda::gemm(2, c, 1, a, false, b, true); + DLIB_TEST(max(abs(truth-mat(c))) < 1e-6); + } + { + resizable_tensor a(3,4), b(3,4), c(3,3); + + c = 1; + a = matrix_cast(gaussian_randm(a.num_samples(),a.size()/a.num_samples())); + b = matrix_cast(gaussian_randm(b.num_samples(),b.size()/b.num_samples())); + + matrix truth = mat(c)+mat(a)*trans(mat(b)); + + a.async_copy_to_device(); b.async_copy_to_device(); c.async_copy_to_device(); + cuda::gemm(1, c, 1, a, false, b, true); + DLIB_TEST(max(abs(truth-mat(c))) < 1e-6); + } + { + resizable_tensor a(3,4), b(4,3), c(3,3); + + c = 1; + a = matrix_cast(gaussian_randm(a.num_samples(),a.size()/a.num_samples())); + b = matrix_cast(gaussian_randm(b.num_samples(),b.size()/b.num_samples())); + + matrix truth = 2*mat(c)+mat(a)*mat(b); + + a.async_copy_to_device(); b.async_copy_to_device(); c.async_copy_to_device(); + cuda::gemm(2, c, 1, a, false, b, false); + DLIB_TEST(max(abs(truth-mat(c))) < 1e-6); + } + { + resizable_tensor a(3,4), b(4,3), c(3,3); + + c = std::numeric_limits::infinity(); + a = matrix_cast(gaussian_randm(a.num_samples(),a.size()/a.num_samples())); + b = matrix_cast(gaussian_randm(b.num_samples(),b.size()/b.num_samples())); + a.async_copy_to_device(); b.async_copy_to_device(); c.async_copy_to_device(); + + matrix truth = mat(a)*mat(b); + + cuda::gemm(0, c, 1, a, false, b, false); + DLIB_TEST(max(abs(truth-mat(c))) < 1e-6); + } + { + resizable_tensor a(3,4), b(4,4), c(3,4); + + c = 1; + a = matrix_cast(gaussian_randm(a.num_samples(),a.size()/a.num_samples())); + b = matrix_cast(gaussian_randm(b.num_samples(),b.size()/b.num_samples())); + + matrix truth = 2*mat(c)+mat(a)*mat(b); + + cuda::gemm(2, c, 1, a, false, b, false); + DLIB_TEST(get_rect(truth) == get_rect(mat(c))); + DLIB_TEST(max(abs(truth-mat(c))) < 1e-6); + } + { + resizable_tensor a(4,3), b(4,4), c(3,4); + + c = 1; + a = matrix_cast(gaussian_randm(a.num_samples(),a.size()/a.num_samples())); + b = matrix_cast(gaussian_randm(b.num_samples(),b.size()/b.num_samples())); + + matrix truth = 2*mat(c)+trans(mat(a))*mat(b); + + cuda::gemm(2, c, 1, a, true, b, false); + DLIB_TEST(get_rect(truth) == get_rect(mat(c))); + DLIB_TEST(max(abs(truth-mat(c))) < 1e-6); + } + { + resizable_tensor a(4,3), b(4,5), c(3,5); + + c = 1; + a = matrix_cast(gaussian_randm(a.num_samples(),a.size()/a.num_samples())); + b = matrix_cast(gaussian_randm(b.num_samples(),b.size()/b.num_samples())); + + matrix truth = 2*mat(c)+trans(mat(a))*mat(b); + + cuda::gemm(2, c, 1, a, true, b, false); + DLIB_TEST(get_rect(truth) == get_rect(mat(c))); + DLIB_TEST(max(abs(truth-mat(c))) < 1e-6); + } + { + resizable_tensor a(4,3), b(4,5), c(3,5); + + c = std::numeric_limits::infinity(); + a = matrix_cast(gaussian_randm(a.num_samples(),a.size()/a.num_samples())); + b = matrix_cast(gaussian_randm(b.num_samples(),b.size()/b.num_samples())); + + matrix truth = trans(mat(a))*mat(b); + + cuda::gemm(0, c, 1, a, true, b, false); + DLIB_TEST(get_rect(truth) == get_rect(mat(c))); + DLIB_TEST(max(abs(truth-mat(c))) < 1e-6); + } + } + } a; + +} + +#endif // DLIB_USE_CUDA + diff --git a/ml/dlib/dlib/test/data_io.cpp b/ml/dlib/dlib/test/data_io.cpp new file mode 100644 index 000000000..8ced88008 --- /dev/null +++ b/ml/dlib/dlib/test/data_io.cpp @@ -0,0 +1,227 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "tester.h" +#include +#include +#include +#include "create_iris_datafile.h" +#include +#include + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + dlib::logger dlog("test.data_io"); + + + class test_data_io : public tester + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a unit test. When it is constructed + it adds itself into the testing framework. + !*/ + public: + test_data_io ( + ) : + tester ( + "test_data_io", // the command line argument name for this test + "Run tests on the data_io stuff.", // the command line argument description + 0 // the number of command line arguments for this test + ) + { + } + + + template + void run_test() + { + print_spinner(); + + typedef typename sample_type::value_type::second_type scalar_type; + + std::vector samples; + std::vector labels; + + load_libsvm_formatted_data("iris.scale",samples, labels); + save_libsvm_formatted_data("iris.scale2", samples, labels); + + DLIB_TEST(samples.size() == 150); + DLIB_TEST(labels.size() == 150); + DLIB_TEST(max_index_plus_one(samples) == 5); + fix_nonzero_indexing(samples); + DLIB_TEST(max_index_plus_one(samples) == 4); + + load_libsvm_formatted_data("iris.scale2",samples, labels); + + DLIB_TEST(samples.size() == 150); + DLIB_TEST(labels.size() == 150); + + DLIB_TEST(max_index_plus_one(samples) == 5); + fix_nonzero_indexing(samples); + DLIB_TEST(max_index_plus_one(samples) == 4); + + one_vs_one_trainer,scalar_type> trainer; + + typedef sparse_linear_kernel kernel_type; + trainer.set_trainer(krr_trainer()); + + randomize_samples(samples, labels); + matrix cv = cross_validate_multiclass_trainer(trainer, samples, labels, 4); + + dlog << LINFO << "confusion matrix: \n" << cv; + const scalar_type cv_accuracy = sum(diag(cv))/sum(cv); + dlog << LINFO << "cv accuracy: " << cv_accuracy; + DLIB_TEST(cv_accuracy > 0.97); + + + + + { + print_spinner(); + typedef matrix dsample_type; + std::vector dsamples = sparse_to_dense(samples); + DLIB_TEST(dsamples.size() == 150); + DLIB_TEST(dsamples[0].size() == 4); + DLIB_TEST(max_index_plus_one(dsamples) == 4); + + one_vs_one_trainer,scalar_type> trainer; + + typedef linear_kernel kernel_type; + trainer.set_trainer(rr_trainer()); + + cv = cross_validate_multiclass_trainer(trainer, dsamples, labels, 4); + + dlog << LINFO << "dense confusion matrix: \n" << cv; + const scalar_type cv_accuracy = sum(diag(cv))/sum(cv); + dlog << LINFO << "dense cv accuracy: " << cv_accuracy; + DLIB_TEST(cv_accuracy > 0.97); + } + + } + + + void test_sparse_to_dense() + { + { + std::map temp; + + matrix m, m2; + + m = sparse_to_dense(m); + DLIB_TEST(m.size() == 0); + m.set_size(2,1); + m = 1, 2; + m2 = sparse_to_dense(m); + DLIB_TEST(m == m2); + m2 = sparse_to_dense(m,1); + DLIB_TEST(m2.size() == 1); + DLIB_TEST(m2(0,0) == 1); + m2 = sparse_to_dense(m,0); + DLIB_TEST(m2.size() == 0); + + temp[3] = 2; + temp[5] = 4; + m2 = sparse_to_dense(temp); + m.set_size(6); + m = 0,0,0,2,0,4; + DLIB_TEST(m2 == m); + + m2 = sparse_to_dense(temp, 5); + m.set_size(5); + m = 0,0,0,2,0; + DLIB_TEST(m2 == m); + + m2 = sparse_to_dense(temp, 7); + m.set_size(7); + m = 0,0,0,2,0,4,0; + DLIB_TEST(m2 == m); + + std::vector > > vects; + + std::vector > v; + v.push_back(make_pair(5,2)); + v.push_back(make_pair(3,1)); + v.push_back(make_pair(5,2)); + v.push_back(make_pair(3,1)); + v = make_sparse_vector(v); + vects.push_back(v); + vects.push_back(v); + vects.push_back(v); + vects.push_back(v); + DLIB_TEST(max_index_plus_one(v) == 6); + m2 = sparse_to_dense(v); + m.set_size(6); + m = 0,0,0,2,0,4; + DLIB_TEST_MSG(m2 == m, m2 << "\n\n" << m ); + + m2 = sparse_to_dense(v,7); + m.set_size(7); + m = 0,0,0,2,0,4,0; + DLIB_TEST(m2 == m); + + m2 = sparse_to_dense(v,5); + m.set_size(5); + m = 0,0,0,2,0; + DLIB_TEST(m2 == m); + + v.clear(); + m2 = sparse_to_dense(v); + DLIB_TEST(m2.size() == 0); + + + std::vector > mvects = sparse_to_dense(vects); + DLIB_TEST(mvects.size() == 4); + m.set_size(6); + m = 0,0,0,2,0,4; + DLIB_TEST(mvects[0] == m); + DLIB_TEST(mvects[1] == m); + DLIB_TEST(mvects[2] == m); + DLIB_TEST(mvects[3] == m); + + + mvects = sparse_to_dense(vects, 7); + DLIB_TEST(mvects.size() == 4); + m.set_size(7); + m = 0,0,0,2,0,4,0; + DLIB_TEST(mvects[0] == m); + DLIB_TEST(mvects[1] == m); + DLIB_TEST(mvects[2] == m); + DLIB_TEST(mvects[3] == m); + + mvects = sparse_to_dense(vects, 5); + DLIB_TEST(mvects.size() == 4); + m.set_size(5); + m = 0,0,0,2,0; + DLIB_TEST(mvects[0] == m); + DLIB_TEST(mvects[1] == m); + DLIB_TEST(mvects[2] == m); + DLIB_TEST(mvects[3] == m); + + } + } + + + void perform_test ( + ) + { + print_spinner(); + create_iris_datafile(); + + test_sparse_to_dense(); + + run_test >(); + run_test >(); + run_test > >(); + run_test > >(); + } + }; + + test_data_io a; + +} + + diff --git a/ml/dlib/dlib/test/directed_graph.cpp b/ml/dlib/dlib/test/directed_graph.cpp new file mode 100644 index 000000000..b97976b96 --- /dev/null +++ b/ml/dlib/dlib/test/directed_graph.cpp @@ -0,0 +1,541 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tester.h" + +// This is called an unnamed-namespace and it has the effect of making everything inside this file "private" +// so that everything you declare will have static linkage. Thus we won't have any multiply +// defined symbol errors coming out of the linker when we try to compile the test suite. +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + // Declare the logger we will use in this test. The name of the tester + // should start with "test." + logger dlog("test.directed_graph"); + + template < + typename directed_graph + > + void directed_graph_test ( + ) + /*! + requires + - directed_graph is an implementation of directed_graph/directed_graph_kernel_abstract.h + is instantiated with int + ensures + - runs tests on directed_graph for compliance with the specs + !*/ + { + print_spinner(); + + COMPILE_TIME_ASSERT(is_directed_graph::value == true); + directed_graph a, b; + dlib::set::compare_1b_c s; + + DLIB_TEST(graph_contains_directed_cycle(a) == false); + DLIB_TEST(graph_contains_undirected_cycle(a) == false); + + DLIB_TEST(a.number_of_nodes() == 0); + + DLIB_TEST(graph_contains_length_one_cycle(a) == false); + + a.set_number_of_nodes(5); + DLIB_TEST(graph_contains_length_one_cycle(a) == false); + DLIB_TEST(graph_is_connected(a) == false); + DLIB_TEST(graph_contains_directed_cycle(a) == false); + DLIB_TEST(graph_contains_undirected_cycle(a) == false); + DLIB_TEST(a.number_of_nodes() == 5); + + for (int i = 0; i < 5; ++i) + { + a.node(i).data = i; + DLIB_TEST(a.node(i).index() == (unsigned int)i); + } + + a.remove_node(1); + + DLIB_TEST(a.number_of_nodes() == 4); + + + // make sure that only the number with data == 1 was remove + int count = 0; + for (int i = 0; i < 4; ++i) + { + count += a.node(i).data; + DLIB_TEST(a.node(i).number_of_children() == 0); + DLIB_TEST(a.node(i).number_of_parents() == 0); + DLIB_TEST(a.node(i).index() == (unsigned int)i); + } + + DLIB_TEST(count == 9); + + DLIB_TEST(graph_contains_directed_cycle(a) == false); + + a.add_edge(1,1); + DLIB_TEST(graph_contains_length_one_cycle(a) == true); + DLIB_TEST(graph_contains_undirected_cycle(a) == true); + + DLIB_TEST(graph_contains_directed_cycle(a) == true); + + a.add_edge(1,2); + + DLIB_TEST(graph_contains_directed_cycle(a) == true); + + DLIB_TEST(a.node(1).number_of_children() == 2); + DLIB_TEST(a.node(1).number_of_parents() == 1); + DLIB_TEST_MSG(a.node(1).parent(0).index() == 1,""); + + DLIB_TEST_MSG(a.node(1).child(0).index() + a.node(1).child(1).index() == 3,""); + DLIB_TEST(a.node(2).number_of_children() == 0); + DLIB_TEST(a.node(2).number_of_parents() == 1); + DLIB_TEST(a.node(2).index() == 2); + + int val = a.node(1).data; + a.remove_node(1); + DLIB_TEST(graph_contains_length_one_cycle(a) == false); + + DLIB_TEST(graph_contains_directed_cycle(a) == false); + DLIB_TEST(graph_contains_undirected_cycle(a) == false); + + + DLIB_TEST(a.number_of_nodes() == 3); + + count = 0; + for (int i = 0; i < 3; ++i) + { + count += a.node(i).data; + DLIB_TEST(a.node(i).number_of_children() == 0); + DLIB_TEST(a.node(i).number_of_parents() == 0); + DLIB_TEST(a.node(i).index() == (unsigned int)i); + } + DLIB_TEST(count == 9-val); + + + val = a.add_node(); + DLIB_TEST(val == 3); + DLIB_TEST(a.number_of_nodes() == 4); + + for (int i = 0; i < 4; ++i) + { + a.node(i).data = i; + DLIB_TEST(a.node(i).index() == (unsigned int)i); + } + + for (int i = 0; i < 4; ++i) + { + DLIB_TEST(a.node(i).data == i); + DLIB_TEST(a.node(i).index() == (unsigned int)i); + } + + a.add_edge(0, 1); + a.add_edge(0, 2); + DLIB_TEST(graph_is_connected(a) == false); + a.add_edge(1, 3); + DLIB_TEST(graph_is_connected(a) == true); + a.add_edge(2, 3); + DLIB_TEST(graph_is_connected(a) == true); + DLIB_TEST(graph_contains_length_one_cycle(a) == false); + + DLIB_TEST(a.has_edge(0, 1)); + DLIB_TEST(a.has_edge(0, 2)); + DLIB_TEST(a.has_edge(1, 3)); + DLIB_TEST(a.has_edge(2, 3)); + + DLIB_TEST(!a.has_edge(1, 0)); + DLIB_TEST(!a.has_edge(2, 0)); + DLIB_TEST(!a.has_edge(3, 1)); + DLIB_TEST(!a.has_edge(3, 2)); + + DLIB_TEST(a.node(0).number_of_parents() == 0); + DLIB_TEST(a.node(0).number_of_children() == 2); + + DLIB_TEST(a.node(1).number_of_parents() == 1); + DLIB_TEST(a.node(1).number_of_children() == 1); + DLIB_TEST(a.node(1).child(0).index() == 3); + DLIB_TEST(a.node(1).parent(0).index() == 0); + + DLIB_TEST(a.node(2).number_of_parents() == 1); + DLIB_TEST(a.node(2).number_of_children() == 1); + DLIB_TEST(a.node(2).child(0).index() == 3); + DLIB_TEST(a.node(2).parent(0).index() == 0); + + DLIB_TEST(a.node(3).number_of_parents() == 2); + DLIB_TEST(a.node(3).number_of_children() == 0); + + DLIB_TEST(graph_contains_directed_cycle(a) == false); + DLIB_TEST(graph_contains_undirected_cycle(a) == true); + + a.remove_edge(0,1); + + DLIB_TEST(graph_contains_directed_cycle(a) == false); + + DLIB_TEST(!a.has_edge(0, 1)); + DLIB_TEST(a.has_edge(0, 2)); + DLIB_TEST(a.has_edge(1, 3)); + DLIB_TEST(a.has_edge(2, 3)); + + DLIB_TEST(!a.has_edge(1, 0)); + DLIB_TEST(!a.has_edge(2, 0)); + DLIB_TEST(!a.has_edge(3, 1)); + DLIB_TEST(!a.has_edge(3, 2)); + + + DLIB_TEST(a.node(0).number_of_parents() == 0); + DLIB_TEST(a.node(0).number_of_children() == 1); + + DLIB_TEST(a.node(1).number_of_parents() == 0); + DLIB_TEST(a.node(1).number_of_children() == 1); + DLIB_TEST(a.node(1).child(0).index() == 3); + + DLIB_TEST(a.node(2).number_of_parents() == 1); + DLIB_TEST(a.node(2).number_of_children() == 1); + DLIB_TEST(a.node(2).child(0).index() == 3); + DLIB_TEST(a.node(2).parent(0).index() == 0); + + DLIB_TEST(a.node(3).number_of_parents() == 2); + DLIB_TEST(a.node(3).number_of_children() == 0); + + for (int i = 0; i < 4; ++i) + { + DLIB_TEST(a.node(i).data == i); + DLIB_TEST(a.node(i).index() == (unsigned int)i); + } + + + + swap(a,b); + + DLIB_TEST(a.number_of_nodes() == 0); + DLIB_TEST(b.number_of_nodes() == 4); + DLIB_TEST(b.node(0).number_of_parents() == 0); + DLIB_TEST(b.node(0).number_of_children() == 1); + + DLIB_TEST(b.node(1).number_of_parents() == 0); + DLIB_TEST(b.node(1).number_of_children() == 1); + DLIB_TEST(b.node(1).child(0).index() == 3); + + DLIB_TEST(b.node(2).number_of_parents() == 1); + DLIB_TEST(b.node(2).number_of_children() == 1); + DLIB_TEST(b.node(2).child(0).index() == 3); + DLIB_TEST(b.node(2).parent(0).index() == 0); + + DLIB_TEST(b.node(3).number_of_parents() == 2); + DLIB_TEST(b.node(3).number_of_children() == 0); + b.node(0).child_edge(0) = static_cast(b.node(0).child(0).index()+1); + b.node(1).child_edge(0) = static_cast(b.node(1).child(0).index()+1); + b.node(2).child_edge(0) = static_cast(b.node(2).child(0).index()+1); + + DLIB_TEST_MSG(b.node(0).child_edge(0) == b.node(0).child(0).index()+1, + b.node(0).child_edge(0) << " " << b.node(0).child(0).index()+1); + DLIB_TEST_MSG(b.node(1).child_edge(0) == b.node(1).child(0).index()+1, + b.node(1).child_edge(0) << " " << b.node(1).child(0).index()+1); + DLIB_TEST_MSG(b.node(2).child_edge(0) == b.node(2).child(0).index()+1, + b.node(2).child_edge(0) << " " << b.node(2).child(0).index()+1); + + DLIB_TEST_MSG(b.node(2).parent_edge(0) == 2+1, + b.node(2).parent_edge(0) << " " << 2+1); + DLIB_TEST_MSG(b.node(3).parent_edge(0) == 3+1, + b.node(3).parent_edge(0) << " " << 3+1); + DLIB_TEST_MSG(b.node(3).parent_edge(1) == 3+1, + b.node(3).parent_edge(1) << " " << 3+1); + + ostringstream sout; + + serialize(b, sout); + + istringstream sin(sout.str()); + + a.set_number_of_nodes(20); + DLIB_TEST(a.number_of_nodes() == 20); + deserialize(a, sin); + DLIB_TEST(a.number_of_nodes() == 4); + + DLIB_TEST(!a.has_edge(0, 1)); + DLIB_TEST(a.has_edge(0, 2)); + DLIB_TEST(a.has_edge(1, 3)); + DLIB_TEST(a.has_edge(2, 3)); + + DLIB_TEST(!a.has_edge(1, 0)); + DLIB_TEST(!a.has_edge(2, 0)); + DLIB_TEST(!a.has_edge(3, 1)); + DLIB_TEST(!a.has_edge(3, 2)); + + DLIB_TEST_MSG(a.node(0).child_edge(0) == a.node(0).child(0).index()+1, + a.node(0).child_edge(0) << " " << a.node(0).child(0).index()+1); + DLIB_TEST_MSG(a.node(1).child_edge(0) == a.node(1).child(0).index()+1, + a.node(1).child_edge(0) << " " << a.node(1).child(0).index()+1); + DLIB_TEST_MSG(a.node(2).child_edge(0) == a.node(2).child(0).index()+1, + a.node(2).child_edge(0) << " " << a.node(2).child(0).index()+1); + DLIB_TEST_MSG(a.node(2).parent_edge(0) == 2+1, + a.node(2).parent_edge(0) << " " << 2+1); + DLIB_TEST_MSG(a.node(3).parent_edge(0) == 3+1, + a.node(3).parent_edge(0) << " " << 3+1); + DLIB_TEST_MSG(a.node(3).parent_edge(1) == 3+1, + a.node(3).parent_edge(1) << " " << 3+1); + + + + for (int i = 0; i < 4; ++i) + { + DLIB_TEST(a.node(i).data == i); + DLIB_TEST(a.node(i).index() == (unsigned int)i); + } + + + DLIB_TEST(graph_contains_undirected_cycle(a) == false); + + DLIB_TEST(b.number_of_nodes() == 4); + DLIB_TEST(b.node(0).number_of_parents() == 0); + DLIB_TEST(b.node(0).number_of_children() == 1); + + DLIB_TEST(b.node(1).number_of_parents() == 0); + DLIB_TEST(b.node(1).number_of_children() == 1); + DLIB_TEST(b.node(1).child(0).index() == 3); + + DLIB_TEST(b.node(2).number_of_parents() == 1); + DLIB_TEST(b.node(2).number_of_children() == 1); + DLIB_TEST(b.node(2).child(0).index() == 3); + DLIB_TEST(b.node(2).parent(0).index() == 0); + + DLIB_TEST(b.node(3).number_of_parents() == 2); + DLIB_TEST(b.node(3).number_of_children() == 0); + + + DLIB_TEST(a.number_of_nodes() == 4); + DLIB_TEST(a.node(0).number_of_parents() == 0); + DLIB_TEST(a.node(0).number_of_children() == 1); + + DLIB_TEST(a.node(1).number_of_parents() == 0); + DLIB_TEST(a.node(1).number_of_children() == 1); + DLIB_TEST(a.node(1).child(0).index() == 3); + + DLIB_TEST(a.node(2).number_of_parents() == 1); + DLIB_TEST(a.node(2).number_of_children() == 1); + DLIB_TEST(a.node(2).child(0).index() == 3); + DLIB_TEST(a.node(2).parent(0).index() == 0); + + DLIB_TEST(a.node(3).number_of_parents() == 2); + DLIB_TEST(a.node(3).number_of_children() == 0); + + DLIB_TEST(a.number_of_nodes() == 4); + a.clear(); + DLIB_TEST(a.number_of_nodes() == 0); + + + DLIB_TEST(graph_contains_directed_cycle(a) == false); + + a.set_number_of_nodes(10); + + DLIB_TEST(graph_contains_directed_cycle(a) == false); + + a.add_edge(0,1); + a.add_edge(1,2); + a.add_edge(1,3); + a.add_edge(2,4); + a.add_edge(3,4); + a.add_edge(4,5); + a.add_edge(5,1); + + DLIB_TEST(graph_contains_directed_cycle(a) == true); + DLIB_TEST(graph_contains_undirected_cycle(a) == true); + + a.remove_edge(5,1); + + DLIB_TEST(graph_contains_undirected_cycle(a) == true); + DLIB_TEST(graph_contains_directed_cycle(a) == false); + a.add_edge(7,8); + DLIB_TEST(graph_contains_directed_cycle(a) == false); + a.add_edge(8,7); + DLIB_TEST(graph_contains_directed_cycle(a) == true); + DLIB_TEST(graph_contains_undirected_cycle(a) == true); + + + a.clear(); + /* + Make a graph that looks like: + 0 1 + \ / + 2 + | + 3 + */ + a.set_number_of_nodes(4); + a.add_edge(0,2); + a.add_edge(1,2); + a.add_edge(2,3); + for (unsigned long i = 0; i < 4; ++i) + a.node(i).data = i; + + graph::kernel_1a_c g; + create_moral_graph(a,g); + + graph::compare_1b_c, dlib::set::compare_1a_c>::kernel_1a_c join_tree; + dlib::set::compare_1b_c>::kernel_1b_c sos; + + create_join_tree(g, join_tree); + DLIB_TEST(is_join_tree(g, join_tree)); + DLIB_TEST(join_tree.number_of_nodes() == 2); + DLIB_TEST(graph_contains_undirected_cycle(join_tree) == false); + DLIB_TEST(graph_is_connected(join_tree) == true); + + unsigned long temp; + triangulate_graph_and_find_cliques(g,sos); + + temp = 2; s.add(temp); + temp = 3; s.add(temp); + DLIB_TEST(sos.is_member(s)); + s.clear(); + temp = 0; s.add(temp); + temp = 1; s.add(temp); + temp = 2; s.add(temp); + DLIB_TEST(sos.is_member(s)); + DLIB_TEST(sos.size() == 2); + DLIB_TEST(sos.is_member(join_tree.node(0).data)); + DLIB_TEST(sos.is_member(join_tree.node(1).data)); + + + s.clear(); + temp = 0; s.add(temp); + DLIB_TEST(is_clique(g,s) == true); + DLIB_TEST(is_maximal_clique(g,s) == false); + temp = 3; s.add(temp); + DLIB_TEST(is_clique(g,s) == false); + s.destroy(3); + DLIB_TEST(is_clique(g,s) == true); + temp = 2; s.add(temp); + DLIB_TEST(is_clique(g,s) == true); + DLIB_TEST(is_maximal_clique(g,s) == false); + temp = 1; s.add(temp); + DLIB_TEST(is_clique(g,s) == true); + DLIB_TEST(is_maximal_clique(g,s) == true); + s.clear(); + DLIB_TEST(is_clique(g,s) == true); + temp = 3; s.add(temp); + DLIB_TEST(is_clique(g,s) == true); + temp = 2; s.add(temp); + DLIB_TEST(is_clique(g,s) == true); + DLIB_TEST(is_maximal_clique(g,s) == true); + + + DLIB_TEST(a.number_of_nodes() == 4); + DLIB_TEST(g.number_of_nodes() == 4); + for (unsigned long i = 0; i < 4; ++i) + DLIB_TEST( a.node(i).data == (int)i); + DLIB_TEST(g.has_edge(0,1)); + DLIB_TEST(g.has_edge(0,2)); + DLIB_TEST(g.has_edge(1,2)); + DLIB_TEST(g.has_edge(3,2)); + DLIB_TEST(g.has_edge(0,3) == false); + DLIB_TEST(g.has_edge(1,3) == false); + + } + + + void test_copy() + { + { + directed_graph::kernel_1a_c a,b; + + a.set_number_of_nodes(3); + a.node(0).data = 1; + a.node(1).data = 2; + a.node(2).data = 3; + a.add_edge(0,1); + a.add_edge(1,0); + a.add_edge(0,2); + edge(a,0,1) = 4; + edge(a,1,0) = 3; + edge(a,0,2) = 5; + + a.add_edge(0,0); + edge(a,0,0) = 9; + copy_graph(a, b); + + DLIB_TEST(b.number_of_nodes() == 3); + DLIB_TEST(b.node(0).data == 1); + DLIB_TEST(b.node(1).data == 2); + DLIB_TEST(b.node(2).data == 3); + DLIB_TEST(edge(b,0,1) == 4); + DLIB_TEST(edge(b,1,0) == 3); + DLIB_TEST(edge(b,0,2) == 5); + DLIB_TEST(edge(b,0,0) == 9); + } + { + directed_graph::kernel_1a_c a,b; + + a.set_number_of_nodes(4); + a.node(0).data = 1; + a.node(1).data = 2; + a.node(2).data = 3; + a.node(3).data = 8; + a.add_edge(0,1); + a.add_edge(0,2); + a.add_edge(2,3); + a.add_edge(3,2); + edge(a,0,1) = 4; + edge(a,0,2) = 5; + edge(a,2,3) = 6; + edge(a,3,2) = 3; + + copy_graph(a, b); + + DLIB_TEST(b.number_of_nodes() == 4); + DLIB_TEST(b.node(0).data == 1); + DLIB_TEST(b.node(1).data == 2); + DLIB_TEST(b.node(2).data == 3); + DLIB_TEST(b.node(3).data == 8); + DLIB_TEST(edge(b,0,1) == 4); + DLIB_TEST(edge(b,0,2) == 5); + DLIB_TEST(edge(b,2,3) == 6); + DLIB_TEST(edge(b,3,2) == 3); + } + } + + + + class directed_graph_tester : public tester + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a test for the directed_graph object. When it is constructed + it adds itself into the testing framework. The command line switch is + specified as test_directed_graph by passing that string to the tester constructor. + !*/ + public: + directed_graph_tester ( + ) : + tester ("test_directed_graph", + "Runs tests on the directed_graph component.") + {} + + void perform_test ( + ) + { + test_copy(); + + dlog << LINFO << "testing kernel_1a_c"; + directed_graph_test::kernel_1a_c>(); + + dlog << LINFO << "testing kernel_1a"; + directed_graph_test::kernel_1a>(); + } + } a; + + +} + + diff --git a/ml/dlib/dlib/test/discriminant_pca.cpp b/ml/dlib/dlib/test/discriminant_pca.cpp new file mode 100644 index 000000000..2a7aa61d1 --- /dev/null +++ b/ml/dlib/dlib/test/discriminant_pca.cpp @@ -0,0 +1,365 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "tester.h" +#include +#include +#include +#include +#include +#include + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + dlib::logger dlog("test.discriminant_pca"); + + using dlib::equal; + + class discriminant_pca_tester : public tester + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a unit test. When it is constructed + it adds itself into the testing framework. + !*/ + public: + discriminant_pca_tester ( + ) : + tester ( + "test_discriminant_pca", // the command line argument name for this test + "Run tests on the discriminant_pca object.", // the command line argument description + 0 // the number of command line arguments for this test + ) + { + thetime = 1407805946;// time(0); + } + + time_t thetime; + dlib::rand rnd; + + template + void test1() + { + + dpca_type dpca, dpca2, dpca3; + + DLIB_TEST(dpca.in_vector_size() == 0); + DLIB_TEST(dpca.between_class_weight() == 1); + DLIB_TEST(dpca.within_class_weight() == 1); + + // generate a bunch of 4 dimensional vectors and compute the normal PCA transformation matrix + // and just make sure it is a unitary matrix as it should be. + for (int i = 0; i < 5000; ++i) + { + dpca.add_to_total_variance(randm(4,1,rnd)); + DLIB_TEST(dpca.in_vector_size() == 4); + } + + + matrix mat = dpca.dpca_matrix(1); + + DLIB_TEST(equal(mat*trans(mat), identity_matrix(4))); + + mat = dpca.dpca_matrix(0.9); + DLIB_TEST(equal(mat*trans(mat), identity_matrix(mat.nr()))); + + matrix eig; + dpca.dpca_matrix(mat, eig, 1); + DLIB_TEST(equal(mat*trans(mat), identity_matrix(4))); + // check that all eigen values are grater than 0 + DLIB_TEST(min(eig > 0) == 1); + DLIB_TEST(eig.size() == mat.nr()); + DLIB_TEST(is_col_vector(eig)); + // check that the eigenvalues are sorted + double last = eig(0); + for (long i = 1; i < eig.size(); ++i) + { + DLIB_TEST(last >= eig(i)); + } + + { + matrix mat = dpca.dpca_matrix_of_size(4); + DLIB_TEST(equal(mat*trans(mat), identity_matrix(4))); + } + { + matrix mat = dpca.dpca_matrix_of_size(3); + DLIB_TEST(equal(mat*trans(mat), identity_matrix(3))); + } + + + dpca.set_within_class_weight(5); + dpca.set_between_class_weight(6); + + DLIB_TEST(dpca.in_vector_size() == 4); + DLIB_TEST(dpca.within_class_weight() == 5); + DLIB_TEST(dpca.between_class_weight() == 6); + + + ostringstream sout; + serialize(dpca, sout); + istringstream sin(sout.str()); + deserialize(dpca2, sin); + + // now make sure the serialization worked + DLIB_TEST(dpca.in_vector_size() == 4); + DLIB_TEST(dpca.within_class_weight() == 5); + DLIB_TEST(dpca.between_class_weight() == 6); + DLIB_TEST(dpca2.in_vector_size() == 4); + DLIB_TEST(dpca2.within_class_weight() == 5); + DLIB_TEST(dpca2.between_class_weight() == 6); + DLIB_TEST(equal(dpca.dpca_matrix(), dpca2.dpca_matrix(), 1e-10)); + DLIB_TEST(equal(mat, dpca2.dpca_matrix(1), 1e-10)); + DLIB_TEST(equal(dpca.dpca_matrix(1), mat, 1e-10)); + + // now test swap + dpca2.swap(dpca3); + DLIB_TEST(dpca2.in_vector_size() == 0); + DLIB_TEST(dpca2.between_class_weight() == 1); + DLIB_TEST(dpca2.within_class_weight() == 1); + + DLIB_TEST(dpca3.in_vector_size() == 4); + DLIB_TEST(dpca3.within_class_weight() == 5); + DLIB_TEST(dpca3.between_class_weight() == 6); + DLIB_TEST(equal(mat, dpca3.dpca_matrix(1), 1e-10)); + DLIB_TEST((dpca3 + dpca3).in_vector_size() == 4); + DLIB_TEST((dpca3 + dpca3).within_class_weight() == 5); + DLIB_TEST((dpca3 + dpca3).between_class_weight() == 6); + + dpca.clear(); + + DLIB_TEST(dpca.in_vector_size() == 0); + DLIB_TEST(dpca.between_class_weight() == 1); + DLIB_TEST(dpca.within_class_weight() == 1); + } + + template + void test2() + { + dpca_type dpca, dpca2, dpca3; + + typename dpca_type::column_matrix samp1(4), samp2(4); + + for (int i = 0; i < 5000; ++i) + { + dpca.add_to_total_variance(randm(4,1,rnd)); + DLIB_TEST(dpca.in_vector_size() == 4); + + // do this to subtract out the variance along the 3rd axis + samp1 = 0,0,0,0; + samp2 = 0,0,1,0; + dpca.add_to_within_class_variance(samp1, samp2); + } + + matrix mat; + + dpca.set_within_class_weight(0); + mat = dpca.dpca_matrix(1); + DLIB_TEST(equal(mat*trans(mat), identity_matrix(4))); + DLIB_TEST(dpca.dpca_matrix(1).nr() == 4); + dpca.set_within_class_weight(1000); + DLIB_TEST(dpca.dpca_matrix(1).nr() == 3); + + // the 3rd column of the transformation matrix should be all zero since + // we killed all the variation long the 3rd axis + DLIB_TEST(sum(abs(colm(dpca.dpca_matrix(1),2))) < 1e-5); + + mat = dpca.dpca_matrix(1); + DLIB_TEST(equal(mat*trans(mat), identity_matrix(3))); + + + } + + template + void test3() + { + dpca_type dpca, dpca2, dpca3; + + typename dpca_type::column_matrix samp1(4), samp2(4); + + for (int i = 0; i < 5000; ++i) + { + dpca.add_to_total_variance(randm(4,1,rnd)); + DLIB_TEST(dpca.in_vector_size() == 4); + + // do this to subtract out the variance along the 3rd axis + samp1 = 0,0,0,0; + samp2 = 0,0,1,0; + dpca.add_to_within_class_variance(samp1, samp2); + + // do this to subtract out the variance along the 1st axis + samp1 = 0,0,0,0; + samp2 = 1,0,0,0; + dpca.add_to_within_class_variance(samp1, samp2); + } + + matrix mat; + + dpca.set_within_class_weight(0); + mat = dpca.dpca_matrix(1); + DLIB_TEST(equal(mat*trans(mat), identity_matrix(4))); + DLIB_TEST(dpca.dpca_matrix(1).nr() == 4); + dpca.set_within_class_weight(10000); + DLIB_TEST(dpca.dpca_matrix(1).nr() == 2); + + // the 1st and 3rd columns of the transformation matrix should be all zero since + // we killed all the variation long the 1st and 3rd axes + DLIB_TEST(sum(abs(colm(dpca.dpca_matrix(1),2))) < 1e-5); + DLIB_TEST(sum(abs(colm(dpca.dpca_matrix(1),0))) < 1e-5); + + mat = dpca.dpca_matrix(1); + DLIB_TEST(equal(mat*trans(mat), identity_matrix(2))); + + + } + + template + void test4() + { + dpca_type dpca, dpca2, dpca3; + + dpca_type add_dpca1, add_dpca2, add_dpca3, add_dpca4, sum_dpca; + + typename dpca_type::column_matrix samp1(4), samp2(4), samp; + + for (int i = 0; i < 5000; ++i) + { + samp = randm(4,1,rnd); + dpca.add_to_total_variance(samp); + add_dpca4.add_to_total_variance(samp); + DLIB_TEST(dpca.in_vector_size() == 4); + + // do this to subtract out the variance along the 3rd axis + samp1 = 0,0,0,0; + samp2 = 0,0,1,0; + dpca.add_to_within_class_variance(samp1, samp2); + add_dpca1.add_to_within_class_variance(samp1, samp2); + + // do this to subtract out the variance along the 1st axis + samp1 = 0,0,0,0; + samp2 = 1,0,0,0; + dpca.add_to_within_class_variance(samp1, samp2); + add_dpca2.add_to_within_class_variance(samp1, samp2); + + // do this to add the variance along the 3rd axis back in + samp1 = 0,0,0,0; + samp2 = 0,0,1,0; + dpca.add_to_between_class_variance(samp1, samp2); + add_dpca3.add_to_between_class_variance(samp1, samp2); + } + + matrix mat, mat2; + + sum_dpca += dpca_type() + dpca_type() + add_dpca1 + dpca_type() + add_dpca2 + add_dpca3 + add_dpca4; + dpca.set_within_class_weight(0); + dpca.set_between_class_weight(0); + sum_dpca.set_within_class_weight(0); + sum_dpca.set_between_class_weight(0); + mat = dpca.dpca_matrix(1); + DLIB_TEST(equal(mat, sum_dpca.dpca_matrix(1), 1e-10)); + DLIB_TEST(equal(mat*trans(mat), identity_matrix(4))); + DLIB_TEST(dpca.dpca_matrix(1).nr() == 4); + dpca.set_within_class_weight(10000); + sum_dpca.set_within_class_weight(10000); + DLIB_TEST(dpca.dpca_matrix(1).nr() == 2); + + // the 1st and 3rd columns of the transformation matrix should be all zero since + // we killed all the variation long the 1st and 3rd axes + DLIB_TEST(sum(abs(colm(dpca.dpca_matrix(1),2))) < 1e-4); + DLIB_TEST(sum(abs(colm(dpca.dpca_matrix(1),0))) < 1e-4); + + mat = dpca.dpca_matrix(1); + DLIB_TEST(equal(mat*trans(mat), identity_matrix(2))); + DLIB_TEST_MSG(equal(mat, mat2=sum_dpca.dpca_matrix(1), 1e-9), max(abs(mat - mat2))); + + + // now add the variance back in using the between class weight + dpca.set_within_class_weight(0); + dpca.set_between_class_weight(1); + mat = dpca.dpca_matrix(1); + DLIB_TEST(equal(mat*trans(mat), identity_matrix(4))); + DLIB_TEST(dpca.dpca_matrix(1).nr() == 4); + dpca.set_within_class_weight (10000); + dpca.set_between_class_weight(100000); + sum_dpca.set_within_class_weight (10000); + sum_dpca.set_between_class_weight(100000); + DLIB_TEST(dpca.dpca_matrix(1).nr() == 3); + + // the first column should be all zeros + DLIB_TEST(sum(abs(colm(dpca.dpca_matrix(1),0))) < 1e-5); + + mat = dpca.dpca_matrix(1); + DLIB_TEST(equal(mat*trans(mat), identity_matrix(3))); + DLIB_TEST(equal(mat, sum_dpca.dpca_matrix(1))); + + + } + + template + void test5() + { + dpca_type dpca, dpca2; + typename dpca_type::column_matrix samp1(4), samp2(4); + + samp1 = 0,0,0,0; + samp2 = 0,0,1,0; + + for (int i = 0; i < 5000; ++i) + { + dpca.add_to_between_class_variance(samp1, samp2); + dpca2.add_to_total_variance(samp1); + dpca2.add_to_total_variance(samp2); + } + + matrix mat, eig; + dpca.dpca_matrix(mat, eig, 1); + + // make sure the eigenvalues come out the way they should for this simple data set + DLIB_TEST(eig.size() == 1); + DLIB_TEST_MSG(abs(eig(0) - 1) < 1e-10, abs(eig(0) - 1)); + + dpca2.dpca_matrix(mat, eig, 1); + + // make sure the eigenvalues come out the way they should for this simple data set + DLIB_TEST(eig.size() == 1); + DLIB_TEST(abs(eig(0) - 0.25) < 1e-10); + + } + + void perform_test ( + ) + { + ++thetime; + typedef matrix sample_type; + typedef discriminant_pca dpca_type; + + dlog << LINFO << "time seed: " << thetime; + rnd.set_seed(cast_to_string(thetime)); + + test5(); + + for (int i = 0; i < 10; ++i) + { + print_spinner(); + test1(); + print_spinner(); + test2(); + print_spinner(); + test3(); + print_spinner(); + test4(); + } + } + }; + + // Create an instance of this object. Doing this causes this test + // to be automatically inserted into the testing framework whenever this cpp file + // is linked into the project. Note that since we are inside an unnamed-namespace + // we won't get any linker errors about the symbol a being defined multiple times. + discriminant_pca_tester a; + +} + + diff --git a/ml/dlib/dlib/test/disjoint_subsets.cpp b/ml/dlib/dlib/test/disjoint_subsets.cpp new file mode 100644 index 000000000..2545219cd --- /dev/null +++ b/ml/dlib/dlib/test/disjoint_subsets.cpp @@ -0,0 +1,102 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include + +#include "tester.h" + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.disjoint_subsets"); + + void test_disjoint_subset() + { + print_spinner(); + disjoint_subsets s; + + DLIB_TEST(s.size() == 0); + + s.set_size(5); + DLIB_TEST(s.size() == 5); + + DLIB_TEST(s.find_set(0) == 0); + DLIB_TEST(s.find_set(1) == 1); + DLIB_TEST(s.find_set(2) == 2); + DLIB_TEST(s.find_set(3) == 3); + DLIB_TEST(s.find_set(4) == 4); + + unsigned long id = s.merge_sets(1,3); + DLIB_TEST(s.find_set(0) == 0); + DLIB_TEST(s.find_set(1) == id); + DLIB_TEST(s.find_set(2) == 2); + DLIB_TEST(s.find_set(3) == id); + DLIB_TEST(s.find_set(4) == 4); + + id = s.merge_sets(s.find_set(1),4); + DLIB_TEST(s.find_set(0) == 0); + DLIB_TEST(s.find_set(1) == id); + DLIB_TEST(s.find_set(2) == 2); + DLIB_TEST(s.find_set(3) == id); + DLIB_TEST(s.find_set(4) == id); + + unsigned long id2 = s.merge_sets(0,2); + DLIB_TEST(s.find_set(0) == id2); + DLIB_TEST(s.find_set(1) == id); + DLIB_TEST(s.find_set(2) == id2); + DLIB_TEST(s.find_set(3) == id); + DLIB_TEST(s.find_set(4) == id); + + id = s.merge_sets(s.find_set(1),s.find_set(0)); + DLIB_TEST(s.find_set(0) == id); + DLIB_TEST(s.find_set(1) == id); + DLIB_TEST(s.find_set(2) == id); + DLIB_TEST(s.find_set(3) == id); + DLIB_TEST(s.find_set(4) == id); + + DLIB_TEST(s.size() == 5); + s.set_size(1); + DLIB_TEST(s.size() == 1); + DLIB_TEST(s.find_set(0) == 0); + s.set_size(2); + DLIB_TEST(s.size() == 2); + DLIB_TEST(s.find_set(0) == 0); + DLIB_TEST(s.find_set(1) == 1); + id = s.merge_sets(0,1); + DLIB_TEST(s.size() == 2); + DLIB_TEST(id == s.find_set(0)); + DLIB_TEST(id == s.find_set(1)); + DLIB_TEST(s.size() == 2); + s.clear(); + DLIB_TEST(s.size() == 0); + + } + + + class tester_disjoint_subsets : public tester + { + public: + tester_disjoint_subsets ( + ) : + tester ("test_disjoint_subsets", + "Runs tests on the disjoint_subsets component.") + {} + + void perform_test ( + ) + { + test_disjoint_subset(); + } + } a; + + +} diff --git a/ml/dlib/dlib/test/disjoint_subsets_sized.cpp b/ml/dlib/dlib/test/disjoint_subsets_sized.cpp new file mode 100644 index 000000000..57ced68e6 --- /dev/null +++ b/ml/dlib/dlib/test/disjoint_subsets_sized.cpp @@ -0,0 +1,143 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include + +#include "tester.h" + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.disjoint_subsets_sized"); + + void test_disjoint_subsets_sized() + { + print_spinner(); + disjoint_subsets_sized s; + + DLIB_TEST(s.size() == 0); + DLIB_TEST(s.get_number_of_sets() == 0); + + s.set_size(5); + DLIB_TEST(s.size() == 5); + DLIB_TEST(s.get_number_of_sets() == 5); + + DLIB_TEST(s.find_set(0) == 0); + DLIB_TEST(s.find_set(1) == 1); + DLIB_TEST(s.find_set(2) == 2); + DLIB_TEST(s.find_set(3) == 3); + DLIB_TEST(s.find_set(4) == 4); + + DLIB_TEST(s.get_size_of_set(0) == 1); + DLIB_TEST(s.get_size_of_set(1) == 1); + DLIB_TEST(s.get_size_of_set(2) == 1); + DLIB_TEST(s.get_size_of_set(3) == 1); + DLIB_TEST(s.get_size_of_set(4) == 1); + + unsigned long id = s.merge_sets(1,3); + DLIB_TEST(s.get_number_of_sets() == 4); + DLIB_TEST(s.find_set(0) == 0); + DLIB_TEST(s.find_set(1) == id); + DLIB_TEST(s.find_set(2) == 2); + DLIB_TEST(s.find_set(3) == id); + DLIB_TEST(s.find_set(4) == 4); + DLIB_TEST(s.get_size_of_set(0) == 1); + DLIB_TEST(s.get_size_of_set(s.find_set(1)) == 2); + DLIB_TEST(s.get_size_of_set(2) == 1); + DLIB_TEST(s.get_size_of_set(s.find_set(3)) == 2); + DLIB_TEST(s.get_size_of_set(4) == 1); + + id = s.merge_sets(s.find_set(1),4); + DLIB_TEST(s.get_number_of_sets() == 3); + DLIB_TEST(s.find_set(0) == 0); + DLIB_TEST(s.find_set(1) == id); + DLIB_TEST(s.find_set(2) == 2); + DLIB_TEST(s.find_set(3) == id); + DLIB_TEST(s.find_set(4) == id); + DLIB_TEST(s.get_size_of_set(0) == 1); + DLIB_TEST(s.get_size_of_set(s.find_set(1)) == 3); + DLIB_TEST(s.get_size_of_set(2) == 1); + DLIB_TEST(s.get_size_of_set(s.find_set(3)) == 3); + DLIB_TEST(s.get_size_of_set(s.find_set(4)) == 3); + + unsigned long id2 = s.merge_sets(0,2); + DLIB_TEST(s.get_number_of_sets() == 2); + DLIB_TEST(s.find_set(0) == id2); + DLIB_TEST(s.find_set(1) == id); + DLIB_TEST(s.find_set(2) == id2); + DLIB_TEST(s.find_set(3) == id); + DLIB_TEST(s.find_set(4) == id); + DLIB_TEST(s.get_size_of_set(s.find_set(0)) == 2); + DLIB_TEST(s.get_size_of_set(s.find_set(1)) == 3); + DLIB_TEST(s.get_size_of_set(s.find_set(2)) == 2); + DLIB_TEST(s.get_size_of_set(s.find_set(3)) == 3); + DLIB_TEST(s.get_size_of_set(s.find_set(4)) == 3); + + id = s.merge_sets(s.find_set(1),s.find_set(0)); + DLIB_TEST(s.get_number_of_sets() == 1); + DLIB_TEST(s.find_set(0) == id); + DLIB_TEST(s.find_set(1) == id); + DLIB_TEST(s.find_set(2) == id); + DLIB_TEST(s.find_set(3) == id); + DLIB_TEST(s.find_set(4) == id); + DLIB_TEST(s.get_size_of_set(s.find_set(0)) == 5); + DLIB_TEST(s.get_size_of_set(s.find_set(1)) == 5); + DLIB_TEST(s.get_size_of_set(s.find_set(2)) == 5); + DLIB_TEST(s.get_size_of_set(s.find_set(3)) == 5); + DLIB_TEST(s.get_size_of_set(s.find_set(4)) == 5); + + DLIB_TEST(s.size() == 5); + s.set_size(1); + DLIB_TEST(s.size() == 1); + DLIB_TEST(s.get_number_of_sets() == 1); + DLIB_TEST(s.find_set(0) == 0); + DLIB_TEST(s.get_size_of_set(0) == 1); + s.set_size(2); + DLIB_TEST(s.size() == 2); + DLIB_TEST(s.get_number_of_sets() == 2); + DLIB_TEST(s.find_set(0) == 0); + DLIB_TEST(s.find_set(1) == 1); + DLIB_TEST(s.get_size_of_set(0) == 1); + DLIB_TEST(s.get_size_of_set(1) == 1); + id = s.merge_sets(0,1); + DLIB_TEST(s.size() == 2); + DLIB_TEST(s.get_number_of_sets() == 1); + DLIB_TEST(id == s.find_set(0)); + DLIB_TEST(id == s.find_set(1)); + DLIB_TEST(s.get_size_of_set(s.find_set(0)) == 2); + DLIB_TEST(s.get_size_of_set(s.find_set(1)) == 2); + DLIB_TEST(s.size() == 2); + s.clear(); + DLIB_TEST(s.size() == 0); + DLIB_TEST(s.get_number_of_sets() == 0); + + } + + + class tester_disjoint_subsets_sized : public tester + { + public: + tester_disjoint_subsets_sized ( + ) : + tester ("test_disjoint_subsets_sized", + "Runs tests on the disjoint_subsets_sized component.") + {} + + void perform_test ( + ) + { + test_disjoint_subsets_sized(); + } + } a; + + +} diff --git a/ml/dlib/dlib/test/dnn.cpp b/ml/dlib/dlib/test/dnn.cpp new file mode 100644 index 000000000..9d3258b70 --- /dev/null +++ b/ml/dlib/dlib/test/dnn.cpp @@ -0,0 +1,3261 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include +#include +#include +#include "../dnn.h" + +#include "tester.h" + +#ifndef __INTELLISENSE__ + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.dnn"); + +// ---------------------------------------------------------------------------------------- + + template + float compare_gradients ( + const tensor& t, + T grad + ) + { + float max_error = 0; + auto p = t.host(); + for (size_t i = 0; i < t.size(); ++i) + { + max_error = std::max(max_error, std::abs(p[i]-grad(i))); + } + return max_error; + } + +// ---------------------------------------------------------------------------------------- + + void test_tanh() + { + using namespace dlib::tt; + print_spinner(); + resizable_tensor src, dest, gradient_input; + src = matrix_cast(gaussian_randm(5,5, 0)); + dest = matrix_cast(gaussian_randm(5,5, 1)); + gradient_input = matrix_cast(gaussian_randm(5,5, 2)); + + + + auto grad_src = [&](long idx) { + auto f = [&](float eps) { + const float old = src.host()[idx]; + src.host()[idx] += eps; + tanh(dest, src); + float result = dot(gradient_input, dest); + src.host()[idx] = old; + return result; + }; + const float eps = 0.01; + return (f(+eps)-f(-eps))/(2*eps); + }; + + resizable_tensor src_grad; + src_grad.copy_size(src); + src_grad = 0; + + tanh(dest, src); + tanh_gradient(src_grad, dest, gradient_input); + + auto grad_error = compare_gradients(src_grad, grad_src); + dlog << LINFO << "src error: " << grad_error; + DLIB_TEST(grad_error < 0.001); + } + + void test_sigmoid() + { + using namespace dlib::tt; + print_spinner(); + resizable_tensor src, dest, gradient_input; + src = matrix_cast(gaussian_randm(5,5, 0)); + dest = matrix_cast(gaussian_randm(5,5, 1)); + gradient_input = matrix_cast(gaussian_randm(5,5, 2)); + + + + auto grad_src = [&](long idx) { + auto f = [&](float eps) { + const float old = src.host()[idx]; + src.host()[idx] += eps; + sigmoid(dest, src); + float result = dot(gradient_input, dest); + src.host()[idx] = old; + return result; + }; + const float eps = 0.01; + return (f(+eps)-f(-eps))/(2*eps); + }; + + resizable_tensor src_grad; + src_grad.copy_size(src); + src_grad = 0; + + sigmoid(dest, src); + sigmoid_gradient(src_grad, dest, gradient_input); + + auto grad_error = compare_gradients(src_grad, grad_src); + dlog << LINFO << "src error: " << grad_error; + DLIB_TEST(grad_error < 0.001); + } + + void test_softmax() + { + using namespace dlib::tt; + print_spinner(); + const long nr = 3; + const long nc = 3; + resizable_tensor src(5,5,nr,nr), dest(5,5,nr,nc), gradient_input(5,5,nr,nc); + tt::tensor_rand rnd; + rnd.fill_uniform(src); + rnd.fill_uniform(dest); + // fill like this as a test of the assignment operator. + gradient_input = matrix_cast(gaussian_randm(5,5*nr*nc, 2)); + + + + auto grad_src = [&](long idx) { + auto f = [&](float eps) { + const float old = src.host()[idx]; + src.host()[idx] += eps; + tt::softmax(dest, src); + float result = dot(gradient_input, dest); + src.host()[idx] = old; + return result; + }; + const float eps = 0.01; + return (f(+eps)-f(-eps))/(2*eps); + }; + + resizable_tensor src_grad; + src_grad.copy_size(src); + src_grad = 0; + + tt::softmax(dest, src); + softmax_gradient(src_grad, dest, gradient_input); + + auto grad_error = compare_gradients(src_grad, grad_src); + dlog << LINFO << "src error: " << grad_error; + DLIB_TEST(grad_error < 0.001); + +#ifdef DLIB_USE_CUDA + resizable_tensor src1 = src; + resizable_tensor src2 = src; + resizable_tensor dest1, dest2; + dest1.copy_size(src); + dest2.copy_size(src); + cuda::softmax_all(dest1, src1); + cpu::softmax_all(dest2, src2); + DLIB_TEST_MSG(max(abs(mat(dest1)-mat(dest2))) < 1e-5, max(abs(mat(dest1)-mat(dest2)))); +#endif + } + + void test_softmax_all() + { + using namespace dlib::tt; + print_spinner(); + const long nr = 3; + const long nc = 3; + resizable_tensor src(5,5,nr,nr), dest(5,5,nr,nc), gradient_input(5,5,nr,nc); + tt::tensor_rand rnd; + rnd.fill_uniform(src); + rnd.fill_uniform(dest); + // fill like this as a test of the assignment operator. + gradient_input = matrix_cast(gaussian_randm(5,5*nr*nc, 2)); + + + + auto grad_src = [&](long idx) { + auto f = [&](float eps) { + const float old = src.host()[idx]; + src.host()[idx] += eps; + tt::softmax_all(dest, src); + float result = dot(gradient_input, dest); + src.host()[idx] = old; + return result; + }; + const float eps = 0.01; + return (f(+eps)-f(-eps))/(2*eps); + }; + + resizable_tensor src_grad; + src_grad.copy_size(src); + src_grad = 0; + + tt::softmax_all(dest, src); + softmax_all_gradient(src_grad, dest, gradient_input); + + auto grad_error = compare_gradients(src_grad, grad_src); + dlog << LINFO << "src error: " << grad_error; + DLIB_TEST(grad_error < 0.001); + +#ifdef DLIB_USE_CUDA + resizable_tensor src1 = src; + resizable_tensor src2 = src; + resizable_tensor dest1, dest2; + dest1.copy_size(src); + dest2.copy_size(src); + cuda::softmax_all(dest1, src1); + cpu::softmax_all(dest2, src2); + DLIB_TEST_MSG(max(abs(mat(dest1)-mat(dest2))) < 1e-5, max(abs(mat(dest1)-mat(dest2)))); +#endif + } + + void test_batch_normalize() + { + using namespace dlib::tt; + print_spinner(); + resizable_tensor src, gamma, beta, dest, dest2, dest3, means, vars, gradient_input; + src = matrix_cast(gaussian_randm(5,5, 0)); + gamma = matrix_cast(gaussian_randm(1,5, 1)); + beta = matrix_cast(gaussian_randm(1,5, 2)); + gradient_input = matrix_cast(gaussian_randm(5,5, 3)); + + gamma = 1; + beta = 0; + + resizable_tensor running_means; + resizable_tensor running_variances; + batch_normalize(DEFAULT_BATCH_NORM_EPS,dest, means, vars, 1, running_means, running_variances, src, gamma, beta); + const double scale = (src.num_samples())/(src.num_samples()-1.0); + // Turn back into biased variance estimate because that's how batch_normalize() works, so if we want to match it this is necessary. + running_variances = mat(running_variances)/scale; + batch_normalize_inference(DEFAULT_BATCH_NORM_EPS,dest2, src, gamma, beta, running_means, running_variances); + DLIB_TEST_MSG(max(abs(mat(dest2)-mat(dest))) < 1e-5, max(abs(mat(dest2)-mat(dest)))); + cpu::batch_normalize_inference(DEFAULT_BATCH_NORM_EPS,dest3, src, gamma, beta, running_means, running_variances); + DLIB_TEST_MSG(max(abs(mat(dest3)-mat(dest))) < 1e-5, max(abs(mat(dest3)-mat(dest)))); + + + auto grad_src = [&](long idx) { + auto f = [&](float eps) { + const float old = src.host()[idx]; + src.host()[idx] += eps; + batch_normalize(DEFAULT_BATCH_NORM_EPS,dest, means, vars, 1, running_means, running_variances, src, gamma, beta); + float result = dot(gradient_input, dest); + src.host()[idx] = old; + return result; + }; + const float eps = 0.01; + return (f(+eps)-f(-eps))/(2*eps); + }; + auto grad_gamma = [&](long idx) { + auto f = [&](float eps) { + const float old = gamma.host()[idx]; + gamma.host()[idx] += eps; + batch_normalize(DEFAULT_BATCH_NORM_EPS,dest, means, vars, 1, running_means, running_variances, src, gamma, beta); + float result = dot(gradient_input, dest); + gamma.host()[idx] = old; + return result; + }; + const float eps = 0.01; + return (f(+eps)-f(-eps))/(2*eps); + }; + auto grad_beta = [&](long idx) { + auto f = [&](float eps) { + const float old = beta.host()[idx]; + beta.host()[idx] += eps; + batch_normalize(DEFAULT_BATCH_NORM_EPS,dest, means, vars, 1, running_means, running_variances, src, gamma, beta); + float result = dot(gradient_input, dest); + beta.host()[idx] = old; + return result; + }; + const float eps = 0.01; + return (f(+eps)-f(-eps))/(2*eps); + }; + + resizable_tensor src_grad, gamma_grad, beta_grad; + src_grad.copy_size(src); + gamma_grad.copy_size(gamma); + beta_grad.copy_size(beta); + src_grad = 0; + gamma_grad = 8; + beta_grad = 8; + + batch_normalize_gradient(DEFAULT_BATCH_NORM_EPS,gradient_input, means, vars, src, gamma, src_grad, gamma_grad, beta_grad); + + auto grad_error = compare_gradients(src_grad, grad_src); + dlog << LINFO << "src error: " << grad_error; + DLIB_TEST(grad_error < 0.001); + + grad_error = compare_gradients(gamma_grad, grad_gamma); + dlog << LINFO << "gamma error: " << grad_error; + DLIB_TEST(grad_error < 0.001); + + grad_error = compare_gradients(beta_grad, grad_beta); + dlog << LINFO << "beta error: " << grad_error; + DLIB_TEST(grad_error < 0.001); + } + + void test_batch_normalize_conv() + { + using namespace dlib::tt; + print_spinner(); + resizable_tensor src(5,5,4,4), gamma, beta, dest, dest2, dest3, means, vars, gradient_input(5,5,4,4); + tt::tensor_rand rnd; + rnd.fill_gaussian(src); + rnd.fill_gaussian(gradient_input); + gamma = matrix_cast(gaussian_randm(1,5, 1)); + beta = matrix_cast(gaussian_randm(1,5, 2)); + + gamma = 1; + beta = 0; + + resizable_tensor running_means; + resizable_tensor running_variances; + batch_normalize_conv(DEFAULT_BATCH_NORM_EPS,dest, means, vars, 1, running_means, running_variances, src, gamma, beta); + const double scale = (src.num_samples()*src.nr()*src.nc())/(src.num_samples()*src.nr()*src.nc()-1.0); + // Turn back into biased variance estimate because that's how + // batch_normalize_conv() works, so if we want to match it this is necessary. + running_variances = mat(running_variances)/scale; + batch_normalize_conv_inference(DEFAULT_BATCH_NORM_EPS,dest2, src, gamma, beta, running_means, running_variances); + DLIB_TEST(max(abs(mat(dest2)-mat(dest))) < 1e-5); + cpu::batch_normalize_conv_inference(DEFAULT_BATCH_NORM_EPS,dest3, src, gamma, beta, running_means, running_variances); + DLIB_TEST(max(abs(mat(dest3)-mat(dest))) < 1e-5); + + + auto grad_src = [&](long idx) { + auto f = [&](float eps) { + const float old = src.host()[idx]; + src.host()[idx] += eps; + batch_normalize_conv(DEFAULT_BATCH_NORM_EPS,dest, means, vars, 1, running_means, running_variances, src, gamma, beta); + float result = dot(gradient_input, dest); + src.host()[idx] = old; + return result; + }; + const float eps = 0.01; + return (f(+eps)-f(-eps))/(2*eps); + }; + auto grad_gamma = [&](long idx) { + auto f = [&](float eps) { + const float old = gamma.host()[idx]; + gamma.host()[idx] += eps; + batch_normalize_conv(DEFAULT_BATCH_NORM_EPS,dest, means, vars, 1, running_means, running_variances, src, gamma, beta); + float result = dot(gradient_input, dest); + gamma.host()[idx] = old; + return result; + }; + const float eps = 0.01; + return (f(+eps)-f(-eps))/(2*eps); + }; + auto grad_beta = [&](long idx) { + auto f = [&](float eps) { + const float old = beta.host()[idx]; + beta.host()[idx] += eps; + batch_normalize_conv(DEFAULT_BATCH_NORM_EPS,dest, means, vars, 1, running_means, running_variances, src, gamma, beta); + float result = dot(gradient_input, dest); + beta.host()[idx] = old; + return result; + }; + const float eps = 0.01; + return (f(+eps)-f(-eps))/(2*eps); + }; + + + resizable_tensor src_grad, gamma_grad, beta_grad; + src_grad.copy_size(src); + gamma_grad.copy_size(gamma); + beta_grad.copy_size(beta); + src_grad = 0; + gamma_grad = 9; + beta_grad = 9; + + batch_normalize_conv_gradient(DEFAULT_BATCH_NORM_EPS,gradient_input, means, vars, src, gamma, src_grad, gamma_grad, beta_grad); + + + auto grad_error = compare_gradients(src_grad, grad_src); + dlog << LINFO << "src error: " << grad_error; + DLIB_TEST(grad_error < 0.001); + + grad_error = compare_gradients(gamma_grad, grad_gamma); + dlog << LINFO << "gamma error: " << grad_error; + DLIB_TEST(grad_error < 0.001); + + grad_error = compare_gradients(beta_grad, grad_beta); + dlog << LINFO << "beta error: " << grad_error; + DLIB_TEST(grad_error < 0.001); + + } + +// ---------------------------------------------------------------------------------------- + + void test_basic_tensor_ops() + { + using namespace dlib::tt; + print_spinner(); + resizable_tensor dest, src(3,4), A(1,4), B(1,4); + src = 2; + dest.copy_size(src); + affine_transform(dest, src, 2, 3); + dlog << LINFO << mat(dest); + matrix truth1(3,4), truth2(3,4); + + truth1 = 2; + DLIB_TEST(max(abs(truth1-mat(src))) < 1e-5); + src *= 2; + truth1 = 4; + DLIB_TEST(max(abs(truth1-mat(src))) < 1e-5); + src = 2; + + + truth1 = 7; + truth2 = 7, 10, 7, 7, + 7, 10, 7, 7, + 7, 10, 7, 7; + DLIB_TEST(max(abs(truth1-mat(dest))) < 1e-5); + + A = 2; + B = 3; + A.host()[1] = 3; + B.host()[1] = 4; + dest = 0; + affine_transform(dest, src, A, B); + dlog << LINFO << mat(dest); + DLIB_TEST(max(abs(truth2-mat(dest))) < 1e-5); + + A = matrix_cast(gaussian_randm(3,4, 1)); + B = matrix_cast(gaussian_randm(3,4, 2)); + affine_transform(dest, src, A, B); + dlog << LINFO << mat(dest); + matrix truth3 = pointwise_multiply(mat(src), mat(A)) + mat(B); + DLIB_TEST(max(abs(truth3-mat(dest))) < 1e-5); + + matrix truth4 = pointwise_multiply(mat(A), mat(B)); + tt::multiply(false, A, A, B); + DLIB_TEST(max(abs(truth4-mat(A))) < 1e-5); + truth4 = pointwise_multiply(mat(A), mat(B)) + mat(A); + tt::multiply(true, A, A, B); + DLIB_TEST(max(abs(truth4-mat(A))) < 1e-5); + + matrix truth5 = mat(B) > 0.1; + dlog << LINFO << truth5; + threshold(B, 0.1); + DLIB_TEST(max(abs(truth5-mat(B))) < 1e-5); + + int cnt = 0; + for(auto& x : A) + x = cnt++; + + truth1.set_size(2,2); + truth2.set_size(2,2); + truth3.set_size(2,2); + truth1 = 0,1,2,3; + truth2 = 4,5,6,7; + truth3 = 8,9,10,11; + + alias_tensor at(2,2); + auto A0 = at(A,0); + auto A4 = at(A,4); + auto A8 = at(const_cast(A),8); + DLIB_TEST(mat(A0) == truth1); + DLIB_TEST(mat(at(A,4)) == truth2); + DLIB_TEST(mat(A8) == truth3); + + A4 += uniform_matrix(2,2,2); + truth2 += 2; + DLIB_TEST(mat(A4) == truth2); + truth1 = trans(reshape_to_column_vector(truth1)); + truth2 = trans(reshape_to_column_vector(truth2)); + truth3 = trans(reshape_to_column_vector(truth3)); + + DLIB_TEST(mat(A) == join_cols(truth1,join_cols(truth2,truth3))); + + affine_transform(A,A,1,2); + truth1 += 2; + truth2 += 2; + truth3 += 2; + DLIB_TEST(mat(at(A,4)) == reshape(truth2,2,2)); + DLIB_TEST(mat(A) == join_cols(truth1,join_cols(truth2,truth3))); + + { + resizable_tensor dest(3,4); + resizable_tensor A, B; + A = dest; + B = dest; + + tensor_rand rnd; + rnd.fill_uniform(dest); + rnd.fill_uniform(A); + rnd.fill_uniform(B); + + dest.set_size(1,4); + + tt::multiply(false, dest, A, B); + DLIB_TEST(max(abs(mat(dest)-sum_rows(pointwise_multiply(mat(A),mat(B))))) < 1e-6); + + A.set_size(1,4); + rnd.fill_uniform(A); + matrix AA = join_cols(mat(A),mat(A)); AA = join_cols(mat(A),AA); + + tt::multiply(false, dest, A, B); + DLIB_TEST(max(abs(mat(dest)-sum_rows(pointwise_multiply(AA,mat(B))))) < 1e-6); + + tt::multiply(false, dest, B, A); + DLIB_TEST(max(abs(mat(dest)-sum_rows(pointwise_multiply(AA,mat(B))))) < 1e-6); + matrix prevdest = mat(dest); + tt::multiply(true, dest, B, A); + DLIB_TEST(max(abs(mat(dest)-prevdest-sum_rows(pointwise_multiply(AA,mat(B))))) < 1e-6); + + dest.set_size(3,4); + tt::multiply(false, dest, B, A); + DLIB_TEST(max(abs(mat(dest)-pointwise_multiply(AA,mat(B)))) < 1e-6); + prevdest = mat(dest); + tt::multiply(true, dest, B, A); + DLIB_TEST(max(abs(mat(dest)-prevdest-pointwise_multiply(AA,mat(B)))) < 1e-6); + + tt::multiply(false, dest, A, B); + DLIB_TEST(max(abs(mat(dest)-pointwise_multiply(AA,mat(B)))) < 1e-6); + prevdest = mat(dest); + tt::multiply(true, dest, B, A); + DLIB_TEST(max(abs(mat(dest)-prevdest-pointwise_multiply(AA,mat(B)))) < 1e-6); + } + + { + resizable_tensor A, B, truth; + A.set_size(2,3,4,5); + truth.copy_size(A); + B.copy_size(A); + + A = 4; + B = 1; + truth = 1; + DLIB_TEST(max(abs(mat(B)- mat(truth))) < 1e-5); + memcpy(A, truth); + DLIB_TEST(max(abs(mat(A)- mat(truth))) < 1e-5); + + A = 4; + A.host(); + B.host(); + memcpy(A, truth); + DLIB_TEST(max(abs(mat(A)- mat(truth))) < 1e-5); + +#ifdef DLIB_USE_CUDA + A = 4; + A.device(); + B.host(); + memcpy(A, truth); + DLIB_TEST(max(abs(mat(A)- mat(truth))) < 1e-5); + + A = 4; + A.device(); + B.device(); + memcpy(A, truth); + DLIB_TEST(max(abs(mat(A)- mat(truth))) < 1e-5); + + A = 4; + A.host(); + B.device(); + memcpy(A, truth); + DLIB_TEST(max(abs(mat(A)- mat(truth))) < 1e-5); + + A = 4; + A.host_write_only(); + B.device(); + memcpy(A, truth); + DLIB_TEST(max(abs(mat(A)- mat(truth))) < 1e-5); +#endif + } + + { + resizable_tensor A, B; + A.set_size(11); + B.copy_size(A); + + A = 4; + B = 1; + matrix truth; + + + alias_tensor at(5); + A = 4; + A.host(); + B.host(); + { + // non-aliasing test + auto aA = at(A,5); + auto aB = at(B,5); + memcpy(aA, aB); + truth = {4,4,4,4,4, 1,1,1,1,1, 4}; + DLIB_TEST(max(abs(mat(A)- truth)) < 1e-5); + } + { + // aliasing test + auto aA = at(A,1); + auto aB = at(A,6); + memcpy(aA, aB); + truth = {4,1,1,1,1, 4,1,1,1,1, 4}; + DLIB_TEST(max(abs(mat(A)- truth)) < 1e-5); + } + + +#ifdef DLIB_USE_CUDA + A = 4; + A.device(); + B.host(); + { + // non-aliasing test + auto aA = at(A,5); + auto aB = at(B,5); + memcpy(aA, aB); + truth = {4,4,4,4,4, 1,1,1,1,1, 4}; + DLIB_TEST(max(abs(mat(A)- truth)) < 1e-5); + } + { + // aliasing test + auto aA = at(A,1); + auto aB = at(A,6); + memcpy(aA, aB); + truth = {4,1,1,1,1, 4,1,1,1,1, 4}; + DLIB_TEST(max(abs(mat(A)- truth)) < 1e-5); + } + + + A = 4; + A.device(); + B.device(); + { + // non-aliasing test + auto aA = at(A,5); + auto aB = at(B,5); + memcpy(aA, aB); + truth = {4,4,4,4,4, 1,1,1,1,1, 4}; + DLIB_TEST(max(abs(mat(A)- truth)) < 1e-5); + } + { + // aliasing test + auto aA = at(A,1); + auto aB = at(A,6); + memcpy(aA, aB); + truth = {4,1,1,1,1, 4,1,1,1,1, 4}; + DLIB_TEST(max(abs(mat(A)- truth)) < 1e-5); + } + + A = 4; + A.host(); + B.device(); + { + // non-aliasing test + auto aA = at(A,5); + auto aB = at(B,5); + memcpy(aA, aB); + truth = {4,4,4,4,4, 1,1,1,1,1, 4}; + DLIB_TEST(max(abs(mat(A)- truth)) < 1e-5); + } + { + // aliasing test + auto aA = at(A,1); + auto aB = at(A,6); + memcpy(aA, aB); + truth = {4,1,1,1,1, 4,1,1,1,1, 4}; + DLIB_TEST(max(abs(mat(A)- truth)) < 1e-5); + } + +#endif + } + + { + resizable_tensor A(4,5), B(4); + + tensor_rand rnd; + rnd.fill_uniform(A); + rnd.fill_uniform(B); + + float alpha = 1.4; + float beta = 0.5; + + matrix a(mat(A)), b(mat(B)); + for (long c = 0; c < a.nc(); ++c) + { + set_colm(a,c) = beta*colm(a,c) + alpha*b; + } + + tt::add(beta, A, alpha, B); + DLIB_TEST_MSG(max(abs(mat(A)-a)) < 1e-6, max(abs(mat(A)-a))); + + beta = 0; + for (long c = 0; c < a.nc(); ++c) + { + set_colm(a,c) = beta*colm(a,c) + alpha*b; + } + + tt::add(beta, A, alpha, B); + DLIB_TEST(max(abs(mat(A)-a)) < 1e-6); + } + + { + resizable_tensor A, B; + A.set_size(2,3,4,5); + B.set_size(2,3,4,5); + + tensor_rand rnd; + rnd.fill_uniform(A); + rnd.fill_uniform(B); + + matrix truth; + + truth = 2*mat(A) + 3*mat(B); + tt::add(2, A, 3, B); + DLIB_TEST(max(abs(mat(A)-truth )) < 1e-6); + + + rnd.fill_uniform(A); + rnd.fill_uniform(B); + truth = 0*mat(A) + 3*mat(B); + tt::add(0, A, 3, B); + DLIB_TEST(max(abs(mat(A)-truth )) < 1e-6); + + rnd.fill_uniform(A); + rnd.fill_uniform(B); + truth = 1*mat(A) + 0*mat(B); + tt::add(1, A, 0, B); + DLIB_TEST(max(abs(mat(A)-truth )) < 1e-6); + + + rnd.fill_uniform(A); + rnd.fill_uniform(B); + truth = 0*mat(A) + 0*mat(B); + tt::add(0, A, 0, B); + DLIB_TEST(max(abs(mat(A)-truth )) < 1e-6); + + + B.set_size(1,3,4,5); + rnd.fill_uniform(A); + rnd.fill_uniform(B); + truth = 2*mat(A) + 3*join_cols(mat(B), mat(B)); + tt::add(2, A, 3, B); + DLIB_TEST(max(abs(mat(A)-truth )) < 1e-6); + DLIB_TEST(A.num_samples()==2); + + B.set_size(1,1,4,5); + rnd.fill_uniform(A); + rnd.fill_uniform(B); + matrix temp = join_rows(mat(B), join_rows(mat(B),mat(B))); + truth = 2*mat(A) + 3*join_cols(temp,temp); + tt::add(2, A, 3, B); + DLIB_TEST_MSG(max(abs(mat(A)-truth )) < 1e-6, max(abs(mat(A)-truth ))); + + B.set_size(1,3,1,1); + rnd.fill_uniform(A); + rnd.fill_uniform(B); + resizable_tensor AA(A), BB(B); + tt::add(2, A, 3, B); + cpu::add(2, AA, 3, BB); + DLIB_TEST_MSG(max(abs(mat(A)-mat(AA) )) < 1e-6, max(abs(mat(A)-mat(AA) ))); + } + + { + print_spinner(); + resizable_tensor dest1(123,456), dest2(123,456); + resizable_tensor src1(123,456), src2(123,456); + + tt::tensor_rand rnd; + + rnd.fill_uniform(src1); tt::affine_transform(src1, src1, 1, 2); src2 = src1; // random in range [2, 3] + dest1 = exp(mat(src1)); + tt::exp(dest2, src2); + tt::exp(src2, src2); // should work in place + DLIB_TEST_MSG(max(abs(mat(dest1)-mat(dest2))) < 1e-5, max(abs(mat(dest1)-mat(dest2)))); + DLIB_TEST(max(abs(mat(dest1)-mat(src2))) < 1e-5); + + rnd.fill_uniform(src1); tt::affine_transform(src1, src1, 1, 2); src2 = src1; // random in range [2, 3] + dest1 = log(mat(src1)); + tt::log(dest2, src2); + tt::log(src2, src2); // should work in place + DLIB_TEST(max(abs(mat(dest1)-mat(dest2))) < 1e-5); + DLIB_TEST(max(abs(mat(dest1)-mat(src2))) < 1e-5); + + rnd.fill_uniform(src1); tt::affine_transform(src1, src1, 1, 2); src2 = src1; // random in range [2, 3] + dest1 = log10(mat(src1)); + tt::log10(dest2, src2); + tt::log10(src2, src2); // should work in place + DLIB_TEST(max(abs(mat(dest1)-mat(dest2))) < 1e-5); + DLIB_TEST(max(abs(mat(dest1)-mat(src2))) < 1e-5); + + } + } + +// ---------------------------------------------------------------------------------------- + +#ifdef DLIB_USE_CUDA + + void test_scale_channels() + { + tt::tensor_rand rnd; + + resizable_tensor dest1(2,3,4,5), dest2; + rnd.fill_gaussian(dest1); + dest2 = dest1; + + resizable_tensor src(2,3,4,5); + resizable_tensor scales(2,3); + rnd.fill_gaussian(src); + rnd.fill_gaussian(scales); + + + cpu::scale_channels(true, dest1, src, scales); + cuda::scale_channels(true, dest2, src, scales); + + DLIB_TEST(max(abs(mat(dest1)-mat(dest2))) < 1e-6); + + cpu::scale_channels(false, dest1, src, scales); + cuda::scale_channels(false, dest2, src, scales); + + DLIB_TEST(max(abs(mat(dest1)-mat(dest2))) < 1e-6); + } + +// ---------------------------------------------------------------------------------------- + + void test_affine_rect() + { + dlib::rand rnd; + + for (int iter = 0; iter < 20; ++iter) + { + + long nr = 1 + rnd.get_random_32bit_number()%10; + long nc = 1 + rnd.get_random_32bit_number()%10; + + resizable_tensor dest1(nr,nc), dest2(nr,nc), src1(nr,nc), src2(nr,nc), src3(nr,nc); + matrix dest3; + + dest1 = 1; + dest2 = 1; + dest3 = mat(dest1); + src1 = 2; + src2 = 3; + src3 = 4; + + point p1(rnd.get_random_32bit_number()%nc, rnd.get_random_32bit_number()%nr); + point p2(rnd.get_random_32bit_number()%nc, rnd.get_random_32bit_number()%nr); + rectangle rect(p1,p2); + + cuda::affine_transform(rect, dest1, src1, src2, src3, 2,3,4); + + cpu::affine_transform(rect, dest2, src1, src2, src3, 2,3,4); + + DLIB_TEST(mat(dest1) == mat(dest2)); + + set_subm(dest3,rect) = 2*subm(mat(src1),rect) + 3*subm(mat(src2),rect) + 4*subm(mat(src3),rect); + DLIB_TEST(dest3 == mat(dest1)); + + dest1 = 1; + tt::affine_transform(rect, dest1, src1, src2, src3, 2,3,4); + DLIB_TEST(dest3 == mat(dest1)); + } + } + + void test_conv() + { + cuda::tensor_conv conv1; + cpu::tensor_conv conv2; + + dlib::rand prnd; + for (int iter = 0; iter < 400; ++iter) + { + print_spinner(); + + resizable_tensor data(prnd.get_random_32bit_number()%5+1, + prnd.get_random_32bit_number()%5+1, + prnd.get_random_32bit_number()%25+1, + prnd.get_random_32bit_number()%25+1 + ); + resizable_tensor filters( + prnd.get_random_32bit_number()%5+1, + data.k(), + prnd.get_random_32bit_number()%6+1, + prnd.get_random_32bit_number()%6+1 + ); + + tt::tensor_rand rnd; + rnd.fill_uniform(data); + rnd.fill_uniform(filters); + + + resizable_tensor output1, output2; + + + const int stride_y = prnd.get_random_32bit_number()%5+1; + const int stride_x = prnd.get_random_32bit_number()%5+1; + int padding_y = prnd.get_random_32bit_number()%(filters.nr()/2+1); + int padding_x = prnd.get_random_32bit_number()%(filters.nc()/2+1); + if (!(filters.nr() <= data.nr() + 2*padding_y)) + padding_y = (filters.nr()-data.nr()+1)/2; + if (!(filters.nc() <= data.nc() + 2*padding_x)) + padding_x = (filters.nc()-data.nc()+1)/2; + conv1.setup(data,filters,stride_y,stride_x,padding_y,padding_x); + conv1(false, output1, data, filters); + conv2.setup(data,filters,stride_y,stride_x,padding_y,padding_x); + conv2(false, output2, data, filters); + dlog << LINFO << "forward error: "<< max(abs(mat(output1)-mat(output2))); + DLIB_TEST_MSG(max(abs(mat(output1)-mat(output2))) < 1e-3, max(abs(mat(output1)-mat(output2))) + <<"\n\t padding_y: "<< padding_y + <<"\n\t padding_x: "<< padding_x + ); + + conv1(true, output1, data, filters); + conv2(true, output2, data, filters); + dlog << LINFO << "forward error: "<< max(abs(mat(output1)-mat(output2))); + DLIB_TEST_MSG(max(abs(mat(output1)-mat(output2))) < 1e-3, max(abs(mat(output1)-mat(output2))) + <<"\n\t padding_y: "<< padding_y + <<"\n\t padding_x: "<< padding_x + ); + + + + resizable_tensor gi, data_gradient1, data_gradient2; + gi.copy_size(output1); + rnd.fill_uniform(gi); + + data_gradient1.copy_size(data); + data_gradient2.copy_size(data); + data_gradient1 = 1; + data_gradient2 = 1; + + conv1.get_gradient_for_data(true, gi, filters, data_gradient1); + conv2.get_gradient_for_data(true, gi, filters, data_gradient2); + + dlog << LINFO << "data gradient error: "<< max(abs(mat(data_gradient1)-mat(data_gradient2))); + DLIB_TEST(max(abs(mat(data_gradient1)-mat(data_gradient2))) < 1e-3); + + conv1.get_gradient_for_data(false, gi, filters, data_gradient1); + conv2.get_gradient_for_data(false, gi, filters, data_gradient2); + + dlog << LINFO << "data gradient error: "<< max(abs(mat(data_gradient1)-mat(data_gradient2))); + DLIB_TEST(max(abs(mat(data_gradient1)-mat(data_gradient2))) < 1e-3); + + + resizable_tensor filter_gradient1, filter_gradient2; + gi.copy_size(output1); + rnd.fill_uniform(gi); + + filter_gradient1.copy_size(filters); + filter_gradient2.copy_size(filters); + filter_gradient1 = 1; + filter_gradient2 = 1; + + conv1.get_gradient_for_filters(false, gi, data, filter_gradient1); + conv2.get_gradient_for_filters(false, gi, data, filter_gradient2); + + dlog << LINFO << "filter gradient error: "<< max(abs(mat(filter_gradient1)-mat(filter_gradient2))); + DLIB_TEST_MSG(max(abs(mat(filter_gradient1)-mat(filter_gradient2))) < 1e-3, max(abs(mat(filter_gradient1)-mat(filter_gradient2)))); + + + conv1.get_gradient_for_filters(true, gi, data, filter_gradient1); + conv2.get_gradient_for_filters(true, gi, data, filter_gradient2); + + dlog << LINFO << "filter gradient error: "<< max(abs(mat(filter_gradient1)-mat(filter_gradient2))); + DLIB_TEST_MSG(max(abs(mat(filter_gradient1)-mat(filter_gradient2))) < 2e-3, max(abs(mat(filter_gradient1)-mat(filter_gradient2)))); + } + } + + void compare_adam() + { + float t = 2; + tt::tensor_rand rnd; + resizable_tensor s, m, v, params, params_grad; + s.set_size(89,90,60,73); + m.copy_size(s); + v.copy_size(s); + params.copy_size(s); + params_grad.copy_size(s); + + rnd.fill_uniform(s); + rnd.fill_uniform(m); + rnd.fill_uniform(v); + rnd.fill_uniform(params); + rnd.fill_uniform(params_grad); + + resizable_tensor mm(m), vv(v); + cpu::compute_adam_update(0,params.size(),s, mm, vv, t, 0.01, 0.001, 0.9, 0.99, params, params_grad); + matrix s1 = mat(s); + + rnd.fill_uniform(s); + cuda::compute_adam_update(0,params.size(),s, m, v, t, 0.01, 0.001, 0.9, 0.99, params, params_grad); + matrix s2 = mat(s); + + DLIB_TEST_MSG(max(abs(s1-s2)) < 1e-6, max(abs(s1-s2))); + DLIB_TEST_MSG(max(abs(mat(m)-mat(mm))) < 1e-6, max(abs(mat(m)-mat(mm)))); + DLIB_TEST_MSG(max(abs(mat(v)-mat(vv))) < 1e-6, max(abs(mat(v)-mat(vv)))); + } + + void test_multiply_zero_padded() + { + print_spinner(); + dlib::rand rnd; + tt::tensor_rand trnd; + for (int iter = 0; iter < 300; ++iter) + { + resizable_tensor dest1(rnd.get_random_32bit_number()%4+1, + rnd.get_random_32bit_number()%4+1, + rnd.get_random_32bit_number()%4+1, + rnd.get_random_32bit_number()%4+1); + resizable_tensor dest2; + dest2.copy_size(dest1); + resizable_tensor src1(rnd.get_random_32bit_number()%4+1, + rnd.get_random_32bit_number()%4+1, + rnd.get_random_32bit_number()%4+1, + rnd.get_random_32bit_number()%4+1); + resizable_tensor src2(rnd.get_random_32bit_number()%4+1, + rnd.get_random_32bit_number()%4+1, + rnd.get_random_32bit_number()%4+1, + rnd.get_random_32bit_number()%4+1); + + trnd.fill_uniform(dest1); + trnd.fill_uniform(dest2); + trnd.fill_uniform(src1); + trnd.fill_uniform(src2); + cpu::multiply_zero_padded(false, dest1, src1, src2); + cuda::multiply_zero_padded(false, dest2, src1, src2); + DLIB_TEST(max(abs(mat(dest1) - mat(dest2))) < 1e-5); + + cpu::multiply_zero_padded(true, dest1, src1, src2); + cuda::multiply_zero_padded(true, dest2, src1, src2); + DLIB_TEST(max(abs(mat(dest1) - mat(dest2))) < 1e-5); + } + + // make sure we have a test for the case where all tensors have the same + // dimensions. + resizable_tensor dest1(3,4,5,6); + resizable_tensor dest2; + resizable_tensor src1; + resizable_tensor src2; + dest2.copy_size(dest1); + src1.copy_size(dest1); + src2.copy_size(dest1); + + trnd.fill_uniform(dest1); + trnd.fill_uniform(dest2); + trnd.fill_uniform(src1); + trnd.fill_uniform(src2); + cpu::multiply_zero_padded(false, dest1, src1, src2); + cuda::multiply_zero_padded(false, dest2, src1, src2); + DLIB_TEST(max(abs(mat(dest1) - mat(dest2))) < 1e-5); + + cpu::multiply_zero_padded(true, dest1, src1, src2); + cuda::multiply_zero_padded(true, dest2, src1, src2); + DLIB_TEST(max(abs(mat(dest1) - mat(dest2))) < 1e-5); + } + + void test_add() + { + print_spinner(); + dlib::rand rnd; + tt::tensor_rand trnd; + for (int iter = 0; iter < 300; ++iter) + { + resizable_tensor dest1(rnd.get_random_32bit_number()%4+1, + rnd.get_random_32bit_number()%4+1, + rnd.get_random_32bit_number()%4+1, + rnd.get_random_32bit_number()%4+1); + resizable_tensor dest2; + dest2.copy_size(dest1); + resizable_tensor src1(rnd.get_random_32bit_number()%4+1, + rnd.get_random_32bit_number()%4+1, + rnd.get_random_32bit_number()%4+1, + rnd.get_random_32bit_number()%4+1); + resizable_tensor src2(rnd.get_random_32bit_number()%4+1, + rnd.get_random_32bit_number()%4+1, + rnd.get_random_32bit_number()%4+1, + rnd.get_random_32bit_number()%4+1); + + trnd.fill_uniform(dest1); + trnd.fill_uniform(dest2); + trnd.fill_uniform(src1); + trnd.fill_uniform(src2); + cpu::add(dest1, src1, src2); + cuda::add(dest2, src1, src2); + + DLIB_TEST(max(abs(mat(dest1) - mat(dest2))) < 1e-5); + } + + // make sure we have a test for the case where all tensors have the same + // dimensions. + resizable_tensor dest1(3,4,5,6); + resizable_tensor dest2; + resizable_tensor src1; + resizable_tensor src2; + dest2.copy_size(dest1); + src1.copy_size(dest1); + src2.copy_size(dest1); + + trnd.fill_uniform(dest1); + trnd.fill_uniform(dest2); + trnd.fill_uniform(src1); + trnd.fill_uniform(src2); + + cpu::add(dest1, src1, src2); + cuda::add(dest2, src1, src2); + + DLIB_TEST(max(abs(mat(dest1) - mat(dest2))) < 1e-5); + } + + void test_more_ops(const long nr, const long nc) + { + using namespace dlib::tt; + print_spinner(); + // We are going to make sure that the CPU implementation of these things matches + // the CUDA implementation. + + tensor_rand rnd; + + resizable_tensor dest(nr,nc), src(nr,nc), dest2, src2; + resizable_tensor srcb(nr,nc), srcc(nr,nc), srcb2, srcc2; + + + rnd.fill_uniform(dest); + rnd.fill_uniform(src); + dest2 = dest; src2 = src; + cuda::multiply(false, dest, dest, src); + cpu::multiply(false, dest2, dest2, src2); + DLIB_TEST(equal(mat(dest),mat(dest2))); + cuda::multiply(true, dest, dest, src); + cpu::multiply(true, dest2, dest2, src2); + DLIB_TEST(equal(mat(dest),mat(dest2))); + + + rnd.fill_uniform(dest); + rnd.fill_uniform(src); + dest2 = dest; src2 = src; + cuda::affine_transform(dest, src, 2, 3); + cpu::affine_transform(dest2, src2, 2, 3); + DLIB_TEST(equal(mat(dest),mat(dest2))); + + rnd.fill_uniform(dest); + rnd.fill_uniform(src); + rnd.fill_uniform(srcb); + dest2 = dest; src2 = src; srcb2 = srcb; + cuda::affine_transform(dest, src, srcb, 2, 3, 4); + cpu::affine_transform(dest2, src2, srcb2, 2, 3, 4); + DLIB_TEST(equal(mat(dest),mat(dest2))); + + rnd.fill_uniform(dest); + rnd.fill_uniform(src); + rnd.fill_uniform(srcb); + rnd.fill_uniform(srcc); + dest2 = dest; src2 = src; srcb2 = srcb; srcc2 = srcc; + cuda::affine_transform(dest, src, srcb, srcc, 2, 3, 4, 5); + cpu::affine_transform(dest2, src2, srcb2, srcc2, 2, 3, 4, 5); + DLIB_TEST(equal(mat(dest),mat(dest2))); + + cuda::affine_transform(dest, src, srcb, srcc, 2, 3, 4, 0); + cpu::affine_transform(dest2, src2, srcb2, srcc2, 2, 3, 4, 0); + DLIB_TEST(equal(mat(dest),mat(dest2))); + + cuda::affine_transform_range(0, dest.size(), dest, src, srcb, srcc, 2, 3, 4); + cpu::affine_transform_range(0, dest2.size(), dest2, src2, srcb2, srcc2, 2, 3, 4); + DLIB_TEST(equal(mat(dest),mat(dest2))); + + if (3 < dest.size()) + { + dest = 999; + dest2 = 999; + cuda::affine_transform_range(3, dest.size()-1, dest, src, srcb, srcc, 2, 3, 4); + cpu::affine_transform_range(3, dest2.size()-1, dest2, src2, srcb2, srcc2, 2, 3, 4); + DLIB_TEST(equal(mat(dest),mat(dest2))); + + cuda::affine_transform_range(dest.size(), dest.size(), dest, src, srcb, srcc, 2, 3, 4); + cpu::affine_transform_range(dest2.size(), dest2.size(), dest2, src2, srcb2, srcc2, 2, 3, 4); + DLIB_TEST(equal(mat(dest),mat(dest2))); + } + + + rnd.fill_uniform(dest); + rnd.fill_uniform(src); + rnd.fill_uniform(srcb); + rnd.fill_uniform(srcc); + dest2 = dest; src2 = src; srcb2 = srcb; srcc2 = srcc; + cuda::affine_transform(dest, src, srcb, srcc); + cpu::affine_transform(dest2, src2, srcb2, srcc2); + DLIB_TEST(equal(mat(dest),mat(dest2))); + // now exercise code path where the A/B tensors have num_samples()==1 + srcb.set_size(1,nc); + srcc.set_size(1,nc); + rnd.fill_uniform(dest); + rnd.fill_uniform(src); + rnd.fill_uniform(srcb); + rnd.fill_uniform(srcc); + dest2 = dest; src2 = src; srcb2 = srcb; srcc2 = srcc; + cuda::affine_transform(dest, src, srcb, srcc); + cpu::affine_transform(dest2, src2, srcb2, srcc2); + DLIB_TEST(equal(mat(dest),mat(dest2))); + + + rnd.fill_uniform(src); + src2 = src; + cuda::threshold(src, 0.5); + cpu::threshold(src2, 0.5); + DLIB_TEST(equal(mat(src),mat(src2))); + + { + resizable_tensor dest(3,4); + resizable_tensor A, B; + A = dest; + B = dest; + + rnd.fill_uniform(dest); + rnd.fill_uniform(A); + rnd.fill_uniform(B); + + dest.set_size(1,4); + + cuda::multiply(false, dest, A, B); + DLIB_TEST_MSG(max(abs(mat(dest)-sum_rows(pointwise_multiply(mat(A),mat(B))))) < 1e-6, max(abs(mat(dest)-sum_rows(pointwise_multiply(mat(A),mat(B)))))); + + A.set_size(1,4); + rnd.fill_uniform(A); + matrix AA = join_cols(mat(A),mat(A)); AA = join_cols(mat(A),AA); + + cuda::multiply(false, dest, A, B); + DLIB_TEST(max(abs(mat(dest)-sum_rows(pointwise_multiply(AA,mat(B))))) < 1e-6); + + cuda::multiply(false, dest, B, A); + DLIB_TEST(max(abs(mat(dest)-sum_rows(pointwise_multiply(AA,mat(B))))) < 1e-6); + matrix prevdest = mat(dest); + cuda::multiply(true, dest, B, A); + DLIB_TEST(max(abs(mat(dest)-prevdest-sum_rows(pointwise_multiply(AA,mat(B))))) < 1e-6); + + dest.set_size(3,4); + cuda::multiply(false, dest, B, A); + DLIB_TEST(max(abs(mat(dest)-pointwise_multiply(AA,mat(B)))) < 1e-6); + prevdest = mat(dest); + cuda::multiply(true, dest, B, A); + DLIB_TEST(max(abs(mat(dest)-prevdest-pointwise_multiply(AA,mat(B)))) < 1e-6); + + cuda::multiply(false, dest, A, B); + DLIB_TEST(max(abs(mat(dest)-pointwise_multiply(AA,mat(B)))) < 1e-6); + } + + { + resizable_tensor invnorms1, invnorms2; + resizable_tensor data(4,5), out1, out2; + rnd.fill_uniform(data); + + const double eps = 0.1; + + invnorms2 = reciprocal(sqrt(sum_cols(squared(mat(data))) + eps)); + tt::inverse_norms(invnorms1, data, eps); + DLIB_TEST(max(abs(mat(invnorms1)-mat(invnorms2))) < 1e-6); + + out1.copy_size(data); + tt::scale_rows(out1, data, invnorms1); + out2 = scale_rows(mat(data), mat(invnorms1)); + DLIB_TEST(max(abs(mat(out1)-mat(out2))) < 1e-6); + } + + { + resizable_tensor a(123,432), b(123,432); + rnd.fill_gaussian(a); + rnd.fill_gaussian(b); + + resizable_tensor out; + dot_prods(out, a,b); + const matrix truth = sum_cols(pointwise_multiply(mat(a), mat(b))); + DLIB_TEST(max(abs(mat(out) - truth)) < 1e-4); + out = 0; + DLIB_TEST(max(abs(mat(out) - truth)) > 1e-2); + dot_prods(false, out, a,b); + DLIB_TEST(max(abs(mat(out) - truth)) < 1e-4); + dot_prods(true, out, a,b); + DLIB_TEST(max(abs(mat(out)/2 - truth)) < 1e-4); + DLIB_TEST(max(abs(mat(out) - truth)) > 1e-2); + } + } + +// ---------------------------------------------------------------------------------------- + + void compare_bn_gpu_and_cpu() + { + print_spinner(); + resizable_tensor dest, dest2; + resizable_tensor means, means2; + resizable_tensor invstds, invstds2; + resizable_tensor running_means, running_means2; + resizable_tensor running_variances, running_variances2; + resizable_tensor src(64,20,100,100); + resizable_tensor gamma(1,20,100,100); + resizable_tensor beta(1,20,100,100); + gamma = 2; + beta = 3; + tt::tensor_rand rnd; + rnd.fill_uniform(src); + + + cpu::batch_normalize(DEFAULT_BATCH_NORM_EPS,dest, means, invstds, 1, running_means, running_variances, src, gamma, beta); + cuda::batch_normalize(DEFAULT_BATCH_NORM_EPS,dest2,means2,invstds2, 1, running_means2, running_variances2, src, gamma, beta); + + dlog << LINFO << "dest error: "<< max(abs(mat(dest) -mat(dest2))); + dlog << LINFO << "means error: "<< max(abs(mat(means) -mat(means2))); + dlog << LINFO << "invstds error: "<< max(abs(mat(invstds) -mat(invstds2))); + dlog << LINFO << "running_means error: "<< max(abs(mat(running_means) -mat(running_means2))); + dlog << LINFO << "running_variances error: "<< max(abs(mat(running_variances) -mat(running_variances2))); + + DLIB_TEST(max(abs(mat(dest) -mat(dest2))) < 1e-4); + DLIB_TEST(max(abs(mat(means) -mat(means2))) < 1e-4); + DLIB_TEST(max(abs(mat(invstds) -mat(invstds2))) < 1e-4); + DLIB_TEST(max(abs(mat(running_means) -mat(running_means2))) < 1e-4); + DLIB_TEST_MSG(max(abs(mat(running_variances) -mat(running_variances2))) < 1e-4, + mean(mat(running_variances)) + << "\n" << mean(mat(running_variances2)) + << "\n" << max(abs(mat(running_variances) -mat(running_variances2))) + << "\n" << mean(abs(mat(running_variances) -mat(running_variances2))) + ); + + + // now check that the gradients match as well + resizable_tensor gradient_input; + resizable_tensor src_grad, gamma_grad, beta_grad; + resizable_tensor src_grad2, gamma_grad2, beta_grad2; + gradient_input.copy_size(dest); + src_grad.copy_size(src); src_grad = 0; src_grad2 = src_grad; + gamma_grad.copy_size(gamma); gamma_grad = 0; gamma_grad2 = gamma_grad; + beta_grad.copy_size(beta); beta_grad = 0; beta_grad2 = beta_grad; + rnd.fill_uniform(gradient_input); + + + cpu::batch_normalize_gradient(DEFAULT_BATCH_NORM_EPS,gradient_input, means, invstds, src, gamma, src_grad, gamma_grad, beta_grad); + cuda::batch_normalize_gradient(DEFAULT_BATCH_NORM_EPS,gradient_input, means, invstds, src, gamma, src_grad2, gamma_grad2, beta_grad2); + + dlog << LINFO << "src_grad error: " << max(abs(mat(src_grad)-mat(src_grad2))); + dlog << LINFO << "gamma_grad error: " << max(abs(mat(gamma_grad)-mat(gamma_grad2))); + dlog << LINFO << "beta_grad error: " << max(abs(mat(beta_grad)-mat(beta_grad2))); + DLIB_TEST(max(abs(mat(src_grad)-mat(src_grad2))) < 1e-4); + DLIB_TEST(max(abs(mat(gamma_grad)-mat(gamma_grad2))) < 1e-4); + DLIB_TEST(max(abs(mat(beta_grad)-mat(beta_grad2))) < 1e-4); + } + + void compare_bn_conv_gpu_and_cpu() + { + print_spinner(); + resizable_tensor dest, dest2; + resizable_tensor means, means2; + resizable_tensor invstds, invstds2; + resizable_tensor running_means, running_means2; + resizable_tensor running_variances, running_variances2; + resizable_tensor src(2,8,10,9); + resizable_tensor gamma(1,8); + resizable_tensor beta(1,8); + gamma = 2; + beta = 3; + tt::tensor_rand rnd; + rnd.fill_uniform(src); + + cpu::batch_normalize_conv(DEFAULT_BATCH_NORM_EPS,dest,means,invstds,1,running_means,running_variances, src, gamma, beta); + cuda::batch_normalize_conv(DEFAULT_BATCH_NORM_EPS,dest2,means2,invstds2,1,running_means2,running_variances2, src, gamma, beta); + + dlog << LINFO << "dest error: "<< max(abs(mat(dest) -mat(dest2))); + dlog << LINFO << "means error: "<< max(abs(mat(means) -mat(means2))); + dlog << LINFO << "invstds error: "<< max(abs(mat(invstds) -mat(invstds2))); + dlog << LINFO << "running_means error: "<< max(abs(mat(running_means) -mat(running_means2))); + dlog << LINFO << "running_variances error: "<< max(abs(mat(running_variances) -mat(running_variances2))); + + DLIB_TEST(max(abs(mat(dest) -mat(dest2))) < 1e-4); + DLIB_TEST(max(abs(mat(means) -mat(means2))) < 1e-4); + DLIB_TEST(max(abs(mat(invstds) -mat(invstds2))) < 1e-4); + DLIB_TEST(max(abs(mat(running_means) -mat(running_means2))) < 1e-4); + DLIB_TEST(max(abs(mat(running_variances) -mat(running_variances2))) < 1e-4); + + resizable_tensor gradient_input; + resizable_tensor src_grad, gamma_grad, beta_grad; + resizable_tensor src_grad2, gamma_grad2, beta_grad2; + gradient_input.copy_size(dest); + src_grad.copy_size(src); src_grad = 0; src_grad2 = src_grad; + gamma_grad.copy_size(gamma); gamma_grad = 0; gamma_grad2 = gamma_grad; + beta_grad.copy_size(beta); beta_grad = 0; beta_grad2 = beta_grad; + rnd.fill_uniform(gradient_input); + + + cpu::batch_normalize_conv_gradient(DEFAULT_BATCH_NORM_EPS,gradient_input, means, invstds, src, gamma, src_grad, gamma_grad, beta_grad); + cuda::batch_normalize_conv_gradient(DEFAULT_BATCH_NORM_EPS,gradient_input, means, invstds, src, gamma, src_grad2, gamma_grad2, beta_grad2); + + dlog << LINFO << "src_grad error: " << max(abs(mat(src_grad)-mat(src_grad2))); + dlog << LINFO << "gamma_grad error: " << max(abs(mat(gamma_grad)-mat(gamma_grad2))); + dlog << LINFO << "beta_grad error: " << max(abs(mat(beta_grad)-mat(beta_grad2))); + DLIB_TEST(max(abs(mat(src_grad)-mat(src_grad2))) < 1e-4); + DLIB_TEST(max(abs(mat(gamma_grad)-mat(gamma_grad2))) < 1e-4); + DLIB_TEST(max(abs(mat(beta_grad)-mat(beta_grad2))) < 1e-4); + } + + + void test_more_ops2() + { + dlib::rand rnd; + tt::tensor_rand trand; + + for (int iter = 0; iter < 100; ++iter) + { + print_spinner(); + resizable_tensor dest1, dest2, src1, src2; + src1.set_size(rnd.get_random_32bit_number()%30+1, + rnd.get_random_32bit_number()%30+1, + rnd.get_random_32bit_number()%30+1, + rnd.get_random_32bit_number()%30+1); + dest1.copy_size(src1); + dest2.copy_size(src1); + src2.set_size(1,src1.k(),1,1); + + trand.fill_uniform(dest1); + trand.fill_uniform(dest2); + trand.fill_uniform(src1); + trand.fill_uniform(src2); + + cpu::multiply_conv(false, dest1, src1, src2); + cuda::multiply_conv(false, dest2, src1, src2); + DLIB_TEST(max(abs(mat(dest1)-mat(dest2))) < 1e-5); + cpu::multiply_conv(true, dest1, src1, src2); + cuda::multiply_conv(true, dest2, src1, src2); + DLIB_TEST(max(abs(mat(dest1)-mat(dest2))) < 1e-5); + + + // now try it using the other mode of multiply_conv + src2.copy_size(src1); + dest1.set_size(1,src1.k(),1,1); + dest2.set_size(1,src1.k(),1,1); + trand.fill_uniform(dest1); + trand.fill_uniform(dest2); + trand.fill_uniform(src1); + trand.fill_uniform(src2); + cpu::multiply_conv(false, dest1, src1, src2); + cuda::multiply_conv(false, dest2, src1, src2); + float scale = max(abs(mat(dest1))); + float scalem = mean(abs(mat(dest1))); + DLIB_TEST_MSG(max(abs(mat(dest1)-mat(dest2)))/scale < 1e-4 , max(abs(mat(dest1)-mat(dest2)))/scale); + DLIB_TEST_MSG(mean(abs(mat(dest1)-mat(dest2)))/scalem < 1e-5 , mean(abs(mat(dest1)-mat(dest2)))/scalem); + matrix prevd2 = mat(dest2); + cpu::multiply_conv(false, dest1, src1, src2); + cuda::multiply_conv(true, dest2, src1, src2); + scale = max(abs(mat(dest1))); + scalem = mean(abs(mat(dest1))); + DLIB_TEST_MSG(max(abs(mat(dest1)-mat(dest2)+prevd2))/scale < 1e-4 , max(abs(mat(dest1)-mat(dest2)+prevd2))/scale); + DLIB_TEST_MSG(mean(abs(mat(dest1)-mat(dest2)+prevd2))/scalem < 1e-5 , mean(abs(mat(dest1)-mat(dest2)+prevd2))/scalem); + } + + for (int iter = 0; iter < 100; ++iter) + { + print_spinner(); + resizable_tensor dest1, dest2, src, A, B; + src.set_size(rnd.get_random_32bit_number()%30+1, + rnd.get_random_32bit_number()%30+1, + rnd.get_random_32bit_number()%30+1, + rnd.get_random_32bit_number()%30+1); + dest1.copy_size(src); + dest2.copy_size(src); + A.set_size(1,src.k(),1,1); + B.set_size(1,src.k(),1,1); + + trand.fill_uniform(dest1); + trand.fill_uniform(dest2); + trand.fill_uniform(src); + trand.fill_uniform(A); + trand.fill_uniform(B); + + cpu::affine_transform_conv(dest1, src, A, B); + cuda::affine_transform_conv(dest2, src, A, B); + DLIB_TEST(max(abs(mat(dest1)-mat(dest2))) < 1e-5); + } + + for (int iter = 0; iter < 100; ++iter) + { + print_spinner(); + resizable_tensor dest1, dest2, g; + g.set_size(rnd.get_random_32bit_number()%30+1, + rnd.get_random_32bit_number()%30+1, + rnd.get_random_32bit_number()%30+1, + rnd.get_random_32bit_number()%30+1); + dest1.set_size(1,g.k(),1,1); + dest2.set_size(1,g.k(),1,1); + + trand.fill_uniform(dest1); + trand.fill_uniform(dest2); + trand.fill_uniform(g); + + cpu::assign_conv_bias_gradient(dest1, g); + cuda::assign_conv_bias_gradient(dest2, g); + const float scale = max(abs(mat(dest1))); + const float scalem = mean(abs(mat(dest1))); + DLIB_TEST_MSG(max(abs(mat(dest1)-mat(dest2)))/scale < 1e-4 , max(abs(mat(dest1)-mat(dest2)))/scale); + DLIB_TEST_MSG(mean(abs(mat(dest1)-mat(dest2)))/scalem < 1e-5 , mean(abs(mat(dest1)-mat(dest2)))/scalem); + } + + } + +#endif // DLIB_USE_CUDA + +// ---------------------------------------------------------------------------------------- + + void test_max_pool( + const int window_height, + const int window_width, + const int stride_y, + const int stride_x, + const int padding_y, + const int padding_x + ) + { + print_spinner(); + resizable_tensor A, B, gradient_input; + A.set_size(4,5,16,7); + B.copy_size(A); + gradient_input.copy_size(A); + + tt::tensor_rand rnd; + rnd.fill_gaussian(A,0,1); + rnd.fill_gaussian(B,0,1); + rnd.fill_gaussian(gradient_input,0,1); + + + tt::pooling mp; + + mp.setup_max_pooling(window_height,window_width,stride_y,stride_x,padding_y,padding_x); + mp(A, B); + + // make sure max pooling does what it's spec says it should. + DLIB_TEST( A.num_samples() == B.num_samples()); + DLIB_TEST( A.k() == B.k()); + + DLIB_TEST( A.nr() == 1+(B.nr()+2*padding_y-window_height)/stride_y); + DLIB_TEST( A.nc() == 1+(B.nc()+2*padding_x-window_width)/stride_x); + + const long x_offset = window_width/2 - padding_x; + const long y_offset = window_height/2 - padding_y; + for (long s = 0; s < A.num_samples(); ++s) + { + for (long k = 0; k < A.k(); ++k) + { + for (long r = 0; r < A.nr(); ++r) + { + for (long c = 0; c < A.nc(); ++c) + { + DLIB_TEST_MSG(image_plane(A,s,k)(r,c) == max(subm_clipped(image_plane(B,s,k), + centered_rect(c*stride_x+x_offset, + r*stride_y+y_offset, + window_width, + window_height))), + "padding: "<< padding_x << " " << padding_y + << " window size: " << window_width << " " << window_height + << " stride: " << stride_x << " " << stride_y + ); + } + } + } + } + } + +// ---------------------------------------------------------------------------------------- + + void test_avg_pool( + const int window_height, + const int window_width, + const int stride_y, + const int stride_x, + const int padding_y, + const int padding_x + ) + { + print_spinner(); + resizable_tensor A, B, gradient_input; + A.set_size(4,5,16,7); + B.copy_size(A); + gradient_input.copy_size(A); + + tt::tensor_rand rnd; + rnd.fill_gaussian(A,0,1); + rnd.fill_gaussian(B,0,1); + rnd.fill_gaussian(gradient_input,0,1); + + + tt::pooling mp; + + mp.setup_avg_pooling(window_height,window_width,stride_y,stride_x,padding_y,padding_x); + mp(A, B); + + // make sure avg pooling does what it's spec says it should. + DLIB_TEST( A.num_samples() == B.num_samples()); + DLIB_TEST( A.k() == B.k()); + DLIB_TEST( A.nr() == 1+(B.nr()+2*padding_y-window_height)/stride_y); + DLIB_TEST( A.nc() == 1+(B.nc()+2*padding_x-window_width)/stride_x); + + const long x_offset = window_width/2 - padding_x; + const long y_offset = window_height/2 - padding_y; + for (long s = 0; s < A.num_samples(); ++s) + { + for (long k = 0; k < A.k(); ++k) + { + for (long r = 0; r < A.nr(); ++r) + { + for (long c = 0; c < A.nc(); ++c) + { + float expected = mean(subm_clipped(image_plane(B,s,k), + centered_rect(c*stride_x+x_offset, + r*stride_y+y_offset, + window_width, + window_height))); + float err = abs(image_plane(A,s,k)(r,c) - expected); + DLIB_TEST_MSG(err < 1e-5, err << " " << expected << " " << image_plane(A,s,k)(r,c)); + } + } + } + } + } + +// ---------------------------------------------------------------------------------------- + + void test_layers() + { + { + print_spinner(); + extract_<0,2,2,2> l; + auto res = test_layer(l); + DLIB_TEST_MSG(res, res); + } + { + print_spinner(); + extract_<3,2,1,2> l; + auto res = test_layer(l); + DLIB_TEST_MSG(res, res); + } + { + print_spinner(); + extract_<0,2,1,2> l; + auto res = test_layer(l); + DLIB_TEST_MSG(res, res); + } + { + print_spinner(); + upsample_<1,1> l; + auto res = test_layer(l); + DLIB_TEST_MSG(res, res); + } + { + print_spinner(); + upsample_<2,1> l; + auto res = test_layer(l); + DLIB_TEST_MSG(res, res); + } + { + print_spinner(); + upsample_<2,2> l; + auto res = test_layer(l); + DLIB_TEST_MSG(res, res); + } + { + print_spinner(); + upsample_<3,3> l; + auto res = test_layer(l); + DLIB_TEST_MSG(res, res); + } + { + print_spinner(); + l2normalize_ l; + auto res = test_layer(l); + DLIB_TEST_MSG(res, res); + } + { + print_spinner(); + multiply_ l; + auto res = test_layer(l); + DLIB_TEST_MSG(res, res); + } + { + print_spinner(); + max_pool_<3,3,1,1> l; + auto res = test_layer(l); + DLIB_TEST_MSG(res, res); + } + { + print_spinner(); + avg_pool_<3,3,1,1> l; + auto res = test_layer(l); + DLIB_TEST_MSG(res, res); + } + { + print_spinner(); + affine_ l(CONV_MODE); + auto res = test_layer(l); + DLIB_TEST_MSG(res, res); + } + { + print_spinner(); + affine_ l(FC_MODE); + auto res = test_layer(l); + DLIB_TEST_MSG(res, res); + } + { + print_spinner(); + bn_ l; + auto res = test_layer(l); + DLIB_TEST_MSG(res, res); + } + { + print_spinner(); + bn_ l; + auto res = test_layer(l); + DLIB_TEST_MSG(res, res); + } + { + print_spinner(); + cont_<3,3,3,2,2,0,0> l; + auto res = test_layer(l); + DLIB_TEST_MSG(res, res); + } + { + print_spinner(); + cont_<3,3,3,2,2> l; + auto res = test_layer(l); + DLIB_TEST_MSG(res, res); + } + { + print_spinner(); + cont_<3,3,3,1,1> l; + auto res = test_layer(l); + DLIB_TEST_MSG(res, res); + } + { + print_spinner(); + cont_<3,3,3,1,1,0,0> l; + auto res = test_layer(l); + DLIB_TEST_MSG(res, res); + } + { + print_spinner(); + cont_<3,2,2,2,2> l; + auto res = test_layer(l); + DLIB_TEST_MSG(res, res); + } + { + print_spinner(); + con_<3,2,2,2,2> l; + auto res = test_layer(l); + DLIB_TEST_MSG(res, res); + } + { + print_spinner(); + con_<3,3,3,1,1>l; + auto res = test_layer(l); + DLIB_TEST_MSG(res, res); + } + { + print_spinner(); + con_<3,3,2,1,1> l; + auto res = test_layer(l); + DLIB_TEST_MSG(res, res); + } + { + print_spinner(); + con_<2,1,1,1,1> l; + auto res = test_layer(l); + DLIB_TEST_MSG(res, res); + } + { + print_spinner(); + con_<3,0,2,2,2> l; + auto res = test_layer(l); + DLIB_TEST_MSG(res, res); + } + { + print_spinner(); + con_<3,2,0,2,2> l; + auto res = test_layer(l); + DLIB_TEST_MSG(res, res); + } + { + print_spinner(); + con_<3,0,0,2,2> l; + auto res = test_layer(l); + DLIB_TEST_MSG(res, res); + } + { + print_spinner(); + fc_<1,FC_HAS_BIAS> l; + auto res = test_layer(l); + DLIB_TEST_MSG(res, res); + } + { + print_spinner(); + fc_<5,FC_HAS_BIAS> l; + auto res = test_layer(l); + DLIB_TEST_MSG(res, res); + } + { + print_spinner(); + fc_<4,FC_NO_BIAS> l; + auto res = test_layer(l); + DLIB_TEST_MSG(res, res); + } + { + print_spinner(); + relu_ l; + auto res = test_layer(l); + DLIB_TEST_MSG(res, res); + } + { + print_spinner(); + prelu_ l; + auto res = test_layer(l); + DLIB_TEST_MSG(res, res); + } + { + print_spinner(); + sig_ l; + auto res = test_layer(l); + DLIB_TEST_MSG(res, res); + } + { + print_spinner(); + htan_ l; + auto res = test_layer(l); + DLIB_TEST_MSG(res, res); + } + { + print_spinner(); + softmax_ l; + auto res = test_layer(l); + DLIB_TEST_MSG(res, res); + } + { + print_spinner(); + softmax_all_ l; + auto res = test_layer(l); + DLIB_TEST_MSG(res, res); + } + } + +// ---------------------------------------------------------------------------------------- + + template using rcon = max_pool<2,2,2,2,relu>>>; + template using rfc = relu>>; + + void test_tagging( + ) + { + typedef loss_multiclass_log>>>>>>>>> net_type; + + net_type net; + net_type net2(num_fc_outputs(4)); + + DLIB_TEST(layer(net).num_computational_layers == 8); + DLIB_TEST(layer(net).num_computational_layers == 8+3+3); + DLIB_TEST(layer(net).num_layers == 10); + DLIB_TEST(layer(net).num_layers == 10+3+3+1); + DLIB_TEST(&layer(net).get_output() == &layer(net).get_output()); + DLIB_TEST(&layer(net).get_output() != &layer(net).subnet().subnet().get_output()); + DLIB_TEST(net.subnet().subnet().subnet().layer_details().get_num_outputs() == 10); + DLIB_TEST(net2.subnet().subnet().subnet().layer_details().get_num_outputs() == 4); + } + +// ---------------------------------------------------------------------------------------- + + template < + int N, + template class BN, + int stride, + typename SUBNET + > + using block = BN>>>>; + + template < + template class,int,typename> class block, + int N, + templateclass BN, + typename SUBNET + > + using residual = add_prev1>>; + + template < + template class,int,typename> class block, + int N, + templateclass BN, + typename SUBNET + > + using residual_down = add_prev2>>>>>; + + + template using res = relu>; + template using ares = relu>; + template using res_down = relu>; + template using ares_down = relu>; + + template + using pres = prelu>>>>>>>; + + void test_visit_funcions() + { + using net_type2 = loss_multiclass_log> + >>>>>>>>>>>; + + net_type2 pnet; + + DLIB_TEST_MSG(pnet.num_layers == 131, pnet.num_layers); + DLIB_TEST_MSG(pnet.num_computational_layers == 109, pnet.num_computational_layers); + + std::vector hit(pnet.num_computational_layers, false); + size_t count = 0; + visit_layer_parameter_gradients(pnet, [&](size_t i, tensor& ){hit[i] = true; ++count; }); + for (auto x : hit) + DLIB_TEST(x); + DLIB_TEST(count == pnet.num_computational_layers); + + count = 0; + std::vector hit2(pnet.num_computational_layers, false); + visit_layer_parameters(pnet, [&](size_t i, tensor& ){hit2[i] = true; ++count; }); + for (auto x : hit2) + DLIB_TEST(x); + DLIB_TEST(count == pnet.num_computational_layers); + } + + float tensor_read_cpu(const tensor& t, long i, long k, long r, long c) + { + const float* p = t.host() + t.k() * t.nr() * t.nc() * i + + t.nr() * t.nc() * k + t.nc() * r + c; + return *p; + } + void test_copy_tensor_cpu() + { + using namespace dlib::tt; + print_spinner(); + resizable_tensor dest(10, 9, 7, 15); + resizable_tensor src1(10, 3, 7, 15); + resizable_tensor src2(10, 3, 7, 15); + resizable_tensor src3(10, 9, 7, 15); + tt::tensor_rand rnd; + rnd.fill_gaussian(dest); + rnd.fill_gaussian(src1); + rnd.fill_gaussian(src2); + rnd.fill_gaussian(src3); + + cpu::copy_tensor(false, dest, 0, src1, 0, src1.k()); //full copy src1->dest + cpu::copy_tensor(false, dest, src1.k(), src2, 0, src2.k()); //full copy src2->dest with offset of src1 + cpu::copy_tensor(false, dest, src1.k() + src2.k(), src3, 3, 3); //partial copy src3 into the rest place of dest + + + for (long i = 0; i < dest.num_samples(); ++i) + { + for (long k = 0; k < dest.k(); ++k) + { + for (long r = 0; r < dest.nr(); ++r) + { + for (long c = 0; c < dest.nc(); ++c) + { + float dest_value = tensor_read_cpu(dest, i, k, r, c); + // first part is from src1 + if (k < src1.k()) + { + float src_value = tensor_read_cpu(src1, i, k, r, c); + DLIB_TEST(src_value == dest_value); + } + // second part is from src2 + else if (k < src1.k() + src2.k()) + { + float src_value = tensor_read_cpu(src2, i, k - src1.k(), r, c); + DLIB_TEST(src_value == dest_value); + } + // third part is from src3 + else + { + float src_value = tensor_read_cpu(src3, i, k - src1.k() - src2.k() + 3, r, c); + DLIB_TEST(src_value == dest_value); + } + } + } + } + } + } + void test_copy_tensor_add_to_cpu() + { + using namespace dlib::tt; + print_spinner(); + resizable_tensor dest(10, 9, 7, 15); + resizable_tensor src1(10, 3, 7, 15); + resizable_tensor src2(10, 3, 7, 15); + resizable_tensor src3(10, 9, 7, 15); + tt::tensor_rand rnd; + rnd.fill_gaussian(dest); + rnd.fill_gaussian(src1); + rnd.fill_gaussian(src2); + rnd.fill_gaussian(src3); + + const resizable_tensor old_dest = dest; + + cpu::copy_tensor(true, dest, 0, src1, 0, src1.k()); //full copy src1->dest + cpu::copy_tensor(true, dest, src1.k(), src2, 0, src2.k()); //full copy src2->dest with offset of src1 + cpu::copy_tensor(true, dest, src1.k() + src2.k(), src3, 3, 3); //partial copy src3 into the rest place of dest + + + for (long i = 0; i < dest.num_samples(); ++i) + { + for (long k = 0; k < dest.k(); ++k) + { + for (long r = 0; r < dest.nr(); ++r) + { + for (long c = 0; c < dest.nc(); ++c) + { + float old_dest_value = tensor_read_cpu(old_dest, i, k, r, c); + float dest_value = tensor_read_cpu(dest, i, k, r, c); + // first part is from src1 + if (k < src1.k()) + { + float src_value = tensor_read_cpu(src1, i, k, r, c)+old_dest_value; + DLIB_TEST(std::abs(src_value - dest_value) < 1e-6); + } + // second part is from src2 + else if (k < src1.k() + src2.k()) + { + float src_value = tensor_read_cpu(src2, i, k - src1.k(), r, c)+old_dest_value; + DLIB_TEST(std::abs(src_value - dest_value) < 1e-6); + } + // third part is from src3 + else + { + float src_value = tensor_read_cpu(src3, i, k - src1.k() - src2.k() + 3, r, c)+old_dest_value; + DLIB_TEST(std::abs(src_value - dest_value) < 1e-6); + } + } + } + } + } + } +#ifdef DLIB_USE_CUDA + void test_copy_tensor_gpu() + { + using namespace dlib::tt; + print_spinner(); + resizable_tensor dest(10, 9, 7, 15); + resizable_tensor src1(10, 3, 7, 15); + resizable_tensor src2(10, 3, 7, 15); + resizable_tensor src3(10, 9, 7, 15); + tt::tensor_rand rnd; + rnd.fill_gaussian(dest); + rnd.fill_gaussian(src1); + rnd.fill_gaussian(src2); + rnd.fill_gaussian(src3); + cuda::copy_tensor(false, dest, 0, src1, 0, src1.k()); //full copy src1->dest + cuda::copy_tensor(false, dest, src1.k(), src2, 0, src2.k()); //full copy src2->dest with offset of src1 + cuda::copy_tensor(false, dest, src1.k() + src2.k(), src3, 3, 3); //partial copy src3 into the rest place of dest + + + for (long i = 0; i < dest.num_samples(); ++i) + { + for (long k = 0; k < dest.k(); ++k) + { + for (long r = 0; r < dest.nr(); ++r) + { + for (long c = 0; c < dest.nc(); ++c) + { + float dest_value = tensor_read_cpu(dest, i, k, r, c); + // first part is from src1 + if (k < src1.k()) + { + float src_value = tensor_read_cpu(src1, i, k, r, c); + DLIB_TEST(src_value == dest_value); + } + // second part is from src2 + else if (k < src1.k() + src2.k()) + { + float src_value = tensor_read_cpu(src2, i, k - src1.k(), r, c); + DLIB_TEST(src_value == dest_value); + } + // third part is from src3 + else + { + float src_value = tensor_read_cpu(src3, i, k - src1.k() - src2.k() + 3, r, c); + DLIB_TEST(src_value == dest_value); + } + } + } + } + } + } + void test_copy_tensor_add_to_gpu() + { + using namespace dlib::tt; + print_spinner(); + resizable_tensor dest(10, 9, 7, 15); + resizable_tensor src1(10, 3, 7, 15); + resizable_tensor src2(10, 3, 7, 15); + resizable_tensor src3(10, 9, 7, 15); + tt::tensor_rand rnd; + rnd.fill_gaussian(dest); + rnd.fill_gaussian(src1); + rnd.fill_gaussian(src2); + rnd.fill_gaussian(src3); + + const resizable_tensor old_dest = dest; + + cuda::copy_tensor(true, dest, 0, src1, 0, src1.k()); //full copy src1->dest + cuda::copy_tensor(true, dest, src1.k(), src2, 0, src2.k()); //full copy src2->dest with offset of src1 + cuda::copy_tensor(true, dest, src1.k() + src2.k(), src3, 3, 3); //partial copy src3 into the rest place of dest + + + for (long i = 0; i < dest.num_samples(); ++i) + { + for (long k = 0; k < dest.k(); ++k) + { + for (long r = 0; r < dest.nr(); ++r) + { + for (long c = 0; c < dest.nc(); ++c) + { + float old_dest_value = tensor_read_cpu(old_dest, i, k, r, c); + float dest_value = tensor_read_cpu(dest, i, k, r, c); + // first part is from src1 + if (k < src1.k()) + { + float src_value = tensor_read_cpu(src1, i, k, r, c)+old_dest_value; + DLIB_TEST_MSG(std::abs(src_value - dest_value) < 1e-6, std::abs(src_value - dest_value)); + } + // second part is from src2 + else if (k < src1.k() + src2.k()) + { + float src_value = tensor_read_cpu(src2, i, k - src1.k(), r, c)+old_dest_value; + DLIB_TEST(std::abs(src_value - dest_value) < 1e-6); + } + // third part is from src3 + else + { + float src_value = tensor_read_cpu(src3, i, k - src1.k() - src2.k() + 3, r, c)+old_dest_value; + DLIB_TEST(std::abs(src_value - dest_value) < 1e-6); + } + } + } + } + } + } +#endif//DLIB_USE_CUDA + + template using concat_block1 = con<5,1,1,1,1,SUBNET>; + template using concat_block2 = con<8,3,3,1,1,SUBNET>; + template using concat_block3 = max_pool<3,3,1,1,SUBNET>; + template using concat_incept = inception3; + + void test_concat() + { + using namespace dlib::tt; + print_spinner(); + + using net_type = concat_incept>>; + + resizable_tensor data(10, 1, 111, 222); + tt::tensor_rand rnd; + rnd.fill_gaussian(data); + + net_type net; + + + auto& out = net.forward(data); + + auto& b1o = layer(net).get_output(); + auto& b2o = layer(net).get_output(); + auto& b3o = layer(net).get_output(); + + resizable_tensor dest(10, 14, 111, 222); + copy_tensor(false, dest, 0, b1o, 0, b1o.k()); + copy_tensor(false, dest, b1o.k(), b2o, 0, b2o.k()); + copy_tensor(false, dest, b1o.k() + b2o.k(), b3o, 0, b3o.k()); + + DLIB_TEST(dest.size() == out.size()); + int error = memcmp(dest.host(), out.host(), dest.size()); + DLIB_TEST(error == 0); + + resizable_tensor gr(10, 14, 111, 222); + rnd.fill_gaussian(gr); + + resizable_tensor params; + net.layer_details().backward(gr, net, params); + + auto& b1g = layer(net).subnet().get_gradient_input(); + auto& b2g = layer(net).subnet().get_gradient_input(); + auto& b3g = layer(net).subnet().get_gradient_input(); + + resizable_tensor g1(10, 5, 111, 222); + resizable_tensor g2(10, 8, 111, 222); + resizable_tensor g3(10, 1, 111, 222); + + copy_tensor(false, g1, 0, gr, 0, g1.k()); + copy_tensor(false, g2, 0, gr, g1.k(), g2.k()); + copy_tensor(false, g3, 0, gr, g1.k() + g2.k(), g3.k()); + DLIB_TEST(g1.size() == b1g.size()); + error = memcmp(g1.host(), b1g.host(), b1g.size()); + DLIB_TEST(error == 0); + DLIB_TEST(g2.size() == b2g.size()); + error = memcmp(g2.host(), b2g.host(), b2g.size()); + DLIB_TEST(error == 0); + DLIB_TEST(g3.size() == b3g.size()); + error = memcmp(g3.host(), b3g.host(), b3g.size()); + DLIB_TEST(error == 0); + } + +// ---------------------------------------------------------------------------------------- + + void test_simple_linear_regression() + { + const int num_samples = 1000; + ::std::vector> x(num_samples); + ::std::vector y(num_samples); + ::std::default_random_engine generator(16); + ::std::normal_distribution distribution(0,0.1); + const float true_intercept = 50.0; + const float true_slope = 10.0; + for ( int ii = 0; ii < num_samples; ++ii ) + { + const double val = static_cast(ii)/10; + matrix tmp(1,1); + tmp = val; + x[ii] = tmp; + y[ii] = (true_intercept + true_slope*static_cast(val) + distribution(generator)); + } + + using net_type = loss_mean_squared>>>; + net_type net; + layer<1>(net).layer_details().set_bias_learning_rate_multiplier(300); + sgd defsolver(0,0.9); + dnn_trainer trainer(net, defsolver); + trainer.set_learning_rate(1e-5); + trainer.set_min_learning_rate(1e-6); + trainer.set_mini_batch_size(50); + trainer.set_max_num_epochs(170); + trainer.train(x, y); + + const float slope = layer<1>(net).layer_details().get_weights().host()[0]; + const float slope_error = abs(true_slope - slope); + const float intercept = layer<1>(net).layer_details().get_biases().host()[0]; + const float intercept_error = abs(true_intercept - intercept); + const float eps_slope = 0.05, eps_intercept = 0.1; + + DLIB_TEST_MSG(slope_error <= eps_slope, + "Expected slope = " << true_slope << " Estimated slope = " << slope << " Error limit = " << eps_slope); + DLIB_TEST_MSG(intercept_error <= eps_intercept, + "Expected intercept = " << true_intercept << " Estimated intercept = " << intercept << " Error limit = " << eps_intercept); + + } + +// ---------------------------------------------------------------------------------------- + + void test_simple_linear_regression_eil() + { + print_spinner(); + const int num_samples = 1000; + ::std::vector> x(num_samples); + ::std::vector y(num_samples); + ::std::default_random_engine generator(16); + ::std::normal_distribution distribution(0,0.0001); + const float true_intercept = 50.0; + const float true_slope = 10.0; + for ( int ii = 0; ii < num_samples; ++ii ) + { + const double val = static_cast(ii)/10; + matrix tmp(1,1); + tmp = val; + x[ii] = tmp; + y[ii] = (true_intercept + true_slope*static_cast(val) + distribution(generator)); + } + + using net_type = loss_epsilon_insensitive>>>; + net_type net(0.01); + layer<1>(net).layer_details().set_bias_learning_rate_multiplier(300); + sgd defsolver(0,0.9); + dnn_trainer trainer(net, defsolver); + trainer.set_learning_rate(1e-5); + trainer.set_min_learning_rate(1e-8); + trainer.set_mini_batch_size(50); + trainer.set_max_num_epochs(570); + trainer.train(x, y); + + const float slope = layer<1>(net).layer_details().get_weights().host()[0]; + const float slope_error = abs(true_slope - slope); + const float intercept = layer<1>(net).layer_details().get_biases().host()[0]; + const float intercept_error = abs(true_intercept - intercept); + const float eps_slope = 0.01, eps_intercept = 0.1; + + dlog << LINFO << "slope_error: "<< slope_error; + dlog << LINFO << "intercept_error: "<< intercept_error; + DLIB_TEST_MSG(slope_error <= eps_slope, + "Expected slope = " << true_slope << " Estimated slope = " << slope << " Error limit = " << eps_slope); + DLIB_TEST_MSG(intercept_error <= eps_intercept, + "Expected intercept = " << true_intercept << " Estimated intercept = " << intercept << " Error limit = " << eps_intercept); + + } + +// ---------------------------------------------------------------------------------------- + + void test_simple_linear_regression_with_mult_prev() + { + srand(1234); + print_spinner(); + const int num_samples = 1000; + ::std::vector> x(num_samples); + ::std::vector y(num_samples); + const float true_slope = 2.0; + for ( int ii = 0; ii < num_samples; ++ii ) + { + const double val = static_cast(ii-500)/100; + matrix tmp(1,1); + tmp = val; + x[ii] = tmp; + y[ii] = ( true_slope*static_cast(val*val)); + } + + randomize_samples(x,y); + + using net_type = loss_mean_squared>>>>>>>; + net_type net; + sgd defsolver(0,0.9); + dnn_trainer trainer(net, defsolver); + trainer.set_learning_rate(1e-5); + trainer.set_min_learning_rate(1e-11); + trainer.set_mini_batch_size(50); + trainer.set_max_num_epochs(2000); + trainer.train(x, y); + + running_stats rs; + for (size_t i = 0; i < x.size(); ++i) + { + double val = y[i]; + double out = net(x[i]); + rs.add(std::abs(val-out)); + } + dlog << LINFO << "rs.mean(): " << rs.mean(); + dlog << LINFO << "rs.stddev(): " << rs.stddev(); + dlog << LINFO << "rs.max(): " << rs.max(); + DLIB_TEST(rs.mean() < 0.1); + } + +// ---------------------------------------------------------------------------------------- + + void test_multioutput_linear_regression() + { + const int num_outputs = 2; + const int num_samples = 1000; + ::std::vector> x(num_samples); + ::std::vector> y(num_samples); + ::std::default_random_engine generator(16); + ::std::normal_distribution distribution(0,0.1); + ::std::normal_distribution slope_distribution(10,5); + ::std::normal_distribution intercept_distribution(50,10); + ::std::vector true_intercepts(num_outputs); + ::std::vector true_slopes(num_outputs); + for ( int jj = 0; jj < num_outputs; ++jj ) + { + true_slopes[jj] = slope_distribution(generator); + true_intercepts[jj] = intercept_distribution(generator); + } + matrix ytmp(num_outputs, 1); + for ( int ii = 0; ii < num_samples; ++ii ) + { + const double val = static_cast(ii)/10; + matrix tmp(1,1); + tmp = val; + x[ii] = tmp; + for ( int jj = 0; jj < num_outputs; ++jj ) + ytmp(jj, 0) = (true_intercepts[jj] + true_slopes[jj]*static_cast(val) + distribution(generator)); + + y[ii] = ytmp; + } + + using net_type = loss_mean_squared_multioutput>>>; + net_type net; + layer<1>(net).layer_details().set_bias_learning_rate_multiplier(900); + sgd defsolver(0,0.9); + dnn_trainer trainer(net, defsolver); + trainer.set_learning_rate(1e-5); + trainer.set_min_learning_rate(1e-6); + trainer.set_mini_batch_size(50); + trainer.set_max_num_epochs(170); + trainer.train(x, y); + + float slope_error = 0.0; + float intercept_error = 0.0; + const float eps_slope = 0.05, eps_intercept = 0.1; + + for ( int jj = 0; jj < num_outputs; ++jj ) + { + slope_error += abs(layer<1>(net).layer_details().get_weights().host()[jj] - true_slopes[jj]); + intercept_error += abs(layer<1>(net).layer_details().get_biases().host()[jj] - true_intercepts[jj]); + } + + slope_error /= float(num_outputs); + intercept_error /= float(num_outputs); + + DLIB_TEST_MSG(slope_error <= eps_slope, + "Average absolute slope error = " << slope_error << " Error limit = " << eps_slope); + DLIB_TEST_MSG(intercept_error <= eps_intercept, + "Average absolute intercept error = " << intercept_error << " Error limit = " << eps_intercept); + + } + +// ---------------------------------------------------------------------------------------- + + void test_simple_autoencoder() + { + print_spinner(); + + srand(1234); + + const int output_width = 7; + const int output_height = 7; + const int num_samples = 100; + ::std::vector> x(num_samples); + + matrix tmp(output_width, output_height); + for (int i = 0; i < num_samples; ++i) + { + const int model = i % 4; + + for (int r = 0; r < output_height; ++r) + for (int c = 0; c < output_width; ++c) + switch (model) { + case 0: tmp(r, c) = r / output_height; break; + case 1: tmp(r, c) = c / output_width; break; + case 2: tmp(r, c) = 1.0 - r / output_height; break; + case 3: tmp(r, c) = 1.0 - c / output_width; break; + default: DLIB_TEST_MSG(false, "Invalid model: " << model << " (should be between 0 and 3)"); + } + + x[i] = tmp; + } + + using net_type = loss_mean_squared_per_pixel< + cont<1,output_height,output_width,2,2, + relu>>>>>; + net_type net; + + const auto autoencoder_error = [&x, &net, &output_height, &output_width]() + { + const auto y = net(x); + double error = 0.0; + for (size_t i = 0; i < x.size(); ++i) + for (int r = 0; r < output_height; ++r) + for (int c = 0; c < output_width; ++c) + error += fabs(y[i](r, c) - x[i](r, c)); + + return error / (x.size() * output_height * output_width); + }; + + // The autoencoder can't be very good before it's been trained + // (or at least the probability of the reconstruction error + // being small should be super low; in fact, the error ought to + // be much higher than 0.01, however since the initialization + // is random, putting the limit below too high could make the + // tests fail when other, unrelated tests are added into the + // sequence) + const double error_before = autoencoder_error(); + DLIB_TEST_MSG(error_before > 0.01, "Autoencoder error before training = " << error_before); + + // Make sure there's an information bottleneck, as intended + const auto& output2 = dlib::layer<2>(net).get_output(); + DLIB_TEST(output2.nr() == 1); + DLIB_TEST(output2.nc() == 1); + DLIB_TEST(output2.k() == 4); + + sgd defsolver(0,0.9); + dnn_trainer trainer(net, defsolver); + trainer.set_learning_rate(0.01); + trainer.set_max_num_epochs(1000); + trainer.train(x, x); + + // Now we should have learned everything there is to it + const double error_after = autoencoder_error(); + DLIB_TEST_MSG(error_after < 1e-6, "Autoencoder error after training = " << error_after); + } + +// ---------------------------------------------------------------------------------------- + + void test_loss_multiclass_per_pixel_learned_params_on_trivial_single_pixel_task() + { + print_spinner(); + + constexpr uint16_t num_classes = 7; + constexpr uint16_t true_label = num_classes / 2; + + ::std::vector> x({ matrix({ 1 }) }); + ::std::vector> y({ matrix({ true_label }) }); + + using net_type = loss_multiclass_log_per_pixel>>>; + net_type net; + + dnn_trainer trainer(net, sgd(0,0)); + trainer.set_learning_rate(1e7); + trainer.set_max_num_epochs(1); + trainer.train(x, y); + + const tensor& learned_params = layer<1>(net).layer_details().get_layer_params(); + const float* learned_params_data = learned_params.host(); + + for (int is_bias = 0; is_bias <= 1; ++is_bias) { + for (uint16_t k = 0; k < num_classes; ++k) { + size_t index = k + is_bias * num_classes; + DLIB_TEST(index < learned_params.size()); + if (k == true_label) { + DLIB_TEST(learned_params_data[index] > 1e5); + } + else { + DLIB_TEST(learned_params_data[index] < -1e5); + } + } + } + } + +// ---------------------------------------------------------------------------------------- + + void test_loss_multiclass_per_pixel_activations_on_trivial_single_pixel_task() + { + print_spinner(); + + constexpr int input_height = 35; + constexpr int input_width = 27; + constexpr int output_height = input_height; + constexpr int output_width = input_width; + constexpr int num_samples = 7; + constexpr int num_classes = 5; + + ::std::vector> x(num_samples); + ::std::vector> y(num_samples); + + matrix xtmp(input_height, input_width); + matrix ytmp(output_height, output_width); + + ::std::default_random_engine generator(16); + ::std::bernoulli_distribution coinflip(0.5); + + using filter_type = con>>; + + // Define a "truth" filter + filter_type truth_filter; + truth_filter(xtmp); // Set up the convolutional layer + + // Generate training data + for (int ii = 0; ii < num_samples; ++ii) { + // Generate random inputs x + for (int jj = 0; jj < input_height; ++jj) + for (int kk = 0; kk < input_width; ++kk) + xtmp(jj, kk) = coinflip(generator) ? 1.f : -1.f; + x[ii] = xtmp; + + // Generate target output y by applying the truth filter on x + const tensor& output = truth_filter(xtmp); + const float* const out_data = output.host(); + + const auto out_element = [&](int row, int column, int k) { + return out_data[(k * output.nr() + row) * output.nc() + column]; + }; + + for (int jj = 0; jj < output_height; ++jj) { + for (int kk = 0; kk < output_width; ++kk) { + uint16_t label = 0; + float max_value = out_element(jj, kk, 0); + for (long k = 1; k < num_classes; ++k) { + const float value = out_element(jj, kk, k); + if (value > max_value) { + label = static_cast(k); + max_value = value; + } + } + ytmp(jj, kk) = label; + } + } + y[ii] = ytmp; + } + + using net_type = loss_multiclass_log_per_pixel; + net_type net; + + dnn_trainer trainer(net, sgd(0,0)); + trainer.set_learning_rate(1e6); + trainer.set_max_num_epochs(1); + trainer.train(x, y); + + // Feed forward the training samples. + resizable_tensor temp_tensor; + net.subnet().to_tensor(&x[0], &x[0] + num_samples, temp_tensor); + net.subnet().forward(temp_tensor); + const dimpl::subnet_wrapper wsub(net.subnet()); + const tensor& output_tensor = wsub.get_output(); + const float* const out_data = output_tensor.host(); + + // Let's have a look at the activations before softmax. They should be pretty high + // (in terms of absolute value), because the learning task is trivial. + for (int ii = 0; ii < num_samples; ++ii) { + for (int jj = 0; jj < output_height; ++jj) { + for (int kk = 0; kk < output_width; ++kk) { + const uint16_t true_label = y[ii](jj, kk); + + for (long k = 0; k < num_classes; ++k) { + const size_t index = ((ii * output_tensor.k() + k) * output_tensor.nr() + jj) * output_tensor.nc() + kk; + DLIB_TEST(index < output_tensor.size()); + + if (k == true_label) { + DLIB_TEST(out_data[index] > 1e4); + } + else { + DLIB_TEST(out_data[index] < -1e4); + } + } + } + } + } + } + +// ---------------------------------------------------------------------------------------- + + void test_loss_multiclass_per_pixel_outputs_on_trivial_task() + { + print_spinner(); + + constexpr int input_height = 7; + constexpr int input_width = 5; + constexpr int output_height = input_height; + constexpr int output_width = input_width; + constexpr int num_samples = 7; + constexpr int num_classes = 5; + constexpr int filter_height = 3; + constexpr int filter_width = 3; + + ::std::vector> x(num_samples); + ::std::vector> y(num_samples); + + matrix xtmp(input_height, input_width); + matrix ytmp(output_height, output_width); + + ::std::default_random_engine generator(16); + ::std::bernoulli_distribution coinflip(0.5); + + using filter_type = con>>; + + // Define a "truth" filter + filter_type truth_filter; + truth_filter(xtmp); // Set up the convolutional layer + + // Generate training data + for (int ii = 0; ii < num_samples; ++ii) { + // Generate random inputs x + for (int jj = 0; jj < input_height; ++jj) + for (int kk = 0; kk < input_width; ++kk) + xtmp(jj, kk) = coinflip(generator) ? 1.f : -1.f; + x[ii] = xtmp; + + // Generate target output y by applying the truth filter on x + const tensor& output = truth_filter(xtmp); + const float* const out_data = output.host(); + + const auto out_element = [&](int row, int column, int k) { + return out_data[(k * output.nr() + row) * output.nc() + column]; + }; + + for (int jj = 0; jj < output_height; ++jj) { + for (int kk = 0; kk < output_width; ++kk) { + uint16_t label = 0; + float max_value = out_element(jj, kk, 0); + for (long k = 1; k < num_classes; ++k) { + const float value = out_element(jj, kk, k); + if (value > max_value) { + label = static_cast(k); + max_value = value; + } + } + ytmp(jj, kk) = label; + } + } + y[ii] = ytmp; + } + + using net_type = loss_multiclass_log_per_pixel; + net_type net; + + dnn_trainer trainer(net, sgd(0, 0.9)); + trainer.set_learning_rate(1); + trainer.set_max_num_epochs(2000); + trainer.train(x, y); + + // The learning task is separable, so the net should have no problem + // getting all the outputs right. + DLIB_TEST(net(x) == y); + } + +// ---------------------------------------------------------------------------------------- + + void test_loss_multiclass_per_pixel_with_noise_and_pixels_to_ignore() + { + // "Semantic segmentation" - see https://github.com/davisking/dlib/issues/288 + // Test learning when some pixels are to be ignored, etc. + + print_spinner(); + + constexpr int input_height = 5; + constexpr int input_width = 7; + constexpr int output_height = input_height; + constexpr int output_width = input_width; + const int num_samples = 1000; + const int num_classes = 6; + const double ignore_probability = 0.5; + const double noise_probability = 0.05; + + ::std::default_random_engine generator(16); + ::std::bernoulli_distribution ignore(ignore_probability); + ::std::bernoulli_distribution noise_occurrence(noise_probability); + ::std::uniform_int_distribution noisy_label(0, num_classes - 1); + + ::std::vector> x(num_samples); + ::std::vector> y(num_samples); + + ::std::vector truth_histogram(num_classes); + + matrix xtmp(input_height, input_width); + matrix ytmp(output_height, output_width); + + // The function to be learned. + const auto ground_truth = [num_classes](const matrix& x, int row, int column) { + double sum = 0.0; + const int first_column = std::max(0, column - 1); + const int last_column = std::min(static_cast(x.nc() - 1), column + 1); + for (int c = first_column; c <= last_column; ++c) { + sum += x(row, c); + } + DLIB_TEST(sum < num_classes); + return static_cast(sum); + }; + + for ( int ii = 0; ii < num_samples; ++ii ) { + for ( int jj = 0; jj < input_height; ++jj ) { + for ( int kk = 0; kk < input_width; ++kk ) { + // Generate numbers between 0 and 2. + double value = static_cast(ii + jj + kk) / 10.0; + value -= (static_cast(value) / 2) * 2; + DLIB_TEST(value >= 0.0 && value < 2.0); + xtmp(jj, kk) = value; + } + } + x[ii] = xtmp; + + for ( int jj = 0; jj < output_height; ++jj ) { + for ( int kk = 0; kk < output_width; ++kk ) { + uint16_t truth = ground_truth(x[ii], jj, kk); + DLIB_TEST(truth < num_classes); + ++truth_histogram[truth]; + if (ignore(generator)) { + ytmp(jj, kk) = loss_multiclass_log_per_pixel_::label_to_ignore; + } + else if (noise_occurrence(generator)) { + ytmp(jj, kk) = noisy_label(generator); + } + else { + ytmp(jj, kk) = truth; + } + } + } + + y[ii] = ytmp; + } + + const int num_total_elements = num_samples * output_height * output_width; + + { // Require a reasonably balanced truth histogram in order to make sure that a trivial classifier is not enough + const int required_min_histogram_value = static_cast(::std::ceil(num_total_elements / num_classes * 0.375)); + for (auto histogram_value : truth_histogram) { + DLIB_TEST_MSG(histogram_value >= required_min_histogram_value, + "Histogram value = " << histogram_value << ", required = " << required_min_histogram_value); + } + } + + using net_type = loss_multiclass_log_per_pixel>>>>; + net_type net; + sgd defsolver(0,0.9); + dnn_trainer trainer(net, defsolver); + trainer.set_learning_rate(0.1); + trainer.set_min_learning_rate(0.01); + trainer.set_mini_batch_size(50); + trainer.set_max_num_epochs(170); + trainer.train(x, y); + + const ::std::vector> predictions = net(x); + + int num_correct = 0; + + for ( int ii = 0; ii < num_samples; ++ii ) { + const matrix& prediction = predictions[ii]; + DLIB_TEST(prediction.nr() == output_height); + DLIB_TEST(prediction.nc() == output_width); + for ( int jj = 0; jj < output_height; ++jj ) + for ( int kk = 0; kk < output_width; ++kk ) + if ( prediction(jj, kk) == ground_truth(x[ii], jj, kk) ) + ++num_correct; + } + + // First some sanity checks. + const int num_correct_max = num_total_elements; + DLIB_TEST(num_correct_max == ::std::accumulate(truth_histogram.begin(), truth_histogram.end(), 0)); + DLIB_TEST_MSG(num_correct <= num_correct_max, + "Number of correctly classified elements = " << num_correct << ", max = " << num_correct_max); + + // This is the real test, verifying that we have actually learned something. + const int num_correct_required = static_cast(::std::ceil(0.9 * num_correct_max)); + DLIB_TEST_MSG(num_correct >= num_correct_required, + "Number of correctly classified elements = " << num_correct << ", required = " << num_correct_required); + } + +// ---------------------------------------------------------------------------------------- + + void test_loss_multiclass_per_pixel_weighted() + { + // Train with pixel-specific weights + + print_spinner(); + + constexpr int input_height = 5; + constexpr int input_width = 7; + constexpr int output_height = input_height; + constexpr int output_width = input_width; + const int num_samples = 1000; + const int num_classes = 6; + + ::std::default_random_engine generator(16); + ::std::uniform_real_distribution u01(0.0, 1.0); + ::std::uniform_int_distribution noisy_label(0, num_classes - 1); + + ::std::vector> x(num_samples); + ::std::vector> y(num_samples); + + matrix xtmp(input_height, input_width); + matrix ytmp(output_height, output_width); + + // Generate input data + for (int ii = 0; ii < num_samples; ++ii) { + for (int jj = 0; jj < input_height; ++jj) { + for (int kk = 0; kk < input_width; ++kk) { + xtmp(jj, kk) = u01(generator); + ytmp(jj, kk) = noisy_label(generator); + } + } + x[ii] = xtmp; + y[ii] = ytmp; + } + + using net_type = loss_multiclass_log_per_pixel_weighted>>>; + using weighted_label = loss_multiclass_log_per_pixel_weighted_::weighted_label; + + ::std::vector> y_weighted(num_samples); + + for (int weighted_class = 0; weighted_class < num_classes; ++weighted_class) { + + print_spinner(); + + // Assign weights + for (int ii = 0; ii < num_samples; ++ii) { + if (weighted_class == 0) { + y_weighted[ii].set_size(input_height, input_width); + } + for (int jj = 0; jj < input_height; ++jj) { + for (int kk = 0; kk < input_width; ++kk) { + const uint16_t label = y[ii](jj, kk); + const float weight + = label == weighted_class + ? 1.1f + : 0.9f; + y_weighted[ii](jj, kk) = weighted_label(label, weight); + } + } + } + + net_type net; + sgd defsolver(0,0.9); + dnn_trainer trainer(net, defsolver); + trainer.set_learning_rate(0.1); + trainer.set_min_learning_rate(0.01); + trainer.set_mini_batch_size(10); + trainer.set_max_num_epochs(10); + trainer.train(x, y_weighted); + + const ::std::vector> predictions = net(x); + + int num_weighted_class = 0; + int num_not_weighted_class = 0; + + for ( int ii = 0; ii < num_samples; ++ii ) { + const matrix& prediction = predictions[ii]; + DLIB_TEST(prediction.nr() == output_height); + DLIB_TEST(prediction.nc() == output_width); + for ( int jj = 0; jj < output_height; ++jj ) + for ( int kk = 0; kk < output_width; ++kk ) + if ( prediction(jj, kk) == weighted_class ) + ++num_weighted_class; + else + ++num_not_weighted_class; + } + + DLIB_TEST_MSG(num_weighted_class > num_not_weighted_class, + "The weighted class (" << weighted_class << ") does not dominate: " + << num_weighted_class << " <= " << num_not_weighted_class); + } + } + +// ---------------------------------------------------------------------------------------- + + void test_tensor_resize_bilinear(long samps, long k, long nr, long nc, long onr, long onc) + { + resizable_tensor img(samps,k,nr,nc); + resizable_tensor out(samps,k,onr,onc); + resizable_tensor out2(samps,k,onr,onc); + + dlib::rand rnd; + for (int iter = 0; iter < 10; ++iter) + { + print_spinner(); + + const size_t idx = rnd.get_random_64bit_number()%img.size(); + + img = 1; + img.host()[idx] = 2; + cpu::resize_bilinear(out, img); +#ifdef DLIB_USE_CUDA + cuda::resize_bilinear(out2, img); + DLIB_TEST(max(abs(mat(out)-mat(out2))) < 1e-5); +#endif + + resizable_tensor gradient_input; + gradient_input.copy_size(out); + tt::tensor_rand rnd; + rnd.fill_uniform(gradient_input); + + const float h = 1e-2; + + img.host()[idx] = 2; + cpu::resize_bilinear(out, img); + float f1 = dot(out, gradient_input); + + img.host()[idx] = 2+h; + cpu::resize_bilinear(out, img); + float f2 = dot(out, gradient_input); + + const float numerical_grad = (f2-f1)/h; + dlog << LINFO << "numerical grad: " << numerical_grad; + + + resizable_tensor grad, grad2; + grad.copy_size(img); + grad = 0.1; + grad2.copy_size(img); + grad2 = 0.1; + + cpu::resize_bilinear_gradient(grad2, gradient_input); + dlog << LINFO << "analytic grad: "<< grad2.host()[idx]-0.1; + DLIB_TEST_MSG(std::abs(numerical_grad - grad2.host()[idx]+0.1) < 1e-2, std::abs(numerical_grad - grad2.host()[idx]+0.1) << " numerical_grad: " << numerical_grad); + +#ifdef DLIB_USE_CUDA + cuda::resize_bilinear_gradient(grad, gradient_input); + dlog << LINFO << "analytic grad: "<< grad.host()[idx]-0.1; + DLIB_TEST_MSG(std::abs(numerical_grad - grad.host()[idx]+0.1) < 1e-2, std::abs(numerical_grad - grad.host()[idx]+0.1) << " numerical_grad: " << numerical_grad); + DLIB_TEST(max(abs(mat(grad)-mat(grad2))) < 1e-5); +#endif + + } + + + // now test with strided/sub-window calls + alias_tensor aimg(samps, k, nr-2,nc-2); + alias_tensor aout(samps, k, onr-2,onc-2); + for (int iter = 0; iter < 10; ++iter) + { + print_spinner(); + + const size_t idx = rnd.get_random_64bit_number()%img.size(); + + img = 1; + img.host()[idx] = 2; + out = 9; + out2 = 9; + auto wout = aout(out, out.nc()*1+1); + auto wimg = aimg(img, img.nc()*1+1); + cpu::resize_bilinear(wout,out.nc(),out.nr()*out.nc(), wimg,img.nc(),img.nr()*img.nc()); +#ifdef DLIB_USE_CUDA + auto wout2 = aout(out2, out2.nc()*1+1); + cuda::resize_bilinear(wout2,out2.nc(),out2.nr()*out2.nc(), wimg,img.nc(),img.nr()*img.nc()); + DLIB_TEST(max(abs(mat(out)-mat(out2))) < 1e-5); +#endif + + + resizable_tensor gradient_input; + gradient_input.copy_size(out); + tt::tensor_rand rnd; + rnd.fill_uniform(gradient_input); + + const float h = 1e-2; + + img.host()[idx] = 2; + out = 0; + wout = aout(out, out.nc()*1+1); + wimg = aimg(img, img.nc()*1+1); + cpu::resize_bilinear(wout,out.nc(),out.nr()*out.nc(), wimg,img.nc(),img.nr()*img.nc()); + float f1 = dot(out, gradient_input); + + img.host()[idx] = 2+h; + out = 0; + cpu::resize_bilinear(wout,out.nc(),out.nr()*out.nc(), wimg,img.nc(),img.nr()*img.nc()); + float f2 = dot(out, gradient_input); + + const float numerical_grad = (f2-f1)/h; + dlog << LINFO << "numerical grad: " << numerical_grad; + + + resizable_tensor grad, grad2; + grad.copy_size(img); + grad = 0.1; + grad2.copy_size(img); + grad2 = 0.1; + + auto wgrad2 = aimg(grad2, grad2.nc()*1+1); + auto wgradient_input = aout(gradient_input, gradient_input.nc()*1+1); + cpu::resize_bilinear_gradient(wgrad2,grad2.nc(),grad2.nr()*grad2.nc(), wgradient_input,gradient_input.nc(),gradient_input.nr()*gradient_input.nc()); + dlog << LINFO << "analytic grad: "<< grad2.host()[idx]-0.1; + DLIB_TEST_MSG(std::abs(numerical_grad - grad2.host()[idx]+0.1) < 1e-2, std::abs(numerical_grad - grad2.host()[idx]+0.1) << " numerical_grad: " << numerical_grad); + +#ifdef DLIB_USE_CUDA + wgrad2 = aimg(grad, grad.nc()*1+1); + wgradient_input = aout(gradient_input, gradient_input.nc()*1+1); + cuda::resize_bilinear_gradient(wgrad2,grad.nc(),grad.nr()*grad.nc(), wgradient_input,gradient_input.nc(),gradient_input.nr()*gradient_input.nc()); + dlog << LINFO << "analytic grad: "<< grad.host()[idx]-0.1; + DLIB_TEST_MSG(std::abs(numerical_grad - grad.host()[idx]+0.1) < 1e-2, std::abs(numerical_grad - grad.host()[idx]+0.1) << " numerical_grad: " << numerical_grad); + DLIB_TEST_MSG(max(abs(mat(grad)-mat(grad2))) < 1e-5, max(abs(mat(grad)-mat(grad2)))); +#endif + + + } + } + + + void test_serialization() + { + print_spinner(); + + using net_type = loss_mean_squared>>>; + net_type net, net2; + + std::ostringstream out; + serialize(net, out); + const std::string serialized = out.str(); + std::istringstream in(serialized); + dlib::deserialize(net2, in); + } + +// ---------------------------------------------------------------------------------------- + + void test_loss_dot() + { + print_spinner(); + + std::vector> samples; + std::vector> labels; + + const matrix proj = matrix_cast(randm(2,3)); + for (int i = 0; i < 128; ++i) + { + // The task is going to be to learn the matrix proj. So we make our + // training data thusly: + matrix x = matrix_cast(randm(3,1)); + matrix y = normalize(proj*x); + samples.push_back(x); + labels.push_back(y); + } + + using net_type = loss_dot< + l2normalize> + >>>; + + net_type net; + dnn_trainer trainer(net, sgd(1e-4, 0.9)); + trainer.set_learning_rate(0.01); + trainer.set_min_learning_rate(0.0000001); + trainer.set_mini_batch_size(128); + trainer.set_max_num_epochs(50000); + trainer.train(samples, labels); + + + for (size_t i = 0; i < samples.size(); ++i) + { + DLIB_TEST(std::abs(1-dot(net(samples[i]),labels[i])) < 0.001); + } + } + +// ---------------------------------------------------------------------------------------- + + void test_loss_multimulticlass_log() + { + print_spinner(); + std::map> all_labels; + all_labels["c1"] = {"a", "b", "c"}; + all_labels["c2"] = {"d", "e", "f"}; + + // make training data + std::vector> samples; + std::vector> labels; + for (int i = 0; i < 3; ++i) + { + for (int j = 0; j < 3; ++j) + { + matrix samp(2,3); + samp = 0; + samp(0,i) = 1; + samp(1,j) = 1; + samples.push_back(samp); + + std::map l; + if (i == 0) l["c1"] = "a"; + if (i == 1) l["c1"] = "b"; + if (i == 2) l["c1"] = "c"; + if (j == 0) l["c2"] = "d"; + if (j == 1) l["c2"] = "e"; + if (j == 2) l["c2"] = "f"; + labels.push_back(l); + } + } + + using net_type = loss_multimulticlass_log< + fc<1, + input> + >>; + + net_type net(all_labels); + net.subnet().layer_details().set_num_outputs(net.loss_details().number_of_labels()); + + dnn_trainer trainer(net, sgd(0.1)); + trainer.set_learning_rate(0.1); + trainer.set_min_learning_rate(0.00001); + trainer.set_iterations_without_progress_threshold(500); + + trainer.train(samples, labels); + + auto predicted_labels = net(samples); + + // make sure the network predicts the right labels + for (size_t i = 0; i < samples.size(); ++i) + { + DLIB_TEST(predicted_labels[i]["c1"] == labels[i]["c1"]); + DLIB_TEST(predicted_labels[i]["c2"] == labels[i]["c2"]); + } + + } + +// ---------------------------------------------------------------------------------------- + + class dnn_tester : public tester + { + public: + dnn_tester ( + ) : + tester ("test_dnn", + "Runs tests on the deep neural network tools.") + {} + + void run_tests ( + ) + { + // make the tests repeatable + srand(1234); + + test_tagging(); +#ifdef DLIB_USE_CUDA + test_affine_rect(); + test_conv(); + test_more_ops2(); + test_more_ops(1,1); + test_more_ops(3,4); + test_more_ops(4,3); + test_more_ops(4,1); + test_more_ops(1,4); + test_more_ops(10000,4); + compare_bn_gpu_and_cpu(); + compare_bn_conv_gpu_and_cpu(); + test_add(); + test_multiply_zero_padded(); + compare_adam(); + test_copy_tensor_gpu(); + test_copy_tensor_add_to_gpu(); + test_scale_channels(); +#endif + test_tensor_resize_bilinear(2, 3, 6,6, 11, 11); + test_tensor_resize_bilinear(2, 3, 6,6, 3, 4); + test_tensor_resize_bilinear(2, 3, 5,6, 12, 21); + test_max_pool(1,1,2,3,0,0); + test_max_pool(3,3,1,1,0,0); + test_max_pool(3,3,2,2,0,0); + test_max_pool(2,2,2,2,0,0); + test_max_pool(4,5,3,1,0,0); + test_avg_pool(1,1,2,3,0,0); + test_avg_pool(3,3,1,1,0,0); + test_avg_pool(3,3,2,2,0,0); + test_avg_pool(2,2,2,2,0,0); + test_avg_pool(4,5,3,1,0,0); + test_avg_pool(4,4,2,2,0,0); + test_avg_pool(4,5,40,50,0,0); + test_max_pool(2,2,2,3,1,1); + test_max_pool(3,3,1,1,1,1); + test_max_pool(3,3,2,2,2,1); + test_max_pool(2,2,2,2,1,0); + test_max_pool(4,5,3,1,2,3); + test_avg_pool(1,1,2,3,0,0); + test_avg_pool(3,3,1,1,1,2); + test_avg_pool(3,3,2,2,2,1); + test_avg_pool(2,2,2,2,1,0); + test_avg_pool(4,5,3,1,2,4); + test_avg_pool(4,4,2,2,1,3); + test_avg_pool(4,5,40,50,0,1); + test_tanh(); + test_softmax(); + test_softmax_all(); + test_sigmoid(); + test_batch_normalize(); + test_batch_normalize_conv(); + test_basic_tensor_ops(); + test_layers(); + test_visit_funcions(); + test_copy_tensor_cpu(); + test_copy_tensor_add_to_cpu(); + test_concat(); + test_simple_linear_regression(); + test_simple_linear_regression_eil(); + test_simple_linear_regression_with_mult_prev(); + test_multioutput_linear_regression(); + test_simple_autoencoder(); + test_loss_multiclass_per_pixel_learned_params_on_trivial_single_pixel_task(); + test_loss_multiclass_per_pixel_activations_on_trivial_single_pixel_task(); + test_loss_multiclass_per_pixel_outputs_on_trivial_task(); + test_loss_multiclass_per_pixel_with_noise_and_pixels_to_ignore(); + test_loss_multiclass_per_pixel_weighted(); + test_serialization(); + test_loss_dot(); + test_loss_multimulticlass_log(); + } + + void perform_test() + { + dlog << LINFO << "NOW RUNNING TESTS WITH set_dnn_prefer_fastest_algorithms()"; + set_dnn_prefer_fastest_algorithms(); + run_tests(); + + dlog << LINFO << "NOW RUNNING TESTS WITH set_dnn_prefer_smallest_algorithms()"; + set_dnn_prefer_smallest_algorithms(); + run_tests(); + } + } a; +} + +#endif // __INTELLISENSE__ + diff --git a/ml/dlib/dlib/test/ekm_and_lisf.cpp b/ml/dlib/dlib/test/ekm_and_lisf.cpp new file mode 100644 index 000000000..b6f410177 --- /dev/null +++ b/ml/dlib/dlib/test/ekm_and_lisf.cpp @@ -0,0 +1,306 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "tester.h" +#include +#include +#include +#include +#include +#include + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + dlib::logger dlog("test.ekm_and_lisf"); + + + class empirical_kernel_map_tester : public tester + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a unit test. When it is constructed + it adds itself into the testing framework. + !*/ + public: + empirical_kernel_map_tester ( + ) : + tester ( + "test_ekm_and_lisf", // the command line argument name for this test + "Run tests on the empirical_kernel_map and linearly_independent_subset_finder objects.", // the command line argument description + 0 // the number of command line arguments for this test + ) + { + thetime = time(0); + } + + time_t thetime; + dlib::rand rnd; + + template + void validate ( + const T& ekm_small, + const T& ekm_big + ) + { + matrix tmat; + projection_function proj; + + ekm_small.get_transformation_to(ekm_big, tmat, proj); + DLIB_TEST(tmat.nr() == ekm_big.out_vector_size()); + DLIB_TEST(tmat.nc() == ekm_small.out_vector_size()); + DLIB_TEST((unsigned long)proj.basis_vectors.size() == ekm_big.basis_size() - ekm_small.basis_size()); + for (unsigned long i = 0; i < 6; ++i) + { + const typename T::sample_type temp = randm(4,1,rnd); + DLIB_TEST(length(ekm_big.project(temp) - (tmat*ekm_small.project(temp) + proj(temp))) < 1e-10); + } + } + + + void test_transformation_stuff() + { + typedef matrix sample_type; + typedef radial_basis_kernel kernel_type; + const kernel_type kern(1); + + + for (unsigned long n = 1; n < 6; ++n) + { + print_spinner(); + for (unsigned long extra = 1; extra < 10; ++extra) + { + std::vector samps_small, samps_big; + linearly_independent_subset_finder lisf_small(kern, 1000); + linearly_independent_subset_finder lisf_big(kern, 1000); + for (unsigned long i = 0; i < n; ++i) + { + samps_small.push_back(randm(4,1,rnd)); + samps_big.push_back(samps_small.back()); + lisf_big.add(samps_small.back()); + lisf_small.add(samps_small.back()); + } + for (unsigned long i = 0; i < extra; ++i) + { + samps_big.push_back(randm(4,1,rnd)); + lisf_big.add(samps_big.back()); + } + + + // test no lisf + { + empirical_kernel_map ekm_small, ekm_big; + ekm_small.load(kern, samps_small); + ekm_big.load(kern, samps_big); + + validate(ekm_small, ekm_big); + } + + // test with lisf + { + empirical_kernel_map ekm_small, ekm_big; + ekm_small.load(lisf_small); + ekm_big.load(lisf_big); + + validate(ekm_small, ekm_big); + } + + // test with partly lisf + { + empirical_kernel_map ekm_small, ekm_big; + ekm_small.load(kern, samps_small); + ekm_big.load(lisf_big); + + validate(ekm_small, ekm_big); + } + + // test with partly lisf + { + empirical_kernel_map ekm_small, ekm_big; + ekm_small.load(lisf_small); + ekm_big.load(kern, samps_big); + + validate(ekm_small, ekm_big); + } + + } + } + + + // test what happens if the bigger ekm only has repeated basis vectors + { + empirical_kernel_map ekm_big, ekm_small; + std::vector samps_big, samps_small; + + sample_type temp = randm(4,1,rnd); + + samps_small.push_back(temp); + samps_big.push_back(temp); + samps_big.push_back(temp); + + ekm_big.load(kern, samps_big); + ekm_small.load(kern, samps_small); + + validate(ekm_small, ekm_big); + + } + { + empirical_kernel_map ekm_big, ekm_small; + linearly_independent_subset_finder lisf_small(kern, 1000); + std::vector samps_big; + + sample_type temp = randm(4,1,rnd); + + lisf_small.add(temp); + samps_big.push_back(temp); + samps_big.push_back(temp); + + ekm_big.load(kern, samps_big); + ekm_small.load(lisf_small); + + validate(ekm_small, ekm_big); + + } + { + empirical_kernel_map ekm_big, ekm_small; + std::vector samps_big, samps_small; + + sample_type temp = randm(4,1,rnd); + sample_type temp2 = randm(4,1,rnd); + + samps_small.push_back(temp); + samps_small.push_back(temp2); + samps_big.push_back(temp); + samps_big.push_back(temp2); + samps_big.push_back(randm(4,1,rnd)); + + ekm_big.load(kern, samps_big); + ekm_small.load(kern, samps_small); + + validate(ekm_small, ekm_big); + + } + { + empirical_kernel_map ekm_big, ekm_small; + linearly_independent_subset_finder lisf_small(kern, 1000); + std::vector samps_big; + + sample_type temp = randm(4,1,rnd); + sample_type temp2 = randm(4,1,rnd); + + lisf_small.add(temp); + lisf_small.add(temp2); + samps_big.push_back(temp); + samps_big.push_back(temp2); + samps_big.push_back(temp); + + ekm_big.load(kern, samps_big); + ekm_small.load(lisf_small); + + validate(ekm_small, ekm_big); + + } + + + } + + + + void perform_test ( + ) + { + ++thetime; + typedef matrix sample_type; + //dlog << LINFO << "time seed: " << thetime; + //rnd.set_seed(cast_to_string(thetime)); + + + typedef radial_basis_kernel kernel_type; + + + for (int n = 1; n < 10; ++n) + { + print_spinner(); + dlog << LINFO << "matrix size " << n; + + std::vector samples; + // make some samples + for (int i = 0; i < n; ++i) + { + samples.push_back(randm(4,1,rnd)); + // double up the samples just to mess with the lisf + if (n > 5) + samples.push_back(samples.back()); + } + + dlog << LINFO << "samples.size(): "<< samples.size(); + + const kernel_type kern(1); + + linearly_independent_subset_finder lisf(kern, 100, 1e-4); + unsigned long count = 0; + for (unsigned long i = 0; i < samples.size(); ++i) + { + if (lisf.add(samples[i])) + { + DLIB_TEST(equal(lisf[lisf.size()-1], samples[i])); + ++count; + } + } + DLIB_TEST(count == lisf.size()); + + DLIB_TEST(lisf.size() == (unsigned int)n); + + + dlog << LINFO << "lisf.size(): "<< lisf.size(); + + // make sure the kernel matrices coming out of the lisf are correct + DLIB_TEST(dlib::equal(lisf.get_kernel_matrix(), kernel_matrix(kern, lisf), 1e-8)); + DLIB_TEST(dlib::equal(lisf.get_inv_kernel_marix(), inv(kernel_matrix(kern, lisf.get_dictionary())), 1e-8)); + + empirical_kernel_map ekm; + ekm.load(lisf); + DLIB_TEST(ekm.basis_size() == lisf.size()); + + std::vector proj_samples; + for (unsigned long i = 0; i < samples.size(); ++i) + { + double err; + proj_samples.push_back(ekm.project(samples[i], err)); + DLIB_TEST(err <= 1e-4); + const double error_agreement = std::abs(err - lisf.projection_error(samples[i])); + dlog << LTRACE << "err: " << err << " error_agreement: "<< error_agreement; + DLIB_TEST(error_agreement < 1e-11); + } + + for (int i = 0; i < 5; ++i) + { + sample_type temp = randm(4,1,rnd); + double err; + ekm.project(temp, err); + const double error_agreement = std::abs(err - lisf.projection_error(temp)); + dlog << LTRACE << "err: " << err << " error_agreement: "<< error_agreement; + DLIB_TEST(error_agreement < 1e-11); + } + + // make sure the EKM did the projection correctly + DLIB_TEST(dlib::equal(kernel_matrix(kern, samples), kernel_matrix(linear_kernel(), proj_samples), 1e-5)); + } + + + test_transformation_stuff(); + + } + }; + + // Create an instance of this object. Doing this causes this test + // to be automatically inserted into the testing framework whenever this cpp file + // is linked into the project. Note that since we are inside an unnamed-namespace + // we won't get any linker errors about the symbol a being defined multiple times. + empirical_kernel_map_tester a; + +} + + diff --git a/ml/dlib/dlib/test/elastic_net.cpp b/ml/dlib/dlib/test/elastic_net.cpp new file mode 100644 index 000000000..0e0501639 --- /dev/null +++ b/ml/dlib/dlib/test/elastic_net.cpp @@ -0,0 +1,122 @@ +// Copyright (C) 2016 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include +#include "tester.h" +#include +#include +#include +#include +#include +#include + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + dlib::logger dlog("test.elastic_net"); + +// ---------------------------------------------------------------------------------------- + + matrix basic_elastic_net( + const matrix& X, + const matrix& Y, + double ridge_lambda, + double lasso_budget, + double eps + ) + { + DLIB_CASSERT(X.nc() == Y.nr(),""); + + + typedef matrix sample_type; + typedef linear_kernel kernel_type; + + svm_c_linear_dcd_trainer trainer; + trainer.solve_svm_l2_problem(true); + const double C = 1/(2*ridge_lambda); + trainer.set_c(C); + trainer.set_epsilon(eps); + trainer.enable_shrinking(true); + trainer.include_bias(false); + + + std::vector samples; + std::vector labels; + for (long r = 0; r < X.nr(); ++r) + { + sample_type temp = trans(rowm(X,r)); + + const double xmul = (1/lasso_budget); + samples.push_back(temp - xmul*Y); + labels.push_back(+1); + samples.push_back(temp + xmul*Y); + labels.push_back(-1); + } + + svm_c_linear_dcd_trainer::optimizer_state state; + auto df = trainer.train(samples, labels, state); + auto&& alpha = state.get_alpha(); + + matrix betas(alpha.size()/2); + for (long i = 0; i < betas.size(); ++i) + betas(i) = lasso_budget*(alpha[2*i] - alpha[2*i+1]); + betas /= sum(mat(alpha)); + return betas; + } + +// ---------------------------------------------------------------------------------------- + + class test_elastic_net : public tester + { + public: + test_elastic_net ( + ) : + tester ( + "test_elastic_net", + "Run tests on the elastic_net object.", + 0 + ) + { + } + + void perform_test ( + ) + { + matrix w = {1,2,0,4, 0,0,0,0,0, 6, 7,8,0, 9, 0}; + + matrix X = randm(w.size(),1000); + matrix Y = trans(X)*w; + Y += 0.1*(randm(Y.nr(), Y.nc())-0.5); + + + double ridge_lambda = 0.1; + double lasso_budget = sum(abs(w)); + double eps = 0.0000001; + + dlib::elastic_net solver(X*trans(X),X*Y); + solver.set_epsilon(eps); + + + matrix results; + matrix results2; + for (double s = 1.2; s > 0.10; s *= 0.9) + { + print_spinner(); + dlog << LINFO << "s: "<< s; + // make sure the two solvers agree. + results = basic_elastic_net(X, Y, ridge_lambda, lasso_budget*s, eps); + results2 = solver(ridge_lambda, lasso_budget*s); + dlog << LINFO << "error: "<< max(abs(results - results2)); + DLIB_TEST(max(abs(results - results2)) < 1e-3); + } + } + } a; + +// ---------------------------------------------------------------------------------------- + +} + + + diff --git a/ml/dlib/dlib/test/empirical_kernel_map.cpp b/ml/dlib/dlib/test/empirical_kernel_map.cpp new file mode 100644 index 000000000..95b085ab3 --- /dev/null +++ b/ml/dlib/dlib/test/empirical_kernel_map.cpp @@ -0,0 +1,444 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "tester.h" +#include +#include +#include +#include +#include +#include + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + dlib::logger dlog("test.empirical_kernel_map"); + + + class empirical_kernel_map_tester : public tester + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a unit test. When it is constructed + it adds itself into the testing framework. + !*/ + public: + empirical_kernel_map_tester ( + ) : + tester ( + "test_empirical_kernel_map", // the command line argument name for this test + "Run tests on the empirical_kernel_map object.", // the command line argument description + 0 // the number of command line arguments for this test + ) + { + // always use the same time so that tests are repeatable + thetime = 0;//time(0); + } + + time_t thetime; + dlib::rand rnd; + + void test_projection_error() + { + for (int runs = 0; runs < 10; ++runs) + { + print_spinner(); + typedef matrix sample_type; + typedef radial_basis_kernel kernel_type; + const kernel_type kern(0.2); + + empirical_kernel_map ekm; + + // generate samples + const int num = rnd.get_random_8bit_number()%50 + 1; + std::vector samples; + for (int i = 0; i < num; ++i) + { + samples.push_back(randm(5,1,rnd)); + } + + + ekm.load(kern, samples); + DLIB_TEST(ekm.basis_size() == samples.size()); + + double err; + + // the samples in the basis should have zero projection error + for (unsigned long i = 0; i < samples.size(); ++i) + { + ekm.project(samples[i], err); + DLIB_TEST_MSG(abs(err) < 1e-13, abs(err)); + + } + + // Do some sanity tests on the conversion to distance functions while we are at it. + for (int i = 0; i < 30; ++i) + { + // pick two random samples + const sample_type samp1 = samples[rnd.get_random_32bit_number()%samples.size()]; + const sample_type samp2 = samples[rnd.get_random_32bit_number()%samples.size()]; + + const matrix proj1 = ekm.project(samp1); + const matrix proj2 = ekm.project(samp2); + + distance_function df1 = ekm.convert_to_distance_function(proj1); + distance_function df2 = ekm.convert_to_distance_function(proj2); + + DLIB_TEST(df1.get_kernel() == kern); + DLIB_TEST(df2.get_kernel() == kern); + + // make sure the norms are correct + DLIB_TEST(std::abs(df1.get_squared_norm() - + trans(df1.get_alpha())*kernel_matrix(df1.get_kernel(),df1.get_basis_vectors())*df1.get_alpha()) < 1e-10); + DLIB_TEST(std::abs(df2.get_squared_norm() - + trans(df2.get_alpha())*kernel_matrix(df2.get_kernel(),df2.get_basis_vectors())*df2.get_alpha()) < 1e-10); + + + const double true_dist = std::sqrt(kern(samp1,samp1) + kern(samp2,samp2) - 2*kern(samp1,samp2)); + DLIB_TEST_MSG(abs(df1(df2) - true_dist) < 1e-7, abs(df1(df2) - true_dist)); + DLIB_TEST_MSG(abs(length(proj1-proj2) - true_dist) < 1e-7, abs(length(proj1-proj2) - true_dist)); + + + // test distance function operators + const decision_function dec1 = ekm.convert_to_decision_function(proj1); + const decision_function dec2 = ekm.convert_to_decision_function(proj2); + DLIB_TEST(dec1.kernel_function == kern); + DLIB_TEST(dec2.kernel_function == kern); + + distance_function temp; + temp = dec1; + DLIB_TEST(std::abs(temp.get_squared_norm() - df1.get_squared_norm()) < 1e-10); + temp = dec2; + DLIB_TEST(std::abs(temp.get_squared_norm() - df2.get_squared_norm()) < 1e-10); + temp = distance_function(dec1.alpha, dec1.kernel_function, dec1.basis_vectors); + DLIB_TEST(std::abs(temp.get_squared_norm() - df1.get_squared_norm()) < 1e-10); + + df1 = dec1; + + temp = df1 + df2; + decision_function dec3(temp.get_alpha(), 0, temp.get_kernel(), temp.get_basis_vectors()); + DLIB_TEST(std::abs(temp.get_squared_norm() - + trans(temp.get_alpha())*kernel_matrix(temp.get_kernel(),temp.get_basis_vectors())*temp.get_alpha()) < 1e-10); + for (unsigned long j = 0; j < samples.size(); ++j) + { + DLIB_TEST(std::abs(dec3(samples[j]) - (dec1(samples[j]) + dec2(samples[j]))) < 1e-10); + } + + + temp = df1 - df2; + dec3 = decision_function(temp.get_alpha(), 0, temp.get_kernel(), temp.get_basis_vectors()); + DLIB_TEST(std::abs(temp.get_squared_norm() - + trans(temp.get_alpha())*kernel_matrix(temp.get_kernel(),temp.get_basis_vectors())*temp.get_alpha()) < 1e-10); + for (unsigned long j = 0; j < samples.size(); ++j) + { + DLIB_TEST(std::abs(dec3(samples[j]) - (dec1(samples[j]) - dec2(samples[j]))) < 1e-10); + } + + temp = 3*(df1 - df2)*2; + dec3 = decision_function(temp.get_alpha(), 0, temp.get_kernel(), temp.get_basis_vectors()); + DLIB_TEST(std::abs(temp.get_squared_norm() - + trans(temp.get_alpha())*kernel_matrix(temp.get_kernel(),temp.get_basis_vectors())*temp.get_alpha()) < 1e-10); + for (unsigned long j = 0; j < samples.size(); ++j) + { + DLIB_TEST(std::abs(dec3(samples[j]) - 6*(dec1(samples[j]) - dec2(samples[j]))) < 1e-10); + } + + distance_function df_empty(kern); + + temp = df_empty + (df1 + df2)/2 + df_empty - df_empty + (df_empty + df_empty) - (df_empty - df_empty); + dec3 = decision_function(temp.get_alpha(), 0, temp.get_kernel(), temp.get_basis_vectors()); + DLIB_TEST(std::abs(temp.get_squared_norm() - + trans(temp.get_alpha())*kernel_matrix(temp.get_kernel(),temp.get_basis_vectors())*temp.get_alpha()) < 1e-10); + for (unsigned long j = 0; j < samples.size(); ++j) + { + DLIB_TEST(std::abs(dec3(samples[j]) - 0.5*(dec1(samples[j]) + dec2(samples[j]))) < 1e-10); + } + } + // Do some sanity tests on the conversion to distance functions while we are at it. This + // time multiply one of the projections by 30 and see that it still all works out right. + for (int i = 0; i < 30; ++i) + { + // pick two random samples + const sample_type samp1 = samples[rnd.get_random_32bit_number()%samples.size()]; + const sample_type samp2 = samples[rnd.get_random_32bit_number()%samples.size()]; + + matrix proj1 = ekm.project(samp1); + matrix proj2 = 30*ekm.project(samp2); + + distance_function df1 = ekm.convert_to_distance_function(proj1); + distance_function df2 = ekm.convert_to_distance_function(proj2); + + DLIB_TEST_MSG(abs(length(proj1-proj2) - df1(df2)) < 1e-7, abs(length(proj1-proj2) - df1(df2))); + } + + + // now generate points with projection error + for (double i = 1; i < 10; ++i) + { + sample_type test_point = i*randm(5,1,rnd); + ekm.project(test_point, err); + // turn into normal distance rather than squared distance + err = sqrt(err); + dlog << LTRACE << "projection error: " << err; + + distance_function df = ekm.convert_to_distance_function(ekm.project(test_point)); + + // the projection error should be the distance between the test_point and the point it gets + // projected onto + DLIB_TEST_MSG(abs(df(test_point) - err) < 1e-10, abs(df(test_point) - err)); + // while we are at it make sure the squared norm in the distance function is right + double df_error = abs(df.get_squared_norm() - trans(df.get_alpha())*kernel_matrix(kern, samples)*df.get_alpha()); + DLIB_TEST_MSG( df_error < 1e-10, df_error); + } + + + + } + } + + template + void test_with_kernel(const kernel_type& kern) + { + typedef typename kernel_type::sample_type sample_type; + + empirical_kernel_map ekm, ekm2, ekm3; + + for (int j = 0; j < 10; ++j) + { + sample_type samp; + std::vector samples; + std::vector proj_samples; + print_spinner(); + const int num = rnd.get_random_8bit_number()%200 + 1; + // make some random samples + for (int i = 0; i < num; ++i) + { + samples.push_back(randm(4,1,rnd)); + } + // add on a little bit to make sure there is at least one non-zero sample. If all the + // samples are zero then empirical_kernel_map_error will be thrown and we don't want that. + samples.front()(0) += 0.001; + + ekm2.load(kern, samples); + DLIB_TEST(ekm2.basis_size() == samples.size()); + for (unsigned long i = 0; i < samples.size(); ++i) + DLIB_TEST(dlib::equal(ekm2[i] , samples[i])); + + // test serialization + ostringstream sout; + serialize(ekm2, sout); + ekm2.clear(); + istringstream sin(sout.str()); + deserialize(ekm3, sin); + // also test swap + ekm3.swap(ekm); + DLIB_TEST(ekm.get_kernel() == kern); + DLIB_TEST(ekm.out_vector_size() != 0); + DLIB_TEST(ekm2.out_vector_size() == 0); + DLIB_TEST(ekm3.out_vector_size() == 0); + + + + // project all the samples into kernel space + for (unsigned long i = 0; i < samples.size(); ++i) + { + proj_samples.push_back(ekm.project(samples[i])); + } + + DLIB_TEST(max(abs(kernel_matrix(kern, samples) - kernel_matrix(linear_kernel(), proj_samples))) < 1e-12); + DLIB_TEST(ekm.out_vector_size() == proj_samples[0].size()); + + for (int i = 0; i < 30; ++i) + { + const unsigned long idx1 = rnd.get_random_32bit_number()%samples.size(); + const unsigned long idx2 = rnd.get_random_32bit_number()%samples.size(); + decision_function dec_funct = ekm.convert_to_decision_function(proj_samples[idx1]); + distance_function dist_funct = ekm.convert_to_distance_function(proj_samples[idx1]); + + // make sure the distances match + const double dist_error = abs(length(proj_samples[idx1] - proj_samples[idx2]) - dist_funct(samples[idx2])); + DLIB_TEST_MSG( dist_error < 1e-6, dist_error); + // make sure the dot products match + DLIB_TEST(abs(dot(proj_samples[idx1],proj_samples[idx2]) - dec_funct(samples[idx2])) < 1e-10); + + // also try the dec_funct with samples that weren't in the original set + samp = 100*randm(4,1,rnd); + // make sure the dot products match + DLIB_TEST(abs(dot(proj_samples[idx1],ekm.project(samp)) - dec_funct(samp)) < 1e-10); + samp = randm(4,1,rnd); + // make sure the dot products match + DLIB_TEST(abs(dot(proj_samples[idx1],ekm.project(samp)) - dec_funct(samp)) < 1e-10); + } + + + + proj_samples.clear(); + + + // now do the projection but use the projection_function returned by get_projection_function() + projection_function proj2 = ekm.get_projection_function(); + projection_function proj; + sout.clear(); + sout.str(""); + sin.clear(); + sin.str(""); + // test serialization + serialize(proj2, sout); + sin.str(sout.str()); + deserialize(proj, sin); + + for (unsigned long i = 0; i < samples.size(); ++i) + { + proj_samples.push_back(proj(samples[i])); + } + + DLIB_TEST(max(abs(kernel_matrix(kern, samples) - kernel_matrix(linear_kernel(), proj_samples))) < 1e-12); + DLIB_TEST(ekm.out_vector_size() == proj_samples[0].size()); + DLIB_TEST(proj.out_vector_size() == proj_samples[0].size()); + + ekm.clear(); + DLIB_TEST(ekm.out_vector_size() == 0); + DLIB_TEST(ekm2.out_vector_size() == 0); + DLIB_TEST(ekm3.out_vector_size() == 0); + + + for (int i = 0; i < 30; ++i) + { + const unsigned long idx1 = rnd.get_random_32bit_number()%samples.size(); + const unsigned long idx2 = rnd.get_random_32bit_number()%samples.size(); + decision_function dec_funct = convert_to_decision_function(proj,proj_samples[idx1]); + + // make sure the dot products match + DLIB_TEST(abs(dot(proj_samples[idx1],proj_samples[idx2]) - dec_funct(samples[idx2])) < 1e-10); + + // also try the dec_funct with samples that weren't in the original set + samp = 100*randm(4,1,rnd); + // make sure the dot products match + DLIB_TEST(abs(dot(proj_samples[idx1],proj(samp)) - dec_funct(samp)) < 1e-10); + samp = randm(4,1,rnd); + // make sure the dot products match + DLIB_TEST(abs(dot(proj_samples[idx1],proj(samp)) - dec_funct(samp)) < 1e-10); + } + + + + + + } + + for (int j = 1; j <= 20; ++j) + { + dlog << LTRACE << "j: " << j; + sample_type samp, samp2; + std::vector samples1; + std::vector samples2; + print_spinner(); + // make some random samples. At the end samples1 will be a subset of samples2 + for (int i = 0; i < 5*j; ++i) + { + samples1.push_back(randm(10,1,rnd)); + samples2.push_back(samples1.back()); + } + for (int i = 0; i < 5*j; ++i) + { + samples2.push_back(randm(10,1,rnd)); + } + // add on a little bit to make sure there is at least one non-zero sample. If all the + // samples are zero then empirical_kernel_map_error will be thrown and we don't want that. + samples1.front()(0) += 0.001; + samples2.front()(0) += 0.001; + + ekm.load(kern, samples1); + for (unsigned long i = 0; i < samples1.size(); ++i) + DLIB_TEST(dlib::equal(ekm[i] , samples1[i])); + DLIB_TEST(ekm.basis_size() == samples1.size()); + ekm2.load(kern, samples2); + DLIB_TEST(ekm2.basis_size() == samples2.size()); + + dlog << LTRACE << "ekm.out_vector_size(): " << ekm.out_vector_size(); + dlog << LTRACE << "ekm2.out_vector_size(): " << ekm2.out_vector_size(); + const double eps = 1e-6; + + matrix transform; + // Make sure transformations back to yourself work right. Note that we can't just + // check that transform is the identity matrix since it might be an identity transform + // for only a subspace of vectors (this happens if the ekm maps points into a subspace of + // all possible ekm.out_vector_size() vectors). + transform = ekm.get_transformation_to(ekm); + DLIB_TEST(transform.nr() == ekm.out_vector_size()); + DLIB_TEST(transform.nc() == ekm.out_vector_size()); + for (unsigned long i = 0; i < samples1.size(); ++i) + { + samp = ekm.project(samples1[i]); + DLIB_TEST_MSG(length(samp - transform*samp) < eps, length(samp - transform*samp)); + samp = ekm.project((samples1[0] + samples1[i])/2); + DLIB_TEST_MSG(length(samp - transform*samp) < eps, length(samp - transform*samp)); + } + + transform = ekm2.get_transformation_to(ekm2); + DLIB_TEST(transform.nr() == ekm2.out_vector_size()); + DLIB_TEST(transform.nc() == ekm2.out_vector_size()); + for (unsigned long i = 0; i < samples2.size(); ++i) + { + samp = ekm2.project(samples2[i]); + DLIB_TEST_MSG(length(samp - transform*samp) < eps, length(samp - transform*samp)); + samp = ekm2.project((samples2[0] + samples2[i])/2); + DLIB_TEST_MSG(length(samp - transform*samp) < eps, length(samp - transform*samp)); + //dlog << LTRACE << "mapping error: " << length(samp - transform*samp); + } + + + // now test the transform from ekm to ekm2 + transform = ekm.get_transformation_to(ekm2); + DLIB_TEST(transform.nr() == ekm2.out_vector_size()); + DLIB_TEST(transform.nc() == ekm.out_vector_size()); + for (unsigned long i = 0; i < samples1.size(); ++i) + { + samp = ekm.project(samples1[i]); + distance_function df1 = ekm.convert_to_distance_function(samp); + distance_function df2 = ekm2.convert_to_distance_function(transform*samp); + DLIB_TEST_MSG(df1(df2) < eps, df1(df2)); + //dlog << LTRACE << "mapping error: " << df1(df2); + + + samp = ekm.project((samples1[0] + samples1[i])/2); + df1 = ekm.convert_to_distance_function(samp); + df2 = ekm2.convert_to_distance_function(transform*samp); + DLIB_TEST_MSG(df1(df2) < eps, df1(df2)); + } + + + } + } + + void perform_test ( + ) + { + ++thetime; + typedef matrix sample_type; + dlog << LINFO << "time seed: " << thetime; + rnd.set_seed(cast_to_string(thetime)); + + print_spinner(); + test_projection_error(); + print_spinner(); + dlog << LINFO << "test with linear kernel"; + test_with_kernel(linear_kernel()); + print_spinner(); + dlog << LINFO << "test with rbf kernel"; + test_with_kernel(radial_basis_kernel(0.2)); + print_spinner(); + } + }; + + // Create an instance of this object. Doing this causes this test + // to be automatically inserted into the testing framework whenever this cpp file + // is linked into the project. Note that since we are inside an unnamed-namespace + // we won't get any linker errors about the symbol a being defined multiple times. + empirical_kernel_map_tester a; + +} + + diff --git a/ml/dlib/dlib/test/entropy_coder.cpp b/ml/dlib/dlib/test/entropy_coder.cpp new file mode 100644 index 000000000..12a9a3305 --- /dev/null +++ b/ml/dlib/dlib/test/entropy_coder.cpp @@ -0,0 +1,587 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include + +#include +#include + +#include "tester.h" + +namespace +{ + + using namespace test; + using namespace std; + using namespace dlib; + + logger dlog("test.entropy_coder"); + + namespace entropy_coder_kernel_test_helpers + { + template < + typename encoder, + typename decoder + > + std::string test ( + const std::string& input + ) + /*! + ensures + - encodes the data from in and then tries to decode it and returns + "" if it was successfully decoded else it returns the decoded string + !*/ + { + ostringstream sout; + istringstream sin; + istringstream in; + + + in.str(input); + + const unsigned long max_total = 65535; + + + + + unsigned long counts[256]; + for (int i = 0; i < 256; ++i) + { + counts[i] = 1; + } + + + encoder e; + + + + DLIB_TEST(e.stream_is_set() == false); + + e.set_stream(sout); + + DLIB_TEST(e.stream_is_set() == true); + DLIB_TEST(&e.get_stream() == &sout); + + unsigned char ch; + + unsigned long total = 256; + + while (in.read((char*)&ch,1)) + { + if (total > max_total) + { + total = 0; + for (int j = 0; j<256; ++j) + { + counts[j] >>= 1; + if (counts[j] == 0) + counts[j] = 1; + total += counts[j]; + + } + } + + unsigned long low_count = 0; + unsigned long high_count; + for (int i = 0; i < ch; ++i) + low_count += counts[i]; + high_count = low_count + counts[ch]; + + e.encode(low_count,high_count,total); + + + ++total; + counts[ch] += 1; + } + + DLIB_TEST(e.stream_is_set() == true); + DLIB_TEST(&e.get_stream() == &sout); + + + + e.clear(); + + DLIB_TEST(e.stream_is_set() == false); + + + // ***************************************** + + + decoder d; + + + DLIB_TEST(d.stream_is_set() == false); + DLIB_TEST(d.get_target_called() == false); + + sin.str(sout.str()); + sout.str(""); + + d.set_stream(sin); + + DLIB_TEST(d.get_target_called() == false); + + DLIB_TEST(d.stream_is_set() == true); + DLIB_TEST(&d.get_stream() == &sin); + + for (int i = 0; i < 256; ++i) + { + counts[i] = 1; + } + + total = 256; + + for (string::size_type i = 0; i < input.size() ; ++i) + { + if (total > max_total) + { + total = 0; + for (int j = 0; j<256; ++j) + { + counts[j] >>= 1; + if (counts[j] == 0) + counts[j] = 1; + total += counts[j]; + + } + } + + DLIB_TEST(d.get_target_called() == false); + + unsigned long target = d.get_target(total); + + DLIB_TEST(target < total); + + DLIB_TEST(d.get_target_called() == true); + + + unsigned long low_count; + unsigned long high_count = 0; + + unsigned long j; + for (j = 0; high_count <= target; ++j) + { + high_count += counts[j]; + } + --j; + low_count = high_count - counts[j]; + + + ch = static_cast(j); + + + sout.rdbuf()->sputn((char*)&ch,1); + + + + d.decode(low_count,high_count); + DLIB_TEST(d.get_target_called() == false); + ++total; + counts[ch] += 1; + + } + + DLIB_TEST(d.stream_is_set() == true); + DLIB_TEST(&d.get_stream() == &sin); + + d.clear(); + + DLIB_TEST(d.stream_is_set() == false); + DLIB_TEST_MSG(sout.str().size() == input.size(),"the test script is buggy"); + + + if (sout.str() == input) + return ""; + else + return sout.str(); + + } + + } + + + + + template < + typename encoder, + typename decoder + > + void entropy_coder_kernel_test ( + ) + /*! + requires + - encoder is an implementation of entropy_encoder/entropy_encoder_kernel_abstract.h + - decoder is an implementation of entropy_decoder/entropy_decoder_kernel_abstract.h + ensures + - runs tests on encoder and decoder for compliance with the specs + !*/ + { + using namespace entropy_coder_kernel_test_helpers; + + dlog << LTRACE << 1; + + print_spinner(); + string temp, temp2; + + srand(static_cast(time(0))); + + for (int k = 0; k < 10000; ++k) + { + string temp; + istringstream sin; + ostringstream sout; + decoder d; + encoder e; + + e.set_stream(sout); + + int num = ::rand() %200; + unsigned long total[200]; + unsigned long high_count[200]; + unsigned long low_count[200]; + for (int i = 0; i < num; ++i) + { + total[i] = ::rand()%256 + 20; + high_count[i] = ::rand()%total[i] + 1; + low_count[i] = ::rand()%high_count[i]; + + e.encode(low_count[i],high_count[i],total[i]); + } + + e.clear(); + + sout.rdbuf()->sputc('a'); + + sin.str(sout.str()); + + + d.set_stream(sin); + + + for (int i = 0; i < num; ++i) + { + unsigned long N = d.get_target(total[i]); + DLIB_TEST(low_count[i] <= N && N < high_count[i]); + d.decode(low_count[i],high_count[i]); + } + + + + + + + DLIB_TEST_MSG(sin.rdbuf()->sgetc() != EOF,"num: " << num); + DLIB_TEST_MSG(sin.rdbuf()->sgetc() == 'a', + "sin.rdbuf()->sgetc() == " << (char)sin.rdbuf()->sgetc() << + "\nnum: " << num + ); + DLIB_TEST(sin.rdbuf()->sbumpc() == 'a'); + DLIB_TEST(sin.rdbuf()->sgetc() == EOF); + + } // for (int k = 0; k < 10000; ++k) + + dlog << LTRACE << 2; + + print_spinner(); + + // the point of this block is to make sure that the return value + // from decoder.get_target(total) is a always less than total + for (int k = 0; k < 20; ++k) + { + string temp; + temp.push_back(static_cast(::rand()&0xff)); + istringstream sin(temp); + decoder d; + d.set_stream(sin); + unsigned long total = ::rand()%256 + 20; + unsigned long target = d.get_target(total); + DLIB_TEST(target target) + low_count = target; + + d.decode(low_count,high_count); + target = d.get_target(total); + DLIB_TEST_MSG(target(time(0)); + srand(seed1 ); + int array[65536]; + for (int i = 0; i < 65536; ++i) + { + array[i] = ::rand()%256; + } + for (int i = 0; i < 60; ++i) + { + int idx = ::rand()%65536; + int radius = 35; + if (idx > radius && idx <65536-radius) + { + for (int j = idx-radius; j < idx+radius; ++j) + array[j] = array[idx]; + } + } + + // test with 3 random strings of length 10000 + // but use the above array to bias the random numbers + for (int j = 0; j < 3; ++j) + { + print_spinner(); + temp = ""; + //seed2 = static_cast(time(0)); + srand(seed2 ); + for ( int i = 0; i < 10000; ++i) + { + int a = array[::rand()%65536]; + temp += (unsigned char)a; + } + string temp2; + temp2 = test(temp); + if (temp2 != "") + { + + int k = 0; + DLIB_TEST(temp != temp2); + while (temp[k] == temp2[k])++k; + } + + + DLIB_TEST_MSG(temp2 == "",""); + } + } + + print_spinner(); + + + dlog << LTRACE << 4; + + + + + // test with a large string which contains all the same character + temp = "eeeeeeeeee"; + for (int i = 0; i < 13; ++i) + { + temp = temp + temp; + } + temp = test(temp); + if (temp != "") + { + // crop off all the e's until we find the part that is messed up + string::size_type pos = temp.find_first_not_of("e"); + temp = temp.substr(pos); + } + DLIB_TEST_MSG(temp == "","decoded string: \"" << temp << "\""); /**/ + + + dlog << LTRACE << 5; + + print_spinner(); + + temp = "davis"; + temp = test(temp); DLIB_TEST_MSG(temp == "","decoded string: \"" << temp << "\""); + + temp = ""; + temp = test(temp); DLIB_TEST_MSG(temp == "","decoded string: \"" << temp << "\""); + + // test for each single character + for ( int i = 0; i <= 255; ++i) + { + temp = (unsigned char)i; + temp = test(temp); DLIB_TEST_MSG(temp == "","decoded string: \"" << temp << "\""); + } + + dlog << LTRACE << 6; + + // test with a long string with the same thing repeated many times + temp = "davis "; + for (int i = 0; i < 10; ++i) + { + temp = temp + temp; + } + temp = test(temp); DLIB_TEST_MSG(temp == "","decoded string: \"" << temp << "\""); + + dlog << LTRACE << 7; + + // test with 10 random strings of length 1000 + for (int j = 0; j < 10; ++j) + { + temp = ""; + srand(static_cast(time(0))); + for ( int i = 0; i < 1000; ++i) + { + int a = ::rand()%256; + temp += (unsigned char)a; + } + temp = test(temp); DLIB_TEST_MSG(temp == "","decoded string: \"" << temp << "\""); + } + + + dlog << LTRACE << 8; + + print_spinner(); + + // test with 15 random strings of length 30000 + for (int j = 0; j < 15; ++j) + { + print_spinner(); + temp = ""; + unsigned long seed = static_cast(time(0)); + srand(seed); + for ( int i = 0; i < 30000; ++i) + { + int a = ::rand()%256; + temp += (unsigned char)a; + } + temp = test(temp); DLIB_TEST_MSG(temp == "","seed: " << seed); + } + + + dlog << LTRACE << 9; + + print_spinner(); + + // test with a large string which contains all the same character + temp = " "; + for (int i = 0; i < 10; ++i) + { + temp = temp + temp; + } + temp = test(temp); + if (temp != "") + { + // crop off all the spacess until we find the part that is messed up + string::size_type pos = temp.find_first_not_of(" "); + temp = temp.substr(pos); + } + DLIB_TEST_MSG(temp == "","decoded string: \"" << temp << "\"");/**/ + + + + + + + dlog << LTRACE << 10; + + + + + + + // test with a large string which contains a bunch of a's followed by a + // bunch of z's + temp = "aaaaaaaa"; + temp2 = "zzzzzzzz"; + for (int i = 0; i < 12; ++i) + { + temp = temp + temp; + temp2 = temp2 + temp2; + } + temp += temp2; + print_spinner(); + temp = test(temp); + DLIB_TEST(temp == ""); + + + + dlog << LTRACE << 11; + + + + } + + + + + class entropy_coder_tester : public tester + { + public: + entropy_coder_tester ( + ) : + tester ("test_entropy_coder", + "Runs tests on the entropy_coder component.") + {} + + void perform_test ( + ) + { + dlog << LINFO << "testing kernel_1a 1"; + entropy_coder_kernel_test< + entropy_encoder::kernel_1a, + entropy_decoder::kernel_1a + >(); + + dlog << LINFO << "testing kernel_1a_c 2"; + entropy_coder_kernel_test< + entropy_encoder::kernel_1a_c, + entropy_decoder::kernel_1a_c + >(); + + dlog << LINFO << "testing kernel_1a 3"; + entropy_coder_kernel_test< + entropy_encoder::kernel_2a, + entropy_decoder::kernel_2a + >(); + + dlog << LINFO << "testing kernel_1a_c 4"; + entropy_coder_kernel_test< + entropy_encoder::kernel_2a_c, + entropy_decoder::kernel_2a_c + >(); + + dlog << LINFO << "testing kernel_1a 5"; + entropy_coder_kernel_test< + entropy_encoder::kernel_1a, + entropy_decoder::kernel_1a_c + >(); + + dlog << LINFO << "testing kernel_1a_c 6"; + entropy_coder_kernel_test< + entropy_encoder::kernel_1a_c, + entropy_decoder::kernel_1a + >(); + + dlog << LINFO << "testing kernel_1a 7"; + entropy_coder_kernel_test< + entropy_encoder::kernel_2a, + entropy_decoder::kernel_2a_c + >(); + + dlog << LINFO << "testing kernel_1a_c 8"; + entropy_coder_kernel_test< + entropy_encoder::kernel_2a_c, + entropy_decoder::kernel_2a + >(); + + } + } a; + + + + +} + diff --git a/ml/dlib/dlib/test/entropy_encoder_model.cpp b/ml/dlib/dlib/test/entropy_encoder_model.cpp new file mode 100644 index 000000000..0276cbe77 --- /dev/null +++ b/ml/dlib/dlib/test/entropy_encoder_model.cpp @@ -0,0 +1,198 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include + +#include +#include +#include +#include +#include "tester.h" + +namespace +{ + using namespace test; + using namespace std; + using namespace dlib; + + logger dlog("test.entropy_coder_model"); + + template < + typename ee, + typename ed + > + void entropy_encoder_model_kernel_test ( + ) + /*! + requires + - ee is an implementation of entropy_encoder_model/entropy_encoder_model_kernel_abstract.h + the alphabet_size for ee is 256 + - ed is an implementation of entropy_decoder_model/entropy_decoder_model_kernel_abstract.h + the alphabet_size for ed is 256 + - ee and ed must share the same kernel number + ensures + - runs tests on ee and ed for compliance with the specs + !*/ + { + + print_spinner(); + srand(static_cast(time(0))); + + typedef typename ee::entropy_encoder_type ee_type; + typedef typename ed::entropy_decoder_type ed_type; + + + + { + + ee_type ecoder; + ed_type dcoder; + + ee elen(ecoder); + ed dlen(dcoder); + ee elit(ecoder); + ed dlit(dcoder); + + + istringstream sin; + ostringstream sout; + + ecoder.set_stream(sout); + + + unsigned long temp; + + + elen.encode(0); + elit.encode(9); + + elen.encode(0); + elit.encode(0); + + elen.encode(0); + elit.encode(4); + + elen.encode(0); + elit.encode(0); + + elen.encode(0); + elit.encode(2); + + elen.encode(0); + elit.encode(0); + + + + + + + + ecoder.clear(); + sin.str(sout.str()); + dcoder.set_stream(sin); + + + dlen.decode(temp); + DLIB_TEST(temp == 0); + dlit.decode(temp); + DLIB_TEST(temp == 9); + + dlen.decode(temp); + DLIB_TEST(temp == 0); + dlit.decode(temp); + DLIB_TEST(temp == 0); + + dlen.decode(temp); + DLIB_TEST(temp == 0); + dlit.decode(temp); + DLIB_TEST(temp == 4); + + dlen.decode(temp); + DLIB_TEST(temp == 0); + dlit.decode(temp); + DLIB_TEST(temp == 0); + + dlen.decode(temp); + DLIB_TEST(temp == 0); + dlit.decode(temp); + DLIB_TEST(temp == 2); + + dlen.decode(temp); + DLIB_TEST(temp == 0); + dlit.decode(temp); + DLIB_TEST(temp == 0); + + + + + } + + } + + + + + class entropy_encoder_model_tester : public tester + { + public: + entropy_encoder_model_tester ( + ) : + tester ("test_entropy_coder_model", + "Runs tests on the entropy_encoder_model and entropy_decoder_model components.") + {} + + void perform_test ( + ) + { + typedef entropy_encoder::kernel_2a_c ee; + typedef entropy_decoder::kernel_2a_c ed; + + dlog << LINFO << "testing kernel_1a"; + entropy_encoder_model_kernel_test< + entropy_encoder_model<256,ee>::kernel_1a, + entropy_decoder_model<256,ed>::kernel_1a>(); + + dlog << LINFO << "testing kernel_2a"; + entropy_encoder_model_kernel_test< + entropy_encoder_model<256,ee>::kernel_2a, + entropy_decoder_model<256,ed>::kernel_2a>(); + + dlog << LINFO << "testing kernel_3a"; + entropy_encoder_model_kernel_test< + entropy_encoder_model<256,ee>::kernel_3a, + entropy_decoder_model<256,ed>::kernel_3a>(); + + dlog << LINFO << "testing kernel_4a"; + entropy_encoder_model_kernel_test< + entropy_encoder_model<256,ee>::kernel_4a, + entropy_decoder_model<256,ed>::kernel_4a>(); + + dlog << LINFO << "testing kernel_4b"; + entropy_encoder_model_kernel_test< + entropy_encoder_model<256,ee>::kernel_4b, + entropy_decoder_model<256,ed>::kernel_4b>(); + + dlog << LINFO << "testing kernel_5a"; + entropy_encoder_model_kernel_test< + entropy_encoder_model<256,ee>::kernel_5a, + entropy_decoder_model<256,ed>::kernel_5a>(); + + dlog << LINFO << "testing kernel_5c"; + entropy_encoder_model_kernel_test< + entropy_encoder_model<256,ee>::kernel_5c, + entropy_decoder_model<256,ed>::kernel_5c>(); + + dlog << LINFO << "testing kernel_6a"; + entropy_encoder_model_kernel_test< + entropy_encoder_model<256,ee>::kernel_6a, + entropy_decoder_model<256,ed>::kernel_6a>(); + + } + } a; + +} + diff --git a/ml/dlib/dlib/test/example.cpp b/ml/dlib/dlib/test/example.cpp new file mode 100644 index 000000000..4cf927159 --- /dev/null +++ b/ml/dlib/dlib/test/example.cpp @@ -0,0 +1,72 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "tester.h" + +// This is called an unnamed-namespace and it has the effect of making everything +// inside this file "private" so that everything you declare will have static linkage. +// Thus we won't have any multiply defined symbol errors coming out of the linker when +// we try to compile the test suite. +namespace +{ + using namespace test; + // Declare the logger we will use in this test. The name of the logger + // should start with "test." + dlib::logger dlog("test.example"); + + + class example_tester : public tester + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a unit test. When it is constructed + it adds itself into the testing framework. + !*/ + public: + example_tester ( + ) : + tester ( + "test_example", // the command line argument name for this test + "Run example tests.", // the command line argument description + 0 // the number of command line arguments for this test + ) + {} + + void perform_test ( + ) + { + // This message gets logged to the file debug.txt if the user has enabled logging by + // supplying the -d option on the command line (and they haven't set the logging level + // to something higher than LINFO). + dlog << dlib::LINFO << "some message you want to log"; + + // This test is considered a success if this function doesn't throw an exception. + // So we can use the DLIB_TEST_MSG macro to perform our tests since it throws an + // exception containing a message if its first argument is false. + + // make sure 3 is bigger than 2 + DLIB_TEST_MSG(3 > 2,"This message prints if your compiler doesn't know 3 is bigger than 2"); + + // make sure 5 is not equal to 9 + DLIB_TEST_MSG(5 != 9,"This message prints if your compiler thinks 5 is the same as 9"); + + // This is a form of test you can use when you don't care about having a message + DLIB_TEST(5 != 8); + + // If your test takes a long time to run you can also call print_spinner() + // periodically. This will cause a spinning / character to display on the + // console to indicate to the user that your test is still running (rather + // than hung) + print_spinner(); + } + }; + + // Create an instance of this object. Doing this causes this test + // to be automatically inserted into the testing framework whenever this cpp file + // is linked into the project. Note that since we are inside an unnamed-namespace + // we won't get any linker errors about the symbol a being defined multiple times. + example_tester a; + +} + + diff --git a/ml/dlib/dlib/test/example_args.cpp b/ml/dlib/dlib/test/example_args.cpp new file mode 100644 index 000000000..573216c79 --- /dev/null +++ b/ml/dlib/dlib/test/example_args.cpp @@ -0,0 +1,75 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "tester.h" + +// This is called an unnamed-namespace and it has the effect of making everything +// inside this file "private" so that everything you declare will have static linkage. +// Thus we won't have any multiply defined symbol errors coming out of the linker when +// we try to compile the test suite. +namespace +{ + // Declare the logger we will use in this test. The name of the logger + // should start with "test." + dlib::logger dlog("test.example_args"); + + using namespace test; + + class example_args_tester : public tester + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a unit test. When it is constructed + it adds itself into the testing framework. + + This particular test requires the user to supply a command line + argument when they run it. + !*/ + public: + example_args_tester ( + ) : + tester ( + "test_example_args", // the command line argument name for this test + "Run example tests with argument.", // the command line argument description + 1 // the number of command line arguments for this test + ) + {} + + void perform_test ( + const std::string& arg + ) + { + // This message gets logged to the file debug.txt if the user has enabled logging by + // supplying the -d option on the command line (and they haven't set the logging level + // to something higher than LINFO). + dlog << dlib::LINFO << "some message you want to log"; + dlog << dlib::LINFO << "the argument passed to this test was " << arg; + + // This test is considered a success if this function doesn't throw an exception. + // So we can use the DLIB_TEST_MSG macro to perform our tests since it throws an + // exception containing a message if its first argument is false. + + // make sure 3 is bigger than 2 + DLIB_TEST_MSG(3 > 2,"This message prints if your compiler doesn't know 3 is bigger than 2"); + + // make sure 5 is not equal to 9 + DLIB_TEST_MSG(5 != 9,"This message prints if your compiler thinks 5 is the same as 9"); + + // If your test takes a long time to run you can also call print_spinner() + // periodically. This will cause a spinning / character to display on the + // console to indicate to the user that your test is still running (rather + // than hung) + print_spinner(); + } + + }; + + // Create an instance of this object. Doing this causes this test + // to be automatically inserted into the testing framework whenever this cpp file + // is linked into the project. Note that since we are inside an unnamed-namespace + // we won't get any linker errors about the symbol a being defined multiple times. + example_args_tester a; +} + + + diff --git a/ml/dlib/dlib/test/examples/CMakeLists.txt b/ml/dlib/dlib/test/examples/CMakeLists.txt new file mode 100644 index 000000000..93bd9a139 --- /dev/null +++ b/ml/dlib/dlib/test/examples/CMakeLists.txt @@ -0,0 +1,8 @@ + +# Disable some warnings from gcc when compiling the examples because fixing them would make the +# examples harder to read. +if (CMAKE_COMPILER_IS_GNUCXX) + add_definitions("-Wno-comment -Wno-unused-parameter") +endif() + +add_subdirectory(../../../examples examples_build) diff --git a/ml/dlib/dlib/test/face.cpp b/ml/dlib/dlib/test/face.cpp new file mode 100644 index 000000000..1bb1a94b6 --- /dev/null +++ b/ml/dlib/dlib/test/face.cpp @@ -0,0 +1,360 @@ +// Copyright (C) 2014 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "tester.h" +#include +#include +#include +#include +#include +#include +#include + +//#include +//#include + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + dlib::logger dlog("test.face"); + + + class face_tester : public tester + { + public: + face_tester ( + ) : + tester ( + "test_face", // the command line argument name for this test + "Run tests on the face detection/landmarking modules.", // the command line argument description + 0 // the number of command line arguments for this test + ) + { + } + + + void get_test_face_landmark_dataset ( + dlib::array >& images, + std::vector >& objects + ) + { + istringstream sin(get_decoded_string()); + images.resize(1); + objects.resize(1); + load_dng(images[0], sin); + pyramid_up(images[0]); + deserialize(objects[0], sin); + } + + void perform_test() + { + print_spinner(); + dlib::array > images; + std::vector > objects; + get_test_face_landmark_dataset(images, objects); + + frontal_face_detector detector = get_frontal_face_detector(); + + print_spinner(); + shape_predictor_trainer trainer; + trainer.set_tree_depth(2); + trainer.set_nu(0.05); + //trainer.be_verbose(); + + shape_predictor sp = trainer.train(images, objects); + + print_spinner(); + + // It should have been able to perfectly fit the data + DLIB_TEST(test_shape_predictor(sp, images, objects) == 0); + + print_spinner(); + + // While we are here, make sure the default face detector works + std::vector dets = detector(images[0]); + DLIB_TEST(dets.size() == 3); + + + /* + // visualize the detections + std::vector shapes; + for (unsigned long j = 0; j < dets.size(); ++j) + { + full_object_detection shape = sp(images[0], dets[j]); + shapes.push_back(shape); + } + image_window win(images[0]); + win.add_overlay(render_face_detections(shapes)); + cin.get(); + */ + + } + + + // ------------------------------------------------------------------------------------ + + // This function returns the contents of the file 'test_faces.dat' + const std::string get_decoded_string() + { + dlib::base64 base64_coder; + dlib::compress_stream::kernel_1ea compressor; + std::ostringstream sout; + std::istringstream sin; + + // The base64 encoded data from the file 'test_faces.dat' we want to decode and return. + sout << "RFYXmMn7UA64INJ2Umw+YCh5xX6v+y3bqV/1EKP3ZtvdIWxDCoyZ8oJjj/LXTw3PyEMQReTyXO+8"; + sout << "423ExTTLMRLxc5gEI4PK8vLMVWdRjRKldYfFLisH5qk7o6TxYfRXqmch/t4c53UfsUuJuVMvA+nf"; + sout << "05tfgx0Z2VgrlMmvKSxX0VHEjQ3ZVqQl0mLeur1qoMRmcPjYA4QJOvcruwyz/hFpVU8snvsri4QA"; + sout << "hkCrOtvl/iTlLvcWXDM98OXEYLk4vr6z2v73mGGEdbaJFafXOgutY0NRESpJfingJkuZKXNV7fQ8"; + sout << "F1DmHwFM0Izoc3fUyN7+xtLo2LSEH8uohv79wjxdbm71n5kgHb6cnpYiiCLOgqKrK+qJUJG1/OBB"; + sout << "9IORGL6SHrogefDg76m4ayil0lE4pQLa1fKhXGqphdxDMyBXHkOpxA356i5Evk7jZb/AqHUu/hQV"; + sout << "QYOsW20TRhuFfrasrb/Veq10hMZomosAnNb+Uhgnu2Ip4igX6ozJqmCL4NngjEgoXuCA02KzmIqK"; + sout << "fcslrBjyg7WZnF2w7UZXcHoZ0AVijb/ANPUqvq9H4k5WCFh53pwG6oj+CG5N7EXLzt9UHtillGJH"; + sout << "S252/dLvzaNxSq0vDP+Y0IvtKhZgIlmX3duu+L7iJUMinE90OXQFzhwCaOM8oP3Mq+Becgto5Vjb"; + sout << "vv+2xeyrYL9hBfMyPzbBhVWZdaPyI62MGMPjTfmAUmufWm3/Pxey5Jw6MUNX1rfWMXYQTtamdtcE"; + sout << "XmuFGhLCWbnJN2GtlOWu+S0LdWAziktr0mdovdFZTwKnd2Fuf1VeMMHBCgF5D56UK6/JTXO1gPlj"; + sout << "VpavizirUH46rHVnGOFZHWarPKymXyrQ/VhFbVBSOi9UGbsMo9sYNYhvFlzSBN5orNoY+5AP+Y2X"; + sout << "BDw342Vg3YRIZBqaXZ7oOqyDtjFYNf2g7ColnaFrPiYCKg9VQvVQVXvr4Oa6TTryepXylHk11ijw"; + sout << "xYgjiENTDKVuXPLVrAEWZD6wq5+LGbmj1cgh59XgLErkoliKxVeEv4ktXqceCFK9YWljqJZSBxFj"; + sout << "T64N945O47Oek591sfRFYrjebJz3kooaaOQ7lm4BWpW64Am/Od0FIkggSzPUgf+jYrnPdH4oqv3d"; + sout << "NJIXJO8aeCi/WCqgbdgLwpu1MB8cTdbK2TrCbi3sHhQddE6rqZ1XIg25gYTw72q4ZFUHmRrBEgdz"; + sout << "i2xh7qHevCWC9Ht7HrUcLl35/UCcNck/ftiDa/xN3rBJp+cJCyS9ZXScFbgFKYfBau8Dx+JA3ygU"; + sout << "pvYbUP29UeK5hXZoQzC74UKw/liapW8GbR3ZPV+59xmw1b7phyApZfODkDLG5lyTB5Co6vIfmqXp"; + sout << "DSA1I0ZPIk7hJHWtddgvAg6Yv4GhMZs/cn08Z/ZRXYSMw20rA66BeMD+IPFe3MPFPnKsl/qRvXIC"; + sout << "gTmLcjPQ5naBo/3haQg4TjJCBCErTC2JktJE0vRfm8v+1kEFoycTsRfhDZtjMnqmqSiNO3VSbnbb"; + sout << "et/mFvCdd9kqBXQ4VaeYdJwORA7/TocC84G91jlfNRhbY2KW9xHXYYwscfaV5DtPWhyJBEimzLsz"; + sout << "F59UbkO/WOT7HEwJ3p1/ReH6feLrIYwR8OdSeuOUtK2oboodiUmx/if9pyNONd0If9jFGDXvWP/r"; + sout << "dL6NdRh68swmPyiCn7sAwaUbd7PF7K+jqgk3jMDMaoWE6p+aAFK3H/JxRO7th2XvX57uA+RKoXEq"; + sout << "LBJrDkoPGbA5Ctj6KUWknMM62HVIO1SjwZH7ROOlCVbHvDd3JTT6TNs2Lj8g78q6WRMnDVV0L9Q7"; + sout << "S6awjHePvyj9fDn66jpmRxI9TkqNPx0b8tGfOnUoXGKK/WcInBjVd+FYy1SO5FIVZ4JQdiSMgehu"; + sout << "EU7X2yjbdgCtVg+kY3ZgYcCMM72NLdm7+U+qVxzfY9bNtzLRsEDNFL2wnq3DkbRRiHKHswEugXjG"; + sout << "fFfzqjU71sDH3gc1dhOQT456XK0zhmV/62+N+oTxrFY2F/ArALeDiR1q1JOrE1UDzU/ezAZVWS/a"; + sout << "DJ8O9TVZsbOXOU4hAiQ6mR5hh6i5c/zb7jBl6ehNgixrJqguo9PT10P3S6I+IiTG2vHlws4ZoTmz"; + sout << "7iOj1ICCN6ivwPj5+afYScxPGLVujxWTW8ksey8RjSIbH+N9wUEAJAcZAGXo2UrP93uRkdLt/cvi"; + sout << "Q4eIONlUrJ79hlo1xyvPmyNw9Ye1zG6epUK+lsXxtUbtX+uX+sDM/Gkr5QUdzzorfa4Mj9vEglQp"; + sout << "iK5319kTeymqqV25YhNZZsoqCg+eckmoZP21cAHkPYZWAzg4lxhm71EAp+fE67Y44szeCSRNydRh"; + sout << "mCl0ZGV59wVfUONdohN/9OQJ8gxifHspscrRszBALwCu2FrniGwRMvIi6lD6tb49dmpPlO1y5SEl"; + sout << "HPF2I271RVx0YwFPprp+2/fzsnKIIhfgTNRVji2tfjHPoge84e5O02Z/cLlg7EzIfr/JK8Xd6mwp"; + sout << "t3kIofktLqWbnBrr8qEvz2BmHRlEjYlpb6HfbK82ZHiOusVJdhAxRcRIqmImh5wtunlJOXdQqfmg"; + sout << "8P0TYVDzQlc6ywbKt/VjTsxI4qisabVvVBkW5VsKU4/JONoNj7fgomwfHqryIIZcgc5vPLawnurD"; + sout << "7JR1LRiTy7+7riDIxKdYtgbTuOnK6tO103XC+5NT07cKydDgtTu1gYoTTgzY+gkK53lX2Swbbme7"; + sout << "wIFOXRwSK4ntGt4XXuuPbF8gc0T8ez8/tOEXE6Di8Ckj2O3uz2vGM27PQH9YU3Y3YCdcuaHlUIMG"; + sout << "s7bgx4nM+xMPmS5baTuAHiPAxp3SepeKpHg0on6Bv99NPSDf7nWRcvvd3+/MHhyDfkbzd88GPnkQ"; + sout << "Y7a91Tlj8RKG6B8W+zzn+SWzTIz5eJcLJCEKN7fXm1YjJjk3Tdi99PZA/K3Ek2k2lcd2oQPXj6fc"; + sout << "EUs9zy6sZdSjttwsYlZzLW0Rpiscjqs1XNA+2D5UahufMk4AVpnDPqZ/OvmSZIPpTT4r+uQsEZlJ"; + sout << "BYf6A9riAPhgR68zPz+i+ffpPcgwURCDviqf370nLIqfbDgJSztxXI2MPsHZqypB+VtvuLTq+M2s"; + sout << "NVdvA6z5J35d3fk8EHVo8/TIWbsSulqnpJvjIHT9GeTV4EvaJo0kh092cFQ3QkbRIIzEqnT/o06z"; + sout << "/+gqB0fKl93o50EUVzJbz02rPc/4qVaqfSJLg+HEZUg30PB8wQHb2oWqgL1lYDlqrv5plbK2kJm9"; + sout << "HW9aQfOyX57rBsYi4aljZl5icy/JElsuNanhnTcTvrHQ/9ntw8QqV2PuVesbbaUQSjDRWG6D1uaK"; + sout << "HSYB/6fvMZ+8hy+gq4d/tMKpoCzgERJjJOU+N0vHpKmyZgE0EGkPJwlgev/oxTJtsQndgrNviPkr"; + sout << "ub9ZRkS5uM7Jb+1mKztyVWDdtVvxnvgboNayuS5/VdwlbQezQKEiB8I8UWLsgEJiXg9sgjBaFrv8"; + sout << "3WgtSQmMRyukOnDPWwDskmUyIHKIye1wOY3H1BUuav52tg8gv1+y2CrVVebRm/8MmhJYe/8DpLeo"; + sout << "1eEAX2SNtZH4VzpZSuAdANtYXgBaNTc0uWtw9Wc8mwo2hTfbu5nVYJ6vlUFP7HH7L+1idfrz832x"; + sout << "l20/+8EKd8y2f20iP6m1mnKCQ9PUPPhWMfWMkJ21VhKJJDcpKhvQq20O/yqhfxabl+73QZiCS/eR"; + sout << "E1ih71FN4x+s952FXOakzYlo5Gge1OCgHE0+YVXBSal4fz6Ye1iRG7+XgLxDIx6AGbOfemRQbOzW"; + sout << "e/8pe0kPqkkS5ogdkCGemb3hKTgGFGXgP9IvJ18VuRDxPSHD1e5THwvULz7V0hUWO4aKx8t8tZIK"; + sout << "flDB/npEo3L/1jU5rLEuX5KQbxJfEY2V22hND1ohUZdU+Uy6BFZ/hdYzFAUyNdLAFS36wB3XThar"; + sout << "54CQ0RGIxCCtv5ucI7VQuc46jkk6SEmZV8FUwv79ExsrB+rAlBvhNOrmVA2vmikrb/iZfA6z1hC2"; + sout << "aj8BGiULdcW0YUiN/TXxoVT0qgh87VjPcOfkj8gTRMF8VGAgRs0HpXVXKZf+ncAOKJ59yu3EsOCp"; + sout << "vSG5zxDxrUYD7TFxrelev/7Jtan4J8ouFsK8ZA4WmBkbWLDBkKvV09c7Jxfmgt27aVI5uuUBJYb0"; + sout << "TdKIG6J/JC85GrRedIb/kRaMiTP0Els6jGp3C/B9PQq6HbzSCUEkLE8is+uWnTDOynbgTtUVEGxH"; + sout << "EKZGBtnPfqgZRDOnZTWMO9Hd9qATpI2qRrgxIvHTUhqD3DQQ37AGTcNsMmj/+mXTBV2vbM9H7Q5K"; + sout << "APzltdgkGc+hZIL8Fy3CXzzRFlXAEoIcnJ3BKT7AdGg3pEaW1YcX4akaOTDmImZelYTCoTGu1R4Y"; + sout << "ZD/rRCeiGR9txS9x9/ptvxeD3J8wYOXxDzCkyMQy1io+izFuN4kTd5MW3RlvepWT4FE2hvjyTyV+"; + sout << "G4F+lcZFjCGmJKguZEzH3Qtww3OTGZqOfL9oADqQsUpEl92hm0uNOPHH6+8gWPZgb3EcCRkWeXaA"; + sout << "DkTbCC5rV04N6Fg3j27K1v7qkykWB4y61W00dFejgPitoYuln4A8lY9RtIKzJYFfSOKnrAqYGMkn"; + sout << "PReI5ZneFTiklSL6FkhWHvr5oz4EsAoU3fLWjky27E1CtpFkPKHLGyc2mnf2N7iBMGa8j2ipy9xd"; + sout << "JwInqdDdGIHR5dUmNAy1A4vyGbOE7qa0ErVy4m3riKkyE0ObQTNoBeYa6CUHwNWCLCbiW2VzRY/H"; + sout << "s6APvI6QIco+Wu9dx36qklvA6/OfM79BsUACnCRG2yo3bHeTeKMwKIx05RkNNJ76eOhjxOiPJBVR"; + sout << "K5V+G3SutaRM0hSK0ABS5nfveZmowfJr8nZHRBAPHyIdv0bJW8lSNIRtZtgwm1dl13+eaeAJNpJm"; + sout << "3eVvhP8c43LG938Bjxs/nfDl5GMzPnyLOIYInIt4wTXCZlRbt6pMXs9IX9/DpUt/AF44b+XQLnUT"; + sout << "DgJa4In4qREDt58wel1oAGe9xpIqdFNPduUuCo/Ly0XHrBFd3jQFgp0JWZ039GtG2YTHpo8rLfu+"; + sout << "TiTAmKnFbVZa6dlOzmVDV+53ptJxiWNfNMa2ri/YEFI7BpKi1FZvMfpGBzNmAdZBAt7kaSvIzqdl"; + sout << "OHMqka3GBbzyW1PusVpj6SWBZ7rsfIFPRdVO4PGcOWSIQ/YlZYXtkV82cw+m4D/ScXdlj0VZBB04"; + sout << "YfeUz1m8tiWsLdHxIKRI1JcLek5PzCQX33RFmfeqGBEa7q/kCzGiJ0QHiJV07Fxm2NKRUZQIWiJO"; + sout << "n1roUf9C06IT03Wd0rcSlwG8Ji9SJZOlBEi7B9Vos61eXMEnkcNiRVeOxLaOfqr20XMde4kquqrS"; + sout << "qZi2ZhNVqokhXNswwMtlMDfWBWWI39q4evlS8c60lQjXg39kyKbyVZOzp6KrJ+xDYUxURb2d/7DZ"; + sout << "+UpikS9DmbUgWqV1pmx84ARtPp8/5teoBM7qd6sRrpJ1Q4cE4WLQfr/nVnVmSfZVzm6yqlxjGvhQ"; + sout << "Bwz85XvMn4xjWIRDNnuwyJh7PoXkoacU54idHA4k7B3qeW59MKtXD3hA34qKUqqE0amH1Xzi3W84"; + sout << "HpL4xT9EoNUPv+ufvo/4Yir/YIMIVuj1BFQACONJWezdbHO7ze7DQrceR9ojpVlkM442CArpnYWZ"; + sout << "3SR0NRB0PFjJcgA3RPVWW591v7Aw3G7/aUh5OLSZsOoxrnE2Hb8wM4wQXutleUjcUCxdliAbLl/k"; + sout << "RcLSE+IzpwRpnSo05sBx91Q71Ws9NeS0Pruy03wScuztfv15MPoBmMH4Nc+JhF/iokGM7C4IICOb"; + sout << "Woffq9KzUSWaIEEO+qaKqfnG7+/bW5U/gDW0xAfvTt0IuXtoJH5/Gq/g/anFdecnyCdBa4C02MFJ"; + sout << "gLqm9Dyiuh5Hny9bAuE08YxgrfpwOx3nGxQ/MrXwGUNjCxJEz4TkmFt/Z6HmOUhobQ4Xue5rK6gp"; + sout << "qWkg7gkXJmpNaj40TFnxQ3Fvw7S1UPNcY3vEvskuYXlB6mI16FCa6ApkR3+krkCKokxelBU5eMxx"; + sout << "+j5Lv8hYzlSlULUJrDOmGRxQ65lOJPmxbuukr6Uaeh3/i5m8GLqJ3EAkeIABkck+sT/OCCS+1hmY"; + sout << "/wRbadX8sd19GjHTLkZ7jwEss0qQj1LdtcEtqXgtzy7+8NhQv9c4KhxFNZdYDezehZYZuf4+UeHM"; + sout << "UAPE6+DX+AtKRD5lSwcAEoFID5GLbsLa+gGWzhmn5dfTDvaJrQNYzI6K1f0MBsWxuXu5SM6dRehm"; + sout << "9FR9es1OnhIFP8bg5uR7lfIEsfAH6ysgkJyylceoooXwE+cALB+IYfk3mIHNuDMbrQxWIMZmS1er"; + sout << "I45Gz53QPLgcSjH9Z1mzrhaZV3ZiJqrTYETtObMRX0136rNLBCaytijF1H4QBJaNEsuJpHAYQHsC"; + sout << "ko8cTk8lzMmh+ENtUjLvrG+pEx9FKCPIgUIZrAAM39fWS5uKhFaAPEkfD/FxbEMdM/hjzbQzcMHy"; + sout << "iYywZCaCYwolzrQtlTIL3I4VqzpUk/vWMIZt/PMK8lUGBDxzELkXNTYepRJ+uR6QZKOU+/LpLk8K"; + sout << "MD6+BZNl2karSYjx9qDDHuADoiNJbSXLV2NqyJiNyTmdKv8E4e52mVFfetoMp+gpd7vMJaAObPuJ"; + sout << "A7rUsNjWz+Zho52LXUSUC1G1MdZPRZPMfk/KvNnXZX93P/KaBXTNLlRf4KkRXQdrnmP5KgKFOvIX"; + sout << "KUMJ7UQtP0koYJl8OsFHG5yX+qI4BOJqVT0eUypXtlW/CHRX51Wl11wIDsTqbi7zEkBu3kFLMfnr"; + sout << "3MEYcVWthG5Y9KecfsdLtRnQVSFRMEKEn1kpVAb7xqZhwKDqREVs+32AawqNxPdO4JOCC/wW0zzH"; + sout << "QlJ8KhbcrxKLaygvVOrqaZGeCNTrrRxpG792CtX4OqXdkxpqcPNjhQJXMgSm/OvHJ1sJOrIHYaru"; + sout << "JPDH/J4e0qg28rlRpN/iY9xc5Q/dMxhoEsv87ehP+MPKF5ByxrWHbNc9IaOiz5BQIQWAc1glXkKc"; + sout << "oyMgoCbZ1Zss8zJ1+k4NwoC59BpnT4KbZHZP7+MbXZidCSosl1P88yC0vj0BX8MK+8x6PA3X9Zc5"; + sout << "Sg2OMThc6WRR3oi/DeHePYqJgPtJwhZt7N651llY8/YIJD7pqVEiB8KJGcrjZ9tDuU0MDTkKnrr7"; + sout << "qhlSl/XFBGq0x9zIfYeBt6k2wgrpE2/FjgziM3YuH/e7jfppGx3S3G4O+yrUuAynknZQ6Opq+Qs1"; + sout << "PYFZMW8Fj+CPMgXGy/2+JPXULK8vf0hN3FfNDeHiGpuGEVaEBHdZdE3CW8cWwyRSdvQHgPbXwEaL"; + sout << "pLGxPhgGQlPLK/JvDQhtL7gqS1LQEAER1sleOHoTV5zEpCsKUcvisz0V7TyFozYnzaG0IcfeGysR"; + sout << "Kx7O4dq05dlnXicv2IrhcCp9+QZSHDL/E/t/SEafxZJu+sIhqknM+Sgx3MAhE/U8KyOANAYy1aPB"; + sout << "H5Qt5RnDbYuFJXHDG9DfEWWcNP8e2EL1CX0/gncYvIz3b5Ge0gv4hwbV23Xzsi3hOki4A/B44nOd"; + sout << "fAE3Ao1D16+6XowEVC8gUh+y1TSnYPTtnB30NbrcMhHNaJP0pYzG41gNaMJNjK8SGXSnPKiXrsJ2"; + sout << "1jsvbLaMTEBdCqw+2lgg/QwEVgUef7JhVNKbQKf/HTtuD50Ofa4PXv37Q92/xnHWHwjbWfeNEZWA"; + sout << "Y5bqQI/MlWBKWSGFjpoQsfbCzS6P3ieD3ID91DmB7tDjTDYm5Wq6LoHAQzx+tTxiq45Qn5tUg0zx"; + sout << "7VC8Xh0ygItADoEqXpKUXdfv0o6G4zxjePM4dvyf4NBkh27q1S/aQwQp8UOury2Eunij57XSxaOb"; + sout << "IFwlFJYC6wtmg6oV0QZAburx8qJMLdIvHyqa5YJJ3HGuHJIyiD+tuxWVxDZSJQ1RedYqBPAiloyy"; + sout << "Z3j6EJKqNh0kq3c0K/B7gpX0P1ZjOFZUadIbEW6dd8bxl0RuyhzdSUMkn4rWgmx7zPvoOEgXzbK+"; + sout << "mJmBT9lEaYE4sFkXfZqDCGq66jdc4RrlD4IZbGkXwMeTuHsfLL1hviKSOyJCTbmS5dKZUdWuVk8R"; + sout << "XFqATzyHHdArHR9RvNa2bQk5b59tODpeKCqECv5DJ5T2ap8u21ctDCJiFCHq3gLfw6IumI1L2m6/"; + sout << "aPGdZ0d4ZMR9ooQvJuhAODKg0ZsA+LjHTakpfUwHpanPWkCbmhuh3oMC+hMX+AqOi55nr2O5uxRL"; + sout << "9UwixLBmRa8g3lbgUGdteTTbvZ2ePyzwxhWp1RS+aFJqtORmRMkgB6vRc4SC1CywKwwHoR3RKEj7"; + sout << "uW6oQyLQqPYevLuBrbO6yn7U6DkU6blGF7jg/2y12MpwYPr8l661YMXXsmVHCVSJ0AsPSGhJ8oL3"; + sout << "Cqk6DJbpyoN8O1xK4zNE0cToRFMhEBjlim0lg56HtbEHfLkqRwnFqfo6vvK6opxJoShMXDa5jrLH"; + sout << "GkpmE4DzaGtZ/P397TF3Y12c6lXJFDYWslp4tMskzsy9FdW63kRUvl6Q3UsWj72qGE6r/PGVetJ9"; + sout << "in33oiVHTjYppFrza0ryzPE02V4uobC3y2DtoG/YK/GJFkhm68O0QKxyuBURMfT4j046fBQAbwUo"; + sout << "B/Ylsb+srIUK6sIEnnzdJ/8ve3f961Yx5Kywf/9kVZnUz5RoiP/bOWXm6jasSq7LEvX2nT2fweup"; + sout << "/TI83XIDWd0rFQoGTfuIuXFfLQbXkX4oZpmYLdr0kQAKjFtNB/JYV5PTE7PskJKD/AS6iYLd2pOf"; + sout << "cEYJxPZWSWUSz+EmRNmeDl3lfch9LD6VXgaxY1xPF0/1uQfU+BBikdVQJPlMzVB9QK17ir6rynim"; + sout << "CP038a8ctWt5RMBsaJPZr7bieh11aTTW2mC65Y3PYQ8WsXofuw4x3xoXI1S/hGCM2QvmgWq29xHp"; + sout << "5Jkp/Z5YLaGBGxCHM+QQm4LzggVnhYjlguAbfEmWFapwhzmx1L26gv2q2AUiozn2I7dAh1vDRKD/"; + sout << "u5XMODPPJE+NTsr5DQz7aIdEwRLZynp1FO+VN71nYDa5G8ruF4v320ocRKm/mQ5uzU4X7w69CLdS"; + sout << "kmag9jfDGvuolYqCooAts5b7tFIcC3WjlXeLPq5y8HmzY69Z72HrISpq0Fyq4vcZaksLSdpv+Pil"; + sout << "iUu6Vmti4LYLHYsmue1UgMzL0qqdRMz6XmCPBkUXYDS4oOoQuDcH5iJo9aWRoADapKUHqIUnoR3O"; + sout << "Vx4h5b2/XtCEb8dNF0ubJ+oZGbAize7SWTGnUCQdpz2wxAWFymc5/5jFxfVU2BKSPhrKRYImgnVU"; + sout << "8GOms14wCx4wyGuRZY3p9s1uPUBNNjDnJqVg6I+7STXJKrYzeP2gP227k2JE3o9eLehe66hcqaPi"; + sout << "egdpG5RfPN9X87FtePJ6lPjJ8j0Ysgoa6l+DDUDuEZHp6APIG5miY893oac8uo/r2RgRNVv6vLFo"; + sout << "8VIFR7IwcLj1zXvwriv/Szw3POh7y8svSb4eGKr1c1/5JTJB58Fcjz0AMm2+rW0Twwb3STxn4SZ6"; + sout << "nXOyP0btY0VCckovcLoFij3lsl21ZGMdfG9cHVlKL88pg4Ip00QmgcQW3QmBBoCjlPlRVSSsDHGY"; + sout << "8TBBGLuxJi7NPzU1HNtjG3cnYw0t49og2hjrIbF+6fHb9x0pPtnJZwX7SbBYlk4Z8v84fR7cjC+9"; + sout << "l6SLvaRgqkTj/aaiiHtC17zaxNhP9wHqXmPUZdKM6xs3vsAF1dYOLPlIv+nmMLBPfravTZDkf3p1"; + sout << "x56EjPHwT1GULNkLBG8iod/cteD0E0GIZf0g5/o0hfI1t18CMGxfMyeaASZugV3+KL1yOwXD07D0"; + sout << "lp8iLn8FlYgXOgUO7+OweJcIu1IwkzLSM2aNbnH3VDPlu/Ff8ZHL0jiuxhQAVT6jdWJSUEOf2yiB"; + sout << "8mIGCK7CH9Xv9l/grSdh6GrE+NvvWqKQrhX1k8wxHEM4PnhMSO4R+5dRFWeeh6cYfTNHTYWZ2xzU"; + sout << "run1L6tULZzpksLtHYDg0vEEq3hDS/3yf+/cvCX4ibt5pUqorTpAtNvbTaguUyBy2TOAy6fSFGh2"; + sout << "eHil/3JEYZXnfFtcBo+pIvt3LWEPlWCUHNKFLQnpd77Y6wUFn3Ku7o7Nu7zBxemmvxxYUXImAlzt"; + sout << "354O5G/5G1GUGFf1K2u11fFcuXrfowEE+1eUEpC0KLyxZhOJa5nA6dKtnQbq8wrGXxuGuJGlDSAu"; + sout << "0sUYLPmELc+kGUyY6A27B0FKFN50bP1U5iWLAtxt0NmPqOnwzvnj5GYQ9R/ZIpzf73N2OoYL4+ba"; + sout << "6E32ION5IxY2YQ1IqEHxsjzvDW5KoQCb8oz63eKgwLHBz/1yhA0ELpzG9ti5pVE4WbGKOtS/2xTh"; + sout << "RIgpnpbB7bPUdtw33cjky7t6UAO+QYI1kg8rscd/Ug44hd627JK61SxnGlK5wBRj7aoUxH2yb3Dt"; + sout << "jMgmcgZYtdIsHjuU63vrN0acMHULcyCRIFuFEtXgnQNIKjPUG3iuN3714Y9sncW5HqDGAYLyRpaA"; + sout << "69XPtqYEajN2uLF1Q9KIeW701X1diQoHw7TFq0p5x1oTMRzjqcz5lnLIM0DycqCPGoGAnyL0o5A3"; + sout << "Rw9qaSq1bB5VOdqpZHyN38huATstlHmcO2GqN2r9k2BKqqDYxuzmhB1K5ugoJID0lm6KfR87e+2q"; + sout << "CQL1tr6ecqFe4LkO6nRR57w6gl48J0tFYuxsgRcktYvkt/BF1JHNnw/BChE9lDBgBZz8TfAqUv5b"; + sout << "Ofi4WiuGxq5sKIa8a5SRwfVbz1h/MBqBEd9nDlZ1acbg9ZGakzwCfH0WrcArLiN9FqGqmiZClS0d"; + sout << "cUwftQDUI7yEoiIb18/479c1VU2q5dJddCBaabm8CtGtd8w8KTANBXX4pZoSuOlYUUQaxYsZ/avm"; + sout << "RYqfU1uOkHlmqm++EvfKEGW/Rmoq/fCIl4mk7YoLpMcub9wlTSVP2W//uvZG0LYwlZScWGOGmo6Z"; + sout << "xa02FmCXVIyXUfnitTEv2oYj3CV/57nW4+1jkXTcYhH8wt5wX5G2eO7qw3esj1x1kL07xmven4Il"; + sout << "nmKMri2FNxrWUBzvCwmfKGfPh1IkQ9LAfaJJATfGeBFI33RQdSILTaozNl88dHbUO8I+ZnySvavV"; + sout << "uX9Ia4Gvrm0nzV6WYbE4SCDjppsaz0CSZy1exA9t/NFoQFjvf25pszQ9JlDtwDm65ssYqCTLjyoJ"; + sout << "AjMMUygef0nD9WxtWkVCywmyR5HIZWqwV/poAROWBUNL72p0kYxDsm8u8D6LFDQJj+/rSQFr4jzF"; + sout << "RkH7zemYp52d7nSo7ZV84Pf11aTTWqg1OCwAFz+bcg5wZObUL48+WzPAPePNtlo2ef4hn3tyi/pj"; + sout << "v62xZax1oHnnB4ozyZqwSd7aH8LGN4G0Pk37PCagVXLEyLEGQXA3NqxxokimT93xORZOF4hZjhUR"; + sout << "EM6aKChyS4DKkB/HN8IEMrPcKL7zadhDrWX6aeAxBOIbF02jTyCo7rFEO4g3TLuy+aX29SStyzAd"; + sout << "bESo1w0hmnjboB/cUVh9SU0rWvyHnBveXBU1QFsmKEpEVXbD9iao+VArYlDKQgXS0elIQIJHLJ2v"; + sout << "8cVk5GtsXQU9Yd5vyzC9R/ZuUD+fuRcoKYwevCjnVegbn4mCK6VICihqWM5etEr9CqdjjzAtemN3"; + sout << "RhW/4c1v6Gfcb3LQIctJmk08SbUfkImRlULjt3sr+iF/9gMx5AneRDqnq1YRbiusqAfpCl83zBFm"; + sout << "/txWD4btmU3/Q9TzIYjcDm8JFIpptv1+F6myuN1ElJPj5dcmfBZ2/KQknRf7cFBSFCfeKS2glsIm"; + sout << "tTj1p8jK+qKf3GS+v/n6VutWgGXAU7bjZSfaWn3wfNX1rJXOX4Czp1dXxmXuxltVQ09bTpEMQL5o"; + sout << "F9LI3ExjgsokVCsnuAc25hTRWP6bUkWsd3fvbVK/Qrg8sYEpq+3836NJyb2BcFVBqmDJaW+MZ2+P"; + sout << "dJwzMjQMFfRNnsEBRwHuVTDwa4tyR+1yFOqG514ohI4UKamdebXPlrjm278ztqaN/4ASFsVoQb7O"; + sout << "nPvqiMT9REV9ZLiK8z2NK4dDr4KUCr/UihBZqX7qQWnFRyy0VooOFAyP9CjfohLthZ8Y0HWFDMpx"; + sout << "0imNkxXlh3CIDNTRXAmgFupIEQyY08sZX/Oqx4NdzpxWLh2S0VJmIupzuxWTnKBRJXWZEVPKqsak"; + sout << "zBsNzdtcqrF5dIIY/7d+8LFdEBIZ56wftYkRRvJt9S7mPIY++gohDronAX/ohfwXwu4jCm0OLWzO"; + sout << "l4FIAdeN0m65Nng2pWnF6M+qB+b5BT3I5cAh7GF4t3woiWOBF7LnjOYmHe7ZkzPzftxeZ820e+eJ"; + sout << "i1pFSQl3/H+aneIsjEQVGxgynGtvGW3pz/0f5rWrRdUG9ZYPJItWXCjJz2k/Eb3fHDA9JVIT9FEM"; + sout << "XCsN2FQguDWy7cDxNMKxVRuhomjmVCR50/W9/Wsjg6+tnNNhY5Iukh40jU9uN/WShMpzXNv2QKxR"; + sout << "bfUtL0zpojBbZOkAd1LZ++bQREzjPBFQ6G18eQa7DphNarttvCzw9qdgvmwpl3CfvW4PeFChxQ2V"; + sout << "SmfUcuFnESjEN30zsEyFtIJNVd1E+4C6vDxk+FRWl+jGbzNkQ1GK+9z6M0HgUxBXMXcHiJbMEJ3s"; + sout << "Ri6PqeKjZUAJzORigJmgOO5wURFI+3EVQ0NuPWfaQ/6UZ1n7fdXi8ncmO4jVwY1ptN/4tJL+bAm0"; + sout << "eu/c6ug2bfNCTx6TnutxBAg5AvxlDrJIUsx+TsM50J4OZmv7pI4dC0UWU3TpzaYHa/cQ+OY61UT/"; + sout << "j1/QPSquxIlWA46SGt/xrUg9iPIWw+2hdRBkhTBE6PrMMHLl5jxWpJ96YH5WDnifNnnhzRDcLfyf"; + sout << "9nZYWQaBzm3JangobvnsDPgjF91OfeFJhVPyUwS5u+Sxw+NKPj9wVAaW3gZVS2BOODAud/EvFfQF"; + sout << "08t3Ab27GGwWNx9ExN/jlq8EHAM+VRcyfn3hBVjLZdt9fs6zZliYdfyNaA/fgO1iLwY4DOSKgALr"; + sout << "Vs5+KDTmyNf5DkAg/W/d1tA/o76/qXUz3gHdm8Q1TqA3ZuR3g1BE6oS4T9Sg62oYIthP99doW8ll"; + sout << "FPr0aF+JHwXzG0k6Ao4wGdKmqi+n91pVYI7r2zE1kKtk7GyF0fSmcVnnP42TQfaUbaKB1bMT+a1V"; + sout << "NMIydv7QD5F2af3El9np7/1S69sGpoNK7PLow71jdlCewGVrq+iYxZaVhe3xHwerQQ1uGfmRxUSK"; + sout << "exonRQPz0yD5fkjRCqq+Hbys1CXe4iKa2uO0pW+yMlvRnMWpV3Bx6uxYR6pECBJ0x8DQHk4cBwa9"; + sout << "J/vAlf6dDaVh4PIZy6PF783iAPRusNy0TcCxXTJbIg5YqUb5QAyJ2HEPbIaKeZwSnhytAJKP98JL"; + sout << "JC+81f3cN/tJKfWQAJz4RgcrbhUfyf5yWK9rPOwvFf5WKfecfPWc6wpT4IqSrlDe3LQqxpFxJwfG"; + sout << "dzNo23TgOFK+H176FNuW7jXD1sUsh9MA3Yycm3BWZ85cK55DOI3T49K2WB0KnFLkhA9wPJNb20qt"; + sout << "m1M8QtKwF6jzpMnbBff/vM4oNL0K0QzXovI4gJTDUkHWUlZ3XxMk9PQIqoKZDp9v2JTNW6cs1zcV"; + sout << "Jcm2XE2ZhTCFmOQ/DIG1AQIzQfOY6s/VRvHPSkIJH76fo1ex2jj5CmbGw2JFNLPDQIIHEjQJGHFw"; + sout << "KpVGso6XF6CdMtcmlswoRrIvN0odvC0md0D1lf09ZuoyNt68aMSxgeBhoQYW4qC6E86Hef3TmzWO"; + sout << "esZCvDs2I3UVKgPEkICz2TEHy7dcC1VzU2k/F/cm3+y33lloDem4Dh2igflppvhthcYCLOmFiEW4"; + sout << "YtxiVJKpIAKbNHqv9qpxNdcbLCsxkENYKwmmG8E3fLSyE7St/z3dvuTuDI35lxANzv1N04YrvfEr"; + sout << "dIPmvepllDz4Ua9TyYQoZezD6UsivW0gfWdbG9dbYnwAZRKEgQjRD1zvGUt/nGCUKdtBDh/Qu7jD"; + sout << "c8KMBJx0meoY4jLZBSw3Rccd/KOvzC3UTSsBeTBSkQL7wuSMWZf2cTI3BiWaVTaUXREVmUG0eeSy"; + sout << "W+Vym1qM2tPWkUm7V3toeS94LU44OlGDPyoHHUOsT663Zew0ao3+rSqS+KASC4L+6oqP3ffSI2WC"; + sout << "517CtYtPA09gFqdWnN02mJ/+gEWrYUZXJsNh31AGQ4e4N2L1Tupy+L+mgkjTyHeiV6dUsvVQ2J36"; + sout << "R2VL8EOQcchBMinBo0WKkP6xoTPcCwwMI7T5sHHR+KVUs2DJXTNluNFhYyOic2ImOwhoOHKwfr7I"; + sout << "bJERGZYKwwBqGPO0mMnB81MZFumFzoo0SNhvNC9a74X8U6gKtDsQCeIHNXWsWxaO2LmjhlZqXEMC"; + sout << "nrLi4UnXweqUgDscbWdq6fE6Ad/ZimmUJlj5iF6aXWj51B2VIgtYXxBFlVPURZDw9z9KPdyidzM1"; + sout << "PSitwb6KwWZK8fJ4ZUVOUt1bIki9wAnpyDdAhrPtDVtngS8RJrcCVRrjQ91Vhr95kSTG4b6VtN23"; + sout << "m3lhkU/T8RAVAphAait/TICzXeXdnFjALDIubsx1e3FMhV/TJHflVQhnahwTaUuIVqb1ZkmP9aMw"; + sout << "0K1G37NbwIRzyHGVWnvJXxKfLeV3n0OJZ3W5dDmTOZSE1s7JdJ4rdEXpiMVsW07TB6JEtIz0c/AV"; + sout << "IvDmgF9QH5Ly4Ko171Lg/tVMjBIhCiBU4zQco6brJHx+MoxEsNUXgUCNoMqqjd1a3F5wGsvhlWpn"; + sout << "iwNpTje2A/ccVQmoJpDSNByMTB/0GsGJTjZ0dWR0KiltaQaKTlRFlGn/2jgbk+i7ujAGj/Gls6DY"; + sout << "WoR/asND5U/rlyzwJS77EefK16ew8w3Nfy2g/vVvpIUQBG67CnjUzxpnutsI7KvZak2Yehg2NIZQ"; + sout << "IAuyXMd5zixobwku63RvLT6Fd6D63HnQhEtgtPTcLJT6rQOf2cXAXN3RLOghcpM44flJ0PW679AQ"; + sout << "pgVrmELVQr2Dvv2fCb3W5k/JdgvTKrJ2YPIY9Y7q5aSZ28VC1u1k6y/KsppJ+6t+TW8PloS+qKNb"; + sout << "2tawgDRraMsNVUTMhABYEPZ+qoYspJPV4rKfVnELc8otpvkR404ulUsELxml0TndtntTjy/f9lmi"; + sout << "I7pfDqPsOUfwUtdhuW1XyIhX3RFmavc/VFLHu8gVSmJgdgxjraoe1ldt4F8vgQu22uZu5yBLtxpw"; + sout << "WgKVifBtJx/lZ4iAPMNJUZt+Fo+1K3uofYmOg78dsM3BSeghNcNI5ki4NrUJ3yWTzfvrUIr90Ee3"; + sout << "VUaGz0/wtOsFI74ef6BaTCYNAt7psf/dwoow4r4FzHgXNqr9MtWvfp22gGCaRLuUWTHgRzIkCflH"; + sout << "YyNFYHD+yTaKmSrYay2bBYFGRFe6Vx9b8FR7fmdyqq4HzRwQ2/FVeSQ50OHeT3vWTTg07M31JkZY"; + sout << "VlTjnw8Ew9N+o+qI9AFaftksVdXr2YjXMpJMwbHcqK6himQLSkzkUEQXsBI6vQSPobBtPM2oNlCs"; + sout << "oEf4AvHYmHxyyzVdSRjQqEoekWURI6htJxRQh0DaDhqOaRQmUp0AkB5StSD8z6o6Wyf0yvzT8OJo"; + sout << "NEjaoZbz82eivQUGf3RmYLZIx4xjZ32GA2fUN50RfgyF/KzW1GNLLykJ2NK+saJKGZB7RC2ZlsI6"; + sout << "CmkHIZeNwG131Opp+ygasp0GBJUzlKnEzmV9CIBigCv2oPPxewEr7T8/OGDPXWhY2LEgVbgyxLdD"; + sout << "MnZlJC28aanAWGoGMcuXgyacFhOWc7I7TTml8jlQv1oDRJU39PQwZZ93HwhhZ2s87wuRYCQdXSH+"; + sout << "pdtb9YblyosciNtYJguXl+2iBdGGIqNAp6oxYj/rI5l0pPY7EUkGEnOzQq/U3zCDPL1tdUymaOh8"; + sout << "+XGYMMo5L34oxvi79Rh4snSaOO5h0fl0aaZES/v/x5QTtQu4H7qwVfWIsOg8ujFFDFeZmmJpb7pc"; + sout << "ff03fNUsHBfS3yf1JvrXhVGDh7byfy8DTFCy07gXUaKJzQJNOjEqX+iu25I+RNxXbfHsRxsWdNEw"; + sout << "6aHRUUb+zThVJGkglO5W/S32SJznnjkHdFu1WXBU/DbM7b/quJTUgWnZRLy6CugMu4I/hp5ArKCh"; + sout << "vekyQySI/9TjOEPEFW37few/H060cZcz1V3DilvOTBrCoItzdjxZB/cmGxPln3mhkPqGr+NsXRdJ"; + sout << "tPMRc91NxvDP1oBnXXEYD418bmxfvqbXZm4Q+bMq07TT5rWABCZ+NFmkgkXryGhfKQWqidxbRenT"; + sout << "8teao4iAbwSHWB/Za5+OPYUS4b2u7OULXqkmbQDnTeJX6ou0VFYUpkXStbtLITLJrvAG0BrMq7G9"; + sout << "09vbgw6e+krF5JHaFXKAgVf3/SlvRrZi6zgPT7wsY1WLGlMAFNAC58UhizDbVV5Xivj/mVxi29s2"; + sout << "tZHKQRshEnuw1KvEp61O4HEnO2dw27v1000YgRz9HtQp9Ra0SsePIjlh8/h5O2VTqJgMlqgAiy9g"; + sout << "l7nvX3Ft3a/K7S8GlVMwM8z4DiSzNf0irgnC/+sVu3ZAy13UiviJTv0gZ6iVL78GdPSqhKq6x8IU"; + sout << "JGX2sbzthVZUpYPm65WjEBKUKsHWIdeAokKh+sS2TGEAo9COawULVwzYttKSY30UlsBLL/ofpjF/"; + sout << "eCWfzunbpZ3MmkXJYVygXCsGgEuxDmWrhcxIF5/pPZDafmaxqHCml/zxew2twDN9lEJV/jqZrwSL"; + sout << "I2awqrzxIp/4TfqYj4C8HOQGb4N246snRl3iH2tvzQZrCg1I2mp1s0+xiHESKtfecHfMt8hbb2QZ"; + sout << "50c/3MAKQb9WatUDqfZuTnnwXMo5vm0Sh9KqLYQj81LFOzKzf1NDil38GSmFGWPsNm7vm8Q6S6BI"; + sout << "MZafbg2gM+ohtKS/u/36ZADS9/bxf90Fzkn5UEjZUOIRBhYowQNilZzHCABNNXFO/5SUzSJqgLIA"; + + // Put the data into the istream sin + sin.str(sout.str()); + sout.str(""); + + // Decode the base64 text into its compressed binary form + base64_coder.decode(sin,sout); + sin.clear(); + sin.str(sout.str()); + sout.str(""); + + // Decompress the data into its original form + compressor.decompress(sin,sout); + + // Return the decoded and decompressed data + return sout.str(); + } + + // ---------------------------------------------------------------------------------------- + + }; + + face_tester a; + +// ---------------------------------------------------------------------------------------- + +} + + + diff --git a/ml/dlib/dlib/test/fft.cpp b/ml/dlib/dlib/test/fft.cpp new file mode 100644 index 000000000..9d491e8fb --- /dev/null +++ b/ml/dlib/dlib/test/fft.cpp @@ -0,0 +1,553 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tester.h" + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.fft"); + +// ---------------------------------------------------------------------------------------- + + matrix > rand_complex(long nr, long nc) + { + static dlib::rand rnd; + matrix > m(nr,nc); + + for (long r = 0; r < m.nr(); ++r) + { + for (long c = 0; c < m.nc(); ++c) + { + m(r,c) = complex(rnd.get_random_gaussian()*10, rnd.get_random_gaussian()*10); + } + } + return m; + } + +// ---------------------------------------------------------------------------------------- + + const std::string get_decoded_string(); + void test_against_saved_good_ffts() + { + print_spinner(); + istringstream sin(get_decoded_string()); + matrix > m1, m2; + matrix > fm1, fm2; + while (sin.peek() != EOF) + { + deserialize(m1,sin); + deserialize(m2,sin); + + fm1 = matrix_cast >(m1); + fm2 = matrix_cast >(m2); + + DLIB_TEST(max(norm(fft(m1)-m2)) < 1e-16); + DLIB_TEST(max(norm(m1-ifft(m2))) < 1e-16); + + DLIB_TEST(max(norm(fft(fm1)-fm2)) < 1e-7); + DLIB_TEST(max(norm(fm1-ifft(fm2))) < 1e-7); + } + } + +// ---------------------------------------------------------------------------------------- + + void test_random_ffts() + { + for (int iter = 0; iter < 10; ++iter) + { + print_spinner(); + for (int nr = 1; nr <= 128; nr*=2) + { + for (int nc = 1; nc <= 128; nc *= 2) + { + const matrix > m1 = rand_complex(nr,nc); + const matrix > fm1 = matrix_cast >(rand_complex(nr,nc)); + + DLIB_TEST(max(norm(ifft(fft(m1))-m1)) < 1e-16); + DLIB_TEST(max(norm(ifft(fft(fm1))-fm1)) < 1e-7); + + matrix > temp = m1; + matrix > ftemp = fm1; + fft_inplace(temp); + fft_inplace(ftemp); + DLIB_TEST(max(norm(temp-fft(m1))) < 1e-16); + DLIB_TEST(max(norm(ftemp-fft(fm1))) < 1e-7); + ifft_inplace(temp); + ifft_inplace(ftemp); + DLIB_TEST(max(norm(temp/temp.size()-m1)) < 1e-16); + DLIB_TEST(max(norm(ftemp/ftemp.size()-fm1)) < 1e-7); + } + } + } + } + +// ---------------------------------------------------------------------------------------- + + template + void test_real_compile_time_sized_ffts() + { + print_spinner(); + const matrix,nr,nc> m1 = complex_matrix(real(rand_complex(nr,nc))); + const matrix,nr,nc> fm1 = matrix_cast >(complex_matrix(real(rand_complex(nr,nc)))); + + DLIB_TEST(max(norm(ifft(fft(complex_matrix(real(m1))))-m1)) < 1e-16); + DLIB_TEST(max(norm(ifft(fft(complex_matrix(real(fm1))))-fm1)) < 1e-7); + + matrix,nr,nc> temp = m1; + matrix,nr,nc> ftemp = fm1; + fft_inplace(temp); + fft_inplace(ftemp); + DLIB_TEST(max(norm(temp-fft(m1))) < 1e-16); + DLIB_TEST(max(norm(ftemp-fft(fm1))) < 1e-7); + ifft_inplace(temp); + ifft_inplace(ftemp); + DLIB_TEST(max(norm(temp/temp.size()-m1)) < 1e-16); + DLIB_TEST(max(norm(ftemp/ftemp.size()-fm1)) < 1e-7); + } + + void test_random_real_ffts() + { + for (int iter = 0; iter < 10; ++iter) + { + print_spinner(); + for (int nr = 1; nr <= 128; nr*=2) + { + for (int nc = 1; nc <= 128; nc *= 2) + { + const matrix > m1 = complex_matrix(real(rand_complex(nr,nc))); + const matrix > fm1 = matrix_cast >(complex_matrix(real(rand_complex(nr,nc)))); + + DLIB_TEST(max(norm(ifft(fft(complex_matrix(real(m1))))-m1)) < 1e-16); + DLIB_TEST(max(norm(ifft(fft(complex_matrix(real(fm1))))-fm1)) < 1e-7); + + matrix > temp = m1; + matrix > ftemp = fm1; + fft_inplace(temp); + fft_inplace(ftemp); + DLIB_TEST(max(norm(temp-fft(m1))) < 1e-16); + DLIB_TEST(max(norm(ftemp-fft(fm1))) < 1e-7); + ifft_inplace(temp); + ifft_inplace(ftemp); + DLIB_TEST(max(norm(temp/temp.size()-m1)) < 1e-16); + DLIB_TEST(max(norm(ftemp/ftemp.size()-fm1)) < 1e-7); + } + } + } + + test_real_compile_time_sized_ffts<16,16>(); + test_real_compile_time_sized_ffts<16,1>(); + test_real_compile_time_sized_ffts<1,16>(); + } + +// ---------------------------------------------------------------------------------------- + + class test_fft : public tester + { + public: + test_fft ( + ) : + tester ("test_fft", + "Runs tests on the fft routines.") + {} + + void perform_test ( + ) + { + test_against_saved_good_ffts(); + test_random_ffts(); + test_random_real_ffts(); + } + } a; + +// ---------------------------------------------------------------------------------------- + + // This function returns the contents of the file 'fft_test_data.dat' + const std::string get_decoded_string() + { + dlib::base64 base64_coder; + dlib::compress_stream::kernel_1ea compressor; + std::ostringstream sout; + std::istringstream sin; + + // The base64 encoded data from the file 'fft_test_data.dat' we want to decode and return. + sout << "gO1l2wKz8OsyeYMPYcGx6QdBG65vnrB+omgAJ7Bnsuk9vkTw/Y9Y/UZEFXhVf6qnq92QHPLV16Fo"; + sout << "a+IUHNTjoPAfBOTyfb8QRcTj9SaWpxA65+UCJ+5L6x/TEyPKDtB23S0KRpRSdfxBSW9/rnUrkIv7"; + sout << "6i6LWcxKzdsw2WGsRCX1k3t0adQW49m/yb8LV9Loqs7/phzY7HkJ4D2PLtpc6Wyk1qG/h6KQ7nkF"; + sout << "GFkHIoh+xKXhHpqWaSofx8H8m/++H++g0VSPqfQ1ktFz+K8UtiGoyR2GqpP+br47YLXG3WqVU5Km"; + sout << "Di3+IjQoBH2m4jykD926aRvdRrgUH4gZunokl+U6shv20Zm0NL8j4A46/2f++YPGCVBNJJmcJdI7"; + sout << "9RlPL9SFbJ8rnH5bbLvZ2pKZmmbeZN78yzLUhdGwn4DGpf/Zo1fU2YPUjVKkwY6olW4w3tiBl05a"; + sout << "cS1HwBeQjnajqsXNyudbrBkM1Z9XiwM+J5iMsu5ldaJ8iLn30W2Te2RnZhJRHO8MgL7Fn1j0n0Qb"; + sout << "8dB+6aQYv0l/5LQkr5SX6YSRYX5b5rnqhi8IzJKms6dzoyBm97IGTm8pRxtLXcmsk1MvJcHF2gl2"; + sout << "CslQazsl5iIS6fMxEodmlMdwdfIpp/6MqmeIydSHwdyJJZnNPl2p5X+Il5egmwdaSoDQNphPfTaQ"; + sout << "R0Xh3xqsZKgHLKxB14Rsf/R7Eu9ZASTByX3UrEHsSzLSUo9/G+tS3n1iC30Liusksh2Wkt+/QtDy"; + sout << "A1ZX31H5OlSFwCYC/TYitwyl4U9k7WhHBDoT7MdmVTYQEK1dK48nwOhnZa9prE8n3dD40CCe25q3"; + sout << "Qo4VVYc5tBWu1TfTbshvkmHAcp3Gyw/caqq6jdq5Z2BD1b67i/bY66xhmowOFS8xeA7v6tKdkvpp"; + sout << "Rk8FegzVdB72wpw3872T4K+eplMDcCPGkwIieF5pZStWxhGsNOC0p2wvpFvTpQgfNOGUvRt69hsd"; + sout << "xaUEYlWZcY3sfsiOwPGgBUEEv6b+8W7+8Ddj8Nx4wG+bdWozphfz7THbmOeaDM63imIEHmJbZ47I"; + sout << "QgoyzFD5WoWtZ1wMEv4LL+a63B3FzBcvPvdPaa2QEmyiK9yN7GEePs2Fv2A3ymhGw5NeR1dOzAjz"; + sout << "lEQW01p8opk/dpyLO18zj8d+Hn4EnJkKD0p1u+XuLRda8AnRu/WmSOOpyG5EUrUoEyuvbECLbY9X"; + sout << "3AMgzkbxltmZlkyOOwfCGM0yumGYKdz0aGKdyid7ddLMTpQ908dCNLyRgTybdZG9137PQirgX5+O"; + sout << "08T/+L4EIyyrslOYxpUaLm2ASnSUgiivoIJvfnu8IeH2W9fPupY89ioXIYuwZU8f9FDCA9z7peQw"; + sout << "9H6l4PDdDrB7nwQhncpV9FYLkHQLbSgE1VD+eL6Y2k48pI2zUndcoHEZW72NcmK6E8fDvfgbKkYD"; + sout << "m02RiGuj4tvEEsIVuVa29Q0JGO+37n7Mlz7+RMcUMo1pLnh+jibas6R+1LCy7b4ubiKMFB1gvjut"; + sout << "gjMABy1dJxSOdb9xUa0K/Alwwu3MxdkrbTxwqkn0C2JnVV7z9S2I+PWcfZKzcpg8Itzh/ON6I/DE"; + sout << "EGK3s39XhLI2xPg3PE9R9QMaisqxb3FeP1NkBXrLQtuQfrSk+KZk6ArQWVgtem799fxgipQsa5RH"; + sout << "z2Dq9t+pJzNGUnWg5PWzaAY3lWMscn+BIRhsZfDJ3QBtS9Vmib8r2dtYwXi/Q+FhnAHFfcXbhDC3"; + sout << "GHn16aP2PY1sw8KMtfPRAcqY8Ylbr9EQXjWoIYUs0YyX2Ks8ZgibunTPFz/Wu98RVYswMtjubFaJ"; + sout << "jb0pK9S6qoe/w10CAAHqoAfca7uMOxw9trZZmjCf5vF4leH/nDgsNjesYn21rE6rLhSbg8vaZXo5"; + sout << "I/e1uhZlRz4ZNnMlZSnL70Jt0IjuR0YNphCsGZjmvvZ4ihxrcLrHvAcSTJuqW5EARtvjyQWqBKSP"; + sout << "5XhlkrI73Ejvy+Lhv6n6O7+VrfWa/tGRuvvAToS1wPOP1T2oniDXsNlD0QbMnCao+dTWgkTDiNTk"; + sout << "sFxsoN8YjwHqYUAp+hfnu1Vh2ovyemAUmo87vuG7at6f8MgFSuZffmBkGuijKNDDy7OrHoh7+5/+"; + sout << "aOkcvp0pW3ONZ4l6peRNvzaW5DEBTvcZGvRwVCHWII1eGpzeJKaHWvDfLqjaPkFrG5pR7SGCY/9L"; + sout << "73W2U0JCe2h7VjWbCM7hdvJEgYi/mEarVQpt+0P834es6Rm9rsMCbgbrWl7uv35+LVMTHU29Oxln"; + sout << "bDzBUJQs5KIA81IWR3R7D+HuJvpMkAYMF73c1owI7K74SBOsTq1ayC81aNlK7YwOOjZyBqwsQ5sy"; + sout << "zZi0k9AcKRGmTC323o7Tp/n/gkAU3NObTnqPEJitjGloXqrhPvorixBhHXSZy+wgL5R+04KiF1uU"; + sout << "LEFOzJ0zKUMstTB+fgC7D6ZnVEtUq3HEYnmaRRwEhRSgMTLXE8VvnOdo802pMVN5GMCkH299rJm5"; + sout << "Ina8mTwlC9JrNuYHot5KK/Gny4KPyUeS51cifByPwroemwBHe9EmKCkcEJPoDpG3QMaV36aopyJl"; + sout << "GwhZxaZSqbut9XSWr0IMxHUkFeslRB+n/7Vx+xWpDNjQ7JA5S/B0ZW+YBQPcjA3sRQTey25JD4Jy"; + sout << "RsULxNY5e3mjn59fI8OpBOYfNPTt2Jzppm1GDpym0LHuz7KZ6xk6QAyogk+HMjC/5RcQA7zJWDRM"; + sout << "dXC4CXUjrBxVzmm/YHXv76LrsaFdzJgn+/qzlM6IvIgicMhcJl+hA1swTkgcw6JRalJiDqnvapKP"; + sout << "V+T+/X5PSNMswgZURHQJ2l0PkMrUT909pBOC9t4GCsK8k4rYS2o0I0UYfcpm4jMRU5X34zlT8Qv+"; + sout << "GV3mA0oGq1U2dJwArlPX3gI5sZ2Jsw7Qa5edvQNG5GoRb2j2Muo4AkZXXjbx0KEa5leLIhVL4BAE"; + sout << "2GTdbL7T8hUGY3QlRQGwSVAytjUfXg4jCyn9w6ZbxUOu5MDBuCEtrhRSJNKuBLInK3Bh+fr2FshC"; + sout << "T1eDtIFE2EDEaSbLj4NCNWpTFdKMXZ9CQg2VtoVOIJfgKzqAjjcWX8kqWpMFlQgtdTfIqN7gnFit"; + sout << "do/FO0OzLghevyexHdl+Ze+MjITKOF0mTPPMkcIYcINIR1za6q3rLDZg03+GouzYhL8lwM3WAnkg"; + sout << "Qg+NM6reQATKFK3ieOxacZYnIwOR/ZMM/lO/rHY/ZbdAnJHbMBWwRtK1vDi+o+ZgS7EgsDpsmz/l"; + sout << "PguXPK0Ws51OUhIJJ5YDBv+nVPJabxOYV3dU0z49xFpxNTW9pTISo8mKZvLp2D765kExGJ9YKoAx"; + sout << "Hfi6WEg3pFS9YQLNhOZjE4bQThugIWXhi+2OPgqUIUoV5ctSnP5Lv+xhbkZfjnQQQQffrrU4peSz"; + sout << "6CuNEVLuNuG/mc3WEDZwf1HxYv3u9pr7A79QG0EROf23zPzaf5biE9e9xH+ruPApRHM58H2RpxXU"; + sout << "RlkYnfoAUqyvT3Lhhk6ngv8Axhi4otvz7sRiXQmZO7mtzWzsCTkCJoziwRKlD6P6LYnbm4fRYP1M"; + sout << "MvOuW3NhsQNrsDtgMuvqiVQpRzg157ES1i1qnTjJxTD5emK1RljuQEetbGksyetctWdWiEd8ZfSh"; + sout << "DHBJC2FLucmkMt0LHsVPnk4ni055uMRdKPRKjTE2MjpEsxR52xiWR3MtwXiEhH9fZnUl1IdBl3PG"; + sout << "TfLiZ286m4ePm6JOgNM1chtZir+q8pr4ghk/xycWvHmgkqT9dQcFP8iEtlVLCS22/2mS79cTev2r"; + sout << "yE90otp9vibcTnpORzrnLrMhhpmYRTxRjRaHGhwdJYluARJFBBVTMEenK2ubdLOJ8skZjLzPv1dt"; + sout << "9IrO1sNUwrMpEie8PG7D7DzQ7//jdlC/HUZaGKrwj5aMUULi+ZYiBLYoeL4N8ozAK1u3KtXLKlRE"; + sout << "3Akys4Py8+CmrY5qaaDOXZvwl3FF3skmGhx5KValRXrbndqr3Cks0hXglHgNonZh795galZwu0Jp"; + sout << "ww/mTQLCV0djTdEfjXBUnP49zyGXWWsEsl2jfqEAfBDcT4+mMzAUtBSwwPJYXXAJQz45R32MThNb"; + sout << "k21X+rw63QJe0zIbOJepHz3jaedMkj8GKNYBjqzibNqfYelunBUqW0bpi81HYdN5OFY/3GNKgygG"; + sout << "4R5HJaP+x9e1HxehpI/4pKFC+TAIb29uSV5GtkNTb1fYLm0kjeCZNA5GKtf42gBY52N6STl+lcI0"; + sout << "gD+jJ/ogknne3sRtEJEtCFFe1c50oikyJamQbeUO1PcDUBt8Phl1rI/p4PTP+H686usJVhSDY+b5"; + sout << "9CdS6F7XSSDiXlpFl+Esex30fRZ8zAQsTo9oN0sSZUUJKcyVk0dCqw2mHWPpyM6hYKQ3ij1nYjYl"; + sout << "3PzRfFMlu+dgStcBn70jvlEv5pOWXb2OqrN9nJtb29n8jrB2K2nlbcYoPPiQ3yXk+Wpom82LoT5W"; + sout << "F9NeNwwAB4EDWtB96OU6noW8NHJj7NiADQJGvQpk/3JzIzeBJQCxULYJMRJdBKf61+24F791reHa"; + sout << "qrH+rLUrrv05dIPDTUvGW5LQLTTFFa59OmMIu7WJE7Ln6gMIwDw3FXnGFzaWnHlHL/9jJ0zM1FQL"; + sout << "kfK4wTd++GbeI0gsnXWFK0N0kV/FiHm++J4udWwIXZxH7qZCHtwlT/5oGDVujtAtOPag+txUrjVc"; + sout << "G4iLeiPbV/2Vfc2D1oV5/yyXDDii9qrLH6SOOfgvdiJZr7X3uMUIDGO75x5wBDSxr9t3I2CrX2dM"; + sout << "M6kD7U1+bf5QVRbkh3Us4NAhFVnLNEcrm0x9Yx0wRmxPKgJeGGbWi7/BHi8ShIFllizuxuMyfypC"; + sout << "hhzSlxxbYAQwtcC3cHEnyYZAO7HC6hyke+HQJfxAmKyfguGtzEzsiG18XJVruwz7IoOpZS/O71zy"; + sout << "Nv+T8trOhy59ZUAgIsCAAQJYEBWl/T/qFtkE+tITbVRKtHjbxHSeN12OnHFRoKguJYaakTo4qLs0"; + sout << "fr4E4nZUMfjdF7oI7YutegY9TkiJ9ujLJw4pfY1XRtPrRukEl8orypWXq0gErnYO/RVtK3XImrDp"; + sout << "LY5sXH5pNzkqVH9VCl6lh9sg2HWjNwv9bDcDlIhvTL19Mx9yUtx/iQtG/OKy22tW6ByahPNnMNtA"; + sout << "tBVB38RLf6eJr68mhn10Qg68cXxVL7/zEIZd9rUaCo8xCzeFblDNErKfG02JJ6fbQ6M6ZdNez7Q0"; + sout << "x2IYbz2DEk0wHmR7OtA/oTFMvJlyMt+dDWTEpHnvqkbe+veENpxn2WWy3UsumkvhhtzzmLxyD6Sh"; + sout << "mMbMPwgUjvMG51JfRrgMfJzT49z0sebSfzvid/9QV4lNkR7s9nfUJEwAued4S4klRy3LiFdQhjQR"; + sout << "FOZZNqUge8vxVOzVCfS+xsjvnGrd7azt7LJg6wPXFgPfeE2bRlx+8AoRFG7SUpudmm/bkNw+uNgS"; + sout << "YRdaH8p16RyNoMlSfi/7BNDhtKwrl202pVuCqhFey0mPYehYee2HhLZs6ph+HKMYy8lZ/ac1Q17d"; + sout << "1tcI4WH0Hz0B/3GWl8xWfoq2OO40EIjuCPNhk70MpiytWXggJrKoKPu52GOqTU8+jZ6F+u6U2muZ"; + sout << "6QZLYXDwPaNz/lq5U4ACw767DkhUHd1/h0g6r/RwtLKxdrzYldQto99TAMmHc+z9aIciTv7kl/Gs"; + sout << "WA58nI8aODhwjIkOGaExdlR1k/3JR2tAAj5vRzYlJeakhAA82pA+8xMPZr3HRKMPCcCJZiOFUYkA"; + sout << "AAAIGp0hTQAAAAEAAAABAAO4M25wAAAAAQAAAAEAAAABAAAD22vyAAADGD3aK6jS5oPrMPzwAAjt"; + sout << "xbxcPWpGahG+GZHRkwABJjSNeG3dJthsOADhIXZ+4e8jAWWTyn1FQrzNkDMdvhTq6iuNBtgaodaU"; + sout << "LlEAl217cnKFEVX0Gz4mAAAAAQAQ+2CiQuAAAJqBbzxJicET21QADU29xADCbz6wuq0KAV9tgJYC"; + sout << "z4z1fuKC2IFuACLkvCvxAAAAaRl6AAKB7r5zjcz2HSjBdaDc1QgdOUAzmTdpgegYYD9XVQtaphsW"; + sout << "sUhXeeIliGybsdhruMlJ8hi2YzzBBZc1GwjNawB2sz18UCIbQKBoDMBo39MbAAAAP1M+WcfB4DGK"; + sout << "yDXgydAAAACCIiA890S+Coil7foud6zPIspdcqIDxhws3Kiht10f5NDHUhjgTYsAjtAgSoGLQ64J"; + sout << "iNm0zrLDZkuHC4Hm69wkOD43AHbIxu+k5r5vgCW/m4traQAAbol6XlYIroFESJxqPkH6Zdxu3wnY"; + sout << "DA1HzRnsBIlQ6SvPrboLtfVXncBN7aM7vLaX177RTxKn2qep5WX3/yG9sAP1QSkqapaSdLpB0Q2N"; + sout << "9t4O8ryKiLFvPQrbFhK3Pux8X2PWKHAzZfkmUSQh+OiOICNKDbPRY0Es2GaX3Qnl2oIVflDINm0i"; + sout << "Y6t8x0AZriURrafLgtzuoqaC7rAYwb1iQoFJPwXiZJNIk6W8g0623IGSygd7aZ+xqv5NjU+q0C9d"; + sout << "0PJs8kZ7Db15tjjTBIoS9gBNefJg2D5TCJjEPSzPcZeTjPzeZJeK7v1KHgunnfi6igrT1efQvxDa"; + sout << "KlbIBaR5Mk6fxdih+YOFTnVsT1mE8b8nffJ566Rgc6azYkJizsRen8b2g7jzgv2O0BOI7VtdrVpV"; + sout << "BxoQ58p3EGMCV/mlDhcfVTctvr8hrjUu4OhN7UUoqi94JRXL5XbNDcCbFwXScln/Lm0bkzNsDnlM"; + sout << "R4OqLiz3ktIpJhAtUemrNZtPa0+/Ge6PILu4jPNom92BcmxjJusoTOrHSTQg6cYtcSB8spjVRdpj"; + sout << "tFfYfDY/6rxKpE5X+LU8P59b1/cdrNSbqkuRIuGFFzkQDv1BUNfa50aIKo53OmvDWkeEpkpS2xD+"; + sout << "hz3uN05FteKkZ7kHDPmEEJ2lUHB4oicxgkseb8nWToePhOkr2JcNakTx6yc+ZT7bzPcoT0hueCpg"; + sout << "Ljzb2AQ0UdAJEAD9eruD1rF9PDEaXD/W4D3ja2EgvEY9wSMR56Ne8LSi0jeFjp8jKxcbmBo848xY"; + sout << "dofjq919a3V9KDRUZ3d9t2Bmfc4yFoS6nBZCVy9klxK2ZaKePGjeCbENr08cfenUT7kA+ZQURi4u"; + sout << "WEwCgdui67K4H5NPvbq+QAoKn9d+ZHx2YfullZgBCi34oLzT23yccD4uxP4GbZVXakvv9lLoq+rf"; + sout << "T2uxZr36V3aJlhqVoSReJ4Oqz5qbTKNH9F3S14GFJOTByKZP4XJwVytHIcuu9JLpQPDt/nkREX4Q"; + sout << "0cKeNqXqdujkp9XCAZEJ1RkvHE+F4tiUCTKHvoCslgs4x9alnUWVps4Qy5CMaB8+x1eHM3gkuf0Y"; + sout << "qNIFuIXUybNj0hnKO3IV5k8m2MURJiZsmM/dg0fkwJG0DIAmCFCXdugfHMHpqXZR7IBfA+Q6BICL"; + sout << "Njxa4BWCRQoeKc9LD1Nits6rbVmqkKPAlwc+yhfYzny/3/ZHfkMuF/s6CBYugf3dYMIch10I/QRi"; + sout << "9mVUWdCIacS1G42Yaja31j4s+304a/luTAsDKlOObmKTzDV2fDhbKSLToOY9iXxRK6KJY0GydspW"; + sout << "ak8OvtNCCVvEsA5NccWOn8njo/sryvQgM9yzV8XCI2MDF0blRNFJQR72EDdG9NXKgg1gj/vG51dH"; + sout << "CHj9E5ffneorEoPjn8pfOny29jcb16D9lVc1zcp7v83pLXuAyUp3lNC2ff/PcIbRMpNns1MnyV2J"; + sout << "8gjsQTqTL9AAcvz04ohVo1LbZRl6rl2D0PHAOCOJQAJ65OA6BHxfeh7EDDnniKfLlI3CD6W5XZxW"; + sout << "KPVAa4RGyCcSLbmb5d658vkB9rBKvaqncbiarQVbyLQ5cCRokVd9HTmqYX4Ky+0yPVcmoXAS2B7Z"; + sout << "S0zXkeuA8Lo2ZYRad7/CBg6foq40dpQ9EPw19H0YquMXeGVEsgSvu7jBCYrtuJzL/wlxe+plrjXy"; + sout << "E4iwHoZ78KKyOgwTSqKtynp20YkZG9NC59XWvyrd0oTBrGlFpOjer+OtcEgl+3XYl1q66yJTyQie"; + sout << "x/Q7AjBZWPmlhEkktLtXYW1yIqr+EmYJqJJ4Bbssvsa1/jd/EmZQV9//HRBmber5Kw0C8royzI93"; + sout << "uzM2t59sG8hnOHXCjAQgV/HuTCTD3hzeDgrv0aHMfIC+mx5Xkt7mLhrhLeODhJxguHb9a4pQwhfK"; + sout << "gD0DKhyk0RwVZBNF2+3UfK6bJ1zeUgf0FJ6Dvak5i2+BUVtohn2qjbcIRFZKEtE8Oekca+FE+9+S"; + sout << "mTsAIvR/cbfuDVG4cmQE8sPP2nB75KQJHylHW3ZvR5v1icnLfob/CfYMmfSPYDgDFx9cYroX/luf"; + sout << "lLKKxxQ+1r4TIUGKcBV0ZJcvjCwzU3C2eG6oN9P99+qndIXYmvRvpLmq5QYyjsRHmXn2MTKSa8nl"; + sout << "y+vPwWboutMugCUmEfTXhvc33o5/KKWU9cVSP5G3mrGBDu4BfG+GeqP2+DtbHTi9oOYm7723fFF8"; + sout << "CUxOvoAQkOQQWC7Wfd1QGgsVPxhz7FTsviov+G164lh/4Qlkqy/8pzEJdQK9uYmCsvcip7tKbzS9"; + sout << "lRX9OoPwQZQPWPosqZjbYBmEVr/e9jtFv+2pH8p+5S0GI4qiNg8n2fZE+isn9XamDpOqNEpKk1gd"; + sout << "Xg/ombkOajBBxBlpNbXQa4aZylfiz3ANw01leRleTqeHB3HdNnTaJn3mr/lkKg3DYIA1N9/Iv8jQ"; + sout << "1lQexZCF0jEDg51sUNPF2yDGnpSpZVkDLXCGIlKe4BI1/pR/kiqrEdg/ITP3+tVhND3x/pBWKlNs"; + sout << "AteO6/IZrK4XOQ1DnZgJSAdz3Qe2uO4wY/7MQBhFO37V5gAQAQVLcXsik3FDcobieQdvvoJM2Z3b"; + sout << "i+FGg1vcn01gNGDH0PvTKSEN/KFYTpyKDw3sjzgLbfEPOJcGmIFj4JUuaEYB2TfQfzOwzM7QCDQa"; + sout << "6jbMHpN4VvRPQqZ879vRWzHhg8P2M3rJmcQsrjaJDIL7t41eMOmB3Ey4Vajya50NPPyVjXEtFdYj"; + sout << "TPc49LO+npIw81WlxPbynhk9lmLeYwr8LALpc5dNFr8BkfCNc3C+IFi+RHSdq6tj7vXqvE+KtH+7"; + sout << "lFZNdvI4hFiIrDJssynPUFAR7bob0zEq2RxDGJBSe7AmmwhUgGCHRFXlpJg8bFCFs4rCgVrBctry"; + sout << "AY5TT1WWTTa5P/jPOCTfrUOwjxgGD6ubejoyGIoBPfsRM57XSYv9gQKTi5hk3k4qHQnItrf9p/AO"; + sout << "LIxqMGF/nky2Z11JWz1krqtE9phBhmXnMk9ap1YWJd83L5Et4wxpud+J2JxJoM4kjtnG+HXqdEsC"; + sout << "KCg7bNrSyAjtohz2vQvpXsZjWkq7QT+fTzI134PhMzbqnsWdKi/dDsZBNbXB4ua60uqFb/tLtb7n"; + sout << "sNBfRnWfch0qZLuil5eAWkLlRC4zyaF1zJmchHnVync81bHIJVNj22+ctIbdN+P7aCYBA1n3sl1U"; + sout << "0LDKWsXycnOxmKFUMm8kh+f/eP3BQu2Pe5W1tYDQnge9rvF1072VXvjpDBIna/VPuDwCFo2jPPl3"; + sout << "jbxOuvZTOSTWOwyiX5B8aLBfRXm7wXAwRoRy4WLr2dsiGFGSwP+pZlJLfNGY1vmbILPg/9iNNMdW"; + sout << "8MBRnhh4COlSw5/NxCl3FtRLfO3uLe3Z9+tBpX2qkUHcFyeulVFINjZaNB8gUBPr/Ub6/pk603yY"; + sout << "YTEqMSo239BJZvCsJmpgZaCA/weLepsvaEzpJvjI3Gvo1Jf/zdm+eY8VVwprkj4WKWKWjHq4miAf"; + sout << "y4TRxVYOG1lCH/cGcbepOcRMUG6uY0jTP7tqNfd4Re8IXTLCq1P4EOGuvkIe8hT/YdPP2B6W+Rd1"; + sout << "z37NgjkKJqiSZBxpnj4WZLmdJ2eOUUJOR4Mce1mHEKN6gOhzzFRK98ZB0WEI6jITJ4wNcdDpw7Eh"; + sout << "1Sg3JvoigX7Onq9eCnb+5uXFQbQrffuNiTCa++zx5Zm89m7ZuTJ5NFhssQ3v/mv8Wesl/9lebi05"; + sout << "KOVcyMAq9Fzx5Wtj9GMHrTNTyzS33CGVrx3Imh88CXZ79PBLBO9V2Lpjk/yIuCyP9pKezOZDzTED"; + sout << "rrMRA+AkfStHcPOL8WwGaAQgM+AlF6FVrKzs8TQW3hrBUztSeOFrxpDLod53zDcNaWRe/gpKtB/4"; + sout << "RKRjjso+LANmNLW/IepM3Jy3sepn4slG9YS/up5puIZY6zrsw8n7nnejBrUBSgZNUaeYLCLhWcWC"; + sout << "aa/M2BfeOT+X9PXJyzvQxDx/fw2duaVO0yRYb1ZNwtoOXWWXzBmoWKVrlXcegQ1wDzDkLW6lwnww"; + sout << "4wJ2tMR5bWMUJii/0Ep50BMgAG2TrSK5jrBpiAaGJaSOB7k39YsZ/3/8rozN1IKa2mrK6VqkvY2Y"; + sout << "AeOhdivfgdgccST6Ymbe81UvjsiVeCQx6tI2RcnR7NTxLC3guqwsSirHXDHTflWI8aP4bPb895Nq"; + sout << "7JdomRqug/eiaoSfv4AotVU89pWyzC6FAN8UdnEXjZvYNA0gf4plXcMvlLKmRfHj/vs5x3v878/m"; + sout << "elmdsnYX+sIHEc3hAZSISJoTZkGEptELFMW85Rj87/J7d7cC1q6vTMSNHgpysPX+2M9BNMEGDpJh"; + sout << "baYrMV6ASpbVikZm97PCzHNnwQAxsOPIFoA/ZtoyXleszvtD2jgNsKiMQo92dIFDAJU+4FIDWG06"; + sout << "7BoF0nZgUUrfwzC0kL6cN/ui9mBpnAu09t/9dValOs+/Prfp5NfdYesdUpJCqt+o3/VcgGCd5OfD"; + sout << "1n0lyUjd0g69rM73kEb7jOFV6LhN/5sfmzql9DIqUYtaxs5nJ94eFkp+lSLeXKJqBIrxEql0JL8H"; + sout << "ISH+HMQBpE/hWIkXJ/RryGxYv/SLm4mfKYNgp8i0KKzpp5fiK9ZlJVmyLAM0eRAMllUa4j7Q0zNl"; + sout << "p1EO0iv3pD9Lhbb60VlD7ObwhDApO6DpIy0mjj1G31DWJi4uEzYrUttVsSpj6a2+rrI+7nMA1wDO"; + sout << "OEGWM77EjjTElz2sU+1w9gTNn6j3uOFu5Y4+/ysXnAIVtg0zQluCSflOwSAFyUvpBNoKVnQD+mGH"; + sout << "870jvWJ/y5uoyjwaxl33+t79t78PC8ycBjvu5M+RFchnLW3QqUSP4l84gYj5oLsz38LzvrRU+lUk"; + sout << "mFDmrCoDe29pLZXYwnI7HHYPEtPBJX2fdBaudv244Z/a0NPgMSHS2uwdTkdSdWMK1lTBt1GB0+XE"; + sout << "HqdlnBbMPq9U4y9Uhh1hxdBZHvAcrhYjXDSvbDHDzZRE7acCvLmAOR6aRoGup/WIowNtU/wXfK5c"; + sout << "irNg4hSWLIjPNB54AkPsPc99IxGPH66PnE2sAWFALd7E3GnAo9N7t5WniKWWI3xtmQLbSOHwdI37"; + sout << "iAOHsNgX5TjGzUYwwnQndubqfV0NY4/V86wO1Rar1qBsIUSylkids760J0HqyUsWNXxsS6T9Hq8y"; + sout << "0szlFA/E8IBYgdH2LwkayEbDbKMsonJC+9NhEx8u3dj4ckmmKXOY8NT97XlqxomKyexbsKu9mb1G"; + sout << "NtxTI3Yt2mDKyt4hzMvhF/DgvKErMfTMTJ/h21uaZXrmwu9G+yQQJVaB/LEfrAenakK5mhmIR0Ne"; + sout << "+e9cQQSr+iD8oZMy/IkM5rd5FOk8FY9pZb91hXaWNIMDfwyfZZATaJruQfO5cgjMGXSjpq2gFHW2"; + sout << "9m1zGxyWrEkm2lV6FQlPLHBGybLKxq0VNUqIsdM8VX9Kv+6i8UgN9Ee+hyd4a9In41m0Z50jqgja"; + sout << "Jojh/Np7WSPqnPhNCv6/K0teVxq3bo3UfaXlrEZtkAi0hgy67XtZoBBQ2Se0ZWzOntgP3Yo3OI1S"; + sout << "02Ta7r55Ox/WYR5NoFoT5P4ihVcZVmD8DtuFLtjWeWKrHAWPpDDe0RYpN9Ma/DOOlx+vf1Ir7kzz"; + sout << "oXXOiVjM/hgKR9QjX1N05clcmkG8uSoxtaOU6Zov8BKIZAFfTMdQrhKBrW+Xgi5sKef2mBhPGwyx"; + sout << "pnudRmwMFKk1D8UXtuoEHYQX03cvYEDmWberg887C1ca4UdAg7b5/mysr2g5Md+vGHrqRtJE5Zhx"; + sout << "p2BFSzeFV83Zpe7PYf2uTLffLleJSKV5l6ohKDh4V2P4ztKvNwTKZYnYFdWlfMPUz3svuvihxG2h"; + sout << "OXNZJ9+Byx7v6RrU/42lLyY6p2078QYHif06BBp2267VkKcRN4pP9LlS2JECoex4vg43X/dE/48f"; + sout << "DbRfW8KeR5kSOnH9dDWcwdoco8MBwOV47KIkte6i3vj0BpAQZD+GFuR5EfIBxXuklMYFr7KRS9xK"; + sout << "oHMGA00uxrV5VWrzCm5lV4l4oMQcv9/hqFTLKo7nlQP5yu/TyoF+OSXP/qKX7N+CASLfNtwL1Ux5"; + sout << "fweiUkaKnLZh8f+bxh4x0/UO3H0LWyq07/1evSYUBQhHkzSYPhkTq2msJ+eFBc0+gpbOWntK37JY"; + sout << "udFJvL50jNoZf9clWcqzxnK5M/rqMjVCi1zzboiC1vyxWPhR8QMvEMRZ8XpVW/cAdLz5R+M3DGms"; + sout << "KCJtIpxrduXr5+Bq09jn/1oi4qro2/ikBRTVTLj4js+yaI2t0uE6XQX1PO2JwSHW3V9krhJ/7JJh"; + sout << "sLH3xidX1mf+O/hgu6NU1p4yrsGjz6qTYOQSjU7cVsCMRxW8y9vCWh2PcumONvuWeApgNKkkQj5d"; + sout << "c+8ftbo70YFtVWhKjTG4BHpNh6fXDC+0ZmkhTFEyHAwx0hF5k7217oTa7aKEQp2qGEROYxe/wBMO"; + sout << "Kq/lLPLQpVEldVSoOzmFNxGTkOo0j/bjAtwqwEd6mA87f8PvCiAWLTRpPFla4UbU8X0t82Ur943f"; + sout << "m+hm5CdEBtyV9P1uxMO6ez1SI5YezMPfMCkhpFqozJsPrjkPMIUSDc+WtF8frIfNsWWmjukFwFyB"; + sout << "3UU7r1AHWLJlEo4g9XPsMks3LkFeQ30cwm8Lnlj4wAistdLd3HNBydGSxO48Uaya2M6dfm3Xeu8Z"; + sout << "77+7God1knNpAnL7GFhTTTV7XcFxEr76Mfdt+KJTXHIf1M8ofe67hLrsT9FunpPmopFrNp69v2TD"; + sout << "H0/SQbseKjykzwvJkK46UeJFxYabRHSgL67prx6jwl2Cp9JnNIco/hWNPGbFVaEcMbZx5lkinzvT"; + sout << "KgwrJbVJ3oKnY2EJAnNDC2f1F8UpC1qQyEvhkA7g3yaJUKMF74TXqjz6BOhGmzn2c2GQBqzZ6d0/"; + sout << "Ko5f97NNO360xzWgTIiLbg1sPrA8OllufpaQRxyqFzTlU/kSU2BDLWXM2Iy6tGUAWCUGIkMgOUia"; + sout << "cXdp07NMyxlTzWBux4nmXRndTn9y53qZXjkeF87G+ZG2Q4lUmp2hunx+dyYVdOFWCePCYJ8TcORQ"; + sout << "kPI7Jcacu7QjoOz652vsf+PQPsarO1KlyKqX4rFudTs79TVOYdHsf53Y7RIMqp1NAFDjAsyFPdq3"; + sout << "YpOr3UrJu2RZm+eumrTK/JMHSRcbJuAyHJsBITQgvmisRy1wlcDqzw3k8OcfpkjFsJNPNaQ1kL7t"; + sout << "HiNZMtrv3t0ER3+NZL6PnWWp/tN+ASeE5p75S2UyKtgLONa6xg+gNfU1PzxuqgX6CU1me3GLx/xs"; + sout << "9pO0eUq57cy+gFcDzvTPw51WUrsaQhD+ayfUyRw9vpBiw34NQY1vMNReMgl6EbvRxFULAkx/oDFy"; + sout << "wnPkvTkMobxW3JnVGL2mb1iHUKmqQzBsMk6zQKiqtasPGrzg9Reb2DYSVdLpZWCha1BI8ySpOQy/"; + sout << "ndBNtz6DKHgIQ0lp2pzmJbU7JgqZD3UmfnKzHcjskto87KRz15aWYTnIFg7INMFMNUXqUYizX4V/"; + sout << "I6UpRr7GvKsITthFaIdwb1AGzKXRRmv7NrsH0x8ip+p6mDp/Sb9dUG0N4ODEYfLlGzA1U3KbZAWg"; + sout << "tfH2AF5+vVKtgUuuMH7XPQ9FBT8HwvqygWub9K6kLC1qwH9pK8YObWtrfOdTz0yfOwhaMo/kZCzS"; + sout << "+CWcit3PkXtOuNXlq3oO7fYcRDPdCOFbUs/KV522grIRqZ7mgludKZqP9b702FVEGFT+7TbxjS44"; + sout << "sx/F5YjT/XEfCP4ivZNjSYdEsQvSWfvnb5IBkjJafE8xnihUwcUNYYY9gUvKqZNm5XjKiTFLhPJR"; + sout << "SlaNrK/pzAMEDwVzHWPeF0ZG9y22WsIFfsYgDqJwjmXzd7yLkKBBe+NuRPlvhh/Adtzr+P/4NMoz"; + sout << "f7dcJzr/+VJMCqoq3tiRo6j5nOm3sylEvA7/HTethnRyHx494FMdSwJ1t4AXxrH1dSFNGOXWJ8Tn"; + sout << "WMi25e99RMpe+Fy3fBygtCiXgQ5/sozmxR6LEjv80uoqL0sopOWKSKe+aGZVzzPHhkfJV/HY9N9O"; + sout << "vJuNA1Cwp8X9OeFYzusBXdrWMfzTxfeYg6Qj7coKEToRF0SRuskob0+oWzufMPJVCCRRevu2BodH"; + sout << "QWE5HiZMuR4SYbEgwQBnH7RbF9vW/DB57H3HRvRc+NBpbYZcWipQyi3cy1RwmSNOtexX6XdqoXg/"; + sout << "iIBLmFENMiByy8AljjARQPUiH6OBU7xAY4zgiBrM0JCtsNlhnDuF1n1udIMnjmUcjAlJ4OkG9Q0L"; + sout << "w4RvKt6/d7xijIvnX+i0+jH898Jhe+fX+vu9prCfPxnhDPCeDOre7g5xcTAxpX9PyoFww93kdkUq"; + sout << "FAt7//v+bzkV8FvhUzso67ANCtS9w+ZmeEoU/ePml8oW2xGaOCx51hzfSzXhcvi6DQP70QhWGFKE"; + sout << "lqjV3s5Ra5WC4P0Ipqq6PKuho8bJn6hz4nlhlLaSF5GfDpqQKzHxvxvDw0//PJQGNv7LoCu5xtmf"; + sout << "u/CnU9wUfWie3YRDs6023yhdgyuY2nbSWSadumgv/zBNJH8zdP156zru/LLvr05hbn5QMlvdg3H6"; + sout << "jS5SxCi/FXaW5sOvngedtvRwZuN20nHHZQIFLr0JebQGxOS3Pceh6CeSmslzrvflpnrFJJk7wKDw"; + sout << "ponOE1ChzU50pPPl1Cr6vtz3mQ+cSSUfy/yj3jiIjVzs1ZRa73+iC3ibUKWzUcuZFeJtZ63UFYTL"; + sout << "I2oz6V2ylX4Swqr6INVp1abOAWFWba12EP97P36KMGN+Srm5E/BoowWKhpy9uR+AqQVP+NP5BW/m"; + sout << "02UQyGCthZMnbw7lSD+0Ihu4E/j0Zfh6K2vBz72vxGa9BW3aDUgNvLRyU4CyLf/X8Q5+73iT6Cwl"; + sout << "dlFumHdywzarpmRS01qfzCT7WN4Pd5HKKNvq9KaaUOUQqmkXrrGaaIoNTAKHh04hhi6BF5rfueW2"; + sout << "rOFlukiABYz2pLHwd8JCTUUH+l5ly6G5NhNWGdM9AH2tPxUpmqW2D3hmaF7k5I+ehdNQKHxnzXUB"; + sout << "Wv7O540zfZEVkF8cGAuaF9rvd4GTvTSU/0hO9JerDpYPDcwu7V7JT0lDjjlVyf8Rzr/5QtLAvsmR"; + sout << "ptFG+VFBQFzS1oSd0yQ5sSLnJbZcABQq5zK9kYMlSVg+DERu0yqSaxk7f5Kjm+KHLjoSR+9th9lB"; + sout << "AydFXCQdhbntrBkbSjG5t71xVSyhxcazUYoeulumJbiDMo2Nz/GBvviJL3ZAAyq2D0PFXJsVwk+o"; + sout << "0sSWItsGGLiiw55G5NN5KmQ/hxhbjxCMjzQ5ALsREiye7MPDdg9GgwA7xa7NdVdjprN8RxNlCS9a"; + sout << "bQRBO+z+P+2MoN7tk6JIWn2KU7ex6R6FKjwHz+kU4lJRaJF8LzpujYtTw78zAgqS2uTPDGpHdfJ9"; + sout << "uxW9x2IvyYg63TG7xBKS4+iFmSJGHsR5naozx3D72vCZ3jTe8D/fv8FlHOoPyYO4gOZgyC/cOIdC"; + sout << "jM/EFJKHkL99pSetGn8KALo7QbrosHpZER2s1nhIXc/kfG2rq+scF0ECChECnN9sVYEzerYuNXzk"; + sout << "nsAPu/3W0xYMg+TE0xQTdO8OTHsfnbGAm5ELN0wxTVfdXIE9QpYoSSRGtHphyFoKcOgRkyFHXMmb"; + sout << "zPuwR4bRhvT/HiW/bPnNrOBL2qpeoMKmhRyhU/8FpgoYANV+tuh2DCWGSu+b/xgzGO9kuqoekBaf"; + sout << "PGQjh+kV5tWT5u+NTO2SxkBaGLZkBK8j0a/h7CtCmwpK+7Hq6WyiFUgAUenY8FiZAgd/4lPXvcJf"; + sout << "JV+e2P83P/iAnnfFH9rt48Bq54rTjuhgDJ1FGHmW6uzHqX1XTINTidVSukSpy8+hpZvAiiNrQfTc"; + sout << "clOqzIsuuJVBivD5t9/BwKOp3dee9ZyDc9qTH9fjqkq8dKSjZjwTil0meMxI3EEhCOssrXql/+jH"; + sout << "JjPHBao4OnlnblFyPUFDp0MyFP3OE3o7Wa/RYrhLuq+edYM9aKgdWzvSbJAn8/LGHDt3/iH5joVH"; + sout << "UrMD9vp93N8RBsdbMY6hDoHEmHWXJvQAyVQpS/urgZbjbdtKT+vQKvwrRlUY4osNJ5fVGVDfY1qx"; + sout << "+3u5BWq4Fd2UwsKRM8Ut6m4dm0yog+pr3ZVcPfgGsyXG5HZdyDpdg0AJSb/wOAJiNZQGtvZQ87jJ"; + sout << "2Vk3fIVSb0gk2p4vINajCZksZCWordROndVW9kKviwD4QBpuvxTCgW1ww38M+0osHHzptdC/h98m"; + sout << "tFXeW1AOwfG41mRwckQs0MWVn5Be4cUFiZQ+qkgH99p+3ZoNiGdk2CCS5ABvejuAiJ07wbiAzi5S"; + sout << "2kdjxM6GgpBjmT3gOz3jvc4dgcXA6RZC+sbh9k3U7LakNE4PfZ/WME3qbN1nuj+9/xhHoa5/4gVF"; + sout << "qM571ayMhfBoKsoChNxVxiQDwMnMkIsbD+1xUEr/9OzIf1fnYKU26pX94koCtsdxN5l0HH5Url3J"; + sout << "+cSXI/XX9lOo/uUM5VW4eZ9aB9zfOx0/fuZTqcVxqkzUz8SH1gfOS2QzGBkoSDATOxVU++tTGsVS"; + sout << "jvCtloGsRzz7cSOdVEyGH7Pf1PweTT5MIyrqdmVDRyYiYcHNhxJvTS7iEcL6wdwXS0iX8NdIp9uq"; + sout << "NCrhqBcAlQZ8vKgXLe+FJ/oT0lwZQaF46SFEWjbvP8fm0xMVGm+pkaniNAVn2m7D1FADdIfcNTZt"; + sout << "QjWrboY5ZfHbOwOl06gTV9JTQGRjFDUjy98D4st3sQUZBq2JaFJ32Qg+fy7knNSrpu7sHWloP+vT"; + sout << "PvoJGbK/hFP0Zn5ZwN5V5UEbE2JzoJE91+VHhUwbh9TzNBxwnDroOxwpzHvJnCxnb8SPDKKrdT21"; + sout << "n+U+kVkf+0MiLo/GEQMa/365TlPqVNJDKzbtOwFzAV4/001Z8oHQwWqOeMZtOGkgbLx+mSs/sfCC"; + sout << "fMbMfcXXyjRgPYsG0JiQHXnB1AlQ7jdTslLgpTLra1uWB7PdFydbSd9xBrGJliq+s1VLML5vtHOx"; + sout << "zIkRy3z/6F985ZiZN991fTL8O+dnIr7/bhRAfnGz9zjdjVhfMCSZ1qbNg+pDIFre1eRezr7Vi/sf"; + sout << "/EAvYapgYXBpO2YHMivon5jd8BHIiDuMxvO3k8gtoc3N7cS4gRFbBX3KF4OY7st3c2TVBeXFGfQ4"; + sout << "pdfpb7uwAfOLY02qrouv7egOIQkfvVGGtFHHWyF7Ua+rBmdUgSN7qyAoX5ImUCkJBfYWKpDyV+qE"; + sout << "sl2Zv6TAlmoV1Ejb5zpSRGKTOZQeOL32IC4ObnmNN8Tt6PF2jWHKfF+E2EvDUnkvuryJgdaat1/e"; + sout << "ROv+l2tnyxLcsudPd5IC6PX1PtyQ/VRXVjPdgALubsid997mfwVUEKqN6lWCnyxWevrjqybNaNGy"; + sout << "TKR9jTz6c1lS/pvGyI6/tRau4N7yCm17IOPL4IS5LQxIrmZM5u4CJbXb66Jc+ugcFY1tLbh/2cOy"; + sout << "+XriVPTbEGDQ4A5uj+7xOo5ZRdoTrdVVPtDk+dlVfKzXzxpb25S81YFqhjO2mhIAQUVimhDvfzJI"; + sout << "3STEBOwpTG27aw//24pQ/l2/HG5fmjCFKnU27+lJU4OZWLVC4xyNFcC45PPCOcMbbGOv89uwVAzB"; + sout << "grsyHpXH8piJ+i48nWyjJYRcMY7QruLJj0XwI/zyfostfynEtzCQ4z7izYem68epGW8hJIno9YC9"; + sout << "ILlnQ3D1Q6ZuLS+DZbbYX5KtL51EGpOXIcURfvVEgUPpJDRszh3+2ftzLqq2kBq7/fp1qBVlX7V7"; + sout << "FpB+Atuo0nScYLh2qkckGVz6FD+sd55R1X87c2nyoq7Mneo23ed3+fOWL08vOS8k/TRYm8sah4Mw"; + sout << "jBMphehu83zxUzzmCxgyw/YqBpNeIvKOl0aG0QKJl+d/B6ZUujG2c/d21wxvGfhUdmPq0fIDY4f6"; + sout << "C+j6/lZmgeQKQtYoAN6ElDRpTXd1A+3QqqcI4B6stZQR4JS+yXVg9cbekGvrt4QQs4JX28naqY1a"; + sout << "2FUNMAiLZ2HiAOR65wSUstVLm7BfKhWmjwtsaC/0EW20Dk2UgnVOrBLnBQ9qAN4b7NmU0WonmxEY"; + sout << "ojTy/G3EUpNlGyA8/vW01VWvnHXY6vXYFIqbr5vVrcva9uTYUR55ZqOoEwieXtBtjMbrfsnHYv8P"; + sout << "tR/v9PeZbgs9t/KMYR/uMBEJ9AZiee7WVJ4h0btOaXK/46+XzeUtI5WcyLNL//5ryxuloF8d4dv+"; + sout << "Cc9x78QO31N65ELT/k1U0egzJxgrw/3Hki+p9JU8IkTsW0KpGjEGZBOV2pgikBdJvtj6FAa3wRUi"; + sout << "PJoSIfsvxzrV2luFEsFgl5yMEW90eOsiLKXcWDhqEEbulmLx2ij8UrCVCY56umeqc+LyPlDfKgrJ"; + sout << "EZjImfmAt2Ygq1eC4diAMOmE1UxpFq6naGmN7qqepdQluKbFsqDlYTpOaqfems/zLJqkcaw70LGr"; + sout << "VMp8Pqv5ii/0GDRXDBhkijv5WzyxkR3EypXXxNR/+1AiALmzJmi+es37MAcJrzTJXLa9VBqivhbR"; + sout << "8dN5+0P+2LwViHb55+k9sXzIjOGnOk0MEWlYHOGjMRTEGFRTN7A9nCyYXy9GNfPK2JcHVf87/hHY"; + sout << "5K2i4i7tUrCe2Csag8f30XMI6neY3ftCMT4Ig5IaZ1sFOq/T7tq9itnCcp4mwlEYPMyMIkG/F/LU"; + sout << "hee2A7h1pyXJfcGezwMznUz7W2ul6nPiyNiukKwkyscniZUpLlaY4QEqucllRuJ/68AaZ4b/Oej2"; + sout << "jyS1Ic2KncAmPZ+1Bg0neMcNSWquSJgYzaDClIpXV8f3PkHrr2uIsC9W3eNshopJHpqLMzuaQSHY"; + sout << "2vAiF94VOhVSyUWpQ/1azlaZt+to+uWlh913iWJJ1W6/ny5AwzWh7M0ikd65/vX/nmzcWDNIe4FB"; + sout << "dXtUgIbWZGSvTnb1Q0ff0C4F+XGSG/27sLrdSbmNv7IjZcqQJkNsqlIjOMUXKRPmfMmOD68qQomE"; + sout << "m02IDYew+Ah0vholXhlgVIa1V5eHItbFm0krwNfQsxBESR4dEQMJbkQAE17x52EneJzQRaNjkDdB"; + sout << "dCl8MkWvCwi4iNl6Mn5jQsO+3K2cC6fYHDRReXmp2OmWXlR600U+oSNyuQjNiUM6CD+q7IV0SWS9"; + sout << "txg/d76GC3ELUb9MtBhtm0dUBT4hCjyuLqMAVhzzPLfuR8ko2hc9l8VziIHDGnusZKle7mQWr84t"; + sout << "L/ERhHo94fQVR+BuogcMUEcFq0cAqYHWzwI41KVB/N77Tud7X5KuhGc5ulVJaPukTGGvqIrBDbBV"; + sout << "HXFCuuptn8cr123muSyZ9qBcXcSs7DRnZ1mFO1LoxW7j+SsuyE/c0bUN7ndjGTUC8OiS+NC8Bjdu"; + sout << "u6oxAUIqlNYIoY7EAPF0WphyKvId9I8knRWQUQhD4L2kk8yNvtYvjdv4udkJAUQydaZut/lnijSg"; + sout << "Ph21q7/c5IlPo22ZIiPI+a3O3wkV07i6vAPoC3zYn5Vu6bFRHSSyolNsOrEvCkWwQ0oTCcm0Mp2S"; + sout << "eiwjZU6/k1AjV5sbx86VksQ/AvKxU7RyMbeuElprxKJIMLdkghibR3x4clnKG1uRzM5c6eINHY97"; + sout << "vYFAN7XsTbpmtA5GgHUtBWVJl6OYhBplfM/ZECAqOp0ZbBpEpfEntR/Go9qLPlXT7uFTcqxKePSv"; + sout << "pYwsS8xLTovG/ELnF7/7NxL1ZPloXw7JXSPoSWEdbQ77qG6nOOYyeYf8i8MCkSC8tBM6rcORlcfl"; + sout << "GTUIeRALSs6B63QHlms4Ev//nz2bSKv7BHMmrA1+2EI23xeZgiuPoKBfdCe/YxF5kXZ0PSlxsj2N"; + sout << "7OCHsUTBiQrVCUoAO5hXZWAlSz93FgBGkc8bh/d6dLW98dXvWPxLdq6vx7URXleswpb3FPFNm8R5"; + sout << "Q6WOEu2+Y7ilKgIq4yv2259Heqthy3z/t3QuTfa+lmhTEpLDyFFZ1o6ERsqYmMrkPM5DoPn5K1BS"; + sout << "8svZoTTyRFw6yxjoVC3rk3ICe904M/cuuPxOaV6D3/7/k2/9Qn2+dDWV/Kv/HnH9kR5L/YjZ924u"; + sout << "yI4vfT98p5Jbxd+m8wguelLfXVTz2dNjBQ9fxfZbVWMJj2Eo1BjHIAXpFjE7g9fog+NbIqnlIb3J"; + sout << "rBh8jEGWy2lNrNZXVzYly0z4d7qmrcbGI5FhlWOkKBFAdp39ju66kI39BZiM/RoTydqIQ2iA/eZX"; + sout << "wqgzqMPQ/LrpImeroHcY3tMd334UHnzDDq2rJcqreapzK+YM/saPsBmiYWs4joiDCjTJcGuHOroE"; + sout << "PIDb3yGry6Hov+GNqI/NpjFMQIrEJlR25yBiunN3t4sAcHQJzT8OHxrDVV2BWGRbagCMqWqeN7KY"; + sout << "ePwovpts9a5QbgncQzbGT8X9p5WQ/uY9miOJACWeHr/kZXjQt2ngGBQdOm9i3RfYW8NOjBDGtBpJ"; + sout << "Ys31Wj9mwuC7+o+4DLfqtfXrjW/Or4zJca7EkuKJbRjlicYSyMNSrQPjGxq8qYkHds1dEQM0JydM"; + sout << "NjFG5vLb+nqNRrABnSVkuEl2W83udqnjDE+mJVEnfwHnPxF2k1Kr2nmxqO54ZUranYnyggZxrvuu"; + sout << "HGkg62IpJXv6u1R+7rF8XRErh36r2lkmps1cCkmW7PrqaB+z/wExOH7kDZG4cMJitrTDywtYsDVu"; + sout << "xltXSqRBUTH/nxrq/9EiQOM7j/ITCoiZJRfA+G2hW4srArUW8EXEy1dKfOKrooxrZnGeEh2E5sHi"; + sout << "Dib1M9iQlTCc+BZFPqvOv8lM8uIe4ylVvh5DtyzcKgg1ta5hj4nTKL11vLDBXyqBCVz288zqy+Ra"; + sout << "seXn0vPNCk9fsxOtFyjUH9UiE7sYcVB/SWYa662DLAvV832uOjFVlOEESlVhhYWb4nCEDZVnOnPH"; + sout << "1ilf7LOgVi+qHmMm3FS2F55YclALS9CpVmVbt4ADpgKaE8V0TlI84mglvHOJfqE5duRDMfdQkjBV"; + sout << "o4vxT6Y2xv49ARNgDI2J8yt0bxnQ257afzSUxvvlROxBrYYNNkQ5lPivD+zeWH5j/ojyxqF5iSNq"; + sout << "Arp6ClK878ubKMvRPjdmrifFhSfs9vhfL3Ox0YFPBq3zKLPbTJEalN3Eeiuzd0dDSnUatLnV38V+"; + sout << "wyxNsDDywi83kR+dulf2SytyhnCCKOHHstUa3M5klOVLCsrj5lL3CarIo89fqpqqY0H3gzosW+jw"; + sout << "e+Qd/PTgX7FQd08oasrwptSX41xYXtTH/K/ndPg6i/YzmdIBvGuV+A4MltyzGXXRYxulYZv+OLOE"; + sout << "9AIl10pH5amcenHV9Q97qZ2NAtGeBGcWQNpQRl7titOfka4CHpNBXjg1xm/go6SHCy2Lb220mw13"; + sout << "x2kdIEa7b+bRZrXwYJ7CNHE/OuKclyzDvZHfwOrqGl2mfYnG78oi+Gr482vOBZBr5sYf2GxUT4lr"; + sout << "7YPENdvyYarw/liMmmWE7Td9rTkHXTt3KvCaeJRddK63LIjQrOJNCalDytkGT2LfsY9hu5r6a7Of"; + sout << "yspHZhOqBBOyjGXe4/D5nQqwpFfccywWBEtnC/+DINOvqpfqx9CxFAZzdjJJRlIPtK6erv1dSTEn"; + sout << "AodulJz+SoMYUrOSyKPqQlVdN0ycHvLFTwspnZv/6NqNzV8iqYuIb/iylAww3BkmtbmHJBKpeRKc"; + sout << "DVNWSOiFRb83QaDQLh+nfH3J0p8OcEQgN0XSQIamAOGuRPcTILCNsn9gg2qEwpgpo/1gL50XfpvH"; + sout << "2SNQBJqD6DaZ1VwqzEyQXAw26niGJ83+UEB915mZR10FYo8mM3Et5WYyg3/0Y3u3+B0GHtF1lImL"; + sout << "wrpLFhUdRBijcxzCYwuPwrbC6tIHnfV06rwehfR5Anq0cwqip1O+uw2Af3IfjWh0wImoQx3Cm9k5"; + sout << "Z3iRoZ0+MttTVyFuWjHGrwVkSItkwzvzRL/UdeyTkRqSQvDM9zZiwNygcp8FrpdEBuxMmfC4zAyW"; + sout << "y0yNR75+2NpI4MPxmEh6jvDJrsKQ+hxl7teqfPRRi8leqIvz305ZBVcrP5bHMcuj7Iun9znHdqnR"; + sout << "ngddgfi7+SRsy5pQEtmhMoTI5Q3q7OB4I8IrttJPJRu5Y2HCUeCfmTlTVQD0iu/eXYeAbt9GXcb2"; + sout << "LptTexWOkchMYcVh5qFWVmnPWjRa3X1nl3ildvrkqjTjA96J4G9sd0TwldCyZ/GglrM+SP55tvLv"; + sout << "4MOevxl2gpNHipkZh19MiJpBArgBT1uyScDXrty46tWdIC3eKAmVfZprFbjswk6L9f+DmzDyeDlD"; + sout << "v3mfZy7c0rZt/AIrReTIJYb/aWRCno4vyWww6AbvH1AYm0nWD3/jcZazTEsLGk1LkZwQ6hpXOHa7"; + sout << "47lDu5kdU6hprzMEhGuNLdYDKf+Bk8I7sYp7fBn0FvpnD+w11jhINFJZHL+4MG53CzMqo0iiAbX0"; + sout << "9ppZbFaOVCrNfqJaRWKYQnkcyCFFOsF8MCKeujNtUynXLAz258d669Fr8c4jl/nzSdi6Y5SMqdJw"; + sout << "tIBdmkAHxdv6sCRys4gyeyFjVHDeUNrszqG34rEV/905CJn8JNTOtUB/kk+WJaxrKyLkGUMHQHWO"; + sout << "N8whx74uMJhsg4W+Qpogy9uopqF7Unk/kCbuBrDzXR9+XjYwhFE6uteAbUJ7RgrWCP8929150lOz"; + sout << "dk86E6B9KZjyIke0jrxH9AwH1hiOLkF8kIWVXdjTPgqF4Q5BjjIdSWhXWS55t2rpD7Hxg3cvOILF"; + sout << "c6IoU/uRuxpoMMnSuzq+weYnEjTBW6SlnjXjkIuZaK1+2U9rWDaqEgTn2ip8dgS+YA7vA7+dh4we"; + sout << "vR308ToEwblhv8vSWJICn4iyfaUDcFCXVegEzl4oJcLY0u4NAAxEVj8O4YN9XQcA51dADKZD3+70"; + sout << "yfVPjt4xF1StFZAt8zZuiuh17/SxoyRZXSOZ4lrX1ihQI+mVIXAblu+LpJM+ugowwdnnBS/poJwJ"; + sout << "b2W5jv3q+IFy6VJjYkvzUpckAnHGlQyDhYpyYARW6S8BcAox4++WgK6K7iYjgJVUxTKdMf3LmbWm"; + sout << "UhoR8Xy3P6NBq7lzZRddjFQLZ1wtKBgvqjSdtXDFg8C4oUG3s1UKF5rLQ8ZC4EXn9kHCJI6U4IdT"; + sout << "nVsNrF8RTfItEy3ji64JnvVGmFDPb3V0CPSlSrGwmiF9TeN4iT1OOWCXNfJyLL947Mpd2hOx1jpa"; + sout << "52OVo+GofOhImcAXIryA9BoupDVR//+7iFB8qffCuz2ZyhAeKTvVg5/TEXsItr/9Mc4kbKUo31E7"; + sout << "7cd47sykXhs+vdWU1qVpVkjDOmsHdnXLxkrL23+UUu0gUoE/zdEayg+yJ5nckV0EULpnOUdI6kD2"; + sout << "r7PFc0abFaw0FrIV61f0AHLAYuFGL4/pZNXoHbMPcTjIOr5MvuneN2n32lqhHQyoZr7er4gm4n6s"; + sout << "fPmgp/ezVB0uhb5Bovr7uKBjqCSSxwpyWLIioAUIy+J/bGQMqvC447PgH3Kkvkhmp33i0pEp1utN"; + sout << "hXTHtnu2vdNNwZNrdANwSs8/hzRjI4qdcFVjCpnRyrcIATvkalWB6345vTSyg77XPO0m7q3M41sN"; + sout << "me9p4AeLxlgRDIfkJrq27zPlYT+w+o27s3jXBS2MvMdSwl4CW6UyHCqrouaXy51PrWtFLY1LZjx9"; + sout << "3R0xGbp9XZC9WRqiL91VMa9l/sTS92bw7urPeRj1FHh5HNyDSpaLEDHQiihXJaz4xxGbdEveyHW7"; + sout << "4yhgD4sR0slooQhw9O6IaL7gjw0auFf+ty8jedYY/LfLHhbRGYvmdL1NiskwM6mdgPs/w1q++HFt"; + sout << "u613gA7BRTxRw9MIiriCtGyk4U8uO6PnGGqRoXNpnJXE9fgl9Qdoy6N+uw2JTIMcE6btB1lbjdQD"; + sout << "nlUalgZMbXzm07z+MNNNGSc5Lc3ylAJ3v19GL6qPYYvDtXcXC8lPRSmeZu3k4tDZ7gVlokrSewYR"; + sout << "iQznd/Cq4w/MBXFRy5nbtKMeAfQOEU0KTgcOPVYfd4sUJiQixYOyDpRjWHLuEfCXoXEbLIBLeJ8j"; + sout << "6YHwyOMMzxBepwBSxTNRlNnY5GF22kwhk/KUIB12mmkegsiTSfA4Kz3q05i7KF93WZB1gDRSkpJ5"; + sout << "DgtN6fALotQJlAdI7CgKxiqkTD6YD+wIA7g/bu7EwH5anE+r9mBEimDxrDNyhCSVwfYPnFmi/HLk"; + sout << "CtH5msyO61QCth9RCpOKMsnqHH2oqh6XrH+D6FjhmFxqlzp5PJ3SEybtCY25xxxnmwyJh0uDcAzr"; + sout << "a1EupoKzlsY8U8kFdcmhOtPOnSgQRqUoLQxZ0RUixPiKFI3ikbai5/ZNNlt7Cl6XcOWxnQn1XLGh"; + sout << "7KewlqwiweypCVOk1D6NHgTdfCGzW0NWYrs8cqzc47cWDSMQbn8Zzr9LMHdjq1t5VoT+0pyasBhf"; + sout << "Wxih23uQlVfaNO5SlxpMPuW3TFpASdhxiXVmJbIl/j02lo5w2MLYUUouGKT/WeTr2f2R1kNZkCcN"; + sout << "44DdifeZYVqoUrAV3etS4Py0H6OHDtAnFmU7983u/bZ7bpkfvFl97zvms5D5JWuFQ8HZTZikTzIM"; + sout << "uxDpDcxkFz1U24O/WnovgMCipjI6uV39rjWr3qXb3qzc9D7Fp+jh3eg3F9C/SeXY9Ru077P9mAI2"; + sout << "oRgh+2hJ86+1/YdCetlEVgmIVVVGVbcZWLNoHISNo2lnH3fSDcfWTRekNBlvUgiAE49tH+av0SN7"; + sout << "cxjm3vWlWv9Mdt1BvZORAk2sYugBnDJZbB+F53rKioLbUNdlCG2HcjDboLrRnY//cxOlOARg9UYB"; + sout << "N4+HueNJApuqkGhItKHRMKXWXWXKYkyN4PdPicVFpdlxwoeXNWMtNvUbOv9oAox2HwWqFERO9JGa"; + sout << "wSXE2oYyE0+BUUUPq4Bf0HobSR9yOVj2YmgvvEv61uS/7rDiSHElMnyrGnOpy1Jk9Qc7BwDJ4iJt"; + sout << "LH0qQR8w6qXJcDajjLT2dYkzFYj8Hl4iglGCRuFTKjhQLp/E1nium2AS5aSlvk2GwX4SHEnk6nAP"; + sout << "zYbc2wAbehdVW48RzUWWSW62Zckj2YpDSMoHlPJhysfxN5NhGzoBag5D28vAsbmuOYKw+aDDg7/Q"; + sout << "kmGxNMjXuZ8z2dUHAXA04UxRx7P4kIDYHrOPyd1s+V/Nhd2WSodXW+X0N1pC5Wv8EYYeNaUke7h/"; + sout << "8w83XCJ9cHSUO/PTRY8VLd1GX0UFkPb4m3Vgbv2ARqK+ZoZ7t09USLMz9JG9MeD3+J0vL9jShQW3"; + sout << "+4j4G9ZuiIW3VlOu5Qm7jdcLEMzd54mqjC6gYRqTYNP+1P3LtfVY7nhRdxR6S8Zt16B1z+hHtLsY"; + sout << "APWNCR21V6k3R3MjFIit1mBJZDLFFM2OEAYZW+SQ+TkcSueb5xCyFK2gwynzU6l5/CSbaudjWjsE"; + sout << "hpcqZTmHsv+ertehndIaMes/Ihe9ZWf8KJU8sYDn4gV/A/ZXesRLMDtGsTKTywN99GdKi1yX1vpv"; + sout << "ugSgxjMoplA6lZHmz3+sEUepfpLVQxpBQ7OvrZrhBuf8GxgKRC2C9LbdKdq8+qM/qjuNUf4CKz13"; + sout << "OZfp6K89LPJx4Qua2uEKtoa1i4LN3Yt5urZz6CZmlyZR8/jo2e/PBqsKc5zP7SzJdQ3fELngYdUm"; + sout << "HJC7uE6ElkbQWqTUg1Tjm8W2kGmbahkM6eVYps1yVrWqWQvFG/m3IUUrIhtvK5JwzlRfH13GIevg"; + sout << "cyQl2750lWOaewIa3Bni58GNVuR3icSmcf6hZ2PbeEdnnxMW8oHUp/cWMVHmZdG+AA=="; + + + // Put the data into the istream sin + sin.str(sout.str()); + sout.str(""); + + // Decode the base64 text into its compressed binary form + base64_coder.decode(sin,sout); + sin.clear(); + sin.str(sout.str()); + sout.str(""); + + // Decompress the data into its original form + compressor.decompress(sin,sout); + + // Return the decoded and decompressed data + return sout.str(); + } + +} + + + diff --git a/ml/dlib/dlib/test/fhog.cpp b/ml/dlib/dlib/test/fhog.cpp new file mode 100644 index 000000000..88377bcc1 --- /dev/null +++ b/ml/dlib/dlib/test/fhog.cpp @@ -0,0 +1,684 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "tester.h" +#include +#include +#include +#include +#include +#include + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + dlib::logger dlog("test.fhog"); + + + class fhog_tester : public tester + { + public: + fhog_tester ( + ) : + tester ( + "test_fhog", // the command line argument name for this test + "Run tests on the fhog functions.", // the command line argument description + 0 // the number of command line arguments for this test + ) + { + } + + template + void test_fhog_interlaced( + const image_type& img, + const int sbin, + const array2d >& ref_hog + ) + { + array2d > hog; + extract_fhog_features(img, hog, sbin); + + DLIB_TEST(hog.nr() == ref_hog.nr()); + DLIB_TEST(hog.nc() == ref_hog.nc()); + for (long r = 0; r < hog.nr(); ++r) + { + for (long c = 0; c < hog.nc(); ++c) + { + DLIB_TEST_MSG(max(abs(hog[r][c] - ref_hog[r][c])) < 1e-6, max(abs(hog[r][c] - ref_hog[r][c]))); + } + } + } + + template + void test_fhog_planar( + const image_type& img, + const int sbin, + const array2d >& ref_hog + ) + { + dlib::array > hog; + extract_fhog_features(img, hog, sbin); + DLIB_TEST(hog.size() == 31); + DLIB_TEST_MSG(hog[0].nr() == max(static_cast(img.nr()/(double)sbin+0.5)-2,0), + hog[0].nr() << " " << max(static_cast(img.nr()/(double)sbin+0.5)-2,0)); + DLIB_TEST(hog[0].nc() == max(static_cast(img.nc()/(double)sbin+0.5)-2,0)); + + DLIB_TEST(hog.size() == 31); + for (long o = 0; o < (long)hog.size(); ++o) + { + DLIB_TEST(hog[o].nr() == ref_hog.nr()); + DLIB_TEST(hog[o].nc() == ref_hog.nc()); + for (long r = 0; r < hog[o].nr(); ++r) + { + for (long c = 0; c < hog[o].nc(); ++c) + { + DLIB_TEST_MSG(std::abs(hog[o][r][c] - ref_hog[r][c](o)) < 1e-6, std::abs(hog[o][r][c] - ref_hog[r][c](o))); + } + } + } + } + + void test_on_small() + { + print_spinner(); + array2d img; + dlib::array > hog; + + // do this just to make sure it doesn't crash on small images + for (int i = 0; i < 10; ++i) + { + img.set_size(i,i); + assign_all_pixels(img, i); + extract_fhog_features(img, hog); + + DLIB_TEST(hog.size() == 31); + DLIB_TEST(hog[0].nr() == max(static_cast(img.nr()/8.0+0.5)-2,0)); + DLIB_TEST(hog[0].nc() == max(static_cast(img.nc()/8.0+0.5)-2,0)); + } + for (int i = 1; i < 10; ++i) + { + img.set_size(i,i+1); + assign_all_pixels(img, i); + extract_fhog_features(img, hog); + DLIB_TEST(hog.size() == 31); + DLIB_TEST(hog[0].nr() == max(static_cast(img.nr()/8.0+0.5)-2,0)); + DLIB_TEST(hog[0].nc() == max(static_cast(img.nc()/8.0+0.5)-2,0)); + } + for (int i = 1; i < 10; ++i) + { + img.set_size(i+1,i); + assign_all_pixels(img, i); + extract_fhog_features(img, hog); + DLIB_TEST(hog.size() == 31); + DLIB_TEST(hog[0].nr() == max(static_cast(img.nr()/8.0+0.5)-2,0)); + DLIB_TEST(hog[0].nc() == max(static_cast(img.nc()/8.0+0.5)-2,0)); + } + } + + void test_point_transforms() + { + dlib::rand rnd; + for (int iter = 0; iter < 100; ++iter) + { + for (int cell_size = 1; cell_size < 10; ++cell_size) + { + print_spinner(); + for (long i = -10; i <= 10; ++i) + { + for (long j = -10; j <= 10; ++j) + { + for (long k = -10; k <= 10; ++k) + { + for (long l = -10; l <= 10; ++l) + { + rectangle rect(point(i,j), point(k,l)); + const int rows = rnd.get_random_32bit_number()%11+1; + const int cols = rnd.get_random_32bit_number()%11+1; + DLIB_TEST_MSG(rect == image_to_fhog(fhog_to_image(rect,cell_size,rows,cols),cell_size,rows,cols), + " rows: "<< rows << + " cols: "<< cols << + " cell_size: "<< cell_size << + " rect: "<< rect << + " irect: "< img; + array2d gimg; + dlog << LINFO << "get_decoded_string_face_dng()"; + istringstream sin(get_decoded_string_face_dng()); + load_dng(img, sin); + assign_image(gimg, img); + dlog << LINFO << "get_decoded_string_fhog_feats()"; + sin.str(get_decoded_string_fhog_feats()); + int sbin1, sbin2, gsbin1; + array2d > vhog1, vhog2, gvhog1; + deserialize(sbin1, sin); + deserialize(vhog1, sin); + deserialize(sbin2, sin); + deserialize(vhog2, sin); + dlog << LINFO << "get_decoded_string_fhog_grayscale()"; + sin.str(get_decoded_string_fhog_grayscale()); + deserialize(gsbin1, sin); + deserialize(gvhog1, sin); + + /* + // code used to generate the saved feature data. + ofstream fout1("feats1.dat", ios::binary); + extract_fhog_features(img, vhog1, sbin1); + extract_fhog_features(img, vhog2, sbin2); + serialize(sbin1,fout1); + serialize(vhog1,fout1); + serialize(sbin2,fout1); + serialize(vhog2,fout1); + ofstream fout2("feats2.dat", ios::binary); + extract_fhog_features(gimg, gvhog1, gsbin1); + serialize(gsbin1,fout2); + serialize(gvhog1,fout2); + */ + + // make sure the feature extractor always outputs the same answer + dlog << LINFO << "1"; + test_fhog_planar(img, sbin1, vhog1); + dlog << LINFO << "2"; + test_fhog_planar(img, sbin2, vhog2); + dlog << LINFO << "3"; + test_fhog_planar(gimg, gsbin1, gvhog1); + dlog << LINFO << "4"; + test_fhog_interlaced(img, sbin1, vhog1); + dlog << LINFO << "5"; + test_fhog_interlaced(img, sbin2, vhog2); + dlog << LINFO << "6"; + test_fhog_interlaced(gimg, gsbin1, gvhog1); + + } + + // This function returns the contents of the file 'face.dng' + const std::string get_decoded_string_face_dng() + { + dlib::base64 base64_coder; + dlib::compress_stream::kernel_1ea compressor; + std::ostringstream sout; + std::istringstream sin; + + // The base64 encoded data from the file 'face.dng' we want to decode and return. + sout << "RFYXmMpdiStV6dVZSJkJX8t7GVavYwTD+Fn11ZjivhnFQyvVJ5t39yJYrK6Qh6K58ovzMlgPiBLV"; + sout << "nd+ZR0JYVCVvwRapp+dznB9SG9dJbBrsYH68k04uOs9ehP8aK4/EXvcZG6s+rL0rnhVAf7SxL7PT"; + sout << "r/11jIuASMa1daKZjAm5Sc1icGXG2FJjO6CxM8mOzWJ1ze69MPD1bz/QYAWtMUqUUIAM0qOPHY0x"; + sout << "T8tdU+Vo6S6E+8dJpV6a6iDdocbp91meDQcT0/kadhC2tmn0eZoNulTn5MtmsmEeuPI2lLLcRJ9P"; + sout << "yt3c/OJIzI8FaDzYG6aWJ/yBQx/DJF0avAlh7V1UmbD8O/dMoF9nUFDwnhGyS6DYfTXxCYgVgoj+"; + sout << "Ik5RLHY0U/DhNTciFaLX41/MyIt0xcGtxhoVcvwkfIigKnYQsYfNpRdUWseRlZ1KYaR4Oc5B2tie"; + sout << "kH3e5AhrY/HtffCah0sf6MBWJEi7CH9AnVLDQefL8Ph+qCWJGf7cGnM/oAaHQCzHIHVi+mK6EBnN"; + sout << "1NDrzbdXmikwYneB3LUZxCLKZmxsFduB2HgiS0A+tTK6IYc+jqCHqz8N6Gw0sSjAK7rrPDTvxhSN"; + sout << "lX3f6E2IDfVmyvk0l3RhuA1PNEh/nlKR+YxcXHyYW4wGf+UfWScAzKGxrHLxLC7LQycCEaCMkU92"; + sout << "SQV5NSSlwKYKACabK6UJ3gGIpvuQK2Aw7VWmC0iLczqgWsX0GKJR0FAcVL9Ed3nV0Wd0s5BkjBsr"; + sout << "RbUKzw11Qu0toj6BNfwXo/5cY2dtjj93a+CBfNrSEuFyJzZU7cn890c9m+q8C41p+wQdf4pFpjcV"; + sout << "8Kz40Fyt8KtxItWSsACIwmUO9h7DGnyGskWBYrxgDV2VVlvuPAnnSCFPkbdsa/pfnohUq0C5a/ii"; + sout << "BjASduPdaBHpjZ64f+TIaXNAGdrFiN61W6e3fOx4fLFlzPQ8szyWuuDh2hIz1FMflbmu6UOEkQji"; + sout << "w+bwDDJ5OUFmY/00+3B0XAFmj7Pt8OQ70lAVLcX5fC553diQJzrrlJ5p9/8+ILln+oleUVJhtp2q"; + sout << "VCZ9XknXLjkQik30M7orOj+tZt7HDgC5sz/wHU5arOL3nIX5IuIHBJRlB8dERZgoPNQlB090rItP"; + sout << "MuT+Hyr/eR7Kcux7Fy2CoMxcIfWEXvxQoolLKC66q4+SFirdMRjXuwbRXrUBbenmBfMMNDAOkQKO"; + sout << "Bi7d8t1wI9ulNbACtqLbmPjW6iabc0yM4g69cZjRx/JYhV5AaykROJxCP6ZKTw+3ddAht8xoyHLN"; + sout << "40rB40fwEXIvv7qxCdCa3h6l6IRV26fOLdcew1G0qjPORcKK1TPhzmneYvhPZ1m0r6KxWNEnYcFq"; + sout << "WhDGNxj/05eBy2qIiIU/KUPhxKyipF0ekgPoCsWT3S8edYjoaIl5cI0TNpNEwKGRLQeOLBDt+MEh"; + sout << "z1yKm0i2jrtxDBpYf5FW/Fln/XJAK6z9yqSDDTDwQleabvRHDH9bc54bLc0TL9g7/eIj9xcshhaB"; + sout << "zbi7baB/yUnI1+0N6CZ45gV3BcD5n2QLiHME8wELveiMxps6MTf3SdKSRZJcnoVfN4AGYqQV42ec"; + sout << "dFB9y8FLZVL3/8rmB+XEu6YoiGcNK6iATLYQfF0GFRrur0Q6bQgdvXv1uZKtNfYfznsAAu/KBdxX"; + sout << "8qskZBMGA3LxJC3j41VW6Fviy+XUxxcmG9ykbf0COJWDul6ZQ7iRI7rn9EpFIYBM1lKzjdC0UFTW"; + sout << "yDWEE+mf9y+RZdlxHdROFj93FNwzSdzNr1yjqHvZHBZYiuArHEuDPXdxqVRePcID4EHzmpDgWFwR"; + sout << "o5qqDxU8e9UYfS8SG545SPZv69SJVJKld5fQLZ4FbcCjv7wTwrOKTROvurKkxopKt1n69BdDA14H"; + sout << "mViSyK22xK/F7/ydjLoqx6aJ8xyNpoUk6XIeJ5Ei2Lhk84VQk9dxzULVy3KsfRUrZCTTi4YiXkHJ"; + sout << "SmQx4NQKqHR2IOgnJBZuNG9J3Fzv3NKhQpmKL0ZbYLXWdKP9FHWUR0x7y8f74Su+GrplBsjh9NIm"; + sout << "QdaKLa3NvJB1TML1/GNcdJVZUuSaX0cQn4bbumvtcENVbC9u99fGnsaS5FOu4AHd3338zLUMy34C"; + sout << "OpJjU1c/IElgyKmaGYlAolgqGU3lixhxPGBhUlXGfHmST2ZWq/l6NxD//HQXRaRUiQGQzWCvhzOO"; + sout << "ywUlVzl9eJ5e5cdLWvffsPRzBgRMrdHJG4TbSuLAREsSD9QEGab3a6y+qa8T3Si/1Ut+Sn2QvPh2"; + sout << "meqqk9g0fRWtrWxcnbUDU6zMlk36L/o/y5inrHdGY+ixIewhI0n4/Nl3wD96SRITcEVSx6K/BVot"; + sout << "+qIP78I6uk+miUF6MW4AnFyCe1wRNyhT48638KIphSQSKdu7TaBndi2DNgFFvWrm6/cPOqkmCzGC"; + sout << "O22uwHyuY9XhafswKLH02+VD24PIS/Fw7JMP+KzvfCHQd4XxxdsISe0/cjwg26ZfGcnULLY2E+dX"; + sout << "LjdgCxNyFBzFTQ4gB4QExF0dHu+SPoo5T3VAojJbYZqIelFY+u2yQDuS4HCUISPkwuLHXHbcBuwg"; + sout << "5TeuFhyBrlwxHQC/OPACmQJREImiqpzrjmh5QipeEgYHK3Zc72tYSeY7eTzS4jj0eRQ8KiNIGSi2"; + sout << "2LjzAfN2Zm7HGbiBtKZVen96E8HLcrd3nSWnizfaLLWTWB3zu9zz9/vFdaa3TlO6BidYsKomTCgB"; + sout << "wy8yMeykE2qbxgrpRqEqmOkrOI9XtTTJIycfAlwBwoFwuqvGIPtFrYmC/MwRMCphg7acSRcjZg81"; + sout << "5IEqpoq9ca7Zc3s4foteVMCJT1A+qmNAJ/j7IoyeX7GnlM3jsqpYt9BmKfbw5Dr2JB9vzroPV++x"; + sout << "UN2VXRPbahjbIvrTULpeBdmlHU0i3Ya8H/C9RY6c2DhImZ1gDjgn0jQ9GC+CsZpiM2xBvfZZGOEu"; + sout << "c8N8pdo2owD8s5q2G5ZCGNdME/AG+iIlb0P00AX+XR8FYhxKb3y50i1giM41mnkKM/WMGFAnpiuo"; + sout << "YordYSi5plePBnxBfd1Iq46PpsD/n/uUTZMHs6TGp1hM6QriyEhOO261HNHoU+n8m1Omz2cfRJyx"; + sout << "AuFLwHSEqvGCSmslmoDpSg2qOaIWK1LWlN+1sYJj18iL4GRM0A5QzXaS0RThqEgmPjeBOkFBjfSO"; + sout << "hB7mb3sDbY49qbN6P48bGV+yF6y34gYAiVkm2NksHzN4ovwg4O6WMQZwEhNk+4gTIzG69jIm6Hbn"; + sout << "2l48A3CYmn8gcjZw39nrlSxpMf7KPkRsdvGmc5Qx9RjP71zH/KJ2TXP0xxzsaGgmqzXfey5l0Hih"; + sout << "XZtfZw8Y28fHBfm3bnIncS4w9S91no+RYMv0aqc9ty7l+Pa28ELwSgQj9eP4u/i5iq/GPmmSxiTd"; + sout << "Si/eeyK1RFJEP4Tv4f3PkV9Js+azu8BbtU+BLO1FBlVg3CzXH5Pc5FMujLdmlqa495hTmi8YW6Et"; + sout << "Fx8dkC80mYFGpVjS+B6pcQLbLBL9gmKzJf4L94/gXZ25BEDob66+XOaRnJ4RkSAN2g6gFJB9lJDh"; + sout << "rLerp3kP/ubPCvcFywuGx3UjJuwFNHE9m62uiaXFU4m04Kc4n7ccHc6hYUkhkY53v2Qb5SDx2qCf"; + sout << "Yg+PWVXujfYrqxRHSwqtV3yX5kMrtYsYpygb7crweOt58BWUa3duyo23UGJHaCwhGwXat6PEC5DQ"; + sout << "2Oe3LVJmc8eYtD97mHKFPhptBl5u2Bztb3zis/oNj1NdMjnDrNuscEAnrpk1CetvHKLglK63Zo/D"; + sout << "rf6SJcmGR2h9g6wAeV7UdsfD6AvteiPj5sl4UuY9x55pP3CTTYklBO1MaDd/XO3A66uMh95RZVGr"; + sout << "VWDd/uKL+rIuI+vKjz8rt80nv3SyUrY9fbftPdK4pBaVnIt73yZrrqv4Zr28H8XpFFQAV9BPlC9o"; + sout << "a8G+AFx/+W2cSfo9r1Uw7npVvRTe6TtIiKagYUmWpx5BfX0VH/VAW0FUh9oiVfx5rm9eaxfSQnD6"; + sout << "7qBINPxsKq+ZDSXni7qfC3J043Le/uL+3XUqsccvEMoU65akKC3lmw1txoUukv92oxyqPX0eOGsB"; + sout << "AU4JdXCldqjU9K3QhyCvv80ZWotGfUr0TlN1LVZqcF2iq3pX1UDOBsPwz9v0QNg8Bmlqy0Vs+MUj"; + sout << "nMCwU9xErzkXLsuVaG+Llk7mmAl7C34BF9O9qSl2kCmbQYoQ87zS7gm/pK7aKGNsICHrar6vlsKo"; + sout << "BJA++/8XKL3nseNZHzq7hKHnOTzagP52MRf+TPXbTVjQPKnCKVAZJcsOlkmuZc7iDnLn4muHDRjg"; + sout << "y09EYcYlFWhLAgsWmatQBsT028ytgMNrQGHDJdjuNkxYfPo+/91ijaaBiey+DgrUVn0fm20k6/Nm"; + sout << "colrwPwHrK3uOdgBn2ysDeUXU8NLMtR94fIL7etQ9tlUuufwrxEL9zYUM8tpks8HDR51xgTwUOVo"; + sout << "DyGFzOdYQRzwi+kkEPEwkpNQbB258d5w9G5eR00P8B/aSjm+w4FU0MsXM0GgPxnQ+gTpS1cezLTn"; + sout << "eelvJYiq/IInLLxoCXycZFPt3WFQqOBpcs6TV/QucjI/5xMZtP3JHUFv16UKPTFI7p9DF+8Ch5HN"; + sout << "gWXCnRSPdYR4ZRid+Xfzi0TvQsXV6u6PaE+H5MpyNMBWhCwxb6FdiLUW0BswGNpHBaFxjB26Qbmv"; + sout << "OW+s0OuXDvKigjQRkeaYawjRAIAN/+CEYR3oUad2HyJ5Ybr/lRlybQuuIqBhuvpkYzszS7BqrxOh"; + sout << "FJYaivT6r3HbHjaJ+Yz/zNW4KsL80zYkPMP7QgcbbSfE2mAavr+ciXdZBqMMUR50sDNLxep9+hoa"; + sout << "ys9wl75QMdx1jn1qn7f04JMSjCyZ7M4bWSyTW7VEr+NBBLmiMzhI6Ufh1iCUpvrIDSQwSDUL88wt"; + sout << "oSiouRbqizt36TldsvFV6afdLgjRrp2cb4vOQBiltwnY06JraGZnsrb4UCfHZhxd8sq/invK9tUd"; + sout << "D3z8hYyLGbS+3LBCK85r74IYvCuhoUp+KobIZPhvWuvdjmmq3SAxIKHNdLC5hnLVMhGJUrckc18H"; + sout << "9zK53uB3QXX6zGKK62Jph4aOdJoDQaPL0K/yHgn9UayEhH/N1uj3Ao39c05puaxzcSotfBeS5+6K"; + sout << "WYyOOMtt5ikKz79qfj6dVWge22fxXUc6yHYfdga0IbYWRocIx+DuyUZnrRQHihNKgYpvF2vhCX/o"; + sout << "R097oHI4ojZFAX/ZWJ7igJvX7ChiwTjK8KDk+vJ4SUd3IHXaiLkkkd9p6tCuc9Lw5jqWiGrrQKuI"; + sout << "7AmGsPFU2EsfOzmwdZctDGXq2/IutVDmwGiucpBKsRN4y12Q1FWKpceVj2q761LfDx2qJoeZKTPZ"; + sout << "jHPdXnGKcWy+DM6GoH9e5jP4CW+HfdHe474bHfLDbP4NE1oF1vdh4NcLy6woi47hg3FS60z+wePD"; + sout << "bWq29WsSwU5oXq58nMxKOBiMcbGFrkOme/Z59Ybi7Cw1+U3nGE3evCFyVMC6g4f/jvCyWF5I3Nm3"; + sout << "OqmkO6fmZ4ahql6C+RwfdRM8A3FllNPgO5riBNX6RA5xKj+JS+OZUrSSN+tUqcgN18IlmLBExEUt"; + sout << "rdG06PKy+WM8Cju3gtbOFX43H4URr9CQcDxWbN6NoqgF8k5a/4+xf7DilfGJg0E2Vu8GG7tmFSU/"; + sout << "LS6gtfLOFyEnQkTzqK8OhVPVLT62cEfCZN9ZY3iKQyZ+VLQhxwarUAgqeNAMXM40NBJqnUIaaTKa"; + sout << "ryhHefUHazhfVgx2+GikVF9wvMobCvvP1qYONlL9EH+ufuLEw1V35BYIIClbrC2uMrnF0H3QbuJQ"; + sout << "ma69tq8TDPkyDLiaczKuAxUzJoj9reJOYGYTxzP7AQKmmEmEZ2cX6+2klWcRXv23XXN8Ypjjnj+d"; + sout << "fTdzxV4kzcHwOYMsI92tadahezCm9uOR0d8p9IH61QQSlUlJw8tX0TpGkNhZpv23STjQhb+uxzAX"; + sout << "1vYdYbPOenr5vCyhnpp33QezQj9cLhSv1WweplUmZjHcJTkPBdflRA9AuqxDVVnbbofXd4EJDC6k"; + sout << "u7xBoD7EKC+kCEkx7ygj8Gv5GVKbgy1js4gLuYwhJ5aqdNpqm881kkxntfRMluVcdH3IGAUzWR5w"; + sout << "26eq7Je4Ttr1cC/xy452i3pJocbhCqrNUG85RyB5FXHAv6GMvm0rUIa6IyC/kfis+sQsdYkQ8GMQ"; + sout << "wL2s8fdDT6l38N9JbRNwdRv8Xa9QAjwcGNbP2v7tAzM5MyhHW7FImYVAaNAaLbzE8v95zeGpT8Cl"; + sout << "CWronhkcJRab8AKP2UcqAD+mW1hVEAqyDe7oWoZziKa2G5aW2vs/WG0z+NqL1zGvUekDcmJ5L4SK"; + sout << "XgdQgxMb/1k48YqYQZFtQrIqoBbYn4qPeB7i378T5TLcCgB6SldsFdlNzs/czN0doroozh3W+sli"; + sout << "d15Qnv5WMjOinjh0Ybt13wcUzeT2p0ovTtYLoiYAhDeAibydJETLdcozfpXFIJNUSoH6TcLge3tr"; + sout << "0uVP92B1O+n0MibJvLsLUKQ9ueIiHgZb6bUSUixAg89QCDRLCZgkg24DLZ4MMTg7IRfFm2eR2lmJ"; + sout << "Erpe62rf2+JE5JTqU88yn6kLK3bQ9vmaGRZ1NUxibTcdpo4hH91qIldLT+jrdmhrawRjYcRduYGO"; + sout << "WXgjTbgKRTxnqrXwRD4Hl/B1EV6ggYC+jn3LQHaT6bYd1hORmtuLKy9duSHVCNBAZvnto27l+h6g"; + sout << "VbUF8eZasmk+q8Fn83bx7C3eKoHjr6acEUyQxtWVCbmeaMd7h48Mt3Z3r8TyX3DkwQmpClciwpyC"; + sout << "E+pbYEWMZXGOuXPmcHTM/Iky8jNSWyw2lLVQQUzPOJ0v0dtNipYRZqBQCDtSE0JuA3Jo8l00uox2"; + sout << "bH11ErfGplsGZJejPGL8ba90e6xeLwH5oe/GduQ0/Xk4+faqBhy/7TFeexQcFDRCCTrC8+jATm82"; + sout << "vHo+NWJjjDlEI4+F2FOhpRg7MtrrzNP/e+cD++wYeGjkRplbxd5PeyALnjZJ6kghJqFLL3NJ1E4Q"; + sout << "gKmRErg1xWWQzyuDbbXPr5dwwRyZU0gkG0WTwyUy2dFV4KRyn2IbMH6STm+0af96YF0joZzkUroH"; + sout << "ztMN8dWtmQESq6EYQfGlhQzNoBKLXjN3LK4TMWBE+1N5ilXkgv3cnN74RZHdLhEXRJnF29x/DmVQ"; + sout << "qZQ4s31Vk0kqKdQ8tW0rs82+dMMtFv8+P2rYA1GZJQV4P5/TBU36BVetlN+swvULk4XpoqhTTMbx"; + sout << "Oj9tONiyIiJitC3YiCU+G5uL0YETB8nSKrtHRiBD8k7nYj4fbUtSbu9+lKRsVK7kU41mKdBImON1"; + sout << "6Qk0XqAx2DEK4w59khYMRRxOD4u2zZWDVp+Nl7Sd7ihas/vQx5yLXHKmIpjCK3SQYjJz09txIErQ"; + sout << "0wJJZoSxH8efhGsTPuVrbQpGcHLD7bIkWf5kjR9MmOBCmUGgeeGOyi45x0k6Cx+z9oaaTXYcvRtY"; + sout << "M+R8tW5gCLaOPjfbq4QjP6yfYogoaTiSEKOPcMgiOQKrXNiv2ahVBT/lvkm6Q8+IdesGWJtD6xqo"; + sout << "+CC486Du6CFDzAHcnLMk5c3CqDfFGl5Yf68bV4aGm28BaM4vikeKRhm2tULeM7PipfQiI9R9Gy/L"; + sout << "1yZB26qciwCalP4CA2NVjiJut7FZgTF2bO/g0qfvyKsAxMetRTmqALBJi8QvKqAE4i/8gRlTuwgV"; + sout << "x6EGsUPCIcQmD8aJkZgy8+erSAY7MLcnUXu90AC37BLyaOt0tzJKfVRb6cP8wfZHqJneoGSNAA=="; + + // Put the data into the istream sin + sin.str(sout.str()); + sout.str(""); + + // Decode the base64 text into its compressed binary form + base64_coder.decode(sin,sout); + sin.clear(); + sin.str(sout.str()); + sout.str(""); + + // Decompress the data into its original form + compressor.decompress(sin,sout); + + // Return the decoded and decompressed data + return sout.str(); + } + + // ------------------------------------------------------------------------------------ + + // This function returns the contents of the file 'feats1.dat' + const std::string get_decoded_string_fhog_feats() + { + dlib::base64 base64_coder; + dlib::compress_stream::kernel_1ea compressor; + std::ostringstream sout; + std::istringstream sin; + + // The base64 encoded data from the file 'fhog.feats' we want to decode and return. + sout << "AXAWQhjCEZzu/U+RFPgnRfCFsyRzQjyOA8TMshjEOL3FZhUZD4amYHfLuNN8NZrJyy2CdiXOWrFH"; + sout << "dTTKsMQkSz9f2VGKE6GCsTcZ+fpbk+d7fBnS2DEUkw5ttiSxj/VsrigkeV0MEUR//EP3/w1uvJad"; + sout << "96PFVeqVf6tEmcuuUbYzViC13lfnps7ftXmic5CwLZ97llqMYpLKgfEIXZpelB2al6PyVAg2FOfe"; + sout << "2L/SzviQ001tYYgx+L2v057L/lJdx2+uOQ7tovbsY0zEBEXTr14ZQW5Wd2lWnj+CrBNQ/4wMQZIV"; + sout << "86X5JV6nl/fxrM3eF5+U6U5cZoZFnn9vRONE0kf6/g6W4Z8qQNL2mJ/M8oKoLGtuaaEjymk0ENRM"; + sout << "0M9GMLrO/qfcH/DHUb/HaK851L5sPuf150ecpIbPmQpxTOpI5XwmvwDAP9BCmLazQWQV85wJJFom"; + sout << "LjT9dNi7eRsV0mNCZDwO9iwUTv6KfAR3Vn4t30X3n3+x+Yhnq5PU/4vu7GJpje25uxICm6+PC9bz"; + sout << "QJC3wiw7jrlF8eRb2kZywIJHWtrN39R1tddENrGZkU8LlPTDcqTqfuwC1HXSt0GKlDRBW42DkSxm"; + sout << "E7yKype2G4bgbSnuf/LKftuWlNXzMP3cDFdChtKJNOws6btgvS014C4UGWe3uNVp9SNAmAFyh+UE"; + sout << "rRTl1w3RNom0yvlee9IP1sB8mOEZn5Rrzo2lCyTm6w49PlLfnS293FxvnpRSSlvqUt1lhzzEh7Wd"; + sout << "YjK90oPWGqvNn/bP1vQW35KlVrc+JT9VlewcZRYJAq2OigVxa9Ao8Gs+fFyh8/Aym6Y3Z6+Y7Xna"; + sout << "fsmYdGaUtfniJemtzHatTuhbuo979fqkIDQ6bhHlMN35IyzK26QKzGrMCaHenZEKck9qR8NTkhc0"; + sout << "urO85952cv4aI/cJ5tz8S6f1v71F1gIz76zL0UNlkULQf3tqEleTqVF5Q+YKAtov0Tig9xiSuCZN"; + sout << "btjVVpg06zlRqmnRqWpZU4W7uu19MKe6BGoZ08l5K/3UNgucyuUBTGtTGt35QaML+N4WLcTYj90J"; + sout << "41Tlt+OD4C8nZ1ID982h7X+YxmnSdMNnaeM2Lc/h8twOwg11vHf6OSo6st4yjK201O8lH0Lemips"; + sout << "LsO2qhAsF4AcWVG7HKUNh5jMr7TTQ/SPP2qiJ6eglJYLFAlwtCSXRSJFuK93Y3eUtzDpVDh44tDJ"; + sout << "y+Bsa9oWJrBeBouo9DHG20N+wJB3rOVDprmq1ZGfhkSzqmVzPk6motDY/jkPGZ2E8/NIWpXpyU8z"; + sout << "cUbc8Er2cAVshJ9uGCg0tUh6lGWsrWucpgCS5lI44PJkbSuyFZMbIZrgxrVmkgOa70pmmuHebs+c"; + sout << "GW0FheFD0IHXVxaJtTuPeDuMML/YrAqdyzibareHbg2Hn4aS+qhRUcMyemOzrvjhGF23oSXbJqaO"; + sout << "sJw/DWkt33eFNqWxofXw2pWMMuF8akESryBDyGEfk9nofHoJhTciXGNkuxIiSZtVYsVnKNU9C4ei"; + sout << "31HWlbHjMtG9RD3zqkIXwerCkhElWfxG2M9ja1S7M4+0VrS9nT2ngc/n/KZDpinuB3GIOnnaqiqQ"; + sout << "8ASsOEl9/Ni9Lflomns/CdxWns3OlcU8KVhWHYz5hqV+MI6SQLlS9j39WFKd1IPaer4y1fd99x44"; + sout << "jkyP4ekjoVVfVxpJlb4favfwI1AnFD2K2TaUeULaALTuBwNT5DnRcLNnwqD+6r33rEv94Nk+2+bB"; + sout << "0QPLkojfPlDJocwcGon+z2EHQO4t+RPAafjQb42dNCymI/bIdgwL64vldhl9KfZtXhE4llYhYuLf"; + sout << "xliN92ELGSt8o5IdAsoOoCPdqQh5NujdDs3p8kVSETq1HLilGRnuxyVwrTqJEUC4G5ntPfqXr1Fp"; + sout << "revo1FOr4SRa8+c2GgA7K8/fvXwN59GONFMKxt+sI/xko8hbOQwOu1VEQ2Ak1aS0cc5MoUrhUFhF"; + sout << "gCTqGd+9bWwg95ELl944dD7MIauKl1wHy5iI7u7EDlhvlGhmtU+6lzCmBiDbbPnoki+yZLc0V0UL"; + sout << "jvaYU1WlKZhjWcA7fHkUyGchIe+KJE8DcW+huCO5iz6gPWwy3ZURQEW+wYPD4Sp6szOTOPsNKcN/"; + sout << "j9PxKGDbNO6/Rou+TEfznVuu4ltLKZSDfYMK9g37+XjMLHTU5jCoPm9KJDjFvCTRAcmRoQfwuXyY"; + sout << "o7w4QwcevrjFdb3/IYMrygb57S0iMkFJPUUiCF/bOfQA8tpLePYYtg2ILGGuH2UFwOLszxguBLLD"; + sout << "ziSVWU+xmCi46kKVuNE9gyeT2OCPtn/U+qEu2B89UkbynI5v/FVpJhJf6MjLc1jfDrEi5NflvvCQ"; + sout << "l2QLJGjtfDRJgcXFuuNzWBMiVglMOhT4n3bta7tV2KraK7Yc3Pc0GIZv/zVHf3BqeDEikXdw73tt"; + sout << "aomM2RQstiy71sMmGkmCvUnEbgrY8O60g5nCSmMZIFbbLit9dLyBjHFUELrLxXup7wmkxu+ZVEzT"; + sout << "FxshL754iYXXkXTqVGirp3NNNhGKPGc4g77Yne+nxpkZ4MwOi3wQ7YqPihwIkIdYewBMiQEJ4Y8W"; + sout << "0+Os7iD8OYrccbHKqvKMTE84QOKGZSsGIaCD+iIvv9F9/EIB2ZDv+2aEs/3ix85vtg8L2f6WmGdq"; + sout << "fOdoKVuU2pPIkzlyTNQAwai3NsjcnI++lEpVC8s0K7fpIN8uWDLRnGGuq/G2gFQEoqN8eP0E944k"; + sout << "l/YTW8+baMZVp/wQNA4mo3v2UvDdHrDcZmIvVuwEaAUW4dwylXMazecCG1SMkusYSXGoorB615oF"; + sout << "DaqndLnQSfyJrsXdBdZDAPmsssTuOvCjrjTPnb36+WebHGuLUNYZ1kqjhZt8hnxTFuYN0sTWb+UU"; + sout << "2tL72LNTK66l8LuuQKUMimV2uHT7Dsv4VX6XE+YXczrC0HycSaVmtshxRh99fnNTEtDKo2bXbSTt"; + sout << "s57BqGWi3/orxqSXecUPFYvQBjrIfDrinany5uht/8FK5JH+c4PMTsbiEeQ95RX7eOaBo6IoF7PB"; + sout << "zaXhJH5TW/qwdO6K/Caqvjp+no08tUTdn/hRdGMyQ8cIYXsMaiQKbkpXlGnhoBK34XA12T6pXa+E"; + sout << "QocfgTw773eQRFWN6vmhvuv8Pd3KPJAJO80slizTFSOvxVO68aM7gZdnDFTgfibe/v+2N1xIUq12"; + sout << "B+YWn9yGP232QOnNq0nuAvFLhuzlau3U+qR2n8DThKWTboF02vsqThaQzF+0EPFk/3we6AeAiatk"; + sout << "dcvXEbGk/TkGI1V5ICcpGS/fvivZlqYhAIL+yi5/5M3wPX14KwriXpFVMGKozUYaW05+27adupOM"; + sout << "p4/0EvfeM8T2m+MQ3GcPLA8njXEDbLLWnoZ+YMC3l2OVRMV/yFXkZvQd6tAQUymv1xB03lNv1M1K"; + sout << "tPE1Ps7ucMH3cOjff0fZOEYabEl81VbmYUCYfntdWlApBLrs3gBWiT0uLoiV3cJkq2VWtgpeyAcJ"; + sout << "PiZ4L7AblENmUS9gr/7gYdn3uridNqfoos0uvUCFrIS4a7siub0FpCMwiwSiyZwtUoD5vcq3Khza"; + sout << "DVGJoijIBo/yEgUTho23FJqMaOYyRnVen4i8GH1H7PUhJe7KThuQYk4UQP+XO1qdLITUwBlbNvks"; + sout << "ciB3IIN6QKQTcoDXEEQaMcPYRNhaaGYFDeSZ4yIRhLV5JyPhOiLFj5EOzrC2B14Op4lOkerAY2J5"; + sout << "cx0CN9xEUrm80GJGtKudSd0JKscXIDTBj9lxngSScCmKQRn4AWJ/acRm/fyc5Gpg5PLx+o0jCI97"; + sout << "hm6qOSqslV4GS3BozqP1x18yqrC++IJvOISjjfSvSAkW2s+qv4ba9gfNYIhsJgAan1vaAgSPDXOD"; + sout << "hf9RJHBE7Xi86Ux4uK/o/0GK3R4QsKa1/t6qij0XAlc6lmt6MXbSr/Tnjs4ykAbSjzmiwON9Jnzo"; + sout << "KYQiY46ULY+o+UjbHTMuxkTJQjCKtyertjpISD1yNYBxItA703l4ZLK0iklv0ZrFMyov8y4ySVmF"; + sout << "Tj/eFWy/PpTEQdyzGXrotmYbb+V+BG5e04bq40FsUhhALgSkcGEYoQtxLCZzkbyWQmEN/uXC9Gdk"; + sout << "wG6Iln2vWqzZSRMeGZ61VMm+dzr5CN2iHtFcNQnjwWHgr8C726YO5j5eWLvHqLouU8ufiojzvsvI"; + sout << "ycQ9aUfkcr3AXu8hv0+SOUK8mT5JdQ5aPXc9WOe4c6mdgfOK9Jq2ZJCLgQvj2swIoGQ4OVXn2D4t"; + sout << "MNaKJu0/5ujgOg/634gfxJbicVJLzn30TgWYXX6Ixj/JQW7yxM2iYyHyOl+/SltruQS+NnF+llzd"; + sout << "rYuXYejtciq7Sf+DFNOJPSaTJuI12jSl4yH4fOMMkkxWMM5QkFzgWGK7A90BJFo4Uhy05TELqU8J"; + sout << "FtUWdrngONpIU6haQcDpoYtXVQ8h06yZYJnVQBToS4szMcsPPUVYxITcGnbT470DR5+Rrsm4MSkO"; + sout << "ZXmf4sC2G/k5at6a8o2OKsohg34bNzdo+R6rpJLcMLmUogDdHj7uNjrvMXgZgN2VS33jpsKaVuww"; + sout << "0fm1I1pD6ltIfO+iF5oEOQz6Ml8TSxcGei4xi9Si7g4KGTZ6+KhS/KcdGQv7HsGH1+auYumQvKIJ"; + sout << "mWY1hPH9tLSAB5t6vh6YUg0Fx9G8Wf5Twz8JXsoyRbOs2AAvXcBQ5EPNbNiKz71rzPb0s4Kd3TKJ"; + sout << "cNZexkD8z0J8U/KVzDHA4kyJU9tKB2ZArGPkIYAWw+u9d9VGdxXgst9WtboxuCOy0+y0XiXez1nq"; + sout << "8Tib6FgDNCMKD4uk1GvVs697TYCphk9MlHPUevFlgBJuVI2uOjsnGjBtzQOeuWHyXoz8xQhyLtI6"; + sout << "MnW6c6lCs8REqpAWwuIPF6YLzAZd4uhpOyKFTWI5jus7I2Rkr7RmFDcOCXcnHw7M60Bvcgsa5xy7"; + sout << "AoMab6/pVJr6EjwK2JEmkLaxPUU/F0WqXpI8roFPbIZ0sfuUzZ7tZmkZtelboyTtuEZTxbKagOWR"; + sout << "78qwMM9feTEwicGFlvyiYdUGaIwiIckil3oQO9w6bQAaPnygDcykVMFK1fZaZztYQ8AKiEjyr/V5"; + sout << "dvc04CEShL8uRKDZNY8Y5cDGVHOSR2g/u0t3PDuzMCfmQeKNJQgc0uF3ozXP4xvTvopiK58Y3656"; + sout << "m2ZUfjDIf/g9gRXmrmte452CbaU4HQzlIblKaEJr153rXTQbZPbSHyHOHWRuczQAtnyP6k8YSC+3"; + sout << "zqBtSIvq637hKi+7Ov8NgiYhw78ehwHiPtenLDlB/YMCwqqPNmo+8eYGmtOoRgaQecIoYHYWfaUE"; + sout << "NMtwvnfU5g3JS0j6VZloZxGNmrtuDqWoXclUftoQdsDlE//Q/5+KHzyZjf60WrXx1Ix35UmF9IEI"; + sout << "jEvtLp8KOoWs077NCkXs0HVwuKxbpZx0v3qVb7HwsgcaoypbjhWaMGYflEKvEbJt+TD4MN8kXDEE"; + sout << "VHFdCSOzUplHRdcy4aOkQdzfEYvNTgaiTSO2CautGYgS2m+l6Cd8B4K5PZs99xYFa7t80L1d2Zpe"; + sout << "M9rYCt6RKakTTFDSR94nxxzdyfbA9sVYu6MCD1G0lN+zrEisA6hjqSAqUXsooaW4WTOd3HoNkZeS"; + sout << "tGviSThwRe0awN4pH5dXfHtxaRBkzOBVw12FB6urObgv+3jcnTzRZWvF14ioEmMuKkHb3Ienbv83"; + sout << "71BeuTrUarJazVmIzbH4ulyWwLxEHeKL0r4PIEfbUg7tiU44lMnfxmrFFR6FiwuBxrUvyv/yEmiP"; + sout << "AbMnRjX+4MVrIHr001qxVFOaeK5FehfvmWotuUW63TvubKCRp++5uQDUh/LeMp5ZSX/RRVpEKEwA"; + sout << "It9wUHj/cafFqSwt603pOOtAC+xL0ozsVF1kYyl9vBEow6mBe66WBwekd+DjlzLbehk5oIXSaYsp"; + sout << "zu0iBnukS7YoO/2Ho5JMdkJuHf5AtuC1bDt51FL9McuIHboobI3/K0YrpwPmFORwSUSroXle6XK8"; + sout << "LeX4Fcp6YaicZoFYpLVbhVvrJSW3F4zBERbSqBrHBMxoXMblajhf58RcPKwZPvPnSL+yB5V0VP5X"; + sout << "RJtdo0ir42WND1VCIfVqCCrSg/m+/R2sffd/CzqQ/JuJnwPtKrAvuaZ8zbmq9UkK0WAWY9Wql+gy"; + sout << "AYQNgM6Nd9TNJ7LEyEJQl3nK2qjg67rBCizIdwRcPFSWoWE9DjVz5eFPIJjPG//dpt296BsBeW7U"; + sout << "NhV2Ig9EstFK3+/GC3annnXsT7OQt7QCnx8BvzbVHiFU0n0yikhA1uU9iYm5qsVwdJ23NWimhTeo"; + sout << "XOWG/MVCrPHpV2qs/4PljiPW2eVdzodj+nfDPVwkHwmFm5Y6TGDRBfJWd+ZiXMkrDa2CEMrMvzEL"; + sout << "/DymDMQfMI9HgMiLWmiXS4DjcRL1Fp7DTHcQa9swLcZwbyE8r14L+5xXzDySF/EKJ9LwmldUaN+t"; + sout << "qjmQ7DecED+MkTHMYXnID/Gl3a1/wnrfxa9IuRAwHiIhuwb8siPJT1S2mncAqk7YwUyaY0hl6TB1"; + sout << "GJqoaPTaeFqzTD1zuz5N3cFxx1O0R7OBzHuBtV/JXrZlyzNKZhUaopWCyQ64+RGa/K/lQcyJNVo9"; + sout << "lErihKgf/LNzd7f94T3CmYMRUJZanOLNUWcOg2iJLGR3hy1pnlmc34IgQ5SHZisUBy2jsdK0xoKA"; + sout << "Kb/cPiaZj6ab/4sq+Owh5I66r7FjACv3uILtt69Fz1o4yaOb8z1ZSvFei4CPrikpxYXltEj+aPZq"; + sout << "pAyL3+qrlJ7eOYE2svEpEAiYTztUAn9/ukCZZ71IafGHnhPggc+eizHPuoqfdUa1NWeM7H/XcFHH"; + sout << "b5PQ2mQPCY1gD4SrBqPS7I1QvfMuOSRba9YFbJAecSC9bPHlQQ+c/rma+R3yhyLyXN2j2xzHpgJr"; + sout << "oT2cEn9yT2fHRiZIWqMAJa0KRJgRN0hLjhGQ1VKN4WZfKy8uyqCrAbzBqR0eYcDjkpvD/syBg81K"; + sout << "8acv+5OjBrHgCofewj2DFs0lC3mMlo8qe+GNjxu/CZufpnNabIY4ggwBxdMWZvZQ7j27VwMmvkdd"; + sout << "Uei0foug3YA/Gy/XUTRLhuyYKy9dqdv6VTSiek0aFQyM4kz4KqKlAbGixu17SFgvZ33fP6ZMZ5Pd"; + sout << "eWC+Z1vRafLRpsABf1DJujXUEUZ8HJfWBuuaMkJkxFApcPwctu4SMpZDjnoc/7gblzc0yLr8BRxD"; + sout << "gygNRNu4uVAT3sbaRzYOvF3tq7wr+R7qS2YmPEzhoZyfubEHdV48wmE3XD/RnZJ8maanzDjEH/Q9"; + sout << "yX/a34y6YSlrtNim3R8pI/AddLi86cAcEKfK3V9QXCGFqO7GfQuyZHO75grKpL+6MwRPUWigaLGN"; + sout << "9JfWrc5tpUu8NPY0ACmztq5p8iNqylVgIYOC9LRIl6TKJQXPSY1k73NpzNNzOJhVCCLY2IaCMUtY"; + sout << "pbxZs+VLufZmNFhq1D5qxUDfKII7x9ocFYqtwLnjZztYT1HhDTcwmB+5uQWHoXKRvsTgvSfR1l33"; + sout << "w4mqzDAI4ykfuNbLmymuKaVV2ST6W7pI6xfZd/YD1em6OQYF+F44tj2G846Ry1IsdVK8AiprcFHA"; + sout << "QFxuwxcY9eUgjSvhbK4BKkQgFSzx7yDZTtDxXSWrGsPoHUCh0GKU757B4ubLvDV197o/No9EKfGh"; + sout << "3tlmEwvvwQd7yTgwdbksgVieGZBJqLK5eRe3CqYXpHryPcxydrO+2LPpzUiiHPY5FDRZjnYY+as0"; + sout << "IC4dtlX5BicqVB2gbHiKHbjLf0pob27d/WqQKKfA/7h2wY/jYyckEiX8g9M6I5QeABzABxyyAMlx"; + sout << "zoi3bd/RD35ijUysmIl/+qi3GHaYKK9Bfu1SRb9oPgmFYKKKRrYxm/cypmJXp+GUjlNokWN1u7OM"; + sout << "Nm6oqRRuqYWiCwZge4HHVQkQs6BuH/Nqvd3Bqkt9U1eH4TnkS5MhnYVdHt0hxmG6Py5A446SERn1"; + sout << "gJRsUDXk/kRkird01o1UqCrlhwX9WG1cjY2I9nFsokudgnYryxn+d+hVdM4E4MO5ZHmOXicEn+hZ"; + sout << "N5eyfU7B75rBT/yZzzot1k+B07BE1x8K7N+S5JX07utAi8/htK/a4vxKhJiyx7p8etztBI06LKss"; + sout << "grLKdJtqZvX3kwFEIXqpJn5W2bijJNeQYnJxfFY+5D/k5POGuBYjf9lw7ilnIkbEDkaMWf9Mrt3W"; + sout << "A30zlSEMksIByxqL5/VFx7oM1oSx872AYwXDNNnfITgNwKBwniRHuU68SjWBO37CYYmDhcB7Ug/M"; + sout << "FMX7oBgPrld/Lg/ut/xbWwFOiM1M6L0TmT5HX6RcFSdKg3Sda6adAkSK3Ux2HTZAUVaK8zG2Mm1e"; + sout << "W+/4SSqK29MAZUBK+oEPrNB7An6NgWZ0TRu8sGeMcZCm2Zb+3+ZVUbsTFoG5pFhzzvMGr9fAsh3Z"; + sout << "4Ngvc76mkiqT7xDZw72BUnrPz+eO+RGqMG+oGWJlXd5PYD4XIkW25kBfOr+iK/gCFUsfZbFoHEFa"; + sout << "M9EFExjXiZYOz4iuSZnpeChRQ5wbl/56iTGSMHpB3KKU9Rfyv4dIDGS0KmnSH7Q7mdn2lXerjDmF"; + sout << "zlc7IbX35O0kHeSXaH08ZFf9vBweKQvEW6Qxs59FWPuS47FpXRAPFClwYskLxByN2ux9kxHprAFr"; + sout << "zuNKRyilr/NlCEJn5SiYU2/fAif/YrNNiyxXx7sWwQW2mf9Emqzsrkb+eCqAPs1NmwuHO8JXIyKO"; + sout << "0EbIUaVuENKa+DLtcrfI9HlxhrF8vd8m8s4ZeBHWr3jkbVdcjX9mmtcSrHyHXcWVImc8CJOiR7jw"; + sout << "0HTAydarj2G229/7kLX6RncrRqY6/Z2e0YEbfTt8RDnwriUmKXuf7VUMljP+tWx8aXMlbm47WMaY"; + sout << "nnjZ4wE7UTgrYUce7ehuGDBEM4VaKYrp4n782gXFdo9VhUZ7JEaveARi8Di1SI5MuTQ8N2hUfPKn"; + sout << "JbK8mEjSmcN1sOZBOXxxWX+e9RRsO5t5ujCSy+UBr4gaqMHAlwxtBmif+kl7s9o+UjHOlWLn7V74"; + sout << "PiLxVUzCTHz5A0rX7MyGXMBai8BT5XjmazTZz8YPIiq/ZmnyULon+uTrmBdivNEjDq4M476YfAma"; + sout << "rcqxp/picZEanq8yctrujPHilXH1WuXhkPM+gn9Gkjp/sGfA13JiFeXZO9GJHkqLwBKApqS1HJIf"; + sout << "bhomnqcdp6+hM/IEaNS3dVZ1aTSefIM/OTR2fNVbmzwfENzVke3dCJvw5B8zTsgEXeKuRE58pYaQ"; + sout << "H+ffv486NoO8gXKrUF47ptIj/Tt5apc5Pt1AVFsuaxvry3UrUmY7MLbfM4xB29ah1PY9iYx/OssX"; + sout << "DGvnDskGHiQJlz1s3saJl9qm3BZr3//6D1/CgGvXw0DffJYL9yIqlzoBRKAue/lQbWagNArmsccI"; + sout << "WxVk2TqYgTHwhl40fWSGIuY1kfj6FEAUxBBxGul8y9Cv3e/hMnE0gDQz2tZs7EGqP8WqKPF4BpDh"; + sout << "Ep+6vvj67y2Hf36GALRYNLlhCB5+HygzOW4jwtQvCHxbmtPWvQkRQ8iQ7XlZH0OOATz1FsccBZXr"; + sout << "3+O26cf7zcoZIE28MIQHIJqsDFHCJVr8JVcYITcv46R1z9ZgQQ+z7KK0M6FY6SOgCyu5kyQR6YJd"; + sout << "HEYkNp9J4N21iRr08DUYycJDtrf7xoyYfj2L5QjpPVEKrEhIkjb1kr8uZp/KfCjCkedPBrMsPuOJ"; + sout << "28Acpne2exkkKaunrJZr64Axfgl5t0OIBP79Cqy8PRnMYQTtqHf/pxaHNyrYaXO9vjtTVoV++kaM"; + sout << "svy9Ol/R4LxEXrgtlSYhKF4o6iCJUSiLhE6j/xzyOtRrCAHZ56WCJ7YTm0oa22TtH0DvoFMY11Tr"; + sout << "SZ04ig5Tl8At1jXKymUSnK7EXANVHZmZ+01xTE4efnFIf33N4uWK+c/hEqmS6KCVjuHT9FwTKEIm"; + sout << "4cKum8uhL11rodikW09dYfIyQV9yO0k6EJpeUSQMhPBqbgTWHnVQoHoot+c1uBoOdK/+bRuz5vip"; + sout << "l6+0nmpZZoO+OjdEap1pQqpcTEhmfBuUGXCibggZhHEvHvFGQo5an4N48OQA2B8CccDHMJIDP9+j"; + sout << "0JfmziBnF+ZOLfKLwuuE4uZg87iSWFkwSynsWwoUQ1Cy3u/URW620WLmkx0GDoGPkcsxJ0LVu9dh"; + sout << "OrWWaFesEANPRMCKuuU55hs913KzoOKdgZPzM8dheJZaZB15wi+u/RTm+obSWZVwibTDyLPQ55mz"; + sout << "FSv7Lyp1EDFvxyel+7osJa5ifhLrU5f6CAcKwS5t2IwaZaBxrVaFgX5lQmifY1Gd2new1mkfWYmb"; + sout << "n4AaURVZOBC4U1Dx65ch0PNeEYxk0DLAGOsvdBDbWbFMNc0LiGEF6GiCMBYVsw/cYnjY07cEZg5N"; + sout << "M9RxGfLfVlyy4MW9ek5ov4+NVLV+vaosZvA9gP92vaiS7jBj8qCb2uYIK1mfrjHcOUFqcxgzltBR"; + sout << "eSU3ewoVRJIaVreWX4RSXTomL1GOyfcFQ7gZghSJLKlwgkK6+yvqig0jyuvmiRFJGDYOxQtDykwM"; + sout << "EKOOZrckRdktXs6aFQOla+IRiN3XVUfY2wvx3egfOYPKrJS2HBEdXeIelB7LK6vRx8nwuFAAwZWQ"; + sout << "iB4OvgM8U9N2H1wa0xXWc1897unEm1KxAP1iS44BmZ0PPASPa1RrltkwWMwRtMp9fD08mwLdGe88"; + sout << "fsGIrnwHA0HIShoZaBmCSvX0yfaJ34nchnMukP0B/8v7f2j0t1F6hr5qJNSV3cTNcc2QToBy03ro"; + sout << "GphkM0UhwwoYTlFppF/8LV4QmtsNK9wuziTmv8ghXUgvFRTtVcBMlHY4SFcRKoagFF+B1UfCuVZt"; + sout << "G8fNilXM8zc2/eP3g5FTU3YXRgw2CAheUNiekne53EE4DmDJ1f6n4OyjZFP8PZ85Nzkkco9FW/8/"; + sout << "31JIClFqSE1VGbDXleZMOrdfoknF55krQB7v6Wsu82gzQU2AvvhbuFlPodDKH1xUuUT7gERKv4ka"; + sout << "fdS1xf/PRafedvPX7k8fFZQerqvvUlgO6PaEAxK5MqZqHiUgUh9A/pqBZNBl1kQLP9+G70i1CY9x"; + sout << "tTETsJJTgYz/HP6yXsqeFm52yO24myKHMURYBmOZU6AFyqFCsOdJG/GA5otaO9MG0A3hPdC3Py+2"; + sout << "GWbnl11yPDVImktL+LYbx2oolBWZrelLRuKAtDp4p/Svt5R6fOvYM9XxZnnpR/MNaTt7I8iEZSQe"; + sout << "w57IL33ZfHgKlxEr/ouJgT3vlxqXsuXnP49ytEVKGja0JAOzlShLYqB3GIY9Gb3EtN0h/id1jqQm"; + sout << "QwR8y5q54W8pdYflhgi92YWYroUUIFRxheAgPwUSqrOUNiN6xpSwr2usNY1zvGGsRqFKuhKgh7P7"; + sout << "cvGc2sOj3izPgl4NlR2DaoFTbXd6uDM3IrYGSCFJdbMlonX5cvO91ySJpKwPOnqjAbzBjnZCfepb"; + sout << "4px8fSH5I6LJVU1R+sGRCoewaHTFlsnaCQrsy9BTGMIWwAKCAYCbWN0T4ItZaXhJWarghYGsT57P"; + sout << "MagSlMpKeiToHWnWhPhtxkhm1ZVRuYekcpwrlWwsHBV6O5pEML3wmWZDWJNZWh/GkmhPhf0aF78F"; + sout << "2lOVKBWLYq+4Xt3lVNvqCr7m2rQNH6KzZUNfeoIJH48SgirJQiNNE2iQOsReTQTbCW87NCN4GKGh"; + sout << "Vf5A2JU03N+5fT+dorN/LTQmeKddK0MR2nshO1m0kZSQ9TDUE6Da7ITtIjGpKK1QqohIx9BEMoML"; + sout << "t3BcLTNkK1SaaYRE9Fm8AZXr4z9AILXUKktv5bytIRBZncs8078FdpF+O2JWEzELgG2s/FTHTyjU"; + sout << "NU3A5q6+8LoeSXuoqOZz2QwEMhKwkz7AlujJU/CJ3+uZBAUBythODaVBlaZM6/f8dSY3489Xiu9z"; + sout << "L917CSsYpq4JQWq7I7pOkPjy0t7Y7QmKPITQv4QQmF3P6SJXiainjWDQlZx59RTzg2Z2MDZ0dcWx"; + sout << "f7f44IkngOVEUHi3iwrSygMTySBYdDxei/kBt7oXAAkOMLucZ24VE26nLD1hMAwRz83ENc0sOZAz"; + sout << "F7He/e0H2I53NAYlk0s55wntUcdp3sHh3UOcGyBGcRUipg6NWy/LrWzxWqJdo1DsBpV4iazYLRfB"; + sout << "JdKobsmPGmWSyABV8tea/IuVgqUm5RrWfBUa6Kw2TqVuos2PsGqsRi4cHKl3XQ0GJjbV7qV7Nacn"; + sout << "mRVqIAetzUUZkuyp6Q8OMLX36zwwrlP0rPJaqk24VBEnx2FKjaR/LkP5lUFenOHJuPXSh89pYWA7"; + sout << "7/y1i2ejWuMrpElkJ7qfpTGD/A5ZkNWCyla2VXOXUVjK2t+cBYByjEwlMCwr+tXUXQYVMPf+UzzT"; + sout << "EVv45u14EhFg5XOu7urPZWyO4EbejgEvyyx0c6Zgt7RaQgXn+akwMHI3oDn2NcZYGo8Vwg/fSs6j"; + sout << "Ggzyse3ShPpUr6qb9TT3Vo6hXegWwd6tDPsqDlc62JhQ0MYdw2A1x8aGK1iUCz4pqsWmXWhGPO1g"; + sout << "rXyqBaIrVHDMim3cZO2LY/TOwIJZGuHNvVnQQ0TDYAQEgyej9mVLuWaFuCE/JZlw/8U2VngPVq0w"; + sout << "EePssUQnewMHYTfceXOETEBkU0xl4WLdwaJbvIXG+pPcVzYW8z/D3yh2LU7KHbklffbjWhYyAm3b"; + sout << "nTTR+YTZoF5PwWA9DmsxIbLQJn3Ejss6pOj1YSqq1cD2r1+TZjWefNmH/nncwjU0T6e/iBxkgQLc"; + sout << "NXouKJEemknL/HOvd5sUS6YYbIGg7Pa3ur3Qn0quKT8QSHRPTxZeTP47rSp+Koibf/XfZOdG5RzB"; + sout << "U2Kbz2vi24JkfzbNCG3tjNvzPJVcZrMde1TNVNQVFbXCzMZgO0Za7o3IRyxy05rmieBcGryGX1ln"; + sout << "hnhwSzxSSxnuJ3szdlfw1j68i8LsX+iaZHonWIw21afDPwXaoflD4cZIZI2uPLAKeBoOYMxKyFGj"; + sout << "2KhIVlr+poDcu9gZBMF3CwwxSRZpwVhpUllBfOJDJTLbM3UOyNRMs/U2T6TdxH1fRbgeQcfQ4CqU"; + sout << "D196crANUlHv6VL0bU8CFS5f85YwKO4JliKHZhvQRDtcBrwKIXttWPt1f3OONAscmsl+JtVgVt+r"; + sout << "h1R+X0b/puYAa4tqBvUBiiokN21cQR40Pp5PPLMszGskImO0XM6Fx9spo80xxDLhOXZsV4rJKs3q"; + sout << "lub2BoFFGa/nbd88uBheoBR8lh5d3sWciz8Y6eMMlEASpLmXFg/nqq1jzTlC7lGmAhyqSTAL2dHJ"; + sout << "+8ACp+IgMI8v7DR1TEp0qgwrVC7B6+8bozJlpH73HdqQYkyjAnVcOBTn6KO1pjAJ+YcXqe2ioiBt"; + sout << "OMAlNQVAJZgxpEnwwd16yXDmK+Fp7v8aGEx1EECVaQYW7ZVvKRC9249ioFaEgFjpo+8veoWFd9G3"; + sout << "1E4aBC8rOXrrsI6U3QqWbYlusqvogtKGwT2pMYSidnPneM9iaJ5ixPSlB7VPp4MwdZvPIpOOR8Hq"; + sout << "cFqQfZAa7sDq5idYKPUrRo5XD6szBNbb9Oob6y3G9RQndEEt2363luFr6Nv80r8yuyaUUlRcmPmU"; + sout << "/VCExotSxcx6B4uWyd6UoA1wZv1hY+OCUL0ahArlLzkR2ZxowfeDCn5eoawdQnazyx9wSAEmqGTT"; + sout << "lrT+MyINZeyxp9DTtvjFwVw1aDVT87PF7E4Yw2H5V/55elbLzn+Rlr4AgU7W+RnuwOk52LCKrSud"; + sout << "JTAq/vix0svgRHAs3E8wsLzBsFXc9SZ6/AWptBKuix0lkoDObM8rLJMjxpQPmIyVvB3jWXITDn+n"; + sout << "8ou5GgmZUDeGu3ZiOWbtPhO90akFKl1XDQJ+k9DGrlQd0dbsDa2lTgPbL+3wOD9fdjFzqfmGA2vI"; + sout << "kg75VqCCEOUXbQ8N2U2IALiOEuYblJk7WVsbmOwmEw7T57QVJvivVJ9Z36TJeWfpzIiHCeReso1y"; + sout << "3RGCo1qlsgfwtz5e3/ycu5aEq+Pg1W4EtSpeuIEZyH7zQaUILmAFP23JDmHyZNh9ewxSRfZGETHg"; + sout << "KKN32MMx6taVzf/8+RQhxI3JJ+PSk0vrvUEX4L/2Bri8meJbN5UtWuRN5gha/O8jPDixa4XrMAiq"; + sout << "etOumapH/QhEJBy5NHmcjLH7XpZ72qxTIL6VS6m2GjFUyekOdKZOHDyEKA0FR6wgEnoki7gS0L6N"; + sout << "5plyHCOmd1aBWEXf5+P8EJeh1nmu6AXtDful0kn8G9nNnWrbC+iIj2QQ+XZPxdZIGuJKd6MrXf0y"; + sout << "zW5dznFYvB8R+LTy2SW4WFvMWzJAlggi1N7w25/rvtnt57E+I5TimCjAKJ5Vcjj/R0DOqX4BNnOK"; + sout << "MqeMDogP+DE9NesTswgfFngZBjomZsx8f2Wdzzi/UW0xBZa3yPk+CV+wwTlWNxgBKFCYFC8GlJ41"; + sout << "MQt+TPDsDZJuBdbexuS/PA/zzOE/wZPVXgzLZFmTKsZAfz9894HEHalRlKLgTlvgW65XDihdoh71"; + sout << "HDwdA9Knb9r2qH0dwsOTpoU0uABJkND+V1Ezr3oi35dJ2zw2gR0omEnVXZM9dW8XIp8ln8zegt+S"; + sout << "dyMbNDzX9RClWSIVIGuRYGCahBqXMCJG/AAKvkXM87mZiQsIz6uLvJpkeGiG1ah7kTSjVFIKSYES"; + sout << "73oz5l7AIAIReUzdvMLk5ZbmqdW4JHBR5XBA3rTpAyNEunRi2Ddy1eBt4i8I7GlXaaZ7sVS028ze"; + sout << "VkY4WR/6FJKO8ccbofsJb8BiM8UZdPmkRdEKGgFv13dkq6inmo3S52P5S50mYapa1YKvWCvMHaov"; + sout << "z8BOe6WqptQRGc+vZ4v1vgq40yaWBa9pSTyvntLUKCkQX0qi3ifKhRykXP6SBylckIHs8DfuJqXP"; + sout << "X06djSx/F7IyvBQg90SLa1h5hYRTAchYX1ZgvG7LlSvPp7i+sWCx5D7KpTbF6O7AU8kMPpaRm2wE"; + sout << "RvZ4DCgXdh9Lgw69wRlMfWJEJcV3mOIGDRoBZToyJqgHMDJigCet6NIsD2xvvza5JWPf3DURWRHz"; + sout << "wdVqLXlUivD6/9r290ZcIbdWFMbz2aY9x/ojLrJmAGX/kySItIXsJlDbcvquJWNc6UYrpxurQEku"; + sout << "ejm7E6HjXUPmi9gh2kRkw29A5xIFaugoph9dZSwyCEx2BbqY21hwQ5elfdcvdJD/fe7iiWmfrtxQ"; + sout << "BqVebUOrY5+/vEvMf7EupYK8cxtCb3mvViCR0rGokTwuw7NVuMH6DEJ/zXB7BO/2bPXnr48Q2pmM"; + sout << "pThaXSCJ5Ta1SZqGZRG5pBsUWg8MCQBc/KmDWv5csQlqDJOLvEULRsKdxxcm2wBthRXx97JVHMwe"; + sout << "Tlb1TkjoSIQonMsvU+dRJCW3qhgbR4i0t7wGfbg9YdXG8LiKJHvkpu/IPKHZtaMxOp3O6kl7Lcb+"; + sout << "Zr412Aanu8BN5GP+rW8/C+TjEBG/WeR23iZr6SByC7TCzZwJO25J6mtC1nFxdAUizSmxUfwDwvdE"; + sout << "5WyGMq8TTIHZ9KvKMw/jODKXcviOmLHY9CsPtYXbChnF/fWgQx/ykshrLmtSqFEaOt4YCiPFKQHh"; + sout << "/IjsYqcs0UQeGTk/6tnp8Cewe8A1DAaBEI1RTllobQKLXEUBWAV0VcrVXILhemvApUW0uqu6dSqg"; + sout << "cfv9uvu1BblNEfNmfTcUssCnDQ5c5vFjEB1KESkrBTL+p6bEn2b1bbxQrhiqnVO6mi9anbtBgBU4"; + sout << "lTjSG4KMnsD63xnrLapoU7ReEZxGjsQgm1Af8/lewJajhNOZi3FmDgkb8Lhq2rz8OYy6pwXCEM07"; + sout << "JwMzOWFBZDUA"; + + // Put the data into the istream sin + sin.str(sout.str()); + sout.str(""); + + // Decode the base64 text into its compressed binary form + base64_coder.decode(sin,sout); + sin.clear(); + sin.str(sout.str()); + sout.str(""); + + // Decompress the data into its original form + compressor.decompress(sin,sout); + + // Return the decoded and decompressed data + return sout.str(); + } + + // ------------------------------------------------------------------------------------ + + // This function returns the contents of the file 'feats2.dat' + const std::string get_decoded_string_fhog_grayscale() + { + dlib::base64 base64_coder; + dlib::compress_stream::kernel_1ea compressor; + std::ostringstream sout; + std::istringstream sin; + + // The base64 encoded data from the file 'fhog.feats' we want to decode and return. + sout << "AXE4jk6QRzUCNtyVtAwaCkqQk/DMJKm1t48e6FXsZJ/Zfw0Utm0AVwNlJU7O2+XVsftV/yE3zO4S"; + sout << "YG9DY33gIQWV6sw7AGva02FE7zlkRbW3IOyZeG5LYs2r4vGYmZvZuYQJ8CVmkJpGrYioIuqWLyoD"; + sout << "IfAFmt0z7HREwTsGxP9BG6UIXb95jQCGocuoO+MxQNpq3qtMAr+C2xnN3Na+ITiKYUX+zEP+FrWJ"; + sout << "uqsYqhyN4H+1rlYaLhrj0nhUs4+Zp+fn3LZecDsZGeq6KDEIM68rOZAY8WWA/o4x30cQ0P299z2m"; + sout << "7Pl1vRxpN3MdgzkS2zlOpGnrexfA3TK07nFcTK4Lc97t/75JULM85uEo6yYUBjyeY7MPYTHaM0tH"; + sout << "+TGdAqQBf6TbnYya1MK0BM0nVLp4TkMeRZj49erCtZIVaFWOYmufs3LywApX8AQZAsBhU0QA5cQu"; + sout << "G6HKcdTiqxUxbjOIueAxWBbDbGIwUj8URJVh/WVW4TAXnzZu48JvRzif++Xeswz6O8+PXdPVp4BG"; + sout << "TaJv9bBwqV1y7lM/T9FxktP3YIcnbW+ezBUh/logAVeLzCAbYaqChRMUkBizO54DEiA+NBDV5ndu"; + sout << "Iax8NL9DFp1zmO8gdJBZZLgzf6K8ZbRwwKNlLEkVEZGKei3aVIwo7ed0tjG4aEIodgAuoXD8oa4p"; + sout << "t7v3sYbn/Wz94AyvAwFV0M1mAMaXqXhpfJSb5CJsSOyKqUtBBhpS58V6Gsq9M9weuwVHee5JI8PI"; + sout << "bY8WlfQNSVo4JcToIdKPhehz4Ywad3gdX8DhSd+WBT2aMomB71jr3m9CcC9da3oxXjh/jVaN/boQ"; + sout << "oY5UqFGZ7T1b42x79PiIoOY339sKYR9vr7WPDHTouVtas/3hJBkBKwXqKmBbRPU2N4DTgpt+1zFq"; + sout << "nig8prSmjiAh9ofp20Pbl5FbYoIdoOfLgYIluAOuT0XhKxLAwWoNGLBA37bGagSxzXnzlDPpFDZ2"; + sout << "WVdrBwehuD4WhBRxMiTZkzJwhzRkyvsBnO/JqFxXAD5b2CvH9p+3P/czeRv2hRjNZaKHEfdMtIWn"; + sout << "lPR1zWmZ3LTMBR8R2kmAUt50Sikzj2TpzLlBhemvUVAcv0HEdiILnp3eilOx/ee7GzEroOMHuOGe"; + sout << "moxWRB1EHj6h2Ya6TVmMb47SzMwKo7mf1PtZ5JLDk61k6Bv9M0WCDiwxlBwC/PjTxHDsTOWRrSJG"; + sout << "roZ/csy/RQF//nA0f8uKC9HWmGTOsrA6VAVteglGi9k6E9etxvmGCQD2ud2tyz2HSnklJAMutXd+"; + sout << "flvArk1avpvwXGoNgT3EsEicUGOg93vXl1A873vbHtwnycMzAa8NYjZGW/GPROg8yKkbpoxK+Zel"; + sout << "+VP14SvqQDKEYDevoGuQWhsQHhW7bYeSa0bSamm9DFzqK2Ld4/aU7JUHYdaAvAJPIHe3F/N8NcJ1"; + sout << "RE7VoQlrVZXBy33Ly6wwiv1rm8yK4sMdBnNXcxzCyG1NkPsmC/16A08Dy8RV7nJhd1iOJ3gy9BUG"; + sout << "Y4ofrw/XzITd6vL8mIJW2jyrVzW2OPqMWy6IO/7CHz7d3k5+7KTKWkksTRsuAYWDgZtU5Umph0kN"; + sout << "/RRnuBRIe/6KYj5/thAY7yLmGFoy8jhDmIikSP/l/pJhUjjgKk1R+oYHGDJ4FSD2KNfRDQ/xHZb2"; + sout << "++ObYSCHAuhB1HoJBwPcmxkFtxjS047iGXnXo+2nchItcrtifnQTeI0qyV6conJnskM4jfhnKsj7"; + sout << "A7y4owQjOvkDwu0GueuDOTo/9mW2NjAaiHCp8yarSV1dPI1b/XR231+p5IEp1zWzqFr+O3pDAL8F"; + sout << "+eU/vj8taOxtT0CcT0gW2rRr8oTigRWCUGP1hxDBdhJtAa7Q+CuFvmtpjm/4fmbroJ5ZVNA/71FN"; + sout << "yL9DKBmiVTwljf+yRpTGhpMG+xky30zS2R150N93YDDXVT9StjKaOrtLap+9w7BtCvXGdPiyR2QN"; + sout << "g7gqzrz31poE36fAwM4san1jbbTC+eYcTErQ2wXCQkVne3kbVOwErB0ayl7rNqkw5b3gAME4+DnN"; + sout << "IM13kdtxY2WeBND96g2enTmFizxHWFzW2asW4/XqGt8EVmzTB4XM/Ytd+XXaNadRUHh44wiNIhPp"; + sout << "txmIbtkSesyQ4YYSmUYEEcYkZwfcRxHUAGCcnQbSKGLq+N5IcyiLVwhXDfK4fqxH7Oi5DmOwAVmV"; + sout << "kdDQ6GP7wDcrkAQS9s6fL04bNf7rodacAPdX4HwIKa2YykdXWOiQhp6BRxjUG44AVV5fiTGP1WVn"; + sout << "MJKXYzSyY5oN3ADAT5em+cIYYVPnsnZBuUzAjHAw3WYj4VlRIrmP+oPdKPFncEyfTn1G7DbmyaPd"; + sout << "TL3DdB8EDImfZ+A2UMr2i7jynH/fFXzsi2PM9cFKxCsEqG2LGr7KZDP2FFEVvWwIDnvUClj9nrHd"; + sout << "zyioiild+DW6PoYvqFQX4LUf9Jr09RiJuIvz35I/15pBQbSnTxD1RFJwR92k0yA3514jYBAtfLmf"; + sout << "/1refalgDjOyD1ntgp3INSYSbpR3TqNIynWwSJGxq9I5aFJypV3Nq8w2Rn3/kld9nqDaTG79ns13"; + sout << "dwfPOiIZyfdrcxkYOtS4+iijs+YE3JhRKVWMf6ub0cVXttJGORgpzgkSoDWNVZ6O3hVydCziYzoS"; + sout << "RgHUH2Oas+aZK5IF3Z3aaRYc0A+wy/PEx84LRHsAYdPUQJr+bGseC4wScP1Dyc3xpeJSR1V9t2tc"; + sout << "AMbhtwcKGX4j5nGnhxSevKzbCDteEnB23TnSQwbuWYmzhB77V3jpu1Cm2h7FcKGM8vPbAQeRBr1b"; + sout << "+RY95s8hsZGi/USiu4TQ/wHZfEGYcHuVBI920d84pPRVe5EJ8+FzZj+Qy7JwriqLN+7WyUCFdwxn"; + sout << "4B+WXHTe2epBJMQzlE25kKDHKb2lDDv5HzVUlrZK4rEShkPNv8SB3U9u20GTLlHJbM8RDvPkNjmu"; + sout << "U4ZLDSJylDaRHWqgchgMnmX08aprI/o1HbgZ5aiByXAUoHSSGHanyYmW/S0LOW7YZH+jOgxzW68U"; + sout << "lheNnX6Z26RdMe5Xtkd5jx2jXgIT7HADCN8wWdZEVT7FvXRoxtO3nz30cbOim/+IvB2lt//OcTJU"; + sout << "BwkyweuhJtiXZV1yY+X2z3dWBjBXVFWidPMMWjTUIwalo9A91RL6ZS25kuBXKm/BV6X/zAHkY+jB"; + sout << "A7qJGLPe5h6SO3GKPSLv0wE+9G6VIhH2TPLfAd+PpqM+xuUNlWlo5JR6ItOgGHuFdjZeDlISYOID"; + sout << "CJn4zfaU6M4Dmw+m1wsIiyQy4Cw0DcZYUjwStEfLJu7BRyFI55wEbXJr36PRWpvB2wzkI4z1u2yM"; + sout << "7bFvguH0teVxtMwQy9S4lT9JlfX8QEqNRsuxzQprPJqj0ie2cFBrvJ2E0FtcuarikbcYAmC1iYdz"; + sout << "Tc2CRVe60taBHR92AsvuTv5OhCIssVc5W+Lbv5e3S0vLwUprWPhfbth/zNevKNiIeogTJOOTQexo"; + sout << "4EhJ2pf8QSz+ixlWi5ffYKLKSx03wYBZ3fNIhJDP+Y0mz2pd0IyGhxDrMAxYHhPjoEGCfYX6fYz2"; + sout << "XQGQ1eTGvNEATqz2v9HxYBicYAcW4mjEcTqsMHMYcboVgqiLxOD+jOHJqNrw+eiuC2Ucu6muQ2es"; + sout << "Zhbo1UNZO/J9XbsZIKa6t0PX0CSczL02Dti8r5rCQ7yeUgwZFzy5oJSShAuUvBDILXBeDhnuU4W3"; + sout << "eooTvk0lkb8sNAgkUlvHXgdN9jkfQuHqX453IOffB9TnxopLuxS3R6uuRa5ED2pWkL0L5ceftRv3"; + sout << "0T/sMlHmTNoQhcJSoINMf7f+WmHa95iD1HolQNFJaDEUzLV1DfD2mntncFZoJ4zlr+b/94qp8s/n"; + sout << "NM0CyoI4ansElbkkjUs7QQp+43pGXu8sRgg7tva7/3vUwLp3Md+8WkX+uppIhj25nlwfkD6IHxdk"; + sout << "uFNPNgdIikcM463W56CKek7LufT8wQG13mJOMJCObvhW/EU/yb/88hegDxeVJKrWHRyebllsFbZ9"; + sout << "UeEWkJPKdA+YHilgfSX0aFgofCDmmh1k/cO2RPqIGIqZSfT//6lFEDdzPcCLp/kd4oq7qh5Ko7lo"; + sout << "bg4N9n6F90oJI+JFJRLy6bEhZC/7obQlP6FFUUqSpizj18+zTLPUaDpt3/eg9QeXUeKcNlkHYt23"; + sout << "AqzS6PqB9t4bUr+E81QUxpegh5V0M4OPQJbFQ4/rna/AwZVDmGRjJLKzkWUwpnIv7f2+Gl+91ly2"; + sout << "ve9Ube6Jf9h9M/j7kppHRARY4QbXauYAcSp3wRaIueJWx36dMsqgzuYfwHIyuhGS3CbC1EHUZwM5"; + sout << "hgCHFvPGHLM7T7p4w6k0e7n8RZJsDybyDAydW8SfSI6LIeX2st4LOHrdLrpUNPyE9JyuIMCSnPaT"; + sout << "1XnjdWev4jjMeXxa1XWbzy0AQpeQ0UYA4OpqRSu1DyoPWf4IA+bf012m5Im/1BiGF972Ie/6CNvS"; + sout << "niYOfmxxLWTihdMtslxiy53y2MW2iDzLxX5nvUSyKXfO1XUlDJGTd/zfUywZcY8cgI/f1IPzr2Lk"; + sout << "3/YKqkJ2De4IpbixDEkbgAroIaoOZXEGD+yNzXVcyISIhMiKHJERdwVxp4j3duc3M04wrwJMZvtK"; + sout << "lk5jnn+ILhTxcpboq8gge4CbWziUUj9du6mklxZaeYBBpB6CzlVLitesXieA/zl2JnWzRMD7Ho3r"; + sout << "LLSqxh9UeFgjFtb3ilKSHNZH6x1DS0hFhEIA"; + + // Put the data into the istream sin + sin.str(sout.str()); + sout.str(""); + + // Decode the base64 text into its compressed binary form + base64_coder.decode(sin,sout); + sin.clear(); + sin.str(sout.str()); + sout.str(""); + + // Decompress the data into its original form + compressor.decompress(sin,sout); + + // Return the decoded and decompressed data + return sout.str(); + } + + // ---------------------------------------------------------------------------------------- + + }; + + fhog_tester a; + +// ---------------------------------------------------------------------------------------- + +} + + diff --git a/ml/dlib/dlib/test/filtering.cpp b/ml/dlib/dlib/test/filtering.cpp new file mode 100644 index 000000000..61dc88440 --- /dev/null +++ b/ml/dlib/dlib/test/filtering.cpp @@ -0,0 +1,166 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include +#include +#include + +#include "tester.h" + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.filtering"); + +// ---------------------------------------------------------------------------------------- + + template + double test_filter ( + filter_type kf, + int size + ) + { + // This test has a point moving in a circle around the origin. The point + // also gets a random bump in a random direction at each time step. + + running_stats rs; + + dlib::rand rnd; + int count = 0; + const dlib::vector z(0,0,1); + dlib::vector p(10,10), temp; + for (int i = 0; i < size; ++i) + { + // move the point around in a circle + p += z.cross(p).normalize()/0.5; + // randomly drop measurements + if (rnd.get_random_double() < 0.7 || count < 4) + { + // make a random bump + dlib::vector pp; + pp.x() = rnd.get_random_gaussian()/3; + pp.y() = rnd.get_random_gaussian()/3; + + ++count; + kf.update(p+pp); + } + else + { + kf.update(); + dlog << LTRACE << "MISSED MEASUREMENT"; + } + // figure out the next position + temp = (p+z.cross(p).normalize()/0.5); + const double error = length(temp - rowm(kf.get_predicted_next_state(),range(0,1))); + rs.add(error); + + dlog << LTRACE << temp << "("<< error << "): " << trans(kf.get_predicted_next_state()); + + // test the serialization a few times. + if (count < 10) + { + ostringstream sout; + serialize(kf, sout); + istringstream sin(sout.str()); + filter_type temp; + deserialize(temp, sin); + kf = temp; + } + } + + + return rs.mean(); + + } + +// ---------------------------------------------------------------------------------------- + + void test_kalman_filter() + { + matrix R; + R = 0.3, 0, + 0, 0.3; + + // the variables in the state are + // x,y, x velocity, y velocity, x acceleration, and y acceleration + matrix A; + A = 1, 0, 1, 0, 0, 0, + 0, 1, 0, 1, 0, 0, + 0, 0, 1, 0, 1, 0, + 0, 0, 0, 1, 0, 1, + 0, 0, 0, 0, 1, 0, + 0, 0, 0, 0, 0, 1; + + // the measurements only tell us the positions + matrix H; + H = 1, 0, 0, 0, 0, 0, + 0, 1, 0, 0, 0, 0; + + + kalman_filter<6,2> kf; + kf.set_measurement_noise(R); + matrix pn = 0.01*identity_matrix(); + kf.set_process_noise(pn); + kf.set_observation_model(H); + kf.set_transition_model(A); + + DLIB_TEST(equal(kf.get_observation_model() , H)); + DLIB_TEST(equal(kf.get_transition_model() , A)); + DLIB_TEST(equal(kf.get_measurement_noise() , R)); + DLIB_TEST(equal(kf.get_process_noise() , pn)); + DLIB_TEST(equal(kf.get_current_estimation_error_covariance() , identity_matrix(pn))); + + double kf_error = test_filter(kf, 300); + + dlog << LINFO << "kf error: "<< kf_error; + DLIB_TEST_MSG(kf_error < 0.75, kf_error); + } + +// ---------------------------------------------------------------------------------------- + + void test_rls_filter() + { + + rls_filter rls(10, 0.99, 0.1); + + DLIB_TEST(rls.get_window_size() == 10); + DLIB_TEST(rls.get_forget_factor() == 0.99); + DLIB_TEST(rls.get_c() == 0.1); + + double rls_error = test_filter(rls, 1000); + + dlog << LINFO << "rls error: "<< rls_error; + DLIB_TEST_MSG(rls_error < 0.75, rls_error); + } + +// ---------------------------------------------------------------------------------------- + + class filtering_tester : public tester + { + public: + filtering_tester ( + ) : + tester ("test_filtering", + "Runs tests on the filtering stuff (rls and kalman filters).") + {} + + void perform_test ( + ) + { + test_rls_filter(); + test_kalman_filter(); + } + } a; + +} + + diff --git a/ml/dlib/dlib/test/find_max_factor_graph_nmplp.cpp b/ml/dlib/dlib/test/find_max_factor_graph_nmplp.cpp new file mode 100644 index 000000000..2260e92a1 --- /dev/null +++ b/ml/dlib/dlib/test/find_max_factor_graph_nmplp.cpp @@ -0,0 +1,787 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#include +#include +#include +#include +#include +#include +#include + +#include "tester.h" + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.find_max_factor_graph_nmplp"); + +// ---------------------------------------------------------------------------------------- + + dlib::rand rnd; + + template + class map_problem + { + /* + This is a simple 8 node problem with two cycles in it unless fully_connected is true + and then it's a fully connected 8 note graph. + */ + + public: + + mutable std::map,std::map,double> > weights; + map_problem() + { + for (int i = 0; i < 8; ++i) + { + for (int j = i; j < 8; ++j) + { + weights[make_unordered_pair(i,j)][make_pair(0,0)] = rnd.get_random_gaussian(); + weights[make_unordered_pair(i,j)][make_pair(0,1)] = rnd.get_random_gaussian(); + weights[make_unordered_pair(i,j)][make_pair(1,0)] = rnd.get_random_gaussian(); + weights[make_unordered_pair(i,j)][make_pair(1,1)] = rnd.get_random_gaussian(); + } + } + } + + struct node_iterator + { + node_iterator() {} + node_iterator(unsigned long nid_): nid(nid_) {} + bool operator== (const node_iterator& item) const { return item.nid == nid; } + bool operator!= (const node_iterator& item) const { return item.nid != nid; } + + node_iterator& operator++() + { + ++nid; + return *this; + } + + unsigned long nid; + }; + + struct neighbor_iterator + { + neighbor_iterator() : count(0) {} + + bool operator== (const neighbor_iterator& item) const { return item.node_id() == node_id(); } + bool operator!= (const neighbor_iterator& item) const { return item.node_id() != node_id(); } + neighbor_iterator& operator++() + { + ++count; + return *this; + } + + unsigned long node_id () const + { + if (fully_connected) + { + if (count < home_node) + return count; + else + return count+1; + } + + if (home_node < 4) + { + if (count == 0) + return (home_node + 4 + 1)%4; + else if (count == 1) + return (home_node + 4 - 1)%4; + else + return 8; // one past the end + } + else + { + if (count == 0) + return (home_node + 4 + 1)%4 + 4; + else if (count == 1) + return (home_node + 4 - 1)%4 + 4; + else + return 8; // one past the end + } + } + + unsigned long home_node; + unsigned long count; + }; + + unsigned long number_of_nodes ( + ) const + { + return 8; + } + + node_iterator begin( + ) const + { + node_iterator temp; + temp.nid = 0; + return temp; + } + + node_iterator end( + ) const + { + node_iterator temp; + temp.nid = 8; + return temp; + } + + neighbor_iterator begin( + const node_iterator& it + ) const + { + neighbor_iterator temp; + temp.home_node = it.nid; + return temp; + } + + neighbor_iterator begin( + const neighbor_iterator& it + ) const + { + neighbor_iterator temp; + temp.home_node = it.node_id(); + return temp; + } + + neighbor_iterator end( + const node_iterator& + ) const + { + neighbor_iterator temp; + temp.home_node = 9; + temp.count = 8; + return temp; + } + + neighbor_iterator end( + const neighbor_iterator& + ) const + { + neighbor_iterator temp; + temp.home_node = 9; + temp.count = 8; + return temp; + } + + + unsigned long node_id ( + const node_iterator& it + ) const + { + return it.nid; + } + + unsigned long node_id ( + const neighbor_iterator& it + ) const + { + return it.node_id(); + } + + + unsigned long num_states ( + const node_iterator& + ) const + { + return 2; + } + + unsigned long num_states ( + const neighbor_iterator& + ) const + { + return 2; + } + + double factor_value (const node_iterator& it1, const node_iterator& it2, unsigned long s1, unsigned long s2) const + { return basic_factor_value(it1.nid, it2.nid, s1, s2); } + double factor_value (const neighbor_iterator& it1, const node_iterator& it2, unsigned long s1, unsigned long s2) const + { return basic_factor_value(it1.node_id(), it2.nid, s1, s2); } + double factor_value (const node_iterator& it1, const neighbor_iterator& it2, unsigned long s1, unsigned long s2) const + { return basic_factor_value(it1.nid, it2.node_id(), s1, s2); } + double factor_value (const neighbor_iterator& it1, const neighbor_iterator& it2, unsigned long s1, unsigned long s2) const + { return basic_factor_value(it1.node_id(), it2.node_id(), s1, s2); } + + private: + + double basic_factor_value ( + unsigned long n1, + unsigned long n2, + unsigned long s1, + unsigned long s2 + ) const + { + if (n1 > n2) + { + swap(n1,n2); + swap(s1,s2); + } + return weights[make_unordered_pair(n1,n2)][make_pair(s1,s2)]; + } + + }; + +// ---------------------------------------------------------------------------------------- + + class map_problem_chain + { + /* + This is a chain structured 8 node graph (so no cycles). + */ + + public: + + mutable std::map,std::map,double> > weights; + map_problem_chain() + { + for (int i = 0; i < 7; ++i) + { + weights[make_unordered_pair(i,i+1)][make_pair(0,0)] = rnd.get_random_gaussian(); + weights[make_unordered_pair(i,i+1)][make_pair(0,1)] = rnd.get_random_gaussian(); + weights[make_unordered_pair(i,i+1)][make_pair(1,0)] = rnd.get_random_gaussian(); + weights[make_unordered_pair(i,i+1)][make_pair(1,1)] = rnd.get_random_gaussian(); + } + } + + struct node_iterator + { + node_iterator() {} + node_iterator(unsigned long nid_): nid(nid_) {} + bool operator== (const node_iterator& item) const { return item.nid == nid; } + bool operator!= (const node_iterator& item) const { return item.nid != nid; } + + node_iterator& operator++() + { + ++nid; + return *this; + } + + unsigned long nid; + }; + + struct neighbor_iterator + { + neighbor_iterator() : count(0) {} + + bool operator== (const neighbor_iterator& item) const { return item.node_id() == node_id(); } + bool operator!= (const neighbor_iterator& item) const { return item.node_id() != node_id(); } + neighbor_iterator& operator++() + { + ++count; + return *this; + } + + unsigned long node_id () const + { + if (count >= 2) + return 8; + return nid[count]; + } + + unsigned long nid[2]; + unsigned int count; + }; + + unsigned long number_of_nodes ( + ) const + { + return 8; + } + + node_iterator begin( + ) const + { + node_iterator temp; + temp.nid = 0; + return temp; + } + + node_iterator end( + ) const + { + node_iterator temp; + temp.nid = 8; + return temp; + } + + neighbor_iterator begin( + const node_iterator& it + ) const + { + neighbor_iterator temp; + if (it.nid == 0) + { + temp.nid[0] = it.nid+1; + temp.nid[1] = 8; + } + else if (it.nid == 7) + { + temp.nid[0] = it.nid-1; + temp.nid[1] = 8; + } + else + { + temp.nid[0] = it.nid-1; + temp.nid[1] = it.nid+1; + } + return temp; + } + + neighbor_iterator begin( + const neighbor_iterator& it + ) const + { + const unsigned long nid = it.node_id(); + neighbor_iterator temp; + if (nid == 0) + { + temp.nid[0] = nid+1; + temp.nid[1] = 8; + } + else if (nid == 7) + { + temp.nid[0] = nid-1; + temp.nid[1] = 8; + } + else + { + temp.nid[0] = nid-1; + temp.nid[1] = nid+1; + } + return temp; + } + + neighbor_iterator end( + const node_iterator& + ) const + { + neighbor_iterator temp; + temp.nid[0] = 8; + temp.nid[1] = 8; + return temp; + } + + neighbor_iterator end( + const neighbor_iterator& + ) const + { + neighbor_iterator temp; + temp.nid[0] = 8; + temp.nid[1] = 8; + return temp; + } + + + unsigned long node_id ( + const node_iterator& it + ) const + { + return it.nid; + } + + unsigned long node_id ( + const neighbor_iterator& it + ) const + { + return it.node_id(); + } + + + unsigned long num_states ( + const node_iterator& + ) const + { + return 2; + } + + unsigned long num_states ( + const neighbor_iterator& + ) const + { + return 2; + } + + double factor_value (const node_iterator& it1, const node_iterator& it2, unsigned long s1, unsigned long s2) const + { return basic_factor_value(it1.nid, it2.nid, s1, s2); } + double factor_value (const neighbor_iterator& it1, const node_iterator& it2, unsigned long s1, unsigned long s2) const + { return basic_factor_value(it1.node_id(), it2.nid, s1, s2); } + double factor_value (const node_iterator& it1, const neighbor_iterator& it2, unsigned long s1, unsigned long s2) const + { return basic_factor_value(it1.nid, it2.node_id(), s1, s2); } + double factor_value (const neighbor_iterator& it1, const neighbor_iterator& it2, unsigned long s1, unsigned long s2) const + { return basic_factor_value(it1.node_id(), it2.node_id(), s1, s2); } + + private: + + double basic_factor_value ( + unsigned long n1, + unsigned long n2, + unsigned long s1, + unsigned long s2 + ) const + { + if (n1 > n2) + { + swap(n1,n2); + swap(s1,s2); + } + return weights[make_unordered_pair(n1,n2)][make_pair(s1,s2)]; + } + + }; + +// ---------------------------------------------------------------------------------------- + + + class map_problem2 + { + /* + This is a simple tree structured graph. In particular, it is a star made + up of 6 nodes. + */ + public: + matrix numbers; + + map_problem2() + { + numbers = randm(5,3,rnd); + } + + struct node_iterator + { + node_iterator() {} + node_iterator(unsigned long nid_): nid(nid_) {} + bool operator== (const node_iterator& item) const { return item.nid == nid; } + bool operator!= (const node_iterator& item) const { return item.nid != nid; } + + node_iterator& operator++() + { + ++nid; + return *this; + } + + unsigned long nid; + }; + + struct neighbor_iterator + { + neighbor_iterator() : count(0) {} + + bool operator== (const neighbor_iterator& item) const { return item.node_id() == node_id(); } + bool operator!= (const neighbor_iterator& item) const { return item.node_id() != node_id(); } + neighbor_iterator& operator++() + { + ++count; + return *this; + } + + unsigned long node_id () const + { + if (home_node == 6) + return 6; + + if (home_node < 5) + { + // all the nodes are connected to node 5 and nothing else + if (count == 0) + return 5; + else + return 6; // the number returned by the end() functions. + } + else if (count < 5) + { + return count; + } + else + { + return 6; + } + + } + + unsigned long home_node; + unsigned long count; + }; + + unsigned long number_of_nodes ( + ) const + { + return 6; + } + + node_iterator begin( + ) const + { + node_iterator temp; + temp.nid = 0; + return temp; + } + + node_iterator end( + ) const + { + node_iterator temp; + temp.nid = 6; + return temp; + } + + neighbor_iterator begin( + const node_iterator& it + ) const + { + neighbor_iterator temp; + temp.home_node = it.nid; + return temp; + } + + neighbor_iterator begin( + const neighbor_iterator& it + ) const + { + neighbor_iterator temp; + temp.home_node = it.node_id(); + return temp; + } + + neighbor_iterator end( + const node_iterator& + ) const + { + neighbor_iterator temp; + temp.home_node = 6; + return temp; + } + + neighbor_iterator end( + const neighbor_iterator& + ) const + { + neighbor_iterator temp; + temp.home_node = 6; + return temp; + } + + + unsigned long node_id ( + const node_iterator& it + ) const + { + return it.nid; + } + + unsigned long node_id ( + const neighbor_iterator& it + ) const + { + return it.node_id(); + } + + + unsigned long num_states ( + const node_iterator& + ) const + { + return 3; + } + + unsigned long num_states ( + const neighbor_iterator& + ) const + { + return 3; + } + + double factor_value (const node_iterator& it1, const node_iterator& it2, unsigned long s1, unsigned long s2) const + { return basic_factor_value(it1.nid, it2.nid, s1, s2); } + double factor_value (const neighbor_iterator& it1, const node_iterator& it2, unsigned long s1, unsigned long s2) const + { return basic_factor_value(it1.node_id(), it2.nid, s1, s2); } + double factor_value (const node_iterator& it1, const neighbor_iterator& it2, unsigned long s1, unsigned long s2) const + { return basic_factor_value(it1.nid, it2.node_id(), s1, s2); } + double factor_value (const neighbor_iterator& it1, const neighbor_iterator& it2, unsigned long s1, unsigned long s2) const + { return basic_factor_value(it1.node_id(), it2.node_id(), s1, s2); } + + private: + + double basic_factor_value ( + unsigned long n1, + unsigned long n2, + unsigned long s1, + unsigned long s2 + ) const + { + if (n1 > n2) + { + swap(n1,n2); + swap(s1,s2); + } + + + // basically ignore the other node in this factor. The node we + // are ignoring is the center node of this star graph. So we basically + // let it always have a value of 1. + if (s2 == 1) + return numbers(n1,s1) + 1; + else + return numbers(n1,s1); + } + + }; + +// ---------------------------------------------------------------------------------------- + + template + double find_total_score ( + const map_problem& prob, + const std::vector& map_assignment + ) + { + typedef typename map_problem::node_iterator node_iterator; + typedef typename map_problem::neighbor_iterator neighbor_iterator; + + double score = 0; + for (node_iterator i = prob.begin(); i != prob.end(); ++i) + { + const unsigned long id_i = prob.node_id(i); + for (neighbor_iterator j = prob.begin(i); j != prob.end(i); ++j) + { + const unsigned long id_j = prob.node_id(j); + score += prob.factor_value(i,j, map_assignment[id_i], map_assignment[id_j]); + } + } + + return score; + } + +// ---------------------------------------------------------------------------------------- + + + template < + typename map_problem + > + void brute_force_find_max_factor_graph_nmplp ( + const map_problem& prob, + std::vector& map_assignment + ) + { + std::vector temp_assignment; + temp_assignment.resize(prob.number_of_nodes(),0); + + double best_score = -std::numeric_limits::infinity(); + + for (unsigned long i = 0; i < 255; ++i) + { + temp_assignment[0] = (i&0x01)!=0; + temp_assignment[1] = (i&0x02)!=0; + temp_assignment[2] = (i&0x04)!=0; + temp_assignment[3] = (i&0x08)!=0; + temp_assignment[4] = (i&0x10)!=0; + temp_assignment[5] = (i&0x20)!=0; + temp_assignment[6] = (i&0x40)!=0; + temp_assignment[7] = (i&0x80)!=0; + + double score = find_total_score(prob,temp_assignment); + if (score > best_score) + { + best_score = score; + map_assignment = temp_assignment; + } + } + } + +// ---------------------------------------------------------------------------------------- + + template + void do_test( + ) + { + print_spinner(); + std::vector map_assignment1, map_assignment2; + map_problem prob; + find_max_factor_graph_nmplp(prob, map_assignment1, 1000, 1e-8); + + const double score1 = find_total_score(prob, map_assignment1); + + brute_force_find_max_factor_graph_nmplp(prob, map_assignment2); + const double score2 = find_total_score(prob, map_assignment2); + + dlog << LINFO << "score NMPLP: " << score1; + dlog << LINFO << "score MAP: " << score2; + + DLIB_TEST(std::abs(score1 - score2) < 1e-10); + DLIB_TEST(mat(map_assignment1) == mat(map_assignment2)); + } + +// ---------------------------------------------------------------------------------------- + + template + void do_test2( + ) + { + print_spinner(); + std::vector map_assignment1, map_assignment2; + map_problem prob; + find_max_factor_graph_nmplp(prob, map_assignment1, 10, 1e-8); + + const double score1 = find_total_score(prob, map_assignment1); + + map_assignment2.resize(6); + map_assignment2[0] = index_of_max(rowm(prob.numbers,0)); + map_assignment2[1] = index_of_max(rowm(prob.numbers,1)); + map_assignment2[2] = index_of_max(rowm(prob.numbers,2)); + map_assignment2[3] = index_of_max(rowm(prob.numbers,3)); + map_assignment2[4] = index_of_max(rowm(prob.numbers,4)); + map_assignment2[5] = 1; + const double score2 = find_total_score(prob, map_assignment2); + + dlog << LINFO << "score NMPLP: " << score1; + dlog << LINFO << "score MAP: " << score2; + dlog << LINFO << "MAP assignment: "<< trans(mat(map_assignment1)); + + DLIB_TEST(std::abs(score1 - score2) < 1e-10); + DLIB_TEST(mat(map_assignment1) == mat(map_assignment2)); + } + +// ---------------------------------------------------------------------------------------- + + class test_find_max_factor_graph_nmplp : public tester + { + public: + test_find_max_factor_graph_nmplp ( + ) : + tester ("test_find_max_factor_graph_nmplp", + "Runs tests on the find_max_factor_graph_nmplp routine.") + {} + + void perform_test ( + ) + { + rnd.clear(); + + dlog << LINFO << "test on a chain structured graph"; + for (int i = 0; i < 30; ++i) + do_test(); + + dlog << LINFO << "test on a 2 cycle graph"; + for (int i = 0; i < 30; ++i) + do_test >(); + + dlog << LINFO << "test on a fully connected graph"; + for (int i = 0; i < 5; ++i) + do_test >(); + + dlog << LINFO << "test on a tree structured graph"; + for (int i = 0; i < 10; ++i) + do_test2(); + } + } a; + +} + + + + diff --git a/ml/dlib/dlib/test/find_max_factor_graph_viterbi.cpp b/ml/dlib/dlib/test/find_max_factor_graph_viterbi.cpp new file mode 100644 index 000000000..82754aefd --- /dev/null +++ b/ml/dlib/dlib/test/find_max_factor_graph_viterbi.cpp @@ -0,0 +1,217 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#include +#include +#include +#include +#include +#include + +#include "tester.h" + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.find_max_factor_graph_viterbi"); + +// ---------------------------------------------------------------------------------------- + + dlib::rand rnd; + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long O, + unsigned long NS, + unsigned long num_nodes, + bool all_negative + > + class map_problem + { + public: + unsigned long order() const { return O; } + unsigned long num_states() const { return NS; } + + map_problem() + { + data = randm(number_of_nodes(),(long)std::pow(num_states(),(double)order()+1), rnd); + if (all_negative) + data = -data; + } + + unsigned long number_of_nodes ( + ) const + { + return num_nodes; + } + + template < + typename EXP + > + double factor_value ( + unsigned long node_id, + const matrix_exp& node_states + ) const + { + if (node_states.size() == 1) + return data(node_id, node_states(0)); + else if (node_states.size() == 2) + return data(node_id, node_states(0) + node_states(1)*NS); + else if (node_states.size() == 3) + return data(node_id, (node_states(0) + node_states(1)*NS)*NS + node_states(2)); + else + return data(node_id, ((node_states(0) + node_states(1)*NS)*NS + node_states(2))*NS + node_states(3)); + } + + matrix data; + }; + + +// ---------------------------------------------------------------------------------------- + + template < + typename map_problem + > + void brute_force_find_max_factor_graph_viterbi ( + const map_problem& prob, + std::vector& map_assignment + ) + { + using namespace dlib::impl; + const int order = prob.order(); + const int num_states = prob.num_states(); + + map_assignment.resize(prob.number_of_nodes()); + double best_score = -std::numeric_limits::infinity(); + matrix node_states; + node_states.set_size(prob.number_of_nodes()); + node_states = 0; + do + { + double score = 0; + for (unsigned long i = 0; i < prob.number_of_nodes(); ++i) + { + score += prob.factor_value(i, (colm(node_states,range(i,i-std::min(order,i))))); + } + + if (score > best_score) + { + for (unsigned long i = 0; i < map_assignment.size(); ++i) + map_assignment[i] = node_states(i); + best_score = score; + } + + } while(advance_state(node_states,num_states)); + + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long order, + unsigned long num_states, + unsigned long num_nodes, + bool all_negative + > + void do_test_() + { + dlog << LINFO << "order: "<< order + << " num_states: " << num_states + << " num_nodes: " << num_nodes + << " all_negative: " << all_negative; + + for (int i = 0; i < 25; ++i) + { + print_spinner(); + map_problem prob; + std::vector assign, assign2; + brute_force_find_max_factor_graph_viterbi(prob, assign); + find_max_factor_graph_viterbi(prob, assign2); + + DLIB_TEST_MSG(mat(assign) == mat(assign2), + trans(mat(assign)) + << trans(mat(assign2)) + ); + } + } + + template < + unsigned long order, + unsigned long num_states, + unsigned long num_nodes + > + void do_test() + { + do_test_(); + } + + template < + unsigned long order, + unsigned long num_states, + unsigned long num_nodes + > + void do_test_negative() + { + do_test_(); + } + +// ---------------------------------------------------------------------------------------- + + class test_find_max_factor_graph_viterbi : public tester + { + public: + test_find_max_factor_graph_viterbi ( + ) : + tester ("test_find_max_factor_graph_viterbi", + "Runs tests on the find_max_factor_graph_viterbi routine.") + {} + + void perform_test ( + ) + { + do_test<1,3,0>(); + do_test<1,3,1>(); + do_test<1,3,2>(); + do_test<0,3,2>(); + do_test_negative<0,3,2>(); + + do_test<1,3,8>(); + do_test<2,3,7>(); + do_test_negative<2,3,7>(); + do_test<3,3,8>(); + do_test<4,3,8>(); + do_test_negative<4,3,8>(); + do_test<0,3,8>(); + do_test<4,3,1>(); + do_test<4,3,0>(); + + do_test<3,2,1>(); + do_test<3,2,0>(); + do_test<3,2,2>(); + do_test<2,2,1>(); + do_test_negative<3,2,1>(); + do_test_negative<3,2,0>(); + do_test_negative<3,2,2>(); + do_test_negative<2,2,1>(); + + do_test<0,3,0>(); + do_test<1,2,8>(); + do_test<2,2,7>(); + do_test<3,2,8>(); + do_test<0,2,8>(); + + do_test<1,1,8>(); + do_test<2,1,8>(); + do_test<3,1,8>(); + do_test<0,1,8>(); + } + } a; + +} + + + + diff --git a/ml/dlib/dlib/test/find_optimal_parameters.cpp b/ml/dlib/dlib/test/find_optimal_parameters.cpp new file mode 100644 index 000000000..9f2f5b348 --- /dev/null +++ b/ml/dlib/dlib/test/find_optimal_parameters.cpp @@ -0,0 +1,58 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include "tester.h" + + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.find_optimal_parameters"); + +// ---------------------------------------------------------------------------------------- + + + class find_optimal_parameters : public tester + { + public: + find_optimal_parameters ( + ) : + tester ("test_find_optimal_parameters", + "Runs tests on find_optimal_parameters().") + {} + + void perform_test ( + ) + { + print_spinner(); + matrix params = {0.5, 0.5}; + dlib::find_optimal_parameters(4, 0.001, 100, params, {-0.1, -0.01}, {5, 5}, [](const matrix& params) { + cout << "."; + return sum(squared(params)); + }); + + matrix true_params = {0,0}; + + DLIB_TEST(max(abs(true_params - params)) < 1e-10); + + params = {0.1}; + dlib::find_optimal_parameters(4, 0.001, 100, params, {-0.01}, {5}, [](const matrix& params) { + cout << "."; + return sum(squared(params)); + }); + + true_params = {0}; + DLIB_TEST(max(abs(true_params - params)) < 1e-10); + } + } a; + +} + + + diff --git a/ml/dlib/dlib/test/geometry.cpp b/ml/dlib/dlib/test/geometry.cpp new file mode 100644 index 000000000..505a83d95 --- /dev/null +++ b/ml/dlib/dlib/test/geometry.cpp @@ -0,0 +1,883 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tester.h" + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.geometry"); + + void geometry_test ( + ) + /*! + ensures + - runs tests on the geometry stuff compliance with the specs + !*/ + { + print_spinner(); + + point p1; + point p2(2,3); + + DLIB_TEST(p1.x() == 0); + DLIB_TEST(p1.y() == 0); + DLIB_TEST(p2.x() == 2); + DLIB_TEST(p2.y() == 3); + + DLIB_TEST((-p2).x() == -2); + DLIB_TEST((-p2).y() == -3); + + + p2 += p2; + DLIB_TEST(p2.x() == 4); + DLIB_TEST(p2.y() == 6); + + dlib::vector v1 = point(1,0); + dlib::vector v2(0,0,1); + + p1 = v2.cross(v1); + DLIB_TEST(p1 == point(0,1)); + DLIB_TEST(p1 != point(1,1)); + DLIB_TEST(p1 != point(1,0)); + + p1 = point(2,3); + rectangle rect1 = p1; + DLIB_TEST(rect1.width() == 1); + DLIB_TEST(rect1.height() == 1); + p2 = point(1,1); + + rect1 += p2; + DLIB_TEST(rect1.left() == 1); + DLIB_TEST(rect1.top() == 1); + DLIB_TEST(rect1.right() == 2); + DLIB_TEST(rect1.bottom() == 3); + + DLIB_TEST(rect1.width() == 2); + DLIB_TEST(rect1.height() == 3); + + // test the iostream << and >> operators (via string_cast and cast_to_string) + DLIB_TEST(string_cast(" (1, 2 )") == point(1,2)); + DLIB_TEST(string_cast(" ( -1, 2 )") == point(-1,2)); + DLIB_TEST(string_cast(" [(1, 2 )(3,4)]") == rectangle(1,2,3,4)); + DLIB_TEST(string_cast >(" (1, 2 , 3.5)") == dlib::vector(1,2,3.5)); + + DLIB_TEST(string_cast(cast_to_string(rect1)) == rect1); + DLIB_TEST(string_cast(cast_to_string(p1)) == p1); + DLIB_TEST(string_cast >(cast_to_string(v1)) == v1); + + rectangle rect2; + + // test the serialization code + ostringstream sout; + serialize(rect1,sout); + serialize(p1,sout); + serialize(v1,sout); + serialize(rect1,sout); + serialize(p1,sout); + serialize(v1,sout); + + istringstream sin(sout.str()); + + deserialize(rect2,sin); + deserialize(p2,sin); + deserialize(v2,sin); + DLIB_TEST(rect2 == rect1); + DLIB_TEST(p2 == p1); + DLIB_TEST(v2 == v1); + deserialize(rect2,sin); + deserialize(p2,sin); + deserialize(v2,sin); + DLIB_TEST(rect2 == rect1); + DLIB_TEST(p2 == p1); + DLIB_TEST(v2 == v1); + DLIB_TEST(sin.good()); + DLIB_TEST(sin.get() == EOF); + + + v1.x() = 1; + v1.y() = 2; + v1.z() = 3; + + matrix mv = v1; + DLIB_TEST(mv.nr() == 3); + DLIB_TEST(mv.nc() == 1); + DLIB_TEST(mv(0) == 1); + DLIB_TEST(mv(1) == 2); + DLIB_TEST(mv(2) == 3); + + set_all_elements(mv,0); + DLIB_TEST(mv(0) == 0); + DLIB_TEST(mv(1) == 0); + DLIB_TEST(mv(2) == 0); + + mv(0) = 5; + mv(1) = 6; + mv(2) = 7; + + v1 = mv; + DLIB_TEST(v1.x() == 5); + DLIB_TEST(v1.y() == 6); + DLIB_TEST(v1.z() == 7); + + + { + dlib::vector vd2; + dlib::vector vd3; + dlib::vector vl2; + dlib::vector vl3; + + vd2.x() = 2.3; + vd2.y() = 4.7; + + vd3.z() = 9; + + vd3 = vd2; + + + + vl2 = vd3; + vl3 = vd3; + + + DLIB_TEST(vd2.z() == 0); + DLIB_TEST(vd3.z() == 0); + DLIB_TEST(vl2.z() == 0); + DLIB_TEST(vl3.z() == 0); + + DLIB_TEST(vl2.x() == 2); + DLIB_TEST(vl3.x() == 2); + DLIB_TEST(vl2.y() == 5); + DLIB_TEST(vl3.y() == 5); + + + DLIB_TEST(abs(vd2.cross(vd3).dot(vd2)) < 1e-7); + DLIB_TEST(abs(vd3.cross(vd2).dot(vd2)) < 1e-7); + DLIB_TEST(abs(vd2.cross(vd3).dot(vd3)) < 1e-7); + DLIB_TEST(abs(vd3.cross(vd2).dot(vd3)) < 1e-7); + + DLIB_TEST(abs(vl2.cross(vl3).dot(vl2)) == 0); + DLIB_TEST(abs(vl3.cross(vl2).dot(vl2)) == 0); + DLIB_TEST(abs(vl2.cross(vl3).dot(vl3)) == 0); + DLIB_TEST(abs(vl3.cross(vl2).dot(vl3)) == 0); + + + DLIB_TEST((vd2-vd3).length() < 1e-7); + + DLIB_TEST(vl2 == vl3); + + + vl2.x() = 0; + vl2.y() = 0; + vl3 = vl2; + + vl2.x() = 4; + vl3.y() = 3; + + DLIB_TEST(vl2.cross(vl3).length() == 12); + DLIB_TEST(vl3.cross(vl2).length() == 12); + + + matrix m(3,3); + m = 1,2,3, + 4,5,6, + 7,8,9; + + vd3.x() = 2; + vd3.y() = 3; + vd3.z() = 4; + + vd3 = m*vd3; + + DLIB_TEST_MSG(vd3.x() == 1*2 + 2*3 + 3*4,vd3.x() << " == " << (1*2 + 2*3 + 3*4)); + DLIB_TEST(vd3.y() == 4*2 + 5*3 + 6*4); + DLIB_TEST(vd3.z() == 7*2 + 8*3 + 9*4); + + (vd3*2).dot(vd3); + (vd2*2).dot(vd3); + (vd3*2).dot(vd2); + (vd2*2).dot(vd2); + (2*vd3*2).dot(vd3); + (2*vd2*2).dot(vd3); + (2*vd3*2).dot(vd2); + (2*vd2*2).dot(vd2); + + (vd2 + vd3).dot(vd2); + (vd2 - vd3).dot(vd2); + (vd2/2).dot(vd2); + (vd3/2).dot(vd2); + } + + { + dlib::vector vd2; + dlib::vector vl3; + + vl3.x() = 1; + vl3.y() = 2; + vl3.z() = 3; + + vd2.x() = 6.5; + vd2.y() = 7.5; + + DLIB_TEST((vl3 + vd2).x() == 1+6.5); + DLIB_TEST((vl3 + vd2).y() == 2+7.5); + DLIB_TEST((vl3 + vd2).z() == 3+0); + + DLIB_TEST((vl3 - vd2).x() == 1-6.5); + DLIB_TEST((vl3 - vd2).y() == 2-7.5); + DLIB_TEST((vl3 - vd2).z() == 3-0); + + } + + { + dlib::vector v(3,4,5); + DLIB_TEST((-v).x() == -3.0); + DLIB_TEST((-v).y() == -4.0); + DLIB_TEST((-v).z() == -5.0); + } + + { + rectangle rect; + + point tl(2,3); + point tr(8,3); + point bl(2,9); + point br(8,9); + + rect += tl; + rect += tr; + rect += bl; + rect += br; + + DLIB_TEST(rect.tl_corner() == tl); + DLIB_TEST(rect.tr_corner() == tr); + DLIB_TEST(rect.bl_corner() == bl); + DLIB_TEST(rect.br_corner() == br); + + } + + { + point p1, center; + + center = point(3,4); + p1 = point(10,4); + + DLIB_TEST(rotate_point(center, p1, pi/2) == point(3,7+4)); + + center = point(3,3); + p1 = point(10,3); + + DLIB_TEST(rotate_point(center, p1, pi/4) == point(8,8)); + DLIB_TEST(rotate_point(center, p1, -pi/4) == point(8,-2)); + + DLIB_TEST(rotate_point(center, p1, pi/4 + 10*pi) == point(8,8)); + DLIB_TEST(rotate_point(center, p1, -pi/4 + 10*pi) == point(8,-2)); + DLIB_TEST(rotate_point(center, p1, pi/4 - 10*pi) == point(8,8)); + DLIB_TEST(rotate_point(center, p1, -pi/4 - 10*pi) == point(8,-2)); + + point_rotator rot(pi/2); + DLIB_TEST(rot(point(1,0)) == point(0,1)); + DLIB_TEST(rot(point(0,1)) == point(-1,0)); + DLIB_TEST(point(rot.get_m()*(dlib::vector(1,0))) == point(0,1)); + DLIB_TEST(point(rot.get_m()*(dlib::vector(0,1))) == point(-1,0)); + } + + { + rectangle rect; + + rect = grow_rect(rect,1); + DLIB_TEST(rect.width() == 2); + DLIB_TEST(rect.height() == 2); + DLIB_TEST(rect.left() == -1); + DLIB_TEST(rect.top() == -1); + DLIB_TEST(rect.right() == 0); + DLIB_TEST(rect.bottom() == 0); + } + { + rectangle rect; + + rect = grow_rect(rect,2); + DLIB_TEST(rect.width() == 4); + DLIB_TEST(rect.height() == 4); + DLIB_TEST(rect.left() == -2); + DLIB_TEST(rect.top() == -2); + DLIB_TEST(rect.right() == 1); + DLIB_TEST(rect.bottom() == 1); + + rect = shrink_rect(rect,1); + DLIB_TEST(rect.width() == 2); + DLIB_TEST(rect.height() == 2); + DLIB_TEST(rect.left() == -1); + DLIB_TEST(rect.top() == -1); + DLIB_TEST(rect.right() == 0); + DLIB_TEST(rect.bottom() == 0); + } + { + std::vector< dlib::vector > a; + + dlib::vector v; + dlib::rand rnd; + + for (int i = 0; i < 10; ++i) + { + v.x() = rnd.get_random_double(); + v.y() = rnd.get_random_double(); + v.z() = rnd.get_random_double(); + a.push_back(v); + + } + + // This test is just to make sure the covariance function can compile when used + // on a dlib::vector. The actual test doesn't matter. + DLIB_TEST(sum(covariance(mat(a))) < 10); + + } + + + DLIB_TEST(rectangle() + point(5,4) + point(10,10) == rectangle(5,4,10,10)); + + // make sure the center of a centered rectangle is always right + for (long x = -10; x <= 10; ++x) + { + for (long y = -10; y <= 10; ++y) + { + for (long w = 0; w < 10; ++w) + { + for (long h = 0; h < 10; ++h) + { + DLIB_TEST(center(centered_rect(x,y,w,h)) == point(x,y)); + } + } + } + } + + } + +// ---------------------------------------------------------------------------------------- + + void test_border_enumerator() + { + + + + border_enumerator be; + DLIB_TEST(be.at_start() == true); + DLIB_TEST(be.size() == 0); + DLIB_TEST(be.current_element_valid() == false); + DLIB_TEST(be.move_next() == false); + DLIB_TEST(be.at_start() == false); + DLIB_TEST(be.current_element_valid() == false); + DLIB_TEST(be.move_next() == false); + DLIB_TEST(be.at_start() == false); + DLIB_TEST(be.current_element_valid() == false); + DLIB_TEST(be.size() == 0); + + be = border_enumerator(rectangle(4,4,4,4),1); + DLIB_TEST(be.at_start() == true); + DLIB_TEST(be.size() == 1); + be = border_enumerator(rectangle(4,4,4,4),3); + DLIB_TEST(be.at_start() == true); + DLIB_TEST(be.size() == 1); + be = border_enumerator(rectangle(4,4,4,4),0); + DLIB_TEST(be.at_start() == true); + DLIB_TEST(be.size() == 0); + be = border_enumerator(rectangle(4,4,5,5),0); + DLIB_TEST(be.at_start() == true); + DLIB_TEST(be.size() == 0); + be = border_enumerator(rectangle(4,4,5,5),1); + DLIB_TEST(be.at_start() == true); + DLIB_TEST(be.size() == 4); + be = border_enumerator(rectangle(4,4,5,5),2); + DLIB_TEST(be.size() == 4); + be = border_enumerator(rectangle(4,4,6,6),1); + DLIB_TEST(be.size() == 8); + be = border_enumerator(rectangle(4,4,6,6),2); + DLIB_TEST(be.size() == 9); + be = border_enumerator(rectangle(4,4,6,6),3); + DLIB_TEST(be.size() == 9); + DLIB_TEST(be.at_start() == true); + + array2d img, img2; + for (int size = 1; size < 10; ++size) + { + for (int bs = 0; bs < 4; ++bs) + { + img.set_size(size,size); + img2.set_size(size,size); + + assign_all_pixels(img, 1); + assign_all_pixels(img2, 1); + + zero_border_pixels(img2, bs,bs); + + be = border_enumerator(get_rect(img),bs); + DLIB_TEST(be.at_start() == true); + DLIB_TEST(be.current_element_valid() == false); + while (be.move_next()) + { + DLIB_TEST(be.at_start() == false); + DLIB_TEST(be.current_element_valid() == true); + DLIB_TEST_MSG(get_rect(img).contains(be.element()) == true, + get_rect(img) << " " << be.element() + ); + const point p = be.element(); + img[p.y()][p.x()] = 0; + } + DLIB_TEST(be.at_start() == false); + DLIB_TEST(be.current_element_valid() == false); + DLIB_TEST(be.move_next() == false); + DLIB_TEST(be.current_element_valid() == false); + DLIB_TEST(be.move_next() == false); + DLIB_TEST(be.current_element_valid() == false); + DLIB_TEST(be.at_start() == false); + + DLIB_TEST(mat(img) == mat(img2)); + + } + } + + for (int size = 1; size < 10; ++size) + { + for (int bs = 0; bs < 4; ++bs) + { + img.set_size(size,size+5); + img2.set_size(size,size+5); + + assign_all_pixels(img, 1); + assign_all_pixels(img2, 1); + + zero_border_pixels(img2, bs,bs); + + const point shift(4,5); + + be = border_enumerator(translate_rect(get_rect(img),shift),bs); + DLIB_TEST(be.at_start() == true); + DLIB_TEST(be.current_element_valid() == false); + while (be.move_next()) + { + DLIB_TEST(be.current_element_valid() == true); + DLIB_TEST(be.at_start() == false); + DLIB_TEST_MSG(get_rect(img).contains(be.element()-shift) == true, + get_rect(img) << " " << be.element() + ); + const point p = be.element()-shift; + img[p.y()][p.x()] = 0; + } + DLIB_TEST(be.current_element_valid() == false); + DLIB_TEST(be.move_next() == false); + DLIB_TEST(be.current_element_valid() == false); + DLIB_TEST(be.move_next() == false); + DLIB_TEST(be.current_element_valid() == false); + DLIB_TEST(be.at_start() == false); + + DLIB_TEST(mat(img) == mat(img2)); + + } + } + + for (int size = 1; size < 10; ++size) + { + for (int bs = 0; bs < 4; ++bs) + { + img.set_size(size+2,size); + img2.set_size(size+2,size); + + assign_all_pixels(img, 1); + assign_all_pixels(img2, 1); + + zero_border_pixels(img2, bs,bs); + + const point shift(-4,5); + + be = border_enumerator(translate_rect(get_rect(img),shift),bs); + DLIB_TEST(be.current_element_valid() == false); + while (be.move_next()) + { + DLIB_TEST(be.current_element_valid() == true); + DLIB_TEST_MSG(get_rect(img).contains(be.element()-shift) == true, + get_rect(img) << " " << be.element() + ); + const point p = be.element()-shift; + img[p.y()][p.x()] = 0; + } + DLIB_TEST(be.current_element_valid() == false); + DLIB_TEST(be.move_next() == false); + DLIB_TEST(be.current_element_valid() == false); + DLIB_TEST(be.move_next() == false); + DLIB_TEST(be.current_element_valid() == false); + + DLIB_TEST(mat(img) == mat(img2)); + + } + } + + { + matrix hits, truth; + const rectangle rect = rectangle(1,1,4,3); + + border_enumerator be(rect, rectangle(2,2, 3, 3)); + DLIB_TEST(be.size() == 8); + hits = false; + while (be.move_next()) + { + DLIB_TEST(rect.contains(be.element())); + hits(be.element().y(), be.element().x()) = true; + } + DLIB_TEST(be.current_element_valid() == false); + DLIB_TEST(be.size() == 8); + truth = false; + truth(1,1) = truth(1,2) = truth(1,3) = truth(1,4) = truth(2,1) = + truth(3,1) = truth(2,4) = truth(3,4) = true; + DLIB_TEST_MSG(truth == hits, truth << endl << hits); + + + + + be = border_enumerator(rect, rectangle(0,0, 9, 9)); + DLIB_TEST(be.size() == 0); + hits = false; + while (be.move_next()) + { + DLIB_TEST(rect.contains(be.element())); + hits(be.element().y(), be.element().x()) = true; + } + DLIB_TEST(be.current_element_valid() == false); + DLIB_TEST(be.size() == 0); + truth = false; + DLIB_TEST(truth == hits); + + + + be = border_enumerator(rect, rectangle(0,0, 3, 9)); + DLIB_TEST(be.size() == 3); + hits = false; + while (be.move_next()) + { + DLIB_TEST(rect.contains(be.element())); + hits(be.element().y(), be.element().x()) = true; + } + DLIB_TEST(be.current_element_valid() == false); + DLIB_TEST(be.size() == 3); + truth = false; + truth(1,4) = truth(2,4) = truth(3,4) = true; + DLIB_TEST(truth == hits); + + + + + be = border_enumerator(rect, rectangle(2,1, 4, 3)); + DLIB_TEST(be.size() == 3); + hits = false; + while (be.move_next()) + { + DLIB_TEST(rect.contains(be.element())); + hits(be.element().y(), be.element().x()) = true; + } + DLIB_TEST(be.current_element_valid() == false); + DLIB_TEST(be.size() == 3); + truth = false; + truth(1,1) = truth(2,1) = truth(3,1) = true; + DLIB_TEST(truth == hits); + + + + be = border_enumerator(rect, rectangle(1,1, 5, 2)); + DLIB_TEST(be.size() == 4); + hits = false; + while (be.move_next()) + { + DLIB_TEST(rect.contains(be.element())); + hits(be.element().y(), be.element().x()) = true; + } + DLIB_TEST(be.current_element_valid() == false); + DLIB_TEST(be.size() == 4); + truth = false; + truth(3,1) = truth(3,2) = truth(3,3) = truth(3,4) = true; + DLIB_TEST(truth == hits); + + + + } + + } + +// ---------------------------------------------------------------------------------------- + + void test_find_affine_transform() + { + //typedef dlib::vector vect; + typedef point vect; + std::vector from, to; + + from.push_back(vect(0,0)); + to.push_back(vect(0,1)); + + from.push_back(vect(0,1)); + to.push_back(vect(1,1)); + + from.push_back(vect(1,1)); + to.push_back(vect(1,0)); + + from.push_back(vect(1,0)); + to.push_back(vect(0,0)); + + point_transform_affine t = find_affine_transform(from,to); + point_transform_affine tinv = inv(t); + + for (unsigned long i = 0; i < from.size(); ++i) + { + dlog << LINFO << "affine transformation error: "<< length(t(from[i])-to[i]); + DLIB_TEST(length(t(from[i])-to[i]) < 1e-14); + DLIB_TEST(length(tinv(t(from[i]))-from[i]) < 1e-14); + DLIB_TEST(length(t(tinv(from[i]))-from[i]) < 1e-14); + + point_transform_affine temp = t*inv(t); + DLIB_TEST(length(temp.get_b()) < 1e-14); + DLIB_TEST(max(abs(temp.get_m() - identity_matrix(2))) < 1e-14); + } + + ostringstream sout; + serialize(t, sout); + istringstream sin(sout.str()); + point_transform_affine t2; + DLIB_TEST(length(t2(point(2,3)) - point(2,3)) < 1e-14); + deserialize(t2, sin); + DLIB_TEST(max(abs(t2.get_m()-t.get_m())) < 1e-14); + DLIB_TEST(max(abs(t2.get_b()-t.get_b())) < 1e-14); + } + +// ---------------------------------------------------------------------------------------- + + double projective_transform_pass_rate(const double error_rate) + { + print_spinner(); + dlog << LINFO << "projective_transform_pass_rate, error_rate: "<< error_rate; + dlib::rand rnd; + running_stats pass_rate; + for (int rounds = 0; rounds < 1000; ++rounds) + { + running_stats rs, rs_true; + matrix H = 2*(randm(3,3,rnd)-0.5); + + H(0,2) = rnd.get_random_gaussian()*10; + H(1,2) = rnd.get_random_gaussian()*10; + + + H(2,0) = rnd.get_random_double()*2.1; + H(2,1) = rnd.get_random_double()*2.1; + H(2,2) = 1 + rnd.get_random_gaussian()*3.1; + + point_transform_projective tran(H); + point_transform_projective traninv = inv(tran); + + const int num = rnd.get_random_32bit_number()%8 + 4; + + std::vector > from_points, to_points; + for (int i = 0; i < num; ++i) + { + dlib::vector p = randm(2,1,rnd)*1000; + from_points.push_back(p); + to_points.push_back(tran(p) + (randm(2,1,rnd)-0.5)*error_rate); + DLIB_TEST(length(traninv(tran(p))-p) <= 1e-5); + DLIB_TEST(length(tran(traninv(p))-p) <= 1e-5); + + point_transform_projective temp = tran*traninv; + DLIB_TEST_MSG(max(abs(temp.get_m() - identity_matrix(3))) < 1e-10, temp.get_m()); + temp = traninv*tran; + DLIB_TEST_MSG(max(abs(temp.get_m() - identity_matrix(3))) < 1e-10, temp.get_m()); + } + + + point_transform_projective tran2 = find_projective_transform(from_points, to_points); + + for (unsigned long i = 0; i < from_points.size(); ++i) + { + const double err = length_squared(tran2(from_points[i]) - to_points[i]); + rs.add(err); + const double err_true = length_squared(tran(from_points[i]) - to_points[i]); + rs_true.add(err_true); + } + + if ( rs.mean() < 0.01) + { + pass_rate.add(1); + } + else + { + dlog << LINFO << " errors: mean/max: " << rs.mean() << " " << rs.max(); + pass_rate.add(0); + } + + ostringstream sout; + serialize(tran, sout); + istringstream sin(sout.str()); + point_transform_projective tran3; + DLIB_TEST(length(tran3(point(2,3)) - point(2,3)) < 1e-14); + deserialize(tran3, sin); + DLIB_TEST(max(abs(tran3.get_m()-tran.get_m())) < 1e-14); + } + + dlog << LINFO << " pass_rate.mean(): "<< pass_rate.mean(); + return pass_rate.mean(); + } + +// ---------------------------------------------------------------------------------------- + + template + void test_find_similarity_transform() + { + print_spinner(); + std::vector > from_points, to_points; + + from_points.push_back(dlib::vector(0,0)); + from_points.push_back(dlib::vector(0,1)); + from_points.push_back(dlib::vector(1,0)); + + to_points.push_back(dlib::vector(8,0)); + to_points.push_back(dlib::vector(6,0)); + to_points.push_back(dlib::vector(8,2)); + + point_transform_affine tform = find_similarity_transform(from_points, to_points); + + for (unsigned long i = 0; i < from_points.size(); ++i) + { + DLIB_TEST(length(tform(from_points[i]) - to_points[i]) < 1e-14); + } + } + + template + void test_find_similarity_transform2() + { + print_spinner(); + std::vector > from_points, to_points; + + from_points.push_back(dlib::vector(0,0)); + from_points.push_back(dlib::vector(0,1)); + + to_points.push_back(dlib::vector(8,0)); + to_points.push_back(dlib::vector(6,0)); + + point_transform_affine tform = find_similarity_transform(from_points, to_points); + + for (unsigned long i = 0; i < from_points.size(); ++i) + { + DLIB_TEST(length(tform(from_points[i]) - to_points[i]) < 1e-14); + } + } + + +// ---------------------------------------------------------------------------------------- + + void test_rect_to_drect() + { + print_spinner(); + dlib::rand rnd; + for (int i = 0; i < 5000; ++i) + { + rectangle rect = centered_rect(rnd.get_random_32bit_number()%100, + rnd.get_random_32bit_number()%100, + rnd.get_random_32bit_number()%100, + rnd.get_random_32bit_number()%100); + + drectangle drect = rect; + rectangle rect2 = drect; + DLIB_TEST(rect2 == rect); + DLIB_TEST(rect.width() == drect.width()); + DLIB_TEST(rect.height() == drect.height()); + DLIB_TEST(dcenter(rect) == dcenter(drect)); + DLIB_TEST(rect.is_empty() == drect.is_empty()); + } + } + +// ---------------------------------------------------------------------------------------- + + void test_affine3d() + { + const dlib::vector x(1,0,0); + const dlib::vector y(0,1,0); + const dlib::vector z(0,0,1); + const dlib::vector e(1,1,1); + const dlib::vector ex(-1,1,1); + const dlib::vector ey(1,-1,1); + const dlib::vector ez(1,1,-1); + + dlib::vector w; + + w = rotate_around_z(pi/2)(x); + DLIB_TEST(length(w-y) < 1e-12); + w = rotate_around_z(pi/2)(e); + DLIB_TEST(length(w-ex) < 1e-12); + + w = rotate_around_y(-pi/2)(x); + DLIB_TEST(length(w-z) < 1e-12); + w = rotate_around_y(pi/2)(e); + DLIB_TEST(length(w-ez) < 1e-12); + + w = rotate_around_x(pi/2)(y); + DLIB_TEST(length(w-z) < 1e-12); + w = rotate_around_x(pi/2)(e); + DLIB_TEST(length(w-ey) < 1e-12); + + w = translate_point(x)(y); + DLIB_TEST(length(w-x-y) < 1e-12); + + point_transform_affine3d tform; + tform = rotate_around_x(pi/2)*rotate_around_z(pi/2)*translate_point(x); + DLIB_TEST(length(tform(dlib::vector())-z) < 1e-12); + DLIB_TEST(length(inv(tform)(z)) < 1e-12); + + point_transform_affine tform2; + tform = tform*tform2;// the default tform is the identity mapping so this shouldn't do anything different + DLIB_TEST(length(tform(dlib::vector())-z) < 1e-12); + DLIB_TEST(length(inv(tform)(z)) < 1e-12); + } + +// ---------------------------------------------------------------------------------------- + + class geometry_tester : public tester + { + public: + geometry_tester ( + ) : + tester ("test_geometry", + "Runs tests on the geometry stuff.") + {} + + void perform_test ( + ) + { + test_affine3d(); + test_rect_to_drect(); + geometry_test(); + test_border_enumerator(); + test_find_affine_transform(); + DLIB_TEST(projective_transform_pass_rate(0.1) > 0.99); + DLIB_TEST(projective_transform_pass_rate(0.0) == 1); + + test_find_similarity_transform(); + test_find_similarity_transform2(); + test_find_similarity_transform(); + test_find_similarity_transform2(); + } + } a; + +} + + + diff --git a/ml/dlib/dlib/test/global_optimization.cpp b/ml/dlib/dlib/test/global_optimization.cpp new file mode 100644 index 000000000..fee80c81f --- /dev/null +++ b/ml/dlib/dlib/test/global_optimization.cpp @@ -0,0 +1,302 @@ +// Copyright (C) 2017 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tester.h" + + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.global_optimization"); + +// ---------------------------------------------------------------------------------------- + + void test_upper_bound_function(double relative_noise_magnitude, double solver_eps) + { + print_spinner(); + + dlog << LINFO << "test_upper_bound_function, relative_noise_magnitude="<< relative_noise_magnitude << ", solver_eps=" << solver_eps; + + auto rosen = [](const matrix& x) { return -1*( 100*std::pow(x(1) - x(0)*x(0),2.0) + std::pow(1 - x(0),2)); }; + + dlib::rand rnd; + auto make_rnd = [&rnd]() { matrix x(2); x = 2*rnd.get_random_double(), 2*rnd.get_random_double(); return x; }; + + + std::vector evals; + for (int i = 0; i < 100; ++i) + { + auto x = make_rnd(); + evals.emplace_back(x,rosen(x)); + } + + upper_bound_function ub(evals, relative_noise_magnitude, solver_eps); + DLIB_TEST(ub.num_points() == (long)evals.size()); + DLIB_TEST(ub.dimensionality() == 2); + for (auto& ev : evals) + { + dlog << LINFO << ub(ev.x) - ev.y; + DLIB_TEST_MSG(ub(ev.x) - ev.y > -1e10, ub(ev.x) - ev.y); + } + + + for (int i = 0; i < 100; ++i) + { + auto x = make_rnd(); + evals.emplace_back(x,rosen(x)); + ub.add(evals.back()); + } + + DLIB_TEST(ub.num_points() == (long)evals.size()); + DLIB_TEST(ub.dimensionality() == 2); + + for (auto& ev : evals) + { + dlog << LINFO << ub(ev.x) - ev.y; + DLIB_TEST_MSG(ub(ev.x) - ev.y > -1e10, ub(ev.x) - ev.y); + } + + + if (solver_eps < 0.001) + { + dlog << LINFO << "out of sample points: "; + for (int i = 0; i < 10; ++i) + { + auto x = make_rnd(); + dlog << LINFO << ub(x) - rosen(x); + DLIB_TEST_MSG(ub(x) - rosen(x) > 1e-10, ub(x) - rosen(x)); + } + } + } + +// ---------------------------------------------------------------------------------------- + + double complex_holder_table ( double x0, double x1) + { + // The regular HolderTable function + //return -std::abs(sin(x0)*cos(x1)*exp(std::abs(1-std::sqrt(x0*x0+x1*x1)/pi))); + + // My more complex version of it with discontinuities and more local minima. + double sign = 1; + for (double j = -4; j < 9; j += 0.5) + { + if (j < x0 && x0 < j+0.5) + x0 += sign*0.25; + sign *= -1; + } + // HolderTable function tilted towards 10,10 + return -std::abs(sin(x0)*cos(x1)*exp(std::abs(1-std::sqrt(x0*x0+x1*x1)/pi))) +(x0+x1)/10 + sin(x0*10)*cos(x1*10); + } + +// ---------------------------------------------------------------------------------------- + + void test_global_function_search() + { + + function_spec spec{{-10,-10}, {10,10}}; + function_spec spec2{{-10,-10, -50}, {10,10, 50}}; + global_function_search opt({spec, spec, spec2}); + + dlib::rand rnd; + bool found_optimal_point = false; + for (int i = 0; i < 400 && !found_optimal_point; ++i) + { + print_spinner(); + std::vector nexts; + for (int k = 0; k < rnd.get_integer_in_range(1,4); ++k) + nexts.emplace_back(opt.get_next_x()); + + for (auto& next : nexts) + { + switch (next.function_idx()) + { + case 0: next.set( -complex_holder_table(next.x()(0), next.x()(1))); break; + case 1: next.set( -10*complex_holder_table(next.x()(0), next.x()(1))); break; + case 2: next.set( -2*complex_holder_table(next.x()(0), next.x()(1))); break; + default: DLIB_TEST(false); break; + } + + matrix x; + double y; + size_t function_idx; + opt.get_best_function_eval(x,y,function_idx); + /* + cout << "\ni: "<< i << endl; + cout << "best eval x: "<< trans(x); + cout << "best eval y: "<< y << endl; + cout << "best eval function index: "<< function_idx << endl; + */ + + if (std::abs(y - 10*21.9210397) < 0.0001) + { + found_optimal_point = true; + break; + } + } + } + + DLIB_TEST(found_optimal_point); + } + +// ---------------------------------------------------------------------------------------- + + void test_find_max_global( + ) + { + print_spinner(); + auto rosen = [](const matrix& x) { return -1*( 100*std::pow(x(1) - x(0)*x(0),2.0) + std::pow(1 - x(0),2)); }; + + auto result = find_max_global(rosen, {0.1, 0.1}, {2, 2}, max_function_calls(100), 0); + matrix true_x = {1,1}; + + dlog << LINFO << "rosen: " << trans(result.x); + DLIB_TEST_MSG(max(abs(true_x-result.x)) < 1e-5, max(abs(true_x-result.x))); + print_spinner(); + + result = find_max_global(rosen, {0.1, 0.1}, {2, 2}, max_function_calls(100)); + dlog << LINFO << "rosen: " << trans(result.x); + DLIB_TEST_MSG(max(abs(true_x-result.x)) < 1e-5, max(abs(true_x-result.x))); + print_spinner(); + + result = find_max_global(rosen, {0.1, 0.1}, {2, 2}, std::chrono::seconds(5)); + dlog << LINFO << "rosen: " << trans(result.x); + DLIB_TEST_MSG(max(abs(true_x-result.x)) < 1e-5, max(abs(true_x-result.x))); + print_spinner(); + + result = find_max_global(rosen, {0.1, 0.1}, {2, 2}, {false,false}, max_function_calls(100)); + dlog << LINFO << "rosen: " << trans(result.x); + DLIB_TEST_MSG(max(abs(true_x-result.x)) < 1e-5, max(abs(true_x-result.x))); + print_spinner(); + + result = find_max_global(rosen, {0.1, 0.1}, {0.9, 0.9}, {false,false}, max_function_calls(140)); + true_x = {0.9, 0.81}; + dlog << LINFO << "rosen, bounded at 0.9: " << trans(result.x); + DLIB_TEST_MSG(max(abs(true_x-result.x)) < 1e-5, max(abs(true_x-result.x))); + print_spinner(); + + result = find_max_global([](double x){ return -std::pow(x-2,2.0); }, -10, 10, max_function_calls(10), 0); + dlog << LINFO << "(x-2)^2: " << trans(result.x); + DLIB_TEST(result.x.size()==1); + DLIB_TEST(std::abs(result.x - 2) < 1e-9); + print_spinner(); + + result = find_max_global([](double x){ return -std::pow(x-2,2.0); }, -10, 1, max_function_calls(10)); + dlog << LINFO << "(x-2)^2, bound at 1: " << trans(result.x); + DLIB_TEST(result.x.size()==1); + DLIB_TEST(std::abs(result.x - 1) < 1e-9); + print_spinner(); + + result = find_max_global([](double x){ return -std::pow(x-2,2.0); }, -10, 1, std::chrono::seconds(2)); + dlog << LINFO << "(x-2)^2, bound at 1: " << trans(result.x); + DLIB_TEST(result.x.size()==1); + DLIB_TEST(std::abs(result.x - 1) < 1e-9); + print_spinner(); + + + result = find_max_global([](double a, double b){ return -complex_holder_table(a,b);}, + {-10, -10}, {10, 10}, max_function_calls(400), 0); + dlog << LINFO << "complex_holder_table y: "<< result.y; + DLIB_TEST_MSG(std::abs(result.y - 21.9210397) < 0.0001, std::abs(result.y - 21.9210397)); + } + +// ---------------------------------------------------------------------------------------- + + void test_find_min_global( + ) + { + print_spinner(); + auto rosen = [](const matrix& x) { return +1*( 100*std::pow(x(1) - x(0)*x(0),2.0) + std::pow(1 - x(0),2)); }; + + auto result = find_min_global(rosen, {0.1, 0.1}, {2, 2}, max_function_calls(100), 0); + matrix true_x = {1,1}; + + dlog << LINFO << "rosen: " << trans(result.x); + DLIB_TEST_MSG(min(abs(true_x-result.x)) < 1e-5, min(abs(true_x-result.x))); + print_spinner(); + + result = find_min_global(rosen, {0.1, 0.1}, {2, 2}, max_function_calls(100)); + dlog << LINFO << "rosen: " << trans(result.x); + DLIB_TEST_MSG(min(abs(true_x-result.x)) < 1e-5, min(abs(true_x-result.x))); + print_spinner(); + + result = find_min_global(rosen, {0.1, 0.1}, {2, 2}, std::chrono::seconds(5)); + dlog << LINFO << "rosen: " << trans(result.x); + DLIB_TEST_MSG(min(abs(true_x-result.x)) < 1e-5, min(abs(true_x-result.x))); + print_spinner(); + + result = find_min_global(rosen, {0.1, 0.1}, {2, 2}, {false,false}, max_function_calls(100)); + dlog << LINFO << "rosen: " << trans(result.x); + DLIB_TEST_MSG(min(abs(true_x-result.x)) < 1e-5, min(abs(true_x-result.x))); + print_spinner(); + + result = find_min_global(rosen, {0, 0}, {0.9, 0.9}, {false,false}, max_function_calls(100)); + true_x = {0.9, 0.81}; + dlog << LINFO << "rosen, bounded at 0.9: " << trans(result.x); + DLIB_TEST_MSG(min(abs(true_x-result.x)) < 1e-5, min(abs(true_x-result.x))); + print_spinner(); + + result = find_min_global([](double x){ return std::pow(x-2,2.0); }, -10, 10, max_function_calls(10), 0); + dlog << LINFO << "(x-2)^2: " << trans(result.x); + DLIB_TEST(result.x.size()==1); + DLIB_TEST(std::abs(result.x - 2) < 1e-9); + print_spinner(); + + result = find_min_global([](double x){ return std::pow(x-2,2.0); }, -10, 1, max_function_calls(10)); + dlog << LINFO << "(x-2)^2, bound at 1: " << trans(result.x); + DLIB_TEST(result.x.size()==1); + DLIB_TEST(std::abs(result.x - 1) < 1e-9); + print_spinner(); + + result = find_min_global([](double x){ return std::pow(x-2,2.0); }, -10, 1, std::chrono::seconds(2)); + dlog << LINFO << "(x-2)^2, bound at 1: " << trans(result.x); + DLIB_TEST(result.x.size()==1); + DLIB_TEST(std::abs(result.x - 1) < 1e-9); + print_spinner(); + + + result = find_min_global([](double a, double b){ return complex_holder_table(a,b);}, + {-10, -10}, {10, 10}, max_function_calls(400), 0); + dlog << LINFO << "complex_holder_table y: "<< result.y; + DLIB_TEST_MSG(std::abs(result.y + 21.9210397) < 0.0001, std::abs(result.y + 21.9210397)); + } + +// ---------------------------------------------------------------------------------------- + + class global_optimization_tester : public tester + { + public: + global_optimization_tester ( + ) : + tester ("test_global_optimization", + "Runs tests on the global optimization components.") + {} + + void perform_test ( + ) + { + test_upper_bound_function(0.01, 1e-6); + test_upper_bound_function(0.0, 1e-6); + test_upper_bound_function(0.0, 1e-1); + test_global_function_search(); + test_find_max_global(); + test_find_min_global(); + } + } a; + +} + + diff --git a/ml/dlib/dlib/test/graph.cpp b/ml/dlib/dlib/test/graph.cpp new file mode 100644 index 000000000..0651f3049 --- /dev/null +++ b/ml/dlib/dlib/test/graph.cpp @@ -0,0 +1,414 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#include +#include +#include +#include +#include +#include +#include + +#include "tester.h" + +// This is called an unnamed-namespace and it has the effect of making everything inside this file "private" +// so that everything you declare will have static linkage. Thus we won't have any multiply +// defined symbol errors coming out of the linker when we try to compile the test suite. +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + + + // Declare the logger we will use in this test. The name of the tester + // should start with "test." + logger dlog("test.graph"); + + template < + typename graph + > + void graph_test ( + ) + /*! + requires + - graph is an implementation of graph/graph_kernel_abstract.h + is instantiated with int + ensures + - runs tests on graph for compliance with the specs + !*/ + { + + print_spinner(); + + COMPILE_TIME_ASSERT(is_graph::value); + + graph a, b; + dlib::set::compare_1b_c s; + + DLIB_TEST(graph_contains_length_one_cycle(a) == false); + DLIB_TEST(graph_contains_undirected_cycle(a) == false); + + DLIB_TEST(a.number_of_nodes() == 0); + + a.set_number_of_nodes(5); + DLIB_TEST(graph_is_connected(a) == false); + DLIB_TEST(graph_contains_undirected_cycle(a) == false); + DLIB_TEST(a.number_of_nodes() == 5); + DLIB_TEST(graph_contains_length_one_cycle(a) == false); + + for (int i = 0; i < 5; ++i) + { + a.node(i).data = i; + DLIB_TEST(a.node(i).index() == (unsigned int)i); + } + + a.remove_node(1); + + DLIB_TEST(a.number_of_nodes() == 4); + + + // make sure that only the number with data == 1 was removed + int count = 0; + for (int i = 0; i < 4; ++i) + { + count += a.node(i).data; + DLIB_TEST(a.node(i).number_of_neighbors() == 0); + DLIB_TEST(a.node(i).index() == (unsigned int)i); + } + + DLIB_TEST(count == 9); + + + a.add_edge(1,1); + DLIB_TEST(graph_contains_length_one_cycle(a) == true); + DLIB_TEST(graph_contains_undirected_cycle(a) == true); + DLIB_TEST(a.has_edge(1,1)); + DLIB_TEST(a.node(1).number_of_neighbors() == 1); + + a.add_edge(1,3); + DLIB_TEST(a.node(1).number_of_neighbors() == 2); + DLIB_TEST(a.node(2).number_of_neighbors() == 0); + DLIB_TEST(a.node(3).number_of_neighbors() == 1); + DLIB_TEST(a.has_edge(1,1)); + DLIB_TEST(a.has_edge(1,3)); + DLIB_TEST(a.has_edge(3,1)); + DLIB_TEST(graph_contains_undirected_cycle(a) == true); + a.remove_edge(1,1); + DLIB_TEST(graph_contains_length_one_cycle(a) == false); + DLIB_TEST(graph_contains_undirected_cycle(a) == false); + DLIB_TEST(a.node(1).number_of_neighbors() == 1); + DLIB_TEST(a.node(2).number_of_neighbors() == 0); + DLIB_TEST(a.node(3).number_of_neighbors() == 1); + DLIB_TEST(a.has_edge(1,1) == false); + DLIB_TEST(a.has_edge(1,3)); + DLIB_TEST(a.has_edge(3,1)); + DLIB_TEST(graph_contains_undirected_cycle(a) == false); + + swap(a,b); + + + DLIB_TEST(graph_contains_undirected_cycle(b) == false); + DLIB_TEST(b.node(1).number_of_neighbors() == 1); + DLIB_TEST(b.node(2).number_of_neighbors() == 0); + DLIB_TEST(b.node(3).number_of_neighbors() == 1); + DLIB_TEST(b.has_edge(1,1) == false); + DLIB_TEST(b.has_edge(1,3)); + DLIB_TEST(b.has_edge(3,1)); + DLIB_TEST(graph_contains_undirected_cycle(b) == false); + + DLIB_TEST(a.number_of_nodes() == 0); + DLIB_TEST(b.number_of_nodes() == 4); + + copy_graph_structure(b,b); + DLIB_TEST(b.number_of_nodes() == 4); + + b.add_edge(1,2); + DLIB_TEST(graph_contains_undirected_cycle(b) == false); + DLIB_TEST(graph_contains_undirected_cycle(b) == false); + b.add_edge(3,2); + DLIB_TEST(graph_contains_undirected_cycle(b) == true); + b.add_edge(1,1); + DLIB_TEST(graph_is_connected(b) == false); + b.add_edge(0,2); + DLIB_TEST(graph_is_connected(b) == true); + + DLIB_TEST(graph_contains_undirected_cycle(b) == true); + + DLIB_TEST(a.number_of_nodes() == 0); + + for (unsigned long i = 0; i < b.number_of_nodes(); ++i) + { + for (unsigned long j = 0; j < b.node(i).number_of_neighbors(); ++j) + { + b.node(i).edge(j) = 'c'; + } + } + + b.node(1).edge(0) = 'a'; + const unsigned long e1 = b.node(1).neighbor(0).index(); + b.node(0).edge(0) = 'n'; + const unsigned long e2 = b.node(0).neighbor(0).index(); + + ostringstream sout; + serialize(b, sout); + istringstream sin(sout.str()); + + DLIB_TEST(graph_contains_undirected_cycle(a) == false); + + a.set_number_of_nodes(10); + deserialize(a, sin); + DLIB_TEST(graph_contains_undirected_cycle(a) == true); + + for (unsigned long i = 0; i < a.number_of_nodes(); ++i) + { + for (unsigned long j = 0; j < a.node(i).number_of_neighbors(); ++j) + { + if ((i == 0 && a.node(i).neighbor(j).index() == e2) || + (i == e2 && a.node(i).neighbor(j).index() == 0) ) + { + DLIB_TEST(a.node(i).edge(j) == 'n'); + } + else if ((i == 1 && a.node(i).neighbor(j).index() == e1) || + (i == e1 && a.node(i).neighbor(j).index() == 1)) + { + DLIB_TEST(a.node(i).edge(j) == 'a'); + } + else + { + DLIB_TEST(i != 0 || a.node(i).neighbor(j).index() != e2); + DLIB_TEST_MSG(a.node(i).edge(j) == 'c',a.node(i).edge(j)); + } + } + } + + DLIB_TEST(a.number_of_nodes() == 4); + DLIB_TEST(a.has_edge(1,2) == true); + DLIB_TEST(a.has_edge(3,2) == true); + DLIB_TEST(a.has_edge(1,1) == true); + DLIB_TEST(a.has_edge(0,2) == true); + DLIB_TEST(a.has_edge(1,3) == true); + DLIB_TEST(a.has_edge(0,1) == false); + DLIB_TEST(a.has_edge(0,3) == false); + DLIB_TEST(a.has_edge(0,0) == false); + DLIB_TEST(a.has_edge(1,0) == false); + DLIB_TEST(a.has_edge(3,0) == false); + + + for (unsigned long i = 0; i < a.number_of_nodes(); ++i) + { + a.node(i).data = static_cast(i); + } + + a.remove_node(2); + DLIB_TEST(a.number_of_nodes() == 3); + DLIB_TEST(graph_contains_undirected_cycle(a) == true); + + count = 0; + for (unsigned long i = 0; i < a.number_of_nodes(); ++i) + { + if (a.node(i).data == 0) + { + DLIB_TEST(a.node(i).number_of_neighbors() == 0); + } + else if (a.node(i).data == 1) + { + DLIB_TEST(a.node(i).number_of_neighbors() == 2); + } + else if (a.node(i).data == 3) + { + DLIB_TEST(a.node(i).number_of_neighbors() == 1); + } + else + { + DLIB_TEST_MSG(false,"this is impossible"); + } + + for (unsigned long j = 0; j < a.number_of_nodes(); ++j) + { + if ((a.node(i).data == 1 && a.node(j).data == 1) || + (a.node(i).data == 1 && a.node(j).data == 3) || + (a.node(i).data == 3 && a.node(j).data == 1)) + { + DLIB_TEST(a.has_edge(i,j) == true); + ++count; + } + else + { + DLIB_TEST(a.has_edge(i,j) == false); + } + } + } + DLIB_TEST_MSG(count == 3,count); + DLIB_TEST(graph_contains_undirected_cycle(a) == true); + a.remove_edge(1,1); + DLIB_TEST(graph_contains_undirected_cycle(a) == false); + + DLIB_TEST(b.number_of_nodes() == 4); + b.clear(); + DLIB_TEST(b.number_of_nodes() == 0); + + + a.clear(); + + /* + 1 7 + | / \ + 2 6 0 + \ / | + 3 / + / \ / + 4 5 + */ + a.set_number_of_nodes(8); + a.add_edge(1,2); + a.add_edge(2,3); + a.add_edge(3,4); + a.add_edge(3,5); + a.add_edge(3,6); + a.add_edge(6,7); + a.add_edge(7,0); + a.add_edge(0,5); + + DLIB_TEST(graph_is_connected(a)); + + dlib::set::compare_1b_c>::kernel_1b_c sos; + + dlib::graph::compare_1b_c, dlib::set::compare_1b_c>::kernel_1a_c join_tree; + unsigned long temp; + triangulate_graph_and_find_cliques(a,sos); + DLIB_TEST(a.number_of_nodes() == 8); + + create_join_tree(a, join_tree); + DLIB_TEST(join_tree.number_of_nodes() == 6); + DLIB_TEST(graph_is_connected(join_tree) == true); + DLIB_TEST(graph_contains_undirected_cycle(join_tree) == false); + DLIB_TEST(is_join_tree(a, join_tree)); + + // check old edges + DLIB_TEST(a.has_edge(1,2)); + DLIB_TEST(a.has_edge(2,3)); + DLIB_TEST(a.has_edge(3,4)); + DLIB_TEST(a.has_edge(3,5)); + DLIB_TEST(a.has_edge(3,6)); + DLIB_TEST(a.has_edge(6,7)); + DLIB_TEST(a.has_edge(7,0)); + DLIB_TEST(a.has_edge(0,5)); + + DLIB_TEST(graph_is_connected(a)); + + DLIB_TEST(sos.size() == 6); + + + temp = 1; s.add(temp); + temp = 2; s.add(temp); + DLIB_TEST(sos.is_member(s)); + s.clear(); + temp = 2; s.add(temp); + temp = 3; s.add(temp); + DLIB_TEST(sos.is_member(s)); + s.clear(); + temp = 4; s.add(temp); + temp = 3; s.add(temp); + DLIB_TEST(sos.is_member(s)); + + sos.reset(); + while (sos.move_next()) + { + DLIB_TEST(is_clique(a, sos.element())); + DLIB_TEST(is_maximal_clique(a, sos.element())); + } + + } + + + void test_copy() + { + { + graph::kernel_1a_c a,b; + + a.set_number_of_nodes(3); + a.node(0).data = 1; + a.node(1).data = 2; + a.node(2).data = 3; + a.add_edge(0,1); + a.add_edge(0,2); + edge(a,0,1) = 4; + edge(a,0,2) = 5; + + a.add_edge(0,0); + edge(a,0,0) = 9; + copy_graph(a, b); + + DLIB_TEST(b.number_of_nodes() == 3); + DLIB_TEST(b.node(0).data == 1); + DLIB_TEST(b.node(1).data == 2); + DLIB_TEST(b.node(2).data == 3); + DLIB_TEST(edge(b,0,1) == 4); + DLIB_TEST(edge(b,0,2) == 5); + DLIB_TEST(edge(b,0,0) == 9); + } + { + graph::kernel_1a_c a,b; + + a.set_number_of_nodes(4); + a.node(0).data = 1; + a.node(1).data = 2; + a.node(2).data = 3; + a.node(3).data = 8; + a.add_edge(0,1); + a.add_edge(0,2); + a.add_edge(2,3); + edge(a,0,1) = 4; + edge(a,0,2) = 5; + edge(a,2,3) = 6; + + copy_graph(a, b); + + DLIB_TEST(b.number_of_nodes() == 4); + DLIB_TEST(b.node(0).data == 1); + DLIB_TEST(b.node(1).data == 2); + DLIB_TEST(b.node(2).data == 3); + DLIB_TEST(b.node(3).data == 8); + DLIB_TEST(edge(b,0,1) == 4); + DLIB_TEST(edge(b,0,2) == 5); + DLIB_TEST(edge(b,2,3) == 6); + } + } + + + + class graph_tester : public tester + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a test for the graph object. When it is constructed + it adds itself into the testing framework. The command line switch is + specified as test_directed_graph by passing that string to the tester constructor. + !*/ + public: + graph_tester ( + ) : + tester ("test_graph", + "Runs tests on the graph component.") + {} + + void perform_test ( + ) + { + dlog << LINFO << "testing kernel_1a_c"; + graph_test::kernel_1a_c>(); + + dlog << LINFO << "testing kernel_1a"; + graph_test::kernel_1a>(); + + test_copy(); + } + } a; + + +} + + + diff --git a/ml/dlib/dlib/test/graph_cuts.cpp b/ml/dlib/dlib/test/graph_cuts.cpp new file mode 100644 index 000000000..43ba35c16 --- /dev/null +++ b/ml/dlib/dlib/test/graph_cuts.cpp @@ -0,0 +1,1217 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tester.h" + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + + logger dlog("test.graph_cuts"); + +// ---------------------------------------------------------------------------------------- + + class dense_potts_problem + { + public: + typedef double value_type; + private: + + matrix factors1; + matrix factors2; + matrix labels; + public: + + dense_potts_problem ( + unsigned long num_nodes, + dlib::rand& rnd + ) + { + factors1 = -7*(randm(num_nodes, 1, rnd)-0.5); + factors2 = make_symmetric(randm(num_nodes, num_nodes, rnd) > 0.5); + labels.set_size(num_nodes); + labels = FREE_NODE; + } + + unsigned long number_of_nodes ( + ) const { return factors1.nr(); } + + unsigned long number_of_neighbors ( + unsigned long // idx + ) const { return number_of_nodes()-1; } + + unsigned long get_neighbor_idx ( + unsigned long node_id1, + unsigned long node_id2 + ) const + { + if (node_id2 < node_id1) + return node_id2; + else + return node_id2-1; + } + + unsigned long get_neighbor ( + unsigned long node_id, + unsigned long idx + ) const + { + DLIB_TEST(node_id < number_of_nodes()); + DLIB_TEST(idx < number_of_neighbors(node_id)); + if (idx < node_id) + return idx; + else + return idx+1; + } + + void set_label ( + const unsigned long& idx, + node_label value + ) + { + labels(idx) = value; + } + + node_label get_label ( + const unsigned long& idx + ) const + { + return labels(idx); + } + + + value_type factor_value (unsigned long idx) const + { + DLIB_TEST(idx < number_of_nodes()); + + return factors1(idx); + } + + value_type factor_value_disagreement (unsigned long idx1, unsigned long idx2) const + { + DLIB_TEST(idx1 != idx2); + DLIB_TEST(idx1 < number_of_nodes()); + DLIB_TEST(idx2 < number_of_nodes()); + DLIB_TEST(get_neighbor_idx(idx1,idx2) < number_of_neighbors(idx1)); + DLIB_TEST(get_neighbor_idx(idx2,idx1) < number_of_neighbors(idx2)); + + return factors2(idx1, idx2); + } + + }; + +// ---------------------------------------------------------------------------------------- + + class image_potts_problem + { + public: + typedef double value_type; + const static unsigned long max_number_of_neighbors = 4; + private: + + matrix factors1; + matrix factors2; + matrix labels; + long nr; + long nc; + rectangle rect, inner_rect; + mutable long count; + public: + + image_potts_problem ( + long nr_, + long nc_, + dlib::rand& rnd + ) : nr(nr_), nc(nc_) + { + rect = rectangle(0,0,nc-1,nr-1); + inner_rect = shrink_rect(rect,1); + const unsigned long num_nodes = nr*nc; + factors1 = -7*(randm(num_nodes, 1, rnd)); + factors2 = randm(num_nodes, 4, rnd) > 0.5; + + //factors1 = 0; + //set_rowm(factors1, range(0, factors1.nr()/2)) = -1; + + labels.set_size(num_nodes); + labels = FREE_NODE; + + count = 0; + } + + ~image_potts_problem() + { + dlog << LTRACE << "interface calls: " << count; + dlog << LTRACE << "labels hash: "<< murmur_hash3_128bit(&labels(0), labels.size()*sizeof(labels(0)), 0).first; + } + + unsigned long number_of_nodes ( + ) const { return factors1.nr(); } + + unsigned long number_of_neighbors ( + unsigned long idx + ) const + { + ++count; + const point& p = get_loc(idx); + if (inner_rect.contains(p)) + return 4; + else if (p == rect.tl_corner() || + p == rect.bl_corner() || + p == rect.tr_corner() || + p == rect.br_corner() ) + return 2; + else + return 3; + } + + unsigned long get_neighbor_idx ( + long node_id1, + long node_id2 + ) const + { + ++count; + const point& p = get_loc(node_id1); + long ret = 0; + if (rect.contains(p + point(1,0))) + { + if (node_id2-node_id1 == 1) + return ret; + ++ret; + } + + if (rect.contains(p - point(1,0))) + { + if (node_id2-node_id1 == -1) + return ret; + ++ret; + } + + if (rect.contains(p + point(0,1))) + { + if (node_id2-node_id1 == nc) + return ret; + ++ret; + } + + return ret; + } + + unsigned long get_neighbor ( + long node_id, + long idx + ) const + { + ++count; + const point& p = get_loc(node_id); + if (rect.contains(p + point(1,0))) + { + if (idx == 0) + return node_id+1; + --idx; + } + + if (rect.contains(p - point(1,0))) + { + if (idx == 0) + return node_id-1; + --idx; + } + + if (rect.contains(p + point(0,1))) + { + if (idx == 0) + return node_id+nc; + --idx; + } + + return node_id-nc; + } + + void set_label ( + const unsigned long& idx, + node_label value + ) + { + ++count; + labels(idx) = value; + } + + node_label get_label ( + const unsigned long& idx + ) const + { + ++count; + return labels(idx); + } + + value_type factor_value (unsigned long idx) const + { + ++count; + DLIB_TEST(idx < (unsigned long)number_of_nodes()); + + return factors1(idx); + } + + value_type factor_value_disagreement (unsigned long idx1, unsigned long idx2) const + { + ++count; + DLIB_TEST(idx1 != idx2); + DLIB_TEST(idx1 < (unsigned long)number_of_nodes()); + DLIB_TEST(idx2 < (unsigned long)number_of_nodes()); + + // make this function symmetric + if (idx1 > idx2) + swap(idx1,idx2); + + + DLIB_TEST(get_neighbor(idx1, get_neighbor_idx(idx1, idx2)) == idx2); + DLIB_TEST(get_neighbor(idx2, get_neighbor_idx(idx2, idx1)) == idx1); + + // the neighbor relationship better be symmetric + DLIB_TEST(get_neighbor_idx(idx1,idx2) < number_of_neighbors(idx1)); + DLIB_TEST_MSG(get_neighbor_idx(idx2,idx1) < number_of_neighbors(idx2), + "\n idx1: "<< idx1 << + "\n idx2: "<< idx2 << + "\n get_neighbor_idx(idx2,idx1): "<< get_neighbor_idx(idx2,idx1) << + "\n number_of_neighbors(idx2): " << number_of_neighbors(idx2) << + "\n nr: "<< nr << + "\n nc: "<< nc + ); + + return factors2(idx1, get_neighbor_idx(idx1,idx2)); + } + + private: + point get_loc ( + const unsigned long& idx + ) const + { + return point(idx%nc, idx/nc); + } + + }; + +// ---------------------------------------------------------------------------------------- + + template + void brute_force_potts_model ( + potts_model& g + ) + { + potts_model m(g); + + const unsigned long num = (unsigned long)std::pow(2.0, (double)m.number_of_nodes()); + + double best_score = -std::numeric_limits::infinity(); + for (unsigned long i = 0; i < num; ++i) + { + for (unsigned long j = 0; j < m.number_of_nodes(); ++j) + { + unsigned long T = (1)< best_score) + { + best_score = score; + g = m; + } + } + } + +// ---------------------------------------------------------------------------------------- + + template + void brute_force_potts_model_on_graph ( + const graph_type& g, + std::vector& labels_ + ) + { + std::vector labels; + labels.resize(g.number_of_nodes()); + + const unsigned long num = (unsigned long)std::pow(2.0, (double)g.number_of_nodes()); + + double best_score = -std::numeric_limits::infinity(); + for (unsigned long i = 0; i < num; ++i) + { + for (unsigned long j = 0; j < g.number_of_nodes(); ++j) + { + unsigned long T = (1)< best_score) + { + best_score = score; + labels_ = labels; + } + } + } + +// ---------------------------------------------------------------------------------------- + + template + void make_random_undirected_graph( + dlib::rand& rnd, + graph_type& g + ) + { + typedef typename graph_type::edge_type edge_weight_type; + g.clear(); + const unsigned int num_nodes = rnd.get_random_32bit_number()%8; + g.set_number_of_nodes(num_nodes); + + const unsigned int num_edges = static_cast(num_nodes*(num_nodes-1)/2*rnd.get_random_double() + 0.5); + + // add the right number of randomly selected edges + unsigned int count = 0; + while (count < num_edges) + { + unsigned long i = rnd.get_random_32bit_number()%g.number_of_nodes(); + unsigned long j = rnd.get_random_32bit_number()%g.number_of_nodes(); + if (i != j && g.has_edge(i, j) == false) + { + ++count; + g.add_edge(i, j); + edge(g, i, j) = static_cast(rnd.get_random_double()*50); + } + } + + for (unsigned long i = 0; i < g.number_of_nodes(); ++i) + { + g.node(i).data = static_cast(rnd.get_random_gaussian()*200); + } + } + +// ---------------------------------------------------------------------------------------- + + void test_graph_potts_model( + dlib::rand& rnd + ) + { + using namespace std; + double brute_force_score; + double graph_cut_score; + + graph::kernel_1a_c temp; + make_random_undirected_graph(rnd,temp); + + { + std::vector labels; + + brute_force_potts_model_on_graph(temp, labels); + + for (unsigned long i = 0; i < temp.number_of_nodes(); ++i) + { + dlog << LTRACE << "node " << i << ": "<< (int)labels[i]; + } + + brute_force_score = potts_model_score(temp, labels); + dlog << LTRACE << "brute force score: "<< brute_force_score; + } + dlog << LTRACE << "******************"; + + { + std::vector labels; + find_max_factor_graph_potts(temp, labels); + DLIB_TEST(temp.number_of_nodes() == labels.size()); + + for (unsigned long i = 0; i < temp.number_of_nodes(); ++i) + { + dlog << LTRACE << "node " << i << ": "<< (int)labels[i]; + } + graph_cut_score = potts_model_score(temp, labels); + dlog << LTRACE << "graph cut score: "<< graph_cut_score; + } + + DLIB_TEST_MSG(graph_cut_score == brute_force_score, std::abs(graph_cut_score - brute_force_score)); + + dlog << LTRACE << "##################"; + dlog << LTRACE << "##################"; + dlog << LTRACE << "##################"; + } + +// ---------------------------------------------------------------------------------------- + + template + void impl_test_potts_model ( + potts_prob& p + ) + { + using namespace std; + double brute_force_score; + double graph_cut_score; + + { + potts_prob temp(p); + brute_force_potts_model(temp); + + for (unsigned long i = 0; i < temp.number_of_nodes(); ++i) + { + dlog << LTRACE << "node " << i << ": "<< (int)temp.get_label(i); + } + brute_force_score = potts_model_score(temp); + dlog << LTRACE << "brute force score: "<< brute_force_score; + } + dlog << LTRACE << "******************"; + + { + potts_prob temp(p); + find_max_factor_graph_potts(temp); + + for (unsigned long i = 0; i < temp.number_of_nodes(); ++i) + { + dlog << LTRACE << "node " << i << ": "<< (int)temp.get_label(i); + } + graph_cut_score = potts_model_score(temp); + dlog << LTRACE << "graph cut score: "<< graph_cut_score; + } + + DLIB_TEST_MSG(graph_cut_score == brute_force_score, std::abs(graph_cut_score - brute_force_score)); + + dlog << LTRACE << "##################"; + dlog << LTRACE << "##################"; + dlog << LTRACE << "##################"; + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// BASIC MIN CUT STUFF +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template + void brute_force_min_cut ( + directed_graph& g, + unsigned long source, + unsigned long sink + ) + { + typedef typename directed_graph::edge_type edge_weight_type; + const unsigned long num = (unsigned long)std::pow(2.0, (double)g.number_of_nodes()); + + std::vector best_cut(g.number_of_nodes(),FREE_NODE); + + edge_weight_type best_score = std::numeric_limits::max(); + for (unsigned long i = 0; i < num; ++i) + { + for (unsigned long j = 0; j < g.number_of_nodes(); ++j) + { + unsigned long T = (1)< + void print_graph( + const directed_graph& g + ) + { + using namespace std; + dlog << LTRACE << "number of nodes: "<< g.number_of_nodes(); + for (unsigned long i = 0; i < g.number_of_nodes(); ++i) + { + for (unsigned long n = 0; n < g.node(i).number_of_children(); ++n) + dlog << LTRACE << i << " -(" << g.node(i).child_edge(n) << ")-> " << g.node(i).child(n).index(); + } + } + + template + void copy_edge_weights ( + directed_graph& dest, + const directed_graph& src + ) + { + for (unsigned long i = 0; i < src.number_of_nodes(); ++i) + { + for (unsigned long n = 0; n < src.node(i).number_of_children(); ++n) + { + dest.node(i).child_edge(n) = src.node(i).child_edge(n); + } + } + } + +// ---------------------------------------------------------------------------------------- + + template + void pick_random_source_and_sink ( + dlib::rand& rnd, + const graph_type& g, + unsigned long& source, + unsigned long& sink + ) + { + source = rnd.get_random_32bit_number()%g.number_of_nodes(); + sink = rnd.get_random_32bit_number()%g.number_of_nodes(); + while (sink == source) + sink = rnd.get_random_32bit_number()%g.number_of_nodes(); + } + +// ---------------------------------------------------------------------------------------- + + template + void make_random_graph( + dlib::rand& rnd, + dgraph_type& g, + unsigned long& source, + unsigned long& sink + ) + { + typedef typename dgraph_type::edge_type edge_weight_type; + g.clear(); + const unsigned int num_nodes = rnd.get_random_32bit_number()%7 + 2; + g.set_number_of_nodes(num_nodes); + + const unsigned int num_edges = static_cast(num_nodes*(num_nodes-1)/2*rnd.get_random_double() + 0.5); + + // add the right number of randomly selected edges + unsigned int count = 0; + while (count < num_edges) + { + unsigned long parent = rnd.get_random_32bit_number()%g.number_of_nodes(); + unsigned long child = rnd.get_random_32bit_number()%g.number_of_nodes(); + if (parent != child && g.has_edge(parent, child) == false) + { + ++count; + g.add_edge(parent, child); + edge(g, parent, child) = static_cast(rnd.get_random_double()*50); + + // have to have edges both ways + swap(parent, child); + g.add_edge(parent, child); + edge(g, parent, child) = static_cast(rnd.get_random_double()*50); + } + } + + pick_random_source_and_sink(rnd, g, source, sink); + } + +// ---------------------------------------------------------------------------------------- + + template + void make_random_chain_graph( + dlib::rand& rnd, + dgraph_type& g, + unsigned long& source, + unsigned long& sink + ) + { + typedef typename dgraph_type::edge_type edge_weight_type; + g.clear(); + const unsigned int num_nodes = rnd.get_random_32bit_number()%7 + 2; + g.set_number_of_nodes(num_nodes); + + for (unsigned long i = 1; i < g.number_of_nodes(); ++i) + { + g.add_edge(i,i-1); + g.add_edge(i-1,i); + edge(g, i, i-1) = static_cast(rnd.get_random_double()*50); + edge(g, i-1, i) = static_cast(rnd.get_random_double()*50); + } + + pick_random_source_and_sink(rnd, g, source, sink); + } + +// ---------------------------------------------------------------------------------------- + + template + void make_random_grid_graph( + dlib::rand& rnd, + dgraph_type& g, + unsigned long& source, + unsigned long& sink + ) + /*! + ensures + - makes a grid graph like the kind used for potts models. + !*/ + { + typedef typename dgraph_type::edge_type edge_weight_type; + g.clear(); + const long nr = rnd.get_random_32bit_number()%2 + 2; + const long nc = rnd.get_random_32bit_number()%2 + 2; + g.set_number_of_nodes(nr*nc+2); + + const rectangle rect(0,0,nc-1,nr-1); + for (long r = 0; r < nr; ++r) + { + for (long c = 0; c < nc; ++c) + { + const point p(c,r); + const unsigned long i = p.y()*nc + p.x(); + + const point n2(c-1,r); + if (rect.contains(n2)) + { + const unsigned long j = n2.y()*nc + n2.x(); + g.add_edge(i,j); + g.add_edge(j,i); + edge(g,i,j) = static_cast(rnd.get_random_double()*50); + edge(g,j,i) = static_cast(rnd.get_random_double()*50); + } + + const point n4(c,r-1); + if (rect.contains(n4)) + { + const unsigned long j = n4.y()*nc + n4.x(); + g.add_edge(i,j); + g.add_edge(j,i); + edge(g,i,j) = static_cast(rnd.get_random_double()*50); + edge(g,j,i) = static_cast(rnd.get_random_double()*50); + } + } + } + + // use the last two nodes as source and sink. Also connect them to all the other nodes. + source = g.number_of_nodes()-1; + sink = g.number_of_nodes()-2; + for (unsigned long i = 0; i < g.number_of_nodes()-2; ++i) + { + g.add_edge(i,source); + g.add_edge(source,i); + g.add_edge(i,sink); + g.add_edge(sink,i); + + edge(g,i,source) = static_cast(rnd.get_random_double()*50); + edge(g,source,i) = static_cast(rnd.get_random_double()*50); + edge(g,i,sink) = static_cast(rnd.get_random_double()*50); + edge(g,sink,i) = static_cast(rnd.get_random_double()*50); + } + } + +// ---------------------------------------------------------------------------------------- + + template + void run_test_on_graphs ( + const min_cut& mc, + dgraph_type& g1, + dgraph_type& g2, + unsigned long source, + unsigned long sink + ) + { + typedef typename dgraph_type::edge_type edge_weight_type; + using namespace std; + + + dlog << LTRACE << "number of nodes: "<< g1.number_of_nodes(); + dlog << LTRACE << "is graph connected: "<< graph_is_connected(g1); + dlog << LTRACE << "has self loops: "<< graph_contains_length_one_cycle(g1); + dlog << LTRACE << "SOURCE_CUT: " << source; + dlog << LTRACE << "SINK_CUT: " << sink; + mc(g1, source, sink); + brute_force_min_cut(g2, source, sink); + + print_graph(g1); + + // make sure the flow residuals are 0 at the cut locations + for (unsigned long i = 0; i < g1.number_of_nodes(); ++i) + { + for (unsigned long j = 0; j < g1.node(i).number_of_children(); ++j) + { + if ((g1.node(i).data == SOURCE_CUT && g1.node(i).child(j).data != SOURCE_CUT) || + (g1.node(i).data != SINK_CUT && g1.node(i).child(j).data == SINK_CUT) + ) + { + DLIB_TEST_MSG(g1.node(i).child_edge(j) == 0, g1.node(i).child_edge(j)); + } + } + } + + // copy the edge weights from g2 back to g1 so we can compute cut scores + copy_edge_weights(g1, g2); + + DLIB_TEST(g1.number_of_nodes() == g2.number_of_nodes()); + for (unsigned long i = 0; i < g1.number_of_nodes(); ++i) + { + dlog << LTRACE << "node " << i << ": " << (int)g1.node(i).data << ", " << (int)g2.node(i).data; + if (g1.node(i).data != g2.node(i).data) + { + edge_weight_type cut_score = graph_cut_score(g1); + edge_weight_type brute_force_score = graph_cut_score(g2); + dlog << LTRACE << "graph cut score: "<< cut_score; + dlog << LTRACE << "brute force score: "<< brute_force_score; + + if (brute_force_score != cut_score) + print_graph(g1); + DLIB_TEST_MSG(brute_force_score == cut_score,std::abs(brute_force_score-cut_score)); + } + } + + } + +// ---------------------------------------------------------------------------------------- + + template + void test_graph_cuts(dlib::rand& rnd) + { + typedef typename dlib::directed_graph::kernel_1a_c dgraph_type; + // we will create two identical graphs. + dgraph_type g1, g2; + min_cut mc; + + unsigned long source, sink; + + dlib::rand rnd_copy(rnd); + make_random_graph(rnd,g1, source, sink); + make_random_graph(rnd_copy,g2, source, sink); + run_test_on_graphs(mc, g1, g2, source, sink); + + rnd_copy = rnd; + make_random_grid_graph(rnd,g1, source, sink); + make_random_grid_graph(rnd_copy,g2, source, sink); + run_test_on_graphs(mc, g1, g2, source, sink); + + rnd_copy = rnd; + make_random_chain_graph(rnd,g1, source, sink); + make_random_chain_graph(rnd_copy,g2, source, sink); + run_test_on_graphs(mc, g1, g2, source, sink); + + } + +// ---------------------------------------------------------------------------------------- + + class test_potts_grid_problem + { + public: + test_potts_grid_problem(int seed_) :seed(seed_){} + int seed; + + long nr() const { return 3;} + long nc() const { return 3;} + + typedef double value_type; + + value_type factor_value(unsigned long idx) const + { + // Copy idx into a char buffer to avoid warnings about violation of strict aliasing + // rules when murmur_hash3() gets inlined into this function. + char buf[sizeof(idx)]; + memcpy(buf,&idx,sizeof(idx)); + // now hash the buffer rather than idx. + return ((double)murmur_hash3(buf, sizeof(buf), seed) - std::numeric_limits::max()/2.0)/1000.0; + } + + value_type factor_value_disagreement(unsigned long idx1, unsigned long idx2) const + { + return std::abs(factor_value(idx1+idx2)/10.0); + } + }; + +// ---------------------------------------------------------------------------------------- + + template + void brute_force_potts_grid_problem( + const prob_type& prob, + array2d& labels + ) + { + const unsigned long num = (unsigned long)std::pow(2.0, (double)prob.nr()*prob.nc()); + + array2d temp(prob.nr(), prob.nc()); + unsigned char* data = &temp[0][0]; + + double best_score = -std::numeric_limits::infinity(); + for (unsigned long i = 0; i < num; ++i) + { + for (unsigned long j = 0; j < temp.size(); ++j) + { + unsigned long T = (1)< best_score) + { + best_score = score; + assign_image(labels, temp); + } + } + } + + void test_inf() + { + graph::kernel_1a_c g; + g.set_number_of_nodes(4); + g.add_edge(0,1); + g.add_edge(1,2); + g.add_edge(2,3); + g.add_edge(3,0); + + g.node(0).data = std::numeric_limits::infinity(); + g.node(1).data = -std::numeric_limits::infinity(); + g.node(2).data = std::numeric_limits::infinity(); + g.node(3).data = -std::numeric_limits::infinity(); + + edge(g,0,1) = 1; + edge(g,1,2) = 1; + edge(g,2,3) = 1; + edge(g,3,0) = 1; + + std::vector labels; + find_max_factor_graph_potts(g, labels); + + DLIB_TEST(labels[0] != 0); + DLIB_TEST(labels[1] == 0); + DLIB_TEST(labels[2] != 0); + DLIB_TEST(labels[3] == 0); + + // -------------------------- + + g.node(0).data = std::numeric_limits::infinity(); + g.node(1).data = 0; + g.node(2).data = 0; + g.node(3).data = -3; + + edge(g,0,1) = 1; + edge(g,1,2) = 1; + edge(g,2,3) = 1; + edge(g,3,0) = 1; + + find_max_factor_graph_potts(g, labels); + + DLIB_TEST(labels[0] != 0); + DLIB_TEST(labels[1] != 0); + DLIB_TEST(labels[2] != 0); + DLIB_TEST(labels[3] == 0); + + // -------------------------- + + g.node(0).data = std::numeric_limits::infinity(); + g.node(1).data = 0; + g.node(2).data = 0; + g.node(3).data = -0.1; + + edge(g,0,1) = 1; + edge(g,1,2) = 1; + edge(g,2,3) = 1; + edge(g,3,0) = 1; + + find_max_factor_graph_potts(g, labels); + + DLIB_TEST(labels[0] != 0); + DLIB_TEST(labels[1] != 0); + DLIB_TEST(labels[2] != 0); + DLIB_TEST(labels[3] != 0); + + // -------------------------- + + g.node(0).data = std::numeric_limits::infinity(); + g.node(1).data = 0; + g.node(2).data = 0; + g.node(3).data = -0.1; + + edge(g,0,1) = 1; + edge(g,1,2) = 1; + edge(g,2,3) = 0; + edge(g,3,0) = 0; + + find_max_factor_graph_potts(g, labels); + + DLIB_TEST(labels[0] != 0); + DLIB_TEST(labels[1] != 0); + DLIB_TEST(labels[2] != 0); + DLIB_TEST(labels[3] == 0); + + // -------------------------- + + g.node(0).data = -std::numeric_limits::infinity(); + g.node(1).data = 0; + g.node(2).data = 0; + g.node(3).data = 0.1; + + edge(g,0,1) = 1; + edge(g,1,2) = 1; + edge(g,2,3) = 0; + edge(g,3,0) = 0; + + find_max_factor_graph_potts(g, labels); + + DLIB_TEST(labels[0] == 0); + DLIB_TEST(labels[1] == 0); + DLIB_TEST(labels[2] == 0); + DLIB_TEST(labels[3] != 0); + + // -------------------------- + + g.node(0).data = -std::numeric_limits::infinity(); + g.node(1).data = std::numeric_limits::infinity(); + g.node(2).data = 0; + g.node(3).data = 0.1; + + edge(g,0,1) = 1; + edge(g,1,2) = 1; + edge(g,2,3) = 0; + edge(g,3,0) = 0; + + find_max_factor_graph_potts(g, labels); + + DLIB_TEST(labels[0] == 0); + DLIB_TEST(labels[1] != 0); + DLIB_TEST(labels[2] != 0); + DLIB_TEST(labels[3] != 0); + + // -------------------------- + + g.node(0).data = -10; + g.node(1).data = std::numeric_limits::infinity(); + g.node(2).data = 0; + g.node(3).data = 0.1; + + edge(g,0,1) = std::numeric_limits::infinity(); + edge(g,1,2) = std::numeric_limits::infinity(); + edge(g,2,3) = std::numeric_limits::infinity(); + edge(g,3,0) = std::numeric_limits::infinity(); + + find_max_factor_graph_potts(g, labels); + + DLIB_TEST(labels[0] != 0); + DLIB_TEST(labels[1] != 0); + DLIB_TEST(labels[2] != 0); + DLIB_TEST(labels[3] != 0); + + // -------------------------- + + g.node(0).data = 10; + g.node(1).data = -std::numeric_limits::infinity(); + g.node(2).data = 20.05; + g.node(3).data = -0.1; + + edge(g,0,1) = std::numeric_limits::infinity(); + edge(g,1,2) = 10; + edge(g,2,3) = std::numeric_limits::infinity(); + edge(g,3,0) = 10; + + find_max_factor_graph_potts(g, labels); + + DLIB_TEST(labels[0] == 0); + DLIB_TEST(labels[1] == 0); + DLIB_TEST(labels[2] == 0); + DLIB_TEST(labels[3] == 0); + + // -------------------------- + + g.node(0).data = 10; + g.node(1).data = -std::numeric_limits::infinity(); + g.node(2).data = 20.2; + g.node(3).data = -0.1; + + edge(g,0,1) = std::numeric_limits::infinity(); + edge(g,1,2) = 10; + edge(g,2,3) = std::numeric_limits::infinity(); + edge(g,3,0) = 10; + + find_max_factor_graph_potts(g, labels); + + DLIB_TEST(labels[0] == 0); + DLIB_TEST(labels[1] == 0); + DLIB_TEST(labels[2] != 0); + DLIB_TEST(labels[3] != 0); + } + + struct potts_pair_image_model + { + typedef double value_type; + + template + value_type factor_value ( + const pixel_type1& , + const pixel_type2& v2 + ) const + { + return v2; + } + + template + value_type factor_value_disagreement ( + const pixel_type& v1, + const pixel_type& v2 + ) const + { + if (v1 == v2) + return 10; + else + return 0; + } + }; + + void test_potts_pair_grid() + { + array2d img1(40,40); + array2d img2(40,40); + + assign_all_pixels(img1, -1); + assign_all_pixels(img2, -1); + + img1[4][4] = 1000; + + img2[4][3] = 1; + img2[4][4] = 1; + img2[4][5] = 1; + img2[3][3] = 1; + img2[3][4] = 1; + img2[3][5] = 1; + img2[5][3] = 1; + img2[5][4] = 1; + img2[5][5] = 1; + + array2d labels; + find_max_factor_graph_potts(make_potts_grid_problem(potts_pair_image_model(),img2,img1), labels); + + dlog << LINFO << "num true labels: " << sum(matrix_cast(mat(labels)!=0)); + DLIB_TEST(sum(matrix_cast(mat(labels)!=0)) == 9); + DLIB_TEST(sum(matrix_cast(mat(labels)==0)) == (int)img1.size()-9); + + DLIB_TEST(labels[4][3] != 0); + DLIB_TEST(labels[4][4] != 0); + DLIB_TEST(labels[4][5] != 0); + DLIB_TEST(labels[3][3] != 0); + DLIB_TEST(labels[3][4] != 0); + DLIB_TEST(labels[3][5] != 0); + DLIB_TEST(labels[5][3] != 0); + DLIB_TEST(labels[5][4] != 0); + DLIB_TEST(labels[5][5] != 0); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class graph_cuts_tester : public tester + { + public: + graph_cuts_tester ( + ) : + tester ("test_graph_cuts", + "Runs tests on the graph cuts tools.") + {} + + dlib::rand rnd; + + void perform_test ( + ) + { + test_potts_pair_grid(); + test_inf(); + + for (int i = 0; i < 500; ++i) + { + array2d labels, brute_labels; + test_potts_grid_problem prob(i); + find_max_factor_graph_potts(prob, labels); + brute_force_potts_grid_problem(prob, brute_labels); + + DLIB_TEST(labels.nr() == brute_labels.nr()); + DLIB_TEST(labels.nc() == brute_labels.nc()); + for (long r = 0; r < labels.nr(); ++r) + { + for (long c = 0; c < labels.nc(); ++c) + { + bool normal = (labels[r][c] != 0); + bool brute = (brute_labels[r][c] != 0); + DLIB_TEST(normal == brute); + } + } + } + + for (int i = 0; i < 1000; ++i) + { + print_spinner(); + dlog << LTRACE << "test_grpah_cuts iter: " << i; + test_graph_cuts(rnd); + print_spinner(); + dlog << LTRACE << "test_grpah_cuts iter: " << i; + test_graph_cuts(rnd); + } + + + for (int k = 0; k < 300; ++k) + { + dlog << LTRACE << "image_potts_problem iter " << k; + print_spinner(); + image_potts_problem p(3,3, rnd); + impl_test_potts_model(p); + } + for (int k = 0; k < 300; ++k) + { + dlog << LTRACE << "dense_potts_problem iter " << k; + print_spinner(); + dense_potts_problem p(6, rnd); + impl_test_potts_model(p); + } + + for (int k = 0; k < 300; ++k) + { + dlog << LTRACE << "dense_potts_problem iter " << k; + print_spinner(); + test_graph_potts_model(rnd); + } + } + } a; + + +} + + + + diff --git a/ml/dlib/dlib/test/graph_labeler.cpp b/ml/dlib/dlib/test/graph_labeler.cpp new file mode 100644 index 000000000..112873896 --- /dev/null +++ b/ml/dlib/dlib/test/graph_labeler.cpp @@ -0,0 +1,472 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#include +#include +#include +#include +#include +#include + +#include "tester.h" + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + + + logger dlog("test.graph_cuts"); + + + template < + typename graph_type, + typename samples_type, + typename labels_type + > + void make_data( + samples_type& samples, + labels_type& labels + ) + { + //samples.clear(); + //labels.clear(); + + std::vector label; + graph_type g; + + // --------------------------- + g.set_number_of_nodes(4); + label.resize(g.number_of_nodes()); + g.node(0).data = 0, 0, 1; label[0] = true; + g.node(1).data = 0, 0, 1; label[1] = true; + g.node(2).data = 0, 1, 0; label[2] = false; + g.node(3).data = 0, 1, 0; label[3] = false; + + g.add_edge(0,1); + g.add_edge(1,2); + g.add_edge(2,3); + g.add_edge(3,0); + + edge(g,0,1) = 1, 1; + edge(g,1,2) = 1, 1; + edge(g,2,3) = 1, 1; + edge(g,3,0) = 1, 1; + samples.push_back(g); + labels.push_back(label); + // --------------------------- + + g.clear(); + g.set_number_of_nodes(4); + label.resize(g.number_of_nodes()); + g.node(0).data = 0, 0, 1; label[0] = true; + g.node(1).data = 0, 0, 0; label[1] = true; + g.node(2).data = 0, 1, 0; label[2] = false; + g.node(3).data = 0, 0, 0; label[3] = false; + + g.add_edge(0,1); + g.add_edge(1,2); + g.add_edge(2,3); + g.add_edge(3,0); + + edge(g,0,1) = 1, 0; + edge(g,1,2) = 0, 1; + edge(g,2,3) = 1, 0; + edge(g,3,0) = 0, 1; + samples.push_back(g); + labels.push_back(label); + // --------------------------- + + g.clear(); + g.set_number_of_nodes(4); + label.resize(g.number_of_nodes()); + g.node(0).data = 0, 1, 0; label[0] = false; + g.node(1).data = 0, 1, 0; label[1] = false; + g.node(2).data = 0, 1, 0; label[2] = false; + g.node(3).data = 0, 0, 0; label[3] = false; + + g.add_edge(0,1); + g.add_edge(1,2); + g.add_edge(2,3); + g.add_edge(3,0); + + edge(g,0,1) = 1, 0; + edge(g,1,2) = 0, 1; + edge(g,2,3) = 1, 0; + edge(g,3,0) = 0, 1; + samples.push_back(g); + labels.push_back(label); + // --------------------------- + } + + + + + template < + typename graph_type, + typename samples_type, + typename labels_type + > + void make_data_sparse( + samples_type& samples, + labels_type& labels + ) + { + //samples.clear(); + //labels.clear(); + + std::vector label; + graph_type g; + typename graph_type::edge_type v; + + // --------------------------- + g.set_number_of_nodes(4); + label.resize(g.number_of_nodes()); + g.node(0).data[2] = 1; label[0] = true; + g.node(1).data[2] = 1; label[1] = true; + g.node(2).data[1] = 1; label[2] = false; + g.node(3).data[1] = 1; label[3] = false; + + g.add_edge(0,1); + g.add_edge(1,2); + g.add_edge(2,3); + g.add_edge(3,0); + g.add_edge(3,1); + + v[0] = 1; v[1] = 1; + edge(g,0,1) = v; + edge(g,1,2) = v; + edge(g,2,3) = v; + edge(g,3,0) = v; + samples.push_back(g); + labels.push_back(label); + // --------------------------- + + g.clear(); + g.set_number_of_nodes(5); + label.resize(g.number_of_nodes()); + g.node(0).data[2] = 1; label[0] = true; + g.node(1).data[0] = 0; label[1] = true; + g.node(2).data[1] = 1; label[2] = false; + g.node(3).data[0] = 0; label[3] = false; + label[4] = true; + + g.add_edge(0,1); + g.add_edge(1,4); + g.add_edge(1,2); + g.add_edge(2,3); + g.add_edge(3,0); + + edge(g,0,1)[0] = 1; + edge(g,1,4)[0] = 1; + edge(g,1,2)[1] = 1; + edge(g,2,3)[0] = 1; + edge(g,3,0)[1] = 1; + samples.push_back(g); + labels.push_back(label); + // --------------------------- + + g.clear(); + g.set_number_of_nodes(4); + label.resize(g.number_of_nodes()); + g.node(0).data[1] = 1; label[0] = false; + g.node(1).data[1] = 1; label[1] = false; + g.node(2).data[1] = 1; label[2] = false; + g.node(3).data[1] = 0; label[3] = false; + + g.add_edge(0,1); + g.add_edge(1,2); + g.add_edge(2,3); + g.add_edge(3,0); + + edge(g,0,1)[0] = 1; + edge(g,1,2)[1] = 1; + edge(g,2,3)[0] = 1; + edge(g,3,0)[1] = 1; + samples.push_back(g); + labels.push_back(label); + // --------------------------- + } + + + + + + + template < + typename graph_type, + typename samples_type, + typename labels_type + > + void make_data2( + samples_type& samples, + labels_type& labels + ) + { + //samples.clear(); + //labels.clear(); + + std::vector label; + graph_type g; + + // --------------------------- + g.set_number_of_nodes(4); + label.resize(g.number_of_nodes()); + g.node(0).data = 0, 0, 1; label[0] = true; + g.node(1).data = 0, 0, 1; label[1] = true; + g.node(2).data = 0, 1, 0; label[2] = false; + g.node(3).data = 0, 1, 0; label[3] = false; + + g.add_edge(0,1); + g.add_edge(1,2); + g.add_edge(2,3); + g.add_edge(3,0); + + edge(g,0,1) = 1, 1; + edge(g,1,2) = 1, 1; + edge(g,2,3) = 1, 1; + edge(g,3,0) = 1, 1; + samples.push_back(g); + labels.push_back(label); + // --------------------------- + } + + + + + template < + typename graph_type, + typename samples_type, + typename labels_type + > + void make_data2_sparse( + samples_type& samples, + labels_type& labels + ) + { + //samples.clear(); + //labels.clear(); + + std::vector label; + graph_type g; + typename graph_type::edge_type v; + + // --------------------------- + g.set_number_of_nodes(4); + label.resize(g.number_of_nodes()); + g.node(0).data[2] = 1; label[0] = true; + g.node(1).data[2] = 1; label[1] = true; + g.node(2).data[1] = 1; label[2] = false; + g.node(3).data[1] = 1; label[3] = false; + + g.add_edge(0,1); + g.add_edge(1,2); + g.add_edge(2,3); + g.add_edge(3,0); + + v[0] = 1; v[1] = 1; + edge(g,0,1) = v; + edge(g,1,2) = v; + edge(g,2,3) = v; + edge(g,3,0) = v; + samples.push_back(g); + labels.push_back(label); + // --------------------------- + + } + + + + + + template < + typename node_vector_type, + typename edge_vector_type, + typename vector_type, + typename graph_type + > + void test1( + const dlib::array& samples, + const std::vector >& labels + ) + { + dlog << LINFO << "begin test1()"; + + structural_graph_labeling_trainer trainer; + //trainer.be_verbose(); + trainer.set_epsilon(1e-12); + graph_labeler labeler = trainer.train(samples, labels); + + + // test serialization code for the labeler. + std::ostringstream sout; + serialize(labeler, sout); + std::istringstream sin(sout.str()); + labeler = graph_labeler(); + deserialize(labeler, sin); + + std::vector temp; + for (unsigned long k = 0; k < samples.size(); ++k) + { + temp = labeler(samples[k]); + for (unsigned long i = 0; i < temp.size(); ++i) + { + const bool true_label = (labels[k][i] != 0); + const bool pred_label = (temp[i] != 0); + DLIB_TEST(true_label == pred_label); + } + } + + matrix cv; + + cv = test_graph_labeling_function(labeler, samples, labels); + DLIB_TEST(sum(cv) == 2); + cv = cross_validate_graph_labeling_trainer(trainer, samples, labels, 4); + DLIB_TEST(sum(cv) == 2); + + dlog << LINFO << "edge weights: " << trans(sparse_to_dense(labeler.get_edge_weights())); + dlog << LINFO << "node weights: " << trans(sparse_to_dense(labeler.get_node_weights())); + } + + + + class graph_labeling_tester : public tester + { + public: + graph_labeling_tester ( + ) : + tester ("test_graph_labeling", + "Runs tests on the graph labeling component.") + {} + + void perform_test ( + ) + { + print_spinner(); + // test with dense vectors + { + typedef matrix node_vector_type; + typedef matrix edge_vector_type; + typedef matrix vector_type; + typedef dlib::graph::kernel_1a_c graph_type; + + dlib::array samples; + std::vector > labels; + + make_data(samples, labels); + make_data(samples, labels); + make_data(samples, labels); + make_data(samples, labels); + + + test1(samples, labels); + } + print_spinner(); + // test with dense vectors and sparse vectors together + { + typedef matrix node_vector_type; + typedef matrix edge_vector_type; + typedef std::map vector_type; + typedef dlib::graph::kernel_1a_c graph_type; + + dlib::array samples; + std::vector > labels; + + make_data(samples, labels); + make_data(samples, labels); + make_data(samples, labels); + make_data(samples, labels); + + + test1(samples, labels); + } + print_spinner(); + // test with sparse vectors + { + typedef std::vector > vector_type; + typedef std::map edge_vector_type; + typedef std::map node_vector_type; + typedef dlib::graph::kernel_1a_c graph_type; + + dlib::array samples; + std::vector > labels; + + make_data_sparse(samples, labels); + make_data_sparse(samples, labels); + make_data_sparse(samples, labels); + make_data_sparse(samples, labels); + + + test1(samples, labels); + } + + + + print_spinner(); + // test with dense vectors + { + typedef matrix node_vector_type; + typedef matrix edge_vector_type; + typedef matrix vector_type; + typedef dlib::graph::kernel_1a_c graph_type; + + dlib::array samples; + std::vector > labels; + + make_data2(samples, labels); + make_data2(samples, labels); + make_data2(samples, labels); + make_data2(samples, labels); + + + test1(samples, labels); + } + print_spinner(); + // test with sparse vectors + { + typedef std::vector > vector_type; + typedef std::map edge_vector_type; + typedef std::map node_vector_type; + typedef dlib::graph::kernel_1a_c graph_type; + + dlib::array samples; + std::vector > labels; + + make_data2_sparse(samples, labels); + make_data2_sparse(samples, labels); + make_data2_sparse(samples, labels); + make_data2_sparse(samples, labels); + + + test1(samples, labels); + } + print_spinner(); + // test with sparse vectors and dense mix + { + typedef matrix vector_type; + typedef std::map edge_vector_type; + typedef std::map node_vector_type; + typedef dlib::graph::kernel_1a_c graph_type; + + dlib::array samples; + std::vector > labels; + + make_data2_sparse(samples, labels); + make_data2_sparse(samples, labels); + make_data2_sparse(samples, labels); + make_data2_sparse(samples, labels); + + + test1(samples, labels); + } + } + } a; + + +} + + + + diff --git a/ml/dlib/dlib/test/gui/CMakeLists.txt b/ml/dlib/dlib/test/gui/CMakeLists.txt new file mode 100644 index 000000000..2ab3c2b47 --- /dev/null +++ b/ml/dlib/dlib/test/gui/CMakeLists.txt @@ -0,0 +1,20 @@ +# +# This is a CMake makefile. You can find the cmake utility and +# information about it at http://www.cmake.org +# + +# create a variable called target_name and set it to the string "test" +set (target_name gui) + +project(${target_name}) + +add_subdirectory(../.. dlib_build) + +# add all the cpp files we want to compile to this list. This tells +# cmake that they are part of our target (which is the executable named test) +add_executable(${target_name} main.cpp ) + + +# Tell cmake to link our target executable to dlib. +target_link_libraries(${target_name} dlib::dlib ) + diff --git a/ml/dlib/dlib/test/gui/main.cpp b/ml/dlib/dlib/test/gui/main.cpp new file mode 100644 index 000000000..d61aba8c0 --- /dev/null +++ b/ml/dlib/dlib/test/gui/main.cpp @@ -0,0 +1,840 @@ +#include +#include +#include +#include +#include + +#include "dlib/image_io.h" +#include "dlib/array2d.h" +#include "dlib/gui_core.h" +#include "dlib/assert.h" +#include "dlib/misc_api.h" + +#include "dlib/image_transforms.h" + +#include "dlib/timer.h" + +#include "dlib/gui_widgets.h" +#include "dlib/queue.h" + +using namespace dlib; +using namespace std; + + +typedef dlib::array2d image; + + + + +#include "dlib/base64.h" + + + + +class color_box : public draggable +{ + unsigned char red, green,blue; + +public: + color_box ( + drawable_window& w, + rectangle area, + unsigned char red_, + unsigned char green_, + unsigned char blue_ + ) : + draggable(w, MOUSE_WHEEL), + red(red_), + green(green_), + blue(blue_), + t(*this,&color_box::action) + { + rect = area; + + t.set_delay_time(4); + // t.start(); + + set_draggable_area(rectangle(10,10,500,500)); + + enable_events(); + } + + ~color_box() + { + disable_events(); + } + +private: + + void action ( + ) + { + ++red; + parent.invalidate_rectangle(rect); + } + + void draw ( + const canvas& c + ) const + { + if (hidden == false ) + { + fill_rect(c,rect,rgb_pixel(red,green,blue)); + std::vector poly; + poly.push_back((rect.tl_corner()+rect.tr_corner())/2); + poly.push_back((rect.tr_corner()+rect.br_corner())/2); + poly.push_back((rect.br_corner()+rect.bl_corner())/2); + poly.push_back((rect.bl_corner()+rect.tl_corner())/2); + draw_solid_convex_polygon(c,poly,rgb_alpha_pixel(0,0,0,70)); + } + } + + void on_wheel_up( + unsigned long state + ) + { + if (state == base_window::NONE) + cout << "up scroll, NONE" << endl; + else if (state&base_window::LEFT) + cout << "up scroll, LEFT" << endl; + else if (state&base_window::RIGHT) + cout << "up scroll, RIGHT" << endl; + else if (state&base_window::MIDDLE) + cout << "up scroll, MIDDLE" << endl; + else if (state&base_window::SHIFT) + cout << "up scroll, SHIFT" << endl; + else if (state&base_window::CONTROL) + cout << "up scroll, CONTROL" << endl; + + } + + void on_wheel_down( + unsigned long state + ) + { + + if (state == base_window::NONE) + cout << "down scroll, NONE" << endl; + else if (state&base_window::LEFT) + cout << "down scroll, LEFT" << endl; + else if (state&base_window::RIGHT) + cout << "down scroll, RIGHT" << endl; + else if (state&base_window::MIDDLE) + cout << "down scroll, MIDDLE" << endl; + else if (state&base_window::SHIFT) + cout << "down scroll, SHIFT" << endl; + else if (state&base_window::CONTROL) + cout << "down scroll, CONTROL" << endl; + + } + + + void on_window_resized () + { + draggable::on_window_resized(); + } + timer t; +}; + + + + + + +class win : public drawable_window +{ + + label lbl_last_keydown; + label lbl_mod_shift; + label lbl_mod_control; + label lbl_mod_alt; + label lbl_mod_meta; + label lbl_mod_caps_lock; + label lbl_mod_num_lock; + label lbl_mod_scroll_lock; + void on_keydown ( + unsigned long key, + bool is_printable, + unsigned long state + ) + { + if (is_printable) + lbl_last_keydown.set_text(string("last keydown: ") + (char)key); + else + lbl_last_keydown.set_text(string("last keydown: nonprintable")); + + if (state&base_window::KBD_MOD_SHIFT) + lbl_mod_shift.set_text("shift is on"); + else + lbl_mod_shift.set_text("shift is off"); + + if (state&base_window::KBD_MOD_CONTROL) + lbl_mod_control.set_text("control is on"); + else + lbl_mod_control.set_text("control is off"); + + if (state&base_window::KBD_MOD_ALT) + lbl_mod_alt.set_text("alt is on"); + else + lbl_mod_alt.set_text("alt is off"); + + + if (state&base_window::KBD_MOD_META) + lbl_mod_meta.set_text("meta is on"); + else + lbl_mod_meta.set_text("meta is off"); + + if (state&base_window::KBD_MOD_CAPS_LOCK) + lbl_mod_caps_lock.set_text("caps_lock is on"); + else + lbl_mod_caps_lock.set_text("caps_lock is off"); + + if (state&base_window::KBD_MOD_NUM_LOCK) + lbl_mod_num_lock.set_text("num_lock is on"); + else + lbl_mod_num_lock.set_text("num_lock is off"); + + + if (state&base_window::KBD_MOD_SCROLL_LOCK) + lbl_mod_scroll_lock.set_text("scroll_lock is on"); + else + lbl_mod_scroll_lock.set_text("scroll_lock is off"); + + drawable_window::on_keydown(key,is_printable,state); + } + + void rb_click ( + ) + { + if (rb.is_checked()) + rb.set_name("radio button checked"); + else + rb.set_name("radio button"); + rb.set_checked(); + } + + void cb_sb_enabled ( + toggle_button& + ) + { + if (sb_enabled.is_checked()) + { + sb.enable(); + lb.enable(); + b.enable(); + } + else + { + lb.disable(); + sb.disable(); + b.disable(); + } + + if (sb_enabled.is_checked()) + rb.enable(); + else + rb.disable(); + + if (sb_enabled.is_checked()) + tabs.enable(); + else + tabs.disable(); + + if (sb_enabled.is_checked()) + tf.enable(); + else + tf.disable(); + + if (sb_enabled.is_checked()) + tb.enable(); + else + tb.disable(); + + } + + void cb_sb_shown ( + ) + { + if (sb_shown.is_checked()) + { + sb.show(); + tabs.show(); + lb.show(); + } + else + { + sb.hide(); + tabs.hide(); + lb.hide(); + } + } + + + void tab_change ( + unsigned long new_idx, + unsigned long + ) + { + tab_label.set_text(tabs.tab_name(new_idx)); + } + + void scroll_handler ( + ) + { + ostringstream sout; + sout << "scroll bar pos: " << sb.slider_pos(); + sbl.set_text(sout.str()); + } + + void scroll2_handler ( + ) + { + sb.set_length(sb2.slider_pos()); + ostringstream sout; + sout << "scroll bar2 pos: " << sb2.slider_pos(); + sbl2.set_text(sout.str()); + scroll_handler(); + } + + void scroll3_handler ( + ) + { + sb.set_max_slider_pos(sb3.slider_pos()); + ostringstream sout; + sout << "scroll bar3 pos: " << sb3.slider_pos(); + sbl3.set_text(sout.str()); + scroll_handler(); + } + + void lb_double_click ( + unsigned long + ) + { + dlib::queue::kernel_2a_c sel; + lb.get_selected(sel); + sel.reset(); + while (sel.move_next()) + { + cout << lb[sel.element()] << endl; + } + //message_box("list_box",lb[idx]); + } + + void msg_box ( + ) + { + message_box("title","you clicked the ok button!\n HURRAY!"); + } + + static void try_this_junk ( + void* param + ) + { + win& p = *reinterpret_cast(param); + put_on_clipboard(p.tf.text() + "\nfoobar"); + + + } + + void on_set_clipboard ( + ) + { + create_new_thread(try_this_junk,this); + //try_this_junk(this); + } + + static void try_this_junk2 ( + void* + ) + { + + string temp; + get_from_clipboard(temp); + message_box("clipboard",temp); + + } + void on_get_clipboard ( + ) + { + create_new_thread(try_this_junk2,this); + } + + + void on_show_msg_click ( + ) + { + message_box("title","This is a test message.",*this,&win::msg_box); + } + + void on_menu_help ( + ) + { + message_box("About","This is the messy dlib gui regression test program"); + } + +public: + + ~win() + { + close_window(); + } + + void cbox_clicked ( + ) + { + if (cbox.is_checked()) + cbl.set_text(cbox.name() + " box is checked"); + else + cbl.set_text("box NOT is checked"); + } + + win ( + ): + drawable_window(true), + lbl_last_keydown(*this), + lbl_mod_shift(*this), + lbl_mod_control(*this), + lbl_mod_alt(*this), + lbl_mod_meta(*this), + lbl_mod_caps_lock(*this), + lbl_mod_num_lock(*this), + lbl_mod_scroll_lock(*this), + b(*this), + btn_count(*this), + btn_get_clipboard(*this), + btn_set_clipboard(*this), + btn_show_message(*this), + cb1(*this,rectangle(100,100,200,200),255,0,0), + cb2(*this,rectangle(150,150,250,240),0,255,0), + cbl(*this), + cbox(*this), + group1(*this), + group2(*this), + group3(*this), + keyboard_count(1), + keydown(*this), + keyup(*this), + l1(*this), + l2(*this), + l3(*this), + lb(*this), + leave_count(*this), + left_down(*this), + left_up(*this), + middle_down(*this), + middle_up(*this), + mouse_state(*this), + mt(*this), + nrect(*this), + pos(*this), + rb(*this), + right_down(*this), + right_up(*this), + sb2(*this,scroll_bar::VERTICAL), + sb3(*this,scroll_bar::VERTICAL), + sb_enabled(*this), + sbl2(*this), + sbl3(*this), + sbl(*this), + sb_shown(*this), + sb(*this,scroll_bar::HORIZONTAL), + scroll(*this), + tab_label(*this), + tabs(*this), + tf(*this), + tb(*this), + mbar(*this) + { + bool use_bdf_fonts = false; + + std::shared_ptr f(new bdf_font); + + if (use_bdf_fonts) + { + + ifstream fin("/home/davis/source/10x20.bdf"); + f->read_bdf_file(fin,0xFFFF); + + mt.set_main_font(f); + } + //mt.hide(); + mt.set_pos(5,200); + + + lbl_last_keydown.set_text("?"); + lbl_mod_shift.set_text("?"); + lbl_mod_control.set_text("?"); + lbl_mod_alt.set_text("?"); + lbl_mod_meta.set_text("?"); + lbl_mod_caps_lock.set_text("?"); + lbl_mod_num_lock.set_text("?"); + lbl_mod_scroll_lock.set_text("?"); + + lbl_last_keydown.set_pos(20,420); + lbl_mod_shift.set_pos(20,lbl_last_keydown.bottom()+5); + lbl_mod_control.set_pos(20,lbl_mod_shift.bottom()+5); + lbl_mod_alt.set_pos(20,lbl_mod_control.bottom()+5); + lbl_mod_meta.set_pos(20,lbl_mod_alt.bottom()+5); + lbl_mod_caps_lock.set_pos(20,lbl_mod_meta.bottom()+5); + lbl_mod_num_lock.set_pos(20,lbl_mod_caps_lock.bottom()+5); + lbl_mod_scroll_lock.set_pos(20,lbl_mod_num_lock.bottom()+5); + + lb.set_pos(580,200); + lb.set_size(200,300); + if (use_bdf_fonts) + lb.set_main_font(f); + + dlib::queue::kernel_2a_c qos; + string a; + a = "Davis"; qos.enqueue(a); + a = "king"; qos.enqueue(a); + a = "one"; qos.enqueue(a); + a = "two"; qos.enqueue(a); + a = "three"; qos.enqueue(a); + a = "yo yo yo alsdkjf asfj lsa jfsf\n this is a long phrase"; qos.enqueue(a); + a = "four"; qos.enqueue(a); + a = "five"; qos.enqueue(a); + a = "six"; qos.enqueue(a); + a = "seven"; qos.enqueue(a); + a = "eight"; qos.enqueue(a); + a = "nine"; qos.enqueue(a); + a = "ten"; qos.enqueue(a); + a = "eleven"; qos.enqueue(a); + a = "twelve"; qos.enqueue(a); + for (int i = 0; i < 1000; ++i) + { + a = "thirteen"; qos.enqueue(a); + } + lb.load(qos); + lb.select(1); + lb.select(2); + lb.select(3); + lb.select(5); + lb.enable_multiple_select(); + lb.set_double_click_handler(*this,&win::lb_double_click); + // lb.disable_multiple_select(); + + btn_show_message.set_pos(50,350); + btn_show_message.set_name("message_box()"); + mbar.set_number_of_menus(2); + mbar.set_menu_name(0,"File",'F'); + mbar.set_menu_name(1,"Help",'H'); + mbar.menu(0).add_menu_item(menu_item_text("show msg click",*this,&win::on_show_msg_click,'s')); + mbar.menu(0).add_menu_item(menu_item_text("get clipboard",*this,&win::on_get_clipboard,'g')); + mbar.menu(0).add_menu_item(menu_item_text("set clipboard",*this,&win::on_set_clipboard,'c')); + mbar.menu(0).add_menu_item(menu_item_separator()); + mbar.menu(0).add_submenu(menu_item_submenu("submenu",'m'), submenu); + submenu.add_menu_item(menu_item_separator()); + submenu.add_menu_item(menu_item_separator()); + submenu.add_menu_item(menu_item_text("show msg click",*this,&win::on_show_msg_click,'s')); + submenu.add_menu_item(menu_item_text("get clipboard",*this,&win::on_get_clipboard,'g')); + submenu.add_menu_item(menu_item_text("set clipboard",*this,&win::on_set_clipboard,'c')); + submenu.add_menu_item(menu_item_separator()); + submenu.add_menu_item(menu_item_separator()); + mbar.menu(1).add_menu_item(menu_item_text("About",*this,&win::on_menu_help,'A')); + + btn_show_message.set_click_handler(*this,&win::on_show_msg_click); + btn_get_clipboard.set_pos(btn_show_message.right()+5,btn_show_message.top()); + btn_get_clipboard.set_name("get_from_clipboard()"); + btn_get_clipboard.set_click_handler(*this,&win::on_get_clipboard); + + btn_get_clipboard.set_style(button_style_toolbar1()); + btn_set_clipboard.set_pos(btn_get_clipboard.right()+5,btn_get_clipboard.top()); + btn_set_clipboard.set_name("put_on_clipboard()"); + btn_set_clipboard.set_click_handler(*this,&win::on_set_clipboard); + + nrect.set_size(700,500); + nrect.set_name("test widgets"); + nrect.set_pos(2,mbar.bottom()+2); + + //throw dlib::error("holy crap batman"); + tab_label.set_pos(10,440); + + tabs.set_click_handler(*this,&win::tab_change); + tabs.set_pos(5,mbar.bottom()+10); + tabs.set_size(280,100); + tabs.set_number_of_tabs(3); + tabs.set_tab_name(0,"davis"); + tabs.set_tab_name(1,"edward"); + tabs.set_tab_name(2,"king alsklsdkfj asfd"); + tabs.set_tab_group(0,group1); + tabs.set_tab_group(1,group2); + tabs.set_tab_group(2,group3); + + l1.set_text("group one"); + l2.set_text("group two"); + l3.set_text("group three"); + + group1.add(l1,0,0); + group2.add(l2,20,10); + group3.add(l3,0,0); + + + + sb_enabled.set_name("enabled"); + sb_shown.set_name("shown"); + sb_shown.set_checked(); + sb_enabled.set_checked(); + sb_shown.set_click_handler(*this,&win::cb_sb_shown); + sb_enabled.set_click_handler(*this,&win::cb_sb_enabled); + + sb_shown.set_tooltip_text("I'm a checkbox"); + + rb.set_click_handler(*this,&win::rb_click); + + + sb3.set_pos(440,mbar.bottom()+10); + sb3.set_max_slider_pos(300); + sb3.set_slider_pos(150); + sb3.set_length(300); + sb2.set_pos(470,mbar.bottom()+10); + sb2.set_max_slider_pos(300); + sb2.set_length(300); + sb.set_pos(500,mbar.bottom()+10); + sb.set_max_slider_pos(30); + sb.set_length(300); + + + sb.set_scroll_handler(*this,&win::scroll_handler); + sb2.set_scroll_handler(*this,&win::scroll2_handler); + sb3.set_scroll_handler(*this,&win::scroll3_handler); + sbl.set_pos(540,mbar.bottom()+20); + sbl2.set_pos(540,mbar.bottom()+40); + sbl3.set_pos(540,mbar.bottom()+60); + + cbox.set_pos(300,mbar.bottom()+30); + cbox.set_name("davis king"); + cbox.set_click_handler(*this,&win::cbox_clicked); + + cbl.set_pos(300,cbox.get_rect().bottom()+1); + cbox.set_checked(); + sb_enabled.set_pos(cbox.get_rect().left(),cbox.get_rect().bottom()+20); + sb_shown.set_pos(sb_enabled.get_rect().left(),sb_enabled.get_rect().bottom()+2); + + + + if (use_bdf_fonts) + rb.set_main_font(f); + rb.set_name("radio button"); + rb.set_pos(sb_shown.get_rect().left(),sb_shown.get_rect().bottom()+2); + + + cb1.set_z_order(10); + cb2.set_z_order(20); + + pos.set_pos(50,50); + left_up.set_pos(50,70); + left_down.set_pos(50,90); + middle_up.set_pos(50,110); + middle_down.set_pos(50,130); + right_up.set_pos(50,150); + right_down.set_pos(50,170); + + mouse_state.set_pos(50,190); + + leave_count.set_pos(50,210); + + scroll_count = 0; + scroll.set_pos(50,230); + + btn_count.set_pos(50,250); + + + keydown.set_pos(50,270); + keyup.set_pos(50,290); + + tf.set_pos(50,310); + tf.set_text("Davis685g@"); + tf.set_width(500); + tf.set_text_color(rgb_pixel(255,0,0)); + tf.set_enter_key_handler(*this,&win::on_enter_key); + tf.set_focus_lost_handler(*this,&win::on_tf_focus_lost); + + tb.set_pos(250,400); + tb.set_text("initial test\nstring"); + tb.set_size(300,300); + tb.set_text_color(rgb_pixel(255,0,0)); + tb.set_enter_key_handler(*this,&win::on_enter_key); + tb.set_focus_lost_handler(*this,&win::on_tf_focus_lost); + + + button_count = 0; + count = 0; + b.set_name("button"); + b.set_pos(540,100); + b.set_click_handler(*this,&win::on_click); + b.set_tooltip_text("hurray i'm a button!"); + if (use_bdf_fonts) + b.set_main_font(f); + + + set_size(815,730); + + nrect.wrap_around( + cbox.get_rect() + + rb.get_rect() + + sb_enabled.get_rect() + + sb_shown.get_rect()); + + flip = 0; + open_file_box(*this,&win::on_open_file); + open_existing_file_box(*this,&win::on_open_file); + save_file_box(*this,&win::on_open_file); + + if (use_bdf_fonts) + { + tf.set_main_font(f); + tb.set_main_font(f); + } + if (use_bdf_fonts) + tabs.set_main_font(f); + + } + +private: + + + void on_enter_key() + { + cout << "enter key pressed" << endl; + } + + void on_tf_focus_lost() + { + cout << "text field/box lost focus" << endl; + } + + + void on_open_file (const std::string& file) + { + message_box("file opened",file); + } + + + + + void on_click ( + ) + { + ostringstream sout; + sout << "text field: " << tf.text(); + ++button_count; + btn_count.set_text(sout.str()); + + if (flip == 0) + { + flip = 1; + lb.set_size(200,200); + } + else if (flip == 1) + { + flip = 2; + lb.set_size(150,200); + } + else if (flip == 2) + { + flip = 3; + lb.set_size(150,300); + } + else + { + flip = 0; + lb.set_size(200,300); + } + } + + + button b; + label btn_count; + button btn_get_clipboard; + button btn_set_clipboard; + button btn_show_message; + int button_count; + color_box cb1; + color_box cb2; + label cbl; + check_box cbox; + int count; + int flip; + widget_group group1; + widget_group group2; + widget_group group3; + int keyboard_count; + label keydown; + label keyup; + label l1; + label l2; + label l3; + list_box lb; + label leave_count; + label left_down; + label left_up; + label middle_down; + label middle_up; + label mouse_state; + mouse_tracker mt; + named_rectangle nrect; + label pos; + radio_button rb; + label right_down; + label right_up; + scroll_bar sb2; + scroll_bar sb3; + check_box sb_enabled; + label sbl2; + label sbl3; + label sbl; + check_box sb_shown; + scroll_bar sb; + int scroll_count; + label scroll; + label tab_label; + tabbed_display tabs; + text_field tf; + text_box tb; + menu_bar mbar; + popup_menu submenu; + +}; + + +win w; + +int main() +{ + + try + { + + image_window win; + + array2d img; + img.set_size(100,100); + assign_all_pixels(img,0); + + fill_rect(img, rectangle(1,1,1,1), 255); + fill_rect(img, rectangle(1,3,2,5), 255); + fill_rect(img, rectangle(4,3,5,4), 255); + fill_rect(img, rectangle(9,9,13,10), 255); + + win.set_image(img); + + win.add_overlay(image_display::overlay_rect(rectangle(1,1,1,1), rgb_pixel(255,0,0))); + win.add_overlay(image_display::overlay_rect(rectangle(1,3,2,5), rgb_pixel(255,0,0))); + win.add_overlay(image_display::overlay_rect(rectangle(4,3,5,4), rgb_pixel(255,0,0))); + win.add_overlay(image_display::overlay_rect(rectangle(9,9,13,10), rgb_pixel(255,0,0))); + + + + w.set_pos (100,200); + w.set_title("test window"); + w.show(); + + w.wait_until_closed(); + } + catch (exception& e) + { + cout << e.what() << endl; + } + +} diff --git a/ml/dlib/dlib/test/hash.cpp b/ml/dlib/dlib/test/hash.cpp new file mode 100644 index 000000000..94930e629 --- /dev/null +++ b/ml/dlib/dlib/test/hash.cpp @@ -0,0 +1,369 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tester.h" + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.hash"); + + + template + void to_little ( + std::vector& item + ) + { + byte_orderer bo; + for (unsigned long i = 0; i < item.size(); ++i) + bo.host_to_little(item[i]); + } + + + template + void to_little ( + matrix& item + ) + { + byte_orderer bo; + for (long r = 0; r < item.nr(); ++r) + { + for (long c = 0; c < item.nc(); ++c) + { + bo.host_to_little(item(r,c)); + } + } + } + + // Run the official test for MurmurHash3 + void murmur_hash_test() + { + uint8 key[256]; + uint32 hashes[256]; + uint32 final = 0; + + memset(key,0,sizeof(key)); + memset(hashes,0,sizeof(hashes)); + + // Hash keys of the form {0}, {0,1}, {0,1,2}... up to N=255,using 256-N as + // the seed. + for(int i = 0; i < 256; i++) + { + key[i] = (uint8)i; + + hashes[i] = murmur_hash3(key,i,256-i); + } + + byte_orderer bo; + bo.host_to_little(hashes); + final = murmur_hash3(hashes,sizeof(hashes),0); + + // using ostringstream to avoid compiler error in visual studio 2005 + ostringstream sout; + sout << hex << final; + dlog << LINFO << "final: "<< sout.str(); + DLIB_TEST(final == 0xB0F57EE3); + } + + void murmur_hash_128_test() + { + uint8 key[256]; + uint64 hashes[256*2]; + uint32 final = 0; + + memset(key,0,sizeof(key)); + memset(hashes,0,sizeof(hashes)); + + // Hash keys of the form {0}, {0,1}, {0,1,2}... up to N=255,using 256-N as + // the seed. + for(int i = 0; i < 256; i++) + { + key[i] = (uint8)i; + + const std::pair temp = murmur_hash3_128bit(key,i,256-i); + hashes[2*i] = temp.first; + hashes[2*i+1] = temp.second; + } + + byte_orderer bo; + bo.host_to_little(hashes); + final = static_cast(murmur_hash3_128bit(hashes,sizeof(hashes),0).first); + + // using ostringstream to avoid compiler error in visual studio 2005 + ostringstream sout; + sout << hex << final; + dlog << LINFO << "final 64: "<< sout.str(); + DLIB_TEST(final == 0x6384BA69); + } + + void test_murmur_hash_128_4() + { + byte_orderer bo; + dlib::rand rnd; + for (int i = 0; i < 100; ++i) + { + uint32 buf[4] = { rnd.get_random_32bit_number(), + rnd.get_random_32bit_number(), + rnd.get_random_32bit_number(), + rnd.get_random_32bit_number() + }; + + bo.host_to_little(buf); + + std::pair temp1, temp2; + + // Make sure the 4 integer version of murmur hash does the same thing + // as the memory block version. + temp1 = murmur_hash3_128bit(buf, sizeof(buf), 0); + temp2 = murmur_hash3_128bit(buf[0], buf[1], buf[2], buf[3]); + DLIB_TEST( temp1.first == temp2.first); + DLIB_TEST( temp1.second == temp2.second); + } + } + + void test_murmur_hash_128_3() + { + byte_orderer bo; + dlib::rand rnd; + for (int i = 0; i < 100; ++i) + { + uint64 buf[2] = { rnd.get_random_64bit_number(), + rnd.get_random_64bit_number(), + }; + + const uint32 seed = rnd.get_random_32bit_number(); + + bo.host_to_little(buf); + std::pair temp1, temp2; + + // Make sure the 3 integer version of murmur hash does the same thing + // as the memory block version. + temp1 = murmur_hash3_128bit(buf, sizeof(buf), seed); + temp2 = murmur_hash3_128bit_3(buf[0], buf[1], seed); + DLIB_TEST( temp1.first == temp2.first); + DLIB_TEST( temp1.second == temp2.second); + } + } + + void test_murmur_hash_64_2() + { + byte_orderer bo; + dlib::rand rnd; + for (int i = 0; i < 100; ++i) + { + uint32 val = rnd.get_random_32bit_number(); + const uint32 seed = rnd.get_random_32bit_number(); + + + bo.host_to_little(val); + uint32 temp1, temp2; + + // Make sure the 2 integer version of murmur hash does the same thing + // as the memory block version. + temp1 = murmur_hash3(&val, sizeof(val), seed); + temp2 = murmur_hash3_2(val, seed); + DLIB_TEST(temp1 == temp2); + } + } + + void test_murmur_hash_64_3() + { + byte_orderer bo; + dlib::rand rnd; + for (int i = 0; i < 100; ++i) + { + uint32 buf[2] = {rnd.get_random_32bit_number(), + rnd.get_random_32bit_number()}; + const uint32 seed = rnd.get_random_32bit_number(); + + + bo.host_to_little(buf); + uint32 temp1, temp2; + + // Make sure the 2 integer version of murmur hash does the same thing + // as the memory block version. + temp1 = murmur_hash3(&buf, sizeof(buf), seed); + temp2 = murmur_hash3_3(buf[0], buf[1], seed); + DLIB_TEST(temp1 == temp2); + } + } + +// ---------------------------------------------------------------------------------------- + + uint64 slow_count_bits ( uint64 v) + { + uint64 count = 0; + for (int i = 0; i < 64; ++i) + { + if (v&1) + ++count; + v >>= 1; + } + return count; + } + + + uint32 slow_count_bits ( uint32 v) + { + uint32 count = 0; + for (int i = 0; i < 32; ++i) + { + if (v&1) + ++count; + v >>= 1; + } + return count; + } + + +// ---------------------------------------------------------------------------------------- + + void test_hamming_stuff() + { + dlib::rand rnd; + for (int i = 0; i < 10000; ++i) + { + uint32 v = rnd.get_random_32bit_number(); + uint64 v2 = rnd.get_random_64bit_number(); + DLIB_TEST(slow_count_bits(v) == count_bits(v)); + DLIB_TEST(slow_count_bits(v2) == count_bits(v2)); + } + + DLIB_TEST(hamming_distance((uint32)0x1F, (uint32)0x0F) == 1); + DLIB_TEST(hamming_distance((uint32)0x1F, (uint32)0x1F) == 0); + DLIB_TEST(hamming_distance((uint32)0x1F, (uint32)0x19) == 2); + DLIB_TEST(hamming_distance((uint32)0x2F, (uint32)0x19) == 4); + } + +// ---------------------------------------------------------------------------------------- + + class test_hash : public tester + { + public: + test_hash ( + ) : + tester ("test_hash", + "Runs tests on the hash routines.") + {} + + void perform_test ( + ) + { + print_spinner(); + + test_hamming_stuff(); + + murmur_hash_test(); + murmur_hash_128_test(); + + std::string str1 = "some random string"; + matrix mat(2,2); + + mat = 1,2,3,4; + + matrix mat2(2,3); + + mat2 = 1,2,3,4,5,6; + + to_little(mat2); + + std::vector v(4); + v[0] = 'c'; + v[1] = 'a'; + v[2] = 't'; + v[3] = '!'; + + std::vector v2(4); + v[0] = 'c'; + v[1] = 'a'; + v[2] = 't'; + v[3] = '!'; + to_little(v2); + + std::map m; + m['c'] = 'C'; + m['a'] = 'A'; + m['t'] = 'T'; + + dlog << LINFO << "hash(str1): "<< dlib::hash(str1); + dlog << LINFO << "hash(v): "<< dlib::hash(v); + dlog << LINFO << "hash(v2): "<< dlib::hash(v2); + dlog << LINFO << "hash(m): "<< dlib::hash(m); + dlog << LINFO << "hash(mat): "<< dlib::hash(mat); + dlog << LINFO << "hash(mat2): "<< dlib::hash(mat2); + + uint32 ui1 = 123485393; + uint64 ui2 = ui1; + ui2 *= ui2; + ui2 *= ui2; + dlog << LINFO << "hash(ui1): "<< dlib::hash(ui1); + dlog << LINFO << "hash(ui2): "<< dlib::hash(ui2); + dlog << LINFO << "hash(make_pair(ui2,ui1)): "<< dlib::hash(make_pair(ui2,ui1)); + dlog << LINFO << "hash(make_pair(ui2,ui2)): "<< dlib::hash(make_pair(ui2,ui2)); + dlog << LINFO << "hash(make_pair(ui1,ui1)): "<< dlib::hash(make_pair(ui1,ui1)); + dlog << LINFO << "hash(ui1,3): "<< dlib::hash(ui1,3); + dlog << LINFO << "hash(ui2,3): "<< dlib::hash(ui2,3); + dlog << LINFO << "hash(make_pair(ui2,ui1),3): "<< dlib::hash(make_pair(ui2,ui1),3); + dlog << LINFO << "hash(make_pair(ui2,ui2),3): "<< dlib::hash(make_pair(ui2,ui2),3); + dlog << LINFO << "hash(make_pair(ui1,ui1),3): "<< dlib::hash(make_pair(ui1,ui1),3); + + DLIB_TEST(dlib::hash(ui1) == 0x63e272e4); + DLIB_TEST(dlib::hash(ui2) == 0xaf55561a); + DLIB_TEST(dlib::hash(make_pair(ui2,ui1)) == 0x52685376); + DLIB_TEST(dlib::hash(make_pair(ui2,ui2)) == 0xd25d6929); + DLIB_TEST(dlib::hash(make_pair(ui1,ui1)) == 0xeea3b63e); + DLIB_TEST(dlib::hash(ui1,3) == 0x95d1c4c0); + DLIB_TEST(dlib::hash(ui2,3) == 0x6ada728d); + DLIB_TEST(dlib::hash(make_pair(ui2,ui1),3) == 0x2f72a0ff); + DLIB_TEST(dlib::hash(make_pair(ui2,ui2),3) == 0xac1407f0); + DLIB_TEST(dlib::hash(make_pair(ui1,ui1),3) == 0x39ad637a); + + + DLIB_TEST(dlib::hash(str1) == 0x3ffe6bf6); + DLIB_TEST(dlib::hash(v) == 0xf1af2ca6); + DLIB_TEST(dlib::hash(v2) == 0x63852afc); + DLIB_TEST(dlib::hash(m) == 0xaacc3f6f); + DLIB_TEST(dlib::hash(mat) == 0x3e349da5); + DLIB_TEST(dlib::hash(mat2) == 0x3a95dc52); + DLIB_TEST(murmur_hash3(&str1[0], str1.size(), 0) == 0x3ffe6bf6); + + dlog << LINFO << "hash(str1,1): "<< dlib::hash(str1,1); + dlog << LINFO << "hash(v,3): "<< dlib::hash(v,3); + dlog << LINFO << "hash(v2,3): "<< dlib::hash(v2,3); + dlog << LINFO << "hash(m,4): "<< dlib::hash(m,4); + dlog << LINFO << "hash(mat,5): "<< dlib::hash(mat,5); + dlog << LINFO << "hash(mat2,6): "<< dlib::hash(mat2,6); + + DLIB_TEST(dlib::hash(str1,1) == 0xb17cea93); + DLIB_TEST(dlib::hash(v,3) == 0x7ec9284c); + DLIB_TEST(dlib::hash(v2,3) == 0xb2ce147f); + DLIB_TEST(dlib::hash(m,4) == 0xfa5e7ac2); + DLIB_TEST(dlib::hash(mat,5) == 0x8de27259); + DLIB_TEST(dlib::hash(mat2,6) == 0xb8aa7714); + DLIB_TEST(murmur_hash3(&str1[0], str1.size(), 1) == 0xb17cea93); + + test_murmur_hash_128_4(); + test_murmur_hash_128_3(); + test_murmur_hash_64_2(); + test_murmur_hash_64_3(); + } + } a; + + + +} + + + diff --git a/ml/dlib/dlib/test/hash_map.cpp b/ml/dlib/dlib/test/hash_map.cpp new file mode 100644 index 000000000..09af09936 --- /dev/null +++ b/ml/dlib/dlib/test/hash_map.cpp @@ -0,0 +1,450 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include + +#include +#include "tester.h" + +namespace +{ + + using namespace test; + using namespace std; + using namespace dlib; + + logger dlog("test.hash_map"); + + template < + typename hash_map + > + void hash_map_kernel_test ( + ) + /*! + requires + - hash_map is an implementation of hash_map/hash_map_kernel_abstract.h and + is instantiated to map int to int + ensures + - runs tests on hash_map for compliance with the specs + !*/ + { + + srand(static_cast(time(0))); + + print_spinner(); + + + hash_map test, test2; + + enumerable >& e = test; + DLIB_TEST(e.at_start() == true); + + for (int j = 0; j < 4; ++j) + { + print_spinner(); + + DLIB_TEST(test.at_start() == true); + DLIB_TEST(test.current_element_valid() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.at_start() == false); + DLIB_TEST(test.current_element_valid() == false); + + + DLIB_TEST(test.size() == 0); + DLIB_TEST(test.is_in_domain(5) == false); + DLIB_TEST(test.is_in_domain(0) == false); + DLIB_TEST(test.is_in_domain(-999) == false); + DLIB_TEST(test.is_in_domain(4999) == false); + + + int a,b; + a = 8; + b = 94; + test.add(a,b); + DLIB_TEST(test.size() == 1); + DLIB_TEST(test.is_in_domain(8) == true); + DLIB_TEST(test.is_in_domain(5) == false); + DLIB_TEST(test.is_in_domain(0) == false); + DLIB_TEST(test.is_in_domain(-999) == false); + DLIB_TEST(test.is_in_domain(4999) == false); + DLIB_TEST(test[8] == 94); + a = 53; + b = 4; + test.add(a,b); + DLIB_TEST(test.size() == 2); + DLIB_TEST(test.is_in_domain(53) == true); + DLIB_TEST(test.is_in_domain(5) == false); + DLIB_TEST(test.is_in_domain(0) == false); + DLIB_TEST(test.is_in_domain(-999) == false); + DLIB_TEST(test.is_in_domain(4999) == false); + DLIB_TEST(test[53] == 4); + + + swap(test,test2); + + + DLIB_TEST_MSG(test2.size() == 2,test2.size()); + DLIB_TEST(test2.is_in_domain(8) == true); + DLIB_TEST(test2.is_in_domain(5) == false); + DLIB_TEST(test2.is_in_domain(0) == false); + DLIB_TEST(test2.is_in_domain(-999) == false); + DLIB_TEST(test2.is_in_domain(4999) == false); + DLIB_TEST(test2[8] == 94); + DLIB_TEST(test2.size() == 2); + DLIB_TEST(test2.is_in_domain(53) == true); + DLIB_TEST(test2.is_in_domain(5) == false); + DLIB_TEST(test2.is_in_domain(0) == false); + DLIB_TEST(test2.is_in_domain(-999) == false); + DLIB_TEST(test2.is_in_domain(4999) == false); + DLIB_TEST(test2[53] == 4); + + + DLIB_TEST(test.size() == 0); + DLIB_TEST(test.is_in_domain(8) == false); + DLIB_TEST(test.is_in_domain(5) == false); + DLIB_TEST(test.is_in_domain(0) == false); + DLIB_TEST(test.is_in_domain(-999) == false); + DLIB_TEST(test.is_in_domain(4999) == false); + DLIB_TEST(test.size() == 0); + DLIB_TEST(test.is_in_domain(53) == false); + DLIB_TEST(test.is_in_domain(5) == false); + DLIB_TEST(test.is_in_domain(0) == false); + DLIB_TEST(test.is_in_domain(-999) == false); + DLIB_TEST(test.is_in_domain(4999) == false); + + + test.clear(); + DLIB_TEST(test.at_start() == true); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.at_start() == false); + + + DLIB_TEST(test.size() == 0); + + while (test.size() < 10000) + { + a = ::rand(); + b = ::rand(); + if (!test.is_in_domain(a)) + test.add(a,b); + } + + DLIB_TEST(test.size() == 10000); + test.clear(); + DLIB_TEST(test.size() == 0); + + while (test.size() < 10000) + { + a = ::rand(); + b = ::rand(); + if (!test.is_in_domain(a)) + test.add(a,b); + } + + DLIB_TEST(test.size() == 10000); + + int count = 0; + while (test.move_next()) + { + DLIB_TEST(test.element().key() == test.element().key()); + DLIB_TEST(test.element().value() == test.element().value()); + DLIB_TEST(test.element().key() == test.element().key()); + DLIB_TEST(test.element().value() == test.element().value()); + + + + ++count; + } + DLIB_TEST(test.current_element_valid() == false); + DLIB_TEST(test.at_start() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.current_element_valid() == false); + DLIB_TEST(test.at_start() == false); + DLIB_TEST(test.move_next() == false); + + DLIB_TEST(count == 10000); + + test.swap(test2); + + DLIB_TEST(test.size() == 2); + DLIB_TEST(test2.size() == 10000); + count = 0; + test2.reset(); + + test2.move_next(); + test2.element().value() = 99; + DLIB_TEST(test2[test2.element().key()] == 99); + DLIB_TEST(test2.element().value() == 99); + + test2.reset(); + + while (test2.move_next()) + { + DLIB_TEST(test2[test2.element().key()] == test2.element().value()); + DLIB_TEST(test2.element().key() == test2.element().key()); + DLIB_TEST(test2.element().value() == test2.element().value()); + DLIB_TEST(test2.element().key() == test2.element().key()); + DLIB_TEST(test2.element().value() == test2.element().value()); + + ++count; + } + DLIB_TEST(test2.size() == 10000); + DLIB_TEST(count == 10000); + DLIB_TEST(test2.current_element_valid() == false); + DLIB_TEST(test2.at_start() == false); + DLIB_TEST(test2.move_next() == false); + DLIB_TEST(test2.current_element_valid() == false); + DLIB_TEST(test2.at_start() == false); + DLIB_TEST(test2.move_next() == false); + + + + test2.clear(); + DLIB_TEST(test2.size() == 0); + DLIB_TEST(test2.at_start() == true); + + while (test.size() < 20000) + { + a = ::rand(); + b = ::rand(); + if (!test.is_in_domain(a)) + test.add(a,b); + } + + DLIB_TEST(test.at_start() == true); + + { + int* array1 = new int[test.size()]; + int* array2 = new int[test.size()]; + + int* tmp1 = array1; + int* tmp2 = array2; + + + + // serialize the state of test, then clear test, then + // load the state back into test. + ostringstream sout; + serialize(test,sout); + DLIB_TEST(test.at_start() == true); + istringstream sin(sout.str()); + test.clear(); + deserialize(test,sin); + DLIB_TEST(test.at_start() == true); + + + count = 0; + while (test.move_next()) + { + DLIB_TEST(test.element().key() == test.element().key()); + DLIB_TEST(test.element().value() == test.element().value()); + DLIB_TEST(test.element().key() == test.element().key()); + DLIB_TEST(test.current_element_valid() == true); + *tmp1 = test.element().key(); + *tmp2 = test.element().value(); + ++tmp1; + ++tmp2; + ++count; + } + DLIB_TEST(count == 20000); + + tmp1 = array1; + tmp2 = array2; + for (int i = 0; i < 20000; ++i) + { + DLIB_TEST(test.is_in_domain(*tmp1) == true); + DLIB_TEST(test[*tmp1] == *tmp2); + ++tmp1; + ++tmp2; + } + + DLIB_TEST(test.size() == 20000); + + tmp1 = array1; + tmp2 = array2; + count = 0; + while (test.size() > 10000) + { + test.remove(*tmp1,a,b); + DLIB_TEST(*tmp1 == a); + DLIB_TEST(*tmp2 == b); + ++tmp1; + ++tmp2; + ++count; + } + DLIB_TEST(count == 10000); + DLIB_TEST(test.size() == 10000); + + while (test.move_next()) + { + DLIB_TEST(test.element().key() == *tmp1); + DLIB_TEST(test.element().key() == *tmp1); + DLIB_TEST(test.element().key() == *tmp1); + DLIB_TEST(test.element().value() == *tmp2); + DLIB_TEST(test.element().value() == *tmp2); + DLIB_TEST(test.element().value() == *tmp2); + ++tmp1; + ++tmp2; + ++count; + } + DLIB_TEST(count == 20000); + DLIB_TEST(test.size() == 10000); + + while (test.size() < 20000) + { + a = ::rand(); + b = ::rand(); + if (!test.is_in_domain(a)) + test.add(a,b); + } + + test2.swap(test); + + count = 0; + while (test2.move_next()) + { + DLIB_TEST(test2.element().key() == test2.element().key()); + DLIB_TEST(test2.element().value() == test2.element().value()); + DLIB_TEST(test2.element().key() == test2.element().key()); + + ++count; + } + + DLIB_TEST(count == 20000); + DLIB_TEST(test2.size() == 20000); + + int c = 0; + while (test2.size()>0) + { + test2.remove_any(b,c); + + } + + DLIB_TEST(test2.size() == 0); + delete [] array1; + delete [] array2; + } + + test.clear(); + test2.clear(); + while (test.size() < 10000) + { + a = ::rand(); + b = ::rand(); + if (!test.is_in_domain(a)) + test.add(a,b); + } + + count = 0; + while (test.move_next()) + { + + DLIB_TEST(test[test.element().key()] == test.element().value()); + + ++count; + if (count == 5000) + break; + DLIB_TEST(test.current_element_valid() == true); + } + + test.reset(); + + count = 0; + + while (test.move_next()) + { + + ++count; + DLIB_TEST(test.current_element_valid() == true); + } + + DLIB_TEST(count == 10000); + + + test.clear(); + test2.clear(); + } + + + + + { + test.clear(); + DLIB_TEST(test.size() == 0); + int a = 5; + int b = 6; + test.add(a,b); + a = 7; + b = 8; + test.add(a,b); + DLIB_TEST(test.size() == 2); + DLIB_TEST(test[7] == 8); + DLIB_TEST(test[5] == 6); + DLIB_TEST(test.is_in_domain(7)); + DLIB_TEST(test.is_in_domain(5)); + test.destroy(7); + DLIB_TEST(test.size() == 1); + DLIB_TEST(!test.is_in_domain(7)); + DLIB_TEST(test.is_in_domain(5)); + test.destroy(5); + DLIB_TEST(test.size() == 0); + DLIB_TEST(!test.is_in_domain(7)); + DLIB_TEST(!test.is_in_domain(5)); + } + + + + } + + + + + + class hash_map_tester : public tester + { + public: + hash_map_tester ( + ) : + tester ("test_hash_map", + "Runs tests on the hash_map component.") + {} + + void perform_test ( + ) + { + dlog << LINFO << "testing kernel_1a"; + hash_map_kernel_test::kernel_1a>(); + + dlog << LINFO << "testing kernel_1b_c"; + hash_map_kernel_test::kernel_1a_c>(); + + dlog << LINFO << "testing kernel_1b"; + hash_map_kernel_test::kernel_1b>(); + + dlog << LINFO << "testing kernel_1a_c"; + hash_map_kernel_test::kernel_1b_c>(); + + dlog << LINFO << "testing kernel_1c"; + hash_map_kernel_test::kernel_1c>(); + + dlog << LINFO << "testing kernel_1c_c"; + hash_map_kernel_test::kernel_1c_c>(); + } + } a; + +} + diff --git a/ml/dlib/dlib/test/hash_set.cpp b/ml/dlib/dlib/test/hash_set.cpp new file mode 100644 index 000000000..02b665bdb --- /dev/null +++ b/ml/dlib/dlib/test/hash_set.cpp @@ -0,0 +1,387 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include + +#include +#include "tester.h" + +namespace +{ + + using namespace test; + using namespace std; + using namespace dlib; + + logger dlog("test.hash_set"); + + template < + typename hash_set + > + void hash_set_kernel_test ( + ) + /*! + requires + - hash_set is an implementation of hash_set/hash_set_kernel_abstract.h and + is instantiated with int + ensures + - runs tests on hash_set for compliance with the specs + !*/ + { + + + srand(static_cast(time(0))); + + + print_spinner(); + + hash_set test, test2; + + + enumerable& e = test; + DLIB_TEST(e.at_start() == true); + + + for (int j = 0; j < 4; ++j) + { + print_spinner(); + + DLIB_TEST(test.at_start() == true); + DLIB_TEST(test.current_element_valid() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.at_start() == false); + DLIB_TEST(test.current_element_valid() == false); + + + DLIB_TEST(test.size() == 0); + DLIB_TEST(test.is_member(5) == false); + DLIB_TEST(test.is_member(0) == false); + DLIB_TEST(test.is_member(-999) == false); + DLIB_TEST(test.is_member(4999) == false); + + + int a,b = 0; + a = 8; + test.add(a); + DLIB_TEST(test.size() == 1); + DLIB_TEST(test.is_member(8) == true); + DLIB_TEST(test.is_member(5) == false); + DLIB_TEST(test.is_member(0) == false); + DLIB_TEST(test.is_member(-999) == false); + DLIB_TEST(test.is_member(4999) == false); + a = 53; + test.add(a); + DLIB_TEST(test.size() == 2); + DLIB_TEST(test.is_member(53) == true); + DLIB_TEST(test.is_member(5) == false); + DLIB_TEST(test.is_member(0) == false); + DLIB_TEST(test.is_member(-999) == false); + DLIB_TEST(test.is_member(4999) == false); + + + swap(test,test2); + + + + DLIB_TEST(test2.is_member(8) == true); + DLIB_TEST(test2.is_member(5) == false); + DLIB_TEST(test2.is_member(0) == false); + DLIB_TEST(test2.is_member(-999) == false); + DLIB_TEST(test2.is_member(4999) == false); + DLIB_TEST(test2.size() == 2); + DLIB_TEST(test2.is_member(53) == true); + DLIB_TEST(test2.is_member(5) == false); + DLIB_TEST(test2.is_member(0) == false); + DLIB_TEST(test2.is_member(-999) == false); + DLIB_TEST(test2.is_member(4999) == false); + + + DLIB_TEST(test.size() == 0); + DLIB_TEST(test.is_member(8) == false); + DLIB_TEST(test.is_member(5) == false); + DLIB_TEST(test.is_member(0) == false); + DLIB_TEST(test.is_member(-999) == false); + DLIB_TEST(test.is_member(4999) == false); + DLIB_TEST(test.size() == 0); + DLIB_TEST(test.is_member(53) == false); + DLIB_TEST(test.is_member(5) == false); + DLIB_TEST(test.is_member(0) == false); + DLIB_TEST(test.is_member(-999) == false); + DLIB_TEST(test.is_member(4999) == false); + + + test.clear(); + DLIB_TEST(test.at_start() == true); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.at_start() == false); + + + DLIB_TEST(test.size() == 0); + + while (test.size() < 10000) + { + a = ::rand(); + if (!test.is_member(a)) + test.add(a); + } + + DLIB_TEST(test.size() == 10000); + test.clear(); + DLIB_TEST(test.size() == 0); + + while (test.size() < 10000) + { + a = ::rand(); + if (!test.is_member(a)) + test.add(a); + } + + DLIB_TEST(test.size() == 10000); + + int count = 0; + while (test.move_next()) + { + DLIB_TEST(test.element() == test.element()); + DLIB_TEST(test.element() == test.element()); + DLIB_TEST(test.element() == test.element()); + + + ++count; + } + DLIB_TEST(test.current_element_valid() == false); + DLIB_TEST(test.at_start() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.current_element_valid() == false); + DLIB_TEST(test.at_start() == false); + DLIB_TEST(test.move_next() == false); + + DLIB_TEST(count == 10000); + + test.swap(test2); + + DLIB_TEST(test.size() == 2); + DLIB_TEST(test2.size() == 10000); + count = 0; + test2.reset(); + while (test2.move_next()) + { + DLIB_TEST(test2.element() == test2.element()); + DLIB_TEST(test2.element() == test2.element()); + DLIB_TEST(test2.element() == test2.element()); + + ++count; + } + DLIB_TEST(test2.size() == 10000); + DLIB_TEST(count == 10000); + DLIB_TEST(test2.current_element_valid() == false); + DLIB_TEST(test2.at_start() == false); + DLIB_TEST(test2.move_next() == false); + DLIB_TEST(test2.current_element_valid() == false); + DLIB_TEST(test2.at_start() == false); + DLIB_TEST(test2.move_next() == false); + + + + test2.clear(); + DLIB_TEST(test2.size() == 0); + DLIB_TEST(test2.at_start() == true); + + while (test.size() < 20000) + { + a = ::rand(); + if (!test.is_member(a)) + test.add(a); + } + + DLIB_TEST(test.at_start() == true); + + { + int* array = new int[test.size()]; + int* tmp = array; + + // serialize the state of test, then clear test, then + // load the state back into test. + ostringstream sout; + serialize(test,sout); + DLIB_TEST(test.at_start() == true); + istringstream sin(sout.str()); + test.clear(); + deserialize(test,sin); + + + + count = 0; + while (test.move_next()) + { + DLIB_TEST(test.element() == test.element()); + DLIB_TEST(test.element() == test.element()); + DLIB_TEST(test.element() == test.element()); + *tmp = test.element(); + ++tmp; + ++count; + } + DLIB_TEST(count == 20000); + + tmp = array; + for (int i = 0; i < 20000; ++i) + { + DLIB_TEST(test.is_member(*tmp) == true); + ++tmp; + } + + DLIB_TEST(test.size() == 20000); + + tmp = array; + count = 0; + while (test.size() > 10000) + { + test.remove(*tmp,a); + DLIB_TEST(*tmp == a); + ++tmp; + ++count; + } + DLIB_TEST(count == 10000); + DLIB_TEST(test.size() == 10000); + + while (test.move_next()) + { + ++count; + } + DLIB_TEST(count == 20000); + DLIB_TEST(test.size() == 10000); + + while (test.size() < 20000) + { + a = ::rand(); + if (!test.is_member(a)) + test.add(a); + } + + test2.swap(test); + + count = 0; + while (test2.move_next()) + { + DLIB_TEST(test2.element() == test2.element()); + DLIB_TEST(test2.element() == test2.element()); + DLIB_TEST(test2.element() == test2.element()); + + ++count; + } + + DLIB_TEST(count == 20000); + DLIB_TEST(test2.size() == 20000); + + + while (test2.size()>0) + { + test2.remove_any(b); + } + + DLIB_TEST(test2.size() == 0); + delete [] array; + } + + test.clear(); + test2.clear(); + while (test.size() < 10000) + { + a = ::rand(); + if (!test.is_member(a)) + test.add(a); + } + + count = 0; + while (test.move_next()) + { + ++count; + if (count == 5000) + break; + DLIB_TEST(test.current_element_valid() == true); + } + + test.reset(); + + count = 0; + while (test.move_next()) + { + ++count; + DLIB_TEST(test.current_element_valid() == true); + } + + DLIB_TEST(count == 10000); + + + test.clear(); + test2.clear(); + } + + + { + test.clear(); + DLIB_TEST(test.size() == 0); + int a = 5; + test.add(a); + a = 7; + test.add(a); + DLIB_TEST(test.size() == 2); + DLIB_TEST(test.is_member(7)); + DLIB_TEST(test.is_member(5)); + test.destroy(7); + DLIB_TEST(test.size() == 1); + DLIB_TEST(!test.is_member(7)); + DLIB_TEST(test.is_member(5)); + test.destroy(5); + DLIB_TEST(test.size() == 0); + DLIB_TEST(!test.is_member(7)); + DLIB_TEST(!test.is_member(5)); + } + + } + + + + + class hash_set_tester : public tester + { + public: + hash_set_tester ( + ) : + tester ("test_hash_set", + "Runs tests on the hash_set component.") + {} + + void perform_test ( + ) + { + dlog << LINFO << "testing kernel_1a"; + hash_set_kernel_test::kernel_1a>(); + dlog << LINFO << "testing kernel_1a_c"; + hash_set_kernel_test::kernel_1a_c>(); + dlog << LINFO << "testing kernel_1b"; + hash_set_kernel_test::kernel_1b>(); + dlog << LINFO << "testing kernel_1b_c"; + hash_set_kernel_test::kernel_1b_c>(); + dlog << LINFO << "testing kernel_1c"; + hash_set_kernel_test::kernel_1c>(); + dlog << LINFO << "testing kernel_1c_c"; + hash_set_kernel_test::kernel_1c_c>(); + } + } a; + +} + diff --git a/ml/dlib/dlib/test/hash_table.cpp b/ml/dlib/dlib/test/hash_table.cpp new file mode 100644 index 000000000..f4754835e --- /dev/null +++ b/ml/dlib/dlib/test/hash_table.cpp @@ -0,0 +1,663 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include + +#include +#include "tester.h" + +namespace +{ + + using namespace test; + using namespace std; + using namespace dlib; + + logger dlog("test.hash_table"); + + template < + typename hash_table + > + void hash_table_kernel_test ( + ) + /*! + requires + - hash_table is an implementation of hash_table/hash_table_kernel_abstract.h + and is instantiated to map ints to ints + ensures + - runs tests on hash_table for compliance with the specs + !*/ + { + + srand(static_cast(time(0))); + + + + + { + hash_table test(16); + + DLIB_TEST(test.count(3) == 0); + + enumerable >& e = test; + DLIB_TEST(e.at_start() == true); + + hash_table test2(16); + + hash_table test3(0); + hash_table test4(0); + + + print_spinner(); + + int b; + for (int j = 0; j < 4; ++j) + { + int a = 4; + b = 5; + test2.add(a,b); + DLIB_TEST(test2.size() == 1); + DLIB_TEST(*test2[4] == 5); + DLIB_TEST(test2[99] == 0); + + DLIB_TEST(test2.move_next()); + DLIB_TEST(test2.element().key() == 4); + DLIB_TEST(test2.element().value() == 5); + + swap(test,test2); + DLIB_TEST(test.size() == 1); + DLIB_TEST(*test[4] == 5); + DLIB_TEST(test[99] == 0); + + test.swap(test2); + + a = 99; + b = 35; + test2.add(a,b); + DLIB_TEST(test2.size() == 2); + DLIB_TEST(*test2[4] == 5); + DLIB_TEST(*test2[99] == 35); + DLIB_TEST(test2[99] != 0); + DLIB_TEST(test2[949] == 0); + + test2.destroy(4); + DLIB_TEST(test2.size() == 1); + DLIB_TEST(test2[4] == 0); + DLIB_TEST(*test2[99] == 35); + DLIB_TEST(test2[99] != 0); + DLIB_TEST(test2[949] == 0); + + + + test2.destroy(99); + DLIB_TEST(test2.size() == 0); + DLIB_TEST(test2[4] == 0); + DLIB_TEST(test2[99] == 0); + DLIB_TEST(test2[949] == 0); + + + + test2.clear(); + } + + + print_spinner(); + + + + + for (int j = 0; j < 4; ++j) + { + + DLIB_TEST(test.count(3) == 0); + DLIB_TEST(test.size() == 0); + DLIB_TEST(test.at_start() == true); + DLIB_TEST(test.current_element_valid() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.at_start() == false); + DLIB_TEST(test.current_element_valid() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.move_next() == false); + + int a; + + for (int i = 0; i < 10000; ++i) + { + a = ::rand()%1000; + int temp = a; + unsigned long count = test.count(a); + test.add(a,b); + DLIB_TEST(test.count(temp) == count+1); + } + + { + unsigned long count = test.count(3); + + a = 3; test.add(a,b); ++count; + DLIB_TEST(test.count(3) == count); + a = 3; test.add(a,b); ++count; + DLIB_TEST(test.count(3) == count); + a = 3; test.add(a,b); ++count; + DLIB_TEST(test.count(3) == count); + a = 3; test.add(a,b); ++count; + DLIB_TEST(test.count(3) == count); + } + + + test.clear(); + + + for (int i = 0; i < 10000; ++i) + { + a = b = i; + unsigned long count = test.count(a); + test.add(a,b); + DLIB_TEST(test.count(i) == count+1); + } + + DLIB_TEST(test.size() == 10000); + DLIB_TEST(test.at_start() == true); + DLIB_TEST(test.current_element_valid() == false); + DLIB_TEST(test.move_next() == true); + DLIB_TEST(test.at_start() == false); + DLIB_TEST(test.current_element_valid() == true); + DLIB_TEST(test.move_next() == true); + DLIB_TEST(test.move_next() == true); + DLIB_TEST(test.current_element_valid() == true); + + + test.reset(); + + DLIB_TEST(test.size() == 10000); + DLIB_TEST(test.at_start() == true); + DLIB_TEST(test.current_element_valid() == false); + + + if (test.size() > 0) + { + int* array = new int[test.size()]; + int* tmp = array; + + int count = 0; + while (test.move_next()) + { + ++count; + *tmp = test.element().key(); + DLIB_TEST(test[*tmp] != 0); + DLIB_TEST(*tmp == test.element().key()); + DLIB_TEST(*tmp == test.element().value()); + DLIB_TEST(*tmp == test.element().key()); + DLIB_TEST(test.current_element_valid() == true); + ++tmp; + } + + DLIB_TEST(count == 10000); + DLIB_TEST(test.at_start() == false); + DLIB_TEST(test.current_element_valid() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.current_element_valid() == false); + DLIB_TEST(test.at_start() == false); + DLIB_TEST(test.current_element_valid() == false); + + DLIB_TEST(test.size() == 10000); + + swap(test,test2); + + + + + // serialize the state of test2, then clear test2, then + // load the state back into test2. + ostringstream sout; + serialize(test2,sout); + DLIB_TEST(test2.at_start() == true); + istringstream sin(sout.str()); + test2.clear(); + deserialize(test2,sin); + DLIB_TEST(test2.at_start() == true); + + + + + tmp = array; + for (int i = 0; i < 10000; ++i) + { + DLIB_TEST(*test2[*tmp] == *tmp); + DLIB_TEST(*test2[*tmp] == *tmp); + DLIB_TEST(*test2[*tmp] == *tmp); + ++tmp; + } + + test2.swap(test); + test.reset(); + + DLIB_TEST(test.at_start() == true); + count = 0; + tmp = array; + while (test.size() > 0) + { + test.remove(*tmp,a,b); + + ++tmp; + ++count; + } + + DLIB_TEST(count == 10000); + DLIB_TEST(test.size() == 0); + + + + DLIB_TEST(count == 10000); + + + + + + + + delete [] array; + } + + test.move_next(); + + for (int i = 0; i < 10000; ++i) + { + a = ::rand(); + test.add(a,b); + } + + DLIB_TEST(test.at_start() == true); + DLIB_TEST(test.move_next() == true); + + DLIB_TEST(test.size() == 10000); + + for (int i = 0; i < 10000; ++i) + { + test.remove_any(a,b); + } + + DLIB_TEST(test.at_start() == true); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.size() == 0); + + test.clear(); + + + + + + + + + + int* dtmp = new int[10000]; + int* rtmp = new int[10000]; + + int* d = dtmp; + int* r = rtmp; + for (unsigned long i = 0; i < 10000; ++i) + { + a = ::rand(); + b = ::rand(); + *d = a; + *r = b; + if (test[a] != 0) + { + --i; + continue; + } + test.add(a,b); + ++d; + ++r; + DLIB_TEST(test.size() == i+1); + } + + DLIB_TEST(test.size() == 10000); + + for (int i = 0; i < 10000; ++i) + { + DLIB_TEST(*test[dtmp[i]] == rtmp[i]); + } + + + delete [] dtmp; + delete [] rtmp; + + test.clear(); + }} + + + print_spinner(); + + + + + + + + + + + + + + + + + + + + + + + + + // now do the same thing as above but with a much smaller hash table + { + hash_table test(13); + + DLIB_TEST(test.count(3) == 0); + + enumerable >& e = test; + DLIB_TEST(e.at_start() == true); + + hash_table test2(16); + + hash_table test3(0); + hash_table test4(0); + + + int b; + for (int j = 0; j < 4; ++j) + { + int a = 4; + b = 5; + test2.add(a,b); + DLIB_TEST(test2.size() == 1); + DLIB_TEST(*test2[4] == 5); + DLIB_TEST(test2[99] == 0); + + + DLIB_TEST(test2.move_next()); + DLIB_TEST(test2.element().key() == 4); + DLIB_TEST(test2.element().value() == 5); + + swap(test,test2); + DLIB_TEST(test.size() == 1); + DLIB_TEST(*test[4] == 5); + DLIB_TEST(test[99] == 0); + + test.swap(test2); + + a = 99; + b = 35; + test2.add(a,b); + DLIB_TEST(test2.size() == 2); + DLIB_TEST(*test2[4] == 5); + DLIB_TEST(*test2[99] == 35); + DLIB_TEST(test2[99] != 0); + DLIB_TEST(test2[949] == 0); + + test2.destroy(4); + DLIB_TEST(test2.size() == 1); + DLIB_TEST(test2[4] == 0); + DLIB_TEST(*test2[99] == 35); + DLIB_TEST(test2[99] != 0); + DLIB_TEST(test2[949] == 0); + + + + test2.destroy(99); + DLIB_TEST(test2.size() == 0); + DLIB_TEST(test2[4] == 0); + DLIB_TEST(test2[99] == 0); + DLIB_TEST(test2[949] == 0); + + + + test2.clear(); + } + + + print_spinner(); + + + + + for (int j = 0; j < 4; ++j) + { + + DLIB_TEST(test.count(3) == 0); + DLIB_TEST(test.size() == 0); + DLIB_TEST(test.at_start() == true); + DLIB_TEST(test.current_element_valid() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.at_start() == false); + DLIB_TEST(test.current_element_valid() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.move_next() == false); + + int a; + + for (int i = 0; i < 10000; ++i) + { + a = ::rand()%1000; + int temp = a; + unsigned long count = test.count(a); + test.add(a,b); + DLIB_TEST(test.count(temp) == count+1); + } + + { + unsigned long count = test.count(3); + + a = 3; test.add(a,b); ++count; + DLIB_TEST(test.count(3) == count); + a = 3; test.add(a,b); ++count; + DLIB_TEST(test.count(3) == count); + a = 3; test.add(a,b); ++count; + DLIB_TEST(test.count(3) == count); + a = 3; test.add(a,b); ++count; + DLIB_TEST(test.count(3) == count); + } + + + test.clear(); + + + for (int i = 0; i < 10000; ++i) + { + a = b = i; + unsigned long count = test.count(a); + test.add(a,b); + DLIB_TEST(test.count(i) == count+1); + } + + DLIB_TEST(test.size() == 10000); + DLIB_TEST(test.at_start() == true); + DLIB_TEST(test.current_element_valid() == false); + DLIB_TEST(test.move_next() == true); + DLIB_TEST(test.at_start() == false); + DLIB_TEST(test.current_element_valid() == true); + DLIB_TEST(test.move_next() == true); + DLIB_TEST(test.move_next() == true); + DLIB_TEST(test.current_element_valid() == true); + + + test.reset(); + + DLIB_TEST(test.size() == 10000); + DLIB_TEST(test.at_start() == true); + DLIB_TEST(test.current_element_valid() == false); + + + if (test.size() > 0) + { + int* array = new int[test.size()]; + int* tmp = array; + + int count = 0; + while (test.move_next()) + { + ++count; + *tmp = test.element().key(); + DLIB_TEST(test[*tmp] != 0); + DLIB_TEST(*tmp == test.element().key()); + DLIB_TEST(*tmp == test.element().value()); + DLIB_TEST(*tmp == test.element().key()); + DLIB_TEST(test.current_element_valid() == true); + ++tmp; + } + + DLIB_TEST(count == 10000); + DLIB_TEST(test.at_start() == false); + DLIB_TEST(test.current_element_valid() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.current_element_valid() == false); + DLIB_TEST(test.at_start() == false); + DLIB_TEST(test.current_element_valid() == false); + + DLIB_TEST(test.size() == 10000); + + swap(test,test2); + + tmp = array; + for (int i = 0; i < 10000; ++i) + { + DLIB_TEST(*test2[*tmp] == *tmp); + DLIB_TEST(*test2[*tmp] == *tmp); + DLIB_TEST(*test2[*tmp] == *tmp); + ++tmp; + } + + test2.swap(test); + test.reset(); + + DLIB_TEST(test.at_start() == true); + count = 0; + tmp = array; + while (test.size() > 0) + { + test.remove(*tmp,a,b); + + ++tmp; + ++count; + } + + DLIB_TEST(count == 10000); + DLIB_TEST(test.size() == 0); + + + + DLIB_TEST(count == 10000); + + + + + + + + delete [] array; + } + + test.move_next(); + + for (int i = 0; i < 10000; ++i) + { + a = ::rand(); + test.add(a,b); + } + + DLIB_TEST(test.at_start() == true); + DLIB_TEST(test.move_next() == true); + + DLIB_TEST(test.size() == 10000); + + for (int i = 0; i < 10000; ++i) + { + test.remove_any(a,b); + } + + DLIB_TEST(test.at_start() == true); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.size() == 0); + + test.clear(); + + + + + + + + + int* dtmp = new int[10000]; + int* rtmp = new int[10000]; + + int* d = dtmp; + int* r = rtmp; + for (unsigned long i = 0; i < 10000; ++i) + { + a = ::rand(); + b = ::rand(); + *d = a; + *r = b; + if (test[a] != 0) + { + --i; + continue; + } + test.add(a,b); + ++d; + ++r; + DLIB_TEST(test.size() == i+1); + } + + DLIB_TEST(test.size() == 10000); + + for (int i = 0; i < 10000; ++i) + { + DLIB_TEST(*test[dtmp[i]] == rtmp[i]); + } + + + delete [] dtmp; + delete [] rtmp; + + test.clear(); + }} + + } + + + + + class hash_table_tester : public tester + { + public: + hash_table_tester ( + ) : + tester ("test_hash_table", + "Runs tests on the hash_table component.") + {} + + void perform_test ( + ) + { + dlog << LINFO << "testing kernel_1a"; + hash_table_kernel_test::kernel_1a> (); + dlog << LINFO << "testing kernel_1a_c"; + hash_table_kernel_test::kernel_1a_c>(); + dlog << LINFO << "testing kernel_2a"; + hash_table_kernel_test::kernel_2a> (); + dlog << LINFO << "testing kernel_2a_c"; + hash_table_kernel_test::kernel_2a_c>(); + } + } a; + +} + diff --git a/ml/dlib/dlib/test/hog_image.cpp b/ml/dlib/dlib/test/hog_image.cpp new file mode 100644 index 000000000..615e4d2ff --- /dev/null +++ b/ml/dlib/dlib/test/hog_image.cpp @@ -0,0 +1,126 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tester.h" + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.hog_image"); + +// ---------------------------------------------------------------------------------------- + + class test_hog_image : public tester + { + public: + test_hog_image ( + ) : + tester ("test_hog_image", + "Runs tests on the hog_image object.") + {} + + void perform_test ( + ) + { + print_spinner(); + array2d img; + img.set_size(200,200); + + assign_all_pixels(img, 0); + + hog_image<3,3,1,4,hog_signed_gradient,hog_full_interpolation> hog1, hog1_deserialized; + hog_image<4,4,2,4,hog_signed_gradient,hog_full_interpolation> hog2; + + hog1.load(img); + hog2.load(img); + + + // Just test all the coordinate mapping functions. + + DLIB_TEST(hog1.get_block_rect(0,0).width() == 3*3); + DLIB_TEST(hog1.get_block_rect(0,0).height() == 3*3); + DLIB_TEST(hog2.get_block_rect(0,0).width() == 4*4); + DLIB_TEST(hog2.get_block_rect(0,0).height() == 4*4); + + DLIB_TEST(get_rect(img).contains(hog1.get_block_rect(0,0))); + DLIB_TEST(get_rect(img).contains(hog1.get_block_rect(hog1.nr()-1,hog1.nc()-1))); + DLIB_TEST(get_rect(img).contains(hog2.get_block_rect(0,0))); + DLIB_TEST(get_rect(img).contains(hog2.get_block_rect(hog2.nr()-1,hog2.nc()-1))); + + dlib::rand rnd; + for (int i = 0; i < 20000; ++i) + { + point p(rnd.get_random_16bit_number(), rnd.get_random_16bit_number()); + p.x() -= 20000; + p.y() -= 20000; + + DLIB_TEST((hog1.feat_to_image_space(hog1.image_to_feat_space(p)) - p).length() <= 3); + DLIB_TEST((hog2.feat_to_image_space(hog2.image_to_feat_space(p)) - p).length() <= 10); + + DLIB_TEST_MSG((hog1.image_to_feat_space(hog1.feat_to_image_space(p)) - p).length() <= 3, + p << " " << hog1.feat_to_image_space(p) << " " << hog1.image_to_feat_space(hog1.feat_to_image_space(p)) ); + DLIB_TEST((hog2.image_to_feat_space(hog2.feat_to_image_space(p)) - p).length() <= 10); + } + + + DLIB_TEST(hog1.feat_to_image_space(point(0,0)) == point(5,5)); + DLIB_TEST(hog2.feat_to_image_space(point(0,0)) == point(9,9)); + + DLIB_TEST(hog1.feat_to_image_space(point(1,1)) == point(8,8)); + DLIB_TEST(hog2.feat_to_image_space(point(1,1)) == point(17,17)); + + DLIB_TEST(hog1.image_to_feat_space(hog1.feat_to_image_space(point(0,0))) == point(0,0)); + DLIB_TEST(hog2.image_to_feat_space(hog2.feat_to_image_space(point(0,0))) == point(0,0)); + DLIB_TEST(hog1.image_to_feat_space(hog1.feat_to_image_space(point(1,1))) == point(1,1)); + DLIB_TEST(hog2.image_to_feat_space(hog2.feat_to_image_space(point(1,1))) == point(1,1)); + DLIB_TEST(hog1.image_to_feat_space(hog1.feat_to_image_space(point(1,2))) == point(1,2)); + DLIB_TEST(hog2.image_to_feat_space(hog2.feat_to_image_space(point(1,2))) == point(1,2)); + + + + DLIB_TEST(hog1_deserialized.size() != hog1.size()); + DLIB_TEST(hog1_deserialized.nr() != hog1.nr()); + DLIB_TEST(hog1_deserialized.nc() != hog1.nc()); + ostringstream sout; + serialize(hog1, sout); + istringstream sin(sout.str()); + deserialize(hog1_deserialized, sin); + + DLIB_TEST(hog1_deserialized.size() == hog1.size()); + DLIB_TEST(hog1_deserialized.nr() == hog1.nr()); + DLIB_TEST(hog1_deserialized.nc() == hog1.nc()); + DLIB_TEST(hog1_deserialized(0,2) == hog1(0,2)); + DLIB_TEST(hog1_deserialized.get_block_rect(1,2) == hog1.get_block_rect(1,2)); + DLIB_TEST(hog1_deserialized.image_to_feat_space(hog1_deserialized.feat_to_image_space(point(0,0))) == point(0,0)); + DLIB_TEST(hog1_deserialized.image_to_feat_space(hog1_deserialized.feat_to_image_space(point(1,1))) == point(1,1)); + DLIB_TEST(hog1_deserialized.image_to_feat_space(hog1_deserialized.feat_to_image_space(point(1,2))) == point(1,2)); + + + + DLIB_TEST(hog1.size() > 1); + DLIB_TEST(hog1.nr() > 1); + DLIB_TEST(hog1.nc() > 1); + hog1.clear(); + DLIB_TEST(hog1.size() == 0); + DLIB_TEST(hog1.nr() == 0); + DLIB_TEST(hog1.nc() == 0); + } + } a; + +} + + + + diff --git a/ml/dlib/dlib/test/image.cpp b/ml/dlib/dlib/test/image.cpp new file mode 100644 index 000000000..01f1410cf --- /dev/null +++ b/ml/dlib/dlib/test/image.cpp @@ -0,0 +1,1903 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tester.h" + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.image"); + + + void image_test ( + ) + /*! + ensures + - runs tests on pixel objects and functions for compliance with the specs + !*/ + { + + print_spinner(); + + array2d img1, img2; + + img1.set_size(100,100); + + assign_all_pixels(img1,7); + + assign_image(img2, img1); + + DLIB_TEST_MSG(img1.nr() == 100 && img1.nc() == 100 && + img2.nr() == 100 && img2.nc() == 100,""); + + + for (long r = 0; r < img1.nr(); ++r) + { + for (long c = 0; c < img1.nc(); ++c) + { + DLIB_TEST(img1[r][c] == 7); + DLIB_TEST(img2[r][c] == 7); + } + } + + img2.clear(); + DLIB_TEST(img2.size() == 0); + DLIB_TEST(img2.nr() == 0); + DLIB_TEST(img2.nc() == 0); + assign_image(img2, mat(img1)); + + DLIB_TEST_MSG(img1.nr() == 100 && img1.nc() == 100 && + img2.nr() == 100 && img2.nc() == 100,""); + + + for (long r = 0; r < img1.nr(); ++r) + { + for (long c = 0; c < img1.nc(); ++c) + { + DLIB_TEST(img1[r][c] == 7); + DLIB_TEST(img2[r][c] == 7); + } + } + + + threshold_image(img1, img2, 4); + + for (long r = 0; r < img1.nr(); ++r) + { + for (long c = 0; c < img1.nc(); ++c) + { + DLIB_TEST(img1[r][c] == 7); + DLIB_TEST(img2[r][c] == on_pixel); + } + } + + { + array2d img; + img.set_size(14,15); + for (long r = 0; r < 14; ++r) + { + for (long c = 0; c < 15; ++c) + { + img[r][c].h = static_cast(r*14 + c + 1); + img[r][c].s = static_cast(r*14 + c + 2); + img[r][c].i = static_cast(r*14 + c + 3); + } + } + + ostringstream sout; + save_dng(img, sout); + istringstream sin(sout.str()); + + img.clear(); + DLIB_TEST(img.nr() == 0); + DLIB_TEST(img.nc() == 0); + + load_dng(img, sin); + + DLIB_TEST(img.nr() == 14); + DLIB_TEST(img.nc() == 15); + + for (long r = 0; r < 14; ++r) + { + for (long c = 0; c < 15; ++c) + { + DLIB_TEST(img[r][c].h == r*14 + c + 1); + DLIB_TEST(img[r][c].s == r*14 + c + 2); + DLIB_TEST(img[r][c].i == r*14 + c + 3); + } + } + } + + + + + { + array2d img; + img.set_size(14,15); + for (long r = 0; r < 14; ++r) + { + for (long c = 0; c < 15; ++c) + { + img[r][c].red = static_cast(r*14 + c + 1); + img[r][c].green = static_cast(r*14 + c + 2); + img[r][c].blue = static_cast(r*14 + c + 3); + img[r][c].alpha = static_cast(r*14 + c + 4); + } + } + + ostringstream sout; + save_dng(img, sout); + istringstream sin(sout.str()); + + img.clear(); + DLIB_TEST(img.nr() == 0); + DLIB_TEST(img.nc() == 0); + + load_dng(img, sin); + + DLIB_TEST(img.nr() == 14); + DLIB_TEST(img.nc() == 15); + + for (long r = 0; r < 14; ++r) + { + for (long c = 0; c < 15; ++c) + { + DLIB_TEST(img[r][c].red == r*14 + c + 1); + DLIB_TEST(img[r][c].green == r*14 + c + 2); + DLIB_TEST(img[r][c].blue == r*14 + c + 3); + DLIB_TEST(img[r][c].alpha == r*14 + c + 4); + } + } + } + +#ifdef DLIB_PNG_SUPPORT + { + array2d img; + array2d img2, img3; + img.set_size(14,15); + img2.set_size(img.nr(),img.nc()); + img3.set_size(img.nr(),img.nc()); + for (long r = 0; r < 14; ++r) + { + for (long c = 0; c < 15; ++c) + { + img[r][c].red = static_cast(r*14 + c + 1); + img[r][c].green = static_cast(r*14 + c + 2); + img[r][c].blue = static_cast(r*14 + c + 3); + img[r][c].alpha = static_cast(r*14 + c + 4); + } + } + + save_png(img, "test.png"); + + img.clear(); + DLIB_TEST(img.nr() == 0); + DLIB_TEST(img.nc() == 0); + + load_png(img, "test.png"); + + DLIB_TEST(img.nr() == 14); + DLIB_TEST(img.nc() == 15); + + assign_all_pixels(img2, 255); + assign_all_pixels(img3, 0); + load_png(img2, "test.png"); + assign_image(img3, img); + + for (long r = 0; r < 14; ++r) + { + for (long c = 0; c < 15; ++c) + { + DLIB_TEST(img[r][c].red == r*14 + c + 1); + DLIB_TEST(img[r][c].green == r*14 + c + 2); + DLIB_TEST(img[r][c].blue == r*14 + c + 3); + DLIB_TEST(img[r][c].alpha == r*14 + c + 4); + + DLIB_TEST(img2[r][c].red == img3[r][c].red); + DLIB_TEST(img2[r][c].green == img3[r][c].green); + DLIB_TEST(img2[r][c].blue == img3[r][c].blue); + } + } + } +#endif // DLIB_PNG_SUPPORT + + + + { + array2d img; + img.set_size(14,15); + for (long r = 0; r < 14; ++r) + { + for (long c = 0; c < 15; ++c) + { + img[r][c].red = static_cast(r*14 + c + 1); + img[r][c].green = static_cast(r*14 + c + 2); + img[r][c].blue = static_cast(r*14 + c + 3); + } + } + + ostringstream sout; + save_dng(img, sout); + save_bmp(img, sout); + save_dng(img, sout); + save_bmp(img, sout); + istringstream sin(sout.str()); + + for (int i = 0; i < 2; ++i) + { + img.clear(); + DLIB_TEST(img.nr() == 0); + DLIB_TEST(img.nc() == 0); + + load_dng(img, sin); + + DLIB_TEST(img.nr() == 14); + DLIB_TEST(img.nc() == 15); + + for (long r = 0; r < 14; ++r) + { + for (long c = 0; c < 15; ++c) + { + DLIB_TEST(img[r][c].red == r*14 + c + 1); + DLIB_TEST(img[r][c].green == r*14 + c + 2); + DLIB_TEST(img[r][c].blue == r*14 + c + 3); + } + } + + img.clear(); + DLIB_TEST(img.nr() == 0); + DLIB_TEST(img.nc() == 0); + + load_bmp(img, sin); + + DLIB_TEST(img.nr() == 14); + DLIB_TEST(img.nc() == 15); + + for (long r = 0; r < 14; ++r) + { + for (long c = 0; c < 15; ++c) + { + DLIB_TEST_MSG(img[r][c].red == r*14 + c + 1, "got " << (int)img[r][c].red << " but expected " << r*14 + c + 1); + DLIB_TEST(img[r][c].green == r*14 + c + 2); + DLIB_TEST(img[r][c].blue == r*14 + c + 3); + } + } + } + } + { + array2d img; + img.set_size(14,15); + for (long r = 0; r < 14; ++r) + { + for (long c = 0; c < 15; ++c) + { + img[r][c].red = static_cast(r*14 + c + 1); + img[r][c].green = static_cast(r*14 + c + 2); + img[r][c].blue = static_cast(r*14 + c + 3); + } + } + + ostringstream sout; + save_dng(img, sout); + save_bmp(img, sout); + save_dng(img, sout); + save_bmp(img, sout); + istringstream sin(sout.str()); + + for (int i = 0; i < 2; ++i) + { + img.clear(); + DLIB_TEST(img.nr() == 0); + DLIB_TEST(img.nc() == 0); + + load_dng(img, sin); + + DLIB_TEST(img.nr() == 14); + DLIB_TEST(img.nc() == 15); + + for (long r = 0; r < 14; ++r) + { + for (long c = 0; c < 15; ++c) + { + DLIB_TEST(img[r][c].red == r*14 + c + 1); + DLIB_TEST(img[r][c].green == r*14 + c + 2); + DLIB_TEST(img[r][c].blue == r*14 + c + 3); + } + } + + img.clear(); + DLIB_TEST(img.nr() == 0); + DLIB_TEST(img.nc() == 0); + + load_bmp(img, sin); + + DLIB_TEST(img.nr() == 14); + DLIB_TEST(img.nc() == 15); + + for (long r = 0; r < 14; ++r) + { + for (long c = 0; c < 15; ++c) + { + DLIB_TEST_MSG(img[r][c].red == r*14 + c + 1, "got " << (int)img[r][c].red << " but expected " << r*14 + c + 1); + DLIB_TEST(img[r][c].green == r*14 + c + 2); + DLIB_TEST(img[r][c].blue == r*14 + c + 3); + } + } + } + } + +#ifdef DLIB_PNG_SUPPORT + { + array2d img; + img.set_size(14,15); + for (long r = 0; r < 14; ++r) + { + for (long c = 0; c < 15; ++c) + { + img[r][c].red = static_cast(r*14 + c + 1); + img[r][c].green = static_cast(r*14 + c + 2); + img[r][c].blue = static_cast(r*14 + c + 3); + } + } + + save_png(img, "test.png"); + + img.clear(); + DLIB_TEST(img.nr() == 0); + DLIB_TEST(img.nc() == 0); + + load_png(img, "test.png"); + + DLIB_TEST(img.nr() == 14); + DLIB_TEST(img.nc() == 15); + + for (long r = 0; r < 14; ++r) + { + for (long c = 0; c < 15; ++c) + { + DLIB_TEST(img[r][c].red == r*14 + c + 1); + DLIB_TEST(img[r][c].green == r*14 + c + 2); + DLIB_TEST(img[r][c].blue == r*14 + c + 3); + } + } + } + { + array2d img; + img.set_size(14,15); + for (long r = 0; r < 14; ++r) + { + for (long c = 0; c < 15; ++c) + { + img[r][c].red = static_cast(r*14 + c + 1); + img[r][c].green = static_cast(r*14 + c + 2); + img[r][c].blue = static_cast(r*14 + c + 3); + } + } + + save_png(img, "test.png"); + + img.clear(); + DLIB_TEST(img.nr() == 0); + DLIB_TEST(img.nc() == 0); + + load_png(img, "test.png"); + + DLIB_TEST(img.nr() == 14); + DLIB_TEST(img.nc() == 15); + + for (long r = 0; r < 14; ++r) + { + for (long c = 0; c < 15; ++c) + { + DLIB_TEST(img[r][c].red == r*14 + c + 1); + DLIB_TEST(img[r][c].green == r*14 + c + 2); + DLIB_TEST(img[r][c].blue == r*14 + c + 3); + } + } + } +#endif // DLIB_PNG_SUPPORT + + + + { + array2d img; + img.set_size(14,15); + for (long r = 0; r < 14; ++r) + { + for (long c = 0; c < 15; ++c) + { + img[r][c] = static_cast(r*14 + c + 0xF0); + } + } + + ostringstream sout; + save_dng(img, sout); + istringstream sin(sout.str()); + + img.clear(); + DLIB_TEST(img.nr() == 0); + DLIB_TEST(img.nc() == 0); + + load_dng(img, sin); + + DLIB_TEST(img.nr() == 14); + DLIB_TEST(img.nc() == 15); + + for (long r = 0; r < 14; ++r) + { + for (long c = 0; c < 15; ++c) + { + DLIB_TEST(img[r][c] == r*14 + c + 0xF0); + } + } + } + + +#ifdef DLIB_PNG_SUPPORT + { + array2d img; + img.set_size(14,15); + for (long r = 0; r < 14; ++r) + { + for (long c = 0; c < 15; ++c) + { + img[r][c] = static_cast(r*14 + c + 0xF0); + } + } + + save_png(img, "test.png"); + + img.clear(); + DLIB_TEST(img.nr() == 0); + DLIB_TEST(img.nc() == 0); + + load_png(img, "test.png"); + + DLIB_TEST(img.nr() == 14); + DLIB_TEST(img.nc() == 15); + + for (long r = 0; r < 14; ++r) + { + for (long c = 0; c < 15; ++c) + { + DLIB_TEST(img[r][c] == r*14 + c + 0xF0); + } + } + } +#endif // DLIB_PNG_SUPPORT + + + + { + array2d img; + img.set_size(14,15); + for (long r = 0; r < 14; ++r) + { + for (long c = 0; c < 15; ++c) + { + img[r][c] = static_cast(r*14 + c*111); + } + } + + ostringstream sout; + save_dng(img, sout); + save_bmp(img, sout); + save_dng(img, sout); + save_bmp(img, sout); + istringstream sin(sout.str()); + + for (int i = 0; i < 2; ++i) + { + img.clear(); + DLIB_TEST(img.nr() == 0); + DLIB_TEST(img.nc() == 0); + + load_dng(img, sin); + + DLIB_TEST(img.nr() == 14); + DLIB_TEST(img.nc() == 15); + + for (long r = 0; r < 14; ++r) + { + for (long c = 0; c < 15; ++c) + { + DLIB_TEST(img[r][c] == static_cast(r*14 + c*111)); + } + } + + + img.clear(); + DLIB_TEST(img.nr() == 0); + DLIB_TEST(img.nc() == 0); + + load_bmp(img, sin); + + DLIB_TEST(img.nr() == 14); + DLIB_TEST(img.nc() == 15); + + for (long r = 0; r < 14; ++r) + { + for (long c = 0; c < 15; ++c) + { + DLIB_TEST(img[r][c] == static_cast(r*14 + c*111)); + } + } + } + } + + +#ifdef DLIB_PNG_SUPPORT + { + array2d img; + img.set_size(14,15); + for (long r = 0; r < 14; ++r) + { + for (long c = 0; c < 15; ++c) + { + img[r][c] = static_cast(r*14 + c); + } + } + + save_png(img, "test.png"); + + img.clear(); + DLIB_TEST(img.nr() == 0); + DLIB_TEST(img.nc() == 0); + + load_png(img, "test.png"); + + DLIB_TEST(img.nr() == 14); + DLIB_TEST(img.nc() == 15); + + for (long r = 0; r < 14; ++r) + { + for (long c = 0; c < 15; ++c) + { + DLIB_TEST(img[r][c] == r*14 + c); + } + } + + } +#endif // DLIB_PNG_SUPPORT + + + { + // in this test we will only assign pixel values that can be + // represented with 8 bits even though we are using a wider pixel type. + array2d img; + img.set_size(14,15); + for (long r = 0; r < 14; ++r) + { + for (long c = 0; c < 15; ++c) + { + img[r][c] = static_cast(r*14 + c); + } + } + + ostringstream sout; + save_dng(img, sout); + save_bmp(img, sout); + save_dng(img, sout); + save_bmp(img, sout); + istringstream sin(sout.str()); + + for (int i = 0; i < 2; ++i) + { + img.clear(); + DLIB_TEST(img.nr() == 0); + DLIB_TEST(img.nc() == 0); + + load_dng(img, sin); + + DLIB_TEST(img.nr() == 14); + DLIB_TEST(img.nc() == 15); + + for (long r = 0; r < 14; ++r) + { + for (long c = 0; c < 15; ++c) + { + DLIB_TEST(img[r][c] == r*14 + c); + } + } + + + img.clear(); + DLIB_TEST(img.nr() == 0); + DLIB_TEST(img.nc() == 0); + + load_bmp(img, sin); + + DLIB_TEST(img.nr() == 14); + DLIB_TEST(img.nc() == 15); + + for (long r = 0; r < 14; ++r) + { + for (long c = 0; c < 15; ++c) + { + DLIB_TEST(img[r][c] == r*14 + c); + } + } + } + } + + { + array2d img1; + array2d img2; + img1.set_size(10,10); + assign_all_pixels(img1, 0); + + img1[5][5] = 10000; + img1[7][7] = 10000; + + equalize_histogram(img1, img2); + + for (long r = 0; r < img1.nr(); ++r) + { + for (long c = 0; c < img2.nc(); ++c) + { + if ((r == 5 && c == 5) || + (r == 7 && c == 7)) + { + DLIB_TEST(img2[r][c] == 255); + } + else + { + DLIB_TEST(img2[r][c] == 0); + } + } + } + + } + + { + array2d img; + img.set_size(10,10); + assign_all_pixels(img, 0); + + assign_border_pixels(img, 2,2, 4); + + DLIB_TEST(zeros_matrix(6,6) == subm(mat(img), rectangle(2,2,7,7))); + DLIB_TEST(uniform_matrix(1,10, 4) == rowm(mat(img), 0)); + DLIB_TEST(uniform_matrix(1,10, 4) == rowm(mat(img), 1)); + DLIB_TEST(uniform_matrix(1,10, 4) == rowm(mat(img), 8)); + DLIB_TEST(uniform_matrix(1,10, 4) == rowm(mat(img), 9)); + + DLIB_TEST(uniform_matrix(10,1, 4) == colm(mat(img), 0)); + DLIB_TEST(uniform_matrix(10,1, 4) == colm(mat(img), 1)); + DLIB_TEST(uniform_matrix(10,1, 4) == colm(mat(img), 8)); + DLIB_TEST(uniform_matrix(10,1, 4) == colm(mat(img), 9)); + + + assign_border_pixels(img, 7, 7, 5); + DLIB_TEST(uniform_matrix(10,10, 5) == mat(img)); + assign_border_pixels(img, 37, 47, 5); + DLIB_TEST(uniform_matrix(10,10, 5) == mat(img)); + } + + { + array2d img; + img.set_size(11,11); + assign_all_pixels(img, 0); + + assign_border_pixels(img, 2,2, 4); + + DLIB_TEST(zeros_matrix(7,7) == subm(mat(img), rectangle(2,2,8,8))); + DLIB_TEST(uniform_matrix(1,11, 4) == rowm(mat(img), 0)); + DLIB_TEST(uniform_matrix(1,11, 4) == rowm(mat(img), 1)); + DLIB_TEST(uniform_matrix(1,11, 4) == rowm(mat(img), 9)); + DLIB_TEST(uniform_matrix(1,11, 4) == rowm(mat(img), 10)); + + DLIB_TEST(uniform_matrix(11,1, 4) == colm(mat(img), 0)); + DLIB_TEST(uniform_matrix(11,1, 4) == colm(mat(img), 1)); + DLIB_TEST(uniform_matrix(11,1, 4) == colm(mat(img), 9)); + DLIB_TEST(uniform_matrix(11,1, 4) == colm(mat(img), 10)); + + assign_border_pixels(img, 7, 7, 5); + DLIB_TEST(uniform_matrix(11,11, 5) == mat(img)); + assign_border_pixels(img, 70, 57, 5); + DLIB_TEST(uniform_matrix(11,11, 5) == mat(img)); + } + + + } + + + template + void test_integral_image ( + ) + { + dlib::rand rnd; + + array2d img; + integral_image_generic int_img; + + int_img.load(img); + DLIB_TEST(int_img.nr() == 0); + DLIB_TEST(int_img.nc() == 0); + + // make 5 random images + for (int i = 0; i < 5; ++i) + { + print_spinner(); + img.set_size(rnd.get_random_16bit_number()%200+1, rnd.get_random_16bit_number()%200+1); + + for (long r = 0; r < img.nr(); ++r) + { + for (long c = 0; c < img.nc(); ++c) + { + img[r][c] = (int)rnd.get_random_8bit_number() - 100; + } + } + + int_img.load(img); + DLIB_TEST(int_img.nr() == img.nr()); + DLIB_TEST(int_img.nc() == img.nc()); + + // make 200 random rectangles + for (int j = 0; j < 500; ++j) + { + point p1(rnd.get_random_32bit_number()%img.nc(), rnd.get_random_32bit_number()%img.nr()); + point p2(rnd.get_random_32bit_number()%img.nc(), rnd.get_random_32bit_number()%img.nr()); + rectangle rect(p1,p2); + DLIB_TEST(int_img.get_sum_of_area(rect) == sum(subm(matrix_cast(mat(img)), rect))); + rect = rectangle(p1,p1); + DLIB_TEST(int_img.get_sum_of_area(rect) == sum(subm(matrix_cast(mat(img)), rect))); + } + + } + + + } + + void test_filtering2(int nr, int nc, dlib::rand& rnd) + { + print_spinner(); + dlog << LINFO << "test_filtering2(): " << nr << " " << nc; + array2d img(302,301); + for (long r = 0; r < img.nr(); ++r) + { + for (long c = 0; c < img.nc(); ++c) + { + img[r][c] = rnd.get_random_gaussian(); + } + } + matrix filt = matrix_cast(randm(nr,nc,rnd)); + + matrix out = xcorr_same(mat(img),filt); + matrix out2 = subm(conv(mat(img),flip(filt)), filt.nr()/2, filt.nc()/2, img.nr(), img.nc()); + // make sure xcorr_same does exactly what the docs say it should. + DLIB_TEST(max(abs(out-out2)) < 1e-7); + + // Now compare the filtering functions to xcorr_same to make sure everything does + // filtering in the same way. + array2d imout(img.nr(), img.nc()); + assign_all_pixels(imout, 10); + rectangle rect = spatially_filter_image(img, imout, filt); + border_enumerator be(get_rect(imout),rect); + while (be.move_next()) + { + DLIB_TEST(imout[be.element().y()][be.element().x()] == 0); + } + DLIB_TEST_MSG(max(abs(subm(mat(imout),rect) - subm(out,rect))) < 1e-5, max(abs(subm(mat(imout),rect) - subm(out,rect)))); + + + assign_all_pixels(imout, 10); + out = 10; + rect = spatially_filter_image(img, imout, filt,2,true,true); + be = border_enumerator(get_rect(imout),rect); + while (be.move_next()) + { + DLIB_TEST(imout[be.element().y()][be.element().x()] == 10); + } + out += abs(xcorr_same(mat(img),filt)/2); + DLIB_TEST(max(abs(subm(mat(imout),rect) - subm(out,rect))) < 1e-7); + + + assign_all_pixels(imout, -10); + out = -10; + rect = spatially_filter_image(img, imout, filt,2,false,true); + be = border_enumerator(get_rect(imout),rect); + while (be.move_next()) + { + DLIB_TEST(imout[be.element().y()][be.element().x()] == -10); + } + out += xcorr_same(mat(img),filt)/2; + DLIB_TEST_MSG(max(abs(subm(mat(imout),rect) - subm(out,rect))) < 1e-5, max(abs(subm(mat(imout),rect) - subm(out,rect)))); + + + + + matrix row_filt = matrix_cast(randm(nc,1,rnd)); + matrix col_filt = matrix_cast(randm(nr,1,rnd)); + assign_all_pixels(imout, 10); + rect = spatially_filter_image_separable(img, imout, row_filt, col_filt); + out = xcorr_same(tmp(xcorr_same(mat(img),trans(row_filt))), col_filt); + DLIB_TEST_MSG(max(abs(subm(mat(imout),rect) - subm(out,rect))) < 1e-5, max(abs(subm(mat(imout),rect) - subm(out,rect)))); + + be = border_enumerator(get_rect(imout),rect); + while (be.move_next()) + { + DLIB_TEST(imout[be.element().y()][be.element().x()] == 0); + } + + + assign_all_pixels(imout, 10); + out = 10; + rect = spatially_filter_image_separable(img, imout, row_filt, col_filt,2,true,true); + out += abs(xcorr_same(tmp(xcorr_same(mat(img),trans(row_filt))), col_filt)/2); + DLIB_TEST_MSG(max(abs(subm(mat(imout),rect) - subm(out,rect))) < 1e-7, + max(abs(subm(mat(imout),rect) - subm(out,rect)))); + + be = border_enumerator(get_rect(imout),rect); + while (be.move_next()) + { + DLIB_TEST(imout[be.element().y()][be.element().x()] == 10); + } + + } + + template + void test_filtering(bool use_abs, unsigned long scale ) + { + print_spinner(); + dlog << LINFO << "test_filtering(" << use_abs << "," << scale << ")"; + array2d img, img2, img3; + img.set_size(10,11); + + assign_all_pixels(img, 10); + + matrix filter2; + filter2 = 1,1,1,1,1, + 1,1,1,1,1, + 1,1,1,1,1; + + assign_all_pixels(img2,3); + rectangle brect = spatially_filter_image(img, img2, filter2); + DLIB_TEST(brect == shrink_rect(get_rect(img), filter2.nc()/2, filter2.nr()/2)); + + const rectangle rect(2,1,img.nc()-3,img.nr()-2); + + for (long r = 0; r row_filter; + matrix col_filter; + + row_filter = 1,1,1,1,1; + col_filter = 1,1,1; + + spatially_filter_image_separable(img, img3, row_filter, col_filter); + + DLIB_TEST(mat(img2) == mat(img3)); + + + dlib::rand rnd; + + for (int i = 0; i < 30; ++i) + { + for (long r = 0; r < img.nr(); ++r) + { + for (long c = 0; c < img.nc(); ++c) + { + img[r][c] = rnd.get_random_8bit_number(); + } + } + + row_filter(0) = ((int)rnd.get_random_8bit_number() - 100)/10; + row_filter(1) = ((int)rnd.get_random_8bit_number() - 100)/10; + row_filter(2) = ((int)rnd.get_random_8bit_number() - 100)/10; + row_filter(3) = ((int)rnd.get_random_8bit_number() - 100)/10; + row_filter(4) = ((int)rnd.get_random_8bit_number() - 100)/10; + col_filter(0) = ((int)rnd.get_random_8bit_number() - 100)/10; + col_filter(1) = ((int)rnd.get_random_8bit_number() - 100)/10; + col_filter(2) = ((int)rnd.get_random_8bit_number() - 100)/10; + + const matrix filter = trans(col_filter)*row_filter; + + assign_all_pixels(img2,3); + assign_all_pixels(img3,3); + // Just make sure both filtering methods give the same results. + rectangle brect1, brect2; + brect1 = spatially_filter_image(img, img2, filter, scale, use_abs); + brect2 = spatially_filter_image_separable(img, img3, row_filter, col_filter, scale, use_abs); + DLIB_TEST(mat(img2) == mat(img3)); + + DLIB_TEST(brect1 == shrink_rect(get_rect(img), filter.nc()/2, filter.nr()/2)); + DLIB_TEST(brect1 == brect2); + } + + { + array2d img, img2; + img.set_size(3,4); + + matrix filter(3,3); + filter = 1; + assign_all_pixels(img,-1); + + spatially_filter_image(img,img2,filter); + + DLIB_TEST(img2[0][0] == 0); + DLIB_TEST(img2[0][1] == 0); + DLIB_TEST(img2[0][2] == 0); + DLIB_TEST(img2[0][3] == 0); + + DLIB_TEST(img2[1][0] == 0); + DLIB_TEST(img2[1][1] == -9); + DLIB_TEST(img2[1][2] == -9); + DLIB_TEST(img2[1][3] == 0); + + DLIB_TEST(img2[2][0] == 0); + DLIB_TEST(img2[2][1] == 0); + DLIB_TEST(img2[2][2] == 0); + DLIB_TEST(img2[2][3] == 0); + + assign_all_pixels(img,-1); + + spatially_filter_image(img,img2,filter,2,true); + + DLIB_TEST(img2[0][0] == 0); + DLIB_TEST(img2[0][1] == 0); + DLIB_TEST(img2[0][2] == 0); + DLIB_TEST(img2[0][3] == 0); + + DLIB_TEST(img2[1][0] == 0); + DLIB_TEST(img2[1][1] == 4); + DLIB_TEST(img2[1][2] == 4); + DLIB_TEST(img2[1][3] == 0); + + DLIB_TEST(img2[2][0] == 0); + DLIB_TEST(img2[2][1] == 0); + DLIB_TEST(img2[2][2] == 0); + DLIB_TEST(img2[2][3] == 0); + + matrix rowf(3,1), colf(3,1); + rowf = 1; + colf = 1; + assign_all_pixels(img,-1); + + spatially_filter_image_separable(img,img2,rowf,colf); + DLIB_TEST(img2[0][0] == 0); + DLIB_TEST(img2[0][1] == 0); + DLIB_TEST(img2[0][2] == 0); + DLIB_TEST(img2[0][3] == 0); + + DLIB_TEST(img2[1][0] == 0); + DLIB_TEST(img2[1][1] == -9); + DLIB_TEST(img2[1][2] == -9); + DLIB_TEST(img2[1][3] == 0); + + DLIB_TEST(img2[2][0] == 0); + DLIB_TEST(img2[2][1] == 0); + DLIB_TEST(img2[2][2] == 0); + DLIB_TEST(img2[2][3] == 0); + + spatially_filter_image_separable(img,img2,rowf,colf,1,true); + DLIB_TEST(img2[0][0] == 0); + DLIB_TEST(img2[0][1] == 0); + DLIB_TEST(img2[0][2] == 0); + DLIB_TEST(img2[0][3] == 0); + + DLIB_TEST(img2[1][0] == 0); + DLIB_TEST(img2[1][1] == 9); + DLIB_TEST(img2[1][2] == 9); + DLIB_TEST(img2[1][3] == 0); + + DLIB_TEST(img2[2][0] == 0); + DLIB_TEST(img2[2][1] == 0); + DLIB_TEST(img2[2][2] == 0); + DLIB_TEST(img2[2][3] == 0); + + assign_all_pixels(img2, 3); + spatially_filter_image_separable(img,img2,rowf,colf,1,true, true); + DLIB_TEST(img2[0][0] == 3); + DLIB_TEST(img2[0][1] == 3); + DLIB_TEST(img2[0][2] == 3); + DLIB_TEST(img2[0][3] == 3); + + DLIB_TEST(img2[1][0] == 3); + DLIB_TEST_MSG(img2[1][1] == 9+3, img2[1][1] ); + DLIB_TEST(img2[1][2] == 9+3); + DLIB_TEST(img2[1][3] == 3); + + DLIB_TEST(img2[2][0] == 3); + DLIB_TEST(img2[2][1] == 3); + DLIB_TEST(img2[2][2] == 3); + DLIB_TEST(img2[2][3] == 3); + } + { + array2d img, img2; + img.set_size(3,4); + + matrix filter(3,3); + filter = 1; + assign_all_pixels(img,-1); + + spatially_filter_image(img,img2,filter,2); + + DLIB_TEST(img2[0][0] == 0); + DLIB_TEST(img2[0][1] == 0); + DLIB_TEST(img2[0][2] == 0); + DLIB_TEST(img2[0][3] == 0); + + DLIB_TEST(img2[1][0] == 0); + DLIB_TEST(std::abs(img2[1][1] - -4.5) < 1e-14); + DLIB_TEST(std::abs(img2[1][2] - -4.5) < 1e-14); + DLIB_TEST(img2[1][3] == 0); + + DLIB_TEST(img2[2][0] == 0); + DLIB_TEST(img2[2][1] == 0); + DLIB_TEST(img2[2][2] == 0); + DLIB_TEST(img2[2][3] == 0); + + } + { + array2d img, img2; + img.set_size(3,4); + img2.set_size(3,4); + assign_all_pixels(img2, 8); + + matrix filter(3,3); + filter = 1; + assign_all_pixels(img,-1); + + spatially_filter_image(img,img2,filter,2, false, true); + + DLIB_TEST(img2[0][0] == 8); + DLIB_TEST(img2[0][1] == 8); + DLIB_TEST(img2[0][2] == 8); + DLIB_TEST(img2[0][3] == 8); + + DLIB_TEST(img2[1][0] == 8); + DLIB_TEST(std::abs(img2[1][1] - -4.5 - 8) < 1e-14); + DLIB_TEST(std::abs(img2[1][2] - -4.5 - 8) < 1e-14); + DLIB_TEST(img2[1][3] == 8); + + DLIB_TEST(img2[2][0] == 8); + DLIB_TEST(img2[2][1] == 8); + DLIB_TEST(img2[2][2] == 8); + DLIB_TEST(img2[2][3] == 8); + + } + } + + + void test_zero_border_pixels( + ) + { + array2d img; + img.set_size(4,5); + + assign_all_pixels(img, 1); + zero_border_pixels(img, 2,1); + + DLIB_TEST(img[0][0] == 0); + DLIB_TEST(img[1][0] == 0); + DLIB_TEST(img[2][0] == 0); + DLIB_TEST(img[3][0] == 0); + DLIB_TEST(img[0][1] == 0); + DLIB_TEST(img[1][1] == 0); + DLIB_TEST(img[2][1] == 0); + DLIB_TEST(img[3][1] == 0); + + DLIB_TEST(img[0][3] == 0); + DLIB_TEST(img[1][3] == 0); + DLIB_TEST(img[2][3] == 0); + DLIB_TEST(img[3][3] == 0); + DLIB_TEST(img[0][4] == 0); + DLIB_TEST(img[1][4] == 0); + DLIB_TEST(img[2][4] == 0); + DLIB_TEST(img[3][4] == 0); + + DLIB_TEST(img[0][2] == 0); + DLIB_TEST(img[3][2] == 0); + + DLIB_TEST(img[1][2] == 1); + DLIB_TEST(img[2][2] == 1); + + rectangle rect = get_rect(img); + rect.left()+=2; + rect.top()+=1; + rect.right()-=2; + rect.bottom()-=1; + assign_all_pixels(img, 1); + zero_border_pixels(img, rect); + + DLIB_TEST(img[0][0] == 0); + DLIB_TEST(img[1][0] == 0); + DLIB_TEST(img[2][0] == 0); + DLIB_TEST(img[3][0] == 0); + DLIB_TEST(img[0][1] == 0); + DLIB_TEST(img[1][1] == 0); + DLIB_TEST(img[2][1] == 0); + DLIB_TEST(img[3][1] == 0); + + DLIB_TEST(img[0][3] == 0); + DLIB_TEST(img[1][3] == 0); + DLIB_TEST(img[2][3] == 0); + DLIB_TEST(img[3][3] == 0); + DLIB_TEST(img[0][4] == 0); + DLIB_TEST(img[1][4] == 0); + DLIB_TEST(img[2][4] == 0); + DLIB_TEST(img[3][4] == 0); + + DLIB_TEST(img[0][2] == 0); + DLIB_TEST(img[3][2] == 0); + + DLIB_TEST(img[1][2] == 1); + DLIB_TEST(img[2][2] == 1); + + rect.right()+=1; + assign_all_pixels(img, 1); + zero_border_pixels(img, rect); + DLIB_TEST(img[0][0] == 0); + DLIB_TEST(img[1][0] == 0); + DLIB_TEST(img[2][0] == 0); + DLIB_TEST(img[3][0] == 0); + DLIB_TEST(img[0][1] == 0); + DLIB_TEST(img[1][1] == 0); + DLIB_TEST(img[2][1] == 0); + DLIB_TEST(img[3][1] == 0); + + DLIB_TEST(img[0][3] == 0); + DLIB_TEST(img[1][3] == 1); + DLIB_TEST(img[2][3] == 1); + DLIB_TEST(img[3][3] == 0); + DLIB_TEST(img[0][4] == 0); + DLIB_TEST(img[1][4] == 0); + DLIB_TEST(img[2][4] == 0); + DLIB_TEST(img[3][4] == 0); + + DLIB_TEST(img[0][2] == 0); + DLIB_TEST(img[3][2] == 0); + + DLIB_TEST(img[1][2] == 1); + DLIB_TEST(img[2][2] == 1); + } + + + void test_label_connected_blobs() + { + array2d img; + img.set_size(400,401); + + assign_all_pixels(img,0); + + rectangle rect1, rect2, rect3; + + rect1 = centered_rect(99,120, 50,70); + rect2 = centered_rect(199,80, 34,68); + rect3 = centered_rect(249,180, 120,78); + + fill_rect(img, rect1, 255); + fill_rect(img, rect2, 255); + fill_rect(img, rect3, 255); + + array2d labels; + unsigned long num; + num = label_connected_blobs(img, + zero_pixels_are_background(), + neighbors_8(), + connected_if_both_not_zero(), + labels); + + DLIB_TEST(num == 4); + DLIB_TEST(labels.nr() == img.nr()); + DLIB_TEST(labels.nc() == img.nc()); + + const unsigned char l1 = labels[rect1.top()][rect1.left()]; + const unsigned char l2 = labels[rect2.top()][rect2.left()]; + const unsigned char l3 = labels[rect3.top()][rect3.left()]; + + DLIB_TEST(l1 != 0 && l2 != 0 && l3 != 0); + DLIB_TEST(l1 != l2 && l1 != l3 && l2 != l3); + + for (long r = 0; r < labels.nr(); ++r) + { + for (long c = 0; c < labels.nc(); ++c) + { + if (rect1.contains(c,r)) + { + DLIB_TEST(labels[r][c] == l1); + } + else if (rect2.contains(c,r)) + { + DLIB_TEST(labels[r][c] == l2); + } + else if (rect3.contains(c,r)) + { + DLIB_TEST(labels[r][c] == l3); + } + else + { + DLIB_TEST(labels[r][c] == 0); + } + } + } + } + + void test_label_connected_blobs2() + { + array2d img; + img.set_size(400,401); + + assign_all_pixels(img,0); + + rectangle rect1, rect2, rect3; + + rect1 = centered_rect(99,120, 50,70); + rect2 = centered_rect(199,80, 34,68); + rect3 = centered_rect(249,180, 120,78); + + fill_rect(img, rect1, 255); + fill_rect(img, rect2, 253); + fill_rect(img, rect3, 255); + + array2d labels; + unsigned long num; + num = label_connected_blobs(img, + nothing_is_background(), + neighbors_4(), + connected_if_equal(), + labels); + + DLIB_TEST(num == 5); + DLIB_TEST(labels.nr() == img.nr()); + DLIB_TEST(labels.nc() == img.nc()); + + const unsigned char l0 = labels[0][0]; + const unsigned char l1 = labels[rect1.top()][rect1.left()]; + const unsigned char l2 = labels[rect2.top()][rect2.left()]; + const unsigned char l3 = labels[rect3.top()][rect3.left()]; + + DLIB_TEST(l0 != 0 && l1 != 0 && l2 != 0 && l3 != 0); + DLIB_TEST(l1 != l2 && l1 != l3 && l2 != l3 && + l0 != l1 && l0 != l2 && l0 != l3); + + for (long r = 0; r < labels.nr(); ++r) + { + for (long c = 0; c < labels.nc(); ++c) + { + if (rect1.contains(c,r)) + { + DLIB_TEST(labels[r][c] == l1); + } + else if (rect2.contains(c,r)) + { + DLIB_TEST(labels[r][c] == l2); + } + else if (rect3.contains(c,r)) + { + DLIB_TEST(labels[r][c] == l3); + } + else + { + DLIB_TEST(labels[r][c] == l0); + } + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename in_image_type, + typename out_image_type + > + void downsample_image ( + const unsigned long downsample, + const in_image_type& in_img, + out_image_type& out_img, + bool add_to + ) + { + out_img.set_size((in_img.nr()+downsample-1)/downsample, + (in_img.nc()+downsample-1)/downsample); + + for (long r = 0; r < out_img.nr(); ++r) + { + for (long c = 0; c < out_img.nc(); ++c) + { + if (add_to) + out_img[r][c] += in_img[r*downsample][c*downsample]; + else + out_img[r][c] = in_img[r*downsample][c*downsample]; + } + } + } + + template < + typename in_image_type, + typename out_image_type, + typename EXP1, + typename EXP2, + typename T + > + void test_spatially_filter_image_separable_down_simple ( + const unsigned long downsample, + const in_image_type& in_img, + out_image_type& out_img, + const matrix_exp& row_filter, + const matrix_exp& col_filter, + T scale, + bool use_abs = false, + bool add_to = false + ) + { + out_image_type temp; + spatially_filter_image_separable(in_img, temp, row_filter, col_filter, scale, use_abs, false); + downsample_image(downsample, temp, out_img, add_to); + } + + + + + template + void test_downsampled_filtering_helper(long row_filt_size, long col_filt_size) + { + print_spinner(); + dlog << LTRACE << "***********************************"; + dlog << LTRACE << "downsample: " << downsample; + dlog << LTRACE << "row_filt_size: "<< row_filt_size; + dlog << LTRACE << "col_filt_size: "<< col_filt_size; + dlib::rand rnd; + array2d out1, out2; + for (long nr = 0; nr < 3; ++nr) + { + for (long nc = 0; nc < 3; ++nc) + { + dlog << LTRACE << "nr: "<< nr; + dlog << LTRACE << "nc: "<< nc; + array2d img(25+nr,25+nc); + for (int k = 0; k < 5; ++k) + { + for (long r = 0; r < img.nr(); ++r) + { + for (long c = 0; c < img.nc(); ++c) + { + img[r][c] = rnd.get_random_8bit_number(); + } + } + + matrix row_filter(row_filt_size); + matrix col_filter(col_filt_size); + + row_filter = matrix_cast(10*randm(row_filt_size,1, rnd)); + col_filter = matrix_cast(10*randm(col_filt_size,1, rnd)); + + row_filter -= 3; + col_filter -= 3; + + + test_spatially_filter_image_separable_down_simple(downsample, img, out1, row_filter, col_filter,1 ); + spatially_filter_image_separable_down(downsample, img, out2, row_filter, col_filter); + + DLIB_TEST(get_rect(out1) == get_rect(out2)); + DLIB_TEST(mat(out1) == mat(out2)); + + test_spatially_filter_image_separable_down_simple(downsample, img, out1, row_filter, col_filter,3, true, true ); + spatially_filter_image_separable_down(downsample, img, out2, row_filter, col_filter, 3, true, true); + + DLIB_TEST(get_rect(out1) == get_rect(out2)); + DLIB_TEST(mat(out1) == mat(out2)); + + } + } + } + } + + void test_downsampled_filtering() + { + test_downsampled_filtering_helper<1>(5,5); + test_downsampled_filtering_helper<2>(5,5); + test_downsampled_filtering_helper<3>(5,5); + test_downsampled_filtering_helper<1>(3,5); + test_downsampled_filtering_helper<2>(3,5); + test_downsampled_filtering_helper<3>(3,5); + test_downsampled_filtering_helper<1>(5,3); + test_downsampled_filtering_helper<2>(5,3); + test_downsampled_filtering_helper<3>(5,3); + + test_downsampled_filtering_helper<1>(3,3); + test_downsampled_filtering_helper<2>(3,3); + test_downsampled_filtering_helper<3>(3,3); + + test_downsampled_filtering_helper<1>(1,1); + test_downsampled_filtering_helper<2>(1,1); + test_downsampled_filtering_helper<3>(1,1); + + } + +// ---------------------------------------------------------------------------------------- + + template + void test_segment_image() + { + print_spinner(); + array2d img(100,100); + for (long r = 0; r < img.nr(); ++r) + { + for (long c = 0; c < img.nc(); ++c) + { + if (c < 50 || r < 50) + assign_pixel(img[r][c], 0); + else + assign_pixel(img[r][c], 255); + } + } + + array2d out; + segment_image(img, out); + + DLIB_TEST(get_rect(img) == get_rect(out)); + const unsigned long v1 = out[0][0]; + const unsigned long v2 = out[90][90]; + + for (long r = 0; r < img.nr(); ++r) + { + for (long c = 0; c < img.nc(); ++c) + { + if (c < 50 || r < 50) + { + DLIB_TEST(out[r][c] == v1); + } + else + { + DLIB_TEST(out[r][c] == v2); + } + } + } + } + +// ---------------------------------------------------------------------------------------- + + template + void test_dng_floats(double scale) + { + dlog << LINFO << "in test_dng_floats"; + print_spinner(); + array2d img(100,101); + + dlib::rand rnd; + for (long r = 0; r < img.nr(); ++r) + { + for (long c = 0; c < img.nc(); ++c) + { + T val = rnd.get_random_double()*scale; + img[r][c] = val; + + // Lets the float_details object while we are here doing this stuff. + float_details temp = val; + T val2 = temp; + // for the same type we should exactly reproduce the value (unless + // it's long double and then maybe it's slightly different). + if (is_same_type::value) + { + DLIB_TEST(std::abs(val2-val) < scale*std::numeric_limits::epsilon()); + } + else + { + DLIB_TEST(val2 == val); + } + + float valf = temp; + double vald = temp; + long double vall = temp; + + DLIB_TEST(std::abs(valf-val) < scale*std::numeric_limits::epsilon()); + DLIB_TEST(std::abs(vald-val) < scale*std::numeric_limits::epsilon()); + DLIB_TEST(std::abs(vall-val) < scale*std::numeric_limits::epsilon()); + } + } + + ostringstream sout; + save_dng(img, sout); + istringstream sin; + + array2d img1; + array2d img2; + array2d img3; + + sin.clear(); sin.str(sout.str()); + load_dng(img1, sin); + + sin.clear(); sin.str(sout.str()); + load_dng(img2, sin); + + sin.clear(); sin.str(sout.str()); + load_dng(img3, sin); + + DLIB_TEST(img.nr() == img1.nr()); + DLIB_TEST(img.nr() == img2.nr()); + DLIB_TEST(img.nr() == img3.nr()); + DLIB_TEST(img.nc() == img1.nc()); + DLIB_TEST(img.nc() == img2.nc()); + DLIB_TEST(img.nc() == img3.nc()); + + DLIB_TEST(max(abs(mat(img) - matrix_cast(mat(img1)))) < scale*std::numeric_limits::epsilon()); + DLIB_TEST(max(abs(mat(img) - matrix_cast(mat(img2)))) < scale*std::numeric_limits::epsilon()); + DLIB_TEST(max(abs(mat(img) - matrix_cast(mat(img3)))) < scale*std::numeric_limits::epsilon()); + } + + void test_dng_float_int() + { + dlog << LINFO << "in test_dng_float_int"; + print_spinner(); + + array2d img; + assign_image(img, gaussian_randm(101,100)*10000); + + ostringstream sout; + save_dng(img, sout); + istringstream sin(sout.str()); + array2d img2; + load_dng(img2, sin); + sout.clear(); sout.str(""); + + save_dng(img2, sout); + sin.clear(); sin.str(sout.str()); + array2d img3; + load_dng(img3, sin); + + // this whole thing should have been totally lossless. + DLIB_TEST(mat(img) == mat(img3)); + } + +// ---------------------------------------------------------------------------------------- + + template + void test_filtering_center ( + dlib::rand& rnd + ) + { + array2d img(rnd.get_random_32bit_number()%100+1, + rnd.get_random_32bit_number()%100+1); + matrix filt(rnd.get_random_32bit_number()%10+1, + rnd.get_random_32bit_number()%10+1); + + for (long r = 0; r < img.nr(); ++r) + { + for (long c = 0; c < img.nc(); ++c) + { + img[r][c] = rnd.get_random_32bit_number()%100; + } + } + for (long r = 0; r < filt.nr(); ++r) + { + for (long c = 0; c < filt.nc(); ++c) + { + filt(r,c) = rnd.get_random_32bit_number()%100; + } + } + + array2d out; + const rectangle area = spatially_filter_image(img, out, filt); + + for (long r = 0; r < out.nr(); ++r) + { + for (long c = 0; c < out.nc(); ++c) + { + const rectangle rect = centered_rect(point(c,r), filt.nc(), filt.nr()); + if (get_rect(out).contains(rect)) + { + T val = sum(pointwise_multiply(filt, subm(mat(img),rect))); + DLIB_TEST_MSG(val == out[r][c],"err: " << val-out[r][c]); + DLIB_TEST(area.contains(point(c,r))); + } + else + { + DLIB_TEST(!area.contains(point(c,r))); + } + } + } + } + + template + void test_separable_filtering_center ( + dlib::rand& rnd + ) + { + array2d img(rnd.get_random_32bit_number()%100+1, + rnd.get_random_32bit_number()%100+1); + matrix row_filt(rnd.get_random_32bit_number()%10+1); + matrix col_filt(rnd.get_random_32bit_number()%10+1); + + for (long r = 0; r < img.nr(); ++r) + { + for (long c = 0; c < img.nc(); ++c) + { + img[r][c] = rnd.get_random_32bit_number()%10; + } + } + for (long r = 0; r < row_filt.size(); ++r) + { + row_filt(r) = rnd.get_random_32bit_number()%10; + } + for (long r = 0; r < col_filt.size(); ++r) + { + col_filt(r) = rnd.get_random_32bit_number()%10; + } + + array2d out; + const rectangle area = spatially_filter_image_separable(img, out, row_filt, col_filt); + + for (long r = 0; r < out.nr(); ++r) + { + for (long c = 0; c < out.nc(); ++c) + { + const rectangle rect = centered_rect(point(c,r), row_filt.size(), col_filt.size()); + if (get_rect(out).contains(rect)) + { + T val = sum(pointwise_multiply(col_filt*row_filt, subm(mat(img),rect))); + DLIB_TEST_MSG(val == out[r][c],"err: " << val-out[r][c]); + + DLIB_TEST(area.contains(point(c,r))); + } + else + { + DLIB_TEST(!area.contains(point(c,r))); + } + } + } + } + +// ---------------------------------------------------------------------------------------- + + void run_hough_test() + { + array2d img(300,300); + + + for (int k = -2; k <= 2; ++k) + { + print_spinner(); + running_stats rs; + array2d himg; + hough_transform ht(200+k); + double angle1 = 0; + double angle2 = 0; + const int len = 90; + // Draw a bunch of random lines, hough transform them, then make sure the hough + // transform detects them accurately. + for (int i = 0; i < 500; ++i) + { + point cent = center(get_rect(img)); + point arc = cent + point(len,0); + arc = rotate_point(cent, arc, angle1); + + point l = arc + point(500,0); + point r = arc - point(500,0); + l = rotate_point(arc, l, angle2); + r = rotate_point(arc, r, angle2); + + angle1 += pi/13; + angle2 += pi/40; + + assign_all_pixels(img, 0); + draw_line(img, l, r, 255); + rectangle box = translate_rect(get_rect(ht),point(50,50)); + ht(img, box, himg); + + point p = max_point(mat(himg)); + DLIB_TEST(himg[p.y()][p.x()] > 255*3); + + l -= point(50,50); + r -= point(50,50); + std::pair line = ht.get_line(p); + // make sure the best scoring hough point matches the line we drew. + double dist1 = distance_to_line(make_pair(l,r), line.first); + double dist2 = distance_to_line(make_pair(l,r), line.second); + //cout << "DIST1: " << dist1 << endl; + //cout << "DIST2: " << dist2 << endl; + rs.add(dist1); + rs.add(dist2); + DLIB_TEST(dist1 < 2.5); + DLIB_TEST(dist2 < 2.5); + } + //cout << "rs.mean(): " << rs.mean() << endl; + DLIB_TEST(rs.mean() < 0.7); + } + } + +// ---------------------------------------------------------------------------------------- + + void test_extract_image_chips() + { + dlib::rand rnd; + + // Make sure that cropping a white box out of a larger white image always produces an + // exact white box. This should catch any bad border effects from a messed up internal + // cropping. + for (int iter = 0; iter < 1000; ++iter) + { + print_spinner(); + const long nr = rnd.get_random_32bit_number()%100 + 1; + const long nc = rnd.get_random_32bit_number()%100 + 1; + const long size = rnd.get_random_32bit_number()%10000 + 4; + const double angle = rnd.get_random_double() * pi; + + matrix img(501,501), chip; + img = 255; + chip_details details(centered_rect(center(get_rect(img)),nr,nc), size, angle); + extract_image_chip(img, details, chip); + DLIB_TEST_MSG(max(abs(chip-255))==0,"nr: " << nr << " nc: "<< nc << " size: " << size << " angle: " << angle + << " error: " << max(abs(chip-255)) ); + } + + + { + // Make sure that the interpolation in extract_image_chip() keeps stuff in the + // right places. + + matrix img(10,10), chip; + img = 0; + img(1,1) = 255; + img(8,8) = 255; + + extract_image_chip(img, chip_details(get_rect(img), 9*9), chip); + + DLIB_TEST(chip(1,1) == 195); + DLIB_TEST(chip(7,7) == 195); + chip(1,1) -= 195; + chip(7,7) -= 195; + DLIB_TEST(sum(matrix_cast(chip)) == 0); + } + + + + // Test the rotation ability of extract_image_chip(). Do this by drawing a line and + // then rotating it so it's horizontal. Check that it worked correctly by hough + // transforming it. + hough_transform ht(151); + matrix img(300,300); + for (int iter = 0; iter < 1000; ++iter) + { + print_spinner(); + img = 0; + const int len = 9000; + point cent = center(get_rect(img)); + point l = cent + point(len,0); + point r = cent - point(len,0); + const double angle = rnd.get_random_double()*pi*3; + l = rotate_point(cent, l, angle); + r = rotate_point(cent, r, angle); + draw_line(img, l, r, 255); + + + const long wsize = rnd.get_random_32bit_number()%350 + 150; + + matrix temp; + chip_details details(centered_rect(center(get_rect(img)), wsize,wsize), chip_dims(ht.size(),ht.size()), angle); + extract_image_chip(img, details, temp); + + + matrix tform; + ht(temp, get_rect(temp), tform); + std::pair line = ht.get_line(max_point(tform)); + + DLIB_TEST_MSG(line.first.y() == line.second.y()," wsize: " << wsize); + DLIB_TEST(length(line.first-line.second) > 100); + DLIB_TEST(length((line.first+line.second)/2.0 - center(get_rect(temp))) <= 1); + } + + } + +// ---------------------------------------------------------------------------------------- + + class image_tester : public tester + { + public: + image_tester ( + ) : + tester ("test_image", + "Runs tests on the image processing objects and functions.") + {} + + void perform_test ( + ) + { + image_test(); + run_hough_test(); + test_extract_image_chips(); + test_integral_image(); + test_integral_image(); + test_integral_image(); + test_integral_image(); + + test_zero_border_pixels(); + + test_filtering(false,1); + test_filtering(true,1); + test_filtering(false,3); + test_filtering(true,3); + test_filtering(false,1); + test_filtering(true,1); + test_filtering(false,3); + test_filtering(true,3); + + test_label_connected_blobs(); + test_label_connected_blobs2(); + test_downsampled_filtering(); + + test_segment_image(); + test_segment_image(); + test_segment_image(); + test_segment_image(); + test_segment_image(); + test_segment_image(); + + test_dng_floats(1); + test_dng_floats(1); + test_dng_floats(1); + test_dng_floats(1e30); + test_dng_floats(1e30); + test_dng_floats(1e30); + + test_dng_float_int(); + + dlib::rand rnd; + for (int i = 0; i < 10; ++i) + { + // the spatial filtering stuff is the same as xcorr_same when the filter + // sizes are odd. + test_filtering2(3,3,rnd); + test_filtering2(5,5,rnd); + test_filtering2(7,7,rnd); + } + + for (int i = 0; i < 100; ++i) + test_filtering_center(rnd); + for (int i = 0; i < 100; ++i) + test_filtering_center(rnd); + for (int i = 0; i < 100; ++i) + test_separable_filtering_center(rnd); + for (int i = 0; i < 100; ++i) + test_separable_filtering_center(rnd); + + { + print_spinner(); + matrix img(40,80); + assign_all_pixels(img, 255); + skeleton(img); + + DLIB_TEST(sum(matrix_cast(mat(img)))/255 == 40); + draw_line(img, point(20,19), point(59,19), 00); + DLIB_TEST(sum(matrix_cast(mat(img))) == 0); + } + } + } a; + +} + + + diff --git a/ml/dlib/dlib/test/iosockstream.cpp b/ml/dlib/dlib/test/iosockstream.cpp new file mode 100644 index 000000000..c68de7bd3 --- /dev/null +++ b/ml/dlib/dlib/test/iosockstream.cpp @@ -0,0 +1,181 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include +#include +#include + +#include "tester.h" + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + + logger dlog("test.iosockstream"); + +// ---------------------------------------------------------------------------------------- + + class serv : public server_iostream + { + virtual void on_connect ( + std::istream& in, + std::ostream& out, + const std::string& , + const std::string& , + unsigned short , + unsigned short , + uint64 + ) + { + try + { + dlog << LINFO << "serv1: serving connection"; + + std::string temp; + in >> temp; + DLIB_TEST(temp == "word"); + in >> temp; + DLIB_TEST(temp == "another"); + out << "yay words "; + in >> temp; + DLIB_TEST(temp == "yep"); + } + catch (error& e) + { + error_string = e.what(); + } + + } + + + public: + std::string error_string; + + }; + + class serv2 : public server_iostream + { + virtual void on_connect ( + std::istream& , + std::ostream& out, + const std::string& , + const std::string& , + unsigned short , + unsigned short , + uint64 + ) + { + try + { + dlog << LINFO << "serv2: serving connection"; + + out << "one two three four five"; + } + catch (error& e) + { + error_string = e.what(); + } + + } + + + public: + std::string error_string; + + }; + +// ---------------------------------------------------------------------------------------- + + void test1() + { + dlog << LINFO << "in test1()"; + serv theserv; + theserv.set_listening_port(12345); + theserv.start_async(); + + // wait a little bit to make sure the server has started listening before we try + // to connect to it. + dlib::sleep(500); + + for (int i = 0; i < 200; ++i) + { + dlog << LINFO << "i: " << i; + print_spinner(); + iosockstream stream("localhost:12345"); + + stream << "word another "; + std::string temp; + stream >> temp; + DLIB_TEST(temp == "yay"); + stream >> temp; + DLIB_TEST(temp == "words"); + stream << "yep "; + } + + // Just to make sure the server finishes processing the last connection before + // we kill it and accidentally trigger a DLIB_TEST(). + dlib::sleep(500); + + if (theserv.error_string.size() != 0) + throw error(theserv.error_string); + } + +// ---------------------------------------------------------------------------------------- + + void test2() + { + dlog << LINFO << "in test2()"; + serv2 theserv; + theserv.set_listening_port(12345); + theserv.start_async(); + + // wait a little bit to make sure the server has started listening before we try + // to connect to it. + dlib::sleep(500); + + for (int i = 0; i < 200; ++i) + { + dlog << LINFO << "i: " << i; + print_spinner(); + iosockstream stream("localhost:12345"); + + std::string temp; + stream >> temp; DLIB_TEST(temp == "one"); + stream >> temp; DLIB_TEST(temp == "two"); + stream >> temp; DLIB_TEST(temp == "three"); + stream >> temp; DLIB_TEST(temp == "four"); + stream >> temp; DLIB_TEST(temp == "five"); + } + } + +// ---------------------------------------------------------------------------------------- + + class test_iosockstream : public tester + { + public: + test_iosockstream ( + ) : + tester ("test_iosockstream", + "Runs tests on the iosockstream component.") + {} + + void perform_test ( + ) + { + test1(); + test2(); + } + } a; + +} + + diff --git a/ml/dlib/dlib/test/is_same_object.cpp b/ml/dlib/dlib/test/is_same_object.cpp new file mode 100644 index 000000000..ed71c5bef --- /dev/null +++ b/ml/dlib/dlib/test/is_same_object.cpp @@ -0,0 +1,141 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "tester.h" +#include +#include +#include + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + dlib::logger dlog("test.is_same_object"); + + DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST(has_booya_template, void, template booya, (std::string)const); + DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST(has_booya2_template, void, template booya2, (int)const); + DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST(has_funct_int, void, funct, (int)); + DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST(has_funct_double, void, funct, (double)); + DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST(has_funct_f, float, funct_f, (int)); + + class htest + { + public: + template + void booya(std::string) const {} + + template + void booya2(EXP) const {} + + void funct(double) {} + }; + + class htest2 + { + public: + + void funct(int) {} + + float funct_f(int) { return 0;} + }; + + void test_metaprog() + { + DLIB_TEST(has_booya2_template::value == true); + DLIB_TEST(has_booya2_template::value == false); + +#if _MSC_VER > 1600 // there is a bug in visual studio 2010 and older that prevents this test from working + DLIB_TEST(has_booya_template::value == true); +#endif + + DLIB_TEST(has_booya_template::value == false); + + DLIB_TEST(has_funct_int::value == false); + DLIB_TEST(has_funct_int::value == true); + DLIB_TEST(has_funct_double::value == true); + DLIB_TEST(has_funct_double::value == false); + + DLIB_TEST(has_funct_f::value == false); + DLIB_TEST(has_funct_f::value == true); + } + + class is_same_object_tester : public tester + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a unit test. When it is constructed + it adds itself into the testing framework. + !*/ + public: + is_same_object_tester ( + ) : + tester ( + "test_is_same_object", // the command line argument name for this test + "Run tests on the is_same_object function.", // the command line argument description + 0 // the number of command line arguments for this test + ) + { + } + + struct base {}; + struct derived : public base {}; + + template + void go(const base& a, const base& b) + { + DLIB_TEST( is_same_object(a,b) == truth) ; + DLIB_TEST( is_same_object(b,a) == truth) ; + } + + + template + void go2(const base& a, const derived& b) + { + DLIB_TEST( is_same_object(a,b) == truth) ; + DLIB_TEST( is_same_object(b,a) == truth) ; + } + + + void perform_test ( + ) + { + print_spinner(); + + int a, b; + double d; + DLIB_TEST( is_same_object(a,a) == true) ; + DLIB_TEST( is_same_object(a,b) == false) ; + DLIB_TEST( is_same_object(d,b) == false) ; + DLIB_TEST( is_same_object(d,d) == true) ; + + base sb; + derived sd, sd2; + + DLIB_TEST( is_same_object(sb,sd) == false) ; + DLIB_TEST( is_same_object(sd,sb) == false) ; + + go(sd, sd); + go(sd, sd2); + go(sb, sb); + go(sd, sb); + + go2(sd, sd); + go2(sd2, sd); + go2(sd, sd2); + go2(sb, sd); + + test_metaprog(); + } + }; + + // Create an instance of this object. Doing this causes this test + // to be automatically inserted into the testing framework whenever this cpp file + // is linked into the project. Note that since we are inside an unnamed-namespace + // we won't get any linker errors about the symbol a being defined multiple times. + is_same_object_tester a; + +} + + + diff --git a/ml/dlib/dlib/test/isotonic_regression.cpp b/ml/dlib/dlib/test/isotonic_regression.cpp new file mode 100644 index 000000000..2ab46903c --- /dev/null +++ b/ml/dlib/dlib/test/isotonic_regression.cpp @@ -0,0 +1,103 @@ +// Copyright (C) 2018 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include +#include +#include + +#include "tester.h" + + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.isotonic_regression"); + +// ---------------------------------------------------------------------------------------- + + class optimization_tester : public tester + { + public: + optimization_tester ( + ) : + tester ("test_isotonic_regression", + "Runs tests on the isotonic_regression object.") + {} + + void perform_test ( + ) + { + dlib::rand rnd; + + for (int round = 0; round < 100; ++round) + { + print_spinner(); + std::vector vect; + for (int i = 0; i < 5; ++i) + vect.push_back(put_in_range(-1,1,rnd.get_random_gaussian())); + + + auto f = [&](const matrix& x) + { + double dist = 0; + double sum = 0; + for (long i = 0; i < x.size(); ++i) + { + sum += x(i); + dist += (sum-vect[i])*(sum-vect[i]); + } + return dist; + }; + + auto objval = [vect](const matrix& x) + { + return sum(squared(mat(vect)-x)); + }; + + auto is_monotonic = [](const matrix& x) + { + for (long i = 1; i < x.size(); ++i) + { + if (x(i-1) > x(i)) + return false; + } + return true; + }; + + matrix lower(5), upper(5); + lower = 0; + lower(0) = -4; + upper = 4; + // find the solution with find_min_global() and then check that it matches + auto result = find_min_global(f, lower, upper, max_function_calls(40)); + + for (long i = 1; i < result.x.size(); ++i) + result.x(i) += result.x(i-1); + + isotonic_regression mr; + mr(vect); + + dlog << LINFO << "err: "<< objval(mat(vect)) - objval(result.x); + + DLIB_CASSERT(is_monotonic(mat(vect))); + DLIB_CASSERT(is_monotonic(result.x)); + // isotonic_regression should be at least as good as find_min_global(). + DLIB_CASSERT(objval(mat(vect)) - objval(result.x) < 1e-13); + } + + } + } a; + +} + + + diff --git a/ml/dlib/dlib/test/kcentroid.cpp b/ml/dlib/dlib/test/kcentroid.cpp new file mode 100644 index 000000000..c16ab6eca --- /dev/null +++ b/ml/dlib/dlib/test/kcentroid.cpp @@ -0,0 +1,684 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include +#include +#include +#include "../stl_checked.h" +#include "../array.h" +#include "../rand.h" +#include "checkerboard.h" +#include + +#include "tester.h" +#include + + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.kcentroid"); + +// ---------------------------------------------------------------------------------------- + + template + struct unopt_sparse_linear_kernel + { + typedef typename T::value_type::second_type scalar_type; + typedef T sample_type; + typedef default_memory_manager mem_manager_type; + + scalar_type operator() ( + const sample_type& a, + const sample_type& b + ) const + { + return dot(a,b); + } + + bool operator== ( + const unopt_sparse_linear_kernel& + ) const + { + return true; + } + }; + + template + struct unopt_linear_kernel + { + typedef typename T::type scalar_type; + typedef T sample_type; + typedef typename T::mem_manager_type mem_manager_type; + + scalar_type operator() ( + const sample_type& a, + const sample_type& b + ) const + { + return trans(a)*b; + } + + bool operator== ( + const unopt_linear_kernel& + ) const + { + return true; + } + }; + + bool approx_equal(double a, double b) + { + return (std::abs(a-b) < 1000*(std::numeric_limits::epsilon())); + } + + bool approx_equal(double a, double b, double eps) + { + return (std::abs(a-b) < eps); + } + + template + double dist ( + const K& k, + const matrix& a, + const matrix& b + ) + /*! + ensures + - returns the distance between the a and b vectors in the + feature space defined by the given kernel k. + !*/ + { + const double bias = std::sqrt(k.offset); + return std::sqrt(length_squared(a-colm(b,0,4)) + std::pow(b(4)-bias,2.0)); + + } + + template + double dist ( + const K& k, + std::map a, + std::map b + ) + /*! + ensures + - returns the distance between the a and b vectors in the + feature space defined by the given kernel k. + !*/ + { + double temp = 0; + const double bias = std::sqrt(k.offset); + temp += std::pow(a[0]-b[0],2.0); + temp += std::pow(a[1]-b[1],2.0); + temp += std::pow(a[2]-b[2],2.0); + temp += std::pow(a[3]-b[3],2.0); + temp += std::pow(bias-b[4],2.0); + + return std::sqrt(temp); + + } + +// ---------------------------------------------------------------------------------------- + + template + void test_kcentroid_with_linear_kernel( + ) + /*! + requires + - kernel_type::sample_type == a matrix + - kernel_type == a kernel that just computes a dot product + between its inputs. I.e. a linear kernel + ensures + - tests the kcentroid object with the given kernel + !*/ + { + // Here we declare that our samples will be 2 dimensional column vectors. + typedef typename kernel_type::sample_type sample_type; + + kernel_type default_kernel; + kcentroid test(default_kernel,0.001,20); + + sample_type temp, temp2; + + temp = 2,0,0,0,0; + dlog << LDEBUG << test(temp) ; + dlog << LDEBUG << "squared_norm(): " << test.squared_norm() ; + + DLIB_TEST(approx_equal(test(temp), 2)); + DLIB_TEST(approx_equal(test.squared_norm(), 0)); + + // make test store the point(2,0,0,0,0) + test.train(temp, 0, 1); + dlog << LDEBUG << test(temp) ; + dlog << LDEBUG << "squared_norm(): " << test.squared_norm() ; + DLIB_TEST(approx_equal(test(temp), 0)); + DLIB_TEST(approx_equal(test.get_distance_function()(temp), 0)); + DLIB_TEST(approx_equal(test.squared_norm(), 4)); + + temp = 0,2,0,0,0; + dlog << LDEBUG << test(temp) ; + DLIB_TEST(approx_equal(test(temp), std::sqrt(2*2 + 2*2.0))); + DLIB_TEST(approx_equal(test.squared_norm(), 4)); + + // make test store the point(0,2,0,0,0) + test.train(temp, 0, 1); + + dlog << LDEBUG << test(temp) ; + DLIB_TEST(approx_equal(test(temp), 0)); + DLIB_TEST(approx_equal(test.squared_norm(), 4)); + + temp = 2,0,0,0,0; + DLIB_TEST(approx_equal(test(temp), std::sqrt(2*2 + 2*2.0))); + DLIB_TEST(approx_equal(test.squared_norm(), 4)); + + // make test store the point(1,1,0,0,0) + test.train(temp, 0.5, 0.5); + + temp = 0; + DLIB_TEST(approx_equal(test(temp), std::sqrt(2.0))); + DLIB_TEST(approx_equal(test.squared_norm(), 2)); + + // make test store the point(1,1,0,3,0) + temp = 0,0,0,3,0; + temp2 = 1,1,0,3,0; + test.train(temp, 1, 1); + + temp = 0; + DLIB_TEST(approx_equal(test(temp), length(temp2))); + DLIB_TEST(approx_equal(test.squared_norm(), length_squared(temp2))); + temp = 1,2,3,4,5; + DLIB_TEST(approx_equal(test(temp), length(temp2-temp))); + DLIB_TEST(approx_equal(test.get_distance_function()(temp), length(temp2-temp))); + + // make test store the point(0,1,0,3,-1) + temp = 1,0,0,0,1; + test.train(temp, 1, -1); + temp2 = 0,1,0,3,-1; + + temp = 0; + DLIB_TEST(approx_equal(test(temp), length(temp2))); + DLIB_TEST(approx_equal(test.squared_norm(), length_squared(temp2))); + temp = 1,2,3,4,5; + DLIB_TEST(approx_equal(test(temp), length(temp2-temp))); + + + // make test store the -1*point(0,1,0,3,-1) + temp = 0,0,0,0,0; + test.train(temp, -1, 0); + temp2 = -temp2; + + temp = 0; + DLIB_TEST(approx_equal(test(temp), length(temp2))); + DLIB_TEST(approx_equal(test.squared_norm(), length_squared(temp2))); + temp = 1,2,-3,4,5; + DLIB_TEST(approx_equal(test(temp), length(temp2-temp))); + + + + // make test store the point(0,0,0,0,0) + temp = 0,0,0,0,0; + test.train(temp, 0, 0); + temp2 = 0; + + temp = 0; + DLIB_TEST(approx_equal(test(temp), length(temp2))); + DLIB_TEST(approx_equal(test.squared_norm(), length_squared(temp2))); + temp = 1,2,-3,4,5; + DLIB_TEST(approx_equal(test(temp), length(temp2-temp))); + + + + // make test store the point(1,0,0,0,0) + temp = 1,0,0,0,0; + test.train(temp, 1, 1); + temp2 = 1,0,0,0,0; + + temp = 0; + DLIB_TEST(approx_equal(test(temp), length(temp2))); + DLIB_TEST(approx_equal(test.squared_norm(), length_squared(temp2))); + DLIB_TEST(approx_equal(test.inner_product(test), length_squared(temp2))); + temp = 1,2,-3,4,5; + DLIB_TEST(approx_equal(test(temp), length(temp2-temp))); + DLIB_TEST(approx_equal(test(test), 0)); + DLIB_TEST(approx_equal(test.get_distance_function()(test.get_distance_function()), 0)); + + } + +// ---------------------------------------------------------------------------------------- + + template + void test_kcentroid_with_offset_linear_kernel( + ) + /*! + requires + - kernel_type::sample_type == a matrix + - kernel_type == a kernel that just computes a dot product + between its inputs + some constant. I.e. a linear kernel + wrapped by offset_kernel + ensures + - tests the kcentroid object with the given kernel + !*/ + { + // Here we declare that our samples will be 2 dimensional column vectors. + typedef typename kernel_type::sample_type sample_type; + + kernel_type k; + kcentroid test(k,0.001,20); + + sample_type temp, temp2, temp3; + + matrix val, val2; + + const double b = std::sqrt(k.offset); + + temp = 2,0,0,0; + temp2 = 0; + val = 0; + DLIB_TEST(approx_equal(test(temp), dist(k,temp,val))); + DLIB_TEST(approx_equal(test(temp2), dist(k,temp2,val))); + DLIB_TEST(approx_equal(test.squared_norm(), length_squared(val))); + + + temp2 = 0; + + // make test store the point(0,0,0,0,b) + val = 0,0,0,0,b; + test.train(temp2, 0,1); + + temp = 2,0,0,0; + dlog << LDEBUG << test(temp) ; + dlog << LDEBUG << "squared_norm(): " << test.squared_norm() ; + + DLIB_TEST(approx_equal(test(temp), dist(k,temp,val))); + DLIB_TEST(approx_equal(test(temp2), dist(k,temp2,val))); + DLIB_TEST_MSG(approx_equal(test.get_distance_function()(temp2), dist(k,temp2,val), 1e-6), + test.get_distance_function()(temp2) - dist(k,temp2,val) << " compare to: " << + test(temp2) - dist(k,temp2,val) + ); + DLIB_TEST(approx_equal(test.squared_norm(), length_squared(val))); + + + // make test store the point(0,0,0,0,0) + val = 0,0,0,0,0; + test.train(temp2, 1,-1); + DLIB_TEST(approx_equal(test(temp), dist(k,temp,val))); + DLIB_TEST(approx_equal(test(temp2), dist(k,temp2,val))); + DLIB_TEST_MSG(approx_equal(test.get_distance_function()(temp2), dist(k,temp2,val)), + test.get_distance_function()(temp2) - dist(k,temp2,val) << " compare to: " << + test(temp2) - dist(k,temp2,val) + ); + DLIB_TEST(approx_equal(test.squared_norm(), length_squared(val))); + + + + val2 = 0,1,0,0,b; + val += val2; + temp2 = 0,1,0,0; + // make test store the point val + test.train(temp2, 1,1); + + temp = 1,0,3,0; + DLIB_TEST(approx_equal(test(temp), dist(k,temp,val))); + DLIB_TEST_MSG(approx_equal(test(temp2), dist(k,temp2,val), 1e-7), + test(temp2) - dist(k,temp2,val)); + DLIB_TEST(approx_equal(test.squared_norm(), length_squared(val))); + DLIB_TEST_MSG(approx_equal(test(test), 0, 1e-7), test(test)); + + + val2 = 0,1,2.6,8,b; + val += val2; + temp2 = 0,1,2.6,8; + // make test store the point val + test.train(temp2, 1,1); + + temp = 1,1,3,0; + DLIB_TEST(approx_equal(test(temp), dist(k,temp,val))); + DLIB_TEST_MSG(approx_equal(test(temp2), dist(k,temp2,val)), test(temp2) - dist(k,temp2,val)); + DLIB_TEST(approx_equal(test.squared_norm(), length_squared(val))); + DLIB_TEST(approx_equal(test.inner_product(test), length_squared(val))); + DLIB_TEST(approx_equal(test(test), 0)); + DLIB_TEST_MSG(approx_equal(test.get_distance_function()(test.get_distance_function()), 0, 1e-6), + test.get_distance_function()(test.get_distance_function())); + } + +// ---------------------------------------------------------------------------------------- + + template + void test_kcentroid_with_sparse_linear_kernel( + ) + /*! + requires + - kernel_type::sample_type == a std::map + - kernel_type == a kernel that just computes a dot product + between its inputs. I.e. a linear kernel + ensures + - tests the kcentroid object with the given kernel + !*/ + { + // Here we declare that our samples will be 2 dimensional column vectors. + typedef typename kernel_type::sample_type sample_type; + + kernel_type default_kernel; + kcentroid test(default_kernel,0.001,20); + + dlog << LDEBUG << "AAAA 1" ; + + sample_type temp, temp2; + + temp[0] = 2; + dlog << LDEBUG << test(temp) ; + dlog << LDEBUG << "squared_norm(): " << test.squared_norm() ; + + DLIB_TEST(approx_equal(test(temp), 2)); + DLIB_TEST(approx_equal(test.squared_norm(), 0)); + + // make test store the point(2,0,0,0,0) + test.train(temp, 0, 1); + dlog << LDEBUG << test(temp) ; + dlog << LDEBUG << "squared_norm(): " << test.squared_norm() ; + DLIB_TEST(approx_equal(test(temp), 0)); + DLIB_TEST(approx_equal(test.squared_norm(), 4)); + + dlog << LDEBUG << "AAAA 2" ; + temp.clear(); + temp[1] = 2; + dlog << LDEBUG << test(temp) ; + DLIB_TEST(approx_equal(test(temp), std::sqrt(2*2 + 2*2.0))); + DLIB_TEST(approx_equal(test.squared_norm(), 4)); + + // make test store the point(0,2,0,0,0) + test.train(temp, 0, 1); + + dlog << LDEBUG << test(temp) ; + DLIB_TEST(approx_equal(test(temp), 0)); + DLIB_TEST(approx_equal(test.squared_norm(), 4)); + + temp.clear(); + temp[0] = 2; + DLIB_TEST(approx_equal(test(temp), std::sqrt(2*2 + 2*2.0))); + DLIB_TEST(approx_equal(test.squared_norm(), 4)); + + // make test store the point(1,1,0,0,0) + test.train(temp, 0.5, 0.5); + + dlog << LDEBUG << "AAAA 3" ; + temp.clear(); + DLIB_TEST(approx_equal(test(temp), std::sqrt(2.0))); + DLIB_TEST(approx_equal(test.squared_norm(), 2)); + DLIB_TEST(approx_equal(test(test), 0)); + DLIB_TEST(approx_equal(test.get_distance_function()(test.get_distance_function()), 0)); + + dlog << LDEBUG << "AAAA 3.1" ; + // make test store the point(1,1,0,3,0) + temp.clear(); temp[3] = 3; + temp2.clear(); + temp2[0] = 1; + temp2[1] = 1; + temp2[3] = 3; + test.train(temp, 1, 1); + + dlog << LDEBUG << "AAAA 3.2" ; + temp.clear(); + DLIB_TEST(approx_equal(test(temp), length(temp2))); + DLIB_TEST(approx_equal(test.squared_norm(), length_squared(temp2))); + dlog << LDEBUG << "AAAA 3.3" ; + temp[0] = 1; + temp[1] = 2; + temp[2] = 3; + temp[3] = 4; + temp[4] = 5; + dlog << LDEBUG << "AAAA 3.4" ; + double junk = dlib::distance(temp2,temp); + dlog << LDEBUG << "AAAA 3.5" ; + DLIB_TEST(approx_equal(test(temp), junk) ); + + dlog << LDEBUG << "AAAA 4" ; + // make test store the point(0,1,0,3,-1) + temp.clear(); + temp[0] = 1; + temp[4] = 1; + test.train(temp, 1, -1); + temp2.clear(); + temp2[1] = 1; + temp2[3] = 3; + temp2[4] = -1; + + temp.clear(); + DLIB_TEST(approx_equal(test(temp), length(temp2))); + DLIB_TEST(approx_equal(test.squared_norm(), length_squared(temp2))); + temp[0] = 1; + temp[1] = 2; + temp[2] = 3; + temp[3] = 4; + temp[4] = 5; + DLIB_TEST(approx_equal(test(temp), dlib::distance(temp2,temp))); + + + // make test store the -1*point(0,1,0,3,-1) + temp.clear(); + test.train(temp, -1, 0); + temp2[0] = -temp2[0]; + temp2[1] = -temp2[1]; + temp2[2] = -temp2[2]; + temp2[3] = -temp2[3]; + temp2[4] = -temp2[4]; + + dlog << LDEBUG << "AAAA 5" ; + temp.clear(); + DLIB_TEST(approx_equal(test(temp), length(temp2))); + DLIB_TEST(approx_equal(test.squared_norm(), length_squared(temp2))); + temp[0] = 1; + temp[1] = 2; + temp[2] = -3; + temp[3] = 4; + temp[4] = 5; + DLIB_TEST(approx_equal(test(temp), dlib::distance(temp2,temp))); + + + + // make test store the point(0,0,0,0,0) + temp.clear(); + test.train(temp, 0, 0); + temp2.clear(); + + temp.clear(); + DLIB_TEST(approx_equal(test(temp), length(temp2))); + DLIB_TEST(approx_equal(test.squared_norm(), length_squared(temp2))); + temp[0] = 1; + temp[1] = 2; + temp[2] = -3; + temp[3] = 4; + temp[4] = 5; + DLIB_TEST(approx_equal(test(temp), dlib::distance(temp2,temp))); + DLIB_TEST(approx_equal(test.get_distance_function()(temp), dlib::distance(temp2,temp))); + + + dlog << LDEBUG << "AAAA 6" ; + + // make test store the point(1,0,0,0,0) + temp.clear(); + temp[0] = 1; + test.train(temp, 1, 1); + temp2.clear(); + temp2[0] = 1; + + temp.clear(); + DLIB_TEST(approx_equal(test(temp), length(temp2))); + DLIB_TEST(approx_equal(test.squared_norm(), length_squared(temp2))); + DLIB_TEST(approx_equal(test.inner_product(test), length_squared(temp2))); + temp[0] = 1; + temp[1] = 2; + temp[2] = -3; + temp[3] = 4; + temp[4] = 5; + DLIB_TEST(approx_equal(test(temp), dlib::distance(temp2,temp))); + DLIB_TEST(approx_equal(test.get_distance_function()(temp), dlib::distance(temp2,temp))); + DLIB_TEST(approx_equal(test(test), 0)); + DLIB_TEST(approx_equal(test.get_distance_function()(test.get_distance_function()), 0)); + + dlog << LDEBUG << "AAAA 7" ; + } + +// ---------------------------------------------------------------------------------------- + + template + void test_kcentroid_with_offset_sparse_linear_kernel( + ) + /*! + requires + - kernel_type::sample_type == a std::map + - kernel_type == a kernel that just computes a dot product + between its inputs + some constant. I.e. a linear kernel + wrapped by offset_kernel + ensures + - tests the kcentroid object with the given kernel + !*/ + { + // Here we declare that our samples will be 2 dimensional column vectors. + typedef typename kernel_type::sample_type sample_type; + + kernel_type k; + kcentroid test(k,0.001,20); + + sample_type temp, temp2, temp3; + + std::map val, val2; + + const double b = std::sqrt(k.offset); + + temp.clear(); + temp[0] = 2; + temp2.clear(); + val.clear(); + DLIB_TEST(approx_equal(test(temp), dist(k,temp,val))); + DLIB_TEST(approx_equal(test(temp2), dist(k,temp2,val))); + DLIB_TEST(approx_equal(test.squared_norm(), length_squared(val))); + + + temp2.clear(); + + // make test store the point(0,0,0,0,b) + val.clear(); + val[4] = b; + test.train(temp2, 0,1); + + temp.clear(); + temp[0] = 2; + dlog << LDEBUG << test(temp) ; + dlog << LDEBUG << "squared_norm(): " << test.squared_norm() ; + + DLIB_TEST(approx_equal(test(temp), dist(k,temp,val))); + DLIB_TEST(approx_equal(test(temp2), dist(k,temp2,val))); + DLIB_TEST_MSG(approx_equal(test.get_distance_function()(temp2), dist(k,temp2,val), 1e-7), + test.get_distance_function()(temp2) - dist(k,temp2,val) + ); + DLIB_TEST(approx_equal(test.squared_norm(), length_squared(val))); + DLIB_TEST(approx_equal(test(test), 0)); + DLIB_TEST(approx_equal(test.get_distance_function()(test.get_distance_function()), 0, 1e-6)); + + // make test store the point(0,0,0,0,0) + val.clear(); + test.train(temp2, 1,-1); + + temp.clear(); + temp[0] = 2; + dlog << LDEBUG << test(temp) ; + dlog << LDEBUG << "squared_norm(): " << test.squared_norm() ; + + DLIB_TEST_MSG(approx_equal(test(temp), dist(k,temp,val)), test(temp) - dist(k,temp,val)); + DLIB_TEST(approx_equal(test(temp2), dist(k,temp2,val))); + DLIB_TEST(approx_equal(test.get_distance_function()(temp2), dist(k,temp2,val))); + DLIB_TEST(approx_equal(test.squared_norm(), length_squared(val))); + DLIB_TEST(approx_equal(test(test), 0)); + DLIB_TEST(approx_equal(test.get_distance_function()(test.get_distance_function()), 0)); + + val2.clear(); + val2[0] = 0; + val2[1] = 1; + val2[2] = 0; + val2[3] = 0; + val2[4] = b; + for (unsigned int i = 0; i < 5; ++i) val[i] += val2[i]; + temp2.clear(); + temp2[1] = 1; + // make test store the point val + test.train(temp2, 1,1); + + temp.clear(); + temp[0] = 1; + temp[2] = 3; + DLIB_TEST(approx_equal(test(temp), dist(k,temp,val))); + DLIB_TEST_MSG(approx_equal(test(temp2), dist(k,temp2,val), 1e-7), + test(temp2) - dist(k,temp2,val)); + DLIB_TEST(approx_equal(test.squared_norm(), length_squared(val))); + + + val2.clear(); + val2[0] = 0; + val2[1] = 1; + val2[2] = 2.6; + val2[3] = 8; + val2[4] = b; + for (unsigned int i = 0; i < 5; ++i) val[i] += val2[i]; + + temp2.clear(); + temp2[0] = 0; + temp2[1] = 1; + temp2[2] = 2.6; + temp2[3] = 8; + // make test store the point val + test.train(temp2, 1,1); + + temp.clear(); + temp[0] = 1; + temp[1] = 1; + temp[2] = 3; + temp[3] = 0; + DLIB_TEST(approx_equal(test(temp), dist(k,temp,val))); + DLIB_TEST_MSG(approx_equal(test(temp2), dist(k,temp2,val)), test(temp2) - dist(k,temp2,val)); + DLIB_TEST(approx_equal(test.squared_norm(), length_squared(val))); + DLIB_TEST(approx_equal(test.inner_product(test), length_squared(val))); + DLIB_TEST_MSG(approx_equal(test(test), 0, 1e-6), test(test)); + DLIB_TEST(approx_equal(test.get_distance_function()(test.get_distance_function()), 0)); + } + +// ---------------------------------------------------------------------------------------- + + class kcentroid_tester : public tester + { + public: + kcentroid_tester ( + ) : + tester ("test_kcentroid", + "Runs tests on the kcentroid components.") + {} + + void perform_test ( + ) + { + // The idea here is to exercize all the various overloads of the kcentroid object. We also want + // to exercize the non-overloaded default version. That is why we have these unopt_* linear + // kernels + test_kcentroid_with_linear_kernel > >(); + test_kcentroid_with_offset_linear_kernel > > >(); + test_kcentroid_with_linear_kernel > >(); + test_kcentroid_with_offset_linear_kernel > > >(); + test_kcentroid_with_sparse_linear_kernel > >(); + test_kcentroid_with_offset_sparse_linear_kernel > > >(); + test_kcentroid_with_sparse_linear_kernel > >(); + test_kcentroid_with_offset_sparse_linear_kernel > > >(); + } + } a; + +} + + diff --git a/ml/dlib/dlib/test/kernel_matrix.cpp b/ml/dlib/dlib/test/kernel_matrix.cpp new file mode 100644 index 000000000..8fc3a2d2c --- /dev/null +++ b/ml/dlib/dlib/test/kernel_matrix.cpp @@ -0,0 +1,161 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "tester.h" +#include +#include +#include + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + dlib::logger dlog("test.kernel_matrix"); + + + class kernel_matrix_tester : public tester + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a unit test. When it is constructed + it adds itself into the testing framework. + !*/ + public: + kernel_matrix_tester ( + ) : + tester ( + "test_kernel_matrix", // the command line argument name for this test + "Run tests on the kernel_matrix functions.", // the command line argument description + 0 // the number of command line arguments for this test + ) + { + } + + + void perform_test ( + ) + { + print_spinner(); + + typedef matrix sample_type; + typedef radial_basis_kernel kernel_type; + kernel_type kern(0.1); + + std::vector vect1; + std::vector vect2; + + const sample_type samp = randm(4,1); + sample_type samp2, samp3; + + vect1.push_back(randm(4,1)); + vect1.push_back(randm(4,1)); + vect1.push_back(randm(4,1)); + vect1.push_back(randm(4,1)); + + vect2.push_back(randm(4,1)); + vect2.push_back(randm(4,1)); + vect2.push_back(randm(4,1)); + vect2.push_back(randm(4,1)); + vect2.push_back(randm(4,1)); + + matrix K; + + K.set_size(vect1.size(), vect2.size()); + for (long r = 0; r < K.nr(); ++r) + { + for (long c = 0; c < K.nc(); ++c) + { + K(r,c) = kern(vect1[r], vect2[c]); + } + } + DLIB_TEST(equal(K, kernel_matrix(kern, vect1, vect2))); + DLIB_TEST(equal(K, kernel_matrix(kern, mat(vect1), mat(vect2)))); + + + K.set_size(vect2.size(), vect1.size()); + for (long r = 0; r < K.nr(); ++r) + { + for (long c = 0; c < K.nc(); ++c) + { + K(r,c) = kern(vect2[r], vect1[c]); + } + } + DLIB_TEST(equal(K, kernel_matrix(kern, vect2, vect1))); + DLIB_TEST(equal(K, tmp(kernel_matrix(kern, vect2, vect1)))); + DLIB_TEST(equal(K, kernel_matrix(kern, mat(vect2), mat(vect1)))); + + + K.set_size(vect1.size(), vect1.size()); + for (long r = 0; r < K.nr(); ++r) + { + for (long c = 0; c < K.nc(); ++c) + { + K(r,c) = kern(vect1[r], vect1[c]); + } + } + DLIB_TEST(equal(K, kernel_matrix(kern, vect1, vect1))); + DLIB_TEST(equal(K, tmp(kernel_matrix(kern, vect1, vect1)))); + DLIB_TEST(equal(K, kernel_matrix(kern, vect1))); + DLIB_TEST(equal(K, tmp(kernel_matrix(kern, vect1)))); + DLIB_TEST(equal(K, kernel_matrix(kern, mat(vect1), mat(vect1)))); + DLIB_TEST(equal(K, tmp(kernel_matrix(kern, mat(vect1), mat(vect1))))); + DLIB_TEST(equal(K, kernel_matrix(kern, mat(vect1)))); + DLIB_TEST(equal(K, tmp(kernel_matrix(kern, mat(vect1))))); + + + K.set_size(vect1.size(),1); + for (long r = 0; r < K.nr(); ++r) + { + for (long c = 0; c < K.nc(); ++c) + { + K(r,c) = kern(vect1[r], samp); + } + } + DLIB_TEST(equal(K, kernel_matrix(kern, vect1, samp))); + DLIB_TEST(equal(K, kernel_matrix(kern, mat(vect1), samp))); + + + K.set_size(1, vect1.size()); + for (long r = 0; r < K.nr(); ++r) + { + for (long c = 0; c < K.nc(); ++c) + { + K(r,c) = kern(samp, vect1[c]); + } + } + DLIB_TEST(equal(K, kernel_matrix(kern, samp, vect1))); + DLIB_TEST(equal(K, kernel_matrix(kern, samp, mat(vect1)))); + DLIB_TEST(equal(K, tmp(kernel_matrix(kern, samp, vect1)))); + DLIB_TEST(equal(K, tmp(kernel_matrix(kern, samp, mat(vect1))))); + + + + samp2 = samp; + samp3 = samp; + + // test the alias detection + samp2 = kernel_matrix(kern, vect1, samp2); + DLIB_TEST(equal(samp2, kernel_matrix(kern, vect1, samp))); + + samp3 = trans(kernel_matrix(kern, samp3, vect2)); + DLIB_TEST(equal(samp3, trans(kernel_matrix(kern, samp, vect2)))); + + + samp2 += kernel_matrix(kern, vect1, samp); + DLIB_TEST(equal(samp2, 2*kernel_matrix(kern, vect1, samp))); + + samp3 += trans(kernel_matrix(kern, samp, vect2)); + DLIB_TEST(equal(samp3, 2*trans(kernel_matrix(kern, samp, vect2)))); + } + }; + + // Create an instance of this object. Doing this causes this test + // to be automatically inserted into the testing framework whenever this cpp file + // is linked into the project. Note that since we are inside an unnamed-namespace + // we won't get any linker errors about the symbol a being defined multiple times. + kernel_matrix_tester a; + +} + + diff --git a/ml/dlib/dlib/test/kmeans.cpp b/ml/dlib/dlib/test/kmeans.cpp new file mode 100644 index 000000000..95c037b77 --- /dev/null +++ b/ml/dlib/dlib/test/kmeans.cpp @@ -0,0 +1,163 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include +#include +#include +#include +#include +#include + +#include "tester.h" + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.kmeans"); + + dlib::rand rnd; + + template + void run_test( + const std::vector& seed_centers + ) + { + print_spinner(); + + + sample_type samp; + + std::vector samples; + + + for (unsigned long j = 0; j < seed_centers.size(); ++j) + { + for (int i = 0; i < 250; ++i) + { + samp = randm(seed_centers[0].size(),1,rnd) - 0.5; + samples.push_back(samp + seed_centers[j]); + } + } + + randomize_samples(samples); + + { + std::vector centers; + pick_initial_centers(seed_centers.size(), centers, samples, linear_kernel()); + + find_clusters_using_kmeans(samples, centers); + + DLIB_TEST(centers.size() == seed_centers.size()); + + std::vector hits(centers.size(),0); + for (unsigned long i = 0; i < samples.size(); ++i) + { + unsigned long best_idx = 0; + double best_dist = 1e100; + for (unsigned long j = 0; j < centers.size(); ++j) + { + if (length(samples[i] - centers[j]) < best_dist) + { + best_dist = length(samples[i] - centers[j]); + best_idx = j; + } + } + hits[best_idx]++; + } + + for (unsigned long i = 0; i < hits.size(); ++i) + { + DLIB_TEST(hits[i] == 250); + } + } + { + std::vector centers; + pick_initial_centers(seed_centers.size(), centers, samples, linear_kernel()); + + find_clusters_using_angular_kmeans(samples, centers); + + DLIB_TEST(centers.size() == seed_centers.size()); + + std::vector hits(centers.size(),0); + for (unsigned long i = 0; i < samples.size(); ++i) + { + unsigned long best_idx = 0; + double best_dist = 1e100; + for (unsigned long j = 0; j < centers.size(); ++j) + { + if (length(samples[i] - centers[j]) < best_dist) + { + best_dist = length(samples[i] - centers[j]); + best_idx = j; + } + } + hits[best_idx]++; + } + + for (unsigned long i = 0; i < hits.size(); ++i) + { + DLIB_TEST(hits[i] == 250); + } + } + } + + + class test_kmeans : public tester + { + public: + test_kmeans ( + ) : + tester ("test_kmeans", + "Runs tests on the find_clusters_using_kmeans() function.") + {} + + void perform_test ( + ) + { + { + dlog << LINFO << "test dlib::vector"; + typedef dlib::vector sample_type; + std::vector seed_centers; + seed_centers.push_back(sample_type(10,10)); + seed_centers.push_back(sample_type(10,-10)); + seed_centers.push_back(sample_type(-10,10)); + seed_centers.push_back(sample_type(-10,-10)); + + run_test(seed_centers); + } + { + dlog << LINFO << "test dlib::vector"; + typedef dlib::vector sample_type; + std::vector seed_centers; + seed_centers.push_back(sample_type(10,10)); + seed_centers.push_back(sample_type(10,-10)); + seed_centers.push_back(sample_type(-10,10)); + seed_centers.push_back(sample_type(-10,-10)); + + run_test(seed_centers); + } + { + dlog << LINFO << "test dlib::matrix"; + typedef dlib::matrix sample_type; + std::vector seed_centers; + sample_type samp; + samp = 10,10,0; seed_centers.push_back(samp); + samp = -10,10,1; seed_centers.push_back(samp); + samp = -10,-10,2; seed_centers.push_back(samp); + + run_test(seed_centers); + } + + + } + } a; + + + +} + + + diff --git a/ml/dlib/dlib/test/learning_to_track.cpp b/ml/dlib/dlib/test/learning_to_track.cpp new file mode 100644 index 000000000..730c01ba8 --- /dev/null +++ b/ml/dlib/dlib/test/learning_to_track.cpp @@ -0,0 +1,306 @@ +// Copyright (C) 2014 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include "tester.h" +#include +#include + + + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.learning_to_track"); + +// ---------------------------------------------------------------------------------------- + + struct detection_dense + { + typedef struct track_dense track_type; + matrix measurements; + }; + + + struct track_dense + { + typedef matrix feature_vector_type; + + track_dense() + { + time_since_last_association = 0; + } + + void get_similarity_features(const detection_dense det, feature_vector_type& feats) const + { + feats = abs(last_measurements - det.measurements); + } + + void update_track(const detection_dense det) + { + last_measurements = det.measurements; + time_since_last_association = 0; + } + + void propagate_track() + { + ++time_since_last_association; + } + + matrix last_measurements; + unsigned long time_since_last_association; + }; + +// ---------------------------------------------------------------------------------------- + + struct detection_sparse + { + typedef struct track_sparse track_type; + matrix measurements; + }; + + + struct track_sparse + { + typedef std::vector > feature_vector_type; + + track_sparse() + { + time_since_last_association = 0; + } + + void get_similarity_features(const detection_sparse det, feature_vector_type& feats) const + { + matrix temp = abs(last_measurements - det.measurements); + feats.clear(); + for (long i = 0; i < temp.size(); ++i) + feats.push_back(make_pair(i, temp(i))); + } + + void update_track(const detection_sparse det) + { + last_measurements = det.measurements; + time_since_last_association = 0; + } + + void propagate_track() + { + ++time_since_last_association; + } + + matrix last_measurements; + unsigned long time_since_last_association; + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + dlib::rand rnd; + const long num_objects = 4; + const long num_properties = 6; + std::vector > object_properties(num_objects); + + void initialize_object_properties() + { + rnd.set_seed("23ja2oirfjaf"); + for (unsigned long i = 0; i < object_properties.size(); ++i) + object_properties[i] = randm(num_properties,1,rnd); + } + + template + detection sample_detection_from_sensor(long object_id) + { + DLIB_CASSERT(object_id < num_objects, + "You can't ask to sample a detection from an object that doesn't exist."); + detection temp; + // Set the measurements equal to the object's true property values plus a little bit of + // noise. + temp.measurements = object_properties[object_id] + randm(num_properties,1,rnd)*0.1; + return temp; + } + +// ---------------------------------------------------------------------------------------- + + + template + std::vector > > make_random_tracking_data_for_training() + { + typedef std::vector > detections_at_single_time_step; + typedef std::vector track_history; + + track_history data; + + // At each time step we get a set of detections from the objects in the world. + // Simulate 100 time steps worth of data where there are 3 objects present. + const int num_time_steps = 100; + for (int i = 0; i < num_time_steps; ++i) + { + detections_at_single_time_step dets(3); + // sample a detection from object 0 + dets[0].det = sample_detection_from_sensor(0); + dets[0].label = 0; + + // sample a detection from object 1 + dets[1].det = sample_detection_from_sensor(1); + dets[1].label = 1; + + // sample a detection from object 2 + dets[2].det = sample_detection_from_sensor(2); + dets[2].label = 2; + + randomize_samples(dets, rnd); + data.push_back(dets); + } + + // Now let's imagine object 1 and 2 are gone but a new object, object 3 has arrived. + for (int i = 0; i < num_time_steps; ++i) + { + detections_at_single_time_step dets(2); + // sample a detection from object 0 + dets[0].det = sample_detection_from_sensor(0); + dets[0].label = 0; + + // sample a detection from object 3 + dets[1].det = sample_detection_from_sensor(3); + dets[1].label = 3; + + randomize_samples(dets, rnd); + data.push_back(dets); + } + + return data; + } + +// ---------------------------------------------------------------------------------------- + + template + std::vector make_random_detections(long num_dets) + { + DLIB_CASSERT(num_dets <= num_objects, + "You can't ask for more detections than there are objects in our little simulation."); + + std::vector dets(num_dets); + for (unsigned long i = 0; i < dets.size(); ++i) + { + dets[i] = sample_detection_from_sensor(i); + } + randomize_samples(dets, rnd); + return dets; + } + +// ---------------------------------------------------------------------------------------- + + template + void test_tracking_stuff() + { + print_spinner(); + + + typedef std::vector > detections_at_single_time_step; + typedef std::vector track_history; + std::vector data; + data.push_back(make_random_tracking_data_for_training()); + data.push_back(make_random_tracking_data_for_training()); + data.push_back(make_random_tracking_data_for_training()); + data.push_back(make_random_tracking_data_for_training()); + data.push_back(make_random_tracking_data_for_training()); + + + structural_track_association_trainer trainer; + trainer.set_c(1000); + track_association_function assoc = trainer.train(data); + + double test_val = test_track_association_function(assoc, data); + DLIB_TEST_MSG( test_val == 1, test_val); + test_val = cross_validate_track_association_trainer(trainer, data, 5); + DLIB_TEST_MSG ( test_val == 1, test_val); + + + + typedef typename detection::track_type track; + std::vector tracks; + + std::vector dets = make_random_detections(3); + assoc(tracks, dets); + DLIB_TEST(tracks.size() == 3); + + dets = make_random_detections(3); + assoc(tracks, dets); + DLIB_TEST(tracks.size() == 3); + + dets = make_random_detections(3); + assoc(tracks, dets); + DLIB_TEST(tracks.size() == 3); + + dets = make_random_detections(4); + assoc(tracks, dets); + DLIB_TEST(tracks.size() == 4); + + dets = make_random_detections(3); + assoc(tracks, dets); + DLIB_TEST(tracks.size() == 4); + unsigned long total_miss = 0; + for (unsigned long i = 0; i < tracks.size(); ++i) + total_miss += tracks[i].time_since_last_association; + DLIB_TEST(total_miss == 1); + + dets = make_random_detections(3); + assoc(tracks, dets); + DLIB_TEST(tracks.size() == 4); + total_miss = 0; + unsigned long num_zero = 0; + for (unsigned long i = 0; i < tracks.size(); ++i) + { + total_miss += tracks[i].time_since_last_association; + if (tracks[i].time_since_last_association == 0) + ++num_zero; + } + DLIB_TEST(total_miss == 2); + DLIB_TEST(num_zero == 3); + + + + ostringstream sout; + serialize(assoc, sout); + + istringstream sin(sout.str()); + deserialize(assoc, sin); + DLIB_TEST( test_track_association_function(assoc, data) == 1); + } + + +// ---------------------------------------------------------------------------------------- + + class test_learning_to_track : public tester + { + public: + test_learning_to_track ( + ) : + tester ("test_learning_to_track", + "Runs tests on the assignment learning code.") + {} + + void perform_test ( + ) + { + initialize_object_properties(); + for (int i = 0; i < 3; ++i) + { + dlog << LINFO << "run test_tracking_stuff()"; + test_tracking_stuff(); + dlog << LINFO << "run test_tracking_stuff()"; + test_tracking_stuff(); + } + } + } a; + +// ---------------------------------------------------------------------------------------- + +} + + diff --git a/ml/dlib/dlib/test/least_squares.cpp b/ml/dlib/dlib/test/least_squares.cpp new file mode 100644 index 000000000..c4282ad05 --- /dev/null +++ b/ml/dlib/dlib/test/least_squares.cpp @@ -0,0 +1,452 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include "optimization_test_functions.h" +#include +#include +#include +#include +#include +#include "../rand.h" + +#include "tester.h" + + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + using namespace dlib::test_functions; + + logger dlog("test.least_squares"); + +// ---------------------------------------------------------------------------------------- + + void test_with_chebyquad() + { + print_spinner(); + { + matrix ch; + + ch = chebyquad_start(2); + + solve_least_squares(objective_delta_stop_strategy(1e-13, 80), + chebyquad_residual, + derivative(chebyquad_residual), + range(0,ch.size()-1), + ch); + + dlog << LINFO << "chebyquad 2 obj: " << chebyquad(ch); + dlog << LINFO << "chebyquad 2 der: " << length(chebyquad_derivative(ch)); + dlog << LINFO << "chebyquad 2 error: " << length(ch - chebyquad_solution(2)); + + DLIB_TEST(length(ch - chebyquad_solution(2)) < 1e-5); + + } + { + matrix ch; + + ch = chebyquad_start(2); + + solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80), + chebyquad_residual, + derivative(chebyquad_residual), + range(0,ch.size()-1), + ch); + + dlog << LINFO << "LM chebyquad 2 obj: " << chebyquad(ch); + dlog << LINFO << "LM chebyquad 2 der: " << length(chebyquad_derivative(ch)); + dlog << LINFO << "LM chebyquad 2 error: " << length(ch - chebyquad_solution(2)); + + DLIB_TEST(length(ch - chebyquad_solution(2)) < 1e-5); + + } + + print_spinner(); + { + matrix ch; + + ch = chebyquad_start(2); + + solve_least_squares(objective_delta_stop_strategy(1e-13, 80), + chebyquad_residual, + derivative(chebyquad_residual), + range(0,ch.size()-1), + ch); + + dlog << LINFO << "chebyquad 2 obj: " << chebyquad(ch); + dlog << LINFO << "chebyquad 2 der: " << length(chebyquad_derivative(ch)); + dlog << LINFO << "chebyquad 2 error: " << length(ch - chebyquad_solution(2)); + + DLIB_TEST(length(ch - chebyquad_solution(2)) < 1e-5); + + } + print_spinner(); + { + matrix ch; + + ch = chebyquad_start(2); + + solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80), + chebyquad_residual, + derivative(chebyquad_residual), + range(0,ch.size()-1), + ch); + + dlog << LINFO << "LM chebyquad 2 obj: " << chebyquad(ch); + dlog << LINFO << "LM chebyquad 2 der: " << length(chebyquad_derivative(ch)); + dlog << LINFO << "LM chebyquad 2 error: " << length(ch - chebyquad_solution(2)); + + DLIB_TEST(length(ch - chebyquad_solution(2)) < 1e-5); + + } + + print_spinner(); + { + matrix ch; + + ch = chebyquad_start(4); + + solve_least_squares(objective_delta_stop_strategy(1e-13, 80), + chebyquad_residual, + derivative(chebyquad_residual), + range(0,ch.size()-1), + ch); + + dlog << LINFO << "chebyquad 4 obj: " << chebyquad(ch); + dlog << LINFO << "chebyquad 4 der: " << length(chebyquad_derivative(ch)); + dlog << LINFO << "chebyquad 4 error: " << length(ch - chebyquad_solution(4)); + + DLIB_TEST(length(ch - chebyquad_solution(4)) < 1e-5); + + } + print_spinner(); + { + matrix ch; + + ch = chebyquad_start(4); + + solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80), + chebyquad_residual, + derivative(chebyquad_residual), + range(0,ch.size()-1), + ch); + + dlog << LINFO << "LM chebyquad 4 obj: " << chebyquad(ch); + dlog << LINFO << "LM chebyquad 4 der: " << length(chebyquad_derivative(ch)); + dlog << LINFO << "LM chebyquad 4 error: " << length(ch - chebyquad_solution(4)); + + DLIB_TEST(length(ch - chebyquad_solution(4)) < 1e-5); + + } + + + print_spinner(); + { + matrix ch; + + ch = chebyquad_start(6); + + solve_least_squares(objective_delta_stop_strategy(1e-13, 80), + chebyquad_residual, + derivative(chebyquad_residual), + range(0,ch.size()-1), + ch); + + dlog << LINFO << "chebyquad 6 obj: " << chebyquad(ch); + dlog << LINFO << "chebyquad 6 der: " << length(chebyquad_derivative(ch)); + dlog << LINFO << "chebyquad 6 error: " << length(ch - chebyquad_solution(6)); + + // the ch variable contains a permutation of what is in chebyquad_solution(6). + // Apparently there is more than one minimum?. Just check that the objective + // goes to zero. + DLIB_TEST(chebyquad(ch) < 1e-10); + + } + print_spinner(); + { + matrix ch; + + ch = chebyquad_start(6); + + solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80), + chebyquad_residual, + derivative(chebyquad_residual), + range(0,ch.size()-1), + ch); + + dlog << LINFO << "LM chebyquad 6 obj: " << chebyquad(ch); + dlog << LINFO << "LM chebyquad 6 der: " << length(chebyquad_derivative(ch)); + dlog << LINFO << "LM chebyquad 6 error: " << length(ch - chebyquad_solution(6)); + + DLIB_TEST(chebyquad(ch) < 1e-10); + + } + + + print_spinner(); + { + matrix ch; + + ch = chebyquad_start(8); + + solve_least_squares(objective_delta_stop_strategy(1e-13, 80), + chebyquad_residual, + derivative(chebyquad_residual), + range(0,ch.size()-1), + ch); + + dlog << LINFO << "chebyquad 8 obj: " << chebyquad(ch); + dlog << LINFO << "chebyquad 8 der: " << length(chebyquad_derivative(ch)); + dlog << LINFO << "chebyquad 8 error: " << length(ch - chebyquad_solution(8)); + + DLIB_TEST(length(ch - chebyquad_solution(8)) < 1e-5); + + } + print_spinner(); + { + matrix ch; + + ch = chebyquad_start(8); + + solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80), + chebyquad_residual, + derivative(chebyquad_residual), + range(0,ch.size()-1), + ch); + + dlog << LINFO << "LM chebyquad 8 obj: " << chebyquad(ch); + dlog << LINFO << "LM chebyquad 8 der: " << length(chebyquad_derivative(ch)); + dlog << LINFO << "LM chebyquad 8 error: " << length(ch - chebyquad_solution(8)); + + DLIB_TEST(length(ch - chebyquad_solution(8)) < 1e-5); + + } + } + +// ---------------------------------------------------------------------------------------- + + void test_with_brown() + { + print_spinner(); + { + matrix ch; + + ch = brown_start(); + + solve_least_squares(objective_delta_stop_strategy(1e-13, 300), + brown_residual, + derivative(brown_residual), + range(1,20), + ch); + + dlog << LINFO << "brown obj: " << brown(ch); + dlog << LINFO << "brown der: " << length(brown_derivative(ch)); + dlog << LINFO << "brown error: " << length(ch - brown_solution()); + + DLIB_TEST_MSG(length(ch - brown_solution()) < 1e-5,length(ch - brown_solution()) ); + + } + print_spinner(); + { + matrix ch; + + ch = brown_start(); + + solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80), + brown_residual, + derivative(brown_residual), + range(1,20), + ch); + + dlog << LINFO << "LM brown obj: " << brown(ch); + dlog << LINFO << "LM brown der: " << length(brown_derivative(ch)); + dlog << LINFO << "LM brown error: " << length(ch - brown_solution()); + + DLIB_TEST(length(ch - brown_solution()) < 1e-5); + + } + } + +// ---------------------------------------------------------------------------------------- + +// These functions are declared here because wrapping the real rosen functions in this +// way avoids triggering a bug in visual studio 2005 which prevents this code from compiling. + double rosen_residual_double (int i, const matrix& m) + { return rosen_residual(i,m); } + float rosen_residual_float (int i, const matrix& m) + { return rosen_residual(i,m); } + + matrix rosen_residual_derivative_double (int i, const matrix& m) + { return rosen_residual_derivative(i,m); } + /* + matrix rosen_residual_derivative_float (int i, const matrix& m) + { return rosen_residual_derivative(i,m); } + */ + + double rosen_big_residual_double (int i, const matrix& m) + { return rosen_big_residual(i,m); } + +// ---------------------------------------------------------------------------------------- + + void test_with_rosen() + { + + print_spinner(); + { + matrix ch; + + ch = rosen_start(); + + solve_least_squares(objective_delta_stop_strategy(1e-13, 80), + rosen_residual_double, + rosen_residual_derivative_double, + range(1,20), + ch); + + dlog << LINFO << "rosen obj: " << rosen(ch); + dlog << LINFO << "rosen error: " << length(ch - rosen_solution()); + + DLIB_TEST(length(ch - rosen_solution()) < 1e-5); + + } + print_spinner(); + { + matrix ch; + + ch = rosen_start(); + + solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80), + rosen_residual_double, + rosen_residual_derivative_double, + range(1,20), + ch); + + dlog << LINFO << "lm rosen obj: " << rosen(ch); + dlog << LINFO << "lm rosen error: " << length(ch - rosen_solution()); + + DLIB_TEST(length(ch - rosen_solution()) < 1e-5); + + } + + + + print_spinner(); + { + matrix ch; + + ch = rosen_start(); + + solve_least_squares(objective_delta_stop_strategy(1e-13, 80), + rosen_residual_double, + derivative(rosen_residual_double), + range(1,20), + ch); + + dlog << LINFO << "rosen obj: " << rosen(ch); + dlog << LINFO << "rosen error: " << length(ch - rosen_solution()); + + DLIB_TEST(length(ch - rosen_solution()) < 1e-5); + + } + print_spinner(); + { + matrix ch; + + ch = rosen_start(); + + solve_least_squares(objective_delta_stop_strategy(1e-13, 80), + rosen_residual_float, + derivative(rosen_residual_float), + range(1,20), + ch); + + dlog << LINFO << "float rosen obj: " << rosen(ch); + dlog << LINFO << "float rosen error: " << length(ch - rosen_solution()); + + DLIB_TEST(length(ch - rosen_solution()) < 1e-5); + + } + print_spinner(); + { + matrix ch; + + ch = rosen_start(); + + solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80), + rosen_residual_float, + derivative(rosen_residual_float), + range(1,20), + ch); + + dlog << LINFO << "LM float rosen obj: " << rosen(ch); + dlog << LINFO << "LM float rosen error: " << length(ch - rosen_solution()); + + DLIB_TEST(length(ch - rosen_solution()) < 1e-5); + + } + print_spinner(); + { + matrix ch; + + ch = rosen_start(); + + solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80), + rosen_residual_double, + derivative(rosen_residual_double), + range(1,20), + ch); + + dlog << LINFO << "LM rosen obj: " << rosen(ch); + dlog << LINFO << "LM rosen error: " << length(ch - rosen_solution()); + + DLIB_TEST(length(ch - rosen_solution()) < 1e-5); + + } + print_spinner(); + { + matrix ch; + + ch = rosen_big_start(); + + solve_least_squares(objective_delta_stop_strategy(1e-13, 80), + rosen_big_residual_double, + derivative(rosen_big_residual_double), + range(1,2), + ch); + + dlog << LINFO << "rosen big obj: " << rosen_big(ch); + dlog << LINFO << "rosen big error: " << length(ch - rosen_big_solution()); + + DLIB_TEST(length(ch - rosen_big_solution()) < 1e-5); + + } + } + +// ---------------------------------------------------------------------------------------- + + class optimization_tester : public tester + { + public: + optimization_tester ( + ) : + tester ("test_least_squares", + "Runs tests on the least squares optimization component.") + {} + + void perform_test ( + ) + { + test_with_chebyquad(); + test_with_brown(); + test_with_rosen(); + } + } a; + +} + + diff --git a/ml/dlib/dlib/test/linear_manifold_regularizer.cpp b/ml/dlib/dlib/test/linear_manifold_regularizer.cpp new file mode 100644 index 000000000..e73b1c8d3 --- /dev/null +++ b/ml/dlib/dlib/test/linear_manifold_regularizer.cpp @@ -0,0 +1,408 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "tester.h" +#include +#include +#include +#include +#include +#include +#include +#include + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + dlib::logger dlog("test.linear_manifold_regularizer"); + + template + void test_find_k_nearest_neighbors_lsh( + const samples_type& samples + ) + { + std::vector edges1, edges2; + + find_k_nearest_neighbors(samples, cosine_distance(), 2, edges1); + find_k_nearest_neighbors_lsh(samples, cosine_distance(), hash_type(), 2, 6, edges2, 2); + + std::sort(edges1.begin(), edges1.end(), order_by_index); + std::sort(edges2.begin(), edges2.end(), order_by_index); + + DLIB_TEST_MSG(edges1.size() == edges2.size(), edges1.size() << " " << edges2.size()); + for (unsigned long i = 0; i < edges1.size(); ++i) + { + DLIB_TEST(edges1[i] == edges2[i]); + DLIB_TEST_MSG(std::abs(edges1[i].distance() - edges2[i].distance()) < 1e-7, + edges1[i].distance() - edges2[i].distance()); + } + } + + template + void test_knn_lsh_sparse() + { + dlib::rand rnd; + std::vector > samples; + samples.resize(20); + for (unsigned int i = 0; i < samples.size(); ++i) + { + samples[i][0] = rnd.get_random_gaussian(); + samples[i][2] = rnd.get_random_gaussian(); + } + + test_find_k_nearest_neighbors_lsh(samples); + test_find_k_nearest_neighbors_lsh(samples); + test_find_k_nearest_neighbors_lsh(samples); + test_find_k_nearest_neighbors_lsh(samples); + } + + template + void test_knn_lsh_dense() + { + dlib::rand rnd; + std::vector > samples; + samples.resize(20); + for (unsigned int i = 0; i < samples.size(); ++i) + { + samples[i].set_size(2); + samples[i](0) = rnd.get_random_gaussian(); + samples[i](1) = rnd.get_random_gaussian(); + } + + test_find_k_nearest_neighbors_lsh(samples); + test_find_k_nearest_neighbors_lsh(samples); + test_find_k_nearest_neighbors_lsh(samples); + test_find_k_nearest_neighbors_lsh(samples); + } + + + + class linear_manifold_regularizer_tester : public tester + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a unit test. When it is constructed + it adds itself into the testing framework. + !*/ + public: + linear_manifold_regularizer_tester ( + ) : + tester ( + "test_linear_manifold_regularizer", // the command line argument name for this test + "Run tests on the linear_manifold_regularizer object.", // the command line argument description + 0 // the number of command line arguments for this test + ) + { + seed = 1; + } + + dlib::rand rnd; + + unsigned long seed; + + typedef matrix sample_type; + typedef radial_basis_kernel kernel_type; + + void do_the_test() + { + print_spinner(); + std::vector samples; + + // Declare an instance of the kernel we will be using. + const kernel_type kern(0.1); + + const unsigned long num_points = 200; + + // create a large dataset with two concentric circles. + generate_circle(samples, 1, num_points); // circle of radius 1 + generate_circle(samples, 5, num_points); // circle of radius 5 + + std::vector edges; + find_percent_shortest_edges_randomly(samples, squared_euclidean_distance(0.1, 4), 1, 10000, "random seed", edges); + + dlog << LTRACE << "number of edges generated: " << edges.size(); + + empirical_kernel_map ekm; + + ekm.load(kern, randomly_subsample(samples, 100)); + + // Project all the samples into the span of our 50 basis samples + for (unsigned long i = 0; i < samples.size(); ++i) + samples[i] = ekm.project(samples[i]); + + + // Now create the manifold regularizer. The result is a transformation matrix that + // embodies the manifold assumption discussed above. + linear_manifold_regularizer lmr; + lmr.build(samples, edges, use_gaussian_weights(0.1)); + matrix T = lmr.get_transformation_matrix(10000); + + print_spinner(); + + // generate the T matrix manually and make sure it matches. The point of this test + // is to make sure that the more complex version of this that happens inside the linear_manifold_regularizer + // is correct. It uses a tedious block of loops to do it in a way that is a lot faster for sparse + // W matrices but isn't super straight forward. + matrix X(samples[0].size(), samples.size()); + for (unsigned long i = 0; i < samples.size(); ++i) + set_colm(X,i) = samples[i]; + + matrix W(samples.size(), samples.size()); + W = 0; + for (unsigned long i = 0; i < edges.size(); ++i) + { + W(edges[i].index1(), edges[i].index2()) = use_gaussian_weights(0.1)(edges[i]); + W(edges[i].index2(), edges[i].index1()) = use_gaussian_weights(0.1)(edges[i]); + } + matrix L = diagm(sum_rows(W)) - W; + matrix trueT = inv_lower_triangular(chol(identity_matrix(X.nr()) + (10000.0/sum(lowerm(W)))*X*L*trans(X))); + + dlog << LTRACE << "T error: "<< max(abs(T - trueT)); + DLIB_TEST(max(abs(T - trueT)) < 1e-7); + + + print_spinner(); + // Apply the transformation generated by the linear_manifold_regularizer to + // all our samples. + for (unsigned long i = 0; i < samples.size(); ++i) + samples[i] = T*samples[i]; + + + // For convenience, generate a projection_function and merge the transformation + // matrix T into it. + projection_function proj = ekm.get_projection_function(); + proj.weights = T*proj.weights; + + + // Pick 2 different labeled points. One on the inner circle and another on the outer. + // For each of these test points we will see if using the single plane that separates + // them is a good way to separate the concentric circles. Also do this a bunch + // of times with different randomly chosen points so we can see how robust the result is. + for (int itr = 0; itr < 10; ++itr) + { + print_spinner(); + std::vector test_points; + // generate a random point from the radius 1 circle + generate_circle(test_points, 1, 1); + // generate a random point from the radius 5 circle + generate_circle(test_points, 5, 1); + + // project the two test points into kernel space. Recall that this projection_function + // has the manifold regularizer incorporated into it. + const sample_type class1_point = proj(test_points[0]); + const sample_type class2_point = proj(test_points[1]); + + double num_wrong = 0; + + // Now attempt to classify all the data samples according to which point + // they are closest to. The output of this program shows that without manifold + // regularization this test will fail but with it it will perfectly classify + // all the points. + for (unsigned long i = 0; i < samples.size(); ++i) + { + double distance_to_class1 = length(samples[i] - class1_point); + double distance_to_class2 = length(samples[i] - class2_point); + + bool predicted_as_class_1 = (distance_to_class1 < distance_to_class2); + + bool really_is_class_1 = (i < num_points); + + // now count how many times we make a mistake + if (predicted_as_class_1 != really_is_class_1) + ++num_wrong; + } + + DLIB_TEST_MSG(num_wrong == 0, num_wrong); + } + + } + + void generate_circle ( + std::vector& samples, + double radius, + const long num + ) + { + sample_type m(2,1); + + for (long i = 0; i < num; ++i) + { + double sign = 1; + if (rnd.get_random_double() < 0.5) + sign = -1; + m(0) = 2*radius*rnd.get_random_double()-radius; + m(1) = sign*sqrt(radius*radius - m(0)*m(0)); + + samples.push_back(m); + } + } + + + void test_knn1() + { + std::vector > samples; + + matrix test; + + test = 0,0; samples.push_back(test); + test = 1,1; samples.push_back(test); + test = 1,-1; samples.push_back(test); + test = -1,1; samples.push_back(test); + test = -1,-1; samples.push_back(test); + + std::vector edges; + find_k_nearest_neighbors(samples, squared_euclidean_distance(), 1, edges); + DLIB_TEST(edges.size() == 4); + + std::sort(edges.begin(), edges.end(), &order_by_index); + + DLIB_TEST(edges[0] == sample_pair(0,1,0)); + DLIB_TEST(edges[1] == sample_pair(0,2,0)); + DLIB_TEST(edges[2] == sample_pair(0,3,0)); + DLIB_TEST(edges[3] == sample_pair(0,4,0)); + + find_k_nearest_neighbors(samples, squared_euclidean_distance(), 3, edges); + DLIB_TEST(edges.size() == 8); + + find_k_nearest_neighbors(samples, squared_euclidean_distance(3.9, 4.1), 3, edges); + DLIB_TEST(edges.size() == 4); + + std::sort(edges.begin(), edges.end(), &order_by_index); + + DLIB_TEST(edges[0] == sample_pair(1,2,0)); + DLIB_TEST(edges[1] == sample_pair(1,3,0)); + DLIB_TEST(edges[2] == sample_pair(2,4,0)); + DLIB_TEST(edges[3] == sample_pair(3,4,0)); + + find_k_nearest_neighbors(samples, squared_euclidean_distance(30000, 4.1), 3, edges); + DLIB_TEST(edges.size() == 0); + } + + void test_knn1_approx() + { + std::vector > samples; + + matrix test; + + test = 0,0; samples.push_back(test); + test = 1,1; samples.push_back(test); + test = 1,-1; samples.push_back(test); + test = -1,1; samples.push_back(test); + test = -1,-1; samples.push_back(test); + + std::vector edges; + find_approximate_k_nearest_neighbors(samples, squared_euclidean_distance(), 1, 10000, seed, edges); + DLIB_TEST(edges.size() == 4); + + std::sort(edges.begin(), edges.end(), &order_by_index); + + DLIB_TEST(edges[0] == sample_pair(0,1,0)); + DLIB_TEST(edges[1] == sample_pair(0,2,0)); + DLIB_TEST(edges[2] == sample_pair(0,3,0)); + DLIB_TEST(edges[3] == sample_pair(0,4,0)); + + find_approximate_k_nearest_neighbors(samples, squared_euclidean_distance(), 3, 10000, seed, edges); + DLIB_TEST(edges.size() == 8); + + find_approximate_k_nearest_neighbors(samples, squared_euclidean_distance(3.9, 4.1), 3, 10000, seed, edges); + DLIB_TEST(edges.size() == 4); + + std::sort(edges.begin(), edges.end(), &order_by_index); + + DLIB_TEST(edges[0] == sample_pair(1,2,0)); + DLIB_TEST(edges[1] == sample_pair(1,3,0)); + DLIB_TEST(edges[2] == sample_pair(2,4,0)); + DLIB_TEST(edges[3] == sample_pair(3,4,0)); + + find_approximate_k_nearest_neighbors(samples, squared_euclidean_distance(30000, 4.1), 3, 10000, seed, edges); + DLIB_TEST(edges.size() == 0); + } + + void test_knn2() + { + std::vector > samples; + + matrix test; + + test = 1,1; samples.push_back(test); + test = 1,-1; samples.push_back(test); + test = -1,1; samples.push_back(test); + test = -1,-1; samples.push_back(test); + + std::vector edges; + find_k_nearest_neighbors(samples, squared_euclidean_distance(), 2, edges); + DLIB_TEST(edges.size() == 4); + + std::sort(edges.begin(), edges.end(), &order_by_index); + + DLIB_TEST(edges[0] == sample_pair(0,1,0)); + DLIB_TEST(edges[1] == sample_pair(0,2,0)); + DLIB_TEST(edges[2] == sample_pair(1,3,0)); + DLIB_TEST(edges[3] == sample_pair(2,3,0)); + + find_k_nearest_neighbors(samples, squared_euclidean_distance(), 200, edges); + DLIB_TEST(edges.size() == 4*3/2); + } + + void test_knn2_approx() + { + std::vector > samples; + + matrix test; + + test = 1,1; samples.push_back(test); + test = 1,-1; samples.push_back(test); + test = -1,1; samples.push_back(test); + test = -1,-1; samples.push_back(test); + + std::vector edges; + // For this simple graph and high number of samples we will do we should obtain the exact + // knn solution. + find_approximate_k_nearest_neighbors(samples, squared_euclidean_distance(), 2, 10000, seed, edges); + DLIB_TEST(edges.size() == 4); + + std::sort(edges.begin(), edges.end(), &order_by_index); + + DLIB_TEST(edges[0] == sample_pair(0,1,0)); + DLIB_TEST(edges[1] == sample_pair(0,2,0)); + DLIB_TEST(edges[2] == sample_pair(1,3,0)); + DLIB_TEST(edges[3] == sample_pair(2,3,0)); + + + find_approximate_k_nearest_neighbors(samples, squared_euclidean_distance(), 200, 10000, seed, edges); + DLIB_TEST(edges.size() == 4*3/2); + } + + void perform_test ( + ) + { + for (int i = 0; i < 5; ++i) + { + do_the_test(); + + ++seed; + test_knn1_approx(); + test_knn2_approx(); + } + test_knn1(); + test_knn2(); + test_knn_lsh_sparse(); + test_knn_lsh_sparse(); + test_knn_lsh_dense(); + test_knn_lsh_dense(); + + } + }; + + // Create an instance of this object. Doing this causes this test + // to be automatically inserted into the testing framework whenever this cpp file + // is linked into the project. Note that since we are inside an unnamed-namespace + // we won't get any linker errors about the symbol a being defined multiple times. + linear_manifold_regularizer_tester a; + +} + + + diff --git a/ml/dlib/dlib/test/lspi.cpp b/ml/dlib/dlib/test/lspi.cpp new file mode 100644 index 000000000..013887115 --- /dev/null +++ b/ml/dlib/dlib/test/lspi.cpp @@ -0,0 +1,258 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "tester.h" +#include +#include +#include +#include + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + dlib::logger dlog("test.lspi"); + + template + struct chain_model + { + typedef int state_type; + typedef int action_type; // 0 is move left, 1 is move right + const static bool force_last_weight_to_1 = have_prior; + + + const static int num_states = 4; // not required in the model interface + + matrix offset; + chain_model() + { + offset = + 2.048 , + 2.56 , + 2.048 , + 3.2 , + 2.56 , + 4 , + 3.2, + 5 ; + if (!have_prior) + offset = 0; + + } + + unsigned long num_features( + ) const + { + if (have_prior) + return num_states*2 + 1; + else + return num_states*2; + } + + action_type find_best_action ( + const state_type& state, + const matrix& w + ) const + { + if (w(state*2)+offset(state*2) >= w(state*2+1)+offset(state*2+1)) + //if (w(state*2) >= w(state*2+1)) + return 0; + else + return 1; + } + + void get_features ( + const state_type& state, + const action_type& action, + matrix& feats + ) const + { + feats.set_size(num_features()); + feats = 0; + feats(state*2 + action) = 1; + if (have_prior) + feats(num_features()-1) = offset(state*2+action); + } + + }; + + void test_lspi_prior1() + { + print_spinner(); + typedef process_sample > sample_type; + std::vector samples; + + samples.push_back(sample_type(0,0,0,0)); + samples.push_back(sample_type(0,1,1,0)); + + samples.push_back(sample_type(1,0,0,0)); + samples.push_back(sample_type(1,1,2,0)); + + samples.push_back(sample_type(2,0,1,0)); + samples.push_back(sample_type(2,1,3,0)); + + samples.push_back(sample_type(3,0,2,0)); + samples.push_back(sample_type(3,1,3,1)); + + + lspi > trainer; + //trainer.be_verbose(); + trainer.set_lambda(0); + policy > pol = trainer.train(samples); + + dlog << LINFO << pol.get_weights(); + + matrix w = pol.get_weights(); + DLIB_TEST(pol.get_weights().size() == 9); + DLIB_TEST(w(w.size()-1) == 1); + w(w.size()-1) = 0; + DLIB_TEST_MSG(length(w) < 1e-12, length(w)); + + dlog << LINFO << "action: " << pol(0); + dlog << LINFO << "action: " << pol(1); + dlog << LINFO << "action: " << pol(2); + dlog << LINFO << "action: " << pol(3); + DLIB_TEST(pol(0) == 1); + DLIB_TEST(pol(1) == 1); + DLIB_TEST(pol(2) == 1); + DLIB_TEST(pol(3) == 1); + } + + void test_lspi_prior2() + { + print_spinner(); + typedef process_sample > sample_type; + std::vector samples; + + samples.push_back(sample_type(0,0,0,0)); + samples.push_back(sample_type(0,1,1,0)); + + samples.push_back(sample_type(1,0,0,0)); + samples.push_back(sample_type(1,1,2,0)); + + samples.push_back(sample_type(2,0,1,0)); + samples.push_back(sample_type(2,1,3,1)); + + samples.push_back(sample_type(3,0,2,0)); + samples.push_back(sample_type(3,1,3,0)); + + + lspi > trainer; + //trainer.be_verbose(); + trainer.set_lambda(0); + policy > pol = trainer.train(samples); + + + dlog << LINFO << "action: " << pol(0); + dlog << LINFO << "action: " << pol(1); + dlog << LINFO << "action: " << pol(2); + dlog << LINFO << "action: " << pol(3); + DLIB_TEST(pol(0) == 1); + DLIB_TEST(pol(1) == 1); + DLIB_TEST(pol(2) == 1); + DLIB_TEST(pol(3) == 0); + } + + void test_lspi_noprior1() + { + print_spinner(); + typedef process_sample > sample_type; + std::vector samples; + + samples.push_back(sample_type(0,0,0,0)); + samples.push_back(sample_type(0,1,1,0)); + + samples.push_back(sample_type(1,0,0,0)); + samples.push_back(sample_type(1,1,2,0)); + + samples.push_back(sample_type(2,0,1,0)); + samples.push_back(sample_type(2,1,3,0)); + + samples.push_back(sample_type(3,0,2,0)); + samples.push_back(sample_type(3,1,3,1)); + + + lspi > trainer; + //trainer.be_verbose(); + trainer.set_lambda(0.01); + policy > pol = trainer.train(samples); + + dlog << LINFO << pol.get_weights(); + DLIB_TEST(pol.get_weights().size() == 8); + + + dlog << LINFO << "action: " << pol(0); + dlog << LINFO << "action: " << pol(1); + dlog << LINFO << "action: " << pol(2); + dlog << LINFO << "action: " << pol(3); + DLIB_TEST(pol(0) == 1); + DLIB_TEST(pol(1) == 1); + DLIB_TEST(pol(2) == 1); + DLIB_TEST(pol(3) == 1); + } + void test_lspi_noprior2() + { + print_spinner(); + typedef process_sample > sample_type; + std::vector samples; + + samples.push_back(sample_type(0,0,0,0)); + samples.push_back(sample_type(0,1,1,0)); + + samples.push_back(sample_type(1,0,0,0)); + samples.push_back(sample_type(1,1,2,1)); + + samples.push_back(sample_type(2,0,1,0)); + samples.push_back(sample_type(2,1,3,0)); + + samples.push_back(sample_type(3,0,2,0)); + samples.push_back(sample_type(3,1,3,0)); + + + lspi > trainer; + //trainer.be_verbose(); + trainer.set_lambda(0.01); + policy > pol = trainer.train(samples); + + dlog << LINFO << pol.get_weights(); + DLIB_TEST(pol.get_weights().size() == 8); + + + dlog << LINFO << "action: " << pol(0); + dlog << LINFO << "action: " << pol(1); + dlog << LINFO << "action: " << pol(2); + dlog << LINFO << "action: " << pol(3); + DLIB_TEST(pol(0) == 1); + DLIB_TEST(pol(1) == 1); + DLIB_TEST(pol(2) == 0); + DLIB_TEST(pol(3) == 0); + } + + class lspi_tester : public tester + { + public: + lspi_tester ( + ) : + tester ( + "test_lspi", // the command line argument name for this test + "Run tests on the lspi object.", // the command line argument description + 0 // the number of command line arguments for this test + ) + { + } + + void perform_test ( + ) + { + test_lspi_prior1(); + test_lspi_prior2(); + + test_lspi_noprior1(); + test_lspi_noprior2(); + } + }; + + lspi_tester a; +} + diff --git a/ml/dlib/dlib/test/lz77_buffer.cpp b/ml/dlib/dlib/test/lz77_buffer.cpp new file mode 100644 index 000000000..ccbb1a24c --- /dev/null +++ b/ml/dlib/dlib/test/lz77_buffer.cpp @@ -0,0 +1,569 @@ +// Copyright (C) 2004 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include + +#include + +#include "tester.h" + +namespace +{ + using namespace test; + using namespace std; + using namespace dlib; + + logger dlog("test.lz77_buffer"); + + template < + typename buf + > + void lz77_buffer_kernel_test ( + ) + /*! + requires + - buf is an implementation of lz77_buffer/lz77_buffer_kernel_abstract.h + ensures + - runs tests on buf for compliance with the specs + !*/ + { + typedef dlib::sliding_buffer::kernel_1a sbuf; + + buf test(8,20); + srand(static_cast(time(0))); + + DLIB_TEST(test.get_lookahead_buffer_size() == 0); + DLIB_TEST(test.get_history_buffer_size() == 0); + DLIB_TEST_MSG(test.get_history_buffer_limit() == 256-20,test.get_history_buffer_limit()); + DLIB_TEST(test.get_lookahead_buffer_limit() == 20); + + + for (int g = 0; g < 2; ++g) + { + test.clear(); + + for (int i = 0; i < 1000; ++i) + { + test.add('a'); + } + DLIB_TEST(test.get_lookahead_buffer_size() == 20); + + + test.shift_buffers(5); + + DLIB_TEST(test.get_lookahead_buffer_size() == 15); + + + + unsigned long index, length, temp; + temp = test.get_lookahead_buffer_size(); + test.find_match(index,length,5); + + + DLIB_TEST_MSG(length <= temp, + "length: " << length << + "\ntemp: " << temp); + DLIB_TEST(test.get_lookahead_buffer_size() <= 15); + + + } + + + for (int g = 0; g < 2; ++g) + { + + + + test.clear(); + + + + DLIB_TEST(test.get_lookahead_buffer_size() == 0); + DLIB_TEST(test.get_history_buffer_size() == 0); + DLIB_TEST(test.get_history_buffer_limit() == 256-20); + DLIB_TEST(test.get_lookahead_buffer_limit() == 20); + + unsigned long a,b, temp = test.get_lookahead_buffer_size(); + test.find_match(a,b,0); + DLIB_TEST(b <= temp); + DLIB_TEST(b == 0); + + test.find_match(a,b,5); + DLIB_TEST(b == 0); + + DLIB_TEST(test.get_lookahead_buffer_size() == 0); + DLIB_TEST(test.get_history_buffer_size() == 0); + DLIB_TEST(test.get_history_buffer_limit() == 256-20); + DLIB_TEST(test.get_lookahead_buffer_limit() == 20); + + + + ostringstream sout; + sout << "DLIB_TEST_MSG(test.get_lookahead_buffer_size() == 0,);\n"; + sout << "DLIB_TEST_MSG(test.get_history_buffer_size() == 0,);\n"; + sout << "DLIB_TEST_MSG(test.get_history_buffer_limit() == 256-20,);\n"; + sout << "DLIB_TEST_MSG(test.get_lookahead_buffer_limit() == 20,);\n"; + sout << "DLIB_TEST_MSG(test.get_lookahead_buffer_size() == 0,);\n"; + sout << "DLIB_TEST_MSG(test.get_history_buffer_size() == 0,);\n"; + sout << "DLIB_TEST_MSG(test.get_history_buffer_limit() == 256-20,);\n"; + sout << "DLIB_TEST_MSG(test.get_lookahead_buffer_limit() == 20,);\n"; + istringstream sin(sout.str()); + + sout.str(""); + sout.clear(); + + unsigned char ch; + sbuf sbuffer; + sbuffer.set_size(8); + + + ch = sin.get(); + sbuffer[0] = ch; sbuffer.rotate_left(1); + test.add(ch); + DLIB_TEST(test.lookahead_buffer(test.get_lookahead_buffer_size()-1) == ch); + DLIB_TEST(test.get_lookahead_buffer_size() == 1); + DLIB_TEST(test.get_history_buffer_size() == 0); + DLIB_TEST(test.get_lookahead_buffer_limit() == 20); + + + + ch = sin.get(); + sbuffer[0] = ch; sbuffer.rotate_left(1); + test.add(ch); + DLIB_TEST(test.lookahead_buffer(test.get_lookahead_buffer_size()-1) == ch); + DLIB_TEST(test.get_lookahead_buffer_size() == 2); + DLIB_TEST(test.get_history_buffer_size() == 0); + DLIB_TEST(test.get_lookahead_buffer_limit() == 20); + + + + ch = sin.get(); + sbuffer[0] = ch; sbuffer.rotate_left(1); + test.add(ch); + DLIB_TEST(test.lookahead_buffer(test.get_lookahead_buffer_size()-1) == ch); + DLIB_TEST(test.get_lookahead_buffer_size() == 3); + DLIB_TEST(test.get_history_buffer_size() == 0); + + // add 17 chars to test so that the lookahead buffer will be full + for (int i = 0; i < 17; ++i) + { + ch = sin.get(); + sbuffer[0] = ch; sbuffer.rotate_left(1); + test.add(ch); + DLIB_TEST(test.lookahead_buffer(test.get_lookahead_buffer_size()-1) == ch); + } + + DLIB_TEST(test.get_lookahead_buffer_size() == 20); + DLIB_TEST(test.get_history_buffer_size() == 0); + DLIB_TEST(test.lookahead_buffer(0) == sbuffer[20]); + + + ch = sin.get(); + sbuffer[0] = ch; sbuffer.rotate_left(1); + test.add(ch); + DLIB_TEST(test.lookahead_buffer(test.get_lookahead_buffer_size()-1) == ch); + DLIB_TEST(test.get_lookahead_buffer_size() == 20); + DLIB_TEST(test.get_history_buffer_size() == 1); + + + + + + + // add the above text to test and make sure it gives the correct results + ch = sin.get(); + while (sin) + { + sbuffer[0] = ch; sbuffer.rotate_left(1); + test.add(ch); + DLIB_TEST(test.lookahead_buffer(test.get_lookahead_buffer_size()-1) == ch); + DLIB_TEST(test.history_buffer(0) == sbuffer[21]); + DLIB_TEST(test.history_buffer(1) == sbuffer[22]); + + ch = sin.get(); + } + + + + // make sure the contents of lookahead_buffer and history_buffer + // match what is in sbuffer + sbuffer.rotate_right(1); + for (unsigned int i = 0; i < test.get_history_buffer_size(); ++i) + { + DLIB_TEST(sbuffer[test.get_lookahead_buffer_limit()+i] == test.history_buffer(i)); + } + for (unsigned int i = 0; i < test.get_lookahead_buffer_size(); ++i) + { + DLIB_TEST(sbuffer[test.get_lookahead_buffer_limit()-1-i] == test.lookahead_buffer(i)); + } + sbuffer.rotate_left(1); + + + + + + + + + + + sbuffer.rotate_right(1); // do this because we never put anything in sbuffer[0] + + unsigned long match_index, match_length; + unsigned long ltemp = test.get_lookahead_buffer_size(); + test.find_match(match_index,match_length,0); + DLIB_TEST(match_length <= ltemp); + + + // verify the match with sbuffer + for (unsigned int i = 0; i < match_length; ++i) + { + DLIB_TEST_MSG(sbuffer[19-i] == sbuffer[match_index+20-i],i); + } + + + sin.str(""); + sin.clear(); + + } // for (int g = 0; g < 2; ++g) + + + for (int g = 0; g < 8; ++g) + { + test.clear(); + + + DLIB_TEST(test.get_lookahead_buffer_size() == 0); + DLIB_TEST(test.get_history_buffer_size() == 0); + DLIB_TEST(test.get_history_buffer_limit() == 256-20); + DLIB_TEST(test.get_lookahead_buffer_limit() == 20); + + + sbuf sbuffer; + sbuffer.set_size(8); + + ostringstream sout; + sout << "DLIB_TEST_MSG(test.get_lookahead_buffer_size() == 0,);\n"; + sout << "DLIB_TEST_MSG(test.get_history_buffer_size() == 0,);\n"; + sout << "DLIB_TEST_MSG(test.get_history_buffer_limit() == 256-20,);\n"; + sout << "DLIB_TEST_MSG(test.get_lookahead_buffer_limit() == 20,);\n"; + sout << "DLIB_TEST_MSG(test.get_lookahead_buffer_size() == 0,);\n"; + sout << "DLIB_TEST_MSG(test.get_history_buffer_limit() == 256-20,);\n"; + sout << "DLIB_TEST_MSG(test.get_lookahead_buffer_limit() == 20,);\n"; + sout << "DLIB_TEST_MSG(test.get_history_buffer_limit() == 256-20,);\n"; + sout << "DLIB_TEST_MSG(test.get_lookahead_buffer_size() == 0,);\n"; + sout << "DLIB_TEST_MSG(test.get_history_buffer_limit() == 256-20,);\n"; + sout << "DLIB_TEST_MSG(test.get_history_buffer_limit() == 256-20,);\n"; + istringstream sin(sout.str()); + + unsigned char ch; + for (int i = 0; i < 100; ++i) + { + ch = sin.get(); + sbuffer[0] = ch; sbuffer.rotate_left(1); + test.add(ch); + } + + // make sure the contents of lookahead_buffer and history_buffer + // match what is in sbuffer + sbuffer.rotate_right(1); + for (unsigned int i = 0; i < test.get_history_buffer_size(); ++i) + { + DLIB_TEST(sbuffer[test.get_lookahead_buffer_limit()+i] == test.history_buffer(i)); + } + for (unsigned int i = 0; i < test.get_lookahead_buffer_size(); ++i) + { + DLIB_TEST(sbuffer[test.get_lookahead_buffer_limit()-1-i] == test.lookahead_buffer(i)); + } + sbuffer.rotate_left(1); + + + + + unsigned long match_index, match_length; + unsigned long ltemp = test.get_lookahead_buffer_size(); + test.find_match(match_index,match_length,0); + DLIB_TEST(match_length <= ltemp); + + DLIB_TEST(test.get_lookahead_buffer_size() == 20-match_length); + + sbuffer.rotate_right(1); // do this because we never put anything in sbuffer[0] + // verify the match with sbuffer + for (unsigned int i = 0; i < match_length; ++i) + { + DLIB_TEST(sbuffer[i+20-match_length] == sbuffer[i+1+match_index+20-match_length]); + } + sbuffer.rotate_left(1); // free up sbuffer[0] for new data + + + + + for (int i = 0; i < 7+g*2; ++i) + { + ch = sin.get(); + sbuffer[0] = ch; sbuffer.rotate_left(1); + test.add(ch); + } + + ch = '?'; + sbuffer[0] = ch; sbuffer.rotate_left(1); + test.add(ch); + ch = 'a'; + sbuffer[0] = ch; sbuffer.rotate_left(1); + test.add(ch); + ch = 'v'; + sbuffer[0] = ch; sbuffer.rotate_left(1); + test.add(ch); + ch = 'i'; + sbuffer[0] = ch; sbuffer.rotate_left(1); + test.add(ch); + ch = 's'; + sbuffer[0] = ch; sbuffer.rotate_left(1); + test.add(ch); + + + // adjust sbuffer due to the last call to test.find_match() + // but only if we haven't already added enough (20 or more) chars + // to fill the lookahead buffer already. + if (match_length > static_cast(12+g*2)) + sbuffer.rotate_left(match_length-(12+g*2)); + + + + + + // make sure the contents of lookahead_buffer and history_buffer + // match what is in sbuffer + sbuffer.rotate_right(1); + for (unsigned int i = 0; i < test.get_history_buffer_size(); ++i) + { + DLIB_TEST(sbuffer[test.get_lookahead_buffer_limit()+i] == test.history_buffer(i)); + } + for (unsigned int i = 0; i < test.get_lookahead_buffer_size(); ++i) + { + DLIB_TEST(sbuffer[test.get_lookahead_buffer_limit()-1-i] == test.lookahead_buffer(i)); + } + sbuffer.rotate_left(1); + + + + + test.find_match(match_index,match_length,10+g); + + if (match_length > 0) + DLIB_TEST(match_length >= static_cast(10+g) ); + + + sbuffer.rotate_right(1); // do this because we never put anything in sbuffer[0] + // verify the match with sbuffer + for (unsigned int i = 0; i < match_length; ++i) + { + DLIB_TEST(sbuffer[i+20-match_length] == sbuffer[i+1+match_index+20-match_length]); + } + sbuffer.rotate_left(1); // free up sbuffer[0] for new data + + } // for (int g = 0; g < 8; ++g) + + + + + + + + srand(static_cast(time(0))); + + for (int g = 0; g < 200; ++g) + { + test.clear(); + + DLIB_TEST(test.get_lookahead_buffer_size() == 0); + DLIB_TEST(test.get_history_buffer_size() == 0); + DLIB_TEST(test.get_history_buffer_limit() == 256-20); + DLIB_TEST(test.get_lookahead_buffer_limit() == 20); + + + sbuf sbuffer; + sbuffer.set_size(8); + + ostringstream sout; + int l = ::rand()%500; + for (int i = 0; i < l; ++i) + { + char temp = static_cast(::rand()%256); + sout << temp; + } + istringstream sin(sout.str()); + + unsigned char ch; + for (int i = 0; i < l; ++i) + { + ch = sin.get(); + sbuffer[0] = ch; sbuffer.rotate_left(1); + test.add(ch); + } + + // make sure the contents of lookahead_buffer and history_buffer + // match what is in sbuffer + sbuffer.rotate_right(1); + + // adjust so that sbuffer[19] is the same as lookahead_buffer[0] + if (test.get_lookahead_buffer_size() < 20) + sbuffer.rotate_left(20-test.get_lookahead_buffer_size()); + + for (unsigned int i = 0; i < test.get_history_buffer_size(); ++i) + { + DLIB_TEST(sbuffer[test.get_lookahead_buffer_limit()+i] == test.history_buffer(i)); + } + for (unsigned int i = 0; i < test.get_lookahead_buffer_size(); ++i) + { + DLIB_TEST(sbuffer[test.get_lookahead_buffer_limit()-1-i] == test.lookahead_buffer(i)); + } + sbuffer.rotate_left(1); + + + + unsigned long match_index, match_length; + unsigned long lookahead_size_before = test.get_lookahead_buffer_size(); + test.find_match(match_index,match_length,0); + DLIB_TEST(match_length <= lookahead_size_before); + + + DLIB_TEST(test.get_lookahead_buffer_size() == lookahead_size_before-match_length); + + sbuffer.rotate_right(1); // do this because we never put anything in sbuffer[0] + // verify the match with sbuffer + for (unsigned int i = 0; i < match_length; ++i) + { + DLIB_TEST_MSG(sbuffer[19-i] == sbuffer[match_index+20-i],i); + } + sbuffer.rotate_left(1); // free up sbuffer[0] for new data + + } // for (int g = 0; g < 200; ++g) + + + + + + + + + srand(static_cast(time(0))); + + for (int g = 0; g < 300; ++g) + { + test.clear(); + + DLIB_TEST(test.get_lookahead_buffer_size() == 0); + DLIB_TEST(test.get_history_buffer_size() == 0); + DLIB_TEST(test.get_history_buffer_limit() == 256-20); + DLIB_TEST(test.get_lookahead_buffer_limit() == 20); + + + sbuf sbuffer; + sbuffer.set_size(8); + + ostringstream sout; + int l = ::rand()%500; + for (int i = 0; i < l; ++i) + { + char temp = static_cast(::rand()%20); + sout << temp; + sout << temp; + sout << temp; + sout << temp; + sout << temp; + sout << temp; + } + istringstream sin(sout.str()); + + unsigned char ch; + for (int i = 0; i < l; ++i) + { + ch = sin.get(); + sbuffer[0] = ch; sbuffer.rotate_left(1); + test.add(ch); + } + + // make sure the contents of lookahead_buffer and history_buffer + // match what is in sbuffer + sbuffer.rotate_right(1); + + // adjust so that sbuffer[19] is the same as lookahead_buffer[0] + if (test.get_lookahead_buffer_size() < 20) + sbuffer.rotate_left(20-test.get_lookahead_buffer_size()); + + for (unsigned int i = 0; i < test.get_history_buffer_size(); ++i) + { + DLIB_TEST(sbuffer[test.get_lookahead_buffer_limit()+i] == test.history_buffer(i)); + } + for (unsigned int i = 0; i < test.get_lookahead_buffer_size(); ++i) + { + DLIB_TEST(sbuffer[test.get_lookahead_buffer_limit()-1-i] == test.lookahead_buffer(i)); + } + sbuffer.rotate_left(1); + + + + unsigned long match_index = 0, match_length = 0; + unsigned long lookahead_size_before = test.get_lookahead_buffer_size(); + unsigned long history_size_before = test.get_history_buffer_size(); + test.find_match(match_index,match_length,2); + + if (match_length != 0) + { + DLIB_TEST_MSG(match_index < history_size_before, + "match_index: " << match_index << + "\nhistory_size_before: " << history_size_before); + + } + + + DLIB_TEST(test.get_lookahead_buffer_size() == lookahead_size_before-match_length); + + sbuffer.rotate_right(1); // do this because we never put anything in sbuffer[0] + // verify the match with sbuffer + for (unsigned int i = 0; i < match_length; ++i) + { + DLIB_TEST_MSG(sbuffer[19-i] == sbuffer[match_index+20-i],i); + } + sbuffer.rotate_left(1); // free up sbuffer[0] for new data + + + + } // for (int g = 0; g < 300; ++g) + + } + + + + + class lz77_buffer_tester : public tester + { + public: + lz77_buffer_tester ( + ) : + tester ("test_lz77_buffer", + "Runs tests on the lz77_buffer component.") + {} + + void perform_test ( + ) + { + dlog << LINFO << "testing kernel_1a"; + lz77_buffer_kernel_test (); + dlog << LINFO << "testing kernel_1a_c"; + lz77_buffer_kernel_test(); + dlog << LINFO << "testing kernel_2a"; + lz77_buffer_kernel_test (); + dlog << LINFO << "testing kernel_2a_c"; + lz77_buffer_kernel_test(); + } + } a; + +} + diff --git a/ml/dlib/dlib/test/main.cpp b/ml/dlib/dlib/test/main.cpp new file mode 100644 index 000000000..4800a7211 --- /dev/null +++ b/ml/dlib/dlib/test/main.cpp @@ -0,0 +1,217 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include "tester.h" +#include + + +using namespace std; +using namespace dlib; +using namespace test; + +typedef cmd_line_parser::check_1a_c clp; + +static logger dlog("test.main"); + +int main (int argc, char** argv) +{ + try + { + clp parser; + + parser.add_option("runall","Run all the tests that don't take any arguments."); + parser.add_option("h","Displays this information."); + parser.add_option("n","How many times to run the selected tests. The default is 1.",1); + parser.add_option("d","log debugging statements to file debug.txt."); + parser.add_option("l","Set the logging level (all, trace, debug, info, warn, error, or fatal), the default is all.",1); + parser.add_option("a","Append debugging messages to debug.txt rather than clearing the file at program startup."); + parser.add_option("q","Be quiet. Don't print the testing progress or results to standard out."); + + unsigned long num = 1; + + // add the options for all the different tests + testers().reset(); + while (testers().move_next()) + { + tester& test = *testers().element().value(); + parser.add_option(test.cmd_line_switch(), test.description(), test.num_of_args()); + } + + parser.parse(argc,argv); + + parser.check_option_arg_range("n",1,1000000000); + const char* singles[] = {"d","l","a","n","h","runall","q"}; + parser.check_one_time_options(singles); + const char* d_sub[] = {"l","a"}; + const char* l_args[] = {"all", "trace", "debug", "info", "warn", "error", "fatal"}; + parser.check_sub_options("d",d_sub); + parser.check_option_arg_range("l",l_args); + + + if (parser.option("n")) + { + num = string_cast(parser.option("n").argument()); + } + + if (parser.option("q")) + { + be_verbose = false; + } + + if (parser.option("h")) + { + cout << "Usage: test [options]\n"; + parser.print_options(cout); + cout << "\n\n"; + return 0; + } + + ofstream fout; + if (parser.option("d")) + { + if (parser.option("a")) + fout.open("debug.txt",ios::app); + else + fout.open("debug.txt"); + + set_all_logging_output_streams(fout); + + if (parser.option("l").count() == 0) + set_all_logging_levels(LALL); + else if (parser.option("l").argument() == "all") + set_all_logging_levels(LALL); + else if (parser.option("l").argument() == "trace") + set_all_logging_levels(LTRACE); + else if (parser.option("l").argument() == "debug") + set_all_logging_levels(LDEBUG); + else if (parser.option("l").argument() == "info") + set_all_logging_levels(LINFO); + else if (parser.option("l").argument() == "warn") + set_all_logging_levels(LWARN); + else if (parser.option("l").argument() == "error") + set_all_logging_levels(LERROR); + else if (parser.option("l").argument() == "fatal") + set_all_logging_levels(LFATAL); + } + else + { + set_all_logging_levels(LNONE); + } + + unsigned long num_of_failed_tests = 0; + unsigned long num_of_passed_tests = 0; + for (unsigned long i = 0; i < num; ++i) + { + dlog << LINFO << "************ Starting Test Run " << i+1 << " of " << num << ". ************"; + + // loop over all the testers and see if they are supposed to run + testers().reset(); + while (testers().move_next()) + { + tester& test= *testers().element().value(); + const clp::option_type& opt = parser.option(test.cmd_line_switch()); + // run the test for this option as many times as the user has requested. + for (unsigned long j = 0; j < parser.option("runall").count() + opt.count(); ++j) + { + // quit this loop if this option has arguments and this round through the loop is + // from the runall option being present. + if (test.num_of_args() > 0 && j == opt.count()) + break; + + if (be_verbose) + cout << "Running " << test.cmd_line_switch() << " " << flush; + + dlog << LINFO << "Running " << test.cmd_line_switch(); + try + { + switch (test.num_of_args()) + { + case 0: + test.perform_test(); + break; + case 1: + test.perform_test(opt.argument(0,j)); + break; + case 2: + test.perform_test(opt.argument(0,j), opt.argument(1,j)); + break; + default: + cerr << "\n\nThe test '" << test.cmd_line_switch() << "' requested " << test.num_of_args() + << " arguments but only 2 are supported." << endl; + dlog << LINFO << "The test '" << test.cmd_line_switch() << "' requested " << test.num_of_args() + << " arguments but only 2 are supported."; + break; + } + if (be_verbose) + cout << "\r \r"; + + ++num_of_passed_tests; + + } + catch (std::exception& e) + { + if (be_verbose) + { + cout << "\n\n!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n"; + cout << "!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TEST FAILED: " << test.cmd_line_switch() + << " !!!!!!!!!!!!!!!!!!!!!!!!!!!!!"; + cout << "\n!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n\n"; + cout << "Failure message from test: " << e.what() << endl; + } + + + dlog << LERROR << "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!"; + dlog << LERROR << "!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TEST FAILED: " << test.cmd_line_switch() + << " !!!!!!!!!!!!!!!!!!!!!!!!!!!!!"; + dlog << LERROR << "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!"; + dlog << LERROR << "Failure message from test: " << e.what(); + ++num_of_failed_tests; + } + } + } + } + dlog << LINFO << "Testing Finished"; + if (num_of_passed_tests == 0 && num_of_failed_tests == 0) + { + cout << "You didn't select any tests to run.\n"; + cout << "Try the -h option for more information.\n"; + } + else if (num_of_failed_tests == 0) + { + if (be_verbose) + { + cout << "\n\nTesting Finished\n"; + cout << "Total number of individual testing statements executed: "<< number_of_testing_statements_executed() << endl; + cout << "All tests completed successfully\n\n"; + } + dlog << LINFO << "Total number of individual testing statements executed: "<< number_of_testing_statements_executed(); + dlog << LINFO << "All tests completed successfully"; + } + else + { + if (be_verbose) + { + cout << "\n\nTesting Finished\n"; + cout << "Total number of individual testing statements executed: "<< number_of_testing_statements_executed() << endl; + cout << "Number of failed tests: " << num_of_failed_tests << "\n"; + cout << "Number of passed tests: " << num_of_passed_tests << "\n\n"; + } + dlog << LINFO << "Total number of individual testing statements executed: "<< number_of_testing_statements_executed(); + dlog << LWARN << "Number of failed tests: " << num_of_failed_tests; + dlog << LWARN << "Number of passed tests: " << num_of_passed_tests; + } + + return num_of_failed_tests; + } + catch (exception& e) + { + cout << e.what() << endl; + cout << "\nTry the -h option for more information.\n"; + cout << endl; + } +} + diff --git a/ml/dlib/dlib/test/makefile b/ml/dlib/dlib/test/makefile new file mode 100644 index 000000000..d4d478705 --- /dev/null +++ b/ml/dlib/dlib/test/makefile @@ -0,0 +1,185 @@ +# This is the makefile used to build the dlib C++ library's regression test suite +# on Debian Linux using the gcc compiler. + +# this is the name of the output executable +TARGET = dtest + +# these are the compile time flags passed to gcc +CXXFLAGS ?= -ggdb -Wall +CPPFLAGS ?= -std=c++11 -DDEBUG -DDLIB_NO_GUI_SUPPORT -I../.. + +# These are the link time flags passed to gcc +LFLAGS = -lpthread -lnsl + +# The name of the compiler. If you only have one version of +# gcc installed then you probably want to change this to just g++ +CXX ?= nice g++ + +#################################################### +#################################################### +# Here we list all the cpp files we want to compile + +SRC = main.cpp +SRC += tester.cpp +SRC += ../all/source.cpp + +SRC += example.cpp +SRC += example_args.cpp + +SRC += active_learning.cpp +SRC += any.cpp +SRC += any_function.cpp +SRC += array2d.cpp +SRC += array.cpp +SRC += assignment_learning.cpp +SRC += base64.cpp +SRC += bayes_nets.cpp +SRC += bigint.cpp +SRC += binary_search_tree_kernel_1a.cpp +SRC += binary_search_tree_kernel_2a.cpp +SRC += binary_search_tree_mm1.cpp +SRC += binary_search_tree_mm2.cpp +SRC += bridge.cpp +SRC += bsp.cpp +SRC += byte_orderer.cpp +SRC += cca.cpp +SRC += clustering.cpp +SRC += cmd_line_parser.cpp +SRC += cmd_line_parser_wchar_t.cpp +SRC += compress_stream.cpp +SRC += conditioning_class_c.cpp +SRC += conditioning_class.cpp +SRC += config_reader.cpp +SRC += crc32.cpp +SRC += create_iris_datafile.cpp +SRC += data_io.cpp +SRC += directed_graph.cpp +SRC += discriminant_pca.cpp +SRC += disjoint_subsets.cpp +SRC += ekm_and_lisf.cpp +SRC += empirical_kernel_map.cpp +SRC += entropy_coder.cpp +SRC += entropy_encoder_model.cpp +SRC += face.cpp +SRC += fft.cpp +SRC += fhog.cpp +SRC += filtering.cpp +SRC += find_max_factor_graph_nmplp.cpp +SRC += find_max_factor_graph_viterbi.cpp +SRC += geometry.cpp +SRC += graph.cpp +SRC += graph_cuts.cpp +SRC += graph_labeler.cpp +SRC += hash.cpp +SRC += hash_map.cpp +SRC += hash_set.cpp +SRC += hash_table.cpp +SRC += hog_image.cpp +SRC += image.cpp +SRC += iosockstream.cpp +SRC += is_same_object.cpp +SRC += kcentroid.cpp +SRC += kernel_matrix.cpp +SRC += kmeans.cpp +SRC += learning_to_track.cpp +SRC += least_squares.cpp +SRC += linear_manifold_regularizer.cpp +SRC += lspi.cpp +SRC += lz77_buffer.cpp +SRC += map.cpp +SRC += matrix2.cpp +SRC += matrix3.cpp +SRC += matrix4.cpp +SRC += matrix_chol.cpp +SRC += matrix.cpp +SRC += matrix_eig.cpp +SRC += matrix_lu.cpp +SRC += matrix_qr.cpp +SRC += max_cost_assignment.cpp +SRC += max_sum_submatrix.cpp +SRC += md5.cpp +SRC += member_function_pointer.cpp +SRC += metaprogramming.cpp +SRC += mpc.cpp +SRC += multithreaded_object.cpp +SRC += numerical_integration.cpp +SRC += object_detector.cpp +SRC += oca.cpp +SRC += one_vs_all_trainer.cpp +SRC += one_vs_one_trainer.cpp +SRC += optimization.cpp +SRC += optimization_test_functions.cpp +SRC += opt_qp_solver.cpp +SRC += parallel_for.cpp +SRC += parse.cpp +SRC += pipe.cpp +SRC += pixel.cpp +SRC += probabilistic.cpp +SRC += pyramid_down.cpp +SRC += queue.cpp +SRC += rand.cpp +SRC += ranking.cpp +SRC += read_write_mutex.cpp +SRC += reference_counter.cpp +SRC += rls.cpp +SRC += sammon.cpp +SRC += scan_image.cpp +SRC += sequence.cpp +SRC += sequence_labeler.cpp +SRC += sequence_segmenter.cpp +SRC += serialize.cpp +SRC += set.cpp +SRC += sldf.cpp +SRC += sliding_buffer.cpp +SRC += sockets2.cpp +SRC += sockets.cpp +SRC += sockstreambuf.cpp +SRC += sparse_vector.cpp +SRC += stack.cpp +SRC += static_map.cpp +SRC += static_set.cpp +SRC += statistics.cpp +SRC += std_vector_c.cpp +SRC += string.cpp +SRC += svm_c_linear.cpp +SRC += svm_c_linear_dcd.cpp +SRC += svm.cpp +SRC += svm_multiclass_linear.cpp +SRC += svm_struct.cpp +SRC += svr_linear_trainer.cpp +SRC += symmetric_matrix_cache.cpp +SRC += thread_pool.cpp +SRC += threads.cpp +SRC += timer.cpp +SRC += tokenizer.cpp +SRC += trust_region.cpp +SRC += tuple.cpp +SRC += type_safe_union.cpp +SRC += vectorstream.cpp + + +#################################################### + +TMP = $(SRC:.cpp=.o) +OBJ = $(TMP:.c=.o) + +$(TARGET): $(OBJ) + @echo Linking $@ + @$(CXX) $(LDFLAGS) $(OBJ) $(LFLAGS) -o $@ + @echo Build Complete + +clean: + @rm -f $(OBJ) $(TARGET) + @echo All object files and binaries removed + +dep: + @echo Running makedepend + @makedepend -- $(CFLAGS) -- $(SRC) 2> /dev/null + @echo Completed makedepend + +############################################################################### +########## Stuff from makedepend ##### +########## type make dep at the command line to rebuild the dependencies ##### +########## Also, DON'T edit the contents of this file beyond this line. ##### +############################################################################### + diff --git a/ml/dlib/dlib/test/map.cpp b/ml/dlib/dlib/test/map.cpp new file mode 100644 index 000000000..6901ddf05 --- /dev/null +++ b/ml/dlib/dlib/test/map.cpp @@ -0,0 +1,441 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include + +#include +#include "tester.h" + +namespace +{ + + using namespace test; + using namespace std; + using namespace dlib; + + logger dlog("test.map"); + + template < + typename map + > + void map_kernel_test ( + ) + /*! + requires + - map is an implementation of map/map_kernel_abstract.h and + is instantiated to map int to int + ensures + - runs tests on map for compliance with the specs + !*/ + { + + print_spinner(); + + srand(static_cast(time(0))); + + + + map test, test2; + + enumerable >& e = test; + DLIB_TEST(e.at_start() == true); + + for (int j = 0; j < 4; ++j) + { + + DLIB_TEST(test.at_start() == true); + DLIB_TEST(test.current_element_valid() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.at_start() == false); + DLIB_TEST(test.current_element_valid() == false); + + + DLIB_TEST(test.size() == 0); + DLIB_TEST(test.is_in_domain(5) == false); + DLIB_TEST(test.is_in_domain(0) == false); + DLIB_TEST(test.is_in_domain(-999) == false); + DLIB_TEST(test.is_in_domain(4999) == false); + + + int a,b; + a = 8; + b = 94; + test.add(a,b); + DLIB_TEST(test.size() == 1); + DLIB_TEST(test.is_in_domain(8) == true); + DLIB_TEST(test.is_in_domain(5) == false); + DLIB_TEST(test.is_in_domain(0) == false); + DLIB_TEST(test.is_in_domain(-999) == false); + DLIB_TEST(test.is_in_domain(4999) == false); + DLIB_TEST(test[8] == 94); + a = 53; + b = 4; + test.add(a,b); + DLIB_TEST(test.size() == 2); + DLIB_TEST(test.is_in_domain(53) == true); + DLIB_TEST(test.is_in_domain(5) == false); + DLIB_TEST(test.is_in_domain(0) == false); + DLIB_TEST(test.is_in_domain(-999) == false); + DLIB_TEST(test.is_in_domain(4999) == false); + DLIB_TEST(test[53] == 4); + + + swap(test,test2); + + + DLIB_TEST(test2.size() == 2); + DLIB_TEST(test2.is_in_domain(8) == true); + DLIB_TEST(test2.is_in_domain(5) == false); + DLIB_TEST(test2.is_in_domain(0) == false); + DLIB_TEST(test2.is_in_domain(-999) == false); + DLIB_TEST(test2.is_in_domain(4999) == false); + DLIB_TEST(test2[8] == 94); + DLIB_TEST(test2.size() == 2); + DLIB_TEST(test2.is_in_domain(53) == true); + DLIB_TEST(test2.is_in_domain(5) == false); + DLIB_TEST(test2.is_in_domain(0) == false); + DLIB_TEST(test2.is_in_domain(-999) == false); + DLIB_TEST(test2.is_in_domain(4999) == false); + DLIB_TEST(test2[53] == 4); + + + DLIB_TEST(test.size() == 0); + DLIB_TEST(test.is_in_domain(8) == false); + DLIB_TEST(test.is_in_domain(5) == false); + DLIB_TEST(test.is_in_domain(0) == false); + DLIB_TEST(test.is_in_domain(-999) == false); + DLIB_TEST(test.is_in_domain(4999) == false); + DLIB_TEST(test.size() == 0); + DLIB_TEST(test.is_in_domain(53) == false); + DLIB_TEST(test.is_in_domain(5) == false); + DLIB_TEST(test.is_in_domain(0) == false); + DLIB_TEST(test.is_in_domain(-999) == false); + DLIB_TEST(test.is_in_domain(4999) == false); + + + test.clear(); + DLIB_TEST(test.at_start() == true); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.at_start() == false); + + + DLIB_TEST(test.size() == 0); + + while (test.size() < 10000) + { + a = ::rand(); + b = ::rand(); + if (!test.is_in_domain(a)) + test.add(a,b); + } + + DLIB_TEST(test.size() == 10000); + test.clear(); + DLIB_TEST(test.size() == 0); + + while (test.size() < 10000) + { + a = ::rand(); + b = ::rand(); + if (!test.is_in_domain(a)) + test.add(a,b); + } + + DLIB_TEST(test.size() == 10000); + + int count = 0; + a = -1; + while (test.move_next()) + { + DLIB_TEST(test.element().key() == test.element().key()); + DLIB_TEST(test.element().value() == test.element().value()); + DLIB_TEST(test.element().key() == test.element().key()); + DLIB_TEST(test.element().value() == test.element().value()); + + + DLIB_TEST(a < test.element().key()); + a = test.element().key(); + ++count; + } + DLIB_TEST(test.current_element_valid() == false); + DLIB_TEST(test.at_start() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.current_element_valid() == false); + DLIB_TEST(test.at_start() == false); + DLIB_TEST(test.move_next() == false); + + DLIB_TEST(count == 10000); + + test.swap(test2); + + DLIB_TEST(test.size() == 2); + DLIB_TEST(test2.size() == 10000); + count = 0; + a = -1; + test2.reset(); + + test2.move_next(); + test2.element().value() = 99; + DLIB_TEST(test2[test2.element().key()] == 99); + DLIB_TEST(test2.element().value() == 99); + + test2.reset(); + + while (test2.move_next()) + { + DLIB_TEST(test2[test2.element().key()] == test2.element().value()); + DLIB_TEST(test2.element().key() == test2.element().key()); + DLIB_TEST(test2.element().value() == test2.element().value()); + DLIB_TEST(test2.element().key() == test2.element().key()); + DLIB_TEST(test2.element().value() == test2.element().value()); + DLIB_TEST(a < test2.element().key()); + a = test2.element().key(); + ++count; + } + DLIB_TEST(test2.size() == 10000); + DLIB_TEST(count == 10000); + DLIB_TEST(test2.current_element_valid() == false); + DLIB_TEST(test2.at_start() == false); + DLIB_TEST(test2.move_next() == false); + DLIB_TEST(test2.current_element_valid() == false); + DLIB_TEST(test2.at_start() == false); + DLIB_TEST(test2.move_next() == false); + + + + test2.clear(); + DLIB_TEST(test2.size() == 0); + DLIB_TEST(test2.at_start() == true); + + while (test.size() < 20000) + { + a = ::rand(); + b = ::rand(); + if (!test.is_in_domain(a)) + test.add(a,b); + } + + // serialize the state of test, then clear test, then + // load the state back into test. + ostringstream sout; + serialize(test,sout); + istringstream sin(sout.str()); + test.clear(); + deserialize(test,sin); + + DLIB_TEST(test.at_start() == true); + + { + int* array1 = new int[test.size()]; + int* array2 = new int[test.size()]; + + int* tmp1 = array1; + int* tmp2 = array2; + + count = 0; + while (test.move_next()) + { + DLIB_TEST(test.element().key() == test.element().key()); + DLIB_TEST(test.element().value() == test.element().value()); + DLIB_TEST(test.element().key() == test.element().key()); + DLIB_TEST(test.current_element_valid() == true); + *tmp1 = test.element().key(); + *tmp2 = test.element().value(); + ++tmp1; + ++tmp2; + ++count; + } + DLIB_TEST(count == 20000); + + tmp1 = array1; + tmp2 = array2; + for (int i = 0; i < 20000; ++i) + { + DLIB_TEST(test.is_in_domain(*tmp1) == true); + DLIB_TEST(test[*tmp1] == *tmp2); + ++tmp1; + ++tmp2; + } + + DLIB_TEST(test.size() == 20000); + + tmp1 = array1; + tmp2 = array2; + count = 0; + while (test.size() > 10000) + { + test.remove(*tmp1,a,b); + DLIB_TEST(*tmp1 == a); + DLIB_TEST(*tmp2 == b); + ++tmp1; + ++tmp2; + ++count; + } + DLIB_TEST(count == 10000); + DLIB_TEST(test.size() == 10000); + + while (test.move_next()) + { + DLIB_TEST(test.element().key() == *tmp1); + DLIB_TEST(test.element().key() == *tmp1); + DLIB_TEST(test.element().key() == *tmp1); + DLIB_TEST(test.element().value() == *tmp2); + DLIB_TEST(test.element().value() == *tmp2); + DLIB_TEST(test.element().value() == *tmp2); + ++tmp1; + ++tmp2; + ++count; + } + DLIB_TEST(count == 20000); + DLIB_TEST(test.size() == 10000); + + while (test.size() < 20000) + { + a = ::rand(); + b = ::rand(); + if (!test.is_in_domain(a)) + test.add(a,b); + } + + test2.swap(test); + + count = 0; + a = -1; + while (test2.move_next()) + { + DLIB_TEST(test2.element().key() == test2.element().key()); + DLIB_TEST(test2.element().value() == test2.element().value()); + DLIB_TEST(test2.element().key() == test2.element().key()); + DLIB_TEST(a < test2.element().key()); + a = test2.element().key(); + ++count; + } + + DLIB_TEST(count == 20000); + DLIB_TEST(test2.size() == 20000); + + a = -1; + int c = 0; + while (test2.size()>0) + { + test2.remove_any(b,c); + DLIB_TEST( a < b); + a = b; + } + + DLIB_TEST(test2.size() == 0); + delete [] array1; + delete [] array2; + } + + test.clear(); + test2.clear(); + while (test.size() < 10000) + { + a = ::rand(); + b = ::rand(); + if (!test.is_in_domain(a)) + test.add(a,b); + } + + count = 0; + a = -1; + while (test.move_next()) + { + DLIB_TEST(a < test.element().key()); + DLIB_TEST(test[test.element().key()] == test.element().value()); + a = test.element().key(); + ++count; + if (count == 5000) + break; + DLIB_TEST(test.current_element_valid() == true); + } + + test.reset(); + + count = 0; + a = -1; + while (test.move_next()) + { + DLIB_TEST(a < test.element().key()); + a = test.element().key(); + ++count; + DLIB_TEST(test.current_element_valid() == true); + } + + DLIB_TEST(count == 10000); + + + test.clear(); + test2.clear(); + } + + + { + test.clear(); + DLIB_TEST(test.size() == 0); + int a = 5; + int b = 6; + test.add(a,b); + a = 7; + b = 8; + test.add(a,b); + DLIB_TEST(test.size() == 2); + DLIB_TEST(test[7] == 8); + DLIB_TEST(test[5] == 6); + DLIB_TEST(test.is_in_domain(7)); + DLIB_TEST(test.is_in_domain(5)); + test.destroy(7); + DLIB_TEST(test.size() == 1); + DLIB_TEST(!test.is_in_domain(7)); + DLIB_TEST(test.is_in_domain(5)); + test.destroy(5); + DLIB_TEST(test.size() == 0); + DLIB_TEST(!test.is_in_domain(7)); + DLIB_TEST(!test.is_in_domain(5)); + } + + } + + + + + class map_tester : public tester + { + public: + map_tester ( + ) : + tester ("test_map", + "Runs tests on the map component.") + {} + + void perform_test ( + ) + { + dlog << LINFO << "testing kernel_1a"; + map_kernel_test::kernel_1a> (); + dlog << LINFO << "testing kernel_1a_c"; + map_kernel_test::kernel_1a_c>(); + dlog << LINFO << "testing kernel_1b"; + map_kernel_test::kernel_1b> (); + dlog << LINFO << "testing kernel_1b_c"; + map_kernel_test::kernel_1b_c>(); + } + } a; + +} + diff --git a/ml/dlib/dlib/test/matrix.cpp b/ml/dlib/dlib/test/matrix.cpp new file mode 100644 index 000000000..0a3ea5996 --- /dev/null +++ b/ml/dlib/dlib/test/matrix.cpp @@ -0,0 +1,1519 @@ + +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include +#include +#include "../stl_checked.h" +#include "../array.h" +#include "../rand.h" + +#include "tester.h" +#include +#include + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + dlib::rand rnd; + + logger dlog("test.matrix"); + + template + const matrix rand_sp_banded(long n, long bw) + { + matrix m = 10 * identity_matrix(n); + for (long row = 0; row < m.nr(); ++row) + { + for (long col = row; col < min(m.nc(), row + bw); ++col) + { + type r = rnd.get_random_double(); + m(row,col) += r; + m(col,row) += r; + } + } + + return m; + } + + void matrix_test ( + ) + /*! + ensures + - runs tests on the matrix stuff compliance with the specs + !*/ + { + typedef memory_manager_stateless::kernel_2_2a MM; + print_spinner(); + + + { + matrix,2,2,MM> m; + set_all_elements(m,complex(1,2)); + DLIB_TEST((conj(m) == uniform_matrix,2,2>(conj(m(0,0))))); + DLIB_TEST((real(m) == uniform_matrix(1))); + DLIB_TEST((imag(m) == uniform_matrix(2))); + DLIB_TEST_MSG((sum(abs(norm(m) - uniform_matrix(5))) < 1e-10 ),norm(m)); + + } + + { + matrix m(5,5); + + for (long r = 0; r < m.nr(); ++r) + { + for (long c = 0; c < m.nc(); ++c) + { + m(r,c) = r*c; + } + } + + m = cos(exp(m)); + + + matrix mi = pinv(m ); + DLIB_TEST(mi.nr() == m.nc()); + DLIB_TEST(mi.nc() == m.nr()); + DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix()))); + DLIB_TEST((equal(round_zeros(m*mi,0.000001) , identity_matrix()))); + DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix(m)))); + DLIB_TEST((equal(round_zeros(m*mi,0.000001) , identity_matrix(m)))); + + mi = pinv(m,1e-12); + DLIB_TEST(mi.nr() == m.nc()); + DLIB_TEST(mi.nc() == m.nr()); + DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix()))); + DLIB_TEST((equal(round_zeros(m*mi,0.000001) , identity_matrix()))); + DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix(m)))); + DLIB_TEST((equal(round_zeros(m*mi,0.000001) , identity_matrix(m)))); + + m = diagm(diag(m)); + mi = pinv(diagm(diag(m)),1e-12); + DLIB_TEST(mi.nr() == m.nc()); + DLIB_TEST(mi.nc() == m.nr()); + DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix()))); + DLIB_TEST((equal(round_zeros(m*mi,0.000001) , identity_matrix()))); + DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix(m)))); + DLIB_TEST((equal(round_zeros(m*mi,0.000001) , identity_matrix(m)))); + + mi = pinv(m,0); + DLIB_TEST(mi.nr() == m.nc()); + DLIB_TEST(mi.nc() == m.nr()); + DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix()))); + DLIB_TEST((equal(round_zeros(m*mi,0.000001) , identity_matrix()))); + DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix(m)))); + DLIB_TEST((equal(round_zeros(m*mi,0.000001) , identity_matrix(m)))); + + m = diagm(diag(m)); + mi = pinv(diagm(diag(m)),0); + DLIB_TEST(mi.nr() == m.nc()); + DLIB_TEST(mi.nc() == m.nr()); + DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix()))); + DLIB_TEST((equal(round_zeros(m*mi,0.000001) , identity_matrix()))); + DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix(m)))); + DLIB_TEST((equal(round_zeros(m*mi,0.000001) , identity_matrix(m)))); + } + { + matrix m(5,5); + + for (long r = 0; r < m.nr(); ++r) + { + for (long c = 0; c < m.nc(); ++c) + { + m(r,c) = r*c; + } + } + + m = cos(exp(m)); + + + matrix mi = pinv(m ); + DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix()))); + } + + { + matrix m(5,5); + + for (long r = 0; r < m.nr(); ++r) + { + for (long c = 0; c < m.nc(); ++c) + { + m(r,c) = r*c; + } + } + + m = cos(exp(m)); + + + matrix mi = pinv(m ); + DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix()))); + } + + + { + matrix m(5,5); + + for (long r = 0; r < m.nr(); ++r) + { + for (long c = 0; c < m.nc(); ++c) + { + m(r,c) = r*c; + } + } + + m = cos(exp(m)); + + + matrix mi = pinv(m ); + DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix()))); + } + + { + matrix m; + + for (long r = 0; r < m.nr(); ++r) + { + for (long c = 0; c < m.nc(); ++c) + { + m(r,c) = r*c; + } + } + + m = cos(exp(m)); + + + matrix mi = pinv(m ); + DLIB_TEST(mi.nr() == m.nc()); + DLIB_TEST(mi.nc() == m.nr()); + DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix()))); + } + + { + matrix m(5,2); + + for (long r = 0; r < m.nr(); ++r) + { + for (long c = 0; c < m.nc(); ++c) + { + m(r,c) = r*c; + } + } + + m = cos(exp(m)); + + + matrix mi = pinv(m ); + DLIB_TEST(mi.nr() == m.nc()); + DLIB_TEST(mi.nc() == m.nr()); + DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix()))); + } + + { + matrix m; + + for (long r = 0; r < m.nr(); ++r) + { + for (long c = 0; c < m.nc(); ++c) + { + m(r,c) = r*c; + } + } + + m = cos(exp(m)); + + + matrix mi = trans(pinv(trans(m) )); + DLIB_TEST(mi.nr() == m.nc()); + DLIB_TEST(mi.nc() == m.nr()); + DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix()))); + } + + { + matrix m(5,2); + + for (long r = 0; r < m.nr(); ++r) + { + for (long c = 0; c < m.nc(); ++c) + { + m(r,c) = r*c; + } + } + + m = cos(exp(m)); + + + matrix mi = trans(pinv(trans(m) )); + DLIB_TEST(mi.nr() == m.nc()); + DLIB_TEST(mi.nc() == m.nr()); + DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix()))); + } + + { + matrix a1(5,1); + matrix a2(1,5); + matrix b1(5,1); + matrix b2(1,5); + matrix c1(5,1); + matrix c2(1,5); + matrix d1(5,1); + matrix d2(1,5); + + for (long i = 0; i < 5; ++i) + { + a1(i) = i; + a2(i) = i; + b1(i) = i; + b2(i) = i; + c1(i) = i; + c2(i) = i; + d1(i) = i; + d2(i) = i; + } + + DLIB_TEST(a1 == trans(a2)); + DLIB_TEST(a1 == trans(b2)); + DLIB_TEST(a1 == trans(c2)); + DLIB_TEST(a1 == trans(d2)); + + DLIB_TEST(a1 == b1); + DLIB_TEST(a1 == c1); + DLIB_TEST(a1 == d1); + + DLIB_TEST(trans(a1) == c2); + DLIB_TEST(trans(b1) == c2); + DLIB_TEST(trans(c1) == c2); + DLIB_TEST(trans(d1) == c2); + + DLIB_TEST(sum(a1) == 10); + DLIB_TEST(sum(a2) == 10); + DLIB_TEST(sum(b1) == 10); + DLIB_TEST(sum(b2) == 10); + DLIB_TEST(sum(c1) == 10); + DLIB_TEST(sum(c2) == 10); + DLIB_TEST(sum(d1) == 10); + DLIB_TEST(sum(d2) == 10); + + const matrix orig1 = a1; + const matrix orig2 = a2; + + ostringstream sout; + serialize(a1,sout); + serialize(a2,sout); + serialize(b1,sout); + serialize(b2,sout); + serialize(c1,sout); + serialize(c2,sout); + serialize(d1,sout); + serialize(d2,sout); + + DLIB_TEST(a1 == orig1); + DLIB_TEST(a2 == orig2); + DLIB_TEST(b1 == orig1); + DLIB_TEST(b2 == orig2); + DLIB_TEST(c1 == orig1); + DLIB_TEST(c2 == orig2); + DLIB_TEST(d1 == orig1); + DLIB_TEST(d2 == orig2); + + set_all_elements(a1,99); + set_all_elements(a2,99); + set_all_elements(b1,99); + set_all_elements(b2,99); + set_all_elements(c1,99); + set_all_elements(c2,99); + set_all_elements(d1,99); + set_all_elements(d2,99); + + DLIB_TEST(a1 != orig1); + DLIB_TEST(a2 != orig2); + DLIB_TEST(b1 != orig1); + DLIB_TEST(b2 != orig2); + DLIB_TEST(c1 != orig1); + DLIB_TEST(c2 != orig2); + DLIB_TEST(d1 != orig1); + DLIB_TEST(d2 != orig2); + + istringstream sin(sout.str()); + + deserialize(a1,sin); + deserialize(a2,sin); + deserialize(b1,sin); + deserialize(b2,sin); + deserialize(c1,sin); + deserialize(c2,sin); + deserialize(d1,sin); + deserialize(d2,sin); + + DLIB_TEST(a1 == orig1); + DLIB_TEST(a2 == orig2); + DLIB_TEST(b1 == orig1); + DLIB_TEST(b2 == orig2); + DLIB_TEST(c1 == orig1); + DLIB_TEST(c2 == orig2); + DLIB_TEST(d1 == orig1); + DLIB_TEST(d2 == orig2); + + + } + + { + matrix a(5); + matrix b(5); + matrix c(5); + matrix d(5); + DLIB_TEST(a.nr() == 1); + DLIB_TEST(a.nc() == 5); + DLIB_TEST(c.nr() == 1); + DLIB_TEST(c.nc() == 5); + + DLIB_TEST(b.nc() == 1); + DLIB_TEST(b.nr() == 5); + DLIB_TEST(d.nc() == 1); + DLIB_TEST(d.nr() == 5); + } + + { + matrix a; + matrix b; + matrix c; + matrix d; + + a.set_size(5); + b.set_size(5); + c.set_size(5); + d.set_size(5); + + DLIB_TEST(a.nr() == 1); + DLIB_TEST(a.nc() == 5); + DLIB_TEST(c.nr() == 1); + DLIB_TEST(c.nc() == 5); + + DLIB_TEST(b.nc() == 1); + DLIB_TEST(b.nr() == 5); + DLIB_TEST(d.nc() == 1); + DLIB_TEST(d.nr() == 5); + } + + { + matrix a(1,5); + matrix b(5,1); + + set_all_elements(a,1); + set_all_elements(b,1); + + + a = a*b; + + DLIB_TEST(a(0) == 5); + } + + { + matrix a(6,7); + + for (long r = 0; r < a.nr(); ++r) + { + for (long c = 0; c < a.nc(); ++c) + { + a(r,c) = r*a.nc() + c; + } + } + + + + DLIB_TEST(rowm(a,1).nr() == 1); + DLIB_TEST(rowm(a,1).nc() == a.nc()); + DLIB_TEST(colm(a,1).nr() == a.nr()); + DLIB_TEST(colm(a,1).nc() == 1); + + for (long c = 0; c < a.nc(); ++c) + { + DLIB_TEST( rowm(a,1)(c) == 1*a.nc() + c); + } + + for (long r = 0; r < a.nr(); ++r) + { + DLIB_TEST( colm(a,1)(r) == r*a.nc() + 1); + } + + rectangle rect(2, 1, 3+2-1, 2+1-1); + DLIB_TEST(get_rect(a).contains(get_rect(a))); + DLIB_TEST(get_rect(a).contains(rect)); + for (long r = 0; r < 2; ++r) + { + for (long c = 0; c < 3; ++c) + { + DLIB_TEST(subm(a,1,2,2,3)(r,c) == (r+1)*a.nc() + c+2); + DLIB_TEST(subm(a,1,2,2,3) == subm(a,rect)); + DLIB_TEST(subm_clipped(a,1,2,2,3) == subm(a,rect)); + DLIB_TEST(subm_clipped(a,1,2,2,3) == subm_clipped(a,rect)); + } + } + + DLIB_TEST(subm(a,rectangle()).nr() == 0); + DLIB_TEST(subm(a,rectangle()).nc() == 0); + + } + + { + array2d a; + a.set_size(6,7); + + + for (long r = 0; r < a.nr(); ++r) + { + for (long c = 0; c < a.nc(); ++c) + { + a[r][c] = r*a.nc() + c; + } + } + + + + DLIB_TEST(rowm(mat(a),1).nr() == 1); + DLIB_TEST(rowm(mat(a),1).nc() == a.nc()); + DLIB_TEST(colm(mat(a),1).nr() == a.nr()); + DLIB_TEST(colm(mat(a),1).nc() == 1); + + for (long c = 0; c < a.nc(); ++c) + { + DLIB_TEST( rowm(mat(a),1)(c) == 1*a.nc() + c); + } + + for (long r = 0; r < a.nr(); ++r) + { + DLIB_TEST( colm(mat(a),1)(r) == r*a.nc() + 1); + } + + for (long r = 0; r < 2; ++r) + { + for (long c = 0; c < 3; ++c) + { + DLIB_TEST(subm(mat(a),1,2,2,3)(r,c) == (r+1)*a.nc() + c+2); + } + } + + + } + + { + array2d m; + m.set_size(5,5); + + for (long r = 0; r < m.nr(); ++r) + { + for (long c = 0; c < m.nc(); ++c) + { + m[r][c] = r*c; + } + } + + + matrix mi = pinv(cos(exp(mat(m))) ); + DLIB_TEST(mi.nr() == m.nc()); + DLIB_TEST(mi.nc() == m.nr()); + DLIB_TEST((equal(round_zeros(mi*cos(exp(mat(m))),0.000001) , identity_matrix()))); + DLIB_TEST((equal(round_zeros(cos(exp(mat(m)))*mi,0.000001) , identity_matrix()))); + } + + { + matrix m1, res; + matrix m2; + + set_all_elements(m1,0); + + + long res_vals[] = { + 9, 9, 9, 9, 9, + 0, 1, 1, 0, 0, + 0, 1, 1, 0, 2, + 0, 0, 2, 2, 2, + 0, 0, 2, 2, 0 + }; + + res = res_vals; + + set_all_elements(m2, 1); + set_subm(m1, range(1,2), range(1,2)) = subm(m2,0,0,2,2); + set_all_elements(m2, 2); + set_subm(m1, 3,2,2,2) = m2; + + set_colm(m1,4) = trans(rowm(m1,4)); + set_rowm(m1,0) = 9; + + DLIB_TEST_MSG(m1 == res, "m1: \n" << m1 << "\nres: \n" << res); + + set_subm(m1,0,0,5,5) = m1*m1; + DLIB_TEST_MSG(m1 == res*res, "m1: \n" << m1 << "\nres*res: \n" << res*res); + + m1 = res; + set_subm(m1,1,1,2,2) = subm(m1,0,0,2,2); + + long res_vals2[] = { + 9, 9, 9, 9, 9, + 0, 9, 9, 0, 0, + 0, 0, 1, 0, 2, + 0, 0, 2, 2, 2, + 0, 0, 2, 2, 0 + }; + + res = res_vals2; + DLIB_TEST_MSG(m1 == res, "m1: \n" << m1 << "\nres: \n" << res); + + + } + + { + matrix m1, res; + matrix m2; + + set_all_elements(m1,0); + + + long res_vals[] = { + 9, 9, 9, 9, 9, + 0, 1, 1, 0, 0, + 0, 1, 1, 0, 2, + 0, 0, 2, 2, 2, + 0, 0, 2, 2, 0 + }; + + res = res_vals; + + set_all_elements(m2, 1); + set_subm(m1, rectangle(1,1,2,2)) = subm(m2,0,0,2,2); + set_all_elements(m2, 2); + set_subm(m1, 3,2,2,2) = m2; + + set_colm(m1,4) = trans(rowm(m1,4)); + set_rowm(m1,0) = 9; + + DLIB_TEST_MSG(m1 == res, "m1: \n" << m1 << "\nres: \n" << res); + + set_subm(m1,0,0,5,5) = m1*m1; + DLIB_TEST_MSG(m1 == res*res, "m1: \n" << m1 << "\nres*res: \n" << res*res); + + m1 = res; + set_subm(m1,1,1,2,2) = subm(m1,0,0,2,2); + + long res_vals2[] = { + 9, 9, 9, 9, 9, + 0, 9, 9, 0, 0, + 0, 0, 1, 0, 2, + 0, 0, 2, 2, 2, + 0, 0, 2, 2, 0 + }; + + res = res_vals2; + DLIB_TEST_MSG(m1 == res, "m1: \n" << m1 << "\nres: \n" << res); + + + } + + { + matrix m1, res; + matrix m2; + + set_all_elements(m1,0); + + + long res_vals[] = { + 9, 0, 3, 3, 0, + 9, 2, 2, 2, 0, + 9, 2, 2, 2, 0, + 4, 4, 4, 4, 4, + 9, 0, 3, 3, 0 + }; + long res_vals_c3[] = { + 9, 0, 3, 0, + 9, 2, 2, 0, + 9, 2, 2, 0, + 4, 4, 4, 4, + 9, 0, 3, 0 + }; + long res_vals_r2[] = { + 9, 0, 3, 3, 0, + 9, 2, 2, 2, 0, + 4, 4, 4, 4, 4, + 9, 0, 3, 3, 0 + }; + + matrix temp; + + res = res_vals; + + temp = matrix(res_vals_r2); + DLIB_TEST(remove_row<2>(res) == temp); + DLIB_TEST(remove_row<2>(res)(3,3) == 3); + DLIB_TEST(remove_row<2>(res).nr() == 4); + DLIB_TEST(remove_row<2>(res).nc() == 5); + DLIB_TEST(remove_row(res,2) == temp); + DLIB_TEST(remove_row(res,2)(3,3) == 3); + DLIB_TEST(remove_row(res,2).nr() == 4); + DLIB_TEST(remove_row(res,2).nc() == 5); + + temp = matrix(res_vals); + temp = remove_row(res,2); + DLIB_TEST((temp == matrix(res_vals_r2))); + temp = matrix(res_vals); + temp = remove_col(res,3); + DLIB_TEST((temp == matrix(res_vals_c3))); + + matrix vect; + set_all_elements(vect,1); + temp = identity_matrix(3); + temp = temp*vect; + DLIB_TEST(temp == vect); + + temp = matrix(res_vals_c3); + DLIB_TEST(remove_col(res,3) == temp); + DLIB_TEST(remove_col(res,3)(2,3) == 0); + DLIB_TEST(remove_col(res,3).nr() == 5); + DLIB_TEST(remove_col(res,3).nc() == 4); + + set_all_elements(m2, 1); + set_subm(m1, rectangle(1,1,3,2)) = 2; + set_all_elements(m2, 2); + set_subm(m1, 3,2,2,2) = 3; + + set_colm(m1,0) = 9; + set_rowm(m1,0) = rowm(m1,4); + set_rowm(m1,3) = 4; + + DLIB_TEST_MSG(m1 == res, "m1: \n" << m1 << "\nres: \n" << res); + + } + + } + + + void matrix_test2() + { + print_spinner(); + + + { + + const double stuff[] = { + 1, 2, 3, + 6, 3, 3, + 7, 3, 9}; + + matrix m(stuff); + + // make m be symmetric + m = m*trans(m); + + matrix L = chol(m); + DLIB_TEST(equal(L*trans(L), m)); + + DLIB_TEST_MSG(equal(inv(m), inv_upper_triangular(trans(L))*inv_lower_triangular((L))), ""); + DLIB_TEST(equal(round_zeros(inv_upper_triangular(trans(L))*trans(L),1e-10), identity_matrix(3), 1e-10)); + DLIB_TEST(equal(round_zeros(inv_lower_triangular((L))*(L),1e-10) ,identity_matrix(3),1e-10)); + + } + + { + + const double stuff[] = { + 1, 2, 3, 6, 3, 4, + 6, 3, 3, 1, 2, 3, + 7, 3, 9, 54.3, 5, 3, + -6, 3, -3, 1, 2, 3, + 1, 2, 3, 5, -3, 3, + 7, 3, -9, 54.3, 5, 3 + }; + + matrix m(stuff); + + // make m be symmetric + m = m*trans(m); + + matrix L = chol(m); + DLIB_TEST_MSG(equal(L*trans(L), m, 1e-10), L*trans(L)-m); + + DLIB_TEST_MSG(equal(inv(m), inv_upper_triangular(trans(L))*inv_lower_triangular((L))), ""); + DLIB_TEST_MSG(equal(inv(m), trans(inv_lower_triangular(L))*inv_lower_triangular((L))), ""); + DLIB_TEST_MSG(equal(inv(m), trans(inv_lower_triangular(L))*trans(inv_upper_triangular(trans(L)))), ""); + DLIB_TEST_MSG(equal(round_zeros(inv_upper_triangular(trans(L))*trans(L),1e-10) , identity_matrix(6), 1e-10), + round_zeros(inv_upper_triangular(trans(L))*trans(L),1e-10)); + DLIB_TEST_MSG(equal(round_zeros(inv_lower_triangular((L))*(L),1e-10) ,identity_matrix(6), 1e-10), + round_zeros(inv_lower_triangular((L))*(L),1e-10)); + + } + + { + // Test band chol + matrix m = rand_sp_banded(10, 3); + + matrix L = chol(m); + DLIB_TEST_MSG(equal(L*trans(L), m, 1e-10), L*trans(L)-m); + DLIB_TEST_MSG(equal(inv(m), inv_upper_triangular(trans(L))*inv_lower_triangular((L))), ""); + DLIB_TEST_MSG(equal(inv(m), trans(inv_lower_triangular(L))*inv_lower_triangular((L))), ""); + DLIB_TEST_MSG(equal(inv(m), trans(inv_lower_triangular(L))*trans(inv_upper_triangular(trans(L)))), ""); + } + + { + // Test band chol in column major layout + matrix m(rand_sp_banded(10, 3)); + + matrix L = chol(m); + DLIB_TEST_MSG(equal(L*trans(L), m, 1e-10), L*trans(L)-m); + DLIB_TEST_MSG(equal(inv(m), inv_upper_triangular(trans(L))*inv_lower_triangular((L))), ""); + DLIB_TEST_MSG(equal(inv(m), trans(inv_lower_triangular(L))*inv_lower_triangular((L))), ""); + DLIB_TEST_MSG(equal(inv(m), trans(inv_lower_triangular(L))*trans(inv_upper_triangular(trans(L)))), ""); + } + + { + matrix m(3,4), m2; + m = 1,2,3,4, + 4,5,6,6, + 6,1,8,0; + m2 = m; + DLIB_TEST(round(m) == m2); + DLIB_TEST(round_zeros(m) == m2); + + m2 = 0,2,3,4, + 4,5,6,6, + 6,0,8,0; + + DLIB_TEST(round_zeros(m,2) == m2); + } + + + { + + matrix m(identity_matrix(6)*4.5); + + matrix L = chol(m); + DLIB_TEST_MSG(equal(L*trans(L), m, 1e-10), L*trans(L)-m); + + DLIB_TEST_MSG(equal(inv(m), inv_upper_triangular(trans(L))*inv_lower_triangular((L))), ""); + DLIB_TEST_MSG(equal(round_zeros(inv_upper_triangular(trans(L))*trans(L),1e-10) , identity_matrix(6), 1e-10), + round_zeros(inv_upper_triangular(trans(L))*trans(L),1e-10)); + DLIB_TEST_MSG(equal(round_zeros(inv_lower_triangular((L))*(L),1e-10) ,identity_matrix(6), 1e-10), + round_zeros(inv_lower_triangular((L))*(L),1e-10)); + + } + + { + + matrix m(identity_matrix(6)*4.5); + m(1,4) = 2; + + DLIB_TEST_MSG(dlib::equal(inv_upper_triangular(m), inv(m),1e-10), inv_upper_triangular(m)-inv(m)); + DLIB_TEST_MSG(dlib::equal(inv_lower_triangular(trans(m)), inv(trans(m)),1e-10), inv_lower_triangular(trans(m))-inv(trans(m))); + + } + + { + matrix a; + matrix b; + matrix i; + a.set_size(1000,10); + b.set_size(1000,10); + i.set_size(1000,10); + dlib::rand rnd; + for (long r = 0; r < a.nr(); ++r) + { + for (long c = 0; c < a.nc(); ++c) + { + a(r,c) = rnd.get_random_double(); + b(r,c) = rnd.get_random_float(); + i(r,c) = r+c*r; + } + } + + // make sure the multiply optimizations aren't messing things up + DLIB_TEST(trans(i)*i == tmp(trans(i)*i)); + DLIB_TEST_MSG(equal(trans(a)*a , tmp(trans(a)*a), 1e-11),max(abs(trans(a)*a - tmp(trans(a)*a)))); + DLIB_TEST_MSG(equal(trans(b)*b , tmp(trans(b)*b), 1e-3f),max(abs(trans(b)*b - tmp(trans(b)*b)))); + } + + { + matrix i(4,1); + i(0) = 1; + i(1) = 2; + i(2) = 3; + i(3) = 4; + matrix m; + set_all_elements(m,0); + m(0,0) = 1; + m(1,1) = 2; + m(2,2) = 3; + m(3,3) = 4; + + DLIB_TEST(diagm(i) == m); + } + + { + matrix i; + i(0) = 1; + i(1) = 2; + i(2) = 3; + i(3) = 4; + matrix m; + set_all_elements(m,0); + m(0,0) = 1; + m(1,1) = 2; + m(2,2) = 3; + m(3,3) = 4; + + DLIB_TEST(diagm(i) == m); + } + + { + matrix i(4,1); + i(0) = 1; + i(1) = 2; + i(2) = 3; + i(3) = 4; + matrix m(4,4); + set_all_elements(m,0); + m(0,0) = 1; + m(1,1) = 2; + m(2,2) = 3; + m(3,3) = 4; + + DLIB_TEST(diagm(i) == m); + } + + { + matrix i(1,4); + i(0) = 1; + i(1) = 2; + i(2) = 3; + i(3) = 4; + matrix m(4,4); + set_all_elements(m,0); + m(0,0) = 1; + m(1,1) = 2; + m(2,2) = 3; + m(3,3) = 4; + + DLIB_TEST(diagm(i) == m); + } + + { + DLIB_TEST(range(0,5).nc() == 6); + DLIB_TEST(range(1,5).nc() == 5); + DLIB_TEST(range(0,5).nr() == 1); + DLIB_TEST(range(1,5).nr() == 1); + DLIB_TEST(trans(range(0,5)).nr() == 6); + DLIB_TEST(trans(range(1,5)).nr() == 5); + DLIB_TEST(trans(range(0,5)).nc() == 1); + DLIB_TEST(trans(range(1,5)).nc() == 1); + + DLIB_TEST(range(0,2,5).nc() == 3); + DLIB_TEST(range(1,2,5).nc() == 3); + DLIB_TEST(range(0,2,5).nr() == 1); + DLIB_TEST(range(1,2,5).nr() == 1); + DLIB_TEST(trans(range(0,2,5)).nr() == 3); + DLIB_TEST(trans(range(1,2,5)).nr() == 3); + DLIB_TEST(trans(range(0,2,5)).nc() == 1); + DLIB_TEST(trans(range(1,2,5)).nc() == 1); + + DLIB_TEST(range(0,3,6).nc() == 3); + DLIB_TEST(range(1,3,5).nc() == 2); + DLIB_TEST(range(0,3,5).nr() == 1); + DLIB_TEST(range(1,3,5).nr() == 1); + DLIB_TEST(trans(range(0,3,6)).nr() == 3); + DLIB_TEST(trans(range(1,3,5)).nr() == 2); + DLIB_TEST(trans(range(0,3,5)).nc() == 1); + DLIB_TEST(trans(range(1,3,5)).nc() == 1); + + DLIB_TEST(range(1,9,5).nc() == 1); + DLIB_TEST(range(1,9,5).nr() == 1); + + DLIB_TEST(range(0,0).nc() == 1); + DLIB_TEST(range(0,0).nr() == 1); + + DLIB_TEST(range(1,1)(0) == 1); + + DLIB_TEST(range(0,5)(0) == 0 && range(0,5)(1) == 1 && range(0,5)(5) == 5); + DLIB_TEST(range(1,2,5)(0) == 1 && range(1,2,5)(1) == 3 && range(1,2,5)(2) == 5); + DLIB_TEST((range<0,5>()(0) == 0 && range<0,5>()(1) == 1 && range<0,5>()(5) == 5)); + DLIB_TEST((range<1,2,5>()(0) == 1 && range<1,2,5>()(1) == 3 && range<1,2,5>()(2) == 5)); + + + DLIB_TEST((range<0,5>().nc() == 6)); + DLIB_TEST((range<1,5>().nc() == 5)); + DLIB_TEST((range<0,5>().nr() == 1)); + DLIB_TEST((range<1,5>().nr() == 1)); + DLIB_TEST((trans(range<0,5>()).nr() == 6)); + DLIB_TEST((trans(range<1,5>()).nr() == 5)); + DLIB_TEST((trans(range<0,5>()).nc() == 1)); + DLIB_TEST((trans(range<1,5>()).nc() == 1)); + + DLIB_TEST((range<0,2,5>().nc() == 3)); + DLIB_TEST((range<1,2,5>().nc() == 3)); + DLIB_TEST((range<0,2,5>().nr() == 1)); + DLIB_TEST((range<1,2,5>().nr() == 1)); + DLIB_TEST((trans(range<0,2,5>()).nr() == 3)); + DLIB_TEST((trans(range<1,2,5>()).nr() == 3)); + DLIB_TEST((trans(range<0,2,5>()).nc() == 1)); + DLIB_TEST((trans(range<1,2,5>()).nc() == 1)); + + DLIB_TEST((range<0,3,6>().nc() == 3)); + DLIB_TEST((range<1,3,5>().nc() == 2)); + DLIB_TEST((range<0,3,5>().nr() == 1)); + DLIB_TEST((range<1,3,5>().nr() == 1)); + DLIB_TEST((trans(range<0,3,6>()).nr() == 3)); + DLIB_TEST((trans(range<1,3,5>()).nr() == 2)); + DLIB_TEST((trans(range<0,3,5>()).nc() == 1)); + DLIB_TEST((trans(range<1,3,5>()).nc() == 1)); + } + + { + DLIB_TEST(range(5,0).nc() == 6); + DLIB_TEST(range(5,1).nc() == 5); + DLIB_TEST(range(5,0).nr() == 1); + DLIB_TEST(range(5,1).nr() == 1); + DLIB_TEST(trans(range(5,0)).nr() == 6); + DLIB_TEST(trans(range(5,1)).nr() == 5); + DLIB_TEST(trans(range(5,0)).nc() == 1); + DLIB_TEST(trans(range(5,1)).nc() == 1); + + DLIB_TEST(range(5,2,0).nc() == 3); + DLIB_TEST(range(5,2,1).nc() == 3); + DLIB_TEST(range(5,2,0).nr() == 1); + DLIB_TEST(range(5,2,1).nr() == 1); + DLIB_TEST(trans(range(5,2,0)).nr() == 3); + DLIB_TEST(trans(range(5,2,1)).nr() == 3); + DLIB_TEST(trans(range(5,2,0)).nc() == 1); + DLIB_TEST(trans(range(5,2,1)).nc() == 1); + + DLIB_TEST(range(6,3,0).nc() == 3); + DLIB_TEST(range(5,3,1).nc() == 2); + DLIB_TEST(range(5,3,0).nr() == 1); + DLIB_TEST(range(5,3,1).nr() == 1); + DLIB_TEST(trans(range(6,3,0)).nr() == 3); + DLIB_TEST(trans(range(5,3,1)).nr() == 2); + DLIB_TEST(trans(range(5,3,0)).nc() == 1); + DLIB_TEST(trans(range(5,3,1)).nc() == 1); + + DLIB_TEST(range(5,9,1).nc() == 1); + DLIB_TEST(range(5,9,1).nr() == 1); + + DLIB_TEST(range(0,0).nc() == 1); + DLIB_TEST(range(0,0).nr() == 1); + + DLIB_TEST(range(1,1)(0) == 1); + + DLIB_TEST(range(5,0)(0) == 5 && range(5,0)(1) == 4 && range(5,0)(5) == 0); + DLIB_TEST(range(5,2,1)(0) == 5 && range(5,2,1)(1) == 3 && range(5,2,1)(2) == 1); + DLIB_TEST((range<5,0>()(0) == 5 && range<5,0>()(1) == 4 && range<5,0>()(5) == 0)); + DLIB_TEST((range<5,2,1>()(0) == 5 && range<5,2,1>()(1) == 3 && range<5,2,1>()(2) == 1)); + + + DLIB_TEST((range<5,0>().nc() == 6)); + DLIB_TEST((range<5,1>().nc() == 5)); + DLIB_TEST((range<5,0>().nr() == 1)); + DLIB_TEST((range<5,1>().nr() == 1)); + DLIB_TEST((trans(range<5,0>()).nr() == 6)); + DLIB_TEST((trans(range<5,1>()).nr() == 5)); + DLIB_TEST((trans(range<5,0>()).nc() == 1)); + DLIB_TEST((trans(range<5,1>()).nc() == 1)); + + DLIB_TEST((range<5,2,0>().nc() == 3)); + DLIB_TEST((range<5,2,1>().nc() == 3)); + DLIB_TEST((range<5,2,0>().nr() == 1)); + DLIB_TEST((range<5,2,1>().nr() == 1)); + DLIB_TEST((trans(range<5,2,0>()).nr() == 3)); + DLIB_TEST((trans(range<5,2,1>()).nr() == 3)); + DLIB_TEST((trans(range<5,2,0>()).nc() == 1)); + DLIB_TEST((trans(range<5,2,1>()).nc() == 1)); + + DLIB_TEST((range<6,3,0>().nc() == 3)); + DLIB_TEST((range<5,3,1>().nc() == 2)); + DLIB_TEST((range<5,3,0>().nr() == 1)); + DLIB_TEST((range<5,3,1>().nr() == 1)); + DLIB_TEST((trans(range<6,3,0>()).nr() == 3)); + DLIB_TEST((trans(range<5,3,1>()).nr() == 2)); + DLIB_TEST((trans(range<5,3,0>()).nc() == 1)); + DLIB_TEST((trans(range<5,3,1>()).nc() == 1)); + } + + { + matrix m(4,3); + for (long r = 0; r < m.nr(); ++r) + { + for (long c = 0; c < m.nc(); ++c) + { + m(r,c) = r*c; + } + } + + DLIB_TEST(subm(m,range(0,3),range(0,0)) == colm(m,0)); + DLIB_TEST(subm(m,range(0,3),range(1,1)) == colm(m,1)); + DLIB_TEST(subm(m,range(0,3),range(2,2)) == colm(m,2)); + + DLIB_TEST(subm(m,range(0,0),range(0,2)) == rowm(m,0)); + DLIB_TEST(subm(m,range(1,1),range(0,2)) == rowm(m,1)); + DLIB_TEST(subm(m,range(2,2),range(0,2)) == rowm(m,2)); + DLIB_TEST(subm(m,range(3,3),range(0,2)) == rowm(m,3)); + + DLIB_TEST(subm(m,0,0,2,2) == subm(m,range(0,1),range(0,1))); + DLIB_TEST(subm(m,1,1,2,2) == subm(m,range(1,2),range(1,2))); + + matrix m2 = subm(m,range(0,2,2),range(0,2,2)); + + DLIB_TEST(m2(0,0) == m(0,0)); + DLIB_TEST(m2(0,1) == m(0,2)); + DLIB_TEST(m2(1,0) == m(2,0)); + DLIB_TEST(m2(1,1) == m(2,2)); + + + } + { + matrix m(4,3); + for (long r = 0; r < m.nr(); ++r) + { + for (long c = 0; c < m.nc(); ++c) + { + m(r,c) = r*c; + } + } + + DLIB_TEST(subm(m,range<0,3>(),range<0,0>()) == colm(m,0)); + DLIB_TEST(subm(m,range<0,3>(),range<1,1>()) == colm(m,1)); + DLIB_TEST(subm(m,range<0,3>(),range<2,2>()) == colm(m,2)); + + DLIB_TEST(subm(m,range<0,0>(),range<0,2>()) == rowm(m,0)); + DLIB_TEST(subm(m,range<1,1>(),range<0,2>()) == rowm(m,1)); + DLIB_TEST(subm(m,range<2,2>(),range<0,2>()) == rowm(m,2)); + DLIB_TEST(subm(m,range<3,3>(),range<0,2>()) == rowm(m,3)); + + DLIB_TEST(subm(m,0,0,2,2) == subm(m,range<0,1>(),range<0,1>())); + DLIB_TEST(subm(m,1,1,2,2) == subm(m,range<1,2>(),range<1,2>())); + + matrix m2 = subm(m,range<0,2,2>(),range<0,2,2>()); + + DLIB_TEST(m2(0,0) == m(0,0)); + DLIB_TEST(m2(0,1) == m(0,2)); + DLIB_TEST(m2(1,0) == m(2,0)); + DLIB_TEST(m2(1,1) == m(2,2)); + + + } + + { + matrix a = randm(3,4); + matrix b = randm(3,4); + + matrix m1, m2; + + m1 = max_pointwise(a,b); + m2 = min_pointwise(a,b); + DLIB_TEST(m1.nr() == a.nr()); + DLIB_TEST(m1.nc() == a.nc()); + DLIB_TEST(m2.nr() == a.nr()); + DLIB_TEST(m2.nc() == a.nc()); + for (long r = 0; r < a.nr(); ++r) + { + for (long c = 0; c < a.nc(); ++c) + { + DLIB_TEST_MSG(m1(r,c) == std::max(a(r,c), b(r,c)), m1(r,c) << " : " << a(r,c) << " " << b(r,c)); + DLIB_TEST(m2(r,c) == std::min(a(r,c), b(r,c))); + } + } + } + + { + matrix m; + set_subm(m, range(0,3), range(0,4)) = 4; + DLIB_TEST(min(m) == max(m) && min(m) == 4); + + set_subm(m,range(1,1),range(0,4)) = 7; + DLIB_TEST((rowm(m,0) == uniform_matrix(1,5, 4))); + DLIB_TEST((rowm(m,1) == uniform_matrix(1,5, 7))); + DLIB_TEST((rowm(m,2) == uniform_matrix(1,5, 4))); + DLIB_TEST((rowm(m,3) == uniform_matrix(1,5, 4))); + + + set_subm(m, range(0,2,3), range(0,2,4)) = trans(subm(m,0,0,3,2)); + + + DLIB_TEST(m(0,2) == 7); + DLIB_TEST(m(2,2) == 7); + + DLIB_TEST(sum(m) == 7*5+ 7+7 + 4*(4*5 - 7)); + + } + + { + matrix mat(4,5); + DLIB_TEST((uniform_matrix(4,5,1) == ones_matrix(4,5))); + DLIB_TEST((uniform_matrix(4,5,1) == ones_matrix(mat))); + DLIB_TEST((uniform_matrix(4,5,0) == zeros_matrix(4,5))); + DLIB_TEST((uniform_matrix(4,5,0) == zeros_matrix(mat))); + DLIB_TEST((uniform_matrix(4,5,1) == ones_matrix(4,5))); + DLIB_TEST((uniform_matrix(4,5,0) == zeros_matrix(4,5))); + DLIB_TEST((uniform_matrix >(4,5,1) == ones_matrix >(4,5))); + DLIB_TEST((uniform_matrix >(4,5,0) == zeros_matrix >(4,5))); + DLIB_TEST((uniform_matrix >(4,5,1) == ones_matrix >(4,5))); + DLIB_TEST((uniform_matrix >(4,5,0) == zeros_matrix >(4,5))); + DLIB_TEST((complex_matrix(ones_matrix(3,3), zeros_matrix(3,3)) == complex_matrix(ones_matrix(3,3)))); + DLIB_TEST((pointwise_multiply(complex_matrix(ones_matrix(3,3)), ones_matrix(3,3)*2) == + complex_matrix(2*ones_matrix(3,3)))); + } + + { + DLIB_TEST(( uniform_matrix(303,303, 3)*identity_matrix(303) == uniform_matrix(3) ) ); + DLIB_TEST(( uniform_matrix(3)*identity_matrix() == uniform_matrix(3) )); + } + + { + matrix m(2,3); + m = 1,2,3, + 5,6,7; + + DLIB_TEST_MSG(m(0,0) == 1 && m(0,1) == 2 && m(0,2) == 3 && + m(1,0) == 5 && m(1,1) == 6 && m(1,2) == 7,""); + + m = 4; + DLIB_TEST((m == uniform_matrix(4))); + + matrix m2; + m2 = 1,2,3, + 5,6,7; + DLIB_TEST_MSG(m2(0,0) == 1 && m2(0,1) == 2 && m2(0,2) == 3 && + m2(1,0) == 5 && m2(1,1) == 6 && m2(1,2) == 7,""); + + matrix m3; + m3 = 1, + 5; + DLIB_TEST(m3(0) == 1 && m3(1) == 5 ); + + matrix m4; + m4 = 1, 5; + DLIB_TEST(m3(0) == 1 && m3(1) == 5 ); + } + + { + matrix m(4,1); + m = 3, 1, 5, 2; + DLIB_TEST(index_of_min(m) == 1); + DLIB_TEST(index_of_max(m) == 2); + DLIB_TEST(index_of_min(trans(m)) == 1); + DLIB_TEST(index_of_max(trans(m)) == 2); + } + + { + matrix m1(1,5), m2; + + m1 = 3.0000, 3.7500, 4.5000, 5.2500, 6.0000; + m2 = linspace(3, 6, 5); + + DLIB_TEST(equal(m1, m2)); + + m1 = pow(10, m1); + m2 = logspace(3, 6, 5); + + DLIB_TEST(equal(m1, m2)); + } + + { + matrix m = cartesian_product(range(1,3), range(0,1)); + + matrix c0, c1, c2, c3, c4, c5; + c0 = 1, 0; + c1 = 1, 1; + c2 = 2, 0; + c3 = 2, 1; + c4 = 3, 0; + c5 = 3, 1; + + DLIB_TEST_MSG(colm(m,0) == c0, colm(m,0) << "\n\n" << c0); + DLIB_TEST(colm(m,1) == c1); + DLIB_TEST(colm(m,2) == c2); + DLIB_TEST(colm(m,3) == c3); + DLIB_TEST(colm(m,4) == c4); + DLIB_TEST(colm(m,5) == c5); + } + + + { + matrix m(2,2), mr(2,2), mr_max(2,2); + + m = 1, 2, + 0, 4; + + mr = 1, 1.0/2.0, + 0, 1.0/4.0; + + mr_max = 1, 1.0/2.0, + std::numeric_limits::max(), 1.0/4.0; + + DLIB_TEST(equal(reciprocal(m), mr)); + DLIB_TEST(equal(reciprocal_max(m), mr_max)); + + } + + { + matrix m1, m2; + m1.set_size(3,1); + m2.set_size(1,3); + + m1 = 1,2,3; + m2 = 4,5,6; + DLIB_TEST(dot(m1, m2) == 1*4 + 2*5 + 3*6); + DLIB_TEST(dot(m1, trans(m2)) == 1*4 + 2*5 + 3*6); + DLIB_TEST(dot(trans(m1), m2) == 1*4 + 2*5 + 3*6); + DLIB_TEST(dot(trans(m1), trans(m2)) == 1*4 + 2*5 + 3*6); + } + + { + matrix m1, m2; + m1.set_size(3,1); + m2.set_size(3,1); + + m1 = 1,2,3; + m2 = 4,5,6; + DLIB_TEST(dot(m1, m2) == 1*4 + 2*5 + 3*6); + DLIB_TEST(dot(m1, trans(m2)) == 1*4 + 2*5 + 3*6); + DLIB_TEST(dot(trans(m1), m2) == 1*4 + 2*5 + 3*6); + DLIB_TEST(dot(trans(m1), trans(m2)) == 1*4 + 2*5 + 3*6); + } + { + matrix m1, m2; + m1.set_size(1,3); + m2.set_size(1,3); + + m1 = 1,2,3; + m2 = 4,5,6; + DLIB_TEST(dot(m1, m2) == 1*4 + 2*5 + 3*6); + DLIB_TEST(dot(m1, trans(m2)) == 1*4 + 2*5 + 3*6); + DLIB_TEST(dot(trans(m1), m2) == 1*4 + 2*5 + 3*6); + DLIB_TEST(dot(trans(m1), trans(m2)) == 1*4 + 2*5 + 3*6); + } + { + matrix m1; + matrix m2; + m1.set_size(1,3); + m2.set_size(3,1); + + m1 = 1,2,3; + m2 = 4,5,6; + DLIB_TEST(dot(m1, m2) == 1*4 + 2*5 + 3*6); + DLIB_TEST(dot(m1, trans(m2)) == 1*4 + 2*5 + 3*6); + DLIB_TEST(dot(trans(m1), m2) == 1*4 + 2*5 + 3*6); + DLIB_TEST(dot(trans(m1), trans(m2)) == 1*4 + 2*5 + 3*6); + } + + { + matrix m1(3,3), m2(3,3); + + m1 = 1; + m2 = 1; + m1 = m1*subm(m2,0,0,3,3); + DLIB_TEST(is_finite(m1)); + } + { + matrix m1; + matrix m2(3,3); + + m1 = 1; + m2 = 1; + m1 = subm(m2,0,0,3,3)*m1; + } + + { + matrix m(2,1); + + m = 3,3; + m /= m(0); + + DLIB_TEST(m(0) == 1); + DLIB_TEST(m(1) == 1); + } + { + matrix m(2,1); + + m = 3,3; + m *= m(0); + + DLIB_TEST(m(0) == 9); + DLIB_TEST(m(1) == 9); + } + { + matrix m(2,1); + + m = 3,3; + m -= m(0); + + DLIB_TEST(m(0) == 0); + DLIB_TEST(m(1) == 0); + } + { + matrix m(2,1); + + m = 3,3; + m += m(0); + + DLIB_TEST(m(0) == 6); + DLIB_TEST(m(1) == 6); + DLIB_TEST(is_finite(m)); + } + + + { + matrix m(3,3); + m = 3; + m(1,1) = std::numeric_limits::infinity(); + DLIB_TEST(is_finite(m) == false); + m(1,1) = -std::numeric_limits::infinity(); + DLIB_TEST(is_finite(m) == false); + m(1,1) = 2; + DLIB_TEST(is_finite(m)); + } + + { + matrix m(4,1), mm, mmm; + + mmm = mm = (m = 1,2,3,4); + DLIB_TEST(m(0) == 1); + DLIB_TEST(m(1) == 2); + DLIB_TEST(m(2) == 3); + DLIB_TEST(m(3) == 4); + DLIB_TEST(mm == m); + DLIB_TEST(mmm == m); + DLIB_TEST(mm(0) == 1); + DLIB_TEST(mm(1) == 2); + DLIB_TEST(mm(2) == 3); + DLIB_TEST(mm(3) == 4); + } + + { + const long n = 5; + matrix m1, m2, m3, truth; + m1 = randm(n,n); + m2 = randm(n,n); + + rectangle rect1(1,1,3,3); + rectangle rect2(2,1,4,3); + + truth = subm(m1,rect1)*subm(m2,rect2); + m3 = mat(&m1(0,0)+6, 3,3, m1.nc()) * mat(&m2(0,0)+7, 3,3, m2.nc()); + + DLIB_TEST(max(abs(truth-m3)) < 1e-13); + } + + { + const long n = 5; + matrix m1, m2, m3, truth; + m1 = randm(n,n); + m2 = randm(n,n); + m3 = randm(n,n); + + + truth = m1*m2; + m3 = mat(&m1(0,0),n,n)*mat(&m2(0,0),n,n); + DLIB_TEST(max(abs(truth-m3)) < 1e-13); + m3 = 0; + set_ptrm(&m3(0,0),n,n) = mat(&m1(0,0),n,n)*mat(&m2(0,0),n,n); + DLIB_TEST(max(abs(truth-m3)) < 1e-13); + set_ptrm(&m3(0,0),n,n) = m1*m2; + DLIB_TEST(max(abs(truth-m3)) < 1e-13); + + // now make sure it deals with aliasing correctly. + truth = m1*m2; + m1 = mat(&m1(0,0),n,n)*mat(&m2(0,0),n,n); + DLIB_TEST(max(abs(truth-m1)) < 1e-13); + + m1 = randm(n,n); + truth = m1*m2; + set_ptrm(&m1(0,0),n,n) = mat(&m1(0,0),n,n)*mat(&m2(0,0),n,n); + DLIB_TEST(max(abs(truth-m1)) < 1e-13); + + m1 = randm(n,n); + truth = m1*m2; + set_ptrm(&m1(0,0),n,n) = m1*m2; + DLIB_TEST(max(abs(truth-m1)) < 1e-13); + + m1 = randm(n,n); + truth = m1+m1*m2; + set_ptrm(&m1(0,0),n,n) += m1*m2; + DLIB_TEST(max(abs(truth-m1)) < 1e-13); + + m1 = randm(n,n); + truth = m1-m1*m2; + set_ptrm(&m1(0,0),n,n) -= m1*m2; + DLIB_TEST(max(abs(truth-m1)) < 1e-13); + + } + + { + matrix a(3,3); + matrix m = randm(3,3); + matrix b = randm(3,1); + + a = 0; + set_colm(a,0) = m*b; + DLIB_TEST(colm(a,0) == m*b); + a = 0; + set_rowm(a,0) = trans(m*b); + DLIB_TEST(rowm(a,0) == trans(m*b)); + DLIB_TEST(rowm(a,0) != m*b); + } + { + matrix a(3,3); + matrix m = randm(3,3); + matrix b = randm(3,1); + + a = 0; + set_colm(a,0) = m*b; + DLIB_TEST(equal(colm(a,0) , m*b)); + a = 0; + set_rowm(a,0) = trans(m*b); + DLIB_TEST(equal(rowm(a,0) , trans(m*b))); + DLIB_TEST(!equal(rowm(a,0) , m*b)); + } + { + matrix a(3,3); + matrix m = randm(3,3); + matrix b = randm(3,1); + + a = 0; + set_colm(a,0) = m*b; + DLIB_TEST(equal(colm(a,0) , m*b)); + a = 0; + set_rowm(a,0) = trans(m*b); + DLIB_TEST(equal(rowm(a,0) , trans(m*b))); + DLIB_TEST(!equal(rowm(a,0) , m*b)); + } + } + + + + + + + class matrix_tester : public tester + { + public: + matrix_tester ( + ) : + tester ("test_matrix", + "Runs tests on the matrix component.") + {} + + void perform_test ( + ) + { + matrix_test(); + matrix_test2(); + } + } a; + +} + + diff --git a/ml/dlib/dlib/test/matrix2.cpp b/ml/dlib/dlib/test/matrix2.cpp new file mode 100644 index 000000000..8de17fc7f --- /dev/null +++ b/ml/dlib/dlib/test/matrix2.cpp @@ -0,0 +1,1158 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include +#include +#include "../stl_checked.h" +#include "../array.h" +#include "../rand.h" + +#include "tester.h" +#include +#include + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.matrix2"); + + dlib::rand rnd; + + void matrix_test1 ( + ) + { + typedef memory_manager_stateless::kernel_2_2a MM; + print_spinner(); + + const double ident[] = { + 1, 0, 0, 0, + 0, 1, 0, 0, + 0, 0, 1, 0, + 0, 0, 0, 1 }; + + const double uniform3[] = { + 3, 3, 3, 3, + 3, 3, 3, 3, + 3, 3, 3, 3, + 3, 3, 3, 3 + }; + + const double uniform1[] = { + 1, 1, 1, 1, + 1, 1, 1, 1, + 1, 1, 1, 1, + 1, 1, 1, 1 + }; + + const double uniform0[] = { + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0 + }; + + const int array[] = { + 42, 58, 9, 1, + 9, 5, 8, 2, + 98, 28, 4, 77, + 9, 2, 44, 88 }; + + const int array2[] = { + 1, 22, 3, + 4, 52, 6, + 7, 8, 9 }; + + const int array2_r[] = { + 52, 6, 4, + 8, 9, 7, + 22, 3, 1 + }; + + const double array_f[] = { + -0.99, + 0.99}; + + + matrix fm(array_f); + + DLIB_TEST(fm.size() == 2); + matrix dfm(fm); + DLIB_TEST(round(fm)(0) == -1); + DLIB_TEST(round(fm)(1) == 1); + DLIB_TEST(round(dfm)(0) == -1); + DLIB_TEST(round(dfm)(1) == 1); + DLIB_TEST(round(dfm).size() == dfm.size()); + + + const int array3[] = { 1, 2, 3, 4 }; + + matrix m3(array2); + matrix dm3; + DLIB_TEST(dm3.size() == 0); + DLIB_TEST(dm3.nr() == 0); + DLIB_TEST(dm3.nc() == 0); + dm3.set_size(3,4); + DLIB_TEST(dm3.nr() == 3); + DLIB_TEST(dm3.nc() == 4); + DLIB_TEST(dm3.size() == 3*4); + dm3.set_size(3,3); + DLIB_TEST(dm3.nr() == 3); + DLIB_TEST(dm3.nc() == 3); + dm3 = m3; + dm3(0,0)++; + DLIB_TEST( dm3 != m3); + dm3 = m3; + DLIB_TEST( dm3 == m3); + DLIB_TEST( abs(sum(squared(normalize(dm3))) - 1.0) < 1e-10); + + matrix mrc; + mrc.set_size(3,4); + + set_all_elements(mrc,1); + + DLIB_TEST(diag(mrc) == uniform_matrix(3,1,1)); + DLIB_TEST(diag(matrix(mrc)) == uniform_matrix(3,1,1)); + + matrix mrc2; + set_all_elements(mrc2,1); + DLIB_TEST((removerc<1,1>(mrc) == mrc2)); + DLIB_TEST((removerc(mrc,1,1) == mrc2)); + + matrix m4, m5, m6; + set_all_elements(m4, 4); + set_all_elements(m5, 4); + set_all_elements(m6, 1); + + DLIB_TEST(squared(m4) == pointwise_multiply(m4,m4)); + DLIB_TEST(cubed(m4) == pointwise_multiply(m4,m4,m4)); + DLIB_TEST(pow(matrix_cast(m4),2) == squared(matrix_cast(m4))); + DLIB_TEST(pow(matrix_cast(m4),3) == cubed(matrix_cast(m4))); + + matrix dm4; + matrix::kernel_2_2a> dm5; + dm4 = dm4; + dm4 = dm5; + DLIB_TEST(dm4.nr() == 0); + dm4 = m4; + dm5 = m5; + DLIB_TEST(dm4 == dm5); + + + DLIB_TEST(m4 == m5); + DLIB_TEST(m6 != m5); + m4.swap(m6); + DLIB_TEST(m6 == m5); + DLIB_TEST(m4 != m5); + + DLIB_TEST(m3.nr() == 3); + DLIB_TEST(m3.nc() == 3); + + matrix v(array3), v2; + DLIB_TEST(v.nr() == 4); + DLIB_TEST(v.nc() == 1); + + std::vector stdv(4); + std_vector_c stdv_c(4); + dlib::array arr; + arr.resize(4); + for (long i = 0; i < 4; ++i) + stdv[i] = stdv_c[i] = arr[i] = i+1; + + DLIB_TEST(mat(stdv)(0) == 1); + DLIB_TEST(mat(stdv)(1) == 2); + DLIB_TEST(mat(stdv)(2) == 3); + DLIB_TEST(mat(stdv)(3) == 4); + DLIB_TEST(mat(stdv).nr() == 4); + DLIB_TEST(mat(stdv).nc() == 1); + DLIB_TEST(mat(stdv).size() == 4); + DLIB_TEST(equal(trans(mat(stdv))*mat(stdv), trans(v)*v)); + DLIB_TEST(equal(trans(mat(stdv))*mat(stdv), tmp(trans(v)*v))); + + DLIB_TEST(mat(stdv_c)(0) == 1); + DLIB_TEST(mat(stdv_c)(1) == 2); + DLIB_TEST(mat(stdv_c)(2) == 3); + DLIB_TEST(mat(stdv_c)(3) == 4); + DLIB_TEST(mat(stdv_c).nr() == 4); + DLIB_TEST(mat(stdv_c).nc() == 1); + DLIB_TEST(mat(stdv_c).size() == 4); + DLIB_TEST(equal(trans(mat(stdv_c))*mat(stdv_c), trans(v)*v)); + + DLIB_TEST(mat(arr)(0) == 1); + DLIB_TEST(mat(arr)(1) == 2); + DLIB_TEST(mat(arr)(2) == 3); + DLIB_TEST(mat(arr)(3) == 4); + DLIB_TEST(mat(arr).nr() == 4); + DLIB_TEST(mat(arr).nc() == 1); + DLIB_TEST(mat(arr).size() == 4); + DLIB_TEST(equal(trans(mat(arr))*mat(arr), trans(v)*v)); + + DLIB_TEST(v(0) == 1); + DLIB_TEST(v(1) == 2); + DLIB_TEST(v(2) == 3); + DLIB_TEST(v(3) == 4); + matrix dv = v; + DLIB_TEST((trans(v)*v).size() == 1); + DLIB_TEST((trans(v)*v).nr() == 1); + DLIB_TEST((trans(v)*dv).nr() == 1); + DLIB_TEST((trans(dv)*dv).nr() == 1); + DLIB_TEST((trans(dv)*v).nr() == 1); + DLIB_TEST((trans(v)*v).nc() == 1); + DLIB_TEST((trans(v)*dv).nc() == 1); + DLIB_TEST((trans(dv)*dv).nc() == 1); + DLIB_TEST((trans(dv)*v).nc() == 1); + DLIB_TEST((trans(v)*v)(0) == 1*1 + 2*2 + 3*3 + 4*4); + DLIB_TEST((trans(dv)*v)(0) == 1*1 + 2*2 + 3*3 + 4*4); + DLIB_TEST((trans(dv)*dv)(0) == 1*1 + 2*2 + 3*3 + 4*4); + DLIB_TEST((trans(v)*dv)(0) == 1*1 + 2*2 + 3*3 + 4*4); + + dv = trans(dv)*v; + DLIB_TEST(dv.nr() == 1); + DLIB_TEST(dv.nc() == 1); + + dm3 = m3; + DLIB_TEST(floor(det(m3)+0.01) == -444); + DLIB_TEST(floor(det(dm3)+0.01) == -444); + DLIB_TEST(min(m3) == 1); + DLIB_TEST(m3(min_point(m3).y(),min_point(m3).x()) == 1); + DLIB_TEST(min(dm3) == 1); + DLIB_TEST(max(m3) == 52); + DLIB_TEST(m3(max_point(m3).y(),max_point(m3).x()) == 52); + DLIB_TEST(max(dm3) == 52); + DLIB_TEST(sum(m3) == 112); + DLIB_TEST(sum(dm3) == 112); + DLIB_TEST(prod(m3) == 41513472); + DLIB_TEST(prod(dm3) == 41513472); + DLIB_TEST(prod(diag(m3)) == 1*52*9); + DLIB_TEST(prod(diag(dm3)) == 1*52*9); + DLIB_TEST(sum(diag(m3)) == 1+52+9); + DLIB_TEST(sum(diag(dm3)) == 1+52+9); + DLIB_TEST(equal(round(10000*m3*inv(m3))/10000 , identity_matrix())); + DLIB_TEST(equal(round(10000*dm3*inv(m3))/10000 , identity_matrix())); + DLIB_TEST(equal(round(10000*dm3*inv(dm3))/10000 , identity_matrix())); + DLIB_TEST(equal(round(10000*m3*inv(dm3))/10000 , identity_matrix())); + DLIB_TEST(equal(round(10000*tmp(m3*inv(m3)))/10000 , identity_matrix())); + DLIB_TEST(equal(round(10000*tmp(dm3*inv(m3)))/10000 , identity_matrix())); + DLIB_TEST(equal(round(10000*tmp(dm3*inv(dm3)))/10000 , identity_matrix())); + DLIB_TEST(equal(round(10000*tmp(m3*inv(dm3)))/10000 , identity_matrix())); + DLIB_TEST(-1*m3 == -m3); + DLIB_TEST(-1*dm3 == -m3); + DLIB_TEST(-1*m3 == -dm3); + DLIB_TEST(-1*dm3 == -dm3); + + DLIB_TEST(m3 == dm3); + m3(1,1) = 99; + DLIB_TEST(m3 != dm3); + m3 = dm3; + DLIB_TEST(m3 == dm3); + + matrix mident(ident); + matrix muniform0(uniform0); + matrix muniform1(uniform1); + matrix muniform3(uniform3); + matrix m1(array), m2; + DLIB_TEST(m1.nr() == 4); + DLIB_TEST(m1.nc() == 4); + + DLIB_TEST(muniform1 + muniform1 + muniform1 == muniform3); + DLIB_TEST(muniform1*2 + muniform1 + muniform1 - muniform1 == muniform3); + DLIB_TEST(2*muniform1 + muniform1 + muniform1 - muniform1 == muniform3); + DLIB_TEST(muniform1 + muniform1 + muniform1 - muniform3 == muniform0); + DLIB_TEST(equal(muniform3/3 , muniform1)); + DLIB_TEST(v != m1); + DLIB_TEST(v == v); + DLIB_TEST(m1 == m1); + + muniform0.swap(muniform1); + DLIB_TEST((muniform1 == matrix_cast(uniform_matrix()))); + DLIB_TEST((muniform0 == matrix_cast(uniform_matrix()))); + DLIB_TEST((muniform1 == matrix_cast(uniform_matrix(4,4,0)))); + DLIB_TEST((muniform0 == matrix_cast(uniform_matrix(4,4,1)))); + swap(muniform0,muniform1); + + DLIB_TEST((mident == identity_matrix())); + DLIB_TEST((muniform0 == matrix_cast(uniform_matrix()))); + DLIB_TEST((muniform1 == matrix_cast(uniform_matrix()))); + DLIB_TEST((muniform3 == matrix_cast(uniform_matrix()))); + DLIB_TEST((muniform1*8 == matrix_cast(uniform_matrix()))); + + set_all_elements(m2,7); + DLIB_TEST(m2 == muniform1*7); + m2 = array; + DLIB_TEST(m2 == m1); + + const double m1inv[] = { + -0.00946427624, 0.0593272941, 0.00970564379, -0.00973323731, + 0.0249312057, -0.0590122427, -0.00583102756, 0.00616002729, + -0.00575431149, 0.110081189, -0.00806792253, 0.00462297692, + 0.00327847478, -0.0597669712, 0.00317386196, 0.00990759201 + }; + + m2 = m1inv; + DLIB_TEST((round(m2*m1) == identity_matrix())); + DLIB_TEST((round(tmp(m2*m1)) == identity_matrix())); + + DLIB_TEST_MSG(round(m2*10000) == round(inv(m1)*10000), + round(m2*10000) - round(inv(m1)*10000) + << "\n\n" << round(m2*10000) + << "\n\n" << round(inv(m1)*10000) + << "\n\n" << m2 + << "\n\n" << inv(m1) + ); + DLIB_TEST(m1 == abs(-1*m1)); + DLIB_TEST(abs(m2) == abs(-1*m2)); + + DLIB_TEST_MSG(floor(det(m1)+0.01) == 3297875,"\nm1: \n" << m1 << "\ndet(m1): " << det(m1)); + + + ostringstream sout; + m1 = m2; + serialize(m1,sout); + set_all_elements(m1,0); + istringstream sin(sout.str()); + deserialize(m1,sin); + DLIB_TEST_MSG(round(100000*m1) == round(100000*m2),"m1: \n" << m1 << endl << "m2: \n" << m2); + + + set_all_elements(v,2); + v2 = pointwise_multiply(v, v*2); + set_all_elements(v,8); + DLIB_TEST(v == v2); + DLIB_TEST(v == tmp(v2)); + DLIB_TEST((v == rotate<2,0>(v))); + + m4 = array2; + m5 = array2_r; + DLIB_TEST((m5 == rotate<1,1>(m4))); + + m5 = array2; + DLIB_TEST((m5*2 == pointwise_multiply(m5,uniform_matrix()))); + DLIB_TEST((tmp(m5*2) == tmp(pointwise_multiply(m5,uniform_matrix())))); + + v = tmp(v); + + + + + matrix dm10(10,5); + DLIB_TEST(dm10.nr() == 10); + DLIB_TEST(dm10.nc() == 5); + set_all_elements(dm10,4); + DLIB_TEST(dm10.nr() == 10); + DLIB_TEST(dm10.nc() == 5); + matrix m10; + DLIB_TEST(m10.nr() == 10); + DLIB_TEST(m10.nc() == 5); + set_all_elements(m10,4); + DLIB_TEST(dm10 == m10); + DLIB_TEST((clamp<0,3>(dm10) == clamp<0,3>(m10))); + DLIB_TEST((clamp<0,3>(dm10)(0,2) == 3)); + + set_all_elements(dm10,1); + set_all_elements(m10,4); + DLIB_TEST(4*dm10 == m10); + DLIB_TEST(5*dm10 - dm10 == m10); + DLIB_TEST((16*dm10)/4 == m10); + DLIB_TEST(dm10+dm10+2*dm10 == m10); + DLIB_TEST(dm10+tmp(dm10+2*dm10) == m10); + set_all_elements(dm10,4); + DLIB_TEST(dm10 == m10); + DLIB_TEST_MSG(sum(abs(sigmoid(dm10) -sigmoid(m10))) < 1e-10,sum(abs(sigmoid(dm10) -sigmoid(m10))) ); + + { + matrix x, l, u, out; + x = 3,4; + + l = 1,1; + u = 2,2.2; + + out = 2, 2.2; + DLIB_TEST(equal(dlib::clamp(x, l, u) , out)); + out = 3, 2.2; + DLIB_TEST(!equal(dlib::clamp(x, l, u) , out)); + out = 2, 4.2; + DLIB_TEST(!equal(dlib::clamp(x, l, u) , out)); + + x = 1.5, 1.5; + out = x; + DLIB_TEST(equal(dlib::clamp(x, l, u) , out)); + + x = 0.5, 1.5; + out = 1, 1.5; + DLIB_TEST(equal(dlib::clamp(x, l, u) , out)); + + x = 1.5, 0.5; + out = 1.5, 1.0; + DLIB_TEST(equal(dlib::clamp(x, l, u) , out)); + + } + + matrix m7; + matrix dm7(7,7); + dm7 = randm(7,7, rnd); + m7 = dm7; + + DLIB_TEST_MSG(max(abs(dm7*inv(dm7) - identity_matrix(7))) < 1e-12, max(abs(dm7*inv(dm7) - identity_matrix(7)))); + DLIB_TEST(equal(inv(dm7), inv(m7))); + DLIB_TEST(abs(det(dm7) - det(m7)) < 1e-14); + DLIB_TEST(abs(min(dm7) - min(m7)) < 1e-14); + DLIB_TEST(abs(max(dm7) - max(m7)) < 1e-14); + DLIB_TEST_MSG(abs(sum(dm7) - sum(m7)) < 1e-13,sum(dm7) - sum(m7)); + DLIB_TEST(abs(prod(dm7) -prod(m7)) < 1e-14); + DLIB_TEST(equal(diag(dm7) , diag(m7))); + DLIB_TEST(equal(trans(dm7) , trans(m7))); + DLIB_TEST(equal(abs(dm7) , abs(m7))); + DLIB_TEST(equal(round(dm7) , round(m7))); + DLIB_TEST(matrix_cast(dm7) == matrix_cast(m7)); + DLIB_TEST((rotate<2,3>(dm7) == rotate<2,3>(m7))); + DLIB_TEST((sum(pointwise_multiply(dm7,dm7) - pointwise_multiply(m7,m7))) < 1e-10); + DLIB_TEST((sum(pointwise_multiply(dm7,dm7,dm7) - pointwise_multiply(m7,m7,m7))) < 1e-10); + DLIB_TEST_MSG((sum(pointwise_multiply(dm7,dm7,dm7,dm7) - pointwise_multiply(m7,m7,m7,m7))) < 1e-10, + (sum(pointwise_multiply(dm7,dm7,dm7,dm7) - pointwise_multiply(m7,m7,m7,m7))) + ); + + + matrix temp(5,5); + matrix dsm(5,5); + matrix sm; + + set_all_elements(dsm,1); + set_all_elements(sm,1); + set_all_elements(temp,1); + + dsm += dsm; + sm += sm; + DLIB_TEST(dsm == 2*temp); + DLIB_TEST(sm == 2*temp); + temp = dsm*sm + dsm; + dsm += dsm*sm; + DLIB_TEST_MSG(temp == dsm,temp - dsm); + + set_all_elements(dsm,1); + set_all_elements(sm,1); + set_all_elements(temp,1); + + dsm += dsm; + sm += sm; + DLIB_TEST(dsm == 2*temp); + DLIB_TEST(sm == 2*temp); + temp = dsm*sm + dsm; + sm += dsm*sm; + DLIB_TEST_MSG(temp == sm,temp - sm); + + set_all_elements(dsm,1); + set_all_elements(sm,1); + set_all_elements(temp,1); + + dsm += dsm; + sm += sm; + DLIB_TEST(dsm == 2*temp); + DLIB_TEST(sm == 2*temp); + temp = sm - dsm*sm ; + sm -= dsm*sm; + DLIB_TEST_MSG(temp == sm,temp - sm); + + set_all_elements(dsm,1); + set_all_elements(sm,1); + set_all_elements(temp,1); + + dsm += dsm; + sm += sm; + DLIB_TEST(dsm == 2*temp); + DLIB_TEST(sm == 2*temp); + temp = dsm - dsm*sm ; + dsm -= dsm*sm; + DLIB_TEST_MSG(temp == dsm,temp - dsm); + + set_all_elements(dsm,1); + set_all_elements(sm,1); + set_all_elements(temp,2); + + dsm *= 2; + sm *= 2; + DLIB_TEST(dsm == temp); + DLIB_TEST(sm == temp); + dsm /= 2; + sm /= 2; + DLIB_TEST(dsm == temp/2); + DLIB_TEST(sm == temp/2); + + dsm += dsm; + sm += sm; + DLIB_TEST(dsm == temp); + DLIB_TEST(sm == temp); + dsm += sm; + sm += dsm; + DLIB_TEST(dsm == 2*temp); + DLIB_TEST(sm == temp*3); + dsm -= sm; + sm -= dsm; + DLIB_TEST(dsm == -temp); + DLIB_TEST(sm == 4*temp); + sm -= sm; + dsm -= dsm; + DLIB_TEST(dsm == 0*temp); + DLIB_TEST(sm == 0*temp); + + set_all_elements(dsm,1); + set_all_elements(sm,1); + set_all_elements(temp,3); + dsm += sm+sm; + DLIB_TEST(dsm == temp); + + set_all_elements(dsm,1); + set_all_elements(sm,1); + set_all_elements(temp,-1); + dsm -= sm+sm; + DLIB_TEST(dsm == temp); + + set_all_elements(dsm,1); + set_all_elements(sm,1); + set_all_elements(temp,-1); + sm -= dsm+dsm; + DLIB_TEST(sm == temp); + + set_all_elements(dsm,1); + set_all_elements(sm,1); + set_all_elements(temp,3); + sm += dsm+dsm; + DLIB_TEST(sm == temp); + + + + // test the implicit conversion to bool stuff + { + matrix bt1(3,1); + matrix bt2; + set_all_elements(bt1,2); + set_all_elements(bt2,3); + + float val = trans(bt1)*bt2; + DLIB_TEST((float)(trans(bt1)*bt2) == 18); + DLIB_TEST((float)(trans(bt1)*bt2) != 19); + DLIB_TEST(val == 18); + } + { + matrix bt1; + matrix bt2(3,1); + set_all_elements(bt1,2); + set_all_elements(bt2,3); + + float val = trans(bt1)*bt2; + DLIB_TEST((float)(trans(bt1)*bt2) == 18); + DLIB_TEST((float)(trans(bt1)*bt2) != 19); + DLIB_TEST(val == 18); + } + { + matrix bt1(3,1); + matrix bt2(3,1); + set_all_elements(bt1,2); + set_all_elements(bt2,3); + + float val = trans(bt1)*bt2; + DLIB_TEST((float)(trans(bt1)*bt2) == 18); + DLIB_TEST((float)(trans(bt1)*bt2) != 19); + DLIB_TEST(val == 18); + } + { + matrix bt1; + matrix bt2; + set_all_elements(bt1,2); + set_all_elements(bt2,3); + + float val = trans(bt1)*bt2; + DLIB_TEST((float)(trans(bt1)*bt2) == 18); + DLIB_TEST((float)(trans(bt1)*bt2) != 19); + DLIB_TEST(val == 18); + } + + + + + { + srand(423452); + const long M = 50; + const long N = 40; + + matrix a(M,N); + for (long r = 0; r < a.nr(); ++r) + { + for (long c = 0; c < a.nc(); ++c) + { + a(r,c) = 10*((double)::rand())/RAND_MAX; + } + } + + matrix u, u2; + matrix q, q2; + matrix v, v2; + + matrix a2; + a2 = tmp(a/2); + + + svd2(true,true,a2+a2,u,q,v); + + double err = max(abs(a - subm(u,get_rect(a2+a2))*diagm(q)*trans(v))); + DLIB_TEST_MSG( err < 1e-11,"err: " << err); + using dlib::equal; + DLIB_TEST((equal(trans(u)*u , identity_matrix(), 1e-10))); + DLIB_TEST((equal(trans(v)*v , identity_matrix(), 1e-10))); + + svd2(false,true,a2+a2,u,q,v2); + svd2(true,false,a2+a2,u2,q,v); + svd2(false,false,a2+a2,u,q2,v); + + err = max(abs(a - subm(u2,get_rect(a2+a2))*diagm(q2)*trans(v2))); + DLIB_TEST_MSG( err < 1e-11,"err: " << err); + DLIB_TEST((equal(trans(u2)*u2 , identity_matrix(), 1e-10))); + DLIB_TEST((equal(trans(v2)*v2 , identity_matrix(), 1e-10))); + + } + + + { + srand(423452); + const long M = 3; + const long N = 3; + + matrix a(M,N); + for (long r = 0; r < a.nr(); ++r) + { + for (long c = 0; c < a.nc(); ++c) + { + a(r,c) = 10*((double)::rand())/RAND_MAX; + } + } + + matrix u, u2; + matrix q, q2; + matrix v, v2; + + matrix a2; + a2 = tmp(a/2); + + + svd2(true,true,a2+a2,u,q,v); + + double err = max(abs(a - subm(u,get_rect(a2+a2))*diagm(q)*trans(v))); + DLIB_TEST_MSG( err < 1e-11,"err: " << err); + using dlib::equal; + DLIB_TEST((equal(trans(u)*u , identity_matrix(), 1e-10))); + DLIB_TEST((equal(trans(v)*v , identity_matrix(), 1e-10))); + + svd2(false,true,a2+a2,u,q,v2); + svd2(true,false,a2+a2,u2,q,v); + svd2(false,false,a2+a2,u,q2,v); + + err = max(abs(a - subm(u2,get_rect(a2+a2))*diagm(q2)*trans(v2))); + DLIB_TEST_MSG( err < 1e-11,"err: " << err); + DLIB_TEST((equal(trans(u2)*u2 , identity_matrix(), 1e-10))); + DLIB_TEST((equal(trans(v2)*v2 , identity_matrix(), 1e-10))); + + } + + { + srand(423452); + const long M = 3; + const long N = 3; + + + matrix a(M,N); + for (long r = 0; r < a.nr(); ++r) + { + for (long c = 0; c < a.nc(); ++c) + { + a(r,c) = 10*((double)::rand())/RAND_MAX; + } + } + + matrix u, u2; + matrix q, q2; + matrix v, v2; + + matrix a2; + a2 = tmp(a/2); + + + svd2(true,true,a2+a2,u,q,v); + + double err = max(abs(a - subm(u,get_rect(a2+a2))*diagm(q)*trans(v))); + DLIB_TEST_MSG( err < 1e-11,"err: " << err); + using dlib::equal; + DLIB_TEST((equal(trans(u)*u , identity_matrix(), 1e-10))); + DLIB_TEST((equal(trans(v)*v , identity_matrix(), 1e-10))); + + svd2(false,true,a2+a2,u,q,v2); + svd2(true,false,a2+a2,u2,q,v); + svd2(false,false,a2+a2,u,q2,v); + + err = max(abs(a - subm(u2,get_rect(a2+a2))*diagm(q2)*trans(v2))); + DLIB_TEST_MSG( err < 1e-11,"err: " << err); + DLIB_TEST((equal(trans(u2)*u2 , identity_matrix(), 1e-10))); + DLIB_TEST((equal(trans(v2)*v2 , identity_matrix(), 1e-10))); + + } + + + + { + srand(423452); + const long M = 10; + const long N = 7; + + matrix a(M,N); + for (long r = 0; r < a.nr(); ++r) + { + for (long c = 0; c < a.nc(); ++c) + { + a(r,c) = 10*((double)::rand())/RAND_MAX; + } + } + + matrix u; + matrix q; + matrix v; + + matrix a2; + a2 = tmp(a/2); + + + svd2(true,true,a2+a2,u,q,v); + + double err = sum(round(1e10*(a - subm(u,get_rect(a2+a2))*diagm(q)*trans(v)))); + DLIB_TEST_MSG( err == 0,"err: " << err); + DLIB_TEST((round(1e10*trans(u)*u) == 1e10*identity_matrix())); + DLIB_TEST((round(1e10*trans(v)*v) == 1e10*identity_matrix())); + } + + + } + + + void matrix_test2 ( + ) + { + typedef memory_manager_stateless::kernel_2_2a MM; + { + srand(423452); + const long M = 10; + const long N = 7; + + matrix a(M,N); + for (long r = 0; r < a.nr(); ++r) + { + for (long c = 0; c < a.nc(); ++c) + { + a(r,c) = 10*((double)::rand())/RAND_MAX; + } + } + + matrix u(M,N); + matrix w; + matrix v(N,N); + + matrix a2; + a2 = tmp(a/2); + + + svd(a2+a2,u,w,v); + + DLIB_TEST( sum(round(1e10*(a - u*w*trans(v)))) == 0); + DLIB_TEST((round(1e10*trans(u)*u) == 1e10*identity_matrix())); + DLIB_TEST((round(1e10*trans(v)*v) == 1e10*identity_matrix())); + } + + { + srand(423452); + const long M = 1; + const long N = 1; + + matrix a(M,N); + for (long r = 0; r < a.nr(); ++r) + { + for (long c = 0; c < a.nc(); ++c) + { + a(r,c) = 10*((double)::rand())/RAND_MAX; + } + } + + matrix u; + matrix w; + matrix v; + + matrix a2; + a2 = 0; + a2 = tmp(a/2); + + + svd(a2+a2,u,w,v); + + DLIB_TEST( sum(round(1e10*(a - u*w*trans(v)))) == 0); + DLIB_TEST((round(1e10*trans(u)*u) == 1e10*identity_matrix())); + DLIB_TEST((round(1e10*trans(v)*v) == 1e10*identity_matrix())); + } + + + { + srand(53434); + const long M = 5; + const long N = 5; + + matrix a(M,N); + for (long r = 0; r < a.nr(); ++r) + { + for (long c = 0; c < a.nc(); ++c) + { + a(r,c) = 10*((double)::rand())/RAND_MAX; + } + } + + matrix u(M,N); + matrix w; + matrix v; + + svd(a,u,w,v); + + DLIB_TEST( sum(round(1e10*(a - u*w*trans(v)))) == 0); + DLIB_TEST((round(1e10*trans(u)*u) == 1e10*identity_matrix())); + DLIB_TEST((round(1e10*trans(v)*v) == 1e10*identity_matrix())); + } + + + { + srand(11234); + const long M = 9; + const long N = 4; + + matrix a(M,N); + for (long r = 0; r < a.nr(); ++r) + { + for (long c = 0; c < a.nc(); ++c) + { + a(r,c) = 10*((double)::rand())/RAND_MAX; + } + } + + matrix u; + matrix w; + matrix v; + + svd(a,u,w,v); + + DLIB_TEST( sum(round(1e10*(a - u*w*trans(v)))) == 0); + DLIB_TEST((round(1e10*trans(u)*u) == 1e10*identity_matrix())); + DLIB_TEST((round(1e10*trans(v)*v) == 1e10*identity_matrix())); + } + + + + { + srand(53934); + const long M = 2; + const long N = 4; + + matrix a(M,N); + for (long r = 0; r < a.nr(); ++r) + { + for (long c = 0; c < a.nc(); ++c) + { + a(r,c) = 10*((double)::rand())/RAND_MAX; + } + } + + matrix u; + matrix w; + matrix v; + + svd(a,u,w,v); + + DLIB_TEST( sum(round(1e10*(a - u*w*trans(v)))) == 0); + } + + + { + srand(53234); + const long M = 9; + const long N = 40; + + matrix a(M,N); + for (long r = 0; r < a.nr(); ++r) + { + for (long c = 0; c < a.nc(); ++c) + { + a(r,c) = 10*((double)::rand())/RAND_MAX; + } + } + + matrix u; + matrix w; + matrix v; + + svd(a,u,w,v); + + DLIB_TEST( sum(round(1e10*(a - u*w*trans(v)))) == 0); + } + + { + srand(53234); + const long M = 9; + const long N = 40; + + typedef matrix mat; + mat a(M,N); + for (long r = 0; r < a.nr(); ++r) + { + for (long c = 0; c < a.nc(); ++c) + { + a(r,c) = 10*((double)::rand())/RAND_MAX; + } + } + + mat u; + mat w; + mat v; + + svd(a,u,w,v); + + DLIB_TEST( sum(round(1e10*(a - u*w*trans(v)))) == 0); + } + + + { + matrix a(3,3); + matrix b; + set_all_elements(a,0); + + a(0,0) = 1; + a(1,1) = 2; + a(2,2) = 3; + b = a; + + DLIB_TEST(diag(a)(0) == 1); + DLIB_TEST(diag(a)(1) == 2); + DLIB_TEST(diag(a)(2) == 3); + DLIB_TEST(diag(a).nr() == 3); + DLIB_TEST(diag(a).nc() == 1); + + DLIB_TEST(diag(b)(0) == 1); + DLIB_TEST(diag(b)(1) == 2); + DLIB_TEST(diag(b)(2) == 3); + DLIB_TEST(diag(b).nr() == 3); + DLIB_TEST(diag(b).nc() == 1); + + DLIB_TEST(pointwise_multiply(a,b)(0,0) == 1); + DLIB_TEST(pointwise_multiply(a,b)(1,1) == 4); + DLIB_TEST(pointwise_multiply(a,b)(2,2) == 9); + DLIB_TEST(pointwise_multiply(a,b)(1,0) == 0); + DLIB_TEST(pointwise_multiply(a,b,a)(1,0) == 0); + DLIB_TEST(pointwise_multiply(a,b,a,b)(1,0) == 0); + + + DLIB_TEST(complex_matrix(a,b)(0,0) == std::complex(1,1)); + DLIB_TEST(complex_matrix(a,b)(2,2) == std::complex(3,3)); + DLIB_TEST(complex_matrix(a,b)(2,1) == std::complex(0,0)); + } + + { + matrix > m(2,2), m2(2,2); + complex val1(1,2), val2(1.0/complex(1,2)); + m = val1; + m2 = val2; + + DLIB_TEST(equal(reciprocal(m) , m2)); + } + { + matrix > m(2,2), m2(2,2); + complex val1(1,2), val2(1.0f/complex(1,2)); + m = val1; + m2 = val2; + + DLIB_TEST(equal(reciprocal(m) , m2)); + } + + { + matrix m1, m2; + set_all_elements(m1,2.0); + set_all_elements(m2,1.0/2.0); + DLIB_TEST(reciprocal(m1) == m2); + DLIB_TEST((reciprocal(uniform_matrix(2.0)) == m2)); + DLIB_TEST((round_zeros(uniform_matrix(1e-8f)) == uniform_matrix(0)) ); + set_all_elements(m1,2.0); + m2 = m1; + m1(1,0) = static_cast(1e-8); + m2(1,0) = 0; + DLIB_TEST(round_zeros(m1) == m2); + m1 = round_zeros(m1); + DLIB_TEST(m1 == m2); + } + + { + matrix > m; + m.set_size(3,3); + set_all_elements(m,uniform_matrix(1)); + DLIB_TEST((sum(m) == uniform_matrix(9))); + DLIB_TEST((round_zeros(sqrt(sum(m)) - uniform_matrix(3)) == uniform_matrix(0))); + } + + { + matrix m1; + matrix m2; + m2.set_size(2,2); + + set_all_elements(m1,2); + m2 = uniform_matrix(2); + + m1 = m1 + m2; + DLIB_TEST((m1 == uniform_matrix(4))); + + set_all_elements(m1,2); + set_all_elements(m2,2); + m1 = m1*m1; + DLIB_TEST((m1 == uniform_matrix(8))); + + m1(1,0) = 1; + set_all_elements(m2,8); + m2(0,1) = 1; + m1 = trans(m1); + DLIB_TEST(m1 == m2); + } + + { + matrix m; + matrix m2(2,3); + + set_all_elements(m,1); + DLIB_TEST(mean(m) == 1); + set_all_elements(m,2); + DLIB_TEST(mean(m) == 2); + m(0,0) = 1; + m(0,1) = 1; + m(0,2) = 1; + DLIB_TEST(abs(mean(m) - 1.5) < 1e-10); + DLIB_TEST(abs(variance(m) - 0.3) < 1e-10); + + set_all_elements(m2,1); + DLIB_TEST(mean(m2) == 1); + set_all_elements(m2,2); + DLIB_TEST(mean(m2) == 2); + m2(0,0) = 1; + m2(0,1) = 1; + m2(0,2) = 1; + DLIB_TEST(abs(mean(m2) - 1.5) < 1e-10); + DLIB_TEST(abs(variance(m2) - 0.3) < 1e-10); + + set_all_elements(m,0); + DLIB_TEST(abs(variance(m)) < 1e-10); + set_all_elements(m,1); + DLIB_TEST(abs(variance(m)) < 1e-10); + set_all_elements(m,23.4); + DLIB_TEST(abs(variance(m)) < 1e-10); + } + + { + matrix,2,2,MM> m; + set_all_elements(m,uniform_matrix(1)); + DLIB_TEST((round_zeros(variance(m)) == uniform_matrix(0))); + DLIB_TEST((round_zeros(mean(m)) == uniform_matrix(1))); + m(0,0) = uniform_matrix(9); + DLIB_TEST((round_zeros(variance(m)) == uniform_matrix(16))); + DLIB_TEST((round_zeros(mean(m)) == uniform_matrix(3))); + + matrix > m2(2,2); + set_all_elements(m2,uniform_matrix(1)); + DLIB_TEST((round_zeros(variance(m2)) == uniform_matrix(0))); + DLIB_TEST((round_zeros(mean(m2)) == uniform_matrix(1))); + m2(0,0) = uniform_matrix(9); + DLIB_TEST((round_zeros(variance(m2)) == uniform_matrix(16))); + DLIB_TEST((round_zeros(mean(m2)) == uniform_matrix(3))); + } + + + { + matrix m(4,4), m2; + m = 1,2,3,4, + 1,2,3,4, + 4,6,8,10, + 4,6,8,10; + m2 = m; + + DLIB_TEST(colm(m,range(0,3)) == m); + DLIB_TEST(rowm(m,range(0,3)) == m); + DLIB_TEST(colm(m,range(0,0)) == colm(m,0)); + DLIB_TEST(rowm(m,range(0,0)) == rowm(m,0)); + DLIB_TEST(colm(m,range(1,1)) == colm(m,1)); + DLIB_TEST(rowm(m,range(1,1)) == rowm(m,1)); + + DLIB_TEST(colm(m,range(2,2)) == colm(m,2)); + DLIB_TEST(rowm(m,range(2,2)) == rowm(m,2)); + + DLIB_TEST(colm(m,range(1,2)) == subm(m,0,1,4,2)); + DLIB_TEST(rowm(m,range(1,2)) == subm(m,1,0,2,4)); + + set_colm(m,range(1,2)) = 9; + set_subm(m2,0,1,4,2) = 9; + DLIB_TEST(m == m2); + + set_colm(m,range(1,2)) = 11; + set_subm(m2,0,1,4,2) = 11; + DLIB_TEST(m == m2); + } + + { + print_spinner(); + matrix m1; + matrix m2; + matrix m3; + matrix m4; + + dlib::rand rnd; + for (int i = 0; i < 50; ++i) + { + m1 = randm(1,1,rnd); + m2 = randm(2,2,rnd); + m3 = randm(3,3,rnd); + m4 = randm(4,4,rnd); + + DLIB_TEST(max(abs(m1*inv(m1) - identity_matrix(m1))) < 1e-13); + DLIB_TEST(max(abs(m2*inv(m2) - identity_matrix(m2))) < 1e-12); + DLIB_TEST(max(abs(m3*inv(m3) - identity_matrix(m3))) < 1e-13); + DLIB_TEST_MSG(max(abs(m4*inv(m4) - identity_matrix(m4))) < 1e-12, max(abs(m4*inv(m4) - identity_matrix(m4)))); + } + } + + } + + + + + + + class matrix_tester : public tester + { + public: + matrix_tester ( + ) : + tester ("test_matrix2", + "Runs tests on the matrix component.") + {} + + void perform_test ( + ) + { + matrix_test1(); + matrix_test2(); + } + } a; + +} + + diff --git a/ml/dlib/dlib/test/matrix3.cpp b/ml/dlib/dlib/test/matrix3.cpp new file mode 100644 index 000000000..b66af638c --- /dev/null +++ b/ml/dlib/dlib/test/matrix3.cpp @@ -0,0 +1,1134 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include +#include +#include "../stl_checked.h" +#include "../array.h" +#include "../rand.h" + +#include "tester.h" +#include +#include + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.matrix3"); + + + const double eps_mul = 200; + + template + void check_equal ( + const T& a, + const U& b + ) + { + DLIB_TEST(a.nr() == b.nr()); + DLIB_TEST(a.nc() == b.nc()); + typedef typename T::type type; + for (long r = 0; r < a.nr(); ++r) + { + for (long c = 0; c < a.nc(); ++c) + { + type error = std::abs(a(r,c) - b(r,c)); + DLIB_TEST_MSG(error < std::sqrt(std::numeric_limits::epsilon())*eps_mul, "error: " << error << + " eps: " << std::sqrt(std::numeric_limits::epsilon())*eps_mul); + } + } + } + + template + void c_check_equal ( + const T& a, + const U& b + ) + { + DLIB_TEST(a.nr() == b.nr()); + DLIB_TEST(a.nc() == b.nc()); + typedef typename T::type type; + for (long r = 0; r < a.nr(); ++r) + { + for (long c = 0; c < a.nc(); ++c) + { + typename type::value_type error = std::abs(a(r,c) - b(r,c)); + DLIB_TEST_MSG(error < std::sqrt(std::numeric_limits::epsilon())*eps_mul, "error: " << error << + " eps: " << std::sqrt(std::numeric_limits::epsilon())*eps_mul); + } + } + } + + template + void assign_no_blas ( + const T& a_, + const U& b + ) + { + T& a = const_cast(a_); + DLIB_TEST(a.nr() == b.nr()); + DLIB_TEST(a.nc() == b.nc()); + for (long r = 0; r < a.nr(); ++r) + { + for (long c = 0; c < a.nc(); ++c) + { + a(r,c) = b(r,c); + } + } + } + + template + type rnd_num (dlib::rand& rnd) + { + return static_cast(10*rnd.get_random_double()); + } + + template + void test_blas( long rows, long cols) + { + // The tests in this function exercise the BLAS bindings located in the matrix/matrix_blas_bindings.h file. + // It does this by performing an assignment that is subject to BLAS bindings and comparing the + // results directly to an unevaluated matrix_exp that should be equal. + + dlib::rand rnd; + + matrix a(rows,cols), temp, temp2, temp3; + + for (int k = 0; k < 6; ++k) + { + for (long r= 0; r < a.nr(); ++r) + { + for (long c = 0; c < a.nc(); ++c) + { + a(r,c) = rnd_num(rnd); + } + } + matrix at; + at = trans(a); + + matrix > c_a(rows,cols), c_at, c_sqr; + for (long r= 0; r < a.nr(); ++r) + { + for (long c = 0; c < a.nc(); ++c) + { + c_a(r,c) = complex(rnd_num(rnd),rnd_num(rnd)); + } + } + c_at = trans(c_a); + const int size = max(rows,cols); + c_sqr = 10*matrix_cast >(complex_matrix(randm(size,size,rnd), randm(size,size,rnd))); + + + matrix > c_temp(cols,cols), c_temp2(cols,cols); + const complex i(0,1); + + const type one = 1; + const type two = 1; + const type num1 = static_cast(3.6); + const type num2 = static_cast(6.6); + const type num3 = static_cast(8.6); + + matrix,0,1> c_cv4(cols), c_cv3(rows); + matrix,1,0> c_rv4(cols), c_rv3(rows); + + matrix cv4(cols); + + for (long idx = 0; idx < cv4.size(); ++idx) + cv4(idx) = rnd_num(rnd); + + for (long idx = 0; idx < c_cv4.size(); ++idx) + c_cv4(idx) = complex(rnd_num(rnd),rnd_num(rnd)); + + matrix rv3(rows); + + for (long idx = 0; idx < rv3.size(); ++idx) + rv3(idx) = rnd_num(rnd); + + for (long idx = 0; idx < c_rv3.size(); ++idx) + c_rv3(idx) = complex(rnd_num(rnd),rnd_num(rnd)); + + matrix cv3(rows); + + for (long idx = 0; idx < cv3.size(); ++idx) + cv3(idx) = rnd_num(rnd); + + for (long idx = 0; idx < c_cv3.size(); ++idx) + c_cv3(idx) = complex(rnd_num(rnd),rnd_num(rnd)); + + matrix rv4(cols); + for (long idx = 0; idx < rv4.size(); ++idx) + rv4(idx) = rnd_num(rnd); + + for (long idx = 0; idx < c_rv4.size(); ++idx) + c_rv4(idx) = complex(rnd_num(rnd),rnd_num(rnd)); + + + + // GEMM tests + dlog << LTRACE << "1.1"; + check_equal(tmp(at*a), at*a); + check_equal(tmp(trans(at*a)), trans(at*a)); + check_equal(tmp(2.4*trans(4*trans(at*a) + at*3*a)), 2.4*trans(4*trans(at*a) + at*3*a)); + dlog << LTRACE << "1.2"; + check_equal(tmp(trans(a)*a), trans(a)*a); + check_equal(tmp(trans(trans(a)*a)), trans(trans(a)*a)); + dlog << LTRACE << "1.3"; + check_equal(tmp(at*trans(at)), at*trans(at)); + check_equal(tmp(trans(at*trans(at))), trans(at*trans(at))); + dlog << LTRACE << "1.4"; + check_equal(tmp(trans(at)*trans(a)), a*at); + check_equal(tmp(trans(trans(at)*trans(a))), trans(a*at)); + dlog << LTRACE << "1.5"; + + print_spinner(); + c_check_equal(tmp(conj(trans(c_a))*c_a), trans(conj(c_a))*c_a); + dlog << LTRACE << "1.5.1"; + c_check_equal(tmp(trans(conj(trans(c_a))*c_a)), trans(trans(conj(c_a))*c_a)); + dlog << LTRACE << "1.5.2"; + c_check_equal(tmp((conj(trans(c_sqr))*trans(c_sqr))), (trans(conj(c_sqr))*trans(c_sqr))); + dlog << LTRACE << "1.5.3"; + c_check_equal(tmp(trans(conj(trans(c_sqr))*trans(c_sqr))), trans(trans(conj(c_sqr))*trans(c_sqr))); + dlog << LTRACE << "1.6"; + c_check_equal(tmp(c_at*trans(conj(c_at))), c_at*conj(trans(c_at))); + dlog << LTRACE << "1.6.1"; + c_check_equal(tmp(trans(c_at*trans(conj(c_at)))), trans(c_at*conj(trans(c_at)))); + dlog << LTRACE << "1.6.2"; + c_check_equal(tmp((c_sqr)*trans(conj(c_sqr))), (c_sqr)*conj(trans(c_sqr))); + dlog << LTRACE << "1.6.2.1"; + c_check_equal(tmp(trans(c_sqr)*trans(conj(c_sqr))), trans(c_sqr)*conj(trans(c_sqr))); + dlog << LTRACE << "1.6.3"; + c_check_equal(tmp(trans(trans(c_sqr)*trans(conj(c_sqr)))), trans(trans(c_sqr)*conj(trans(c_sqr)))); + dlog << LTRACE << "1.7"; + c_check_equal(tmp(conj(trans(c_at))*trans(conj(c_a))), conj(trans(c_at))*trans(conj(c_a))); + c_check_equal(tmp(trans(conj(trans(c_at))*trans(conj(c_a)))), trans(conj(trans(c_at))*trans(conj(c_a)))); + dlog << LTRACE << "1.8"; + + check_equal(tmp(a*trans(rowm(a,1))) , a*trans(rowm(a,1))); + check_equal(tmp(a*colm(at,1)) , a*colm(at,1)); + check_equal(tmp(subm(a,1,1,2,2)*subm(a,1,2,2,2)), subm(a,1,1,2,2)*subm(a,1,2,2,2)); + + dlog << LTRACE << "1.9"; + check_equal(tmp(trans(a*trans(rowm(a,1)))) , trans(a*trans(rowm(a,1)))); + dlog << LTRACE << "1.10"; + check_equal(tmp(trans(a*colm(at,1))) , trans(a*colm(at,1))); + dlog << LTRACE << "1.11"; + check_equal(tmp(trans(subm(a,1,1,2,2)*subm(a,1,2,2,2))), trans(subm(a,1,1,2,2)*subm(a,1,2,2,2))); + dlog << LTRACE << "1.12"; + + { + temp = at*a; + temp2 = temp; + + temp += 3.5*at*a; + assign_no_blas(temp2, temp2 + 3.5*at*a); + check_equal(temp, temp2); + + temp -= at*3.5*a; + assign_no_blas(temp2, temp2 - at*3.5*a); + check_equal(temp, temp2); + + temp = temp + 4*at*a; + assign_no_blas(temp2, temp2 + 4*at*a); + check_equal(temp, temp2); + + temp = temp - 2.4*at*a; + assign_no_blas(temp2, temp2 - 2.4*at*a); + check_equal(temp, temp2); + } + dlog << LTRACE << "1.13"; + { + temp = trans(at*a); + temp2 = temp; + temp3 = temp; + + dlog << LTRACE << "1.14"; + temp += trans(3.5*at*a); + assign_no_blas(temp2, temp2 + trans(3.5*at*a)); + check_equal(temp, temp2); + + dlog << LTRACE << "1.15"; + temp -= trans(at*3.5*a); + assign_no_blas(temp2, temp2 - trans(at*3.5*a)); + check_equal(temp, temp2); + + dlog << LTRACE << "1.16"; + temp = trans(temp + 4*at*a); + assign_no_blas(temp3, trans(temp2 + 4*at*a)); + check_equal(temp, temp3); + + temp2 = temp; + dlog << LTRACE << "1.17"; + temp = trans(temp - 2.4*at*a); + assign_no_blas(temp3, trans(temp2 - 2.4*at*a)); + check_equal(temp, temp3); + } + + dlog << LTRACE << "1.17.1"; + { + matrix m1, m2; + + m1 = matrix_cast(randm(rows, cols, rnd)); + m2 = matrix_cast(randm(cols, rows + 8, rnd)); + check_equal(tmp(m1*m2), m1*m2); + check_equal(tmp(trans(m1*m2)), trans(m1*m2)); + + m1 = trans(m1); + check_equal(tmp(trans(m1)*m2), trans(m1)*m2); + check_equal(tmp(trans(trans(m1)*m2)), trans(trans(m1)*m2)); + + m2 = trans(m2); + check_equal(tmp(trans(m1)*trans(m2)), trans(m1)*trans(m2)); + check_equal(tmp(trans(trans(m1)*trans(m2))), trans(trans(m1)*trans(m2))); + + m1 = trans(m1); + check_equal(tmp(m1*trans(m2)), m1*trans(m2)); + check_equal(tmp(trans(m1*trans(m2))), trans(m1*trans(m2))); + } + + dlog << LTRACE << "1.17.5"; + { + matrix r; + matrix c; + + r = matrix_cast(randm(1, rows+9, rnd)); + c = matrix_cast(randm(rows, 1, rnd)); + + check_equal(tmp(c*r), c*r); + check_equal(tmp(trans(c*r)), trans(c*r)); + + check_equal(tmp(trans(r)*trans(c)), trans(r)*trans(c)); + check_equal(tmp(trans(trans(r)*trans(c))), trans(trans(r)*trans(c))); + } + + dlog << LTRACE << "1.18"; + + // GEMV tests + check_equal(tmp(a*cv4), a*cv4); + check_equal(tmp(trans(a*cv4)), trans(a*cv4)); + check_equal(tmp(rv3*a), rv3*a); + check_equal(tmp(trans(cv4)*at), trans(cv4)*at); + check_equal(tmp(a*trans(rv4)), a*trans(rv4)); + check_equal(tmp(trans(a*trans(rv4))), trans(a*trans(rv4))); + + check_equal(tmp(trans(a)*cv3), trans(a)*cv3); + check_equal(tmp(rv4*trans(a)), rv4*trans(a)); + check_equal(tmp(trans(cv3)*trans(at)), trans(cv3)*trans(at)); + check_equal(tmp(trans(cv3)*a), trans(cv3)*a); + check_equal(tmp(trans(a)*trans(rv3)), trans(a)*trans(rv3)); + + + c_check_equal(tmp(trans(conj(c_a))*c_cv3), trans(conj(c_a))*c_cv3); + c_check_equal(tmp(c_rv4*trans(conj(c_a))), c_rv4*trans(conj(c_a))); + c_check_equal(tmp(trans(c_cv3)*trans(conj(c_at))), trans(c_cv3)*trans(conj(c_at))); + c_check_equal(tmp(conj(trans(c_a))*trans(c_rv3)), trans(conj(c_a))*trans(c_rv3)); + c_check_equal(tmp(c_rv4*conj(c_at)), c_rv4*conj(c_at)); + c_check_equal(tmp(trans(c_cv4)*conj(c_at)), trans(c_cv4)*conj(c_at)); + + dlog << LTRACE << "2.00"; + + c_check_equal(tmp(trans(trans(conj(c_a))*c_cv3)), trans(trans(conj(c_a))*c_cv3)); + c_check_equal(tmp(trans(c_rv4*trans(conj(c_a)))), trans(c_rv4*trans(conj(c_a)))); + c_check_equal(tmp(trans(trans(c_cv3)*trans(conj(c_at)))), trans(trans(c_cv3)*trans(conj(c_at)))); + dlog << LTRACE << "2.20"; + c_check_equal(tmp(trans(conj(trans(c_a))*trans(c_rv3))), trans(trans(conj(c_a))*trans(c_rv3))); + c_check_equal(tmp(trans(c_rv4*conj(c_at))), trans(c_rv4*conj(c_at))); + c_check_equal(tmp(trans(trans(c_cv4)*conj(c_at))), trans(trans(c_cv4)*conj(c_at))); + + + + dlog << LTRACE << "6"; + temp = a*at; + check_equal(temp, a*at); + temp = temp + a*at + trans(at)*at + trans(at)*sin(at); + check_equal(temp, a*at + a*at+ trans(at)*at + trans(at)*sin(at)); + + dlog << LTRACE << "6.1"; + temp = a*at; + check_equal(temp, a*at); + temp = a*at + temp; + check_equal(temp, a*at + a*at); + + print_spinner(); + dlog << LTRACE << "6.2"; + temp = a*at; + check_equal(temp, a*at); + dlog << LTRACE << "6.2.3"; + temp = temp - a*at; + dlog << LTRACE << "6.2.4"; + check_equal(temp, a*at-a*at); + + dlog << LTRACE << "6.3"; + temp = a*at; + dlog << LTRACE << "6.3.5"; + check_equal(temp, a*at); + dlog << LTRACE << "6.3.6"; + temp = a*at - temp; + dlog << LTRACE << "6.4"; + check_equal(temp, a*at-a*at); + + + + const long d = min(rows,cols); + rectangle rect(1,1,d,d); + temp.set_size(max(rows,cols)+4,max(rows,cols)+4); + set_all_elements(temp,4); + temp2 = temp; + + dlog << LTRACE << "7"; + set_subm(temp,rect) = a*at; + assign_no_blas( set_subm(temp2,rect) , a*at); + check_equal(temp, temp2); + + temp = a; + temp2 = a; + + set_colm(temp,1) = a*cv4; + assign_no_blas( set_colm(temp2,1) , a*cv4); + check_equal(temp, temp2); + + set_rowm(temp,1) = rv3*a; + assign_no_blas( set_rowm(temp2,1) , rv3*a); + check_equal(temp, temp2); + + + // Test BLAS GER + { + temp.set_size(cols,cols); + set_all_elements(temp,3); + temp2 = temp; + + + dlog << LTRACE << "8"; + temp += cv4*rv4; + assign_no_blas(temp2, temp2 + cv4*rv4); + check_equal(temp, temp2); + + dlog << LTRACE << "8.3"; + temp = temp + cv4*rv4; + assign_no_blas(temp2, temp2 + cv4*rv4); + check_equal(temp, temp2); + dlog << LTRACE << "8.9"; + } + { + temp.set_size(cols,cols); + set_all_elements(temp,3); + temp2 = temp; + temp3 = 0; + + dlog << LTRACE << "8.10"; + + temp += trans(cv4*rv4); + assign_no_blas(temp3, temp2 + trans(cv4*rv4)); + check_equal(temp, temp3); + temp3 = 0; + + dlog << LTRACE << "8.11"; + temp2 = temp; + temp = trans(temp + cv4*rv4); + assign_no_blas(temp3, trans(temp2 + cv4*rv4)); + check_equal(temp, temp3); + dlog << LTRACE << "8.12"; + } + { + matrix > temp, temp2, temp3; + matrix,0,1 > cv4; + matrix,1,0 > rv4; + cv4.set_size(cols); + rv4.set_size(cols); + temp.set_size(cols,cols); + set_all_elements(temp,complex(3,5)); + temp(cols-1, cols-4) = 9; + temp2 = temp; + temp3.set_size(cols,cols); + temp3 = 0; + + for (long i = 0; i < rv4.size(); ++i) + { + rv4(i) = complex(rnd_num(rnd),rnd_num(rnd)); + cv4(i) = complex(rnd_num(rnd),rnd_num(rnd)); + } + + dlog << LTRACE << "8.13"; + + temp += trans(cv4*rv4); + assign_no_blas(temp3, temp2 + trans(cv4*rv4)); + c_check_equal(temp, temp3); + temp3 = 0; + + dlog << LTRACE << "8.14"; + temp2 = temp; + temp = trans(temp + cv4*rv4); + assign_no_blas(temp3, trans(temp2 + cv4*rv4)); + c_check_equal(temp, temp3); + dlog << LTRACE << "8.15"; + } + + + + + set_all_elements(c_temp, one + num1*i); + c_temp2 = c_temp; + set_all_elements(c_rv4, one + num2*i); + set_all_elements(c_cv4, two + num3*i); + + + dlog << LTRACE << "9"; + c_temp += c_cv4*c_rv4; + assign_no_blas(c_temp2, c_temp2 + c_cv4*c_rv4); + c_check_equal(c_temp, c_temp2); + dlog << LTRACE << "9.1"; + c_temp += c_cv4*conj(c_rv4); + assign_no_blas(c_temp2, c_temp2 + c_cv4*conj(c_rv4)); + c_check_equal(c_temp, c_temp2); + dlog << LTRACE << "9.2"; + c_temp = c_cv4*conj(c_rv4) + c_temp; + assign_no_blas(c_temp2, c_temp2 + c_cv4*conj(c_rv4)); + c_check_equal(c_temp, c_temp2); + dlog << LTRACE << "9.3"; + c_temp = trans(c_rv4)*trans(conj(c_cv4)) + c_temp; + assign_no_blas(c_temp2, c_temp2 + trans(c_rv4)*trans(conj(c_cv4))); + c_check_equal(c_temp, c_temp2); + + + dlog << LTRACE << "9.4"; + c_temp += conj(c_cv4)*c_rv4; + assign_no_blas(c_temp2, c_temp2 + conj(c_cv4)*c_rv4); + c_check_equal(c_temp, c_temp2); + dlog << LTRACE << "9.5"; + c_temp += conj(c_cv4)*conj(c_rv4); + assign_no_blas(c_temp2, c_temp2 + conj(c_cv4)*conj(c_rv4)); + c_check_equal(c_temp, c_temp2); + dlog << LTRACE << "9.6"; + c_temp = conj(c_cv4)*conj(c_rv4) + c_temp; + assign_no_blas(c_temp2, c_temp2 + conj(c_cv4)*conj(c_rv4)); + c_check_equal(c_temp, c_temp2); + dlog << LTRACE << "9.7"; + c_temp = conj(trans(c_rv4))*trans(conj(c_cv4)) + c_temp; + assign_no_blas(c_temp2, c_temp2 + conj(trans(c_rv4))*trans(conj(c_cv4))); + c_check_equal(c_temp, c_temp2); + + + dlog << LTRACE << "10"; + c_temp += trans(c_cv4*c_rv4); + assign_no_blas(c_temp2, c_temp2 + trans(c_cv4*c_rv4)); + c_check_equal(c_temp, c_temp2); + dlog << LTRACE << "10.1"; + c_temp += trans(c_cv4*conj(c_rv4)); + assign_no_blas(c_temp2, c_temp2 + trans(c_cv4*conj(c_rv4))); + c_check_equal(c_temp, c_temp2); + dlog << LTRACE << "10.2"; + c_temp = trans(c_cv4*conj(c_rv4)) + c_temp; + assign_no_blas(c_temp2, c_temp2 + trans(c_cv4*conj(c_rv4))); + c_check_equal(c_temp, c_temp2); + dlog << LTRACE << "10.3"; + c_temp = trans(trans(c_rv4)*trans(conj(c_cv4))) + c_temp; + assign_no_blas(c_temp2, c_temp2 + trans(trans(c_rv4)*trans(conj(c_cv4)))); + c_check_equal(c_temp, c_temp2); + + + dlog << LTRACE << "10.4"; + c_temp += trans(conj(c_cv4)*c_rv4); + assign_no_blas(c_temp2, c_temp2 + trans(conj(c_cv4)*c_rv4)); + c_check_equal(c_temp, c_temp2); + dlog << LTRACE << "10.5"; + c_temp += trans(conj(c_cv4)*conj(c_rv4)); + assign_no_blas(c_temp2, c_temp2 + trans(conj(c_cv4)*conj(c_rv4))); + c_check_equal(c_temp, c_temp2); + dlog << LTRACE << "10.6"; + c_temp = trans(conj(c_cv4)*conj(c_rv4)) + c_temp; + assign_no_blas(c_temp2, c_temp2 + trans(conj(c_cv4)*conj(c_rv4))); + c_check_equal(c_temp, c_temp2); + dlog << LTRACE << "10.7"; + c_temp = trans(conj(trans(c_rv4))*trans(conj(c_cv4))) + c_temp; + assign_no_blas(c_temp2, c_temp2 + trans(conj(trans(c_rv4))*trans(conj(c_cv4)))); + c_check_equal(c_temp, c_temp2); + + dlog << LTRACE << "10.8"; + + + print_spinner(); + + // Test DOT + check_equal( tmp(rv4*cv4), rv4*cv4); + check_equal( tmp(trans(rv4*cv4)), trans(rv4*cv4)); + check_equal( tmp(trans(cv4)*trans(rv4)), trans(cv4)*trans(rv4)); + check_equal( tmp(rv4*3.9*cv4), rv4*3.9*cv4); + check_equal( tmp(trans(cv4)*3.9*trans(rv4)), trans(cv4)*3.9*trans(rv4)); + check_equal( tmp(rv4*cv4*3.9), rv4*3.9*cv4); + check_equal( tmp(trans(cv4)*trans(rv4)*3.9), trans(cv4)*3.9*trans(rv4)); + + + check_equal( tmp(trans(rv4*cv4)), trans(rv4*cv4)); + check_equal( tmp(trans(trans(rv4*cv4))), trans(trans(rv4*cv4))); + check_equal( tmp(trans(trans(cv4)*trans(rv4))), trans(trans(cv4)*trans(rv4))); + check_equal( tmp(trans(rv4*3.9*cv4)), trans(rv4*3.9*cv4)); + check_equal( tmp(trans(trans(cv4)*3.9*trans(rv4))), trans(trans(cv4)*3.9*trans(rv4))); + check_equal( tmp(trans(rv4*cv4*3.9)), trans(rv4*3.9*cv4)); + check_equal( tmp(trans(trans(cv4)*trans(rv4)*3.9)), trans(trans(cv4)*3.9*trans(rv4))); + + + temp.set_size(1,1); + temp = 4; + check_equal( tmp(temp + rv4*cv4), temp + rv4*cv4); + check_equal( tmp(temp + trans(cv4)*trans(rv4)), temp + trans(cv4)*trans(rv4)); + + dlog << LTRACE << "11"; + + + + c_check_equal( tmp(conj(c_rv4)*c_cv4), conj(c_rv4)*c_cv4); + c_check_equal( tmp(conj(trans(c_cv4))*trans(c_rv4)), trans(conj(c_cv4))*trans(c_rv4)); + + c_check_equal( tmp(conj(c_rv4)*i*c_cv4), conj(c_rv4)*i*c_cv4); + c_check_equal( tmp(conj(trans(c_cv4))*i*trans(c_rv4)), trans(conj(c_cv4))*i*trans(c_rv4)); + + c_temp.set_size(1,1); + c_temp = 4; + c_check_equal( tmp(c_temp + conj(c_rv4)*c_cv4), c_temp + conj(c_rv4)*c_cv4); + c_check_equal( tmp(c_temp + trans(conj(c_cv4))*trans(c_rv4)), c_temp + trans(conj(c_cv4))*trans(c_rv4)); + + complex tmp = c_rv4*c_cv4; + DLIB_TEST(abs((tmp + i) - ((c_rv4*c_cv4)(0) + i)) < std::sqrt(std::numeric_limits::epsilon())*eps_mul ); + DLIB_TEST(max(abs((rv4*cv4 + 1.0) - ((rv4*cv4)(0) + 1.0))) < std::sqrt(std::numeric_limits::epsilon())*eps_mul); + + } + + { + matrix m(2,3), m2(6,1); + + m = 1,2,3, + 4,5,6; + + m2 = 1,2,3,4,5,6; + + DLIB_TEST(reshape_to_column_vector(m) == m2); + DLIB_TEST(reshape_to_column_vector(m+m) == m2+m2); + + } + { + matrix m(2,3); + matrix m2(6,1); + + m = 1,2,3, + 4,5,6; + + m2 = 1,2,3,4,5,6; + + DLIB_TEST(reshape_to_column_vector(m) == m2); + DLIB_TEST(reshape_to_column_vector(m+m) == m2+m2); + + } + } + + + void matrix_test ( + ) + /*! + ensures + - runs tests on the matrix stuff compliance with the specs + !*/ + { + print_spinner(); + + + { + matrix m1(2,2), m2(2,2); + + m1 = 1, 2, + 3, 4; + + m2 = 4, 5, + 6, 7; + + + DLIB_TEST(subm(tensor_product(m1,m2),range(0,1), range(0,1)) == 1*m2); + DLIB_TEST(subm(tensor_product(m1,m2),range(0,1), range(2,3)) == 2*m2); + DLIB_TEST(subm(tensor_product(m1,m2),range(2,3), range(0,1)) == 3*m2); + DLIB_TEST(subm(tensor_product(m1,m2),range(2,3), range(2,3)) == 4*m2); + } + + { + print_spinner(); + dlog << LTRACE << "testing blas stuff"; + dlog << LTRACE << " \nsmall double"; + test_blas(3,4); + print_spinner(); + dlog << LTRACE << " \nsmall float"; + test_blas(3,4); + print_spinner(); + dlog << LTRACE << " \nbig double"; + test_blas(120,131); + print_spinner(); + dlog << LTRACE << " \nbig float"; + test_blas(120,131); + print_spinner(); + dlog << LTRACE << "testing done"; + } + + + { + matrix m(3,4), ml(3,4), mu(3,4); + m = 1,2,3,4, + 4,5,6,7, + 7,8,9,0; + + ml = 1,0,0,0, + 4,5,0,0, + 7,8,9,0; + + mu = 1,2,3,4, + 0,5,6,7, + 0,0,9,0; + + + DLIB_TEST(lowerm(m) == ml); + DLIB_TEST(upperm(m) == mu); + + ml = 3,0,0,0, + 4,3,0,0, + 7,8,3,0; + + mu = 4,2,3,4, + 0,4,6,7, + 0,0,4,0; + + DLIB_TEST(lowerm(m,3) == ml); + DLIB_TEST(upperm(m,4) == mu); + + } + + { + matrix m(3,4), row(1,3), col(2,1); + m = 1,2,3,4, + 4,5,6,7, + 7,8,9,0; + + row = 4,5,6; + col = 3,6; + + DLIB_TEST(rowm(m, 1, 3) == row); + DLIB_TEST(colm(m, 2, 2) == col); + + } + + + { + std::vector v(34, 8); + std::vector v2(34, 9); + + DLIB_TEST(mat(&v[0], v.size()) == mat(v)); + DLIB_TEST(mat(&v2[0], v.size()) != mat(v)); + } + + { + std::vector v(1, 3); + std::vector v2(1, 2); + + DLIB_TEST(mat(&v[0], v.size()) == mat(v)); + DLIB_TEST(mat(&v2[0], v.size()) != mat(v)); + } + + { + matrix a(3,3), b(3,3); + a = 1, 2.5, 1, + 3, 4, 5, + 0.5, 2.2, 3; + + b = 0, 1, 0, + 1, 1, 1, + 0, 1, 1; + + DLIB_TEST((a>1) == b); + DLIB_TEST((1=1) == b); + DLIB_TEST((1<=a) == b); + + b = 0, 0, 0, + 0, 0, 0, + 0, 1, 0; + DLIB_TEST((a==2.2) == b); + DLIB_TEST((a!=2.2) == (b==0)); + DLIB_TEST((2.2==a) == b); + DLIB_TEST((2.2!=a) == (0==b)); + + b = 0, 0, 0, + 0, 0, 0, + 1, 0, 0; + DLIB_TEST((a<1) == b); + DLIB_TEST((1>a) == b); + + b = 1, 0, 1, + 0, 0, 0, + 1, 0, 0; + DLIB_TEST((a<=1) == b); + DLIB_TEST((1>=a) == b); + } + + { + matrix a, b, c; + a = randm(4,2); + + b += a; + c -= a; + + DLIB_TEST(equal(a, b)); + DLIB_TEST(equal(-a, c)); + + b += a; + c -= a; + + DLIB_TEST(equal(2*a, b)); + DLIB_TEST(equal(-2*a, c)); + + b += a + a; + c -= a + a; + + DLIB_TEST(equal(4*a, b)); + DLIB_TEST(equal(-4*a, c)); + + b.set_size(0,0); + c.set_size(0,0); + + + b += a + a; + c -= a + a; + + DLIB_TEST(equal(2*a, b)); + DLIB_TEST(equal(-2*a, c)); + } + + { + matrix a, b, c; + + a.set_size(2, 3); + b.set_size(2, 6); + c.set_size(4, 3); + + a = 1, 2, 3, + 4, 5, 6; + + b = 1, 2, 3, 1, 2, 3, + 4, 5, 6, 4, 5, 6; + + c = 1, 2, 3, + 4, 5, 6, + 1, 2, 3, + 4, 5, 6; + + DLIB_TEST(join_rows(a,a) == b); + DLIB_TEST(join_rows(a,abs(a)) == b); + DLIB_TEST(join_cols(trans(a), trans(a)) == trans(b)); + DLIB_TEST(join_cols(a,a) == c); + DLIB_TEST(join_cols(a,abs(a)) == c); + DLIB_TEST(join_rows(trans(a),trans(a)) == trans(c)); + } + + { + matrix a; + matrix b; + matrix c; + + a = 1, 2, 3, + 4, 5, 6; + + b = 1, 2, 3, 1, 2, 3, + 4, 5, 6, 4, 5, 6; + + c = 1, 2, 3, + 4, 5, 6, + 1, 2, 3, + 4, 5, 6; + + DLIB_TEST(join_rows(a,a) == b); + DLIB_TEST(join_rows(a,abs(a)) == b); + DLIB_TEST(join_cols(trans(a), trans(a)) == trans(b)); + DLIB_TEST(join_cols(a,a) == c); + DLIB_TEST(join_cols(a,abs(a)) == c); + DLIB_TEST(join_rows(trans(a),trans(a)) == trans(c)); + } + + { + matrix a; + matrix a2; + matrix b; + matrix c; + + a = 1, 2, 3, + 4, 5, 6; + + a2 = a; + + b = 1, 2, 3, 1, 2, 3, + 4, 5, 6, 4, 5, 6; + + c = 1, 2, 3, + 4, 5, 6, + 1, 2, 3, + 4, 5, 6; + + DLIB_TEST(join_rows(a,a2) == b); + DLIB_TEST(join_rows(a2,a) == b); + DLIB_TEST(join_cols(trans(a2), trans(a)) == trans(b)); + DLIB_TEST(join_cols(a2,a) == c); + DLIB_TEST(join_cols(a,a2) == c); + DLIB_TEST(join_rows(trans(a2),trans(a)) == trans(c)); + } + + { + matrix a, b; + + a.set_size(2,3); + + a = 1, 2, 3, + 4, 5, 6; + + b.set_size(3,2); + b = 1, 2, + 3, 4, + 5, 6; + + DLIB_TEST(reshape(a, 3, 2) == b); + + b.set_size(2,3); + b = 1, 4, 2, + 5, 3, 6; + + DLIB_TEST(reshape(trans(a), 2, 3) == b); + + } + + { + matrix a; + matrix b; + + a = 1, 2, 3, + 4, 5, 6; + + b.set_size(3,2); + b = 1, 2, + 3, 4, + 5, 6; + + DLIB_TEST(reshape(a, 3, 2) == b); + + b.set_size(2,3); + b = 1, 4, 2, + 5, 3, 6; + + DLIB_TEST(reshape(trans(a), 2, 3) == b); + + } + + { + std::vector v(6); + for (unsigned long i = 0; i < v.size(); ++i) + v[i] = i; + + matrix a; + a = 0, 1, 2, + 3, 4, 5; + + DLIB_TEST(mat(&v[0], 2, 3) == a); + } + + { + matrix a(3,4); + matrix b(3,1), c(1,4); + + a = 1, 2, 3, 6, + 4, 5, 6, 9, + 1, 1, 1, 3; + + b(0) = sum(rowm(a,0)); + b(1) = sum(rowm(a,1)); + b(2) = sum(rowm(a,2)); + + c(0) = sum(colm(a,0)); + c(1) = sum(colm(a,1)); + c(2) = sum(colm(a,2)); + c(3) = sum(colm(a,3)); + + DLIB_TEST(sum_cols(a) == b); + DLIB_TEST(sum_rows(a) == c); + + } + + { + matrix m(3,3); + + m = 1, 2, 3, + 4, 5, 6, + 7, 8, 9; + + DLIB_TEST(make_symmetric(m) == trans(make_symmetric(m))); + DLIB_TEST(lowerm(make_symmetric(m)) == lowerm(m)); + DLIB_TEST(upperm(make_symmetric(m)) == trans(lowerm(m))); + } + + { + matrix a; + matrix b(3,1), c(1,4); + + a = 1, 2, 3, 6, + 4, 5, 6, 9, + 1, 1, 1, 3; + + b(0) = sum(rowm(a,0)); + b(1) = sum(rowm(a,1)); + b(2) = sum(rowm(a,2)); + + c(0) = sum(colm(a,0)); + c(1) = sum(colm(a,1)); + c(2) = sum(colm(a,2)); + c(3) = sum(colm(a,3)); + + DLIB_TEST(sum_cols(a) == b); + DLIB_TEST(sum_rows(a) == c); + + } + + { + matrix m(3,4), s(3,4); + m = -2, 1, 5, -5, + 5, 5, 5, 5, + 9, 0, -4, -2; + + s = -1, 1, 1, -1, + 1, 1, 1, 1, + 1, 1, -1, -1; + + DLIB_TEST(sign(m) == s); + DLIB_TEST(sign(matrix_cast(m)) == matrix_cast(s)); + } + + } + + + void test_matrix_IO() + { + dlib::rand rnd; + print_spinner(); + + for (int i = 0; i < 400; ++i) + { + ostringstream sout; + sout.precision(20); + + matrix m1, m2, m3; + + const long r = rnd.get_random_32bit_number()%7+1; + const long c = rnd.get_random_32bit_number()%7+1; + const long num = rnd.get_random_32bit_number()%2+1; + + m1 = randm(r,c,rnd); + sout << m1; + if (num != 1) + sout << "\n" << m1; + + if (rnd.get_random_double() < 0.3) + sout << " \n"; + else if (rnd.get_random_double() < 0.3) + sout << " \n\n 3 3 3 3"; + else if (rnd.get_random_double() < 0.3) + sout << " \n \n v 3 3 3 3 3"; + + istringstream sin(sout.str()); + sin >> m2; + DLIB_TEST_MSG(equal(m1,m2), m1 << "\n***********\n" << m2); + + if (num != 1) + { + sin >> m3; + DLIB_TEST_MSG(equal(m1,m3), m1 << "\n***********\n" << m3); + } + } + + + { + istringstream sin(" 1 2\n3"); + matrix m; + DLIB_TEST(sin.good()); + sin >> m; + DLIB_TEST(!sin.good()); + } + { + istringstream sin(""); + matrix m; + DLIB_TEST(sin.good()); + sin >> m; + DLIB_TEST(!sin.good()); + } + } + + + void test_axpy() + { + const int n = 4; + matrix B = dlib::randm(n,n); + + matrix g = dlib::uniform_matrix(n,1,0.0); + + const double tau = 1; + + matrix p = g + tau*dlib::colm(B,0); + matrix q = dlib::colm(B,0); + DLIB_TEST(max(abs(p-q)) < 1e-14); + + p = tau*dlib::colm(B,0); + q = dlib::colm(B,0); + DLIB_TEST(max(abs(p-q)) < 1e-14); + + + + + g = dlib::uniform_matrix(n,n,0.0); + p = g + tau*B; + DLIB_TEST(max(abs(p-B)) < 1e-14); + + p = g + tau*subm(B,get_rect(B)); + DLIB_TEST(max(abs(p-B)) < 1e-14); + + g = dlib::uniform_matrix(2,2,0.0); + p = g + tau*subm(B,1,1,2,2); + DLIB_TEST(max(abs(p-subm(B,1,1,2,2))) < 1e-14); + + set_subm(p,0,0,2,2) = g + tau*subm(B,1,1,2,2); + DLIB_TEST(max(abs(p-subm(B,1,1,2,2))) < 1e-14); + } + + + class matrix_tester : public tester + { + public: + matrix_tester ( + ) : + tester ("test_matrix3", + "Runs tests on the matrix component.") + {} + + void perform_test ( + ) + { + test_axpy(); + test_matrix_IO(); + matrix_test(); + } + } a; + +} + + diff --git a/ml/dlib/dlib/test/matrix4.cpp b/ml/dlib/dlib/test/matrix4.cpp new file mode 100644 index 000000000..d2b83b712 --- /dev/null +++ b/ml/dlib/dlib/test/matrix4.cpp @@ -0,0 +1,1119 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include +#include +#include "../stl_checked.h" +#include "../array.h" +#include "../rand.h" + +#include "tester.h" +#include +#include + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.matrix4"); + + void matrix_test ( + ) + /*! + ensures + - runs tests on the matrix stuff compliance with the specs + !*/ + { + print_spinner(); + + { + matrix m = round(10*randm(3,3)); + matrix v = round(10*randm(3,1)); + + DLIB_TEST(equal( m*diagm(v) , m*tmp(diagm(v)) )); + DLIB_TEST(equal( scale_columns(m,v) , m*tmp(diagm(v)) )); + + DLIB_TEST(equal( diagm(v)*m , tmp(diagm(v))*m )); + DLIB_TEST(equal( scale_rows(m,v) , tmp(diagm(v))*m )); + } + + { + matrix m = round(10*randm(3,3)); + matrix v = round(10*randm(1,3)); + + DLIB_TEST(equal( m*diagm(v) , m*tmp(diagm(v)) )); + DLIB_TEST(equal( scale_columns(m,v) , m*tmp(diagm(v)) )); + + DLIB_TEST(equal( diagm(v)*m , tmp(diagm(v))*m )); + DLIB_TEST(equal( scale_rows(m,v) , tmp(diagm(v))*m )); + } + + { + matrix m = round(10*randm(3,3)); + matrix v = round(10*randm(1,3)); + + DLIB_TEST(equal( m*diagm(v) , m*tmp(diagm(v)) )); + DLIB_TEST(equal( scale_columns(m,v) , m*tmp(diagm(v)) )); + + DLIB_TEST(equal( diagm(v)*m , tmp(diagm(v))*m )); + DLIB_TEST(equal( scale_rows(m,v) , tmp(diagm(v))*m )); + } + + { + matrix m = round(10*randm(3,3)); + matrix v = round(10*randm(1,3)); + + DLIB_TEST(equal( m*diagm(v) , m*tmp(diagm(v)) )); + DLIB_TEST(equal( scale_columns(m,v) , m*tmp(diagm(v)) )); + + DLIB_TEST(equal( diagm(v)*m , tmp(diagm(v))*m )); + DLIB_TEST(equal( scale_rows(m,v) , tmp(diagm(v))*m )); + } + + + { + matrix m = round(10*randm(3,3)); + matrix v = round(10*randm(1,3)); + + DLIB_TEST(equal( m*diagm(v) , m*tmp(diagm(v)) )); + DLIB_TEST(equal( scale_columns(m,v) , m*tmp(diagm(v)) )); + + DLIB_TEST(equal( diagm(v)*m , tmp(diagm(v))*m )); + DLIB_TEST(equal( scale_rows(m,v) , tmp(diagm(v))*m )); + } + + { + matrix m = round(10*randm(3,3)); + matrix v = round(10*randm(3,1)); + + DLIB_TEST(equal( m*diagm(v) , m*tmp(diagm(v)) )); + DLIB_TEST(equal( scale_columns(m,v) , m*tmp(diagm(v)) )); + + DLIB_TEST(equal( diagm(v)*m , tmp(diagm(v))*m )); + DLIB_TEST(equal( scale_rows(m,v) , tmp(diagm(v))*m )); + } + + + { + matrix m = round(10*randm(3,3)); + matrix v = round(10*randm(3,1)); + + DLIB_TEST(equal( m*diagm(v) , m*tmp(diagm(v)) )); + DLIB_TEST(equal( scale_columns(m,v) , m*tmp(diagm(v)) )); + + DLIB_TEST(equal( diagm(v)*m , tmp(diagm(v))*m )); + DLIB_TEST(equal( scale_rows(m,v) , tmp(diagm(v))*m )); + } + + { + matrix m = round(10*randm(3,3)); + matrix v = round(10*randm(3,1)); + + DLIB_TEST(equal( m*diagm(v) , m*tmp(diagm(v)) )); + DLIB_TEST(equal( scale_columns(m,v) , m*tmp(diagm(v)) )); + + DLIB_TEST(equal( diagm(v)*m , tmp(diagm(v))*m )); + DLIB_TEST(equal( scale_rows(m,v) , tmp(diagm(v))*m )); + } + + + { + matrix m = round(10*randm(3,3)); + matrix v = round(10*randm(3,1)); + + DLIB_TEST(equal( m*diagm(v) , m*tmp(diagm(v)) )); + DLIB_TEST(equal( scale_columns(m,v) , m*tmp(diagm(v)) )); + + DLIB_TEST(equal( diagm(v)*m , tmp(diagm(v))*m )); + DLIB_TEST(equal( scale_rows(m,v) , tmp(diagm(v))*m )); + } + + { + matrix m = round(10*randm(3,5)); + matrix v1 = round(10*randm(5,1)); + matrix v2 = round(10*randm(3,1)); + + DLIB_TEST(equal( m*diagm(v1) , m*tmp(diagm(v1)) )); + DLIB_TEST(equal( scale_columns(m,v1) , m*tmp(diagm(v1)) )); + + DLIB_TEST(equal( diagm(v2)*m , tmp(diagm(v2))*m )); + DLIB_TEST(equal( scale_rows(m,v2) , tmp(diagm(v2))*m )); + } + + { + matrix m = round(10*randm(3,5)); + matrix v1 = round(10*randm(5,1)); + matrix v2 = round(10*randm(3,1)); + + DLIB_TEST(equal( m*diagm(v1) , m*tmp(diagm(v1)) )); + DLIB_TEST(equal( scale_columns(m,v1) , m*tmp(diagm(v1)) )); + + DLIB_TEST(equal( diagm(v2)*m , tmp(diagm(v2))*m )); + DLIB_TEST(equal( scale_rows(m,v2) , tmp(diagm(v2))*m )); + } + + } + + void test_stuff() + { + print_spinner(); + + { + matrix m(3,3), lr(3,3), ud(3,3); + + m = 1,2,3, + 4,5,6, + 7,8,9; + + + lr = 3,2,1, + 6,5,4, + 9,8,7; + + ud = 7,8,9, + 4,5,6, + 1,2,3; + + DLIB_TEST(lr == fliplr(m)); + DLIB_TEST(ud == flipud(m)); + } + { + matrix m(3,2), lr(3,2), ud(3,2); + + m = 1,2, + 3,4, + 5,6; + + lr = 2,1, + 4,3, + 6,5; + + ud = 5,6, + 3,4, + 1,2; + + DLIB_TEST(lr == fliplr(m)); + DLIB_TEST(ud == flipud(m)); + } + + { + matrix a, b; + + a = matrix_cast(round(10*randm(3,3))); + b = a; + + b *= b; + DLIB_TEST(b == a*a); + } + + { + matrix m(2,3), m2(2,3); + + m = 1,2,3, + 4,5,6; + + + m2 = 3,4,5, + 6,7,8; + + DLIB_TEST(m + 2 == m2); + DLIB_TEST(2 + m == m2); + + m += 2; + DLIB_TEST(m == m2); + m -= 2; + + m2 = 0,1,2, + 3,4,5; + + DLIB_TEST(m - 1 == m2); + + m -= 1; + DLIB_TEST(m == m2); + m += 1; + + + m2 = 5,4,3, + 2,1,0; + + DLIB_TEST(6 - m == m2); + } + + { + matrix m(2,3), m2(2,3); + + m = 1,2,3, + 4,5,6; + + + m2 = 3,4,5, + 6,7,8; + + DLIB_TEST(m + 2 == m2); + DLIB_TEST(2 + m == m2); + + m += 2; + DLIB_TEST(m == m2); + m -= 2; + + m2 = 0,1,2, + 3,4,5; + + DLIB_TEST(m - 1 == m2); + + m -= 1; + DLIB_TEST(m == m2); + m += 1; + + + m2 = 5,4,3, + 2,1,0; + + DLIB_TEST(6 - m == m2); + } + + { + matrix m(2,3), m2(2,3); + + m = 1,2,3, + 4,5,6; + + + m2 = 3,4,5, + 6,7,8; + + DLIB_TEST(m + 2 == m2); + DLIB_TEST(2 + m == m2); + + m += 2; + DLIB_TEST(m == m2); + m -= 2; + + m2 = 0,1,2, + 3,4,5; + + DLIB_TEST(m - 1 == m2); + + m -= 1; + DLIB_TEST(m == m2); + m += 1; + + + m2 = 5,4,3, + 2,1,0; + + DLIB_TEST(6 - m == m2); + } + + { + matrix m, m2; + + m = 1,2,3, + 4,5,6; + + + m2 = 3,4,5, + 6,7,8; + + DLIB_TEST(m + 2 == m2); + DLIB_TEST(2 + m == m2); + + m += 2; + DLIB_TEST(m == m2); + m -= 2; + + m2 = 0,1,2, + 3,4,5; + + DLIB_TEST(m - 1 == m2); + + m -= 1; + DLIB_TEST(m == m2); + m += 1; + + + m2 = 5,4,3, + 2,1,0; + + DLIB_TEST(6 - m == m2); + } + + { + matrix m(2,3), m2(3,2); + + m = 1,2,3, + 4,5,6; + + m2 = 2,5, + 3,6, + 4,7; + + DLIB_TEST(trans(m+1) == m2); + DLIB_TEST(trans(m)+1 == m2); + DLIB_TEST(1+trans(m) == m2); + DLIB_TEST(1+m-1 == m); + + m = trans(m+1); + DLIB_TEST(m == m2); + m = trans(m-1); + DLIB_TEST(trans(m+1) == m2); + m = trans(m)+1; + DLIB_TEST(m == m2); + } + + { + matrix d(3,1), di(3,1); + matrix m(3,3); + + m = 1,2,3, + 4,5,6, + 7,8,9; + + d = 1,2,3; + + di = 1, 1/2.0, 1/3.0; + + DLIB_TEST(inv(diagm(d)) == diagm(di)); + DLIB_TEST(pinv(diagm(d)) == diagm(di)); + DLIB_TEST(inv(diagm(d))*m == tmp(diagm(di))*m); + DLIB_TEST(m*inv(diagm(d)) == m*tmp(diagm(di))); + + DLIB_TEST(equal(inv(diagm(d)) + m , tmp(diagm(di)) + m)); + DLIB_TEST(equal(m + inv(diagm(d)) , tmp(diagm(di)) + m)); + + DLIB_TEST((m + identity_matrix(3) == m + tmp(identity_matrix(3)))); + DLIB_TEST((m + identity_matrix() == m + tmp(identity_matrix()))); + DLIB_TEST((m + 2*identity_matrix(3) == m + 2*tmp(identity_matrix(3)))); + DLIB_TEST((m + 2*identity_matrix() == m + 2*tmp(identity_matrix()))); + DLIB_TEST((m + identity_matrix(3)*2 == m + 2*tmp(identity_matrix(3)))); + DLIB_TEST((m + identity_matrix()*2 == m + 2*tmp(identity_matrix()))); + + DLIB_TEST((identity_matrix(3) + m == m + tmp(identity_matrix(3)))); + DLIB_TEST((identity_matrix() + m == m + tmp(identity_matrix()))); + DLIB_TEST((2*identity_matrix(3) + m == m + 2*tmp(identity_matrix(3)))); + DLIB_TEST((2*identity_matrix() + m == m + 2*tmp(identity_matrix()))); + + } + { + matrix d(3,1), di(3,1); + matrix m(3,3); + + m = 1,2,3, + 4,5,6, + 7,8,9; + + d = 1,2,3; + + di = 1, 1/2.0, 1/3.0; + + DLIB_TEST(equal(inv(diagm(d)) , diagm(di))); + DLIB_TEST(equal(inv(diagm(d)) , diagm(di))); + DLIB_TEST(equal(inv(diagm(d))*m , tmp(diagm(di))*m)); + DLIB_TEST(equal(m*inv(diagm(d)) , m*tmp(diagm(di)))); + + DLIB_TEST_MSG(equal(inv(diagm(d)) + m , tmp(diagm(di)) + m), + (inv(diagm(d)) + m) - (tmp(diagm(di)) + m) ); + DLIB_TEST(equal(m + inv(diagm(d)) , tmp(diagm(di)) + m)); + + + DLIB_TEST((m + identity_matrix(3) == m + tmp(identity_matrix(3)))); + DLIB_TEST((m + identity_matrix() == m + tmp(identity_matrix()))); + DLIB_TEST((m + 2*identity_matrix(3) == m + 2*tmp(identity_matrix(3)))); + DLIB_TEST((m + 2*identity_matrix() == m + 2*tmp(identity_matrix()))); + DLIB_TEST((m + identity_matrix(3)*2 == m + 2*tmp(identity_matrix(3)))); + DLIB_TEST((m + identity_matrix()*2 == m + 2*tmp(identity_matrix()))); + + DLIB_TEST((identity_matrix(3) + m == m + tmp(identity_matrix(3)))); + DLIB_TEST((identity_matrix() + m == m + tmp(identity_matrix()))); + DLIB_TEST((2*identity_matrix(3) + m == m + 2*tmp(identity_matrix(3)))); + DLIB_TEST((2*identity_matrix() + m == m + 2*tmp(identity_matrix()))); + } + + { + matrix d(1,3), di(1,3); + matrix m(3,3); + + m = 1,2,3, + 4,5,6, + 7,8,9; + + d = 1,2,3; + + di = 1, 1/2.0, 1/3.0; + + DLIB_TEST(equal(inv(diagm(d)) , diagm(di))); + DLIB_TEST(equal(inv(diagm(d)) , diagm(di))); + DLIB_TEST(equal(inv(diagm(d))*m , tmp(diagm(di))*m)); + DLIB_TEST(equal(m*inv(diagm(d)) , m*tmp(diagm(di)))); + + DLIB_TEST(equal(inv(diagm(d)) + m , tmp(diagm(di)) + m)); + DLIB_TEST(equal(m + inv(diagm(d)) , tmp(diagm(di)) + m)); + + + DLIB_TEST((m + identity_matrix(3) == m + tmp(identity_matrix(3)))); + DLIB_TEST((m + identity_matrix() == m + tmp(identity_matrix()))); + DLIB_TEST((m + 2*identity_matrix(3) == m + 2*tmp(identity_matrix(3)))); + DLIB_TEST((m + 2*identity_matrix() == m + 2*tmp(identity_matrix()))); + DLIB_TEST((m + identity_matrix(3)*2 == m + 2*tmp(identity_matrix(3)))); + DLIB_TEST((m + identity_matrix()*2 == m + 2*tmp(identity_matrix()))); + + DLIB_TEST((identity_matrix(3) + m == m + tmp(identity_matrix(3)))); + DLIB_TEST((identity_matrix() + m == m + tmp(identity_matrix()))); + DLIB_TEST((2*identity_matrix(3) + m == m + 2*tmp(identity_matrix(3)))); + DLIB_TEST((2*identity_matrix() + m == m + 2*tmp(identity_matrix()))); + } + + { + matrix d(1,3), di(1,3); + matrix m(3,3); + + m = 1,2,3, + 4,5,6, + 7,8,9; + + d = 1,2,3; + + di = 1, 1/2.0, 1/3.0; + + DLIB_TEST(equal(inv(diagm(d)) , diagm(di))); + DLIB_TEST(equal(inv(diagm(d)) , diagm(di))); + DLIB_TEST(equal(inv(diagm(d))*m , tmp(diagm(di))*m)); + DLIB_TEST(equal(m*inv(diagm(d)) , m*tmp(diagm(di)))); + + DLIB_TEST(equal(inv(diagm(d)) + m , tmp(diagm(di)) + m)); + DLIB_TEST(equal(m + inv(diagm(d)) , tmp(diagm(di)) + m)); + + + DLIB_TEST((m + identity_matrix(3) == m + tmp(identity_matrix(3)))); + DLIB_TEST((m + identity_matrix() == m + tmp(identity_matrix()))); + DLIB_TEST((m + 2*identity_matrix(3) == m + 2*tmp(identity_matrix(3)))); + DLIB_TEST((m + 2*identity_matrix() == m + 2*tmp(identity_matrix()))); + DLIB_TEST((m + identity_matrix(3)*2 == m + 2*tmp(identity_matrix(3)))); + DLIB_TEST((m + identity_matrix()*2 == m + 2*tmp(identity_matrix()))); + + DLIB_TEST((identity_matrix(3) + m == m + tmp(identity_matrix(3)))); + DLIB_TEST((identity_matrix() + m == m + tmp(identity_matrix()))); + DLIB_TEST((2*identity_matrix(3) + m == m + 2*tmp(identity_matrix(3)))); + DLIB_TEST((2*identity_matrix() + m == m + 2*tmp(identity_matrix()))); + } + + + { + matrix d1, d2; + + d1 = 1,2,3; + + d2 = 2,3,4; + + matrix ans; + ans = 2, 0, 0, + 0, 6, 0, + 0, 0, 12; + + DLIB_TEST(ans == diagm(d1)*diagm(d2)); + } + + + dlib::rand rnd; + for (int i = 0; i < 1; ++i) + { + matrix d1 = randm(4,1,rnd); + matrix d2 = randm(5,1,rnd); + + matrix m = randm(4,5,rnd); + + DLIB_TEST_MSG(equal(pointwise_multiply(d1*trans(d2), m) , diagm(d1)*m*diagm(d2)), + pointwise_multiply(d1*trans(d2), m) - diagm(d1)*m*diagm(d2) + ); + DLIB_TEST(equal(pointwise_multiply(d1*trans(d2), m) , diagm(d1)*(m*diagm(d2)))); + DLIB_TEST(equal(pointwise_multiply(d1*trans(d2), m) , (diagm(d1)*m)*diagm(d2))); + + DLIB_TEST(equal(pointwise_multiply(reciprocal(d1)*trans(reciprocal(d2)), m) , inv(diagm(d1))*m*inv(diagm(d2)))); + DLIB_TEST(equal(pointwise_multiply(reciprocal(d1)*trans(reciprocal(d2)), m) , inv(diagm(d1))*(m*inv(diagm(d2))))); + DLIB_TEST(equal(pointwise_multiply(reciprocal(d1)*trans(reciprocal(d2)), m) , (inv(diagm(d1))*m)*inv(diagm(d2)))); + + DLIB_TEST(equal(pointwise_multiply(reciprocal(d1)*trans((d2)), m) , inv(diagm(d1))*m*(diagm(d2)))); + DLIB_TEST(equal(pointwise_multiply(reciprocal(d1)*trans((d2)), m) , inv(diagm(d1))*(m*(diagm(d2))))); + DLIB_TEST(equal(pointwise_multiply(reciprocal(d1)*trans((d2)), m) , (inv(diagm(d1))*m)*(diagm(d2)))); + + DLIB_TEST(equal(pointwise_multiply((d1)*trans(reciprocal(d2)), m) , (diagm(d1))*m*inv(diagm(d2)))); + DLIB_TEST(equal(pointwise_multiply((d1)*trans(reciprocal(d2)), m) , (diagm(d1))*(m*inv(diagm(d2))))); + DLIB_TEST(equal(pointwise_multiply((d1)*trans(reciprocal(d2)), m) , ((diagm(d1))*m)*inv(diagm(d2)))); + } + for (int i = 0; i < 1; ++i) + { + matrix d1 = randm(4,1,rnd); + matrix d2 = randm(5,1,rnd); + + matrix m = randm(4,5,rnd); + + DLIB_TEST(equal(pointwise_multiply(d1*trans(d2), m) , diagm(d1)*m*diagm(d2))); + DLIB_TEST(equal(pointwise_multiply(d1*trans(d2), m) , diagm(d1)*(m*diagm(d2)))); + DLIB_TEST(equal(pointwise_multiply(d1*trans(d2), m) , (diagm(d1)*m)*diagm(d2))); + + DLIB_TEST(equal(pointwise_multiply(reciprocal(d1)*trans(reciprocal(d2)), m) , inv(diagm(d1))*m*inv(diagm(d2)))); + DLIB_TEST(equal(pointwise_multiply(reciprocal(d1)*trans(reciprocal(d2)), m) , inv(diagm(d1))*(m*inv(diagm(d2))))); + DLIB_TEST(equal(pointwise_multiply(reciprocal(d1)*trans(reciprocal(d2)), m) , (inv(diagm(d1))*m)*inv(diagm(d2)))); + + DLIB_TEST(equal(pointwise_multiply(reciprocal(d1)*trans((d2)), m) , inv(diagm(d1))*m*(diagm(d2)))); + DLIB_TEST(equal(pointwise_multiply(reciprocal(d1)*trans((d2)), m) , inv(diagm(d1))*(m*(diagm(d2))))); + DLIB_TEST(equal(pointwise_multiply(reciprocal(d1)*trans((d2)), m) , (inv(diagm(d1))*m)*(diagm(d2)))); + + DLIB_TEST(equal(pointwise_multiply((d1)*trans(reciprocal(d2)), m) , (diagm(d1))*m*inv(diagm(d2)))); + DLIB_TEST(equal(pointwise_multiply((d1)*trans(reciprocal(d2)), m) , (diagm(d1))*(m*inv(diagm(d2))))); + DLIB_TEST(equal(pointwise_multiply((d1)*trans(reciprocal(d2)), m) , ((diagm(d1))*m)*inv(diagm(d2)))); + } + for (int i = 0; i < 1; ++i) + { + matrix d1 = randm(4,1,rnd); + matrix d2 = randm(5,1,rnd); + + matrix m = randm(4,5,rnd); + + DLIB_TEST(equal(pointwise_multiply(d1*trans(d2), m) , diagm(d1)*m*diagm(d2))); + DLIB_TEST(equal(pointwise_multiply(d1*trans(d2), m) , diagm(d1)*(m*diagm(d2)))); + DLIB_TEST(equal(pointwise_multiply(d1*trans(d2), m) , (diagm(d1)*m)*diagm(d2))); + + DLIB_TEST(equal(pointwise_multiply(reciprocal(d1)*trans(reciprocal(d2)), m) , inv(diagm(d1))*m*inv(diagm(d2)))); + DLIB_TEST(equal(pointwise_multiply(reciprocal(d1)*trans(reciprocal(d2)), m) , inv(diagm(d1))*(m*inv(diagm(d2))))); + DLIB_TEST(equal(pointwise_multiply(reciprocal(d1)*trans(reciprocal(d2)), m) , (inv(diagm(d1))*m)*inv(diagm(d2)))); + + DLIB_TEST(equal(pointwise_multiply(reciprocal(d1)*trans((d2)), m) , inv(diagm(d1))*m*(diagm(d2)))); + DLIB_TEST(equal(pointwise_multiply(reciprocal(d1)*trans((d2)), m) , inv(diagm(d1))*(m*(diagm(d2))))); + DLIB_TEST(equal(pointwise_multiply(reciprocal(d1)*trans((d2)), m) , (inv(diagm(d1))*m)*(diagm(d2)))); + + DLIB_TEST(equal(pointwise_multiply((d1)*trans(reciprocal(d2)), m) , (diagm(d1))*m*inv(diagm(d2)))); + DLIB_TEST(equal(pointwise_multiply((d1)*trans(reciprocal(d2)), m) , (diagm(d1))*(m*inv(diagm(d2))))); + DLIB_TEST(equal(pointwise_multiply((d1)*trans(reciprocal(d2)), m) , ((diagm(d1))*m)*inv(diagm(d2)))); + } + + + + { + for (int i = 0; i < 5; ++i) + { + matrix m = randm(3,4) + 1; + + DLIB_TEST(equal(1.0/m , reciprocal(m))); + DLIB_TEST(equal(0.0/m , zeros_matrix(3,4))); + } + } + + { + matrix m(2,3); + m = 1,2,3, + 4,5,6; + matrix M(2,3); + M = m; + + DLIB_TEST(upperbound(m,6) == M); + DLIB_TEST(upperbound(m,60) == M); + DLIB_TEST(lowerbound(m,-2) == M); + DLIB_TEST(lowerbound(m,0) == M); + + M = 2,2,3, + 4,5,6; + DLIB_TEST(lowerbound(m,2) == M); + + M = 0,0,0, + 0,0,0; + DLIB_TEST(upperbound(m,0) == M); + + M = 1,2,3, + 3,3,3; + DLIB_TEST(upperbound(m,3) == M); + } + + { + matrix A = randm(9,5); + matrix B = A; + + orthogonalize(A); + orthogonalize(B); + + DLIB_TEST(equal(A,B)); + } + } + + + + template < + long D1, + long D2, + long D3, + long D4 + > + void test_conv() + { + dlog << LINFO << D1 << " " << D2 << " " << D3 << " " << D4; + matrix a(1,1); + matrix b(2,2); + matrix c(3,3); + matrix d(4,1); + + a = 4; + + b = 1,2, + 3,4; + + c = 1,2,3, + 4,5,6, + 7,8,9; + + d = 1, + 2, + 3, + 4; + + matrix temp(4,4), temp2; + temp = 1, 4, 7, 6, + 7, 23, 33, 24, + 19, 53, 63, 42, + 21, 52, 59, 36; + + DLIB_TEST(conv(b,c) == temp); + DLIB_TEST(conv(c,b) == temp); + DLIB_TEST(xcorr(c,flip(b)) == temp); + + temp.set_size(2,2); + temp = 23, 33, + 53, 63; + DLIB_TEST(conv_same(b,c) == temp); + DLIB_TEST(xcorr_same(b,flip(c)) == temp); + + temp2.set_size(2,2); + temp2 = 63, 53, + 33, 23; + DLIB_TEST(flip(temp) == temp2); + DLIB_TEST(flip(temp) == fliplr(flipud(temp))); + + DLIB_TEST(conv_valid(b,c).nr() == 0); + DLIB_TEST(conv_valid(b,c).nc() == 0); + + DLIB_TEST(conv_valid(c,b) == temp); + DLIB_TEST(xcorr_valid(c,flip(b)) == temp); + + temp.set_size(1,1); + temp = 16; + + DLIB_TEST(conv(a,a) == temp); + DLIB_TEST(conv_same(a,a) == temp); + DLIB_TEST(conv_valid(a,a) == temp); + DLIB_TEST(xcorr(a,a) == temp); + DLIB_TEST(xcorr_same(a,a) == temp); + DLIB_TEST(xcorr_valid(a,a) == temp); + + temp.set_size(0,0); + DLIB_TEST(conv(temp,temp).nr() == 0); + DLIB_TEST(conv(temp,temp).nc() == 0); + DLIB_TEST(conv_same(temp,temp).nr() == 0); + DLIB_TEST(conv_same(temp,temp).nc() == 0); + DLIB_TEST_MSG(conv_valid(temp,temp).nr() == 0, conv_valid(temp,temp).nr()); + DLIB_TEST(conv_valid(temp,temp).nc() == 0); + DLIB_TEST(conv(c,temp).nr() == 0); + DLIB_TEST(conv(c,temp).nc() == 0); + DLIB_TEST(conv_same(c,temp).nr() == 0); + DLIB_TEST(conv_same(c,temp).nc() == 0); + DLIB_TEST(conv_valid(c,temp).nr() == 0); + DLIB_TEST(conv_valid(c,temp).nc() == 0); + DLIB_TEST(conv(temp,c).nr() == 0); + DLIB_TEST(conv(temp,c).nc() == 0); + DLIB_TEST(conv_same(temp,c).nr() == 0); + DLIB_TEST(conv_same(temp,c).nc() == 0); + DLIB_TEST(conv_valid(temp,c).nr() == 0); + DLIB_TEST(conv_valid(temp,c).nc() == 0); + + temp.set_size(5,2); + temp = 1, 2, + 5, 8, + 9, 14, + 13, 20, + 12, 16; + DLIB_TEST(conv(b,d) == temp); + DLIB_TEST(xcorr(b,flip(d)) == temp); + + temp.set_size(2,2); + temp = 9, 14, + 13, 20; + DLIB_TEST(conv_same(b,d) == temp); + DLIB_TEST(xcorr_same(b,flip(d)) == temp); + + DLIB_TEST(conv_valid(b,d).nr() == 0); + DLIB_TEST(xcorr_valid(b,flip(d)).nr() == 0); + DLIB_TEST_MSG(conv_valid(b,d).nc() == 0, conv_valid(b,d).nc()); + DLIB_TEST(xcorr_valid(b,flip(d)).nc() == 0); + + temp.set_size(5,5); + temp = 1, 4, 10, 12, 9, + 8, 26, 56, 54, 36, + 30, 84, 165, 144, 90, + 56, 134, 236, 186, 108, + 49, 112, 190, 144, 81; + + DLIB_TEST(conv(c,c) == temp); + DLIB_TEST(xcorr(c,flip(c)) == temp); + matrix temp3 = c; + temp3 = conv(temp3,c); + DLIB_TEST(temp3 == temp); + + temp3 = c; + temp3 = conv(c,temp3); + DLIB_TEST(temp3 == temp); + + + temp.set_size(3,3); + temp = 26, 56, 54, + 84, 165, 144, + 134, 236, 186; + DLIB_TEST(conv_same(c,c) == temp); + DLIB_TEST(xcorr_same(c,flip(c)) == temp); + temp3 = c; + temp3 = conv_same(c,temp3); + DLIB_TEST(temp3 == temp); + temp3 = c; + temp3 = conv_same(temp3,c); + DLIB_TEST(temp3 == temp); + + temp.set_size(1,1); + temp = 165; + DLIB_TEST(conv_valid(c,c) == temp); + DLIB_TEST(xcorr_valid(c,flip(c)) == temp); + temp3 = c; + temp3 = conv_valid(c,temp3); + DLIB_TEST(temp3 == temp); + temp3 = c; + temp3 = conv_valid(temp3,c); + DLIB_TEST(temp3 == temp); + + + dlib::rand rnd; + for (int i = 0; i < 3; ++i) + { + matrix > a, b; + a = complex_matrix(matrix_cast(round(20*randm(2,7,rnd))), + matrix_cast(round(20*randm(2,7,rnd)))); + b = complex_matrix(matrix_cast(round(20*randm(3,2,rnd))), + matrix_cast(round(20*randm(3,2,rnd)))); + + DLIB_TEST(xcorr(a,b) == conv(a, flip(conj(b)))); + DLIB_TEST(xcorr_valid(a,b) == conv_valid(a, flip(conj(b)))); + DLIB_TEST(xcorr_same(a,b) == conv_same(a, flip(conj(b)))); + } + + + for (int i = 0; i < 30; ++i) + { + auto nr1 = rnd.get_integer_in_range(1,30); + auto nc1 = rnd.get_integer_in_range(1,30); + auto nr2 = rnd.get_integer_in_range(1,30); + auto nc2 = rnd.get_integer_in_range(1,30); + matrix a, b; + a = randm(nr1,nc1,rnd); + b = randm(nr2,nc2,rnd); + + DLIB_TEST(max(abs(xcorr(a,b) - xcorr_fft(a,b))) < 1e-12); + } + } + + void test_complex() + { + matrix > a, b; + + a = complex_matrix(linspace(1,7,7), linspace(2,8,7)); + b = complex_matrix(linspace(4,10,7), linspace(2,8,7)); + + DLIB_TEST(mean(a) == complex(4, 5)); + } + + void test_setsubs() + { + { + matrix m(3,3); + m = 0; + + set_colm(m,0) += 1; + set_rowm(m,0) += 1; + set_subm(m,1,1,2,2) += 5; + + matrix m2(3,3); + m2 = 2, 1, 1, + 1, 5, 5, + 1, 5, 5; + + DLIB_TEST(m == m2); + + set_colm(m,0) -= 1; + set_rowm(m,0) -= 1; + set_subm(m,1,1,2,2) -= 5; + + m2 = 0; + DLIB_TEST(m == m2); + + matrix r; + matrix c; + matrix b; + r = 1,2,3; + + c = 2, + 3, + 4; + + b = 2,3, + 4,5; + + set_colm(m,1) += c; + set_rowm(m,1) += r; + set_subm(m,1,1,2,2) += b; + + m2 = 0, 2, 0, + 1, 7, 6, + 0, 8, 5; + + DLIB_TEST(m2 == m); + + set_colm(m,1) -= c; + set_rowm(m,1) -= r; + set_subm(m,1,1,2,2) -= b; + + m2 = 0; + DLIB_TEST(m2 == m); + + + // check that the code path for destructive aliasing works right. + m = 2*identity_matrix(3); + set_colm(m,1) += m*c; + m2 = 2, 4, 0, + 0, 8, 0, + 0, 8, 2; + DLIB_TEST(m == m2); + + m = 2*identity_matrix(3); + set_colm(m,1) -= m*c; + m2 = 2, -4, 0, + 0, -4, 0, + 0, -8, 2; + DLIB_TEST(m == m2); + + m = 2*identity_matrix(3); + set_rowm(m,1) += r*m; + m2 = 2, 0, 0, + 2, 6, 6, + 0, 0, 2; + DLIB_TEST(m == m2); + + m = 2*identity_matrix(3); + set_rowm(m,1) -= r*m; + m2 = 2, 0, 0, + -2, -2, -6, + 0, 0, 2; + DLIB_TEST(m == m2); + + m = identity_matrix(3); + const rectangle rect(0,0,1,1); + set_subm(m,rect) += subm(m,rect)*b; + m2 = 3, 3, 0, + 4, 6, 0, + 0, 0, 1; + DLIB_TEST(m == m2); + + m = identity_matrix(3); + set_subm(m,rect) -= subm(m,rect)*b; + m2 = -1, -3, 0, + -4, -4, 0, + 0, 0, 1; + DLIB_TEST(m == m2); + + } + + { + matrix a, b; + a = 2; + b = 3; + DLIB_TEST(dot(a,b) == 6); + } + { + matrix a; + matrix b(1); + a = 2; + b = 3; + DLIB_TEST(dot(a,b) == 6); + DLIB_TEST(dot(b,a) == 6); + } + { + matrix a; + matrix b(1); + a = 2; + b = 3; + DLIB_TEST(dot(a,b) == 6); + DLIB_TEST(dot(b,a) == 6); + } + } + + template + std::vector tovect1(const T& m) + { + std::vector temp; + for (typename T::const_iterator i = m.begin(); i != m.end(); ++i) + { + temp.push_back(*i); + } + return temp; + } + + template + std::vector tovect2(const T& m) + { + std::vector temp; + for (typename T::const_iterator i = m.begin(); i != m.end(); i++) + { + temp.push_back(*i); + } + return temp; + } + + template + std::vector tovect3(const T& m_) + { + matrix m(m_); + std::vector temp; + for (matrix::iterator i = m.begin(); i != m.end(); ++i) + { + temp.push_back(*i); + } + return temp; + } + + template + std::vector tovect4(const T& m_) + { + matrix m(m_); + std::vector temp; + for (matrix::iterator i = m.begin(); i != m.end(); i++) + { + temp.push_back(*i); + } + return temp; + } + + void test_iterators() + { + matrix m(3,2); + m = 1,2,3, + 4,5,6; + + std::vector v1 = tovect1(m); + std::vector v2 = tovect2(m); + std::vector v3 = tovect3(m); + std::vector v4 = tovect4(m); + + std::vector v5 = tovect1(m+m); + std::vector v6 = tovect2(m+m); + std::vector v7 = tovect3(m+m); + std::vector v8 = tovect4(m+m); + + + std::vector a1, a2; + for (int i = 1; i <= 6; ++i) + { + a1.push_back(i); + a2.push_back(i*2); + } + + DLIB_TEST(max(abs(mat(v1) - mat(a1))) == 0); + DLIB_TEST(max(abs(mat(v2) - mat(a1))) == 0); + DLIB_TEST(max(abs(mat(v3) - mat(a1))) == 0); + DLIB_TEST(max(abs(mat(v4) - mat(a1))) == 0); + + DLIB_TEST(max(abs(mat(v5) - mat(a2))) == 0); + DLIB_TEST(max(abs(mat(v6) - mat(a2))) == 0); + DLIB_TEST(max(abs(mat(v7) - mat(a2))) == 0); + DLIB_TEST(max(abs(mat(v8) - mat(a2))) == 0); + } + + void test_linpiece() + { + matrix temp = linpiece(5, linspace(-1, 9, 2)); + DLIB_CASSERT(temp.size() == 1,""); + DLIB_CASSERT(std::abs(temp(0) - 6) < 1e-13,""); + + temp = linpiece(5, linspace(-1, 9, 6)); + DLIB_CASSERT(temp.size() == 5,""); + DLIB_CASSERT(std::abs(temp(0) - 2) < 1e-13,""); + DLIB_CASSERT(std::abs(temp(1) - 2) < 1e-13,""); + DLIB_CASSERT(std::abs(temp(2) - 2) < 1e-13,""); + DLIB_CASSERT(std::abs(temp(3) - 0) < 1e-13,""); + DLIB_CASSERT(std::abs(temp(4) - 0) < 1e-13,""); + + temp = linpiece(4, linspace(-1, 9, 6)); + DLIB_CASSERT(temp.size() == 5,""); + DLIB_CASSERT(std::abs(temp(0) - 2) < 1e-13,""); + DLIB_CASSERT(std::abs(temp(1) - 2) < 1e-13,""); + DLIB_CASSERT(std::abs(temp(2) - 1) < 1e-13,""); + DLIB_CASSERT(std::abs(temp(3) - 0) < 1e-13,""); + DLIB_CASSERT(std::abs(temp(4) - 0) < 1e-13,""); + + temp = linpiece(40, linspace(-1, 9, 6)); + DLIB_CASSERT(temp.size() == 5,""); + DLIB_CASSERT(std::abs(temp(0) - 2) < 1e-13,""); + DLIB_CASSERT(std::abs(temp(1) - 2) < 1e-13,""); + DLIB_CASSERT(std::abs(temp(2) - 2) < 1e-13,""); + DLIB_CASSERT(std::abs(temp(3) - 2) < 1e-13,""); + DLIB_CASSERT(std::abs(temp(4) - 2) < 1e-13,""); + + temp = linpiece(-40, linspace(-1, 9, 6)); + DLIB_CASSERT(temp.size() == 5,""); + DLIB_CASSERT(std::abs(temp(0) - 0) < 1e-13,""); + DLIB_CASSERT(std::abs(temp(1) - 0) < 1e-13,""); + DLIB_CASSERT(std::abs(temp(2) - 0) < 1e-13,""); + DLIB_CASSERT(std::abs(temp(3) - 0) < 1e-13,""); + DLIB_CASSERT(std::abs(temp(4) - 0) < 1e-13,""); + + temp = linpiece(0, linspace(-1, 9, 6)); + DLIB_CASSERT(temp.size() == 5,""); + DLIB_CASSERT(std::abs(temp(0) - 1) < 1e-13,""); + DLIB_CASSERT(std::abs(temp(1) - 0) < 1e-13,""); + DLIB_CASSERT(std::abs(temp(2) - 0) < 1e-13,""); + DLIB_CASSERT(std::abs(temp(3) - 0) < 1e-13,""); + DLIB_CASSERT(std::abs(temp(4) - 0) < 1e-13,""); + + } + + class matrix_tester : public tester + { + public: + matrix_tester ( + ) : + tester ("test_matrix4", + "Runs tests on the scale_rows and scale_columns functions.") + {} + + void perform_test ( + ) + { + test_iterators(); + test_setsubs(); + + test_conv<0,0,0,0>(); + test_conv<1,2,3,4>(); + + test_stuff(); + for (int i = 0; i < 10; ++i) + matrix_test(); + + test_complex(); + test_linpiece(); + } + } a; + +} + + + diff --git a/ml/dlib/dlib/test/matrix_chol.cpp b/ml/dlib/dlib/test/matrix_chol.cpp new file mode 100644 index 000000000..b46c4866d --- /dev/null +++ b/ml/dlib/dlib/test/matrix_chol.cpp @@ -0,0 +1,182 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include +#include +#include "../stl_checked.h" +#include "../array.h" +#include "../rand.h" +#include + +#include "tester.h" + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.matrix_chol"); + + dlib::rand rnd; + +// ---------------------------------------------------------------------------------------- + + template + const matrix symm(const mat_type& m) { return m*trans(m); } + +// ---------------------------------------------------------------------------------------- + + template + const matrix randmat(long r, long c) + { + matrix m(r,c); + for (long row = 0; row < m.nr(); ++row) + { + for (long col = 0; col < m.nc(); ++col) + { + m(row,col) = static_cast(rnd.get_random_double()); + } + } + + return m; + } + + template + const matrix randmat() + { + matrix m; + for (long row = 0; row < m.nr(); ++row) + { + for (long col = 0; col < m.nc(); ++col) + { + m(row,col) = static_cast(rnd.get_random_double()); + } + } + + return m; + } + +// ---------------------------------------------------------------------------------------- + + template + void test_cholesky ( const matrix_type& m) + { + typedef typename matrix_type::type type; + const type eps = 10*max(abs(m))*sqrt(std::numeric_limits::epsilon()); + dlog << LDEBUG << "test_cholesky(): " << m.nr() << " x " << m.nc() << " eps: " << eps; + print_spinner(); + + + cholesky_decomposition test(m); + + // none of the matrices we should be passing in to test_cholesky() should be non-spd. + DLIB_TEST(test.is_spd() == true); + + type temp; + DLIB_TEST_MSG( (temp= max(abs(test.get_l()*trans(test.get_l()) - m))) < eps,temp); + + { + matrix mat = chol(m); + DLIB_TEST_MSG( (temp= max(abs(mat*trans(mat) - m))) < eps,temp); + } + + + matrix m2; + matrix col; + + m2 = identity_matrix(m.nr()); + DLIB_TEST_MSG(equal(m*test.solve(m2), m2,eps),max(abs(m*test.solve(m2)- m2))); + m2 = randmat(m.nr(),5); + DLIB_TEST_MSG(equal(m*test.solve(m2), m2,eps),max(abs(m*test.solve(m2)- m2))); + m2 = randmat(m.nr(),1); + DLIB_TEST_MSG(equal(m*test.solve(m2), m2,eps),max(abs(m*test.solve(m2)- m2))); + col = randmat(m.nr(),1); + DLIB_TEST_MSG(equal(m*test.solve(col), col,eps),max(abs(m*test.solve(m2)- m2))); + + // now make us a non-spd matrix + if (m.nr() > 2) + { + matrix sm(lowerm(m)); + sm(1,1) = 0; + + cholesky_decomposition test2(sm); + DLIB_TEST_MSG(test2.is_spd() == false, test2.get_l()); + + + cholesky_decomposition test3(sm*trans(sm)); + DLIB_TEST_MSG(test3.is_spd() == false, test3.get_l()); + + sm = sm*trans(sm); + sm(1,1) = 5; + sm(1,0) -= 1; + cholesky_decomposition test4(sm); + DLIB_TEST_MSG(test4.is_spd() == false, test4.get_l()); + } + + } + +// ---------------------------------------------------------------------------------------- + + void matrix_test_double() + { + + test_cholesky(uniform_matrix(1,1,1) + 10*symm(randmat(1,1))); + test_cholesky(uniform_matrix(2,2,1) + 10*symm(randmat(2,2))); + test_cholesky(uniform_matrix(3,3,1) + 10*symm(randmat(3,3))); + test_cholesky(uniform_matrix(4,4,1) + 10*symm(randmat(4,4))); + test_cholesky(uniform_matrix(15,15,1) + 10*symm(randmat(15,15))); + test_cholesky(uniform_matrix(101,101,1) + 10*symm(randmat(101,101))); + + typedef matrix mat; + test_cholesky(mat(uniform_matrix(101,101,1) + 10*symm(randmat(101,101)))); + } + +// ---------------------------------------------------------------------------------------- + + void matrix_test_float() + { + + test_cholesky(uniform_matrix(1,1,1) + 2*symm(randmat(1,1))); + test_cholesky(uniform_matrix(2,2,1) + 2*symm(randmat(2,2))); + test_cholesky(uniform_matrix(3,3,1) + 2*symm(randmat(3,3))); + + typedef matrix mat; + test_cholesky(mat(uniform_matrix(3,3,1) + 2*symm(randmat(3,3)))); + } + +// ---------------------------------------------------------------------------------------- + + class matrix_tester : public tester + { + public: + matrix_tester ( + ) : + tester ("test_matrix_chol", + "Runs tests on the matrix cholesky component.") + { + rnd.set_seed(cast_to_string(time(0))); + } + + void perform_test ( + ) + { + dlog << LINFO << "seed string: " << rnd.get_seed(); + + dlog << LINFO << "begin testing with double"; + matrix_test_double(); + dlog << LINFO << "begin testing with float"; + matrix_test_float(); + } + } a; + +} + + + diff --git a/ml/dlib/dlib/test/matrix_eig.cpp b/ml/dlib/dlib/test/matrix_eig.cpp new file mode 100644 index 000000000..9fbce6598 --- /dev/null +++ b/ml/dlib/dlib/test/matrix_eig.cpp @@ -0,0 +1,245 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include +#include +#include "../stl_checked.h" +#include "../array.h" +#include "../rand.h" +#include + +#include "tester.h" + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.matrix_eig"); + + dlib::rand rnd; + +// ---------------------------------------------------------------------------------------- + + template + const matrix randm(long r, long c) + { + matrix m(r,c); + for (long row = 0; row < m.nr(); ++row) + { + for (long col = 0; col < m.nc(); ++col) + { + m(row,col) = static_cast(rnd.get_random_double()); + } + } + + return m; + } + + template + const matrix randm() + { + matrix m; + for (long row = 0; row < m.nr(); ++row) + { + for (long col = 0; col < m.nc(); ++col) + { + m(row,col) = static_cast(rnd.get_random_double()); + } + } + + return m; + } + +// ---------------------------------------------------------------------------------------- + + template + void test_eigenvalue_impl ( const matrix_type& m, const eigenvalue_decomposition& test ) + { + typedef typename matrix_type::type type; + const type eps = 10*max(abs(m))*sqrt(std::numeric_limits::epsilon()); + dlog << LDEBUG << "test_eigenvalue(): " << m.nr() << " x " << m.nc() << " eps: " << eps; + print_spinner(); + + + DLIB_TEST(test.dim() == m.nr()); + + // make sure all the various ways of asking for the eigenvalues are actually returning a + // consistent set of eigenvalues. + DLIB_TEST(equal(real(test.get_eigenvalues()), test.get_real_eigenvalues(), eps)); + DLIB_TEST(equal(imag(test.get_eigenvalues()), test.get_imag_eigenvalues(), eps)); + DLIB_TEST(equal(real(diag(test.get_d())), test.get_real_eigenvalues(), eps)); + DLIB_TEST(equal(imag(diag(test.get_d())), test.get_imag_eigenvalues(), eps)); + + matrix eig1 ( real_eigenvalues(m)); + matrix eig2 ( test.get_real_eigenvalues()); + sort(&eig1(0), &eig1(0) + eig1.size()); + sort(&eig2(0), &eig2(0) + eig2.size()); + DLIB_TEST(max(abs(eig1 - eig2)) < eps); + + const matrix V = test.get_pseudo_v(); + const matrix D = test.get_pseudo_d(); + const matrix > CV = test.get_v(); + const matrix > CD = test.get_d(); + const matrix > CM = complex_matrix(m, uniform_matrix(m.nr(),m.nc(),0)); + + DLIB_TEST(V.nr() == test.dim()); + DLIB_TEST(V.nc() == test.dim()); + DLIB_TEST(D.nr() == test.dim()); + DLIB_TEST(D.nc() == test.dim()); + + // CD is a diagonal matrix + DLIB_TEST(diagm(diag(CD)) == CD); + + // verify that these things are actually eigenvalues and eigenvectors of m + DLIB_TEST_MSG(max(abs(m*V - V*D)) < eps, max(abs(m*V - V*D)) << " " << eps); + DLIB_TEST(max(norm(CM*CV - CV*CD)) < eps); + + // if m is a symmetric matrix + if (max(abs(m-trans(m))) < 1e-5) + { + dlog << LTRACE << "m is symmetric"; + // there aren't any imaginary eigenvalues + DLIB_TEST(max(abs(test.get_imag_eigenvalues())) < eps); + DLIB_TEST(diagm(diag(D)) == D); + + // only check the determinant against the eigenvalues for small matrices + // because for huge ones the determinant might be so big it overflows a floating point number. + if (m.nr() < 50) + { + const type mdet = det(m); + DLIB_TEST_MSG(std::abs(prod(test.get_real_eigenvalues()) - mdet) < std::abs(mdet)*sqrt(std::numeric_limits::epsilon()), + std::abs(prod(test.get_real_eigenvalues()) - mdet) <<" eps: " << std::abs(mdet)*sqrt(std::numeric_limits::epsilon()) + << " mdet: "<< mdet << " prod(eig): " << prod(test.get_real_eigenvalues()) + ); + } + + // V is orthogonal + DLIB_TEST(equal(V*trans(V), identity_matrix(test.dim()), eps)); + DLIB_TEST(equal(m , V*D*trans(V), eps)); + } + else + { + dlog << LTRACE << "m is NOT symmetric"; + DLIB_TEST_MSG(equal(m , V*D*inv(V), eps), max(abs(m - V*D*inv(V)))); + } + } + +// ---------------------------------------------------------------------------------------- + + template + void test_eigenvalue ( const matrix_type& m ) + { + typedef typename matrix_type::type type; + typedef typename matrix_type::mem_manager_type MM; + matrix mr(m); + matrix mc(m); + + { + eigenvalue_decomposition test(mr); + test_eigenvalue_impl(mr, test); + + eigenvalue_decomposition test_symm(make_symmetric(mr)); + test_eigenvalue_impl(make_symmetric(mr), test_symm); + } + + { + eigenvalue_decomposition test(mc); + test_eigenvalue_impl(mc, test); + + eigenvalue_decomposition test_symm(make_symmetric(mc)); + test_eigenvalue_impl(make_symmetric(mc), test_symm); + } + } + +// ---------------------------------------------------------------------------------------- + + void matrix_test_double() + { + + test_eigenvalue(10*randm(1,1)); + test_eigenvalue(10*randm(2,2)); + test_eigenvalue(10*randm(3,3)); + test_eigenvalue(10*randm(4,4)); + test_eigenvalue(10*randm(15,15)); + test_eigenvalue(10*randm(150,150)); + + test_eigenvalue(10*randm()); + test_eigenvalue(10*randm()); + test_eigenvalue(10*randm()); + } + +// ---------------------------------------------------------------------------------------- + + void matrix_test_float() + { + + test_eigenvalue(10*randm(1,1)); + test_eigenvalue(10*randm(2,2)); + test_eigenvalue(10*randm(3,3)); + test_eigenvalue(10*randm(4,4)); + test_eigenvalue(10*randm(15,15)); + test_eigenvalue(10*randm(50,50)); + + test_eigenvalue(10*randm()); + test_eigenvalue(10*randm()); + test_eigenvalue(10*randm()); + } + + template + void test_eigenvalue2() + { + for (int seed = 0; seed < 10; ++seed) + { + print_spinner(); + matrix H = gaussian_randm(dims,dims,seed); + H = H*trans(H); + + eigenvalue_decomposition > eig(H); + matrix HH = eig.get_pseudo_v()*diagm(eig.get_real_eigenvalues())*trans(eig.get_pseudo_v()); + DLIB_TEST_MSG(max(abs(H - HH))<1e-12, "dims: " << dims << " error: " << max(abs(H - HH))); + } + } + +// ---------------------------------------------------------------------------------------- + + class matrix_tester : public tester + { + public: + matrix_tester ( + ) : + tester ("test_matrix_eig", + "Runs tests on the matrix eigen decomp component.") + { + //rnd.set_seed(cast_to_string(time(0))); + } + + void perform_test ( + ) + { + dlog << LINFO << "seed string: " << rnd.get_seed(); + + dlog << LINFO << "begin testing with double"; + matrix_test_double(); + dlog << LINFO << "begin testing with float"; + matrix_test_float(); + + test_eigenvalue2<10>(); + test_eigenvalue2<11>(); + test_eigenvalue2<3>(); + test_eigenvalue2<2>(); + test_eigenvalue2<1>(); + } + } a; + +} + + + diff --git a/ml/dlib/dlib/test/matrix_lu.cpp b/ml/dlib/dlib/test/matrix_lu.cpp new file mode 100644 index 000000000..f5425b355 --- /dev/null +++ b/ml/dlib/dlib/test/matrix_lu.cpp @@ -0,0 +1,223 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include +#include +#include "../stl_checked.h" +#include "../array.h" +#include "../rand.h" +#include + +#include "tester.h" + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.matrix_lu"); + + dlib::rand rnd; + +// ---------------------------------------------------------------------------------------- + + template + const matrix symm(const mat_type& m) { return m*trans(m); } + +// ---------------------------------------------------------------------------------------- + + template + const matrix randmat(long r, long c) + { + matrix m(r,c); + for (long row = 0; row < m.nr(); ++row) + { + for (long col = 0; col < m.nc(); ++col) + { + m(row,col) = static_cast(rnd.get_random_double()); + } + } + + return m; + } + + template + const matrix randmat() + { + matrix m; + for (long row = 0; row < m.nr(); ++row) + { + for (long col = 0; col < m.nc(); ++col) + { + m(row,col) = static_cast(rnd.get_random_double()); + } + } + + return m; + } + +// ---------------------------------------------------------------------------------------- + + template + void test_lu ( const matrix_type& m) + { + typedef typename matrix_type::type type; + const type eps = 10*max(abs(m))*sqrt(std::numeric_limits::epsilon()); + dlog << LDEBUG << "test_lu(): " << m.nr() << " x " << m.nc() << " eps: " << eps; + print_spinner(); + + + lu_decomposition test(m); + + DLIB_TEST(test.is_square() == (m.nr() == m.nc())); + + DLIB_TEST(test.nr() == m.nr()); + DLIB_TEST(test.nc() == m.nc()); + + dlog << LDEBUG << "m.nr(): " << m.nr() << " m.nc(): " << m.nc(); + + type temp; + DLIB_TEST_MSG( (temp= max(abs(test.get_l()*test.get_u() - rowm(m,test.get_pivot())))) < eps,temp); + + if (test.is_square()) + { + // none of the matrices we should be passing in to test_lu() should be singular. + DLIB_TEST_MSG (abs(test.det()) > eps/100, "det: " << test.det() ); + dlog << LDEBUG << "big det: " << test.det(); + + DLIB_TEST(test.is_singular() == false); + + matrix m2; + matrix col; + + m2 = identity_matrix(m.nr()); + DLIB_TEST_MSG(equal(m*test.solve(m2), m2,eps),max(abs(m*test.solve(m2)- m2))); + m2 = randmat(m.nr(),5); + DLIB_TEST_MSG(equal(m*test.solve(m2), m2,eps),max(abs(m*test.solve(m2)- m2))); + m2 = randmat(m.nr(),1); + DLIB_TEST_MSG(equal(m*test.solve(m2), m2,eps),max(abs(m*test.solve(m2)- m2))); + col = randmat(m.nr(),1); + DLIB_TEST_MSG(equal(m*test.solve(col), col,eps),max(abs(m*test.solve(m2)- m2))); + + // now make us a singular matrix + if (m.nr() > 1) + { + matrix sm(m); + set_colm(sm,0) = colm(sm,1); + + lu_decomposition test2(sm); + DLIB_TEST_MSG( (temp= max(abs(test2.get_l()*test2.get_u() - rowm(sm,test2.get_pivot())))) < eps,temp); + + // these checks are only accurate for small matrices + if (test2.nr() < 100) + { + DLIB_TEST_MSG(test2.is_singular() == true,"det: " << test2.det()); + DLIB_TEST_MSG(abs(test2.det()) < eps,"det: " << test2.det()); + } + + } + } + + } + +// ---------------------------------------------------------------------------------------- + + void matrix_test_double() + { + + + test_lu(10*randmat(2,2)); + test_lu(10*randmat(1,1)); + test_lu(10*symm(randmat(2,2))); + test_lu(10*randmat(4,4)); + test_lu(10*randmat(9,4)); + test_lu(10*randmat(3,8)); + test_lu(10*randmat(15,15)); + test_lu(2*symm(randmat(15,15))); + test_lu(10*randmat(100,100)); + test_lu(10*randmat(137,200)); + test_lu(10*randmat(200,101)); + + test_lu(10*randmat()); + test_lu(10*randmat()); + test_lu(10*randmat()); + test_lu(10*randmat()); + test_lu(10*randmat()); + test_lu(10*randmat()); + test_lu(10*randmat()); + test_lu(10*randmat()); + test_lu(10*randmat()); + test_lu(10*randmat()); + + typedef matrix mat; + test_lu(mat(3*randmat(4,4))); + test_lu(mat(3*randmat(9,4))); + test_lu(mat(3*randmat(3,8))); + } + +// ---------------------------------------------------------------------------------------- + + void matrix_test_float() + { + + // ------------------------------- + + test_lu(3*randmat(1,1)); + test_lu(3*randmat(2,2)); + test_lu(3*randmat(4,4)); + test_lu(3*randmat(9,4)); + test_lu(3*randmat(3,8)); + test_lu(3*randmat(137,200)); + test_lu(3*randmat(200,101)); + + test_lu(3*randmat()); + test_lu(3*randmat()); + test_lu(3*randmat()); + test_lu(3*randmat()); + test_lu(3*randmat()); + test_lu(3*randmat()); + test_lu(3*randmat()); + test_lu(3*randmat()); + + typedef matrix mat; + test_lu(mat(3*randmat(4,4))); + test_lu(mat(3*randmat(9,4))); + test_lu(mat(3*randmat(3,8))); + } + +// ---------------------------------------------------------------------------------------- + + class matrix_tester : public tester + { + public: + matrix_tester ( + ) : + tester ("test_matrix_lu", + "Runs tests on the matrix LU component.") + { + //rnd.set_seed(cast_to_string(time(0))); + } + + void perform_test ( + ) + { + dlog << LINFO << "seed string: " << rnd.get_seed(); + + dlog << LINFO << "begin testing with double"; + matrix_test_double(); + dlog << LINFO << "begin testing with float"; + matrix_test_float(); + } + } a; + +} + + + diff --git a/ml/dlib/dlib/test/matrix_qr.cpp b/ml/dlib/dlib/test/matrix_qr.cpp new file mode 100644 index 000000000..e3c7c4e42 --- /dev/null +++ b/ml/dlib/dlib/test/matrix_qr.cpp @@ -0,0 +1,208 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include +#include +#include "../stl_checked.h" +#include "../array.h" +#include "../rand.h" +#include + +#include "tester.h" + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.matrix_qr"); + + dlib::rand rnd; + +// ---------------------------------------------------------------------------------------- + + template + const matrix symm(const mat_type& m) { return m*trans(m); } + +// ---------------------------------------------------------------------------------------- + + template + const matrix randmat(long r, long c) + { + matrix m(r,c); + for (long row = 0; row < m.nr(); ++row) + { + for (long col = 0; col < m.nc(); ++col) + { + m(row,col) = static_cast(rnd.get_random_double()); + } + } + + return m; + } + + template + const matrix randmat() + { + matrix m; + for (long row = 0; row < m.nr(); ++row) + { + for (long col = 0; col < m.nc(); ++col) + { + m(row,col) = static_cast(rnd.get_random_double()); + } + } + + return m; + } + +// ---------------------------------------------------------------------------------------- + + template + void test_qr ( const matrix_type& m) + { + typedef typename matrix_type::type type; + const type eps = 10*max(abs(m))*sqrt(std::numeric_limits::epsilon()); + dlog << LDEBUG << "test_qr(): " << m.nr() << " x " << m.nc() << " eps: " << eps; + print_spinner(); + + + qr_decomposition test(m); + + + DLIB_TEST(test.nr() == m.nr()); + DLIB_TEST(test.nc() == m.nc()); + + + type temp; + DLIB_TEST_MSG( (temp= max(abs(test.get_q()*test.get_r() - m))) < eps,temp); + + // none of the matrices we should be passing in to test_qr() should be non-full rank. + DLIB_TEST(test.is_full_rank() == true); + + if (m.nr() == m.nc()) + { + matrix m2; + matrix col; + + m2 = identity_matrix(m.nr()); + DLIB_TEST_MSG(equal(m*test.solve(m2), m2,eps),max(abs(m*test.solve(m2)- m2))); + m2 = randmat(m.nr(),5); + DLIB_TEST_MSG(equal(m*test.solve(m2), m2,eps),max(abs(m*test.solve(m2)- m2))); + m2 = randmat(m.nr(),1); + DLIB_TEST_MSG(equal(m*test.solve(m2), m2,eps),max(abs(m*test.solve(m2)- m2))); + col = randmat(m.nr(),1); + DLIB_TEST_MSG(equal(m*test.solve(col), col,eps),max(abs(m*test.solve(m2)- m2))); + } + else + { + DLIB_TEST_MSG(dlib::equal(pinv(m), test.solve(identity_matrix(m.nr())), eps), + max(abs(pinv(m) - test.solve(identity_matrix(m.nr())))) ); + } + + // now make us a non-full rank matrix + if (m.nc() > 1) + { + matrix sm(m); + set_colm(sm,0) = colm(sm,1); + + qr_decomposition test2(sm); + DLIB_TEST_MSG( (temp= max(abs(test.get_q()*test.get_r() - m))) < eps,temp); + + if (test2.nc() < 100) + { + DLIB_TEST_MSG(test2.is_full_rank() == false,"eps: " << eps); + } + + } + + } + +// ---------------------------------------------------------------------------------------- + + void matrix_test_double() + { + + test_qr(10*randmat(1,1)); + test_qr(10*randmat(2,2)); + test_qr(10*symm(randmat(2,2))); + test_qr(10*randmat(4,4)); + test_qr(10*randmat(9,4)); + test_qr(10*randmat(15,15)); + test_qr(2*symm(randmat(15,15))); + test_qr(10*randmat(100,100)); + test_qr(10*randmat(237,200)); + test_qr(10*randmat(200,101)); + + test_qr(10*randmat()); + test_qr(10*randmat()); + test_qr(10*randmat()); + test_qr(10*randmat()); + test_qr(10*randmat()); + test_qr(10*randmat()); + test_qr(10*randmat()); + + typedef matrix mat; + test_qr(mat(3*randmat(9,4))); + test_qr(mat(3*randmat(9,9))); + } + +// ---------------------------------------------------------------------------------------- + + void matrix_test_float() + { + + + test_qr(3*randmat(1,1)); + test_qr(3*randmat(2,2)); + test_qr(3*randmat(4,4)); + test_qr(3*randmat(9,4)); + test_qr(3*randmat(237,200)); + + test_qr(3*randmat()); + test_qr(3*randmat()); + test_qr(3*randmat()); + test_qr(3*randmat()); + test_qr(3*randmat()); + + typedef matrix mat; + test_qr(mat(3*randmat(5,4))); + test_qr(mat(3*randmat(9,9))); + } + +// ---------------------------------------------------------------------------------------- + + class matrix_tester : public tester + { + public: + matrix_tester ( + ) : + tester ("test_matrix_qr", + "Runs tests on the matrix QR component.") + { + //rnd.set_seed(cast_to_string(time(0))); + } + + void perform_test ( + ) + { + dlog << LINFO << "seed string: " << rnd.get_seed(); + + dlog << LINFO << "begin testing with double"; + matrix_test_double(); + dlog << LINFO << "begin testing with float"; + matrix_test_float(); + } + } a; + +} + + + diff --git a/ml/dlib/dlib/test/max_cost_assignment.cpp b/ml/dlib/dlib/test/max_cost_assignment.cpp new file mode 100644 index 000000000..852418764 --- /dev/null +++ b/ml/dlib/dlib/test/max_cost_assignment.cpp @@ -0,0 +1,157 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include +#include +#include "../rand.h" + +#include "tester.h" + + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.max_cost_assignment"); + +// ---------------------------------------------------------------------------------------- + + std::vector > permutations ( + matrix vals + ) + { + if (vals.size() == 0) + { + return std::vector >(); + } + else if (vals.size() == 1) + { + return std::vector >(1,std::vector(1,vals(0))); + } + + + std::vector > temp; + + + for (long i = 0; i < vals.size(); ++i) + { + const std::vector >& res = permutations(remove_col(vals,i)); + + for (unsigned long j = 0; j < res.size(); ++j) + { + temp.resize(temp.size()+1); + std::vector& part = temp.back(); + part.push_back(vals(i)); + part.insert(part.end(), res[j].begin(), res[j].end()); + } + } + + + return temp; + } + +// ---------------------------------------------------------------------------------------- + + template + std::vector brute_force_max_cost_assignment ( + matrix cost + ) + { + if (cost.size() == 0) + return std::vector(); + + const std::vector >& perms = permutations(range(0,cost.nc()-1)); + + T best_cost = std::numeric_limits::min(); + unsigned long best_idx = 0; + for (unsigned long i = 0; i < perms.size(); ++i) + { + const T temp = assignment_cost(cost, perms[i]); + if (temp > best_cost) + { + best_idx = i; + best_cost = temp; + } + } + + return perms[best_idx]; + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class test_max_cost_assignment : public tester + { + public: + test_max_cost_assignment ( + ) : + tester ("test_max_cost_assignment", + "Runs tests on the max_cost_assignment function.") + {} + + dlib::rand rnd; + + template + void test_hungarian() + { + long size = rnd.get_random_32bit_number()%7; + long range = rnd.get_random_32bit_number()%100; + matrix cost = matrix_cast(randm(size,size,rnd)*range) - range/2; + + // use a uniform cost matrix sometimes + if ((rnd.get_random_32bit_number()%100) == 0) + cost = rnd.get_random_32bit_number()%100; + + // negate the cost matrix every now and then + if ((rnd.get_random_32bit_number()%100) == 0) + cost = -cost; + + + std::vector assign = brute_force_max_cost_assignment(cost); + T true_eval = assignment_cost(cost, assign); + assign = max_cost_assignment(cost); + DLIB_TEST(assignment_cost(cost,assign) == true_eval); + assign = max_cost_assignment(matrix_cast(cost)); + DLIB_TEST(assignment_cost(cost,assign) == true_eval); + + + cost = matrix_cast(randm(size,size,rnd)*range); + assign = brute_force_max_cost_assignment(cost); + true_eval = assignment_cost(cost, assign); + assign = max_cost_assignment(cost); + DLIB_TEST(assignment_cost(cost,assign) == true_eval); + assign = max_cost_assignment(matrix_cast(cost)); + DLIB_TEST(assignment_cost(cost,assign) == true_eval); + assign = max_cost_assignment(matrix_cast::type>(cost)); + DLIB_TEST(assignment_cost(cost,assign) == true_eval); + } + + void perform_test ( + ) + { + for (long i = 0; i < 1000; ++i) + { + if ((i%100) == 0) + print_spinner(); + + test_hungarian(); + test_hungarian(); + test_hungarian(); + test_hungarian(); + } + } + } a; + +} + + + diff --git a/ml/dlib/dlib/test/max_sum_submatrix.cpp b/ml/dlib/dlib/test/max_sum_submatrix.cpp new file mode 100644 index 000000000..34b4756ff --- /dev/null +++ b/ml/dlib/dlib/test/max_sum_submatrix.cpp @@ -0,0 +1,177 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#include +#include +#include +#include +#include +#include + +#include "tester.h" + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.max_sum_submatrix"); + +// ---------------------------------------------------------------------------------------- + + bool order_rects ( + const rectangle& a, + const rectangle& b + ) + { + if (a.left() < b.left()) return true; + else if (a.left() > b.left()) return false; + + if (a.right() < b.right()) return true; + else if (a.right() > b.right()) return false; + + if (a.top() < b.top()) return true; + else if (a.top() > b.top()) return false; + + if (a.bottom() < b.bottom()) return true; + else if (a.bottom() > b.bottom()) return false; + + return false; + } + + void run_test( + const int num + ) + { + static dlib::rand rnd; + + matrix mat, mask; + + mat.set_size(rnd.get_random_32bit_number()%1000 + 1, + rnd.get_random_32bit_number()%1000 + 1); + mask.set_size(mat.nr(), mat.nc()); + mask = 0; + + mat = -10000; + + std::vector true_rects; + + for (int i = 0; i < num; ++i) + { + const int width = rnd.get_random_32bit_number()%100 + 1; + const int height = rnd.get_random_32bit_number()%100 + 1; + + rectangle rect = centered_rect(rnd.get_random_16bit_number()%mat.nc(), + rnd.get_random_16bit_number()%mat.nr(), + width,height); + rect = get_rect(mat).intersect(rect); + + // make sure this new rectangle doesn't overlap or abut any others + if (sum(subm(mask,grow_rect(rect,1).intersect(get_rect(mask)))) == 0) + { + set_subm(mat, rect) = rnd.get_random_8bit_number()%100 + 1; + set_subm(mask, rect) = 1; + true_rects.push_back(rect); + } + } + + + std::vector res; + res = max_sum_submatrix(mat, true_rects.size()+10, 0); + + DLIB_TEST(res.size() == true_rects.size()); + + // make sure big rectangles come first + for (unsigned long i = 0; i+1 < res.size(); ++i) + { + DLIB_TEST(sum(subm(mat,res[i])) >= sum(subm(mat,res[i+1]))); + } + + // make sure rectangles match + sort(true_rects.begin(), true_rects.end(), order_rects); + sort(res.begin(), res.end(), order_rects); + for (unsigned long i = 0; i < res.size(); ++i) + { + DLIB_TEST_MSG(res[i] == true_rects[i], + "i: " << i << " res[i]: " << res[i] << " true_rects[i]: " << true_rects[i]); + } + + } + +// ---------------------------------------------------------------------------------------- + + template + void run_test2() + { + matrix mat(100,100); + mat = 1; + std::vector res = max_sum_submatrix(mat, 0, 0); + + DLIB_TEST(res.size() == 0); + res = max_sum_submatrix(mat, 1, 0); + DLIB_TEST(res.size() == 1); + DLIB_TEST(res[0] == get_rect(mat)); + res = max_sum_submatrix(mat, 3, 0); + DLIB_TEST(res.size() == 1); + DLIB_TEST(res[0] == get_rect(mat)); + res = max_sum_submatrix(mat, 3, 10); + DLIB_TEST(res.size() == 1); + DLIB_TEST(res[0] == get_rect(mat)); + + res = max_sum_submatrix(mat, 3, mat.size()); + DLIB_TEST(res.size() == 0); + + mat = -1; + res = max_sum_submatrix(mat, 1, 0); + DLIB_TEST(res.size() == 0); + + const rectangle rect1 = rectangle(10,10,40,40); + const rectangle rect2 = rectangle(35,35,80,80); + + set_subm(mat, rect1) = 2; + set_subm(mat, rect2) = 1; + res = max_sum_submatrix(mat, 3, 0); + DLIB_TEST(res.size() == 2); + DLIB_TEST(res[0] == rect2); + DLIB_TEST(res[1] == rect1); + + res = max_sum_submatrix(mat, 3, 2*rect1.area() - 2*(rect1.intersect(rect2)).area()); + DLIB_TEST(res.size() == 1); + DLIB_TEST(res[0] == rect2); + } + +// ---------------------------------------------------------------------------------------- + + + class test_max_sum_submatrix : public tester + { + public: + test_max_sum_submatrix ( + ) : + tester ("test_max_sum_submatrix", + "Runs tests on the max_sum_submatrix() function.") + {} + + void perform_test ( + ) + { + for (int j = 0; j < 5; ++j) + { + print_spinner(); + for (int i = 0; i < 40; ++i) + run_test(i); + } + + run_test2(); + run_test2(); + run_test2(); + run_test2(); + run_test2(); + } + } a; + +} + + + + diff --git a/ml/dlib/dlib/test/md5.cpp b/ml/dlib/dlib/test/md5.cpp new file mode 100644 index 000000000..15fafac3c --- /dev/null +++ b/ml/dlib/dlib/test/md5.cpp @@ -0,0 +1,71 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include + +#include "tester.h" + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.md5"); + + void md5_test ( + ) + /*! + ensures + - runs tests on the md5 stuff compliance with the specs + !*/ + { + + DLIB_TEST(md5 ("") == "d41d8cd98f00b204e9800998ecf8427e"); + DLIB_TEST(md5 ("a") == "0cc175b9c0f1b6a831c399e269772661"); + DLIB_TEST(md5 ("abc") == "900150983cd24fb0d6963f7d28e17f72"); + DLIB_TEST(md5 ("message digest") == "f96b697d7cb7938d525a2f31aaf161d0"); + DLIB_TEST(md5 ("abcdefghijklmnopqrstuvwxyz") == "c3fcd3d76192e4007dfb496cca67e13b"); + DLIB_TEST(md5 ("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789") == "d174ab98d277d9f5a5611c2c9f419d9f"); + DLIB_TEST(md5 ("12345678901234567890123456789012345678901234567890123456789012345678901234567890") == "57edf4a22be3c955ac49da2e2107b67a"); + + // make sure the two versions of md5() always agree + for (int num = 0; num < 2000; ++num) + { + std::string temp; + for (int i = 0; i < num; ++i) + temp += 'a'; + + istringstream str(temp); + DLIB_TEST(md5(temp) == md5(str)); + } + + } + + + class md5_tester : public tester + { + public: + md5_tester ( + ) : + tester ("test_md5", + "Runs tests on the md5 component.") + {} + + void perform_test ( + ) + { + md5_test(); + } + } a; + +} + + + diff --git a/ml/dlib/dlib/test/member_function_pointer.cpp b/ml/dlib/dlib/test/member_function_pointer.cpp new file mode 100644 index 000000000..72aa3aa35 --- /dev/null +++ b/ml/dlib/dlib/test/member_function_pointer.cpp @@ -0,0 +1,553 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include + +#include + +#include "tester.h" + +namespace +{ + using namespace test; + using namespace std; + using namespace dlib; + + logger dlog("test.member_function_pointer"); + + class mfp_test_helper_other + { + public: + mfp_test_helper_other ( + ): i(-1) {} + + + mutable int i; + + + void go0 ( + ) { i = 0; } + void go1 ( + int v1 + ) { i = 1*v1; } + void go2 ( + int v1,int v2 + ) { i = 2*v1*v2; } + void go3 ( + int v1,int v2,int v3 + ) { i = 3*v1*v2*v3; } + void go4 ( + int v1,int v2,int v3,int v4 + ) { i = 4*v1*v2*v3*v4; } + + }; + + + class mfp_test_helper + { + public: + mfp_test_helper ( + ): i(-1) {} + + + mutable int i; + + + void go0 ( + ) { i = 0; } + void go1 ( + int v1 + ) { i = 1*v1; } + void go2 ( + int v1,int v2 + ) { i = 2*v1*v2; } + void go3 ( + int v1,int v2,int v3 + ) { i = 3*v1*v2*v3; } + void go4 ( + int v1,int v2,int v3,int v4 + ) { i = 4*v1*v2*v3*v4; } + + }; + + class mfp_test_helper_const + { + public: + mfp_test_helper_const ( + ): i(-1) {} + + + mutable int i; + + void go0 ( + ) const { i = 0; } + void go1 ( + int v1 + ) const { i = 1*v1; } + void go2 ( + int v1,int v2 + ) const { i = 2*v1*v2; } + void go3 ( + int v1,int v2,int v3 + ) const { i = 3*v1*v2*v3; } + void go4 ( + int v1,int v2,int v3,int v4 + ) const { i = 4*v1*v2*v3*v4; } + }; + + template < + template class mfp, + typename test_helper + > + void member_function_pointer_kernel_test ( + ) + /*! + requires + - mfp is an implementation of member_function_pointer/member_function_pointer_kernel_abstract.h + ensures + - runs tests on mfp for compliance with the specs + !*/ + { + + + test_helper helper; + + mfp<> a0, b0; + mfp a1, b1; + mfp a2, b2; + mfp a3, b3; + mfp a4, b4; + + mfp<> a0c, b0c; + mfp a1c, b1c; + mfp a2c, b2c; + mfp a3c, b3c; + mfp a4c, b4c; + + DLIB_TEST(a0c == b0c); + DLIB_TEST(a1c == b1c); + DLIB_TEST(a2c == b2c); + DLIB_TEST(a3c == b3c); + DLIB_TEST(a4c == b4c); + DLIB_TEST((a0c != b0c) == false); + DLIB_TEST((a1c != b1c) == false); + DLIB_TEST((a2c != b2c) == false); + DLIB_TEST((a3c != b3c) == false); + DLIB_TEST((a4c != b4c) == false); + + DLIB_TEST(a0.is_set() == false); + DLIB_TEST(b0.is_set() == false); + DLIB_TEST(a0c.is_set() == false); + DLIB_TEST(b0c.is_set() == false); + + DLIB_TEST(!a0 ); + DLIB_TEST(!b0 ); + DLIB_TEST(!a0c); + DLIB_TEST(!b0c); + + DLIB_TEST(a1.is_set() == false); + DLIB_TEST(b1.is_set() == false); + DLIB_TEST(a1c.is_set() == false); + DLIB_TEST(b1c.is_set() == false); + + DLIB_TEST(!a1 ); + DLIB_TEST(!b1 ); + DLIB_TEST(!a1c); + DLIB_TEST(!b1c); + + + DLIB_TEST(a2.is_set() == false); + DLIB_TEST(b2.is_set() == false); + DLIB_TEST(a2c.is_set() == false); + DLIB_TEST(b2c.is_set() == false); + + DLIB_TEST(!a2); + DLIB_TEST(!b2); + DLIB_TEST(!a2c); + DLIB_TEST(!b2c); + + DLIB_TEST(a3.is_set() == false); + DLIB_TEST(b3.is_set() == false); + DLIB_TEST(a3c.is_set() == false); + DLIB_TEST(b3c.is_set() == false); + + DLIB_TEST(!a3); + DLIB_TEST(!b3); + DLIB_TEST(!a3c); + DLIB_TEST(!b3c); + + DLIB_TEST(a4.is_set() == false); + DLIB_TEST(b4.is_set() == false); + DLIB_TEST(a4c.is_set() == false); + DLIB_TEST(b4c.is_set() == false); + + DLIB_TEST(!a4); + DLIB_TEST(!b4); + DLIB_TEST(!a4c); + DLIB_TEST(!b4c); + + a0.set(helper,&test_helper::go0); + a0c.set(helper,&test_helper::go0); + DLIB_TEST(a0.is_set() == true); + DLIB_TEST(a0c.is_set() == true); + DLIB_TEST(b0.is_set() == false); + DLIB_TEST(b0c.is_set() == false); + + DLIB_TEST(a0); + DLIB_TEST(a0c); + DLIB_TEST(!b0); + DLIB_TEST(!b0c); + + a0 = a0; + DLIB_TEST(a0 == a0); + DLIB_TEST(!(a0 != a0)); + DLIB_TEST(a0.is_set() == true); + DLIB_TEST(a0c.is_set() == true); + DLIB_TEST(b0.is_set() == false); + DLIB_TEST(b0c.is_set() == false); + + DLIB_TEST(a0); + DLIB_TEST(a0c); + DLIB_TEST(!b0); + DLIB_TEST(!b0c); + + swap(a0,b0); + swap(a0c,b0c); + DLIB_TEST(a0.is_set() == false); + DLIB_TEST(a0c.is_set() == false); + DLIB_TEST(b0.is_set() == true); + DLIB_TEST(b0c.is_set() == true); + + DLIB_TEST(!a0); + DLIB_TEST(!a0c); + DLIB_TEST(b0); + DLIB_TEST(b0c); + + a0 = b0; + DLIB_TEST(a0 == a0); + DLIB_TEST(a0 == b0); + DLIB_TEST(!(a0 != b0)); + DLIB_TEST(a0.is_set() == true); + DLIB_TEST(a0c.is_set() == false); + DLIB_TEST(b0.is_set() == true); + DLIB_TEST(b0c.is_set() == true); + + DLIB_TEST(a0 ); + DLIB_TEST(!a0c); + DLIB_TEST(b0); + DLIB_TEST(b0c); + + + a0.clear(); + a0c.clear(); + b0.clear(); + b0c.clear(); + DLIB_TEST(a0.is_set() == false); + DLIB_TEST(a0c.is_set() == false); + DLIB_TEST(b0.is_set() == false); + DLIB_TEST(b0c.is_set() == false); + + + a1.set(helper,&test_helper::go1); + a1c.set(helper,&test_helper::go1); + DLIB_TEST(a1.is_set() == true); + DLIB_TEST(a1c.is_set() == true); + DLIB_TEST(b1.is_set() == false); + DLIB_TEST(b1c.is_set() == false); + swap(a1,b1); + swap(a1c,b1c); + DLIB_TEST(a1.is_set() == false); + DLIB_TEST(a1c.is_set() == false); + DLIB_TEST(b1.is_set() == true); + DLIB_TEST(b1c.is_set() == true); + + DLIB_TEST(!a1); + DLIB_TEST(!a1c); + DLIB_TEST(b1); + DLIB_TEST(b1c); + + + a1 = b1; + DLIB_TEST(a1 == a1); + DLIB_TEST(a1 == b1); + DLIB_TEST(!(a1 != b1)); + DLIB_TEST(a1.is_set() == true); + DLIB_TEST(a1c.is_set() == false); + DLIB_TEST(b1.is_set() == true); + DLIB_TEST(b1c.is_set() == true); + + + a1.clear(); + a1c.clear(); + b1.clear(); + b1c.clear(); + DLIB_TEST(a1.is_set() == false); + DLIB_TEST(a1c.is_set() == false); + DLIB_TEST(b1.is_set() == false); + DLIB_TEST(b1c.is_set() == false); + + + a2.set(helper,&test_helper::go2); + a2c.set(helper,&test_helper::go2); + DLIB_TEST(a2.is_set() == true); + DLIB_TEST(a2c.is_set() == true); + DLIB_TEST(b2.is_set() == false); + DLIB_TEST(b2c.is_set() == false); + swap(a2,b2); + swap(a2c,b2c); + DLIB_TEST(a2.is_set() == false); + DLIB_TEST(a2c.is_set() == false); + DLIB_TEST(b2.is_set() == true); + DLIB_TEST(b2c.is_set() == true); + + DLIB_TEST(!a2); + DLIB_TEST(!a2c); + DLIB_TEST(b2); + DLIB_TEST(b2c); + if (b2) + { + } + else + { + DLIB_TEST(false); + } + + if (a2c) + { + DLIB_TEST(false); + } + else + { + DLIB_TEST(true); + } + + a2 = b2; + DLIB_TEST(a2 == a2); + DLIB_TEST(a2 == b2); + DLIB_TEST(!(a2 != b2)); + DLIB_TEST(a2.is_set() == true); + DLIB_TEST(a2c.is_set() == false); + DLIB_TEST(b2.is_set() == true); + DLIB_TEST(b2c.is_set() == true); + + a2.clear(); + a2c.clear(); + b2.clear(); + b2c.clear(); + DLIB_TEST(a2.is_set() == false); + DLIB_TEST(a2c.is_set() == false); + DLIB_TEST(b2.is_set() == false); + DLIB_TEST(b2c.is_set() == false); + + + a3.set(helper,&test_helper::go3); + a3c.set(helper,&test_helper::go3); + DLIB_TEST(a3.is_set() == true); + DLIB_TEST(a3c.is_set() == true); + DLIB_TEST(b3.is_set() == false); + DLIB_TEST(b3c.is_set() == false); + swap(a3,b3); + swap(a3c,b3c); + DLIB_TEST(a3.is_set() == false); + DLIB_TEST(a3c.is_set() == false); + DLIB_TEST(b3.is_set() == true); + DLIB_TEST(b3c.is_set() == true); + + a3 = b3; + DLIB_TEST(a3 == a3); + DLIB_TEST(a3 == b3); + DLIB_TEST(!(a3 != b3)); + DLIB_TEST(a3.is_set() == true); + DLIB_TEST(a3c.is_set() == false); + DLIB_TEST(b3.is_set() == true); + DLIB_TEST(b3c.is_set() == true); + + + a3.clear(); + a3c.clear(); + b3.clear(); + b3c.clear(); + DLIB_TEST(a3.is_set() == false); + DLIB_TEST(a3c.is_set() == false); + DLIB_TEST(b3.is_set() == false); + DLIB_TEST(b3c.is_set() == false); + + + a4.set(helper,&test_helper::go4); + a4c.set(helper,&test_helper::go4); + DLIB_TEST(a4.is_set() == true); + DLIB_TEST(a4c.is_set() == true); + DLIB_TEST(b4.is_set() == false); + DLIB_TEST(b4c.is_set() == false); + swap(a4,b4); + swap(a4c,b4c); + DLIB_TEST(a4.is_set() == false); + DLIB_TEST(a4c.is_set() == false); + DLIB_TEST(b4.is_set() == true); + DLIB_TEST(b4c.is_set() == true); + + a4 = b4; + a4 = b4; + a4 = b4; + a4 = b4; + DLIB_TEST(a4 == a4); + DLIB_TEST(a4 == b4); + DLIB_TEST(!(a4 != b4)); + DLIB_TEST(a4.is_set() == true); + DLIB_TEST(a4c.is_set() == false); + DLIB_TEST(b4.is_set() == true); + DLIB_TEST(b4c.is_set() == true); + + + a4.clear(); + a4c.clear(); + b4.clear(); + b4c.clear(); + DLIB_TEST(a4.is_set() == false); + DLIB_TEST(a4c.is_set() == false); + DLIB_TEST(b4.is_set() == false); + DLIB_TEST(b4c.is_set() == false); + + + a0.set(helper,&test_helper::go0); + a0c.set(helper,&test_helper::go0); + b0 = a0; + b0c = a0c; + helper.i = -1; + a0(); + DLIB_TEST(helper.i == 0); + helper.i = -1; + b0(); + DLIB_TEST(helper.i == 0); + helper.i = -1; + a0c(); + DLIB_TEST(helper.i == 0); + helper.i = -1; + b0c(); + DLIB_TEST(helper.i == 0); + + + a1.set(helper,&test_helper::go1); + a1c.set(helper,&test_helper::go1); + b1 = a1; + b1c = a1c; + helper.i = -1; + a1(1); + DLIB_TEST(helper.i == 1); + helper.i = -1; + b1(10); + DLIB_TEST(helper.i == 1*10); + helper.i = -1; + a1c(20); + DLIB_TEST(helper.i == 1*20); + helper.i = -1; + b1c(30); + DLIB_TEST(helper.i == 1*30); + + + a2.set(helper,&test_helper::go2); + a2c.set(helper,&test_helper::go2); + b2 = a2; + b2c = a2c; + helper.i = -1; + a2(1,2); + DLIB_TEST(helper.i == 2*1*2); + helper.i = -1; + b2(3,4); + DLIB_TEST(helper.i == 2*3*4); + helper.i = -1; + a2c(5,6); + DLIB_TEST(helper.i == 2*5*6); + helper.i = -1; + b2c(7,8); + DLIB_TEST(helper.i == 2*7*8); + + + a3.set(helper,&test_helper::go3); + a3c.set(helper,&test_helper::go3); + b3 = a3; + b3c = a3c; + helper.i = -1; + a3(1,2,3); + DLIB_TEST(helper.i == 3*1*2*3); + helper.i = -1; + b3(4,5,6); + DLIB_TEST(helper.i == 3*4*5*6); + helper.i = -1; + a3c(7,8,9); + DLIB_TEST(helper.i == 3*7*8*9); + helper.i = -1; + b3c(1,2,3); + DLIB_TEST(helper.i == 3*1*2*3); + + + a4.set(helper,&test_helper::go4); + a4c.set(helper,&test_helper::go4); + DLIB_TEST(a4 == a4c); + b4 = a4; + b4c = a4c; + helper.i = -1; + a4(1,2,3,4); + DLIB_TEST(helper.i == 4*1*2*3*4); + helper.i = -1; + b4(5,6,7,8); + DLIB_TEST(helper.i == 4*5*6*7*8); + helper.i = -1; + a4c(9,1,2,3); + DLIB_TEST(helper.i == 4*9*1*2*3); + helper.i = -1; + b4c(4,5,6,7); + DLIB_TEST(helper.i == 4*4*5*6*7); + + DLIB_TEST(a4 == b4); + DLIB_TEST(a4); + DLIB_TEST(a4 == b4); + a4.clear(); + DLIB_TEST(a4 != b4); + DLIB_TEST(!a4); + DLIB_TEST(a4 == 0); + DLIB_TEST(a4 == a4); + a4 = a4; + DLIB_TEST(a4 != b4); + DLIB_TEST(!a4); + DLIB_TEST(a4 == a4); + mfp_test_helper_other other; + a4.set(other,&mfp_test_helper_other::go4); + DLIB_TEST(a4 != b4); + DLIB_TEST(a4); + DLIB_TEST(a4 == a4); + a4.set(helper,&test_helper::go4); + DLIB_TEST(a4 == b4); + DLIB_TEST(a4); + DLIB_TEST(a4 == a4); + + + + } + + + + class member_function_pointer_tester : public tester + { + public: + member_function_pointer_tester ( + ) : + tester ("test_member_function_pointer", + "Runs tests on the member_function_pointer component.") + {} + + void perform_test ( + ) + { + member_function_pointer_kernel_test(); + member_function_pointer_kernel_test(); + } + } a; + +} + + diff --git a/ml/dlib/dlib/test/metaprogramming.cpp b/ml/dlib/dlib/test/metaprogramming.cpp new file mode 100644 index 000000000..344d5ca3a --- /dev/null +++ b/ml/dlib/dlib/test/metaprogramming.cpp @@ -0,0 +1,94 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include + +#include "tester.h" + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.metaprogramming"); + + + void metaprogramming_test ( + ) + /*! + ensures + - runs tests on template metaprogramming objects and functions for compliance with the specs + !*/ + { + + print_spinner(); + + DLIB_TEST(is_signed_type::value == true); + DLIB_TEST(is_signed_type::value == true); + DLIB_TEST(is_signed_type::value == true); + DLIB_TEST(is_signed_type::value == true); + DLIB_TEST(is_unsigned_type::value == false); + DLIB_TEST(is_unsigned_type::value == false); + DLIB_TEST(is_unsigned_type::value == false); + DLIB_TEST(is_unsigned_type::value == false); + + DLIB_TEST(is_unsigned_type::value == true); + DLIB_TEST(is_unsigned_type::value == true); + DLIB_TEST(is_unsigned_type::value == true); + DLIB_TEST(is_unsigned_type::value == true); + DLIB_TEST(is_signed_type::value == false); + DLIB_TEST(is_signed_type::value == false); + DLIB_TEST(is_signed_type::value == false); + DLIB_TEST(is_signed_type::value == false); + + + COMPILE_TIME_ASSERT(is_signed_type::value == true); + COMPILE_TIME_ASSERT(is_signed_type::value == true); + COMPILE_TIME_ASSERT(is_signed_type::value == true); + COMPILE_TIME_ASSERT(is_signed_type::value == true); + COMPILE_TIME_ASSERT(is_unsigned_type::value == false); + COMPILE_TIME_ASSERT(is_unsigned_type::value == false); + COMPILE_TIME_ASSERT(is_unsigned_type::value == false); + COMPILE_TIME_ASSERT(is_unsigned_type::value == false); + + COMPILE_TIME_ASSERT(is_unsigned_type::value == true); + COMPILE_TIME_ASSERT(is_unsigned_type::value == true); + COMPILE_TIME_ASSERT(is_unsigned_type::value == true); + COMPILE_TIME_ASSERT(is_unsigned_type::value == true); + COMPILE_TIME_ASSERT(is_signed_type::value == false); + COMPILE_TIME_ASSERT(is_signed_type::value == false); + COMPILE_TIME_ASSERT(is_signed_type::value == false); + COMPILE_TIME_ASSERT(is_signed_type::value == false); + + + } + + + + + class metaprogramming_tester : public tester + { + public: + metaprogramming_tester ( + ) : + tester ("test_metaprogramming", + "Runs tests on the metaprogramming objects and functions.") + {} + + void perform_test ( + ) + { + metaprogramming_test(); + } + } a; + +} + + + diff --git a/ml/dlib/dlib/test/mpc.cpp b/ml/dlib/dlib/test/mpc.cpp new file mode 100644 index 000000000..c0a98dd17 --- /dev/null +++ b/ml/dlib/dlib/test/mpc.cpp @@ -0,0 +1,346 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include + +#include +#include +#include "tester.h" + +namespace +{ + using namespace test; + using namespace std; + using namespace dlib; + + logger dlog("test.mpc"); + + template < + typename EXP1, + typename EXP2, + typename T, long NR, long NC, typename MM, typename L + > + unsigned long solve_qp_box_using_smo ( + const matrix_exp& _Q, + const matrix_exp& _b, + matrix& alpha, + matrix& lower, + matrix& upper, + T eps, + unsigned long max_iter + ) + /*! + ensures + - solves: 0.5*trans(x)*Q*x + trans(b)*x where x is box constrained. + !*/ + { + const_temp_matrix Q(_Q); + const_temp_matrix b(_b); + //cout << "IN QP SOLVER" << endl; + //cout << "max eig: " << max(real_eigenvalues(Q)) << endl; + //cout << "min eig: " << min(real_eigenvalues(Q)) << endl; + //cout << "Q: \n" << Q << endl; + //cout << "b: \n" << b << endl; + + // make sure requires clause is not broken + DLIB_ASSERT(Q.nr() == Q.nc() && + alpha.size() == lower.size() && + alpha.size() == upper.size() && + is_col_vector(b) && + is_col_vector(alpha) && + is_col_vector(lower) && + is_col_vector(upper) && + b.size() == alpha.size() && + b.size() == Q.nr() && + alpha.size() > 0 && + 0 <= min(alpha-lower) && + 0 <= max(upper-alpha) && + eps > 0 && + max_iter > 0, + "\t unsigned long solve_qp_box_using_smo()" + << "\n\t Invalid arguments were given to this function" + << "\n\t Q.nr(): " << Q.nr() + << "\n\t Q.nc(): " << Q.nc() + << "\n\t is_col_vector(b): " << is_col_vector(b) + << "\n\t is_col_vector(alpha): " << is_col_vector(alpha) + << "\n\t is_col_vector(lower): " << is_col_vector(lower) + << "\n\t is_col_vector(upper): " << is_col_vector(upper) + << "\n\t b.size(): " << b.size() + << "\n\t alpha.size(): " << alpha.size() + << "\n\t lower.size(): " << lower.size() + << "\n\t upper.size(): " << upper.size() + << "\n\t Q.nr(): " << Q.nr() + << "\n\t min(alpha-lower): " << min(alpha-lower) + << "\n\t max(upper-alpha): " << max(upper-alpha) + << "\n\t eps: " << eps + << "\n\t max_iter: " << max_iter + ); + + + // Compute f'(alpha) (i.e. the gradient of f(alpha)) for the current alpha. + matrix df = Q*alpha + b; + matrix QQ = reciprocal_max(diag(Q)); + + + unsigned long iter = 0; + for (; iter < max_iter; ++iter) + { + T max_df = 0; + long best_r =0; + for (long r = 0; r < Q.nr(); ++r) + { + if (alpha(r) <= lower(r) && df(r) > 0) + ;//alpha(r) = lower(r); + else if (alpha(r) >= upper(r) && df(r) < 0) + ;//alpha(r) = upper(r); + else if (std::abs(df(r)) > max_df) + { + best_r = r; + max_df = std::abs(df(r)); + } + } + + //for (long r = 0; r < Q.nr(); ++r) + long r = best_r; + { + + const T old_alpha = alpha(r); + alpha(r) = -(df(r)-Q(r,r)*alpha(r))*QQ(r); + if (alpha(r) < lower(r)) + alpha(r) = lower(r); + else if (alpha(r) > upper(r)) + alpha(r) = upper(r); + + const T delta = old_alpha-alpha(r); + + // Now update the gradient. We will perform the equivalent of: df = Q*alpha + b; + for(long k = 0; k < df.nr(); ++k) + df(k) -= Q(r,k)*delta; + } + + if (max_df < eps) + break; + } + //cout << "df: \n" << trans(df) << endl; + //cout << "objective value: " << 0.5*trans(alpha)*Q*alpha + trans(b)*alpha << endl; + + return iter+1; + } + +// ---------------------------------------------------------------------------------------- + + namespace impl_mpc + { + template + void pack( + matrix& out, + const std::vector >& item + ) + { + DLIB_CASSERT(item.size() != 0,""); + out.set_size(item.size()*item[0].size()); + long j = 0; + for (unsigned long i = 0; i < item.size(); ++i) + for (long r = 0; r < item[i].size(); ++r) + out(j++) = item[i](r); + } + + template + void pack( + matrix& out, + const matrix& item, + const long num + ) + { + out.set_size(item.size()*num); + long j = 0; + for (long r = 0; r < num; ++r) + for (long i = 0; i < item.size(); ++i) + out(j++) = item(i); + } + + template + void unpack( + std::vector >& out, + const matrix& item + ) + { + DLIB_CASSERT(out.size() != 0,""); + DLIB_CASSERT((long)out.size()*out[0].size() == item.size(),""); + long j = 0; + for (unsigned long i = 0; i < out.size(); ++i) + for (long r = 0; r < out[i].size(); ++r) + out[i](r) = item(j++); + } + } + + template + unsigned long solve_linear_mpc ( + const matrix& A, + const matrix& B, + const matrix& C, + const matrix& Q, + const matrix& R, + const matrix& _lower, + const matrix& _upper, + const std::vector >& target, + const matrix& initial_state, + std::vector >& controls // input and output + ) + { + using namespace impl_mpc; + DLIB_CASSERT(target.size() == controls.size(),""); + + matrix K(B.nr()*controls.size(), B.nc()*controls.size()); + matrix M(B.nr()*controls.size()); + + // compute powers of A: Apow[i] == A^i + std::vector > Apow(controls.size()); + Apow[0] = identity_matrix(A); + for (unsigned long i = 1; i < Apow.size(); ++i) + Apow[i] = A*Apow[i-1]; + + // fill in K + K = 0; + for (unsigned long r = 0; r < controls.size(); ++r) + for (unsigned long c = 0; c <= r; ++c) + set_subm(K,r*B.nr(),c*B.nc(), B.nr(), B.nc()) = Apow[r-c]*B; + + // fill in M + set_subm(M,0*A.nr(),0,A.nr(),1) = A*initial_state + C; + for (unsigned long i = 1; i < controls.size(); ++i) + set_subm(M,i*A.nr(),0,A.nr(),1) = A*subm(M,(i-1)*A.nr(),0,A.nr(),1) + C; + + //cout << "M: \n" << M << endl; + //cout << "K: \n" << K << endl; + + matrix t, v, lower, upper; + pack(t, target); + pack(v, controls); + pack(lower, _lower, controls.size()); + pack(upper, _upper, controls.size()); + + + matrix QQ(K.nr(),K.nr()), RR(K.nc(),K.nc()); + QQ = 0; + RR = 0; + for (unsigned long c = 0; c < controls.size(); ++c) + { + set_subm(QQ,c*Q.nr(),c*Q.nr(),Q.nr(),Q.nr()) = diagm(Q); + set_subm(RR,c*R.nr(),c*R.nr(),R.nr(),R.nr()) = diagm(R); + } + + matrix m1 = trans(K)*QQ*K+RR; + matrix m2 = trans(K)*QQ*(M-t); + + + // run the solver... + unsigned long iter; + iter = solve_qp_box_using_smo( + m1, + m2, + v, + lower, + upper, + 0.00000001, + 100000); + + //cout << "iterations: " << iter << endl; + + unpack(controls, v); + return iter; + } + + + + class test_mpc : public tester + { + public: + test_mpc ( + ) : + tester ("test_mpc", + "Runs tests on the mpc object.") + {} + + void perform_test ( + ) + { + // a basic position + velocity model + matrix A; + A = 1, 1, + 0, 1; + matrix B, C; + B = 0, + 1; + + C = 0.02,0.1; // no constant bias + + matrix Q; + Q = 2, 0; // only care about getting the position right + matrix R, lower, upper; + R = 1; + + lower = -0.2; + upper = 0.2; + + std::vector > controls(30); + std::vector > target(30); + for (unsigned long i = 0; i < controls.size(); ++i) + { + controls[i] = 0; + target[i] = 0; + } + + mpc<2,1,30> solver(A,B,C,Q,R,lower,upper); + solver.set_epsilon(0.00000001); + solver.set_max_iterations(10000); + matrix initial_state; + initial_state = 0; + initial_state(0) = 5; + for (int i = 0; i < 30; ++i) + { + print_spinner(); + matrix control = solver(initial_state); + + for (unsigned long i = 1; i < controls.size(); ++i) + controls[i-1] = controls[i]; + + // Compute the correct control via SMO and make sure it matches. + solve_linear_mpc(A,B,C,Q,R,lower,upper, target, initial_state, controls); + dlog << LINFO << "ERROR: " << length(control-controls[0]); + DLIB_TEST(length(control-controls[0]) < 1e-7); + + initial_state = A*initial_state + B*control + C; + //cout << control(0) << "\t" << trans(initial_state); + } + + { + // also just generally test our QP solver. + matrix Q = gaussian_randm(20,20,5); + Q = Q*trans(Q); + + matrix b = randm(20,1)-0.5; + matrix alpha, lower, upper, alpha2; + alpha = 0; + alpha2 = 0; + lower = -4; + upper = 3; + + solve_qp_box_using_smo(Q,b,alpha,lower, upper, 0.000000001, 500000); + solve_qp_box_constrained(Q,b,alpha2,lower, upper, 0.000000001, 50000); + dlog << LINFO << trans(alpha); + dlog << LINFO << trans(alpha2); + dlog << LINFO << "objective value: " << 0.5*trans(alpha)*Q*alpha + trans(b)*alpha; + dlog << LINFO << "objective value2: " << 0.5*trans(alpha2)*Q*alpha + trans(b)*alpha2; + DLIB_TEST_MSG(max(abs(alpha-alpha2)) < 1e-7, max(abs(alpha-alpha2))); + } + } + } a; + +} + + + + diff --git a/ml/dlib/dlib/test/multithreaded_object.cpp b/ml/dlib/dlib/test/multithreaded_object.cpp new file mode 100644 index 000000000..96f8ea26e --- /dev/null +++ b/ml/dlib/dlib/test/multithreaded_object.cpp @@ -0,0 +1,321 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include + +#include +#include "tester.h" + +namespace +{ + using namespace test; + using namespace std; + using namespace dlib; + + logger dlog("test.multithreaded_object"); + + dlib::mutex cm; + int count; + + class test1 : multithreaded_object + { + public: + test1 () + { + DLIB_TEST(number_of_threads_registered() == 0); + DLIB_TEST(number_of_threads_alive() == 0); + DLIB_TEST(is_running() == false); + clear(); + DLIB_TEST(number_of_threads_registered() == 0); + DLIB_TEST(number_of_threads_alive() == 0); + DLIB_TEST(is_running() == false); + } + + ~test1 () + { + DLIB_TEST(number_of_threads_registered() == 0); + DLIB_TEST(number_of_threads_alive() == 0); + DLIB_TEST(is_running() == false); + stop(); + DLIB_TEST(number_of_threads_registered() == 0); + DLIB_TEST(number_of_threads_alive() == 0); + DLIB_TEST(is_running() == false); + wait(); + DLIB_TEST(number_of_threads_registered() == 0); + DLIB_TEST(number_of_threads_alive() == 0); + DLIB_TEST(is_running() == false); + } + + private: + }; + + class test2 : private multithreaded_object + { + public: + test2() + { + DLIB_TEST(number_of_threads_registered() == 0); + DLIB_TEST(number_of_threads_alive() == 0); + DLIB_TEST(is_running() == false); + register_thread(*this,&test2::thread); + DLIB_TEST(number_of_threads_registered() == 1); + DLIB_TEST(number_of_threads_alive() == 0); + DLIB_TEST(is_running() == false); + clear(); + DLIB_TEST(number_of_threads_registered() == 0); + DLIB_TEST(number_of_threads_alive() == 0); + DLIB_TEST(is_running() == false); + register_thread(*this,&test2::thread); + DLIB_TEST(number_of_threads_registered() == 1); + DLIB_TEST(number_of_threads_alive() == 0); + DLIB_TEST(is_running() == false); + } + + ~test2() + { + DLIB_TEST(number_of_threads_registered() == 1); + DLIB_TEST(number_of_threads_alive() == 0); + DLIB_TEST(is_running() == false); + stop(); + DLIB_TEST(number_of_threads_registered() == 1); + DLIB_TEST(number_of_threads_alive() == 0); + DLIB_TEST(is_running() == false); + wait(); + DLIB_TEST(number_of_threads_registered() == 1); + DLIB_TEST(number_of_threads_alive() == 0); + DLIB_TEST(is_running() == false); + } + + private: + + void thread() + { + auto_mutex M(cm); + ++count; + } + + }; + + class test3_c1 : private multithreaded_object + { + public: + test3_c1() + { + DLIB_TEST(number_of_threads_registered() == 0); + DLIB_TEST(number_of_threads_alive() == 0); + DLIB_TEST(is_running() == false); + register_thread(*this,&test3_c1::thread); + DLIB_TEST(number_of_threads_registered() == 1); + DLIB_TEST(number_of_threads_alive() == 0); + DLIB_TEST(is_running() == false); + start(); + DLIB_TEST(number_of_threads_registered() == 1); + DLIB_TEST(is_running() == true); + } + + ~test3_c1() + { + DLIB_TEST(number_of_threads_registered() == 1); + stop(); + DLIB_TEST(is_running() == false); + DLIB_TEST(number_of_threads_registered() == 1); + wait(); + DLIB_TEST(number_of_threads_registered() == 1); + DLIB_TEST(number_of_threads_alive() == 0); + DLIB_TEST(is_running() == false); + } + + private: + + void thread() + { + cm.lock(); + ++count; + cm.unlock(); + // wait until we are supposed to stop + while (!should_stop()) + dlib::sleep(1); + } + + }; + + class test4_c2 : private multithreaded_object + { + public: + test4_c2() + { + DLIB_TEST(number_of_threads_registered() == 0); + DLIB_TEST(number_of_threads_alive() == 0); + DLIB_TEST(is_running() == false); + register_thread(*this,&test4_c2::thread); + DLIB_TEST(number_of_threads_registered() == 1); + DLIB_TEST(number_of_threads_alive() == 0); + DLIB_TEST(is_running() == false); + start(); + DLIB_TEST(number_of_threads_registered() == 1); + DLIB_TEST(number_of_threads_alive() == 1); + DLIB_TEST(is_running() == true); + register_thread(*this,&test4_c2::thread); + DLIB_TEST(number_of_threads_registered() == 2); + DLIB_TEST(number_of_threads_alive() == 2); + DLIB_TEST(is_running() == true); + start(); + DLIB_TEST(number_of_threads_registered() == 2); + DLIB_TEST(number_of_threads_alive() == 2); + DLIB_TEST(is_running() == true); + start(); + DLIB_TEST(number_of_threads_registered() == 2); + DLIB_TEST(number_of_threads_alive() == 2); + DLIB_TEST(is_running() == true); + start(); + DLIB_TEST(number_of_threads_registered() == 2); + DLIB_TEST(number_of_threads_alive() == 2); + DLIB_TEST(is_running() == true); + start(); + DLIB_TEST(number_of_threads_registered() == 2); + DLIB_TEST(number_of_threads_alive() == 2); + DLIB_TEST(is_running() == true); + pause(); + DLIB_TEST(number_of_threads_registered() == 2); + DLIB_TEST(number_of_threads_alive() == 2); + DLIB_TEST(is_running() == false); + } + + ~test4_c2() + { + try + { + DLIB_TEST(number_of_threads_registered() == 2); + DLIB_TEST(number_of_threads_alive() == 2); + DLIB_TEST_MSG(is_running() == false,"is_running(): " << is_running()); + stop(); + DLIB_TEST(number_of_threads_registered() == 2); + DLIB_TEST(is_running() == false); + wait(); + DLIB_TEST(number_of_threads_registered() == 2); + DLIB_TEST(number_of_threads_alive() == 0); + DLIB_TEST(is_running() == false); + } + catch(std::exception& e) + { + std::cerr << e.what() << std::endl; + exit(1); + } + } + + private: + + void thread() + { + auto_mutex M(cm); + ++count; + while (!should_stop()) + dlib::sleep(10); + } + + }; + + + class test5 : private multithreaded_object + { + public: + test5() + { + register_thread(*this,&test5::thread1); + register_thread(*this,&test5::thread2); + register_thread(*this,&test5::thread3); + register_thread(*this,&test5::thread3); + start(); + } + + ~test5() + { + stop(); + wait(); + } + + private: + + void thread1() + { + while (!should_stop()) + dlib::sleep(10); + } + + void thread2() + { + while (!should_stop()) + dlib::sleep(10); + } + + void thread3() + { + while (!should_stop()) + dlib::sleep(10); + } + + }; + + + void multithreaded_object_test ( + ) + /*! + ensures + - runs tests on dlib::multithreaded_object for compliance with the specs + !*/ + { + + count = 0; + + for (int i = 0; i < 5; ++i) + { + { + test1 a1; + test2 a2; + test3_c1 a3; + test4_c2 a4; + test5 a5; + } + DLIB_TEST(count == (i+1)*3); + print_spinner(); + } + count = 0; + + for (int i = 0; i < 5; ++i) + { + { + test1 a1; + test2 a2; + test3_c1 a3; + test4_c2 a4; + test5 a5; + dlib::sleep(50); + } + DLIB_TEST(count == (i+1)*3); + print_spinner(); + } + } + + + class multithreaded_object_tester : public tester + { + public: + multithreaded_object_tester ( + ) : + tester ("test_multithreaded_object", + "Runs tests on the multithreaded_object component.") + {} + + void perform_test ( + ) + { + multithreaded_object_test(); + } + } a; + +} + + + diff --git a/ml/dlib/dlib/test/numerical_integration.cpp b/ml/dlib/dlib/test/numerical_integration.cpp new file mode 100644 index 000000000..d0e247623 --- /dev/null +++ b/ml/dlib/dlib/test/numerical_integration.cpp @@ -0,0 +1,228 @@ +// Copyright (C) 2013 Steve Taylor (steve98654@gmail.com) +// License: Boost Software License See LICENSE.txt for the full license. + +// This function test battery is given in: +// +// Test functions taken from Pedro Gonnet's dissertation at ETH: +// Adaptive Quadrature Re-Revisited +// http://e-collection.library.ethz.ch/eserv/eth:65/eth-65-02.pdf + +#include +#include +#include +#include +#include "tester.h" + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.numerical_integration"); + + class numerical_integration_tester : public tester + { + public: + numerical_integration_tester ( + ) : + tester ("test_numerical_integration", + "Runs tests on the numerical integration function.", + 0 + ) + {} + + void perform_test() + { + + dlog < m; + double tol = 1e-10; + double eps = 1e-8; + + m(0) = integrate_function_adapt_simp(&gg1, 0.0, 1.0, tol); + m(1) = integrate_function_adapt_simp(&gg2, 0.0, 1.0, tol); + m(2) = integrate_function_adapt_simp(&gg3, 0.0, 1.0, tol); + m(3) = integrate_function_adapt_simp(&gg4, 0.0, 1.0, tol); + m(4) = integrate_function_adapt_simp(&gg5, -1.0, 1.0, tol); + m(5) = integrate_function_adapt_simp(&gg6, 0.0, 1.0, tol); + m(6) = integrate_function_adapt_simp(&gg7, 0.0, 1.0, tol); + m(7) = integrate_function_adapt_simp(&gg8, 0.0, 1.0, tol); + m(8) = integrate_function_adapt_simp(&gg9, 0.0, 1.0, tol); + m(9) = integrate_function_adapt_simp(&gg10, 0.0, 1.0, tol); + m(10) = integrate_function_adapt_simp(&gg11, 0.0, 1.0, tol); + m(11) = integrate_function_adapt_simp(&gg12, 1e-6, 1.0, tol); + m(12) = integrate_function_adapt_simp(&gg13, 0.0, 10.0, tol); + m(13) = integrate_function_adapt_simp(&gg14, 0.0, 10.0, tol); + m(14) = integrate_function_adapt_simp(&gg15, 0.0, 10.0, tol); + m(15) = integrate_function_adapt_simp(&gg16, 0.01, 1.0, tol); + m(16) = integrate_function_adapt_simp(&gg17, 0.0, pi, tol); + m(17) = integrate_function_adapt_simp(&gg18, 0.0, 1.0, tol); + m(18) = integrate_function_adapt_simp(&gg19, -1.0, 1.0, tol); + m(19) = integrate_function_adapt_simp(&gg20, 0.0, 1.0, tol); + m(20) = integrate_function_adapt_simp(&gg21, 0.0, 1.0, tol); + m(21) = integrate_function_adapt_simp(&gg22, 0.0, 5.0, tol); + + // Here we compare the approximated integrals against + // highly accurate approximations generated either from + // the exact integral values or Mathematica's NIntegrate + // function using a working precision of 20. + + DLIB_TEST(abs(m(0) - 1.7182818284590452354) < 1e-11); + DLIB_TEST(abs(m(1) - 0.7000000000000000000) < eps); + DLIB_TEST(abs(m(2) - 0.6666666666666666667) < eps); + DLIB_TEST(abs(m(3) - 0.2397141133444008336) < eps); + DLIB_TEST(abs(m(4) - 1.5822329637296729331) < 1e-11); + DLIB_TEST(abs(m(5) - 0.4000000000000000000) < eps); + DLIB_TEST(abs(m(6) - 2.0000000000000000000) < 1e-4); + DLIB_TEST(abs(m(7) - 0.8669729873399110375) < eps); + DLIB_TEST(abs(m(8) - 1.1547005383792515290) < eps); + DLIB_TEST(abs(m(9) - 0.6931471805599453094) < eps); + DLIB_TEST(abs(m(10) - 0.3798854930417224753) < eps); + DLIB_TEST(abs(m(11) - 0.7775036341124982763) < eps); + DLIB_TEST(abs(m(12) - 0.5000000000000000000) < eps); + DLIB_TEST(abs(m(13) - 1.0000000000000000000) < eps); + DLIB_TEST(abs(m(14) - 0.4993633810764567446) < eps); + DLIB_TEST(abs(m(15) - 0.1121393035410217 ) < eps); + DLIB_TEST(abs(m(16) - 0.2910187828600526985) < eps); + DLIB_TEST(abs(m(17) + 0.4342944819032518276) < 1e-5); + DLIB_TEST(abs(m(18) - 1.56439644406905 ) < eps); + DLIB_TEST(abs(m(19) - 0.1634949430186372261) < eps); + DLIB_TEST(abs(m(20) - 0.0134924856494677726) < eps); + } + + static double gg1(double x) + { + return pow(e,x); + } + + static double gg2(double x) + { + if(x > 0.3) + { + return 1.0; + } + else + { + return 0; + } + } + + static double gg3(double x) + { + return pow(x,0.5); + } + + static double gg4(double x) + { + return 23.0/25.0*cosh(x)-cos(x); + } + + static double gg5(double x) + { + return 1/(pow(x,4) + pow(x,2) + 0.9); + } + + static double gg6(double x) + { + return pow(x,1.5); + } + + static double gg7(double x) + { + return pow(x,-0.5); + } + + static double gg8(double x) + { + return 1/(1 + pow(x,4)); + } + + static double gg9(double x) + { + return 2/(2 + sin(10*pi*x)); + } + + static double gg10(double x) + { + return 1/(1+x); + } + + static double gg11(double x) + { + return 1.0/(1 + pow(e,x)); + } + + static double gg12(double x) + { + return x/(pow(e,x)-1.0); + } + + static double gg13(double x) + { + return sqrt(50.0)*pow(e,-50.0*pi*x*x); + } + + static double gg14(double x) + { + return 25.0*pow(e,-25.0*x); + } + + static double gg15(double x) + { + return 50.0/(pi*(2500.0*x*x+1)); + } + + static double gg16(double x) + { + return 50.0*pow((sin(50.0*pi*x)/(50.0*pi*x)),2); + } + + static double gg17(double x) + { + return cos(cos(x)+3*sin(x)+2*cos(2*x)+3*cos(3*x)); + } + + static double gg18(double x) + { + return log10(x); + } + + static double gg19(double x) + { + return 1/(1.005+x*x); + } + + static double gg20(double x) + { + return 1/cosh(20.0*(x-1.0/5.0)) + 1/cosh(400.0*(x-2.0/5.0)) + + 1/cosh(8000.0*(x-3.0/5.0)); + } + + static double gg21(double x) + { + return 1.0/(1.0+(230.0*x-30.0)*(230.0*x-30.0)); + } + + static double gg22(double x) + { + if(x < 1) + { + return (x + 1.0); + } + else if(x >= 1 && x <= 3) + { + return (3.0 - x); + } + else + { + return 2.0; + } + } + + }; + + numerical_integration_tester a; +} + diff --git a/ml/dlib/dlib/test/object_detector.cpp b/ml/dlib/dlib/test/object_detector.cpp new file mode 100644 index 000000000..fdb72f520 --- /dev/null +++ b/ml/dlib/dlib/test/object_detector.cpp @@ -0,0 +1,1028 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include +#include "tester.h" +#include +#include +#include +#include +#include +#include +#include +#include + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.object_detector"); + +// ---------------------------------------------------------------------------------------- + + struct funny_image + { + array2d img; + long nr() const { return img.nr(); } + long nc() const { return img.nc(); } + }; + + void swap(funny_image& a, funny_image& b) + { + a.img.swap(b.img); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_array_type, + typename detector_type + > + void validate_some_object_detector_stuff ( + const image_array_type& images, + detector_type& detector, + double eps = 1e-10 + ) + { + for (unsigned long i = 0; i < images.size(); ++i) + { + std::vector dets = detector(images[i]); + std::vector > dets2; + + detector(images[i], dets2); + + matrix psi(detector.get_w().size()); + matrix psi2(detector.get_w().size()); + const double thresh = detector.get_w()(detector.get_w().size()-1); + + DLIB_TEST(dets.size() == dets2.size()); + for (unsigned long j = 0; j < dets.size(); ++j) + { + DLIB_TEST(dets[j] == dets2[j].second); + + const full_object_detection fdet = detector.get_scanner().get_full_object_detection(dets[j], detector.get_w()); + psi = 0; + detector.get_scanner().get_feature_vector(fdet, psi); + + double check_score = dot(psi,detector.get_w()) - thresh; + DLIB_TEST_MSG(std::abs(check_score - dets2[j].first) < eps, std::abs(check_score - dets2[j].first) << " check_score: "<< check_score); + } + + } + } + +// ---------------------------------------------------------------------------------------- + + class very_simple_feature_extractor : noncopyable + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a feature extractor which goes to every pixel in an image and + produces a 32 dimensional feature vector. This vector is an indicator vector + which records the pattern of pixel values in a 4-connected region. So it should + be able to distinguish basic things like whether or not a location falls on the + corner of a white box, on an edge, in the middle, etc. + + + Note that this object also implements the interface defined in dlib/image_keypoint/hashed_feature_image_abstract.h. + This means all the member functions in this object are supposed to behave as + described in the hashed_feature_image specification. So when you define your own + feature extractor objects you should probably refer yourself to that documentation + in addition to reading this example program. + !*/ + + + public: + + inline void load ( + const funny_image& img_ + ) + { + const array2d& img = img_.img; + + feat_image.set_size(img.nr(), img.nc()); + assign_all_pixels(feat_image,0); + for (long r = 1; r+1 < img.nr(); ++r) + { + for (long c = 1; c+1 < img.nc(); ++c) + { + unsigned char f = 0; + if (img[r][c]) f |= 0x1; + if (img[r][c+1]) f |= 0x2; + if (img[r][c-1]) f |= 0x4; + if (img[r+1][c]) f |= 0x8; + if (img[r-1][c]) f |= 0x10; + + // Store the code value for the pattern of pixel values in the 4-connected + // neighborhood around this row and column. + feat_image[r][c] = f; + } + } + } + + inline void load ( + const array2d& img + ) + { + feat_image.set_size(img.nr(), img.nc()); + assign_all_pixels(feat_image,0); + for (long r = 1; r+1 < img.nr(); ++r) + { + for (long c = 1; c+1 < img.nc(); ++c) + { + unsigned char f = 0; + if (img[r][c]) f |= 0x1; + if (img[r][c+1]) f |= 0x2; + if (img[r][c-1]) f |= 0x4; + if (img[r+1][c]) f |= 0x8; + if (img[r-1][c]) f |= 0x10; + + // Store the code value for the pattern of pixel values in the 4-connected + // neighborhood around this row and column. + feat_image[r][c] = f; + } + } + } + + inline size_t size () const { return feat_image.size(); } + inline long nr () const { return feat_image.nr(); } + inline long nc () const { return feat_image.nc(); } + + inline long get_num_dimensions ( + ) const + { + // Return the dimensionality of the vectors produced by operator() + return 32; + } + + typedef std::vector > descriptor_type; + + inline const descriptor_type& operator() ( + long row, + long col + ) const + /*! + requires + - 0 <= row < nr() + - 0 <= col < nc() + ensures + - returns a sparse vector which describes the image at the given row and column. + In particular, this is a vector that is 0 everywhere except for one element. + !*/ + { + feat.clear(); + const unsigned long only_nonzero_element_index = feat_image[row][col]; + feat.push_back(make_pair(only_nonzero_element_index,1.0)); + return feat; + } + + // This block of functions is meant to provide a way to map between the row/col space taken by + // this object's operator() function and the images supplied to load(). In this example it's trivial. + // However, in general, you might create feature extractors which don't perform extraction at every + // possible image location (e.g. the hog_image) and thus result in some more complex mapping. + inline const rectangle get_block_rect ( long row, long col) const { return centered_rect(col,row,3,3); } + inline const point image_to_feat_space ( const point& p) const { return p; } + inline const rectangle image_to_feat_space ( const rectangle& rect) const { return rect; } + inline const point feat_to_image_space ( const point& p) const { return p; } + inline const rectangle feat_to_image_space ( const rectangle& rect) const { return rect; } + + inline friend void serialize ( const very_simple_feature_extractor& item, std::ostream& out) { serialize(item.feat_image, out); } + inline friend void deserialize ( very_simple_feature_extractor& item, std::istream& in ) { deserialize(item.feat_image, in); } + + void copy_configuration ( const very_simple_feature_extractor& ){} + + private: + array2d feat_image; + + // This variable doesn't logically contribute to the state of this object. It is here + // only to avoid returning a descriptor_type object by value inside the operator() method. + mutable descriptor_type feat; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename image_array_type + > + void make_simple_test_data ( + image_array_type& images, + std::vector >& object_locations + ) + { + images.clear(); + object_locations.clear(); + + images.resize(3); + images[0].set_size(400,400); + images[1].set_size(400,400); + images[2].set_size(400,400); + + // set all the pixel values to black + assign_all_pixels(images[0], 0); + assign_all_pixels(images[1], 0); + assign_all_pixels(images[2], 0); + + // Now make some squares and draw them onto our black images. All the + // squares will be 70 pixels wide and tall. + + std::vector temp; + temp.push_back(centered_rect(point(100,100), 70,70)); + fill_rect(images[0],temp.back(),255); // Paint the square white + temp.push_back(centered_rect(point(200,300), 70,70)); + fill_rect(images[0],temp.back(),255); // Paint the square white + object_locations.push_back(temp); + + temp.clear(); + temp.push_back(centered_rect(point(140,200), 70,70)); + fill_rect(images[1],temp.back(),255); // Paint the square white + temp.push_back(centered_rect(point(303,200), 70,70)); + fill_rect(images[1],temp.back(),255); // Paint the square white + object_locations.push_back(temp); + + temp.clear(); + temp.push_back(centered_rect(point(123,121), 70,70)); + fill_rect(images[2],temp.back(),255); // Paint the square white + object_locations.push_back(temp); + + // corrupt each image with random noise just to make this a little more + // challenging + dlib::rand rnd; + for (unsigned long i = 0; i < images.size(); ++i) + { + for (long r = 0; r < images[i].nr(); ++r) + { + for (long c = 0; c < images[i].nc(); ++c) + { + typedef typename image_array_type::type image_type; + typedef typename image_type::type type; + images[i][r][c] = (type)put_in_range(0,255,images[i][r][c] + 10*rnd.get_random_gaussian()); + } + } + } + } + + template < + typename image_array_type + > + void make_simple_test_data ( + image_array_type& images, + std::vector >& object_locations + ) + { + images.clear(); + object_locations.clear(); + + + images.resize(3); + images[0].set_size(400,400); + images[1].set_size(400,400); + images[2].set_size(400,400); + + // set all the pixel values to black + assign_all_pixels(images[0], 0); + assign_all_pixels(images[1], 0); + assign_all_pixels(images[2], 0); + + // Now make some squares and draw them onto our black images. All the + // squares will be 70 pixels wide and tall. + const int shrink = 0; + std::vector temp; + + rectangle rect = centered_rect(point(100,100), 70,71); + std::vector movable_parts; + movable_parts.push_back(shrink_rect(rect,shrink).tl_corner()); + movable_parts.push_back(shrink_rect(rect,shrink).tr_corner()); + movable_parts.push_back(shrink_rect(rect,shrink).bl_corner()); + movable_parts.push_back(shrink_rect(rect,shrink).br_corner()); + temp.push_back(full_object_detection(rect, movable_parts)); + fill_rect(images[0],rect,255); // Paint the square white + + rect = centered_rect(point(200,200), 70,71); + movable_parts.clear(); + movable_parts.push_back(shrink_rect(rect,shrink).tl_corner()); + movable_parts.push_back(shrink_rect(rect,shrink).tr_corner()); + movable_parts.push_back(shrink_rect(rect,shrink).bl_corner()); + movable_parts.push_back(shrink_rect(rect,shrink).br_corner()); + temp.push_back(full_object_detection(rect, movable_parts)); + fill_rect(images[0],rect,255); // Paint the square white + + object_locations.push_back(temp); + // ------------------------------------ + temp.clear(); + + rect = centered_rect(point(140,200), 70,71); + movable_parts.clear(); + movable_parts.push_back(shrink_rect(rect,shrink).tl_corner()); + movable_parts.push_back(shrink_rect(rect,shrink).tr_corner()); + movable_parts.push_back(shrink_rect(rect,shrink).bl_corner()); + movable_parts.push_back(shrink_rect(rect,shrink).br_corner()); + temp.push_back(full_object_detection(rect, movable_parts)); + fill_rect(images[1],rect,255); // Paint the square white + + + rect = centered_rect(point(303,200), 70,71); + movable_parts.clear(); + movable_parts.push_back(shrink_rect(rect,shrink).tl_corner()); + movable_parts.push_back(shrink_rect(rect,shrink).tr_corner()); + movable_parts.push_back(shrink_rect(rect,shrink).bl_corner()); + movable_parts.push_back(shrink_rect(rect,shrink).br_corner()); + temp.push_back(full_object_detection(rect, movable_parts)); + fill_rect(images[1],rect,255); // Paint the square white + + object_locations.push_back(temp); + // ------------------------------------ + temp.clear(); + + rect = centered_rect(point(123,121), 70,71); + movable_parts.clear(); + movable_parts.push_back(shrink_rect(rect,shrink).tl_corner()); + movable_parts.push_back(shrink_rect(rect,shrink).tr_corner()); + movable_parts.push_back(shrink_rect(rect,shrink).bl_corner()); + movable_parts.push_back(shrink_rect(rect,shrink).br_corner()); + temp.push_back(full_object_detection(rect, movable_parts)); + fill_rect(images[2],rect,255); // Paint the square white + + object_locations.push_back(temp); + + // corrupt each image with random noise just to make this a little more + // challenging + dlib::rand rnd; + for (unsigned long i = 0; i < images.size(); ++i) + { + for (long r = 0; r < images[i].nr(); ++r) + { + for (long c = 0; c < images[i].nc(); ++c) + { + typedef typename image_array_type::type image_type; + typedef typename image_type::type type; + images[i][r][c] = (type)put_in_range(0,255,images[i][r][c] + 40*rnd.get_random_gaussian()); + } + } + } + } + +// ---------------------------------------------------------------------------------------- + + void test_fhog_pyramid ( + ) + { + print_spinner(); + dlog << LINFO << "test_fhog_pyramid()"; + + typedef dlib::array > grayscale_image_array_type; + grayscale_image_array_type images; + std::vector > object_locations; + make_simple_test_data(images, object_locations); + + typedef scan_fhog_pyramid > image_scanner_type; + image_scanner_type scanner; + scanner.set_detection_window_size(35,35); + structural_object_detection_trainer trainer(scanner); + trainer.set_num_threads(4); + trainer.set_overlap_tester(test_box_overlap(0,0)); + object_detector detector = trainer.train(images, object_locations); + + matrix res = test_object_detection_function(detector, images, object_locations); + dlog << LINFO << "Test detector (precision,recall): " << res; + DLIB_TEST(sum(res) == 3); + + { + ostringstream sout; + serialize(detector, sout); + istringstream sin(sout.str()); + object_detector d2; + deserialize(d2, sin); + matrix res = test_object_detection_function(d2, images, object_locations); + dlog << LINFO << "Test detector (precision,recall): " << res; + DLIB_TEST(sum(res) == 3); + + validate_some_object_detector_stuff(images, detector, 1e-6); + } + + { + std::vector > detectors; + detectors.push_back(detector); + detectors.push_back(detector); + detectors.push_back(detector); + + std::vector dets1 = evaluate_detectors(detectors, images[0]); + std::vector dets2 = detector(images[0]); + DLIB_TEST(dets1.size() > 0); + DLIB_TEST(dets2.size()*3 == dets1.size()); + dlib::set::kernel_1a_c d1, d2; + for (unsigned long i = 0; i < dets1.size(); ++i) + { + if (!d1.is_member(dets1[i])) + d1.add(dets1[i]); + } + for (unsigned long i = 0; i < dets2.size(); ++i) + { + if (!d2.is_member(dets2[i])) + d2.add(dets2[i]); + } + DLIB_TEST(d1.size() == d2.size()); + DLIB_TEST(set_intersection_size(d1,d2) == d1.size()); + } + } + +// ---------------------------------------------------------------------------------------- + + void test_1 ( + ) + { + print_spinner(); + dlog << LINFO << "test_1()"; + + typedef dlib::array > grayscale_image_array_type; + grayscale_image_array_type images; + std::vector > object_locations; + make_simple_test_data(images, object_locations); + + typedef hashed_feature_image > feature_extractor_type; + typedef scan_image_pyramid, feature_extractor_type> image_scanner_type; + image_scanner_type scanner; + const rectangle object_box = compute_box_dimensions(1,35*35); + scanner.add_detection_template(object_box, create_grid_detection_template(object_box,2,2)); + setup_hashed_features(scanner, images, 9); + use_uniform_feature_weights(scanner); + structural_object_detection_trainer trainer(scanner); + trainer.set_num_threads(4); + trainer.set_overlap_tester(test_box_overlap(0,0)); + object_detector detector = trainer.train(images, object_locations); + + matrix res = test_object_detection_function(detector, images, object_locations); + dlog << LINFO << "Test detector (precision,recall): " << res; + DLIB_TEST(sum(res) == 3); + + { + ostringstream sout; + serialize(detector, sout); + istringstream sin(sout.str()); + object_detector d2; + deserialize(d2, sin); + matrix res = test_object_detection_function(d2, images, object_locations); + dlog << LINFO << "Test detector (precision,recall): " << res; + DLIB_TEST(sum(res) == 3); + + validate_some_object_detector_stuff(images, detector); + } + } + +// ---------------------------------------------------------------------------------------- + + void test_1_boxes ( + ) + { + print_spinner(); + dlog << LINFO << "test_1_boxes()"; + + typedef dlib::array > grayscale_image_array_type; + grayscale_image_array_type images; + std::vector > object_locations; + make_simple_test_data(images, object_locations); + + typedef hashed_feature_image > feature_extractor_type; + typedef scan_image_boxes image_scanner_type; + image_scanner_type scanner; + setup_hashed_features(scanner, images, 9); + use_uniform_feature_weights(scanner); + structural_object_detection_trainer trainer(scanner); + trainer.set_num_threads(4); + trainer.set_overlap_tester(test_box_overlap(0,0)); + object_detector detector = trainer.train(images, object_locations); + + matrix res = test_object_detection_function(detector, images, object_locations); + dlog << LINFO << "Test detector (precision,recall): " << res; + DLIB_TEST(sum(res) == 3); + + { + ostringstream sout; + serialize(detector, sout); + istringstream sin(sout.str()); + object_detector d2; + deserialize(d2, sin); + matrix res = test_object_detection_function(d2, images, object_locations); + dlog << LINFO << "Test detector (precision,recall): " << res; + DLIB_TEST(sum(res) == 3); + + validate_some_object_detector_stuff(images, detector); + } + } + +// ---------------------------------------------------------------------------------------- + + void test_1m ( + ) + { + print_spinner(); + dlog << LINFO << "test_1m()"; + + typedef dlib::array > grayscale_image_array_type; + grayscale_image_array_type images; + std::vector > object_locations; + make_simple_test_data(images, object_locations); + + typedef hashed_feature_image > feature_extractor_type; + typedef scan_image_pyramid, feature_extractor_type> image_scanner_type; + image_scanner_type scanner; + const rectangle object_box = compute_box_dimensions(1,35*35); + std::vector mboxes; + const int mbox_size = 20; + mboxes.push_back(centered_rect(0,0, mbox_size,mbox_size)); + mboxes.push_back(centered_rect(0,0, mbox_size,mbox_size)); + mboxes.push_back(centered_rect(0,0, mbox_size,mbox_size)); + mboxes.push_back(centered_rect(0,0, mbox_size,mbox_size)); + scanner.add_detection_template(object_box, create_grid_detection_template(object_box,1,1), mboxes); + setup_hashed_features(scanner, images, 9); + use_uniform_feature_weights(scanner); + structural_object_detection_trainer trainer(scanner); + trainer.set_num_threads(4); + trainer.set_overlap_tester(test_box_overlap(0,0)); + object_detector detector = trainer.train(images, object_locations); + + matrix res = test_object_detection_function(detector, images, object_locations); + dlog << LINFO << "Test detector (precision,recall): " << res; + DLIB_TEST(sum(res) == 3); + + { + ostringstream sout; + serialize(detector, sout); + istringstream sin(sout.str()); + object_detector d2; + deserialize(d2, sin); + matrix res = test_object_detection_function(d2, images, object_locations); + dlog << LINFO << "Test detector (precision,recall): " << res; + DLIB_TEST(sum(res) == 3); + + validate_some_object_detector_stuff(images, detector); + } + } + +// ---------------------------------------------------------------------------------------- + + void test_1_fine_hog ( + ) + { + print_spinner(); + dlog << LINFO << "test_1_fine_hog()"; + + typedef dlib::array > grayscale_image_array_type; + grayscale_image_array_type images; + std::vector > object_locations; + make_simple_test_data(images, object_locations); + + typedef hashed_feature_image > feature_extractor_type; + typedef scan_image_pyramid, feature_extractor_type> image_scanner_type; + image_scanner_type scanner; + const rectangle object_box = compute_box_dimensions(1,35*35); + scanner.add_detection_template(object_box, create_grid_detection_template(object_box,2,2)); + setup_hashed_features(scanner, images, 9); + use_uniform_feature_weights(scanner); + structural_object_detection_trainer trainer(scanner); + trainer.set_num_threads(4); + trainer.set_overlap_tester(test_box_overlap(0,0)); + object_detector detector = trainer.train(images, object_locations); + + matrix res = test_object_detection_function(detector, images, object_locations); + dlog << LINFO << "Test detector (precision,recall): " << res; + DLIB_TEST(sum(res) == 3); + + { + ostringstream sout; + serialize(detector, sout); + istringstream sin(sout.str()); + object_detector d2; + deserialize(d2, sin); + matrix res = test_object_detection_function(d2, images, object_locations); + dlog << LINFO << "Test detector (precision,recall): " << res; + DLIB_TEST(sum(res) == 3); + + validate_some_object_detector_stuff(images, detector); + } + } + +// ---------------------------------------------------------------------------------------- + + void test_1_poly ( + ) + { + print_spinner(); + dlog << LINFO << "test_1_poly()"; + + typedef dlib::array > grayscale_image_array_type; + grayscale_image_array_type images; + std::vector > object_locations; + make_simple_test_data(images, object_locations); + + typedef hashed_feature_image > feature_extractor_type; + typedef scan_image_pyramid, feature_extractor_type> image_scanner_type; + image_scanner_type scanner; + const rectangle object_box = compute_box_dimensions(1,35*35); + scanner.add_detection_template(object_box, create_grid_detection_template(object_box,2,2)); + setup_hashed_features(scanner, images, 9); + use_uniform_feature_weights(scanner); + structural_object_detection_trainer trainer(scanner); + trainer.set_num_threads(4); + trainer.set_overlap_tester(test_box_overlap(0,0)); + object_detector detector = trainer.train(images, object_locations); + + matrix res = test_object_detection_function(detector, images, object_locations); + dlog << LINFO << "Test detector (precision,recall): " << res; + DLIB_TEST(sum(res) == 3); + + { + ostringstream sout; + serialize(detector, sout); + istringstream sin(sout.str()); + object_detector d2; + deserialize(d2, sin); + matrix res = test_object_detection_function(d2, images, object_locations); + dlog << LINFO << "Test detector (precision,recall): " << res; + DLIB_TEST(sum(res) == 3); + + validate_some_object_detector_stuff(images, detector); + } + } + +// ---------------------------------------------------------------------------------------- + + void test_1m_poly ( + ) + { + print_spinner(); + dlog << LINFO << "test_1_poly()"; + + typedef dlib::array > grayscale_image_array_type; + grayscale_image_array_type images; + std::vector > object_locations; + make_simple_test_data(images, object_locations); + + typedef hashed_feature_image > feature_extractor_type; + typedef scan_image_pyramid, feature_extractor_type> image_scanner_type; + image_scanner_type scanner; + const rectangle object_box = compute_box_dimensions(1,35*35); + std::vector mboxes; + const int mbox_size = 20; + mboxes.push_back(centered_rect(0,0, mbox_size,mbox_size)); + mboxes.push_back(centered_rect(0,0, mbox_size,mbox_size)); + mboxes.push_back(centered_rect(0,0, mbox_size,mbox_size)); + mboxes.push_back(centered_rect(0,0, mbox_size,mbox_size)); + scanner.add_detection_template(object_box, create_grid_detection_template(object_box,2,2), mboxes); + setup_hashed_features(scanner, images, 9); + use_uniform_feature_weights(scanner); + structural_object_detection_trainer trainer(scanner); + trainer.set_num_threads(4); + trainer.set_overlap_tester(test_box_overlap(0,0)); + object_detector detector = trainer.train(images, object_locations); + + matrix res = test_object_detection_function(detector, images, object_locations); + dlog << LINFO << "Test detector (precision,recall): " << res; + DLIB_TEST(sum(res) == 3); + + { + ostringstream sout; + serialize(detector, sout); + istringstream sin(sout.str()); + object_detector d2; + deserialize(d2, sin); + matrix res = test_object_detection_function(d2, images, object_locations); + dlog << LINFO << "Test detector (precision,recall): " << res; + DLIB_TEST(sum(res) == 3); + + validate_some_object_detector_stuff(images, detector); + } + } + +// ---------------------------------------------------------------------------------------- + + void test_1_poly_nn ( + ) + { + print_spinner(); + dlog << LINFO << "test_1_poly_nn()"; + + typedef dlib::array > grayscale_image_array_type; + grayscale_image_array_type images; + std::vector > object_locations; + make_simple_test_data(images, object_locations); + + typedef nearest_neighbor_feature_image > feature_extractor_type; + typedef scan_image_pyramid, feature_extractor_type> image_scanner_type; + image_scanner_type scanner; + + setup_grid_detection_templates(scanner, object_locations, 2, 2); + feature_extractor_type nnfe; + pyramid_down<2> pyr_down; + poly_image<5> polyi; + nnfe.set_basis(randomly_sample_image_features(images, pyr_down, polyi, 80)); + scanner.copy_configuration(nnfe); + + structural_object_detection_trainer trainer(scanner); + trainer.set_num_threads(4); + object_detector detector = trainer.train(images, object_locations); + + matrix res = test_object_detection_function(detector, images, object_locations); + dlog << LINFO << "Test detector (precision,recall): " << res; + DLIB_TEST(sum(res) == 3); + + { + ostringstream sout; + serialize(detector, sout); + istringstream sin(sout.str()); + object_detector d2; + deserialize(d2, sin); + matrix res = test_object_detection_function(d2, images, object_locations); + dlog << LINFO << "Test detector (precision,recall): " << res; + DLIB_TEST(sum(res) == 3); + + validate_some_object_detector_stuff(images, detector); + } + } + +// ---------------------------------------------------------------------------------------- + + void test_1_poly_nn_boxes ( + ) + { + print_spinner(); + dlog << LINFO << "test_1_poly_nn_boxes()"; + + typedef dlib::array > grayscale_image_array_type; + grayscale_image_array_type images; + std::vector > object_locations; + make_simple_test_data(images, object_locations); + + typedef nearest_neighbor_feature_image > feature_extractor_type; + typedef scan_image_boxes image_scanner_type; + image_scanner_type scanner; + + feature_extractor_type nnfe; + pyramid_down<2> pyr_down; + poly_image<5> polyi; + nnfe.set_basis(randomly_sample_image_features(images, pyr_down, polyi, 80)); + scanner.copy_configuration(nnfe); + + structural_object_detection_trainer trainer(scanner); + trainer.set_num_threads(4); + object_detector detector = trainer.train(images, object_locations); + + matrix res = test_object_detection_function(detector, images, object_locations); + dlog << LINFO << "Test detector (precision,recall): " << res; + DLIB_TEST(sum(res) == 3); + + { + ostringstream sout; + serialize(detector, sout); + istringstream sin(sout.str()); + object_detector d2; + deserialize(d2, sin); + matrix res = test_object_detection_function(d2, images, object_locations); + dlog << LINFO << "Test detector (precision,recall): " << res; + DLIB_TEST(sum(res) == 3); + + validate_some_object_detector_stuff(images, detector); + } + } + +// ---------------------------------------------------------------------------------------- + + void test_2 ( + ) + { + print_spinner(); + dlog << LINFO << "test_2()"; + + typedef dlib::array > grayscale_image_array_type; + grayscale_image_array_type images; + std::vector > object_locations; + make_simple_test_data(images, object_locations); + + typedef scan_image_pyramid, very_simple_feature_extractor> image_scanner_type; + image_scanner_type scanner; + const rectangle object_box = compute_box_dimensions(1,70*70); + scanner.add_detection_template(object_box, create_grid_detection_template(object_box,2,2)); + scanner.set_max_pyramid_levels(1); + structural_object_detection_trainer trainer(scanner); + trainer.set_num_threads(0); + object_detector detector = trainer.train(images, object_locations); + + matrix res = test_object_detection_function(detector, images, object_locations); + dlog << LINFO << "Test detector (precision,recall): " << res; + DLIB_TEST(sum(res) == 3); + + res = cross_validate_object_detection_trainer(trainer, images, object_locations, 3); + dlog << LINFO << "3-fold cross validation (precision,recall): " << res; + DLIB_TEST(sum(res) == 3); + + { + ostringstream sout; + serialize(detector, sout); + istringstream sin(sout.str()); + object_detector d2; + deserialize(d2, sin); + matrix res = test_object_detection_function(d2, images, object_locations); + dlog << LINFO << "Test detector (precision,recall): " << res; + DLIB_TEST(sum(res) == 3); + validate_some_object_detector_stuff(images, detector); + } + } + +// ---------------------------------------------------------------------------------------- + + class pyramid_down_funny : noncopyable + { + pyramid_down<2> pyr; + public: + + template + dlib::vector point_down ( const dlib::vector& p) const { return pyr.point_down(p); } + + template + dlib::vector point_up ( const dlib::vector& p) const { return pyr.point_up(p); } + + template + dlib::vector point_down ( const dlib::vector& p, unsigned int levels) const { return pyr.point_down(p,levels); } + + template + dlib::vector point_up ( const dlib::vector& p, unsigned int levels) const { return pyr.point_up(p,levels); } + + rectangle rect_up ( const rectangle& rect) const { return pyr.rect_up(rect); } + + rectangle rect_up ( const rectangle& rect, unsigned int levels) const { return pyr.rect_up(rect,levels); } + + rectangle rect_down ( const rectangle& rect) const { return pyr.rect_down(rect); } + + rectangle rect_down ( const rectangle& rect, unsigned int levels) const { return pyr.rect_down(rect,levels); } + + template < + typename in_image_type, + typename out_image_type + > + void operator() ( + const in_image_type& original, + out_image_type& down + ) const + { + pyr(original.img, down.img); + } + + }; + + // make sure everything works even when the image isn't a dlib::array2d. + // So test with funny_image. + void test_3 ( + ) + { + print_spinner(); + dlog << LINFO << "test_3()"; + + + typedef dlib::array > grayscale_image_array_type; + typedef dlib::array funny_image_array_type; + grayscale_image_array_type images_temp; + funny_image_array_type images; + std::vector > object_locations; + make_simple_test_data(images_temp, object_locations); + images.resize(images_temp.size()); + for (unsigned long i = 0; i < images_temp.size(); ++i) + { + images[i].img.swap(images_temp[i]); + } + + typedef scan_image_pyramid image_scanner_type; + image_scanner_type scanner; + const rectangle object_box = compute_box_dimensions(1,70*70); + scanner.add_detection_template(object_box, create_grid_detection_template(object_box,2,2)); + scanner.set_max_pyramid_levels(1); + structural_object_detection_trainer trainer(scanner); + trainer.set_num_threads(4); + object_detector detector = trainer.train(images, object_locations); + + matrix res = test_object_detection_function(detector, images, object_locations); + dlog << LINFO << "Test detector (precision,recall): " << res; + DLIB_TEST(sum(res) == 3); + + res = cross_validate_object_detection_trainer(trainer, images, object_locations, 3); + dlog << LINFO << "3-fold cross validation (precision,recall): " << res; + DLIB_TEST(sum(res) == 3); + + { + ostringstream sout; + serialize(detector, sout); + istringstream sin(sout.str()); + object_detector d2; + deserialize(d2, sin); + matrix res = test_object_detection_function(d2, images, object_locations); + dlog << LINFO << "Test detector (precision,recall): " << res; + DLIB_TEST(sum(res) == 3); + } + } + +// ---------------------------------------------------------------------------------------- + + class funny_box_generator + { + public: + template + void operator() ( + const image_type& img, + std::vector& rects + ) const + { + rects.clear(); + find_candidate_object_locations(img.img, rects); + dlog << LINFO << "funny_box_generator, rects.size(): "<< rects.size(); + } + }; + + inline void serialize(const funny_box_generator&, std::ostream& ) {} + inline void deserialize(funny_box_generator&, std::istream& ) {} + + + // make sure everything works even when the image isn't a dlib::array2d. + // So test with funny_image. + void test_3_boxes ( + ) + { + print_spinner(); + dlog << LINFO << "test_3_boxes()"; + + + typedef dlib::array > grayscale_image_array_type; + typedef dlib::array funny_image_array_type; + grayscale_image_array_type images_temp; + funny_image_array_type images; + std::vector > object_locations; + make_simple_test_data(images_temp, object_locations); + images.resize(images_temp.size()); + for (unsigned long i = 0; i < images_temp.size(); ++i) + { + images[i].img.swap(images_temp[i]); + } + + typedef scan_image_boxes image_scanner_type; + image_scanner_type scanner; + structural_object_detection_trainer trainer(scanner); + trainer.set_num_threads(4); + object_detector detector = trainer.train(images, object_locations); + + matrix res = test_object_detection_function(detector, images, object_locations); + dlog << LINFO << "Test detector (precision,recall): " << res; + DLIB_TEST(sum(res) == 3); + + res = cross_validate_object_detection_trainer(trainer, images, object_locations, 3); + dlog << LINFO << "3-fold cross validation (precision,recall): " << res; + DLIB_TEST(sum(res) == 3); + + { + ostringstream sout; + serialize(detector, sout); + istringstream sin(sout.str()); + object_detector d2; + deserialize(d2, sin); + matrix res = test_object_detection_function(d2, images, object_locations); + dlog << LINFO << "Test detector (precision,recall): " << res; + DLIB_TEST(sum(res) == 3); + } + } + +// ---------------------------------------------------------------------------------------- + + class object_detector_tester : public tester + { + public: + object_detector_tester ( + ) : + tester ("test_object_detector", + "Runs tests on the structural object detection stuff.") + {} + + void perform_test ( + ) + { + test_fhog_pyramid(); + test_1_boxes(); + test_1_poly_nn_boxes(); + test_3_boxes(); + + test_1(); + test_1m(); + test_1_fine_hog(); + test_1_poly(); + test_1m_poly(); + test_1_poly_nn(); + test_2(); + test_3(); + } + } a; + +} + + diff --git a/ml/dlib/dlib/test/oca.cpp b/ml/dlib/dlib/test/oca.cpp new file mode 100644 index 000000000..97881a758 --- /dev/null +++ b/ml/dlib/dlib/test/oca.cpp @@ -0,0 +1,244 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include +#include +#include + +#include "tester.h" + + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.oca"); + +// ---------------------------------------------------------------------------------------- + + class test_oca : public tester + { + + public: + test_oca ( + ) : + tester ("test_oca", + "Runs tests on the oca component.") + { + } + + void perform_test( + ) + { + print_spinner(); + + typedef matrix w_type; + w_type w; + + decision_function > df; + svm_c_linear_trainer > trainer; + trainer.set_c_class1(2); + trainer.set_c_class1(3); + trainer.set_learns_nonnegative_weights(true); + trainer.set_epsilon(1e-12); + + std::vector x; + w_type temp(2); + temp = -1, 1; + x.push_back(temp); + temp = 1, -1; + x.push_back(temp); + + std::vector y; + y.push_back(+1); + y.push_back(-1); + + w_type true_w(3); + + oca solver; + + // test the version without a non-negativity constraint on w. + solver(make_oca_problem_c_svm(2.0, 3.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, 0); + dlog << LINFO << trans(w); + true_w = -0.5, 0.5, 0; + dlog << LINFO << "error: "<< max(abs(w-true_w)); + DLIB_TEST(max(abs(w-true_w)) < 1e-10); + + solver.solve_with_elastic_net(make_oca_problem_c_svm(2.0, 3.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, 0.5); + dlog << LINFO << trans(w); + true_w = -0.5, 0.5, 0; + dlog << LINFO << "error: "<< max(abs(w-true_w)); + DLIB_TEST(max(abs(w-true_w)) < 1e-10); + + print_spinner(); + + w_type prior = true_w; + solver(make_oca_problem_c_svm(20.0, 30.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, prior); + dlog << LINFO << trans(w); + true_w = -0.5, 0.5, 0; + dlog << LINFO << "error: "<< max(abs(w-true_w)); + DLIB_TEST(max(abs(w-true_w)) < 1e-10); + + prior = 0,0,0; + solver(make_oca_problem_c_svm(20.0, 30.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, prior); + dlog << LINFO << trans(w); + true_w = -0.5, 0.5, 0; + dlog << LINFO << "error: "<< max(abs(w-true_w)); + DLIB_TEST(max(abs(w-true_w)) < 1e-10); + + prior = -1,1,0; + solver(make_oca_problem_c_svm(20.0, 30.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, prior); + dlog << LINFO << trans(w); + true_w = -1.0, 1.0, 0; + dlog << LINFO << "error: "<< max(abs(w-true_w)); + DLIB_TEST(max(abs(w-true_w)) < 1e-10); + + prior = -0.2,0.2,0; + solver(make_oca_problem_c_svm(20.0, 30.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, prior); + dlog << LINFO << trans(w); + true_w = -0.5, 0.5, 0; + dlog << LINFO << "error: "<< max(abs(w-true_w)); + DLIB_TEST(max(abs(w-true_w)) < 1e-10); + + prior = -10.2,-1,0; + solver(make_oca_problem_c_svm(20.0, 30.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, prior); + dlog << LINFO << trans(w); + true_w = -10.2, -1.0, 0; + dlog << LINFO << "error: "<< max(abs(w-true_w)); + DLIB_TEST(max(abs(w-true_w)) < 1e-10); + + print_spinner(); + + // test the version with a non-negativity constraint on w. + solver(make_oca_problem_c_svm(2.0, 3.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, 9999); + dlog << LINFO << trans(w); + true_w = 0, 1, 0; + dlog << LINFO << "error: "<< max(abs(w-true_w)); + DLIB_TEST(max(abs(w-true_w)) < 1e-10); + + df = trainer.train(x,y); + w = join_cols(df.basis_vectors(0), uniform_matrix(1,1,-df.b)); + true_w = 0, 1, 0; + dlog << LINFO << "error: "<< max(abs(w-true_w)); + DLIB_TEST_MSG(max(abs(w-true_w)) < 1e-9, max(abs(w-true_w))); + + + print_spinner(); + + // test the version with a non-negativity constraint on w. + solver(make_oca_problem_c_svm(2.0, 3.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, 2); + dlog << LINFO << trans(w); + true_w = 0, 1, 0; + dlog << LINFO << "error: "<< max(abs(w-true_w)); + DLIB_TEST(max(abs(w-true_w)) < 1e-10); + + print_spinner(); + + + // test the version with a non-negativity constraint on w. + solver(make_oca_problem_c_svm(2.0, 3.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, 1); + dlog << LINFO << trans(w); + true_w = 0, 1, 0; + dlog << LINFO << "error: "<< max(abs(w-true_w)); + DLIB_TEST(max(abs(w-true_w)) < 1e-10); + + print_spinner(); + + + // switching the labels should change which w weight goes negative. + y.clear(); + y.push_back(-1); + y.push_back(+1); + + + solver(make_oca_problem_c_svm(2.0, 3.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, 0); + dlog << LINFO << trans(w); + true_w = 0.5, -0.5, 0; + dlog << LINFO << "error: "<< max(abs(w-true_w)); + DLIB_TEST(max(abs(w-true_w)) < 1e-10); + + print_spinner(); + + solver(make_oca_problem_c_svm(2.0, 3.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, 1); + dlog << LINFO << trans(w); + true_w = 0.5, -0.5, 0; + dlog << LINFO << "error: "<< max(abs(w-true_w)); + DLIB_TEST(max(abs(w-true_w)) < 1e-10); + + print_spinner(); + + solver(make_oca_problem_c_svm(2.0, 3.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, 2); + dlog << LINFO << trans(w); + true_w = 1, 0, 0; + dlog << LINFO << "error: "<< max(abs(w-true_w)); + DLIB_TEST(max(abs(w-true_w)) < 1e-10); + + print_spinner(); + + solver(make_oca_problem_c_svm(2.0, 3.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, 5); + dlog << LINFO << trans(w); + true_w = 1, 0, 0; + dlog << LINFO << "error: "<< max(abs(w-true_w)); + DLIB_TEST(max(abs(w-true_w)) < 1e-10); + + df = trainer.train(x,y); + w = join_cols(df.basis_vectors(0), uniform_matrix(1,1,-df.b)); + true_w = 1, 0, 0; + dlog << LINFO << "error: "<< max(abs(w-true_w)); + DLIB_TEST_MSG(max(abs(w-true_w)) < 1e-9, max(abs(w-true_w))); + + + + x.clear(); + y.clear(); + temp = -2, 2; + x.push_back(temp); + temp = 0, -0; + x.push_back(temp); + + y.push_back(+1); + y.push_back(-1); + + trainer.set_c(10); + df = trainer.train(x,y); + w = join_cols(df.basis_vectors(0), uniform_matrix(1,1,-df.b)); + true_w = 0, 1, -1; + dlog << LINFO << "w: " << trans(w); + dlog << LINFO << "error: "<< max(abs(w-true_w)); + DLIB_TEST(max(abs(w-true_w)) < 1e-10); + + + x.clear(); + y.clear(); + temp = -2, 2; + x.push_back(temp); + temp = 0, -0; + x.push_back(temp); + + y.push_back(-1); + y.push_back(+1); + + trainer.set_c(10); + df = trainer.train(x,y); + w = join_cols(df.basis_vectors(0), uniform_matrix(1,1,-df.b)); + true_w = 1, 0, 1; + dlog << LINFO << "w: " << trans(w); + dlog << LINFO << "error: "<< max(abs(w-true_w)); + DLIB_TEST(max(abs(w-true_w)) < 1e-10); + + } + + } a; + +} + + + diff --git a/ml/dlib/dlib/test/one_vs_all_trainer.cpp b/ml/dlib/dlib/test/one_vs_all_trainer.cpp new file mode 100644 index 000000000..b928714b2 --- /dev/null +++ b/ml/dlib/dlib/test/one_vs_all_trainer.cpp @@ -0,0 +1,305 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "tester.h" +#include +#include +#include + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + dlib::logger dlog("test.one_vs_all_trainer"); + + + class test_one_vs_all_trainer : public tester + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a unit test. When it is constructed + it adds itself into the testing framework. + !*/ + public: + test_one_vs_all_trainer ( + ) : + tester ( + "test_one_vs_all_trainer", // the command line argument name for this test + "Run tests on the one_vs_all_trainer stuff.", // the command line argument description + 0 // the number of command line arguments for this test + ) + { + } + + + + template + void generate_data ( + std::vector& samples, + std::vector& labels + ) + { + const long num = 50; + + sample_type m; + + dlib::rand rnd; + + + // make some samples near the origin + double radius = 0.5; + for (long i = 0; i < num+10; ++i) + { + double sign = 1; + if (rnd.get_random_double() < 0.5) + sign = -1; + m(0) = 2*radius*rnd.get_random_double()-radius; + m(1) = sign*sqrt(radius*radius - m(0)*m(0)); + + // add this sample to our set of samples we will run k-means + samples.push_back(m); + labels.push_back(1); + } + + // make some samples in a circle around the origin but far away + radius = 10.0; + for (long i = 0; i < num+20; ++i) + { + double sign = 1; + if (rnd.get_random_double() < 0.5) + sign = -1; + m(0) = 2*radius*rnd.get_random_double()-radius; + m(1) = sign*sqrt(radius*radius - m(0)*m(0)); + + // add this sample to our set of samples we will run k-means + samples.push_back(m); + labels.push_back(2); + } + + // make some samples in a circle around the point (25,25) + radius = 4.0; + for (long i = 0; i < num+30; ++i) + { + double sign = 1; + if (rnd.get_random_double() < 0.5) + sign = -1; + m(0) = 2*radius*rnd.get_random_double()-radius; + m(1) = sign*sqrt(radius*radius - m(0)*m(0)); + + // translate this point away from the origin + m(0) += 25; + m(1) += 25; + + // add this sample to our set of samples we will run k-means + samples.push_back(m); + labels.push_back(3); + } + } + + template + void run_test ( + ) + { + print_spinner(); + typedef matrix sample_type; + + std::vector samples, norm_samples; + std::vector labels; + + // First, get our labeled set of training data + generate_data(samples, labels); + + typedef one_vs_all_trainer,label_type > ova_trainer; + + + ova_trainer trainer; + + typedef polynomial_kernel poly_kernel; + typedef radial_basis_kernel rbf_kernel; + + // make the binary trainers and set some parameters + krr_trainer rbf_trainer; + svm_nu_trainer poly_trainer; + poly_trainer.set_kernel(poly_kernel(0.1, 1, 2)); + rbf_trainer.set_kernel(rbf_kernel(0.1)); + + + trainer.set_trainer(rbf_trainer); + trainer.set_trainer(poly_trainer, 1); + + randomize_samples(samples, labels); + matrix res = cross_validate_multiclass_trainer(trainer, samples, labels, 2); + + print_spinner(); + + matrix ans(3,3); + ans = 60, 0, 0, + 0, 70, 0, + 0, 0, 80; + + DLIB_TEST_MSG(ans == res, "res: \n" << res); + + // test using a normalized_function with a one_vs_all_decision_function + { + poly_trainer.set_kernel(poly_kernel(1.1, 1, 2)); + trainer.set_trainer(poly_trainer, 1); + vector_normalizer normalizer; + normalizer.train(samples); + for (unsigned long i = 0; i < samples.size(); ++i) + norm_samples.push_back(normalizer(samples[i])); + normalized_function > ndf; + ndf.function = trainer.train(norm_samples, labels); + ndf.normalizer = normalizer; + DLIB_TEST(ndf(samples[0]) == labels[0]); + DLIB_TEST(ndf(samples[40]) == labels[40]); + DLIB_TEST(ndf(samples[90]) == labels[90]); + DLIB_TEST(ndf(samples[120]) == labels[120]); + poly_trainer.set_kernel(poly_kernel(0.1, 1, 2)); + trainer.set_trainer(poly_trainer, 1); + print_spinner(); + } + + one_vs_all_decision_function df = trainer.train(samples, labels); + + DLIB_TEST(df.number_of_classes() == 3); + + DLIB_TEST(df(samples[0]) == labels[0]); + DLIB_TEST(df(samples[90]) == labels[90]); + + + one_vs_all_decision_function, // This is the output of the poly_trainer + decision_function // This is the output of the rbf_trainer + > df2, df3; + + + df2 = df; + ofstream fout("df.dat", ios::binary); + serialize(df2, fout); + fout.close(); + + // load the function back in from disk and store it in df3. + ifstream fin("df.dat", ios::binary); + deserialize(df3, fin); + + + DLIB_TEST(df3(samples[0]) == labels[0]); + DLIB_TEST(df3(samples[90]) == labels[90]); + res = test_multiclass_decision_function(df3, samples, labels); + + DLIB_TEST(res == ans); + + + } + + template + void run_probabilistic_test ( + ) + { + print_spinner(); + typedef matrix sample_type; + + std::vector samples; + std::vector labels; + + // First, get our labeled set of training data + generate_data(samples, labels); + + typedef one_vs_all_trainer,label_type > ova_trainer; + + + ova_trainer trainer; + + typedef polynomial_kernel poly_kernel; + typedef radial_basis_kernel rbf_kernel; + + // make the binary trainers and set some parameters + krr_trainer rbf_trainer; + svm_nu_trainer poly_trainer; + poly_trainer.set_kernel(poly_kernel(0.1, 1, 2)); + rbf_trainer.set_kernel(rbf_kernel(0.1)); + + + trainer.set_trainer(probabilistic(rbf_trainer, 3)); + trainer.set_trainer(probabilistic(poly_trainer, 3), 1); + + randomize_samples(samples, labels); + matrix res = cross_validate_multiclass_trainer(trainer, samples, labels, 2); + + print_spinner(); + + matrix ans(3,3); + ans = 60, 0, 0, + 0, 70, 0, + 0, 0, 80; + + DLIB_TEST_MSG(ans == res, "res: \n" << res); + + one_vs_all_decision_function df = trainer.train(samples, labels); + + DLIB_TEST(df.number_of_classes() == 3); + + DLIB_TEST(df(samples[0]) == labels[0]); + DLIB_TEST(df(samples[90]) == labels[90]); + + + one_vs_all_decision_function >, // This is the output of the poly_trainer + probabilistic_function > // This is the output of the rbf_trainer + > df2, df3; + + + df2 = df; + ofstream fout("df.dat", ios::binary); + serialize(df2, fout); + fout.close(); + + // load the function back in from disk and store it in df3. + ifstream fin("df.dat", ios::binary); + deserialize(df3, fin); + + + DLIB_TEST(df3(samples[0]) == labels[0]); + DLIB_TEST(df3(samples[90]) == labels[90]); + res = test_multiclass_decision_function(df3, samples, labels); + + DLIB_TEST(res == ans); + + + } + + void perform_test ( + ) + { + dlog << LINFO << "run_test()"; + run_test(); + + dlog << LINFO << "run_test()"; + run_test(); + + dlog << LINFO << "run_test()"; + run_test(); + + dlog << LINFO << "run_test()"; + run_test(); + + dlog << LINFO << "run_probabilistic_test()"; + run_probabilistic_test(); + + dlog << LINFO << "run_probabilistic_test()"; + run_probabilistic_test(); + + dlog << LINFO << "run_probabilistic_test()"; + run_probabilistic_test(); + + dlog << LINFO << "run_probabilistic_test()"; + run_probabilistic_test(); + } + }; + + test_one_vs_all_trainer a; + +} + + diff --git a/ml/dlib/dlib/test/one_vs_one_trainer.cpp b/ml/dlib/dlib/test/one_vs_one_trainer.cpp new file mode 100644 index 000000000..70cdefaf9 --- /dev/null +++ b/ml/dlib/dlib/test/one_vs_one_trainer.cpp @@ -0,0 +1,218 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "tester.h" +#include +#include +#include +#include + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + dlib::logger dlog("test.one_vs_one_trainer"); + + + class test_one_vs_one_trainer : public tester + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a unit test. When it is constructed + it adds itself into the testing framework. + !*/ + public: + test_one_vs_one_trainer ( + ) : + tester ( + "test_one_vs_one_trainer", // the command line argument name for this test + "Run tests on the one_vs_one_trainer stuff.", // the command line argument description + 0 // the number of command line arguments for this test + ) + { + } + + + + template + void generate_data ( + std::vector& samples, + std::vector& labels + ) + { + const long num = 50; + + sample_type m; + + dlib::rand rnd; + + + // make some samples near the origin + double radius = 0.5; + for (long i = 0; i < num+10; ++i) + { + double sign = 1; + if (rnd.get_random_double() < 0.5) + sign = -1; + m(0) = 2*radius*rnd.get_random_double()-radius; + m(1) = sign*sqrt(radius*radius - m(0)*m(0)); + + // add this sample to our set of samples we will run k-means + samples.push_back(m); + labels.push_back(1); + } + + // make some samples in a circle around the origin but far away + radius = 10.0; + for (long i = 0; i < num+20; ++i) + { + double sign = 1; + if (rnd.get_random_double() < 0.5) + sign = -1; + m(0) = 2*radius*rnd.get_random_double()-radius; + m(1) = sign*sqrt(radius*radius - m(0)*m(0)); + + // add this sample to our set of samples we will run k-means + samples.push_back(m); + labels.push_back(2); + } + + // make some samples in a circle around the point (25,25) + radius = 4.0; + for (long i = 0; i < num+30; ++i) + { + double sign = 1; + if (rnd.get_random_double() < 0.5) + sign = -1; + m(0) = 2*radius*rnd.get_random_double()-radius; + m(1) = sign*sqrt(radius*radius - m(0)*m(0)); + + // translate this point away from the origin + m(0) += 25; + m(1) += 25; + + // add this sample to our set of samples we will run k-means + samples.push_back(m); + labels.push_back(3); + } + } + + template + void run_test ( + ) + { + print_spinner(); + typedef matrix sample_type; + + std::vector samples, norm_samples; + std::vector labels; + + // First, get our labeled set of training data + generate_data(samples, labels); + + typedef one_vs_one_trainer,label_type > ovo_trainer; + + + ovo_trainer trainer; + + typedef histogram_intersection_kernel hist_kernel; + typedef radial_basis_kernel rbf_kernel; + + // make the binary trainers and set some parameters + krr_trainer rbf_trainer; + svm_nu_trainer hist_trainer; + rbf_trainer.set_kernel(rbf_kernel(0.1)); + + + trainer.set_trainer(rbf_trainer); + trainer.set_trainer(hist_trainer, 1, 2); + + randomize_samples(samples, labels); + matrix res = cross_validate_multiclass_trainer(trainer, samples, labels, 2); + + print_spinner(); + + matrix ans(3,3); + ans = 60, 0, 0, + 0, 70, 0, + 0, 0, 80; + + DLIB_TEST_MSG(ans == res, "res: \n" << res); + + // test using a normalized_function with a one_vs_one_decision_function + { + trainer.set_trainer(hist_trainer, 1, 2); + vector_normalizer normalizer; + normalizer.train(samples); + for (unsigned long i = 0; i < samples.size(); ++i) + norm_samples.push_back(normalizer(samples[i])); + normalized_function > ndf; + ndf.function = trainer.train(norm_samples, labels); + ndf.normalizer = normalizer; + DLIB_TEST(ndf(samples[0]) == labels[0]); + DLIB_TEST(ndf(samples[40]) == labels[40]); + DLIB_TEST(ndf(samples[90]) == labels[90]); + DLIB_TEST(ndf(samples[120]) == labels[120]); + trainer.set_trainer(hist_trainer, 1, 2); + print_spinner(); + } + + + + + one_vs_one_decision_function df = trainer.train(samples, labels); + + DLIB_TEST(df.number_of_classes() == 3); + + DLIB_TEST(df(samples[0]) == labels[0]); + DLIB_TEST(df(samples[90]) == labels[90]); + + + one_vs_one_decision_function, // This is the output of the hist_trainer + decision_function // This is the output of the rbf_trainer + > df2, df3; + + + df2 = df; + ofstream fout("df.dat", ios::binary); + serialize(df2, fout); + fout.close(); + + // load the function back in from disk and store it in df3. + ifstream fin("df.dat", ios::binary); + deserialize(df3, fin); + + + DLIB_TEST(df3(samples[0]) == labels[0]); + DLIB_TEST(df3(samples[90]) == labels[90]); + res = test_multiclass_decision_function(df3, samples, labels); + + DLIB_TEST(res == ans); + + + } + + void perform_test ( + ) + { + dlog << LINFO << "run_test()"; + run_test(); + + dlog << LINFO << "run_test()"; + run_test(); + + dlog << LINFO << "run_test()"; + run_test(); + + dlog << LINFO << "run_test()"; + run_test(); + } + }; + + test_one_vs_one_trainer a; + +} + + diff --git a/ml/dlib/dlib/test/opt_qp_solver.cpp b/ml/dlib/dlib/test/opt_qp_solver.cpp new file mode 100644 index 000000000..ffa386323 --- /dev/null +++ b/ml/dlib/dlib/test/opt_qp_solver.cpp @@ -0,0 +1,813 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tester.h" + + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.opt_qp_solver"); + +// ---------------------------------------------------------------------------------------- + + class test_smo + { + public: + double penalty; + double C; + + double operator() ( + const matrix& alpha + ) const + { + + double obj = 0.5* trans(alpha)*Q*alpha - trans(alpha)*b; + double c1 = pow(sum(alpha)-C,2); + double c2 = sum(pow(pointwise_multiply(alpha, alpha<0), 2)); + + obj += penalty*(c1 + c2); + + return obj; + } + + matrix Q, b; + }; + +// ---------------------------------------------------------------------------------------- + + class test_smo_derivative + { + public: + double penalty; + double C; + + matrix operator() ( + const matrix& alpha + ) const + { + + matrix obj = Q*alpha - b; + matrix c1 = uniform_matrix(alpha.size(),1, 2*(sum(alpha)-C)); + matrix c2 = 2*pointwise_multiply(alpha, alpha<0); + + return obj + penalty*(c1 + c2); + } + + matrix Q, b; + }; + +// ---------------------------------------------------------------------------------------- + + double compute_objective_value ( + const matrix& w, + const matrix& A, + const matrix& b, + const double C + ) + { + return 0.5*dot(w,w) + C*max(trans(A)*w + b); + } + +// ---------------------------------------------------------------------------------------- + + void test_qp4_test1() + { + matrix A(3,2); + A = 1,2, + -3,1, + 6,7; + + matrix b(2); + b = 1, + 2; + + const double C = 2; + + matrix alpha(2), true_alpha(2), d(3), lambda; + alpha = C/2, C/2; + d = 0; + + solve_qp4_using_smo(A, tmp(trans(A)*A), b, d, alpha, lambda, 1e-9, 800); + matrix w = lowerbound(-A*alpha, 0); + + dlog << LINFO << "*******************************************************"; + + dlog << LINFO << "w: " << trans(w); + + dlog << LINFO << "computed obj: "<< compute_objective_value(w,A,b,C); + w = 0; + dlog << LINFO << "with true w obj: "<< compute_objective_value(w,A,b,C); + + dlog << LINFO << "alpha: " << trans(alpha); + true_alpha = 0, 2; + dlog << LINFO << "true alpha: "<< trans(true_alpha); + + dlog << LINFO << "alpha error: "<< max(abs(alpha-true_alpha)); + DLIB_TEST(max(abs(alpha-true_alpha)) < 1e-9); + } + +// ---------------------------------------------------------------------------------------- + + void test_qp4_test2() + { + matrix A(3,2); + A = 1,2, + 3,-1, + 6,7; + + matrix b(2); + b = 1, + 2; + + const double C = 2; + + matrix alpha(2), true_alpha(2), d(3), lambda; + alpha = C/2, C/2; + d = 0; + + solve_qp4_using_smo(A, tmp(trans(A)*A), b, d, alpha, lambda, 1e-9, 800); + matrix w = lowerbound(-A*alpha, 0); + + dlog << LINFO << "*******************************************************"; + + dlog << LINFO << "w: " << trans(w); + + dlog << LINFO << "computed obj: "<< compute_objective_value(w,A,b,C); + w = 0, 0.25, 0; + dlog << LINFO << "with true w obj: "<< compute_objective_value(w,A,b,C); + + dlog << LINFO << "alpha: " << trans(alpha); + true_alpha = 0.43750, 1.56250; + dlog << LINFO << "true alpha: "<< trans(true_alpha); + + dlog << LINFO << "alpha error: "<< max(abs(alpha-true_alpha)); + DLIB_TEST(max(abs(alpha-true_alpha)) < 1e-9); + } + +// ---------------------------------------------------------------------------------------- + + void test_qp4_test3() + { + matrix A(3,2); + A = 1,2, + -3,-1, + 6,7; + + matrix b(2); + b = 1, + 2; + + const double C = 2; + + matrix alpha(2), true_alpha(2), d(3), lambda; + alpha = C/2, C/2; + d = 0; + + solve_qp4_using_smo(A, tmp(trans(A)*A), b, d, alpha, lambda, 1e-9, 800); + matrix w = lowerbound(-A*alpha, 0); + + dlog << LINFO << "*******************************************************"; + + dlog << LINFO << "w: " << trans(w); + + dlog << LINFO << "computed obj: "<< compute_objective_value(w,A,b,C); + w = 0, 2, 0; + dlog << LINFO << "with true w obj: "<< compute_objective_value(w,A,b,C); + + dlog << LINFO << "alpha: " << trans(alpha); + true_alpha = 0, 2; + dlog << LINFO << "true alpha: "<< trans(true_alpha); + + dlog << LINFO << "alpha error: "<< max(abs(alpha-true_alpha)); + DLIB_TEST(max(abs(alpha-true_alpha)) < 1e-9); + } + +// ---------------------------------------------------------------------------------------- + + void test_qp4_test5() + { + matrix A(3,3); + A = 1,2,4, + 3,1,6, + 6,7,-2; + + matrix b(3); + b = 1, + 2, + 3; + + const double C = 2; + + matrix alpha(3), true_alpha(3), d(3), lambda; + alpha = C/2, C/2, 0; + d = 0; + + solve_qp4_using_smo(A, tmp(trans(A)*A), b, d, alpha, lambda, 1e-9, 800); + matrix w = lowerbound(-A*alpha, 0); + + + dlog << LINFO << "*******************************************************"; + + dlog << LINFO << "w: " << trans(w); + + dlog << LINFO << "computed obj: "<< compute_objective_value(w,A,b,C); + w = 0, 0, 0.11111111111111111111; + dlog << LINFO << "with true w obj: "<< compute_objective_value(w,A,b,C); + + dlog << LINFO << "alpha: " << trans(alpha); + true_alpha = 0, 0.432098765432099, 1.567901234567901; + dlog << LINFO << "true alpha: "<< trans(true_alpha); + + dlog << LINFO << "alpha error: "<< max(abs(alpha-true_alpha)); + DLIB_TEST(max(abs(alpha-true_alpha)) < 1e-9); + } + +// ---------------------------------------------------------------------------------------- + + void test_qp4_test4() + { + matrix A(3,2); + A = 1,2, + 3,1, + 6,7; + + matrix b(2); + b = 1, + 2; + + const double C = 2; + + matrix alpha(2), d(3), lambda; + alpha = C/2, C/2; + + solve_qp4_using_smo(A, tmp(trans(A)*A), b, d, alpha, lambda, 1e-9, 800); + matrix w = lowerbound(-A*alpha, 0); + + dlog << LINFO << "*******************************************************"; + + dlog << LINFO << "w: " << trans(w); + + const double computed_obj = compute_objective_value(w,A,b,C); + w = 0, 0, 0; + const double true_obj = compute_objective_value(w,A,b,C); + dlog << LINFO << "computed obj: "<< computed_obj; + dlog << LINFO << "with true w obj: "<< true_obj; + + DLIB_TEST_MSG(abs(computed_obj - true_obj) < 1e-8, abs(computed_obj - true_obj)); + } + + void test_qp4_test6() + { + matrix A(3,3); + A = 1,2,4, + 3,1,6, + 6,7,-2; + + matrix b(3); + b = -1, + -2, + -3; + + const double C = 2; + + matrix alpha(3), d(3), lambda; + d = 0; + alpha = C/2, C/2, 0; + + unsigned long iters = solve_qp4_using_smo(A, tmp(trans(A)*A), b, d, alpha, lambda, 1e-9, 3000); + matrix w = lowerbound(-A*alpha, 0); + + dlog << LINFO << "*******************************************************"; + + dlog << LINFO << "alpha: " << trans(alpha); + dlog << LINFO << "lambda: " << trans(lambda); + dlog << LINFO << "w: " << trans(w); + + + const double computed_obj = compute_objective_value(w,A,b,C); + w = 0, 0, 0; + const double true_obj = compute_objective_value(w,A,b,C); + dlog << LINFO << "computed obj: "<< computed_obj; + dlog << LINFO << "with true w obj: "<< true_obj; + + DLIB_TEST_MSG(abs(computed_obj - true_obj) < 1e-8, + "computed_obj: "<< computed_obj << " true_obj: " << true_obj << " delta: "<< abs(computed_obj - true_obj) + << " iters: " << iters + << "\n alpha: " << trans(alpha) + << " lambda: " << trans(lambda) + ); + } + + void test_qp4_test7() + { + matrix A(3,3); + A = -1,2,4, + -3,1,6, + -6,7,-2; + + matrix b(3); + b = -1, + -2, + 3; + + matrix Q(3,3); + Q = 4,-5,6, + 1,-4,2, + -9,-4,5; + Q = Q*trans(Q); + + const double C = 2; + + matrix alpha(3), true_alpha(3), d(3), lambda; + alpha = C/2, C/2, 0; + d = 0; + + solve_qp4_using_smo(A, Q, b, d, alpha, lambda, 1e-9, 800); + + dlog << LINFO << "*******************************************************"; + + dlog << LINFO << "alpha: " << trans(alpha); + true_alpha = 0, 2, 0; + dlog << LINFO << "true alpha: "<< trans(true_alpha); + + dlog << LINFO << "alpha error: "<< max(abs(alpha-true_alpha)); + DLIB_TEST(max(abs(alpha-true_alpha)) < 1e-9); + + } + +// ---------------------------------------------------------------------------------------- + + void test_solve_qp4_using_smo() + { + test_qp4_test1(); + test_qp4_test2(); + test_qp4_test3(); + test_qp4_test4(); + test_qp4_test5(); + test_qp4_test6(); + test_qp4_test7(); + } + +// ---------------------------------------------------------------------------------------- + + double max_distance_to( + const std::vector>& a, + const std::vector>& b + ) + { + double best_dist = 0; + for (auto&& aa : a) + { + for (auto&& bb : b) + { + double dist = length(aa-bb); + if (dist > best_dist) + best_dist = dist; + } + } + return best_dist; + } + + double min_distance_to( + const std::vector>& a, + const std::vector>& b + ) + { + double best_dist = std::numeric_limits::infinity(); + for (auto&& aa : a) + { + for (auto&& bb : b) + { + double dist = length(aa-bb); + if (dist < best_dist) + best_dist = dist; + } + } + return best_dist; + } + + double min_distance_to( + const std::vector>& s, + const matrix& v + ) + { + double best_dist = std::numeric_limits::infinity(); + for (auto& x : s) + { + double dist = length(v-x); + if (dist < best_dist) + { + best_dist = dist; + } + } + return best_dist; + } + + double max_distance_to( + const std::vector>& s, + const matrix& v + ) + { + double best_dist = 0; + for (auto& x : s) + { + double dist = length(v-x); + if (dist > best_dist) + { + best_dist = dist; + } + } + return best_dist; + } + + void test_find_gap_between_convex_hulls() + { + print_spinner(); + std::vector> set1, set2; + + const double dist_thresh = 5.47723; + + // generate two groups of points that are pairwise close within each set and + // pairwise far apart between each set, according to dist_thresh distance threshold. + bool which = true; + for (size_t i = 0; i < 10000; ++i) + { + matrix v = gaussian_randm(15,1,i); + const auto min_dist1 = min_distance_to(set1,v); + const auto min_dist2 = min_distance_to(set2,v); + const auto max_dist1 = max_distance_to(set1,v); + const auto max_dist2 = max_distance_to(set2,v); + if (which) + { + if ((set1.size()==0 || max_dist1 < dist_thresh) && min_dist2 > dist_thresh ) + { + set1.push_back(v); + which = !which; + } + } + else + { + if ((set2.size()==0 || max_dist2 < dist_thresh) && min_dist1 > dist_thresh) + { + set2.push_back(v); + which = !which; + } + } + } + + dlog << LINFO << "set1.size(): "<< set1.size(); + dlog << LINFO << "set2.size(): "<< set2.size(); + + + // make sure we generated the points correctly. + dlog << LINFO << "dist_thresh: "<< dist_thresh; + dlog << LINFO << "max distance between set1 and set1: "<< max_distance_to(set1,set1); + dlog << LINFO << "max distance between set2 and set2: "<< max_distance_to(set2,set2); + DLIB_TEST(max_distance_to(set1,set1) < dist_thresh); + DLIB_TEST(max_distance_to(set2,set2) < dist_thresh); + dlog << LINFO << "min distance between set2 and set1: "<< min_distance_to(set2,set1); + DLIB_TEST(min_distance_to(set2,set1) > dist_thresh); + + + // It is slightly counterintuitive but true that points picked using the above procedure + // will have elements of their convex hulls that are much closer together than + // dist_thresh, even though none of the vertices of the hulls are that close + // together. This is especially true in high dimensions. So let's use this to + // test find_gap_between_convex_hulls(). It should be able to find a pair of + // points in the convex hulls of our sets that are a lot closer together than + // dist_thresh. + + // First we need to convert the vectors to matrices. + matrix A, B; + A.set_size(set1[0].size(), set1.size()); + B.set_size(set2[0].size(), set2.size()); + for (long c = 0; c < A.nc(); ++c) + set_colm(A,c) = set1[c]; + for (long c = 0; c < B.nc(); ++c) + set_colm(B,c) = set2[c]; + + matrix c1, c2; + find_gap_between_convex_hulls(A, B, c1, c2, 0.0001); + // make sure c1 and c2 are convex combinations. + DLIB_TEST(abs(sum(c1)-1) < 1e-8); + DLIB_TEST(abs(sum(c2)-1) < 1e-8); + DLIB_TEST(min(c1) >= 0); + DLIB_TEST(min(c2) >= 0); + + // now test that the points found are close together. + dlog << LINFO << "dist: "<< length(A*c1 - B*c2); + DLIB_TEST(length(A*c1 - B*c2) < 4); + } + +// ---------------------------------------------------------------------------------------- + + void test_solve_qp_box_constrained_blockdiag() + { + dlib::rand rnd; + for (int iter = 0; iter < 50; ++iter) + { + print_spinner(); + + matrix Q1, Q2; + matrix b1, b2; + + Q1 = randm(4,4,rnd); Q1 = Q1*trans(Q1); + Q2 = randm(4,4,rnd); Q2 = Q2*trans(Q2); + b1 = gaussian_randm(4,1, iter*2+0); + b2 = gaussian_randm(4,1, iter*2+1); + + std::map, matrix> offdiag; + + if (rnd.get_random_gaussian() > 0) + offdiag[make_unordered_pair(0,0)] = randm(4,1,rnd); + if (rnd.get_random_gaussian() > 0) + offdiag[make_unordered_pair(1,0)] = randm(4,1,rnd); + if (rnd.get_random_gaussian() > 0) + offdiag[make_unordered_pair(1,1)] = randm(4,1,rnd); + + std::vector> Q_blocks = {Q1, Q2}; + std::vector> bs = {b1, b2}; + + + // make the single big Q and b + matrix Q = join_cols(join_rows(Q1, zeros_matrix(Q1)), + join_rows(zeros_matrix(Q2),Q2)); + matrix b = join_cols(b1,b2); + for (auto& p : offdiag) + { + long r = p.first.first; + long c = p.first.second; + set_subm(Q, 4*r,4*c, 4,4) += diagm(p.second); + if (c != r) + set_subm(Q, 4*c,4*r, 4,4) += diagm(p.second); + } + + + matrix alpha = zeros_matrix(b); + matrix lower = -10000*ones_matrix(b); + matrix upper = 10000*ones_matrix(b); + + auto iters = solve_qp_box_constrained(Q, b, alpha, lower, upper, 1e-9, 10000); + dlog << LINFO << "iters: "<< iters; + dlog << LINFO << "alpha: " << trans(alpha); + + dlog << LINFO; + + std::vector> alphas(2); + alphas[0] = zeros_matrix(4,1); alphas[1] = zeros_matrix(4,1); + + lower = -10000*ones_matrix(alphas[0]); + upper = 10000*ones_matrix(alphas[0]); + std::vector> lowers = {lower,lower}, uppers = {upper, upper}; + auto iters2 = solve_qp_box_constrained_blockdiag(Q_blocks, bs, offdiag, alphas, lowers, uppers, 1e-9, 10000); + dlog << LINFO << "iters2: "<< iters2; + dlog << LINFO << "alpha: " << trans(join_cols(alphas[0],alphas[1])); + + dlog << LINFO << "obj1: "<< 0.5*trans(alpha)*Q*alpha + trans(b)*alpha; + dlog << LINFO << "obj2: "<< 0.5*trans(join_cols(alphas[0],alphas[1]))*Q*join_cols(alphas[0],alphas[1]) + trans(b)*join_cols(alphas[0],alphas[1]); + dlog << LINFO << "obj1-obj2: "<<(0.5*trans(alpha)*Q*alpha + trans(b)*alpha) - (0.5*trans(join_cols(alphas[0],alphas[1]))*Q*join_cols(alphas[0],alphas[1]) + trans(b)*join_cols(alphas[0],alphas[1])); + + DLIB_TEST_MSG(max(abs(alpha - join_cols(alphas[0], alphas[1]))) < 1e-6, max(abs(alpha - join_cols(alphas[0], alphas[1])))); + + DLIB_TEST(iters == iters2); + + } + } + +// ---------------------------------------------------------------------------------------- + + void test_solve_qp_box_constrained_blockdiag_compact(dlib::rand& rnd, double percent_off_diag_present) + { + print_spinner(); + + dlog << LINFO << "test_solve_qp_box_constrained_blockdiag_compact(), percent_off_diag_present==" << percent_off_diag_present; + + std::map, matrix> offdiag; + std::vector> Q_blocks; + std::vector> bs; + + const long num_blocks = 20; + const long dims = 4; + const double lambda = 10; + for (long i = 0; i < num_blocks; ++i) + { + matrix Q1; + matrix b1; + Q1 = randm(dims,dims,rnd); Q1 = Q1*trans(Q1); + b1 = gaussian_randm(dims,1, i); + + Q_blocks.push_back(Q1); + bs.push_back(b1); + + // test with some graph regularization terms + for (long j = 0; j < num_blocks; ++j) + { + if (rnd.get_random_double() < percent_off_diag_present) + { + if (i==j) + offdiag[make_unordered_pair(i,j)] = (num_blocks-1)*lambda*rnd.get_random_double()*ones_matrix(dims,1); + else + offdiag[make_unordered_pair(i,j)] = -lambda*rnd.get_random_double()*ones_matrix(dims,1); + } + } + } + + // build out the dense version of the QP so we can test it against the dense solver. + matrix Q(num_blocks*dims, num_blocks*dims); + Q = 0; + matrix b(num_blocks*dims); + for (long i = 0; i < num_blocks; ++i) + { + set_subm(Q,i*dims,i*dims,dims,dims) = Q_blocks[i]; + set_subm(b,i*dims,0,dims,1) = bs[i]; + } + for (auto& p : offdiag) + { + long r = p.first.first; + long c = p.first.second; + set_subm(Q, dims*r,dims*c, dims,dims) += diagm(p.second); + if (c != r) + set_subm(Q, dims*c,dims*r, dims,dims) += diagm(p.second); + } + + + + matrix alpha = zeros_matrix(dims*num_blocks,1); + matrix lower = -10000*ones_matrix(dims*num_blocks,1); + matrix upper = 10000*ones_matrix(dims*num_blocks,1); + + auto iters = solve_qp_box_constrained(Q, b, alpha, lower, upper, 1e-9, 20000); + dlog << LINFO << "iters: "<< iters; + + + matrix init_alpha = zeros_matrix(bs[0]); + lower = -10000*ones_matrix(bs[0]); + upper = 10000*ones_matrix(bs[0]); + + std::vector> alphas(num_blocks, init_alpha); + std::vector> lowers(num_blocks, lower); + std::vector> uppers(num_blocks, upper); + + auto iters2 = solve_qp_box_constrained_blockdiag(Q_blocks, bs, offdiag, alphas, lowers, uppers, 1e-9, 20000); + dlog << LINFO << "iters2: "<< iters2; + + + const matrix refalpha = reshape(alpha, num_blocks, dims); + + // now make sure the two solvers agree on the outputs. + for (long r = 0; r < num_blocks; ++r) + { + for (long c = 0; c < dims; ++c) + { + DLIB_TEST_MSG(std::abs(refalpha(r,c) - alphas[r](c)) < 1e-6, std::abs(refalpha(r,c) - alphas[r](c))); + } + } + } + +// ---------------------------------------------------------------------------------------- + + class opt_qp_solver_tester : public tester + { + /* + The idea here is just to solve the same problem with two different + methods and check that they basically agree. The SMO solver should be + very accurate but for this problem the BFGS solver is relatively + inaccurate. So this test is really just a sanity check on the SMO + solver. + */ + public: + opt_qp_solver_tester ( + ) : + tester ("test_opt_qp_solver", + "Runs tests on the solve_qp_using_smo component.") + { + thetime = time(0); + } + + time_t thetime; + dlib::rand rnd; + + void perform_test( + ) + { + print_spinner(); + test_solve_qp4_using_smo(); + print_spinner(); + + ++thetime; + //dlog << LINFO << "time seed: " << thetime; + //rnd.set_seed(cast_to_string(thetime)); + + running_stats rs; + + for (int i = 0; i < 40; ++i) + { + for (long dims = 1; dims < 6; ++dims) + { + rs.add(do_the_test(dims, 1.0)); + } + } + + for (int i = 0; i < 40; ++i) + { + for (long dims = 1; dims < 6; ++dims) + { + rs.add(do_the_test(dims, 5.0)); + } + } + + dlog << LINFO << "disagreement mean: " << rs.mean(); + dlog << LINFO << "disagreement stddev: " << rs.stddev(); + DLIB_TEST_MSG(rs.mean() < 0.001, rs.mean()); + DLIB_TEST_MSG(rs.stddev() < 0.001, rs.stddev()); + + + test_find_gap_between_convex_hulls(); + test_solve_qp_box_constrained_blockdiag(); + + // try a range of off diagonal sparseness. We do this to make sure we exercise both + // the compact and sparse code paths within the solver. + test_solve_qp_box_constrained_blockdiag_compact(rnd, 0.001); + test_solve_qp_box_constrained_blockdiag_compact(rnd, 0.01); + test_solve_qp_box_constrained_blockdiag_compact(rnd, 0.04); + test_solve_qp_box_constrained_blockdiag_compact(rnd, 0.10); + test_solve_qp_box_constrained_blockdiag_compact(rnd, 0.50); + test_solve_qp_box_constrained_blockdiag_compact(rnd, 1.00); + } + + double do_the_test ( + const long dims, + double C + ) + { + print_spinner(); + dlog << LINFO << "dims: " << dims; + dlog << LINFO << "testing with C == " << C; + test_smo test; + + test.Q = randm(dims, dims, rnd); + test.Q = trans(test.Q)*test.Q; + test.b = randm(dims,1, rnd); + test.C = C; + + test_smo_derivative der; + der.Q = test.Q; + der.b = test.b; + der.C = test.C; + + + matrix x(dims), alpha(dims); + + + test.penalty = 20000; + der.penalty = test.penalty; + + alpha = C/alpha.size(); + x = alpha; + + const unsigned long max_iter = 400000; + solve_qp_using_smo(test.Q, test.b, alpha, 0.00000001, max_iter); + DLIB_TEST_MSG(abs(sum(alpha) - C) < 1e-13, abs(sum(alpha) - C) ); + dlog << LTRACE << "alpha: " << alpha; + dlog << LINFO << "SMO: true objective: "<< 0.5*trans(alpha)*test.Q*alpha - trans(alpha)*test.b; + + + double obj = find_min(bfgs_search_strategy(), + objective_delta_stop_strategy(1e-13, 5000), + test, + der, + x, + -10); + + + dlog << LINFO << "BFGS: objective: " << obj; + dlog << LINFO << "BFGS: true objective: "<< 0.5*trans(x)*test.Q*x - trans(x)*test.b; + dlog << LINFO << "sum(x): " << sum(x); + dlog << LINFO << x; + + double disagreement = max(abs(x-alpha)); + dlog << LINFO << "Disagreement: " << disagreement; + return disagreement; + } + } a; + +} + + + diff --git a/ml/dlib/dlib/test/optimization.cpp b/ml/dlib/dlib/test/optimization.cpp new file mode 100644 index 000000000..b47449abe --- /dev/null +++ b/ml/dlib/dlib/test/optimization.cpp @@ -0,0 +1,1231 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include "optimization_test_functions.h" +#include +#include +#include +#include +#include +#include +#include +#include "../stl_checked.h" +#include "../array.h" +#include "../rand.h" + +#include "tester.h" + + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.optimization"); + +// ---------------------------------------------------------------------------------------- + + bool approx_equal ( + double a, + double b + ) + { + return std::abs(a - b) < 100*std::numeric_limits::epsilon(); + } + +// ---------------------------------------------------------------------------------------- + + long total_count = 0; + + + template + double apq ( const T& x) + { + DLIB_ASSERT(x.nr() > 1 && x.nc() == 1,""); + COMPILE_TIME_ASSERT(is_matrix::value); + double temp = 0; + for (long r = 0; r < x.nr(); ++r) + { + temp += (r+1)*x(r)*x(r); + } + + ++total_count; + + return temp + 1/100.0*(x(0) + x(x.nr()-1))*(x(0) + x(x.nr()-1)); + } + + template + T der_apq ( const T& x) + { + DLIB_ASSERT(x.nr() > 1 && x.nc() == 1,""); + COMPILE_TIME_ASSERT(is_matrix::value); + T temp(x.nr()); + for (long r = 0; r < x.nr(); ++r) + { + temp(r) = 2*(r+1)*x(r) ; + } + + temp(0) += 1/50.0*(x(0) + x(x.nr()-1)); + temp(x.nr()-1) += 1/50.0*(x(0) + x(x.nr()-1)); + + ++total_count; + + return temp; + } + +// ---------------------------------------------------------------------------------------- + + // Rosenbrock's function. minimum at (1,1) + double rosen ( const matrix& x) + { + ++total_count; + return 100*pow(x(1) - x(0)*x(0),2) + pow(1 - x(0),2); + } + + matrix der_rosen ( const matrix& x) + { + ++total_count; + matrix res; + res(0) = -400*x(0)*(x(1)-x(0)*x(0)) - 2*(1-x(0)); + res(1) = 200*(x(1)-x(0)*x(0)); + return res; + } + +// ---------------------------------------------------------------------------------------- + + // negative of Rosenbrock's function. minimum at (1,1) + double neg_rosen ( const matrix& x) + { + ++total_count; + return -(100*pow(x(1) - x(0)*x(0),2) + pow(1 - x(0),2)); + } + + matrix der_neg_rosen ( const matrix& x) + { + ++total_count; + matrix res; + res(0) = -400*x(0)*(x(1)-x(0)*x(0)) - 2*(1-x(0)); + res(1) = 200*(x(1)-x(0)*x(0)); + return -res; + } + +// ---------------------------------------------------------------------------------------- + + double simple ( const matrix& x) + { + ++total_count; + return 10*x(0)*x(0) + x(1)*x(1); + } + + matrix der_simple ( const matrix& x) + { + ++total_count; + matrix res; + res(0) = 20*x(0); + res(1) = 2*x(1); + return res; + } + +// ---------------------------------------------------------------------------------------- + + double powell ( const matrix& x) + { + ++total_count; + return pow(x(0) + 10*x(1),2) + + pow(std::sqrt(5.0)*(x(2) - x(3)),2) + + pow((x(1) - 2*x(2))*(x(1) - 2*x(2)),2) + + pow(std::sqrt(10.0)*(x(0) - x(3))*(x(0) - x(3)),2); + } + +// ---------------------------------------------------------------------------------------- + +// a simple function with a minimum at zero + double single_variable_function ( double x) + { + ++total_count; + return 3*x*x + 5; + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + void test_apq ( + const matrix p + ) + { + typedef matrix T; + const double eps = 1e-12; + const double minf = -10; + matrix x(p.nr()), opt(p.nr()); + set_all_elements(opt, 0); + double val = 0; + + if (p.size() < 20) + dlog << LINFO << "testing with apq and the start point: " << trans(p); + else + dlog << LINFO << "testing with apq and a big vector with " << p.size() << " components."; + + // don't use bfgs on really large vectors + if (p.size() < 20) + { + total_count = 0; + x = p; + val = find_min(bfgs_search_strategy(), + objective_delta_stop_strategy(eps), + wrap_function(apq), wrap_function(der_apq), x, minf); + DLIB_TEST_MSG(dlib::equal(x,opt, 1e-5),opt-x); + DLIB_TEST(approx_equal(val , apq(x))); + dlog << LINFO << "find_min() bgfs: got apq in " << total_count; + + total_count = 0; + x = p; + find_min(bfgs_search_strategy(), + gradient_norm_stop_strategy(), + wrap_function(apq), wrap_function(der_apq), x, minf); + DLIB_TEST_MSG(dlib::equal(x,opt, 1e-5),opt-x); + dlog << LINFO << "find_min() bgfs(gn): got apq in " << total_count; + } + + + if (p.size() < 100) + { + total_count = 0; + x = p; + val=find_min_bobyqa(wrap_function(apq), x, 2*x.size()+1, + uniform_matrix(x.size(),1,-1e100), + uniform_matrix(x.size(),1,1e100), + (max(abs(x))+1)/10, + 1e-6, + 10000); + DLIB_TEST_MSG(dlib::equal(x,opt, 1e-5),opt-x); + DLIB_TEST(approx_equal(val , apq(x))); + dlog << LINFO << "find_min_bobyqa(): got apq in " << total_count; + } + + total_count = 0; + x = p; + val=find_min(lbfgs_search_strategy(10), + objective_delta_stop_strategy(eps), + wrap_function(apq), wrap_function(der_apq), x, minf); + DLIB_TEST_MSG(dlib::equal(x,opt, 1e-5),opt-x); + DLIB_TEST(approx_equal(val , apq(x))); + dlog << LINFO << "find_min() lbgfs-10: got apq in " << total_count; + + + total_count = 0; + x = p; + val=find_min(lbfgs_search_strategy(1), + objective_delta_stop_strategy(eps), + wrap_function(apq), wrap_function(der_apq), x, minf); + DLIB_TEST_MSG(dlib::equal(x,opt, 1e-5),opt-x); + DLIB_TEST(approx_equal(val , apq(x))); + dlog << LINFO << "find_min() lbgfs-1: got apq in " << total_count; + + + total_count = 0; + x = p; + val=find_min(cg_search_strategy(), + objective_delta_stop_strategy(eps), + wrap_function(apq), wrap_function(der_apq), x, minf); + DLIB_TEST_MSG(dlib::equal(x,opt, 1e-5),opt-x); + DLIB_TEST(approx_equal(val , apq(x))); + dlog << LINFO << "find_min() cg: got apq in " << total_count; + + + // don't do approximate derivative tests if the input point is really long + if (p.size() < 20) + { + total_count = 0; + x = p; + val=find_min(bfgs_search_strategy(), + objective_delta_stop_strategy(eps), + wrap_function(apq), derivative(wrap_function(apq)), x, minf); + DLIB_TEST_MSG(dlib::equal(x,opt, 1e-5),opt-x); + DLIB_TEST(approx_equal(val , apq(x))); + dlog << LINFO << "find_min() bfgs: got apq/noder in " << total_count; + + + total_count = 0; + x = p; + val=find_min(cg_search_strategy(), + objective_delta_stop_strategy(eps), + wrap_function(apq), derivative(wrap_function(apq)), x, minf); + DLIB_TEST_MSG(dlib::equal(x,opt, 1e-5),opt-x); + DLIB_TEST(approx_equal(val , apq(x))); + dlog << LINFO << "find_min() cg: got apq/noder in " << total_count; + + + total_count = 0; + x = p; + val=find_min_using_approximate_derivatives(bfgs_search_strategy(), + objective_delta_stop_strategy(eps), + wrap_function(apq), x, minf); + DLIB_TEST_MSG(dlib::equal(x,opt, 1e-5),opt-x); + DLIB_TEST(approx_equal(val , apq(x))); + dlog << LINFO << "find_min() bfgs: got apq/noder2 in " << total_count; + + + total_count = 0; + x = p; + val=find_min_using_approximate_derivatives(lbfgs_search_strategy(10), + objective_delta_stop_strategy(eps), + wrap_function(apq), x, minf); + DLIB_TEST_MSG(dlib::equal(x,opt, 1e-5),opt-x); + dlog << LINFO << "find_min() lbfgs-10: got apq/noder2 in " << total_count; + + + total_count = 0; + x = p; + val=find_min_using_approximate_derivatives(cg_search_strategy(), + objective_delta_stop_strategy(eps), + wrap_function(apq), x, minf); + DLIB_TEST_MSG(dlib::equal(x,opt, 1e-5),opt-x); + DLIB_TEST(approx_equal(val , apq(x))); + dlog << LINFO << "find_min() cg: got apq/noder2 in " << total_count; + } + } + + void test_powell ( + const matrix p + ) + { + const double eps = 1e-15; + const double minf = -1; + matrix x, opt; + opt(0) = 0; + opt(1) = 0; + opt(2) = 0; + opt(3) = 0; + + double val = 0; + + dlog << LINFO << "testing with powell and the start point: " << trans(p); + + /* + total_count = 0; + x = p; + val=find_min(bfgs_search_strategy(), + objective_delta_stop_strategy(eps), + powell, derivative(powell,1e-8), x, minf); + DLIB_TEST_MSG(dlib::equal(x,opt, 1e-2),opt-x); + DLIB_TEST(approx_equal(val , powell(x))); + dlog << LINFO << "find_min() bfgs: got powell/noder in " << total_count; + + + total_count = 0; + x = p; + val=find_min(cg_search_strategy(), + objective_delta_stop_strategy(eps), + powell, derivative(powell,1e-9), x, minf); + DLIB_TEST_MSG(dlib::equal(x,opt, 1e-2),opt-x); + DLIB_TEST(approx_equal(val , powell(x))); + dlog << LINFO << "find_min() cg: got powell/noder in " << total_count; + */ + + total_count = 0; + x = p; + val=find_min_using_approximate_derivatives(bfgs_search_strategy(), + objective_delta_stop_strategy(eps), + powell, x, minf, 1e-10); + DLIB_TEST_MSG(dlib::equal(x,opt, 1e-1),opt-x); + DLIB_TEST(approx_equal(val , powell(x))); + dlog << LINFO << "find_min() bfgs: got powell/noder2 in " << total_count; + + + total_count = 0; + x = p; + val=find_min_using_approximate_derivatives(lbfgs_search_strategy(4), + objective_delta_stop_strategy(eps), + powell, x, minf, 1e-10); + DLIB_TEST_MSG(dlib::equal(x,opt, 1e-1),opt-x); + DLIB_TEST(approx_equal(val , powell(x))); + dlog << LINFO << "find_min() lbfgs-4: got powell/noder2 in " << total_count; + + + total_count = 0; + x = p; + val=find_min_using_approximate_derivatives(lbfgs_search_strategy(4), + gradient_norm_stop_strategy(), + powell, x, minf, 1e-10); + DLIB_TEST_MSG(dlib::equal(x,opt, 1e-1),opt-x); + DLIB_TEST(approx_equal(val , powell(x))); + dlog << LINFO << "find_min() lbfgs-4(gn): got powell/noder2 in " << total_count; + + + total_count = 0; + x = p; + val=find_min_using_approximate_derivatives(cg_search_strategy(), + objective_delta_stop_strategy(eps), + powell, x, minf, 1e-10); + DLIB_TEST_MSG(dlib::equal(x,opt, 1e-1),opt-x); + DLIB_TEST(approx_equal(val , powell(x))); + dlog << LINFO << "find_min() cg: got powell/noder2 in " << total_count; + + + total_count = 0; + x = p; + val=find_min_bobyqa(powell, x, 2*x.size()+1, + uniform_matrix(x.size(),1,-1e100), + uniform_matrix(x.size(),1,1e100), + (max(abs(x))+1)/10, + 1e-8, + 10000); + DLIB_TEST_MSG(dlib::equal(x,opt, 1e-3),opt-x); + DLIB_TEST(approx_equal(val , powell(x))); + dlog << LINFO << "find_min_bobyqa(): got powell in " << total_count; + + } + + + + void test_simple ( + const matrix p + ) + { + const double eps = 1e-12; + const double minf = -10000; + matrix x, opt; + opt(0) = 0; + opt(1) = 0; + double val = 0; + + dlog << LINFO << "testing with simple and the start point: " << trans(p); + + total_count = 0; + x = p; + val=find_min(bfgs_search_strategy(), + objective_delta_stop_strategy(eps), + simple, der_simple, x, minf); + DLIB_TEST_MSG(dlib::equal(x,opt, 1e-5),opt-x); + DLIB_TEST(approx_equal(val , simple(x))); + dlog << LINFO << "find_min() bfgs: got simple in " << total_count; + + + total_count = 0; + x = p; + val=find_min(bfgs_search_strategy(), + gradient_norm_stop_strategy(), + simple, der_simple, x, minf); + DLIB_TEST_MSG(dlib::equal(x,opt, 1e-5),opt-x); + DLIB_TEST(approx_equal(val , simple(x))); + dlog << LINFO << "find_min() bfgs(gn): got simple in " << total_count; + + + total_count = 0; + x = p; + val=find_min(lbfgs_search_strategy(3), + objective_delta_stop_strategy(eps), + simple, der_simple, x, minf); + DLIB_TEST_MSG(dlib::equal(x,opt, 1e-5),opt-x); + DLIB_TEST(approx_equal(val , simple(x))); + dlog << LINFO << "find_min() lbfgs-3: got simple in " << total_count; + + + total_count = 0; + x = p; + val=find_min(cg_search_strategy(), + objective_delta_stop_strategy(eps), + simple, der_simple, x, minf); + DLIB_TEST_MSG(dlib::equal(x,opt, 1e-5),opt-x); + DLIB_TEST(approx_equal(val , simple(x))); + dlog << LINFO << "find_min() cg: got simple in " << total_count; + + + + total_count = 0; + x = p; + val=find_min(bfgs_search_strategy(), + objective_delta_stop_strategy(eps), + simple, derivative(simple), x, minf); + DLIB_TEST_MSG(dlib::equal(x,opt, 1e-5),opt-x); + DLIB_TEST(approx_equal(val , simple(x))); + dlog << LINFO << "find_min() bfgs: got simple/noder in " << total_count; + + + total_count = 0; + x = p; + val=find_min(lbfgs_search_strategy(8), + objective_delta_stop_strategy(eps), + simple, derivative(simple), x, minf); + DLIB_TEST_MSG(dlib::equal(x,opt, 1e-5),opt-x); + DLIB_TEST(approx_equal(val , simple(x))); + dlog << LINFO << "find_min() lbfgs-8: got simple/noder in " << total_count; + + + total_count = 0; + x = p; + val=find_min(cg_search_strategy(), + objective_delta_stop_strategy(eps), + simple, derivative(simple), x, minf); + DLIB_TEST_MSG(dlib::equal(x,opt, 1e-5),opt-x); + DLIB_TEST(approx_equal(val , simple(x))); + dlog << LINFO << "find_min() cg: got simple/noder in " << total_count; + + + + total_count = 0; + x = p; + val=find_min_using_approximate_derivatives(bfgs_search_strategy(), + objective_delta_stop_strategy(eps), + simple, x, minf); + DLIB_TEST_MSG(dlib::equal(x,opt, 1e-5),opt-x); + DLIB_TEST(approx_equal(val , simple(x))); + dlog << LINFO << "find_min() bfgs: got simple/noder2 in " << total_count; + + + total_count = 0; + x = p; + val=find_min_using_approximate_derivatives(lbfgs_search_strategy(6), + objective_delta_stop_strategy(eps), + simple, x, minf); + DLIB_TEST_MSG(dlib::equal(x,opt, 1e-5),opt-x); + DLIB_TEST(approx_equal(val , simple(x))); + dlog << LINFO << "find_min() lbfgs-6: got simple/noder2 in " << total_count; + + + total_count = 0; + x = p; + val=find_min_using_approximate_derivatives(cg_search_strategy(), + objective_delta_stop_strategy(eps), + simple, x, minf); + DLIB_TEST_MSG(dlib::equal(x,opt, 1e-5),opt-x); + DLIB_TEST(approx_equal(val , simple(x))); + dlog << LINFO << "find_min() cg: got simple/noder2 in " << total_count; + + + total_count = 0; + x = p; + val=find_min_bobyqa(simple, x, 2*x.size()+1, + uniform_matrix(x.size(),1,-1e100), + uniform_matrix(x.size(),1,1e100), + (max(abs(x))+1)/10, + 1e-6, + 10000); + DLIB_TEST_MSG(dlib::equal(x,opt, 1e-5),opt-x); + DLIB_TEST(approx_equal(val , simple(x))); + dlog << LINFO << "find_min_bobyqa(): got simple in " << total_count; + + } + + + void test_rosen ( + const matrix p + ) + { + const double eps = 1e-15; + const double minf = -10; + matrix x, opt; + opt(0) = 1; + opt(1) = 1; + + double val = 0; + + dlog << LINFO << "testing with rosen and the start point: " << trans(p); + + total_count = 0; + x = p; + val=find_min(bfgs_search_strategy(), + objective_delta_stop_strategy(eps), + rosen, der_rosen, x, minf); + DLIB_TEST_MSG(dlib::equal(x,opt, 1e-7),opt-x); + DLIB_TEST(approx_equal(val , rosen(x))); + dlog << LINFO << "find_min() bfgs: got rosen in " << total_count; + + + total_count = 0; + x = p; + val=find_min(bfgs_search_strategy(), + gradient_norm_stop_strategy(), + rosen, der_rosen, x, minf); + DLIB_TEST_MSG(dlib::equal(x,opt, 1e-7),opt-x); + DLIB_TEST(approx_equal(val , rosen(x))); + dlog << LINFO << "find_min() bfgs(gn): got rosen in " << total_count; + + + total_count = 0; + x = p; + val=find_min(lbfgs_search_strategy(20), + objective_delta_stop_strategy(eps), + rosen, der_rosen, x, minf); + DLIB_TEST_MSG(dlib::equal(x,opt, 1e-7),opt-x); + DLIB_TEST(approx_equal(val , rosen(x))); + dlog << LINFO << "find_min() lbfgs-20: got rosen in " << total_count; + + + total_count = 0; + x = p; + val=find_min(cg_search_strategy(), + objective_delta_stop_strategy(eps), + rosen, der_rosen, x, minf); + DLIB_TEST_MSG(dlib::equal(x,opt, 1e-7),opt-x); + DLIB_TEST(approx_equal(val , rosen(x))); + dlog << LINFO << "find_min() cg: got rosen in " << total_count; + + + + total_count = 0; + x = p; + val=find_min(bfgs_search_strategy(), + objective_delta_stop_strategy(eps), + rosen, derivative(rosen,1e-5), x, minf); + DLIB_TEST_MSG(dlib::equal(x,opt, 1e-4),opt-x); + DLIB_TEST(approx_equal(val , rosen(x))); + dlog << LINFO << "find_min() bfgs: got rosen/noder in " << total_count; + + + total_count = 0; + x = p; + val=find_min(lbfgs_search_strategy(5), + objective_delta_stop_strategy(eps), + rosen, derivative(rosen,1e-5), x, minf); + DLIB_TEST_MSG(dlib::equal(x,opt, 1e-4),opt-x); + DLIB_TEST(approx_equal(val , rosen(x))); + dlog << LINFO << "find_min() lbfgs-5: got rosen/noder in " << total_count; + + + total_count = 0; + x = p; + val=find_min(cg_search_strategy(), + objective_delta_stop_strategy(eps), + rosen, derivative(rosen,1e-5), x, minf); + DLIB_TEST_MSG(dlib::equal(x,opt, 1e-4),opt-x); + DLIB_TEST(approx_equal(val , rosen(x))); + dlog << LINFO << "find_min() cg: got rosen/noder in " << total_count; + + + total_count = 0; + x = p; + val=find_min_using_approximate_derivatives(cg_search_strategy(), + objective_delta_stop_strategy(eps), + rosen, x, minf); + DLIB_TEST_MSG(dlib::equal(x,opt, 1e-4),opt-x); + DLIB_TEST(approx_equal(val , rosen(x))); + dlog << LINFO << "find_min() cg: got rosen/noder2 in " << total_count; + + + if (max(abs(p)) < 1000) + { + total_count = 0; + x = p; + val=find_min_bobyqa(rosen, x, 2*x.size()+1, + uniform_matrix(x.size(),1,-1e100), + uniform_matrix(x.size(),1,1e100), + (max(abs(x))+1)/10, + 1e-6, + 10000); + DLIB_TEST_MSG(dlib::equal(x,opt, 1e-5),opt-x); + DLIB_TEST(approx_equal(val , rosen(x))); + dlog << LINFO << "find_min_bobyqa(): got rosen in " << total_count; + } + } + + + void test_neg_rosen ( + const matrix p + ) + { + const double eps = 1e-15; + const double maxf = 10; + matrix x, opt; + opt(0) = 1; + opt(1) = 1; + + double val = 0; + + dlog << LINFO << "testing with neg_rosen and the start point: " << trans(p); + + total_count = 0; + x = p; + val=find_max( + bfgs_search_strategy(), + objective_delta_stop_strategy(eps), neg_rosen, der_neg_rosen, x, maxf); + DLIB_TEST_MSG(dlib::equal(x,opt, 1e-7),opt-x); + DLIB_TEST(approx_equal(val , neg_rosen(x))); + dlog << LINFO << "find_max() bfgs: got neg_rosen in " << total_count; + + total_count = 0; + x = p; + val=find_max( + lbfgs_search_strategy(5), + objective_delta_stop_strategy(eps), neg_rosen, der_neg_rosen, x, maxf); + DLIB_TEST_MSG(dlib::equal(x,opt, 1e-7),opt-x); + DLIB_TEST(approx_equal(val , neg_rosen(x))); + dlog << LINFO << "find_max() lbfgs-5: got neg_rosen in " << total_count; + + total_count = 0; + x = p; + val=find_max( + lbfgs_search_strategy(5), + objective_delta_stop_strategy(eps), neg_rosen, derivative(neg_rosen), x, maxf); + DLIB_TEST_MSG(dlib::equal(x,opt, 1e-7),opt-x); + DLIB_TEST(approx_equal(val , neg_rosen(x))); + dlog << LINFO << "find_max() lbfgs-5: got neg_rosen/noder in " << total_count; + + + total_count = 0; + x = p; + val=find_max_using_approximate_derivatives( + cg_search_strategy(), + objective_delta_stop_strategy(eps), neg_rosen, x, maxf); + DLIB_TEST_MSG(dlib::equal(x,opt, 1e-7),opt-x); + DLIB_TEST(approx_equal(val , neg_rosen(x))); + dlog << LINFO << "find_max() cg: got neg_rosen/noder2 in " << total_count; + + + total_count = 0; + x = p; + val=find_max_bobyqa(neg_rosen, x, 2*x.size()+1, + uniform_matrix(x.size(),1,-1e100), + uniform_matrix(x.size(),1,1e100), + (max(abs(x))+1)/10, + 1e-6, + 10000); + DLIB_TEST_MSG(dlib::equal(x,opt, 1e-5),opt-x); + DLIB_TEST(approx_equal(val , neg_rosen(x))); + dlog << LINFO << "find_max_bobyqa(): got neg_rosen in " << total_count; + } + +// ---------------------------------------------------------------------------------------- + + void test_single_variable_function ( + const double p + ) + { + const double eps = 1e-7; + + + dlog << LINFO << "testing with single_variable_function and the start point: " << p; + double out, x; + + total_count = 0; + x = p; + out = find_min_single_variable(single_variable_function, x, -1e100, 1e100, eps, 1000); + DLIB_TEST_MSG(std::abs(out-5) < 1e-6, out-5); + DLIB_TEST_MSG(std::abs(x) < 1e-6, x); + dlog << LINFO << "find_min_single_variable(): got single_variable_function in " << total_count; + + + total_count = 0; + x = p; + out = -find_max_single_variable(negate_function(single_variable_function), x, -1e100, 1e100, eps, 1000); + DLIB_TEST_MSG(std::abs(out-5) < 1e-6, out-5); + DLIB_TEST_MSG(std::abs(x) < 1e-6, x); + dlog << LINFO << "find_max_single_variable(): got single_variable_function in " << total_count; + + + if (p > 0) + { + total_count = 0; + x = p; + out = find_min_single_variable(single_variable_function, x, -1e-4, 1e100, eps, 1000); + DLIB_TEST_MSG(std::abs(out-5) < 1e-6, out-5); + DLIB_TEST_MSG(std::abs(x) < 1e-6, x); + dlog << LINFO << "find_min_single_variable(): got single_variable_function in " << total_count; + + + if (p > 3) + { + total_count = 0; + x = p; + out = -find_max_single_variable(negate_function(single_variable_function), x, 3, 1e100, eps, 1000); + DLIB_TEST_MSG(std::abs(out - (3*3*3+5)) < 1e-6, out-(3*3*3+5)); + DLIB_TEST_MSG(std::abs(x-3) < 1e-6, x); + dlog << LINFO << "find_max_single_variable(): got single_variable_function in " << total_count; + } + } + + if (p < 0) + { + total_count = 0; + x = p; + out = find_min_single_variable(single_variable_function, x, -1e100, 1e-4, eps, 1000); + DLIB_TEST_MSG(std::abs(out-5) < 1e-6, out-5); + DLIB_TEST_MSG(std::abs(x) < 1e-6, x); + dlog << LINFO << "find_min_single_variable(): got single_variable_function in " << total_count; + + if (p < -3) + { + total_count = 0; + x = p; + out = find_min_single_variable(single_variable_function, x, -1e100, -3, eps, 1000); + DLIB_TEST_MSG(std::abs(out - (3*3*3+5)) < 1e-6, out-(3*3*3+5)); + DLIB_TEST_MSG(std::abs(x+3) < 1e-6, x); + dlog << LINFO << "find_min_single_variable(): got single_variable_function in " << total_count; + } + } + + } + +// ---------------------------------------------------------------------------------------- + + void optimization_test ( + ) + /*! + ensures + - runs tests on the optimization stuff compliance with the specs + !*/ + { + matrix p; + + print_spinner(); + + p.set_size(2); + + // test with single_variable_function + test_single_variable_function(0); + test_single_variable_function(1); + test_single_variable_function(-10); + test_single_variable_function(-100); + test_single_variable_function(900.53); + + // test with the rosen function + p(0) = 9; + p(1) = -4.9; + test_rosen(p); + test_neg_rosen(p); + + p(0) = 0; + p(1) = 0; + test_rosen(p); + + p(0) = 5323; + p(1) = 98248; + test_rosen(p); + + // test with the simple function + p(0) = 1; + p(1) = 1; + test_simple(p); + + p(0) = 0.5; + p(1) = -9; + test_simple(p); + + p(0) = 645; + p(1) = 839485; + test_simple(p); + + print_spinner(); + + // test with the apq function + p.set_size(5); + + p(0) = 1; + p(1) = 1; + p(2) = 1; + p(3) = 1; + p(4) = 1; + test_apq(p); + + p(0) = 1; + p(1) = 2; + p(2) = 3; + p(3) = 4; + p(4) = 5; + test_apq(p); + + p(0) = 1; + p(1) = 2; + p(2) = -3; + p(3) = 4; + p(4) = 5; + test_apq(p); + + print_spinner(); + + p(0) = 1; + p(1) = 2324; + p(2) = -3; + p(3) = 4; + p(4) = 534534; + test_apq(p); + + p.set_size(10); + p(0) = 1; + p(1) = 2; + p(2) = -3; + p(3) = 4; + p(4) = 5; + p(5) = 1; + p(6) = 2; + p(7) = -3; + p(8) = 4; + p(9) = 5; + test_apq(p); + + // test apq with a big vector + p.set_size(500); + dlib::rand rnd; + for (long i = 0; i < p.size(); ++i) + { + p(i) = rnd.get_random_double()*20 - 10; + } + test_apq(p); + + print_spinner(); + + // test with the powell function + p.set_size(4); + + p(0) = 3; + p(1) = -1; + p(2) = 0; + p(3) = 1; + test_powell(p); + + { + matrix m; + m(0) = -0.43; + m(1) = 0.919; + DLIB_TEST(dlib::equal(der_rosen(m) , derivative(rosen)(m),1e-5)); + + DLIB_TEST_MSG(std::abs(derivative(make_line_search_function(rosen,m,m))(0) - + make_line_search_function(derivative(rosen),m,m)(0)) < 1e-5,""); + DLIB_TEST_MSG(std::abs(derivative(make_line_search_function(rosen,m,m))(1) - + make_line_search_function(derivative(rosen),m,m)(1)) < 1e-5,""); + + DLIB_TEST_MSG(std::abs(derivative(make_line_search_function(rosen,m,m))(0) - + make_line_search_function(der_rosen,m,m)(0)) < 1e-5,""); + DLIB_TEST_MSG(std::abs(derivative(make_line_search_function(rosen,m,m))(1) - + make_line_search_function(der_rosen,m,m)(1)) < 1e-5,""); + } + { + matrix m; + m(0) = 1; + m(1) = 2; + DLIB_TEST(dlib::equal(der_rosen(m) , derivative(rosen)(m),1e-5)); + + DLIB_TEST_MSG(std::abs(derivative(make_line_search_function(rosen,m,m))(0) - + make_line_search_function(derivative(rosen),m,m)(0)) < 1e-5,""); + DLIB_TEST_MSG(std::abs(derivative(make_line_search_function(rosen,m,m))(1) - + make_line_search_function(derivative(rosen),m,m)(1)) < 1e-5,""); + + DLIB_TEST_MSG(std::abs(derivative(make_line_search_function(rosen,m,m))(0) - + make_line_search_function(der_rosen,m,m)(0)) < 1e-5,""); + DLIB_TEST_MSG(std::abs(derivative(make_line_search_function(rosen,m,m))(1) - + make_line_search_function(der_rosen,m,m)(1)) < 1e-5,""); + } + + { + matrix m; + m = 1,2; + DLIB_TEST(std::abs(neg_rosen(m) - negate_function(rosen)(m) ) < 1e-16); + } + + } + + template + double unconstrained_gradient_magnitude ( + const der_funct& grad, + const T& x, + const T& lower, + const T& upper + ) + { + T g = grad(x); + + double unorm = 0; + + for (long i = 0; i < g.size(); ++i) + { + if (lower(i) < x(i) && x(i) < upper(i)) + unorm += g(i)*g(i); + else if (x(i) == lower(i) && g(i) < 0) + unorm += g(i)*g(i); + else if (x(i) == upper(i) && g(i) > 0) + unorm += g(i)*g(i); + } + + return unorm; + } + + template + double unconstrained_gradient_magnitude_neg_funct ( + const der_funct& grad, + const T& x, + const T& lower, + const T& upper + ) + { + T g = grad(x); + + double unorm = 0; + + for (long i = 0; i < g.size(); ++i) + { + if (lower(i) < x(i) && x(i) < upper(i)) + unorm += g(i)*g(i); + else if (x(i) == lower(i) && g(i) > 0) + unorm += g(i)*g(i); + else if (x(i) == upper(i) && g(i) < 0) + unorm += g(i)*g(i); + } + + return unorm; + } + + template + double test_bound_solver_neg_rosen (dlib::rand& rnd, search_strategy_type search_strategy) + { + using namespace dlib::test_functions; + print_spinner(); + matrix starting_point, lower, upper, x; + + + // pick random bounds + lower = rnd.get_random_gaussian()+1, rnd.get_random_gaussian()+1; + upper = rnd.get_random_gaussian()+1, rnd.get_random_gaussian()+1; + while (upper(0) < lower(0)) upper(0) = rnd.get_random_gaussian()+1; + while (upper(1) < lower(1)) upper(1) = rnd.get_random_gaussian()+1; + + starting_point = rnd.get_random_double()*(upper(0)-lower(0))+lower(0), + rnd.get_random_double()*(upper(1)-lower(1))+lower(1); + + dlog << LINFO << "lower: "<< trans(lower); + dlog << LINFO << "upper: "<< trans(upper); + dlog << LINFO << "starting: "<< trans(starting_point); + + x = starting_point; + double val = find_max_box_constrained( + search_strategy, + objective_delta_stop_strategy(1e-16, 500), + neg_rosen, der_neg_rosen, x, + lower, + upper + ); + + DLIB_TEST_MSG(std::abs(val - neg_rosen(x)) < 1e-11, std::abs(val - neg_rosen(x))); + dlog << LINFO << "neg_rosen solution:\n" << x; + + dlog << LINFO << "neg_rosen gradient: "<< trans(der_neg_rosen(x)); + const double gradient_residual = unconstrained_gradient_magnitude_neg_funct(der_neg_rosen, x, lower, upper); + dlog << LINFO << "gradient_residual: "<< gradient_residual; + + return gradient_residual; + } + + template + double test_bound_solver_rosen (dlib::rand& rnd, search_strategy_type search_strategy) + { + using namespace dlib::test_functions; + print_spinner(); + matrix starting_point, lower, upper, x; + + + // pick random bounds and sometimes put the upper bound at zero so we can have + // a test where the optimal value has a bound active at 0 so make sure this case + // works properly. + if (rnd.get_random_double() > 0.2) + { + lower = rnd.get_random_gaussian()+1, rnd.get_random_gaussian()+1; + upper = rnd.get_random_gaussian()+1, rnd.get_random_gaussian()+1; + while (upper(0) < lower(0)) upper(0) = rnd.get_random_gaussian()+1; + while (upper(1) < lower(1)) upper(1) = rnd.get_random_gaussian()+1; + } + else + { + upper = 0,0; + if (rnd.get_random_double() > 0.5) + upper(0) = -rnd.get_random_double(); + if (rnd.get_random_double() > 0.5) + upper(1) = -rnd.get_random_double(); + + lower = rnd.get_random_double()+1, rnd.get_random_double()+1; + lower = upper - lower; + } + const bool pick_uniform_bounds = rnd.get_random_double() > 0.9; + if (pick_uniform_bounds) + { + double x = rnd.get_random_gaussian()*2; + double y = rnd.get_random_gaussian()*2; + lower = min(x,y); + upper = max(x,y); + } + + starting_point = rnd.get_random_double()*(upper(0)-lower(0))+lower(0), + rnd.get_random_double()*(upper(1)-lower(1))+lower(1); + + dlog << LINFO << "lower: "<< trans(lower); + dlog << LINFO << "upper: "<< trans(upper); + dlog << LINFO << "starting: "<< trans(starting_point); + + x = starting_point; + double val; + if (!pick_uniform_bounds) + { + val = find_min_box_constrained( + search_strategy, + objective_delta_stop_strategy(1e-16, 500), + rosen, der_rosen, x, + lower, + upper + ); + } + else + { + val = find_min_box_constrained( + search_strategy, + objective_delta_stop_strategy(1e-16, 500), + rosen, der_rosen, x, + lower(0), + upper(0) + ); + } + + + DLIB_TEST_MSG(std::abs(val - rosen(x)) < 1e-11, std::abs(val - rosen(x))); + dlog << LINFO << "rosen solution:\n" << x; + + dlog << LINFO << "rosen gradient: "<< trans(der_rosen(x)); + const double gradient_residual = unconstrained_gradient_magnitude(der_rosen, x, lower, upper); + dlog << LINFO << "gradient_residual: "<< gradient_residual; + + return gradient_residual; + } + + template + double test_bound_solver_brown (dlib::rand& rnd, search_strategy_type search_strategy) + { + using namespace dlib::test_functions; + print_spinner(); + matrix starting_point(4), lower(4), upper(4), x; + + const matrix solution = brown_solution(); + + // pick random bounds + lower = rnd.get_random_gaussian(), rnd.get_random_gaussian(), rnd.get_random_gaussian(), rnd.get_random_gaussian(); + lower = lower*10 + solution; + upper = rnd.get_random_gaussian(), rnd.get_random_gaussian(), rnd.get_random_gaussian(), rnd.get_random_gaussian(); + upper = upper*10 + solution; + for (int i = 0; i < lower.size(); ++i) + { + if (upper(i) < lower(i)) + swap(upper(i),lower(i)); + } + + starting_point = rnd.get_random_double()*(upper(0)-lower(0))+lower(0), + rnd.get_random_double()*(upper(1)-lower(1))+lower(1), + rnd.get_random_double()*(upper(2)-lower(2))+lower(2), + rnd.get_random_double()*(upper(3)-lower(3))+lower(3); + + dlog << LINFO << "lower: "<< trans(lower); + dlog << LINFO << "upper: "<< trans(upper); + dlog << LINFO << "starting: "<< trans(starting_point); + + x = starting_point; + double val = find_min_box_constrained( + search_strategy, + objective_delta_stop_strategy(1e-16, 500), + brown, brown_derivative, x, + lower, + upper + ); + + DLIB_TEST(std::abs(val - brown(x)) < 1e-14); + dlog << LINFO << "brown solution:\n" << x; + return unconstrained_gradient_magnitude(brown_derivative, x, lower, upper); + } + + template + void test_box_constrained_optimizers(search_strategy_type search_strategy) + { + dlib::rand rnd; + running_stats rs; + + dlog << LINFO << "test find_min_box_constrained() on rosen"; + for (int i = 0; i < 10000; ++i) + rs.add(test_bound_solver_rosen(rnd, search_strategy)); + dlog << LINFO << "mean rosen gradient: " << rs.mean(); + dlog << LINFO << "max rosen gradient: " << rs.max(); + DLIB_TEST(rs.mean() < 1e-12); + DLIB_TEST(rs.max() < 1e-9); + + dlog << LINFO << "test find_min_box_constrained() on brown"; + rs.clear(); + for (int i = 0; i < 1000; ++i) + rs.add(test_bound_solver_brown(rnd, search_strategy)); + dlog << LINFO << "mean brown gradient: " << rs.mean(); + dlog << LINFO << "max brown gradient: " << rs.max(); + dlog << LINFO << "min brown gradient: " << rs.min(); + DLIB_TEST(rs.mean() < 4e-5); + DLIB_TEST_MSG(rs.max() < 3e-2, rs.max()); + DLIB_TEST(rs.min() < 1e-10); + + dlog << LINFO << "test find_max_box_constrained() on neg_rosen"; + rs.clear(); + for (int i = 0; i < 1000; ++i) + rs.add(test_bound_solver_neg_rosen(rnd, search_strategy)); + dlog << LINFO << "mean neg_rosen gradient: " << rs.mean(); + dlog << LINFO << "max neg_rosen gradient: " << rs.max(); + DLIB_TEST(rs.mean() < 1e-12); + DLIB_TEST(rs.max() < 1e-9); + + } + + void test_poly_min_extract_2nd() + { + double off; + + off = 0.0; DLIB_TEST(std::abs( poly_min_extrap(off*off, -2*off, (1-off)*(1-off)) - off) < 1e-13); + off = 0.1; DLIB_TEST(std::abs( poly_min_extrap(off*off, -2*off, (1-off)*(1-off)) - off) < 1e-13); + off = 0.2; DLIB_TEST(std::abs( poly_min_extrap(off*off, -2*off, (1-off)*(1-off)) - off) < 1e-13); + off = 0.3; DLIB_TEST(std::abs( poly_min_extrap(off*off, -2*off, (1-off)*(1-off)) - off) < 1e-13); + off = 0.4; DLIB_TEST(std::abs( poly_min_extrap(off*off, -2*off, (1-off)*(1-off)) - off) < 1e-13); + off = 0.5; DLIB_TEST(std::abs( poly_min_extrap(off*off, -2*off, (1-off)*(1-off)) - off) < 1e-13); + off = 0.6; DLIB_TEST(std::abs( poly_min_extrap(off*off, -2*off, (1-off)*(1-off)) - off) < 1e-13); + off = 0.8; DLIB_TEST(std::abs( poly_min_extrap(off*off, -2*off, (1-off)*(1-off)) - off) < 1e-13); + off = 0.9; DLIB_TEST(std::abs( poly_min_extrap(off*off, -2*off, (1-off)*(1-off)) - off) < 1e-13); + off = 1.0; DLIB_TEST(std::abs( poly_min_extrap(off*off, -2*off, (1-off)*(1-off)) - off) < 1e-13); + } + + void test_solve_trust_region_subproblem_bounded() + { + print_spinner(); + matrix H(2,2); + H = 1, 0, + 0, 1; + matrix g, lower, upper, p, true_p; + g = {0, 0}; + + double radius = 0.5; + lower = {0.5, 0}; + upper = {10, 10}; + + + solve_trust_region_subproblem_bounded(H,g, radius, p, 0.001, 500, lower, upper); + true_p = { 0.5, 0}; + DLIB_TEST_MSG(length(p-true_p) < 1e-12, p); + + } + +// ---------------------------------------------------------------------------------------- + + class optimization_tester : public tester + { + public: + optimization_tester ( + ) : + tester ("test_optimization", + "Runs tests on the optimization component.") + {} + + void perform_test ( + ) + { + dlog << LINFO << "test_box_constrained_optimizers(bfgs_search_strategy())"; + test_box_constrained_optimizers(bfgs_search_strategy()); + dlog << LINFO << "test_box_constrained_optimizers(lbfgs_search_strategy(5))"; + test_box_constrained_optimizers(lbfgs_search_strategy(5)); + test_poly_min_extract_2nd(); + optimization_test(); + test_solve_trust_region_subproblem_bounded(); + } + } a; + +} + + diff --git a/ml/dlib/dlib/test/optimization_test_functions.cpp b/ml/dlib/dlib/test/optimization_test_functions.cpp new file mode 100644 index 000000000..1cded50ee --- /dev/null +++ b/ml/dlib/dlib/test/optimization_test_functions.cpp @@ -0,0 +1,425 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#include "optimization_test_functions.h" + +/* + + Most of the code in this file is converted from the set of Fortran 90 routines + created by John Burkardt. + + The original Fortran can be found here: http://orion.math.iastate.edu/burkardt/f_src/testopt/testopt.html + +*/ + + +namespace dlib +{ + namespace test_functions + { + + // ---------------------------------------------------------------------------------------- + + matrix chebyquad_residuals(const matrix& x) + { + matrix fvec(x.size()); + const int n = x.size(); + int i; + int j; + double t; + double t1; + double t2; + double th; + fvec = 0; + + for (j = 1; j <= n; ++j) + { + t1 = 1.0E+00; + t2 = 2.0E+00 * x(j-1) - 1.0E+00; + t = 2.0E+00 * t2; + for (i = 1; i <= n; ++i) + { + fvec(i-1) = fvec(i-1) + t2; + th = t * t2 - t1; + t1 = t2; + t2 = th; + } + } + + for (i = 1; i <= n; ++i) + { + fvec(i-1) = fvec(i-1) / (double) ( n ); + if ( ( i%2 ) == 0 ) + fvec(i-1) = fvec(i-1) + 1.0E+00 / ( (double)i*i - 1.0E+00 ); + } + + return fvec; + } + + // ---------------------------------------------------------------------------------------- + + double chebyquad_residual(int i, const matrix& x) + { + return chebyquad_residuals(x)(i); + } + + // ---------------------------------------------------------------------------------------- + + int& chebyquad_calls() + { + static int count = 0; + return count; + } + + double chebyquad(const matrix& x ) + { + chebyquad_calls()++; + return sum(squared(chebyquad_residuals(x))); + } + + // ---------------------------------------------------------------------------------------- + + matrix chebyquad_derivative (const matrix& x) + { + const int n = x.size(); + matrix fvec = chebyquad_residuals(x); + matrix g(n); + int i; + int j; + double s1; + double s2; + double t; + double t1; + double t2; + double th; + + for (j = 1; j <= n; ++j) + { + g(j-1) = 0.0E+00; + t1 = 1.0E+00; + t2 = 2.0E+00 * x(j-1) - 1.0E+00; + t = 2.0E+00 * t2; + s1 = 0.0E+00; + s2 = 2.0E+00; + for (i = 1; i <= n; ++i) + { + g(j-1) = g(j-1) + fvec(i-1) * s2; + th = 4.0E+00 * t2 + t * s2 - s1; + s1 = s2; + s2 = th; + th = t * t2 - t1; + t1 = t2; + t2 = th; + } + } + + g = 2.0E+00 * g / (double) ( n ); + + return g; + } + + // ---------------------------------------------------------------------------------------- + + matrix chebyquad_start (int n) + { + int i; + matrix x(n); + + for (i = 1; i <= n; ++i) + x(i-1) = double ( i ) / double ( n + 1 ); + + return x; + } + + // ---------------------------------------------------------------------------------------- + + matrix chebyquad_solution (int n) + { + matrix x(n); + + x = 0; + switch (n) + { + case 2: + x = 0.2113249E+00, 0.7886751E+00; + break; + case 4: + x = 0.1026728E+00, 0.4062037E+00, 0.5937963E+00, 0.8973272E+00; + break; + case 6: + x = 0.066877E+00, 0.288741E+00, 0.366682E+00, 0.633318E+00, 0.711259E+00, 0.933123E+00; + break; + case 8: + x = 0.043153E+00, 0.193091E+00, 0.266329E+00, 0.500000E+00, 0.500000E+00, 0.733671E+00, 0.806910E+00, 0.956847E+00; + break; + default: + std::ostringstream sout; + sout << "don't know chebyquad solution for n = " << n; + throw dlib::error(sout.str()); + break; + } + + return x; + } + + // ---------------------------------------------------------------------------------------- + + matrix chebyquad_hessian(const matrix& x) + { + const int lda = x.size(); + const int n = x.size(); + double d1; + double d2; + matrix fvec = chebyquad_residuals(x); + matrix gvec(n); + matrix h(lda,n); + int i; + int j; + int k; + double p1; + double p2; + double s1; + double s2; + double ss1; + double ss2; + double t; + double t1; + double t2; + double th; + double tt; + double tth; + double tt1; + double tt2; + h = 0; + + d1 = 1.0E+00 / double ( n ); + d2 = 2.0E+00 * d1; + + for (j = 1; j <= n; ++j) + { + + h(j-1,j-1) = 4.0E+00 * d1; + t1 = 1.0E+00; + t2 = 2.0E+00 * x(j-1) - 1.0E+00; + t = 2.0E+00 * t2; + s1 = 0.0E+00; + s2 = 2.0E+00; + p1 = 0.0E+00; + p2 = 0.0E+00; + gvec(0) = s2; + + for (i = 2; i <= n; ++i) + { + th = 4.0E+00 * t2 + t * s2 - s1; + s1 = s2; + s2 = th; + th = t * t2 - t1; + t1 = t2; + t2 = th; + th = 8.0E+00 * s1 + t * p2 - p1; + p1 = p2; + p2 = th; + gvec(i-1) = s2; + h(j-1,j-1) = h(j-1,j-1) + fvec(i-1) * th + d1 * s2*s2; + } + + h(j-1,j-1) = d2 * h(j-1,j-1); + + for (k = 1; k <= j-1; ++k) + { + + h(j-1,k-1) = 0.0; + tt1 = 1.0E+00; + tt2 = 2.0E+00 * x(k-1) - 1.0E+00; + tt = 2.0E+00 * tt2; + ss1 = 0.0E+00; + ss2 = 2.0E+00; + + for (i = 1; i <= n; ++i) + { + h(j-1,k-1) = h(j-1,k-1) + ss2 * gvec(i-1); + tth = 4.0E+00 * tt2 + tt * ss2 - ss1; + ss1 = ss2; + ss2 = tth; + tth = tt * tt2 - tt1; + tt1 = tt2; + tt2 = tth; + } + + h(j-1,k-1) = d2 * d1 * h(j-1,k-1); + + } + + } + + h = make_symmetric(h); + return h; + } + + // ---------------------------------------------------------------------------------------- + // ---------------------------------------------------------------------------------------- + // ---------------------------------------------------------------------------------------- + // ---------------------------------------------------------------------------------------- + + double brown_residual (int i, const matrix& x) + /*! + requires + - 1 <= i <= 20 + ensures + - returns the ith brown residual + !*/ + { + double c; + double f; + double f1; + double f2; + + f = 0.0E+00; + + + c = double ( i ) / 5.0E+00; + f1 = x(0) + c * x(1) - std::exp ( c ); + f2 = x(2) + std::sin ( c ) * x(3) - std::cos ( c ); + + f = f1*f1 + f2*f2; + + return f; + } + + // ---------------------------------------------------------------------------------------- + + double brown ( const matrix& x) + { + double f; + int i; + + f = 0; + + for (i = 1; i <= 20; ++i) + { + f += std::pow(brown_residual(i, x), 2); + } + + return f; + } + + // ---------------------------------------------------------------------------------------- + + matrix brown_derivative ( const matrix& x) + { + double c; + double df1dx1; + double df1dx2; + double df2dx3; + double df2dx4; + double f1; + double f2; + matrix g; + int i; + + g = 0; + + for (i = 1; i <= 20; ++i) + { + + c = double ( i ) / 5.0E+00; + + f1 = x(0) + c * x(1) - std::exp ( c ); + f2 = x(2) + std::sin ( c ) * x(3) - std::cos ( c ); + + df1dx1 = 1.0E+00; + df1dx2 = c; + df2dx3 = 1.0E+00; + df2dx4 = std::sin ( c ); + + using std::pow; + g(0) = g(0) + 4.0E+00 * ( pow(f1,3) * df1dx1 + f1 * pow(f2,2) * df1dx1 ); + g(1) = g(1) + 4.0E+00 * ( pow(f1,3) * df1dx2 + f1 * pow(f2,2) * df1dx2 ); + g(2) = g(2) + 4.0E+00 * ( pow(f1,2) * f2 * df2dx3 + pow(f2,3) * df2dx3 ); + g(3) = g(3) + 4.0E+00 * ( pow(f1,2) * f2 * df2dx4 + pow(f2,3) * df2dx4 ); + + } + + return g; + } + + // ---------------------------------------------------------------------------------------- + + matrix brown_hessian ( const matrix& x) + { + double c; + double df1dx1; + double df1dx2; + double df2dx3; + double df2dx4; + double f1; + double f2; + matrix h; + int i; + + h = 0; + + for (i = 1; i <= 20; ++i) + { + + c = double ( i ) / 5.0E+00; + + f1 = x(0) + c * x(1) - std::exp ( c ); + f2 = x(2) + std::sin ( c ) * x(3) - std::cos ( c ); + + df1dx1 = 1.0E+00; + df1dx2 = c; + df2dx3 = 1.0E+00; + df2dx4 = std::sin ( c ); + + using std::pow; + h(0,0) = h(0,0) + 12.0E+00 * pow(f1,2) * df1dx1 * df1dx1 + 4.0E+00 * pow(f2,2) * df1dx1 * df1dx1; + h(0,1) = h(0,1) + 12.0E+00 * pow(f1,2) * df1dx1 * df1dx2 + 4.0E+00 * pow(f2,2) * df1dx1 * df1dx2; + h(0,2) = h(0,2) + 8.0E+00 * f1 * f2 * df1dx1 * df2dx3; + h(0,3) = h(0,3) + 8.0E+00 * f1 * f2 * df1dx1 * df2dx4; + + h(1,0) = h(1,0) + 12.0E+00 * pow(f1,2) * df1dx2 * df1dx1 + 4.0E+00 * pow(f2,2) * df1dx2 * df1dx1; + h(1,1) = h(1,1) + 12.0E+00 * pow(f1,2) * df1dx2 * df1dx2 + 4.0E+00 * pow(f2,2) * df1dx2 * df1dx2; + h(1,2) = h(1,2) + 8.0E+00 * f1 * f2 * df1dx2 * df2dx3; + h(1,3) = h(1,3) + 8.0E+00 * f1 * f2 * df1dx2 * df2dx4; + + h(2,0) = h(2,0) + 8.0E+00 * f1 * f2 * df2dx3 * df1dx1; + h(2,1) = h(2,1) + 8.0E+00 * f1 * f2 * df2dx3 * df1dx2; + h(2,2) = h(2,2) + 4.0E+00 * pow(f1,2) * df2dx3 * df2dx3 + 12.0E+00 * pow(f2,2) * df2dx3 * df2dx3; + h(2,3) = h(2,3) + 4.0E+00 * pow(f1,2) * df2dx4 * df2dx3 + 12.0E+00 * pow(f2,2) * df2dx3 * df2dx4; + + h(3,0) = h(3,0) + 8.0E+00 * f1 * f2 * df2dx4 * df1dx1; + h(3,1) = h(3,1) + 8.0E+00 * f1 * f2 * df2dx4 * df1dx2; + h(3,2) = h(3,2) + 4.0E+00 * pow(f1,2) * df2dx3 * df2dx4 + 12.0E+00 * pow(f2,2) * df2dx4 * df2dx3; + h(3,3) = h(3,3) + 4.0E+00 * pow(f1,2) * df2dx4 * df2dx4 + 12.0E+00 * pow(f2,2) * df2dx4 * df2dx4; + + } + + return make_symmetric(h); + } + + // ---------------------------------------------------------------------------------------- + + matrix brown_start () + { + matrix x; + x = 25.0E+00, 5.0E+00, -5.0E+00, -1.0E+00; + return x; + } + + // ---------------------------------------------------------------------------------------- + + matrix brown_solution () + { + matrix x; + // solution from original documentation. + //x = -11.5844E+00, 13.1999E+00, -0.406200E+00, 0.240998E+00; + x = -11.594439905669450042, 13.203630051593080452, -0.40343948856573402795, 0.23677877338218666914; + return x; + } + + // ---------------------------------------------------------------------------------------- + + } +} + + diff --git a/ml/dlib/dlib/test/optimization_test_functions.h b/ml/dlib/dlib/test/optimization_test_functions.h new file mode 100644 index 000000000..a20523eae --- /dev/null +++ b/ml/dlib/dlib/test/optimization_test_functions.h @@ -0,0 +1,310 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_OPTIMIZATION_TEST_FUNCTiONS_H_h_ +#define DLIB_OPTIMIZATION_TEST_FUNCTiONS_H_h_ + +#include +#include +#include + +/* + + Most of the code in this file is converted from the set of Fortran 90 routines + created by John Burkardt. + + The original Fortran can be found here: http://orion.math.iastate.edu/burkardt/f_src/testopt/testopt.html + +*/ + +// GCC 4.8 gives false alarms about some variables being uninitialized. Disable these +// false warnings. +#if ( defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ == 8) + #pragma GCC diagnostic ignored "-Wmaybe-uninitialized" +#endif + + +namespace dlib +{ + namespace test_functions + { + + // ---------------------------------------------------------------------------------------- + + matrix chebyquad_residuals(const matrix& x); + + double chebyquad_residual(int i, const matrix& x); + + int& chebyquad_calls(); + + double chebyquad(const matrix& x ); + + matrix chebyquad_derivative (const matrix& x); + + matrix chebyquad_start (int n); + + matrix chebyquad_solution (int n); + + matrix chebyquad_hessian(const matrix& x); + + // ---------------------------------------------------------------------------------------- + + class chebyquad_function_model + { + public: + + // Define the type used to represent column vectors + typedef matrix column_vector; + // Define the type used to represent the hessian matrix + typedef matrix general_matrix; + + double operator() ( + const column_vector& x + ) const + { + return chebyquad(x); + } + + void get_derivative_and_hessian ( + const column_vector& x, + column_vector& d, + general_matrix& h + ) const + { + d = chebyquad_derivative(x); + h = chebyquad_hessian(x); + } + }; + + // ---------------------------------------------------------------------------------------- + // ---------------------------------------------------------------------------------------- + // ---------------------------------------------------------------------------------------- + // ---------------------------------------------------------------------------------------- + + double brown_residual (int i, const matrix& x); + /*! + requires + - 1 <= i <= 20 + ensures + - returns the ith brown residual + !*/ + + double brown ( const matrix& x); + + matrix brown_derivative ( const matrix& x); + + matrix brown_hessian ( const matrix& x); + + matrix brown_start (); + + matrix brown_solution (); + + class brown_function_model + { + public: + + // Define the type used to represent column vectors + typedef matrix column_vector; + // Define the type used to represent the hessian matrix + typedef matrix general_matrix; + + double operator() ( + const column_vector& x + ) const + { + return brown(x); + } + + void get_derivative_and_hessian ( + const column_vector& x, + column_vector& d, + general_matrix& h + ) const + { + d = brown_derivative(x); + h = brown_hessian(x); + } + }; + + // ---------------------------------------------------------------------------------------- + // ---------------------------------------------------------------------------------------- + // ---------------------------------------------------------------------------------------- + // ---------------------------------------------------------------------------------------- + + template + matrix rosen_big_start() + { + matrix x; + x = -1.2, -1; + return x; + } + + // This is a variation on the Rosenbrock test function but with large residuals. The + // minimum is at 1, 1 and the objective value is 1. + template + T rosen_big_residual (int i, const matrix& m) + { + using std::pow; + const T x = m(0); + const T y = m(1); + + if (i == 1) + { + return 100*pow(y - x*x,2)+1.0; + } + else + { + return pow(1 - x,2) + 1.0; + } + } + + template + T rosen_big ( const matrix& m) + { + using std::pow; + return 0.5*(pow(rosen_big_residual(1,m),2) + pow(rosen_big_residual(2,m),2)); + } + + template + matrix rosen_big_solution () + { + matrix x; + // solution from original documentation. + x = 1,1; + return x; + } + + // ---------------------------------------------------------------------------------------- + // ---------------------------------------------------------------------------------------- + // ---------------------------------------------------------------------------------------- + // ---------------------------------------------------------------------------------------- + + template + matrix rosen_start() + { + matrix x; + x = -1.2, -1; + return x; + } + + template + T rosen ( const matrix& m) + { + const T x = m(0); + const T y = m(1); + + using std::pow; + // compute Rosenbrock's function and return the result + return 100.0*pow(y - x*x,2) + pow(1 - x,2); + } + + template + T rosen_residual (int i, const matrix& m) + { + const T x = m(0); + const T y = m(1); + + + if (i == 1) + { + return 10*(y - x*x); + } + else + { + return 1 - x; + } + } + + template + matrix rosen_residual_derivative (int i, const matrix& m) + { + const T x = m(0); + + matrix d; + + if (i == 1) + { + d = -20*x, 10; + } + else + { + d = -1, 0; + } + return d; + } + + template + const matrix rosen_derivative ( const matrix& m) + { + const T x = m(0); + const T y = m(1); + + // make us a column vector of length 2 + matrix res(2); + + // now compute the gradient vector + res(0) = -400*x*(y-x*x) - 2*(1-x); // derivative of rosen() with respect to x + res(1) = 200*(y-x*x); // derivative of rosen() with respect to y + return res; + } + + template + const matrix rosen_hessian ( const matrix& m) + { + const T x = m(0); + const T y = m(1); + + // make us a column vector of length 2 + matrix res; + + // now compute the gradient vector + res(0,0) = -400*y + 3*400*x*x + 2; + res(1,1) = 200; + + res(0,1) = -400*x; + res(1,0) = -400*x; + return res; + } + + template + matrix rosen_solution () + { + matrix x; + // solution from original documentation. + x = 1,1; + return x; + } + + // ------------------------------------------------------------------------------------ + + template + struct rosen_function_model + { + typedef matrix column_vector; + typedef matrix general_matrix; + + T operator() ( column_vector x) const + { + return static_cast(rosen(x)); + } + + void get_derivative_and_hessian ( + const column_vector& x, + column_vector& d, + general_matrix& h + ) const + { + d = rosen_derivative(x); + h = rosen_hessian(x); + } + + }; + + // ---------------------------------------------------------------------------------------- + + } +} + +#endif // DLIB_OPTIMIZATION_TEST_FUNCTiONS_H_h_ + + + diff --git a/ml/dlib/dlib/test/parallel_for.cpp b/ml/dlib/dlib/test/parallel_for.cpp new file mode 100644 index 000000000..5bee40955 --- /dev/null +++ b/ml/dlib/dlib/test/parallel_for.cpp @@ -0,0 +1,334 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "tester.h" +#include +#include +#include + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + dlib::logger dlog("test.parallel_for"); + + class assign_element + { + public: + + assign_element( + std::vector& vect_ + ) : vect(vect_){} + + std::vector& vect; + + void go (long i ) + { + DLIB_TEST( 0 <= i && i < (long)vect.size()); + vect[i] = i; + } + + void operator() (long i ) const + { + DLIB_TEST( 0 <= i && i < (long)vect.size()); + vect[i] = i; + } + + }; + + void test_parallel_for(long start) + { + std::vector vect(200,0); + + parallel_for(4, start, vect.size(), assign_element(vect)); + + for (long i = 0; i < start; ++i) + { + DLIB_TEST(vect[i] == 0); + } + for (long i = start; i < (long)vect.size(); ++i) + { + DLIB_TEST(vect[i] == i); + } + } + + void test_parallel_for2(long start) + { + std::vector vect(200,0); + + assign_element temp(vect); + parallel_for(4, start, vect.size(), temp, &assign_element::go); + + for (long i = 0; i < start; ++i) + { + DLIB_TEST(vect[i] == 0); + } + for (long i = start; i < (long)vect.size(); ++i) + { + DLIB_TEST(vect[i] == i); + } + } + + struct parfor_test_helper + { + mutable std::vector test; + + parfor_test_helper() : test(400,100000) + { + } + + void go(long begin, long end) + { + for (long i = begin; i < end; ++i) + test[i] = i; + } + + void operator()(long begin, long end) const + { + for (long i = begin; i < end; ++i) + test[i] = i; + } + + void go2(long i) + { + test[i] = i; + } + + }; + + struct parfor_test_helper2 + { + mutable std::vector test; + + parfor_test_helper2() : test(400,100000) + { + } + + void operator()(long i) const + { + test[i] = i; + } + + }; + + void test_parallel_for_additional() + { + { + parfor_test_helper helper; + parallel_for(4, 0, helper.test.size(), helper, &parfor_test_helper::go2); + + for (unsigned long i = 0; i < helper.test.size(); ++i) + { + DLIB_CASSERT(helper.test[i] == (long)i, helper.test[i]); + } + } + { + parfor_test_helper helper; + parallel_for(4, 10, helper.test.size(), helper, &parfor_test_helper::go2); + + for (unsigned long i = 0; i < 10; ++i) + { + DLIB_CASSERT(helper.test[i] == 100000, helper.test[i]); + } + for (unsigned long i = 10; i < helper.test.size(); ++i) + { + DLIB_CASSERT(helper.test[i] == (long)i, helper.test[i]); + } + } + { + parfor_test_helper helper; + parallel_for_blocked(4, 0, helper.test.size(), helper, &parfor_test_helper::go); + + for (unsigned long i = 0; i < helper.test.size(); ++i) + { + DLIB_CASSERT(helper.test[i] == (long)i, helper.test[i]); + } + } + { + parfor_test_helper helper; + parallel_for_blocked(4, 10, helper.test.size(), helper, &parfor_test_helper::go); + + for (unsigned long i = 0; i < 10; ++i) + { + DLIB_CASSERT(helper.test[i] == 100000, helper.test[i]); + } + for (unsigned long i = 10; i < helper.test.size(); ++i) + { + DLIB_CASSERT(helper.test[i] == (long)i, helper.test[i]); + } + } + { + parfor_test_helper helper; + parallel_for_blocked(4, 0, helper.test.size(), helper); + + for (unsigned long i = 0; i < helper.test.size(); ++i) + { + DLIB_CASSERT(helper.test[i] == (long)i, helper.test[i]); + } + } + { + parfor_test_helper helper; + parallel_for_blocked(4, 10, helper.test.size(), helper); + + for (unsigned long i = 0; i < 10; ++i) + { + DLIB_CASSERT(helper.test[i] == 100000, helper.test[i]); + } + for (unsigned long i = 10; i < helper.test.size(); ++i) + { + DLIB_CASSERT(helper.test[i] == (long)i, helper.test[i]); + } + } + { + parfor_test_helper2 helper; + parallel_for(4, 0, helper.test.size(), helper); + + for (unsigned long i = 0; i < helper.test.size(); ++i) + { + DLIB_CASSERT(helper.test[i] == (long)i, helper.test[i]); + } + } + { + parfor_test_helper2 helper; + parallel_for(4, 10, helper.test.size(), helper); + + for (unsigned long i = 0; i < 10; ++i) + { + DLIB_CASSERT(helper.test[i] == 100000, helper.test[i]); + } + for (unsigned long i = 10; i < helper.test.size(); ++i) + { + DLIB_CASSERT(helper.test[i] == (long)i, helper.test[i]); + } + } + + + + + + + { + parfor_test_helper helper; + parallel_for_verbose(4, 0, helper.test.size(), helper, &parfor_test_helper::go2); + + for (unsigned long i = 0; i < helper.test.size(); ++i) + { + DLIB_CASSERT(helper.test[i] == (long)i, helper.test[i]); + } + } + { + parfor_test_helper helper; + parallel_for_verbose(4, 10, helper.test.size(), helper, &parfor_test_helper::go2); + + for (unsigned long i = 0; i < 10; ++i) + { + DLIB_CASSERT(helper.test[i] == 100000, helper.test[i]); + } + for (unsigned long i = 10; i < helper.test.size(); ++i) + { + DLIB_CASSERT(helper.test[i] == (long)i, helper.test[i]); + } + } + { + parfor_test_helper helper; + parallel_for_blocked_verbose(4, 0, helper.test.size(), helper, &parfor_test_helper::go); + + for (unsigned long i = 0; i < helper.test.size(); ++i) + { + DLIB_CASSERT(helper.test[i] == (long)i, helper.test[i]); + } + } + { + parfor_test_helper helper; + parallel_for_blocked_verbose(4, 10, helper.test.size(), helper, &parfor_test_helper::go); + + for (unsigned long i = 0; i < 10; ++i) + { + DLIB_CASSERT(helper.test[i] == 100000, helper.test[i]); + } + for (unsigned long i = 10; i < helper.test.size(); ++i) + { + DLIB_CASSERT(helper.test[i] == (long)i, helper.test[i]); + } + } + { + parfor_test_helper helper; + parallel_for_blocked_verbose(4, 0, helper.test.size(), helper); + + for (unsigned long i = 0; i < helper.test.size(); ++i) + { + DLIB_CASSERT(helper.test[i] == (long)i, helper.test[i]); + } + } + { + parfor_test_helper helper; + parallel_for_blocked_verbose(4, 10, helper.test.size(), helper); + + for (unsigned long i = 0; i < 10; ++i) + { + DLIB_CASSERT(helper.test[i] == 100000, helper.test[i]); + } + for (unsigned long i = 10; i < helper.test.size(); ++i) + { + DLIB_CASSERT(helper.test[i] == (long)i, helper.test[i]); + } + } + { + parfor_test_helper2 helper; + parallel_for_verbose(4, 0, helper.test.size(), helper); + + for (unsigned long i = 0; i < helper.test.size(); ++i) + { + DLIB_CASSERT(helper.test[i] == (long)i, helper.test[i]); + } + } + { + parfor_test_helper2 helper; + parallel_for_verbose(4, 10, helper.test.size(), helper); + + for (unsigned long i = 0; i < 10; ++i) + { + DLIB_CASSERT(helper.test[i] == 100000, helper.test[i]); + } + for (unsigned long i = 10; i < helper.test.size(); ++i) + { + DLIB_CASSERT(helper.test[i] == (long)i, helper.test[i]); + } + } + } + + class test_parallel_for_routines : public tester + { + public: + test_parallel_for_routines ( + ) : + tester ( + "test_parallel_for", // the command line argument name for this test + "Run tests on the parallel_for routines.", // the command line argument description + 0 // the number of command line arguments for this test + ) + { + } + + void perform_test ( + ) + { + test_parallel_for(0); + test_parallel_for(30); + test_parallel_for(50); + test_parallel_for2(0); + test_parallel_for2(30); + test_parallel_for2(50); + + test_parallel_for_additional(); + } + }; + + test_parallel_for_routines a; + +} + + + + diff --git a/ml/dlib/dlib/test/parse.cpp b/ml/dlib/dlib/test/parse.cpp new file mode 100644 index 000000000..b0ea13b1b --- /dev/null +++ b/ml/dlib/dlib/test/parse.cpp @@ -0,0 +1,233 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#include +#include +#include +#include +#include + +#include "tester.h" + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + + logger dlog("test.parse"); + +// ---------------------------------------------------------------------------------------- + + const unsigned long DET = 0; + const unsigned long N = 1; + const unsigned long V = 2; + const unsigned long NP = 3; + const unsigned long VP = 4; + const unsigned long S = 5; + const unsigned long B = 6; + const unsigned long G = 7; + const unsigned long A = 8; + + typedef unsigned long tags; + + template + void user_defined_ruleset ( + const std::vector& words, + const constituent& c, + std::vector >& possible_ids + ) + { + DLIB_TEST(c.begin < c.k && c.k < c.end && c.end <= words.size()); + DLIB_TEST(possible_ids.size() == 0); + + if (c.left_tag == NP && c.right_tag == VP) possible_ids.push_back(make_pair(S,log(0.80))); + else if (c.left_tag == DET && c.right_tag == N) possible_ids.push_back(make_pair(NP,log(0.30))); + else if (c.left_tag == VP && c.right_tag == A) possible_ids.push_back(make_pair(VP,log(0.30))); + else if (c.left_tag == V && c.right_tag == NP) + { + possible_ids.push_back(make_pair(VP,log(0.20))); + possible_ids.push_back(make_pair(B,0.10)); + } + else if (has_glue_term) + { + possible_ids.push_back(make_pair(G, log(0.01))); + } + } + +// ---------------------------------------------------------------------------------------- + + void dotest1() + { + print_spinner(); + dlog << LINFO << "in dotest1()"; + + std::vector words; + std::vector sequence; + for (int i = 0; i < 8; ++i) + { + sequence.push_back(DET); + sequence.push_back(N); + sequence.push_back(V); + sequence.push_back(DET); + sequence.push_back(N); + sequence.push_back(A); + + words.push_back("The"); + words.push_back("flight"); + words.push_back("includes"); + words.push_back("a"); + words.push_back("meal"); + words.push_back("AWORD"); + } + + std::vector > parse_tree; + + find_max_parse_cky(sequence, user_defined_ruleset, parse_tree); + DLIB_TEST(parse_tree.size() != 0); + + + std::vector roots; + find_trees_not_rooted_with_tag(parse_tree, G, roots); + DLIB_TEST(roots.size() == 8); + + for (unsigned long i = 0; i < roots.size(); ++i) + { + dlog << LINFO << parse_tree_to_string(parse_tree, words, roots[i]); + DLIB_TEST(parse_tree_to_string(parse_tree, words, roots[i]) == "[[The flight] [[includes [a meal]] AWORD]]"); + dlog << LINFO << parse_tree_to_string_tagged(parse_tree, words, roots[i]); + DLIB_TEST(parse_tree_to_string_tagged(parse_tree, words, roots[i]) == "[5 [3 The flight] [4 [4 includes [3 a meal]] AWORD]]"); + } + + + words.clear(); + sequence.clear(); + + for (int i = 0; i < 2; ++i) + { + sequence.push_back(DET); + sequence.push_back(N); + sequence.push_back(V); + sequence.push_back(DET); + sequence.push_back(N); + + words.push_back("The"); + words.push_back("flight"); + words.push_back("includes"); + words.push_back("a"); + words.push_back("meal"); + } + + find_max_parse_cky(sequence, user_defined_ruleset, parse_tree); + DLIB_TEST(parse_tree.size() != 0); + + const std::string str1 = "[[[The flight] [includes [a meal]]] [[The flight] [includes [a meal]]]]"; + const std::string str2 = "[7 [5 [3 The flight] [4 includes [3 a meal]]] [5 [3 The flight] [4 includes [3 a meal]]]]"; + dlog << LINFO << parse_tree_to_string(parse_tree, words); + DLIB_TEST(parse_tree_to_string(parse_tree, words) == str1); + dlog << LINFO << parse_tree_to_string_tagged(parse_tree, words); + DLIB_TEST(parse_tree_to_string_tagged(parse_tree, words) == str2); + + const std::string str3 = "[[The flight] [includes [a meal]]] [[The flight] [includes [a meal]]]"; + const std::string str4 = "[5 [3 The flight] [4 includes [3 a meal]]] [5 [3 The flight] [4 includes [3 a meal]]]"; + dlog << LINFO << parse_trees_to_string(parse_tree, words, G); + DLIB_TEST(parse_trees_to_string(parse_tree, words, G) == str3); + dlog << LINFO << parse_trees_to_string_tagged(parse_tree, words, G); + DLIB_TEST(parse_trees_to_string_tagged(parse_tree, words, G) == str4); + + sequence.clear(); + find_max_parse_cky(sequence, user_defined_ruleset, parse_tree); + DLIB_TEST(parse_tree.size() == 0); + } + +// ---------------------------------------------------------------------------------------- + + void dotest2() + { + print_spinner(); + dlog << LINFO << "in dotest2()"; + + std::vector words; + std::vector sequence; + for (int i = 0; i < 8; ++i) + { + sequence.push_back(DET); + sequence.push_back(N); + sequence.push_back(V); + sequence.push_back(DET); + sequence.push_back(N); + + words.push_back("The"); + words.push_back("flight"); + words.push_back("includes"); + words.push_back("a"); + words.push_back("meal"); + } + + std::vector > parse_tree; + + find_max_parse_cky(sequence, user_defined_ruleset, parse_tree); + DLIB_TEST(parse_tree.size() == 0); + + + std::vector roots; + find_trees_not_rooted_with_tag(parse_tree, G, roots); + DLIB_TEST(roots.size() == 0); + + + words.clear(); + sequence.clear(); + + for (int i = 0; i < 2; ++i) + { + sequence.push_back(DET); + sequence.push_back(N); + sequence.push_back(V); + sequence.push_back(DET); + sequence.push_back(N); + + words.push_back("The"); + words.push_back("flight"); + words.push_back("includes"); + words.push_back("a"); + words.push_back("meal"); + } + + find_max_parse_cky(sequence, user_defined_ruleset, parse_tree); + DLIB_TEST(parse_tree.size() == 0); + + sequence.clear(); + find_max_parse_cky(sequence, user_defined_ruleset, parse_tree); + DLIB_TEST(parse_tree.size() == 0); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class parse_tester : public tester + { + public: + parse_tester ( + ) : + tester ("test_parse", + "Runs tests on the parsing tools.") + {} + + + void perform_test ( + ) + { + dotest1(); + dotest2(); + } + } a; + + +} + + + + diff --git a/ml/dlib/dlib/test/pipe.cpp b/ml/dlib/dlib/test/pipe.cpp new file mode 100644 index 000000000..d84dd0d44 --- /dev/null +++ b/ml/dlib/dlib/test/pipe.cpp @@ -0,0 +1,688 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include +#include + +#include "tester.h" + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.pipe"); + + namespace pipe_kernel_test_helpers + { + const unsigned long proc1_count = 10000; + dlib::mutex m; + signaler s(m); + unsigned long threads_running = 0; + bool found_error; + + inline void add_running_thread ( + ) + { + auto_mutex M(m); + ++threads_running; + } + + inline void remove_running_thread ( + ) + { + auto_mutex M(m); + --threads_running; + s.broadcast(); + } + + inline void wait_for_threads ( + ) + { + auto_mutex M(m); + while (threads_running > 0) + s.wait(); + } + + template < + typename pipe + > + void threadproc1 ( + void* param + ) + { + add_running_thread(); + pipe& p = *static_cast(param); + try + { + + int last = -1; + for (unsigned long i = 0; i < proc1_count; ++i) + { + int cur=0; + DLIB_TEST(p.dequeue(cur) == true); + DLIB_TEST(last + 1 == cur); + last = cur; + } + DLIB_TEST(p.size() == 0); + } + catch(exception& e) + { + auto_mutex M(m); + found_error = true; + cout << "\n\nERRORS FOUND" << endl; + cout << e.what() << endl; + dlog << LWARN << "ERRORS FOUND"; + dlog << LWARN << e.what(); + p.disable(); + } + + remove_running_thread(); + } + + + template < + typename pipe + > + void threadproc2 ( + void* param + ) + { + add_running_thread(); + pipe& p = *static_cast(param); + try + { + + int last = -1; + int cur; + while (p.dequeue(cur)) + { + DLIB_TEST(last < cur); + last = cur; + } + auto_mutex M(m); + } + catch(exception& e) + { + auto_mutex M(m); + found_error = true; + cout << "\n\nERRORS FOUND" << endl; + cout << e.what() << endl; + dlog << LWARN << "ERRORS FOUND"; + dlog << LWARN << e.what(); + p.disable(); + } + remove_running_thread(); + } + + + + template < + typename pipe + > + void threadproc3 ( + void* param + ) + { + add_running_thread(); + pipe& p = *static_cast(param); + try + { + + int last = -1; + int cur; + while (p.dequeue_or_timeout(cur,100000)) + { + DLIB_TEST(last < cur); + last = cur; + } + auto_mutex M(m); + } + catch(exception& e) + { + auto_mutex M(m); + found_error = true; + cout << "\n\nERRORS FOUND" << endl; + cout << e.what() << endl; + dlog << LWARN << "ERRORS FOUND"; + dlog << LWARN << e.what(); + p.disable(); + } + remove_running_thread(); + } + + + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template + class PipelineProcessor : private dlib::threaded_object + { + public: + PipelineProcessor( + dlib::pipe & in, + dlib::pipe & out) : + InPipe(in), + OutPipe(out), + InMsg(), + OutMsg() { + start(); + } + + ~PipelineProcessor() { + // signal the thread to stop + stop(); + wait(); + } + + private: + dlib::pipe & InPipe; + dlib::pipe & OutPipe; + + in_type InMsg; + out_type OutMsg; + + void thread() + { + while (!should_stop()) { + if(InPipe.dequeue_or_timeout(InMsg, 100)) + { + // if function signals ready to send OutMsg + while (!OutPipe.enqueue_or_timeout(OutMsg, 100)) + { + // try to send until should stop + if (should_stop()) + { + return; + } + } + } + } + }; + }; + + + void do_zero_size_test_with_timeouts() + { + dlog << LINFO << "in do_zero_size_test_with_timeouts()"; + // make sure we can get though this without deadlocking + for (int k = 0; k < 10; ++k) + { + dlib::pipe in_pipe(10); + dlib::pipe out_pipe(0); + { + PipelineProcessor pp(in_pipe, out_pipe); + + int in = 1; + in_pipe.enqueue(in); + in = 2; + in_pipe.enqueue(in); + in = 3; + in_pipe.enqueue(in); + // sleep to make sure thread enqueued + dlib::sleep(100); + + float out = 1.0f; + out_pipe.dequeue(out); + dlib::sleep(100); + } + print_spinner(); + } + + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename pipe + > + void pipe_kernel_test ( + ) + /*! + requires + - pipe is an implementation of pipe/pipe_kernel_abstract.h and + is instantiated with int + ensures + - runs tests on pipe for compliance with the specs + !*/ + { + using namespace pipe_kernel_test_helpers; + found_error = false; + + + print_spinner(); + pipe test(10), test2(100); + pipe test_0(0), test2_0(0); + pipe test_1(1), test2_1(1); + + DLIB_TEST(test.size() == 0); + DLIB_TEST(test2.size() == 0); + DLIB_TEST(test_0.size() == 0); + DLIB_TEST(test2_0.size() == 0); + DLIB_TEST(test_1.size() == 0); + DLIB_TEST(test2_1.size() == 0); + + DLIB_TEST(test.is_enqueue_enabled() == true); + DLIB_TEST(test.is_dequeue_enabled() == true); + DLIB_TEST(test.is_enabled() == true); + + test.empty(); + test2.empty(); + DLIB_TEST(test.size() == 0); + DLIB_TEST(test2.size() == 0); + + test_0.empty(); + test2_0.empty(); + DLIB_TEST(test_0.size() == 0); + DLIB_TEST(test2_0.size() == 0); + + test_1.empty(); + test2_1.empty(); + DLIB_TEST(test_1.size() == 0); + DLIB_TEST(test2_1.size() == 0); + + + + int a; + a = 3; + test.enqueue(a); + DLIB_TEST(test.size() == 1); + a = 5; + test.enqueue(a); + DLIB_TEST(test.size() == 2); + + a = 0; + test.dequeue(a); + DLIB_TEST(a == 3); + DLIB_TEST(test.size() == 1); + + a = 0; + test.dequeue(a); + DLIB_TEST(a == 5); + DLIB_TEST(test.size() == 0); + + + print_spinner(); + { + dlog << LINFO << "starting normal length pipe tests"; + create_new_thread(&threadproc1,&test); + create_new_thread(&threadproc2,&test2); + create_new_thread(&threadproc2,&test2); + create_new_thread(&threadproc2,&test2); + + for (unsigned long i = 0; i < proc1_count; ++i) + { + a = i; + test.enqueue(a); + } + DLIB_TEST(test.is_enqueue_enabled() == true); + test.disable_enqueue(); + DLIB_TEST(test.is_enqueue_enabled() == false); + for (unsigned long i = 0; i < proc1_count; ++i) + { + a = i; + test.enqueue(a); + } + + for (unsigned long i = 0; i < 100000; ++i) + { + a = i; + if (i%2 == 0) + test2.enqueue(a); + else + test2.enqueue_or_timeout(a,100000); + } + + test2.wait_for_num_blocked_dequeues(3); + DLIB_TEST(test2.size() == 0); + test2.disable(); + + wait_for_threads(); + DLIB_TEST(test2.size() == 0); + + test2.enable(); + + print_spinner(); + + create_new_thread(&threadproc3,&test2); + create_new_thread(&threadproc3,&test2); + + + for (unsigned long i = 0; i < 100000; ++i) + { + a = i; + if (i%2 == 0) + test2.enqueue(a); + else + test2.enqueue_or_timeout(a,100000); + } + + test2.wait_for_num_blocked_dequeues(2); + DLIB_TEST(test2.size() == 0); + test2.disable(); + + wait_for_threads(); + DLIB_TEST(test2.size() == 0); + + } + + + print_spinner(); + { + dlog << LINFO << "starting 0 length pipe tests"; + create_new_thread(&threadproc1,&test_0); + create_new_thread(&threadproc2,&test2_0); + create_new_thread(&threadproc2,&test2_0); + create_new_thread(&threadproc2,&test2_0); + dlog << LTRACE << "0: 1"; + + for (unsigned long i = 0; i < proc1_count; ++i) + { + a = i; + test_0.enqueue(a); + } + + dlog << LTRACE << "0: 2"; + DLIB_TEST(test_0.is_enqueue_enabled() == true); + test_0.disable_enqueue(); + DLIB_TEST(test_0.is_enqueue_enabled() == false); + for (unsigned long i = 0; i < proc1_count; ++i) + { + a = i; + test_0.enqueue(a); + } + + dlog << LTRACE << "0: 3"; + for (unsigned long i = 0; i < 100000; ++i) + { + a = i; + if (i%2 == 0) + test2_0.enqueue(a); + else + test2_0.enqueue_or_timeout(a,100000); + } + + print_spinner(); + dlog << LTRACE << "0: 4"; + test2_0.wait_for_num_blocked_dequeues(3); + DLIB_TEST(test2_0.size() == 0); + test2_0.disable(); + + wait_for_threads(); + DLIB_TEST(test2_0.size() == 0); + + dlog << LTRACE << "0: 5"; + test2_0.enable(); + + + create_new_thread(&threadproc3,&test2_0); + create_new_thread(&threadproc3,&test2_0); + + + for (unsigned long i = 0; i < 20000; ++i) + { + if ((i%100) == 0) + print_spinner(); + + a = i; + if (i%2 == 0) + test2_0.enqueue(a); + else + test2_0.enqueue_or_timeout(a,100000); + } + + dlog << LTRACE << "0: 6"; + test2_0.wait_for_num_blocked_dequeues(2); + DLIB_TEST(test2_0.size() == 0); + test2_0.disable(); + + wait_for_threads(); + DLIB_TEST(test2_0.size() == 0); + + dlog << LTRACE << "0: 7"; + } + + print_spinner(); + { + dlog << LINFO << "starting 1 length pipe tests"; + create_new_thread(&threadproc1,&test_1); + create_new_thread(&threadproc2,&test2_1); + create_new_thread(&threadproc2,&test2_1); + create_new_thread(&threadproc2,&test2_1); + + for (unsigned long i = 0; i < proc1_count; ++i) + { + a = i; + test_1.enqueue(a); + } + DLIB_TEST(test_1.is_enqueue_enabled() == true); + test_1.disable_enqueue(); + DLIB_TEST(test_1.is_enqueue_enabled() == false); + for (unsigned long i = 0; i < proc1_count; ++i) + { + a = i; + test_1.enqueue(a); + } + print_spinner(); + + for (unsigned long i = 0; i < 100000; ++i) + { + a = i; + if (i%2 == 0) + test2_1.enqueue(a); + else + test2_1.enqueue_or_timeout(a,100000); + } + + test2_1.wait_for_num_blocked_dequeues(3); + DLIB_TEST(test2_1.size() == 0); + test2_1.disable(); + + wait_for_threads(); + DLIB_TEST(test2_1.size() == 0); + + test2_1.enable(); + + + create_new_thread(&threadproc3,&test2_1); + create_new_thread(&threadproc3,&test2_1); + + + for (unsigned long i = 0; i < 100000; ++i) + { + a = i; + if (i%2 == 0) + test2_1.enqueue(a); + else + test2_1.enqueue_or_timeout(a,100000); + } + + test2_1.wait_for_num_blocked_dequeues(2); + DLIB_TEST(test2_1.size() == 0); + test2_1.disable(); + + wait_for_threads(); + DLIB_TEST(test2_1.size() == 0); + + } + + test.enable_enqueue(); + test_0.enable_enqueue(); + test_1.enable_enqueue(); + + DLIB_TEST(test.is_enabled()); + DLIB_TEST(test.is_enqueue_enabled()); + DLIB_TEST(test_0.is_enabled()); + DLIB_TEST(test_0.is_enqueue_enabled()); + DLIB_TEST(test_1.is_enabled()); + DLIB_TEST(test_1.is_enqueue_enabled()); + + DLIB_TEST(test.size() == 0); + DLIB_TEST(test_0.size() == 0); + DLIB_TEST(test_1.size() == 0); + DLIB_TEST(test.max_size() == 10); + DLIB_TEST(test_0.max_size() == 0); + DLIB_TEST(test_1.max_size() == 1); + + + for (int i = 0; i < 100; ++i) + { + a = 1; + test.enqueue_or_timeout(a,0); + a = 1; + test_0.enqueue_or_timeout(a,0); + a = 1; + test_1.enqueue_or_timeout(a,0); + } + + DLIB_TEST_MSG(test.size() == 10,"size: " << test.size() ); + DLIB_TEST_MSG(test_0.size() == 0,"size: " << test.size() ); + DLIB_TEST_MSG(test_1.size() == 1,"size: " << test.size() ); + + for (int i = 0; i < 10; ++i) + { + a = 0; + DLIB_TEST(test.enqueue_or_timeout(a,10) == false); + a = 0; + DLIB_TEST(test_0.enqueue_or_timeout(a,10) == false); + a = 0; + DLIB_TEST(test_1.enqueue_or_timeout(a,10) == false); + } + + DLIB_TEST_MSG(test.size() == 10,"size: " << test.size() ); + DLIB_TEST_MSG(test_0.size() == 0,"size: " << test.size() ); + DLIB_TEST_MSG(test_1.size() == 1,"size: " << test.size() ); + + for (int i = 0; i < 10; ++i) + { + a = 0; + DLIB_TEST(test.dequeue_or_timeout(a,0) == true); + DLIB_TEST(a == 1); + } + + DLIB_TEST(test.max_size() == 10); + DLIB_TEST(test_0.max_size() == 0); + DLIB_TEST(test_1.max_size() == 1); + + a = 0; + DLIB_TEST(test_1.dequeue_or_timeout(a,0) == true); + + DLIB_TEST(test.max_size() == 10); + DLIB_TEST(test_0.max_size() == 0); + DLIB_TEST(test_1.max_size() == 1); + + + DLIB_TEST_MSG(a == 1,"a: " << a); + + DLIB_TEST(test.size() == 0); + DLIB_TEST(test_0.size() == 0); + DLIB_TEST(test_1.size() == 0); + + DLIB_TEST(test.dequeue_or_timeout(a,0) == false); + DLIB_TEST(test_0.dequeue_or_timeout(a,0) == false); + DLIB_TEST(test_1.dequeue_or_timeout(a,0) == false); + DLIB_TEST(test.dequeue_or_timeout(a,10) == false); + DLIB_TEST(test_0.dequeue_or_timeout(a,10) == false); + DLIB_TEST(test_1.dequeue_or_timeout(a,10) == false); + + DLIB_TEST(test.size() == 0); + DLIB_TEST(test_0.size() == 0); + DLIB_TEST(test_1.size() == 0); + + DLIB_TEST(found_error == false); + + + + + { + test.enable(); + test.enable_enqueue(); + test.empty(); + DLIB_TEST(test.size() == 0); + DLIB_TEST(test.is_enabled() == true); + DLIB_TEST(test.is_enqueue_enabled() == true); + DLIB_TEST(test.is_dequeue_enabled() == true); + test.disable_dequeue(); + dlog << LINFO << "Make sure disable_dequeue() works right..."; + DLIB_TEST(test.is_dequeue_enabled() == false); + DLIB_TEST(test.dequeue(a) == false); + test.wait_until_empty(); + a = 4; + test.enqueue(a); + test.wait_until_empty(); + test.wait_for_num_blocked_dequeues(4); + DLIB_TEST(test.size() == 1); + DLIB_TEST(test.dequeue(a) == false); + DLIB_TEST(test.dequeue_or_timeout(a,10000) == false); + DLIB_TEST(test.size() == 1); + a = 0; + test.enable_dequeue(); + DLIB_TEST(test.is_dequeue_enabled() == true); + DLIB_TEST(test.dequeue(a) == true); + DLIB_TEST(a == 4); + test_1.wait_until_empty(); + } + { + test_1.enable(); + test_1.enable_enqueue(); + test_1.empty(); + DLIB_TEST(test_1.size() == 0); + DLIB_TEST(test_1.is_enabled() == true); + DLIB_TEST(test_1.is_enqueue_enabled() == true); + DLIB_TEST(test_1.is_dequeue_enabled() == true); + test_1.disable_dequeue(); + dlog << LINFO << "Make sure disable_dequeue() works right..."; + DLIB_TEST(test_1.is_dequeue_enabled() == false); + DLIB_TEST(test_1.dequeue(a) == false); + a = 4; + test_1.wait_for_num_blocked_dequeues(4); + test_1.wait_for_num_blocked_dequeues(0); + test_1.enqueue(a); + test_1.wait_until_empty(); + DLIB_TEST(test_1.size() == 1); + DLIB_TEST(test_1.dequeue(a) == false); + DLIB_TEST(test_1.dequeue_or_timeout(a,10000) == false); + DLIB_TEST(test_1.size() == 1); + a = 0; + test_1.enable_dequeue(); + DLIB_TEST(test_1.is_dequeue_enabled() == true); + DLIB_TEST(test_1.dequeue(a) == true); + DLIB_TEST(a == 4); + test_1.wait_until_empty(); + } + + } + + + + + class pipe_tester : public tester + { + public: + pipe_tester ( + ) : + tester ("test_pipe", + "Runs tests on the pipe component.") + {} + + void perform_test ( + ) + { + pipe_kernel_test >(); + + do_zero_size_test_with_timeouts(); + } + } a; + +} + + diff --git a/ml/dlib/dlib/test/pixel.cpp b/ml/dlib/dlib/test/pixel.cpp new file mode 100644 index 000000000..40772b130 --- /dev/null +++ b/ml/dlib/dlib/test/pixel.cpp @@ -0,0 +1,777 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include +#include +#include + +#include "tester.h" + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.pixel"); + + + void pixel_test ( + ) + /*! + ensures + - runs tests on pixel objects and functions for compliance with the specs + !*/ + { + + print_spinner(); + + unsigned char p_gray; + unsigned short p_gray16; + long p_int; + float p_float; + signed char p_schar; + rgb_pixel p_rgb,p_rgb2; + hsi_pixel p_hsi, p_hsi2; + rgb_alpha_pixel p_rgba; + lab_pixel p_lab, p_lab2; + + assign_pixel(p_int, 0.0f); + assign_pixel(p_float, 0.0f); + assign_pixel(p_schar, 0); + + assign_pixel(p_gray, -2); + assign_pixel(p_rgb,0); + assign_pixel(p_hsi, -4); + assign_pixel(p_rgba, p_int); + assign_pixel(p_gray16,0); + assign_pixel(p_lab,-400); + + DLIB_TEST(p_int == 0); + DLIB_TEST(p_float == 0); + DLIB_TEST(p_schar == 0); + + DLIB_TEST(p_gray == 0); + DLIB_TEST(p_gray16 == 0); + + DLIB_TEST(p_rgb.red == 0); + DLIB_TEST(p_rgb.green == 0); + DLIB_TEST(p_rgb.blue == 0); + + DLIB_TEST(p_rgba.red == 0); + DLIB_TEST(p_rgba.green == 0); + DLIB_TEST(p_rgba.blue == 0); + DLIB_TEST(p_rgba.alpha == 255); + + DLIB_TEST(p_hsi.h == 0); + DLIB_TEST(p_hsi.s == 0); + DLIB_TEST(p_hsi.i == 0); + + DLIB_TEST(p_lab.l == 0); + DLIB_TEST(p_lab.a == 128); + DLIB_TEST(p_lab.b == 128); + + assign_pixel(p_gray,10); + assign_pixel(p_gray16,10); + assign_pixel(p_rgb,10); + assign_pixel(p_hsi,10); + assign_pixel(p_rgba,10); + assign_pixel(p_lab,10); + + assign_pixel(p_int, -10); + assign_pixel(p_float, -10); + assign_pixel(p_schar, -10); + + DLIB_TEST(p_int == -10); + DLIB_TEST(p_float == -10); + DLIB_TEST(p_schar == -10); + + DLIB_TEST(p_gray == 10); + DLIB_TEST(p_gray16 == 10); + + DLIB_TEST(p_rgb.red == 10); + DLIB_TEST(p_rgb.green == 10); + DLIB_TEST(p_rgb.blue == 10); + + DLIB_TEST(p_rgba.red == 10); + DLIB_TEST(p_rgba.green == 10); + DLIB_TEST(p_rgba.blue == 10); + DLIB_TEST(p_rgba.alpha == 255); + + DLIB_TEST(p_hsi.h == 0); + DLIB_TEST(p_hsi.s == 0); + DLIB_TEST(p_hsi.i == 10); + + DLIB_TEST(p_lab.l == 10); + DLIB_TEST(p_lab.a == 128); + DLIB_TEST(p_lab.b == 128); + + assign_pixel(p_gray16,12345); + DLIB_TEST(p_gray16 == 12345); + + assign_pixel(p_float,3.141); + DLIB_TEST(p_float == 3.141f); + + p_rgb.red = 255; + p_rgb.green = 100; + p_rgb.blue = 50; + + p_rgba.alpha = 4; + assign_pixel(p_gray,p_rgb); + assign_pixel(p_rgb,p_rgb); + assign_pixel(p_rgba,p_rgb); + assign_pixel(p_hsi,p_rgb); + assign_pixel(p_lab,p_rgb); + + assign_pixel(p_float,p_rgb); + assign_pixel(p_int,p_rgb); + assign_pixel(p_schar,p_rgb); + + DLIB_TEST(p_schar == std::numeric_limits::max()); + + DLIB_TEST(p_int == (255+100+50)/3); + DLIB_TEST_MSG(p_float == (255+100+50)/3, p_float - (255+100+50)/3); + DLIB_TEST(p_gray == (255+100+50)/3); + + DLIB_TEST(p_rgb.red == 255); + DLIB_TEST(p_rgb.green == 100); + DLIB_TEST(p_rgb.blue == 50); + + DLIB_TEST(p_rgba.red == 255); + DLIB_TEST(p_rgba.green == 100); + DLIB_TEST(p_rgba.blue == 50); + DLIB_TEST(p_rgba.alpha == 255); + + DLIB_TEST(p_hsi.i > 0); + DLIB_TEST(p_hsi.s > 0); + DLIB_TEST(p_hsi.h > 0); + + DLIB_TEST(p_lab.l > 0); + DLIB_TEST(p_lab.a > 0); + DLIB_TEST(p_lab.b > 0); + + assign_pixel(p_rgb,0); + DLIB_TEST(p_rgb.red == 0); + DLIB_TEST(p_rgb.green == 0); + DLIB_TEST(p_rgb.blue == 0); + assign_pixel(p_rgb, p_hsi); + + DLIB_TEST_MSG(p_rgb.red > 251 ,(int)p_rgb.green); + DLIB_TEST_MSG(p_rgb.green > 96 && p_rgb.green < 104,(int)p_rgb.green); + DLIB_TEST_MSG(p_rgb.blue > 47 && p_rgb.blue < 53,(int)p_rgb.green); + + assign_pixel(p_rgb,0); + DLIB_TEST(p_rgb.red == 0); + DLIB_TEST(p_rgb.green == 0); + DLIB_TEST(p_rgb.blue == 0); + + assign_pixel(p_rgb, p_lab); + DLIB_TEST_MSG(p_rgb.red > 251 ,(int)p_rgb.green); + DLIB_TEST_MSG(p_rgb.green > 96 && p_rgb.green < 104,(int)p_rgb.green); + DLIB_TEST_MSG(p_rgb.blue > 47 && p_rgb.blue < 53,(int)p_rgb.green); + + assign_pixel(p_hsi2, p_hsi); + DLIB_TEST(p_hsi.h == p_hsi2.h); + DLIB_TEST(p_hsi.s == p_hsi2.s); + DLIB_TEST(p_hsi.i == p_hsi2.i); + assign_pixel(p_hsi,0); + DLIB_TEST(p_hsi.h == 0); + DLIB_TEST(p_hsi.s == 0); + DLIB_TEST(p_hsi.i == 0); + assign_pixel(p_hsi, p_rgba); + + DLIB_TEST(p_hsi.h == p_hsi2.h); + DLIB_TEST(p_hsi.s == p_hsi2.s); + DLIB_TEST(p_hsi.i == p_hsi2.i); + + assign_pixel(p_lab2, p_lab); + DLIB_TEST(p_lab.l == p_lab2.l); + DLIB_TEST(p_lab.a == p_lab2.a); + DLIB_TEST(p_lab.b == p_lab2.b); + assign_pixel(p_lab,0); + DLIB_TEST(p_lab.l == 0); + DLIB_TEST(p_lab.a == 128); + DLIB_TEST(p_lab.b == 128); + assign_pixel(p_lab, p_rgba); + + DLIB_TEST(p_lab.l == p_lab2.l); + DLIB_TEST(p_lab.a == p_lab2.a); + DLIB_TEST(p_lab.b == p_lab2.b); + + assign_pixel(p_rgba, 100); + assign_pixel(p_gray, 10); + assign_pixel(p_rgb, 10); + assign_pixel(p_hsi, 10); + + assign_pixel(p_schar, 10); + assign_pixel(p_float, 10); + assign_pixel(p_int, 10); + + p_rgba.alpha = 0; + assign_pixel(p_gray, p_rgba); + DLIB_TEST(p_gray == 10); + assign_pixel(p_schar, p_rgba); + DLIB_TEST(p_schar == 10); + assign_pixel(p_int, p_rgba); + DLIB_TEST(p_int == 10); + assign_pixel(p_float, p_rgba); + DLIB_TEST(p_float == 10); + assign_pixel(p_rgb, p_rgba); + DLIB_TEST(p_rgb.red == 10); + DLIB_TEST(p_rgb.green == 10); + DLIB_TEST(p_rgb.blue == 10); + + assign_pixel(p_hsi, p_rgba); + assign_pixel(p_hsi2, p_rgb); + DLIB_TEST(p_hsi.h == 0); + DLIB_TEST(p_hsi.s == 0); + DLIB_TEST_MSG(p_hsi.i < p_hsi2.i+2 && p_hsi.i > p_hsi2.i -2,(int)p_hsi.i << " " << (int)p_hsi2.i); + + // this value corresponds to RGB(10,10,10) + p_lab.l = 7; + p_lab.a = 128; + p_lab.b = 128; + + assign_pixel(p_lab, p_rgba); + assign_pixel(p_lab2, p_rgb); + DLIB_TEST(p_lab.a == 128); + DLIB_TEST(p_lab.b == 128); + DLIB_TEST_MSG(p_lab.l < p_lab2.l+2 && p_lab.l > p_lab2.l -2,(int)p_lab.l << " " << (int)p_lab2.l); + + assign_pixel(p_lab, 128); + DLIB_TEST(p_lab.l == 128); + DLIB_TEST(p_lab.a == 128); + DLIB_TEST(p_lab.b == 128); + assign_pixel(p_rgb, p_lab); + //Lab midpoint (50,0,0) is not same as RGB midpoint (127,127,127) + DLIB_TEST(p_rgb.red == 119); + DLIB_TEST(p_rgb.green == 119); + DLIB_TEST(p_rgb.blue == 119); + + //Lab limit values test + //red, green, blue, yellow, black, white + p_lab.l = 84; + p_lab.a = 164; + p_lab.b = 56; + assign_pixel(p_rgb, p_lab); + DLIB_TEST(p_rgb.red == 0); + DLIB_TEST(p_rgb.green == 64); + DLIB_TEST(p_rgb.blue == 194); + + p_lab.l = 255; + p_lab.a = 0; + p_lab.b = 0; + assign_pixel(p_rgb, p_lab); + DLIB_TEST(p_rgb.red == 0); + DLIB_TEST(p_rgb.green == 255); + DLIB_TEST(p_rgb.blue == 255); + + p_lab.l = 0; + p_lab.a = 255; + p_lab.b = 0; + + assign_pixel(p_rgb, p_lab); + DLIB_TEST(p_rgb.red == 0); + DLIB_TEST(p_rgb.green == 0); + DLIB_TEST(p_rgb.blue == 195); + + p_lab.l = 0; + p_lab.a = 0; + p_lab.b = 255; + + assign_pixel(p_rgb, p_lab); + DLIB_TEST(p_rgb.red == 0); + DLIB_TEST(p_rgb.green == 45); + DLIB_TEST(p_rgb.blue == 0); + + p_lab.l = 255; + p_lab.a = 255; + p_lab.b = 0; + assign_pixel(p_rgb, p_lab); + DLIB_TEST(p_rgb.red == 255); + DLIB_TEST(p_rgb.green == 139); + DLIB_TEST(p_rgb.blue == 255); + + p_lab.l = 0; + p_lab.a = 255; + p_lab.b = 255; + assign_pixel(p_rgb, p_lab); + DLIB_TEST(p_rgb.red == 132); + DLIB_TEST(p_rgb.green == 0); + DLIB_TEST(p_rgb.blue == 0); + + p_lab.l = 255; + p_lab.a = 0; + p_lab.b = 255; + assign_pixel(p_rgb, p_lab); + DLIB_TEST(p_rgb.red == 0); + DLIB_TEST(p_rgb.green == 255); + DLIB_TEST(p_rgb.blue == 0); + + p_lab.l = 255; + p_lab.a = 255; + p_lab.b = 255; + assign_pixel(p_rgb, p_lab); + DLIB_TEST(p_rgb.red == 255); + DLIB_TEST(p_rgb.green == 70); + DLIB_TEST(p_rgb.blue == 0); + + //RGB limit tests + p_rgb.red = 0; + p_rgb.green = 0; + p_rgb.blue = 0; + assign_pixel(p_lab, p_rgb); + assign_pixel(p_rgb2, p_lab); + DLIB_TEST(p_rgb2.red < 3); + DLIB_TEST(p_rgb2.green < 3); + DLIB_TEST(p_rgb2.blue < 3); + + p_rgb.red = 255; + p_rgb.green = 0; + p_rgb.blue = 0; + assign_pixel(p_lab, p_rgb); + assign_pixel(p_rgb2, p_lab); + DLIB_TEST(p_rgb2.red > 252); + DLIB_TEST(p_rgb2.green < 3); + DLIB_TEST(p_rgb2.blue < 3); + + p_rgb.red = 0; + p_rgb.green = 255; + p_rgb.blue = 0; + assign_pixel(p_lab, p_rgb); + assign_pixel(p_rgb2, p_lab); + DLIB_TEST(p_rgb2.red < 8); + DLIB_TEST(p_rgb2.green > 252); + DLIB_TEST(p_rgb2.blue < 5); + + p_rgb.red = 0; + p_rgb.green = 0; + p_rgb.blue = 255; + assign_pixel(p_lab, p_rgb); + assign_pixel(p_rgb2, p_lab); + DLIB_TEST(p_rgb2.red < 3); + DLIB_TEST(p_rgb2.green < 3); + DLIB_TEST(p_rgb2.blue > 252); + + p_rgb.red = 255; + p_rgb.green = 255; + p_rgb.blue = 0; + assign_pixel(p_lab, p_rgb); + assign_pixel(p_rgb2, p_lab); + DLIB_TEST(p_rgb2.red > 252); + DLIB_TEST(p_rgb2.green > 252); + DLIB_TEST(p_rgb2.blue < 9); + + p_rgb.red = 0; + p_rgb.green = 255; + p_rgb.blue = 255; + assign_pixel(p_lab, p_rgb); + assign_pixel(p_rgb2, p_lab); + DLIB_TEST(p_rgb2.red < 5); + DLIB_TEST(p_rgb2.green > 252); + DLIB_TEST(p_rgb2.blue > 252); + + p_rgb.red = 255; + p_rgb.green = 0; + p_rgb.blue = 255; + assign_pixel(p_lab, p_rgb); + assign_pixel(p_rgb2, p_lab); + DLIB_TEST(p_rgb2.red> 252); + DLIB_TEST(p_rgb2.green < 6); + DLIB_TEST(p_rgb2.blue > 252); + + p_rgb.red = 255; + p_rgb.green = 255; + p_rgb.blue = 255; + assign_pixel(p_lab, p_rgb); + assign_pixel(p_rgb2, p_lab); + DLIB_TEST(p_rgb2.red > 252 ); + DLIB_TEST(p_rgb2.green> 252); + DLIB_TEST(p_rgb2.blue > 252); + + + assign_pixel(p_rgba, 100); + assign_pixel(p_gray, 10); + assign_pixel(p_schar, 10); + assign_pixel(p_float, 10); + assign_pixel(p_int, 10); + + assign_pixel(p_rgb, 10); + p_rgba.alpha = 128; + assign_pixel(p_gray, p_rgba); + assign_pixel(p_schar, p_rgba); + assign_pixel(p_float, p_rgba); + assign_pixel(p_int, p_rgba); + assign_pixel(p_rgb, p_rgba); + DLIB_TEST(p_gray == (100 + 10)/2); + DLIB_TEST(p_schar == (100 + 10)/2); + DLIB_TEST(p_int == (100 + 10)/2); + DLIB_TEST(p_float == (100 + 10)/2); + DLIB_TEST(p_rgb.red == (100 + 10)/2); + DLIB_TEST(p_rgb.green == (100 + 10)/2); + DLIB_TEST(p_rgb.blue == (100 + 10)/2); + + assign_pixel(p_rgba, 100); + assign_pixel(p_gray, 10); + assign_pixel(p_schar, 10); + assign_pixel(p_int, 10); + assign_pixel(p_float, 10); + assign_pixel(p_rgb, 10); + DLIB_TEST(p_rgba.alpha == 255); + assign_pixel(p_gray, p_rgba); + assign_pixel(p_schar, p_rgba); + assign_pixel(p_int, p_rgba); + assign_pixel(p_float, p_rgba); + assign_pixel(p_rgb, p_rgba); + DLIB_TEST(p_gray == 100); + DLIB_TEST(p_schar == 100); + DLIB_TEST(p_int == 100); + DLIB_TEST(p_float == 100); + DLIB_TEST(p_rgb.red == 100); + DLIB_TEST(p_rgb.green == 100); + DLIB_TEST(p_rgb.blue == 100); + + + p_rgb.red = 1; + p_rgb.green = 2; + p_rgb.blue = 3; + + p_rgba.red = 4; + p_rgba.green = 5; + p_rgba.blue = 6; + p_rgba.alpha = 7; + + p_gray = 8; + p_schar = 8; + p_int = 8; + p_float = 8; + + p_hsi.h = 9; + p_hsi.s = 10; + p_hsi.i = 11; + + p_lab.l = 10; + p_lab.a = 9; + p_lab.b = 8; + + ostringstream sout; + serialize(p_rgb,sout); + serialize(p_rgba,sout); + serialize(p_gray,sout); + serialize(p_schar,sout); + serialize(p_int,sout); + serialize(p_float,sout); + serialize(p_hsi,sout); + serialize(p_lab,sout); + + assign_pixel(p_rgb,0); + assign_pixel(p_rgba,0); + assign_pixel(p_gray,0); + assign_pixel(p_schar,0); + assign_pixel(p_int,0); + assign_pixel(p_float,0); + assign_pixel(p_hsi,0); + assign_pixel(p_lab,0); + + istringstream sin(sout.str()); + + deserialize(p_rgb,sin); + deserialize(p_rgba,sin); + deserialize(p_gray,sin); + deserialize(p_schar,sin); + deserialize(p_int,sin); + deserialize(p_float,sin); + deserialize(p_hsi,sin); + deserialize(p_lab,sin); + + DLIB_TEST(p_rgb.red == 1); + DLIB_TEST(p_rgb.green == 2); + DLIB_TEST(p_rgb.blue == 3); + + DLIB_TEST(p_rgba.red == 4); + DLIB_TEST(p_rgba.green == 5); + DLIB_TEST(p_rgba.blue == 6); + DLIB_TEST(p_rgba.alpha == 7); + + DLIB_TEST(p_gray == 8); + DLIB_TEST(p_schar == 8); + DLIB_TEST(p_int == 8); + DLIB_TEST(p_float == 8); + + DLIB_TEST(p_hsi.h == 9); + DLIB_TEST(p_hsi.s == 10); + DLIB_TEST(p_hsi.i == 11); + + DLIB_TEST(p_lab.l == 10); + DLIB_TEST(p_lab.a == 9); + DLIB_TEST(p_lab.b == 8); + + { + matrix m_gray, m_schar, m_int, m_float; + matrix m_rgb, m_hsi, m_lab; + + m_gray = pixel_to_vector(p_gray); + m_schar = pixel_to_vector(p_schar); + m_int = pixel_to_vector(p_int); + m_float = pixel_to_vector(p_float); + + m_hsi = pixel_to_vector(p_hsi); + m_rgb = pixel_to_vector(p_rgb); + m_lab = pixel_to_vector(p_lab); + + DLIB_TEST(m_gray(0) == p_gray); + DLIB_TEST(m_float(0) == p_float); + DLIB_TEST(m_int(0) == p_int); + DLIB_TEST(m_schar(0) == p_schar); + + DLIB_TEST(m_rgb(0) == p_rgb.red); + DLIB_TEST(m_rgb(1) == p_rgb.green); + DLIB_TEST(m_rgb(2) == p_rgb.blue); + DLIB_TEST(m_hsi(0) == p_hsi.h); + DLIB_TEST(m_hsi(1) == p_hsi.s); + DLIB_TEST(m_hsi(2) == p_hsi.i); + DLIB_TEST(m_lab(0) == p_lab.l); + DLIB_TEST(m_lab(1) == p_lab.a); + DLIB_TEST(m_lab(2) == p_lab.b); + + DLIB_TEST(p_rgb.red == 1); + DLIB_TEST(p_rgb.green == 2); + DLIB_TEST(p_rgb.blue == 3); + + DLIB_TEST(p_rgba.red == 4); + DLIB_TEST(p_rgba.green == 5); + DLIB_TEST(p_rgba.blue == 6); + DLIB_TEST(p_rgba.alpha == 7); + + DLIB_TEST(p_gray == 8); + DLIB_TEST(p_int == 8); + DLIB_TEST(p_float == 8); + DLIB_TEST(p_schar == 8); + + DLIB_TEST(p_hsi.h == 9); + DLIB_TEST(p_hsi.s == 10); + DLIB_TEST(p_hsi.i == 11); + + DLIB_TEST(p_lab.l == 10); + DLIB_TEST(p_lab.a == 9); + DLIB_TEST(p_lab.b == 8); + + assign_pixel(p_gray,0); + assign_pixel(p_hsi,0); + assign_pixel(p_rgb,0); + assign_pixel(p_lab,0); + + vector_to_pixel(p_gray, m_gray); + vector_to_pixel(p_hsi, m_hsi); + vector_to_pixel(p_rgb, m_rgb); + vector_to_pixel(p_lab, m_lab); + + DLIB_TEST(p_rgb.red == 1); + DLIB_TEST(p_rgb.green == 2); + DLIB_TEST(p_rgb.blue == 3); + + DLIB_TEST(p_rgba.red == 4); + DLIB_TEST(p_rgba.green == 5); + DLIB_TEST(p_rgba.blue == 6); + DLIB_TEST(p_rgba.alpha == 7); + + DLIB_TEST(p_gray == 8); + + DLIB_TEST(p_hsi.h == 9); + DLIB_TEST(p_hsi.s == 10); + DLIB_TEST(p_hsi.i == 11); + + DLIB_TEST(p_lab.l == 10); + DLIB_TEST(p_lab.a == 9); + DLIB_TEST(p_lab.b == 8); + } + + + + + { + unsigned char p_gray; + unsigned short p_gray16; + long p_int; + float p_float; + signed char p_schar; + rgb_pixel p_rgb; + hsi_pixel p_hsi, p_hsi2; + rgb_alpha_pixel p_rgba; + lab_pixel p_lab; + + + assign_pixel(p_gray, 0); + assign_pixel(p_gray16, 0); + assign_pixel(p_int, 0); + assign_pixel(p_float, 0); + assign_pixel(p_schar, 0); + assign_pixel(p_rgb, 0); + assign_pixel(p_hsi, 0); + assign_pixel(p_lab, 0); + + + assign_pixel(p_gray, 100); + assign_pixel(p_schar, p_gray); + DLIB_TEST(p_schar == 100); + + assign_pixel(p_gray, 200); + assign_pixel(p_schar, p_gray); + DLIB_TEST(p_schar == std::numeric_limits::max()); + + assign_pixel(p_int, p_gray); + DLIB_TEST(p_int == 200); + + assign_pixel(p_float, p_gray); + DLIB_TEST(p_float == 200); + + assign_pixel(p_rgb, p_float); + DLIB_TEST(p_rgb.red == 200); + DLIB_TEST(p_rgb.green == 200); + DLIB_TEST(p_rgb.blue == 200); + + p_schar = 0; + assign_pixel(p_schar, p_rgb); + DLIB_TEST(p_schar == std::numeric_limits::max()); + + + p_schar = -10; + assign_pixel(p_float, p_schar); + DLIB_TEST(p_float == -10); + assign_pixel(p_int, p_schar); + DLIB_TEST(p_int == -10); + assign_pixel(p_schar, p_schar); + DLIB_TEST(p_schar == -10); + assign_pixel(p_gray, p_schar); + DLIB_TEST(p_gray == 0); + + assign_pixel(p_rgb, p_schar); + DLIB_TEST(p_rgb.red == 0); + DLIB_TEST(p_rgb.green == 0); + DLIB_TEST(p_rgb.blue == 0); + + assign_pixel(p_gray16, p_schar); + DLIB_TEST(p_gray16 == 0); + + DLIB_TEST(get_pixel_intensity(p_float) == -10); + DLIB_TEST(get_pixel_intensity(p_int) == -10); + DLIB_TEST(get_pixel_intensity(p_schar) == -10); + DLIB_TEST(get_pixel_intensity(p_rgb) == 0); + DLIB_TEST(get_pixel_intensity(p_gray16) == 0); + + p_rgb.red = 100; + p_rgb.green = 100; + p_rgb.blue = 100; + DLIB_TEST(get_pixel_intensity(p_rgb) == 100); + p_rgb.red = 1; + p_rgb.green = 2; + p_rgb.blue = 3; + DLIB_TEST(get_pixel_intensity(p_rgb) == 2); + p_rgba.alpha = 100; + p_rgba.red = 100; + p_rgba.green = 100; + p_rgba.blue = 100; + DLIB_TEST(get_pixel_intensity(p_rgba) == 100); + p_rgba.red = 1; + p_rgba.green = 2; + p_rgba.blue = 3; + p_rgba.alpha = 0; + DLIB_TEST(get_pixel_intensity(p_rgba) == 2); + p_hsi.h = 123; + p_hsi.s = 100; + p_hsi.i = 84; + DLIB_TEST(get_pixel_intensity(p_hsi) == 84); + + p_lab.l = 123; + p_lab.a = 100; + p_lab.b = 84; + DLIB_TEST(get_pixel_intensity(p_lab) == 123); + + p_float = 54.25; + DLIB_TEST(get_pixel_intensity(p_float) == 54.25); + + assign_pixel(p_gray, p_float); + DLIB_TEST(get_pixel_intensity(p_gray) == 54); + + assign_pixel_intensity(p_float, -1000); + assign_pixel_intensity(p_schar, -100); + assign_pixel_intensity(p_int, -10000); + assign_pixel_intensity(p_gray, -100); + + p_rgba.red = 10; + p_rgba.green = 10; + p_rgba.blue = 10; + p_rgba.alpha = 0; + DLIB_TEST_MSG(get_pixel_intensity(p_rgba) == 10, (int)get_pixel_intensity(p_rgba)); + assign_pixel_intensity(p_rgba, 2); + DLIB_TEST_MSG(p_rgba.red == 2, (int)p_rgba.red); + DLIB_TEST_MSG(p_rgba.green == 2, (int)p_rgba.green); + DLIB_TEST_MSG(p_rgba.blue == 2, (int)p_rgba.blue); + DLIB_TEST_MSG(p_rgba.alpha == 0, (int)p_rgba.alpha); + DLIB_TEST_MSG(get_pixel_intensity(p_rgba) == 2, (int)get_pixel_intensity(p_rgba)); + + DLIB_TEST(p_float == -1000); + DLIB_TEST(get_pixel_intensity(p_float) == -1000); + DLIB_TEST(p_schar == -100); + DLIB_TEST(get_pixel_intensity(p_schar) == -100); + DLIB_TEST(p_int == -10000); + DLIB_TEST(get_pixel_intensity(p_int) == -10000); + DLIB_TEST(p_gray == 0); + assign_pixel_intensity(p_gray, 1000); + DLIB_TEST(p_gray == 255); + DLIB_TEST(get_pixel_intensity(p_gray) == 255); + + assign_pixel_intensity(p_float, p_gray); + DLIB_TEST(p_float == 255); + DLIB_TEST(get_pixel_intensity(p_float) == 255); + + assign_pixel_intensity(p_int, p_gray); + DLIB_TEST(p_int == 255); + DLIB_TEST(get_pixel_intensity(p_int) == 255); + + + p_float = 1e10; + assign_pixel(p_schar, p_float); + DLIB_TEST(p_schar == std::numeric_limits::max()); + + p_float = -1e10; + assign_pixel(p_schar, p_float); + DLIB_TEST(p_schar == std::numeric_limits::min()); + + double p_double = 1e200; + assign_pixel(p_float, p_double); + DLIB_TEST(p_float == std::numeric_limits::max()); + + p_double = -1e200; + assign_pixel(p_float, p_double); + DLIB_TEST(p_float == -std::numeric_limits::max()); + } + + + } + + + + + class pixel_tester : public tester + { + public: + pixel_tester ( + ) : + tester ("test_pixel", + "Runs tests on the pixel objects and functions.") + {} + + void perform_test ( + ) + { + pixel_test(); + } + } a; + +} diff --git a/ml/dlib/dlib/test/probabilistic.cpp b/ml/dlib/dlib/test/probabilistic.cpp new file mode 100644 index 000000000..e8a24829a --- /dev/null +++ b/ml/dlib/dlib/test/probabilistic.cpp @@ -0,0 +1,123 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include +#include +#include "../stl_checked.h" +#include "../array.h" +#include "../rand.h" +#include "checkerboard.h" +#include + +#include "tester.h" +#include + + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.probabilistic"); + +// ---------------------------------------------------------------------------------------- + + class test_probabilistic : public tester + { + public: + test_probabilistic ( + ) : + tester ("test_probabilistic", + "Runs tests on the probabilistic trainer adapter.") + {} + + void perform_test ( + ) + { + print_spinner(); + + + typedef double scalar_type; + typedef matrix sample_type; + + std::vector x; + std::vector > x_linearized; + std::vector y; + + get_checkerboard_problem(x,y, 1000, 2); + + random_subset_selector rx; + random_subset_selector ry; + rx.set_max_size(x.size()); + ry.set_max_size(x.size()); + + dlog << LINFO << "pos labels: "<< sum(mat(y) == +1); + dlog << LINFO << "neg labels: "<< sum(mat(y) == -1); + + for (unsigned long i = 0; i < x.size(); ++i) + { + rx.add(x[i]); + ry.add(y[i]); + } + + const scalar_type gamma = 2.0; + + typedef radial_basis_kernel kernel_type; + + krr_trainer krr_trainer; + krr_trainer.use_classification_loss_for_loo_cv(); + krr_trainer.set_kernel(kernel_type(gamma)); + krr_trainer.set_basis(randomly_subsample(x, 100)); + probabilistic_decision_function df; + + dlog << LINFO << "cross validation: " << cross_validate_trainer(krr_trainer, rx,ry, 4); + print_spinner(); + + running_stats rs_pos, rs_neg; + + print_spinner(); + df = probabilistic(krr_trainer,3).train(x, y); + for (unsigned long i = 0; i < x.size(); ++i) + { + if (y[i] > 0) + rs_pos.add(df(x[i])); + else + rs_neg.add(df(x[i])); + } + dlog << LINFO << "rs_pos.mean(): "<< rs_pos.mean(); + dlog << LINFO << "rs_neg.mean(): "<< rs_neg.mean(); + DLIB_TEST_MSG(rs_pos.mean() > 0.95, rs_pos.mean()); + DLIB_TEST_MSG(rs_neg.mean() < 0.05, rs_neg.mean()); + rs_pos.clear(); + rs_neg.clear(); + + + print_spinner(); + df = probabilistic(krr_trainer,3).train(rx, ry); + for (unsigned long i = 0; i < x.size(); ++i) + { + if (y[i] > 0) + rs_pos.add(df(x[i])); + else + rs_neg.add(df(x[i])); + } + dlog << LINFO << "rs_pos.mean(): "<< rs_pos.mean(); + dlog << LINFO << "rs_neg.mean(): "<< rs_neg.mean(); + DLIB_TEST_MSG(rs_pos.mean() > 0.95, rs_pos.mean()); + DLIB_TEST_MSG(rs_neg.mean() < 0.05, rs_neg.mean()); + rs_pos.clear(); + rs_neg.clear(); + + } + } a; + +} + + diff --git a/ml/dlib/dlib/test/pyramid_down.cpp b/ml/dlib/dlib/test/pyramid_down.cpp new file mode 100644 index 000000000..c026a8162 --- /dev/null +++ b/ml/dlib/dlib/test/pyramid_down.cpp @@ -0,0 +1,424 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#include +#include +#include +#include +#include +//#include +#include + +#include "tester.h" + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.pyramid_down"); + +// ---------------------------------------------------------------------------------------- + +void test_pyramid_down_grayscale() +{ + array2d img, down; + pyramid_down<2> pyr; + + img.set_size(300,264); + + assign_all_pixels(img, 10); + + pyr(img, down); + + DLIB_TEST(std::abs(down.nr()*2 - img.nr()) < 5); + DLIB_TEST(std::abs(down.nc()*2 - img.nc()) < 5); + + rectangle rect1 = get_rect(img); + rectangle rect2 = pyr.rect_up(get_rect(down)); + double overlap = rect1.intersect(rect2).area() / (double)(rect1 + rect2).area(); + DLIB_TEST(overlap > 0.95); + + rect1 = get_rect(down); + rect2 = pyr.rect_down(get_rect(img)); + overlap = rect1.intersect(rect2).area() / (double)(rect1 + rect2).area(); + DLIB_TEST(overlap > 0.95); + + DLIB_TEST(min(mat(down)) == 10); + DLIB_TEST(max(mat(down)) == 10); +} + +void test_pyramid_down_rgb() +{ + array2d img; + array2d down; + pyramid_down<2> pyr; + + img.set_size(231, 351); + + assign_all_pixels(img, rgb_pixel(1,2,3)); + + pyr(img, down); + + DLIB_TEST(std::abs(down.nr()*2 - img.nr()) < 5); + DLIB_TEST(std::abs(down.nc()*2 - img.nc()) < 5); + + rectangle rect1 = get_rect(img); + rectangle rect2 = pyr.rect_up(get_rect(down)); + double overlap = rect1.intersect(rect2).area() / (double)(rect1 + rect2).area(); + DLIB_TEST(overlap > 0.95); + + rect1 = get_rect(down); + rect2 = pyr.rect_down(get_rect(img)); + overlap = rect1.intersect(rect2).area() / (double)(rect1 + rect2).area(); + DLIB_TEST(overlap > 0.95); + + bool pixels_match = true; + for (long r = 0; r < down.nr(); ++r) + { + for (long c = 0; c < down.nc(); ++c) + { + if (down[r][c].red != 1 || + down[r][c].green != 2 || + down[r][c].blue != 3 ) + { + pixels_match = false; + } + } + } + DLIB_TEST(pixels_match); +} + +// ---------------------------------------------------------------------------- + +template +rgb_pixel mean_pixel ( + const image_type& img, + const rectangle& rect +) +{ + long red = 0; + long green = 0; + long blue = 0; + for (long r = rect.top(); r <= rect.bottom(); ++r) + { + for (long c = rect.left(); c <= rect.right(); ++c) + { + red += img[r][c].red; + green += img[r][c].green; + blue += img[r][c].blue; + } + } + + const long n = rect.area(); + return rgb_pixel(red/n, green/n, blue/n); +} + +// ---------------------------------------------------------------------------- + +template +void test_pyramid_down_rgb2() +{ + array2d img, img3; + array2d img2, img4; + + + img.set_size(300,400); + assign_all_pixels(img, 0); + rectangle rect1 = centered_rect( 10,10, 14, 14); + rectangle rect2 = centered_rect( 100,100, 34, 42); + rectangle rect3 = centered_rect( 310,215, 65, 21); + + fill_rect(img, rect1, rgb_pixel(255,0,0)); + fill_rect(img, rect2, rgb_pixel(0,255,0)); + fill_rect(img, rect3, rgb_pixel(0,0,255)); + + + + pyramid_down_type pyr; + + pyr(img, img2); + pyr(img, img3); + + + DLIB_TEST(((rect1.tl_corner() - pyr.rect_down(pyr.rect_up(rect1,2),2).tl_corner()).length()) < 1); + DLIB_TEST(((rect1.br_corner() - pyr.rect_down(pyr.rect_up(rect1,2),2).br_corner()).length()) < 1); + DLIB_TEST(((rect2.tl_corner() - pyr.rect_down(pyr.rect_up(rect2,2),2).tl_corner()).length()) < 1); + DLIB_TEST(((rect2.br_corner() - pyr.rect_down(pyr.rect_up(rect2,2),2).br_corner()).length()) < 1); + DLIB_TEST(((rect3.tl_corner() - pyr.rect_down(pyr.rect_up(rect3,2),2).tl_corner()).length()) < 1); + DLIB_TEST(((rect3.br_corner() - pyr.rect_down(pyr.rect_up(rect3,2),2).br_corner()).length()) < 1); + + rect1 = shrink_rect(pyr.rect_down(rect1),1); + rect2 = shrink_rect(pyr.rect_down(rect2),1); + rect3 = shrink_rect(pyr.rect_down(rect3),1); + + DLIB_TEST(rect1.area() > 10); + DLIB_TEST(rect2.area() > 10); + DLIB_TEST(rect3.area() > 10); + + /* + image_window my_window(img); + image_window win2(img2); + image_window win3(img3); + win2.add_overlay(image_window::overlay_rect(rect1, rgb_pixel(255,0,0))); + win2.add_overlay(image_window::overlay_rect(rect2, rgb_pixel(255,0,0))); + win2.add_overlay(image_window::overlay_rect(rect3, rgb_pixel(255,0,0))); + win3.add_overlay(image_window::overlay_rect(rect1, rgb_pixel(255,0,0))); + win3.add_overlay(image_window::overlay_rect(rect2, rgb_pixel(255,0,0))); + win3.add_overlay(image_window::overlay_rect(rect3, rgb_pixel(255,0,0))); + */ + + + DLIB_TEST(std::abs((int)mean(subm(matrix_cast(mat(img2)),rect1)) - 255/3) < 3); + DLIB_TEST(std::abs((int)mean(subm(matrix_cast(mat(img2)),rect2)) - 255/3) < 3); + DLIB_TEST(std::abs((int)mean(subm(matrix_cast(mat(img2)),rect3)) - 255/3) < 3); + assign_image(img4, img); + DLIB_TEST(std::abs((int)mean(mat(img4)) - mean(mat(img2))) < 2); + + + rgb_pixel mean1 = mean_pixel(img3, rect1); + rgb_pixel mean2 = mean_pixel(img3, rect2); + rgb_pixel mean3 = mean_pixel(img3, rect3); + rgb_pixel mean_all_true = mean_pixel(img, get_rect(img)); + rgb_pixel mean_all = mean_pixel(img3, get_rect(img3)); + DLIB_TEST(mean1.red > 250); + DLIB_TEST(mean1.green < 3); + DLIB_TEST(mean1.blue < 3); + + DLIB_TEST(mean2.red < 3); + DLIB_TEST(mean2.green > 250); + DLIB_TEST(mean2.blue < 3); + + DLIB_TEST(mean3.red < 3); + DLIB_TEST(mean3.green < 3); + DLIB_TEST(mean3.blue > 250); + + DLIB_TEST(std::abs((int)mean_all_true.red - mean_all.red) < 1); + DLIB_TEST(std::abs((int)mean_all_true.green - mean_all.green) < 1); + DLIB_TEST(std::abs((int)mean_all_true.blue - mean_all.blue) < 1); + + //my_window.wait_until_closed(); +} + + +// ---------------------------------------------------------------------------------------- + +template +void test_pyramid_down_grayscale2() +{ + array2d img; + array2d img2, img4; + + + img.set_size(300,400); + assign_all_pixels(img, 0); + rectangle rect1 = centered_rect( 10,10, 14, 14); + rectangle rect2 = centered_rect( 100,100, 34, 42); + rectangle rect3 = centered_rect( 310,215, 65, 21); + + fill_rect(img, rect1, 255); + fill_rect(img, rect2, 170); + fill_rect(img, rect3, 100); + + + + pyramid_down_type pyr; + + pyr(img, img2); + + + DLIB_TEST(((rect1.tl_corner() - pyr.rect_down(pyr.rect_up(rect1,2),2).tl_corner()).length()) < 1); + DLIB_TEST(((rect1.br_corner() - pyr.rect_down(pyr.rect_up(rect1,2),2).br_corner()).length()) < 1); + DLIB_TEST(((rect2.tl_corner() - pyr.rect_down(pyr.rect_up(rect2,2),2).tl_corner()).length()) < 1); + DLIB_TEST(((rect2.br_corner() - pyr.rect_down(pyr.rect_up(rect2,2),2).br_corner()).length()) < 1); + DLIB_TEST(((rect3.tl_corner() - pyr.rect_down(pyr.rect_up(rect3,2),2).tl_corner()).length()) < 1); + DLIB_TEST(((rect3.br_corner() - pyr.rect_down(pyr.rect_up(rect3,2),2).br_corner()).length()) < 1); + + rect1 = shrink_rect(pyr.rect_down(rect1),1); + rect2 = shrink_rect(pyr.rect_down(rect2),1); + rect3 = shrink_rect(pyr.rect_down(rect3),1); + + DLIB_TEST(rect1.area() > 10); + DLIB_TEST(rect2.area() > 10); + DLIB_TEST(rect3.area() > 10); + + /* + image_window my_window(img); + image_window win2(img2); + win2.add_overlay(image_window::overlay_rect(rect1, rgb_pixel(255,0,0))); + win2.add_overlay(image_window::overlay_rect(rect2, rgb_pixel(255,0,0))); + win2.add_overlay(image_window::overlay_rect(rect3, rgb_pixel(255,0,0))); + */ + + + DLIB_TEST(std::abs((int)mean(subm(matrix_cast(mat(img2)),rect1)) - 255) <= 3); + DLIB_TEST(std::abs((int)mean(subm(matrix_cast(mat(img2)),rect2)) - 170) < 3); + DLIB_TEST(std::abs((int)mean(subm(matrix_cast(mat(img2)),rect3)) - 100) < 3); + assign_image(img4, img); + DLIB_TEST(std::abs((int)mean(mat(img4)) - mean(mat(img2))) < 2); + + + //my_window.wait_until_closed(); + + + + // make sure the coordinate mapping is invertible when it should be + for (int l = 0; l < 4; ++l) + { + for (long x = -10; x <= 10; ++x) + { + for (long y = -10; y <= 10; ++y) + { + DLIB_TEST_MSG(point(pyr.point_down(pyr.point_up(point(x,y),l),l)) == point(x,y), + point(x,y) << " " << pyr.point_up(point(x,y),l) << " " << pyr.point_down(pyr.point_up(point(x,y),l),l)); + DLIB_TEST_MSG(point(pyr.point_down(point(pyr.point_up(point(x,y),l)),l)) == point(x,y), + point(x,y) << " " << pyr.point_up(point(x,y),l) << " " << pyr.point_down(point(pyr.point_up(point(x,y),l)),l)); + } + } + } +} + +// ---------------------------------------------------------------------------------------- + +template +void test_pyr_sizes() +{ + dlib::rand rnd; + + for (int iter = 0; iter < 20; ++iter) + { + long nr = rnd.get_random_32bit_number()%10+40; + long nc = rnd.get_random_32bit_number()%10+40; + + array2d img(nr,nc), img2; + assign_all_pixels(img,0); + + pyramid_down_type pyr; + + pyr(img, img2); + find_pyramid_down_output_image_size(pyr, nr, nc); + DLIB_TEST(img2.nr() == nr); + DLIB_TEST(img2.nc() == nc); + } +} + + +// ---------------------------------------------------------------------------------------- + +template +void test_pyramid_down_small_sizes() +{ + print_spinner(); + // just make sure it doesn't get messed up with small images. This test + // is only really useful if asserts are enabled. + pyramid_down_type pyr; + + for (int size = 0; size < 20; ++size) + { + array2d img1(size,size); + array2d img2(size,size); + + array2d out1; + array2d out2; + + assign_all_pixels(img1, 0); + assign_all_pixels(img2, 0); + + pyr(img1, out1); + pyr(img2, out2); + } +} + +// ---------------------------------------------------------------------------------------- + + + class test_pyramid_down : public tester + { + public: + test_pyramid_down ( + ) : + tester ("test_pyramid_down", + "Runs tests on the pyramid_down() function.") + {} + + void perform_test ( + ) + { + print_spinner(); + test_pyramid_down_grayscale(); + print_spinner(); + test_pyramid_down_rgb(); + + print_spinner(); + dlog << LINFO << "call test_pyramid_down_small_sizes >();"; + test_pyramid_down_small_sizes >(); + dlog << LINFO << "call test_pyramid_down_small_sizes >();"; + test_pyramid_down_small_sizes >(); + dlog << LINFO << "call test_pyramid_down_small_sizes >();"; + test_pyramid_down_small_sizes >(); + dlog << LINFO << "call test_pyramid_down_small_sizes >();"; + test_pyramid_down_small_sizes >(); + dlog << LINFO << "call test_pyramid_down_small_sizes();"; + test_pyramid_down_small_sizes(); + dlog << LINFO << "call test_pyramid_down_small_sizes >();"; + test_pyramid_down_small_sizes >(); + + print_spinner(); + dlog << LINFO << "call test_pyramid_down_rgb2 >();"; + test_pyramid_down_rgb2 >(); + + print_spinner(); + dlog << LINFO << "call test_pyramid_down_rgb2 >();"; + test_pyramid_down_rgb2 >(); + + print_spinner(); + dlog << LINFO << "call test_pyramid_down_rgb2 >();"; + test_pyramid_down_rgb2 >(); + + print_spinner(); + dlog << LINFO << "call test_pyramid_down_rgb2 >();"; + test_pyramid_down_rgb2 >(); + + print_spinner(); + dlog << LINFO << "call test_pyramid_down_rgb2 >();"; + test_pyramid_down_rgb2 >(); + + + print_spinner(); + dlog << LINFO << "call test_pyramid_down_grayscale2 >();"; + test_pyramid_down_grayscale2 >(); + + print_spinner(); + dlog << LINFO << "call test_pyramid_down_grayscale2 >();"; + test_pyramid_down_grayscale2 >(); + + print_spinner(); + dlog << LINFO << "call test_pyramid_down_grayscale2 >();"; + test_pyramid_down_grayscale2 >(); + + print_spinner(); + dlog << LINFO << "call test_pyramid_down_grayscale2 >();"; + test_pyramid_down_grayscale2 >(); + + print_spinner(); + dlog << LINFO << "call test_pyramid_down_grayscale2 >();"; + test_pyramid_down_grayscale2 >(); + + + test_pyr_sizes>(); + test_pyr_sizes>(); + test_pyr_sizes>(); + test_pyr_sizes>(); + test_pyr_sizes>(); + test_pyr_sizes>(); + test_pyr_sizes>(); + test_pyr_sizes>(); + test_pyr_sizes>(); + } + } a; + +} + + + + diff --git a/ml/dlib/dlib/test/queue.cpp b/ml/dlib/dlib/test/queue.cpp new file mode 100644 index 000000000..efcdf6054 --- /dev/null +++ b/ml/dlib/dlib/test/queue.cpp @@ -0,0 +1,426 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include +#include + +#include "tester.h" + +// This is called an unnamed-namespace and it has the effect of making everything inside this file "private" +// so that everything you declare will have static linkage. Thus we won't have any multiply +// defined symbol errors coming out of the linker when we try to compile the test suite. +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + // Declare the logger we will use in this test. The name of the tester + // should start with "test." + logger dlog("test.queue"); + + template < + typename queue + > + void queue_sort_test ( + ) + /*! + requires + - queue is an implementation of queue/queue_sort_abstract.h + is instantiated with int + ensures + - runs tests on queue for compliance with the specs + !*/ + { + + print_spinner(); + srand(static_cast(time(0))); + + queue q,q2; + + enumerable& e = q; + + // I will use these DLIB_TEST_MSG macros to assert that conditions are true. If they are + // false then it means we have detected an error in the queue object. CASSERT + // will then throw an exception which we will catch at the end of this function and + // report as an error/failed test. + DLIB_TEST(e.at_start() == true); + + int a = 0; + + DLIB_TEST(q.size() == 0); + DLIB_TEST(q.at_start() == true); + DLIB_TEST(q.current_element_valid() == false); + + q.sort(); + + DLIB_TEST(q.size() == 0); + DLIB_TEST(q.at_start() == true); + DLIB_TEST(q.current_element_valid() == false); + + DLIB_TEST (q.move_next() == false); + DLIB_TEST (q.move_next() == false); + DLIB_TEST (q.move_next() == false); + DLIB_TEST (q.move_next() == false); + DLIB_TEST (q.move_next() == false); + DLIB_TEST (q.move_next() == false); + + + DLIB_TEST(q.size() == 0); + DLIB_TEST(q.at_start() == false); + DLIB_TEST(q.current_element_valid() == false); + + + q.reset(); + + DLIB_TEST(q.size() == 0); + DLIB_TEST(q.at_start() == true); + DLIB_TEST(q.current_element_valid() == false); + + + + + + + + + + + + q.clear(); + q2.clear(); + DLIB_TEST(q.size() == 0); + DLIB_TEST(q2.size() == 0); + + for (int i = 0; i < 10000; ++i) + { + int a = i; + q.enqueue(a); + } + + q2.cat(q); + + DLIB_TEST(q.size() == 0); + DLIB_TEST(q2.size() == 10000); + + int g = 0; + while (q2.move_next()) + { + DLIB_TEST_MSG(q2.element() == g,g); + ++g; + } + + for (int i = 0;i < 10000; ++i) + { + int a = 0; + q2.dequeue(a); + DLIB_TEST(a == i); + } + + DLIB_TEST(q.size() == 0); + DLIB_TEST(q2.size() == 0); + q.clear(); + q2.clear(); + + + + + print_spinner(); + + + dlog << LTRACE << "creating big pre-sorted queue"; + q.clear(); + DLIB_TEST(q.size() == 0); + + for (int i = 0; i < 10000; ++i) + { + int a = i; + q.enqueue(a); + } + + dlog << LTRACE << "sorting already sorted queue"; + q.sort(); + + + dlog << LTRACE << "done sorting, checking the results"; + for (int i = 0; i < 10000; ++i) + { + q.dequeue(a); + DLIB_TEST(a == i); + } + + + q.clear(); + dlog << LTRACE << "done with the big pre-sorted queue test"; + + + + + + + + + + + + + + + + q.clear(); + q2.clear(); + DLIB_TEST(q.size() == 0); + DLIB_TEST(q2.size() == 0); + + for (int i = 0; i < 1; ++i) + { + int a = i; + q.enqueue(a); + } + + q2.cat(q); + + DLIB_TEST(q.size() == 0); + DLIB_TEST(q2.size() == 1); + + + + g = 0; + while (q2.move_next()) + { + DLIB_TEST_MSG(q2.element() == g,g); + ++g; + } + + for (int i = 0;i < 1; ++i) + { + int a = 0; + q2.dequeue(a); + DLIB_TEST(a == i); + } + + DLIB_TEST(q.size() == 0); + DLIB_TEST(q2.size() == 0); + q.clear(); + q2.clear(); + + + + + + + + print_spinner(); + + + + + + + + + + + + for (int j = 0; j < 3; ++j) + { + for (int i = 0; i < 10000; ++i) + { + a = ::rand(); + q.enqueue(a); + } + + while (q.move_next()) ; + + DLIB_TEST(q.at_start() == false); + + q.sort(); + + DLIB_TEST(q.at_start() == true); + + // serialize the state of q, then clear q, then + // load the state back into q. + ostringstream sout; + serialize(q,sout); + DLIB_TEST(q.at_start() == true); + istringstream sin(sout.str()); + q.clear(); + deserialize(q,sin); + + + DLIB_TEST(q.at_start() == true); + + a = 0; + int last = 0; + while (q.move_next()) + { + ++a; + DLIB_TEST_MSG(last <= q.element(),"items weren't actually sorted"); + last = q.element(); + DLIB_TEST(q.current_element_valid() == true); + DLIB_TEST(q.at_start() == false); + DLIB_TEST(q.current_element_valid() == true); + + + } + DLIB_TEST_MSG(a == 10000,"some items were lost between the sorting and iterating"); + + + DLIB_TEST(q.size() == 10000); + swap(q,q2); + DLIB_TEST(q2.at_start() == false); + DLIB_TEST(q2.current_element_valid() == false); + + DLIB_TEST (q2.move_next() == false); + DLIB_TEST (q2.move_next() == false); + DLIB_TEST (q2.move_next() == false); + DLIB_TEST (q2.move_next() == false); + DLIB_TEST (q2.move_next() == false); + DLIB_TEST (q2.move_next() == false); + + + DLIB_TEST(q2.size() == 10000); + DLIB_TEST(q2.at_start() == false); + DLIB_TEST(q2.current_element_valid() == false); + + q2.clear(); + + q.swap(q2); + + DLIB_TEST(q.size() == 0); + DLIB_TEST(q.at_start() == true); + DLIB_TEST(q.current_element_valid() == false); + } + + + + print_spinner(); + + + + // try the above code but this time with just one element + // in the queue + for (int j = 0; j < 3; ++j) + { + for (int i = 0; i < 1; ++i) + { + a = ::rand(); + q.enqueue(a); + } + + q.sort(); + + a = 0; + int last = 0; + while (q.move_next()) + { + ++a; + DLIB_TEST_MSG(last <= q.element(),"items weren't actually sorted"); + DLIB_TEST(q.current_element_valid() == true); + + } + DLIB_TEST_MSG(a == 1,"some items were lost between the sorting and iterating"); + + + DLIB_TEST(q.size() == 1); + DLIB_TEST(q.at_start() == false); + DLIB_TEST(q.current_element_valid() == false); + + q.clear(); + + DLIB_TEST(q.size() == 0); + DLIB_TEST(q.at_start() == true); + DLIB_TEST(q.current_element_valid() == false); + } + + + print_spinner(); + + { + q.clear(); + remover& go = q; + for (int i = 0; i < 100; ++i) + { + int a = 3; + q.enqueue(a); + } + DLIB_TEST(go.size() == 100); + for (int i = 0; i < 100; ++i) + { + int a = 9; + q.remove_any(a); + DLIB_TEST(a == 3); + } + DLIB_TEST(go.size() == 0); + } + + } + + + struct factory + { + template + struct return_type { + typedef typename memory_manager::kernel_3c type; + }; + + template + static typename return_type::type* get_instance ( + ) + { + static typename return_type::type a; + return &a; + } + }; + + + + + class queue_tester : public tester + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a test for the queue object. When it is constructed + it adds itself into the testing framework. The command line switch is + specified as test_queue by passing that string to the tester constructor. + !*/ + public: + queue_tester ( + ) : + tester ("test_queue", + "Runs tests on the queue component.") + {} + + void perform_test ( + ) + { + // There are multiple implementations of the queue object so use + // the templated function defined above to test them all and report + // a failed test if any of them don't pass. + + typedef dlib::memory_manager_global::kernel_1a mm; + + + dlog << LINFO << "testing sort_1a_c"; + queue_sort_test::sort_1a_c> (); + dlog << LINFO << "testing sort_1a"; + queue_sort_test::sort_1a>(); + dlog << LINFO << "testing sort_1b"; + queue_sort_test::sort_1b> (); + dlog << LINFO << "testing sort_1b_c"; + queue_sort_test::sort_1b_c>(); + dlog << LINFO << "testing sort_1c"; + queue_sort_test::sort_1c> (); + dlog << LINFO << "testing sort_1c_c"; + queue_sort_test::sort_1c_c>(); + } + } a; + +} + diff --git a/ml/dlib/dlib/test/rand.cpp b/ml/dlib/dlib/test/rand.cpp new file mode 100644 index 000000000..db051c530 --- /dev/null +++ b/ml/dlib/dlib/test/rand.cpp @@ -0,0 +1,436 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tester.h" + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.rand"); + + void check_bpp ( + const std::string str + ) + { + istringstream rdata; + ostringstream sout; + rdata.str(str); + double compressed_size; + compress_stream::kernel_1a cs1; + compress_stream::kernel_2a cs2; + + compress_stream_kernel_1< + entropy_encoder_model_kernel_5<257,entropy_encoder::kernel_1a,4000000,4>, + entropy_decoder_model_kernel_5<257,entropy_decoder::kernel_1a,4000000,4>, + crc32::kernel_1a + > cs3; + + + print_spinner(); + + rdata.clear(); + rdata.seekg(0); + sout.clear(); + sout.str(""); + cs1.compress(rdata,sout); + compressed_size = sout.str().size(); + compressed_size *= 8; + compressed_size /= str.size(); + DLIB_TEST_MSG(compressed_size >= 8, "order 0 bps: " << compressed_size); + dlog << LINFO << "order 0: " << compressed_size; + + print_spinner(); + + rdata.clear(); + rdata.seekg(0); + sout.clear(); + sout.str(""); + cs2.compress(rdata,sout); + compressed_size = sout.str().size(); + compressed_size *= 8; + compressed_size /= str.size(); + DLIB_TEST_MSG(compressed_size >= 8, "order 1 bps: " << compressed_size); + dlog << LINFO << "order 1: " << compressed_size; + + print_spinner(); + + rdata.clear(); + rdata.seekg(0); + sout.clear(); + sout.str(""); + cs3.compress(rdata,sout); + compressed_size = sout.str().size(); + compressed_size *= 8; + compressed_size /= str.size(); + DLIB_TEST_MSG(compressed_size >= 8, "order 4 bps: " << compressed_size); + dlog << LINFO << "order 4: " << compressed_size; + + } + + template < + typename rand + > + void rand_test ( + ) + /*! + requires + - rand is an implementation of rand/rand_kernel_abstract.h + is instantiated with int + ensures + - runs tests on rand for compliance with the specs + !*/ + { + + ostringstream seed; + seed << (unsigned int)time(0); + + ostringstream sout; + + + rand r, r2; + DLIB_TEST(r.get_seed() == ""); + r.set_seed(seed.str()); + + DLIB_TEST(r.get_seed() == seed.str()); + r.clear(); + DLIB_TEST(r.get_seed() == ""); + swap(r,r2); + DLIB_TEST(r.get_seed() == ""); + r.set_seed(seed.str()); + DLIB_TEST(r.get_seed() == seed.str()); + swap(r,r2); + DLIB_TEST(r2.get_seed() == seed.str()); + DLIB_TEST(r.get_seed() == ""); + swap(r,r2); + DLIB_TEST(r.get_seed() == seed.str()); + DLIB_TEST(r2.get_seed() == ""); + + print_spinner(); + unsigned long size = 100000; + for (unsigned long i = 0; i < size; ++i) + { + uint32 ch = r.get_random_32bit_number(); + sout.write((char*)&ch,4); + } + + check_bpp(sout.str()); + sout.clear(); + sout.str(""); + + print_spinner(); + for (unsigned long i = 0; i < size; ++i) + { + uint16 ch = r.get_random_16bit_number(); + sout.write((char*)&ch,2); + } + + check_bpp(sout.str()); + sout.clear(); + sout.str(""); + + print_spinner(); + for (unsigned long i = 0; i < size; ++i) + { + unsigned char ch = r.get_random_8bit_number(); + sout.write((char*)&ch,1); + } + + check_bpp(sout.str()); + sout.clear(); + sout.str(""); + + + // make sure the things can serialize right + { + r.clear(); + r2.clear(); + + + for (int i =0; i < 1000; ++i) + { + r.get_random_32bit_number(); + r.get_random_gaussian(); + } + + ostringstream sout; + serialize(r, sout); + + istringstream sin(sout.str()); + deserialize(r2, sin); + + + for (int i =0; i < 1000; ++i) + { + DLIB_TEST(r.get_random_32bit_number() == r2.get_random_32bit_number()); + DLIB_TEST(std::abs(r.get_random_gaussian() - r2.get_random_gaussian()) < 1e-14); + } + } + + + // make sure calling clear() and set_seed("") do the same thing + { + r.clear(); + r2.set_seed(""); + rand r3; + + + DLIB_TEST(r.get_seed() == r2.get_seed()); + DLIB_TEST(r.get_seed() == r3.get_seed()); + + + for (int i =0; i < 1000; ++i) + { + const uint32 num1 = r.get_random_32bit_number(); + const uint32 num2 = r2.get_random_32bit_number(); + const uint32 num3 = r3.get_random_32bit_number(); + DLIB_TEST( num1 == num2); + DLIB_TEST( num1 == num3); + } + } + + } + + + template + void test_normal_numbers( + rand_type& rnd + ) + { + print_spinner(); + dlog << LINFO << "test normality"; + double cnt1 = 0; // num <= -1.2 + double cnt2 = 0; // num <= -0.5 + double cnt3 = 0; // num <= 0 + double cnt4 = 0; // num <= 0.5 + double cnt5 = 0; // num <= 1.2 + + const unsigned long total = 1000000; + for (unsigned long i = 0; i < total; ++i) + { + const double r = rnd.get_random_gaussian(); + if (r <= -1.2) cnt1 += 1; + if (r <= -0.5) cnt2 += 1; + if (r <= 0) cnt3 += 1; + if (r <= 0.5) cnt4 += 1; + if (r <= 1.2) cnt5 += 1; + } + + cnt1 /= total; + cnt2 /= total; + cnt3 /= total; + cnt4 /= total; + cnt5 /= total; + + dlog << LINFO << "cnt1: "<< cnt1; + dlog << LINFO << "cnt2: "<< cnt2; + dlog << LINFO << "cnt3: "<< cnt3; + dlog << LINFO << "cnt4: "<< cnt4; + dlog << LINFO << "cnt5: "<< cnt5; + + DLIB_TEST(std::abs(cnt1 - 0.11507) < 0.001); + DLIB_TEST(std::abs(cnt2 - 0.30854) < 0.001); + DLIB_TEST(std::abs(cnt3 - 0.5) < 0.001); + DLIB_TEST(std::abs(cnt4 - 0.69146) < 0.001); + DLIB_TEST(std::abs(cnt5 - 0.88493) < 0.001); + + } + + void test_gaussian_random_hash() + { + print_spinner(); + dlog << LINFO << "test_gaussian_random_hash()"; + double cnt1 = 0; // num <= -1.2 + double cnt2 = 0; // num <= -0.5 + double cnt3 = 0; // num <= 0 + double cnt4 = 0; // num <= 0.5 + double cnt5 = 0; // num <= 1.2 + + const unsigned long total = 1000000; + for (unsigned long i = 0; i < total; ++i) + { + const double r = gaussian_random_hash(i,0,0); + if (r <= -1.2) cnt1 += 1; + if (r <= -0.5) cnt2 += 1; + if (r <= 0) cnt3 += 1; + if (r <= 0.5) cnt4 += 1; + if (r <= 1.2) cnt5 += 1; + } + for (unsigned long i = 0; i < total; ++i) + { + const double r = gaussian_random_hash(0,i,0); + if (r <= -1.2) cnt1 += 1; + if (r <= -0.5) cnt2 += 1; + if (r <= 0) cnt3 += 1; + if (r <= 0.5) cnt4 += 1; + if (r <= 1.2) cnt5 += 1; + } + for (unsigned long i = 0; i < total; ++i) + { + const double r = gaussian_random_hash(0,0,i); + if (r <= -1.2) cnt1 += 1; + if (r <= -0.5) cnt2 += 1; + if (r <= 0) cnt3 += 1; + if (r <= 0.5) cnt4 += 1; + if (r <= 1.2) cnt5 += 1; + } + + cnt1 /= total*3; + cnt2 /= total*3; + cnt3 /= total*3; + cnt4 /= total*3; + cnt5 /= total*3; + + dlog << LINFO << "cnt1: "<< cnt1; + dlog << LINFO << "cnt2: "<< cnt2; + dlog << LINFO << "cnt3: "<< cnt3; + dlog << LINFO << "cnt4: "<< cnt4; + dlog << LINFO << "cnt5: "<< cnt5; + + DLIB_TEST(std::abs(cnt1 - 0.11507) < 0.001); + DLIB_TEST(std::abs(cnt2 - 0.30854) < 0.001); + DLIB_TEST(std::abs(cnt3 - 0.5) < 0.001); + DLIB_TEST(std::abs(cnt4 - 0.69146) < 0.001); + DLIB_TEST(std::abs(cnt5 - 0.88493) < 0.001); + } + + void test_uniform_random_hash() + { + print_spinner(); + dlog << LINFO << "test_uniform_random_hash()"; + double cnt1 = 0; // num <= 0.2 + double cnt2 = 0; // num <= 0.4 + double cnt3 = 0; // num <= 0.6 + double cnt4 = 0; // num <= 0.8 + double cnt5 = 0; // num <= 1.0 + + double min_val = 10; + double max_val = 0; + + const unsigned long total = 1000000; + for (unsigned long i = 0; i < total; ++i) + { + const double r = uniform_random_hash(i,0,0); + min_val = min(r,min_val); + max_val = max(r,max_val); + + if (r <= 0.2) cnt1 += 1; + if (r <= 0.4) cnt2 += 1; + if (r <= 0.6) cnt3 += 1; + if (r <= 0.8) cnt4 += 1; + if (r <= 1.0) cnt5 += 1; + } + for (unsigned long i = 0; i < total; ++i) + { + const double r = uniform_random_hash(0,i,0); + min_val = min(r,min_val); + max_val = max(r,max_val); + + if (r <= 0.2) cnt1 += 1; + if (r <= 0.4) cnt2 += 1; + if (r <= 0.6) cnt3 += 1; + if (r <= 0.8) cnt4 += 1; + if (r <= 1.0) cnt5 += 1; + } + for (unsigned long i = 0; i < total; ++i) + { + const double r = uniform_random_hash(0,0,i); + min_val = min(r,min_val); + max_val = max(r,max_val); + + if (r <= 0.2) cnt1 += 1; + if (r <= 0.4) cnt2 += 1; + if (r <= 0.6) cnt3 += 1; + if (r <= 0.8) cnt4 += 1; + if (r <= 1.0) cnt5 += 1; + } + + cnt1 /= total*3; + cnt2 /= total*3; + cnt3 /= total*3; + cnt4 /= total*3; + cnt5 /= total*3; + + dlog << LINFO << "cnt1: "<< cnt1; + dlog << LINFO << "cnt2: "<< cnt2; + dlog << LINFO << "cnt3: "<< cnt3; + dlog << LINFO << "cnt4: "<< cnt4; + dlog << LINFO << "cnt5: "<< cnt5; + dlog << LINFO << "min_val: "<< min_val; + dlog << LINFO << "max_val: "<< max_val; + + DLIB_TEST(std::abs(cnt1 - 0.2) < 0.001); + DLIB_TEST(std::abs(cnt2 - 0.4) < 0.001); + DLIB_TEST(std::abs(cnt3 - 0.6) < 0.001); + DLIB_TEST(std::abs(cnt4 - 0.8) < 0.001); + DLIB_TEST(std::abs(cnt5 - 1.0) < 0.001); + DLIB_TEST(std::abs(min_val - 0.0) < 0.001); + DLIB_TEST(std::abs(max_val - 1.0) < 0.001); + } + + void test_get_integer() + { + + print_spinner(); + dlib::rand rnd; + + + int big_val = 0; + int small_val = 0; + + const long long maxval = (((unsigned long long)1)<<62) + (((unsigned long long)1)<<61); + for (int i = 0; i < 10000000; ++i) + { + if (rnd.get_integer(maxval) > maxval/2) + ++big_val; + else + ++small_val; + } + + // make sure there isn't any funny bias + DLIB_TEST(std::abs(big_val/(double)small_val - 1) < 0.001); + + //cout << big_val/(double)small_val << endl; + + } + + class rand_tester : public tester + { + public: + rand_tester ( + ) : + tester ("test_rand", + "Runs tests on the rand component.") + {} + + void perform_test ( + ) + { + dlog << LINFO << "testing kernel_1a"; + rand_test(); + rand_test(); + + dlib::rand rnd; + test_normal_numbers(rnd); + test_gaussian_random_hash(); + test_uniform_random_hash(); + test_get_integer(); + } + } a; + +} + + diff --git a/ml/dlib/dlib/test/random_forest.cpp b/ml/dlib/dlib/test/random_forest.cpp new file mode 100644 index 000000000..b3447bf5c --- /dev/null +++ b/ml/dlib/dlib/test/random_forest.cpp @@ -0,0 +1,405 @@ +// Copyright (C) 2018 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#include +#include +#include + +#include +#include +#include + +#include "tester.h" + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + + logger dlog("test.random_forest"); + + const std::string get_decoded_string(); + +// ---------------------------------------------------------------------------------------- + + + + class test_random_forest : public tester + { + public: + test_random_forest ( + ) : + tester ("test_random_forest", + "Runs tests on the random forest tools.") + {} + + + void perform_test ( + ) + { + istringstream sin(get_decoded_string()); + + print_spinner(); + + typedef matrix sample_type; + std::vector labels; + std::vector samples; + + deserialize(samples, sin); + deserialize(labels, sin); + + DLIB_TEST(samples.size() == 506); + + random_forest_regression_trainer trainer; + trainer.set_num_trees(1000); + trainer.set_seed("random forest"); + + std::vector oobs; + auto df = trainer.train(samples, labels, oobs); + + DLIB_TEST(df.get_num_trees() == 1000); + + auto result = test_regression_function(df, samples, labels); + // train: 2.239 0.987173 0.970669 1.1399 + dlog << LINFO << "train: " << result; + DLIB_TEST_MSG(result(0) < 2.3, result(0)); + + running_stats rs; + for (size_t i = 0; i < oobs.size(); ++i) + rs.add(std::pow(oobs[i]-labels[i],2.0)); + dlog << LINFO << "OOB MSE: "<< rs.mean(); + DLIB_TEST_MSG(rs.mean() < 10.2, rs.mean()); + + print_spinner(); + + stringstream ss; + serialize(df, ss); + decltype(df) df2; + deserialize(df2, ss); + DLIB_TEST(df2.get_num_trees() == 1000); + result = test_regression_function(df2, samples, labels); + // train: 2.239 0.987173 0.970669 1.1399 + dlog << LINFO << "serialized train results: " << result; + DLIB_TEST_MSG(result(0) < 2.3, result(0)); + } + } a; + + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + + // This function returns the contents of the file './housing_data.dat' + const std::string get_decoded_string() + { + dlib::base64 base64_coder; + dlib::compress_stream::kernel_1ea compressor; + std::ostringstream sout; + std::istringstream sin; + + // The base64 encoded data from the file './housing_data.dat' we want to decode and return. + sout << "AvlDWRK3FGmPtCL8V/RoXQLzsOKukA0zQosjXGZPM6yNf1U5LjdhRVloILZ5baZq5Tj0hLQGf1JY"; + sout << "ggTd0DEbBez7lzvZ6hOBABkJ6U0aZAEeIMUL/K19h5gHhpwdWvcolJA+VbQVD4acLQFxcIgsgN8N"; + sout << "2zjQsQCUkHGSGyop7/xrVZAJS4Nwy8qMWJyjgc/9mZXdDsaPWtSVTeoxpb3d94bDYdrEQ8T0Z6N1"; + sout << "fCXCQo/3bus7NA+FJoOtw23DHMsYnsv+tfvaNzzzX7lc0cRPe8Pi5q9JDBs4Gc5bPh3Hw8QJz3n5"; + sout << "beGJpU1KRHu9vq1zgFuavXvyZ1HQVHLj/yJRm4lL5eUv7CQeULGP9UmIkvVCJTZ4uw6mIFJyPYjN"; + sout << "jjygWMfJ1dnlGpI/ZlaVNJaTEB28tNymV0UKeKb/Sg42s9+wNNQzWMBm5TRvqmplf/0Gx+4Tcmcd"; + sout << "FtUb1pgkz6OB59Ko4L2hsxVQYYZHGpQt6QBMiubW7GVe2g4yWjRhqlQTh2sjceGRi26SFOrD+gnH"; + sout << "9xZlbyKdlKlcT3nVfcKYziLKAjbmr5QIu+W6WR7M+p90CHkDrjkVK0WTSc23kOhuua5feG54aoht"; + sout << "hcViWRpASVPXsKKJcl2yTlZ02uFQak5Lid/znCDmWQk0fjfzEZZNzgoCNVi8xCx68lH4Mjm3MdF4"; + sout << "ExZqX0jsNlxDOwKp7TYIbhfY4/XdzFSi20CtxXbB3knkuggb+ru/u/6lh8rmqwLjsANqb2CWG0EH"; + sout << "32i1/gmtlY54kAYG58GWB89klTnUwRImQv/QjJBrf8TwK0jUOjkrnOWgKWwNsK0jba54750QPal4"; + sout << "SJFvmeBIR52/T2ZsC2iAQGQuod1IZtIfI2pn0dpkdWm/Y3JdelJ/EADtNjJch6maVoamxKmoWlVw"; + sout << "TRoHTIzdsJKtQiewn21H8bFX7HzZ1yE3/iuSEeLPB7T7gnfrtCEEzqAWiUb/t7mXtqBqt6Kdk6By"; + sout << "JZqE0KdtJ7MJ/yV1MHQ94ExDhYI0qaItZVGwyH8ETCzr6xqmue9hPD6SRq3oy+aQKvJxOqFMcqsj"; + sout << "cL+/2S2F1frRgZDGCb6tB3TdMCJDhZoChQNmJ3hyoAdWXrPEysL0TW8LFSIItAylIRljMsXgnMRE"; + sout << "RCkfYPdbweT71k6l7FiaAsqHeKh4w8CxKJkzhJeALEPLz4QvqFt6/DoFmhKTrX4xUk3M/Y+dU/LY"; + sout << "3B/S+e6v9cWuluEGXDzgj/LKBeruPk7hcnhIilMmd3D8sew3tvOdIowxmM67lqW6fwExcK6oKPlT"; + sout << "aDZcttWKTndVEKsUvnZr/PQ8sta49+GGSfkXw/MS0TAjTQ0Wck8wSJ2CMoUGBmVSKLSwEWvdAqdo"; + sout << "lQLDxAVayR3GeKasJshQw69o/3d4JnUOBcU5ZJM0z2D51EDQM3lvmnB9dtiJ2rcypWG53ETvQqYc"; + sout << "S3suPgaDqKmxZbRNvnfuYbG6+qPeHDN6WmdAt9Iw5XWdjlG6u8BGI6+vqY1C8J6pJ2p7ITUVBVCU"; + sout << "NRYyXVhz0PQrzy5jFwbVeRZo3IV/bkPlHV4UujcSEOi67fqk4ixOzA2JsxYzei+Rsp1ahpK3Wmuk"; + sout << "9ZEeqD1/xqqeosV5pvwcXhjjp4UJ0bGY0pEf7w9uDW0aZT+D8prSmTXGjFAQGiSBmfLFw1Yk2CyG"; + sout << "V8RG7/7uxM6qyj9LYsNTGvdyD8DvQNEJ0v7J9IwCihdJAFhuKgqxlmkJx3mz6MiiIC19CuQKR0NC"; + sout << "1/PG2wh7zEhnDwfINSR2b41ZDcm8/ky/k+xhJ6fi3ZqpSlkqyPRA/YAID9Dl7Ngn9/xeWQNXzEHD"; + sout << "pn6aPHclBaDm5keUHlcN9d+vrxrRjV/GdRhs5M2eVRl47djEiYTaQ3Ua9Lg0oWBXYFkC1ncNhsIT"; + sout << "tFLwx1Fc89lNwN3kHW91X3g1MMDFZoyzwBan2v/3DOZpPH4U1cPr1SbAPu0HITHK6pCUtMRY2/DZ"; + sout << "9MTmm6FLeJhjiVma1+ALD4wYTNhebWkmX12jeDPT5gMyDhq3dYQKIvq83aVY1MQ2UroMDf9NVEdh"; + sout << "V94rpEjw0ewZeDKHwPazkOU4q5m69VxwkKSbNzKZ1oe0P5s7j4Z44NYP3o7Qq8MPLi9l7QVKqge8"; + sout << "6PqEdYoxGr9a/QDB7wcSdpcZTku0MippOZm1uG0zA2K6WlTmC/3lCm4m8EZBZXq6YkTjUpPreeio"; + sout << "umUsmp+XbtJ1LWK3V5hMl0R67Fa6tBd6x9bP6/FSwjeYPHj8dz3nK0VLX+NC7+NjIPBkKgAcqs+1"; + sout << "m9u/BA72sNE/cn7NgfgXPlHsho3reV8Iujq+MTN5iayeTH2fyG7XmV0RkpOY770bJEdugB4QlWWL"; + sout << "nZgYd2yyNFlBcvXRKZpoG9cTWuzdxSbJnYdgKGLuPQ0B0BYwrSLnHEM8pboXCN4uIrALW6ipmKeO"; + sout << "/S8bW4u73Bqgamnh7/pdLAoWo6EiD5C3uNrW0OYdnnSjhLHkV+krDhPQr6nAfjH/0vo+CXPDbMyW"; + sout << "DmkJVgJ/cBt+EWNyIOeBLliqbe9zY6hqzGRt4b6cp1fDH/tbYMCsxhp0LIPnDNDouUQK3j+VBB3X"; + sout << "E8NnCSzBa4SdhMNww7jbJea+RXqe+g1clXqkBf/ZitToTPmzHhYcPRAcvhLcRE0n/uWj7jbv5vOD"; + sout << "swyUpiJaXQgRG77rh89xNVqz9qio6uP2xGdlyW0IENbxMYsGq+XKxNMAMHvfRh8JwakGw7ZI/Wn7"; + sout << "8uWdjM2lNmenBCYSS9qe+2DKnqTxOSnm5ugsYr6IXftmlzev0ke2rRBfvllAv8GSY8GTJul+gbZV"; + sout << "+3Wu8xZawgLFjngRph0aq4PviIwrMS1PhE5M7pC65E40uaY+xJv4rQGNCLF3/+SLvnLfTRdH0QZU"; + sout << "r/hXG0BCcaWE4hb7HIzon9mNIZf2Eb+IWxAhUQ2/Nhe/hNTRx+DpB/8H2DurZPFK4nrOPvxmmzgA"; + sout << "3VFL0kJjNfeXGlo2sSQEM8sDecXQkl47KGWROHIaJyRZAoBOBpMHxTXs//3aqWhlOZ88pftZEmSL"; + sout << "K0sXXxS3BAKB8SLzu4VNPNdvmtT7Z4sHmfYj5NXayXS3V2d2646L2QW4jzzwHjpJ6p2/4mjKniFO"; + sout << "TSQu1wKSv16L/QVbiwvg6wYgIL8cct+XSRSUGfvVgo9lAt+OHTIL7s/9A66jqdDmK3UnHcV8XOVN"; + sout << "Wd+DnXOG7nI5shOiGHAqAAtuQZK9sAZZtyLmg68UN02ZroY0pUePwuGaBribKuKLDtcfMcDHwv0x"; + sout << "lSlCkHu9SjsC1Qswk90Yx8YddYY1ePYaoez6xUAraj+zOLNuFCZtm6hGTQq+pPZ5xn/K6zzOvaos"; + sout << "sxDaWBSNDEuOfAK2dHgctL6XKH7/kHAZxBa7CbTSe0zFuO1WbicRrqO1NpUyO9P1L82dv1VCSuyi"; + sout << "Mtj7UNnywrmZmMBf5x5yYBdVwBUKIBy/sNJdrpDHbAi6MJoYzCai8TqTulNL8jGAzkjpXYRqPf99"; + sout << "fXWRTzN1PMHjvEbNOBIX4OorGR4lj4E7i+d1DKaJjXCDgvJQTUGvRWdu7MOOkXPCmIlxFL9Wv2CB"; + sout << "LpzzNp+I3NuLg2F30ratBBLoqIrnBBXb390pABYah8bRnARCUJLjFXugVqTWoMwAsrbS6sfdFRf8"; + sout << "fKt/+Nx2vX8tRJBFFgBEbS2le05ekg7HC6egGCLImh8j8sf4gs+2xdGKXh9mnW8BrqZJvQPkeR4D"; + sout << "Fro5V/EFe7EAIXpQfMRoNpHUSyn5oPJDFYMjjc1EEO4C6qqJ29nV149m60BjWDuVK1y+mdkCvaDo"; + sout << "iZfAKI4TiDExfMpdAJM5Ti6G7pauQnW/lxGNX44JOR82strVKwsQxSecUc+uT+kYjcdV9Mx29qJC"; + sout << "qXJgi2cUhrNsVH0LIe45X3pSzfFnMQ2o+fGAURgQCSW6wToqmtHBsCorr0v32ew524X6tZz11HMC"; + sout << "7DKppzWxGTBOCBwOPrAjwlUb8zaRpj3bQGXWkvDk7E7ZUMQ6TJu0wgkuBNIPeIAh2tLfGqrqZqxp"; + sout << "Y2hM/G/qQG+Bukj8yyxitXqIwceSp3v2BnLnL/WriBpSVEm9LnjiPou/BL98WMhp13UKWe3L3XkC"; + sout << "izri1YMoCyxaQFX6RODu4We1vi/9gjamFSCAc5Tj+CmUasCJpbmkjt7Fp+v3PhXa4qpt32ZR0yuF"; + sout << "G0dowpb8WYvT3U7nWAOFKBRgj9Ri6QPlCVaUa/LnkjZ+fNzytlkzQ9TTsPpOkEeJo+nCF3cCUYBH"; + sout << "Y6lIyjcQRk9guLIgk955mBLpyjni8ZFOOsjTsW+LoOvAiZhVTGwA75/g5z6IcYLlcb0nwpZ/O2fS"; + sout << "QPFcb5V6uhi5TnQHDQGHihSU4MBo5BQfNd+VuSxliK/TVvFU0yYjqPzKxCxgBpDO8qKsPMbc2YKL"; + sout << "SFY2ygJ7PwksSIEQUum0MSEFf1ZJ3WNTajxSvFLToOkAtLpnvZlWymYkI72/Dgi7jBpfhIw1U1Td"; + sout << "tuTLc0L6IfALX2I2VL2tOhBcisUL8IRhDipxhBTRBaJYLG2RB6ICKBuAQaXf4ODAPKbLhzfRSss+"; + sout << "2VTojSwerCyQkKyoUZLR67G2ysWWLERwD1btSNH4IjaPYVEmaWk4I4F1YZhrmN3q5du7t7g3E4C4"; + sout << "/UVLrCVTQD0CnBVBB5hzMEByG/4ZhIu+JWx+jRx1288XA1k84c28NLfMnqDsLHGtVxOLFDBgwFxs"; + sout << "vD8S2E1+G4La3DQWc/X6jfkC+dtp0ihh5qQxGaGCKh0mcd3BHnNJYUSqSRQLRhOjiZBxmujGrhJG"; + sout << "oHPaUCxfgY3vl9y6KAVlcLcKTYZpmukkjxCEVOClOy2pHYivhgkO2HR7okgNGpj8qN9EcVTWPq8u"; + sout << "dbBjHLQ2GbqHamyaDJFUhJfsibyXR5ZtG2WAZDH3uXlL8AGNriBhwgGVcaRGH4sO/NmWWdM/gnap"; + sout << "6geVpginIZALN+egxDTQtxTH5qTPfkMg9tdjlX/zB7e1LbVR40waeP5PtanIvb7VU/GbVbQMEDKc"; + sout << "Lqfj3v6RyK6wX7mDvF7HWFtav8R/j9wlgf75kOiXz+2eN7GeXEmF68LqH6g4n7Ulyhq3uqszT4Jg"; + sout << "hk7ynKJoLURg5KyJPTUCvadDaiqaLH4hF2bErrQzIIbbKDCq5Cb7n0EhCZ03MjLFs9+HSa3yfaGN"; + sout << "NUS3wdGiM9x6rNaKDP6/vySXZzBvgtinskFBvb7UqCwmQ2MF1lwr0+nTNfH7R9fw3fi+tHXB6kyh"; + sout << "PovdaPH+3dfnsbqVSoJLj2OvjsFfTFQXn35xd/IW3UEdBYVSDZP8VGRnXQUSS4BbJ79VUMNOTmwz"; + sout << "CsoiZzIZNgHShekR2XKv1oXM+BheSAxK+r/d+VdPgjlkByfCwuw8iP/odUoaXzk6iTh8h1pGyESL"; + sout << "QY8mNIzzPsU39opNlK7JmOzlYG2wtCS+DcG+bw4HLJP3Or9mChHpN+V3xzL5Tsb/5fGeqcQ0hvsA"; + sout << "aXMhlsnRtxRSDkfE0s1HW0r/O63X5Hm1Yw/vJw8BzEtNYg8h3x7xvECS4vAwwuKLS30rjlVqjPqI"; + sout << "TNchzWOA98U8AoC3t0asTAaEXce8tPtLqXD0EyycoU3slyfpErU4vySzpCXtkv3BShevfZy9yhX0"; + sout << "2HG5zTc+l8GdXayf6mVSXaQ2N2OV6gCwd+hwqHjqvYSg4a0Ug+/cEw3zVi5AGiLIzTGDGsfJHE86"; + sout << "9ohKS4z7yI+doqegx0f7N95Njw315nKSZnSSf6Pa/I20SrcQabMC36H2vdv9gkOlsYlZLyCOL54P"; + sout << "ZOXlim7GgCt8LPEO6maHmQn0f5mtJAYIxMrJKoMasXvc2ZI2tktbh6bAJNfpSL0KbTfeQtaFJnVX"; + sout << "C2f8RXf61VY9rNVd+qtpNuiavf8ZuaVbSLsLzF/beAFpS4djI0Nn78CJBhZAnzhPD76byg/vXG42"; + sout << "nyD/u/FLJ3eccPnvs5umbz/gPiFk6gW+HnTXaYEdwaGWDdlr4QxvDki8Wsr0AWPlzA8D0nmkPCZW"; + sout << "EZLBUIUjnBN2K86dyqEDW8+C42vuwXfa31wkOX0/8S7FSuT1BET8HdK8fykJ8NxdKlUsIFNr9SPz"; + sout << "maMVyvkPp6IQ+DG9PBOFIaFy+zHbPzCRNNd3LTBhkQ0K6bP7u4tG6b4fdmmQzSSGsfXqEUiXkjGX"; + sout << "ge14Qh+f/2KA4TjBZDQWF5NKR6/x4lsHfj8dZDg3+fEwY2fqezjD/jptis7N/VfeSIM/3xD+gF3w"; + sout << "BqJn4Wz+ohlWucLfS0JRREnPAWfje7RQYatBkLok2Uy2hO7lgfw6ipUHNVPUw6XmK3VW+McnK0Ur"; + sout << "L4CI9LAFF+kDdBfTs8hnhmLtk6h2Sucjo1ahEBxAyUuRgqMko5Sy8Lr9Eo49KiKO+V8LpA5ZDMq8"; + sout << "iabdyb1WLFnyvE01K4uKqGHLoeh7heD7/0sAbIVySks0mv5cH46AT288mIVcHrSUurhtxawYZY4P"; + sout << "/DqW0jHbqIkZWJiOquIfeTbtgRax76gX01JSeEdL0UyPHTmoqkvMdQVwjYcIBdLrCPPWQWNjkmWa"; + sout << "XBedBCzwmp6fZX8ew5AABNCmlBBlhqNFZQ4yG91NXuiDoS25xckthn+6l7Mn0FHs9418wKqa+3eB"; + sout << "uGqiAJVwpgWx7CWhWi9MonFdA1nni9AyjubzEaUSbzjL4ghneOGC38FyEQcuKIxrqgY+ManAAdlc"; + sout << "hVaHl9Rx4r16AITagHNPLwCbbeJ7nbM+arjvU6qmK4Bg4E96IDjrrp2EQJMYZrs5+oRbpdtomWTx"; + sout << "k3hcUFCMUiulWQs/pc5bXm+Xvx0mNnpu9A4GtFMzzpKO4M9Q/mtKch9H547N8hjV4jrWsJCqplho"; + sout << "XIp9yBwPKVaEwWYpFcmpfQIJW+I56ewA0xthRBdqqqQSoS3zMJwdQEUm+XibYA9XALC2dkbTH2fo"; + sout << "H9a4ImxxquSZZpoqUuRsbsgejD2v0ynbipTQ/lNVwswk18Wma+Whg4sOdCSQIMns2QW/3quqHoqb"; + sout << "DC06jJZuQJnLNly7x48Pcus1gsWc6aVhbpl7cCLq/YUY0ACMS5toED6q+5mqakCg69pK4dm9WHIf"; + sout << "D0hRK0v/05LheyllMYmQhO98Z0Jz8IJQQsl2sZahUr4Q8oTTFt9rLRKd+onL5HwdJQDAiYB/fex0"; + sout << "3MHYPjyzK2aH6mN7qOG2VquMGb6LmOPszLSggrzFsVzPqYrzqH/6sGZPV1wMduWw/qYadGPlpzu/"; + sout << "XIgnQ2Qzb21xqwTnivx30xgWDf8kuQWDdECGT+/GPAqNx1P9RGLru+tSPe06oZ8RuoSs64zDqJ3E"; + sout << "KMmOeFAt+xQFBg6dEBs1pgfO78AfZDAbadYzvPCEp2zGGeYwhZlF+a+QIbFCyLe4EO6a0wlrPEBn"; + sout << "3DESwdT7ETielqoEQvKdeSQiks0UGvlXRhxKdH+ylQslmdolzE57pMkquAwiFMXddLGFegrctP9s"; + sout << "tmsvLPKWDIqiHy+F79eU6vOfwwS7btaRg5zuRKWkQ+B2CU8F/kx4FR4ZxhK8fzGjMUyjAmHZhEXf"; + sout << "kvnchtB6z0pN7wUf0n+Clxo0DiXlJlRQPo3pZDttbC685azJ3OoH04xS37vxUSx1ir/LWLz/tjkW"; + sout << "iFYq3qxftzK+jU7XzDx2nif7ZLc/+ecfHdQPXK4YZzJ1x8C7SvC7rBLRxnKqTYgv2bL9G1sCU+x6"; + sout << "0hQtMba3x42k//w4RtV2KkazHoMTZc9UuNSsaSoAoGauzw0cs99op7HCpOgoyRu5JeY+fimo2H5C"; + sout << "cXBecQbQdUB0uVxxEQHPwJN7vi94JfbpdnIMLLRjBwRs/2FOmMWbWWcShUYoWDSmJOLaw3Piwtk6"; + sout << "bg/ppKqGAfrzDJkR0n1OZgKvUbnb8WRyZse0W+tO+PcsL1wvwG+8mMJU+AOBs1P/iVLxW/Y4CuXi"; + sout << "/e7SckKJ3vsm/pQawrzhDIjOwofxzBWQ4kODfSEWHZvpQD0HNf/qP6IYfqhUu/0JtRJGLhlQ8hQJ"; + sout << "iJBGtwsCRJWKrBgu6cizrYcA664+XPgjF/FQYLGmPiPrBdrbWjVxSk3tEOgVFOuK+bkI0EX3p0hm"; + sout << "gYbr3oIec4bKzrSgYsIQtHMo1FnQl1xwHL0vH24KF6V6eyYpgVBfg42MNDk/aaCZ4XVIgH0H0wns"; + sout << "sRXftElLUVk8yLhqq9kXBmgHvPZMfA5WTP+KhXFRbfxw0A2nWGbztsniRcoA3N0pGdqwDyOE5VGg"; + sout << "tX94o9eS0eOJzh80SKaHFaX8GtlpVhogNJiMVlzwVoNJASK2Pr7Yp8uIqcUT7+e0VtkdsVlG5wv8"; + sout << "WLEbqmRXrsKLs9f1p23SelLo78kI9nBujEvDSOCChnNwqNPG85kiz24jL0LMaWHAHlnY6uZypDyM"; + sout << "TUmjsyosdrCobZRnQFf4UwUhuNtj3f6sQke+GQhzr474hTpfSDqLGMW6IcE3OcU4x3waC87DPSRC"; + sout << "7PtmJ/+8nWIElEbGJtjS+rL+Ue6faqpkh+dkPC5ZsWHHRvXzyuRNawC15L0kLhhCc9Y+s+fWOppC"; + sout << "iWtPPQKk4PKDA/g5TRA+KPkFH0B6YchdiEaCMLmleDqF9uo+XNzMnHdOKrkTZ29gPosM9w8CpTSF"; + sout << "neDroZT/v1ckXkEZ6rhlVkF2pBmqG0DTL3LPclzO3JC6i6noY+kFU0jSARjXgQU3NXrYpgeRhLuv"; + sout << "hKlC4Fl4xTK8l+p8J8Uk9zKTKsAdyDcAT6rBGfRpmJiN5a+lxuiBSCyygeHDQGhaJFROD53Q/m9M"; + sout << "dxBlTjqMI7r5M2BNFKaZ4vhFdukvCbu2dPm8t10pc4brs7L5TeBVo5qxaFgRbkpS1mTCoDtFH4fi"; + sout << "Wl+zpEdF6VpKrmFaCSSduE2tMhr164re7m4P7CxeJWvYdFfWGD+uDhFM7oPVXvvkC5gYmjdAPIYq"; + sout << "co6IvSMraA9ANQd8b/hO6u+zo+U2Fos9Xe/1u5YWr9JdXZ9oFsdNGQLaLH6VcRrwAyx7+tRjf5Ia"; + sout << "IHli9+TQmZ4tbtxERYe8TaqFqugpCGtvmcE/DFo0BeWgFRFnAY6nyXGgJCrzbxrMdCOBdvIYzuuB"; + sout << "+A68idNO9ifsfNxRalfJCQNymwy4MAylkG8ncZ6Bqx1XztI9ckbD7U7TBMHCWt9xnMxQz9G0HmsQ"; + sout << "pIa0x8tKk5zZ2TyOVe1LjwBXzwhn17i67Ph+NTA2pw8La9KcI2xdlDtAhD+LxRBANo95CeGL9NKp"; + sout << "cDjWMrqwRlvXL0qroKeJuRqtSPGC+hbFEJgTX7iWDq4QJCyvscIm/lWz0ZQIHyXyh3yV/UyGbMXD"; + sout << "hc6mVp20J+AGPR0NCEN3mTh01ON3LJI1t1P6OT8oGM2ofce1YsHLdMlY3uu00ErXy0YF6vz7jft1"; + sout << "0St41Ydx54E5As7cbimxngKnJsFVIdJC8uh8SaAJQJWtQK6DG5sXJVDADSIKXM0TuhRTxRREGQWB"; + sout << "W6Hd2jP7ZArcBCuB2GGMw4sn8iAYMK1LP12hxoZsBZ8iAbohy3MpWZHiE9MDU5PbGRyNsEnuLmQp"; + sout << "APmj5sFUcAsA0MSUYZli2jB2WWAWwTaQ1CGm92tdrdflShh5FR8IhAZrwXPzI/w/1vAianD9yheA"; + sout << "j0cYf/EaB92n9xRSK5zeajIV+DFT/451rNvi0Dqea6cDxRkza31G3d7pWwPkY6WiGdvSzlgy8uJN"; + sout << "pt2gJoZJ6VzzeYD1bsb5X/FYBtYwFuiGicRbUadEA0b736fEC3AG2OZVh5bFmVArBoUukUoBNf8S"; + sout << "gWgzfeYNyL25qa5jaeg+X8okmGtNUzfxLtmJiCY4/A9aDh3yoSCIHwH1we8m18DjMTXYzoc2b99i"; + sout << "18h/UV7FF5xvl5awkZyLjDPPSDtdafb5ufHyNVjblORDSqS2soTzwyoyZm4PTCeiXWaO9Dpz81fM"; + sout << "+bCkqFrXm7qEJqYSrCGLDlxwVZJeHm/lCNpbO+GYq6Cd5CuaJVactLLRre6s43nyD7IxiRfCmb/f"; + sout << "LUyVi5sXEIaWiw80Me64uq/s9ADmuDUkX9Gd8WA+7fyyytSuMpooPNEsVBY9LE5nKWOlOy8Hrqmf"; + sout << "piWc2Pf5nUtyZSsK3p94XbsysthhunMLsiv5j4mcs61xi2IyEgWB5hJ01qk4gQV/8SHPYJ4stRxA"; + sout << "Ea80306xhLLKQjYSpPKHOvoil9kCHIgBzOp6lZas2vOzK/w50AVekKYXFQK2lWMs8TUBzWy5fYPK"; + sout << "CZgcrWP0fShh9pthBw0DBEmpGM6fQZ5XQlBC2hG8HEDbbakqvIUkpuL7jlFde+HW31aTQRluTYl3"; + sout << "huZ21u40SMx86ghv3VUzHpmF5x2b+UIIi+l0yg5FZHXRsth2xRsm4pNO6bspgiL/HrWMfsjwD2Vz"; + sout << "d/2Kx2Dn9FLRUg/ASedunqth0Ovq3z9Qds7pH6QVdBUmtPokcHoC3KKl1gmY7/cN880Az+h0SMpn"; + sout << "eqduvQM9adP2tmuybV5zgKGCt1q6cc0fPPBD1DuwAgr832VjU87nVOl13p4TV9NKX6wvnfRcw1bQ"; + sout << "nJdFr911d2uMjwuPJdKusPo6o86c8YHTOcmUkC2QkMM6gsYp/lK+lv9fwJhvXUKg0aBlNceJ0/eK"; + sout << "LGzHHzsVCweHXjlVY6Z89uMHZZeqp/wzEfatokYF+jIfD9rP+9AyuMIRuIOuXTCkemsGcAHqpg3F"; + sout << "EcSZcaimDAyropuc7CYsVhhxKRDQBYjTbnd0dhIIDZ9WVP/MbG7QRmJF77TB1+a6GlNjoOYEuJfm"; + sout << "RX34p0IQ/ycmc8PcUbFXAC2/epoQKPRprwg2+EbciWSYQe9i8T9gzJVuVHaWF1GjlsNJNvJDnWVW"; + sout << "2ffDvQuZ/YZ//zqKcA6e9A6tTttCUD4XebQmhT5vIesFMuKNUHBvJZwerszeY+AY1Hs8kwTJNMB/"; + sout << "DDj71I1sz1vq7X8OczT4vaHqLDg/4MiyHFatIaGMlbegVLtthaj7BdhwxM7xz0iilKncYQ2zYw9S"; + sout << "wMgYGoTth7eZQe/q0rgzXi25acEvNkbidVbeI+PtUQ1694G/eKRqOYnmaWmhMsCsEUJH5ZI+XhkN"; + sout << "+94T9Tjb6s/P9z0PisH0UAUDT0Rp+DeikJF1h/yLnxhQ3KxIwt9yB+ZlizVXB+6F7xcOAXuVocD+"; + sout << "AyoxOZRmI7dRlFB28ki5Bcl/EHXa70EEFyFao+xc66nv7luVhyscR7PydzdbIlYba2tnkr/QS5RC"; + sout << "kQ+4t8Z2smt7YKo2d/A4Gz3YNk3K3ZbUDWWSClHkcUklQVJDGg4b2da+V9RV8iAuugNuEdDVrs1r"; + sout << "ixGKhyyEGGfIaFUPEaL5/NrAZqBuX6FSuloE7m/MShmkxRuilG5Ngqy9Gb273F5O5x5T6/koSaTc"; + sout << "5tLEdNFmpZYmj7vdIAsHeoyjxmmSycfE8lsCFh1yZRp3I58aXVxBoTHGnQYkQoIeBpr9GBPTo9hx"; + sout << "k05R5M/LAy+Y11NSEW/gRNSiDkmUDGclorU8nz+dLVuyq4ZFX0fGt+IH92B/Ut6oX+S8CaL0iDcf"; + sout << "L+AtFn+m7o5UUKZx9KN2YEv3EbxEdl3m3BsSsgJr/KnVvd88zCEoILj6zZAHE0tkqrYRDX0rKjGc"; + sout << "1LNaJQ+TGXjE0btPlj4hVLWwZHJx9JT70uDDDQ/xz4v8p4Q8MQqfyhf5dc6/GeZtMP031eq1R84J"; + sout << "TeEyGjpF7F4Kngzg1q8mFT8Ay4dmF1rCwvwMl7QGwMAIp0lx+MRC3J9YK5TALpvK2NA70LNDzquo"; + sout << "vuu2cNlAGDtCBJNQ7n9vFEDy37OeXwbTpoJDOWXx4uDL6HQD4yZKxeFbX9AdyS3pTcl12wDkodRU"; + sout << "ESXKNoL9DydL+atZpSrTK06OYqoo4s5ihsGUh/CRJe1owWDoCHJEmT4ghhVeU8YVHxxdVEpOtXw2"; + sout << "csBk3ljCjfpZoXf5yLtEa5Md5JdAP7fVuqDv8sPzg/I0IvXvD24a1RxPcalo5Z5adVfWWGZkC1W/"; + sout << "oBAEgcYFTCVW7IprKK/JuNv1988z19JHhpqMEDNWr7JszAEQ9KRTNtLYjFb/uDCdUSgqiQbV6tjD"; + sout << "PeeKTQbxZ4r6fmtEuV75z/0be62g4t+/aHGNWJjJuZ0P1A6of0LOPZwKhRY2kydC8okBVp3TsM76"; + sout << "9p8yuQAs3WuzaSJR8H2woYQoykHJV/ARapMBuxHlvrhDWFITpyN2LXl5suea3UK1GBJ6HWSiFrIQ"; + sout << "RvMpY13CsRH7uPdx0svXicHK/GRnOPr6ei7cmMsp+nOKXmE15XfatnD8N6OHLImCrfY+bLS1FO4K"; + sout << "EOWti+cmcfz06z70BNUnGSHJqWNohvvGVsre5rimgFSRUJrxN1RTrievQuaVB+hya9rL+dKBRjmf"; + sout << "Uc95nLFFBzuhO/CYEzaGX8JpyyQgh0I38lMww3jK+FRbw3AocP7/rdaNpqi6coY/eFrl8Iv8drWh"; + sout << "5B59c5boqoOa6ZFvIkqB3oJp7ogpO6zFnSS3rGXt0tMWyj5PkSWeN2Tq9pO2gdSM/p0UgN4Ywcpn"; + sout << "rU8gtJgD/zct8G5pH4rAETV3vjfKEnlqG47oIDJzi2PY5zuiSlf3z5pY2nnPjhAFhlBAOGxCV+Ch"; + sout << "i0Y0ziAO3PKo9YXNF8q2hnzroT2o/GXjOZdC56mkRdzYALMv5vkPTQMBqddjahpZDVJLN/jCsnGF"; + sout << "fVOCW57+dzaFVT5Zlsk9xqIaUim+bDg2IHv832FM49MIQx7sJAS4lRmsZNlS1NjWKHwsOtgLPK7+"; + sout << "jRIc6qhU1i7l7cWFd6+oM0U3Sv5yBXTLdTGbWbUniUCn0Izv29BjX9KouaFPRNTyKYfNnoE2dDq/"; + sout << "jNGe+Uxcnbt8vxewpCqRvS7iGBX+Ylf6MW+HkhFu5eKu2pSEK5JyLLS/+kSRHyLhdmhz1PBRh+mr"; + sout << "duozCqvGZN1cMESXzPAVSgKE2sFz7au28raq7YvYI+Pe/8AbD75HPkYlEdVu6SXzwNGrCksJuE1A"; + sout << "u/tAl4GZzEzvyQqUtcf3HD1dWV/ihrtPgXbpCR+GeR9zWrj4MjTDm7ZtsL2NDm599UTNNPJgaIxD"; + sout << "5coKin30t2hOg6LzFoGKpAwGTijauINY/xAqgQBA8vEQ7uYGK2bkPn9llAAG9e2L+KPKVS6nLyFf"; + sout << "unzr/rPkxU4VITFN6V9GGZoJQZ/QFiCm6kLO+beLPgsPkZAqJNO5Gl9OTa6728Ew05ZnziMsWJaM"; + sout << "OaAqjBrE92wtFITs3Qdr+CHKgKqsrGXEUK6hfylB0pOhtNq78gtwsQ8rFzwyu3hoDv4YVEt6FBOx"; + sout << "zb1KBFOMz2x/RZ1qTdO9bMONUe31rjsOTmiFu8/CQyYfjciF8cp367jbzgcWV2aFhQkY2tA7SL/P"; + sout << "HCU5bSb4qpqJG4zg+RPv2Hx/6DpmIDXwJgSajh3a7O+HfcMKIRvWOXqKkc4LPK1RmMy8aDYzNnav"; + sout << "frd4Ii8a2KCeVqsGmJylpypEeMyjQCX8CcIYBYOEVQmUQGAoO71Cauftv1pt1yFAUzpDn+gGGSmN"; + sout << "pp2ZCm/qPb63lb6Kz94Piq6oVw23zqFrr0pqXr2TEi4e7jTLzMGCcXBUj/qNiCly7TmlFzM4obpO"; + sout << "1ev0yo6ccmAF8H4BOeyX7lqeqlwZHmpjc/8oa7QwuQnBXB7c7HHm+L9F3N9QoPnEqLtSmrJmNPoL"; + sout << "qk1d0l1174BQPBZCl6dcCHQdefAZAQ5v66WcPAoZNlt0lVL6CBman1pk5p0e4zU1EPrqYIUxzfBG"; + sout << "mLv2zWim3OpVjpYpB82fhtIlyIwOst+2rkbeCdIm/3X54LiC/hudzp7zUa4pe99DM8jenauzTIR6"; + sout << "BdqYbHBRQTc5rKaUmRDq/+JPVaG2dAjWVTdPHLs+rFM3MvLdd0wPG2T26uwwQyhAx+PHT9JEhU+t"; + sout << "pSJpE/s6LlZmqt6RPuGgYuO6jifhECGWmdy6SgT2wYl/9REvPSsMIMiB9DAbdKAFv7ios1KtcjOq"; + sout << "pIOUs4NKeJ3QMSU3lE5JXf0V45VBkJ5JfO9lyCMgHRGFd89mRf8/HON8GYkSidyFF+d2Z+po7tHS"; + sout << "Bhfq86T6T4vTUDE3KCIcsir6kE+hyZylWw+fnBRzWYVYMBp2YCKHybxBKdkxpvAnDLibZYdyEtd/"; + sout << "20RPeXxxk54lJkFAjNy6vtnh6vLomfNALcZ8oqS8iZWX5v4q35b152XHo9lEWbTxokbbXeMmqLdL"; + sout << "UjBrCPkD1j9ogboDfWD5TNg60dJVikPCyUHbTSTTEU+I9niREiVDdZToXbeaeRgOKYxtnwWiA/FM"; + sout << "BYXiPb57y+/il8TD1ZT74JRjz6kAmSa3bM7GClr//V3Mdl3TpoP949ZhX1L16IHc8WwgE0rQnQHV"; + sout << "NxPxVWS0EhAdmlYB73ib5jPua1rtlwRbeYbDF2iNsuw8ss8phioK+ZR6BnZX0XZN2Tw58Fa1e5kV"; + sout << "y2kHZGIkY6hk/lrA/vAP+QuV8Pxb/P7VXVsgkHglBYuZkkZLgjqAZDvPjvZcBqEDQjCHrG6V4woD"; + sout << "X7xWGVYXVRt5vJcwOX0hKCcUSv3+1XL6+ID0urNTuFqMw96QKtb3//H6qZvZIO/bYGjJcDVbQHJB"; + sout << "6Q//OjqKgSiw67SotBAvIVQrgDeHUzaS7JdjyvWHCooClKrAiIVcnKX02r1y+mVpeL36166MC9D7"; + sout << "dAi/coCiOwcQ6STbBY2PPhWrko31E8o/l3uDeC4cl3TIVjkFjS61GMogqvd2ISZ3Egq1jZkRYMln"; + sout << "962j1y32UizRikxEK+2d0UREZ70oZzW3UoD6YqshoElP78GkH7HGS8hjpL+UhuCjCura1FK4/qVG"; + sout << "yp5GneObh4S9DzJV+IZuviLd6drknWn+nKYS4YPWkbOqBu7RBvh3eb4bQMzvVl2thL1b4Ff8Q79d"; + sout << "mWxu8ajtptUs0OrSor9gnqrLu3K3RWRTWSElrrMHjCD5GZcsrR/qp7Ip4GDBRPCVJNLa1Tmm8TMc"; + sout << "kTkGtbN0E96kT2qVA4/s0vYCxntAUPunP0k/JtWg5YzVXQi5/hKDdIE3EfpV2PXjveDoBasXAmEi"; + sout << "6oUcZk1zKqwFygj1793MM02voesHS7BxWz50W6yD+SozDsnliV7fY4S9+8vj8CipRYSD8t7MoczJ"; + sout << "hg2tEMYVVBo/nV/4l607tvicVT8HYWOkBhuOgeMNrocTp/3zcfiCODeda8rzbGPWjmKOCPcgOmTD"; + sout << "TO5jx0f568jqAuhFgeg6wv/3uXuY3mmFGWsCx4Sf+2nlWqhCMKmSsRT84UXcFYpp0Kcvx1OUFC/c"; + sout << "E64GAqCYYhzb3hF7jys36qQefbfm9+t3owVhP0d9udWNBzw/QeMJ6djmHvk11Tl3qvjGUGV0iXh6"; + sout << "4ZxyCeLQWZA6cNcN8Ovd2vRrtQ8SHlFPhMKqoevmMLkyZKMrD9CUPHmgzlpTuAasa7PEnbaAFHcO"; + sout << "sVE1KNm7uMU7QjFI8u10VRJ6qBpTx6Z3GXPq7Jslk2V6z0xmH/elkzMAu2Wr3MId0/7GCuheVZhh"; + sout << "VAf9EWJ2ZfZxGBOXd8Io9eiatd/VdlOh7FBglIGSpx8UHU3jzpSu1fcnCjVRg3XxWmUI0iRrqxQc"; + sout << "iVT6ttiezDImpPlHxP2OjgIghWD6EZ2Gesunu93a2lep85rEZqN8sACV2sXDy5K7CySNLyhNE7fX"; + sout << "eV1bU3FdOstad/82yh7TJdIpIdFV1tEOV+gAOMZa+516EjdnDs4WJpfWHPWG6xdVJWAvCHss1B8s"; + sout << "k11txpDa6vs3+NCx/mF/6ElB7TxkisPo1m+KfgjBGpI/YHlm6c216WwSh1k1hLk0T3s4wFEb8M5w"; + sout << "BlbxPkt89Y0wWwc/Eg37XquMIme6aBZjU3CZK1NwiAPkKXa9Y6fBTZWLT3W/NpP2Vk8KGxae2fRo"; + sout << "F83VOgFxb1SUIePgZ2vMS/6OnuRxNqiwkcDEI15uVcK+l1AynHFqODaA31sxqQuHnw/FOrPG5yJR"; + sout << "OvmgJOJ9ss9QZkvXaTZFc3vfEZElcKdW9K5xEZfiymZWX9Qihiv4PG/L+qo5Tm6cFxi/8MVW7Tgd"; + sout << "OV5whxO+RWOCj1kyuILneXxrwiYL3tn74Z4+tT6oP2I9l2UXF574JXrOzLLeBhULrD9fpPFlb/lM"; + sout << "5YfP97MRad6MEZfY5uMUa36kUR8s7FMdTpSyKqhEmzmepGD7JI0uGPNutfSDO5SJsPfK8Yh1uPWA"; + sout << "5J5d/NEUp/fqvoniWgm/2ye0ApW90EtX9eDE/DyfPQimpqOdFBO7/EvDkDHlZ0u08F7+sCavgzxE"; + sout << "jHMXJ+uOF5q96vNnGDhyc8WC9NRzX66GChvTCh7nDhBYLMCZzSvWU7wPWC9OLOCvA+lyTTGvFCgs"; + sout << "qCrk5Hoc5etEIrpyOCe6evI681jAoNI0KK4Opb/vqVtKgLTxJhBbi/EVhJTzdALEB/WduuYpfqqD"; + sout << "sUbdDlNRdSGgricRZUxD0iN5GZUBDEzVzp6moEW+Q7NxsryW6/88Ow3ky1fm41QPwQ40Rk45cnIU"; + sout << "5nZxTrYUkBoaWR/Lh43XtKDh3cBhADGiaMIPBVJL/tOUtXTCj93ca73/LqanFOO/rrnROhXUCn1h"; + sout << "+LWEWn553i9PF0CqFkFjtK/dP68wPQ7y09/NYYw9P86YdgoWwm6ZlqskD2ByyYgR5KimmyJLQyxb"; + sout << "CmITcgVbucDiDxUUFG4Rq9Sgn4PW6jGaoujCgclaxOgLnM7lqREfW8XlN/yH8Pvyt7hjAIWNOHyg"; + sout << "hiscgSb1ev+btDSm+aRCjPgRNfBDApl2zINR7ey0wA+wOl6wZh7cxYwbeb55/UqH6OzEXbjJzjki"; + sout << "6c9Z95PGsu8U+FaBJmWyJWqb/SxBxsf4zPH154KF7wE6b0AoP2Y1tGD2hDN8WdEthVWbztLm66nI"; + sout << "n0fz38pv7CNJCuVf+jeqvk6bh8/dI0hVF0CsMd5uVe+fOLzK3nmIma5nB/TQvcCqfR1YjDGV3mhd"; + sout << "QPQ1a7ypy6WplKOU5LamJw8IaYrik3Gg2tLo63FfE8njUkrYdtqiYpgk+okF8vF4S8c8NdWlgTNn"; + sout << "KE+9+dPQZr174KhRz1JwzspIUqW/d/QjGIgJm5NcXAo07/hgpfI2zfHlkGP5ETxLnquzkCcCsrnK"; + sout << "IQ+vacYRwHRVSTVWpvQiDXaNK7RJrfW2ei23/bW8YkE72+lZIACjehXSDgKiBg0EIStpOVLwTSln"; + sout << "fXVW5fHQkyAc4fOVX6bGqcBu40hHN/LRFJY7Jv5qS/AjCtLUR4UCRhMWAo7ITDLgbXoIDKnyrRLF"; + sout << "vS7LGaJ9IHbFv0Lw8Xiqr1YSpjnLZZDcREIsjCf+2G/7kNBY96DPmp63sGuBfIR0f/I6MzUi8HNG"; + sout << "R+EBmcCO0AfPzTeoGkIM9+5YmMpXwvMU0zCz71PxeLlnTUVUggnXTBAa92YS6HuhGKIl5aU0OCXW"; + sout << "76CfhWijfNZCG33EdpOJbXxbtbX97y+Xjn1KyPnNW/Wyk5VO8lqD2ZGC1Wnohpx4q2VtDkVzw9h+"; + sout << "DDbj2hpyfv2byvYJPUSlZ4bJ77aCkYuZsjMOaz4MYjVfgdi7svHdnIBauBgv+rubJAwx8oB9bk3K"; + sout << "Wy2EDPfLHGxtUKt9+keHGnx9xZ9vGjLXzaLiGltRQ5YDiVK/vXPUc8eKy7p+NyNu80yopPIjJcgV"; + sout << "IC0mPwR2BvdGVweS/iLAuMVJw3STNQyl7XVej6lppzt2SWidtjqxYnYjJwRYNnFr0l7SgVA9Tink"; + sout << "jZz7M4OGdDN521PxFXjOcRRRM1Kr7/56n+VF+LdTRCihmUsOuGi4bBlZ+eKTcUmg7FX6LXKptWLm"; + sout << "h6QUA5ZXcUZ9XUrJb1AxaHmfP/XXd9r1bsi+F3JdFUPuCOZHrzok6cshx+9r8WH1MNCBBVTeUrFl"; + sout << "+65D1RufbIVoqZfQ8l3+zghe+U/ujnqEox8ysQAZkhYG6beS6ksj/QhWbGno9W8TeaXr2NvcF2ct"; + sout << "PAh85hlFNWcxBdGnHXl8fQJ+p5+t8+jftQLoVqd6B9/beAUgp9KYrLugImLUnzw6Q7PjXCiFmo8d"; + sout << "gqPvKtiqj4mKu5L5L4/ul2zdN4zQtRE6tm24+ENVyfJ00adWwqEGxrCIvUMiNbOzSI9TtoTVrKgx"; + sout << "lON3nzfcVy3ucT93AtQfa2reu/HwRVNhn0GP/xZGBCy746if7jr5Fa34dAESoJ1mlelIKOion8QJ"; + sout << "bsgQixF1UWwd1/O4DgHG+U43HYhLWOSw+NOnajpEaVoKB+jf/M0u5+BQvbwCcE6zd+Vhmhzy6x7i"; + sout << "5ySbdBren80S4wUt97VFIq7HI0KvvWpkbAeQjGOhAWIqRrtbJup9yQa09VJiR0hiMQ8K2UlTx3By"; + sout << "BFtRAiJO5mn85Q0YxqyogwSE2Kf4F7qu3EBnT/rEta+a0e/UcElFcZPpJfqSwS18mNBPDEyagvBv"; + sout << "BQiEumckYIVwtFMpfrvELupxCTY51anDR44j+RTJlQMf91CHDZi9eewt9Rfo4ja5HTTdK5Wmjkvt"; + sout << "DR2SBS/Z4YJX4elAbie5bsrFWR7IH9M8M/vtR++6Sw2SKEAQjS7D23xmoKXD/bC9JWoz9P0yICkJ"; + sout << "Ckdwh6H+FNzh6Ms8qlcEWtQD28SI271I9tt0Yiw7FvwIYH/+aG7+PnXpyei7Y8OmV2scS2CaAR2e"; + sout << "+k0qUumoXI/yPunLeJE8BQHjslaw8eglGnb7YjfgmoGPC3lzrc69UX3Okmnb8JJUoC3KEsN0Zf0n"; + sout << "GPErNUSm7ABFRAQvHvTiDxfGyf8aN6UNYZDC0rNO5MgYpQxvmeTnCMtXz0P9OEpaBs3ZuNoJVR+n"; + sout << "s1zhiz5JteC9VmCFM8PknTz2tQXBWonqXdHxCPOIi7k0K1Jn8ddygrfUczXeMeZY/H/akoWADYHT"; + sout << "cGPqpi/81wQUaE4GmV6X27XcuHUcELZTkyTqpkBprMoPJJDspnbFj0aaHiz19Ws0weE3wWhUFiA4"; + sout << "D4WlUi4uF40OEAhnv0qcSVHil71C7AEpTcS1dxe+GSs02PppfT+PZDexeWk/S/C8E+O/Fr4BVfO0"; + sout << "G+M+XdcvagPHLZYLTEJhLhvQzPQRbOb9vdPUhiIsodTUzejdQ+zA6k7ynQqNEDAaVwyy82AQRyoi"; + sout << "4BFGQzQ2hYCUp+In9B/69qyBSvT+q6E1YDUARctcr3h/AR/3J7uG4VdxhFA6YOKAnJGCSCMdunwi"; + sout << "hb1or9pPHhW5rDfvmtH9iZCsuvvJmFA8FXmAURRDRziSxYR/zcApR0JBT63mtyw3lMCHP1qDg95w"; + sout << "0ysG3hLdyMI3lnM/g1W0tg1hvXR0C3qjivvuthJHyt3fqSZbVnvNEkD2k+7BoG6KWk5EpoSqv+Ni"; + sout << "znQvLFggfPOtd6Y8BkmBcNhCTCcyMa2ZkvzUNO7419VR/gNeU56athpUvD21r8d0/MWrEi2yKWsM"; + sout << "aCVwYr/eAgp9kpn9i8RuSJmxe4NTb7Isqn25vOfUeUY17vS+c6wjCDqbkmukld4nOHl8uHlHdJoK"; + sout << "elhC/yXSgUyOM8yKFSyw+Ap6v9CISYh3y09g6gBL35LjS7Vb4Ew/ks2t2L9zWrU+IC4BLcQDmMfG"; + sout << "QUk5tt1n05bmuLmnSboYhPj2MNT9/EiqHY9AfwPhMYmWbcSCshdShSoYs1aciC1gc1H393KbvaTl"; + sout << "N6TeDKz6bEkvM0mOu+wRDR5u+tnr0wWSW8JkMt+WkjPvqWkGYiKZuhJT/4rTPRhJEKXE34gxeyFA"; + sout << "16oufb8W9C8Fj30WPJq1bd43y/r0Fs1dwOyZAyLsHv/2AtO5/Jw2KEHsTIgolGr0Qe6//jc/c50R"; + sout << "TN/aDVTBCbvhcRtV/zOuy7oqb+p0IM3p7eWtKFrZ1wXly7FvvPQ5ODn9CnG5wo59XTikanWOMLKf"; + sout << "MwQt89ZW60v9bQTgIUnFYpVlp9SF2XPxmeW6x+NVCEOiXZQ8AM3nDiLA5M0ctrI2DzzLeHoEKWem"; + sout << "TvVZ1MrwJ8jZoRsigx45HqoB353+bXlS+5vMB/M+zUEPu/HHuU5k8zqwU93NFkTR208ZtecG3IPs"; + sout << "ENQd7J10XsbhbYhAvEkU9ZS/3FS7aK7bDvf7cqD/QcGCfIad4rk1Ks6nuGSeR7hUrH1NK1fe9lpY"; + sout << "lNk3EaNPlBEBmNkbGpNIUcJ7ntUy+b1q+VoC+q300a4qo3gIOkN3s25NaDJime/eJmlRYZhH4ip8"; + sout << "6+m+nA8as1/T0/4d7HXiFWQwZ49NZsSESry1yCs97C5IQ6nScahxHDi72AbQeNDB+RtGaQJiOUi7"; + sout << "NSBNuSlm9G9GpR55HvMi/JUqF09BrOn+49zwYqMbSkf8CjoenLM7UzCoywGlk0nXSBsbANAz7D2B"; + sout << "65qWXswtvr4xnCSfRLUNHs3AtlfEqpsdlaF9gDQm5Z4IhT1WOXbETcY5S8T/5DDBIoWi9rHnKuxM"; + sout << "+zmu881jhj3d9p8fbcH65hevrYM49+ZQfWPXTMUY77YbwYTGmYgScAWmqaMugoMWD1ocpYYRM0IJ"; + sout << "+SEiUb57moAOoEeiYZcPqmckTWuHJhYtgbuBojwXqaK/qvDssM/59aTMWagOYHcapC4gBG+s99No"; + sout << "pOCnbe2brIm4+6xWs7LzSA38RZHZSdh66V3n+83R0/wAIw9+X35SXMwrXC96OqXF/6AFvqkL2Wnk"; + sout << "SBbvyq0txWR6b7AaZ418Dmngg3yQh04fwc8xZLy7/1ZYAbGLRRV1mNrpc2Fa1kLjxoRHZMBA75Pt"; + sout << "HirY4CHOvKaEdlk27BW2px1QCTCkZQ/gojWhiZ1kPUAUiW7VcyFSzjtXzswHEIAnGR2dWHgGZDVT"; + sout << "OVuBJ0nTPs8itQ2Htelag60Et9jrwDZzYa4Rhy1FjngWN1S/QAp9iGe95SRVXuBtLNgAVp+sx7SU"; + sout << "VOECSHoLfpSeZPvlm5ibeSN83gFbIG2rsTZ3IlvJjWq82Npzas6p9WVKTEPGS+Ux8nWIBT/enw7o"; + sout << "7KX9phVWQqcYH5IB2waRO+Ke7h6/y696NQMq0R4Xbki9lmjoWNKFtM+GgLygVqxWWWp9iyQFkUQx"; + sout << "7tRJT1da0ImlgCXS/uTTRxvcG9d/E5FMotFa6mA7py7P+eraScFdEHL4J0kA"; + + // Put the data into the istream sin + sin.str(sout.str()); + sout.str(""); + + // Decode the base64 text into its compressed binary form + base64_coder.decode(sin,sout); + sin.clear(); + sin.str(sout.str()); + sout.str(""); + + // Decompress the data into its original form + compressor.decompress(sin,sout); + + // Return the decoded and decompressed data + return sout.str(); + } + + + + +} + + + + + diff --git a/ml/dlib/dlib/test/ranking.cpp b/ml/dlib/dlib/test/ranking.cpp new file mode 100644 index 000000000..83f70d8cc --- /dev/null +++ b/ml/dlib/dlib/test/ranking.cpp @@ -0,0 +1,485 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tester.h" + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + + logger dlog("test.ranking"); + +// ---------------------------------------------------------------------------------------- + + template + void brute_force_count_ranking_inversions ( + const std::vector& x, + const std::vector& y, + std::vector& x_count, + std::vector& y_count + ) + { + x_count.assign(x.size(),0); + y_count.assign(y.size(),0); + + for (unsigned long i = 0; i < x.size(); ++i) + { + for (unsigned long j = 0; j < y.size(); ++j) + { + if (x[i] <= y[j]) + { + x_count[i]++; + y_count[j]++; + } + } + } + } + +// ---------------------------------------------------------------------------------------- + + void test_count_ranking_inversions() + { + print_spinner(); + dlog << LINFO << "in test_count_ranking_inversions()"; + + dlib::rand rnd; + std::vector x, y; + std::vector x_count, y_count; + std::vector x_count2, y_count2; + for (int iter = 0; iter < 5000; ++iter) + { + x.resize(rnd.get_random_32bit_number()%10); + y.resize(rnd.get_random_32bit_number()%10); + for (unsigned long i = 0; i < x.size(); ++i) + x[i] = ((int)rnd.get_random_32bit_number()%10) - 5; + for (unsigned long i = 0; i < y.size(); ++i) + y[i] = ((int)rnd.get_random_32bit_number()%10) - 5; + + count_ranking_inversions(x, y, x_count, y_count); + brute_force_count_ranking_inversions(x, y, x_count2, y_count2); + + DLIB_TEST(mat(x_count) == mat(x_count2)); + DLIB_TEST(mat(y_count) == mat(y_count2)); + } + } + +// ---------------------------------------------------------------------------------------- + + void run_prior_test() + { + print_spinner(); + typedef matrix sample_type; + typedef linear_kernel kernel_type; + + svm_rank_trainer trainer; + + ranking_pair data; + + sample_type samp; + samp = 0, 0, 1; data.relevant.push_back(samp); + samp = 0, 1, 0; data.nonrelevant.push_back(samp); + + trainer.set_c(10); + decision_function df = trainer.train(data); + + trainer.set_prior(df); + + data.relevant.clear(); + data.nonrelevant.clear(); + samp = 1, 0, 0; data.relevant.push_back(samp); + samp = 0, 1, 0; data.nonrelevant.push_back(samp); + + df = trainer.train(data); + + dlog << LINFO << trans(df.basis_vectors(0)); + DLIB_TEST(df.basis_vectors(0)(0) > 0); + DLIB_TEST(df.basis_vectors(0)(1) < 0); + DLIB_TEST(df.basis_vectors(0)(2) > 0); + } + +// ---------------------------------------------------------------------------------------- + + void run_prior_sparse_test() + { + print_spinner(); + typedef std::map sample_type; + typedef sparse_linear_kernel kernel_type; + + svm_rank_trainer trainer; + + ranking_pair data; + + sample_type samp; + samp[0] = 1; data.relevant.push_back(samp); samp.clear(); + samp[1] = 1; data.nonrelevant.push_back(samp); samp.clear(); + + trainer.set_c(10); + decision_function df = trainer.train(data); + + trainer.set_prior(df); + + data.relevant.clear(); + data.nonrelevant.clear(); + samp[2] = 1; data.relevant.push_back(samp); samp.clear(); + samp[1] = 1; data.nonrelevant.push_back(samp); samp.clear(); + + df = trainer.train(data); + + matrix w = sparse_to_dense(df.basis_vectors(0)); + dlog << LINFO << trans(w); + DLIB_TEST(w(0) > 0.1); + DLIB_TEST(w(1) < -0.1); + DLIB_TEST(w(2) > 0.1); + } + +// ---------------------------------------------------------------------------------------- + + void dotest1() + { + print_spinner(); + dlog << LINFO << "in dotest1()"; + + typedef matrix sample_type; + + typedef linear_kernel kernel_type; + + svm_rank_trainer trainer; + + + std::vector > samples; + + ranking_pair p; + sample_type samp; + + samp = 0, 0, 0, 1; p.relevant.push_back(samp); + samp = 1, 0, 0, 0; p.nonrelevant.push_back(samp); + samples.push_back(p); + + samp = 0, 0, 1, 0; p.relevant.push_back(samp); + samp = 1, 0, 0, 0; p.nonrelevant.push_back(samp); + samp = 0, 1, 0, 0; p.nonrelevant.push_back(samp); + samp = 0, 1, 0, 0; p.nonrelevant.push_back(samp); + samples.push_back(p); + + + trainer.set_c(10); + + decision_function df = trainer.train(samples); + + dlog << LINFO << "accuracy: "<< test_ranking_function(df, samples); + matrix res; + res = 1,1; + DLIB_TEST(equal(test_ranking_function(df, samples), res)); + + DLIB_TEST(equal(test_ranking_function(trainer.train(samples[1]), samples), res)); + + trainer.set_epsilon(1e-13); + df = trainer.train(samples); + + dlog << LINFO << df.basis_vectors(0); + sample_type truew; + truew = -0.5, -0.5, 0.5, 0.5; + DLIB_TEST(length(truew - df.basis_vectors(0)) < 1e-10); + + dlog << LINFO << "accuracy: "<< test_ranking_function(df, samples); + DLIB_TEST(equal(test_ranking_function(df, samples), res)); + + dlog << LINFO << "cv-accuracy: "<< cross_validate_ranking_trainer(trainer, samples,2); + DLIB_TEST(std::abs(cross_validate_ranking_trainer(trainer, samples,2)(0) - 0.7777777778) < 0.0001); + + trainer.set_learns_nonnegative_weights(true); + df = trainer.train(samples); + truew = 0, 0, 1.0, 1.0; + dlog << LINFO << df.basis_vectors(0); + DLIB_TEST(length(truew - df.basis_vectors(0)) < 1e-10); + dlog << LINFO << "accuracy: "<< test_ranking_function(df, samples); + DLIB_TEST(equal(test_ranking_function(df, samples), res)); + + + samples.clear(); + samples.push_back(p); + samples.push_back(p); + samples.push_back(p); + samples.push_back(p); + dlog << LINFO << "cv-accuracy: "<< cross_validate_ranking_trainer(trainer, samples,4); + DLIB_TEST(equal(cross_validate_ranking_trainer(trainer, samples,4) , res)); + + df.basis_vectors(0) = 0; + dlog << LINFO << "BAD RANKING:" << test_ranking_function(df, samples); + DLIB_TEST(test_ranking_function(df, samples)(1) < 0.5); + } + +// ---------------------------------------------------------------------------------------- + + void dotest_sparse_vectors() + { + print_spinner(); + dlog << LINFO << "in dotest_sparse_vectors()"; + + typedef std::map sample_type; + + typedef sparse_linear_kernel kernel_type; + + svm_rank_trainer trainer; + + + std::vector > samples; + + ranking_pair p; + sample_type samp; + + samp[3] = 1; p.relevant.push_back(samp); samp.clear(); + samp[0] = 1; p.nonrelevant.push_back(samp); samp.clear(); + samples.push_back(p); + + samp[2] = 1; p.relevant.push_back(samp); samp.clear(); + samp[0] = 1; p.nonrelevant.push_back(samp); samp.clear(); + samp[1] = 1; p.nonrelevant.push_back(samp); samp.clear(); + samp[1] = 1; p.nonrelevant.push_back(samp); samp.clear(); + samples.push_back(p); + + + trainer.set_c(10); + + decision_function df = trainer.train(samples); + + matrix res; + res = 1,1; + + dlog << LINFO << "accuracy: "<< test_ranking_function(df, samples); + DLIB_TEST(equal(test_ranking_function(df, samples), res)); + + DLIB_TEST(equal(test_ranking_function(trainer.train(samples[1]), samples), res)); + + trainer.set_epsilon(1e-13); + df = trainer.train(samples); + + dlog << LINFO << sparse_to_dense(df.basis_vectors(0)); + sample_type truew; + truew[0] = -0.5; + truew[1] = -0.5; + truew[2] = 0.5; + truew[3] = 0.5; + DLIB_TEST(length(subtract(truew , df.basis_vectors(0))) < 1e-10); + + dlog << LINFO << "accuracy: "<< test_ranking_function(df, samples); + DLIB_TEST(equal(test_ranking_function(df, samples), res)); + + dlog << LINFO << "cv-accuracy: "<< cross_validate_ranking_trainer(trainer, samples,2); + DLIB_TEST(std::abs(cross_validate_ranking_trainer(trainer, samples,2)(0) - 0.7777777778) < 0.0001); + + trainer.set_learns_nonnegative_weights(true); + df = trainer.train(samples); + truew[0] = 0.0; + truew[1] = 0.0; + truew[2] = 1.0; + truew[3] = 1.0; + dlog << LINFO << sparse_to_dense(df.basis_vectors(0)); + DLIB_TEST(length(subtract(truew , df.basis_vectors(0))) < 1e-10); + dlog << LINFO << "accuracy: "<< test_ranking_function(df, samples); + DLIB_TEST(equal(test_ranking_function(df, samples), res)); + + + samples.clear(); + samples.push_back(p); + samples.push_back(p); + samples.push_back(p); + samples.push_back(p); + dlog << LINFO << "cv-accuracy: "<< cross_validate_ranking_trainer(trainer, samples,4); + DLIB_TEST(equal(cross_validate_ranking_trainer(trainer, samples,4) , res) ); + } + +// ---------------------------------------------------------------------------------------- + + template + class simple_rank_trainer + { + public: + template + decision_function train ( + const ranking_pair& pair + ) const + { + typedef matrix sample_type; + + std::vector relevant = pair.relevant; + std::vector nonrelevant = pair.nonrelevant; + + std::vector samples; + std::vector labels; + for (unsigned long i = 0; i < relevant.size(); ++i) + { + for (unsigned long j = 0; j < nonrelevant.size(); ++j) + { + samples.push_back(relevant[i] - nonrelevant[j]); + labels.push_back(+1); + samples.push_back(nonrelevant[i] - relevant[j]); + labels.push_back(-1); + } + } + + if (use_dcd_trainer) + { + svm_c_linear_dcd_trainer trainer; + trainer.set_c(1.0/samples.size()); + trainer.set_epsilon(1e-10); + trainer.force_last_weight_to_1(true); + //trainer.be_verbose(); + return trainer.train(samples, labels); + } + else + { + svm_c_linear_trainer trainer; + trainer.set_c(1.0); + trainer.set_epsilon(1e-13); + trainer.force_last_weight_to_1(true); + //trainer.be_verbose(); + decision_function df = trainer.train(samples, labels); + DLIB_TEST_MSG(df.b == 0, df.b); + return df; + } + } + }; + + template + void test_svmrank_weight_force_dense() + { + print_spinner(); + dlog << LINFO << "use_dcd_trainer: "<< use_dcd_trainer; + + typedef matrix sample_type; + typedef linear_kernel kernel_type; + + ranking_pair pair; + + for (int i = 0; i < 20; ++i) + { + pair.relevant.push_back(abs(gaussian_randm(10,1,i))); + } + + for (int i = 0; i < 20; ++i) + { + pair.nonrelevant.push_back(-abs(gaussian_randm(10,1,i+10000))); + pair.nonrelevant.back()(9) += 1; + } + + + svm_rank_trainer trainer; + trainer.force_last_weight_to_1(true); + trainer.set_epsilon(1e-13); + //trainer.be_verbose(); + decision_function df; + df = trainer.train(pair); + + matrix res; + res = 1,1; + dlog << LINFO << "weights: "<< trans(df.basis_vectors(0)); + const matrix acc1 = test_ranking_function(df, pair); + dlog << LINFO << "ranking accuracy: " << acc1; + DLIB_TEST(equal(acc1,res)); + + simple_rank_trainer strainer; + decision_function df2; + df2 = strainer.train(pair); + dlog << LINFO << "weights: "<< trans(df2.basis_vectors(0)); + const matrix acc2 = test_ranking_function(df2, pair); + dlog << LINFO << "ranking accuracy: " << acc2; + DLIB_TEST(equal(acc2,res)); + + dlog << LINFO << "w error: " << max(abs(df.basis_vectors(0) - df2.basis_vectors(0))); + dlog << LINFO << "b error: " << abs(df.b - df2.b); + DLIB_TEST(std::abs(max(abs(df.basis_vectors(0) - df2.basis_vectors(0)))) < 1e-8); + DLIB_TEST(std::abs(abs(df.b - df2.b)) < 1e-8); + } + +// ---------------------------------------------------------------------------------------- + + void test_dnn_ranking_loss() + { + print_spinner(); + typedef matrix sample_type; + + + ranking_pair data; + sample_type samp; + + // Make one relevant example. + samp = 1, 0; + data.relevant.push_back(samp); + + // Now make a non-relevant example. + samp = 0, 1; + data.nonrelevant.push_back(samp); + + + using net_type = loss_ranking>>>; + net_type net; + dnn_trainer trainer(net, sgd(1.0, 0.9)); + std::vector> x; + std::vector y; + + x.push_back(matrix_cast(data.relevant[0])); y.push_back(1); + x.push_back(matrix_cast(data.nonrelevant[0])); y.push_back(-1); + + //trainer.be_verbose(); + trainer.set_learning_rate_schedule(logspace(-1, -7, 4000)); + trainer.train(x,y); + + matrix params = mat(net.subnet().layer_details().get_layer_params()); + dlog << LINFO << "params: "<< params; + dlog << LINFO << "relevant output score: " << net(x[0]); + dlog << LINFO << "nonrelevant output score: " << net(x[1]); + + DLIB_TEST(std::abs(params(0) - 1) < 0.001); + DLIB_TEST(std::abs(params(1) + 1) < 0.001); + DLIB_TEST(std::abs(net(x[0]) - 1) < 0.001); + DLIB_TEST(std::abs(net(x[1]) + 1) < 0.001); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class test_ranking_tools : public tester + { + public: + test_ranking_tools ( + ) : + tester ("test_ranking", + "Runs tests on the ranking tools.") + {} + + + void perform_test ( + ) + { + test_count_ranking_inversions(); + dotest1(); + dotest_sparse_vectors(); + test_svmrank_weight_force_dense(); + test_svmrank_weight_force_dense(); + run_prior_test(); + run_prior_sparse_test(); + test_dnn_ranking_loss(); + + } + } a; + + +} + + + + diff --git a/ml/dlib/dlib/test/read_write_mutex.cpp b/ml/dlib/dlib/test/read_write_mutex.cpp new file mode 100644 index 000000000..fb8bdb84c --- /dev/null +++ b/ml/dlib/dlib/test/read_write_mutex.cpp @@ -0,0 +1,208 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include +#include + +#include "tester.h" + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.read_write_mutex"); + + class read_write_mutex_tester : public tester, multithreaded_object + { + public: + read_write_mutex_tester ( + ) : + tester ("test_read_write_mutex", + "Runs tests on the read_write_mutex component.") + { + register_thread(*this, &read_write_mutex_tester::thread_write); + register_thread(*this, &read_write_mutex_tester::thread_write); + register_thread(*this, &read_write_mutex_tester::thread_write); + + register_thread(*this, &read_write_mutex_tester::thread_readonly); + register_thread(*this, &read_write_mutex_tester::thread_readonly); + register_thread(*this, &read_write_mutex_tester::thread_readonly); + register_thread(*this, &read_write_mutex_tester::thread_readonly2); + register_thread(*this, &read_write_mutex_tester::thread_readonly2); + register_thread(*this, &read_write_mutex_tester::thread_readonly2); + + } + + read_write_mutex m; + + dlib::mutex mut; + int num_write; + int num_read; + int max_read; + + bool failure; + + void thread_write () + { + // do this so that the readonly threads can get into their loops first. This way + // we can see if the mutex lets many readers into their area + dlib::sleep(250); + for (int i = 0; i < 6; ++i) + { + auto_mutex lock(m); + + mut.lock(); + ++num_write; + mut.unlock(); + + // only one write thread should ever be active at once + if (num_write != 1) + { + failure = true; + dlog << LERROR << "1"; + } + + dlib::sleep(300); + + // only one write thread should ever be active at once + if (num_write != 1) + { + failure = true; + dlog << LERROR << "2"; + } + + mut.lock(); + --num_write; + mut.unlock(); + + print_spinner(); + } + dlog << LINFO << "exit thread_write()"; + } + + void do_readonly_stuff() + { + mut.lock(); + ++num_read; + max_read = max(num_read, max_read); + mut.unlock(); + + if (num_write != 0) + { + failure = true; + dlog << LERROR << "3"; + } + + dlib::sleep(300); + + if (num_write != 0) + { + failure = true; + dlog << LERROR << "4"; + } + + mut.lock(); + max_read = max(num_read, max_read); + --num_read; + mut.unlock(); + + print_spinner(); + } + + void thread_readonly () + { + for (int i = 0; i < 6; ++i) + { + auto_mutex_readonly lock(m); + DLIB_TEST(lock.has_read_lock()); + DLIB_TEST(!lock.has_write_lock()); + do_readonly_stuff(); + + lock.lock_readonly(); + DLIB_TEST(lock.has_read_lock()); + DLIB_TEST(!lock.has_write_lock()); + lock.unlock(); + DLIB_TEST(!lock.has_read_lock()); + DLIB_TEST(!lock.has_write_lock()); + lock.lock_readonly(); + DLIB_TEST(lock.has_read_lock()); + DLIB_TEST(!lock.has_write_lock()); + lock.lock_write(); + DLIB_TEST(!lock.has_read_lock()); + DLIB_TEST(lock.has_write_lock()); + lock.lock_write(); + DLIB_TEST(!lock.has_read_lock()); + DLIB_TEST(lock.has_write_lock()); + } + + dlog << LINFO << "exit thread_readonly()"; + } + + void thread_readonly2 () + { + for (int i = 0; i < 6; ++i) + { + m.lock_readonly(); + auto_unlock_readonly unlock(m); + + do_readonly_stuff(); + } + dlog << LINFO << "exit thread_readonly2()"; + } + + + void perform_test ( + ) + { + num_write = 0; + num_read = 0; + max_read = 0; + failure = false; + + // doing this big block of weird stuff should have no effect. + { + m.unlock(); + + m.lock_readonly(); + m.lock_readonly(); + m.unlock(); + m.unlock_readonly(); + m.unlock(); + m.unlock_readonly(); + + m.unlock(); + m.unlock_readonly(); + + m.lock(); + m.unlock_readonly(); + m.unlock_readonly(); + m.unlock(); + } + + + // start up our testing threads + start(); + + // wait for the threads to finish + wait(); + + + DLIB_TEST(failure == false); + DLIB_TEST_MSG(max_read == 6, "max_read: "<< max_read); + + } + + } a; + + +} + + + diff --git a/ml/dlib/dlib/test/reference_counter.cpp b/ml/dlib/dlib/test/reference_counter.cpp new file mode 100644 index 000000000..330ceed94 --- /dev/null +++ b/ml/dlib/dlib/test/reference_counter.cpp @@ -0,0 +1,122 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include + +#include +#include "tester.h" + +namespace +{ + + using namespace test; + using namespace std; + using namespace dlib; + + logger dlog("test.reference_counter"); + + template < + typename ref_counter + > + void reference_counter_test ( + ) + /*! + requires + - ref_counter is an implementation of reference_counter/reference_counter_kernel_abstract.h + and is instantiated to contain an int + ensures + - runs tests on reference_counter for compliance with the specs + !*/ + { + + ref_counter a, b, c; + + for (long i = 0; i < 10; ++i) + { + print_spinner(); + for (long j = 0; j < 10000; ++j) + { + a.modify() = j; + b.modify() = j+1; + c.modify() = j+2; + DLIB_ASSERT(a.access() == j,""); + DLIB_ASSERT(b.access() == j+1,""); + DLIB_ASSERT(c.access() == j+2,""); + DLIB_ASSERT(a.modify() == j,""); + DLIB_ASSERT(b.modify() == j+1,""); + DLIB_ASSERT(c.modify() == j+2,""); + DLIB_ASSERT(a.access() == j,""); + DLIB_ASSERT(b.access() == j+1,""); + DLIB_ASSERT(c.access() == j+2,""); + DLIB_ASSERT(a.modify() == j,""); + DLIB_ASSERT(b.modify() == j+1,""); + DLIB_ASSERT(c.modify() == j+2,""); + a = c; + DLIB_ASSERT(a.access() == j+2,""); + DLIB_ASSERT(b.access() == j+1,""); + DLIB_ASSERT(c.access() == j+2,""); + DLIB_ASSERT(a.modify() == j+2,""); + DLIB_ASSERT(b.modify() == j+1,""); + DLIB_ASSERT(c.modify() == j+2,""); + DLIB_ASSERT(a.access() == j+2,""); + DLIB_ASSERT(b.access() == j+1,""); + DLIB_ASSERT(c.access() == j+2,""); + DLIB_ASSERT(a.modify() == j+2,""); + DLIB_ASSERT(b.modify() == j+1,""); + DLIB_ASSERT(c.modify() == j+2,""); + + a = b = c; + DLIB_ASSERT(a.access() == b.access(),""); + DLIB_ASSERT(a.access() == c.access(),""); + DLIB_ASSERT(c.access() == b.access(),""); + a.modify() = j; + DLIB_ASSERT(a.access() == j,""); + DLIB_ASSERT(a.access() != b.access(),""); + DLIB_ASSERT(a.access() != c.access(),""); + DLIB_ASSERT(c.access() == b.access(),""); + DLIB_ASSERT(c.access() == j+2,""); + DLIB_ASSERT(b.access() == j+2,""); + + DLIB_ASSERT(a.access() == j,""); + a = a; + DLIB_ASSERT(a.access() == j,""); + c = c; + DLIB_ASSERT(c.access() == j+2,""); + DLIB_ASSERT(b.access() == j+2,""); + swap(a,c); + DLIB_ASSERT(a.access() == j+2,""); + DLIB_ASSERT(c.access() == j,""); + DLIB_ASSERT(b.access() == j+2,""); + } + } + + } + + + + + + class reference_counter_tester : public tester + { + public: + reference_counter_tester ( + ) : + tester ("test_reference_counter", + "Runs tests on the reference_counter component.") + {} + + void perform_test ( + ) + { + dlog << LINFO << "testing kernel_1a"; + reference_counter_test::kernel_1a> (); + } + } a; + +} + + diff --git a/ml/dlib/dlib/test/rls.cpp b/ml/dlib/dlib/test/rls.cpp new file mode 100644 index 000000000..c4516ad74 --- /dev/null +++ b/ml/dlib/dlib/test/rls.cpp @@ -0,0 +1,196 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include +#include + +#include "tester.h" + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.rls"); + + + void test_rls() + { + dlib::rand rnd; + + running_stats rs1, rs2, rs3, rs4, rs5; + + for (int k = 0; k < 2; ++k) + { + for (long num_vars = 1; num_vars < 4; ++num_vars) + { + print_spinner(); + for (long size = 1; size < 300; ++size) + { + { + matrix X = randm(size,num_vars,rnd); + matrix Y = randm(size,1,rnd); + + + const double C = 1000; + const double forget_factor = 1.0; + rls r(forget_factor, C); + for (long i = 0; i < Y.size(); ++i) + { + r.train(trans(rowm(X,i)), Y(i)); + } + + + matrix w = pinv(1.0/C*identity_matrix(X.nc()) + trans(X)*X)*trans(X)*Y; + + rs1.add(length(r.get_w() - w)); + } + + { + matrix X = randm(size,num_vars,rnd); + matrix Y = randm(size,1,rnd); + + matrix G(size,1); + + const double C = 10000; + const double forget_factor = 0.8; + rls r(forget_factor, C); + for (long i = 0; i < Y.size(); ++i) + { + r.train(trans(rowm(X,i)), Y(i)); + + G(i) = std::pow(forget_factor, i/2.0); + } + + G = flipud(G); + + X = diagm(G)*X; + Y = diagm(G)*Y; + + matrix w = pinv(1.0/C*identity_matrix(X.nc()) + trans(X)*X)*trans(X)*Y; + + rs5.add(length(r.get_w() - w)); + } + + { + matrix X = randm(size,num_vars,rnd); + matrix Y = colm(X,0)*10; + + + const double C = 1000000; + const double forget_factor = 1.0; + rls r(forget_factor, C); + for (long i = 0; i < Y.size(); ++i) + { + r.train(trans(rowm(X,i)), Y(i)); + } + + + matrix w = pinv(1.0/C*identity_matrix(X.nc()) + trans(X)*X)*trans(X)*Y; + + rs2.add(length(r.get_w() - w)); + } + + { + matrix X = join_rows(randm(size,num_vars,rnd)-0.5, ones_matrix(size,1)); + matrix Y = uniform_matrix(size,1,10); + + + const double C = 1e7; + const double forget_factor = 1.0; + + matrix w = pinv(1.0/C*identity_matrix(X.nc()) + trans(X)*X)*trans(X)*Y; + + rls r(forget_factor, C); + for (long i = 0; i < Y.size(); ++i) + { + r.train(trans(rowm(X,i)), Y(i)); + rs3.add(std::abs(r(trans(rowm(X,i))) - 10)); + } + + + } + { + matrix X = randm(size,num_vars,rnd)-0.5; + matrix Y = colm(X,0)*10; + + + const double C = 1e6; + const double forget_factor = 0.7; + + + rls r(forget_factor, C); + DLIB_TEST(std::abs(r.get_c() - C) < 1e-10); + DLIB_TEST(std::abs(r.get_forget_factor() - forget_factor) < 1e-15); + DLIB_TEST(r.get_w().size() == 0); + + for (long i = 0; i < Y.size(); ++i) + { + r.train(trans(rowm(X,i)), Y(i)); + rs4.add(std::abs(r(trans(rowm(X,i))) - X(i,0)*10)); + } + + DLIB_TEST(r.get_w().size() == num_vars); + + decision_function > > df = r.get_decision_function(); + DLIB_TEST(std::abs(df(trans(rowm(X,0))) - r(trans(rowm(X,0)))) < 1e-15); + } + } + } + } + + dlog << LINFO << "rs1.mean(): " << rs1.mean(); + dlog << LINFO << "rs2.mean(): " << rs2.mean(); + dlog << LINFO << "rs3.mean(): " << rs3.mean(); + dlog << LINFO << "rs4.mean(): " << rs4.mean(); + dlog << LINFO << "rs5.mean(): " << rs5.mean(); + dlog << LINFO << "rs1.max(): " << rs1.max(); + dlog << LINFO << "rs2.max(): " << rs2.max(); + dlog << LINFO << "rs3.max(): " << rs3.max(); + dlog << LINFO << "rs4.max(): " << rs4.max(); + dlog << LINFO << "rs5.max(): " << rs5.max(); + + DLIB_TEST_MSG(rs1.mean() < 1e-10, rs1.mean()); + DLIB_TEST_MSG(rs2.mean() < 1e-9, rs2.mean()); + DLIB_TEST_MSG(rs3.mean() < 1e-6, rs3.mean()); + DLIB_TEST_MSG(rs4.mean() < 1e-6, rs4.mean()); + DLIB_TEST_MSG(rs5.mean() < 1e-3, rs5.mean()); + + DLIB_TEST_MSG(rs1.max() < 1e-10, rs1.max()); + DLIB_TEST_MSG(rs2.max() < 1e-6, rs2.max()); + DLIB_TEST_MSG(rs3.max() < 0.001, rs3.max()); + DLIB_TEST_MSG(rs4.max() < 0.01, rs4.max()); + DLIB_TEST_MSG(rs5.max() < 0.1, rs5.max()); + + } + + + + + class rls_tester : public tester + { + public: + rls_tester ( + ) : + tester ("test_rls", + "Runs tests on the rls component.") + {} + + void perform_test ( + ) + { + test_rls(); + } + } a; + +} + + + diff --git a/ml/dlib/dlib/test/sammon.cpp b/ml/dlib/dlib/test/sammon.cpp new file mode 100644 index 000000000..5328bd1f6 --- /dev/null +++ b/ml/dlib/dlib/test/sammon.cpp @@ -0,0 +1,211 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include +#include + +#include "tester.h" + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.sammon"); + + + std::vector > make_test_data4( + ) + { + std::vector > data; + + matrix m; + + m = 0,0,0, 0; data.push_back(m); + m = 1,0,0, 0; data.push_back(m); + m = 0,1,0, 0; data.push_back(m); + m = 0,0,1, 0; data.push_back(m); + + return data; + } + + std::vector > make_test_data3( + ) + { + std::vector > data; + + matrix m; + + m = 0,0,0; data.push_back(m); + m = 1,0,0; data.push_back(m); + m = 0,1,0; data.push_back(m); + m = 0,0,1; data.push_back(m); + + return data; + } + + std::vector > make_test_data3d( + ) + { + std::vector > data; + + matrix m; + + m = 0,0,0; data.push_back(m); + m = 1,0,0; data.push_back(m); + m = 0,1,0; data.push_back(m); + m = 0,0,1; data.push_back(m); + + return data; + } + + + void runtest() + { + sammon_projection s; + std::vector > projs = s(make_test_data3(),2); + running_stats rs1, rs2; + + rs1.add(length(projs[0] - projs[1])); + rs1.add(length(projs[0] - projs[2])); + rs1.add(length(projs[0] - projs[3])); + + rs2.add(length(projs[1] - projs[2])); + rs2.add(length(projs[2] - projs[3])); + rs2.add(length(projs[3] - projs[1])); + + DLIB_TEST(rs1.stddev()/rs1.mean() < 1e-4); + DLIB_TEST(rs2.stddev()/rs2.mean() < 1e-4); + + + + projs = s(make_test_data4(),2); + rs1.clear(); + rs2.clear(); + + rs1.add(length(projs[0] - projs[1])); + rs1.add(length(projs[0] - projs[2])); + rs1.add(length(projs[0] - projs[3])); + + rs2.add(length(projs[1] - projs[2])); + rs2.add(length(projs[2] - projs[3])); + rs2.add(length(projs[3] - projs[1])); + + DLIB_TEST(rs1.stddev()/rs1.mean() < 1e-4); + DLIB_TEST(rs2.stddev()/rs2.mean() < 1e-4); + + projs = s(make_test_data3d(),2); + rs1.clear(); + rs2.clear(); + + rs1.add(length(projs[0] - projs[1])); + rs1.add(length(projs[0] - projs[2])); + rs1.add(length(projs[0] - projs[3])); + + rs2.add(length(projs[1] - projs[2])); + rs2.add(length(projs[2] - projs[3])); + rs2.add(length(projs[3] - projs[1])); + + DLIB_TEST(rs1.stddev()/rs1.mean() < 1e-4); + DLIB_TEST(rs2.stddev()/rs2.mean() < 1e-4); + } + + void runtest2() + { + sammon_projection s; + std::vector > projs, temp; + + DLIB_TEST(s(projs,3).size() == 0); + + matrix m; + m = 1,2; + projs.push_back(m); + temp = s(projs,2); + DLIB_TEST(temp.size() == 1); + DLIB_TEST(temp[0].size() == 2); + + projs.push_back(m); + temp = s(projs,1); + DLIB_TEST(temp.size() == 2); + DLIB_TEST(temp[0].size() == 1); + DLIB_TEST(temp[1].size() == 1); + } + + void runtest3(int num_dims) + { + sammon_projection s; + std::vector > projs; + matrix m; + m = 1, 1, 1; + projs.push_back(m); + + m = 1, 2, 1; + projs.push_back(m); + + m = 1, 3, 1; + projs.push_back(m); + + projs = s(projs,num_dims); + + const double d1a = length(projs[0] - projs[1]); + const double d1b = length(projs[1] - projs[2]); + const double d2 = length(projs[0] - projs[2]); + + DLIB_TEST(std::abs(d1a-d1b)/d1a < 1e-8); + DLIB_TEST(std::abs(d2/d1a-2) < 1e-8); + } + + void runtest4(int num_dims) + { + sammon_projection s; + std::vector > projs; + matrix m; + m = 1, 1, 1; + projs.push_back(m); + + m = 1, 2, 1; + projs.push_back(m); + + + projs = s(projs,num_dims); + + DLIB_TEST(length(projs[0] - projs[1]) > 1e-5); + } + + class sammon_tester : public tester + { + public: + sammon_tester ( + ) : + tester ("test_sammon", + "Runs tests on the sammon_projection component.") + {} + + void perform_test ( + ) + { + print_spinner(); + runtest(); + print_spinner(); + runtest2(); + print_spinner(); + runtest3(2); + print_spinner(); + runtest4(2); + runtest3(1); + print_spinner(); + runtest4(1); + } + } a; + +} + + + diff --git a/ml/dlib/dlib/test/scan_image.cpp b/ml/dlib/dlib/test/scan_image.cpp new file mode 100644 index 000000000..c3a0115e3 --- /dev/null +++ b/ml/dlib/dlib/test/scan_image.cpp @@ -0,0 +1,713 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include "dlib/image_processing.h" + +#include "dlib/test/tester.h" + +#include "dlib/image_transforms.h" +#include "dlib/pixel.h" +#include "dlib/array2d.h" +#include "dlib/array.h" + +// ---------------------------------------------------------------------------------------- + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + using dlib::array; + + // Declare the logger we will use in this test. The name of the tester + // should start with "test." + logger dlog("test.scan_image"); + +// ---------------------------------------------------------------------------------------- + + template + void sum_filter_i ( + const image_type1& img, + image_type2& out, + const rectangle& rect + ) + { + typedef typename image_type1::type pixel_type; + typedef typename promote::type ptype; + integral_image_generic iimg; + iimg.load(img); + for (long r = 0; r < img.nr(); ++r) + { + for (long c = 0; c < img.nc(); ++c) + { + const rectangle temp = translate_rect(rect, point(c,r)).intersect(get_rect(iimg)); + if (temp.is_empty() == false) + out[r][c] += iimg.get_sum_of_area(temp); + } + } + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_array_type + > + void scan_image_i ( + std::vector >& dets, + const image_array_type& images, + const std::vector >& rects, + const double thresh, + const unsigned long max_dets + ) + { + typedef typename image_array_type::type::type pixel_type; + typedef typename promote::type ptype; + array > iimg; + iimg.set_max_size(images.size()); + iimg.set_size(images.size()); + + for (unsigned long i = 0; i < iimg.size(); ++i) + iimg[i].load(images[i]); + + + dets.clear(); + + + for (long r = 0; r < images[0].nr(); ++r) + { + for (long c = 0; c < images[0].nc(); ++c) + { + ptype temp = 0; + for (unsigned long i = 0; i < rects.size(); ++i) + { + rectangle rtemp = translate_rect(rects[i].second,point(c,r)).intersect(get_rect(images[0])); + if (rtemp.is_empty() == false) + temp += iimg[rects[i].first].get_sum_of_area(rtemp); + } + if (temp > thresh) + { + dets.push_back(std::make_pair(temp, point(c,r))); + + if (dets.size() >= max_dets) + return; + + } + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_array_type + > + void scan_image_old ( + std::vector >& dets, + const image_array_type& images, + const std::vector >& rects, + const double thresh, + const unsigned long max_dets + ) + { + dets.clear(); + if (max_dets == 0) + return; + + typedef typename image_array_type::type::type pixel_type; + typedef typename promote::type ptype; + + std::vector > column_sums(rects.size()); + for (unsigned long i = 0; i < column_sums.size(); ++i) + { + const typename image_array_type::type& img = images[rects[i].first]; + column_sums[i].resize(img.nc() + rects[i].second.width(),0); + + const long top = -1 + rects[i].second.top(); + const long bottom = -1 + rects[i].second.bottom(); + long left = rects[i].second.left()-1; + + // initialize column_sums[i] at row -1 + for (unsigned long j = 0; j < column_sums[i].size(); ++j) + { + rectangle strip(left,top,left,bottom); + strip = strip.intersect(get_rect(img)); + if (!strip.is_empty()) + { + column_sums[i][j] = sum(matrix_cast(subm(mat(img),strip))); + } + + ++left; + } + } + + + const rectangle area = get_rect(images[0]); + + for (long r = 0; r < images[0].nr(); ++r) + { + // set to sum at point(-1,r). i.e. should be equal to sum_of_rects_in_images(images, rects, point(-1,r)) + // We compute it's value in the next loop. + ptype cur_sum = 0; + + // Update the first part of column_sums since we only work on the c+width part of column_sums + // in the main loop. + for (unsigned long i = 0; i < rects.size(); ++i) + { + const typename image_array_type::type& img = images[rects[i].first]; + const long top = r + rects[i].second.top() - 1; + const long bottom = r + rects[i].second.bottom(); + const long width = rects[i].second.width(); + for (long k = 0; k < width; ++k) + { + const long right = k-width + rects[i].second.right(); + + const ptype br_corner = area.contains(right,bottom) ? img[bottom][right] : 0; + const ptype tr_corner = area.contains(right,top) ? img[top][right] : 0; + // update the sum in this column now that we are on the next row + column_sums[i][k] = column_sums[i][k] + br_corner - tr_corner; + cur_sum += column_sums[i][k]; + } + } + + for (long c = 0; c < images[0].nc(); ++c) + { + for (unsigned long i = 0; i < rects.size(); ++i) + { + const typename image_array_type::type& img = images[rects[i].first]; + const long top = r + rects[i].second.top() - 1; + const long bottom = r + rects[i].second.bottom(); + const long right = c + rects[i].second.right(); + const long width = rects[i].second.width(); + + const ptype br_corner = area.contains(right,bottom) ? img[bottom][right] : 0; + const ptype tr_corner = area.contains(right,top) ? img[top][right] : 0; + // update the sum in this column now that we are on the next row + column_sums[i][c+width] = column_sums[i][c+width] + br_corner - tr_corner; + + + // add in the new right side of the rect and subtract the old right side. + cur_sum = cur_sum + column_sums[i][c+width] - column_sums[i][c]; + + } + + if (cur_sum > thresh) + { + dets.push_back(std::make_pair(cur_sum, point(c,r))); + + if (dets.size() >= max_dets) + return; + } + } + } + } + +// ---------------------------------------------------------------------------------------- + + void run_test1() + { + dlog << LINFO << "run_test1()"; + + print_spinner(); + array2d img, temp_img; + img.set_size(600,600); + assign_all_pixels(img,0); + rectangle rect = centered_rect(10,10,5,5); + dlog << LTRACE << "expected: 10,10"; + fill_rect(img, rect, 255); + + + array > images; + std::vector > rects; + for (int i = 0; i < 10; ++i) + { + assign_image(temp_img, img); + images.push_back(temp_img); + rects.push_back(make_pair(i,centered_rect(0,0,5,5))); + } + + std::vector > dets, dets2, dets3; + + + dlog << LTRACE << "best score: "<< sum_of_rects_in_images(images,rects,point(10,10)); + scan_image(dets,images,rects,30000, 100); + scan_image_i(dets2,images,rects,30000, 100); + scan_image_old(dets3,images,rects,30000, 100); + + + + dlog << LTRACE << "dets.size(): "<< dets.size(); + dlog << LTRACE << "dets2.size(): "<< dets2.size(); + dlog << LTRACE << "dets3.size(): "<< dets3.size(); + + DLIB_TEST(dets.size() == dets2.size()); + DLIB_TEST(dets.size() == dets3.size()); + + for (unsigned long i = 0; i < dets.size(); ++i) + { + //dlog << LTRACE << "dets["< " << dets[i].first; + //dlog << LTRACE << "dets2["< " << dets2[i].first; + //dlog << LTRACE << "dets3["< " << dets3[i].first; + + DLIB_TEST(sum_of_rects_in_images(images, rects, dets[i].second) == dets[i].first); + DLIB_TEST(sum_of_rects_in_images(images, rects, dets2[i].second) == dets2[i].first); + DLIB_TEST(sum_of_rects_in_images(images, rects, dets3[i].second) == dets3[i].first); + } + + + } + +// ---------------------------------------------------------------------------------------- + + void run_test2() + { + print_spinner(); + dlog << LINFO << "run_test2()"; + array2d img, temp_img; + img.set_size(600,600); + assign_all_pixels(img,0); + rectangle rect = centered_rect(10,11,5,6); + dlog << LTRACE << "expected: 10,11"; + fill_rect(img, rect, 255); + + + array > images; + std::vector > rects; + for (int i = 0; i < 10; ++i) + { + assign_image(temp_img, img); + images.push_back(temp_img); + rects.push_back(make_pair(i,centered_rect(0,0,5,5))); + rects.push_back(make_pair(i,centered_rect(3,2,5,6))); + } + + std::vector > dets, dets2, dets3; + + + scan_image(dets,images,rects,30000, 100); + scan_image_i(dets2,images,rects,30000, 100); + scan_image_old(dets3,images,rects,30000, 100); + + + + dlog << LTRACE << "dets.size(): "<< dets.size(); + dlog << LTRACE << "dets2.size(): "<< dets2.size(); + dlog << LTRACE << "dets3.size(): "<< dets3.size(); + + DLIB_TEST(dets.size() == dets2.size()); + DLIB_TEST(dets.size() == dets3.size()); + + for (unsigned long i = 0; i < dets.size(); ++i) + { + //dlog << LTRACE << "dets["< " << dets[i].first; + //dlog << LTRACE << "dets2["< " << dets2[i].first; + //dlog << LTRACE << "dets3["< " << dets3[i].first; + + DLIB_TEST(sum_of_rects_in_images(images, rects, dets[i].second) == dets[i].first); + DLIB_TEST(sum_of_rects_in_images(images, rects, dets2[i].second) == dets2[i].first); + DLIB_TEST(sum_of_rects_in_images(images, rects, dets3[i].second) == dets3[i].first); + } + + + } + +// ---------------------------------------------------------------------------------------- + + template + void run_test3(const double thresh) + { + dlog << LINFO << "running run_test3("< > images; + images.resize(1); + images[0].set_size(200,180); + + for (int iter = 0; iter < 50; ++iter) + { + print_spinner(); + assign_all_pixels(images[0], thresh - 0.0001); + + for (int i = 0; i < 20; ++i) + { + point p1(rnd.get_random_32bit_number()%images[0].nc(), + rnd.get_random_32bit_number()%images[0].nr()); + point p2(rnd.get_random_32bit_number()%images[0].nc(), + rnd.get_random_32bit_number()%images[0].nr()); + + rectangle rect(p1,p2); + fill_rect(images[0], rect, static_cast(rnd.get_random_double()*10 - 5)); + } + + std::vector > rects; + rects.push_back(make_pair(0,centered_rect(0,0,1+rnd.get_random_32bit_number()%40,1+rnd.get_random_32bit_number()%40))); + rects.push_back(make_pair(0,centered_rect(0,0,1+rnd.get_random_32bit_number()%40,1+rnd.get_random_32bit_number()%40))); + + + + + std::vector > dets, dets2, dets3; + scan_image(dets,images,rects,thresh, 100); + scan_image_i(dets2,images,rects,thresh, 100); + scan_image_old(dets3,images,rects,thresh, 100); + + dlog << LTRACE << "dets.size(): "<< dets.size(); + dlog << LTRACE << "dets2.size(): "<< dets2.size(); + dlog << LTRACE << "dets3.size(): "<< dets3.size(); + + DLIB_TEST(dets.size() == dets2.size()); + DLIB_TEST(dets.size() == dets3.size()); + + for (unsigned long i = 0; i < dets.size(); ++i) + { + //dlog << LTRACE << "dets["< " << dets[i].first; + //dlog << LTRACE << "dets2["< " << dets2[i].first; + //dlog << LTRACE << "dets3["< " << dets3[i].first; + + DLIB_TEST_MSG(std::abs(sum_of_rects_in_images(images, rects, dets[i].second) - dets[i].first) < 1e-6, + "error: "<< sum_of_rects_in_images(images, rects, dets[i].second) - dets[i].first + << " dets["< + void test_sum_filter ( + ) + { + dlib::rand rnd; + + for (int k = 0; k < 20; ++k) + { + print_spinner(); + + array2d img(1 + rnd.get_random_32bit_number()%100, + 1 + rnd.get_random_32bit_number()%100); + + for (long r = 0; r < img.nr(); ++r) + { + for (long c = 0; c < img.nc(); ++c) + { + img[r][c] = static_cast(100*(rnd.get_random_double()-0.5)); + } + } + + array2d test1(img.nr(), img.nc()); + array2d test2(img.nr(), img.nc()); + array2d test1_i(img.nr(), img.nc()); + array2d test2_i(img.nr(), img.nc()); + + assign_all_pixels(test1, 0); + assign_all_pixels(test2, 0); + assign_all_pixels(test1_i, 0); + assign_all_pixels(test2_i, 0); + + for (int i = 0; i < 10; ++i) + { + const long width = rnd.get_random_32bit_number()%10 + 1; + const long height = rnd.get_random_32bit_number()%10 + 1; + const point p(rnd.get_random_32bit_number()%img.nc(), + rnd.get_random_32bit_number()%img.nr()); + + const rectangle rect = centered_rect(p, width, height); + sum_filter(img, test1, rect); + sum_filter(img, test2, rect); + sum_filter(img, test1_i, rect); + sum_filter(img, test2_i, rect); + + DLIB_TEST(mat(test1) == mat(test1_i)); + DLIB_TEST(mat(test2) == mat(test2_i)); + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename image_type1, + typename image_type2 + > + void naive_max_filter ( + const image_type1& img, + image_type2& out, + const long width, + const long height, + typename image_type1::type thresh + ) + { + const rectangle area = get_rect(img); + for (long r = 0; r < img.nr(); ++r) + { + for (long c = 0; c < img.nc(); ++c) + { + const rectangle win = centered_rect(point(c,r),width,height).intersect(area); + out[r][c] += std::max(dlib::max(subm(mat(img),win)), thresh); + } + } + } + +// ---------------------------------------------------------------------------------------- + + void test_max_filter(long rows, long cols, long width, long height, dlib::rand& rnd) + { + array2d img(rows, cols); + rectangle rect = centered_rect(0,0, width, height); + + array2d out(img.nr(),img.nc()); + assign_all_pixels(out, 0); + array2d out2(img.nr(),img.nc()); + assign_all_pixels(out2, 0); + + for (long r = 0; r < img.nr(); ++r) + { + for (long c = 0; c < img.nc(); ++c) + { + img[r][c] = rnd.get_random_32bit_number(); + } + } + + const int thresh = rnd.get_random_32bit_number(); + + naive_max_filter(img, out2, rect.width(), rect.height(), thresh); + max_filter(img, out, rect.width(), rect.height(), thresh); + + DLIB_TEST_MSG(mat(out) == mat(out2), + "rows: "<< rows + << "\ncols: "<< rows + << "\nwidth: "<< width + << "\nheight: "<< height ); + } + +// ---------------------------------------------------------------------------------------- + + void test_max_filter() + { + dlib::rand rnd; + for (int iter = 0; iter < 300; ++iter) + { + print_spinner(); + test_max_filter(0,0,1,1,rnd); + test_max_filter(0,0,3,1,rnd); + test_max_filter(0,0,3,3,rnd); + test_max_filter(0,0,1,3,rnd); + test_max_filter(1,1,1,1,rnd); + test_max_filter(2,2,1,1,rnd); + test_max_filter(3,3,1,1,rnd); + test_max_filter(3,3,3,3,rnd); + test_max_filter(3,3,2,2,rnd); + test_max_filter(3,3,3,5,rnd); + test_max_filter(3,3,6,8,rnd); + test_max_filter(20,20,901,901,rnd); + test_max_filter(5,5,1,5,rnd); + test_max_filter(50,50,9,9,rnd); + test_max_filter(50,50,9,9,rnd); + test_max_filter(50,50,10,10,rnd); + test_max_filter(50,50,11,10,rnd); + test_max_filter(50,50,10,11,rnd); + test_max_filter(50,50,10,21,rnd); + test_max_filter(50,50,20,10,rnd); + test_max_filter(50,50,20,10,rnd); + test_max_filter(50,50,9,9,rnd); + test_max_filter(20,20,1,901,rnd); + test_max_filter(20,20,3,901,rnd); + test_max_filter(20,20,901,1,rnd); + } + + for (int iter = 0; iter < 200; ++iter) + { + print_spinner(); + test_max_filter((int)rnd.get_random_8bit_number()%100+1, + (int)rnd.get_random_8bit_number()%100+1, + (int)rnd.get_random_8bit_number()%150+1, + (int)rnd.get_random_8bit_number()%150+1, + rnd); + } + } + +// ---------------------------------------------------------------------------------------- + + void make_images ( + dlib::rand& rnd, + array >& images, + long num, + long nr, + long nc + ) + { + images.resize(num); + for (unsigned long i = 0; i < images.size(); ++i) + { + images[i].set_size(nr,nc); + } + + for (unsigned long i = 0; i < images.size(); ++i) + { + for (long r = 0; r < nr; ++r) + { + for (long c = 0; c < nc; ++c) + { + images[i][r][c] = rnd.get_random_8bit_number(); + } + } + } + } + + + template < + typename image_array_type + > + void brute_force_scan_image_movable_parts ( + std::vector >& dets, + const image_array_type& images, + const rectangle& window, + const std::vector >& fixed_rects, + const std::vector >& movable_rects, + const double thresh, + const unsigned long + ) + { + dets.clear(); + if (movable_rects.size() == 0 && fixed_rects.size() == 0) + return; + + for (long r = 0; r < images[0].nr(); ++r) + { + for (long c = 0; c < images[0].nc(); ++c) + { + const point p(c,r); + double score = sum_of_rects_in_images_movable_parts(images, + window, + fixed_rects, + movable_rects, + p); + + if (score >= thresh) + { + dets.push_back(make_pair(score,p)); + } + } + } + } + + void test_scan_images_movable_parts() + { + array > images; + dlib::rand rnd; + for (int iter = 0; iter < 40; ++iter) + { + print_spinner(); + const int num_images = rnd.get_random_32bit_number()%4+1; + + make_images(rnd,images, num_images, + rnd.get_random_32bit_number()%50+1, + rnd.get_random_32bit_number()%50+1 + ); + + std::vector > dets1, dets2; + std::vector > fixed_rects, movable_rects; + + double total_area = 0; + for (unsigned long i = 0; i < images.size(); ++i) + { + fixed_rects.push_back(make_pair(i, centered_rect( + rnd.get_random_32bit_number()%10-5, + rnd.get_random_32bit_number()%10-5, + rnd.get_random_32bit_number()%10, + rnd.get_random_32bit_number()%10 + ))); + + total_area += fixed_rects.back().second.area(); + + movable_rects.push_back(make_pair(i, centered_rect( + 0, + 0, + rnd.get_random_32bit_number()%10+1, + rnd.get_random_32bit_number()%10+1 + ))); + total_area += movable_rects.back().second.area(); + } + + const rectangle window = centered_rect(0,0, + rnd.get_random_32bit_number()%15+1, + rnd.get_random_32bit_number()%15+1); + dlog << LINFO << "window size: "<< window.width() << ", " << window.height(); + const double thresh = total_area*130; + const unsigned long max_dets = get_rect(images[0]).area(); + + scan_image_movable_parts(dets1,images,window,fixed_rects,movable_rects,thresh, max_dets); + brute_force_scan_image_movable_parts(dets2,images,window,fixed_rects,movable_rects,thresh, max_dets); + + dlog << LINFO << "max_possible dets: " << max_dets; + dlog << LINFO << "regular dets: " << dets1.size(); + dlog << LINFO << "brute force: " << dets2.size(); + DLIB_TEST(dets1.size() == dets2.size()); + + array2d check(images[0].nr(), images[0].nc()); + assign_all_pixels(check, 1e-300); + for (unsigned long i = 0; i < dets1.size(); ++i) + { + const point p = dets1[i].second; + check[p.y()][p.x()] = dets1[i].first; + } + for (unsigned long i = 0; i < dets2.size(); ++i) + { + const point p = dets2[i].second; + DLIB_TEST(std::abs(check[p.y()][p.x()] - dets2[i].first) < 1e-10); + } + dlog << LINFO << "=======================\n"; + } + } + +// ---------------------------------------------------------------------------------------- + + class scan_image_tester : public tester + { + public: + scan_image_tester ( + ) : + tester ("test_scan_image", + "Runs tests on the scan_image routine.") + {} + + void perform_test ( + ) + { + test_scan_images_movable_parts(); + test_max_filter(); + + run_test1(); + run_test2(); + run_test3(1); + run_test3(-1); + run_test3(1); + run_test3(-1); + + test_sum_filter(); + test_sum_filter(); + } + } a; + +} + + diff --git a/ml/dlib/dlib/test/sequence.cpp b/ml/dlib/dlib/test/sequence.cpp new file mode 100644 index 000000000..ffa6efdb8 --- /dev/null +++ b/ml/dlib/dlib/test/sequence.cpp @@ -0,0 +1,312 @@ +// Copyright (C) 2004 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include + +#include +#include "tester.h" + +namespace +{ + + using namespace test; + using namespace std; + using namespace dlib; + + logger dlog("test.sequence"); + + template < + typename seq + > + void sequence_sort_test ( + ) + /*! + requires + - seq is an implementation of sequence/sequence_sort_aseqract.h is instantiated + with int + ensures + - runs tests on seq for compliance with the specs + !*/ + { + + + srand(static_cast(time(0))); + + + print_spinner(); + + + + + + { + // this test is to make sure that jumping around via + // operator[] doesn't corrupt the object + + seq a; + + for (int i = 0; i < 100; ++i) + { + int x = i; + a.add(a.size(),x); + } + + + int x = 0; + + for (int i = 0; i < (int)a.size(); ++i) + { + DLIB_TEST_MSG(a[i] >= i,"1"); + // cout << a[i] << endl; + } + + for (unsigned long i = 0; i < a.size(); ++i) + { + for (unsigned long j = i+1; j < a.size(); ++j) + { + if ((a[j]+a[i])%3 ==0) + { + a.remove(j,x); + --j; + } + } + } + + //cout << endl; + + for (int i = 0; i < (int)a.size(); ++i) + { + // cout << a[i] << endl; + DLIB_TEST_MSG(a[i] >= i,"2"); + } + + } + + + + + + + + seq test, test2; + + DLIB_TEST(test.size() == 0); + DLIB_TEST(test.at_start() == true); + DLIB_TEST(test.current_element_valid() == false); + + enumerable& e = test; + + DLIB_TEST(e.at_start() == true); + DLIB_TEST(e.current_element_valid() == false); + + + for (int g = 0; g < 5; ++g) + { + test.clear(); + test2.clear(); + DLIB_TEST(test.size() == 0); + DLIB_TEST(test.at_start() == true); + DLIB_TEST(test.current_element_valid() == false); + DLIB_TEST(e.at_start() == true); + DLIB_TEST(e.current_element_valid() == false); + + DLIB_TEST(e.move_next() == false); + DLIB_TEST(e.current_element_valid() == false); + DLIB_TEST(e.at_start() == false); + DLIB_TEST(test.at_start() == false); + swap(test,test2); + DLIB_TEST(test.at_start() == true); + test.clear(); + test2.clear(); + + int a; + + + for (int i = 0; i < 100; ++i) + { + a = i; + test.add(i,a); + } + + DLIB_TEST(test.size() == 100); + + for (int i = 0; i < static_cast(test.size()); ++i) + { + DLIB_TEST(test[i] == i); + } + + swap(test,test2); + + a = 0; + DLIB_TEST(test2.at_start() == true); + while(test2.move_next()) + { + DLIB_TEST(test2.at_start() == false); + DLIB_TEST(test2.current_element_valid() == true); + DLIB_TEST(test2.element() == a); + ++a; + } + + DLIB_TEST(test2.at_start() == false); + DLIB_TEST(test2.current_element_valid() == false); + + test2.reset(); + + DLIB_TEST(test2.at_start() == true); + DLIB_TEST(test2.current_element_valid() == false); + + a = 0; + while(test2.move_next()) + { + DLIB_TEST(test2.at_start() == false); + DLIB_TEST(test2.current_element_valid() == true); + DLIB_TEST(test2.element() == a); + ++a; + } + + + + + + for (int i = 0; i < 1000; ++i) + { + a = ::rand(); + test.add(0,a); + } + DLIB_TEST(test.size() == 1000); + + test.sort(); + + + for (unsigned long i = 0; i < test.size()-1; ++i) + { + DLIB_TEST(test[i] <= test[i+1]); + } + + a = 0; + while(test.move_next()) + { + DLIB_TEST(a <= test.element()); + a = test.element(); + } + + + test.clear(); + test2.clear(); + + DLIB_TEST(test.size() == 0); + DLIB_TEST(test2.size() == 0); + + for (int i = 0; i < 100; ++i) + { + a = i; + test.add(i,a); + } + + for (int i = 100; i < 200; ++i) + { + a = i; + test.add(i,a); + } + + test.cat(test2); + DLIB_TEST(test.size() == 200); + DLIB_TEST(test2.size() == 0); + + + // serialize the state of test, then clear test, then + // load the state back into test. + ostringstream sout; + serialize(test,sout); + DLIB_TEST(test.at_start() == true); + istringstream sin(sout.str()); + test.clear(); + deserialize(test,sin); + + + for (int i = 0; i < 200; ++i) + { + DLIB_TEST(test[i] == i); + } + + a = 0; + while (test.move_next()) + { + DLIB_TEST(test.element() == a); + DLIB_TEST(test[0]==0); + ++a; + } + + DLIB_TEST(a == 200); + + DLIB_TEST(test[9] == 9); + test.remove(9,a); + DLIB_TEST(a == 9); + DLIB_TEST(test[9] == 10); + DLIB_TEST(test.size() == 199); + + test.remove(0,a); + DLIB_TEST(test[0] == 1); + DLIB_TEST(test.size() == 198); + DLIB_TEST(a == 0); + DLIB_TEST(test[9] == 11); + DLIB_TEST(test[20] == 22); + + + + + } + + { + test.clear(); + for (int i = 0; i < 100; ++i) + { + int a = 3; + test.add(0,a); + } + DLIB_TEST(test.size() == 100); + remover& go = test; + for (int i = 0; i < 100; ++i) + { + int a = 9; + go.remove_any(a); + DLIB_TEST(a == 3); + } + DLIB_TEST(go.size() == 0); + } + + + } + + + + + class sequence_tester : public tester + { + public: + sequence_tester ( + ) : + tester ("test_sequence", + "Runs tests on the sequence component.") + {} + + void perform_test ( + ) + { + dlog << LINFO << "testing sort_1a"; + sequence_sort_test::sort_1a> (); + dlog << LINFO << "testing sort_1a_c"; + sequence_sort_test::sort_1a_c>(); + dlog << LINFO << "testing sort_2a"; + sequence_sort_test::sort_2a> (); + dlog << LINFO << "testing sort_2a_c"; + sequence_sort_test::sort_2a_c>(); + } + } a; + +} + diff --git a/ml/dlib/dlib/test/sequence_labeler.cpp b/ml/dlib/dlib/test/sequence_labeler.cpp new file mode 100644 index 000000000..4002d6821 --- /dev/null +++ b/ml/dlib/dlib/test/sequence_labeler.cpp @@ -0,0 +1,461 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include "tester.h" +#include +#include + + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.sequence_labeler"); + +// ---------------------------------------------------------------------------------------- + + const unsigned long num_label_states = 3; // the "hidden" states + const unsigned long num_sample_states = 3; + +// ---------------------------------------------------------------------------------------- + + struct funny_sequence + { + std::vector item; + unsigned long size() const { return item.size(); } + }; + funny_sequence make_funny_sequence(const std::vector& item) + { + funny_sequence temp; + temp.item = item; + return temp; + } + +// ---------------------------------------------------------------------------------------- + + class feature_extractor + { + public: + typedef funny_sequence sequence_type; + + unsigned long num_features() const + { + return num_label_states*num_label_states + num_label_states*num_sample_states; + } + + unsigned long order() const + { + return 1; + } + + unsigned long num_labels() const + { + return num_label_states; + } + + template + void get_features ( + feature_setter& set_feature, + const sequence_type& x, + const matrix_exp& y, + unsigned long position + ) const + { + if (y.size() > 1) + set_feature(y(1)*num_label_states + y(0)); + + set_feature(num_label_states*num_label_states + + y(0)*num_sample_states + x.item[position]); + } + }; + + class feature_extractor_partial + { + public: + typedef funny_sequence sequence_type; + + unsigned long num_features() const + { + return num_label_states*num_label_states + num_label_states*num_sample_states; + } + + unsigned long order() const + { + return 1; + } + + unsigned long num_labels() const + { + return num_label_states; + } + + template + void get_features ( + feature_setter& set_feature, + const sequence_type& x, + const matrix_exp& y, + unsigned long position + ) const + { + if (y.size() > 1) + { + set_feature(y(1)*num_label_states + y(0), 0.5); + set_feature(y(1)*num_label_states + y(0), 0.5); + } + + set_feature(num_label_states*num_label_states + + y(0)*num_sample_states + x.item[position],0.25); + set_feature(num_label_states*num_label_states + + y(0)*num_sample_states + x.item[position],0.75); + } + }; + + bool called_rejct_labeling = false; + class feature_extractor2 + { + public: + typedef funny_sequence sequence_type; + + unsigned long num_features() const + { + return num_label_states*num_label_states + num_label_states*num_sample_states; + } + + unsigned long order() const + { + return 1; + } + + unsigned long num_labels() const + { + return num_label_states; + } + + template + bool reject_labeling ( + const sequence_type& , + const matrix_exp& , + unsigned long + ) const + { + called_rejct_labeling = true; + return false; + } + + template + void get_features ( + feature_setter& set_feature, + const sequence_type& x, + const matrix_exp& y, + unsigned long position + ) const + { + if (y.size() > 1) + set_feature(y(1)*num_label_states + y(0)); + + set_feature(num_label_states*num_label_states + + y(0)*num_sample_states + x.item[position]); + } + }; + + void serialize(const feature_extractor&, std::ostream&) {} + void deserialize(feature_extractor&, std::istream&) {} + void serialize(const feature_extractor2&, std::ostream&) {} + void deserialize(feature_extractor2&, std::istream&) {} + +// ---------------------------------------------------------------------------------------- + + void sample_hmm ( + dlib::rand& rnd, + const matrix& transition_probabilities, + const matrix& emission_probabilities, + unsigned long previous_label, + unsigned long& next_label, + unsigned long& next_sample + ) + /*! + requires + - previous_label < transition_probabilities.nr() + - transition_probabilities.nr() == transition_probabilities.nc() + - transition_probabilities.nr() == emission_probabilities.nr() + - The rows of transition_probabilities and emission_probabilities must sum to 1. + (i.e. sum_cols(transition_probabilities) and sum_cols(emission_probabilities) + must evaluate to vectors of all 1s.) + ensures + - This function randomly samples the HMM defined by transition_probabilities + and emission_probabilities assuming that the previous hidden state + was previous_label. + - The HMM is defined by: + - P(next_label |previous_label) == transition_probabilities(previous_label, next_label) + - P(next_sample|next_label) == emission_probabilities (next_label, next_sample) + - #next_label == the sampled value of the hidden state + - #next_sample == the sampled value of the observed state + !*/ + { + // sample next_label + double p = rnd.get_random_double(); + for (long c = 0; p >= 0 && c < transition_probabilities.nc(); ++c) + { + next_label = c; + p -= transition_probabilities(previous_label, c); + } + + // now sample next_sample + p = rnd.get_random_double(); + for (long c = 0; p >= 0 && c < emission_probabilities.nc(); ++c) + { + next_sample = c; + p -= emission_probabilities(next_label, c); + } + } + +// ---------------------------------------------------------------------------------------- + + void make_dataset ( + const matrix& transition_probabilities, + const matrix& emission_probabilities, + std::vector& samples, + std::vector >& labels, + unsigned long dataset_size + ) + /*! + requires + - transition_probabilities.nr() == transition_probabilities.nc() + - transition_probabilities.nr() == emission_probabilities.nr() + - The rows of transition_probabilities and emission_probabilities must sum to 1. + (i.e. sum_cols(transition_probabilities) and sum_cols(emission_probabilities) + must evaluate to vectors of all 1s.) + ensures + - This function randomly samples a bunch of sequences from the HMM defined by + transition_probabilities and emission_probabilities. + - The HMM is defined by: + - The probability of transitioning from hidden state H1 to H2 + is given by transition_probabilities(H1,H2). + - The probability of a hidden state H producing an observed state + O is given by emission_probabilities(H,O). + - #samples.size() == labels.size() == dataset_size + - for all valid i: + - #labels[i] is a randomly sampled sequence of hidden states from the + given HMM. #samples[i] is its corresponding randomly sampled sequence + of observed states. + !*/ + { + samples.clear(); + labels.clear(); + + dlib::rand rnd; + + // now randomly sample some labeled sequences from our Hidden Markov Model + for (unsigned long iter = 0; iter < dataset_size; ++iter) + { + const unsigned long sequence_size = rnd.get_random_32bit_number()%20+3; + std::vector sample(sequence_size); + std::vector label(sequence_size); + + unsigned long previous_label = rnd.get_random_32bit_number()%num_label_states; + for (unsigned long i = 0; i < sample.size(); ++i) + { + unsigned long next_label=0, next_sample=0; + sample_hmm(rnd, transition_probabilities, emission_probabilities, + previous_label, next_label, next_sample); + + label[i] = next_label; + sample[i] = next_sample; + + previous_label = next_label; + } + + samples.push_back(make_funny_sequence(sample)); + labels.push_back(label); + } + } + +// ---------------------------------------------------------------------------------------- + + template + void do_test() + { + called_rejct_labeling = false; + + matrix transition_probabilities(num_label_states, num_label_states); + transition_probabilities = 0.05, 0.90, 0.05, + 0.05, 0.05, 0.90, + 0.90, 0.05, 0.05; + + matrix emission_probabilities(num_label_states,num_sample_states); + emission_probabilities = 0.5, 0.5, 0.0, + 0.0, 0.5, 0.5, + 0.5, 0.0, 0.5; + + print_spinner(); + + + std::vector samples; + std::vector > labels; + make_dataset(transition_probabilities,emission_probabilities, + samples, labels, 1000); + + dlog << LINFO << "samples.size(): "<< samples.size(); + + // print out some of the randomly sampled sequences + for (int i = 0; i < 10; ++i) + { + dlog << LINFO << "hidden states: " << trans(mat(labels[i])); + dlog << LINFO << "observed states: " << trans(mat(samples[i].item)); + dlog << LINFO << "******************************"; + } + + print_spinner(); + structural_sequence_labeling_trainer trainer; + trainer.set_c(4); + DLIB_TEST(trainer.get_c() == 4); + trainer.set_num_threads(4); + DLIB_TEST(trainer.get_num_threads() == 4); + + + + // Learn to do sequence labeling from the dataset + sequence_labeler labeler = trainer.train(samples, labels); + + std::vector predicted_labels = labeler(samples[0]); + dlog << LINFO << "true hidden states: "<< trans(mat(labels[0])); + dlog << LINFO << "predicted hidden states: "<< trans(mat(predicted_labels)); + + DLIB_TEST(mat(labels[0]) == mat(predicted_labels)); + + + print_spinner(); + + + // We can also do cross-validation + matrix confusion_matrix; + confusion_matrix = cross_validate_sequence_labeler(trainer, samples, labels, 4); + dlog << LINFO << "cross-validation: "; + dlog << LINFO << confusion_matrix; + double accuracy = sum(diag(confusion_matrix))/sum(confusion_matrix); + dlog << LINFO << "label accuracy: "<< accuracy; + DLIB_TEST(std::abs(accuracy - 0.882) < 0.01); + + print_spinner(); + + + matrix true_hmm_model_weights = log(join_cols(reshape_to_column_vector(transition_probabilities), + reshape_to_column_vector(emission_probabilities))); + + sequence_labeler labeler_true(true_hmm_model_weights); + + confusion_matrix = test_sequence_labeler(labeler_true, samples, labels); + dlog << LINFO << "True HMM model: "; + dlog << LINFO << confusion_matrix; + accuracy = sum(diag(confusion_matrix))/sum(confusion_matrix); + dlog << LINFO << "label accuracy: "<< accuracy; + DLIB_TEST(std::abs(accuracy - 0.882) < 0.01); + + + + print_spinner(); + + + + + // Finally, the labeler can be serialized to disk just like most dlib objects. + ostringstream sout; + serialize(labeler, sout); + + sequence_labeler labeler2; + // recall from disk + istringstream sin(sout.str()); + deserialize(labeler2, sin); + confusion_matrix = test_sequence_labeler(labeler2, samples, labels); + dlog << LINFO << "deserialized labeler: "; + dlog << LINFO << confusion_matrix; + accuracy = sum(diag(confusion_matrix))/sum(confusion_matrix); + dlog << LINFO << "label accuracy: "<< accuracy; + DLIB_TEST(std::abs(accuracy - 0.882) < 0.01); + } + +// ---------------------------------------------------------------------------------------- + + void test2() + { + /* + The point of this test is to make sure calling set_feature() multiple + times works the way it is supposed to. + */ + + print_spinner(); + std::vector samples; + std::vector > labels; + + matrix transition_probabilities(num_label_states, num_label_states); + transition_probabilities = 0.05, 0.90, 0.05, + 0.05, 0.05, 0.90, + 0.90, 0.05, 0.05; + + matrix emission_probabilities(num_label_states,num_sample_states); + emission_probabilities = 0.5, 0.5, 0.0, + 0.0, 0.5, 0.5, + 0.5, 0.0, 0.5; + + + make_dataset(transition_probabilities,emission_probabilities, + samples, labels, 1000); + + dlog << LINFO << "samples.size(): "<< samples.size(); + + structural_sequence_labeling_trainer trainer; + structural_sequence_labeling_trainer trainer_part; + trainer.set_c(4); + trainer_part.set_c(4); + trainer.set_num_threads(4); + trainer_part.set_num_threads(4); + trainer.set_epsilon(1e-8); + trainer_part.set_epsilon(1e-8); + + + + // Learn to do sequence labeling from the dataset + sequence_labeler labeler = trainer.train(samples, labels); + sequence_labeler labeler_part = trainer_part.train(samples, labels); + + dlog << LINFO << "weight disagreement: "<< max(abs(labeler.get_weights() - labeler_part.get_weights())); + dlog << LINFO << "max weight magnitude: "<< max(abs(labeler.get_weights())); + + // Both feature extractors should be equivalent. + DLIB_TEST(max(abs(labeler.get_weights() - labeler_part.get_weights())) < 1e-6); + + } + +// ---------------------------------------------------------------------------------------- + + class sequence_labeler_tester : public tester + { + public: + sequence_labeler_tester ( + ) : + tester ("test_sequence_labeler", + "Runs tests on the sequence labeling code.") + {} + + void perform_test ( + ) + { + do_test(); + DLIB_TEST(called_rejct_labeling == false); + do_test(); + DLIB_TEST(called_rejct_labeling == true); + + test2(); + } + } a; + +} + + diff --git a/ml/dlib/dlib/test/sequence_segmenter.cpp b/ml/dlib/dlib/test/sequence_segmenter.cpp new file mode 100644 index 000000000..acdcd69be --- /dev/null +++ b/ml/dlib/dlib/test/sequence_segmenter.cpp @@ -0,0 +1,294 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include "tester.h" +#include +#include + + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.sequence_segmenter"); + +// ---------------------------------------------------------------------------------------- + + dlib::rand rnd; + + template + class unigram_extractor + { + public: + + const static bool use_BIO_model = use_BIO_model_; + const static bool use_high_order_features = use_high_order_features_; + const static bool allow_negative_weights = allow_negative_weights_; + + typedef std::vector sequence_type; + + std::map > feats; + + unigram_extractor() + { + matrix v1, v2, v3; + v1 = randm(num_features(), 1, rnd); + v2 = randm(num_features(), 1, rnd); + v3 = randm(num_features(), 1, rnd); + v1(0) = 1; + v2(1) = 1; + v3(2) = 1; + v1(3) = -1; + v2(4) = -1; + v3(5) = -1; + for (unsigned long i = 0; i < num_features(); ++i) + { + if ( i < 3) + feats[i] = v1; + else if (i < 6) + feats[i] = v2; + else + feats[i] = v3; + } + } + + unsigned long num_features() const { return 10; } + unsigned long window_size() const { return 3; } + + template + void get_features ( + feature_setter& set_feature, + const sequence_type& x, + unsigned long position + ) const + { + const matrix& m = feats.find(x[position])->second; + for (unsigned long i = 0; i < num_features(); ++i) + { + set_feature(i, m(i)); + } + } + + }; + + template + void serialize(const unigram_extractor& item , std::ostream& out ) + { + serialize(item.feats, out); + } + + template + void deserialize(unigram_extractor& item, std::istream& in) + { + deserialize(item.feats, in); + } + +// ---------------------------------------------------------------------------------------- + + void make_dataset ( + std::vector >& samples, + std::vector >& labels, + unsigned long dataset_size + ) + { + samples.clear(); + labels.clear(); + + samples.resize(dataset_size); + labels.resize(dataset_size); + + + unigram_extractor fe; + dlib::rand rnd; + + for (unsigned long iter = 0; iter < dataset_size; ++iter) + { + + samples[iter].resize(10); + labels[iter].resize(10); + + for (unsigned long i = 0; i < samples[iter].size(); ++i) + { + samples[iter][i] = rnd.get_random_32bit_number()%fe.num_features(); + if (samples[iter][i] < 3) + { + labels[iter][i] = impl_ss::BEGIN; + } + else if (samples[iter][i] < 6) + { + labels[iter][i] = impl_ss::INSIDE; + } + else + { + labels[iter][i] = impl_ss::OUTSIDE; + } + + if (i != 0) + { + // do rejection sampling to avoid impossible labels + if (labels[iter][i] == impl_ss::INSIDE && + labels[iter][i-1] == impl_ss::OUTSIDE) + { + --i; + } + } + } + } + } + +// ---------------------------------------------------------------------------------------- + + void make_dataset2 ( + std::vector >& samples, + std::vector > >& segments, + unsigned long dataset_size + ) + { + segments.clear(); + std::vector > labels; + make_dataset(samples, labels, dataset_size); + segments.resize(samples.size()); + + // Convert from BIO tagging to the explicit segments representation. + for (unsigned long k = 0; k < labels.size(); ++k) + { + for (unsigned long i = 0; i < labels[k].size(); ++i) + { + if (labels[k][i] == impl_ss::BEGIN) + { + const unsigned long begin = i; + ++i; + while (i < labels[k].size() && labels[k][i] == impl_ss::INSIDE) + ++i; + + segments[k].push_back(std::make_pair(begin, i)); + --i; + } + } + } + } + +// ---------------------------------------------------------------------------------------- + + template + void do_test() + { + dlog << LINFO << "use_BIO_model: "<< use_BIO_model; + dlog << LINFO << "use_high_order_features: "<< use_high_order_features; + dlog << LINFO << "allow_negative_weights: "<< allow_negative_weights; + + std::vector > samples; + std::vector > > segments; + make_dataset2( samples, segments, 100); + + print_spinner(); + typedef unigram_extractor fe_type; + + fe_type fe_temp; + fe_type fe_temp2; + structural_sequence_segmentation_trainer trainer(fe_temp2); + trainer.set_c(5); + trainer.set_num_threads(1); + + + sequence_segmenter labeler = trainer.train(samples, segments); + + print_spinner(); + + const std::vector > predicted_labels = labeler(samples[1]); + const std::vector > true_labels = segments[1]; + /* + for (unsigned long i = 0; i < predicted_labels.size(); ++i) + cout << "["< 0); + DLIB_TEST(predicted_labels.size() == true_labels.size()); + for (unsigned long i = 0; i < predicted_labels.size(); ++i) + { + DLIB_TEST(predicted_labels[i].first == true_labels[i].first); + DLIB_TEST(predicted_labels[i].second == true_labels[i].second); + } + + + matrix res; + + res = cross_validate_sequence_segmenter(trainer, samples, segments, 3); + dlog << LINFO << "cv res: "<< res; + DLIB_TEST(min(res) > 0.98); + make_dataset2( samples, segments, 100); + res = test_sequence_segmenter(labeler, samples, segments); + dlog << LINFO << "test res: "<< res; + DLIB_TEST(min(res) > 0.98); + + print_spinner(); + + ostringstream sout; + serialize(labeler, sout); + istringstream sin(sout.str()); + sequence_segmenter labeler2; + deserialize(labeler2, sin); + + res = test_sequence_segmenter(labeler2, samples, segments); + dlog << LINFO << "test res2: "<< res; + DLIB_TEST(min(res) > 0.98); + + long N; + if (use_BIO_model) + N = 3*3+3; + else + N = 5*5+5; + const double min_normal_weight = min(colm(labeler2.get_weights(), 0, labeler2.get_weights().size()-N)); + const double min_trans_weight = min(labeler2.get_weights()); + dlog << LINFO << "min_normal_weight: " << min_normal_weight; + dlog << LINFO << "min_trans_weight: " << min_trans_weight; + if (allow_negative_weights) + { + DLIB_TEST(min_normal_weight < 0); + DLIB_TEST(min_trans_weight < 0); + } + else + { + DLIB_TEST(min_normal_weight == 0); + DLIB_TEST(min_trans_weight < 0); + } + } + +// ---------------------------------------------------------------------------------------- + + + class unit_test_sequence_segmenter : public tester + { + public: + unit_test_sequence_segmenter ( + ) : + tester ("test_sequence_segmenter", + "Runs tests on the sequence segmenting code.") + {} + + void perform_test ( + ) + { + do_test(); + do_test(); + do_test(); + do_test(); + do_test(); + do_test(); + do_test(); + do_test(); + } + } a; + +} + + + diff --git a/ml/dlib/dlib/test/serialize.cpp b/ml/dlib/dlib/test/serialize.cpp new file mode 100644 index 000000000..f8b3384b9 --- /dev/null +++ b/ml/dlib/dlib/test/serialize.cpp @@ -0,0 +1,1087 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tester.h" + +namespace dlib +{ + static bool operator!=(const rgb_pixel& a, const rgb_pixel& b) + { + return !(a.red==b.red && a.green==b.green && a.blue==b.blue); + } + static bool operator!=(const bgr_pixel& a, const bgr_pixel& b) + { + return !(a.red==b.red && a.green==b.green && a.blue==b.blue); + } + + static bool operator!=(const hsi_pixel& a, const hsi_pixel& b) + { + return !(a.h==b.h && a.s==b.s && a.i==b.i); + } + static bool operator!=(const rgb_alpha_pixel& a, const rgb_alpha_pixel& b) + { + return !(a.red==b.red && a.green==b.green && a.blue==b.blue && a.alpha==b.alpha); + } + +} + +namespace +{ + +// ---------------------------------------------------------------------------------------- + + using namespace test; + using namespace dlib; + using namespace std; + + struct test_object + { + signed char i1; + signed short i2; + signed long i3; + unsigned char i4; + unsigned short i5; + unsigned long i6; + uint64 i7; + int64 i8; + + signed char i1_0; + signed short i2_0; + signed long i3_0; + unsigned char i4_0; + unsigned short i5_0; + unsigned long i6_0; + uint64 i7_0; + int64 i8_0; + + signed char i1_n; + signed short i2_n; + signed long i3_n; + + + float f1; + double f2; + long double f3; + float f1_inf; + double f2_inf; + long double f3_inf; + float f1_ninf; + double f2_ninf; + long double f3_ninf; + float f1_qnan; + double f2_qnan; + long double f3_qnan; + float f1_snan; + double f2_snan; + long double f3_snan; + + std::string s1; + std::wstring s2; + + int array[10]; + + bool b_true; + bool b_false; + + + void set_state_1( + ) + { + i1 = 1; + i2 = 2; + i3 = 3; + i4 = 4; + i5 = 5; + i6 = 6; + i7 = 7; + i8 = 8; + + i1_0 = 0; + i2_0 = 0; + i3_0 = 0; + i4_0 = 0; + i5_0 = 0; + i6_0 = 0; + i7_0 = 0; + i8_0 = 0; + + i1_n = -1; + i2_n = -2; + i3_n = -3; + + f1 = 123.456f; + f2 = 543.341; + f3 = 5234234.23; + + f1_inf = numeric_limits::infinity(); + f2_inf = numeric_limits::infinity(); + f3_inf = numeric_limits::infinity(); + f1_ninf = -numeric_limits::infinity(); + f2_ninf = -numeric_limits::infinity(); + f3_ninf = -numeric_limits::infinity(); + f1_qnan = numeric_limits::quiet_NaN(); + f2_qnan = numeric_limits::quiet_NaN(); + f3_qnan = numeric_limits::quiet_NaN(); + f1_snan = numeric_limits::signaling_NaN(); + f2_snan = numeric_limits::signaling_NaN(); + f3_snan = numeric_limits::signaling_NaN(); + + s1 = "davis"; + s2 = L"yo yo yo"; + + for (int i = 0; i < 10; ++i) + array[i] = i; + + b_true = true; + b_false = false; + } + + void set_state_2( + ) + { + i1 = 10; + i2 = 20; + i3 = 30; + i4 = 40; + i5 = 50; + i6 = 60; + i7 = 70; + i8 = 80; + + i1_0 = 5; + i2_0 = 6; + i3_0 = 7; + i4_0 = 8; + i5_0 = 9; + i6_0 = 10; + i7_0 = 11; + i8_0 = 12; + + i1_n = -13; + i2_n = -25; + i3_n = -12; + + f1 = 45.3f; + f2 = 0.001; + f3 = 2.332; + + f1_inf = f1; + f2_inf = f2; + f3_inf = f3; + f1_ninf = f1; + f2_ninf = f2; + f3_ninf = f3; + f1_qnan = f1; + f2_qnan = f2; + f3_qnan = f3; + f1_snan = f1; + f2_snan = f2; + f3_snan = f3; + + s1 = ""; + s2 = L""; + + for (int i = 0; i < 10; ++i) + array[i] = 10-i; + + b_true = false; + b_false = true; + } + + void assert_in_state_1 ( + ) + { + DLIB_TEST (i1 == 1); + DLIB_TEST (i2 == 2); + DLIB_TEST (i3 == 3); + DLIB_TEST (i4 == 4); + DLIB_TEST (i5 == 5); + DLIB_TEST (i6 == 6); + DLIB_TEST (i7 == 7); + DLIB_TEST (i8 == 8); + + DLIB_TEST (i1_0 == 0); + DLIB_TEST (i2_0 == 0); + DLIB_TEST (i3_0 == 0); + DLIB_TEST (i4_0 == 0); + DLIB_TEST (i5_0 == 0); + DLIB_TEST (i6_0 == 0); + DLIB_TEST (i7_0 == 0); + DLIB_TEST (i8_0 == 0); + + DLIB_TEST (i1_n == -1); + DLIB_TEST (i2_n == -2); + DLIB_TEST (i3_n == -3); + + DLIB_TEST (abs(f1 -123.456) < 1e-5); + DLIB_TEST (abs(f2 - 543.341) < 1e-10); + DLIB_TEST (abs(f3 - 5234234.23) < 1e-10); + + DLIB_TEST (f1_inf == numeric_limits::infinity()); + DLIB_TEST (f2_inf == numeric_limits::infinity()); + DLIB_TEST (f3_inf == numeric_limits::infinity()); + DLIB_TEST (f1_ninf == -numeric_limits::infinity()); + DLIB_TEST (f2_ninf == -numeric_limits::infinity()); + DLIB_TEST (f3_ninf == -numeric_limits::infinity()); + DLIB_TEST (!(f1_qnan <= numeric_limits::infinity() && f1_qnan >= -numeric_limits::infinity() )); + DLIB_TEST (!(f2_qnan <= numeric_limits::infinity() && f1_qnan >= -numeric_limits::infinity() )); + DLIB_TEST (!(f3_qnan <= numeric_limits::infinity() && f1_qnan >= -numeric_limits::infinity() )); + DLIB_TEST (!(f1_snan <= numeric_limits::infinity() && f1_qnan >= -numeric_limits::infinity() )); + DLIB_TEST (!(f2_snan <= numeric_limits::infinity() && f1_qnan >= -numeric_limits::infinity() )); + DLIB_TEST (!(f3_snan <= numeric_limits::infinity() && f1_qnan >= -numeric_limits::infinity() )); + + DLIB_TEST (s1 == "davis"); + DLIB_TEST (s2 == L"yo yo yo"); + + for (int i = 0; i < 10; ++i) + { + DLIB_TEST (array[i] == i); + } + + DLIB_TEST (b_true == true); + DLIB_TEST (b_false == false); + + } + + void assert_in_state_2 ( + ) + { + DLIB_TEST (i1 == 10); + DLIB_TEST (i2 == 20); + DLIB_TEST (i3 == 30); + DLIB_TEST (i4 == 40); + DLIB_TEST (i5 == 50); + DLIB_TEST (i6 == 60); + DLIB_TEST (i7 == 70); + DLIB_TEST (i8 == 80); + + DLIB_TEST (i1_0 == 5); + DLIB_TEST (i2_0 == 6); + DLIB_TEST (i3_0 == 7); + DLIB_TEST (i4_0 == 8); + DLIB_TEST (i5_0 == 9); + DLIB_TEST (i6_0 == 10); + DLIB_TEST (i7_0 == 11); + DLIB_TEST (i8_0 == 12); + + DLIB_TEST (i1_n == -13); + DLIB_TEST (i2_n == -25); + DLIB_TEST (i3_n == -12); + + DLIB_TEST (abs(f1 - 45.3) < 1e-5); + DLIB_TEST (abs(f2 - 0.001) < 1e-10); + DLIB_TEST (abs(f3 - 2.332) < 1e-10); + DLIB_TEST (abs(f1_inf - 45.3) < 1e-5); + DLIB_TEST (abs(f2_inf - 0.001) < 1e-10); + DLIB_TEST (abs(f3_inf - 2.332) < 1e-10); + DLIB_TEST (abs(f1_ninf - 45.3) < 1e-5); + DLIB_TEST (abs(f2_ninf - 0.001) < 1e-10); + DLIB_TEST (abs(f3_ninf - 2.332) < 1e-10); + DLIB_TEST (abs(f1_qnan - 45.3) < 1e-5); + DLIB_TEST (abs(f2_qnan - 0.001) < 1e-10); + DLIB_TEST (abs(f3_qnan - 2.332) < 1e-10); + DLIB_TEST (abs(f1_snan - 45.3) < 1e-5); + DLIB_TEST (abs(f2_snan - 0.001) < 1e-10); + DLIB_TEST (abs(f3_snan - 2.332) < 1e-10); + + DLIB_TEST (s1 == ""); + DLIB_TEST (s2 == L""); + + for (int i = 0; i < 10; ++i) + { + DLIB_TEST (array[i] == 10-i); + } + + DLIB_TEST (b_true == false); + DLIB_TEST (b_false == true); + + } + + }; + +// ---------------------------------------------------------------------------------------- + + void serialize ( + const test_object& item, + std::ostream& out + ) + { + dlib::serialize(item.i1,out); + dlib::serialize(item.i2,out); + dlib::serialize(item.i3,out); + dlib::serialize(item.i4,out); + dlib::serialize(item.i5,out); + dlib::serialize(item.i6,out); + dlib::serialize(item.i7,out); + dlib::serialize(item.i8,out); + + dlib::serialize(item.i1_0,out); + dlib::serialize(item.i2_0,out); + dlib::serialize(item.i3_0,out); + dlib::serialize(item.i4_0,out); + dlib::serialize(item.i5_0,out); + dlib::serialize(item.i6_0,out); + dlib::serialize(item.i7_0,out); + dlib::serialize(item.i8_0,out); + + dlib::serialize(item.i1_n,out); + dlib::serialize(item.i2_n,out); + dlib::serialize(item.i3_n,out); + + + dlib::serialize(item.f1,out); + dlib::serialize(item.f2,out); + dlib::serialize(item.f3,out); + + dlib::serialize(item.f1_inf,out); + dlib::serialize(item.f2_inf,out); + dlib::serialize(item.f3_inf,out); + dlib::serialize(item.f1_ninf,out); + dlib::serialize(item.f2_ninf,out); + dlib::serialize(item.f3_ninf,out); + dlib::serialize(item.f1_qnan,out); + dlib::serialize(item.f2_qnan,out); + dlib::serialize(item.f3_qnan,out); + dlib::serialize(item.f1_snan,out); + dlib::serialize(item.f2_snan,out); + dlib::serialize(item.f3_snan,out); + + dlib::serialize(item.s1,out); + dlib::serialize(item.s2,out); + + dlib::serialize(item.array,out); + + dlib::serialize(item.b_true,out); + dlib::serialize(item.b_false,out); + } + +// ---------------------------------------------------------------------------------------- + + void deserialize ( + test_object& item, + std::istream& in + ) + { + dlib::deserialize(item.i1,in); + dlib::deserialize(item.i2,in); + dlib::deserialize(item.i3,in); + dlib::deserialize(item.i4,in); + dlib::deserialize(item.i5,in); + dlib::deserialize(item.i6,in); + dlib::deserialize(item.i7,in); + dlib::deserialize(item.i8,in); + + dlib::deserialize(item.i1_0,in); + dlib::deserialize(item.i2_0,in); + dlib::deserialize(item.i3_0,in); + dlib::deserialize(item.i4_0,in); + dlib::deserialize(item.i5_0,in); + dlib::deserialize(item.i6_0,in); + dlib::deserialize(item.i7_0,in); + dlib::deserialize(item.i8_0,in); + + dlib::deserialize(item.i1_n,in); + dlib::deserialize(item.i2_n,in); + dlib::deserialize(item.i3_n,in); + + + dlib::deserialize(item.f1,in); + dlib::deserialize(item.f2,in); + dlib::deserialize(item.f3,in); + + dlib::deserialize(item.f1_inf,in); + dlib::deserialize(item.f2_inf,in); + dlib::deserialize(item.f3_inf,in); + dlib::deserialize(item.f1_ninf,in); + dlib::deserialize(item.f2_ninf,in); + dlib::deserialize(item.f3_ninf,in); + dlib::deserialize(item.f1_qnan,in); + dlib::deserialize(item.f2_qnan,in); + dlib::deserialize(item.f3_qnan,in); + dlib::deserialize(item.f1_snan,in); + dlib::deserialize(item.f2_snan,in); + dlib::deserialize(item.f3_snan,in); + + dlib::deserialize(item.s1,in); + dlib::deserialize(item.s2,in); + + dlib::deserialize(item.array,in); + + dlib::deserialize(item.b_true,in); + dlib::deserialize(item.b_false,in); + } + +// ---------------------------------------------------------------------------------------- + + // This function returns the contents of the file 'stuff.bin' but using the old + // floating point serialization format. + const std::string get_decoded_string() + { + dlib::base64::kernel_1a base64_coder; + dlib::compress_stream::kernel_1ea compressor; + std::ostringstream sout; + std::istringstream sin; + + + // The base64 encoded data from the file 'stuff.bin' we want to decode and return. + sout << "AVaifX9zEbXa9aocsrcRuvnNrR3WLuuU5eLWiy0UeXmnKXGLKZz8V44gzT4CM6wnCmAHFQug8G3C"; + sout << "4cuLdNgp2ApkeLcvwFNJRENE0ShrRaxEBFEA8nah7vm8B2VmgImNblCejuP5IcDt60EaCKlqiit8"; + sout << "+JGrzYxqBm3xFS4P+qlOROdbxc7pXBmUdh0rqNSEvn0FBPdoqY/5SpHgA2yAcH8XFrM1cdu0xS3P"; + sout << "8PBcmLMJ7bFdzplwhrjuxtm4NfEOi6Rl9sU44AXycYgJd0+uH+dyoI9X3co5b3YWJtjvdVeztNAr"; + sout << "BfSPfR6oAVNfiMBG7QA="; + + + // Put the data into the istream sin + sin.str(sout.str()); + sout.str(""); + + // Decode the base64 text into its compressed binary form + base64_coder.decode(sin,sout); + sin.clear(); + sin.str(sout.str()); + sout.str(""); + + // Decompress the data into its original form + compressor.decompress(sin,sout); + + // Return the decoded and decompressed data + return sout.str(); + } + + + // This function returns the contents of the file 'stuff.bin' but using the new + // floating point serialization format. + const std::string get_decoded_string2() + { + dlib::base64 base64_coder; + dlib::compress_stream::kernel_1ea compressor; + std::ostringstream sout; + std::istringstream sin; + + // The base64 encoded data from the file 'stuff.bin' we want to decode and return. + sout << "AVaifX9zEbXa9aocsrcRuvnNqzZLptZ5mRd46xScCIfX6sq/46hG9JwIInElG50EtJKJY/+jAWit"; + sout << "TpDBWrxBz124JRLsBz62h0D3Tqgnd8zygRx7t33Ybw40o07MrhzNEHgYavUukaPje5by78JIWHgk"; + sout << "l7nb/TK+9ndVLrAThJ4v+GiPT3kh9H1tAAAAAQhbLa06pQjhrnjTXcRox1ZBEAV9/q1zAA=="; + + // Put the data into the istream sin + sin.str(sout.str()); + sout.str(""); + + // Decode the base64 text into its compressed binary form + base64_coder.decode(sin,sout); + sin.clear(); + sin.str(sout.str()); + sout.str(""); + + // Decompress the data into its original form + compressor.decompress(sin,sout); + + // Return the decoded and decompressed data + return sout.str(); + } + +// ---------------------------------------------------------------------------------------- + + // Declare the logger we will use in this test. The name of the tester + // should start with "test." + logger dlog("test.serialize"); + + void serialize_test ( + ) + /*! + ensures + - runs tests on the serialization code for compliance with the specs + !*/ + { + + + print_spinner(); + + ostringstream sout; + test_object obj; + + obj.set_state_1(); + obj.assert_in_state_1(); + serialize(obj, sout); + obj.assert_in_state_1(); + + obj.set_state_2(); + obj.assert_in_state_2(); + serialize(obj, sout); + obj.assert_in_state_2(); + + + istringstream sin(sout.str()); + + deserialize(obj,sin); + obj.assert_in_state_1(); + deserialize(obj,sin); + obj.assert_in_state_2(); + + + // now do the same thing as above but deserialize from some stored binary + // data to make sure the serialized values are portable between different + // machines + + sin.clear(); + sin.str(get_decoded_string()); + deserialize(obj,sin); + obj.assert_in_state_1(); + deserialize(obj,sin); + obj.assert_in_state_2(); + + + sin.clear(); + sin.str(get_decoded_string2()); + deserialize(obj,sin); + obj.assert_in_state_1(); + deserialize(obj,sin); + obj.assert_in_state_2(); + + + /* + // This is the code that produced the encoded data stored in the get_decoded_string() function + ofstream fout("stuff.bin",ios::binary); + obj.set_state_1(); + obj.assert_in_state_1(); + serialize(obj, fout); + obj.assert_in_state_1(); + + obj.set_state_2(); + obj.assert_in_state_2(); + serialize(obj, fout); + obj.assert_in_state_2(); + */ + + + test_object obj2; + obj.set_state_1(); + obj2.set_state_2(); + dlib::serialize("serialization_test.dat") << obj << obj2; + obj.assert_in_state_1(); + obj2.assert_in_state_2(); + obj.set_state_2(); + obj2.set_state_1(); + obj.assert_in_state_2(); + obj2.assert_in_state_1(); + dlib::deserialize("serialization_test.dat") >> obj >> obj2; + obj.assert_in_state_1(); + obj2.assert_in_state_2(); + } + + + template + void test_vector ( + ) + { + std::vector a, b; + + for (int i = -10; i < 30; ++i) + { + a.push_back(i); + } + + ostringstream sout; + dlib::serialize(a, sout); + istringstream sin(sout.str()); + + dlib::deserialize(b, sin); + + + DLIB_TEST(a.size() == b.size()); + DLIB_TEST(a.size() == 40); + for (unsigned long i = 0; i < a.size(); ++i) + { + DLIB_TEST(a[i] == b[i]); + } + + std::vector c; + sout.str(""); + dlib::serialize(c, sout); + sin.str(sout.str()); + dlib::deserialize(a, sin); + DLIB_TEST(a.size() == 0); + DLIB_TEST(c.size() == 0); + } + + void test_std_array ( + ) + { + std::array a, b; + + a = {1, 2, 3, 4, 5}; + + ostringstream sout; + dlib::serialize(a, sout); + istringstream sin(sout.str()); + + dlib::deserialize(b, sin); + + + DLIB_TEST(a.size() == b.size()); + DLIB_TEST(a.size() == 5); + for (unsigned long i = 0; i < a.size(); ++i) + { + DLIB_TEST(a[i] == b[i]); + } + + std::array aa, bb; + sout.str(""); + dlib::serialize(aa, sout); + sin.str(sout.str()); + dlib::deserialize(bb, sin); + DLIB_TEST(bb.size() == 0); + } + + void test_vector_bool ( + ) + { + std::vector a, b; + + a.push_back(true); + a.push_back(true); + a.push_back(false); + a.push_back(true); + a.push_back(false); + a.push_back(true); + + ostringstream sout; + dlib::serialize(a, sout); + istringstream sin(sout.str()); + + dlib::deserialize(b, sin); + + + DLIB_TEST(a.size() == b.size()); + DLIB_TEST(a.size() == 6); + for (unsigned long i = 0; i < a.size(); ++i) + { + DLIB_TEST(a[i] == b[i]); + } + } + +// ---------------------------------------------------------------------------------------- + + // This function returns the contents of the file 'matarray.dat' + const std::string get_decoded_string_matarray_old() + { + dlib::base64 base64_coder; + dlib::compress_stream::kernel_1ea compressor; + std::ostringstream sout; + std::istringstream sin; + + // The base64 encoded data from the file 'matarray.dat' we want to decode and return. + sout << "AW852sEbTIeV+m/wLUcKJKPW+6IclviUWZcFh1daDZ0blDjPNTgPx0Lv56sIEwlG4I6C5OJzJBkZ"; + sout << "PvczLjS7IEKh6eg7amNOyEexsQSgojL1oMe2gDEfkyInUGPJV90sNS0cvp/hIB134V8JCTYUP6vH"; + sout << "9qpegLSIIQG+/NjLWyK2472vC88BJfKgkL3CPLMjQwB3tB928FNLbESDLIvpnb6q9ve68iuoyZZt"; + sout << "z3TTJxHW3MIdgzuhNomvPxfo/Q+7lC/Orj0FewUX90al6DckwzOtLVRidh/ZKpsQsxzJYQGkjdX5"; + sout << "mDzzXKqQb3Y3DnzEmwtRD9CUON3iRv1r26gHWLYorrYA"; + + + // Put the data into the istream sin + sin.str(sout.str()); + sout.str(""); + + // Decode the base64 text into its compressed binary form + base64_coder.decode(sin,sout); + sin.clear(); + sin.str(sout.str()); + sout.str(""); + + // Decompress the data into its original form + compressor.decompress(sin,sout); + + // Return the decoded and decompressed data + return sout.str(); + } + + // This function returns the contents of the file 'matarray.dat' + const std::string get_decoded_string_matarray() + { + dlib::base64 base64_coder; + dlib::compress_stream::kernel_1ea compressor; + std::ostringstream sout; + std::istringstream sin; + + // The base64 encoded data from the file 'matarray.dat' we want to decode and return. + sout << "gO6XH2WGbm8Xaw3a5FJbh3V823W6P2Qk/vHaAAAAARccIppHWdmViaKby7JA5PQvXjYMWUYvXRHv"; + sout << "xPdURZl1un3CT/rjT11Yry0y3+1W7GBmfBJ0gVFKGdiGuqoNAMtmzL/ll3YfEQ7ED7aB33aDTktw"; + sout << "AWVkHT+gqTbKwjP+8YvB3s3ziK640ITOAWazAghKDVl7AHGn+fjq29paBZMczuJofl8FinZUhwa9"; + sout << "Ol5gdAEQa6VZDmJUeo2soTJcEDpkW9LkRmXvjQkyEHfEHQNFDfQq4p2U+dHz4lOKlcj3VzQIeG/s"; + sout << "oxa9KhJND4aQ5xeNUUHUzFBU3XhQHlyDIn/RNdX/ZwA="; + + + // Put the data into the istream sin + sin.str(sout.str()); + sout.str(""); + + // Decode the base64 text into its compressed binary form + base64_coder.decode(sin,sout); + sin.clear(); + sin.str(sout.str()); + sout.str(""); + + // Decompress the data into its original form + compressor.decompress(sin,sout); + + // Return the decoded and decompressed data + return sout.str(); + } + + void setup_mats_and_arrays ( + array2d& a, + matrix& m, + array2d& img1, + array2d& img2, + array2d& img3, + array2d& img4, + array2d& img5 + ) + { + a.set_size(3,5); + int cnt = 0; + for (long r = 0; r < a.nr(); ++r) + { + for (long c = 0; c < a.nc(); ++c) + { + a[r][c] = cnt++; + } + } + m = mat(a); + + img1.set_size(3,5); + img2.set_size(3,5); + img3.set_size(3,5); + img4.set_size(3,5); + img5.set_size(3,5); + + assign_all_pixels(img1, 0); + assign_all_pixels(img2, 0); + assign_all_pixels(img3, 0); + assign_all_pixels(img4, 0); + assign_all_pixels(img5, 0); + + unsigned char pcnt = 0; + for (long r = 0; r < img1.nr(); ++r) + { + for (long c = 0; c < img1.nc(); ++c) + { + rgb_alpha_pixel temp; + temp.red = pcnt++; + temp.green = pcnt++; + temp.blue = pcnt++; + temp.alpha = 150+pcnt++; + assign_pixel(img1[r][c], temp); + assign_pixel(img2[r][c], temp); + assign_pixel(img3[r][c], temp); + assign_pixel(img4[r][c], temp); + } + } + + for (long r = 0; r < img5.nr(); ++r) + { + for (long c = 0; c < img5.nc(); ++c) + { + img5[r][c].h = pcnt++; + img5[r][c].s = pcnt++; + img5[r][c].i = pcnt++; + } + } + } + + + void test_deserialize( + std::istream& fin + ) + { + array2d a; + matrix m; + array2d img1; + array2d img2; + array2d img3; + array2d img4; + array2d img5; + setup_mats_and_arrays(a,m,img1,img2,img3,img4,img5); + + + array2d img1_; + array2d img2_; + array2d img3_; + array2d img4_; + array2d img5_; + + matrix m_; + array2d a_; + + deserialize(a_, fin); DLIB_TEST(mat(a_) == mat(a)); + deserialize(m_, fin); DLIB_TEST(mat(m_) == mat(m)); + deserialize(a_, fin); DLIB_TEST(mat(a_) == mat(a)); + deserialize(m_, fin); DLIB_TEST(mat(m_) == mat(m)); + + deserialize(img1_, fin); DLIB_TEST(mat(img1_) == mat(img1)); + deserialize(img2_, fin); DLIB_TEST(mat(img2_) == mat(img2)); + deserialize(img3_, fin); DLIB_TEST(mat(img3_) == mat(img3)); + deserialize(img4_, fin); DLIB_TEST(mat(img4_) == mat(img4)); + deserialize(img5_, fin); DLIB_TEST(mat(img5_) == mat(img5)); + } + + void test_deserialize_all_array2d( + std::istream& fin + ) + { + array2d a; + matrix m; + array2d img1; + array2d img2; + array2d img3; + array2d img4; + array2d img5; + setup_mats_and_arrays(a,m,img1,img2,img3,img4,img5); + + + array2d img1_; + array2d img2_; + array2d img3_; + array2d img4_; + array2d img5_; + + array2d m_; + array2d a_; + + deserialize(a_, fin); DLIB_TEST(mat(a_) == mat(a)); + deserialize(m_, fin); DLIB_TEST(mat(m_) == mat(m)); + deserialize(a_, fin); DLIB_TEST(mat(a_) == mat(a)); + deserialize(m_, fin); DLIB_TEST(mat(m_) == mat(m)); + + deserialize(img1_, fin); DLIB_TEST(mat(img1_) == mat(img1)); + deserialize(img2_, fin); DLIB_TEST(mat(img2_) == mat(img2)); + deserialize(img3_, fin); DLIB_TEST(mat(img3_) == mat(img3)); + deserialize(img4_, fin); DLIB_TEST(mat(img4_) == mat(img4)); + deserialize(img5_, fin); DLIB_TEST(mat(img5_) == mat(img5)); + } + + void test_deserialize_all_matrix( + std::istream& fin + ) + { + array2d a; + matrix m; + array2d img1; + array2d img2; + array2d img3; + array2d img4; + array2d img5; + setup_mats_and_arrays(a,m,img1,img2,img3,img4,img5); + + + matrix img1_; + matrix img2_; + matrix img3_; + matrix img4_; + matrix img5_; + + matrix m_; + matrix a_; + + deserialize(a_, fin); DLIB_TEST(mat(a_) == mat(a)); + deserialize(m_, fin); DLIB_TEST(mat(m_) == mat(m)); + deserialize(a_, fin); DLIB_TEST(mat(a_) == mat(a)); + deserialize(m_, fin); DLIB_TEST(mat(m_) == mat(m)); + + deserialize(img1_, fin); DLIB_TEST(mat(img1_) == mat(img1)); + deserialize(img2_, fin); DLIB_TEST(mat(img2_) == mat(img2)); + deserialize(img3_, fin); DLIB_TEST(mat(img3_) == mat(img3)); + deserialize(img4_, fin); DLIB_TEST(mat(img4_) == mat(img4)); + deserialize(img5_, fin); DLIB_TEST(mat(img5_) == mat(img5)); + } + + void test_array2d_and_matrix_serialization() + { + ostringstream sout; + array2d a; + matrix m; + array2d img1; + array2d img2; + array2d img3; + array2d img4; + array2d img5; + setup_mats_and_arrays(a,m,img1,img2,img3,img4,img5); + + serialize(a, sout); + serialize(m, sout); + serialize(a, sout); + serialize(m, sout); + + serialize(img1, sout); + serialize(img2, sout); + serialize(img3, sout); + serialize(img4, sout); + serialize(img5, sout); + + // -------------------- + + { + istringstream sin(sout.str()); + test_deserialize(sin); + } + { + istringstream sin(sout.str()); + test_deserialize_all_array2d(sin); + } + { + istringstream sin(sout.str()); + test_deserialize_all_matrix(sin); + } + + + { + istringstream sin(get_decoded_string_matarray()); + test_deserialize(sin); + } + { + istringstream sin(get_decoded_string_matarray()); + test_deserialize_all_array2d(sin); + } + { + istringstream sin(get_decoded_string_matarray()); + test_deserialize_all_matrix(sin); + } + + + { + // Make sure we can still deserialize the serialization + // format for array2d and matrix objects used by older versions + // of dlib. + istringstream sin(get_decoded_string_matarray_old()); + test_deserialize(sin); + } + } + +// ---------------------------------------------------------------------------------------- + + void test_strings() + { + string str1 = "stuff"; + char buf[6]; + buf[0] = 0; + buf[1] = 1; + buf[2] = 2; + buf[3] = 0; + buf[4] = 3; + buf[5] = 3; + + dlib::serialize("ser_test_string.dat") << str1 << buf << "morestuff" << ""; + + string str2, str3, str4; + char buf2[6]; + memset(buf2,0,sizeof(buf2)); + dlib::deserialize("ser_test_string.dat") >> str2 >> buf2 >> str3 >> str4; + DLIB_TEST(str2 == "stuff"); + DLIB_TEST(str3 == "morestuff"); + DLIB_TEST(str4 == ""); + DLIB_TEST(buf2[0] == 0); + DLIB_TEST(buf2[1] == 1); + DLIB_TEST(buf2[2] == 2); + DLIB_TEST(buf2[3] == 0); + DLIB_TEST(buf2[4] == 3); + DLIB_TEST(buf2[5] == 3); + + + ofstream fout("ser_test_string.dat", ios::binary); + dlib::serialize(str1, fout); + dlib::serialize(buf, fout); + dlib::serialize("morestuff", fout); + fout.close(); + ifstream fin("ser_test_string.dat", ios::binary); + memset(buf2,0,sizeof(buf2)); + str2.clear(); + str3.clear(); + dlib::deserialize(str2, fin); + dlib::deserialize(buf2, fin); + dlib::deserialize(str3, fin); + + DLIB_TEST(str2 == "stuff"); + DLIB_TEST(str3 == "morestuff"); + DLIB_TEST(buf2[0] == 0); + DLIB_TEST(buf2[1] == 1); + DLIB_TEST(buf2[2] == 2); + DLIB_TEST(buf2[3] == 0); + DLIB_TEST(buf2[4] == 3); + DLIB_TEST(buf2[5] == 3); + + + + // make sure ramdump() overloads compile and work. + { + matrix a = {1,2,3,4}; + const matrix b = {3,2,3,4}; + dlib::serialize("ramdump_mat.dat") << ramdump(a) << ramdump(b); + matrix A, B; + dlib::deserialize("ramdump_mat.dat") >> ramdump(A) >> ramdump(B); + + DLIB_TEST(A == a); + DLIB_TEST(B == b); + A = 0; + B = 0; + DLIB_TEST(A != a); + DLIB_TEST(B != b); + + ostringstream sout; + dlib::serialize(ramdump(a), sout); + dlib::serialize(ramdump(b), sout); + istringstream sin(sout.str()); + dlib::deserialize(ramdump(A), sin); + dlib::deserialize(ramdump(B), sin); + + DLIB_TEST(A == a); + DLIB_TEST(B == b); + } + } + +// ---------------------------------------------------------------------------------------- + + class serialize_tester : public tester + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a test for the serialization . When it is constructed + it adds itself into the testing framework. The command line switch is + specified as test_serialize by passing that string to the tester constructor. + !*/ + public: + serialize_tester ( + ) : + tester ("test_serialize", + "Runs tests on the serialization code.") + {} + + void perform_test ( + ) + { + serialize_test(); + test_vector(); + test_vector(); + test_vector(); + test_vector_bool(); + test_array2d_and_matrix_serialization(); + test_strings(); + test_std_array(); + } + } a; + + +} + + diff --git a/ml/dlib/dlib/test/set.cpp b/ml/dlib/dlib/test/set.cpp new file mode 100644 index 000000000..f8d3bd374 --- /dev/null +++ b/ml/dlib/dlib/test/set.cpp @@ -0,0 +1,464 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include + +#include "tester.h" + +namespace +{ + using namespace test; + using namespace std; + using namespace dlib; + + logger dlog("test.set"); + + template < + typename set + > + void set_compare_test ( + ) + /*! + requires + - set is an implementation of set/set_compare_abstract.h and + is instantiated with int + ensures + - runs tests on set for compliance with the specs + !*/ + { + + + srand(static_cast(time(0))); + + + + set test, test2; + + enumerable& e = test; + DLIB_TEST(e.at_start() == true); + + for (int j = 0; j < 4; ++j) + { + + DLIB_TEST(test.at_start() == true); + DLIB_TEST(test.current_element_valid() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.at_start() == false); + DLIB_TEST(test.current_element_valid() == false); + + + DLIB_TEST(test.size() == 0); + DLIB_TEST(test.is_member(5) == false); + DLIB_TEST(test.is_member(0) == false); + DLIB_TEST(test.is_member(-999) == false); + DLIB_TEST(test.is_member(4999) == false); + + + int a,b = 0; + a = 8; + test.add(a); + DLIB_TEST(test.size() == 1); + DLIB_TEST(test.is_member(8) == true); + DLIB_TEST(test.is_member(5) == false); + DLIB_TEST(test.is_member(0) == false); + DLIB_TEST(test.is_member(-999) == false); + DLIB_TEST(test.is_member(4999) == false); + a = 53; + test.add(a); + DLIB_TEST(test.size() == 2); + DLIB_TEST(test.is_member(53) == true); + DLIB_TEST(test.is_member(5) == false); + DLIB_TEST(test.is_member(0) == false); + DLIB_TEST(test.is_member(-999) == false); + DLIB_TEST(test.is_member(4999) == false); + + + swap(test,test2); + + + + DLIB_TEST(test2.is_member(8) == true); + DLIB_TEST(test2.is_member(5) == false); + DLIB_TEST(test2.is_member(0) == false); + DLIB_TEST(test2.is_member(-999) == false); + DLIB_TEST(test2.is_member(4999) == false); + DLIB_TEST(test2.size() == 2); + DLIB_TEST(test2.is_member(53) == true); + DLIB_TEST(test2.is_member(5) == false); + DLIB_TEST(test2.is_member(0) == false); + DLIB_TEST(test2.is_member(-999) == false); + DLIB_TEST(test2.is_member(4999) == false); + + + DLIB_TEST(test.size() == 0); + DLIB_TEST(test.is_member(8) == false); + DLIB_TEST(test.is_member(5) == false); + DLIB_TEST(test.is_member(0) == false); + DLIB_TEST(test.is_member(-999) == false); + DLIB_TEST(test.is_member(4999) == false); + DLIB_TEST(test.size() == 0); + DLIB_TEST(test.is_member(53) == false); + DLIB_TEST(test.is_member(5) == false); + DLIB_TEST(test.is_member(0) == false); + DLIB_TEST(test.is_member(-999) == false); + DLIB_TEST(test.is_member(4999) == false); + + + test.clear(); + DLIB_TEST(test.at_start() == true); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.at_start() == false); + + + DLIB_TEST(test.size() == 0); + + while (test.size() < 10000) + { + a = ::rand(); + if (!test.is_member(a)) + test.add(a); + } + + DLIB_TEST(test.size() == 10000); + test.clear(); + DLIB_TEST(test.size() == 0); + + while (test.size() < 10000) + { + a = ::rand(); + if (!test.is_member(a)) + test.add(a); + } + + DLIB_TEST(test.size() == 10000); + + int count = 0; + a = 0; + while (test.move_next()) + { + enumerable& gogo = test; + gogo.element(); + + DLIB_TEST(test.element() == test.element()); + DLIB_TEST(test.element() == test.element()); + DLIB_TEST(test.element() == test.element()); + + DLIB_TEST(a <= test.element()); + a = test.element(); + ++count; + } + DLIB_TEST(test.current_element_valid() == false); + DLIB_TEST(test.at_start() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.current_element_valid() == false); + DLIB_TEST(test.at_start() == false); + DLIB_TEST(test.move_next() == false); + + DLIB_TEST(count == 10000); + + test.swap(test2); + + DLIB_TEST(test.size() == 2); + DLIB_TEST(test2.size() == 10000); + count = 0; + a = -1; + test2.reset(); + while (test2.move_next()) + { + DLIB_TEST(test2.element() == test2.element()); + DLIB_TEST(test2.element() == test2.element()); + DLIB_TEST(test2.element() == test2.element()); + DLIB_TEST(a < test2.element()); + a = test2.element(); + ++count; + } + DLIB_TEST(test2.size() == 10000); + DLIB_TEST(count == 10000); + DLIB_TEST(test2.current_element_valid() == false); + DLIB_TEST(test2.at_start() == false); + DLIB_TEST(test2.move_next() == false); + DLIB_TEST(test2.current_element_valid() == false); + DLIB_TEST(test2.at_start() == false); + DLIB_TEST(test2.move_next() == false); + + + + test2.clear(); + DLIB_TEST(test2.size() == 0); + DLIB_TEST(test2.at_start() == true); + + while (test.size() < 20000) + { + a = ::rand(); + if (!test.is_member(a)) + test.add(a); + } + + DLIB_TEST(test.at_start() == true); + + { + int* array = new int[test.size()]; + int* tmp = array; + + count = 0; + while (test.move_next()) + { + DLIB_TEST(test.element() == test.element()); + DLIB_TEST(test.element() == test.element()); + DLIB_TEST(test.element() == test.element()); + *tmp = test.element(); + ++tmp; + ++count; + } + DLIB_TEST(count == 20000); + + // serialize the state of test, then clear test, then + // load the state back into test. + ostringstream sout; + serialize(test,sout); + DLIB_TEST(test.at_start() == true); + istringstream sin(sout.str()); + test.clear(); + deserialize(test,sin); + + + + tmp = array; + for (int i = 0; i < 20000; ++i) + { + DLIB_TEST(test.is_member(*tmp) == true); + ++tmp; + } + + DLIB_TEST(test.size() == 20000); + + tmp = array; + count = 0; + while (test.size() > 10000) + { + test.remove(*tmp,a); + DLIB_TEST(*tmp == a); + ++tmp; + ++count; + } + DLIB_TEST(count == 10000); + DLIB_TEST(test.size() == 10000); + + while (test.move_next()) + { + DLIB_TEST(test.element() == *tmp); + DLIB_TEST(test.element() == *tmp); + DLIB_TEST(test.element() == *tmp); + ++tmp; + ++count; + } + DLIB_TEST(count == 20000); + DLIB_TEST(test.size() == 10000); + + while (test.size() < 20000) + { + a = ::rand(); + if (!test.is_member(a)) + test.add(a); + } + + test2.swap(test); + + count = 0; + a = 0; + while (test2.move_next()) + { + DLIB_TEST(test2.element() == test2.element()); + DLIB_TEST(test2.element() == test2.element()); + DLIB_TEST(test2.element() == test2.element()); + DLIB_TEST(a <= test2.element()); + a = test2.element(); + ++count; + } + + DLIB_TEST(count == 20000); + DLIB_TEST(test2.size() == 20000); + + a = -1; + while (test2.size()>0) + { + test2.remove_any(b); + DLIB_TEST( a < b); + a = b; + } + + DLIB_TEST(test2.size() == 0); + delete [] array; + } + + test.clear(); + test2.clear(); + while (test.size() < 10000) + { + a = ::rand(); + if (!test.is_member(a)) + test.add(a); + } + + count = 0; + a = -1; + while (test.move_next()) + { + DLIB_TEST(a < test.element()); + a = test.element(); + ++count; + if (count == 5000) + break; + DLIB_TEST(test.current_element_valid() == true); + } + + test.reset(); + + count = 0; + a = -1; + while (test.move_next()) + { + DLIB_TEST(a < test.element()); + a = test.element(); + ++count; + DLIB_TEST(test.current_element_valid() == true); + } + + DLIB_TEST(count == 10000); + + + test.clear(); + test2.clear(); + } + + + + { + DLIB_TEST(test == test2); + DLIB_TEST((test < test2) == false); + DLIB_TEST((test2 < test) == false); + + int a = 3, b = 3; + test.add(a); + test2.add(b); + test.move_next(); + DLIB_TEST(test == test2); + DLIB_TEST(test.at_start() && test2.at_start()); + test.move_next(); + DLIB_TEST((test < test2) == false); + DLIB_TEST(test.at_start() && test2.at_start()); + test.move_next(); + DLIB_TEST((test2 < test) == false); + DLIB_TEST(test.at_start() && test2.at_start()); + + a = 2; b = 5; + test.add(a); + test2.add(b); + DLIB_TEST(test.at_start() && test2.at_start()); + test2.move_next(); + DLIB_TEST((test == test2) == false); + DLIB_TEST(test.at_start() && test2.at_start()); + test2.move_next(); + DLIB_TEST((test < test2) == true); + DLIB_TEST(test.at_start() && test2.at_start()); + test2.move_next(); + DLIB_TEST((test2 < test) == false); + DLIB_TEST(test.at_start() && test2.at_start()); + + + a = 8; + test.add(a); + DLIB_TEST(test.at_start() && test2.at_start()); + test2.move_next(); + DLIB_TEST((test == test2) == false); + DLIB_TEST(test.at_start() && test2.at_start()); + test2.move_next(); + DLIB_TEST((test < test2) == false); + DLIB_TEST(test.at_start() && test2.at_start()); + test2.move_next(); + DLIB_TEST((test2 < test) == true); + DLIB_TEST(test.at_start() && test2.at_start()); + + test.clear(); + + DLIB_TEST(test.at_start() && test2.at_start()); + test2.move_next(); + DLIB_TEST((test == test2) == false); + DLIB_TEST(test.at_start() && test2.at_start()); + test2.move_next(); + DLIB_TEST((test < test2) == true); + DLIB_TEST(test.at_start() && test2.at_start()); + test2.move_next(); + DLIB_TEST((test2 < test) == false); + DLIB_TEST(test.at_start() && test2.at_start()); + + + } + + + { + test.clear(); + DLIB_TEST(test.size() == 0); + int a = 5; + test.add(a); + a = 7; + test.add(a); + DLIB_TEST(test.size() == 2); + DLIB_TEST(test.is_member(7)); + DLIB_TEST(test.is_member(5)); + test.destroy(7); + DLIB_TEST(test.size() == 1); + DLIB_TEST(!test.is_member(7)); + DLIB_TEST(test.is_member(5)); + test.destroy(5); + DLIB_TEST(test.size() == 0); + DLIB_TEST(!test.is_member(7)); + DLIB_TEST(!test.is_member(5)); + } + + + } + + + + + class set_tester : public tester + { + public: + set_tester ( + ) : + tester ("test_set", + "Runs tests on the set component.") + {} + + void perform_test ( + ) + { + dlog << LINFO << "testing compare_1a"; + set_compare_test::compare_1a> (); + dlog << LINFO << "testing compare_1a_c"; + set_compare_test::compare_1a_c>(); + dlog << LINFO << "testing compare_1b"; + set_compare_test::compare_1b> (); + dlog << LINFO << "testing compare_1b_c"; + set_compare_test::compare_1b_c>(); + } + } a; + +} + diff --git a/ml/dlib/dlib/test/sldf.cpp b/ml/dlib/dlib/test/sldf.cpp new file mode 100644 index 000000000..4ca9a3dd0 --- /dev/null +++ b/ml/dlib/dlib/test/sldf.cpp @@ -0,0 +1,296 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "tester.h" +#include +#include +#include +#include +#include +#include +#include + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + dlib::logger dlog("test.sldf"); + + + class sldf_tester : public tester + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a unit test. When it is constructed + it adds itself into the testing framework. + !*/ + public: + sldf_tester ( + ) : + tester ( + "test_sldf", // the command line argument name for this test + "Run tests on the simplify_linear_decision_function routines.", // the command line argument description + 0 // the number of command line arguments for this test + ) + { + } + + dlib::rand rnd; + + + void perform_test ( + ) + { + print_spinner(); + typedef std::map sample_type; + + typedef matrix dense_sample_type; + + typedef sparse_linear_kernel kernel_type; + typedef linear_kernel dense_kernel_type; + + + svm_nu_trainer linear_trainer; + linear_trainer.set_nu(0.2); + svm_nu_trainer dense_linear_trainer; + dense_linear_trainer.set_nu(0.2); + + std::vector samples; + std::vector labels; + + // make an instance of a sample vector so we can use it below + sample_type sample; + + // Now lets go into a loop and randomly generate 300 samples. + double label = +1; + for (int i = 0; i < 300; ++i) + { + // flip this flag + label *= -1; + + sample.clear(); + + // now make a random sparse sample with at most 10 non-zero elements + for (int j = 0; j < 10; ++j) + { + int idx = rnd.get_random_32bit_number()%100; + double value = rnd.get_random_double(); + + sample[idx] = label*value; + } + + // Also save the samples we are generating so we can let the svm_c_linear_trainer + // learn from them below. + samples.push_back(sample); + labels.push_back(label); + } + + + { + print_spinner(); + dlog << LINFO << " test with sparse samples "; + decision_function df = linear_trainer.train(samples, labels); + + dlog << LINFO << "df.basis_vectors.size(): "<< df.basis_vectors.size(); + DLIB_TEST(df.basis_vectors.size() > 4); + + dlog << LINFO << "test scores: "<< test_binary_decision_function(df, samples, labels); + + // save the outputs of the decision function before we mess with it + std::vector prev_vals; + for (unsigned long i = 0; i < samples.size(); ++i) + prev_vals.push_back(df(samples[i])); + + df = simplify_linear_decision_function(df); + + dlog << LINFO << "df.basis_vectors.size(): "<< df.basis_vectors.size(); + DLIB_TEST(df.basis_vectors.size() == 1); + + dlog << LINFO << "test scores: "<< test_binary_decision_function(df, samples, labels); + + // now check that the simplified decision function still produces the same results + std::vector cur_vals; + for (unsigned long i = 0; i < samples.size(); ++i) + cur_vals.push_back(df(samples[i])); + + const double err = max(abs(mat(cur_vals) - mat(prev_vals))); + dlog << LINFO << "simplify error: "<< err; + DLIB_TEST(err < 1e-13); + + } + + + // same as above but call simplify_linear_decision_function() two times + { + print_spinner(); + dlog << LINFO << " test with sparse samples "; + decision_function df = linear_trainer.train(samples, labels); + + dlog << LINFO << "df.basis_vectors.size(): "<< df.basis_vectors.size(); + DLIB_TEST(df.basis_vectors.size() > 4); + + dlog << LINFO << "test scores: "<< test_binary_decision_function(df, samples, labels); + + // save the outputs of the decision function before we mess with it + std::vector prev_vals; + for (unsigned long i = 0; i < samples.size(); ++i) + prev_vals.push_back(df(samples[i])); + + df = simplify_linear_decision_function(df); + df = simplify_linear_decision_function(df); + + dlog << LINFO << "df.basis_vectors.size(): "<< df.basis_vectors.size(); + DLIB_TEST(df.basis_vectors.size() == 1); + + dlog << LINFO << "test scores: "<< test_binary_decision_function(df, samples, labels); + + // now check that the simplified decision function still produces the same results + std::vector cur_vals; + for (unsigned long i = 0; i < samples.size(); ++i) + cur_vals.push_back(df(samples[i])); + + const double err = max(abs(mat(cur_vals) - mat(prev_vals))); + dlog << LINFO << "simplify error: "<< err; + DLIB_TEST(err < 1e-13); + + } + + + { + print_spinner(); + dlog << LINFO << " test with dense samples "; + std::vector dense_samples(sparse_to_dense(samples)); + + // In addition to the rule we learned with the pegasos trainer lets also use our linear_trainer + // to learn a decision rule. + decision_function dense_df = dense_linear_trainer.train(dense_samples, labels); + + dlog << LINFO << "dense_df.basis_vectors.size(): "<< dense_df.basis_vectors.size(); + DLIB_TEST(dense_df.basis_vectors.size() > 4); + + dlog << LINFO << "test scores: "<< test_binary_decision_function(dense_df, dense_samples, labels); + + // save the outputs of the decision function before we mess with it + std::vector prev_vals; + for (unsigned long i = 0; i < dense_samples.size(); ++i) + prev_vals.push_back(dense_df(dense_samples[i])); + + dense_df = simplify_linear_decision_function(dense_df); + + dlog << LINFO << "dense_df.basis_vectors.size(): "<< dense_df.basis_vectors.size(); + DLIB_TEST(dense_df.basis_vectors.size() == 1); + + dlog << LINFO << "test scores: "<< test_binary_decision_function(dense_df, dense_samples, labels); + + + // now check that the simplified decision function still produces the same results + std::vector cur_vals; + for (unsigned long i = 0; i < dense_samples.size(); ++i) + cur_vals.push_back(dense_df(dense_samples[i])); + + const double err = max(abs(mat(cur_vals) - mat(prev_vals))); + dlog << LINFO << "simplify error: "<< err; + DLIB_TEST(err < 1e-13); + } + + // same as above but call simplify_linear_decision_function() two times + { + print_spinner(); + dlog << LINFO << " test with dense samples "; + std::vector dense_samples(sparse_to_dense(samples)); + + // In addition to the rule we learned with the pegasos trainer lets also use our linear_trainer + // to learn a decision rule. + decision_function dense_df = dense_linear_trainer.train(dense_samples, labels); + + dlog << LINFO << "dense_df.basis_vectors.size(): "<< dense_df.basis_vectors.size(); + DLIB_TEST(dense_df.basis_vectors.size() > 4); + + dlog << LINFO << "test scores: "<< test_binary_decision_function(dense_df, dense_samples, labels); + + // save the outputs of the decision function before we mess with it + std::vector prev_vals; + for (unsigned long i = 0; i < dense_samples.size(); ++i) + prev_vals.push_back(dense_df(dense_samples[i])); + + dense_df = simplify_linear_decision_function(dense_df); + dense_df = simplify_linear_decision_function(dense_df); + + dlog << LINFO << "dense_df.basis_vectors.size(): "<< dense_df.basis_vectors.size(); + DLIB_TEST(dense_df.basis_vectors.size() == 1); + + dlog << LINFO << "test scores: "<< test_binary_decision_function(dense_df, dense_samples, labels); + + + // now check that the simplified decision function still produces the same results + std::vector cur_vals; + for (unsigned long i = 0; i < dense_samples.size(); ++i) + cur_vals.push_back(dense_df(dense_samples[i])); + + const double err = max(abs(mat(cur_vals) - mat(prev_vals))); + dlog << LINFO << "simplify error: "<< err; + DLIB_TEST(err < 1e-13); + } + + { + print_spinner(); + + dlog << LINFO << " test with sparse samples and a vector normalizer"; + std::vector dense_samples(sparse_to_dense(samples)); + std::vector norm_samples; + + // make a normalizer and normalize everything + vector_normalizer normalizer; + normalizer.train(dense_samples); + for (unsigned long i = 0; i < dense_samples.size(); ++i) + norm_samples.push_back(normalizer(dense_samples[i])); + + normalized_function > dense_df; + + dense_df.normalizer = normalizer; + dense_df.function = dense_linear_trainer.train(norm_samples, labels); + + dlog << LINFO << "dense_df.function.basis_vectors.size(): "<< dense_df.function.basis_vectors.size(); + DLIB_TEST(dense_df.function.basis_vectors.size() > 4); + + dlog << LINFO << "test scores: "<< test_binary_decision_function(dense_df, dense_samples, labels); + + // save the outputs of the decision function before we mess with it + std::vector prev_vals; + for (unsigned long i = 0; i < dense_samples.size(); ++i) + prev_vals.push_back(dense_df(dense_samples[i])); + + + decision_function simple_df = simplify_linear_decision_function(dense_df); + + dlog << LINFO << "simple_df.basis_vectors.size(): "<< simple_df.basis_vectors.size(); + DLIB_TEST(simple_df.basis_vectors.size() == 1); + + dlog << LINFO << "test scores: "<< test_binary_decision_function(simple_df, dense_samples, labels); + + + // now check that the simplified decision function still produces the same results + std::vector cur_vals; + for (unsigned long i = 0; i < dense_samples.size(); ++i) + cur_vals.push_back(simple_df(dense_samples[i])); + + const double err = max(abs(mat(cur_vals) - mat(prev_vals))); + dlog << LINFO << "simplify error: "<< err; + DLIB_TEST(err < 1e-13); + + } + + } + }; + + // Create an instance of this object. Doing this causes this test + // to be automatically inserted into the testing framework whenever this cpp file + // is linked into the project. Note that since we are inside an unnamed-namespace + // we won't get any linker errors about the symbol a being defined multiple times. + sldf_tester a; + +} + + + diff --git a/ml/dlib/dlib/test/sliding_buffer.cpp b/ml/dlib/dlib/test/sliding_buffer.cpp new file mode 100644 index 000000000..449ea858e --- /dev/null +++ b/ml/dlib/dlib/test/sliding_buffer.cpp @@ -0,0 +1,439 @@ +// Copyright (C) 2004 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include + +#include +#include "tester.h" + +namespace +{ + + using namespace test; + using namespace std; + using namespace dlib; + + logger dlog("test.sliding_buffer"); + + template < + typename buf + > + void sliding_buffer_kernel_test ( + ) + /*! + requires + - buf is an implementation of sliding_buffer/sliding_buffer_kernel_abstract.h + - buf is instantiated with T=unsigned char + ensures + - runs tests on buf for compliance with the specs + !*/ + { + + print_spinner(); + + buf test; + + DLIB_TEST(test.size() == 0); + + test.set_size(3); + buf test2; + + DLIB_TEST(test.size() == 8); + + for (int g = 0; g < 2; ++g) + { + + test.clear(); + + DLIB_TEST(test.size() == 0); + test.set_size(2); + + DLIB_TEST(test.size() == 4); + + + + test[0] = 'a'; + test[1] = 's'; + test[2] = 'd'; + test[3] = 'f'; + + unsigned long id = test.get_element_id(2); + DLIB_TEST(test[test.get_element_index(id)] == 'd'); + + + DLIB_TEST(test[0] == 'a'); + DLIB_TEST(test[1] == 's'); + DLIB_TEST(test[2] == 'd'); + DLIB_TEST(test[3] == 'f'); + + DLIB_TEST(test2.size() == 0); + swap(test,test2); + DLIB_TEST(test2.size() == 4); + + DLIB_TEST(test2[0] == 'a'); + DLIB_TEST(test2[1] == 's'); + DLIB_TEST(test2[2] == 'd'); + DLIB_TEST(test2[3] == 'f'); + + swap(test,test2); + + test.rotate_left(4); + + + DLIB_TEST(test[test.get_element_index(id)] == 'd'); + + DLIB_TEST(test[0] == 'a'); + DLIB_TEST(test[1] == 's'); + DLIB_TEST(test[2] == 'd'); + DLIB_TEST(test[3] == 'f'); + + test.rotate_right(1); + + DLIB_TEST(test[test.get_element_index(id)] == 'd'); + + DLIB_TEST(test[0] == 's'); + DLIB_TEST(test[1] == 'd'); + DLIB_TEST(test[2] == 'f'); + DLIB_TEST(test[3] == 'a'); + + + test.rotate_left(1); + + DLIB_TEST(test[test.get_element_index(id)] == 'd'); + DLIB_TEST(test[0] == 'a'); + DLIB_TEST(test[1] == 's'); + DLIB_TEST(test[2] == 'd'); + DLIB_TEST(test[3] == 'f'); + + + test.rotate_left(16); + + DLIB_TEST(test[test.get_element_index(id)] == 'd'); + DLIB_TEST(test[0] == 'a'); + DLIB_TEST(test[1] == 's'); + DLIB_TEST(test[2] == 'd'); + DLIB_TEST(test[3] == 'f'); + + + test.rotate_left(2); + + DLIB_TEST(test[test.get_element_index(id)] == 'd'); + + DLIB_TEST(test[0] == 'd'); + DLIB_TEST(test[1] == 'f'); + DLIB_TEST(test[2] == 'a'); + DLIB_TEST(test[3] == 's'); + + test.rotate_left(1); + + DLIB_TEST(test[test.get_element_index(id)] == 'd'); + DLIB_TEST(test[0] == 's'); + DLIB_TEST(test[1] == 'd'); + DLIB_TEST(test[2] == 'f'); + DLIB_TEST(test[3] == 'a'); + + test.rotate_left(1); + + DLIB_TEST(test[test.get_element_index(id)] == 'd'); + DLIB_TEST(test[0] == 'a'); + DLIB_TEST(test[1] == 's'); + DLIB_TEST(test[2] == 'd'); + DLIB_TEST(test[3] == 'f'); + + DLIB_TEST(test.size() == 4); + + test[0] = 'x'; + + DLIB_TEST(test[0] == 'x'); + DLIB_TEST(test[1] == 's'); + DLIB_TEST(test[2] == 'd'); + DLIB_TEST(test[3] == 'f'); + + test.rotate_left(1); + + DLIB_TEST_MSG(test[0] == 'f',test[0]); + DLIB_TEST(test[1] == 'x'); + DLIB_TEST(test[2] == 's'); + DLIB_TEST(test[3] == 'd'); + + + test[0] = 'x'; + + DLIB_TEST(test[0] == 'x'); + DLIB_TEST(test[1] == 'x'); + DLIB_TEST(test[2] == 's'); + DLIB_TEST(test[3] == 'd'); + + + test.rotate_left(1); + + + DLIB_TEST(test[0] == 'd'); + DLIB_TEST(test[1] == 'x'); + DLIB_TEST(test[2] == 'x'); + DLIB_TEST(test[3] == 's'); + + + + DLIB_TEST(test.at_start() == true); + DLIB_TEST(test.current_element_valid() == false); + DLIB_TEST(test.move_next() == true); + DLIB_TEST(test.at_start() == false); + DLIB_TEST(test.current_element_valid() == true); + + test.clear(); + test2.clear(); + + + DLIB_TEST(test.size() == 0); + DLIB_TEST(test.at_start() == true); + DLIB_TEST(test.current_element_valid() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.at_start() == false); + DLIB_TEST(test.current_element_valid() == false); + + swap(test,test2); + + DLIB_TEST(test2.at_start() == false); + DLIB_TEST(test2.current_element_valid() == false); + DLIB_TEST(test2.move_next() == false); + DLIB_TEST(test2.at_start() == false); + DLIB_TEST(test2.current_element_valid() == false); + + + DLIB_TEST(test.at_start() == true); + DLIB_TEST(test.current_element_valid() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.at_start() == false); + + test.set_size(3); + DLIB_TEST(test.size() == 8); + + DLIB_TEST(test.at_start() == true); + DLIB_TEST(test.current_element_valid() == false); + DLIB_TEST(test.move_next() == true); + DLIB_TEST(test.at_start() == false); + DLIB_TEST(test.current_element_valid() == true); + test.reset(); + DLIB_TEST(test.size() == 8); + DLIB_TEST(test.at_start() == true); + DLIB_TEST(test.current_element_valid() == false); + DLIB_TEST(test.move_next() == true); + DLIB_TEST(test.at_start() == false); + DLIB_TEST(test.current_element_valid() == true); + + + test.rotate_right(1); + DLIB_TEST(test.size() == 8); + DLIB_TEST(test.at_start() == true); + DLIB_TEST(test.current_element_valid() == false); + DLIB_TEST(test.move_next() == true); + DLIB_TEST(test.at_start() == false); + DLIB_TEST(test.current_element_valid() == true); + + test.rotate_left(1); + DLIB_TEST(test.size() == 8); + DLIB_TEST(test.at_start() == true); + DLIB_TEST(test.current_element_valid() == false); + DLIB_TEST(test.move_next() == true); + DLIB_TEST(test.at_start() == false); + DLIB_TEST(test.current_element_valid() == true); + test.reset(); + + + for (unsigned long i = 0; i < test.size(); ++i) + { + test[i] = static_cast(i); + } + + unsigned long count = 0; + while (test.move_next()) + { + DLIB_TEST(test.element() == count); + ++count; + } + + DLIB_TEST(count == test.size()); + + + test2.clear(); + ostringstream sout; + istringstream sin; + + serialize(test,sout); + sin.str(sout.str()); + deserialize(test2,sin); + + char ch; + sin >> ch; + DLIB_TEST( !sin); + + DLIB_TEST(test2.size() == test.size()); + + + for (unsigned long i = 0; i < test.size(); ++i) + { + DLIB_TEST_MSG(test[i] == test2[i], + "\ni: " << i << + "\ntest[i]: " << test[i] << + "\ntest2[i]: " << test2[i]); + } + + count = 0; + while (test.move_next() && test2.move_next()) + { + DLIB_TEST(test.element() == count); + DLIB_TEST(test2.element() == count); + ++count; + } + + DLIB_TEST(test2.size() == count); + DLIB_TEST(test.size() == count); + + test2.clear(); + + + } // for (int g = 0; g < 2; ++g) + + + } + + + + void test_circular_buffer() + { + circular_buffer buf; + + DLIB_TEST(buf.size() == 0); + + buf.assign(4, 0); + DLIB_TEST(buf.size() == 4); + + DLIB_TEST(buf[0] == 0); + DLIB_TEST(buf[1] == 0); + DLIB_TEST(buf[2] == 0); + DLIB_TEST(buf[3] == 0); + buf.push_back(1); + DLIB_TEST(buf[0] == 0); + DLIB_TEST(buf[1] == 0); + DLIB_TEST(buf[2] == 0); + DLIB_TEST(buf[3] == 1); + buf.push_back(2); + DLIB_TEST(buf[0] == 0); + DLIB_TEST(buf[1] == 0); + DLIB_TEST(buf[2] == 1); + DLIB_TEST(buf[3] == 2); + buf.push_front(3); + DLIB_TEST(buf[0] == 3); + DLIB_TEST(buf[1] == 0); + DLIB_TEST(buf[2] == 0); + DLIB_TEST(buf[3] == 1); + buf.push_front(4); + DLIB_TEST(buf.front() == 4); + DLIB_TEST(buf[0] == 4); + DLIB_TEST(buf[1] == 3); + DLIB_TEST(buf[2] == 0); + DLIB_TEST(buf[3] == 0); + + buf.assign(4, 5); + DLIB_TEST(buf[0] == 5); + DLIB_TEST(buf[1] == 5); + DLIB_TEST(buf[2] == 5); + DLIB_TEST(buf[3] == 5); + + buf.push_back(3); + DLIB_TEST(buf[0] == 5); + DLIB_TEST(buf[1] == 5); + DLIB_TEST(buf[2] == 5); + DLIB_TEST(buf[3] == 3); + buf.push_back(2); + DLIB_TEST(buf[0] == 5); + DLIB_TEST(buf[1] == 5); + DLIB_TEST(buf[2] == 3); + DLIB_TEST(buf[3] == 2); + buf.push_back(1); + DLIB_TEST(buf[0] == 5); + DLIB_TEST(buf[1] == 3); + DLIB_TEST(buf[2] == 2); + DLIB_TEST(buf[3] == 1); + buf.push_back(0); + DLIB_TEST(buf[0] == 3); + DLIB_TEST(buf[1] == 2); + DLIB_TEST(buf[2] == 1); + DLIB_TEST(buf[3] == 0); + buf.push_back(-1); + DLIB_TEST(buf.back() == -1); + DLIB_TEST(buf[0] == 2); + DLIB_TEST(buf[1] == 1); + DLIB_TEST(buf[2] == 0); + DLIB_TEST(buf[3] == -1); + + buf.resize(1); + buf[0] = 9; + DLIB_TEST(buf.size() == 1); + DLIB_TEST(buf[0] == 9); + buf.push_back(1); + DLIB_TEST(buf[0] == 1); + buf.push_back(4); + DLIB_TEST(buf[0] == 4); + buf.push_front(3); + DLIB_TEST(buf[0] == 3); + + buf.clear(); + DLIB_TEST(buf.size() == 0); + + buf.assign(3, 0); + + circular_buffer buf2, buf3; + + buf.push_back(1); + buf.push_back(2); + + ostringstream sout; + serialize(buf, sout); + istringstream sin(sout.str()); + deserialize(buf2, sin); + + DLIB_TEST(buf.size() == buf2.size()); + for (unsigned long i = 0; i < buf.size(); ++i) + DLIB_TEST(buf[i] == buf2[i]); + + buf.swap(buf3); + DLIB_TEST(buf.size() == 0); + DLIB_TEST(buf3.size() == buf2.size()); + for (unsigned long i = 0; i < buf3.size(); ++i) + DLIB_TEST(buf3[i] == buf2[i]); + + + + } + + + + class sliding_buffer_tester : public tester + { + public: + sliding_buffer_tester ( + ) : + tester ("test_sliding_buffer", + "Runs tests on the sliding_buffer component.") + {} + + void perform_test ( + ) + { + dlog << LINFO << "testing kernel_1a"; + sliding_buffer_kernel_test::kernel_1a> (); + dlog << LINFO << "testing kernel_1a_c"; + sliding_buffer_kernel_test::kernel_1a_c>(); + + test_circular_buffer(); + } + } a; + +} + diff --git a/ml/dlib/dlib/test/smart_pointers.cpp b/ml/dlib/dlib/test/smart_pointers.cpp new file mode 100644 index 000000000..c2281efda --- /dev/null +++ b/ml/dlib/dlib/test/smart_pointers.cpp @@ -0,0 +1,449 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +// This is a legacy test for old dlib smart pointers which is excluded +// from CMakeLists.txt. Including this test will pull legacy smart_pointers.h +// code which is uncompilable on C++17 compilers + +#include +#include +#include +#include +#include + +#include "tester.h" + +// Don't warn about auto_ptr +#if (defined(__GNUC__) && ((__GNUC__ >= 4 && __GNUC_MINOR__ >= 6) || (__GNUC__ > 4))) || \ + (defined(__clang__) && ((__clang_major__ >= 3 && __clang_minor__ >= 4))) +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" +#endif + +namespace +{ + bool used_array_delete; + template + struct test_deleter + { + void operator() (T* item) const + { + used_array_delete = false; + delete item; + } + }; + + template + struct test_deleter + { + void operator() (T* item) const + { + used_array_delete = true; + delete [] item; + } + }; + + + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.smart_pointers"); + + int counter = 0; + struct base + { + int num; + virtual ~base() {} + }; + + struct derived : public base + { + derived() { ++counter; } + ~derived() { --counter; } + }; + + int deleter_called = 0; + void deleter ( derived* p) { ++deleter_called; delete p; } + void deleter_base ( base* p) { ++deleter_called; delete p; } + typedef void (*D)(derived*); + typedef void (*Db)(base*); + + void smart_pointers_test ( + ) + /*! + ensures + - runs tests on the smart pointers for compliance with the specs + !*/ + { + counter = 0; + deleter_called = 0; + + { + DLIB_TEST_MSG(counter == 0,counter); + scoped_ptr p1(new derived); + scoped_ptr p2(new derived); + scoped_ptr p3; + DLIB_TEST_MSG(counter == 2,counter); + DLIB_TEST(!p3); + + p1->num = 1; + p2->num = 2; + DLIB_TEST(p1->num == 1); + DLIB_TEST(p2->num == 2); + + (*p1).num = 3; + (*p2).num = 4; + DLIB_TEST(p1->num == 3); + DLIB_TEST(p2->num == 4); + + DLIB_TEST_MSG(counter == 2,counter); + + DLIB_TEST(p1); + DLIB_TEST(p2); + + DLIB_TEST_MSG(counter == 2,counter); + p1.reset(); + DLIB_TEST_MSG(counter == 1,counter); + DLIB_TEST(!p1); + DLIB_TEST(p2); + p1.reset(new derived); + DLIB_TEST_MSG(counter == 2,counter); + DLIB_TEST(p1); + + + DLIB_TEST_MSG(counter == 2,counter); + p2.reset(); + DLIB_TEST_MSG(counter == 1,counter); + DLIB_TEST(!p2); + derived* d = new derived; + p2.reset(d); + DLIB_TEST(p2.get() == d); + DLIB_TEST_MSG(counter == 2,counter); + DLIB_TEST(p2); + DLIB_TEST(!p3); + p2->num = 9; + swap(p2,p3); + DLIB_TEST(!p2); + DLIB_TEST(p3); + DLIB_TEST(p3->num == 9); + p2.swap(p3); + DLIB_TEST(p2); + DLIB_TEST(!p3); + DLIB_TEST(p2->num == 9); + + + DLIB_TEST_MSG(counter == 2,counter); + + } + DLIB_TEST_MSG(counter == 0,counter); + + { + base* realp1 = new derived; + derived* realp2 = new derived; + dlib::shared_ptr p1(realp1); + dlib::shared_ptr p2(realp2,&deleter); + dlib::shared_ptr p3; + dlib::shared_ptr p4; + DLIB_TEST(p4.get() == 0); + DLIB_TEST(p1); + DLIB_TEST(p2); + DLIB_TEST(!p3); + DLIB_TEST(!p4); + DLIB_TEST(p1.get() == realp1); + DLIB_TEST(p2.get() == realp2); + p1->num = 1; + p2->num = 2; + DLIB_TEST((*p1).num == 1); + DLIB_TEST((*p2).num == 2); + + p1.swap(p3); + DLIB_TEST(!p1); + DLIB_TEST(p3); + DLIB_TEST((*p3).num == 1); + DLIB_TEST(p3->num == 1); + swap(p1,p3); + DLIB_TEST(p1); + DLIB_TEST(!p3); + DLIB_TEST((*p1).num == 1); + DLIB_TEST(p1->num == 1); + DLIB_TEST_MSG(counter == 2,counter); + + DLIB_TEST(p1.unique()); + DLIB_TEST(p2.unique()); + DLIB_TEST(!p3.unique()); + DLIB_TEST(!p4.unique()); + + DLIB_TEST(p1.use_count() == 1); + DLIB_TEST(p2.use_count() == 1); + DLIB_TEST(p3.use_count() == 0); + DLIB_TEST(p4.use_count() == 0); + + dlib::shared_ptr p11(p1); + + DLIB_TEST(!p1.unique()); + DLIB_TEST(p2.unique()); + DLIB_TEST(!p3.unique()); + DLIB_TEST(!p4.unique()); + + DLIB_TEST(p1.use_count() == 2); + DLIB_TEST(p2.use_count() == 1); + DLIB_TEST(p3.use_count() == 0); + DLIB_TEST(p4.use_count() == 0); + + dlib::shared_ptr p22(p2); + + DLIB_TEST(!p1.unique()); + DLIB_TEST(!p2.unique()); + DLIB_TEST(!p3.unique()); + DLIB_TEST(!p4.unique()); + + DLIB_TEST(p1.use_count() == 2); + DLIB_TEST(p2.use_count() == 2); + DLIB_TEST(p3.use_count() == 0); + DLIB_TEST(p4.use_count() == 0); + + DLIB_TEST(p11.get() == realp1); + DLIB_TEST(p11 == p1); + DLIB_TEST(p22 == p2); + DLIB_TEST(p3 == p4); + DLIB_TEST(p11 != p22); + DLIB_TEST(p1 != p2); + DLIB_TEST(p3 != p1); + DLIB_TEST(p3 != p11); + DLIB_TEST(p3 != p2); + + + p1 = p1 = p1; + DLIB_TEST(p1.use_count() == 2); + DLIB_TEST(p1->num == 1); + DLIB_TEST(p11.use_count() == 2); + p1.reset(); + DLIB_TEST(p1.get() == 0); + DLIB_TEST(p1.use_count() == 0); + DLIB_TEST(p1.unique() == false); + DLIB_TEST(p11.use_count() == 1); + p11 = p2; + DLIB_TEST(p1.use_count() == 0); + DLIB_TEST(p1.unique() == false); + DLIB_TEST(p11.use_count() == 3); + DLIB_TEST(p11.unique() == false); + + // now p11, p2, and p22 all reference the same thing and the rest are null + DLIB_TEST_MSG((p11 < p2) == false,""); + DLIB_TEST_MSG((p2 < p11) == false,""); + + DLIB_TEST(get_deleter(p4) == 0); + p4 = p2; + DLIB_TEST(get_deleter(p4) != 0); + DLIB_TEST(get_deleter(p4) == get_deleter(p2)); + DLIB_TEST(get_deleter(p4) == get_deleter(p11)); + DLIB_TEST(get_deleter(p4) == 0); + + realp1 = new derived; + p1.reset(realp1, &deleter_base); + DLIB_TEST(p1.get() == realp1); + DLIB_TEST(p1.unique()); + DLIB_TEST(p1.use_count() == 1); + DLIB_TEST(*get_deleter(p1) == &deleter_base); + DLIB_TEST(p1 != p4); + p4 = dynamic_pointer_cast(p1); + DLIB_TEST(!p1.unique()); + DLIB_TEST(p1.use_count() == 2); + DLIB_TEST(p1 == p4); + + realp1 = new derived; + p1.reset(realp1); + DLIB_TEST(p1.get() == realp1); + DLIB_TEST(p1.unique()); + DLIB_TEST(p1.use_count() == 1); + DLIB_TEST(get_deleter(p1) == 0); + + + auto_ptr ap1(new derived); + auto_ptr ap2(new derived); + ap1->num = 35; + ap2->num = 36; + + DLIB_TEST(ap1.get() != 0); + DLIB_TEST(ap2.get() != 0); + p1 = ap2; + p2 = ap1; + + DLIB_TEST(ap1.get() == 0); + DLIB_TEST(p1.unique()); + DLIB_TEST(p1.use_count() == 1); + DLIB_TEST(ap2.get() == 0); + DLIB_TEST(p2.unique()); + DLIB_TEST(p2.use_count() == 1); + DLIB_TEST(p1->num == 36); + DLIB_TEST(p2->num == 35); + + } + + DLIB_TEST_MSG(counter == 0,counter); + DLIB_TEST_MSG(deleter_called == 2,counter); + + dlib::weak_ptr wp4; + { + dlib::shared_ptr p1(new derived, &deleter_base); + dlib::shared_ptr p2; + dlib::shared_ptr p3; + + dlib::weak_ptr wp1; + dlib::weak_ptr wp2; + dlib::weak_ptr wp3; + + dlib::weak_ptr wp1c(p1); + dlib::weak_ptr wp2c(p1); + dlib::weak_ptr wp3c(p2); + + DLIB_TEST(wp1c.use_count() == 1); + DLIB_TEST(wp1c.lock() == p1); + DLIB_TEST(wp1c.expired() == false); + + DLIB_TEST(wp2c.use_count() == 1); + DLIB_TEST(wp2c.lock() == p1); + DLIB_TEST(wp2c.expired() == false); + + DLIB_TEST(wp3c.use_count() == 0); + DLIB_TEST(wp3c.lock() == dlib::shared_ptr()); + DLIB_TEST(wp3c.expired() == true); + + DLIB_TEST(wp2.use_count() == 0); + DLIB_TEST(wp2.expired() == true); + DLIB_TEST(wp2.lock().use_count() == 0); + DLIB_TEST(wp2.lock().unique() == false); + + wp1 = p1; + wp2 = wp1; + wp3 = p1; + + DLIB_TEST(p1.use_count() == 1); + DLIB_TEST(p1.unique()); + DLIB_TEST(wp1.use_count() == 1); + DLIB_TEST(wp2.use_count() == 1); + DLIB_TEST(wp3.use_count() == 1); + DLIB_TEST(wp1.expired() == false); + DLIB_TEST(wp2.expired() == false); + DLIB_TEST(wp3.expired() == false); + DLIB_TEST(wp1.lock() == p1); + DLIB_TEST(wp2.lock() == p1); + DLIB_TEST(wp3.lock() == p1); + + wp3.reset(); + + DLIB_TEST(p1.use_count() == 1); + DLIB_TEST(p1.unique()); + DLIB_TEST(wp1.use_count() == 1); + DLIB_TEST(wp2.use_count() == 1); + DLIB_TEST(wp3.use_count() == 0); + DLIB_TEST(wp1.expired() == false); + DLIB_TEST(wp2.expired() == false); + DLIB_TEST(wp3.expired() == true); + DLIB_TEST(wp1.lock() == p1); + DLIB_TEST(wp2.lock() == p1); + DLIB_TEST(wp3.lock() == dlib::shared_ptr()); + + + p1.reset(); + + DLIB_TEST(p1.use_count() == 0); + DLIB_TEST(p1.unique() == false); + DLIB_TEST(wp1.use_count() == 0); + DLIB_TEST(wp2.use_count() == 0); + DLIB_TEST(wp3.use_count() == 0); + DLIB_TEST(wp1.expired() == true); + DLIB_TEST(wp2.expired() == true); + DLIB_TEST(wp3.expired() == true); + DLIB_TEST(wp1.lock() == dlib::shared_ptr()); + DLIB_TEST(wp2.lock() == dlib::shared_ptr()); + DLIB_TEST(wp3.lock() == dlib::shared_ptr()); + + p1.reset(new derived); + + DLIB_TEST(p1.use_count() == 1); + DLIB_TEST(p1.unique() == true); + DLIB_TEST(wp1.use_count() == 0); + DLIB_TEST(wp2.use_count() == 0); + DLIB_TEST(wp3.use_count() == 0); + DLIB_TEST(wp1.expired() == true); + DLIB_TEST(wp2.expired() == true); + DLIB_TEST(wp3.expired() == true); + DLIB_TEST(wp1.lock() == dlib::shared_ptr()); + DLIB_TEST(wp2.lock() == dlib::shared_ptr()); + DLIB_TEST(wp3.lock() == dlib::shared_ptr()); + + DLIB_TEST(wp4.expired() == true); + DLIB_TEST(wp4.lock() == dlib::shared_ptr()); + wp4 = p1; + p3 = p1; + DLIB_TEST(wp4.expired() == false); + DLIB_TEST(wp4.lock() == p3); + + + bool ok = false; + try { + dlib::shared_ptr bad_ptr(wp1); + } catch (dlib::bad_weak_ptr&) + { + ok = true; + } + DLIB_TEST(ok); + } + DLIB_TEST(wp4.expired() == true); + DLIB_TEST(wp4.lock() == dlib::shared_ptr()); + + + DLIB_TEST_MSG(counter == 0,counter); + DLIB_TEST_MSG(deleter_called == 3,counter); + + { + scoped_ptr a(new int[10]); + + { + used_array_delete = false; + scoped_ptr > b(new int[10]); + + for (int i = 0; i < 10; ++i) + { + a[i] = i; + b[i] = i; + } + } + DLIB_TEST(used_array_delete == true); + + + { + used_array_delete = true; + scoped_ptr > c(new int); + } + DLIB_TEST(used_array_delete == false); + + scoped_ptr const_a(new int[10]); + + } + + } + + + + class smart_pointers_tester : public tester + { + public: + smart_pointers_tester ( + ) : + tester ("test_smart_pointers", + "Runs tests on the smart pointers.") + {} + + void perform_test ( + ) + { + smart_pointers_test(); + } + } a; + +} + + + diff --git a/ml/dlib/dlib/test/sockets.cpp b/ml/dlib/dlib/test/sockets.cpp new file mode 100644 index 000000000..920fa9402 --- /dev/null +++ b/ml/dlib/dlib/test/sockets.cpp @@ -0,0 +1,247 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "tester.h" + +namespace { + + using namespace test; + using namespace dlib; + using namespace std; + + dlib::mutex gm; + dlib::signaler gs(gm); + const char magic_num = 42; + const int min_bytes_sent = 10000; + int assigned_port; + + logger dlog("test.sockets"); + +// ---------------------------------------------------------------------------------------- + + class serv : public server::kernel_1a_c + { + public: + serv ( + ) : + error_occurred (false), + got_connections(false) + {} + + void on_listening_port_assigned ( + ) + { + auto_mutex M(gm); + assigned_port = get_listening_port(); + gs.broadcast(); + } + + + void on_connect ( + connection& con + ) + { + dlog << LINFO << "in serv::on_connect(): got new connection"; + int status; + int count = 0; + char buf[100]; + while ((status = con.read(buf,sizeof(buf))) > 0) + { + for (int i = 0; i < status; ++i) + { + if (buf[i] != magic_num) + { + tag = 4.0; + error_occurred = true; + } + } + count += status; + } + if (count != min_bytes_sent) + { + tag = 5.0; + error_occurred = true; + } + got_connections = true; + dlog << LINFO << "in serv::on_connect(): on_connect ending"; + } + + bool error_occurred; + bool got_connections; + double tag; + }; + +// ---------------------------------------------------------------------------------------- + + class thread_container : public multithreaded_object + { + public: + + serv& srv; + + thread_container ( + serv& srv_ + ) : srv(srv_) + { + for (int i = 0; i < 10; ++i) + register_thread(*this, &thread_container::thread_proc); + + // start up the threads + start(); + } + + ~thread_container () + { + // wait for all threads to terminate + wait(); + } + + void thread_proc ( + ) + { + try + { + dlog << LTRACE << "enter thread"; + { + auto_mutex M(gm); + while (assigned_port == 0) + gs.wait(); + } + + int status; + std::unique_ptr con; + string hostname; + string ip; + status = get_local_hostname(hostname); + if (status) + { + srv.tag = 1.0; + srv.error_occurred = true; + srv.clear(); + dlog << LERROR << "leaving thread, line: " << __LINE__; + dlog << LERROR << "get_local_hostname() failed"; + return; + } + + status = hostname_to_ip(hostname,ip); + if (status) + { + srv.tag = 2.0; + srv.error_occurred = true; + srv.clear(); + dlog << LERROR << "leaving thread, line: " << __LINE__; + dlog << LERROR << "hostname_to_ip() failed"; + return; + } + + dlog << LTRACE << "try to connect to the server at port " << srv.get_listening_port(); + status = create_connection(con,srv.get_listening_port(),ip); + if (status) + { + srv.tag = 3.0; + srv.error_occurred = true; + srv.clear(); + dlog << LERROR << "leaving thread, line: " << __LINE__; + dlog << LERROR << "create_connection() failed"; + return; + } + + dlog << LTRACE << "sending magic_num to server"; + int i; + for (i = 0; i < min_bytes_sent; ++i) + { + con->write(&magic_num,1); + } + + dlog << LTRACE << "shutting down connection to server"; + close_gracefully(con); + dlog << LTRACE << "finished calling close_gracefully() on the connection"; + } + catch (exception& e) + { + srv.error_occurred = true; + dlog << LERROR << "exception thrown in thread_proc(): " << e.what(); + cout << "exception thrown in thread_proc(): " << e.what(); + } + dlog << LTRACE << "exit thread"; + } + }; + + void run_server(serv* srv) + { + dlog << LTRACE << "calling srv.start()"; + srv->start(); + dlog << LTRACE << "srv.start() just ended."; + } + + void sockets_test ( + ) + /*! + requires + - sockets is an implementation of sockets/sockets_kernel_abstract.h + is instantiated with int + ensures + - runs tests on sockets for compliance with the specs + !*/ + { + + dlog << LTRACE << "starting test"; + serv srv; + + assigned_port = 0; + + + dlog << LTRACE << "spawning threads"; + thread_container stuff(srv); + + + + thread_function thread2(run_server, &srv); + + // wait until all the sending threads have ended + stuff.wait(); + + if (srv.error_occurred) + { + dlog << LDEBUG << "tag: " << srv.tag; + } + + srv.clear(); + + dlog << LTRACE << "ending successful test"; + DLIB_TEST( !srv.error_occurred); + DLIB_TEST( srv.got_connections); + } + +// ---------------------------------------------------------------------------------------- + + + class sockets_tester : public tester + { + public: + sockets_tester ( + ) : + tester ("test_sockets", + "Runs tests on the sockets component.") + {} + + void perform_test ( + ) + { + sockets_test(); + } + } a; + +} + diff --git a/ml/dlib/dlib/test/sockets2.cpp b/ml/dlib/dlib/test/sockets2.cpp new file mode 100644 index 000000000..3521e751d --- /dev/null +++ b/ml/dlib/dlib/test/sockets2.cpp @@ -0,0 +1,204 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include +#include + +#include "tester.h" +#include +#include +#include + +// This is called an unnamed-namespace and it has the effect of making everything +// inside this file "private" so that everything you declare will have static linkage. +// Thus we won't have any multiply defined symbol errors coming out of the linker when +// we try to compile the test suite. +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + // Declare the logger we will use in this test. The name of the logger + // should start with "test." + dlib::logger dlog("test.sockets2"); + + + class sockets2_tester : public tester, private multithreaded_object + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a unit test. When it is constructed + it adds itself into the testing framework. + !*/ + + short port_num; + string data_to_send; + + bool test_failed; + + void write_thread ( + ) + { + try + { + std::unique_ptr con(connect("127.0.0.1", port_num)); + + // Send a copy of the data down the connection so we can test our the read() function + // that uses timeouts in the main thread. + if (con->write(data_to_send.data(), data_to_send.size()) != (int)data_to_send.size()) + { + test_failed = true; + dlog << LERROR << "failed to send all the data down the connection"; + } + + close_gracefully(con,300000); + } + catch (exception& e) + { + test_failed = true; + dlog << LERROR << e.what(); + } + } + + void no_write_thread ( + ) + { + try + { + std::unique_ptr con(connect("127.0.0.1", port_num)); + + // just do nothing until the connection closes + char ch; + con->read(&ch, 1); + dlog << LDEBUG << "silent connection finally closing"; + } + catch (exception& e) + { + test_failed = true; + dlog << LERROR << e.what(); + } + } + + public: + sockets2_tester ( + ) : + tester ( + "test_sockets2", // the command line argument name for this test + "Run sockets2 tests.", // the command line argument description + 0 // the number of command line arguments for this test + ) + { + register_thread(*this, &sockets2_tester::write_thread); + register_thread(*this, &sockets2_tester::write_thread); + register_thread(*this, &sockets2_tester::write_thread); + register_thread(*this, &sockets2_tester::write_thread); + register_thread(*this, &sockets2_tester::write_thread); + register_thread(*this, &sockets2_tester::no_write_thread); + } + + void perform_test ( + ) + { + run_tests(0); + run_tests(40); + } + + void run_tests ( + unsigned long timeout_to_use + ) + { + // make sure there aren't any threads running + wait(); + + port_num = 5000; + test_failed = false; + + print_spinner(); + data_to_send = "oi 2m3ormao2m fo2im3fo23mi o2mi3 foa2m3fao23ifm2o3fmia23oima23iom3giugbiua"; + // make the block of data much larger + for (int i = 0; i < 11; ++i) + data_to_send = data_to_send + data_to_send; + + dlog << LINFO << "data block size: " << data_to_send.size(); + + + std::unique_ptr list; + DLIB_TEST(create_listener(list, port_num, "127.0.0.1") == 0); + DLIB_TEST(bool(list)); + + // kick off the sending threads + start(); + + + dlib::array > cons; + std::vector bytes_received(6,0); + std::unique_ptr con_temp; + + // accept the 6 connections we should get + for (int i = 0; i < 6; ++i) + { + DLIB_TEST(list->accept(con_temp) == 0); + cons.push_back(con_temp); + print_spinner(); + } + + int finished_cons = 0; + + // now receive all the bytes from the sending threads + while (finished_cons < 5) + { + for (unsigned long i = 0; i < cons.size(); ++i) + { + if (cons[i]) + { + const int buf_size = 3000; + char buf[buf_size]; + + int status = cons[i]->read(buf, buf_size, timeout_to_use); + + if (status > 0) + { + DLIB_TEST(equal(buf, buf+status, data_to_send.begin()+bytes_received[i])); + bytes_received[i] += status; + } + else if (status == 0) + { + // the connection is closed to kill it + cons[i].reset(); + ++finished_cons; + } + } + } + print_spinner(); + } + + for (unsigned long i = 0; i < bytes_received.size(); ++i) + { + DLIB_TEST(bytes_received[i] == (long)data_to_send.size() || cons[i]); + } + + + dlog << LINFO << "All data received correctly"; + + cons.clear(); + + + print_spinner(); + + DLIB_TEST(test_failed == false); + + + // wait for all the sending threads to terminate + wait(); + } + }; + + // Create an instance of this object. Doing this causes this test + // to be automatically inserted into the testing framework whenever this cpp file + // is linked into the project. Note that since we are inside an unnamed-namespace + // we won't get any linker errors about the symbol a being defined multiple times. + sockets2_tester a; + +} + + diff --git a/ml/dlib/dlib/test/sockstreambuf.cpp b/ml/dlib/dlib/test/sockstreambuf.cpp new file mode 100644 index 000000000..519feb2b5 --- /dev/null +++ b/ml/dlib/dlib/test/sockstreambuf.cpp @@ -0,0 +1,253 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "tester.h" + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + dlib::mutex m; + dlib::signaler s(m); + bool thread_running; + + logger dlog("test.sockstreambuf"); + +// ---------------------------------------------------------------------------------------- + + template + struct thread_proc_struct + { + static void thread_proc ( + void* param + ) + { + + listener& list = *static_cast(param); + connection* con; + list.accept(con); + + ssb buf(con); + ostream out(&buf); + + + char ch; + char* bigbuf = new char[1000000]; + + + for (int i = 'a'; i < 'z'; ++i) + { + ch = i; + out << ch << " "; + } + + out.put('A'); + + for (int i = 0; i < 256; ++i) + { + ch = i; + out.write(&ch,1); + } + + for (int i = -100; i < 25600; ++i) + { + out << i << " "; + } + + out.put('A'); + + for (int i = -100; i < 25600; ++i) + { + out.write((char*)&i,sizeof(i)); + } + + for (int i = 0; i < 1000000; ++i) + { + bigbuf[i] = (i&0xFF); + } + out.write(bigbuf,1000000); + + out.put('d'); + out.put('a'); + out.put('v'); + out.put('i'); + out.put('s'); + + + string tstring = "this is a test"; + int tint = -853; + unsigned int tuint = 89; + serialize(tstring,out); + serialize(tint,out); + serialize(tuint,out); + + + out.flush(); + + + auto_mutex M(m); + thread_running = false; + s.signal(); + + dlib::sleep(300); + delete con; + delete &list; + + delete [] bigbuf; + } + }; + + template + void sockstreambuf_test ( + ) + /*! + requires + - ssb is an implementation of sockstreambuf/sockstreambuf_kernel_abstract.h + ensures + - runs tests on ssb for compliance with the specs + !*/ + { + char ch; + vector vbuf; + vbuf.resize(1000000); + char* bigbuf = &vbuf[0]; + connection* con; + + print_spinner(); + thread_running = true; + listener* list; + if (create_listener(list,0)) + { + DLIB_TEST_MSG(false, "Unable to create a listener"); + } + + create_new_thread(&thread_proc_struct::thread_proc,list); + + if (create_connection(con,list->get_listening_port(),"127.0.0.1")) + { + DLIB_TEST_MSG(false, "Unable to create a connection"); + } + + // make sure con gets deleted + std::unique_ptr del_con(con); + + ssb buf(con); + istream in(&buf); + + + + for (int i = 'a'; i < 'z'; ++i) + { + in >> ch; + char c = i; + DLIB_TEST_MSG(ch == c,"ch: " << (int)ch << " c: " << (int)c); + } + + in.get(); + DLIB_TEST_MSG(in.peek() == 'A', "*" << in.peek() << "*"); + in.get(); + + for (int i = 0; i < 256; ++i) + { + in.read(&ch,1); + char c = i; + DLIB_TEST_MSG(ch == c,"ch: " << (int)ch << " c: " << (int)c ); + } + + for (int i = -100; i < 25600; ++i) + { + int n = 0; + in >> n; + DLIB_TEST_MSG(n == i,"n: " << n << " i:" << i); + } + + in.get(); + DLIB_TEST_MSG(in.peek() == 'A', "*" << in.peek() << "*"); + in.get(); + + for (int i = -100; i < 25600; ++i) + { + int n; + in.read((char*)&n,sizeof(n)); + DLIB_TEST_MSG(n == i,"n: " << n << " i:" << i); + } + + in.read(bigbuf,1000000); + for (int i = 0; i < 1000000; ++i) + { + DLIB_TEST(bigbuf[i] == (char)(i&0xFF)); + } + + DLIB_TEST(in.get() == 'd'); + DLIB_TEST(in.get() == 'a'); + DLIB_TEST(in.get() == 'v'); + DLIB_TEST(in.get() == 'i'); + + DLIB_TEST(in.peek() == 's'); + + DLIB_TEST(in.get() == 's'); + + in.putback('s'); + DLIB_TEST(in.peek() == 's'); + + DLIB_TEST(in.get() == 's'); + + + string tstring; + int tint; + unsigned int tuint; + deserialize(tstring,in); + deserialize(tint,in); + deserialize(tuint,in); + + DLIB_TEST(tstring == "this is a test"); + DLIB_TEST(tint == -853); + DLIB_TEST(tuint == 89); + + + + auto_mutex M(m); + while (thread_running) + s.wait(); + + } + +// ---------------------------------------------------------------------------------------- + + + class sockstreambuf_tester : public tester + { + public: + sockstreambuf_tester ( + ) : + tester ("test_sockstreambuf", + "Runs tests on the sockstreambuf component.") + {} + + void perform_test ( + ) + { + dlog << LINFO << "testing sockstreambuf"; + sockstreambuf_test(); + dlog << LINFO << "testing sockstreambuf_unbuffered"; + sockstreambuf_test(); + } + } a; + +} + + diff --git a/ml/dlib/dlib/test/sparse_vector.cpp b/ml/dlib/dlib/test/sparse_vector.cpp new file mode 100644 index 000000000..97b60b7b3 --- /dev/null +++ b/ml/dlib/dlib/test/sparse_vector.cpp @@ -0,0 +1,301 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include +#include "tester.h" +#include +#include +#include +#include +#include + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + dlib::logger dlog("test.sparse_vector"); + + void test_sparse_matrix_vector_multiplies() + { + dlib::rand rnd; + + const long size = 30; + + for (int iter = 0; iter < 10; ++iter) + { + print_spinner(); + + std::vector edges; + std::vector oedges; + matrix M(size,size); + M = 0; + for (long i = 0; i < M.size()/3; ++i) + { + const long r = rnd.get_random_32bit_number()%M.nr(); + const long c = rnd.get_random_32bit_number()%M.nc(); + const double d = rnd.get_random_gaussian()*10; + M(r,c) += d; + oedges.push_back(ordered_sample_pair(r,c,d)); + } + + matrix SM(size,size); + SM = 0; + for (long i = 0; i < SM.size()/3; ++i) + { + const long r = rnd.get_random_32bit_number()%SM.nr(); + const long c = rnd.get_random_32bit_number()%SM.nc(); + const double d = rnd.get_random_gaussian()*10; + SM(r,c) += d; + if (r != c) + SM(c,r) += d; + edges.push_back(sample_pair(r,c,d)); + } + + const matrix v = randm(size,1); + + matrix result; + + sparse_matrix_vector_multiply(oedges, v, result); + DLIB_TEST_MSG(length(M*v - result) < 1e-12, length(M*v - result)); + + sparse_matrix_vector_multiply(edges, v, result); + DLIB_TEST_MSG(length(SM*v - result) < 1e-12, length(SM*v - result)); + + } + } + +// ---------------------------------------------------------------------------------------- + + void test_sparse_matrix_vector_multiply1() + { + print_spinner(); + std::map sv; + sv[2] = 8; + sv[6] = 2.3; + + matrix v; + v = 0; + v(2) = 8; + v(6) = 2.3; + + + matrix r1, r2; + + r1 = gaussian_randm(4,10)*v; + r2 = sparse_matrix_vector_multiply(gaussian_randm(4,std::numeric_limits::max()),sv); + + DLIB_TEST(max(abs(r1-r2)) < 1e-15); + } + +// ---------------------------------------------------------------------------------------- + + void test_sparse_matrix_vector_multiply2() + { + std::vector > sv; + sv.push_back(make_pair(6, 1.42)); + sv.push_back(make_pair(3, 5)); + + matrix v; + v = 0; + v(3) = 5; + v(6) = 1.42; + + + matrix r1, r2; + + r1 = gaussian_randm(3,9)*v; + r2 = sparse_matrix_vector_multiply(gaussian_randm(3,std::numeric_limits::max()),sv); + + DLIB_TEST(max(abs(r1-r2)) < 1e-15); + } + +// ---------------------------------------------------------------------------------------- + + void test_make_sparse_vector_inplace() + { + std::vector > vect; + vect.push_back(make_pair(4,1)); + vect.push_back(make_pair(0,1)); + vect.push_back(make_pair(4,1)); + vect.push_back(make_pair(3,1)); + vect.push_back(make_pair(8,1)); + vect.push_back(make_pair(8,1)); + vect.push_back(make_pair(8,1)); + vect.push_back(make_pair(8,1)); + + make_sparse_vector_inplace(vect); + + DLIB_TEST(vect.size() == 4); + DLIB_TEST(vect[0].first == 0); + DLIB_TEST(vect[1].first == 3); + DLIB_TEST(vect[2].first == 4); + DLIB_TEST(vect[3].first == 8); + + DLIB_TEST(vect[0].second == 1); + DLIB_TEST(vect[1].second == 1); + DLIB_TEST(vect[2].second == 2); + DLIB_TEST(vect[3].second == 4); + } + +// ---------------------------------------------------------------------------------------- + + class sparse_vector_tester : public tester + { + public: + sparse_vector_tester ( + ) : + tester ( + "test_sparse_vector", // the command line argument name for this test + "Run tests on the sparse_vector routines.", // the command line argument description + 0 // the number of command line arguments for this test + ) + { + } + + + void perform_test ( + ) + { + test_make_sparse_vector_inplace(); + + std::map v; + v[4] = 8; + v[2] = -4; + v[9] = 10; + + DLIB_TEST(max(v) == 10); + DLIB_TEST(min(v) == -4); + + v.clear(); + v[4] = 8; + v[9] = 10; + DLIB_TEST(max(v) == 10); + DLIB_TEST(min(v) == 0); + + + v.clear(); + v[4] = -9; + v[9] = -4; + DLIB_TEST(max(v) == 0); + DLIB_TEST(min(v) == -9); + + + { + matrix a(2,2), b(2,2); + a = randm(2,2); + b = randm(2,2); + + DLIB_TEST(equal(a-b, subtract(a,b))); + DLIB_TEST(equal(a+b, add(a,b))); + DLIB_TEST(equal(a-(b+b), subtract(a,b+b))); + DLIB_TEST(equal(a+b+b, add(a,b+b))); + } + + { + std::map a, b, c; + a[1] = 2; + a[3] = 5; + + b[0] = 3; + b[1] = 1; + + c = add(a,b); + DLIB_TEST(c.size() == 3); + DLIB_TEST(c[0] == 3); + DLIB_TEST(c[1] == 3); + DLIB_TEST(c[3] == 5); + + c = subtract(a,b); + DLIB_TEST(c.size() == 3); + DLIB_TEST(c[0] == -3); + DLIB_TEST(c[1] == 1); + DLIB_TEST(c[3] == 5); + + c = add(b,a); + DLIB_TEST(c.size() == 3); + DLIB_TEST(c[0] == 3); + DLIB_TEST(c[1] == 3); + DLIB_TEST(c[3] == 5); + + c = subtract(b,a); + DLIB_TEST(c.size() == 3); + DLIB_TEST(c[0] == 3); + DLIB_TEST(c[1] == -1); + DLIB_TEST(c[3] == -5); + + std::vector > aa, bb, cc; + + aa.assign(a.begin(), a.end()); + bb.assign(b.begin(), b.end()); + + cc = add(aa,bb); + DLIB_TEST(cc.size() == 3); + DLIB_TEST(cc[0].first == 0); + DLIB_TEST(cc[1].first == 1); + DLIB_TEST(cc[2].first == 3); + DLIB_TEST(cc[0].second == 3); + DLIB_TEST(cc[1].second == 3); + DLIB_TEST(cc[2].second == 5); + + cc = subtract(aa,bb); + DLIB_TEST(cc.size() == 3); + DLIB_TEST(cc[0].first == 0); + DLIB_TEST(cc[1].first == 1); + DLIB_TEST(cc[2].first == 3); + DLIB_TEST(cc[0].second == -3); + DLIB_TEST(cc[1].second == 1); + DLIB_TEST(cc[2].second == 5); + + cc = add(bb,aa); + DLIB_TEST(cc.size() == 3); + DLIB_TEST(cc[0].first == 0); + DLIB_TEST(cc[1].first == 1); + DLIB_TEST(cc[2].first == 3); + DLIB_TEST(cc[0].second == 3); + DLIB_TEST(cc[1].second == 3); + DLIB_TEST(cc[2].second == 5); + + cc = subtract(bb,aa); + DLIB_TEST(cc.size() == 3); + DLIB_TEST(cc[0].first == 0); + DLIB_TEST(cc[1].first == 1); + DLIB_TEST(cc[2].first == 3); + DLIB_TEST(cc[0].second == 3); + DLIB_TEST(cc[1].second == -1); + DLIB_TEST(cc[2].second == -5); + + } + + test_sparse_matrix_vector_multiplies(); + test_sparse_matrix_vector_multiply1(); + test_sparse_matrix_vector_multiply2(); + + { + matrix a, b; + a = gaussian_randm(6,1, 0); + b = gaussian_randm(6,1, 1); + + std::vector > aa, bb; + + assign(aa, a); + assign(bb, b); + + // dot() does something special when the sparse vectors have entries for + // each dimension, which is what happens when they are copied from dense + // vectors. So the point of the tests in this block is to make sure dot() + // works right in this case. + DLIB_TEST(std::abs(dot(a,b) - dot(aa,bb)) < 1e-14); + a(3) = 0; + assign(aa, a); + DLIB_TEST(std::abs(dot(a,b) - dot(aa,bb)) < 1e-14); + } + } + }; + + sparse_vector_tester a; + +} + + + diff --git a/ml/dlib/dlib/test/stack.cpp b/ml/dlib/dlib/test/stack.cpp new file mode 100644 index 000000000..0b92eaeb9 --- /dev/null +++ b/ml/dlib/dlib/test/stack.cpp @@ -0,0 +1,294 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include + +#include "tester.h" + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.stack"); + + template < + typename stack + > + void stack_kernel_test ( + ) + /*! + requires + - stack is an implementation of stack/stack_sort_abstract.h + stack is instantiated with int + ensures + - runs tests on stack for compliance with the specs + !*/ + { + + + srand(static_cast(time(0))); + + print_spinner(); + + stack a1, a2; + + + + DLIB_TEST(a1.size() == 0); + DLIB_TEST(a1.at_start()); + DLIB_TEST(a1.current_element_valid() == false); + DLIB_TEST(a1.move_next() == false); + DLIB_TEST(a1.size() == 0); + DLIB_TEST(a1.current_element_valid() == false); + DLIB_TEST(a1.at_start() == false); + DLIB_TEST(a1.move_next() == false); + DLIB_TEST(a1.current_element_valid() == false); + DLIB_TEST(a1.size() == 0); + DLIB_TEST(a1.at_start() == false); + DLIB_TEST(a1.size() == 0); + + swap(a1,a2); + DLIB_TEST(a2.size() == 0); + DLIB_TEST(a2.current_element_valid() == false); + DLIB_TEST(a2.at_start() == false); + DLIB_TEST(a2.move_next() == false); + DLIB_TEST(a2.current_element_valid() == false); + DLIB_TEST(a2.size() == 0); + DLIB_TEST(a2.at_start() == false); + DLIB_TEST(a2.size() == 0); + + + + DLIB_TEST(a1.size() == 0); + DLIB_TEST(a1.at_start()); + DLIB_TEST(a1.current_element_valid() == false); + DLIB_TEST(a1.move_next() == false); + DLIB_TEST(a1.size() == 0); + DLIB_TEST(a1.current_element_valid() == false); + DLIB_TEST(a1.at_start() == false); + DLIB_TEST(a1.move_next() == false); + DLIB_TEST(a1.current_element_valid() == false); + DLIB_TEST(a1.size() == 0); + DLIB_TEST(a1.at_start() == false); + DLIB_TEST(a1.size() == 0); + + a1.reset(); + a2.reset(); + + for (unsigned long k = 0; k < 4; ++k) + { + + DLIB_TEST(a1.size() == 0); + DLIB_TEST(a1.at_start()); + DLIB_TEST(a1.current_element_valid() == false); + DLIB_TEST(a1.move_next() == false); + DLIB_TEST(a1.size() == 0); + DLIB_TEST(a1.current_element_valid() == false); + DLIB_TEST(a1.at_start() == false); + DLIB_TEST(a1.move_next() == false); + DLIB_TEST(a1.current_element_valid() == false); + DLIB_TEST(a1.size() == 0); + DLIB_TEST(a1.at_start() == false); + DLIB_TEST(a1.size() == 0); + + swap(a1,a2); + DLIB_TEST(a2.size() == 0); + DLIB_TEST(a2.current_element_valid() == false); + DLIB_TEST(a2.at_start() == false); + DLIB_TEST(a2.move_next() == false); + DLIB_TEST(a2.current_element_valid() == false); + DLIB_TEST(a2.size() == 0); + DLIB_TEST(a2.at_start() == false); + DLIB_TEST(a2.size() == 0); + + + + DLIB_TEST(a1.size() == 0); + DLIB_TEST(a1.at_start()); + DLIB_TEST(a1.current_element_valid() == false); + DLIB_TEST(a1.move_next() == false); + DLIB_TEST(a1.size() == 0); + DLIB_TEST(a1.current_element_valid() == false); + DLIB_TEST(a1.at_start() == false); + DLIB_TEST(a1.move_next() == false); + DLIB_TEST(a1.current_element_valid() == false); + DLIB_TEST(a1.size() == 0); + DLIB_TEST(a1.at_start() == false); + DLIB_TEST(a1.size() == 0); + + a1.clear(); + a2.clear(); + + + DLIB_TEST(a1.size() == 0); + DLIB_TEST(a1.at_start()); + DLIB_TEST(a1.current_element_valid() == false); + DLIB_TEST(a1.move_next() == false); + DLIB_TEST(a1.size() == 0); + DLIB_TEST(a1.current_element_valid() == false); + DLIB_TEST(a1.at_start() == false); + DLIB_TEST(a1.move_next() == false); + DLIB_TEST(a1.current_element_valid() == false); + DLIB_TEST(a1.size() == 0); + DLIB_TEST(a1.at_start() == false); + DLIB_TEST(a1.size() == 0); + + swap(a1,a2); + DLIB_TEST(a2.size() == 0); + DLIB_TEST(a2.current_element_valid() == false); + DLIB_TEST(a2.at_start() == false); + DLIB_TEST(a2.move_next() == false); + DLIB_TEST(a2.current_element_valid() == false); + DLIB_TEST(a2.size() == 0); + DLIB_TEST(a2.at_start() == false); + DLIB_TEST(a2.size() == 0); + + + + DLIB_TEST(a1.size() == 0); + DLIB_TEST(a1.at_start()); + DLIB_TEST(a1.current_element_valid() == false); + DLIB_TEST(a1.move_next() == false); + DLIB_TEST(a1.size() == 0); + DLIB_TEST(a1.current_element_valid() == false); + DLIB_TEST(a1.at_start() == false); + DLIB_TEST(a1.move_next() == false); + DLIB_TEST(a1.current_element_valid() == false); + DLIB_TEST(a1.size() == 0); + DLIB_TEST(a1.at_start() == false); + DLIB_TEST(a1.size() == 0); + + a1.clear(); + a2.clear(); + + + for (unsigned long i = 0; i < 100; ++i) + { + int a = (int)i; + a1.push(a); + } + + DLIB_TEST(a1.size() == 100); + + int count = 99; + while (a1.move_next()) + { + DLIB_TEST_MSG(a1.element() == count,a1.element() << " : " << count); + --count; + } + + DLIB_TEST(a1.current_element_valid() == false); + DLIB_TEST(a1.at_start() == false); + + a1.swap(a2); + + count = 99; + DLIB_TEST(a2.current_element_valid() == false); + DLIB_TEST(a2.at_start() == false); + DLIB_TEST(a1.current_element_valid() == false); + DLIB_TEST(a1.at_start() == true); + + DLIB_TEST(a1.size() == 0); + DLIB_TEST(a2.size() == 100); + DLIB_TEST(a2.current() == 99); + + a2.reset(); + while (a2.move_next()) + { + DLIB_TEST(a2.element() == count--); + } + + DLIB_TEST(a2.current_element_valid() == false); + DLIB_TEST(a2.at_start() == false); + int b = 4; + a2.push(b); + DLIB_TEST(a2.current_element_valid() == false); + DLIB_TEST(a2.at_start() == true); + + DLIB_TEST(a2.current() == 4); + int c = 0; + a2.pop(c); + DLIB_TEST(c == 4); + + // serialize the state of a2, then clear a2, then + // load the state back into a2. + ostringstream sout; + serialize(a2,sout); + DLIB_TEST(a2.at_start() == true); + istringstream sin(sout.str()); + a2.clear(); + deserialize(a2,sin); + + + count = 99; + while (a2.size()) + { + int a = 0; + DLIB_TEST(a2.current() == count); + DLIB_TEST(const_cast(a2).current() == count); + a2.pop(a); + DLIB_TEST(a == count--); + } + + + + + + + a1.clear(); + a2.clear(); + } + + + { + a1.clear(); + remover& go = a1; + for (int i = 0; i < 100; ++i) + { + int a = 3; + a1.push(a); + } + DLIB_TEST(go.size() == 100); + for (int i = 0; i < 100; ++i) + { + int a = 9; + a1.remove_any(a); + DLIB_TEST(a == 3); + } + DLIB_TEST(go.size() == 0); + } + + } + + + + + class stack_tester : public tester + { + public: + stack_tester ( + ) : + tester ("test_stack", + "Runs tests on the stack component.") + {} + + void perform_test ( + ) + { + dlog << LINFO << "testing kernel_1a"; + stack_kernel_test::kernel_1a> (); + dlog << LINFO << "testing kernel_1a_c"; + stack_kernel_test::kernel_1a_c>(); + } + } a; + +} + diff --git a/ml/dlib/dlib/test/static_map.cpp b/ml/dlib/dlib/test/static_map.cpp new file mode 100644 index 000000000..931ae1fae --- /dev/null +++ b/ml/dlib/dlib/test/static_map.cpp @@ -0,0 +1,323 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include + +#include +#include + +#include +#include "tester.h" + +namespace +{ + + using namespace test; + using namespace std; + using namespace dlib; + + logger dlog("test.static_map"); + + template < + typename map + > + void static_map_kernel_test ( + ) + /*! + requires + - map is an implementation of static_map/static_map_kernel_abstract.h and + is instantiated to map int to int + ensures + - runs tests on map for compliance with the specs + !*/ + { + + print_spinner(); + srand(static_cast(time(0))); + + typedef binary_search_tree::kernel_2a_c bst; + typedef hash_table::kernel_1a_c ht; + + const unsigned long table_4_max_size = 100; + const unsigned long tree_max_size = 50000; + ht table_4(4); + ht table_8(8); + bst tree; + + ht table_4b(4); + ht table_8b(8); + bst treeb; + + + // just do the following to make sure operator[] doesn't hang + // under some instances + { + int g = 1, h = 1; + treeb.add(g,h); + map test; + map test2; + + DLIB_TEST(test.size() == 0); + DLIB_TEST(test.at_start()); + DLIB_TEST(test.current_element_valid() == false); + DLIB_TEST(test.move_next() == false); + DLIB_TEST(test.current_element_valid() == false); + DLIB_TEST(test.at_start() == false); + + swap(test,test2); + DLIB_TEST(test2.at_start() == false); + DLIB_TEST(test.at_start() == true); + + swap(test,test2); + DLIB_TEST(test.at_start() == false); + + + DLIB_TEST(test.size() == 0); + DLIB_TEST(test[1] == 0); + DLIB_TEST(test[2] == 0); + DLIB_TEST(test[3] == 0); + DLIB_TEST(test[0] == 0); + + test.load(treeb); + DLIB_TEST(test.at_start()); + DLIB_TEST(test[1] != 0); + DLIB_TEST(test[2] == 0); + DLIB_TEST(test[3] == 0); + DLIB_TEST(test[0] == 0); + + test2.clear(); + swap(test2,test); + DLIB_TEST(test2[1] != 0); + DLIB_TEST(test2[2] == 0); + DLIB_TEST(test2[3] == 0); + DLIB_TEST(test2[0] == 0); + DLIB_TEST(test[1] == 0); + DLIB_TEST(test[2] == 0); + DLIB_TEST(test[3] == 0); + DLIB_TEST(test[0] == 0); + + + DLIB_TEST(treeb.size() == 0); + treeb.clear(); + } + + + for (unsigned long i = 0; i < table_4_max_size; ++i) + { + int a = ::rand()&0xFF; + int b = a + 1; + int ab = a; + int bb = b; + table_4.add(a,b); + table_4b.add(ab,bb); + } + + for (unsigned long i = 0; i < table_4_max_size; ++i) + { + int a = ::rand()&0xF; + int b = a + 1; + int ab = a; + int bb = b; + table_8.add(a,b); + table_8b.add(ab,bb); + } + + for (unsigned long i = 0; i < tree_max_size; ++i) + { + int a = ::rand()&0xFFF; + int b = a + 1; + int ab = a; + int bb = b; + tree.add(a,b); + treeb.add(ab,bb); + } + + map m_4; + m_4.load(table_4); + map m_8; + m_8.load(table_8); + map m_t; + m_t.load(tree); + map e; + e.load(table_4); + + DLIB_TEST(e.size() == 0); + DLIB_TEST(e.at_start() == true); + DLIB_TEST(e.current_element_valid() == false); + DLIB_TEST(e.move_next() == false); + DLIB_TEST(e.at_start() == false); + DLIB_TEST(e.current_element_valid() == false); + + DLIB_TEST(m_4.size() == table_4b.size()); + DLIB_TEST(m_8.size() == table_8b.size()); + DLIB_TEST(m_t.size() == treeb.size()); + + DLIB_TEST(m_4.at_start() == true); + DLIB_TEST(m_8.at_start() == true); + DLIB_TEST(m_t.at_start() == true); + DLIB_TEST(m_4.current_element_valid() == false); + DLIB_TEST(m_8.current_element_valid() == false); + DLIB_TEST(m_t.current_element_valid() == false); + + + DLIB_TEST(m_4.move_next() == true); + DLIB_TEST(m_4.at_start() == false); + DLIB_TEST(m_4.current_element_valid() == true); + DLIB_TEST(m_8.move_next() == true); + DLIB_TEST(m_8.at_start() == false); + DLIB_TEST(m_8.current_element_valid() == true); + DLIB_TEST(m_t.move_next() == true); + DLIB_TEST(m_t.at_start() == false); + DLIB_TEST(m_t.current_element_valid() == true); + + m_4.reset(); + m_8.reset(); + m_t.reset(); + + while (m_4.move_next()) + { + DLIB_TEST( table_4b[m_4.element().key()] != 0); + DLIB_TEST( *table_4b[m_4.element().key()] == m_4.element().value()); + } + + // serialize the state of m_4, then clear m_4, then + // load the state back into m_4. + ostringstream sout; + serialize(m_4,sout); + DLIB_TEST(m_4.at_start() == true); + istringstream sin(sout.str()); + m_4.clear(); + deserialize(m_4,sin); + DLIB_TEST(m_4.at_start() == true); + + + + while (table_4b.move_next()) + { + DLIB_TEST( m_4[table_4b.element().key()] != 0); + DLIB_TEST( *m_4[table_4b.element().key()] == table_4b.element().value()); + } + + // serialize the state of m_8, then clear m_8, then + // load the state back into m_8. + sout.str(""); + serialize(m_8,sout); + DLIB_TEST(m_8.at_start() == true); + sin.str(sout.str()); + m_8.clear(); + deserialize(m_8,sin); + DLIB_TEST(m_8.at_start() == true); + + while (m_8.move_next()) + { + DLIB_TEST( table_8b[m_8.element().key()] != 0); + DLIB_TEST( *table_8b[m_8.element().key()] == m_8.element().value()); + } + + while (table_8b.move_next()) + { + DLIB_TEST( m_8[table_8b.element().key()] != 0); + DLIB_TEST( *m_8[table_8b.element().key()] == table_8b.element().value()); + } + + + while (m_t.move_next()) + { + DLIB_TEST( treeb[m_t.element().key()] != 0); + DLIB_TEST( *treeb[m_t.element().key()] == m_t.element().value()); + } + + // make sure operator[] doesn't hang + for (int l = 1; l < 10000; ++l) + { + DLIB_TEST(m_t[l+0xFFF] == 0); + } + + while (treeb.move_next()) + { + DLIB_TEST( m_t[treeb.element().key()] != 0); + DLIB_TEST( *m_t[treeb.element().key()] == treeb.element().value()); + } + + + + m_4.reset(); + m_8.reset(); + m_t.reset(); + + int last = 0; + while (m_4.move_next()) + { + DLIB_TEST(last <= m_4.element().key()); + DLIB_TEST(m_4.element().key() + 1 == m_4.element().value()); + last = m_4.element().key(); + } + + last = 0; + while (m_8.move_next()) + { + DLIB_TEST(last <= m_8.element().key()); + DLIB_TEST(m_8.element().key() + 1 == m_8.element().value()); + last = m_8.element().key(); + } + + last = 0; + while (m_t.move_next()) + { + DLIB_TEST(last <= m_t.element().key()); + DLIB_TEST(m_t.element().key() + 1 == m_t.element().value()); + last = m_t.element().key(); + } + + + + + + + // this is just to test swap + m_4.swap(m_8); + m_4.reset(); + table_4b.reset(); + while (m_8.move_next()) + { + DLIB_TEST( table_4b[m_8.element().key()] != 0); + DLIB_TEST( *table_4b[m_8.element().key()] == m_8.element().value()); + } + + while (table_4b.move_next()) + { + DLIB_TEST( m_8[table_4b.element().key()] != 0); + DLIB_TEST( *m_8[table_4b.element().key()] == table_4b.element().value()); + } + + } + + + + + + class static_map_tester : public tester + { + public: + static_map_tester ( + ) : + tester ("test_static_map", + "Runs tests on the static_map component.") + {} + + void perform_test ( + ) + { + dlog << LINFO << "testing kernel_1a"; + static_map_kernel_test::kernel_1a> (); + dlog << LINFO << "testing kernel_1a_c"; + static_map_kernel_test::kernel_1a_c>(); + } + } a; + +} + diff --git a/ml/dlib/dlib/test/static_set.cpp b/ml/dlib/dlib/test/static_set.cpp new file mode 100644 index 000000000..0ad864e4a --- /dev/null +++ b/ml/dlib/dlib/test/static_set.cpp @@ -0,0 +1,206 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include + +#include +#include +#include +#include "tester.h" + +namespace +{ + + using namespace test; + using namespace std; + using namespace dlib; + + logger dlog("test.static_set"); + + template < + typename set + > + void static_set_kernel_test ( + ) + /*! + requires + - set is an implementation of static_set/static_set_kernel_abstract.h and + is instantiated to hold ints + ensures + - runs tests on set for compliance with the specs + !*/ + { + + print_spinner(); + + srand(static_cast(time(0))); + + typedef queue::kernel_2a_c queue_of_int; + typedef dlib::set::kernel_1a_c set_of_int; + + queue_of_int q, qb, qc; + set_of_int ds; + + set S; + S.load(ds); + + for (int k = 1; k < 1000; ++k) + { + q.clear(); + qb.clear(); + qc.clear(); + unsigned long num = k; + for (unsigned long i = 0; i < num; ++i) + { + int a = ::rand()&0xFF; + int b = a; + int c = a; + q.enqueue(a); + qb.enqueue(b); + qc.enqueue(c); + } + + + + set s; + + DLIB_TEST(s.size() == 0); + DLIB_TEST(s.at_start()); + DLIB_TEST(s.current_element_valid() == false); + DLIB_TEST(s.move_next() == false); + DLIB_TEST(s.current_element_valid() == false); + DLIB_TEST(s.at_start() == false); + + s.load(q); + DLIB_TEST(s.at_start()); + set se; + se.load(q); + + DLIB_TEST(se.size() == 0); + DLIB_TEST(se.at_start() == true); + DLIB_TEST(se.current_element_valid() == false); + DLIB_TEST(se.move_next() == false); + DLIB_TEST(se.at_start() == false); + DLIB_TEST(se.current_element_valid() == false); + + + DLIB_TEST(s.size() == qb.size()); + DLIB_TEST(s.at_start() == true); + DLIB_TEST(s.current_element_valid() == false); + DLIB_TEST(s.move_next() == true); + DLIB_TEST(s.at_start() == false); + DLIB_TEST(s.current_element_valid() == true); + s.reset(); + se.reset(); + + swap(se,s); + + DLIB_TEST(s.size() == 0); + DLIB_TEST(s.at_start() == true); + DLIB_TEST(s.current_element_valid() == false); + DLIB_TEST(s.move_next() == false); + DLIB_TEST(s.at_start() == false); + DLIB_TEST(s.current_element_valid() == false); + + DLIB_TEST(se.size() == qb.size()); + DLIB_TEST(se.at_start() == true); + DLIB_TEST(se.current_element_valid() == false); + DLIB_TEST(se.move_next() == true); + DLIB_TEST(se.at_start() == false); + DLIB_TEST(se.current_element_valid() == true); + s.reset(); + se.reset(); + + swap(se,s); + + + + int last = 0; + while (s.move_next()) + { + DLIB_TEST(last <= s.element()); + last = s.element(); + } + + + + while (qb.move_next()) + { + int a; + qb.dequeue(a); + DLIB_TEST(s.is_member(a)); + DLIB_TEST(!se.is_member(a)); + + // make sure is_member() doesn't hang + for (int l = 0; l < 100; ++l) + { + int a = ::rand(); + s.is_member(a); + } + } + + swap(s,se); + + // serialize the state of se, then clear se, then + // load the state back into se. + ostringstream sout; + serialize(se,sout); + DLIB_TEST(se.at_start() == true); + istringstream sin(sout.str()); + se.clear(); + deserialize(se,sin); + DLIB_TEST(se.at_start() == true); + + + last = 0; + while (se.move_next()) + { + DLIB_TEST(last <= se.element()); + last = se.element(); + } + + + DLIB_TEST(s.size() == 0); + DLIB_TEST(se.size() == qc.size()); + + while (qc.move_next()) + { + int a; + qc.dequeue(a); + DLIB_TEST(se.is_member(a)); + DLIB_TEST(!s.is_member(a)); + } + + + } + } + + + + + + class static_set_tester : public tester + { + public: + static_set_tester ( + ) : + tester ("test_static_set", + "Runs tests on the static_set component.") + {} + + void perform_test ( + ) + { + dlog << LINFO << "testing kernel_1a"; + static_set_kernel_test::kernel_1a> (); + dlog << LINFO << "testing kernel_1a_c"; + static_set_kernel_test::kernel_1a_c>(); + } + } a; + +} + diff --git a/ml/dlib/dlib/test/statistics.cpp b/ml/dlib/dlib/test/statistics.cpp new file mode 100644 index 000000000..0394286ad --- /dev/null +++ b/ml/dlib/dlib/test/statistics.cpp @@ -0,0 +1,915 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tester.h" + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.statistics"); + + + + class statistics_tester : public tester + { + public: + statistics_tester ( + ) : + tester ("test_statistics", + "Runs tests on the statistics component.") + {} + + void test_random_subset_selector () + { + random_subset_selector rand_set; + + for (int j = 0; j < 30; ++j) + { + print_spinner(); + + running_stats rs, rs2; + + rand_set.set_max_size(1000); + + for (double i = 0; i < 100000; ++i) + { + rs.add(i); + rand_set.add(i); + } + + + for (unsigned long i = 0; i < rand_set.size(); ++i) + rs2.add(rand_set[i]); + + + dlog << LDEBUG << "true mean: " << rs.mean(); + dlog << LDEBUG << "true sampled: " << rs2.mean(); + double ratio = rs.mean()/rs2.mean(); + DLIB_TEST_MSG(0.96 < ratio && ratio < 1.04, " ratio: " << ratio); + } + + + { + random_subset_selector r1, r2; + r1.set_max_size(300); + for (int i = 0; i < 4000; ++i) + r1.add(i); + + ostringstream sout; + serialize(r1, sout); + istringstream sin(sout.str()); + deserialize(r2, sin); + + DLIB_TEST(r1.size() == r2.size()); + DLIB_TEST(r1.max_size() == r2.max_size()); + DLIB_TEST(r1.next_add_accepts() == r2.next_add_accepts()); + DLIB_TEST(std::equal(r1.begin(), r1.end(), r2.begin())); + + for (int i = 0; i < 4000; ++i) + { + r1.add(i); + r2.add(i); + } + + DLIB_TEST(r1.size() == r2.size()); + DLIB_TEST(r1.max_size() == r2.max_size()); + DLIB_TEST(r1.next_add_accepts() == r2.next_add_accepts()); + DLIB_TEST(std::equal(r1.begin(), r1.end(), r2.begin())); + } + } + + void test_random_subset_selector2 () + { + random_subset_selector rand_set; + DLIB_TEST(rand_set.next_add_accepts() == false); + DLIB_TEST(rand_set.size() == 0); + DLIB_TEST(rand_set.max_size() == 0); + + for (int j = 0; j < 30; ++j) + { + print_spinner(); + + running_stats rs, rs2; + + rand_set.set_max_size(1000); + DLIB_TEST(rand_set.next_add_accepts() == true); + + for (double i = 0; i < 100000; ++i) + { + rs.add(i); + if (rand_set.next_add_accepts()) + rand_set.add(i); + else + rand_set.add(); + } + + DLIB_TEST(rand_set.size() == 1000); + DLIB_TEST(rand_set.max_size() == 1000); + + for (unsigned long i = 0; i < rand_set.size(); ++i) + rs2.add(rand_set[i]); + + + dlog << LDEBUG << "true mean: " << rs.mean(); + dlog << LDEBUG << "true sampled: " << rs2.mean(); + double ratio = rs.mean()/rs2.mean(); + DLIB_TEST_MSG(0.96 < ratio && ratio < 1.04, " ratio: " << ratio); + } + } + + void test_running_cross_covariance () + { + running_cross_covariance > rcc1, rcc2; + + matrix xm, ym; + const int num = 40; + + dlib::rand rnd; + for (int i = 0; i < num; ++i) + { + matrix x = randm(4,1,rnd); + matrix y = randm(4,1,rnd); + + xm += x/num; + ym += y/num; + + if (i < 15) + rcc1.add(x,y); + else + rcc2.add(x,y); + } + + rnd.clear(); + matrix cov; + for (int i = 0; i < num; ++i) + { + matrix x = randm(4,1,rnd); + matrix y = randm(4,1,rnd); + cov += (x-xm)*trans(y-ym); + } + cov /= num-1; + + running_cross_covariance > rcc = rcc1 + rcc2; + DLIB_TEST(max(abs(rcc.covariance_xy()-cov)) < 1e-14); + DLIB_TEST(max(abs(rcc.mean_x()-xm)) < 1e-14); + DLIB_TEST(max(abs(rcc.mean_y()-ym)) < 1e-14); + } + + std::map dense_to_sparse ( + const matrix& x + ) + { + std::map temp; + for (long i = 0; i < x.size(); ++i) + temp[i] = x(i); + return temp; + } + + void test_running_cross_covariance_sparse() + { + running_cross_covariance > rcc1, rcc2; + + running_covariance > rc1, rc2; + + matrix xm, ym; + const int num = 40; + + rc1.set_dimension(4); + rc2.set_dimension(4); + + rcc1.set_dimensions(4,5); + rcc2.set_dimensions(4,5); + + dlib::rand rnd; + for (int i = 0; i < num; ++i) + { + matrix x = randm(4,1,rnd); + matrix y = randm(5,1,rnd); + + xm += x/num; + ym += y/num; + + if (i < 15) + { + rcc1.add(x,dense_to_sparse(y)); + rc1.add(x); + } + else if (i < 30) + { + rcc2.add(dense_to_sparse(x),y); + rc2.add(dense_to_sparse(x)); + } + else + { + rcc2.add(dense_to_sparse(x),dense_to_sparse(y)); + rc2.add(x); + } + } + + rnd.clear(); + matrix cov, cov2; + for (int i = 0; i < num; ++i) + { + matrix x = randm(4,1,rnd); + matrix y = randm(5,1,rnd); + cov += (x-xm)*trans(y-ym); + cov2 += (x-xm)*trans(x-xm); + } + cov /= num-1; + cov2 /= num-1; + + running_cross_covariance > rcc = rcc1 + rcc2; + DLIB_TEST_MSG(max(abs(rcc.covariance_xy()-cov)) < 1e-14, max(abs(rcc.covariance_xy()-cov))); + DLIB_TEST(max(abs(rcc.mean_x()-xm)) < 1e-14); + DLIB_TEST(max(abs(rcc.mean_y()-ym)) < 1e-14); + + running_covariance > rc = rc1 + rc2; + DLIB_TEST(max(abs(rc.covariance()-cov2)) < 1e-14); + DLIB_TEST(max(abs(rc.mean()-xm)) < 1e-14); + } + + void test_running_covariance ( + ) + { + dlib::rand rnd; + std::vector > vects; + + running_covariance > cov, cov2; + DLIB_TEST(cov.in_vector_size() == 0); + + for (unsigned long dims = 1; dims < 5; ++dims) + { + for (unsigned long samps = 2; samps < 10; ++samps) + { + vects.clear(); + cov.clear(); + DLIB_TEST(cov.in_vector_size() == 0); + for (unsigned long i = 0; i < samps; ++i) + { + vects.push_back(randm(dims,1,rnd)); + cov.add(vects.back()); + + } + DLIB_TEST(cov.in_vector_size() == (long)dims); + + DLIB_TEST(equal(mean(mat(vects)), cov.mean())); + DLIB_TEST_MSG(equal(covariance(mat(vects)), cov.covariance()), + max(abs(covariance(mat(vects)) - cov.covariance())) + << " dims = " << dims << " samps = " << samps + ); + } + } + + for (unsigned long dims = 1; dims < 5; ++dims) + { + for (unsigned long samps = 2; samps < 10; ++samps) + { + vects.clear(); + cov.clear(); + cov2.clear(); + DLIB_TEST(cov.in_vector_size() == 0); + for (unsigned long i = 0; i < samps; ++i) + { + vects.push_back(randm(dims,1,rnd)); + if ((i%2) == 0) + cov.add(vects.back()); + else + cov2.add(vects.back()); + + } + DLIB_TEST((cov+cov2).in_vector_size() == (long)dims); + + DLIB_TEST(equal(mean(mat(vects)), (cov+cov2).mean())); + DLIB_TEST_MSG(equal(covariance(mat(vects)), (cov+cov2).covariance()), + max(abs(covariance(mat(vects)) - (cov+cov2).covariance())) + << " dims = " << dims << " samps = " << samps + ); + } + } + + } + + void test_running_stats() + { + print_spinner(); + + running_stats rs, rs2; + + running_scalar_covariance rsc1, rsc2; + running_scalar_covariance_decayed rscd1(1000000), rscd2(1000000); + + for (double i = 0; i < 100; ++i) + { + rs.add(i); + + rsc1.add(i,i); + rsc2.add(i,i); + rsc2.add(i,-i); + + rscd1.add(i,i); + rscd2.add(i,i); + rscd2.add(i,-i); + } + + // make sure the running_stats and running_scalar_covariance agree + DLIB_TEST_MSG(std::abs(rs.mean() - rsc1.mean_x()) < 1e-10, std::abs(rs.mean() - rsc1.mean_x())); + DLIB_TEST(std::abs(rs.mean() - rsc1.mean_y()) < 1e-10); + DLIB_TEST(std::abs(rs.stddev() - rsc1.stddev_x()) < 1e-10); + DLIB_TEST(std::abs(rs.stddev() - rsc1.stddev_y()) < 1e-10); + DLIB_TEST(std::abs(rs.variance() - rsc1.variance_x()) < 1e-10); + DLIB_TEST(std::abs(rs.variance() - rsc1.variance_y()) < 1e-10); + DLIB_TEST(rs.current_n() == rsc1.current_n()); + + DLIB_TEST(std::abs(rsc1.correlation() - 1) < 1e-10); + DLIB_TEST(std::abs(rsc2.correlation() - 0) < 1e-10); + + + DLIB_TEST_MSG(std::abs(rs.mean() - rscd1.mean_x()) < 1e-2, std::abs(rs.mean() - rscd1.mean_x()) << " " << rscd1.mean_x()); + DLIB_TEST(std::abs(rs.mean() - rscd1.mean_y()) < 1e-2); + DLIB_TEST_MSG(std::abs(rs.stddev() - rscd1.stddev_x()) < 1e-2, std::abs(rs.stddev() - rscd1.stddev_x())); + DLIB_TEST(std::abs(rs.stddev() - rscd1.stddev_y()) < 1e-2); + DLIB_TEST_MSG(std::abs(rs.variance() - rscd1.variance_x()) < 1e-2, std::abs(rs.variance() - rscd1.variance_x())); + DLIB_TEST(std::abs(rs.variance() - rscd1.variance_y()) < 1e-2); + DLIB_TEST(std::abs(rscd1.correlation() - 1) < 1e-2); + DLIB_TEST(std::abs(rscd2.correlation() - 0) < 1e-2); + + + + // test serialization of running_stats + ostringstream sout; + serialize(rs, sout); + istringstream sin(sout.str()); + deserialize(rs2, sin); + // make sure the running_stats and running_scalar_covariance agree + DLIB_TEST_MSG(std::abs(rs2.mean() - rsc1.mean_x()) < 1e-10, std::abs(rs2.mean() - rsc1.mean_x())); + DLIB_TEST(std::abs(rs2.mean() - rsc1.mean_y()) < 1e-10); + DLIB_TEST(std::abs(rs2.stddev() - rsc1.stddev_x()) < 1e-10); + DLIB_TEST(std::abs(rs2.stddev() - rsc1.stddev_y()) < 1e-10); + DLIB_TEST(std::abs(rs2.variance() - rsc1.variance_x()) < 1e-10); + DLIB_TEST(std::abs(rs2.variance() - rsc1.variance_y()) < 1e-10); + DLIB_TEST(rs2.current_n() == rsc1.current_n()); + + rsc1.clear(); + rsc1.add(1, -1); + rsc1.add(0, 0); + rsc1.add(1, -1); + rsc1.add(0, 0); + rsc1.add(1, -1); + rsc1.add(0, 0); + + DLIB_TEST(std::abs(rsc1.covariance() - -0.3) < 1e-10); + } + + void test_skewness_and_kurtosis_1() + { + + dlib::rand rnum; + running_stats rs1; + + double tp = 0; + + rnum.set_seed("DlibRocks"); + + for(int i = 0; i< 1000000; i++) + { + tp = rnum.get_random_gaussian(); + rs1.add(tp); + } + + // check the unbiased skewness and excess kurtosis of one million Gaussian + // draws are both near_vects zero. + DLIB_TEST(abs(rs1.skewness()) < 0.1); + DLIB_TEST(abs(rs1.ex_kurtosis()) < 0.1); + } + + void test_skewness_and_kurtosis_2() + { + + string str = "DlibRocks"; + + for(int j = 0; j<5 ; j++) + { + matrix dat; + dlib::rand rnum; + running_stats rs1; + + double tp = 0; + double n = 100000; + double xb = 0; + + double sknum = 0; + double skdenom = 0; + double unbi_skew = 0; + + double exkurnum = 0; + double exkurdenom = 0; + double unbi_exkur = 0; + + random_shuffle(str.begin(), str.end()); + rnum.set_seed(str); + + for(int i = 0; i t(15),u(15),v(15); + + for (unsigned long i = 0; i < t.size(); ++i) + { + t[i] = i; + u[i] = i+1; + v[i] = i+2; + } + randomize_samples(t,u,v); + + DLIB_TEST(t.size() == 15); + DLIB_TEST(u.size() == 15); + DLIB_TEST(v.size() == 15); + + for (unsigned long i = 0; i < t.size(); ++i) + { + const unsigned long val = t[i]; + DLIB_TEST(u[i] == val+1); + DLIB_TEST(v[i] == val+2); + } + } + void test_randomize_samples2() + { + dlib::matrix t(15),u(15),v(15); + + for (long i = 0; i < t.size(); ++i) + { + t(i) = i; + u(i) = i+1; + v(i) = i+2; + } + randomize_samples(t,u,v); + + DLIB_TEST(t.size() == 15); + DLIB_TEST(u.size() == 15); + DLIB_TEST(v.size() == 15); + + for (long i = 0; i < t.size(); ++i) + { + const long val = t(i); + DLIB_TEST(u(i) == val+1); + DLIB_TEST(v(i) == val+2); + } + } + + void another_test() + { + std::vector a; + + running_stats rs1, rs2; + + for (int i = 0; i < 10; ++i) + { + rs1.add(i); + a.push_back(i); + } + + DLIB_TEST(std::abs(variance(mat(a)) - rs1.variance()) < 1e-13); + DLIB_TEST(std::abs(stddev(mat(a)) - rs1.stddev()) < 1e-13); + DLIB_TEST(std::abs(mean(mat(a)) - rs1.mean()) < 1e-13); + + for (int i = 10; i < 20; ++i) + { + rs2.add(i); + a.push_back(i); + } + + DLIB_TEST(std::abs(variance(mat(a)) - (rs1+rs2).variance()) < 1e-13); + DLIB_TEST(std::abs(mean(mat(a)) - (rs1+rs2).mean()) < 1e-13); + DLIB_TEST((rs1+rs2).current_n() == 20); + + running_scalar_covariance rc1, rc2, rc3; + dlib::rand rnd; + for (double i = 0; i < 10; ++i) + { + const double a = i + rnd.get_random_gaussian(); + const double b = i + rnd.get_random_gaussian(); + rc1.add(a,b); + rc3.add(a,b); + } + for (double i = 11; i < 20; ++i) + { + const double a = i + rnd.get_random_gaussian(); + const double b = i + rnd.get_random_gaussian(); + rc2.add(a,b); + rc3.add(a,b); + } + + DLIB_TEST(std::abs((rc1+rc2).mean_x() - rc3.mean_x()) < 1e-13); + DLIB_TEST(std::abs((rc1+rc2).mean_y() - rc3.mean_y()) < 1e-13); + DLIB_TEST_MSG(std::abs((rc1+rc2).variance_x() - rc3.variance_x()) < 1e-13, std::abs((rc1+rc2).variance_x() - rc3.variance_x())); + DLIB_TEST(std::abs((rc1+rc2).variance_y() - rc3.variance_y()) < 1e-13); + DLIB_TEST(std::abs((rc1+rc2).covariance() - rc3.covariance()) < 1e-13); + DLIB_TEST((rc1+rc2).current_n() == rc3.current_n()); + + } + + void test_average_precision() + { + std::vector items; + DLIB_TEST(average_precision(items) == 1); + DLIB_TEST(average_precision(items,1) == 0); + + items.push_back(true); + DLIB_TEST(average_precision(items) == 1); + DLIB_TEST(std::abs(average_precision(items,1) - 0.5) < 1e-14); + + items.push_back(true); + DLIB_TEST(average_precision(items) == 1); + DLIB_TEST(std::abs(average_precision(items,1) - 2.0/3.0) < 1e-14); + + items.push_back(false); + + DLIB_TEST(average_precision(items) == 1); + DLIB_TEST(std::abs(average_precision(items,1) - 2.0/3.0) < 1e-14); + + items.push_back(true); + + DLIB_TEST(std::abs(average_precision(items) - (2.0+3.0/4.0)/3.0) < 1e-14); + + items.push_back(true); + + DLIB_TEST(std::abs(average_precision(items) - (2.0 + 4.0/5.0 + 4.0/5.0)/4.0) < 1e-14); + DLIB_TEST(std::abs(average_precision(items,1) - (2.0 + 4.0/5.0 + 4.0/5.0)/5.0) < 1e-14); + } + + + template + void check_distance_metrics ( + const std::vector >& samples + ) + { + running_stats rs; + for (unsigned long i = 0; i < samples.size(); ++i) + { + for (unsigned long j = 0; j < samples[i].near_vects.size(); ++j) + { + const double d1 = length_squared(samples[i].anchor_vect - samples[i].near_vects[j]); + for (unsigned long k = 0; k < samples[i].far_vects.size(); ++k) + { + const double d2 = length_squared(samples[i].anchor_vect - samples[i].far_vects[k]); + rs.add(d2-d1); + } + } + } + + dlog << LINFO << "dist gap max: "<< rs.max(); + dlog << LINFO << "dist gap min: "<< rs.min(); + dlog << LINFO << "dist gap mean: "<< rs.mean(); + dlog << LINFO << "dist gap stddev: "<< rs.stddev(); + DLIB_TEST(rs.min() >= 0.99); + DLIB_TEST(rs.mean() >= 0.9999); + } + + void test_vector_normalizer_frobmetric(dlib::rand& rnd) + { + print_spinner(); + typedef matrix sample_type; + vector_normalizer_frobmetric normalizer; + + std::vector > samples; + frobmetric_training_sample samp; + + const long key = 1; + const long dims = 5; + // Lets make some two class training data. Each sample will have dims dimensions but + // only the one with index equal to key will be meaningful. In particular, if the key + // dimension is > 0 then the sample is class +1 and -1 otherwise. + + long k = 0; + for (int i = 0; i < 50; ++i) + { + samp.clear(); + samp.anchor_vect = gaussian_randm(dims,1,k++); + if (samp.anchor_vect(key) > 0) + samp.anchor_vect(key) = rnd.get_random_double() + 5; + else + samp.anchor_vect(key) = -(rnd.get_random_double() + 5); + + matrix temp; + + for (int j = 0; j < 5; ++j) + { + // Don't always put an equal number of near_vects and far_vects vectors into the + // training samples. + const int numa = rnd.get_random_32bit_number()%2 + 1; + const int numb = rnd.get_random_32bit_number()%2 + 1; + + for (int num = 0; num < numa; ++num) + { + temp = gaussian_randm(dims,1,k++); temp(key) = 0.1; + //temp = gaussian_randm(dims,1,k++); temp(key) = std::abs(temp(key)); + if (samp.anchor_vect(key) > 0) samp.near_vects.push_back(temp); + else samp.far_vects.push_back(temp); + } + + for (int num = 0; num < numb; ++num) + { + temp = gaussian_randm(dims,1,k++); temp(key) = -0.1; + //temp = gaussian_randm(dims,1,k++); temp(key) = -std::abs(temp(key)); + if (samp.anchor_vect(key) < 0) samp.near_vects.push_back(temp); + else samp.far_vects.push_back(temp); + } + } + samples.push_back(samp); + } + + normalizer.set_epsilon(0.0001); + normalizer.set_c(100); + normalizer.set_max_iterations(6000); + normalizer.train(samples); + + dlog << LINFO << "learned transform: \n" << normalizer.transform(); + + matrix total; + + for (unsigned long i = 0; i < samples.size(); ++i) + { + samples[i].anchor_vect = normalizer(samples[i].anchor_vect); + total += samples[i].anchor_vect; + for (unsigned long j = 0; j < samples[i].near_vects.size(); ++j) + samples[i].near_vects[j] = normalizer(samples[i].near_vects[j]); + for (unsigned long j = 0; j < samples[i].far_vects.size(); ++j) + samples[i].far_vects[j] = normalizer(samples[i].far_vects[j]); + } + total /= samples.size(); + dlog << LINFO << "sample transformed means: "<< trans(total); + DLIB_TEST(length(total) < 1e-9); + check_distance_metrics(samples); + + // make sure serialization works + stringstream os; + serialize(normalizer, os); + vector_normalizer_frobmetric normalizer2; + deserialize(normalizer2, os); + DLIB_TEST(equal(normalizer.transform(), normalizer2.transform())); + DLIB_TEST(equal(normalizer.transformed_means(), normalizer2.transformed_means())); + DLIB_TEST(normalizer.in_vector_size() == normalizer2.in_vector_size()); + DLIB_TEST(normalizer.out_vector_size() == normalizer2.out_vector_size()); + DLIB_TEST(normalizer.get_max_iterations() == normalizer2.get_max_iterations()); + DLIB_TEST(std::abs(normalizer.get_c() - normalizer2.get_c()) < 1e-14); + DLIB_TEST(std::abs(normalizer.get_epsilon() - normalizer2.get_epsilon()) < 1e-14); + + } + + void prior_frobnorm_test() + { + frobmetric_training_sample > sample; + std::vector > > samples; + + matrix x, near_, far_; + x = 0,0,0; + near_ = 1,0,0; + far_ = 0,1,0; + + sample.anchor_vect = x; + sample.near_vects.push_back(near_); + sample.far_vects.push_back(far_); + + samples.push_back(sample); + + vector_normalizer_frobmetric > trainer; + trainer.set_c(100); + print_spinner(); + trainer.train(samples); + + matrix correct; + correct = 0, 0, 0, + 0, 1, 0, + 0, 0, 0; + + dlog << LDEBUG << trainer.transform(); + DLIB_TEST(max(abs(trainer.transform()-correct)) < 1e-8); + + trainer.set_uses_identity_matrix_prior(true); + print_spinner(); + trainer.train(samples); + correct = 1, 0, 0, + 0, 2, 0, + 0, 0, 1; + + dlog << LDEBUG << trainer.transform(); + DLIB_TEST(max(abs(trainer.transform()-correct)) < 1e-8); + + } + + void test_lda () + { + // This test makes sure we pick the right direction in a simple 2D -> 1D LDA + typedef matrix sample_type; + + std::vector labels; + std::vector samples; + for (int i=0; i<4; i++) + { + sample_type s; + s(0) = i; + s(1) = i+1; + samples.push_back(s); + labels.push_back(1); + + sample_type s1; + s1(0) = i+1; + s1(1) = i; + samples.push_back(s1); + labels.push_back(2); + } + + matrix X; + X.set_size(8,2); + for (int i=0; i<8; i++){ + X(i,0) = samples[i](0); + X(i,1) = samples[i](1); + } + + matrix mean; + + dlib::compute_lda_transform(X,mean,labels,1); + + std::vector vals1, vals2; + for (unsigned long i = 0; i < samples.size(); ++i) + { + double val = X*samples[i]-mean; + if (i%2 == 0) + vals1.push_back(val); + else + vals2.push_back(val); + dlog << LINFO << "1D LDA output: " << val; + } + + if (vals1[0] > vals2[0]) + swap(vals1, vals2); + + const double err = equal_error_rate(vals1, vals2).first; + dlog << LINFO << "LDA ERR: " << err; + DLIB_TEST(err == 0); + DLIB_TEST(equal_error_rate(vals2, vals1).first == 1); + } + + void test_running_stats_decayed() + { + print_spinner(); + std::vector tmp(300); + std::vector tmp_var(tmp.size()); + dlib::rand rnd; + const int num_rounds = 100000; + for (int rounds = 0; rounds < num_rounds; ++rounds) + { + running_stats_decayed rs(100); + + for (size_t i = 0; i < tmp.size(); ++i) + { + rs.add(rnd.get_random_gaussian() + 1); + tmp[i] += rs.mean(); + if (i > 0) + tmp_var[i] += rs.variance(); + } + } + + // should print all 1s basically since the mean and variance should always be 1. + for (size_t i = 0; i < tmp.size(); ++i) + { + DLIB_TEST(std::abs(1-tmp[i]/num_rounds) < 0.001); + if (i > 1) + DLIB_TEST(std::abs(1-tmp_var[i]/num_rounds) < 0.01); + } + } + + void test_running_scalar_covariance_decayed() + { + print_spinner(); + std::vector tmp(300); + std::vector tmp_var(tmp.size()); + std::vector tmp_covar(tmp.size()); + dlib::rand rnd; + const int num_rounds = 500000; + for (int rounds = 0; rounds < num_rounds; ++rounds) + { + running_scalar_covariance_decayed rs(100); + + for (size_t i = 0; i < tmp.size(); ++i) + { + rs.add(rnd.get_random_gaussian() + 1, rnd.get_random_gaussian() + 1); + tmp[i] += (rs.mean_y()+rs.mean_x())/2; + if (i > 0) + { + tmp_var[i] += (rs.variance_y()+rs.variance_x())/2; + tmp_covar[i] += rs.covariance(); + } + } + } + + // should print all 1s basically since the mean and variance should always be 1. + for (size_t i = 0; i < tmp.size(); ++i) + { + DLIB_TEST(std::abs(1-tmp[i]/num_rounds) < 0.001); + if (i > 1) + { + DLIB_TEST(std::abs(1-tmp_var[i]/num_rounds) < 0.01); + DLIB_TEST(std::abs(tmp_covar[i]/num_rounds) < 0.001); + } + } + } + + + void test_event_corr() + { + print_spinner(); + DLIB_TEST(event_correlation(1000,1000,500,2000) == 0); + DLIB_TEST(std::abs(event_correlation(1000,1000,300,2000) + 164.565757010104) < 1e-11); + DLIB_TEST(std::abs(event_correlation(1000,1000,700,2000) - 164.565757010104) < 1e-11); + + DLIB_TEST(event_correlation(10,1000,5,2000) == 0); + DLIB_TEST(event_correlation(1000,10,5,2000) == 0); + DLIB_TEST(std::abs(event_correlation(10,1000,1,2000) - event_correlation(1000,10,1,2000)) < 1e-11); + DLIB_TEST(std::abs(event_correlation(10,1000,9,2000) - event_correlation(1000,10,9,2000)) < 1e-11); + + DLIB_TEST(std::abs(event_correlation(10,1000,1,2000) + 3.69672251700842) < 1e-11); + DLIB_TEST(std::abs(event_correlation(10,1000,9,2000) - 3.69672251700842) < 1e-11); + } + + void perform_test ( + ) + { + prior_frobnorm_test(); + dlib::rand rnd; + for (int i = 0; i < 5; ++i) + test_vector_normalizer_frobmetric(rnd); + + test_random_subset_selector(); + test_random_subset_selector2(); + test_running_covariance(); + test_running_cross_covariance(); + test_running_cross_covariance_sparse(); + test_running_stats(); + test_skewness_and_kurtosis_1(); + test_skewness_and_kurtosis_2(); + test_randomize_samples(); + test_randomize_samples2(); + another_test(); + test_average_precision(); + test_lda(); + test_event_corr(); + test_running_stats_decayed(); + test_running_scalar_covariance_decayed(); + } + } a; + +} + + diff --git a/ml/dlib/dlib/test/std_vector_c.cpp b/ml/dlib/dlib/test/std_vector_c.cpp new file mode 100644 index 000000000..fe7f82514 --- /dev/null +++ b/ml/dlib/dlib/test/std_vector_c.cpp @@ -0,0 +1,101 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include + +#include "tester.h" + +// This is called an unnamed-namespace and it has the effect of making everything inside this file "private" +// so that everything you declare will have static linkage. Thus we won't have any multiply +// defined symbol errors coming out of the linker when we try to compile the test suite. +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + // Declare the logger we will use in this test. The name of the tester + // should start with "test." + logger dlog("test.std_vector_c"); + + + class std_vector_c_tester : public tester + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a test for the std_vector_c object. When it is constructed + it adds itself into the testing framework. The command line switch is + specified as test_std_vector_c by passing that string to the tester constructor. + !*/ + public: + std_vector_c_tester ( + ) : + tester ("test_std_vector_c", + "Runs tests on the std_vector_c component.") + {} + + void perform_test ( + ) + { + std::vector c; + std_vector_c a, b; + a.push_back(3); + a.push_back(2); + a.push_back(1); + + DLIB_TEST(a[0] == 3); + DLIB_TEST(a[1] == 2); + DLIB_TEST(a[2] == 1); + c = a; + DLIB_TEST(c[0] == 3); + DLIB_TEST(c[1] == 2); + DLIB_TEST(c[2] == 1); + DLIB_TEST(c.size() == 3); + DLIB_TEST(a.size() == 3); + DLIB_TEST(b.size() == 0); + + DLIB_TEST(a == c); + DLIB_TEST(!(a != c)); + DLIB_TEST(a <= c); + DLIB_TEST(a >= c); + DLIB_TEST(!(a < c)); + DLIB_TEST(!(a > c)); + + swap(b,c); + DLIB_TEST(b[0] == 3); + DLIB_TEST(b[1] == 2); + DLIB_TEST(b[2] == 1); + DLIB_TEST(c.size() == 0); + DLIB_TEST(b.size() == 3); + swap(c,b); + DLIB_TEST(c[0] == 3); + DLIB_TEST(c[1] == 2); + DLIB_TEST(c[2] == 1); + DLIB_TEST(c.size() == 3); + DLIB_TEST(b.size() == 0); + swap(a,b); + DLIB_TEST(b[0] == 3); + DLIB_TEST(b[1] == 2); + DLIB_TEST(b[2] == 1); + DLIB_TEST(b.size() == 3); + DLIB_TEST(a.size() == 0); + + + swap(b,c); + swap(c,c); + + + std_vector_c h(a); + std_vector_c i(c); + std::vector j(b); + } + } a; + +} + diff --git a/ml/dlib/dlib/test/string.cpp b/ml/dlib/dlib/test/string.cpp new file mode 100644 index 000000000..18f9035c5 --- /dev/null +++ b/ml/dlib/dlib/test/string.cpp @@ -0,0 +1,329 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include + +#include "tester.h" + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.string"); + + + void string_test ( + ) + /*! + ensures + - runs tests on string functions for compliance with the specs + !*/ + { + + print_spinner(); + + string a = " davis "; + string A = " DAVIS "; + string empty = " "; + + dlog << LTRACE << 1; + + double dval; + int ival; + bool bval; + + DLIB_TEST_MSG(string_cast("5") == 5,string_cast("5")); + DLIB_TEST_MSG(string_cast("0x5") == 5,string_cast("0x5")); + DLIB_TEST_MSG(string_cast("0xA") == 10,string_cast("0xA")); + DLIB_TEST(string_cast("0.5") == 0.5); + DLIB_TEST((dval = sa ="0.5") == 0.5); + DLIB_TEST(string_cast("0.5 !") == "0.5 !"); + DLIB_TEST(string_cast("true") == true); + DLIB_TEST((bval = sa = "true") == true); + DLIB_TEST(string_cast("false") == false); + DLIB_TEST(string_cast("TRUE") == true); + DLIB_TEST(string_cast("FALSE") == false); + DLIB_TEST((bval = sa = "FALSE") == false); + + dlog << LTRACE << 2; + + DLIB_TEST_MSG(string_cast(L"5") == 5,string_cast("5")); + DLIB_TEST_MSG((ival = sa = L"5") == 5,string_cast("5")); + dlog << LTRACE << 2.1; + DLIB_TEST_MSG(string_cast(L"0x5") == 5,string_cast("0x5")); + DLIB_TEST_MSG(string_cast(L"0xA") == 10,string_cast("0xA")); + DLIB_TEST(string_cast(L"0.5") == 0.5); + DLIB_TEST(string_cast(L"0.5 !") == "0.5 !"); + DLIB_TEST(string_cast(L"true") == true); + DLIB_TEST(string_cast(L"false") == false); + DLIB_TEST(string_cast(L"TRUE") == true); + DLIB_TEST((bval = sa = L"TRUE") == true); + DLIB_TEST(string_cast(L"FALSE") == false); + + dlog << LTRACE << 3; + + DLIB_TEST(cast_to_string(5) == "5"); + DLIB_TEST(cast_to_string(5.5) == "5.5"); + + dlog << LTRACE << 4; + DLIB_TEST(cast_to_wstring(5) == L"5"); + DLIB_TEST(cast_to_wstring(5.5) == L"5.5"); + dlog << LTRACE << 5; + DLIB_TEST(toupper(a) == A); + DLIB_TEST(toupper(A) == A); + DLIB_TEST(tolower(a) == a); + DLIB_TEST(tolower(A) == a); + DLIB_TEST(trim(a) == "davis"); + DLIB_TEST(ltrim(a) == "davis "); + DLIB_TEST(rtrim(a) == " davis"); + DLIB_TEST(trim(string_cast(a)) == L"davis"); + DLIB_TEST(ltrim(string_cast(a)) == L"davis "); + DLIB_TEST(rtrim(string_cast(a)) == L" davis"); + DLIB_TEST(trim(a, " ") == "davis"); + DLIB_TEST(ltrim(a, " ") == "davis "); + DLIB_TEST(rtrim(a, " ") == " davis"); + DLIB_TEST(trim(empty) == ""); + DLIB_TEST(ltrim(empty) == ""); + DLIB_TEST(rtrim(empty) == ""); + DLIB_TEST(trim(string_cast(empty)) == L""); + DLIB_TEST(ltrim(string_cast(empty)) == L""); + DLIB_TEST(rtrim(string_cast(empty)) == L""); + DLIB_TEST(trim(empty, " ") == ""); + DLIB_TEST(ltrim(empty, " ") == ""); + DLIB_TEST(rtrim(empty, " ") == ""); + + + dlog << LTRACE << 6; + DLIB_TEST( (lpad(wstring(L"davis"), 10) == L" davis")); + DLIB_TEST( (rpad(wstring(L"davis"), 10) == L"davis ")); + DLIB_TEST( (pad(wstring(L"davis"), 10) == L" davis ")); + + DLIB_TEST( (lpad(string("davis"), -10) == "davis")); + DLIB_TEST( (rpad(string("davis"), -10) == "davis")); + DLIB_TEST( (pad(string("davis"), -10) == "davis")); + DLIB_TEST( (lpad(string("davis"), 10) == " davis")); + DLIB_TEST( (rpad(string("davis"), 10) == "davis ")); + DLIB_TEST( (pad(string("davis"), 10) == " davis ")); + DLIB_TEST( (lpad(string("davis"), 10, string("*")) == "*****davis")); + DLIB_TEST( (rpad(string("davis"), 10, string("*")) == "davis*****")); + DLIB_TEST( (pad(string("davis"), 10, string("*")) == "**davis***")); + DLIB_TEST( (lpad(string("davis"), 10, string("_-")) == "_-_-_davis")); + DLIB_TEST( (rpad(string("davis"), 10, string("_-")) == "davis_-_-_")); + DLIB_TEST( (pad(string("davis"), 10, string("_-")) == "_-davis_-_")); + DLIB_TEST( (lpad(string("davis"), 10, string("willy wanka")) == "willydavis")); + DLIB_TEST( (rpad(string("davis"), 10, string("willy wanka")) == "daviswilly")); + DLIB_TEST( (pad(string("davis"), 10, string("willy wanka")) == "widaviswil")); + DLIB_TEST( (lpad(string("davis"), 10, "*")) == "*****davis"); + DLIB_TEST( (rpad(string("davis"), 10, "*") == "davis*****")); + DLIB_TEST( (pad(string("davis"), 10, "*") == "**davis***")); + DLIB_TEST( (lpad(string("davis"), 10, "_-") == "_-_-_davis")); + DLIB_TEST( (rpad(string("davis"), 10, "_-") == "davis_-_-_")); + DLIB_TEST( (pad(string("davis"), 10, "_-") == "_-davis_-_")); + DLIB_TEST( (lpad(string("davis"), 10, "willy wanka") == "willydavis")); + DLIB_TEST( (rpad(string("davis"), 10, "willy wanka") == "daviswilly")); + DLIB_TEST( (pad(string("davis"), 10, "willy wanka") == "widaviswil")); + dlog << LTRACE << 7; + + a = "file.txt"; + DLIB_TEST( (left_substr(a,string(".")) == "file")); + DLIB_TEST( (left_substr(a,".") == "file")); + DLIB_TEST( (right_substr(a,string(".")) == "txt")); + DLIB_TEST( (right_substr(a,".") == "txt")); + + DLIB_TEST( (left_substr(a," ") == "file.txt")); + DLIB_TEST( (right_substr(a," ") == "")); + + DLIB_TEST( (left_substr(a,"") == "file.txt")); + DLIB_TEST( (right_substr(a,"") == "")); + + wstring ws = L"file.txt"; + DLIB_TEST( (left_substr(ws,wstring(L".")) == L"file")); + DLIB_TEST_MSG( (left_substr(ws,L".") == L"file"), L""); + DLIB_TEST( (right_substr(ws,wstring(L".")) == L"txt")); + DLIB_TEST_MSG( (right_substr(ws,L".") == L"txt"), L""); + + + dlog << LTRACE << 8; + { + ostringstream sout; + wchar_t w = 85; + char c = 4; + serialize(w,sout); + serialize(c,sout); + w = static_cast(-1); + serialize(w,sout); + c = static_cast(-1); + serialize(c,sout); + + istringstream sin(sout.str()); + w = 0; + c = 0; + deserialize(w,sin); + deserialize(c,sin); + DLIB_TEST(w == 85); + DLIB_TEST(c == 4); + deserialize(w,sin); + deserialize(c,sin); + DLIB_TEST(w == static_cast(-1)); + DLIB_TEST(c == static_cast(-1)); + + wstring str = L"test string"; + + sout.str(""); + serialize(str, sout); + sin.clear(); + sin.str(sout.str()); + str = L"something else"; + deserialize(str,sin); + DLIB_TEST(str == L"test string"); + } + } + + + void test_split() + { + std::vector v; + + string str; + string delim = " , "; + + v = split(string("one, two,three four")," ,"); + DLIB_TEST(v.size() == 4); + DLIB_TEST(v[0] == "one"); + DLIB_TEST(v[1] == "two"); + DLIB_TEST(v[2] == "three"); + DLIB_TEST(v[3] == "four"); + + v = split(string("one, two,three four"),delim); + DLIB_TEST(v.size() == 4); + DLIB_TEST(v[0] == "one"); + DLIB_TEST(v[1] == "two"); + DLIB_TEST(v[2] == "three"); + DLIB_TEST(v[3] == "four"); + + v = split(string("")); + DLIB_TEST(v.size() == 0); + + v = split(string(" ")); + DLIB_TEST(v.size() == 0); + + v = split(string(" one two ")); + DLIB_TEST(v.size() == 2); + DLIB_TEST(v[0] == "one"); + DLIB_TEST(v[1] == "two"); + + v = split(string(" one ")); + DLIB_TEST(v.size() == 1); + DLIB_TEST(v[0] == "one"); + + v = split(string("one")); + DLIB_TEST(v.size() == 1); + DLIB_TEST(v[0] == "one"); + + v = split(string("o")); + DLIB_TEST(v.size() == 1); + DLIB_TEST(v[0] == "o"); + + + std::vector wv; + wstring wstr = L"test string"; + wv = split(wstr); + DLIB_TEST(wv.size() == 2); + DLIB_TEST(wv[0] == L"test"); + DLIB_TEST(wv[1] == L"string"); + wv = split(wstr,L" "); + DLIB_TEST(wv.size() == 2); + DLIB_TEST(wv[0] == L"test"); + DLIB_TEST(wv[1] == L"string"); + + wstr = L"Über alle Maßen\u00A0Öttingenstraße"; + wv = split(wstr, L" \u00A0\n\r\t"); + DLIB_TEST(wv.size() == 4); + DLIB_TEST(wv[0] == L"Über"); + DLIB_TEST(wv[1] == L"alle"); + DLIB_TEST(wv[2] == L"Maßen"); + DLIB_TEST(wv[3] == L"Öttingenstraße"); + + wstr = L"test string hah"; + DLIB_TEST(split_on_first(wstr).first == L"test"); + DLIB_TEST(split_on_first(wstr).second == L"string hah"); + DLIB_TEST(split_on_first(wstr,L"#").first == L"test string hah"); + DLIB_TEST(split_on_first(wstr,L"#").second == L""); + DLIB_TEST(split_on_last(wstr).first == L"test string"); + DLIB_TEST(split_on_last(wstr).second == L"hah"); + DLIB_TEST(split_on_last(wstr,L"#").first == L"test string hah"); + DLIB_TEST(split_on_last(wstr,L"#").second == L""); + wstr = L""; + DLIB_TEST(split_on_first(wstr).first == L""); + DLIB_TEST(split_on_first(wstr).second == L""); + + str = "test string hah"; + DLIB_TEST(split_on_first(str).first == "test"); + DLIB_TEST(split_on_first(str).second == "string hah"); + DLIB_TEST(split_on_first(str,"#").first == "test string hah"); + DLIB_TEST(split_on_first(str,"#").second == ""); + DLIB_TEST(split_on_last(str).first == "test string"); + DLIB_TEST(split_on_last(str).second == "hah"); + DLIB_TEST(split_on_last(str,"#").first == "test string hah"); + DLIB_TEST(split_on_last(str,"#").second == ""); + str = ""; + DLIB_TEST(split_on_first(str).first == ""); + DLIB_TEST(split_on_first(str).second == ""); + + wstr = L"test.string.hah"; + DLIB_TEST(split_on_first(wstr,L".").first == L"test"); + DLIB_TEST(split_on_first(wstr,L".").second == L"string.hah"); + DLIB_TEST(split_on_first(wstr).first == L"test.string.hah"); + DLIB_TEST(split_on_first(wstr).second == L""); + DLIB_TEST(split_on_last(wstr,L".").first == L"test.string"); + DLIB_TEST(split_on_last(wstr,L".").second == L"hah"); + DLIB_TEST(split_on_last(wstr).first == L"test.string.hah"); + DLIB_TEST(split_on_last(wstr).second == L""); + wstr = L""; + DLIB_TEST(split_on_first(wstr).first == L""); + DLIB_TEST(split_on_first(wstr).second == L""); + + str = "test.string.hah"; + DLIB_TEST(split_on_first(str,".").first == "test"); + DLIB_TEST(split_on_first(str,".").second == "string.hah"); + DLIB_TEST(split_on_first(str).first == "test.string.hah"); + DLIB_TEST(split_on_first(str).second == ""); + DLIB_TEST(split_on_last(str,".").first == "test.string"); + DLIB_TEST(split_on_last(str,".").second == "hah"); + DLIB_TEST(split_on_last(str).first == "test.string.hah"); + DLIB_TEST(split_on_last(str).second == ""); + str = ""; + DLIB_TEST(split_on_first(str).first == ""); + DLIB_TEST(split_on_first(str).second == ""); + } + + + + class string_tester : public tester + { + public: + string_tester ( + ) : + tester ("test_string", + "Runs tests on the string objects and functions.") + {} + + void perform_test ( + ) + { + string_test(); + test_split(); + } + } a; + +} + + + diff --git a/ml/dlib/dlib/test/svm.cpp b/ml/dlib/dlib/test/svm.cpp new file mode 100644 index 000000000..b46d44331 --- /dev/null +++ b/ml/dlib/dlib/test/svm.cpp @@ -0,0 +1,661 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include +#include +#include "../stl_checked.h" +#include "../array.h" +#include "../rand.h" +#include "checkerboard.h" +#include + +#include "tester.h" +#include + + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.svm"); + +// ---------------------------------------------------------------------------------------- + + void test_clutering ( + ) + { + dlog << LINFO << " being test_clutering()"; + // Here we declare that our samples will be 2 dimensional column vectors. + typedef matrix sample_type; + + // Now we are making a typedef for the kind of kernel we want to use. I picked the + // radial basis kernel because it only has one parameter and generally gives good + // results without much fiddling. + typedef radial_basis_kernel kernel_type; + + // Here we declare an instance of the kcentroid object. The first argument to the constructor + // is the kernel we wish to use. The second is a parameter that determines the numerical + // accuracy with which the object will perform part of the learning algorithm. Generally + // smaller values give better results but cause the algorithm to run slower. You just have + // to play with it to decide what balance of speed and accuracy is right for your problem. + // Here we have set it to 0.01. + kcentroid kc(kernel_type(0.1),0.01); + + // Now we make an instance of the kkmeans object and tell it to use kcentroid objects + // that are configured with the parameters from the kc object we defined above. + kkmeans test(kc); + + std::vector samples; + std::vector initial_centers; + + sample_type m; + + dlib::rand rnd; + + print_spinner(); + // we will make 50 points from each class + const long num = 50; + + // make some samples near the origin + double radius = 0.5; + for (long i = 0; i < num; ++i) + { + double sign = 1; + if (rnd.get_random_double() < 0.5) + sign = -1; + m(0) = 2*radius*rnd.get_random_double()-radius; + m(1) = sign*sqrt(radius*radius - m(0)*m(0)); + + // add this sample to our set of samples we will run k-means + samples.push_back(m); + } + + // make some samples in a circle around the origin but far away + radius = 10.0; + for (long i = 0; i < num; ++i) + { + double sign = 1; + if (rnd.get_random_double() < 0.5) + sign = -1; + m(0) = 2*radius*rnd.get_random_double()-radius; + m(1) = sign*sqrt(radius*radius - m(0)*m(0)); + + // add this sample to our set of samples we will run k-means + samples.push_back(m); + } + + // make some samples in a circle around the point (25,25) + radius = 4.0; + for (long i = 0; i < num; ++i) + { + double sign = 1; + if (rnd.get_random_double() < 0.5) + sign = -1; + m(0) = 2*radius*rnd.get_random_double()-radius; + m(1) = sign*sqrt(radius*radius - m(0)*m(0)); + + // translate this point away from the origin + m(0) += 25; + m(1) += 25; + + // add this sample to our set of samples we will run k-means + samples.push_back(m); + } + print_spinner(); + + // tell the kkmeans object we made that we want to run k-means with k set to 3. + // (i.e. we want 3 clusters) + test.set_number_of_centers(3); + + // You need to pick some initial centers for the k-means algorithm. So here + // we will use the dlib::pick_initial_centers() function which tries to find + // n points that are far apart (basically). + pick_initial_centers(3, initial_centers, samples, test.get_kernel()); + + print_spinner(); + // now run the k-means algorithm on our set of samples. + test.train(samples,initial_centers); + print_spinner(); + + const unsigned long class1 = test(samples[0]); + const unsigned long class2 = test(samples[num]); + const unsigned long class3 = test(samples[2*num]); + // now loop over all our samples and print out their predicted class. In this example + // all points are correctly identified. + for (unsigned long i = 0; i < samples.size()/3; ++i) + { + DLIB_TEST(test(samples[i]) == class1); + DLIB_TEST(test(samples[i+num]) == class2); + DLIB_TEST(test(samples[i+2*num]) == class3); + } + + dlog << LINFO << " end test_clutering()"; + } + +// ---------------------------------------------------------------------------------------- + + // Here is the sinc function we will be trying to learn with the krls + // object. + double sinc(double x) + { + if (x == 0) + return 1; + return sin(x)/x; + } + + + void test_regression ( + ) + { + dlog << LINFO << " being test_regression()"; + // Here we declare that our samples will be 1 dimensional column vectors. The reason for + // using a matrix here is that in general you can use N dimensional vectors as inputs to the + // krls object. But here we only have 1 dimension to make the example simple. + typedef matrix sample_type; + + // Now we are making a typedef for the kind of kernel we want to use. I picked the + // radial basis kernel because it only has one parameter and generally gives good + // results without much fiddling. + typedef radial_basis_kernel kernel_type; + + // Here we declare an instance of the krls object. The first argument to the constructor + // is the kernel we wish to use. The second is a parameter that determines the numerical + // accuracy with which the object will perform part of the regression algorithm. Generally + // smaller values give better results but cause the algorithm to run slower. You just have + // to play with it to decide what balance of speed and accuracy is right for your problem. + // Here we have set it to 0.001. + krls test(kernel_type(0.1),0.001); + rvm_regression_trainer rvm_test; + rvm_test.set_kernel(test.get_kernel()); + + krr_trainer krr_test; + krr_test.set_kernel(test.get_kernel()); + + svr_trainer svr_test; + svr_test.set_kernel(test.get_kernel()); + svr_test.set_epsilon_insensitivity(0.0001); + svr_test.set_c(10); + + rbf_network_trainer rbf_test; + rbf_test.set_kernel(test.get_kernel()); + rbf_test.set_num_centers(13); + + print_spinner(); + std::vector samples; + std::vector samples2; + std::vector labels; + std::vector labels2; + // now we train our object on a few samples of the sinc function. + sample_type m; + for (double x = -10; x <= 5; x += 0.6) + { + m(0) = x; + test.train(m, sinc(x)); + + samples.push_back(m); + samples2.push_back(m); + labels.push_back(sinc(x)); + labels2.push_back(2); + } + + print_spinner(); + decision_function test2 = rvm_test.train(samples, labels); + print_spinner(); + decision_function test3 = rbf_test.train(samples, labels); + print_spinner(); + decision_function test4 = krr_test.train(samples, labels); + print_spinner(); + decision_function test5 = svr_test.train(samples, labels); + print_spinner(); + + // now we output the value of the sinc function for a few test points as well as the + // value predicted by krls object. + m(0) = 2.5; dlog << LDEBUG << "krls: " << sinc(m(0)) << " " << test(m); DLIB_TEST(abs(sinc(m(0)) - test(m)) < 0.01); + m(0) = 0.1; dlog << LDEBUG << "krls: " << sinc(m(0)) << " " << test(m); DLIB_TEST(abs(sinc(m(0)) - test(m)) < 0.01); + m(0) = -4; dlog << LDEBUG << "krls: " << sinc(m(0)) << " " << test(m); DLIB_TEST(abs(sinc(m(0)) - test(m)) < 0.01); + m(0) = 5.0; dlog << LDEBUG << "krls: " << sinc(m(0)) << " " << test(m); DLIB_TEST(abs(sinc(m(0)) - test(m)) < 0.01); + + m(0) = 2.5; dlog << LDEBUG << "rvm: " << sinc(m(0)) << " " << test2(m); DLIB_TEST(abs(sinc(m(0)) - test2(m)) < 0.01); + m(0) = 0.1; dlog << LDEBUG << "rvm: " << sinc(m(0)) << " " << test2(m); DLIB_TEST(abs(sinc(m(0)) - test2(m)) < 0.01); + m(0) = -4; dlog << LDEBUG << "rvm: " << sinc(m(0)) << " " << test2(m); DLIB_TEST(abs(sinc(m(0)) - test2(m)) < 0.01); + m(0) = 5.0; dlog << LDEBUG << "rvm: " << sinc(m(0)) << " " << test2(m); DLIB_TEST(abs(sinc(m(0)) - test2(m)) < 0.01); + + m(0) = 2.5; dlog << LDEBUG << "rbf: " << sinc(m(0)) << " " << test3(m); DLIB_TEST(abs(sinc(m(0)) - test3(m)) < 0.01); + m(0) = 0.1; dlog << LDEBUG << "rbf: " << sinc(m(0)) << " " << test3(m); DLIB_TEST(abs(sinc(m(0)) - test3(m)) < 0.01); + m(0) = -4; dlog << LDEBUG << "rbf: " << sinc(m(0)) << " " << test3(m); DLIB_TEST(abs(sinc(m(0)) - test3(m)) < 0.01); + m(0) = 5.0; dlog << LDEBUG << "rbf: " << sinc(m(0)) << " " << test3(m); DLIB_TEST(abs(sinc(m(0)) - test3(m)) < 0.01); + + m(0) = 2.5; dlog << LDEBUG << "krr: " << sinc(m(0)) << " " << test4(m); DLIB_TEST(abs(sinc(m(0)) - test4(m)) < 0.01); + m(0) = 0.1; dlog << LDEBUG << "krr: " << sinc(m(0)) << " " << test4(m); DLIB_TEST(abs(sinc(m(0)) - test4(m)) < 0.01); + m(0) = -4; dlog << LDEBUG << "krr: " << sinc(m(0)) << " " << test4(m); DLIB_TEST(abs(sinc(m(0)) - test4(m)) < 0.01); + m(0) = 5.0; dlog << LDEBUG << "krr: " << sinc(m(0)) << " " << test4(m); DLIB_TEST(abs(sinc(m(0)) - test4(m)) < 0.01); + + m(0) = 2.5; dlog << LDEBUG << "svr: " << sinc(m(0)) << " " << test5(m); DLIB_TEST(abs(sinc(m(0)) - test5(m)) < 0.01); + m(0) = 0.1; dlog << LDEBUG << "svr: " << sinc(m(0)) << " " << test5(m); DLIB_TEST(abs(sinc(m(0)) - test5(m)) < 0.01); + m(0) = -4; dlog << LDEBUG << "svr: " << sinc(m(0)) << " " << test5(m); DLIB_TEST(abs(sinc(m(0)) - test5(m)) < 0.01); + m(0) = 5.0; dlog << LDEBUG << "svr: " << sinc(m(0)) << " " << test5(m); DLIB_TEST(abs(sinc(m(0)) - test5(m)) < 0.01); + + + randomize_samples(samples, labels); + dlog << LINFO << "KRR MSE and R-squared: "<< cross_validate_regression_trainer(krr_test, samples, labels, 6); + dlog << LINFO << "SVR MSE and R-squared: "<< cross_validate_regression_trainer(svr_test, samples, labels, 6); + matrix cv = cross_validate_regression_trainer(krr_test, samples, labels, 6); + DLIB_TEST(cv(0) < 1e-4); + DLIB_TEST(cv(1) > 0.99); + cv = cross_validate_regression_trainer(svr_test, samples, labels, 6); + DLIB_TEST(cv(0) < 1e-4); + DLIB_TEST(cv(1) > 0.99); + + + + + randomize_samples(samples2, labels2); + dlog << LINFO << "KRR MSE and R-squared: "<< cross_validate_regression_trainer(krr_test, samples2, labels2, 6); + dlog << LINFO << "SVR MSE and R-squared: "<< cross_validate_regression_trainer(svr_test, samples2, labels2, 6); + cv = cross_validate_regression_trainer(krr_test, samples2, labels2, 6); + DLIB_TEST(cv(0) < 1e-4); + cv = cross_validate_regression_trainer(svr_test, samples2, labels2, 6); + DLIB_TEST(cv(0) < 1e-4); + + dlog << LINFO << " end test_regression()"; + } + +// ---------------------------------------------------------------------------------------- + + void test_anomaly_detection ( + ) + { + dlog << LINFO << " begin test_anomaly_detection()"; + // Here we declare that our samples will be 2 dimensional column vectors. + typedef matrix sample_type; + + // Now we are making a typedef for the kind of kernel we want to use. I picked the + // radial basis kernel because it only has one parameter and generally gives good + // results without much fiddling. + typedef radial_basis_kernel kernel_type; + + // Here we declare an instance of the kcentroid object. The first argument to the constructor + // is the kernel we wish to use. The second is a parameter that determines the numerical + // accuracy with which the object will perform part of the learning algorithm. Generally + // smaller values give better results but cause the algorithm to run slower. You just have + // to play with it to decide what balance of speed and accuracy is right for your problem. + // Here we have set it to 0.01. + kcentroid test(kernel_type(0.1),0.01); + + + svm_one_class_trainer one_class_trainer; + one_class_trainer.set_nu(0.4); + one_class_trainer.set_kernel(kernel_type(0.2)); + + std::vector samples; + + // now we train our object on a few samples of the sinc function. + sample_type m; + for (double x = -15; x <= 8; x += 1) + { + m(0) = x; + m(1) = sinc(x); + test.train(m); + samples.push_back(m); + } + + decision_function df = one_class_trainer.train(samples); + + running_stats rs; + + // Now lets output the distance from the centroid to some points that are from the sinc function. + // These numbers should all be similar. We will also calculate the statistics of these numbers + // by accumulating them into the running_stats object called rs. This will let us easily + // find the mean and standard deviation of the distances for use below. + dlog << LDEBUG << "Points that are on the sinc function:\n"; + m(0) = -1.5; m(1) = sinc(m(0)); dlog << LDEBUG << " " << test(m); rs.add(test(m)); + m(0) = -1.5; m(1) = sinc(m(0)); dlog << LDEBUG << " " << test(m); rs.add(test(m)); + m(0) = -0; m(1) = sinc(m(0)); dlog << LDEBUG << " " << test(m); rs.add(test(m)); + m(0) = -0.5; m(1) = sinc(m(0)); dlog << LDEBUG << " " << test(m); rs.add(test(m)); + m(0) = -4.1; m(1) = sinc(m(0)); dlog << LDEBUG << " " << test(m); rs.add(test(m)); + m(0) = -1.5; m(1) = sinc(m(0)); dlog << LDEBUG << " " << test(m); rs.add(test(m)); + m(0) = -0.5; m(1) = sinc(m(0)); dlog << LDEBUG << " " << test(m); rs.add(test(m)); + + m(0) = -1.5; m(1) = sinc(m(0)); DLIB_TEST_MSG(rs.scale(test(m)) < 2, rs.scale(test(m))); + m(0) = -1.5; m(1) = sinc(m(0)); DLIB_TEST_MSG(rs.scale(test(m)) < 2, rs.scale(test(m))); + m(0) = -0; m(1) = sinc(m(0)); DLIB_TEST_MSG(rs.scale(test(m)) < 2, rs.scale(test(m))); + m(0) = -0.5; m(1) = sinc(m(0)); DLIB_TEST_MSG(rs.scale(test(m)) < 2, rs.scale(test(m))); + m(0) = -4.1; m(1) = sinc(m(0)); DLIB_TEST_MSG(rs.scale(test(m)) < 2, rs.scale(test(m))); + m(0) = -1.5; m(1) = sinc(m(0)); DLIB_TEST_MSG(rs.scale(test(m)) < 2, rs.scale(test(m))); + m(0) = -0.5; m(1) = sinc(m(0)); DLIB_TEST_MSG(rs.scale(test(m)) < 2, rs.scale(test(m))); + + const double thresh = 0.01; + m(0) = -1.5; m(1) = sinc(m(0)); DLIB_TEST_MSG(df(m)+thresh > 0, df(m)); + m(0) = -1.5; m(1) = sinc(m(0)); DLIB_TEST_MSG(df(m)+thresh > 0, df(m)); + m(0) = -0; m(1) = sinc(m(0)); DLIB_TEST_MSG(df(m)+thresh > 0, df(m)); + m(0) = -0.5; m(1) = sinc(m(0)); DLIB_TEST_MSG(df(m)+thresh > 0, df(m)); + m(0) = -4.1; m(1) = sinc(m(0)); DLIB_TEST_MSG(df(m)+thresh > 0, df(m)); + m(0) = -1.5; m(1) = sinc(m(0)); DLIB_TEST_MSG(df(m)+thresh > 0, df(m)); + m(0) = -0.5; m(1) = sinc(m(0)); DLIB_TEST_MSG(df(m)+thresh > 0, df(m)); + + dlog << LDEBUG; + // Lets output the distance from the centroid to some points that are NOT from the sinc function. + // These numbers should all be significantly bigger than previous set of numbers. We will also + // use the rs.scale() function to find out how many standard deviations they are away from the + // mean of the test points from the sinc function. So in this case our criterion for "significantly bigger" + // is > 3 or 4 standard deviations away from the above points that actually are on the sinc function. + dlog << LDEBUG << "Points that are NOT on the sinc function:\n"; + m(0) = -1.5; m(1) = sinc(m(0))+4; + dlog << LDEBUG << " " << test(m) << " is " << rs.scale(test(m)) << " standard deviations from sinc."; + DLIB_TEST_MSG(rs.scale(test(m)) > 6, rs.scale(test(m))); + DLIB_TEST_MSG(df(m) + thresh < 0, df(m)); + + m(0) = -1.5; m(1) = sinc(m(0))+3; + dlog << LDEBUG << " " << test(m) << " is " << rs.scale(test(m)) << " standard deviations from sinc."; + DLIB_TEST_MSG(rs.scale(test(m)) > 6, rs.scale(test(m))); + DLIB_TEST_MSG(df(m) + thresh < 0, df(m)); + + m(0) = -0; m(1) = -sinc(m(0)); + dlog << LDEBUG << " " << test(m) << " is " << rs.scale(test(m)) << " standard deviations from sinc."; + DLIB_TEST_MSG(rs.scale(test(m)) > 6, rs.scale(test(m))); + DLIB_TEST_MSG(df(m) + thresh < 0, df(m)); + + m(0) = -0.5; m(1) = -sinc(m(0)); + dlog << LDEBUG << " " << test(m) << " is " << rs.scale(test(m)) << " standard deviations from sinc."; + DLIB_TEST_MSG(rs.scale(test(m)) > 6, rs.scale(test(m))); + DLIB_TEST_MSG(df(m) + thresh < 0, df(m)); + + m(0) = -4.1; m(1) = sinc(m(0))+2; + dlog << LDEBUG << " " << test(m) << " is " << rs.scale(test(m)) << " standard deviations from sinc."; + DLIB_TEST_MSG(rs.scale(test(m)) > 6, rs.scale(test(m))); + DLIB_TEST_MSG(df(m) + thresh < 0, df(m)); + + m(0) = -1.5; m(1) = sinc(m(0))+0.9; + dlog << LDEBUG << " " << test(m) << " is " << rs.scale(test(m)) << " standard deviations from sinc."; + DLIB_TEST_MSG(rs.scale(test(m)) > 6, rs.scale(test(m))); + DLIB_TEST_MSG(df(m) + thresh < 0, df(m)); + + m(0) = -0.5; m(1) = sinc(m(0))+1; + dlog << LDEBUG << " " << test(m) << " is " << rs.scale(test(m)) << " standard deviations from sinc."; + DLIB_TEST_MSG(rs.scale(test(m)) > 6, rs.scale(test(m))); + DLIB_TEST_MSG(df(m) + thresh < 0, df(m)); + + dlog << LINFO << " end test_anomaly_detection()"; + } + +// ---------------------------------------------------------------------------------------- + + void unittest_binary_classification ( + ) + /*! + ensures + - runs tests on the svm stuff compliance with the specs + !*/ + { + dlog << LINFO << " begin unittest_binary_classification()"; + print_spinner(); + + + typedef double scalar_type; + typedef matrix sample_type; + + std::vector x; + std::vector > x_linearized; + std::vector y; + + get_checkerboard_problem(x,y, 300, 2); + const scalar_type gamma = 1; + + typedef radial_basis_kernel kernel_type; + + rbf_network_trainer rbf_trainer; + rbf_trainer.set_kernel(kernel_type(gamma)); + rbf_trainer.set_num_centers(100); + + rvm_trainer rvm_trainer; + rvm_trainer.set_kernel(kernel_type(gamma)); + + krr_trainer krr_trainer; + krr_trainer.use_classification_loss_for_loo_cv(); + krr_trainer.set_kernel(kernel_type(gamma)); + + svm_pegasos pegasos_trainer; + pegasos_trainer.set_kernel(kernel_type(gamma)); + pegasos_trainer.set_lambda(0.00001); + + + svm_c_ekm_trainer ocas_ekm_trainer; + ocas_ekm_trainer.set_kernel(kernel_type(gamma)); + ocas_ekm_trainer.set_c(100000); + + svm_nu_trainer trainer; + trainer.set_kernel(kernel_type(gamma)); + trainer.set_nu(0.05); + + svm_c_trainer c_trainer; + c_trainer.set_kernel(kernel_type(gamma)); + c_trainer.set_c(100); + + svm_c_linear_trainer > > lin_trainer; + lin_trainer.set_c(100000); + // use an ekm to linearize this dataset so we can use it with the lin_trainer + empirical_kernel_map ekm; + ekm.load(kernel_type(gamma), x); + for (unsigned long i = 0; i < x.size(); ++i) + x_linearized.push_back(ekm.project(x[i])); + + + print_spinner(); + matrix rvm_cv = cross_validate_trainer_threaded(rvm_trainer, x,y, 4, 2); + print_spinner(); + matrix krr_cv = cross_validate_trainer_threaded(krr_trainer, x,y, 4, 2); + print_spinner(); + matrix svm_cv = cross_validate_trainer(trainer, x,y, 4); + print_spinner(); + matrix svm_c_cv = cross_validate_trainer(c_trainer, x,y, 4); + print_spinner(); + matrix rbf_cv = cross_validate_trainer_threaded(rbf_trainer, x,y, 10, 2); + print_spinner(); + matrix lin_cv = cross_validate_trainer_threaded(lin_trainer, x_linearized, y, 4, 2); + print_spinner(); + matrix ocas_ekm_cv = cross_validate_trainer_threaded(ocas_ekm_trainer, x, y, 4, 2); + print_spinner(); + ocas_ekm_trainer.set_basis(randomly_subsample(x, 300)); + matrix ocas_ekm_cv2 = cross_validate_trainer_threaded(ocas_ekm_trainer, x, y, 4, 2); + print_spinner(); + matrix peg_cv = cross_validate_trainer_threaded(batch(pegasos_trainer,1.0), x,y, 4, 2); + print_spinner(); + matrix peg_c_cv = cross_validate_trainer_threaded(batch_cached(pegasos_trainer,1.0), x,y, 4, 2); + print_spinner(); + + dlog << LDEBUG << "rvm cv: " << rvm_cv; + dlog << LDEBUG << "krr cv: " << krr_cv; + dlog << LDEBUG << "nu-svm cv: " << svm_cv; + dlog << LDEBUG << "C-svm cv: " << svm_c_cv; + dlog << LDEBUG << "rbf cv: " << rbf_cv; + dlog << LDEBUG << "lin cv: " << lin_cv; + dlog << LDEBUG << "ocas_ekm cv: " << ocas_ekm_cv; + dlog << LDEBUG << "ocas_ekm cv2: " << ocas_ekm_cv2; + dlog << LDEBUG << "peg cv: " << peg_cv; + dlog << LDEBUG << "peg cached cv: " << peg_c_cv; + + // make sure the cached version of pegasos computes the same result + DLIB_TEST_MSG(sum(abs(peg_cv - peg_c_cv)) < std::sqrt(std::numeric_limits::epsilon()), + sum(abs(peg_cv - peg_c_cv)) << " \n" << peg_cv << peg_c_cv ); + + DLIB_TEST_MSG(mean(rvm_cv) > 0.9, rvm_cv); + DLIB_TEST_MSG(mean(krr_cv) > 0.9, krr_cv); + DLIB_TEST_MSG(mean(svm_cv) > 0.9, svm_cv); + DLIB_TEST_MSG(mean(svm_c_cv) > 0.9, svm_c_cv); + DLIB_TEST_MSG(mean(rbf_cv) > 0.9, rbf_cv); + DLIB_TEST_MSG(mean(lin_cv) > 0.9, lin_cv); + DLIB_TEST_MSG(mean(peg_cv) > 0.9, peg_cv); + DLIB_TEST_MSG(mean(peg_c_cv) > 0.9, peg_c_cv); + DLIB_TEST_MSG(mean(ocas_ekm_cv) > 0.9, ocas_ekm_cv); + DLIB_TEST_MSG(mean(ocas_ekm_cv2) > 0.9, ocas_ekm_cv2); + + const long num_sv = trainer.train(x,y).basis_vectors.size(); + print_spinner(); + const long num_rv = rvm_trainer.train(x,y).basis_vectors.size(); + print_spinner(); + dlog << LDEBUG << "num sv: " << num_sv; + dlog << LDEBUG << "num rv: " << num_rv; + print_spinner(); + ocas_ekm_trainer.clear_basis(); + const long num_bv = ocas_ekm_trainer.train(x,y).basis_vectors.size(); + dlog << LDEBUG << "num ekm bv: " << num_bv; + + DLIB_TEST(num_rv <= 17); + DLIB_TEST_MSG(num_sv <= 45, num_sv); + DLIB_TEST_MSG(num_bv <= 45, num_bv); + + decision_function df = reduced2(trainer, 19).train(x,y); + print_spinner(); + + matrix svm_reduced_error = test_binary_decision_function(df, x, y); + print_spinner(); + dlog << LDEBUG << "svm reduced test error: " << svm_reduced_error; + dlog << LDEBUG << "svm reduced num sv: " << df.basis_vectors.size(); + DLIB_TEST(mean(svm_reduced_error) > 0.9); + + svm_cv = cross_validate_trainer(reduced(trainer,30), x,y, 4); + dlog << LDEBUG << "svm reduced cv: " << svm_cv; + DLIB_TEST_MSG(mean(svm_cv) > 0.9, svm_cv); + + DLIB_TEST(df.basis_vectors.size() <= 19); + dlog << LINFO << " end unittest_binary_classification()"; + } + +// ---------------------------------------------------------------------------------------- + + template + struct kernel_der_obj + { + typename kernel_type::sample_type x; + kernel_type k; + + double operator()(const typename kernel_type::sample_type& y) const { return k(x,y); } + }; + + + template + void test_kernel_derivative ( + const kernel_type& k, + const typename kernel_type::sample_type& x, + const typename kernel_type::sample_type& y + ) + { + kernel_der_obj obj; + obj.x = x; + obj.k = k; + kernel_derivative der(obj.k); + DLIB_TEST(dlib::equal(derivative(obj)(y) , der(obj.x,y), 1e-5)); + } + + void test_kernel_derivative ( + ) + { + typedef matrix sample_type; + + sigmoid_kernel k1; + radial_basis_kernel k2; + linear_kernel k3; + polynomial_kernel k4(2,3,4); + + offset_kernel > k5; + offset_kernel > k6; + + dlib::rand rnd; + + sample_type x, y; + for (int i = 0; i < 10; ++i) + { + x = randm(2,1,rnd); + y = randm(2,1,rnd); + test_kernel_derivative(k1, x, y); + test_kernel_derivative(k2, x, y); + test_kernel_derivative(k3, x, y); + test_kernel_derivative(k4, x, y); + test_kernel_derivative(k5, x, y); + test_kernel_derivative(k6, x, y); + } + } + +// ---------------------------------------------------------------------------------------- + + void test_svm_trainer2() + { + typedef matrix sample_type; + typedef linear_kernel kernel_type; + + + std::vector samples; + std::vector labels; + + sample_type samp; + samp(0) = 1; + samp(1) = 1; + samples.push_back(samp); + labels.push_back(+1); + + samp(0) = 1; + samp(1) = 2; + samples.push_back(samp); + labels.push_back(-1); + + svm_c_trainer trainer; + + decision_function df = trainer.train(samples, labels); + + samp(0) = 1; + samp(1) = 1; + dlog << LINFO << "test +1 : "<< df(samp); + DLIB_TEST(df(samp) > 0); + samp(0) = 1; + samp(1) = 2; + dlog << LINFO << "test -1 : "<< df(samp); + DLIB_TEST(df(samp) < 0); + + svm_nu_trainer trainer2; + df = trainer2.train(samples, labels); + + samp(0) = 1; + samp(1) = 1; + dlog << LINFO << "test +1 : "<< df(samp); + DLIB_TEST(df(samp) > 0); + samp(0) = 1; + samp(1) = 2; + dlog << LINFO << "test -1 : "<< df(samp); + DLIB_TEST(df(samp) < 0); + + } + +// ---------------------------------------------------------------------------------------- + + class svm_tester : public tester + { + public: + svm_tester ( + ) : + tester ("test_svm", + "Runs tests on the svm/kernel algorithm components.") + {} + + void perform_test ( + ) + { + test_kernel_derivative(); + unittest_binary_classification(); + test_clutering(); + test_regression(); + test_anomaly_detection(); + test_svm_trainer2(); + } + } a; + +} + + diff --git a/ml/dlib/dlib/test/svm_c_linear.cpp b/ml/dlib/dlib/test/svm_c_linear.cpp new file mode 100644 index 000000000..9e30d81f7 --- /dev/null +++ b/ml/dlib/dlib/test/svm_c_linear.cpp @@ -0,0 +1,392 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include +#include +#include "../stl_checked.h" +#include "../array.h" +#include "../rand.h" +#include "checkerboard.h" +#include + +#include "tester.h" +#include + + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.svm_c_linear"); + + typedef matrix sample_type; + typedef std::vector > sparse_sample_type; + +// ---------------------------------------------------------------------------------------- + + void run_prior_test() + { + typedef matrix sample_type; + typedef linear_kernel kernel_type; + + svm_c_linear_trainer trainer; + + std::vector samples; + std::vector labels; + + sample_type samp; + samp = 0, 0, 1; samples.push_back(samp); labels.push_back(+1); + samp = 0, 1, 0; samples.push_back(samp); labels.push_back(-1); + + trainer.set_c(10); + decision_function df = trainer.train(samples, labels); + + trainer.set_prior(df); + + samples.clear(); + labels.clear(); + samp = 1, 0, 0; samples.push_back(samp); labels.push_back(+1); + samp = 0, 1, 0; samples.push_back(samp); labels.push_back(-1); + + df = trainer.train(samples, labels); + + samp = 0, 0, 1; samples.push_back(samp); labels.push_back(+1); + matrix rs = test_binary_decision_function(df, samples, labels); + dlog << LINFO << rs; + DLIB_TEST(rs(0) == 1); + DLIB_TEST(rs(1) == 1); + + dlog << LINFO << trans(df.basis_vectors(0)); + DLIB_TEST(df.basis_vectors(0)(0) > 0); + DLIB_TEST(df.basis_vectors(0)(1) < 0); + DLIB_TEST(df.basis_vectors(0)(2) > 0); + } + + void run_prior_sparse_test() + { + typedef std::map sample_type; + typedef sparse_linear_kernel kernel_type; + + svm_c_linear_trainer trainer; + + std::vector samples; + std::vector labels; + + sample_type samp; + samp[0] = 1; samples.push_back(samp); labels.push_back(+1); samp.clear(); + samp[1] = 1; samples.push_back(samp); labels.push_back(-1); samp.clear(); + + trainer.set_c(10); + decision_function df = trainer.train(samples, labels); + + trainer.set_prior(df); + + samples.clear(); + labels.clear(); + samp[2] = 1; samples.push_back(samp); labels.push_back(+1); samp.clear(); + samp[1] = 1; samples.push_back(samp); labels.push_back(-1); samp.clear(); + + df = trainer.train(samples, labels); + + matrix rs = test_binary_decision_function(df, samples, labels); + dlog << LINFO << rs; + DLIB_TEST(rs(0) == 1); + DLIB_TEST(rs(1) == 1); + + matrix w = sparse_to_dense(df.basis_vectors(0)); + dlog << LINFO << trans(w); + DLIB_TEST(w(0) > 0.1); + DLIB_TEST(w(1) < -0.1); + DLIB_TEST(w(2) > 0.1); + } + + void get_simple_points ( + std::vector& samples, + std::vector& labels + ) + { + samples.clear(); + labels.clear(); + sample_type samp(2); + + samp = 0,0; + samples.push_back(samp); + labels.push_back(-1); + + samp = 0,1; + samples.push_back(samp); + labels.push_back(-1); + + samp = 3,0; + samples.push_back(samp); + labels.push_back(+1); + + samp = 3,1; + samples.push_back(samp); + labels.push_back(+1); + } + +// ---------------------------------------------------------------------------------------- + + void get_simple_points_sparse ( + std::vector& samples, + std::vector& labels + ) + { + samples.clear(); + labels.clear(); + sparse_sample_type samp; + + samp.push_back(make_pair(0, 0.0)); + samp.push_back(make_pair(1, 0.0)); + samples.push_back(samp); + labels.push_back(-1); + + samp.clear(); + samp.push_back(make_pair(0, 0.0)); + samp.push_back(make_pair(1, 1.0)); + samples.push_back(samp); + labels.push_back(-1); + + samp.clear(); + samp.push_back(make_pair(0, 3.0)); + samp.push_back(make_pair(1, 0.0)); + samples.push_back(samp); + labels.push_back(+1); + + samp.clear(); + samp.push_back(make_pair(0, 3.0)); + samp.push_back(make_pair(1, 1.0)); + samples.push_back(samp); + labels.push_back(+1); + } + +// ---------------------------------------------------------------------------------------- + + void test_sparse ( + ) + { + print_spinner(); + dlog << LINFO << "test with sparse vectors"; + std::vector samples; + std::vector labels; + + sample_type samp; + + get_simple_points_sparse(samples,labels); + + svm_c_linear_trainer > trainer; + trainer.set_c(1e4); + //trainer.be_verbose(); + trainer.set_epsilon(1e-11); + + + double obj; + decision_function > df = trainer.train(samples, labels, obj); + dlog << LDEBUG << "obj: "<< obj; + DLIB_TEST_MSG(abs(obj - 0.72222222222) < 1e-7, obj); + + DLIB_TEST(abs(df(samples[0]) - (-1)) < 1e-6); + DLIB_TEST(abs(df(samples[1]) - (-1)) < 1e-6); + DLIB_TEST(abs(df(samples[2]) - (1)) < 1e-6); + DLIB_TEST(abs(df(samples[3]) - (1)) < 1e-6); + + + // While we are at it, make sure the krr_trainer works with sparse samples + krr_trainer > krr; + + df = krr.train(samples, labels); + DLIB_TEST(abs(df(samples[0]) - (-1)) < 1e-6); + DLIB_TEST(abs(df(samples[1]) - (-1)) < 1e-6); + DLIB_TEST(abs(df(samples[2]) - (1)) < 1e-6); + DLIB_TEST(abs(df(samples[3]) - (1)) < 1e-6); + + + // Now test some of the sparse helper functions + DLIB_TEST(max_index_plus_one(samples) == 2); + DLIB_TEST(max_index_plus_one(samples[0]) == 2); + + matrix m; + m = 1; + add_to(m, samples[3]); + DLIB_TEST(m(0) == 1 + samples[3][0].second); + DLIB_TEST(m(1) == 1 + samples[3][1].second); + DLIB_TEST(m(2) == 1); + + m = 1; + subtract_from(m, samples[3]); + DLIB_TEST(m(0) == 1 - samples[3][0].second); + DLIB_TEST(m(1) == 1 - samples[3][1].second); + DLIB_TEST(m(2) == 1); + + m = 1; + add_to(m, samples[3], 2); + DLIB_TEST(m(0) == 1 + 2*samples[3][0].second); + DLIB_TEST(m(1) == 1 + 2*samples[3][1].second); + DLIB_TEST(m(2) == 1); + + m = 1; + subtract_from(m, samples[3], 2); + DLIB_TEST(m(0) == 1 - 2*samples[3][0].second); + DLIB_TEST(m(1) == 1 - 2*samples[3][1].second); + DLIB_TEST(m(2) == 1); + + } + +// ---------------------------------------------------------------------------------------- + + void test_dense ( + ) + { + print_spinner(); + dlog << LINFO << "test with dense vectors"; + std::vector samples; + std::vector labels; + + sample_type samp; + + get_simple_points(samples,labels); + + svm_c_linear_trainer > trainer; + trainer.set_c(1e4); + //trainer.be_verbose(); + trainer.set_epsilon(1e-11); + + + double obj; + decision_function > df = trainer.train(samples, labels, obj); + dlog << LDEBUG << "obj: "<< obj; + DLIB_TEST_MSG(abs(obj - 0.72222222222) < 1e-7, abs(obj - 0.72222222222)); + // There shouldn't be any margin violations since this dataset is so trivial. So that means the objective + // should be exactly the squared norm of the decision plane (times 0.5). + DLIB_TEST_MSG(abs(length_squared(df.basis_vectors(0))*0.5 + df.b*df.b*0.5 - 0.72222222222) < 1e-7, + length_squared(df.basis_vectors(0))*0.5 + df.b*df.b*0.5); + + DLIB_TEST(abs(df(samples[0]) - (-1)) < 1e-6); + DLIB_TEST(abs(df(samples[1]) - (-1)) < 1e-6); + DLIB_TEST(abs(df(samples[2]) - (1)) < 1e-6); + DLIB_TEST(abs(df(samples[3]) - (1)) < 1e-6); + } + +// ---------------------------------------------------------------------------------------- + + class tester_svm_c_linear : public tester + { + public: + tester_svm_c_linear ( + ) : + tester ("test_svm_c_linear", + "Runs tests on the svm_c_linear_trainer.") + {} + + void perform_test ( + ) + { + test_dense(); + test_sparse(); + run_prior_test(); + run_prior_sparse_test(); + + // test mixed sparse and dense dot products + { + std::map sv; + matrix dv(4); + + dv = 1,2,3,4; + + sv[0] = 1; + sv[3] = 1; + + + DLIB_TEST(dot(sv,dv) == 5); + DLIB_TEST(dot(dv,sv) == 5); + DLIB_TEST(dot(dv,dv) == 30); + DLIB_TEST(dot(sv,sv) == 2); + + sv[10] = 9; + DLIB_TEST(dot(sv,dv) == 5); + } + + // test mixed sparse dense assignments + { + std::map sv, sv2; + std::vector > sv3; + matrix dv(4), dv2; + + dv = 1,2,3,4; + + sv[0] = 1; + sv[3] = 1; + + + assign(dv2, dv); + + DLIB_TEST(dv2.size() == 4); + DLIB_TEST(dv2(0) == 1); + DLIB_TEST(dv2(1) == 2); + DLIB_TEST(dv2(2) == 3); + DLIB_TEST(dv2(3) == 4); + + assign(sv2, dv); + DLIB_TEST(sv2.size() == 4); + DLIB_TEST(sv2[0] == 1); + DLIB_TEST(sv2[1] == 2); + DLIB_TEST(sv2[2] == 3); + DLIB_TEST(sv2[3] == 4); + + assign(sv2, sv); + DLIB_TEST(sv2.size() == 2); + DLIB_TEST(sv2[0] == 1); + DLIB_TEST(sv2[1] == 0); + DLIB_TEST(sv2[2] == 0); + DLIB_TEST(sv2[3] == 1); + + assign(sv3, sv); + DLIB_TEST(sv3.size() == 2); + DLIB_TEST(sv3[0].second == 1); + DLIB_TEST(sv3[1].second == 1); + DLIB_TEST(sv3[0].first == 0); + DLIB_TEST(sv3[1].first == 3); + + assign(sv3, dv); + DLIB_TEST(sv3.size() == 4); + DLIB_TEST(sv3[0].second == 1); + DLIB_TEST(sv3[1].second == 2); + DLIB_TEST(sv3[2].second == 3); + DLIB_TEST(sv3[3].second == 4); + DLIB_TEST(sv3[0].first == 0); + DLIB_TEST(sv3[1].first == 1); + DLIB_TEST(sv3[2].first == 2); + DLIB_TEST(sv3[3].first == 3); + + assign(sv3, sv); + DLIB_TEST(sv3.size() == 2); + DLIB_TEST(sv3[0].second == 1); + DLIB_TEST(sv3[1].second == 1); + DLIB_TEST(sv3[0].first == 0); + DLIB_TEST(sv3[1].first == 3); + + sv.clear(); + assign(sv, sv3); + DLIB_TEST(sv.size() == 2); + DLIB_TEST(sv[0] == 1); + DLIB_TEST(sv[1] == 0); + DLIB_TEST(sv[2] == 0); + DLIB_TEST(sv[3] == 1); + + } + } + } a; + +} + + + diff --git a/ml/dlib/dlib/test/svm_c_linear_dcd.cpp b/ml/dlib/dlib/test/svm_c_linear_dcd.cpp new file mode 100644 index 000000000..93c99db30 --- /dev/null +++ b/ml/dlib/dlib/test/svm_c_linear_dcd.cpp @@ -0,0 +1,545 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include + +#include "tester.h" + + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.svm_c_linear_dcd"); + +// ---------------------------------------------------------------------------------------- + + void test_sparse() + { + typedef std::map sample_type; + + + typedef sparse_linear_kernel kernel_type; + + + + svm_c_linear_trainer linear_trainer_cpa; + svm_c_linear_dcd_trainer linear_trainer; + + svm_c_linear_dcd_trainer::optimizer_state state; + + const double C = 0.2; + linear_trainer.set_epsilon(1e-10); + linear_trainer_cpa.set_epsilon(1e-10); + + + std::vector samples; + std::vector labels; + + // make an instance of a sample vector so we can use it below + sample_type sample; + + decision_function df, df2, df3; + + dlib::rand rnd; + // Now lets go into a loop and randomly generate 10000 samples. + double label = +1; + for (int i = 0; i < 100; ++i) + { + // flip this flag + label *= -1; + + sample.clear(); + + // now make a random sparse sample with at most 10 non-zero elements + for (int j = 0; j < 5; ++j) + { + int idx = rnd.get_random_32bit_number()%10; + double value = rnd.get_random_double(); + + sample[idx] = label*value; + } + + // Also save the samples we are generating so we can let the svm_c_linear_trainer + // learn from them below. + samples.push_back(sample); + labels.push_back(label); + + if (samples.size() > 1) + { + linear_trainer_cpa.set_c_class1(C); + linear_trainer_cpa.set_c_class2(1.5*C); + linear_trainer.set_c_class1(C/samples.size()); + linear_trainer.set_c_class2(1.5*C/samples.size()); + + df = linear_trainer.train(samples, labels, state); + df2 = linear_trainer_cpa.train(samples, labels); + df3 = linear_trainer.train(samples, labels); + + DLIB_TEST_MSG( dlib::distance(df.basis_vectors(0), df2.basis_vectors(0)) < 1e-8, dlib::distance(df.basis_vectors(0), df2.basis_vectors(0))); + DLIB_TEST( std::abs(df.b - df2.b) < 1e-8); + DLIB_TEST( dlib::distance(df.basis_vectors(0), df3.basis_vectors(0)) < 1e-8); + DLIB_TEST( std::abs(df.b - df3.b) < 1e-8); + } + } + } + +// ---------------------------------------------------------------------------------------- + + void test_normal_no_bias() + { + typedef matrix sample_type; + + + typedef linear_kernel kernel_type; + + + + svm_c_linear_trainer linear_trainer_cpa; + svm_c_linear_dcd_trainer linear_trainer; + + svm_c_linear_dcd_trainer::optimizer_state state; + + const double C = 1.0; + linear_trainer.set_epsilon(1e-10); + linear_trainer_cpa.set_epsilon(1e-10); + + linear_trainer.include_bias(false); + + + std::vector samples, samples_explict_bias; + std::vector labels; + + // make an instance of a sample vector so we can use it below + sample_type sample; + + decision_function df, df2, df3; + + dlib::rand rnd; + // Now lets go into a loop and randomly generate 10000 samples. + double label = +1; + for (int i = 0; i < 100; ++i) + { + // flip this flag + label *= -1; + + sample = 0; + + // now make a random sparse sample with at most 10 non-zero elements + for (int j = 0; j < 5; ++j) + { + int idx = rnd.get_random_32bit_number()%9; + double value = rnd.get_random_double(); + + sample(idx) = label*value; + } + + // Also save the samples we are generating so we can let the svm_c_linear_trainer + // learn from them below. + samples.push_back(sample); + labels.push_back(label); + + sample(9) = -1; + samples_explict_bias.push_back(sample); + + if (samples.size() > 1) + { + linear_trainer_cpa.set_c_class1(C); + linear_trainer_cpa.set_c_class2(1.5*C); + linear_trainer.set_c_class1(C/samples.size()); + linear_trainer.set_c_class2(1.5*C/samples.size()); + + df = linear_trainer.train(samples_explict_bias, labels, state); + df2 = linear_trainer_cpa.train(samples, labels); + df3 = linear_trainer.train(samples_explict_bias, labels); + + DLIB_TEST( std::abs(df2.basis_vectors(0)(9)) < 1e-7); + DLIB_TEST_MSG( max(abs(colm(df.basis_vectors(0),0,9) - colm(df2.basis_vectors(0),0,9))) < 1e-6, max(abs(colm(df.basis_vectors(0),0,9) - colm(df2.basis_vectors(0),0,9)))); + DLIB_TEST( std::abs(df.basis_vectors(0)(9) - df2.b) < 1e-6); + DLIB_TEST( max(abs(df.basis_vectors(0) - df3.basis_vectors(0))) < 1e-6); + DLIB_TEST( std::abs(df.b - df3.b) < 1e-7); + } + } + } + +// ---------------------------------------------------------------------------------------- + + void test_normal() + { + typedef matrix sample_type; + + + typedef linear_kernel kernel_type; + + + + svm_c_linear_trainer linear_trainer_cpa; + svm_c_linear_dcd_trainer linear_trainer; + + svm_c_linear_dcd_trainer::optimizer_state state; + + const double C = 1; + linear_trainer.set_epsilon(1e-10); + linear_trainer_cpa.set_epsilon(1e-10); + + std::vector samples; + std::vector labels; + + // make an instance of a sample vector so we can use it below + sample_type sample; + + decision_function df, df2, df3; + + dlib::rand rnd; + // Now lets go into a loop and randomly generate 10000 samples. + double label = +1; + for (int i = 0; i < 100; ++i) + { + // flip this flag + label *= -1; + + sample = 0; + + // now make a random sparse sample with at most 10 non-zero elements + for (int j = 0; j < 5; ++j) + { + int idx = rnd.get_random_32bit_number()%10; + double value = rnd.get_random_double(); + + sample(idx) = label*value; + } + + // Also save the samples we are generating so we can let the svm_c_linear_trainer + // learn from them below. + samples.push_back(sample); + labels.push_back(label); + + if (samples.size() > 1) + { + linear_trainer_cpa.set_c_class1(C); + linear_trainer_cpa.set_c_class2(1.5*C); + linear_trainer.set_c_class1(C/samples.size()); + linear_trainer.set_c_class2(1.5*C/samples.size()); + + df = linear_trainer.train(samples, labels, state); + df2 = linear_trainer_cpa.train(samples, labels); + df3 = linear_trainer.train(samples, labels); + + DLIB_TEST_MSG( max(abs(df.basis_vectors(0) - df2.basis_vectors(0))) < 1e-7, max(abs(df.basis_vectors(0) - df2.basis_vectors(0)))); + DLIB_TEST( std::abs(df.b - df2.b) < 1e-7); + DLIB_TEST( max(abs(df.basis_vectors(0) - df3.basis_vectors(0))) < 1e-7); + DLIB_TEST( std::abs(df.b - df3.b) < 1e-7); + } + } + } + +// ---------------------------------------------------------------------------------------- + + void test_normal_force_last_weight(bool have_bias, bool force_weight) + { + typedef matrix sample_type; + dlog << LINFO << "have_bias: "<< have_bias << " force_weight: "<< force_weight; + + + typedef linear_kernel kernel_type; + + + svm_c_linear_trainer linear_trainer_cpa; + + svm_c_linear_dcd_trainer linear_trainer; + + svm_c_linear_dcd_trainer::optimizer_state state; + + const double C = 1; + linear_trainer.set_epsilon(1e-10); + linear_trainer_cpa.set_epsilon(1e-11); + + linear_trainer_cpa.force_last_weight_to_1(force_weight); + + linear_trainer.force_last_weight_to_1(force_weight); + linear_trainer.include_bias(have_bias); + + std::vector samples; + std::vector labels; + + // make an instance of a sample vector so we can use it below + sample_type sample; + + decision_function df, df2; + + running_stats rs; + + dlib::rand rnd; + // Now lets go into a loop and randomly generate 10000 samples. + double label = +1; + for (int i = 0; i < 40; ++i) + { + // flip this flag + label *= -1; + + sample = 0; + + // now make a random sparse sample with at most 10 non-zero elements + for (int j = 0; j < 5; ++j) + { + int idx = rnd.get_random_32bit_number()%9; + double value = rnd.get_random_double(); + + sample(idx) = label*value + label; + } + + sample(9) = 4; + + // Also save the samples we are generating so we can let the svm_c_linear_trainer + // learn from them below. + samples.push_back(sample); + labels.push_back(label); + + linear_trainer.set_c(C); + linear_trainer_cpa.set_c(C*samples.size()); + + df = linear_trainer.train(samples, labels, state); + + if (force_weight) + { + DLIB_TEST(std::abs(df.basis_vectors(0)(9) - 1) < 1e-8); + DLIB_TEST(std::abs(df.b) < 1e-8); + + if (samples.size() > 1) + { + df2 = linear_trainer_cpa.train(samples, labels); + DLIB_TEST_MSG( max(abs(df.basis_vectors(0) - df2.basis_vectors(0))) < 1e-7, max(abs(df.basis_vectors(0) - df2.basis_vectors(0)))); + DLIB_TEST( std::abs(df.b - df2.b) < 1e-7); + } + } + + if (!have_bias) + DLIB_TEST(std::abs(df.b) < 1e-8); + + + for (unsigned long k = 0; k < samples.size(); ++k) + { + //cout << "pred: "<< labels[k]*df(samples[k]) << endl; + rs.add(labels[k]*df(samples[k])); + } + } + DLIB_TEST_MSG(std::abs(rs.min()-1) < 1e-7, std::abs(rs.min()-1)); + } + +// ---------------------------------------------------------------------------------------- + + void test_normal_1_sample(double label) + { + typedef matrix sample_type; + + + typedef linear_kernel kernel_type; + + + + svm_c_linear_dcd_trainer linear_trainer; + + svm_c_linear_dcd_trainer::optimizer_state state; + + const double C = 10; + linear_trainer.set_epsilon(1e-10); + linear_trainer.set_c(C); + + + linear_trainer.force_last_weight_to_1(true); + linear_trainer.include_bias(false); + + std::vector samples; + std::vector labels; + + // make an instance of a sample vector so we can use it below + sample_type sample; + + sample = 0; + sample(0) = -1; + sample(1) = -1; + sample(9) = 4; + + samples.push_back(sample); + labels.push_back(label); + + for (int i = 0; i < 4; ++i) + { + decision_function df; + df = linear_trainer.train(samples, labels); + + if (label > 0) + { + DLIB_TEST(std::abs(df(samples[0])-4) < 1e-8); + } + else + { + DLIB_TEST(std::abs(df(samples[0])+1) < 1e-8); + } + } + } + +// ---------------------------------------------------------------------------------------- + + void test_sparse_1_sample(double label) + { + typedef std::vector > sample_type; + + + typedef sparse_linear_kernel kernel_type; + + + + svm_c_linear_dcd_trainer linear_trainer; + + svm_c_linear_dcd_trainer::optimizer_state state; + + const double C = 10; + linear_trainer.set_epsilon(1e-10); + linear_trainer.set_c(C); + + + linear_trainer.force_last_weight_to_1(true); + linear_trainer.include_bias(false); + + std::vector samples; + std::vector labels; + + // make an instance of a sample vector so we can use it below + sample_type sample; + + sample.push_back(make_pair(0,-1)); + sample.push_back(make_pair(1,1)); + sample.push_back(make_pair(9,4)); + + for (int i = 0; i < 4; ++i) + { + samples.push_back(sample); + labels.push_back(label); + + decision_function df; + df = linear_trainer.train(samples, labels); + + + if (label > 0) + { + DLIB_TEST(std::abs(df(samples[0])-4) < 1e-8); + } + else + { + DLIB_TEST(std::abs(df(samples[0])+1) < 1e-8); + } + } + } + +// ---------------------------------------------------------------------------------------- + + void test_l2_version () + { + typedef std::map sample_type; + typedef sparse_linear_kernel kernel_type; + + svm_c_linear_dcd_trainer linear_trainer; + linear_trainer.set_c(10); + linear_trainer.set_epsilon(1e-5); + + std::vector samples; + std::vector labels; + + // make an instance of a sample vector so we can use it below + sample_type sample; + + + // Now let's go into a loop and randomly generate 10000 samples. + double label = +1; + for (int i = 0; i < 1000; ++i) + { + // flip this flag + label *= -1; + + sample.clear(); + + // now make a random sparse sample with at most 10 non-zero elements + for (int j = 0; j < 10; ++j) + { + int idx = std::rand()%100; + double value = static_cast(std::rand())/RAND_MAX; + + sample[idx] = label*value; + } + + // Also save the samples we are generating so we can let the svm_c_linear_trainer + // learn from them below. + samples.push_back(sample); + labels.push_back(label); + } + + decision_function df = linear_trainer.train(samples, labels); + + sample.clear(); + sample[4] = 0.3; + sample[10] = 0.9; + DLIB_TEST(df(sample) > 0); + + sample.clear(); + sample[83] = -0.3; + sample[26] = -0.9; + sample[58] = -0.7; + DLIB_TEST(df(sample) < 0); + + sample.clear(); + sample[0] = -0.2; + sample[9] = -0.8; + DLIB_TEST(df(sample) < 0); + } + + class tester_svm_c_linear_dcd : public tester + { + public: + tester_svm_c_linear_dcd ( + ) : + tester ("test_svm_c_linear_dcd", + "Runs tests on the svm_c_linear_dcd_trainer.") + {} + + void perform_test ( + ) + { + test_normal(); + print_spinner(); + test_normal_no_bias(); + print_spinner(); + test_sparse(); + print_spinner(); + test_normal_force_last_weight(false,false); + print_spinner(); + test_normal_force_last_weight(false,true); + print_spinner(); + test_normal_force_last_weight(true,false); + print_spinner(); + test_normal_force_last_weight(true,true); + print_spinner(); + test_normal_1_sample(+1); + print_spinner(); + test_normal_1_sample(-1); + print_spinner(); + test_sparse_1_sample(+1); + print_spinner(); + test_sparse_1_sample(-1); + print_spinner(); + + test_l2_version(); + } + } a; + +} + + + + diff --git a/ml/dlib/dlib/test/svm_multiclass_linear.cpp b/ml/dlib/dlib/test/svm_multiclass_linear.cpp new file mode 100644 index 000000000..e01d48892 --- /dev/null +++ b/ml/dlib/dlib/test/svm_multiclass_linear.cpp @@ -0,0 +1,226 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "tester.h" +#include +#include +#include "create_iris_datafile.h" +#include +#include +#include + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + dlib::logger dlog("test.svm_multiclass_trainer"); + + + class test_svm_multiclass_trainer : public tester + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a unit test. When it is constructed + it adds itself into the testing framework. + !*/ + public: + test_svm_multiclass_trainer ( + ) : + tester ( + "test_svm_multiclass_trainer", // the command line argument name for this test + "Run tests on the svm_multiclass_linear_trainer stuff.", // the command line argument description + 0 // the number of command line arguments for this test + ) + { + } + + + void test_prior () + { + print_spinner(); + typedef matrix sample_type; + typedef linear_kernel kernel_type; + + std::vector samples; + std::vector labels; + + for (int i = 0; i < 4; ++i) + { + if (i==2) + ++i; + for (int iter = 0; iter < 5; ++iter) + { + sample_type samp; + samp = 0; + samp(i) = 1; + samples.push_back(samp); + labels.push_back(i); + } + } + + + svm_multiclass_linear_trainer trainer; + + multiclass_linear_decision_function df = trainer.train(samples, labels); + + //cout << "test: \n" << test_multiclass_decision_function(df, samples, labels) << endl; + //cout << df.weights << endl; + //cout << df.b << endl; + + std::vector samples2; + std::vector labels2; + int i = 2; + for (int iter = 0; iter < 5; ++iter) + { + sample_type samp; + samp = 0; + samp(i) = 1; + samples2.push_back(samp); + labels2.push_back(i); + samples.push_back(samp); + labels.push_back(i); + } + + trainer.set_prior(df); + trainer.set_c(0.1); + df = trainer.train(samples2, labels2); + + matrix res = test_multiclass_decision_function(df, samples, labels); + dlog << LINFO << "test: \n" << res; + dlog << LINFO << df.weights; + dlog << LINFO << df.b; + DLIB_TEST((unsigned int)sum(diag(res))==samples.size()); + } + + void test_prior_sparse () + { + print_spinner(); + typedef std::map sample_type; + typedef sparse_linear_kernel kernel_type; + + std::vector samples; + std::vector labels; + + for (int i = 0; i < 4; ++i) + { + if (i==2) + ++i; + for (int iter = 0; iter < 5; ++iter) + { + sample_type samp; + samp[i] = 1; + samples.push_back(samp); + labels.push_back(i); + } + } + + + svm_multiclass_linear_trainer trainer; + + multiclass_linear_decision_function df = trainer.train(samples, labels); + + //cout << "test: \n" << test_multiclass_decision_function(df, samples, labels) << endl; + //cout << df.weights << endl; + //cout << df.b << endl; + + std::vector samples2; + std::vector labels2; + int i = 2; + for (int iter = 0; iter < 5; ++iter) + { + sample_type samp; + samp[i] = 1; + samp[i+10] = 1; + samples2.push_back(samp); + labels2.push_back(i); + samples.push_back(samp); + labels.push_back(i); + } + + trainer.set_prior(df); + trainer.set_c(0.1); + df = trainer.train(samples2, labels2); + + matrix res = test_multiclass_decision_function(df, samples, labels); + dlog << LINFO << "test: \n" << res; + dlog << LINFO << df.weights; + dlog << LINFO << df.b; + DLIB_TEST((unsigned int)sum(diag(res))==samples.size()); + } + + template + void run_test() + { + print_spinner(); + + typedef typename sample_type::value_type::second_type scalar_type; + + std::vector samples; + std::vector labels; + + load_libsvm_formatted_data("iris.scale",samples, labels); + + DLIB_TEST(samples.size() == 150); + DLIB_TEST(labels.size() == 150); + + typedef sparse_linear_kernel kernel_type; + svm_multiclass_linear_trainer trainer; + trainer.set_c(100); + trainer.set_epsilon(0.000001); + + randomize_samples(samples, labels); + matrix cv = cross_validate_multiclass_trainer(trainer, samples, labels, 4); + + dlog << LINFO << "confusion matrix: \n" << cv; + const scalar_type cv_accuracy = sum(diag(cv))/sum(cv); + dlog << LINFO << "cv accuracy: " << cv_accuracy; + DLIB_TEST(cv_accuracy > 0.97); + + + + + { + print_spinner(); + typedef matrix dsample_type; + std::vector dsamples = sparse_to_dense(samples); + DLIB_TEST(dsamples.size() == 150); + + typedef linear_kernel kernel_type; + svm_multiclass_linear_trainer trainer; + trainer.set_c(100); + + cv = cross_validate_multiclass_trainer(trainer, dsamples, labels, 4); + + dlog << LINFO << "dense confusion matrix: \n" << cv; + const scalar_type cv_accuracy = sum(diag(cv))/sum(cv); + dlog << LINFO << "dense cv accuracy: " << cv_accuracy; + DLIB_TEST(cv_accuracy > 0.97); + } + + } + + + + + void perform_test ( + ) + { + print_spinner(); + create_iris_datafile(); + + run_test >(); + run_test >(); + run_test > >(); + run_test > >(); + + test_prior(); + test_prior_sparse(); + } + }; + + test_svm_multiclass_trainer a; + +} + + diff --git a/ml/dlib/dlib/test/svm_struct.cpp b/ml/dlib/dlib/test/svm_struct.cpp new file mode 100644 index 000000000..00208c48d --- /dev/null +++ b/ml/dlib/dlib/test/svm_struct.cpp @@ -0,0 +1,641 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include +#include +#include +#include +#include + +#include "tester.h" + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.svm_struct"); + + + template < + typename matrix_type, + typename sample_type, + typename label_type + > + class test_multiclass_svm_problem : public structural_svm_problem_threaded > > + { + + public: + typedef typename matrix_type::type scalar_type; + typedef std::vector > feature_vector_type; + + test_multiclass_svm_problem ( + const std::vector& samples_, + const std::vector& labels_ + ) : + structural_svm_problem_threaded > >(2), + samples(samples_), + labels(labels_), + dims(10+1) // +1 for the bias + { + for (int i = 0; i < 10; ++i) + { + distinct_labels.push_back(i); + } + } + + virtual long get_num_dimensions ( + ) const + { + return dims*10; + } + + virtual long get_num_samples ( + ) const + { + return static_cast(samples.size()); + } + + virtual void get_truth_joint_feature_vector ( + long idx, + feature_vector_type& psi + ) const + { + assign(psi, samples[idx]); + // Add a constant -1 to account for the bias term. + psi.push_back(std::make_pair(dims-1,static_cast(-1))); + + // Find which distinct label goes with this psi. + const long label_idx = index_of_max(mat(distinct_labels) == labels[idx]); + + offset_feature_vector(psi, dims*label_idx); + } + + virtual void separation_oracle ( + const long idx, + const matrix_type& current_solution, + scalar_type& loss, + feature_vector_type& psi + ) const + { + scalar_type best_val = -std::numeric_limits::infinity(); + unsigned long best_idx = 0; + + // Figure out which label is the best. That is, what label maximizes + // LOSS(idx,y) + F(x,y). Note that y in this case is given by distinct_labels[i]. + for (unsigned long i = 0; i < distinct_labels.size(); ++i) + { + // Compute the F(x,y) part: + // perform: temp == dot(relevant part of current solution, samples[idx]) - current_bias + scalar_type temp = dot(rowm(current_solution, range(i*dims, (i+1)*dims-2)), samples[idx]) - current_solution((i+1)*dims-1); + + // Add the LOSS(idx,y) part: + if (labels[idx] != distinct_labels[i]) + temp += 1; + + // Now temp == LOSS(idx,y) + F(x,y). Check if it is the biggest we have seen. + if (temp > best_val) + { + best_val = temp; + best_idx = i; + } + } + + assign(psi, samples[idx]); + // add a constant -1 to account for the bias term + psi.push_back(std::make_pair(dims-1,static_cast(-1))); + + offset_feature_vector(psi, dims*best_idx); + + if (distinct_labels[best_idx] == labels[idx]) + loss = 0; + else + loss = 1; + } + + private: + + void offset_feature_vector ( + feature_vector_type& sample, + const unsigned long val + ) const + { + if (val != 0) + { + for (typename feature_vector_type::iterator i = sample.begin(); i != sample.end(); ++i) + { + i->first += val; + } + } + } + + + const std::vector& samples; + const std::vector& labels; + std::vector distinct_labels; + const long dims; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename K, + typename label_type_ = typename K::scalar_type + > + class test_svm_multiclass_linear_trainer2 + { + public: + typedef label_type_ label_type; + typedef K kernel_type; + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + + typedef multiclass_linear_decision_function trained_function_type; + + + test_svm_multiclass_linear_trainer2 ( + ) : + C(10), + eps(1e-4), + verbose(false) + { + } + + trained_function_type train ( + const std::vector& all_samples, + const std::vector& all_labels + ) const + { + scalar_type svm_objective = 0; + return train(all_samples, all_labels, svm_objective); + } + + trained_function_type train ( + const std::vector& all_samples, + const std::vector& all_labels, + scalar_type& svm_objective + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(is_learning_problem(all_samples,all_labels), + "\t trained_function_type test_svm_multiclass_linear_trainer2::train(all_samples,all_labels)" + << "\n\t invalid inputs were given to this function" + << "\n\t all_samples.size(): " << all_samples.size() + << "\n\t all_labels.size(): " << all_labels.size() + ); + + typedef matrix w_type; + w_type weights; + std::vector samples1(all_samples.begin(), all_samples.begin()+all_samples.size()/2); + std::vector samples2(all_samples.begin()+all_samples.size()/2, all_samples.end()); + + std::vector labels1(all_labels.begin(), all_labels.begin()+all_labels.size()/2); + std::vector labels2(all_labels.begin()+all_labels.size()/2, all_labels.end()); + test_multiclass_svm_problem problem1(samples1, labels1); + test_multiclass_svm_problem problem2(samples2, labels2); + problem1.set_max_cache_size(3); + problem2.set_max_cache_size(0); + + svm_struct_processing_node node1(problem1, 12345, 3); + svm_struct_processing_node node2(problem2, 12346, 0); + + solver.set_inactive_plane_threshold(50); + solver.set_subproblem_epsilon(1e-4); + + svm_struct_controller_node controller; + controller.set_c(C); + controller.set_epsilon(eps); + if (verbose) + controller.be_verbose(); + controller.add_processing_node("127.0.0.1", 12345); + controller.add_processing_node("localhost:12346"); + svm_objective = controller(solver, weights); + + + + trained_function_type df; + + const long dims = max_index_plus_one(all_samples); + df.labels = select_all_distinct_labels(all_labels); + df.weights = colm(reshape(weights, df.labels.size(), dims+1), range(0,dims-1)); + df.b = colm(reshape(weights, df.labels.size(), dims+1), dims); + return df; + } + + private: + scalar_type C; + scalar_type eps; + bool verbose; + mutable oca solver; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename K, + typename label_type_ = typename K::scalar_type + > + class test_svm_multiclass_linear_trainer3 + { + public: + typedef label_type_ label_type; + typedef K kernel_type; + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + + typedef multiclass_linear_decision_function trained_function_type; + + + test_svm_multiclass_linear_trainer3 ( + ) : + C(10), + eps(1e-4), + verbose(false) + { + } + + trained_function_type train ( + const std::vector& all_samples, + const std::vector& all_labels + ) const + { + scalar_type svm_objective = 0; + return train(all_samples, all_labels, svm_objective); + } + + trained_function_type train ( + const std::vector& all_samples, + const std::vector& all_labels, + scalar_type& svm_objective + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(is_learning_problem(all_samples,all_labels), + "\t trained_function_type test_svm_multiclass_linear_trainer3::train(all_samples,all_labels)" + << "\n\t invalid inputs were given to this function" + << "\n\t all_samples.size(): " << all_samples.size() + << "\n\t all_labels.size(): " << all_labels.size() + ); + + typedef matrix w_type; + w_type weights; + test_multiclass_svm_problem problem(all_samples, all_labels); + problem.set_max_cache_size(0); + + problem.set_c(C); + problem.set_epsilon(eps); + + if (verbose) + problem.be_verbose(); + + solver.set_inactive_plane_threshold(50); + solver.set_subproblem_epsilon(1e-4); + svm_objective = solver(problem, weights); + + + trained_function_type df; + + const long dims = max_index_plus_one(all_samples); + df.labels = select_all_distinct_labels(all_labels); + df.weights = colm(reshape(weights, df.labels.size(), dims+1), range(0,dims-1)); + df.b = colm(reshape(weights, df.labels.size(), dims+1), dims); + return df; + } + + private: + scalar_type C; + scalar_type eps; + bool verbose; + mutable oca solver; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename K, + typename label_type_ = typename K::scalar_type + > + class test_svm_multiclass_linear_trainer4 + { + public: + typedef label_type_ label_type; + typedef K kernel_type; + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + + typedef multiclass_linear_decision_function trained_function_type; + + + test_svm_multiclass_linear_trainer4 ( + ) : + C(10), + eps(1e-4), + verbose(false) + { + } + + trained_function_type train ( + const std::vector& all_samples, + const std::vector& all_labels + ) const + { + scalar_type svm_objective = 0; + return train(all_samples, all_labels, svm_objective); + } + + trained_function_type train ( + const std::vector& all_samples, + const std::vector& all_labels, + scalar_type& svm_objective + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(is_learning_problem(all_samples,all_labels), + "\t trained_function_type test_svm_multiclass_linear_trainer4::train(all_samples,all_labels)" + << "\n\t invalid inputs were given to this function" + << "\n\t all_samples.size(): " << all_samples.size() + << "\n\t all_labels.size(): " << all_labels.size() + ); + + typedef matrix w_type; + w_type weights; + test_multiclass_svm_problem problem(all_samples, all_labels); + problem.set_max_cache_size(3); + + problem.set_c(C); + problem.set_epsilon(eps); + + if (verbose) + problem.be_verbose(); + + solver.set_inactive_plane_threshold(50); + solver.set_subproblem_epsilon(1e-4); + svm_objective = solver(problem, weights); + + + trained_function_type df; + + const long dims = max_index_plus_one(all_samples); + df.labels = select_all_distinct_labels(all_labels); + df.weights = colm(reshape(weights, df.labels.size(), dims+1), range(0,dims-1)); + df.b = colm(reshape(weights, df.labels.size(), dims+1), dims); + return df; + } + + private: + scalar_type C; + scalar_type eps; + bool verbose; + mutable oca solver; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename K, + typename label_type_ = typename K::scalar_type + > + class test_svm_multiclass_linear_trainer5 + { + public: + typedef label_type_ label_type; + typedef K kernel_type; + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + + typedef multiclass_linear_decision_function trained_function_type; + + + test_svm_multiclass_linear_trainer5 ( + ) : + C(10), + eps(1e-4), + verbose(false) + { + } + + trained_function_type train ( + const std::vector& all_samples, + const std::vector& all_labels + ) const + { + scalar_type svm_objective = 0; + return train(all_samples, all_labels, svm_objective); + } + + trained_function_type train ( + const std::vector& all_samples, + const std::vector& all_labels, + scalar_type& svm_objective + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(is_learning_problem(all_samples,all_labels), + "\t trained_function_type test_svm_multiclass_linear_trainer5::train(all_samples,all_labels)" + << "\n\t invalid inputs were given to this function" + << "\n\t all_samples.size(): " << all_samples.size() + << "\n\t all_labels.size(): " << all_labels.size() + ); + + typedef matrix w_type; + w_type weights; + const long dims = max_index_plus_one(all_samples); + trained_function_type df; + df.labels = select_all_distinct_labels(all_labels); + multiclass_svm_problem problem(all_samples, all_labels, df.labels, dims, 4); + problem.set_max_cache_size(3); + + problem.set_c(C); + problem.set_epsilon(eps); + + if (verbose) + problem.be_verbose(); + + solver.set_inactive_plane_threshold(50); + solver.set_subproblem_epsilon(1e-4); + svm_objective = solver(problem, weights); + + + + df.weights = colm(reshape(weights, df.labels.size(), dims+1), range(0,dims-1)); + df.b = colm(reshape(weights, df.labels.size(), dims+1), dims); + return df; + } + + private: + scalar_type C; + scalar_type eps; + bool verbose; + mutable oca solver; + }; + + +// ---------------------------------------------------------------------------------------- + + typedef matrix sample_type; + typedef double scalar_type; + + void make_dataset ( + std::vector& samples, + std::vector& labels, + int num, + dlib::rand& rnd + ) + { + samples.clear(); + labels.clear(); + for (int i = 0; i < 10; ++i) + { + for (int j = 0; j < num; ++j) + { + sample_type samp; + samp = 0; + samp(i) = 10*rnd.get_random_double()+1; + + samples.push_back(samp); + labels.push_back(i); + } + } + } + +// ---------------------------------------------------------------------------------------- + + class test_svm_struct : public tester + { + public: + test_svm_struct ( + ) : + tester ("test_svm_struct", + "Runs tests on the structural svm components.") + {} + + void run_test ( + const std::vector& samples, + const std::vector& labels, + const double true_obj + ) + { + typedef linear_kernel kernel_type; + svm_multiclass_linear_trainer trainer1; + test_svm_multiclass_linear_trainer2 trainer2; + test_svm_multiclass_linear_trainer3 trainer3; + test_svm_multiclass_linear_trainer4 trainer4; + test_svm_multiclass_linear_trainer5 trainer5; + + trainer1.set_epsilon(1e-4); + trainer1.set_c(10); + + + multiclass_linear_decision_function df1, df2, df3, df4, df5; + double obj1, obj2, obj3, obj4, obj5; + + // Solve a multiclass SVM a whole bunch of different ways and make sure + // they all give the same answer. + print_spinner(); + df1 = trainer1.train(samples, labels, obj1); + print_spinner(); + df2 = trainer2.train(samples, labels, obj2); + print_spinner(); + df3 = trainer3.train(samples, labels, obj3); + print_spinner(); + df4 = trainer4.train(samples, labels, obj4); + print_spinner(); + df5 = trainer5.train(samples, labels, obj5); + print_spinner(); + + dlog << LINFO << "obj1: "<< obj1; + dlog << LINFO << "obj2: "<< obj2; + dlog << LINFO << "obj3: "<< obj3; + dlog << LINFO << "obj4: "<< obj4; + dlog << LINFO << "obj5: "<< obj5; + DLIB_TEST(std::abs(obj1 - obj2) < 1e-2); + DLIB_TEST(std::abs(obj1 - obj3) < 1e-2); + DLIB_TEST(std::abs(obj1 - obj4) < 1e-2); + DLIB_TEST(std::abs(obj1 - obj5) < 1e-2); + DLIB_TEST(std::abs(obj1 - true_obj) < 1e-2); + DLIB_TEST(std::abs(obj2 - true_obj) < 1e-2); + DLIB_TEST(std::abs(obj3 - true_obj) < 1e-2); + DLIB_TEST(std::abs(obj4 - true_obj) < 1e-2); + DLIB_TEST(std::abs(obj5 - true_obj) < 1e-2); + + dlog << LINFO << "weight error: "<< max(abs(df1.weights - df2.weights)); + dlog << LINFO << "weight error: "<< max(abs(df1.weights - df3.weights)); + dlog << LINFO << "weight error: "<< max(abs(df1.weights - df4.weights)); + dlog << LINFO << "weight error: "<< max(abs(df1.weights - df5.weights)); + + DLIB_TEST(max(abs(df1.weights - df2.weights)) < 1e-2); + DLIB_TEST(max(abs(df1.weights - df3.weights)) < 1e-2); + DLIB_TEST(max(abs(df1.weights - df4.weights)) < 1e-2); + DLIB_TEST(max(abs(df1.weights - df5.weights)) < 1e-2); + + dlog << LINFO << "b error: "<< max(abs(df1.b - df2.b)); + dlog << LINFO << "b error: "<< max(abs(df1.b - df3.b)); + dlog << LINFO << "b error: "<< max(abs(df1.b - df4.b)); + dlog << LINFO << "b error: "<< max(abs(df1.b - df5.b)); + DLIB_TEST(max(abs(df1.b - df2.b)) < 1e-2); + DLIB_TEST(max(abs(df1.b - df3.b)) < 1e-2); + DLIB_TEST(max(abs(df1.b - df4.b)) < 1e-2); + DLIB_TEST(max(abs(df1.b - df5.b)) < 1e-2); + + matrix res = test_multiclass_decision_function(df1, samples, labels); + dlog << LINFO << res; + dlog << LINFO << "accuracy: " << sum(diag(res))/sum(res); + DLIB_TEST(sum(diag(res)) == samples.size()); + + res = test_multiclass_decision_function(df2, samples, labels); + dlog << LINFO << res; + dlog << LINFO << "accuracy: " << sum(diag(res))/sum(res); + DLIB_TEST(sum(diag(res)) == samples.size()); + + res = test_multiclass_decision_function(df3, samples, labels); + dlog << LINFO << res; + dlog << LINFO << "accuracy: " << sum(diag(res))/sum(res); + DLIB_TEST(sum(diag(res)) == samples.size()); + + res = test_multiclass_decision_function(df4, samples, labels); + dlog << LINFO << res; + dlog << LINFO << "accuracy: " << sum(diag(res))/sum(res); + DLIB_TEST(sum(diag(res)) == samples.size()); + + res = test_multiclass_decision_function(df5, samples, labels); + dlog << LINFO << res; + dlog << LINFO << "accuracy: " << sum(diag(res))/sum(res); + DLIB_TEST(sum(diag(res)) == samples.size()); + } + + void perform_test ( + ) + { + std::vector samples; + std::vector labels; + + dlib::rand rnd; + + dlog << LINFO << "test with 100 samples per class"; + make_dataset(samples, labels, 100, rnd); + run_test(samples, labels, 1.155); + + dlog << LINFO << "test with 1 sample per class"; + make_dataset(samples, labels, 1, rnd); + run_test(samples, labels, 0.251); + + dlog << LINFO << "test with 2 sample per class"; + make_dataset(samples, labels, 2, rnd); + run_test(samples, labels, 0.444); + } + } a; + + + +} + + + + diff --git a/ml/dlib/dlib/test/svr_linear_trainer.cpp b/ml/dlib/dlib/test/svr_linear_trainer.cpp new file mode 100644 index 000000000..ca1a5442f --- /dev/null +++ b/ml/dlib/dlib/test/svr_linear_trainer.cpp @@ -0,0 +1,161 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include +#include + +#include "tester.h" +#include + + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.svr_linear_trainer"); + + typedef matrix sample_type; + typedef std::vector > sparse_sample_type; + +// ---------------------------------------------------------------------------------------- + + double sinc(double x) + { + if (x == 0) + return 1; + return sin(x)/x; + } + + template + void test1() + { + typedef matrix sample_type; + + typedef radial_basis_kernel kernel_type; + + print_spinner(); + + std::vector samples; + std::vector targets; + + // The first thing we do is pick a few training points from the sinc() function. + sample_type m(1); + for (scalar_type x = -10; x <= 4; x += 1) + { + m(0) = x; + + samples.push_back(m); + targets.push_back(sinc(x)+1.1); + } + + randomize_samples(samples, targets); + + empirical_kernel_map ekm; + ekm.load(kernel_type(0.1), samples); + + for (unsigned long i = 0; i < samples.size(); ++i) + samples[i] = ekm.project(samples[i]); + + svr_linear_trainer > linear_trainer; + linear_trainer.set_epsilon(0.0001); + linear_trainer.set_c(30); + linear_trainer.set_epsilon_insensitivity(0.001); + + matrix res = cross_validate_regression_trainer(linear_trainer, samples, targets, 5); + dlog << LINFO << "MSE and R-Squared: "<< res; + DLIB_TEST(res(0) < 1e-4); + DLIB_TEST(res(1) > 0.99); + + dlib::rand rnd; + + samples.clear(); + targets.clear(); + std::vector noisefree_targets; + for (scalar_type x = 0; x <= 5; x += 0.1) + { + m(0) = x; + samples.push_back(matrix_cast(linpiece(m, linspace(0,5,20)))); + targets.push_back(x*x + rnd.get_random_gaussian()); + noisefree_targets.push_back(x*x); + } + linear_trainer.set_learns_nonnegative_weights(true); + linear_trainer.set_epsilon_insensitivity(1.0); + decision_function > df2 = linear_trainer.train(samples, targets); + + print_spinner(); + res = test_regression_function(df2, samples, noisefree_targets); + dlog << LINFO << "MSE and R-Squared: "<< res; + DLIB_TEST(res(0) < 0.15); + DLIB_TEST(res(1) > 0.98); + DLIB_TEST(df2.basis_vectors.size()==1); + DLIB_TEST(max(df2.basis_vectors(0)) >= 0); + + linear_trainer.force_last_weight_to_1(true); + df2 = linear_trainer.train(samples, targets); + DLIB_TEST(std::abs(df2.basis_vectors(0)(samples[0].size()-1) - 1.0) < 1e-14); + + res = test_regression_function(df2, samples, noisefree_targets); + dlog << LINFO << "MSE and R-Squared: "<< res; + DLIB_TEST(res(0) < 0.20); + DLIB_TEST(res(1) > 0.98); + + + // convert into sparse vectors and try it out + typedef std::vector > sparse_samp; + std::vector ssamples; + for (unsigned long i = 0; i < samples.size(); ++i) + { + sparse_samp s; + for (long j = 0; j < samples[i].size(); ++j) + s.push_back(make_pair(j,samples[i](j))); + ssamples.push_back(s); + } + + svr_linear_trainer > strainer; + strainer.set_learns_nonnegative_weights(true); + strainer.set_epsilon_insensitivity(1.0); + strainer.set_c(30); + decision_function > df; + df = strainer.train(ssamples, targets); + res = test_regression_function(df, ssamples, noisefree_targets); + dlog << LINFO << "MSE and R-Squared: "<< res; + DLIB_TEST(res(0) < 0.15); + DLIB_TEST(res(1) > 0.98); + DLIB_TEST(df2.basis_vectors.size()==1); + DLIB_TEST(max(sparse_to_dense(df2.basis_vectors(0))) >= 0); + } + + +// ---------------------------------------------------------------------------------------- + + class tester_svr_linear_trainer : public tester + { + public: + tester_svr_linear_trainer ( + ) : + tester ("test_svr_linear_trainer", + "Runs tests on the svr_linear_trainer.") + {} + + void perform_test ( + ) + { + dlog << LINFO << "TEST double"; + test1(); + dlog << LINFO << "TEST float"; + test1(); + } + } a; + +} + + + diff --git a/ml/dlib/dlib/test/symmetric_matrix_cache.cpp b/ml/dlib/dlib/test/symmetric_matrix_cache.cpp new file mode 100644 index 000000000..6d93a4daa --- /dev/null +++ b/ml/dlib/dlib/test/symmetric_matrix_cache.cpp @@ -0,0 +1,212 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "tester.h" +#include +#include +#include +#include + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + dlib::logger dlog("test.symmetric_matrix_cache"); + + + class test_symmetric_matrix_cache : public tester + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a unit test. When it is constructed + it adds itself into the testing framework. + !*/ + public: + test_symmetric_matrix_cache ( + ) : + tester ( + "test_symmetric_matrix_cache", // the command line argument name for this test + "Run tests on the symmetric_matrix_cache function.", // the command line argument description + 0 // the number of command line arguments for this test + ) + { + } + + dlib::rand rnd; + + // ----------------------------------- + + template + void test_colm_exp ( + const matrix_exp& m1, + const matrix_exp& m2 + ) + { + for (long i = 0; i < m1.nc(); ++i) + { + + typename colm_exp::type c1 = colm(m1,i); + typename colm_exp::type c2 = colm(m2,i); + + DLIB_TEST(equal(c1 , c2)); + DLIB_TEST(equal(colm(m1,i) , c2)); + DLIB_TEST(equal(c1 , colm(m2,i))); + DLIB_TEST(equal(colm(m1,i) , colm(m2,i))); + } + + + // Get a bunch of columns at once to test out the reference + // counting and automatic cache expansion built into the symmetric_matrix_cache. + // This test verifies that, for example, getting column 3 doesn't stomp on + // any of the previous columns. + typename colm_exp::type c1_0 = colm(m1,0); + typename colm_exp::type c1_1 = colm(m1,1); + typename colm_exp::type c1_2 = colm(m1,2); + typename colm_exp::type c1_3 = colm(m1,3); + typename colm_exp::type c1_4 = colm(m1,4); + typename colm_exp::type c1_5 = colm(m1,5); + + typename colm_exp::type c2_0 = colm(m2,0); + typename colm_exp::type c2_1 = colm(m2,1); + typename colm_exp::type c2_2 = colm(m2,2); + typename colm_exp::type c2_3 = colm(m2,3); + typename colm_exp::type c2_4 = colm(m2,4); + typename colm_exp::type c2_5 = colm(m2,5); + + DLIB_TEST(equal(c1_0, c2_0)); + DLIB_TEST(equal(c1_1, c2_1)); + DLIB_TEST(equal(c1_2, c2_2)); + DLIB_TEST(equal(c1_3, c2_3)); + DLIB_TEST(equal(c1_4, c2_4)); + DLIB_TEST(equal(c1_5, c2_5)); + } + + // ----------------------------------- + + template + void test_rowm_exp ( + const matrix_exp& m1, + const matrix_exp& m2 + ) + { + for (long i = 0; i < m1.nc(); ++i) + { + + typename rowm_exp::type r1 = rowm(m1,i); + typename rowm_exp::type r2 = rowm(m2,i); + + DLIB_TEST(equal(r1 , r2)); + DLIB_TEST(equal(rowm(m1,i) , r2)); + DLIB_TEST(equal(r1 , rowm(m2,i))); + DLIB_TEST(equal(rowm(m1,i) , rowm(m2,i))); + } + + + // Get a bunch of rows at once to test out the reference + // counting and automatic cache expansion built into the symmetric_matrix_cache. + // This test verifies that, for example, getting row 3 doesn't stomp on + // any of the previous rows. + typename rowm_exp::type r1_0 = rowm(m1,0); + typename rowm_exp::type r1_1 = rowm(m1,1); + typename rowm_exp::type r1_2 = rowm(m1,2); + typename rowm_exp::type r1_3 = rowm(m1,3); + typename rowm_exp::type r1_4 = rowm(m1,4); + typename rowm_exp::type r1_5 = rowm(m1,5); + + typename rowm_exp::type r2_0 = rowm(m2,0); + typename rowm_exp::type r2_1 = rowm(m2,1); + typename rowm_exp::type r2_2 = rowm(m2,2); + typename rowm_exp::type r2_3 = rowm(m2,3); + typename rowm_exp::type r2_4 = rowm(m2,4); + typename rowm_exp::type r2_5 = rowm(m2,5); + + DLIB_TEST(equal(r1_0, r2_0)); + DLIB_TEST(equal(r1_1, r2_1)); + DLIB_TEST(equal(r1_2, r2_2)); + DLIB_TEST(equal(r1_3, r2_3)); + DLIB_TEST(equal(r1_4, r2_4)); + DLIB_TEST(equal(r1_5, r2_5)); + } + + // ----------------------------------- + + template + void test_diag_exp ( + const matrix_exp& m1, + const matrix_exp& m2 + ) + { + + typename diag_exp::type c1 = diag(m1); + typename diag_exp::type c2 = diag(m2); + + DLIB_TEST(equal(c1 , c2)); + DLIB_TEST(equal(diag(m1) , c2)); + DLIB_TEST(equal(c1 , diag(m2))); + DLIB_TEST(equal(diag(m1) , diag(m2))); + } + + // ----------------------------------- + + void test_stuff ( + long csize + ) + { + print_spinner(); + dlog << LINFO << "csize: "<< csize; + matrix m = randm(10,10,rnd); + + m = make_symmetric(m); + + DLIB_TEST(equal(symmetric_matrix_cache(m, csize), matrix_cast(m))); + DLIB_TEST(equal(symmetric_matrix_cache(m, csize), matrix_cast(m))); + + dlog << LINFO << "test colm/rowm"; + + + for (long i = 0; i < m.nr(); ++i) + { + DLIB_TEST(equal(colm(symmetric_matrix_cache(m, csize),i), colm(matrix_cast(m),i))); + DLIB_TEST(equal(rowm(symmetric_matrix_cache(m, csize),i), rowm(matrix_cast(m),i))); + // things are supposed to be symmetric + DLIB_TEST(equal(colm(symmetric_matrix_cache(m, csize),i), trans(rowm(matrix_cast(m),i)))); + DLIB_TEST(equal(rowm(symmetric_matrix_cache(m, csize),i), trans(colm(matrix_cast(m),i)))); + } + + dlog << LINFO << "test diag"; + DLIB_TEST(equal(diag(symmetric_matrix_cache(m,csize)), diag(matrix_cast(m)))); + + test_colm_exp(symmetric_matrix_cache(m,csize), matrix_cast(m)); + test_rowm_exp(symmetric_matrix_cache(m,csize), matrix_cast(m)); + test_diag_exp(symmetric_matrix_cache(m,csize), matrix_cast(m)); + + test_colm_exp(tmp(symmetric_matrix_cache(m,csize)), tmp(matrix_cast(m))); + test_rowm_exp(symmetric_matrix_cache(m,csize), tmp(matrix_cast(m))); + test_diag_exp(tmp(symmetric_matrix_cache(m,csize)), tmp(matrix_cast(m))); + } + + + void perform_test ( + ) + { + + for (int itr = 0; itr < 5; ++itr) + { + test_stuff(0); + test_stuff(1); + test_stuff(2); + } + + } + }; + + // Create an instance of this object. Doing this causes this test + // to be automatically inserted into the testing framework whenever this cpp file + // is linked into the project. Note that since we are inside an unnamed-namespace + // we won't get any linker errors about the symbol a being defined multiple times. + test_symmetric_matrix_cache a; + +} + + diff --git a/ml/dlib/dlib/test/tester.cpp b/ml/dlib/dlib/test/tester.cpp new file mode 100644 index 000000000..2fb4d41ac --- /dev/null +++ b/ml/dlib/dlib/test/tester.cpp @@ -0,0 +1,175 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include +#include "tester.h" +#include +#include + +namespace test +{ +// ----------------------------------------------------------------------------- + + bool be_verbose = true; + +// ----------------------------------------------------------------------------- + + static dlib::mutex spinner_mutex; + static dlib::mutex test_count_mutex; + dlib::uint64 test_count = 0; + +// ----------------------------------------------------------------------------- + + dlib::uint64 number_of_testing_statements_executed ( + ) + { + dlib::auto_mutex lock(test_count_mutex); + return test_count; + } + + void increment_test_count ( + ) + { + test_count_mutex.lock(); + ++test_count; + test_count_mutex.unlock(); + } + +// ----------------------------------------------------------------------------- + + void check_test ( + bool _exp, + long line, + const char* file, + const char* _exp_str + ) + { + test_count_mutex.lock(); + ++test_count; + test_count_mutex.unlock(); + if ( !(_exp) ) + { + std::ostringstream dlib_o_out; + dlib_o_out << "\n\nError occurred at line " << line << ".\n"; + dlib_o_out << "Error occurred in file " << file << ".\n"; + dlib_o_out << "Failing expression was " << _exp_str << ".\n"; + throw dlib::error(dlib_o_out.str()); + } + } + +// ----------------------------------------------------------------------------- + + map_of_testers& testers ( + ) + { + static map_of_testers t; + return t; + } + +// ----------------------------------------------------------------------------- + + tester:: + tester ( + const std::string& switch_name_x, + const std::string& description_x, + unsigned long num_of_args_x + ) : + switch_name(switch_name_x), + description_(description_x), + num_of_args_(num_of_args_x) + { + using namespace std; + if (testers().is_in_domain(switch_name)) + { + cerr << "ERROR: More than one tester has been defined with the switch '" << switch_name << "'." << endl; + exit(1); + } + + string temp(switch_name); + tester* t = this; + testers().add(temp,t); + } + +// ----------------------------------------------------------------------------- + + const std::string& tester:: + cmd_line_switch ( + ) const + { + return switch_name; + } + +// ----------------------------------------------------------------------------- + + const std::string& tester:: + description ( + ) const + { + return description_; + } + +// ----------------------------------------------------------------------------- + + unsigned long tester:: + num_of_args ( + ) const + { + return num_of_args_; + } + +// ----------------------------------------------------------------------------- + + void tester:: + perform_test ( + ) + { + } + +// ----------------------------------------------------------------------------- + + void tester:: + perform_test ( + const std::string& + ) + { + } + +// ----------------------------------------------------------------------------- + + void tester:: + perform_test ( + const std::string&, + const std::string& + ) + { + } + +// ----------------------------------------------------------------------------- + + void print_spinner ( + ) + { + if (be_verbose) + { + using namespace std; + dlib::auto_mutex M(spinner_mutex); + static int i = 0; + cout << "\b\b"; + switch (i) + { + case 0: cout << '|'; break; + case 1: cout << '/'; break; + case 2: cout << '-'; break; + case 3: cout << '\\'; break; + } + cout << " " << flush; + i = (i+1)%4; + } + } + +// ----------------------------------------------------------------------------- + +} + + + diff --git a/ml/dlib/dlib/test/tester.h b/ml/dlib/dlib/test/tester.h new file mode 100644 index 000000000..e16647cf5 --- /dev/null +++ b/ml/dlib/dlib/test/tester.h @@ -0,0 +1,187 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_TESTEr_ +#define DLIB_TESTEr_ + +#include +#include +#include +#include +#include +#include +#include + +#ifdef __INTEL_COMPILER +// ignore the bogus warning about not overloading perform_test() all the way +#pragma warning (disable: 654) +#endif + + +#define DLIB_TEST(_exp) check_test(bool(_exp), __LINE__, __FILE__, #_exp) + +#define DLIB_TEST_MSG(_exp,_message) \ + do{increment_test_count(); if ( !(_exp) ) \ + { \ + std::ostringstream dlib_o_out; \ + dlib_o_out << "\n\nError occurred at line " << __LINE__ << ".\n"; \ + dlib_o_out << "Error occurred in file " << __FILE__ << ".\n"; \ + dlib_o_out << "Failing expression was " << #_exp << ".\n"; \ + dlib_o_out << _message << "\n"; \ + throw dlib::error(dlib_o_out.str()); \ + }}while(0) + +namespace test +{ + class tester; + typedef dlib::map::kernel_1a_c map_of_testers; + + map_of_testers& testers ( + ); + +// ----------------------------------------------------------------------------- + + void check_test ( + bool _exp, + long line, + const char* file, + const char* _exp_str + ); + +// ----------------------------------------------------------------------------- + +// This bool controls any cout statements in this program. Only print to +// standard out if we should be verbose. The default is true + extern bool be_verbose; + +// ----------------------------------------------------------------------------- + + dlib::uint64 number_of_testing_statements_executed ( + ); + /*! + ensures + - returns the total number of DLIB_TEST and DLIB_TEST_MSG + statements executed since program startup. + !*/ + + void increment_test_count ( + ); + /*! + ensures + - increments number_of_testing_statements_executed() + !*/ + +// ----------------------------------------------------------------------------- + + void print_spinner ( + ); + /*! + ensures + - reprints the spinner + !*/ + +// ----------------------------------------------------------------------------- + + class tester + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a generic regression test. + !*/ + + public: + + tester ( + const std::string& switch_name, + const std::string& description_, + unsigned long num_of_args_ = 0 + ); + /*! + requires + - testers().is_in_domain(switch_name) == false + ensures + - #cmd_line_switch() == switch_name + - #description() == description_ + - #num_of_args() == num_of_args_ + - adds this tester to the testers() map. + !*/ + + virtual ~tester ( + ){} + + const std::string& cmd_line_switch ( + ) const; + /*! + ensures + - returns the name of the command line switch for this tester. + !*/ + + const std::string& description ( + ) const; + /*! + ensures + - returns the description of what this tester tests. + !*/ + + unsigned long num_of_args ( + ) const; + /*! + ensures + - returns the number of arguments this test expects + !*/ + + virtual void perform_test ( + ); + /*! + requires + - is invoked when number_of_args() == 0 + ensures + - performs the test and throws an exception + derived from std::exception if the test fails. + !*/ + + virtual void perform_test ( + const std::string& arg + ); + /*! + requires + - is invoked when number_of_args() == 1 + ensures + - performs the test and throws an exception + derived from std::exception if the test fails. + !*/ + + virtual void perform_test ( + const std::string& arg1, + const std::string& arg2 + ); + /*! + requires + - is invoked when number_of_args() == 2 + ensures + - performs the test and throws an exception + derived from std::exception if the test fails. + !*/ + + private: + + // --------------------------------------------------------------------------- + // Implementation Details + // --------------------------------------------------------------------------- + + /*! + CONVENTION + - switch_name == cmd_line_switch() + - description_ == description() + - num_of_args_ == num_of_args() + - test::tester[switch_name] == this + !*/ + + const std::string switch_name; + const std::string description_; + const unsigned long num_of_args_; + }; + +} + +#endif // DLIB_TESTEr_ + diff --git a/ml/dlib/dlib/test/thread_pool.cpp b/ml/dlib/dlib/test/thread_pool.cpp new file mode 100644 index 000000000..73ccb346e --- /dev/null +++ b/ml/dlib/dlib/test/thread_pool.cpp @@ -0,0 +1,428 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include +#include +#include + +#include "tester.h" + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.thread_pool"); + + + struct some_struct : noncopyable + { + float val; + }; + + int global_var = 0; + + struct add_functor + { + add_functor() { var = 1;} + add_functor(int v):var(v) {} + + template + void operator()(T a, U b, V& res) + { + dlib::sleep(20); + res = a + b; + } + + void set_global_var() { global_var = 9; } + void set_global_var_const() const { global_var = 9; } + + void set_global_var_arg1(int val) { global_var = val; } + void set_global_var_const_arg1(int val) const { global_var = val; } + void set_global_var_arg2(int val, int val2) { global_var = val+val2; } + void set_global_var_const_arg2(int val, int val2) const { global_var = val+val2; } + + void operator()() + { + global_var = 9; + } + + // use an any just so that if this object goes out of scope + // then var will get all messed up. + any var; + void operator()(int& a) { dlib::sleep(100); a = var.get(); } + void operator()(int& a, int& b) { dlib::sleep(100); a = var.get(); b = 2; } + void operator()(int& a, int& b, int& c) { dlib::sleep(100); a = var.get(); b = 2; c = 3; } + void operator()(int& a, int& b, int& c, int& d) { dlib::sleep(100); a = var.get(); b = 2; c = 3; d = 4; } + }; + + + void set_global_var() { global_var = 9; } + + void gset_struct_to_zero (some_struct& a) { a.val = 0; } + void gset_to_zero (int& a) { a = 0; } + void gincrement (int& a) { ++a; } + void gadd (int a, const int& b, int& res) { dlib::sleep(20); res = a + b; } + void gadd1(int& a, int& res) { res += a; } + void gadd2 (int c, int a, const int& b, int& res) { dlib::sleep(20); res = a + b + c; } + + class thread_pool_tester : public tester + { + public: + thread_pool_tester ( + ) : + tester ("test_thread_pool", + "Runs tests on the thread_pool component.") + {} + + void perform_test ( + ) + { + add_functor f; + for (int num_threads= 0; num_threads < 4; ++num_threads) + { + dlib::future a, b, c, res, d; + thread_pool tp(num_threads); + print_spinner(); + + dlib::future obj; + + + for (int i = 0; i < 4; ++i) + { + a = 1; + b = 2; + c = 3; + res = 4; + + + DLIB_TEST(a==a); + DLIB_TEST(a!=b); + DLIB_TEST(a==1); + + tp.add_task(gset_to_zero, a); + tp.add_task(gset_to_zero, b); + tp.add_task(*this, &thread_pool_tester::set_to_zero, c); + tp.add_task(gset_to_zero, res); + DLIB_TEST(a == 0); + DLIB_TEST(b == 0); + DLIB_TEST(c == 0); + DLIB_TEST(res == 0); + + + tp.add_task(gincrement, a); + tp.add_task(*this, &thread_pool_tester::increment, b); + tp.add_task(*this, &thread_pool_tester::increment, c); + tp.add_task(gincrement, res); + + DLIB_TEST(a == 1); + DLIB_TEST(b == 1); + DLIB_TEST(c == 1); + DLIB_TEST(res == 1); + + tp.add_task(&gincrement, a); + tp.add_task(*this, &thread_pool_tester::increment, b); + tp.add_task(*this, &thread_pool_tester::increment, c); + tp.add_task(&gincrement, res); + tp.add_task(gincrement, a); + tp.add_task(*this, &thread_pool_tester::increment, b); + tp.add_task(*this, &thread_pool_tester::increment, c); + tp.add_task(gincrement, res); + + DLIB_TEST(a == 3); + DLIB_TEST(b == 3); + DLIB_TEST(c == 3); + DLIB_TEST(res == 3); + + tp.add_task(*this, &thread_pool_tester::increment, c); + tp.add_task(gincrement, res); + DLIB_TEST(c == 4); + DLIB_TEST(res == 4); + + + tp.add_task(gadd, a, b, res); + DLIB_TEST(res == a+b); + DLIB_TEST(res == 6); + a = 3; + b = 4; + res = 99; + DLIB_TEST(res == 99); + tp.add_task(*this, &thread_pool_tester::add, a, b, res); + DLIB_TEST(res == a+b); + DLIB_TEST(res == 7); + + a = 1; + b = 2; + c = 3; + res = 88; + DLIB_TEST(res == 88); + DLIB_TEST(a == 1); + DLIB_TEST(b == 2); + DLIB_TEST(c == 3); + + tp.add_task(gadd2, a, b, c, res); + DLIB_TEST(res == 6); + DLIB_TEST(a == 1); + DLIB_TEST(b == 2); + DLIB_TEST(c == 3); + + a = 1; + b = 2; + c = 3; + res = 88; + DLIB_TEST(res == 88); + DLIB_TEST(a == 1); + DLIB_TEST(b == 2); + DLIB_TEST(c == 3); + tp.add_task(*this, &thread_pool_tester::add2, a, b, c, res); + DLIB_TEST(res == 6); + DLIB_TEST(a == 1); + DLIB_TEST(b == 2); + DLIB_TEST(c == 3); + + a = 1; + b = 2; + c = 3; + res = 88; + tp.add_task(gadd1, a, b); + DLIB_TEST(a == 1); + DLIB_TEST(b == 3); + a = 2; + tp.add_task(*this, &thread_pool_tester::add1, a, b); + DLIB_TEST(a == 2); + DLIB_TEST(b == 5); + + + val = 4; + uint64 id = tp.add_task(*this, &thread_pool_tester::zero_val); + tp.wait_for_task(id); + DLIB_TEST(val == 0); + id = tp.add_task(*this, &thread_pool_tester::accum2, 1,2); + tp.wait_for_all_tasks(); + DLIB_TEST(val == 3); + id = tp.add_task(*this, &thread_pool_tester::accum1, 3); + tp.wait_for_task(id); + DLIB_TEST(val == 6); + + + obj.get().val = 8; + DLIB_TEST(obj.get().val == 8); + tp.add_task(gset_struct_to_zero, obj); + DLIB_TEST(obj.get().val == 0); + obj.get().val = 8; + DLIB_TEST(obj.get().val == 8); + tp.add_task(*this,&thread_pool_tester::set_struct_to_zero, obj); + DLIB_TEST(obj.get().val == 0); + + a = 1; + b = 2; + res = 0; + tp.add_task(f, a, b, res); + DLIB_TEST(a == 1); + DLIB_TEST(b == 2); + DLIB_TEST(res == 3); + + + global_var = 0; + DLIB_TEST(global_var == 0); + id = tp.add_task(&set_global_var); + tp.wait_for_task(id); + DLIB_TEST(global_var == 9); + + global_var = 0; + DLIB_TEST(global_var == 0); + id = tp.add_task(f); + tp.wait_for_task(id); + DLIB_TEST(global_var == 9); + + global_var = 0; + DLIB_TEST(global_var == 0); + id = tp.add_task(f, &add_functor::set_global_var); + tp.wait_for_task(id); + DLIB_TEST(global_var == 9); + + global_var = 0; + a = 4; + DLIB_TEST(global_var == 0); + id = tp.add_task(f, &add_functor::set_global_var_arg1, a); + tp.wait_for_task(id); + DLIB_TEST(global_var == 4); + + global_var = 0; + a = 4; + DLIB_TEST(global_var == 0); + id = tp.add_task_by_value(f, &add_functor::set_global_var_arg1, a); + tp.wait_for_task(id); + DLIB_TEST(global_var == 4); + + + + global_var = 0; + a = 4; + b = 3; + DLIB_TEST(global_var == 0); + id = tp.add_task(f, &add_functor::set_global_var_arg2, a, b); + tp.wait_for_task(id); + DLIB_TEST(global_var == 7); + + global_var = 0; + a = 4; + b = 3; + DLIB_TEST(global_var == 0); + id = tp.add_task_by_value(f, &add_functor::set_global_var_arg2, a, b); + tp.wait_for_task(id); + DLIB_TEST(global_var == 7); + + global_var = 0; + a = 4; + b = 3; + DLIB_TEST(global_var == 0); + id = tp.add_task(f, &add_functor::set_global_var_const_arg2, a, b); + tp.wait_for_task(id); + DLIB_TEST(global_var == 7); + + global_var = 0; + a = 4; + b = 3; + DLIB_TEST(global_var == 0); + id = tp.add_task_by_value(f, &add_functor::set_global_var_const_arg2, a, b); + tp.wait_for_task(id); + DLIB_TEST(global_var == 7); + + + + + + + global_var = 0; + a = 4; + DLIB_TEST(global_var == 0); + id = tp.add_task(f, &add_functor::set_global_var_const_arg1, a); + tp.wait_for_task(id); + DLIB_TEST(global_var == 4); + + global_var = 0; + a = 4; + DLIB_TEST(global_var == 0); + id = tp.add_task_by_value(f, &add_functor::set_global_var_const_arg1, a); + tp.wait_for_task(id); + DLIB_TEST(global_var == 4); + + global_var = 0; + DLIB_TEST(global_var == 0); + id = tp.add_task_by_value(f, &add_functor::set_global_var); + tp.wait_for_task(id); + DLIB_TEST(global_var == 9); + + + global_var = 0; + DLIB_TEST(global_var == 0); + id = tp.add_task(f, &add_functor::set_global_var_const); + tp.wait_for_task(id); + DLIB_TEST(global_var == 9); + + + global_var = 0; + DLIB_TEST(global_var == 0); + id = tp.add_task_by_value(f, &add_functor::set_global_var_const); + tp.wait_for_task(id); + DLIB_TEST(global_var == 9); + + + + } + + // add this task just to to perterb the thread pool before it goes out of scope + tp.add_task(f, a, b, res); + + for (int k = 0; k < 3; ++k) + { + print_spinner(); + global_var = 0; + tp.add_task_by_value(add_functor()); + tp.wait_for_all_tasks(); + DLIB_TEST(global_var == 9); + + a = 0; b = 0; c = 0; d = 0; + tp.add_task_by_value(add_functor(), a); + DLIB_TEST(a == 1); + a = 0; b = 0; c = 0; d = 0; + tp.add_task_by_value(add_functor(8), a, b); + DLIB_TEST(a == 8); + DLIB_TEST(b == 2); + a = 0; b = 0; c = 0; d = 0; + tp.add_task_by_value(add_functor(), a, b, c); + DLIB_TEST(a == 1); + DLIB_TEST(b == 2); + DLIB_TEST(c == 3); + a = 0; b = 0; c = 0; d = 0; + tp.add_task_by_value(add_functor(5), a, b, c, d); + DLIB_TEST(a == 5); + DLIB_TEST(b == 2); + DLIB_TEST(c == 3); + DLIB_TEST(d == 4); + } + + + tp.wait_for_all_tasks(); + + // make sure exception propagation from tasks works correctly. + auto f_throws = []() { throw dlib::error("test exception");}; + bool got_exception = false; + try + { + tp.add_task_by_value(f_throws); + tp.wait_for_all_tasks(); + } + catch(dlib::error& e) + { + DLIB_TEST(e.info == "test exception"); + got_exception = true; + } + DLIB_TEST(got_exception); + + dlib::future aa; + auto f_throws2 = [](int& a) { a = 1; throw dlib::error("test exception");}; + got_exception = false; + try + { + tp.add_task(f_throws2, aa); + aa.get(); + } + catch(dlib::error& e) + { + DLIB_TEST(e.info == "test exception"); + got_exception = true; + } + DLIB_TEST(got_exception); + + } + } + + long val; + void accum1(long a) { val += a; } + void accum2(long a, long b) { val += a + b; } + void zero_val() { dlib::sleep(20); val = 0; } + + + void set_struct_to_zero (some_struct& a) { a.val = 0; } + void set_to_zero (int& a) { dlib::sleep(20); a = 0; } + void increment (int& a) const { dlib::sleep(20); ++a; } + void add (int a, const int& b, int& res) { dlib::sleep(20); res = a + b; } + void add1(int& a, int& res) const { res += a; } + void add2 (int c, int a, const int& b, int& res) { res = a + b + c; } + + + } a; + + +} + + + diff --git a/ml/dlib/dlib/test/threads.cpp b/ml/dlib/dlib/test/threads.cpp new file mode 100644 index 000000000..1aeb1a3f9 --- /dev/null +++ b/ml/dlib/dlib/test/threads.cpp @@ -0,0 +1,158 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include +#include + +#include "tester.h" + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.threads"); + + void test_async() + { +#if __cplusplus >= 201103 + print_spinner(); + auto v1 = dlib::async([]() { dlib::sleep(500); return 1; }).share(); + auto v2 = dlib::async([v1]() { dlib::sleep(400); return v1.get()+1; }).share(); + auto v3 = dlib::async([v2](int a) { dlib::sleep(300); return v2.get()+a; },2).share(); + auto v4 = dlib::async([v3]() { dlib::sleep(200); return v3.get()+1; }); + + DLIB_TEST(v4.get() == 5); + + print_spinner(); + auto except = dlib::async([](){ dlib::sleep(300); throw error("oops"); }); + bool got_exception = false; + try + { + except.get(); + } + catch (error&e) + { + got_exception = true; + DLIB_TEST(e.what() == string("oops")); + } + DLIB_TEST(got_exception); +#endif + } + + class threads_tester : public tester + { + public: + threads_tester ( + ) : + tester ("test_threads", + "Runs tests on the threads component."), + sm(cm) + {} + + thread_specific_data tsd; + rmutex cm; + rsignaler sm; + int count; + bool failure; + + void perform_test ( + ) + { + failure = false; + print_spinner(); + + + count = 10; + if (!create_new_thread(*this)) failure = true; + if (!create_new_thread(*this)) failure = true; + if (!create_new_thread(*this)) failure = true; + if (!create_new_thread(*this)) failure = true; + if (!create_new_thread(*this)) failure = true; + if (!create_new_thread(*this)) failure = true; + if (!create_new_thread(*this)) failure = true; + if (!create_new_thread(*this)) failure = true; + if (!create_new_thread(*this)) failure = true; + if (!create_new_thread(*this)) failure = true; + + thread(66); + + // this should happen in the main program thread + if (is_dlib_thread()) + failure = true; + + auto_mutex M(cm); + while (count > 0 && !failure) + sm.wait(); + + + DLIB_TEST(!failure); + + test_async(); + } + + void thread_end_handler ( + ) + { + auto_mutex M(cm); + --count; + if (count == 0) + sm.signal(); + } + + void thread1() { thread(1); } + void thread2() + { + thread(2); + if (is_dlib_thread() == false) + failure = true; + } + void thread3() { thread(3); } + void thread4() { thread(4); } + void thread5() { thread(5); } + void thread6() { thread(6); } + void thread7() { thread(7); } + void thread8() { thread(8); } + void thread9() { thread(9); } + void thread10() { thread(10); } + + void thread ( + int num + ) + { + dlog << LTRACE << "starting thread num " << num; + if (is_dlib_thread()) + register_thread_end_handler(*this,&threads_tester::thread_end_handler); + tsd.data() = num; + for (int i = 0; i < 0x3FFFF; ++i) + { + if ((i&0xFFF) == 0) + { + print_spinner(); + dlib::sleep(10); + } + // if this isn't equal to num then there is a problem with the thread specific data stuff + if (tsd.data() != num) + { + auto_mutex M(cm); + failure = true; + sm.signal(); + } + } + dlog << LTRACE << "ending of thread num " << num; + + + } + } a; + + +} + + + diff --git a/ml/dlib/dlib/test/timer.cpp b/ml/dlib/dlib/test/timer.cpp new file mode 100644 index 000000000..ae004a55d --- /dev/null +++ b/ml/dlib/dlib/test/timer.cpp @@ -0,0 +1,347 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include + +#include +#include +#include "tester.h" + +namespace +{ + + using namespace test; + using namespace std; + using namespace dlib; + + logger dlog("test.timer"); + + class timer_test_helper + { + public: + dlib::mutex m; + int count; + dlib::uint64 timestamp; + dlib::timestamper ts; + + timer_test_helper():count(0), timestamp(0){} + void add() + { + m.lock(); + ++count; + m.unlock(); + } + + void delayed_add() + { + dlib::sleep(1000); + print_spinner(); + add(); + } + + void set_timestamp() + { + m.lock(); + timestamp = ts.get_timestamp(); + dlog << LTRACE << "in set_timestamp(), time is " << timestamp; + dlib::sleep(1); + print_spinner(); + m.unlock(); + } + }; + + template < + typename timer_t + > + void timer_test2 ( + ) + /*! + requires + - timer_t is an implementation of dlib/timer/timer_abstract.h is instantiated + timer_test_helper + ensures + - runs tests on timer_t for compliance with the specs + !*/ + { + for (int j = 0; j < 4; ++j) + { + print_spinner(); + timer_test_helper h; + + timer_t t1(h,&timer_test_helper::set_timestamp); + t1.set_delay_time(0); + dlog << LTRACE << "t1.start()"; + t1.start(); + + dlib::sleep(60); + print_spinner(); + t1.stop_and_wait(); + + dlib::uint64 cur_time = h.ts.get_timestamp(); + dlog << LTRACE << "get current time: " << cur_time; + + // make sure the action function has been called recently + DLIB_TEST_MSG((cur_time-h.timestamp)/1000 < 30, (cur_time-h.timestamp)/1000); + + } + } + + template < + typename timer_t + > + void timer_test ( + ) + /*! + requires + - timer_t is an implementation of dlib/timer/timer_abstract.h is instantiated + timer_test_helper + ensures + - runs tests on timer_t for compliance with the specs + !*/ + { + + print_spinner(); + for (int j = 0; j < 3; ++j) + { + timer_test_helper h; + + timer_t t1(h,&timer_test_helper::add); + timer_t t2(h,&timer_test_helper::add); + timer_t t3(h,&timer_test_helper::add); + + DLIB_TEST(t1.delay_time() == 1000); + DLIB_TEST(t2.delay_time() == 1000); + DLIB_TEST(t3.delay_time() == 1000); + DLIB_TEST(t1.is_running() == false); + DLIB_TEST(t2.is_running() == false); + DLIB_TEST(t3.is_running() == false); + DLIB_TEST(t1.action_function() == &timer_test_helper::add); + DLIB_TEST(t2.action_function() == &timer_test_helper::add); + DLIB_TEST(t3.action_function() == &timer_test_helper::add); + DLIB_TEST(&t1.action_object() == &h); + DLIB_TEST(&t2.action_object() == &h); + DLIB_TEST(&t3.action_object() == &h); + + t1.set_delay_time(1000); + t2.set_delay_time(500); + t3.set_delay_time(1500); + + DLIB_TEST(t1.delay_time() == 1000); + DLIB_TEST(t2.delay_time() == 500); + DLIB_TEST(t3.delay_time() == 1500); + DLIB_TEST(t1.is_running() == false); + DLIB_TEST(t2.is_running() == false); + DLIB_TEST(t3.is_running() == false); + DLIB_TEST(t1.action_function() == &timer_test_helper::add); + DLIB_TEST(t2.action_function() == &timer_test_helper::add); + DLIB_TEST(t3.action_function() == &timer_test_helper::add); + DLIB_TEST(&t1.action_object() == &h); + DLIB_TEST(&t2.action_object() == &h); + DLIB_TEST(&t3.action_object() == &h); + dlib::sleep(1100); + print_spinner(); + DLIB_TEST(h.count == 0); + + t1.stop_and_wait(); + t2.stop_and_wait(); + t3.stop_and_wait(); + + dlib::sleep(1100); + print_spinner(); + DLIB_TEST(h.count == 0); + DLIB_TEST(t1.delay_time() == 1000); + DLIB_TEST(t2.delay_time() == 500); + DLIB_TEST(t3.delay_time() == 1500); + DLIB_TEST(t1.is_running() == false); + DLIB_TEST(t2.is_running() == false); + DLIB_TEST(t3.is_running() == false); + DLIB_TEST(t1.action_function() == &timer_test_helper::add); + DLIB_TEST(t2.action_function() == &timer_test_helper::add); + DLIB_TEST(t3.action_function() == &timer_test_helper::add); + DLIB_TEST(&t1.action_object() == &h); + DLIB_TEST(&t2.action_object() == &h); + DLIB_TEST(&t3.action_object() == &h); + + t1.start(); + t2.start(); + t3.start(); + + DLIB_TEST(t1.delay_time() == 1000); + DLIB_TEST(t2.delay_time() == 500); + DLIB_TEST(t3.delay_time() == 1500); + DLIB_TEST(t1.is_running() == true); + DLIB_TEST(t2.is_running() == true); + DLIB_TEST(t3.is_running() == true); + DLIB_TEST(t1.action_function() == &timer_test_helper::add); + DLIB_TEST(t2.action_function() == &timer_test_helper::add); + DLIB_TEST(t3.action_function() == &timer_test_helper::add); + DLIB_TEST(&t1.action_object() == &h); + DLIB_TEST(&t2.action_object() == &h); + DLIB_TEST(&t3.action_object() == &h); + + t1.stop(); + t2.stop(); + t3.stop(); + + DLIB_TEST(t1.delay_time() == 1000); + DLIB_TEST(t2.delay_time() == 500); + DLIB_TEST(t3.delay_time() == 1500); + DLIB_TEST(t1.is_running() == false); + DLIB_TEST(t2.is_running() == false); + DLIB_TEST(t3.is_running() == false); + DLIB_TEST(t1.action_function() == &timer_test_helper::add); + DLIB_TEST(t2.action_function() == &timer_test_helper::add); + DLIB_TEST(t3.action_function() == &timer_test_helper::add); + DLIB_TEST(&t1.action_object() == &h); + DLIB_TEST(&t2.action_object() == &h); + DLIB_TEST(&t3.action_object() == &h); + + DLIB_TEST(h.count == 0); + dlib::sleep(1100); + print_spinner(); + DLIB_TEST(h.count == 0); + + for (int i = 1; i <= 3; ++i) + { + t1.start(); + t2.start(); + t3.start(); + + DLIB_TEST(t1.is_running() == true); + DLIB_TEST(t2.is_running() == true); + DLIB_TEST(t3.is_running() == true); + + dlib::sleep(1800); + print_spinner(); + // this should allow the timers to trigger 5 times + t1.stop(); + t2.stop(); + t3.stop(); + + DLIB_TEST_MSG(h.count == 5*i,"h.count: " << h.count << " i: " << i); + dlib::sleep(1100); + DLIB_TEST_MSG(h.count == 5*i,"h.count: " << h.count << " i: " << i); + } + + + t1.stop_and_wait(); + + h.count = 0; + t1.start(); + dlib::sleep(300); + print_spinner(); + DLIB_TEST_MSG(h.count == 0,h.count); + t1.set_delay_time(400); + dlib::sleep(200); + print_spinner(); + DLIB_TEST_MSG(h.count == 1,h.count); + dlib::sleep(250); + print_spinner(); + DLIB_TEST_MSG(h.count == 1,h.count); + dlib::sleep(100); + print_spinner(); + DLIB_TEST_MSG(h.count == 2,h.count); + t1.set_delay_time(2000); + DLIB_TEST_MSG(h.count == 2,h.count); + dlib::sleep(1000); + print_spinner(); + DLIB_TEST_MSG(h.count == 2,h.count); + t1.clear(); + + h.count = 0; + t3.start(); + DLIB_TEST(t3.is_running() == true); + DLIB_TEST(t3.delay_time() == 1500); + DLIB_TEST_MSG(h.count == 0,h.count); + t3.clear(); + DLIB_TEST(t3.is_running() == false); + DLIB_TEST(t3.delay_time() == 1000); + DLIB_TEST_MSG(h.count == 0,h.count); + dlib::sleep(200); + print_spinner(); + DLIB_TEST(t3.is_running() == false); + DLIB_TEST(t3.delay_time() == 1000); + DLIB_TEST_MSG(h.count == 0,h.count); + + + { + h.count = 0; + timer_t t4(h,&timer_test_helper::delayed_add); + t4.set_delay_time(100); + t4.start(); + DLIB_TEST_MSG(h.count == 0,h.count); + dlib::sleep(400); + print_spinner(); + DLIB_TEST_MSG(h.count == 0,h.count); + t4.stop_and_wait(); + DLIB_TEST_MSG(h.count == 1,h.count); + DLIB_TEST(t4.is_running() == false); + } + + { + h.count = 0; + timer_t t4(h,&timer_test_helper::delayed_add); + t4.set_delay_time(100); + t4.start(); + DLIB_TEST_MSG(h.count == 0,h.count); + dlib::sleep(400); + print_spinner(); + DLIB_TEST_MSG(h.count == 0,h.count); + t4.clear(); + DLIB_TEST(t4.is_running() == false); + DLIB_TEST_MSG(h.count == 0,h.count); + t4.stop_and_wait(); + DLIB_TEST_MSG(h.count == 1,h.count); + DLIB_TEST(t4.is_running() == false); + } + + { + h.count = 0; + timer_t t5(h,&timer_test_helper::delayed_add); + t5.set_delay_time(100); + t5.start(); + DLIB_TEST_MSG(h.count == 0,h.count); + dlib::sleep(400); + print_spinner(); + DLIB_TEST_MSG(h.count == 0,h.count); + } + DLIB_TEST_MSG(h.count == 1,h.count); + + } + + } + + + + + class timer_tester : public tester + { + public: + timer_tester ( + ) : + tester ("test_timer", + "Runs tests on the timer component.") + {} + + void perform_test ( + ) + { + dlog << LINFO << "testing timer_heavy with test_timer"; + timer_test > (); + dlog << LINFO << "testing timer_heavy with test_timer2"; + timer_test2 > (); + + dlog << LINFO << "testing timer with test_timer"; + timer_test > (); + dlog << LINFO << "testing timer with test_timer2"; + timer_test2 > (); + } + } a; + +} + + diff --git a/ml/dlib/dlib/test/tokenizer.cpp b/ml/dlib/dlib/test/tokenizer.cpp new file mode 100644 index 000000000..95a95a7e1 --- /dev/null +++ b/ml/dlib/dlib/test/tokenizer.cpp @@ -0,0 +1,378 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include + +#include +#include "tester.h" + +namespace +{ + using namespace test; + using namespace std; + using namespace dlib; + + logger dlog("test.tokenizer"); + + template < + typename tok + > + void tokenizer_kernel_test ( + ) + /*! + requires + - tok is an implementation of tokenizer_kernel_abstract.h + ensures + - runs tests on tok for compliance with the specs + !*/ + { + + print_spinner(); + + tok test; + + DLIB_TEST(test.numbers() == "0123456789"); + DLIB_TEST(test.uppercase_letters() == "ABCDEFGHIJKLMNOPQRSTUVWXYZ"); + DLIB_TEST(test.lowercase_letters() == "abcdefghijklmnopqrstuvwxyz"); + + DLIB_TEST_MSG(test.get_identifier_body() == "_" + test.lowercase_letters() + + test.uppercase_letters() + test.numbers(),""); + DLIB_TEST_MSG(test.get_identifier_head() == "_" + test.lowercase_letters() + + test.uppercase_letters(),""); + + DLIB_TEST(test.stream_is_set() == false); + test.clear(); + DLIB_TEST(test.stream_is_set() == false); + + DLIB_TEST_MSG(test.get_identifier_body() == "_" + test.lowercase_letters() + + test.uppercase_letters() + test.numbers(),""); + DLIB_TEST_MSG(test.get_identifier_head() == "_" + test.lowercase_letters() + + test.uppercase_letters(),""); + + tok test2; + + ostringstream sout; + istringstream sin; + test2.set_stream(sin); + + DLIB_TEST(test2.stream_is_set()); + DLIB_TEST(&test2.get_stream() == &sin); + + int type; + string token; + + test2.get_token(type,token); + DLIB_TEST(type == tok::END_OF_FILE); + test2.get_token(type,token); + DLIB_TEST(type == tok::END_OF_FILE); + test2.get_token(type,token); + DLIB_TEST(type == tok::END_OF_FILE); + + + sin.clear(); + sin.str(" The cat 123asdf1234 ._ \n test."); + + test2.get_token(type,token); + DLIB_TEST(type == tok::WHITE_SPACE); + DLIB_TEST(token == " "); + + DLIB_TEST(test2.peek_type() == tok::IDENTIFIER); + DLIB_TEST(test2.peek_token() == "The"); + test2.get_token(type,token); + DLIB_TEST(type == tok::IDENTIFIER); + DLIB_TEST(token == "The"); + + test2.get_token(type,token); + DLIB_TEST(type == tok::WHITE_SPACE); + DLIB_TEST(token == " "); + + test2.get_token(type,token); + DLIB_TEST(type == tok::IDENTIFIER); + DLIB_TEST(token == "cat"); + + test2.get_token(type,token); + DLIB_TEST(type == tok::WHITE_SPACE); + DLIB_TEST(token == " "); + + test2.get_token(type,token); + DLIB_TEST(type == tok::NUMBER); + DLIB_TEST_MSG(token == "123","token: " << token); + + DLIB_TEST(test2.peek_type() == tok::IDENTIFIER); + DLIB_TEST(test2.peek_token() == "asdf1234"); + DLIB_TEST(test2.peek_type() == tok::IDENTIFIER); + DLIB_TEST(test2.peek_token() == "asdf1234"); + DLIB_TEST(test2.peek_type() == tok::IDENTIFIER); + DLIB_TEST(test2.peek_token() == "asdf1234"); + test2.get_token(type,token); + DLIB_TEST(type == tok::IDENTIFIER); + DLIB_TEST(token == "asdf1234"); + + test2.get_token(type,token); + DLIB_TEST(type == tok::WHITE_SPACE); + DLIB_TEST_MSG(token == " ","token: " << token); + + test2.get_token(type,token); + DLIB_TEST(type == tok::CHAR); + DLIB_TEST_MSG(token == ".","token: " << token); + + test2.get_token(type,token); + DLIB_TEST(type == tok::IDENTIFIER); + DLIB_TEST(token == "_"); + + DLIB_TEST(test2.peek_type() == tok::WHITE_SPACE); + DLIB_TEST_MSG(test2.peek_token() == " ","token: \"" << token << "\"" << + "\ntoken size: " << (unsigned int)token.size()); + + swap(test,test2); + + DLIB_TEST(test2.stream_is_set() == false); + + DLIB_TEST(test.peek_type() == tok::WHITE_SPACE); + DLIB_TEST_MSG(test.peek_token() == " ","token: \"" << token << "\"" << + "\ntoken size: " << (unsigned int)token.size()); + test.get_token(type,token); + DLIB_TEST(type == tok::WHITE_SPACE); + DLIB_TEST_MSG(token == " ","token: \"" << token << "\"" << + "\ntoken size: " << (unsigned int)token.size()); + + test.get_token(type,token); + DLIB_TEST_MSG(type == tok::END_OF_LINE,"token: " << token); + DLIB_TEST_MSG(token == "\n","token: " << token); + + swap(test,test2); + DLIB_TEST(test.stream_is_set() == false); + + test2.get_token(type,token); + DLIB_TEST(type == tok::WHITE_SPACE); + DLIB_TEST_MSG(token == " ","token: " << token); + + test2.get_token(type,token); + DLIB_TEST(type == tok::IDENTIFIER); + DLIB_TEST_MSG(token == "test","token: " << token); + + test2.get_token(type,token); + DLIB_TEST(type == tok::CHAR); + DLIB_TEST_MSG(token == ".","token: " << token); + + test2.get_token(type,token); + DLIB_TEST(type == tok::END_OF_FILE); + + + + + + + + + + + test2.set_identifier_token("_" + test.uppercase_letters() + + test.lowercase_letters(),test.numbers() + "_" + test.uppercase_letters() + +test.lowercase_letters()); + + + sin.clear(); + sin.str(" The cat 123asdf1234 ._ \n\r test."); + + test2.get_token(type,token); + DLIB_TEST(type == tok::WHITE_SPACE); + DLIB_TEST(token == " "); + + test2.get_token(type,token); + DLIB_TEST(type == tok::IDENTIFIER); + DLIB_TEST(token == "The"); + + test2.get_token(type,token); + DLIB_TEST(type == tok::WHITE_SPACE); + DLIB_TEST(token == " "); + + test2.get_token(type,token); + DLIB_TEST(type == tok::IDENTIFIER); + DLIB_TEST(token == "cat"); + + test2.get_token(type,token); + DLIB_TEST(type == tok::WHITE_SPACE); + DLIB_TEST(token == " "); + + test2.get_token(type,token); + DLIB_TEST(type == tok::NUMBER); + DLIB_TEST_MSG(token == "123","token: " << token); + + test2.get_token(type,token); + DLIB_TEST(type == tok::IDENTIFIER); + DLIB_TEST(token == "asdf1234"); + + test2.get_token(type,token); + DLIB_TEST(type == tok::WHITE_SPACE); + DLIB_TEST_MSG(token == " ","token: " << token); + + test2.get_token(type,token); + DLIB_TEST(type == tok::CHAR); + DLIB_TEST_MSG(token == ".","token: " << token); + + test2.get_token(type,token); + DLIB_TEST(type == tok::IDENTIFIER); + DLIB_TEST(token == "_"); + + swap(test,test2); + + DLIB_TEST(test2.stream_is_set() == false); + + test.get_token(type,token); + DLIB_TEST(type == tok::WHITE_SPACE); + DLIB_TEST_MSG(token == " ","token: \"" << token << "\"" << + "\ntoken size: " << (unsigned int)token.size()); + + test.get_token(type,token); + DLIB_TEST_MSG(type == tok::END_OF_LINE,"token: " << token); + DLIB_TEST_MSG(token == "\n","token: " << token); + + swap(test,test2); + DLIB_TEST(test.stream_is_set() == false); + + test2.get_token(type,token); + DLIB_TEST(type == tok::WHITE_SPACE); + DLIB_TEST_MSG(token == "\r ","token: " << token); + + test2.get_token(type,token); + DLIB_TEST(type == tok::IDENTIFIER); + DLIB_TEST_MSG(token == "test","token: " << token); + + test2.get_token(type,token); + DLIB_TEST(type == tok::CHAR); + DLIB_TEST_MSG(token == ".","token: " << token); + + test2.get_token(type,token); + DLIB_TEST(type == tok::END_OF_FILE); + + + + + + + + + + + + + + test2.set_identifier_token(test.uppercase_letters() + + test.lowercase_letters(),test.numbers() + test.uppercase_letters() + +test.lowercase_letters()); + + + sin.clear(); + sin.str(" The cat 123as_df1234 ._ \n test."); + + test2.get_token(type,token); + DLIB_TEST(type == tok::WHITE_SPACE); + DLIB_TEST(token == " "); + + test2.get_token(type,token); + DLIB_TEST(type == tok::IDENTIFIER); + DLIB_TEST(token == "The"); + + test2.get_token(type,token); + DLIB_TEST(type == tok::WHITE_SPACE); + DLIB_TEST(token == " "); + + test2.get_token(type,token); + DLIB_TEST(type == tok::IDENTIFIER); + DLIB_TEST(token == "cat"); + + test2.get_token(type,token); + DLIB_TEST(type == tok::WHITE_SPACE); + DLIB_TEST(token == " "); + + test2.get_token(type,token); + DLIB_TEST(type == tok::NUMBER); + DLIB_TEST_MSG(token == "123","token: " << token); + + test2.get_token(type,token); + DLIB_TEST(type == tok::IDENTIFIER); + DLIB_TEST(token == "as"); + + test2.get_token(type,token); + DLIB_TEST(type == tok::CHAR); + DLIB_TEST_MSG(token == "_","token: " << token); + + test2.get_token(type,token); + DLIB_TEST(type == tok::IDENTIFIER); + DLIB_TEST(token == "df1234"); + + test2.get_token(type,token); + DLIB_TEST(type == tok::WHITE_SPACE); + DLIB_TEST_MSG(token == " ","token: " << token); + + test2.get_token(type,token); + DLIB_TEST(type == tok::CHAR); + DLIB_TEST_MSG(token == ".","token: " << token); + + test2.get_token(type,token); + DLIB_TEST(type == tok::CHAR); + DLIB_TEST(token == "_"); + + swap(test,test2); + + DLIB_TEST(test2.stream_is_set() == false); + + test.get_token(type,token); + DLIB_TEST(type == tok::WHITE_SPACE); + DLIB_TEST_MSG(token == " ","token: \"" << token << "\"" << + "\ntoken size: " << (unsigned int)token.size()); + + test.get_token(type,token); + DLIB_TEST_MSG(type == tok::END_OF_LINE,"token: " << token); + DLIB_TEST_MSG(token == "\n","token: " << token); + + swap(test,test2); + DLIB_TEST(test.stream_is_set() == false); + + test2.get_token(type,token); + DLIB_TEST(type == tok::WHITE_SPACE); + DLIB_TEST_MSG(token == " ","token: " << token); + + test2.get_token(type,token); + DLIB_TEST(type == tok::IDENTIFIER); + DLIB_TEST_MSG(token == "test","token: " << token); + + test2.get_token(type,token); + DLIB_TEST(type == tok::CHAR); + DLIB_TEST_MSG(token == ".","token: " << token); + + test2.get_token(type,token); + DLIB_TEST(type == tok::END_OF_FILE); + + + } + + + + + + class tokenizer_tester : public tester + { + public: + tokenizer_tester ( + ) : + tester ("test_tokenizer", + "Runs tests on the tokenizer component.") + {} + + void perform_test ( + ) + { + dlog << LINFO << "testing kernel_1a"; + tokenizer_kernel_test (); + dlog << LINFO << "testing kernel_1a_c"; + tokenizer_kernel_test(); + } + } a; + +} + + diff --git a/ml/dlib/dlib/test/tools/CMakeLists.txt b/ml/dlib/dlib/test/tools/CMakeLists.txt new file mode 100644 index 000000000..adbd43cb9 --- /dev/null +++ b/ml/dlib/dlib/test/tools/CMakeLists.txt @@ -0,0 +1,5 @@ +cmake_minimum_required(VERSION 2.8.12) + +add_subdirectory(../../../tools/imglab imglab_build) +add_subdirectory(../../../tools/htmlify htmlify_build) +add_subdirectory(../../../tools/convert_dlib_nets_to_caffe convert_dlib_nets_to_caffe_build) diff --git a/ml/dlib/dlib/test/trust_region.cpp b/ml/dlib/dlib/test/trust_region.cpp new file mode 100644 index 000000000..aa2775b9c --- /dev/null +++ b/ml/dlib/dlib/test/trust_region.cpp @@ -0,0 +1,329 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include "optimization_test_functions.h" +#include +#include +#include +#include +#include +#include "../rand.h" + +#include "tester.h" + + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + using namespace dlib::test_functions; + + logger dlog("test.trust_region"); + +// ---------------------------------------------------------------------------------------- + + template + struct neg_rosen_model + { + typedef matrix column_vector; + typedef matrix general_matrix; + + T operator() ( column_vector x) const + { + return -static_cast(rosen(x)); + } + + void get_derivative_and_hessian ( + const column_vector& x, + column_vector& d, + general_matrix& h + ) const + { + d = -matrix_cast(rosen_derivative(x)); + h = -matrix_cast(rosen_hessian(x)); + } + + }; + +// ---------------------------------------------------------------------------------------- + + dlib::rand rnd; + + template + void test_with_rosen() + { + print_spinner(); + + matrix ans; + ans = 1,1; + + matrix p = 100*matrix_cast(randm(2,1,rnd)) - 50; + + T obj = find_min_trust_region(objective_delta_stop_strategy(1e-12, 100), rosen_function_model(), p); + + DLIB_TEST_MSG(std::abs(obj) < 1e-10, "obj: " << obj); + DLIB_TEST_MSG(length(p-ans) < 1e-5, "length(p): " << length(p-ans)); + + matrix p2 = 100*matrix_cast(randm(2,1,rnd)) - 50; + obj = find_max_trust_region(objective_delta_stop_strategy(1e-12, 100), neg_rosen_model(), p2); + + DLIB_TEST_MSG(std::abs(obj) < 1e-10, "obj: " << obj); + DLIB_TEST_MSG(length(p-ans) < 1e-5, "length(p): " << length(p-ans)); + } + +// ---------------------------------------------------------------------------------------- + + void test_trust_region_sub_problem() + { + dlog << LINFO << "subproblem test 1"; + { + matrix B; + B = 1, 0, + 0, 1; + + matrix g, p, ans; + g = 0; + + ans = 0; + + solve_trust_region_subproblem(B,g,1,p, 0.001, 10); + + DLIB_TEST(length(p-ans) < 1e-10); + solve_trust_region_subproblem(B,g,1,p, 0.001, 1); + DLIB_TEST(length(p-ans) < 1e-10); + } + + dlog << LINFO << "subproblem test 2"; + { + matrix B; + B = 1, 0, + 0, 1; + + B *= 0.1; + + matrix g, p, ans; + g = 1; + + ans = -g / length(g); + + solve_trust_region_subproblem(B,g,1,p, 1e-6, 20); + + DLIB_TEST(length(p-ans) < 1e-4); + } + + dlog << LINFO << "subproblem test 3"; + { + matrix B; + B = 0, 0, + 0, 0; + + matrix g, p, ans; + g = 1; + + ans = -g / length(g); + + solve_trust_region_subproblem(B,g,1,p, 1e-6, 20); + + dlog << LINFO << "ans: " << trans(ans); + dlog << LINFO << "p: " << trans(p); + DLIB_TEST(length(p-ans) < 1e-4); + } + return; + + dlog << LINFO << "subproblem test 4"; + { + matrix B; + B = 2, 0, + 0, -1; + + + matrix g, p, ans; + g = 0; + + ans = 0, -1; + + solve_trust_region_subproblem(B,g,1,p, 1e-6, 20); + + DLIB_TEST(length(p-ans) < 1e-4); + } + + + dlog << LINFO << "subproblem test 5"; + { + matrix B; + B = 2, 0, + 0, -1; + + + matrix g, p, ans; + g = 0, 1; + + ans = 0, -1; + + solve_trust_region_subproblem(B,g,1,p, 1e-6, 20); + + DLIB_TEST(length(p-ans) < 1e-4); + } + + dlog << LINFO << "subproblem test 6"; + for (int i = 0; i < 10; ++i) + { + matrix B; + + B = randm(10,10, rnd); + + B = 0.01*B*trans(B); + + + matrix g, p, ans; + g = 1; + + solve_trust_region_subproblem(B,g,1,p, 1e-6, 20); + + DLIB_TEST(std::abs(length(p) - 1) < 1e-4); + } + } + +// ---------------------------------------------------------------------------------------- + + void test_problems() + { + print_spinner(); + { + matrix ch; + + ch = brown_start(); + + find_min_trust_region(objective_delta_stop_strategy(1e-7, 80), + brown_function_model(), + ch); + + dlog << LINFO << "brown obj: " << brown(ch); + dlog << LINFO << "brown der: " << length(brown_derivative(ch)); + dlog << LINFO << "brown error: " << length(ch - brown_solution()); + + DLIB_TEST(length(ch - brown_solution()) < 1e-5); + + } + print_spinner(); + { + matrix ch; + + ch = rosen_start(); + + find_min_trust_region(objective_delta_stop_strategy(1e-7, 80), + rosen_function_model(), + ch); + + dlog << LINFO << "rosen obj: " << rosen(ch); + dlog << LINFO << "rosen der: " << length(rosen_derivative(ch)); + dlog << LINFO << "rosen error: " << length(ch - rosen_solution()); + + DLIB_TEST(length(ch - rosen_solution()) < 1e-5); + } + + print_spinner(); + { + matrix ch; + + ch = chebyquad_start(2); + + find_min_trust_region(objective_delta_stop_strategy(1e-7, 80), + chebyquad_function_model(), + ch); + + dlog << LINFO << "chebyquad 2 obj: " << chebyquad(ch); + dlog << LINFO << "chebyquad 2 der: " << length(chebyquad_derivative(ch)); + dlog << LINFO << "chebyquad 2 error: " << length(ch - chebyquad_solution(2)); + + DLIB_TEST(length(ch - chebyquad_solution(2)) < 1e-5); + + } + print_spinner(); + { + matrix ch; + + ch = chebyquad_start(4); + + find_min_trust_region(objective_delta_stop_strategy(1e-7, 80), + chebyquad_function_model(), + ch); + + dlog << LINFO << "chebyquad 4 obj: " << chebyquad(ch); + dlog << LINFO << "chebyquad 4 der: " << length(chebyquad_derivative(ch)); + dlog << LINFO << "chebyquad 4 error: " << length(ch - chebyquad_solution(4)); + + DLIB_TEST(length(ch - chebyquad_solution(4)) < 1e-5); + } + print_spinner(); + { + matrix ch; + + ch = chebyquad_start(6); + + find_min_trust_region(objective_delta_stop_strategy(1e-12, 80), + chebyquad_function_model(), + ch); + + dlog << LINFO << "chebyquad 6 obj: " << chebyquad(ch); + dlog << LINFO << "chebyquad 6 der: " << length(chebyquad_derivative(ch)); + dlog << LINFO << "chebyquad 6 error: " << length(ch - chebyquad_solution(6)); + + DLIB_TEST(length(ch - chebyquad_solution(6)) < 1e-5); + + } + print_spinner(); + { + matrix ch; + + ch = chebyquad_start(8); + + find_min_trust_region(objective_delta_stop_strategy(1e-10, 80), + chebyquad_function_model(), + ch); + + dlog << LINFO << "chebyquad 8 obj: " << chebyquad(ch); + dlog << LINFO << "chebyquad 8 der: " << length(chebyquad_derivative(ch)); + dlog << LINFO << "chebyquad 8 error: " << length(ch - chebyquad_solution(8)); + + DLIB_TEST(length(ch - chebyquad_solution(8)) < 1e-5); + } + + } + + + + class optimization_tester : public tester + { + public: + optimization_tester ( + ) : + tester ("test_trust_region", + "Runs tests on the trust region optimization component.") + {} + + void perform_test ( + ) + { + dlog << LINFO << "test with rosen"; + for (int i = 0; i < 50; ++i) + test_with_rosen(); + + dlog << LINFO << "test with rosen"; + for (int i = 0; i < 50; ++i) + test_with_rosen(); + + + test_trust_region_sub_problem(); + + test_problems(); + } + } a; + +} + + diff --git a/ml/dlib/dlib/test/tuple.cpp b/ml/dlib/dlib/test/tuple.cpp new file mode 100644 index 000000000..da7a18ec8 --- /dev/null +++ b/ml/dlib/dlib/test/tuple.cpp @@ -0,0 +1,186 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include + +#include "tester.h" + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.tuple"); + + struct s_nil + { + template + void operator() ( + const T& + ) const + { + } + }; + + + struct inc + { + template + void operator() ( + T& a + ) const + { + a += 1; + } + }; + + + template + void check_const ( + const T& t + ) + { + t.template get<0>(); + + typedef typename T::template get_type<0>::type type0; + t.template get(); + t.template index(); + } + + template + void check_nonconst ( + T& t + ) + { + t.template get<0>(); + + typedef typename T::template get_type<0>::type type0; + t.template get(); + t.template index(); + } + + void tuple_test ( + ) + /*! + ensures + - runs tests on tuple functions for compliance with the specs + !*/ + { + + print_spinner(); + + using dlib::tuple; + + tuple<> a; + tuple b; + tuple c; + + + a.get<1>(); + a.get<2>(); + a.get<3>(); + a.get<4>(); + a.get<5>(); + + check_nonconst(b); + check_nonconst(c); + check_const(b); + check_const(c); + + COMPILE_TIME_ASSERT((is_same_type::get_type<0>::type, null_type>::value)); + COMPILE_TIME_ASSERT((is_same_type::get_type<0>::type, int>::value)); + COMPILE_TIME_ASSERT((is_same_type::get_type<0>::type, int>::value)); + COMPILE_TIME_ASSERT((is_same_type::get_type<1>::type, float>::value)); + COMPILE_TIME_ASSERT((is_same_type::get_type<2>::type, null_type>::value)); + + b.get<0>() = 8; + DLIB_TEST(b.get() == 8); + DLIB_TEST(b.index() == 0); + + c.get<0>() = 9; + DLIB_TEST(c.get() == 9); + DLIB_TEST(c.index() == 0); + c.get<1>() = 3.0; + DLIB_TEST(c.get() == 3.0); + DLIB_TEST(c.index() == 1); + + + + { + typedef tuple T; + T a, b; + a.get<0>() = 1; + a.get<1>() = 3; + a.get<2>() = 2; + + b = a; + + inc i; + s_nil n; + a.for_each(inc()); + a.for_each(i); + const_cast(a).for_each(s_nil()); + const_cast(a).for_each(n); + + DLIB_TEST(a.get<0>() == b.get<0>()+2); + DLIB_TEST(a.get<1>() == b.get<1>()+2); + DLIB_TEST(a.get<2>() == b.get<2>()+2); + + ostringstream sout; + + serialize(a,sout); + istringstream sin(sout.str()); + deserialize(b,sin); + + DLIB_TEST(a.get<0>() == b.get<0>()); + DLIB_TEST(a.get<1>() == b.get<1>()); + DLIB_TEST(a.get<2>() == b.get<2>()); + + a.for_index(i,0); + a.for_index(inc(),1); + const_cast(a).for_index(n,2); + const_cast(a).for_index(s_nil(),0); + + DLIB_TEST(a.get<0>() == b.get<0>()+1); + DLIB_TEST(a.get<1>() == b.get<1>()+1); + DLIB_TEST(a.get<2>() == b.get<2>()+0); + + swap(a,b); + + DLIB_TEST(b.get<0>() == a.get<0>()+1); + DLIB_TEST(b.get<1>() == a.get<1>()+1); + DLIB_TEST(b.get<2>() == a.get<2>()+0); + } + + + } + + + + + class tuple_tester : public tester + { + public: + tuple_tester ( + ) : + tester ("test_tuple", + "Runs tests on the tuple object") + {} + + void perform_test ( + ) + { + tuple_test(); + } + } a; + +} + + + diff --git a/ml/dlib/dlib/test/type_safe_union.cpp b/ml/dlib/dlib/test/type_safe_union.cpp new file mode 100644 index 000000000..6a18fa8e1 --- /dev/null +++ b/ml/dlib/dlib/test/type_safe_union.cpp @@ -0,0 +1,455 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include +#include +#include +#include +#include + +#include "tester.h" + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.type_safe_union"); + + struct can_not_copy: noncopyable {}; + void serialize(const can_not_copy&, std::ostream&) {} + void deserialize(can_not_copy&, std::istream&) {} + + void swap(can_not_copy&, can_not_copy&) {} + + class test + { + + private: + + enum kind + { + FLOAT, DOUBLE, CHAR, STRING, NONE + }; + + void operator() (float val) + { + DLIB_TEST(val == f_val); + last_kind = FLOAT; + } + + void operator() (double val) + { + DLIB_TEST(val == d_val); + last_kind = DOUBLE; + } + + void operator() (char val) + { + DLIB_TEST(val == c_val); + last_kind = CHAR; + } + + void operator()(std::string& val) + { + DLIB_TEST(val == s_val); + last_kind = STRING; + } + + void operator()(const std::string& val) + { + DLIB_TEST(val == s_val); + last_kind = STRING; + } + + // ------------------------------ + + friend class type_safe_union; + typedef type_safe_union tsu; + tsu a, b, c; + + float f_val; + double d_val; + char c_val; + std::string s_val; + + kind last_kind; + + public: + void test_stuff() + { + DLIB_TEST(a.is_empty() == true); + DLIB_TEST(a.contains() == false); + DLIB_TEST(a.contains() == false); + DLIB_TEST(a.contains() == false); + DLIB_TEST(a.contains() == false); + DLIB_TEST(a.contains() == false); + + DLIB_TEST(a.get_type_id() == -1); + DLIB_TEST(a.get_type_id() == 1); + DLIB_TEST(a.get_type_id() == 2); + DLIB_TEST(a.get_type_id() == 3); + DLIB_TEST(a.get_type_id() == 4); + DLIB_TEST(a.get_type_id() == -1); + + + f_val = 4.345f; + a.get() = f_val; + DLIB_TEST(a.cast_to() == f_val); + DLIB_TEST(const_cast(a).cast_to() == f_val); + bool exception_thrown = false; + try {a.cast_to(); } + catch (bad_type_safe_union_cast&) { exception_thrown = true;} + DLIB_TEST(exception_thrown); + + + DLIB_TEST(a.is_empty() == false); + DLIB_TEST(a.contains() == false); + DLIB_TEST(a.contains() == true); + DLIB_TEST(a.contains() == false); + DLIB_TEST(a.contains() == false); + DLIB_TEST(a.contains() == false); + + + last_kind = NONE; + const_cast(a).apply_to_contents(*this); + DLIB_TEST(last_kind == FLOAT); + + // ----------- + + d_val = 4.345; + a.get() = d_val; + last_kind = NONE; + a.apply_to_contents(*this); + DLIB_TEST(last_kind == DOUBLE); + + // ----------- + + c_val = 'a'; + a.get() = c_val; + last_kind = NONE; + const_cast(a).apply_to_contents(*this); + DLIB_TEST(last_kind == CHAR); + + // ----------- + + s_val = "test string"; + a.get() = s_val; + last_kind = NONE; + a.apply_to_contents(*this); + DLIB_TEST(last_kind == STRING); + + DLIB_TEST(a.cast_to() == s_val); + exception_thrown = false; + try {a.cast_to(); } + catch (bad_type_safe_union_cast&) { exception_thrown = true;} + DLIB_TEST(exception_thrown); + + // ----------- + DLIB_TEST(a.is_empty() == false); + DLIB_TEST(a.contains() == false); + DLIB_TEST(a.contains() == false); + DLIB_TEST(a.contains() == false); + DLIB_TEST(a.contains() == true); + DLIB_TEST(a.contains() == false); + // ----------- + + a.swap(b); + + DLIB_TEST(a.is_empty() == true); + DLIB_TEST(a.contains() == false); + DLIB_TEST(a.contains() == false); + DLIB_TEST(a.contains() == false); + DLIB_TEST(a.contains() == false); + DLIB_TEST(a.contains() == false); + + DLIB_TEST(b.is_empty() == false); + DLIB_TEST(b.contains() == false); + DLIB_TEST(b.contains() == false); + DLIB_TEST(b.contains() == false); + DLIB_TEST(b.contains() == true); + DLIB_TEST(b.contains() == false); + + + last_kind = NONE; + b.apply_to_contents(*this); + DLIB_TEST(last_kind == STRING); + + // ----------- + + b.swap(a); + + DLIB_TEST(b.is_empty() == true); + DLIB_TEST(b.contains() == false); + DLIB_TEST(b.contains() == false); + DLIB_TEST(b.contains() == false); + DLIB_TEST(b.contains() == false); + DLIB_TEST(b.contains() == false); + + DLIB_TEST(a.is_empty() == false); + DLIB_TEST(a.contains() == false); + DLIB_TEST(a.contains() == false); + DLIB_TEST(a.contains() == false); + DLIB_TEST(a.contains() == true); + DLIB_TEST(a.contains() == false); + + + last_kind = NONE; + a.apply_to_contents(*this); + DLIB_TEST(last_kind == STRING); + last_kind = NONE; + b.apply_to_contents(*this); + DLIB_TEST(last_kind == NONE); + + + a.get() = 'a'; + b.get() = 'b'; + + DLIB_TEST(a.is_empty() == false); + DLIB_TEST(a.contains() == true); + DLIB_TEST(b.is_empty() == false); + DLIB_TEST(b.contains() == true); + DLIB_TEST(a.contains() == false); + DLIB_TEST(b.contains() == false); + + + DLIB_TEST(a.get() == 'a'); + DLIB_TEST(b.get() == 'b'); + + swap(a,b); + + + DLIB_TEST(a.is_empty() == false); + DLIB_TEST(a.contains() == true); + DLIB_TEST(b.is_empty() == false); + DLIB_TEST(b.contains() == true); + DLIB_TEST(a.contains() == false); + DLIB_TEST(b.contains() == false); + + DLIB_TEST(a.get() == 'b'); + DLIB_TEST(b.get() == 'a'); + + // ----------- + + a.get() = 'a'; + b.get() = "a string"; + + DLIB_TEST(a.is_empty() == false); + DLIB_TEST(a.contains() == true); + DLIB_TEST(b.is_empty() == false); + DLIB_TEST(b.contains() == false); + DLIB_TEST(a.contains() == false); + DLIB_TEST(b.contains() == true); + + + DLIB_TEST(a.get() == 'a'); + DLIB_TEST(b.get() == "a string"); + + swap(a,b); + + DLIB_TEST(b.is_empty() == false); + DLIB_TEST(b.contains() == true); + DLIB_TEST(a.is_empty() == false); + DLIB_TEST(a.contains() == false); + DLIB_TEST(b.contains() == false); + DLIB_TEST(a.contains() == true); + + + DLIB_TEST(b.get() == 'a'); + DLIB_TEST(a.get() == "a string"); + + + + + { + type_safe_union a, b, empty_union; + + ostringstream sout; + istringstream sin; + + a.get() = 'd'; + + serialize(a, sout); + + sin.str(sout.str()); + deserialize(b, sin); + + DLIB_TEST(b.contains() == false); + DLIB_TEST(b.contains() == false); + DLIB_TEST(b.contains() == true); + DLIB_TEST(b.get() == 'd'); + + DLIB_TEST(a.contains() == false); + DLIB_TEST(a.contains() == false); + DLIB_TEST(a.contains() == true); + DLIB_TEST(a.get() == 'd'); + + sin.clear(); + sout.clear(); + sout.str(""); + + a.get() = "davis"; + + serialize(a, sout); + sin.str(sout.str()); + deserialize(b, sin); + + + DLIB_TEST(b.contains() == false); + DLIB_TEST(b.contains() == false); + DLIB_TEST(b.contains() == true); + DLIB_TEST(b.get() == "davis"); + + sin.clear(); + sout.clear(); + sout.str(""); + + serialize(empty_union, sout); + sin.str(sout.str()); + deserialize(b, sin); + + DLIB_TEST(b.is_empty() == true); + + } + + { + type_safe_union a, b, empty_union; + + ostringstream sout; + istringstream sin; + + a = 'd'; + + serialize(a, sout); + + sin.str(sout.str()); + deserialize(b, sin); + + DLIB_TEST(b.contains() == false); + DLIB_TEST(b.contains() == false); + DLIB_TEST(b.contains() == true); + DLIB_TEST(b.get() == 'd'); + + DLIB_TEST(a.contains() == false); + DLIB_TEST(a.contains() == false); + DLIB_TEST(a.contains() == true); + DLIB_TEST(a.get() == 'd'); + + sin.clear(); + sout.clear(); + sout.str(""); + + a = std::string("davis"); + + serialize(a, sout); + sin.str(sout.str()); + deserialize(b, sin); + + + DLIB_TEST(b.contains() == false); + DLIB_TEST(b.contains() == false); + DLIB_TEST(b.contains() == true); + DLIB_TEST(b.get() == "davis"); + + sin.clear(); + sout.clear(); + sout.str(""); + + serialize(empty_union, sout); + sin.str(sout.str()); + deserialize(b, sin); + + DLIB_TEST(b.is_empty() == true); + + } + + { + typedef type_safe_union tsu_type; + tsu_type a('d'), aa(std::string("davis")), b, empty_union; + + ostringstream sout; + istringstream sin; + + + serialize(a, sout); + + sin.str(sout.str()); + deserialize(b, sin); + + DLIB_TEST(b.contains() == false); + DLIB_TEST(b.contains() == false); + DLIB_TEST(b.contains() == true); + DLIB_TEST(b.get() == 'd'); + + DLIB_TEST(a.contains() == false); + DLIB_TEST(a.contains() == false); + DLIB_TEST(a.contains() == true); + DLIB_TEST(a.get() == 'd'); + + DLIB_TEST(aa.contains() == false); + DLIB_TEST(aa.contains() == false); + DLIB_TEST(aa.contains() == false); + DLIB_TEST(aa.contains() == true); + + sin.clear(); + sout.clear(); + sout.str(""); + + + serialize(aa, sout); + sin.str(sout.str()); + deserialize(b, sin); + + + DLIB_TEST(b.contains() == false); + DLIB_TEST(b.contains() == false); + DLIB_TEST(b.contains() == true); + DLIB_TEST(b.get() == "davis"); + + sin.clear(); + sout.clear(); + sout.str(""); + + serialize(empty_union, sout); + sin.str(sout.str()); + deserialize(b, sin); + + DLIB_TEST(b.is_empty() == true); + + a.get(); + DLIB_TEST(a.contains() == true); + + } + } + + }; + + + + class type_safe_union_tester : public tester + { + public: + type_safe_union_tester ( + ) : + tester ("test_type_safe_union", + "Runs tests on the type_safe_union object") + {} + + void perform_test ( + ) + { + for (int i = 0; i < 10; ++i) + { + test a; + a.test_stuff(); + } + } + } a; + +} + + + + diff --git a/ml/dlib/dlib/test/vectorstream.cpp b/ml/dlib/dlib/test/vectorstream.cpp new file mode 100644 index 000000000..a955961a2 --- /dev/null +++ b/ml/dlib/dlib/test/vectorstream.cpp @@ -0,0 +1,142 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include + +#include +#include +#include +#include +#include + +#include "tester.h" + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + + logger dlog("test.vectorstream"); + +// ---------------------------------------------------------------------------------------- + + void test1() + { + print_spinner(); + + std::vector buf; + vectorstream s(buf); + + for (int i = -1000; i <= 1000; ++i) + { + char ch = i; + s.put(ch); + } + + DLIB_TEST(buf.size() == 2001); + + int cnt = -1000; + for (unsigned long i = 0; i < buf.size(); ++i) + { + char ch = cnt; + DLIB_TEST(buf[i] == ch); + ++cnt; + } + + for (int i = -1000; i <= 1000; ++i) + { + DLIB_TEST(s.peek() != EOF); + char ch1 = i; + char ch2 = s.get(); + DLIB_TEST(ch1 == ch2); + } + + DLIB_TEST(s.peek() == EOF); + DLIB_TEST(s.get() == EOF); + + s.clear(); + s.seekg(6); + + for (int i = -1000+6; i <= 1000; ++i) + { + DLIB_TEST(s.peek() != EOF); + char ch1 = i; + char ch2 = s.get(); + DLIB_TEST(ch1 == ch2); + } + + DLIB_TEST(s.peek() == EOF); + DLIB_TEST(s.get() == EOF); + + std::string temp; + temp = "one two three!"; + + s.seekg(0); + buf.clear(); + s.clear(); + + serialize(temp, s); + std::string temp2; + deserialize(temp2, s); + DLIB_TEST(temp2 == temp); + + s.put('1'); + s.put('2'); + s.put('3'); + s.put('4'); + DLIB_TEST(s.get() == '1'); + DLIB_TEST(s.get() == '2'); + DLIB_TEST(s.get() == '3'); + DLIB_TEST(s.get() == '4'); + + s.putback('4'); + DLIB_TEST(s.get() == '4'); + s.putback('4'); + s.putback('3'); + s.putback('2'); + s.putback('1'); + DLIB_TEST(s.get() == '1'); + DLIB_TEST(s.get() == '2'); + DLIB_TEST(s.get() == '3'); + DLIB_TEST(s.get() == '4'); + DLIB_TEST(s.good() == true); + DLIB_TEST(s.get() == EOF); + DLIB_TEST(s.good() == false); + + // make sure seeking to a crazy offset doesn't mess things up + s.clear(); + s.seekg(1000000); + DLIB_TEST(s.get() == EOF); + DLIB_TEST(s.good() == false); + s.clear(); + s.seekg(1000000); + char sbuf[100]; + s.read(sbuf, sizeof(sbuf)); + DLIB_TEST(s.good() == false); + } + +// ---------------------------------------------------------------------------------------- + + class test_vectorstream : public tester + { + public: + test_vectorstream ( + ) : + tester ("test_vectorstream", + "Runs tests on the vectorstream component.") + {} + + void perform_test ( + ) + { + test1(); + } + } a; + +} + + diff --git a/ml/dlib/dlib/test_for_odr_violations.cpp b/ml/dlib/dlib/test_for_odr_violations.cpp new file mode 100644 index 000000000..fcf785995 --- /dev/null +++ b/ml/dlib/dlib/test_for_odr_violations.cpp @@ -0,0 +1,47 @@ +// Copyright (C) 2014 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_TEST_FOR_ODR_VIOLATIONS_CPp_ +#define DLIB_TEST_FOR_ODR_VIOLATIONS_CPp_ + +#include "test_for_odr_violations.h" + +extern "C" +{ +// The point of this block of code is to cause a link time error that will prevent a user +// from compiling part of their application with DLIB_ASSERT enabled and part with them +// disabled since doing that would be a violation of C++'s one definition rule. +#ifdef ENABLE_ASSERTS + const int USER_ERROR__inconsistent_build_configuration__see_dlib_faq_1 = 0; +#else + const int USER_ERROR__inconsistent_build_configuration__see_dlib_faq_1_ = 0; +#endif + + +// The point of this block of code is to cause a link time error if someone builds dlib via +// cmake as a separately installable library, and therefore generates a dlib/config.h from +// cmake, but then proceeds to use the default unconfigured dlib/config.h from version +// control. It should be obvious why this is bad, if it isn't you need to read a book +// about C++. Moreover, it can only happen if someone manually copies files around and +// messes things up. If instead they run `make install` or `cmake --build . --target +// install` things will be setup correctly, which is what they should do. To summarize: DO +// NOT BUILD A STANDALONE DLIB AND THEN GO CHERRY PICKING FILES FROM THE BUILD FOLDER AND +// MIXING THEM WITH THE SOURCE FROM GITHUB. USE CMAKE'S INSTALL SCRIPTS TO INSTALL DLIB. +// Or even better, don't install dlib at all and instead build your program as shown in +// examples/CMakeLists.txt +#if defined(DLIB_NOT_CONFIGURED) && !defined(DLIB__CMAKE_GENERATED_A_CONFIG_H_FILE) + const int USER_ERROR__inconsistent_build_configuration__see_dlib_faq_2 = 0; +#endif + + + + + +#ifdef DLIB_CHECK_FOR_VERSION_MISMATCH + const int DLIB_CHECK_FOR_VERSION_MISMATCH = 0; +#endif + +} + + +#endif // DLIB_TEST_FOR_ODR_VIOLATIONS_CPp_ + diff --git a/ml/dlib/dlib/test_for_odr_violations.h b/ml/dlib/dlib/test_for_odr_violations.h new file mode 100644 index 000000000..5fa5111ba --- /dev/null +++ b/ml/dlib/dlib/test_for_odr_violations.h @@ -0,0 +1,57 @@ +// Copyright (C) 2014 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_TEST_FOR_ODR_VIOLATIONS_H_ +#define DLIB_TEST_FOR_ODR_VIOLATIONS_H_ + +#include "assert.h" +#include "config.h" + +extern "C" +{ +// =========================>>> WHY YOU ARE GETTING AN ERROR HERE <<<========================= +// The point of this block of code is to cause a link time error that will prevent a user +// from compiling part of their application with DLIB_ASSERT enabled and part with it +// disabled since doing that would be a violation of C++'s one definition rule. So if you +// are getting an error here then you are either not enabling DLIB_ASSERT consistently +// (e.g. by compiling part of your program in a debug mode and part in a release mode) or +// you have simply forgotten to compile dlib/all/source.cpp into your application. +// =========================>>> WHY YOU ARE GETTING AN ERROR HERE <<<========================= +#ifdef ENABLE_ASSERTS + const extern int USER_ERROR__inconsistent_build_configuration__see_dlib_faq_1; + const int DLIB_NO_WARN_UNUSED dlib_check_assert_helper_variable = USER_ERROR__inconsistent_build_configuration__see_dlib_faq_1; +#else + const extern int USER_ERROR__inconsistent_build_configuration__see_dlib_faq_1_; + const int DLIB_NO_WARN_UNUSED dlib_check_assert_helper_variable = USER_ERROR__inconsistent_build_configuration__see_dlib_faq_1_; +#endif + + + +// The point of this block of code is to cause a link time error if someone builds dlib via +// cmake as a separately installable library, and therefore generates a dlib/config.h from +// cmake, but then proceeds to use the default unconfigured dlib/config.h from version +// control. It should be obvious why this is bad, if it isn't you need to read a book +// about C++. Moreover, it can only happen if someone manually copies files around and +// messes things up. If instead they run `make install` or `cmake --build . --target +// install` things will be setup correctly, which is what they should do. To summarize: DO +// NOT BUILD A STANDALONE DLIB AND THEN GO CHERRY PICKING FILES FROM THE BUILD FOLDER AND +// MIXING THEM WITH THE SOURCE FROM GITHUB. USE CMAKE'S INSTALL SCRIPTS TO INSTALL DLIB. +// Or even better, don't install dlib at all and instead build your program as shown in +// examples/CMakeLists.txt +#if defined(DLIB_NOT_CONFIGURED) && !defined(DLIB__CMAKE_GENERATED_A_CONFIG_H_FILE) + const extern int USER_ERROR__inconsistent_build_configuration__see_dlib_faq_2; + const int DLIB_NO_WARN_UNUSED dlib_check_not_configured_helper_variable = USER_ERROR__inconsistent_build_configuration__see_dlib_faq_2; +#endif + + + +// Cause the user to get a linker error if they try to use header files from one version of +// dlib with the compiled binary from a different version of dlib. +#ifdef DLIB_CHECK_FOR_VERSION_MISMATCH + const extern int DLIB_CHECK_FOR_VERSION_MISMATCH; + const int DLIB_NO_WARN_UNUSED dlib_check_for_version_mismatch = DLIB_CHECK_FOR_VERSION_MISMATCH; +#endif + +} + +#endif // DLIB_TEST_FOR_ODR_VIOLATIONS_H_ + diff --git a/ml/dlib/dlib/threads.h b/ml/dlib/dlib/threads.h new file mode 100644 index 000000000..371a317e0 --- /dev/null +++ b/ml/dlib/dlib/threads.h @@ -0,0 +1,28 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#ifdef DLIB_ALL_SOURCE_END +#include "dlib_basic_cpp_build_tutorial.txt" +#endif + +#ifndef DLIB_THREADs_ +#define DLIB_THREADs_ + +#include "threads/threads_kernel.h" + +#include "threads/auto_mutex_extension.h" +#include "threads/auto_unlock_extension.h" +#include "threads/create_new_thread_extension.h" +#include "threads/multithreaded_object_extension.h" +#include "threads/rmutex_extension.h" +#include "threads/rsignaler_extension.h" +#include "threads/threaded_object_extension.h" +#include "threads/thread_specific_data_extension.h" +#include "threads/thread_function_extension.h" +#include "threads/thread_pool_extension.h" +#include "threads/read_write_mutex_extension.h" +#include "threads/parallel_for_extension.h" +#include "threads/async.h" + +#endif // DLIB_THREADs_ + diff --git a/ml/dlib/dlib/threads/async.cpp b/ml/dlib/dlib/threads/async.cpp new file mode 100644 index 000000000..6aa947bcb --- /dev/null +++ b/ml/dlib/dlib/threads/async.cpp @@ -0,0 +1,48 @@ +// Copyright (C) 2016 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_AsYNC_CPP_ +#define DLIB_AsYNC_CPP_ + +// C++11 things don't work in old versions of visual studio +#if !defined( _MSC_VER) || _MSC_VER >= 1900 + +#include "async.h" +#include +#include "../string.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + unsigned long default_num_threads() + { + try + { + char* nt = getenv("DLIB_NUM_THREADS"); + if (nt) + return string_cast(nt); + } catch(string_cast_error&) {} + return std::thread::hardware_concurrency(); + } + } + +// ---------------------------------------------------------------------------------------- + + thread_pool& default_thread_pool() + { + static thread_pool tp(impl::default_num_threads()); + return tp; + } +} + +// ---------------------------------------------------------------------------------------- + +#endif + +#endif // DLIB_AsYNC_CPP_ + + diff --git a/ml/dlib/dlib/threads/async.h b/ml/dlib/dlib/threads/async.h new file mode 100644 index 000000000..bc6fe5575 --- /dev/null +++ b/ml/dlib/dlib/threads/async.h @@ -0,0 +1,105 @@ +// Copyright (C) 2016 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_AsYNC_Hh_ +#define DLIB_AsYNC_Hh_ + +// C++11 things don't work in old versions of visual studio +#if !defined( _MSC_VER) || _MSC_VER >= 1900 + +#include "async_abstract.h" +#include "thread_pool_extension.h" +#include +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + template struct selector {}; + + template + void call_prom_set_value( + T& prom, + U& fun, + selector + ) + { + prom.set_value(fun()); + } + + template + void call_prom_set_value( + T& prom, + U& fun, + selector + ) + { + fun(); + prom.set_value(); + } + } + +// ---------------------------------------------------------------------------------------- + + thread_pool& default_thread_pool(); + +// ---------------------------------------------------------------------------------------- + + template < + typename Function, + typename ...Args + > + std::future::type> async( + thread_pool& tp, + Function&& f, + Args&&... args + ) + { + auto prom = std::make_shared::type>>(); + std::future::type> ret = prom->get_future(); + using bind_t = decltype(std::bind(std::forward(f), std::forward(args)...)); + auto fun = std::make_shared(std::bind(std::forward(f), std::forward(args)...)); + tp.add_task_by_value([fun, prom]() + { + try + { + impl::call_prom_set_value(*prom, *fun, impl::selector::type>()); + } + catch(...) + { + prom->set_exception(std::current_exception()); + } + }); + return std::move(ret); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename Function, + typename ...Args + > + std::future::type> async( + Function&& f, + Args&&... args + ) + { + return async(default_thread_pool(), std::forward(f), std::forward(args)...); + } + +} + +// ---------------------------------------------------------------------------------------- + +#ifdef NO_MAKEFILE +#include "async.cpp" +#endif + +#endif +#endif // DLIB_AsYNC_Hh_ + + + diff --git a/ml/dlib/dlib/threads/async_abstract.h b/ml/dlib/dlib/threads/async_abstract.h new file mode 100644 index 000000000..a9fa1e458 --- /dev/null +++ b/ml/dlib/dlib/threads/async_abstract.h @@ -0,0 +1,67 @@ +// Copyright (C) 2016 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_AsYNC_ABSTRACT_Hh_ +#ifdef DLIB_AsYNC_ABSTRACT_Hh_ + +#include "thread_pool_extension_abstract.h" +#include +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + thread_pool& default_thread_pool( + ); + /*! + ensures + - returns a reference to a global thread_pool. If the DLIB_NUM_THREADS + environment variable is set to an integer then the thread pool will contain + DLIB_NUM_THREADS threads, otherwise it will contain + std::thread::hardware_concurrency() threads. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename Function, + typename ...Args + > + std::future::type> async( + thread_pool& tp, + Function&& f, + Args&&... args + ); + /*! + requires + - f must be a function and f(args...) must be a valid expression. + ensures + - This function behaves just like std::async(std::launch::async, f, args) + except that instead of spawning a new thread to process each task it submits + the task to the provided dlib::thread_pool. Therefore, dlib::async() is + guaranteed to use a bounded number of threads unlike std::async(). This also + means that calls to dlib::async() will block if there aren't any free threads + in the thread pool. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename Function, + typename ...Args + > + std::future::type> async( + Function&& f, + Args&&... args + ); + /*! + ensures + - Calling this function is equivalent to directly calling async(default_thread_pool(), f, args...) + !*/ +} + +// ---------------------------------------------------------------------------------------- + +#endif // DLIB_AsYNC_ABSTRACT_Hh_ + diff --git a/ml/dlib/dlib/threads/auto_mutex_extension.h b/ml/dlib/dlib/threads/auto_mutex_extension.h new file mode 100644 index 000000000..595c1b176 --- /dev/null +++ b/ml/dlib/dlib/threads/auto_mutex_extension.h @@ -0,0 +1,180 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_AUTO_MUTEX_EXTENSIOn_ +#define DLIB_AUTO_MUTEX_EXTENSIOn_ + +#include "threads_kernel.h" +#include "rmutex_extension.h" +#include "read_write_mutex_extension.h" +#include "auto_mutex_extension_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class auto_mutex + { + /*! + INITIAL VALUE + - if (m != 0) then + - the mutex pointed to by m is locked + - if (r != 0) then + - the mutex pointed to by r is locked + - if (rw != 0) then + - the mutex pointed to by rw is locked + - exactly one of r, m, or rw is not 0. + + CONVENTION + - if (m != 0) then + - the mutex pointed to by m is locked + - if (r != 0) then + - the mutex pointed to by r is locked + - if (rw != 0) then + - the mutex pointed to by rw is locked + - exactly one of r, m, or rw is not 0. + !*/ + public: + + explicit auto_mutex ( + const mutex& m_ + ) : m(&m_), + r(0), + rw(0) + { + m->lock(); + } + + explicit auto_mutex ( + const rmutex& r_ + ) : m(0), + r(&r_), + rw(0) + { + r->lock(); + } + + explicit auto_mutex ( + const read_write_mutex& rw_ + ) : m(0), + r(0), + rw(&rw_) + { + rw->lock(); + } + + void unlock() + { + if (m != 0) + { + m->unlock(); + m = 0; + } + else if (r != 0) + { + r->unlock(); + r = 0; + } + else if (rw != 0) + { + rw->unlock(); + rw = 0; + } + } + + ~auto_mutex ( + ) + { + unlock(); + } + + private: + + const mutex* m; + const rmutex* r; + const read_write_mutex* rw; + + // restricted functions + auto_mutex(auto_mutex&); // copy constructor + auto_mutex& operator=(auto_mutex&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- + + class auto_mutex_readonly + { + public: + + explicit auto_mutex_readonly ( + const read_write_mutex& rw_ + ) : rw(rw_), _has_write_lock(false), _has_read_lock(true) + { + rw.lock_readonly(); + } + + ~auto_mutex_readonly ( + ) + { + unlock(); + } + + void lock_readonly ( + ) + { + if (!_has_read_lock) + { + unlock(); + rw.lock_readonly(); + _has_read_lock = true; + } + } + + void lock_write ( + ) + { + if (!_has_write_lock) + { + unlock(); + rw.lock(); + _has_write_lock = true; + } + } + + void unlock ( + ) + { + if (_has_write_lock) + { + rw.unlock(); + _has_write_lock = false; + } + else if (_has_read_lock) + { + rw.unlock_readonly(); + _has_read_lock = false; + } + } + + bool has_read_lock ( + ) { return _has_read_lock; } + + bool has_write_lock ( + ) { return _has_write_lock; } + + private: + + const read_write_mutex& rw; + bool _has_write_lock; + bool _has_read_lock; + + // restricted functions + auto_mutex_readonly(auto_mutex_readonly&); // copy constructor + auto_mutex_readonly& operator=(auto_mutex_readonly&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_AUTO_MUTEX_EXTENSIOn_ + diff --git a/ml/dlib/dlib/threads/auto_mutex_extension_abstract.h b/ml/dlib/dlib/threads/auto_mutex_extension_abstract.h new file mode 100644 index 000000000..1990c834e --- /dev/null +++ b/ml/dlib/dlib/threads/auto_mutex_extension_abstract.h @@ -0,0 +1,185 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_AUTO_MUTEX_EXTENSIOn_ABSTRACT_ +#ifdef DLIB_AUTO_MUTEX_EXTENSIOn_ABSTRACT_ + +#include "threads_kernel_abstract.h" +#include "rmutex_extension_abstract.h" +#include "read_write_mutex_extension_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class auto_mutex + { + /*! + INITIAL VALUE + The mutex given in the constructor is locked and associated with this + object. + + WHAT THIS OBJECT REPRESENTS + This object represents a mechanism for automatically locking and unlocking + a mutex object. + !*/ + public: + + explicit auto_mutex ( + const mutex& m + ); + /*! + ensures + - #*this is properly initialized + - m will be locked + !*/ + + explicit auto_mutex ( + const rmutex& m + ); + /*! + ensures + - #*this is properly initialized + - m will be locked + !*/ + + explicit auto_mutex ( + const read_write_mutex& m + ); + /*! + ensures + - #*this is properly initialized + - m will be locked via m.lock() (i.e. a write lock will be obtained) + !*/ + + void unlock( + ); + /*! + ensures + - if (unlock() has not already been called) then + - The mutex associated with *this has been unlocked. This is useful if + you want to unlock a mutex before the auto_mutex destructor executes. + !*/ + + ~auto_mutex ( + ); + /*! + ensures + - all resources allocated by *this have been freed + - calls unlock() + !*/ + + private: + // restricted functions + auto_mutex(auto_mutex&); // copy constructor + auto_mutex& operator=(auto_mutex&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- + + class auto_mutex_readonly + { + /*! + INITIAL VALUE + The mutex given in the constructor is locked using a read-only lock and + associated with this object. + + WHAT THIS OBJECT REPRESENTS + This object represents a mechanism for automatically locking and unlocking + a read_write_mutex object. In particular, a readonly lock is used. + !*/ + public: + + explicit auto_mutex_readonly ( + const read_write_mutex& m + ); + /*! + ensures + - #*this is properly initialized + - a readonly lock will be obtained on m using m.lock_readonly() + - #has_read_lock() == true + !*/ + + ~auto_mutex_readonly ( + ); + /*! + ensures + - all resources allocated by *this have been freed + - the mutex associated with *this has been unlocked + !*/ + + bool has_read_lock ( + ); + /*! + ensures + - returns true if this object has called read_write_mutex::lock_readonly() + on its associated mutex and has yet to release that lock. + !*/ + + bool has_write_lock ( + ); + /*! + ensures + - returns true if this object has called read_write_mutex::lock() on its + associated mutex and has yet to release that lock. + !*/ + + void lock_readonly ( + ); + /*! + ensures + - This function converts the lock on the associated mutex into a readonly lock. + Specifically: + if (!has_read_lock()) then + - if (has_write_lock()) then + - unlocks the associated mutex and then relocks it by calling + read_write_mutex::lock_readonly() + - else + - locks the associated mutex by calling read_write_mutex::lock_readonly() + - #has_read_lock() == true + - Note that the lock switch is not atomic. This means that whatever + resource is protected by the mutex might have been modified during the + call to lock_readonly(). + !*/ + + void lock_write ( + ); + /*! + ensures + - This function converts the lock on the associated mutex into a write lock. + Specifically: + if (!has_write_lock()) then + - if (has_read_lock()) then + - unlocks the associated mutex and then relocks it by calling + read_write_mutex::lock() + - else + - locks the associated mutex by calling read_write_mutex::lock() + - #has_write_lock() == true + - Note that the lock switch is not atomic. This means that whatever + resource is protected by the mutex might have been modified during the + call to lock_write(). + !*/ + + void unlock ( + ); + /*! + ensures + - if (has_read_lock() || has_write_lock()) then + - unlocks the associated mutex. This is useful if you want to unlock a + mutex before the auto_mutex_readonly destructor executes. + - #has_read_lock() == false + - #has_write_lock() == false + !*/ + + private: + // restricted functions + auto_mutex_readonly(auto_mutex_readonly&); // copy constructor + auto_mutex_readonly& operator=(auto_mutex_readonly&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_AUTO_MUTEX_EXTENSIOn_ABSTRACT_ + diff --git a/ml/dlib/dlib/threads/auto_unlock_extension.h b/ml/dlib/dlib/threads/auto_unlock_extension.h new file mode 100644 index 000000000..cd1d4db9a --- /dev/null +++ b/ml/dlib/dlib/threads/auto_unlock_extension.h @@ -0,0 +1,116 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_AUTO_UNLOCK_EXTENSIOn_ +#define DLIB_AUTO_UNLOCK_EXTENSIOn_ + +#include "threads_kernel.h" +#include "rmutex_extension.h" +#include "read_write_mutex_extension.h" +#include "auto_unlock_extension_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class auto_unlock + { + /*! + INITIAL VALUE + - if (m != 0) then + - the mutex pointed to by m is locked + - if (r != 0) then + - the mutex pointed to by r is locked + - if (rw != 0) then + - the mutex pointed to by rw is locked + - exactly one of r, m, or rw is not 0. + + CONVENTION + - if (m != 0) then + - the mutex pointed to by m is locked + - if (r != 0) then + - the mutex pointed to by r is locked + - if (rw != 0) then + - the mutex pointed to by rw is locked + - exactly one of r, m, or rw is not 0. + !*/ + public: + + explicit auto_unlock ( + const mutex& m_ + ) : m(&m_), + r(0), + rw(0) + {} + + explicit auto_unlock ( + const rmutex& r_ + ) : m(0), + r(&r_), + rw(0) + {} + + explicit auto_unlock ( + const read_write_mutex& rw_ + ) : m(0), + r(0), + rw(&rw_) + {} + + ~auto_unlock ( + ) + { + if (m != 0) + m->unlock(); + else if (r != 0) + r->unlock(); + else + rw->unlock(); + } + + private: + + const mutex* m; + const rmutex* r; + const read_write_mutex* rw; + + // restricted functions + auto_unlock(auto_unlock&); // copy constructor + auto_unlock& operator=(auto_unlock&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- + + class auto_unlock_readonly + { + + public: + + explicit auto_unlock_readonly ( + const read_write_mutex& rw_ + ) : + rw(rw_) + {} + + ~auto_unlock_readonly ( + ) + { + rw.unlock_readonly(); + } + + private: + + const read_write_mutex& rw; + + // restricted functions + auto_unlock_readonly(auto_unlock_readonly&); // copy constructor + auto_unlock_readonly& operator=(auto_unlock_readonly&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_AUTO_UNLOCK_EXTENSIOn_ + + diff --git a/ml/dlib/dlib/threads/auto_unlock_extension_abstract.h b/ml/dlib/dlib/threads/auto_unlock_extension_abstract.h new file mode 100644 index 000000000..f947d4879 --- /dev/null +++ b/ml/dlib/dlib/threads/auto_unlock_extension_abstract.h @@ -0,0 +1,116 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_AUTO_UNLOCK_EXTENSIOn_ABSTRACT_ +#ifdef DLIB_AUTO_UNLOCK_EXTENSIOn_ABSTRACT_ + +#include "threads_kernel_abstract.h" +#include "rmutex_extension_abstract.h" +#include "read_write_mutex_extension_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class auto_unlock + { + /*! + INITIAL VALUE + The mutex given in the constructor is associated with this object. + + WHAT THIS OBJECT REPRESENTS + This object represents a mechanism for automatically unlocking + a mutex object. It is useful when you already have a locked mutex + and want to make sure it gets unlocked even if an exception is thrown + or you quit the function at a weird spot. + !*/ + public: + + explicit auto_unlock ( + const mutex& m + ); + /*! + ensures + - #*this is properly initialized + - does not modify m in any way + !*/ + + explicit auto_unlock ( + const rmutex& m + ); + /*! + ensures + - #*this is properly initialized + - does not modify m in any way + !*/ + + explicit auto_unlock ( + const read_write_mutex& m + ); + /*! + ensures + - #*this is properly initialized + - does not modify m in any way + !*/ + + ~auto_unlock ( + ); + /*! + ensures + - all resources allocated by *this have been freed + - calls unlock() on the mutex associated with *this + !*/ + + private: + // restricted functions + auto_unlock(auto_unlock&); // copy constructor + auto_unlock& operator=(auto_unlock&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- + + class auto_unlock_readonly + { + /*! + INITIAL VALUE + The mutex given in the constructor is associated with this object. + + WHAT THIS OBJECT REPRESENTS + This object represents a mechanism for automatically unlocking + a read_write_mutex object. It is useful when you already have a locked mutex + and want to make sure it gets unlocked even if an exception is thrown + or you quit the function at a weird spot. Note that the mutex + is unlocked by calling unlock_readonly() on it. + !*/ + public: + + explicit auto_unlock_readonly ( + const read_write_mutex& m + ); + /*! + ensures + - #*this is properly initialized + - does not modify m in any way + !*/ + + ~auto_unlock_readonly ( + ); + /*! + ensures + - all resources allocated by *this have been freed + - calls unlock_readonly() on the mutex associated with *this + !*/ + + private: + // restricted functions + auto_unlock_readonly(auto_unlock_readonly&); // copy constructor + auto_unlock_readonly& operator=(auto_unlock_readonly&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_AUTO_UNLOCK_EXTENSIOn_ABSTRACT_ + + diff --git a/ml/dlib/dlib/threads/create_new_thread_extension.h b/ml/dlib/dlib/threads/create_new_thread_extension.h new file mode 100644 index 000000000..8f419b6be --- /dev/null +++ b/ml/dlib/dlib/threads/create_new_thread_extension.h @@ -0,0 +1,46 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CREATE_NEW_THREAD_EXTENSIOn_ +#define DLIB_CREATE_NEW_THREAD_EXTENSIOn_ + +#include "threads_kernel_abstract.h" +#include "create_new_thread_extension_abstract.h" +#include "../threads.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + void (T::*funct)() + > + inline void dlib_create_new_thread_helper ( + void* obj + ) + { + T* o = static_cast(obj); + (o->*funct)(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + void (T::*funct)() + > + inline bool create_new_thread ( + T& obj + ) + { + return create_new_thread(dlib_create_new_thread_helper,&obj); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CREATE_NEW_THREAD_EXTENSIOn_ + + diff --git a/ml/dlib/dlib/threads/create_new_thread_extension_abstract.h b/ml/dlib/dlib/threads/create_new_thread_extension_abstract.h new file mode 100644 index 000000000..43fbc474d --- /dev/null +++ b/ml/dlib/dlib/threads/create_new_thread_extension_abstract.h @@ -0,0 +1,33 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_CREATE_NEW_THREAD_EXTENSIOn_ABSTRACT_ +#ifdef DLIB_CREATE_NEW_THREAD_EXTENSIOn_ABSTRACT_ + +#include "threads_kernel_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + void (T::*funct)() + > + bool create_new_thread ( + T& obj + ); + /*! + ensures + - creates a new thread and calls obj.*funct() from it. + - returns true upon success and false upon failure to create the new thread. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CREATE_NEW_THREAD_EXTENSIOn_ABSTRACT_ + + + diff --git a/ml/dlib/dlib/threads/multithreaded_object_extension.cpp b/ml/dlib/dlib/threads/multithreaded_object_extension.cpp new file mode 100644 index 000000000..def4af5f2 --- /dev/null +++ b/ml/dlib/dlib/threads/multithreaded_object_extension.cpp @@ -0,0 +1,241 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MULTITHREADED_OBJECT_EXTENSIOn_CPP +#define DLIB_MULTITHREADED_OBJECT_EXTENSIOn_CPP + +#include "multithreaded_object_extension.h" +#include "create_new_thread_extension.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + multithreaded_object:: + multithreaded_object ( + ): + s(m_), + is_running_(false), + should_stop_(false), + threads_started(0) + { + } + +// ---------------------------------------------------------------------------------------- + + multithreaded_object:: + ~multithreaded_object ( + ) + { + try + { + DLIB_ASSERT(number_of_threads_alive() == 0, + "\tmultithreaded_object::~multithreaded_object()" + << "\n\tYou have let a multithreaded object destruct itself before terminating its threads" + << "\n\tthis: " << this + ); + } + catch (std::exception& e) + { + std::cerr << e.what() << std::endl; + assert(false); + abort(); + } + } + +// ---------------------------------------------------------------------------------------- + + void multithreaded_object:: + clear ( + ) + { + auto_mutex M(m_); + stop(); + wait(); + dead_threads.clear(); + is_running_ = false; + should_stop_ = false; + } + +// ---------------------------------------------------------------------------------------- + + bool multithreaded_object:: + is_running ( + ) const + { + auto_mutex M(m_); + return is_running_; + } + +// ---------------------------------------------------------------------------------------- + + unsigned long multithreaded_object:: + number_of_threads_registered ( + ) const + { + auto_mutex M(m_); + return thread_ids.size() + dead_threads.size(); + } + +// ---------------------------------------------------------------------------------------- + + unsigned long multithreaded_object:: + number_of_threads_alive ( + ) const + { + auto_mutex M(m_); + return threads_started; + } + +// ---------------------------------------------------------------------------------------- + + void multithreaded_object:: + wait ( + ) const + { + auto_mutex M(m_); + + DLIB_ASSERT(thread_ids.is_in_domain(get_thread_id()) == false, + "\tvoid multithreaded_object::wait()" + << "\n\tYou can NOT call this function from one of the threads registered in this object" + << "\n\tthis: " << this + ); + + while (threads_started > 0) + s.wait(); + } + +// ---------------------------------------------------------------------------------------- + + void multithreaded_object:: + start ( + ) + { + auto_mutex M(m_); + const unsigned long num_threads_registered = dead_threads.size() + thread_ids.size(); + // start any dead threads + for (unsigned long i = threads_started; i < num_threads_registered; ++i) + { + if (create_new_thread(*this) == false) + { + should_stop_ = true; + is_running_ = false; + throw thread_error(); + } + ++threads_started; + } + is_running_ = true; + should_stop_ = false; + s.broadcast(); + } + +// ---------------------------------------------------------------------------------------- + + void multithreaded_object:: + pause ( + ) + { + auto_mutex M(m_); + is_running_ = false; + } + +// ---------------------------------------------------------------------------------------- + + void multithreaded_object:: + stop ( + ) + { + auto_mutex M(m_); + should_stop_ = true; + is_running_ = false; + s.broadcast(); + } + +// ---------------------------------------------------------------------------------------- + + bool multithreaded_object:: + should_stop ( + ) const + { + auto_mutex M(m_); + DLIB_ASSERT(thread_ids.is_in_domain(get_thread_id()), + "\tbool multithreaded_object::should_stop()" + << "\n\tYou can only call this function from one of the registered threads in this object" + << "\n\tthis: " << this + ); + while (is_running_ == false && should_stop_ == false) + s.wait(); + return should_stop_; + } + +// ---------------------------------------------------------------------------------------- + + multithreaded_object::raii_thread_helper:: + raii_thread_helper( + multithreaded_object& self_, + thread_id_type id_ + ) : self(self_), id(id_){} + + multithreaded_object::raii_thread_helper:: + ~raii_thread_helper() + { + auto_mutex M(self.m_); + if (self.thread_ids.is_in_domain(id)) + { + mfp temp; + thread_id_type id_temp; + self.thread_ids.remove(id,id_temp,temp); + // put this thread's registered function back into the dead_threads queue + self.dead_threads.enqueue(temp); + } + + --self.threads_started; + // If this is the last thread to terminate then + // signal that that is the case. + if (self.threads_started == 0) + { + self.is_running_ = false; + self.should_stop_ = false; + self.s.broadcast(); + } + } + +// ---------------------------------------------------------------------------------------- + + void multithreaded_object:: + thread_helper( + ) + { + mfp mf; + thread_id_type id = get_thread_id(); + + // this guy's destructor does all the necessary cleanup in this function + raii_thread_helper raii(*this, id); + + // if there is a dead_thread sitting around then pull it + // out and put it into mf + { + auto_mutex M(m_); + if (dead_threads.size() > 0) + { + dead_threads.dequeue(mf); + mfp temp(mf); + thread_ids.add(id,temp); + } + } + + if (mf.is_set()) + { + // call the registered thread function + mf(); + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MULTITHREADED_OBJECT_EXTENSIOn_CPP + + diff --git a/ml/dlib/dlib/threads/multithreaded_object_extension.h b/ml/dlib/dlib/threads/multithreaded_object_extension.h new file mode 100644 index 000000000..9dd37fdcc --- /dev/null +++ b/ml/dlib/dlib/threads/multithreaded_object_extension.h @@ -0,0 +1,153 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MULTITHREADED_OBJECT_EXTENSIOn_ +#define DLIB_MULTITHREADED_OBJECT_EXTENSIOn_ + +#include "multithreaded_object_extension_abstract.h" +#include "threads_kernel.h" +#include "auto_mutex_extension.h" +#include "rmutex_extension.h" +#include "rsignaler_extension.h" +#include "../algs.h" +#include "../assert.h" +#include "../map.h" +#include "../member_function_pointer.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class multithreaded_object + { + /*! + INITIAL VALUE + - is_running_ == false + - should_stop_ == false + - thread_ids.size() == 0 + - dead_threads.size() == 0 + - threads_started == 0 + + CONVENTION + - number_of_threads_registered() == thread_ids.size() + dead_threads.size() + - number_of_threads_alive() == threads_started + + - is_running() == is_running_ + - should_stop() == should_stop_ + + - thread_ids == a map of current thread ids to the member function + pointers that that thread runs. + - threads_started == the number of threads that have been spawned to run + thread_helper but haven't ended yet. + + - dead_threads == a queue that contains all the member function pointers + for threads that are currently registered but not running + + - m_ == the mutex used to protect all our variables + - s == the signaler for m_ + !*/ + + public: + + multithreaded_object ( + ); + + virtual ~multithreaded_object ( + ) = 0; + + void clear ( + ); + + bool is_running ( + ) const; + + unsigned long number_of_threads_alive ( + ) const; + + unsigned long number_of_threads_registered ( + ) const; + + void wait ( + ) const; + + void start ( + ); + + void pause ( + ); + + void stop ( + ); + + protected: + + bool should_stop ( + ) const; + + template < + typename T + > + void register_thread ( + T& object, + void (T::*thread)() + ) + { + auto_mutex M(m_); + try + { + mfp mf; + mf.set(object,thread); + dead_threads.enqueue(mf); + if (is_running_) + start(); + } + catch (...) + { + is_running_ = false; + should_stop_ = true; + s.broadcast(); + throw; + } + } + + private: + + class raii_thread_helper + { + public: + raii_thread_helper(multithreaded_object& self_, thread_id_type id_); + ~raii_thread_helper(); + + multithreaded_object& self; + thread_id_type id; + }; + + void thread_helper( + ); + + typedef member_function_pointer<> mfp; + + rmutex m_; + rsignaler s; + map::kernel_2a>::kernel_1a thread_ids; + queue::kernel_2a>::kernel_1a dead_threads; + + bool is_running_; + bool should_stop_; + unsigned long threads_started; + + // restricted functions + multithreaded_object(multithreaded_object&); // copy constructor + multithreaded_object& operator=(multithreaded_object&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- + +} + +#ifdef NO_MAKEFILE +#include "multithreaded_object_extension.cpp" +#endif + +#endif // DLIB_MULTITHREADED_OBJECT_EXTENSIOn_ + diff --git a/ml/dlib/dlib/threads/multithreaded_object_extension_abstract.h b/ml/dlib/dlib/threads/multithreaded_object_extension_abstract.h new file mode 100644 index 000000000..e7862b78f --- /dev/null +++ b/ml/dlib/dlib/threads/multithreaded_object_extension_abstract.h @@ -0,0 +1,186 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_MULTITHREADED_OBJECT_EXTENSIOn_ABSTRACT_ +#ifdef DLIB_MULTITHREADED_OBJECT_EXTENSIOn_ABSTRACT_ + +#include "threads_kernel_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class multithreaded_object + { + /*! + INITIAL VALUE + - is_running() == false + - number_of_threads_alive() == 0 + - number_of_threads_registered() == 0 + + WHAT THIS OBJECT REPRESENTS + This object represents a multithreaded object. It is similar to + the threaded_object except it allows you to have many threads in a + single object rather than just one. To use it you inherit from it + and register the member functions in your new class that you want + to run in their own threads by calling register_thread(). Then when + you call start() it will spawn all the registered functions + in their own threads. + !*/ + + public: + + multithreaded_object ( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc + - dlib::thread_error + the constructor may throw this exception if there is a problem + gathering resources to create threading objects. + !*/ + + virtual ~multithreaded_object ( + ) = 0; + /*! + requires + - number_of_threads_alive() == 0 + (i.e. in the destructor for the object you derive from this one you + must wait for all the threads to end.) + ensures + - all resources allocated by *this have been freed. + !*/ + + void clear( + ); + /*! + ensures + - #*this has its initial value + - blocks until all threads have terminated + throws + - std::bad_alloc or dlib::thread_error + if an exception is thrown then *this is unusable + until clear() is called and succeeds + !*/ + + bool is_running ( + ) const; + /*! + ensures + - if (number_of_threads_alive() > 0 && the threads are currently supposed to be executing) then + - returns true + - else + - returns false + !*/ + + unsigned long number_of_threads_alive ( + ) const; + /*! + ensures + - returns the number of threads that are currently alive (i.e. + the number of threads that have started but not yet terminated) + !*/ + + unsigned long number_of_threads_registered ( + ) const; + /*! + ensures + - returns the number of threads that have been registered by + calls to register_thread() + !*/ + + void wait ( + ) const; + /*! + requires + - is not called from one of this object's threads + ensures + - if (number_of_threads_alive() > 0) then + - blocks until all the threads in this object have terminated + (i.e. blocks until number_of_threads_alive() == 0) + !*/ + + void start ( + ); + /*! + ensures + - #number_of_threads_alive() == number_of_threads_registered() + - #is_running() == true + - #should_stop() == false + - all the threads registered are up and running. + throws + - std::bad_alloc or dlib::thread_error + If either of these exceptions are thrown then + #is_running() == false and should_stop() == true + !*/ + + void pause ( + ); + /*! + ensures + - #is_running() == false + !*/ + + void stop ( + ); + /*! + ensures + - #should_stop() == true + - #is_running() == false + !*/ + + protected: + + template < + typename T + > + void register_thread ( + T& object, + void (T::*thread)() + ); + /*! + requires + - (object.*thread)() forms a valid function call + - the thread function does not throw + ensures + - registers the member function pointed to by thread as one of the threads + that runs when is_running() == true + - #number_of_threads_registered() == number_of_threads_registered() + 1 + - if (is_running() == true) + - spawns this new member function in its own thread + - #number_of_threads_alive() += number_of_threads_alive() + 1 + throws + - std::bad_alloc or dlib::thread_error + If either of these exceptions are thrown then + #is_running() == false and should_stop() == true + !*/ + + bool should_stop ( + ) const; + /*! + requires + - is only called from one of the registered threads in this object + ensures + - if (is_running() == false && should_stop() == false) then + - blocks until (#is_running() == true || #should_stop() == true) + - if (this thread is supposed to terminate) then + - returns true + - else + - returns false + !*/ + + private: + + // restricted functions + multithreaded_object(multithreaded_object&); // copy constructor + multithreaded_object& operator=(multithreaded_object&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MULTITHREADED_OBJECT_EXTENSIOn_ABSTRACT_ + diff --git a/ml/dlib/dlib/threads/parallel_for_extension.h b/ml/dlib/dlib/threads/parallel_for_extension.h new file mode 100644 index 000000000..60b64b1b4 --- /dev/null +++ b/ml/dlib/dlib/threads/parallel_for_extension.h @@ -0,0 +1,676 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_PARALLEL_FoR_Hh_ +#define DLIB_PARALLEL_FoR_Hh_ + +#include "parallel_for_extension_abstract.h" +#include "thread_pool_extension.h" +#include "../console_progress_indicator.h" +#include "async.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + + template + class helper_parallel_for + { + public: + helper_parallel_for ( + T& obj_, + void (T::*funct_)(long) + ) : + obj(obj_), + funct(funct_) + {} + + T& obj; + void (T::*funct)(long); + + void process_block (long begin, long end) + { + for (long i = begin; i < end; ++i) + (obj.*funct)(i); + } + }; + + template + class helper_parallel_for_funct + { + public: + helper_parallel_for_funct ( + const T& funct_ + ) : funct(funct_) {} + + const T& funct; + + void run(long i) + { + funct(i); + } + }; + + template + class helper_parallel_for_funct2 + { + public: + helper_parallel_for_funct2 ( + const T& funct_ + ) : funct(funct_) {} + + const T& funct; + + void run(long begin, long end) + { + funct(begin, end); + } + }; + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template + void parallel_for_blocked ( + thread_pool& tp, + long begin, + long end, + T& obj, + void (T::*funct)(long, long), + long chunks_per_thread = 8 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(begin <= end && chunks_per_thread > 0, + "\t void parallel_for_blocked()" + << "\n\t Invalid inputs were given to this function" + << "\n\t begin: " << begin + << "\n\t end: " << end + << "\n\t chunks_per_thread: " << chunks_per_thread + ); + + if (tp.num_threads_in_pool() != 0) + { + const long num = end-begin; + const long num_workers = static_cast(tp.num_threads_in_pool()); + // How many samples to process in a single task (aim for chunks_per_thread jobs per worker) + const long block_size = std::max(1L, num/(num_workers*chunks_per_thread)); + for (long i = 0; i < num; i+=block_size) + { + tp.add_task(obj, funct, begin+i, begin+std::min(i+block_size, num)); + } + tp.wait_for_all_tasks(); + } + else + { + // Since there aren't any threads in the pool we might as well just invoke + // the function directly since that's all the thread_pool object would do. + // But doing it ourselves skips a mutex lock. + (obj.*funct)(begin, end); + } + } + +// ---------------------------------------------------------------------------------------- + + template + void parallel_for_blocked ( + unsigned long num_threads, + long begin, + long end, + T& obj, + void (T::*funct)(long, long), + long chunks_per_thread = 8 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(begin <= end && chunks_per_thread > 0, + "\t void parallel_for_blocked()" + << "\n\t Invalid inputs were given to this function" + << "\n\t begin: " << begin + << "\n\t end: " << end + << "\n\t chunks_per_thread: " << chunks_per_thread + ); + + thread_pool tp(num_threads); + parallel_for_blocked(tp, begin, end, obj, funct, chunks_per_thread); + } + +// ---------------------------------------------------------------------------------------- + + template + void parallel_for_blocked ( + thread_pool& tp, + long begin, + long end, + const T& funct, + long chunks_per_thread = 8 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(begin <= end && chunks_per_thread > 0, + "\t void parallel_for_blocked()" + << "\n\t Invalid inputs were given to this function" + << "\n\t begin: " << begin + << "\n\t end: " << end + << "\n\t chunks_per_thread: " << chunks_per_thread + ); + + impl::helper_parallel_for_funct2 helper(funct); + parallel_for_blocked(tp, begin, end, helper, &impl::helper_parallel_for_funct2::run, chunks_per_thread); + } + +// ---------------------------------------------------------------------------------------- + + template + void parallel_for_blocked ( + unsigned long num_threads, + long begin, + long end, + const T& funct, + long chunks_per_thread = 8 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(begin <= end && chunks_per_thread > 0, + "\t void parallel_for_blocked()" + << "\n\t Invalid inputs were given to this function" + << "\n\t begin: " << begin + << "\n\t end: " << end + << "\n\t chunks_per_thread: " << chunks_per_thread + ); + + thread_pool tp(num_threads); + parallel_for_blocked(tp, begin, end, funct, chunks_per_thread); + } + + template + void parallel_for_blocked ( + long begin, + long end, + const T& funct, + long chunks_per_thread = 8 + ) + { + parallel_for_blocked(default_thread_pool(), begin, end, funct, chunks_per_thread); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template + void parallel_for ( + thread_pool& tp, + long begin, + long end, + T& obj, + void (T::*funct)(long), + long chunks_per_thread = 8 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(begin <= end && chunks_per_thread > 0, + "\t void parallel_for()" + << "\n\t Invalid inputs were given to this function" + << "\n\t begin: " << begin + << "\n\t end: " << end + << "\n\t chunks_per_thread: " << chunks_per_thread + ); + + impl::helper_parallel_for helper(obj, funct); + parallel_for_blocked(tp, begin, end, helper, &impl::helper_parallel_for::process_block, chunks_per_thread); + } + +// ---------------------------------------------------------------------------------------- + + template + void parallel_for ( + unsigned long num_threads, + long begin, + long end, + T& obj, + void (T::*funct)(long), + long chunks_per_thread = 8 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(begin <= end && chunks_per_thread > 0, + "\t void parallel_for()" + << "\n\t Invalid inputs were given to this function" + << "\n\t begin: " << begin + << "\n\t end: " << end + << "\n\t chunks_per_thread: " << chunks_per_thread + ); + + thread_pool tp(num_threads); + parallel_for(tp, begin, end, obj, funct, chunks_per_thread); + } + +// ---------------------------------------------------------------------------------------- + + template + void parallel_for ( + thread_pool& tp, + long begin, + long end, + const T& funct, + long chunks_per_thread = 8 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(begin <= end && chunks_per_thread > 0, + "\t void parallel_for()" + << "\n\t Invalid inputs were given to this function" + << "\n\t begin: " << begin + << "\n\t end: " << end + << "\n\t chunks_per_thread: " << chunks_per_thread + ); + + impl::helper_parallel_for_funct helper(funct); + parallel_for(tp, begin, end, helper, &impl::helper_parallel_for_funct::run, chunks_per_thread); + } + +// ---------------------------------------------------------------------------------------- + + template + void parallel_for ( + unsigned long num_threads, + long begin, + long end, + const T& funct, + long chunks_per_thread = 8 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(begin <= end && chunks_per_thread > 0, + "\t void parallel_for()" + << "\n\t Invalid inputs were given to this function" + << "\n\t begin: " << begin + << "\n\t end: " << end + << "\n\t chunks_per_thread: " << chunks_per_thread + ); + + thread_pool tp(num_threads); + parallel_for(tp, begin, end, funct, chunks_per_thread); + } + +// ---------------------------------------------------------------------------------------- + + template + void parallel_for ( + long begin, + long end, + const T& funct, + long chunks_per_thread = 8 + ) + { + parallel_for(default_thread_pool(), begin, end, funct, chunks_per_thread); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + namespace impl + { + template + class parfor_verbose_helper + { + public: + parfor_verbose_helper(T& obj_, void (T::*funct_)(long), long begin, long end) : + obj(obj_), funct(funct_), pbar(end-begin) + { + count = 0; + wrote_to_screen = pbar.print_status(0); + } + + ~parfor_verbose_helper() + { + if (wrote_to_screen) + std::cout << std::endl; + } + + mutable long count; + T& obj; + void (T::*funct)(long); + mutable console_progress_indicator pbar; + mutable bool wrote_to_screen; + mutex m; + + void operator()(long i) const + { + (obj.*funct)(i); + { + auto_mutex lock(m); + wrote_to_screen = pbar.print_status(++count) || wrote_to_screen; + } + } + + }; + + template + class parfor_verbose_helper3 + { + public: + parfor_verbose_helper3(T& obj_, void (T::*funct_)(long,long), long begin, long end) : + obj(obj_), funct(funct_), pbar(end-begin) + { + count = 0; + wrote_to_screen = pbar.print_status(0); + } + + ~parfor_verbose_helper3() + { + if (wrote_to_screen) + std::cout << std::endl; + } + + mutable long count; + T& obj; + void (T::*funct)(long,long); + mutable console_progress_indicator pbar; + mutable bool wrote_to_screen; + mutex m; + + void operator()(long begin, long end) const + { + (obj.*funct)(begin, end); + { + auto_mutex lock(m); + count += end-begin; + wrote_to_screen = pbar.print_status(count) || wrote_to_screen; + } + } + }; + + template + class parfor_verbose_helper2 + { + public: + parfor_verbose_helper2(const T& obj_, long begin, long end) : obj(obj_), pbar(end-begin) + { + count = 0; + wrote_to_screen = pbar.print_status(0); + } + + ~parfor_verbose_helper2() + { + if (wrote_to_screen) + std::cout << std::endl; + } + + mutable long count; + const T& obj; + mutable console_progress_indicator pbar; + mutable bool wrote_to_screen; + mutex m; + + void operator()(long i) const + { + obj(i); + { + auto_mutex lock(m); + wrote_to_screen = pbar.print_status(++count) || wrote_to_screen; + } + } + + void operator()(long begin, long end) const + { + obj(begin, end); + { + auto_mutex lock(m); + count += end-begin; + wrote_to_screen = pbar.print_status(count) || wrote_to_screen; + } + } + }; + } + + template + void parallel_for_verbose ( + thread_pool& tp, + long begin, + long end, + T& obj, + void (T::*funct)(long), + long chunks_per_thread = 8 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(begin <= end && chunks_per_thread > 0, + "\t void parallel_for_verbose()" + << "\n\t Invalid inputs were given to this function" + << "\n\t begin: " << begin + << "\n\t end: " << end + << "\n\t chunks_per_thread: " << chunks_per_thread + ); + + impl::parfor_verbose_helper helper(obj, funct, begin, end); + parallel_for(tp, begin, end, helper, chunks_per_thread); + } + +// ---------------------------------------------------------------------------------------- + + template + void parallel_for_verbose ( + unsigned long num_threads, + long begin, + long end, + T& obj, + void (T::*funct)(long), + long chunks_per_thread = 8 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(begin <= end && chunks_per_thread > 0, + "\t void parallel_for_verbose()" + << "\n\t Invalid inputs were given to this function" + << "\n\t begin: " << begin + << "\n\t end: " << end + << "\n\t chunks_per_thread: " << chunks_per_thread + ); + + impl::parfor_verbose_helper helper(obj, funct, begin, end); + parallel_for(num_threads, begin, end, helper, chunks_per_thread); + } + +// ---------------------------------------------------------------------------------------- + + template + void parallel_for_verbose ( + thread_pool& tp, + long begin, + long end, + const T& funct, + long chunks_per_thread = 8 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(begin <= end && chunks_per_thread > 0, + "\t void parallel_for_verbose()" + << "\n\t Invalid inputs were given to this function" + << "\n\t begin: " << begin + << "\n\t end: " << end + << "\n\t chunks_per_thread: " << chunks_per_thread + ); + + impl::parfor_verbose_helper2 helper(funct, begin, end); + parallel_for(tp, begin, end, helper, chunks_per_thread); + } + +// ---------------------------------------------------------------------------------------- + + template + void parallel_for_verbose ( + unsigned long num_threads, + long begin, + long end, + const T& funct, + long chunks_per_thread = 8 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(begin <= end && chunks_per_thread > 0, + "\t void parallel_for_verbose()" + << "\n\t Invalid inputs were given to this function" + << "\n\t begin: " << begin + << "\n\t end: " << end + << "\n\t chunks_per_thread: " << chunks_per_thread + ); + + impl::parfor_verbose_helper2 helper(funct, begin, end); + parallel_for(num_threads, begin, end, helper, chunks_per_thread); + } + +// ---------------------------------------------------------------------------------------- + + template + void parallel_for_verbose ( + long begin, + long end, + const T& funct, + long chunks_per_thread = 8 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(begin <= end && chunks_per_thread > 0, + "\t void parallel_for_verbose()" + << "\n\t Invalid inputs were given to this function" + << "\n\t begin: " << begin + << "\n\t end: " << end + << "\n\t chunks_per_thread: " << chunks_per_thread + ); + + impl::parfor_verbose_helper2 helper(funct, begin, end); + parallel_for(begin, end, helper, chunks_per_thread); + } + +// ---------------------------------------------------------------------------------------- + + template + void parallel_for_blocked_verbose ( + thread_pool& tp, + long begin, + long end, + T& obj, + void (T::*funct)(long,long), + long chunks_per_thread = 8 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(begin <= end && chunks_per_thread > 0, + "\t void parallel_for_blocked_verbose()" + << "\n\t Invalid inputs were given to this function" + << "\n\t begin: " << begin + << "\n\t end: " << end + << "\n\t chunks_per_thread: " << chunks_per_thread + ); + + impl::parfor_verbose_helper3 helper(obj, funct, begin, end); + parallel_for_blocked(tp, begin, end, helper, chunks_per_thread); + } + +// ---------------------------------------------------------------------------------------- + + template + void parallel_for_blocked_verbose ( + unsigned long num_threads, + long begin, + long end, + T& obj, + void (T::*funct)(long,long), + long chunks_per_thread = 8 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(begin <= end && chunks_per_thread > 0, + "\t void parallel_for_blocked_verbose()" + << "\n\t Invalid inputs were given to this function" + << "\n\t begin: " << begin + << "\n\t end: " << end + << "\n\t chunks_per_thread: " << chunks_per_thread + ); + + impl::parfor_verbose_helper3 helper(obj, funct, begin, end); + parallel_for_blocked(num_threads, begin, end, helper, chunks_per_thread); + } + +// ---------------------------------------------------------------------------------------- + + template + void parallel_for_blocked_verbose ( + thread_pool& tp, + long begin, + long end, + const T& funct, + long chunks_per_thread = 8 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(begin <= end && chunks_per_thread > 0, + "\t void parallel_for_blocked_verbose()" + << "\n\t Invalid inputs were given to this function" + << "\n\t begin: " << begin + << "\n\t end: " << end + << "\n\t chunks_per_thread: " << chunks_per_thread + ); + + impl::parfor_verbose_helper2 helper(funct, begin, end); + parallel_for_blocked(tp, begin, end, helper, chunks_per_thread); + } + +// ---------------------------------------------------------------------------------------- + + template + void parallel_for_blocked_verbose ( + unsigned long num_threads, + long begin, + long end, + const T& funct, + long chunks_per_thread = 8 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(begin <= end && chunks_per_thread > 0, + "\t void parallel_for_blocked_verbose()" + << "\n\t Invalid inputs were given to this function" + << "\n\t begin: " << begin + << "\n\t end: " << end + << "\n\t chunks_per_thread: " << chunks_per_thread + ); + + impl::parfor_verbose_helper2 helper(funct, begin, end); + parallel_for_blocked(num_threads, begin, end, helper, chunks_per_thread); + } + +// ---------------------------------------------------------------------------------------- + + template + void parallel_for_blocked_verbose ( + long begin, + long end, + const T& funct, + long chunks_per_thread = 8 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(begin <= end && chunks_per_thread > 0, + "\t void parallel_for_blocked_verbose()" + << "\n\t Invalid inputs were given to this function" + << "\n\t begin: " << begin + << "\n\t end: " << end + << "\n\t chunks_per_thread: " << chunks_per_thread + ); + + impl::parfor_verbose_helper2 helper(funct, begin, end); + parallel_for_blocked(begin, end, helper, chunks_per_thread); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_PARALLEL_FoR_Hh_ + diff --git a/ml/dlib/dlib/threads/parallel_for_extension_abstract.h b/ml/dlib/dlib/threads/parallel_for_extension_abstract.h new file mode 100644 index 000000000..ffd2e0c44 --- /dev/null +++ b/ml/dlib/dlib/threads/parallel_for_extension_abstract.h @@ -0,0 +1,469 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_PARALLEL_FoR_ABSTRACT_Hh_ +#ifdef DLIB_PARALLEL_FoR_ABSTRACT_Hh_ + +#include "thread_pool_extension_abstract.h" +#include "async_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template + void parallel_for_blocked ( + thread_pool& tp, + long begin, + long end, + T& obj, + void (T::*funct)(long, long), + long chunks_per_thread = 8 + ); + /*! + requires + - begin <= end + - chunks_per_thread > 0 + ensures + - This is a convenience function for submitting a block of jobs to a thread_pool. + In particular, given the half open range [begin, end), this function will + split the range into approximately tp.num_threads_in_pool()*chunks_per_thread + blocks, which it will then submit to the thread_pool. The given thread_pool + will then call (obj.*funct)() on each of the subranges. + - To be precise, suppose we have broken the range [begin, end) into the + following subranges: + - [begin[0], end[0]) + - [begin[1], end[1]) + - [begin[2], end[2]) + ... + - [begin[n], end[n]) + Then parallel_for_blocked() submits each of these subranges to tp for + processing such that (obj.*funct)(begin[i], end[i]) is invoked for all valid + values of i. Moreover, the subranges are non-overlapping and completely + cover the total range of [begin, end). + - This function will not perform any memory allocations or create any system + resources such as mutex objects. + !*/ + +// ---------------------------------------------------------------------------------------- + + template + void parallel_for_blocked ( + unsigned long num_threads, + long begin, + long end, + T& obj, + void (T::*funct)(long, long), + long chunks_per_thread = 8 + ); + /*! + requires + - begin <= end + - chunks_per_thread > 0 + ensures + - This function is equivalent to the following block of code: + thread_pool tp(num_threads); + parallel_for_blocked(tp, begin, end, obj, funct, chunks_per_thread); + !*/ + +// ---------------------------------------------------------------------------------------- + + template + void parallel_for_blocked ( + thread_pool& tp, + long begin, + long end, + const T& funct, + long chunks_per_thread = 8 + ); + /*! + requires + - chunks_per_thread > 0 + - begin <= end + ensures + - This is a convenience function for submitting a block of jobs to a + thread_pool. In particular, given the range [begin, end), this function will + split the range into approximately tp.num_threads_in_pool()*chunks_per_thread + blocks, which it will then submit to the thread_pool. The given thread_pool + will then call funct() on each of the subranges. + - To be precise, suppose we have broken the range [begin, end) into the + following subranges: + - [begin[0], end[0]) + - [begin[1], end[1]) + - [begin[2], end[2]) + ... + - [begin[n], end[n]) + Then parallel_for_blocked() submits each of these subranges to tp for + processing such that funct(begin[i], end[i]) is invoked for all valid values + of i. + - This function will not perform any memory allocations or create any system + resources such as mutex objects. + !*/ + +// ---------------------------------------------------------------------------------------- + + template + void parallel_for_blocked ( + unsigned long num_threads, + long begin, + long end, + const T& funct, + long chunks_per_thread = 8 + ); + /*! + requires + - begin <= end + - chunks_per_thread > 0 + ensures + - This function is equivalent to the following block of code: + thread_pool tp(num_threads); + parallel_for_blocked(tp, begin, end, funct, chunks_per_thread); + !*/ + +// ---------------------------------------------------------------------------------------- + + template + void parallel_for_blocked ( + long begin, + long end, + const T& funct, + long chunks_per_thread = 8 + ); + /*! + requires + - begin <= end + - chunks_per_thread > 0 + ensures + - This function is equivalent to the following block of code: + parallel_for_blocked(default_thread_pool(), begin, end, funct, chunks_per_thread); + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template + void parallel_for ( + thread_pool& tp, + long begin, + long end, + T& obj, + void (T::*funct)(long), + long chunks_per_thread = 8 + ); + /*! + requires + - begin <= end + - chunks_per_thread > 0 + ensures + - This function is equivalent to the following function call: + parallel_for_blocked(tp, begin, end, [&](long begin_sub, long end_sub) + { + for (long i = begin_sub; i < end_sub; ++i) + (obj.*funct)(i); + }, chunks_per_thread); + - Therefore, this routine invokes (obj.*funct)(i) for all i in the range + [begin, end). However, it does so using tp.num_threads_in_pool() parallel + threads. + - This function will not perform any memory allocations or create any system + resources such as mutex objects. + !*/ + +// ---------------------------------------------------------------------------------------- + + template + void parallel_for ( + unsigned long num_threads, + long begin, + long end, + T& obj, + void (T::*funct)(long), + long chunks_per_thread = 8 + ); + /*! + requires + - begin <= end + - chunks_per_thread > 0 + ensures + - This function is equivalent to the following block of code: + thread_pool tp(num_threads); + parallel_for(tp, begin, end, obj, funct, chunks_per_thread); + !*/ + +// ---------------------------------------------------------------------------------------- + + template + void parallel_for ( + thread_pool& tp, + long begin, + long end, + const T& funct, + long chunks_per_thread = 8 + ); + /*! + requires + - begin <= end + - chunks_per_thread > 0 + ensures + - This function is equivalent to the following function call: + parallel_for_blocked(tp, begin, end, [&](long begin_sub, long end_sub) + { + for (long i = begin_sub; i < end_sub; ++i) + funct(i); + }, chunks_per_thread); + - Therefore, this routine invokes funct(i) for all i in the range [begin, end). + However, it does so using tp.num_threads_in_pool() parallel threads. + - This function will not perform any memory allocations or create any system + resources such as mutex objects. + !*/ + +// ---------------------------------------------------------------------------------------- + + template + void parallel_for ( + unsigned long num_threads, + long begin, + long end, + const T& funct, + long chunks_per_thread = 8 + ); + /*! + requires + - begin <= end + - chunks_per_thread > 0 + ensures + - This function is equivalent to the following block of code: + thread_pool tp(num_threads); + parallel_for(tp, begin, end, funct, chunks_per_thread); + !*/ + +// ---------------------------------------------------------------------------------------- + + template + void parallel_for ( + long begin, + long end, + const T& funct, + long chunks_per_thread = 8 + ); + /*! + requires + - begin <= end + - chunks_per_thread > 0 + ensures + - This function is equivalent to the following block of code: + parallel_for(default_thread_pool(), begin, end, funct, chunks_per_thread); + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template + void parallel_for_verbose ( + thread_pool& tp, + long begin, + long end, + T& obj, + void (T::*funct)(long), + long chunks_per_thread = 8 + ); + /*! + requires + - begin <= end + - chunks_per_thread > 0 + ensures + - This function is identical to the parallel_for() routine defined above except + that it will print messages to cout showing the progress in executing the + parallel for loop. + !*/ + +// ---------------------------------------------------------------------------------------- + + template + void parallel_for_verbose ( + unsigned long num_threads, + long begin, + long end, + T& obj, + void (T::*funct)(long), + long chunks_per_thread = 8 + ); + /*! + requires + - begin <= end + - chunks_per_thread > 0 + ensures + - This function is identical to the parallel_for() routine defined above except + that it will print messages to cout showing the progress in executing the + parallel for loop. + !*/ + +// ---------------------------------------------------------------------------------------- + + template + void parallel_for_verbose ( + thread_pool& tp, + long begin, + long end, + const T& funct, + long chunks_per_thread = 8 + ); + /*! + requires + - begin <= end + - chunks_per_thread > 0 + ensures + - This function is identical to the parallel_for() routine defined above except + that it will print messages to cout showing the progress in executing the + parallel for loop. + !*/ + +// ---------------------------------------------------------------------------------------- + + template + void parallel_for_verbose ( + unsigned long num_threads, + long begin, + long end, + const T& funct, + long chunks_per_thread = 8 + ); + /*! + requires + - begin <= end + - chunks_per_thread > 0 + ensures + - This function is identical to the parallel_for() routine defined above except + that it will print messages to cout showing the progress in executing the + parallel for loop. + !*/ + +// ---------------------------------------------------------------------------------------- + + template + void parallel_for_verbose ( + long begin, + long end, + const T& funct, + long chunks_per_thread = 8 + ); + /*! + requires + - begin <= end + - chunks_per_thread > 0 + ensures + - This function is identical to the parallel_for() routine defined above except + that it will print messages to cout showing the progress in executing the + parallel for loop. + - It will also use the default_thread_pool(). + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template + void parallel_for_blocked_verbose ( + thread_pool& tp, + long begin, + long end, + T& obj, + void (T::*funct)(long,long), + long chunks_per_thread = 8 + ); + /*! + requires + - begin <= end + - chunks_per_thread > 0 + ensures + - This function is identical to the parallel_for_blocked() routine defined + above except that it will print messages to cout showing the progress in + executing the parallel for loop. + !*/ + +// ---------------------------------------------------------------------------------------- + + template + void parallel_for_blocked_verbose ( + unsigned long num_threads, + long begin, + long end, + T& obj, + void (T::*funct)(long,long), + long chunks_per_thread = 8 + ); + /*! + requires + - begin <= end + - chunks_per_thread > 0 + ensures + - This function is identical to the parallel_for_blocked() routine defined + above except that it will print messages to cout showing the progress in + executing the parallel for loop. + !*/ + +// ---------------------------------------------------------------------------------------- + + template + void parallel_for_blocked_verbose ( + thread_pool& tp, + long begin, + long end, + const T& funct, + long chunks_per_thread = 8 + ); + /*! + requires + - begin <= end + - chunks_per_thread > 0 + ensures + - This function is identical to the parallel_for_blocked() routine defined + above except that it will print messages to cout showing the progress in + executing the parallel for loop. + !*/ + +// ---------------------------------------------------------------------------------------- + + template + void parallel_for_blocked_verbose ( + unsigned long num_threads, + long begin, + long end, + const T& funct, + long chunks_per_thread = 8 + ); + /*! + requires + - begin <= end + - chunks_per_thread > 0 + ensures + - This function is identical to the parallel_for_blocked() routine defined + above except that it will print messages to cout showing the progress in + executing the parallel for loop. + !*/ + +// ---------------------------------------------------------------------------------------- + + template + void parallel_for_blocked_verbose ( + long begin, + long end, + const T& funct, + long chunks_per_thread = 8 + ); + /*! + requires + - begin <= end + - chunks_per_thread > 0 + ensures + - This function is identical to the parallel_for_blocked() routine defined + above except that it will print messages to cout showing the progress in + executing the parallel for loop. + - It will also use the default_thread_pool() + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_PARALLEL_FoR_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/threads/posix.h b/ml/dlib/dlib/threads/posix.h new file mode 100644 index 000000000..7226743e1 --- /dev/null +++ b/ml/dlib/dlib/threads/posix.h @@ -0,0 +1,6 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_THREADS_KERNEl_1_ +#include "threads_kernel_2.h" +#endif + diff --git a/ml/dlib/dlib/threads/read_write_mutex_extension.h b/ml/dlib/dlib/threads/read_write_mutex_extension.h new file mode 100644 index 000000000..20e5d5ed8 --- /dev/null +++ b/ml/dlib/dlib/threads/read_write_mutex_extension.h @@ -0,0 +1,177 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_READ_WRITE_MUTEX_EXTENSIOn_ +#define DLIB_READ_WRITE_MUTEX_EXTENSIOn_ + +#include "threads_kernel.h" +#include "read_write_mutex_extension_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class read_write_mutex + { + /*! + INITIAL VALUE + - max_locks == defined by constructor + - available_locks == max_locks + - write_lock_in_progress == false + - write_lock_active == false + + CONVENTION + - Each time someone gets a read only lock they take one of the "available locks" + and each write lock takes all possible locks (i.e. max_locks). The number of + available locks is recorded in available_locks. Any time you try to lock this + object and there aren't available locks you have to wait. + + - max_locks == max_readonly_locks() + + - if (some thread is on the process of obtaining a write lock) then + - write_lock_in_progress == true + - else + - write_lock_in_progress == false + + - if (some thread currently has a write lock on this mutex) then + - write_lock_active == true + - else + - write_lock_active == false + !*/ + + public: + + read_write_mutex ( + ) : s(m), + max_locks(0xFFFFFFFF), + available_locks(max_locks), + write_lock_in_progress(false), + write_lock_active(false) + {} + + explicit read_write_mutex ( + unsigned long max_locks_ + ) : s(m), + max_locks(max_locks_), + available_locks(max_locks_), + write_lock_in_progress(false), + write_lock_active(false) + { + // make sure requires clause is not broken + DLIB_ASSERT(max_locks > 0, + "\t read_write_mutex::read_write_mutex(max_locks)" + << "\n\t You must give a non-zero value for max_locks" + << "\n\t this: " << this + ); + } + + ~read_write_mutex ( + ) + {} + + void lock ( + ) const + { + m.lock(); + + // If another write lock is already in progress then wait for it to finish + // before we start trying to grab all the available locks. This way we + // don't end up fighting over the locks. + while (write_lock_in_progress) + s.wait(); + + // grab the right to perform a write lock + write_lock_in_progress = true; + + // now start grabbing all the locks + unsigned long locks_obtained = available_locks; + available_locks = 0; + while (locks_obtained != max_locks) + { + s.wait(); + locks_obtained += available_locks; + available_locks = 0; + } + + write_lock_in_progress = false; + write_lock_active = true; + + m.unlock(); + } + + void unlock ( + ) const + { + m.lock(); + + // only do something if there really was a lock in place + if (write_lock_active) + { + available_locks = max_locks; + write_lock_active = false; + s.broadcast(); + } + + m.unlock(); + } + + void lock_readonly ( + ) const + { + m.lock(); + + while (available_locks == 0) + s.wait(); + + --available_locks; + + m.unlock(); + } + + void unlock_readonly ( + ) const + { + m.lock(); + + // If this condition is false then it means there are no more readonly locks + // to free. So we don't do anything. + if (available_locks != max_locks && !write_lock_active) + { + ++available_locks; + + // only perform broadcast when there is another thread that might be listening + if (available_locks == 1 || write_lock_in_progress) + { + s.broadcast(); + } + } + + m.unlock(); + } + + unsigned long max_readonly_locks ( + ) const + { + return max_locks; + } + + private: + mutex m; + signaler s; + const unsigned long max_locks; + mutable unsigned long available_locks; + mutable bool write_lock_in_progress; + mutable bool write_lock_active; + + // restricted functions + read_write_mutex(read_write_mutex&); // copy constructor + read_write_mutex& operator=(read_write_mutex&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_READ_WRITE_MUTEX_EXTENSIOn_ + + diff --git a/ml/dlib/dlib/threads/read_write_mutex_extension_abstract.h b/ml/dlib/dlib/threads/read_write_mutex_extension_abstract.h new file mode 100644 index 000000000..18672b057 --- /dev/null +++ b/ml/dlib/dlib/threads/read_write_mutex_extension_abstract.h @@ -0,0 +1,146 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_READWRITE_MUTEX_EXTENSIOn_ABSTRACT_ +#ifdef DLIB_READWRITE_MUTEX_EXTENSIOn_ABSTRACT_ + +#include "threads_kernel_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class read_write_mutex + { + /*! + INITIAL VALUE + read_write_mutex is in the fully unlocked state + + WHAT THIS OBJECT REPRESENTS + This object represents a mutex intended to be used for synchronous + thread control of shared data. When a thread wants to access some + shared data it locks out other threads by calling lock() and calls + unlock() when it is finished. + + This mutex also has the additional ability to distinguish between + a lock for the purposes of modifying some shared data, a write lock, + and a lock for the purposes of only reading shared data, a readonly + lock. The lock() and unlock() functions are used for write locks while + the lock_readonly() and unlock_readonly() are for readonly locks. + + The difference between a readonly and write lock can be understood as + follows. The read_write_mutex will allow many threads to obtain simultaneous + readonly locks but will only allow a single thread to obtain a write lock. + Moreover, while the write lock is obtained no other threads are allowed + to have readonly locks. + !*/ + public: + + read_write_mutex ( + ); + /*! + ensures + - #*this is properly initialized + - max_readonly_locks() == 0xFFFFFFFF + (i.e. about 4 billion) + throws + - dlib::thread_error + the constructor may throw this exception if there is a problem + gathering resources to create the read_write_mutex. + !*/ + + explicit read_write_mutex ( + unsigned long max_locks + ); + /*! + requires + - max_locks > 0 + ensures + - #*this is properly initialized + - max_readonly_locks() == max_locks + throws + - dlib::thread_error + the constructor may throw this exception if there is a problem + gathering resources to create the read_write_mutex. + !*/ + + ~read_write_mutex ( + ); + /*! + requires + - *this is not locked + ensures + - all resources allocated by *this have been freed + !*/ + + void lock ( + ) const; + /*! + requires + - The thread calling this function does not have any kind of lock on this + object + ensures + - if (there is any kind of lock on *this) then + - the calling thread is put to sleep until a write lock becomes available. + Once available, a write lock is obtained on this mutex and this function + terminates. + - else + - a write lock is obtained on this mutex and the calling thread is not put to sleep + !*/ + + void unlock ( + ) const; + /*! + ensures + - if (there is a write lock on *this) then + - #*this is unlocked (i.e. other threads may now lock this object) + - else + - the call to unlock() has no effect + !*/ + + unsigned long max_readonly_locks ( + ) const; + /*! + ensures + - returns the maximum number of concurrent readonly locks this object will allow. + !*/ + + void lock_readonly ( + ) const; + /*! + requires + - The thread calling this function does not already have a write + lock on this object + ensures + - if (there is a write lock on *this or there are no free readonly locks) then + - the calling thread is put to sleep until there is no longer a write lock + and a free readonly lock is available. Once this is the case, a readonly + lock is obtained and this function terminates. + - else + - a readonly lock is obtained on *this and the calling thread is not put + to sleep. Note that multiple readonly locks can be obtained at once. + !*/ + + void unlock_readonly ( + ) const; + /*! + ensures + - if (there is a readonly lock on *this) then + - one readonly lock is removed from *this. + - else + - the call to unlock_readonly() has no effect. + !*/ + + private: + // restricted functions + read_write_mutex(read_write_mutex&); // copy constructor + read_write_mutex& operator=(read_write_mutex&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_READWRITE_MUTEX_EXTENSIOn_ABSTRACT_ + + diff --git a/ml/dlib/dlib/threads/rmutex_extension.h b/ml/dlib/dlib/threads/rmutex_extension.h new file mode 100644 index 000000000..b7bf998be --- /dev/null +++ b/ml/dlib/dlib/threads/rmutex_extension.h @@ -0,0 +1,109 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_RMUTEX_EXTENSIOn_ +#define DLIB_RMUTEX_EXTENSIOn_ + +#include "threads_kernel.h" +#include "rmutex_extension_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class rmutex + { + /*! + INITIAL VALUE + count == 0 + thread_id == 0 + + CONVENTION + - count == lock_count() + + - if (no thread currently has a lock on this mutex) then + - count == 0 + - else + - count == the number of times the thread that owns this mutex has + called lock() + - thread_id == the id of this thread. + !*/ + public: + + rmutex ( + ) : s(m), + thread_id(0), + count(0) + {} + + ~rmutex ( + ) + {} + + unsigned long lock_count ( + ) const + { + return count; + } + + void lock ( + unsigned long times = 1 + ) const + { + const thread_id_type current_thread_id = get_thread_id(); + m.lock(); + if (thread_id == current_thread_id) + { + // we already own this mutex in this case + count += times; + } + else + { + // wait for our turn to claim this rmutex + while (count != 0) + s.wait(); + + count = times; + thread_id = current_thread_id; + } + m.unlock(); + } + + void unlock ( + unsigned long times = 1 + ) const + { + const thread_id_type current_thread_id = get_thread_id(); + m.lock(); + if (thread_id == current_thread_id) + { + if (count <= times) + { + count = 0; + s.signal(); + } + else + { + count -= times; + } + } + m.unlock(); + } + + private: + mutex m; + signaler s; + mutable thread_id_type thread_id; + mutable unsigned long count; + + // restricted functions + rmutex(rmutex&); // copy constructor + rmutex& operator=(rmutex&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_RMUTEX_EXTENSIOn_ + diff --git a/ml/dlib/dlib/threads/rmutex_extension_abstract.h b/ml/dlib/dlib/threads/rmutex_extension_abstract.h new file mode 100644 index 000000000..144dbf4d7 --- /dev/null +++ b/ml/dlib/dlib/threads/rmutex_extension_abstract.h @@ -0,0 +1,107 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_RMUTEX_EXTENSIOn_ABSTRACT_ +#ifdef DLIB_RMUTEX_EXTENSIOn_ABSTRACT_ + +#include "threads_kernel_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class rmutex + { + /*! + INITIAL VALUE + rmutex is in the unlocked state + + WHAT THIS OBJECT REPRESENTS + This object represents a recursive mutex intended to be used for synchronous + thread control of shared data. When a thread wants to access some + shared data it locks out other threads by calling lock() and calls + unlock() when it is finished. + + The difference between this and the normal mutex object is that it is safe to + call lock() from a thread that already has a lock on this mutex. Doing + so just increments a counter but otherwise has no effect on the mutex. + Note that unlock() must be called for each call to lock() to release the + mutex. + !*/ + public: + + rmutex ( + ); + /*! + ensures + - #*this is properly initialized + throws + - dlib::thread_error + the constructor may throw this exception if there is a problem + gathering resources to create the rmutex. + !*/ + + ~rmutex ( + ); + /*! + requires + - *this is not locked + ensures + - all resources allocated by *this have been freed + !*/ + + unsigned long lock_count ( + ) const; + /*! + requires + - the calling thread has a lock on this mutex + ensures + - returns the number of times the thread has called lock() + !*/ + + void lock ( + unsigned long times = 1 + ) const; + /*! + ensures + - if (*this is currently locked by another thread) then + - the thread that called lock() on *this is put to sleep until + it becomes available. + - #lock_count() == times + - if (*this is currently unlocked) then + - #*this becomes locked and the current thread is NOT put to sleep + but now "owns" #*this + - #lock_count() == times + - if (*this is locked and owned by the current thread) then + - the calling thread retains its lock on *this and isn't put to sleep. + - #lock_count() == lock_count() + times + !*/ + + void unlock ( + unsigned long times = 1 + ) const; + /*! + ensures + - if (*this is currently locked and owned by the thread calling unlock) then + - if (lock_count() <= times ) then + - #*this is unlocked (i.e. other threads may now lock this object) + - else + - #*this will remain locked + - #lock_count() == lock_count() - times + - else + - the call to unlock() has no effect + !*/ + + + private: + // restricted functions + rmutex(rmutex&); // copy constructor + rmutex& operator=(rmutex&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_RMUTEX_EXTENSIOn_ABSTRACT_ + diff --git a/ml/dlib/dlib/threads/rsignaler_extension.h b/ml/dlib/dlib/threads/rsignaler_extension.h new file mode 100644 index 000000000..bfb5a7ecb --- /dev/null +++ b/ml/dlib/dlib/threads/rsignaler_extension.h @@ -0,0 +1,90 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_RSIGNALER_EXTENSIOn_ +#define DLIB_RSIGNALER_EXTENSIOn_ + +#include "rsignaler_extension_abstract.h" +#include "rmutex_extension.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class rsignaler + { + public: + rsignaler ( + const rmutex& associated_mutex + ) : + assoc_mutex(associated_mutex), + s(m) + {} + + ~rsignaler ( + ) + {} + + void wait ( + ) const + { + m.lock(); + const unsigned long lock_count = assoc_mutex.lock_count(); + assoc_mutex.unlock(lock_count); + s.wait(); + m.unlock(); + assoc_mutex.lock(lock_count); + } + + bool wait_or_timeout ( + unsigned long milliseconds + ) const + { + m.lock(); + const unsigned long lock_count = assoc_mutex.lock_count(); + assoc_mutex.unlock(lock_count); + bool res = s.wait_or_timeout(milliseconds); + m.unlock(); + assoc_mutex.lock(lock_count); + return res; + } + + void signal ( + ) const + { + m.lock(); + s.signal(); + m.unlock(); + } + + void broadcast ( + ) const + { + m.lock(); + s.broadcast(); + m.unlock(); + } + + const rmutex& get_mutex ( + ) const { return assoc_mutex; } + + private: + + const rmutex& assoc_mutex; + mutex m; + signaler s; + + + // restricted functions + rsignaler(rsignaler&); // copy constructor + rsignaler& operator=(rsignaler&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_RSIGNALER_EXTENSIOn_ + + + diff --git a/ml/dlib/dlib/threads/rsignaler_extension_abstract.h b/ml/dlib/dlib/threads/rsignaler_extension_abstract.h new file mode 100644 index 000000000..ae5f450d7 --- /dev/null +++ b/ml/dlib/dlib/threads/rsignaler_extension_abstract.h @@ -0,0 +1,123 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_RSIGNALER_EXTENSIOn_ABSTRACT_ +#ifdef DLIB_RSIGNALER_EXTENSIOn_ABSTRACT_ + +#include "threads_kernel_abstract.h" +#include "rmutex_extension_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class rsignaler + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents an event signaling system for threads. It gives + a thread the ability to wake up other threads that are waiting for a + particular signal. + + Each rsignaler object is associated with one and only one rmutex object. + More than one rsignaler object may be associated with a single rmutex + but a signaler object may only be associated with a single rmutex. + + NOTE: + You must guard against spurious wakeups. This means that a thread + might return from a call to wait even if no other thread called + signal. This is rare but must be guarded against. + + Also note that this object is identical to the signaler object + except that it works with rmutex objects rather than mutex objects. + !*/ + + public: + + rsignaler ( + const rmutex& associated_mutex + ); + /*! + ensures + - #*this is properly initialized + - #get_mutex() == associated_mutex + throws + - dlib::thread_error + the constructor may throw this exception if there is a problem + gathering resources to create the signaler. + !*/ + + + ~rsignaler ( + ); + /*! + ensures + - all resources allocated by *this have been freed + !*/ + + void wait ( + ) const; + /*! + requires + - get_mutex() is locked and owned by the calling thread + ensures + - atomically unlocks get_mutex() and blocks the calling thread + - calling thread may wake if another thread calls signal() or broadcast() + on *this + - when wait() returns the calling thread again has a lock on get_mutex() + !*/ + + bool wait_or_timeout ( + unsigned long milliseconds + ) const; + /*! + requires + - get_mutex() is locked and owned by the calling thread + ensures + - atomically unlocks get_mutex() and blocks the calling thread + - calling thread may wake if another thread calls signal() or broadcast() + on *this + - after the specified number of milliseconds has elapsed the calling thread + will wake once get_mutex() is free + - when wait returns the calling thread again has a lock on get_mutex() + + - returns false if the call to wait_or_timeout timed out + - returns true if the call did not time out + !*/ + + void signal ( + ) const; + /*! + ensures + - if (at least one thread is waiting on *this) then + - at least one of the waiting threads will wake + !*/ + + void broadcast ( + ) const; + /*! + ensures + - any and all threads waiting on *this will wake + !*/ + + const rmutex& get_mutex ( + ) const; + /*! + ensures + - returns a const reference to the rmutex associated with *this + !*/ + + + private: + // restricted functions + rsignaler(rsignaler&); // copy constructor + rsignaler& operator=(rsignaler&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_RSIGNALER_EXTENSIOn_ABSTRACT_ + + diff --git a/ml/dlib/dlib/threads/thread_function_extension.h b/ml/dlib/dlib/threads/thread_function_extension.h new file mode 100644 index 000000000..7ecdd6520 --- /dev/null +++ b/ml/dlib/dlib/threads/thread_function_extension.h @@ -0,0 +1,215 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_THREAD_FUNCTIOn_ +#define DLIB_THREAD_FUNCTIOn_ + +#include + +#include "thread_function_extension_abstract.h" +#include "threads_kernel.h" +#include "auto_mutex_extension.h" +#include "threaded_object_extension.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class thread_function : private threaded_object + { + + class base_funct + { + public: + virtual void go() = 0; + virtual ~base_funct() {} + }; + + template + class super_funct_4 : public base_funct + { + public: + super_funct_4 ( F funct, T1 arg1, T2 arg2, T3 arg3, T4 arg4) : + f(funct), + a1(arg1), + a2(arg2), + a3(arg3), + a4(arg4) + { + } + + void go() { f(a1, a2, a3, a4); } + + + F f; + T1 a1; + T2 a2; + T3 a3; + T4 a4; + }; + + template + class super_funct_3 : public base_funct + { + public: + super_funct_3 ( F funct, T1 arg1, T2 arg2, T3 arg3): + f(funct), + a1(arg1), + a2(arg2), + a3(arg3) + { + } + + void go() { f(a1, a2, a3); } + + + F f; + T1 a1; + T2 a2; + T3 a3; + }; + + template + class super_funct_2 : public base_funct + { + public: + super_funct_2 ( F funct, T1 arg1, T2 arg2) : + f(funct), + a1(arg1), + a2(arg2) + { + } + + void go() { f(a1, a2); } + + + F f; + T1 a1; + T2 a2; + }; + + template + class super_funct_1 : public base_funct + { + public: + super_funct_1 ( F funct, T arg) : f(funct), a(arg) + { + } + + void go() { f(a); } + + + F f; + T a; + }; + + template + class super_funct_0 : public base_funct + { + public: + super_funct_0 ( F funct) : f(funct) + { + } + + void go() { f(); } + + F f; + }; + + public: + + template + thread_function ( + F funct + ) + { + f.reset(new super_funct_0(funct)); + start(); + } + + template + thread_function ( + F funct, + T arg + ) + { + f.reset(new super_funct_1(funct,arg)); + start(); + } + + template + thread_function ( + F funct, + T1 arg1, + T2 arg2 + ) + { + f.reset(new super_funct_2(funct, arg1, arg2)); + start(); + } + + template + thread_function ( + F funct, + T1 arg1, + T2 arg2, + T3 arg3 + ) + { + f.reset(new super_funct_3(funct, arg1, arg2, arg3)); + start(); + } + + template + thread_function ( + F funct, + T1 arg1, + T2 arg2, + T3 arg3, + T4 arg4 + ) + { + f.reset(new super_funct_4(funct, arg1, arg2, arg3, arg4)); + start(); + } + + ~thread_function ( + ) + { + threaded_object::wait(); + } + + bool is_alive ( + ) const + { + return threaded_object::is_alive(); + } + + void wait ( + ) const + { + threaded_object::wait(); + } + + private: + + void thread () + { + f->go(); + } + + std::unique_ptr f; + + // restricted functions + thread_function(thread_function&); // copy constructor + thread_function& operator=(thread_function&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_THREAD_FUNCTIOn_ + + + diff --git a/ml/dlib/dlib/threads/thread_function_extension_abstract.h b/ml/dlib/dlib/threads/thread_function_extension_abstract.h new file mode 100644 index 000000000..65ea998ac --- /dev/null +++ b/ml/dlib/dlib/threads/thread_function_extension_abstract.h @@ -0,0 +1,146 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_THREAD_FUNCTIOn_ABSTRACT_ +#ifdef DLIB_THREAD_FUNCTIOn_ABSTRACT_ + +#include "threads_kernel_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class thread_function + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a thread on a global C++ function or function + object. That is, it allows you to run a function in its own thread. + !*/ + public: + + template + thread_function ( + F funct + ); + /*! + ensures + - #*this is properly initialized + - the function funct has been started in its own thread + throws + - std::bad_alloc + - dlib::thread_error + the constructor may throw this exception if there is a problem + gathering resources to create threading objects. + !*/ + + template + thread_function ( + F funct, + T1 arg1 + ); + /*! + ensures + - #*this is properly initialized + - A thread has been created and it will call funct(arg1) + throws + - std::bad_alloc + - dlib::thread_error + the constructor may throw this exception if there is a problem + gathering resources to create threading objects. + !*/ + + template + thread_function ( + F funct, + T1 arg1, + T2 arg2 + ); + /*! + ensures + - #*this is properly initialized + - A thread has been created and it will call funct(arg1, arg2) + throws + - std::bad_alloc + - dlib::thread_error + the constructor may throw this exception if there is a problem + gathering resources to create threading objects. + !*/ + + template + thread_function ( + F funct, + T1 arg1, + T2 arg2, + T3 arg3 + ); + /*! + ensures + - #*this is properly initialized + - A thread has been created and it will call funct(arg1, arg2, arg3) + throws + - std::bad_alloc + - dlib::thread_error + the constructor may throw this exception if there is a problem + gathering resources to create threading objects. + !*/ + + template + thread_function ( + F funct, + T1 arg1, + T2 arg2, + T3 arg3, + T4 arg4 + ); + /*! + ensures + - #*this is properly initialized + - A thread has been created and it will call funct(arg1, arg2, arg3, arg4) + throws + - std::bad_alloc + - dlib::thread_error + the constructor may throw this exception if there is a problem + gathering resources to create threading objects. + !*/ + + ~thread_function ( + ); + /*! + ensures + - all resources allocated by *this have been freed. + - blocks until is_alive() == false + !*/ + + bool is_alive ( + ) const; + /*! + ensures + - if (this object's thread has yet to terminate) then + - returns true + - else + - returns false + !*/ + + void wait ( + ) const; + /*! + ensures + - if (is_alive() == true) then + - blocks until this object's thread terminates + !*/ + + private: + + // restricted functions + thread_function(thread_function&); // copy constructor + thread_function& operator=(thread_function&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_THREAD_FUNCTIOn_ABSTRACT_ + + diff --git a/ml/dlib/dlib/threads/thread_pool_extension.cpp b/ml/dlib/dlib/threads/thread_pool_extension.cpp new file mode 100644 index 000000000..00d99b910 --- /dev/null +++ b/ml/dlib/dlib/threads/thread_pool_extension.cpp @@ -0,0 +1,347 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_THREAD_POOl_CPPh_ +#define DLIB_THREAD_POOl_CPPh_ + +#include "thread_pool_extension.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + thread_pool_implementation:: + thread_pool_implementation ( + unsigned long num_threads + ) : + task_done_signaler(m), + task_ready_signaler(m), + we_are_destructing(false) + { + tasks.resize(num_threads); + threads.resize(num_threads); + for (unsigned long i = 0; i < num_threads; ++i) + { + threads[i] = std::thread([&](){this->thread();}); + } + } + +// ---------------------------------------------------------------------------------------- + + void thread_pool_implementation:: + shutdown_pool ( + ) + { + { + auto_mutex M(m); + + // first wait for all pending tasks to finish + bool found_task = true; + while (found_task) + { + found_task = false; + for (unsigned long i = 0; i < tasks.size(); ++i) + { + // If task bucket i has a task that is currently supposed to be processed + if (tasks[i].is_empty() == false) + { + found_task = true; + break; + } + } + + if (found_task) + task_done_signaler.wait(); + } + + // now tell the threads to kill themselves + we_are_destructing = true; + task_ready_signaler.broadcast(); + } + + // wait for all threads to terminate + for (auto& t : threads) + t.join(); + threads.clear(); + + // Throw any unhandled exceptions. Since shutdown_pool() is only called in the + // destructor this will kill the program. + for (auto&& task : tasks) + task.propagate_exception(); + } + +// ---------------------------------------------------------------------------------------- + + thread_pool_implementation:: + ~thread_pool_implementation() + { + shutdown_pool(); + } + +// ---------------------------------------------------------------------------------------- + + unsigned long thread_pool_implementation:: + num_threads_in_pool ( + ) const + { + auto_mutex M(m); + return tasks.size(); + } + +// ---------------------------------------------------------------------------------------- + + void thread_pool_implementation:: + wait_for_task ( + uint64 task_id + ) const + { + auto_mutex M(m); + if (tasks.size() != 0) + { + const unsigned long idx = task_id_to_index(task_id); + while (tasks[idx].task_id == task_id) + task_done_signaler.wait(); + + for (auto&& task : tasks) + task.propagate_exception(); + } + } + +// ---------------------------------------------------------------------------------------- + + void thread_pool_implementation:: + wait_for_all_tasks ( + ) const + { + const thread_id_type thread_id = get_thread_id(); + + auto_mutex M(m); + bool found_task = true; + while (found_task) + { + found_task = false; + for (unsigned long i = 0; i < tasks.size(); ++i) + { + // If task bucket i has a task that is currently supposed to be processed + // and it originated from the calling thread + if (tasks[i].is_empty() == false && tasks[i].thread_id == thread_id) + { + found_task = true; + break; + } + } + + if (found_task) + task_done_signaler.wait(); + } + + // throw any exceptions generated by the tasks + for (auto&& task : tasks) + task.propagate_exception(); + } + +// ---------------------------------------------------------------------------------------- + + bool thread_pool_implementation:: + is_worker_thread ( + const thread_id_type id + ) const + { + for (unsigned long i = 0; i < worker_thread_ids.size(); ++i) + { + if (worker_thread_ids[i] == id) + return true; + } + + // if there aren't any threads in the pool then we consider all threads + // to be worker threads + if (tasks.size() == 0) + return true; + else + return false; + } + +// ---------------------------------------------------------------------------------------- + + void thread_pool_implementation:: + thread ( + ) + { + { + // save the id of this worker thread into worker_thread_ids + auto_mutex M(m); + thread_id_type id = get_thread_id(); + worker_thread_ids.push_back(id); + } + + task_state_type task; + while (we_are_destructing == false) + { + long idx = 0; + + // wait for a task to do + { auto_mutex M(m); + while ( (idx = find_ready_task()) == -1 && we_are_destructing == false) + task_ready_signaler.wait(); + + if (we_are_destructing) + break; + + tasks[idx].is_being_processed = true; + task = tasks[idx]; + } + + std::exception_ptr eptr = nullptr; + try + { + // now do the task + if (task.bfp) + task.bfp(); + else if (task.mfp0) + task.mfp0(); + else if (task.mfp1) + task.mfp1(task.arg1); + else if (task.mfp2) + task.mfp2(task.arg1, task.arg2); + } + catch(...) + { + eptr = std::current_exception(); + } + + // Now let others know that we finished the task. We do this + // by clearing out the state of this task + { auto_mutex M(m); + tasks[idx].is_being_processed = false; + tasks[idx].task_id = 0; + tasks[idx].bfp.clear(); + tasks[idx].mfp0.clear(); + tasks[idx].mfp1.clear(); + tasks[idx].mfp2.clear(); + tasks[idx].arg1 = 0; + tasks[idx].arg2 = 0; + tasks[idx].eptr = eptr; + task_done_signaler.broadcast(); + } + + } + } + +// ---------------------------------------------------------------------------------------- + + long thread_pool_implementation:: + find_empty_task_slot ( + ) const + { + for (auto&& task : tasks) + task.propagate_exception(); + + for (unsigned long i = 0; i < tasks.size(); ++i) + { + if (tasks[i].is_empty()) + return i; + } + + return -1; + } + +// ---------------------------------------------------------------------------------------- + + long thread_pool_implementation:: + find_ready_task ( + ) const + { + for (unsigned long i = 0; i < tasks.size(); ++i) + { + if (tasks[i].is_ready()) + return i; + } + + return -1; + } + +// ---------------------------------------------------------------------------------------- + + uint64 thread_pool_implementation:: + make_next_task_id ( + long idx + ) + { + uint64 id = tasks[idx].next_task_id * tasks.size() + idx; + tasks[idx].next_task_id += 1; + return id; + } + +// ---------------------------------------------------------------------------------------- + + unsigned long thread_pool_implementation:: + task_id_to_index ( + uint64 id + ) const + { + return static_cast(id%tasks.size()); + } + +// ---------------------------------------------------------------------------------------- + + uint64 thread_pool_implementation:: + add_task_internal ( + const bfp_type& bfp, + std::shared_ptr& item + ) + { + auto_mutex M(m); + const thread_id_type my_thread_id = get_thread_id(); + + // find a thread that isn't doing anything + long idx = find_empty_task_slot(); + if (idx == -1 && is_worker_thread(my_thread_id)) + { + // this function is being called from within a worker thread and there + // aren't any other worker threads free so just perform the task right + // here + + M.unlock(); + bfp(); + + // return a task id that is both non-zero and also one + // that is never normally returned. This way calls + // to wait_for_task() will never block given this id. + return 1; + } + + // wait until there is a thread that isn't doing anything + while (idx == -1) + { + task_done_signaler.wait(); + idx = find_empty_task_slot(); + } + + tasks[idx].thread_id = my_thread_id; + tasks[idx].task_id = make_next_task_id(idx); + tasks[idx].bfp = bfp; + tasks[idx].function_copy.swap(item); + + task_ready_signaler.signal(); + + return tasks[idx].task_id; + } + +// ---------------------------------------------------------------------------------------- + + bool thread_pool_implementation:: + is_task_thread ( + ) const + { + auto_mutex M(m); + return is_worker_thread(get_thread_id()); + } + +// ---------------------------------------------------------------------------------------- + +} + + +#endif // DLIB_THREAD_POOl_CPPh_ + diff --git a/ml/dlib/dlib/threads/thread_pool_extension.h b/ml/dlib/dlib/threads/thread_pool_extension.h new file mode 100644 index 000000000..bc2e1782c --- /dev/null +++ b/ml/dlib/dlib/threads/thread_pool_extension.h @@ -0,0 +1,1392 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_THREAD_POOl_Hh_ +#define DLIB_THREAD_POOl_Hh_ + +#include +#include +#include + +#include "thread_pool_extension_abstract.h" +#include "multithreaded_object_extension.h" +#include "../member_function_pointer.h" +#include "../bound_function_pointer.h" +#include "threads_kernel.h" +#include "auto_mutex_extension.h" +#include "../uintn.h" +#include "../array.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class thread_pool_implementation; + + template < + typename T + > + class future + { + /*! + INITIAL VALUE + - task_id == 0 + - tp.get() == 0 + + CONVENTION + - is_ready() == (tp.get() == 0) + - get() == var + + - if (tp.get() != 0) + - tp == a pointer to the thread_pool_implementation that is using this future object + - task_id == the task id of the task in the thread pool tp that is using + this future object. + !*/ + public: + + future ( + ) : task_id(0) {} + + future ( + const T& item + ) : task_id(0), var(item) {} + + future ( + const future& item + ) :task_id(0), var(item.get()) {} + + ~future ( + ) { wait(); } + + future& operator=( + const T& item + ) { get() = item; return *this; } + + future& operator=( + const future& item + ) { get() = item.get(); return *this; } + + operator T& ( + ) { return get(); } + + operator const T& ( + ) const { return get(); } + + T& get ( + ) { wait(); return var; } + + const T& get ( + ) const { wait(); return var; } + + bool is_ready ( + ) const { return tp.get() == 0; } + + private: + + friend class thread_pool; + + inline void wait () const; + + mutable uint64 task_id; + mutable std::shared_ptr tp; + + T var; + }; + +// ---------------------------------------------------------------------------------------- + + template + inline void swap ( + future& a, + future& b + ) { dlib::exchange(a.get(), b.get()); } + // Note that dlib::exchange() just calls std::swap. I'm only using it because + // this works around some bugs in certain compilers. + +// ---------------------------------------------------------------------------------------- + + template bool operator== (const future& a, const future& b) { return a.get() == b.get(); } + template bool operator!= (const future& a, const future& b) { return a.get() != b.get(); } + template bool operator<= (const future& a, const future& b) { return a.get() <= b.get(); } + template bool operator>= (const future& a, const future& b) { return a.get() >= b.get(); } + template bool operator< (const future& a, const future& b) { return a.get() < b.get(); } + template bool operator> (const future& a, const future& b) { return a.get() > b.get(); } + + template bool operator== (const future& a, const T& b) { return a.get() == b; } + template bool operator== (const T& a, const future& b) { return a == b.get(); } + template bool operator!= (const future& a, const T& b) { return a.get() != b; } + template bool operator!= (const T& a, const future& b) { return a != b.get(); } + template bool operator<= (const future& a, const T& b) { return a.get() <= b; } + template bool operator<= (const T& a, const future& b) { return a <= b.get(); } + template bool operator>= (const future& a, const T& b) { return a.get() >= b; } + template bool operator>= (const T& a, const future& b) { return a >= b.get(); } + template bool operator< (const future& a, const T& b) { return a.get() < b; } + template bool operator< (const T& a, const future& b) { return a < b.get(); } + template bool operator> (const future& a, const T& b) { return a.get() > b; } + template bool operator> (const T& a, const future& b) { return a > b.get(); } + +// ---------------------------------------------------------------------------------------- + + class thread_pool_implementation + { + /*! + CONVENTION + - num_threads_in_pool() == tasks.size() + - if (the destructor has been called) then + - we_are_destructing == true + - else + - we_are_destructing == false + + - is_task_thread() == is_worker_thread(get_thread_id()) + + - m == the mutex used to protect everything in this object + - worker_thread_ids == an array that contains the thread ids for + all the threads in the thread pool + !*/ + typedef bound_function_pointer::kernel_1a_c bfp_type; + + friend class thread_pool; + explicit thread_pool_implementation ( + unsigned long num_threads + ); + + public: + ~thread_pool_implementation( + ); + + void wait_for_task ( + uint64 task_id + ) const; + + unsigned long num_threads_in_pool ( + ) const; + + void wait_for_all_tasks ( + ) const; + + bool is_task_thread ( + ) const; + + template + uint64 add_task ( + T& obj, + void (T::*funct)() + ) + { + auto_mutex M(m); + const thread_id_type my_thread_id = get_thread_id(); + + // find a thread that isn't doing anything + long idx = find_empty_task_slot(); + if (idx == -1 && is_worker_thread(my_thread_id)) + { + // this function is being called from within a worker thread and there + // aren't any other worker threads free so just perform the task right + // here + + M.unlock(); + (obj.*funct)(); + + // return a task id that is both non-zero and also one + // that is never normally returned. This way calls + // to wait_for_task() will never block given this id. + return 1; + } + + // wait until there is a thread that isn't doing anything + while (idx == -1) + { + task_done_signaler.wait(); + idx = find_empty_task_slot(); + } + + tasks[idx].thread_id = my_thread_id; + tasks[idx].task_id = make_next_task_id(idx); + tasks[idx].mfp0.set(obj,funct); + + task_ready_signaler.signal(); + + return tasks[idx].task_id; + } + + template + uint64 add_task ( + T& obj, + void (T::*funct)(long), + long arg1 + ) + { + auto_mutex M(m); + const thread_id_type my_thread_id = get_thread_id(); + + // find a thread that isn't doing anything + long idx = find_empty_task_slot(); + if (idx == -1 && is_worker_thread(my_thread_id)) + { + // this function is being called from within a worker thread and there + // aren't any other worker threads free so just perform the task right + // here + + M.unlock(); + (obj.*funct)(arg1); + + // return a task id that is both non-zero and also one + // that is never normally returned. This way calls + // to wait_for_task() will never block given this id. + return 1; + } + + // wait until there is a thread that isn't doing anything + while (idx == -1) + { + task_done_signaler.wait(); + idx = find_empty_task_slot(); + } + + tasks[idx].thread_id = my_thread_id; + tasks[idx].task_id = make_next_task_id(idx); + tasks[idx].mfp1.set(obj,funct); + tasks[idx].arg1 = arg1; + + task_ready_signaler.signal(); + + return tasks[idx].task_id; + } + + template + uint64 add_task ( + T& obj, + void (T::*funct)(long,long), + long arg1, + long arg2 + ) + { + auto_mutex M(m); + const thread_id_type my_thread_id = get_thread_id(); + + // find a thread that isn't doing anything + long idx = find_empty_task_slot(); + if (idx == -1 && is_worker_thread(my_thread_id)) + { + // this function is being called from within a worker thread and there + // aren't any other worker threads free so just perform the task right + // here + + M.unlock(); + (obj.*funct)(arg1, arg2); + + // return a task id that is both non-zero and also one + // that is never normally returned. This way calls + // to wait_for_task() will never block given this id. + return 1; + } + + // wait until there is a thread that isn't doing anything + while (idx == -1) + { + task_done_signaler.wait(); + idx = find_empty_task_slot(); + } + + tasks[idx].thread_id = my_thread_id; + tasks[idx].task_id = make_next_task_id(idx); + tasks[idx].mfp2.set(obj,funct); + tasks[idx].arg1 = arg1; + tasks[idx].arg2 = arg2; + + task_ready_signaler.signal(); + + return tasks[idx].task_id; + } + + struct function_object_copy + { + virtual ~function_object_copy(){} + }; + + template + struct function_object_copy_instance : function_object_copy + { + function_object_copy_instance(const T& item_) : item(item_) {} + T item; + virtual ~function_object_copy_instance(){} + }; + + uint64 add_task_internal ( + const bfp_type& bfp, + std::shared_ptr& item + ); + /*! + ensures + - adds a task to call the given bfp object. + - swaps item into the internal task object which will have a lifetime + at least as long as the running task. + - returns the task id for this new task + !*/ + + uint64 add_task_internal ( + const bfp_type& bfp + ) { std::shared_ptr temp; return add_task_internal(bfp, temp); } + /*! + ensures + - adds a task to call the given bfp object. + - returns the task id for this new task + !*/ + + void shutdown_pool ( + ); + /*! + ensures + - causes all threads to terminate and blocks the + caller until this happens. + !*/ + + private: + + bool is_worker_thread ( + const thread_id_type id + ) const; + /*! + requires + - m is locked + ensures + - if (thread with given id is one of the thread pool's worker threads or num_threads_in_pool() == 0) then + - returns true + - else + - returns false + !*/ + + void thread ( + ); + /*! + this is the function that executes the threads in the thread pool + !*/ + + long find_empty_task_slot ( + ) const; + /*! + requires + - m is locked + ensures + - if (there is currently a empty task slot) then + - returns the index of that task slot in tasks + - there is a task slot + - else + - returns -1 + !*/ + + long find_ready_task ( + ) const; + /*! + requires + - m is locked + ensures + - if (there is currently a task to do) then + - returns the index of that task in tasks + - else + - returns -1 + !*/ + + uint64 make_next_task_id ( + long idx + ); + /*! + requires + - m is locked + - 0 <= idx < tasks.size() + ensures + - returns the next index to be used for tasks that are placed in + tasks[idx] + !*/ + + unsigned long task_id_to_index ( + uint64 id + ) const; + /*! + requires + - m is locked + - num_threads_in_pool() != 0 + ensures + - returns the index in tasks corresponding to the given id + !*/ + + struct task_state_type + { + task_state_type() : is_being_processed(false), task_id(0), next_task_id(2), arg1(0), arg2(0), eptr(nullptr) {} + + bool is_ready () const + /*! + ensures + - if (is_empty() == false && no thread is currently processing this task) then + - returns true + - else + - returns false + !*/ + { + return !is_being_processed && !is_empty(); + } + + bool is_empty () const + /*! + ensures + - if (this task state is empty. i.e. it doesn't contain a task to be processed) then + - returns true + - else + - returns false + !*/ + { + return task_id == 0; + } + + bool is_being_processed; // true when a thread is working on this task + uint64 task_id; // the id of this task. 0 means this task is empty + thread_id_type thread_id; // the id of the thread that requested this task + + uint64 next_task_id; + + long arg1; + long arg2; + + member_function_pointer<> mfp0; + member_function_pointer mfp1; + member_function_pointer mfp2; + bfp_type bfp; + + std::shared_ptr function_copy; + mutable std::exception_ptr eptr; // non-null if the task threw an exception + + void propagate_exception() const + { + if (eptr) + { + auto tmp = eptr; + eptr = nullptr; + std::rethrow_exception(tmp); + } + } + + }; + + array tasks; + array worker_thread_ids; + + mutex m; + signaler task_done_signaler; + signaler task_ready_signaler; + bool we_are_destructing; + + std::vector threads; + + // restricted functions + thread_pool_implementation(thread_pool_implementation&); // copy constructor + thread_pool_implementation& operator=(thread_pool_implementation&); // assignment operator + + }; + + +// ---------------------------------------------------------------------------------------- + + class thread_pool + { + /*! + This object is just a shell that holds a std::shared_ptr + to the real thread_pool_implementation object. The reason for doing + it this way is so that we can allow any mixture of destruction orders + between thread_pool objects and futures. Whoever gets destroyed + last cleans up the thread_pool_implementation resources. + !*/ + typedef bound_function_pointer::kernel_1a_c bfp_type; + + public: + explicit thread_pool ( + unsigned long num_threads + ) + { + impl.reset(new thread_pool_implementation(num_threads)); + } + + ~thread_pool ( + ) + { + try + { + impl->shutdown_pool(); + } + catch (std::exception& e) + { + std::cerr << "An unhandled exception was inside a dlib::thread_pool when it was destructed." << std::endl; + std::cerr << "It's what string is: \n" << e.what() << std::endl; + using namespace std; + assert(false); + abort(); + } + catch (...) + { + std::cerr << "An unhandled exception was inside a dlib::thread_pool when it was destructed." << std::endl; + using namespace std; + assert(false); + abort(); + } + } + + void wait_for_task ( + uint64 task_id + ) const { impl->wait_for_task(task_id); } + + unsigned long num_threads_in_pool ( + ) const { return impl->num_threads_in_pool(); } + + void wait_for_all_tasks ( + ) const { impl->wait_for_all_tasks(); } + + bool is_task_thread ( + ) const { return impl->is_task_thread(); } + + template + uint64 add_task ( + T& obj, + void (T::*funct)() + ) + { + return impl->add_task(obj, funct); + } + + template + uint64 add_task ( + T& obj, + void (T::*funct)(long), + long arg1 + ) + { + return impl->add_task(obj, funct, arg1); + } + + template + uint64 add_task ( + T& obj, + void (T::*funct)(long,long), + long arg1, + long arg2 + ) + { + return impl->add_task(obj, funct, arg1, arg2); + } + + // -------------------- + + template + uint64 add_task ( + F& function_object + ) + { + COMPILE_TIME_ASSERT(is_function::value == false); + COMPILE_TIME_ASSERT(is_pointer_type::value == false); + + bfp_type temp; + temp.set(function_object); + uint64 id = impl->add_task_internal(temp); + + return id; + } + + template + uint64 add_task_by_value ( + const F& function_object + ) + { + thread_pool_implementation::function_object_copy_instance* ptr = 0; + ptr = new thread_pool_implementation::function_object_copy_instance(function_object); + std::shared_ptr function_copy(ptr); + + + bfp_type temp; + temp.set(ptr->item); + uint64 id = impl->add_task_internal(temp, function_copy); + + return id; + } + + template + uint64 add_task ( + const T& obj, + void (T::*funct)() const + ) + { + bfp_type temp; + temp.set(obj,funct); + uint64 id = impl->add_task_internal(temp); + + return id; + } + + template + uint64 add_task_by_value ( + const T& obj, + void (T::*funct)() const + ) + { + thread_pool_implementation::function_object_copy_instance* ptr = 0; + ptr = new thread_pool_implementation::function_object_copy_instance(obj); + std::shared_ptr function_copy(ptr); + + bfp_type temp; + temp.set(ptr->item,funct); + uint64 id = impl->add_task_internal(temp, function_copy); + + return id; + } + + template + uint64 add_task_by_value ( + const T& obj, + void (T::*funct)() + ) + { + thread_pool_implementation::function_object_copy_instance* ptr = 0; + ptr = new thread_pool_implementation::function_object_copy_instance(obj); + std::shared_ptr function_copy(ptr); + + bfp_type temp; + temp.set(ptr->item,funct); + uint64 id = impl->add_task_internal(temp, function_copy); + + return id; + } + + uint64 add_task ( + void (*funct)() + ) + { + bfp_type temp; + temp.set(funct); + uint64 id = impl->add_task_internal(temp); + + return id; + } + + // -------------------- + + template + uint64 add_task ( + F& function_object, + future& arg1 + ) + { + COMPILE_TIME_ASSERT(is_function::value == false); + COMPILE_TIME_ASSERT(is_pointer_type::value == false); + + bfp_type temp; + temp.set(function_object,arg1.get()); + uint64 id = impl->add_task_internal(temp); + + // tie the future to this task + arg1.task_id = id; + arg1.tp = impl; + return id; + } + + template + uint64 add_task_by_value ( + const F& function_object, + future& arg1 + ) + { + thread_pool_implementation::function_object_copy_instance* ptr = 0; + ptr = new thread_pool_implementation::function_object_copy_instance(function_object); + std::shared_ptr function_copy(ptr); + + bfp_type temp; + temp.set(ptr->item, arg1.get()); + uint64 id = impl->add_task_internal(temp, function_copy); + + // tie the future to this task + arg1.task_id = id; + arg1.tp = impl; + return id; + } + + template + uint64 add_task ( + T& obj, + void (T::*funct)(T1), + future& arg1 + ) + { + bfp_type temp; + temp.set(obj,funct,arg1.get()); + uint64 id = impl->add_task_internal(temp); + + // tie the future to this task + arg1.task_id = id; + arg1.tp = impl; + return id; + } + + template + uint64 add_task_by_value ( + const T& obj, + void (T::*funct)(T1), + future& arg1 + ) + { + thread_pool_implementation::function_object_copy_instance* ptr = 0; + ptr = new thread_pool_implementation::function_object_copy_instance(obj); + std::shared_ptr function_copy(ptr); + + bfp_type temp; + temp.set(ptr->item,funct,arg1.get()); + uint64 id = impl->add_task_internal(temp, function_copy); + + // tie the future to this task + arg1.task_id = id; + arg1.tp = impl; + return id; + } + + + template + uint64 add_task ( + const T& obj, + void (T::*funct)(T1) const, + future& arg1 + ) + { + bfp_type temp; + temp.set(obj,funct,arg1.get()); + uint64 id = impl->add_task_internal(temp); + + // tie the future to this task + arg1.task_id = id; + arg1.tp = impl; + return id; + } + + template + uint64 add_task_by_value ( + const T& obj, + void (T::*funct)(T1) const, + future& arg1 + ) + { + thread_pool_implementation::function_object_copy_instance* ptr = 0; + ptr = new thread_pool_implementation::function_object_copy_instance(obj); + std::shared_ptr function_copy(ptr); + + bfp_type temp; + temp.set(ptr->item,funct,arg1.get()); + uint64 id = impl->add_task_internal(temp, function_copy); + + // tie the future to this task + arg1.task_id = id; + arg1.tp = impl; + return id; + } + + template + uint64 add_task ( + void (*funct)(T1), + future& arg1 + ) + { + bfp_type temp; + temp.set(funct,arg1.get()); + uint64 id = impl->add_task_internal(temp); + + // tie the future to this task + arg1.task_id = id; + arg1.tp = impl; + return id; + } + + // -------------------- + + template + uint64 add_task ( + F& function_object, + future& arg1, + future& arg2 + ) + { + COMPILE_TIME_ASSERT(is_function::value == false); + COMPILE_TIME_ASSERT(is_pointer_type::value == false); + + bfp_type temp; + temp.set(function_object, arg1.get(), arg2.get()); + uint64 id = impl->add_task_internal(temp); + + // tie the future to this task + arg1.task_id = id; + arg1.tp = impl; + arg2.task_id = id; + arg2.tp = impl; + return id; + } + + template + uint64 add_task_by_value ( + const F& function_object, + future& arg1, + future& arg2 + ) + { + thread_pool_implementation::function_object_copy_instance* ptr = 0; + ptr = new thread_pool_implementation::function_object_copy_instance(function_object); + std::shared_ptr function_copy(ptr); + + bfp_type temp; + temp.set(ptr->item, arg1.get(), arg2.get()); + uint64 id = impl->add_task_internal(temp, function_copy); + + // tie the future to this task + arg1.task_id = id; + arg1.tp = impl; + arg2.task_id = id; + arg2.tp = impl; + return id; + } + + template + uint64 add_task ( + T& obj, + void (T::*funct)(T1,T2), + future& arg1, + future& arg2 + ) + { + bfp_type temp; + temp.set(obj, funct, arg1.get(), arg2.get()); + uint64 id = impl->add_task_internal(temp); + + // tie the futures to this task + arg1.task_id = id; + arg1.tp = impl; + arg2.task_id = id; + arg2.tp = impl; + return id; + } + + template + uint64 add_task_by_value ( + const T& obj, + void (T::*funct)(T1,T2), + future& arg1, + future& arg2 + ) + { + thread_pool_implementation::function_object_copy_instance* ptr = 0; + ptr = new thread_pool_implementation::function_object_copy_instance(obj); + std::shared_ptr function_copy(ptr); + + bfp_type temp; + temp.set(ptr->item, funct, arg1.get(), arg2.get()); + uint64 id = impl->add_task_internal(temp, function_copy); + + // tie the futures to this task + arg1.task_id = id; + arg1.tp = impl; + arg2.task_id = id; + arg2.tp = impl; + return id; + } + + template + uint64 add_task ( + const T& obj, + void (T::*funct)(T1,T2) const, + future& arg1, + future& arg2 + ) + { + bfp_type temp; + temp.set(obj, funct, arg1.get(), arg2.get()); + uint64 id = impl->add_task_internal(temp); + + // tie the futures to this task + arg1.task_id = id; + arg1.tp = impl; + arg2.task_id = id; + arg2.tp = impl; + return id; + } + + template + uint64 add_task_by_value ( + const T& obj, + void (T::*funct)(T1,T2) const, + future& arg1, + future& arg2 + ) + { + thread_pool_implementation::function_object_copy_instance* ptr = 0; + ptr = new thread_pool_implementation::function_object_copy_instance(obj); + std::shared_ptr function_copy(ptr); + + bfp_type temp; + temp.set(ptr->item, funct, arg1.get(), arg2.get()); + uint64 id = impl->add_task_internal(temp, function_copy); + + // tie the futures to this task + arg1.task_id = id; + arg1.tp = impl; + arg2.task_id = id; + arg2.tp = impl; + return id; + } + + template + uint64 add_task ( + void (*funct)(T1,T2), + future& arg1, + future& arg2 + ) + { + bfp_type temp; + temp.set(funct, arg1.get(), arg2.get()); + uint64 id = impl->add_task_internal(temp); + + // tie the futures to this task + arg1.task_id = id; + arg1.tp = impl; + arg2.task_id = id; + arg2.tp = impl; + return id; + } + + // -------------------- + + template + uint64 add_task ( + F& function_object, + future& arg1, + future& arg2, + future& arg3 + ) + { + COMPILE_TIME_ASSERT(is_function::value == false); + COMPILE_TIME_ASSERT(is_pointer_type::value == false); + + bfp_type temp; + temp.set(function_object, arg1.get(), arg2.get(), arg3.get()); + uint64 id = impl->add_task_internal(temp); + + // tie the future to this task + arg1.task_id = id; + arg1.tp = impl; + arg2.task_id = id; + arg2.tp = impl; + arg3.task_id = id; + arg3.tp = impl; + return id; + } + + template + uint64 add_task_by_value ( + const F& function_object, + future& arg1, + future& arg2, + future& arg3 + ) + { + thread_pool_implementation::function_object_copy_instance* ptr = 0; + ptr = new thread_pool_implementation::function_object_copy_instance(function_object); + std::shared_ptr function_copy(ptr); + + bfp_type temp; + temp.set(ptr->item, arg1.get(), arg2.get(), arg3.get()); + uint64 id = impl->add_task_internal(temp, function_copy); + + // tie the future to this task + arg1.task_id = id; + arg1.tp = impl; + arg2.task_id = id; + arg2.tp = impl; + arg3.task_id = id; + arg3.tp = impl; + return id; + } + + template + uint64 add_task ( + T& obj, + void (T::*funct)(T1,T2,T3), + future& arg1, + future& arg2, + future& arg3 + ) + { + bfp_type temp; + temp.set(obj, funct, arg1.get(), arg2.get(), arg3.get()); + uint64 id = impl->add_task_internal(temp); + + // tie the futures to this task + arg1.task_id = id; + arg1.tp = impl; + arg2.task_id = id; + arg2.tp = impl; + arg3.task_id = id; + arg3.tp = impl; + return id; + } + + template + uint64 add_task_by_value ( + const T& obj, + void (T::*funct)(T1,T2,T3), + future& arg1, + future& arg2, + future& arg3 + ) + { + thread_pool_implementation::function_object_copy_instance* ptr = 0; + ptr = new thread_pool_implementation::function_object_copy_instance(obj); + std::shared_ptr function_copy(ptr); + + bfp_type temp; + temp.set(ptr->item, funct, arg1.get(), arg2.get(), arg3.get()); + uint64 id = impl->add_task_internal(temp, function_copy); + + // tie the futures to this task + arg1.task_id = id; + arg1.tp = impl; + arg2.task_id = id; + arg2.tp = impl; + arg3.task_id = id; + arg3.tp = impl; + return id; + } + + template + uint64 add_task ( + const T& obj, + void (T::*funct)(T1,T2,T3) const, + future& arg1, + future& arg2, + future& arg3 + ) + { + bfp_type temp; + temp.set(obj, funct, arg1.get(), arg2.get(), arg3.get()); + uint64 id = impl->add_task_internal(temp); + + // tie the futures to this task + arg1.task_id = id; + arg1.tp = impl; + arg2.task_id = id; + arg2.tp = impl; + arg3.task_id = id; + arg3.tp = impl; + return id; + } + + template + uint64 add_task_by_value ( + const T& obj, + void (T::*funct)(T1,T2,T3) const, + future& arg1, + future& arg2, + future& arg3 + ) + { + thread_pool_implementation::function_object_copy_instance* ptr = 0; + ptr = new thread_pool_implementation::function_object_copy_instance(obj); + std::shared_ptr function_copy(ptr); + + bfp_type temp; + temp.set(ptr->item, funct, arg1.get(), arg2.get(), arg3.get()); + uint64 id = impl->add_task_internal(temp, function_copy); + + // tie the futures to this task + arg1.task_id = id; + arg1.tp = impl; + arg2.task_id = id; + arg2.tp = impl; + arg3.task_id = id; + arg3.tp = impl; + return id; + } + + template + uint64 add_task ( + void (*funct)(T1,T2,T3), + future& arg1, + future& arg2, + future& arg3 + ) + { + bfp_type temp; + temp.set(funct, arg1.get(), arg2.get(), arg3.get()); + uint64 id = impl->add_task_internal(temp); + + // tie the futures to this task + arg1.task_id = id; + arg1.tp = impl; + arg2.task_id = id; + arg2.tp = impl; + arg3.task_id = id; + arg3.tp = impl; + return id; + } + + // -------------------- + + template + uint64 add_task ( + F& function_object, + future& arg1, + future& arg2, + future& arg3, + future& arg4 + ) + { + COMPILE_TIME_ASSERT(is_function::value == false); + COMPILE_TIME_ASSERT(is_pointer_type::value == false); + + bfp_type temp; + temp.set(function_object, arg1.get(), arg2.get(), arg3.get(), arg4.get()); + uint64 id = impl->add_task_internal(temp); + + // tie the future to this task + arg1.task_id = id; + arg1.tp = impl; + arg2.task_id = id; + arg2.tp = impl; + arg3.task_id = id; + arg3.tp = impl; + arg4.task_id = id; + arg4.tp = impl; + return id; + } + + template + uint64 add_task_by_value ( + const F& function_object, + future& arg1, + future& arg2, + future& arg3, + future& arg4 + ) + { + thread_pool_implementation::function_object_copy_instance* ptr = 0; + ptr = new thread_pool_implementation::function_object_copy_instance(function_object); + std::shared_ptr function_copy(ptr); + + bfp_type temp; + temp.set(ptr->item, arg1.get(), arg2.get(), arg3.get(), arg4.get()); + uint64 id = impl->add_task_internal(temp, function_copy); + + // tie the future to this task + arg1.task_id = id; + arg1.tp = impl; + arg2.task_id = id; + arg2.tp = impl; + arg3.task_id = id; + arg3.tp = impl; + arg4.task_id = id; + arg4.tp = impl; + return id; + } + + template + uint64 add_task ( + T& obj, + void (T::*funct)(T1,T2,T3,T4), + future& arg1, + future& arg2, + future& arg3, + future& arg4 + ) + { + bfp_type temp; + temp.set(obj, funct, arg1.get(), arg2.get(), arg3.get(), arg4.get()); + uint64 id = impl->add_task_internal(temp); + + // tie the futures to this task + arg1.task_id = id; + arg1.tp = impl; + arg2.task_id = id; + arg2.tp = impl; + arg3.task_id = id; + arg3.tp = impl; + arg4.task_id = id; + arg4.tp = impl; + return id; + } + + template + uint64 add_task_by_value ( + const T& obj, + void (T::*funct)(T1,T2,T3,T4), + future& arg1, + future& arg2, + future& arg3, + future& arg4 + ) + { + thread_pool_implementation::function_object_copy_instance* ptr = 0; + ptr = new thread_pool_implementation::function_object_copy_instance(obj); + std::shared_ptr function_copy(ptr); + + bfp_type temp; + temp.set(ptr->item, funct, arg1.get(), arg2.get(), arg3.get(), arg4.get()); + uint64 id = impl->add_task_internal(temp, function_copy); + + // tie the futures to this task + arg1.task_id = id; + arg1.tp = impl; + arg2.task_id = id; + arg2.tp = impl; + arg3.task_id = id; + arg3.tp = impl; + arg4.task_id = id; + arg4.tp = impl; + return id; + } + + template + uint64 add_task ( + const T& obj, + void (T::*funct)(T1,T2,T3,T4) const, + future& arg1, + future& arg2, + future& arg3, + future& arg4 + ) + { + bfp_type temp; + temp.set(obj, funct, arg1.get(), arg2.get(), arg3.get(), arg4.get()); + uint64 id = impl->add_task_internal(temp); + + // tie the futures to this task + arg1.task_id = id; + arg1.tp = impl; + arg2.task_id = id; + arg2.tp = impl; + arg3.task_id = id; + arg3.tp = impl; + arg4.task_id = id; + arg4.tp = impl; + return id; + } + + template + uint64 add_task_by_value ( + const T& obj, + void (T::*funct)(T1,T2,T3,T4) const, + future& arg1, + future& arg2, + future& arg3, + future& arg4 + ) + { + thread_pool_implementation::function_object_copy_instance* ptr = 0; + ptr = new thread_pool_implementation::function_object_copy_instance(obj); + std::shared_ptr function_copy(ptr); + + bfp_type temp; + temp.set(ptr->item, funct, arg1.get(), arg2.get(), arg3.get(), arg4.get()); + uint64 id = impl->add_task_internal(temp, function_copy); + + // tie the futures to this task + arg1.task_id = id; + arg1.tp = impl; + arg2.task_id = id; + arg2.tp = impl; + arg3.task_id = id; + arg3.tp = impl; + arg4.task_id = id; + arg4.tp = impl; + return id; + } + + template + uint64 add_task ( + void (*funct)(T1,T2,T3,T4), + future& arg1, + future& arg2, + future& arg3, + future& arg4 + ) + { + bfp_type temp; + temp.set(funct, arg1.get(), arg2.get(), arg3.get(), arg4.get()); + uint64 id = impl->add_task_internal(temp); + + // tie the futures to this task + arg1.task_id = id; + arg1.tp = impl; + arg2.task_id = id; + arg2.tp = impl; + arg3.task_id = id; + arg3.tp = impl; + arg4.task_id = id; + arg4.tp = impl; + return id; + } + + private: + + std::shared_ptr impl; + + // restricted functions + thread_pool(thread_pool&); // copy constructor + thread_pool& operator=(thread_pool&); // assignment operator + + }; + + +// ---------------------------------------------------------------------------------------- + + template + void future:: + wait ( + ) const + { + if (tp) + { + tp->wait_for_task(task_id); + tp.reset(); + task_id = 0; + } + } + +} + +// ---------------------------------------------------------------------------------------- + +#ifdef NO_MAKEFILE +#include "thread_pool_extension.cpp" +#endif + +#endif // DLIB_THREAD_POOl_Hh_ + + diff --git a/ml/dlib/dlib/threads/thread_pool_extension_abstract.h b/ml/dlib/dlib/threads/thread_pool_extension_abstract.h new file mode 100644 index 000000000..ba54a7546 --- /dev/null +++ b/ml/dlib/dlib/threads/thread_pool_extension_abstract.h @@ -0,0 +1,842 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_THREAD_POOl_ABSTRACT_Hh_ +#ifdef DLIB_THREAD_POOl_ABSTRACT_Hh_ + +#include "threads_kernel_abstract.h" +#include "../uintn.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + class future + { + /*! + INITIAL VALUE + - is_ready() == true + + WHAT THIS OBJECT REPRESENTS + This object represents a container that allows you to safely pass objects + into the tasks performed by the thread_pool object defined below. An + example will make it clear: + + // Suppose you have a global function defined as follows + void add (int a, int b, int& result) { result = a + b; } + + // Also suppose you have a thread_pool named tp defined somewhere. + // Then you could do the following. + future a, b, result; + a = 3; + b = 4; + // this function call causes another thread to execute a call to the add() function + // and passes in the int objects contained in a, b, and result + tp.add_task(add,a,b,result); + // This line will wait for the task in the thread pool to finish and then print the + // value in the result integer. So it will print a 7. + cout << result << endl; + !*/ + + public: + future ( + ); + /*! + ensures + - The object of type T contained in this future has + an initial value for its type. + - #is_ready() == true + !*/ + + future ( + const T& item + ); + /*! + ensures + - #get() == item + - #is_ready() == true + !*/ + + future ( + const future& item + ); + /*! + ensures + - if (item.is_ready() == false) then + - the call to this function blocks until the thread processing the task related + to the item future has finished. + - #is_ready() == true + - #item.is_ready() == true + - #get() == item.get() + !*/ + + ~future ( + ); + /*! + ensures + - if (is_ready() == false) then + - the call to this function blocks until the thread processing the task related + to this future has finished. + !*/ + + bool is_ready ( + ) const; + /*! + ensures + - if (the value of this future may not yet be ready to be accessed because it + is in use by a task in a thread_pool) then + - returns false + - else + - returns true + !*/ + + future& operator=( + const T& item + ); + /*! + ensures + - if (is_ready() == false) then + - the call to this function blocks until the thread processing the task related + to this future has finished. + - #is_ready() == true + - #get() == item + - returns *this + !*/ + + future& operator=( + const future& item + ); + /*! + ensures + - if (is_ready() == false || item.is_ready() == false) then + - the call to this function blocks until the threads processing the tasks related + to this future and the item future have finished. + - #is_ready() == true + - #item.is_ready() == true + - #get() == item.get() + - returns *this + !*/ + + operator T& ( + ); + /*! + ensures + - if (is_ready() == false) then + - the call to this function blocks until the thread processing the task related + to this future has finished. + - #is_ready() == true + - returns get() + !*/ + + operator const T& ( + ) const; + /*! + ensures + - if (is_ready() == false) then + - the call to this function blocks until the thread processing the task related + to this future has finished. + - #is_ready() == true + - returns get() + !*/ + + T& get ( + ); + /*! + ensures + - if (is_ready() == false) then + - the call to this function blocks until the thread processing the task related + to this future has finished. + - #is_ready() == true + - returns a non-const reference to the object of type T contained inside this future + !*/ + + const T& get ( + ) const; + /*! + ensures + - if (is_ready() == false) then + - the call to this function blocks until the thread processing the task related + to this future has finished. + - #is_ready() == true + - returns a const reference to the object of type T contained inside this future + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template + inline void swap ( + future& a, + future& b + ) { std::swap(a.get(), b.get()); } + /*! + provides a global swap function + !*/ + +// ---------------------------------------------------------------------------------------- + + +// The future object comes with overloads for all the usual comparison operators. + + template bool operator== (const future& a, const future& b) { return a.get() == b.get(); } + template bool operator!= (const future& a, const future& b) { return a.get() != b.get(); } + template bool operator<= (const future& a, const future& b) { return a.get() <= b.get(); } + template bool operator>= (const future& a, const future& b) { return a.get() >= b.get(); } + template bool operator< (const future& a, const future& b) { return a.get() < b.get(); } + template bool operator> (const future& a, const future& b) { return a.get() > b.get(); } + + template bool operator== (const future& a, const T& b) { return a.get() == b; } + template bool operator== (const T& a, const future& b) { return a == b.get(); } + template bool operator!= (const future& a, const T& b) { return a.get() != b; } + template bool operator!= (const T& a, const future& b) { return a != b.get(); } + template bool operator<= (const future& a, const T& b) { return a.get() <= b; } + template bool operator<= (const T& a, const future& b) { return a <= b.get(); } + template bool operator>= (const future& a, const T& b) { return a.get() >= b; } + template bool operator>= (const T& a, const future& b) { return a >= b.get(); } + template bool operator< (const future& a, const T& b) { return a.get() < b; } + template bool operator< (const T& a, const future& b) { return a < b.get(); } + template bool operator> (const future& a, const T& b) { return a.get() > b; } + template bool operator> (const T& a, const future& b) { return a > b.get(); } + +// ---------------------------------------------------------------------------------------- + + class thread_pool + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a fixed size group of threads which you can + submit tasks to and then wait for those tasks to be completed. + + Note that setting the number of threads to 0 is a valid way to + use this object. It causes it to not contain any threads + at all. When tasks are submitted to the object in this mode + the tasks are processed within the calling thread. So in this + mode any thread that calls add_task() is considered to be + a thread_pool thread capable of executing tasks. + + This object is also implemented such that no memory allocations occur + after the thread_pool has been constructed so long as the user doesn't + call any of the add_task_by_value() routines. The future object also + doesn't perform any memory allocations or contain any system resources + such as mutex objects. + + EXCEPTIONS + Note that if an exception is thrown inside a task thread and is not caught + then the exception will be trapped inside the thread pool and rethrown at a + later time when someone calls one of the add task or wait member functions + of the thread pool. This allows exceptions to propagate out of task threads + and into the calling code where they can be handled. + !*/ + + public: + explicit thread_pool ( + unsigned long num_threads + ); + /*! + ensures + - #num_threads_in_pool() == num_threads + throws + - std::bad_alloc + - dlib::thread_error + the constructor may throw this exception if there is a problem + gathering resources to create threading objects. + !*/ + + ~thread_pool( + ); + /*! + ensures + - blocks until all tasks in the pool have finished. + - If one of the threads has generated an exception but it hasn't yet been + rethrown to the caller (e.g. by calling wait_for_all_tasks()) then the + program will be terminated. So make sure you handle all the possible + exceptions from your tasks. + !*/ + + bool is_task_thread ( + ) const; + /*! + ensures + - if (the thread calling this function is one of the threads in this + thread pool or num_threads_in_pool() == 0) then + - returns true + - else + - returns false + !*/ + + unsigned long num_threads_in_pool ( + ) const; + /*! + ensures + - returns the number of threads contained in this thread pool. That is, returns + the maximum number of tasks that this object will process concurrently. + !*/ + + template + uint64 add_task_by_value ( + const F& function_object + ); + /*! + requires + - function_object() is a valid expression + ensures + - makes a copy of function_object, call it FCOPY. + - if (is_task_thread() == true and there aren't any free threads available) then + - calls FCOPY() within the calling thread and returns when it finishes + - else + - the call to this function blocks until there is a free thread in the pool + to process this new task. Once a free thread is available the task + is handed off to that thread which then calls FCOPY(). + - returns a task id that can be used by this->wait_for_task() to wait + for the submitted task to finish. + !*/ + + template + uint64 add_task ( + T& obj, + void (T::*funct)() + ); + /*! + requires + - funct == a valid member function pointer for class T + - obj will not go out of scope until after the task has completed (i.e. + this function passes obj to the task by reference. If you want to avoid + this restriction then use add_task_by_value()) + ensures + - if (is_task_thread() == true and there aren't any free threads available) then + - calls (obj.*funct)() within the calling thread and returns + when it finishes. + - else + - the call to this function blocks until there is a free thread in the pool + to process this new task. Once a free thread is available the task + is handed off to that thread which then calls (obj.*funct)() + - returns a task id that can be used by this->wait_for_task() to wait + for the submitted task to finish. + !*/ + + template + uint64 add_task_by_value ( + const T& obj, + void (T::*funct)() + ); + /*! + requires + - funct == a valid member function pointer for class T + ensures + - makes a copy of obj, call it OBJ_COPY. + - if (is_task_thread() == true and there aren't any free threads available) then + - calls (OBJ_COPY.*funct)() within the calling thread and returns + when it finishes. + - else + - the call to this function blocks until there is a free thread in the pool + to process this new task. Once a free thread is available the task + is handed off to that thread which then calls (OBJ_COPY.*funct)(). + - returns a task id that can be used by this->wait_for_task() to wait + for the submitted task to finish. + !*/ + + template + uint64 add_task ( + T& obj, + void (T::*funct)(long), + long arg1 + ); + /*! + requires + - funct == a valid member function pointer for class T + - obj will not go out of scope until after the task has completed (i.e. + this function passes obj to the task by reference. If you want to avoid + this restriction then use add_task_by_value()) + ensures + - if (is_task_thread() == true and there aren't any free threads available) then + - calls (obj.*funct)(arg1) within the calling thread and returns + when it finishes + - else + - the call to this function blocks until there is a free thread in the pool + to process this new task. Once a free thread is available the task + is handed off to that thread which then calls (obj.*funct)(arg1) + - returns a task id that can be used by this->wait_for_task() to wait + for the submitted task to finish. + !*/ + + template + uint64 add_task ( + T& obj, + void (T::*funct)(long,long), + long arg1, + long arg2 + ); + /*! + requires + - funct == a valid member function pointer for class T + - obj will not go out of scope until after the task has completed (i.e. + this function passes obj to the task by reference. If you want to avoid + this restriction then use add_task_by_value()) + ensures + - if (is_task_thread() == true and there aren't any free threads available) then + - calls (obj.*funct)(arg1,arg2) within the calling thread and returns + when it finishes + - else + - the call to this function blocks until there is a free thread in the pool + to process this new task. Once a free thread is available the task + is handed off to that thread which then calls (obj.*funct)(arg1,arg2) + - returns a task id that can be used by this->wait_for_task() to wait + for the submitted task to finish. + !*/ + + void wait_for_task ( + uint64 task_id + ) const; + /*! + ensures + - if (there is currently a task with the given id being executed in the thread pool) then + - the call to this function blocks until the task with the given id is complete + - else + - the call to this function returns immediately + !*/ + + void wait_for_all_tasks ( + ) const; + /*! + ensures + - the call to this function blocks until all tasks which were submitted + to the thread pool by the thread that is calling this function have + finished. + !*/ + + // -------------------- + + template + uint64 add_task ( + F& function_object, + future& arg1 + ); + /*! + requires + - function_object(arg1.get()) is a valid expression + (i.e. The A1 type stored in the future must be a type that can be passed into the given function object) + - function_object will not go out of scope until after the task has completed (i.e. + this function passes function_object to the task by reference. If you want to avoid + this restriction then use add_task_by_value()) + ensures + - if (is_task_thread() == true and there aren't any free threads available) then + - calls function_object(arg1.get()) within the calling thread and returns + when it finishes + - else + - the call to this function blocks until there is a free thread in the pool + to process this new task. Once a free thread is available the task + is handed off to that thread which then calls function_object(arg1.get()). + - #arg1.is_ready() == false + - returns a task id that can be used by this->wait_for_task() to wait + for the submitted task to finish. + !*/ + + template + uint64 add_task_by_value ( + const F& function_object, + future& arg1 + ); + /*! + requires + - function_object(arg1.get()) is a valid expression + (i.e. The A1 type stored in the future must be a type that can be passed into the given function object) + ensures + - makes a copy of function_object, call it FCOPY. + - if (is_task_thread() == true and there aren't any free threads available) then + - calls FCOPY(arg1.get()) within the calling thread and returns when it finishes + - else + - the call to this function blocks until there is a free thread in the pool + to process this new task. Once a free thread is available the task + is handed off to that thread which then calls FCOPY(arg1.get()). + - #arg1.is_ready() == false + - returns a task id that can be used by this->wait_for_task() to wait + for the submitted task to finish. + !*/ + + template + uint64 add_task ( + T& obj, + void (T::*funct)(T1), + future& arg1 + ); + /*! + requires + - funct == a valid member function pointer for class T + - (obj.*funct)(arg1.get()) must be a valid expression. + (i.e. The A1 type stored in the future must be a type that can be passed into the given function) + - obj will not go out of scope until after the task has completed (i.e. + this function passes obj to the task by reference. If you want to avoid + this restriction then use add_task_by_value()) + ensures + - if (is_task_thread() == true and there aren't any free threads available) then + - calls (obj.*funct)(arg1.get()) within the calling thread and returns + when it finishes + - else + - the call to this function blocks until there is a free thread in the pool + to process this new task. Once a free thread is available the task + is handed off to that thread which then calls (obj.*funct)(arg1.get()). + - #arg1.is_ready() == false + - returns a task id that can be used by this->wait_for_task() to wait + for the submitted task to finish. + !*/ + + template + uint64 add_task_by_value ( + const T& obj, + void (T::*funct)(T1), + future& arg1 + ); + /*! + requires + - funct == a valid member function pointer for class T + - (obj.*funct)(arg1.get()) must be a valid expression. + (i.e. The A1 type stored in the future must be a type that can be passed into the given function) + ensures + - makes a copy of obj, call it OBJ_COPY. + - if (is_task_thread() == true and there aren't any free threads available) then + - calls (OBJ_COPY.*funct)(arg1.get()) within the calling thread and returns + when it finishes. + - else + - the call to this function blocks until there is a free thread in the pool + to process this new task. Once a free thread is available the task + is handed off to that thread which then calls (OBJ_COPY.*funct)(arg1.get()). + - returns a task id that can be used by this->wait_for_task() to wait + for the submitted task to finish. + !*/ + + template + uint64 add_task ( + const T& obj, + void (T::*funct)(T1) const, + future& arg1 + ); + /*! + requires + - funct == a valid member function pointer for class T + - (obj.*funct)(arg1.get()) must be a valid expression. + (i.e. The A1 type stored in the future must be a type that can be passed into the given function) + - obj will not go out of scope until after the task has completed (i.e. + this function passes obj to the task by reference. If you want to avoid + this restriction then use add_task_by_value()) + ensures + - if (is_task_thread() == true and there aren't any free threads available) then + - calls (obj.*funct)(arg1.get()) within the calling thread and returns + when it finishes + - else + - the call to this function blocks until there is a free thread in the pool + to process this new task. Once a free thread is available the task + is handed off to that thread which then calls (obj.*funct)(arg1.get()). + - #arg1.is_ready() == false + - returns a task id that can be used by this->wait_for_task() to wait + for the submitted task to finish. + !*/ + + template + uint64 add_task_by_value ( + const T& obj, + void (T::*funct)(T1) const, + future& arg1 + ); + /*! + requires + - funct == a valid member function pointer for class T + - (obj.*funct)(arg1.get()) must be a valid expression. + (i.e. The A1 type stored in the future must be a type that can be passed into the given function) + ensures + - makes a copy of obj, call it OBJ_COPY. + - if (is_task_thread() == true and there aren't any free threads available) then + - calls (OBJ_COPY.*funct)(arg1.get()) within the calling thread and returns + when it finishes. + - else + - the call to this function blocks until there is a free thread in the pool + to process this new task. Once a free thread is available the task + is handed off to that thread which then calls (OBJ_COPY.*funct)(arg1.get()). + - returns a task id that can be used by this->wait_for_task() to wait + for the submitted task to finish. + !*/ + + template + uint64 add_task ( + void (*funct)(T1), + future& arg1 + ); + /*! + requires + - funct == a valid function pointer + - (funct)(arg1.get()) must be a valid expression. + (i.e. The A1 type stored in the future must be a type that can be passed into the given function) + ensures + - if (is_task_thread() == true and there aren't any free threads available) then + - calls funct(arg1.get()) within the calling thread and returns + when it finishes + - else + - the call to this function blocks until there is a free thread in the pool + to process this new task. Once a free thread is available the task + is handed off to that thread which then calls funct(arg1.get()). + - #arg1.is_ready() == false + - returns a task id that can be used by this->wait_for_task() to wait + for the submitted task to finish. + !*/ + + // -------------------------------------------------------------------------------- + // The remainder of this class just contains overloads for add_task() and add_task_by_value() + // that take up to 4 futures (as well as 0 futures). Their behavior is identical to the above + // add_task() and add_task_by_value() functions. + // -------------------------------------------------------------------------------- + + template + uint64 add_task ( + F& function_object, + future& arg1, + future& arg2 + ); + + template + uint64 add_task_by_value ( + const F& function_object, + future& arg1, + future& arg2 + ); + + template + uint64 add_task ( + T& obj, + void (T::*funct)(T1,T2), + future& arg1, + future& arg2 + ); + + uint64 add_task_by_value ( + const T& obj, + void (T::*funct)(T1,T2), + future& arg1, + future& arg2 + ); + + template + uint64 add_task ( + const T& obj, + void (T::*funct)(T1,T2) const, + future& arg1, + future& arg2 + ); + + template + uint64 add_task_by_value ( + const T& obj, + void (T::*funct)(T1,T2) const, + future& arg1, + future& arg2 + ); + + template + uint64 add_task ( + void (*funct)(T1,T2), + future& arg1, + future& arg2 + ); + + // -------------------- + + template + uint64 add_task ( + F& function_object, + future& arg1, + future& arg2, + future& arg3 + ); + + template + uint64 add_task_by_value ( + const F& function_object, + future& arg1, + future& arg2, + future& arg3 + ); + + template + uint64 add_task ( + T& obj, + void (T::*funct)(T1,T2,T3), + future& arg1, + future& arg2, + future& arg3 + ); + + template + uint64 add_task_by_value ( + const T& obj, + void (T::*funct)(T1,T2,T3), + future& arg1, + future& arg2, + future& arg3 + ); + + template + uint64 add_task ( + const T& obj, + void (T::*funct)(T1,T2,T3) const, + future& arg1, + future& arg2, + future& arg3 + ); + + template + uint64 add_task_by_value ( + const T& obj, + void (T::*funct)(T1,T2,T3) const, + future& arg1, + future& arg2, + future& arg3 + ); + + template + uint64 add_task ( + void (*funct)(T1,T2,T3), + future& arg1, + future& arg2, + future& arg3 + ); + + // -------------------- + + template + uint64 add_task ( + F& function_object, + future& arg1, + future& arg2, + future& arg3, + future& arg4 + ); + + template + uint64 add_task_by_value ( + const F& function_object, + future& arg1, + future& arg2, + future& arg3, + future& arg4 + ); + + template + uint64 add_task ( + T& obj, + void (T::*funct)(T1,T2,T3,T4), + future& arg1, + future& arg2, + future& arg3, + future& arg4 + ); + + template + uint64 add_task_by_value ( + const T& obj, + void (T::*funct)(T1,T2,T3,T4), + future& arg1, + future& arg2, + future& arg3, + future& arg4 + ); + + template + uint64 add_task ( + const T& obj, + void (T::*funct)(T1,T2,T3,T4) const, + future& arg1, + future& arg2, + future& arg3, + future& arg4 + ); + + template + uint64 add_task_by_value ( + const T& obj, + void (T::*funct)(T1,T2,T3,T4) const, + future& arg1, + future& arg2, + future& arg3, + future& arg4 + ); + + template + uint64 add_task ( + void (*funct)(T1,T2,T3,T4), + future& arg1, + future& arg2, + future& arg3, + future& arg4 + ); + + // -------------------- + + template + uint64 add_task ( + F& function_object + ); + + template + uint64 add_task ( + const T& obj, + void (T::*funct)() const, + ); + + template + uint64 add_task_by_value ( + const T& obj, + void (T::*funct)() const + ); + + uint64 add_task ( + void (*funct)() + ); + + // -------------------- + + private: + + // restricted functions + thread_pool(thread_pool&); // copy constructor + thread_pool& operator=(thread_pool&); // assignment operator + }; + +} + +// ---------------------------------------------------------------------------------------- + +#endif // DLIB_THREAD_POOl_ABSTRACT_Hh_ + + + diff --git a/ml/dlib/dlib/threads/thread_specific_data_extension.h b/ml/dlib/dlib/threads/thread_specific_data_extension.h new file mode 100644 index 000000000..0b5339200 --- /dev/null +++ b/ml/dlib/dlib/threads/thread_specific_data_extension.h @@ -0,0 +1,141 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_THREAD_SPECIFIC_DATA_EXTENSIOn_ +#define DLIB_THREAD_SPECIFIC_DATA_EXTENSIOn_ + +#include "thread_specific_data_extension_abstract.h" +#include "threads_kernel_abstract.h" +#include "../binary_search_tree.h" +#include "auto_mutex_extension.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + class thread_specific_data + { + /*! + CONVENTION + - for all valid ID: + (*items[ID]) == pointer to the data for thread with id ID + !*/ + public: + + thread_specific_data ( + ) + { + thread_end_handler_calls_left = 0; + } + + ~thread_specific_data ( + ) + { + // We should only call the unregister_thread_end_handler function if there are + // some outstanding callbacks we expect to get. Otherwise lets avoid calling it + // since the dlib state that maintains the registered thread end handlers may have + // been destructed already (since the program might be in the process of terminating). + bool call_unregister = false; + m.lock(); + if (thread_end_handler_calls_left > 0) + call_unregister = true; + m.unlock(); + + if (call_unregister) + unregister_thread_end_handler(const_cast(*this),&thread_specific_data::thread_end_handler); + + auto_mutex M(m); + items.reset(); + while (items.move_next()) + { + delete items.element().value(); + } + } + + inline T& data ( + ) { return get_data(); } + + inline const T& data ( + ) const { return get_data(); } + + private: + + T& get_data ( + ) const + { + thread_id_type id = get_thread_id(); + auto_mutex M(m); + + T** item = items[id]; + if (item) + { + return **item; + } + else + { + // register an end handler for this thread so long as it is a dlib created thread. + T* new_item = new T; + + bool in_tree = false; + try + { + T* temp_item = new_item; + thread_id_type temp_id = id; + items.add(temp_id,temp_item); + in_tree = true; + + if (is_dlib_thread(id)) + { + register_thread_end_handler(const_cast(*this),&thread_specific_data::thread_end_handler); + ++thread_end_handler_calls_left; + } + } + catch (...) + { + if (in_tree) + { + items.destroy(id); + } + delete new_item; + throw; + } + + return *new_item; + } + } + + void thread_end_handler ( + ) + { + const thread_id_type id = get_thread_id(); + thread_id_type junk = 0; + T* item = 0; + auto_mutex M(m); + --thread_end_handler_calls_left; + if (items[id]) + { + items.remove(id,junk,item); + delete item; + } + } + + mutable typename binary_search_tree::kernel_2a items; + mutex m; + mutable long thread_end_handler_calls_left; + + // restricted functions + thread_specific_data(thread_specific_data&); // copy constructor + thread_specific_data& operator=(thread_specific_data&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_THREAD_SPECIFIC_DATA_EXTENSIOn_ + + + diff --git a/ml/dlib/dlib/threads/thread_specific_data_extension_abstract.h b/ml/dlib/dlib/threads/thread_specific_data_extension_abstract.h new file mode 100644 index 000000000..03fb9ddaa --- /dev/null +++ b/ml/dlib/dlib/threads/thread_specific_data_extension_abstract.h @@ -0,0 +1,87 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_THREAD_SPECIFIC_DATA_EXTENSIOn_ABSTRACT_ +#ifdef DLIB_THREAD_SPECIFIC_DATA_EXTENSIOn_ABSTRACT_ + +#include "threads_kernel_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + class thread_specific_data + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents a container of thread specific data. When + a thread calls the data() member function it gets a reference to a T object + that is specific to its own thread. Each subsequent call to data() from that + thread returns the same instance. Also note that when a thread ends + the instance of its data() object gets destroyed and freed (if the thread + was created by the dlib library). So any pointers or references to the object + will be invalid after the thread has ended. + !*/ + public: + + thread_specific_data ( + ); + /*! + ensures + - #*this is properly initialized + !*/ + + ~thread_specific_data ( + ); + /*! + ensures + - all resources allocated by *this have been freed. This includes + all the thread specific data returned by the data() functions. + !*/ + + T& data ( + ); + /*! + ensures + - if (the calling thread has NOT called this->data() before) then + - constructs an instance of T that is specific to the calling + thread. + - returns a reference to the T instance that was constructed for + the calling thread. + throws + - std::bad_alloc or any exception thrown by T's constructor + If an exception is thrown then the call to data() will have + no effect on *this. + !*/ + + const T& data ( + ) const; + /*! + ensures + - if (the calling thread has NOT called this->data() before) then + - constructs an instance of T that is specific to the calling + thread. + - returns a const reference to the T instance that was constructed for + the calling thread. + throws + - std::bad_alloc or any exception thrown by T's constructor + If an exception is thrown then the call to data() will have + no effect on *this. + !*/ + + private: + // restricted functions + thread_specific_data(thread_specific_data&); // copy constructor + thread_specific_data& operator=(thread_specific_data&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_THREAD_SPECIFIC_DATA_EXTENSIOn_ABSTRACT_ + + diff --git a/ml/dlib/dlib/threads/threaded_object_extension.cpp b/ml/dlib/dlib/threads/threaded_object_extension.cpp new file mode 100644 index 000000000..a7326c11d --- /dev/null +++ b/ml/dlib/dlib/threads/threaded_object_extension.cpp @@ -0,0 +1,290 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_THREADED_OBJECT_EXTENSIOn_CPP +#define DLIB_THREADED_OBJECT_EXTENSIOn_CPP + +#include "threaded_object_extension.h" +#include "create_new_thread_extension.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + threaded_object:: + threaded_object ( + ): + s(m_), + id1(0), + is_running_(false), + is_alive_(false), + should_stop_(false), + id_valid(false) + { + } + +// ---------------------------------------------------------------------------------------- + + threaded_object:: + ~threaded_object ( + ) + { + try + { + DLIB_ASSERT(is_alive() == false, + "\tthreaded_object::~threaded_object()" + << "\n\tYou have let a threaded object destruct itself before terminating its thread" + << "\n\tthis: " << this + ); + } + catch (std::exception& e) + { + std::cerr << e.what() << std::endl; + assert(false); + abort(); + } + } + +// ---------------------------------------------------------------------------------------- + + bool threaded_object:: + is_running ( + ) const + { + auto_mutex M(m_); + + DLIB_ASSERT(id1 != get_thread_id() || id_valid == false, + "\tbool threaded_object::is_running()" + << "\n\tYou can NOT call this function from the thread that executes threaded_object::thread" + << "\n\tthis: " << this + ); + + return is_running_; + } + +// ---------------------------------------------------------------------------------------- + + bool threaded_object:: + is_alive ( + ) const + { + auto_mutex M(m_); + + DLIB_ASSERT(id1 != get_thread_id() || id_valid == false, + "\tbool threaded_object::is_alive()" + << "\n\tYou can NOT call this function from the thread that executes threaded_object::thread" + << "\n\tthis: " << this + ); + + return is_alive_; + } + +// ---------------------------------------------------------------------------------------- + + void threaded_object:: + wait ( + ) const + { + auto_mutex M(m_); + + DLIB_ASSERT(id1 != get_thread_id() || id_valid == false, + "\tvoid threaded_object::wait()" + << "\n\tYou can NOT call this function from the thread that executes threaded_object::thread" + << "\n\tthis: " << this + ); + + while (is_alive_) + s.wait(); + } + +// ---------------------------------------------------------------------------------------- + + void threaded_object:: + start ( + ) + { + auto_mutex M(m_); + + DLIB_ASSERT(id1 != get_thread_id() || id_valid == false, + "\tvoid threaded_object::start()" + << "\n\tYou can NOT call this function from the thread that executes threaded_object::thread" + << "\n\tthis: " << this + ); + + if (is_alive_ == false) + { + if (create_new_thread(*this) == false) + { + is_running_ = false; + throw thread_error(); + } + } + is_alive_ = true; + is_running_ = true; + should_stop_ = false; + s.broadcast(); + } + +// ---------------------------------------------------------------------------------------- + + void threaded_object:: + restart ( + ) + { + auto_mutex M(m_); + + DLIB_ASSERT(id1 != get_thread_id() || id_valid == false, + "\tvoid threaded_object::restart()" + << "\n\tYou can NOT call this function from the thread that executes threaded_object::thread" + << "\n\tthis: " << this + ); + + if (is_alive_ == false) + { + if (create_new_thread(*this) == false) + { + is_running_ = false; + throw thread_error(); + } + should_respawn_ = false; + } + else + { + should_respawn_ = true; + } + is_alive_ = true; + is_running_ = true; + should_stop_ = false; + s.broadcast(); + } + +// ---------------------------------------------------------------------------------------- + + void threaded_object:: + set_respawn ( + ) + { + auto_mutex M(m_); + + DLIB_ASSERT(id1 != get_thread_id() || id_valid == false, + "\tvoid threaded_object::set_respawn()" + << "\n\tYou can NOT call this function from the thread that executes threaded_object::thread" + << "\n\tthis: " << this + ); + + should_respawn_ = true; + } + +// ---------------------------------------------------------------------------------------- + + bool threaded_object:: + should_respawn ( + ) const + { + auto_mutex M(m_); + + DLIB_ASSERT(id1 != get_thread_id() || id_valid == false, + "\tbool threaded_object::should_respawn()" + << "\n\tYou can NOT call this function from the thread that executes threaded_object::thread" + << "\n\tthis: " << this + ); + + return should_respawn_; + } + +// ---------------------------------------------------------------------------------------- + + void threaded_object:: + pause ( + ) + { + auto_mutex M(m_); + + DLIB_ASSERT(id1 != get_thread_id() || id_valid == false, + "\tvoid threaded_object::pause()" + << "\n\tYou can NOT call this function from the thread that executes threaded_object::thread" + << "\n\tthis: " << this + ); + + is_running_ = false; + } + +// ---------------------------------------------------------------------------------------- + + void threaded_object:: + stop ( + ) + { + auto_mutex M(m_); + + DLIB_ASSERT(id1 != get_thread_id() || id_valid == false, + "\tvoid threaded_object::stop()" + << "\n\tYou can NOT call this function from the thread that executes threaded_object::thread" + << "\n\tthis: " << this + ); + + should_stop_ = true; + is_running_ = false; + should_respawn_ = false; + s.broadcast(); + } + +// ---------------------------------------------------------------------------------------- + + bool threaded_object:: + should_stop ( + ) const + { + auto_mutex M(m_); + DLIB_ASSERT(is_alive_ && id1 == get_thread_id() && id_valid == true, + "\tbool threaded_object::should_stop()" + << "\n\tYou can only call this function from the thread that executes threaded_object::thread" + << "\n\tthis: " << this + ); + while (is_running_ == false && should_stop_ == false) + s.wait(); + return should_stop_; + } + +// ---------------------------------------------------------------------------------------- + + void threaded_object:: + thread_helper( + ) + { +#ifdef ENABLE_ASSERTS + id1 = get_thread_id(); + id_valid = true; +#endif + while (true) + { + m_.lock(); + should_respawn_ = false; + m_.unlock(); + + thread(); + + auto_mutex M(m_); + + if (should_respawn_) + continue; + +#ifdef ENABLE_ASSERTS + id_valid = false; +#endif + + is_alive_ = false; + is_running_ = false; + should_stop_ = false; + s.broadcast(); + + return; + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_THREADED_OBJECT_EXTENSIOn_CPP + diff --git a/ml/dlib/dlib/threads/threaded_object_extension.h b/ml/dlib/dlib/threads/threaded_object_extension.h new file mode 100644 index 000000000..dcf00daea --- /dev/null +++ b/ml/dlib/dlib/threads/threaded_object_extension.h @@ -0,0 +1,123 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_THREADED_OBJECT_EXTENSIOn_ +#define DLIB_THREADED_OBJECT_EXTENSIOn_ + +#include "threaded_object_extension_abstract.h" +#include "threads_kernel.h" +#include "auto_mutex_extension.h" +#include "../algs.h" +#include "../assert.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class threaded_object + { + /*! + INITIAL VALUE + - is_running_ == false + - is_alive_ == false + - should_stop_ == false + - should_respawn_ == false + +#ifdef ENABLE_ASSERTS + - id_valid == false + - id1 == get_main_thread_id() +#endif + + CONVENTION + - is_running() == is_running_ + - is_alive() == is_alive_ + - should_stop() == should_stop_ + - should_respawn() == should_respawn_ + + +#ifdef ENABLE_ASSERTS + - if (when thread() is executing) then + - id1 == the id of the running thread + - id_valid == true + - else + - id1 == an undefined value + - id_valid == false +#endif + + - m_ == the mutex used to protect all our variables + - s == the signaler for m_ + !*/ + + public: + + threaded_object ( + ); + + virtual ~threaded_object ( + ); + + bool is_running ( + ) const; + + bool is_alive ( + ) const; + + void wait ( + ) const; + + void start ( + ); + + void restart ( + ); + + void set_respawn ( + ); + + bool should_respawn ( + ) const; + + void pause ( + ); + + void stop ( + ); + + protected: + + bool should_stop ( + ) const; + + private: + + void thread_helper( + ); + + virtual void thread ( + ) = 0; + + mutex m_; + signaler s; + thread_id_type id1; + bool is_running_; + bool is_alive_; + bool should_stop_; + bool should_respawn_; + bool id_valid; + + // restricted functions + threaded_object(threaded_object&); // copy constructor + threaded_object& operator=(threaded_object&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- + +} + +#ifdef NO_MAKEFILE +#include "threaded_object_extension.cpp" +#endif + +#endif // DLIB_THREADED_OBJECT_EXTENSIOn_ + + diff --git a/ml/dlib/dlib/threads/threaded_object_extension_abstract.h b/ml/dlib/dlib/threads/threaded_object_extension_abstract.h new file mode 100644 index 000000000..32a8fbc31 --- /dev/null +++ b/ml/dlib/dlib/threads/threaded_object_extension_abstract.h @@ -0,0 +1,199 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_THREADED_OBJECT_EXTENSIOn_ABSTRACT_ +#ifdef DLIB_THREADED_OBJECT_EXTENSIOn_ABSTRACT_ + +#include "threads_kernel_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class threaded_object + { + /*! + INITIAL VALUE + - is_running() == false + - is_alive() == false + - should_respawn() == false + + WHAT THIS OBJECT REPRESENTS + This object represents a simple threaded object. To use it you inherit + from it and define the thread() function. Then when you call start() + it will spawn a thread that calls this->thread(). + !*/ + public: + + threaded_object ( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc + - dlib::thread_error + the constructor may throw this exception if there is a problem + gathering resources to create threading objects. + !*/ + + virtual ~threaded_object ( + ); + /*! + requires + - is_alive() == false + (i.e. in the destructor for the object you derive from this one you + must wait for this->thread() to end.) + ensures + - all resources allocated by *this have been freed. + !*/ + + bool is_running ( + ) const; + /*! + requires + - is not called from this->thread() + ensures + - if (is_alive() && this->thread() is currently supposed to be executing) then + - returns true + - else + - returns false + !*/ + + bool is_alive ( + ) const; + /*! + requires + - is not called from this->thread() + ensures + - if (this->thread() has been called by some thread and has yet to terminate) then + - returns true + - else + - returns false + !*/ + + void wait ( + ) const; + /*! + requires + - is not called from this->thread() + ensures + - if (is_alive() == true) then + - blocks until this->thread() terminates + !*/ + + void start ( + ); + /*! + requires + - is not called from this->thread() + ensures + - #is_alive() == true + - #is_running() == true + - #should_stop() == false + throws + - std::bad_alloc or dlib::thread_error + If either of these exceptions are thrown then + #is_alive() == false and #is_running() == false + !*/ + + void set_respawn ( + ); + /*! + requires + - is not called from this->thread() + ensures + - #should_respawn() == true + !*/ + + bool should_respawn ( + ) const; + /*! + requires + - is not called from this->thread() + ensures + - returns true if the thread will automatically restart upon termination and + false otherwise. Note that every time a thread starts it sets should_respawn() + back to false. Therefore, a single call to set_respawn() can cause at most + one respawn to occur. + !*/ + + void restart ( + ); + /*! + requires + - is not called from this->thread() + ensures + - This function atomically executes set_respawn() and start(). The precise meaning of this + is defined below. + - if (is_alive()) then + - #should_respawn() == true + - else + - #should_respawn() == false + - #is_alive() == true + - #is_running() == true + - #should_stop() == false + throws + - std::bad_alloc or dlib::thread_error + If either of these exceptions are thrown then + #is_alive() == false and #is_running() == false + !*/ + + void pause ( + ); + /*! + requires + - is not called from this->thread() + ensures + - #is_running() == false + !*/ + + void stop ( + ); + /*! + requires + - is not called from this->thread() + ensures + - #should_stop() == true + - #is_running() == false + - #should_respawn() == false + !*/ + + protected: + + bool should_stop ( + ) const; + /*! + requires + - is only called from the thread that executes this->thread() + ensures + - calls to this function block until (#is_running() == true || #should_stop() == true) + - if (this thread is supposed to terminate) then + - returns true + - else + - returns false + !*/ + + private: + + virtual void thread ( + ) = 0; + /*! + requires + - is executed in its own thread + - is only executed in one thread at a time + throws + - does not throw any exceptions + !*/ + + // restricted functions + threaded_object(threaded_object&); // copy constructor + threaded_object& operator=(threaded_object&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_THREADED_OBJECT_EXTENSIOn_ABSTRACT_ + diff --git a/ml/dlib/dlib/threads/threads_kernel.h b/ml/dlib/dlib/threads/threads_kernel.h new file mode 100644 index 000000000..77cb16d92 --- /dev/null +++ b/ml/dlib/dlib/threads/threads_kernel.h @@ -0,0 +1,18 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_THREADs_KERNEL_ +#define DLIB_THREADs_KERNEL_ + +#include "../platform.h" + +#ifdef WIN32 +#include "windows.h" +#endif + +#ifndef WIN32 +#include "posix.h" +#endif + +#endif // DLIB_THREADs_KERNEL_ + + diff --git a/ml/dlib/dlib/threads/threads_kernel_1.cpp b/ml/dlib/dlib/threads/threads_kernel_1.cpp new file mode 100644 index 000000000..cb36b8d3f --- /dev/null +++ b/ml/dlib/dlib/threads/threads_kernel_1.cpp @@ -0,0 +1,83 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_THREADS_KERNEL_1_CPp_ +#define DLIB_THREADS_KERNEL_1_CPp_ + +#include "../platform.h" + +#ifdef WIN32 + +#include "threads_kernel_1.h" + +#include + + +namespace dlib +{ + namespace threads_kernel_shared_helpers + { + + // ----------------------------------------------------------------------------------- + + struct info + { + void* param; + void (*funct)(void*); + }; + + // ----------------------------------------------------------------------------------- + + unsigned int __stdcall thread_starter ( + void* param + ) + { + info* alloc_p = static_cast(param); + info p = *alloc_p; + delete alloc_p; + + p.funct(p.param); + return 0; + } + + // ----------------------------------------------------------------------------------- + + bool spawn_thread ( + void (*funct)(void*), + void* param + ) + { + info* p; + try { p = new info; } + catch (...) { return false; } + + p->funct = funct; + p->param = param; + + + unsigned int garbage; + + HANDLE thandle = (HANDLE)_beginthreadex (NULL,0,thread_starter,p,0,&garbage); + // make thread and add it to the pool + + // return false if _beginthreadex didn't work + if ( thandle == 0) + { + delete p; + return false; + } + + // throw away the thread handle + CloseHandle(thandle); + return true; + } + + // ----------------------------------------------------------------------------------- + + } + +} + +#endif // WIN32 + +#endif // DLIB_THREADS_KERNEL_1_CPp_ + diff --git a/ml/dlib/dlib/threads/threads_kernel_1.h b/ml/dlib/dlib/threads/threads_kernel_1.h new file mode 100644 index 000000000..586a21b7e --- /dev/null +++ b/ml/dlib/dlib/threads/threads_kernel_1.h @@ -0,0 +1,158 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_THREADS_KERNEl_1_ +#define DLIB_THREADS_KERNEl_1_ + +#ifdef DLIB_ISO_CPP_ONLY +#error "DLIB_ISO_CPP_ONLY is defined so you can't use this OS dependent code. Turn DLIB_ISO_CPP_ONLY off if you want to use it." +#endif + +#include "threads_kernel_abstract.h" + +#include "../windows_magic.h" +#include +#include "../algs.h" +#include +#include +#include + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + typedef DWORD thread_id_type; + + inline thread_id_type get_thread_id ( + ) + { + return GetCurrentThreadId(); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // mutex object +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + // forward declaration of signaler + class signaler; + + class mutex + { + public: + + mutex ( + ) + { + } + + ~mutex ( + ) { } + + void lock ( + ) const { cs.lock(); } + + void unlock ( + ) const { cs.unlock(); } + + private: + + friend class signaler; + + mutable std::mutex cs; + + // restricted functions + mutex(mutex&); // copy constructor + mutex& operator=(mutex&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // signaler object +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class signaler + { + + public: + signaler ( + const mutex& associated_mutex + ) : + m(associated_mutex) + { + + } + + ~signaler ( + ) { } + + void wait ( + ) const + { + std::unique_lock cs(m.cs, std::defer_lock); + cv.wait(cs); + } + + bool wait_or_timeout ( + unsigned long milliseconds + ) const + { + std::unique_lock cs(m.cs, std::defer_lock); + auto status = cv.wait_until(cs, std::chrono::system_clock::now() + std::chrono::milliseconds(milliseconds)); + return status == std::cv_status::no_timeout; + } + + void signal ( + ) const + { + cv.notify_one(); + } + + void broadcast ( + ) const + { + cv.notify_all(); + } + + const mutex& get_mutex ( + ) const { return m; } + + private: + + mutable std::condition_variable cv; + + const mutex& m; + + // restricted functions + signaler(signaler&); // copy constructor + signaler& operator=(signaler&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- + + namespace threads_kernel_shared_helpers + { + bool spawn_thread ( + void (*funct)(void*), + void* param + ); + /*! + is identical to create_new_thread() but just doesn't use any thread pooling. + !*/ + } + +// ---------------------------------------------------------------------------------------- + +} + +#include "threads_kernel_shared.h" + +#ifdef NO_MAKEFILE +#include "threads_kernel_1.cpp" +#endif + +#endif // DLIB_THREADS_KERNEl_1_ + diff --git a/ml/dlib/dlib/threads/threads_kernel_2.cpp b/ml/dlib/dlib/threads/threads_kernel_2.cpp new file mode 100644 index 000000000..06fb80d00 --- /dev/null +++ b/ml/dlib/dlib/threads/threads_kernel_2.cpp @@ -0,0 +1,75 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_THREADS_KERNEL_2_CPp_ +#define DLIB_THREADS_KERNEL_2_CPp_ + +#include "../platform.h" + +#ifdef POSIX + +#include "threads_kernel_2.h" + + +namespace dlib +{ + namespace threads_kernel_shared_helpers + { + + // ----------------------------------------------------------------------------------- + + struct info + { + void* param; + void (*funct)(void*); + }; + + // ----------------------------------------------------------------------------------- + + void* thread_starter ( + void* param + ) + { + info* alloc_p = static_cast(param); + info p = *alloc_p; + delete alloc_p; + + // detach self + pthread_detach(pthread_self()); + + p.funct(p.param); + return 0; + } + + // ----------------------------------------------------------------------------------- + + bool spawn_thread ( + void (*funct)(void*), + void* param + ) + { + info* p; + try { p = new info; } + catch (...) { return false; } + + p->funct = funct; + p->param = param; + + pthread_t thread_id; + if ( pthread_create (&thread_id, 0, thread_starter, p) ) + { + delete p; + return false; + } + return true; + } + + // ----------------------------------------------------------------------------------- + + } + +} + +#endif // POSIX + +#endif // DLIB_THREADS_KERNEL_2_CPp_ + diff --git a/ml/dlib/dlib/threads/threads_kernel_2.h b/ml/dlib/dlib/threads/threads_kernel_2.h new file mode 100644 index 000000000..209142131 --- /dev/null +++ b/ml/dlib/dlib/threads/threads_kernel_2.h @@ -0,0 +1,180 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_THREADS_KERNEl_2_ +#define DLIB_THREADS_KERNEl_2_ + +#ifdef DLIB_ISO_CPP_ONLY +#error "DLIB_ISO_CPP_ONLY is defined so you can't use this OS dependent code. Turn DLIB_ISO_CPP_ONLY off if you want to use it." +#endif + +#include "threads_kernel_abstract.h" +#include +#include +#include +#include "../algs.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + typedef pthread_t thread_id_type; + + inline thread_id_type get_thread_id ( + ) + { + return pthread_self(); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // mutex object +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + // forward declaration of signaler + class signaler; + + class mutex + { + // give signaler access to hMutex + friend class signaler; + public: + + mutex ( + ) + { + if (pthread_mutex_init(&myMutex,0)) + { + throw dlib::thread_error(ECREATE_MUTEX, + "in function mutex::mutex() an error occurred making the mutex" + ); + } + } + + ~mutex ( + ) { pthread_mutex_destroy(&myMutex); } + + void lock ( + ) const { pthread_mutex_lock(&myMutex); } + + void unlock ( + ) const { pthread_mutex_unlock(&myMutex); } + + private: + + mutable pthread_mutex_t myMutex; + + // restricted functions + mutex(mutex&); // copy constructor + mutex& operator=(mutex&); // assignement opertor + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // signaler object +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class signaler + { + + public: + + + signaler ( + const mutex& assoc_mutex + ) : + associated_mutex(&assoc_mutex.myMutex), + m(assoc_mutex) + { + if (pthread_cond_init(&cond,0)) + { + throw dlib::thread_error(ECREATE_SIGNALER, + "in function signaler::signaler() an error occurred making the signaler" + ); + } + } + + ~signaler ( + ) { pthread_cond_destroy(&cond); } + + void wait ( + ) const + { + pthread_cond_wait(&cond,associated_mutex); + } + + bool wait_or_timeout ( + unsigned long milliseconds + ) const + { + timespec time_to_wait; + + timeval curtime; + gettimeofday(&curtime,0); + + // get the time and adjust the timespec object by the appropriate amount + time_to_wait.tv_sec = milliseconds/1000 + curtime.tv_sec; + time_to_wait.tv_nsec = curtime.tv_usec; + time_to_wait.tv_nsec *= 1000; + time_to_wait.tv_nsec += (milliseconds%1000)*1000000; + + time_to_wait.tv_sec += time_to_wait.tv_nsec/1000000000; + time_to_wait.tv_nsec = time_to_wait.tv_nsec%1000000000; + + if ( pthread_cond_timedwait(&cond,associated_mutex,&time_to_wait) == ETIMEDOUT) + { + return false; + } + else + { + return true; + } + } + + void signal ( + ) const { pthread_cond_signal(&cond); } + + void broadcast ( + ) const { pthread_cond_broadcast(&cond); } + + const mutex& get_mutex ( + ) const { return m; } + + private: + + pthread_mutex_t* const associated_mutex; + mutable pthread_cond_t cond; + const mutex& m; + + // restricted functions + signaler(signaler&); // copy constructor + signaler& operator=(signaler&); // assignement opertor + }; + +// ---------------------------------------------------------------------------------------- + + namespace threads_kernel_shared_helpers + { + bool spawn_thread ( + void (*funct)(void*), + void* param + ); + /*! + is identical to create_new_thread() but just doesn't use any thread pooling. + !*/ + } + +// ---------------------------------------------------------------------------------------- + +} + +#include "threads_kernel_shared.h" + +#ifdef NO_MAKEFILE +#include "threads_kernel_2.cpp" +#endif + +#endif // DLIB_THREADS_KERNEl_2_ + diff --git a/ml/dlib/dlib/threads/threads_kernel_abstract.h b/ml/dlib/dlib/threads/threads_kernel_abstract.h new file mode 100644 index 000000000..d88d37dad --- /dev/null +++ b/ml/dlib/dlib/threads/threads_kernel_abstract.h @@ -0,0 +1,302 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_THREADS_KERNEl_ABSTRACT_ +#ifdef DLIB_THREADS_KERNEl_ABSTRACT_ + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + /*! + THREAD POOLING + When threads end they go into a global thread pool and each waits there + for 30 seconds before timing out and having its resources returned to the + operating system. When create_new_thread() is called it first looks in the + thread pool to see if there are any threads it can snatch from the pool, if + not then it makes a new one. + + Note that whenever I say something happens when a thread "terminates" or "ends" + I mean "when it returns to the thread pool." From the client programmer point + of view a thread terminates/ends when it returns to the dlib thread pool and you + shouldn't and indeed don't need to know when it actually gets its resources + reclaimed by the operating system. + + If you want to change the timeout to a different value you can #define + DLIB_THREAD_POOL_TIMEOUT to whatever value (in milliseconds) that you like. + + EXCEPTIONS + Unless specified otherwise, nothing in this file throws exceptions. + !*/ + +// ---------------------------------------------------------------------------------------- + + thread_id_type get_thread_id ( + ); + /*! + ensures + - returns a unique id for the calling thread. Note that while the id is unique + among all currently existing threads it may have been used by a previous + thread that has terminated. + !*/ + +// ---------------------------------------------------------------------------------------- + + bool is_dlib_thread ( + thread_id_type id = get_thread_id() + ); + /*! + ensures + - if (the thread with the given id was spawned by a call to + dlib::create_new_thread) then + - returns true + - else + - returns false + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + void register_thread_end_handler ( + T& obj, + void (T::*handler)() + ); + /*! + requires + - handler == a valid member function pointer for class T + - handler does not throw + - handler does not call register_thread_end_handler() + - handler does not block + - is_dlib_thread() == true (i.e. the calling thread was spawned by dlib::create_new_thread()) + ensures + - let ID == the thread id for the thread calling register_thread_end_handler() + - (obj.*handler)() will be called when the thread with thread id ID is + terminating and it will be called from within that terminating thread. + (i.e. inside the handler function get_thread_id() == ID == the id of the + thread that is terminating. ) + - each call to this function adds another handler that will be called when + the given thread terminates. This means that if you call it a bunch of + times then you will end up registering multiple handlers (or single + handlers multiple times) that will be called when the thread ends. + throws + - std::bad_alloc + If this exception is thrown then the call to this function had no effect. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + void unregister_thread_end_handler ( + T& obj, + void (T::*handler)() + ); + /*! + requires + - handler == a valid member function pointer for class T + ensures + - Undoes all previous calls to register_thread_end_handler(obj,handler). + So the given handler won't be called when any threads end. + throws + - std::bad_alloc + If this exception is thrown then the call to this function had no effect. + !*/ + +// ---------------------------------------------------------------------------------------- + + bool create_new_thread ( + void (*funct)(void*), + void* param + ); + /*! + ensures + - creates a new thread for the function pointed to by funct + - passes it param as its parameter. (i.e. calls funct(param) from the new thread) + - returns true upon success and false upon failure to create the new thread + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // mutex object +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class mutex + { + /*! + INITIAL VALUE + mutex is in the unlocked state + + WHAT THIS OBJECT REPRESENTS + This object represents a mutex intended to be used for synchronous + thread control of shared data. When a thread wants to access some + shared data it locks out other threads by calling lock() and calls + unlock() when it is finished. + !*/ + public: + + mutex ( + ); + /*! + ensures + - #*this is properly initialized + throws + - dlib::thread_error + the constructor may throw this exception if there is a problem + gathering resources to create the mutex. + !*/ + + ~mutex ( + ); + /*! + requires + - *this is not locked + ensures + - all resources allocated by *this have been freed + !*/ + + void lock ( + ) const; + /*! + requires + - the thread calling lock() does not already have a lock on *this + ensures + - if (*this is currently locked by another thread) then + - the thread that called lock() on *this is put to sleep until + it becomes available + - if (*this is currently unlocked) then + - #*this becomes locked and the current thread is NOT put to sleep + but now "owns" #*this + !*/ + + void unlock ( + ) const; + /*! + requires + - the thread calling unlock() already has a lock on *this + ensures + - #*this is unlocked (i.e. other threads may now lock this object) + !*/ + + + private: + // restricted functions + mutex(mutex&); // copy constructor + mutex& operator=(mutex&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // signaler object +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class signaler + { + /*! + + WHAT THIS OBJECT REPRESENTS + This object represents an event signaling system for threads. It gives + a thread the ability to wake up other threads that are waiting for a + particular signal. + + Each signaler object is associated with one and only one mutex object. + More than one signaler object may be associated with a single mutex + but a signaler object may only be associated with a single mutex. + + NOTE: + You must guard against spurious wakeups. This means that a thread + might return from a call to wait even if no other thread called + signal. This is rare but must be guarded against. + !*/ + public: + + signaler ( + const mutex& associated_mutex + ); + /*! + ensures + - #*this is properly initialized + - #get_mutex() == associated_mutex + throws + - dlib::thread_error + the constructor may throw this exception if there is a problem + gathering resources to create the signaler. + !*/ + + + ~signaler ( + ); + /*! + ensures + - all resources allocated by *this have been freed + !*/ + + void wait ( + ) const; + /*! + requires + - get_mutex() is locked and owned by the calling thread + ensures + - atomically unlocks get_mutex() and blocks the calling thread + - calling thread may wake if another thread calls signal() or broadcast() + on *this + - when wait() returns the calling thread again has a lock on get_mutex() + !*/ + + bool wait_or_timeout ( + unsigned long milliseconds + ) const; + /*! + requires + - get_mutex() is locked and owned by the calling thread + ensures + - atomically unlocks get_mutex() and blocks the calling thread + - calling thread may wake if another thread calls signal() or broadcast() + on *this + - after the specified number of milliseconds has elapsed the calling thread + will wake once get_mutex() is free + - when wait returns the calling thread again has a lock on get_mutex() + + - returns false if the call to wait_or_timeout timed out + - returns true if the call did not time out + !*/ + + + void signal ( + ) const; + /*! + ensures + - if (at least one thread is waiting on *this) then + - at least one of the waiting threads will wake + !*/ + + void broadcast ( + ) const; + /*! + ensures + - any and all threads waiting on *this will wake + !*/ + + const mutex& get_mutex ( + ) const; + /*! + ensures + - returns a const reference to the mutex associated with *this + !*/ + + private: + // restricted functions + signaler(signaler&); // copy constructor + signaler& operator=(signaler&); // assignment operator + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_THREADS_KERNEl_ABSTRACT_ + diff --git a/ml/dlib/dlib/threads/threads_kernel_shared.cpp b/ml/dlib/dlib/threads/threads_kernel_shared.cpp new file mode 100644 index 000000000..8e81193e9 --- /dev/null +++ b/ml/dlib/dlib/threads/threads_kernel_shared.cpp @@ -0,0 +1,318 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_THREADS_KERNEL_SHARED_CPp_ +#define DLIB_THREADS_KERNEL_SHARED_CPp_ + +#include "threads_kernel_shared.h" +#include "../assert.h" +#include "../platform.h" +#include + + +#ifndef DLIB_THREAD_POOL_TIMEOUT +// default to 30000 milliseconds +#define DLIB_THREAD_POOL_TIMEOUT 30000 +#endif + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// threader functions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + namespace threads_kernel_shared + { + + bool thread_pool_has_been_destroyed = false; + +// ---------------------------------------------------------------------------------------- + + struct threader_destruct_helper + { + // cause the thread pool to begin its destruction process when + // global objects start to be destroyed + ~threader_destruct_helper() + { + thread_pool().destruct_if_ready(); + } + }; + +// ---------------------------------------------------------------------------------------- + + threader& thread_pool ( + ) + { + static threader* thread_pool = new threader; + static threader_destruct_helper a; + return *thread_pool; + } + +// ---------------------------------------------------------------------------------------- + + bool threader:: + is_dlib_thread ( + thread_id_type id + ) + { + auto_mutex M(data_mutex); + return thread_ids.is_member(id); + } + +// ---------------------------------------------------------------------------------------- + + threader:: + threader ( + ) : + total_count(0), + function_pointer(0), + pool_count(0), + data_ready(data_mutex), + data_empty(data_mutex), + destruct(false), + destructed(data_mutex), + do_not_ever_destruct(false) + { +#ifdef WIN32 + // Trying to destroy the global thread pool when we are part of a DLL and the + // DLL is being unloaded can sometimes lead to weird behavior. For example, in + // the python interpreter you will get the interpreter to hang. Or if we are + // part of a MATLAB mex file and the file is being unloaded there can also be + // similar weird issues. So when we are using dlib on windows we just disable + // the destruction of the global thread pool since it doesn't matter anyway. + // It's resources will just get freed by the OS. This is even the recommended + // thing to do by Microsoft (http://blogs.msdn.com/b/oldnewthing/archive/2012/01/05/10253268.aspx). + // + // As an aside, it's worth pointing out that the reason we try and free + // resources on program shutdown on other operating systems is so we can have + // clean reports from tools like valgrind which check for memory leaks. But + // trying to do this on windows is a lost cause so we give up in this case and + // follow the Microsoft recommendation. + do_not_ever_destruct = true; +#endif // WIN32 + } + +// ---------------------------------------------------------------------------------------- + + threader:: + ~threader ( + ) + { + data_mutex.lock(); + destruct = true; + data_ready.broadcast(); + + // wait for all the threads to end + while (total_count > 0) + destructed.wait(); + + thread_pool_has_been_destroyed = true; + data_mutex.unlock(); + } + +// ---------------------------------------------------------------------------------------- + + void threader:: + destruct_if_ready ( + ) + { + if (do_not_ever_destruct) + return; + + data_mutex.lock(); + + // if there aren't any active threads, just maybe some sitting around + // in the pool then just destroy the threader + if (total_count == pool_count) + { + destruct = true; + data_ready.broadcast(); + data_mutex.unlock(); + delete this; + } + else + { + // There are still some user threads running so there isn't + // much we can really do. Just let the program end without + // cleaning up threading resources. + data_mutex.unlock(); + } + } + +// ---------------------------------------------------------------------------------------- + + void threader:: + call_end_handlers ( + ) + { + reg.m.lock(); + const thread_id_type id = get_thread_id(); + thread_id_type id_copy; + member_function_pointer<> mfp; + + // Remove all the member function pointers for this thread from the tree + // and call them. + while (reg.reg[id] != 0) + { + reg.reg.remove(id,id_copy,mfp); + reg.m.unlock(); + mfp(); + reg.m.lock(); + } + reg.m.unlock(); + } + + // ------------------------------------------------------------------------------------ + + bool threader:: + create_new_thread ( + void (*funct)(void*), + void* param + ) + { + + // get a lock on the data mutex + auto_mutex M(data_mutex); + + // loop to ensure that the new function pointer is in the data + while (true) + { + // if the data is empty then add new data and quit loop + if (function_pointer == 0) + { + parameter = param; + function_pointer = funct; + break; + } + else + { + // wait for data to become empty + data_empty.wait(); + } + } + + + // get a thread for this new data + // if a new thread must be created + if (pool_count == 0) + { + // make thread and add it to the pool + if ( threads_kernel_shared_helpers::spawn_thread(thread_starter, this) == false ) + { + function_pointer = 0; + parameter = 0; + data_empty.signal(); + return false; + } + ++total_count; + } + // wake up a thread from the pool + else + { + data_ready.signal(); + } + + return true; + } + + // ------------------------------------------------------------------------------------ + + void thread_starter ( + void* object + ) + { + // get a reference to the calling threader object + threader& self = *static_cast(object); + + + { + auto_mutex M(self.data_mutex); + + // add this thread id + thread_id_type thread_id = get_thread_id(); + self.thread_ids.add(thread_id); + + // indicate that this thread is now in the thread pool + ++self.pool_count; + + while (self.destruct == false) + { + // if data is ready then process it and launch the thread + // if its not ready then go back into the pool + while (self.function_pointer != 0) + { + // indicate that this thread is now out of the thread pool + --self.pool_count; + + // get the data for the function call + void (*funct)(void*) = self.function_pointer; + void* param = self.parameter; + self.function_pointer = 0; + + // signal that the data is now empty + self.data_empty.signal(); + + self.data_mutex.unlock(); + // Call funct with its intended parameter. If this function throws then + // we intentionally let the exception escape the thread and result in whatever + // happens when it gets caught by the OS (generally the program is terminated). + funct(param); + self.call_end_handlers(); + + self.data_mutex.lock(); + + // indicate that this thread is now back in the thread pool + ++self.pool_count; + } + + if (self.destruct == true) + break; + + // if we timed out and there isn't any work to do then + // this thread will quit this loop and end. + if (self.data_ready.wait_or_timeout(DLIB_THREAD_POOL_TIMEOUT) == false && + self.function_pointer == 0) + break; + + } + + // remove this thread id from thread_ids + thread_id = get_thread_id(); + self.thread_ids.destroy(thread_id); + + // indicate that this thread is now out of the thread pool + --self.pool_count; + --self.total_count; + + self.destructed.signal(); + + } // end of auto_mutex M(self.data_mutex) block + } + + // ------------------------------------------------------------------------------------ + + } + +// ---------------------------------------------------------------------------------------- + + bool is_dlib_thread ( + thread_id_type id + ) + { + return threads_kernel_shared::thread_pool().is_dlib_thread(id); + } + + bool is_dlib_thread ( + ) + { + return is_dlib_thread(get_thread_id()); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_THREADS_KERNEL_SHARED_CPp_ + diff --git a/ml/dlib/dlib/threads/threads_kernel_shared.h b/ml/dlib/dlib/threads/threads_kernel_shared.h new file mode 100644 index 000000000..b4526e8db --- /dev/null +++ b/ml/dlib/dlib/threads/threads_kernel_shared.h @@ -0,0 +1,274 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_THREADS_KERNEl_SHARED_ +#define DLIB_THREADS_KERNEl_SHARED_ + +// this file should be included at the bottom of one of the thread kernel headers for a +// specific platform. +//#include "../threads.h" +#include "auto_mutex_extension.h" +#include "../binary_search_tree.h" +#include "../member_function_pointer.h" +#include "../memory_manager.h" +#include "../queue.h" +#include "../set.h" +#include "../test_for_odr_violations.h" + + + + + +namespace dlib +{ + + +// ---------------------------------------------------------------------------------------- + + namespace threads_kernel_shared + { + void thread_starter ( + void* + ); + + class threader + { + /*! + INITIAL VALUE + - pool_count == 0 and + - data_ready is associated with the mutex data_mutex + - data_empty is associated with the mutex data_mutex + - destructed is associated with the mutex data_mutex + - destruct == false + - total_count == 0 + - function_pointer == 0 + - do_not_ever_destruct == false + + CONVENTION + - data_ready is associated with the mutex data_mutex + - data_empty is associated with the mutex data_mutex + - data_ready == a signaler used signal when there is new data waiting + to start a thread with. + - data_empty == a signaler used to signal when the data is now empty + - pool_count == the number of suspended threads in the thread pool + - total_count == the number of threads that are executing anywhere. i.e. + pool_count + the ones that are currently running some user function. + - if (function_pointer != 0) then + - parameter == a void pointer pointing to the parameter which + should be used to start the next thread + - function_pointer == a pointer to the next function to make a + new thread with + + - if (the destructor is running) then + - destruct == true + - else + - destruct == false + + - thread_ids is locked by the data_mutex + - thread_ids == a set that contains the thread id for each thread spawned by this + object. + !*/ + + + public: + threader ( + ); + + ~threader ( + ); + + void destruct_if_ready ( + ); + /*! + ensures + - if (there are no threads currently running and we haven't set do_not_ever_destruct) then + - calls delete this + - else + - does nothing + !*/ + + bool create_new_thread ( + void (*funct)(void*), + void* param + ); + + template < + typename T + > + void unregister_thread_end_handler ( + T& obj, + void (T::*handler)() + ) + { + member_function_pointer<> mfp, junk_mfp; + mfp.set(obj,handler); + + thread_id_type junk_id; + + // find any member function pointers in the registry that point to the same + // thing as mfp and remove them + auto_mutex M(reg.m); + reg.reg.reset(); + while (reg.reg.move_next()) + { + while (reg.reg.current_element_valid() && reg.reg.element().value() == mfp) + { + reg.reg.remove_current_element(junk_id, junk_mfp); + } + } + } + + template < + typename T + > + void register_thread_end_handler ( + T& obj, + void (T::*handler)() + ) + { + thread_id_type id = get_thread_id(); + member_function_pointer<> mfp; + mfp.set(obj,handler); + + auto_mutex M(reg.m); + reg.reg.add(id,mfp); + } + + bool is_dlib_thread ( + thread_id_type id + ); + + private: + + friend void thread_starter ( + void* + ); + + void call_end_handlers ( + ); + /*! + ensures + - calls the registered end handlers for the calling thread and + then removes them from reg.reg + !*/ + + + // private data + set::kernel_2b>::kernel_1b_c thread_ids; + unsigned long total_count; + void* parameter; + void (*function_pointer)(void*); + unsigned long pool_count; + mutex data_mutex; // mutex to protect the above data + signaler data_ready; // signaler to signal when there is new data + signaler data_empty; // signaler to signal when the data is empty + bool destruct; + signaler destructed; // signaler to signal when a thread has ended + bool do_not_ever_destruct; + + struct registry_type + { + mutex m; + binary_search_tree< + thread_id_type, + member_function_pointer<>, + memory_manager::kernel_2a + >::kernel_2a_c reg; + }; + + // stuff for the register_thread_end_handler + registry_type reg; + + + // restricted functions + threader(threader&); // copy constructor + threader& operator=(threader&); // assignement opertor + + }; + + // ------------------------------------------------------------------------------------ + + threader& thread_pool ( + ); + /*! + ensures + - returns a reference to the global threader object + !*/ + + // ------------------------------------------------------------------------------------ + + extern bool thread_pool_has_been_destroyed; + } + + bool is_dlib_thread ( + thread_id_type id + ); + + bool is_dlib_thread ( + ); + +// ---------------------------------------------------------------------------------------- + + inline bool create_new_thread ( + void (*funct)(void*), + void* param + ) + { + try + { + // now make this thread + return threads_kernel_shared::thread_pool().create_new_thread(funct,param); + } + catch (std::bad_alloc&) + { + return false; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + inline void register_thread_end_handler ( + T& obj, + void (T::*handler)() + ) + { + DLIB_ASSERT(is_dlib_thread(), + "\tvoid register_thread_end_handler" + << "\n\tYou can't register a thread end handler for a thread dlib didn't spawn." + ); + + threads_kernel_shared::thread_pool().register_thread_end_handler(obj,handler); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + inline void unregister_thread_end_handler ( + T& obj, + void (T::*handler)() + ) + { + // Check if the thread pool has been destroyed and if it has then don't do anything. + // This bool here is always true except when the program has started to terminate and + // the thread pool object has been destroyed. This if is here to catch other global + // objects that have destructors that try to call unregister_thread_end_handler(). + // Without this check we get into trouble if the thread pool is destroyed before these + // objects. + if (threads_kernel_shared::thread_pool_has_been_destroyed == false) + threads_kernel_shared::thread_pool().unregister_thread_end_handler(obj,handler); + } + +// ---------------------------------------------------------------------------------------- + +} + +#ifdef NO_MAKEFILE +#include "threads_kernel_shared.cpp" +#endif + +#endif // DLIB_THREADS_KERNEl_SHARED_ + diff --git a/ml/dlib/dlib/threads/windows.h b/ml/dlib/dlib/threads/windows.h new file mode 100644 index 000000000..f7c775950 --- /dev/null +++ b/ml/dlib/dlib/threads/windows.h @@ -0,0 +1,6 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_THREADS_KERNEl_2_ +#include "threads_kernel_1.h" +#endif + diff --git a/ml/dlib/dlib/time_this.h b/ml/dlib/dlib/time_this.h new file mode 100644 index 000000000..aec0d2de8 --- /dev/null +++ b/ml/dlib/dlib/time_this.h @@ -0,0 +1,36 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_TIME_THIs_ +#define DLIB_TIME_THIs_ + + +#include + +// ---------------------------------------------------------------------------------------- + +#define TIME_THIS_TO(_tt_op,_tt_out) \ + { \ + auto _tt_start = std::chrono::high_resolution_clock::now(); \ + {_tt_op;} \ + auto _tt_stop = std::chrono::high_resolution_clock::now(); \ + auto _tt_thetime = _tt_stop-_tt_start; \ + using std::chrono::duration_cast; \ + using std::chrono::duration; \ + if (_tt_thetime >= std::chrono::minutes(1)) \ + _tt_out << "\ntime: " << duration_cast>>(_tt_thetime).count() << "min\n"; \ + else if (_tt_thetime >= std::chrono::seconds(1)) \ + _tt_out << "\ntime: " << duration_cast>(_tt_thetime).count() << "sec\n"; \ + else if (_tt_thetime >= std::chrono::milliseconds(1)) \ + _tt_out << "\ntime: " << duration_cast>(_tt_thetime).count() << "ms\n"; \ + else if (_tt_thetime >= std::chrono::microseconds(1)) \ + _tt_out << "\ntime: " << duration_cast>(_tt_thetime).count() << "us\n"; \ + else \ + _tt_out << "\ntime: " << duration_cast>(_tt_thetime).count() << "ns\n"; \ + } + +#define TIME_THIS(_tt_op) TIME_THIS_TO(_tt_op,std::cout) + +// ---------------------------------------------------------------------------------------- + +#endif // DLIB_TIME_THIs_ + diff --git a/ml/dlib/dlib/timeout.h b/ml/dlib/dlib/timeout.h new file mode 100644 index 000000000..2c3590e46 --- /dev/null +++ b/ml/dlib/dlib/timeout.h @@ -0,0 +1,10 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_TIMEOUt_ +#define DLIB_TIMEOUt_ + +#include "timeout/timeout.h" + +#endif // DLIB_TIMEOUt_ + + diff --git a/ml/dlib/dlib/timeout/timeout.h b/ml/dlib/dlib/timeout/timeout.h new file mode 100644 index 000000000..663e0ca65 --- /dev/null +++ b/ml/dlib/dlib/timeout/timeout.h @@ -0,0 +1,200 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_TIMEOUT_KERNEl_1_ +#define DLIB_TIMEOUT_KERNEl_1_ + +#include "../threads.h" +#include "../algs.h" +#include "../misc_api.h" +#include "timeout_abstract.h" +#include "../uintn.h" +#include "../timer.h" + +#ifdef _MSC_VER +// this is to disable the "'this' : used in base member initializer list" +// warning you get from some of the GUI objects since all the objects +// require that their parent class be passed into their constructor. +// In this case though it is totally safe so it is ok to disable this warning. +#pragma warning(disable : 4355) +#endif // _MSC_VER + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class timeout + { + /*! + INITIAL VALUE + - b == a pointer to some kind of bind object + + CONVENTION + - b == a pointer to some kind of bind object + !*/ + + class bind + { + public: + virtual void go() = 0; + virtual ~bind() {} + }; + + template + class functor : public bind + { + public: + functor(const T& f) : function(f) {} + T function; + void go() { function(); } + }; + + template + class zero : public bind + { + public: + T* object; + R (T::*callback_function)(); + void go() { (object->*callback_function)(); } + + }; + + template + class one : public bind + { + public: + T* object; + R (T::*callback_function)(U); + U val; + void go() { (object->*callback_function)(val); } + }; + + public: + + // This typedef is here for backwards compatibility with previous versions of dlib. + typedef timeout kernel_1a; + + template < + typename T + > + timeout ( + T callback_function, + unsigned long ms_to_timeout + ) : + t(*this,&timeout::trigger_timeout) + { + b = new functor(callback_function); + t.set_delay_time(ms_to_timeout); + t.start(); + } + + template < + typename T + > + timeout ( + T& object, + void (T::*callback_function)(), + unsigned long ms_to_timeout + ): + t(*this,&timeout::trigger_timeout) + { + zero* B = new zero; + b = B; + B->object = &object; + B->callback_function = callback_function; + t.set_delay_time(ms_to_timeout); + t.start(); + } + + template < + typename T, + typename U + > + timeout ( + T& object, + void (T::*callback_function)(U callback_function_argument), + unsigned long ms_to_timeout, + U callback_function_argument + ): + t(*this,&timeout::trigger_timeout) + { + one* B = new one; + b = B; + B->object = &object; + B->callback_function = callback_function; + B->val = callback_function_argument; + t.set_delay_time(ms_to_timeout); + t.start(); + } + + template < + typename T + > + timeout ( + T& object, + int (T::*callback_function)(), + unsigned long ms_to_timeout + ): + t(*this,&timeout::trigger_timeout) + { + zero* B = new zero; + b = B; + B->object = &object; + B->callback_function = callback_function; + t.set_delay_time(ms_to_timeout); + t.start(); + } + + template < + typename T, + typename U + > + timeout ( + T& object, + int (T::*callback_function)(U callback_function_argument), + unsigned long ms_to_timeout, + U callback_function_argument + ): + t(*this,&timeout::trigger_timeout) + { + one* B = new one; + b = B; + B->object = &object; + B->callback_function = callback_function; + B->val = callback_function_argument; + t.set_delay_time(ms_to_timeout); + t.start(); + } + + virtual ~timeout ( + ) + { + t.stop_and_wait(); + delete b; + } + + private: + + void trigger_timeout () + { + b->go(); + t.stop(); + } + + dlib::timer t; + bind* b; + + // restricted functions + timeout(const timeout&); // copy constructor + timeout& operator=(const timeout&); // assignment operator + + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_TIMEOUT_KERNEl_1_ + + + diff --git a/ml/dlib/dlib/timeout/timeout_abstract.h b/ml/dlib/dlib/timeout/timeout_abstract.h new file mode 100644 index 000000000..c2401bb68 --- /dev/null +++ b/ml/dlib/dlib/timeout/timeout_abstract.h @@ -0,0 +1,188 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_TIMEOUT_KERNEl_ABSTRACT_ +#ifdef DLIB_TIMEOUT_KERNEl_ABSTRACT_ + +#include "../threads.h" + +namespace dlib +{ + + class timeout + { + /*! + WHAT THIS OBJECT REPRESENTS + This object provides a simple way to implement a timeout. An example will + make its use clear. Suppose we want to read from a socket but we want to + terminate the connection if the read takes longer than 10 seconds. This + could be accomplished as follows: + + connection* con = a connection from somewhere; + { + // setup a timer that will call con->shutdown() in 10 seconds + timeout t(*con,&connection::shutdown,10000); + // Now call read on the connection. If this call to read() takes more + // than 10 seconds then the t timeout will trigger and shutdown the + // connection. If read completes in less than 10 seconds then the t + // object will be destructed on the next line due to the } and then the + // timeout won't trigger. + con->read(buf,100); + } + + + Alternatively, if you have a compiler capable of using C++11 lambda + functions, you can use a syntax like this: + { + timeout t([con](){ con->shutdown(); }, 10000); + con->read(buf,100); + } + + More generally, you can use this with things other than sockets. For + example, the following statement will print "Hello world!" after 1000ms: + timeout t([](){ cout << "Hello world!" << endl; }, 1000); + + + + THREAD SAFETY + All methods of this class are thread safe. + !*/ + + public: + + template < + typename T + > + timeout ( + T callback_function, + unsigned long ms_to_timeout + ); + /*! + requires + - callback_function does not throw + ensures + - does not block. + - #*this is properly initialized + - if (this object isn't destructed in ms_to_timeout milliseconds) then + - callback_function() will be called in ms_to_timeout milliseconds. + throws + - std::bad_alloc + - dlib::thread_error + !*/ + + template < + typename T + > + timeout ( + T& object, + void (T::*callback_function)(), + unsigned long ms_to_timeout + ); + /*! + requires + - callback_function does not throw + ensures + - does not block. + - #*this is properly initialized + - if (this object isn't destructed in ms_to_timeout milliseconds) then + - (object.*callback_function)() will be called in ms_to_timeout + milliseconds. + throws + - std::bad_alloc + - dlib::thread_error + !*/ + + template < + typename T, + typename U + > + timeout ( + T& object, + void (T::*callback_function)(U callback_function_argument), + unsigned long ms_to_timeout, + U callback_function_argument + ); + /*! + requires + - callback_function does not throw + ensures + - does not block. + - #*this is properly initialized + - if (this object isn't destructed in ms_to_timeout milliseconds) then + - (object.*callback_function)(callback_function_argument) will be + called in ms_to_timeout milliseconds. + throws + - std::bad_alloc + - dlib::thread_error + !*/ + + template < + typename T + > + timeout ( + T& object, + int (T::*callback_function)(), + unsigned long ms_to_timeout + ); + /*! + requires + - callback_function does not throw + ensures + - does not block. + - #*this is properly initialized + - if (this object isn't destructed in ms_to_timeout milliseconds) then + - (object.*callback_function)() will be called in ms_to_timeout + milliseconds. + throws + - std::bad_alloc + - dlib::thread_error + !*/ + + template < + typename T, + typename U + > + timeout ( + T& object, + int (T::*callback_function)(U callback_function_argument), + unsigned long ms_to_timeout, + U callback_function_argument + ); + /*! + requires + - callback_function does not throw + ensures + - does not block. + - #*this is properly initialized + - if (this object isn't destructed in ms_to_timeout milliseconds) then + - (object.*callback_function)(callback_function_argument) will be + called in ms_to_timeout milliseconds. + throws + - std::bad_alloc + - dlib::thread_error + !*/ + + virtual ~timeout ( + ); + /*! + requires + - is not called from inside the callback_function given to the + constructor. + ensures + - any resources associated with *this have been released + - if (the callback_function hasn't been called yet) then + - the callback_function specified in the constructor will not be called + !*/ + + private: + + // restricted functions + timeout(const timeout&); // copy constructor + timeout& operator=(const timeout&); // assignment operator + + }; + +} + +#endif // DLIB_TIMEOUT_KERNEl_ABSTRACT_ + + diff --git a/ml/dlib/dlib/timer.h b/ml/dlib/dlib/timer.h new file mode 100644 index 000000000..918078e4d --- /dev/null +++ b/ml/dlib/dlib/timer.h @@ -0,0 +1,10 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_TIMEr_ +#define DLIB_TIMEr_ + +#include "timer/timer.h" +#include "timer/timer_heavy.h" + +#endif // DLIB_TIMEr_ + diff --git a/ml/dlib/dlib/timer/timer.cpp b/ml/dlib/dlib/timer/timer.cpp new file mode 100644 index 000000000..d12aa6f29 --- /dev/null +++ b/ml/dlib/dlib/timer/timer.cpp @@ -0,0 +1,235 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_TIMER_cPPh_ +#define DLIB_TIMER_cPPh_ + +#include "timer.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + timer_global_clock:: + timer_global_clock( + ): + s(m), + shutdown(false), + running(false) + { + } + +// ---------------------------------------------------------------------------------------- + + timer_global_clock:: + ~timer_global_clock() + { + // The only time this destructor is called is when + // + // a) the process terminates + // b) the dynamic library(.so/.dll) is unloaded (could be a part of a)) + // + // in case of a) + // windows: the process termination is especially painful, since threads are killed + // before destructors of the process image .dll's are called. + // Thus, for the windows platform, there is no threads running, so the only thing + // to do here is just let the standard memberwise destructors run + // linux: it's ok to just signal shutdown and wait for the running thread, to exit + // + // in case of b) + // windows: + // if it's part of the termination process, a) applies + // if its part of user doing manual load_library/unload_library + // there is no (safe/robust)solution, but best practices are described here + // https://msdn.microsoft.com/en-us/library/windows/desktop/dn633971.aspx + // to support such a clean shutdown, you are required to make a call prior to + // unload dll, that shutdown all the threads in the contained dll. + // This could be done in this module by providing a global_delete_clock() + // + // linux: the destructor for linux will do it's usual job regardless. + // + + #ifndef _WIN32 + m.lock(); + shutdown = true; + s.signal(); + m.unlock(); + wait(); + #endif + } + +// ---------------------------------------------------------------------------------------- + + void timer_global_clock:: + add ( + timer_base* r + ) + { + if (r->in_global_clock == false) + { + // if the thread isn't running then start it up + if (!running) + { + start(); + running = true; + } + + uint64 t = ts.get_timestamp() + r->delay*1000; + tm.reset(); + if (!tm.move_next() || t < tm.element().key()) + { + // we need to make the thread adjust its next time to + // trigger if this new event occurrs sooner than the + // next event in tm + s.signal(); + } + timer_base* rtemp = r; + uint64 ttemp = t; + tm.add(ttemp,rtemp); + r->next_time_to_run = t; + r->in_global_clock = true; + } + } + +// ---------------------------------------------------------------------------------------- + + void timer_global_clock:: + remove ( + timer_base* r + ) + { + if (r->in_global_clock) + { + tm.position_enumerator(r->next_time_to_run-1); + do + { + if (tm.element().value() == r) + { + uint64 t; + timer_base* rtemp; + tm.remove_current_element(t,rtemp); + r->in_global_clock = false; + break; + } + } while (tm.move_next()); + } + } + +// ---------------------------------------------------------------------------------------- + + void timer_global_clock:: + adjust_delay ( + timer_base* r, + unsigned long new_delay + ) + { + if (r->in_global_clock) + { + remove(r); + // compute the new next_time_to_run and store it in t + uint64 t = r->next_time_to_run; + t -= r->delay*1000; + t += new_delay*1000; + + tm.reset(); + if (!tm.move_next() || t < tm.element().key()) + { + // we need to make the thread adjust its next time to + // trigger if this new event occurrs sooner than the + // next event in tm + s.signal(); + } + + // set this incase add throws + r->running = false; + r->delay = new_delay; + + timer_base* rtemp = r; + uint64 ttemp = t; + tm.add(ttemp,rtemp); + r->next_time_to_run = t; + r->in_global_clock = true; + + // put this back now that we know add didn't throw + r->running = true; + + } + else + { + r->delay = new_delay; + } + } + +// ---------------------------------------------------------------------------------------- + + void timer_global_clock:: + thread() + { + auto_mutex M(m); + while (!shutdown) + { + unsigned long delay = 100000; + + tm.reset(); + tm.move_next(); + // loop and start all the action functions for timers that should have + // triggered. + while(tm.current_element_valid()) + { + const uint64 cur_time = ts.get_timestamp(); + uint64 t = tm.element().key(); + // if the next event in tm is ready to trigger + if (t <= cur_time + 999) + { + // remove this event from the tm map + timer_base* r = tm.element().value(); + timer_base* rtemp; + tm.remove_current_element(t,rtemp); + r->in_global_clock = false; + + // if this timer is still "running" then start its action function + if (r->running) + { + r->restart(); + } + } + else + { + // there aren't any more timers that should trigger so we compute + // the delay to the next timer event. + delay = static_cast((t - cur_time)/1000); + break; + } + } + + s.wait_or_timeout(delay); + } + } + +// ---------------------------------------------------------------------------------------- + + std::shared_ptr get_global_clock() + { + static std::shared_ptr d(new timer_global_clock); + return d; + } + +// ---------------------------------------------------------------------------------------- + + // do this just to make sure get_global_clock() gets called at program startup + class timer_global_clock_helper + { + public: + timer_global_clock_helper() + { + get_global_clock(); + } + }; + static timer_global_clock_helper call_get_global_clock; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_TIMER_cPPh_ + diff --git a/ml/dlib/dlib/timer/timer.h b/ml/dlib/dlib/timer/timer.h new file mode 100644 index 000000000..c8b6c0af6 --- /dev/null +++ b/ml/dlib/dlib/timer/timer.h @@ -0,0 +1,427 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_TIMEr_Hh_ +#define DLIB_TIMEr_Hh_ + +#include + +#include "../threads.h" +#include "../algs.h" +#include "../misc_api.h" +#include "timer_abstract.h" +#include "../uintn.h" +#include "../binary_search_tree.h" +#include "timer_heavy.h" + +namespace dlib +{ + + struct timer_base : public threaded_object + { + /*! + WHAT THIS OBJECT REPRESENTS + This object contains the base members of the timer object. + It exists so that we can access them from outside any templated functions. + !*/ + + unsigned long delay; + // these are only modified by the global_clock + uint64 next_time_to_run; + timestamper ts; + bool running; + bool in_global_clock; + }; + +// ---------------------------------------------------------------------------------------- + + class timer_global_clock : private threaded_object + { + /*! + This object sets up a timer that triggers the action function + for timer objects that are tracked inside this object. + INITIAL VALUE + - shutdown == false + - running == false + + CONVENTION + - if (shutdown) then + - thread() should terminate + - else (running) then + - thread() is running + + - tm[time] == pointer to a timer_base object + !*/ + typedef binary_search_tree::kernel_2b>::kernel_2a_c time_map; + public: + + ~timer_global_clock(); + + void add ( + timer_base* r + ); + /*! + requires + - m is locked + ensures + - starts the thread if it isn't already started + - adds r to tm + - #r->in_global_clock == true + - updates r->next_time_to_run appropriately according to + r->delay + !*/ + + void remove ( + timer_base* r + ); + /*! + requires + - m is locked + ensures + - if (r is in tm) then + - removes r from tm + - #r->in_global_clock == false + !*/ + + void adjust_delay ( + timer_base* r, + unsigned long new_delay + ); + /*! + requires + - m is locked + ensures + - #r->delay == new_delay + - if (r->in_global_clock) then + - the time to the next event will have been appropriately adjusted + !*/ + + mutex m; + + friend std::shared_ptr get_global_clock(); + + private: + timer_global_clock(); + + time_map tm; + signaler s; + bool shutdown; + bool running; + timestamper ts; + + void thread(); + /*! + ensures + - spawns timer tasks as is appropriate + !*/ + }; + std::shared_ptr get_global_clock(); + /*! + ensures + - returns the global instance of the timer_global_clock object + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + class timer : private timer_base + { + /*! + INITIAL VALUE + - running == false + - delay == 1000 + - ao == a pointer to the action_object() + - af == a pointer to the action_function() + - in_global_clock == false + - next_time_to_run == 0 + - gc == get_global_clock() + + CONVENTION + - the mutex used to lock everything is gc->m + - running == is_running() + - delay == delay_time() + - *ao == action_object() + - af == action_function() + - if (!running) then + - in_global_clock == false + - else + - next_time_to_run == the next time this timer should run according + to the timestamper in the global_clock + !*/ + + public: + + // These typedefs are here for backwards compatibility with previous versions of + // dlib. + typedef timer_heavy kernel_1a; + typedef timer kernel_2a; + + typedef void (T::*af_type)(); + + timer( + T& ao_, + af_type af_ + ); + + virtual ~timer( + ); + + void clear( + ); + + af_type action_function ( + ) const; + + const T& action_object ( + ) const; + + T& action_object ( + ); + + bool is_running ( + ) const; + + unsigned long delay_time ( + ) const; + + void set_delay_time ( + unsigned long milliseconds + ); + + void start ( + ); + + void stop ( + ); + + void stop_and_wait ( + ); + + private: + + void thread ( + ); + /*! + ensures + - calls the action function + !*/ + + // data members + T& ao; + const af_type af; + std::shared_ptr gc; + + // restricted functions + timer(const timer&); // copy constructor + timer& operator=(const timer&); // assignment operator + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + timer:: + timer( + T& ao_, + af_type af_ + ) : + ao(ao_), + af(af_), + gc(get_global_clock()) + { + delay = 1000; + next_time_to_run = 0; + running = false; + in_global_clock = false; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + timer:: + ~timer( + ) + { + clear(); + wait(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + void timer:: + clear( + ) + { + auto_mutex M(gc->m); + running = false; + gc->remove(this); + delay = 1000; + next_time_to_run = 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + typename timer::af_type timer:: + action_function ( + ) const + { + return af; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + const T& timer:: + action_object ( + ) const + { + return ao; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + T& timer:: + action_object ( + ) + { + return ao; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + bool timer:: + is_running ( + ) const + { + auto_mutex M(gc->m); + return running; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + unsigned long timer:: + delay_time ( + ) const + { + auto_mutex M(gc->m); + return delay; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + void timer:: + set_delay_time ( + unsigned long milliseconds + ) + { + auto_mutex M(gc->m); + gc->adjust_delay(this,milliseconds); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + void timer:: + start ( + ) + { + auto_mutex M(gc->m); + if (!running) + { + gc->add(this); + running = true; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + void timer:: + stop ( + ) + { + gc->m.lock(); + running = false; + gc->remove(this); + gc->m.unlock(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + void timer:: + thread ( + ) + { + // call the action function + (ao.*af)(); + auto_mutex M(gc->m); + if (running) + { + gc->remove(this); + gc->add(this); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + void timer:: + stop_and_wait ( + ) + { + gc->m.lock(); + running = false; + gc->remove(this); + gc->m.unlock(); + wait(); + } + +// ---------------------------------------------------------------------------------------- + +} + +#ifdef NO_MAKEFILE +#include "timer.cpp" +#endif + +#endif // DLIB_TIMEr_Hh_ + + diff --git a/ml/dlib/dlib/timer/timer_abstract.h b/ml/dlib/dlib/timer/timer_abstract.h new file mode 100644 index 000000000..180cd490d --- /dev/null +++ b/ml/dlib/dlib/timer/timer_abstract.h @@ -0,0 +1,190 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_TIMER_KERNEl_ABSTRACT_ +#ifdef DLIB_TIMER_KERNEl_ABSTRACT_ + +#include "../threads.h" + +namespace dlib +{ + + template < + typename T + > + class timer + { + /*! + INITIAL VALUE + is_running() == false + delay_time() == 1000 + action_object() == The object that is passed into the constructor + action_function() == The member function pointer that is passed to + the constructor. + + WHAT THIS OBJECT REPRESENTS + This object represents a timer that will call a given member function + (the action function) repeatedly at regular intervals and in its own + thread. + + Note that the delay_time() is measured in milliseconds but you are not + guaranteed to have that level of resolution. The actual resolution + is implementation dependent. + + THREAD SAFETY + All methods of this class are thread safe. + !*/ + + public: + + typedef void (T::*af_type)(); + + timer ( + T& ao, + af_type af + ); + /*! + requires + - af does not throw + ensures + - does not block. + - #*this is properly initialized + - #action_object() == ao + - #action_function() == af + (af is a member function pointer to a member in the class T) + throws + - std::bad_alloc + - dlib::thread_error + !*/ + + virtual ~timer ( + ); + /*! + requires + - is not called from inside the action_function() + ensures + - any resources associated with *this have been released + - will not call the action_function() anymore. + - if (the action function is currently executing) then + - blocks until it finishes + !*/ + + void clear( + ); + /*! + ensures + - #*this has its initial value + - does not block + throws + - std::bad_alloc or dlib::thread_error + If either of these exceptions are thrown then #*this is unusable + until clear() is called and succeeds. + !*/ + + af_type action_function ( + ) const; + /*! + ensures + - does not block. + - returns a pointer to the member function of action_object() that is + called by *this. + !*/ + + const T& action_object ( + ) const; + /*! + ensures + - does not block. + - returns a const reference to the object used to call the member + function pointer action_function() + !*/ + + T& action_object ( + ); + /*! + ensures + - does not block. + - returns a non-const reference to the object used to call the member + function pointer action_function() + !*/ + + bool is_running ( + ) const; + /*! + ensures + - does not block. + - if (*this is currently scheduled to call the action_function()) then + - returns true + - else + - returns false + !*/ + + unsigned long delay_time ( + ) const; + /*! + ensures + - does not block. + - returns the amount of time, in milliseconds, that *this will wait between + the return of one call to the action_function() and the beginning of the + next call to the action_function(). + !*/ + + void set_delay_time ( + unsigned long milliseconds + ); + /*! + ensures + - does not block. + - #delay_time() == milliseconds + throws + - std::bad_alloc or dlib::thread_error + If either of these exceptions are thrown then #is_running() == false + but otherwise this function succeeds + !*/ + + void start ( + ); + /*! + ensures + - does not block. + - if (is_running() == false) then + - #is_running() == true + - The action_function() will run in another thread. + - The first call to the action_function() will occur in roughly + delay_time() milliseconds. + - else + - this call to start() has no effect + throws + - dlib::thread_error or std::bad_alloc + If this exception is thrown then #is_running() == false but + otherwise this call to start() has no effect. + !*/ + + void stop ( + ); + /*! + ensures + - #is_running() == false + - does not block. + !*/ + + void stop_and_wait ( + ); + /*! + ensures + - #is_running() == false + - if (the action function is currently executing) then + - blocks until it finishes + !*/ + + private: + + // restricted functions + timer(const timer&); // copy constructor + timer& operator=(const timer&); // assignment operator + + }; + +} + +#endif // DLIB_TIMER_KERNEl_ABSTRACT_ + diff --git a/ml/dlib/dlib/timer/timer_heavy.h b/ml/dlib/dlib/timer/timer_heavy.h new file mode 100644 index 000000000..693b91ad9 --- /dev/null +++ b/ml/dlib/dlib/timer/timer_heavy.h @@ -0,0 +1,392 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_TIMER_KERNEl_1_ +#define DLIB_TIMER_KERNEl_1_ + +#include "../threads.h" +#include "../algs.h" +#include "../misc_api.h" +#include "timer_abstract.h" + +namespace dlib +{ + + template < + typename T + > + class timer_heavy + { + /*! + WHAT THIS OBJECT REPRESENTS + This is an implementation of the timer_abstract.h interface. It is very + simple and uses only one thread which is always alive in a timer_heavy. + The reason this object exists is for historical reasons. Originally, the + dlib::timer was a multi-implementation component and the timer_heavy was + its first implementation. It was superseded later by the more efficient + dlib::timer. However, timer_heavy is still around so that + dlib::timer::kernel_1a has something to refer to. This way, old client + code which somehow depends on the same thread always calling a timer action + function isn't going to be disrupted. + + + INITIAL VALUE + - running == false + - delay == 1000 + - ao == a pointer to the action_object() + - af == a pointer to the action_function() + - m == a mutex that locks everything in this class + - s == a signaler for mutex m + - stop_running == false + + CONVENTION + - running && !stop_running == is_running() + - delay == delay_time() + - *ao == action_object() + - af == action_function() + + - if (running) then + - there is a thread running + - if (is_running()) then + - next_time_to_run == the time when the next execution of the action + function should occur. (the time is given by ts.get_timestamp()) + + - stop_running is used to tell the thread to quit. If it is + set to true then the thread should end. + !*/ + + public: + + typedef void (T::*af_type)(); + + timer_heavy( + T& ao_, + af_type af_ + ); + + virtual ~timer_heavy( + ); + + void clear( + ); + + af_type action_function ( + ) const; + + const T& action_object ( + ) const; + + T& action_object ( + ); + + bool is_running ( + ) const; + + unsigned long delay_time ( + ) const; + + void set_delay_time ( + unsigned long milliseconds + ); + + void start ( + ); + + void stop ( + ); + + void stop_and_wait ( + ); + + private: + + void thread ( + ); + /*! + requires + - is run in its own thread + ensures + - calls the action function for the given timer object in the manner + specified by timer_kernel_abstract.h + !*/ + + // data members + T& ao; + const af_type af; + unsigned long delay; + mutex m; + signaler s; + + bool running; + bool stop_running; + timestamper ts; + uint64 next_time_to_run; + + // restricted functions + timer_heavy(const timer_heavy&); // copy constructor + timer_heavy& operator=(const timer_heavy&); // assignment operator + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + timer_heavy:: + timer_heavy( + T& ao_, + af_type af_ + ) : + ao(ao_), + af(af_), + delay(1000), + s(m), + running(false), + stop_running(false) + { + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + timer_heavy:: + ~timer_heavy( + ) + { + stop_and_wait(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + void timer_heavy:: + clear( + ) + { + m.lock(); + stop_running = true; + delay = 1000; + s.broadcast(); + m.unlock(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + typename timer_heavy::af_type timer_heavy:: + action_function ( + ) const + { + return af; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + const T& timer_heavy:: + action_object ( + ) const + { + return ao; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + T& timer_heavy:: + action_object ( + ) + { + return ao; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + bool timer_heavy:: + is_running ( + ) const + { + auto_mutex M(m); + return running && !stop_running; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + unsigned long timer_heavy:: + delay_time ( + ) const + { + auto_mutex M(m); + return delay; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + void timer_heavy:: + set_delay_time ( + unsigned long milliseconds + ) + { + m.lock(); + + // if (is_running()) then we should adjust next_time_to_run + if (running && !stop_running) + { + next_time_to_run -= delay*1000; + next_time_to_run += milliseconds*1000; + } + + delay = milliseconds; + s.broadcast(); + m.unlock(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + void timer_heavy:: + start ( + ) + { + auto_mutex M(m); + + // if (is_running() == false) then reset the countdown to the next call + // to the action_function() + if ( (running && !stop_running) == false) + next_time_to_run = ts.get_timestamp() + delay*1000; + + stop_running = false; + if (running == false) + { + running = true; + + // start the thread + if (create_new_thread(*this) == false) + { + running = false; + throw dlib::thread_error("error creating new thread in timer_heavy::start"); + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + void timer_heavy:: + stop ( + ) + { + m.lock(); + stop_running = true; + s.broadcast(); + m.unlock(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + void timer_heavy:: + thread ( + ) + { + auto_mutex M(m); + unsigned long delay_remaining; + uint64 current_time = ts.get_timestamp(); + + if (current_time < next_time_to_run) + delay_remaining = static_cast((next_time_to_run-current_time)/1000); + else + delay_remaining = 0; + + while (stop_running == false) + { + if (delay_remaining > 0) + s.wait_or_timeout(delay_remaining); + + if (stop_running) + break; + + current_time = ts.get_timestamp(); + if (current_time < next_time_to_run) + { + // then we woke up too early so we should keep waiting + delay_remaining = static_cast((next_time_to_run-current_time)/1000); + + // rounding might make this be zero anyway. So if it is + // then we will say we have hit the next time to run. + if (delay_remaining > 0) + continue; + } + + // call the action function + m.unlock(); + (ao.*af)(); + m.lock(); + + current_time = ts.get_timestamp(); + next_time_to_run = current_time + delay*1000; + delay_remaining = delay; + } + running = false; + stop_running = false; + s.broadcast(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + void timer_heavy:: + stop_and_wait ( + ) + { + m.lock(); + if (running) + { + // make the running thread terminate + stop_running = true; + + s.broadcast(); + // wait for the thread to quit + while (running) + s.wait(); + } + m.unlock(); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_TIMER_KERNEl_1_ + diff --git a/ml/dlib/dlib/timing.h b/ml/dlib/dlib/timing.h new file mode 100644 index 000000000..2d5116ade --- /dev/null +++ b/ml/dlib/dlib/timing.h @@ -0,0 +1,196 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_TImING_Hh_ +#define DLIB_TImING_Hh_ + +#include +#include +#include +#include "string.h" + +#include + +// ---------------------------------------------------------------------------------------- + +/*!A timing + + This set of functions is useful for determining how much time is spent + executing blocks of code. Consider the following example: + + int main() + { + using namespace dlib::timing; + for (int i = 0; i < 10; ++i) + { + // timing block #1 + start(1,"block #1"); + dlib::sleep(500); + stop(1); + + // timing block #2 + start(2,"block #2"); + dlib::sleep(1000); + stop(2); + } + + print(); + } + + This program would output: + Timing report: + block #1: 5.0 seconds + block #2: 10.0 seconds + + So we spent 5 seconds in block #1 and 10 seconds in block #2 + + + + Additionally, note that you can use an RAII style timing block object. For + example, if we wanted to find out how much time we spent in a loop a convenient + way to do this would be as follows: + + int main() + { + using namespace dlib::timing; + for (int i = 0; i < 10; ++i) + { + block tb(1, "main loop"); + + dlib::sleep(1500); + } + + print(); + } + + This program would output: + Timing report: + block main loop: 15.0 seconds + +!*/ + +// ---------------------------------------------------------------------------------------- + +namespace dlib +{ + namespace timing + { + const int TIME_SLOTS = 500; + const int NAME_LENGTH = 40; + + inline std::atomic* time_buf() + { + static std::atomic buf[TIME_SLOTS]; + return buf; + } + + inline char* name_buf(int i, const char* name) + { + static char buf[TIME_SLOTS][NAME_LENGTH] = {{0}}; + // if this name buffer is empty then copy name into it + if (buf[i][0] == '\0') + { + using namespace std; + strncpy(buf[i], name, NAME_LENGTH-1); + buf[i][NAME_LENGTH-1] = '\0'; + } + // return the name buffer + return buf[i]; + } + + inline uint64_t ts() + { + using namespace std::chrono; + return duration_cast>(high_resolution_clock::now().time_since_epoch()).count(); + } + + inline void start(int i) + { + time_buf()[i] -= ts(); + } + + inline void start(int i, const char* name) + { + time_buf()[i] -= ts(); + name_buf(i,name); + } + + inline void stop(int i) + { + time_buf()[i] += ts(); + } + + inline void print() + { + using namespace std; + cout << "Timing report: " << endl; + + // figure out how long the longest name is going to be. + unsigned long max_name_length = 0; + for (int i = 0; i < TIME_SLOTS; ++i) + { + string name; + // Check if the name buffer is empty. Use the name it contains if it isn't. + if (name_buf(i,"")[0] != '\0') + name = cast_to_string(i) + ": " + name_buf(i,""); + else + name = cast_to_string(i); + max_name_length = std::max(max_name_length, name.size()); + } + + for (int i = 0; i < TIME_SLOTS; ++i) + { + if (time_buf()[i] != 0) + { + double time = time_buf()[i]/1000.0/1000.0; + string name; + // Check if the name buffer is empty. Use the name it contains if it isn't. + if (name_buf(i,"")[0] != '\0') + name = cast_to_string(i) + ": " + name_buf(i,""); + else + name = cast_to_string(i); + + // make sure the name is always the same length. Do so by padding with spaces + if (name.size() < max_name_length) + name += string(max_name_length-name.size(),' '); + + if (time < 1000) + cout << " " << name << ": " << time << " milliseconds" << endl; + else if (time < 1000*60) + cout << " " << name << ": " << time/1000.0 << " seconds" << endl; + else if (time < 1000*60*60) + cout << " " << name << ": " << time/1000.0/60.0 << " minutes" << endl; + else + cout << " " << name << ": " << time/1000.0/60.0/60.0 << " hours" << endl; + } + } + } + + inline void clear() + { + for (int i = 0; i < TIME_SLOTS; ++i) + { + // clear timing buffer + time_buf()[i] = 0; + // clear name buffer + name_buf(i,"")[0] = '\0'; + } + } + + struct block + { + /*! + WHAT THIS OBJECT REPRESENTS + This is an RAII tool for calling start() and stop() + !*/ + + block(int i):idx(i) {start(idx);} + block(int i, const char* str):idx(i) {start(idx,str);} + ~block() { stop(idx); } + const int idx; + }; + } +} + + +#endif // DLIB_TImING_Hh_ + diff --git a/ml/dlib/dlib/tokenizer.h b/ml/dlib/dlib/tokenizer.h new file mode 100644 index 000000000..01b6fcf83 --- /dev/null +++ b/ml/dlib/dlib/tokenizer.h @@ -0,0 +1,33 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_TOKENIZEr_ +#define DLIB_TOKENIZEr_ + +#include "tokenizer/tokenizer_kernel_1.h" +#include "tokenizer/tokenizer_kernel_c.h" + + +namespace dlib +{ + + class tokenizer + { + tokenizer() {} + + + public: + + //----------- kernels --------------- + + // kernel_1a + typedef tokenizer_kernel_1 + kernel_1a; + typedef tokenizer_kernel_c + kernel_1a_c; + + + }; +} + +#endif // DLIB_TOKENIZEr_ + diff --git a/ml/dlib/dlib/tokenizer/tokenizer_kernel_1.cpp b/ml/dlib/dlib/tokenizer/tokenizer_kernel_1.cpp new file mode 100644 index 000000000..daa83184c --- /dev/null +++ b/ml/dlib/dlib/tokenizer/tokenizer_kernel_1.cpp @@ -0,0 +1,295 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_TOKENIZER_KERNEL_1_CPp_ +#define DLIB_TOKENIZER_KERNEL_1_CPp_ +#include "tokenizer_kernel_1.h" + +#include +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + tokenizer_kernel_1:: + tokenizer_kernel_1 ( + ) : + headset(0), + bodyset(0), + have_peeked(false) + { + try + { + headset = new bool[UCHAR_MAX]; + bodyset = new bool[UCHAR_MAX]; + + clear(); + } + catch (...) + { + if (headset) delete [] headset; + if (bodyset) delete [] bodyset; + throw; + } + } + +// ---------------------------------------------------------------------------------------- + + tokenizer_kernel_1:: + ~tokenizer_kernel_1 ( + ) + { + delete [] bodyset; + delete [] headset; + } + +// ---------------------------------------------------------------------------------------- + + void tokenizer_kernel_1:: + clear( + ) + { + using namespace std; + + in = 0; + streambuf = 0; + have_peeked = false; + + head = "_" + lowercase_letters() + uppercase_letters(); + body = "_" + lowercase_letters() + uppercase_letters() + numbers(); + + for (unsigned long i = 0; i < UCHAR_MAX; ++i) + { + headset[i] = false; + bodyset[i] = false; + } + + for (string::size_type i = 0; i < head.size(); ++i) + headset[static_cast(head[i])] = true; + for (string::size_type i = 0; i < body.size(); ++i) + bodyset[static_cast(body[i])] = true; + } + +// ---------------------------------------------------------------------------------------- + + void tokenizer_kernel_1:: + set_stream ( + std::istream& in_ + ) + { + in = &in_; + streambuf = in_.rdbuf(); + have_peeked = false; + } + +// ---------------------------------------------------------------------------------------- + + bool tokenizer_kernel_1:: + stream_is_set ( + ) const + { + return (in != 0); + } + +// ---------------------------------------------------------------------------------------- + + std::istream& tokenizer_kernel_1:: + get_stream ( + ) const + { + return *in; + } + +// ---------------------------------------------------------------------------------------- + + void tokenizer_kernel_1:: + get_token ( + int& type, + std::string& token + ) + { + if (!have_peeked) + { + std::streambuf::int_type ch; + ch = streambuf->sbumpc(); + + switch (ch) + { + case EOF: + type = END_OF_FILE; + token.clear(); + return; + + case '\n': + type = END_OF_LINE; + token = "\n"; + return; + + case '\r': + case ' ': + case '\t': + type = WHITE_SPACE; + token = static_cast(ch); + ch = streambuf->sgetc(); + while ((ch == ' ' || ch == '\t' || ch == '\r') && ch != EOF) + { + token += static_cast(ch); + ch = streambuf->snextc(); + } + return; + + default: + if (headset[static_cast(ch)]) + { + type = IDENTIFIER; + token = static_cast(ch); + ch = streambuf->sgetc(); + while ( bodyset[static_cast(ch)] && ch != EOF ) + { + token += static_cast(ch); + ch = streambuf->snextc(); + } + } + else if ('0' <= ch && ch <= '9') + { + type = NUMBER; + token = static_cast(ch); + ch = streambuf->sgetc(); + while (('0' <= ch && ch <= '9') && ch != EOF) + { + token += static_cast(ch); + ch = streambuf->snextc(); + } + } + else + { + type = CHAR; + token = static_cast(ch); + } + return; + } // switch (ch) + } + + // if we get this far it means we have peeked so we should + // return the peek data. + type = next_type; + token = next_token; + have_peeked = false; + } + +// ---------------------------------------------------------------------------------------- + + int tokenizer_kernel_1:: + peek_type ( + ) const + { + const_cast(this)->get_token(next_type,next_token); + have_peeked = true; + return next_type; + } + +// ---------------------------------------------------------------------------------------- + + const std::string& tokenizer_kernel_1:: + peek_token ( + ) const + { + const_cast(this)->get_token(next_type,next_token); + have_peeked = true; + return next_token; + } + +// ---------------------------------------------------------------------------------------- + + void tokenizer_kernel_1:: + swap ( + tokenizer_kernel_1& item + ) + { + exchange(in,item.in); + exchange(streambuf,item.streambuf); + exchange(head,item.head); + exchange(body,item.body); + exchange(bodyset,item.bodyset); + exchange(headset,item.headset); + exchange(have_peeked,item.have_peeked); + exchange(next_type,item.next_type); + exchange(next_token,item.next_token); + } + +// ---------------------------------------------------------------------------------------- + + void tokenizer_kernel_1:: + set_identifier_token ( + const std::string& head_, + const std::string& body_ + ) + { + using namespace std; + + head = head_; + body = body_; + + for (unsigned long i = 0; i < UCHAR_MAX; ++i) + { + headset[i] = false; + bodyset[i] = false; + } + + for (string::size_type i = 0; i < head.size(); ++i) + headset[static_cast(head[i])] = true; + for (string::size_type i = 0; i < body.size(); ++i) + bodyset[static_cast(body[i])] = true; + } + +// ---------------------------------------------------------------------------------------- + + const std::string tokenizer_kernel_1:: + get_identifier_head ( + ) const + { + return head; + } + +// ---------------------------------------------------------------------------------------- + + const std::string tokenizer_kernel_1:: + get_identifier_body ( + ) const + { + return body; + } + +// ---------------------------------------------------------------------------------------- + + const std::string tokenizer_kernel_1:: + lowercase_letters ( + ) const + { + return std::string("abcdefghijklmnopqrstuvwxyz"); + } + +// ---------------------------------------------------------------------------------------- + + const std::string tokenizer_kernel_1:: + uppercase_letters ( + ) const + { + return std::string("ABCDEFGHIJKLMNOPQRSTUVWXYZ"); + } + +// ---------------------------------------------------------------------------------------- + + const std::string tokenizer_kernel_1:: + numbers ( + ) const + { + return std::string("0123456789"); + } + +// ---------------------------------------------------------------------------------------- + +} +#endif // DLIB_TOKENIZER_KERNEL_1_CPp_ + diff --git a/ml/dlib/dlib/tokenizer/tokenizer_kernel_1.h b/ml/dlib/dlib/tokenizer/tokenizer_kernel_1.h new file mode 100644 index 000000000..d67ae278f --- /dev/null +++ b/ml/dlib/dlib/tokenizer/tokenizer_kernel_1.h @@ -0,0 +1,155 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_TOKENIZER_KERNEl_1_ +#define DLIB_TOKENIZER_KERNEl_1_ + +#include +#include +#include +#include "../algs.h" +#include "tokenizer_kernel_abstract.h" + +namespace dlib +{ + + class tokenizer_kernel_1 + { + /*! + INITIAL VALUE + - in == 0 + - streambuf == 0 + - have_peeked == false + - head == "_" + lowercase_letters() + uppercase_letters() + - body == "_" + lowercase_letters() + uppercase_letters() + numbers() + - headset == pointer to an array of UCHAR_MAX bools and set according + to the CONVENTION. + - bodyset == pointer to an array of UCHAR_MAX bools and set according + to the CONVENTION. + + CONVENTION + - if (stream_is_set()) then + - get_stream() == *in + - streambuf == in->rdbuf() + - else + - in == 0 + - streambuf == 0 + + - body == get_identifier_body() + - head == get_identifier_head() + + - if (the char x appears in head) then + - headset[static_cast(x)] == true + - else + - headset[static_cast(x)] == false + + - if (the char x appears in body) then + - bodyset[static_cast(x)] == true + - else + - bodyset[static_cast(x)] == false + + - if (have_peeked) then + - next_token == the next token to be returned from get_token() + - next_type == the type of token in peek_token + !*/ + + public: + + // The name of this enum is irrelevant but on some compilers (gcc on MAC OS X) not having it named + // causes an error for whatever reason + enum some_random_name + { + END_OF_LINE, + END_OF_FILE, + IDENTIFIER, + CHAR, + NUMBER, + WHITE_SPACE + }; + + tokenizer_kernel_1 ( + ); + + virtual ~tokenizer_kernel_1 ( + ); + + void clear( + ); + + void set_stream ( + std::istream& in + ); + + bool stream_is_set ( + ) const; + + std::istream& get_stream ( + ) const; + + void get_token ( + int& type, + std::string& token + ); + + void swap ( + tokenizer_kernel_1& item + ); + + void set_identifier_token ( + const std::string& head, + const std::string& body + ); + + int peek_type ( + ) const; + + const std::string& peek_token ( + ) const; + + const std::string get_identifier_head ( + ) const; + + const std::string get_identifier_body ( + ) const; + + const std::string lowercase_letters ( + ) const; + + const std::string uppercase_letters ( + ) const; + + const std::string numbers ( + ) const; + + private: + + // restricted functions + tokenizer_kernel_1(const tokenizer_kernel_1&); // copy constructor + tokenizer_kernel_1& operator=(const tokenizer_kernel_1&); // assignment operator + + + // data members + std::istream* in; + std::streambuf* streambuf; + std::string head; + std::string body; + bool* headset; + bool* bodyset; + + mutable std::string next_token; + mutable int next_type; + mutable bool have_peeked; + }; + + inline void swap ( + tokenizer_kernel_1& a, + tokenizer_kernel_1& b + ) { a.swap(b); } + +} + +#ifdef NO_MAKEFILE +#include "tokenizer_kernel_1.cpp" +#endif + +#endif // DLIB_TOKENIZER_KERNEl_1 + diff --git a/ml/dlib/dlib/tokenizer/tokenizer_kernel_abstract.h b/ml/dlib/dlib/tokenizer/tokenizer_kernel_abstract.h new file mode 100644 index 000000000..f534b8f7f --- /dev/null +++ b/ml/dlib/dlib/tokenizer/tokenizer_kernel_abstract.h @@ -0,0 +1,289 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_TOKENIZER_KERNEl_ABSTRACT_ +#ifdef DLIB_TOKENIZER_KERNEl_ABSTRACT_ + +#include +#include + +namespace dlib +{ + + class tokenizer + { + /*! + INITIAL VALUE + stream_is_set() == false + get_identifier_head() == "_" + lowercase_letters() + uppercase_letters() + get_identifier_body() == "_" + lowercase_letters() + uppercase_letters() + + numbers() + + WHAT THIS OBJECT REPRESENTS + This object represents a simple tokenizer for textual data. + + BUFFERING + This object is allowed to buffer data from the input stream. + Thus if you clear it or switch streams (via calling set_stream()) + any buffered data will be lost. + + TOKENS + When picking out tokens the tokenizer will always extract the + longest token it can. For example, if faced with the string + "555" it will consider the three 5s to be a single NUMBER + token not three smaller NUMBER tokens. + + Also note that no characters in the input stream are discarded. + They will all be returned in the text of some token. + Additionally, each character will never be returned more than once. + This means that if you concatenated all returned tokens it would exactly + reproduce the contents of the input stream. + + The tokens are defined as follows: + + END_OF_LINE + This is a single character token and is always the '\n' + character. + + END_OF_FILE + This token represents the end of file. It doesn't have any + actual characters associated with it. + + IDENTIFIER + This is a multi-character token. It is defined as a string that + begins with a character from get_identifier_head() and is + followed by any number of characters from get_identifier_body(). + + NUMBER + This is a multi-character token. It is defined as a sequence of + numbers. + + WHITE_SPACE + This is a multi character token. It is defined as a sequence of + one or more spaces, carrage returns, and tabs. I.e. It is + composed of characters from the following string " \r\t". + + CHAR + This is a single character token. It matches anything that isn't + part of one of the above tokens. + !*/ + + public: + + enum + { + END_OF_LINE, + END_OF_FILE, + IDENTIFIER, + CHAR, + NUMBER, + WHITE_SPACE + }; + + tokenizer ( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc + !*/ + + virtual ~tokenizer ( + ); + /*! + ensures + - any resources associated with *this have been released + !*/ + + void clear( + ); + /*! + ensures + - #*this has its initial value + throws + - std::bad_alloc + If this exception is thrown then #*this is unusable + until clear() is called and succeeds. + !*/ + + void set_stream ( + std::istream& in + ); + /*! + ensures + - #*this will read data from in and tokenize it + - #stream_is_set() == true + - #get_stream() == in + !*/ + + bool stream_is_set ( + ) const; + /*! + ensures + - returns true if a stream has been associated with *this by calling + set_stream() + !*/ + + std::istream& get_stream ( + ) const; + /*! + requires + - stream_is_set() == true + ensures + - returns a reference to the istream object that *this is reading + from. + !*/ + + void get_token ( + int& type, + std::string& token + ); + /*! + requires + - stream_is_set() == true + ensures + - #token == the next token from the input stream get_stream() + - #type == the type of the token in #token + throws + - bad_alloc + If this exception is thrown then the call to this function will + have no effect on *this but the values of #type and #token will be + undefined. Additionally, some characters may have been read + from the stream get_stream() and lost. + !*/ + + int peek_type ( + ) const; + /*! + requires + - stream_is_set() == true + ensures + - returns the type of the token that will be returned from + the next call to get_token() + throws + - bad_alloc + If this exception is thrown then the call to this function will + have no effect on *this. However, some characters may have been + read from the stream get_stream() and lost. + !*/ + + const std::string& peek_token ( + ) const; + /*! + requires + - stream_is_set() == true + ensures + - returns the text of the token that will be returned from + the next call to get_token() + throws + - bad_alloc + If this exception is thrown then the call to this function will + have no effect on *this. However, some characters may have been + read from the stream get_stream() and lost. + !*/ + + void set_identifier_token ( + const std::string& head, + const std::string& body + ); + /*! + requires + - head.find_first_of(" \r\t\n0123456789") == std::string::npos + (i.e. head doesn't contain any characters from the string + " \r\t\n0123456789"). + - body.find_frst_of(" \r\t\n") == std::string::npos + (i.e. body doesn't contain any characters from the string " \r\t\n"). + ensures + - #get_identifier_head() == head + - #get_identifier_body() == body + throws + - std::bad_alloc + If this exception is thrown then #*this is unusable + until clear() is called and succeeds. + !*/ + + const std::string get_identifier_head ( + ) const; + /*! + ensures + - returns a string containing the characters that can be the start + of an IDENTIFIER token. + throws + - std::bad_alloc + If this exception is thrown then the call to this function + has no effect. + !*/ + + const std::string get_identifier_body ( + ) const; + /*! + ensures + - returns a string containing the characters that can appear in the + body of an IDENTIFIER token. + throws + - std::bad_alloc + If this exception is thrown then the call to this function + has no effect. + !*/ + + const std::string lowercase_letters ( + ) const; + /*! + ensures + - returns "abcdefghijklmnopqrstuvwxyz" + throws + - std::bad_alloc + If this exception is thrown then the call to this function + has no effect. + !*/ + + const std::string uppercase_letters ( + ) const; + /*! + ensures + - returns "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + throws + - std::bad_alloc + If this exception is thrown then the call to this function + has no effect. + !*/ + + const std::string numbers ( + ) const; + /*! + ensures + - returns "0123456789" + throws + - std::bad_alloc + If this exception is thrown then the call to this function + has no effect. + !*/ + + void swap ( + tokenizer& item + ); + /*! + ensures + - swaps *this and item + !*/ + + private: + + // restricted functions + tokenizer(const tokenizer&); // copy constructor + tokenizer& operator=(const tokenizer&); // assignment operator + + }; + + inline void swap ( + tokenizer& a, + tokenizer& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + +} + +#endif // DLIB_TOKENIZER_KERNEl_ABSTRACT_ + diff --git a/ml/dlib/dlib/tokenizer/tokenizer_kernel_c.h b/ml/dlib/dlib/tokenizer/tokenizer_kernel_c.h new file mode 100644 index 000000000..f9604809d --- /dev/null +++ b/ml/dlib/dlib/tokenizer/tokenizer_kernel_c.h @@ -0,0 +1,167 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_TOKENIZER_KERNEl_C_ +#define DLIB_TOKENIZER_KERNEl_C_ + +#include "tokenizer_kernel_abstract.h" +#include "../assert.h" +#include +#include + +namespace dlib +{ + + template < + typename tokenizer + > + class tokenizer_kernel_c : public tokenizer + { + + public: + std::istream& get_stream ( + ) const; + + void get_token ( + int& type, + std::string& token + ); + + void set_identifier_token ( + const std::string& head, + const std::string& body + ); + + int peek_type ( + ) const; + + const std::string& peek_token ( + ) const; + }; + + template < + typename tokenizer + > + inline void swap ( + tokenizer_kernel_c& a, + tokenizer_kernel_c& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename tokenizer + > + void tokenizer_kernel_c:: + set_identifier_token ( + const std::string& head, + const std::string& body + ) + { + using namespace std; + // make sure requires clause is not broken + DLIB_CASSERT( head.find_first_of(" \r\t\n0123456789") == string::npos && + body.find_first_of(" \r\t\n") == string::npos , + "\tvoid tokenizer::set_identifier_token()" + << "\n\tyou can't define the IDENTIFIER token this way." + << "\n\thead: " << head + << "\n\tbody: " << body + << "\n\tthis: " << this + ); + + // call the real function + tokenizer::set_identifier_token(head,body); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename tokenizer + > + std::istream& tokenizer_kernel_c:: + get_stream ( + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT( this->stream_is_set() == true, + "\tstd::istream& tokenizer::get_stream()" + << "\n\tyou must set a stream for this object before you can get it" + << "\n\tthis: " << this + ); + + // call the real function + return tokenizer::get_stream(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename tokenizer + > + int tokenizer_kernel_c:: + peek_type ( + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT( this->stream_is_set() == true, + "\tint tokenizer::peek_type()" + << "\n\tyou must set a stream for this object before you peek at what it contains" + << "\n\tthis: " << this + ); + + // call the real function + return tokenizer::peek_type(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename tokenizer + > + const std::string& tokenizer_kernel_c:: + peek_token ( + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT( this->stream_is_set() == true, + "\tint tokenizer::peek_token()" + << "\n\tyou must set a stream for this object before you peek at what it contains" + << "\n\tthis: " << this + ); + + // call the real function + return tokenizer::peek_token(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename tokenizer + > + void tokenizer_kernel_c:: + get_token ( + int& type, + std::string& token + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( this->stream_is_set() == true, + "\tvoid tokenizer::get_token()" + << "\n\tyou must set a stream for this object before you can get tokens from it." + << "\n\tthis: " << this + ); + + // call the real function + tokenizer::get_token(type,token); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_TOKENIZER_KERNEl_C_ + + diff --git a/ml/dlib/dlib/travis/build-and-test.sh b/ml/dlib/dlib/travis/build-and-test.sh new file mode 100755 index 000000000..4ee74e36b --- /dev/null +++ b/ml/dlib/dlib/travis/build-and-test.sh @@ -0,0 +1,45 @@ +#!/usr/bin/env bash +# Exit if anything fails. +set -eux + +# execute the contents of MATRIX_EVAL if it's set +if [[ -v MATRIX_EVAL ]]; then + eval "${MATRIX_EVAL}" +fi + +# build dlib and tests +if [ "$VARIANT" = "test" ]; then + mkdir build + cd build + cmake ../dlib/test + cmake --build . --target dtest -- -j 2 + ./dtest --runall +fi + +if [ "$VARIANT" = "dlib_all_source_cpp" ]; then + mkdir build + cd build + cmake ../dlib/test + cmake --build . --target dlib_all_source_cpp -- -j 2 +fi + +if [ "$VARIANT" = "tools" ]; then + mkdir build + cd build + cmake ../dlib/test/tools + cmake --build . -- -j 2 +fi + +if [ "$VARIANT" = "examples" ]; then + mkdir build + cd build + cmake ../examples + cmake --build . -- -j 1 +fi + +if [ "$VARIANT" = "python-api" ]; then + python setup.py test --clean + pip uninstall numpy -y + python setup.py test --clean +fi + diff --git a/ml/dlib/dlib/tuple.h b/ml/dlib/dlib/tuple.h new file mode 100644 index 000000000..2cecddeef --- /dev/null +++ b/ml/dlib/dlib/tuple.h @@ -0,0 +1,10 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_TUPLe_TOP_ +#define DLIB_TUPLe_TOP_ + +#include "tuple/tuple.h" + +#endif // DLIB_TUPLe_TOPh_ + + diff --git a/ml/dlib/dlib/tuple/tuple.h b/ml/dlib/dlib/tuple/tuple.h new file mode 100644 index 000000000..cac2b4ade --- /dev/null +++ b/ml/dlib/dlib/tuple/tuple.h @@ -0,0 +1,410 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_TUPLe_H_ +#define DLIB_TUPLe_H_ + +#include "../enable_if.h" +#include "../algs.h" +#include "../serialize.h" +#include "tuple_abstract.h" + +// ---------------------------------------------------------------------------------------- + +#define DLIB_TUPLE_GLOBAL_HELPERS(N) \ + DLIB_TUPLE_GH(N##0) DLIB_TUPLE_GH(N##1) DLIB_TUPLE_GH(N##2) DLIB_TUPLE_GH(N##3) \ + DLIB_TUPLE_GH(N##4) DLIB_TUPLE_GH(N##5) DLIB_TUPLE_GH(N##6) DLIB_TUPLE_GH(N##7) + +#define DLIB_TUPLE_GH(N) DLIB_TUPLE_GET_INDEX(N) DLIB_TUPLE_GET_ITEM(N) DLIB_TUPLE_GET_HELPER_STRUCT(N) + +#define DLIB_TUPLE_MEMBER_GET(N) \ + DLIB_TUPLE_MG(N##0) DLIB_TUPLE_MG(N##1) DLIB_TUPLE_MG(N##2) DLIB_TUPLE_MG(N##3) \ + DLIB_TUPLE_MG(N##4) DLIB_TUPLE_MG(N##5) DLIB_TUPLE_MG(N##6) DLIB_TUPLE_MG(N##7) + +#define DLIB_TUPLE_GET_INDEX(N) \ + template const typename enable_if, long>::type get_index (const T&) {return N;} + +#define DLIB_TUPLE_GET_ITEM(N) \ + template const typename enable_if,Q>::type& get_item_const (const T& t) {return t.v##N;}\ + template typename enable_if,Q>::type& get_item ( T& t) {return t.v##N;} + + +#define DLIB_TUPLE_GET_HELPER_STRUCT(N) \ + template struct get_helper \ + { \ + typedef typename T::type##N type; \ + static const type& get(const T& t) { return t.v##N; } \ + static type& get( T& t) { return t.v##N; } \ + }; + +#define DLIB_TUPLE_TEMPLATE_LIST(N) \ + class T##N##0 = null_type, class T##N##1 = null_type, class T##N##2 = null_type, class T##N##3 = null_type, \ + class T##N##4 = null_type, class T##N##5 = null_type, class T##N##6 = null_type, class T##N##7 = null_type + +#define DLIB_TUPLE_VARIABLE_LIST(N) \ + T##N##0 v##N##0; T##N##1 v##N##1; T##N##2 v##N##2; T##N##3 v##N##3; \ + T##N##4 v##N##4; T##N##5 v##N##5; T##N##6 v##N##6; T##N##7 v##N##7; \ + typedef T##N##0 type##N##0; typedef T##N##1 type##N##1; typedef T##N##2 type##N##2; \ + typedef T##N##3 type##N##3; typedef T##N##4 type##N##4; typedef T##N##5 type##N##5; \ + typedef T##N##6 type##N##6; typedef T##N##7 type##N##7; + +// ---------------------------------------------------------------------------------------- + +namespace dlib +{ + + struct null_type{}; + + // provide default serialization for the null_type + inline void serialize ( + const null_type& , + std::ostream& + ){} + inline void deserialize ( + null_type& , + std::istream& + ){} + +// ---------------------------------------------------------------------------------------- + + namespace tuple_helpers + { + template struct get_helper; + + // use these preprocessor macros to declare all the global stuff used by the + // tuple member functions. + DLIB_TUPLE_GLOBAL_HELPERS(0) + DLIB_TUPLE_GLOBAL_HELPERS(01) + DLIB_TUPLE_GLOBAL_HELPERS(02) + DLIB_TUPLE_GLOBAL_HELPERS(03) + + // ------------------------------------------------------------------------------------ + + // use templates to recursively enumerate everything in the tuple that isn't a null_type + template < + typename T, + typename F, + long i = 0, + typename enabled = void + > + struct for_each + { + static void go( + T& a, + F& funct + ) + { + funct(a.template get()); + for_each::go(a,funct); + } + + static bool go( + T& a, + F& funct, + long idx + ) + /*! + ensures + - returns true if the function was applied to the given index + - returns false if the index is invalid so the function wasn't + applied to anything + !*/ + { + if (idx == i) + { + funct(a.template get()); + return true; + } + else + { + return for_each::go(a,funct,idx); + } + } + }; + + template struct template_or { const static bool value = true; }; + template <> struct template_or { const static bool value = false; }; + + // the base case of the recursion + template < + typename T, + typename F, + long i + > + struct for_each::type >::value> >::type > + { + static void go( T&, F& ) { } + static bool go( T&, F&, long ) { return false; } + }; + + // ------------------------------------------------------------------------------------ + + // use templates to recursively enumerate everything in the tuple that isn't a null_type + template < + typename T, + long i = 0, + typename enabled = void + > + struct tuple_swap + { + static void go( + T& a, + T& b + ) + { + exchange(a.template get(), b.template get()); + tuple_swap::go(a,b); + } + }; + + template + struct at_base_case + { + + }; + + // the base case of the recursion + template < + typename T, + long i + > + struct tuple_swap::type >::value > >::type > + { static void go( T&, T& ) { } }; + + // ------------------------------------------------------------------------------------ + + struct tuple_serialize + { + tuple_serialize (std::ostream& out_) : out(out_){} + std::ostream& out; + + template + void operator() ( + T& a + ) const { serialize(a,out); } + }; + + // ------------------------------------------------------------------------------------ + + struct tuple_deserialize + { + tuple_deserialize (std::istream& in_) : in(in_){} + std::istream& in; + template + void operator() ( + T& a + ) const { deserialize(a,in); } + }; + + + } + +// ---------------------------------------------------------------------------------------- + + // use these preprocessor macros to declare 4*8 template arguments (below we count them in octal) + template < + DLIB_TUPLE_TEMPLATE_LIST(0), // args 00-07 + DLIB_TUPLE_TEMPLATE_LIST(01), // args 010-017 + DLIB_TUPLE_TEMPLATE_LIST(02), // args 020-027 + DLIB_TUPLE_TEMPLATE_LIST(03) // args 030-037 + > + class tuple + { + public: + + // use these macros to declare 8*4 member variables + DLIB_TUPLE_VARIABLE_LIST(0) + DLIB_TUPLE_VARIABLE_LIST(01) + DLIB_TUPLE_VARIABLE_LIST(02) + DLIB_TUPLE_VARIABLE_LIST(03) + + const static long max_fields = 4*8; + + template < long idx > + struct get_type + { + typedef typename tuple_helpers::get_helper::type type; + }; + + template < long idx > + const typename tuple_helpers::get_helper::type& get ( + ) const { return tuple_helpers::get_helper::get(*this); } + + template < long idx > + typename tuple_helpers::get_helper::type& get ( + ) { return tuple_helpers::get_helper::get(*this); } + + template < class Q> + long index ( + ) const { return tuple_helpers::get_index(*this); } + + template + const Q& get ( + ) const {return tuple_helpers::get_item_const(*this);} + + template + Q& get ( + ) {return tuple_helpers::get_item(*this);} + + + + + template + void for_index ( + F& funct, + long idx + ) + { + // do this #ifdef stuff to avoid getting a warning about valid_idx not being + // used when ENABLE_ASSERTS isn't defined. +#ifdef ENABLE_ASSERTS + const bool valid_idx = tuple_helpers::for_each::go(*this,funct,idx); +#else + tuple_helpers::for_each::go(*this,funct,idx); +#endif + DLIB_ASSERT(valid_idx, + "\tvoid tuple::for_index()" + << "\n\tYou have attempted to call for_index() with an index out of the valid range" + << "\n\tidx: " << idx + << "\n\tthis: " << this + ); + } + + template + void for_index ( + F& funct, + long idx + ) const + { + // do this #ifdef stuff to avoid getting a warning about valid_idx not being + // used when ENABLE_ASSERTS isn't defined. +#ifdef ENABLE_ASSERTS + const bool valid_idx = tuple_helpers::for_each::go(*this,funct,idx); +#else + tuple_helpers::for_each::go(*this,funct,idx); +#endif + DLIB_ASSERT(valid_idx, + "\tvoid tuple::for_index()" + << "\n\tYou have attempted to call for_index() with an index out of the valid range" + << "\n\tidx: " << idx + << "\n\tthis: " << this + ); + } + + template + void for_index ( + const F& funct, + long idx + ) + { + // do this #ifdef stuff to avoid getting a warning about valid_idx not being + // used when ENABLE_ASSERTS isn't defined. +#ifdef ENABLE_ASSERTS + const bool valid_idx = tuple_helpers::for_each::go(*this,funct,idx); +#else + tuple_helpers::for_each::go(*this,funct,idx); +#endif + DLIB_ASSERT(valid_idx, + "\tvoid tuple::for_index()" + << "\n\tYou have attempted to call for_index() with an index out of the valid range" + << "\n\tidx: " << idx + << "\n\tthis: " << this + ); + } + + template + void for_index ( + const F& funct, + long idx + ) const + { + // do this #ifdef stuff to avoid getting a warning about valid_idx not being + // used when ENABLE_ASSERTS isn't defined. +#ifdef ENABLE_ASSERTS + const bool valid_idx = tuple_helpers::for_each::go(*this,funct,idx); +#else + tuple_helpers::for_each::go(*this,funct,idx); +#endif + DLIB_ASSERT(valid_idx, + "\tvoid tuple::for_index()" + << "\n\tYou have attempted to call for_index() with an index out of the valid range" + << "\n\tidx: " << idx + << "\n\tthis: " << this + ); + } + + + + + template + void for_each ( + F& funct + ) { tuple_helpers::for_each::go(*this,funct); } + + template + void for_each ( + F& funct + ) const { tuple_helpers::for_each::go(*this,funct); } + + template + void for_each ( + const F& funct + ) const { tuple_helpers::for_each::go(*this,funct); } + + template + void for_each ( + const F& funct + ) { tuple_helpers::for_each::go(*this,funct); } + + + + + inline friend void serialize ( + tuple& item, + std::ostream& out + ) + { + try + { + item.for_each(tuple_helpers::tuple_serialize(out)); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing an object of type dlib::tuple<>"); + } + } + + inline friend void deserialize ( + tuple& item, + std::istream& in + ) + { + try + { + item.for_each(tuple_helpers::tuple_deserialize(in)); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while deserializing an object of type dlib::tuple<>"); + } + } + + inline friend void swap ( + tuple& a, + tuple& b + ) + { + tuple_helpers::tuple_swap::go(a,b); + } + + inline void swap( + tuple& item + ) { tuple_helpers::tuple_swap::go(item,*this); } + + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_TUPLe_H_ + diff --git a/ml/dlib/dlib/tuple/tuple_abstract.h b/ml/dlib/dlib/tuple/tuple_abstract.h new file mode 100644 index 000000000..aff9b8122 --- /dev/null +++ b/ml/dlib/dlib/tuple/tuple_abstract.h @@ -0,0 +1,302 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_TUPLe_ABSTRACT_H_ +#ifdef DLIB_TUPLe_ABSTRACT_H_ + +#include "../algs.h" +#include "../serialize.h" +#include "tuple_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + struct null_type + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is the default type used as the default + template argument to the tuple object's template arguments. + + Also note that it has no state associated with it. + !*/ + }; + + inline void serialize ( + const null_type& , + std::ostream& + ){} + inline void deserialize ( + null_type& , + std::istream& + ){} + /*! + Serialization support is provided for null_type because in some cases + it makes your code a little more concise and easier to deal with + when using tuple objects and serialization. The serialization literally + does nothing though. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T0 = null_type, + typename T1 = null_type, + typename T2 = null_type, + typename T3 = null_type, + ... + typename T31 = null_type + > + class tuple + { + /*! + INITIAL VALUE + Each object in the tuple is default initialized by its own constructor. + The tuple object itself does not modify them or add any additional + state. + + WHAT THIS OBJECT REPRESENTS + This object represents a container of between 0 and 31 objects + where the objects contained are specified in the template + arguments. + + EXAMPLE + We can declare a tuple that contains an int, a float, and a char like so: + tuple ex; + + Then we can access each of these by their index number. The index number + is just the order each type has in the template argument list. So we have: + ex.get<0>() = 5; // assign the int the value 5 + ex.get<1>() = 3.14; // assign the float the value 3.14 + ex.get<2>() = 'c'; // assign the char the value 'c' + + Also, since we only have one of each type in this example tuple we can + unambiguously access each field in the tuple by their types. So for + example, we can use this syntax to access our fields: + ex.get() // returns 5 + ex.get() // returns 3.14 + ex.get() // returns 'c' + + We can also get the indexes of each of these fields like so: + ex.index() // returns 0 + ex.index() // returns 1 + ex.index() // returns 2 + !*/ + + public: + // the maximum number of items this tuple template can contain + const static long max_fields = 32; + + template + struct get_type + { + typedef (the type of the Tindex template argument) type; + }; + + template + const get_type::type& get ( + ) const; + /*! + requires + - 0 <= index <= 31 + ensures + - returns a const reference to the index(th) object contained + inside this tuple + !*/ + + template + get_type::type& get ( + ); + /*! + requires + - 0 <= index <= 31 + ensures + - returns a non-const reference to the index(th) object contained + inside this tuple + !*/ + + template + const long index ( + ) const; + /*! + requires + - Q is a type of object contained in this tuple and there is + only one object of that type in the tuple + ensures + - returns the index of the object in this tuple with type Q + !*/ + + template + const Q& get ( + ) const; + /*! + requires + - Q is a type of object contained in this tuple and there is + only one object of that type in the tuple + ensures + - returns a const reference to the object in this tuple + with type Q + !*/ + + template + Q& get ( + ); + /*! + requires + - Q is a type of object contained in this tuple and there is + only one object of that type in the tuple + ensures + - returns a non-const reference to the object in this tuple + with type Q + !*/ + + template + void for_each ( + F& funct + ); + /*! + requires + - funct is a templated function object + ensures + - for each item X in this tuple that isn't a null_type object: + - calls funct(X); + !*/ + + template + void for_each ( + F& funct + ) const; + /*! + requires + - funct is a templated function object + ensures + - for each item X in this tuple that isn't a null_type object: + - calls funct(X); + !*/ + + template + void for_each ( + const F& funct + ); + /*! + requires + - funct is a templated function object + ensures + - for each item X in this tuple that isn't a null_type object: + - calls funct(X); + !*/ + + template + void for_each ( + const F& funct + ) const; + /*! + requires + - funct is a templated function object + ensures + - for each item X in this tuple that isn't a null_type object: + - calls funct(X); + !*/ + + template + void for_index ( + F& funct, + long idx + ); + /*! + requires + - funct is a templated function object + - 0 <= idx < max_fields && get_type::type != null_type + (i.e. idx must be the index of a non-null_type object in this tuple) + ensures + - calls funct(this->get()); + !*/ + + template + void for_index ( + F& funct, + long idx + ) const; + /*! + requires + - funct is a templated function object + - 0 <= idx < max_fields && get_type::type != null_type + (i.e. idx must be the index of a non-null_type object in this tuple) + ensures + - calls funct(this->get()); + !*/ + + template + void for_index ( + const F& funct, + long idx + ); + /*! + requires + - funct is a templated function object + - 0 <= idx < max_fields && get_type::type != null_type + (i.e. idx must be the index of a non-null_type object in this tuple) + ensures + - calls funct(this->get()); + !*/ + + template + void for_index ( + const F& funct, + long idx + ) const; + /*! + requires + - funct is a templated function object + - 0 <= idx < max_fields && get_type::type != null_type + (i.e. idx must be the index of a non-null_type object in this tuple) + ensures + - calls funct(this->get()); + !*/ + + void swap ( + tuple& item + ); + /*! + ensures + - swaps *this and item + !*/ + + // ------------------------------------------------- + // global functions for tuple objects + // ------------------------------------------------- + + friend void swap ( + tuple& a, + tuple& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + + friend void serialize ( + const tuple& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + + friend void deserialize ( + tuple& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_TUPLe_ABSTRACT_H_ + + diff --git a/ml/dlib/dlib/type_safe_union.h b/ml/dlib/dlib/type_safe_union.h new file mode 100644 index 000000000..d04bf0d17 --- /dev/null +++ b/ml/dlib/dlib/type_safe_union.h @@ -0,0 +1,11 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_TYPE_SAFE_UNIOn_TOP_ +#define DLIB_TYPE_SAFE_UNIOn_TOP_ + +#include "type_safe_union/type_safe_union_kernel.h" + +#endif // DLIB_TYPE_SAFE_UNIOn_TOP_ + + + diff --git a/ml/dlib/dlib/type_safe_union/type_safe_union_kernel.h b/ml/dlib/dlib/type_safe_union/type_safe_union_kernel.h new file mode 100644 index 000000000..76a171286 --- /dev/null +++ b/ml/dlib/dlib/type_safe_union/type_safe_union_kernel.h @@ -0,0 +1,711 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_TYPE_SAFE_UNIOn_h_ +#define DLIB_TYPE_SAFE_UNIOn_h_ + +#include "type_safe_union_kernel_abstract.h" +#include "../algs.h" +#include "../noncopyable.h" +#include "../serialize.h" +#include +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class bad_type_safe_union_cast : public std::bad_cast + { + public: + virtual const char * what() const throw() + { + return "bad_type_safe_union_cast"; + } + }; + +// ---------------------------------------------------------------------------------------- + + struct _void{}; + inline void serialize( const _void&, std::ostream&){} + inline void deserialize( _void&, std::istream&){} + +// ---------------------------------------------------------------------------------------- + + template < + typename T1, + typename T2 = _void, + typename T3 = _void, + typename T4 = _void, + typename T5 = _void, + typename T6 = _void, + typename T7 = _void, + typename T8 = _void, + typename T9 = _void, + typename T10 = _void, + + typename T11 = _void, + typename T12 = _void, + typename T13 = _void, + typename T14 = _void, + typename T15 = _void, + typename T16 = _void, + typename T17 = _void, + typename T18 = _void, + typename T19 = _void, + typename T20 = _void + > + class type_safe_union : noncopyable + { + /*! + CONVENTION + - is_empty() == (type_identity == 0) + - contains() == (type_identity == get_type_id()) + - mem.get() == the block of memory on the stack which is + where objects in the union are stored + !*/ + + private: + + template + void invoke_on ( + T& obj, + U& item + ) const + { + obj(item); + } + + template + void invoke_on ( + T& , + _void + ) const + { + } + + + const static size_t max_size = tmax::value, + sizeof(T3)>::value, + sizeof(T4)>::value, + sizeof(T5)>::value, + sizeof(T6)>::value, + sizeof(T7)>::value, + sizeof(T8)>::value, + sizeof(T9)>::value, + sizeof(T10)>::value, + sizeof(T11)>::value, + sizeof(T12)>::value, + sizeof(T13)>::value, + sizeof(T14)>::value, + sizeof(T15)>::value, + sizeof(T16)>::value, + sizeof(T17)>::value, + sizeof(T18)>::value, + sizeof(T19)>::value, + sizeof(T20)>::value; + + // -------------------------------------------- + + // member data + stack_based_memory_block mem; + int type_identity; + + // -------------------------------------------- + + template + void validate_type() const + { + // ERROR: You are trying to get a type of object that isn't + // representable by this type_safe_union. I.e. The given + // type T isn't one of the ones given to this object's template + // arguments. + COMPILE_TIME_ASSERT(( is_same_type::value || + is_same_type::value || + is_same_type::value || + is_same_type::value || + is_same_type::value || + is_same_type::value || + is_same_type::value || + is_same_type::value || + is_same_type::value || + is_same_type::value || + + is_same_type::value || + is_same_type::value || + is_same_type::value || + is_same_type::value || + is_same_type::value || + is_same_type::value || + is_same_type::value || + is_same_type::value || + is_same_type::value || + is_same_type::value + )); + + } + + + struct destruct_helper + { + template + void operator() (T& item) const + { + item.~T(); + } + }; + + void destruct ( + ) + /*! + ensures + - #is_empty() == true + !*/ + { + // destruct whatever is in this object + apply_to_contents(destruct_helper()); + + // mark this object as being empty + type_identity = 0; + } + + template + void construct ( + ) + { + if (type_identity != get_type_id()) + { + destruct(); + new(mem.get()) T(); + type_identity = get_type_id(); + } + } + + template + void construct ( + const T& item + ) + { + if (type_identity != get_type_id()) + { + destruct(); + new(mem.get()) T(item); + type_identity = get_type_id(); + } + } + + template + T& unchecked_get( + ) + /*! + requires + - contains() == true + ensures + - returns a non-const reference to the T object + !*/ + { + return *static_cast(mem.get()); + } + + template + const T& unchecked_get( + ) const + /*! + requires + - contains() == true + ensures + - returns a const reference to the T object + !*/ + { + return *static_cast(mem.get()); + } + + template + void operator() (T& item) + /* + This function is used by the swap function of this class. See that + function to see how this works. + */ + { + exchange(get(), item); + } + + public: + + typedef T1 type1; + typedef T2 type2; + typedef T3 type3; + typedef T4 type4; + typedef T5 type5; + typedef T6 type6; + typedef T7 type7; + typedef T8 type8; + typedef T9 type9; + typedef T10 type10; + typedef T11 type11; + typedef T12 type12; + typedef T13 type13; + typedef T14 type14; + typedef T15 type15; + typedef T16 type16; + typedef T17 type17; + typedef T18 type18; + typedef T19 type19; + typedef T20 type20; + + + type_safe_union() : type_identity(0) + { + } + + template + type_safe_union ( + const T& item + ) : type_identity(0) + { + validate_type(); + construct(item); + } + + ~type_safe_union() + { + destruct(); + } + + template + static int get_type_id ( + ) + { + if (is_same_type::value) return 1; + if (is_same_type::value) return 2; + if (is_same_type::value) return 3; + if (is_same_type::value) return 4; + if (is_same_type::value) return 5; + + if (is_same_type::value) return 6; + if (is_same_type::value) return 7; + if (is_same_type::value) return 8; + if (is_same_type::value) return 9; + if (is_same_type::value) return 10; + + if (is_same_type::value) return 11; + if (is_same_type::value) return 12; + if (is_same_type::value) return 13; + if (is_same_type::value) return 14; + if (is_same_type::value) return 15; + + if (is_same_type::value) return 16; + if (is_same_type::value) return 17; + if (is_same_type::value) return 18; + if (is_same_type::value) return 19; + if (is_same_type::value) return 20; + + // return a number that doesn't match any of the + // valid states of type_identity + return -1; + } + + template + bool contains ( + ) const + { + return type_identity == get_type_id(); + } + + bool is_empty ( + ) const + { + return type_identity == 0; + } + + + public: + + template < + typename t1, typename t2, typename t3, typename t4, typename t5, + typename t6, typename t7, typename t8, typename t9, typename t10, + typename t11, typename t12, typename t13, typename t14, typename t15, + typename t16, typename t17, typename t18, typename t19, typename t20 + > + friend void serialize ( + const type_safe_union& item, + std::ostream& out + ); + + + template < + typename T + > + void apply_to_contents ( + T& obj + ) + { + switch (type_identity) + { + // do nothing because we are empty + case 0: break; + + case 1: invoke_on(obj,unchecked_get()); break; + case 2: invoke_on(obj,unchecked_get()); break; + case 3: invoke_on(obj,unchecked_get()); break; + case 4: invoke_on(obj,unchecked_get()); break; + case 5: invoke_on(obj,unchecked_get()); break; + + case 6: invoke_on(obj,unchecked_get()); break; + case 7: invoke_on(obj,unchecked_get()); break; + case 8: invoke_on(obj,unchecked_get()); break; + case 9: invoke_on(obj,unchecked_get()); break; + case 10: invoke_on(obj,unchecked_get()); break; + + case 11: invoke_on(obj,unchecked_get()); break; + case 12: invoke_on(obj,unchecked_get()); break; + case 13: invoke_on(obj,unchecked_get()); break; + case 14: invoke_on(obj,unchecked_get()); break; + case 15: invoke_on(obj,unchecked_get()); break; + + case 16: invoke_on(obj,unchecked_get()); break; + case 17: invoke_on(obj,unchecked_get()); break; + case 18: invoke_on(obj,unchecked_get()); break; + case 19: invoke_on(obj,unchecked_get()); break; + case 20: invoke_on(obj,unchecked_get()); break; + } + } + + template < + typename T + > + void apply_to_contents ( + const T& obj + ) + { + switch (type_identity) + { + // do nothing because we are empty + case 0: break; + + case 1: invoke_on(obj,unchecked_get()); break; + case 2: invoke_on(obj,unchecked_get()); break; + case 3: invoke_on(obj,unchecked_get()); break; + case 4: invoke_on(obj,unchecked_get()); break; + case 5: invoke_on(obj,unchecked_get()); break; + + case 6: invoke_on(obj,unchecked_get()); break; + case 7: invoke_on(obj,unchecked_get()); break; + case 8: invoke_on(obj,unchecked_get()); break; + case 9: invoke_on(obj,unchecked_get()); break; + case 10: invoke_on(obj,unchecked_get()); break; + + case 11: invoke_on(obj,unchecked_get()); break; + case 12: invoke_on(obj,unchecked_get()); break; + case 13: invoke_on(obj,unchecked_get()); break; + case 14: invoke_on(obj,unchecked_get()); break; + case 15: invoke_on(obj,unchecked_get()); break; + + case 16: invoke_on(obj,unchecked_get()); break; + case 17: invoke_on(obj,unchecked_get()); break; + case 18: invoke_on(obj,unchecked_get()); break; + case 19: invoke_on(obj,unchecked_get()); break; + case 20: invoke_on(obj,unchecked_get()); break; + } + } + + template < + typename T + > + void apply_to_contents ( + T& obj + ) const + { + switch (type_identity) + { + // do nothing because we are empty + case 0: break; + + case 1: invoke_on(obj,unchecked_get()); break; + case 2: invoke_on(obj,unchecked_get()); break; + case 3: invoke_on(obj,unchecked_get()); break; + case 4: invoke_on(obj,unchecked_get()); break; + case 5: invoke_on(obj,unchecked_get()); break; + + case 6: invoke_on(obj,unchecked_get()); break; + case 7: invoke_on(obj,unchecked_get()); break; + case 8: invoke_on(obj,unchecked_get()); break; + case 9: invoke_on(obj,unchecked_get()); break; + case 10: invoke_on(obj,unchecked_get()); break; + + case 11: invoke_on(obj,unchecked_get()); break; + case 12: invoke_on(obj,unchecked_get()); break; + case 13: invoke_on(obj,unchecked_get()); break; + case 14: invoke_on(obj,unchecked_get()); break; + case 15: invoke_on(obj,unchecked_get()); break; + + case 16: invoke_on(obj,unchecked_get()); break; + case 17: invoke_on(obj,unchecked_get()); break; + case 18: invoke_on(obj,unchecked_get()); break; + case 19: invoke_on(obj,unchecked_get()); break; + case 20: invoke_on(obj,unchecked_get()); break; + } + } + + template < + typename T + > + void apply_to_contents ( + const T& obj + ) const + { + switch (type_identity) + { + // do nothing because we are empty + case 0: break; + + case 1: invoke_on(obj,unchecked_get()); break; + case 2: invoke_on(obj,unchecked_get()); break; + case 3: invoke_on(obj,unchecked_get()); break; + case 4: invoke_on(obj,unchecked_get()); break; + case 5: invoke_on(obj,unchecked_get()); break; + + case 6: invoke_on(obj,unchecked_get()); break; + case 7: invoke_on(obj,unchecked_get()); break; + case 8: invoke_on(obj,unchecked_get()); break; + case 9: invoke_on(obj,unchecked_get()); break; + case 10: invoke_on(obj,unchecked_get()); break; + + case 11: invoke_on(obj,unchecked_get()); break; + case 12: invoke_on(obj,unchecked_get()); break; + case 13: invoke_on(obj,unchecked_get()); break; + case 14: invoke_on(obj,unchecked_get()); break; + case 15: invoke_on(obj,unchecked_get()); break; + + case 16: invoke_on(obj,unchecked_get()); break; + case 17: invoke_on(obj,unchecked_get()); break; + case 18: invoke_on(obj,unchecked_get()); break; + case 19: invoke_on(obj,unchecked_get()); break; + case 20: invoke_on(obj,unchecked_get()); break; + } + } + + void swap ( + type_safe_union& item + ) + { + // if both *this and item contain the same type of thing + if (type_identity == item.type_identity) + { + // swap the things in this and item. + item.apply_to_contents(*this); + } + else if (type_identity == 0) + { + // *this doesn't contain anything. So swap this and item and + // then destruct item. + item.apply_to_contents(*this); + item.destruct(); + } + else if (item.type_identity == 0) + { + // *this doesn't contain anything. So swap this and item and + // then destruct this. + apply_to_contents(item); + destruct(); + } + else + { + type_safe_union temp; + // swap *this into temp + apply_to_contents(temp); + // swap item into *this + item.apply_to_contents(*this); + // swap temp into item + temp.apply_to_contents(item); + } + } + + template + T& get( + ) + { + validate_type(); + construct(); + return *static_cast(mem.get()); + } + + template + const T& cast_to ( + ) const + { + validate_type(); + if (contains()) + return *static_cast(mem.get()); + else + throw bad_type_safe_union_cast(); + } + + template + T& cast_to ( + ) + { + validate_type(); + if (contains()) + return *static_cast(mem.get()); + else + throw bad_type_safe_union_cast(); + } + + template + type_safe_union& operator= ( const T& item) { get() = item; return *this; } + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename T1, typename T2, typename T3, typename T4, typename T5, + typename T6, typename T7, typename T8, typename T9, typename T10, + typename T11, typename T12, typename T13, typename T14, typename T15, + typename T16, typename T17, typename T18, typename T19, typename T20 + > + inline void swap ( + type_safe_union& a, + type_safe_union& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- + + template < + typename from, + typename T1, typename T2, typename T3, typename T4, typename T5, + typename T6, typename T7, typename T8, typename T9, typename T10, + typename T11, typename T12, typename T13, typename T14, typename T15, + typename T16, typename T17, typename T18, typename T19, typename T20 + > + struct is_convertible > + { + const static bool value = is_convertible::value || + is_convertible::value || + is_convertible::value || + is_convertible::value || + is_convertible::value || + is_convertible::value || + is_convertible::value || + is_convertible::value || + is_convertible::value || + is_convertible::value || + is_convertible::value || + is_convertible::value || + is_convertible::value || + is_convertible::value || + is_convertible::value || + is_convertible::value || + is_convertible::value || + is_convertible::value || + is_convertible::value || + is_convertible::value; + }; + +// ---------------------------------------------------------------------------------------- + + namespace impl_tsu + { + struct serialize_helper + { + /* + This is a function object to help us serialize type_safe_unions + */ + + std::ostream& out; + serialize_helper(std::ostream& out_): out(out_) {} + template + void operator() (const T& item) const { serialize(item, out); } + }; + } + + template < + typename T1, typename T2, typename T3, typename T4, typename T5, + typename T6, typename T7, typename T8, typename T9, typename T10, + typename T11, typename T12, typename T13, typename T14, typename T15, + typename T16, typename T17, typename T18, typename T19, typename T20 + > + void serialize ( + const type_safe_union& item, + std::ostream& out + ) + { + try + { + // save the type_identity + serialize(item.type_identity, out); + item.apply_to_contents(dlib::impl_tsu::serialize_helper(out)); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing an object of type type_safe_union"); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T1, typename T2, typename T3, typename T4, typename T5, + typename T6, typename T7, typename T8, typename T9, typename T10, + typename T11, typename T12, typename T13, typename T14, typename T15, + typename T16, typename T17, typename T18, typename T19, typename T20 + > + void deserialize ( + type_safe_union& item, + std::istream& in + ) + { + try + { + typedef type_safe_union tsu_type; + + int type_identity; + deserialize(type_identity, in); + switch (type_identity) + { + // swap an empty type_safe_union into item since it should be in the empty state + case 0: tsu_type().swap(item); break; + + case 1: deserialize(item.template get(), in); break; + case 2: deserialize(item.template get(), in); break; + case 3: deserialize(item.template get(), in); break; + case 4: deserialize(item.template get(), in); break; + case 5: deserialize(item.template get(), in); break; + + case 6: deserialize(item.template get(), in); break; + case 7: deserialize(item.template get(), in); break; + case 8: deserialize(item.template get(), in); break; + case 9: deserialize(item.template get(), in); break; + case 10: deserialize(item.template get(), in); break; + + case 11: deserialize(item.template get(), in); break; + case 12: deserialize(item.template get(), in); break; + case 13: deserialize(item.template get(), in); break; + case 14: deserialize(item.template get(), in); break; + case 15: deserialize(item.template get(), in); break; + + case 16: deserialize(item.template get(), in); break; + case 17: deserialize(item.template get(), in); break; + case 18: deserialize(item.template get(), in); break; + case 19: deserialize(item.template get(), in); break; + case 20: deserialize(item.template get(), in); break; + + default: throw serialization_error("Corrupt data detected while deserializing type_safe_union"); + } + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while deserializing an object of type type_safe_union"); + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_TYPE_SAFE_UNIOn_h_ + diff --git a/ml/dlib/dlib/type_safe_union/type_safe_union_kernel_abstract.h b/ml/dlib/dlib/type_safe_union/type_safe_union_kernel_abstract.h new file mode 100644 index 000000000..3d041e6a3 --- /dev/null +++ b/ml/dlib/dlib/type_safe_union/type_safe_union_kernel_abstract.h @@ -0,0 +1,329 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_TYPE_SAFE_UNION_KERNEl_ABSTRACT_ +#ifdef DLIB_TYPE_SAFE_UNION_KERNEl_ABSTRACT_ + +#include "../algs.h" +#include "../noncopyable.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class bad_type_safe_union_cast : public std::bad_cast + { + /*! + This is the exception object thrown by type_safe_union::cast_to() if the + type_safe_union does not contain the type of object being requested. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename T1, + typename T2 = _void, // _void indicates parameter not used. + typename T3 = _void, + typename T4 = _void, + typename T5 = _void, + typename T6 = _void, + typename T7 = _void, + typename T8 = _void, + typename T9 = _void, + typename T10 = _void, + typename T11 = _void, + typename T12 = _void, + typename T13 = _void, + typename T14 = _void, + typename T15 = _void, + typename T16 = _void, + typename T17 = _void, + typename T18 = _void, + typename T19 = _void, + typename T20 = _void + > + class type_safe_union : noncopyable + { + /*! + REQUIREMENTS ON ALL TEMPLATE ARGUMENTS + All template arguments must be default constructable and have + a global swap. + + INITIAL VALUE + - is_empty() == true + - contains() == false, for all possible values of U + + WHAT THIS OBJECT REPRESENTS + This object is a type safe analogue of the classic C union object. + The type_safe_union, unlike a union, can contain non-POD types such + as std::string. + + For example: + union my_union + { + int a; + std::string b; // Error, std::string isn't a POD + }; + + type_safe_union my_type_safe_union; // No error + !*/ + + public: + + typedef T1 type1; + typedef T2 type2; + typedef T3 type3; + typedef T4 type4; + typedef T5 type5; + typedef T6 type6; + typedef T7 type7; + typedef T8 type8; + typedef T9 type9; + typedef T10 type10; + typedef T11 type11; + typedef T12 type12; + typedef T13 type13; + typedef T14 type14; + typedef T15 type15; + typedef T16 type16; + typedef T17 type17; + typedef T18 type18; + typedef T19 type19; + typedef T20 type20; + + type_safe_union( + ); + /*! + ensures + - this object is properly initialized + !*/ + + template + type_safe_union ( + const T& item + ); + /*! + requires + - T must be one of the types given to this object's template arguments + ensures + - this object is properly initialized + - #get() == item + (i.e. this object will contain a copy of item) + !*/ + + ~type_safe_union( + ); + /*! + ensures + - all resources associated with this object have been freed + !*/ + + template + static int get_type_id ( + ); + /*! + ensures + - if (T is the same type as one of the template arguments) then + - returns a number indicating which template argument it is. + (e.g. if T is the same type as T3 then this function returns 3) + - else + - returns -1 + !*/ + + template + bool contains ( + ) const; + /*! + ensures + - if (this type_safe_union currently contains an object of type T) then + - returns true + - else + - returns false + !*/ + + bool is_empty ( + ) const; + /*! + ensures + - if (this type_safe_union currently contains any object at all) then + - returns true + - else + - returns false + !*/ + + template + void apply_to_contents ( + T& obj + ); + /*! + requires + - obj is a function object capable of operating on all the types contained + in this type_safe_union. I.e. obj(this->get()) must be a valid + expression for all the possible U types. + ensures + - if (is_empty() == false) then + - Let U denote the type of object currently contained in this type_safe_union + - calls obj(this->get()) + - The object returned by this->get() will be non-const + !*/ + + template + void apply_to_contents ( + const T& obj + ); + /*! + requires + - obj is a function object capable of operating on all the types contained + in this type_safe_union. I.e. obj(this->get()) must be a valid + expression for all the possible U types. + ensures + - if (is_empty() == false) then + - Let U denote the type of object currently contained in this type_safe_union + - calls obj(this->get()) + - The object returned by this->get() will be non-const + !*/ + + template + void apply_to_contents ( + T& obj + ) const; + /*! + requires + - obj is a function object capable of operating on all the types contained + in this type_safe_union. I.e. obj(this->get()) must be a valid + expression for all the possible U types. + ensures + - if (is_empty() == false) then + - Let U denote the type of object currently contained in this type_safe_union + - calls obj(this->get()) + - The object returned by this->get() will be const + !*/ + + template + void apply_to_contents ( + const T& obj + ) const; + /*! + requires + - obj is a function object capable of operating on all the types contained + in this type_safe_union. I.e. obj(this->get()) must be a valid + expression for all the possible U types. + ensures + - if (is_empty() == false) then + - Let U denote the type of object currently contained in this type_safe_union + - calls obj(this->get()) + - The object returned by this->get() will be const + !*/ + + template + T& get( + ); + /*! + requires + - T must be one of the types given to this object's template arguments + ensures + - #is_empty() == false + - #contains() == true + - if (contains() == true) + - returns a non-const reference to the object contained in this type_safe_union. + - else + - Constructs an object of type T inside *this + - Any previous object stored in this type_safe_union is destructed and its + state is lost. + - returns a non-const reference to the newly created T object. + !*/ + + template + const T& cast_to ( + ) const; + /*! + requires + - T must be one of the types given to this object's template arguments + ensures + - if (contains() == true) then + - returns a const reference to the object contained in this type_safe_union. + - else + - throws bad_type_safe_union_cast + !*/ + + template + T& cast_to ( + ); + /*! + requires + - T must be one of the types given to this object's template arguments + ensures + - if (contains() == true) then + - returns a non-const reference to the object contained in this type_safe_union. + - else + - throws bad_type_safe_union_cast + !*/ + + template + type_safe_union& operator= ( + const T& item + ); + /*! + requires + - T must be one of the types given to this object's template arguments + ensures + - #get() == item + (i.e. this object will contain a copy of item) + - returns *this + !*/ + + void swap ( + type_safe_union& item + ); + /*! + ensures + - swaps *this and item + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template < ... > + inline void swap ( + type_safe_union<...>& a, + type_safe_union<...>& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + +// ---------------------------------------------------------------------------------------- + + template < ... > + void serialize ( + const type_safe_union<...>& item, + std::ostream& out + ); + /*! + provides serialization support + + Note that type_safe_union objects are serialized as follows: + - if (item.is_empty()) then + - perform: serialize(0, out) + - else + - perform: serialize(item.get_type_id(), out); + serialize(item.get(), out); + !*/ + + template < ... > + void deserialize ( + type_safe_union<...>& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_TYPE_SAFE_UNION_KERNEl_ABSTRACT_ + diff --git a/ml/dlib/dlib/uintn.h b/ml/dlib/dlib/uintn.h new file mode 100644 index 000000000..8e2726546 --- /dev/null +++ b/ml/dlib/dlib/uintn.h @@ -0,0 +1,96 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#ifndef DLIB_UINtn_ +#define DLIB_UINtn_ + +#include "assert.h" + +namespace dlib +{ + + /*! + uint64 is a typedef for an unsigned integer that is exactly 64 bits wide. + uint32 is a typedef for an unsigned integer that is exactly 32 bits wide. + uint16 is a typedef for an unsigned integer that is exactly 16 bits wide. + uint8 is a typedef for an unsigned integer that is exactly 8 bits wide. + + int64 is a typedef for an integer that is exactly 64 bits wide. + int32 is a typedef for an integer that is exactly 32 bits wide. + int16 is a typedef for an integer that is exactly 16 bits wide. + int8 is a typedef for an integer that is exactly 8 bits wide. + !*/ + + +#ifdef __GNUC__ + typedef unsigned long long uint64; + typedef long long int64; +#elif defined(__BORLANDC__) + typedef unsigned __int64 uint64; + typedef __int64 int64; +#elif defined(_MSC_VER) + typedef unsigned __int64 uint64; + typedef __int64 int64; +#else + typedef unsigned long long uint64; + typedef long long int64; +#endif + + typedef unsigned short uint16; + typedef unsigned int uint32; + typedef unsigned char uint8; + + typedef short int16; + typedef int int32; + typedef char int8; + + + // make sure these types have the right sizes on this platform + COMPILE_TIME_ASSERT(sizeof(uint8) == 1); + COMPILE_TIME_ASSERT(sizeof(uint16) == 2); + COMPILE_TIME_ASSERT(sizeof(uint32) == 4); + COMPILE_TIME_ASSERT(sizeof(uint64) == 8); + + COMPILE_TIME_ASSERT(sizeof(int8) == 1); + COMPILE_TIME_ASSERT(sizeof(int16) == 2); + COMPILE_TIME_ASSERT(sizeof(int32) == 4); + COMPILE_TIME_ASSERT(sizeof(int64) == 8); + + + + template + struct unsigned_type; + template + struct unsigned_type { typedef uint8 type; }; + template + struct unsigned_type { typedef uint16 type; }; + template + struct unsigned_type { typedef uint32 type; }; + template + struct unsigned_type { typedef uint64 type; }; + /*! + ensures + - sizeof(unsigned_type::type) == sizeof(T) + - unsigned_type::type is an unsigned integral type + !*/ + + template + T zero_extend_cast( + const U val + ) + /*! + requires + - U and T are integral types + ensures + - let ut be a typedef for unsigned_type::type + - return static_cast(static_cast(val)); + !*/ + { + typedef typename unsigned_type::type ut; + return static_cast(static_cast(val)); + } + +} + +#endif // DLIB_UINtn_ + diff --git a/ml/dlib/dlib/unicode.h b/ml/dlib/dlib/unicode.h new file mode 100644 index 000000000..3c1598811 --- /dev/null +++ b/ml/dlib/dlib/unicode.h @@ -0,0 +1,9 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_UNICODe_TOP_ +#define DLIB_UNICODe_TOP_ + +#include "unicode/unicode.h" + +#endif // DLIB_UNICODe_TOP_ + diff --git a/ml/dlib/dlib/unicode/unicode.cpp b/ml/dlib/dlib/unicode/unicode.cpp new file mode 100644 index 000000000..2facc919c --- /dev/null +++ b/ml/dlib/dlib/unicode/unicode.cpp @@ -0,0 +1,175 @@ +// Copyright (C) 2008 Keita Mochizuki, Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_UNICODe_CPp_ +#define DLIB_UNICODe_CPp_ +#include "unicode.h" +#include +#include "../string.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + static const unichar SURROGATE_FIRST_TOP = 0xD800; + static const unichar SURROGATE_SECOND_TOP = 0xDC00; + static const unichar SURROGATE_CLEARING_MASK = 0x03FF; + static const unichar SURROGATE_TOP = SURROGATE_FIRST_TOP; + static const unichar SURROGATE_END = 0xE000; + static const unichar SMP_TOP = 0x10000; + static const int VALID_BITS = 10; + +// ---------------------------------------------------------------------------------------- + + template bool is_surrogate(T ch) + { + return (zero_extend_cast(ch) >= SURROGATE_TOP && + zero_extend_cast(ch) < SURROGATE_END); + } + +// ---------------------------------------------------------------------------------------- + + template unichar surrogate_pair_to_unichar(T first, T second) + { + return ((first & SURROGATE_CLEARING_MASK) << VALID_BITS) | ((second & SURROGATE_CLEARING_MASK) + SMP_TOP); + } + //110110 0000000000 + //110111 0000000000 + +// ---------------------------------------------------------------------------------------- + + void unichar_to_surrogate_pair(unichar input, unichar &first, unichar &second) + { + first = ((input - SMP_TOP) >> VALID_BITS) | SURROGATE_FIRST_TOP; + second = (input & SURROGATE_CLEARING_MASK) | SURROGATE_SECOND_TOP; + } + +// ---------------------------------------------------------------------------------------- + + template void wstr2ustring_t(const wchar_t *src, size_t src_len, ustring &dest); + + template <> void wstr2ustring_t<4>(const wchar_t *src, size_t , ustring &dest) + { + dest.assign((const unichar *)(src)); + } + + template <> void wstr2ustring_t<2>(const wchar_t *src, size_t src_len, ustring &dest) + { + size_t wlen = 0; + for (size_t i = 0; i < src_len; i++) + { + is_surrogate(src[i]) ? i++, wlen++ : wlen++; + } + dest.resize(wlen); + for (size_t i = 0, ii = 0; ii < src_len; ++i) + { + if (is_surrogate(src[ii])) + { + dest[i] = surrogate_pair_to_unichar(src[ii], src[ii+1]); + ii += 2; + }else + { + dest[i] = zero_extend_cast(src[ii]); + ii++; + } + } + } + +// ---------------------------------------------------------------------------------------- + + const ustring convert_wstring_to_utf32(const std::wstring &src) + { + ustring dest; + wstr2ustring_t(src.c_str(), src.size(), dest); + return dest; + } + +// ---------------------------------------------------------------------------------------- + + template struct ustring2wstr + { + }; + + // for the environment of sizeof(wchar_t) == 2 (i.e. Win32) + template <> struct ustring2wstr<2> + { + wchar_t *wstr; + size_t wlen; + ustring2wstr(const ustring &src){ + wlen = 0; + for (size_t i = 0; i < src.length(); ++i) + { + if (src[i] < SMP_TOP) wlen++; + else wlen += 2; + } + wstr = new wchar_t[wlen+1]; + wstr[wlen] = L'\0'; + + size_t wi = 0; + for (size_t i = 0; i < src.length(); ++i) + { + if (src[i] < SMP_TOP) + { + wstr[wi++] = (wchar_t)src[i]; + }else + { + unichar high, low; + unichar_to_surrogate_pair(src[i], high, low); + wstr[wi++] = (wchar_t)high; + wstr[wi++] = (wchar_t)low; + } + } + } + ~ustring2wstr() + { + delete[] wstr; + } + }; + + // for the environment of sizeof(wchar_t) == 4 (i.e. Unix gcc) + template <> struct ustring2wstr<4> + { + const wchar_t *wstr; + size_t wlen; + ustring2wstr(const ustring &src){ + wstr = (const wchar_t *)(src.c_str()); + wlen = src.size(); + } + }; + +// ---------------------------------------------------------------------------------------- + + const std::wstring convert_utf32_to_wstring(const ustring &src) + { + ustring2wstr conv(src); + std::wstring dest(conv.wstr); + return dest; + } + +// ---------------------------------------------------------------------------------------- + + const std::wstring convert_mbstring_to_wstring(const std::string &src) + { + std::vector wstr(src.length()+5); + std::mbstowcs(&wstr[0], src.c_str(), src.length()+1); + return std::wstring(&wstr[0]); + } + +// ---------------------------------------------------------------------------------------- + + const std::string convert_wstring_to_mbstring(const std::wstring &src) + { + using namespace std; + std::string str; + str.resize((src.length() + 1) * MB_CUR_MAX); + wcstombs(&str[0], src.c_str(), str.size()); + return std::string(&str[0]); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_UNICODe_CPp_ + diff --git a/ml/dlib/dlib/unicode/unicode.h b/ml/dlib/dlib/unicode/unicode.h new file mode 100644 index 000000000..d7510e34a --- /dev/null +++ b/ml/dlib/dlib/unicode/unicode.h @@ -0,0 +1,622 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net), and Nils Labugt +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_UNICODe_H_ +#define DLIB_UNICODe_H_ + +#include "../uintn.h" +#include "../algs.h" +#include "unicode_abstract.h" +#include +#include + +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + typedef uint32 unichar; + +#if defined(__GNUC__) && __GNUC__ < 4 && __GNUC_MINOR__ < 4 + struct unichar_traits + { + typedef dlib::unichar char_type; + typedef dlib::unichar int_type; + typedef std::streamoff off_type; + typedef std::streampos pos_type; + typedef std::mbstate_t state_type; + + static void assign(char_type& c1, const char_type& c2) { c1 = c2; } + static bool eq(const char_type& c1, const char_type& c2) { return c1 == c2; } + static bool lt(const char_type& c1, const char_type& c2) { return c1 < c2; } + static int compare(const char_type* s1, const char_type* s2, size_t n) + { + for (size_t i = 0; i < n; ++i) + { + if (s1[i] < s2[i]) + return -1; + else if (s1[i] > s2[i]) + return 1; + } + return 0; + } + + static size_t length(const char_type* s) + { + size_t i = 0; + while (s[i] != 0) + ++i; + return i; + } + + static const char_type* find(const char_type* s, size_t n, + const char_type& a) + { + for (size_t i = 0; i < n; ++i) + { + if (s[i] == a) + { + return s+i; + } + } + return 0; + } + + static char_type* move(char_type* s1, const char_type* s2, size_t n) + { + return static_cast(std::memmove(s1, s2, sizeof(char_type)*n)); + } + + static char_type* copy(char_type* s1, const char_type* s2, size_t n) + { + for (size_t i = 0; i < n; ++i) + s1[i] = s2[i]; + + return s1; + } + + static char_type* assign(char_type* s, size_t n, char_type a) + { + for (size_t i = 0; i < n; ++i) + s[i] = a; + + return s; + } + + + static int_type not_eof(const int_type& c) + { + if (!eq_int_type(c,eof())) + return to_int_type(c); + else + return 0; + } + + static char_type to_char_type(const int_type& c) { return static_cast(c); } + static int_type to_int_type(const char_type& c) { return zero_extend_cast(c); } + + static bool eq_int_type(const int_type& c1, const int_type& c2) { return c1 == c2; } + + static int_type eof() { return static_cast(EOF); } + }; + + typedef std::basic_string ustring; +#else + typedef std::basic_string ustring; +#endif + +// ---------------------------------------------------------------------------------------- + + namespace unicode_helpers + { + + template < + typename charT + > + int u8_to_u32( + charT& result, + std::istream& in + ) + /*! + ensures + - if (there just wasn't any more data and we hit EOF) then + - returns 0 + - else if (we decoded another character without error) then + - #result == the decoded character + - returns the number of bytes consumed to make this character + - else + - some error occurred + - returns -1 + !*/ + { + int val = in.get(); + if (val == EOF) + return 0; + + unichar ch[4]; + ch[0] = zero_extend_cast(val); + if ( ch[0] < 0x80 ) + { + result = static_cast(ch[0]); + return 1; + } + if ( ( ch[0] & ~0x3F ) == 0x80 ) + { + // invalid leading byte + return -1; + } + if ( ( ch[0] & ~0x1F ) == 0xC0 ) + { + val = in.get(); + if ( val == EOF ) + return -1; + + ch[1] = zero_extend_cast(val); + if ( ( ch[1] & ~0x3F ) != 0x80 ) + return -1; // invalid tail + if ( ( ch[0] & ~0x01 ) == 0xC0 ) + return -1; // overlong form + ch[0] &= 0x1F; + ch[1] &= 0x3F; + result = static_cast(( ch[0] << 6 ) | ch[1]); + return 2; + } + if ( ( ch[0] & ~0x0F ) == 0xE0 ) + { + for ( unsigned n = 1;n < 3;n++ ) + { + val = in.get(); + if ( val == EOF ) + return -1; + ch[n] = zero_extend_cast(val); + if ( ( ch[n] & ~0x3F ) != 0x80 ) + return -1; // invalid tail + ch[n] &= 0x3F; + } + ch[0] &= 0x0F; + result = static_cast(( ch[0] << 12 ) | ( ch[1] << 6 ) | ch[2]); + if ( result < 0x0800 ) + return -1; // overlong form + if ( result >= 0xD800 && result < 0xE000 ) + return -1; // invalid character (UTF-16 surrogate pairs) + if ( result >= 0xFDD0 && result <= 0xFDEF ) + return -1; // noncharacter + if ( result >= 0xFFFE ) + return -1; // noncharacter + return 3; + } + if ( ( ch[0] & ~0x07 ) == 0xF0 ) + { + for ( unsigned n = 1;n < 4;n++ ) + { + val = in.get(); + if ( val == EOF ) + return -1; + ch[n] = zero_extend_cast(val); + if ( ( ch[n] & ~0x3F ) != 0x80 ) + return -1; // invalid tail + ch[n] &= 0x3F; + } + if ( ( ch[0] ^ 0xF6 ) < 4 ) + return -1; + ch[0] &= 0x07; + result = static_cast(( ch[0] << 18 ) | ( ch[1] << 12 ) | ( ch[2] << 6 ) | ch[3]); + if ( result < 0x10000 ) + return -1; // overlong form + if ( (result & 0xFFFF) >= 0xFFFE ) + return -1; // noncharacter + return 4; + } + return -1; + } + + // ------------------------------------------------------------------------------------ + + template + class basic_utf8_streambuf : public std::basic_streambuf + { + public: + basic_utf8_streambuf ( + std::ifstream& fin_ + ) : + fin(fin_) + { + this->setg(in_buffer+max_putback, + in_buffer+max_putback, + in_buffer+max_putback); + } + + protected: + + typedef typename std::basic_streambuf::int_type int_type; + + // input functions + int_type underflow( + ) + { + if (this->gptr() < this->egptr()) + { + return zero_extend_cast(*this->gptr()); + } + + int num_put_back = static_cast(this->gptr() - this->eback()); + if (num_put_back > max_putback) + { + num_put_back = max_putback; + } + + // copy the putback characters into the putback end of the in_buffer + std::memmove(in_buffer+(max_putback-num_put_back), this->gptr()-num_put_back, num_put_back); + + + // fill the buffer with characters + int n = in_buffer_size-max_putback; + int i; + for (i = 0; i < n; ++i) + { + charT ch; + if (unicode_helpers::u8_to_u32(ch,fin) > 0) + { + (in_buffer+max_putback)[i] = ch; + } + else + { + break; + } + } + + if (i == 0) + { + // an error occurred or we hit EOF + return EOF; + } + + // reset in_buffer pointers + this->setg (in_buffer+(max_putback-num_put_back), + in_buffer+max_putback, + in_buffer+max_putback+i); + + return zero_extend_cast(*this->gptr()); + } + + private: + std::ifstream& fin; + static const int max_putback = 4; + static const int in_buffer_size = 10; + charT in_buffer[in_buffer_size]; + }; + } + +// ---------------------------------------------------------------------------------------- +#if defined(__GNUC__) && __GNUC__ >= 6 +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wmisleading-indentation" +#endif + template + bool is_combining_char( + const T ch_ + ) + { + const unichar ch = zero_extend_cast(ch_); + if ( ch < 0x300 ) return false; + if ( ch < 0x370 ) return true; + + if ( ch < 0x800 ) + { + if ( ch < 0x483 )return false;if ( ch < 0x48A )return true; + + if ( ch < 0x591 )return false;if ( ch < 0x5D0 ) + { + if ( ch == 0x5C0 )return false; + if ( ch == 0x5C3 )return false; + if ( ch == 0x5C6 )return false; + return true; + } + if ( ch < 0x610 )return false;if ( ch < 0x616 )return true; + if ( ch < 0x64B )return false;if ( ch < 0x660 )return true; + + if ( ch == 0x670 )return true; + + if ( ch < 0x6D6 )return false;if ( ch < 0x6EE ) + { + if ( ch == 0x6DD )return false; + if ( ch == 0x6E5 )return false; + if ( ch == 0x6E6 )return false; + if ( ch == 0x6E9 )return false; + return true; + } + if ( ch == 0x711 )return true; + + if ( ch < 0x730 )return false;if ( ch < 0x74B )return true; + if ( ch < 0x7A6 )return false;if ( ch < 0x7B1 )return true; + if ( ch < 0x7EB )return false;if ( ch < 0x7F4 )return true; + return false; + } + if ( ch < 0xA00 ) + { + if ( ch < 0x901 )return false;if ( ch < 0x904 )return true; + if ( ch < 0x93C )return false;if ( ch < 0x955 ) + { + if ( ch == 0x93D )return false; + if ( ch == 0x950 )return false; + return true; + } + if ( ch < 0x962 )return false;if ( ch < 0x964 )return true; + if ( ch < 0x981 )return false;if ( ch < 0x984 )return true; + if ( ch < 0x9BC )return false;if ( ch < 0x9D8 ) + { + if ( ch == 0x9BD )return false; + if ( ch == 0x9CE )return false; + return true; + } + if ( ch < 0x9E2 )return false;if ( ch < 0x9E4 )return true; + return false; + } + if ( ch < 0xC00 ) + { + if ( ch < 0xA01 )return false;if ( ch < 0xA04 )return true; + if ( ch < 0xA3C )return false;if ( ch < 0xA4E )return true; + if ( ch < 0xA70 )return false;if ( ch < 0xA72 )return true; + if ( ch < 0xA81 )return false;if ( ch < 0xA84 )return true; + if ( ch < 0xABC )return false;if ( ch < 0xACE ) + { + if ( ch == 0xABD )return false; + return true; + } + if ( ch < 0xAE2 )return false;if ( ch < 0xAE4 )return true; + if ( ch < 0xB01 )return false;if ( ch < 0xB04 )return true; + if ( ch < 0xB3C )return false;if ( ch < 0xB58 ) + { + if ( ch == 0xB3D )return false; + return true; + } + if ( ch == 0xB82 )return true; + + if ( ch < 0xBBE )return false;if ( ch < 0xBD8 )return true; + + if ( ch == 0xBF4 )return true; + if ( ch == 0xBF8 )return true; + return false; + } + if(ch < 0xE00) + { + if ( ch < 0xC01 )return false;if ( ch < 0xC04 )return true; + if ( ch < 0xC3E )return false;if ( ch < 0xC57 )return true; + if ( ch < 0xC82 )return false;if ( ch < 0xC84 )return true; + if ( ch < 0xCBC )return false;if ( ch < 0xCD7 ) + { + if ( ch == 0xCBD )return false; + return true; + } + if ( ch < 0xCE2 )return false;if ( ch < 0xCE4 )return true; + if ( ch < 0xD02 )return false;if ( ch < 0xD04 )return true; + if ( ch < 0xD3E )return false;if ( ch < 0xD58 )return true; + if ( ch < 0xD82 )return false;if ( ch < 0xD84 )return true; + if ( ch < 0xDCA )return false;if ( ch < 0xDF4 )return true; + return false; + } + if(ch < 0x1000) + { + if ( ch == 0xE31 )return true; + + if ( ch < 0xE34 )return false;if ( ch < 0xE3B )return true; + if ( ch < 0xE47 )return false;if ( ch < 0xE4F )return true; + + if ( ch == 0xEB1 )return true; + + if ( ch < 0xEB4 )return false;if ( ch < 0xEBD )return true; + if ( ch < 0xEC8 )return false;if ( ch < 0xECE )return true; + if ( ch < 0xF18 )return false;if ( ch < 0xF1A )return true; + + if ( ch == 0xF35 )return true; + if ( ch == 0xF37 )return true; + if ( ch == 0xF39 )return true; + + if ( ch < 0xF3E )return false;if ( ch < 0xF40 )return true; + if ( ch < 0xF71 )return false;if ( ch < 0xF88 ) + { + if ( ch == 0xF85 )return false; + return true; + } + if ( ch < 0xF90 )return false;if ( ch < 0xFBD )return true; + + if ( ch == 0xFC6 )return true; + return false; + } + if ( ch < 0x1800 ) + { + if ( ch < 0x102C )return false;if ( ch < 0x1040 )return true; + if ( ch < 0x1056 )return false;if ( ch < 0x105A )return true; + + if ( ch == 0x135F )return true; + + if ( ch < 0x1712 )return false;if ( ch < 0x1715 )return true; + if ( ch < 0x1732 )return false;if ( ch < 0x1735 )return true; + if ( ch < 0x1752 )return false;if ( ch < 0x1754 )return true; + if ( ch < 0x1772 )return false;if ( ch < 0x1774 )return true; + if ( ch < 0x17B6 )return false;if ( ch < 0x17D4 )return true; + + if ( ch == 0x17DD )return true; + return false; + } + if(ch < 0x2000) + { + if ( ch < 0x180B )return false;if ( ch < 0x180E )return true; + + if ( ch == 0x18A9 )return true; + + if ( ch < 0x1920 )return false;if ( ch < 0x193C )return true; + if ( ch < 0x19B0 )return false;if ( ch < 0x19C1 )return true; + if ( ch < 0x19C8 )return false;if ( ch < 0x19CA )return true; + if ( ch < 0x1A17 )return false;if ( ch < 0x1A1C )return true; + if ( ch < 0x1B00 )return false;if ( ch < 0x1B05 )return true; + if ( ch < 0x1B34 )return false;if ( ch < 0x1B45 )return true; + if ( ch < 0x1B6B )return false;if ( ch < 0x1B74 )return true; + if ( ch < 0x1DC0 )return false;if ( ch < 0x1E00 )return true; + return false; + } + if ( ch < 0x20D0 )return false;if ( ch < 0x2100 )return true; + if ( ch < 0x302A )return false;if ( ch < 0x3030 )return true; + if ( ch < 0x3099 )return false;if ( ch < 0x309B )return true; + + if ( ch == 0xA802 )return true; + if ( ch == 0xA806 )return true; + if ( ch == 0xA80B )return true; + + if ( ch < 0xA823 )return false;if ( ch < 0xA828 )return true; + + if ( ch == 0xFB1E )return true; + + if ( ch < 0xFE00 )return false;if ( ch < 0xFE10 )return true; + if ( ch < 0xFE20 )return false;if ( ch < 0xFE30 )return true; + if ( ch < 0x10A01 )return false;if ( ch < 0x10A10 )return true; + if ( ch < 0x10A38 )return false;if ( ch < 0x10A40 )return true; + if ( ch < 0x1D165 )return false;if ( ch < 0x1D16A )return true; + if ( ch < 0x1D16D )return false;if ( ch < 0x1D173 )return true; + if ( ch < 0x1D17B )return false;if ( ch < 0x1D183 )return true; + if ( ch < 0x1D185 )return false;if ( ch < 0x1D18C )return true; + if ( ch < 0x1D1AA )return false;if ( ch < 0x1D1AE )return true; + if ( ch < 0x1D242 )return false;if ( ch < 0x1D245 )return true; + if ( ch < 0xE0100 )return false;if ( ch < 0xE01F0 )return true; + return false; + } +#if defined(__GNUC__) && __GNUC__ >= 6 +#pragma GCC diagnostic pop +#endif + +// ---------------------------------------------------------------------------------------- + + class invalid_utf8_error : public error + { + public: + invalid_utf8_error():error(EUTF8_TO_UTF32) {} + }; + + inline const ustring convert_utf8_to_utf32 ( + const std::string& str + ) + { + using namespace unicode_helpers; + ustring temp; + std::istringstream sin(str); + + temp.reserve(str.size()); + + int status; + unichar ch; + while ( (status = u8_to_u32(ch,sin)) > 0) + temp.push_back(ch); + + if (status < 0) + throw invalid_utf8_error(); + + return temp; + } + +// ---------------------------------------------------------------------------------------- + + bool is_surrogate(unichar ch); + + unichar surrogate_pair_to_unichar(unichar first, unichar second); + + void unichar_to_surrogate_pair(unichar unicode, unichar &first, unichar &second); + + + const ustring convert_wstring_to_utf32 ( + const std::wstring &wstr + ); + + const std::wstring convert_utf32_to_wstring ( + const ustring &src + ); + + const std::wstring convert_mbstring_to_wstring ( + const std::string &src + ); + + const std::string convert_wstring_to_mbstring( + const std::wstring &src + ); + +// ---------------------------------------------------------------------------------------- + + template + class basic_utf8_ifstream : public std::basic_istream + { + public: + + basic_utf8_ifstream ( + ) : std::basic_istream(&buf), buf(fin) {} + + basic_utf8_ifstream ( + const char* file_name, + std::ios_base::openmode mode = std::ios::in + ) : + std::basic_istream(&buf), + buf(fin) + { + fin.open(file_name,mode); + // make this have the same error state as fin + this->clear(fin.rdstate()); + } + + basic_utf8_ifstream ( + const std::string& file_name, + std::ios_base::openmode mode = std::ios::in + ) : + std::basic_istream(&buf), + buf(fin) + { + fin.open(file_name.c_str(),mode); + // make this have the same error state as fin + this->clear(fin.rdstate()); + } + + void open( + const std::string& file_name, + std::ios_base::openmode mode = std::ios::in + ) + { + open(file_name.c_str(),mode); + } + + void open ( + const char* file_name, + std::ios_base::openmode mode = std::ios::in + ) + { + fin.close(); + fin.clear(); + fin.open(file_name,mode); + // make this have the same error state as fin + this->clear(fin.rdstate()); + } + + void close ( + ) + { + fin.close(); + // make this have the same error state as fin + this->clear(fin.rdstate()); + } + + private: + + std::ifstream fin; + unicode_helpers::basic_utf8_streambuf buf; + }; + + typedef basic_utf8_ifstream utf8_uifstream; + typedef basic_utf8_ifstream utf8_wifstream; + +// ---------------------------------------------------------------------------------------- + +} + +#ifdef NO_MAKEFILE +#include "unicode.cpp" +#endif + +#endif // DLIB_UNICODe_H_ + diff --git a/ml/dlib/dlib/unicode/unicode_abstract.h b/ml/dlib/dlib/unicode/unicode_abstract.h new file mode 100644 index 000000000..ed5b9ab4e --- /dev/null +++ b/ml/dlib/dlib/unicode/unicode_abstract.h @@ -0,0 +1,233 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net), and Nils Labugt +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_UNICODe_ABSTRACT_H_ +#ifdef DLIB_UNICODe_ABSTRACT_H_ + +#include "../uintn.h" +#include "../error.h" +#include +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + // a typedef for an unsigned 32bit integer to hold our UNICODE characters + typedef uint32 unichar; + + // a typedef for a string object to hold our UNICODE strings + typedef std::basic_string ustring; + +// ---------------------------------------------------------------------------------------- + + template + bool is_combining_char( + const T ch_ + ); + /*! + ensures + - if (ch_ is a unicode combining character) then + - returns true + - else + - returns false + !*/ + + bool is_surrogate( + unichar ch + ); + /*! + ensures + - if (ch is a unicode surrogate character) then + - returns true + - else + - returns false + !*/ + + unichar surrogate_pair_to_unichar( + unichar first, + unichar second + ); + /*! + requires + - 0xD800 <= first < 0xDC00 + - 0xDC00 <= second < 0xE000 + - is_surrogate(first) == true + - is_surrogate(second) == true + ensures + - converts two surrogates into one unicode character + !*/ + + void unichar_to_surrogate_pair( + unichar ch, + unichar& first, + unichar& second + ); + /*! + requires + - ch >= 0x10000 (i.e. is not in Basic Multilingual Plane) + ensures + - surrogate_pair_to_unichar(#first,#second) == ch + (i.e. converts ch into two surrogate characters) + !*/ + +// ---------------------------------------------------------------------------------------- + + class invalid_utf8_error : public error + { + public: + invalid_utf8_error():error(EUTF8_TO_UTF32) {} + }; + + const ustring convert_utf8_to_utf32 ( + const std::string& str + ); + /*! + ensures + - if (str is a valid UTF-8 encoded string) then + - returns a copy of str that has been converted into a + unichar string + - else + - throws invalid_utf8_error + !*/ + +// ---------------------------------------------------------------------------------------- + + const ustring convert_wstring_to_utf32 ( + const std::wstring &wstr + ); + /*! + requires + - wstr is a valid UTF-16 string when sizeof(wchar_t) == 2 + - wstr is a valid UTF-32 string when sizeof(wchar_t) == 4 + ensures + - converts wstr into UTF-32 string + !*/ + +// ---------------------------------------------------------------------------------------- + + const std::wstring convert_utf32_to_wstring ( + const ustring &str + ); + /*! + requires + - str is a valid UTF-32 encoded string + ensures + - converts str into wstring whose encoding is UTF-16 when sizeof(wchar_t) == 2 + - converts str into wstring whose encoding is UTF-32 when sizeof(wchar_t) == 4 + !*/ + +// ---------------------------------------------------------------------------------------- + + const std::wstring convert_mbstring_to_wstring ( + const std::string &str + ); + /*! + requires + - str is a valid multibyte string whose encoding is same as current locale setting + ensures + - converts str into wstring whose encoding is UTF-16 when sizeof(wchar_t) == 2 + - converts str into wstring whose encoding is UTF-32 when sizeof(wchar_t) == 4 + !*/ + +// ---------------------------------------------------------------------------------------- + + const std::string convert_wstring_to_mbstring ( + const std::wstring &src + ); + /*! + requires + - str is a valid wide character string string whose encoding is same as current + locale setting + ensures + - returns a multibyte encoded version of the given string + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename charT + > + class basic_utf8_ifstream : public std::basic_istream + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents an input file stream much like the + normal std::ifstream except that it knows how to read UTF-8 + data. So when you read characters out of this stream it will + automatically convert them from the UTF-8 multibyte encoding + into a fixed width wide character encoding. + !*/ + + public: + + basic_utf8_ifstream ( + ); + /*! + ensures + - constructs an input stream that isn't yet associated with + a file. + !*/ + + basic_utf8_ifstream ( + const char* file_name, + std::ios_base::openmode mode = std::ios::in + ); + /*! + ensures + - tries to open the given file for reading by this stream + - mode is interpreted exactly the same was as the open mode + argument used by std::ifstream. + !*/ + + basic_utf8_ifstream ( + const std::string& file_name, + std::ios_base::openmode mode = std::ios::in + ); + /*! + ensures + - tries to open the given file for reading by this stream + - mode is interpreted exactly the same was as the open mode + argument used by std::ifstream. + !*/ + + void open( + const std::string& file_name, + std::ios_base::openmode mode = std::ios::in + ); + /*! + ensures + - tries to open the given file for reading by this stream + - mode is interpreted exactly the same was as the open mode + argument used by std::ifstream. + !*/ + + void open ( + const char* file_name, + std::ios_base::openmode mode = std::ios::in + ); + /*! + ensures + - tries to open the given file for reading by this stream + - mode is interpreted exactly the same was as the open mode + argument used by std::ifstream. + !*/ + + void close ( + ); + /*! + ensures + - any file opened by this stream has been closed + !*/ + }; + + typedef basic_utf8_ifstream utf8_uifstream; + typedef basic_utf8_ifstream utf8_wifstream; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_UNICODe_ABSTRACT_H_ + + diff --git a/ml/dlib/dlib/unordered_pair.h b/ml/dlib/dlib/unordered_pair.h new file mode 100644 index 000000000..9ea75b912 --- /dev/null +++ b/ml/dlib/dlib/unordered_pair.h @@ -0,0 +1,176 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_UNORDERED_PAiR_Hh_ +#define DLIB_UNORDERED_PAiR_Hh_ + +#include "serialize.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + struct unordered_pair + { + /*! + REQUIREMENTS ON T + T must be default constructable, copyable, and comparable using + operator < and == + + WHAT THIS OBJECT REPRESENTS + This object is very similar to the std::pair struct except unordered_pair + is only capable of representing an unordered set of two items rather than + an ordered list of two items like std::pair. + + This is best illustrated by example. Suppose we have the following + five variables: + std::pair p1(1, 5), p2(5,1); + unordered_pair up1(1,5), up2(5,1), up3(6,7); + + Then it is the case that: + up1 == up2 + up1 != up3 + p1 != p2 + + So the unordered_pair doesn't care about the order of the arguments. + In this case, up1 and up2 are both equivalent. + + !*/ + + typedef T type; + typedef T first_type; + typedef T second_type; + + const T first; + const T second; + + unordered_pair() : first(), second() + /*! + ensures + - #first and #second are default initialized + !*/ {} + + unordered_pair( + const T& a, + const T& b + ) : + first( a < b ? a : b), + second(a < b ? b : a) + /*! + ensures + - #first <= #second + - #first and #second contain copies of the items a and b. + !*/ {} + + template + unordered_pair ( + const unordered_pair & p + ) : + first(p.first), + second(p.second) + /*! + ensures + - #*this is a copy of p + !*/ {} + + unordered_pair& operator= ( + const unordered_pair& item + ) + /*! + ensures + - #*this == item + !*/ + { + const_cast(first) = item.first; + const_cast(second) = item.second; + return *this; + } + }; + +// ---------------------------------------------------------------------------------------- + + template + bool operator==(const unordered_pair& a, const unordered_pair & b) + { + return a.first == b.first && a.second == b.second; + } + + template + bool operator!=(const unordered_pair& a, const unordered_pair & b) + { + return !(a == b); + } + + template + bool operator<(const unordered_pair& a, const unordered_pair& b) + { + return (a.first < b.first || (!(b.first < a.first) && a.second < b.second)); + } + + template + bool operator>(const unordered_pair& a, const unordered_pair & b) + { + return b < a; + } + + template + bool operator<=(const unordered_pair& a, const unordered_pair & b) + { + return !(b < a); + } + + template + bool operator>=(const unordered_pair& a, const unordered_pair & b) + { + return !(a < b); + } + + template + unordered_pair make_unordered_pair (const T& a, const T& b) + { + return unordered_pair(a,b); + } + +// ---------------------------------------------------------------------------------------- + + template + void serialize ( + const unordered_pair& item, + std::ostream& out + ) + { + try + { + serialize(item.first,out); + serialize(item.second,out); + } + catch (serialization_error& e) + { throw serialization_error(e.info + "\n while serializing object of type unordered_pair"); } + } + + template + void deserialize ( + unordered_pair& item, + std::istream& in + ) + { + try + { + T a, b; + deserialize(a,in); + deserialize(b,in); + item = make_unordered_pair(a,b); + } + catch (serialization_error& e) + { throw serialization_error(e.info + "\n while deserializing object of type unordered_pair"); } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_UNORDERED_PAiR_Hh_ + diff --git a/ml/dlib/dlib/vectorstream.h b/ml/dlib/dlib/vectorstream.h new file mode 100644 index 000000000..afeb247de --- /dev/null +++ b/ml/dlib/dlib/vectorstream.h @@ -0,0 +1,11 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_VECTORSTReAMh_ +#define DLIB_VECTORSTReAMh_ + +#include "vectorstream/vectorstream.h" +#include "vectorstream/unserialize.h" + + +#endif // DLIB_VECTORSTReAMh_ + diff --git a/ml/dlib/dlib/vectorstream/unserialize.h b/ml/dlib/dlib/vectorstream/unserialize.h new file mode 100644 index 000000000..dbfe5584b --- /dev/null +++ b/ml/dlib/dlib/vectorstream/unserialize.h @@ -0,0 +1,98 @@ +// Copyright (C) 2016 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_uNSERIALIZE_Hh_ +#define DLIB_uNSERIALIZE_Hh_ + +#include "unserialize_abstract.h" + +#include "../serialize.h" +#include "../algs.h" +#include "vectorstream.h" + + + +namespace dlib +{ + class unserialize : public std::istream + { + class mystreambuf : public std::streambuf + { + typedef std::vector::size_type size_type; + size_type read_pos; // buffer[read_pos] == next byte to read from buffer + public: + std::vector buffer; + std::istream& str; + + template + mystreambuf( + const T& item, + std::istream& str_ + ) : + read_pos(0), + str(str_) + { + // put the item into our buffer. + vectorstream vstr(buffer); + serialize(item, vstr); + } + + + // ------------------------ INPUT FUNCTIONS ------------------------ + + int_type underflow( + ) + { + if (read_pos < buffer.size()) + return static_cast(buffer[read_pos]); + else + return str.peek(); + } + + int_type uflow( + ) + { + if (read_pos < buffer.size()) + return static_cast(buffer[read_pos++]); + else + return str.get(); + } + + std::streamsize xsgetn ( + char* s, + std::streamsize n + ) + { + if (read_pos < buffer.size()) + { + const size_type num = std::min(n, buffer.size()-read_pos); + std::memcpy(s, &buffer[read_pos], num); + read_pos += num; + return num; + } + else + { + return str.rdbuf()->sgetn(s,n); + } + return 0; + } + + }; + + public: + + template + unserialize ( + const T& item, + std::istream& str + ) : + std::istream(&buf), + buf(item, str) + {} + + private: + mystreambuf buf; + }; +} + +#endif // DLIB_uNSERIALIZE_Hh_ + diff --git a/ml/dlib/dlib/vectorstream/unserialize_abstract.h b/ml/dlib/dlib/vectorstream/unserialize_abstract.h new file mode 100644 index 000000000..b7d67836c --- /dev/null +++ b/ml/dlib/dlib/vectorstream/unserialize_abstract.h @@ -0,0 +1,58 @@ +// Copyright (C) 2016 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_uNSERIALIZE_ABSTRACT_Hh_ +#ifdef DLIB_uNSERIALIZE_ABSTRACT_Hh_ + +#include "../serialize.h" +#include + +namespace dlib +{ + class unserialize : public std::istream + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a tool that allows you to effectively put an object you just + deserialized from a stream back into the stream. Its use is best + illustrated via an example. + + void example(std::istream& in) + { + // Suppose that in contains serialized copies of three "some_type" + // objects. You could read them as follows: + some_type obj1, obj2, obj3; + + deserialize(obj1, in); // reads obj1 from stream. + deserialize(obj2, in); // reads obj2 from stream. + + unserialize in2(obj2, in); // make the in2 stream that has obj2 at its front. + deserialize(obj2, in2); // reads obj2 from stream again. + deserialize(obj3, in2); // reads obj3 from stream. + } + + The reason unserialize is useful is because it allows you to peek at the + next object in a stream and potentially do something different based on + what object is coming next, but still allowing subsequent deserialize() + statements to be undisturbed by the fact that you peeked at the data. + !*/ + + public: + + template + unserialize ( + const T& item, + std::istream& in + ); + /*! + requires + - T must be serializable + ensures + - The bytes in this stream begin with a serialized copy of item followed + immediately by the bytes in the given istream. + !*/ + }; +} + +#endif // DLIB_uNSERIALIZE_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/vectorstream/vectorstream.h b/ml/dlib/dlib/vectorstream/vectorstream.h new file mode 100644 index 000000000..a1b7b8b07 --- /dev/null +++ b/ml/dlib/dlib/vectorstream/vectorstream.h @@ -0,0 +1,138 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_VECTORStREAM_Hh_ +#define DLIB_VECTORStREAM_Hh_ + +#include "vectorstream_abstract.h" + +#include +#include +#include +#include +#include +#include "../algs.h" + + +#ifdef _MSC_VER +// Disable the warning about inheriting from std::iostream 'via dominance' since this warning is a warning about +// visual studio conforming to the standard and is ignorable. See http://connect.microsoft.com/VisualStudio/feedback/details/733720/inheriting-from-std-fstream-produces-c4250-warning +// for further details if interested. +#pragma warning(disable : 4250) +#endif // _MSC_VER + +namespace dlib +{ + class vectorstream : public std::iostream + { + class vector_streambuf : public std::streambuf + { + typedef std::vector::size_type size_type; + size_type read_pos; // buffer[read_pos] == next byte to read from buffer + public: + std::vector& buffer; + + vector_streambuf( + std::vector& buffer_ + ) : + read_pos(0), + buffer(buffer_) + {} + + + void seekg(size_type pos) + { + read_pos = pos; + } + + // ------------------------ OUTPUT FUNCTIONS ------------------------ + + int_type overflow ( int_type c) + { + if (c != EOF) buffer.push_back(static_cast(c)); + return c; + } + + std::streamsize xsputn ( const char* s, std::streamsize num) + { + buffer.insert(buffer.end(), s, s+num); + return num; + } + + // ------------------------ INPUT FUNCTIONS ------------------------ + + int_type underflow( + ) + { + if (read_pos < buffer.size()) + return static_cast(buffer[read_pos]); + else + return EOF; + } + + int_type uflow( + ) + { + if (read_pos < buffer.size()) + return static_cast(buffer[read_pos++]); + else + return EOF; + } + + int_type pbackfail( + int_type c + ) + { + // if they are trying to push back a character that they didn't read last + // that is an error + const unsigned long prev = read_pos-1; + if (c != EOF && prev < buffer.size() && + c != static_cast(buffer[prev])) + { + return EOF; + } + + read_pos = prev; + return 1; + } + + std::streamsize xsgetn ( + char* s, + std::streamsize n + ) + { + if (read_pos < buffer.size()) + { + const size_type num = std::min(n, buffer.size()-read_pos); + std::memcpy(s, &buffer[read_pos], num); + read_pos += num; + return num; + } + return 0; + } + + }; + + public: + + vectorstream ( + std::vector& buffer + ) : + std::iostream(&buf), + buf(buffer) + {} + + std::istream& seekg ( + std::streampos pos + ) + { + buf.seekg(pos); + return *this; + } + + private: + vector_streambuf buf; + }; +} + +#endif // DLIB_VECTORStREAM_Hh_ + diff --git a/ml/dlib/dlib/vectorstream/vectorstream_abstract.h b/ml/dlib/dlib/vectorstream/vectorstream_abstract.h new file mode 100644 index 000000000..f5fe1004c --- /dev/null +++ b/ml/dlib/dlib/vectorstream/vectorstream_abstract.h @@ -0,0 +1,62 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_VECTORStREAM_ABSTRACT_Hh_ +#ifdef DLIB_VECTORStREAM_ABSTRACT_Hh_ + +#include +#include + +namespace dlib +{ + class vectorstream : public std::iostream + { + /*! + WHAT THIS OBJECT REPRESENTS + This is an iostream object that reads and writes from an in-memory buffer. + It functions very much the same way as the std::stringstream object. + However, while the std::stringstream holds its buffer internally and it can + only be accessed by copying it out, the vectorstream uses an external + std::vector as its buffer. That is, it holds a reference to an + external vector and does not contain any internal buffers of its own. + + This object is useful as a slightly more efficient alternative to the + std::stringstream since you can avoid the overhead of copying buffer + contents to and from the stream. This is particularly useful when used as + a source or target for serialization routines. + !*/ + + public: + + vectorstream ( + std::vector& buffer + ); + /*! + ensures + - This object will use the given vector as its read/write buffer. That is: + - Any data written to this stream will be appended to the given buffer + - Any data read from this stream is read from the given buffer, + starting with buffer[0], then buffer[1], and so on. Just like + std::stringstream, writes to the stream do not move the position of + the next byte that will be read from the buffer. + - This constructor does not copy the buffer. Only a reference to it will + be used. Therefore, any time data is written to this stream it will + immediately show up in the buffer. + !*/ + + std::istream& seekg ( + std::streampos pos + ); + /*! + ensures + - The next read from this object will read from the position buffer[pos], + where buffer is the std::vector given to this object's constructor. Note + that if pos >= buffer.size() then the next read will simply return EOF. + - returns *this + !*/ + + }; +} + +#endif // DLIB_VECTORStREAM_ABSTRACT_Hh_ + + diff --git a/ml/dlib/dlib/windows_magic.h b/ml/dlib/dlib/windows_magic.h new file mode 100644 index 000000000..cfb7f22ed --- /dev/null +++ b/ml/dlib/dlib/windows_magic.h @@ -0,0 +1,50 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_WINDOWS_MAGIc_ +#define DLIB_WINDOWS_MAGIc_ + +#include "platform.h" + +#ifdef WIN32 + +// This file contains all the magical #defines you have to setup before you +// include the windows header files. + +#ifndef NOMINMAX +#define NOMINMAX // prevent windows from messing with std::min and std::max +#endif + +// Prevent windows from #defining IN or OUT +#ifndef _NO_W32_PSEUDO_MODIFIERS +#define _NO_W32_PSEUDO_MODIFIERS +#endif + +// now just for good measure undefine min and max if they are defined +#ifdef min +#undef min +#endif + +#ifdef max +#undef max +#endif + +#ifdef NO_MAKEFILE +// only define this if all the cpp files are going to be sucked into the headers +// because otherwise we don't need it since everything is isolated in the sockets +// cpp file and this declaration for _WINSOCKAPI_ appears there also. +#ifndef _WINSOCKAPI_ +#define _WINSOCKAPI_ /* Prevent inclusion of winsock.h in windows.h */ +#endif +#endif + +// This is something stupid you have to do to make visual studio include the right +// stuff. I don't really know what the deal is with this. +#if _WIN32_WINNT < 0x0500 +#undef _WIN32_WINNT +#define _WIN32_WINNT 0x0500 +#endif + +#endif // WIN32 + +#endif // DLIB_WINDOWS_MAGIc_ + diff --git a/ml/dlib/dlib/xml_parser.h b/ml/dlib/dlib/xml_parser.h new file mode 100644 index 000000000..4e60ed68f --- /dev/null +++ b/ml/dlib/dlib/xml_parser.h @@ -0,0 +1,13 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_XML_PARSEr_ +#define DLIB_XML_PARSEr_ + +#include + +#include "xml_parser/xml_parser_kernel_interfaces.h" +#include "xml_parser/xml_parser_kernel_1.h" + + +#endif // DLIB_XML_PARSEr_ + diff --git a/ml/dlib/dlib/xml_parser/xml_parser_kernel_1.h b/ml/dlib/dlib/xml_parser/xml_parser_kernel_1.h new file mode 100644 index 000000000..e1854bc26 --- /dev/null +++ b/ml/dlib/dlib/xml_parser/xml_parser_kernel_1.h @@ -0,0 +1,1532 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_XML_PARSER_KERNEl_1_ +#define DLIB_XML_PARSER_KERNEl_1_ + + +#include "xml_parser_kernel_abstract.h" + +#include +#include +#include +#include +#include "xml_parser_kernel_interfaces.h" +#include "../algs.h" +#include +#include "../map.h" +#include "../stack.h" +#include "../sequence.h" +#include "../memory_manager.h" + +namespace dlib +{ + + class xml_parser + { + typedef dlib::map::kernel_2a>::kernel_1b map; + typedef dlib::stack::kernel_2a>::kernel_1a stack; + typedef sequence::kernel_2a seq_dh; + typedef sequence::kernel_2a seq_eh; + + /*! + INITIAL VALUE + dh_list.size() == 0 + eh_list.size() == 0 + + CONVENTION + dh_list == a sequence of pointers to all the document_handlers that + have been added to the xml_parser + eh_list == a sequence of pointers to all the error_handlers that + have been added to the xml_parser + + map is used to implement the attribute_list interface + stack is used just inside the parse function + seq_dh is used to make the dh_list member variable + seq_eh is used to make the eh_list member variable + !*/ + + + + public: + + // These typedefs are here for backwards compatibly with previous versions of + // dlib. + typedef xml_parser kernel_1a; + typedef xml_parser kernel_1a_c; + + xml_parser( + ) {} + + virtual ~xml_parser( + ){} + + inline void clear( + ); + + inline void parse ( + std::istream& in + ); + + inline void add_document_handler ( + document_handler& item + ); + + inline void add_error_handler ( + error_handler& item + ); + + + inline void swap ( + xml_parser& item + ); + + + private: + + // ----------------------------------- + + // attribute_list interface implementation + class attrib_list : public attribute_list + { + public: + // the list of attribute name/value pairs + map list; + + bool is_in_list ( + const std::string& key + ) const + { + return list.is_in_domain(key); + } + + const std::string& operator[] ( + const std::string& key + ) const + { + if (is_in_list(key)) + return list[key]; + else + throw xml_attribute_list_error("No XML attribute named " + key + " is present in tag."); + } + + bool at_start ( + ) const { return list.at_start(); } + + void reset ( + ) const { return list.reset(); } + + bool current_element_valid ( + ) const { return list.current_element_valid(); } + + const type& element ( + ) const { return list.element(); } + + type& element ( + ) { return list.element(); } + + bool move_next ( + ) const { return list.move_next(); } + + size_t size ( + ) const { return list.size(); } + }; + + + // ----------------------------------- + + enum token_type + { + element_start, // the first tag of an element + element_end, // the last tag of an element + empty_element, // the singular tag of an empty element + pi, // processing instruction + chars, // the non-markup data between tags + chars_cdata, // the data from a CDATA section + eof, // this token is returned when we reach the end of input + error, // this token indicates that the tokenizer couldn't + // determine which category the next token fits into + dtd, // this token is for an entire dtd + comment // this is a token for comments + }; + /* + notes about the tokens: + the tokenizer guarantees that the following tokens to not + contain the '<' character except as the first character of the token + element_start, element_end, empty_element, and pi. they also only + contain the '>' characer as their last character. + + it is also guaranteed that pi is at least of the form . that + is to say that it always always begins with . + + it is also guaranteed that all markup tokens will begin with the '<' + character and end with the '>'. there won't be any leading or + trailing whitespaces. this whitespace is considered a chars token. + */ + + + // private member functions + inline void get_next_token( + std::istream& in, + std::string& token_text, + int& token_kind, + unsigned long& line_number + ); + /*! + ensures + gets the next token from in and puts it in token_text and + token_kind == the kind of the token found and + line_number is incremented every time a '\n' is encountered and + entity references are translated into the characters they represent + only for chars tokens + !*/ + + inline int parse_element ( + const std::string& token, + std::string& name, + attrib_list& atts + ); + /*! + requires + token is a token of kind start_element or empty_element + ensures + gets the element name and puts it into the string name and + parses out the attributes and puts them into the attribute_list atts + + return 0 upon success or + returns -1 if it failed to parse token + !*/ + + inline int parse_pi ( + const std::string& token, + std::string& target, + std::string& data + ); + /*! + requires + token is a token of kind pi + ensures + the target from the processing instruction is put into target and + the data from the processing instruction is put into data + + return 0 upon success or + returns -1 if it failed to parse token + !*/ + + inline int parse_element_end ( + const std::string& token, + std::string& name + ); + /*! + requires + token is a token of kind element_end + ensures + the name from the ending element tag is put into the string name + + return 0 upon success or + returns -1 if it failed to parse token + !*/ + + inline int change_entity ( + std::istream& in + ); + /*! + ensures + performs the following translations and returns the new character + amp; -> & + lt; -> < + gt; -> > + apos; -> ' + quot; -> " + + or returns -1 if we hit an undefined entity reference or EOF. + (i.e. it was not one of the entities listed above) + + !*/ + + // ----------------------------------- + + // private member data + seq_dh dh_list; + seq_eh eh_list; + + // ----------------------------------- + + // restricted functions: assignment and copy construction + xml_parser(xml_parser&); + xml_parser& operator= ( + xml_parser& + ); + + }; + + inline void swap ( + xml_parser& a, + xml_parser& b + ) { a.swap(b); } + + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + void xml_parser:: + clear( + ) + { + // unregister all event handlers + eh_list.clear(); + dh_list.clear(); + } + +// ---------------------------------------------------------------------------------------- + + void xml_parser:: + parse ( + std::istream& in + ) + { + DLIB_CASSERT ( in.fail() == false , + "\tvoid xml_parser::parse" + << "\n\tthe input stream must not be in the fail state" + << "\n\tthis: " << this + ); + + + // save which exceptions in will throw and make it so it won't throw any + // for the life of this function + std::ios::iostate old_exceptions = in.exceptions(); + // set it to not throw anything + in.exceptions(std::ios::goodbit); + + + try + { + unsigned long line_number = 1; + + // skip any whitespace before the start of the document + while (in.peek() == ' ' || in.peek() == '\t' || in.peek() == '\n' || in.peek() == '\r' ) + { + if (in.peek() == '\n') + ++line_number; + in.get(); + } + + + + stack tags; // this stack contains the last start tag seen + bool seen_fatal_error = false; + bool seen_root_tag = false; // this is true after we have seen the root tag + + + + // notify all the document_handlers that we are about to being parsing + for (unsigned long i = 0; i < dh_list.size(); ++i) + { + dh_list[i]->start_document(); + } + + + std::string chars_buf; // used to collect chars data between consecutive + // chars and chars_cdata tokens so that + // document_handlers receive all chars data between + // tags in one call + + // variables to be used with the parsing functions + attrib_list atts; + std::string name; + std::string target; + std::string data; + + + + // variables to use with the get_next_token() function + std::string token_text; + int token_kind; + + get_next_token(in,token_text,token_kind,line_number); + + + while (token_kind != eof) + { + bool is_empty = false; // this becomes true when this token is an empty_element + + switch (token_kind) + { + + + case empty_element: is_empty = true; + // fall through + case element_start: + { + seen_root_tag = true; + + int status = parse_element(token_text,name,atts); + // if there was no error parsing the element + if (status == 0) + { + // notify all the document_handlers + for (unsigned long i = 0; i < dh_list.size(); ++i) + { + dh_list[i]->start_element(line_number,name,atts); + if (is_empty) + dh_list[i]->end_element(line_number,name); + } + } + else + { + seen_fatal_error = true; + } + + // if this is an element_start token then push the name of + // the element on to the stack + if (token_kind == element_start) + { + tags.push(name); + } + + }break; + + // ---------------------------------------- + + case element_end: + { + + int status = parse_element_end (token_text,name); + + // if there was no error parsing the element + if (status == 0) + { + // make sure this ending element tag matches the last start + // element tag we saw + if ( tags.size() == 0 || name != tags.current()) + { + // they don't match so signal a fatal error + seen_fatal_error = true; + } + else + { + // notify all the document_handlers + for (unsigned long i = 0; i < dh_list.size(); ++i) + { + dh_list[i]->end_element(line_number,name); + } + + // they match so throw away this element name + tags.pop(name); + } + } + else + { + seen_fatal_error = true; + } + + + }break; + + // ---------------------------------------- + + case pi: + { + + int status = parse_pi (token_text,target,data); + // if there was no error parsing the element + if (status == 0) + { + // notify all the document_handlers + for (unsigned long i = 0; i < dh_list.size(); ++i) + { + dh_list[i]->processing_instruction(line_number,target,data); + } + } + else + { + // notify all the error_handlers + for (unsigned long i = 0; i < eh_list.size(); ++i) + { + eh_list[i]->error(line_number); + } + } + while (in.peek() == ' ' || in.peek() == '\t' || in.peek() == '\n' || in.peek() == '\r' ) + { + if (in.peek() == '\n') + ++line_number; + in.get(); + } + + + }break; + + // ---------------------------------------- + + case chars: + { + if (tags.size() != 0) + { + chars_buf += token_text; + } + else if (token_text.find_first_not_of(" \t\r\n") != std::string::npos) + { + // you can't have non whitespace chars data outside the root element + seen_fatal_error = true; + } + }break; + + // ---------------------------------------- + + case chars_cdata: + { + if (tags.size() != 0) + { + chars_buf += token_text; + } + else + { + // you can't have chars_data outside the root element + seen_fatal_error = true; + } + }break; + + // ---------------------------------------- + + case eof: + break; + + // ---------------------------------------- + + case error: + { + seen_fatal_error = true; + }break; + + // ---------------------------------------- + + case dtd: // fall though + case comment: // do nothing + break; + + // ---------------------------------------- + + + } + + // if there was a fatal error then quit loop + if (seen_fatal_error) + break; + + // if we have seen the last tag then quit the loop + if (tags.size() == 0 && seen_root_tag) + break; + + + get_next_token(in,token_text,token_kind,line_number); + + // if the next token is not a chars or chars_cdata token then flush + // the chars_buf to the document_handlers + if ( (token_kind != chars) && + (token_kind != chars_cdata) && + (token_kind != dtd) && + (token_kind != comment) && + (chars_buf.size() != 0) + ) + { + // notify all the document_handlers + for (unsigned long i = 0; i < dh_list.size(); ++i) + { + dh_list[i]->characters(chars_buf); + } + chars_buf.erase(); + } + + + } //while (token_kind != eof) + + + + + // you can't have any unmatched tags or any fatal erros + if (tags.size() != 0 || seen_fatal_error) + { + // notify all the error_handlers + for (unsigned long i = 0; i < eh_list.size(); ++i) + { + eh_list[i]->fatal_error(line_number); + } + + } + + + // notify all the document_handlers that we have ended parsing + for (unsigned long i = 0; i < dh_list.size(); ++i) + { + dh_list[i]->end_document(); + } + + } + catch (...) + { + // notify all the document_handlers that we have ended parsing + for (unsigned long i = 0; i < dh_list.size(); ++i) + { + dh_list[i]->end_document(); + } + + // restore the old exception settings to in + in.exceptions(old_exceptions); + + // don't forget to rethrow the exception + throw; + } + + // restore the old exception settings to in + in.exceptions(old_exceptions); + + } + +// ---------------------------------------------------------------------------------------- + + void xml_parser:: + add_document_handler ( + document_handler& item + ) + { + document_handler* temp = &item; + dh_list.add(dh_list.size(),temp); + } + +// ---------------------------------------------------------------------------------------- + + void xml_parser:: + add_error_handler ( + error_handler& item + ) + { + error_handler* temp = &item; + eh_list.add(eh_list.size(),temp); + } + +// ---------------------------------------------------------------------------------------- + + void xml_parser:: + swap ( + xml_parser& item + ) + { + dh_list.swap(item.dh_list); + eh_list.swap(item.eh_list); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // private member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + void xml_parser:: + get_next_token( + std::istream& in, + std::string& token_text, + int& token_kind, + unsigned long& line_number + ) + { + + token_text.erase(); + + std::istream::int_type ch1 = in.get(); + std::istream::int_type ch2; + + + switch (ch1) + { + + // ----------------------------------------- + + // this is the start of some kind of a tag + case '<': + { + ch2 = in.get(); + switch (ch2) + { + + // --------------------------------- + + // this is a dtd, comment, or chars_cdata token + case '!': + { + // if this is a CDATA section ******************************* + if ( in.peek() == '[') + { + token_kind = chars_cdata; + + // throw away the '[' + in.get(); + + // make sure the next chars are CDATA[ + std::istream::int_type ch = in.get(); + if (ch != 'C') + token_kind = error; + ch = in.get(); + if (ch != 'D') + token_kind = error; + ch = in.get(); + if (ch != 'A') + token_kind = error; + ch = in.get(); + if (ch != 'T') + token_kind = error; + ch = in.get(); + if (ch != 'A') + token_kind = error; + ch = in.get(); + if (ch != '[') + token_kind = error; + // if this is an error token then end + if (token_kind == error) + break; + + + // get the rest of the chars and put them into token_text + int brackets_seen = 0; // this is the number of ']' chars + // we have seen in a row + bool seen_closing = false; // true if we have seen ]]> + do + { + ch = in.get(); + + if (ch == '\n') + ++line_number; + + token_text += ch; + + // if this is the closing + if (brackets_seen == 2 && ch == '>') + seen_closing = true; + // if we are seeing a bracket + else if (ch == ']') + ++brackets_seen; + // if we didn't see a bracket + else + brackets_seen = 0; + + + } while ( (!seen_closing) && (ch != EOF) ); + + // check if this is an error token + if (ch == EOF) + { + token_kind = error; + } + else + { + token_text.erase(token_text.size()-3); + } + + + + } + // this is a comment token **************************** + else if (in.peek() == '-') + { + + token_text += ch1; + token_text += ch2; + token_text += '-'; + + token_kind = comment; + + // throw away the '-' char + in.get(); + + // make sure the next char is another '-' + std::istream::int_type ch = in.get(); + if (ch != '-') + { + token_kind = error; + break; + } + + token_text += '-'; + + + // get the rest of the chars and put them into token_text + int hyphens_seen = 0; // this is the number of '-' chars + // we have seen in a row + bool seen_closing = false; // true if we have seen ]]> + do + { + ch = in.get(); + + if (ch == '\n') + ++line_number; + + token_text += ch; + + // if this should be a closing block + if (hyphens_seen == 2) + { + if (ch == '>') + seen_closing = true; + else // this isn't a closing so make it signal error + ch = EOF; + } + // if we are seeing a hyphen + else if (ch == '-') + ++hyphens_seen; + // if we didn't see a hyphen + else + hyphens_seen = 0; + + + } while ( (!seen_closing) && (ch != EOF) ); + + // check if this is an error token + if (ch == EOF) + { + token_kind = error; + } + + + + + + } + else // this is a dtd token ************************* + { + + token_text += ch1; + token_text += ch2; + int bracket_depth = 1; // this is the number of '<' chars seen + // minus the number of '>' chars seen + + std::istream::int_type ch; + do + { + ch = in.get(); + if (ch == '>') + --bracket_depth; + else if (ch == '<') + ++bracket_depth; + else if (ch == '\n') + ++line_number; + + token_text += ch; + + } while ( (bracket_depth > 0) && (ch != EOF) ); + + // make sure we didn't just hit EOF + if (bracket_depth == 0) + { + token_kind = dtd; + } + else + { + token_kind = error; + } + } + } + break; + + // --------------------------------- + + // this is a pi token + case '?': + { + token_text += ch1; + token_text += ch2; + std::istream::int_type ch; + + do + { + ch = in.get(); + token_text += ch; + if (ch == '\n') + ++line_number; + // else if we hit a < then thats an error + else if (ch == '<') + ch = EOF; + } while (ch != '>' && ch != EOF); + // if we hit the end of the pi + if (ch == '>') + { + // make sure there was a trailing '?' + if ( (token_text.size() > 3) && + (token_text[token_text.size()-2] != '?') + ) + { + token_kind = error; + } + else + { + token_kind = pi; + } + } + // if we hit EOF unexpectidely then error + else + { + token_kind = error; + } + } + break; + + // --------------------------------- + + // this is an error token + case EOF: + { + token_kind = error; + } + break; + + // --------------------------------- + // this is an element_end token + case '/': + { + token_kind = element_end; + token_text += ch1; + token_text += ch2; + std::istream::int_type ch; + do + { + ch = in.get(); + if (ch == '\n') + ++line_number; + // else if we hit a < then thats an error + else if (ch == '<') + ch = EOF; + token_text += ch; + } while ( (ch != '>') && (ch != EOF)); + + // check if this is an error token + if (ch == EOF) + { + token_kind = error; + } + } + break; + + + // --------------------------------- + + // this is an element_start or empty_element token + default: + { + + token_text += ch1; + token_text += ch2; + std::istream::int_type ch = '\0'; + std::istream::int_type last; + do + { + last = ch; + ch = in.get(); + if (ch == '\n') + ++line_number; + // else if we hit a < then thats an error + else if (ch == '<') + ch = EOF; + token_text += ch; + } while ( (ch != '>') && (ch != EOF)); + + // check if this is an error token + if (ch == EOF) + { + token_kind = error; + } + // if this is an empty_element + else if (last == '/') + { + token_kind = empty_element; + } + else + { + token_kind = element_start; + } + + + } + break; + + // --------------------------------- + + } + + } + break; + + // ----------------------------------------- + + // this is an eof token + case EOF: + { + token_kind = eof; + } + break; + + // ----------------------------------------- + + // this is a chars token + default: + { + if (ch1 == '\n') + { + ++line_number; + token_text += ch1; + } + // if the first thing in this chars token is an entity reference + else if (ch1 == '&') + { + + int temp = change_entity(in); + if (temp == -1) + { + token_kind = error; + break; + } + else + { + token_text += temp; + } + } + else + { + token_text += ch1; + } + + + token_kind = chars; + + std::istream::int_type ch = 0; + while (in.peek() != '<' && in.peek() != EOF) + { + ch = in.get(); + + if (ch == '\n') + ++line_number; + + // if this is one of the predefined entity references then change it + if (ch == '&') + { + int temp = change_entity(in); + if (temp == -1) + { + ch = EOF; + break; + } + else + token_text += temp; + } + else + { + token_text += ch; + } + } + + // if this is an error token + if (ch == EOF) + { + token_kind = error; + } + + } + break; + + // ----------------------------------------- + + } + + + } + + + +// ---------------------------------------------------------------------------------------- + + int xml_parser:: + parse_element ( + const std::string& token, + std::string& name, + attrib_list& atts + ) + { + name.erase(); + atts.list.clear(); + + // there must be at least one character between the <> + if (token[1] == '>') + return -1; + + std::string::size_type i; + std::istream::int_type ch = token[1]; + i = 2; + + // fill out name. the name can not contain any of the following characters + while ( (ch != '>') && + (ch != ' ') && + (ch != '=') && + (ch != '/') && + (ch != '\t') && + (ch != '\r') && + (ch != '\n') + ) + { + name += ch; + ch = token[i]; + ++i; + } + + // skip any whitespaces + while ( ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r' ) + { + ch = token[i]; + ++i; + } + + // find any attributes + while (ch != '>' && ch != '/') + { + std::string attribute_name; + std::string attribute_value; + + // fill out attribute_name + while ( (ch != '=') && + (ch != ' ') && + (ch != '\t') && + (ch != '\r') && + (ch != '\n') && + (ch != '>') + ) + { + attribute_name += ch; + ch = token[i]; + ++i; + } + + // you can't have empty attribute names + if (attribute_name.size() == 0) + return -1; + + // if we hit > too early then return error + if (ch == '>') + return -1; + + // skip any whitespaces + while (ch == ' ' || ch == '\t' || ch =='\n' || ch =='\r') + { + ch = token[i]; + ++i; + } + + // the next char should be a '=', error if it's not + if (ch != '=') + return -1; + + // get the next char + ch = token[i]; + ++i; + + // skip any whitespaces + while (ch == ' ' || ch == '\t' || ch =='\n' || ch =='\r') + { + ch = token[i]; + ++i; + } + + + // get the delimiter for the attribute value + std::istream::int_type delimiter = ch; // this should be either a ' or " character + ch = token[i]; // get the next char + ++i; + if (delimiter != '\'' && delimiter!='"') + return -1; + + + // fill out attribute_value + while ( (ch != delimiter) && + (ch != '>') + ) + { + attribute_value += ch; + ch = token[i]; + ++i; + } + + + // if there was no delimiter then this is an error + if (ch == '>') + { + return -1; + } + + // go to the next char + ch = token[i]; + ++i; + + // the next char must be either a '>' or '/' (denoting the end of the tag) + // or a white space character + if (ch != '>' && ch != ' ' && ch != '/' && ch != '\t' && ch !='\n' && ch !='\r') + return -1; + + // skip any whitespaces + while (ch == ' ' || ch == '\t' || ch =='\n' || ch =='\r') + { + ch = token[i]; + ++i; + } + + + // add attribute_value and attribute_name to atts + if (atts.list.is_in_domain(attribute_name)) + { + // attributes may not be multiply defined + return -1; + } + else + { + atts.list.add(attribute_name,attribute_value); + } + + + } + + // you can't have an element with no name + if (name.size() == 0) + return -1; + + return 0; + + } + +// ---------------------------------------------------------------------------------------- + + int xml_parser:: + parse_pi ( + const std::string& token, + std::string& target, + std::string& data + ) + { + target.erase(); + data.erase(); + + std::istream::int_type ch = token[2]; + std::string::size_type i = 3; + while (ch != ' ' && ch != '?' && ch != '\t' && ch != '\n' && ch!='\r') + { + target += ch; + ch = token[i]; + ++i; + } + if (target.size() == 0) + return -1; + + // if we aren't at a ? character then go to the next character + if (ch != '?' ) + { + ch = token[i]; + ++i; + } + + // if we still aren't at the end of the processing instruction then + // set this stuff in the data section + while (ch != '?') + { + data += ch; + ch = token[i]; + ++i; + } + + return 0; + } + +// ---------------------------------------------------------------------------------------- + + int xml_parser:: + parse_element_end ( + const std::string& token, + std::string& name + ) + { + name.erase(); + std::string::size_type end = token.size()-1; + for (std::string::size_type i = 2; i < end; ++i) + { + if (token[i] == ' ' || token[i] == '\t' || token[i] == '\n'|| token[i] == '\r') + break; + name += token[i]; + } + + if (name.size() == 0) + return -1; + + return 0; + } + +// ---------------------------------------------------------------------------------------- + + int xml_parser:: + change_entity ( + std::istream& in + ) + { + + std::istream::int_type buf[6]; + + + buf[1] = in.get(); + + // if this is an undefined entity reference then return error + if (buf[1] != 'a' && + buf[1] != 'l' && + buf[1] != 'g' && + buf[1] != 'q' + ) + return -1; + + + buf[2] = in.get(); + // if this is an undefined entity reference then return error + if (buf[2] != 'm' && + buf[2] != 't' && + buf[2] != 'p' && + buf[2] != 'u' + ) + return -1; + + + buf[3] = in.get(); + // if this is an undefined entity reference then return error + if (buf[3] != 'p' && + buf[3] != ';' && + buf[3] != 'o' + ) + return -1; + + // check if this is < or > + if (buf[3] == ';') + { + if (buf[2] != 't') + return -1; + + // if this is < then return '<' + if (buf[1] == 'l') + return '<'; + // if this is > then return '>' + if (buf[1] == 'g') + return '>'; + + // it is neither so it must be an undefined entity reference + return -1; + } + + + buf[4] = in.get(); + // if this should be & + if (buf[4] == ';') + { + // if this is not & then return error + if (buf[1] != 'a' || + buf[2] != 'm' || + buf[3] != 'p' + ) + return -1; + + return '&'; + } + + buf[5] = in.get(); + + // if this should be ' + if (buf[1] == 'a' && + buf[2] == 'p' && + buf[3] == 'o' && + buf[4] == 's' && + buf[5] == ';' + ) + return '\''; + + + // if this should be " + if (buf[1] == 'q' && + buf[2] == 'u' && + buf[3] == 'o' && + buf[4] == 't' && + buf[5] == ';' + ) + return '"'; + + + // it was an undefined entity reference + return -1; + + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class xml_parse_error : public error + { + public: + xml_parse_error( + const std::string& a + ): error(a) {} + }; + + namespace impl + { + class default_xml_error_handler : public error_handler + { + std::string filename; + + public: + + default_xml_error_handler ( + ) {} + + default_xml_error_handler ( + const std::string& filename_ + ) :filename(filename_) {} + + virtual void error ( + const unsigned long + ) + { + // just ignore non-fatal errors + } + + virtual void fatal_error ( + const unsigned long line_number + ) + { + std::ostringstream sout; + if (filename.size() != 0) + sout << "There is a fatal error on line " << line_number << " in the XML file '"< +#include +#include "xml_parser_kernel_interfaces.h" + +namespace dlib +{ + + class xml_parser + { + + /*! + INITIAL VALUE + no objects are registered to receive events + + + WHAT THIS OBJECT REPRESENTS + This object represents a simple SAX style event driven XML parser. + It takes its input from an input stream object and sends events to all + registered document_handler and error_handler objects. + + note that this xml parser ignores all DTD related XML markup. It will + parse XML documents with DTD's but it just won't check if the document + is valid. This also means that entity references may not be used except + for the predefined ones which are as follows: + & + < + > + ' + " + + also note that there is no interpreting of entity references inside + a CDATA section or inside of tags, they are only interpreted inside + normal non-markup data. + + This parser considers the end of the xml document to be the closing + tag of the root tag (as opposed to using EOF as the end of the + document). This is a deviation from the xml standard. + + Aside from ignoring DTD stuff and entity references everywhere but + data, and the above comment regarding EOF, this parser should conform + to the rest of the XML standard. + !*/ + + public: + + + xml_parser( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc + !*/ + + virtual ~xml_parser( + ); + /*! + ensures + - all memory associated with *this has been released + !*/ + + void clear( + ); + /*! + ensures + - #*this has its initial value + throws + - std::bad_alloc + if this exception is thrown then *this is unusable + until clear() is called and succeeds + !*/ + + void parse ( + std::istream& in + ); + /*! + requires + - in.fail() == false + ensures + - the data from the input stream in will be parsed and the appropriate + events will be generated + - parsing will stop when the parser has reached the closing tag + for the xml document or EOF (which ever comes first). Note that + hitting EOF first is a fatal error. + throws + - std::bad_alloc + if parse() throws then it will be unusable until clear() is + called and succeeds + - other exceptions + document_handlers and error_handlers my throw any exception. If + they throw while parse() is running then parse() will let the + exception propagate out and the xml_parser object will be unusable + until clear() is called and succeeds. note that end_document() + is still called. + !*/ + + void add_document_handler ( + document_handler& item + ); + /*! + ensures + - item will now receive document events from the parser + throws + - std::bad_alloc + if add_document_handler() throws then it has no effect + !*/ + + void add_error_handler ( + error_handler& item + ); + /*! + ensures + - item will now receive error events from the parser + throws + - std::bad_alloc + if add_error_handler() throws then it has no effect + !*/ + + + void swap ( + xml_parser& item + ); + /*! + ensures + - swaps *this and item + !*/ + + + private: + + // restricted functions + xml_parser(xml_parser&); // copy constructor + xml_parser& operator=(xml_parser&); // assignment operator + + }; + + + inline void swap ( + xml_parser& a, + xml_parser& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class xml_parse_error : public error + { + /*! + WHAT THIS OBJECT REPRESENTS + This is the exception object thrown by the parse_xml() routines defined + below. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + void parse_xml ( + std::istream& in, + document_handler& dh, + error_handler& eh + ); + /*! + ensures + - makes an xml_parser and tells it to parse the given input stream using the + supplied document_handler and error_handler. + !*/ + + void parse_xml ( + std::istream& in, + error_handler& eh, + document_handler& dh + ) + /*! + ensures + - makes an xml_parser and tells it to parse the given input stream using the + supplied document_handler and error_handler. + !*/ + + void parse_xml ( + std::istream& in, + error_handler& eh + ); + /*! + ensures + - makes an xml_parser and tells it to parse the given input stream using the + supplied error_handler. + !*/ + + void parse_xml ( + std::istream& in, + document_handler& dh + ); + /*! + ensures + - makes an xml_parser and tells it to parse the given input stream using the + supplied document_handler. + - Uses a default error handler that will throw an xml_parse_error exception + if a fatal parsing error is encountered. + throws + - xml_parse_error + Thrown if a fatal parsing error is encountered. + !*/ + +// ---------------------------------------------------------------------------------------- + + void parse_xml ( + const std::string& filename, + document_handler& dh, + error_handler& eh + ); + /*! + ensures + - makes an xml_parser and tells it to parse the given input file using the + supplied error_handler and document_handler. + throws + - xml_parse_error + Thrown if there is a problem parsing the input file. + !*/ + + void parse_xml ( + const std::string& filename, + error_handler& eh, + document_handler& dh + ); + /*! + ensures + - makes an xml_parser and tells it to parse the given input file using the + supplied error_handler and document_handler. + throws + - xml_parse_error + Thrown if there is a problem parsing the input file. + !*/ + + void parse_xml ( + const std::string& filename, + error_handler& eh + ); + /*! + ensures + - makes an xml_parser and tells it to parse the given input file using the + supplied error_handler. + throws + - xml_parse_error + Thrown if there is a problem parsing the input file. + !*/ + + void parse_xml ( + const std::string& filename, + document_handler& dh + ); + /*! + ensures + - makes an xml_parser and tells it to parse the given input file using the + supplied document_handler. + - Uses a default error handler that will throw an xml_parse_error exception + if a fatal parsing error is encountered. + throws + - xml_parse_error + Thrown if there is a problem parsing the input file. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_XML_PARSER_KERNEl_ABSTRACT_ + diff --git a/ml/dlib/dlib/xml_parser/xml_parser_kernel_interfaces.h b/ml/dlib/dlib/xml_parser/xml_parser_kernel_interfaces.h new file mode 100644 index 000000000..a0edf3317 --- /dev/null +++ b/ml/dlib/dlib/xml_parser/xml_parser_kernel_interfaces.h @@ -0,0 +1,244 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_XML_PARSER_KERNEl_INTERFACES_ +#define DLIB_XML_PARSER_KERNEl_INTERFACES_ + +#include +#include "../interfaces/enumerable.h" +#include "../interfaces/map_pair.h" +#include "../error.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class xml_attribute_list_error : public dlib::error + { + /*! + WHAT THIS OBJECT REPRESENTS + This is an exception object thrown by attribute_list objects if you try to + access a non-existent attribute. + !*/ + public: + xml_attribute_list_error(const std::string& msg) : dlib::error(msg){} + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class attribute_list : public enumerable > + { + + /*! + WHAT THIS OBJECT REPRESENTS + this object represents a list of the attributes found in + an XML element. each attribute is associated with a value. + !*/ + + + public: + + inline virtual ~attribute_list ( + ) =0; + + + virtual bool is_in_list ( + const std::string& key + ) const =0; + /*! + ensures + - returns true if there is an attribute named key in the list + - returns false + !*/ + + virtual const std::string& operator[] ( + const std::string& key + ) const =0; + /*! + ensures + if (is_in_list(key) == true) then + - returns a const reference to the value associated with the attribute + named key. + - else + - throws xml_attribute_list_error + !*/ + + protected: + + // restricted functions + attribute_list& operator=(attribute_list&) {return *this;} + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class document_handler + { + /*! + EXCEPTIONS + a document_handler is allowed to throw any exception + + + WHAT THIS OBJECT REPRESENTS + this object is an interface for handling the basic events + generated by an XML parser + !*/ + + + public: + + inline virtual ~document_handler ( + ) =0; + + virtual void start_document ( + )=0; + /*! + requires + - is called when the document parsing begins + !*/ + + virtual void end_document ( + )=0; + /*! + requires + - is called after the document parsing has ended. note that this + is always called, even if an error occurs. + !*/ + + virtual void start_element ( + const unsigned long line_number, + const std::string& name, + const dlib::attribute_list& atts + )=0; + /*! + requires + - is called when an opening element tag is encountered. + - line_number == the line number where the opening tag for this element + was encountered. + - name == the name of the element encountered + - atts == a list containing all the attributes in this element and their + associated values + !*/ + + virtual void end_element ( + const unsigned long line_number, + const std::string& name + )=0; + /*! + requires + - is called when a closing element tag is encountered. (note that this + includes tags such as . I.e. the previous tag would + trigger a start_element() callback as well as an end_element() callback) + - line_number == the line number where the closing tag for this + element was encountered and + - name == the name of the element encountered + !*/ + + virtual void characters ( + const std::string& data + )=0; + /*! + requires + - is called just before we encounter a start_element, end_element, or + processing_instruction tag but only if there was data between the + last and next tag. + (i.e. data will never be "") + - data == all the normal non-markup data and CDATA between the next and + last tag in the document. + !*/ + + virtual void processing_instruction ( + const unsigned long line_number, + const std::string& target, + const std::string& data + )=0; + /*! + requires + - is called when a processing instruction is encountered + - line_number == the line number where this processing instruction + was encountered + - target == the target value for this processing instruction + - data == the data value for this processing instruction + !*/ + + protected: + + // restricted functions + document_handler& operator=(document_handler&) { return *this; } + }; + + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class error_handler + { + /*! + EXCEPTIONS + an error_handler is allowed to throw any exception + + + WHAT THIS OBJECT REPRESENTS + this object is an interface for handling the error/warning + events generated by an XML parser + !*/ + + public: + + inline virtual ~error_handler ( + ) =0; + + virtual void error ( + const unsigned long line_number + )=0; + /*! + requires + - is called when an error that does NOT require the parser to halt + is encountered. (i.e. somewhat minor errors in the input) + - line_number == the line number where this error was encountered + + the following events trigger an error: + an invalid processing instruction + !*/ + + virtual void fatal_error ( + const unsigned long line_number + )=0; + /*! + requires + - is called when an error that requires the parser to abort its parsing + is encountered (i.e. fatal errors in the input) + - line_number == the line number where this fatal error was encountered + + the following events trigger a fatal_error: + Everything other than the events listed above for error. + Also note that encountering an entity reference other than the + predefined ones listed in xml_parser_kernel_abstract is a fatal_error. + Hitting EOF before the closing tag for the document is also a fatal_error. + !*/ + + protected: + + // restricted functions + error_handler& operator=(error_handler&) { return *this;} + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + document_handler::~document_handler ( + ){} + attribute_list::~attribute_list ( + ){} + error_handler::~error_handler ( + ){} + +} + +#endif // DLIB_XML_PARSER_KERNEl_INTERFACES_ + diff --git a/ml/dlib/docs/.logger_revnum b/ml/dlib/docs/.logger_revnum new file mode 100644 index 000000000..bf6b783fb --- /dev/null +++ b/ml/dlib/docs/.logger_revnum @@ -0,0 +1 @@ +f9ef9feebe06 diff --git a/ml/dlib/docs/README.txt b/ml/dlib/docs/README.txt new file mode 100644 index 000000000..155043460 --- /dev/null +++ b/ml/dlib/docs/README.txt @@ -0,0 +1,72 @@ +This "package" is just a copy of the stuff I use to generate the documentation +for the dlib library. It contains a copy of the XSLT and XML I use to +generate the HTML documentation. + +The current version of these files can be obtained from the dlib GitHub +repository at: https://github.com/davisking/dlib + +======================== Overview ======================== + +I write all my documentation in XML files. If you look through the files in +the docs folder you will see each of them. There is also a stylesheet.xsl +file which contains all the XSLT I wrote to transform XML files into HTML. +Anyway, I use that stylesheet to generate the dlib documentation from those +XML files. + +There is also a stylesheet inside the docs/chm folder (htmlhelp_stylesheet.xsl) +that knows how to look at the XML files and generate the table of contents +files needed by the htmlhelp tool (the thing that makes chm help files). + +Also note that the first 80 or so lines of the stylesheet.xsl file contains +stuff specific to the dlib project and thus should be changed or removed +as appropriate if you want to reuse it for a different project. + +======================== Installing the required tools ======================== + +To begin with, the XML and XSLT is usable on any operating system, however, +all the scripts I have in the docs folder that automate everything are bash +shell scripts. I also use stuff like wine and other Linux tools and I have +only ever tested any of this in Debian. So if you want to use all the scripts +then you should probably run this stuff in Linux. But if not you can probably +hack something together :) + +There are four scripts in the docs folder. + + - testenv_rel: This script tests your environment for all the needed utilities. + Run it and it should tell you what else you need to install. + Note that the htmlify utility is something I wrote and is in + dlib's repository in the tools/htmlify folder. You should + build and install it. (go into that folder, make a subfolder + called build, then cd into build and say: "cmake ..; make; + sudo make install". You will need to install cmake if you + don't have it already) + + - makedocs: This remakes all the HTML documentation by pulling files out + of the dlib repository. If you want to use this stuff for your + own projects you will need to edit this file a bit. + + Note that this script puts its output in the docs/web and + docs/chm/docs folders. I use the chm folder for off-line + documentation while the web folder contains what goes onto + dlib.net. Both sets of HTML are generated from the same XML + files and are mostly the same. You will see and + tags inside the XML though in cases where the two + differ. + + - makerel: Runs makedocs as well as creates tar and zip files of the project. + It also runs htmlhelp in wine to generate the chm help files. + Note that you will need to run docs/chm/htmlhelp/setup_htmlhelp.sh + before it will work in wine. + + +======================== License for documentation files ======================== + +To the extent possible under law, Davis E King has waived all copyright and +related or neighboring rights to dlib documentation (XML, HTML, and XSLT files). +This work is published from United States. + +That is, I (Davis the author) don't care what you do with this. So do +whatever you want :) + + + diff --git a/ml/dlib/docs/bash_helper_functions b/ml/dlib/docs/bash_helper_functions new file mode 100755 index 000000000..66603e6d2 --- /dev/null +++ b/ml/dlib/docs/bash_helper_functions @@ -0,0 +1,30 @@ +#/bin/sh +# +# This script defines some helper functions used by other scripts in the docs +# folder. + +get_short_revision_number() +{ + RESULT=`hg log -r $1 | grep changeset | awk '{print $2}' | sed -e 's/:.*//'` +} + +get_last_modified_date() +{ + RESULT=`hg log $1 -l1 --template '{date|date}\n' | awk '{ print $2" "$3", " $5}'` +} + +get_dlib_version() +{ + cat ../dlib/CMakeLists.txt | awk '/set\(CPACK_PACKAGE_VERSION_'$1'/{ match($2,"\"(.*)\"",a); print a[1]}' +} + + +# call like: set_dlib_version MAJOR 42 +set_dlib_version() +{ + sed -i -e 's/\(set(CPACK_PACKAGE_VERSION_'$1' *"\).*\(".*\)/\1'$2'\2/' ../dlib/CMakeLists.txt +} + +MAJOR_NUM=`get_dlib_version MAJOR` +MINOR_NUM=`get_dlib_version MINOR` +PATCH_NUM=`get_dlib_version PATCH` diff --git a/ml/dlib/docs/docs/algorithms.xml b/ml/dlib/docs/docs/algorithms.xml new file mode 100644 index 000000000..82c58f94a --- /dev/null +++ b/ml/dlib/docs/docs/algorithms.xml @@ -0,0 +1,1118 @@ + + + + + Algorithms + + + + + +

    + This page documents library components that are all basically just implementations of + mathematical functions or algorithms that don't fit in any of the other pages + of the dlib documentation. So this includes things like checksums, cryptographic hashes, + sorting, etc. +

    + + + + + + + +
    + Tools + bigint + disjoint_subsets + disjoint_subsets_sized + + Quantum Computing + + quantum_register + gate + + + hsort_array + isort_array + numeric_constants + put_in_range + qsort_array + split_array + integrate_function_adapt_simp + square_root + + Set Utilities + + set_intersection_size + set_intersection + set_union + set_difference + + +
    + + + +
    + Statistics + rand + median + running_stats + running_stats_decayed + running_scalar_covariance_decayed + running_gradient + running_scalar_covariance + mean_sign_agreement + correlation + covariance + r_squared + mean_squared_error + running_covariance + running_cross_covariance + random_subset_selector + randomly_subsample + find_upper_quantile + count_steps_without_decrease_robust + count_steps_without_decrease + count_steps_without_increase + + binomial_random_vars_are_different + event_correlation + max_scoring_element + min_scoring_element + +
    + +
    + Hashing + md5 + crc32 + hash + count_bits + hamming_distance + murmur_hash3 + murmur_hash3_128bit + gaussian_random_hash + uniform_random_hash + projection_hash + create_random_projection_hash + create_max_margin_projection_hash + hash_samples + hash_similar_angles_64 + hash_similar_angles_128 + hash_similar_angles_256 + hash_similar_angles_512 +
    + +
    + Filtering + kalman_filter + rls_filter + momentum_filter + rect_filter + find_optimal_rect_filter + find_optimal_momentum_filter +
    + +
    +
    + + + + + + + + + + + hash_similar_angles_64 + dlib/lsh.h + dlib/lsh/hashes_abstract.h + + This object is a tool for computing locality sensitive hashes that give + vectors with small angles between each other similar hash values. In + particular, this object creates 64 random planes which pass though the + origin and uses them to create a 64bit hash. + + + + + + + + hash_similar_angles_128 + dlib/lsh.h + dlib/lsh/hashes_abstract.h + + This object is a tool for computing locality sensitive hashes that give + vectors with small angles between each other similar hash values. In + particular, this object creates 128 random planes which pass though the + origin and uses them to create a 128bit hash. + + + + + + + + + hash_similar_angles_256 + dlib/lsh.h + dlib/lsh/hashes_abstract.h + + This object is a tool for computing locality sensitive hashes that give + vectors with small angles between each other similar hash values. In + particular, this object creates 256 random planes which pass though the + origin and uses them to create a 256bit hash. + + + + + + + + + hash_similar_angles_512 + dlib/lsh.h + dlib/lsh/hashes_abstract.h + + This object is a tool for computing locality sensitive hashes that give + vectors with small angles between each other similar hash values. In + particular, this object creates 512 random planes which pass though the + origin and uses them to create a 512bit hash. + + + + + + + + hash_samples + dlib/graph_utils_threaded.h + dlib/graph_utils/find_k_nearest_neighbors_lsh_abstract.h + + This is a simple function for hashing a bunch of vectors using a + locality sensitive hashing object such as hash_similar_angles_128. + It is also capable of running in parallel on a multi-core CPU. + + + + + + + + bigint + dlib/bigint.h + dlib/bigint/bigint_kernel_abstract.h + + This object represents an arbitrary precision unsigned integer. It's pretty simple. + It's interface is just like a normal int, you don't have to tell it how much memory + to use or anything unusual. It just goes :) + + + + + bigint_kernel_1 + dlib/bigint/bigint_kernel_1.h + + This implementation is done using an array of unsigned shorts. It is also reference counted. + For further details see the above link. Also note that kernel_2 should be + faster in almost every case so you should really just use that version of the bigint object. + + + + + kernel_1a + is a typedef for bigint_kernel_1 + + + + + + + bigint_kernel_2 + dlib/bigint/bigint_kernel_2.h + + This implementation is basically the same as kernel_1 except it uses the + Fast Fourier Transform to perform multiplications much faster. + + + + + kernel_2a + is a typedef for bigint_kernel_2 + + + + + + + + + + + + + + crc32 + dlib/crc32.h + dlib/crc32/crc32_kernel_abstract.h + + This object represents the CRC-32 algorithm for calculating checksums. + + + + + + + + gaussian_random_hash + dlib/hash.h + dlib/general_hash/random_hashing_abstract.h + + This function uses hashing to generate Gaussian distributed random values + with mean 0 and variance 1. + + + + + + + + uniform_random_hash + dlib/hash.h + dlib/general_hash/random_hashing_abstract.h + + This function uses hashing to generate uniform random values in the range [0,1). + + + + + + + + murmur_hash3 + dlib/hash.h + dlib/general_hash/murmur_hash3_abstract.h + + This function takes a block of memory and returns a 32bit hash. The + hashing algorithm used is Austin Appleby's excellent + MurmurHash3. + + + + + + + + murmur_hash3_128bit + dlib/hash.h + dlib/general_hash/murmur_hash3_abstract.h + + This function takes a block of memory and returns a 128bit hash. The + hashing algorithm used is Austin Appleby's excellent + MurmurHash3. + + + + + + + + kalman_filter + dlib/filtering.h + dlib/filtering/kalman_filter_abstract.h + + This object implements the Kalman filter, which is a tool for + recursively estimating the state of a process given measurements + related to that process. To use this tool you will have to + be familiar with the workings of the Kalman filter. An excellent + introduction can be found in the paper: +
    + An Introduction to the Kalman Filter + by Greg Welch and Gary Bishop +
    +
    +
    + + + + + momentum_filter + dlib/filtering.h + dlib/filtering/kalman_filter_abstract.h + + This object is a simple tool for filtering a single scalar value that + measures the location of a moving object that has some non-trivial + momentum. Importantly, the measurements are noisy and the object can + experience sudden unpredictable accelerations. To accomplish this + filtering we use a simple Kalman filter with a + state transition model of: +
    +
    +   position_{i+1} = position_{i} + velocity_{i} 
    +   velocity_{i+1} = velocity_{i} + some_unpredictable_acceleration
    +
    +
    + + and a measurement model of: +
    +
    +   measured_position_{i} = position_{i} + measurement_noise
    +
    +
    + + Where some_unpredictable_acceleration and measurement_noise are 0 mean Gaussian + noise sources. + + To allow for really sudden and large but infrequent accelerations, at each + step we check if the current measured position deviates from the predicted + filtered position by more than a user specified amount, + and if so we adjust the filter's state to keep it within these bounds. + This allows the moving object to undergo large unmodeled accelerations, far + in excess of what would be suggested by the basic Kalman filter's noise model, without + then experiencing a long lag time where the Kalman filter has to "catch + up" to the new position. +
    +
    + + + + + rect_filter + dlib/filtering.h + dlib/filtering/kalman_filter_abstract.h + + This object is just a momentum_filter applied to the + four corners of a rectangle. It allows + you to filter a stream of rectangles, for instance, bounding boxes from an object detector + applied to a video stream. + + + + + + + find_optimal_momentum_filter + dlib/filtering.h + dlib/filtering/kalman_filter_abstract.h + + This function finds the "optimal" settings of a momentum_filter + based on unfiltered measurement data. + + + + + + + find_optimal_rect_filter + dlib/filtering.h + dlib/filtering/kalman_filter_abstract.h + + This function finds the "optimal" settings of a rect_filter + based on unfiltered measurement data. + + + + + + + rls_filter + dlib/filtering.h + dlib/filtering/rls_filter_abstract.h + + This object is a tool for doing time series prediction using + linear recursive least squares. In particular, + this object takes a sequence of points from the user and, at each + step, attempts to predict the value of the next point. + + + + + + + + projection_hash + dlib/lsh.h + dlib/lsh/projection_hash_abstract.h + + This is a tool for hashing elements of a vector space into the integers. + It is intended to represent locality sensitive hashing functions such as + the popular random projection hashing method. + + + + + + + + create_random_projection_hash + dlib/lsh.h + dlib/lsh/create_random_projection_hash_abstract.h + + Creates a random projection based locality sensitive + hashing function. The projection matrix + is generated by sampling its elements from a Gaussian random number generator. + + + + + + + + create_max_margin_projection_hash + dlib/lsh.h + dlib/lsh/create_random_projection_hash_abstract.h + + Creates a random projection based locality sensitive + hashing function. + This is accomplished using a variation on the random hyperplane generation + technique from the paper: +
    + Random Maximum Margin Hashing by Alexis Joly and Olivier Buisson +
    + In particular, we use a linear support vector machine to generate planes. + We train it on randomly selected and randomly labeled points from + the data to be hashed. +
    + +
    + + + + + hash + dlib/hash.h + dlib/general_hash/hash_abstract.h + + This is a set of convenience functions for invoking murmur_hash3 + on std::strings, std::vectors, std::maps, or dlib::matrix objects. +

    + As an aside, the hash() for matrix objects is defined here. + It has the same interface as all the others. +

    +
    + +
    + + + + + count_bits + dlib/hash.h + dlib/general_hash/count_bits_abstract.h + + This function counts the number of bits in an unsigned integer which are + set to 1. + + + + + + + + hamming_distance + dlib/hash.h + dlib/general_hash/count_bits_abstract.h + + This function returns the hamming distance between two unsigned integers. + That is, it returns the number of bits which differer in the two integers. + + + + + + + + rand + dlib/rand.h + dlib/rand/rand_kernel_abstract.h + + This object represents a pseudorandom number generator. + + + + + + + + disjoint_subsets + dlib/disjoint_subsets.h + dlib/disjoint_subsets/disjoint_subsets_abstract.h + + This object represents a set of integers which is partitioned into + a number of disjoint subsets. It supports the two fundamental operations + of finding which subset a particular integer belongs to as well as + merging subsets. + + + + + + + + + disjoint_subsets_sized + dlib/disjoint_subsets.h + dlib/disjoint_subsets/disjoint_subsets_sized_abstract.h + + This object is just like disjoint_subsets except that it + also keeps track of the size of each set. + + + + + + + + + running_stats + dlib/statistics.h + dlib/statistics/statistics_abstract.h + + This object represents something that can compute the running mean, + variance, skewness, and kurtosis statistics of a stream of real numbers. + + + + running_stats_ex.cpp.html + kcentroid_ex.cpp.html + + + + + + + + + running_stats_decayed + dlib/statistics.h + dlib/statistics/statistics_abstract.h + + This object represents something that can compute the running mean and + variance of a stream of real numbers. It is similar to running_stats + except that it forgets about data it has seen after a certain period of + time. It does this by exponentially decaying old statistics. + + + + + + + + running_scalar_covariance_decayed + dlib/statistics.h + dlib/statistics/statistics_abstract.h + + This object represents something that can compute the running covariance of + a stream of real number pairs. It is essentially the same as + running_scalar_covariance except that it forgets about data it has seen + after a certain period of time. It does this by exponentially decaying old + statistics. + + + + + + + + running_gradient + dlib/statistics/running_gradient.h + dlib/statistics/running_gradient_abstract.h + + This object is a tool for estimating if a noisy sequence of numbers is + trending up or down and by how much. It does this by finding the least + squares fit of a line to the data and then allows you to perform a + statistical test on the slope of that line. + + + + + + + find_upper_quantile + dlib/statistics/running_gradient.h + dlib/statistics/running_gradient_abstract.h + + Finds and returns the scalar value such that a user specified percentage of + the values in a container are greater than said value. For example, 0.5 + would find the median value in a container while 0.1 would find the value + that lower bounded the 10% largest values in a container. + + + + + + + count_steps_without_increase + dlib/statistics/running_gradient.h + dlib/statistics/running_gradient_abstract.h + + Given a potentially noisy time series, this function returns a count of how + long the time series has gone without noticeably increasing in value. It does + this by adding the elements of the time series into a running_gradient object and counting how many + elements, starting with the most recent, you need to examine before you + are confident that the series has been increasing in value. + + + + + + + binomial_random_vars_are_different + dlib/statistics/statistic.h + dlib/statistics/statistics_abstract.h + + This function performs a simple statistical test to check if two binomially + distributed random variables have the same parameter (i.e. the chance of + "success"). It uses the simple likelihood ratio test discussed in + the following paper: +
    + Dunning, Ted. "Accurate methods for the statistics of surprise and + coincidence." Computational linguistics 19.1 (1993): 61-74. +
    + So for an extended discussion of the method see the above paper. +
    +
    + + + + + event_correlation + dlib/statistics/statistic.h + dlib/statistics/statistics_abstract.h + + This function does a statistical test to determine if two events co-occur in a + statistically significant way. It uses the simple likelihood ratio + test discussed in the following paper: +
    + Dunning, Ted. "Accurate methods for the statistics of surprise and + coincidence." Computational linguistics 19.1 (1993): 61-74. +
    + So for an extended discussion of the method see the above paper. +
    +
    + + + + + max_scoring_element + dlib/algs.h + dlib/algs.h + + This function finds the element of container that has the largest score, + according to a user supplied score function, and returns a std::pair containing + that maximal element along with the score. + + + + + + + min_scoring_element + dlib/algs.h + dlib/algs.h + + This function finds the element of container that has the smallest score, + according to a user supplied score function, and returns a std::pair containing + that minimal element along with the score. + + + + + + + count_steps_without_decrease + dlib/statistics/running_gradient.h + dlib/statistics/running_gradient_abstract.h + + Given a potentially noisy time series, this function returns a count of how + long the time series has gone without noticeably decreasing in value. It does + this by adding the elements of the time series into a running_gradient object and counting how many + elements, starting with the most recent, you need to examine before you + are confident that the series has been decreasing in value. + + + + + + + count_steps_without_decrease_robust + dlib/statistics/running_gradient.h + dlib/statistics/running_gradient_abstract.h + + This function behaves just like count_steps_without_decrease except + that it ignores times series values that are anomalously large. This makes it + robust to sudden noisy but transient spikes in the time series values. + + + + + + + running_covariance + dlib/statistics.h + dlib/statistics/statistics_abstract.h + + This object is a simple tool for computing the mean and + covariance of a sequence of vectors. + + + + + + + + running_cross_covariance + dlib/statistics.h + dlib/statistics/statistics_abstract.h + + This object is a simple tool for computing the mean and + cross-covariance matrices of a sequence of pairs of vectors. + + + + + + + + running_scalar_covariance + dlib/statistics.h + dlib/statistics/statistics_abstract.h + + This object is a simple tool for computing the covariance of a + sequence of scalar values. + + + + + + + mean_sign_agreement + dlib/statistics.h + dlib/statistics/statistics_abstract.h + + This is a function for computing the probability that + matching elements of two std::vectors have the same sign. + + + + + + + correlation + dlib/statistics.h + dlib/statistics/statistics_abstract.h + + This is a function for computing the correlation between + matching elements of two std::vectors. + + + + + + + covariance + dlib/statistics.h + dlib/statistics/statistics_abstract.h + + This is a function for computing the covariance between + matching elements of two std::vectors. + + + + + + + r_squared + dlib/statistics.h + dlib/statistics/statistics_abstract.h + + This is a function for computing the R squared coefficient between + matching elements of two std::vectors. + + + + + + + mean_squared_error + dlib/statistics.h + dlib/statistics/statistics_abstract.h + + This is a function for computing the mean squared error between + matching elements of two std::vectors. + + + + + + + random_subset_selector + dlib/statistics.h + dlib/statistics/random_subset_selector_abstract.h + + This object is a tool to help you select a random subset of a large body of data. + In particular, it is useful when the body of data is too large to fit into memory. + + + + + + + + + + randomly_subsample + dlib/statistics.h + dlib/statistics/random_subset_selector_abstract.h + + This is a set of convenience functions for + creating random subsets of data. + + + + + + + + hsort_array + dlib/sort.h + dlib/sort.h + + hsort_array is an implementation of the heapsort algorithm. It will sort anything that has an + array like operator[] interface. + + + + + + + + put_in_range + dlib/algs.h + dlib/algs.h + + This is a simple function that takes a range and a value and returns the given + value if it is within the range. If it isn't in the range then it returns the + end of range value that is closest. + + + + + + + + isort_array + dlib/sort.h + dlib/sort.h + + isort_array is an implementation of the insertion sort algorithm. It will sort anything that has an + array like operator[] interface. + + + + + + + + + numeric_constants + dlib/numeric_constants.h + dlib/numeric_constants.h + + This is just a header file containing definitions of common numeric constants such as pi and e. + + + + + + + qsort_array + dlib/sort.h + dlib/sort.h + + qsort_array is an implementation of the QuickSort algorithm. It will sort anything that has an array like + operator[] interface. If the quick sort becomes unstable then it switches to a heap sort. This + way sorting is guaranteed to take at most N*log(N) time. + + + + + + + + + split_array + dlib/array.h + dlib/array/array_tools_abstract.h + + This function is used to efficiently split array + like objects into two parts. It uses the global swap() function instead + of copying to move elements around, so it works on arrays of non-copyable + types. + + + + + + + integrate_function_adapt_simp + dlib/numerical_integration.h + dlib/numerical_integration/integrate_function_adapt_simpson_abstract.h + + Computes an approximation of the integral of a real valued function using the + adaptive Simpson method outlined in +
    + Gander, W. and W. Gautshi, "Adaptive + Quadrature -- Revisited" BIT, Vol. 40, (2000), pp.84-101 +
    +
    + + integrate_function_adapt_simp_ex.cpp.html + + +
    + + + + + + md5 + dlib/md5.h + dlib/md5/md5_kernel_abstract.h + + This is an implementation of The MD5 Message-Digest Algorithm as described in rfc1321. + + + + + + + + + + median + dlib/algs.h + dlib/algs.h + + This function takes three parameters and finds the median of the three. The median is swapped into + the first parameter and the first parameter ends up in one of the other two, unless the first parameter was + the median to begin with of course. + + + + + + + + square_root + dlib/algs.h + dlib/algs.h + + square_root is a function which takes an unsigned long and returns the square root of it or + if the root is not an integer then it is rounded up to the next integer. + + + + + + + + set_intersection + dlib/set_utils.h + dlib/set_utils/set_utils_abstract.h + + This function takes two set objects and + gives you their intersection. + + + + + + + + set_union + dlib/set_utils.h + dlib/set_utils/set_utils_abstract.h + + This function takes two set objects and + gives you their union. + + + + + + + + set_difference + dlib/set_utils.h + dlib/set_utils/set_utils_abstract.h + + This function takes two set objects and + gives you their difference. + + + + + + + + set_intersection_size + dlib/set_utils.h + dlib/set_utils/set_utils_abstract.h + + This function takes two set objects and tells you + how many items they have in common. + + + + + + + + quantum_register + dlib/quantum_computing.h + dlib/quantum_computing/quantum_computing_abstract.h + + This object represents a set of quantum bits. It can be used + with the quantum gate object to simulate + quantum algorithms. + + + + quantum_computing_ex.cpp.html + + + + + + + + gate + dlib/quantum_computing.h + dlib/quantum_computing/quantum_computing_abstract.h + + This object represents a quantum gate that operates on a + quantum_register. + + + quantum_computing_ex.cpp.html + + + + + + +
    + + + + + + diff --git a/ml/dlib/docs/docs/api.xml b/ml/dlib/docs/docs/api.xml new file mode 100644 index 000000000..2d6811b44 --- /dev/null +++ b/ml/dlib/docs/docs/api.xml @@ -0,0 +1,1289 @@ + + + + + API Wrappers + + + + + +

    + + These wrappers provide a portable object oriented interface for networking, multithreading, + GUI development, and file browsing. + Programs written using them can be compiled under POSIX or MS Windows platforms without changing the code. +

    + + + + + + + + + +
    + API + + + + + gui_widgets + + + widgets + + + draggable + dlib/gui_widgets/base_widgets_abstract.h.html#draggable + + + tooltip + dlib/gui_widgets/base_widgets_abstract.h.html#tooltip + + + popup_menu_region + dlib/gui_widgets/base_widgets_abstract.h.html#popup_menu_region + + + button_action + dlib/gui_widgets/base_widgets_abstract.h.html#button_action + + + scrollable_region + dlib/gui_widgets/base_widgets_abstract.h.html#scrollable_region + + + zoomable_region + dlib/gui_widgets/base_widgets_abstract.h.html#zoomable_region + + + mouse_over_event + dlib/gui_widgets/base_widgets_abstract.h.html#mouse_over_event + + + scroll_bar + dlib/gui_widgets/base_widgets_abstract.h.html#scroll_bar + + + widget_group + dlib/gui_widgets/base_widgets_abstract.h.html#widget_group + + + image_widget + dlib/gui_widgets/base_widgets_abstract.h.html#image_widget + + + popup_menu + dlib/gui_widgets/base_widgets_abstract.h.html#popup_menu + + + menu_item + dlib/gui_widgets/base_widgets_abstract.h.html#menu_item + + + menu_item_text + dlib/gui_widgets/base_widgets_abstract.h.html#menu_item_text + + + menu_item_separator + dlib/gui_widgets/base_widgets_abstract.h.html#menu_item_separator + + + menu_item_submenu + dlib/gui_widgets/base_widgets_abstract.h.html#menu_item_submenu + + + named_rectangle + dlib/gui_widgets/widgets_abstract.h.html#named_rectangle + + + menu_bar + dlib/gui_widgets/widgets_abstract.h.html#menu_bar + + + perspective_window + dlib/gui_widgets/widgets_abstract.h.html#perspective_window + + + perspective_display + dlib/gui_widgets/widgets_abstract.h.html#perspective_display + + + image_window + dlib/gui_widgets/widgets_abstract.h.html#image_window + + + image_display + dlib/gui_widgets/widgets_abstract.h.html#image_display + + + message_box + dlib/gui_widgets/widgets_abstract.h.html#message_box + + + message_box_blocking + dlib/gui_widgets/widgets_abstract.h.html#message_box_blocking + + + open_file_box + dlib/gui_widgets/widgets_abstract.h.html#open_file_box + + + open_existing_file_box + dlib/gui_widgets/widgets_abstract.h.html#open_existing_file_box + + + save_file_box + dlib/gui_widgets/widgets_abstract.h.html#save_file_box + + + label + dlib/gui_widgets/widgets_abstract.h.html#label + + + button + dlib/gui_widgets/base_widgets_abstract.h.html#button + + + toggle_button + dlib/gui_widgets/widgets_abstract.h.html#toggle_button + + + text_grid + dlib/gui_widgets/widgets_abstract.h.html#text_grid + + + directed_graph_drawer + dlib/gui_widgets/widgets_abstract.h.html#directed_graph_drawer + + + list_box + dlib/gui_widgets/widgets_abstract.h.html#list_box + + + check_box + dlib/gui_widgets/widgets_abstract.h.html#check_box + + + radio_button + dlib/gui_widgets/widgets_abstract.h.html#radio_button + + + text_field + dlib/gui_widgets/widgets_abstract.h.html#text_field + + + text_box + dlib/gui_widgets/widgets_abstract.h.html#text_box + + + tabbed_display + dlib/gui_widgets/widgets_abstract.h.html#tabbed_display + + + mouse_tracker + dlib/gui_widgets/widgets_abstract.h.html#mouse_tracker + + + + + styles + + + button_style + dlib/gui_widgets/style_abstract.h.html#button_style + + + button_style_default + dlib/gui_widgets/style_abstract.h.html#button_style_default + + + button_style_arrow + dlib/gui_widgets/style_abstract.h.html#button_style_arrow + + + button_style_toolbar1 + dlib/gui_widgets/style_abstract.h.html#button_style_toolbar1 + + + button_style_toolbar_icon1 + dlib/gui_widgets/style_abstract.h.html#button_style_toolbar_icon1 + + + + + toggle_button_style + dlib/gui_widgets/style_abstract.h.html#toggle_button_style + + + toggle_button_style_default + dlib/gui_widgets/style_abstract.h.html#toggle_button_style_default + + + toggle_button_style_check_box + dlib/gui_widgets/style_abstract.h.html#toggle_button_style_check_box + + + toggle_button_style_radio_button + dlib/gui_widgets/style_abstract.h.html#toggle_button_style_radio_button + + + + + scroll_bar_style + dlib/gui_widgets/style_abstract.h.html#scroll_bar_style + + + scroll_bar_style_default + dlib/gui_widgets/style_abstract.h.html#scroll_bar_style_default + + + + + scrollable_region_style + dlib/gui_widgets/style_abstract.h.html#scrollable_region_style + + + scrollable_region_style_default + dlib/gui_widgets/style_abstract.h.html#scrollable_region_style_default + + + + + list_box_style + dlib/gui_widgets/style_abstract.h.html#list_box_style + + + list_box_style_default + dlib/gui_widgets/style_abstract.h.html#list_box_style_default + + + + + text_field_style + dlib/gui_widgets/style_abstract.h.html#text_field_style + + + text_field_style_default + dlib/gui_widgets/style_abstract.h.html#text_field_style_default + + + + + text_box_style + dlib/gui_widgets/style_abstract.h.html#text_box_style + + + text_box_style_default + dlib/gui_widgets/style_abstract.h.html#text_box_style_default + + + + + canvas drawing functions + + + draw_line + dlib/gui_widgets/canvas_drawing_abstract.h.html#draw_line + + + draw_rectangle + dlib/gui_widgets/canvas_drawing_abstract.h.html#draw_rectangle + + + draw_circle + dlib/gui_widgets/canvas_drawing_abstract.h.html#draw_circle + + + draw_pixel + dlib/gui_widgets/canvas_drawing_abstract.h.html#draw_pixel + + + draw_solid_circle + dlib/gui_widgets/canvas_drawing_abstract.h.html#draw_solid_circle + + + draw_solid_convex_polygon + dlib/gui_widgets/canvas_drawing_abstract.h.html#draw_solid_convex_polygon + + + draw_button_down + dlib/gui_widgets/canvas_drawing_abstract.h.html#draw_button_down + + + draw_sunken_rectangle + dlib/gui_widgets/canvas_drawing_abstract.h.html#draw_sunken_rectangle + + + draw_button_up + dlib/gui_widgets/canvas_drawing_abstract.h.html#draw_button_up + + + draw_checkered + dlib/gui_widgets/canvas_drawing_abstract.h.html#draw_checkered + + + draw_image + dlib/gui_widgets/canvas_drawing_abstract.h.html#draw_image + + + fill_rect + dlib/gui_widgets/canvas_drawing_abstract.h.html#fill_rect + + + fill_rect_with_vertical_gradient + dlib/gui_widgets/canvas_drawing_abstract.h.html#fill_rect_with_vertical_gradient + + + fill_gradient_rounded + dlib/gui_widgets/canvas_drawing_abstract.h.html#fill_gradient_rounded + + + draw_rounded_rectangle + dlib/gui_widgets/canvas_drawing_abstract.h.html#draw_rounded_rectangle + + + + + drawable + dlib/gui_widgets/drawable_abstract.h.html#drawable + + + set_main_font + dlib/gui_widgets/drawable_abstract.h.html#set_main_font + + + main_font + dlib/gui_widgets/drawable_abstract.h.html#main_font + + + z_order + dlib/gui_widgets/drawable_abstract.h.html#z_order + + + next_free_user_event_number + dlib/gui_widgets/drawable_abstract.h.html#next_free_user_event_number + + + set_z_order + dlib/gui_widgets/drawable_abstract.h.html#set_z_order + + + shape + + + get_rect + dlib/gui_widgets/drawable_abstract.h.html#get_rect + + + bottom + dlib/gui_widgets/drawable_abstract.h.html#bottom + + + top + dlib/gui_widgets/drawable_abstract.h.html#top + + + left + dlib/gui_widgets/drawable_abstract.h.html#left + + + right + dlib/gui_widgets/drawable_abstract.h.html#right + + + width + dlib/gui_widgets/drawable_abstract.h.html#width + + + height + dlib/gui_widgets/drawable_abstract.h.html#height + + + + + set_pos + dlib/gui_widgets/drawable_abstract.h.html#set_pos + + + is_enabled + dlib/gui_widgets/drawable_abstract.h.html#is_enabled + + + enable + dlib/gui_widgets/drawable_abstract.h.html#enable + + + disable + dlib/gui_widgets/drawable_abstract.h.html#disable + + + is_hidden + dlib/gui_widgets/drawable_abstract.h.html#is_hidden + + + show + dlib/gui_widgets/drawable_abstract.h.html#show + + + hide + dlib/gui_widgets/drawable_abstract.h.html#hide + + + parent_window + dlib/gui_widgets/drawable_abstract.h.html#parent_window + + + enable_events + dlib/gui_widgets/drawable_abstract.h.html#enable_events + + + events_are_enabled + dlib/gui_widgets/drawable_abstract.h.html#events_are_enabled + + + disable_events + dlib/gui_widgets/drawable_abstract.h.html#disable_events + + + events + + + on_window_resized + dlib/gui_widgets/drawable_abstract.h.html#on_window_resized + + + on_window_moved + dlib/gui_widgets/drawable_abstract.h.html#on_window_moved + + + on_focus_gained + dlib/gui_widgets/drawable_abstract.h.html#on_focus_gained + + + on_focus_lost + dlib/gui_widgets/drawable_abstract.h.html#on_focus_lost + + + on_mouse_up + dlib/gui_widgets/drawable_abstract.h.html#on_mouse_up + + + on_mouse_move + dlib/gui_widgets/drawable_abstract.h.html#on_mouse_move + + + on_mouse_enter + dlib/gui_widgets/drawable_abstract.h.html#on_mouse_enter + + + on_mouse_leave + dlib/gui_widgets/drawable_abstract.h.html#on_mouse_leave + + + on_mouse_down + dlib/gui_widgets/drawable_abstract.h.html#on_mouse_down + + + on_wheel_up + dlib/gui_widgets/drawable_abstract.h.html#on_wheel_up + + + on_wheel_down + dlib/gui_widgets/drawable_abstract.h.html#on_wheel_down + + + on_keydown + dlib/gui_widgets/drawable_abstract.h.html#on_keydown + + + on_string_put + dlib/gui_widgets/drawable_abstract.h.html#on_string_put + + + on_user_event + dlib/gui_widgets/drawable_abstract.h.html#on_user_event + + + draw + dlib/gui_widgets/drawable_abstract.h.html#draw + + + + + + + drawable_window + dlib/gui_widgets/drawable_abstract.h.html#drawable_window + + + fonts + + + letter + dlib/gui_widgets/fonts_abstract.h.html#letter + + + font + dlib/gui_widgets/fonts_abstract.h.html#font + + + default_font + dlib/gui_widgets/fonts_abstract.h.html#default_font + + + get_native_font + dlib/gui_widgets/fonts_abstract.h.html#get_native_font + + + bdf_font + dlib/gui_widgets/fonts_abstract.h.html#bdf_font + + + + + + + + gui_core + + + base_window + dlib/gui_core/gui_core_kernel_abstract.h.html#base_window + + + canvas + dlib/gui_core/gui_core_kernel_abstract.h.html#canvas + + + get_from_clipboard + dlib/gui_core/gui_core_kernel_abstract.h.html#get_from_clipboard + + + put_on_clipboard + dlib/gui_core/gui_core_kernel_abstract.h.html#put_on_clipboard + + + + + + dir_nav + + + get_filesystem_roots + dlib/dir_nav/dir_nav_kernel_abstract.h.html#get_filesystem_roots + + + file + dlib/dir_nav/dir_nav_kernel_abstract.h.html#file + + + directory + dlib/dir_nav/dir_nav_kernel_abstract.h.html#directory + + + get_files_in_directory_tree + dlib/dir_nav/dir_nav_extensions_abstract.h.html#get_files_in_directory_tree + + + get_parent_directory + dlib/dir_nav/dir_nav_extensions_abstract.h.html#get_parent_directory + + + file_exists + dlib/dir_nav/dir_nav_extensions_abstract.h.html#file_exists + + + select_oldest_file + dlib/dir_nav/dir_nav_extensions_abstract.h.html#select_oldest_file + + + select_newest_file + dlib/dir_nav/dir_nav_extensions_abstract.h.html#select_newest_file + + + + + misc_api + + + sleep + dlib/misc_api/misc_api_kernel_abstract.h.html#sleep + + + get_current_dir + dlib/misc_api/misc_api_kernel_abstract.h.html#get_current_dir + + + set_current_dir + dlib/misc_api/misc_api_kernel_abstract.h.html#set_current_dir + + + locally_change_current_dir + dlib/misc_api/misc_api_kernel_abstract.h.html#locally_change_current_dir + + + create_directory + dlib/misc_api/misc_api_kernel_abstract.h.html#create_directory + + + timestamper + dlib/misc_api/misc_api_kernel_abstract.h.html#timestamper + + + + + threads + + + extensions + + thread_specific_data + create_new_thread_extension + rsignaler + rmutex + read_write_mutex + auto_mutex + auto_mutex_readonly + auto_unlock + auto_unlock_readonly + threaded_object + thread_pool + async + default_thread_pool + parallel_for + + future + dlib/threads/thread_pool_extension_abstract.h.html#future + + thread_function + multithreaded_object + + + + is_dlib_thread + dlib/threads/threads_kernel_abstract.h.html#is_dlib_thread + + + create_new_thread + dlib/threads/threads_kernel_abstract.h.html#create_new_thread + + + mutex + dlib/threads/threads_kernel_abstract.h.html#mutex + + + unregister_thread_end_handler + dlib/threads/threads_kernel_abstract.h.html#unregister_thread_end_handler + + + register_thread_end_handler + dlib/threads/threads_kernel_abstract.h.html#register_thread_end_handler + + + signaler + dlib/threads/threads_kernel_abstract.h.html#signaler + + + get_thread_id + dlib/threads/threads_kernel_abstract.h.html#get_thread_id + + + + + sockets + + + extensions + + + network_address + dlib/sockets/sockets_extensions_abstract.h.html#network_address + + + connect + dlib/sockets/sockets_extensions_abstract.h.html#connect + + + is_ip_address + dlib/sockets/sockets_extensions_abstract.h.html#is_ip_address + + + close_gracefully + dlib/sockets/sockets_extensions_abstract.h.html#close_gracefully + + + + + objects + + + connection + dlib/sockets/sockets_kernel_abstract.h.html#connection + + + listener + dlib/sockets/sockets_kernel_abstract.h.html#listener + + + + + functions + + + create_connection + dlib/sockets/sockets_kernel_abstract.h.html#create_connection + + + create_listener + dlib/sockets/sockets_kernel_abstract.h.html#create_listener + + + get_local_hostname + dlib/sockets/sockets_kernel_abstract.h.html#get_local_hostname + + + hostname_to_ip + dlib/sockets/sockets_kernel_abstract.h.html#hostname_to_ip + + + ip_to_hostname + dlib/sockets/sockets_kernel_abstract.h.html#ip_to_hostname + + + + + +
    +
    +
    + + + + + + + + + dir_nav + dlib/dir_nav.h + dlib/dir_nav/dir_nav_kernel_abstract.h + + This is a set of objects that provide an easy and portable way to traverse a directory tree. + + + + dir_nav_ex.cpp.html + + + + + dir_nav_kernel_1 + dlib/dir_nav/dir_nav_kernel_1.h + + MS Windows implementation + + + + + dir_nav_kernel_2 + dlib/dir_nav/dir_nav_kernel_2.h + + POSIX implementation + + + + + + + + dir_nav_extensions + dlib/dir_nav/dir_nav_extensions_abstract.h + + This is just some miscellaneous extensions to the dir_nav component. + + + + + + + + + + + gui_core + dlib/gui_core.h + dlib/gui_core/gui_core_kernel_abstract.h + + This is a set of objects and functions which provide a very basic + framework for manipulating windows. It is intended to provide a portable + interface which can be used to build a more complex windowing toolkit. + + + + + gui_core_kernel_1 + dlib/gui_core/gui_core_kernel_1.h + + MS Windows implementation + + + + + gui_core_kernel_2 + dlib/gui_core/gui_core_kernel_2.h + + X Windows implementation + + + + + + + + + + + + + misc_api + dlib/misc_api.h + dlib/misc_api/misc_api_kernel_abstract.h + + This is just a collection of miscellaneous APIs that were small/simple + enough not to warrant their own module. + + + + + misc_api_kernel_1 + dlib/misc_api/misc_api_kernel_1.h + + MS Windows implementation + + + + + misc_api_kernel_2 + dlib/misc_api/misc_api_kernel_2.h + + POSIX implementation + + + + + + + + + + + + + sockets + dlib/sockets.h + dlib/sockets/sockets_kernel_abstract.h + + This is a set of objects that provides an easy to use and object oriented + interface for dealing with TCP networking. There are currently two implementations, + one for UNIX and another for all versions of Windows after Windows95. + Both provide the exact same interface so programs written with them can be + recompiled on either platform without a problem. +

    + You also may want to take note of the timeout object. + It provides a mechanism which you can use to add a timeout to a network operation. +

    +
    + + + iosockstream_ex.cpp.html + sockets_ex.cpp.html + sockstreambuf_ex.cpp.html + server_http_ex.cpp.html + server_iostream_ex.cpp.html + + + + + sockets_kernel_1 + dlib/sockets/sockets_kernel_1.h + + MS Windows implementation + + + + + sockets_kernel_2 + dlib/sockets/sockets_kernel_2.h + + POSIX implementation + + + + + + + + + sockets_extensions + dlib/sockets/sockets_extensions_abstract.h + + This is just some miscellaneous extensions to the socket api. + + + + + + +
    + + + + + + + threads + dlib/threads.h + dlib/threads/threads_kernel_abstract.h + + This is a set of objects that provides an easy to use and object oriented interface + for creating multi-threaded programs. There are currently two implementations, one + for UNIX and another for any variant of MS Windows after Windows 95. Both provide + the exact same interface so programs written with them can be recompiled on either + platform without a problem. +

    + You also probably want to take note of the pipe object. + It provides an easy to use typesafe mechanism to send messages between threads. +

    +
    + + + threads_ex.cpp.html + logger_ex_2.cpp.html + pipe_ex.cpp.html + multithreaded_object_ex.cpp.html + threaded_object_ex.cpp.html + thread_function_ex.cpp.html + thread_pool_ex.cpp.html + + + + + threads_kernel_1 + dlib/threads/threads_kernel_1.h + + MS Windows implementation + + + + + threads_kernel_2 + dlib/threads/threads_kernel_2.h + + POSIX implementation + + + + + + + + + rsignaler + dlib/threads/rsignaler_extension_abstract.h + + This extension adds a signaler object that can be used with the rmutex object. + Also note that this extension is included by dlib/threads.h so you don't have to include + anything extra to get it. + + + + + + thread_specific_data + dlib/threads/thread_specific_data_extension_abstract.h + + This extension adds the ability to easily create thread specific data. + + + + + + rmutex + dlib/threads/rmutex_extension_abstract.h + + This extension adds a mutex object that can handle recursive calls + to lock(). + Also note that this extension is included by dlib/threads.h so you don't have to include + anything extra to get it. + + + + + read_write_mutex + dlib/threads/read_write_mutex_extension_abstract.h + + This extension adds a mutex object that can perform both normal "write locks" as well as "readonly locks". + See the specification for details. + Also note that this extension is included by dlib/threads.h so you don't have to include + anything extra to get it. + + + + + create_new_thread_extension + dlib/threads/create_new_thread_extension_abstract.h + + This extension adds some templated overloads to the + create_new_thread() function. They allow you to create new threads using member functions from a class. + Also note that this extension is included by dlib/threads.h so you don't have to include + anything extra to get it. + + + + + + + auto_mutex + dlib/threads/auto_mutex_extension_abstract.h + + This extension adds a mechanism to automatically lock and unlock a mutex. + Also note that this extension is included by dlib/threads.h so you don't have to include + anything extra to get it. + + + + threads_ex.cpp.html + + + + + + auto_mutex_readonly + dlib/threads/auto_mutex_extension_abstract.h + + This extension adds a mechanism to automatically perform a readonly lock and unlock + of a read_write_mutex. + Also note that this extension is included by dlib/threads.h so you don't have to include + anything extra to get it. + + + + + auto_unlock_readonly + dlib/threads/auto_unlock_extension_abstract.h + + This extension adds a mechanism to automatically remove a readonly unlock from a read_write_mutex. + Also note that this extension is included by dlib/threads.h so you don't have to include + anything extra to get it. + + + + + thread_function + dlib/threads/thread_function_extension_abstract.h + + This object represents a thread on a global C++ function. That is, it allows you + to run a global function in its own thread. + + + + thread_function_ex.cpp.html + + + + + + + threaded_object + dlib/threads/threaded_object_extension_abstract.h + + This extension represents a simple threaded object. It provides a convenient + mechanism to create an object that contains a thread. + + + + threaded_object_ex.cpp.html + + + + + + + thread_pool + dlib/threads/thread_pool_extension_abstract.h + +

    + This object represents a fixed size group of threads which you can + submit tasks to and then wait for those tasks to be completed. It also + provides a future object + that provides a container which allows you to safely pass objects into the tasks. +

    + The implementation of this extension can be found + here. It is + implemented such that no memory allocations occur after the thread pool + has been constructed so long as the user doesn't call + any of the add_task_by_value() routines. The future object also doesn't + perform any memory allocations or contain any system resources such as mutex objects. +
    + + + thread_pool_ex.cpp.html + +
    + + + default_thread_pool + dlib/threads/async_abstract.h + + This function returns a reference to a global thread_pool. If the DLIB_NUM_THREADS + environment variable is set to an integer then the thread pool will contain + DLIB_NUM_THREADS threads, otherwise it will contain + std::thread::hardware_concurrency() threads. + + + + + async + dlib/threads/async_abstract.h + + This function behaves just like std::async() + except that instead of spawning a new thread to process each task it submits + the task to a dlib::thread_pool. Therefore, dlib::async() is + guaranteed to use a bounded number of threads unlike std::async(). This also + means that calls to dlib::async() will block if there aren't any free threads + in the thread pool. + + + + + parallel_for + dlib/threads/parallel_for_extension_abstract.h + + This is a set of functions for executing the contents of a for loop in parallel. + It is useful for taking advantage of multi-processor systems. + + + parallel_for_ex.cpp.html + + + + + multithreaded_object + dlib/threads/multithreaded_object_extension_abstract.h + + This object represents a multithreaded object. It is similar to + the threaded_object except it allows you to have many threads in a + single object rather than just one. + + + + multithreaded_object_ex.cpp.html + pipe_ex.cpp.html + + + + + auto_unlock + dlib/threads/auto_unlock_extension_abstract.h + + This extension adds a mechanism to automatically unlock a mutex. + Also note that this extension is included by dlib/threads.h so you don't have to include + anything extra to get it. + + + + +
    + + + +
    + + + + + + + + gui_widgets + dlib/gui_widgets.h + +

    + This component is a collection of various windowing widgets such as buttons, + labels, text boxes, and so on. It also includes the drawable + interface, drawable_window, and font handling objects. + dlib/gui_widgets/widgets_abstract.h + defines all of the high level graphical widgets provided by this + component that can appear in a drawable_window. To view the specifications for the other members of this + component look at dlib/gui_widgets/fonts_abstract.h, + dlib/gui_widgets/drawable_abstract.h, + and dlib/gui_widgets/base_widgets_abstract.h. +

    +

    This component isn't actually a wrapper on top of OS APIs. Rather, it is + implemented on top of the gui_core + component. I put it on this page just because I expect that people would + look here when searching for the sort of functionality provided by this component. +

    + + +

    Primary widgets

    + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    + +
    + + + gui_api_ex.cpp.html + image_ex.cpp.html + 3d_point_cloud_ex.cpp.html + surf_ex.cpp.html + bayes_net_gui_ex.cpp.html + +
    + + + + + +
    + + + + +
    diff --git a/ml/dlib/docs/docs/bayes.xml b/ml/dlib/docs/docs/bayes.xml new file mode 100644 index 000000000..2ed11fe1d --- /dev/null +++ b/ml/dlib/docs/docs/bayes.xml @@ -0,0 +1,377 @@ + + + + + Bayesian Networks + + + + + +

    + This page documents all the tools within the dlib library that relate + to the construction and evaluation of Bayesian networks. If you want + a quick introduction to the tools then you should consult the + Bayesian Net example program. + +

    +

    + The + library also comes with a graphical application to assist in the + creation of bayesian networks. This application is one of the + example programs, so to use it + you have to compile it yourself. +

    + + + + + + + +
    + Tools + assignment + joint_probability_table + conditional_probability_table + bayes_node + bayesian_network_gibbs_sampler + bayesian_network_join_tree + +
    + +
    + Node Utilities + set_node_value + node_value + node_is_evidence + set_node_as_evidence + set_node_as_nonevidence + set_node_num_values + node_num_values + node_probability + set_node_probability + node_first_parent_assignment + node_next_parent_assignment + node_cpt_filled_out + +
    + +
    +
    + + + + + + + + + + + bayesian_network_join_tree + dlib/bayes_utils.h + dlib/bayes_utils/bayes_utils_abstract.h + + This object represents an implementation of the join tree algorithm + (a.k.a. the junction tree algorithm) + for inference in bayesian networks. + + + bayes_net_ex.cpp.html + bayes_net_gui_ex.cpp.html + bayes_net_from_disk_ex.cpp.html + + + + + + + + bayesian_network_gibbs_sampler + dlib/bayes_utils.h + dlib/bayes_utils/bayes_utils_abstract.h + + This object performs Markov Chain Monte Carlo sampling of a bayesian + network using the Gibbs sampling technique. + + + bayes_net_ex.cpp.html + + + + + + + + bayes_node + dlib/bayes_utils.h + dlib/bayes_utils/bayes_utils_abstract.h + + This object represents a node in a bayesian network. It is + intended to be used inside the directed_graph object to + represent bayesian networks. + + + + bayes_net_ex.cpp.html + bayes_net_gui_ex.cpp.html + + + + + + + conditional_probability_table + dlib/bayes_utils.h + dlib/bayes_utils/bayes_utils_abstract.h + + This object represents a conditional probability table. + + + + + + + + joint_probability_table + dlib/bayes_utils.h + dlib/bayes_utils/bayes_utils_abstract.h + + This object represents a joint probability table. + + + + + + + + assignment + dlib/bayes_utils.h + dlib/bayes_utils/bayes_utils_abstract.h + + This object models an assignment of random variables to particular values. + It is used with the joint_probability_table and + conditional_probability_table + objects to represent assignments of various random variables to actual values. + + + bayes_net_ex.cpp.html + bayes_net_gui_ex.cpp.html + + + + + + + + set_node_probability + dlib/bayes_utils.h + dlib/bayes_utils/bayes_utils_abstract.h + + This is a function declared in the dlib::bayes_node_utils namespace. It + is a convenience function that allows you to easily set the probability of a + bayes_node given its parents when it is inside + a directed_graph object. + + + bayes_net_ex.cpp.html + + + + + + + + node_first_parent_assignment + dlib/bayes_utils.h + dlib/bayes_utils/bayes_utils_abstract.h + + This is a function declared in the dlib::bayes_node_utils namespace. It + is a convenience function that allows you to easily obtain an assignment + that contains all the parents of a node in a bayesian network. + + + bayes_net_gui_ex.cpp.html + + + + + + + + node_next_parent_assignment + dlib/bayes_utils.h + dlib/bayes_utils/bayes_utils_abstract.h + + This is a function declared in the dlib::bayes_node_utils namespace. It + is a convenience function that allows you to easily loop through all the parent assignments + of a node in a bayesian network. + + + bayes_net_gui_ex.cpp.html + + + + + + + + node_cpt_filled_out + dlib/bayes_utils.h + dlib/bayes_utils/bayes_utils_abstract.h + + This is a function declared in the dlib::bayes_node_utils namespace. It + is a convenience function that allows you to easily verify that a node + in a bayesian network has its conditional_probability_table + completely filled out. + + + bayes_net_gui_ex.cpp.html + + + + + + + + + node_probability + dlib/bayes_utils.h + dlib/bayes_utils/bayes_utils_abstract.h + + This is a function declared in the dlib::bayes_node_utils namespace. It + is a convenience function that allows you to easily obtain the probability of a + bayes_node given its parents when it is inside + a directed_graph object. + + + + + + + + node_num_values + dlib/bayes_utils.h + dlib/bayes_utils/bayes_utils_abstract.h + + This is a function declared in the dlib::bayes_node_utils namespace. It + is a convenience function that allows you to easily obtain the number of values of a + bayes_node when it is inside + a directed_graph object. + + + + + + + + set_node_num_values + dlib/bayes_utils.h + dlib/bayes_utils/bayes_utils_abstract.h + + This is a function declared in the dlib::bayes_node_utils namespace. It + is a convenience function that allows you to easily set the number of values of a + bayes_node when it is inside + a directed_graph object. + + + bayes_net_ex.cpp.html + + + + + + + + set_node_as_nonevidence + dlib/bayes_utils.h + dlib/bayes_utils/bayes_utils_abstract.h + + This is a function declared in the dlib::bayes_node_utils namespace. It + is a convenience function that allows you to easily remove the evidence flag of a + bayes_node when it is inside + a directed_graph object. + + + bayes_net_ex.cpp.html + + + + + + + + set_node_as_evidence + dlib/bayes_utils.h + dlib/bayes_utils/bayes_utils_abstract.h + + This is a function declared in the dlib::bayes_node_utils namespace. It + is a convenience function that allows you to easily set the evidence flag of a + bayes_node when it is inside + a directed_graph object. + + + bayes_net_ex.cpp.html + + + + + + + + node_is_evidence + dlib/bayes_utils.h + dlib/bayes_utils/bayes_utils_abstract.h + + This is a function declared in the dlib::bayes_node_utils namespace. It + is a convenience function that allows you to easily determine if a + bayes_node is evidence when it is inside + a directed_graph object. + + + + + + + + node_value + dlib/bayes_utils.h + dlib/bayes_utils/bayes_utils_abstract.h + + This is a function declared in the dlib::bayes_node_utils namespace. It + is a convenience function that allows you to easily obtain the value of a + bayes_node when it is inside a + directed_graph object. + + + + + + + + set_node_value + dlib/bayes_utils.h + dlib/bayes_utils/bayes_utils_abstract.h + + This is a function declared in the dlib::bayes_node_utils namespace. It + is a convenience function that allows you to easily modify the value of a + bayes_node when it is inside a + directed_graph object. + + + bayes_net_ex.cpp.html + + + + + + + + + + + +
    + + diff --git a/ml/dlib/docs/docs/bayesopt_vs_lipo.svg b/ml/dlib/docs/docs/bayesopt_vs_lipo.svg new file mode 100644 index 000000000..f88f4ab5a --- /dev/null +++ b/ml/dlib/docs/docs/bayesopt_vs_lipo.svg @@ -0,0 +1,21764 @@ + + + + + + image/svg+xml + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 0 + + + 20 + + + 40 + + + 60 + + + 80 + + + 100 + + + 120 + + + 140 + + + 160 + + + 180 + + + 200 + + + Objective Function Calls + + + + + + + + + + + + + + + + + + + + + -25 + + + -20 + + + -15 + + + -10 + + + -5 + + + 0 + + + 5 + + + log10(error) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + BayesOpt + + + + + + + + + + + + + + + + + + + MaxLIPO+TR + + + + + + + + + + + + + + + + + + + Solving Holder Table Function: BayesOpt vs. MaxLIPO+TR + diff --git a/ml/dlib/docs/docs/bigminus.gif b/ml/dlib/docs/docs/bigminus.gif new file mode 100644 index 000000000..aea8e5c01 Binary files /dev/null and b/ml/dlib/docs/docs/bigminus.gif differ diff --git a/ml/dlib/docs/docs/bigplus.gif b/ml/dlib/docs/docs/bigplus.gif new file mode 100644 index 000000000..6bee68e21 Binary files /dev/null and b/ml/dlib/docs/docs/bigplus.gif differ diff --git a/ml/dlib/docs/docs/books.xml b/ml/dlib/docs/docs/books.xml new file mode 100644 index 000000000..cda81bd86 --- /dev/null +++ b/ml/dlib/docs/docs/books.xml @@ -0,0 +1,306 @@ + + + + + Suggested Books + + + + + + + +

    + One of the major goals of dlib is to have documentation that enables + someone to easily make use of its various components. Ideally, + you would read a short description of something, understand it immediately, + and begin using it in your application without any difficulty. Obviously, this + depends partly on the background of the user. For example, if you have + never written C++ software before then it probably isn't going to be this easy. +

    +

    + This page is meant to complement the main library documentation by providing + references to books, along with my commentary, which explain most of + the background material needed to understand the various parts of the library. + In most cases these are the books I learned from during the process + of creating dlib. As always, if you disagree with anything or think I have left out + an important text then shoot me an email. +

    +

    + + +

    General Programming

    +
      +

      C++

      +
        +
      • Programming: Principles and Practice Using C++ by Bjarne Stroustrup +
          This is the sort of book you would use in a freshman introduction-to-programming class. + So if you are just beginning to study programming and are interested in C++ then I think + it is probably safe to say this is one of the best books you could read.

        +
      • +
      • Accelerated C++: Practical Programming by Example by Andrew Koenig and Barbara E. Moo +
          If you are new to C++ but already know how to program then this is a great book. It's also + about one fourth the size of the Stroustrup book.

        +
      • +
      • Effective C++: 55 Specific Ways to Improve Your Programs and Designs (3rd Edition) by Scott Meyers +
          This is a great intermediate level C++ book. Most people have heard the jokes about + how easy it is to shoot yourself in the foot with C++. This book explains many things you + need to know about the language to avoid doing so on a regular basis. So if you are + writing C++ software then this is a must-read. I would even claim that + you are a danger to the C++ software you touch unless you know what is in this book. + I'm not kidding. Finally, the book isn't just about the quirks of C++. It also discusses many general + software engineering ideas which have wide applicability. So in this + respect it is a great book for any software developer to read. +

        +
      • +
      • More Effective C++: 35 New Ways to Improve Your Programs and Designs by Scott Meyers +
          Consider this an expansion to Effective C++. If you are going to read the above + book then you would almost certainly benefit from reading this one as well. +

        +
      • +
      • The C++ Standard Library: A Tutorial and Reference by Nicolai M. Josuttis +
          If you are going to buy a reference book on the C++ standard library then this + is the one to get. I think you + will find it is better than any of the available online references. So if you find + yourself frustrated with the online resources, then this is the book for you. +

        +
      • +
      • Online C++ Standard Library Reference +
          What I said aside, this is a good online reference. I often find myself referring to it + when I do not have the Josuttis book on hand. +

        +
      • +
      + + +

      Multithreading

      +
        +
      • Programming with POSIX Threads by David R. Butenhof +
          When I was an undergrad, this book was my main resource for learning about multithreading. + It was enjoyable to read, as are all the books on this list, and covered everything + in great depth without becoming overbearing. Also, despite what the title may suggest, + this book is useful for understanding multithreading broadly, not just multithreading + on POSIX systems. +

        +
      • +
      + +

      Network Programming

      +
        +
      • Unix Network Programming, Volume 1: The Sockets Networking API (3rd Edition) + by W. Richard Stevens +
          A lot of people call this book the network programming Bible and + this praise is well deserved. If you want a deep understanding of how computer networks + function, including the Internet, then this is the book to read. As with + the Butenhof book above, this is an excellent choice even for people who do not + intend to write software for Unix systems. +

        +
      • +
      + +

      WIN32 Programming

      + It has been a long time since I needed to refer to these two books. However, + they contained information I couldn't find elsewhere no matter + how hard I looked. So I recommend them in case you need to create or understand + some low level win32 code. +
      +
      +
        +
      • Win32 Programming by Brent E. Rector and Joseph M. Newcomer
      • +
      • Programming Windows by Charles Petzold
      • +
      • MSDN Library +
          This is Microsoft's online reference documentation. It is very large and sometimes + confusing. But at the end of the day you should be able to find the documentation + for just about every function in the entire Windows API. +

        +
      • +
      +
    + + + + + +

    Computer Science: Algorithms and Data Structures

    +
      +
    • Introduction to Algorithms by Cormen, Leiserson, Rivest and Stein +
        You should get this book if you are looking for a good discussion of the classic computer science + algorithms and data structures (e.g. most of the components on the containers + page). +

      +
    • +
    • Algorithms in C++, Parts 1-4: Fundamentals, Data Structure, Sorting, Searching + (3rd Edition) by Robert Sedgewick +
        This is another good algorithms book. If you are going to get only one book on this + subject then get the one above. However, when I was learning about these topics I + used both these books and on many occasions I found it helpful to read the description + of an algorithm or data structure in both. Where one description was a little vague or + confusing the other generally filled in the gaps. +

      +
    • +
    + + + +

    Lossless Data Compression

    +
      +
    • Text Compression by Bell, Cleary, and Witten +
        When I was studying data compression this was my most useful + resource. If you are looking to understand how lossless data compression + algorithms work then this is the book you want. It is completely self-contained + and an absolute joy to read. Note that contrary to one of the reviews on + amazon.com, the book Managing Gigabytes is not the second edition of this book; + if this topic interests you then be sure you get the 318 page + book published in 1990. +

      +
    • +
    + + + +

    General Math

    + +
      +
    • Linear Algebra Done Right by Sheldon Jay Axler +
        If a matrix seems like an arbitrary grid of numbers or you find that + you are confused by vectors, matrices, and the various things + that get done with them then this book will change your whole view of this subject. + It doesn't teach you any algorithms. Instead, it will give you a general + framework in which to think about all this stuff. Once you have that down + everything else will start to make a lot more sense. If all goes well + you will even start to agree with the following: linear algebra is beautiful. :) +

      +
    • +
    • Numerical Linear Algebra by Trefethen and Bau +
        While Linear Algebra Done Right is fairly abstract, this book by + Trefethen and Bau will + explain some of the actual algorithms that are often used. + This is a great second book if you find that you want to know more about + the SVD, LU decomposition, or various other algorithms involving linear algebra. +

      +
    • +
    • Calculus: Single and Multivariable by Hughes-Hallett, Gleason, and McCallum +
        + Some of the books below will require and understanding of basic calculus. So + I'm recommending this book. It was the book I used as an undergrad and I + remember it being alright. That isn't exactly a glowing review so if you + are really considering buying a calculus book you may want to check out + other reviews before picking this one. +

      +
    • +
    • Introduction to Real Analysis (third edition) by Bartle and Sherbert +
        At some level real analysis is like a really rigorous repeat of calculus. + So if you already have an undergraduate education in calculus and + you are reading things that seem reminiscent of calculus but involve + stuff you haven't seen before (e.g. sup, inf, "sets of numbers", sequences of points) + then you may be in need of a real analysis book. This one is quite good and should + be accessible to someone with the usual undergraduate computer science math background. +

      +
    • +
    + + + + + +

    Optimization

    + + The subject of linear algebra is fundamental to optimization. So you must be familiar + with the contents of a book like Linear Algebra Done Right if you are going to study + this area. You will also need to know how to find the derivative of a function and + understand what a derivative is all about. So you will need to know a little bit of + calculus. Finally, once in a while you will need to know a little bit about real + analysis. Ultimately, what you need all depends on how deep you want to go. + +
      +
    • Practical Methods of Optimization (second edition) by R. Fletcher 1987 +
        I love this book. When I got it I literally spent my weekends sitting around + reading it for hours. It is a fascinating and well written introduction to + the subject of optimization. This has been my most valuable resource for + learning the fundamentals of optimization and I cannot recommend it highly enough. +

      +
    • +
    • Numerical Optimization by Jorge Nocedal and Stephen Wright 2006 +
        This is a more recent text on optimization that is also very good. It + covers many algorithms not covered by the above book. +

      +
    • +
    • Introduction to Derivative-Free Optimization by Conn, Scheinberg, and Vicente +
        If you want to understand algorithms like BOBYQA + then this is a good recent book on the subject. Note that a book like Practical Methods of Optimization + is almost certainly a prerequisite for reading this book. As an aside, BOBYQA is not discussed in this book but + its predecessor, NEWUOA is. +

      +
    • +
    + + + + +

    Machine Learning

    + +
      +
    • Artificial Intelligence: A Modern Approach (3rd Edition) by Stuart Russell and Peter Norvig +
        This book is about the much broader field of AI but it contains an excellent introduction + to machine learning and it also covers other useful topics like bayesian networks. + Moreover, it is very well written and self-contained. So you don't need any particular + background to be able to learn from it apart from a typical undergraduate background + in computer science. +

      +
    • +
    • Learning with Kernels: Support Vector Machines, Regularization, Optimization, and Beyond + by Bernhard Schlkopf and Alexander J. Smola +
        Most of the machine learning tools in dlib are implementations of various kernel methods. + So if you want a book that covers this topic in great depth as well as breadth then this is + probably the book for you. The most important prerequisite for this book is linear + algebra. Virtually everything in this book depends on linear algebra in a fundamental way. +

        + The second important subject is optimization. Whenever you see the text + mention the KKT conditions, duality, "primal variables", or quadratic programming it + is talking about ideas from optimization. A good book which will explain all this to you + is Practical Methods of Optimization. Note that this book calls the KKT conditions + just the "KT" conditions. It is talking about the same thing. Also, duality + is something that comes up a lot in optimization but in the context of machine learning + usually people are talking about a particular form known as the Wolfe Dual. +

        + It would also be good (but maybe not critical depending on which parts you want to read) to + be familiar with real analysis. +

      +
    • +
    • Kernel Methods for Pattern Analysis by John Shawe-Taylor and Nello Cristianini +
        This is another good book about kernel methods. If you have to choose between + this book and Learning with Kernels I would go with Learning with Kernels. However, it is + good to have both since reading different presentations of difficult subjects + usually makes learning them easier. +

      +
    • + +
    • Structured Prediction and Learning in Computer Vision by Sebastian Nowozin and Christoph H. Lampert 2011 +
        If you are looking for a book discussing the background material necessary + for understanding things like the Structural SVM + tools in dlib then this is a good book. It is also available online + in PDF form. +

      +
    • + +
    + +

    Image Processing

    +
      +
    • Digital Image Processing by Rafael C. Gonzalez and Richard E. Woods +
        This is a terrific introduction to digital image processing. + By and large this book doesn't require any special prerequisites. Sometimes + calculus shows up, but not too much. +

      +
    • +
    + + + + + + + + +
    + diff --git a/ml/dlib/docs/docs/boost.png b/ml/dlib/docs/docs/boost.png new file mode 100644 index 000000000..b4d51fcd5 Binary files /dev/null and b/ml/dlib/docs/docs/boost.png differ diff --git a/ml/dlib/docs/docs/change_log.xml b/ml/dlib/docs/docs/change_log.xml new file mode 100644 index 000000000..6088c869e --- /dev/null +++ b/ml/dlib/docs/docs/change_log.xml @@ -0,0 +1,11 @@ + + + + + Change Log + +
    +
    Old Change Logs
    +
    + +
    diff --git a/ml/dlib/docs/docs/chm/READ THE README. DO NOT EDIT THE TABLE OF CONTENTS FILE b/ml/dlib/docs/docs/chm/READ THE README. DO NOT EDIT THE TABLE OF CONTENTS FILE new file mode 100644 index 000000000..e69de29bb diff --git a/ml/dlib/docs/docs/chm/READ THE README. DO NOT EDIT THE TABLE OF CONTENTS FILE2 b/ml/dlib/docs/docs/chm/READ THE README. DO NOT EDIT THE TABLE OF CONTENTS FILE2 new file mode 100644 index 000000000..e69de29bb diff --git a/ml/dlib/docs/docs/chm/READ THE README. DO NOT EDIT THE TABLE OF CONTENTS FILE3 b/ml/dlib/docs/docs/chm/READ THE README. DO NOT EDIT THE TABLE OF CONTENTS FILE3 new file mode 100644 index 000000000..e69de29bb diff --git a/ml/dlib/docs/docs/chm/README.txt b/ml/dlib/docs/docs/chm/README.txt new file mode 100644 index 000000000..ce1e21ee1 --- /dev/null +++ b/ml/dlib/docs/docs/chm/README.txt @@ -0,0 +1,5 @@ +The Table of Contents.hhc file is auto generated by the toc.xml and htmlhelp_stylesheet.xsl files. +You really can edit it if you want but I suggest you use the stylesheet to auto generate it instead. + +If you want to regenerate the table of contents file you can do so with +the command "msxsl toc.xml htmlhelp_stylesheet.xsl" if you are using msxsl.exe. \ No newline at end of file diff --git a/ml/dlib/docs/docs/chm/documentation.html b/ml/dlib/docs/docs/chm/documentation.html new file mode 100644 index 000000000..2a03451c7 --- /dev/null +++ b/ml/dlib/docs/docs/chm/documentation.html @@ -0,0 +1,20 @@ + + +dlib C++ library + + + + + + + + +

    + +click here to go to the documentation +

    + + + \ No newline at end of file diff --git a/ml/dlib/docs/docs/chm/htmlhelp/hha.dll b/ml/dlib/docs/docs/chm/htmlhelp/hha.dll new file mode 100644 index 000000000..07518f2c1 Binary files /dev/null and b/ml/dlib/docs/docs/chm/htmlhelp/hha.dll differ diff --git a/ml/dlib/docs/docs/chm/htmlhelp/hhc.exe b/ml/dlib/docs/docs/chm/htmlhelp/hhc.exe new file mode 100644 index 000000000..9a1f31de1 Binary files /dev/null and b/ml/dlib/docs/docs/chm/htmlhelp/hhc.exe differ diff --git a/ml/dlib/docs/docs/chm/htmlhelp/htmlhelp.reg b/ml/dlib/docs/docs/chm/htmlhelp/htmlhelp.reg new file mode 100644 index 000000000..3d91e08b0 --- /dev/null +++ b/ml/dlib/docs/docs/chm/htmlhelp/htmlhelp.reg @@ -0,0 +1,5 @@ +REGEDIT4 + + +[HKEY_CURRENT_USER\Software\Wine\AppDefaults\hhc.exe\DllOverrides] +"itss"="native" diff --git a/ml/dlib/docs/docs/chm/htmlhelp/itcc.dll b/ml/dlib/docs/docs/chm/htmlhelp/itcc.dll new file mode 100644 index 000000000..5e78ebb8e Binary files /dev/null and b/ml/dlib/docs/docs/chm/htmlhelp/itcc.dll differ diff --git a/ml/dlib/docs/docs/chm/htmlhelp/itircl.dll b/ml/dlib/docs/docs/chm/htmlhelp/itircl.dll new file mode 100644 index 000000000..85d1ec9ae Binary files /dev/null and b/ml/dlib/docs/docs/chm/htmlhelp/itircl.dll differ diff --git a/ml/dlib/docs/docs/chm/htmlhelp/itss.dll b/ml/dlib/docs/docs/chm/htmlhelp/itss.dll new file mode 100644 index 000000000..da3293be4 Binary files /dev/null and b/ml/dlib/docs/docs/chm/htmlhelp/itss.dll differ diff --git a/ml/dlib/docs/docs/chm/htmlhelp/setup_htmlhelp.sh b/ml/dlib/docs/docs/chm/htmlhelp/setup_htmlhelp.sh new file mode 100755 index 000000000..020f83b1a --- /dev/null +++ b/ml/dlib/docs/docs/chm/htmlhelp/setup_htmlhelp.sh @@ -0,0 +1,10 @@ +#!/bin/sh + +cp *.dll ~/.wine/drive_c/windows/system32/ + +# Setup the registry +wine regedit htmlhelp.reg + +wine regsvr32 itcc.dll +wine regsvr32 itircl.dll + diff --git a/ml/dlib/docs/docs/chm/htmlhelp_stylesheet.xsl b/ml/dlib/docs/docs/chm/htmlhelp_stylesheet.xsl new file mode 100644 index 000000000..8b101f986 --- /dev/null +++ b/ml/dlib/docs/docs/chm/htmlhelp_stylesheet.xsl @@ -0,0 +1,223 @@ + + + + + + + + + + + abcdefghijklmnopqrstuvwxyz + ABCDEFGHIJKLMNOPQRSTUVWXYZ + + + + + + + +
      + + + + + + + + + + +
    + +
    + + + + + + + + + + + + + + +
  • + + +
  • +
      + + + + + + + +
    +
    + + + + + + + + + +
    +
    + + + + + + + + + + + + + + + + + + + + +
  • + + + + + + + + + + + + + +
  • + + + +
      + + + + + + + +
    +
    + +
      + + + + +
    +
    +
    +
    + + + +
  • + + + +
  • + + + +
      + + + +
    • + + + +
    • +
      + +
    • + + + +
    • +
      +
      +
      + + + + +
    • + + + +
    • +
      + +
    • + + + +
    • +
      +
      +
      +
      + +
    • + + + +
    • +
      + +
    • + + +
    • +
        + +
      • + + + +
      • +
        +
      +
      +
    +
    +
    +
    +
    +
    + + + +
      + + + + +
    +
    + + + + +
    diff --git a/ml/dlib/docs/docs/chm/lib.hhp b/ml/dlib/docs/docs/chm/lib.hhp new file mode 100644 index 000000000..de5633f1d --- /dev/null +++ b/ml/dlib/docs/docs/chm/lib.hhp @@ -0,0 +1,77 @@ +[OPTIONS] +Binary TOC=Yes +Compatibility=1.1 or later +Compiled file=help.chm +Contents file=Table of Contents.hhc +Default topic=docs/index.html +Display compile progress=Yes +Full-text search=Yes +Language=0x409 English (United States) +Title=dLib + + +[FILES] +docs/guipics/button.png +docs/guipics/check_box.png +docs/guipics/directed_graph_drawer.png +docs/guipics/image_window.jpg +docs/guipics/label.png +docs/guipics/list_box.png +docs/guipics/menu_bar.png +docs/guipics/message_box.png +docs/guipics/mouse_tracker.png +docs/guipics/named_rectangle.png +docs/guipics/open_existing_file_box.png +docs/guipics/open_file_box.png +docs/guipics/popup_menu.png +docs/guipics/radio_button.png +docs/guipics/save_file_box.png +docs/guipics/scroll_bar.png +docs/guipics/tabbed_display.png +docs/guipics/text_box.png +docs/guipics/text_field.png +docs/guipics/text_grid.png +docs/api.html +docs/dlib/test/makefile +docs/right.gif +docs/down.gif +docs/plus.gif +docs/minus.gif +docs/change_log.html +docs/compile.html +docs/compress_stream_ex.cpp.html +docs/compression.html +docs/containers.html +docs/dir_nav_ex.cpp.html +docs/gui_api_ex.cpp.html +docs/index.html +docs/intro.html +docs/optimization.html +docs/kernel_1a.html +docs/kernel_1b.html +docs/kernel_1c.html +docs/kernel_1da.html +docs/kernel_1db.html +docs/kernel_1ea.html +docs/kernel_1eb.html +docs/kernel_1ec.html +docs/kernel_2a.html +docs/kernel_3a.html +docs/kernel_3b.html +docs/license.html +docs/network.html +docs/other.html +docs/metaprogramming.html +docs/imaging.html +docs/parsing.html +docs/queue_ex.cpp.html +docs/release_notes.html +docs/old_release_notes.html +docs/sockets_ex.cpp.html +docs/faq.html +docs/rbf_normal.gif +docs/rbf_big_gamma.gif +docs/rbf_small_gamma.gif + +[INFOTYPES] + diff --git a/ml/dlib/docs/docs/chm/toc.xml b/ml/dlib/docs/docs/chm/toc.xml new file mode 100644 index 000000000..37593b1ed --- /dev/null +++ b/ml/dlib/docs/docs/chm/toc.xml @@ -0,0 +1,10 @@ + + + + + docs + ../main_menu.xml + false + + + diff --git a/ml/dlib/docs/docs/compile.xml b/ml/dlib/docs/docs/compile.xml new file mode 100644 index 000000000..bb21b3fa9 --- /dev/null +++ b/ml/dlib/docs/docs/compile.xml @@ -0,0 +1,227 @@ + + + + + How to compile + + + + + + + + + +

    Compiling C++ Example Programs on Any Operating System Using CMake

    + The best way to compile a program that uses dlib is to use CMake. For + example, the following commands will compile the example programs on any operating + system: + +cd examples +mkdir build +cd build +cmake .. +cmake --build . --config Release + +Note that you also need to have a C++11 compiler installed on your system. There are free C++11 compilers +for most operating systems. For example, Visual Studio is free on Windows and GCC is free and +works well on Mac OS X and Linux systems. If you have multiple compilers/IDEs installed then you can +tell CMake which one you want it to use via the -G option. +

    + The examples/CMakeLists.txt file tells CMake how to build + the examples. You can create your own projects by starting with this file and editing it however you like. + You can also perform additional configuration of a cmake project using the cmake-gui or ccmake tool. For example, + if you are using dlib's face detector then you should turn on either SSE4 or AVX instructions since this + makes it run much faster (also see this FAQ). +

    +

    + Finally, note that when using Visual Studio, CMake will by default generate a 32bit executable. + This means the programs you compile will only be able to use 2GB of RAM. To avoid this, you need + to tell CMake to generate a 64bit executable. You do this by using a command like + cmake -G "Visual Studio 14 2015 Win64" -T host=x64 .. instead of cmake .. + You can see the list of valid arguments to -G by running cmake with no options. Note also the -T host=x64 + option, which tells Visual Studio to let the compiler use more than 2GB of RAM. That is important if you don't want the compiler to + crash from running out of RAM in some situations. +

    + + +
    +

    Compiling Dlib's Python Interface

    +

    + Go to the base folder of the dlib repository and run python setup.py install. That + should compile and install the dlib python API on your system. +

    +

    + Alternatively, if you want to add more python bindings to dlib's + python interface then you probably want to avoid the setup.py file + and work directly using CMake. In particular, dlib's python API is + built by the CMake project in the tools/python folder. You build + this project using the usual CMake commands and when compiled it + outputs the dlib shared library that defines the python API for dlib. +

    +
    + + +

    Compiling C++ Examples Without CMake

    + +

    + In most cases, to use this library all you have to do is extract it somewhere, make + sure the folder containing the dlib folder is in your include path, and + finally add dlib/all/source.cpp to your + project. It is worth noting that most of dlib is "header-only" which means that, in + many cases, you don't actually have to build dlib/all/source.cpp into your + application. So if you don't get linker errors when you exclude dlib/all/source.cpp + from your project then you don't need it. +

    +

    + An example makefile that uses this library can be found here: dlib/test/makefile. It is the makefile used to build + the regression test suite for this library. +

    +

    + Again, note that you should not add the dlib folder itself to your compiler's include path. + Doing so will cause the + build to fail because of name collisions (e.g. dlib/string.h with string.h from the standard library). + Instead you should add the folder that contains the dlib folder to your include search path and then use + include statements of the form #include <dlib/queue.h>. This will ensure that everything + builds correctly. +

    +

    + Note also that if you want to work with jpeg/png/gif files using dlib then you will + need to link your program with libjpeg, libpng, and/or libgif. You also need to tell dlib + about this by defining the DLIB_JPEG_SUPPORT, DLIB_PNG_SUPPORT, and DLIB_GIF_SUPPORT preprocessor directives. + How you "link to libjpeg/libpng/libgif" varies from platform to platform. On UNIX machines you + usually just add a -ljpeg, -lpng, or -lgif switch to your compiler (after installing the libraries). + On windows it's less well defined. So dlib comes with a copy of libjpeg and libpng in the dlib/external + folder so you can statically compile them into your application if no system wide version + is available on your machine. If all this talk about linking is confusing to you then + just use CMake. It will set this all up for you. +

    +

    + Dlib is also capable of using any optimized BLAS or LAPACK libraries that are + installed on your system. Linking to these libraries will make many things run + faster. To do this you define the DLIB_USE_BLAS and/or DLIB_USE_LAPACK preprocessor + directives and then link your program with whatever BLAS or LAPACK libraries you + have. If you use CMake it will set this up automatically. +

    + + + + +

    Compiling on Linux From Command Line

    + From within the examples folder, you can compile nearly all of the examples with a single command like so: + +g++ -std=c++11 -O3 -I.. ../dlib/all/source.cpp -lpthread -lX11 example_program_name.cpp + + Note that not all examples require this much work. For example, the svm_ex.cpp example + can be compiled with just: + +g++ -std=c++11 -O3 -I.. svm_ex.cpp + + +On non-Linux systems like Solaris, you might have to link to other libraries. For example, I have seen systems +where it was also necessary to supply -lnsl or -lsocket options to g++. Additionally, the X11 development +library isn't installed on Ubuntu by default. So if you require it and are using Ubuntu you can install +it by typing: + +sudo apt-get install libx11-dev + + +

    Compiling on Windows Using GCC

    +

    + The commands for gcc on windows are the same as above but you may also have to link + (via the -l option) to the following libraries: gdi32, comctl32, user32, winmm, ws2_32, or imm32. +

    + +

    Compiling on Windows Using Visual Studio 2015 or Newer

    +

    + All you need to do is create an empty console project. Then add dlib/all/source.cpp to it and add the + folder containing the dlib folder to the #include search path. Then you can compile any example program + by adding it to your project. +

    +

    + Again, note that dlib will only be able to work with jpeg and png files if you link + in libjpeg and libpng. In Visual Studio, the easiest way to do this is to add all the + libjpeg, libpng, and zlib source files in the dlib/external folder into your project and also define the + DLIB_PNG_SUPPORT and DLIB_JPEG_SUPPORT preprocessor directives. If you don't know + how to configure Visual Studio then you should use CMake as shown above since it will + take care of everything automatically. +

    + + +
    +

    Miscellaneous Preprocessor Directives

    + +

    + In addition to the preprocessor directives mentioned above, there + are a few more you can supply during the build process to cause the + library to build in various optional ways. By default, the library + will always do something reasonable, but they are listed here in + the event that you need to use them. +

    + + + ENABLE_ASSERTS +

    #define ENABLE_ASSERTS

    +

    + Defining this directive causes all the DLIB_ASSERT macros to + be active. If you are using Visual Studio or CMake then ENABLE_ASSERTS will be automatically enabled + for you when you compile in debug mode. However, if you are using a different build system then you + might have to manually enable it if you want to turn the asserts on. +

    + + DLIB_ISO_CPP_ONLY +

    #define DLIB_ISO_CPP_ONLY

    +

    + This is a #define directive that you can set to cause the library to exclude all non ISO C++ code (The things in the API wrappers section and any objects that depend on those wrappers). + This is useful if you are trying to build on a system that isn't fully supported by the library or if you + just decide you don't want any of that stuff compiled into your program for your own reasons. +

    + DLIB_NO_GUI_SUPPORT +

    #define DLIB_NO_GUI_SUPPORT

    +

    + This is just like the DLIB_ISO_CPP_ONLY option except that it excludes only the GUI part of the library. + An example of when you might want to use this would be if you don't need GUI support and you are building + on a UNIX platform that doesn't have the X11 headers installed. +

    + NO_MAKEFILE +

    #define NO_MAKEFILE

    +

    + This preprocessor directive causes the dlib headers to pull in all the + code that would normally be built in dlib/all/source.cpp. Thus if you #define NO_MAKEFILE you won't + have to add dlib/all/source.cpp to your project. The only time this is useful is when your + project consists of a single translation unit (i.e. a single cpp file). In this instance NO_MAKEFILE + allows you to easily build your project on the command line by saying something like g++ -DNO_MAKEFILE + project.cpp. But again, this is only for single cpp file projects. If you use NO_MAKEFILE with projects + that contain more than one cpp file you will get linker errors about multiply defined symbols. +

    +

    + Also note that if you use this macro then the stack trace + functionality in the library will be disabled. +

    + DLIB_THREAD_POOL_TIMEOUT +

    #define DLIB_THREAD_POOL_TIMEOUT <time-in-milliseconds>

    +

    + If you use dlib to create your threads then you receive the benefit of the dlib dynamic thread pool (Note that the + dlib::thread_pool object is something else unrelated to this so don't confuse + the two). This pool + enables dlib to spawn new threads very rapidly since it draws threads back out of its thread pool when + the pool isn't empty. +

    +

    + Thus, when a thread that was created by dlib ends it actually goes back into the dlib thread pool + and waits DLIB_THREAD_POOL_TIMEOUT milliseconds before totally terminating and releasing its resources back + to the operating system. The default timeout used by this library is 30,000 milliseconds (30 seconds). You + may however change this to whatever you like by defining DLIB_THREAD_POOL_TIMEOUT to some new value. +

    + + + + + + + + + + +
    diff --git a/ml/dlib/docs/docs/compression.xml b/ml/dlib/docs/docs/compression.xml new file mode 100644 index 000000000..a5b3897b2 --- /dev/null +++ b/ml/dlib/docs/docs/compression.xml @@ -0,0 +1,881 @@ + + + + + Data Compression + + + + +

    + This page contains a bunch of objects that implement various parts of compression algorithms. + They can be put together in different ways to construct many different algorithms. + Note that the compress_stream object contains complete compression algorithms. So if you + just want to compress some data then you can easily use that object and not bother with the others. +

    +

    + In the column to the right you can see benchmark data for each of the compress_stream + typedefs. The times measured are the time it takes to compress and then + decompress each file. It was run on a 3.0ghz P4. For reference see the Canterbury corpus + web site. +

    + + + + + + + + + +
    + Objects + compress_stream + conditioning_class + entropy_decoder + entropy_encoder + entropy_decoder_model + entropy_encoder_model + lz77_buffer + lzp_buffer +
    + + + + +
    + Benchmarks + + kernel_1a + kernel_1a.html + + + kernel_1b + kernel_1b.html + + + kernel_1c + kernel_1c.html + + + kernel_1da + kernel_1da.html + + + kernel_1db + kernel_1db.html + + + kernel_1ea + kernel_1ea.html + + + kernel_1eb + kernel_1eb.html + + + kernel_1ec + kernel_1ec.html + + + kernel_2a + kernel_2a.html + + + kernel_3a + kernel_3a.html + + + kernel_3b + kernel_3b.html + +
    +
    +
    + + + + + + + + + + + + compress_stream + dlib/compress_stream.h + dlib/compress_stream/compress_stream_kernel_abstract.h + + This object is pretty straight forward. It has no state and just + contains the functions compress and decompress. + They do just what their names imply to iostream objects. + + + + compress_stream_ex.cpp.html + file_to_code_ex.cpp.html + + + + + compress_stream_kernel_1 + dlib/compress_stream/compress_stream_kernel_1.h + + This implementation is done using the entropy_encoder_model and + entropy_decoder_model objects. + + + + + + kernel_1a + is a typedef for compress_stream_kernel_1 which uses entropy_decoder_model_kernel_1b and entropy_decoder_model_kernel_1b + + + kernel_1b + is a typedef for compress_stream_kernel_1 which uses entropy_decoder_model_kernel_2b and entropy_decoder_model_kernel_2b + + + kernel_1c + is a typedef for compress_stream_kernel_1 which uses entropy_decoder_model_kernel_3b and entropy_decoder_model_kernel_3b + + + kernel_1da + is a typedef for compress_stream_kernel_1 which uses entropy_decoder_model_kernel_4a and entropy_decoder_model_kernel_4a + + + kernel_1db + is a typedef for compress_stream_kernel_1 which uses entropy_decoder_model_kernel_4b and entropy_decoder_model_kernel_4b + + + kernel_1ea + is a typedef for compress_stream_kernel_1 which uses entropy_decoder_model_kernel_5a and entropy_decoder_model_kernel_5a + + + kernel_1eb + is a typedef for compress_stream_kernel_1 which uses entropy_decoder_model_kernel_5b and entropy_decoder_model_kernel_5b + + + kernel_1ec + is a typedef for compress_stream_kernel_1 which uses entropy_decoder_model_kernel_5c and entropy_decoder_model_kernel_5c + + + + + + + compress_stream_kernel_2 + dlib/compress_stream/compress_stream_kernel_2.h + + This implementation is done using the entropy_encoder_model and + entropy_decoder_model objects. It also uses the + lz77_buffer object. It uses the entropy coder models to + encode symbols when there is no match found by the lz77_buffer. + + + + + + kernel_2a + is a typedef for compress_stream_kernel_2 which uses entropy_encoder_model_kernel_2b, entropy_decoder_model_kernel_2b, and lz77_buffer_kernel_2a. + + + + + + + + compress_stream_kernel_3 + dlib/compress_stream/compress_stream_kernel_3.h + + This implementation is done using the lzp_buffer object and + crc32 object. It does not use any sort of entropy coding, instead + a byte aligned output method is used. + + + + + + kernel_3a + is a typedef for compress_stream_kernel_3 which uses lzp_buffer_kernel_1. + + + kernel_3b + is a typedef for compress_stream_kernel_3 which uses lzp_buffer_kernel_2. + + + + + + + + + + + + + + conditioning_class + dlib/conditioning_class.h + dlib/conditioning_class/conditioning_class_kernel_abstract.h + + This object represents a conditioning class used for arithmetic style + compression. It maintains the cumulative counts which are needed + by the entropy_encoder and entropy_decoder objects below. + + + + + conditioning_class_kernel_1 + dlib/conditioning_class/conditioning_class_kernel_1.h + + This implementation is done using an array to store all the counts and they are summed + whenever the cumulative counts are requested. It's pretty straight forward. + + + + + kernel_1a + is a typedef for conditioning_class_kernel_1 + + + + + + conditioning_class_kernel_2 + dlib/conditioning_class/conditioning_class_kernel_2.h + + This implementation is done using a binary tree where each node in the tree represents one symbol and + contains that symbols count and the sum of all the counts for the nodes to the left. This way + when you request a cumulative count it can be computed by visiting log n nodes where n is the + size of the alphabet. + + + + + kernel_2a + is a typedef for conditioning_class_kernel_2 + + + + + + + conditioning_class_kernel_3 + dlib/conditioning_class/conditioning_class_kernel_3.h + + This implementation is done using an array to store all the counts and they are + summed whenever the cumulative counts are requested. The counts are also kept in + semi-sorted order to speed up the calculation of the cumulative count. + + + + + kernel_3a + is a typedef for conditioning_class_kernel_3 + + + + + + + conditioning_class_kernel_4 + dlib/conditioning_class/conditioning_class_kernel_4.h + + This implementation is done using a linked list to store all the counts and they are + summed whenever the cumulative counts are requested. The counts are also kept in + semi-sorted order to speed up the calculation of the cumulative count. This implementation + also uses the memory_manager component to create a + memory pool of linked list nodes. This implementation is especially useful for high order + contexts and/or very large and sparse alphabets. + + + + + + kernel_4a + is a typedef for conditioning_class_kernel_4 with a memory pool of 10,000 nodes. + + + kernel_4b + is a typedef for conditioning_class_kernel_4 with a memory pool of 100,000 nodes. + + + kernel_4c + is a typedef for conditioning_class_kernel_4 with a memory pool of 1,000,000 nodes. + + + kernel_4d + is a typedef for conditioning_class_kernel_4 with a memory pool of 10,000,000 nodes. + + + + + + + + + + + + + entropy_decoder + dlib/entropy_decoder.h + dlib/entropy_decoder/entropy_decoder_kernel_abstract.h + + This object represents an entropy decoder. E.g. the decoding part of + an arithmetic coder. + + + + + entropy_decoder_kernel_1 + dlib/entropy_decoder/entropy_decoder_kernel_1.h + + This object is implemented using arithmetic coding and is done in the + straight forward way using integers and fixed precision math. + + + + + kernel_1a + is a typedef for entropy_decoder_kernel_1 + + + + + + entropy_decoder_kernel_2 + dlib/entropy_decoder/entropy_decoder_kernel_2.h + + This object is implemented using "range" coding and is done + in the straight forward way using integers and fixed precision math. + + + + + kernel_2a + is a typedef for entropy_decoder_kernel_2 + + + + + + + + + + + + + + entropy_encoder + dlib/entropy_encoder.h + dlib/entropy_encoder/entropy_encoder_kernel_abstract.h + + This object represents an entropy encoder. E.g. the encoding part of + an arithmetic coder. + + + + + entropy_encoder_kernel_1 + dlib/entropy_encoder/entropy_encoder_kernel_1.h + + This object is implemented using arithmetic coding and is done in the + straight forward way using integers and fixed precision math. + + + + + kernel_1a + is a typedef for entropy_encoder_kernel_1 + + + + + + entropy_encoder_kernel_2 + dlib/entropy_encoder/entropy_encoder_kernel_2.h + + This object is implemented using "range" coding and is done + in the straight forward way using integers and fixed precision math. + + + + + kernel_2a + is a typedef for entropy_encoder_kernel_2 + + + + + + + + + + + + + + entropy_decoder_model + dlib/entropy_decoder_model.h + dlib/entropy_decoder_model/entropy_decoder_model_kernel_abstract.h + + This object represents some kind of statistical model. You + can use it to read symbols from an entropy_decoder and it will calculate + the cumulative counts/probabilities and manage contexts for you. + + + + + entropy_decoder_model_kernel_1 + dlib/entropy_decoder_model/entropy_decoder_model_kernel_1.h + + This object is implemented using the conditioning_class component. + It implements an order-0 finite context model and uses lazy exclusions and update exclusions. + The escape method used is method D. + + + + + kernel_1a + is a typedef for entropy_decoder_model_kernel_1 that uses conditioning_class_kernel_1a + + + kernel_1b + is a typedef for entropy_decoder_model_kernel_1 that uses conditioning_class_kernel_2a + + + kernel_1c + is a typedef for entropy_decoder_model_kernel_1 that uses conditioning_class_kernel_3a + + + + + + + entropy_decoder_model_kernel_2 + dlib/entropy_decoder_model/entropy_decoder_model_kernel_2.h + + This object is implemented using the conditioning_class component. + It implements an order-1-0 finite context model and uses lazy exclusions and update exclusions. + The escape method used is method D. + + + + + kernel_2a + is a typedef for entropy_decoder_model_kernel_2 that uses conditioning_class_kernel_1a + + + kernel_2b + is a typedef for entropy_decoder_model_kernel_2 that uses conditioning_class_kernel_2a + + + kernel_2c + is a typedef for entropy_decoder_model_kernel_2 that uses conditioning_class_kernel_3a + + + kernel_2d + is a typedef for entropy_decoder_model_kernel_2 that uses conditioning_class_kernel_2a for its order-0 + context and conditioning_class_kernel_4b for its order-1 context. + + + + + + + entropy_decoder_model_kernel_3 + dlib/entropy_decoder_model/entropy_decoder_model_kernel_3.h + + This object is implemented using the conditioning_class component. + It implements an order-2-1-0 finite context model and uses lazy exclusions and update exclusions. + The escape method used is method D. + + + + + kernel_3a + is a typedef for entropy_decoder_model_kernel_3 that uses conditioning_class_kernel_1a for orders 0 and 1 + and conditioning_class_kernel_4b for order-2. + + + kernel_3b + is a typedef for entropy_decoder_model_kernel_3 that uses conditioning_class_kernel_2a for orders 0 and 1 + and conditioning_class_kernel_4b for order-2. + + + kernel_3c + is a typedef for entropy_decoder_model_kernel_3 that uses conditioning_class_kernel_3a for orders 0 and 1 + and conditioning_class_kernel_4b for order-2. + + + + + + + entropy_decoder_model_kernel_4 + dlib/entropy_decoder_model/entropy_decoder_model_kernel_4.h + + This object is implemented using a variation of the PPM algorithm described by Alistair Moffat in his paper "Implementing + the PPM data compression scheme." + It provides template arguments to select the maximum order and maximum memory to use. For speed, + exclusions are not used. The escape method used is method D. + + + + + kernel_4a + is a typedef for entropy_decoder_model_kernel_4 with the max order set to 4 and the max number + of nodes set to 200,000 + + + kernel_4b + is a typedef for entropy_decoder_model_kernel_4 with the max order set to 5 and the max number + of nodes set to 1,000,000 + + + + + + + entropy_decoder_model_kernel_5 + dlib/entropy_decoder_model/entropy_decoder_model_kernel_5.h + + This object is implemented using a variation of the PPM algorithm described by Alistair Moffat in his paper "Implementing + the PPM data compression scheme." + It provides template arguments to select the maximum order and maximum memory to use. Exclusions are used. The escape method used is method D. + This implementation is very much like kernel_4 except it is tuned for higher compression rather than speed. + This also uses Dmitry Shkarin's Information Inheritance scheme. + + + + + kernel_5a + is a typedef for entropy_decoder_model_kernel_5 with the max order set to 4 and the max number + of nodes set to 200,000 + + + kernel_5b + is a typedef for entropy_decoder_model_kernel_5 with the max order set to 5 and the max number + of nodes set to 1,000,000 + + + kernel_5c + is a typedef for entropy_decoder_model_kernel_5 with the max order set to 7 and the max number + of nodes set to 2,500,000 + + + + + + + entropy_decoder_model_kernel_6 + dlib/entropy_decoder_model/entropy_decoder_model_kernel_6.h + + This object just assigns every symbol the same probability. I.e. it uses an order-(-1) model. + + + + + kernel_6a + is a typedef for entropy_decoder_model_kernel_6 + + + + + + + + + + + + + + + entropy_encoder_model + dlib/entropy_encoder_model.h + dlib/entropy_encoder_model/entropy_encoder_model_kernel_abstract.h + + This object represents some kind of statistical model. You + can use it to write symbols to an entropy_encoder and it will calculate + the cumulative counts/probabilities and manage contexts for you. + + + + + entropy_encoder_model_kernel_1 + dlib/entropy_encoder_model/entropy_encoder_model_kernel_1.h + + This object is implemented using the conditioning_class component. + It implements an order-0 finite context model and uses lazy exclusions and update exclusions. + The escape method used is method D. + + + + + kernel_1a + is a typedef for entropy_encoder_model_kernel_1 that uses conditioning_class_kernel_1a + + + kernel_1b + is a typedef for entropy_encoder_model_kernel_1 that uses conditioning_class_kernel_2a + + + kernel_1c + is a typedef for entropy_encoder_model_kernel_1 that uses conditioning_class_kernel_3a + + + + + + + entropy_encoder_model_kernel_2 + dlib/entropy_encoder_model/entropy_encoder_model_kernel_2.h + + This object is implemented using the conditioning_class component. + It implements an order-1-0 finite context model and uses lazy exclusions and update exclusions. + The escape method used is method D. + + + + + kernel_2a + is a typedef for entropy_encoder_model_kernel_2 that uses conditioning_class_kernel_1a + + + kernel_2b + is a typedef for entropy_encoder_model_kernel_2 that uses conditioning_class_kernel_2a + + + kernel_2c + is a typedef for entropy_encoder_model_kernel_2 that uses conditioning_class_kernel_3a + + + kernel_2d + is a typedef for entropy_encoder_model_kernel_2 that uses conditioning_class_kernel_2a for its order-0 + context and conditioning_class_kernel_4b for its order-1 context. + + + + + + + entropy_encoder_model_kernel_3 + dlib/entropy_encoder_model/entropy_encoder_model_kernel_3.h + + This object is implemented using the conditioning_class component. + It implements an order-2-1-0 finite context model and uses lazy exclusions and update exclusions. + The escape method used is method D. + + + + + kernel_3a + is a typedef for entropy_encoder_model_kernel_3 that uses conditioning_class_kernel_1a for orders 0 and 1 + and conditioning_class_kernel_4b for order-2. + + + kernel_3b + is a typedef for entropy_encoder_model_kernel_3 that uses conditioning_class_kernel_2a for orders 0 and 1 + and conditioning_class_kernel_4b for order-2. + + + kernel_3c + is a typedef for entropy_encoder_model_kernel_3 that uses conditioning_class_kernel_3a for orders 0 and 1 + and conditioning_class_kernel_4b for order-2. + + + + + + + entropy_encoder_model_kernel_4 + dlib/entropy_encoder_model/entropy_encoder_model_kernel_4.h + + This object is implemented using a variation of the PPM algorithm described by Alistair Moffat in his paper "Implementing + the PPM data compression scheme." + It provides template arguments to select the maximum order and maximum memory to use. For speed, + exclusions are not used. The escape method used is method D. + + + + + kernel_4a + is a typedef for entropy_encoder_model_kernel_4 with the max order set to 4 and the max number + of nodes set to 200,000 + + + kernel_4b + is a typedef for entropy_encoder_model_kernel_4 with the max order set to 5 and the max number + of nodes set to 1,000,000 + + + + + + + entropy_encoder_model_kernel_5 + dlib/entropy_encoder_model/entropy_encoder_model_kernel_5.h + + This object is implemented using a variation of the PPM algorithm described by Alistair Moffat in his paper "Implementing + the PPM data compression scheme." + It provides template arguments to select the maximum order and maximum memory to use. Exclusions are used. The escape method used is method D. + This implementation is very much like kernel_4 except it is tuned for higher compression rather than speed. + This also uses Dmitry Shkarin's Information Inheritance scheme. + + + + + kernel_5a + is a typedef for entropy_encoder_model_kernel_5 with the max order set to 4 and the max number + of nodes set to 200,000 + + + kernel_5b + is a typedef for entropy_encoder_model_kernel_5 with the max order set to 5 and the max number + of nodes set to 1,000,000 + + + kernel_5c + is a typedef for entropy_encoder_model_kernel_5 with the max order set to 7 and the max number + of nodes set to 2,500,000 + + + + + + + entropy_encoder_model_kernel_6 + dlib/entropy_encoder_model/entropy_encoder_model_kernel_6.h + + This object just assigns every symbol the same probability. I.e. it uses an order-(-1) model. + + + + + kernel_6a + is a typedef for entropy_encoder_model_kernel_6 + + + + + + + + + + + + + + + lz77_buffer + dlib/lz77_buffer.h + dlib/lz77_buffer/lz77_buffer_kernel_abstract.h + + This object represents a pair of buffers (history and lookahead buffers) + used during lz77 style compression. + + + + + lz77_buffer_kernel_1 + dlib/lz77_buffer/lz77_buffer_kernel_1.h + + This object is implemented using the sliding_buffer and it + just does simple linear searches of the history buffer to find matches. + + + + + kernel_1a + is a typedef for lz77_buffer_kernel_1 that uses sliding_buffer_kernel_1 + + + + + + lz77_buffer_kernel_2 + dlib/lz77_buffer/lz77_buffer_kernel_2.h + + This object is implemented using the sliding_buffer. It + finds matches by using a hash table. + + + + + kernel_2a + is a typedef for lz77_buffer_kernel_2 that uses sliding_buffer_kernel_1 + + + + + + + + + + + + + + lzp_buffer + dlib/lzp_buffer.h + dlib/lzp_buffer/lzp_buffer_kernel_abstract.h + + This object represents some variation on the LZP algorithm + described by Charles Bloom in his paper "LZP: a new data + compression algorithm" + + + + + lzp_buffer_kernel_1 + dlib/lzp_buffer/lzp_buffer_kernel_1.h + + This object is implemented using the sliding_buffer and uses + an order-3 model to predict matches. + + + + + kernel_1a + is a typedef for lzp_buffer_kernel_1 that uses sliding_buffer_kernel_1 + + + + + + lzp_buffer_kernel_2 + dlib/lzp_buffer/lzp_buffer_kernel_2.h + + This object is implemented using the sliding_buffer and uses + an order-5-4-3 model to predict matches. + + + + + kernel_2a + is a typedef for lzp_buffer_kernel_2 that uses sliding_buffer_kernel_1 + + + + + + + + + + + + + + + + +
    diff --git a/ml/dlib/docs/docs/containers.xml b/ml/dlib/docs/docs/containers.xml new file mode 100644 index 000000000..8409b6907 --- /dev/null +++ b/ml/dlib/docs/docs/containers.xml @@ -0,0 +1,1201 @@ + + + + + Containers + + + + +

    + Many of these containers were inspired by the work of the Reusable + Software Research Group at Ohio State. As such, many of the objects do not support + copying in any form, only swapping is allowed. That is, when objects + are added or removed from most of these containers they are swapped in + and out, not copied. +

    + +

    + This allows you to do things like have containers of containers of + containers without encountering the overhead of the massive copying + that would likely result if you did the same thing with the STL. It + also means you can store objects that are not copyable inside these + containers, which is not something you can do with the STL prior to C++11. +

    + +

    + Note that it is assumed by these containers that swap() and + operator< do not throw. They may not function correctly if this + assumption is broken. Also note that the built in types (int, long, + char, etc.) and std::string will not cause operator< or swap() to + throw. +

    + +

    + Note also that most of the containers inherit from the + enumerable interface. Thus, all the + member functions inherited from enumerable are defined in the + enumerable class and their documentation is not repeated in each + container's documentation. This includes the size() member + function in each container. +

    + + + + + + + + +
    + Objects + static_set + any + any_trainer + any_function + any_decision_function + array + array2d + binary_search_tree + hash_map + hash_set + hash_table + directed_graph + graph + map + queue + reference_counter + type_safe_union + unordered_pair + sequence + set + stack + std_vector_c + static_map + sliding_buffer + circular_buffer + tuple + reference_wrapper + +
    + +
    + Interfaces + map_pair + enumerable + + remover + + + remover + dlib/interfaces/remover.h.html#remover + + + asc_remover + dlib/interfaces/remover.h.html#asc_remover + + + pair_remover + dlib/interfaces/remover.h.html#pair_remover + + + asc_pair_remover + dlib/interfaces/remover.h.html#asc_pair_remover + + + +
    +
    +
    + + + + + + + + + array + dlib/array.h + dlib/array/array_kernel_abstract.h + + This object represents a 1-Dimensional array of objects. + + + + + + + sliding_buffer + dlib/sliding_buffer.h + dlib/sliding_buffer/sliding_buffer_kernel_abstract.h + + This object represents an array with the ability to rotate its contents + left or right. Note that the size of this object is always a power of two. + If you need arbitrary sized objects then use a circular_buffer. + + + + + sliding_buffer_kernel_1 + dlib/sliding_buffer/sliding_buffer_kernel_1.h + + This object is implemented using a C style array in the obvious way. See the code for details. + + + + + kernel_1a + is a typedef for sliding_buffer_kernel_1 + + + + + + + + + + + + + + circular_buffer + dlib/sliding_buffer.h + dlib/sliding_buffer/circular_buffer_abstract.h + + This object represents a simple sliding buffer which can contain + and arbitrary number of elements. + + + + + + + + array2d + dlib/array2d.h + dlib/array2d/array2d_kernel_abstract.h + + This object represents a 2-Dimensional array of objects. + + + + image_ex.cpp.html + + + + + + + + binary_search_tree + dlib/binary_search_tree.h + dlib/binary_search_tree/binary_search_tree_kernel_abstract.h + + This object represents a data dictionary that is built on top of some kind of binary search tree. + + + + + binary_search_tree_kernel_1 + dlib/binary_search_tree/binary_search_tree_kernel_1.h + + This implementation is done using an AVL binary search tree. It uses the + memory_manager for all memory allocations. + + + + + kernel_1a + is a typedef for binary_search_tree_kernel_1 + + + + + + binary_search_tree_kernel_2 + dlib/binary_search_tree/binary_search_tree_kernel_2.h + + This implementation is done using a red-black binary search tree. It uses the + memory_manager for all memory allocations. + + + + kernel_2a + is a typedef for binary_search_tree_kernel_2 + + + + + + + + + + + + + + hash_map + dlib/hash_map.h + dlib/hash_map/hash_map_kernel_abstract.h + + This object represents a hashed mapping of items of type domain onto items of type range. + + + + + hash_map_kernel_1 + dlib/hash_map/hash_map_kernel_1.h + + This implementation is done using a hash_table object. It uses the + memory_manager for all memory allocations. + + + + + + + kernel_1a + is a typedef for hash_map_kernel_1 that uses hash_table_kernel_1a + + + kernel_1b + is a typedef for hash_map_kernel_1 that uses hash_table_kernel_2a + + + kernel_1c + is a typedef for hash_map_kernel_1 that uses hash_table_kernel_2b + + + + + + + + + + + + + + + + hash_set + dlib/hash_set.h + dlib/hash_set/hash_set_kernel_abstract.h + + This object represents a hashed unordered and unaddressed collection of unique items. + + + + + hash_set_kernel_1 + dlib/hash_set/hash_set_kernel_1.h + + This implementation is done using a hash_table object. It uses the + memory_manager for all memory allocations. + + + + + + + kernel_1a + is a typedef for hash_set_kernel_1 that uses hash_table_kernel_1a + + + kernel_1b + is a typedef for hash_set_kernel_1 that uses hash_table_kernel_2a + + + kernel_1c + is a typedef for hash_set_kernel_1 that uses hash_table_kernel_2b + + + + + + + + + + + + + + + + hash_table + dlib/hash_table.h + dlib/hash_table/hash_table_kernel_abstract.h + + This object represents a data dictionary that is built on top of some kind of + hash table. + + + + + hash_table_kernel_1 + dlib/hash_table/hash_table_kernel_1.h + + This implementation is done using singly linked lists as hashing buckets. It uses the + memory_manager for all memory allocations. + + + + + + kernel_1a + is a typedef for hash_table_kernel_1. + + + + + + + hash_table_kernel_2 + dlib/hash_table/hash_table_kernel_2.h + + This implementation is done using + binary_search_tree objects as hashing buckets. It uses the + memory_manager for all memory allocations. + + + + + + + kernel_2a + is a typedef for hash_table_kernel_2 that uses binary_search_tree_kernel_1 + + + kernel_2b + is a typedef for hash_table_kernel_2 that uses binary_search_tree_kernel_2 + + + + + + + + + + + + + + + + + map + dlib/map.h + dlib/map/map_kernel_abstract.h + + This object represents a mapping of items of type domain onto items of type range. + + + + + map_kernel_1 + dlib/map/map_kernel_1.h + + This is implemented using the binary_search_tree component. It uses the + memory_manager for all memory allocations. + + + + + + + kernel_1a + is a typedef for map_kernel_1 that uses binary_search_tree_kernel_1 + + + kernel_1b + is a typedef for map_kernel_1 that uses binary_search_tree_kernel_2 + + + + + + + + + + + + + + + + + enumerable + dlib/interfaces/enumerable.h + dlib/interfaces/enumerable.h + + This object is an abstract class which represents an interface for iterating over + all the elements of a container. + + + + + + + + + + map_pair + dlib/interfaces/map_pair.h + dlib/interfaces/map_pair.h + + This object is an abstract class which represents an interface for accessing a + pair from a container such as the map, hash_table, etc. + + + + + + + + + + remover + dlib/interfaces/remover.h + dlib/interfaces/remover.h + + This is a set of interfaces which gives the ability to remove all the items in a + container without actually knowing what kind of container contains them. + + + + + + + + + + type_safe_union + dlib/type_safe_union.h + dlib/type_safe_union/type_safe_union_kernel_abstract.h + + This object is a type safe analogue of the classic C union object. + The type_safe_union, unlike a union, can contain non-POD types such + as std::string. +

    It is also implemented without performing any + heap memory allocations and instead it stores everything on the stack.

    +
    + + + pipe_ex_2.cpp.html + bridge_ex.cpp.html + + +
    + + + + + unordered_pair + dlib/unordered_pair.h + dlib/unordered_pair.h + + This object is very similar to the std::pair struct except unordered_pair + is only capable of representing an unordered set of two items rather than + an ordered list of two items like std::pair. + + + + + + + + any + dlib/any.h + dlib/any/any_abstract.h + + This object is basically a type-safe version of a void*. In particular, + it is a container which can contain only one object but the object may + be of any type. + +

    + It is somewhat like the type_safe_union except you don't have to declare + the set of possible content types beforehand. So in some sense this is + like a less type-strict version of the type_safe_union. +

    +
    +
    + + + + + any_decision_function + dlib/any.h + dlib/any/any_decision_function_abstract.h + + This object is a version of dlib::any that is restricted to containing + elements which are some kind of function object with an operator() with + the following signature: + result_type operator()(const sample_type&) const + +

    + It is intended to be used to contain dlib::decision_function + objects and other types which represent learned decision functions. It allows you + to write code which contains and processes these decision functions + without needing to know the specific types of decision functions used. +

    +
    +
    + + + + + any_function + dlib/any.h + dlib/any/any_function_abstract.h + + This object is a version of dlib::any that is restricted to containing + elements which are some kind of function or function object. + + + + + + + any_trainer + dlib/any.h + dlib/any/any_trainer_abstract.h + + This object is a version of dlib::any that is restricted to containing + elements which are some kind of object with a .train() method compatible + with the following signature: +
     decision_function train(
    +      const std::vector<sample_type>& samples,
    +      const std::vector<scalar_type>& labels
    +   ) const
    +
    + Where decision_function is a type capable of being stored in an + any_decision_function object. + +

    + any_trainer is intended to be used to contain objects such as the svm_nu_trainer + and other similar types which represent supervised machine learning algorithms. + It allows you to write code which contains and processes these trainer objects + without needing to know the specific types of trainer objects used. +

    +
    +
    + + + + + tuple + dlib/tuple.h + dlib/tuple/tuple_abstract.h + + This is an implementation of a very simple templated container object. + It contains between 0 and 31 objects where each object is listed + explicitly in the tuple's template arguments. + +

    + Note that there is only one implementation of this object so there aren't any + different kernels to choose from when you create instances of the tuple object. + So for example, you + could declare a tuple of 3 ints using the following statement: + dlib::tuple<int,int,int> t; +

    +
    + +
    + + + + + reference_wrapper + dlib/ref.h + dlib/ref.h + + This is a simple object that just holds a reference to another object. + It is useful because it can serve as a kind of "copyable reference". + + + thread_function_ex.cpp.html + + + + + + + + graph + dlib/graph.h + dlib/graph/graph_kernel_abstract.h + + This object represents a graph which is a set of nodes with undirected + edges connecting various nodes. + + + + + graph_kernel_1 + dlib/graph/graph_kernel_1.h + + This is implemented using std::vector to contain all the nodes and edges. + + + + + kernel_1a + is a typedef for graph_kernel_1 + + + + + + + + + + + + directed_graph + dlib/directed_graph.h + dlib/directed_graph/directed_graph_kernel_abstract.h + + This object represents a directed graph which is a set of nodes with directed + edges connecting various nodes. + + + + + directed_graph_kernel_1 + dlib/directed_graph/directed_graph_kernel_1.h + + This is implemented using std::vector to contain all the nodes and edges. + + + + + kernel_1a + is a typedef for directed_graph_kernel_1 + + + + + + + + + + + + + queue + dlib/queue.h + dlib/queue/queue_kernel_abstract.h + + This object represents a first in first out queue. + + + + queue_ex.cpp.html + + + + + queue_kernel_1 + dlib/queue/queue_kernel_1.h + + This is implemented in the obvious way using a singly linked list. It does not use the + memory_manager at all. + + + + + kernel_1a + is a typedef for queue_kernel_1 + + + + + + queue_kernel_2 + dlib/queue/queue_kernel_2.h + + This is implemented using a singly linked list and each node in the list + contains block_size (a template parameter) elements. It uses the + memory_manager for all memory allocations. + + + + + kernel_2a + is a typedef for queue_kernel_2 with a block_size of 20 + + + kernel_2b + is a typedef for queue_kernel_2 with a block_size of 100 + + + + + + + + + queue_sort + dlib/queue/queue_sort_abstract.h + + This extension gives a queue the ability to sort its contents. + + + + + queue_sort_1 + dlib/queue/queue_sort_1.h + + This is a version of the QuickSort algorithm. + + + + + sort_1a + is a typedef for queue_kernel_1a extended by queue_sort_1 + + + sort_1b + is a typedef for queue_kernel_2a extended by queue_sort_1 + + + sort_1c + is a typedef for queue_kernel_2b extended by queue_sort_1 + + + + + + + + + + + + + + + + + + + reference_counter + dlib/reference_counter.h + dlib/reference_counter/reference_counter_kernel_abstract.h + + This object represents a container for an object and provides reference counting + capabilities for the object it contains. + + + + + reference_counter_kernel_1 + dlib/reference_counter/reference_counter_kernel_1.h + + This implementation is done using pointers in the obvious way. + + + + + kernel_1a + is a typedef for reference_counter_kernel_1 + + + + + + + + + + + + + + + sequence + dlib/sequence.h + dlib/sequence/sequence_kernel_abstract.h + + This object represents an ordered sequence of items, each item is + associated with an integer value. The items are numbered from 0 to the number of items in the + sequence minus 1. + + + + + sequence_kernel_1 + dlib/sequence/sequence_kernel_1.h + + This is implemented as an AVL binary search tree. + Accessing(or adding or removing) an element always takes O(log n) time. + It uses the memory_manager for all memory allocations. + + + + + kernel_1a + is a typedef for sequence_kernel_1 + + + + + + sequence_kernel_2 + dlib/sequence/sequence_kernel_2.h + + This implementation is done using a doubly linked list in the shape of a ring. + It will remember the last element accessed(or added or removed) and give O(1) + access time to the elements just left and right of it. Aside from that, + accessing(or adding or removing) a random element will take O(n) and in the worst + case it will take time proportional to the size of the sequence/2. +

    + It does not use the + memory_manager at all. +

    + +
    + + + + kernel_2a + is a typedef for sequence_kernel_2 + + + +
    +
    + + + + + sequence_sort + dlib/sequence/sequence_sort_abstract.h + + This extension gives a sequence the ability to sort its contents. + + + + + sequence_sort_1 + dlib/sequence/sequence_sort_1.h + + This is a version of the QuickSort algorithm and it sorts sequences of less + than 30 elements with a selection sort. This implementation is fastest when + used with sequence_kernel_2 and fairly slow when used with sequence_kernel_1 + + + + + sort_1a + is a typedef for sequence_kernel_2a extended by sequence_sort_1 + + + + + + sequence_sort_2 + dlib/sequence/sequence_sort_2.h + + This is a version of the QuickSort algorithm. This implementation of sort is + the best to use with sequence_kernel_1 objects but gives extremely poor performance + with sequence_kernel_2 objects. + + + + + sort_2a + is a typedef for sequence_kernel_1a extended by sequence_sort_2 + + + + + + + + + + sequence_compare + dlib/sequence/sequence_compare_abstract.h + + This extension gives sequences the ability to compare themselves using + operator< and operator==. Thus they can be used in the other container classes + that require this ability. (maps, sets, etc.) + + + + + sequence_compare_1 + dlib/sequence/sequence_compare_1.h + + The implementation is obvious. Click on the sequence_compare_1 link if you want to see. + + + + + compare_1a + is a typedef for sequence_kernel_1a extended by sequence_compare_1 + + + compare_1b + is a typedef for sequence_kernel_2a extended by sequence_compare_1 + + + + + + + + + + + +
    + + + + + + + set + dlib/set.h + dlib/set/set_kernel_abstract.h + + This object represents an unordered and unaddressed collection of unique items. + + + + + set_kernel_1 + dlib/set/set_kernel_1.h + + This is implemented using the binary_search_tree component. It uses the + memory_manager for all memory allocations. + + + + + + kernel_1a + is a typedef for set_kernel_1 that uses binary_search_tree_kernel_1 + + + kernel_1b + is a typedef for set_kernel_1 that uses binary_search_tree_kernel_2 + + + + + + + + + + + set_compare + dlib/set/set_compare_abstract.h + + This extension gives sets the ability to compare themselves using operator< and + operator==. Thus they can be used in the other container classes that require + this ability. (maps, sets, etc.) + + + + + set_compare_1 + dlib/set/set_compare_1.h + + The implementation is obvious. Click on the set_compare_1 link if you want to see. + + + + + compare_1a + is a typedef for set_kernel_1a extended by set_compare_1 + + + compare_1b + is a typedef for set_kernel_1b extended by set_compare_1 + + + + + + + + + + + + + + + + + + stack + dlib/stack.h + dlib/stack/stack_kernel_abstract.h + + This object represents a last in first out stack. + + + + + stack_kernel_1 + dlib/stack/stack_kernel_1.h + + This implementation is done in the obvious way using a singly linked list. It uses the + memory_manager for all memory allocations. + + + + + + kernel_1a + is a typedef for stack_kernel_1 + + + + + + + + + + + + + + static_map + dlib/static_map.h + dlib/static_map/static_map_kernel_abstract.h + + This object represents a mapping of items of type domain onto items of type range. + The difference between this object and the normal map object is that it does not support adding + or removing individual objects from itself. This allows implementations to focus on using less memory and + achieving faster searching. + + + + + static_map_kernel_1 + dlib/static_map/static_map_kernel_1.h + + This implementation is just a sorted array which can be searched using a binary search. + + + + + kernel_1a + is a typedef for static_map_kernel_1 + + + + + + + + + + + + + + static_set + dlib/static_set.h + dlib/static_set/static_set_kernel_abstract.h + + This object represents an unordered and unaddressed collection of items. + The difference between this object and the normal set object is that it does not support adding + or removing individual objects from itself. This allows implementations to focus on using less memory and + achieving faster searching. + + + + + static_set_kernel_1 + dlib/static_set/static_set_kernel_1.h + + This implementation is just a sorted array which can be searched using a binary search. + + + + + kernel_1a + is a typedef for static_set_kernel_1 + + + + + + + + + + + static_set_compare + dlib/static_set/static_set_compare_abstract.h + + This extension gives static_sets the ability to compare themselves using operator< and + operator==. Thus they can be used in the other container classes that require + this ability. (maps, static_sets, etc.) + + + + + static_set_compare_1 + dlib/static_set/static_set_compare_1.h + + The implementation is obvious. Click on the static_set_compare_1 link if you want to see. + + + + + compare_1a + is a typedef for static_set_kernel_1a extended by static_set_compare_1 + + + + + + + + + + + + + + + + + std_vector_c + dlib/stl_checked.h + dlib/stl_checked/std_vector_c_abstract.h + + This object is a simple wrapper around the std::vector object. It + provides an identical interface but also checks the preconditions of + each member function. That is, if you violate a requires + clause the dlib::fatal_error exception is thrown. + + + + + +
    + + + + +
    diff --git a/ml/dlib/docs/docs/dlib-icon-30x32.png b/ml/dlib/docs/docs/dlib-icon-30x32.png new file mode 100644 index 000000000..6e7732b9f Binary files /dev/null and b/ml/dlib/docs/docs/dlib-icon-30x32.png differ diff --git a/ml/dlib/docs/docs/dlib-icon-32.png b/ml/dlib/docs/docs/dlib-icon-32.png new file mode 100644 index 000000000..0c6684fa9 Binary files /dev/null and b/ml/dlib/docs/docs/dlib-icon-32.png differ diff --git a/ml/dlib/docs/docs/dlib-icon-48.png b/ml/dlib/docs/docs/dlib-icon-48.png new file mode 100644 index 000000000..af5c0fb4f Binary files /dev/null and b/ml/dlib/docs/docs/dlib-icon-48.png differ diff --git a/ml/dlib/docs/docs/dlib-icon-64.png b/ml/dlib/docs/docs/dlib-icon-64.png new file mode 100644 index 000000000..056dc7d13 Binary files /dev/null and b/ml/dlib/docs/docs/dlib-icon-64.png differ diff --git a/ml/dlib/docs/docs/dlib-icon.ico b/ml/dlib/docs/docs/dlib-icon.ico new file mode 100644 index 000000000..31d431509 Binary files /dev/null and b/ml/dlib/docs/docs/dlib-icon.ico differ diff --git a/ml/dlib/docs/docs/dlib-logo-and-icons.svg b/ml/dlib/docs/docs/dlib-logo-and-icons.svg new file mode 100644 index 000000000..8a552c464 --- /dev/null +++ b/ml/dlib/docs/docs/dlib-logo-and-icons.svg @@ -0,0 +1,1602 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + image/svg+xml + + + + + + + + + + + + + 16x16 icon + larger icon + + + + + + Download + + + + + + + + diff --git a/ml/dlib/docs/docs/dlib-logo-small.png b/ml/dlib/docs/docs/dlib-logo-small.png new file mode 100644 index 000000000..6abcc9dfe Binary files /dev/null and b/ml/dlib/docs/docs/dlib-logo-small.png differ diff --git a/ml/dlib/docs/docs/dlib-logo.png b/ml/dlib/docs/docs/dlib-logo.png new file mode 100644 index 000000000..592227229 Binary files /dev/null and b/ml/dlib/docs/docs/dlib-logo.png differ diff --git a/ml/dlib/docs/docs/dlib.css b/ml/dlib/docs/docs/dlib.css new file mode 100644 index 000000000..e7e9fb45c --- /dev/null +++ b/ml/dlib/docs/docs/dlib.css @@ -0,0 +1,369 @@ + +/* + * Overall page layout stuff. It goes like: + * + * ============ page_header ============= + * | | + * ============ top_content ============= + * | | | | + * | main_menu | main_text | right_menu | + * | | | | + * ============ bottom_content ========== + * | | + * | . | + * | . | + * | . | + * | | + */ + +body +{ + margin:0px; + background-color: #EDF3EE; + width:62.0em; + margin-left: auto; + margin-right: auto; + font-size: 17px; +} + +h1,h2,h3 +{ + line-height:1.2; +} + +#page_header { + height:59px; + text-align: left; + margin-top: 0.4em; +} + +#top_content +{ + width: 100%; + height: 900px; + display: table; +} +#top_content > div +{ + display: table-cell; +} +#main_menu +{ + width: fit-content; + padding-bottom: 6em; /* So the menu_footer always has room. */ +} +#main_text +{ + background-color: white; + border: 1px solid rgb(102,102,102); + width: 100%; + line-height: 1.3em; +} +#main_text_title +{ + margin-top: 10px; + margin-bottom: 20px; + text-align: center; + font-size: 2.1em; + font-weight: bold; +} +#main_text_body +{ + vertical-align: top; + padding-top: 1px; + padding-left: 10px; + padding-right: 10px; +} + +#right_menu +{ + width: fit-content; + min-width: 9em; +} + +div.menu +{ + vertical-align: top; + position: relative; + background-color:#F5F5F5; + padding:7px; + border: 1px solid rgb(102,102,102); +} +div.menu_top { +} +div.menu_footer { + position: absolute; + bottom: 10px; +} + + +#bottom_content +{ + line-height: 1.3em; +} + +/* ============================================================= */ + +pre {margin:0px;} + +ul.tree li { list-style: none; margin-left:10px;} +ul.tree { margin:0px; padding:0px; margin-left:5px; font-size:0.95em; } +ul.tree li ul { margin-left:10px; padding:0px; } + +li#term { list-style: none; } + +div.include_file_more_details_wrapper +{ + margin-top: 1em; + margin-bottom: 5px; + overflow: auto; + width: 100%; + display: inline-block; +} +div.include_file +{ + font-size: 1.3em; + font-weight: bold; + font-family: monospace; + + float: right; + + padding: 0.4em; +} +a.more_details, a.more_details_extension +{ + font-size: 1.6em; + font-weight: bold; + float: left; + + + margin: 5px; + + font: 200 1.6em source-sans-pro, sans-serif; + text-align: center; + padding-left: 1.5em; + padding-right: 1.5em; + -webkit-font-smoothing: antialiased; + background-color: #f59820; + border-color: #d47c0a; + color: #fff; + + padding: 0.2em 1.5em 0.2em; + + -moz-box-shadow: 2px 2px 9px #777777; + -webkit-box-shadow: 2px 2px 9px #777777; + box-shadow: 1px 1px 5px #777777; + border-radius: 3px; + -moz-border-radius: 3px; + -webkit-border-radius: 3px; + + transition: 320ms; + display: inline-block; + +} +a.more_details:hover, a.more_details_extension:hover +{ + text-decoration: none; + -moz-box-shadow: 1px 1px 9px #77a777; + -webkit-box-shadow: 1px 1px 9px #77a777; + box-shadow: 1px 1px 5px #77a777; + color: #ffffff; + + background-color: #d5780a; +} + +a.more_details_extension +{ + margin-top: 10px; + padding: 0.1em 1.0em 0.1em; + display: block; + font-size: 1.3em; + float: none; + width: 160px; +} + +div.component { + background-color:white; + border: 2px solid rgb(102,102,102); + text-align:left; + margin-top: 1.5em; + padding: 0.7em; +} + +div.question { + background-color:white; + border: 2px solid rgb(102,102,102); + text-align:left; + margin-top: 1.5em; + margin-bottom: 90%; + padding: 0.7em; +} + +div.function { + background-color:white; + border: 2px solid rgb(102,102,255); + text-align:left; + margin-top: 0.3em; + padding: 0.3em; +} + +div.class { + background-color:white; + border: 2px solid rgb(255,102,102); + text-align:left; + margin-top: 0.3em; + padding: 0.3em; +} + +div.extension { + background-color:#FDFDFD; + border: 1px solid rgb(102,102,102); + text-align:left; + margin-top: 1.0em; + padding: 0.7em; +} + +div.logb { + text-align:left; + padding: 0.0em; + float: left; + background-color:#c0c0c0; + border: double ; + margin: 0.5em; +} + +div.name { + float: left; +} +div.line1 { + float:left; + width:100%; + background-color:#dfdfdf; +} +div.line2 { + float:left; + width:100%; +} +div.inc { + float: right; +} + +video +{ + border: black dotted 1px; + border-bottom: black solid 2px; +} + +.code_box +{ + color: black; + margin: 1em 0.25in; + padding: 0.5em; + background: rgb(240,240,240); + border-top: black dotted 1px; + border-left: black dotted 1px; + border-right: black solid 2px; + border-bottom: black solid 2px; +} + +tt +{ + padding: 0.3em; + font-size:13px; + font-family: monospace; + background: rgb(240,240,240); + color: black; +} + + + +.bdotted {border-bottom: 1px dotted} +.bdashed {border-bottom: 1px dashed} +.bsolid {border-bottom: 1px solid} +.bdouble {border-bottom: 1px double} +.bgroove {border-bottom: 1px groove} +.bridge {border-bottom: 1px ridge} +.binset {border-bottom: 1px inset} +.boutset {border-bottom: 1px outset} + +div.row1 { + background-color:#dfdfdf; +} +div.row2 { + background-color:#f2f2f2; +} + +div.typedefs { + margin-left: 1.5em; + margin-top: 0.2em; + border: 1px dotted; + width: 52em; +} + +div.tdn { + width: 10em; +} + +.fullhr { + clear: both; +} + + +a { + text-decoration: none; + font-family: sans-serif; +} +a:hover{ + text-decoration: underline; +} +a.menu{ + white-space: nowrap; +} +a.sub{ + cursor: pointer; + margin-left:-9px; + color: green; +} +#download_button { + font: 200 16px source-sans-pro, sans-serif; + text-align: center; + padding-left: 1.5em; + padding-right: 1.5em; + -webkit-font-smoothing: antialiased; + background-color: #2098f5; + border-color: #0a7cd4; + color: #fff; + + padding: 0.7em 1em 0.8em; + + -moz-box-shadow: 2px 2px 9px #777777; + -webkit-box-shadow: 2px 2px 9px #777777; + box-shadow: 1px 1px 5px #777777; + border-radius: 3px; + -moz-border-radius: 3px; + -webkit-border-radius: 3px; + + transition: 320ms; + display: inline-block; +} +#download_button:hover { + text-decoration: none; + -moz-box-shadow: 1px 1px 9px #77a777; + -webkit-box-shadow: 1px 1px 9px #77a777; + box-shadow: 1px 1px 5px #77a777; + color: #ffffff; + background-color: #0a7cd4; +} + +div { + display:block; +} + +#dlib_version { + color: #fff; + display: block; + font-size: 0.8em; + font-weight: 400; + margin: 0; +} + diff --git a/ml/dlib/docs/docs/dlib.js b/ml/dlib/docs/docs/dlib.js new file mode 100644 index 000000000..025bbf996 --- /dev/null +++ b/ml/dlib/docs/docs/dlib.js @@ -0,0 +1,94 @@ + +function init_page() +{ + if (navigator.appVersion.indexOf("Win")!=-1) + { + var a = document.getElementById("download_button"); + a.href = a.href.replace("tar.bz2", "zip"); + } +} +window.onload = init_page; + +// -------------------------------------------------------------- +// Tree collapse stuff +// -------------------------------------------------------------- + +function Toggle(node) +{ + // Unfold the branch if it isn't visible + var next_node = node.nextSibling; + if (next_node.style.display == 'none') + { + // Change the image (if there is an image) + if (node.childNodes.length > 0) + { + if (node.childNodes.length > 0) + { + if (node.childNodes.item(0).nodeName == "IMG") + { + node.childNodes.item(0).src = "minus.gif"; + } + } + } + + next_node.style.display = 'block'; + } + // Collapse the branch if it IS visible + else + { + // Change the image (if there is an image) + if (node.childNodes.length > 0) + { + if (node.childNodes.length > 0) + { + if (node.childNodes.item(0).nodeName == "IMG") + { + node.childNodes.item(0).src = "plus.gif"; + } + } + } + + next_node.style.display = 'none'; + } + +} +function BigToggle(node) +{ + // Unfold the branch if it isn't visible + var next_node = node.nextSibling; + if (next_node.style.display == 'none') + { + // Change the image (if there is an image) + if (node.childNodes.length > 0) + { + if (node.childNodes.length > 0) + { + if (node.childNodes.item(0).nodeName == "IMG") + { + node.childNodes.item(0).src = "bigminus.gif"; + } + } + } + + next_node.style.display = 'block'; + } + // Collapse the branch if it IS visible + else + { + // Change the image (if there is an image) + if (node.childNodes.length > 0) + { + if (node.childNodes.length > 0) + { + if (node.childNodes.item(0).nodeName == "IMG") + { + node.childNodes.item(0).src = "bigplus.gif"; + } + } + } + + next_node.style.display = 'none'; + } + +} + diff --git a/ml/dlib/docs/docs/down.gif b/ml/dlib/docs/docs/down.gif new file mode 100644 index 000000000..5d7bb6467 Binary files /dev/null and b/ml/dlib/docs/docs/down.gif differ diff --git a/ml/dlib/docs/docs/enable_if.html b/ml/dlib/docs/docs/enable_if.html new file mode 100644 index 000000000..09cc2d4d2 --- /dev/null +++ b/ml/dlib/docs/docs/enable_if.html @@ -0,0 +1,387 @@ + + +enable_if + + + + + + + + + + +
    +
    + + +

    +enable_if

    +
    +
    +Copyright 2003 Jaakko Järvi, Jeremiah Willcock, Andrew Lumsdaine.
    +
    + + +

    1  Introduction

    + + +The enable_if family of templates is a set of tools to allow a function template or a class template specialization +to include or exclude itself from a set of matching functions or specializations +based on properties of its template arguments. +For example, one can define function templates that +are only enabled for, and thus only match, an arbitrary set of types +defined by a traits class. The enable_if templates can also be +applied to enable class template specializations. Applications of +enable_if are discussed in length +in [1] and [2].
    +
    + + +

    1.1  Synopsis

    + + +
    namespace boost {
    +  template <class Cond, class T = void> struct enable_if;
    +  template <class Cond, class T = void> struct disable_if;
    +  template <class Cond, class T> struct lazy_enable_if;
    +  template <class Cond, class T> struct lazy_disable_if;
    +
    +  template <bool B, class T = void> struct enable_if_c;
    +  template <bool B, class T = void> struct disable_if_c;
    +  template <bool B, class T> struct lazy_enable_if_c;
    +  template <bool B, class T> struct lazy_disable_if_c;
    +}
    +
    + + +

    1.2  Background

    + + +Sensible operation of template function overloading in C++ relies +on the SFINAE (substitution-failure-is-not-an-error) +principle [3]: if an invalid argument +or return type is formed during the instantiation of a function +template, the instantiation is removed from the overload resolution +set instead of causing a compilation error. The following example, +taken from [1], +demonstrates why this is important: +
    int negate(int i) { return -i; }
    +
    +template <class F>
    +typename F::result_type negate(const F& f) { return -f(); }
    +
    +
    +Suppose the compiler encounters the call negate(1). The first +definition is obviously a better match, but the compiler must +nevertheless consider (and instantiate the prototypes) of both +definitions to find this out. Instantiating the latter definition with +F as int would result in: +
    int::result_type negate(const int&);
    +
    +
    +where the return type is invalid. If this was an error, adding an unrelated function template +(that was never called) could break otherwise valid code. +Due to the SFINAE principle the above example is not, however, erroneous. +The latter definition of negate is simply removed from the overload resolution set.
    +
    +The enable_if templates are tools for controlled creation of the SFINAE +conditions.
    +
    + + +

    2  The enable_if templates

    + + +The names of the enable_if templates have three parts: an optional lazy_ tag, +either enable_if or disable_if, and an optional _c tag. +All eight combinations of these parts are supported. +The meaning of the lazy_ tag is described in Section 3.3. +The second part of the name indicates whether a true condition argument should +enable or disable the current overload. +The third part of the name indicates whether the condition argument is a bool value +(_c suffix), or a type containing a static bool constant named value (no suffix). +The latter version interoperates with Boost.MPL.
    +
    +The definitions of enable_if_c and enable_if are as follows (we use enable_if templates +unqualified but they are in the boost namespace). +
    template <bool B, class T = void>
    +struct enable_if_c {
    +  typedef T type;
    +};
    +
    +template <class T>
    +struct enable_if_c<false, T> {};
    +
    +template <class Cond, class T = void>
    +struct enable_if : public enable_if_c<Cond::value, T> {};
    +
    +
    +An instantiation of the enable_if_c template with the parameter +B as true contains a member type type, defined +to be T. If B is +false, no such member is defined. Thus +enable_if_c<B, T>::type is either a valid or an invalid type +expression, depending on the value of B. +When valid, enable_if_c<B, T>::type equals T. +The enable_if_c template can thus be used for controlling when functions are considered for +overload resolution and when they are not. +For example, the following function is defined for all arithmetic types (according to the +classification of the Boost type_traits library): +
    template <class T>
    +typename enable_if_c<boost::is_arithmetic<T>::value, T>::type 
    +foo(T t) { return t; }
    +
    +
    +The disable_if_c template is provided as well, and has the +same functionality as enable_if_c except for the negated condition. The following +function is enabled for all non-arithmetic types. +
    template <class T>
    +typename disable_if_c<boost::is_arithmetic<T>::value, T>::type 
    +bar(T t) { return t; }
    +
    +
    +For easier syntax in some cases and interoperation with Boost.MPL we provide versions of +the enable_if templates taking any type with a bool member constant named +value as the condition argument. +The MPL bool_, and_, or_, and not_ templates are likely to be +useful for creating such types. Also, the traits classes in the Boost.Type_traits library +follow this convention. +For example, the above example function foo can be alternatively written as: +
    template <class T>
    +typename enable_if<boost::is_arithmetic<T>, T>::type 
    +foo(T t) { return t; }
    +
    +
    + + +

    3  Using enable_if

    + + +The enable_if templates are defined in +boost/utility/enable_if.hpp, which is included by boost/utility.hpp.
    +
    +The enable_if template can be used either as the return type, or as an +extra argument. For example, the foo function in the previous section could also be written +as: +
    template <class T>
    +T foo(T t, typename enable_if<boost::is_arithmetic<T> >::type* dummy = 0); 
    +
    +
    Hence, an extra parameter of type void* is added, but it is given +a default value to keep the parameter hidden from client code. +Note that the second template argument was not given to enable_if, as the default +void gives the desired behavior.
    +
    +Whether to write the enabler as an argument or within the return type is +largely a matter of taste, but for certain functions, only one +alternative is possible: +
    • +Operators have a fixed number of arguments, thus enable_if must be used in the return type. +
    • Constructors and destructors do not have a return type; an extra argument is the only option. +
    • There does not seem to be a way to specify an enabler for a conversion operator. Converting constructors, +however, can have enablers as extra default arguments. +
    + + +

    3.1  Enabling template class specializations

    + + +Class template specializations can be enabled or disabled with enable_if. +One extra template parameter needs to be added for the enabler expressions. +This parameter has the default value void. +For example: +
    template <class T, class Enable = void> 
    +class A { ... };
    +
    +template <class T>
    +class A<T, typename enable_if<is_integral<T> >::type> { ... };
    +
    +template <class T>
    +class A<T, typename enable_if<is_float<T> >::type> { ... };
    +
    +
    Instantiating A with any integral type matches the first specialization, +whereas any floating point type matches the second one. All other types +match the primary template. +The condition can be any compile-time boolean expression that depends on the +template arguments of the class. +Note that again, the second argument to enable_if is not needed; the default (void) +is the correct value.
    +
    + + +

    3.2  Overlapping enabler conditions

    + + +Once the compiler has examined the enabling conditions and included the +function into the overload resolution set, normal C++ overload resolution +rules are used to select the best matching function. +In particular, there is no ordering between enabling conditions. +Function templates with enabling conditions that are not mutually exclusive can +lead to ambiguities. For example: +
    template <class T>
    +typename enable_if<boost::is_integral<T>, void>::type 
    +foo(T t) {}
    +
    +template <class T>
    +typename enable_if<boost::is_arithmetic<T>, void>::type 
    +foo(T t) {}
    +
    +
    +All integral types are also arithmetic. Therefore, say, for the call foo(1), +both conditions are true and both functions are thus in the overload resolution set. +They are both equally good matches and thus ambiguous. +Of course, more than one enabling condition can be simultaneously true as long as +other arguments disambiguate the functions.
    +
    +The above discussion applies to using enable_if in class template +partial specializations as well.
    +
    + + +

    3.3  Lazy enable_if

    + + +In some cases it is necessary to avoid instantiating part of a +function signature unless an enabling condition is true. For example: +
    template <class T, class U> class mult_traits;
    +
    +template <class T, class U>
    +typename enable_if<is_multipliable<T, U>, typename mult_traits<T, U>::type>::type
    +operator*(const T& t, const U& u) { ... }
    +
    +
    Assume the class template mult_traits is a traits class defining +the resulting type of a multiplication operator. The is_multipliable traits +class specifies for which types to enable the operator. Whenever +is_multipliable<A, B>::value is true for some types A and B, +then mult_traits<A, B>::type is defined.
    +
    +Now, trying to invoke (some other overload) of operator* with, say, operand types C and D +for which is_multipliable<C, D>::value is false +and mult_traits<C, D>::type is not defined is an error on some compilers. +The SFINAE principle is not applied because +the invalid type occurs as an argument to another template. The lazy_enable_if +and lazy_disable_if templates (and their _c versions) can be used in such +situations: +
    template<class T, class U>
    +typename lazy_enable_if<is_multipliable<T, U>, mult_traits<T, U> >::type
    +operator*(const T& t, const U& u) { ... }
    +
    +
    The second argument of lazy_enable_if must be a class type +that defines a nested type named type whenever the first +parameter (the condition) is true.
    +
    + + +
    Note
    + +Referring to one member type or static constant in a traits class +causes all of the members (type and static constant) of that +specialization to be instantiated. Therefore, if your traits classes +can sometimes contain invalid types, you should use two distinct +templates for describing the conditions and the type mappings. In the +above example, is_multipliable<T, U>::value defines when +mult_traits<T, U>::type is valid.
    +
    + + +

    3.4  Compiler workarounds

    + + +Some compilers flag functions as ambiguous if the only distinguishing factor is a different +condition in an enabler (even though the functions could never be ambiguous). For example, +some compilers (e.g. GCC 3.2) diagnose the following two functions as ambiguous: +
    template <class T>
    +typename enable_if<boost::is_arithmetic<T>, T>::type 
    +foo(T t);
    +
    +template <class T>
    +typename disable_if<boost::is_arithmetic<T>, T>::type 
    +foo(T t);
    +
    +
    Two workarounds can be applied: +
    • +Use an extra dummy parameter which disambiguates the functions. Use a default value for +it to hide the parameter from the caller. For example: +
      template <int> struct dummy { dummy(int) {} };
      +
      +template <class T>
      +typename enable_if<boost::is_arithmetic<T>, T>::type 
      +foo(T t, dummy<0> = 0);
      +
      +template <class T>
      +typename disable_if<boost::is_arithmetic<T>, T>::type 
      +foo(T t, dummy<1> = 0);
      +

      +
      +
    • Define the functions in different namespaces and bring them into a common +namespace with using declarations: +
      namespace A {
      +  template <class T>
      +  typename enable_if<boost::is_arithmetic<T>, T>::type 
      +  foo(T t);
      +}
      +
      +namespace B {
      +  template <class T>
      +  typename disable_if<boost::is_arithmetic<T>, T>::type 
      +  foo(T t);
      +}
      +
      +using A::foo;
      +using B::foo;
      +
      +
      +Note that the second workaround above cannot be used for member +templates. On the other hand, operators do not accept extra arguments, +which makes the first workaround unusable. As the net effect, +neither of the workarounds are of assistance for templated operators that +need to be defined as member functions (assignment and +subscript operators). +
    + + +

    4  Acknowledgements

    + +We are grateful to Howard Hinnant, Jason Shirk, Paul Mensonides, and Richard +Smith whose findings have influenced the library.
    +
    + + +

    References

    +
    [1]
    +Jaakko Järvi, Jeremiah Willcock, Howard Hinnant, and Andrew Lumsdaine. +Function overloading based on arbitrary properties of types. +C/C++ Users Journal, 21(6):25--32, June 2003.
    +
    +
    [2]
    +Jaakko Järvi, Jeremiah Willcock, and Andrew Lumsdaine. +Concept-controlled polymorphism. +In Frank Pfennig and Yannis Smaragdakis, editors, Generative + Programming and Component Engineering, volume 2830 of LNCS, pages + 228--244. Springer Verlag, September 2003.
    +
    +
    [3]
    +David Vandevoorde and Nicolai M. Josuttis. +C++ Templates: The Complete Guide. +Addison-Wesley, 2002.
    + + + + + +
    + +Contributed by:
    +Jaakko Järvi, Jeremiah Willcock and Andrew Lumsdaine
    +{jajarvi|jewillco|lums}@osl.iu.edu
    +Indiana University
    +Open Systems Lab + + + +
    +
    This document was translated from LATEX by +HEVEA. +
    + + diff --git a/ml/dlib/docs/docs/face_landmarking_example.png b/ml/dlib/docs/docs/face_landmarking_example.png new file mode 100644 index 000000000..27f70ecbf Binary files /dev/null and b/ml/dlib/docs/docs/face_landmarking_example.png differ diff --git a/ml/dlib/docs/docs/faq.xml b/ml/dlib/docs/docs/faq.xml new file mode 100644 index 000000000..dbec1b847 --- /dev/null +++ b/ml/dlib/docs/docs/faq.xml @@ -0,0 +1,547 @@ + + + + + Frequently Asked Questions + + + + + + You are getting this error because you are not compiling all the C++ + code in your program with consistent settings. This is a violation of + C++'s One Definition Rule. + In this case, you are compiling some translation units with dlib's + assert macros enabled and others with them disabled. +

    + For reference, the code that generates this error is: + dlib/test_for_odr_violations.h and + dlib/test_for_odr_violations.cpp. +

    +
    + + + + You are getting this error because you are not compiling all the C++ + code in your program with consistent settings. This is a violation of + C++'s One Definition Rule. + In this case, you compiled a standalone copy of dlib with CMake and instead of using make install + or cmake --build . --target install to copy the resulting build files somewhere you went + and cherry picked files manually and messed it up. In particular, CMake compiled dlib + with a bunch of settings recorded in the CMake generated config.h file but you instead + are now trying to build more dlib related code with the + dlib/config.h from source control. +

    + For reference, the code that generates this error is: + dlib/test_for_odr_violations.h and + dlib/test_for_odr_violations.cpp. +

    +
    + + + Do not post a question like "I'm using dlib, and it doesn't work?" or + "I'm using the object detector and it doesn't work, what do I do?". + If this is all you say then I have no idea what is wrong. 99% of the + time it's some kind of user error. 1% of the time it's some problem + in dlib. But again, without more information it's impossible to know. + So please don't post questions like this. + +

    + If you think you found some kind of bug or problem in dlib then feel + free to submit a dlib issue on github. + But include the version of dlib you are using, what you + are trying, what happened, what you expected to have happened instead, etc. +

    + +

    + On the other hand, if you haven't found a bug or problem in dlib, but + instead are looking for machine learning/computer vision/programming + help then post your question to stack overflow with the dlib tag. +

    +
    + + +

    + First, note that you need a version of Visual Studio with decent + C++11 support. This means you need Visual Studio 2015 or newer. +

    + There are instructions on the How to Compile page. + If you do not understand the instructions in the "Compiling on Windows Using Visual Studio" section + or are getting errors then follow the instructions in the "Compiling on Any Operating System Using CMake" + section. In particular, install CMake and then type + these exact commands from within the root of the dlib distribution: + +cd examples +mkdir build +cd build +del /F /S /Q * +cmake .. +cmake --build . --config Release + + That should compile the dlib examples in visual studio. The output executables will appear in the Release folder. + The del /F /S /Q * command is to make sure you clear out any extraneous files you might have placed in + the build folder and is not necessary if build begins empty. +
    + + + + + Dlib isn't slow. I get this question many times a week and 95% of the time it's from someone + using Visual Studio who has compiled their program in Debug mode rather than the optimized + Release mode. So if you are using Visual Studio then realize that Visual Studio has these two modes. + The default is Debug. The mode is selectable via a drop down: +

    + Debug mode disables compiler optimizations. So the program will be very slow if you run it in Debug mode. + So click the drop down, +

    + and select Release. +

    + Then when you compile the program it will appear in a folder named Release rather than in a folder named Debug. + +
    +
    + Finally, you can enable either SSE4 or AVX instruction use. These will make certain operations much faster (e.g. face detection). + You do this using CMake's cmake-gui tool. For example, if you execute + these commands you will get the cmake-gui screen: + +cd examples +mkdir build +cd build +cmake .. +cmake-gui . + + Which looks like this: +

    + Where you can select SSE4 or AVX instruction use. Then you click configure and then generate. After that + when you build your visual studio project some things will be faster. + + Finally, note that AVX is a little bit faster than SSE4 but if your computer is fairly old it might + not support it. In that case, either buy a new computer or use SSE4 + instructions. +
    + + + + + + + If you use dlib in your research then please use the following citation: +
    +
    +Davis E. King. Dlib-ml: A Machine Learning Toolkit. + Journal of Machine Learning Research 10, pp. 1755-1758, 2009 +
    +
    +
    +@Article{dlib09,
    +  author = {Davis E. King},
    +  title = {Dlib-ml: A Machine Learning Toolkit},
    +  journal = {Journal of Machine Learning Research},
    +  year = {2009},
    +  volume = {10},
    +  pages = {1755-1758},
    +}
    +         
    +
    + + + + + + + Here are the possibilities: +
      +
    • You are using a file stream and forgot to put it into binary mode. + You need to do something like this: + +std::ifstream fin("myfile", std::ios::binary); + +or + +std::ofstream fout("myfile", std::ios::binary); + + +If you don't give std::ios::binary then the iostream will mess with the binary data and cause serialization +to not work right. +
    • + +
      +
    • The iostream is in a bad state. You can check the state by calling mystream.good(). + If it returns false then the stream is in an error state such as end-of-file or maybe it failed + to do the I/O. Also note that if you close a file stream and reopen it you might have to call + mystream.clear() to clear out the error flags. +
    • +
    + +
    + + + + Long answer, read the matrix example program. +

    + Short answer, here are some examples: + +matrix<double> mat; +mat.set_size(4,5); + +matrix<double,0,1> column_vect; +column_vect.set_size(6); + +matrix<double,0,1> column_vect2(6); // give size to constructor + +matrix<double,1> row_vect; +row_vect.set_size(5); + + +
    + + + + If you can't find something then check the index. +

    + Also, the bulk of the documentation can be found by following the + links. +
    + + + + + + There should never be anything in dlib that prevents you from using or + interacting with other libraries. Moreover, there are some additional tools + in dlib to make some interactions easier: + +
      +
    • BLAS and LAPACK libraries are used by the matrix + automatically if you #define DLIB_USE_BLAS and/or DLIB_USE_LAPACK and link against + the appropriate library files. Note that the CMakeLists.txt file that comes with dlib will + do this for you automatically in many instances.

    • + +
    • Armadillo and Eigen libraries have matrix objects which can be converted into + dlib matrix objects by calling dlib::mat() on them.

    • + +
    • OpenCV image objects can be converted into a form usable by dlib routines + by using cv_image. You can also convert from a + dlib matrix or image to an OpenCV Mat using dlib::toMat().

    • + +
    • Google Protocol Buffers can be serialized by the dlib + serialization routines. + This means that, for example, you can pass protocol buffer objects through a + bridge. +

    • + +
    • libpng and libjpeg are used by load_image whenever + DLIB_PNG_SUPPORT and DLIB_JPEG_SUPPORT are defined respectively. + You must also tell your compiler to link against these libraries to + use them. However, CMake will try to link against them + automatically if they are installed.

    • + + + +
    • SQLite is used by the database object. In + fact, it is just a wrapper around SQLite's C interface which simplifies its use (e.g. + makes resource management use RAII).
    • + +
    +
    + +
    + + + + + + + + + + The optimization algorithm is somewhat unpredictable. Sometimes it is fast and + sometimes it is slow. What usually makes it really slow is if you use a radial basis + kernel and you set the gamma parameter to something too large. This causes the + algorithm to start using a whole lot of relevance vectors (i.e. basis vectors) which + then makes it slow. The algorithm is only fast as long as the number of relevance vectors + remains small but it is hard to know beforehand if that will be the case. +

    + You should try kernel ridge regression instead since it + also doesn't take any parameters but is always very fast. +

    + +
    + + + + + This function makes a copy of your training data for each thread. So you are probably running out + of memory. To avoid this, use the randomly_subsample function + to reduce the amount of data you are using or use fewer threads. +

    + For example, you could reduce the amount of data by saying this: + +// reduce to only 1000 samples +cross_validate_trainer_threaded(trainer, + randomly_subsample(samples, 1000), + randomly_subsample(labels, 1000), + 4, // num folds + 4); // num threads + +

    +
    + + + + + See the Using Custom Kernels example program. + + + + + +

    + Picking the right kernel all comes down to understanding your data, and obviously this is + highly dependent on your problem. +

    + +

    + One thing that's sometimes useful is to plot each feature against the target value. You can get an idea of + what your overall feature space looks like and maybe tell if a linear kernel is the right solution. But + this still hides important information from you. For example, imagine you have two diagonal lines which + are very close together and are both the same length. Suppose one line is of the +1 class and the other is the -1 + class. Each feature (the x or y coordinate values) by itself tells you almost nothing about which class + a point belongs to but together they tell you everything you need to know. +

    + +

    + On the other hand, if you know something about the data you are working with then you can also try and + generate your own features. So for example, if your data is a bunch of images and you know that one + of your classes contains a lot of lines then you can make a feature that attempts to measure the number + of lines in an image using a hough transform or sobel edge filter or whatever. Generally, try and + think up features which should be highly correlated with your target value. A good way to do this is + to try and actually hand code N solutions to the problem using whatever you know about your data or + domain. If you do a good job then you will have N really great features and a linear or rbf kernel + will probably do very well when using them. +

    + +

    + Or you can just try a whole bunch of kernels, kernel parameters, and training algorithm options while + using cross validation. I.e. when in doubt, use brute force :) There is an example of that kind of + thing in the model selection example program. +

    +
    + + + + + This happens when you use the radial_basis_kernel and you set the gamma value to + something highly inappropriate. To understand what's happening lets imagine your + data has just one feature and its value ranges from 0 to 7. Then what you want is a + gamma value that gives nice Gaussian bumps like the one in this graph:
    + +
    + +
    + However, if you make gamma really huge you will get this (it's zero everywhere except for one place): +
    +
    + +
    + Or if you make gamma really small then it will be 1.0 everywhere: +
    +
    + +

    + So you need to pick the gamma value so that it is scaled reasonably to your data. A good rule of + thumb (i.e. not the optimal gamma, just a heuristic guess) is the following: +

    + const double gamma = 1.0/compute_mean_squared_distance(randomly_subsample(samples, 2000)); + +
    + + + +
    + + + + + + + There are three general mistakes people make when trying to train an object detector with dlib. +
      +
    • Not labeling all the objects in each image

      + The tools for training object detectors in dlib use the Max-Margin Object Detection + loss. This loss optimizes the performance of the detector on the whole image, not on some subset of windows cropped from the training data. + That means it counts the number of missed detections and false alarms for each of the training images and tries to find a way + to minimize the sum of these two error metrics. For this to be possible, you must label all the objects in each training image. + If you leave unannotated objects in some of your training images then the loss will think any detections on these unannotated objects + are false alarms, and will therefore try to find a detector that doesn't detect them. If you have enough unannotated objects, the + most accurate detector will be the one that never detects anything. That's obviously not what you want. So make sure you annotate all the + objects in each image. +

      + Sometimes annotating all the objects in each image is too + onerous, or there are ambiguous objects you don't care about. + In these cases you should annotate these objects you don't + care about with ignore boxes so that the MMOD loss knows to + ignore them. You can do this with dlib's imglab tool by + selecting a box and pressing i. Moreover, there are two ways + the code treats ignore boxes. When a detector generates a + detection it compares it against any ignore boxes and ignores + it if the boxes "overlap". Deciding if they overlap is based + on either their intersection over union or just basic percent + coverage of one by another. You have to think about what + mode you want when you annotate things and configure the + training code appropriately. The default behavior is to use + intersection over union to measure overlap. However, if you + wanted to simply mask out large parts of an image you + wouldn't want to use intersection over union to measure + overlap since small boxes contained entirely within the large + ignored region would have small IoU with the big ignore region and thus not "overlap" + the ignore region. In this case you should change the + settings to reflect this before training. The available configuration + options are discussed in great detail in parts of dlib's documentation. +

      +
    • + +
    • Using training images that don't look like the testing images

      + This should be obvious, but needs to be pointed out. If there + is some clear difference between your training and testing + images then you have messed up. You need to show the training + algorithm real images so it can learn what to do. If instead + you only show it images that look obviously different from your + testing images don't be surprised if, when you run the detector + on the testing images, it doesn't work. As a rule of thumb, + a human should not be able to tell if an image came from the training dataset or testing dataset. + +

      + Here are some examples of bad datasets: +

        +
      • A training dataset where objects always appear with + some specific orientation but the testing images have a + diverse set of orientations.
      • +
      • A training dataset where objects are tightly cropped, but testing images that are uncropped.
      • +
      • A training dataset where objects appear only on a perfectly white background with nothing else present, but testing images where objects appear in a normal environment like living rooms or in natural scenes.
      • +
      +

      + +

      + Another way you can mess this up is when using the random_cropper to jitter your training data, which is + common when training a CNN or other deep model. In general, the random_cropper finds images + that are more or less centered on your objects of interest and it also scales the images so + the object has some user specified minimum size. That's all fine. But what can happen is + you train a model that gets 0 training error but when you go and use it it doesn't detect + any objects. Why is that? It's probably because all the objects in your normal images, the + ones you give to the random_cropper, are really small. Smaller than the min size you told + the cropper to make them. So now your testing images are really different from your training + images. Moreover, in general object detectors have some minimum size they scan and if + objects are smaller than that they will never be found. Another related issue is all your + uncropped images might show objects at the very border of the image. But the random_cropper + will center the objects in the crops, by padding with zeros if necessary. Again, make your + testing images look like the training images. Pad the edges of your images with zeros if + needed. +

      +
    • + +
    • Using a HOG based detector but not understanding the limits of HOG templates

      + The HOG detector is very fast and generally easy to train. However, you + have to be aware that HOG detectors are essentially rigid templates that are scanned over an image. So a single HOG detector + isn't going to be able to detect objects that appear in a wide range of orientations or undergo complex deformations or have complex + articulation. +

      + For example, a HOG detector isn't going to be able to learn to detect human faces that are upright as well as faces rotated 90 degrees. + If you wanted to deal with that you would be best off training 2 detectors. One for upright faces and another for 90 degree rotated faces. + You can efficiently run multiple HOG detectors at once using the evaluate_detectors function, so it's not a huge deal to do this. Dlib's imglab tool also has a --cluster option that will help you split a training dataset into clusters that can + be detected by a single HOG detector. You will still need to manually review and clean the dataset after applying --cluster, but it makes + the process of splitting a dataset into coherent poses, from the point of view of HOG, a lot easier. +

      +

      + A related issue arises because HOG is a rigid template, which is that the boxes in your training data need to all have essentially the same + aspect ratio. For instance, a single HOG filter can't possibly detect objects that are both 100x50 pixels and 50x100 pixels. To do this you + would need to split your dataset into two parts, objects with a 2:1 aspect ratio and objects with a 1:2 aspect ratio and then train two separate + HOG detectors, one for each aspect ratio. +

      +

      + However, it should be emphasized that even using multiple HOG detectors will only get you so far. So at some point you should consider + using a CNN based detection method since CNNs can generally deal with arbitrary + rotations, poses, and deformations with one unified + detector. +

      +
    • +
    +
    +
    + + + + + + You can, but you need to use Visual Studio 2015 Update 3 or newer since prior versions + had bad C++11 support. To make this as confusing as possible, + Microsoft has released multiple different versions of "Visual Studio + 2015 Update 3". As of October 2016, the version available from the + Microsoft web page has good enough C++11 support to compile the DNN + tools in dlib. So make sure you have a version no older than October + 2016. +

    + To make this even more complicated, Visual Studio 2017 had + regressions in its C++11 support. So all versions of Visual Studio + 2017 prior to December 2017 would just hang if you tried to compile + the DNN examples. Happily, the newest versions of Visual Studio + 2017 appear to have good C++11 support and will compile the DNN + codes without any issue. So make sure your Visual Studio is + fully updated. +

    +

    + Finally, it should be noted that you should give the -T host=x64 + cmake option when generating a Visual Studio project. If you don't + do this then you will get the default Visual Studio toolchain, + which runs the compiler in 32bit mode, restricting it to 2GB of + RAM, leading to compiler crashes due to it running out of RAM in some + cases. This isn't the 1990s anymore, so you should probably + run your compiler in 64bit mode so it can use your computer's RAM. + Giving -T host=x64 will let Visual Studio use as much RAM + as it needs. +

    +
    + + + A major design goal of this API is to let users create new loss + layers, computational layers, and solvers without needing to + understand or even look at the dlib internals. A lot of the API + decisions are based on what makes the interface a user needs to + implement to create new layers as simple as possible. In particular, + designing the API in this compile-time static way makes it simple for + these use cases. +

    + Here is an example of one problem it addresses. Since dlib + exposes the entire network architecture to the C++ type system we + can get automatic serialization of networks. Without this, we + would have to resort to the kind of hacky global layer registry + used in other tools that compose networks entirely at runtime. +

    +

    + Another nice feature is that we get to use C++11 alias template + statements to create network sub-blocks, which we can then use to easily + define very large networks. There are examples of this in this example program. It + should also be pointed out that it takes days or even weeks to + train one network. So it isn't as if you will be writing a + program that loops over large numbers of networks and trains them + all. This makes the time needed to recompile a program to change + the network irrelevant compared to the entire training time. + Moreover, there are plenty of compile time constructs in C++ you + can use to enumerate network architectures (e.g. loop + over filter widths) if you really wanted to do so. +

    +

    + All that said, if you think you found a compelling use case that isn't supported + by the current API feel free to post a github issue. +

    +
    +
    + + + + +
    diff --git a/ml/dlib/docs/docs/find_max_global_example.mp4 b/ml/dlib/docs/docs/find_max_global_example.mp4 new file mode 100644 index 000000000..2514abe19 Binary files /dev/null and b/ml/dlib/docs/docs/find_max_global_example.mp4 differ diff --git a/ml/dlib/docs/docs/find_max_global_example.png b/ml/dlib/docs/docs/find_max_global_example.png new file mode 100644 index 000000000..42b2eb2b6 Binary files /dev/null and b/ml/dlib/docs/docs/find_max_global_example.png differ diff --git a/ml/dlib/docs/docs/find_max_global_example.webm b/ml/dlib/docs/docs/find_max_global_example.webm new file mode 100644 index 000000000..64556b345 Binary files /dev/null and b/ml/dlib/docs/docs/find_max_global_example.webm differ diff --git a/ml/dlib/docs/docs/find_max_global_results_table.svg b/ml/dlib/docs/docs/find_max_global_results_table.svg new file mode 100644 index 000000000..8bdc46531 --- /dev/null +++ b/ml/dlib/docs/docs/find_max_global_results_table.svg @@ -0,0 +1,3398 @@ + + + +image/svg+xmlProblemHolderTableComplexHolderRosenbrockSphereDebN.1 +MaxLIPO+TR50 +( +± +28) +148 +( +± +90)23( +± +10)39( +± +8)169( +± +161) +PRS+TR +91( +± +96)1589( +± +2941) +19 +( +± +9)44( +± +9)168( +± +150) +LIPO+TR +56( +± +41)213( +± +130) +19 +( +± +9) +35 +( +± +7) +166 +( +± +126) +LIPO +140( +± +71)245( +± +116)61( +± +55)53( +± +16)- +PRS +1842( +± +1751)36421( +± +37249)118( +± +114)198274( +± +199449)529251( +± +538740) +target99% +MaxLIPO+TR51 +( +± +27) +150 +( +± +89) +149 +( +± +84) +23 +( +± +5)138( +± +160) +PRS+TR +92( +± +96)1590( +± +2942)172( +± +90)28( +± +6)139( +± +149) +LIPO+TR +58( +± +42)207( +± +131)163( +± +92) +23 +( +± +5) +135 +( +± +124) +LIPO +161( +± +81)264( +± +119)-38( +± +14)- +PRS +3043( +± +2901)75589( +± +76490)108873( +± +112572)1929( +± +2034)3933( +± +3869) + +=0 +. +1 +MaxLIPO+TR54 +( +± +27) +156 +( +± +89) +189 +( +± +88) +57 +( +± +9)175( +± +161) +PRS+TR +91( +± +91)1590( +± +3106)221( +± +99)62( +± +11) +174 +( +± +137) +LIPO+TR +61( +± +45)216( +± +129)210( +± +98)57( +± +10)181( +± +166) +LIPO +----- +PRS +31351( +± +31805)730957( +± +703887)3 +. +7 +× +10 +6 +( +± +3 +. +7 +× +10 +6 +)2 +× +10 +7 +( +± +2 +× +10 +7 +)1 +. +6 +× +10 +6 +( +± +1 +. +6 +× +10 +6 +) + +=0 +. +01 +MaxLIPO+TR57 +( +± +28) +165 +( +± +87) +219 +( +± +89)92( +± +15)199( +± +171) +PRS+TR +94( +± +87)1776( +± +3437)236( +± +97)99( +± +17)202( +± +174) +LIPO+TR +63( +± +43)221( +± +132)232( +± +103) +91 +( +± +15) +198 +( +± +152) +LIPO +----- +PRS +----- + +=0 +. +001 +MaxLIPO+TR66 +( +± +27) +169 +( +± +89) +238 +( +± +88) +400 +( +± +49)236( +± +162) +PRS+TR +102( +± +92)1548( +± +2989)264( +± +96)411( +± +52)236( +± +168) +LIPO+TR +73( +± +43)230( +± +130)260( +± +100)400( +± +53) +232 +( +± +165) +LIPO +----- +PRS +----- + +=10 + +9 + \ No newline at end of file diff --git a/ml/dlib/docs/docs/graph_tools.xml b/ml/dlib/docs/docs/graph_tools.xml new file mode 100644 index 000000000..89770fca0 --- /dev/null +++ b/ml/dlib/docs/docs/graph_tools.xml @@ -0,0 +1,678 @@ + + + + + Graph Tools + + + + + +

    + In dlib, there are two types of graph representations. On the one + hand, there are graphs based on an object which encapsulates the whole + graph, such as the graph and + directed_graph objects. On the + other hand, there are graphs which are represented as simple vectors + of edges. In this case, we use vectors of sample_pair + or ordered_sample_pair objects for undirected + and directed graphs respectively. +

    + + + + + + + + +
    + Graph Object Based Graphs + graph_contains_directed_cycle + graph_has_symmetric_edges + graph_contains_undirected_cycle + create_moral_graph + triangulate_graph_and_find_cliques + graph_contains_length_one_cycle + find_connected_nodes + graph_is_connected + is_clique + is_maximal_clique + copy_graph_structure + copy_graph + edge + is_join_tree + create_join_tree +
    + + +
    + Creating Edge List Based Graphs + sample_pair + ordered_sample_pair + find_percent_shortest_edges_randomly + find_k_nearest_neighbors + find_k_nearest_neighbors_lsh + find_approximate_k_nearest_neighbors + + Distance Functions + + negative_dot_product_distance + squared_euclidean_distance + cosine_distance + + +
    + +
    + Using Edge List Based Graphs + remove_short_edges + remove_duplicate_edges + remove_long_edges + remove_percent_longest_edges + remove_percent_shortest_edges + use_weights_of_one + use_gaussian_weights + is_ordered_by_index + find_neighbor_ranges + convert_unordered_to_ordered + order_by_index + order_by_distance + order_by_descending_distance + order_by_distance_and_index + contains_duplicate_pairs + max_index_plus_one +
    +
    +
    + + + + + + + + + + + order_by_index + dlib/graph_utils.h + dlib/graph_utils/sample_pair_abstract.h + + This function provides a total ordering of sample_pair + or ordered_sample_pair + objects that will cause pairs that represent the same edge to be adjacent + when sorted. So for example, this function can be used + with std::sort() to first sort a sequence of sample_pair objects and then + find duplicate edges. + + + + + + + order_by_distance + dlib/graph_utils.h + dlib/graph_utils/sample_pair_abstract.h + + This function provides a total ordering of sample_pair + or ordered_sample_pair objects that causes + pairs with smallest distance to be the first in a sorted list. This function + can be used with std::sort(). + + + + + + + order_by_descending_distance + dlib/graph_utils.h + dlib/graph_utils/sample_pair_abstract.h + + This function provides a total ordering of sample_pair + or ordered_sample_pair objects that causes + pairs with largest distance to be the first in a sorted list. This function + can be used with std::sort(). + + + + + + + order_by_distance_and_index + dlib/graph_utils.h + dlib/graph_utils/sample_pair_abstract.h + + This function provides a total ordering of sample_pair or + ordered_sample_pair objects that causes pairs + with smallest distance to be the first in a sorted list but also orders + samples with equal distances according to order_by_index(). This function + can be used with std::sort(). + + + + + + + contains_duplicate_pairs + dlib/graph_utils.h + dlib/graph_utils/edge_list_graphs_abstract.h + + This function checks if a std::vector of sample_pair or + ordered_sample_pair objects + contains any edge more than once. + + + + + + + max_index_plus_one + dlib/graph_utils.h + dlib/graph_utils/edge_list_graphs_abstract.h + + This function finds the number that is one greater than the largest index + value in a std::vector of sample_pair or + ordered_sample_pair objects. Therefore, + it is useful for finding out how many nodes are in an edge list graph (assuming + the graph contains all node indices from 0 to the largest index indicated + by an edge). + + + + + + + sample_pair + dlib/graph_utils.h + dlib/graph_utils/sample_pair_abstract.h + + This object is intended to represent an edge in an undirected graph + which has data samples at its vertices. Therefore, it is the undirected version + of ordered_sample_pair. + + + + linear_manifold_regularizer_ex.cpp.html + + + + + + + + ordered_sample_pair + dlib/graph_utils.h + dlib/graph_utils/ordered_sample_pair_abstract.h + + This object is intended to represent an edge in a directed graph + which has data samples at its vertices. Therefore, it is the directed version + of sample_pair. + + + + + + + find_percent_shortest_edges_randomly + dlib/graph_utils.h + dlib/graph_utils/edge_list_graphs_abstract.h + + This function is a simple approximate form of find_k_nearest_neighbors. + Instead of checking all possible edges it randomly samples a large number of them and + then returns the best ones. + + + + linear_manifold_regularizer_ex.cpp.html + + + + + + + + find_k_nearest_neighbors + dlib/graph_utils.h + dlib/graph_utils/edge_list_graphs_abstract.h + + This is a function which finds all the k nearest neighbors of a set of points and outputs + the result as a vector of sample_pair objects. It takes O(n^2) time where + n is the number of data samples. A faster approximate version is provided by + find_approximate_k_nearest_neighbors + and find_k_nearest_neighbors_lsh. + + + + + + + + find_k_nearest_neighbors_lsh + dlib/graph_utils_threaded.h + dlib/graph_utils/find_k_nearest_neighbors_lsh_abstract.h + + This function is a simple approximate form of find_k_nearest_neighbors. + It uses locality sensitive hashing + to speed up the nearest neighbor computation and is also capable of using a multi-core CPU. + + + + + + + + find_approximate_k_nearest_neighbors + dlib/graph_utils.h + dlib/graph_utils/edge_list_graphs_abstract.h + + This function is a simple approximate form of find_k_nearest_neighbors. + Instead of checking all possible edges it randomly samples a large number of them and then performs + exact k-nearest-neighbors on that randomly selected subset. + + + + + + + remove_short_edges + dlib/graph_utils.h + dlib/graph_utils/edge_list_graphs_abstract.h + + This is a simple function for removing edges with a small distance value from + a vector of sample_pair or ordered_sample_pair objects. + + + + + + + + remove_duplicate_edges + dlib/graph_utils.h + dlib/graph_utils/edge_list_graphs_abstract.h + + This is a simple function for removing duplicate edges (i.e. edges that compare equal + according to ==) from + a vector of sample_pair or ordered_sample_pair objects. + + + + + + + remove_long_edges + dlib/graph_utils.h + dlib/graph_utils/edge_list_graphs_abstract.h + + This is a simple function for removing edges with a large distance value from + a vector of sample_pair or ordered_sample_pair objects. + + + + + + + + remove_percent_longest_edges + dlib/graph_utils.h + dlib/graph_utils/edge_list_graphs_abstract.h + + This is a simple function for removing edges with a large distance value from + a vector of sample_pair or ordered_sample_pair objects. + + + + + + + + remove_percent_shortest_edges + dlib/graph_utils.h + dlib/graph_utils/edge_list_graphs_abstract.h + + This is a simple function for removing edges with a small distance value from + a vector of sample_pair or ordered_sample_pair objects. + + + + + + + + squared_euclidean_distance + dlib/graph_utils.h + dlib/graph_utils/function_objects_abstract.h + + This is a simple function object that computes squared euclidean distance + between two matrix objects. + + + linear_manifold_regularizer_ex.cpp.html + + + + + + + + cosine_distance + dlib/graph_utils.h + dlib/graph_utils/function_objects_abstract.h + + This is a simple function object that computes cosine of the angle between + two vectors. + + + + + + + negative_dot_product_distance + dlib/graph_utils.h + dlib/graph_utils/function_objects_abstract.h + + This is a simple function object that computes -dot(v1,v2) for two + vectors v1 and v2. + + + + + + + use_weights_of_one + dlib/graph_utils.h + dlib/graph_utils/function_objects_abstract.h + + This is a simple function object that takes a single argument + and always returns 1 + + + + + + + + use_gaussian_weights + dlib/graph_utils.h + dlib/graph_utils/function_objects_abstract.h + + This is a simple function object that takes a single argument + which should be an object similar to sample_pair. + + + linear_manifold_regularizer_ex.cpp.html + + + + + + + + is_ordered_by_index + dlib/graph_utils.h + dlib/graph_utils/edge_list_graphs_abstract.h + + This function checks if a vector of sample_pair or + ordered_sample_pair objects is in sorted + order according to their index values. + + + + + + + find_neighbor_ranges + dlib/graph_utils.h + dlib/graph_utils/edge_list_graphs_abstract.h + + This function takes a graph, defined by a vector of + ordered_sample_pair objects, and finds the + ranges that contain the edges for each node in the graph. The output therefore + lets you easily locate the neighbors of any node in the graph. + + + + + + + convert_unordered_to_ordered + dlib/graph_utils.h + dlib/graph_utils/edge_list_graphs_abstract.h + + This function takes a graph, defined by a vector of + sample_pair objects and converts it into the equivalent + graph defined by a vector of ordered_sample_pair objects. + + + + + + + edge + dlib/graph_utils.h + dlib/graph_utils/graph_utils_abstract.h + + This function takes a graph or + directed_graph object and a + pair of indices. It returns a reference to the edge object between the two nodes + with the given indices. + + + + + + + + is_join_tree + dlib/graph_utils.h + dlib/graph_utils/graph_utils_abstract.h + + This function takes two graph objects and + checks if the second of the two graphs is a valid join tree (aka tree decomposition) + of the first graph. + + + + + + + + create_join_tree + dlib/graph_utils.h + dlib/graph_utils/graph_utils_abstract.h + + This function takes a graph object and + creates a join tree for that graph. Or in other words, this function finds a + tree decomposition of the given graph. + + + + + + + + graph_contains_directed_cycle + dlib/graph_utils.h + dlib/graph_utils/graph_utils_abstract.h + + This function checks a directed_graph for directed cycles. + + + + + + + + graph_has_symmetric_edges + dlib/graph_utils.h + dlib/graph_utils/graph_utils_abstract.h + + This function checks if a directed_graph + has a pair of nodes with just one edge between them. If so then it + does not have symmetric edges. + + + + + + + + triangulate_graph_and_find_cliques + dlib/graph_utils.h + dlib/graph_utils/graph_utils_abstract.h + + This function takes a graph and + turns it into a chordal graph. It also returns a + set that contains + all the cliques present in the chordal graph. + + + + + + + + create_moral_graph + dlib/graph_utils.h + dlib/graph_utils/graph_utils_abstract.h + + This function takes a directed_graph + and returns the moralized version of the graph in the form of a + graph object. + + + + + + + + graph_contains_length_one_cycle + dlib/graph_utils.h + dlib/graph_utils/graph_utils_abstract.h + + This function takes a graph + or directed_graph object and + returns true if and only if the graph contains a node that has an edge that + links back to itself. + + + + + + + + find_connected_nodes + dlib/graph_utils.h + dlib/graph_utils/graph_utils_abstract.h + + This function takes a node from a graph + or directed_graph object and a + set of unsigned longs. It finds all the + nodes in the given graph that are connected to the given node by an + undirected path and returns them in the set (also note that the + original query node is also returned in this set). + + + + + + + + graph_is_connected + dlib/graph_utils.h + dlib/graph_utils/graph_utils_abstract.h + + This function takes a graph or + directed_graph object and + determines if the graph is connected. That is, it returns true if and only if + there is an undirected path between any two nodes in the given graph. + + + + + + + + is_clique + dlib/graph_utils.h + dlib/graph_utils/graph_utils_abstract.h + + This function takes a graph and a + set of node index values and checks + if the specified set of nodes is a clique in the graph. + + + + + + + + copy_graph + dlib/graph_utils.h + dlib/graph_utils/graph_utils_abstract.h + + This function takes a graph or + directed_graph and + makes a copy of it. + + + + + + + + copy_graph_structure + dlib/graph_utils.h + dlib/graph_utils/graph_utils_abstract.h + + This function takes a graph or + directed_graph and copies + its structure to another graph or directed_graph object. The only + restriction is that you can't copy the structure of a graph into a + directed_graph. The three other possible combinations are allowed + however. + + + + + + + + is_maximal_clique + dlib/graph_utils.h + dlib/graph_utils/graph_utils_abstract.h + + This function takes a graph and a + set of node index values and checks + if the specified set of nodes is a maximal clique in the graph. + + + + + + + + graph_contains_undirected_cycle + dlib/graph_utils.h + dlib/graph_utils/graph_utils_abstract.h + + This function checks a directed_graph for undirected cycles. + + + + + + + + + + + + + + +
    + + + diff --git a/ml/dlib/docs/docs/guipics/button.png b/ml/dlib/docs/docs/guipics/button.png new file mode 100644 index 000000000..4b6b792b7 Binary files /dev/null and b/ml/dlib/docs/docs/guipics/button.png differ diff --git a/ml/dlib/docs/docs/guipics/check_box.png b/ml/dlib/docs/docs/guipics/check_box.png new file mode 100644 index 000000000..5dbf9181b Binary files /dev/null and b/ml/dlib/docs/docs/guipics/check_box.png differ diff --git a/ml/dlib/docs/docs/guipics/directed_graph_drawer.png b/ml/dlib/docs/docs/guipics/directed_graph_drawer.png new file mode 100644 index 000000000..4a2caf57a Binary files /dev/null and b/ml/dlib/docs/docs/guipics/directed_graph_drawer.png differ diff --git a/ml/dlib/docs/docs/guipics/image_window.jpg b/ml/dlib/docs/docs/guipics/image_window.jpg new file mode 100644 index 000000000..ed19406c7 Binary files /dev/null and b/ml/dlib/docs/docs/guipics/image_window.jpg differ diff --git a/ml/dlib/docs/docs/guipics/label.png b/ml/dlib/docs/docs/guipics/label.png new file mode 100644 index 000000000..67be00024 Binary files /dev/null and b/ml/dlib/docs/docs/guipics/label.png differ diff --git a/ml/dlib/docs/docs/guipics/list_box.png b/ml/dlib/docs/docs/guipics/list_box.png new file mode 100644 index 000000000..305716645 Binary files /dev/null and b/ml/dlib/docs/docs/guipics/list_box.png differ diff --git a/ml/dlib/docs/docs/guipics/menu_bar.png b/ml/dlib/docs/docs/guipics/menu_bar.png new file mode 100644 index 000000000..0c7461a02 Binary files /dev/null and b/ml/dlib/docs/docs/guipics/menu_bar.png differ diff --git a/ml/dlib/docs/docs/guipics/message_box.png b/ml/dlib/docs/docs/guipics/message_box.png new file mode 100644 index 000000000..2decc1065 Binary files /dev/null and b/ml/dlib/docs/docs/guipics/message_box.png differ diff --git a/ml/dlib/docs/docs/guipics/mouse_tracker.png b/ml/dlib/docs/docs/guipics/mouse_tracker.png new file mode 100644 index 000000000..dc96d5956 Binary files /dev/null and b/ml/dlib/docs/docs/guipics/mouse_tracker.png differ diff --git a/ml/dlib/docs/docs/guipics/named_rectangle.png b/ml/dlib/docs/docs/guipics/named_rectangle.png new file mode 100644 index 000000000..45137ff34 Binary files /dev/null and b/ml/dlib/docs/docs/guipics/named_rectangle.png differ diff --git a/ml/dlib/docs/docs/guipics/open_existing_file_box.png b/ml/dlib/docs/docs/guipics/open_existing_file_box.png new file mode 100644 index 000000000..19a3bb4fb Binary files /dev/null and b/ml/dlib/docs/docs/guipics/open_existing_file_box.png differ diff --git a/ml/dlib/docs/docs/guipics/open_file_box.png b/ml/dlib/docs/docs/guipics/open_file_box.png new file mode 100644 index 000000000..3277bac24 Binary files /dev/null and b/ml/dlib/docs/docs/guipics/open_file_box.png differ diff --git a/ml/dlib/docs/docs/guipics/perspective_window.png b/ml/dlib/docs/docs/guipics/perspective_window.png new file mode 100644 index 000000000..13c01f2e2 Binary files /dev/null and b/ml/dlib/docs/docs/guipics/perspective_window.png differ diff --git a/ml/dlib/docs/docs/guipics/popup_menu.png b/ml/dlib/docs/docs/guipics/popup_menu.png new file mode 100644 index 000000000..8e2bc0286 Binary files /dev/null and b/ml/dlib/docs/docs/guipics/popup_menu.png differ diff --git a/ml/dlib/docs/docs/guipics/radio_button.png b/ml/dlib/docs/docs/guipics/radio_button.png new file mode 100644 index 000000000..e5d7a635e Binary files /dev/null and b/ml/dlib/docs/docs/guipics/radio_button.png differ diff --git a/ml/dlib/docs/docs/guipics/save_file_box.png b/ml/dlib/docs/docs/guipics/save_file_box.png new file mode 100644 index 000000000..a1d66c6b1 Binary files /dev/null and b/ml/dlib/docs/docs/guipics/save_file_box.png differ diff --git a/ml/dlib/docs/docs/guipics/scroll_bar.png b/ml/dlib/docs/docs/guipics/scroll_bar.png new file mode 100644 index 000000000..c08df7524 Binary files /dev/null and b/ml/dlib/docs/docs/guipics/scroll_bar.png differ diff --git a/ml/dlib/docs/docs/guipics/tabbed_display.png b/ml/dlib/docs/docs/guipics/tabbed_display.png new file mode 100644 index 000000000..6ab0417ce Binary files /dev/null and b/ml/dlib/docs/docs/guipics/tabbed_display.png differ diff --git a/ml/dlib/docs/docs/guipics/text_box.png b/ml/dlib/docs/docs/guipics/text_box.png new file mode 100644 index 000000000..c00d008f9 Binary files /dev/null and b/ml/dlib/docs/docs/guipics/text_box.png differ diff --git a/ml/dlib/docs/docs/guipics/text_field.png b/ml/dlib/docs/docs/guipics/text_field.png new file mode 100644 index 000000000..ff917c576 Binary files /dev/null and b/ml/dlib/docs/docs/guipics/text_field.png differ diff --git a/ml/dlib/docs/docs/guipics/text_grid.png b/ml/dlib/docs/docs/guipics/text_grid.png new file mode 100644 index 000000000..8b823af72 Binary files /dev/null and b/ml/dlib/docs/docs/guipics/text_grid.png differ diff --git a/ml/dlib/docs/docs/heatmap.png b/ml/dlib/docs/docs/heatmap.png new file mode 100644 index 000000000..d4bb4cab4 Binary files /dev/null and b/ml/dlib/docs/docs/heatmap.png differ diff --git a/ml/dlib/docs/docs/howto_contribute.xml b/ml/dlib/docs/docs/howto_contribute.xml new file mode 100644 index 000000000..3cf4a4209 --- /dev/null +++ b/ml/dlib/docs/docs/howto_contribute.xml @@ -0,0 +1,604 @@ + + + + + How to Contribute + + + + + + + + + + +

    Contributing Money

    + + The simplest way to contribute is to donate money via Zelle or PayPal + to davis@dlib.net. Any amount is appreciated :) + + +

    Contributing Code

    + + Code contributions are welcome and can be done by submitting a pull request + to dlib's github page. + +

    If you want to make a big change or feature addition, it's probably a good idea to talk to me about it first. + Additionally, you should read over the coding guidelines below + and try to follow them. It is also probably a good idea to read the books Effective C++ and + More Effective C++ by Scott Meyers. +

    + + +

    Coding Guidelines

    + + 1. Use Design by Contract
    + 2. Use spaces instead of tabs.
    + 3. Use the standard C++ naming convention
    + 4. Use RAII
    + 5. Don't use pointers
    + 6. Don't use #define for constants.
    + 7. Don't use stack based arrays.
    + 8. Use exceptions, but don't abuse them
    + 9. Write portable code
    + 10. Setup regression tests
    + 11. Use the Boost Software License
    + + +
      + + 1 +
    • Apply Design by Contract to Your Code

      +

        + The most important part of a software library isn't the code, it is the set + of interfaces the library exposes to the user. These interfaces need to be easy + to use right, and hard to use wrong. The only way this + happens is if the interfaces are documented in a simple, consistent, and precise way. +

        +

        + The name for the way I design and document these interfaces is known as + Design by Contract. There is a lot that can be said about Design by Contract, in fact, + whole books have been written about it, and programming languages exist which + use Design by Contract as a central element. Here I will just go over some + of the basic ways it is used in dlib as well as some of the reasons why it is a + Good Thing. +

        +
      • Functions should have documented preconditions which are programmatically verifiable +
          +

          + Many functions have a set of requirements or preconditions that need to be satisfied + if they are to be used. If these requirements are not satisfied + when a function is called then the function will not do what it is supposed to do. Moreover, + any piece of software that calls a function but doesn't make sure all preconditions + are satisfied contains a bug, by definition. +

          +

          + This means all functions must precisely document their preconditions if they are to be + usable. In fact, all preconditions should be programmatically verifiable. Doing this + has a number of benefits. First, it means they are unambiguous. English + can be confusing and vague, but saying "some_predicate == true" uses a + formal language, C++, that we all should understand quite well. Second, it means + you can put checks into the code that will catch all usage errors. +

          +

          + These checks should always be implemented using + DLIB_ASSERT or + DLIB_CASSERT and they should always + cover all preconditions. + These macros take a boolean argument and if it is false they throw dlib::fatal_error. So + you can use them to check that all your preconditions are true. Also, don't forget that + a violated function precondition indicates a bug in a program. + That is, when dlib::fatal_error is thrown it means a bug has been found and the only thing + an application can do at that point is print an error message and terminate. + In fact, dlib::fatal_error has checks in it to make sure someone doesn't catch the + exception and ignore it. These checks will abruptly terminate any program that attempts + to ignore fatal errors. +

          +

          + The above considerations bring me to my next bit of advice. Developers new to Design by Contract + often confuse input validation requirements with function preconditions. When I tell them + to consider any violation of a function's preconditions a bug and terminate their application + with an error message they complain that this is not at all what an application should do when + it receives invalid user inputs. + + They are right, that would be a bad thing + and you should not write software that behaves that way. The way out of this problem is, of + course, to not consider invalid input as a bug. Instead, you should perform explicit input validation + on any + data coming into your program before it gets to any functions that have preconditions + which demand the validated inputs. Moreover, if you make your preconditions programmatically verifiable + then it should be easy to validate any inputs by simply using whatever it is you + use to check your preconditions. +

          +

          + Consider the function cross_validate_trainer as an + example. One of its requirements is that the input forms a valid binary classification problem. + This is documented in the list of preconditions as + "is_binary_classification_problem(x,y) == true". This precondition is just saying + that when you call + the is_binary_classification_problem + function on the x and y inputs it had better return true + if you want to use those inputs with the cross_validate_trainer function. + Given this information it is trivial to perform input validation. All you have to do is + call is_binary_classification_problem on your input data and you are done. +

          +

          + Using the above technique you have validated your inputs, documented your preconditions, and are + buffered by DLIB_ASSERT statements that will catch you if you accidentally forget to validate any + inputs. +

          +

          The thing to understand here is that + a violation of a function's preconditions means you have a bug on your hands. Or in other words, + you should never intentionally violate any function preconditions. But of course + it will happen from time to time because bugs are unavoidable. But at least with + this approach you will get a detailed error message early in development rather than a + mysterious segmentation fault days or weeks later. +

          +
      • +
      • Functions should have documented postconditions +

          + I don't have nearly as much to say about postconditions as I did about function requirements. You should + strive to write programmatically verifiable postconditions because that makes your postconditions + more precise. However, it is sometimes the case that this isn't practical and that is fine. + But whatever you do write needs to clearly communicate to the + user what it is your function does. +

      • +

        + Now you may be wondering why this is called Design by Contract and not Documentation + by Contract. The reason is that the process of writing down all these detailed descriptions + of what your code does becomes part of how you design software. For example, often you + will find that when you go to write down the requirements for calling a function you are unable + to do so. This may be because the requirements are so complex you can't think of a way + to describe them, or you may realize that you don't even know what they are. Alternatively, + you may know what they are but discover there isn't any way to verify them programmatically. All these + things are symptoms of a bad design and the reason you became aware of this design problem + was by attempting to apply Design by Contract. +

        +

        + After you get enough practice with this way of writing software you begin to think a lot + more about questions like "how can I design this class such that every member function + has a very simple set of requirements and postconditions?" Once you start doing this + you are well on your way to creating software components that are easy to use right, and + hard to use wrong. +

        +

        + The notation dlib uses to document preconditions and postconditions is described in + the introduction. All code that goes into dlib + must document itself using this notation. You should also separate the implementation + and specification of a component into two separate files as described in the introduction. This + way users aren't confused or distracted by implementation details when they look at the documentation. +

        +
      +
    • + + + + 2 +
    • Use spaces instead of tabs.

      +

        This is just generally good advice but + it is especially important in dlib since everything is viewable + as pretty-printed HTML. Tabs show up as 8 characters in most browsers + and this results in the HTML version being difficult to read. So + don't use tabs. Additionally, please use 4 spaces for each tab level.

        +
    • + + + + + 3 +
    • Don't use capitol letters in the names of variables, functions, or + classes. Use the _ character to separate words.

      +
        +

        + The reason dlib uses this style is because it is the style used by the + C++ standard library. But more importantly, dlib currently provides + an interface to users that has a consistent look and feel and it is + important to continue to do so. +

        +

        + As for constants, they should usually contain all upper case letters + but all lowercase is ok sometimes. +

        +
    • + + + 4 +
    • Don't use manual resource management. Use RAII + instead.

      +

        + You should not be calling new and delete in your own code. You should instead + be using objects like the std::vector, std::shared_ptr, + or any number of other objects that manage resources such as memory for you. If you want + an array use std::vector (or the checked std_vector_c). + If you want to make a lookup table use a map. If you want + a two dimensional array use matrix or + array2d. +

        +

        + These container objects are examples of what is called RAII (Resource Acquisition Is Initialization) + in C++. It is essentially a name for the fact that, in C++, you can have totally automated and + deterministic resource management by always associating resource acquisition with the construction + of an object and resource release with the destruction of an object. I say resource management + here rather than memory management + because, unlike Java, RAII can be used for more than memory management. For example, when + you use a mutex you first lock + it, do something, and then you need to remember to unlock it. The RAII way of doing this is + to use the auto_mutex which will lock a mutex and automatically + unlock it for you. Or suppose you have made a TCP connection + to another machine and you want to be certain the resources associated with that connection + are always released. You can easily accomplish this with RAII by using the std::unique_ptr as + shown in this example program. +

        +

        + RAII is a trivial technique to use. All you have to do is not call new and delete and + you will never have another memory leak. Just use the appropriate container + instead. Finally, if you don't use RAII then your code is almost certainly not exception safe. +

        +
      +
    • + + + 5 +
    • Don't use pointers

      +

        + There are a number of reasons to not use pointers. First, if you are using pointers then + you are probably not using RAII. Second, pointers are ambiguous. When I see a pointer + I don't know if it is a pointer to a single item, a pointer to nothing, or + a pointer to an array of who knows how many things. On the other hand, when I see a + std::vector I know with certainty that I'm dealing with a kind of array. Or if I see a + reference to something then I know I'm dealing with exactly one instance of some object. +

        +

        + Most importantly, it is impossible to validate the state of a pointer. Consider two + functions: + +double compute_sum_of_array_elements(const double* array, int array_size); +double compute_sum_of_array_elements(const std::vector<double>& array); + + The first function is inherently unsafe. If the user accidentally passes in an invalid pointer + or sets the size argument incorrectly then their program may crash and this will turn into a + potentially hard to find bug. This is because there is absolutely nothing you can do inside + the first function to tell the difference between a valid pointer and size pair and an invalid + pointer and size pair. Nothing. The second function has none of these difficulties. +

        +

        + If you absolutely need pointer semantics then you can usually use a smart pointer like + std::unique_ptr or std::shared_ptr. + If that still isn't good enough for you and you really need to use a normal C style pointer + then isolate your pointers inside a class or function so that they are contained in a small area of the code. + However, in practice the container classes in dlib and the STL are more than sufficient in nearly + every case where pointers would otherwise be used. +

        +
      +
    • + + + 6 +
    • Don't use #define for constants.

      +

        + dlib is meant to be integrated into other people's projects. Because of this everything + in dlib is contained inside the dlib namespace to avoid naming conflicts with user's code. + #defines don't respect namespaces at all. For example, if you #define a constant called SIZE then it + will cause a conflict with any piece of code anywhere that contains the identifier SIZE. + This means that #define based constants must be avoided and constants should be created using the + const keyword instead. +

        +
      +
    • + + + 7 +
    • Don't use stack based arrays.

      +

        + A stack based array, or C style array, is an array declared like this: + int array[200]; + Most of my criticisms of pointers also apply to stack based arrays. In particular, + if you are passing a stack based array to a function then that means you are probably + using functions similar to the unsafe compute_sum_of_array_elements() example above. +

        +

        + The only time it is OK to use this kind of array is when you use it for simple + tasks and you don't start passing pointers to the array to other parts of your code. You + should also use a constant to store the array size and use that constant in your loops + rather than hard coding the size in numerous places. +

        +

        + But even still, you should use a container class instead and preferably one with the ability to do range + checking such as the std_vector_c.

        +

        + Consider the following two bits of code: + +for (int i = 0; i < array_size; ++i) + my_c_array[i] = 4; + +for (int i = 0; i < my_std_vector.size(); ++i) + my_std_vector[i] = 4; + + The second loop clearly doesn't overflow the bounds of the my_std_vector. On the other + hand, just by looking at the code in the first loop, we can not tell if it overflows + my_c_array. We have to assume that array_size is the appropriate constant but we could be wrong. +

        +

        + Buffer overflows are probably the most common kind of bug in C and C++ code. These bugs also + lead to serious exploitable security holes in software. So please try to avoid stack based arrays. +

        +
      +
    • + + + + + 8 +
    • Use exceptions, but don't abuse them.

      +
        +

        + Exceptions are one of the great features of modern programming languages. Some + people, however, consider that to be a contentious statement. But if you accept + the notion that a software library should be hard to use wrong then it + becomes difficult to reject exceptions. +

        +

        + Most of the complaints I hear about exceptions are actually complaints + about their misuse rather than objections to the basic idea. + So before I begin to defend the above + paragraph I would like to lay out more clearly when it is appropriate to + use exceptions and when it is not. +

        +

        + There are two basic questions you should ask yourself when deciding whether to + throw an exception in response to some event. The first is (1) "should this event + occur in the normal use of my library component?" The second question is (2) "if this event + were to occur, is it likely that the user will want to place the code for dealing + with the event near the invocations of my library component?" +

        +

        + If your answers to the above two questions are "no" then you should probably + throw an exception in response to the event. On the other hand, if you answer + "yes" to either of these questions then you should probably not throw an exception. +

        + +

        + A good example of an event worth throwing exceptions for is running out of memory. + (1) It doesn't happen very often, and (2) when it does happen it is hardly ever the case that + you want to deal with the out of memory event right next to the place where you are + attempting to allocate memory. +

        +

        + Alternatively, an example of an event that shouldn't throw an exception comes to + us from the C++ I/O streams. This part of the standard library allows + you to read the contents of a file from disk. When you hit the end of file they + do not throw an exception. This is appropriate because (1) you usually want to + read a file in its entirety. So hitting EOF happens all the time. Additionally, (2) + when you hit EOF you usually want to break out of the loop you are in + and continue immediately into the next block of code. +

        +

        + Usually when someone tells me they don't like exceptions they give reasons like "they make + me put try/catch blocks all over the place and it makes the code hard to read." Or "it makes + it hard to understand the flow of a program with exceptions in it." Invariably they + have been working with bodies of software that disregard the above rules regarding questions + 1 and 2. Indeed, when exceptions are used for flow control the results are horrifying. Using + exceptions for events that occur in the normal use of a software component, especially when + the events need to be dealt with near where they happen result in a spaghetti-like mess + of throw statements and try/catch blocks. Clearly, exceptions should be used judiciously. + So please, take my advice regarding questions 1 and 2 to heart. +

        +

        + Now let's go back to my claim that exceptions are an important part of making + a library that is hard to use wrong. But first let's be honest about one thing, + many developers don't think very hard about error handling and they similarly aren't very + careful about checking function return codes. Moreover, even the most studious of + us can easily forget to add these checks. It is also easy to forget to add + appropriate exception catch blocks. +

        +

        + So what is so great about exceptions then? Well, let's imagine some error just occurred + and it caused an exception to be thrown. If you forgot to setup catch blocks to deal with + the error then your program will be aborted. Not exactly a great thing. But you will, however, + be able to easily find out what exception was thrown. Additionally, exceptions typically contain a + message telling you all about the error. Moreover, + any debugger worth its + salt will be able to show you a stack trace that lets you see exactly where the exception came from. + The exception forces you, the user, to + be aware of this potential error and to add a catch block to deal with it. + This is where the "hard to use wrong" comes from. +

        +

        + Now let's imagine that we are using return codes to communicate errors to the user and the + same error occurs. If you forgot to do all your return code checking then you will + simply be unaware of the error. Maybe your program will crash right away. But more likely, it + will continue to run for a while before crashing at some random place far away from the source + of the error. You and your debugger now get to spend a few hours of quality time + together trying to figure out what went wrong. +

        +

        + The above considerations are why I maintain that exceptions, when used properly, contribute to + the "hard to use wrong" factor of a library. There are also other reasons to use exceptions. + They free the user from needing to clutter up code with lots of return code checking. This makes + code easier to read and let's you focus more on the algorithm you are trying to implement and less + on the bookkeeping. +

        +

        + Finally, it is important to note that there is a place for return codes. When you answer "no" + to questions 1 and 2, I suggest using exceptions. However, if you answer "yes" to even one + of them then I would recommend pretty much anything other than throwing an exception. In this + case error codes are often an excellent idea. +

        + + +

        + As an aside, it is also important that your exception classes inherit from + dlib::error to maintain consistency with the rest of the library. +

        +
      +
    • + + + + 9 +
    • Write portable code

      +
        +
      • Don't complicate the build process +

          + One of dlib's design goals is to not require any installation + process before it can be used. A user should be able to copy + the dlib folder into their project and have it just work. +

          +

          + In particular, using dlib in a project should not make it difficult to + compile the project from the command line. For example, all the + example programs provided with dlib can be compiled using a single + statement on the command line. +

          +

          + Similarly, the user should be able to check the dlib folder into whatever + version control system they use without running into any difficulties. + The user should then be able to check out copies of the code on any + of the dlib supported platforms and have their project build without + needing to mess with anything. +

          +
      • +
      • Don't make assumptions about how objects are laid out in memory. +

          + If you have been following the prohibition against messing around with + pointers then this won't even be an issue for you. Moreover, just about the only + time this should even come up is when you are casting blocks of + memory into other types or dumping the contents of memory to an I/O channel. + All of these things are highly non-portable so don't do them. +

          +

          + If you want a portable way to write the state of an object to an + I/O channel then I recommend you use the serialization + capability in dlib. If that doesn't suit your needs then do + something else, but whatever you do don't just dump the contents of memory. + Convert your data into some portable format and then output that. +

          +

          + As an example of something else you might do: suppose you have a bunch of integers + you want to write to disk. Assuming all your integers are positive numbers representable + using 32 or fewer bits you could store all your numbers in + dlib::uint32 variables and then convert them + into either big or little endian byte order and then write them to an output stream. + You could do this using code similar to the following: + + +dlib::byte_orderer bo; +... +bo.host_to_big(my_uint); +my_out_stream.write((char*)&my_uint, sizeof(my_uint)); +... + +

          + There are three important things to understand about this process. First, you need + to pick variables that always have the same size on all platforms. This means you + can't use any of the built in C++ types like int, float, double, long, etc. All + of these types have different sizes depending on your platform and even compiler settings. + So you need to use something like dlib::uint32 to obtain a type of a known size. +

          +

          + Second, you need to convert each thing you write out into either big or little endian byte order. + The reason for this is, again, portability. If you don't explicitly convert to one + of these byte orders then you end up writing data out using whatever byte order + is used by your current machine. If you do this then only machines that have the same + byte order as yours will be able to read in your data. If you use the dlib::byte_orderer + object this is easy. It is very type safe. In fact, you should have a hard time even getting + it to compile if you use it wrong. +

          +

          + The third thing you should understand is that you need to write out each of your + variables one at a time. You can't write out an entire struct in a + single ostream.write() statement because the compiler is allowed to put any + kind of padding it feels like between the fields in a struct. +

          +

          + You may be aware that compilers usually provide #pragma directives that allow you + to explicitly control this padding. However, if you want to submit code to dlib + you will not use this feature. Not all compilers support it in the same way and, + more importantly, not all CPU architectures are even capable of running code that + has had the padding messed with. This is because it can result in the CPU attempting + to perform what is called an "unaligned load" which many CPUs (like the SPARC) are + incapable of doing. +

          +

          + So in summary, convert your data into a known type with a fixed size, then convert + into a specific byte order (like big endian), then write out each variable individually. + Or you could just use serialize and not worry about all + this horrible stuff. :) +

          + +

          +
        +
      • + +
      • All code that calls functions that aren't in dlib or the C++ + standard library must be isolated inside the API wrappers. +

          + If you want to contribute code to dlib which needs to use something that isn't + in the C++ standard then we need to introduce a new library component + in the API wrappers section. The new component would + provide whatever functionality you need. This new component would have + to provide at least POSIX and win32 implementations. +

          +

          + It is also worth pointing out that simple wrappers around operating system + specific calls are usually a bad solution. This is because there are + invariably subtle, if not huge, differences between what is available on different + operating systems. + So being truly portable takes a lot of work. It involves reading everything + you can find about all the APIs needed to implement the feature on each target platform. + In many cases there will be important details that are undocumented and you will + only be able to find out about them by searching the internet for other developers + complaining about bugs in API functions X, Y, and Z. All this stuff needs to be abstracted + away to put a portable and simple interface in front of it. So this is a task + that shouldn't be taken lightly. +

          +
        +
      • +
    • + + + + 10 +
    • Library components should have regression tests

      +
        +

        + dlib has a regression test suite located in + the dlib/test folder. Whenever possible, library components should have tests + associated with them. GUI components get a pass since it isn't very easy to setup + automatic tests for them but pretty much everything else should have some sort + of test. +

        +
      +
    • + + + 11 +
    • You must use the Boost Software License

      +
        +

        + Having the library use more than one open source license is confusing + so I ask that any code contributions be licensed under the Boost Software + License. +

        +
      +
    • + + +
    + + + + + + + + + + + + + + +
    + diff --git a/ml/dlib/docs/docs/imaging.xml b/ml/dlib/docs/docs/imaging.xml new file mode 100644 index 000000000..2fef49e73 --- /dev/null +++ b/ml/dlib/docs/docs/imaging.xml @@ -0,0 +1,2608 @@ + + + + + Image Processing + + + + + +

    + This page documents the functionality present in this library that deals with the + management and manipulation of images. One thing to note is that there is no + explicit image object. Instead, everything deals with + array2d objects that contain various kinds of pixels or user defined + generic image objects. +

    + + +

    +

    Pixel Types

    + Most image handling routines in dlib will accept images containing any pixel type. + This is made possible by defining a traits class, pixel_traits, for + each possible pixel type. This traits class enables image processing routines to determine + how to handle each kind of pixel and therefore only pixels which have a pixel_traits definition + may be used. The following list defines all the pixel types which come with pixel_traits definitions. +
      +
    • RGB +
        There are two RGB pixel types in dlib, rgb_pixel and bgr_pixel. + Each defines a 24bit RGB pixel type. The bgr_pixel is identical to rgb_pixel except that it lays + the color channels down in memory in BGR order rather than RGB order and is therefore useful + for interfacing with other image processing tools which expect this format (e.g. OpenCV).
      +
    • +
    • RGB Alpha +
        The rgb_alpha_pixel is an 8bit per channel RGB pixel with an 8bit alpha channel.
      +
    • +
    • HSI +
        The hsi_pixel is a 24bit pixel which represents a point in the Hue Saturation Intensity + (HSI) color space.
      +
    • +
    • LAB +
        The lab_pixel is a 24bit pixel which represents a point in the CIELab color space.
      +
    • +
    • Grayscale +
        Any built in scalar type may be used as a grayscale pixel type. For example, unsigned char, int, double, etc.
      +
    • +
    +

    + +

    +

    Object Detection

    + If you want to create object detectors then try the + scan_fhog_pyramid tool first. It is quite + easy to use and train and will, in many cases, give excellent results. If that + doesn't give good results then try the more powerful + convolutional neural network based detector. +

    + + + + + + + + +
    + Pixels + rgb_pixel + bgr_pixel + rgb_alpha_pixel + hsi_pixel + lab_pixel + pixel_traits + get_pixel_intensity + assign_pixel + assign_pixel_intensity +
    + +
    + Image I/O + jpeg_loader + load_bmp + load_dng + load_image + load_jpeg + load_png + png_loader + save_bmp + save_dng + save_png + save_jpeg +
    + +
    + Object Detection + get_frontal_face_detector + object_detector + evaluate_detectors + full_object_detection + mmod_rect + scan_image + scan_image_movable_parts + find_points_above_thresh + scan_image_pyramid + scan_image_boxes + scan_fhog_pyramid + scan_image_custom + find_candidate_object_locations + test_box_overlap + remove_unobtainable_rectangles + setup_hashed_features + correlation_tracker + + Scan Image Pyramid Tools + + compute_box_dimensions + create_single_box_detection_template + create_overlapped_2x2_detection_template + create_grid_detection_template + + determine_object_boxes + setup_grid_detection_templates + setup_grid_detection_templates_verbose + + +
    + +
    + Feature Extraction + get_surf_points + shape_predictor + + SURF Tools + + hessian_pyramid + compute_surf_descriptor + haar_x + haar_y + get_interest_points + interest_point + surf_point + compute_dominant_angle + draw_surf_points + + + + hog_image + extract_fhog_features + fine_hog_image + poly_image + hashed_feature_image + binned_vector_feature_image + nearest_neighbor_feature_image + randomly_sample_image_features + + make_uniform_lbp_image + extract_histogram_descriptors + extract_uniform_lbp_descriptors + extract_highdim_face_lbp_descriptors +
    + + +
    + Edges and Thresholds + edge_orientation + hysteresis_threshold + sobel_edge_detector + suppress_non_maximum_edges + threshold_image + auto_threshold_image + hough_transform +
    + +
    + Morphology + label_connected_blobs + segment_image + binary_dilation + binary_erosion + binary_open + binary_close + binary_intersection + binary_union + binary_difference + binary_complement + skeleton +
    + +
    + Filtering + gaussian_blur + spatially_filter_image + spatially_filter_image_separable + float_spatially_filter_image_separable + separable_3x3_filter_block_grayscale + separable_3x3_filter_block_rgb + sum_filter + sum_filter_assign + max_filter + spatially_filter_image_separable_down +
    + +
    + Scaling and Rotating + pyramid_up + pyramid_down + pyramid_disable + create_tiled_pyramid + + interpolate_nearest_neighbor + interpolate_bilinear + interpolate_quadratic + transform_image + rotate_image + resize_image + flip_image_left_right + flip_image_up_down + add_image_left_right_flips + add_image_rotations + upsample_image_dataset + flip_image_dataset_left_right + rotate_image_dataset + extract_image_chips + random_cropper + jitter_image + sub_image +
    + +
    + Visualization + randomly_color_image + heatmap + jet + render_face_detections + draw_line + draw_solid_circle + draw_surf_points + draw_rectangle + tile_images + draw_fhog + fill_rect +
    + +
    + Miscellaneous + cv_image + toMat + assign_image + assign_image_scaled + assign_all_pixels + assign_border_pixels + equalize_histogram + get_histogram + zero_border_pixels + integral_image + integral_image_generic + disturb_colors + random_color_transform + +
    + +
    +
    + + + + + + + + + + + hough_transform + dlib/image_transforms.h + dlib/image_transforms/hough_transform_abstract.h + + This object is a tool for computing the line finding version of + the Hough transform given some kind of edge detection image as + input. It also allows the edge pixels to be weighted such that + higher weighted edge pixels contribute correspondingly more to + the output of the Hough transform, allowing stronger edges to + create correspondingly stronger line detections in the final + Hough transform. + + + hough_transform_ex.cpp.html + + + + + + + + skeleton + dlib/image_transforms.h + dlib/image_transforms/morphological_operations_abstract.h + + This function computes the skeletonization of an image. That is, + given a binary image, we progressively thin the binary blobs + until only a single pixel wide skeleton of the original blobs + remains. + + + + + + + disturb_colors + dlib/image_transforms.h + dlib/image_transforms/random_color_transform_abstract.h + + Applies a random color transform an image. This is done by + creating a random_color_transform with the given parameters and then + transforming each pixel in the image with the resulting transform. + + + + + + + random_color_transform + dlib/image_transforms.h + dlib/image_transforms/random_color_transform_abstract.h + + This object generates a random color balancing and gamma correction + transform. It then allows you to apply that specific transform to as many + rgb_pixel objects as you like. + + + + + + + integral_image_generic + dlib/image_transforms.h + dlib/image_transforms/integral_image_abstract.h + + This object is an alternate way of representing image data + that allows for very fast computations of sums of pixels in + rectangular regions. To use this object you load it with a + normal image and then you can use the get_sum_of_area() + member function to compute sums of pixels in a given area in + constant time. + + + + + + + + integral_image + dlib/image_transforms.h + dlib/image_transforms/integral_image_abstract.h + + This is a specialization of the integral_image_generic + template for the case where sums of pixel values should be represented with + longs. E.g. if you use 8bit pixels in your original images then this is + the appropriate kind of integral image to use with them. + + + + + + + + haar_x + dlib/image_transforms.h + dlib/image_transforms/integral_image_abstract.h + + This is a function that operates on an integral_image + and allows you to compute the response of a Haar wavelet oriented along + the X axis. + + + + + + + + haar_y + dlib/image_transforms.h + dlib/image_transforms/integral_image_abstract.h + + This is a function that operates on an integral_image + and allows you to compute the response of a Haar wavelet oriented along + the Y axis. + + + + + + + + draw_surf_points + dlib/image_keypoint/draw_surf_points.h + dlib/image_keypoint/draw_surf_points_abstract.h + + This routine adds a bunch of surf_point objects onto + an image_window + object so they can be visualized. + + + surf_ex.cpp.html + + + + + + + + compute_dominant_angle + dlib/image_keypoint.h + dlib/image_keypoint/surf_abstract.h + + Computes and returns the dominant angle (i.e. the angle of the dominant gradient) + at a given point and scale in an image. This function is part of the + main processing of the SURF algorithm. + + + + + + + + get_surf_points + dlib/image_keypoint.h + dlib/image_keypoint/surf_abstract.h + + This function runs the complete SURF algorithm on an input image and + returns the points it found. For a description of what exactly + the SURF algorithm does you should read the following paper: +
    + SURF: Speeded Up Robust Features + By Herbert Bay, Tinne Tuytelaars, and Luc Van Gool +
    +

    + Also note that there are numerous flavors of the SURF algorithm + you can put together using the functions in dlib. The get_surf_points() + function is just an example of one way you might do so. +

    +
    + + surf_ex.cpp.html + + +
    + + + + + shape_predictor + dlib/image_processing.h + dlib/image_processing/shape_predictor_abstract.h + + This object is a tool that takes in an image region containing some object + and outputs a set of point locations that define the pose of the + object. The classic example of this is human face pose prediction, where + you take an image of a human face as input and are expected to identify the + locations of important facial landmarks such as the corners of the mouth + and eyes, tip of the nose, and so forth. For example, here is the output + of dlib's 68-face-landmark shape_predictor on an image from the HELEN dataset:

    + + +

    + + To create useful instantiations of this object you need to use the + shape_predictor_trainer object to train a + shape_predictor using a set of training images, each annotated with shapes you want to predict. + To do this, the shape_predictor_trainer uses the state-of-the-art method from the + paper: +
    + One Millisecond Face Alignment with an Ensemble of Regression Trees + by Vahid Kazemi and Josephine Sullivan, CVPR 2014 +
    +
    + + face_landmark_detection_ex.cpp.html + train_shape_predictor_ex.cpp.html + webcam_face_pose_ex.cpp.html + + train_shape_predictor.py.html + face_landmark_detection.py.html + face_alignment.py.html + + +
    + + + + + compute_surf_descriptor + dlib/image_keypoint.h + dlib/image_keypoint/surf_abstract.h + + Computes the 64 dimensional SURF descriptor vector of a box centered + at a given center point, tilted at a given angle, and sized according to + a given scale. + + + + + + + + hog_image + dlib/image_keypoint.h + dlib/image_keypoint/hog_abstract.h + + This object is a tool for performing the image feature extraction algorithm + described in the following paper: +
    + Histograms of Oriented Gradients for Human Detection + by Navneet Dalal and Bill Triggs +
    +
    + + object_detector_ex.cpp.html + train_object_detector.cpp.html + + +
    + + + + + extract_fhog_features + dlib/image_transforms.h + dlib/image_transforms/fhog_abstract.h + + This function implements the HOG feature extraction method described in + the paper: +
    + Object Detection with Discriminatively Trained Part Based Models by + P. Felzenszwalb, R. Girshick, D. McAllester, D. Ramanan + in IEEE Transactions on Pattern Analysis and Machine Intelligence, Vol. 32, No. 9, Sep. 2010 +
    + This means that it takes an input image and outputs Felzenszwalb's + 31 dimensional version of HOG features. +
    + + fhog_ex.cpp.html + + +
    + + + + + fine_hog_image + dlib/image_keypoint.h + dlib/image_keypoint/fine_hog_image_abstract.h + + This object is a version of the hog_image that + allows you to extract HOG features at a finer resolution. + + + + + + + + poly_image + dlib/image_keypoint.h + dlib/image_keypoint/poly_image_abstract.h + + This object is a tool for extracting local feature descriptors from an image. + In particular, it fits polynomials to local pixel patches and + allows you to query the coefficients of these polynomials. + + + + + + + + hashed_feature_image + dlib/image_keypoint.h + dlib/image_keypoint/hashed_feature_image_abstract.h + + This object is a tool for performing image feature extraction. In + particular, it wraps another image feature extractor and converts + the wrapped image feature vectors into sparse indicator vectors. It does + this by hashing each feature vector and then returns a new vector + which is zero everywhere except for the position determined by the + hash. + +

    + The following feature extractors can be wrapped by the hashed_feature_image: + +
    + + object_detector_ex.cpp.html + train_object_detector.cpp.html + + +
    + + + + + binned_vector_feature_image + dlib/image_keypoint.h + dlib/image_keypoint/binned_vector_feature_image_abstract.h + + This object is a tool for performing image feature extraction. In + particular, it wraps another image feature extractor and converts the + wrapped image feature vectors into a high dimensional sparse vector. For + example, if the lower level feature extractor outputs the vector [3,4,5] + and this vector is hashed into the second bin of four bins then the output + sparse vector is: +
    + [0,0,0,0, 3,4,5,1, 0,0,0,0, 0,0,0,0]. +
    + That is, the output vector has a dimensionality that is equal to the number + of hash bins times the dimensionality of the lower level vector plus one. + The value in the extra dimension concatenated onto the end of the vector is + always a constant value of of 1 and serves as a bias value. This means + that, if there are N hash bins, these vectors are capable of representing N + different linear functions, each operating on the vectors that fall into + their corresponding hash bin. + +

    + The following feature extractors can be wrapped by the binned_vector_feature_image: + +
    +
    + + + + + nearest_neighbor_feature_image + dlib/image_keypoint.h + dlib/image_keypoint/nearest_neighbor_feature_image_abstract.h + + This object is a tool for performing image feature extraction. In + particular, it wraps another image feature extractor and converts + the wrapped image feature vectors into sparse indicator vectors. It does + this by finding the nearest neighbor for each feature vector and returning an + indicator vector that is zero everywhere except for the position indicated by + the nearest neighbor. + +

    + The following feature extractors can be wrapped by the nearest_neighbor_feature_image: + +
    + +
    + + + + + hessian_pyramid + dlib/image_keypoint.h + dlib/image_keypoint/hessian_pyramid_abstract.h + + This object represents an image pyramid where each level in the + pyramid holds determinants of Hessian matrices for the original + input image. This object can be used to find stable interest + points in an image. + +

    + This object is an implementation of the fast Hessian pyramid + as described in the paper: +
    + SURF: Speeded Up Robust Features + By Herbert Bay, Tinne Tuytelaars, and Luc Van Gool +
    + + This implementation was also influenced by the very well documented + OpenSURF library and its corresponding description of how the fast + Hessian algorithm functions: +
    Notes on the OpenSURF Library by Christopher Evans
    +
    + +
    + + + + + get_interest_points + dlib/image_keypoint.h + dlib/image_keypoint/hessian_pyramid_abstract.h + + This function extracts interest points from a hessian_pyramid. + + + + + + + + interest_point + dlib/image_keypoint.h + dlib/image_keypoint/hessian_pyramid_abstract.h + + This is a simple struct used to represent the interest points returned + by the get_interest_points function. + + + + + + + + surf_point + dlib/image_keypoint.h + dlib/image_keypoint/surf_abstract.h + + This is a simple struct used to represent the SURF points returned + by the get_surf_points function. + + + surf_ex.cpp.html + + + + + + + + pixel_traits + dlib/pixel.h + dlib/pixel.h + + As the name implies, this is a traits class for pixel types. It allows you + to determine what sort of pixel type you are dealing with. + + + + + + + + hsi_pixel + dlib/pixel.h + dlib/pixel.h + + This is a simple struct that represents a HSI colored graphical pixel. + + + + + + + + lab_pixel + dlib/pixel.h + dlib/pixel.h + + This is a simple struct that represents a CIELab colored graphical pixel. + + + + + + + + rgb_alpha_pixel + dlib/pixel.h + dlib/pixel.h + + This is a simple struct that represents an RGB colored graphical pixel with an + alpha channel. + + + + + + + + rgb_pixel + dlib/pixel.h + dlib/pixel.h + + This is a simple struct that represents an RGB colored graphical pixel. + + + + + + + + bgr_pixel + dlib/pixel.h + dlib/pixel.h + + This is a simple struct that represents a BGR colored graphical pixel. +

    + The difference between this object and the rgb_pixel + is just that this struct lays its pixels down in memory in BGR order rather + than RGB order. You only care about this if you are doing something like + using the cv_image object to map an OpenCV image + into a more object oriented form. +

    +
    + +
    + + + + + cv_image + dlib/opencv.h + dlib/opencv/cv_image_abstract.h + + This object is meant to be used as a simple wrapper around the OpenCV + IplImage struct or Mat object. Using this class template you can turn + an OpenCV image into something that looks like a normal dlib style + image object. + +

    + So you should be able to use cv_image objects with many of the image + processing functions in dlib as well as the GUI tools for displaying + images on the screen. +

    + +

    + Note that you can do the reverse conversion, from dlib to OpenCV, + using the toMat routine. +

    +
    + + webcam_face_pose_ex.cpp.html + + +
    + + + + + toMat + dlib/opencv.h + dlib/opencv/to_open_cv_abstract.h + + This routine converts a dlib style image into an instance of OpenCV's cv::Mat object. + This is done by setting up the Mat object to point to the same memory as the dlib image. +

    + Note that you can do the reverse conversion, from OpenCV to dlib, + using the cv_image object. +

    +
    + +
    + + + + + heatmap + dlib/image_transforms.h + dlib/image_transforms/colormaps_abstract.h + + Converts a grayscale image into a heatmap. This is useful if you want + to display a grayscale image with more than 256 values. In particular, + this function uses the following color mapping: +
    +
    + + image_ex.cpp.html + + +
    + + + + + jet + dlib/image_transforms.h + dlib/image_transforms/colormaps_abstract.h + + Converts a grayscale image into an image using the jet color + scheme. This is useful if you want to display a grayscale image + with more than 256 values. In particular, this function uses the + following color mapping: +
    +
    + +
    + + + + + randomly_color_image + dlib/image_transforms.h + dlib/image_transforms/colormaps_abstract.h + + Randomly generates a mapping from gray level pixel values + to the RGB pixel space and then uses this mapping to create + a colored version an image. +

    + This function is useful for displaying the results of some image + segmentation. For example, the output of label_connected_blobs + or segment_image. +

    +
    + +
    + + + + + assign_pixel + dlib/pixel.h + dlib/pixel.h + + assign_pixel() is a templated function that can assign any pixel type to another pixel type. + It will perform whatever conversion is necessary to make the assignment work. (E.g. color to + grayscale conversion) + + + + + + + + assign_pixel_intensity + dlib/pixel.h + dlib/pixel.h + + assign_pixel_intensity() is a templated function that can change the + intensity of a pixel. So if the pixel in question is a grayscale pixel + then it simply assigns that pixel the given value. However, if the + pixel is not a grayscale pixel then it converts the pixel to the + HSI color space and sets the I channel to the given intensity + and then converts this HSI value back to the original pixel's + color space. + + + + + + + + get_pixel_intensity + dlib/pixel.h + dlib/pixel.h + + get_pixel_intensity() is a templated function that + returns the grayscale intensity of a pixel. If the pixel isn't a grayscale + pixel then it converts the pixel to grayscale and returns that value. + + + + + + + + png_loader + dlib/image_io.h + dlib/image_loader/png_loader_abstract.h + + This object loads a Portable Network Graphics (PNG) image file into + an array2d of pixels. +

    + Note that you must define DLIB_PNG_SUPPORT if you want to use this object. You + must also set your build environment to link to the libpng library. However, + if you use CMake and dlib's default CMakeLists.txt file then it will get setup + automatically. +

    +
    + +
    + + + + + jpeg_loader + dlib/image_io.h + dlib/image_loader/jpeg_loader_abstract.h + + This object loads a JPEG image file into + an array2d of pixels. +

    + Note that you must define DLIB_JPEG_SUPPORT if you want to use this object. You + must also set your build environment to link to the libjpeg library. However, + if you use CMake and dlib's default CMakeLists.txt file then it will get setup + automatically. +

    +
    + +
    + + + + + load_jpeg + dlib/image_io.h + dlib/image_loader/jpeg_loader_abstract.h + + This function loads a JPEG image file into + an array2d of pixels. +

    + Note that you must define DLIB_JPEG_SUPPORT if you want to use this object. You + must also set your build environment to link to the libjpeg library. However, + if you use CMake and dlib's default CMakeLists.txt file then it will get setup + automatically. +

    +
    + +
    + + + + + load_dng + dlib/image_io.h + dlib/image_loader/image_loader_abstract.h + + This global function loads a dlib DNG file (a lossless compressed image format) into + an array2d of pixels. + + + + + + + + save_dng + dlib/image_io.h + dlib/image_saver/image_saver_abstract.h + + This global function saves an image as a dlib DNG file (a lossless + compressed image format). +

    + This routine can save images containing any type of pixel. However, the DNG format + can natively store only the following pixel types: rgb_pixel, hsi_pixel, + rgb_alpha_pixel, uint8, uint16, float, and double. + All other pixel types will be converted + into one of these types as appropriate before being saved to disk. +

    + +
    + +
    + + + + + save_png + dlib/image_io.h + dlib/image_saver/save_png_abstract.h + + This global function writes an image to disk as a PNG (Portable Network Graphics) file. +

    + Note that you must define DLIB_PNG_SUPPORT if you want to use this function. You + must also set your build environment to link to the libpng library. However, + if you use CMake and dlib's default CMakeLists.txt file then it will get setup + automatically. +

    +

    + This routine can save images containing any type of pixel. However, save_png() can + only natively store the following pixel types: rgb_pixel, + rgb_alpha_pixel, uint8, and uint16. All other pixel + types will be converted into one of these types as appropriate before being + saved to disk. +

    +
    + +
    + + + + + save_jpeg + dlib/image_io.h + dlib/image_saver/save_jpeg_abstract.h + + This global function writes an image to disk as a JPEG file. +

    + Note that you must define DLIB_JPEG_SUPPORT if you want to use this function. You + must also set your build environment to link to the libjpeg library. However, + if you use CMake and dlib's default CMakeLists.txt file then it will get setup + automatically. +

    +

    + This routine can save images containing any type of pixel. However, save_jpeg() can + only natively store the following pixel types: rgb_pixel + and uint8. All other pixel types will be converted into + one of these types as appropriate before being saved to disk. +

    +
    + +
    + + + + + load_image + dlib/image_io.h + dlib/image_loader/load_image_abstract.h + + This global function takes a file name, looks at its extension, and + then loads it into an array2d of + pixels using the appropriate image + loading routine. The supported types are BMP, PNG, JPEG, GIF, and the dlib DNG file format. + +

    + Note that you can only load PNG, JPEG, and GIF files if you link against + libpng, libjpeg, and libgif respectively. You will also need to #define + DLIB_PNG_SUPPORT, DLIB_JPEG_SUPPORT, and DLIB_GIF_SUPPORT. Or use CMake and + it will do all this for you. +

    +
    + + + image_ex.cpp.html + + +
    + + + + + load_bmp + dlib/image_io.h + dlib/image_loader/image_loader_abstract.h + + This global function loads a MS Windows BMP file into an array2d of + pixels. + + + + + + + + load_png + dlib/image_io.h + dlib/image_loader/png_loader_abstract.h + + This function loads a Portable Network Graphics (PNG) image file into + an array2d of pixels. +

    + Note that you must define DLIB_PNG_SUPPORT if you want to use this object. You + must also set your build environment to link to the libpng library. However, + if you use CMake and dlib's default CMakeLists.txt file then it will get setup + automatically. +

    +
    + +
    + + + + + save_bmp + dlib/image_io.h + dlib/image_saver/image_saver_abstract.h + + This global function saves an image as a MS Windows BMP file. + +

    + This routine can save images containing any type of pixel. However, it will + convert all color pixels into rgb_pixel and grayscale pixels into + uint8 type before saving to disk. +

    + +
    + +
    + + + + + draw_line + dlib/image_transforms.h + dlib/image_transforms/draw_abstract.h + + This global function draws a line on an image. + + + + + + + + draw_solid_circle + dlib/image_transforms.h + dlib/image_transforms/draw_abstract.h + + This global function draws a solid circle on an image. + + + + + + + + render_face_detections + dlib/image_processing/render_face_detections.h + dlib/image_processing/render_face_detections_abstract.h + + This function takes a set of full_object_detections + which represent human faces annotated with 68 facial landmarks (according to the iBUG 300-W + scheme) and converts them into a form suitable for display on an + image_window. + +

    + For example, it will take the output of a shape_predictor + that uses this facial landmarking scheme and will produce visualizations like this: +

    + + +
    + + face_landmark_detection_ex.cpp.html + webcam_face_pose_ex.cpp.html + + +
    + + + + + draw_rectangle + dlib/image_transforms.h + dlib/image_transforms/draw_abstract.h + + This global function draws a rectangle on an image. + + + + + + + + tile_images + dlib/image_transforms.h + dlib/image_transforms/draw_abstract.h + + This function takes an array of images and tiles them into a single large + square image and returns this new big tiled image. Therefore, it is a useful + method to visualize many small images at once. + + + + + + + + draw_fhog + dlib/image_transforms.h + dlib/image_transforms/fhog_abstract.h + + This function takes a FHOG feature map which was created by + extract_fhog_features and + converts it into an image suitable for display on the screen. In + particular, we draw all the hog cells into a grayscale image in a + way that shows the magnitude and orientation of the gradient + energy in each cell. + + + fhog_ex.cpp.html + fhog_object_detector_ex.cpp.html + + + + + + + + fill_rect + dlib/image_transforms.h + dlib/image_transforms/draw_abstract.h + + This global function draws a solid rectangle on an image. + + + + + + + + assign_border_pixels + dlib/image_transforms.h + dlib/image_transforms/assign_image_abstract.h + + This global function assigns all the pixels in the border of an image to + a specific value. + + + + + + + + + assign_all_pixels + dlib/image_transforms.h + dlib/image_transforms/assign_image_abstract.h + + This global function assigns all the pixels in an image a specific value. + + + + + + + + assign_image + dlib/image_transforms.h + dlib/image_transforms/assign_image_abstract.h + + This global function copies one image into another and performs any + necessary color space conversions to make it work right. + + + + + + + + assign_image_scaled + dlib/image_transforms.h + dlib/image_transforms/assign_image_abstract.h + + This global function copies one image into another and performs any + necessary color space conversions to make it work right. Additionally, + if the dynamic range of the source image is too big to fit into the destination image + then it will attempt to perform the appropriate scaling. + + + + + + + + get_histogram + dlib/image_transforms.h + dlib/image_transforms/equalize_histogram_abstract.h + + This global function computes an image's histogram and returns it in the + form of a column or row matrix object. + + + + + + + + edge_orientation + dlib/image_transforms.h + dlib/image_transforms/edge_detector_abstract.h + + This global function takes horizontal and vertical gradient magnitude + values and returns the orientation of the gradient. + + + + + + + + + equalize_histogram + dlib/image_transforms.h + dlib/image_transforms/equalize_histogram_abstract.h + + This global function performs histogram equalization on an image. + + + + + + + + hysteresis_threshold + dlib/image_transforms.h + dlib/image_transforms/thresholding_abstract.h + + This global function performs hysteresis thresholding on an image. + + + + + + + + sobel_edge_detector + dlib/image_transforms.h + dlib/image_transforms/edge_detector_abstract.h + + This global function performs spatial filtering on an image using the + sobel edge detection filters. + + + + image_ex.cpp.html + + + + + + + suppress_non_maximum_edges + dlib/image_transforms.h + dlib/image_transforms/edge_detector_abstract.h + + This global function performs non-maximum suppression on a gradient + image. + + + + image_ex.cpp.html + + + + + + + + zero_border_pixels + dlib/image_transforms.h + dlib/image_transforms/assign_image_abstract.h + + This global function zeros the pixels on the border of an image. + + + + + + + + spatially_filter_image + dlib/image_transforms.h + dlib/image_transforms/spatial_filtering_abstract.h + + This global function performs spatial filtering on an image with a user + supplied filter. + + + + + + + + gaussian_blur + dlib/image_transforms.h + dlib/image_transforms/spatial_filtering_abstract.h + + This global function blurs an image by convolving it with a Gaussian filter. + + + + + + + + spatially_filter_image_separable + dlib/image_transforms.h + dlib/image_transforms/spatial_filtering_abstract.h + + This global function performs spatial filtering on an image with a user + supplied separable filter. + + + + + + + + float_spatially_filter_image_separable + dlib/image_transforms.h + dlib/image_transforms/spatial_filtering_abstract.h + + This global function performs spatial filtering on an image with a user + supplied separable filter. It is optimized to work only on float valued + images with float valued filters. + + + + + + + + spatially_filter_image_separable_down + dlib/image_transforms.h + dlib/image_transforms/spatial_filtering_abstract.h + + This global function performs spatial filtering on an image with a user + supplied separable filter. Additionally, it produces a downsampled + output. + + + + + + + + max_filter + dlib/image_transforms.h + dlib/image_transforms/spatial_filtering_abstract.h + + This function slides a rectangle over an input image and outputs a new + image which contains the maximum valued pixel found inside the rectangle at each + position in the input image. + + + + + + + + sum_filter + dlib/image_transforms.h + dlib/image_transforms/spatial_filtering_abstract.h + + This function slides a rectangle over an input image and adds the sum + of pixel values in each rectangle location to another image. + + + + + + + + sum_filter_assign + dlib/image_transforms.h + dlib/image_transforms/spatial_filtering_abstract.h + + This function slides a rectangle over an input image and outputs a new + image which contains the sum of pixels inside the rectangle at each + position in the input image. + + + + + + + + separable_3x3_filter_block_grayscale + dlib/image_transforms.h + dlib/image_transforms/spatial_filtering_abstract.h + + This routine filters part of an image with a user supplied 3x3 separable filter. + The output is a grayscale sub-image. + + + + + + + + separable_3x3_filter_block_rgb + dlib/image_transforms.h + dlib/image_transforms/spatial_filtering_abstract.h + + This routine filters part of an image with a user supplied 3x3 separable filter. + The output is a RGB sub-image. + + + + + + + + create_tiled_pyramid + dlib/image_transforms.h + dlib/image_transforms/image_pyramid_abstract.h + + This function creates an image pyramid and packs the entire pyramid into + one big image. It does this by tiling the different pyramid layers together + and outputting the result. Here is an example: +
    + +
    + + Also, you can use the + image_to_tiled_pyramid() + and + tiled_pyramid_to_image() routines + to convert between the input image coordinate space and the tiled pyramid + coordinate space. +
    + +
    + + + + + pyramid_down + dlib/image_transforms.h + dlib/image_transforms/image_pyramid_abstract.h + + This is a simple function object to help create image pyramids. It + downsamples an image by a ratio of N to N-1 where N is supplied by the + user as a template argument. + + + + + + + + pyramid_disable + dlib/image_transforms.h + dlib/image_transforms/image_pyramid_abstract.h + + This object downsamples an image at a ratio of infinity to 1. That + means it always outputs an image of size zero. This is useful because + it can be supplied to routines which take a pyramid_down function object + and it will essentially disable pyramid processing. This way, a pyramid + oriented function can be turned into a regular routine which processes + just the original undownsampled image. + + + + + + + + interpolate_nearest_neighbor + dlib/image_transforms.h + dlib/image_transforms/interpolation_abstract.h + + This object is a tool for performing nearest neighbor interpolation + on an image. + + + + + + + interpolate_bilinear + dlib/image_transforms.h + dlib/image_transforms/interpolation_abstract.h + + This object is a tool for performing bilinear interpolation + on an image. + + + + + + + interpolate_quadratic + dlib/image_transforms.h + dlib/image_transforms/interpolation_abstract.h + + This object is a tool for performing quadratic interpolation + on an image. + + + + + + + transform_image + dlib/image_transforms.h + dlib/image_transforms/interpolation_abstract.h + + This routine is a tool for transforming images using some kind of point mapping + function (e.g. point_transform_affine) + and pixel interpolation tool (e.g. interpolate_quadratic). + An example application of this routine is for image rotation. Indeed, it is how + rotate_image is implemented. + + + + + + + rotate_image + dlib/image_transforms.h + dlib/image_transforms/interpolation_abstract.h + + This is a routine for rotating an image. + + + + + + + pyramid_up + dlib/image_transforms.h + dlib/image_transforms/interpolation_abstract.h + + This routine upsamples an image. In particular, it takes a + pyramid_down object (or an object with a + compatible interface) as an argument and performs an upsampling + which is the inverse of the supplied pyramid_down object. + + + + + + + resize_image + dlib/image_transforms.h + dlib/image_transforms/interpolation_abstract.h + + This is a routine capable of resizing or stretching an image. + + + + + + + extract_image_chips + dlib/image_transforms.h + dlib/image_transforms/interpolation_abstract.h + + This function extracts "chips" from an image. That is, it takes a list of + rectangular sub-windows (i.e. chips) within an image and extracts those + sub-windows, storing each into its own image. It also allows the user to + specify the scale and rotation for the chip. + + + face_landmark_detection_ex.cpp.html + + + + + + + random_cropper + dlib/image_transforms.h + dlib/image_transforms/random_cropper_abstract.h + + This object is a tool for extracting random crops of objects from a set of + images. The crops are randomly jittered in scale, translation, and + rotation but more or less centered on objects specified by mmod_rect + objects. + + + random_cropper_ex.cpp.html + dnn_mmod_ex.cpp.html + dnn_mmod_find_cars_ex.cpp.html + dnn_mmod_train_find_cars_ex.cpp.html + + + + + + + jitter_image + dlib/image_transforms.h + dlib/image_transforms/interpolation_abstract.h + + Randomly jitters an image by slightly rotating, scaling, and translating it. + There is also a 50% chance it will be mirrored left to right. + + + dnn_metric_learning_on_images_ex.cpp.html + dnn_face_recognition_ex.cpp.html + + + + + + + sub_image + dlib/image_transforms.h + dlib/image_transforms/interpolation_abstract.h + + This function returns a lightweight sub-image of another image. In particular, + the returned sub-image simply holds a pointer to the original image, meaning there + is no overhead for using or creating the sub-image. + + + + + + + add_image_left_right_flips + dlib/image_transforms.h + dlib/image_transforms/interpolation_abstract.h + + This routine takes a set of images and bounding boxes within those + images and doubles the size of the dataset by adding left/right + flipped copies of each image as well as the corresponding bounding + boxes. Therefore, this function is useful if you are training and + object detector and your objects have a left/right symmetry. + + + fhog_object_detector_ex.cpp.html + + + + + + + add_image_rotations + dlib/image_transforms.h + dlib/image_transforms/interpolation_abstract.h + + This routine takes a set of images and bounding boxes within those images and + grows the dataset by computing many different rotations of each image. It will + also adjust the positions of the bounding boxes so that they still fall on the + same objects in each rotated image. + + + + + + + rotate_image_dataset + dlib/image_transforms.h + dlib/image_transforms/interpolation_abstract.h + + This routine takes a set of images and bounding boxes within those + images and rotates the entire dataset by a user specified angle. + This means that all images are rotated and the bounding boxes are adjusted + so that they still sit on top of the same visual objects in the new rotated images. + + + + + + + flip_image_dataset_left_right + dlib/image_transforms.h + dlib/image_transforms/interpolation_abstract.h + + This routine takes a set of images and bounding boxes within those images and + mirrors the entire dataset left to right. This means that all images are + flipped left to right and the bounding boxes are adjusted so that they still + sit on top of the same visual objects in the new flipped images. + + + + + + + upsample_image_dataset + dlib/image_transforms.h + dlib/image_transforms/interpolation_abstract.h + + This routine takes a set of images and bounding boxes within those images and + upsamples the entire dataset. This means that all images are upsampled and the + bounding boxes are adjusted so that they still sit on top of the same visual + objects in the new images. + + + + + + + flip_image_left_right + dlib/image_transforms.h + dlib/image_transforms/interpolation_abstract.h + + This is a routine which can flip an image from left to right. (e.g. as + if viewed through a mirror). + + + + + + + flip_image_up_down + dlib/image_transforms.h + dlib/image_transforms/interpolation_abstract.h + + This routine flips an image upside down. + + + + + + + auto_threshold_image + dlib/image_transforms.h + dlib/image_transforms/thresholding_abstract.h + + This global function performs a simple binary thresholding on an image. + Instead of taking a user supplied threshold + it computes one from the image using k-means clustering. + + + + + + + + threshold_image + dlib/image_transforms.h + dlib/image_transforms/thresholding_abstract.h + + This global function performs a simple binary thresholding on an image with a user + supplied threshold. + + + + + + + + label_connected_blobs + dlib/image_transforms.h + dlib/image_transforms/label_connected_blobs_abstract.h + + This function labels each of the connected blobs in an image with a unique integer label. + + + + + + + + segment_image + dlib/image_transforms.h + dlib/image_transforms/segment_image_abstract.h + + Attempts to segment an image into regions which have some visual consistency to them. + In particular, this function implements the algorithm described in the paper: +
    + Efficient Graph-Based Image Segmentation by Felzenszwalb and Huttenlocher. +
    +
    + +
    + + + + + find_candidate_object_locations + dlib/image_transforms.h + dlib/image_transforms/segment_image_abstract.h + + This function takes an input image and generates a set of candidate + rectangles which are expected to bound any objects in the image. It does + this by running a version of the segment_image routine on the image and + then reports rectangles containing each of the segments as well as rectangles + containing unions of adjacent segments. The basic idea is described in the + paper: +
    + Segmentation as Selective Search for Object Recognition by Koen E. A. van de Sande, et al. +
    + Note that this function deviates from what is described in the paper slightly. + See the code for details. +
    + + find_candidate_object_locations.py.html + +
    + + + + + binary_dilation + dlib/image_transforms.h + dlib/image_transforms/morphological_operations_abstract.h + + This global function performs the morphological operation of dilation on an image. + + + + + + + + binary_erosion + dlib/image_transforms.h + dlib/image_transforms/morphological_operations_abstract.h + + This global function performs the morphological operation of erosion on an image. + + + + + + + + binary_open + dlib/image_transforms.h + dlib/image_transforms/morphological_operations_abstract.h + + This global function performs a morphological opening on an image. + + + + + + + + binary_close + dlib/image_transforms.h + dlib/image_transforms/morphological_operations_abstract.h + + This global function performs a morphological closing on an image. + + + + + + + + binary_intersection + dlib/image_transforms.h + dlib/image_transforms/morphological_operations_abstract.h + + This global function computes the intersection of two binary images. + + + + + + + + binary_union + dlib/image_transforms.h + dlib/image_transforms/morphological_operations_abstract.h + + This global function computes the union of two binary images. + + + + + + + + binary_difference + dlib/image_transforms.h + dlib/image_transforms/morphological_operations_abstract.h + + This global function computes the difference of two binary images. + + + + + + + + binary_complement + dlib/image_transforms.h + dlib/image_transforms/morphological_operations_abstract.h + + This global function computes the complement of a binary image. + + + + + + + + scan_image + dlib/image_processing.h + dlib/image_processing/scan_image_abstract.h + + This global function is a tool for sliding a set of rectangles + over an image space and finding the locations where the sum of pixels in + the rectangles exceeds a threshold. It is useful for implementing + certain kinds of sliding window classifiers. + + + + + + + + scan_image_movable_parts + dlib/image_processing.h + dlib/image_processing/scan_image_abstract.h + + This global function is a tool for sliding a set of rectangles + over an image space and finding the locations where the sum of pixels in + the rectangles exceeds a threshold. It is useful for implementing + certain kinds of sliding window classifiers. The behavior of this + routine is similar to scan_image except that + it can also handle movable parts in addition to rigidly placed parts + within the sliding window. + + + + + + + + find_points_above_thresh + dlib/image_processing.h + dlib/image_processing/scan_image_abstract.h + + This routine finds all points in an image with a pixel value above a + threshold. It also has the ability to produce an efficient random + subsample of such points if the number of them is very large. + + + + + + + + scan_image_pyramid + dlib/image_processing.h + dlib/image_processing/scan_image_pyramid_abstract.h + + This object is a tool for running a sliding window classifier over + an image pyramid. This object can also be understood as a general + tool for implementing the spatial pyramid models described in the paper: +
    + Beyond Bags of Features: Spatial Pyramid Matching for Recognizing + Natural Scene Categories by Svetlana Lazebnik, Cordelia Schmid, + and Jean Ponce +
    + It also includes the ability to represent movable part models. + +

    + The following feature extractors can be used with the scan_image_pyramid object: + +
    + + object_detector_ex.cpp.html + object_detector_advanced_ex.cpp.html + + +
    + + + + + scan_fhog_pyramid + dlib/image_processing.h + dlib/image_processing/scan_fhog_pyramid_abstract.h + + + This object is a tool for running a fixed sized sliding window classifier + over an image pyramid. In particular, it slides a linear classifier over + a HOG pyramid as discussed in the paper: +
    + Histograms of Oriented Gradients for Human Detection by Navneet Dalal + and Bill Triggs, CVPR 2005 +
    + However, we augment the method slightly to use the version of HOG features + from: +
    + Object Detection with Discriminatively Trained Part Based Models by + P. Felzenszwalb, R. Girshick, D. McAllester, D. Ramanan + IEEE Transactions on Pattern Analysis and Machine Intelligence, Vol. 32, No. 9, Sep. 2010 +
    + Since these HOG features have been shown to give superior performance. +
    + + fhog_object_detector_ex.cpp.html + train_object_detector.py.html + + +
    + + + + + scan_image_boxes + dlib/image_processing.h + dlib/image_processing/scan_image_boxes_abstract.h + + This object is a tool for running a classifier over an image with the goal + of localizing each object present. The localization is in the form of the + bounding box around each object of interest. + +

    + Unlike the scan_image_pyramid object which scans a + fixed sized window over an image pyramid, the scan_image_boxes tool allows you to + define your own list of "candidate object locations" which should be evaluated. + This is simply a list of rectangle objects which might contain objects of + interest. The scan_image_boxes object will then evaluate the classifier at each + of these locations and return the subset of rectangles which appear to have + objects in them. +

    + + This object can also be understood as a general tool for implementing the spatial + pyramid models described in the paper: +
    + Beyond Bags of Features: Spatial Pyramid Matching for Recognizing + Natural Scene Categories by Svetlana Lazebnik, Cordelia Schmid, + and Jean Ponce +
    + +

    + The following feature extractors can be used with the scan_image_boxes object: + +
    + +
    + + + + + scan_image_custom + dlib/image_processing.h + dlib/image_processing/scan_image_custom_abstract.h + + This object is a tool for running a classifier over an image with the goal + of localizing each object present. The localization is in the form of the + bounding box around each object of interest. + +

    + Unlike the scan_image_pyramid + and scan_image_boxes objects, this image + scanner delegates all the work of constructing the object feature vector to + a user supplied feature extraction object. That is, scan_image_custom + simply asks the supplied feature extractor what boxes in the image we + should investigate and then asks the feature extractor for the complete + feature vector for each box. That is, scan_image_custom does not apply any + kind of pyramiding or other higher level processing to the features coming + out of the feature extractor. That means that when you use + scan_image_custom it is completely up to you to define the feature vector + used with each image box. +

    +
    + +
    + + + + + full_object_detection + dlib/image_processing.h + dlib/image_processing/full_object_detection_abstract.h + + This object represents the location of an object in an image along with the + positions of each of its constituent parts. + + + + + + + + mmod_rect + dlib/image_processing.h + dlib/image_processing/full_object_detection_abstract.h + + This is a simple struct that is used to give training data and receive detections + from the Max-Margin Object Detection loss layer. + + + + + + + + get_frontal_face_detector + dlib/image_processing/frontal_face_detector.h + dlib/image_processing/frontal_face_detector_abstract.h + + This function returns an object_detector that is + configured to find human faces that are looking more or less towards the camera. + It is created using the scan_fhog_pyramid + object. + + + face_detection_ex.cpp.html + face_detector.py.html + webcam_face_pose_ex.cpp.html + + + + + + + + object_detector + dlib/image_processing.h + dlib/image_processing/object_detector_abstract.h + + This object is a tool for detecting the positions of objects in an image. In + particular, it is a simple container to aggregate an instance of an image + scanner object (either scan_fhog_pyramid, + scan_image_pyramid, scan_image_boxes, or + scan_image_custom), the weight vector + needed by one of these image scanners, and finally an instance of + test_box_overlap. The test_box_overlap object + is used to perform non-max suppression on the output of the image scanner + object. + +

    + Note that you can use the + structural_object_detection_trainer + to learn the parameters of an object_detector. See the example programs for an introduction. +

    + +

    + Also note that dlib contains more powerful CNN based object detection + tooling, which will usually run slower but produce much + more general and accurate detectors. +

    + +
    + + fhog_object_detector_ex.cpp.html + face_detection_ex.cpp.html + object_detector_ex.cpp.html + object_detector_advanced_ex.cpp.html + train_object_detector.cpp.html + + face_detector.py.html + train_object_detector.py.html + + +
    + + + + + correlation_tracker + dlib/image_processing.h + dlib/image_processing/correlation_tracker_abstract.h + + This is a tool for tracking moving objects in a video stream. You give it + the bounding box of an object in the first frame and it attempts to track the + object in the box from frame to frame. + +

    + This tool is an implementation of the method described in the following paper: +

    + Danelljan, Martin, et al. "Accurate scale estimation for robust visual + tracking." Proceedings of the British Machine Vision Conference BMVC. 2014. +
    +

    +
    + + video_tracking_ex.cpp.html + correlation_tracker.py.html + + +
    + + + + + evaluate_detectors + dlib/image_processing.h + dlib/image_processing/scan_fhog_pyramid_abstract.h + + This function allows you to efficiently run a bunch of + scan_fhog_pyramid based + object_detectors + over an image. Importantly, this function is faster than running + each detector individually because it computes the HOG features + only once and then reuses them for each detector. + + + fhog_object_detector_ex.cpp.html + + + + + + + + test_box_overlap + dlib/image_processing.h + dlib/image_processing/box_overlap_testing_abstract.h + + This object is a simple function object for determining if two + rectangles overlap. + + + + + + + + remove_unobtainable_rectangles + dlib/image_processing.h + dlib/image_processing/remove_unobtainable_rectangles_abstract.h + + Recall that the scan_image_pyramid and + scan_image_boxes objects can't produce + all possible rectangles as object detections since they only + consider a limited subset of all possible object positions. + Therefore, when training an object detector that uses these tools + you must make sure the training data does not contain any object + locations that are unobtainable by the image scanning model. + The remove_unobtainable_rectangles() routine is a tool to filter out + these unobtainable rectangles from the training. + + + + + + + + compute_box_dimensions + dlib/image_processing.h + dlib/image_processing/detection_template_tools_abstract.h + + This function is a tool for computing a rectangle with a particular + width/height ratio and area. + + + + + + + + create_single_box_detection_template + dlib/image_processing.h + dlib/image_processing/detection_template_tools_abstract.h + + This function is a tool for creating a detection template usable by + the scan_image_pyramid object. This + particular function creates a detection template with exactly one feature + extraction region. + + + + + + + + create_overlapped_2x2_detection_template + dlib/image_processing.h + dlib/image_processing/detection_template_tools_abstract.h + + This function is a tool for creating a detection template usable by + the scan_image_pyramid object. This + particular function creates a detection template with four overlapping feature + extraction regions. + + + + + + + + create_grid_detection_template + dlib/image_processing.h + dlib/image_processing/detection_template_tools_abstract.h + + This function is a tool for creating a detection template usable by + the scan_image_pyramid object. This + particular function creates a detection template with a grid of feature + extraction regions. + + + + + + + + determine_object_boxes + dlib/image_processing.h + dlib/image_processing/scan_image_pyramid_tools_abstract.h + + The scan_image_pyramid object represents a sliding + window classifier system. For it to work correctly it needs to be given a set of + object boxes which define the size and shape of each sliding window and these windows + need to be able to match the sizes and shapes of targets the user wishes to detect. + Therefore, the determine_object_boxes() routine is a tool for computing a set of object boxes + which can meet this requirement. + + + + + + + + + setup_grid_detection_templates + dlib/image_processing.h + dlib/image_processing/scan_image_pyramid_tools_abstract.h + + This routine uses determine_object_boxes to obtain a set of + object boxes and then adds them to a scan_image_pyramid object + as detection templates. It also uses create_grid_detection_template + to create each feature extraction region. Therefore, the detection templates will extract + features from a regular grid inside each object box. + + + + + + + + setup_grid_detection_templates_verbose + dlib/image_processing.h + dlib/image_processing/scan_image_pyramid_tools_abstract.h + + This function is identical to setup_grid_detection_templates + except that it also outputs information regarding the selected detection templates + to standard out. + + + object_detector_ex.cpp.html + train_object_detector.cpp.html + + + + + + + setup_hashed_features + dlib/image_processing.h + dlib/image_processing/setup_hashed_features_abstract.h + + This is a tool for configuring the hashed_feature_image + or binned_vector_feature_image object + with a random projection hash. + + + object_detector_ex.cpp.html + train_object_detector.cpp.html + + + + + + + + randomly_sample_image_features + dlib/statistics.h + dlib/statistics/image_feature_sampling_abstract.h + + Given a feature extractor such as the hog_image, + this routine selects a random subsample of local image feature vectors + from a set of images. + + + + + + + + + make_uniform_lbp_image + dlib/image_transforms.h + dlib/image_transforms/lbp_abstract.h + + This function extracts the uniform local-binary-pattern feature at every pixel + of an image and stores the output in a new image object. + We use the idea of uniform LBPs from the paper: +
    + Face Description with Local Binary Patterns: Application to Face Recognition + by Ahonen, Hadid, and Pietikainen. +
    +
    +
    + + + + + extract_histogram_descriptors + dlib/image_transforms.h + dlib/image_transforms/lbp_abstract.h + + This function extracts histograms of pixel values from a set of windows in an + image and returns the histograms. + + + + + + + extract_uniform_lbp_descriptors + dlib/image_transforms.h + dlib/image_transforms/lbp_abstract.h + + Extracts histograms of uniform local-binary-patterns from an image. The + histograms are from densely tiled windows that do not overlap and cover all + of the image. + We use the idea of uniform LBPs from the paper: +
    + Face Description with Local Binary Patterns: Application to Face Recognition + by Ahonen, Hadid, and Pietikainen. +
    +
    +
    + + + + + extract_highdim_face_lbp_descriptors + dlib/image_transforms.h + dlib/image_transforms/lbp_abstract.h + + This function extracts the high-dimensional LBP feature described in the + paper: +
    + Blessing of Dimensionality: High-dimensional Feature and Its Efficient + Compression for Face Verification by Dong Chen, Xudong Cao, Fang Wen, and + Jian Sun +
    +
    +
    + + + +
    + + + + +
    + diff --git a/ml/dlib/docs/docs/index.xml b/ml/dlib/docs/docs/index.xml new file mode 100644 index 000000000..0a4b7fa53 --- /dev/null +++ b/ml/dlib/docs/docs/index.xml @@ -0,0 +1,226 @@ + + + + + + + + + +

    + Dlib is a modern C++ toolkit containing machine learning algorithms and + tools for creating complex software in C++ to solve real world problems. + It is used in both industry and academia in a wide range of domains + including robotics, embedded devices, mobile phones, and large high + performance computing environments. Dlib's open source licensing + allows you to use it in any application, free of charge. +

    + +

    + To follow or participate in the development of dlib subscribe to dlib on github. + Also be sure to read the how to contribute page if you intend to + submit code to the project. +

    + +
    + +

    +

    Major Features

    + + +

    + + + + + + +
    diff --git a/ml/dlib/docs/docs/intro.xml b/ml/dlib/docs/docs/intro.xml new file mode 100644 index 000000000..bd21c7e5b --- /dev/null +++ b/ml/dlib/docs/docs/intro.xml @@ -0,0 +1,431 @@ + + + + + Introduction + + + + + + + + + + + + + + + + +

    Overview

    + + + + +

    + Dlib is a general purpose cross-platform open source software library written in the C++ programming + language. Its design is heavily influenced by ideas from design by contract and component-based + software engineering. This means it is, first and foremost, a collection of independent + software components, each accompanied by extensive documentation and thorough debugging modes. +

    + + +

    + Davis King has been the primary + author of dlib since development began in 2002. In that time + dlib has grown to include a wide variety of tools. In particular, + it now contains software components for dealing with networking, + threads, graphical interfaces, complex data structures, linear + algebra, statistical machine learning, image processing, data + mining, XML and text parsing, numerical optimization, Bayesian + networks, and numerous other tasks. In + recent years, much of the development has been focused on creating + a broad set of statistical machine learning tools. However, dlib + remains a general purpose library and welcomes contributions of high + quality software components useful in any domain. +

    + +

    + Core to the development philosophy of dlib is a dedication to + portability and ease of use. Therefore, all code in dlib is designed + to be as portable as possible and similarly to not require a user to + configure or install anything. To help achieve this, all platform + specific code is confined inside the API wrappers. Everything else is + either layered on top of those wrappers or is written in pure ISO + standard C++. Currently the library is known to work on OS X, MS + Windows, Linux, Solaris, the BSDs, and HP-UX. It should work on any + POSIX platform but I haven't had the opportunity to test it on any + others (if you have access to other platforms and would like to help + increase this list then let me know). +

    +

    + The rest of this page explains everything you need to know to get started using the library. It + explains where to find the documentation for each object/function and how to interpret + what you find there. For help compiling with dlib check out the how to compile + page. Or if you are having trouble finding where a particular object's documentation is located you may + be able to find it by consulting the index.

    +

    + The library is also covered by the very liberal Boost Software License + so feel free to use it any way you like. However, if you use dlib in + your research then please cite its Journal of Machine Learning Research paper when + publishing. +

    +

    + Finally, I must give some credit to the Reusable + Software Research Group at Ohio State since they taught me much + of the software engineering techniques used in the creation of this library. +

    + + + + + +

    Notation

    +

    + For the most part I try to document my code in a way that any C++ programmer would understand, + but for the sake of brevity I use some of the following uncommon notation. +

    + +
      +
    • kernel, extension, and abstract +
        + Each component of the library has a specification which defines its core behavior and interface. This + specification defines what is called the component's kernel. Additionally, each component may have any number of + extensions. An extension is essentially a specification for something that layers functionality on top of the + kernel of a component. +
        +
        In the naming of files I use the word abstract to indicate that a file + contains a specification of a kernel component or extension rather than an actual implementation. +
      + + + + +
    • /*! comments like this !*/ +
        + This is just for "formal comments." Generally these appear after a function prototype and contain + the requires/ensures stuff or at the top of a class and tell you general things about the class. +
      + + + + +
    • requires/ensures/throws +
        + These words appear in the formal comment following function prototypes and have the following meanings. +

        requires: This defines a list of requirements for calling the function. These requirements + MUST be met or a call to the function has undefined results. (note that when the checking/debugging modes + are enabled on an object then it will throw the dlib::fatal_error exception with fatal_error::type == EBROKEN_ASSERT when the requires clause is + broken rather than causing "undefined results") + +

        ensures: This defines what the function does. It is a list of conditions that will be + true after the function finishes executing. Note that if an exception is thrown then nothing in the + ensures clause is guaranteed to be true. + +

        throws: This defines what exceptions may be thrown by this function. It generally + tells you why the exception might be thrown. It also tells you what the function does in this event: + Does it have no effect at all? Does it corrupt any objects? etc. + +
        +
        + Sometimes these blocks do not appear in the formal comment. The meanings in these cases are as follows: +
        missing requires: There are no requirements, you may put anything in the function arguments. +
        missing ensures: This means that the effects of the function are unspecified. This is often used + for call backs where the client programmer implements the actual function. +
        missing throws: This doesn't mean anything. A function without a throws block + might throw exceptions or it might not. + +
        +
        + So in summary, the requires clause must always be satisfied, the ensures clause tells you what the + function does when it does not throw or return an error, and the throws clause tells you what happens when the function + does throw. + +
      + + + +
    • meaning_of_hash meaning of # symbol +
        + I use this as a prefix on identifiers to make reference to the value of the identifier "after" + some event has occurred. +

        + The most common place I use this notation is inside the formal comment following a function prototype. + If the # symbol appears in a requires/ensures/throws block then it means the value of + the identifier after the function has finished, otherwise all references to an identifier + refer to its value before the function was called. +

        + An example will make it clear. + + + int funct( + int& something +); +/*! + requires + - something > 4 + ensures + - #some_other_function() == 9 + - #something == something + 1 + - returns something +!*/ + + + This says that funct() requires that "something" be greater than 4, that funct() will increment "something" + by 1, and funct() returns the original value of something. It also says that + after the call to funct() ends a call to some_other_function() will return the value 9. + +
      + +
    • CONVENTION CONVENTION +
        + This is a section of the formal comment which appears at the top of classes which are + actual implementations (as opposed to specifications). This section of the comment contains + a list of invariants that tell you what the member variables are used for. It also relates + the state of the member variables to the class interface. +
        +
        + For example, you might see a line in this section that says "my_size == size()". This just means + that the member variable my_size always contains the value returned by the size() function. +
      + + + + +
    • "initial value for its type" +
        + I frequently say that after a function executes some variable or argument will have an + initial value for its type. This makes sense for objects with a user defined constructor, + but for anything else not so much. Therefore the initial value of a type with no user defined + constructor is undefined. +
      + +
    + + + + + + +

    Organization

    + +

    + The library can be thought of as a collection of components. Each component always consists of + at least two separate files, a specification file and an implementation file. The specification + files are the ones that end with _abstract.h. Each of these specification files don't actually + contain any code and they even have preprocessor directives that prevent any of their contents from + being included. Their purpose is purely to document a component's interface in a file that isn't + cluttered with implementation details the user shouldn't need to know about. +

    + +

    + The next important concept in dlib organization is multi-implementation components. That is, + some components provide more than one implementation of what is defined in their specification. + When you use these components you have to identify them with names like dlib::component::kernel_1a. + Often these components will have just a debugging and non-debugging implementation. However, many components + provide a large number of alternate implementations. For example, the entropy_encoder_model + has 32 different implementations you can choose from. +

    + +
      + +
    • File organization for multi-implementation components +
        + Each component gets its own folder and one file in the root of the directory tree. +

        + I will use the queue object as a typical example and + explain what each of its files contain. + Below is the directory structure and all the files related to the queue component. + +

        +
        • file tree +
            +
          • dlib/ +
              +
            • queue.h +
            • queue/ +
                +
              • queue_kernel_abstract.h +
              • queue_kernel_1.h +
              • queue_kernel_2.h +
              • queue_kernel_c.h +
              • queue_sort_abstract.h +
              • queue_sort_1.h +
              +
            +
          + + +
          + +
        • queue.h +
            This file does not contain any executable code. All it does is define the typedefs such as + kernel_1a, kernel_1a_c, etc. for the queue object. See the Creating Objects + section to learn what these typedefs are for. +
          + +
        • queue_kernel_abstract.h +
            + This file does not contain any code. It even has preprocessor directives that prevent + any of its contents from being included. +
            +
            + The purpose of this file is to define exactly what a queue object does and what its + interface is. +
          + +
        • queue_sort_abstract.h +
            + This file also doesn't contain any code. Its only purpose is to define the sort + extension to queue objects. +
          + +
        • queue_kernel_1.h +
            + This file contains an implementation of the queue kernel specification found + in queue_kernel_abstract.h +
          + +
        • queue_kernel_2.h +
            + This file contains another implementation of the queue kernel specification found + in queue_kernel_abstract.h +
          + +
        • queue_sort_1.h +
            + This file contains an implementation of the queue sort extension specification found + in queue_sort_abstract.h +
          + +
        • queue_kernel_c.h +
            + This file contains a templated class which wraps any implementation of the queue kernel + specification. It is used during debugging to check that the requires clauses are never + violated. +
          +
        +
      +
    + + + + + + + + + + + creating_objects +

    Creating Objects

    + +

    + To create many of the objects in this library you need to choose which kernel implementation you would like and if you + want the checking version or any extensions. +

    +

    + To make this easy there are header files which define typedefs of all this stuff. For + example, to create a queue of ints using queue kernel implementation 1 you would type + dlib::queue<int>::kernel_1a my_queue;. Or to get the debugging/checking version you + would type dlib::queue<int>::kernel_1a_c my_queue;. +

    +

    + There can be a lot of different typedefs for each component. You can find a list of them + in the section for the component in question. For the queue component they can be found + here. +

    +

    + None of the above applies to the single-implementation components, that is, anything that doesn't have an "implementations" + section in its documentation. These tools are designed to have only one implementation and thus do not follow the + above naming convention. For example, to create a + logger object you would simply type dlib::logger mylog("name");. + For the purposes of object creation the API components also appear to be single-implementation. That is, there is no + need to specify which implementation you want since it is automatically determined by which platform you compile under. + Note also that there are no explicit checking versions of these components. However, there are + DLIB_ASSERT statements that perform checking and you can + enable them by #defining DEBUG or ENABLE_ASSERTS. +

    + + + + + +

    Assumptions

    + There are some restrictions on the behavior of certain objects or functions. + Rather than replicating these restrictions all over the place in my documentation they + are listed here. + +
      + +
    • global swap() +
        + It is assumed that this operator does not throw. Undefined behavior results if it does. + Note that std::swap() for all intrinsics and std::string does not throw. +
      + + + +
    • operator<() +
        + It is assumed that this operator (or std::less or any similar functor supplied by you to the library) + does not throw. Undefined behavior results if it does. +
      + + + +
    • dlib::general_hash +
        + It is assumed that general_hash does not throw. Undefined behavior results if it does. + This is actually noted in the general hash spec file but I'm listing it here also for good measure. + +
      + + + +
    + + + + + + + + + thread_safety +

    Thread Safety

    + +

    + In the library there are three kinds of objects with regards to threading: +

      +
    • Objects which are completely thread safe. This means that any pattern of access from + multiple threads is safe.
    • +
    • Objects which are safe to use if no threads touch the same instance, but require access + to a particular instance to be serialized via a mutex if it is shared among threads.
    • +
    • Objects which share some kind of global resource or are reference counted. This kind of object is + extremely thread unfriendly and can only be used in a threaded program with great care.
    • +
    +

    + +

    + How do you know which components/objects are thread safe and which aren't? The rule is that if + the specification for the component doesn't mention threading or thread safety then + it is ok to use as long as you serialize access to shared instances. If the component might have + some global resources or be reference counted then the specifications will tell you this. + Lastly if the component is completely thread safe then the specification will tell you this. +

    +

    + Also note that global functions in dlib are always thread safe. +

    + + + + + + + + +
    diff --git a/ml/dlib/docs/docs/jet.png b/ml/dlib/docs/docs/jet.png new file mode 100644 index 000000000..83bd2a71d Binary files /dev/null and b/ml/dlib/docs/docs/jet.png differ diff --git a/ml/dlib/docs/docs/kernel_1a.txt b/ml/dlib/docs/docs/kernel_1a.txt new file mode 100644 index 000000000..ff9ad371e --- /dev/null +++ b/ml/dlib/docs/docs/kernel_1a.txt @@ -0,0 +1,78 @@ + + + + The Canterbury Corpus + +file size packed size bpb corruption + +text: 152089 86995 4.576 no +play: 125179 75430 4.82062 no +html: 24603 16209 5.27058 no +Csrc: 11150 7084 5.08269 no +list: 3721 2224 4.78151 no +Excl: 1029744 440758 3.42421 no +tech: 426754 248345 4.65552 no +poem: 481861 273394 4.53897 no +fax: 513216 75036 1.16966 no +SPRC: 38240 25660 5.3682 no +man: 4227 2663 5.03998 no + +average: 4.42981 + +time: 875ms + + + + The Calgary Corpus + +file size packed size bpb corruption + +bib: 111261 72533 5.21534 no +book1: 768771 435527 4.53219 no +book2: 610856 364597 4.7749 no +geo: 102400 72600 5.67188 no +news: 377109 244377 5.18422 no +obj1: 21504 16183 6.02046 no +obj2: 246814 189902 6.15531 no +paper1: 53161 33144 4.98772 no +paper2: 82199 47398 4.613 no +pic: 513216 75036 1.16966 no +progc: 39611 25885 5.22784 no +progl: 71646 42688 4.76655 no +progp: 49379 30180 4.88953 no +trans: 93695 64603 5.51603 no + +average: 4.9089 + +time: 1.11sec + + + + The Artificial Corpus + +file size packed size bpb corruption + +a: 1 7 56 no +aaa: 100000 20 0.0016 no +alphabet: 100000 58912 4.71296 no +random: 100000 75202 6.01616 no + +average: 16.6827 + +time: 93ms + + + + The Large Corpus + +file size packed size bpb corruption + +E.coli: 4638690 1162352 2.00462 no +bible: 4047392 2194059 4.33674 no +word: 2473400 1542086 4.98774 no + +average: 3.77637 + +time: 3.766sec + + \ No newline at end of file diff --git a/ml/dlib/docs/docs/kernel_1a.xml b/ml/dlib/docs/docs/kernel_1a.xml new file mode 100644 index 000000000..25aba2e14 --- /dev/null +++ b/ml/dlib/docs/docs/kernel_1a.xml @@ -0,0 +1,8 @@ + + + + + kernel_1a + + + diff --git a/ml/dlib/docs/docs/kernel_1b.txt b/ml/dlib/docs/docs/kernel_1b.txt new file mode 100644 index 000000000..2d6ff74ff --- /dev/null +++ b/ml/dlib/docs/docs/kernel_1b.txt @@ -0,0 +1,77 @@ + + + + The Canterbury Corpus + +file size packed size bpb corruption + +text: 152089 66165 3.48033 no +play: 125179 54572 3.48761 no +html: 24603 11661 3.79173 no +Csrc: 11150 4733 3.39587 no +list: 3721 1629 3.50228 no +Excl: 1029744 343447 2.66821 no +tech: 426754 188332 3.5305 no +poem: 481861 204240 3.39085 no +fax: 513216 54127 0.843731 no +SPRC: 38240 18307 3.82992 no +man: 4227 2100 3.97445 no + +average: 3.26323 + +time: 844ms + + + + The Calgary Corpus + +file size packed size bpb corruption + +bib: 111261 48130 3.46069 no +book1: 768771 346572 3.6065 no +book2: 610856 288605 3.77968 no +geo: 102400 61124 4.77531 no +news: 377109 196085 4.15975 no +obj1: 21504 12445 4.62984 no +obj2: 246814 127142 4.12106 no +paper1: 53161 25438 3.82807 no +paper2: 82199 37295 3.62973 no +pic: 513216 54127 0.843731 no +progc: 39611 19090 3.85549 no +progl: 71646 29773 3.32446 no +progp: 49379 20795 3.36904 no +trans: 93695 40922 3.49406 no + +average: 3.6341 + +time: 1.109sec + + + + The Artificial Corpus + +file size packed size bpb corruption + +a: 1 7 56 no +aaa: 100000 20 0.0016 no +alphabet: 100000 83 0.00664 no +random: 100000 77775 6.222 no + +average: 15.5576 + +time: 94ms + + + + The Large Corpus + +file size packed size bpb corruption + +E.coli: 4638690 1151913 1.98662 no +bible: 4047392 1651476 3.26428 no +word: 2473400 1133090 3.66488 no + +average: 2.97193 + +time: 3.672sec + \ No newline at end of file diff --git a/ml/dlib/docs/docs/kernel_1b.xml b/ml/dlib/docs/docs/kernel_1b.xml new file mode 100644 index 000000000..7642f0c14 --- /dev/null +++ b/ml/dlib/docs/docs/kernel_1b.xml @@ -0,0 +1,8 @@ + + + + + kernel_1b + + + diff --git a/ml/dlib/docs/docs/kernel_1c.txt b/ml/dlib/docs/docs/kernel_1c.txt new file mode 100644 index 000000000..0e0a78d0d --- /dev/null +++ b/ml/dlib/docs/docs/kernel_1c.txt @@ -0,0 +1,78 @@ + + + + The Canterbury Corpus + +file size packed size bpb corruption + +text: 152089 51810 2.72525 no +play: 125179 44002 2.8121 no +html: 24603 8602 2.79706 no +Csrc: 11150 3399 2.43874 no +list: 3721 1272 2.73475 no +Excl: 1029744 237165 1.84252 no +tech: 426754 147090 2.75737 no +poem: 481861 169981 2.82208 no +fax: 513216 54230 0.845336 no +SPRC: 38240 15190 3.17782 no +man: 4227 1763 3.33665 no + +average: 2.57179 + +time: 1.031sec + + + + The Calgary Corpus + +file size packed size bpb corruption + +bib: 111261 37264 2.67939 no +book1: 768771 280052 2.91428 no +book2: 610856 221616 2.90237 no +geo: 102400 62115 4.85273 no +news: 377109 155282 3.29416 no +obj1: 21504 11235 4.17969 no +obj2: 246814 97319 3.15441 no +paper1: 53161 19664 2.95916 no +paper2: 82199 29837 2.90388 no +pic: 513216 54230 0.845336 no +progc: 39611 14610 2.9507 no +progl: 71646 21637 2.41599 no +progp: 49379 14204 2.30122 no +trans: 93695 27848 2.37776 no + +average: 2.90936 + +time: 1.297sec + + + + The Artificial Corpus + +file size packed size bpb corruption + +a: 1 7 56 no +aaa: 100000 18 0.00144 no +alphabet: 100000 65 0.0052 no +random: 100000 90704 7.25632 no + +average: 15.8157 + +time: 203ms + + + + The Large Corpus + +file size packed size bpb corruption + +E.coli: 4638690 1141437 1.96855 no +bible: 4047392 1263237 2.49689 no +word: 2473400 876621 2.83536 no + +average: 2.4336 + +time: 3.391sec + + \ No newline at end of file diff --git a/ml/dlib/docs/docs/kernel_1c.xml b/ml/dlib/docs/docs/kernel_1c.xml new file mode 100644 index 000000000..1bab1d095 --- /dev/null +++ b/ml/dlib/docs/docs/kernel_1c.xml @@ -0,0 +1,8 @@ + + + + + kernel_1c + + + diff --git a/ml/dlib/docs/docs/kernel_1da.txt b/ml/dlib/docs/docs/kernel_1da.txt new file mode 100644 index 000000000..fde5871e0 --- /dev/null +++ b/ml/dlib/docs/docs/kernel_1da.txt @@ -0,0 +1,78 @@ + + + + The Canterbury Corpus + +file size packed size bpb corruption + +text: 152089 45580 2.39754 no +play: 125179 42432 2.71176 no +html: 24603 7745 2.51839 no +Csrc: 11150 3165 2.27085 no +list: 3721 1238 2.66165 no +Excl: 1029744 194875 1.51397 no +tech: 426754 111838 2.09653 no +poem: 481861 148110 2.45897 no +fax: 513216 56075 0.874096 no +SPRC: 38240 14248 2.98075 no +man: 4227 1736 3.28555 no + +average: 2.34273 + +time: 1.812sec + + + + The Calgary Corpus + +file size packed size bpb corruption + +bib: 111261 29161 2.09676 no +book1: 768771 235667 2.4524 no +book2: 610856 165032 2.16132 no +geo: 102400 67663 5.28617 no +news: 377109 128148 2.71853 no +obj1: 21504 10750 3.99926 no +obj2: 246814 82894 2.68685 no +paper1: 53161 17398 2.61816 no +paper2: 82199 26449 2.57414 no +pic: 513216 56075 0.874096 no +progc: 39611 13188 2.6635 no +progl: 71646 17135 1.9133 no +progp: 49379 11764 1.90591 no +trans: 93695 19602 1.67369 no + +average: 2.54458 + +time: 2.36sec + + + + The Artificial Corpus + +file size packed size bpb corruption + +a: 1 6 48 no +aaa: 100000 19 0.00152 no +alphabet: 100000 66 0.00528 no +random: 100000 89652 7.17216 no + +average: 13.7947 + +time: 375ms + + + + The Large Corpus + +file size packed size bpb corruption + +E.coli: 4638690 1130363 1.94945 no +bible: 4047392 871537 1.72266 no +word: 2473400 589688 1.9073 no + +average: 1.8598 + +time: 4.484sec + + \ No newline at end of file diff --git a/ml/dlib/docs/docs/kernel_1da.xml b/ml/dlib/docs/docs/kernel_1da.xml new file mode 100644 index 000000000..428e2e4d7 --- /dev/null +++ b/ml/dlib/docs/docs/kernel_1da.xml @@ -0,0 +1,8 @@ + + + + + kernel_1da + + + diff --git a/ml/dlib/docs/docs/kernel_1db.txt b/ml/dlib/docs/docs/kernel_1db.txt new file mode 100644 index 000000000..648312723 --- /dev/null +++ b/ml/dlib/docs/docs/kernel_1db.txt @@ -0,0 +1,78 @@ + + + + The Canterbury Corpus + +file size packed size bpb corruption + +text: 152089 47843 2.51658 no +play: 125179 45069 2.88029 no +html: 24603 7914 2.57334 no +Csrc: 11150 3279 2.35265 no +list: 3721 1281 2.7541 no +Excl: 1029744 210286 1.6337 no +tech: 426754 115826 2.17129 no +poem: 481861 157348 2.61234 no +fax: 513216 56477 0.880362 no +SPRC: 38240 14466 3.02636 no +man: 4227 1780 3.36882 no + +average: 2.43362 + +time: 2.296sec + + + + The Calgary Corpus + +file size packed size bpb corruption + +bib: 111261 29758 2.13969 no +book1: 768771 247700 2.57762 no +book2: 610856 168694 2.20928 no +geo: 102400 67817 5.2982 no +news: 377109 126675 2.68729 no +obj1: 21504 10871 4.04427 no +obj2: 246814 82948 2.6886 no +paper1: 53161 18113 2.72576 no +paper2: 82199 27700 2.6959 no +pic: 513216 56477 0.880362 no +progc: 39611 13622 2.75115 no +progl: 71646 17263 1.92759 no +progp: 49379 12032 1.94933 no +trans: 93695 19505 1.6654 no + +average: 2.5886 + +time: 2.86sec + + + + The Artificial Corpus + +file size packed size bpb corruption + +a: 1 6 48 no +aaa: 100000 20 0.0016 no +alphabet: 100000 66 0.00528 no +random: 100000 88475 7.078 no + +average: 13.7712 + +time: 531ms + + + + The Large Corpus + +file size packed size bpb corruption + +E.coli: 4638690 1130433 1.94957 no +bible: 4047392 844807 1.66983 no +word: 2473400 504129 1.63056 no + +average: 1.74999 + +time: 5.266sec + + \ No newline at end of file diff --git a/ml/dlib/docs/docs/kernel_1db.xml b/ml/dlib/docs/docs/kernel_1db.xml new file mode 100644 index 000000000..9d8a10eff --- /dev/null +++ b/ml/dlib/docs/docs/kernel_1db.xml @@ -0,0 +1,8 @@ + + + + + kernel_1db + + + diff --git a/ml/dlib/docs/docs/kernel_1ea.txt b/ml/dlib/docs/docs/kernel_1ea.txt new file mode 100644 index 000000000..1697827d4 --- /dev/null +++ b/ml/dlib/docs/docs/kernel_1ea.txt @@ -0,0 +1,78 @@ + + + + The Canterbury Corpus + +file size packed size bpb corruption + +text: 152089 40695 2.14059 no +play: 125179 37421 2.39152 no +html: 24603 6859 2.2303 no +Csrc: 11150 2792 2.00323 no +list: 3721 1084 2.33056 no +Excl: 1029744 156897 1.21892 no +tech: 426754 102805 1.9272 no +poem: 481861 136664 2.26894 no +fax: 513216 51109 0.796686 no +SPRC: 38240 12590 2.63389 no +man: 4227 1530 2.89567 no + +average: 2.07614 + +time: 3.062sec + + + + The Calgary Corpus + +file size packed size bpb corruption + +bib: 111261 26039 1.87228 no +book1: 768771 218772 2.27659 no +book2: 610856 151985 1.99045 no +geo: 102400 59371 4.63836 no +news: 377109 115334 2.4467 no +obj1: 21504 9832 3.65774 no +obj2: 246814 75065 2.43309 no +paper1: 53161 15263 2.29687 no +paper2: 82199 23368 2.27429 no +pic: 513216 51109 0.796686 no +progc: 39611 11549 2.33248 no +progl: 71646 15297 1.70806 no +progp: 49379 10447 1.69254 no +trans: 93695 17677 1.50932 no + +average: 2.28039 + +time: 4.172sec + + + + The Artificial Corpus + +file size packed size bpb corruption + +a: 1 7 56 no +aaa: 100000 17 0.00136 no +alphabet: 100000 65 0.0052 no +random: 100000 82599 6.60792 no + +average: 15.6536 + +time: 672ms + + + + The Large Corpus + +file size packed size bpb corruption + +E.coli: 4638690 1130095 1.94899 no +bible: 4047392 848956 1.67803 no +word: 2473400 542983 1.75623 no + +average: 1.79442 + +time: 6.516sec + + \ No newline at end of file diff --git a/ml/dlib/docs/docs/kernel_1ea.xml b/ml/dlib/docs/docs/kernel_1ea.xml new file mode 100644 index 000000000..3117a5092 --- /dev/null +++ b/ml/dlib/docs/docs/kernel_1ea.xml @@ -0,0 +1,8 @@ + + + + + kernel_1ea + + + diff --git a/ml/dlib/docs/docs/kernel_1eb.txt b/ml/dlib/docs/docs/kernel_1eb.txt new file mode 100644 index 000000000..0b46d59c3 --- /dev/null +++ b/ml/dlib/docs/docs/kernel_1eb.txt @@ -0,0 +1,78 @@ + + + + The Canterbury Corpus + +file size packed size bpb corruption + +text: 152089 40220 2.1156 no +play: 125179 37451 2.39344 no +html: 24603 6819 2.21729 no +Csrc: 11150 2770 1.98744 no +list: 3721 1085 2.33271 no +Excl: 1029744 158436 1.23088 no +tech: 426754 100243 1.87917 no +poem: 481861 136151 2.26042 no +fax: 513216 50374 0.785229 no +SPRC: 38240 12387 2.59142 no +man: 4227 1528 2.89189 no + +average: 2.06232 + +time: 4.875sec + + + + The Calgary Corpus + +file size packed size bpb corruption + +bib: 111261 25279 1.81764 no +book1: 768771 216389 2.25179 no +book2: 610856 146986 1.92498 no +geo: 102400 58768 4.59125 no +news: 377109 109292 2.31852 no +obj1: 21504 9819 3.6529 no +obj2: 246814 72662 2.3552 no +paper1: 53161 15128 2.27656 no +paper2: 82199 23186 2.25657 no +pic: 513216 50374 0.785229 no +progc: 39611 11434 2.30926 no +progl: 71646 14712 1.64274 no +progp: 49379 10210 1.65414 no +trans: 93695 16885 1.4417 no + +average: 2.23418 + +time: 5.421sec + + + + The Artificial Corpus + +file size packed size bpb corruption + +a: 1 7 56 no +aaa: 100000 17 0.00136 no +alphabet: 100000 64 0.00512 no +random: 100000 81881 6.55048 no + +average: 15.6392 + +time: 907ms + + + + The Large Corpus + +file size packed size bpb corruption + +E.coli: 4638690 1129161 1.94738 no +bible: 4047392 794809 1.571 no +word: 2473400 454450 1.46988 no + +average: 1.66275 + +time: 7.39sec + + \ No newline at end of file diff --git a/ml/dlib/docs/docs/kernel_1eb.xml b/ml/dlib/docs/docs/kernel_1eb.xml new file mode 100644 index 000000000..e157438d3 --- /dev/null +++ b/ml/dlib/docs/docs/kernel_1eb.xml @@ -0,0 +1,8 @@ + + + + + kernel_1eb + + + diff --git a/ml/dlib/docs/docs/kernel_1ec.txt b/ml/dlib/docs/docs/kernel_1ec.txt new file mode 100644 index 000000000..488004d48 --- /dev/null +++ b/ml/dlib/docs/docs/kernel_1ec.txt @@ -0,0 +1,78 @@ + + + + The Canterbury Corpus + +file size packed size bpb corruption + +text: 152089 40367 2.12334 no +play: 125179 37785 2.41478 no +html: 24603 6828 2.22022 no +Csrc: 11150 2710 1.94439 no +list: 3721 1084 2.33056 no +Excl: 1029744 162760 1.26447 no +tech: 426754 100488 1.88376 no +poem: 481861 139110 2.30955 no +fax: 513216 50276 0.783701 no +SPRC: 38240 12219 2.55628 no +man: 4227 1533 2.90135 no + +average: 2.06658 + +time: 5.484sec + + + + The Calgary Corpus + +file size packed size bpb corruption + +bib: 111261 25153 1.80858 no +book1: 768771 220128 2.2907 no +book2: 610856 146040 1.91259 no +geo: 102400 58737 4.58883 no +news: 377109 108774 2.30753 no +obj1: 21504 9823 3.65439 no +obj2: 246814 71148 2.30613 no +paper1: 53161 15116 2.27475 no +paper2: 82199 23346 2.27214 no +pic: 513216 50276 0.783701 no +progc: 39611 11351 2.29249 no +progl: 71646 14125 1.5772 no +progp: 49379 9966 1.61461 no +trans: 93695 16068 1.37194 no + +average: 2.21826 + +time: 6.188sec + + + + The Artificial Corpus + +file size packed size bpb corruption + +a: 1 7 56 no +aaa: 100000 17 0.00136 no +alphabet: 100000 65 0.0052 no +random: 100000 81869 6.54952 no + +average: 15.639 + +time: 1.109sec + + + + The Large Corpus + +file size packed size bpb corruption + +E.coli: 4638690 1160064 2.00068 no +bible: 4047392 760498 1.50319 no +word: 2473400 422419 1.36628 no + +average: 1.62338 + +time: 11.203sec + + \ No newline at end of file diff --git a/ml/dlib/docs/docs/kernel_1ec.xml b/ml/dlib/docs/docs/kernel_1ec.xml new file mode 100644 index 000000000..918273865 --- /dev/null +++ b/ml/dlib/docs/docs/kernel_1ec.xml @@ -0,0 +1,8 @@ + + + + + kernel_1ec + + + diff --git a/ml/dlib/docs/docs/kernel_2a.txt b/ml/dlib/docs/docs/kernel_2a.txt new file mode 100644 index 000000000..3caff6866 --- /dev/null +++ b/ml/dlib/docs/docs/kernel_2a.txt @@ -0,0 +1,78 @@ + + + + The Canterbury Corpus + +file size packed size bpb corruption + +text: 152089 55655 2.9275 no +play: 125179 49648 3.17293 no +html: 24603 8345 2.71349 no +Csrc: 11150 3514 2.52126 no +list: 3721 1379 2.96479 no +Excl: 1029744 72617 0.564156 no +tech: 426754 145403 2.72575 no +poem: 481861 196109 3.25586 no +fax: 513216 49740 0.775346 no +SPRC: 38240 13253 2.77259 no +man: 4227 1857 3.51455 no + +average: 2.53711 + +time: 4.641sec + + + + The Calgary Corpus + +file size packed size bpb corruption + +bib: 111261 35687 2.566 no +book1: 768771 313140 3.2586 no +book2: 610856 210661 2.7589 no +geo: 102400 61070 4.77109 no +news: 377109 143307 3.04012 no +obj1: 21504 11004 4.09375 no +obj2: 246814 83289 2.69965 no +paper1: 53161 19433 2.9244 no +paper2: 82199 30671 2.98505 no +pic: 513216 49740 0.775346 no +progc: 39611 14142 2.85618 no +progl: 71646 17196 1.92011 no +progp: 49379 12045 1.95144 no +trans: 93695 19849 1.69478 no + +average: 2.73539 + +time: 4.89sec + + + + The Artificial Corpus + +file size packed size bpb corruption + +a: 1 6 48 no +aaa: 100000 58 0.00464 no +alphabet: 100000 87 0.00696 no +random: 100000 77815 6.2252 no + +average: 13.5592 + +time: 235ms + + + + The Large Corpus + +file size packed size bpb corruption + +E.coli: 4638690 1349400 2.32721 no +bible: 4047392 1206327 2.3844 no +word: 2473400 703195 2.27442 no + +average: 2.32868 + +time: 8.75sec + + \ No newline at end of file diff --git a/ml/dlib/docs/docs/kernel_2a.xml b/ml/dlib/docs/docs/kernel_2a.xml new file mode 100644 index 000000000..4b8c1182a --- /dev/null +++ b/ml/dlib/docs/docs/kernel_2a.xml @@ -0,0 +1,8 @@ + + + + + kernel_2a + + + diff --git a/ml/dlib/docs/docs/kernel_3a.txt b/ml/dlib/docs/docs/kernel_3a.txt new file mode 100644 index 000000000..927c4af7c --- /dev/null +++ b/ml/dlib/docs/docs/kernel_3a.txt @@ -0,0 +1,78 @@ + + + + The Canterbury Corpus + +file size packed size bpb corruption + +text: 152089 118620 6.2395 no +play: 125179 103761 6.63121 no +html: 24603 14067 4.57408 no +Csrc: 11150 5805 4.16502 no +list: 3721 2142 4.60521 no +Excl: 1029744 390978 3.03748 no +tech: 426754 310347 5.81782 no +poem: 481861 435150 7.22449 no +fax: 513216 116298 1.81285 no +SPRC: 38240 23076 4.82762 no +man: 4227 3177 6.01278 no + +average: 4.99528 + +time: 547ms + + + + The Calgary Corpus + +file size packed size bpb corruption + +bib: 111261 68949 4.95764 no +book1: 768771 683415 7.11177 no +book2: 610856 435186 5.69936 no +geo: 102400 105111 8.2118 no +news: 377109 266382 5.65103 no +obj1: 21504 15714 5.84598 no +obj2: 246814 139671 4.52717 no +paper1: 53161 37476 5.63962 no +paper2: 82199 63198 6.15073 no +pic: 513216 116298 1.81285 no +progc: 39611 25398 5.12948 no +progl: 71646 31905 3.56252 no +progp: 49379 20943 3.39302 no +trans: 93695 34056 2.90782 no + +average: 5.04291 + +time: 610ms + + + + The Artificial Corpus + +file size packed size bpb corruption + +a: 1 9 72 no +aaa: 100000 3771 0.30168 no +alphabet: 100000 486 0.03888 no +random: 100000 112491 8.99928 no + +average: 20.335 + +time: 47ms + + + + The Large Corpus + +file size packed size bpb corruption + +E.coli: 4638690 4752396 8.1961 no +bible: 4047392 2649303 5.23656 no +word: 2473400 1452177 4.69694 no + +average: 6.0432 + +time: 2.5sec + + \ No newline at end of file diff --git a/ml/dlib/docs/docs/kernel_3a.xml b/ml/dlib/docs/docs/kernel_3a.xml new file mode 100644 index 000000000..2a95695e6 --- /dev/null +++ b/ml/dlib/docs/docs/kernel_3a.xml @@ -0,0 +1,8 @@ + + + + + kernel_3a + + + diff --git a/ml/dlib/docs/docs/kernel_3b.txt b/ml/dlib/docs/docs/kernel_3b.txt new file mode 100644 index 000000000..69f25e132 --- /dev/null +++ b/ml/dlib/docs/docs/kernel_3b.txt @@ -0,0 +1,78 @@ + + + + The Canterbury Corpus + +file size packed size bpb corruption + +text: 152089 111537 5.86693 no +play: 125179 101457 6.48396 no +html: 24603 13914 4.52433 no +Csrc: 11150 5760 4.13274 no +list: 3721 2160 4.64391 no +Excl: 1029744 407466 3.16557 no +tech: 426754 291483 5.46419 no +poem: 481861 417897 6.93805 no +fax: 513216 114138 1.77918 no +SPRC: 38240 23184 4.85021 no +man: 4227 3159 5.97871 no + +average: 4.89343 + +time: 734ms + + + + The Calgary Corpus + +file size packed size bpb corruption + +bib: 111261 63729 4.58231 no +book1: 768771 655227 6.81844 no +book2: 610856 409392 5.36155 no +geo: 102400 108099 8.44523 no +news: 377109 259065 5.49581 no +obj1: 21504 15768 5.86607 no +obj2: 246814 138564 4.49128 no +paper1: 53161 35901 5.40261 no +paper2: 82199 60291 5.86781 no +pic: 513216 114138 1.77918 no +progc: 39611 24984 5.04587 no +progl: 71646 31113 3.47408 no +progp: 49379 20772 3.36532 no +trans: 93695 33093 2.82559 no + +average: 4.9158 + +time: 907ms + + + + The Artificial Corpus + +file size packed size bpb corruption + +a: 1 9 72 no +aaa: 100000 3771 0.30168 no +alphabet: 100000 486 0.03888 no +random: 100000 112509 9.00072 no + +average: 20.3353 + +time: 78ms + + + + The Large Corpus + +file size packed size bpb corruption + +E.coli: 4638690 4747257 8.18724 no +bible: 4047392 2466675 4.87558 no +word: 2473400 1301805 4.21058 no + +average: 5.7578 + +time: 3.25sec + + \ No newline at end of file diff --git a/ml/dlib/docs/docs/kernel_3b.xml b/ml/dlib/docs/docs/kernel_3b.xml new file mode 100644 index 000000000..9571ab419 --- /dev/null +++ b/ml/dlib/docs/docs/kernel_3b.xml @@ -0,0 +1,8 @@ + + + + + kernel_3b + + + diff --git a/ml/dlib/docs/docs/license.xml b/ml/dlib/docs/docs/license.xml new file mode 100644 index 000000000..c47ee788b --- /dev/null +++ b/ml/dlib/docs/docs/license.xml @@ -0,0 +1,36 @@ + + + + + License +
    +
    +
    +Boost Software License - Version 1.0 - August 17th, 2003
    +
    +Permission is hereby granted, free of charge, to any person or organization
    +obtaining a copy of the software and accompanying documentation covered by
    +this license (the "Software") to use, reproduce, display, distribute,
    +execute, and transmit the Software, and to prepare derivative works of the
    +Software, and to permit third-parties to whom the Software is furnished to
    +do so, all subject to the following:
    +
    +The copyright notices in the Software and this entire statement, including
    +the above license grant, this restriction and the following disclaimer,
    +must be included in all copies of the Software, in whole or in part, and
    +all derivative works of the Software, unless such copies or derivative
    +works are solely in the form of machine-executable object code generated by
    +a source language processor.
    +
    +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    +FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT
    +SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE
    +FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,
    +ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
    +DEALINGS IN THE SOFTWARE.
    +
    +
    +
    +   
    +
    diff --git a/ml/dlib/docs/docs/linear_algebra.xml b/ml/dlib/docs/docs/linear_algebra.xml new file mode 100644 index 000000000..4d228376b --- /dev/null +++ b/ml/dlib/docs/docs/linear_algebra.xml @@ -0,0 +1,1382 @@ + + + + + Linear Algebra + + + + +

    + This page documents the core linear algebra tools included in dlib. + In particular, the three most important objects in this part of the library are the + matrix, vector, and + rectangle. All the other tools on this page + are functions for manipulating these three objects. A good example and introduction + can be found in + the matrix example program. +

    + +

    + Most of the linear algebra tools deal with dense matrices. However, there is also + a limited amount of support for working with sparse matrices and vectors. + In particular, the dlib tools represent sparse vectors using the containers + in the C++ STL. For details, see the notes at the top of + dlib/svm/sparse_vector_abstract.h. +

    +

    + Finally, note that all the dense matrix tools can be obtained by #including <dlib/matrix.h> + while the sparse vector tools can be obtained by #including <dlib/sparse_vector.h>. The + geometry tools can be used by #including <dlib/geometry.h>. +

    + + + + + + + + +
    + Dense Matrix Tools + matrix + + Basic Math Operators + + + exp + dlib/matrix/matrix_math_functions_abstract.h.html#exp + + + log10 + dlib/matrix/matrix_math_functions_abstract.h.html#log10 + + + log + dlib/matrix/matrix_math_functions_abstract.h.html#log + + + sqrt + dlib/matrix/matrix_math_functions_abstract.h.html#sqrt + + + pow + dlib/matrix/matrix_math_functions_abstract.h.html#pow + + + squared + dlib/matrix/matrix_math_functions_abstract.h.html#squared + + + cubed + dlib/matrix/matrix_math_functions_abstract.h.html#cubed + + + sigmoid + dlib/matrix/matrix_math_functions_abstract.h.html#sigmoid + + + abs + dlib/matrix/matrix_math_functions_abstract.h.html#abs + + + reciprocal + dlib/matrix/matrix_math_functions_abstract.h.html#reciprocal + + + reciprocal_max + dlib/matrix/matrix_math_functions_abstract.h.html#reciprocal_max + + + normalize + dlib/matrix/matrix_math_functions_abstract.h.html#normalize + + + round + dlib/matrix/matrix_math_functions_abstract.h.html#round + + + ceil + dlib/matrix/matrix_math_functions_abstract.h.html#ceil + + + floor + dlib/matrix/matrix_math_functions_abstract.h.html#floor + + + round_zeros + dlib/matrix/matrix_math_functions_abstract.h.html#round_zeros + + + conj + dlib/matrix/matrix_math_functions_abstract.h.html#conj + + + norm + dlib/matrix/matrix_math_functions_abstract.h.html#norm + + + imag + dlib/matrix/matrix_math_functions_abstract.h.html#imag + + + real + dlib/matrix/matrix_math_functions_abstract.h.html#real + + + complex_matrix + dlib/matrix/matrix_math_functions_abstract.h.html#complex_matrix + + + sin + dlib/matrix/matrix_math_functions_abstract.h.html#sin + + + cos + dlib/matrix/matrix_math_functions_abstract.h.html#cos + + + tan + dlib/matrix/matrix_math_functions_abstract.h.html#tan + + + asin + dlib/matrix/matrix_math_functions_abstract.h.html#asin + + + acos + dlib/matrix/matrix_math_functions_abstract.h.html#acos + + + atan + dlib/matrix/matrix_math_functions_abstract.h.html#atan + + + sinh + dlib/matrix/matrix_math_functions_abstract.h.html#sinh + + + cosh + dlib/matrix/matrix_math_functions_abstract.h.html#cosh + + + tanh + dlib/matrix/matrix_math_functions_abstract.h.html#tanh + + + + + Linear Algebra + + + inv + dlib/matrix/matrix_la_abstract.h.html#inv + + + pinv + dlib/matrix/matrix_la_abstract.h.html#pinv + + + svd + dlib/matrix/matrix_la_abstract.h.html#svd + + + svd2 + dlib/matrix/matrix_la_abstract.h.html#svd2 + + + svd3 + dlib/matrix/matrix_la_abstract.h.html#svd3 + + + svd_fast + dlib/matrix/matrix_la_abstract.h.html#svd_fast + + + orthogonalize + dlib/matrix/matrix_la_abstract.h.html#orthogonalize + + + det + dlib/matrix/matrix_la_abstract.h.html#det + + + trace + dlib/matrix/matrix_la_abstract.h.html#trace + + + dot + dlib/matrix/matrix_utilities_abstract.h.html#dot + + + length + dlib/matrix/matrix_utilities_abstract.h.html#length + + + length_squared + dlib/matrix/matrix_utilities_abstract.h.html#length_squared + + + trans + dlib/matrix/matrix_utilities_abstract.h.html#trans + + + diag + dlib/matrix/matrix_utilities_abstract.h.html#diag + + + diagm + dlib/matrix/matrix_utilities_abstract.h.html#diagm + + + lowerm + dlib/matrix/matrix_utilities_abstract.h.html#lowerm + + + upperm + dlib/matrix/matrix_utilities_abstract.h.html#upperm + + + chol + dlib/matrix/matrix_la_abstract.h.html#chol + + + inv_lower_triangular + dlib/matrix/matrix_la_abstract.h.html#inv_lower_triangular + + + inv_upper_triangular + dlib/matrix/matrix_la_abstract.h.html#inv_upper_triangular + + + lu_decomposition + dlib/matrix/matrix_la_abstract.h.html#lu_decomposition + + + qr_decomposition + dlib/matrix/matrix_la_abstract.h.html#qr_decomposition + + + cholesky_decomposition + dlib/matrix/matrix_la_abstract.h.html#cholesky_decomposition + + + eigenvalue_decomposition + dlib/matrix/matrix_la_abstract.h.html#eigenvalue_decomposition + + + real_eigenvalues + dlib/matrix/matrix_la_abstract.h.html#real_eigenvalues + + + + + Conversions + + mat + + matrix_cast + dlib/matrix/matrix_utilities_abstract.h.html#matrix_cast + + + pixel_to_vector + dlib/matrix/matrix_utilities_abstract.h.html#pixel_to_vector + + + vector_to_pixel + dlib/matrix/matrix_utilities_abstract.h.html#vector_to_pixel + + + + + Sub Matrix Expressions + + + range + dlib/matrix/matrix_subexp_abstract.h.html#range + + + subm + dlib/matrix/matrix_subexp_abstract.h.html#subm + + + subm_clipped + dlib/matrix/matrix_subexp_abstract.h.html#subm_clipped + + + rowm + dlib/matrix/matrix_subexp_abstract.h.html#rowm + + + colm + dlib/matrix/matrix_subexp_abstract.h.html#colm + + + set_ptrm + dlib/matrix/matrix_subexp_abstract.h.html#set_ptrm + + + set_subm + dlib/matrix/matrix_subexp_abstract.h.html#set_subm + + + set_colm + dlib/matrix/matrix_subexp_abstract.h.html#set_colm + + + set_rowm + dlib/matrix/matrix_subexp_abstract.h.html#set_rowm + + + + + Statistics + + + sum + dlib/matrix/matrix_utilities_abstract.h.html#sum + + + sum_rows + dlib/matrix/matrix_utilities_abstract.h.html#sum_rows + + + sum_cols + dlib/matrix/matrix_utilities_abstract.h.html#sum_cols + + + prod + dlib/matrix/matrix_utilities_abstract.h.html#prod + + + mean + dlib/matrix/matrix_utilities_abstract.h.html#mean + + + max + dlib/matrix/matrix_utilities_abstract.h.html#max + + + min + dlib/matrix/matrix_utilities_abstract.h.html#min + + + find_min_and_max + dlib/matrix/matrix_utilities_abstract.h.html#find_min_and_max + + + max_point + dlib/matrix/matrix_utilities_abstract.h.html#max_point + + + max_point_interpolated + dlib/matrix/matrix_utilities_abstract.h.html#max_point_interpolated + + + min_point + dlib/matrix/matrix_utilities_abstract.h.html#min_point + + + index_of_min + dlib/matrix/matrix_utilities_abstract.h.html#index_of_min + + + index_of_max + dlib/matrix/matrix_utilities_abstract.h.html#index_of_max + + + variance + dlib/matrix/matrix_utilities_abstract.h.html#variance + + + stddev + dlib/matrix/matrix_utilities_abstract.h.html#stddev + + + covariance + dlib/matrix/matrix_utilities_abstract.h.html#covariance + + + randm + dlib/matrix/matrix_utilities_abstract.h.html#randm + + + gaussian_randm + dlib/matrix/matrix_utilities_abstract.h.html#gaussian_randm + + + + + Other Utilities + + + csv + dlib/matrix/matrix_abstract.h.html#csv + + + fft + dlib/matrix/matrix_fft_abstract.h.html#fft + + + ifft + dlib/matrix/matrix_fft_abstract.h.html#ifft + + + is_col_vector + dlib/matrix/matrix_utilities_abstract.h.html#is_col_vector + + + is_row_vector + dlib/matrix/matrix_utilities_abstract.h.html#is_row_vector + + + is_vector + dlib/matrix/matrix_utilities_abstract.h.html#is_vector + + + is_finite + dlib/matrix/matrix_utilities_abstract.h.html#is_finite + + + const_temp_matrix + dlib/matrix/matrix_abstract.h.html#const_temp_matrix + + + symmetric_matrix_cache + dlib/matrix/symmetric_matrix_cache_abstract.h.html + + + conv + dlib/matrix/matrix_conv_abstract.h.html#conv + + + conv_same + dlib/matrix/matrix_conv_abstract.h.html#conv_same + + + conv_valid + dlib/matrix/matrix_conv_abstract.h.html#conv_valid + + + xcorr_fft + dlib/matrix/matrix_conv_abstract.h.html#xcorr_fft + + + xcorr + dlib/matrix/matrix_conv_abstract.h.html#xcorr + + + xcorr_same + dlib/matrix/matrix_conv_abstract.h.html#xcorr_same + + + xcorr_valid + dlib/matrix/matrix_conv_abstract.h.html#xcorr_valid + + + flip + dlib/matrix/matrix_utilities_abstract.h.html#flip + + + flipud + dlib/matrix/matrix_utilities_abstract.h.html#flipud + + + fliplr + dlib/matrix/matrix_utilities_abstract.h.html#fliplr + + + make_symmetric + dlib/matrix/matrix_utilities_abstract.h.html#make_symmetric + + + ones_matrix + dlib/matrix/matrix_utilities_abstract.h.html#ones_matrix + + + zeros_matrix + dlib/matrix/matrix_utilities_abstract.h.html#zeros_matrix + + + uniform_matrix + dlib/matrix/matrix_utilities_abstract.h.html#uniform_matrix + + + identity_matrix + dlib/matrix/matrix_utilities_abstract.h.html#identity_matrix + + + rotate + dlib/matrix/matrix_utilities_abstract.h.html#rotate + + + reshape_to_column_vector + dlib/matrix/matrix_utilities_abstract.h.html#reshape_to_column_vector + + + reshape + dlib/matrix/matrix_utilities_abstract.h.html#reshape + + + removerc + dlib/matrix/matrix_utilities_abstract.h.html#removerc + + + remove_row + dlib/matrix/matrix_utilities_abstract.h.html#remove_row + + + remove_col + dlib/matrix/matrix_utilities_abstract.h.html#remove_col + + + set_all_elements + dlib/matrix/matrix_utilities_abstract.h.html#set_all_elements + + + hash + dlib/matrix/matrix_utilities_abstract.h.html#hash + + + tmp + dlib/matrix/matrix_utilities_abstract.h.html#tmp + + + equal + dlib/matrix/matrix_utilities_abstract.h.html#equal + + + pointwise_multiply + dlib/matrix/matrix_utilities_abstract.h.html#pointwise_multiply + + + join_rows + dlib/matrix/matrix_utilities_abstract.h.html#join_rows + + + join_cols + dlib/matrix/matrix_utilities_abstract.h.html#join_cols + + + tensor_product + dlib/matrix/matrix_utilities_abstract.h.html#tensor_product + + + scale_columns + dlib/matrix/matrix_utilities_abstract.h.html#scale_columns + + + scale_rows + dlib/matrix/matrix_utilities_abstract.h.html#scale_rows + + + sort_columns + dlib/matrix/matrix_utilities_abstract.h.html#sort_columns + + + rsort_columns + dlib/matrix/matrix_utilities_abstract.h.html#rsort_columns + + + + clamp + dlib/matrix/matrix_utilities_abstract.h.html#clamp + + + lowerbound + dlib/matrix/matrix_utilities_abstract.h.html#lowerbound + + + upperbound + dlib/matrix/matrix_utilities_abstract.h.html#upperbound + + + linspace + dlib/matrix/matrix_utilities_abstract.h.html#linspace + + + linpiece + dlib/matrix/matrix_utilities_abstract.h.html#linpiece + + + logspace + dlib/matrix/matrix_utilities_abstract.h.html#logspace + + + cartesian_product + dlib/matrix/matrix_utilities_abstract.h.html#cartesian_product + + + +
    + +
    + 2D/3D Geometry + border_enumerator + rectangle + drectangle + vector + point + rotate_point + rotate_around_x + rotate_around_y + rotate_around_z + point_rotator + point_transform + camera_transform + point_transform_affine + rectangle_transform + point_transform_affine3d + find_affine_transform + find_similarity_transform + point_transform_projective + find_projective_transform + rotation_matrix + get_rect + centered_rect + set_aspect_ratio + set_rect_area + center + dcenter + shrink_rect + grow_rect + translate_rect + translate_point + resize_rect + resize_rect_width + resize_rect_height + move_rect + nearest_point + nearest_rect + distance_to_rect_edge + clip_line_to_rectangle + distance_to_line +
    + + +
    + Sparse Vector Tools + sparse_to_dense + + + dot + dlib/svm/sparse_vector_abstract.h.html#dot + + + distance_squared + dlib/svm/sparse_vector_abstract.h.html#distance_squared + + + distance + dlib/svm/sparse_vector_abstract.h.html#distance + + + assign + dlib/svm/sparse_vector_abstract.h.html#assign + + + length_squared + dlib/svm/sparse_vector_abstract.h.html#length_squared + + + length + dlib/svm/sparse_vector_abstract.h.html#length + + + scale_by + dlib/svm/sparse_vector_abstract.h.html#scale_by + + + add + dlib/svm/sparse_vector_abstract.h.html#add + + + subtract + dlib/svm/sparse_vector_abstract.h.html#subtract + + + max_index_plus_one + dlib/svm/sparse_vector_abstract.h.html#max_index_plus_one + + + add_to + dlib/svm/sparse_vector_abstract.h.html#add_to + + + subtract_from + dlib/svm/sparse_vector_abstract.h.html#subtract_from + + + min + dlib/svm/sparse_vector_abstract.h.html#min + + + max + dlib/svm/sparse_vector_abstract.h.html#max + + + make_sparse_vector + dlib/svm/sparse_vector_abstract.h.html#make_sparse_vector + + + make_sparse_vector_inplace + dlib/svm/sparse_vector_abstract.h.html#make_sparse_vector_inplace + + + sparse_matrix_vector_multiply + dlib/svm/sparse_vector_abstract.h.html#sparse_matrix_vector_multiply + +
    + +
    +
    + + + + + + + + + + + sparse_to_dense + dlib/sparse_vector.h + dlib/svm/sparse_vector_abstract.h + + This is a set of simple functions that take + sparse vectors + and converts them into equivalent dense vectors. + + + + + + + mat + dlib/matrix.h + dlib/matrix/matrix_mat_abstract.h + + This is a set of simple functions that take objects like std::vector or + array2d and convert them into + matrix objects. Note that the conversion is + done using template expressions so there is no runtime cost associated + with calling mat(). + + + + + + + + matrix + dlib/matrix.h + dlib/matrix/matrix_abstract.h + + This is a 2D matrix object that enables you to write code that deals with + matrices using a simple syntax similar to what can be written in MATLAB. It is implemented using + the expression templates technique which allows it to eliminate the + temporary matrix objects that would normally be returned from expressions + such as M = A+B+C+D; Normally each invocation of the + operator would + construct and return a temporary matrix object but using this technique + we can avoid creating all these temporary objects and receive a large speed boost. +

    + This object is also capable of using BLAS and LAPACK libraries such as ATLAS or the Intel + MKL when available. To enable BLAS support all you have to do is #define + DLIB_USE_BLAS and then make sure you link your application with your + BLAS library. Similarly, to enable LAPACK support just #define DLIB_USE_LAPACK and + link to your LAPACK library. Finally, the use of BLAS and LAPACK is transparent to + the user, that is, the dlib matrix object uses BLAS and LAPACK internally to optimize + various operations while still allowing the user to use a simple MATLAB like syntax. +

    +

    + Note that the cmake files that come with dlib will automatically link with ATLAS or the Intel + MKL if they are installed. So using cmake makes this easy, but by no means are you required + to use cmake or the dlib cmake files. +

    +

    + It is also worth noting that all the preconditions of every function + related to the matrix object are checked by DLIB_ASSERT + statements and thus can be enabled by #defining ENABLE_ASSERTS or DEBUG. Doing + this will cause your program to run slower but should catch any usage errors. +

    +
    + + + matrix_ex.cpp.html + matrix_expressions_ex.cpp.html + + + + + matrix_utilities + dlib/matrix/matrix_utilities_abstract.h + + This extension contains miscellaneous utility functions + for manipulating matrix objects. + + + + + matrix_la + dlib/matrix/matrix_la_abstract.h + + This extension contains linear algebra functions to calculate + QR, LU, Cholesky, eigenvalue, and singular value decompositions. It also + contains a few other miscellaneous functions that solve systems of + equations or calculate values derived from the above decompositions. + + + + + matrix_math_functions + dlib/matrix/matrix_math_functions_abstract.h + This extension contains mathematical functions that operate on each + element of a matrix independently. + + + + + matrix_sub_expressions + dlib/matrix/matrix_subexp_abstract.h + + This extension contains a number of functions for dealing with sub-matrices. + + + + +
    + + + + + border_enumerator + dlib/geometry.h + dlib/geometry/border_enumerator_abstract.h + + This object is an enumerator + over the border points of a + rectangle. + + + + + + + nearest_point + dlib/geometry.h + dlib/geometry/rectangle_abstract.h + + This function takes a rectangle and a + point and returns the point in the given + rectangle that is nearest to the given point. + + + + + + + + nearest_rect + dlib/geometry.h + dlib/geometry/rectangle_abstract.h + + This function takes a std::vector<rectangle> and a + point and identifies the rectangle that is nearest to the point. + + + + + + + + distance_to_rect_edge + dlib/geometry.h + dlib/geometry/rectangle_abstract.h + + This function takes a rectangle and a + point and returns the Manhattan distance between + the rectangle's edge and the point. + + + + + + + + distance_to_line + dlib/geometry.h + dlib/geometry/rectangle_abstract.h + + This function takes a line and a point + and returns the distance from the line to the point. + + + + + + + + clip_line_to_rectangle + dlib/geometry.h + dlib/geometry/rectangle_abstract.h + + This function takes a rectangle and a line segment and + returns the part of the line segment that is entirely contained within the + rectangle. + + + + + + + move_rect + dlib/geometry.h + dlib/geometry/rectangle_abstract.h + + This function takes a rectangle and moves + it so that it's upper left corner occupies the given location. + + + + + + + + resize_rect_height + dlib/geometry.h + dlib/geometry/rectangle_abstract.h + + This function takes a rectangle and + returns a new rectangle with the given height but otherwise with the + same edge points as the original rectangle. + + + + + + + + resize_rect_width + dlib/geometry.h + dlib/geometry/rectangle_abstract.h + + This function takes a rectangle and + returns a new rectangle with the given width but otherwise with the + same edge points as the original rectangle. + + + + + + + + resize_rect + dlib/geometry.h + dlib/geometry/rectangle_abstract.h + + This function takes a rectangle and + returns a new rectangle with the given size but with the same upper + left corner as the original rectangle. + + + + + + + + translate_rect + dlib/geometry.h + dlib/geometry/rectangle_abstract.h + + This function takes a rectangle and moves + it by a given number of units along the x and y axis relative to + where it was before the move. + + + + + + + + center + dlib/geometry.h + dlib/geometry/rectangle_abstract.h + + Returns the center point of a rectangle. + + + + + + + + dcenter + dlib/geometry.h + dlib/geometry/rectangle_abstract.h + + Returns the center point of a rectangle. This + is a version of center() which returns a double version + of the point rather than one which uses integers to represent the + result. Therefore, it is slightly more accurate. + + + + + + + + centered_rect + dlib/geometry.h + dlib/geometry/rectangle_abstract.h + + There are various overloads of this function but the basic idea is + that it returns a rectangle with a given + width and height and centered about a given point. + + + + + + + + set_aspect_ratio + dlib/geometry.h + dlib/geometry/rectangle_abstract.h + + This function reshapes a rectangle so that + it has a user specified aspect ratio. + + + + + + + + set_rect_area + dlib/geometry.h + dlib/geometry/rectangle_abstract.h + + This function reshapes a rectangle so that + it has a user specified area. + + + + + + + + shrink_rect + dlib/geometry.h + dlib/geometry/rectangle_abstract.h + + This function takes a rectangle object, + shrinks its borders by a given amount, and returns the result. + + + + + + + + grow_rect + dlib/geometry.h + dlib/geometry/rectangle_abstract.h + + This function takes a rectangle object, + grows its borders by a given amount, and returns the result. + + + + + + + + rotate_point + dlib/geometry.h + dlib/geometry/point_transforms_abstract.h + + This is a function that rotates a 2D vector or + point object about a given point. + + + + + + + + point_rotator + dlib/geometry.h + dlib/geometry/point_transforms_abstract.h + + This is an object that rotates a 2D vector or + point object about the origin. + + + + + + + + point_transform + dlib/geometry.h + dlib/geometry/point_transforms_abstract.h + + This is an object that rotates a 2D vector or + point object about the origin and then adds a + displacement vector. + + + + + + + + point_transform_affine + dlib/geometry.h + dlib/geometry/point_transforms_abstract.h + + This is an object that applies a 2D affine transformation to a vector or + point. Note that you can use find_affine_transform + to easily create affine transforms from sets of point correspondences. + + + + + + + + rectangle_transform + dlib/geometry.h + dlib/geometry/point_transforms_abstract.h + + This is an object that applies a 2D affine transformation + to a rectangle or drectangle. + + + + + + + + camera_transform + dlib/geometry.h + dlib/geometry/point_transforms_abstract.h + + This object maps 3D points into the image plane of a camera. Therefore, + you can use it to compute 2D representations of 3D data from the point of + view of some camera in 3D space. + + + + + + + + point_transform_affine3d + dlib/geometry.h + dlib/geometry/point_transforms_abstract.h + + This is an object that applies a 3D affine transformation to a vector. + + + + + + + + find_affine_transform + dlib/geometry.h + dlib/geometry/point_transforms_abstract.h + + This is a routine that takes in two sets of points and finds the + best affine transformation + that maps between them. + + + + + + + + find_similarity_transform + dlib/geometry.h + dlib/geometry/point_transforms_abstract.h + + This is a routine that takes in two sets of points and finds the + best affine transformation + that maps between them. However, it considers only rotations, translations, + and uniform scale changes in finding the mapping. Therefore, it finds + a similarity transformation rather than a general affine transform. + + + + + + + + point_transform_projective + dlib/geometry.h + dlib/geometry/point_transforms_abstract.h + + This is an object that applies a projective transformation to a vector or + point. Note that you can use find_projective_transform + to easily create projective transforms from sets of point correspondences. + + + + + + + + find_projective_transform + dlib/geometry.h + dlib/geometry/point_transforms_abstract.h + + This is a routine that takes in two sets of points and finds the + best projective transformation + that maps between them. + + + + + + + + rotation_matrix + dlib/geometry.h + dlib/geometry/point_transforms_abstract.h + + This is a method for creating 2D rotation matrices. + + + + + + + + rotate_around_x + dlib/geometry.h + dlib/geometry/point_transforms_abstract.h + + This is a method for creating a point_transform_affine3d + that rotates points around the x-axis. + + + + + + + rotate_around_y + dlib/geometry.h + dlib/geometry/point_transforms_abstract.h + + This is a method for creating a point_transform_affine3d + that rotates points around the y-axis. + + + + + + + rotate_around_z + dlib/geometry.h + dlib/geometry/point_transforms_abstract.h + + This is a method for creating a point_transform_affine3d + that rotates points around the z-axis. + + + + + + + translate_point + dlib/geometry.h + dlib/geometry/point_transforms_abstract.h + + This is a method for creating a point_transform_affine3d + that just translates points. + + + + + + + get_rect + dlib/geometry.h + dlib/geometry/rectangle_abstract.h + + This is a simple template function that returns a rectangle + representing the size of a 2D container (e.g. matrix or + array2d). + + + + + + + drectangle + dlib/geometry.h + dlib/geometry/drectangle_abstract.h + + This object represents a rectangular region inside a Cartesian + coordinate system. It is very similar to the rectangle + except that it uses double variables instead of longs to represent the location of the rectangle. + Therefore, it can position rectangles with sub-pixel accuracy. + + + + + + + rectangle + dlib/geometry.h + dlib/geometry/rectangle_abstract.h + + This object represents a rectangular region inside a Cartesian + coordinate system. It allows you to easily represent and manipulate + rectangles. + + + + + + + vector + dlib/geometry.h + dlib/geometry/vector_abstract.h + + This object represents a two or three dimensional vector. + +

    If you + want to work with general N-dimensional column vectors then you + should the matrix object. In particular, you + should usually use a matrix with this type: + dlib::matrix<double,0,1>.

    +
    +
    + + + + + point + dlib/geometry.h + dlib/geometry/vector_abstract.h + + This object represents a point inside a Cartesian coordinate system. + Note that a point is simply a typedef for a vector + that is 2D and uses longs to represent coordinate values. + + + + + + +
    + + + + +
    + diff --git a/ml/dlib/docs/docs/main_menu.xml b/ml/dlib/docs/docs/main_menu.xml new file mode 100644 index 000000000..bbd91f975 --- /dev/null +++ b/ml/dlib/docs/docs/main_menu.xml @@ -0,0 +1,665 @@ + + + + + +
    + The Library + + Algorithms + algorithms.html + algorithms.xml + + + Graph Tools + graph_tools.html + graph_tools.xml + + + Optimization + optimization.html + optimization.xml + + + Machine Learning + ml.html + ml.xml + + + Linear Algebra + linear_algebra.html + linear_algebra.xml + + + Bayesian Nets + bayes.html + bayes.xml + + + Containers + containers.html + containers.xml + + + API Wrappers + api.html + api.xml + + + Networking + network.html + network.xml + + + Compression + compression.html + compression.xml + + + Parsing + parsing.html + parsing.xml + + + Image Processing + imaging.html + imaging.xml + + + Metaprogramming + metaprogramming.html + metaprogramming.xml + + + Miscellaneous + other.html + other.xml + +
    + + +
    + Help/Info + + + Home + index.html + + + + + Home + http://dlib.net + + + + Dlib Blog + http://blog.dlib.net + + + + Who uses dlib? + http://sourceforge.net/p/dclib/wiki/Known_users/ + + + + Introduction + intro.html + intro.xml + + + Python API + python/index.html + + + How to compile + compile.html + + + Suggested Books + books.html + + + License + license.html + + + How to contribute + howto_contribute.html + + + FAQ + faq.html + + + Index + term_index.html + + + Examples: Python + + + Global Optimization + global_optimization.py.html + + + Face Clustering + face_clustering.py.html + + + Face Jittering/Augmentation + face_jitter.py.html + + + Face Alignment + face_alignment.py.html + + + Video Object Tracking + correlation_tracker.py.html + + + Binary Classification + svm_binary_classifier.py.html + + + Face Landmark Detection + face_landmark_detection.py.html + + + Face Recognition + face_recognition.py.html + + + Find Candidate Object Locations + find_candidate_object_locations.py.html + + + Train Shape Predictor + train_shape_predictor.py.html + + + Face Detector + face_detector.py.html + + + CNN Face Detector + cnn_face_detector.py.html + + + Train Object Detector + train_object_detector.py.html + + + Sequence Segmenter + sequence_segmenter.py.html + + + Structural Support Vector Machines + svm_struct.py.html + + + SVM-Rank + svm_rank.py.html + + + Linear Assignment Problems + max_cost_assignment.py.html + + + + + Examples: C++ + + + Deep Learning Introduction Part 1 + dnn_introduction_ex.cpp.html + + + Deep Learning Introduction Part 2 + dnn_introduction2_ex.cpp.html + + + Deep Learning Imagenet Classifier + dnn_imagenet_ex.cpp.html + + + Deep Learning Imagenet Trainer + dnn_imagenet_train_ex.cpp.html + + + Deep Learning Inception + dnn_inception_ex.cpp.html + + + Deep Metric Learning Introduction + dnn_metric_learning_ex.cpp.html + + + Deep Metric Learning on Images + dnn_metric_learning_on_images_ex.cpp.html + + + Deep Face Recognition + dnn_face_recognition_ex.cpp.html + + + Deep Learning Semantic Segmentation Trainer + dnn_semantic_segmentation_train_ex.cpp.html + + + Deep Learning Semantic Segmentation + dnn_semantic_segmentation_ex.cpp.html + + + Deep Learning Vehicle Detection + dnn_mmod_find_cars_ex.cpp.html + + + Deep Learning Multi-Class Vehicle Detection + dnn_mmod_find_cars2_ex.cpp.html + + + Deep Learning Vehicle Detection Trainer + dnn_mmod_train_find_cars_ex.cpp.html + + + Deep Learning Face Detection + dnn_mmod_face_detection_ex.cpp.html + + + Deep Learning Dog Hipsterizer + dnn_mmod_dog_hipsterizer.cpp.html + + + Deep Learning Max-Margin Object Detection + dnn_mmod_ex.cpp.html + + + + Random Cropper + random_cropper_ex.cpp.html + + + + Linear Model Predictive Control + mpc_ex.cpp.html + + + Video Object Tracking + video_tracking_ex.cpp.html + + + SQLite + sqlite_ex.cpp.html + + + Hough Transform + hough_transform_ex.cpp.html + + + Webcam Face Pose Estimation + webcam_face_pose_ex.cpp.html + + + Linear Assignment Problems + max_cost_assignment_ex.cpp.html + + + Learning to Track + learning_to_track_ex.cpp.html + + + Structural Support Vector Machines + svm_struct_ex.cpp.html + + + Sequence Segmentation + sequence_segmenter_ex.cpp.html + + + Train Object Detector + train_object_detector.cpp.html + + + One Class Classifiers + one_class_classifiers_ex.cpp.html + + + Parallel For Loops + parallel_for_ex.cpp.html + + + Numerical Integration + integrate_function_adapt_simp_ex.cpp.html + + + SVM-Rank + svm_rank_ex.cpp.html + + + BSP + bsp_ex.cpp.html + + + Assignment Learning + assignment_learning_ex.cpp.html + + + Graph Labeling + graph_labeling_ex.cpp.html + + + Sequence Labeling + sequence_labeler_ex.cpp.html + + + Object Detector + object_detector_ex.cpp.html + + + Object Detector Advanced + object_detector_advanced_ex.cpp.html + + + Running Stats + running_stats_ex.cpp.html + + + Config File Reader + config_reader_ex.cpp.html + + + Member Function Pointer + member_function_pointer_ex.cpp.html + + + Empirical Kernel Map + empirical_kernel_map_ex.cpp.html + + + Linear Manifold Regularizer + linear_manifold_regularizer_ex.cpp.html + + + Kernel RLS Regression + krls_ex.cpp.html + + + Optimization + optimization_ex.cpp.html + + + Non-Linear Least Squares + least_squares_ex.cpp.html + + + Kernel RLS Filtering + krls_filter_ex.cpp.html + + + Kernel Centroid + kcentroid_ex.cpp.html + + + Kernel K-Means Clustering + kkmeans_ex.cpp.html + + + Matrix + matrix_ex.cpp.html + + + Matrix Expressions + matrix_expressions_ex.cpp.html + + + 3D Point Cloud + 3d_point_cloud_ex.cpp.html + + + Image + image_ex.cpp.html + + + FHOG Feature Extraction + fhog_ex.cpp.html + + + FHOG Object Detection + fhog_object_detector_ex.cpp.html + + + Face Detection + face_detection_ex.cpp.html + + + Face Landmark Detection + face_landmark_detection_ex.cpp.html + + + Train Shape Predictor + train_shape_predictor_ex.cpp.html + + + SURF + surf_ex.cpp.html + + + Rank Features + rank_features_ex.cpp.html + + + Relevance Vector Regression + rvm_regression_ex.cpp.html + + + Relevance Vector Classification + rvm_ex.cpp.html + + + Kernel Ridge Regression + krr_regression_ex.cpp.html + + + KRR Classification + krr_classification_ex.cpp.html + + + Nu-Support Vector Machine + svm_ex.cpp.html + + + C-Support Vector Machine + svm_c_ex.cpp.html + + + Using Custom Kernels + using_custom_kernels_ex.cpp.html + + + Support Vector Regression + svr_ex.cpp.html + + + Multiclass Classification + multiclass_classification_ex.cpp.html + + + Custom Trainers + custom_trainer_ex.cpp.html + + + Model Selection + model_selection_ex.cpp.html + + + Online SVM + svm_pegasos_ex.cpp.html + + + Sparse Vectors + svm_sparse_ex.cpp.html + + + Neural Network + mlp_ex.cpp.html + + + Bayesian Network From Disk + bayes_net_from_disk_ex.cpp.html + + + Bayesian Network GUI + bayes_net_gui_ex.cpp.html + + + Bayesian Network + bayes_net_ex.cpp.html + + + Std C++ Allocator + std_allocator_ex.cpp.html + + + HTTP Server + server_http_ex.cpp.html + + + Base64 Encoder + file_to_code_ex.cpp.html + + + Sockstreambuf + sockstreambuf_ex.cpp.html + + + IO Streams Server + server_iostream_ex.cpp.html + + + IO Socket Streams + iosockstream_ex.cpp.html + + + Logger + logger_ex.cpp.html + + + Logger Advanced + logger_ex_2.cpp.html + + + Logger Custom Output + logger_custom_output_ex.cpp.html + + + XML Parser + xml_parser_ex.cpp.html + + + Threads + threads_ex.cpp.html + + + Directory Navigation + dir_nav_ex.cpp.html + + + GUI + gui_api_ex.cpp.html + + + Sockets + sockets_ex.cpp.html + + + Queue + queue_ex.cpp.html + + + Quantum Computing + quantum_computing_ex.cpp.html + + + Bridge + bridge_ex.cpp.html + + + Pipe + pipe_ex.cpp.html + + + Pipe 2 + pipe_ex_2.cpp.html + + + Timer + timer_ex.cpp.html + + + Compress Stream + compress_stream_ex.cpp.html + + + Cmd Line Parser + compress_stream_ex.cpp.html#_top + + + Threaded Object + threaded_object_ex.cpp.html + + + Thread Pool + thread_pool_ex.cpp.html + + + Thread Function + thread_function_ex.cpp.html + + + Multithreaded Object + multithreaded_object_ex.cpp.html + + + +
    + +
    + Current Release + Version: + + Release Notes + release_notes.html + + + Change Log + change_log.html + + +
    + + + +
    Download dlib
    +
    ver.
    +
    + http://dlib.net/files/dlib-.tar.bz2 +
    +
    +
    + + + Last Modified:
    + +
    + +
    + +
    + diff --git a/ml/dlib/docs/docs/metaprogramming.xml b/ml/dlib/docs/docs/metaprogramming.xml new file mode 100644 index 000000000..93d095275 --- /dev/null +++ b/ml/dlib/docs/docs/metaprogramming.xml @@ -0,0 +1,813 @@ + + + + + Metaprogramming + + + + + +

    + This page documents library components that provide metaprogramming sorts of functionality. For + the most part they are useful for putting design by contract checks into code or doing various kinds of + clever things with templates. +

    +

    + For example, you might have a templated function that is templated on a type T and you want to + make sure that T is either a char or wchar_t type. You could place the following into your code + and it would cause the compile to error out when T was set to something other than char or wchar_t. +
    + COMPILE_TIME_ASSERT((is_same_type<T,char>::value || is_same_type<T,wchar_t>::value)); +

    + + + + + + + +
    + Objects + is_pointer_type + is_const_type + is_reference_type + is_same_type + is_float_type + is_convertible + is_complex + is_function + is_signed_type + is_unsigned_type + static_switch + noncopyable + enable_if + is_array2d + is_array + is_graph + is_rand + is_matrix + is_config_reader + is_std_vector + is_pair + is_directed_graph + is_built_in_scalar_type + promote + basic_type + unsigned_type + tabs + tmin + tmax + compile_time_integer_list + make_compile_time_integer_range +
    + +
    + Global Functions + DLIB_ASSERT + DLIB_STACK_TRACE + DLIB_STACK_TRACE_NAMED + get_stack_trace + DLIB_CASSERT + COMPILE_TIME_ASSERT + ASSERT_ARE_SAME_TYPE + DLIB_ASSERT_HAS_STANDARD_LAYOUT + ASSERT_ARE_NOT_SAME_TYPE + _dT + is_same_object + assign_zero_if_built_in_scalar_type + wrap_function + DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST +
    + +
    + Other + portability_macros +
    + +
    +
    + + + + + + + + + + + tmin + dlib/algs.h + dlib/algs.h + + This is a template to compute the min of two values at compile time. + + + + + + + + tmax + dlib/algs.h + dlib/algs.h + + This is a template to compute the max of two values at compile time. + + + + + + + + compile_time_integer_list + dlib/metaprogramming.h + dlib/metaprogramming.h + + This is a variadic template that can represent a list of integers. + + + + + + + + make_compile_time_integer_range + dlib/metaprogramming.h + dlib/metaprogramming.h + + This is a variadic template that takes one number, MAX, as input + and creates a compile_time_integer_list + representing the range of integers [1,MAX] inclusive. + + + + + + + tabs + dlib/algs.h + dlib/algs.h + + This is a template to compute the absolute value a number at compile time. + + + + + + + + unsigned_type + dlib/uintn.h + dlib/uintn.h + + This is a template that allows you to obtain the unsigned version + of any integral type. For example, unsigned_type<signed short>::type == + unsigned short. + + + + + + + + static_switch + dlib/algs.h + dlib/algs.h + + To use this template you give it some number of boolean expressions and it + tells you which one of them is true. If more than one of them is true then + it causes a compile time error. It is useful for cases where you want to + specialize a template and you want to specialize it not by + the type of object it gets per say but instead according to the values of some + type traits associated with the various template arguments. A simple example of + this can be seen in the assign_pixel's + implementation which can be found at the bottom of the + dlib/pixel.h file. + + + + + + + + enable_if + dlib/enable_if.h + + This is a family of templates from the Boost C++ libraries that makes it somewhat easier to control + template specialization. For the details see + this page. Note that the header dlib/enable_if.h brings + these templates into the dlib namespace.
    +
    + +
    + + + + + noncopyable + dlib/noncopyable.h + dlib/noncopyable.h + + This is a simple class that makes it easy to declare a non-copyable object. + To use it to make your own class non-copyable just inherit from it. + + + + + + + + is_convertible + dlib/algs.h + dlib/algs.h + + This is a template that can be used to determine if one type is convertible + into another type. + + + + + + + + is_complex + dlib/matrix.h + dlib/matrix/matrix_utilities.h + + This is a template that can be used to determine if a type is a + specialization of std::complex. + + + + + + + + is_same_type + dlib/algs.h + dlib/algs.h + + This is a template where is_same_type<T,U>::value == true when T and U are + the same type and false otherwise. + + + + + + + + is_float_type + dlib/algs.h + dlib/algs.h + + This is a template where is_float_type<T>::value == true when T is + a floating point type (i.e. float, double, or long double) and false otherwise. + + + + + + + + is_same_object + dlib/algs.h + dlib/algs.h + + This is a templated function which checks if both of its arguments are actually + references to the same object. It returns true if they are and false otherwise. + + + + + + + + is_function + dlib/algs.h + dlib/algs.h + + This is a template where is_function<T>::value == true when T is + a function type. + + + + + + + + is_signed_type + dlib/algs.h + dlib/algs.h + + This is a template where is_signed_type<T>::value == true when T is + a signed scalar type and false when it is an unsigned scalar + type. + + + + + + + + is_unsigned_type + dlib/algs.h + dlib/algs.h + + This is a template where is_unsigned_type<T>::value == true when T is + an unsigned scalar type and false when it is a signed scalar + type. + + + + + + + + is_directed_graph + dlib/is_kind.h + dlib/is_kind.h + + This is a template where is_directed_graph<T>::value == true when T + is a directed_graph object. + + + + + + + + is_built_in_scalar_type + dlib/algs.h + dlib/algs.h + + This is a template where is_built_in_scalar_type<T>::value == true when T + is a built in scalar type such as int, char, float, etc. + + + + + + + + promote + dlib/algs.h + dlib/algs.h + + This is a template that takes one of the built in scalar types and gives you another + scalar type that should be big enough to hold sums of values from the original scalar + type. The new scalar type will also always be signed. + +

    + For example, promote<uint16>::type == int32 +

    +
    + +
    + + + + + basic_type + dlib/algs.h + dlib/algs.h + + This is a template that takes a type and strips off any const, volatile, or reference + qualifiers and gives you back the basic underlying type. + +

    + For example, promote<const int&>::type == int +

    +
    + +
    + + + + + + is_std_vector + dlib/is_kind.h + dlib/is_kind.h + + This is a template where is_std_vector<T>::value == true when T + is a std_vector_c or std::vector object. + + + + + + + + is_pair + dlib/is_kind.h + dlib/is_kind.h + + This is a template where is_pair<T>::value == true when T + is a std::pair object. + + + + + + + + + is_matrix + dlib/is_kind.h + dlib/is_kind.h + + This is a template where is_matrix<T>::value == true when T + is a matrix object or some kind + of matrix expression. + + + + + + + + is_config_reader + dlib/is_kind.h + dlib/is_kind.h + + This is a template where is_config_reader<T>::value == true when T + is a config_reader or + config_reader_thread_safe object. + + + + + + + + + is_graph + dlib/is_kind.h + dlib/is_kind.h + + This is a template where is_graph<T>::value == true when T + is a graph object. + + + + + + + + is_array2d + dlib/is_kind.h + dlib/is_kind.h + + This is a template where is_array2d<T>::value == true when T + is an array2d object. + + + + + + + + is_array + dlib/is_kind.h + dlib/is_kind.h + + This is a template where is_array<T>::value == true when T + is an array object. + + + + + + + + is_rand + dlib/is_kind.h + dlib/is_kind.h + + This is a template where is_rand<T>::value == true when T + is a rand object. + + + + + + + + is_reference_type + dlib/algs.h + dlib/algs.h + + This is a template where is_reference_type<T>::value == true when T is a reference + type and false otherwise. + + + + + + + + is_const_type + dlib/algs.h + dlib/algs.h + + This is a template where is_const_type<T>::value == true when T is a const + type and false otherwise. + + + + + + + + + is_pointer_type + dlib/algs.h + dlib/algs.h + + This is a template where is_pointer_type<T>::value == true when T is a pointer + type and false otherwise. + + + + + + + + ASSERT_ARE_NOT_SAME_TYPE + dlib/assert.h + dlib/assert.h + +

    + This is a macro function for debugging. Its form is ASSERT_ARE_NOT_SAME_TYPE(type1, type2). + If type1 and type2 are the same type then the compile will fail. This is sometimes useful + in validating template arguments. +

    +
    + +
    + + + + + ASSERT_ARE_SAME_TYPE + dlib/assert.h + dlib/assert.h + +

    + This is a macro function for debugging. Its form is ASSERT_ARE_SAME_TYPE(type1, type2). + If type1 and type2 are not the same type then the compile will fail. This is sometimes useful + in validating template arguments. +

    +
    + +
    + + + + + DLIB_ASSERT_HAS_STANDARD_LAYOUT + dlib/assert.h + dlib/assert.h + +

    + This macro is meant to cause a compiler error if a type doesn't have a simple + memory layout (like a C struct). In particular, types with simple layouts are + ones which can be copied via memcpy(). +

    + + This was called a POD type in C++03 and in C++0x we are looking to check if + it is a "standard layout type". Once we can use C++0x we can change this macro + to something that uses the std::is_standard_layout type_traits class. + See: http://www2.research.att.com/~bs/C++0xFAQ.html#PODs +
    + +
    + + + + + COMPILE_TIME_ASSERT + dlib/assert.h + dlib/assert.h + +

    + This is a macro function for debugging. Its form is COMPILE_TIME_ASSERT(condition that should + be true). The condition must be a compile time constant and if it is false then the compile + will fail. +

    +
    + +
    + + + + + DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST + dlib/algs.h + dlib/algs.h + +

    + The DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST() macro is used to define traits templates + that tell you if a class has a certain member function. For example, to make a + test to see if a class has a public method with the signature void print(int) you + would say: +

    +
    + DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST(has_print, void, print, (int)) +
    + Then you can check if a class, T, has this method by looking at the boolean value: +
    + has_print<T>::value +
    + which will be true if the member function is in the T class. + +
    + +
    + + + + + + DLIB_CASSERT + dlib/assert.h + dlib/assert.h + +

    + This is a macro function that is identical to the DLIB_ASSERT macro + except that it is always enabled. Even if _DEBUG, DEBUG and ENABLE_ASSERTS are not defined. +

    +

    + Note that when this macro fails and throws an exception it also calls the global + C function dlib_assert_breakpoint(). This behavior makes it easy to set a debugging + tool to break when DLIB_CASSERT fails by setting a breakpoint on dlib_assert_breakpoint(). +

    +
    + +
    + + + + + + DLIB_ASSERT + dlib/assert.h + dlib/assert.h + +

    + This is a macro function for debugging. Its form is +DLIB_ASSERT(condition that should be true, error message) +or you can omit the error message and call it like: +DLIB_ASSERT(condition that should be true) + If the condition is false DLIB_ASSERT throws an exception of type + dlib::fatal_error with fatal_error::type == EBROKEN_ASSERT. An error message detailing + the nature of the problem is stored in the member variable info which is of type std::string. + Look in the following file for more details. The exception classes are defined + here. +

    +

    + This macro is only enabled if _DEBUG, DEBUG or ENABLE_ASSERTS is defined. Also, if this macro is + enabled then ENABLE_ASSERTS will be defined even if you didn't define it. +

    +

    + Note that when this macro fails and throws an exception it also calls the global + C function dlib_assert_breakpoint(). This behavior makes it easy to set a debugging + tool to break when DLIB_ASSERT fails by setting a breakpoint on dlib_assert_breakpoint(). +

    +
    + +
    + + + + + + DLIB_STACK_TRACE + dlib/assert.h + dlib/stack_trace.h + +

    + This is a preprocessor macro that allows you to tag a function so + that dlib will keep track of it in a function call stack. That is, + you will be able to see a stack trace by calling get_stack_trace + if you put this macro at the top of your functions. +

    +

    + This macro is only enabled if DLIB_ENABLE_STACK_TRACE is defined. If it isn't defined then + this macro doesn't do anything. Also note that when this macro is defined it will + cause DLIB_ASSERT and DLIB_CASSERT + to include a stack trace in their error messages. +

    +
    + +
    + + + + + + DLIB_STACK_TRACE_NAMED + dlib/assert.h + dlib/stack_trace.h + + This is a preprocessor macro just like DLIB_STACK_TRACE + except that it allows you to supply your own string to use as the function name + in the stack trace instead of the one deduced by DLIB_STACK_TRACE. +

    + This macro is only enabled if DLIB_ENABLE_STACK_TRACE is defined. +

    +
    + +
    + + + + + get_stack_trace + dlib/assert.h + dlib/stack_trace.h + + This function allows you to query the current stack trace. +

    + This macro is only enabled if DLIB_ENABLE_STACK_TRACE is defined. +

    +
    + +
    + + + + + + _dT + dlib/algs.h + dlib/algs.h + + This is a macro function for converting a string/character literal to either a char or wchar_t literal. + Its form is _dT(target character type,string or character literal) + + + + + + + + assign_zero_if_built_in_scalar_type + dlib/algs.h + dlib/algs.h + +

    + This function assigns its argument the value of 0 if it is a built in scalar + type according to the is_built_in_scalar_type + template. If it isn't a built in scalar type then it does nothing. +

    +

    + This function is useful for suppressing compiler warnings about uninitialized + types inside of templates that are designed to accept the built in types + as well as user defined classes. +

    + +
    + +
    + + + + + + wrap_function + dlib/algs.h + dlib/algs.h + + This is a template that allows you to turn a global function into a + function object. See the specs for more details. + + + + + + + + + portability_macros + dlib/platform.h + dlib/platform.h + + This file #defines various macros depending on the platform being compiled under. + See the file itself for the specifics. + + + + + + + +
    + + + + +
    + diff --git a/ml/dlib/docs/docs/minus.gif b/ml/dlib/docs/docs/minus.gif new file mode 100644 index 000000000..1deac2fe1 Binary files /dev/null and b/ml/dlib/docs/docs/minus.gif differ diff --git a/ml/dlib/docs/docs/ml.xml b/ml/dlib/docs/docs/ml.xml new file mode 100644 index 000000000..f97e7da57 --- /dev/null +++ b/ml/dlib/docs/docs/ml.xml @@ -0,0 +1,3957 @@ + + + + + Machine Learning + + + + + + +
    +
    +

    + Dlib contains a wide range of machine learning algorithms. All + designed to be highly modular, quick to execute, and simple to use + via a clean and modern C++ API. It is used in a wide range of + applications including robotics, embedded devices, mobile phones, and large + high performance computing environments. If you use dlib in your + research please cite: +

    +
    +Davis E. King. Dlib-ml: A Machine Learning Toolkit. 
    +   Journal of Machine Learning Research, 2009
    +
    +@Article{dlib09,
    +  author = {Davis E. King},
    +  title = {Dlib-ml: A Machine Learning Toolkit},
    +  journal = {Journal of Machine Learning Research},
    +  year = {2009},
    +  volume = {10},
    +  pages = {1755-1758},
    +}
    +         
    + + + + + + + + +

    Primary Algorithms

    +
    + Binary Classification + svm_nu_trainer + svm_c_trainer + svm_c_linear_trainer + svm_c_linear_dcd_trainer + svm_c_ekm_trainer + rvm_trainer + svm_pegasos + train_probabilistic_decision_function +
    +
    + Multiclass Classification + one_vs_one_trainer + one_vs_all_trainer + svm_multiclass_linear_trainer +
    +
    + Regression + mlp + krls + rls + krr_trainer + rr_trainer + svr_trainer + svr_linear_trainer + rvm_regression_trainer + rbf_network_trainer + random_forest_regression_trainer +
    +
    + Structured Prediction + + Problem Instances + + structural_svm_sequence_labeling_problem + structural_svm_object_detection_problem + structural_svm_assignment_problem + structural_svm_graph_labeling_problem + + + + Core Tools + + structural_svm_problem + structural_svm_problem_threaded + svm_struct_controller_node + svm_struct_processing_node + + + structural_object_detection_trainer + structural_sequence_labeling_trainer + structural_sequence_segmentation_trainer + structural_assignment_trainer + structural_track_association_trainer + structural_graph_labeling_trainer + svm_rank_trainer + shape_predictor_trainer +
    +
    + Deep Learning + + + Core Tools + + dnn_trainer + add_layer + add_loss_layer + repeat + add_tag_layer + add_skip_layer + layer + test_layer + resizable_tensor + alias_tensor + + + + Input Layers + + input + input_rgb_image + input_rgb_image_sized + input_rgb_image_pyramid + + EXAMPLE_INPUT_LAYER + dlib/dnn/input_abstract.h.html#EXAMPLE_INPUT_LAYER + + + + + Computational Layers + + + EXAMPLE_COMPUTATIONAL_LAYER + dlib/dnn/layers_abstract.h.html#EXAMPLE_COMPUTATIONAL_LAYER_ + + + fc + dlib/dnn/layers_abstract.h.html#fc_ + + + con + dlib/dnn/layers_abstract.h.html#con_ + + + cont + dlib/dnn/layers_abstract.h.html#cont_ + + + scale + dlib/dnn/layers_abstract.h.html#scale_ + + + extract + dlib/dnn/layers_abstract.h.html#extract_ + + + mult_prev + dlib/dnn/layers_abstract.h.html#mult_prev_ + + + upsample + dlib/dnn/layers_abstract.h.html#upsample_ + + + l2normalize + dlib/dnn/layers_abstract.h.html#l2normalize_ + + + dropout + dlib/dnn/layers_abstract.h.html#dropout_ + + + multiply + dlib/dnn/layers_abstract.h.html#multiply_ + + + bn + dlib/dnn/layers_abstract.h.html#bn_ + + + affine + dlib/dnn/layers_abstract.h.html#affine_ + + + max_pool + dlib/dnn/layers_abstract.h.html#max_pool_ + + + avg_pool + dlib/dnn/layers_abstract.h.html#avg_pool_ + + + relu + dlib/dnn/layers_abstract.h.html#relu_ + + + concat + dlib/dnn/layers_abstract.h.html#concat_ + + + prelu + dlib/dnn/layers_abstract.h.html#prelu_ + + + sig + dlib/dnn/layers_abstract.h.html#sig_ + + + htan + dlib/dnn/layers_abstract.h.html#htan_ + + + softmax_all + dlib/dnn/layers_abstract.h.html#softmax_all_ + + + softmax + dlib/dnn/layers_abstract.h.html#softmax_ + + + add_prev + dlib/dnn/layers_abstract.h.html#add_prev_ + + + inception + dlib/dnn/layers_abstract.h.html#inception + + + + + Loss Layers + + + EXAMPLE_LOSS_LAYER + dlib/dnn/loss_abstract.h.html#EXAMPLE_LOSS_LAYER_ + + + loss_dot + dlib/dnn/loss_abstract.h.html#loss_dot_ + + + loss_epsilon_insensitive + dlib/dnn/loss_abstract.h.html#loss_epsilon_insensitive_ + + + loss_ranking + dlib/dnn/loss_abstract.h.html#loss_ranking_ + + + loss_binary_hinge + dlib/dnn/loss_abstract.h.html#loss_binary_hinge_ + + + loss_binary_log + dlib/dnn/loss_abstract.h.html#loss_binary_log_ + + + loss_multimulticlass_log + dlib/dnn/loss_abstract.h.html#loss_multimulticlass_log_ + + + loss_multiclass_log + dlib/dnn/loss_abstract.h.html#loss_multiclass_log_ + + + loss_multiclass_log_per_pixel + dlib/dnn/loss_abstract.h.html#loss_multiclass_log_per_pixel_ + + + loss_multiclass_log_per_pixel_weighted + dlib/dnn/loss_abstract.h.html#loss_multiclass_log_per_pixel_weighted_ + + + loss_mmod + #loss_mmod_ + + + loss_metric + #loss_metric_ + + + loss_mean_squared + #loss_mean_squared_ + + + loss_mean_squared_per_pixel + dlib/dnn/loss_abstract.h.html#loss_mean_squared_per_pixel_ + + + loss_mean_squared_multioutput + dlib/dnn/loss_abstract.h.html#loss_mean_squared_multioutput_ + + + + + Solvers + + + EXAMPLE_SOLVER + dlib/dnn/solvers_abstract.h.html#EXAMPLE_SOLVER + + + sgd + dlib/dnn/solvers_abstract.h.html#sgd + + + adam + dlib/dnn/solvers_abstract.h.html#adam + + + +
    + +
    + Clustering + pick_initial_centers + kkmeans + find_clusters_using_kmeans + find_clusters_using_angular_kmeans + nearest_center + newman_cluster + spectral_cluster + chinese_whispers + bottom_up_cluster + segment_number_line + modularity +
    +
    + Unsupervised + kcentroid + linearly_independent_subset_finder + empirical_kernel_map + svm_one_class_trainer + vector_normalizer + vector_normalizer_pca + sammon_projection + cca +
    +
    + Semi-Supervised/Metric Learning + linear_manifold_regularizer + discriminant_pca + vector_normalizer_frobmetric + compute_lda_transform +
    +
    + Reinforcement Learning + lspi +
    +
    + Feature Selection + rank_features + sort_basis_vectors + rank_unlabeled_training_samples +
    + +

    Other Tools

    +
    + Validation + cross_validate_trainer + cross_validate_object_detection_trainer + cross_validate_trainer_threaded + cross_validate_multiclass_trainer + cross_validate_regression_trainer + cross_validate_sequence_labeler + cross_validate_sequence_segmenter + cross_validate_assignment_trainer + cross_validate_track_association_trainer + cross_validate_graph_labeling_trainer + cross_validate_ranking_trainer + test_binary_decision_function + test_multiclass_decision_function + test_regression_function + test_object_detection_function + test_sequence_labeler + test_sequence_segmenter + test_assignment_function + test_track_association_function + test_graph_labeling_function + test_ranking_function + test_shape_predictor + average_precision + equal_error_rate + compute_roc_curve +
    + +
    + Trainer Adapters + reduced + reduced2 + batch + probabilistic + verbose_batch + batch_cached + verbose_batch_cached + null_trainer + roc_c1_trainer + roc_c2_trainer +
    + +
    + Kernels + radial_basis_kernel + polynomial_kernel + sigmoid_kernel + linear_kernel + histogram_intersection_kernel + offset_kernel + + sparse_radial_basis_kernel + sparse_polynomial_kernel + sparse_sigmoid_kernel + sparse_linear_kernel + sparse_histogram_intersection_kernel + +
    + +
    + Function Objects + random_forest_regression_function + decision_function + projection_function + distance_function + probabilistic_decision_function + probabilistic_function + normalized_function + one_vs_one_decision_function + multiclass_linear_decision_function + one_vs_all_decision_function + sequence_labeler + sequence_segmenter + assignment_function + track_association_function + graph_labeler + policy +
    + +
    + Data IO + load_image_dataset_metadata + load_image_dataset + save_image_dataset_metadata + load_libsvm_formatted_data + save_libsvm_formatted_data + fix_nonzero_indexing + make_bounding_box_regression_training_data +
    + +
    + Miscellaneous + simplify_linear_decision_function + fill_lisf + randomize_samples + is_binary_classification_problem + is_sequence_labeling_problem + is_sequence_segmentation_problem + is_graph_labeling_problem + is_assignment_problem + is_track_association_problem + is_forced_assignment_problem + approximate_distance_function + is_learning_problem + select_all_distinct_labels + find_gamma_with_big_centroid_gap + compute_mean_squared_distance + kernel_matrix + ranking_pair + is_ranking_problem + count_ranking_inversions + learn_platt_scaling + process_sample + + + +
    + +
    +
    + + + + + + + + + + + + add_layer + dlib/dnn.h + dlib/dnn/core_abstract.h + + In dlib, a deep neural network is composed of 3 main parts. An + input layer, a bunch of + computational layers, + and optionally a + loss layer. The add_layer + class is the central object which adds a computational layer onto an + input layer or an entire network. Therefore, deep neural networks are created + by stacking many layers on top of each other using the add_layer class. +

    + For a tutorial showing how this is accomplished read + the DNN Introduction part 1 and + DNN Introduction part 2. +

    +
    + + dnn_introduction_ex.cpp.html + dnn_introduction2_ex.cpp.html + dnn_inception_ex.cpp.html + dnn_imagenet_ex.cpp.html + dnn_imagenet_train_ex.cpp.html + dnn_mmod_ex.cpp.html + dnn_mmod_find_cars_ex.cpp.html + dnn_mmod_find_cars2_ex.cpp.html + dnn_mmod_train_find_cars_ex.cpp.html + dnn_mmod_face_detection_ex.cpp.html + dnn_mmod_dog_hipsterizer.cpp.html + dnn_metric_learning_ex.cpp.html + dnn_metric_learning_on_images_ex.cpp.html + dnn_face_recognition_ex.cpp.html + dnn_semantic_segmentation_ex.cpp.html + dnn_semantic_segmentation_train_ex.cpp.html + +
    + + + + + dnn_trainer + dlib/dnn.h + dlib/dnn/trainer_abstract.h + + This object is a tool training a deep neural network. +

    + For a tutorial showing how this is accomplished read + the DNN Introduction part 1 and + DNN Introduction part 2. +

    +
    + + dnn_introduction_ex.cpp.html + dnn_introduction2_ex.cpp.html + dnn_inception_ex.cpp.html + dnn_imagenet_ex.cpp.html + dnn_imagenet_train_ex.cpp.html + dnn_mmod_ex.cpp.html + dnn_mmod_train_find_cars_ex.cpp.html + dnn_metric_learning_ex.cpp.html + dnn_metric_learning_on_images_ex.cpp.html + dnn_semantic_segmentation_train_ex.cpp.html + +
    + + + + + add_loss_layer + dlib/dnn.h + dlib/dnn/core_abstract.h + + This object is a tool for stacking a loss layer + on the top of a deep neural network. + + + dnn_introduction_ex.cpp.html + dnn_introduction2_ex.cpp.html + dnn_inception_ex.cpp.html + dnn_imagenet_ex.cpp.html + dnn_imagenet_train_ex.cpp.html + dnn_mmod_ex.cpp.html + dnn_mmod_find_cars_ex.cpp.html + dnn_mmod_train_find_cars_ex.cpp.html + dnn_metric_learning_ex.cpp.html + dnn_metric_learning_on_images_ex.cpp.html + dnn_face_recognition_ex.cpp.html + dnn_mmod_face_detection_ex.cpp.html + dnn_mmod_dog_hipsterizer.cpp.html + dnn_semantic_segmentation_train_ex.cpp.html + + + + + + + repeat + dlib/dnn.h + dlib/dnn/core_abstract.h + + This object adds N copies of a computational layer onto a deep neural network. + It is essentially the same as using add_layer N times, + except that it involves less typing, and for large N, will compile much faster. + + + dnn_introduction2_ex.cpp.html + + + + + + + add_tag_layer + dlib/dnn.h + dlib/dnn/core_abstract.h + + This object is a tool for tagging layers in a deep neural network. These tags make it + easy to refer to the tagged layer in other parts of your code. + Specifically, this object adds a new layer onto a deep neural network. + However, this layer simply performs the identity transform. + This means it is a no-op and its presence does not change the + behavior of the network. It exists solely to be used by add_skip_layer or layer() to reference a + particular part of a network. + +

    + For a tutorial showing how to use tagging see the + dnn_introduction2_ex.cpp + example program. +

    +
    + + dnn_introduction2_ex.cpp.html + +
    + + + + + add_skip_layer + dlib/dnn.h + dlib/dnn/core_abstract.h + + This object adds a new layer to a deep neural network which draws its input + from a tagged layer rather than from + the immediate predecessor layer as is normally done. + +

    + For a tutorial showing how to use tagging see the + dnn_introduction2_ex.cpp + example program. +

    +
    +
    + + + + + layer + dlib/dnn.h + dlib/dnn/core_abstract.h + + This global function references a tagged layer + inside a deep neural network object. + +

    + For a tutorial showing how to use tagging see the + dnn_introduction2_ex.cpp + example program. +

    +
    + + dnn_introduction2_ex.cpp.html + +
    + + + + + input + dlib/dnn.h + dlib/dnn/input_abstract.h + + This is a simple input layer type for use in a deep neural network which + takes some kind of image as input and loads it into a network. + + + dnn_introduction_ex.cpp.html + dnn_introduction2_ex.cpp.html + dnn_inception_ex.cpp.html + dnn_imagenet_ex.cpp.html + dnn_imagenet_train_ex.cpp.html + + + + + + + input_rgb_image + dlib/dnn.h + dlib/dnn/input_abstract.h + + This is a simple input layer type for use in a deep neural network + which takes an RGB image as input and loads it into a network. It + is very similar to the input layer except that + it allows you to subtract the average color value from each color + channel when converting an image to a tensor. + + + + + + + input_rgb_image_sized + dlib/dnn.h + dlib/dnn/input_abstract.h + + This layer has an interface and behavior identical to input_rgb_image + except that it requires input images to have a particular size. + + + + + + + input_rgb_image_pyramid + dlib/dnn.h + dlib/dnn/input_abstract.h + + This input layer works with RGB images of type matrix<rgb_pixel>. It is + identical to input_rgb_image except that it + outputs a tensor containing a tiled image pyramid + of each input image rather than a simple copy of each image. + This input layer is meant to be used with a loss layer such as the MMOD loss layer. + + + dnn_mmod_ex.cpp.html + dnn_mmod_find_cars_ex.cpp.html + dnn_mmod_find_cars2_ex.cpp.html + dnn_mmod_train_find_cars_ex.cpp.html + dnn_mmod_face_detection_ex.cpp.html + dnn_mmod_dog_hipsterizer.cpp.html + + + + + + + loss_mmod_ + dlib/dnn.h + dlib/dnn/loss_abstract.h + + This object is a loss layer + for a deep neural network. In particular, it implements the Max Margin Object Detection + loss defined in the paper: +
    Max-Margin Object Detection by Davis E. King.
    + + This means you use this loss if you want to detect the locations of objects + in images. For example, here are some videos that uses loss_mmod to find cars: + +
    +
    +
    + +
    + + dnn_mmod_ex.cpp.html + dnn_mmod_find_cars_ex.cpp.html + dnn_mmod_find_cars2_ex.cpp.html + dnn_mmod_train_find_cars_ex.cpp.html + dnn_mmod_face_detection_ex.cpp.html + dnn_mmod_dog_hipsterizer.cpp.html + cnn_face_detector.py.html + +
    + + + + + loss_metric_ + dlib/dnn.h + dlib/dnn/loss_abstract.h + + This object is a loss layer + for a deep neural network. In particular, it allows you to learn to map objects + into a vector space where objects sharing the same class label are close to + each other, while objects with different labels are far apart. + + + dnn_metric_learning_ex.cpp.html + dnn_metric_learning_on_images_ex.cpp.html + dnn_face_recognition_ex.cpp.html + face_recognition.py.html + face_clustering.py.html + + + + + + + loss_mean_squared_ + dlib/dnn.h + dlib/dnn/loss_abstract.h + + This object is a loss layer + for a deep neural network. In particular, it implements the mean squared loss, which is + appropriate for regression problems. + + + + + + + loss_mean_squared_multioutput_ + dlib/dnn.h + dlib/dnn/loss_abstract.h + + This object is a loss layer + for a deep neural network. In particular, it implements the mean squared loss, which is + appropriate for regression problems. It is identical to the loss_mean_squared_ + loss except this version supports multiple output values. + + + + + + + test_layer + dlib/dnn.h + dlib/dnn/core_abstract.h + + This is a function which tests if a layer object correctly implements + the documented contract + for a computational layer in a deep neural network. + + + + + + + resizable_tensor + dlib/dnn.h + dlib/dnn/tensor_abstract.h + + This object represents a 4D array of float values, all stored contiguously + in memory. Importantly, it keeps two copies of the floats, one on the host + CPU side and another on the GPU device side. It automatically performs the + necessary host/device transfers to keep these two copies of the data in + sync. + +

    + All transfers to the device happen asynchronously with respect to the + default CUDA stream so that CUDA kernel computations can overlap with data + transfers. However, any transfers from the device to the host happen + synchronously in the default CUDA stream. Therefore, you should perform + all your CUDA kernel launches on the default stream so that transfers back + to the host do not happen before the relevant computations have completed. +

    + +

    + If DLIB_USE_CUDA is not #defined then this object will not use CUDA at all. + Instead, it will simply store one host side memory block of floats. +

    + +

    + Finally, the convention in dlib code is to interpret the tensor as a set of + num_samples() 3D arrays, each of dimension k() by nr() by nc(). Also, + while this class does not specify a memory layout, the convention is to + assume that indexing into an element at coordinates (sample,k,nr,nc) can be + accomplished via: + host()[((sample*t.k() + k)*t.nr() + nr)*t.nc() + nc] +

    + +
    +
    + + + + + alias_tensor + dlib/dnn.h + dlib/dnn/tensor_abstract.h + + This object is a tensor that + aliases another tensor. That is, it doesn't have its own block of + memory but instead simply holds pointers to the memory of another + tensor object. It therefore allows you to efficiently break a tensor + into pieces and pass those pieces into functions. + + + + + + + modularity + dlib/clustering.h + dlib/clustering/modularity_clustering_abstract.h + + This function computes the modularity of a particular graph clustering. This + is a number that tells you how good the clustering is. In particular, it + is the measure optimized by the newman_cluster + routine. + + + + + + + + newman_cluster + dlib/clustering.h + dlib/clustering/modularity_clustering_abstract.h + + This function performs the clustering algorithm described in the paper +
    Modularity and community structure in networks by M. E. J. Newman.
    + In particular, this is a method for automatically clustering the nodes in a + graph into groups. The method is able to automatically determine the number + of clusters and does not have any parameters. In general, it is a very good + clustering technique. +
    + +
    + + + + + spectral_cluster + dlib/clustering.h + dlib/clustering/spectral_cluster_abstract.h + + This function performs the clustering algorithm described in the paper +
    On spectral clustering: Analysis and an algorithm by Ng, Jordan, and Weiss.
    +
    + + kkmeans_ex.cpp.html + + +
    + + + + + bottom_up_cluster + dlib/clustering.h + dlib/clustering/bottom_up_cluster_abstract.h + + This function runs a bottom up agglomerative clustering algorithm. + + + + + + + + segment_number_line + dlib/clustering.h + dlib/clustering/bottom_up_cluster_abstract.h + + This routine clusters real valued scalars in essentially linear time. + It uses a combination of bottom up clustering and a simple greedy scan + to try and find the most compact set of ranges that contain all + given scalar values. + + + + + + + + chinese_whispers + dlib/clustering.h + dlib/clustering/chinese_whispers_abstract.h + + This function performs the clustering algorithm described in the paper +
    Chinese Whispers - an Efficient Graph Clustering Algorithm and its + Application to Natural Language Processing Problems by Chris Biemann.
    + In particular, this is a method for automatically clustering the nodes in a + graph into groups. The method is able to automatically determine the number + of clusters. +
    + + dnn_face_recognition_ex.cpp.html + face_clustering.py.html + + +
    + + + + + find_clusters_using_kmeans + dlib/clustering.h + dlib/svm/kkmeans_abstract.h + + This is a simple linear kmeans clustering implementation. + It uses Euclidean distance to compare samples. + + + + + + + + find_clusters_using_angular_kmeans + dlib/clustering.h + dlib/svm/kkmeans_abstract.h + + This is a simple linear kmeans clustering implementation. + To compare a sample to a cluster, it measures the angle between them + with respect to the origin. Therefore, it tries to find clusters + of points that all have small angles between each cluster member. + + + + + + + + nearest_center + dlib/clustering.h + dlib/svm/kkmeans_abstract.h + + This function takes a list of cluster centers and a query vector + and identifies which cluster center is nearest to the query vector. + + + + + + + pick_initial_centers + dlib/clustering.h + dlib/svm/kkmeans_abstract.h + + This is a function that you can use to seed data clustering algorithms + like the kkmeans clustering method. What it + does is pick reasonable starting points for clustering by basically + trying to find a set of points that are all far away from each other. + + + kkmeans_ex.cpp.html + + + + + + + + ranking_pair + dlib/svm.h + dlib/svm/ranking_tools_abstract.h + + This object is used to contain a ranking example. Therefore, ranking_pair + objects are used to represent training examples for learning-to-rank tasks, + such as those used by the svm_rank_trainer. + + + svm_rank_ex.cpp.html + svm_rank.py.html + + + + + + + + kernel_matrix + dlib/svm.h + dlib/svm/kernel_matrix_abstract.h + + This is a simple set of functions that makes it easy to turn a kernel + object and a set of samples into a kernel matrix. It takes these two + things and returns a matrix expression + that represents the kernel matrix. + + + + + + + + is_ranking_problem + dlib/svm.h + dlib/svm/ranking_tools_abstract.h + + This function takes a set of training data for a learning-to-rank problem + and reports back if it could possibly be a well formed problem. + + + + + + + + count_ranking_inversions + dlib/svm.h + dlib/svm/ranking_tools_abstract.h + + Given two sets of objects, X and Y, and an ordering relationship defined + between their elements, this function counts how many times we see an element + in the set Y ordered before an element in the set X. Additionally, this + routine executes efficiently in O(n*log(n)) time via the use of quick sort. + + + + + + + + mlp + dlib/mlp.h + dlib/mlp/mlp_kernel_abstract.h + +

    + This object represents a multilayer layer perceptron network that is + trained using the back propagation algorithm. The training algorithm also + incorporates the momentum method. That is, each round of back propagation + training also adds a fraction of the previous update. This fraction + is controlled by the momentum term set in the constructor. +

    +

    + It is worth noting that a MLP is, in general, very inferior to modern + kernel algorithms such as the support vector machine. So if you haven't + tried any other techniques with your data you really should. +

    +
    + + + mlp_ex.cpp.html + + + + + mlp_kernel_1 + dlib/mlp/mlp_kernel_1.h + + This is implemented in the obvious way. + + + + + kernel_1a + is a typedef for mlp_kernel_1 + + + + + + + +
    + + + + + krls + dlib/svm.h + dlib/svm/krls_abstract.h + + This is an implementation of the kernel recursive least squares algorithm + described in the paper The Kernel Recursive Least Squares Algorithm by Yaakov Engel. +

    + The long and short of this algorithm is that it is an online kernel based + regression algorithm. You give it samples (x,y) and it learns the function + f(x) == y. For a detailed description of the algorithm read the above paper. +

    +

    + Note that if you want to use the linear kernel then you would + be better off using the rls object as it + is optimized for this case. +

    +
    + + + krls_ex.cpp.html + krls_filter_ex.cpp.html + + +
    + + + + + rls + dlib/svm.h + dlib/svm/rls_abstract.h + + This is an implementation of the linear version of the recursive least + squares algorithm. It accepts training points incrementally and, at + each step, maintains the solution to the following optimization problem: +
    + find w minimizing: 0.5*dot(w,w) + C*sum_i(y_i - trans(x_i)*w)^2 +
    + Where (x_i,y_i) are training pairs. x_i is some vector and y_i is a target + scalar value. +
    + +
    + + + + + svm_pegasos + dlib/svm.h + dlib/svm/pegasos_abstract.h + + This object implements an online algorithm for training a support + vector machine for solving binary classification problems. + +

    + The implementation of the Pegasos algorithm used by this object is based + on the following excellent paper: +

    + Pegasos: Primal estimated sub-gradient solver for SVM (2007) + by Shai Shalev-Shwartz, Yoram Singer, Nathan Srebro + In ICML +
    +

    +

    + This SVM training algorithm has two interesting properties. First, the + pegasos algorithm itself converges to the solution in an amount of time + unrelated to the size of the training set (in addition to being quite fast + to begin with). This makes it an appropriate algorithm for learning from + very large datasets. Second, this object uses the kcentroid object + to maintain a sparse approximation of the learned decision function. + This means that the number of support vectors in the resulting decision + function is also unrelated to the size of the dataset (in normal SVM + training algorithms, the number of support vectors grows approximately + linearly with the size of the training set). +

    +

    + However, if you are considering using svm_pegasos, you should also try the + svm_c_linear_trainer for linear + kernels or svm_c_ekm_trainer for non-linear + kernels since these other trainers are, usually, faster and easier to use + than svm_pegasos. +

    +
    + + + svm_pegasos_ex.cpp.html + svm_sparse_ex.cpp.html + svm_binary_classifier.py.html + +
    + + + + + + kkmeans + dlib/clustering.h + dlib/svm/kkmeans_abstract.h + + This is an implementation of a kernelized k-means clustering algorithm. + It performs k-means clustering by using the kcentroid object. +

    + If you want to use the linear kernel (i.e. do a normal k-means clustering) then you + should use the find_clusters_using_kmeans routine. +

    +
    + + + kkmeans_ex.cpp.html + + +
    + + + + + + vector_normalizer + dlib/statistics.h + dlib/statistics/statistics_abstract.h + + This object represents something that can learn to normalize a set + of column vectors. In particular, normalized column vectors should + have zero mean and a variance of one. + + + + svm_ex.cpp.html + + + + + + + + vector_normalizer_frobmetric + dlib/statistics.h + dlib/statistics/vector_normalizer_frobmetric_abstract.h + + This object is a tool for performing the FrobMetric distance metric + learning algorithm described in the following paper: +
    + A Scalable Dual Approach to Semidefinite Metric Learning + By Chunhua Shen, Junae Kim, Lei Wang, in CVPR 2011 +
    + Therefore, this object is a tool that takes as input training triplets + (anchor, near, far) of vectors and attempts to learn a linear + transformation T such that: +
    length(T*anchor-T*near) + 1 < length(T*anchor - T*far)
    + That is, you give a bunch of anchor vectors and for each anchor vector you + specify some vectors which should be near to it and some that should be far + form it. This object then tries to find a transformation matrix that makes + the "near" vectors close to their anchors while the "far" vectors are + farther away. +
    +
    + + + + + compute_lda_transform + dlib/statistics.h + dlib/statistics/lda_abstract.h + + This function performs the dimensionality reducing version of linear + discriminant analysis. That is, you give it a set of labeled vectors and it + returns a linear transform that maps the input vectors into a new space that + is good for distinguishing between the different classes. + + + + + + + discriminant_pca + dlib/statistics.h + dlib/statistics/dpca_abstract.h + + This object implements the Discriminant PCA technique described in the paper: +
    + A New Discriminant Principal Component Analysis Method with Partial Supervision (2009) + by Dan Sun and Daoqiang Zhang +
    + This algorithm is basically a straightforward generalization of the classical PCA + technique to handle partially labeled data. It is useful if you want to learn a linear + dimensionality reduction rule using a bunch of data that is partially labeled. +
    + +
    + + + + + sammon_projection + dlib/statistics.h + dlib/statistics/sammon_abstract.h + + This is a function object that computes the Sammon projection of a set + of N points in a L-dimensional vector space onto a d-dimensional space + (d < L), according to the paper: +
    + A Nonlinear Mapping for Data Structure Analysis (1969) by J.W. Sammon +
    +
    + +
    + + + + + cca + dlib/statistics.h + dlib/statistics/cca_abstract.h + + This function performs a canonical correlation analysis between two sets + of vectors. Additionally, it is designed to be very fast, even for large + datasets of over a million high dimensional vectors. + + + + + + + + vector_normalizer_pca + dlib/statistics.h + dlib/statistics/statistics_abstract.h + + This object represents something that can learn to normalize a set + of column vectors. In particular, normalized column vectors should + have zero mean and a variance of one. + + This object also uses principal component analysis for the purposes + of reducing the number of elements in a vector. + + + + + + + + linearly_independent_subset_finder + dlib/svm.h + dlib/svm/linearly_independent_subset_finder_abstract.h + +

    + This is an implementation of an online algorithm for recursively finding a + set (aka dictionary) of linearly independent vectors in a kernel induced + feature space. To use it you decide how large you would like the dictionary + to be and then you feed it sample points. +

    +

    + The implementation uses the Approximately Linearly Dependent metric described + in the paper The Kernel Recursive Least Squares Algorithm by Yaakov Engel to + decide which points are more linearly independent than others. The metric is + simply the squared distance between a test point and the subspace spanned by + the set of dictionary vectors. +

    +

    + Each time you present this object with a new sample point + it calculates the projection distance and if it is sufficiently large then this + new point is included into the dictionary. Note that this object can be configured + to have a maximum size. Once the max dictionary size is reached each new point + kicks out a previous point. This is done by removing the dictionary vector that + has the smallest projection distance onto the others. That is, the "least linearly + independent" vector is removed to make room for the new one. +

    +
    + + empirical_kernel_map_ex.cpp.html + + +
    + + + + + + fill_lisf + dlib/svm.h + dlib/svm/linearly_independent_subset_finder_abstract.h + + This is a simple function for filling a + linearly_independent_subset_finder + with data points by using random sampling. + + + empirical_kernel_map_ex.cpp.html + + + + + + + + + sort_basis_vectors + dlib/svm.h + dlib/svm/sort_basis_vectors_abstract.h + + A kernel based learning method ultimately needs to select a set of basis functions + represented by a particular choice of kernel and a set of basis vectors. + sort_basis_vectors() is a function which attempts to perform supervised + basis set selection. In particular, you give it a candidate set of basis + vectors and it sorts them according to how useful they are for solving + a particular decision problem. + + + + + + + + rank_unlabeled_training_samples + dlib/svm.h + dlib/svm/active_learning_abstract.h + + This routine implements an active learning method for selecting the most + informative data sample to label out of a set of unlabeled samples. + In particular, it implements the MaxMin Margin and Ratio Margin methods + described in the paper: +
    + Support Vector Machine Active Learning with Applications to Text Classification + by Simon Tong and Daphne Koller. +
    +
    +
    + + + + + + linear_manifold_regularizer + dlib/manifold_regularization.h + dlib/manifold_regularization/linear_manifold_regularizer_abstract.h + +

    + Many learning algorithms attempt to minimize a function that, at a high + level, looks like this: +

    +   f(w) == complexity + training_set_error
    +
    +

    + +

    + The idea is to find the set of parameters, w, that gives low error on + your training data but also is not "complex" according to some particular + measure of complexity. This strategy of penalizing complexity is + usually called regularization. +

    + +

    + In the above setting, all the training data consists of labeled samples. + However, it would be nice to be able to benefit from unlabeled data. + The idea of manifold regularization is to extract useful information from + unlabeled data by first defining which data samples are "close" to each other + (perhaps by using their 3 nearest neighbors) + and then adding a term to + the above function that penalizes any decision rule which produces + different outputs on data samples which we have designated as being close. +

    + +

    + It turns out that it is possible to transform these manifold regularized learning + problems into the normal form shown above by applying a certain kind of + preprocessing to all our data samples. Once this is done we can use a + normal learning algorithm, such as the svm_c_linear_trainer, + on just the + labeled data samples and obtain the same output as the manifold regularized + learner would have produced. +

    + +

    + The linear_manifold_regularizer is a tool for creating this preprocessing + transformation. In particular, the transformation is linear. That is, it + is just a matrix you multiply with all your samples. For a more detailed + discussion of this topic you should consult the following paper. In + particular, see section 4.2. This object computes the inverse T matrix + described in that section. +

    + Linear Manifold Regularization for Large Scale Semi-supervised Learning + by Vikas Sindhwani, Partha Niyogi, and Mikhail Belkin +
    +

    + +
    + + linear_manifold_regularizer_ex.cpp.html + +
    + + + + + + empirical_kernel_map + dlib/svm.h + dlib/svm/empirical_kernel_map_abstract.h + +

    + This object represents a map from objects of sample_type (the kind of object + a kernel function + operates on) to finite dimensional column vectors which + represent points in the kernel feature space defined by whatever kernel + is used with this object. +

    + +

    + To use the empirical_kernel_map you supply it with a particular kernel and a set of + basis samples. After that you can present it with new samples and it will project + them into the part of kernel feature space spanned by your basis samples. +

    + +

    + This means the empirical_kernel_map is a tool you can use to very easily kernelize + any algorithm that operates on column vectors. All you have to do is select a + set of basis samples and then use the empirical_kernel_map to project all your + data points into the part of kernel feature space spanned by those basis samples. + Then just run your normal algorithm on the output vectors and it will be effectively + kernelized. +

    + +

    + Regarding methods to select a set of basis samples, if you are working with only a + few thousand samples then you can just use all of them as basis samples. + Alternatively, the + linearly_independent_subset_finder + often works well for selecting a basis set. I also find that picking a + random subset typically works well. +

    +
    + + empirical_kernel_map_ex.cpp.html + linear_manifold_regularizer_ex.cpp.html + +
    + + + + + + + kcentroid + dlib/svm.h + dlib/svm/kcentroid_abstract.h + + + This object represents a weighted sum of sample points in a kernel induced + feature space. It can be used to kernelize any algorithm that requires only + the ability to perform vector addition, subtraction, scalar multiplication, + and inner products. + +

    + An example use of this object is as an online algorithm for recursively estimating + the centroid of a sequence of training points. This object then allows you to + compute the distance between the centroid and any test points. So you can use + this object to predict how similar a test point is to the data this object has + been trained on (larger distances from the centroid indicate dissimilarity/anomalous + points). +

    + +

    + The object internally keeps a set of "dictionary vectors" + that are used to represent the centroid. It manages these vectors using the + sparsification technique described in the paper The Kernel Recursive Least + Squares Algorithm by Yaakov Engel. This technique allows us to keep the + number of dictionary vectors down to a minimum. In fact, the object has a + user selectable tolerance parameter that controls the trade off between + accuracy and number of stored dictionary vectors. +

    + +
    + + + kcentroid_ex.cpp.html + + +
    + + + + + + train_probabilistic_decision_function + dlib/svm.h + dlib/svm/svm_abstract.h + +

    + Trains a probabilistic_function using + some sort of binary classification trainer object such as the svm_nu_trainer or + krr_trainer. +

    + The probability model is created by using the technique described in the following papers: +
    + Probabilistic Outputs for Support Vector Machines and + Comparisons to Regularized Likelihood Methods by + John C. Platt. March 26, 1999 +
    +
    + A Note on Platt's Probabilistic Outputs for Support Vector Machines + by Hsuan-Tien Lin, Chih-Jen Lin, and Ruby C. Weng +
    +
    + + svm_ex.cpp.html + + +
    + + + + + learn_platt_scaling + dlib/svm.h + dlib/svm/svm_abstract.h + + +

    + This function is an implementation of the algorithm described in the following + papers: +

    + Probabilistic Outputs for Support Vector Machines and Comparisons to + Regularized Likelihood Methods by John C. Platt. March 26, 1999 +
    +
    + A Note on Platt's Probabilistic Outputs for Support Vector Machines + by Hsuan-Tien Lin, Chih-Jen Lin, and Ruby C. Weng +
    +

    +

    + This function is the tool used to implement the + train_probabilistic_decision_function routine. +

    + +
    + +
    + + + + + probabilistic + dlib/svm.h + dlib/svm/svm_abstract.h + + This is a trainer adapter which simply runs the trainer it is given though the + train_probabilistic_decision_function + function. + + + + + + + + rbf_network_trainer + dlib/svm.h + dlib/svm/rbf_network_abstract.h + + Trains a radial basis function network and outputs a decision_function. + This object can be used for either regression or binary classification problems. + It's worth pointing out that this object is essentially an unregularized version + of kernel ridge regression. This means + you should really prefer to use kernel ridge regression instead. + + + + + + + random_forest_regression_trainer + dlib/random_forest.h + dlib/random_forest/random_forest_regression_abstract.h + + This object implements Breiman's classic random forest regression + algorithm. + + + + + + + random_forest_regression_function + dlib/random_forest.h + dlib/random_forest/random_forest_regression_abstract.h + + This object represents a random forest that maps objects to real numbers. You + can learn its parameters using the random_forest_regression_trainer. + + + + + + + rvm_regression_trainer + dlib/svm.h + dlib/svm/rvm_abstract.h + +

    + Trains a relevance vector machine for solving regression problems. + Outputs a decision_function that represents the learned + regression function. +

    + The implementation of the RVM training algorithm used by this library is based + on the following paper: +
    + Tipping, M. E. and A. C. Faul (2003). Fast marginal likelihood maximisation + for sparse Bayesian models. In C. M. Bishop and B. J. Frey (Eds.), Proceedings + of the Ninth International Workshop on Artificial Intelligence and Statistics, + Key West, FL, Jan 3-6. +
    +
    + + rvm_regression_ex.cpp.html + + +
    + + + + + + rvm_trainer + dlib/svm.h + dlib/svm/rvm_abstract.h + +

    + Trains a relevance vector machine for solving binary classification problems. + Outputs a decision_function that represents the learned classifier. +

    + The implementation of the RVM training algorithm used by this library is based + on the following paper: +
    + Tipping, M. E. and A. C. Faul (2003). Fast marginal likelihood maximisation + for sparse Bayesian models. In C. M. Bishop and B. J. Frey (Eds.), Proceedings + of the Ninth International Workshop on Artificial Intelligence and Statistics, + Key West, FL, Jan 3-6. +
    +
    + + rvm_ex.cpp.html + + +
    + + + + + krr_trainer + dlib/svm.h + dlib/svm/krr_trainer_abstract.h + +

    + Performs kernel ridge regression and outputs a decision_function that + represents the learned function. +

    + The implementation is done using the empirical_kernel_map and + linearly_independent_subset_finder to kernelize + the rr_trainer object. Thus it allows you to run the algorithm on large + datasets and obtain sparse outputs. It is also capable of automatically estimating its + regularization parameter using leave-one-out cross-validation. +
    + + krr_regression_ex.cpp.html + krr_classification_ex.cpp.html + + +
    + + + + + + rr_trainer + dlib/svm.h + dlib/svm/rr_trainer_abstract.h + +

    + Performs linear ridge regression and outputs a decision_function that + represents the learned function. In particular, this object can only be used with + the linear_kernel. It is optimized for the linear case where + the number of features in each sample vector is small (i.e. on the order of 1000 or less since the + algorithm is cubic in the number of features.). + If you want to use a nonlinear kernel then you should use the krr_trainer. +

    + This object is capable of automatically estimating its regularization parameter using + leave-one-out cross-validation. +
    + +
    + + + + + svr_trainer + dlib/svm.h + dlib/svm/svr_trainer_abstract.h + +

    + This object implements a trainer for performing epsilon-insensitive support + vector regression. It is implemented using the SMO algorithm, + allowing the use of non-linear kernels. + If you are interested in performing support vector regression with a linear kernel and you + have a lot of training data then you should use the svr_linear_trainer + which is highly optimized for this case. +

    + The implementation of the eps-SVR training algorithm used by this object is based + on the following paper: + +
    + + svr_ex.cpp.html + + +
    + + + + + svr_linear_trainer + dlib/svm.h + dlib/svm/svr_linear_trainer_abstract.h + + This object implements a trainer for performing epsilon-insensitive support + vector regression. It uses the oca + optimizer so it is very efficient at solving this problem when + linear kernels are used, making it suitable for use with large + datasets. + + + + + + + svm_nu_trainer + dlib/svm.h + dlib/svm/svm_nu_trainer_abstract.h + +

    + Trains a nu support vector machine for solving binary classification problems and + outputs a decision_function. + It is implemented using the SMO algorithm. +

    + The implementation of the nu-svm training algorithm used by this library is based + on the following excellent papers: +
      +
    • Chang and Lin, Training {nu}-Support Vector Classifiers: Theory and Algorithms
    • +
    • Chih-Chung Chang and Chih-Jen Lin, LIBSVM : a library for support vector + machines, 2001. Software available at + http://www.csie.ntu.edu.tw/~cjlin/libsvm
    • +
    +
    + + svm_ex.cpp.html + model_selection_ex.cpp.html + + +
    + + + + + svm_one_class_trainer + dlib/svm.h + dlib/svm/svm_one_class_trainer_abstract.h + +

    + Trains a one-class support vector classifier and outputs a decision_function. + It is implemented using the SMO algorithm. +

    + The implementation of the one-class training algorithm used by this library is based + on the following paper: + +
    + + one_class_classifiers_ex.cpp.html + + +
    + + + + + svm_c_trainer + dlib/svm.h + dlib/svm/svm_c_trainer_abstract.h + +

    + Trains a C support vector machine for solving binary classification problems + and outputs a decision_function. + It is implemented using the SMO algorithm. +

    + The implementation of the C-SVM training algorithm used by this library is based + on the following paper: + +
    + + svm_c_ex.cpp.html + + +
    + + + + + svm_c_linear_dcd_trainer + dlib/svm.h + dlib/svm/svm_c_linear_dcd_trainer_abstract.h + + This object represents a tool for training the C formulation of + a support vector machine to solve binary classification problems. + It is optimized for the case where linear kernels are used and + is implemented using the method described in the + following paper: +
    + A Dual Coordinate Descent Method for Large-scale Linear SVM + by Cho-Jui Hsieh, Kai-Wei Chang, and Chih-Jen Lin +
    + + This trainer has the ability to disable the bias term and also + to force the last element of the learned weight vector to be 1. + Additionally, it can be warm-started from the solution to a previous + training run. +
    + + one_class_classifiers_ex.cpp.html + +
    + + + + + svm_c_linear_trainer + dlib/svm.h + dlib/svm/svm_c_linear_trainer_abstract.h + + This object represents a tool for training the C formulation of + a support vector machine to solve binary classification problems. + It is optimized for the case where linear kernels are used and + is implemented using the oca + optimizer and uses the exact line search described in the + following paper: +
    + Optimized Cutting Plane Algorithm for Large-Scale Risk Minimization + by Vojtech Franc, Soren Sonnenburg; Journal of Machine Learning + Research, 10(Oct):2157--2192, 2009. +
    + + This trainer has the ability to restrict the learned weights to non-negative + values. +
    + + svm_sparse_ex.cpp.html + + +
    + + + + + svm_rank_trainer + dlib/svm.h + dlib/svm/svm_rank_trainer_abstract.h + + This object represents a tool for training a ranking support vector machine + using linear kernels. In particular, this object is a tool for training + the Ranking SVM described in the paper: +
    + Optimizing Search Engines using Clickthrough Data by Thorsten Joachims +
    + Finally, note that the implementation of this object is done using the + oca optimizer and + count_ranking_inversions method. + This means that it runs in O(n*log(n)) time, making it suitable for use + with large datasets. +
    + + svm_rank_ex.cpp.html + svm_rank.py.html + + +
    + + + + + shape_predictor_trainer + dlib/image_processing.h + dlib/image_processing/shape_predictor_trainer_abstract.h + + This object is a tool for training shape_predictors + based on annotated training images. Its implementation uses the algorithm described in: +
    + One Millisecond Face Alignment with an Ensemble of Regression Trees + by Vahid Kazemi and Josephine Sullivan, CVPR 2014 +
    + It is capable of learning high quality shape models. For example, this is an example output + for one of the faces in the HELEN face dataset:

    + + +
    + + train_shape_predictor_ex.cpp.html + train_shape_predictor.py.html + + +
    + + + + + svm_c_ekm_trainer + dlib/svm.h + dlib/svm/svm_c_ekm_trainer_abstract.h + + This object represents a tool for training the C formulation of + a support vector machine for solving binary classification problems. + It is implemented using the empirical_kernel_map + to kernelize the svm_c_linear_trainer. This makes it a very fast algorithm + capable of learning from very large datasets. + + + + + + + + + normalized_function + dlib/svm.h + dlib/svm/function_abstract.h + + This object represents a container for another function + object and an instance of the vector_normalizer object. + + It automatically normalizes all inputs before passing them + off to the contained function object. + + + svm_ex.cpp.html + + + + + + + + + probabilistic_decision_function + dlib/svm.h + dlib/svm/function_abstract.h + + This object represents a binary decision function for use with + kernel-based learning-machines. It returns an + estimate of the probability that a given sample is in the +1 class. + + + svm_ex.cpp.html + + + + + + + + probabilistic_function + dlib/svm.h + dlib/svm/function_abstract.h + + This object represents a binary decision function for use with + any kind of binary classifier. It returns an + estimate of the probability that a given sample is in the +1 class. + + + + + + + distance_function + dlib/svm.h + dlib/svm/function_abstract.h + + This object represents a point in kernel induced feature space. + You may use this object to find the distance from the point it + represents to points in input space as well as other points + represented by distance_functions. + + + + + + + + decision_function + dlib/svm.h + dlib/svm/function_abstract.h + + This object represents a classification or regression function that was + learned by a kernel based learning algorithm. Therefore, it is a function + object that takes a sample object and returns a scalar value. + + + svm_ex.cpp.html + + + + + + + + one_vs_one_decision_function + dlib/svm.h + dlib/svm/one_vs_one_decision_function_abstract.h + + This object represents a multiclass classifier built out + of a set of binary classifiers. Each binary classifier + is used to vote for the correct multiclass label using a + one vs. one strategy. Therefore, if you have N classes then + there will be N*(N-1)/2 binary classifiers inside this object. + + + multiclass_classification_ex.cpp.html + custom_trainer_ex.cpp.html + + + + + + + + one_vs_one_trainer + dlib/svm_threaded.h + dlib/svm/one_vs_one_trainer_abstract.h + + This object is a tool for turning a bunch of binary classifiers + into a multiclass classifier. It does this by training the binary + classifiers in a one vs. one fashion. That is, if you have N possible + classes then it trains N*(N-1)/2 binary classifiers which are then used + to vote on the identity of a test sample. + + + multiclass_classification_ex.cpp.html + custom_trainer_ex.cpp.html + + + + + + + + one_vs_all_decision_function + dlib/svm.h + dlib/svm/one_vs_all_decision_function_abstract.h + + This object represents a multiclass classifier built out + of a set of binary classifiers. Each binary classifier + is used to vote for the correct multiclass label using a + one vs. all strategy. Therefore, if you have N classes then + there will be N binary classifiers inside this object. + + + + + + + + sequence_labeler + dlib/svm.h + dlib/svm/sequence_labeler_abstract.h + + This object is a tool for doing sequence labeling. In particular, + it is capable of representing sequence labeling models such as + those produced by Hidden Markov SVMs or Conditional Random fields. + See the following papers for an introduction to these techniques: +
    + Hidden Markov Support Vector Machines by + Y. Altun, I. Tsochantaridis, T. Hofmann +
    + Shallow Parsing with Conditional Random Fields by + Fei Sha and Fernando Pereira +
    +
    + + sequence_labeler_ex.cpp.html + + +
    + + + + + sequence_segmenter + dlib/svm.h + dlib/svm/sequence_segmenter_abstract.h + + This object is a tool for segmenting a sequence of objects into a set of + non-overlapping chunks. An example sequence segmentation task is to take + English sentences and identify all the named entities. In this example, + you would be using a sequence_segmenter to find all the chunks of + contiguous words which refer to proper names. + +

    + Internally, the sequence_segmenter uses the BIO (Begin, Inside, Outside) or + BILOU (Begin, Inside, Last, Outside, Unit) sequence tagging model. + Moreover, it is implemented using a sequence_labeler + object and therefore sequence_segmenter objects are examples of + chain structured conditional random field style sequence + taggers. +

    +
    + + sequence_segmenter.py.html + sequence_segmenter_ex.cpp.html + +
    + + + + + assignment_function + dlib/svm.h + dlib/svm/assignment_function_abstract.h + + This object is a tool for solving the optimal assignment problem given a + user defined method for computing the quality of any particular assignment. + + + assignment_learning_ex.cpp.html + + + + + + + + track_association_function + dlib/svm.h + dlib/svm/track_association_function_abstract.h + + This object is a tool that helps you implement an object tracker. So for + example, if you wanted to track people moving around in a video then this + object can help. In particular, imagine you have a tool for detecting the + positions of each person in an image. Then you can run this person + detector on the video and at each time step, i.e. at each frame, you get a + set of person detections. However, that by itself doesn't tell you how + many people there are in the video and where they are moving to and from. + To get that information you need to figure out which detections match each + other from frame to frame. This is where the track_association_function + comes in. It performs the detection to track association. It will also do + some of the track management tasks like creating a new track when a + detection doesn't match any of the existing tracks. + +

    + Internally, this object is implemented using the + assignment_function object. + In fact, it's really just a thin wrapper around assignment_function and + exists just to provide a more convenient interface to users doing detection + to track association. +

    +
    + + learning_to_track_ex.cpp.html + +
    + + + + + lspi + dlib/control.h + dlib/control/lspi_abstract.h + + This object is an implementation of the reinforcement learning algorithm + described in the following paper: +
    + Lagoudakis, Michail G., and Ronald Parr. "Least-squares policy + iteration." The Journal of Machine Learning Research 4 (2003): + 1107-1149. +
    + +
    +
    + + + + + policy + dlib/control.h + dlib/control/approximate_linear_models_abstract.h + + This is a policy (i.e. a control law) based on a linear function approximator. + You can use a tool like lspi to learn the parameters + of a policy. + + + + + + + process_sample + dlib/control.h + dlib/control/approximate_linear_models_abstract.h + + This object holds a training sample for a reinforcement learning algorithm + (e.g. lspi). + In particular, it contains a state, action, reward, next state sample from + some process. + + + + + + + graph_labeler + dlib/graph_cuts.h + dlib/graph_cuts/graph_labeler_abstract.h + + This object is a tool for labeling each node in a graph + with a value of true or false, subject to a labeling consistency constraint between + nodes that share an edge. In particular, this object is useful for + representing a graph labeling model learned via some machine learning + method, such as the structural_graph_labeling_trainer. + + + graph_labeling_ex.cpp.html + + + + + + + multiclass_linear_decision_function + dlib/svm.h + dlib/svm/function_abstract.h + + This object represents a multiclass classifier built out of a set of + binary classifiers. Each binary classifier is used to vote for the + correct multiclass label using a one vs. all strategy. Therefore, + if you have N classes then there will be N binary classifiers inside + this object. Additionally, this object is linear in the sense that + each of these binary classifiers is a simple linear plane. + + + + + + + + one_vs_all_trainer + dlib/svm_threaded.h + dlib/svm/one_vs_all_trainer_abstract.h + + This object is a tool for turning a bunch of binary classifiers + into a multiclass classifier. It does this by training the binary + classifiers in a one vs. all fashion. That is, if you have N possible + classes then it trains N binary classifiers which are then used + to vote on the identity of a test sample. + + + + + + + + svm_multiclass_linear_trainer + dlib/svm_threaded.h + dlib/svm/svm_multiclass_linear_trainer_abstract.h + + This object represents a tool for training a multiclass support + vector machine. It is optimized for the case where linear kernels + are used and implemented using the structural_svm_problem + object. + + + + + + + + projection_function + dlib/svm.h + dlib/svm/function_abstract.h + + This object represents a function that takes a data sample and projects + it into kernel feature space. The result is a real valued column vector that + represents a point in a kernel feature space. Instances of + this object are created using the + empirical_kernel_map. + + + linear_manifold_regularizer_ex.cpp.html + + + + + + + + offset_kernel + dlib/svm.h + dlib/svm/kernel_abstract.h + + This object represents a kernel with a fixed value offset + added to it. + + + + + + + + linear_kernel + dlib/svm.h + dlib/svm/kernel_abstract.h + + This object represents a linear function kernel for use with + kernel learning machines. + + + + + + + + histogram_intersection_kernel + dlib/svm.h + dlib/svm/kernel_abstract.h + + This object represents a histogram intersection kernel for use with + kernel learning machines. + + + + + + + + sigmoid_kernel + dlib/svm.h + dlib/svm/kernel_abstract.h + + This object represents a sigmoid kernel for use with + kernel learning machines. + + + + + + + + polynomial_kernel + dlib/svm.h + dlib/svm/kernel_abstract.h + + This object represents a polynomial kernel for use with + kernel learning machines. + + + + + + + + radial_basis_kernel + dlib/svm.h + dlib/svm/kernel_abstract.h + + This object represents a radial basis function kernel for use with + kernel learning machines. + + + svm_ex.cpp.html + + + + + + + + + sparse_histogram_intersection_kernel + dlib/svm.h + dlib/svm/sparse_kernel_abstract.h + + This object represents a histogram intersection kernel kernel for use with + kernel learning machines that operate on + sparse vectors. + + + + + + + + sparse_sigmoid_kernel + dlib/svm.h + dlib/svm/sparse_kernel_abstract.h + + This object represents a sigmoid kernel for use with + kernel learning machines that operate on + sparse vectors. + + + + + + + + sparse_linear_kernel + dlib/svm.h + dlib/svm/sparse_kernel_abstract.h + + This object represents a linear kernel for use with + kernel learning machines that operate on + sparse vectors. + + + + + + + + sparse_polynomial_kernel + dlib/svm.h + dlib/svm/sparse_kernel_abstract.h + + This object represents a polynomial kernel for use with + kernel learning machines that operate on + sparse vectors. + + + + + + + + sparse_radial_basis_kernel + dlib/svm.h + dlib/svm/sparse_kernel_abstract.h + + This object represents a radial basis function kernel for use with + kernel learning machines that operate on + sparse vectors. + + + + + + + + + is_binary_classification_problem + dlib/svm.h + dlib/svm/svm_abstract.h + + This function simply takes two vectors, the first containing feature vectors and + the second containing labels, and reports back if the two could possibly + contain data for a well formed classification problem. + + + + + + + + is_sequence_labeling_problem + dlib/svm.h + dlib/svm/svm_abstract.h + + This function takes a set of training data for a sequence labeling problem + and reports back if it could possibly be a well formed sequence labeling problem. + + + + + + + + is_sequence_segmentation_problem + dlib/svm.h + dlib/svm/svm_abstract.h + + This function takes a set of training data for a sequence segmentation problem + and reports back if it could possibly be a well formed sequence segmentation problem. + + + + + + + + is_assignment_problem + dlib/svm.h + dlib/svm/svm_abstract.h + + This function takes a set of training data for an assignment problem + and reports back if it could possibly be a well formed assignment problem. + + + + + + + + is_track_association_problem + dlib/svm.h + dlib/svm/svm_abstract.h + + This function takes a set of training data for a track association learning problem + and reports back if it could possibly be a well formed track association problem. + + + + + + + + is_graph_labeling_problem + dlib/svm_threaded.h + dlib/svm/structural_svm_graph_labeling_problem_abstract.h + + This function takes a set of training data for a graph labeling problem + and reports back if it could possibly be a well formed problem. + + + + + + + + is_forced_assignment_problem + dlib/svm.h + dlib/svm/svm_abstract.h + + This function takes a set of training data for a forced assignment problem + and reports back if it could possibly be a well formed forced assignment problem. + + + + + + + + is_learning_problem + dlib/svm.h + dlib/svm/svm_abstract.h + + This function simply takes two vectors, the first containing feature vectors and + the second containing labels, and reports back if the two could possibly + contain data for a well formed learning problem. In this case it just means + that the two vectors have the same length and aren't empty. + + + + + + + + select_all_distinct_labels + dlib/svm.h + dlib/svm/multiclass_tools_abstract.h + + This is a function which determines all distinct values present in a + std::vector and returns the result. + + + + + + + + simplify_linear_decision_function + dlib/svm.h + dlib/svm/simplify_linear_decision_function_abstract.h + + This is a set of functions that takes various forms of linear decision functions + and collapses them down so that they only compute a single dot product when invoked. + + + + + + + randomize_samples + dlib/svm.h + dlib/svm/svm_abstract.h + + Randomizes the order of samples in a column vector containing sample data. + + + svm_ex.cpp.html + + + + + + + + rank_features + dlib/svm.h + dlib/svm/feature_ranking_abstract.h + + Finds a ranking of the top N (a user supplied parameter) features in a set of data + from a two class classification problem. It + does this by computing the distance between the centroids of both classes in kernel defined + feature space. Good features are then ones that result in the biggest separation between + the two centroids. + + + rank_features_ex.cpp.html + + + + + + + + load_mnist_dataset + dlib/data_io.h + dlib/data_io/mnist_abstract.h + + Loads the MNIST dataset from disk. + + + + + + + load_image_dataset + dlib/data_io.h + dlib/data_io/load_image_dataset_abstract.h + + This is a function which loads the list of images indicated by an + image dataset metadata file + as well as the box locations for each image. It makes loading the + data necessary to train an object_detector + a little more convenient. + + + fhog_object_detector_ex.cpp.html + train_object_detector.cpp.html + + + + + + + + load_image_dataset_metadata + dlib/data_io.h + dlib/data_io/image_dataset_metadata.h + + dlib comes with a graphical tool for annotating images with + labeled rectangles. The tool produces an XML file containing these + annotations. Therefore, load_image_dataset_metadata() is a routine + for parsing these XML files. Note also that this is the metadata + format used by the image labeling tool included with dlib in the + tools/imglab folder. + + + + + + + + save_image_dataset_metadata + dlib/data_io.h + dlib/data_io/image_dataset_metadata.h + + This routine is a tool for saving labeled image metadata to an + XML file. In particular, this routine saves the metadata into a + form which can be read by the load_image_dataset_metadata + routine. Note also that this is the metadata + format used by the image labeling tool included with dlib in the + tools/imglab folder. + + + + + + + + load_libsvm_formatted_data + dlib/data_io.h + dlib/data_io/libsvm_io_abstract.h + + This is a function that loads the data from a file that uses + the LIBSVM format. It loads the data into a std::vector of + sparse vectors. + If you want to load data into dense vectors (i.e. + dlib::matrix objects) then you can use the sparse_to_dense + function to perform the conversion. Also, some LIBSVM formatted files number + their features beginning with 1 rather than 0. If this bothers you, then you + can fix it by using the fix_nonzero_indexing function + on the data after it is loaded. + + + + + + + + save_libsvm_formatted_data + dlib/data_io.h + dlib/data_io/libsvm_io_abstract.h + + This is actually a pair of overloaded functions. Between the two of them + they let you save sparse + or dense data vectors to file using the LIBSVM format. + + + + + + + make_bounding_box_regression_training_data + dlib/image_processing.h + dlib/image_processing/shape_predictor_trainer_abstract.h + + Suppose you have an object detector that can roughly locate objects in an + image. This means your detector draws boxes around objects, but these are + rough boxes in the sense that they aren't positioned super accurately. For + instance, HOG based detectors usually have a stride of 8 pixels. So the + positional accuracy is going to be, at best, +/-8 pixels. + +

    + If you want to get better positional accuracy one easy thing to do is train a + shape_predictor to give you the location + of the object's box. The make_bounding_box_regression_training_data() routine + helps you do this by creating an appropriate training dataset. +

    +
    +
    + + + + + fix_nonzero_indexing + dlib/data_io.h + dlib/data_io/libsvm_io_abstract.h + + This is a simple function that takes a std::vector of + sparse vectors + and makes sure they are zero-indexed (e.g. makes sure the first index value is zero). + + + + + + + find_gamma_with_big_centroid_gap + dlib/svm.h + dlib/svm/feature_ranking_abstract.h + + This is a function that tries to pick a reasonable default value for the + gamma parameter of the radial_basis_kernel. It + picks the parameter that gives the largest separation between the centroids, in + kernel feature space, of two classes of data. + + + rank_features_ex.cpp.html + + + + + + + + compute_mean_squared_distance + dlib/svm.h + dlib/svm/feature_ranking_abstract.h + + This is a function that simply finds the average squared distance between all + pairs of a set of data samples. It is often convenient to use the reciprocal + of this value as the estimate of the gamma parameter of the + radial_basis_kernel. + + + + + + + + batch + dlib/svm.h + dlib/svm/pegasos_abstract.h + + This is a convenience function for creating + batch_trainer objects. + + + svm_pegasos_ex.cpp.html + + + + + + + + verbose_batch + dlib/svm.h + dlib/svm/pegasos_abstract.h + + This is a convenience function for creating + batch_trainer objects. This function + generates a batch_trainer that will print status messages to standard + output so that you can observe the progress of a training algorithm. + + + svm_pegasos_ex.cpp.html + + + + + + + + batch_cached + dlib/svm.h + dlib/svm/pegasos_abstract.h + + This is a convenience function for creating + batch_trainer objects that are setup + to use a kernel matrix cache. + + + + + + + + verbose_batch_cached + dlib/svm.h + dlib/svm/pegasos_abstract.h + + This is a convenience function for creating + batch_trainer objects. This function + generates a batch_trainer that will print status messages to standard + output so that you can observe the progress of a training algorithm. + It will also be configured to use a kernel matrix cache. + + + + + + + + batch_trainer + dlib/svm.h + dlib/svm/pegasos_abstract.h + + This is a batch trainer object that is meant to wrap online trainer objects + that create decision_functions. It + turns an online learning algorithm such as svm_pegasos + into a batch learning object. This allows you to use objects like + svm_pegasos with functions (e.g. cross_validate_trainer) + that expect batch mode training objects. + + + + + + + + null_trainer_type + dlib/svm.h + dlib/svm/null_trainer_abstract.h + + This object is a simple tool for turning a decision_function + (or any object with an interface compatible with decision_function) + into a trainer object that always returns the original decision + function when you try to train with it. + +

    + dlib contains a few "training post processing" algorithms (e.g. + reduced and reduced2). These tools + take in a trainer object, + tell it to perform training, and then they take the output decision + function and do some kind of post processing to it. The null_trainer_type + object is useful because you can use it to run an already + learned decision function through the training post processing + algorithms by turning a decision function into a null_trainer_type + and then giving it to a post processor. +

    +
    + +
    + + + + + null_trainer + dlib/svm.h + dlib/svm/null_trainer_abstract.h + + This is a convenience function for creating + null_trainer_type + objects. + + + + + + + + roc_c1_trainer + dlib/svm.h + dlib/svm/roc_trainer_abstract.h + + This is a convenience function for creating + roc_trainer_type objects that are + setup to pick a point on the ROC curve with respect to the +1 class. + + + + + + + + roc_c2_trainer + dlib/svm.h + dlib/svm/roc_trainer_abstract.h + + This is a convenience function for creating + roc_trainer_type objects that are + setup to pick a point on the ROC curve with respect to the -1 class. + + + + + + + + roc_trainer_type + dlib/svm.h + dlib/svm/roc_trainer_abstract.h + + This object is a simple trainer post processor that allows you to + easily adjust the bias term in a trained decision_function object. + That is, this object lets you pick a point on the ROC curve and + it will adjust the bias term appropriately. + +

    + So for example, suppose you wanted to set the bias term so that + the accuracy of your decision function on +1 labeled samples was 99%. + To do this you would use an instance of this object declared as follows: + roc_trainer_type<trainer_type>(your_trainer, 0.99, +1); +

    +
    + +
    + + + + + reduced_decision_function_trainer + dlib/svm.h + dlib/svm/reduced_abstract.h + + This is a batch trainer object that is meant to wrap other batch trainer objects + that create decision_function objects. + It performs post processing on the output decision_function objects + with the intent of representing the decision_function with fewer + basis vectors. + + + + + + + + reduced + dlib/svm.h + dlib/svm/reduced_abstract.h + + This is a convenience function for creating + reduced_decision_function_trainer + objects. + + + + + + + + reduced2 + dlib/svm.h + dlib/svm/reduced_abstract.h + + This is a convenience function for creating + reduced_decision_function_trainer2 + objects. + + + svm_ex.cpp.html + + + + + + + + reduced_decision_function_trainer2 + dlib/svm.h + dlib/svm/reduced_abstract.h + +

    + This is a batch trainer object that is meant to wrap other batch trainer objects + that create decision_function objects. + It performs post processing on the output decision_function objects + with the intent of representing the decision_function with fewer + basis vectors. +

    +

    + It begins by performing the same post processing as + the reduced_decision_function_trainer + object but it also performs a global gradient based optimization + to further improve the results. The gradient based optimization is + implemented using the approximate_distance_function routine. +

    +
    + + svm_ex.cpp.html + + +
    + + + + + + approximate_distance_function + dlib/svm.h + dlib/svm/reduced_abstract.h + + This function attempts to find a distance_function object which is close + to a target distance_function. That is, it searches for an X such that target(X) is + minimized. Critically, X may be set to use fewer basis vectors than the target. + +

    The optimization begins with an initial guess supplied by the user + and searches for an X which locally minimizes target(X). Since + this problem can have many local minima the quality of the starting point + can significantly influence the results.

    +
    + +
    + + + + + test_binary_decision_function + dlib/svm.h + dlib/svm/svm_abstract.h + + Tests a decision_function that represents a binary decision function and + returns the test accuracy. + + + + + + + + + test_multiclass_decision_function + dlib/svm.h + dlib/svm/cross_validate_multiclass_trainer_abstract.h + + Tests a multiclass decision function (e.g. one_vs_one_decision_function) + and returns a confusion matrix describing the results. + + + multiclass_classification_ex.cpp.html + custom_trainer_ex.cpp.html + + + + + + + + + cross_validate_trainer_threaded + dlib/svm_threaded.h + dlib/svm/svm_threaded_abstract.h + + Performs k-fold cross validation on a user supplied binary classification trainer object such + as the svm_nu_trainer or rbf_network_trainer. + This function does the same thing as cross_validate_trainer + except this function also allows you to specify how many threads of execution to use. + So you can use this function to take advantage of a multi-core system to perform + cross validation faster. + + + + + + + cross_validate_trainer + dlib/svm.h + dlib/svm/svm_abstract.h + + Performs k-fold cross validation on a user supplied binary classification trainer object such + as the svm_nu_trainer or rbf_network_trainer. + + + svm_ex.cpp.html + model_selection_ex.cpp.html + + + + + + + + cross_validate_multiclass_trainer + dlib/svm.h + dlib/svm/cross_validate_multiclass_trainer_abstract.h + + Performs k-fold cross validation on a user supplied multiclass classification trainer object such + as the one_vs_one_trainer. The result is described by a + confusion matrix. + + + multiclass_classification_ex.cpp.html + custom_trainer_ex.cpp.html + + + + + + + + cross_validate_regression_trainer + dlib/svm.h + dlib/svm/cross_validate_regression_trainer_abstract.h + + Performs k-fold cross validation on a user supplied regression trainer object such + as the svr_trainer and returns the mean squared error + and R-squared value. + + + svr_ex.cpp.html + + + + + + + + cross_validate_sequence_labeler + dlib/svm.h + dlib/svm/cross_validate_sequence_labeler_abstract.h + + Performs k-fold cross validation on a user supplied sequence labeling trainer object such + as the structural_sequence_labeling_trainer + and returns a confusion matrix describing the results. + + + sequence_labeler_ex.cpp.html + + + + + + + + cross_validate_sequence_segmenter + dlib/svm.h + dlib/svm/cross_validate_sequence_segmenter_abstract.h + + Performs k-fold cross validation on a user supplied sequence segmentation trainer object such + as the structural_sequence_segmentation_trainer + and returns the resulting precision, recall, and F1-score. + + + sequence_segmenter.py.html + sequence_segmenter_ex.cpp.html + + + + + + + cross_validate_assignment_trainer + dlib/svm.h + dlib/svm/cross_validate_assignment_trainer_abstract.h + + Performs k-fold cross validation on a user supplied assignment trainer object such + as the structural_assignment_trainer + and returns the fraction of assignments predicted correctly. + + + assignment_learning_ex.cpp.html + + + + + + + + cross_validate_track_association_trainer + dlib/svm_threaded.h + dlib/svm/cross_validate_track_association_trainer_abstract.h + + Performs k-fold cross validation on a user supplied track association trainer object such + as the structural_track_association_trainer + and returns the fraction of detections which were correctly associated to their tracks. + + + learning_to_track_ex.cpp.html + + + + + + + cross_validate_graph_labeling_trainer + dlib/svm_threaded.h + dlib/svm/cross_validate_graph_labeling_trainer_abstract.h + + Performs k-fold cross validation on a user supplied graph labeling trainer object such + as the structural_graph_labeling_trainer + and returns the fraction of assignments predicted correctly. + + + graph_labeling_ex.cpp.html + + + + + + + + cross_validate_ranking_trainer + dlib/svm.h + dlib/svm/ranking_tools_abstract.h + + Performs k-fold cross validation on a user supplied ranking trainer object such + as the svm_rank_trainer + and returns the fraction of ranking pairs ordered correctly as well as the mean + average precision. + + + svm_rank_ex.cpp.html + svm_rank.py.html + + + + + + + + test_sequence_labeler + dlib/svm.h + dlib/svm/cross_validate_sequence_labeler_abstract.h + + Tests a sequence_labeler on a set of data + and returns a confusion matrix describing the results. + + + sequence_labeler_ex.cpp.html + + + + + + + + test_sequence_segmenter + dlib/svm.h + dlib/svm/cross_validate_sequence_segmenter_abstract.h + + Tests a sequence_segmenter on a set of data + and returns the resulting precision, recall, and F1-score. + + + sequence_segmenter.py.html + sequence_segmenter_ex.cpp.html + + + + + + + test_assignment_function + dlib/svm.h + dlib/svm/cross_validate_assignment_trainer_abstract.h + + Tests an assignment_function on a set of data + and returns the fraction of assignments predicted correctly. + + + assignment_learning_ex.cpp.html + + + + + + + + test_track_association_function + dlib/svm_threaded.h + dlib/svm/cross_validate_track_association_trainer_abstract.h + + Tests a track_association_function on a set of data + and returns the fraction of detections which were correctly associated to their tracks. + + + learning_to_track_ex.cpp.html + + + + + + + + test_graph_labeling_function + dlib/svm_threaded.h + dlib/svm/cross_validate_graph_labeling_trainer_abstract.h + + Tests a graph_labeler on a set of data + and returns the fraction of labels predicted correctly. + + + + + + + + average_precision + dlib/statistics.h + dlib/statistics/average_precision_abstract.h + + This function computes the average precision of a ranking. + + + + + + + equal_error_rate + dlib/statistics.h + dlib/statistics/lda_abstract.h + + This function finds a threshold that best separates the elements of two + vectors by selecting the threshold with equal error rate. It also reports + the value of the equal error rate. + + + + + + + compute_roc_curve + dlib/statistics.h + dlib/statistics/lda_abstract.h + + This function computes a ROC curve (receiver operating characteristic curve). + + + + + + + test_ranking_function + dlib/svm.h + dlib/svm/ranking_tools_abstract.h + + Tests a decision_function's ability to correctly + rank a dataset and returns the resulting ranking accuracy and mean average precision metrics. + + + svm_rank_ex.cpp.html + svm_rank.py.html + + + + + + + + test_shape_predictor + dlib/image_processing.h + dlib/image_processing/shape_predictor_abstract.h + + Tests a shape_predictor's ability to correctly + predict the part locations of objects. The output is the average distance (measured in pixels) between + each part and its true location. You can optionally normalize each distance using a + user supplied scale. For example, when performing face landmarking, you might want to + normalize the distances by the interocular distance. + + + train_shape_predictor_ex.cpp.html + train_shape_predictor.py.html + + + + + + + + cross_validate_object_detection_trainer + dlib/svm.h + dlib/svm/cross_validate_object_detection_trainer_abstract.h + + Performs k-fold cross validation on a user supplied object detection trainer such + as the structural_object_detection_trainer + and returns the precision and recall. + + + object_detector_ex.cpp.html + object_detector_advanced_ex.cpp.html + train_object_detector.cpp.html + + + + + + + + test_object_detection_function + dlib/svm.h + dlib/svm/cross_validate_object_detection_trainer_abstract.h + + Tests an object detector such + as the object_detector + and returns the precision and recall. + + + fhog_object_detector_ex.cpp.html + object_detector_ex.cpp.html + object_detector_advanced_ex.cpp.html + train_object_detector.cpp.html + dnn_mmod_ex.cpp.html + dnn_mmod_train_find_cars_ex.cpp.html + + + + + + + + test_regression_function + dlib/svm.h + dlib/svm/cross_validate_regression_trainer_abstract.h + + Tests a regression function (e.g. decision_function) + and returns the mean squared error and R-squared value. + + + + + + + + structural_svm_problem + dlib/svm.h + dlib/svm/structural_svm_problem_abstract.h + + This object, when used with the oca optimizer, is a tool + for solving the optimization problem associated + with a structural support vector machine. A structural SVM is a supervised + machine learning method for learning to predict complex outputs. This is + contrasted with a binary classifier which makes only simple yes/no + predictions. A structural SVM, on the other hand, can learn to predict + complex outputs such as entire parse trees or DNA sequence alignments. To + do this, it learns a function F(x,y) which measures how well a particular + data sample x matches a label y. When used for prediction, the best label + for a new x is given by the y which maximizes F(x,y). + +
    +
    + + For an introduction to structured support vector machines you should consult + the following paper: +
    + Predicting Structured Objects with Support Vector Machines by + Thorsten Joachims, Thomas Hofmann, Yisong Yue, and Chun-nam Yu +
    + + For a more detailed discussion of the particular algorithm implemented by this + object see the following paper: +
    + T. Joachims, T. Finley, Chun-Nam Yu, Cutting-Plane Training of Structural SVMs, + Machine Learning, 77(1):27-59, 2009. +
    + Note that this object is essentially a tool for solving the 1-Slack structural + SVM with margin-rescaling. Specifically, see Algorithm 3 in the above referenced + paper. + +

    + Finally, for a very detailed introduction to this subject, you should consider the book: +
    + Structured + Prediction and Learning in Computer Vision by Sebastian Nowozin and + Christoph H. Lampert +
    + +
    + + svm_struct.py.html + svm_struct_ex.cpp.html + + +
    + + + + + structural_svm_problem_threaded + dlib/svm_threaded.h + dlib/svm/structural_svm_problem_threaded_abstract.h + + This is just a version of the structural_svm_problem + which is capable of using multiple cores/threads at a time. You should use it if + you have a multi-core CPU and the separation oracle takes a long time to compute. Or even better, if you + have multiple computers then you can use the svm_struct_controller_node + to distribute the work across many computers. + + + svm_struct_ex.cpp.html + + + + + + + + structural_svm_object_detection_problem + dlib/svm_threaded.h + dlib/svm/structural_svm_object_detection_problem_abstract.h + + This object is a tool for learning the parameter vector needed to use + a scan_fhog_pyramid, + scan_image_pyramid, + scan_image_boxes, or + scan_image_custom object. + +

    + It learns the parameter vector by formulating the problem as a structural SVM problem. + The exact details of the method are described in the paper + Max-Margin Object Detection by Davis E. King. +

    +
    + +
    + + + + + structural_svm_sequence_labeling_problem + dlib/svm_threaded.h + dlib/svm/structural_svm_sequence_labeling_problem_abstract.h + + This object is a tool for learning the weight vector needed to use + a sequence_labeler object. + + It learns the parameter vector by formulating the problem as a + structural SVM problem. + The general approach is discussed in the paper: +
    + Hidden Markov Support Vector Machines by + Y. Altun, I. Tsochantaridis, T. Hofmann +
    + While the particular optimization strategy used is the method from: +
    + T. Joachims, T. Finley, Chun-Nam Yu, Cutting-Plane Training of + Structural SVMs, Machine Learning, 77(1):27-59, 2009. +
    +
    + +
    + + + + + structural_svm_assignment_problem + dlib/svm_threaded.h + dlib/svm/structural_svm_assignment_problem_abstract.h + + This object is a tool for learning the parameters needed to use + an assignment_function object. + It learns the parameters by formulating the problem as a + structural SVM problem. + + + + + + + + structural_svm_graph_labeling_problem + dlib/svm_threaded.h + dlib/svm/structural_svm_graph_labeling_problem_abstract.h + + This object is a tool for learning the weight vectors needed to use + a graph_labeler object. + It learns the parameter vectors by + formulating the problem as a structural SVM problem. + + + + + + + + structural_object_detection_trainer + dlib/svm_threaded.h + dlib/svm/structural_object_detection_trainer_abstract.h + + This object is a tool for learning to detect objects in images based on a set of labeled images. + The training procedure produces an object_detector which + can be used to predict the locations of objects in new images. + It learns the parameter vector by formulating the problem as a structural SVM problem. + The exact details of the method are described in the paper + Max-Margin Object Detection by Davis E. King. +

    + Note that this is just a convenience wrapper around the + structural_svm_object_detection_problem + to make it look similar to all the other trainers in dlib. +

    +
    + + fhog_object_detector_ex.cpp.html + object_detector_ex.cpp.html + object_detector_advanced_ex.cpp.html + train_object_detector.cpp.html + + train_object_detector.py.html + + +
    + + + + + structural_sequence_labeling_trainer + dlib/svm_threaded.h + dlib/svm/structural_sequence_labeling_trainer_abstract.h + + This object is a tool for learning to do sequence labeling based + on a set of training data. The training procedure produces a + sequence_labeler object which can + be use to predict the labels of new data sequences. +

    + Note that this is just a convenience wrapper around the + structural_svm_sequence_labeling_problem + to make it look similar to all the other trainers in dlib. +

    +
    + + sequence_labeler_ex.cpp.html + + +
    + + + + + structural_sequence_segmentation_trainer + dlib/svm_threaded.h + dlib/svm/structural_sequence_segmentation_trainer_abstract.h + + This object is a tool for learning to do sequence segmentation based on a + set of training data. The training procedure produces a sequence_segmenter + object which can be used to identify the sub-segments of new data sequences. +

    + This object internally uses the structural_sequence_labeling_trainer + to solve the learning problem. +

    +
    + + sequence_segmenter.py.html + sequence_segmenter_ex.cpp.html + + +
    + + + + + structural_graph_labeling_trainer + dlib/svm_threaded.h + dlib/svm/structural_graph_labeling_trainer_abstract.h + + This object is a tool for learning to solve a graph labeling problem based + on a training dataset of example labeled graphs. + The training procedure produces a graph_labeler object + which can be used to predict the labelings of new graphs. + +

    + To elaborate, a graph labeling problem is a task to learn a binary classifier which + predicts the label of each node in a graph. Additionally, we have information in + the form of edges between nodes where edges are present when we believe the + linked nodes are likely to have the same label. Therefore, part of a graph labeling + problem is to learn to score each edge in terms of how strongly the edge should enforce + labeling consistency between its two nodes. +

    + +

    + Note that this is just a convenience wrapper around the + structural_svm_graph_labeling_problem + to make it look similar to all the other trainers in dlib. You might also + consider reading the book + Structured + Prediction and Learning in Computer Vision by Sebastian + Nowozin and Christoph H. Lampert since it contains a good introduction to machine learning + methods such as the algorithm implemented by the structural_graph_labeling_trainer. +

    +
    + + graph_labeling_ex.cpp.html + + +
    + + + + + structural_assignment_trainer + dlib/svm_threaded.h + dlib/svm/structural_assignment_trainer_abstract.h + + This object is a tool for learning to solve an assignment problem based + on a training dataset of example assignments. The training procedure produces an + assignment_function object which can be used + to predict the assignments of new data. + + + Note that this is just a convenience wrapper around the + structural_svm_assignment_problem + to make it look similar to all the other trainers in dlib. + + + assignment_learning_ex.cpp.html + + + + + + + + structural_track_association_trainer + dlib/svm_threaded.h + dlib/svm/structural_track_association_trainer_abstract.h + + This object is a tool for learning to solve a track association problem. That + is, it takes in a set of training data and outputs a + track_association_function + you can use to do detection to track association. + + + learning_to_track_ex.cpp.html + + + + + + + svm_struct_controller_node + dlib/svm_threaded.h + dlib/svm/structural_svm_distributed_abstract.h + + This object is a tool for distributing the work involved in solving a + structural_svm_problem across many computers. + + + + + + + + svm_struct_processing_node + dlib/svm_threaded.h + dlib/svm/structural_svm_distributed_abstract.h + + This object is a tool for distributing the work involved in solving a + structural_svm_problem across many computers. + + + + + + +
    + + + + +
    + + diff --git a/ml/dlib/docs/docs/ml_guide.dia b/ml/dlib/docs/docs/ml_guide.dia new file mode 100644 index 000000000..1d66e5391 Binary files /dev/null and b/ml/dlib/docs/docs/ml_guide.dia differ diff --git a/ml/dlib/docs/docs/ml_guide.svg b/ml/dlib/docs/docs/ml_guide.svg new file mode 100644 index 000000000..d971f5e89 --- /dev/null +++ b/ml/dlib/docs/docs/ml_guide.svg @@ -0,0 +1,4345 @@ + + + + + + image/svg+xml + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + NO + + + + + + + + + + + + + + Do you want to detect + objects in images? + + + + + + + Predicting the labels + of nodes in a graph? + + + + + + + Binary labels on + nodes in a graph? + + + + + + + A chain structured graph? + (e.g. words in a sentence) + + + + + + + Are you trying to make + a BIO or BILOU tagger? + + + + + + + + + + + + Trying to solve an + assignment problem? + + + + YES + + + NO + + + NO + + + YES + + + NO + + + NO + + + YES + + + NO + + + YES + + + YES + + + YES + + + NO + + + + + + + + + + + + + + + + + + + + + + + + + + + + Structured Prediction + + + Markov Random Fields + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + YES + + + NO + + + + + + + + + + + Want to make + a tracker? + + + + + + + + structural_track_association_trainer + + + + + + + + + + YES + + + + + + + + NO + + + + + + + svm_rank_trainer + + + + + + + + + structural_object_detection_trainer + + + + + + + + + structural_sequence_segmentation_trainer + + + + + + + + + structural_sequence_labeling_trainer + + + + + + + + + structural_graph_labeling_trainer + + + + + + + + + structural_assignment_trainer + + + + + + + + + structural_svm_problem + (Used to build your own + structured precition tool!) + + + + + + + + + + + + + + + + + + + + + Predicting a + true or false label? + + + + + + + Predicting a + categorial label? + + + + + + + Predicting a + continuous quantity? + + + + + + + Do you have + labeled data? + + + + + + + Are you trying + to rank order + something? + + + + + + + Do you want + to transform + your data? + + + + YES + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Do you know how + many categories? + + + + TOO + SLOW + + + YES + + + NO + + + Clustering + + + + + + + + + + + + + + + + + + + + + + + + + + < 5K + Samples + + + + + + + + + + + + + + + + + + + YES + + + + + + + + + + + + + + + + + + + + + NO + + + + + + Do you have + a graph of "similar" + samples? + + + + + + + + + + + + + + + + + YES + + + NO + + + YES + + + NO + + + YES + + + NO + + + Data Transformations + + + + + + + + + + + + + + + Number of + features + < 100 + + + + NO + + + YES + + + + + + < 20K + Samples + + + + NOT + WORKING + + + YES + + + NO + + + NO + + + + + + + + + Do you have + labeled data? + + + + + + + Are you trying + to label things + as anomalous + vs. normal? + + + + + + + < 20K + Samples + + + + YES + + + + + + Go get + labels! + + + + NO + + + YES + + + NO + + + + + + Number of + features + < 100 + + + + NO + + + YES + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + YES + + + Classification + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + NOT + WORKING + + + + + + < 20K + Samples + + + + YES + + + NO + + + + + + Is this a time-series + or online prediction + problem? + + + + NO + + + + + + + + + + + + + + + + + + + + + + + + + + + + YES + + + Regression + + + + + + + + NO + + + YES + + + YES + + + NO + + + NO + + + YES + + + NO + + + YES + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + svm_c_trainer + with radial_basis_kernel or + histogram_intersection_kernel + + + + + + + + + svm_c_linear_trainer + + + + + + + + + svm_one_class_trainer + with radial_basis_kernel + + + + + + + + + svm_c_linear_dcd_trainer + (see one_class_classifiers_ex.cpp + example program) + + + + + + + + + svm_multiclass_linear_trainer + + + + + + + + + one_vs_one_trainer + with krr_trainer using + radial_basis_kernel + + + + + + + + + krr_trainer with + radial_basis_kernel + + + + + + + + + newman_cluster or + chinese_whispers + + + + + + + + + spectral_cluster or + find_clusters_using_kmeans + + + + + + + + + vector_normalizer_frobmetric + + + + + + + + + linear_manifold_regularizer + + + + + + + + + krr_trainer with + radial_basis_kernel + + + + + + + + + svr_trainer with + radial_basis_kernel or + histogram_intersection_kernel + + + + + + + + + svr_linear_trainer + + + + + + + + + sammon_projection + + + + + + + + + discriminant_pca + + + + + + + + + krls or rls + + + + + + + + + cca + + + + + + + + + + + + + + + + + + + + + + + + Learning a + distance metric? + + + + + + + + Do you have + two views of + your data? + + + diff --git a/ml/dlib/docs/docs/network.xml b/ml/dlib/docs/docs/network.xml new file mode 100644 index 000000000..c632f977c --- /dev/null +++ b/ml/dlib/docs/docs/network.xml @@ -0,0 +1,259 @@ + + + + + Networking + + + + +

    + This page documents tools built on top of the dlib sockets API. + Therefore, all these tools are focused on providing some kind of higher level networking + abstraction or service. +

    + + + + + + + +
    + Objects + linker + server + server_iostream + server_http + bridge + sockstreambuf + iosockstream +
    + +
    + BSP + bsp_connect + bsp_listen + bsp_listen_dynamic_port + bsp_context +
    + +
    + +
    + + + + + + + + + + + + linker + dlib/linker.h + dlib/linker/linker_kernel_abstract.h + + This object represents something that takes two connections and lets + them talk to each other. ie. any incoming data from one connection is + passed unaltered to the other and vice versa. + + + + + + + + bridge + dlib/bridge.h + dlib/bridge/bridge_abstract.h + + This object is a tool for bridging a pipe object between + two network connected applications. + +

    + The bridge object is designed to link two pipes together as efficiently as + possible. To demonstrate its speed, I ran two experiments where a bridge was + used to connect a desktop PC to a laptop, both running Ubuntu 12.04 and + connected via gigabit ethernet. The first experiment is to test its bulk + transfer speed while the second experiment measures how many separate objects + the bridge can transfer each second. +

    +

    + In the first experiment, 1-megapixel images, represented with + array<rgb_pixel> objects, were sent. The transfer rate was + 112 megabytes/second, saturating the gigabit ethernet link. The second + experiment used a pipe<char> and bridge to send individual + char variables over the network. In this experiment, I was able to + send 3.2 million objects a second (i.e. the receiving end was getting a char + back from pipe::dequeue() 3.2 million times each second). +

    +

    + For reference, these experiments were carried out on a desktop with a 2.67GHz + Intel Core-i7 CPU and a laptop with a 2.20GHz Intel Core-i7 CPU. +

    +
    + + bridge_ex.cpp.html + +
    + + + + + bsp_connect + dlib/bsp.h + dlib/bsp/bsp_abstract.h + + This function spawns a BSP job consisting of a number of network hosts + as well as the local host. + + + bsp_ex.cpp.html + + + + + + + bsp_listen + dlib/bsp.h + dlib/bsp/bsp_abstract.h + + This function listens for a TCP connection from the bsp_connect routine. + Once this connection is established, a user supplied function will be executed and it will + then be able to participate in a BSP computation as one of the processing + nodes. + + + bsp_ex.cpp.html + + + + + + + bsp_listen_dynamic_port + dlib/bsp.h + dlib/bsp/bsp_abstract.h + + This function listens for a TCP connection from the bsp_connect routine. + Once this connection is established, a user supplied function will be executed and it will + then be able to participate in a BSP computation as one of the processing + nodes. This function has the additional ability to select the listening TCP port + automatically from the set of available ports. + + + + + + + bsp_context + dlib/bsp.h + dlib/bsp/bsp_abstract.h + + This is a tool used to implement algorithms using the Bulk Synchronous + Parallel (BSP) computing model. In particular, this object defines + the API used for communication between BSP jobs. + + + bsp_ex.cpp.html + + + + + + + server + dlib/server.h + dlib/server/server_kernel_abstract.h + + This object represents a server that listens on a port and spawns new + threads to handle each new connection. It also manages the connections + and threads for you. + + + + sockets_ex.cpp.html + + + + + + + + server_iostream + dlib/server.h + dlib/server/server_iostream_abstract.h + + This is an extension of the server object that redefines + the on_connect() function so that instead of giving you a connection object you + get an istream and ostream object. + + + server_iostream_ex.cpp.html + + + + + + + server_http + dlib/server.h + dlib/server/server_http_abstract.h + + This is an extension of the server_iostream object which + turns it into a simple HTTP server. + + + + server_http_ex.cpp.html + + + + + + + iosockstream + dlib/iosockstream.h + dlib/iosockstream/iosockstream_abstract.h + + This is an iostream object that reads/writes from a TCP network connection. + + + + iosockstream_ex.cpp.html + + + + + + + + + sockstreambuf + dlib/sockstreambuf.h + dlib/sockstreambuf/sockstreambuf_abstract.h + + This object represents a stream buffer for connection objects. If you + are considering using this object then you should also take a look at + the iosockstream. + + + + sockstreambuf_ex.cpp.html + + + + + + + +
    + + + + +
    diff --git a/ml/dlib/docs/docs/old_change_log.xml b/ml/dlib/docs/docs/old_change_log.xml new file mode 100644 index 000000000..bddb0b479 --- /dev/null +++ b/ml/dlib/docs/docs/old_change_log.xml @@ -0,0 +1,7 @@ + + + + + Old Change Logs + + diff --git a/ml/dlib/docs/docs/old_release_notes.xml b/ml/dlib/docs/docs/old_release_notes.xml new file mode 100644 index 000000000..e9e7c408b --- /dev/null +++ b/ml/dlib/docs/docs/old_release_notes.xml @@ -0,0 +1,10 @@ + + + + + Old Release Notes + + + + + diff --git a/ml/dlib/docs/docs/optimization.xml b/ml/dlib/docs/docs/optimization.xml new file mode 100644 index 000000000..cd04ded75 --- /dev/null +++ b/ml/dlib/docs/docs/optimization.xml @@ -0,0 +1,1338 @@ + + + + + Optimization + + + + + +

    + This page documents library components that attempt to find the + minimum or maximum of a user supplied function. An introduction + to the general purpose non-linear optimizers in this section can be + found here. For an example + showing how to use the non-linear least squares routines look + here. +

    + + + + + + + +
    + General Purpose Optimizers + find_min + find_min_box_constrained + find_min_single_variable + find_min_using_approximate_derivatives + find_min_bobyqa + find_min_global + find_max + find_max_box_constrained + find_max_single_variable + find_max_using_approximate_derivatives + find_max_bobyqa + find_max_global + global_function_search + find_max_trust_region + find_min_trust_region +
    + +
    + Special Purpose Optimizers + find_gap_between_convex_hulls + solve_qp_box_constrained + solve_qp_box_constrained_blockdiag + solve_qp_using_smo + solve_qp2_using_smo + solve_qp3_using_smo + solve_qp4_using_smo + oca + mpc + solve_least_squares + solve_least_squares_lm + solve_trust_region_subproblem + solve_trust_region_subproblem_bounded + max_cost_assignment + max_sum_submatrix + find_max_factor_graph_nmplp + find_max_factor_graph_viterbi + find_max_factor_graph_potts + find_max_parse_cky + min_cut + elastic_net + isotonic_regression +
    + +
    + Strategies + cg_search_strategy + bfgs_search_strategy + newton_search_strategy + lbfgs_search_strategy + objective_delta_stop_strategy + gradient_norm_stop_strategy +
    + +
    + Helper Routines + derivative + negate_function + clamp_function + make_line_search_function + poly_min_extrap + lagrange_poly_min_extrap + line_search + backtracking_line_search + graph_cut_score + potts_model_score + parse_tree_to_string + find_trees_not_rooted_with_tag + upper_bound_function + call_function_and_expand_args +
    + +
    +
    + + + + + + + + + + + derivative + dlib/optimization.h + dlib/optimization/optimization_abstract.h + + This is a function that takes another function as input and returns + a function object that numerically computes the derivative of the input function. + + + + + + + + + negate_function + dlib/optimization.h + dlib/optimization/optimization_line_search_abstract.h + + This is a function that takes another function as input and returns + a function object that computes the negation of the input function. + + + + + + + + + clamp_function + dlib/optimization.h + dlib/optimization/optimization_abstract.h + + This is a function that takes another function, f(x), as input and + returns a new function object, g(x), such that + g(x) == f(clamp(x,x_lower,x_upper)) where x_lower and x_upper + are vectors of box constraints which are applied to x. + + + + + + + + make_line_search_function + dlib/optimization.h + dlib/optimization/optimization_line_search_abstract.h + + This is a function that takes another function f(x) as input and returns + a function object l(z) = f(start + z*direction). It is useful for + turning multi-variable functions into single-variable functions for + use with the line_search or + backtracking_line_search routines. + + + + + + + + + poly_min_extrap + dlib/optimization.h + dlib/optimization/optimization_line_search_abstract.h + + This function finds the 2nd or 3rd degree polynomial that interpolates a + set of points and returns the minimum of that polynomial. + + + + + + + + lagrange_poly_min_extrap + dlib/optimization.h + dlib/optimization/optimization_line_search_abstract.h + + This function finds the second order polynomial that interpolates a + set of points and returns the minimum of that polynomial. + + + + + + + + line_search + dlib/optimization.h + dlib/optimization/optimization_line_search_abstract.h + + Performs a gradient based line search on a given function and returns the input + that makes the function significantly smaller. This implements the classic + line search method using the strong Wolfe conditions with a bracketing and then + sectioning phase, both using polynomial interpolation. + + + + + + + + + backtracking_line_search + dlib/optimization.h + dlib/optimization/optimization_line_search_abstract.h + + Performs a line search on a given function and returns the input + that makes the function significantly smaller. This implementation uses a + basic Armijo backtracking search with polynomial interpolation. + + + + + + + + cg_search_strategy + dlib/optimization.h + dlib/optimization/optimization_search_strategies_abstract.h + + This object represents a strategy for determining which direction + a line search should be carried out along. This particular object + is an implementation of the Polak-Ribiere conjugate gradient method + for determining this direction. + +

    + This method uses an amount of memory that is linear in the number + of variables to be optimized. So it is capable of handling problems + with a very large number of variables. However, it is generally + not as good as the L-BFGS algorithm (see the + lbfgs_search_strategy class). +

    +
    + + optimization_ex.cpp.html + + +
    + + + + + bfgs_search_strategy + dlib/optimization.h + dlib/optimization/optimization_search_strategies_abstract.h + + This object represents a strategy for determining which direction + a line search should be carried out along. This particular object + is an implementation of the BFGS quasi-newton method for determining + this direction. + +

    + This method uses an amount of memory that is quadratic in the number + of variables to be optimized. It is generally very effective but + if your problem has a very large number of variables then it isn't + appropriate. Instead, you should try the lbfgs_search_strategy. +

    +
    + + optimization_ex.cpp.html + + +
    + + + + + newton_search_strategy + dlib/optimization.h + dlib/optimization/optimization_search_strategies_abstract.h + + This object represents a strategy for determining which direction + a line search should be carried out along. This particular routine + is an implementation of the newton method for determining this direction. + That means using it requires you to supply a method for + creating hessian matrices for the problem you are trying to optimize. + +

    + Note also that this is actually a helper function for creating + newton_search_strategy_obj objects. +

    + +
    + + optimization_ex.cpp.html + + +
    + + + + + lbfgs_search_strategy + dlib/optimization.h + dlib/optimization/optimization_search_strategies_abstract.h + + This object represents a strategy for determining which direction + a line search should be carried out along. This particular object + is an implementation of the L-BFGS quasi-newton method for determining + this direction. + +

    + This method uses an amount of memory that is linear in the number + of variables to be optimized. This makes it an excellent method + to use when an optimization problem has a large number of variables. +

    +
    + + optimization_ex.cpp.html + + +
    + + + + + objective_delta_stop_strategy + dlib/optimization.h + dlib/optimization/optimization_stop_strategies_abstract.h + + This object represents a strategy for deciding if an optimization + algorithm should terminate. This particular object looks at the + change in the objective function from one iteration to the next and + bases its decision on how large this change is. If the change + is below a user given threshold then the search stops. + + + optimization_ex.cpp.html + + + + + + + + gradient_norm_stop_strategy + dlib/optimization.h + dlib/optimization/optimization_stop_strategies_abstract.h + + This object represents a strategy for deciding if an optimization + algorithm should terminate. This particular object looks at the + norm (i.e. the length) of the current gradient vector and + stops if it is smaller than a user given threshold. + + + + + + + + find_min + dlib/optimization.h + dlib/optimization/optimization_abstract.h + + Performs an unconstrained minimization of a nonlinear function using + some search strategy (e.g. bfgs_search_strategy). + + + optimization_ex.cpp.html + + + + + + + + find_min_box_constrained + dlib/optimization.h + dlib/optimization/optimization_abstract.h + + Performs a box constrained minimization of a nonlinear function using + some search strategy (e.g. bfgs_search_strategy). + This function uses a backtracking line search along with a gradient projection + step to handle the box constraints. + + + optimization_ex.cpp.html + + + + + + + find_min_single_variable + dlib/optimization.h + dlib/optimization/optimization_line_search_abstract.h + + Performs a bound constrained minimization of a nonlinear function. The + function must be of a single variable. Derivatives are not required. + + + + + + + + solve_trust_region_subproblem_bounded + dlib/optimization.h + dlib/optimization/optimization_trust_region_abstract.h + + This function solves the following optimization problem: +
    +Minimize: f(p) == 0.5*trans(p)*B*p + trans(g)*p
    +subject to the following constraint:
    +   length(p) <= radius
    +   lower(i) <= p(i) <= upper(i), for all i
    +
    + +
    + +
    + + + + + solve_trust_region_subproblem + dlib/optimization.h + dlib/optimization/optimization_trust_region_abstract.h + + This function solves the following optimization problem: +
    +Minimize: f(p) == 0.5*trans(p)*B*p + trans(g)*p
    +subject to the following constraint:
    +   length(p) <= radius
    +
    + +
    + +
    + + + + + find_gap_between_convex_hulls + dlib/optimization.h + dlib/optimization/optimization_solve_qp_using_smo_abstract.h + + This function measures the position and size of the gap between two convex + polytopes. In particular, it solves the following quadratic program: +
    +   Minimize: f(cA,cB) == length_squared(A*cA - B*cB) 
    +   subject to the following constraints on cA and cB:
    +      - is_col_vector(cA) == true && cA.size() == A.nc()
    +      - is_col_vector(cB) == true && cB.size() == B.nc()
    +      - sum(cA) == 1 && min(cA) >= 0
    +      - sum(cB) == 1 && min(cB) >= 0
    +
    + +
    + +
    + + + + + solve_qp_using_smo + dlib/optimization.h + dlib/optimization/optimization_solve_qp_using_smo_abstract.h + + This function solves the following quadratic program: +
    +   Minimize: f(alpha) == 0.5*trans(alpha)*Q*alpha - trans(alpha)*b
    +   subject to the following constraints:
    +      sum(alpha) == C 
    +      min(alpha) >= 0 
    +   Where f is convex.  This means that Q should be symmetric and positive-semidefinite.
    +
    + +
    + +
    + + + + + solve_qp_box_constrained + dlib/optimization.h + dlib/optimization/optimization_solve_qp_using_smo_abstract.h + + This function solves the following quadratic program: +
    +   Minimize: f(alpha) == 0.5*trans(alpha)*Q*alpha + trans(b)*alpha 
    +   subject to the following box constraints on alpha:
    +      0 <= min(alpha-lower)
    +      0 <= max(upper-alpha)
    +   Where f is convex.  This means that Q should be positive-semidefinite.
    +
    + +
    + +
    + + + + + solve_qp_box_constrained_blockdiag + dlib/optimization.h + dlib/optimization/optimization_solve_qp_using_smo_abstract.h + + This function solves the following quadratic program: +
    +   Minimize: f(alpha) == 0.5*trans(alpha)*Q*alpha + trans(b)*alpha 
    +   subject to the following box constraints on alpha:
    +      0 <= min(alpha-lower)
    +      0 <= max(upper-alpha)
    +   Where f is convex.  This means that Q should be positive-semidefinite.
    +
    + + So it does the same thing as solve_qp_box_constrained, + except it is optimized for large Q matrices with a special block + structure. In particular, Q must be grouped into identically sized + blocks where all blocks are diagonal matrices, except those on the + main diagonal which can be dense. +
    + +
    + + + + + solve_qp2_using_smo + dlib/optimization.h + dlib/optimization/optimization_solve_qp2_using_smo_abstract.h + + This function solves the following quadratic program: +
    +   Minimize: f(alpha) == 0.5*trans(alpha)*Q*alpha 
    +   subject to the following constraints:
    +      sum(alpha) == nu*y.size() 
    +      0 <= min(alpha) && max(alpha) <= 1 
    +      trans(y)*alpha == 0
    +
    +   Where all elements of y must be equal to +1 or -1 and f is convex.  
    +   This means that Q should be symmetric and positive-semidefinite.
    +
    +
    + This object implements the strategy used by the LIBSVM tool. The following papers + can be consulted for additional details: +
      +
    • Chang and Lin, Training {nu}-Support Vector Classifiers: Theory and Algorithms
    • +
    • Chih-Chung Chang and Chih-Jen Lin, LIBSVM : a library for support vector + machines, 2001. Software available at + http://www.csie.ntu.edu.tw/~cjlin/libsvm
    • +
    + +
    + +
    + + + + + solve_qp3_using_smo + dlib/optimization.h + dlib/optimization/optimization_solve_qp3_using_smo_abstract.h + + This function solves the following quadratic program: +
    +   Minimize: f(alpha) == 0.5*trans(alpha)*Q*alpha + trans(p)*alpha
    +   subject to the following constraints:
    +        for all i such that y(i) == +1:  0 <= alpha(i) <= Cp 
    +        for all i such that y(i) == -1:  0 <= alpha(i) <= Cn 
    +        trans(y)*alpha == B 
    +
    +   Where all elements of y must be equal to +1 or -1 and f is convex.  
    +   This means that Q should be symmetric and positive-semidefinite.
    +
    +
    + This object implements the strategy used by the LIBSVM tool. The following papers + can be consulted for additional details: +
      +
    • Chih-Chung Chang and Chih-Jen Lin, LIBSVM : a library for support vector + machines, 2001. Software available at + http://www.csie.ntu.edu.tw/~cjlin/libsvm
    • +
    • Working Set Selection Using Second Order Information for Training Support Vector Machines by + Fan, Chen, and Lin. In the Journal of Machine Learning Research 2005.
    • +
    + +
    + +
    + + + + + solve_qp4_using_smo + dlib/optimization.h + dlib/optimization/optimization_solve_qp_using_smo_abstract.h + + This function solves the following quadratic program: +
    +   Minimize: f(alpha,lambda) == 0.5*trans(alpha)*Q*alpha - trans(alpha)*b + 
    +                                0.5*trans(lambda)*lambda - trans(lambda)*A*alpha - trans(lambda)*d
    +   subject to the following constraints:
    +      sum(alpha)  == C 
    +      min(alpha)  >= 0 
    +      min(lambda) >= 0
    +      max(lambda) <= max_lambda 
    +   Where f is convex.  This means that Q should be positive-semidefinite.
    +
    + +
    + +
    + + + + + max_cost_assignment + dlib/optimization.h + dlib/optimization/max_cost_assignment_abstract.h + + This function is an implementation of the Hungarian algorithm (also know as the Kuhn-Munkres algorithm) which + runs in O(N^3) time. + It solves the optimal assignment problem. For example, suppose you have an equal number of workers + and jobs and you need to decide which workers to assign to which jobs. Some workers are better at + certain jobs than others. So you would like to find out how to assign them all to jobs such that + overall productivity is maximized. You can use this routine to solve this problem and others like it. +

    + Note that dlib also contains a machine learning + method for learning the cost function needed to use the Hungarian algorithm. +

    + +
    + + max_cost_assignment_ex.cpp.html + max_cost_assignment.py.html + + +
    + + + + + max_sum_submatrix + dlib/optimization.h + dlib/optimization/max_sum_submatrix_abstract.h + + This function finds the submatrix within a user supplied matrix which has the largest sum. It then + zeros out that submatrix and repeats the process until no more maximal submatrices can + be found. + + + + + + + + find_max_factor_graph_nmplp + dlib/optimization.h + dlib/optimization/find_max_factor_graph_nmplp_abstract.h + + This function is a tool for approximately solving the MAP problem in a graphical + model or factor graph with pairwise potential functions. That is, it attempts + to solve a certain kind of optimization problem which can be defined as follows: +
    +   maximize: f(X)
    +   where X is a set of integer valued variables and f(X) can be written
    +   as the sum of functions which each involve only two variables from X.
    +
    +If the graph is tree-structured then this routine always gives the exact solution to the MAP problem. +However, for graphs with cycles, the solution may be approximate. +
    +
    + This function is an implementation of the NMPLP algorithm introduced in the + following papers: +
    + Fixing Max-Product: Convergent Message Passing Algorithms for MAP LP-Relaxations (2008) + by Amir Globerson and Tommi Jaakkola +
    +
    + Introduction to dual decomposition for inference (2011) + by David Sontag, Amir Globerson, and Tommi Jaakkola +
    + +
    + +
    + + + + + find_max_parse_cky + dlib/optimization.h + dlib/optimization/find_max_parse_cky_abstract.h + + This function implements the CKY parsing algorithm. In particular, it + finds the maximum scoring binary parse tree that parses an input sequence of tokens. + + + + + + + + parse_tree_to_string + dlib/optimization.h + dlib/optimization/find_max_parse_cky_abstract.h + + This is a set of functions useful for converting a parse tree output by + find_max_parse_cky into a bracketed string + suitable for displaying the parse tree. + + + + + + + + find_trees_not_rooted_with_tag + dlib/optimization.h + dlib/optimization/find_max_parse_cky_abstract.h + + Finds all the largest non-overlapping parse trees + in tree that are not rooted with a particular tag. +

    + This function is useful when you want to cut a parse tree + into a bunch of sub-trees and you know that the top level of the tree is all + composed of the same kind of tag. So if you want to just "slice off" the top + of the tree where this tag lives then this function is useful for doing that. +

    +
    + +
    + + + + + find_max_factor_graph_viterbi + dlib/optimization.h + dlib/optimization/find_max_factor_graph_viterbi_abstract.h + + This function is a tool for exactly solving the MAP problem in a chain-structured + graphical model or factor graph. In particular, it is an implementation of the classic Viterbi + algorithm for finding the maximizing assignment. In addition to basic first order Markov + models, this function is also capable of finding the MAP assignment for higher order + Markov models. + + + + + + + + potts_model_score + dlib/graph_cuts.h + dlib/graph_cuts/find_max_factor_graph_potts_abstract.h + + This routine computes the model score for a Potts problem and a + candidate labeling. This score is the quantity maximised + by the find_max_factor_graph_potts + routine. + + + + + + + + graph_cut_score + dlib/graph_cuts.h + dlib/graph_cuts/min_cut_abstract.h + + This routine computes the score for a candidate graph cut. This is the + quantity minimized by the min_cut algorithm. + + + + + + + + min_cut + dlib/graph_cuts.h + dlib/graph_cuts/min_cut_abstract.h + + This is a function object which can be used to find the min cut + on a graph. + The implementation is based on the method described in the following + paper: +
    + An Experimental Comparison of Min-Cut/Max-Flow Algorithms for + Energy Minimization in Vision, by Yuri Boykov and Vladimir Kolmogorov, + in PAMI 2004. +
    +
    + +
    + + + + + find_max_factor_graph_potts + dlib/graph_cuts.h + dlib/graph_cuts/find_max_factor_graph_potts_abstract.h + + This is a set of overloaded functions for exactly solving the MAP problem in a Potts + model. This type of model is useful when you have a problem which + can be modeled as a bunch of binary decisions on some variables, + but you have some kind of labeling consistency constraint. This + means that there is some penalty for giving certain pairs of variables + different labels. So in addition to trying to figure out how to best + label each variable on its own, you have to worry about making the + labels pairwise consistent in some sense. The find_max_factor_graph_potts() + routine can be used to find the most probable/highest scoring + labeling for this type of model. +

    The implementation of this routine is based on the min_cut object.

    +
    + +
    + + + + + oca + dlib/optimization.h + dlib/optimization/optimization_oca_abstract.h + + This object is a tool for solving the following optimization problem: +
    +   Minimize: f(w) == 0.5*||w||^2 + C*R(w)
    +
    +   Where R(w) is a user-supplied convex function and C > 0.  Optionally,
    +   this object can also add non-negativity constraints to some or all
    +   of the elements of w.
    +
    +Or it can alternatively solve:
    +   Minimize: f(w) == 0.5*||w-prior||^2 + C*R(w)
    +
    +   Where prior is a user supplied vector and R(w) has the same
    +   interpretation as above.
    +
    +Or it can use the elastic net regularizer:
    +   Minimize: f(w) == 0.5*(1-lasso_lambda)*length_squared(w) + lasso_lambda*sum(abs(w)) + C*R(w)
    +
    +   Where lasso_lambda is a number in the range [0, 1) and controls
    +   trade-off between doing L1 and L2 regularization.  R(w) has the same
    +   interpretation as above.
    +
    +
    +
    + + For a detailed discussion you should consult the following papers + from the Journal of Machine Learning Research: +
    + Optimized Cutting Plane Algorithm for Large-Scale Risk Minimization + by Vojtech Franc, Soren Sonnenburg; 10(Oct):2157--2192, 2009. +
    +
    + Bundle Methods for Regularized Risk Minimization + by Choon Hui Teo, S.V.N. Vishwanthan, Alex J. Smola, Quoc V. Le; 11(Jan):311-365, 2010. +
    + +
    + +
    + + + + + mpc + dlib/control.h + dlib/control/mpc_abstract.h + + This object implements a linear model predictive controller. + In particular, it solves a certain quadratic program using the method + described in the paper: +
    + A Fast Gradient method for embedded linear predictive control (2011) + by Markus Kogel and Rolf Findeisen +
    +
    + + mpc_ex.cpp.html + +
    + + + + + + find_min_bobyqa + dlib/optimization.h + dlib/optimization/optimization_bobyqa_abstract.h + + This function defines the dlib interface to the BOBYQA software developed by M.J.D Powell. + BOBYQA is a method for optimizing a function in the absence of derivative information. + Powell described it as a method that seeks the least value of a function of many + variables, by applying a trust region method that forms quadratic models by + interpolation. There is usually some freedom in the interpolation conditions, + which is taken up by minimizing the Frobenius norm of the change to the second + derivative of the model, beginning with the zero matrix. The values of the variables + are constrained by upper and lower bounds. + +

    + The following paper, published in 2009 by Powell, describes the + detailed working of the BOBYQA algorithm. + +

    + The BOBYQA algorithm for bound constrained optimization + without derivatives by M.J.D. Powell +
    +

    + +

    + Note that BOBYQA only works on functions of two or more variables. So if you need to perform + derivative-free optimization on a function of a single variable + then you should use the find_min_single_variable + function. +

    + +
    + + optimization_ex.cpp.html + + +
    + + + + + find_max_bobyqa + dlib/optimization.h + dlib/optimization/optimization_bobyqa_abstract.h + + This function is identical to the find_min_bobyqa routine + except that it negates the objective function before performing optimization. + Thus this function will attempt to find the maximizer of the objective rather than + the minimizer. +

    + Note that BOBYQA only works on functions of two or more variables. So if you need to perform + derivative-free optimization on a function of a single variable + then you should use the find_max_single_variable + function. +

    +
    + + optimization_ex.cpp.html + + +
    + + + + + isotonic_regression + dlib/optimization.h + dlib/optimization/isotonic_regression_abstract.h + + This object is a tool for performing 1-D isotonic regression. That is, it + finds the least squares fit of a non-parametric curve to some user supplied + data, subject to the constraint that the fitted curve is non-decreasing. + +

    + This is done using the fast O(n) pool adjacent violators algorithm. +

    +
    +
    + + + + + elastic_net + dlib/optimization/elastic_net.h + dlib/optimization/elastic_net_abstract.h + + This object is a tool for solving the following optimization problem: + +
    +   min_w:      length_squared(X*w - Y) + ridge_lambda*length_squared(w)
    +   such that:  sum(abs(w)) <= lasso_budget
    +
    + +

    + That is, it solves the elastic net optimization problem. This object also + has the special property that you can quickly obtain different solutions + for different settings of ridge_lambda, lasso_budget, and target Y values. +

    + +

    + This is because a large amount of work is precomputed in the constructor. + The solver will also remember the previous solution and will use that to + warm start subsequent invocations. Therefore, you can efficiently get + solutions for a wide range of regularization parameters. +

    + + + The particular algorithm used to solve it is described in the paper: +
    + Zhou, Quan, et al. "A reduction of the elastic net to support vector + machines with an application to gpu computing." arXiv preprint + arXiv:1409.1976 (2014). APA +
    + + And for the SVM solver sub-component we use the algorithm from: +
    + Hsieh, Cho-Jui, et al. "A dual coordinate descent method for large-scale + linear SVM." Proceedings of the 25th international conference on Machine + learning. ACM, 2008. +
    +
    +
    + + + + + find_min_using_approximate_derivatives + dlib/optimization.h + dlib/optimization/optimization_abstract.h + + Performs an unconstrained minimization of a nonlinear function using + some search strategy (e.g. bfgs_search_strategy). + This version doesn't take a gradient function but instead numerically approximates + the gradient. + + + optimization_ex.cpp.html + + + + + + + + solve_least_squares + dlib/optimization.h + dlib/optimization/optimization_least_squares_abstract.h + + This is a function for solving non-linear least squares problems. It uses a method + which combines the traditional Levenberg-Marquardt technique with a quasi-newton + approach. It is appropriate for large residual problems (i.e. problems where the + terms in the least squares function, the residuals, don't go to zero but remain + large at the solution) + + + + least_squares_ex.cpp.html + + + + + + + solve_least_squares_lm + dlib/optimization.h + dlib/optimization/optimization_least_squares_abstract.h + + This is a function for solving non-linear least squares problems. It uses + the traditional Levenberg-Marquardt technique. + It is appropriate for small residual problems (i.e. problems where the + terms in the least squares function, the residuals, go to zero at the solution) + + + + least_squares_ex.cpp.html + + + + + + + find_min_trust_region + dlib/optimization.h + dlib/optimization/optimization_trust_region_abstract.h + + Performs an unconstrained minimization of a nonlinear function using + a trust region method. + + + optimization_ex.cpp.html + + + + + + + + find_max_trust_region + dlib/optimization.h + dlib/optimization/optimization_trust_region_abstract.h + + Performs an unconstrained maximization of a nonlinear function using + a trust region method. + + + + + + + + find_max + dlib/optimization.h + dlib/optimization/optimization_abstract.h + + Performs an unconstrained maximization of a nonlinear function using + some search strategy (e.g. bfgs_search_strategy). + + + + + + + + find_max_box_constrained + dlib/optimization.h + dlib/optimization/optimization_abstract.h + + Performs a box constrained maximization of a nonlinear function using + some search strategy (e.g. bfgs_search_strategy). + This function uses a backtracking line search along with a gradient projection + step to handle the box constraints. + + + + + + + + find_max_single_variable + dlib/optimization.h + dlib/optimization/optimization_line_search_abstract.h + + Performs a bound constrained maximization of a nonlinear function. The + function must be of a single variable. Derivatives are not required. + + + + + + + + find_max_using_approximate_derivatives + dlib/optimization.h + dlib/optimization/optimization_abstract.h + + Performs an unconstrained maximization of a nonlinear function using + some search strategy (e.g. bfgs_search_strategy). + This version doesn't take a gradient function but instead numerically approximates + the gradient. + + + + + + + + + + global_function_search + dlib/global_optimization.h + dlib/global_optimization/global_function_search_abstract.h + + This object performs global optimization of a set of user supplied + functions. That is, given a set of functions, each of which could take a different + number of arguments, this object allows you to find which function and which arguments + produce the maximal output. + +

    + Importantly, the global_function_search object does not require the user to + supply derivatives. Moreover, the functions being optimized may contain discontinuities, + behave stochastically, and have many local maxima. The global_function_search + object will attempt to find the global optima in the face of these challenges. + It is also designed to use as few function evaluations as possible, making + it suitable for optimizing functions that are very expensive to evaluate. + It does this by alternating between two modes: a global exploration mode + and a local optima refinement mode. This is accomplished by building and + maintaining two models of the objective function: +

    +
      +
    1. + A global model that upper bounds our objective function. This is a non-parametric + piecewise linear model derived from all function evaluations ever seen by the + global_function_search object. This is based on the method described in Global + Optimization of Lipschitz Functions by Cédric Malherbe and Nicolas Vayatis in the + 2017 International Conference on Machine Learning. +
    2. +
    3. + A local quadratic model fit around the best point seen so far. This uses + a trust region method similar to what is proposed in: + The NEWUOA software for unconstrained optimization without derivatives By + M.J.D. Powell, 40th Workshop on Large Scale Nonlinear Optimization (Erice, + Italy, 2004) +
    4. +
    + + The behavior of the algorithm is illustrated in the following video, which shows the solver in action. In the video, the red line + is the function to be optimized and we are looking for the maximum point. Every time + the global_function_search samples a point from the function we note it with a little + box. The state of the solver is determined by the two models discussed above. + Therefore, we draw the upper bounding model as well as the current local quadratic model + so you can see how they evolve as the optimization proceeds. We also note the location of the + best point seen so far by a little vertical line. +

    + You can see that the optimizer is alternating between picking the maximum upper bounding + point and the maximum point according to the quadratic model. As the optimization + progresses, the upper bound becomes progressively more accurate, helping to find the + best peak to investigate, while the quadratic model quickly finds a high precision + maximizer on whatever peak it currently rests. These two things together allow the + optimizer to find the true global maximizer to high precision (within 1e-9 in this case) by the time the + video concludes. +

    +
    + +
    + +

    + Finally, note that the find_max_global routine is + essentially a simple wrapper around the global_function_search object and exists to + provide a convenient interface. Most users will therefore want to call find_max_global + rather than global_function_search. However, the API of global_function_search + is more general and allows for of a wider set of usage patterns, for example, executing + objective function evaluations in parallel. So more advanced users may want to use + global_function_search directly rather than find_max_global. But try to use find_max_global() first. +

    + +
    + +
    + + + + + find_max_global + dlib/global_optimization.h + dlib/global_optimization/find_max_global_abstract.h + + This function performs global optimization of a function, subject + to bounds constraints. This means it attempts to find the global + maximizer, not just a local maximizer. The search is performed + using the global_function_search object. + See global_function_search's documentation for details of the algorithm. Importantly, + find_max_global() does not require the user to specify derivatives + or starting guesses, all while attempting to use as few calls to + the objective function as possible. It is therefore appropriate for tasks + where evaluating the objective function is time consuming or + expensive, such as in hyper parameter optimization of machine + learning models. + + + + optimization_ex.cpp.html + model_selection_ex.cpp.html + global_optimization.py.html + + + + + + + + find_min_global + dlib/global_optimization.h + dlib/global_optimization/find_max_global_abstract.h + + This function is identical to the find_max_global routine + except it negates the objective function before performing optimization. + Thus this function will attempt to find the minimizer of the objective rather than + the maximizer. + + + optimization_ex.cpp.html + model_selection_ex.cpp.html + global_optimization.py.html + + + + + + + + upper_bound_function + dlib/global_optimization.h + dlib/global_optimization/upper_bound_function_abstract.h + + This object represents a piecewise linear non-parametric function that can + be used to define an upper bound on some more complex and unknown function. + +

    + This is based on the method described in Global Optimization of Lipschitz + Functions by Cédric Malherbe and Nicolas Vayatis in the 2017 International + Conference on Machine Learning. Here we have extended it to support modeling of + stochastic or discontinuous functions by adding a noise term. We also model separate + Lipschitz parameters for each dimension, allowing the model to handle functions with + widely varying sensitivities to each input variable. +

    +
    + +
    + + + + + call_function_and_expand_args + dlib/global_optimization.h + dlib/global_optimization/find_max_global_abstract.h + + This routine allows you to pass a dlib::matrix<double,0,1> object to + a function that takes simple double arguments. It does this by automatically + expanding the matrix elements and invoking the function. For example, suppose you had + a function like this: + +double f(double x, double y, double z); + You could then call f() like this: + +matrix<double,0,1> args = {3,4,5}; +call_function_and_expand_args(f, args); // calls: f(3,4,5) + + This kind of thing is convenient when writing optimizers like find_max_global since it allows a wide range of + input functions to be given to the optimizer, including functions with + explicitly named arguments like x,y,z as shown above. + + + + + + + +
    + + + + +
    + diff --git a/ml/dlib/docs/docs/other.xml b/ml/dlib/docs/docs/other.xml new file mode 100644 index 000000000..d8d151d87 --- /dev/null +++ b/ml/dlib/docs/docs/other.xml @@ -0,0 +1,1166 @@ + + + + + Miscellaneous + + + + + +

    + + This page documents library components that don't really fit in anywhere else. + They all follow the same conventions as the rest of the library. +

    + + + + + + + +
    + Objects + bit_stream + byte_orderer + std_allocator + memory_manager + memory_manager_global + memory_manager_stateless + default_memory_manager + sync_extension + timer + timeout + member_function_pointer + vectorstream + unserialize + bound_function_pointer + error + console_progress_indicator + pipe + copy_functor + logger + + Fixed_width_integers + + uint64 + uint32 + uint16 + uint8 + int64 + int32 + int16 + int8 + + +
    + +
    + Global Functions + ramdump + check_serialized_version + deserialize + serialize + zero_extend_cast + make_mfp + TIME_THIS + timing code blocks +
    + +
    + SQLite + database + statement + transaction + + simple_queries + + + query_object + dlib/sqlite/sqlite_tools_abstract.h.html#query_object + + + query_text + dlib/sqlite/sqlite_tools_abstract.h.html#query_text + + + query_double + dlib/sqlite/sqlite_tools_abstract.h.html#query_double + + + query_int + dlib/sqlite/sqlite_tools_abstract.h.html#query_int + + + query_int64 + dlib/sqlite/sqlite_tools_abstract.h.html#query_int64 + + + query_blob + dlib/sqlite/sqlite_tools_abstract.h.html#query_blob + + + +
    + +
    + Other + dlib_testing_suite + MATLAB + Java +
    + +
    +
    + + + + + + + + + + + + zero_extend_cast + dlib/uintn.h + dlib/uintn.h + + This is a global function that performs a zero extending cast + from one integral type to another integral type. + + + + + + + + uint32 + dlib/uintn.h + dlib/uintn.h + + This is just a typedef for a 32 bit unsigned integer. + + + + + + + + uint8 + dlib/uintn.h + dlib/uintn.h + + This is just a typedef for an 8 bit unsigned integer. + + + + + + + + + uint16 + dlib/uintn.h + dlib/uintn.h + + This is just a typedef for a 16 bit unsigned integer. + + + + + + + + int8 + dlib/uintn.h + dlib/uintn.h + + This is just a typedef for an 8 bit integer. + + + + + + + + int16 + dlib/uintn.h + dlib/uintn.h + + This is just a typedef for a 16 bit integer. + + + + + + + + int32 + dlib/uintn.h + dlib/uintn.h + + This is just a typedef for a 32 bit integer. + + + + + + + + int64 + dlib/uintn.h + dlib/uintn.h + + This is just a typedef for a 64 bit integer. + + + + + + + + std_allocator + dlib/std_allocator.h + dlib/std_allocator.h + + This object is an implementation of an allocator that conforms to the C++ standard + requirements for allocator objects. The M template argument is one of the dlib + memory manager objects and this allocator implementation will do all of its memory allocations + using whatever dlib memory manager you supply. + +

    + Thus, using this allocator object you can use any of the dlib memory manager objects with + the containers in the STL or with any other object that requires an STL style allocator object. +

    +
    + + + std_allocator_ex.cpp.html + + +
    + + + + + uint64 + dlib/uintn.h + dlib/uintn.h + + This is just a typedef for a 64 bit unsigned integer. + + + + + + + + copy_functor + dlib/algs.h + dlib/algs.h + + This is a templated function object that makes copies of something. + + + + + + + + logger + dlib/logger.h + dlib/logger/logger_kernel_abstract.h + + This component is a logging output stream in the style of the log4j + logger available for Java. + + + + logger_ex.cpp.html + logger_ex_2.cpp.html + logger_custom_output_ex.cpp.html + pipe_ex.cpp.html + + + + + extra_logger_headers + dlib/logger/extra_logger_headers.h + This extension contains additional logger headers you may chose to use instead of the + default one. + + + config_file + dlib/logger/logger_config_file.h + This extension provides the configure_loggers_from_file() function + which reads a configuration file from disk that sets up all your loggers. + + + + + + + + + + error + dlib/error.h + dlib/error.h + + This is the base exception class from which all exceptions in this + library inherit. + + + + + + + + console_progress_indicator + dlib/console_progress_indicator.h + dlib/console_progress_indicator.h + + This object is a tool for reporting how long a task will take + to complete. + + + + + + + + pipe + dlib/pipe.h + dlib/pipe/pipe_kernel_abstract.h + + This is a first in first out queue with a fixed maximum size. + It is suitable for passing objects between threads. + +

    + This object is optimized for speed, therefore, it uses + global swap() to create a zero-copy method for moving objects + around. For example, on a computer running Ubuntu 12.04 with + a 2.67GHz Intel i7 920 CPU it is possible to pass over 4.4 + million std::vector<int> objects a second between two + threads. This is regardless of the number of ints in the std::vector + objects. In particular, this test was done with 100,000 + ints in each std::vector. +

    +

    + Finally, note that you can use the pipe as an efficient method to pass + messages between two networked computers by using the bridge. +

    +
    + + + pipe_ex.cpp.html + pipe_ex_2.cpp.html + bridge_ex.cpp.html + + +
    + + + + + + bound_function_pointer + dlib/bound_function_pointer.h + dlib/bound_function_pointer/bound_function_pointer_kernel_abstract.h + + This object represents a function with all its arguments bound to specific objects. + +

    + This implementation is done using type erasure and placement new. This + means that it never allocates memory on the heap and instead stores everything + on the stack. +

    +
    + +
    + + + + + + vectorstream + dlib/vectorstream.h + dlib/vectorstream/vectorstream_abstract.h + + This is an iostream object that reads and writes from an in-memory buffer. + It functions very much the same way as the std::stringstream object. + However, while the std::stringstream holds its buffer internally and it can + only be accessed by copying it out, the vectorstream uses an external + std::vector<char> as its buffer. That is, it holds a reference to an + external vector and does not contain any internal buffers of its own. + +

    + This object is useful as a slightly more efficient alternative to the + std::stringstream since you can avoid the overhead of copying buffer + contents to and from the stream. This is particularly useful when used as + a source or target for serialization routines. +

    +
    +
    + + + + + + unserialize + dlib/vectorstream.h + dlib/vectorstream/unserialize_abstract.h + + This object effectively allows you to peek at the next serialized + object in an istream. It does this by allowing you to read an object + and then put it back. + + + + + + + + member_function_pointer + dlib/member_function_pointer.h + dlib/member_function_pointer/member_function_pointer_kernel_abstract.h + + This object represents a member function pointer. It is useful because + instances of this object can be created without needing to know the type + of object whose member function we will be calling. +

    + The implementation of this object is done using type erasure and placement new. This + means that it never allocates memory on the heap and instead stores everything + on the stack. +

    +
    + + + member_function_pointer_ex.cpp.html + + +
    + + + + + + make_mfp + dlib/member_function_pointer.h + dlib/member_function_pointer/make_mfp_abstract.h + + This function is a simple factory for creating member_function_pointer + objects without needing to know the necessary template arguments for the member_function_pointer. + + + + + + + + bit_stream + dlib/bit_stream.h + dlib/bit_stream/bit_stream_kernel_abstract.h + + This object represents a middle man between a user and the iostream classes that allows single + bits to be read/written easily from/to the iostream classes + + + + + bit_stream_kernel_1 + dlib/bit_stream/bit_stream_kernel_1.h + + This implementation is done by buffering single bits in the obvious way. + + + + + kernel_1a + is a typedef for bit_stream_kernel_1 + + + + + + + + + + + bit_stream_multi + dlib/bit_stream/bit_stream_multi_abstract.h + This extension gives a bit_stream object the ability to read/write multiple bits at a time. + + + bit_stream_multi_1 + dlib/bit_stream/bit_stream_multi_1.h + This implementation is done by calling the read/write functions in the bit_stream kernel. + + + multi_1a + is a typedef for bit_stream_kernel_1 extended by bit_stream_multi_1 + + + + + + + + + + + + + + byte_orderer + dlib/byte_orderer.h + dlib/byte_orderer/byte_orderer_kernel_abstract.h + + This object provides a simple type safe mechanism to convert data + to and from network and host byte orders. I.e. to convert things + between big and little endian byte ordering. + + + + + + + + default_memory_manager + dlib/algs.h + dlib/algs.h + + This is a memory manager object which simply calls new and delete directly (i.e. + it doesn't really do anything). It is the default memory manager used by most + of the objects in dlib. + + + + + + + memory_manager_stateless + dlib/memory_manager_stateless.h + dlib/memory_manager_stateless/memory_manager_stateless_kernel_abstract.h + + This object represents some kind of stateless memory manager or memory pool. + Stateless means that all instances (instances of the same type that is) + of this object are identical and can be used interchangeably. Note that + implementations are allowed to have some shared global state such as a + global memory pool. This object is also thread safe. + + + + + memory_manager_stateless_kernel_1 + dlib/memory_manager_stateless/memory_manager_stateless_kernel_1.h + + This implementation just calls new and delete. So it doesn't do anything special. + + + + + kernel_1a + is a typedef for memory_manager_stateless_kernel_1 + + + + + + + memory_manager_stateless_kernel_2 + dlib/memory_manager_stateless/memory_manager_stateless_kernel_2.h + + This implementation uses a global instance of a memory_manager object + guarded by a mutex as its implementation. + + + + + kernel_2_1a + is a typedef for memory_manager_stateless_kernel_2 that uses memory_manager_1a + + + kernel_2_1b + is a typedef for memory_manager_stateless_kernel_2 that uses memory_manager_1b + + + kernel_2_1c + is a typedef for memory_manager_stateless_kernel_2 that uses memory_manager_1c + + + kernel_2_1d + is a typedef for memory_manager_stateless_kernel_2 that uses memory_manager_1d + + + kernel_2_1e + is a typedef for memory_manager_stateless_kernel_2 that uses memory_manager_1e + + + kernel_2_1f + is a typedef for memory_manager_stateless_kernel_2 that uses memory_manager_1f + + + kernel_2_2a + is a typedef for memory_manager_stateless_kernel_2 that uses memory_manager_2a + + + kernel_2_2b + is a typedef for memory_manager_stateless_kernel_2 that uses memory_manager_2b + + + kernel_2_2c + is a typedef for memory_manager_stateless_kernel_2 that uses memory_manager_2c + + + kernel_2_2d + is a typedef for memory_manager_stateless_kernel_2 that uses memory_manager_2d + + + kernel_2_2e + is a typedef for memory_manager_stateless_kernel_2 that uses memory_manager_2e + + + + kernel_2_3a + is a typedef for memory_manager_stateless_kernel_2 that uses memory_manager_3a + + + kernel_2_3b + is a typedef for memory_manager_stateless_kernel_2 that uses memory_manager_3b + + + kernel_2_3c + is a typedef for memory_manager_stateless_kernel_2 that uses memory_manager_3c + + + kernel_2_3d + is a typedef for memory_manager_stateless_kernel_2 that uses memory_manager_3d + + + kernel_2_3e + is a typedef for memory_manager_stateless_kernel_2 that uses memory_manager_3e + + + + + + + + + + + + memory_manager_global + dlib/memory_manager_global.h + dlib/memory_manager_global/memory_manager_global_kernel_abstract.h + + This object represents some kind of global memory manager or memory pool. + + + + + memory_manager_global_kernel_1 + dlib/memory_manager_global/memory_manager_global_kernel_1.h + + This is implemented in the obvious way. See the code for details. + + + + + kernel_1a + is a typedef for memory_manager_global_kernel_1 + + + + + + + + + + + + memory_manager + dlib/memory_manager.h + dlib/memory_manager/memory_manager_kernel_abstract.h + + This object represents a memory pool. + + + + + memory_manager_kernel_1 + dlib/memory_manager/memory_manager_kernel_1.h + + This memory manager implementation allocates objects one at a time when there are + allocation requests. Then when there is a deallocate request the returning object + is placed into a list of free blocks if that list has less than max_pool_size + blocks in it. Subsequent allocation requests will be serviced by drawing from the + free list whenever it isn't empty. Array allocations, on the other hand, are not + managed at all but are passed directly on to new and delete. +

    + When this object's max_pool_size template parameter is set to 0 it simply calls + new and delete directly and doesn't function as a memory pool. +

    +
    + + + + kernel_1a + is a typedef for memory_manager_kernel_1 with a max_pool_size of 0 + + + kernel_1b + is a typedef for memory_manager_kernel_1 with a max_pool_size of 10 + + + kernel_1c + is a typedef for memory_manager_kernel_1 with a max_pool_size of 100 + + + kernel_1d + is a typedef for memory_manager_kernel_1 with a max_pool_size of 1000 + + + kernel_1e + is a typedef for memory_manager_kernel_1 with a max_pool_size of 10000 + + + kernel_1f + is a typedef for memory_manager_kernel_1 with a max_pool_size of 100000 + + + +
    + + + memory_manager_kernel_2 + dlib/memory_manager/memory_manager_kernel_2.h + + This memory manager implementation allocates memory in blocks of chunk_size*sizeof(T) + bytes. All the sizeof(T) sub-blocks are kept in a linked list of free memory blocks + and are given out whenever an allocation request occurs. Also, memory is not freed + until this object is destructed. + Also note that array allocations are not managed at all but are passed directly + on to new and delete. + + + + + kernel_2a + is a typedef for memory_manager_kernel_2 with a chunk_size of 10 + + + kernel_2b + is a typedef for memory_manager_kernel_2 with a chunk_size of 100 + + + kernel_2c + is a typedef for memory_manager_kernel_2 with a chunk_size of 1000 + + + kernel_2d + is a typedef for memory_manager_kernel_2 with a chunk_size of 10000 + + + kernel_2e + is a typedef for memory_manager_kernel_2 with a chunk_size of 100000 + + + + + + + memory_manager_kernel_3 + dlib/memory_manager/memory_manager_kernel_3.h + + This memory manager implementation allocates memory in blocks of chunk_size*sizeof(T) + bytes. All the sizeof(T) sub-blocks are kept in a linked list of free memory blocks + and are given out whenever an allocation request occurs. Note that array allocations + are managed. So this object is just like kernel_2 but it also pools memory from + array allocations (chunk_size has no effect with respect to array allocations, each array + is allocated one at a time). + Also, memory is not freed until this object is destructed. + + + + + kernel_3a + is a typedef for memory_manager_kernel_3 with a chunk_size of 10 + + + kernel_3b + is a typedef for memory_manager_kernel_3 with a chunk_size of 100 + + + kernel_3c + is a typedef for memory_manager_kernel_3 with a chunk_size of 1000 + + + kernel_3d + is a typedef for memory_manager_kernel_3 with a chunk_size of 10000 + + + kernel_3e + is a typedef for memory_manager_kernel_3 with a chunk_size of 100000 + + + + +
    + +
    + + + + + sync_extension + dlib/sync_extension.h + dlib/sync_extension/sync_extension_kernel_abstract.h + + + This object represents a general extension to any object. This object gives any object which it extends + an integrated rmutex and rsignaler object. The extended object will + then be able to be treated as if it was also a rmutex and + rsignaler. + + + + + + sync_extension_kernel_1 + dlib/sync_extension/sync_extension_kernel_1.h + + This is implemented using a rmutex + and rsignaler in the obvious way. + + + + + kernel_1a + is a typedef for sync_extension_kernel_1 + + + + + + + + + + + + + + timeout + dlib/timeout.h + dlib/timeout/timeout_abstract.h + + This object provides a simple way to implement a timeout. + + + + + + + timer + dlib/timer.h + dlib/timer/timer_abstract.h + + This object represents a timer that will call a given member function + repeatedly at regular intervals. +

    + The implementation of this object has a single master thread + that does all the waiting. This master thread creates and + dispatches threads to specific timer objects when they need + to run their action functions. When a timer object isn't + executing its action function then it doesn't have any thread + allocated to it at all. So it is fairly efficient. +

    +
    + + + timer_ex.cpp.html + + +
    + + + + + database + dlib/sqlite.h + dlib/sqlite/sqlite_abstract.h + + This object is a C++ wrapper around a SQLite database connection + handle and therefore represents a SQLite database file. + +

    + Note that this wrapper is targeted at SQLite Version 3. To use it + you must make sure you link your application with SQLite. However, + if you use CMake and dlib's default CMakeLists.txt file then it will get setup + automatically. This is assuming sqlite3 is properly installed on your system. + On ubuntu you can get it by installing the libsqlite3-dev package. Or you can always + download the SQLite source + and compile it straight into your application (download the amalgamation). +

    +
    + + sqlite_ex.cpp.html + + +
    + + + + + statement + dlib/sqlite.h + dlib/sqlite/sqlite_abstract.h + + This object represents a SQL statement which can be executed + against a database object. In particular, this object is a + C++ wrapper around a SQLite prepared statement. +

    + Note that this wrapper is targeted at SQLite Version 3. To use it + you must make sure you link your application with SQLite. +

    +
    + + sqlite_ex.cpp.html + + + +
    + + + + + transaction + dlib/sqlite.h + dlib/sqlite/sqlite_tools_abstract.h + + This object is a tool for creating exception safe + database transactions. + + + sqlite_ex.cpp.html + + + + + + + + + deserialize + dlib/serialize.h + dlib/serialize.h + + This is actually a set of overloaded functions which provide the ability to restore an object's state + from an input stream. Currently all dlib container classes, non pointer C++ intrinsics, std::string, + std::vector, std::map, std::set, std::complex, dlib::bigint, dlib::uint64, dlib::int64, C style arrays, and dlib::vector objects are serializable. +

    + You can also use serialize() and deserialize() to read/write Google protocol buffer objects. However, + note that dlib::serialize() writes additional delimiting bytes at the start of each protocol buffer message. + We do this because Google protocol buffers are not + self-delimiting + on their own. This means that you can't write more than one protocol buffer object to an output stream + unless you include some kind of delimiter between the messages. + So dlib takes care of this for you by prefixing each message with its length in bytes. In particular, + the number of bytes is encoded as a 32bit little endian integer. +

    +
    + +
    + + + + + serialize + dlib/serialize.h + dlib/serialize.h + + This is actually a set of overloaded functions which provide the ability to save an object's state + to an output stream. Currently all dlib container classes, non pointer C++ intrinsics, std::string, + std::vector, std::map, std::set, std::complex, dlib::bigint, dlib::uint64, dlib::int64, C style arrays, and dlib::vector objects are serializable. +

    + You can also use serialize() and deserialize() to read/write Google protocol buffer objects. However, + note that dlib::serialize() writes additional delimiting bytes at the start of each protocol buffer message. + We do this because Google protocol buffers are not + self-delimiting + on their own. This means that you can't write more than one protocol buffer object to an output stream + unless you include some kind of delimiter between the messages. + So dlib takes care of this for you by prefixing each message with its length in bytes. In particular, + the number of bytes is encoded as a 32bit little endian integer. +

    +
    + +
    + + + + + ramdump + dlib/serialize.h + dlib/serialize.h + + This is a type decoration used to indicate that serialization should be + done by simply dumping the memory of some object to disk as fast as + possible without any sort of conversions. This means that the data written + will be "non-portable" in the sense that the format output by a RAM dump + may depend on things like the endianness of your CPU or settings of certain + compiler switches. + +

    + You use this object like this: + +serialize("yourfile.dat") << ramdump(yourobject); +deserialize("yourfile.dat") >> ramdump(yourobject); + or + +serialize(ramdump(yourobject), out); +deserialize(ramdump(yourobject), in); + + Also, not all objects have a ramdump mode. If you try to use ramdump on an + object that does not define a serialization dump for ramdump you will get a + compiler error. +

    +
    +
    + + + + + check_serialized_version + dlib/serialize.h + dlib/serialize.h + + This function deserializes a string and checks if it matches a user supplied + string (the version). If they don't match then dlib::serialization_error is + thrown. The point of this function is to make checking version strings in + serialized files a little more convenient. + + + + + + + dlib_testing_suite + +

    + This library comes with a command line driven regression test suite. All the testing code + is located in the dlib/testdlib/test folder. If you want to build it and test the library on your + system you can use the makefile at dlib/test/makefile (you may + have to edit it to make it work on your system) or use the CMake CMakeLists.txt file at + dlib/test/CMakeLists.txt to build it. +

    +

    + What you may find more useful however is the testing framework itself. It uses a fairly simple + and modular design. Each test is contained in its own cpp file and when compiled into the + program it automatically shows up in the list of tests to run. If you want to use the + testing framework all you need to do is add the files dlib/test/tester.h, + dlib/test/tester.cpp, and dlib/test/main.cpp + to your project and then add cpp files that contain your tests (see + dlib/test/example.cpp and + dlib/test/example_args.cpp + for some examples). +

    +

    + From the command line you can choose to run all the installed tests, enable or disable the loggers, + set various logging levels, specify how many times to run the tests, or pick just one or two tests + to run at a time rather than the entire suite. + The output of the program, that is, its return value from main() is the number of + failed tests. So if every test succeeds then it returns 0. +

    +
    + +
    + + + + + MATLAB + + dlib contains a tool that makes it easy to call C++ code from MATLAB. It's + documented in the examples in the dlib/matlab folder. In particular, the + dlib/matlab/example_mex_function.cpp, + dlib/matlab/example_mex_callback.cpp, and + dlib/matlab/example_mex_struct.cpp + examples. + You can also easily compile these files using CMake. See the instructions in the README file + in the dlib/matlab folder for further details. + + + + + + + + Java + + dlib contains some CMake scripts and related tools that make calling C++ code + from Java easy. If you look in the dlib/java folder you can find a CMake + project that uses SWIG to build some C++ code and then call it from Java. In + particular, if you run the run_test.sh script it will build and run the code, + calling it from java. + +

    + The dlib/java folder also contains some SWIG aware C++ classes that make + interacting with java arrays (e.g. double[]) from C++ efficient and easy. + See the documentation at the top of the java_array.h file for details. +

    +
    + +
    + + + + + TIME_THIS + dlib/time_this.h + dlib/time_this.h + +

    + This is a macro function for timing blocks of code. Its form is TIME_THIS(whatever you want to time) + It's pretty straight forward. It just prints the time it took to std::cout. +

    +

    + There is another version of this function called TIME_THIS_TO which takes as a parameter an ostream + object to write its output to. Its form is TIME_THIS_TO(what you want to time, the output stream); +

    + +
    + +
    + + + + + + timing code blocks + dlib/timing.h + dlib/timing.h + + This is a set of set of functions for timing blocks of code. Unlike + TIME_THIS, it can be used to find the cumulative + time spent on a block which is executed multiple times. + + + + + + + + +
    + + + + +
    diff --git a/ml/dlib/docs/docs/parsing.xml b/ml/dlib/docs/docs/parsing.xml new file mode 100644 index 000000000..b993acaae --- /dev/null +++ b/ml/dlib/docs/docs/parsing.xml @@ -0,0 +1,652 @@ + + + + + Parsing + + + + + +

    + This page documents the objects and functions that in some way deal with parsing or otherwise + manipulating text. + Everything here follows the same conventions as the rest of the library. +

    + + + + + + + + + +
    + Objects + cmd_line_parser + config_reader + cpp_pretty_printer + cpp_tokenizer + tokenizer + xml_parser + base64 + unichar + ustring + basic_utf8_ifstream + +
    + +
    + Global Functions + string_cast + string_assign + cast_to_string + pad_int_with_zeros + cast_to_wstring + wrap_string + narrow + trim + ltrim + rtrim + pad + lpad + rpad + split_on_first + split_on_last + left_substr + right_substr + split + tolower + toupper + convert_utf8_to_utf32 + is_combining_char + strings_equal_ignore_case +
    +
    +
    + + + + + + + + + + + + + toupper + dlib/string.h + dlib/string/string_abstract.h + + This is a function to convert a string to all uppercase. + + + + + + + + tolower + dlib/string.h + dlib/string/string_abstract.h + + This is a function to convert a string to all lowercase. + + + + + + + + + split_on_first + dlib/string.h + dlib/string/string_abstract.h + + Breaks a string into two parts. The split point is selected based + on the first occurrence of a delimiter character. + + + + + + + split_on_last + dlib/string.h + dlib/string/string_abstract.h + + Breaks a string into two parts. The split point is selected based + on the last occurrence of a delimiter character. + + + + + + + split + dlib/string.h + dlib/string/string_abstract.h + + Breaks a string into a sequence of substrings delimited + by a user specified set of characters. + + + + + + + right_substr + dlib/string.h + dlib/string/string_abstract.h + + This is a function to return the part of a string to the right of a user supplied delimiter. + + + + + + + left_substr + dlib/string.h + dlib/string/string_abstract.h + + This is a function to return the part of a string to the left of a user supplied delimiter. + + + + + + + + rpad + dlib/string.h + dlib/string/string_abstract.h + + This is a function to pad whitespace (or user specified characters) onto the right most end of a string. + + + + + + + + lpad + dlib/string.h + dlib/string/string_abstract.h + + This is a function to pad whitespace (or user specified characters) onto the left most end of a string. + + + + + + + + pad + dlib/string.h + dlib/string/string_abstract.h + + This is a function to pad whitespace (or user specified characters) onto the ends of a string. + + + + + + + + rtrim + dlib/string.h + dlib/string/string_abstract.h + + This is a function to remove the whitespace (or user specified characters) from the right most end of a string. + + + + + + + + ltrim + dlib/string.h + dlib/string/string_abstract.h + + This is a function to remove the whitespace (or user specified characters) from the left most end of a string. + + + + + + + + trim + dlib/string.h + dlib/string/string_abstract.h + + This is a function to remove the whitespace (or user specified characters) from the ends of a string. + + + + + + + + narrow + dlib/string.h + dlib/string/string_abstract.h + + This is a function for converting a string of type std::string or std::wstring + to a plain std::string. + + + + + + + + wrap_string + dlib/string.h + dlib/string/string_abstract.h + + wrap_string is a function that takes a string and breaks it into a number of + lines of a given length. You can use this to make a string + fit nicely into a command prompt window for example. + + + + + + + + strings_equal_ignore_case + dlib/string.h + dlib/string/string_abstract.h + + This is a pair of functions to do a case insensitive comparison between strings. + + + + + + + + cast_to_wstring + dlib/string.h + dlib/string/string_abstract.h + + cast_to_string is a templated function which makes it easy to convert arbitrary objects to + std::wstring strings. The types supported are any types that can be written to std::wostream via + operator<<. + + + + + + + + cast_to_string + dlib/string.h + dlib/string/string_abstract.h + + cast_to_string is a templated function which makes it easy to convert arbitrary objects to + std::string strings. The types supported are any types that can be written to std::ostream via + operator<<. + + + + + + + + pad_int_with_zeros + dlib/string.h + dlib/string/string_abstract.h + + Converts an integer into a string and pads it with leading zeros. + + + + + + + + string_cast + dlib/string.h + dlib/string/string_abstract.h + + string_cast is a templated function which makes it easy to convert strings to + other types. The types supported are any types that can be read by the basic_istream operator>>. It + also supports casting between wstring, string, and ustring objects. + + + + + + + string_assign + dlib/string.h + dlib/string/string_abstract.h + + string_assign is an object which makes it easy to convert strings to + other types. The types supported are any types that can be read by the basic_istream operator>>. It + also supports casting between wstring, string, and ustring objects. Since + string_assign is a simple stateless object there is a global instance of it + called dlib::sa. + + + config_reader_ex.cpp.html + + + + + + + + unichar + dlib/unicode.h + dlib/unicode/unicode_abstract.h + + This is a typedef for an unsigned 32bit integer which we use to store + Unicode values. + + + + + + + + basic_utf8_ifstream + dlib/unicode.h + dlib/unicode/unicode_abstract.h + + This object represents an input file stream much like the + normal std::ifstream except that it knows how to read UTF-8 + data. So when you read characters out of this stream it will + automatically convert them from the UTF-8 multibyte encoding + into a fixed width wide character encoding. + +

    + There are also two typedefs of this object. The first is utf8_wifstream which is a + typedef for wchar_t as the wide character to read into. The second is utf8_uifstream + which uses unichar instead of wchar_t. +

    +
    + +
    + + + + + + ustring + dlib/unicode.h + dlib/unicode/unicode_abstract.h + + This is a typedef for a std::basic_string<unichar>. That is, it is a typedef + for a string object that stores unichar Unicode characters. + + + + + + + + is_combining_char + dlib/unicode.h + dlib/unicode/unicode_abstract.h + + This is a global function that can tell you if a character is a Unicode + combining character or not. + + + + + + + + convert_utf8_to_utf32 + dlib/unicode.h + dlib/unicode/unicode_abstract.h + + This is a global function that can convert UTF-8 strings into strings + of 32bit unichar characters. + + + + + + + + base64 + dlib/base64.h + dlib/base64/base64_kernel_abstract.h + + This object allows you to encode and decode data to and from + the Base64 Content-Transfer-Encoding defined in section 6.8 of + rfc2045. + + + + file_to_code_ex.cpp.html + + + + + + + cmd_line_parser + dlib/cmd_line_parser.h + dlib/cmd_line_parser/cmd_line_parser_kernel_abstract.h + + This object allows you to easily parse a command line. Note that the + documentation for the cmd_line_parser_option + (the object returned by the parser's .option() function) is in a separate file. +

    + Note also that there are standard typedefs for the ASCII and wide character versions of the + cmd_line_parser template. These are the command_line_parser and wcommand_line_parser + types respectively. +

    +
    + + + compress_stream_ex.cpp.html + train_object_detector.cpp.html + + + + + get_option + dlib/cmd_line_parser/get_option_abstract.h + This extension provides a convenience function for accessing the + options to a command line argument or a config_reader. It + is automatically #included when using the command line parser or config reader. + + + + + +
    + + + + + config_reader + dlib/config_reader.h + dlib/config_reader/config_reader_kernel_abstract.h + + This object represents something which is intended to be used to read + text configuration files. + + + + config_reader_ex.cpp.html + + + + + config_reader_thread_safe + dlib/config_reader/config_reader_thread_safe_abstract.h + + This object extends a normal config_reader by simply wrapping all + its member functions inside mutex locks to make it safe to use + in a threaded program. + + + + + + + + + + + cpp_pretty_printer + dlib/cpp_pretty_printer.h + dlib/cpp_pretty_printer/cpp_pretty_printer_kernel_abstract.h + + This object represents an HTML pretty printer for C++ source code. + + + + + cpp_pretty_printer_kernel_1 + dlib/cpp_pretty_printer/cpp_pretty_printer_kernel_1.h + + This is implemented by using the cpp_tokenizer object. + This is the pretty printer I use on all the source in this library. It applies a color scheme, turns + include directives such as #include "file.h" into links to file.h.html and puts HTML anchor points + on function and class declarations. It also looks for comments starting with /*!A and puts an anchor + before the comment using the word following the A as the name of the anchor. + + + + + kernel_1a + is a typedef for cpp_pretty_printer_kernel_1 + + + + + + cpp_pretty_printer_kernel_2 + dlib/cpp_pretty_printer/cpp_pretty_printer_kernel_2.h + + This is implemented by using the cpp_tokenizer object. + It applies a black and white color scheme suitable + for printing on a black and white printer. It also places the document title + prominently at the top of the pretty printed source file. + + + + + kernel_2a + is a typedef for cpp_pretty_printer_kernel_2 + + + + + + + + + + + + + cpp_tokenizer + dlib/cpp_tokenizer.h + dlib/cpp_tokenizer/cpp_tokenizer_kernel_abstract.h + + This object represents a simple tokenizer for C++ source code. + + + + + cpp_tokenizer_kernel_1 + dlib/cpp_tokenizer/cpp_tokenizer_kernel_1.h + + This is implemented by using the tokenizer object in the obvious way. + + + + + kernel_1a + is a typedef for cpp_tokenizer_kernel_1 + + + + + + + + + + + + + tokenizer + dlib/tokenizer.h + dlib/tokenizer/tokenizer_kernel_abstract.h + + This object represents a simple tokenizer for textual data. + + + + + tokenizer_kernel_1 + dlib/tokenizer/tokenizer_kernel_1.h + + This is implemented in the obvious way. + + + + + kernel_1a + is a typedef for tokenizer_kernel_1 + + + + + + + + + + + + + xml_parser + dlib/xml_parser.h + dlib/xml_parser/xml_parser_kernel_abstract.h + + + This object represents a simple SAX style event driven XML parser. + It takes its input from an input stream object and sends events to all + registered document_handler and error_handler objects. +

    + + The xml_parser object also uses the interface classes + document_handler + and + error_handler. + Subclasses of these classes are passed to the xml_parser which generates events while it's + parsing and sends them to the appropriate handler. + +
    + + + xml_parser_ex.cpp.html + +
    + + + + +
    + + + + +
    diff --git a/ml/dlib/docs/docs/plus.gif b/ml/dlib/docs/docs/plus.gif new file mode 100644 index 000000000..2d15c1417 Binary files /dev/null and b/ml/dlib/docs/docs/plus.gif differ diff --git a/ml/dlib/docs/docs/python/conf.py b/ml/dlib/docs/docs/python/conf.py new file mode 100644 index 000000000..c8a2a963b --- /dev/null +++ b/ml/dlib/docs/docs/python/conf.py @@ -0,0 +1,246 @@ +# -*- coding: utf-8 -*- +# +# dlib documentation build configuration file, created by +# sphinx-quickstart on Wed Jun 12 18:29:29 2013. +# +# This file is execfile()d with the current directory set to its containing dir. +# +# Note that not all possible configuration values are present in this +# autogenerated file. +# +# All configuration values have a default; values that are commented out +# serve to show the default. + +import sys, os + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +sys.path.insert(0, os.path.abspath('.')) +sys.path.insert(0, os.path.abspath('../../../build/lib.linux-x86_64-2.7')) + +import generate_dlib_listing +generate_dlib_listing.make_listing_files() + +# -- General configuration ----------------------------------------------------- + +# If your documentation needs a minimal Sphinx version, state it here. +#needs_sphinx = '1.0' + +# Add any Sphinx extension module names here, as strings. They can be extensions +# coming with Sphinx (named 'sphinx.ext.*') or your custom ones. +extensions = ['sphinx.ext.autodoc'] + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + +# The suffix of source filenames. +source_suffix = '.rst' + +# The encoding of source files. +#source_encoding = 'utf-8-sig' + +# The master toctree document. +master_doc = 'index' + +# General information about the project. +project = u'dlib' +copyright = u'2013, Davis E. King' + +# The version info for the project you're documenting, acts as replacement for +# |version| and |release|, also used in various other places throughout the +# built documents. +# +# The short X.Y version. +#version = '18.3' +# The full version, including alpha/beta/rc tags. +#release = '18.3' + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +#language = None + +# There are two options for replacing |today|: either, you set today to some +# non-false value, then it is used: +#today = '' +# Else, today_fmt is used as the format for a strftime call. +#today_fmt = '%B %d, %Y' + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +exclude_patterns = ['_build'] + +# The reST default role (used for this markup: `text`) to use for all documents. +#default_role = None + +# If true, '()' will be appended to :func: etc. cross-reference text. +#add_function_parentheses = True + +# If true, the current module name will be prepended to all description +# unit titles (such as .. function::). +#add_module_names = True + +# If true, sectionauthor and moduleauthor directives will be shown in the +# output. They are ignored by default. +#show_authors = False + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = 'sphinx' + +# A list of ignored prefixes for module index sorting. +#modindex_common_prefix = [] + + +# -- Options for HTML output --------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +html_theme = 'default' + +# Theme options are theme-specific and customize the look and feel of a theme +# further. For a list of options available for each theme, see the +# documentation. +#html_theme_options = {} + +# Add any paths that contain custom themes here, relative to this directory. +#html_theme_path = [] + +# The name for this set of Sphinx documents. If None, it defaults to +# " v documentation". +#html_title = None + +# A shorter title for the navigation bar. Default is the same as html_title. +#html_short_title = None + +# The name of an image file (relative to this directory) to place at the top +# of the sidebar. +#html_logo = None + +# The name of an image file (within the static path) to use as favicon of the +# docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 +# pixels large. +#html_favicon = None + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ['_static'] + +# If not '', a 'Last updated on:' timestamp is inserted at every page bottom, +# using the given strftime format. +#html_last_updated_fmt = '%b %d, %Y' + +# If true, SmartyPants will be used to convert quotes and dashes to +# typographically correct entities. +#html_use_smartypants = True + +# Custom sidebar templates, maps document names to template names. +#html_sidebars = {} + +# Additional templates that should be rendered to pages, maps page names to +# template names. +#html_additional_pages = {} + +# If false, no module index is generated. +#html_domain_indices = True + +# If false, no index is generated. +#html_use_index = True + +# If true, the index is split into individual pages for each letter. +#html_split_index = False + +# If true, links to the reST sources are added to the pages. +#html_show_sourcelink = True + +# If true, "Created using Sphinx" is shown in the HTML footer. Default is True. +#html_show_sphinx = True + +# If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. +#html_show_copyright = True + +# If true, an OpenSearch description file will be output, and all pages will +# contain a tag referring to it. The value of this option must be the +# base URL from which the finished HTML is served. +#html_use_opensearch = '' + +# This is the file name suffix for HTML files (e.g. ".xhtml"). +#html_file_suffix = None + +# Output file base name for HTML help builder. +htmlhelp_basename = 'dlibdoc' + + +# -- Options for LaTeX output -------------------------------------------------- + +latex_elements = { +# The paper size ('letterpaper' or 'a4paper'). +#'papersize': 'letterpaper', + +# The font size ('10pt', '11pt' or '12pt'). +#'pointsize': '10pt', + +# Additional stuff for the LaTeX preamble. +#'preamble': '', +} + +# Grouping the document tree into LaTeX files. List of tuples +# (source start file, target name, title, author, documentclass [howto/manual]). +latex_documents = [ + ('index', 'dlib.tex', u'dlib Documentation', + u'Davis', 'manual'), +] + +# The name of an image file (relative to this directory) to place at the top of +# the title page. +#latex_logo = None + +# For "manual" documents, if this is true, then toplevel headings are parts, +# not chapters. +#latex_use_parts = False + +# If true, show page references after internal links. +#latex_show_pagerefs = False + +# If true, show URL addresses after external links. +#latex_show_urls = False + +# Documents to append as an appendix to all manuals. +#latex_appendices = [] + +# If false, no module index is generated. +#latex_domain_indices = True + + +# -- Options for manual page output -------------------------------------------- + +# One entry per manual page. List of tuples +# (source start file, name, description, authors, manual section). +man_pages = [ + ('index', 'dlib', u'dlib Documentation', + [u'Davis'], 1) +] + +# If true, show URL addresses after external links. +#man_show_urls = False + + +# -- Options for Texinfo output ------------------------------------------------ + +# Grouping the document tree into Texinfo files. List of tuples +# (source start file, target name, title, author, +# dir menu entry, description, category) +texinfo_documents = [ + ('index', 'dlib', u'dlib Documentation', + u'Davis', 'dlib', 'One line description of project.', + 'Miscellaneous'), +] + +# Documents to append as an appendix to all manuals. +#texinfo_appendices = [] + +# If false, no module index is generated. +#texinfo_domain_indices = True + +# How to display URL addresses: 'footnote', 'no', or 'inline'. +#texinfo_show_urls = 'footnote' diff --git a/ml/dlib/docs/docs/python/generate_dlib_listing.py b/ml/dlib/docs/docs/python/generate_dlib_listing.py new file mode 100644 index 000000000..2aaff2458 --- /dev/null +++ b/ml/dlib/docs/docs/python/generate_dlib_listing.py @@ -0,0 +1,32 @@ +from __future__ import print_function +import dlib +import inspect + +def print_element(name, fc, ff): + isclass = inspect.isclass(eval(name)) + ismodule = inspect.ismodule(eval(name)) + if (isclass): + print("* :class:`{0}`".format(name), file=fc) + elif (not ismodule): + print("* :func:`{0}`".format(name), file=ff) + +def make_listing_files(): + + fc = open('classes.txt', 'w') + ff = open('functions.txt', 'w') + + for obj in dir(dlib): + if obj[0] == '_': + continue + print_element('dlib.'+obj, fc, ff) + + for obj in dir(dlib.cuda): + if obj[0] == '_': + continue + print_element('dlib.cuda.'+obj, fc, ff) + + for obj in dir(dlib.image_dataset_metadata): + if obj[0] == '_': + continue + print_element('dlib.image_dataset_metadata.'+obj, fc, ff) + diff --git a/ml/dlib/docs/docs/python/index.rst b/ml/dlib/docs/docs/python/index.rst new file mode 100644 index 000000000..6ad2d3d5b --- /dev/null +++ b/ml/dlib/docs/docs/python/index.rst @@ -0,0 +1,45 @@ + +.. image:: ../dlib-logo.png + :alt: Dlib C++ Library + :target: http://dlib.net + +Dlib is principally a C++ library, however, you can use a number of its tools +from python applications. This page documents the python API for working with +these dlib tools. If you haven't done so already, you should probably look at +the python example programs first before consulting this reference. These +example programs are little mini-tutorials for using dlib from python. They +are listed on the left of the main dlib web page. + +Classes +============================================== + +.. include:: classes.txt + +Functions +============================================== + +.. include:: functions.txt + + +Detailed API Listing +============================================== + +.. toctree:: + :maxdepth: 2 + +.. automodule:: dlib + :members: + :undoc-members: + +.. automodule:: dlib.cuda + :members: + :undoc-members: + +.. automodule:: dlib.image_dataset_metadata + :members: + :undoc-members: + + + + + diff --git a/ml/dlib/docs/docs/rbf_big_gamma.gif b/ml/dlib/docs/docs/rbf_big_gamma.gif new file mode 100644 index 000000000..7885d655b Binary files /dev/null and b/ml/dlib/docs/docs/rbf_big_gamma.gif differ diff --git a/ml/dlib/docs/docs/rbf_normal.gif b/ml/dlib/docs/docs/rbf_normal.gif new file mode 100644 index 000000000..810fd52af Binary files /dev/null and b/ml/dlib/docs/docs/rbf_normal.gif differ diff --git a/ml/dlib/docs/docs/rbf_small_gamma.gif b/ml/dlib/docs/docs/rbf_small_gamma.gif new file mode 100644 index 000000000..24baaff63 Binary files /dev/null and b/ml/dlib/docs/docs/rbf_small_gamma.gif differ diff --git a/ml/dlib/docs/docs/release_notes.xml b/ml/dlib/docs/docs/release_notes.xml new file mode 100644 index 000000000..2bb03580e --- /dev/null +++ b/ml/dlib/docs/docs/release_notes.xml @@ -0,0 +1,4437 @@ + + + + + Release notes + + + + + + + + +New Features and Improvements: + - Deep Learning: + - Added scale_ layer, allowing implementation of squeeze-and-excitation networks. + - Added loss_multimulticlass_log: used for learning a collection of multi-class classifiers. + - Added a random forest regression tool. See random_forest_regression_trainer. + - Added make_bounding_box_regression_training_data() + - Added isotonic_regression + - Added momentum_filter, rect_filter, find_optimal_momentum_filter(), and + find_optimal_rect_filter(). + - Added binomial_random_vars_are_different() and event_correlation(). + - Added xcorr_fft(), a routine for efficiently performing large cross-correlations using the FFT. + - Added the ramdump type decorator for invoking faster serialization routines. + - Added check_serialized_version() + - Added max_scoring_element() and min_scoring_element() + - Made orthogonalize() faster. + - Updates to the Python API: + - Added interface to the global_function_search object. This is a more general + interface to the solver used by find_max_global(). + - Added support for variadic Python functions in find_max_global(). + - Added rect_filter and find_optimal_rect_filter(). + - Added make_bounding_box_regression_training_data() + - Added the image_dataset_metadata routines for parsing XML datasets. + - Added rvm_trainer + - Added probability_that_sequence_is_increasing() + - Added dlib.__time_compiled__ field + - Added num_threads to shape_predictor_training_options. + - Added CUDA controlling routines such as set_device() and + set_dnn_prefer_smallest_algorithms(). + +Non-Backwards Compatible Changes: + - Changed CMake so that there is only the dlib target and it isn't forced to + be static. Instead, the build type will toggle based on the state of CMake's + BUILD_SHARED_LIBS variable. So there is no longer a dlib_shared target. + - Changed the integer types used to represent sizes from 32bits to 64bits in numerous + places, such as in the tensor object. This should be a backwards compatible change + for nearly all client code. + +Bug fixes: + - Fixed memory leak in java swig array binding tool. + - Fixed windows include order problem in all/source.cpp file. + - Fixed cont_ layers not printing the correct num_filters parameter when they were + printed to std::cout or to XML. + - Fixed some code not handling OBJECT_PART_NOT_PRESENT values correctly. + - Fixed fft_inplace() not compiling for compile time sized matrices. + - The shape_predictor_trainer could have very bad runtime for some really + bad parameter settings. This has been fixed and also warning messages about + really bad training data or parameters have been added. + - Fixed the decayed running stats objects so they use unbiased estimators. + + + + + + +New Features and Improvements: + - Switched the Python API from Boost.Python to pybind11. This means Python + users don't need to install Boost anymore, making building dlib's Python API + much easier. + - Made the sparse version of svd_fast() use multiple CPU cores. + - Changed the behavior of imglab's --flip option. It will now attempt to + adjust any object part labels so that the flipped dataset has the same + average part layout as the source dataset. There is also a new --flip-basic + option that behaves like the old --flip. However, most people flipping a + dataset with part annotations will want to use --flip. For more details + see: http://blog.dlib.net/2018/01/correctly-mirroring-datasets.html + +Non-Backwards Compatible Changes: + - Removed std::auto_ptr from dlib's old (and depreciated) smart pointers. + +Bug fixes: + - Fixed global_optimization.py not working in Python 3. + + + + + + +New Features and Improvements: + - Added a global optimizer, find_max_global(), which is suitable for + optimizing expensive functions with many local optima. For example, you + can use it for hyperparameter optimization. See model_selection_ex.cpp + for an example. + - Updates to the deep learning tooling: + - Added semantic segmentation examples: dnn_semantic_segmentation_ex.cpp + and dnn_semantic_segmentation_train_ex.cpp + - New layers: loss_ranking, loss_epsilon_insensitive, softmax_all, and loss_dot. + - Made log loss layers more numerically stable. + - Upgraded the con layer so you can set the number of rows or columns to + 0 in the layer specification. Doing this means "make the filter cover + the whole input image dimension". This provides an easy way to make a + filter sized so it will have one output along that dimension, + effectively making it like a fully connected layer operating on a row + or column. + - Added support for non-scale-invariant MMOD. + - Added an optional parameter to dnn_trainer::get_net() that allows you + to call the function without forcing a state flush to disk. + - Sometimes the loss_mmod layer could experience excessively long runtime + during early training iterations. This has been optimized and is now + much faster. + - Optimized the tensor's management of GPU memory. It now uses less memory + in some cases. It will also not perform a reallocation if resized to a + smaller size. Instead, tensors now behave like std::vector in that + they just change their nominal size but keep the same memory, only + reallocating if they are resized to something larger than their + underlying memory block. This change makes some uses of dlib faster, in + particular, running networks on a large set of images of differing + sizes will now run faster since there won't be any GPU reallocations, + which are notoriously slow. + - Upgraded the input layer so you can give + input<std::array<matrix<T>,K>> types as input. Doing + this will create input tensors with K channels. + - Added disjoint_subsets_sized + - Added Python APIs: get_face_chips(), count_steps_without_decrease(), + count_steps_without_decrease_robust(), and jitter_image(). + - Various improvements to CMake scripts: e.g. improved warning and error + messages, added USE_NEON_INSTRUCTIONS option. + - chol() will use a banded Cholesky algorithm for banded matrices, making it + much faster in these cases. + - Changed the timing code to use the C++11 high resolution clock and + atomics. This makes the timing code a lot more precise. + +Non-Backwards Compatible Changes: + - Changed the random_cropper's set_min_object_size() routine to take min box + dimensions in the same format as the mmod_options object (i.e. two lengths + measured in pixels). This should make defining random_cropping strategies + that are consistent with MMOD settings more straightforward since you can + simply take the mmod_options settings and give them to the random_cropper + and it will do the right thing. + - Changed the mean squared loss layers to return a loss that's the MSE, not + 0.5*MSE. The only thing this effects is the logging messages that print + during training, which were confusing since the reported loss was half the + size you might naively expect. + - Changed the outputs of test_regression_function() and cross_validate_regression_trainer(). + These functions now output 4D rather than 2D vectors. The new output is: + mean squared error, correlation, mean absolute error, and standard + deviation of absolute error. I also made test_regression_function() take + a non-const reference to the regression function so that DNN objects can + be tested. + - Fixed shape_predictor_trainer padding so it behaves as it used to. In + dlib 19.7 the padding code was changed and accidentally doubled the size + of the applied padding in some cases. It's not a huge deal either way, but + this change reverts back to the previous behavior. + +Bug fixes: + - Fixed toMat() not compiling in some cases. + - Significantly reduced the compile time of the DNN example programs in + visual studio. + - Fixed a few image processing functions that weren't using the generic + image interface. + - Fixed a bug in the random_cropper where it might crash due to division by + 0 if small images were given as input. + - Fixed a bug in how the mmod_options automatically determines detection + window sizes. It would pick a bad size in some cases. + - Fixed load_image_dataset()'s skip_empty_images() option. It wasn't + skipping images that only have ignore boxes when you load into mmod_rect + objects. + - Fixed a bug where chinese_whispers(), when called from python, would + sometimes return a labels array that didn't include labels for all the + inputs. + - Fixed a bug in dlib's MS Windows GUI code that was introduced a little + while back when we switched everything to std::shared_ptr. This change + fixes a bug where the program crashes or hangs sometimes during program + shutdown. + - Fixed error in TIME_THIS() introduced in dlib 19.7. It was printing + seconds when it said minutes in the output. + - Adding missing implementation of tabbed_display::selected_tab. + - Changed the windows signaler and mutex code to use the C++11 thread + library instead of the old win32 functions. I did this to work around how + windows unloads dlls. In particular, during dll unload windows will kill + all threads, THEN it will destruct global objects. So this can lead to + problems when a global object that owns threads tries to tell them to + shutdown, since the threads have already vanished. The new code mitigates + some of these problems, in particular, there were some cases where + unloading dlib's python extension would deadlock. This should now be + fixed. + - Fixed compile time errors when either of these macros were enabled: + DLIB_STACK_TRACE, DLIB_ISO_CPP_ONLY. + + + + + + +New Features and Improvements: + - Deep Learning: + - The CNN+MMOD detector is now a multi-class detector. In particular, + the mmod_rect object now has a string label field which you can use to + label objects, and the loss_mmod_ layer will learn to label objects with + those labels. For an example, see: https://www.youtube.com/watch?v=OHbJ7HhbG74 + - CNN+MMOD detectors are now 2.5x faster. For instance, this example program + http://dlib.net/dnn_mmod_find_cars_ex.cpp.html now runs at 98fps instead + of 39fps. + - Added a 5 point face landmarking model that is over 10x smaller than the + 68 point model, runs faster, and works with both HOG and CNN generated + face detections. It is now the recommended landmarking model to use for + face alignment. render_face_detections() and get_face_chip_details() have been + updated to work with both 5 and 68 point models, so the new 5 point model is + a drop in replacement for the 68 point model. + - The imglab tool is slightly improved. It will display box labels with + higher relative contrast. You can also now press END or i to ignore boxes + in imglab. This is useful because it's a much less stressing hand motion + to hit END that i in most cases. + - Added overloads of sub_image() that take raw pointers so you can make + sub_images of anything. + - Changed TIME_THIS() to use std::chrono::high_resolution_clock, so now it's + much higher precision. + - Exposed Chinese whispers clustering to Python, added face clustering example. + +Non-Backwards Compatible Changes: + +Bug fixes: + - Fixed an error in input_rgb_image_pyramid::image_contained_point(). The + function might erroneously indicate that a point wasn't inside the original + image when really it was, causing spurious error messages. + - mmod_options would pick bad window sizes in some corner cases. This has been fixed. + - Fixed a bug in the extract layer that trigged when a tensor with a + different number of samples than the tensor used to initialize the network + was passed through the layer. + - The loss_per_missed_target parameter of the loss_mmod_ wasn't being used + exactly right when boxes were auto-ignored. There weren't any practical + user facing problems due to this, but it has nevertheless been fixed. + + + + + + +New Features and Improvements: + +Non-Backwards Compatible Changes: + +Bug fixes: + - Fix build error in Visual Studio when CUDA is enabled. + + + + + + +New Features and Improvements: + - Deep Learning + - Added a python wrapper for using the CNN face detector. + - Added support for cuDNN v6 and v7. + - Added a simple tool to convert dlib model files to caffe models. + See the tools/convert_dlib_nets_to_caffe folder for details. + - New DNN layers + - loss_multiclass_log_per_pixel_ + - loss_multiclass_log_per_pixel_weighted_ + - loss_mean_squared_per_pixel_ + - cont_ (transpose convolution, sometimes called "deconvolution") + - mult_prev_ (like add_prev_ but multiplies instead of adds) + - extract_ (sort of like caffe's slice layer) + - upsample_ (upsamples a tensor using bilinear interpolation) + - Object Detection + - Upgraded loss_mmod_ to support objects of varying aspect ratio. This + changes the API for the mmod_options struct slightly. + - Relaxed the default non-max suppression parameters used by the + mmod_options object so that users of the deep learning MMOD tool don't + get spurious errors about impossibly labeled objects during training. + - Added missing input validation to loss_mmod_. Specifically, the loss + layer now checks if the user is giving truth boxes that can't be detected + because the non-max suppression settings would prevent them from being + output at the same time. If this happens then we print a warning message + and set one of the offending boxes to "ignore". I also changed all + the input validation errors to warning messages with auto conversion + to ignore boxes rather than exceptions. + - Changed the random_cropper's interface so that instead of talking in + terms of min and max object height, it's now min and max object size. + This way, if you have objects that are short and wide (i.e. objects where + the relevant dimension is width rather than height) you will get sensible + behavior out of the random cropper. + - Added options to input_rgb_image_pyramid that let the user set + create_tiled_pyramid()'s padding parameters. Also changed the default + outer border padding from 0 to 11. This effects even previously trained + models. So any model that doesn't explicitly set the outer patting to + something else will have a padding of 11. This should be a more + reasonable value for most networks. + - Added process() and process_batch() to add_loss_layer. These routines + let you easily pass arguments to any optional parameters of a loss + layer's to_tensor() routine. For instance, it makes it more convenient to + set loss_mmod_'s adjust_threshold parameter. + - Added visit_layers_until_tag() + - Improved how dnn_trainer synchronizes its state to disk. It now uses + two files and alternates between them. This should be more robust in + the face of random hardware failure during synchronization than the + previous synchronization method. + - Made it so you can set the number of output filters for con_ layers at runtime. + - The way cuDNN work buffers are managed has been improved, leading to + less GPU RAM usage. Therefore, users should not need to call + set_dnn_prefer_smallest_algorithms() anymore. + - Added operator<< for random_cropper and dnn_trainer to allow + easy logging of training parameters. + - Made concat_ layer a lot faster. + - Made the dnn_trainer not forget all the previous loss values it knows + about when it determines that there have been a lot of steps without + progress and shrinks the learning rate. Instead, it removes only a + small amount of the oldest values. The problem with the old way of + removing all the loss values in the history was that if you set the + steps without progress threshold to a really high number you would + often observe that the last few learning rate values were obviously not + making progress, however, since all the previous loss values were + forgotten the trainer needed to fully populate its loss history from + scratch before it would figure this out. This new style makes the + trainer not waste time running this excessive optimization of obviously + useless mini-batches. I also changed the default + get_test_iterations_without_progress_threshold() from 200 to 500. Now + that we have a better history management of loss values in the trainer + it's much more sensible to have a larger value here. + - Dlib's simd classes will now use ARM NEON instructions. This makes the + HOG based object detector faster on mobile devices running ARM processors. + - Added last_modified() method to dlib::file. Also, added + select_oldest_file() and select_newest_file(). + - Added solve_qp_box_constrained_blockdiag() + - Added an overload of mat() that takes a row stride value. + - Added cmake scripts and some related tooling that makes it easy to call + C++ code from java. See dlib/java/ folder. + - MATLAB MEX wrapper API + - Made the mex wrapper deal with cell arrays that have null elements. + - Made ctrl+c detection in a mex file work more reliably in newer versions of matlab. + - Added set_rect_area() + - Gave test_object_detection_function() an option to set how ignore box + overlap is tested. + - Added serialization support for the running_stats_decayed object. + - Additions to imglab + - Added --sort and also the ability to propagate boxes from one image to + the next using dlib::correlation_tracker. + - Made it so you can remove images by pressing alt+d. + - Made is so pressing e in imglab toggles between views of the image + where the histogram is equalized or unmodified. This way, if you are + looking at particularly dark or badly contrasted images you can toggle + this mode and maybe get a better view of what you are labeling. + - Made the attribute_list of the xml parser a little more friendly by + allowing you to ask for attributes that don't exist and get a defined + behavior (an exception being thrown) rather than it being a contract + violation. + +Non-Backwards Compatible Changes: + - DNN solver objects are now required to declare operator<<. + - Broke backwards compatibility with previous dnn_trainer serialization + format. The network serialization format has not changed however. So old + model files will still load properly. + - Changed random_cropper interface. + - Changed the XML format output by net_to_xml(). Specifically, the XML tag + for affine layers was changed to use the same conventions as other layers + that support convolutional vs fully connected modes. + - Dlib's smart pointers have been deprecated and all of dlib's code has been + changed to use the std:: version of these smart pointers. The old dlib + smart pointers are still present, allowing users to explicitly include + them if needed, but users should migrate to the C++11 standard version of + these tools. + - Changed the functions that transform between input tensor coordinates and + output tensor coordinates to use dpoint instead of point. This way, we can + obtain sub-pixel coordinates if we need them. + - Upgraded loss_mmod_ to support objects of varying aspect ratio. This + changes the API for the mmod_options struct slightly. + +Bug fixes: + - Made resize_image() and functions that use it like the pyramid objects + produce better results when run on float and double images. There was + needless rounding to integers happening in the bilinear interpolation. Now + if you work with a float image the entire process will run without integer + rounding. + - Made the input_tensor_to_output_tensor() and output_tensor_to_input_tensor() + coordinate mappings work on networks that contain skip layers. + - The input_rgb_image_sized is supposed to be convertible to + input_rgb_image, which it was in all ways except you couldn't deserialize + directly like you would expect. This has now been fixed. + - There was a bug in the concat_ layer's backward() method. It was assigning + the gradient to previous layers instead of adding the gradient, as required + by the layer interface specification. Probably no-one has been impacted + by this bug, but it's still a bug and has been fixed. + - Changed the random_cropper so that it samples background patches uniformly + across scales regardless of the input image size. Previously, if you gave + really large images or really small images it had a bias towards giving only + large patches or small patches respectively. + - Fixed name lookup problem for calls to serialize() on network objects. + - Fixed double delete in tokenizer_kernel_1. + - Fixed error in pyramid_down<2> that caused the output image to be a + little funny looking in some cases. + - Fixed the visit_layers_backwards() and visit_layers_backwards_range() + routines so they visit layers in the correct order. + - Made build scripts work on a wider range of platforms and configurations. + - Worked around global timer cleanup issues that occur on windows when dlib + is used in a dll in some situations. + - Fixed various compiler errors in obscure environments. + + + + + + +New Features: + +Non-Backwards Compatible Changes: + - CMake 2.8.12 is now required to build dlib (but only if you use CMake). + +Bug fixes: + - Fixed a slow memory leak that could occur when using cuDNN. + +Other: + + + + + +New Features: + - Deep Learning + - Added a state-of-the-art face recognition tool (99.38% accuracy on the + LFW benchmark) with C++ and Python example programs. + - Added these new loss layer types: loss_metric_, loss_mean_squared_, and + loss_mean_squared_multioutput_. + - Added the l2normalize_ computational layer. + - Added test_one_step() to the dnn_trainer. This allows you to do + automatic early stopping based on observing the loss on held out data. + - Made the dnn_trainer automatically reload from the last good state if a + loss of NaN is encountered. + - Made alias_tensor usable when it is const. + - Dlib's simd classes will now use PowerPC VSX instructions. This makes the + HOG based object detector faster on PowerPC machines. + - Added compute_roc_curve() + - Added find_gap_between_convex_hulls() + - Added serialization support for std::array. + - Added running_scalar_covariance_decayed object + - Added running_stats_decayed object + - Added min_pointwise() and max_pointwise(). + - Added a 1D clustering routine: segment_number_line(). + - Added Intel MKL FFT bindings. + - Added matlab_object to the mex wrapper. Now you can have parameters that + are arbitrary matlab objects. + - Added support for loading of RGBA JPEG images + +Non-Backwards Compatible Changes: + - Changed the loss layer interface to use two typedefs, output_label_type + and training_label_type instead of a single label_type. This way, the label + type used for training can be distinct from the type output by the network. + This change breaks backwards compatibility with the previous API. + +Bug fixes: + - Fixed compiler warnings and errors on newer compilers. + - Fixed a bug in the repeat layer that caused it to throw exceptions in some + cases. + - Fixed matlab crashing if an error message from a mex file used the % + character, since that is interpreted by matlab as part of an eventual + printf() code. + - Fixed compile time error in random_subset_selector::swap() + - Fixed missing implementation of map_input_to_output() and + map_output_to_input() in the concat_ layer. + - Made the dnn_trainer's detection and backtracking from situations with + increasing loss more robust. Now it will never get into a situation where it + backtracks over and over. Instead, it will only backtrack a few times in a + row before just letting SGD run unimpeded. + +Other: + - Usability improvements to DNN API. + - Improved C++11 detection, especially on OS X. + - Made dlib::thread_pool use std::thread and join on the threads in + thread_pool's destructor. The previous implementation used dlib's global + thread pooling to allocate threads to dlib::thread_pool, however, this + sometimes caused annoying behavior when used as part of a MATLAB mex file, + very occasionally leading to matlab crashes when mex files were unloaded. + This also means that dlib::thread_pool construction is a little bit slower + than it used to be. + + + + + +New Features: + - Updates to the deep learning API: + - Added tools for making convolutional neural network based object detectors. See + dnn_mmod_ex.cpp example program. + - Added annotation() to tensor so you can associate any object you want with a tensor. + - Made layer_details() part of the SUBNET interface so that user defined layer + details objects can access each other. Also added the input_layer() global function + for accessing the input layer specifically. + - alias_tensor can now create aliases of const tensors. + - Added set_all_bn_running_stats_window_sizes(). + - Added visit_layers_backwards(), visit_layers_backwards_range(), and + visit_layers_range(). + - Computational layers can now optionally define map_input_to_output() and + map_output_to_input() member functions. If all layers of a network provide these + functions then the new global functions input_tensor_to_output_tensor() and + output_tensor_to_input_tensor() can be used to map between the network's input and + output tensor coordinates. This is important for fully convolutional object + detectors since they need to map between the image space and final feature space. + These new functions are important for tools like the new MMOD detector. + - Added input_rgb_image_pyramid. + - Image Processing + - The imglab command line tool has these new options: --min-object-size, --rmempty, + --rmlabel, --rm-if-overlaps, and --sort-num-objects. I also changed the behavior of + --split so that it simply partitions the data and is an invertible operation. + - Added mmod_rect + - Added an overload of load_image_dataset() that outputs directly to mmod_rect + instead of rectangle. + - Added image_dataset_file::shrink_big_images(). So now load_image_dataset() can load + a dataset of high resolution files at a user requested lower resolution. + - Added box_intersection_over_union(). + - Added create_tiled_pyramid(), image_to_tiled_pyramid(), and tiled_pyramid_to_image(). + - Added random_cropper + - Upgraded dlib's mex wrapper tooling to enable easy binding of C++ classes to MATLAB + objects. + - Added nearest_rect() + - Added find_upper_quantile() + - Added count_steps_without_decrease_robust(). + - Added get_double_in_range() to dlib::rand. + +Non-Backwards Compatible Changes: + - C++11 is now required to use dlib. + - Changed pinv() so it interprets its tol argument relative to the largest singular + value of the input matrix rather than as an absolute tolerance. This should generally + improve results, but could change the output in some cases. + - Renamed the class members of test_box_overlap so they are less confusing. + - Updates to the deep learning API: + - Changed the DNN API so that sample_expansion_factor is a runtime variable rather + than a compile time constant. This also removes it from the input layer interface + since the DNN core now infers its value at runtime. Therefore, users that define their + own input layers don't need to specify it anymore. + - Changed DEFAULT_BATCH_NORM_EPS from 1e-5 to 1e-4. + - Changed the default batch normalization running stats window from 1000 to 100. + +Bug fixes: + - Made the relational operators constexpr so they don't accidentally cause compilation + errors when they get pulled into the scope of template metaprogramming expressions. + - Fixed all/source.cpp not compiling in some instances. + - CMake scripts now do a better job detecting things like C++11 support, the presence of + CUDA, and other system specific details that could cause the build to fail if not + properly configured. + - Fixed a bug in imglab's --cluster option where it would output xml files with empty + entries if the input xml file contained unannotated images. + - Fixed imglab's --cluster option not working with relative paths. + +Other: + - Made the thread local variables that hold the cudnn and cublas context objects not + destruct and recreate themselves when you switch devices. Instead, they keep a table + of context objects, for each thread and device, reusing as necessary. This prevents + churn in the context objects when you are switching back and forth between devices + inside a single thread, making things run more efficiently for some CUDA based + workflows. + - Made the message argument of the DLIB_ASSERT and DLIB_CASSERT macros optional. + - Made thread_pool and parallel_for propagate exceptions from task threads to calling + code rather than killing the application if a task thread throws. + - Changed imglab --resample so that it never changes the aspect ratio of an image. + - Made the check in dnn_trainer for convergence more robust. Previously, if we + encountered a bad mini-batch that made the loss value suddenly jump up by a larger than + normal value it could make the trainer think we converged. Now the test is robust to + transient spikes in loss value. Additionally, the dnn_trainer will now check if the + loss has been increasing before it saves the state to disk. If it detects that the loss + has been going up then instead of saving to disk it recalls the previously good state. + This way, if we hit a really bad mini-batch during training which negatively effects + the model in a significant way, the dnn_trainer will automatically revert back to an + earlier good state. + + + + + +New Features: + - Support for cuDNN 5.1 + - dlib::async() and dlib::default_thread_pool(). + - rectangle_transform + - imglab tool: added --resample, --ignore, --files, and --extract-chips + command line options. Also added convert_imglab_paths_to_relative and + copy_imglab_dataset scripts. + - Evgeniy Fominov made the shape_predictor trainer multi-threaded and faster. + - sutr90 contributed support for the CIELab color space. See the new lab_pixel. + +Non-Backwards Compatible Changes: + - All the cmake utility scripts were moved to dlib/cmake_utils. + - Code that #includes the shape_predictor can now only be compiled with + compilers that support C++11 lambda functions. + +Bug fixes: + - Made CMake scripts work in a wider range of environments. + - Fixed compile time errors on various platforms. + - Fixed bad multi-threading support in the MATLAB mex wrapper. + - Fixed bug in cuDNN binding that could sometimes cause NaN outputs. + - Fixed bad convergence testing in DNN tooling for very small datasets. + +Other: + + + + + +New Features: + - A deep learning toolkit using CPU and/or GPU hardware. Some major elements + of this are: + - Clean and fully documented C++11 API + - Clean tutorials: see dnn_introduction_ex.cpp and dnn_introduction2_ex.cpp + - Uses cuDNN v5.0 + - Multi-GPU support + - Automatic learning rate adjustment + - A pretrained 1000 class Imagenet classifier (see dnn_imagenet_ex.cpp) + - Optimization Tools + - Added find_optimal_parameters() + - Added elastic_net class + - Added the option to use the elastic net regularizer to the OCA solver. + - Added an option to solve the L2-loss version of the SVM objective function to svm_c_linear_dcd_trainer. + - Added solve_qp_box_constrained() + - Image Processing + - Added random_color_transform, disturb_colors(), and apply_random_color_offset(). + - load_image() now supports loading GIF files. + - Many improvements to the MATLAB binding API + - Automatically link to MATLAB's Intel MKL when used on linux. + - struct support + - mex functions can have up to 20 arguments instead of 10. + - In place operation. Made column major matrices directly wrap MATLAB + matrix objects when used inside mex files. This way, if you use + matrix_colmajor or fmatrix_colmajor in a mex file it will not do any + unnecessary copying or transposing. + - Catch ctrl+c presses in MATLAB console. Allowing early termination of mex functions. + - When used inside mex files, DLIB_ASSERTS won't kill the MATLAB process, + just throw an exception. + - Made cerr print in MATLAB as a red warning message. + - load_mnist_dataset() + - Added a constructor for seeding rand with a time_t. + - Added subm_clipped() + - Added unserialize. + - Added running_gradient + +Non-Backwards Compatible Changes: + - Everything in dlib/matlab/call_matlab.h is now in the dlib namespace. + - DLIB_TEST() and DLIB_TEST_MSG() macros now require you to terminate them with a ; + +Bug fixes: + - Fixed bug in 10 argument version of call_matlab() and also cleaned up a few + minor things. + - setup.py and CMake scripts work in a few more contexts. + - Fixed compiler errors in visual studio 2015. + - Fixed a bug in gaussian_blur() that caused messed up outputs when big + sigma values were used on some pixel types. + - Fixed minor bugs in join_rows() and join_cols(). They didn't work when one + of the matrices was empty. + +Other: + - Made CMake scripts uniformly require CMake version 2.8.4. + - Faster fHOG feature extraction / face detection + - CMake scripts now enable C++11 by default + - Gave array2d and matrix move constructors and move assignment operators. Matrix + can also now be created from initializer lists. + + + + + +New Features: + - Added the set_ptrm() routine for assigning dlib::matrix objects to arbitrary + memory blocks. + +Non-Backwards Compatible Changes: + +Bug fixes: + - Fixed a bug that caused cmake to not provide the correct preprocessor + definitions until cmake was run twice. This was causing some projects to + not build properly. + +Other: + - Improvements to build system: + - Ehsan Azarnasab contributed a setup.py so the dlib Python API can be + installed via the usual 'python setup.py install' command. + - Séverin Lemaignan upgraded dlib's CMake scripts so they include an + install target. Now dlib can be installed system wide by executing + 'cmake PATH_TO_DLIB; make install'. This also includes installing the + appropriate scripts for CMake's find_package(dlib) to work. + + + + + +New Features: + - More clustering tools: + - Added bottom_up_cluster() and find_clusters_using_angular_kmeans() + routines. + - Added a --cluster option to the imglab tool. This lets you cluster + objects into groups of similar appearance/pose. + - Improved the shape_predictor. In particular, it can now be learned from + datasets where some landmarks are missing. The shape_predictor also now + outputs a sparse feature vector that encodes which leafs are used on each + tree to make a prediction. + +Non-Backwards Compatible Changes: + - extract_highdim_face_lbp_descriptors() produces slightly different output. + +Bug fixes: + - Fixed a minor bug in extract_highdim_face_lbp_descriptors() which was + pointed out by Yan Xu. One of the face locations was mistakenly used twice + while another was skipped. This change breaks backwards compatibility with + the previous feature extraction output but should slightly improve + accuracy of classifiers trained using these features. + - Fixed jet() and heatmap() so they work on empty images. + - The SQLite transaction object did not function correctly when compiled + in a C++11 program. Since its destructor can throw, an exception + specification needed to be added indicating that this was possible since + destructors are now noexcept by default in C++11. + - Fixed a bug pointed out by Ernesto Tapia that could cause matrix + expressions that involve sub matrix views (e.g. colm) to produce the wrong + results when the BLAS bindings were enabled. + - Added an if to avoid a possible division by zero inside spectral_cluster(). + - Fixed a bug in parse_xml(). It failed to check if the given input stream + was valid before trying to parse it. + +Other: + + + + + + +New Features: + - Added a linear model predictive control solver. See the mpc_ex.cpp example + program for details. + - Thanks to Patrick Snape, the correlation_tracker can now be used from Python. + +Non-Backwards Compatible Changes: + - The camera_transform's second operator() method now takes 3 arguments + instead of 2. This is to allow it to output the z distance in addition to + scale. + +Bug fixes: + - Fixed a bug in the eigenvalue_decomposition which could occur when a + symmetric matrix was used along with the LAPACK bindings. + - Fixed a bug where the last column of data in a file wasn't loaded on some + OS X machines when load_libsvm_formatted_data() was called. + +Other: + - Added a hard iteration limit to a number of the SVM solvers. + - Adrian Rosebrock graciously setup an OS X machine for dlib testing, which + resulted in improved CMake python scripts on OS X machines. + - Improved the way overlapping points are rendered by the perspective_window. + + + + + + +New Features: + - Added a number of tools for working with 3D data: + - Added the perspective_window which is a tool for displaying 3D point clouds. + - Added camera_transform. It performs the 3D to 2D mapping needed to visualize 3D + data. + - Added point_transform_affine3d as well as functions for creating such transforms: + rotate_around_x(), rotate_around_y(), rotate_around_z(), and translate_point(). + - Added draw_solid_circle() for drawing on images. + - Added get_best_hough_point() to the hough_transform. + - Thanks to Jack Culpepper, the python API for object detection now outputs detection + confidences. + - Added lspi, an implementation of the least-squares policy iteration algorithm. + +Non-Backwards Compatible Changes: + - The shape_predictor and shape_predictor_trainer had a non-optimal behavior when used + with objects that have non-square bounding boxes. This has been fixed but will cause + models that were trained with the previous version of dlib to not work as accurately if + they used non-square boxes. So you might have to retrain your models when updating dlib. + +Bug fixes: + - Fixed a bug which prevented add_image_rotations() from compiling. + +Other: + - The imglab tool now allows the user to click and drag annotations around by holding + shift and right clicking. + + + + + + +New Features: + - Added spectral_cluster() + - Added sub_image() and sub_image_proxy + - Added set_all_logging_headers() + +Non-Backwards Compatible Changes: + +Bug fixes: + - Fixed a bug that caused the correlation_tracker to erroneously trigger an assert when + run in debug mode. + +Other: + - Improved the usability of the new drectanle object. + - Optimized extract_fhog_features() for the case where cell_size==1. This makes it about + 4x faster in that case. + - Made it so you can compose point transform objects via operator *. + + + + + + +New Features: + - Added the correlation_tracker object + - Added the option to force the last weight to 1 to structural_assignment_trainer. + - Added max_point_interpolated() + - Added the drectangle object + - New Python Tools: + - Patrick Snape contributed a Python binding for the face landmarking tool and + the general purpose shape prediction/training tools. + - Vinh Khuc contributed a Python binding for find_candidate_object_locations(), + dlib's implementation of the selective search object location proposal method. + +Non-Backwards Compatible Changes: + +Bug fixes: + - Fixed a bug in extract_image_chips() and get_mapping_to_chip() that caused + incorrect outputs when the requested chip stretched the image unevenly + vertically or horizontally. + - Made CMake check that libpng and libjpeg actually contain the link symbols + they are supposed to since, on some systems, these libraries aren't + installed correctly and will cause linker errors if used. + - Fixed assign_border_pixels(img, rect) so that it correctly zeros an image + when an empty rectangle is supplied. Previously, it did nothing to the + image in this case. + - Fixed compute_lda_transform() so it works properly when the class + covariance matrices are singular even after performing PCA. + - Fixed a bug in find_similarity_transform(). When given just two points as + inputs it would sometimes produce a reflection rather than a similarity + transform. + - Disabled all bindings to FFTW because FFTW isn't threadsafe. + +Other: + - Added an example program for dlib's SQLite API and made a few minor + usability improvements to the API as well. + + + + + + +New Features: + - Upgraded fft() and ifft() to support 2D matrices. + - Added hough_transform + - Added skeleton() for finding the skeletonization of a binary image. + - Added distance_to_line(), clip_line_to_rectangle(), min_point(), and max_point(). + - Added a simple API for calling C++ from MATLAB. + +Non-Backwards Compatible Changes: + +Bug fixes: + - Fixed a compile time error that could happen when calling fft() for + certain input types. + - Fixed a compile time error that prevented auto_threshold_image() from + being used. + - Fixed name clashes with new version of Boost. + - Changed Python pickling code so it works with Python 3. + - Fixed CMake compile time error related to finding fftw. + +Other: + - Made extract_image_chips() much faster when extracting unscaled image chips. + + + + + +New Features: + - Added save_jpeg() + - Added the option to use an identity matrix prior to vector_normalizer_frobmetric. + - Made the extract_image_chips() routine more flexible, in particular: Added + get_mapping_to_chip(), get_face_chip_details(), map_det_to_chip(), and also + upgraded chip_details so you can specify a chip extraction by a bunch of + point correspondences between the chip and the original image. + - Added a set of local-binary-pattern based feature extractors: + make_uniform_lbp_image(), extract_histogram_descriptors(), + extract_uniform_lbp_descriptors(), and extract_highdim_face_lbp_descriptors() + - Added compute_lda_transform() + - Added equal_error_rate() + - Added cast_to() to the type_safe_union. This allows you to get the + contents of a const type_safe_union. + +Non-Backwards Compatible Changes: + +Bug fixes: + - Changed noncopyable.h to avoid a name clash with boost 1.56 + - On some platforms hostname_to_ip() would erroneously return 0.0.0.0. This + has been fixed. + +Other: + + + + + +New Features: + - Added find_similarity_transform() + - Added the ability to upgrade a auto_mutex_readonly from a readonly lock to a write + lock. + - Added an implementation of the paper "One Millisecond Face Alignment with an Ensemble + of Regression Trees" by Vahid Kazemi and Josephine Sullivan which appeared in this + year's CVPR conference. Therefore, dlib now includes tools for learning shape models + and also comes with a state-of-the-art face landmark locator. See the + face_landmark_detection_ex.cpp and train_shape_predictor_ex.cpp example programs for + an introduction. + +Non-Backwards Compatible Changes: + - Made the interface to all the image processing routines more generic. In particular, + it is now easier to use arbitrary image types with dlib. The new generic image + interface is defined in dlib/image_processing/generic_image.h and simply consists of + seven user defined global functions and a traits template. Any user code that was + using array2d objects to represent images will still work. However, if you had been + using your own custom image object you will need to provide implementations of the + seven functions. Instructions for how to do this are in + dlib/image_processing/generic_image.h. + +Bug fixes: + - Changed the murmur hash implementation to avoid any possibility of strict aliasing + violations in user code, even when things get inlined in unfavorable ways. + - Fixed a color space handling bug in resize_image() that caused bad looking outputs in + some cases. + - If "cmake" was a substring of the full path to your source code folder then the cmake + scripts would fail. This has been fixed. + - Fixed a compile time error that could occur when using find_max_single_variable(). + +Other: + - load_image() now uses the internal file header information to detect the + image format rather than looking at the file extension. + - Renamed unit test program to dtest avoid warnings from CMake. + - cross_validate_trainer() and cross_validate_trainer_threaded() no loner make copies + of the training data. This significantly reduces their RAM usage for large datasets. + - Changed the serialization code for C-strings so that they don't save the null + terminator byte. This makes their serialization format the same as the format for + std::string. The code should still be able to read all previously serialized data + correctly, so the change is backwards compatible with previous versions of dlib. + - Changed the evaluate_detectors() routine so that it applies non-max suppression to + each detector individually. This way one detector doesn't stomp on the output of + another detector. + - Made the version of draw_line() that draws onto a regular image use alpha blending + for drawing diagonal lines. + + + + + +New Features: + +Non-Backwards Compatible Changes: + +Bug fixes: + - The new simplified serialization API that works like serialize("filename")<<object + was not opening files in binary mode and therefore didn't work properly on Windows. + This has been fixed. + +Other: + + + + + + +New Features: + - Added the ability to set a previously trained function as a prior to the + svm_multiclass_linear_trainer, svm_c_linear_trainer, and svm_rank_trainer + objects. + - Added a user settable loss to the structural_assignment_trainer and + structural_track_association_trainer objects. + - Added evaluate_detectors(), a function for efficiently running multiple fHOG + based object detectors. + - Added the new split_on_first() and split_on_last() string manipulation functions. + - Added locally_change_current_dir, a RAII tool for switching between directories. + - You can now make a 1x1 matrix containing a single value by calling mat() on a single + scalar value. + - The point transform functions and frobmetric_training_sample are now serializable. + - Added a simplified operator << and >> based syntax for serializing to and + from files. So now you can serialize to a file using a syntax of: + serialize("myfile.dat") << myobject << another_object; + and then load those objects from disk via: + deserialize("myfile.dat") >> myobject >> another_object; + An arbitrary number of objects can be serialized or deserialized by + chaining the << and >> operators. + +Non-Backwards Compatible Changes: + +Bug fixes: + - Fixed a bug pointed out by Daniel Girardeau-Montaut. The covariance() + function didn't work on non-double valued matrices. + - Fixed a bug in the backtracking_line_search() function pointed out by + Ping-Chang Shih. The function ignored the max_iter parameter. + - Fixed a compiler error encountered when using clang 3.4 on Mac OS X 10.9. + Thanks to Martin Fergie for reporting this problem. + - Fixed a potential divide by zero in draw_fhog() + +Other: + - Added an example program showing how to set a custom logger output hook. + - Made linear decision_functions which use sparse vectors much faster. + + + + + + +New Features: + - Added a Python API for working with fHOG based object detectors. See the + new python example programs train_object_detector.py and face_detector.py for + more details. + - Added the ability to use a user supplied fHOG style feature extractor with + the scan_fhog_pyramid object. So now you can define your own version of HOG + for use with these tools. + - The oca solver now supports taking a user supplied prior vector. That is, + it lets you use a regularizer like ||w-prior||^2 instead of the usual + ||w||^2 regularizer. + - Added the structural_track_association_trainer object. It is a structural + SVM tool for creating multi-target tracking algorithms. See the + learning_to_track_ex.cpp example program for an introduction. + - Added the following minor utility functions: nearest_center(), + add_image_rotations(), set_aspect_ratio(), and tile_images(). + +Non-Backwards Compatible Changes: + - Refactored the load_image_dataset() routines so they are easier to use and + more flexible. This introduces a slight backwards incompatibility in that + the version that loads full_object_detection objects now returns an ignore + rectangle set instead of a parts name list. Other than that the changes + are backwards compatible with previous versions of dlib. + - Added a bias term to the assignment_function's model so the user doesn't + need to remember, or even understand, that they should add it themselves. + However, this change breaks backwards compatibility with the previous + serialization format for assignment_function objects. + +Bug fixes: + - Fixed a number of compile time errors that could occur in rare cases. + - The stopping condition for the svr_linear_trainer was too tight, causing it + to take an excessive amount of time to converge in some cases. + - Disabled use of XIM for X11 windowing since it makes programs hang on some + systems. However, this means the wide character input methods won't work on + X11 systems anymore. + - Fixed a bug in randomize_samples() which caused the outputs to be not as + random as they should be. + - Fixed dlib's CMakeLists.txt file so that the "use FFTW" option actually + causes the build to use FFTW. + - Fixed a compile time error that triggered when trying to link with FFTW. + - mat() did not work correctly when used with std::vector<bool> objects. + This has been fixed. + +Other: + + + + + + +New Features: + - Object Detection Tools: + - Added scan_fhog_pyramid, a tool for creating Histogram of Oriented Gradient (HOG) + based object detectors. + - Added get_frontal_face_detector(), a HOG based face detector. + - Added an option to include "ignore/don't care" truth boxes to the + structural_object_detection_trainer. This allows a user to tell the trainer that + they don't care if certain objects are detected or not detected. + - Image Processing Tools: + - Added extract_image_chips() + - Added a version of draw_rectangle() for drawing on images. + - The spatial filtering routines now support even sized filters. + - Added flip_image_dataset_left_right(), upsample_image_dataset(), and + rotate_image_dataset(). + - Machine Learning Tools: + - Added a nuclear norm regularization option to the structural SVM solver. + - Added the option to learn only non-negative weights to the + svm_multiclass_linear_trainer. + - Speed Improvements: + - The svm_multiclass_linear_trainer, one_vs_one_trainer, and one_vs_all_trainer + objects are now multithreaded. This also means you have to #include + dlib/svm_threaded.h instead of dlib/svm.h to use these tools. + - A number of image processing tools can now optionally use SSE and AVX instructions + and are therefore considerably faster. In particular, the following tools have been + accelerated: extract_fhog_features, resize_image, pyramid_down, pyramid_up, + spatially_filter_image_separable, and spatially_filter_image. + - Added an inv() routine that inverts point transformation functions. + - Added a sign() routine for matrix objects. + +Non-Backwards Compatible Changes: + - The spatial image filtering functions have the following changes: + - They no longer zero the image borders when you set the add_to parameter to true. + - The spatially_filter_image_separable_down() routine now only allows grayscale + output images. + - Changed the default parameters of the test_box_overlap object. Now it defaults to + using exactly the PASCAL VOC match criterion. + - To use the svm_multiclass_linear_trainer, one_vs_one_trainer, or one_vs_all_trainer + objects you now have to #include dlib/svm_threaded.h instead of dlib/svm.h. + - pyramid_up() no longer has a levels option. + +Bug fixes: + - Fixed a compile time bug that could occur when wide character strings were + serialized. + - Fixed a compile time bug occurring in gcc 4.7.1 on SUSE Linux. Thanks to Volker + Härtel for finding this. + - Fixed compile time errors that occurred when using gcc on cygwin. + - Fixed a compile time bug that could occur when serializing mlp objects. + - Fixed a bug in the bigint object that caused division to sometimes produce incorrect + results. + - Fixed a bug which sometimes caused load_image_dataset() to erroneously report that + the dataset file could not be found. + - Fixed a bug in the structural_object_detection_trainer that caused it to erroneously + throw a impossible_labeling_error exception in certain rare cases. + - Updated find_max_factor_graph_nmplp() to use the improved version of the algorithm + from the 2011 paper Introduction to dual decomposition for inference by David Sontag, + Amir Globerson, and Tommi Jaakkola. The original algorithm presented in their 2008 + paper had an error that negatively affected its convergence. Thanks to James Gunning + for pointing this out. + +Other: + - Fixed many compiler warnings in gcc 4.8. + - Made many of the mat() converters bind the resulting matrix expressions into BLAS + functions. + - libpng and libjpeg are now included in the dlib/external folder to enable easy static + linking to these libraries on platforms that typically don't have them (e.g. Windows). + Moreover, dlib's cmake files will automatically perform this static linking when no + copy of these libraries is found on the system. + + + + + +New Features: + - Added routines for performing BFGS and L-BFGS optimization with box constraints. + See the new find_min_box_constrained() and find_max_box_constrained() routines. + - Added vector_normalizer_frobmetric. This is a tool for learning a + Mahalanobis distance metric. + - The user can now set different loss values for false alarming vs. getting a + correct detection when using the structural_sequence_segmentation_trainer. + - Added an overload of clamp() that lets you use matrix valued lower/upper bounds. + - New image processing tools: + - Added the scan_image_custom object, split_array(), and add_image_left_right_flips(). + - Added extract_fhog_features(), this is a function for computing + Felzenszwalb's 31 channel HOG image representation. + +Non-Backwards Compatible Changes: + - Refactored the image pyramid code. Now there is just one templated object called + pyramid_down and you give it the downsampling amount as a template argument. To make + old code work with this change use the following substitutions: + change pyramid_down to pyramid_down<2> + change pyramid_down_3_2 to pyramid_down<3> + change pyramid_down_4_3 to pyramid_down<4> + change pyramid_down_5_4 to pyramid_down<5> + +Bug fixes: + +Other: + - Made the structural SVM solver slightly faster. + - Moved the python C++ utility headers from tools/python/src into dlib/python. + - The PNG loader is now able to load grayscale images with an alpha channel. + - Removed checks that prevented users from using references to functions with the + optimization code and forced the use of function pointers. This was to avoid + triggering a bug in gcc 4.0. Since that compiler is no longer officially supported + by dlib I've removed these checks to increase usability. + - Made resize_image() use bilinear interpolation by default and also added a special + version of it that is optimized for this case. + - Dlib's cmake files will now automatically link to the Intel MKL on MS Windows + platforms if the MKL is installed. + + + + + +New Features: + - Added Python interfaces to dlib's structural support vector machine solver and + Hungarian algorithm implementation. + - Added running_cross_covariance + - Added order_by_descending_distance() + - Added is_finite() + - Added the csv IO manipulator that lets you print a matrix in comma separated value + format. + +Non-Backwards Compatible Changes: + - Changed the object detector testing functions to output average precision instead of + mean average precision. + - Added an option to weight the features from a hashed_feature_image relative to the + number of times they occur in an image. I also made it the default behavior to use + this relative weighting and changed the serialization format to accommodate this. + +Bug fixes: + - Fixed typo in learn_platt_scaling(). The method wasn't using the exact prior + suggested by Platt's paper. + - Fixed a bug in running_scalar_covariance that caused the covariance() and + correlation() methods to output the wrong answer if the covariance was negative. + +Other: + - Gave the image_window the ability to tie the mouse and keyboard events together such + that it is possible for a user to listen for both simultaneously. + - A number of changes were made to the structural_svm_problem's code which make it + significantly faster in some cases. + - Added Steven Van Ingelgem's patch to the HTTP server which makes operations on HTTP + headers case-insensitive. + + + + + + +New Features: + - Machine Learning: + - Added the svr_linear_trainer, a tool for solving large scale support vector + regression problems. + - Added a tool for working with BIO and BILOU style sequence taggers/segmenters. + This is the new sequence_segmenter object and its associated + structural_sequence_segmentation_trainer object. + - Added a python interface to some of the machine learning tools. These + include the svm_c_trainer, svm_c_linear_trainer, svm_rank_trainer, and + structural_sequence_segmentation_trainer objects as well as the cca() + routine. + - Added point_transform_projective and find_projective_transform(). + - Added a function for numerically integrating arbitrary functions, this is the + new integrate_function_adapt_simpson() routine which was contributed by + Steve Taylor + - Added jet(), a routine for coloring images with the jet color scheme. + +Non-Backwards Compatible Changes: + +Bug fixes: + - Fixed a bug in hysteresis_threshold() that caused it to produce incorrect + outputs in some cases. + - Fixed a segmentation fault in the eigenvalue_decomposition object which + could occur when NaN valued inputs were given. + +Other: + - Made image saving routines work on matrix objects in addition to array2d objects. + - The machine learning page now contains a flow chart to help new users + select a machine learning tool appropriate for their task. + + + + + + + +New Features: + - Object Detection Tools: + - Added another image scanning tool similar to scan_image_pyramid. This + is the new scan_image_boxes object. It allows a user to easily specify + an arbitrary set of object boxes which should be checked by an object + detector rather than scanning a fixed sized window over the image as is + done by the scan_image_pyramid tool. This allows more flexible scanning + strategies. For example, it is now possible to use the selective search + method implemented by the new find_candidate_object_locations() routine. + - Added the binned_vector_feature_image. + - Upgraded the object_detector so that you can use the adjust_threshold + argument for all versions of the operator() method. + - Added remove_unobtainable_rectangles() + - Hashing Tools: + - Added a set of new locality sensitive hashing functions meant for use + with larger vectors and higher bit sizes than the current LSH tools. + These are the new hash_similar_angles_xxx objects. + - Added find_k_nearest_neighbors_lsh() and hash_samples() + - Added create_max_margin_projection_hash() + - New Matrix Routines: linpiece(), fft(), and ifft() + - Added numeric constants and additional statistics to the running_stats + object. This code was contributed by Steve Taylor. + - Added the image_window::get_next_keypress() routine. This tool allows a + user to easily find out which keyboard key the user is pressing. + +Non-Backwards Compatible Changes: + - Changed the object_detector interface slightly. In particular, it no + longer adds the adjust_threshold argument to the output scores. + - The object detector testing functions now output mean average precision in + addition to precision and recall. + - Changed how dlib does serialization in a number of ways: + - The running_stats and scan_image_pyramid objects have had their + serialization format changed in a way that breaks backwards + compatibility. This means serialized versions of these objects can't be + loaded by the new version of dlib. + - Changed the format dlib uses when it serializes floating point values. + Previously, we used an ASCII based format. Dlib now uses a much more + efficient binary format. However, the deserialization routines have + been made backwards compatible with the previous format. So dlib can + still deserialize older data but older software won't be able to read + the new format. + - Changed the serialization formats for the matrix and array2d objects so + that either object can be deserialized into the other. This was done in a + way that is backwards compatible with previous versions of dlib. That is, + we can still load data serialized by previous dlib versions. However, + older versions of dlib can't load the new serialization format. + +Bug fixes: + - Fixed a bug in save_dng() that happened sometimes when saving images with + unsigned char pixels. + - The test_ranking_function() and cross_validate_ranking_trainer() routines + computed incorrect MAP scores when the learned function output a constant + value for all samples. This has been fixed. + +Other: + - Gave load_image_dataset() the ability to skip images that don't have any + ground truth boxes. + - Changed average_precision() to use interpolated precision. So now it uses + the same metric as the one used by the Pascal VOC challenge. + - Upgraded the dng image format so it can natively store floating point + pixel types without any information loss. + + + + + + +New Features: + - Added svd_fast(), a routine for computing a singular value decomposition of very + large matrices. + - Added cca(), a routine for doing canonical correlation analysis on very large + and high-dimensional datasets. + - Added tools for creating parallel for loops, see parallel_for(). + - Added some features to the image display widgets to let the user easily + get information about where the user is clicking. This is the new + get_next_double_click() routine. + - Added an operator>> for matrix objects which allows you to read in ASCII + matrices using the format used by operator<<. + - Added serialization support for std::vector<bool>. + - Added the following new minor objects and routines: average_precision(), + make_sparse_vector_inplace(), orthogonalize(), count_bits(), draw_surf_points(), + hamming_distance(), cosine_distance, and negative_dot_product_distance. + +Non-Backwards Compatible Changes: + - Changed ranking evaluation functions to return the mean average precision + in addition to just raw ranking accuracy. This changes their return types + from double to matrix<double,1,2>. + - Generalized segment_image() so it works on any pixel type or array of + vectors. I also changed its interface slightly. In particular, I removed + the min_diff parameter and replaced it with an explicit min_size parameter. + - Changed how the SURF descriptor is computed slightly to improve its + accuracy. The interface to the user has not been changed, however, the + number and position of detected SURF points might be different than in + previous dlib versions. + +Bug fixes: + - Fixed an endianness bug in the PNG I/O functions which occurred when 16bit + grayscale PNGs were used. + - Fixed a bug which could potentially occur when empty std::vector<char> or + std::vector<unsigned char> were serialized. + - There was a bug in the version of draw_line() that draws directly onto an + array2d type image (not the one that draws onto a GUI canvas object). The + bug triggered whenever a perfectly horizontal or vertical line that extended + outside the image was drawn. This has been fixed. + - Fixed a bug in the Windows implementation of the signaler object, which + was found by Isaac Peterson. The bug caused the program to deadlock if + signal() or broadcast() was called at exactly the same time a + wait_or_timeout() function timed out. + - Fixed a bug in the image_window and image_display GUI tools which caused + them to not redraw overlay lines correctly in certain cases involving + non-default zoom levels. + - Switched randomly_color_image() to use the non-pointer based version of + murmur_hash3() to avoid violation of the strict aliasing rule. In + particular, the previous version didn't work correctly in gcc 4.7.2 when + optimizations were enabled. + - Visual Studio 2012's iostreams library has a bug which caused the + iosockstream to crash on use. This version of dlib has been changed to + avoid triggering this bug. + +Other: + - Refactored the Platt scaling code a little. Now there is a function, + learn_platt_scaling(), that allows you to directly call the Platt scaling + code without supplying a trainer object. + - Optimized the oca and structural SVM solvers. They are now a little bit faster + than in previous dlib releases. + + + + + + +New Features: + - Machine Learning + - Added svm_rank_trainer, an optimized implementation of the SVM-Rank algorithm. + - Added rank_unlabeled_training_samples(), an implementation of the SVM Active + Learning algorithm. + - Added svm_c_linear_dcd_trainer, a warm-startable SVM solver using the dual + coordinate descent algorithm used by liblinear. + - Added the ability to force the last element of a weight vector to 1 to the + following objects: svm_c_linear_trainer, svm_c_linear_dcd_trainer, + svm_rank_trainer, and oca. + - Added the ability to learn non-negative weight vectors to the + structural_sequence_labeling_trainer object. + - Networking + - Added an iosockstream object. + - Added a method to the server object that lets a user set the graceful close timeout + time to something other than the default of 500ms. + - Linear Algebra + - Added the gaussian_randm() function. + - Added the find_affine_transform() function. + - Added the mat() function. It combines the array_to_matrix(), vector_to_matrix(), + pointer_to_column_vector(), and pointer_to_matrix() methods all into one convenient + interface. mat() also works for Armadillo and Eigen matrices. + - Added STL style begin() and end() methods to matrix and matrix_exp. + - Added an overload of sparse_matrix_vector_multiply() that multiplies a dense matrix + with a sparse vector. + - Made toMat() work with the matrix object in addition to array2d style images. + - Graphical User Interface Tools + - Added draw_solid_convex_polygon(). + - Added an overload of draw_image() that's useful for drawing images and doing + interpolation at the same time. + - Added the on_view_changed() callback to zoomable_region and scrollable_region widgets. + - Added parse_trees_to_string() and parse_trees_to_string_tagged(). + - Added lambda function support to the timeout object. + - Added the vectorstream object. + - Added the parse_xml() routines. + - Added a group name feature to the command line parser. Now it is possible to make + print_options() print related options in named groups. + - Added the following new hashing functions: murmur_hash3_128bit_3(), + murmur_hash3_2(), murmur_hash3_3(), uniform_random_hash(), gaussian_random_hash() + as well as hash() overloads for uint32, uint64, and std::pair. + +Non-Backwards Compatible Changes: + - Made the svm_c_linear_trainer use the risk gap to decide when to stop. This was done + because it is how all the other OCA based SVM tools in dlib decide when to stop. + However, it might cause the outputs to be slightly different in this version of dlib. + - It is now illegal to call unlock() on a mutex when the mutex is not owned by the + calling thread. The most likely reason for doing this was to unlock early in an area + locked by an auto_mutex. Old code that does this can be fixed by calling auto_mutex's + unlock() function instead. + - Removed the structural_assignment_trainer::learns_nonnegative_weights() routine + and moved its functionality into the feature extraction interface used by this object. + +Bug fixes: + - Fixed a bug in find_max_factor_graph_nmplp() which caused it to not work properly on + some compilers. + - Fixed a bug pointed out by Joel Nelson in the version of md5() that took an istream. + The bug caused the function to crash on strings longer than 56 characters. + +Other: + - dlib now has an excellent new logo thanks to Yasser Asmi. + - Added a new documentation page for the various linear algebra tools. + - The following objects were turned into single implementation components: + sockstreambuf, timeout, member_function_pointer, xml_parser, linker, + bound_function_pointer, and timer. + + + + + + +New Features: + - Machine Learning + - Added the ability to learn non-negative weight vectors to the + structural_assignment_trainer object. + - Added two new graph clustering algorithms: Chinese Whispers and Newman's modularity + clustering. + - Added a number of new graph manipulation tools: sparse_matrix_vector_multiply(), + is_ordered_by_index(), find_neighbor_ranges(), convert_unordered_to_ordered(), + remove_duplicate_edges(), and the ordered_sample_pair object. + - Networking + - Added a set of tools for creating applications using the Bulk Synchronous Parallel + computing model. See the new bsp_ex.cpp example program for an introduction. + - Added a routine that lets a user disable Nagle's algorithm on a TCP connection. + - Added an asynchronous start routine to the server object. This is the new + start_async() method. + - Added the network_address object. + - Added connect_to() to the bridge interface. + - Added find_max_parse_cky(), a method implementing the well known CKY algorithm for + parsing probabilistic context free grammars. + - Added the ability to label parts of objects with the mouse to the image_display + widget. + - Added the ability to put overlay circles and full_object_detections into the + image_window widget. + - Added a stddev() for matrix objects. + - Added operator+() for running_stats and running_scalar_covariance. + - Added an overload of murmur_hash3_128bit() that takes 4 integers instead of a block of + memory. + - Added rand::get_random_64bit_number(). + +Non-Backwards Compatible Changes: + - Changed the image_dataset_metadata XML reading tools to use a map of strings to points + to represent object parts. This change removes the old head point from a box since + this information can now be represented in the parts map. + - The syntax for passing order_by_distance and order_by_index to std::sort() is now + slightly different since these functions are now templates. However, this change + allows them to work on any kind of sample_pair or ordered_sample_pair object. + - The default distance value of a sample_pair is now initialized to 1 instead of + infinity. + +Bug fixes: + - Added a patch, contributed by Martin Müllenhaupt, to fix a minor bug in the SQLite + bindings. + - Fixed a typo which would prevent code that called running_stats::max_n() from + compiling. + +Other: + - Added a new documentation page for the various graph tools in dlib. + - Added support for Visual Studio 2012. + - Switched the sample_pair object to use double to store its distance value instead of + float. + - Added William Sobel's patch to the web server that improves its flexibility and + security. + - Changed the server object so you don't have to use the server::kernel_1a syntax to + declare it anymore. Now you just say server, server_iostream, or server_http + depending on which one you want. + - Changed the cmd_line_parser so you don't have to use the ::kernel_1a syntax anymore. + Now it is declared like a normal single implementation object. + - Set the default max number of connections a server will accept at a time to 1000 + rather than the previous default of infinity. + + + + + + +New Features: + - Added more overloads of find_max_factor_graph_potts() to make applying it + over a Markov random field of image pixels really simple. + - Added overloads of serialize()/deserialize() so that they can serialize + Google protocol buffer objects. + - Image Processing: + - Added find_points_above_thresh() + - Added max_filter() + - Added scan_image_movable_parts() + - Added sum_filter_assign() + - Added the full_object_detection object. + - Added the ability to model objects with movable parts into the + scan_image_pyramid object. This update also includes all the needed tools + to train movable part models using the structural_object_detection_trainer. + - Machine Learning: + - Added a per node loss option to the structural_svm_graph_labeling_problem's + interface. + - Added Emanuele Cesena's implementation of Sammon's nonlinear dimensionality + reduction method. + +Non-Backwards Compatible Changes: + - To support movable part models, the serialization format of scan_image_pyramid + objects was modified. This breaks backwards compatibility with the previous + format for scan_image_pyramid objects as well as object_detector instances + that use the scan_image_pyramid. + +Bug fixes: + - Fixed a bug in auto_threshold_image() that caused it to give bad outputs + when used with very large images. + +Other: + - Updated find_max_factor_graph_potts() to correctly say you can use infinite + weights for the factor_value_disagreement() values since the code actually + supports this. + - Made integer serialization about 3 times faster. + + + + + + +New Features: + - Improvements to linear algebra tools: + - Added the lowerbound() and upperbound() routines for thresholding dense + matrices. + - Refined the tools for working with sparse vectors. In particular, + the following functions were added: min(), max(), make_sparse_vector(), + add(), and subtract(). A number of existing routines were also updated + to work with both sparse and dense vectors so that templated code which + works on both vector types is simpler to write. + - Added the += and -= operators to the set_subm(), set_rowm(), and set_colm() + tools for operating on submatrices. + - Optimization: + - Added a new quadratic program solver, solve_qp4_using_smo(). This new + solver is useful for solving quadratic programs corresponding to + non-negative constrained primal quadratic programs. + - Added an optional non-negativity constraint to the oca optimizer. + - Added the min_cut object. It provides a method to find the minimum weight + cut on a graph. + - Added tools for finding the maximum probability assignment in a Potts + style Markov random field. See the find_max_factor_graph_potts() routine + for details. + - Machine Learning: + - Added structural SVM tools for learning the parameters of a Potts style + Markov random field. See the structural_graph_labeling_trainer and + graph_labeler objects as well as their associated example program for + details. + - Added the ability to learn only non-negative weights to the + svm_c_linear_trainer. + - Improved Integration with OpenCV: + - Updated the cv_image object so it works with cv::Mat as well as IplImage. + - Added the toMat() routine for converting from a dlib style image to an + OpenCV cv::Mat image. + +Non-Backwards Compatible Changes: + - Removed the dlib::sparse_vector namespace. Everything from this namespace + was moved into the normal dlib:: namespace so that code which works with + both sparse and dense vectors is more cohesive. + +Bug fixes: + - Fixed a bug in find_max_factor_graph_viterbi() which sometimes occurred when + the model order was larger than the number of variables. + - Fixed a bug which caused a compiler error if you tried to call dot() on two + 1x1 matrices which were statically dimensioned. + +Other: + - Improved existing documentation: added pictures of the gui widgets, + added documentation of the dlib::bridge protocol, and other minor + usability improvements. + + + + + +New Features: + - Image Processing: + - Added the option to make the features generated by poly_image rotationally + invariant. + - Added a set of routines for warping, scaling, and resizing images. + See the new "Scaling and Rotating" section of the image processing + documentation for details. + - Added the heatmap() routine for converting an image into a heatmap. + - Machine Learning + - Updated the sequence labeling trainer to allow the user to set different + loss values for different labels. + - Added the rls object. It is an implementation of the linear recursive + least squares algorithm. + - Added the get_option() routines which slightly simplify option parsing + from the command line and config files. + - Added the 128bit version of Murmur hash. + - Added the kalman_filter and rls_filter objects. These are tools for + performing Kalman filtering and recursive least squares filtering. + - Added the circular_buffer object. + +Non-Backwards Compatible Changes: + - The poly_image generates slightly different features in this new release. + Therefore, classifiers trained using the previous version will need to be + retrained if they are switched to the new version of poly_image. + - Changed the xcorr() functions so they take the complex conjugate of the right + hand arguments if they are complex numbers. This way they do a proper + cross-correlation and also mirror the behavior of MATLAB. However, this + breaks backwards compatibility with the previous behavior of xcorr(). + - Previously, dlib included two versions of dlib::array. However, to + simplify the user interface, dlib now includes only the contiguous + memory implementation of dlib::array. This change should only affect + you if you wrote code which assumed dlib::array::set_max_size() only + allocated a small amount of RAM. The new behavior is similar to the + std::vector::reserve() routine. That is, dlib::array::set_max_size() + will allocate the requested amount of memory immediately. + +Bug fixes: + - Fixed a bug which caused certain matrix expressions to not compile + when the BLAS bindings were enabled. In particular, expressions which + involved a 1x1 matrix sometimes didn't compile. + +Other: + - Made the matrix routines min(), max(), sum() and mean() work with + complex numbers. + - Turned the array object into a single implementation object. Now arrays + can be created using the normal array<type> obj; syntax. Additionally, + all extensions were merged into the array object. + - Added an example program which better documents how to create training + data for the object detection tools as well as how this data can be used. + See the train_object_detector.cpp example for details. + + + + + + +New Features: + - Added tools for timing blocks of code + - Machine Learning + - Added a set of tools for learning to solve the assignment problem. + See the structural_assignment_trainer and its associated example + program for an introduction. + - Added random projection based locality sensitive hashing tools. + - Added tools to simplify the creation of scan_image_pyramid objects. + See the object_detector_ex.cpp example program for details. + - Image Processing + - Added sum_filter() and spatially_filter_image_separable_down() + - New feature extractors: poly_image, nearest_neighbor_feature_image, and + fine_hog_image + +Non-Backwards Compatible Changes: + - Changed the serialization format for rand objects. + - Changed the order of arguments for the sequence_labeler's constructor. + - Object Detection Changes + - Some parts of the object detection tools have been refactored. In particular, + the interfaces of the scan_image_pyramid and structural_object_detection_trainer + have been changed slightly to improve usability. + - Made the test_box_overlap a little more flexible. This change breaks + backwards compatibility with the previous version though. + - The hashed_feature_image object has been made more general. It now + uses a user supplied hashing function rather than its own hashing + implementation. + - Removed constness from the operator() member functions of the + object_detector. + - Fixed improper normalization in the gaussian() functions. The + normalization constant was being computed incorrectly. + - Sequence labeling feature extractors must now define a sequence_type + typedef instead of sample_type. This change allows the user to use any + type of sequence, not just std::vector objects. + +Bug fixes: + - Changed the add_probability() method of joint_probability_table so + it does a saturating add rather than a normal add. This ensures the + probability value stays exactly <= 1. Previously, floating point + rounding error could cause it to be slightly above 1 and would therefore + cause some asserts to misfire during debugging mode. + - The object_detector had code in it which limited the number of outputs + to 100 rectangles. This has been removed. + - Fixed improper normalization in the gaussian() functions. The + normalization constant was being computed incorrectly. + +Other: + - dlib::rand can now generate Gaussian random numbers. + - The structural_object_detection_trainer will now automatically setup + appropriate non-max suppression parameters if the user doesn't supply them. + - The structural_object_detection_trainer has been optimized and now runs + significantly faster than in previous dlib releases. + - The tools folder containing htmlify, imglab, and mltool is now included + in the dlib release archive files. Previously, these tools were only + available directly from source control. + + + + + + +New Features: + - Machine Learning + - Added the histogram intersection kernel for sparse and dense vectors. + - Added a set of tools to allow a user to easily learn to do sequence + labeling using dlib's structural SVM implementation. See the new + sequence_labeler object and its associated example program for an + introduction. + - Image processing: + - Added segment_image() + - Added randomly_color_image() + - Added the border_enumerator + - Added the disjoint_subsets object, it is an implementation of the + union-find algorithm/disjoint-set data structure. + - Added new matrix routines: conv(), conv_same(), conv_valid(), xcorr(), + xcorr_same(), xcorr_valid(), and flip(). + +Non-Backwards Compatible Changes: + - Changed find_max_factor_graph_viterbi() so you can use run-time + defined order and num_states parameters. + +Bug fixes: + - The last dlib release added a max_iterations parameter to the + svm_c_linear_trainer and svm_c_ekm_trainer objects. However, + there was a bug which made them only do at most 16 iterations, + which is too few to solve many problems. This has been fixed. + - Fixed a bug in is_const_type. It didn't work for reference types. + - Fixed a bug in the SQLite binding routine statement::get_column_as_text(). + It didn't work correctly if the column contained a NULL. + - Fixed a bug in find_max_factor_graph_viterbi() which occurred when a + zero order model had negative factor values. + +Other: + + + + + + +New Features: + - Two new routines for performing MAP inference in factor graphs: + - For chain-structured graphs: find_max_factor_graph_viterbi() + - For general graphs: find_max_factor_graph_nmplp() + - Image Processing + - Added more tools for creating image pyramids. See pyramid_down_5_4, + pyramid_down_4_3, and pyramid_down_3_2. + - Added more image filtering and morphology functions. + - Added a set of tools for creating sliding window classifiers: + - Added the scan_image() routine. It is a tool for sliding a set of + rectangles over an image space and finding the locations where the sum + of pixels in the rectangles exceeds a threshold. Also added + scan_image_pyramid, which is a tool for running scan_image() over an + image pyramid. + - Added the structural_object_detection_trainer. This is a tool which + formulates the sliding window classifier learning problem as an + instance of structural SVM learning. + - Added a variety of supporting tools and two object detection example + programs. + - Added the following functions for computing statistics on vectors: + mean_sign_agreement(), correlation(), covariance(), r_squared(), + and mean_squared_error() + - Added a C++ wrapper for SQLite (see the new database and statement objects) + +Non-Backwards Compatible Changes: + - Changed the interface to the ridge regression trainer objects so that they + report the entire set of leave-one-out prediction values rather than a + summary statistic like mean squared error. + - Changed the serialization routine for bgr_pixels to store the pixels in BGR + order rather than RGB. + - Changed the interface for the spatially_filter_image() routine to take the + filter as a matrix rather than C-array. Also, now it won't force signed pixel + values to 0 if they go negative. + - Changed the test_regression_function() and cross_validate_regression_trainer() + routines so they return both the MSE and R-squared values rather than just the + MSE. + - Changed suppress_non_maximum_edges() to use the L2 norm instead of L1 norm + for measuring the strength of an edge since this produces a slightly better + result. + +Bug fixes: + - The image_display didn't display overlay rectangles quite right. If you zoomed + in you could see that some of the pixels which are inside the rectangle were + outside the overlay. Specifically, the right column and bottom row was outside + the overlay rectangle. This has been fixed. Now all pixels which are supposed + to be part of a rectangle are drawn as being inside the overlay rectangle. + - Fixed a bug pointed out by Martin Müllenhaupt which caused the windows socket + code to not compile when used with the mingw-cross-env project. + - Fixed a bug in the png_loader. If you loaded an image with an alpha channel + into something without an alpha channel there were uninitialized values being + alpha blended into the image. + - Fixed a bug in the cpp_tokenizer that only shows up on newer versions of gcc. + It wasn't tokenizing double quoted strings right. + - Fixed a bug in spatially_filter_image() which showed up when using non-square + filters. The bug would cause the edges of the output image to be incorrect. + - Fixed a bug in the matrix class. Expressions of the form mat *= mat(0) would + evaluate incorrectly because the *= operator took the right hand side by reference + and thus experienced an aliasing problem. The other op= operators had similar + problems and have also been fixed. + - Fixed a bug pointed out by Justin Solomon which could cause the svr_trainer and + svm_c_trainer to produce incorrect results in certain unusual cases. + +Other: + - Added a more complete set of methods for converting between image space and + the downsampled hog grid used by hog_image. Now you can convert from image + to hog in addition to hog to image. + - Made the integral_image more general by making it templated on the type of + scalar used to store the sums. + + + + + + +New Features: + - Added the check_sub_option() method to the command line parser check + object. + - Added match_endings to the dir_nav utils. + - Added a set_current_dir() function. + - Added the distance_to_rect_edge() routine. + - Added support for user drawn rectangle overlays and selectable overlays + to the image_display widget. + +Non-Backwards Compatible Changes: + +Bug fixes: + - Fixed a bug in the image_display widget. If you switched it between + images of a different size while any kind of zoom was in effect + it could cause a segmentation fault. + +Other: + + + + + + +New Features: + - You can now add tasks to a thread_pool by value, using the new + add_task_by_value() method. + +Non-Backwards Compatible Changes: + +Bug fixes: + - Fixed a bug which caused multiply defined symbol errors during linking + if the PNG saving routine was #included. + +Other: + - Optimized the threaded and distributed structural svm solvers for the + case where there are many data samples and the separation oracle is + quick to evaluate. + + + + + + +New Features: + - Added a function for saving to the PNG image format. + - Added Austin Appleby's excellent MurmurHash3 hashing code and setup some + additional convenience functions. These functions are murmur_hash3() and + various overloads of hash(). + +Non-Backwards Compatible Changes: + - Made get_pixel_intensity() more efficient. However, the value returned + is slightly different than it used to be for RGB pixel types. + +Bug fixes: + - Setup proper error handling for libpng in the png_loader. Now if the PNG + file is corrupted in the middle it won't just print a message and abort + the program. + - Fixed a bug in assign_pixel_intensity() that happened when the target pixel + was an RGB pixel with an alpha channel. + +Other: + - Added a Frequently Asked Questions page + - Changed the array2d object so you don't have to say array2d<type>::kernel_1a + anymore to declare it. Now you just say array2d<type>. + + + + + + +New Features: + - Added tools for distributing the work involved in solving a structured + SVM problem over many computers and CPUs. + - Added the bridge. It allows a dlib::pipe to be used for networked + communication. + +Non-Backwards Compatible Changes: + - Removed the DLIB_REVISION macro and replaced it with DLIB_MAJOR_VERSION and + DLIB_MINOR_VERSION. + +Bug fixes: + +Other: + - dlib's version control system has switched from Subversion to Mercurial. + + + + + + +New Features: + - Added the max_sum_submatrix() function to the optimization tools. + - Upgraded the pyramid_down function object so it can create color pyramids. + Also, added some functions which define the coordinate transforms between + different layers in an image pyramid. + +Non-Backwards Compatible Changes: + - Changed the oca_problem interface to the oca optimizer. Now the + optimization_status() function reports back the current risk and risk gap + in addition to the overall objective value and objective gap. + - Changed the stopping condition for the structured svm to the one suggested + by the Joachims paper. Now it stops when the risk gap is below a user + supplied epsilon. + +Bug fixes: + +Other: + - Various usability improvements. + - Improved the feature vector caching in the structural_svm_problem object. + - Some objects were setup as multi-implementation objects but only had one + implementation. I went through dlib and switched these to single implementation + objects. So for example, to use the dlib crc32 module you used to declare an + object of type "crc32::kernel_1a" but now you can just say "crc32". Note that + I did this change in a way that maintains backwards compatibility with previous + versions. So crc32::kernel_1a is still allowed but that form is officially + deprecated. The modified objects are as follows: + - base64 + - byte_orderer + - config_reader + - crc32 + - pipe + - rand + + + + + + +New Features: + - Added a multiclass support vector machine. + - Added a tool for solving the optimization problem associated with + structural support vector machines. + - Added new functions for dealing with sparse vectors: add_to(), + subtract_from(), max_index_plus_one(), fix_nonzero_indexing(), a + more flexible dot(), and I renamed assign_dense_to_sparse() to assign() + and made it more flexible. + +Non-Backwards Compatible Changes: + - Renamed max_index_value_plus_one() (a function for working with graphs) to + max_index_plus_one() so that it uses the same name as the essentially + identical function for working with sparse vectors. + - I simplified the cross_validate_multiclass_trainer(), cross_validate_trainer(), + test_binary_decision_function(), and test_multiclass_decision_function() + routines. They now always return double matrices regardless of any other + consideration. This only breaks previous code if you had been assigning + the result into a float or long double matrix. + - Renamed assign_dense_to_sparse() to assign() + +Bug fixes: + - Fixed a bug in load_libsvm_formatted_data(). I had forgotten to clear the + contents of the labels output vector before adding the loaded label data. + - Fixed a bug in the kernel_matrix() function. It didn't compile when used + with sparse samples which were of type std::vector<std::pair<> >. + Moreover, some of the trainers have a dependency on kernel_matrix() so this + fix makes those trainers also work with this kind of sparse sample. + +Other: + - Added a value_type typedef to matrix_exp so it's easier to write templates + which operate on STL containers and matrix objects. + + + + + +New Features: + - Added an implementation of the Hungarian algorithm for solving the optimal + assignment problem (in the new max_cost_assignment() routine). + +Non-Backwards Compatible Changes: + +Bug fixes: + - Fixed a problem which prevented the any_function unit test from compiling + in visual studio 2008. + +Other: + - Changed the oca optimizer so that it warm starts the QP subproblem + rather than resolving it from scratch during each iteration. This + improves the speed and stability of the algorithm. + + + + + +New Features: + - Added the rr_trainer. It is a version of krr_trainer which is optimized + for use with the linear_kernel. + - Added the approximate_distance_function() routine. It is the core optimizer + behind the reduced2() trainer adapter. + - Added an any_function which supports the same functionality as std::function + from the upcoming C++0x standard. I added this so dlib can be modified to + easily support lambda functions while still being compilable with compilers + which don't support the new std::function. + - Added overloads of all the GUI event handlers so you can use general functions + as callbacks (via any_function). This way, if you have a C++0x compiler, you + can use lambda functions with the event handlers. + - Added the split() function for splitting up strings. + +Non-Backwards Compatible Changes: + - Improved the distance_function object by turning it into a properly + encapsulated class rather than just a simple struct. I also added + overloaded +, -, *, and / operators for this object so you can do the + kind of arithmetic you would expect on an object which represents a + point in a vector space. This breaks backwards compatibility with + the previous interface though as the member variables are now private. + +Bug fixes: + - Fixed a compile-time error in the kernel_matrix(). + - Fixed a bug in an assert in the spatially_filter_image() function. + - Applied a patch from Nils Labugt which fixes a runtime bug in the gui_core + component. The bug caused a crash when using X11 and Ubuntu 10.10 in + certain cases. + - Updated code so that it compiles with the clang compiler. + +Other: + - Updated the image_display widget so you can zoom in and out using the + mouse wheel. + + + + + +New Features: + - General Stuff + - Added the promote template + - Added the basic_type template + - Added the assign_image_scaled() function + - Added the unordered_pair object. + - Added the symmetric_matrix_cache() function + - Added two new quadratic program solvers. The solve_qp2_using_smo + and solve_qp3_using_smo objects. + + - Machine Learning Stuff + - Added the svm_c_trainer object. It implements C-SVM classification and + allows the user to set different C values for each class. + - Added the svm_one_class_trainer object. + - Added the svr_trainer object. It implements epsilon-insensitive + support vector regression. + - Added two new any objects. The any_decision_function for containing + decision function style objects and the any_trainer for trainers. + - Added cross_validate_regression_trainer() + - Added test_regression_function() + - Added the probabilistic() function. It is a trainer adapter that + simply calls train_probabilistic_decision_function(). + - Added tools for multiclass classification + - Added one_vs_one_trainer + - Added one_vs_all_trainer + - Added cross_validate_multiclass_trainer() + - Added test_multiclass_decision_function() + +Non-Backwards Compatible Changes: + - invalid_svm_nu_error has been renamed to invalid_nu_error. + - Changed the pixel_traits so that signed grayscale pixel types are allowed. + This involved adding a few new fields to pixel_traits. I also changed the + get_pixel_intensity() function so that its return value is of the same type + as the basic pixel type rather than always being unsigned long. + - Removed the kernel_type typedef from the normalized function since this + meta-object should be capable of working with non-kernel decision functions. + - train_probabilistic_decision_function() no longer accepts column vectors of + samples and labels. Now it only accepts std::vectors of samples and labels. + +Bug fixes: + - Fixed a bug in the deserialization code for the sparse kernels. The bug + prevented code which used the deserialize() routine from compiling. + +Other: + - Changed the image display GUI widgets to use the assign_image_scaled() + function internally. Now they will accept just about any image and + do the right thing. + - Modified the type_safe_union so that you can call apply_to_contents() on const + type_safe_unions. + - Added serialization support for std::pair objects. + - Made the train_probabilistic_decision_function() more general by making it work + with any kind of trainer object rather than only ones which produce + dlib::decision_function objects. I also made it work with trainers that only + take std::vectors. + - Added overloads to the config_reader's methods to allow it to load directly + from a file name given as a string in addition to taking istream objects. + + + + + +New Features: + - Added the ability to add/subtract scalar values to/from all the elements + of a matrix using the - and + operators. + - Added a trust region optimizer. + - Added Levenberg-Marquardt and LM/quasi-newton hybrid methods for solving + non-linear least squares problems. + - Added an any container object. + +Non-Backwards Compatible Changes: + +Bug fixes: + - Fixed a compiler warning and also a runtime bug in sort_basis_vectors(). + The bug triggered when all the basis vectors were included in the final + answer. + +Other: + - Added a bunch of overloads to catch operations on diagonal matrices + and use more efficient code paths for them. For example, inv(diagm(d)) + turns into diagm(reciprocal(d)). Multiplication by a diagonal matrix + is now also handled efficiently. + + + + + + +New Features: + - Added a class for reading JPEG image files. + - Added scale_rows(), flipud() and fliplr() matrix functions. + - Added console_progress_indicator. It is a tool for measuring how long a + task will take. + - Added sort_basis_vectors(). It is a function for performing supervised + basis selection. + +Non-Backwards Compatible Changes: + - Renamed the linearly_independent_subset_finder's dictionary_size() member + function to size(). This way, linearly_independent_subset_finder objects + can be used in many templated functions which expect objects which look + like arrays. + +Bug fixes: + - Changed the assert macros so that they don't use __PRETTY_FUNCTION__ + with gcc 4.4.5 since, on Ubuntu at least, this version of gcc segfaults + when __PRETTY_FUNCTION__ is used within certain template constructs. + - Improved the alias detection capability of kernel_matrix() expressions. + Now statements of the form: sample = kernel_matrix(kern, *, sample) can + be used since the aliasing of sample will be handled. + +Other: + - Generally tried to make things more usable. + - Optimized matrix expressions such as mat*diagm(vect) + - Made the code in chol() more robust to indefinite matrices. + + + + + + +New Features: + - Added the running_scalar_covariance object. + - All the matrix decomposition routines now use LAPACK when DLIB_USE_LAPACK + is #defined. + +Non-Backwards Compatible Changes: + - Removed the dlib::EOTHER constant since it conflicts with visual + studio 2010. + - Changed the svd functions so you can't supply output matrices which use + both column and row major layouts. Now all the output matrices need to + use the same memory layout. + - Removed the qr_decomposition::get_householder() function. + +Bug fixes: + - Minor fixes so that dlib compiles in Visual Studio 2010 + +Other: + - Added an overloaded matrix_assign() that handles symmetric kernel_matrix() + expressions more efficiently by only evaluating the upper triangular part + of the matrix. + + + + + + +New Features: + +Non-Backwards Compatible Changes: + +Bug fixes: + - Fixed a compile-time bug in the matrix related to multiplication by + subm() expressions when the BLAS bindings were enabled. + - Fixed a bug in train_probabilistic_decision_function() which could + cause it to go into an infinite loop when working with very large + datasets. + +Other: + + + + + + +New Features: + - Added a reference_wrapper implementation and modified the thread_function + slightly so it works with it. + - Added an implementation of kernel ridge regression. + - Added a simple newton search strategy for optimizing functions. + +Non-Backwards Compatible Changes: + - If you have created your own matrix expressions then its possible this + new release will cause them to not compile. + +Bug fixes: + - Fixed a bug in scale_columns. It said it didn't have any destructive aliasing + when in fact it destructively aliased its second argument. + - Fixed a bug in the random number generator where setting the seed back to "" + didn't result in the object going back to its initial state. + +Other: + - Reorganized the matrix expression code. It's now much simpler and the + library includes a new example program which details the steps needed to + create new matrix expressions. + - Changed the train_probabilistic_decision_function() routine so that it uses + a more numerically stable method to perform its maximum likelihood optimization. + - Added missing get/set epsilon functions to the RVM training objects. + I also changed the default epsilon from 0.0005 to 0.001. + + + + + +New Features: + - Added the simplify_linear_decision_function() routines. + - Added the find_approximate_k_nearest_neighbors() function. + - Added the fill_lisf() function. + +Non-Backwards Compatible Changes: + - Made the sample_pair have a default distance of infinity instead of + the max floating point value. I also reworked the graph creation functions + to make them a little more versatile. Now you can use infinite distances to + indicate that certain nodes are not connected at all. + - Changed the linear_manifold_regularizer to normalize the regularization + parameter by the sum of edge weights instead of the sum of edges. + +Bug fixes: + - Fixed a bug in the timer_kernel_2 object. In certain rare cases it would + stop calling the action function and essentially shut down without being + told to do so. + +Other: + - Made the reduced() and reduced2() functions more efficient. + - Many small usability improvements here and there. + + + + + + +New Features: + - Added the svm_c_ekm_trainer. It is a kernelized version of the fast + linear trainer svm_c_linear_trainer. + - Added the linear_manifold_regularizer and some supporting tools. + - Added the sum_rows(), sum_cols(), join_rows(), join_cols(), reshape(), + and pointer_to_matrix() functions. + - Added the randomly_subsample() function. + +Non-Backwards Compatible Changes: + +Bug fixes: + - Fixed some minor compile time bugs on certain older compilers. + +Other: + - Updated the += and -= matrix operators to be a little more flexible. Now + if you try to apply them to a matrix of the wrong size it will automatically + resize the target matrix and just do a normal assignment. + - Removed the requirement that you load dng files into an image of the exact + pixel type that created the file. Now you can use any pixel type. I also + changed the code so that grayscale pixels with more than 16 bits get saved as + 16 bit grayscale images instead of 8 bit images. + + + + + + +New Features: + - Added the solve_qp_using_smo() function to solve certain quadratic + programs. + - Added the oca object. It is an implementation of the Optimized Cutting + Plane Algorithm. + - Added a linear SVM trainer that uses oca. + - Added an implementation of the Histogram of Oriented Gradients algorithm + - Added a simple tool for making image pyramids + - Added the running_covariance object + - Added a simple linear (i.e. non-kernelized) kmeans implementation + - Added support for serializing dlib::int64 + - Added some functions to load and save LIBSVM formatted data files. + +Non-Backwards Compatible Changes: + - Changed the definition of dlib's sparse vector format to require + unsigned integral keys. Having this requirement is nice because it + creates a simple correspondence between dense vector index values and + sparse vector keys. The previous sparse vector definition was + excessively generic. + - Renamed sparse_vector::dot_product() to sparse_vector::dot() so that + both dense and sparse vectors have a global function with the same + name (i.e. dot()). + +Bug fixes: + - Fixed a bug discovered by Mitchell Peabody. In some instances trying to + deserialize a std::vector would fail to compile. + +Other: + - Increased the number of template arguments of the type_safe_union from 10 + to 20. Additionally, I made the get_id() function public and renamed it + to get_type_id(). I also added a comment explaining the serialization + format of type_safe_union objects. + - Moved the optimization algorithms into their own page in the documentation. + - Added a Suggested Books page to the documentation + + + + + + +New Features: + - Added the ability to compute transformation matrices that map between + the representations used by different empirical_kernel_maps. Also added + the ability to compute projection error. + - Added the random_subset_selector object. + - Added the compute_mean_squared_distance() function. + +Non-Backwards Compatible Changes: + - Modified the logger's hook implementation so that it uses a special stream + buffer instead of an std::ostringstream. This way logging doesn't cause + memory allocations. This breaks backwards compatibility with the previous + hook function API but only slightly. The new hook functions must take a + const char* instead of std::string. + - Added the const_ret_type typedef to the matrix_exp. It is now required that + all matrix expressions define this type. This enables the expressions to + return elements by constant reference when appropriate rather than always + returning by value. + +Bug fixes: + - Fixed a bug in the matrix BLAS bindings that caused BLAS to return an invalid + argument error. The error occurred when general matrix multiply expressions + were transposed and didn't result in a square matrix. E.g. mat = trans(a*b) + where mat isn't square. + - Fixed potential compile time bugs in the comparison operators for futures. + - Added a missing check for division by zero in the SURF feature extractor. + - Modified the find_min_single_variable() function so that it is more + robust when working with functions that are made up of a bunch of + constant value strips. Previously, these kinds of functions could + cause the optimization to fail. + +Other: + - Changed the regression test suite so that when it sets the logging level + it now sets it for all loggers. Not just ones that start with "test." + + + + + + + +New Features: + - Added some MATLAB style thresholding relational operators to the matrix. + - Added the kernel_matrix() functions. + - Added the empirical_kernel_map object. + - Added the discriminant_pca object. + - Added the read_write_mutex object. + +Non-Backwards Compatible Changes: + - Renamed the support_vectors member of the decision_function and + distance_function classes to basis_vectors. This name more appropriately + reflects how these two classes are used within the library. + - Changed the matrix_exp interface slightly. This could only impact users + who created their own custom matrix expressions. If you don't get a + compiler error then you don't have to worry about it. + +Bug fixes: + - Fixed a minor error in the LBFGS code. + - Added a missing check for division by zero to the kcentroid, krls, + and linearly_independent_subset_finder classes. If someone added + the zero vector to them as the first training example a division by zero + could result. + - There were a few cases where the code wouldn't compile when using + matrices of complex numbers. There was also a runtime bug that triggered + when a rank 1 update was performed where one of the vectors was conjugated + and two or more transposes were used in certain positions. This bug + caused the wrong output to be computed if the BLAS bindings were used. + Both of these bugs have been fixed. + - Fixed a bug in the http server that affected cookies with certain kinds of + data. The result was invalid data being sent back to the web browser. + +Other: + - Generally improved the BLAS bindings for the matrix object. + + + + + + +New Features: + - Added the pointer_to_column_vector function. + - Added the BOBYQA algorithm for derivative-free optimization. + - Added some functions to make it easy to do a line search on a function + of a single variable when derivatives are not available. + +Non-Backwards Compatible Changes: + +Bug fixes: + - Fixed a bug in the cpp pretty printer. It wasn't parsing + exponentiated numbers like 1e100 correctly. + +Other: + - Added a model selection example program using grid search + and the new BOBYQA algorithm. + + + + + + +New Features: + - Added an implementation of the L-BFGS algorithm for unconstrained non-linear + optimization. + +Non-Backwards Compatible Changes: + - Refactored the optimization code. It is now much more flexible but + this resulted in changes to the API. See the optimization example program + for a discussion of the new API. + +Bug fixes: + - Fixed a bug in the get_filesystem_roots() roots function that + prevented it from compiling. + +Other: + + + + + + +New Features: + - Added the ability to use a kernel cache to the batch_trainer object. + - svm_pegasos can now be configured to use two different lambda arguments + for use with unbalanced data. + - Added the reciprocal_max() and dot() matrix functions. + - Added the bgr_pixel and cv_image objects so that OpenCV images can + be easily used with dlib routines. + +Non-Backwards Compatible Changes: + - I changed the batch trainers so that they always call clear() on the + trainer being used before training begins. + - Modified the svm_pegasos class so that the user can set independent lambda + parameters for each class. This breaks backwards compatibility with + the previous interface slightly and changes the serialization format + of this class. + - Split the vector_normalizer into a normal normalizer and a pca normalizer + version. + - The zoomable_region widget now uses exponential rather than linear + zoom scaling since this is much more pleasing to use. There is now + a new requirement on the zoom increment that it must be between 0 + and 1. + +Bug fixes: + - Fixed a bug in the cross_validate_trainer_threaded() function. It could + deadlock if more than about 10 folds were requested. + - Fixed the serialization functions for the normalized_function object. + They will now work with custom normalizer function objects. + - Fixed a minor bug in the zoomable_region::set_min_zoom_scale() function. + It didn't always end up zooming in a smooth sensible manner after this + function was called. + +Other: + - Made the thread_function object more general. It can now handle + arbitrary functions of up to four arguments. + + + + + + +New Features: + - Added the reshape_to_column_vector() function. + - Added a hook to the logger object that lets you set a different kind of + output logging destination (in addition to the std::ostream supported + already). + - Upgraded the scoped_ptr so that it can handle array pointers as well + as customer deleter functions. + - Added overloads of the kernel_derivative object for all the kernels + in dlib. + +Non-Backwards Compatible Changes: + - Reworked the config_reader interface a little to make it easier to use. + In particular, I removed the enumerator over blocks in favor of a simple + get_blocks() function that just returns a std::vector of all the blocks. + I also removed the requires clauses on the block and key accessor functions + and instead made a request for a non-existent key/block result in a non-fatal + exception. This way users can let the config reader perform a more natural + role in config file validation (by catching this exception and acting + accordingly). + - It is now illegal to multiply matrices of size zero together. + +Bug fixes: + - Fixed the gaussian() function used by the SURF code. It wasn't computing + a properly weighted Gaussian function. + - Fixed a few things in various parts of the code to avoid compiler errors + in certain use-cases. + - Added a missing rethrow statement. The xml parser would eat exceptions + thrown by event handlers rather than letting them propagate out as + documented in the specification. + +Other: + + + + + + +New Features: + - Added an implementation of the SURF algorithm which includes the + following new objects and functions: integral_image, hessian_pyramid, + interest_point, surf_point, compute_dominant_angle(), + compute_surf_descriptor(), haar_x(), haar_y(), get_interest_points(), + and get_surf_points(). + - Added the zeros_matrix() and ones_matrix() functions. + - Added serialization support to the type_safe_union object. + - Added the grow_rect() and shrink_rect() functions. + - Added the get_files_in_directory_tree() function. + - Added the null_trainer_type object. + - Added the roc_trainer_type object. + +Non-Backwards Compatible Changes: + - Removed some extraneous get_kernel() functions from some of the + trainer adapter classes since they really aren't needed. + +Bug fixes: + - Changed the socket read/write code so that it can handle a large + number ( > 2 billion) of bytes in I/O calls. + - Added a missing type cast to the reciprocal() function to fix a compile + time error you get when you use it with complex<float> type matrices. + - Fixed a bug in the assign_border_pixels() and zero_border_pixels() functions. + Their contracts said there was no upper limit on the size of the border that + could be assigned/zeroed but the implementations failed to handle the case + where the border was bigger than the image. + +Other: + - Generally cleaned up the code and documentation here and there. + - Added in Steven Van Ingelgem's patches to improve the usability of the + HTTP server object. + - Updated the load_bmp() function so that it is capable of reading BMP + files that have been compressed with the RLE compression sometimes + used for 8bit BMP files. + - Merged in Miguel Grinberg's patch to add a non-blocking read() function to the + connection object. + + + + + +New Features: + - Added a set of kernels that can operate on sparse vectors. + - Added the image_window and image_display objects. + - Added the rotate_point() function and the point_rotator object. + +Non-Backwards Compatible Changes: + - Added Steven Van Ingelgem's patch to add the body of data posted + back to the server into the incoming data object given to the + server_http::on_request() handler. This removes the content_length + field and replaces it with a string that contains the body of content + data. + +Bug fixes: + - Fixed a compile time bug in the offset_kernel. + +Other: + - Added optimized overloads of the kcentroid object for various + linear kernels. + - Changed all the tests in the dlib test suite to use a new DLIB_TEST + macro instead of DLIB_CASSERT since the tests really aren't + technically assertions + + + + + + +New Features: + - Added the strings_equal_ignore_case() functions + +Non-Backwards Compatible Changes: + - Changed the on_request() function in the http server + - Changed the serialization format of the kcentroid and svm_pegasos + objects + - By default, the kcentroid now keeps the most linearly independent + dictionary vectors rather than the newest + +Bug fixes: + +Other: + - Split the algorithms documentation page into three pages, algorithms, + machine learning, and bayes nets. + - Merged in Steven Van Ingelgem's patch to cleanup the HTTP server and + add new functionality. This breaks backwards compatibility with the + previous on_request() interface but it is easy to update old code and + it is now much cleaner and easier to use. + - Changed the kcentroid so that you can tell it to keep the most linearly + independent vectors rather than the newest vectors. I then changed the + svm_pegasos object so that it has a max number of support vector setting + so that the user can supply an upper limit on the number of support + vectors to use. + + + + + + +New Features: + - Matrix related + - Added the find_min_and_max(), index_of_min(), index_of_max(), trace(), + randm(), linspace(), logspace(), and cartesian_product() functions. + - Machine learning related + - Added the offset_kernel + - Added some functions to the kcentroid to allow the user to compute + the inner_product of kcentroids as well as a few other useful things. + - Added a kernelized version of the Pegasos SVM training algorithm. + +Non-Backwards Compatible Changes: + - Changed the range() function so that it returns row vectors + instead of column vectors. + +Bug fixes: + - Changed threading code to avoid a potential race condition during + program termination. + - Fixed a few incorrect DLIB_ASSERT statements + - Fixed a bug in the way Content-type was handled in HTTP posts. + - Fixed a bug in subm() that showed up when statically dimensioned row + vectors were used to select a sub matrix. + +Other: + - Added some functions to the rectangle to make it easy + to get the corner points. + - The cross validation functions no longer allow invalid_svm_nu_error + exceptions to escape. Instead, they are assigned low CV scores. + - Made std_vector_c able to copy directly from std::vector objects. + - Added a get_socket_descriptor() function to the connection class. + + + + + + +New Features: + - Matrix related + - Added QR, LU, Cholesky, and eigenvalue decomposition class objects + - Added overloads for rowm() and colm() that allow you to pick out + less than an entire vector + - Added the lowerm() and upperm() functions + - Added the const_temp_matrix class + +Non-Backwards Compatible Changes: + - Renamed the cholesky_decomposition() function to chol() + +Bug fixes: + - Fixed some errors in the requirements for calling the new rowm() and + colm() functions. + - Fixed dlib::abs() so that it returns the right type when used + with complex matrices. + - Fixed a race condition in the logger object. It was missing a needed call + to unregister_thread_end_handler(). What could happen in some scenarios is, + during program termination, a global part of the logger object could be destructed + when it still had outstanding thread end handlers registered to it. + +Other: + - Added an example program that shows how to use the optimization + functions. + - Gave the matrix object the ability to factor expressions containing + trans() function calls into more efficient forms. + - Generally cleaned up the matrix code + + + + + +New Features: + - Added the multi-line text_box GUI widget. + - Added the type_safe_union object + +Non-Backwards Compatible Changes: + - Renamed the array::expand() function to array::resize() since it does + basically the same thing as std::vector::resize() and more than one + user has told me they found the name "expand" to be confusing. + +Bug fixes: + +Other: + - Added an example showing how to use the type_safe_union and pipe + together. + - Added a page to the documentation that discusses the dlib coding + standards and how to contribute to the project. + + + + + +New Features: + - Added the bound_function_pointer object. + - Added support for futures to the thread_pool object. + - Added a set of objects that makes it possible to create simulations + of quantum computing algorithms. + - Added copy and paste support to the text_field. + - matrix object stuff + - Added the range() function as well as overloads of all the various + sub-matrix selection functions so that you can pick out slices of + matrices like in Matlab. + - Added a new template argument to the matrix object that allows the + user to select the memory layout. Also added a row_major_layout + and column_major_layout. + - The matrix object can now be initialized using a comma separated + list of values. + +Non-Backwards Compatible Changes: + - Changed the fatal_error exception class so that it aborts your program + and prints a message if you try to construct it more than once since + doing so indicates that you ignored the first fatal error. + - The way matrix expressions work in the library has been changed + since the last release. So if you created custom matrix expressions + then they will need to be updated to use the new matrix expression stuff. + +Bug fixes: + - Fixed a minor bug in how the zoomable_region widget drew itself after + a resize in some cases. + - Fixed a problem with draw_line where it didn't always redraw the line + properly. + +Other: + - A lot of the matrix code has been refactored and optimized. The matrix + object will now introduce temporary objects when doing so results in + better performance. I also added a simple system for binding + arbitrary matrix expressions to optimized BLAS routines. + - Cleaned up the vector and point classes. Now there is only one class, + the vector class, and it is capable of representing everything the old + vector and point class could. I also added code to make sure the + vector class does the appropriate type promotions when vector objects + with different types are used together. + - Made the vector class inherit from matrix + + + + + +New Features: + - Added user settable styles to most of the gui widgets + - Added the diagm(), svd2() and svd3() matrix functions + - Added the thread_pool object + +Non-Backwards Compatible Changes: + - Removed the arrow_button widget and moved its functionality into the + button widget. + - Renamed the dragable class to draggable + - Removed the confusing and unnecessary hidden bool argument to the + gui widget style drawing functions. + - Changed some of the events that are about the mouse leaving a widget so + that they still trigger even if the widget has been disabled or hidden. + +Bug fixes: + - Added some missing mutex locks to the scroll_bar widget + - Fixed a bug in the fill_gradient_rounded() function. It didn't always + draw the entire rectangle. + - Fixed a compile time bug in the pinv() function. It didn't compile + when used on statically sized matrices when they weren't square. + +Other: + - The member_function_pointer object now never calls new or delete. + So it is safe to use in a real time environment. + + + + + + +New Features: + - Added the sort_columns() and rsort_columns() functions + - Added the vector_normalizer object + - Added the normalized_function object. + - Added a tensor_product() function for the matrix object. + +Non-Backwards Compatible Changes: + +Bug fixes: + - Made it so that the gui event handler thread isn't created at all + unless some part of an application calls some of the gui_core code. + In the previous release the event handler thread was executed + briefly during program termination and could cause problems if no + windowing environment was available. + - Fixed an #include statement in the matrix utilities so that it works + even if you don't specify an include path argument to your compiler. + +Other: + + + + + + +New Features: + - Added a thread safe shared pointer object + - Added the popup_menu_region widget. + +Non-Backwards Compatible Changes: + - The on_wheel_up() and on_wheel_down() gui events now take an unsigned long + argument. + - Removed the register_program_ending_handler() function from the threading + API and also changed the dlib thread pool so that it no longer causes + a terminating program to wait for any outstanding threads to finish + before allowing the application to end. + - Changed the serialization format of the linearly_independent_subset_finder + class. + - Changed all the font pointers in the gui API's interfaces + to shared_ptr_thread_safe objects. + +Bug fixes: + - Made the kkmeans class actually use the min_change parameter. + - Fixed a bug in the linearly_independent_subset_finder object. Also + added a way to set a minimum tolerance. + - Fixed a bug in the scrollable_region widget that caused it to scroll in an + unpleasant way when the horizontal and vertical scroll increments weren't + set to the same value. + - Made one of the arguments to font::draw_string() not be a reference because + some versions of gcc don't end up doing the right thing when -O3 is + supplied. + - Fixed a bug in the covariance() function that prevented it from compiling + sometimes. + +Other: + - Changed the gui core code around so that it should be safe to make window + objects at the global scope. + - Added more control over how the scrollable_region scrolls its region. + You can now adjust how much it scrolls when the mouse wheel is scrolled + as well as enabling scrolling via a mouse drag. + - Modified the library so that it compiles with the Intel compiler. + - Added some example programs that use the relevance vector machine + + + + + + +New Features: + +Non-Backwards Compatible Changes: + +Bug fixes: + - Fixed a bug in the tooltip widget + - Fixed the cmake option to toggle the ENABLE_ASSERTS macro + - Fixed some bugs in the rvm + - Fixed the serialization code for the kkmeans object so that it actually + works + - Fixed a bug that can trigger when the thread_specific_data object is + destructed + - Fixed a bug in the directory navigation gui. If you tried to go + into a drive on windows that wasn't mounted you got an error. + This is now fixed. + +Other: + - Made the dir_nav stuff work with std::vector and dlib::std_vector_c + as well as dlib::queue objects. + - Generally cleaned up a bunch of things + + + + + + +New Features: + - Added relevance vector machine regression and classification support. + - Added the cross_validate_trainer_threaded() function + - Added the length and length_squared matrix functions. + +Non-Backwards Compatible Changes: + +Bug fixes: + - Changed gui code a little so that windows don't popup in funny + places when used with the cygwin X windows system. + +Other: + - Made it easier to use the scoped_ptr with the TCP connection object + - Optimized the matrix object a little + + + + + + +New Features: + - Machine Learning + - Added the ability to compare kcentroid objects to each other + - Added the rank_features() function + - Added the distance_function object + - Added the reduced_decision_function_trainer object and + reduced() function + - Added the reduced_decision_function_trainer2 object and + reduced2() function + - Added a radial basis function network trainer + - Added the linearly_independent_subset_finder object + - Added the sigmoid_kernel + - Matrix Utilities + - Added the inv_upper_triangular() and inv_upper_triangular() + functions. + +Non-Backwards Compatible Changes: + - Refactored a bunch of the kernel learning code into a much cleaner form. + But this does change the interface to all the training functions. + +Bug fixes: + - Fixed a bug in the min and max calculation in the running_stats object + - Removed a bug in the sum() and variance() functions that + caused them to seg fault when they were used on certain + matrix of matrix objects. + - Added a missing check for division by zero to the conjugate gradient + optimization functions. + - Added some missing member variables to the .swap and serialization + functions for the kcentroid object. So now they should work right. + +Other: + - Added an option to the cmake file to toggle the DLIB_ASSERT macro + - Added an option to the cmake file to toggle the dlib stack trace macros + - Made the library compile in Cygwin + + + + + + +New Features: + - Merged in Keita Mochizuki's unicode patches to the GUI components. So + the dlib GUI now has better unicode support. + - Added the remove_row and remove_col matrix functions. Also made all + three of the above functions capable of taking arguments at run time + as well as compile time. + - Added the ability to cap the number of dictionary vectors used by the krls + and kcentroid object at a user specified number. + - Added the pick_initial_centers() function + - Added the running_stats object + +Non-Backwards Compatible Changes: + - Changed the interface to the krls and kcentroid objects somewhat. + - All of the style objects in the GUI part of the library now use + dlib::ustring instead of std::string. This only matters to you if + you have made your own style objects. + - Changed the serialization format of the krls, kcentroid, and + directed_graph_drawer objects. Note that is also means that the + files saved by previous versions of the bayes_net_gui_ex program + won't load with the current version. + +Bug fixes: + - Fixed an aliasing bug in the set_subm(), set_rowm(), and set_colm() + functions. It was possible that you could get incorrect results + if you used these functions to copy one part of a matrix to another + part of the same matrix if the two areas overlapped. + - Fixed a minor numerical error in the krls code so now it gets slightly + better results. + +Other: + - Made the types generated by the matrix operations a lot shorter. This + avoids some compiler warnings from visual studio and even some potential + internal compiler errors in some instances. + + + + + + +New Features: + - Added some macros that allow dlib to create a stack trace + - Added the wrap_function and is_function templates. + - Added two new events to the text_field object. One for detecting when the + user hits enter and another for detecting when input focus is lost. + - Machine Learning + - Added a kernel based centroid estimator/novelty detector + - Added a kernel based k-means clustering algorithm + - Numerical + - Added an identity_matrix() function that can take a runtime defined size. + - Added a bunch of unconstrained optimization stuff to the library. + It now has a conjugate gradient optimization algorithm as well as + a quasi-newton algorithm. + - Added the set_subm, set_rowm, and set_colm functions. + +Non-Backwards Compatible Changes: + - In the krls object: Added a requires clause to the set_tolerance() member + function and renamed clear() to clear_dictionary(). + +Bug fixes: + - Fixed a bug in the requires clause of the subm() function. It was + more restrictive than it should have been. + +Other: + - Added example programs for the krls object as well as the new + kcentroid and kkmeans objects. + + + + + + +New Features: + - Added an implementation of the kernel recursive least squares algorithm + +Non-Backwards Compatible Changes: + - Broke backwards compatibility in the directed_graph_drawer's serialization + format when I fixed the bug below. + +Bug fixes: + - Fixed two bugs in the directed_graph_drawer widget. First, it sometimes + threw a dlib::fatal_error due to a race condition. Second, the color of + the nodes wasn't being serialized when save_graph() was called. + - Made vector_to_matrix() work for std::vector objects that have non-default + allocators. + +Other: + - Added some stuff to make people get a really obvious error message + when they set up the include path incorrectly. + + + + + + +New Features: + - Added the vector_to_matrix() function. + - Added a cholesky_decomposition() function. + - Added the toggle_button GUI widget + - Added a default toggle button style as well as check box and + radio button styles. + - Added a single click event to list_box + - Added a save_file_box() and open_existing_file_box() function. + +Non-Backwards Compatible Changes: + - Changed the check_box and radio_button widgets to be specializations of + the new toggle_button object. This is a nearly backwards compatible + change except that the events registered to check_box and radio_button + clicks must now take the form void event(toggle_button&) or + void event(void) instead of the previous void event(check_box&) and + void event(radio_button&). + - Removed the is_mouse_over bool from the button_style::draw_button() + function. + +Bug fixes: + - Fixed a compiler error in mingw. + - Changed the preprocessor checks for the wchar_t overload of + is_built_in_scalar_type so that it works properly with visual studio. + +Other: + - Added a Bayesian Network GUI that allows you to create a network + and serialize it to disk. + + + + + +New Features: + - GUI Related + - Added the scrollable_region widget + - Added the text_grid widget + - Added an event to the text_field so you can tell when the + user modifies it. + - Added the fit_to_contents() function to the tabbed_display + widget. + - Bayesian Network Related + - Added the node_first_parent_assignment(), node_next_parent_assignment(), + and node_cpt_filled_out() functions. + +Non-Backwards Compatible Changes: + - Reverted the change in 17.0 where I made drawable::lastx and + drawable::lasty not match the current location of the mouse inside + the on_mouse_move() event. I changed this back to how it was before, + so now lastx and lasty represent the most current record of where + the mouse is in *all* events. + - Changed the functions that control text color in the label and text_field + widgets to use rgb_pixel objects. Also added a function to set the + background color of a text_field. + +Bug fixes: + - Fixed a bug in the bayesian_network_join_tree object that caused it to + compute incorrect results for some networks. + - GUI Related + - Fixed a minor bug in the cursor drawing of the text_field + gui widget. + - Fixed a bug in the compute_cursor_rect() function. It would return an + incorrectly positioned rectangle for 0 length strings. + - Changed the way wchar_t is handled in the serialize.h file. Now + everything should compile correctly in visual studio regardless of how + you set the /Zc:wchar_t compiler option. + - Fixed a bug in the menu_bar widget. One of the members wasn't being + initialized when it needed to be. + - Fixed a bug in the tabbed_display where it didn't redraw itself + correctly after it was moved by set_pos() + +Other: + - Changed the xml parser so that it counts line numbers + from the start of the input stream instead of from the + root tag. + - Changed the xml parser so that you will only get the fatal_error + event once if it occurs. + + + + + +New Features: + - Added a zoomable_region widget + - Added a directed_graph_drawer widget + +Non-Backwards Compatible Changes: + - Changed the first_pixel argument of the draw_string() function + to be a rectangle like all the other draw functions now use. + +Bug fixes: + - Fixed a bug in the tooltip widget that was triggered when calling + its member functions without calling set_tooltip_text(). This also + fixed a bug in the button object that triggered when calling some button + functions that referenced the tooltip widget. + - Fixed a problem in the draw_circle and draw_solid_circle functions. + They didn't draw themselves quite correctly in all cases. + +Other: + + + + + +New Features: + - Added a png_loader object + - GUI related + - Added a popup_menu widget + - Added a menu_bar widget + - Added a tooltip widget + - Added a user selectable style to the gui button. + - Added the draw_rounded_rectangle() and fill_gradient_rounded() functions + - Added the mouse_over_event object to the base_widgets and made the + button_action inherit from it. + - Added the drawable::next_free_user_event_number() function + - matrix and geometry: + - Added a size() function to matrix_exp and matrix_ref objects. + - Added a class that represents 2D points + - Added the following matrix functions: + - squared(), cubed(), get_rect(), a subm() function that takes + rectangles, and normalize() + - Added the following rectangle functions: + - area(), centered_rect(), translate_rect(), move_rect(), resize_rect(), + resize_rect_height(), resize_rect_width(), and nearest_point() + +Non-Backwards Compatible Changes: + - Renamed atom() to array_to_matrix() + - Moved the rectangle object from the gui_core into a new geometry folder + (only matters if you were directly including the rectangle file) + - Moved the vector object into the geometry folder. Also removed the kernel_1a + stuff. So there is now only one possible vector implementation. + - Changed the default position for a rectangle to (0,0) instead of (1,1) + - Added edge data to the directed_graph. This breaks backwards compatibility + with the previous serialization format for directed_graphs. + - GUI related: + - Changed the base_window::on_keydown event signature so that it now + reports more keyboard modifier keys (e.g. alt) + - Made the functions for drawing on canvas objects take points and pixels + instead of just a bunch of integers. Also changed the order of the + arguments so that the canvas comes first, followed by the location + to draw on, then what to draw. + - Moved the canvas drawing functions into the gui_widgets/canvas_drawing.h + file. + - Modified the drawable_window so that the drawable::lastx and drawable::lasty + fields are updated after calls to on_mouse_move. This way the x and y that + go into the on_mouse_move actually tell you something. + +Bug fixes: + - Fixed a bug in the floating point serialization code. It + didn't handle NaN or infinities correctly. + - Fixed a bug in the win32 version of the gui_core component. It was + possible that calling set_size(), set_pos(), or set_title() could cause + the program to deadlock. + - Made the load_bmp() function more robust in the face of weirdly + written BMP files. + - Modified the draw_circle() and draw_solid_circle() functions so that they + only touch each canvas pixel once. This avoids messing up alpha blending + if an rgb_alpha_pixel is used. + +Other: + - Removed the old win32 only gui code in the dlib/gui folder. + - Changed the default GUI font to a nicer Sans Serif font + + + + + +New Features: + - Added another constructor to the thread_function object. + Now it can take proper function objects as well as normal function + pointers. + - Added the probabilistic_decision_function object and svm_nu_train_prob() + function. + +Non-Backwards Compatible Changes: + - Changed the svm train functions so that the cache_size argument + now measures the max number of megabytes of memory to use rather + than number of kernel matrix rows to cache. It's default + value is now 200MB. + - changed the type typedef in the SVM kernel function objects to + be named sample_type instead of type. + +Bug fixes: + - Fixed a bug in the trim, rtrim, and ltrim functions. They + didn't return empty strings when the input string contained all + trim characters. + - Fixed a bug in the decision_function's copy constructor + +Other: + - Added an optimization to the working set selection for the svm training code. + Now the algorithm will prefer to select indices that are in the kernel + matrix cache when possible. + - Fixed a problem with the chm documentation file where many of the links + didn't work. + - Made the support vector functions capable of operating with floats, doubles, + and long doubles instead of just the double type. + + + + + +New Features: + - Added aversion of the draw_line() function for images. + - Added the atom(), rowm(), colm(), and subm() matrix functions. + - Added some push/pop_back() functions to the array object that are similar + to the ones in the std::vector. + - Added the std_vector_c class that wraps std::vector and checks its + function's preconditions. + - Added the polynomial_kernel object for use with the svm algorithm. + +Non-Backwards Compatible Changes: + - Changed the svm_nu_cross_validate() function to return a vector + of both the +1 and -1 cross validation accuracies. + +Bug fixes: + - Fixed a bug in the list_box that caused it to not hide itself properly + when told to do so. + - Fixed canvas::fill() gui function so that it should work right + on 64 bit platforms. + +Other: + + + + + +New Features: + - Added memory manager support to the matrix object. + +Non-Backwards Compatible Changes: + - Made the assign_pixel() function saturate grayscale values bigger + than the target pixel type can handle. Previously it would just + truncate the numbers. + - Removed rand_kernel_1 and rand_kernel_2 because they gave very + inferior results compared to rand_kernel_3. I then renamed + rand_kernel_3 to rand_kernel_1. + - Renamed rand::get_random_number() to get_random_8bit_number() and also + added a get_random_16bit_number() and get_random_32bit_number() + - Added a checksum to compress_stream_kernel_1 and kernel_2. This + breaks backwards compatibility with the previous versions. That is, + the new implementations will complain that decompression fails if + you give them data compressed with the old non-checksum version of + the compression routines. + - Removed the width() and height() functions from the array2d object. + Now only the equivalent nc() and nr() member functions remain. + - Changed array2d::set_size(width,height) to set_size(num_rows, num_cols). + That is, I switched the order of the two arguments to this function. + The reason for doing this is to make it have the same form as the + set_size() member of the matrix object. This way the usage of the + set_size() member for these two very similar data structures is + the same. Hopefully this will reduce confusion rather than + make things worse. + +Bug fixes: + - Fixed a bug in the image_widget. It didn't repaint the screen + all the way if you gave it a smaller image to display. + - Fixed a bug in the cat() function that caused the state of the queue + to be broken if you called cat with an empty queue. + - Made the queue_sort_1 use a better sorting algorithm. In particular, it + will not sort slowly for nearly sorted data. + - Fixed a bug in the queue_kernel_2 object that caused it to not work + correctly with the non-default memory managers. + +Other: + - Added example code for the member_function_pointer as well as the matrix + object. + - Added some more regression tests and made some of the longer running + ones execute a lot quicker. + - Made the unit test suite easier to use. Now tests just throw an exception + to indicate an error rather than returning an error code. + - Added an example program for the multi-layer perceptron neural network. + + + + +New Features: + - Added the is_signed_type and is_unsigned_type templates + - Image Processing stuff + - Added the assign_all_pixels() function + - Added the assign_border_pixels() function + - Added the assign_pixel_intensity() function + - Added the auto_threshold_image() function + - Added the binary_union() function + - Added the edge_orientation() function + - Added the get_histogram() function + - Added the get_pixel_intensity() function + - Added the hysteresis_threshold() function + - Added the sobel_edge_detector() function + - Added the suppress_non_maximum_edges() function + - Added the zero_border_pixels() function + - Changed the pixel_traits structure so that it can support 8, 16, and 32 + bit grayscale pixels. + +Non-Backwards Compatible Changes: + - Added more fields to the pixel_traits template so if you had defined your + own pixel types you will need to update them. + +Bug fixes: + - Fixed some compiler errors in Visual Studio 2008 + +Other: + - Generally tried to clean up the documentation and code in this release + + + + + + +New Features: + - Added the randomize_samples() function + - Added the set_main_font() and main_font() functions to the drawable object. + So now the drawable widgets can use a user provided font. + +Non-Backwards Compatible Changes: + - Made the named_rectangle object a little easier to use. It now won't + let you size it so small that it doesn't display its entire name. + +Bug fixes: + - Fixed a bug in the svm_nu_train() function that caused a crash with + some inputs. + - Fixed a compile time error that occurred when compiling the bayesian + network code in Mac OS X. + - Fixed a bug in the compute_cursor_pos() function where it would + return the incorrect value. + +Other: + - Added an example showing how to use the svm functions. + + + + + + +New Features: + - Added the left_substr() and right_substr() functions + - Added the zero_extend_cast() function + - Added the unsigned_type template + - Added the uint8 typedef + - Bayesian Network related + - Added the assignment object + - Added the bayes_node object + - Added the joint_probability_table object + - Added the conditional_probability_table object + - Added the bayesian_network_gibbs_sampler object + This object implements an algorithm that performs approximate inference + in a Bayesian Network. + - Added the bayesian_network_join_tree object + This object implements an algorithm that performs exact inference + in a Bayesian Network. + - Set related + - Added the set_intersection_size() function + - Added the set_union() function + - Added the set_intersection() function + - Added the set_difference() function + - Graph related + - Added the graph object + - Added the is_graph template + - Added the is_directed_graph template + - Added the create_moral_graph() function + - Added the triangulate_graph_and_find_cliques() function + - Added the graph_contains_length_one_cycle() function + - Added the find_connected_nodes() function + - Added the graph_is_connected() function + - Added the is_clique() function + - Added the is_maximal_clique() function + - Added the copy_graph_structure() function + - Added the create_join_tree() function + - Added the is_join_tree() function + - Added the edge() function + - GUI related + - Added the base_window::get_display_size() function + - Added message_box_blocking() + - Added the bdf_font object which is capable of loading BDF font files into + the font object used by the gui_widgets + - Better Unicode support + - Added the basic_utf8_ifstream: An input stream that can read UTF-8 files + - Added serialization support for wchar_t and std::wstring + - Added the is_combining_char() function + - Added the convert_utf8_to_utf32() function + - Modified most of the string manipulation functions in dlib/string.h + to work with any kind of character type + - The gui widgets' font object now works with Unicode text (i.e. wchar_t + and unichar strings) as well as with normal char data. + +Non-Backwards Compatible Changes: + - The dlib/all_console.cpp and dlib/all_gui.cpp files have been deprecated + in favor of a new file. Now to build with dlib you simply add + dlib/all/source.cpp to your project regardless of what type of project + you are building. + - The GUI program entry point, winmain(), has been removed. You can now use + the normal main() entry point or some other non-standard entry point + provided by your compiler. + - Renamed directed_graph::node::item to directed_graph::node::data + +Bug fixes: + - Fixed some build issues in gcc 4.2 involving some uses of the std_allocator + - Fixed some build issues in Visual Studio involving the dir_nav component + and building with NO_MAKEFILE #defined. + - Moved the #define that disables the old WinSock API into the sockets cpp + file. This should avoid conflicts with people who are using the old WinSock + API. + - Changed the tuple template slightly to avoid a bug in Visual Studio 7.1 + that caused a compile time error in some instances. + +Other: + + + + + + +New Features: + - Added a destroy() function to the map, set, hash_map, and hash_set objects. + - Added the tuple object + - Added an overload of connect() that has a timeout + - Added rand_kernel_3 as a random number generator that uses the Mersenne Twister + algorithm. + - Added the directed_graph object + - Added the graph_contains_undirected_cycle() and graph_contains_directed_cycle() + functions. + - Added the std_allocator object. It is a STL style allocator that can use + the dlib memory manager objects. + - std::string manipulation functions: + - Added the cast_to_string() function. + - Added the tolower() function + - Added the toupper() function + - Added the ltrim() function + - Added the rtrim() function + - Added the trim() function + - Added the lpad() function + - Added the rpad() function + - Added the pad() function + +Non-Backwards Compatible Changes: + - Changed the default logging level from LNONE to LERROR + - Renamed the ASSERT macro to DLIB_ASSERT and CASSERT to DLIB_CASSERT. + This rename avoids a conflict with a macro inside MFC. + - Changed the logger so that settings are inherited when a new logger + is instantiated rather than just having the new logger use the + default settings. + - Removed the logger::clear() function since it no longer really + makes sense given the above change. + - Removed the get_main_thread_id() function and replaced it with the + is_dlib_thread() function. + +Bug fixes: + - Pushed some things into cpp files because doing so avoids build and/or + runtime errors on some platforms. + +Other: + - Changed the string_cast() function so that it will recognize the words true + and false as boolean values. Also improved the error message inside the + string_cast_error exception object. + + + + + + +New Features: + - Added the covariance() function + - Added the rgb_alpha_pixel pixel type and modified all relevant functions to + support it. + +Non-Backwards Compatible Changes: + - The GUI changes that are non-backwards compatible: + - The alpha parameter is now an unsigned char instead of unsigned int + and its range is now 0 to 255 instead of 0 to 256. + - The image_widget no longer has any member functions dealing with + alpha values. If you want to use alpha blending you just give it an + image that has an alpha channel. The same goes for draw_image(). + - There are now more fields in the pixel_traits template. So if you were + defining your own pixels before you will need to update your pixel_traits + specializations. + +Bug fixes: + - Made some functions non-inline and put some things on the stack + instead of heap. Doing this avoids some problems with certain + kinds of builds in visual studio. + +Other: + - Modified the message_box() function so that it is safe to call end_program() + from within its callback event. + + + + + + +New Features: + - Modified the GUI drawing functions to take an alpha argument to allow + alpha blending. + - Added the svm_nu_cross_validate() function to perform k-fold + cross validation using the svm_nu_train() function. + - Added the boost enable_if templates + - Added the rand_float extension to the rand object. + - New matrix features: + - Added the pinv() function + - Changed round_zeros() to use the machine epsilon instead of 1e-6 as + its default epsilon. + - Modified the matrix object so that you can declare them with + a static dimension and a dynamic dimension. E.g. matrix<float,0,10> + is now legal and declares a matrix with a fixed number of columns(10) + and a variable number of rows. + - Added the equal() function to compare two matrices of floating + point numbers for near equality. + - Changed the matrix so that operator(long) works for both + column vectors and now also for row vectors. + - Added a set_size() and constructor that takes a single long for use in + sizing row and column vectors. + - Added the scale_columns() function + +Non-Backwards Compatible Changes: + +Bug fixes: + - Fixed an error in svm_nu_train() where it would incorrectly + complain of incorrect nu values for some datasets. + - Added a missing std:: qualifier at two points in the dlib/vector code that + could cause a compiler error in some instances. + +Other: + - Added a term index to the documentation. + + + + + + +New Features: + - Added a nu support vector classifier training function. + - Added a multilayer neural network object. + - Added the "destructive aliasing" checks into the matrix code. Now temporary + matrices are only created during assignment if the right hand side aliases + the left hand side in a destructive way. This removes many of the previous + uses of temporary matrices. + - Made the sum() matrix function be able to sum matrices of matrices + - New matrix functions: + - acos(), asin(), atan(), ceil(), cos(), cosh(), exp(), floor(), log(), + log10(), mean(), norm(), pow(), reciprocal(), round_zeros(), sin(), + sinh(), sqrt(), tan(), tanh(), variance(), and more overloads of + uniform_matrix(). + +Non-Backwards Compatible Changes: + +Bug fixes: + - Added missing nr() and nc() functions to the uniform_matrix() and + identity_matrix() functions. + - Forgot to add a destructor for the dynamically sized matrix resulting in a + memory leak. This is now fixed. + - Fixed various potential compile time errors + +Other: + + + + + + +New Features: + - Added a copy of the boost noncopyable base class. + - added some smart pointers: + - added shared_ptr + - added weak_ptr + - added scoped_ptr + +Non-Backwards Compatible Changes: + +Bug fixes: + +Other: + - Cleaned up the assert code and removed the need for the dlib/error.ccp file + - Made the matrix take better advantage of the compile time sized + dimensions when it can. + + + + + + +New Features: + - Made it so that command line options have a default conversion to bool + and the bool tells you if they are on the command line or not. + - Added an implicit conversion to a scalar to the matrix object + when it is of dimension 1x1. + - Added the thread_function object + - Added a function to compute the singular value decomposition of a matrix. + +Non-Backwards Compatible Changes: + - Added two new arguments to the on_request() function. They allow you to + see what HTTP headers the client sends you and to control which ones + you send back. + +Bug fixes: + +Other: + + + + + + +New Features: + - matrix object additions: + - Added some functions to convert between matrix and pixel objects. + - Added the clamp() function that operates on matrix objects. + - Added the sigmoid function. + - Made the matrix object capable of being sized at runtime in addition + to its original compile time static sizing capability. + - Added 3 and 4 argument versions of pointwise_multiply() + - Added the +=, -=, *= and /= operators to the matrix object. + +Non-Backwards Compatible Changes: + +Bug fixes: + - Fixed the line numbering in the color pretty printer. Wasn't being + done correctly. + - Fixed a bug in the matrix round() function. + - Fixed some miscellaneous compile time errors + - Fixed a bug in the matrix removerc() function. + - Added some missing checks to catch invalid negative index inputs to + matrix objects. + - Fixed a bug in the matrix inv() function. It could sometimes + segfault if used on certain singular matrices + +Other: + - string_cast() can now convert hex strings to integers + - You can now say myarray2d.set_size(0,0) and have it do what + you would naturally expect. + - Added some #pragma statements that tell visual studio + to link the right system libraries automatically. + So now you don't have to add these things in the + project settings anymore. + + + + + + +New Features: + - Added the set_all_logging_levels(), set_all_logging_output_streams() + functions + - Added the configure_loggers_from_file() function which allows you to + easily configure all logger objects using a textual configuration + file. + +Non-Backwards Compatible Changes: + +Bug fixes: + - Added a workaround into the code that avoids a potential compilation + error on Mac OS X systems. + +Other: + + + + + + +New Features: + +Non-Backwards Compatible Changes: + +Bug fixes: + - Fixed a bug in the POSIX version of the hostname_to_ip() function. It was + screwy if you asked for more than the first IP address (the same address + might be returned more than once). + - Fixed a bug in the pipe object's timeout functions. The timeouts weren't + working correctly. + +Other: + + + + + + +New Features: + - Added the wait_for_num_blocked_dequeues(), enable_enqueue(), + disable_enqueue(), and is_enqueue_enabled() functions to the pipe object. + - The pipe object can now be used with a zero length buffer. + +Non-Backwards Compatible Changes: + - There is no longer a pipe::kernel_1a_c typedef since the pipe + no longer has any requirements to check (due to the change of allowing + zero length buffer sizes) + +Bug fixes: + +Other: + - Made the ASSERT and CASSERT macros call dlib_assert_breakpoint() when they + fail. This way you can easily set a breakpoint on them in a debugging + tool by breaking on calls to this new function. + - Fixed some typos and unclear parts of the pipe spec. + + + + + + +New Features: + - Added a thread safe version of the config_reader object (in the form of an + extension to the config_reader) + - Added the wait_until_empty() function to the pipe object. + +Non-Backwards Compatible Changes: + - Removed the connection::close() and listener::close() functions. They have + been replaced by destructors. To upgrade old code all you have to do is + replace statements of the form "object->close();" with "delete object;". + Both statements do exactly the same thing. However, for connection objects, + you should probably be using the close_gracefully() function instead. + +Bug fixes: + - Removed a potential compile time error in the dng image format handling code. + - Fixed a bug in the bigint object. The destructor was using "delete" + when it should have been using "delete []" + - Fixed a resource leak in the POSIX version hostname_to_ip() + - Fixed a significant memory leak in memory_manager_kernel_1 + - Fixed a memory leak that could occur in memory_manager_kernel_2 + and memory_manager_kernel_3 when the constructor for the object + being constructed threw an exception. + - Added a missing delete statement to plug a memory leak + in the md5 computation code. + - Fixed an uninitialized variable warning from valgrind + (in lz77_buffer/lz77_buffer_kernel_2.h). I think this could + also potentially result in an error when decoding data but I'm not totally + sure. But either way it is fixed now. + - Changed a call to memcpy to memmove in the sockstreambuf_kernel_2 + implementation since the copy could potentially be of overlapped memory. + +Other: + - Changed the connection::read() and connection::write() functions to take + longs instead of unsigned longs as the buffer sizes. They also now + return longs instead of ints. This should be a backwards compatible change. + - Started using the valgrind tool to check for memory errors in the project and + found a few things. Very nice tool :) + + + + + + +New Features: + - Added the multithreaded_object extension to the threads API + - Added the load_dng() and save_dng() functions which can load and store + the DNG lossless compressed image format (which I just made up). + +Non-Backwards Compatible Changes: + - Changed the serialization format for bool to use a 1 byte code rather than 2 + bytes. So this breaks compatibility with the old format. + +Bug fixes: + - The serialization for bool didn't always work right. This is now fixed. + +Other: + + + + + + +New Features: + - New faster version of the bigint object (bigint_kernel_2) that uses + the Fast Fourier Transform to perform multiplications. + - The base_window can now be an "undecorated" window. This new type is suitable + for making things like popup menus and the like. + - Added the on_focus_lost() event to the base_window object + - Added the on_focus_gained() event to the base_window object + - Added the on_window_moved() event to the base_window object + - Added the get_pos() function to the base_window object + - Updated the gui_widgets's drawable interface stuff to support the three + new event types and the new window type. + - Added the drawable::draw_rectangle() function + - Added serialization support for std::complex. + - Added the assign_image() function + +Non-Backwards Compatible Changes: + - Removed the color arguments from the drawable_window object's constructor and + added a new boolean argument (if it is an undecorated window or not). This + probably won't break any code but if it does you should get a compiler error + about it. + - Made it so you must disable the events in the destructor for your + drawable gui widgets. Doing so avoids potential race conditions when + destructing drawable objects. + - Made it so that you are required to call close_window() in a window object's + destructor. This avoids a potential race condition. + +Bug fixes: + - Added a workaround for a bug in MinGW that caused the regression test suite + to crash when compiled with -O3. + - Fixed a potential bug in the X Windows version of the gui_core component. + Added an extra XFlush() to end_program() because without it a + program can crash when calling end_program() in certain instances. + - The spec for the pipe object said that objects you enqueue into it + had an "initial value for their type" after the function completes. This + is incorrect, they are swapped into the pipe so they have an undefined + value afterwards. I fixed the spec for the pipe to say this. + - Fixed a bug in the font rendering functions in the gui_widgets + component. It could cause a segmentation fault sometimes. + - Fixed some potential deadlocks in the windows version of the gui_core + component. + - Fixed a bug in the rsignaler object. When you called wait() or + wait_or_timeout() it only unlocked the associated rmutex once (it could be + locked more than once and thus might cause a deadlock since the thread + calling wait() wouldn't actually unlock the mutex in this case). + - Fixed the initialize_array() function in memory_manger_kernel_3 to be + exception safe. Previously if an exception occurred while creating + an array then a resource leak was created. + +Other: + - Changed the package format for the library somewhat. The examples are now + located in their own top level folder. Additionally, the HTML version of the + documentation also comes in the same archive as the source rather than in a + separate download. + - Started using major and minor version numbers rather than just major ones. + + + + + + +New Features: + - Added operator<< and operator>> iostream operators to the vector object. + +Non-Backwards Compatible Changes: + - Changed the xml_parser's document_handler interface: + made empty element tags (<like_this/>) trigger the end_element() callback + and removed the is_empty bool from start_element(). + +Bug fixes: + - Fixed a potential race condition between the destruction of the thread pool + and the "program ending handlers" stuff. + +Other: + - Made the xml parser more robust to different types of new line characters. + - Modified the source slightly so that it works with mingw. + + + + + + +New Features: + - The config_reader is now enumerable. + - Added the image_widget gui object. + - Added nr() and nc() to the array2d object. + - Added the shutdown_connection() function to the iostream extension + to the server object. + - Added the timer_kernel_2 implementation which is a version of the timer object + that is more efficient in its allocation of threads. + - Added the timeout object. + - There is now a CMakeLists.txt file located in the dlib folder. See + dlib/examples/CMakeLists.txt and dlib/test/CMakeLists.txt for examples + that use CMake to build projects using this library. + - Added the register_program_ending_handler() function to the threading API. + +Non-Backwards Compatible Changes: + - Removed the config_reader::get_blocks() function. Use the + new enumerable interface for the config_reader instead. + - The array2d object now uses longs instead of unsigned longs to report + its dimensions and access its elements. + - Added a uint64 to the on_connect() callback in the iostream + extension to the server object. + - timer::set_delay_time() now throws and timer::start() now may throw + std::bad_alloc. + +Bug fixes: + - Fixed a bug in end_program(). In X Windows it might not cause the + program to end right away if it was called from outside the event + handling thread. + - Fixed a bug in the implementation of the timeout part of the + close_gracefully() function. + +Other: + - The library now works on HP-UX + - The regression test suite now has command line arguments that + enable tests to send debug messages to a file. + + + + + + +New Features: + - The http server extension now supports the POST HTTP method. + - The attribute list object in the xml_parser is now enumerable. + - Added the threaded object extension + - Added the uintn.h file which defines fixed sized unsigned integral types. + +Non-Backwards Compatible Changes: + - Renamed the on_get() callback in the http extension to the server object to + on_request() + - Removed the network byte order functions from the sockets api. (They are still + really there though since they come from actual OS header files. But + officially they have been replaced by the byte_orderer component). + - Renamed dlib/uint64.h to dlib/uintn.h + +Bug fixes: + +Other: + - The command line parser will now let you declare long named options with - + characters in them. + - Made it so you can use the COMPILE_TIME_ASSERT macros anywhere rather than + just inside functions. + + + + + + +New Features: + - For dlib::matrix + - Added the tmp() function + - Added optimized specializations of inv() and det() for 1x1, 2x2, 3x3 and + 4x4 matrices. + - Added the removerc() function + - Sockets related + - Added the connect() function + - Added the is_ip_address() function. + - Added the close_gracefully() function + - Added the iostream extension to the server object. + - Added the http extension to the server object. + +Non-Backwards Compatible Changes: + - Changed the cpp_tokenizer to not convert characters to their html form. + +Bug fixes: + - Removed some potential compile time errors. See the change log for details. + +Other: + - Improved the web site + - Added some more example code + - Added more colors to cpp_pretty_printer_kernel_1. + + + + + +New Features: + - std::map is now serializable + - Added the matrix object and a bunch of supporting code. + - Added the list_box graphical widget + - Added the fill_rect_with_vertical_gradient() function to the + drawable interfaces list of drawing helpers. + - Added the open_file_box() function which provides a simple file chooser. + +Non-Backwards Compatible Changes: + +Bug fixes: + - Made timestamper::get_timestamp() be a const function like it should. Fixes + some compile errors. + - Fixed a bug in the font::draw_string() function. It didn't redraw + multi-line strings right. + - Fixed a bug in the scroll_bar object that would cause a compile + error if you tried to call its width() function. + - Fixed a bug in the array_kernel_1 object. It would cause a segmentation fault + when used sometimes. + +Other: + + + + + +New Features: + - Added the following image transformation functions: + - Added the equalize_histogram() function + - Added the spatially_filter_image() function + - Added the threshold_image() function + - Added the binary_dilation() function + - Added the binary_erosion() function + - Added the binary_open() function + - Added the binary_close() function + - Added the binary_intersection() function + - Added the binary_difference() function + - Added the binary_complement() function + - Added the clear(), load_from() and default constructor back into the + config_reader. + - Made the member_function_pointer copyable and also added operator== and != + to it. + +Non-Backwards Compatible Changes: + - Made the vector object templated so you can use types other than double with it. + But now you + have to specify what type you want to use which is slightly different. + - The asc_pair_remover and asc_remover abstract classes now take a third template + argument. I highly doubt this effects any code outside the library but it is + possible. + +Bug fixes: + - Fixed a bug in the base_window::set_size() function. If you specified a size + of (0,0) it caused your program to error out. This has now been fixed. + - Fixed a bug in the scroll_bar widget. + - Fixed a bug in save_bmp(). For some image sizes it would output a goofy + looking skewed image. + +Other: + - Switched everything that used to call operator< directly to instead use + std::less or to take a template argument that provides a compare functor that + defaults to std::less. + + + + + +New Features: + - Added the assign_pixel() function + - Added the hsi pixel type + - Added the save_bmp() function + - Added the static_switch template + +Non-Backwards Compatible Changes: + - Changed how the config_reader works. It now has a more powerful syntax and + improved interface. Converting any old code to use the new version should be + simple since the new file syntax is very nearly backwards compatible with the + old syntax. (i.e. You probably won't have to change anything about your + config files) + - Renamed the dlib/image_loader.h file to dlib/image_io.h since it now includes + the image saver stuff. + - Renamed the pixel struct to rgb_pixel + - Renamed pixel_traits::color to pixel_traits::rgb + - Renamed pixel_traits::scalar to pixel_traits::grayscale + +Bug fixes: + - Fixed a bug in the load_bmp() function. It would load 24bit bmp files + incorrectly. + - Changed the logger so that it won't deadlock if you write something similar to + my_log << LINFO << function_that_also_logs();. Although this is a + dumb thing to do. But even so, it shouldn't deadlock. + - Fixed a potential linking problem with the vector object. + +Other: + - I decided I'm not going to support Borland v5.5.1 anymore. There are just too + many bugs in this compiler. It is very old at this point so I don't see this + being a big deal. + - Made the drawable::draw_image() and load_bmp() functions able to handle images + of any type of pixel. + - Pulled the imaging, algorithmic and metaprogramming stuff out of the + miscellaneous section of the web page and gave them all their own sections. + + + + + +New Features: + - Added a logger header that prints the date and time. + - Added the LTRACE logging level + - Added a buffered implementation of sockstreambuf. + +Non-Backwards Compatible Changes: + - Changed the specs to say that sockstreambuf may be buffered. + sockstreambuf_kernel_1 is still just as it always has been though. So all old + code will still work just as it always has. But all the same, the specs have + been changed and now allow for an implementation that is not 100% backwards + compatible. + - rand_kernel_2 now emits a different string of random numbers. + +Bug fixes: + - Changed the logger object's implementation to not try to register + a thread end handler for the main program thread anymore. This was + technically a bug according to the spec but it actually did end up + working the way it was supposed to. But even so, it shouldn't have + been doing that. + - Changed binary_search_tree_kernel_1 so that it avoids a bug in the version of + gcc on SuSE Enterprise Linux 9. + - Fixed a bug in the rand_kernel_2 implementation. It wasn't giving good + random numbers. + +Other: + - Modified the code so that you don't get any warnings when -Wall is used with + GCC. + + + + + +New Features: + - Added the ASSERT_ARE_SAME_TYPE macro + - Added the is_same_type template + - Added the get_main_thread_id() function to the threading API + - Added the thread_specific_data extension to the threading API + - Added the logger object. + - Added the auto_unlock object to the threading API. + +Non-Backwards Compatible Changes: + +Bug fixes: + +Other: + - Added an example that is specifically about using threads + - Added two examples about using the logger object + + + + + +New Features: + - Added the memory_manager_stateless object and two implementations of it. + - Added the MACOSX macro to dlib/platform.h + - Added a templated version of create_new_thread() that allow you to start + a thread with a member function. + - Added the register_thread_end_handler() function to the threading kernel API. + - Added memory_manager_kernel_3 + +Non-Backwards Compatible Changes: + - Changed the meaning of the memory_manager_global::get_number_of_allocations() + function because the previous definition of it didn't really make sense for + this object. + - Changed the threading API to wait for all outstanding threads to terminate + before allowing the program to end. It used to just let the OS trash those + threads when main() ended but this isn't a safe or portable thing to do. I + used to assume the user would make sure all their threads were done and had + cleaned up whatever they were doing before main() ended but this is too much + of a burden on the end user. So now the library will wait for your threads to + end. You still need to have some way of telling them it is time to stop though. + +Bug fixes: + - Fixed a minor bug in dlib/timer/timer_kernel_1.h. Its implementation was + slightly off according to the specification of a timer object. + +Other: + - The byte_order object is now capable of flipping entire arrays. + - Made it so that the ENABLE_ASSERTS macro is defined whenever ASSERT is + on. + - Made the array container use the memory managers. + + + + + + +New Features: + - Added functions to explicitly convert to/from little and big endian to the + byte_order object. + - Added the allocate_array() and deallocate_array() functions to the + memory_manager. + - Created the memory_manager_global object + - Added the remove_last_in_order(), position_enumerator() and + remove_current_element() functions to the binary_search_tree object. + +Non-Backwards Compatible Changes: + - I put an #error directive in the old GUI component to notify anyone + trying to use it that it is deprecated. I will be removing it from the + library in a few months. + - Switched the reference_counter object back to not using the memory_manager. + I realized it isn't safe for this object to use the memory_manager since + it could result in memory_managers freeing each other's allocations. + - I redefined the pixel_traits template. It is now a lot simpler and more + convenient. + +Bug fixes: + - Fixed a minor bug in dlib/rand/rand_kernel_2.cpp + +Other: + - Added some more compile time checks to the byte_orderer object. + - Changed some includes and preprocessor macros around a little so now + everything but the GUI stuff compiles in mac OS X. + - Added inclusion guards to all the .cpp files + - Added the all_gui.cpp and all_console.cpp files. They + include all the .cpp files you need to make gui and + console applications respectively into one file. + - Made more containers use the memory_manager. + + + + + + +New Features: + - Added the enqueue_or_timeout() and dequeue_or_timeout() functions + to the pipe object. + - Gave the mouse_tracker the ability to display the mouse position + relative to a user selected point. + - Added the message_box() function to the gui_widgets component. + - Gave the label widget the ability to draw newlines in strings. + - added the close_window() and is_closed() methods to the base_window + object. + - Added the rsignaler extension to the threading API. + - You can now control the thread pool timeout time by setting the + DLIB_THREAD_POOL_TIMEOUT #define. + - Added the get_from_clipboard() and put_on_clipboard() functions + to the gui_core component. + - Added the stop_and_wait() function to the timer object. + - Added the trigger_user_event() function and on_user_event() event + to the base_window object. This new event is also forwarded + to drawable interfaces inside the receiving window. + - Added the wrap_around() function to the named_rectangle widget. + - Added the top(), left(), right(), bottom(), width() and height() + functions to the drawable interface. + +Non-Backwards Compatible Changes: + - Made the radio_button and check_box widgets pass references to themselves + when they call their click handlers. + - Switched the sync_extension to use the rmutex and rsignaler objects + rather than the normal non-reentrant ones. ( Chances are that old + code that used this will still compile fine anyway. ) + - Changed the return type of rand::get_random_number() to be an + unsigned char. I also changed both the implementations of + rand because they weren't very good at all. + - Changed the functions related to drawing strings in the font class. + - Changed the drawable's rectangle to default to being empty + rather than being a single point. Most code should not notice + the difference. + +Bug fixes: + - The event handlers in gui_widgets/drawable.h were private. They + should be protected. This is now fixed. + - Fixed a bug in the way the scroll_bar was drawn when it was + the HORIZONTAL type. + - Changed how the thread pool destructs itself when the program + is ending. Previously it was possible to get an error on + NetBSD when the program was ending. This is now fixed. + - The functions related to setting the jump size in the scroll_bar + widget were private. They are now public. + - There was a bug in the MS Windows version of the gui_core component + where the members of the base_window would not work if called from + within the on_window_close() event. This has now been fixed. + - Made the set_pos() function work right for the mouse_tracker widget. + - Fixed a bug in the base64 object where the string "" could potentially + be decoded incorrectly. + - Made the global swap function for crc32_kernel_1 inline. This fixes a + potential linker error. + - Fixed some potential deadlocking that could occur while using the + gui widgets. + +Other: + - I moved all the regression tests into the dlib/test folder and + made a nice driver program to run all of them. + - I have been using the sourceforge compile farms to test the library + on various platforms. It now works for Solaris and some of the BSDs + in addition to Linux and Windows. + + + + + + + +New Features: + - Added the array_expand extension to the array object. + - Added the cmd_line_parser_check extension to the command line parser. + - Added the pipe object. + - All applicable container classes now use the memory_manager component for + their memory allocation. + - New implementations of the memory_manager object. + - Added the copy_functor class. + +Non-Backwards Compatible Changes: + - Moved the wrap_string, narrow, and string_cast functions + to a new file. You now have to include dlib/string.h to get + them. (This makes a bunch of other things work right in gcc 2.95) + - Renamed the _L macro to _dT + - Removed the scopy class + - Simplified the interface to the memory manager. It is basically the same + though. + - Removed the max_size() methods from the hash_table and binary_search_tree + objects. + - Removed the T_is_POD template arguments from the hash_table and + binary_search_tree objects. + - Simplified the template arguments to all checking components and extensions. + They now take the class they are to extend or check as their only template + argument. This only affects you if you haven't been using the kernel_nx + typedefs. + +Bug fixes: + +Other: + - I changed a few things around and now a majority of the library + again compiles under gcc 2.95. But some things don't and I currently + don't plan on making them work because it involves hackish workarounds + of 2.95's bugs. + - Changed the compress_stream_kernel_1 object so that it will detect data + corruptions better. This change will prevent it from correctly decompressing + data that was compressed with a previous version and has an uncompressed size + greater than about 20,000 bytes. + - There is a new cpp file you need to compile: dlib/error.cpp + - Moved all the regression testing stuff into the dlib/test folder and made + a nicer test driver to run them. + + + + + + +New Features: + - Created the byte_orderer object. + - Created the mouse_tracker gui widget. + - The sliding_buffer object is now enumerable and serializable. + - Added the get_filesystem_roots() function to the dir_nar component. + - Added the create_directory() function to the misc_api component. + +Non-Backwards Compatible Changes: + - The ASSERT macro is now only enabled if DEBUG or ENABLE_ASSERTS + is defined. + +Bug fixes: + - Fixed a minor bug in the cmd_line_parser object. If you gave + an option such as --option=arg when option didn't take any + arguments it could hang your program. + - Fixed a bug in wait_or_timeout() in the posix version of the threading + api. The time to wait for was being calculated incorrectly and could + result in an excessive number of spurious returns. + - Fixed a minor bug in the on_keydown() event for windows. + I had it set such that the shift and ctrl bools would be false + if they were the actual keys being pressed. This isn't what the + specs say should happen but I had a comment in the windows code + that made it clear that I did it on purpose. Go figure :) + This is now fixed. + +Other: + - Improved the cpp_tokenizer object's ability to recognize numerical + constants. + - Improved the text_field gui widget. + - There are now two assert macros. One called ASSERT + and another CASSERT. They both do the same thing but ASSERT + is only enabled when DEBUG or ENABLE_ASSERTS is defined. + All the old ASSERT statements were changed to CASSERT statements. + + + + + +New Features: + - Added array_kernel_2 which is a simple layer on top of a C array. + - Added the tabbed_display GUI widget + - Added the widget_group GUI widget + - Added the named_rectangle GUI widget + - Added the pixel_traits template + +Non-Backwards Compatible Changes: + - The default maximum size for an array object is now 0 rather than + 20,000. +Bug fixes: + +Other: + - made the cpp_pretty_printer a little better about how it handles + C style code. Also added support for /*!A html_anchor_name !*/ + style comments. + + + + + +New Features: + - Created the array2d object. + - Created the base64 object. + - Created the pixel struct. + - Created the load_bmp() function which can load a BMP image file + into an array2d object of pixels. + - Created the drawable::draw_image() function + +Non-Backwards Compatible Changes: + - In the drawable interface I made the z order a long rather + than unsigned long. + - The cpp_tokenizer object now has a NUMBER token type. + - removed the get_ prefix from functions in the cmd_line_parser + and cmd_line_parser_option objects. Also changed the + cmd_line_parser_option::operator[] function to a normal member + function called argument(). + +Bug fixes: + +Other: + - cpp_pretty_printer now colors numeric literals a shade of yellow. + + + + + +New Features: + - Created the member_function_pointer object. + - Created the button_action object. + - Created the arrow_button object. + - Created the check_box object. + - Created the radio_button object. + - Created the scroll_bar object. + - More drawing functions to draw various things + onto a canvas object. + - Added enable/disable functions to the + drawable interface. + +Non-Backwards Compatible Changes: + - The gui widgets are no longer templated at the + class level. + - The drawable object's constructor now takes a + bit set rather than a bunch of bools to tell it + which events to enable. + - I changed the names of some of the functions in the + gui_widgets component so that they all reflected a + uniform naming style. + +Bug fixes: + - Fixed a minor bug in the cpp_tokenizer. + - Minor bug in the timer object. See change log for + details. + +Other: + - Made the timer object a little more robust + + + + + + + + diff --git a/ml/dlib/docs/docs/right.gif b/ml/dlib/docs/docs/right.gif new file mode 100644 index 000000000..50e2637de Binary files /dev/null and b/ml/dlib/docs/docs/right.gif differ diff --git a/ml/dlib/docs/docs/stylesheet.xsl b/ml/dlib/docs/docs/stylesheet.xsl new file mode 100644 index 000000000..19338feaa --- /dev/null +++ b/ml/dlib/docs/docs/stylesheet.xsl @@ -0,0 +1,1201 @@ + + + + + + + + + + + + + true + main_menu.xml + dlib C++ Library + + + + + 02MiiaFNVzS5/u0eQhsy3/knioFHsia1X3DXRpHkE6I= + DGSSJMKDomaDaDTIRJ8jDkv0YMx9Cz7OESbXHjjr6Jw + + + + _LAST_MODIFIED_DATE_ + _CURRENT_RELEASE_ + + + + + + #E3E3E3 + + + + + + abcdefghijklmnopqrstuvwxyz + ABCDEFGHIJKLMNOPQRSTUVWXYZ + '?()<> /\&~!@#$%^*_+=-[]{} + + + + + + + + + + + + + + + + <xsl:value-of select="$project_name"/> + <xsl:if test="title"> + - <xsl:value-of select="title" /> + </xsl:if> + + + + + + + + + + + + + + + + + +
    + + + +
    + +
    +
    + +
    + + + +

    +
      + + +
    • +
      +
    +
    +
    +
    + + + + + + + + + + + + + +
    + +
    + + +
    + + + +
    + + + + + + + + + + +
      + + + + + + + + + + + + + + + + +
    +
    +
    + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
  • + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
  • +
    + + + +
      + + + + + + +
    +
    + + + + + + + + + + + +
    +
    [top]
    +

    + +
    + +
    +
    + + + + + + + + + + +
    + +
    [top]
    +

    +

    +
    +
    + + +
    + +
    + + + + More Details... + + + More Details... + + + + +
    #include <>
    +
    +
    + + + + + + + +
    +
    +

    Extensions to

    +
    + + + +
    +
    +
    + +
    +
    + + + + More Details... + + + More Details... + + + + + + + +
    +
    +
    + +
    +
    +
    + + + + + 1 + + + + + + 1 + + + + + + + + + +
    C++ Example Programs:
    + + + + + + + + + + , + + + + + +
    Python Example Programs:
    + + + + + + + + + + , + + + + +
    + + + +

    Implementations: + + + + +
    + : +
    + + + + +
    +
    +
    + +
    + + : + +
    +
    +
    +
    +
    +
    + + + + + +
    + + + + + + + + + + + + + + + + + + + + +
    +
    +
    + +
    +
    _c
    +
    + is a typedef for that checks its preconditions. +
    +
    +
    + +
    +
    +
    + +
    +
    + + + + +

    Release

    + Release date: +
    + Major Changes in this Release: + + + + + +
    + + + +
    +

    Release

    + + Release date: + +
    + Major Changes in this Release: + + + + + +
    +
    +
    +
    +
    +
    +
    Old Release Notes
    +
    + +
    + + + + + + +

    Release

    + + Release date: + +
    + Major Changes in this Release: + + + + + +
    + +
    +
    +
    +
    +
    + + + + + + + + + + + + + + + + + + + + +
    +
    +
    + + + + +

    + +

    +
    + + + +

    + +

    + + + +
    +

    + +

    + + + + + + + +

    + +

    +
    + +
    + +
    +
    + +
    +         
    +       
    +
    + +
    + +
    +
    + +
    + + +
    + +
    + + + + + + +
    + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    +
    + + More Details... + + +
    +
    + +
  • + +
  • +
    + + +
      + +
    +
    + +
      + +
    +
    +
    + +
      + +
    +
    + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + {@alt} + + + + + + + + + + + + +
    +
    + + +

    + + + + + + + +
    +
    #include <>
    +
    + + +
    +
    #include <>
    +
    +
    + + +
    + + + + +
    +
    #include <>
    +
    +
    +
    +
    + + + +
    +
    + + + + + +
    +
    + + + +
    +
    + [A] + [B] + [C] + [D] + [E] + [F] + [G] + [H] + [I] + [J] + [K] + [L] + [M] + [N] + [O] + [P] + [Q] + [R] + [S] + [T] + [U] + [V] + [W] + [X] + [Y] + [Z] +
    +
    + + + +
    + + + + + + + + + + + + + + + + + + + + + + Jan + Feb + Mar + Apr + May + Jun + Jul + Aug + Sep + Oct + Nov + Dec + + + + + + + , + (:: UTC) + + + + + , + (:: UTC) + + + + + + + + + + + + + + + + + + + + + Revision:
    + Author:
    + Date:
    + + +
    +
    +
    + + + + + + + + + + + + + + +
    +
    + + + + +
    +
    + +
    +
    +
    +
    +
    +
    +
    + + + + + M + Modified + black + + + A + Added + blue + + + D + Deleted + red + + + R + Deleted + red + + + + +

    +
    + + + + + + + + + + +

    Classes and Structs:

    + + + + + +

    Global Functions:

    + + +
    + + () +
    +
    + + Scope:
    +
    + File:

    +
    +
    ;
    +
    +
    +
    +
    +
    +
    + +
    + + + + +
    + + +
    +
    + + Scope:
    +
    + File:

    +
    +
    ;

    +

    +
    + + + + Protected Typedefs +
    +
      + +
    • ;
    • +
      +
    +
    +
    +
    + + + + Public Typedefs +
    +
      + +
    • ;
    • +
      +
    +
    +
    +
    + + + + Protected Variables +
    +
      + +
    • ;
    • +
      +
    +
    +
    +
    + + + + Public Variables +
    +
      + +
    • ;
    • +
      +
    +
    +
    +
    + + + + Protected Methods +
    + +
    + Method Name:

    +
    +
    ;
    +

    +
    +
    +
    +
    +
    +
    + + + + Public Methods +
    + +
    + Method Name:

    +
    +
    ;
    +

    +
    +
    +
    +
    +
    +
    + + + + Protected Inner Classes +
    + + + +
    +
    +
    + + + + Public Inner Classes +
    + + + +
    +
    +
    + +
    +
    +
    + + + + +
    +
    + + + + + + + + + +
    diff --git a/ml/dlib/docs/docs/term_index.xml b/ml/dlib/docs/docs/term_index.xml new file mode 100644 index 000000000..f3a0f08ef --- /dev/null +++ b/ml/dlib/docs/docs/term_index.xml @@ -0,0 +1,1801 @@ + + + + + Index + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/ml/dlib/docs/docs/tiled_pyramid_example.jpg b/ml/dlib/docs/docs/tiled_pyramid_example.jpg new file mode 100644 index 000000000..75f611ef2 Binary files /dev/null and b/ml/dlib/docs/docs/tiled_pyramid_example.jpg differ diff --git a/ml/dlib/docs/docs/vs-cmake-gui.png b/ml/dlib/docs/docs/vs-cmake-gui.png new file mode 100755 index 000000000..bd4308a97 Binary files /dev/null and b/ml/dlib/docs/docs/vs-cmake-gui.png differ diff --git a/ml/dlib/docs/docs/vs_mode_1.png b/ml/dlib/docs/docs/vs_mode_1.png new file mode 100755 index 000000000..ae13c5fcf Binary files /dev/null and b/ml/dlib/docs/docs/vs_mode_1.png differ diff --git a/ml/dlib/docs/docs/vs_mode_2.png b/ml/dlib/docs/docs/vs_mode_2.png new file mode 100755 index 000000000..d74ffca6e Binary files /dev/null and b/ml/dlib/docs/docs/vs_mode_2.png differ diff --git a/ml/dlib/docs/docs/vs_mode_3.png b/ml/dlib/docs/docs/vs_mode_3.png new file mode 100755 index 000000000..dc5cfe26e Binary files /dev/null and b/ml/dlib/docs/docs/vs_mode_3.png differ diff --git a/ml/dlib/docs/makedocs b/ml/dlib/docs/makedocs new file mode 100755 index 000000000..e38b75a73 --- /dev/null +++ b/ml/dlib/docs/makedocs @@ -0,0 +1,282 @@ +#!/bin/bash +. bash_helper_functions + +report_failure () +{ + echo " **** failed to complete **** " + exit 1 +} + +htmlify_python_file () +{ + pygmentize -f html -O full,style=vs $1 > $1.html +} + + +add_links_between_example_programs() +{ + EXT=$3 + # Get the list of example program filenames + pushd $1 > /dev/null + FILES=`ls *.$EXT` + popd > /dev/null + + # Now run sed on all the htmlified example programs to add the links between them. + for f in $FILES + do + #escape the . in the filename + escaped_name=`echo $f | sed -e 's/\./\\\./g'` + pushd $1 > /dev/null + # get a list of all the html example files that contain the name + matching_html_files=`grep -e "\b$escaped_name\b" -l *.$EXT | sed -e "s/\.$EXT\b/.$EXT.html/g"` + popd > /dev/null + + # now actually run sed to add the links + pushd $2 > /dev/null + if [ -n "$matching_html_files" ] + then + sed -i -e "s/\b$escaped_name\b/$escaped_name<\/a>/g" $matching_html_files + fi + popd > /dev/null + done + +} + +htmlify_cmake () +{ + echo "" > $1.html; + echo $1 >> $1.html; + echo "
    " >> $1.html;
    +
    +    #  line 1: make comments green
    +    #  line 2: add links into the add_subdirectory directives
    +    #  line 3: make literal quotes red
    +    #  line 4: make the directives show up blue
    +    #  line 5: make variable names show up purple
    +    sed -e "s/^\([ ]*#.*\)/\1<\/font>/" \
    +        -e "s/add_subdirectory\([ ]*\)(\([ ]*\)\([^ ]*\)\([ ]*\)\([^ )]*\)/add_subdirectory\1(\2\3\4\5<\/a>/"  \
    +        -e "s/\"\([^\"]*\)\"/\"\1<\/font>\"/g"  \
    +        -e "s/^\([ ]*[^( ]*[ ]*\)(/\1<\/font>(/" \
    +        -e "s/{\([^}]*\)}/\{\1<\/font>}/g"  \
    +        $1 >> $1.html;
    +
    +    echo "
    " >> $1.html; +} + +htmlify_python() +{ + FILES=`\ls $1/*.py` + for i in $FILES + do + htmlify_python_file ${i} + rm ${i} + done +} + + +makedocs () +{ + + REVNUM_FILE=.logger_revnum + + + + # figure out the short number that identifies this particular changeset + get_short_revision_number `cat $REVNUM_FILE` + LOGGER_REVNUM=$RESULT + + XSLT_OPTIONS="--nodtdattr --nonet --novalid" + DATE_TODAY=`date --date= "+%b %d, %Y"`; + + + + + # The revision number we are currently at + CHANGESET_ID=`hg id -i | sed -e 's/\+//'` + get_short_revision_number $CHANGESET_ID + REVISION=$RESULT + + + if [ "$1" = "makerel" ] + then + RELEASE=${MAJOR_NUM}.${MINOR_NUM} + else + RELEASE=${MAJOR_NUM}.${MINOR_NUM}.${PATCH_NUM} + fi; + + # get XML versions of the change logs + BASE_LOGGER_REVNUM=`echo $LOGGER_REVNUM - 1000 | bc` + NEXT_LOGGER_REVNUM=`echo $LOGGER_REVNUM + 1 | bc` + echo Getting the mercurial change logs for revisions $NEXT_LOGGER_REVNUM:$REVISION + hg log -v ../dlib ../examples ../tools ../python_examples --style=xml -r$NEXT_LOGGER_REVNUM:$REVISION > docs/log.txt || report_failure + echo Getting the mercurial change logs for revisions $BASE_LOGGER_REVNUM:$LOGGER_REVNUM + hg log -v ../dlib ../examples ../tools ../python_examples --style=xml -r$BASE_LOGGER_REVNUM:$LOGGER_REVNUM > docs/old_log.txt || report_failure + + # grab a clean copy of the repository + rm -rf docs/cache + rm -rf docs/web + rm -rf docs/chm/docs + hg archive docs/cache || report_failure + # Don't need the docs folder in the cache, moreover, deleting it here avoids letting the makerel script include it in the dlib tar balls. + rm -rf docs/cache/docs + + echo "#ifndef DLIB_REVISION_H" > docs/cache/dlib/revision.h + echo "// Version: " $RELEASE >> docs/cache/dlib/revision.h + echo "// Date: " `date` >> docs/cache/dlib/revision.h + echo "// Mercurial Revision ID: " $CHANGESET_ID >> docs/cache/dlib/revision.h + echo "#define DLIB_MAJOR_VERSION " $MAJOR_NUM >> docs/cache/dlib/revision.h + echo "#define DLIB_MINOR_VERSION " $MINOR_NUM >> docs/cache/dlib/revision.h + echo "#define DLIB_PATCH_VERSION " $PATCH_NUM >> docs/cache/dlib/revision.h + echo "#endif" >> docs/cache/dlib/revision.h + + + rm -rf docs/web + rm -rf docs/chm/docs + mkdir docs/web + mkdir docs/chm/docs + + echo Creating HTML version of the source + htmlify --title "dlib C++ Library - " -i docs/cache -o htmltemp.$$ + add_links_between_example_programs docs/cache/examples htmltemp.$$/examples cpp + + echo Copying files around... + cp -r htmltemp.$$/dlib docs/web + cp -r htmltemp.$$/dlib docs/chm/docs + cp -r htmltemp.$$/examples/* docs/web + cp -r htmltemp.$$/examples/* docs/chm/docs + rm -rf htmltemp.$$ + + # create python docs unless you say ./makedocs fast + if [ "$1" != "fast" ] + then + cd .. + python setup.py build || report_failure + python setup.py build_sphinx -c docs/docs/python --build-dir docs/sphinx.$$ || report_failure + cd docs + cp -r sphinx.$$/html docs/web/python + mv sphinx.$$/html docs/chm/docs/python + rm -rf sphinx.$$ + fi; + + + cp docs/cache/dlib/test/makefile docs/web/dlib/test + cp docs/cache/dlib/test/makefile docs/chm/docs/dlib/test + + cp docs/cache/dlib/test/CMakeLists.txt docs/web/dlib/test + cp docs/cache/dlib/test/CMakeLists.txt docs/chm/docs/dlib/test + cp docs/cache/dlib/CMakeLists.txt docs/web/dlib + cp docs/cache/dlib/CMakeLists.txt docs/chm/docs/dlib + mkdir docs/web/examples || report_failure + cp docs/cache/examples/CMakeLists.txt docs/web/examples + mkdir docs/chm/docs/examples || report_failure + cp docs/cache/examples/CMakeLists.txt docs/chm/docs/examples + cp docs/cache/python_examples/*.py docs/chm/docs/ + cp docs/cache/python_examples/*.py docs/web/ + + htmlify_python docs/chm/docs/ + htmlify_python docs/web/ + add_links_between_example_programs docs/cache/python_examples docs/chm/docs py + add_links_between_example_programs docs/cache/python_examples docs/web py + + cp docs/*.gif docs/web + cp docs/*.gif docs/chm/docs + cp docs/ml_guide.svg docs/web + cp docs/ml_guide.svg docs/chm/docs + cp -r docs/guipics docs/web + cp -r docs/guipics docs/chm/docs + cp docs/*.html docs/web + cp docs/*.html docs/chm/docs + cp docs/*.css docs/web + cp docs/*.css docs/chm/docs + cp docs/*.js docs/web + cp docs/*.js docs/chm/docs + cp docs/*.png docs/web + cp docs/*.jpg docs/web + cp docs/*.webm docs/web + cp docs/*.ico docs/web + cp docs/*.png docs/chm/docs + cp docs/*.jpg docs/chm/docs + cp docs/*.webm docs/chm/docs + cp docs/*.ico docs/chm/docs + + cd docs/chm/docs || report_failure + htmlify_cmake dlib/CMakeLists.txt; + htmlify_cmake examples/CMakeLists.txt; + htmlify_cmake dlib/test/CMakeLists.txt; + cd ../../.. || report_failure + cd docs/web || report_failure + htmlify_cmake dlib/CMakeLists.txt; + htmlify_cmake examples/CMakeLists.txt; + htmlify_cmake dlib/test/CMakeLists.txt; + cd ../.. || report_failure + + find docs/web docs/chm -name "CMakeLists.txt" | xargs rm + + + + # generate the HTML docs + echo Generate HTML docs from XML and XSLT style sheet + FILES=`\ls docs/*.xml | grep -v main_menu.xml` + for i in $FILES + do + + # The last modified date for these files should always be the release date (regardless of when the actual xml files were modified). + if [ "${i}" = "docs/release_notes.xml" -o ${i} = "docs/old_release_notes.xml" \ + -o ${i} = "docs/change_log.xml" -o ${i} = "docs/old_change_log.xml" \ + -o ${i} = "docs/index.xml" ] + then + DATE=$DATE_TODAY + else + get_last_modified_date ${i} + DATE=$RESULT + fi; + + #make web version + cat docs/stylesheet.xsl | sed -e 's/"is_chm">[^<]*/"is_chm">false/' -e "s/_CURRENT_RELEASE_/$RELEASE/" -e "s/_LAST_MODIFIED_DATE_/$DATE/" \ + > docs/stylesheet.$$.xsl + OUT_FILE=$(echo ${i} | sed -e "s/\.xml/\.html/" | sed -e "s/docs\//docs\/web\//") + xsltproc $XSLT_OPTIONS -o $OUT_FILE docs/stylesheet.$$.xsl ${i} + + #make chm version + cat docs/stylesheet.xsl | sed -e 's/"is_chm">[^<]*/"is_chm">true/' -e "s/_CURRENT_RELEASE_/$RELEASE/" -e "s/_LAST_MODIFIED_DATE_/$DATE/" \ + > docs/stylesheet.$$.xsl + OUT_FILE=$(echo ${i} | sed -e "s/\.xml/\.html/" | sed -e "s/docs\//docs\/chm\/docs\//") + xsltproc $XSLT_OPTIONS -o $OUT_FILE docs/stylesheet.$$.xsl ${i} + + rm docs/stylesheet.$$.xsl + done + +# Delete doc type header stuff +# FILES=`find docs/chm docs/web -iname "*.html" -type f` +# for i in $FILES +# do +# sed -e '/ temp.$$; +# mv temp.$$ ${i}; +# done + + + echo Generating sitemap + cd docs/web || report_failure + find . -name "*.html" | awk '{ print "http://dlib.net" substr($1,2)}' > sitemap.txt + + # make the main index have a 301 redirect. Use php to do this + echo '' > index.php + cat index.html >> index.php + rm index.html + + cd ../.. +} + + +./testenv || report_failure + + + + +# build all the html documentation +makedocs $1 + +# now make the table of contents for the chm file +echo Generating the table of contents for the chm file +xsltproc -o docs/chm/Table\ of\ Contents.hhc docs/chm/htmlhelp_stylesheet.xsl docs/chm/toc.xml + diff --git a/ml/dlib/docs/makerel b/ml/dlib/docs/makerel new file mode 100755 index 000000000..8a4d2d397 --- /dev/null +++ b/ml/dlib/docs/makerel @@ -0,0 +1,91 @@ +#!/bin/bash +. bash_helper_functions + +# If the first argument to this script is the word major then the +# major version number is updated and the minor is set back to 0. + +report_failure () +{ + echo " **** failed to complete **** " + exit 1 +} + + +./testenv_rel || report_failure + + + +REVNUM_FILE=.logger_revnum +CHANGESET_ID=`hg id -i | sed -e 's/\+//'` + +rm -rf release || report_failure +mkdir release || report_failure + + +if [ "$1" = "major" ] + then + MAJOR_NUM=`echo $MAJOR_NUM+1|bc` + MINOR_NUM=0 +else + MINOR_NUM=`echo $MINOR_NUM+1|bc` +fi; +set_dlib_version MAJOR $MAJOR_NUM +set_dlib_version MINOR $MINOR_NUM +set_dlib_version PATCH 0 + +RELEASE=${MAJOR_NUM}.${MINOR_NUM} +# Commit changes to the version numbers so that the makedocs script will use them. +echo Create Mercurial tags and commit release +hg commit -m "Created release v$RELEASE" || report_failure +hg tag v$RELEASE || report_failure + +./makedocs makerel || exit 1 + +echo $CHANGESET_ID > $REVNUM_FILE +set_dlib_version PATCH 99 +hg commit -m "Record last changeset and set PATCH version to 99" + + +cd release || report_failure +RELDIR=`echo dlib-$RELEASE` +mkdir $RELDIR +cd $RELDIR || report_failure +cp -r ../../docs/cache/* . || report_failure + +echo Version: $RELEASE >> README.md +echo "Date: `date`" >> README.md +echo Mercurial Revision ID: $CHANGESET_ID >> README.md + + + +WEBPAGE=`echo dlib_webpage-$RELEASE.tar` +SOURCE_ZIP=`echo $RELDIR.zip` +SOURCE_TAR=`echo $RELDIR.tar` +tar -C ../../docs/chm -cf - docs/ documentation.html | tar -xf - || report_failure +cd .. || report_failure + +tar -cf $SOURCE_TAR $RELDIR || report_failure +# flip everything to MS-DOS line endings +#find $RELDIR -name "*.cpp" -or -name "*.h" -or -name "*.txt" -or -name "*.html" -or -name "*.py" | xargs flip -m +find $RELDIR -name "*.cpp" -or -name "*.h" -or -name "*.txt" -or -name "*.html" -or -name "*.py" | xargs unix2dos &> /dev/null + +zip -r9 $SOURCE_ZIP $RELDIR > /dev/null || report_failure +tar -C ../docs -cf $WEBPAGE web || report_failure +bzip2 $SOURCE_TAR || report_failure +bzip2 $WEBPAGE || report_failure + +rm -rf $RELDIR + +# Don't make the chm doc file since hhc.exe doesn't run in any copy of wine anymore :( +#wine ../docs/chm/htmlhelp/hhc.exe ../docs/chm/lib.hhp +#mv ../docs/chm/help.chm dlib_documentation-$RELEASE.chm || report_failure + + +mkdir v$RELEASE +#mv dlib_documentation-$RELEASE.chm v$RELEASE +mv $SOURCE_TAR.bz2 v$RELEASE +mv $SOURCE_ZIP v$RELEASE + + + + diff --git a/ml/dlib/docs/testenv b/ml/dlib/docs/testenv new file mode 100755 index 000000000..8d2ce6bc9 --- /dev/null +++ b/ml/dlib/docs/testenv @@ -0,0 +1,31 @@ +#/bin/sh +# +#This script checks to make sure all the commands we need are +#present + +return_error() +{ + echo "Error, can't run the $1 command" + exit 1 +} + + +echo Testing environment for needed utilities + +bc -h > /dev/null || return_error "bc"; +echo nothing | awk '{}' > /dev/null || return_error "awk"; +echo | sed -e "s/s/r/" > /dev/null || return_error "sed"; +htmlify > /dev/null || return_error "htmlify"; +echo | xargs > /dev/null || return_error "xargs"; +hg > /dev/null || return_error "hg"; +xsltproc -V > /dev/null || return_error "xsltproc"; +tar --help > /dev/null || return_error "tar"; +zip -h > /dev/null || return_error "zip"; +bzip2 -h &> /dev/null || return_error "bzip2"; +pygmentize -h &> /dev/null || return_error "pygmentize"; +which sphinx-build &> /dev/null || return_error "sphinx-build"; + + +echo All needed utilities found +exit 0 + diff --git a/ml/dlib/docs/testenv_rel b/ml/dlib/docs/testenv_rel new file mode 100755 index 000000000..fff4a7be1 --- /dev/null +++ b/ml/dlib/docs/testenv_rel @@ -0,0 +1,24 @@ +#/bin/sh +# +#This script checks to make sure all the commands we need are +#present + +return_error() +{ + echo "Error, can't run the $1 command" + exit 1 +} + +./testenv + +echo Testing environment for needed release building utilities + + +#flip -h > /dev/null || return_error "flip"; +unix2dos -h &> /dev/null || return_error "unix2dos"; +#wine --help &> /dev/null || return_error "wine"; + + +echo All needed utilities found +exit 0 + diff --git a/ml/dlib/examples/3d_point_cloud_ex.cpp b/ml/dlib/examples/3d_point_cloud_ex.cpp new file mode 100644 index 000000000..f64a68976 --- /dev/null +++ b/ml/dlib/examples/3d_point_cloud_ex.cpp @@ -0,0 +1,50 @@ +// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt +/* + + This is an example illustrating the use of the perspective_window tool + in the dlib C++ Library. It is a simple tool for displaying 3D point + clouds on the screen. + +*/ + +#include +#include +#include + +using namespace dlib; +using namespace std; + +// ---------------------------------------------------------------------------------------- + +int main() +{ + // Let's make a point cloud that looks like a 3D spiral. + std::vector points; + dlib::rand rnd; + for (double i = 0; i < 20; i+=0.001) + { + // Get a point on a spiral + dlib::vector val(sin(i),cos(i),i/4); + + // Now add some random noise to it + dlib::vector temp(rnd.get_random_gaussian(), + rnd.get_random_gaussian(), + rnd.get_random_gaussian()); + val += temp/20; + + // Pick a color based on how far we are along the spiral + rgb_pixel color = colormap_jet(i,0,20); + + // And add the point to the list of points we will display + points.push_back(perspective_window::overlay_dot(val, color)); + } + + // Now finally display the point cloud. + perspective_window win; + win.set_title("perspective_window 3D point cloud"); + win.add_overlay(points); + win.wait_until_closed(); +} + +// ---------------------------------------------------------------------------- + diff --git a/ml/dlib/examples/CMakeLists.txt b/ml/dlib/examples/CMakeLists.txt new file mode 100644 index 000000000..5c408d74f --- /dev/null +++ b/ml/dlib/examples/CMakeLists.txt @@ -0,0 +1,250 @@ +# +# _______ _ _ _____ _____ _____ _____ +# |__ __| | | |_ _|/ ____| |_ _|/ ____| /\ +# | | | |__| | | | | (___ | | | (___ / \ +# | | | __ | | | \___ \ | | \___ \ / /\ \ +# | | | | | |_| |_ ____) | _| |_ ____) | / ____ \ +# |_|__|_|_ |_|_____|_____/__ |_____|_____/ /_/ _ \_\ +# |__ __| | | |__ __/ __ \| __ \|_ _| /\ | | +# | | | | | | | | | | | | |__) | | | / \ | | +# | | | | | | | | | | | | _ / | | / /\ \ | | +# | | | |__| | | | | |__| | | \ \ _| |_ / ____ \| |____ +# |_| \____/ |_| \____/|_| \_\_____/_/ \_\______| +# +# +# _____ ______ _____ _______ _ _ ______ +# | __ \| ____| /\ | __ \ |__ __| | | | ____| +# | |__) | |__ / \ | | | | | | | |__| | |__ +# | _ /| __| / /\ \ | | | | | | | __ | __| +# | | \ \| |____ / ____ \| |__| | | | | | | | |____ +# |_|__\_\______/_/_ __\_\_____/__ _ |_|__|_|_ |_|______|_ _ _ +# / ____/ __ \| \/ | \/ | ____| \ | |__ __/ ____| | | | | | +# | | | | | | \ / | \ / | |__ | \| | | | | (___ | | | | | +# | | | | | | |\/| | |\/| | __| | . ` | | | \___ \ | | | | | +# | |___| |__| | | | | | | | |____| |\ | | | ____) | |_|_|_|_| +# \_____\____/|_| |_|_| |_|______|_| \_| |_| |_____/ (_|_|_|_) +# +# +# +# This is a CMake makefile. CMake is a tool that helps you build C++ programs. +# You can download CMake from http://www.cmake.org. This CMakeLists.txt file +# you are reading builds dlib's example programs. +# + + +cmake_minimum_required(VERSION 2.8.12) +# Every project needs a name. We call this the "examples" project. +project(examples) + + +# Tell cmake we will need dlib. This command will pull in dlib and compile it +# into your project. Note that you don't need to compile or install dlib. All +# cmake needs is the dlib source code folder and it will take care of everything. +add_subdirectory(../dlib dlib_build) + + +# The next thing we need to do is tell CMake about the code you want to +# compile. We do this with the add_executable() statement which takes the name +# of the output executable and then a list of .cpp files to compile. Here we +# are going to compile one of the dlib example programs which has only one .cpp +# file, assignment_learning_ex.cpp. If your program consisted of multiple .cpp +# files you would simply list them here in the add_executable() statement. +add_executable(assignment_learning_ex assignment_learning_ex.cpp) +# Finally, you need to tell CMake that this program, assignment_learning_ex, +# depends on dlib. You do that with this statement: +target_link_libraries(assignment_learning_ex dlib::dlib) + + + +# To compile this program all you need to do is ask cmake. You would type +# these commands from within the directory containing this CMakeLists.txt +# file: +# mkdir build +# cd build +# cmake .. +# cmake --build . --config Release +# +# The cmake .. command looks in the parent folder for a file named +# CMakeLists.txt, reads it, and sets up everything needed to build program. +# Also, note that CMake can generate Visual Studio or XCode project files. So +# if instead you had written: +# cd build +# cmake .. -G Xcode +# +# You would be able to open the resulting Xcode project and compile and edit +# the example programs within the Xcode IDE. CMake can generate a lot of +# different types of IDE projects. Run the cmake -h command to see a list of +# arguments to -G to see what kinds of projects cmake can generate for you. It +# probably includes your favorite IDE in the list. + + + + +################################################################################# +################################################################################# +# A CMakeLists.txt file can compile more than just one program. So below we +# tell it to compile the other dlib example programs using pretty much the +# same CMake commands we used above. +################################################################################# +################################################################################# + + +# Since there are a lot of examples I'm going to use a macro to simplify this +# CMakeLists.txt file. However, usually you will create only one executable in +# your cmake projects and use the syntax shown above. +macro(add_example name) + add_executable(${name} ${name}.cpp) + target_link_libraries(${name} dlib::dlib ) +endmacro() + +# if an example requires GUI, call this macro to check DLIB_NO_GUI_SUPPORT to include or exclude +macro(add_gui_example name) + if (DLIB_NO_GUI_SUPPORT) + message("No GUI support, so we won't build the ${name} example.") + else() + add_example(${name}) + endif() +endmacro() + +# The deep learning toolkit requires a compiler with essentially complete C++11 +# support. However, versions of Visual Studio prior to October 2016 didn't +# provide enough C++11 support to compile the DNN tooling, but were good enough +# to compile the rest of dlib. So new versions of Visual Studio 2015 will +# work. However, Visual Studio 2017 had some C++11 support regressions, so it +# wasn't until December 2017 that Visual Studio 2017 had good enough C++11 +# support to compile the DNN examples. So if you are using Visual Studio, make +# sure you have an updated version if you want to compile the DNN code. +# +# Also note that Visual Studio users should give the -T host=x64 option so that +# CMake will instruct Visual Studio to use its 64bit toolchain. If you don't +# do this then by default Visual Studio uses a 32bit toolchain, WHICH RESTRICTS +# THE COMPILER TO ONLY 2GB OF RAM, causing it to run out of RAM and crash when +# compiling some of the DNN examples. So generate your project with a statement +# like this: +# cmake .. -G "Visual Studio 14 2015 Win64" -T host=x64 +if (NOT USING_OLD_VISUAL_STUDIO_COMPILER) + add_example(dnn_metric_learning_ex) + add_gui_example(dnn_face_recognition_ex) + add_example(dnn_introduction_ex) + add_example(dnn_introduction2_ex) + add_example(dnn_inception_ex) + add_gui_example(dnn_mmod_ex) + add_gui_example(dnn_mmod_face_detection_ex) + add_gui_example(random_cropper_ex) + add_gui_example(dnn_mmod_dog_hipsterizer) + add_gui_example(dnn_imagenet_ex) + add_gui_example(dnn_mmod_find_cars_ex) + add_gui_example(dnn_mmod_find_cars2_ex) + add_example(dnn_mmod_train_find_cars_ex) + add_gui_example(dnn_semantic_segmentation_ex) + add_example(dnn_imagenet_train_ex) + add_example(dnn_semantic_segmentation_train_ex) + add_example(dnn_metric_learning_on_images_ex) +endif() + + +if (DLIB_NO_GUI_SUPPORT) + message("No GUI support, so we won't build the webcam_face_pose_ex example.") +else() + find_package(OpenCV QUIET) + if (OpenCV_FOUND) + include_directories(${OpenCV_INCLUDE_DIRS}) + + add_executable(webcam_face_pose_ex webcam_face_pose_ex.cpp) + target_link_libraries(webcam_face_pose_ex dlib::dlib ${OpenCV_LIBS} ) + else() + message("OpenCV not found, so we won't build the webcam_face_pose_ex example.") + endif() +endif() + + + +#here we apply our macros +add_gui_example(3d_point_cloud_ex) +add_example(bayes_net_ex) +add_example(bayes_net_from_disk_ex) +add_gui_example(bayes_net_gui_ex) +add_example(bridge_ex) +add_example(bsp_ex) +add_example(compress_stream_ex) +add_example(config_reader_ex) +add_example(custom_trainer_ex) +add_example(dir_nav_ex) +add_example(empirical_kernel_map_ex) +add_gui_example(face_detection_ex) +add_gui_example(face_landmark_detection_ex) +add_gui_example(fhog_ex) +add_gui_example(fhog_object_detector_ex) +add_example(file_to_code_ex) +add_example(graph_labeling_ex) +add_gui_example(gui_api_ex) +add_gui_example(hough_transform_ex) +add_gui_example(image_ex) +add_example(integrate_function_adapt_simp_ex) +add_example(iosockstream_ex) +add_example(kcentroid_ex) +add_example(kkmeans_ex) +add_example(krls_ex) +add_example(krls_filter_ex) +add_example(krr_classification_ex) +add_example(krr_regression_ex) +add_example(learning_to_track_ex) +add_example(least_squares_ex) +add_example(linear_manifold_regularizer_ex) +add_example(logger_custom_output_ex) +add_example(logger_ex) +add_example(logger_ex_2) +add_example(matrix_ex) +add_example(matrix_expressions_ex) +add_example(max_cost_assignment_ex) +add_example(member_function_pointer_ex) +add_example(mlp_ex) +add_example(model_selection_ex) +add_gui_example(mpc_ex) +add_example(multiclass_classification_ex) +add_example(multithreaded_object_ex) +add_gui_example(object_detector_advanced_ex) +add_gui_example(object_detector_ex) +add_gui_example(one_class_classifiers_ex) +add_example(optimization_ex) +add_example(parallel_for_ex) +add_example(pipe_ex) +add_example(pipe_ex_2) +add_example(quantum_computing_ex) +add_example(queue_ex) +add_example(rank_features_ex) +add_example(running_stats_ex) +add_example(rvm_ex) +add_example(rvm_regression_ex) +add_example(sequence_labeler_ex) +add_example(sequence_segmenter_ex) +add_example(server_http_ex) +add_example(server_iostream_ex) +add_example(sockets_ex) +add_example(sockstreambuf_ex) +add_example(std_allocator_ex) +add_gui_example(surf_ex) +add_example(svm_c_ex) +add_example(svm_ex) +add_example(svm_pegasos_ex) +add_example(svm_rank_ex) +add_example(svm_sparse_ex) +add_example(svm_struct_ex) +add_example(svr_ex) +add_example(thread_function_ex) +add_example(thread_pool_ex) +add_example(threaded_object_ex) +add_example(threads_ex) +add_example(timer_ex) +add_gui_example(train_object_detector) +add_example(train_shape_predictor_ex) +add_example(using_custom_kernels_ex) +add_gui_example(video_tracking_ex) +add_example(xml_parser_ex) + + +if (DLIB_LINK_WITH_SQLITE3) + add_example(sqlite_ex) +endif() + + diff --git a/ml/dlib/examples/LICENSE_FOR_EXAMPLE_PROGRAMS.txt b/ml/dlib/examples/LICENSE_FOR_EXAMPLE_PROGRAMS.txt new file mode 100644 index 000000000..c69b87af3 --- /dev/null +++ b/ml/dlib/examples/LICENSE_FOR_EXAMPLE_PROGRAMS.txt @@ -0,0 +1,22 @@ +The intent of the example programs supplied with the dlib C++ library is +to both instruct users and to also provide a simple body of code they +may copy and paste from. To make this as painless as possible all the +example programs have been placed into the public domain. + + +This work is hereby released into the Public Domain. +To view a copy of the public domain dedication, visit +http://creativecommons.org/licenses/publicdomain/ or send a +letter to + Creative Commons + 171 Second Street + Suite 300, + San Francisco, California, 94105, USA. + + + +Public domain dedications are not recognized by some countries. So +if you live in an area where the above dedication isn't valid then +you can consider the example programs to be licensed under the Boost +Software License. + diff --git a/ml/dlib/examples/assignment_learning_ex.cpp b/ml/dlib/examples/assignment_learning_ex.cpp new file mode 100644 index 000000000..7a3acd013 --- /dev/null +++ b/ml/dlib/examples/assignment_learning_ex.cpp @@ -0,0 +1,325 @@ +// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt +/* + + This is an example illustrating the use of the dlib machine learning tools for + learning to solve the assignment problem. + + Many tasks in computer vision or natural language processing can be thought of + as assignment problems. For example, in a computer vision application where + you are trying to track objects moving around in video, you likely need to solve + an association problem every time you get a new video frame. That is, each new + frame will contain objects (e.g. people, cars, etc.) and you will want to + determine which of these objects are actually things you have seen in previous + frames. + + The assignment problem can be optimally solved using the well known Hungarian + algorithm. However, this algorithm requires the user to supply some function + which measures the "goodness" of an individual association. In many cases the + best way to measure this goodness isn't obvious and therefore machine learning + methods are used. + + The remainder of this example will show you how to learn a goodness function + which is optimal, in a certain sense, for use with the Hungarian algorithm. To + do this, we will make a simple dataset of example associations and use them to + train a supervised machine learning method. + + Finally, note that there is a whole example program dedicated to assignment + learning problems where you are trying to make an object tracker. So if that is + what you are interested in then take a look at the learning_to_track_ex.cpp + example program. +*/ + + +#include +#include + +using namespace std; +using namespace dlib; + + +// ---------------------------------------------------------------------------------------- + +/* + In an association problem, we will talk about the "Left Hand Set" (LHS) and the + "Right Hand Set" (RHS). The task will be to learn to map all elements of LHS to + unique elements of RHS. If an element of LHS can't be mapped to a unique element of + RHS for some reason (e.g. LHS is bigger than RHS) then it can also be mapped to the + special -1 output, indicating no mapping to RHS. + + So the first step is to define the type of elements in each of these sets. In the + code below we will use column vectors in both LHS and RHS. However, in general, + they can each contain any type you like. LHS can even contain a different type + than RHS. +*/ + +typedef dlib::matrix column_vector; + +// This type represents a pair of LHS and RHS. That is, sample_type::first +// contains a left hand set and sample_type::second contains a right hand set. +typedef std::pair, std::vector > sample_type; + +// This type will contain the association information between LHS and RHS. That is, +// it will determine which elements of LHS map to which elements of RHS. +typedef std::vector label_type; + +// In this example, all our LHS and RHS elements will be 3-dimensional vectors. +const unsigned long num_dims = 3; + +void make_data ( + std::vector& samples, + std::vector& labels +); +/*! + ensures + - This function creates a training dataset of 5 example associations. + - #samples.size() == 5 + - #labels.size() == 5 + - for all valid i: + - #samples[i].first == a left hand set + - #samples[i].second == a right hand set + - #labels[i] == a set of integers indicating how to map LHS to RHS. To be + precise: + - #samples[i].first.size() == #labels[i].size() + - for all valid j: + -1 <= #labels[i][j] < #samples[i].second.size() + (A value of -1 indicates that #samples[i].first[j] isn't associated with anything. + All other values indicate the associating element of #samples[i].second) + - All elements of #labels[i] which are not equal to -1 are unique. That is, + multiple elements of #samples[i].first can't associate to the same element + in #samples[i].second. +!*/ + +// ---------------------------------------------------------------------------------------- + +struct feature_extractor +{ + /*! + Recall that our task is to learn the "goodness of assignment" function for + use with the Hungarian algorithm. The dlib tools assume this function + can be written as: + match_score(l,r) == dot(w, PSI(l,r)) + bias + where l is an element of LHS, r is an element of RHS, w is a parameter vector, + bias is a scalar value, and PSI() is a user supplied feature extractor. + + This feature_extractor is where we implement PSI(). How you implement this + is highly problem dependent. + !*/ + + // The type of feature vector returned from get_features(). This must be either + // a dlib::matrix or a sparse vector. + typedef column_vector feature_vector_type; + + // The types of elements in the LHS and RHS sets + typedef column_vector lhs_element; + typedef column_vector rhs_element; + + + unsigned long num_features() const + { + // Return the dimensionality of feature vectors produced by get_features() + return num_dims; + } + + void get_features ( + const lhs_element& left, + const rhs_element& right, + feature_vector_type& feats + ) const + /*! + ensures + - #feats == PSI(left,right) + (i.e. This function computes a feature vector which, in some sense, + captures information useful for deciding if matching left to right + is "good"). + !*/ + { + // Let's just use the squared difference between each vector as our features. + // However, it should be emphasized that how to compute the features here is very + // problem dependent. + feats = squared(left - right); + } + +}; + +// We need to define serialize() and deserialize() for our feature extractor if we want +// to be able to serialize and deserialize our learned models. In this case the +// implementation is empty since our feature_extractor doesn't have any state. But you +// might define more complex feature extractors which have state that needs to be saved. +void serialize (const feature_extractor& , std::ostream& ) {} +void deserialize (feature_extractor& , std::istream& ) {} + +// ---------------------------------------------------------------------------------------- + +int main() +{ + try + { + // Get a small bit of training data. + std::vector samples; + std::vector labels; + make_data(samples, labels); + + + structural_assignment_trainer trainer; + // This is the common SVM C parameter. Larger values encourage the + // trainer to attempt to fit the data exactly but might overfit. + // In general, you determine this parameter by cross-validation. + trainer.set_c(10); + // This trainer can use multiple CPU cores to speed up the training. + // So set this to the number of available CPU cores. + trainer.set_num_threads(4); + + // Do the training and save the results in assigner. + assignment_function assigner = trainer.train(samples, labels); + + + // Test the assigner on our data. The output will indicate that it makes the + // correct associations on all samples. + cout << "Test the learned assignment function: " << endl; + for (unsigned long i = 0; i < samples.size(); ++i) + { + // Predict the assignments for the LHS and RHS in samples[i]. + std::vector predicted_assignments = assigner(samples[i]); + cout << "true labels: " << trans(mat(labels[i])); + cout << "predicted labels: " << trans(mat(predicted_assignments)) << endl; + } + + // We can also use this tool to compute the percentage of assignments predicted correctly. + cout << "training accuracy: " << test_assignment_function(assigner, samples, labels) << endl; + + + // Since testing on your training data is a really bad idea, we can also do 5-fold cross validation. + // Happily, this also indicates that all associations were made correctly. + randomize_samples(samples, labels); + cout << "cv accuracy: " << cross_validate_assignment_trainer(trainer, samples, labels, 5) << endl; + + + + // Finally, the assigner can be serialized to disk just like most dlib objects. + serialize("assigner.dat") << assigner; + + // recall from disk + deserialize("assigner.dat") >> assigner; + } + catch (std::exception& e) + { + cout << "EXCEPTION THROWN" << endl; + cout << e.what() << endl; + } +} + +// ---------------------------------------------------------------------------------------- + +void make_data ( + std::vector& samples, + std::vector& labels +) +{ + // Make four different vectors. We will use them to make example assignments. + column_vector A(num_dims), B(num_dims), C(num_dims), D(num_dims); + A = 1,0,0; + B = 0,1,0; + C = 0,0,1; + D = 0,1,1; + + std::vector lhs; + std::vector rhs; + label_type mapping; + + // In all the assignments to follow, we will only say an element of the LHS + // matches an element of the RHS if the two are equal. So A matches with A, + // B with B, etc. But never A with C, for example. + // ------------------------ + + lhs.resize(3); + lhs[0] = A; + lhs[1] = B; + lhs[2] = C; + + rhs.resize(3); + rhs[0] = B; + rhs[1] = A; + rhs[2] = C; + + mapping.resize(3); + mapping[0] = 1; // lhs[0] matches rhs[1] + mapping[1] = 0; // lhs[1] matches rhs[0] + mapping[2] = 2; // lhs[2] matches rhs[2] + + samples.push_back(make_pair(lhs,rhs)); + labels.push_back(mapping); + + // ------------------------ + + lhs[0] = C; + lhs[1] = A; + lhs[2] = B; + + rhs[0] = A; + rhs[1] = B; + rhs[2] = D; + + mapping[0] = -1; // The -1 indicates that lhs[0] doesn't match anything in rhs. + mapping[1] = 0; // lhs[1] matches rhs[0] + mapping[2] = 1; // lhs[2] matches rhs[1] + + samples.push_back(make_pair(lhs,rhs)); + labels.push_back(mapping); + + // ------------------------ + + lhs[0] = A; + lhs[1] = B; + lhs[2] = C; + + rhs.resize(4); + rhs[0] = C; + rhs[1] = B; + rhs[2] = A; + rhs[3] = D; + + mapping[0] = 2; + mapping[1] = 1; + mapping[2] = 0; + + samples.push_back(make_pair(lhs,rhs)); + labels.push_back(mapping); + + // ------------------------ + + lhs.resize(2); + lhs[0] = B; + lhs[1] = C; + + rhs.resize(3); + rhs[0] = C; + rhs[1] = A; + rhs[2] = D; + + mapping.resize(2); + mapping[0] = -1; + mapping[1] = 0; + + samples.push_back(make_pair(lhs,rhs)); + labels.push_back(mapping); + + // ------------------------ + + lhs.resize(3); + lhs[0] = D; + lhs[1] = B; + lhs[2] = C; + + // rhs will be empty. So none of the items in lhs can match anything. + rhs.resize(0); + + mapping.resize(3); + mapping[0] = -1; + mapping[1] = -1; + mapping[2] = -1; + + samples.push_back(make_pair(lhs,rhs)); + labels.push_back(mapping); + +} + diff --git a/ml/dlib/examples/bayes_net_ex.cpp b/ml/dlib/examples/bayes_net_ex.cpp new file mode 100644 index 000000000..64f2ad957 --- /dev/null +++ b/ml/dlib/examples/bayes_net_ex.cpp @@ -0,0 +1,307 @@ +// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt +/* + This is an example illustrating the use of the Bayesian Network + inference utilities found in the dlib C++ library. + + + In this example all the nodes in the Bayesian network are + boolean variables. That is, they take on either the value + 0 or the value 1. + + The network contains 4 nodes and looks as follows: + + B C + \\ // + \/ \/ + A + || + \/ + D + + + The probabilities of each node are summarized below. (The probability + of each node being 0 is not listed since it is just P(X=0) = 1-p(X=1) ) + + p(B=1) = 0.01 + + p(C=1) = 0.001 + + p(A=1 | B=0, C=0) = 0.01 + p(A=1 | B=0, C=1) = 0.5 + p(A=1 | B=1, C=0) = 0.9 + p(A=1 | B=1, C=1) = 0.99 + + p(D=1 | A=0) = 0.2 + p(D=1 | A=1) = 0.5 + +*/ + + +#include +#include +#include +#include +#include + + +using namespace dlib; +using namespace std; + +// ---------------------------------------------------------------------------------------- + +int main() +{ + try + { + // There are many useful convenience functions in this namespace. They all + // perform simple access or modify operations on the nodes of a bayesian network. + // You don't have to use them but they are convenient and they also will check for + // various errors in your bayesian network when your application is built with + // the DEBUG or ENABLE_ASSERTS preprocessor definitions defined. So their use + // is recommended. In fact, most of the global functions used in this example + // program are from this namespace. + using namespace bayes_node_utils; + + // This statement declares a bayesian network called bn. Note that a bayesian network + // in the dlib world is just a directed_graph object that contains a special kind + // of node called a bayes_node. + directed_graph::kernel_1a_c bn; + + // Use an enum to make some more readable names for our nodes. + enum nodes + { + A = 0, + B = 1, + C = 2, + D = 3 + }; + + // The next few blocks of code setup our bayesian network. + + // The first thing we do is tell the bn object how many nodes it has + // and also add the three edges. Again, we are using the network + // shown in ASCII art at the top of this file. + bn.set_number_of_nodes(4); + bn.add_edge(A, D); + bn.add_edge(B, A); + bn.add_edge(C, A); + + + // Now we inform all the nodes in the network that they are binary + // nodes. That is, they only have two possible values. + set_node_num_values(bn, A, 2); + set_node_num_values(bn, B, 2); + set_node_num_values(bn, C, 2); + set_node_num_values(bn, D, 2); + + assignment parent_state; + // Now we will enter all the conditional probability information for each node. + // Each node's conditional probability is dependent on the state of its parents. + // To specify this state we need to use the assignment object. This assignment + // object allows us to specify the state of each nodes parents. + + + // Here we specify that p(B=1) = 0.01 + // parent_state is empty in this case since B is a root node. + set_node_probability(bn, B, 1, parent_state, 0.01); + // Here we specify that p(B=0) = 1-0.01 + set_node_probability(bn, B, 0, parent_state, 1-0.01); + + + // Here we specify that p(C=1) = 0.001 + // parent_state is empty in this case since B is a root node. + set_node_probability(bn, C, 1, parent_state, 0.001); + // Here we specify that p(C=0) = 1-0.001 + set_node_probability(bn, C, 0, parent_state, 1-0.001); + + + // This is our first node that has parents. So we set the parent_state + // object to reflect that A has both B and C as parents. + parent_state.add(B, 1); + parent_state.add(C, 1); + // Here we specify that p(A=1 | B=1, C=1) = 0.99 + set_node_probability(bn, A, 1, parent_state, 0.99); + // Here we specify that p(A=0 | B=1, C=1) = 1-0.99 + set_node_probability(bn, A, 0, parent_state, 1-0.99); + + // Here we use the [] notation because B and C have already + // been added into parent state. + parent_state[B] = 1; + parent_state[C] = 0; + // Here we specify that p(A=1 | B=1, C=0) = 0.9 + set_node_probability(bn, A, 1, parent_state, 0.9); + set_node_probability(bn, A, 0, parent_state, 1-0.9); + + parent_state[B] = 0; + parent_state[C] = 1; + // Here we specify that p(A=1 | B=0, C=1) = 0.5 + set_node_probability(bn, A, 1, parent_state, 0.5); + set_node_probability(bn, A, 0, parent_state, 1-0.5); + + parent_state[B] = 0; + parent_state[C] = 0; + // Here we specify that p(A=1 | B=0, C=0) = 0.01 + set_node_probability(bn, A, 1, parent_state, 0.01); + set_node_probability(bn, A, 0, parent_state, 1-0.01); + + + // Here we set probabilities for node D. + // First we clear out parent state so that it doesn't have any of + // the assignments for the B and C nodes used above. + parent_state.clear(); + parent_state.add(A,1); + // Here we specify that p(D=1 | A=1) = 0.5 + set_node_probability(bn, D, 1, parent_state, 0.5); + set_node_probability(bn, D, 0, parent_state, 1-0.5); + + parent_state[A] = 0; + // Here we specify that p(D=1 | A=0) = 0.2 + set_node_probability(bn, D, 1, parent_state, 0.2); + set_node_probability(bn, D, 0, parent_state, 1-0.2); + + + + // We have now finished setting up our bayesian network. So let's compute some + // probability values. The first thing we will do is compute the prior probability + // of each node in the network. To do this we will use the join tree algorithm which + // is an algorithm for performing exact inference in a bayesian network. + + // First we need to create an undirected graph which contains set objects at each node and + // edge. This long declaration does the trick. + typedef dlib::set::compare_1b_c set_type; + typedef graph::kernel_1a_c join_tree_type; + join_tree_type join_tree; + + // Now we need to populate the join_tree with data from our bayesian network. The next + // function calls do this. Explaining exactly what they do is outside the scope of this + // example. Just think of them as filling join_tree with information that is useful + // later on for dealing with our bayesian network. + create_moral_graph(bn, join_tree); + create_join_tree(join_tree, join_tree); + + // Now that we have a proper join_tree we can use it to obtain a solution to our + // bayesian network. Doing this is as simple as declaring an instance of + // the bayesian_network_join_tree object as follows: + bayesian_network_join_tree solution(bn, join_tree); + + + // now print out the probabilities for each node + cout << "Using the join tree algorithm:\n"; + cout << "p(A=1) = " << solution.probability(A)(1) << endl; + cout << "p(A=0) = " << solution.probability(A)(0) << endl; + cout << "p(B=1) = " << solution.probability(B)(1) << endl; + cout << "p(B=0) = " << solution.probability(B)(0) << endl; + cout << "p(C=1) = " << solution.probability(C)(1) << endl; + cout << "p(C=0) = " << solution.probability(C)(0) << endl; + cout << "p(D=1) = " << solution.probability(D)(1) << endl; + cout << "p(D=0) = " << solution.probability(D)(0) << endl; + cout << "\n\n\n"; + + + // Now to make things more interesting let's say that we have discovered that the C + // node really has a value of 1. That is to say, we now have evidence that + // C is 1. We can represent this in the network using the following two function + // calls. + set_node_value(bn, C, 1); + set_node_as_evidence(bn, C); + + // Now we want to compute the probabilities of all the nodes in the network again + // given that we now know that C is 1. We can do this as follows: + bayesian_network_join_tree solution_with_evidence(bn, join_tree); + + // now print out the probabilities for each node + cout << "Using the join tree algorithm:\n"; + cout << "p(A=1 | C=1) = " << solution_with_evidence.probability(A)(1) << endl; + cout << "p(A=0 | C=1) = " << solution_with_evidence.probability(A)(0) << endl; + cout << "p(B=1 | C=1) = " << solution_with_evidence.probability(B)(1) << endl; + cout << "p(B=0 | C=1) = " << solution_with_evidence.probability(B)(0) << endl; + cout << "p(C=1 | C=1) = " << solution_with_evidence.probability(C)(1) << endl; + cout << "p(C=0 | C=1) = " << solution_with_evidence.probability(C)(0) << endl; + cout << "p(D=1 | C=1) = " << solution_with_evidence.probability(D)(1) << endl; + cout << "p(D=0 | C=1) = " << solution_with_evidence.probability(D)(0) << endl; + cout << "\n\n\n"; + + // Note that when we made our solution_with_evidence object we reused our join_tree object. + // This saves us the time it takes to calculate the join_tree object from scratch. But + // it is important to note that we can only reuse the join_tree object if we haven't changed + // the structure of our bayesian network. That is, if we have added or removed nodes or + // edges from our bayesian network then we must recompute our join_tree. But in this example + // all we did was change the value of a bayes_node object (we made node C be evidence) + // so we are ok. + + + + + + // Next this example will show you how to use the bayesian_network_gibbs_sampler object + // to perform approximate inference in a bayesian network. This is an algorithm + // that doesn't give you an exact solution but it may be necessary to use in some + // instances. For example, the join tree algorithm used above, while fast in many + // instances, has exponential runtime in some cases. Moreover, inference in bayesian + // networks is NP-Hard for general networks so sometimes the best you can do is + // find an approximation. + // However, it should be noted that the gibbs sampler does not compute the correct + // probabilities if the network contains a deterministic node. That is, if any + // of the conditional probability tables in the bayesian network have a probability + // of 1.0 for something the gibbs sampler should not be used. + + + // This Gibbs sampler algorithm works by randomly sampling possibles values of the + // network. So to use it we should set the network to some initial state. + + set_node_value(bn, A, 0); + set_node_value(bn, B, 0); + set_node_value(bn, D, 0); + + // We will leave the C node with a value of 1 and keep it as an evidence node. + + + // First create an instance of the gibbs sampler object + bayesian_network_gibbs_sampler sampler; + + + // To use this algorithm all we do is go into a loop for a certain number of times + // and each time through we sample the bayesian network. Then we count how + // many times a node has a certain state. Then the probability of that node + // having that state is just its count/total times through the loop. + + // The following code illustrates the general procedure. + unsigned long A_count = 0; + unsigned long B_count = 0; + unsigned long C_count = 0; + unsigned long D_count = 0; + + // The more times you let the loop run the more accurate the result will be. Here we loop + // 2000 times. + const long rounds = 2000; + for (long i = 0; i < rounds; ++i) + { + sampler.sample_graph(bn); + + if (node_value(bn, A) == 1) + ++A_count; + if (node_value(bn, B) == 1) + ++B_count; + if (node_value(bn, C) == 1) + ++C_count; + if (node_value(bn, D) == 1) + ++D_count; + } + + cout << "Using the approximate Gibbs Sampler algorithm:\n"; + cout << "p(A=1 | C=1) = " << (double)A_count/(double)rounds << endl; + cout << "p(B=1 | C=1) = " << (double)B_count/(double)rounds << endl; + cout << "p(C=1 | C=1) = " << (double)C_count/(double)rounds << endl; + cout << "p(D=1 | C=1) = " << (double)D_count/(double)rounds << endl; + } + catch (std::exception& e) + { + cout << "exception thrown: " << endl; + cout << e.what() << endl; + cout << "hit enter to terminate" << endl; + cin.get(); + } +} + + + diff --git a/ml/dlib/examples/bayes_net_from_disk_ex.cpp b/ml/dlib/examples/bayes_net_from_disk_ex.cpp new file mode 100644 index 000000000..eaab5881a --- /dev/null +++ b/ml/dlib/examples/bayes_net_from_disk_ex.cpp @@ -0,0 +1,83 @@ +// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt +/* + This is an example illustrating the use of the Bayesian Network + inference utilities found in the dlib C++ library. In this example + we load a saved Bayesian Network from disk. +*/ + + +#include +#include +#include +#include +#include +#include + + +using namespace dlib; +using namespace std; + +// ---------------------------------------------------------------------------------------- + +int main(int argc, char** argv) +{ + try + { + // This statement declares a bayesian network called bn. Note that a bayesian network + // in the dlib world is just a directed_graph object that contains a special kind + // of node called a bayes_node. + directed_graph::kernel_1a_c bn; + + if (argc != 2) + { + cout << "You must supply a file name on the command line. The file should " + << "contain a serialized Bayesian Network" << endl; + return 1; + } + + ifstream fin(argv[1],ios::binary); + + // Note that the saved networks produced by the bayes_net_gui_ex.cpp example can be deserialized + // into a network. So you can make your networks using that GUI if you like. + cout << "Loading the network from disk..." << endl; + deserialize(bn, fin); + + cout << "Number of nodes in the network: " << bn.number_of_nodes() << endl; + + // Let's compute some probability values using the loaded network using the join tree (aka. Junction + // Tree) algorithm. + + // First we need to create an undirected graph which contains set objects at each node and + // edge. This long declaration does the trick. + typedef graph::compare_1b_c, dlib::set::compare_1b_c>::kernel_1a_c join_tree_type; + join_tree_type join_tree; + + // Now we need to populate the join_tree with data from our bayesian network. The next two + // function calls do this. Explaining exactly what they do is outside the scope of this + // example. Just think of them as filling join_tree with information that is useful + // later on for dealing with our bayesian network. + create_moral_graph(bn, join_tree); + create_join_tree(join_tree, join_tree); + + // Now we have a proper join_tree we can use it to obtain a solution to our + // bayesian network. Doing this is as simple as declaring an instance of + // the bayesian_network_join_tree object as follows: + bayesian_network_join_tree solution(bn, join_tree); + + + // now print out the probabilities for each node + cout << "Using the join tree algorithm:\n"; + for (unsigned long i = 0; i < bn.number_of_nodes(); ++i) + { + // print out the probability distribution for node i. + cout << "p(node " << i <<") = " << solution.probability(i); + } + } + catch (exception& e) + { + cout << "exception thrown: " << e.what() << endl; + return 1; + } +} + + diff --git a/ml/dlib/examples/bayes_net_gui_ex.cpp b/ml/dlib/examples/bayes_net_gui_ex.cpp new file mode 100644 index 000000000..81101912c --- /dev/null +++ b/ml/dlib/examples/bayes_net_gui_ex.cpp @@ -0,0 +1,989 @@ +// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt +/* + This is a rather involved example illustrating the use of the GUI api from + the dlib C++ Library. This program is a fully functional utility for + creating Bayesian Networks. It allows the user to graphically draw the network, + save/load the network to/from disk, and also to calculate the posterior + probability of any node in the network given a set of evidence. + + This is not the first dlib example program you should be looking at. If you + want to see a simpler GUI example please look at the gui_api_ex.cpp or + image_ex.cpp example. + + If you want to understand how to use the Bayesian Network utilities in the library + you should definitely look at the bayes_net_ex.cpp example program. It gives a + comprehensive introduction to creating and manipulating Bayesian Networks. If you + want to see how to load a saved network from disk and use it in a non-GUI application + then look at the bayes_net_from_disk_ex.cpp example. + + + Now all of that being said, if you have already looked at the other relevant + examples and want to see a more in-depth example then by all means, continue reading. :) +*/ + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + + +using namespace std; +using namespace dlib; +using namespace dlib::bayes_node_utils; + +// ---------------------------------------------------------------------------- + +typedef directed_graph::kernel_1a_c directed_graph_type; +typedef directed_graph::kernel_1a_c::node_type node_type; +typedef graph::compare_1b_c, dlib::set::compare_1b_c>::kernel_1a_c join_tree_type; + +// ---------------------------------------------------------------------------- + +class main_window : public drawable_window +{ + /*! + INITIAL VALUE + This window starts out hidden and with an empty Bayesian Network + + WHAT THIS OBJECT REPRESENTS + This object is the main window of a utility for drawing Bayesian Networks. + It allows you to draw a directed graph and to set the conditional probability + tables up for each node in the network. It also allows you to compute the + posterior probability of each node. And finally, it lets you save and load + networks from file + !*/ +public: + main_window(); + ~main_window(); + +private: + + // Private helper methods + + void initialize_node_cpt_if_necessary ( unsigned long index ); + void load_selected_node_tables_into_cpt_grid (); + void load_selected_node_tables_into_ppt_grid (); + void no_node_selected (); + + + // Event handlers + + void on_cpt_grid_modified(unsigned long row, unsigned long col); + void on_evidence_toggled (); + void on_graph_modified (); + void on_menu_file_open (); + void on_menu_file_quit (); + void on_menu_file_save (); + void on_menu_file_save_as (); + void on_menu_help_about (); + void on_menu_help_help (); + void on_node_deleted (); + void on_node_deselected ( unsigned long n ); + void on_node_selected (unsigned long n); + void on_open_file_selected ( const std::string& file_name); + void on_save_file_selected ( const std::string& file_name); + void on_sel_node_evidence_modified (); + void on_sel_node_num_values_modified (); + void on_sel_node_text_modified (); + void on_window_resized (); + void recalculate_probabilities (); + + // Member data + + const rgb_pixel color_non_evidence; + const rgb_pixel color_default_bg; + const rgb_pixel color_evidence; + const rgb_pixel color_error; + const rgb_pixel color_gray; + bool graph_modified_since_last_recalc; + + button btn_calculate; + check_box sel_node_is_evidence; + directed_graph_drawer graph_drawer; + label sel_node_index; + label sel_node_num_values_label; + label sel_node_text_label; + label sel_node_evidence_label; + menu_bar mbar; + named_rectangle selected_node_rect; + tabbed_display tables; + text_field sel_node_num_values; + text_field sel_node_text; + text_field sel_node_evidence; + text_grid cpt_grid; + text_grid ppt_grid; + unsigned long selected_node_index; + bool node_is_selected; + widget_group cpt_group; + widget_group ppt_group; + + std::unique_ptr solution; + join_tree_type join_tree; + // The std_vector_c is an object identical to the std::vector except that it checks + // all its preconditions and throws a dlib::fatal_error if they are violated. + std_vector_c cpt_grid_assignments; + std::string graph_file_name; +}; + +// ---------------------------------------------------------------------------------------- + +int main() +{ + // create our window + main_window my_window; + + // tell our window to put itself on the screen + my_window.show(); + + // wait until the user closes this window before we let the program + // terminate. + my_window.wait_until_closed(); +} + +// ---------------------------------------------------------------------------------------- + +#ifdef WIN32 +// If you use main() as your entry point when building a program on MS Windows then +// there will be a black console window associated with your application. If you +// want your application to not have this console window then you need to build +// using the WinMain() entry point as shown below and also set your compiler to +// produce a "Windows" project instead of a "Console" project. In visual studio +// this can be accomplished by going to project->properties->general configuration-> +// Linker->System->SubSystem and selecting Windows instead of Console. +// +int WINAPI WinMain (HINSTANCE, HINSTANCE, PSTR cmds, int) +{ + main(); + return 0; +} +#endif + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// Methods from the main_window object +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + +main_window:: +main_window( +) : + color_non_evidence(0,0,0), + color_default_bg(255,255,255), + color_evidence(100,200,100), + color_error(255,0,0), + color_gray(210,210,210), + graph_modified_since_last_recalc(true), + btn_calculate(*this), + sel_node_is_evidence(*this), + graph_drawer(*this), + sel_node_index(*this), + sel_node_num_values_label (*this), + sel_node_text_label(*this), + sel_node_evidence_label(*this), + mbar(*this), + selected_node_rect(*this), + tables(*this), + sel_node_num_values(*this), + sel_node_text(*this), + sel_node_evidence(*this), + cpt_grid(*this), + ppt_grid(*this), + selected_node_index(0), + node_is_selected(false), + cpt_group(*this), + ppt_group(*this) +{ + // Note that all the GUI widgets take a reference to the window that contains them + // as their constructor argument. This is a universal feature of GUI widgets in the + // dlib library. + + set_title("Bayesian Network Utility"); + + // position the widget that is responsible for drawing the directed graph, the graph_drawer, + // just below the mbar (menu bar) widget. + graph_drawer.set_pos(5,mbar.bottom()+5); + set_size(750,400); + + // register the event handlers with their respective widgets + btn_calculate.set_click_handler (*this, &main_window::recalculate_probabilities); + cpt_grid.set_text_modified_handler (*this, &main_window::on_cpt_grid_modified); + graph_drawer.set_graph_modified_handler (*this, &main_window::on_graph_modified); + graph_drawer.set_node_deleted_handler (*this, &main_window::on_node_deleted); + graph_drawer.set_node_deselected_handler (*this, &main_window::on_node_deselected); + graph_drawer.set_node_selected_handler (*this, &main_window::on_node_selected); + sel_node_evidence.set_text_modified_handler (*this, &main_window::on_sel_node_evidence_modified); + sel_node_is_evidence.set_click_handler (*this, &main_window::on_evidence_toggled); + sel_node_num_values.set_text_modified_handler(*this, &main_window::on_sel_node_num_values_modified); + sel_node_text.set_text_modified_handler (*this, &main_window::on_sel_node_text_modified); + + // now set the text of some of our buttons and labels + btn_calculate.set_name("Recalculate posterior probability table"); + selected_node_rect.set_name("Selected node"); + sel_node_evidence_label.set_text("evidence value:"); + sel_node_is_evidence.set_name("is evidence"); + sel_node_num_values_label.set_text("Number of values: "); + sel_node_text_label.set_text("Node label:"); + + // Now setup the tabbed display. It will have two tabs, one for the conditional + // probability table and one for the posterior probability table. + tables.set_number_of_tabs(2); + tables.set_tab_name(0,"Conditional probability table"); + tables.set_tab_name(1,"Posterior probability table"); + cpt_group.add(cpt_grid,0,0); + ppt_group.add(ppt_grid,0,0); + tables.set_tab_group(0,cpt_group); + tables.set_tab_group(1,ppt_group); + + // Now setup the menu bar. We will have two menus. A File and Help menu. + mbar.set_number_of_menus(2); + mbar.set_menu_name(0,"File",'F'); + mbar.set_menu_name(1,"Help",'H'); + + // add the entries to the File menu. + mbar.menu(0).add_menu_item(menu_item_text("Open", *this, &main_window::on_menu_file_open, 'O')); + mbar.menu(0).add_menu_item(menu_item_separator()); + mbar.menu(0).add_menu_item(menu_item_text("Save", *this, &main_window::on_menu_file_save, 'S')); + mbar.menu(0).add_menu_item(menu_item_text("Save As",*this, &main_window::on_menu_file_save_as, 'a')); + mbar.menu(0).add_menu_item(menu_item_separator()); + mbar.menu(0).add_menu_item(menu_item_text("Quit", *this, &main_window::on_menu_file_quit, 'Q')); + + // Add the entries to the Help menu. + mbar.menu(1).add_menu_item(menu_item_text("Help", *this, &main_window::on_menu_help_help, 'e')); + mbar.menu(1).add_menu_item(menu_item_text("About", *this, &main_window::on_menu_help_about, 'A')); + + + // call our helper functions and window resize event to get the widgets + // to all arrange themselves correctly in our window. + no_node_selected(); + on_window_resized(); +} + +// ---------------------------------------------------------------------------------------- + +main_window:: +~main_window( +) +{ + // You should always call close_window() in the destructor of window + // objects to ensure that no events will be sent to this window while + // it is being destructed. + close_window(); +} + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// Private methods from the main_window object +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + +void main_window:: +load_selected_node_tables_into_ppt_grid ( +) +{ + // This function just takes the currently selected graph node and loads + // its posterior probabilities into the ppt_graph widget. + node_type& node = graph_drawer.graph_node(selected_node_index); + ppt_grid.set_grid_size(2,node.data.table().num_values()); + + // load the top row of the table into the grid. This row is the "title bar" row + // that tells you what each column contains. + for (unsigned long col = 0; col < node.data.table().num_values(); ++col) + { + ppt_grid.set_text(0,col,"P(node=" + cast_to_string(col) + ")"); + ppt_grid.set_background_color(0,col,rgb_pixel(150,150,250)); + ppt_grid.set_editable(0,col,false); + } + + // If we have a solution to the network on hand then load the probabilities + // from that into the table + if (solution) + { + // get the probability distribution for the currently selected node out + // of the solution. + const matrix prob = solution->probability(selected_node_index); + + // now load the probabilities into the ppt_grid so the user can see them. + for (unsigned long col = 0; col < node.data.table().num_values(); ++col) + { + ppt_grid.set_text(1,col,cast_to_string(prob(col))); + } + } + + // make the second row of the table non-editable have a color that indicates + // that to the user + for (unsigned long col = 0; col < node.data.table().num_values(); ++col) + { + ppt_grid.set_background_color(1,col,color_gray); + ppt_grid.set_editable(1,col,false); + } +} + +// ---------------------------------------------------------------------------------------- + +void main_window:: +load_selected_node_tables_into_cpt_grid ( +) +{ + // This function just takes the conditional probability table in the + // currently selected graph node and puts it into the cpt_grid widget. + + node_type& node = graph_drawer.graph_node(selected_node_index); + + initialize_node_cpt_if_necessary(selected_node_index); + cpt_grid_assignments.clear(); + + // figure out how many rows there should be in the cpt + unsigned long cpt_rows = 1; + for (unsigned long i = 0; i < node.number_of_parents(); ++i) + { + cpt_rows *= node.parent(i).data.table().num_values(); + } + + unsigned long cpt_cols = node.data.table().num_values(); + + cpt_grid.set_grid_size(cpt_rows+1, cpt_cols+ node.number_of_parents()); + const unsigned long num_cols = cpt_grid.number_of_columns(); + + // fill in the top row of the grid that shows which parent node the left hand columns go with + assignment a(node_first_parent_assignment(graph_drawer.graph(),selected_node_index)); + unsigned long col = 0; + a.reset(); + while (a.move_next()) + { + cpt_grid.set_text(0,col,cast_to_string(a.element().key()) + ": " + graph_drawer.node_label(a.element().key()) ); + cpt_grid.set_background_color(0,col,rgb_pixel(120,210,210)); + cpt_grid.set_editable(0,col,false); + ++col; + } + + // fill in the top row of the grid that shows which probability the right hand columns go with + for (col = node.number_of_parents(); col < num_cols; ++col) + { + cpt_grid.set_text(0,col,"P(node=" + cast_to_string(col-node.number_of_parents()) + ")"); + cpt_grid.set_background_color(0,col,rgb_pixel(150,150,250)); + cpt_grid.set_editable(0,col,false); + } + + // now loop over all the possible parent assignments for this node + const unsigned long num_values = node.data.table().num_values(); + unsigned long row = 1; + do + { + col = 0; + + // fill in the left side of the grid row that shows what the parent assignment is + a.reset(); + while (a.move_next()) + { + cpt_grid.set_text(row,col,cast_to_string(a.element().value())); + cpt_grid.set_background_color(row,col,rgb_pixel(180,255,255)); + cpt_grid.set_editable(row,col,false); + + ++col; + } + + // fill in the right side of the grid row that shows what the conditional probabilities are + for (unsigned long value = 0; value < num_values; ++value) + { + const double prob = node.data.table().probability(value,a); + cpt_grid.set_text(row,col,cast_to_string(prob)); + ++col; + } + + // save this assignment so we can use it later to modify the node's + // conditional probability table if the user modifies the cpt_grid + cpt_grid_assignments.push_back(a); + ++row; + } while (node_next_parent_assignment(graph_drawer.graph(),selected_node_index,a)); + +} + +// ---------------------------------------------------------------------------------------- + +void main_window:: +initialize_node_cpt_if_necessary ( + unsigned long index +) +{ + node_type& node = graph_drawer.graph_node(index); + + // if the cpt for this node isn't properly filled out then let's clear it out + // and populate it with some reasonable default values + if (node_cpt_filled_out(graph_drawer.graph(), index) == false) + { + node.data.table().empty_table(); + + const unsigned long num_values = node.data.table().num_values(); + + // loop over all the possible parent assignments for this node and fill them out + // with reasonable default values + assignment a(node_first_parent_assignment(graph_drawer.graph(), index)); + do + { + // set the first value to have probability 1 + node.data.table().set_probability(0, a, 1.0); + + // set all the other values to have probability 0 + for (unsigned long value = 1; value < num_values; ++value) + node.data.table().set_probability(value, a, 0); + + } while (node_next_parent_assignment(graph_drawer.graph(), index,a)); + } +} + +// ---------------------------------------------------------------------------------------- + +void main_window:: +no_node_selected ( +) +{ + // Make it so that no node is selected on the gui. Do this by disabling things + // and clearing out text fields and so forth. + + + node_is_selected = false; + tables.disable(); + sel_node_evidence.disable(); + sel_node_is_evidence.disable(); + sel_node_index.disable(); + sel_node_evidence_label.disable(); + sel_node_text_label.disable(); + sel_node_text.disable(); + sel_node_index.set_text("index:"); + sel_node_num_values_label.disable(); + sel_node_num_values.disable(); + cpt_grid.set_grid_size(0,0); + ppt_grid.set_grid_size(0,0); + + sel_node_is_evidence.set_unchecked(); + sel_node_text.set_text(""); + sel_node_num_values.set_text(""); + sel_node_evidence.set_text(""); + sel_node_num_values.set_background_color(color_default_bg); + sel_node_evidence.set_background_color(color_default_bg); +} + +// ---------------------------------------------------------------------------------------- + +void main_window:: +recalculate_probabilities ( +) +{ + // clear out the current solution + solution.reset(); + if (graph_is_connected(graph_drawer.graph()) == false) + { + message_box("Error","Your graph has nodes that are completely disconnected from the other nodes.\n" + "You must connect them somehow"); + } + else if (graph_drawer.graph().number_of_nodes() > 0) + { + if (graph_modified_since_last_recalc) + { + // make sure all the cpts are filled out + const unsigned long num_nodes = graph_drawer.graph().number_of_nodes(); + for (unsigned long i = 0; i < num_nodes; ++i) + { + initialize_node_cpt_if_necessary(i); + } + + // remake the join tree for this graph + create_moral_graph(graph_drawer.graph(), join_tree); + create_join_tree(join_tree, join_tree); + graph_modified_since_last_recalc = false; + } + + // create a solution to this bayesian network using the join tree algorithm + solution.reset(new bayesian_network_join_tree(graph_drawer.graph(), join_tree)); + + if (node_is_selected) + { + load_selected_node_tables_into_ppt_grid(); + } + } +} + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// Event handling methods from the main_window object +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + +// This event is called when the user selects a file with a saved +// bayesian network in it. +void main_window:: +on_open_file_selected ( + const std::string& file_name +) +{ + try + { + no_node_selected(); + ifstream fin(file_name.c_str(), ios::binary); + graph_drawer.load_graph(fin); + graph_file_name = file_name; + set_title("Bayesian Network Utility - " + right_substr(file_name,"\\/")); + } + catch (...) + { + message_box("Error", "Unable to load graph file " + file_name); + } +} + +// ---------------------------------------------------------------------------------------- + +// This event is called when the user selects from the menu bar File->Open +void main_window:: +on_menu_file_open ( +) +{ + // display a file chooser window and when the user choses a file + // call the on_open_file_selected() function + open_existing_file_box(*this, &main_window::on_open_file_selected); +} + +// ---------------------------------------------------------------------------------------- + +// This event is called when the user selects from the menu bar File->Save +void main_window:: +on_menu_file_save ( +) +{ + // if we don't currently have any file name associated with our graph + if (graph_file_name.size() == 0) + { + // display a file chooser window and when the user choses a file + // call the on_save_file_selected() function + save_file_box(*this, &main_window::on_save_file_selected); + } + else + { + // we know what file to open so just do that and save the graph to it + ofstream fout(graph_file_name.c_str(), ios::binary); + graph_drawer.save_graph(fout); + } +} + +// ---------------------------------------------------------------------------------------- + +// This event is called when the user choses which file to save the graph to +void main_window:: +on_save_file_selected ( + const std::string& file_name +) +{ + ofstream fout(file_name.c_str(), ios::binary); + graph_drawer.save_graph(fout); + graph_file_name = file_name; + set_title("Bayesian Network Utility - " + right_substr(file_name,"\\/")); +} + +// ---------------------------------------------------------------------------------------- + +// This event is called when the user selects from the menu bar File->Save As +void main_window:: +on_menu_file_save_as ( +) +{ + // display a file chooser window and when the user choses a file + // call the on_save_file_selected() function + save_file_box(*this, &main_window::on_save_file_selected); +} + +// ---------------------------------------------------------------------------------------- + +// This event is called when the user selects from the menu bar File->Quit +void main_window:: +on_menu_file_quit ( +) +{ + close_window(); +} + +// ---------------------------------------------------------------------------------------- + +// This event is called when the user selects from the menu bar Help->Help +void main_window:: +on_menu_help_help ( +) +{ + message_box("Help", + "To create new nodes right click on the drawing area.\n" + "To create edges select the parent node and then shift+left click on the child node.\n" + "To remove nodes or edges select them by left clicking and then press the delete key."); +} + +// ---------------------------------------------------------------------------------------- + +// This event is called when the user selects from the menu bar Help->About +void main_window:: +on_menu_help_about ( +) +{ + message_box("About","This application is the GUI front end to the dlib C++ Library's\n" + "Bayesian Network inference utilities\n\n" + "Version 1.2\n\n" + "See http://dlib.net for updates"); +} + +// ---------------------------------------------------------------------------------------- + +// This event is called when the user modifies the graph_drawer widget. That is, +// when the user adds or removes an edge or node in the graph. +void main_window:: +on_graph_modified ( +) +{ + // make note of the modification + graph_modified_since_last_recalc = true; + // clear out the solution object since we will need to recalculate it + // since the graph changed + solution.reset(); +} + +// ---------------------------------------------------------------------------------------- + +// This event is called when the user modifies the evidence value for a node +void main_window:: +on_sel_node_evidence_modified ( +) +{ + // make a reference to the node in the graph that is currently selected + node_type& node = graph_drawer.graph_node(selected_node_index); + unsigned long value; + try + { + // get the numerical value of the new evidence value. Here we are taking + // the string from the text field and casting it to an unsigned long. + value = sa = trim(sel_node_evidence.text()); + } + catch (string_cast_error&) + { + // if the user put something that isn't an integer into the + // text field then make it have a different background color + // so that they can easily see this. + sel_node_evidence.set_background_color(color_error); + return; + } + + // validate the input from the user and store it in the selected node + // if it is ok + if (value >= node.data.table().num_values()) + { + sel_node_evidence.set_background_color(color_error); + } + else + { + node.data.set_value(value); + sel_node_evidence.set_background_color(color_default_bg); + } + + // clear out the solution to the graph since we now need + // to recalculate it. + solution.reset(); +} + +// ---------------------------------------------------------------------------------------- + +// This event is called when the user modifies the number of evidence values for +// a node. +void main_window:: +on_sel_node_num_values_modified ( +) +{ + // make a reference to the node in the graph that is currently selected + node_type& node = graph_drawer.graph_node(selected_node_index); + + unsigned long num_values; + try + { + // get the number of values out of the text field. + num_values = sa = trim(sel_node_num_values.text()); + } + catch (string_cast_error&) + { + sel_node_num_values.set_background_color(color_error); + return; + } + + // validate the input from the user to make sure it is something reasonable + if (num_values < 2 || num_values > 100) + { + sel_node_num_values.set_background_color(color_error); + } + else + { + // update the graph + node.data.table().set_num_values(num_values); + graph_modified_since_last_recalc = true; + sel_node_num_values.set_background_color(color_default_bg); + + on_sel_node_evidence_modified(); + // also make sure the evidence value of this node makes sense still + if (node.data.is_evidence() && node.data.value() >= num_values) + { + // just set it to zero + node.data.set_value(0); + } + + } + + solution.reset(); + + // call these functions so that the conditional and posterior probability + // tables get updated + load_selected_node_tables_into_cpt_grid(); + load_selected_node_tables_into_ppt_grid(); +} + +// ---------------------------------------------------------------------------------------- + +// This event is called when the user modifies the cpt_grid (i.e. the conditional +// probability table widget) +void main_window:: +on_cpt_grid_modified(unsigned long row, unsigned long col) +{ + node_type& node = graph_drawer.graph_node(selected_node_index); + solution.reset(); + + double prob; + try + { + // get the new value out of the table + prob = sa = cpt_grid.text(row,col); + } + catch (string_cast_error&) + { + cpt_grid.set_background_color(row,col,color_error); + return; + } + + // validate the value + if (prob < 0 || prob > 1) + { + cpt_grid.set_background_color(row,col,color_error); + return; + } + + // the value of this node that is having its conditional probability + // updated + const unsigned long cur_val = col-node.number_of_parents(); + + node.data.table().set_probability(cur_val, cpt_grid_assignments[row-1], prob); + + // sum the probabilities in the cpt and modify the last one such that they all + // sum to 1. We are excluding either the first or last element from the sum + // because we are going to set it equal to 1-sum below. + double sum = 0; + if (cur_val != node.data.table().num_values()-1) + { + for (unsigned long i = 0; i < node.data.table().num_values()-1; ++i) + sum += node.data.table().probability(i, cpt_grid_assignments[row-1]); + } + else + { + for (unsigned long i = 1; i < node.data.table().num_values(); ++i) + sum += node.data.table().probability(i, cpt_grid_assignments[row-1]); + } + + // make sure all the probabilities sum to 1 + if (sum > 1.0) + { + cpt_grid.set_background_color(row,cpt_grid.number_of_columns()-1,color_error); + } + else + { + // edit one of the other elements in the table to ensure that the probabilities still sum to 1 + if (cur_val == node.data.table().num_values()-1) + { + node.data.table().set_probability(0, cpt_grid_assignments[row-1], 1-sum); + cpt_grid.set_text(row,node.number_of_parents(),cast_to_string(1-sum)); + } + else + { + node.data.table().set_probability(node.data.table().num_values()-1, cpt_grid_assignments[row-1], 1-sum); + cpt_grid.set_text(row,cpt_grid.number_of_columns()-1,cast_to_string(1-sum)); + } + + cpt_grid.set_background_color(row,cpt_grid.number_of_columns()-1,color_default_bg); + cpt_grid.set_background_color(row,col,color_default_bg); + } + +} + +// ---------------------------------------------------------------------------------------- + +// This event is called when the user resizes the main_window. Note that unlike the other +// events, this event is part of the drawable_window base class that main_window inherits from. +// So you won't see any statements in the constructor that say "register the main_window::on_window_resized function" +void main_window:: +on_window_resized () +{ + // when you override any of the drawable_window events you have to make sure you + // call the drawable_window's version of them because it needs to process + // the events as well. So we do that here. + drawable_window::on_window_resized(); + + // The rest of this function positions the widgets on the window + unsigned long width,height; + get_size(width,height); + + // Don't do anything if the user just made the window too small. That is, leave + // the widgets where they are. + if (width < 500 || height < 350) + return; + + // Set the size of the probability tables and the drawing area for the graph + graph_drawer.set_size(width-370,height-10-mbar.height()); + cpt_grid.set_size((width-graph_drawer.width())-35,height-237); + ppt_grid.set_size((width-graph_drawer.width())-35,height-237); + // tell the tabbed display to make itself just the right size to contain + // the two probability tables. + tables.fit_to_contents(); + + + // Now position all the widgets in the window. Note that much of the positioning + // is relative to other widgets. This part of the code I just figured out by + // trying stuff and rerunning the program to see if it looked nice. + sel_node_index.set_pos(graph_drawer.right()+14,graph_drawer.top()+18); + sel_node_text_label.set_pos(sel_node_index.left(),sel_node_index.bottom()+5); + sel_node_text.set_pos(sel_node_text_label.right()+5,sel_node_index.bottom()); + sel_node_num_values_label.set_pos(sel_node_index.left(), sel_node_text.bottom()+5); + sel_node_num_values.set_pos(sel_node_num_values_label.right(), sel_node_text.bottom()+5); + sel_node_is_evidence.set_pos(sel_node_index.left(),sel_node_num_values.bottom()+5); + sel_node_evidence_label.set_pos(sel_node_index.left(),sel_node_is_evidence.bottom()+5); + sel_node_evidence.set_pos(sel_node_evidence_label.right()+5,sel_node_is_evidence.bottom()); + tables.set_pos(sel_node_index.left(),sel_node_evidence.bottom()+5); + sel_node_evidence.set_width(tables.right()-sel_node_evidence.left()+1); + sel_node_text.set_width(tables.right()-sel_node_text.left()+1); + sel_node_num_values.set_width(tables.right()-sel_node_num_values.left()+1); + + + + // Tell the named rectangle to position itself such that it fits around the + // tabbed display that contains the probability tables and the label at the top of the + // screen. + selected_node_rect.wrap_around(sel_node_index.get_rect()+ + tables.get_rect()); + + // finally set the button to be at the bottom of the named rectangle + btn_calculate.set_pos(selected_node_rect.left(), selected_node_rect.bottom()+5); +} + +// ---------------------------------------------------------------------------------------- + +// This event is called by the graph_drawer widget when the user selects a node +void main_window:: +on_node_selected (unsigned long n) +{ + // make a reference to the selected node + node_type& node = graph_drawer.graph_node(n); + + + // enable all the widgets related to the selected node + selected_node_index = n; + node_is_selected = true; + tables.enable(); + sel_node_is_evidence.enable(); + sel_node_index.enable(); + sel_node_evidence_label.enable(); + sel_node_text_label.enable(); + sel_node_text.enable(); + sel_node_num_values_label.enable(); + sel_node_num_values.enable(); + + // make sure the num_values field of the node's cpt is set to something valid. + // So default it to 2 if it isn't set already. + if (node.data.table().num_values() < 2) + { + node.data.table().set_num_values(2); + graph_modified_since_last_recalc = true; + } + + // setup the evidence check box and input field + sel_node_index.set_text("index: " + cast_to_string(n)); + if (graph_drawer.graph_node(n).data.is_evidence()) + { + sel_node_is_evidence.set_checked(); + sel_node_evidence.enable(); + sel_node_evidence.set_text(cast_to_string(graph_drawer.graph_node(n).data.value())); + } + else + { + sel_node_is_evidence.set_unchecked(); + sel_node_evidence.disable(); + sel_node_evidence.set_text(""); + } + + sel_node_num_values.set_text(cast_to_string(node_num_values(graph_drawer.graph(),n))); + + sel_node_text.set_text(graph_drawer.node_label(n)); + + load_selected_node_tables_into_cpt_grid(); + load_selected_node_tables_into_ppt_grid(); +} + +// ---------------------------------------------------------------------------------------- + +// This event is called when the user toggles the "is evidence" check box +void main_window:: +on_evidence_toggled ( +) +{ + if (sel_node_is_evidence.is_checked()) + { + graph_drawer.graph_node(selected_node_index).data.set_as_evidence(); + sel_node_evidence.enable(); + sel_node_evidence.set_text(cast_to_string(graph_drawer.graph_node(selected_node_index).data.value())); + + graph_drawer.set_node_color(selected_node_index, color_evidence); + } + else + { + graph_drawer.graph_node(selected_node_index).data.set_as_nonevidence(); + sel_node_evidence.disable(); + sel_node_evidence.set_text(""); + sel_node_evidence.set_background_color(color_default_bg); + graph_drawer.set_node_color(selected_node_index, color_non_evidence); + } + solution.reset(); +} + +// ---------------------------------------------------------------------------------------- + +// This event is called when the user causes no node to be selected +void main_window:: +on_node_deselected ( unsigned long ) +{ + no_node_selected(); +} + +// ---------------------------------------------------------------------------------------- + +// This event is called when the user causes a node to be deleted +void main_window:: +on_node_deleted ( ) +{ + no_node_selected(); +} + +// ---------------------------------------------------------------------------------------- + +// This event is called when the user changes the text in the "node label" text field +void main_window:: +on_sel_node_text_modified ( +) +{ + // set the selected node's text to match whatever the user just typed in + graph_drawer.set_node_label(selected_node_index,sel_node_text.text()); +} + +// ---------------------------------------------------------------------------------------- + diff --git a/ml/dlib/examples/bridge_ex.cpp b/ml/dlib/examples/bridge_ex.cpp new file mode 100644 index 000000000..bc772ccbb --- /dev/null +++ b/ml/dlib/examples/bridge_ex.cpp @@ -0,0 +1,365 @@ +// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt + + +/* + This is an example showing how to use the bridge object from from the + dlib C++ Library to send messages via TCP/IP. + + In particular, this example will walk you through four progressively + more complex use cases of the bridge object. Note that this example + program assumes you are already familiar with the pipe object and at + least the contents of the pipe_ex_2.cpp example program. +*/ + + +// =========== Example program output =========== +/* + ---- Running example 1 ---- + dequeued value: 1 + dequeued value: 2 + dequeued value: 3 + + ---- Running example 2 ---- + dequeued value: 1 + dequeued value: 2 + dequeued value: 3 + + ---- Running example 3 ---- + dequeued int: 1 + dequeued int: 2 + dequeued struct: 3 some string + + ---- Running example 4 ---- + bridge 1 status: is_connected: true + bridge 1 status: foreign_ip: 127.0.0.1 + bridge 1 status: foreign_port: 43156 + bridge 2 status: is_connected: true + bridge 2 status: foreign_ip: 127.0.0.1 + bridge 2 status: foreign_port: 12345 + dequeued int: 1 + dequeued int: 2 + dequeued struct: 3 some string + bridge 1 status: is_connected: false + bridge 1 status: foreign_ip: 127.0.0.1 + bridge 1 status: foreign_port: 12345 +*/ + + +#include +#include +#include + +using namespace dlib; +using namespace std; + +// ---------------------------------------------------------------------------------------- + +void run_example_1(); +void run_example_2(); +void run_example_3(); +void run_example_4(); + +// ---------------------------------------------------------------------------------------- + +int main() +{ + run_example_1(); + run_example_2(); + run_example_3(); + run_example_4(); +} + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + +void run_example_1( +) +{ + cout << "\n ---- Running example 1 ---- " << endl; + + /* + The idea of the bridge is basically to allow two different dlib::pipe objects + to be connected together via a TCP connection. This is best illustrated by + the following short example. In it we create two pipes, in and out. When + an object is enqueued into the out pipe it will be automatically sent + through a TCP connection and once received at the other end it will be + inserted into the in pipe. + */ + dlib::pipe in(4), out(4); + + + // This bridge will listen on port 12345 for an incoming TCP connection. Then + // it will read data from that connection and put it into the in pipe. + bridge b2(listen_on_port(12345), receive(in)); + + // This bridge will initiate a TCP connection and then start dequeuing + // objects from out and transmitting them over the connection. + bridge b1(connect_to_ip_and_port("127.0.0.1", 12345), transmit(out)); + + // As an aside, in a real program, each of these bridges and pipes would be in a + // separate application. But to make this example self contained they are both + // right here. + + + + // Now let's put some things into the out pipe + int value = 1; + out.enqueue(value); + + value = 2; + out.enqueue(value); + + value = 3; + out.enqueue(value); + + + // Now those 3 ints can be dequeued from the in pipe. They will show up + // in the same order they were inserted into the out pipe. + in.dequeue(value); + cout << "dequeued value: "<< value << endl; + in.dequeue(value); + cout << "dequeued value: "<< value << endl; + in.dequeue(value); + cout << "dequeued value: "<< value << endl; +} + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + +void run_example_2( +) +{ + cout << "\n ---- Running example 2 ---- " << endl; + + /* + This example makes a simple echo server on port 12345. When an object + is inserted into the out pipe it will be sent over a TCP connection, get + put into the echo pipe and then immediately read out of the echo pipe and + sent back over the TCP connection where it will finally be placed into the in + pipe. + */ + + dlib::pipe in(4), out(4), echo(4); + + // Just like TCP connections, a bridge can send data both directions. The directionality + // of a pipe is indicated by the receive() and transmit() type decorations. Also, the order + // they are listed doesn't matter. + bridge echo_bridge(listen_on_port(12345), receive(echo), transmit(echo)); + + // Note that you can also specify the ip and port as a string by using connect_to(). + bridge b1(connect_to("127.0.0.1:12345"), transmit(out), receive(in)); + + + int value = 1; + out.enqueue(value); + + value = 2; + out.enqueue(value); + + value = 3; + out.enqueue(value); + + + in.dequeue(value); + cout << "dequeued value: "<< value << endl; + in.dequeue(value); + cout << "dequeued value: "<< value << endl; + in.dequeue(value); + cout << "dequeued value: "<< value << endl; +} + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + +struct my_example_object +{ + /* + All objects passing through a dlib::bridge must be serializable. This + means there must exist global functions called serialize() and deserialize() + which can convert an object into a bit stream and then reverse the process. + + This example object illustrates how this is done. + */ + + int value; + std::string str; +}; + +void serialize (const my_example_object& item, std::ostream& out) +{ + /* + serialize() just needs to write the state of item to the output stream. + You can do this however you like. Below, I'm using the serialize functions + for int and std::string which come with dlib. But again, you can do whatever + you want here. + */ + dlib::serialize(item.value, out); + dlib::serialize(item.str, out); +} + +void deserialize (my_example_object& item, std::istream& in) +{ + /* + deserialize() is just the inverse of serialize(). Again, you can do + whatever you want here so long as it correctly reconstructs item. This + also means that deserialize() must always consume as many bytes as serialize() + generates. + */ + dlib::deserialize(item.value, in); + dlib::deserialize(item.str, in); +} + +// ---------------------------------------------------------------------------------------- + +void run_example_3( +) +{ + cout << "\n ---- Running example 3 ---- " << endl; + + /* + In this example we will just send ints and my_example_object objects + over a TCP connection. Since we are sending more than one type of + object through a pipe we will need to use the type_safe_union. + */ + + typedef type_safe_union tsu_type; + + dlib::pipe in(4), out(4); + + // Note that we don't have to start the listening bridge first. If b2 + // fails to make a connection it will just keep trying until successful. + bridge b2(connect_to("127.0.0.1:12345"), receive(in)); + // We don't have to configure a bridge in it's constructor. If it's + // more convenient we can do so by calling reconfigure() instead. + bridge b1; + b1.reconfigure(listen_on_port(12345), transmit(out)); + + tsu_type msg; + + msg = 1; + out.enqueue(msg); + + msg = 2; + out.enqueue(msg); + + msg.get().value = 3; + msg.get().str = "some string"; + out.enqueue(msg); + + + // dequeue the three objects we sent and print them on the screen. + for (int i = 0; i < 3; ++i) + { + in.dequeue(msg); + if (msg.contains()) + { + cout << "dequeued int: "<< msg.get() << endl; + } + else if (msg.contains()) + { + cout << "dequeued struct: "<< msg.get().value << " " + << msg.get().str << endl; + } + } +} + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + +void run_example_4( +) +{ + cout << "\n ---- Running example 4 ---- " << endl; + + /* + This final example is the same as example 3 except we will also now be getting + status messages from the bridges. These bridge_status messages tell us the + state of the TCP connection associated with a bridge. Is it connected or not? + Who it is connected to? + + The way you get these status messages is by ensuring that your receive pipe is + capable of storing bridge_status objects. If it is then the bridge will + automatically insert bridge_status messages into your receive pipe whenever + there is a status change. + + There are only two kinds of status changes. The establishment of a connection + or the closing of a connection. Also, a connection which closes due to you + calling clear(), reconfigure(), or destructing a bridge does not generate a + status message since, in this case, you already know about it and just want + the bridge to destroy itself as quickly as possible. + */ + + + typedef type_safe_union tsu_type; + + dlib::pipe in(4), out(4); + dlib::pipe b1_status(4); + + // setup both bridges to have receive pipes capable of holding bridge_status messages. + bridge b1(listen_on_port(12345), transmit(out), receive(b1_status)); + // Note that we can also use a hostname with connect_to() instead of supplying an IP address. + bridge b2(connect_to("localhost:12345"), receive(in)); + + tsu_type msg; + bridge_status bs; + + // Once a connection is established it will generate a status message from each bridge. + // Let's get those and print them. + b1_status.dequeue(bs); + cout << "bridge 1 status: is_connected: " << boolalpha << bs.is_connected << endl; + cout << "bridge 1 status: foreign_ip: " << bs.foreign_ip << endl; + cout << "bridge 1 status: foreign_port: " << bs.foreign_port << endl; + + in.dequeue(msg); + bs = msg.get(); + cout << "bridge 2 status: is_connected: " << bs.is_connected << endl; + cout << "bridge 2 status: foreign_ip: " << bs.foreign_ip << endl; + cout << "bridge 2 status: foreign_port: " << bs.foreign_port << endl; + + + + msg = 1; + out.enqueue(msg); + + msg = 2; + out.enqueue(msg); + + msg.get().value = 3; + msg.get().str = "some string"; + out.enqueue(msg); + + + // Read the 3 things we sent over the connection. + for (int i = 0; i < 3; ++i) + { + in.dequeue(msg); + if (msg.contains()) + { + cout << "dequeued int: "<< msg.get() << endl; + } + else if (msg.contains()) + { + cout << "dequeued struct: "<< msg.get().value << " " + << msg.get().str << endl; + } + } + + // cause bridge 1 to shutdown completely. This will close the connection and + // therefore bridge 2 will generate a status message indicating the connection + // just closed. + b1.clear(); + in.dequeue(msg); + bs = msg.get(); + cout << "bridge 1 status: is_connected: " << bs.is_connected << endl; + cout << "bridge 1 status: foreign_ip: " << bs.foreign_ip << endl; + cout << "bridge 1 status: foreign_port: " << bs.foreign_port << endl; +} + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + diff --git a/ml/dlib/examples/bsp_ex.cpp b/ml/dlib/examples/bsp_ex.cpp new file mode 100644 index 000000000..7dffa68d6 --- /dev/null +++ b/ml/dlib/examples/bsp_ex.cpp @@ -0,0 +1,282 @@ +// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt +/* + + This is an example illustrating the use of the Bulk Synchronous Parallel (BSP) + processing tools from the dlib C++ Library. These tools allow you to easily setup a + number of processes running on different computers which cooperate to compute some + result. + + In this example, we will use the BSP tools to find the minimizer of a simple function. + In particular, we will setup a nested grid search where different parts of the grid are + searched in parallel by different processes. + + + To run this program you should do the following (supposing you want to use three BSP + nodes to do the grid search and, to make things easy, you will run them all on your + current computer): + + 1. Open three command windows and navigate each to the folder containing the + compiled bsp_ex.cpp program. Let's call these window 1, window 2, and window 3. + + 2. In window 1 execute this command: + ./bsp_ex -l12345 + This will start a listening BSP node that listens on port 12345. The BSP node + won't do anything until we tell all the nodes to start running in step 4 below. + + 3. In window 2 execute this command: + ./bsp_ex -l12346 + This starts another listening BSP node. Note that since we are running this + example all on one computer you need to use different listening port numbers + for each listening node. + + 4. In window 3 execute this command: + ./bsp_ex localhost:12345 localhost:12346 + This will start a BSP node that connects to the others and gets them all running. + Additionally, as you will see when we go over the code below, it will also print + the final output of the BSP process, which is the minimizer of our test function. + Once it terminates, all the other BSP nodes will also automatically terminate. +*/ + + + + + +#include +#include +#include + +#include + +using namespace std; +using namespace dlib; + +// ---------------------------------------------------------------------------------------- + +// These are the functions executed by the BSP nodes. They are defined below. +void bsp_job_node_0 (bsp_context& bsp, double& min_value, double& optimal_x); +void bsp_job_other_nodes (bsp_context& bsp, long grid_resolution); + +// ---------------------------------------------------------------------------------------- + +int main(int argc, char** argv) +{ + try + { + // Use the dlib command_line_parser to parse the command line. See the + // compress_stream_ex.cpp example program for an introduction to the command line + // parser. + command_line_parser parser; + parser.add_option("h","Display this help message."); + parser.add_option("l","Run as a listening BSP node.",1); + parser.parse(argc, argv); + parser.check_option_arg_range("l", 1, 65535); + + + // Print a help message if the user gives -h on the command line. + if (parser.option("h")) + { + // display all the command line options + cout << "Usage: bsp_ex (-l port | )\n"; + parser.print_options(); + return 0; + } + + + // If the command line contained -l + if (parser.option("l")) + { + // Get the argument to -l + const unsigned short listening_port = get_option(parser, "l", 0); + cout << "Listening on port " << listening_port << endl; + + const long grid_resolution = 100; + + // bsp_listen() starts a listening BSP job. This means that it will wait until + // someone calls bsp_connect() and connects to it before it starts running. + // However, once it starts it will call bsp_job_other_nodes() which will then + // do all the real work. + // + // The first argument is the port to listen on. The second argument is the + // function which it should run to do all the work. The other arguments are + // optional and allow you to pass values into the bsp_job_other_nodes() + // routine. In this case, we are passing the grid_resolution to + // bsp_job_other_nodes(). + bsp_listen(listening_port, bsp_job_other_nodes, grid_resolution); + } + else + { + if (parser.number_of_arguments() == 0) + { + cout << "You must give some listening BSP nodes as arguments to this program!" << endl; + return 0; + } + + // Take the hostname:port strings from the command line and put them into the + // vector of hosts. + std::vector hosts; + for (unsigned long i = 0; i < parser.number_of_arguments(); ++i) + hosts.push_back(parser[i]); + + double min_value, optimal_x; + + // Calling bsp_connect() does two things. First, it tells all the BSP jobs + // listed in the hosts vector to start running. Second, it starts a locally + // running BSP job that executes bsp_job_node_0() and passes it any arguments + // listed after bsp_job_node_0. So in this case it passes it the 3rd and 4th + // arguments. + // + // Note also that we use dlib::ref() which causes these arguments to be passed + // by reference. This means that bsp_job_node_0() will be able to modify them + // and we will see the results here in main() after bsp_connect() terminates. + bsp_connect(hosts, bsp_job_node_0, dlib::ref(min_value), dlib::ref(optimal_x)); + + // bsp_connect() and bsp_listen() block until all the BSP nodes have terminated. + // Therefore, we won't get to this part of the code until the BSP processing + // has finished. But once we do we can print the results like so: + cout << "optimal_x: "<< optimal_x << endl; + cout << "min_value: "<< min_value << endl; + } + + } + catch (std::exception& e) + { + cout << "error in main(): " << e.what() << endl; + } +} + +// ---------------------------------------------------------------------------------------- + +/* + We are going to use the BSP tools to find the minimum of f(x). Note that + it's minimizer is at x == 2.0. +*/ +double f (double x) +{ + return std::pow(x-2.0, 2.0); +} + +// ---------------------------------------------------------------------------------------- + +void bsp_job_node_0 (bsp_context& bsp, double& min_value, double& optimal_x) +{ + // This function is called by bsp_connect(). In general, any BSP node can do anything + // you want. However, in this example we use this node as a kind of controller for the + // other nodes. In particular, since we are doing a nested grid search, this node's + // job will be to collect results from other nodes and then decide which part of the + // number line subsequent iterations should focus on. + // + // Also, each BSP node has a node ID number. You can determine it by calling + // bsp.node_id(). However, the node spawned by a call to bsp_connect() always has a + // node ID of 0 (hence the name of this function). Additionally, all functions + // executing a BSP task always take a bsp_context as their first argument. This object + // is the interface that allows BSP jobs to communicate with each other. + + + // Now let's get down to work. Recall that we are trying to find the x value that + // minimizes the f(x) defined above. The grid search will start out by considering the + // range [-1e100, 1e100] on the number line. It will progressively narrow this window + // until it has located the minimizer of f(x) to within 1e-15 of its true value. + double left = -1e100; + double right = 1e100; + + min_value = std::numeric_limits::infinity(); + double interval_width = std::abs(right-left); + + // keep going until the window is smaller than 1e-15. + while (right-left > 1e-15) + { + // At the start of each loop, we broadcast the current window to all the other BSP + // nodes. They will each search a separate part of the window and then report back + // the smallest values they found in their respective sub-windows. + // + // Also, you can send/broadcast/receive anything that has global serialize() and + // deserialize() routines defined for it. Dlib comes with serialization functions + // for a lot of types by default, so we don't have to define anything for this + // example program. However, if you want to send an object you defined then you + // will need to write your own serialization functions. See the documentation for + // dlib's serialize() routine or the bridge_ex.cpp example program for an example. + bsp.broadcast(left); + bsp.broadcast(right); + + // Receive the smallest values found from the other BSP nodes. + for (unsigned int k = 1; k < bsp.number_of_nodes(); ++k) + { + // The other nodes will send std::pairs of x/f(x) values. So that is what we + // receive. + std::pair val; + bsp.receive(val); + // save the smallest result. + if (val.second < min_value) + { + min_value = val.second; + optimal_x = val.first; + } + } + + // Now narrow the search window by half. + interval_width *= 0.5; + left = optimal_x - interval_width/2; + right = optimal_x + interval_width/2; + } +} + +// ---------------------------------------------------------------------------------------- + +void bsp_job_other_nodes (bsp_context& bsp, long grid_resolution) +{ + // This is the BSP job called by bsp_listen(). In these jobs we will receive window + // ranges from the controller node, search our sub-window, and then report back the + // location of the best x value we found. + + double left, right; + + // The try_receive() function will either return true with the next message or return + // false if there aren't any more messages in flight between nodes and all other BSP + // nodes are blocked on calls to receive or have terminated. That is, try_receive() + // only returns false if waiting for a message would result in all the BSP nodes + // waiting forever. + // + // Therefore, try_receive() serves both as a message receiving tool as well as an + // implicit form of barrier synchronization. In this case, we use it to know when to + // terminate. That is, we know it is time to terminate if all the messages between + // nodes have been received and all nodes are inactive due to either termination or + // being blocked on a receive call. This will happen once the controller node above + // terminates since it will result in all the other nodes inevitably becoming blocked + // on this try_receive() line with no messages to process. + while (bsp.try_receive(left)) + { + bsp.receive(right); + + // Compute a sub-window range for us to search. We use our node's ID value and the + // total number of nodes to select a subset of the [left, right] window. We will + // store the grid points from our sub-window in values_to_check. + const double l = (bsp.node_id()-1)/(bsp.number_of_nodes()-1.0); + const double r = bsp.node_id() /(bsp.number_of_nodes()-1.0); + const double width = right-left; + // Select grid_resolution number of points which are linearly spaced throughout our + // sub-window. + const matrix values_to_check = linspace(left+l*width, left+r*width, grid_resolution); + + // Search all the points in values_to_check and figure out which one gives the + // minimum value of f(). + double best_x = 0; + double best_val = std::numeric_limits::infinity(); + for (long j = 0; j < values_to_check.size(); ++j) + { + double temp = f(values_to_check(j)); + if (temp < best_val) + { + best_val = temp; + best_x = values_to_check(j); + } + } + + // Report back the identity of the best point we found in our sub-window. Note + // that the second argument to send(), the 0, is the node ID to send to. In this + // case we send our results back to the controller node. + bsp.send(make_pair(best_x, best_val), 0); + } +} + +// ---------------------------------------------------------------------------------------- + diff --git a/ml/dlib/examples/compress_stream_ex.cpp b/ml/dlib/examples/compress_stream_ex.cpp new file mode 100644 index 000000000..502400e5e --- /dev/null +++ b/ml/dlib/examples/compress_stream_ex.cpp @@ -0,0 +1,245 @@ +// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt +/* + + This is an example illustrating the use of the compress_stream and + cmd_line_parser components from the dlib C++ Library. + + This example implements a simple command line compression utility. + + + The output from the program when the -h option is given is: + + Usage: compress_stream_ex (-c|-d|-l) --in input_file --out output_file + Options: + -c Indicates that we want to compress a file. + -d Indicates that we want to decompress a file. + --in This option takes one argument which specifies the name of the + file we want to compress/decompress. + --out This option takes one argument which specifies the name of the + output file. + + Miscellaneous Options: + -h Display this help message. + -l Set the compression level [1-3], 3 is max compression, default + is 2. + +*/ + + + + +#include +#include +#include +#include +#include + +// I am making a typedefs for the versions of compress_stream I want to use. +typedef dlib::compress_stream::kernel_1da cs1; +typedef dlib::compress_stream::kernel_1ea cs2; +typedef dlib::compress_stream::kernel_1ec cs3; + + +using namespace std; +using namespace dlib; + + +int main(int argc, char** argv) +{ + try + { + command_line_parser parser; + + // first I will define the command line options I want. + // Add a -c option and tell the parser what the option is for. + parser.add_option("c","Indicates that we want to compress a file."); + parser.add_option("d","Indicates that we want to decompress a file."); + // add a --in option that takes 1 argument + parser.add_option("in","This option takes one argument which specifies the name of the file we want to compress/decompress.",1); + // add a --out option that takes 1 argument + parser.add_option("out","This option takes one argument which specifies the name of the output file.",1); + // In the code below, we use the parser.print_options() method to print all our + // options to the screen. We can tell it that we would like some options to be + // grouped together by calling set_group_name() before adding those options. In + // general, you can make as many groups as you like by calling set_group_name(). + // However, here we make only one named group. + parser.set_group_name("Miscellaneous Options"); + parser.add_option("h","Display this help message."); + parser.add_option("l","Set the compression level [1-3], 3 is max compression, default is 2.",1); + + + // now I will parse the command line + parser.parse(argc,argv); + + + // Now I will use the parser to validate some things about the command line. + // If any of the following checks fail then an exception will be thrown and it will + // contain a message that tells the user what the problem was. + + // First I want to check that none of the options were given on the command line + // more than once. To do this I define an array that contains the options + // that shouldn't appear more than once and then I just call check_one_time_options() + const char* one_time_opts[] = {"c", "d", "in", "out", "h", "l"}; + parser.check_one_time_options(one_time_opts); + // Here I'm checking that the user didn't pick both the c and d options at the + // same time. + parser.check_incompatible_options("c", "d"); + + // Here I'm checking that the argument to the l option is an integer in the range 1 to 3. + // That is, it should be convertible to an int by dlib::string_assign and be either + // 1, 2, or 3. Note that if you wanted to allow floating point values in the range 1 to + // 3 then you could give a range 1.0 to 3.0 or explicitly supply a type of float or double + // to the template argument of the check_option_arg_range() function. + parser.check_option_arg_range("l", 1, 3); + + // The 'l' option is a sub-option of the 'c' option. That is, you can only select the + // compression level when compressing. This command below checks that the listed + // sub options are always given in the presence of their parent options. + const char* c_sub_opts[] = {"l"}; + parser.check_sub_options("c", c_sub_opts); + + // check if the -h option was given on the command line + if (parser.option("h")) + { + // display all the command line options + cout << "Usage: compress_stream_ex (-c|-d|-l) --in input_file --out output_file\n"; + // This function prints out a nicely formatted list of + // all the options the parser has + parser.print_options(); + return 0; + } + + // Figure out what the compression level should be. If the user didn't supply + // this command line option then a value of 2 will be used. + int compression_level = get_option(parser,"l",2); + + + // make sure one of the c or d options was given + if (!parser.option("c") && !parser.option("d")) + { + cout << "Error in command line:\n You must specify either the c option or the d option.\n"; + cout << "\nTry the -h option for more information." << endl; + return 0; + } + + + string in_file; + string out_file; + + // check if the user told us the input file and if they did then + // get the file name + if (parser.option("in")) + { + in_file = parser.option("in").argument(); + } + else + { + cout << "Error in command line:\n You must specify an input file.\n"; + cout << "\nTry the -h option for more information." << endl; + return 0; + } + + + // check if the user told us the output file and if they did then + // get the file name + if (parser.option("out")) + { + out_file = parser.option("out").argument(); + } + else + { + cout << "Error in command line:\n You must specify an output file.\n"; + cout << "\nTry the -h option for more information." << endl; + return 0; + } + + + // open the files we will be reading from and writing to + ifstream fin(in_file.c_str(),ios::binary); + ofstream fout(out_file.c_str(),ios::binary); + + // make sure the files opened correctly + if (!fin) + { + cout << "Error opening file " << in_file << ".\n"; + return 0; + } + + if (!fout) + { + cout << "Error creating file " << out_file << ".\n"; + return 0; + } + + + + // now perform the actual compression or decompression. + if (parser.option("c")) + { + // save the compression level to the output file + serialize(compression_level, fout); + + switch (compression_level) + { + case 1: + { + cs1 compressor; + compressor.compress(fin,fout); + }break; + case 2: + { + cs2 compressor; + compressor.compress(fin,fout); + }break; + case 3: + { + cs3 compressor; + compressor.compress(fin,fout); + }break; + } + } + else + { + // obtain the compression level from the input file + deserialize(compression_level, fin); + + switch (compression_level) + { + case 1: + { + cs1 compressor; + compressor.decompress(fin,fout); + }break; + case 2: + { + cs2 compressor; + compressor.decompress(fin,fout); + }break; + case 3: + { + cs3 compressor; + compressor.decompress(fin,fout); + }break; + default: + { + cout << "Error in compressed file, invalid compression level" << endl; + }break; + } + } + + + + + } + catch (exception& e) + { + // Note that this will catch any cmd_line_parse_error exceptions and print + // the default message. + cout << e.what() << endl; + } +} + + + + + diff --git a/ml/dlib/examples/config.txt b/ml/dlib/examples/config.txt new file mode 100644 index 000000000..da21d170c --- /dev/null +++ b/ml/dlib/examples/config.txt @@ -0,0 +1,30 @@ +# This is an example config file. Note that # is used to create a comment. + +# At its most basic level a config file is just a bunch of key/value pairs. +# So for example: +key1 = value2 +dlib = a C++ library + +# You can also define "sub blocks" in your config files like so +user1 +{ + # Inside a sub block you can list more key/value pairs. + id = 42 + name = davis + + # you can also nest sub-blocks as deep as you want + details + { + editor = vim + home_dir = /home/davis + } +} +user2 { + id = 1234 + name = joe + details { + editor = emacs + home_dir = /home/joe + } +} + diff --git a/ml/dlib/examples/config_reader_ex.cpp b/ml/dlib/examples/config_reader_ex.cpp new file mode 100644 index 000000000..02ad1cc68 --- /dev/null +++ b/ml/dlib/examples/config_reader_ex.cpp @@ -0,0 +1,146 @@ +// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt +/* + + This is an example illustrating the use of the config_reader component + from the dlib C++ Library. + + This example uses the config_reader to load a config file and then + prints out the values of various fields in the file. +*/ + + +#include +#include +#include +#include + + +using namespace std; +using namespace dlib; + +// ---------------------------------------------------------------------------------------- +// For reference, the contents of the config file used in this example is reproduced below: +/* + +# This is an example config file. Note that # is used to create a comment. + +# At its most basic level a config file is just a bunch of key/value pairs. +# So for example: +key1 = value2 +dlib = a C++ library + +# You can also define "sub blocks" in your config files like so +user1 +{ + # Inside a sub block you can list more key/value pairs. + id = 42 + name = davis + + # you can also nest sub-blocks as deep as you want + details + { + editor = vim + home_dir = /home/davis + } +} +user2 { + id = 1234 + name = joe + details { + editor = emacs + home_dir = /home/joe + } +} + +*/ +// ---------------------------------------------------------------------------------------- + +void print_config_reader_contents ( + const config_reader& cr, + int depth = 0 +); +/* + This is a simple function that recursively walks through everything in + a config reader and prints it to the screen. +*/ + +// ---------------------------------------------------------------------------------------- + +int main() +{ + try + { + config_reader cr("config.txt"); + + // Use our recursive function to print everything in the config file. + print_config_reader_contents(cr); + + // Now let's access some of the fields of the config file directly. You + // use [] for accessing key values and .block() for accessing sub-blocks. + + // Print out the string value assigned to key1 in the config file + cout << cr["key1"] << endl; + + // Print out the name field inside the user1 sub-block + cout << cr.block("user1")["name"] << endl; + // Now print out the editor field in the details block + cout << cr.block("user1").block("details")["editor"] << endl; + + + // Note that you can use get_option() to easily convert fields into + // non-string types. For example, the config file has an integer id + // field that can be converted into an int like so: + int id1 = get_option(cr,"user1.id",0); + int id2 = get_option(cr,"user2.id",0); + cout << "user1's id is " << id1 << endl; + cout << "user2's id is " << id2 << endl; + // The third argument to get_option() is the default value returned if + // the config reader doesn't contain a corresponding entry. So for + // example, the following prints 321 since there is no user3. + int id3 = get_option(cr,"user3.id",321); + cout << "user3's id is " << id3 << endl; + + } + catch (exception& e) + { + // Finally, note that the config_reader throws exceptions if the config + // file is corrupted or if you ask it for a key or block that doesn't exist. + // Here we print out any such error messages. + cout << e.what() << endl; + } +} + +// ---------------------------------------------------------------------------------------- + +void print_config_reader_contents ( + const config_reader& cr, + int depth +) +{ + // Make a string with depth*4 spaces in it. + const string padding(depth*4, ' '); + + // We can obtain a list of all the keys and sub-blocks defined + // at the current level in the config reader like so: + vector keys, blocks; + cr.get_keys(keys); + cr.get_blocks(blocks); + + // Now print all the key/value pairs + for (unsigned long i = 0; i < keys.size(); ++i) + cout << padding << keys[i] << " = " << cr[keys[i]] << endl; + + // Now print all the sub-blocks. + for (unsigned long i = 0; i < blocks.size(); ++i) + { + // First print the block name + cout << padding << blocks[i] << " { " << endl; + // Now recursively print the contents of the sub block. Note that the cr.block() + // function returns another config_reader that represents the sub-block. + print_config_reader_contents(cr.block(blocks[i]), depth+1); + cout << padding << "}" << endl; + } +} + +// ---------------------------------------------------------------------------------------- + diff --git a/ml/dlib/examples/custom_trainer_ex.cpp b/ml/dlib/examples/custom_trainer_ex.cpp new file mode 100644 index 000000000..39af53f39 --- /dev/null +++ b/ml/dlib/examples/custom_trainer_ex.cpp @@ -0,0 +1,277 @@ +// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt +/* + This example program shows you how to create your own custom binary classification + trainer object and use it with the multiclass classification tools in the dlib C++ + library. This example assumes you have already become familiar with the concepts + introduced in the multiclass_classification_ex.cpp example program. + + + In this example we will create a very simple trainer object that takes a binary + classification problem and produces a decision rule which says a test point has the + same class as whichever centroid it is closest to. + + The multiclass training dataset will consist of four classes. Each class will be a blob + of points in one of the quadrants of the cartesian plane. For fun, we will use + std::string labels and therefore the labels of these classes will be the following: + "upper_left", + "upper_right", + "lower_left", + "lower_right" +*/ + +#include + +#include +#include + +#include + +using namespace std; +using namespace dlib; + +// Our data will be 2-dimensional data. So declare an appropriate type to contain these points. +typedef matrix sample_type; + +// ---------------------------------------------------------------------------------------- + +struct custom_decision_function +{ + /*! + WHAT THIS OBJECT REPRESENTS + This object is the representation of our binary decision rule. + !*/ + + // centers of the two classes + sample_type positive_center, negative_center; + + double operator() ( + const sample_type& x + ) const + { + // if x is closer to the positive class then return +1 + if (length(positive_center - x) < length(negative_center - x)) + return +1; + else + return -1; + } +}; + +// Later on in this example we will save our decision functions to disk. This +// pair of routines is needed for this functionality. +void serialize (const custom_decision_function& item, std::ostream& out) +{ + // write the state of item to the output stream + serialize(item.positive_center, out); + serialize(item.negative_center, out); +} + +void deserialize (custom_decision_function& item, std::istream& in) +{ + // read the data from the input stream and store it in item + deserialize(item.positive_center, in); + deserialize(item.negative_center, in); +} + +// ---------------------------------------------------------------------------------------- + +class simple_custom_trainer +{ + /*! + WHAT THIS OBJECT REPRESENTS + This is our example custom binary classifier trainer object. It simply + computes the means of the +1 and -1 classes, puts them into our + custom_decision_function, and returns the results. + + Below we define the train() function. I have also included the + requires/ensures definition for a generic binary classifier's train() + !*/ +public: + + + custom_decision_function train ( + const std::vector& samples, + const std::vector& labels + ) const + /*! + requires + - is_binary_classification_problem(samples, labels) == true + (e.g. labels consists of only +1 and -1 values, samples.size() == labels.size()) + ensures + - returns a decision function F with the following properties: + - if (new_x is a sample predicted have +1 label) then + - F(new_x) >= 0 + - else + - F(new_x) < 0 + !*/ + { + sample_type positive_center, negative_center; + + // compute sums of each class + positive_center = 0; + negative_center = 0; + for (unsigned long i = 0; i < samples.size(); ++i) + { + if (labels[i] == +1) + positive_center += samples[i]; + else // this is a -1 sample + negative_center += samples[i]; + } + + // divide by number of +1 samples + positive_center /= sum(mat(labels) == +1); + // divide by number of -1 samples + negative_center /= sum(mat(labels) == -1); + + custom_decision_function df; + df.positive_center = positive_center; + df.negative_center = negative_center; + + return df; + } +}; + +// ---------------------------------------------------------------------------------------- + +void generate_data ( + std::vector& samples, + std::vector& labels +); +/*! + ensures + - make some four class data as described above. + - each class will have 50 samples in it +!*/ + +// ---------------------------------------------------------------------------------------- + +int main() +{ + std::vector samples; + std::vector labels; + + // First, get our labeled set of training data + generate_data(samples, labels); + + cout << "samples.size(): "<< samples.size() << endl; + + // Define the trainer we will use. The second template argument specifies the type + // of label used, which is string in this case. + typedef one_vs_one_trainer, string> ovo_trainer; + + + ovo_trainer trainer; + + // Now tell the one_vs_one_trainer that, by default, it should use the simple_custom_trainer + // to solve the individual binary classification subproblems. + trainer.set_trainer(simple_custom_trainer()); + + // Next, to make things a little more interesting, we will setup the one_vs_one_trainer + // to use kernel ridge regression to solve the upper_left vs lower_right binary classification + // subproblem. + typedef radial_basis_kernel rbf_kernel; + krr_trainer rbf_trainer; + rbf_trainer.set_kernel(rbf_kernel(0.1)); + trainer.set_trainer(rbf_trainer, "upper_left", "lower_right"); + + + // Now let's do 5-fold cross-validation using the one_vs_one_trainer we just setup. + // As an aside, always shuffle the order of the samples before doing cross validation. + // For a discussion of why this is a good idea see the svm_ex.cpp example. + randomize_samples(samples, labels); + cout << "cross validation: \n" << cross_validate_multiclass_trainer(trainer, samples, labels, 5) << endl; + // This dataset is very easy and everything is correctly classified. Therefore, the output of + // cross validation is the following confusion matrix. + /* + 50 0 0 0 + 0 50 0 0 + 0 0 50 0 + 0 0 0 50 + */ + + + // We can also obtain the decision rule as always. + one_vs_one_decision_function df = trainer.train(samples, labels); + + cout << "predicted label: "<< df(samples[0]) << ", true label: "<< labels[0] << endl; + cout << "predicted label: "<< df(samples[90]) << ", true label: "<< labels[90] << endl; + // The output is: + /* + predicted label: upper_right, true label: upper_right + predicted label: lower_left, true label: lower_left + */ + + + // Finally, let's save our multiclass decision rule to disk. Remember that we have + // to specify the types of binary decision function used inside the one_vs_one_decision_function. + one_vs_one_decision_function > // This is the output of the rbf_trainer + > df2, df3; + + df2 = df; + // save to a file called df.dat + serialize("df.dat") << df2; + + // load the function back in from disk and store it in df3. + deserialize("df.dat") >> df3; + + + // Test df3 to see that this worked. + cout << endl; + cout << "predicted label: "<< df3(samples[0]) << ", true label: "<< labels[0] << endl; + cout << "predicted label: "<< df3(samples[90]) << ", true label: "<< labels[90] << endl; + // Test df3 on the samples and labels and print the confusion matrix. + cout << "test deserialized function: \n" << test_multiclass_decision_function(df3, samples, labels) << endl; + +} + +// ---------------------------------------------------------------------------------------- + +void generate_data ( + std::vector& samples, + std::vector& labels +) +{ + const long num = 50; + + sample_type m; + + dlib::rand rnd; + + + // add some points in the upper right quadrant + m = 10, 10; + for (long i = 0; i < num; ++i) + { + samples.push_back(m + randm(2,1,rnd)); + labels.push_back("upper_right"); + } + + // add some points in the upper left quadrant + m = -10, 10; + for (long i = 0; i < num; ++i) + { + samples.push_back(m + randm(2,1,rnd)); + labels.push_back("upper_left"); + } + + // add some points in the lower right quadrant + m = 10, -10; + for (long i = 0; i < num; ++i) + { + samples.push_back(m + randm(2,1,rnd)); + labels.push_back("lower_right"); + } + + // add some points in the lower left quadrant + m = -10, -10; + for (long i = 0; i < num; ++i) + { + samples.push_back(m + randm(2,1,rnd)); + labels.push_back("lower_left"); + } + +} + +// ---------------------------------------------------------------------------------------- + diff --git a/ml/dlib/examples/dir_nav_ex.cpp b/ml/dlib/examples/dir_nav_ex.cpp new file mode 100644 index 000000000..2f51f2d1b --- /dev/null +++ b/ml/dlib/examples/dir_nav_ex.cpp @@ -0,0 +1,75 @@ +// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt +/* + + This is an example illustrating the use of the dir_nav component from the dlib C++ Library. + It prints a listing of all directories and files in the users + current working directory or the directory specified on the command line. + +*/ + + +#include +#include +#include +#include +#include + +using namespace std; +using namespace dlib; + + +int main(int argc, char** argv) +{ + try + { + string loc; + if (argc == 2) + loc = argv[1]; + else + loc = "."; // if no argument is given then use the current working dir. + + directory test(loc); + + + cout << "directory: " << test.name() << endl; + cout << "full path: " << test.full_name() << endl; + cout << "is root: " << ((test.is_root())?"yes":"no") << endl; + + // get all directories and files in test + std::vector dirs = test.get_dirs(); + std::vector files = test.get_files(); + + // sort the files and directories + sort(files.begin(), files.end()); + sort(dirs.begin(), dirs.end()); + + cout << "\n\n\n"; + + // print all the subdirectories + for (unsigned long i = 0; i < dirs.size(); ++i) + cout << " " << dirs[i].name() << "\n"; + + // print all the subfiles + for (unsigned long i = 0; i < files.size(); ++i) + cout << setw(13) << files[i].size() << " " << files[i].name() << "\n"; + + + cout << "\n\nnumber of dirs: " << dirs.size() << endl; + cout << "number of files: " << files.size() << endl; + + } + catch (file::file_not_found& e) + { + cout << "file not found or accessible: " << e.info << endl; + } + catch (directory::dir_not_found& e) + { + cout << "dir not found or accessible: " << e.info << endl; + } + catch (directory::listing_error& e) + { + cout << "listing error: " << e.info << endl; + } +} + + diff --git a/ml/dlib/examples/dnn_face_recognition_ex.cpp b/ml/dlib/examples/dnn_face_recognition_ex.cpp new file mode 100644 index 000000000..4c0a2a02b --- /dev/null +++ b/ml/dlib/examples/dnn_face_recognition_ex.cpp @@ -0,0 +1,220 @@ +// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt +/* + This is an example illustrating the use of the deep learning tools from the dlib C++ + Library. In it, we will show how to do face recognition. This example uses the + pretrained dlib_face_recognition_resnet_model_v1 model which is freely available from + the dlib web site. This model has a 99.38% accuracy on the standard LFW face + recognition benchmark, which is comparable to other state-of-the-art methods for face + recognition as of February 2017. + + In this example, we will use dlib to do face clustering. Included in the examples + folder is an image, bald_guys.jpg, which contains a bunch of photos of action movie + stars Vin Diesel, The Rock, Jason Statham, and Bruce Willis. We will use dlib to + automatically find their faces in the image and then to automatically determine how + many people there are (4 in this case) as well as which faces belong to each person. + + Finally, this example uses a network with the loss_metric loss. Therefore, if you want + to learn how to train your own models, or to get a general introduction to this loss + layer, you should read the dnn_metric_learning_ex.cpp and + dnn_metric_learning_on_images_ex.cpp examples. +*/ + +#include +#include +#include +#include +#include +#include + +using namespace dlib; +using namespace std; + +// ---------------------------------------------------------------------------------------- + +// The next bit of code defines a ResNet network. It's basically copied +// and pasted from the dnn_imagenet_ex.cpp example, except we replaced the loss +// layer with loss_metric and made the network somewhat smaller. Go read the introductory +// dlib DNN examples to learn what all this stuff means. +// +// Also, the dnn_metric_learning_on_images_ex.cpp example shows how to train this network. +// The dlib_face_recognition_resnet_model_v1 model used by this example was trained using +// essentially the code shown in dnn_metric_learning_on_images_ex.cpp except the +// mini-batches were made larger (35x15 instead of 5x5), the iterations without progress +// was set to 10000, and the training dataset consisted of about 3 million images instead of +// 55. Also, the input layer was locked to images of size 150. +template